Files
aiTTS/tts_engine.py

296 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import soundfile as sf
import yaml
import os
import time
import datetime
import numpy as np
import gc
from pathlib import Path
from typing import Optional, List, Dict, Tuple
from huggingface_hub import snapshot_download
# ---
# Блок импорта модели.
try:
from qwen_tts import Qwen3TTSModel
except ImportError:
print("WARNING: qwen_tts not found. Using Mock Model for testing.")
class Qwen3TTSModel:
@staticmethod
def from_pretrained(path, **kwargs):
print(f"[Mock] Loading model from {path}")
return Qwen3TTSModel()
def create_voice_clone_prompt(self, **kwargs):
return "mock_prompt"
def generate_voice_clone(self, text, **kwargs):
print(f"[Mock] Generating voice clone for: {text[:30]}...")
sr = 24000
duration = len(text) * 0.1
return np.random.rand(1, int(sr * duration)).astype(np.float32), sr
def generate_custom_voice(self, text, **kwargs):
return self.generate_voice_clone(text, **kwargs)
def generate_voice_design(self, text, **kwargs):
return self.generate_voice_clone(text, **kwargs)
def get_supported_speakers(self):
return ["Chelsie", "Dylan", "Eric", "Serena", "Vivian", "Aiden", "Ryan"]
# ---
class TTSEngine:
def __init__(self, config_path: str = "config.yaml"):
with open(config_path, 'r', encoding='utf-8') as f:
self.config = yaml.safe_load(f)
self.models = {}
self.current_model_type = None
try:
self.dtype = getattr(torch, self.config['generation']['dtype'])
except AttributeError:
self.dtype = torch.float16 # По умолчанию FP16
if torch.cuda.is_available():
gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3)
print(f"📊 GPU: {torch.cuda.get_device_name(0)}")
print(f"📊 GPU Memory: {gpu_mem:.1f} GB")
# Инициализация папок
Path(self.config['storage']['model_path']).mkdir(parents=True, exist_ok=True)
Path(self.config['storage']['sample_dir']).mkdir(parents=True, exist_ok=True)
Path(self.config['storage']['output_dir']).mkdir(parents=True, exist_ok=True)
def _resolve_model(self, model_key: str) -> str:
"""Умная загрузка моделей"""
model_cfg_value = self.config['models'][model_key]
base_model_path = self.config['storage']['model_path']
if os.path.isabs(model_cfg_value) or os.path.exists(model_cfg_value):
return model_cfg_value
folder_name = model_cfg_value.split('/')[-1]
local_path = os.path.join(base_model_path, folder_name)
if os.path.exists(local_path) and os.listdir(local_path):
return local_path
print(f"⬇️ Downloading {model_key}...")
try:
snapshot_download(repo_id=model_cfg_value, local_dir=local_path, local_dir_use_symlinks=False)
return local_path
except Exception as e:
raise RuntimeError(f"Failed to load model {model_key}: {e}")
def _unload_other_models(self, keep_model_type: str):
"""Выгружает все модели кроме указанной"""
for mtype in list(self.models.keys()):
if mtype != keep_model_type and mtype in self.models:
print(f"🗑️ Unloading model [{mtype}] to free memory...")
del self.models[mtype]
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _get_model(self, model_type: str):
# Если модель уже загружена — возвращаем её
if model_type in self.models:
return self.models[model_type]
# Выгружаем другие модели чтобы освободить память
self._unload_other_models(model_type)
model_path = self._resolve_model(model_type)
print(f"🚀 Loading model [{model_type}]...")
if torch.cuda.is_available():
torch.cuda.empty_cache()
free_before = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()) / (1024**3)
print(f"📊 Free GPU memory before load: {free_before:.2f} GB")
try:
# Стратегия 1: Пробуем загрузить всё на GPU в FP16
print(f"⚙️ Trying FP16 on GPU...")
self.models[model_type] = Qwen3TTSModel.from_pretrained(
model_path,
dtype=torch.float16,
device_map="cuda:0",
low_cpu_mem_usage=True
)
print(f"✅ Model [{model_type}] loaded on GPU (FP16)")
except RuntimeError as e:
if "out of memory" in str(e).lower():
print(f"⚠️ GPU OOM, trying CPU offloading...")
torch.cuda.empty_cache()
gc.collect()
# Стратегия 2: Используем accelerate с offloading
# Но при этом избегаем bitsandbytes который вызывает pickle ошибку
print(f"⚙️ Using accelerate with CPU offloading (FP16)...")
# Ограничиваем память GPU чтобы force offloading
max_memory = {0: "3GiB", "cpu": "28GiB"} # Оставляем 3GB для одной модели
self.models[model_type] = Qwen3TTSModel.from_pretrained(
model_path,
dtype=torch.float16,
device_map="auto",
max_memory=max_memory,
low_cpu_mem_usage=True
)
print(f"✅ Model [{model_type}] loaded with CPU offloading")
else:
raise e
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / (1024**3)
print(f"📊 GPU memory allocated: {allocated:.2f} GB")
return self.models[model_type]
def get_available_samples(self) -> List[Dict[str, str]]:
samples = []
sample_dir = self.config['storage']['sample_dir']
if not os.path.exists(sample_dir): return samples
for name in sorted(os.listdir(sample_dir)):
full_path = os.path.join(sample_dir, name)
if os.path.isdir(full_path):
audio_path = os.path.join(full_path, "audio.wav")
prompt_path = os.path.join(full_path, "prompt.txt")
if os.path.exists(audio_path):
prompt = ""
if os.path.exists(prompt_path):
with open(prompt_path, 'r', encoding='utf-8') as f:
prompt = f.read().strip()
samples.append({"name": name, "path": full_path, "prompt": prompt})
return samples
def generate_with_sample(self, text: str, sample_path: str) -> Tuple[np.ndarray, int]:
model = self._get_model('base')
audio_file = os.path.join(sample_path, "audio.wav")
prompt_file = os.path.join(sample_path, "prompt.txt")
ref_text = None
if os.path.exists(prompt_file):
with open(prompt_file, 'r', encoding='utf-8') as f:
ref_text = f.read().strip()
print(f"🎤 Cloning voice...")
prompt = model.create_voice_clone_prompt(ref_audio=audio_file, ref_text=ref_text)
wavs, sr = model.generate_voice_clone(
text=text,
language=self.config['generation']['default_language'],
voice_clone_prompt=prompt
)
return wavs, sr
def generate_with_description(self, text: str, description: str) -> Tuple[np.ndarray, int]:
print(f"🎨 Designing voice: '{description}'")
# Генерируем референс через VoiceDesign
vd_model = self._get_model('voice_design')
ref_text = text[:100] if len(text) > 100 else text
ref_wavs, ref_sr = vd_model.generate_voice_design(
text=ref_text,
language=self.config['generation']['default_language'],
instruct=description
)
# Переключаемся на Base (VoiceDesign автоматически выгрузится)
base_model = self._get_model('base')
prompt = base_model.create_voice_clone_prompt(
ref_audio=(ref_wavs[0], ref_sr),
ref_text=ref_text
)
wavs, sr = base_model.generate_voice_clone(
text=text,
language=self.config['generation']['default_language'],
voice_clone_prompt=prompt
)
return wavs, sr
def generate_voice_design_only(self, text: str, description: str) -> Tuple[np.ndarray, int]:
"""Режим предпрослушки: только VoiceDesign без клонирования"""
print(f"🎨 VoiceDesign preview: '{description}'")
model = self._get_model('voice_design')
wavs, sr = model.generate_voice_design(
text=text,
language=self.config['generation']['default_language'],
instruct=description
)
return wavs, sr
def generate_standard(self, text: str, speaker: str = None) -> Tuple[np.ndarray, int]:
model = self._get_model('custom_voice')
speaker = speaker or self.config['generation']['default_speaker']
print(f"🗣️ Using speaker: {speaker}")
wavs, sr = model.generate_custom_voice(
text=text,
language=self.config['generation']['default_language'],
speaker=speaker
)
return wavs, sr
def download_all_models(self):
print("\n--- Checking models ---")
for key in ['base', 'voice_design', 'custom_voice']:
try:
self._resolve_model(key)
print(f"{key}: OK")
except Exception as e:
print(f"{key}: {e}")
def get_custom_speakers_list(self):
try:
model = self._get_model('custom_voice')
speakers = model.get_supported_speakers()
return list(speakers) if hasattr(speakers, '__iter__') else speakers
except Exception as e:
print(f"Error: {e}")
return ["Chelsie", "Dylan", "Eric", "Serena", "Vivian", "Aiden", "Ryan"]
def save_result(self, text: str, wavs: np.ndarray, sr: int) -> str:
out_dir = self.config['storage']['output_dir']
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"speech_{timestamp}"
wav_path = os.path.join(out_dir, f"{filename}.wav")
txt_path = os.path.join(out_dir, f"{filename}.txt")
sf.write(wav_path, wavs[0], sr)
with open(txt_path, 'w', encoding='utf-8') as f:
f.write(text)
return wav_path
def get_history(self) -> List[Dict[str, str]]:
out_dir = self.config['storage']['output_dir']
history = []
if not os.path.exists(out_dir): return history
for f in sorted(os.listdir(out_dir), reverse=True):
if f.endswith(".wav"):
base_name = f[:-4]
txt_path = os.path.join(out_dir, f"{base_name}.txt")
wav_path = os.path.join(out_dir, f)
text_content = "(Текст не найден)"
if os.path.exists(txt_path):
with open(txt_path, 'r', encoding='utf-8') as file:
text_content = file.read()
history.append({"filename": f, "wav_path": wav_path, "txt_path": txt_path, "text": text_content})
return history