高级用法 #

高级功能概览 #

除了基本的转录和翻译功能,Whisper 还支持多种高级用法,可以满足更复杂的应用场景。

text
┌─────────────────────────────────────────────────────────────┐
│                    高级功能                                  │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  处理模式:                                                   │
│  ├── 流式处理                                               │
│  ├── 实时转录                                               │
│  └── 长音频处理                                             │
│                                                             │
│  增强功能:                                                   │
│  ├── 说话人分离                                             │
│  ├── 时间戳对齐                                             │
│  └── 后处理优化                                             │
│                                                             │
│  集成扩展:                                                   │
│  ├── WhisperX                                               │
│  ├── 自定义模型                                             │
│  └── 工作流集成                                             │
│                                                             │
└─────────────────────────────────────────────────────────────┘

流式处理 #

基本流式转录 #

python
import whisper
import pyaudio
import wave
import threading
import queue
import tempfile
import os

class StreamingTranscriber:
    def __init__(self, model_size="base"):
        self.model = whisper.load_model(model_size)
        self.audio_queue = queue.Queue()
        self.is_running = False
        
    def start_recording(self, chunk_duration=5):
        self.is_running = True
        
        CHUNK = 1024
        FORMAT = pyaudio.paInt16
        CHANNELS = 1
        RATE = 16000
        
        p = pyaudio.PyAudio()
        
        stream = p.open(
            format=FORMAT,
            channels=CHANNELS,
            rate=RATE,
            input=True,
            frames_per_buffer=CHUNK
        )
        
        frames_per_chunk = int(RATE / CHUNK * chunk_duration)
        
        while self.is_running:
            frames = []
            for _ in range(frames_per_chunk):
                data = stream.read(CHUNK)
                frames.append(data)
            
            self.audio_queue.put(b''.join(frames))
        
        stream.stop_stream()
        stream.close()
        p.terminate()
    
    def transcribe_stream(self):
        while self.is_running or not self.audio_queue.empty():
            try:
                audio_data = self.audio_queue.get(timeout=1)
                
                with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
                    temp_path = f.name
                    wf = wave.open(temp_path, 'wb')
                    wf.setnchannels(1)
                    wf.setsampwidth(2)
                    wf.setframerate(16000)
                    wf.writeframes(audio_data)
                    wf.close()
                
                result = self.model.transcribe(temp_path)
                os.unlink(temp_path)
                
                if result["text"].strip():
                    print(f"转录: {result['text']}")
                    
            except queue.Empty:
                continue
    
    def start(self):
        record_thread = threading.Thread(target=self.start_recording)
        transcribe_thread = threading.Thread(target=self.transcribe_stream)
        
        record_thread.start()
        transcribe_thread.start()
        
        return record_thread, transcribe_thread
    
    def stop(self):
        self.is_running = False

transcriber = StreamingTranscriber("base")
record_thread, transcribe_thread = transcriber.start()

input("按 Enter 键停止...")
transcriber.stop()
record_thread.join()
transcribe_thread.join()

使用 whisper-live #

bash
pip install whisper-live
python
from whisper_live.client import TranscriptionClient

client = TranscriptionClient(
    "localhost",
    9090,
    lang="zh",
    translate=False,
    model_type="base"
)

client()

说话人分离 #

使用 pyannote.audio #

bash
pip install pyannote.audio
python
import whisper
from pyannote.audio import Pipeline
import torch

whisper_model = whisper.load_model("base")

diarization_pipeline = Pipeline.from_pretrained(
    "pyannote/speaker-diarization",
    use_auth_token="YOUR_HF_TOKEN"
)

def transcribe_with_diarization(audio_path):
    result = whisper_model.transcribe(audio_path)
    
    diarization = diarization_pipeline(audio_path)
    
    speaker_segments = []
    for turn, _, speaker in diarization.itertracks(yield_label=True):
        speaker_segments.append({
            "start": turn.start,
            "end": turn.end,
            "speaker": speaker
        })
    
    output = []
    for segment in result["segments"]:
        segment_start = segment["start"]
        segment_end = segment["end"]
        
        for speaker_seg in speaker_segments:
            if (speaker_seg["start"] <= segment_start <= speaker_seg["end"] or
                speaker_seg["start"] <= segment_end <= speaker_seg["end"]):
                output.append({
                    "speaker": speaker_seg["speaker"],
                    "start": segment_start,
                    "end": segment_end,
                    "text": segment["text"]
                })
                break
    
    return output

result = transcribe_with_diarization("conversation.mp3")

for item in result:
    print(f"[{item['speaker']}] {item['text']}")

使用 WhisperX #

bash
pip install whisperx
python
import whisperx
import gc

device = "cuda"
audio_file = "audio.mp3"
batch_size = 16
compute_type = "float16"

model = whisperx.load_model("base", device, compute_type=compute_type)

audio = whisperx.load_audio(audio_file)
result = model.transcribe(audio, batch_size=batch_size)

print(result["segments"])

model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
result = whisperx.align(result["segments"], model_a, metadata, audio, device)

diarize_model = whisperx.DiarizationPipeline(use_auth_token="YOUR_HF_TOKEN", device=device)
diarize_segments = diarize_model(audio)
result = whisperx.assign_word_speakers(diarize_segments, result)

for segment in result["segments"]:
    speaker = segment.get("speaker", "Unknown")
    print(f"[{speaker}] {segment['text']}")

时间戳对齐 #

强制对齐 #

python
import whisper
import dtw

def align_transcription(audio_path, reference_text):
    model = whisper.load_model("base")
    result = model.transcribe(audio_path, word_timestamps=True)
    
    words_with_timestamps = []
    for segment in result["segments"]:
        if "words" in segment:
            for word in segment["words"]:
                words_with_timestamps.append({
                    "word": word["word"],
                    "start": word["start"],
                    "end": word["end"]
                })
    
    return words_with_timestamps

timestamps = align_transcription("audio.mp3", "参考文本")

for ts in timestamps:
    print(f"[{ts['start']:.2f}s - {ts['end']:.2f}s] {ts['word']}")

精确词级时间戳 #

python
import whisper

model = whisper.load_model("base")
result = model.transcribe("audio.mp3", word_timestamps=True)

def get_word_level_timestamps(result):
    word_timestamps = []
    
    for segment in result["segments"]:
        if "words" in segment:
            for word_info in segment["words"]:
                word_timestamps.append({
                    "word": word_info["word"].strip(),
                    "start": round(word_info["start"], 3),
                    "end": round(word_info["end"], 3),
                    "probability": round(word_info.get("probability", 1.0), 3)
                })
    
    return word_timestamps

words = get_word_level_timestamps(result)

for word in words:
    print(f"{word['word']}: {word['start']:.3f}s - {word['end']:.3f}s (置信度: {word['probability']:.2%})")

后处理优化 #

文本清理 #

python
import whisper
import re

def clean_transcription(text):
    text = re.sub(r'\s+', ' ', text)
    
    text = text.strip()
    
    text = re.sub(r'\s+([,.!?])', r'\1', text)
    text = re.sub(r'([,.!?])([^\s,.!?])', r'\1 \2', text)
    
    return text

model = whisper.load_model("base")
result = model.transcribe("audio.mp3")

original_text = result["text"]
cleaned_text = clean_transcription(original_text)

print(f"原文: {original_text}")
print(f"清理后: {cleaned_text}")

标点优化 #

python
import whisper

model = whisper.load_model("base")
result = model.transcribe("audio.mp3")

def optimize_punctuation(segments):
    optimized = []
    
    for i, segment in enumerate(segments):
        text = segment["text"].strip()
        
        if i < len(segments) - 1:
            next_text = segments[i + 1]["text"].strip()
            if next_text and next_text[0].islower():
                text = text.rstrip('.!?') + ','
        
        optimized.append({
            **segment,
            "text": text
        })
    
    return optimized

optimized_segments = optimize_punctuation(result["segments"])

for seg in optimized_segments:
    print(seg["text"])

专有名词纠正 #

python
import whisper
import re

CORRECTIONS = {
    "openai": "OpenAI",
    "gpt": "GPT",
    "api": "API",
    "python": "Python",
    "javascript": "JavaScript",
    "tensorflow": "TensorFlow",
    "pytorch": "PyTorch"
}

def correct_proper_nouns(text):
    words = text.split()
    corrected_words = []
    
    for word in words:
        lower_word = word.lower().strip('.,!?')
        if lower_word in CORRECTIONS:
            corrected = CORRECTIONS[lower_word]
            if word[0].isupper():
                corrected = corrected
            else:
                corrected = corrected.lower()
            corrected_words.append(corrected)
        else:
            corrected_words.append(word)
    
    return ' '.join(corrected_words)

model = whisper.load_model("base")
result = model.transcribe("tech_talk.mp3")

original = result["text"]
corrected = correct_proper_nouns(original)

print(f"原文: {original}")
print(f"纠正后: {corrected}")

自定义模型 #

加载本地模型 #

python
import whisper

model = whisper.load_model("/path/to/custom/model.pt")

result = model.transcribe("audio.mp3")
print(result["text"])

模型微调 #

python
import whisper
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

class AudioDataset(Dataset):
    def __init__(self, audio_paths, texts):
        self.audio_paths = audio_paths
        self.texts = texts
    
    def __len__(self):
        return len(self.audio_paths)
    
    def __getitem__(self, idx):
        audio = whisper.load_audio(self.audio_paths[idx])
        audio = whisper.pad_or_trim(audio)
        mel = whisper.log_mel_spectrogram(audio)
        
        text = self.texts[idx]
        tokens = whisper.tokenizer.get_tokenizer().encode(text)
        
        return mel, torch.tensor(tokens)

model = whisper.load_model("base")

for param in model.encoder.parameters():
    param.requires_grad = False

optimizer = torch.optim.Adam(model.decoder.parameters(), lr=1e-5)

def train_step(mel, tokens):
    optimizer.zero_grad()
    
    logits = model(mel, tokens)
    
    loss = nn.CrossEntropyLoss()(logits.view(-1, logits.size(-1)), tokens.view(-1))
    
    loss.backward()
    optimizer.step()
    
    return loss.item()

工作流集成 #

与 LangChain 集成 #

bash
pip install langchain
python
import whisper
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter

model = whisper.load_model("base")

def audio_to_documents(audio_path, chunk_size=1000, chunk_overlap=200):
    result = model.transcribe(audio_path)
    
    text = result["text"]
    
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap
    )
    
    chunks = text_splitter.split_text(text)
    
    documents = []
    for i, chunk in enumerate(chunks):
        doc = Document(
            page_content=chunk,
            metadata={
                "source": audio_path,
                "chunk": i,
                "total_chunks": len(chunks)
            }
        )
        documents.append(doc)
    
    return documents

documents = audio_to_documents("podcast.mp3")

for doc in documents[:3]:
    print(f"Chunk {doc.metadata['chunk']}: {doc.page_content[:100]}...")

与 Gradio 集成 #

bash
pip install gradio
python
import whisper
import gradio as gr

model = whisper.load_model("base")

def transcribe_audio(audio_file):
    if audio_file is None:
        return "请上传音频文件"
    
    result = model.transcribe(audio_file)
    
    return result["text"]

demo = gr.Interface(
    fn=transcribe_audio,
    inputs=gr.Audio(type="filepath", label="上传音频"),
    outputs=gr.Textbox(label="转录结果", lines=10),
    title="Whisper 语音转录",
    description="上传音频文件,自动转录为文本"
)

demo.launch()

与 FastAPI 集成 #

python
import whisper
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
import tempfile
import os

app = FastAPI(title="Whisper API")

model = whisper.load_model("base")

@app.post("/transcribe")
async def transcribe(file: UploadFile = File(...)):
    with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
        tmp.write(await file.read())
        tmp_path = tmp.name
    
    try:
        result = model.transcribe(tmp_path)
        
        return JSONResponse({
            "text": result["text"],
            "language": result["language"],
            "segments": result["segments"]
        })
    finally:
        os.unlink(tmp_path)

@app.post("/translate")
async def translate(file: UploadFile = File(...)):
    with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
        tmp.write(await file.read())
        tmp_path = tmp.name
    
    try:
        result = model.translate(tmp_path)
        
        return JSONResponse({
            "text": result["text"],
            "language": result["language"]
        })
    finally:
        os.unlink(tmp_path)

@app.get("/health")
async def health():
    return {"status": "healthy"}

下一步 #

掌握了高级用法后,继续学习 批量处理 了解如何高效处理大量音频文件!

最后更新:2026-04-05