图像分类实战 #
概述 #
本节使用 JAX 实现一个卷积神经网络(CNN)进行图像分类。
数据准备 #
模拟数据 #
python
import jax
import jax.numpy as jnp
def generate_image_data(key, n_samples=1000, img_size=28, n_classes=10):
key1, key2 = jax.random.split(key)
images = jax.random.normal(key1, (n_samples, img_size, img_size, 1))
labels = jax.random.randint(key2, (n_samples,), 0, n_classes)
return images, labels
key = jax.random.PRNGKey(42)
x_train, y_train = generate_image_data(key, n_samples=1000)
x_test, y_test = generate_image_data(jax.random.PRNGKey(1), n_samples=200)
print(f"训练数据: {x_train.shape}, {y_train.shape}")
print(f"测试数据: {x_test.shape}, {y_test.shape}")
模型定义 #
CNN 模型 #
python
import jax
import jax.numpy as jnp
import jax.nn as nn
def init_cnn_params(key):
keys = jax.random.split(key, 6)
params = {
'conv1': {
'w': jax.random.normal(keys[0], (3, 3, 1, 32)) * 0.01,
'b': jnp.zeros(32)
},
'conv2': {
'w': jax.random.normal(keys[1], (3, 3, 32, 64)) * 0.01,
'b': jnp.zeros(64)
},
'fc1': {
'w': jax.random.normal(keys[2], (3136, 128)) * 0.01,
'b': jnp.zeros(128)
},
'fc2': {
'w': jax.random.normal(keys[3], (128, 10)) * 0.01,
'b': jnp.zeros(10)
}
}
return params
def conv2d(x, kernel, stride=1, padding='SAME'):
return jax.lax.conv_general_dilated(
x, kernel,
window_strides=(stride, stride),
padding=padding,
dimension_numbers=('NHWC', 'HWIO', 'NHWC')
)
def max_pool(x, window_shape=(2, 2), strides=(2, 2)):
return jax.lax.reduce_window(
x, -jnp.inf, jax.lax.max,
window_dimensions=(1, window_shape[0], window_shape[1], 1),
window_strides=(1, strides[0], strides[1], 1),
padding='VALID'
)
def forward(params, x):
x = conv2d(x, params['conv1']['w']) + params['conv1']['b']
x = nn.relu(x)
x = max_pool(x)
x = conv2d(x, params['conv2']['w']) + params['conv2']['b']
x = nn.relu(x)
x = max_pool(x)
x = x.reshape((x.shape[0], -1))
x = nn.relu(jnp.dot(x, params['fc1']['w']) + params['fc1']['b'])
x = jnp.dot(x, params['fc2']['w']) + params['fc2']['b']
return x
训练 #
损失函数和训练步骤 #
python
import jax
def cross_entropy_loss(params, x, y):
logits = forward(params, x)
log_probs = nn.log_softmax(logits)
one_hot = nn.one_hot(y, 10)
return -jnp.mean(jnp.sum(one_hot * log_probs, axis=-1))
@jax.jit
def train_step(params, x, y, lr=0.001):
loss, grads = jax.value_and_grad(cross_entropy_loss)(params, x, y)
new_params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
return new_params, loss
def accuracy(params, x, y):
logits = forward(params, x)
predictions = jnp.argmax(logits, axis=-1)
return jnp.mean(predictions == y)
训练循环 #
python
def train(params, x_train, y_train, x_test, y_test, epochs=20, batch_size=32, lr=0.001):
n_samples = x_train.shape[0]
for epoch in range(epochs):
key = jax.random.PRNGKey(epoch)
perm = jax.random.permutation(key, n_samples)
x_train = x_train[perm]
y_train = y_train[perm]
total_loss = 0
n_batches = 0
for i in range(0, n_samples, batch_size):
x_batch = x_train[i:i+batch_size]
y_batch = y_train[i:i+batch_size]
params, loss = train_step(params, x_batch, y_batch, lr)
total_loss += loss
n_batches += 1
avg_loss = total_loss / n_batches
train_acc = accuracy(params, x_train, y_train)
test_acc = accuracy(params, x_test, y_test)
print(f"Epoch {epoch}: loss={avg_loss:.4f}, train_acc={train_acc:.4f}, test_acc={test_acc:.4f}")
return params
key = jax.random.PRNGKey(0)
params = init_cnn_params(key)
params = train(params, x_train, y_train, x_test, y_test, epochs=20, batch_size=32, lr=0.001)
下一步 #
现在你已经完成了图像分类实战,接下来学习 文本生成,构建序列模型!
最后更新:2026-04-04