多分类问题 #

案例概述 #

本案例使用鸢尾花数据集演示 LightGBM 多分类任务:

python
import lightgbm as lgb
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, classification_report,
    confusion_matrix
)
import matplotlib.pyplot as plt
import seaborn as sns

data = load_iris()
X = pd.DataFrame(data.data, columns=data.feature_names)
y = data.target

print(f"数据形状: {X.shape}")
print(f"类别数量: {len(np.unique(y))}")
print(f"类别分布: {np.bincount(y)}")
print(f"类别名称: {data.target_names}")

模型训练 #

python
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

train_data = lgb.Dataset(X_train, label=y_train)
valid_data = lgb.Dataset(X_test, label=y_test, reference=train_data)

params = {
    'objective': 'multiclass',
    'num_class': 3,
    'metric': 'multi_logloss',
    'num_leaves': 31,
    'learning_rate': 0.05,
    'verbose': -1
}

model = lgb.train(
    params,
    train_data,
    num_boost_round=500,
    valid_sets=[valid_data],
    callbacks=[
        lgb.log_evaluation(100),
        lgb.early_stopping(50)
    ]
)

模型评估 #

python
y_pred_proba = model.predict(X_test, num_iteration=model.best_iteration)
y_pred = np.argmax(y_pred_proba, axis=1)

print(f"\n准确率: {accuracy_score(y_test, y_pred):.4f}")

print("\n分类报告:")
print(classification_report(y_test, y_pred, target_names=data.target_names))

cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=data.target_names, yticklabels=data.target_names)
plt.xlabel('预测值')
plt.ylabel('真实值')
plt.title('混淆矩阵')
plt.tight_layout()
plt.show()

下一步 #

现在你已经完成了多分类实战,接下来学习 回归问题,了解如何处理回归任务!

最后更新:2026-04-04