Files
aiTTS/tts_engine.py

244 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import soundfile as sf
import yaml
import os
import time
import datetime
import numpy as np
from pathlib import Path
from typing import Optional, List, Dict, Tuple
from huggingface_hub import snapshot_download
# ---
# Блок импорта модели.
# Если у вас установлен отдельный пакет qwen-tts:
try:
from qwen_tts import Qwen3TTSModel
except ImportError:
# Заглушка для тестирования без реальной модели (эмуляция)
print("WARNING: qwen_tts not found. Using Mock Model for testing.")
class Qwen3TTSModel:
@staticmethod
def from_pretrained(path, **kwargs):
print(f"[Mock] Loading model from {path}")
return Qwen3TTSModel()
def create_voice_clone_prompt(self, **kwargs):
return "mock_prompt"
def generate_voice_clone(self, text, **kwargs):
print(f"[Mock] Generating voice clone for: {text[:30]}...")
# Возвращаем пустой массив нужной размерности
sr = 24000
duration = len(text) * 0.1
return np.random.rand(1, int(sr * duration)).astype(np.float32), sr
def generate_custom_voice(self, text, **kwargs):
return self.generate_voice_clone(text, **kwargs)
def generate_voice_design(self, text, **kwargs):
return self.generate_voice_clone(text, **kwargs)
# ---
class TTSEngine:
def __init__(self, config_path: str = "config.yaml"):
with open(config_path, 'r', encoding='utf-8') as f:
self.config = yaml.safe_load(f)
self.models = {}
try:
self.dtype = getattr(torch, self.config['generation']['dtype'])
except AttributeError:
self.dtype = torch.float32
# Инициализация папок
Path(self.config['storage']['model_path']).mkdir(parents=True, exist_ok=True)
Path(self.config['storage']['sample_dir']).mkdir(parents=True, exist_ok=True)
Path(self.config['storage']['output_dir']).mkdir(parents=True, exist_ok=True)
def _resolve_model(self, model_key: str) -> str:
"""
Умная загрузка моделей:
1. Абсолютный путь -> использовать его.
2. Локальный путь внутри model_path -> использовать.
3. Скачать с HF в model_path.
"""
model_cfg_value = self.config['models'][model_key]
base_model_path = self.config['storage']['model_path']
# 1. Если это абсолютный путь или файл уже существует по этому пути
if os.path.isabs(model_cfg_value) or os.path.exists(model_cfg_value):
print(f"📂 Model [{model_key}]: Using direct path {model_cfg_value}")
return model_cfg_value
# 2. Формируем путь внутри хранилища
# Используем имя репозития как имя папки (замена / на _ если нужно, или сохранение структуры)
folder_name = model_cfg_value.split('/')[-1]
local_path = os.path.join(base_model_path, folder_name)
if os.path.exists(local_path) and os.listdir(local_path):
print(f"📂 Model [{model_key}]: Found locally at {local_path}")
return local_path
# 3. Скачивание с Hugging Face
print(f"⬇️ Model [{model_key}]: Not found. Downloading from HF to {local_path}...")
try:
snapshot_download(
repo_id=model_cfg_value,
local_dir=local_path,
local_dir_use_symlinks=False
)
print(f"✅ Model [{model_key}]: Downloaded.")
return local_path
except Exception as e:
print(f"❌ Error downloading model {model_cfg_value}: {e}")
raise RuntimeError(f"Failed to load model {model_key}")
def _get_model(self, model_type: str):
if model_type not in self.models:
model_path = self._resolve_model(model_type)
print(f"🚀 Loading model [{model_type}] into memory...")
self.models[model_type] = Qwen3TTSModel.from_pretrained(
model_path,
device_map=self.config['generation']['device'],
torch_dtype=self.dtype
)
return self.models[model_type]
def get_available_samples(self) -> List[Dict[str, str]]:
samples = []
sample_dir = self.config['storage']['sample_dir']
if not os.path.exists(sample_dir): return samples
for name in sorted(os.listdir(sample_dir)):
full_path = os.path.join(sample_dir, name)
if os.path.isdir(full_path):
audio_path = os.path.join(full_path, "audio.wav")
prompt_path = os.path.join(full_path, "prompt.txt")
if os.path.exists(audio_path):
prompt = ""
if os.path.exists(prompt_path):
with open(prompt_path, 'r', encoding='utf-8') as f:
prompt = f.read().strip()
samples.append({
"name": name,
"path": full_path,
"prompt": prompt
})
return samples
def generate_with_sample(self, text: str, sample_path: str) -> Tuple[np.ndarray, int]:
"""Режим 1: Клонирование по сэмплу"""
model = self._get_model('base')
audio_file = os.path.join(sample_path, "audio.wav")
prompt_file = os.path.join(sample_path, "prompt.txt")
ref_text = None
if os.path.exists(prompt_file):
with open(prompt_file, 'r', encoding='utf-8') as f:
ref_text = f.read().strip()
print(f"🎤 Cloning voice from: {sample_path}")
# Создаем промпт клонирования
prompt = model.create_voice_clone_prompt(
ref_audio=audio_file,
ref_text=ref_text
)
wavs, sr = model.generate_voice_clone(
text=text,
language=self.config['generation']['default_language'],
voice_clone_prompt=prompt
)
return wavs, sr
def generate_with_description(self, text: str, description: str) -> Tuple[np.ndarray, int]:
"""Режим 2: Генерация голоса по описанию (Design -> Clone)"""
print(f"🎨 Designing voice: '{description}'")
# Шаг А: Генерируем референс через VoiceDesign
vd_model = self._get_model('voice_design')
ref_text = text[:100] if len(text) > 100 else text
# Генерируем сэмпл для будущего клонирования
ref_wavs, ref_sr = vd_model.generate_voice_design(
text=ref_text,
language=self.config['generation']['default_language'],
instruct=description
)
# Шаг Б: Клонируем этот сгенерированный голос через Base модель
base_model = self._get_model('base')
# Передаем tuple (numpy_array, sr) как ref_audio
prompt = base_model.create_voice_clone_prompt(
ref_audio=(ref_wavs[0], ref_sr),
ref_text=ref_text
)
wavs, sr = base_model.generate_voice_clone(
text=text,
language=self.config['generation']['default_language'],
voice_clone_prompt=prompt
)
return wavs, sr
def generate_standard(self, text: str, speaker: str = None) -> Tuple[np.ndarray, int]:
"""Режим 3: Стандартный голос"""
model = self._get_model('custom_voice')
speaker = speaker or self.config['generation']['default_speaker']
print(f"🗣️ Using built-in speaker: {speaker}")
wavs, sr = model.generate_custom_voice(
text=text,
language=self.config['generation']['default_language'],
speaker=speaker
)
return wavs, sr
def save_result(self, text: str, wavs: np.ndarray, sr: int) -> str:
"""Сохраняет WAV и TXT в папку out"""
out_dir = self.config['storage']['output_dir']
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"speech_{timestamp}"
wav_path = os.path.join(out_dir, f"{filename}.wav")
txt_path = os.path.join(out_dir, f"{filename}.txt")
# Сохраняем аудио
sf.write(wav_path, wavs[0], sr)
# Сохраняем текст
with open(txt_path, 'w', encoding='utf-8') as f:
f.write(text)
return wav_path
def get_history(self) -> List[Dict[str, str]]:
"""Возвращает список сгенерированных файлов"""
out_dir = self.config['storage']['output_dir']
history = []
if not os.path.exists(out_dir): return history
files = sorted(os.listdir(out_dir), reverse=True)
for f in files:
if f.endswith(".wav"):
base_name = f[:-4]
txt_path = os.path.join(out_dir, f"{base_name}.txt")
wav_path = os.path.join(out_dir, f)
text_content = "(Текст не найден)"
if os.path.exists(txt_path):
with open(txt_path, 'r', encoding='utf-8') as file:
text_content = file.read()
history.append({
"filename": f,
"wav_path": wav_path,
"txt_path": txt_path,
"text": text_content
})
return history