线性回归实战 #

概述 #

本节使用 JAX 从零实现一个完整的线性回归项目,涵盖数据生成、模型定义、训练和评估。

数据准备 #

生成模拟数据 #

python
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

def generate_data(key, n_samples=100, noise_std=0.1):
    key1, key2 = jax.random.split(key)
    
    true_w = jnp.array([2.0, -3.0])
    true_b = 1.0
    
    x = jax.random.uniform(key1, (n_samples, 2), minval=-1, maxval=1)
    noise = jax.random.normal(key2, (n_samples,)) * noise_std
    y = jnp.dot(x, true_w) + true_b + noise
    
    return x, y, true_w, true_b

key = jax.random.PRNGKey(42)
x_train, y_train, true_w, true_b = generate_data(key, n_samples=200)

print(f"训练数据形状: x={x_train.shape}, y={y_train.shape}")
print(f"真实参数: w={true_w}, b={true_b}")

数据可视化 #

python
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.scatter(x_train[:, 0], y_train, alpha=0.5)
plt.xlabel('x1')
plt.ylabel('y')
plt.title('x1 vs y')

plt.subplot(1, 2, 2)
plt.scatter(x_train[:, 1], y_train, alpha=0.5)
plt.xlabel('x2')
plt.ylabel('y')
plt.title('x2 vs y')

plt.tight_layout()
plt.savefig('data_visualization.png')
plt.close()

模型定义 #

线性模型 #

python
import jax
import jax.numpy as jnp

def init_params(key, input_dim):
    key1, key2 = jax.random.split(key)
    w = jax.random.normal(key1, (input_dim,)) * 0.1
    b = jax.random.normal(key2, ()) * 0.1
    return {'w': w, 'b': b}

def predict(params, x):
    return jnp.dot(x, params['w']) + params['b']

def mse_loss(params, x, y):
    pred = predict(params, x)
    return jnp.mean((pred - y) ** 2)

key = jax.random.PRNGKey(0)
params = init_params(key, input_dim=2)
print(f"初始参数: {params}")

训练 #

训练步骤 #

python
import jax

@jax.jit
def train_step(params, x, y, lr=0.1):
    loss, grads = jax.value_and_grad(mse_loss)(params, x, y)
    new_params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
    return new_params, loss

def train(params, x, y, epochs=100, lr=0.1):
    losses = []
    
    for epoch in range(epochs):
        params, loss = train_step(params, x, y, lr)
        losses.append(loss)
        
        if epoch % 20 == 0:
            print(f"Epoch {epoch}: loss={loss:.6f}")
    
    return params, losses

key = jax.random.PRNGKey(0)
params = init_params(key, input_dim=2)

params, losses = train(params, x_train, y_train, epochs=100, lr=0.1)

print(f"\n训练后参数: w={params['w']}, b={params['b']}")
print(f"真实参数: w={true_w}, b={true_b}")

训练曲线 #

python
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 5))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('Training Loss')
plt.grid(True)
plt.savefig('training_loss.png')
plt.close()

评估 #

预测与可视化 #

python
import jax.numpy as jnp

y_pred = predict(params, x_train)

mse = jnp.mean((y_pred - y_train) ** 2)
mae = jnp.mean(jnp.abs(y_pred - y_train))
r2 = 1 - jnp.sum((y_train - y_pred) ** 2) / jnp.sum((y_train - jnp.mean(y_train)) ** 2)

print(f"MSE: {mse:.6f}")
print(f"MAE: {mae:.6f}")
print(f"R²: {r2:.6f}")

plt.figure(figsize=(8, 6))
plt.scatter(y_train, y_pred, alpha=0.5)
plt.plot([y_train.min(), y_train.max()], [y_train.min(), y_train.max()], 'r--')
plt.xlabel('True Values')
plt.ylabel('Predictions')
plt.title('True vs Predicted')
plt.grid(True)
plt.savefig('predictions.png')
plt.close()

完整代码 #

python
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

def main():
    key = jax.random.PRNGKey(42)
    
    key, subkey = jax.random.split(key)
    x_train, y_train, true_w, true_b = generate_data(subkey, n_samples=200)
    
    key, subkey = jax.random.split(key)
    params = init_params(subkey, input_dim=2)
    
    params, losses = train(params, x_train, y_train, epochs=100, lr=0.1)
    
    print(f"\n最终参数: w={params['w']}, b={params['b']}")
    print(f"真实参数: w={true_w}, b={true_b}")
    
    y_pred = predict(params, x_train)
    mse = jnp.mean((y_pred - y_train) ** 2)
    print(f"最终 MSE: {mse:.6f}")

if __name__ == "__main__":
    main()

下一步 #

现在你已经完成了线性回归实战,接下来学习 图像分类,构建更复杂的神经网络!

最后更新:2026-04-04