296 lines
12 KiB
Python
296 lines
12 KiB
Python
import torch
|
||
import soundfile as sf
|
||
import yaml
|
||
import os
|
||
import time
|
||
import datetime
|
||
import numpy as np
|
||
import gc
|
||
from pathlib import Path
|
||
from typing import Optional, List, Dict, Tuple
|
||
from huggingface_hub import snapshot_download
|
||
|
||
# ---
|
||
# Блок импорта модели.
|
||
try:
|
||
from qwen_tts import Qwen3TTSModel
|
||
except ImportError:
|
||
print("WARNING: qwen_tts not found. Using Mock Model for testing.")
|
||
class Qwen3TTSModel:
|
||
@staticmethod
|
||
def from_pretrained(path, **kwargs):
|
||
print(f"[Mock] Loading model from {path}")
|
||
return Qwen3TTSModel()
|
||
|
||
def create_voice_clone_prompt(self, **kwargs):
|
||
return "mock_prompt"
|
||
|
||
def generate_voice_clone(self, text, **kwargs):
|
||
print(f"[Mock] Generating voice clone for: {text[:30]}...")
|
||
sr = 24000
|
||
duration = len(text) * 0.1
|
||
return np.random.rand(1, int(sr * duration)).astype(np.float32), sr
|
||
|
||
def generate_custom_voice(self, text, **kwargs):
|
||
return self.generate_voice_clone(text, **kwargs)
|
||
|
||
def generate_voice_design(self, text, **kwargs):
|
||
return self.generate_voice_clone(text, **kwargs)
|
||
|
||
def get_supported_speakers(self):
|
||
return ["Chelsie", "Dylan", "Eric", "Serena", "Vivian", "Aiden", "Ryan"]
|
||
# ---
|
||
|
||
class TTSEngine:
|
||
def __init__(self, config_path: str = "config.yaml"):
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
self.config = yaml.safe_load(f)
|
||
|
||
self.models = {}
|
||
self.current_model_type = None
|
||
|
||
try:
|
||
self.dtype = getattr(torch, self.config['generation']['dtype'])
|
||
except AttributeError:
|
||
self.dtype = torch.float16 # По умолчанию FP16
|
||
|
||
if torch.cuda.is_available():
|
||
gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
||
print(f"📊 GPU: {torch.cuda.get_device_name(0)}")
|
||
print(f"📊 GPU Memory: {gpu_mem:.1f} GB")
|
||
|
||
# Инициализация папок
|
||
Path(self.config['storage']['model_path']).mkdir(parents=True, exist_ok=True)
|
||
Path(self.config['storage']['sample_dir']).mkdir(parents=True, exist_ok=True)
|
||
Path(self.config['storage']['output_dir']).mkdir(parents=True, exist_ok=True)
|
||
|
||
def _resolve_model(self, model_key: str) -> str:
|
||
"""Умная загрузка моделей"""
|
||
model_cfg_value = self.config['models'][model_key]
|
||
base_model_path = self.config['storage']['model_path']
|
||
|
||
if os.path.isabs(model_cfg_value) or os.path.exists(model_cfg_value):
|
||
return model_cfg_value
|
||
|
||
folder_name = model_cfg_value.split('/')[-1]
|
||
local_path = os.path.join(base_model_path, folder_name)
|
||
|
||
if os.path.exists(local_path) and os.listdir(local_path):
|
||
return local_path
|
||
|
||
print(f"⬇️ Downloading {model_key}...")
|
||
try:
|
||
snapshot_download(repo_id=model_cfg_value, local_dir=local_path, local_dir_use_symlinks=False)
|
||
return local_path
|
||
except Exception as e:
|
||
raise RuntimeError(f"Failed to load model {model_key}: {e}")
|
||
|
||
def _unload_other_models(self, keep_model_type: str):
|
||
"""Выгружает все модели кроме указанной"""
|
||
for mtype in list(self.models.keys()):
|
||
if mtype != keep_model_type and mtype in self.models:
|
||
print(f"🗑️ Unloading model [{mtype}] to free memory...")
|
||
del self.models[mtype]
|
||
gc.collect()
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
|
||
def _get_model(self, model_type: str):
|
||
# Если модель уже загружена — возвращаем её
|
||
if model_type in self.models:
|
||
return self.models[model_type]
|
||
|
||
# Выгружаем другие модели чтобы освободить память
|
||
self._unload_other_models(model_type)
|
||
|
||
model_path = self._resolve_model(model_type)
|
||
print(f"🚀 Loading model [{model_type}]...")
|
||
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
free_before = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()) / (1024**3)
|
||
print(f"📊 Free GPU memory before load: {free_before:.2f} GB")
|
||
|
||
try:
|
||
# Стратегия 1: Пробуем загрузить всё на GPU в FP16
|
||
print(f"⚙️ Trying FP16 on GPU...")
|
||
self.models[model_type] = Qwen3TTSModel.from_pretrained(
|
||
model_path,
|
||
dtype=torch.float16,
|
||
device_map="cuda:0",
|
||
low_cpu_mem_usage=True
|
||
)
|
||
print(f"✅ Model [{model_type}] loaded on GPU (FP16)")
|
||
|
||
except RuntimeError as e:
|
||
if "out of memory" in str(e).lower():
|
||
print(f"⚠️ GPU OOM, trying CPU offloading...")
|
||
torch.cuda.empty_cache()
|
||
gc.collect()
|
||
|
||
# Стратегия 2: Используем accelerate с offloading
|
||
# Но при этом избегаем bitsandbytes который вызывает pickle ошибку
|
||
print(f"⚙️ Using accelerate with CPU offloading (FP16)...")
|
||
|
||
# Ограничиваем память GPU чтобы force offloading
|
||
max_memory = {0: "3GiB", "cpu": "28GiB"} # Оставляем 3GB для одной модели
|
||
|
||
self.models[model_type] = Qwen3TTSModel.from_pretrained(
|
||
model_path,
|
||
dtype=torch.float16,
|
||
device_map="auto",
|
||
max_memory=max_memory,
|
||
low_cpu_mem_usage=True
|
||
)
|
||
print(f"✅ Model [{model_type}] loaded with CPU offloading")
|
||
else:
|
||
raise e
|
||
|
||
if torch.cuda.is_available():
|
||
allocated = torch.cuda.memory_allocated() / (1024**3)
|
||
print(f"📊 GPU memory allocated: {allocated:.2f} GB")
|
||
|
||
return self.models[model_type]
|
||
|
||
def get_available_samples(self) -> List[Dict[str, str]]:
|
||
samples = []
|
||
sample_dir = self.config['storage']['sample_dir']
|
||
if not os.path.exists(sample_dir): return samples
|
||
|
||
for name in sorted(os.listdir(sample_dir)):
|
||
full_path = os.path.join(sample_dir, name)
|
||
if os.path.isdir(full_path):
|
||
audio_path = os.path.join(full_path, "audio.wav")
|
||
prompt_path = os.path.join(full_path, "prompt.txt")
|
||
if os.path.exists(audio_path):
|
||
prompt = ""
|
||
if os.path.exists(prompt_path):
|
||
with open(prompt_path, 'r', encoding='utf-8') as f:
|
||
prompt = f.read().strip()
|
||
samples.append({"name": name, "path": full_path, "prompt": prompt})
|
||
return samples
|
||
|
||
def generate_with_sample(self, text: str, sample_path: str) -> Tuple[np.ndarray, int]:
|
||
model = self._get_model('base')
|
||
|
||
audio_file = os.path.join(sample_path, "audio.wav")
|
||
prompt_file = os.path.join(sample_path, "prompt.txt")
|
||
|
||
ref_text = None
|
||
if os.path.exists(prompt_file):
|
||
with open(prompt_file, 'r', encoding='utf-8') as f:
|
||
ref_text = f.read().strip()
|
||
|
||
print(f"🎤 Cloning voice...")
|
||
|
||
prompt = model.create_voice_clone_prompt(ref_audio=audio_file, ref_text=ref_text)
|
||
wavs, sr = model.generate_voice_clone(
|
||
text=text,
|
||
language=self.config['generation']['default_language'],
|
||
voice_clone_prompt=prompt
|
||
)
|
||
return wavs, sr
|
||
|
||
def generate_with_description(self, text: str, description: str) -> Tuple[np.ndarray, int]:
|
||
print(f"🎨 Designing voice: '{description}'")
|
||
|
||
# Генерируем референс через VoiceDesign
|
||
vd_model = self._get_model('voice_design')
|
||
ref_text = text[:100] if len(text) > 100 else text
|
||
|
||
ref_wavs, ref_sr = vd_model.generate_voice_design(
|
||
text=ref_text,
|
||
language=self.config['generation']['default_language'],
|
||
instruct=description
|
||
)
|
||
|
||
# Переключаемся на Base (VoiceDesign автоматически выгрузится)
|
||
base_model = self._get_model('base')
|
||
|
||
prompt = base_model.create_voice_clone_prompt(
|
||
ref_audio=(ref_wavs[0], ref_sr),
|
||
ref_text=ref_text
|
||
)
|
||
|
||
wavs, sr = base_model.generate_voice_clone(
|
||
text=text,
|
||
language=self.config['generation']['default_language'],
|
||
voice_clone_prompt=prompt
|
||
)
|
||
return wavs, sr
|
||
|
||
def generate_voice_design_only(self, text: str, description: str) -> Tuple[np.ndarray, int]:
|
||
"""Режим предпрослушки: только VoiceDesign без клонирования"""
|
||
print(f"🎨 VoiceDesign preview: '{description}'")
|
||
|
||
model = self._get_model('voice_design')
|
||
|
||
wavs, sr = model.generate_voice_design(
|
||
text=text,
|
||
language=self.config['generation']['default_language'],
|
||
instruct=description
|
||
)
|
||
return wavs, sr
|
||
|
||
def generate_standard(self, text: str, speaker: str = None) -> Tuple[np.ndarray, int]:
|
||
model = self._get_model('custom_voice')
|
||
speaker = speaker or self.config['generation']['default_speaker']
|
||
|
||
print(f"🗣️ Using speaker: {speaker}")
|
||
wavs, sr = model.generate_custom_voice(
|
||
text=text,
|
||
language=self.config['generation']['default_language'],
|
||
speaker=speaker
|
||
)
|
||
return wavs, sr
|
||
|
||
def download_all_models(self):
|
||
print("\n--- Checking models ---")
|
||
for key in ['base', 'voice_design', 'custom_voice']:
|
||
try:
|
||
self._resolve_model(key)
|
||
print(f"✅ {key}: OK")
|
||
except Exception as e:
|
||
print(f"❌ {key}: {e}")
|
||
|
||
def get_custom_speakers_list(self):
|
||
try:
|
||
model = self._get_model('custom_voice')
|
||
speakers = model.get_supported_speakers()
|
||
return list(speakers) if hasattr(speakers, '__iter__') else speakers
|
||
except Exception as e:
|
||
print(f"Error: {e}")
|
||
return ["Chelsie", "Dylan", "Eric", "Serena", "Vivian", "Aiden", "Ryan"]
|
||
|
||
def save_result(self, text: str, wavs: np.ndarray, sr: int) -> str:
|
||
out_dir = self.config['storage']['output_dir']
|
||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
filename = f"speech_{timestamp}"
|
||
|
||
wav_path = os.path.join(out_dir, f"{filename}.wav")
|
||
txt_path = os.path.join(out_dir, f"{filename}.txt")
|
||
|
||
sf.write(wav_path, wavs[0], sr)
|
||
with open(txt_path, 'w', encoding='utf-8') as f:
|
||
f.write(text)
|
||
return wav_path
|
||
|
||
def get_history(self) -> List[Dict[str, str]]:
|
||
out_dir = self.config['storage']['output_dir']
|
||
history = []
|
||
if not os.path.exists(out_dir): return history
|
||
|
||
for f in sorted(os.listdir(out_dir), reverse=True):
|
||
if f.endswith(".wav"):
|
||
base_name = f[:-4]
|
||
txt_path = os.path.join(out_dir, f"{base_name}.txt")
|
||
wav_path = os.path.join(out_dir, f)
|
||
|
||
text_content = "(Текст не найден)"
|
||
if os.path.exists(txt_path):
|
||
with open(txt_path, 'r', encoding='utf-8') as file:
|
||
text_content = file.read()
|
||
|
||
history.append({"filename": f, "wav_path": wav_path, "txt_path": txt_path, "text": text_content})
|
||
return history
|