import os import time import logging from contextlib import asynccontextmanager from typing import Optional import torch from fastapi import FastAPI, File, Form, UploadFile, HTTPException from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from app.transcriber import transcriber, MODEL_NAME, DEVICE from app.diarizer import diarizer, DIARIZER_MODEL logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s") logger = logging.getLogger("parakeet-api") @asynccontextmanager async def lifespan(app: FastAPI): logger.info(f"Loading ASR model {MODEL_NAME} on {DEVICE}") transcriber.load_model() logger.info("ASR model ready") logger.info(f"Loading diarizer {DIARIZER_MODEL} on {DEVICE}") diarizer.load_model() logger.info("Diarizer ready") yield app = FastAPI(title="Parakeet ASR + Sortformer Diarization API", version="1.2.0", lifespan=lifespan) app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) @app.get("/") async def root(): return {"service": "parakeet-asr", "model": MODEL_NAME, "diarizer": DIARIZER_MODEL, "device": DEVICE, "endpoints": {"transcribe": "/v1/audio/transcriptions", "diarize": "/v1/audio/diarize", "models": "/v1/models", "health": "/health"}} @app.get("/health") async def health(): return {"status": "ready" if (transcriber._loaded and diarizer._loaded) else "loading", "asr_loaded": transcriber._loaded, "diarizer_loaded": diarizer._loaded, "model": MODEL_NAME, "diarizer_model": DIARIZER_MODEL, "device": DEVICE} @app.get("/v1/models") async def list_models(): return {"object": "list", "data": [ {"id": "parakeet-tdt-0.6b-v3", "object": "model", "owned_by": "nvidia", "kind": "stt"}, {"id": "whisper-1", "object": "model", "owned_by": "nvidia", "kind": "stt"}, {"id": DIARIZER_MODEL.split("/")[-1], "object": "model", "owned_by": "nvidia", "kind": "diarization"}]} @app.post("/v1/audio/transcriptions") async def transcribe( file: UploadFile = File(...), model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"), language: Optional[str] = Form(default=None), response_format: Optional[str] = Form(default="json"), temperature: Optional[float] = Form(default=0.0), prompt: Optional[str] = Form(default=None), ): if not transcriber._loaded: raise HTTPException(status_code=503, detail="Model loading") audio_bytes = await file.read() if len(audio_bytes) == 0: raise HTTPException(status_code=400, detail="Empty file") max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024 if len(audio_bytes) > max_size: raise HTTPException(status_code=413, detail=f"File too large") want_timestamps = response_format == "verbose_json" start_time = time.time() try: result = transcriber.transcribe( audio_bytes, file.filename, language, timestamps=want_timestamps ) except Exception as e: logger.exception("Transcription failed") raise HTTPException(status_code=500, detail=f"Failed: {e}") elapsed = time.time() - start_time duration = result.get("duration", 0) rtfx = duration / elapsed if elapsed > 0 else 0 logger.info(f"Done: {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt)") if response_format == "text": return JSONResponse(content=result["text"], media_type="text/plain") if response_format == "verbose_json": return { "task": "transcribe", "language": language or "en", "duration": duration, "text": result["text"], "segments": result.get("segments", []), "words": result.get("words", []), } return {"text": result["text"]} @app.post("/v1/audio/translations") async def translate(file: UploadFile = File(...), model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"), language: Optional[str] = Form(default=None), response_format: Optional[str] = Form(default="json")): return await transcribe(file=file, model=model, language=language, response_format=response_format) @app.post("/v1/audio/diarize") async def diarize( file: UploadFile = File(...), ): """Speaker diarization via Sortformer. Returns who-spoke-when as a list of turns. Does NOT transcribe — pair this output with /v1/audio/transcriptions (verbose_json) and merge by timestamp to produce a diarized transcript. Response shape: { "segments": [{"start_s": 0.00, "end_s": 4.50, "speaker": "Speaker_0"}, ...], "speakers_detected": ["Speaker_0", "Speaker_1"], "duration": 90.5, "model": "nvidia/diar_sortformer_4spk-v1", "device": "cuda" } """ if not diarizer._loaded: raise HTTPException(status_code=503, detail="Diarizer loading") audio_bytes = await file.read() if len(audio_bytes) == 0: raise HTTPException(status_code=400, detail="Empty file") max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024 if len(audio_bytes) > max_size: raise HTTPException(status_code=413, detail="File too large") start_time = time.time() try: result = diarizer.diarize(audio_bytes, file.filename or "audio.wav") except Exception as e: logger.exception("Diarization failed") raise HTTPException(status_code=500, detail=f"Failed: {e}") elapsed = time.time() - start_time duration = result.get("duration", 0) rtfx = duration / elapsed if elapsed > 0 else 0 logger.info(f"Diarized {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt), " f"{len(result['speakers_detected'])} speakers, {len(result['segments'])} turns") return result