Add OmniVoice TTS module with config, API, profiles and CLI
- Create modules/omnivoice/ with VoiceAPI, VoiceProfiles, CLI - Add config manager integration with local model support - Add app/komAI.py entry point - Add tests/test_omnivoice.py - Clone OmniVoice to external/ for development - Add omnivoice config to global.yaml
This commit is contained in:
16
CHECKLIST.md
16
CHECKLIST.md
@@ -25,7 +25,19 @@
|
||||
## Pending
|
||||
|
||||
- [ ] stdout/stderr перехват при аварийном завершении (on crash)
|
||||
- [ ] Создать приложение `app/komAI.py` как точку входа
|
||||
- [ ] Реализовать систему модулей (`modules/`)
|
||||
- [ ] Настроить CI/CD
|
||||
- [ ] Написать интеграционные тесты
|
||||
|
||||
## OmniVoice Integration
|
||||
|
||||
- [x] Создан `modules/omnivoice/`
|
||||
- `__init__.py`: register_config, register_logging, get_api, get_profiles
|
||||
- `config.py`: model_name, device, dtype, num_steps, speed, profiles_dir, output_dir
|
||||
- `api.py`: VoiceAPI: clone(), design(), auto(), generate(), save_audio()
|
||||
- `profiles.py`: VoiceProfiles: add, remove, list, generate
|
||||
- `cli.py`: clone, design, auto, profiles, profile-add, profile-remove, profile-use
|
||||
- [x] Создан `app/komAI.py` - точка входа
|
||||
- [x] Создан `tests/test_omnivoice.py` (6/6 тестов проходят)
|
||||
- [x] Клонирован OmniVoice в `external/OmniVoice`
|
||||
- [x] Создан `venv` с зависимостями
|
||||
- [x] Документация: `modules/omnivoice/README.md`
|
||||
90
app/komAI.py
Normal file
90
app/komAI.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
import src.utils.config_manager as config
|
||||
import src.utils.log_manager as log
|
||||
|
||||
from modules.omnivoice import register_config, register_logging
|
||||
|
||||
|
||||
def init():
|
||||
cfg = config.config
|
||||
|
||||
try:
|
||||
register_config()
|
||||
except Exception as e:
|
||||
print(f"Failed to register omnivoice config: {e}")
|
||||
|
||||
cfg.load()
|
||||
|
||||
log.register_global_params()
|
||||
log.register(module="app", log_console=True, log_file="app.log")
|
||||
log.setup()
|
||||
|
||||
logger = log.get_logger("app")
|
||||
logger.print("komAI initialized")
|
||||
|
||||
try:
|
||||
register_logging()
|
||||
except Exception as e:
|
||||
logger.print(f"Failed to register omnivoice logging: {e}", level="warning")
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def init_cli():
|
||||
import src.utils.config_manager as config
|
||||
|
||||
cfg = config.config
|
||||
|
||||
try:
|
||||
from modules.omnivoice import register_config
|
||||
|
||||
register_config()
|
||||
except Exception as e:
|
||||
print(f"Failed to register omnivoice config: {e}")
|
||||
|
||||
cfg.load()
|
||||
|
||||
|
||||
def run_cli(args=None):
|
||||
from modules.omnivoice.cli import main as cli_main
|
||||
|
||||
return cli_main(args)
|
||||
|
||||
|
||||
def main(argv=None):
|
||||
argv = argv or sys.argv[1:]
|
||||
|
||||
parser = argparse.ArgumentParser(prog="komAI")
|
||||
subparsers = parser.add_subparsers()
|
||||
|
||||
p_voice = subparsers.add_parser("voice", help="OmniVoice TTS")
|
||||
p_voice.add_argument("command", help="clone|design|auto")
|
||||
p_voice.add_argument("args", nargs=argparse.REMAINDER)
|
||||
|
||||
p_shell = subparsers.add_parser("shell", help="Interactive shell")
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
if hasattr(args, "command") and args.command:
|
||||
cmd_args = [args.command] + args.args
|
||||
return run_cli(cmd_args)
|
||||
|
||||
if args.command == "shell" or (not argv):
|
||||
init()
|
||||
print("komAI shell. Use Ctrl+C to exit.")
|
||||
try:
|
||||
import code
|
||||
|
||||
code.interact(local=locals())
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
return 0
|
||||
|
||||
parser.print_help()
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -9,3 +9,24 @@ logging.log_file: "app.log"
|
||||
#Путь к директории логов
|
||||
#logging.log_path: "./log"
|
||||
logging.log_path: "./log"
|
||||
|
||||
# OmniVoice TTS
|
||||
# Model name on HuggingFace
|
||||
omnivoice.model_name: "k2-fsa/OmniVoice"
|
||||
# Local model path (if exists, uses local model)
|
||||
omnivoice.model_path: "models/OmniVoice"
|
||||
# Device for inference
|
||||
omnivoice.device: "cpu"
|
||||
#omnivoice.device: "cuda:0"
|
||||
# Data type
|
||||
#omnivoice.dtype: "float16"
|
||||
# Diffusion steps
|
||||
#omnivoice.num_steps: 32
|
||||
# Speed factor
|
||||
#omnivoice.speed: 1.0
|
||||
# Voice profiles directory
|
||||
omnivoice.profiles_dir: "data/voice_profiles"
|
||||
# Output directory for audio
|
||||
omnivoice.output_dir: "output/voice"
|
||||
# Enable omnivoice module
|
||||
omnivoice.enabled: true
|
||||
|
||||
1
external/OmniVoice
vendored
Submodule
1
external/OmniVoice
vendored
Submodule
Submodule external/OmniVoice added at 6a3f23df5b
93
modules/omnivoice/README.md
Normal file
93
modules/omnivoice/README.md
Normal file
@@ -0,0 +1,93 @@
|
||||
# OmniVoice Module
|
||||
|
||||
TTS (Text-to-Speech) модуль на базе OmniVoice.
|
||||
|
||||
## VENV
|
||||
|
||||
Все зависимости в venv:
|
||||
```bash
|
||||
venv\Scripts\python.exe -m pip install torch==2.8.0+cpu torchaudio==2.8.0+cpu --index-url https://download.pytorch.org/whl/cpu
|
||||
venv\Scripts\python.exe -m pip install soundfile PyYAML transformers accelerate pydub gradio tensorboardX webdataset librosa
|
||||
venv\Scripts\python.exe -m pip install -e external/OmniVoice
|
||||
```
|
||||
|
||||
## Использование
|
||||
|
||||
### Python API
|
||||
|
||||
```python
|
||||
# В venv
|
||||
venv\Scripts\python.exe
|
||||
|
||||
# Инициализация модуля
|
||||
from modules.omnivoice import register_config, register_logging
|
||||
register_config()
|
||||
register_logging()
|
||||
|
||||
from modules.omnivoice import get_api, get_profiles
|
||||
|
||||
api = get_api()
|
||||
profiles = get_profiles()
|
||||
|
||||
# Voice Cloning
|
||||
audio = api.clone(
|
||||
text="Hello, this is a test.",
|
||||
ref_audio="ref.wav",
|
||||
ref_text="Reference transcription."
|
||||
)
|
||||
|
||||
# Voice Design
|
||||
audio = api.design(
|
||||
text="Hello, this is a test.",
|
||||
instruct="female, british accent"
|
||||
)
|
||||
|
||||
# Auto Voice
|
||||
audio = api.auto(text="Hello, this is a test.")
|
||||
|
||||
# Сохранение
|
||||
path = api.save_audio(audio[0], "output.wav")
|
||||
|
||||
# Профили
|
||||
profiles.save_from_generated("my_voice", "Hello", ref_audio="ref.wav")
|
||||
audio = profiles.generate("my_voice", "Generated text")
|
||||
```
|
||||
|
||||
### CLI
|
||||
|
||||
```bash
|
||||
# Скачать модель локально
|
||||
venv\Scripts\python.exe -m modules.omnivoice.cli download
|
||||
|
||||
# Voice Cloning (без --output = воспроизвести)
|
||||
venv\Scripts\python.exe -m modules.omnivoice.cli clone --text "Hello" --ref-audio ref.wav
|
||||
venv\Scripts\python.exe -m modules.omnivoice.cli clone --text "Hello" --ref-audio ref.wav --output hello.wav --profile my_voice
|
||||
|
||||
# Voice Design
|
||||
venv\Scripts\python.exe -m modules.omnivoice.cli design --text "Hello" --instruct "female, british accent"
|
||||
|
||||
# Auto Voice
|
||||
venv\Scripts\python.exe -m modules.omnivoice.cli auto --text "Hello"
|
||||
|
||||
# Профили
|
||||
venv\Scripts\python.exe -m modules.omnivoice.cli profiles
|
||||
venv\Scripts\python.exe -m modules.omnivoice.cli profile-add --name my_voice --ref-audio ref.wav
|
||||
venv\Scripts\python.exe -m modules.omnivoice.cli profile-remove --name my_voice
|
||||
venv\Scripts\python.exe -m modules.omnivoice.cli profile-use --profile my_voice --text "Hello"
|
||||
```
|
||||
|
||||
## Конфигурация
|
||||
|
||||
Параметры в `config/global.yaml`:
|
||||
|
||||
```yaml
|
||||
# omnivoice
|
||||
omnivoice.model_name: "k2-fsa/OmniVoice"
|
||||
omnivoice.device: "cuda:0"
|
||||
omnivoice.dtype: "float16"
|
||||
omnivoice.num_steps: 32
|
||||
omnivoice.speed: 1.0
|
||||
omnivoice.profiles_dir: "data/voice_profiles"
|
||||
omnivoice.output_dir: "output/voice"
|
||||
omnivoice.enabled: true
|
||||
```
|
||||
40
modules/omnivoice/__init__.py
Normal file
40
modules/omnivoice/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import src.utils.config_manager as config
|
||||
import src.utils.log_manager as log
|
||||
|
||||
config = config.config
|
||||
logger = None
|
||||
|
||||
VOICE_CATEGORY = "voice"
|
||||
|
||||
|
||||
def register_config():
|
||||
from .config import register_params
|
||||
|
||||
register_params()
|
||||
|
||||
|
||||
def register_logging():
|
||||
global logger
|
||||
log.register(module=VOICE_CATEGORY, log_console=True, log_file="omnivoice.log")
|
||||
logger = log.get_logger(VOICE_CATEGORY)
|
||||
|
||||
|
||||
def get_api():
|
||||
from .api import VoiceAPI
|
||||
|
||||
return VoiceAPI()
|
||||
|
||||
|
||||
def get_profiles():
|
||||
from .profiles import VoiceProfiles
|
||||
|
||||
return VoiceProfiles()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"register_config",
|
||||
"register_logging",
|
||||
"get_api",
|
||||
"get_profiles",
|
||||
"VOICE_CATEGORY",
|
||||
]
|
||||
224
modules/omnivoice/api.py
Normal file
224
modules/omnivoice/api.py
Normal file
@@ -0,0 +1,224 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
import numpy as np
|
||||
|
||||
import src.utils.config_manager as config_mgr
|
||||
from .config import (
|
||||
MODEL_NAME,
|
||||
MODEL_PATH,
|
||||
DEVICE,
|
||||
DTYPE,
|
||||
NUM_STEPS,
|
||||
SPEED,
|
||||
OUTPUT_DIR,
|
||||
DEFAULT_MODEL_NAME,
|
||||
DEFAULT_MODEL_PATH,
|
||||
DEFAULT_DEVICE,
|
||||
DEFAULT_DTYPE,
|
||||
DEFAULT_NUM_STEPS,
|
||||
DEFAULT_SPEED,
|
||||
DEFAULT_OUTPUT_DIR,
|
||||
)
|
||||
|
||||
config = config_mgr.config
|
||||
|
||||
|
||||
def _cfg(key: str, default):
|
||||
val = config.get(key, cat="omnivoice")
|
||||
return val if val is not None else default
|
||||
|
||||
|
||||
class VoiceAPI:
|
||||
def __init__(self):
|
||||
self._model = None
|
||||
self._config = None
|
||||
|
||||
@property
|
||||
def _get_config(self):
|
||||
if self._config is None:
|
||||
model_path = _cfg(MODEL_PATH, DEFAULT_MODEL_PATH)
|
||||
if model_path and Path(model_path).exists():
|
||||
actual_path = model_path
|
||||
else:
|
||||
actual_path = None
|
||||
self._config = {
|
||||
"model_name": _cfg(MODEL_NAME, DEFAULT_MODEL_NAME),
|
||||
"model_path": actual_path,
|
||||
"device": _cfg(DEVICE, DEFAULT_DEVICE),
|
||||
"dtype": _cfg(DTYPE, DEFAULT_DTYPE),
|
||||
"num_steps": _cfg(NUM_STEPS, DEFAULT_NUM_STEPS),
|
||||
"speed": _cfg(SPEED, DEFAULT_SPEED),
|
||||
"output_dir": _cfg(OUTPUT_DIR, DEFAULT_OUTPUT_DIR),
|
||||
}
|
||||
return self._config
|
||||
|
||||
def _load_model(self):
|
||||
if self._model is not None:
|
||||
return self._model
|
||||
try:
|
||||
import torch
|
||||
from omnivoice import OmniVoice
|
||||
except ImportError:
|
||||
raise ImportError("omnivoice not installed. Run: pip install omnivoice")
|
||||
|
||||
cfg = self._get_config
|
||||
dtype = getattr(torch, cfg["dtype"], torch.float16)
|
||||
|
||||
if cfg["model_path"]:
|
||||
self._model = OmniVoice.from_pretrained(
|
||||
cfg["model_path"],
|
||||
device_map=cfg["device"],
|
||||
dtype=dtype,
|
||||
)
|
||||
else:
|
||||
self._model = OmniVoice.from_pretrained(
|
||||
cfg["model_name"],
|
||||
device_map=cfg["device"],
|
||||
dtype=dtype,
|
||||
)
|
||||
return self._model
|
||||
|
||||
def generate(
|
||||
self,
|
||||
text: str,
|
||||
ref_audio: Optional[str] = None,
|
||||
ref_text: Optional[str] = None,
|
||||
instruct: Optional[str] = None,
|
||||
num_steps: Optional[int] = None,
|
||||
speed: Optional[float] = None,
|
||||
duration: Optional[float] = None,
|
||||
) -> List[np.ndarray]:
|
||||
model = self._load_model()
|
||||
|
||||
kwargs = {"text": text}
|
||||
if ref_audio:
|
||||
kwargs["ref_audio"] = ref_audio
|
||||
if ref_text:
|
||||
kwargs["ref_text"] = ref_text
|
||||
if instruct:
|
||||
kwargs["instruct"] = instruct
|
||||
|
||||
cfg = self._get_config
|
||||
kwargs["num_step"] = num_steps or cfg["num_steps"] or 32
|
||||
kwargs["speed"] = speed or cfg["speed"] or 1.0
|
||||
if duration:
|
||||
kwargs["duration"] = duration
|
||||
|
||||
audio = model.generate(**kwargs)
|
||||
return audio
|
||||
|
||||
def clone(
|
||||
self,
|
||||
text: str,
|
||||
ref_audio: str,
|
||||
ref_text: Optional[str] = None,
|
||||
num_steps: Optional[int] = None,
|
||||
speed: Optional[float] = None,
|
||||
) -> List[np.ndarray]:
|
||||
return self.generate(
|
||||
text=text,
|
||||
ref_audio=ref_audio,
|
||||
ref_text=ref_text,
|
||||
num_steps=num_steps,
|
||||
speed=speed,
|
||||
)
|
||||
|
||||
def design(
|
||||
self,
|
||||
text: str,
|
||||
instruct: str,
|
||||
num_steps: Optional[int] = None,
|
||||
speed: Optional[float] = None,
|
||||
) -> List[np.ndarray]:
|
||||
return self.generate(
|
||||
text=text,
|
||||
instruct=instruct,
|
||||
num_steps=num_steps,
|
||||
speed=speed,
|
||||
)
|
||||
|
||||
def auto(
|
||||
self,
|
||||
text: str,
|
||||
num_steps: Optional[int] = None,
|
||||
speed: Optional[float] = None,
|
||||
) -> List[np.ndarray]:
|
||||
return self.generate(
|
||||
text=text,
|
||||
num_steps=num_steps,
|
||||
speed=speed,
|
||||
)
|
||||
|
||||
def save_audio(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
filename: Optional[str] = None,
|
||||
sample_rate: int = 24000,
|
||||
) -> Optional[Path]:
|
||||
if filename is None:
|
||||
return self.play_audio(audio, sample_rate)
|
||||
|
||||
import soundfile as sf
|
||||
|
||||
cfg = self._get_config
|
||||
output_dir = Path(cfg["output_dir"] or "output/voice")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
path = output_dir / filename
|
||||
sf.write(str(path), audio, sample_rate)
|
||||
return path
|
||||
|
||||
def play_audio(self, audio: np.ndarray, sample_rate: int = 24000) -> None:
|
||||
try:
|
||||
import sounddevice as sd
|
||||
|
||||
sd.play(audio, sample_rate)
|
||||
sd.wait()
|
||||
except ImportError:
|
||||
raise ImportError("sounddevice not installed. Run: pip install sounddevice")
|
||||
|
||||
def speak(
|
||||
self,
|
||||
text: str,
|
||||
ref_audio: Optional[str] = None,
|
||||
ref_text: Optional[str] = None,
|
||||
instruct: Optional[str] = None,
|
||||
num_steps: Optional[int] = None,
|
||||
speed: Optional[float] = None,
|
||||
) -> None:
|
||||
audio = self.generate(
|
||||
text=text,
|
||||
ref_audio=ref_audio,
|
||||
ref_text=ref_text,
|
||||
instruct=instruct,
|
||||
num_steps=num_steps,
|
||||
speed=speed,
|
||||
)
|
||||
self.play_audio(audio[0])
|
||||
|
||||
def reload(self):
|
||||
self._model = None
|
||||
self._config = None
|
||||
|
||||
|
||||
api = VoiceAPI()
|
||||
|
||||
|
||||
def get_api() -> VoiceAPI:
|
||||
return api
|
||||
|
||||
|
||||
def generate(*args, **kwargs) -> List[np.ndarray]:
|
||||
return api.generate(*args, **kwargs)
|
||||
|
||||
|
||||
def clone(*args, **kwargs) -> List[np.ndarray]:
|
||||
return api.clone(*args, **kwargs)
|
||||
|
||||
|
||||
def design(*args, **kwargs) -> List[np.ndarray]:
|
||||
return api.design(*args, **kwargs)
|
||||
|
||||
|
||||
def auto(*args, **kwargs) -> List[np.ndarray]:
|
||||
return api.auto(*args, **kwargs)
|
||||
245
modules/omnivoice/cli.py
Normal file
245
modules/omnivoice/cli.py
Normal file
@@ -0,0 +1,245 @@
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def cmd_clone(args):
|
||||
from .api import api
|
||||
from .profiles import profiles
|
||||
|
||||
audio = api.clone(
|
||||
text=args.text,
|
||||
ref_audio=args.ref_audio,
|
||||
ref_text=args.ref_text,
|
||||
num_steps=args.steps,
|
||||
speed=args.speed,
|
||||
)
|
||||
path = api.save_audio(audio[0], args.output)
|
||||
if path:
|
||||
print(f"Saved: {path}")
|
||||
else:
|
||||
print("Playing...")
|
||||
if args.profile:
|
||||
profiles.save_from_generated(
|
||||
name=args.profile,
|
||||
text=args.text,
|
||||
ref_audio=args.ref_audio,
|
||||
ref_text=args.ref_text,
|
||||
)
|
||||
print(f"Profile saved: {args.profile}")
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_design(args):
|
||||
from .api import api
|
||||
from .profiles import profiles
|
||||
|
||||
audio = api.design(
|
||||
text=args.text,
|
||||
instruct=args.instruct,
|
||||
num_steps=args.steps,
|
||||
speed=args.speed,
|
||||
)
|
||||
path = api.save_audio(audio[0], args.output)
|
||||
if path:
|
||||
print(f"Saved: {path}")
|
||||
else:
|
||||
print("Playing...")
|
||||
if args.profile:
|
||||
profiles.save_from_generated(
|
||||
name=args.profile,
|
||||
text=args.text,
|
||||
instruct=args.instruct,
|
||||
)
|
||||
print(f"Profile saved: {args.profile}")
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_auto(args):
|
||||
from .api import api
|
||||
|
||||
audio = api.auto(
|
||||
text=args.text,
|
||||
num_steps=args.steps,
|
||||
speed=args.speed,
|
||||
)
|
||||
path = api.save_audio(audio[0], args.output)
|
||||
if path:
|
||||
print(f"Saved: {path}")
|
||||
else:
|
||||
print("Playing...")
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_profile_list(args):
|
||||
from .profiles import profiles
|
||||
|
||||
for name in profiles.list():
|
||||
p = profiles.get(name)
|
||||
print(f"{name} [{p.mode}]: {p.description}")
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_profile_add(args):
|
||||
from .profiles import profiles, VoiceProfile
|
||||
|
||||
profile = VoiceProfile(
|
||||
name=args.name,
|
||||
ref_audio=args.ref_audio,
|
||||
ref_text=args.ref_text,
|
||||
instruct=args.instruct,
|
||||
description=args.description or "",
|
||||
)
|
||||
profiles.add(profile)
|
||||
print(f"Profile added: {args.name}")
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_profile_remove(args):
|
||||
from .profiles import profiles
|
||||
|
||||
if profiles.remove(args.name):
|
||||
print(f"Profile removed: {args.name}")
|
||||
return 0
|
||||
print(f"Profile not found: {args.name}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
|
||||
def cmd_profile_use(args):
|
||||
from .api import api
|
||||
from .profiles import profiles
|
||||
|
||||
audio = profiles.generate(args.profile, args.text, args.steps, args.speed)
|
||||
path = api.save_audio(audio[0], args.output)
|
||||
if path:
|
||||
print(f"Saved: {path}")
|
||||
else:
|
||||
print("Playing...")
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_download(args):
|
||||
from huggingface_hub import snapshot_download
|
||||
import src.utils.config_manager as config
|
||||
from modules.omnivoice.config import (
|
||||
MODEL_PATH,
|
||||
DEFAULT_MODEL_PATH,
|
||||
MODEL_NAME,
|
||||
DEFAULT_MODEL_NAME,
|
||||
)
|
||||
|
||||
model_name = (
|
||||
args.model
|
||||
or config.config.get(MODEL_NAME, cat="omnivoice")
|
||||
or DEFAULT_MODEL_NAME
|
||||
)
|
||||
|
||||
if args.path:
|
||||
model_path = Path(args.path)
|
||||
else:
|
||||
model_path_str = (
|
||||
config.config.get(MODEL_PATH, cat="omnivoice") or DEFAULT_MODEL_PATH
|
||||
)
|
||||
model_path = Path(model_path_str)
|
||||
|
||||
if model_path.exists() and any(model_path.iterdir()):
|
||||
print(f"Model already exists: {model_path}")
|
||||
print("Delete folder to re-download")
|
||||
return 0
|
||||
|
||||
print(f"Downloading model {model_name} to {model_path}...")
|
||||
model_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
snapshot_download(
|
||||
repo_id=model_name,
|
||||
local_dir=str(model_path),
|
||||
ignore_patterns=["*.pt", "*.bin", "*.pth"],
|
||||
)
|
||||
|
||||
print(f"Model saved to: {model_path}")
|
||||
return 0
|
||||
|
||||
|
||||
def main(argv=None):
|
||||
argv = argv or sys.argv[1:]
|
||||
|
||||
import src.utils.config_manager as config
|
||||
from modules.omnivoice import register_config
|
||||
|
||||
try:
|
||||
register_config()
|
||||
except Exception as e:
|
||||
print(f"Failed to register omnivoice config: {e}")
|
||||
|
||||
config.config.load()
|
||||
|
||||
parser = argparse.ArgumentParser(prog="komai-voice")
|
||||
subparsers = parser.add_subparsers()
|
||||
|
||||
p_download = subparsers.add_parser("download", help="Download model locally")
|
||||
p_download.add_argument(
|
||||
"--model", default="k2-fsa/OmniVoice", help="Model name on HuggingFace"
|
||||
)
|
||||
p_download.add_argument("--path", help="Local path to save model")
|
||||
p_download.set_defaults(func=cmd_download)
|
||||
|
||||
p_clone = subparsers.add_parser("clone", help="Voice cloning")
|
||||
p_clone.add_argument("--text", required=True)
|
||||
p_clone.add_argument("--ref-audio", required=True)
|
||||
p_clone.add_argument("--ref-text")
|
||||
p_clone.add_argument("--output")
|
||||
p_clone.add_argument("--steps", type=int)
|
||||
p_clone.add_argument("--speed", type=float)
|
||||
p_clone.add_argument("--profile", help="Save as profile")
|
||||
p_clone.set_defaults(func=cmd_clone)
|
||||
|
||||
p_design = subparsers.add_parser("design", help="Voice design")
|
||||
p_design.add_argument("--text", required=True)
|
||||
p_design.add_argument("--instruct", required=True)
|
||||
p_design.add_argument("--output")
|
||||
p_design.add_argument("--steps", type=int)
|
||||
p_design.add_argument("--speed", type=float)
|
||||
p_design.add_argument("--profile", help="Save as profile")
|
||||
p_design.set_defaults(func=cmd_design)
|
||||
|
||||
p_auto = subparsers.add_parser("auto", help="Auto voice")
|
||||
p_auto.add_argument("--text", required=True)
|
||||
p_auto.add_argument("--output")
|
||||
p_auto.add_argument("--steps", type=int)
|
||||
p_auto.add_argument("--speed", type=float)
|
||||
p_auto.set_defaults(func=cmd_auto)
|
||||
|
||||
p_list = subparsers.add_parser("profiles", help="List profiles")
|
||||
p_list.set_defaults(func=cmd_profile_list)
|
||||
|
||||
p_add = subparsers.add_parser("profile-add", help="Add profile")
|
||||
p_add.add_argument("--name", required=True)
|
||||
p_add.add_argument("--ref-audio")
|
||||
p_add.add_argument("--ref-text")
|
||||
p_add.add_argument("--instruct")
|
||||
p_add.add_argument("--description")
|
||||
p_add.set_defaults(func=cmd_profile_add)
|
||||
|
||||
p_rm = subparsers.add_parser("profile-remove", help="Remove profile")
|
||||
p_rm.add_argument("--name", required=True)
|
||||
p_rm.set_defaults(func=cmd_profile_remove)
|
||||
|
||||
p_use = subparsers.add_parser("profile-use", help="Generate using profile")
|
||||
p_use.add_argument("--profile", required=True)
|
||||
p_use.add_argument("--text", required=True)
|
||||
p_use.add_argument("--output")
|
||||
p_use.add_argument("--steps", type=int)
|
||||
p_use.add_argument("--speed", type=float)
|
||||
p_use.set_defaults(func=cmd_profile_use)
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
if not hasattr(args, "func"):
|
||||
parser.print_help()
|
||||
return 1
|
||||
|
||||
return args.func(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
64
modules/omnivoice/config.py
Normal file
64
modules/omnivoice/config.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import src.utils.config_manager as config
|
||||
|
||||
config = config.config
|
||||
|
||||
MODEL_NAME = "model_name"
|
||||
MODEL_PATH = "model_path"
|
||||
DEVICE = "device"
|
||||
DTYPE = "dtype"
|
||||
NUM_STEPS = "num_steps"
|
||||
SPEED = "speed"
|
||||
PROFILES_DIR = "profiles_dir"
|
||||
OUTPUT_DIR = "output_dir"
|
||||
ENABLED = "enabled"
|
||||
|
||||
DEFAULT_MODEL_NAME = "k2-fsa/OmniVoice"
|
||||
DEFAULT_MODEL_PATH = "models/OmniVoice"
|
||||
DEFAULT_DEVICE = "cuda:0"
|
||||
DEFAULT_DTYPE = "float16"
|
||||
DEFAULT_NUM_STEPS = 32
|
||||
DEFAULT_SPEED = 1.0
|
||||
DEFAULT_PROFILES_DIR = "data/voice_profiles"
|
||||
DEFAULT_OUTPUT_DIR = "output/voice"
|
||||
DEFAULT_ENABLED = True
|
||||
|
||||
|
||||
def register_params():
|
||||
config.register(
|
||||
name=MODEL_NAME,
|
||||
val=DEFAULT_MODEL_NAME,
|
||||
cat="omnivoice",
|
||||
desc="Model name (HuggingFace) or use model_path for local",
|
||||
)
|
||||
config.register(
|
||||
name=MODEL_PATH,
|
||||
val=DEFAULT_MODEL_PATH,
|
||||
cat="omnivoice",
|
||||
desc="Local model path (if set, overrides model_name)",
|
||||
)
|
||||
config.register(
|
||||
name=DEVICE, val=DEFAULT_DEVICE, cat="omnivoice", desc="Device for inference"
|
||||
)
|
||||
config.register(name=DTYPE, val=DEFAULT_DTYPE, cat="omnivoice", desc="Data type")
|
||||
config.register(
|
||||
name=NUM_STEPS, val=DEFAULT_NUM_STEPS, cat="omnivoice", desc="Diffusion steps"
|
||||
)
|
||||
config.register(name=SPEED, val=DEFAULT_SPEED, cat="omnivoice", desc="Speed factor")
|
||||
config.register(
|
||||
name=PROFILES_DIR,
|
||||
val=DEFAULT_PROFILES_DIR,
|
||||
cat="omnivoice",
|
||||
desc="Voice profiles directory",
|
||||
)
|
||||
config.register(
|
||||
name=OUTPUT_DIR,
|
||||
val=DEFAULT_OUTPUT_DIR,
|
||||
cat="omnivoice",
|
||||
desc="Output directory for audio",
|
||||
)
|
||||
config.register(
|
||||
name=ENABLED,
|
||||
val=DEFAULT_ENABLED,
|
||||
cat="omnivoice",
|
||||
desc="Enable omnivoice module",
|
||||
)
|
||||
171
modules/omnivoice/profiles.py
Normal file
171
modules/omnivoice/profiles.py
Normal file
@@ -0,0 +1,171 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, List
|
||||
import shutil
|
||||
|
||||
import src.utils.config_manager as config_mgr
|
||||
from .config import PROFILES_DIR, DEFAULT_PROFILES_DIR
|
||||
from .api import api
|
||||
|
||||
config = config_mgr.config
|
||||
|
||||
|
||||
class VoiceProfile:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
ref_audio: Optional[str] = None,
|
||||
ref_text: Optional[str] = None,
|
||||
instruct: Optional[str] = None,
|
||||
description: str = "",
|
||||
):
|
||||
self.name = name
|
||||
self.ref_audio = ref_audio
|
||||
self.ref_text = ref_text
|
||||
self.instruct = instruct
|
||||
self.description = description
|
||||
|
||||
@property
|
||||
def mode(self) -> str:
|
||||
if self.ref_audio:
|
||||
return "clone"
|
||||
elif self.instruct:
|
||||
return "design"
|
||||
return "auto"
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return {
|
||||
"name": self.name,
|
||||
"ref_audio": self.ref_audio,
|
||||
"ref_text": self.ref_text,
|
||||
"instruct": self.instruct,
|
||||
"description": self.description,
|
||||
"mode": self.mode,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "VoiceProfile":
|
||||
return cls(
|
||||
name=data["name"],
|
||||
ref_audio=data.get("ref_audio"),
|
||||
ref_text=data.get("ref_text"),
|
||||
instruct=data.get("instruct"),
|
||||
description=data.get("description", ""),
|
||||
)
|
||||
|
||||
|
||||
class VoiceProfiles:
|
||||
def __init__(self, profiles_dir: Optional[str] = None):
|
||||
self._profiles_dir = (
|
||||
profiles_dir
|
||||
or config.get(PROFILES_DIR, cat="omnivoice")
|
||||
or DEFAULT_PROFILES_DIR
|
||||
)
|
||||
self._profiles_path = Path(self._profiles_dir) / "profiles.json"
|
||||
self._profiles: Dict[str, VoiceProfile] = {}
|
||||
self._load()
|
||||
|
||||
def _load(self):
|
||||
if self._profiles_path.exists():
|
||||
with open(self._profiles_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
for name, profile_data in data.items():
|
||||
self._profiles[name] = VoiceProfile.from_dict(profile_data)
|
||||
|
||||
def _save(self):
|
||||
self._profiles_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
data = {name: profile.to_dict() for name, profile in self._profiles.items()}
|
||||
with open(self._profiles_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
def list(self) -> List[str]:
|
||||
return sorted(self._profiles.keys())
|
||||
|
||||
def get(self, name: str) -> Optional[VoiceProfile]:
|
||||
return self._profiles.get(name)
|
||||
|
||||
def add(self, profile: VoiceProfile) -> None:
|
||||
self._profiles[profile.name] = profile
|
||||
self._save()
|
||||
|
||||
def remove(self, name: str) -> bool:
|
||||
if name in self._profiles:
|
||||
del self._profiles[name]
|
||||
self._save()
|
||||
return True
|
||||
return False
|
||||
|
||||
def save_from_generated(
|
||||
self,
|
||||
name: str,
|
||||
text: str,
|
||||
ref_audio: Optional[str] = None,
|
||||
ref_text: Optional[str] = None,
|
||||
instruct: Optional[str] = None,
|
||||
description: str = "",
|
||||
) -> VoiceProfile:
|
||||
profile = VoiceProfile(
|
||||
name=name,
|
||||
ref_audio=ref_audio,
|
||||
ref_text=ref_text,
|
||||
instruct=instruct,
|
||||
description=description or f"Generated from: {text[:50]}...",
|
||||
)
|
||||
self.add(profile)
|
||||
return profile
|
||||
|
||||
def generate(
|
||||
self,
|
||||
profile_name: str,
|
||||
text: str,
|
||||
num_steps: Optional[int] = None,
|
||||
speed: Optional[float] = None,
|
||||
) -> List:
|
||||
profile = self._profiles.get(profile_name)
|
||||
if not profile:
|
||||
raise ValueError(f"Profile '{profile_name}' not found")
|
||||
|
||||
if profile.mode == "clone":
|
||||
return api.clone(
|
||||
text=text,
|
||||
ref_audio=profile.ref_audio,
|
||||
ref_text=profile.ref_text,
|
||||
num_steps=num_steps,
|
||||
speed=speed,
|
||||
)
|
||||
elif profile.mode == "design":
|
||||
return api.design(
|
||||
text=text,
|
||||
instruct=profile.instruct,
|
||||
num_steps=num_steps,
|
||||
speed=speed,
|
||||
)
|
||||
else:
|
||||
return api.auto(
|
||||
text=text,
|
||||
num_steps=num_steps,
|
||||
speed=speed,
|
||||
)
|
||||
|
||||
|
||||
profiles = VoiceProfiles()
|
||||
|
||||
|
||||
def get_profiles() -> VoiceProfiles:
|
||||
return profiles
|
||||
|
||||
|
||||
def list_profiles() -> List[str]:
|
||||
return profiles.list()
|
||||
|
||||
|
||||
def get_profile(name: str) -> Optional[VoiceProfile]:
|
||||
return profiles.get(name)
|
||||
|
||||
|
||||
def add_profile(profile: VoiceProfile) -> None:
|
||||
profiles.add(profile)
|
||||
|
||||
|
||||
def remove_profile(name: str) -> bool:
|
||||
return profiles.remove(name)
|
||||
100
tests/test_omnivoice.py
Normal file
100
tests/test_omnivoice.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import unittest
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
import src.utils.config_manager as config
|
||||
|
||||
|
||||
class TestOmniVoiceConfig(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.config = config.config
|
||||
self.config.reset()
|
||||
|
||||
def test_register_params(self):
|
||||
from modules.omnivoice.config import register_params
|
||||
from modules.omnivoice.config import (
|
||||
MODEL_NAME,
|
||||
DEVICE,
|
||||
DTYPE,
|
||||
NUM_STEPS,
|
||||
SPEED,
|
||||
PROFILES_DIR,
|
||||
OUTPUT_DIR,
|
||||
ENABLED,
|
||||
DEFAULT_MODEL_NAME,
|
||||
DEFAULT_DEVICE,
|
||||
)
|
||||
|
||||
register_params()
|
||||
|
||||
self.assertEqual(
|
||||
self.config.get(MODEL_NAME, cat="omnivoice"),
|
||||
DEFAULT_MODEL_NAME,
|
||||
)
|
||||
self.assertEqual(
|
||||
self.config.get(DEVICE, cat="omnivoice"),
|
||||
DEFAULT_DEVICE,
|
||||
)
|
||||
|
||||
|
||||
class TestOmniVoiceAPI(unittest.TestCase):
|
||||
def test_api_import(self):
|
||||
from modules.omnivoice.api import VoiceAPI
|
||||
|
||||
api = VoiceAPI()
|
||||
self.assertIsNotNone(api)
|
||||
|
||||
|
||||
class TestOmniVoiceProfiles(unittest.TestCase):
|
||||
def setUp(self):
|
||||
import tempfile
|
||||
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
from modules.omnivoice.profiles import VoiceProfiles
|
||||
|
||||
self.profiles = VoiceProfiles(self.temp_dir)
|
||||
|
||||
def test_profile_create(self):
|
||||
from modules.omnivoice.profiles import VoiceProfile
|
||||
|
||||
profile = VoiceProfile(
|
||||
name="test",
|
||||
ref_audio="ref.wav",
|
||||
ref_text="test text",
|
||||
description="test profile",
|
||||
)
|
||||
self.assertEqual(profile.name, "test")
|
||||
self.assertEqual(profile.mode, "clone")
|
||||
|
||||
def test_profile_instruct_mode(self):
|
||||
from modules.omnivoice.profiles import VoiceProfile
|
||||
|
||||
profile = VoiceProfile(
|
||||
name="test_design",
|
||||
instruct="female, british",
|
||||
description="test design",
|
||||
)
|
||||
self.assertEqual(profile.mode, "design")
|
||||
|
||||
def test_profile_auto_mode(self):
|
||||
from modules.omnivoice.profiles import VoiceProfile
|
||||
|
||||
profile = VoiceProfile(
|
||||
name="test_auto",
|
||||
description="test auto",
|
||||
)
|
||||
self.assertEqual(profile.mode, "auto")
|
||||
|
||||
|
||||
class TestOmniVoiceCLI(unittest.TestCase):
|
||||
def test_cli_import(self):
|
||||
from modules.omnivoice import cli
|
||||
|
||||
self.assertTrue(hasattr(cli, "main"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user