import asyncio import numpy as np import torch import redis.asyncio as redis from fastapi import FastAPI, WebSocket, HTTPException import os import io import wave import requests app = FastAPI() # Redis connection redis_host = os.getenv("REDIS_HOST", "localhost") redis_port = int(os.getenv("REDIS_PORT", 6379)) r = redis.Redis(host=redis_host, port=redis_port, decode_responses=True) # Silero VAD model (load once) print("Loading Silero VAD model...") model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad') (get_speech_timestamps, _, _, _, _) = utils print("VAD model loaded.") # Configuration SAMPLE_RATE = 16000 MIN_SPEECH_DURATION = 0.5 # seconds MIN_SILENCE_DURATION = 0.5 # seconds BUFFER_MAX_DURATION = 30 # max buffer size in seconds BUFFER_MAX_SAMPLES = BUFFER_MAX_DURATION * SAMPLE_RATE # Audio buffer for current speech segment audio_buffer = bytearray() # State is_speaking = False last_speech_end = 0 @app.get("/health") async def health(): return {"status": "ok", "service": "vad"} @app.websocket("/audio-stream") async def audio_stream(websocket: WebSocket): global audio_buffer, is_speaking, last_speech_end await websocket.accept() print("Client connected to VAD") # We'll collect audio in a temporary buffer for VAD processing temp_buffer = bytearray() try: while True: # Receive audio chunk (bytes, int16, 16kHz mono) chunk = await websocket.receive_bytes() temp_buffer.extend(chunk) # Process when we have enough for a window (e.g., 500ms) if len(temp_buffer) >= SAMPLE_RATE // 2 * 2: # 500ms = 8000 samples (each 2 bytes) # Convert to numpy array (int16) then to float32 for VAD audio_int16 = np.frombuffer(temp_buffer, dtype=np.int16) audio_float32 = audio_int16.astype(np.float32) / 32768.0 # Get speech timestamps for this chunk speech_ts = get_speech_timestamps( audio_float32, model, sampling_rate=SAMPLE_RATE, threshold=0.5, min_speech_duration_ms=int(MIN_SPEECH_DURATION * 1000), min_silence_duration_ms=int(MIN_SILENCE_DURATION * 1000) ) if speech_ts: # There is speech in this chunk if not is_speaking: # Speech just started is_speaking = True # Send interrupt signal to Redis (to stop TTS) await r.publish("interrupt", "1") print("Speech started, interrupt sent") # Add the whole chunk to the main buffer (we might want to cut exactly, but for simplicity) audio_buffer.extend(temp_buffer) temp_buffer.clear() else: # No speech in this chunk if is_speaking: # Possible end of speech, but we should wait a bit before finalizing # For now, we'll just accumulate silence for a while, then send audio_buffer.extend(temp_buffer) temp_buffer.clear() # Check if silence duration exceeded threshold # We need to track time, but for simplicity we'll assume this chunk is silence and just continue # A real implementation would check accumulated silence length # Here we'll just keep adding, and later send when buffer stops growing for some time # However, for a simple prototype, we can just send when we detect end of speech # We'll use a timer: after speech, start a timer; if no new speech within 1 sec, send # We'll implement timer later. For now, we'll just send after 1 second of no speech # This requires asyncio.sleep and tracking. # Instead, we'll add a simple approach: after each chunk with no speech, if buffer not empty, check if we should send. # But to avoid complexity, we'll just send after a fixed silence threshold. # Let's implement a more robust approach using asyncio.create_task asyncio.create_task(flush_after_silence()) else: # No speech and not speaking, just keep temp buffer but maybe drop if too long? We'll limit. if len(temp_buffer) > SAMPLE_RATE * 10: # keep max 10 seconds non-speech temp_buffer.clear() # Limit main buffer size if len(audio_buffer) > BUFFER_MAX_SAMPLES * 2: audio_buffer = audio_buffer[-BUFFER_MAX_SAMPLES * 2:] except Exception as e: print(f"VAD connection closed: {e}") finally: # Clean up if needed pass async def flush_after_silence(): """Wait for silence duration then send audio to ASR.""" await asyncio.sleep(MIN_SILENCE_DURATION) global audio_buffer, is_speaking if audio_buffer and not is_speaking: # Send to ASR await send_to_asr(audio_buffer) audio_buffer.clear() async def send_to_asr(audio_data: bytes): """Send audio segment to ASR service and get transcription.""" # Prepare WAV file in memory wav_io = io.BytesIO() with wave.open(wav_io, 'wb') as wav: wav.setnchannels(1) wav.setsampwidth(2) # 16-bit wav.setframerate(SAMPLE_RATE) wav.writeframes(audio_data) wav_io.seek(0) try: response = requests.post( "http://jarvis-asr:8000/transcribe", files={"audio": ("audio.wav", wav_io, "audio/wav")} ) if response.status_code == 200: text = response.json().get("text", "") print(f"ASR result: {text}") # Here you would send the text to orchestrator for further processing # For now, just publish to Redis? We'll do later. except Exception as e: print(f"Error sending to ASR: {e}")