线性回归实战 #
概述 #
本节使用 JAX 从零实现一个完整的线性回归项目,涵盖数据生成、模型定义、训练和评估。
数据准备 #
生成模拟数据 #
python
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
def generate_data(key, n_samples=100, noise_std=0.1):
key1, key2 = jax.random.split(key)
true_w = jnp.array([2.0, -3.0])
true_b = 1.0
x = jax.random.uniform(key1, (n_samples, 2), minval=-1, maxval=1)
noise = jax.random.normal(key2, (n_samples,)) * noise_std
y = jnp.dot(x, true_w) + true_b + noise
return x, y, true_w, true_b
key = jax.random.PRNGKey(42)
x_train, y_train, true_w, true_b = generate_data(key, n_samples=200)
print(f"训练数据形状: x={x_train.shape}, y={y_train.shape}")
print(f"真实参数: w={true_w}, b={true_b}")
数据可视化 #
python
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.scatter(x_train[:, 0], y_train, alpha=0.5)
plt.xlabel('x1')
plt.ylabel('y')
plt.title('x1 vs y')
plt.subplot(1, 2, 2)
plt.scatter(x_train[:, 1], y_train, alpha=0.5)
plt.xlabel('x2')
plt.ylabel('y')
plt.title('x2 vs y')
plt.tight_layout()
plt.savefig('data_visualization.png')
plt.close()
模型定义 #
线性模型 #
python
import jax
import jax.numpy as jnp
def init_params(key, input_dim):
key1, key2 = jax.random.split(key)
w = jax.random.normal(key1, (input_dim,)) * 0.1
b = jax.random.normal(key2, ()) * 0.1
return {'w': w, 'b': b}
def predict(params, x):
return jnp.dot(x, params['w']) + params['b']
def mse_loss(params, x, y):
pred = predict(params, x)
return jnp.mean((pred - y) ** 2)
key = jax.random.PRNGKey(0)
params = init_params(key, input_dim=2)
print(f"初始参数: {params}")
训练 #
训练步骤 #
python
import jax
@jax.jit
def train_step(params, x, y, lr=0.1):
loss, grads = jax.value_and_grad(mse_loss)(params, x, y)
new_params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
return new_params, loss
def train(params, x, y, epochs=100, lr=0.1):
losses = []
for epoch in range(epochs):
params, loss = train_step(params, x, y, lr)
losses.append(loss)
if epoch % 20 == 0:
print(f"Epoch {epoch}: loss={loss:.6f}")
return params, losses
key = jax.random.PRNGKey(0)
params = init_params(key, input_dim=2)
params, losses = train(params, x_train, y_train, epochs=100, lr=0.1)
print(f"\n训练后参数: w={params['w']}, b={params['b']}")
print(f"真实参数: w={true_w}, b={true_b}")
训练曲线 #
python
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 5))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('Training Loss')
plt.grid(True)
plt.savefig('training_loss.png')
plt.close()
评估 #
预测与可视化 #
python
import jax.numpy as jnp
y_pred = predict(params, x_train)
mse = jnp.mean((y_pred - y_train) ** 2)
mae = jnp.mean(jnp.abs(y_pred - y_train))
r2 = 1 - jnp.sum((y_train - y_pred) ** 2) / jnp.sum((y_train - jnp.mean(y_train)) ** 2)
print(f"MSE: {mse:.6f}")
print(f"MAE: {mae:.6f}")
print(f"R²: {r2:.6f}")
plt.figure(figsize=(8, 6))
plt.scatter(y_train, y_pred, alpha=0.5)
plt.plot([y_train.min(), y_train.max()], [y_train.min(), y_train.max()], 'r--')
plt.xlabel('True Values')
plt.ylabel('Predictions')
plt.title('True vs Predicted')
plt.grid(True)
plt.savefig('predictions.png')
plt.close()
完整代码 #
python
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
def main():
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
x_train, y_train, true_w, true_b = generate_data(subkey, n_samples=200)
key, subkey = jax.random.split(key)
params = init_params(subkey, input_dim=2)
params, losses = train(params, x_train, y_train, epochs=100, lr=0.1)
print(f"\n最终参数: w={params['w']}, b={params['b']}")
print(f"真实参数: w={true_w}, b={true_b}")
y_pred = predict(params, x_train)
mse = jnp.mean((y_pred - y_train) ** 2)
print(f"最终 MSE: {mse:.6f}")
if __name__ == "__main__":
main()
下一步 #
现在你已经完成了线性回归实战,接下来学习 图像分类,构建更复杂的神经网络!
最后更新:2026-04-04