Files
ai1/services/vad/app.py

154 lines
6.3 KiB
Python

import asyncio
import numpy as np
import torch
import redis.asyncio as redis
from fastapi import FastAPI, WebSocket, HTTPException
import os
import io
import wave
import requests
app = FastAPI()
# Redis connection
redis_host = os.getenv("REDIS_HOST", "localhost")
redis_port = int(os.getenv("REDIS_PORT", 6379))
r = redis.Redis(host=redis_host, port=redis_port, decode_responses=True)
# Silero VAD model (load once)
print("Loading Silero VAD model...")
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
(get_speech_timestamps, _, _, _, _) = utils
print("VAD model loaded.")
# Configuration
SAMPLE_RATE = 16000
MIN_SPEECH_DURATION = 0.5 # seconds
MIN_SILENCE_DURATION = 0.5 # seconds
BUFFER_MAX_DURATION = 30 # max buffer size in seconds
BUFFER_MAX_SAMPLES = BUFFER_MAX_DURATION * SAMPLE_RATE
# Audio buffer for current speech segment
audio_buffer = bytearray()
# State
is_speaking = False
last_speech_end = 0
@app.get("/health")
async def health():
return {"status": "ok", "service": "vad"}
@app.websocket("/audio-stream")
async def audio_stream(websocket: WebSocket):
global audio_buffer, is_speaking, last_speech_end
await websocket.accept()
print("Client connected to VAD")
# We'll collect audio in a temporary buffer for VAD processing
temp_buffer = bytearray()
try:
while True:
# Receive audio chunk (bytes, int16, 16kHz mono)
chunk = await websocket.receive_bytes()
temp_buffer.extend(chunk)
# Process when we have enough for a window (e.g., 500ms)
if len(temp_buffer) >= SAMPLE_RATE // 2 * 2: # 500ms = 8000 samples (each 2 bytes)
# Convert to numpy array (int16) then to float32 for VAD
audio_int16 = np.frombuffer(temp_buffer, dtype=np.int16)
audio_float32 = audio_int16.astype(np.float32) / 32768.0
# Get speech timestamps for this chunk
speech_ts = get_speech_timestamps(
audio_float32,
model,
sampling_rate=SAMPLE_RATE,
threshold=0.5,
min_speech_duration_ms=int(MIN_SPEECH_DURATION * 1000),
min_silence_duration_ms=int(MIN_SILENCE_DURATION * 1000)
)
if speech_ts:
# There is speech in this chunk
if not is_speaking:
# Speech just started
is_speaking = True
# Send interrupt signal to Redis (to stop TTS)
await r.publish("interrupt", "1")
print("Speech started, interrupt sent")
# Add the whole chunk to the main buffer (we might want to cut exactly, but for simplicity)
audio_buffer.extend(temp_buffer)
temp_buffer.clear()
else:
# No speech in this chunk
if is_speaking:
# Possible end of speech, but we should wait a bit before finalizing
# For now, we'll just accumulate silence for a while, then send
audio_buffer.extend(temp_buffer)
temp_buffer.clear()
# Check if silence duration exceeded threshold
# We need to track time, but for simplicity we'll assume this chunk is silence and just continue
# A real implementation would check accumulated silence length
# Here we'll just keep adding, and later send when buffer stops growing for some time
# However, for a simple prototype, we can just send when we detect end of speech
# We'll use a timer: after speech, start a timer; if no new speech within 1 sec, send
# We'll implement timer later. For now, we'll just send after 1 second of no speech
# This requires asyncio.sleep and tracking.
# Instead, we'll add a simple approach: after each chunk with no speech, if buffer not empty, check if we should send.
# But to avoid complexity, we'll just send after a fixed silence threshold.
# Let's implement a more robust approach using asyncio.create_task
asyncio.create_task(flush_after_silence())
else:
# No speech and not speaking, just keep temp buffer but maybe drop if too long? We'll limit.
if len(temp_buffer) > SAMPLE_RATE * 10: # keep max 10 seconds non-speech
temp_buffer.clear()
# Limit main buffer size
if len(audio_buffer) > BUFFER_MAX_SAMPLES * 2:
audio_buffer = audio_buffer[-BUFFER_MAX_SAMPLES * 2:]
except Exception as e:
print(f"VAD connection closed: {e}")
finally:
# Clean up if needed
pass
async def flush_after_silence():
"""Wait for silence duration then send audio to ASR."""
await asyncio.sleep(MIN_SILENCE_DURATION)
global audio_buffer, is_speaking
if audio_buffer and not is_speaking:
# Send to ASR
await send_to_asr(audio_buffer)
audio_buffer.clear()
async def send_to_asr(audio_data: bytes):
"""Send audio segment to ASR service and get transcription."""
# Prepare WAV file in memory
wav_io = io.BytesIO()
with wave.open(wav_io, 'wb') as wav:
wav.setnchannels(1)
wav.setsampwidth(2) # 16-bit
wav.setframerate(SAMPLE_RATE)
wav.writeframes(audio_data)
wav_io.seek(0)
try:
response = requests.post(
"http://jarvis-asr:8000/transcribe",
files={"audio": ("audio.wav", wav_io, "audio/wav")}
)
if response.status_code == 200:
text = response.json().get("text", "")
print(f"ASR result: {text}")
# Here you would send the text to orchestrator for further processing
# For now, just publish to Redis? We'll do later.
except Exception as e:
print(f"Error sending to ASR: {e}")