FIX state: Curent in worked state

This commit is contained in:
2026-03-23 20:08:04 +03:00
parent 5b79cfeb71
commit 9af128ffc6
4 changed files with 364 additions and 187 deletions

View File

@@ -1,115 +1,162 @@
import torch
import torch
import soundfile as sf
import yaml
import os
import time
import datetime
import numpy as np
import gc
from pathlib import Path
from typing import Optional, List, Dict, Tuple
from huggingface_hub import snapshot_download
# ---
# Блок импорта модели.
# Если у вас установлен отдельный пакет qwen-tts:
try:
from qwen_tts import Qwen3TTSModel
except ImportError:
# Заглушка для тестирования без реальной модели (эмуляция)
print("WARNING: qwen_tts not found. Using Mock Model for testing.")
class Qwen3TTSModel:
@staticmethod
def from_pretrained(path, **kwargs):
print(f"[Mock] Loading model from {path}")
return Qwen3TTSModel()
def create_voice_clone_prompt(self, **kwargs):
return "mock_prompt"
def generate_voice_clone(self, text, **kwargs):
print(f"[Mock] Generating voice clone for: {text[:30]}...")
# Возвращаем пустой массив нужной размерности
sr = 24000
duration = len(text) * 0.1
return np.random.rand(1, int(sr * duration)).astype(np.float32), sr
def generate_custom_voice(self, text, **kwargs):
return self.generate_voice_clone(text, **kwargs)
def generate_voice_design(self, text, **kwargs):
return self.generate_voice_clone(text, **kwargs)
def get_supported_speakers(self):
return ["Chelsie", "Dylan", "Eric", "Serena", "Vivian", "Aiden", "Ryan"]
# ---
class TTSEngine:
def __init__(self, config_path: str = "config.yaml"):
with open(config_path, 'r', encoding='utf-8') as f:
self.config = yaml.safe_load(f)
self.models = {}
self.current_model_type = None
try:
self.dtype = getattr(torch, self.config['generation']['dtype'])
except AttributeError:
self.dtype = torch.float32
self.dtype = torch.float16 # По умолчанию FP16
if torch.cuda.is_available():
gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3)
print(f"📊 GPU: {torch.cuda.get_device_name(0)}")
print(f"📊 GPU Memory: {gpu_mem:.1f} GB")
# Инициализация папок
Path(self.config['storage']['model_path']).mkdir(parents=True, exist_ok=True)
Path(self.config['storage']['sample_dir']).mkdir(parents=True, exist_ok=True)
Path(self.config['storage']['output_dir']).mkdir(parents=True, exist_ok=True)
def _resolve_model(self, model_key: str) -> str:
"""
Умная загрузка моделей:
1. Абсолютный путь -> использовать его.
2. Локальный путь внутри model_path -> использовать.
3. Скачать с HF в model_path.
"""
"""Умная загрузка моделей"""
model_cfg_value = self.config['models'][model_key]
base_model_path = self.config['storage']['model_path']
# 1. Если это абсолютный путь или файл уже существует по этому пути
if os.path.isabs(model_cfg_value) or os.path.exists(model_cfg_value):
print(f"📂 Model [{model_key}]: Using direct path {model_cfg_value}")
return model_cfg_value
# 2. Формируем путь внутри хранилища
# Используем имя репозития как имя папки (замена / на _ если нужно, или сохранение структуры)
folder_name = model_cfg_value.split('/')[-1]
local_path = os.path.join(base_model_path, folder_name)
if os.path.exists(local_path) and os.listdir(local_path):
print(f"📂 Model [{model_key}]: Found locally at {local_path}")
return local_path
# 3. Скачивание с Hugging Face
print(f"⬇️ Model [{model_key}]: Not found. Downloading from HF to {local_path}...")
print(f"⬇️ Downloading {model_key}...")
try:
snapshot_download(
repo_id=model_cfg_value,
local_dir=local_path,
local_dir_use_symlinks=False
)
print(f"✅ Model [{model_key}]: Downloaded.")
snapshot_download(repo_id=model_cfg_value, local_dir=local_path, local_dir_use_symlinks=False)
return local_path
except Exception as e:
print(f"❌ Error downloading model {model_cfg_value}: {e}")
raise RuntimeError(f"Failed to load model {model_key}")
raise RuntimeError(f"Failed to load model {model_key}: {e}")
def _unload_other_models(self, keep_model_type: str):
"""Выгружает все модели кроме указанной"""
for mtype in list(self.models.keys()):
if mtype != keep_model_type and mtype in self.models:
print(f"🗑️ Unloading model [{mtype}] to free memory...")
del self.models[mtype]
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _get_model(self, model_type: str):
if model_type not in self.models:
model_path = self._resolve_model(model_type)
print(f"🚀 Loading model [{model_type}] into memory...")
# Если модель уже загружена — возвращаем её
if model_type in self.models:
return self.models[model_type]
# Выгружаем другие модели чтобы освободить память
self._unload_other_models(model_type)
model_path = self._resolve_model(model_type)
print(f"🚀 Loading model [{model_type}]...")
if torch.cuda.is_available():
torch.cuda.empty_cache()
free_before = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()) / (1024**3)
print(f"📊 Free GPU memory before load: {free_before:.2f} GB")
try:
# Стратегия 1: Пробуем загрузить всё на GPU в FP16
print(f"⚙️ Trying FP16 on GPU...")
self.models[model_type] = Qwen3TTSModel.from_pretrained(
model_path,
device_map=self.config['generation']['device'],
torch_dtype=self.dtype
dtype=torch.float16,
device_map="cuda:0",
low_cpu_mem_usage=True
)
print(f"✅ Model [{model_type}] loaded on GPU (FP16)")
except RuntimeError as e:
if "out of memory" in str(e).lower():
print(f"⚠️ GPU OOM, trying CPU offloading...")
torch.cuda.empty_cache()
gc.collect()
# Стратегия 2: Используем accelerate с offloading
# Но при этом избегаем bitsandbytes который вызывает pickle ошибку
print(f"⚙️ Using accelerate with CPU offloading (FP16)...")
# Ограничиваем память GPU чтобы force offloading
max_memory = {0: "3GiB", "cpu": "28GiB"} # Оставляем 3GB для одной модели
self.models[model_type] = Qwen3TTSModel.from_pretrained(
model_path,
dtype=torch.float16,
device_map="auto",
max_memory=max_memory,
low_cpu_mem_usage=True
)
print(f"✅ Model [{model_type}] loaded with CPU offloading")
else:
raise e
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / (1024**3)
print(f"📊 GPU memory allocated: {allocated:.2f} GB")
return self.models[model_type]
def get_available_samples(self) -> List[Dict[str, str]]:
samples = []
sample_dir = self.config['storage']['sample_dir']
if not os.path.exists(sample_dir): return samples
for name in sorted(os.listdir(sample_dir)):
full_path = os.path.join(sample_dir, name)
if os.path.isdir(full_path):
@@ -120,33 +167,23 @@ class TTSEngine:
if os.path.exists(prompt_path):
with open(prompt_path, 'r', encoding='utf-8') as f:
prompt = f.read().strip()
samples.append({
"name": name,
"path": full_path,
"prompt": prompt
})
samples.append({"name": name, "path": full_path, "prompt": prompt})
return samples
def generate_with_sample(self, text: str, sample_path: str) -> Tuple[np.ndarray, int]:
"""Режим 1: Клонирование по сэмплу"""
model = self._get_model('base')
audio_file = os.path.join(sample_path, "audio.wav")
prompt_file = os.path.join(sample_path, "prompt.txt")
ref_text = None
if os.path.exists(prompt_file):
with open(prompt_file, 'r', encoding='utf-8') as f:
ref_text = f.read().strip()
print(f"🎤 Cloning voice from: {sample_path}")
# Создаем промпт клонирования
prompt = model.create_voice_clone_prompt(
ref_audio=audio_file,
ref_text=ref_text
)
print(f"🎤 Cloning voice...")
prompt = model.create_voice_clone_prompt(ref_audio=audio_file, ref_text=ref_text)
wavs, sr = model.generate_voice_clone(
text=text,
language=self.config['generation']['default_language'],
@@ -155,29 +192,26 @@ class TTSEngine:
return wavs, sr
def generate_with_description(self, text: str, description: str) -> Tuple[np.ndarray, int]:
"""Режим 2: Генерация голоса по описанию (Design -> Clone)"""
print(f"🎨 Designing voice: '{description}'")
# Шаг А: Генерируем референс через VoiceDesign
# Генерируем референс через VoiceDesign
vd_model = self._get_model('voice_design')
ref_text = text[:100] if len(text) > 100 else text
# Генерируем сэмпл для будущего клонирования
ref_wavs, ref_sr = vd_model.generate_voice_design(
text=ref_text,
language=self.config['generation']['default_language'],
instruct=description
)
# Шаг Б: Клонируем этот сгенерированный голос через Base модель
# Переключаемся на Base (VoiceDesign автоматически выгрузится)
base_model = self._get_model('base')
# Передаем tuple (numpy_array, sr) как ref_audio
prompt = base_model.create_voice_clone_prompt(
ref_audio=(ref_wavs[0], ref_sr),
ref_text=ref_text
)
wavs, sr = base_model.generate_voice_clone(
text=text,
language=self.config['generation']['default_language'],
@@ -185,12 +219,24 @@ class TTSEngine:
)
return wavs, sr
def generate_voice_design_only(self, text: str, description: str) -> Tuple[np.ndarray, int]:
"""Режим предпрослушки: только VoiceDesign без клонирования"""
print(f"🎨 VoiceDesign preview: '{description}'")
model = self._get_model('voice_design')
wavs, sr = model.generate_voice_design(
text=text,
language=self.config['generation']['default_language'],
instruct=description
)
return wavs, sr
def generate_standard(self, text: str, speaker: str = None) -> Tuple[np.ndarray, int]:
"""Режим 3: Стандартный голос"""
model = self._get_model('custom_voice')
speaker = speaker or self.config['generation']['default_speaker']
print(f"🗣️ Using built-in speaker: {speaker}")
print(f"🗣️ Using speaker: {speaker}")
wavs, sr = model.generate_custom_voice(
text=text,
language=self.config['generation']['default_language'],
@@ -198,46 +244,52 @@ class TTSEngine:
)
return wavs, sr
def download_all_models(self):
print("\n--- Checking models ---")
for key in ['base', 'voice_design', 'custom_voice']:
try:
self._resolve_model(key)
print(f"{key}: OK")
except Exception as e:
print(f"{key}: {e}")
def get_custom_speakers_list(self):
try:
model = self._get_model('custom_voice')
speakers = model.get_supported_speakers()
return list(speakers) if hasattr(speakers, '__iter__') else speakers
except Exception as e:
print(f"Error: {e}")
return ["Chelsie", "Dylan", "Eric", "Serena", "Vivian", "Aiden", "Ryan"]
def save_result(self, text: str, wavs: np.ndarray, sr: int) -> str:
"""Сохраняет WAV и TXT в папку out"""
out_dir = self.config['storage']['output_dir']
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"speech_{timestamp}"
wav_path = os.path.join(out_dir, f"{filename}.wav")
txt_path = os.path.join(out_dir, f"{filename}.txt")
# Сохраняем аудио
sf.write(wav_path, wavs[0], sr)
# Сохраняем текст
with open(txt_path, 'w', encoding='utf-8') as f:
f.write(text)
return wav_path
def get_history(self) -> List[Dict[str, str]]:
"""Возвращает список сгенерированных файлов"""
out_dir = self.config['storage']['output_dir']
history = []
if not os.path.exists(out_dir): return history
files = sorted(os.listdir(out_dir), reverse=True)
for f in files:
for f in sorted(os.listdir(out_dir), reverse=True):
if f.endswith(".wav"):
base_name = f[:-4]
txt_path = os.path.join(out_dir, f"{base_name}.txt")
wav_path = os.path.join(out_dir, f)
text_content = "(Текст не найден)"
if os.path.exists(txt_path):
with open(txt_path, 'r', encoding='utf-8') as file:
text_content = file.read()
history.append({
"filename": f,
"wav_path": wav_path,
"txt_path": txt_path,
"text": text_content
})
history.append({"filename": f, "wav_path": wav_path, "txt_path": txt_path, "text": text_content})
return history