Add OmniVoice TTS module with config, API, profiles and CLI
- Create modules/omnivoice/ with VoiceAPI, VoiceProfiles, CLI - Add config manager integration with local model support - Add app/komAI.py entry point - Add tests/test_omnivoice.py - Clone OmniVoice to external/ for development - Add omnivoice config to global.yaml
This commit is contained in:
224
modules/omnivoice/api.py
Normal file
224
modules/omnivoice/api.py
Normal file
@@ -0,0 +1,224 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
import numpy as np
|
||||
|
||||
import src.utils.config_manager as config_mgr
|
||||
from .config import (
|
||||
MODEL_NAME,
|
||||
MODEL_PATH,
|
||||
DEVICE,
|
||||
DTYPE,
|
||||
NUM_STEPS,
|
||||
SPEED,
|
||||
OUTPUT_DIR,
|
||||
DEFAULT_MODEL_NAME,
|
||||
DEFAULT_MODEL_PATH,
|
||||
DEFAULT_DEVICE,
|
||||
DEFAULT_DTYPE,
|
||||
DEFAULT_NUM_STEPS,
|
||||
DEFAULT_SPEED,
|
||||
DEFAULT_OUTPUT_DIR,
|
||||
)
|
||||
|
||||
config = config_mgr.config
|
||||
|
||||
|
||||
def _cfg(key: str, default):
|
||||
val = config.get(key, cat="omnivoice")
|
||||
return val if val is not None else default
|
||||
|
||||
|
||||
class VoiceAPI:
|
||||
def __init__(self):
|
||||
self._model = None
|
||||
self._config = None
|
||||
|
||||
@property
|
||||
def _get_config(self):
|
||||
if self._config is None:
|
||||
model_path = _cfg(MODEL_PATH, DEFAULT_MODEL_PATH)
|
||||
if model_path and Path(model_path).exists():
|
||||
actual_path = model_path
|
||||
else:
|
||||
actual_path = None
|
||||
self._config = {
|
||||
"model_name": _cfg(MODEL_NAME, DEFAULT_MODEL_NAME),
|
||||
"model_path": actual_path,
|
||||
"device": _cfg(DEVICE, DEFAULT_DEVICE),
|
||||
"dtype": _cfg(DTYPE, DEFAULT_DTYPE),
|
||||
"num_steps": _cfg(NUM_STEPS, DEFAULT_NUM_STEPS),
|
||||
"speed": _cfg(SPEED, DEFAULT_SPEED),
|
||||
"output_dir": _cfg(OUTPUT_DIR, DEFAULT_OUTPUT_DIR),
|
||||
}
|
||||
return self._config
|
||||
|
||||
def _load_model(self):
|
||||
if self._model is not None:
|
||||
return self._model
|
||||
try:
|
||||
import torch
|
||||
from omnivoice import OmniVoice
|
||||
except ImportError:
|
||||
raise ImportError("omnivoice not installed. Run: pip install omnivoice")
|
||||
|
||||
cfg = self._get_config
|
||||
dtype = getattr(torch, cfg["dtype"], torch.float16)
|
||||
|
||||
if cfg["model_path"]:
|
||||
self._model = OmniVoice.from_pretrained(
|
||||
cfg["model_path"],
|
||||
device_map=cfg["device"],
|
||||
dtype=dtype,
|
||||
)
|
||||
else:
|
||||
self._model = OmniVoice.from_pretrained(
|
||||
cfg["model_name"],
|
||||
device_map=cfg["device"],
|
||||
dtype=dtype,
|
||||
)
|
||||
return self._model
|
||||
|
||||
def generate(
|
||||
self,
|
||||
text: str,
|
||||
ref_audio: Optional[str] = None,
|
||||
ref_text: Optional[str] = None,
|
||||
instruct: Optional[str] = None,
|
||||
num_steps: Optional[int] = None,
|
||||
speed: Optional[float] = None,
|
||||
duration: Optional[float] = None,
|
||||
) -> List[np.ndarray]:
|
||||
model = self._load_model()
|
||||
|
||||
kwargs = {"text": text}
|
||||
if ref_audio:
|
||||
kwargs["ref_audio"] = ref_audio
|
||||
if ref_text:
|
||||
kwargs["ref_text"] = ref_text
|
||||
if instruct:
|
||||
kwargs["instruct"] = instruct
|
||||
|
||||
cfg = self._get_config
|
||||
kwargs["num_step"] = num_steps or cfg["num_steps"] or 32
|
||||
kwargs["speed"] = speed or cfg["speed"] or 1.0
|
||||
if duration:
|
||||
kwargs["duration"] = duration
|
||||
|
||||
audio = model.generate(**kwargs)
|
||||
return audio
|
||||
|
||||
def clone(
|
||||
self,
|
||||
text: str,
|
||||
ref_audio: str,
|
||||
ref_text: Optional[str] = None,
|
||||
num_steps: Optional[int] = None,
|
||||
speed: Optional[float] = None,
|
||||
) -> List[np.ndarray]:
|
||||
return self.generate(
|
||||
text=text,
|
||||
ref_audio=ref_audio,
|
||||
ref_text=ref_text,
|
||||
num_steps=num_steps,
|
||||
speed=speed,
|
||||
)
|
||||
|
||||
def design(
|
||||
self,
|
||||
text: str,
|
||||
instruct: str,
|
||||
num_steps: Optional[int] = None,
|
||||
speed: Optional[float] = None,
|
||||
) -> List[np.ndarray]:
|
||||
return self.generate(
|
||||
text=text,
|
||||
instruct=instruct,
|
||||
num_steps=num_steps,
|
||||
speed=speed,
|
||||
)
|
||||
|
||||
def auto(
|
||||
self,
|
||||
text: str,
|
||||
num_steps: Optional[int] = None,
|
||||
speed: Optional[float] = None,
|
||||
) -> List[np.ndarray]:
|
||||
return self.generate(
|
||||
text=text,
|
||||
num_steps=num_steps,
|
||||
speed=speed,
|
||||
)
|
||||
|
||||
def save_audio(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
filename: Optional[str] = None,
|
||||
sample_rate: int = 24000,
|
||||
) -> Optional[Path]:
|
||||
if filename is None:
|
||||
return self.play_audio(audio, sample_rate)
|
||||
|
||||
import soundfile as sf
|
||||
|
||||
cfg = self._get_config
|
||||
output_dir = Path(cfg["output_dir"] or "output/voice")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
path = output_dir / filename
|
||||
sf.write(str(path), audio, sample_rate)
|
||||
return path
|
||||
|
||||
def play_audio(self, audio: np.ndarray, sample_rate: int = 24000) -> None:
|
||||
try:
|
||||
import sounddevice as sd
|
||||
|
||||
sd.play(audio, sample_rate)
|
||||
sd.wait()
|
||||
except ImportError:
|
||||
raise ImportError("sounddevice not installed. Run: pip install sounddevice")
|
||||
|
||||
def speak(
|
||||
self,
|
||||
text: str,
|
||||
ref_audio: Optional[str] = None,
|
||||
ref_text: Optional[str] = None,
|
||||
instruct: Optional[str] = None,
|
||||
num_steps: Optional[int] = None,
|
||||
speed: Optional[float] = None,
|
||||
) -> None:
|
||||
audio = self.generate(
|
||||
text=text,
|
||||
ref_audio=ref_audio,
|
||||
ref_text=ref_text,
|
||||
instruct=instruct,
|
||||
num_steps=num_steps,
|
||||
speed=speed,
|
||||
)
|
||||
self.play_audio(audio[0])
|
||||
|
||||
def reload(self):
|
||||
self._model = None
|
||||
self._config = None
|
||||
|
||||
|
||||
api = VoiceAPI()
|
||||
|
||||
|
||||
def get_api() -> VoiceAPI:
|
||||
return api
|
||||
|
||||
|
||||
def generate(*args, **kwargs) -> List[np.ndarray]:
|
||||
return api.generate(*args, **kwargs)
|
||||
|
||||
|
||||
def clone(*args, **kwargs) -> List[np.ndarray]:
|
||||
return api.clone(*args, **kwargs)
|
||||
|
||||
|
||||
def design(*args, **kwargs) -> List[np.ndarray]:
|
||||
return api.design(*args, **kwargs)
|
||||
|
||||
|
||||
def auto(*args, **kwargs) -> List[np.ndarray]:
|
||||
return api.auto(*args, **kwargs)
|
||||
Reference in New Issue
Block a user