高级用法 #
高级功能概览 #
除了基本的转录和翻译功能,Whisper 还支持多种高级用法,可以满足更复杂的应用场景。
text
┌─────────────────────────────────────────────────────────────┐
│ 高级功能 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 处理模式: │
│ ├── 流式处理 │
│ ├── 实时转录 │
│ └── 长音频处理 │
│ │
│ 增强功能: │
│ ├── 说话人分离 │
│ ├── 时间戳对齐 │
│ └── 后处理优化 │
│ │
│ 集成扩展: │
│ ├── WhisperX │
│ ├── 自定义模型 │
│ └── 工作流集成 │
│ │
└─────────────────────────────────────────────────────────────┘
流式处理 #
基本流式转录 #
python
import whisper
import pyaudio
import wave
import threading
import queue
import tempfile
import os
class StreamingTranscriber:
def __init__(self, model_size="base"):
self.model = whisper.load_model(model_size)
self.audio_queue = queue.Queue()
self.is_running = False
def start_recording(self, chunk_duration=5):
self.is_running = True
CHUNK = 1024
FORMAT = pyaudio.paInt16
CHANNELS = 1
RATE = 16000
p = pyaudio.PyAudio()
stream = p.open(
format=FORMAT,
channels=CHANNELS,
rate=RATE,
input=True,
frames_per_buffer=CHUNK
)
frames_per_chunk = int(RATE / CHUNK * chunk_duration)
while self.is_running:
frames = []
for _ in range(frames_per_chunk):
data = stream.read(CHUNK)
frames.append(data)
self.audio_queue.put(b''.join(frames))
stream.stop_stream()
stream.close()
p.terminate()
def transcribe_stream(self):
while self.is_running or not self.audio_queue.empty():
try:
audio_data = self.audio_queue.get(timeout=1)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
temp_path = f.name
wf = wave.open(temp_path, 'wb')
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(16000)
wf.writeframes(audio_data)
wf.close()
result = self.model.transcribe(temp_path)
os.unlink(temp_path)
if result["text"].strip():
print(f"转录: {result['text']}")
except queue.Empty:
continue
def start(self):
record_thread = threading.Thread(target=self.start_recording)
transcribe_thread = threading.Thread(target=self.transcribe_stream)
record_thread.start()
transcribe_thread.start()
return record_thread, transcribe_thread
def stop(self):
self.is_running = False
transcriber = StreamingTranscriber("base")
record_thread, transcribe_thread = transcriber.start()
input("按 Enter 键停止...")
transcriber.stop()
record_thread.join()
transcribe_thread.join()
使用 whisper-live #
bash
pip install whisper-live
python
from whisper_live.client import TranscriptionClient
client = TranscriptionClient(
"localhost",
9090,
lang="zh",
translate=False,
model_type="base"
)
client()
说话人分离 #
使用 pyannote.audio #
bash
pip install pyannote.audio
python
import whisper
from pyannote.audio import Pipeline
import torch
whisper_model = whisper.load_model("base")
diarization_pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization",
use_auth_token="YOUR_HF_TOKEN"
)
def transcribe_with_diarization(audio_path):
result = whisper_model.transcribe(audio_path)
diarization = diarization_pipeline(audio_path)
speaker_segments = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
speaker_segments.append({
"start": turn.start,
"end": turn.end,
"speaker": speaker
})
output = []
for segment in result["segments"]:
segment_start = segment["start"]
segment_end = segment["end"]
for speaker_seg in speaker_segments:
if (speaker_seg["start"] <= segment_start <= speaker_seg["end"] or
speaker_seg["start"] <= segment_end <= speaker_seg["end"]):
output.append({
"speaker": speaker_seg["speaker"],
"start": segment_start,
"end": segment_end,
"text": segment["text"]
})
break
return output
result = transcribe_with_diarization("conversation.mp3")
for item in result:
print(f"[{item['speaker']}] {item['text']}")
使用 WhisperX #
bash
pip install whisperx
python
import whisperx
import gc
device = "cuda"
audio_file = "audio.mp3"
batch_size = 16
compute_type = "float16"
model = whisperx.load_model("base", device, compute_type=compute_type)
audio = whisperx.load_audio(audio_file)
result = model.transcribe(audio, batch_size=batch_size)
print(result["segments"])
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
result = whisperx.align(result["segments"], model_a, metadata, audio, device)
diarize_model = whisperx.DiarizationPipeline(use_auth_token="YOUR_HF_TOKEN", device=device)
diarize_segments = diarize_model(audio)
result = whisperx.assign_word_speakers(diarize_segments, result)
for segment in result["segments"]:
speaker = segment.get("speaker", "Unknown")
print(f"[{speaker}] {segment['text']}")
时间戳对齐 #
强制对齐 #
python
import whisper
import dtw
def align_transcription(audio_path, reference_text):
model = whisper.load_model("base")
result = model.transcribe(audio_path, word_timestamps=True)
words_with_timestamps = []
for segment in result["segments"]:
if "words" in segment:
for word in segment["words"]:
words_with_timestamps.append({
"word": word["word"],
"start": word["start"],
"end": word["end"]
})
return words_with_timestamps
timestamps = align_transcription("audio.mp3", "参考文本")
for ts in timestamps:
print(f"[{ts['start']:.2f}s - {ts['end']:.2f}s] {ts['word']}")
精确词级时间戳 #
python
import whisper
model = whisper.load_model("base")
result = model.transcribe("audio.mp3", word_timestamps=True)
def get_word_level_timestamps(result):
word_timestamps = []
for segment in result["segments"]:
if "words" in segment:
for word_info in segment["words"]:
word_timestamps.append({
"word": word_info["word"].strip(),
"start": round(word_info["start"], 3),
"end": round(word_info["end"], 3),
"probability": round(word_info.get("probability", 1.0), 3)
})
return word_timestamps
words = get_word_level_timestamps(result)
for word in words:
print(f"{word['word']}: {word['start']:.3f}s - {word['end']:.3f}s (置信度: {word['probability']:.2%})")
后处理优化 #
文本清理 #
python
import whisper
import re
def clean_transcription(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
text = re.sub(r'\s+([,.!?])', r'\1', text)
text = re.sub(r'([,.!?])([^\s,.!?])', r'\1 \2', text)
return text
model = whisper.load_model("base")
result = model.transcribe("audio.mp3")
original_text = result["text"]
cleaned_text = clean_transcription(original_text)
print(f"原文: {original_text}")
print(f"清理后: {cleaned_text}")
标点优化 #
python
import whisper
model = whisper.load_model("base")
result = model.transcribe("audio.mp3")
def optimize_punctuation(segments):
optimized = []
for i, segment in enumerate(segments):
text = segment["text"].strip()
if i < len(segments) - 1:
next_text = segments[i + 1]["text"].strip()
if next_text and next_text[0].islower():
text = text.rstrip('.!?') + ','
optimized.append({
**segment,
"text": text
})
return optimized
optimized_segments = optimize_punctuation(result["segments"])
for seg in optimized_segments:
print(seg["text"])
专有名词纠正 #
python
import whisper
import re
CORRECTIONS = {
"openai": "OpenAI",
"gpt": "GPT",
"api": "API",
"python": "Python",
"javascript": "JavaScript",
"tensorflow": "TensorFlow",
"pytorch": "PyTorch"
}
def correct_proper_nouns(text):
words = text.split()
corrected_words = []
for word in words:
lower_word = word.lower().strip('.,!?')
if lower_word in CORRECTIONS:
corrected = CORRECTIONS[lower_word]
if word[0].isupper():
corrected = corrected
else:
corrected = corrected.lower()
corrected_words.append(corrected)
else:
corrected_words.append(word)
return ' '.join(corrected_words)
model = whisper.load_model("base")
result = model.transcribe("tech_talk.mp3")
original = result["text"]
corrected = correct_proper_nouns(original)
print(f"原文: {original}")
print(f"纠正后: {corrected}")
自定义模型 #
加载本地模型 #
python
import whisper
model = whisper.load_model("/path/to/custom/model.pt")
result = model.transcribe("audio.mp3")
print(result["text"])
模型微调 #
python
import whisper
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
class AudioDataset(Dataset):
def __init__(self, audio_paths, texts):
self.audio_paths = audio_paths
self.texts = texts
def __len__(self):
return len(self.audio_paths)
def __getitem__(self, idx):
audio = whisper.load_audio(self.audio_paths[idx])
audio = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio)
text = self.texts[idx]
tokens = whisper.tokenizer.get_tokenizer().encode(text)
return mel, torch.tensor(tokens)
model = whisper.load_model("base")
for param in model.encoder.parameters():
param.requires_grad = False
optimizer = torch.optim.Adam(model.decoder.parameters(), lr=1e-5)
def train_step(mel, tokens):
optimizer.zero_grad()
logits = model(mel, tokens)
loss = nn.CrossEntropyLoss()(logits.view(-1, logits.size(-1)), tokens.view(-1))
loss.backward()
optimizer.step()
return loss.item()
工作流集成 #
与 LangChain 集成 #
bash
pip install langchain
python
import whisper
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
model = whisper.load_model("base")
def audio_to_documents(audio_path, chunk_size=1000, chunk_overlap=200):
result = model.transcribe(audio_path)
text = result["text"]
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
chunks = text_splitter.split_text(text)
documents = []
for i, chunk in enumerate(chunks):
doc = Document(
page_content=chunk,
metadata={
"source": audio_path,
"chunk": i,
"total_chunks": len(chunks)
}
)
documents.append(doc)
return documents
documents = audio_to_documents("podcast.mp3")
for doc in documents[:3]:
print(f"Chunk {doc.metadata['chunk']}: {doc.page_content[:100]}...")
与 Gradio 集成 #
bash
pip install gradio
python
import whisper
import gradio as gr
model = whisper.load_model("base")
def transcribe_audio(audio_file):
if audio_file is None:
return "请上传音频文件"
result = model.transcribe(audio_file)
return result["text"]
demo = gr.Interface(
fn=transcribe_audio,
inputs=gr.Audio(type="filepath", label="上传音频"),
outputs=gr.Textbox(label="转录结果", lines=10),
title="Whisper 语音转录",
description="上传音频文件,自动转录为文本"
)
demo.launch()
与 FastAPI 集成 #
python
import whisper
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
import tempfile
import os
app = FastAPI(title="Whisper API")
model = whisper.load_model("base")
@app.post("/transcribe")
async def transcribe(file: UploadFile = File(...)):
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
tmp.write(await file.read())
tmp_path = tmp.name
try:
result = model.transcribe(tmp_path)
return JSONResponse({
"text": result["text"],
"language": result["language"],
"segments": result["segments"]
})
finally:
os.unlink(tmp_path)
@app.post("/translate")
async def translate(file: UploadFile = File(...)):
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
tmp.write(await file.read())
tmp_path = tmp.name
try:
result = model.translate(tmp_path)
return JSONResponse({
"text": result["text"],
"language": result["language"]
})
finally:
os.unlink(tmp_path)
@app.get("/health")
async def health():
return {"status": "healthy"}
下一步 #
掌握了高级用法后,继续学习 批量处理 了解如何高效处理大量音频文件!
最后更新:2026-04-05