import os import time import logging from contextlib import asynccontextmanager from typing import Optional import torch from fastapi import FastAPI, File, Form, UploadFile, HTTPException from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from app.transcriber import transcriber, MODEL_NAME, DEVICE logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s") logger = logging.getLogger("parakeet-api") @asynccontextmanager async def lifespan(app: FastAPI): logger.info(f"Loading model {MODEL_NAME} on {DEVICE}") transcriber.load_model() logger.info("Model ready") yield app = FastAPI(title="Parakeet ASR API", version="1.1.0", lifespan=lifespan) app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) @app.get("/") async def root(): return {"service": "parakeet-asr", "model": MODEL_NAME, "device": DEVICE, "endpoints": {"transcribe": "/v1/audio/transcriptions", "models": "/v1/models", "health": "/health"}} @app.get("/health") async def health(): return {"status": "ready" if transcriber._loaded else "loading", "model": MODEL_NAME, "device": DEVICE} @app.get("/v1/models") async def list_models(): return {"object": "list", "data": [ {"id": "parakeet-tdt-0.6b-v3", "object": "model", "owned_by": "nvidia"}, {"id": "whisper-1", "object": "model", "owned_by": "nvidia"}]} @app.post("/v1/audio/transcriptions") async def transcribe( file: UploadFile = File(...), model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"), language: Optional[str] = Form(default=None), response_format: Optional[str] = Form(default="json"), temperature: Optional[float] = Form(default=0.0), prompt: Optional[str] = Form(default=None), ): if not transcriber._loaded: raise HTTPException(status_code=503, detail="Model loading") audio_bytes = await file.read() if len(audio_bytes) == 0: raise HTTPException(status_code=400, detail="Empty file") max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024 if len(audio_bytes) > max_size: raise HTTPException(status_code=413, detail=f"File too large") want_timestamps = response_format == "verbose_json" start_time = time.time() try: result = transcriber.transcribe( audio_bytes, file.filename, language, timestamps=want_timestamps ) except Exception as e: logger.exception("Transcription failed") raise HTTPException(status_code=500, detail=f"Failed: {e}") elapsed = time.time() - start_time duration = result.get("duration", 0) rtfx = duration / elapsed if elapsed > 0 else 0 logger.info(f"Done: {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt)") if response_format == "text": return JSONResponse(content=result["text"], media_type="text/plain") if response_format == "verbose_json": return { "task": "transcribe", "language": language or "en", "duration": duration, "text": result["text"], "segments": result.get("segments", []), "words": result.get("words", []), } return {"text": result["text"]} @app.post("/v1/audio/translations") async def translate(file: UploadFile = File(...), model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"), language: Optional[str] = Form(default=None), response_format: Optional[str] = Form(default="json")): return await transcribe(file=file, model=model, language=language, response_format=response_format)