排序问题 #
案例概述 #
本案例演示 LightGBM 排序任务(Learning to Rank):
python
import lightgbm as lgb
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
np.random.seed(42)
n_queries = 100
n_docs_per_query = 10
n_features = 20
X = []
y = []
query_sizes = []
for q in range(n_queries):
query_features = np.random.randn(n_docs_per_query, n_features)
relevance = np.random.randint(0, 5, n_docs_per_query)
X.append(query_features)
y.append(relevance)
query_sizes.append(n_docs_per_query)
X = np.vstack(X)
y = np.hstack(y)
print(f"数据形状: {X.shape}")
print(f"查询数量: {n_queries}")
print(f"每个查询文档数: {n_docs_per_query}")
print(f"相关性分数范围: {y.min()} - {y.max()}")
模型训练 #
python
X_train, X_test, y_train, y_test, query_train, query_test = train_test_split(
X, y, query_sizes, test_size=0.2, random_state=42
)
train_data = lgb.Dataset(X_train, label=y_train, group=query_train)
valid_data = lgb.Dataset(X_test, label=y_test, group=query_test, reference=train_data)
params = {
'objective': 'lambdarank',
'metric': 'ndcg',
'num_leaves': 31,
'learning_rate': 0.05,
'verbose': -1
}
model = lgb.train(
params,
train_data,
num_boost_round=500,
valid_sets=[valid_data],
callbacks=[
lgb.log_evaluation(100),
lgb.early_stopping(50)
]
)
print(f"\n最佳 NDCG: {model.best_score['valid_0']['ndcg@1']:.4f}")
预测和评估 #
python
y_pred = model.predict(X_test)
def calculate_ndcg_at_k(y_true, y_pred, query_sizes, k=10):
"""计算 NDCG@k"""
ndcg_scores = []
start_idx = 0
for size in query_sizes:
true = y_true[start_idx:start_idx + size]
pred = y_pred[start_idx:start_idx + size]
sorted_indices = np.argsort(pred)[::-1]
sorted_true = true[sorted_indices]
dcg = np.sum((2**sorted_true[:k] - 1) / np.log2(np.arange(2, k + 2)))
ideal_sorted = np.sort(true)[::-1]
idcg = np.sum((2**ideal_sorted[:k] - 1) / np.log2(np.arange(2, k + 2)))
if idcg > 0:
ndcg_scores.append(dcg / idcg)
start_idx += size
return np.mean(ndcg_scores)
ndcg_1 = calculate_ndcg_at_k(y_test, y_pred, query_test, k=1)
ndcg_5 = calculate_ndcg_at_k(y_test, y_pred, query_test, k=5)
ndcg_10 = calculate_ndcg_at_k(y_test, y_pred, query_test, k=10)
print(f"\nNDCG@1: {ndcg_1:.4f}")
print(f"NDCG@5: {ndcg_5:.4f}")
print(f"NDCG@10: {ndcg_10:.4f}")
总结 #
恭喜你完成了 LightGBM 完全指南的学习!你已经掌握了:
- LightGBM 的基本概念和原理
- 数据处理和特征工程
- 参数配置和调优技巧
- 分类、回归、排序任务实战
- 分布式训练和 GPU 加速
继续实践,探索更多高级应用!
最后更新:2026-04-04