对话模型实战 #

项目概述 #

text
┌─────────────────────────────────────────────────────────────┐
│                   项目目标                                   │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  任务:客服对话助手                                         │
│  ├── 输入:用户问题                                         │
│  ├── 输出:专业回答                                         │
│  └── 目标:准确、专业、友好                                 │
│                                                             │
│  技术栈:                                                    │
│  ├── 模型:Qwen2-7B                                         │
│  ├── 方法:LoRA                                             │
│  ├── 数据:指令格式                                         │
│  └── 部署:vLLM                                             │
│                                                             │
└─────────────────────────────────────────────────────────────┘

项目结构 #

text
chat-assistant/
├── data/
│   ├── raw/                  # 原始对话数据
│   ├── processed/            # 处理后数据
│   └── conversations.json   # 训练数据
├── src/
│   ├── data_preparation.py   # 数据准备
│   ├── train.py             # 训练脚本
│   ├── inference.py         # 推理脚本
│   └── chat.py              # 交互式对话
├── configs/
│   └── config.yaml          # 配置文件
├── models/
│   └── lora/                # LoRA 权重
├── requirements.txt
└── README.md

数据准备 #

数据格式 #

json
{
    "conversations": [
        {
            "role": "system",
            "content": "你是一个专业的客服助手,请用友好、专业的语气回答用户问题。"
        },
        {
            "role": "user",
            "content": "我的订单什么时候能到?"
        },
        {
            "role": "assistant",
            "content": "您好!感谢您的咨询。订单配送时间通常为下单后 2-3 个工作日。您可以在订单详情页面查看实时物流信息。如有其他问题,请随时联系我们。"
        }
    ]
}

数据准备脚本 #

python
import json
from datasets import Dataset
from transformers import AutoTokenizer

def load_conversations(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    return data

def format_conversation(conversation, tokenizer):
    messages = conversation['conversations']
    
    formatted_text = ""
    for msg in messages:
        role = msg['role']
        content = msg['content']
        
        if role == "system":
            formatted_text += f"<|system|>\n{content}\n"
        elif role == "user":
            formatted_text += f"<|user|>\n{content}\n"
        elif role == "assistant":
            formatted_text += f"<|assistant|)\n{content}\n"
    
    formatted_text += tokenizer.eos_token
    
    return formatted_text

def preprocess_dataset(data, tokenizer, max_length=2048):
    formatted_texts = []
    
    for conversation in data:
        text = format_conversation(conversation, tokenizer)
        formatted_texts.append(text)
    
    def tokenize_function(examples):
        model_inputs = tokenizer(
            examples['text'],
            max_length=max_length,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )
        
        labels = model_inputs['input_ids'].clone()
        labels[labels == tokenizer.pad_token_id] = -100
        
        model_inputs['labels'] = labels
        return model_inputs
    
    dataset = Dataset.from_dict({'text': formatted_texts})
    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=['text']
    )
    
    return tokenized_dataset

数据增强 #

python
def augment_conversation(conversation):
    import random
    
    system_variations = [
        "你是一个专业的客服助手,请用友好、专业的语气回答用户问题。",
        "作为客服助手,我会用专业、友好的态度为您服务。",
        "您好,我是客服助手,很高兴为您服务。"
    ]
    
    augmented = conversation.copy()
    if augmented['conversations'][0]['role'] == 'system':
        augmented['conversations'][0]['content'] = random.choice(system_variations)
    
    return augmented

def create_variations(data, n_variations=2):
    augmented_data = []
    
    for conversation in data:
        augmented_data.append(conversation)
        
        for _ in range(n_variations):
            augmented = augment_conversation(conversation)
            augmented_data.append(augmented)
    
    return augmented_data

模型训练 #

训练配置 #

python
import yaml
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model, TaskType
import torch

with open('configs/config.yaml', 'r') as f:
    config = yaml.safe_load(f)

model_name = config['model']['name']
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True
)

lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=config['lora']['r'],
    lora_alpha=config['lora']['alpha'],
    target_modules=config['lora']['target_modules'],
    lora_dropout=config['lora']['dropout'],
    bias="none"
)

model = get_peft_model(model, lora_config)

训练脚本 #

python
def train():
    data = load_conversations(config['data']['path'])
    
    augmented_data = create_variations(data, n_variations=2)
    
    train_size = int(len(augmented_data) * 0.9)
    train_data = augmented_data[:train_size]
    val_data = augmented_data[train_size:]
    
    train_dataset = preprocess_dataset(train_data, tokenizer, config['model']['max_length'])
    val_dataset = preprocess_dataset(val_data, tokenizer, config['model']['max_length'])
    
    training_args = TrainingArguments(
        output_dir=config['output']['dir'],
        num_train_epochs=config['training']['epochs'],
        per_device_train_batch_size=config['training']['batch_size'],
        per_device_eval_batch_size=config['training']['batch_size'],
        gradient_accumulation_steps=config['training']['gradient_accumulation'],
        learning_rate=config['training']['learning_rate'],
        weight_decay=config['training']['weight_decay'],
        warmup_ratio=config['training']['warmup_ratio'],
        lr_scheduler_type=config['training']['lr_scheduler'],
        logging_dir=config['output']['log_dir'],
        logging_steps=config['training']['logging_steps'],
        eval_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        fp16=True,
        gradient_checkpointing=True,
        optim="adamw_8bit",
        report_to="tensorboard"
    )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer
    )
    
    trainer.train()
    
    trainer.save_model(config['output']['model_dir'])
    tokenizer.save_pretrained(config['output']['model_dir'])
    
    return trainer

if __name__ == "__main__":
    train()

配置文件 #

yaml
model:
  name: "Qwen/Qwen2-7B"
  max_length: 2048

lora:
  r: 64
  alpha: 128
  target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
  dropout: 0.05

training:
  epochs: 3
  batch_size: 4
  gradient_accumulation: 8
  learning_rate: 1e-4
  weight_decay: 0.01
  warmup_ratio: 0.1
  lr_scheduler: "cosine"
  logging_steps: 10

data:
  path: "data/conversations.json"

output:
  dir: "outputs"
  log_dir: "logs"
  model_dir: "models/lora"

模型推理 #

对话管理 #

python
from peft import PeftModel
import torch
from typing import List, Dict

class ChatAssistant:
    def __init__(self, base_model_path, lora_model_path):
        self.tokenizer = AutoTokenizer.from_pretrained(
            base_model_path,
            trust_remote_code=True
        )
        
        base_model = AutoModelForCausalLM.from_pretrained(
            base_model_path,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True
        )
        
        self.model = PeftModel.from_pretrained(base_model, lora_model_path)
        self.model.eval()
        
        self.system_prompt = "你是一个专业的客服助手,请用友好、专业的语气回答用户问题。"
        self.history: List[Dict[str, str]] = []
    
    def clear_history(self):
        self.history = []
    
    def chat(self, user_input: str, max_new_tokens: int = 512) -> str:
        self.history.append({"role": "user", "content": user_input})
        
        messages = [{"role": "system", "content": self.system_prompt}]
        messages.extend(self.history)
        
        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        
        inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
        
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=0.7,
                top_p=0.9,
                top_k=50,
                do_sample=True,
                repetition_penalty=1.1,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id
            )
        
        response = self.tokenizer.decode(
            outputs[0][inputs['input_ids'].shape[1]:],
            skip_special_tokens=True
        )
        
        self.history.append({"role": "assistant", "content": response})
        
        return response
    
    def chat_with_context(self, user_input: str, context: str, max_new_tokens: int = 512) -> str:
        enhanced_input = f"背景信息:{context}\n\n用户问题:{user_input}"
        return self.chat(enhanced_input, max_new_tokens)

assistant = ChatAssistant(
    base_model_path="Qwen/Qwen2-7B",
    lora_model_path="models/lora"
)

response = assistant.chat("我的订单什么时候能到?")
print(f"助手: {response}")

交互式对话 #

python
def interactive_chat():
    assistant = ChatAssistant(
        base_model_path="Qwen/Qwen2-7B",
        lora_model_path="models/lora"
    )
    
    print("客服助手已启动,输入 'quit' 退出,输入 'clear' 清空历史")
    print("-" * 50)
    
    while True:
        user_input = input("用户: ").strip()
        
        if user_input.lower() == 'quit':
            print("再见!")
            break
        
        if user_input.lower() == 'clear':
            assistant.clear_history()
            print("对话历史已清空")
            continue
        
        if not user_input:
            continue
        
        response = assistant.chat(user_input)
        print(f"助手: {response}")
        print("-" * 50)

if __name__ == "__main__":
    interactive_chat()

模型评估 #

自动评估 #

python
from typing import List, Dict
import json

def evaluate_responses(model, test_cases: List[Dict]):
    results = []
    
    for case in test_cases:
        user_input = case['user_input']
        expected_keywords = case.get('expected_keywords', [])
        expected_style = case.get('expected_style', 'professional')
        
        response = model.chat(user_input)
        
        keyword_score = sum(1 for kw in expected_keywords if kw in response) / len(expected_keywords) if expected_keywords else 1.0
        
        style_indicators = {
            'professional': ['您好', '感谢', '请'],
            'friendly': ['很高兴', '为您', '帮助'],
            'formal': ['尊敬的', '谨此', '特此']
        }
        
        style_score = sum(1 for ind in style_indicators.get(expected_style, []) if ind in response) / len(style_indicators.get(expected_style, []))
        
        results.append({
            'user_input': user_input,
            'response': response,
            'keyword_score': keyword_score,
            'style_score': style_score,
            'overall_score': (keyword_score + style_score) / 2
        })
    
    avg_score = sum(r['overall_score'] for r in results) / len(results)
    
    print(f"平均得分: {avg_score:.2f}")
    print(f"关键词匹配率: {sum(r['keyword_score'] for r in results) / len(results):.2f}")
    print(f"风格符合率: {sum(r['style_score'] for r in results) / len(results):.2f}")
    
    return results

test_cases = [
    {
        "user_input": "我的订单什么时候能到?",
        "expected_keywords": ["订单", "配送", "工作日"],
        "expected_style": "professional"
    },
    {
        "user_input": "产品有质量问题怎么办?",
        "expected_keywords": ["退换", "售后", "质量"],
        "expected_style": "professional"
    }
]

results = evaluate_responses(assistant, test_cases)

人工评估 #

python
def manual_evaluation(responses: List[Dict]):
    print("人工评估指南:")
    print("1. 准确性:回答是否准确解决问题")
    print("2. 专业性:语气是否专业友好")
    print("3. 完整性:回答是否完整")
    print("4. 安全性:是否有不当内容")
    print("-" * 50)
    
    for i, item in enumerate(responses):
        print(f"\n案例 {i+1}:")
        print(f"用户: {item['user_input']}")
        print(f"助手: {item['response']}")
        
        scores = {}
        for dimension in ['准确性', '专业性', '完整性', '安全性']:
            while True:
                try:
                    score = int(input(f"{dimension} (1-5): "))
                    if 1 <= score <= 5:
                        scores[dimension] = score
                        break
                except ValueError:
                    pass
        
        item['manual_scores'] = scores
    
    return responses

模型部署 #

vLLM 部署 #

python
from vllm import LLM, SamplingParams

class VLLMChatAssistant:
    def __init__(self, model_path):
        self.llm = LLM(model=model_path, trust_remote_code=True)
        self.system_prompt = "你是一个专业的客服助手,请用友好、专业的语气回答用户问题。"
    
    def generate(self, user_input: str, max_tokens: int = 512) -> str:
        prompt = f"<|system|>\n{self.system_prompt}\n<|user|>\n{user_input}\n<|assistant|)\n"
        
        sampling_params = SamplingParams(
            temperature=0.7,
            top_p=0.9,
            max_tokens=max_tokens
        )
        
        outputs = self.llm.generate([prompt], sampling_params)
        
        return outputs[0].outputs[0].text

assistant = VLLMChatAssistant("models/merged")
response = assistant.generate("我的订单什么时候能到?")
print(response)

API 服务 #

python
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
import uvicorn

app = FastAPI(title="客服助手 API")

class ChatRequest(BaseModel):
    message: str
    history: Optional[List[dict]] = []
    context: Optional[str] = None

class ChatResponse(BaseModel):
    response: str
    history: List[dict]

assistant = None

@app.on_event("startup")
async def startup():
    global assistant
    assistant = ChatAssistant(
        base_model_path="Qwen/Qwen2-7B",
        lora_model_path="models/lora"
    )

@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
    if assistant is None:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    assistant.history = request.history.copy()
    
    if request.context:
        response = assistant.chat_with_context(request.message, request.context)
    else:
        response = assistant.chat(request.message)
    
    return ChatResponse(
        response=response,
        history=assistant.history
    )

@app.post("/clear")
async def clear_history():
    if assistant is None:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    assistant.clear_history()
    return {"status": "success"}

@app.get("/health")
async def health_check():
    return {"status": "healthy"}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)

项目总结 #

text
项目成果:
├── 模型:Qwen2-7B + LoRA
├── 对话质量:专业、友好
├── 响应速度:< 1s
└── 部署:vLLM + FastAPI

技术要点:
├── 对话数据格式
├── 多轮对话管理
├── 上下文理解
└── 风格控制

可改进方向:
├── RAG 增强
├── 多模态支持
├── 情感分析
└── 知识图谱

下一步 #

接下来学习 生产部署,了解模型生产环境部署的最佳实践!

最后更新:2026-04-05