import os import time import logging from contextlib import asynccontextmanager from typing import Optional import torch from fastapi import FastAPI, File, Form, UploadFile, HTTPException from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from app.transcriber import transcriber, MODEL_NAME, DEVICE from app.diarizer import diarizer, DIARIZER_MODEL, EMBEDDING_MODEL logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s") logger = logging.getLogger("parakeet-api") @asynccontextmanager async def lifespan(app: FastAPI): logger.info(f"Loading ASR model {MODEL_NAME} on {DEVICE}") transcriber.load_model() logger.info("ASR model ready") logger.info(f"Loading diarizer {DIARIZER_MODEL} on {DEVICE}") diarizer.load_model() logger.info("Diarizer ready") yield app = FastAPI(title="Parakeet ASR + Sortformer Diarization + TitaNet Embedding API", version="1.3.0", lifespan=lifespan) app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) @app.get("/") async def root(): return {"service": "parakeet-asr", "model": MODEL_NAME, "diarizer": DIARIZER_MODEL, "embedding": EMBEDDING_MODEL, "device": DEVICE, "endpoints": {"transcribe": "/v1/audio/transcriptions", "diarize": "/v1/audio/diarize", "diarize_chunk": "/v1/audio/diarize-chunk", "models": "/v1/models", "health": "/health"}} @app.get("/health") async def health(): return {"status": "ready" if (transcriber._loaded and diarizer._loaded) else "loading", "asr_loaded": transcriber._loaded, "diarizer_loaded": diarizer._loaded, "model": MODEL_NAME, "diarizer_model": DIARIZER_MODEL, "device": DEVICE} @app.get("/v1/models") async def list_models(): return {"object": "list", "data": [ {"id": "parakeet-tdt-0.6b-v3", "object": "model", "owned_by": "nvidia", "kind": "stt"}, {"id": "whisper-1", "object": "model", "owned_by": "nvidia", "kind": "stt"}, {"id": DIARIZER_MODEL.split("/")[-1], "object": "model", "owned_by": "nvidia", "kind": "diarization"}]} @app.post("/v1/audio/transcriptions") async def transcribe( file: UploadFile = File(...), model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"), language: Optional[str] = Form(default=None), response_format: Optional[str] = Form(default="json"), temperature: Optional[float] = Form(default=0.0), prompt: Optional[str] = Form(default=None), ): if not transcriber._loaded: raise HTTPException(status_code=503, detail="Model loading") audio_bytes = await file.read() if len(audio_bytes) == 0: raise HTTPException(status_code=400, detail="Empty file") max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024 if len(audio_bytes) > max_size: raise HTTPException(status_code=413, detail=f"File too large") want_timestamps = response_format == "verbose_json" start_time = time.time() try: result = transcriber.transcribe( audio_bytes, file.filename, language, timestamps=want_timestamps ) except Exception as e: logger.exception("Transcription failed") raise HTTPException(status_code=500, detail=f"Failed: {e}") elapsed = time.time() - start_time duration = result.get("duration", 0) rtfx = duration / elapsed if elapsed > 0 else 0 logger.info(f"Done: {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt)") if response_format == "text": return JSONResponse(content=result["text"], media_type="text/plain") if response_format == "verbose_json": return { "task": "transcribe", "language": language or "en", "duration": duration, "text": result["text"], "segments": result.get("segments", []), "words": result.get("words", []), } return {"text": result["text"]} @app.post("/v1/audio/translations") async def translate(file: UploadFile = File(...), model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"), language: Optional[str] = Form(default=None), response_format: Optional[str] = Form(default="json")): return await transcribe(file=file, model=model, language=language, response_format=response_format) @app.post("/v1/audio/diarize") async def diarize( file: UploadFile = File(...), ): """Speaker diarization via Sortformer. Returns who-spoke-when as a list of turns. Does NOT transcribe — pair this output with /v1/audio/transcriptions (verbose_json) and merge by timestamp to produce a diarized transcript. Response shape: { "segments": [{"start_s": 0.00, "end_s": 4.50, "speaker": "Speaker_0"}, ...], "speakers_detected": ["Speaker_0", "Speaker_1"], "duration": 90.5, "model": "nvidia/diar_sortformer_4spk-v1", "device": "cuda" } """ if not diarizer._loaded: raise HTTPException(status_code=503, detail="Diarizer loading") audio_bytes = await file.read() if len(audio_bytes) == 0: raise HTTPException(status_code=400, detail="Empty file") max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024 if len(audio_bytes) > max_size: raise HTTPException(status_code=413, detail="File too large") start_time = time.time() try: result = diarizer.diarize(audio_bytes, file.filename or "audio.wav") except Exception as e: logger.exception("Diarization failed") raise HTTPException(status_code=500, detail=f"Failed: {e}") elapsed = time.time() - start_time duration = result.get("duration", 0) rtfx = duration / elapsed if elapsed > 0 else 0 logger.info(f"Diarized {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt), " f"{len(result['speakers_detected'])} speakers, {len(result['segments'])} turns") return result @app.post("/v1/audio/diarize-chunk") async def diarize_chunk( file: UploadFile = File(...), ): """Per-chunk worker: diarize + extract one voice fingerprint per local speaker. Designed to be called per-audio-chunk by an external orchestrator (Recap Relay) that handles the cross-chunk speaker clustering itself. Single audio decode, single set of GPU passes. Does NOT transcribe — pair with /v1/audio/transcriptions on the same chunk if you want transcript + speakers + fingerprints in one shot. Response shape: { "duration": 300.0, "segments": [{"start_s": 1.2, "end_s": 4.8, "speaker": "Speaker_0"}, ...], "speakers_detected": ["Speaker_0", "Speaker_1", "Speaker_2"], "fingerprints": { "Speaker_0": [0.123, -0.045, ..., 0.211], # 192-dim TitaNet embedding "Speaker_1": [0.087, 0.221, ..., -0.034], "Speaker_2": [-0.156, 0.078, ..., 0.144] }, "models": { "diarization": "nvidia/diar_sortformer_4spk-v1", "embedding": "nvidia/speakerverification_en_titanet_large" } } Speaker labels are LOCAL to this chunk. Run cosine-similarity clustering across the fingerprints from all chunks to merge `chunkA.Speaker_0` with `chunkB.Speaker_2` when they're the same voice. Recommended threshold: cosine distance 0.7 (NeMo default). """ if not diarizer._loaded: raise HTTPException(status_code=503, detail="Diarizer loading") audio_bytes = await file.read() if len(audio_bytes) == 0: raise HTTPException(status_code=400, detail="Empty file") max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024 if len(audio_bytes) > max_size: raise HTTPException(status_code=413, detail="File too large") start_time = time.time() try: result = diarizer.diarize_chunk(audio_bytes, file.filename or "audio.wav") except Exception as e: logger.exception("diarize_chunk failed") raise HTTPException(status_code=500, detail=f"Failed: {e}") elapsed = time.time() - start_time duration = result.get("duration", 0) rtfx = duration / elapsed if elapsed > 0 else 0 n_fp = len(result.get("fingerprints") or {}) logger.info(f"diarize_chunk {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt), " f"{len(result['speakers_detected'])} local speakers, " f"{len(result['segments'])} turns, {n_fp} fingerprints") return result