API 集成 #
API 集成概述 #
将 Whisper 集成到 API 服务中,可以为各种应用提供语音识别能力。
text
┌─────────────────────────────────────────────────────────────┐
│ API 架构 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 客户端层: │
│ ├── Web 应用 │
│ ├── 移动应用 │
│ └── 桌面应用 │
│ │
│ API 层: │
│ ├── REST API │
│ ├── WebSocket │
│ └── gRPC │
│ │
│ 服务层: │
│ ├── Whisper 模型 │
│ ├── 任务队列 │
│ └── 结果存储 │
│ │
└─────────────────────────────────────────────────────────────┘
FastAPI 集成 #
基本 API 服务 #
python
import whisper
from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks
from fastapi.responses import JSONResponse, FileResponse
from pydantic import BaseModel
from typing import Optional
import tempfile
import os
import uuid
import json
app = FastAPI(
title="Whisper API",
description="语音识别 API 服务",
version="1.0.0"
)
model = whisper.load_model("base")
tasks_store = {}
class TranscriptionResult(BaseModel):
task_id: str
status: str
text: Optional[str] = None
language: Optional[str] = None
duration: Optional[float] = None
class TranscriptionOptions(BaseModel):
language: Optional[str] = None
task: str = "transcribe"
model: str = "base"
@app.post("/transcribe", response_model=TranscriptionResult)
async def transcribe_audio(
file: UploadFile = File(...),
language: Optional[str] = None,
task: str = "transcribe"
):
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
content = await file.read()
tmp.write(content)
tmp_path = tmp.name
try:
if task == "translate":
result = model.translate(tmp_path, language=language)
else:
result = model.transcribe(tmp_path, language=language)
return TranscriptionResult(
task_id=str(uuid.uuid4()),
status="completed",
text=result["text"],
language=result["language"],
duration=result["segments"][-1]["end"] if result["segments"] else 0
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
os.unlink(tmp_path)
@app.post("/transcribe/async")
async def transcribe_async(
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
language: Optional[str] = None
):
task_id = str(uuid.uuid4())
tasks_store[task_id] = {
"status": "processing",
"text": None,
"language": None
}
content = await file.read()
background_tasks.add_task(
process_transcription,
task_id,
content,
file.filename,
language
)
return {"task_id": task_id, "status": "processing"}
async def process_transcription(task_id, content, filename, language):
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(filename)[1]) as tmp:
tmp.write(content)
tmp_path = tmp.name
try:
result = model.transcribe(tmp_path, language=language)
tasks_store[task_id] = {
"status": "completed",
"text": result["text"],
"language": result["language"]
}
except Exception as e:
tasks_store[task_id] = {
"status": "error",
"error": str(e)
}
finally:
os.unlink(tmp_path)
@app.get("/tasks/{task_id}")
async def get_task_status(task_id: str):
if task_id not in tasks_store:
raise HTTPException(status_code=404, detail="Task not found")
return tasks_store[task_id]
@app.get("/health")
async def health_check():
return {"status": "healthy", "model": "base"}
@app.get("/")
async def root():
return {
"message": "Whisper API",
"endpoints": {
"/transcribe": "POST - 同步转录",
"/transcribe/async": "POST - 异步转录",
"/tasks/{task_id}": "GET - 查询任务状态",
"/health": "GET - 健康检查"
}
}
完整 API 服务 #
python
import whisper
from fastapi import FastAPI, UploadFile, File, HTTPException, Query, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import Optional, List
import tempfile
import os
import uuid
import asyncio
from datetime import datetime
import json
app = FastAPI(
title="Whisper API",
description="企业级语音识别 API 服务",
version="2.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
models_cache = {}
def get_model(model_size: str = "base"):
if model_size not in models_cache:
models_cache[model_size] = whisper.load_model(model_size)
return models_cache[model_size]
class TranscriptionResponse(BaseModel):
task_id: str
status: str
text: Optional[str] = None
language: Optional[str] = None
duration: Optional[float] = None
segments: Optional[List[dict]] = None
created_at: str
class ErrorResponse(BaseModel):
error: str
detail: str
@app.post("/api/v1/transcribe", response_model=TranscriptionResponse)
async def transcribe(
file: UploadFile = File(...),
model_size: str = Query("base", enum=["tiny", "base", "small", "medium", "large"]),
language: Optional[str] = Query(None),
task: str = Query("transcribe", enum=["transcribe", "translate"]),
word_timestamps: bool = Query(False),
output_format: str = Query("json", enum=["json", "text", "srt", "vtt"])
):
task_id = str(uuid.uuid4())
valid_extensions = [".mp3", ".wav", ".m4a", ".flac", ".ogg", ".webm"]
file_ext = os.path.splitext(file.filename)[1].lower()
if file_ext not in valid_extensions:
raise HTTPException(
status_code=400,
detail=f"不支持的文件格式: {file_ext}"
)
with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as tmp:
content = await file.read()
tmp.write(content)
tmp_path = tmp.name
try:
model = get_model(model_size)
if task == "translate":
result = model.translate(tmp_path, language=language)
else:
result = model.transcribe(
tmp_path,
language=language,
word_timestamps=word_timestamps
)
duration = result["segments"][-1]["end"] if result["segments"] else 0
response = TranscriptionResponse(
task_id=task_id,
status="completed",
text=result["text"],
language=result["language"],
duration=duration,
segments=result["segments"] if output_format == "json" else None,
created_at=datetime.now().isoformat()
)
return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
os.unlink(tmp_path)
@app.post("/api/v1/detect-language")
async def detect_language(file: UploadFile = File(...)):
with tempfile.NamedTemporaryFile(delete=False) as tmp:
content = await file.read()
tmp.write(content)
tmp_path = tmp.name
try:
audio = whisper.load_audio(tmp_path)
audio = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio).to(next(model.parameters()).device)
model = get_model("base")
_, probs = model.detect_language(mel)
detected_lang = max(probs, key=probs.get)
confidence = probs[detected_lang]
return {
"language": detected_lang,
"confidence": confidence,
"all_probabilities": {k: float(v) for k, v in probs.items()}
}
finally:
os.unlink(tmp_path)
@app.get("/api/v1/models")
async def list_models():
return {
"models": [
{"name": "tiny", "size": "39M", "speed": "32x"},
{"name": "base", "size": "74M", "speed": "16x"},
{"name": "small", "size": "244M", "speed": "6x"},
{"name": "medium", "size": "769M", "speed": "2x"},
{"name": "large", "size": "1550M", "speed": "1x"}
]
}
@app.get("/api/v1/languages")
async def list_languages():
return {
"languages": [
{"code": "zh", "name": "中文"},
{"code": "en", "name": "英语"},
{"code": "ja", "name": "日语"},
{"code": "ko", "name": "韩语"},
{"code": "fr", "name": "法语"},
{"code": "de", "name": "德语"},
{"code": "es", "name": "西班牙语"},
{"code": "ru", "name": "俄语"}
]
}
Flask 集成 #
python
import whisper
from flask import Flask, request, jsonify
from werkzeug.utils import secure_filename
import os
import tempfile
app = Flask(__name__)
app.config['MAX_CONTENT_LENGTH'] = 100 * 1024 * 1024
model = whisper.load_model("base")
ALLOWED_EXTENSIONS = {'mp3', 'wav', 'm4a', 'flac', 'ogg'}
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
@app.route('/health', methods=['GET'])
def health():
return jsonify({'status': 'healthy'})
@app.route('/transcribe', methods=['POST'])
def transcribe():
if 'file' not in request.files:
return jsonify({'error': '没有上传文件'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': '没有选择文件'}), 400
if not allowed_file(file.filename):
return jsonify({'error': '不支持的文件格式'}), 400
language = request.form.get('language', None)
task = request.form.get('task', 'transcribe')
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
file.save(tmp.name)
tmp_path = tmp.name
try:
if task == 'translate':
result = model.translate(tmp_path, language=language)
else:
result = model.transcribe(tmp_path, language=language)
return jsonify({
'text': result['text'],
'language': result['language'],
'segments': result['segments']
})
except Exception as e:
return jsonify({'error': str(e)}), 500
finally:
os.unlink(tmp_path)
@app.route('/detect-language', methods=['POST'])
def detect_language():
if 'file' not in request.files:
return jsonify({'error': '没有上传文件'}), 400
file = request.files['file']
with tempfile.NamedTemporaryFile(delete=False) as tmp:
file.save(tmp.name)
tmp_path = tmp.name
try:
audio = whisper.load_audio(tmp_path)
audio = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio).to(model.device)
_, probs = model.detect_language(mel)
detected = max(probs, key=probs.get)
return jsonify({
'language': detected,
'confidence': float(probs[detected])
})
finally:
os.unlink(tmp_path)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=8000, debug=True)
WebSocket 实时转录 #
python
import whisper
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
import asyncio
import json
import numpy as np
app = FastAPI()
model = whisper.load_model("base")
class ConnectionManager:
def __init__(self):
self.active_connections = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def send_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message)
manager = ConnectionManager()
@app.websocket("/ws/transcribe")
async def websocket_transcribe(websocket: WebSocket):
await manager.connect(websocket)
audio_buffer = []
buffer_duration = 5
try:
while True:
data = await websocket.receive_text()
message = json.loads(data)
if message["type"] == "audio":
audio_data = np.frombuffer(
bytes.fromhex(message["data"]),
dtype=np.int16
)
audio_buffer.extend(audio_data)
if len(audio_buffer) >= 16000 * buffer_duration:
audio_array = np.array(audio_buffer[:16000 * buffer_duration], dtype=np.float32)
audio_array = audio_array / 32768.0
audio_array = whisper.pad_or_trim(audio_array)
mel = whisper.log_mel_spectrogram(audio_array).to(model.device)
options = whisper.DecodingOptions(language="zh")
result = whisper.decode(model, mel, options)
await manager.send_message(
json.dumps({
"type": "transcription",
"text": result.text
}),
websocket
)
audio_buffer = audio_buffer[16000 * buffer_duration:]
elif message["type"] == "config":
buffer_duration = message.get("buffer_duration", 5)
await manager.send_message(
json.dumps({
"type": "config_ack",
"buffer_duration": buffer_duration
}),
websocket
)
except WebSocketDisconnect:
manager.disconnect(websocket)
@app.websocket("/ws/stream")
async def websocket_stream(websocket: WebSocket):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_bytes()
audio = np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
audio = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio).to(model.device)
options = whisper.DecodingOptions(language="zh")
result = whisper.decode(model, mel, options)
await manager.send_message(
json.dumps({
"text": result.text,
"language": "zh"
}),
websocket
)
except WebSocketDisconnect:
manager.disconnect(websocket)
客户端示例 #
Python 客户端 #
python
import requests
import json
class WhisperClient:
def __init__(self, base_url="http://localhost:8000"):
self.base_url = base_url
def transcribe(self, audio_path, language=None, task="transcribe"):
url = f"{self.base_url}/api/v1/transcribe"
with open(audio_path, "rb") as f:
files = {"file": f}
params = {"language": language, "task": task}
response = requests.post(url, files=files, params=params)
return response.json()
def detect_language(self, audio_path):
url = f"{self.base_url}/api/v1/detect-language"
with open(audio_path, "rb") as f:
files = {"file": f}
response = requests.post(url, files=files)
return response.json()
client = WhisperClient("http://localhost:8000")
result = client.transcribe("audio.mp3", language="zh")
print(result["text"])
lang = client.detect_language("audio.mp3")
print(f"检测语言: {lang['language']} (置信度: {lang['confidence']:.2%})")
JavaScript 客户端 #
javascript
class WhisperClient {
constructor(baseUrl = 'http://localhost:8000') {
this.baseUrl = baseUrl;
}
async transcribe(audioFile, options = {}) {
const formData = new FormData();
formData.append('file', audioFile);
const params = new URLSearchParams();
if (options.language) params.append('language', options.language);
if (options.task) params.append('task', options.task);
const response = await fetch(
`${this.baseUrl}/api/v1/transcribe?${params.toString()}`,
{ method: 'POST', body: formData }
);
return await response.json();
}
async detectLanguage(audioFile) {
const formData = new FormData();
formData.append('file', audioFile);
const response = await fetch(
`${this.baseUrl}/api/v1/detect-language`,
{ method: 'POST', body: formData }
);
return await response.json();
}
}
const client = new WhisperClient();
document.getElementById('fileInput').addEventListener('change', async (e) => {
const file = e.target.files[0];
if (file) {
const result = await client.transcribe(file, { language: 'zh' });
console.log(result.text);
}
});
下一步 #
掌握了 API 集成后,继续学习 生产环境部署 了解如何部署企业级语音识别服务!
最后更新:2026-04-05