API 集成 #

API 集成概述 #

将 Whisper 集成到 API 服务中,可以为各种应用提供语音识别能力。

text
┌─────────────────────────────────────────────────────────────┐
│                    API 架构                                  │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  客户端层:                                                   │
│  ├── Web 应用                                               │
│  ├── 移动应用                                               │
│  └── 桌面应用                                               │
│                                                             │
│  API 层:                                                     │
│  ├── REST API                                               │
│  ├── WebSocket                                              │
│  └── gRPC                                                   │
│                                                             │
│  服务层:                                                     │
│  ├── Whisper 模型                                           │
│  ├── 任务队列                                               │
│  └── 结果存储                                               │
│                                                             │
└─────────────────────────────────────────────────────────────┘

FastAPI 集成 #

基本 API 服务 #

python
import whisper
from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks
from fastapi.responses import JSONResponse, FileResponse
from pydantic import BaseModel
from typing import Optional
import tempfile
import os
import uuid
import json

app = FastAPI(
    title="Whisper API",
    description="语音识别 API 服务",
    version="1.0.0"
)

model = whisper.load_model("base")

tasks_store = {}

class TranscriptionResult(BaseModel):
    task_id: str
    status: str
    text: Optional[str] = None
    language: Optional[str] = None
    duration: Optional[float] = None

class TranscriptionOptions(BaseModel):
    language: Optional[str] = None
    task: str = "transcribe"
    model: str = "base"

@app.post("/transcribe", response_model=TranscriptionResult)
async def transcribe_audio(
    file: UploadFile = File(...),
    language: Optional[str] = None,
    task: str = "transcribe"
):
    with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
        content = await file.read()
        tmp.write(content)
        tmp_path = tmp.name
    
    try:
        if task == "translate":
            result = model.translate(tmp_path, language=language)
        else:
            result = model.transcribe(tmp_path, language=language)
        
        return TranscriptionResult(
            task_id=str(uuid.uuid4()),
            status="completed",
            text=result["text"],
            language=result["language"],
            duration=result["segments"][-1]["end"] if result["segments"] else 0
        )
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))
    
    finally:
        os.unlink(tmp_path)

@app.post("/transcribe/async")
async def transcribe_async(
    background_tasks: BackgroundTasks,
    file: UploadFile = File(...),
    language: Optional[str] = None
):
    task_id = str(uuid.uuid4())
    
    tasks_store[task_id] = {
        "status": "processing",
        "text": None,
        "language": None
    }
    
    content = await file.read()
    
    background_tasks.add_task(
        process_transcription,
        task_id,
        content,
        file.filename,
        language
    )
    
    return {"task_id": task_id, "status": "processing"}

async def process_transcription(task_id, content, filename, language):
    with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(filename)[1]) as tmp:
        tmp.write(content)
        tmp_path = tmp.name
    
    try:
        result = model.transcribe(tmp_path, language=language)
        
        tasks_store[task_id] = {
            "status": "completed",
            "text": result["text"],
            "language": result["language"]
        }
    
    except Exception as e:
        tasks_store[task_id] = {
            "status": "error",
            "error": str(e)
        }
    
    finally:
        os.unlink(tmp_path)

@app.get("/tasks/{task_id}")
async def get_task_status(task_id: str):
    if task_id not in tasks_store:
        raise HTTPException(status_code=404, detail="Task not found")
    
    return tasks_store[task_id]

@app.get("/health")
async def health_check():
    return {"status": "healthy", "model": "base"}

@app.get("/")
async def root():
    return {
        "message": "Whisper API",
        "endpoints": {
            "/transcribe": "POST - 同步转录",
            "/transcribe/async": "POST - 异步转录",
            "/tasks/{task_id}": "GET - 查询任务状态",
            "/health": "GET - 健康检查"
        }
    }

完整 API 服务 #

python
import whisper
from fastapi import FastAPI, UploadFile, File, HTTPException, Query, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import Optional, List
import tempfile
import os
import uuid
import asyncio
from datetime import datetime
import json

app = FastAPI(
    title="Whisper API",
    description="企业级语音识别 API 服务",
    version="2.0.0"
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

models_cache = {}

def get_model(model_size: str = "base"):
    if model_size not in models_cache:
        models_cache[model_size] = whisper.load_model(model_size)
    return models_cache[model_size]

class TranscriptionResponse(BaseModel):
    task_id: str
    status: str
    text: Optional[str] = None
    language: Optional[str] = None
    duration: Optional[float] = None
    segments: Optional[List[dict]] = None
    created_at: str

class ErrorResponse(BaseModel):
    error: str
    detail: str

@app.post("/api/v1/transcribe", response_model=TranscriptionResponse)
async def transcribe(
    file: UploadFile = File(...),
    model_size: str = Query("base", enum=["tiny", "base", "small", "medium", "large"]),
    language: Optional[str] = Query(None),
    task: str = Query("transcribe", enum=["transcribe", "translate"]),
    word_timestamps: bool = Query(False),
    output_format: str = Query("json", enum=["json", "text", "srt", "vtt"])
):
    task_id = str(uuid.uuid4())
    
    valid_extensions = [".mp3", ".wav", ".m4a", ".flac", ".ogg", ".webm"]
    file_ext = os.path.splitext(file.filename)[1].lower()
    
    if file_ext not in valid_extensions:
        raise HTTPException(
            status_code=400,
            detail=f"不支持的文件格式: {file_ext}"
        )
    
    with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as tmp:
        content = await file.read()
        tmp.write(content)
        tmp_path = tmp.name
    
    try:
        model = get_model(model_size)
        
        if task == "translate":
            result = model.translate(tmp_path, language=language)
        else:
            result = model.transcribe(
                tmp_path,
                language=language,
                word_timestamps=word_timestamps
            )
        
        duration = result["segments"][-1]["end"] if result["segments"] else 0
        
        response = TranscriptionResponse(
            task_id=task_id,
            status="completed",
            text=result["text"],
            language=result["language"],
            duration=duration,
            segments=result["segments"] if output_format == "json" else None,
            created_at=datetime.now().isoformat()
        )
        
        return response
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))
    
    finally:
        os.unlink(tmp_path)

@app.post("/api/v1/detect-language")
async def detect_language(file: UploadFile = File(...)):
    with tempfile.NamedTemporaryFile(delete=False) as tmp:
        content = await file.read()
        tmp.write(content)
        tmp_path = tmp.name
    
    try:
        audio = whisper.load_audio(tmp_path)
        audio = whisper.pad_or_trim(audio)
        
        mel = whisper.log_mel_spectrogram(audio).to(next(model.parameters()).device)
        
        model = get_model("base")
        _, probs = model.detect_language(mel)
        
        detected_lang = max(probs, key=probs.get)
        confidence = probs[detected_lang]
        
        return {
            "language": detected_lang,
            "confidence": confidence,
            "all_probabilities": {k: float(v) for k, v in probs.items()}
        }
    
    finally:
        os.unlink(tmp_path)

@app.get("/api/v1/models")
async def list_models():
    return {
        "models": [
            {"name": "tiny", "size": "39M", "speed": "32x"},
            {"name": "base", "size": "74M", "speed": "16x"},
            {"name": "small", "size": "244M", "speed": "6x"},
            {"name": "medium", "size": "769M", "speed": "2x"},
            {"name": "large", "size": "1550M", "speed": "1x"}
        ]
    }

@app.get("/api/v1/languages")
async def list_languages():
    return {
        "languages": [
            {"code": "zh", "name": "中文"},
            {"code": "en", "name": "英语"},
            {"code": "ja", "name": "日语"},
            {"code": "ko", "name": "韩语"},
            {"code": "fr", "name": "法语"},
            {"code": "de", "name": "德语"},
            {"code": "es", "name": "西班牙语"},
            {"code": "ru", "name": "俄语"}
        ]
    }

Flask 集成 #

python
import whisper
from flask import Flask, request, jsonify
from werkzeug.utils import secure_filename
import os
import tempfile

app = Flask(__name__)
app.config['MAX_CONTENT_LENGTH'] = 100 * 1024 * 1024

model = whisper.load_model("base")

ALLOWED_EXTENSIONS = {'mp3', 'wav', 'm4a', 'flac', 'ogg'}

def allowed_file(filename):
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS

@app.route('/health', methods=['GET'])
def health():
    return jsonify({'status': 'healthy'})

@app.route('/transcribe', methods=['POST'])
def transcribe():
    if 'file' not in request.files:
        return jsonify({'error': '没有上传文件'}), 400
    
    file = request.files['file']
    
    if file.filename == '':
        return jsonify({'error': '没有选择文件'}), 400
    
    if not allowed_file(file.filename):
        return jsonify({'error': '不支持的文件格式'}), 400
    
    language = request.form.get('language', None)
    task = request.form.get('task', 'transcribe')
    
    with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
        file.save(tmp.name)
        tmp_path = tmp.name
    
    try:
        if task == 'translate':
            result = model.translate(tmp_path, language=language)
        else:
            result = model.transcribe(tmp_path, language=language)
        
        return jsonify({
            'text': result['text'],
            'language': result['language'],
            'segments': result['segments']
        })
    
    except Exception as e:
        return jsonify({'error': str(e)}), 500
    
    finally:
        os.unlink(tmp_path)

@app.route('/detect-language', methods=['POST'])
def detect_language():
    if 'file' not in request.files:
        return jsonify({'error': '没有上传文件'}), 400
    
    file = request.files['file']
    
    with tempfile.NamedTemporaryFile(delete=False) as tmp:
        file.save(tmp.name)
        tmp_path = tmp.name
    
    try:
        audio = whisper.load_audio(tmp_path)
        audio = whisper.pad_or_trim(audio)
        mel = whisper.log_mel_spectrogram(audio).to(model.device)
        _, probs = model.detect_language(mel)
        
        detected = max(probs, key=probs.get)
        
        return jsonify({
            'language': detected,
            'confidence': float(probs[detected])
        })
    
    finally:
        os.unlink(tmp_path)

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=8000, debug=True)

WebSocket 实时转录 #

python
import whisper
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
import asyncio
import json
import numpy as np

app = FastAPI()

model = whisper.load_model("base")

class ConnectionManager:
    def __init__(self):
        self.active_connections = []
    
    async def connect(self, websocket: WebSocket):
        await websocket.accept()
        self.active_connections.append(websocket)
    
    def disconnect(self, websocket: WebSocket):
        self.active_connections.remove(websocket)
    
    async def send_message(self, message: str, websocket: WebSocket):
        await websocket.send_text(message)

manager = ConnectionManager()

@app.websocket("/ws/transcribe")
async def websocket_transcribe(websocket: WebSocket):
    await manager.connect(websocket)
    
    audio_buffer = []
    buffer_duration = 5
    
    try:
        while True:
            data = await websocket.receive_text()
            message = json.loads(data)
            
            if message["type"] == "audio":
                audio_data = np.frombuffer(
                    bytes.fromhex(message["data"]),
                    dtype=np.int16
                )
                audio_buffer.extend(audio_data)
                
                if len(audio_buffer) >= 16000 * buffer_duration:
                    audio_array = np.array(audio_buffer[:16000 * buffer_duration], dtype=np.float32)
                    audio_array = audio_array / 32768.0
                    
                    audio_array = whisper.pad_or_trim(audio_array)
                    mel = whisper.log_mel_spectrogram(audio_array).to(model.device)
                    
                    options = whisper.DecodingOptions(language="zh")
                    result = whisper.decode(model, mel, options)
                    
                    await manager.send_message(
                        json.dumps({
                            "type": "transcription",
                            "text": result.text
                        }),
                        websocket
                    )
                    
                    audio_buffer = audio_buffer[16000 * buffer_duration:]
            
            elif message["type"] == "config":
                buffer_duration = message.get("buffer_duration", 5)
                await manager.send_message(
                    json.dumps({
                        "type": "config_ack",
                        "buffer_duration": buffer_duration
                    }),
                    websocket
                )
    
    except WebSocketDisconnect:
        manager.disconnect(websocket)

@app.websocket("/ws/stream")
async def websocket_stream(websocket: WebSocket):
    await manager.connect(websocket)
    
    try:
        while True:
            data = await websocket.receive_bytes()
            
            audio = np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
            
            audio = whisper.pad_or_trim(audio)
            mel = whisper.log_mel_spectrogram(audio).to(model.device)
            
            options = whisper.DecodingOptions(language="zh")
            result = whisper.decode(model, mel, options)
            
            await manager.send_message(
                json.dumps({
                    "text": result.text,
                    "language": "zh"
                }),
                websocket
            )
    
    except WebSocketDisconnect:
        manager.disconnect(websocket)

客户端示例 #

Python 客户端 #

python
import requests
import json

class WhisperClient:
    def __init__(self, base_url="http://localhost:8000"):
        self.base_url = base_url
    
    def transcribe(self, audio_path, language=None, task="transcribe"):
        url = f"{self.base_url}/api/v1/transcribe"
        
        with open(audio_path, "rb") as f:
            files = {"file": f}
            params = {"language": language, "task": task}
            
            response = requests.post(url, files=files, params=params)
        
        return response.json()
    
    def detect_language(self, audio_path):
        url = f"{self.base_url}/api/v1/detect-language"
        
        with open(audio_path, "rb") as f:
            files = {"file": f}
            response = requests.post(url, files=files)
        
        return response.json()

client = WhisperClient("http://localhost:8000")

result = client.transcribe("audio.mp3", language="zh")
print(result["text"])

lang = client.detect_language("audio.mp3")
print(f"检测语言: {lang['language']} (置信度: {lang['confidence']:.2%})")

JavaScript 客户端 #

javascript
class WhisperClient {
    constructor(baseUrl = 'http://localhost:8000') {
        this.baseUrl = baseUrl;
    }
    
    async transcribe(audioFile, options = {}) {
        const formData = new FormData();
        formData.append('file', audioFile);
        
        const params = new URLSearchParams();
        if (options.language) params.append('language', options.language);
        if (options.task) params.append('task', options.task);
        
        const response = await fetch(
            `${this.baseUrl}/api/v1/transcribe?${params.toString()}`,
            { method: 'POST', body: formData }
        );
        
        return await response.json();
    }
    
    async detectLanguage(audioFile) {
        const formData = new FormData();
        formData.append('file', audioFile);
        
        const response = await fetch(
            `${this.baseUrl}/api/v1/detect-language`,
            { method: 'POST', body: formData }
        );
        
        return await response.json();
    }
}

const client = new WhisperClient();

document.getElementById('fileInput').addEventListener('change', async (e) => {
    const file = e.target.files[0];
    if (file) {
        const result = await client.transcribe(file, { language: 'zh' });
        console.log(result.text);
    }
});

下一步 #

掌握了 API 集成后,继续学习 生产环境部署 了解如何部署企业级语音识别服务!

最后更新:2026-04-05