图像分类实战 #

概述 #

本节使用 JAX 实现一个卷积神经网络(CNN)进行图像分类。

数据准备 #

模拟数据 #

python
import jax
import jax.numpy as jnp

def generate_image_data(key, n_samples=1000, img_size=28, n_classes=10):
    key1, key2 = jax.random.split(key)
    
    images = jax.random.normal(key1, (n_samples, img_size, img_size, 1))
    labels = jax.random.randint(key2, (n_samples,), 0, n_classes)
    
    return images, labels

key = jax.random.PRNGKey(42)
x_train, y_train = generate_image_data(key, n_samples=1000)
x_test, y_test = generate_image_data(jax.random.PRNGKey(1), n_samples=200)

print(f"训练数据: {x_train.shape}, {y_train.shape}")
print(f"测试数据: {x_test.shape}, {y_test.shape}")

模型定义 #

CNN 模型 #

python
import jax
import jax.numpy as jnp
import jax.nn as nn

def init_cnn_params(key):
    keys = jax.random.split(key, 6)
    
    params = {
        'conv1': {
            'w': jax.random.normal(keys[0], (3, 3, 1, 32)) * 0.01,
            'b': jnp.zeros(32)
        },
        'conv2': {
            'w': jax.random.normal(keys[1], (3, 3, 32, 64)) * 0.01,
            'b': jnp.zeros(64)
        },
        'fc1': {
            'w': jax.random.normal(keys[2], (3136, 128)) * 0.01,
            'b': jnp.zeros(128)
        },
        'fc2': {
            'w': jax.random.normal(keys[3], (128, 10)) * 0.01,
            'b': jnp.zeros(10)
        }
    }
    return params

def conv2d(x, kernel, stride=1, padding='SAME'):
    return jax.lax.conv_general_dilated(
        x, kernel,
        window_strides=(stride, stride),
        padding=padding,
        dimension_numbers=('NHWC', 'HWIO', 'NHWC')
    )

def max_pool(x, window_shape=(2, 2), strides=(2, 2)):
    return jax.lax.reduce_window(
        x, -jnp.inf, jax.lax.max,
        window_dimensions=(1, window_shape[0], window_shape[1], 1),
        window_strides=(1, strides[0], strides[1], 1),
        padding='VALID'
    )

def forward(params, x):
    x = conv2d(x, params['conv1']['w']) + params['conv1']['b']
    x = nn.relu(x)
    x = max_pool(x)
    
    x = conv2d(x, params['conv2']['w']) + params['conv2']['b']
    x = nn.relu(x)
    x = max_pool(x)
    
    x = x.reshape((x.shape[0], -1))
    
    x = nn.relu(jnp.dot(x, params['fc1']['w']) + params['fc1']['b'])
    x = jnp.dot(x, params['fc2']['w']) + params['fc2']['b']
    
    return x

训练 #

损失函数和训练步骤 #

python
import jax

def cross_entropy_loss(params, x, y):
    logits = forward(params, x)
    log_probs = nn.log_softmax(logits)
    one_hot = nn.one_hot(y, 10)
    return -jnp.mean(jnp.sum(one_hot * log_probs, axis=-1))

@jax.jit
def train_step(params, x, y, lr=0.001):
    loss, grads = jax.value_and_grad(cross_entropy_loss)(params, x, y)
    new_params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
    return new_params, loss

def accuracy(params, x, y):
    logits = forward(params, x)
    predictions = jnp.argmax(logits, axis=-1)
    return jnp.mean(predictions == y)

训练循环 #

python
def train(params, x_train, y_train, x_test, y_test, epochs=20, batch_size=32, lr=0.001):
    n_samples = x_train.shape[0]
    
    for epoch in range(epochs):
        key = jax.random.PRNGKey(epoch)
        perm = jax.random.permutation(key, n_samples)
        x_train = x_train[perm]
        y_train = y_train[perm]
        
        total_loss = 0
        n_batches = 0
        
        for i in range(0, n_samples, batch_size):
            x_batch = x_train[i:i+batch_size]
            y_batch = y_train[i:i+batch_size]
            
            params, loss = train_step(params, x_batch, y_batch, lr)
            total_loss += loss
            n_batches += 1
        
        avg_loss = total_loss / n_batches
        train_acc = accuracy(params, x_train, y_train)
        test_acc = accuracy(params, x_test, y_test)
        
        print(f"Epoch {epoch}: loss={avg_loss:.4f}, train_acc={train_acc:.4f}, test_acc={test_acc:.4f}")
    
    return params

key = jax.random.PRNGKey(0)
params = init_cnn_params(key)

params = train(params, x_train, y_train, x_test, y_test, epochs=20, batch_size=32, lr=0.001)

下一步 #

现在你已经完成了图像分类实战,接下来学习 文本生成,构建序列模型!

最后更新:2026-04-04