多分类问题 #
案例概述 #
本案例使用鸢尾花数据集演示 LightGBM 多分类任务:
python
import lightgbm as lgb
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
accuracy_score, classification_report,
confusion_matrix
)
import matplotlib.pyplot as plt
import seaborn as sns
data = load_iris()
X = pd.DataFrame(data.data, columns=data.feature_names)
y = data.target
print(f"数据形状: {X.shape}")
print(f"类别数量: {len(np.unique(y))}")
print(f"类别分布: {np.bincount(y)}")
print(f"类别名称: {data.target_names}")
模型训练 #
python
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
train_data = lgb.Dataset(X_train, label=y_train)
valid_data = lgb.Dataset(X_test, label=y_test, reference=train_data)
params = {
'objective': 'multiclass',
'num_class': 3,
'metric': 'multi_logloss',
'num_leaves': 31,
'learning_rate': 0.05,
'verbose': -1
}
model = lgb.train(
params,
train_data,
num_boost_round=500,
valid_sets=[valid_data],
callbacks=[
lgb.log_evaluation(100),
lgb.early_stopping(50)
]
)
模型评估 #
python
y_pred_proba = model.predict(X_test, num_iteration=model.best_iteration)
y_pred = np.argmax(y_pred_proba, axis=1)
print(f"\n准确率: {accuracy_score(y_test, y_pred):.4f}")
print("\n分类报告:")
print(classification_report(y_test, y_pred, target_names=data.target_names))
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=data.target_names, yticklabels=data.target_names)
plt.xlabel('预测值')
plt.ylabel('真实值')
plt.title('混淆矩阵')
plt.tight_layout()
plt.show()
下一步 #
现在你已经完成了多分类实战,接下来学习 回归问题,了解如何处理回归任务!
最后更新:2026-04-04