From 55353654b781c684aba438b31476901de52951a2 Mon Sep 17 00:00:00 2001 From: Komisar Date: Thu, 16 Apr 2026 17:51:15 +0300 Subject: [PATCH] Add OmniVoice TTS module with config, API, profiles and CLI - Create modules/omnivoice/ with VoiceAPI, VoiceProfiles, CLI - Add config manager integration with local model support - Add app/komAI.py entry point - Add tests/test_omnivoice.py - Clone OmniVoice to external/ for development - Add omnivoice config to global.yaml --- CHECKLIST.md | 18 ++- app/komAI.py | 90 +++++++++++++ config/global.yaml | 21 +++ external/OmniVoice | 1 + modules/omnivoice/README.md | 93 +++++++++++++ modules/omnivoice/__init__.py | 40 ++++++ modules/omnivoice/api.py | 224 +++++++++++++++++++++++++++++++ modules/omnivoice/cli.py | 245 ++++++++++++++++++++++++++++++++++ modules/omnivoice/config.py | 64 +++++++++ modules/omnivoice/profiles.py | 171 ++++++++++++++++++++++++ tests/test_omnivoice.py | 100 ++++++++++++++ 11 files changed, 1064 insertions(+), 3 deletions(-) create mode 100644 app/komAI.py create mode 160000 external/OmniVoice create mode 100644 modules/omnivoice/README.md create mode 100644 modules/omnivoice/__init__.py create mode 100644 modules/omnivoice/api.py create mode 100644 modules/omnivoice/cli.py create mode 100644 modules/omnivoice/config.py create mode 100644 modules/omnivoice/profiles.py create mode 100644 tests/test_omnivoice.py diff --git a/CHECKLIST.md b/CHECKLIST.md index 646b81e..a336396 100644 --- a/CHECKLIST.md +++ b/CHECKLIST.md @@ -25,7 +25,19 @@ ## Pending - [ ] stdout/stderr перехват при аварийном завершении (on crash) -- [ ] Создать приложение `app/komAI.py` как точку входа -- [ ] Реализовать систему модулей (`modules/`) - [ ] Настроить CI/CD -- [ ] Написать интеграционные тесты \ No newline at end of file +- [ ] Написать интеграционные тесты + +## OmniVoice Integration + +- [x] Создан `modules/omnivoice/` + - `__init__.py`: register_config, register_logging, get_api, get_profiles + - `config.py`: model_name, device, dtype, num_steps, speed, profiles_dir, output_dir + - `api.py`: VoiceAPI: clone(), design(), auto(), generate(), save_audio() + - `profiles.py`: VoiceProfiles: add, remove, list, generate + - `cli.py`: clone, design, auto, profiles, profile-add, profile-remove, profile-use +- [x] Создан `app/komAI.py` - точка входа +- [x] Создан `tests/test_omnivoice.py` (6/6 тестов проходят) +- [x] Клонирован OmniVoice в `external/OmniVoice` +- [x] Создан `venv` с зависимостями +- [x] Документация: `modules/omnivoice/README.md` \ No newline at end of file diff --git a/app/komAI.py b/app/komAI.py new file mode 100644 index 0000000..acbc6ea --- /dev/null +++ b/app/komAI.py @@ -0,0 +1,90 @@ +import sys +import argparse + +import src.utils.config_manager as config +import src.utils.log_manager as log + +from modules.omnivoice import register_config, register_logging + + +def init(): + cfg = config.config + + try: + register_config() + except Exception as e: + print(f"Failed to register omnivoice config: {e}") + + cfg.load() + + log.register_global_params() + log.register(module="app", log_console=True, log_file="app.log") + log.setup() + + logger = log.get_logger("app") + logger.print("komAI initialized") + + try: + register_logging() + except Exception as e: + logger.print(f"Failed to register omnivoice logging: {e}", level="warning") + + return logger + + +def init_cli(): + import src.utils.config_manager as config + + cfg = config.config + + try: + from modules.omnivoice import register_config + + register_config() + except Exception as e: + print(f"Failed to register omnivoice config: {e}") + + cfg.load() + + +def run_cli(args=None): + from modules.omnivoice.cli import main as cli_main + + return cli_main(args) + + +def main(argv=None): + argv = argv or sys.argv[1:] + + parser = argparse.ArgumentParser(prog="komAI") + subparsers = parser.add_subparsers() + + p_voice = subparsers.add_parser("voice", help="OmniVoice TTS") + p_voice.add_argument("command", help="clone|design|auto") + p_voice.add_argument("args", nargs=argparse.REMAINDER) + + p_shell = subparsers.add_parser("shell", help="Interactive shell") + + args = parser.parse_args(argv) + + if hasattr(args, "command") and args.command: + cmd_args = [args.command] + args.args + return run_cli(cmd_args) + + if args.command == "shell" or (not argv): + init() + print("komAI shell. Use Ctrl+C to exit.") + try: + import code + + code.interact(local=locals()) + except KeyboardInterrupt: + pass + return 0 + + parser.print_help() + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/config/global.yaml b/config/global.yaml index 9edbf39..c021a80 100644 --- a/config/global.yaml +++ b/config/global.yaml @@ -9,3 +9,24 @@ logging.log_file: "app.log" #Путь к директории логов #logging.log_path: "./log" logging.log_path: "./log" + +# OmniVoice TTS +# Model name on HuggingFace +omnivoice.model_name: "k2-fsa/OmniVoice" +# Local model path (if exists, uses local model) +omnivoice.model_path: "models/OmniVoice" +# Device for inference +omnivoice.device: "cpu" +#omnivoice.device: "cuda:0" +# Data type +#omnivoice.dtype: "float16" +# Diffusion steps +#omnivoice.num_steps: 32 +# Speed factor +#omnivoice.speed: 1.0 +# Voice profiles directory +omnivoice.profiles_dir: "data/voice_profiles" +# Output directory for audio +omnivoice.output_dir: "output/voice" +# Enable omnivoice module +omnivoice.enabled: true diff --git a/external/OmniVoice b/external/OmniVoice new file mode 160000 index 0000000..6a3f23d --- /dev/null +++ b/external/OmniVoice @@ -0,0 +1 @@ +Subproject commit 6a3f23df5b23c93210fcb756f3c229ad40eb84db diff --git a/modules/omnivoice/README.md b/modules/omnivoice/README.md new file mode 100644 index 0000000..0b3e531 --- /dev/null +++ b/modules/omnivoice/README.md @@ -0,0 +1,93 @@ +# OmniVoice Module + +TTS (Text-to-Speech) модуль на базе OmniVoice. + +## VENV + +Все зависимости в venv: +```bash +venv\Scripts\python.exe -m pip install torch==2.8.0+cpu torchaudio==2.8.0+cpu --index-url https://download.pytorch.org/whl/cpu +venv\Scripts\python.exe -m pip install soundfile PyYAML transformers accelerate pydub gradio tensorboardX webdataset librosa +venv\Scripts\python.exe -m pip install -e external/OmniVoice +``` + +## Использование + +### Python API + +```python +# В venv +venv\Scripts\python.exe + +# Инициализация модуля +from modules.omnivoice import register_config, register_logging +register_config() +register_logging() + +from modules.omnivoice import get_api, get_profiles + +api = get_api() +profiles = get_profiles() + +# Voice Cloning +audio = api.clone( + text="Hello, this is a test.", + ref_audio="ref.wav", + ref_text="Reference transcription." +) + +# Voice Design +audio = api.design( + text="Hello, this is a test.", + instruct="female, british accent" +) + +# Auto Voice +audio = api.auto(text="Hello, this is a test.") + +# Сохранение +path = api.save_audio(audio[0], "output.wav") + +# Профили +profiles.save_from_generated("my_voice", "Hello", ref_audio="ref.wav") +audio = profiles.generate("my_voice", "Generated text") +``` + +### CLI + +```bash +# Скачать модель локально +venv\Scripts\python.exe -m modules.omnivoice.cli download + +# Voice Cloning (без --output = воспроизвести) +venv\Scripts\python.exe -m modules.omnivoice.cli clone --text "Hello" --ref-audio ref.wav +venv\Scripts\python.exe -m modules.omnivoice.cli clone --text "Hello" --ref-audio ref.wav --output hello.wav --profile my_voice + +# Voice Design +venv\Scripts\python.exe -m modules.omnivoice.cli design --text "Hello" --instruct "female, british accent" + +# Auto Voice +venv\Scripts\python.exe -m modules.omnivoice.cli auto --text "Hello" + +# Профили +venv\Scripts\python.exe -m modules.omnivoice.cli profiles +venv\Scripts\python.exe -m modules.omnivoice.cli profile-add --name my_voice --ref-audio ref.wav +venv\Scripts\python.exe -m modules.omnivoice.cli profile-remove --name my_voice +venv\Scripts\python.exe -m modules.omnivoice.cli profile-use --profile my_voice --text "Hello" +``` + +## Конфигурация + +Параметры в `config/global.yaml`: + +```yaml +# omnivoice +omnivoice.model_name: "k2-fsa/OmniVoice" +omnivoice.device: "cuda:0" +omnivoice.dtype: "float16" +omnivoice.num_steps: 32 +omnivoice.speed: 1.0 +omnivoice.profiles_dir: "data/voice_profiles" +omnivoice.output_dir: "output/voice" +omnivoice.enabled: true +``` \ No newline at end of file diff --git a/modules/omnivoice/__init__.py b/modules/omnivoice/__init__.py new file mode 100644 index 0000000..83938f2 --- /dev/null +++ b/modules/omnivoice/__init__.py @@ -0,0 +1,40 @@ +import src.utils.config_manager as config +import src.utils.log_manager as log + +config = config.config +logger = None + +VOICE_CATEGORY = "voice" + + +def register_config(): + from .config import register_params + + register_params() + + +def register_logging(): + global logger + log.register(module=VOICE_CATEGORY, log_console=True, log_file="omnivoice.log") + logger = log.get_logger(VOICE_CATEGORY) + + +def get_api(): + from .api import VoiceAPI + + return VoiceAPI() + + +def get_profiles(): + from .profiles import VoiceProfiles + + return VoiceProfiles() + + +__all__ = [ + "register_config", + "register_logging", + "get_api", + "get_profiles", + "VOICE_CATEGORY", +] diff --git a/modules/omnivoice/api.py b/modules/omnivoice/api.py new file mode 100644 index 0000000..185c8e6 --- /dev/null +++ b/modules/omnivoice/api.py @@ -0,0 +1,224 @@ +from pathlib import Path +from typing import Optional, List +import numpy as np + +import src.utils.config_manager as config_mgr +from .config import ( + MODEL_NAME, + MODEL_PATH, + DEVICE, + DTYPE, + NUM_STEPS, + SPEED, + OUTPUT_DIR, + DEFAULT_MODEL_NAME, + DEFAULT_MODEL_PATH, + DEFAULT_DEVICE, + DEFAULT_DTYPE, + DEFAULT_NUM_STEPS, + DEFAULT_SPEED, + DEFAULT_OUTPUT_DIR, +) + +config = config_mgr.config + + +def _cfg(key: str, default): + val = config.get(key, cat="omnivoice") + return val if val is not None else default + + +class VoiceAPI: + def __init__(self): + self._model = None + self._config = None + + @property + def _get_config(self): + if self._config is None: + model_path = _cfg(MODEL_PATH, DEFAULT_MODEL_PATH) + if model_path and Path(model_path).exists(): + actual_path = model_path + else: + actual_path = None + self._config = { + "model_name": _cfg(MODEL_NAME, DEFAULT_MODEL_NAME), + "model_path": actual_path, + "device": _cfg(DEVICE, DEFAULT_DEVICE), + "dtype": _cfg(DTYPE, DEFAULT_DTYPE), + "num_steps": _cfg(NUM_STEPS, DEFAULT_NUM_STEPS), + "speed": _cfg(SPEED, DEFAULT_SPEED), + "output_dir": _cfg(OUTPUT_DIR, DEFAULT_OUTPUT_DIR), + } + return self._config + + def _load_model(self): + if self._model is not None: + return self._model + try: + import torch + from omnivoice import OmniVoice + except ImportError: + raise ImportError("omnivoice not installed. Run: pip install omnivoice") + + cfg = self._get_config + dtype = getattr(torch, cfg["dtype"], torch.float16) + + if cfg["model_path"]: + self._model = OmniVoice.from_pretrained( + cfg["model_path"], + device_map=cfg["device"], + dtype=dtype, + ) + else: + self._model = OmniVoice.from_pretrained( + cfg["model_name"], + device_map=cfg["device"], + dtype=dtype, + ) + return self._model + + def generate( + self, + text: str, + ref_audio: Optional[str] = None, + ref_text: Optional[str] = None, + instruct: Optional[str] = None, + num_steps: Optional[int] = None, + speed: Optional[float] = None, + duration: Optional[float] = None, + ) -> List[np.ndarray]: + model = self._load_model() + + kwargs = {"text": text} + if ref_audio: + kwargs["ref_audio"] = ref_audio + if ref_text: + kwargs["ref_text"] = ref_text + if instruct: + kwargs["instruct"] = instruct + + cfg = self._get_config + kwargs["num_step"] = num_steps or cfg["num_steps"] or 32 + kwargs["speed"] = speed or cfg["speed"] or 1.0 + if duration: + kwargs["duration"] = duration + + audio = model.generate(**kwargs) + return audio + + def clone( + self, + text: str, + ref_audio: str, + ref_text: Optional[str] = None, + num_steps: Optional[int] = None, + speed: Optional[float] = None, + ) -> List[np.ndarray]: + return self.generate( + text=text, + ref_audio=ref_audio, + ref_text=ref_text, + num_steps=num_steps, + speed=speed, + ) + + def design( + self, + text: str, + instruct: str, + num_steps: Optional[int] = None, + speed: Optional[float] = None, + ) -> List[np.ndarray]: + return self.generate( + text=text, + instruct=instruct, + num_steps=num_steps, + speed=speed, + ) + + def auto( + self, + text: str, + num_steps: Optional[int] = None, + speed: Optional[float] = None, + ) -> List[np.ndarray]: + return self.generate( + text=text, + num_steps=num_steps, + speed=speed, + ) + + def save_audio( + self, + audio: np.ndarray, + filename: Optional[str] = None, + sample_rate: int = 24000, + ) -> Optional[Path]: + if filename is None: + return self.play_audio(audio, sample_rate) + + import soundfile as sf + + cfg = self._get_config + output_dir = Path(cfg["output_dir"] or "output/voice") + output_dir.mkdir(parents=True, exist_ok=True) + + path = output_dir / filename + sf.write(str(path), audio, sample_rate) + return path + + def play_audio(self, audio: np.ndarray, sample_rate: int = 24000) -> None: + try: + import sounddevice as sd + + sd.play(audio, sample_rate) + sd.wait() + except ImportError: + raise ImportError("sounddevice not installed. Run: pip install sounddevice") + + def speak( + self, + text: str, + ref_audio: Optional[str] = None, + ref_text: Optional[str] = None, + instruct: Optional[str] = None, + num_steps: Optional[int] = None, + speed: Optional[float] = None, + ) -> None: + audio = self.generate( + text=text, + ref_audio=ref_audio, + ref_text=ref_text, + instruct=instruct, + num_steps=num_steps, + speed=speed, + ) + self.play_audio(audio[0]) + + def reload(self): + self._model = None + self._config = None + + +api = VoiceAPI() + + +def get_api() -> VoiceAPI: + return api + + +def generate(*args, **kwargs) -> List[np.ndarray]: + return api.generate(*args, **kwargs) + + +def clone(*args, **kwargs) -> List[np.ndarray]: + return api.clone(*args, **kwargs) + + +def design(*args, **kwargs) -> List[np.ndarray]: + return api.design(*args, **kwargs) + + +def auto(*args, **kwargs) -> List[np.ndarray]: + return api.auto(*args, **kwargs) diff --git a/modules/omnivoice/cli.py b/modules/omnivoice/cli.py new file mode 100644 index 0000000..f9d69f2 --- /dev/null +++ b/modules/omnivoice/cli.py @@ -0,0 +1,245 @@ +import argparse +import sys +from pathlib import Path + + +def cmd_clone(args): + from .api import api + from .profiles import profiles + + audio = api.clone( + text=args.text, + ref_audio=args.ref_audio, + ref_text=args.ref_text, + num_steps=args.steps, + speed=args.speed, + ) + path = api.save_audio(audio[0], args.output) + if path: + print(f"Saved: {path}") + else: + print("Playing...") + if args.profile: + profiles.save_from_generated( + name=args.profile, + text=args.text, + ref_audio=args.ref_audio, + ref_text=args.ref_text, + ) + print(f"Profile saved: {args.profile}") + return 0 + + +def cmd_design(args): + from .api import api + from .profiles import profiles + + audio = api.design( + text=args.text, + instruct=args.instruct, + num_steps=args.steps, + speed=args.speed, + ) + path = api.save_audio(audio[0], args.output) + if path: + print(f"Saved: {path}") + else: + print("Playing...") + if args.profile: + profiles.save_from_generated( + name=args.profile, + text=args.text, + instruct=args.instruct, + ) + print(f"Profile saved: {args.profile}") + return 0 + + +def cmd_auto(args): + from .api import api + + audio = api.auto( + text=args.text, + num_steps=args.steps, + speed=args.speed, + ) + path = api.save_audio(audio[0], args.output) + if path: + print(f"Saved: {path}") + else: + print("Playing...") + return 0 + + +def cmd_profile_list(args): + from .profiles import profiles + + for name in profiles.list(): + p = profiles.get(name) + print(f"{name} [{p.mode}]: {p.description}") + return 0 + + +def cmd_profile_add(args): + from .profiles import profiles, VoiceProfile + + profile = VoiceProfile( + name=args.name, + ref_audio=args.ref_audio, + ref_text=args.ref_text, + instruct=args.instruct, + description=args.description or "", + ) + profiles.add(profile) + print(f"Profile added: {args.name}") + return 0 + + +def cmd_profile_remove(args): + from .profiles import profiles + + if profiles.remove(args.name): + print(f"Profile removed: {args.name}") + return 0 + print(f"Profile not found: {args.name}", file=sys.stderr) + return 1 + + +def cmd_profile_use(args): + from .api import api + from .profiles import profiles + + audio = profiles.generate(args.profile, args.text, args.steps, args.speed) + path = api.save_audio(audio[0], args.output) + if path: + print(f"Saved: {path}") + else: + print("Playing...") + return 0 + + +def cmd_download(args): + from huggingface_hub import snapshot_download + import src.utils.config_manager as config + from modules.omnivoice.config import ( + MODEL_PATH, + DEFAULT_MODEL_PATH, + MODEL_NAME, + DEFAULT_MODEL_NAME, + ) + + model_name = ( + args.model + or config.config.get(MODEL_NAME, cat="omnivoice") + or DEFAULT_MODEL_NAME + ) + + if args.path: + model_path = Path(args.path) + else: + model_path_str = ( + config.config.get(MODEL_PATH, cat="omnivoice") or DEFAULT_MODEL_PATH + ) + model_path = Path(model_path_str) + + if model_path.exists() and any(model_path.iterdir()): + print(f"Model already exists: {model_path}") + print("Delete folder to re-download") + return 0 + + print(f"Downloading model {model_name} to {model_path}...") + model_path.mkdir(parents=True, exist_ok=True) + + snapshot_download( + repo_id=model_name, + local_dir=str(model_path), + ignore_patterns=["*.pt", "*.bin", "*.pth"], + ) + + print(f"Model saved to: {model_path}") + return 0 + + +def main(argv=None): + argv = argv or sys.argv[1:] + + import src.utils.config_manager as config + from modules.omnivoice import register_config + + try: + register_config() + except Exception as e: + print(f"Failed to register omnivoice config: {e}") + + config.config.load() + + parser = argparse.ArgumentParser(prog="komai-voice") + subparsers = parser.add_subparsers() + + p_download = subparsers.add_parser("download", help="Download model locally") + p_download.add_argument( + "--model", default="k2-fsa/OmniVoice", help="Model name on HuggingFace" + ) + p_download.add_argument("--path", help="Local path to save model") + p_download.set_defaults(func=cmd_download) + + p_clone = subparsers.add_parser("clone", help="Voice cloning") + p_clone.add_argument("--text", required=True) + p_clone.add_argument("--ref-audio", required=True) + p_clone.add_argument("--ref-text") + p_clone.add_argument("--output") + p_clone.add_argument("--steps", type=int) + p_clone.add_argument("--speed", type=float) + p_clone.add_argument("--profile", help="Save as profile") + p_clone.set_defaults(func=cmd_clone) + + p_design = subparsers.add_parser("design", help="Voice design") + p_design.add_argument("--text", required=True) + p_design.add_argument("--instruct", required=True) + p_design.add_argument("--output") + p_design.add_argument("--steps", type=int) + p_design.add_argument("--speed", type=float) + p_design.add_argument("--profile", help="Save as profile") + p_design.set_defaults(func=cmd_design) + + p_auto = subparsers.add_parser("auto", help="Auto voice") + p_auto.add_argument("--text", required=True) + p_auto.add_argument("--output") + p_auto.add_argument("--steps", type=int) + p_auto.add_argument("--speed", type=float) + p_auto.set_defaults(func=cmd_auto) + + p_list = subparsers.add_parser("profiles", help="List profiles") + p_list.set_defaults(func=cmd_profile_list) + + p_add = subparsers.add_parser("profile-add", help="Add profile") + p_add.add_argument("--name", required=True) + p_add.add_argument("--ref-audio") + p_add.add_argument("--ref-text") + p_add.add_argument("--instruct") + p_add.add_argument("--description") + p_add.set_defaults(func=cmd_profile_add) + + p_rm = subparsers.add_parser("profile-remove", help="Remove profile") + p_rm.add_argument("--name", required=True) + p_rm.set_defaults(func=cmd_profile_remove) + + p_use = subparsers.add_parser("profile-use", help="Generate using profile") + p_use.add_argument("--profile", required=True) + p_use.add_argument("--text", required=True) + p_use.add_argument("--output") + p_use.add_argument("--steps", type=int) + p_use.add_argument("--speed", type=float) + p_use.set_defaults(func=cmd_profile_use) + + args = parser.parse_args(argv) + + if not hasattr(args, "func"): + parser.print_help() + return 1 + + return args.func(args) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/modules/omnivoice/config.py b/modules/omnivoice/config.py new file mode 100644 index 0000000..8ec4183 --- /dev/null +++ b/modules/omnivoice/config.py @@ -0,0 +1,64 @@ +import src.utils.config_manager as config + +config = config.config + +MODEL_NAME = "model_name" +MODEL_PATH = "model_path" +DEVICE = "device" +DTYPE = "dtype" +NUM_STEPS = "num_steps" +SPEED = "speed" +PROFILES_DIR = "profiles_dir" +OUTPUT_DIR = "output_dir" +ENABLED = "enabled" + +DEFAULT_MODEL_NAME = "k2-fsa/OmniVoice" +DEFAULT_MODEL_PATH = "models/OmniVoice" +DEFAULT_DEVICE = "cuda:0" +DEFAULT_DTYPE = "float16" +DEFAULT_NUM_STEPS = 32 +DEFAULT_SPEED = 1.0 +DEFAULT_PROFILES_DIR = "data/voice_profiles" +DEFAULT_OUTPUT_DIR = "output/voice" +DEFAULT_ENABLED = True + + +def register_params(): + config.register( + name=MODEL_NAME, + val=DEFAULT_MODEL_NAME, + cat="omnivoice", + desc="Model name (HuggingFace) or use model_path for local", + ) + config.register( + name=MODEL_PATH, + val=DEFAULT_MODEL_PATH, + cat="omnivoice", + desc="Local model path (if set, overrides model_name)", + ) + config.register( + name=DEVICE, val=DEFAULT_DEVICE, cat="omnivoice", desc="Device for inference" + ) + config.register(name=DTYPE, val=DEFAULT_DTYPE, cat="omnivoice", desc="Data type") + config.register( + name=NUM_STEPS, val=DEFAULT_NUM_STEPS, cat="omnivoice", desc="Diffusion steps" + ) + config.register(name=SPEED, val=DEFAULT_SPEED, cat="omnivoice", desc="Speed factor") + config.register( + name=PROFILES_DIR, + val=DEFAULT_PROFILES_DIR, + cat="omnivoice", + desc="Voice profiles directory", + ) + config.register( + name=OUTPUT_DIR, + val=DEFAULT_OUTPUT_DIR, + cat="omnivoice", + desc="Output directory for audio", + ) + config.register( + name=ENABLED, + val=DEFAULT_ENABLED, + cat="omnivoice", + desc="Enable omnivoice module", + ) diff --git a/modules/omnivoice/profiles.py b/modules/omnivoice/profiles.py new file mode 100644 index 0000000..d26ea46 --- /dev/null +++ b/modules/omnivoice/profiles.py @@ -0,0 +1,171 @@ +import json +from pathlib import Path +from typing import Optional, Dict, List +import shutil + +import src.utils.config_manager as config_mgr +from .config import PROFILES_DIR, DEFAULT_PROFILES_DIR +from .api import api + +config = config_mgr.config + + +class VoiceProfile: + def __init__( + self, + name: str, + ref_audio: Optional[str] = None, + ref_text: Optional[str] = None, + instruct: Optional[str] = None, + description: str = "", + ): + self.name = name + self.ref_audio = ref_audio + self.ref_text = ref_text + self.instruct = instruct + self.description = description + + @property + def mode(self) -> str: + if self.ref_audio: + return "clone" + elif self.instruct: + return "design" + return "auto" + + def to_dict(self) -> Dict: + return { + "name": self.name, + "ref_audio": self.ref_audio, + "ref_text": self.ref_text, + "instruct": self.instruct, + "description": self.description, + "mode": self.mode, + } + + @classmethod + def from_dict(cls, data: Dict) -> "VoiceProfile": + return cls( + name=data["name"], + ref_audio=data.get("ref_audio"), + ref_text=data.get("ref_text"), + instruct=data.get("instruct"), + description=data.get("description", ""), + ) + + +class VoiceProfiles: + def __init__(self, profiles_dir: Optional[str] = None): + self._profiles_dir = ( + profiles_dir + or config.get(PROFILES_DIR, cat="omnivoice") + or DEFAULT_PROFILES_DIR + ) + self._profiles_path = Path(self._profiles_dir) / "profiles.json" + self._profiles: Dict[str, VoiceProfile] = {} + self._load() + + def _load(self): + if self._profiles_path.exists(): + with open(self._profiles_path, "r", encoding="utf-8") as f: + data = json.load(f) + for name, profile_data in data.items(): + self._profiles[name] = VoiceProfile.from_dict(profile_data) + + def _save(self): + self._profiles_path.parent.mkdir(parents=True, exist_ok=True) + data = {name: profile.to_dict() for name, profile in self._profiles.items()} + with open(self._profiles_path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2, ensure_ascii=False) + + def list(self) -> List[str]: + return sorted(self._profiles.keys()) + + def get(self, name: str) -> Optional[VoiceProfile]: + return self._profiles.get(name) + + def add(self, profile: VoiceProfile) -> None: + self._profiles[profile.name] = profile + self._save() + + def remove(self, name: str) -> bool: + if name in self._profiles: + del self._profiles[name] + self._save() + return True + return False + + def save_from_generated( + self, + name: str, + text: str, + ref_audio: Optional[str] = None, + ref_text: Optional[str] = None, + instruct: Optional[str] = None, + description: str = "", + ) -> VoiceProfile: + profile = VoiceProfile( + name=name, + ref_audio=ref_audio, + ref_text=ref_text, + instruct=instruct, + description=description or f"Generated from: {text[:50]}...", + ) + self.add(profile) + return profile + + def generate( + self, + profile_name: str, + text: str, + num_steps: Optional[int] = None, + speed: Optional[float] = None, + ) -> List: + profile = self._profiles.get(profile_name) + if not profile: + raise ValueError(f"Profile '{profile_name}' not found") + + if profile.mode == "clone": + return api.clone( + text=text, + ref_audio=profile.ref_audio, + ref_text=profile.ref_text, + num_steps=num_steps, + speed=speed, + ) + elif profile.mode == "design": + return api.design( + text=text, + instruct=profile.instruct, + num_steps=num_steps, + speed=speed, + ) + else: + return api.auto( + text=text, + num_steps=num_steps, + speed=speed, + ) + + +profiles = VoiceProfiles() + + +def get_profiles() -> VoiceProfiles: + return profiles + + +def list_profiles() -> List[str]: + return profiles.list() + + +def get_profile(name: str) -> Optional[VoiceProfile]: + return profiles.get(name) + + +def add_profile(profile: VoiceProfile) -> None: + profiles.add(profile) + + +def remove_profile(name: str) -> bool: + return profiles.remove(name) diff --git a/tests/test_omnivoice.py b/tests/test_omnivoice.py new file mode 100644 index 0000000..ce7ce84 --- /dev/null +++ b/tests/test_omnivoice.py @@ -0,0 +1,100 @@ +import unittest +import os +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +import src.utils.config_manager as config + + +class TestOmniVoiceConfig(unittest.TestCase): + def setUp(self): + self.config = config.config + self.config.reset() + + def test_register_params(self): + from modules.omnivoice.config import register_params + from modules.omnivoice.config import ( + MODEL_NAME, + DEVICE, + DTYPE, + NUM_STEPS, + SPEED, + PROFILES_DIR, + OUTPUT_DIR, + ENABLED, + DEFAULT_MODEL_NAME, + DEFAULT_DEVICE, + ) + + register_params() + + self.assertEqual( + self.config.get(MODEL_NAME, cat="omnivoice"), + DEFAULT_MODEL_NAME, + ) + self.assertEqual( + self.config.get(DEVICE, cat="omnivoice"), + DEFAULT_DEVICE, + ) + + +class TestOmniVoiceAPI(unittest.TestCase): + def test_api_import(self): + from modules.omnivoice.api import VoiceAPI + + api = VoiceAPI() + self.assertIsNotNone(api) + + +class TestOmniVoiceProfiles(unittest.TestCase): + def setUp(self): + import tempfile + + self.temp_dir = tempfile.mkdtemp() + from modules.omnivoice.profiles import VoiceProfiles + + self.profiles = VoiceProfiles(self.temp_dir) + + def test_profile_create(self): + from modules.omnivoice.profiles import VoiceProfile + + profile = VoiceProfile( + name="test", + ref_audio="ref.wav", + ref_text="test text", + description="test profile", + ) + self.assertEqual(profile.name, "test") + self.assertEqual(profile.mode, "clone") + + def test_profile_instruct_mode(self): + from modules.omnivoice.profiles import VoiceProfile + + profile = VoiceProfile( + name="test_design", + instruct="female, british", + description="test design", + ) + self.assertEqual(profile.mode, "design") + + def test_profile_auto_mode(self): + from modules.omnivoice.profiles import VoiceProfile + + profile = VoiceProfile( + name="test_auto", + description="test auto", + ) + self.assertEqual(profile.mode, "auto") + + +class TestOmniVoiceCLI(unittest.TestCase): + def test_cli_import(self): + from modules.omnivoice import cli + + self.assertTrue(hasattr(cli, "main")) + + +if __name__ == "__main__": + unittest.main()