构建神经网络 #

概述 #

JAX 本身不提供高级的神经网络 API,但我们可以使用 JAX 的基础功能来构建神经网络。本节介绍如何从零开始构建神经网络组件。

参数初始化 #

Xavier/Glorot 初始化 #

python
import jax
import jax.numpy as jnp

def xavier_init(key, in_features, out_features):
    scale = jnp.sqrt(2.0 / (in_features + out_features))
    return jax.random.normal(key, (in_features, out_features)) * scale

key = jax.random.PRNGKey(0)
w = xavier_init(key, 784, 256)
print(f"权重形状: {w.shape}")
print(f"权重标准差: {jnp.std(w):.4f}")

He 初始化 #

python
import jax
import jax.numpy as jnp

def he_init(key, in_features, out_features):
    scale = jnp.sqrt(2.0 / in_features)
    return jax.random.normal(key, (in_features, out_features)) * scale

key = jax.random.PRNGKey(0)
w = he_init(key, 784, 256)
print(f"权重形状: {w.shape}")
print(f"权重标准差: {jnp.std(w):.4f}")

初始化网络参数 #

python
import jax
import jax.numpy as jnp

def init_mlp_params(key, layer_sizes):
    params = []
    for i, (in_size, out_size) in enumerate(layer_sizes):
        key, w_key, b_key = jax.random.split(key, 3)
        w = jax.random.normal(w_key, (in_size, out_size)) * jnp.sqrt(2.0 / in_size)
        b = jnp.zeros(out_size)
        params.append({'w': w, 'b': b})
    return params

key = jax.random.PRNGKey(0)
params = init_mlp_params(key, [(784, 256), (256, 128), (128, 10)])

for i, p in enumerate(params):
    print(f"层 {i}: w{p['w'].shape}, b{p['b'].shape}")

线性层 #

基本线性层 #

python
import jax.numpy as jnp

def linear(params, x):
    w, b = params['w'], params['b']
    return jnp.dot(x, w) + b

params = {'w': jnp.ones((10, 5)), 'b': jnp.zeros(5)}
x = jnp.ones((32, 10))
output = linear(params, x)
print(f"输出形状: {output.shape}")

带激活的线性层 #

python
import jax.numpy as jnp

def linear_relu(params, x):
    w, b = params['w'], params['b']
    return jnp.maximum(jnp.dot(x, w) + b, 0)

params = {'w': jnp.ones((10, 5)), 'b': jnp.zeros(5)}
x = jnp.ones((32, 10))
output = linear_relu(params, x)
print(f"输出形状: {output.shape}")

激活函数 #

常用激活函数 #

python
import jax.numpy as jnp
import jax.nn as nn

x = jnp.array([-2, -1, 0, 1, 2])

relu = nn.relu(x)
print(f"ReLU: {relu}")

sigmoid = nn.sigmoid(x)
print(f"Sigmoid: {sigmoid}")

tanh = nn.tanh(x)
print(f"Tanh: {tanh}")

softmax = nn.softmax(x)
print(f"Softmax: {softmax}")

gelu = nn.gelu(x)
print(f"GELU: {gelu}")

silu = nn.silu(x)  
print(f"SiLU/Swish: {silu}")

激活函数可视化 #

text
┌─────────────────────────────────────────────────────────────┐
│                    常用激活函数                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ReLU:      max(0, x)                                       │
│             ─── 简单高效,可能导致神经元死亡                 │
│                                                             │
│  Sigmoid:   1 / (1 + exp(-x))                               │
│             ─── 输出 (0, 1),可能梯度消失                    │
│                                                             │
│  Tanh:      (exp(x) - exp(-x)) / (exp(x) + exp(-x))         │
│             ─── 输出 (-1, 1),零中心化                       │
│                                                             │
│  GELU:      x * Φ(x)                                        │
│             ─── 平滑的 ReLU,Transformer 常用                │
│                                                             │
│  SiLU:      x * sigmoid(x)                                  │
│             ─── 平滑非单调,现代网络常用                     │
│                                                             │
└─────────────────────────────────────────────────────────────┘

多层感知机 (MLP) #

前向传播 #

python
import jax
import jax.numpy as jnp
import jax.nn as nn

def mlp_forward(params, x):
    for i, layer_params in enumerate(params[:-1]):
        x = jnp.dot(x, layer_params['w']) + layer_params['b']
        x = nn.relu(x)
    
    final_params = params[-1]
    x = jnp.dot(x, final_params['w']) + final_params['b']
    return x

key = jax.random.PRNGKey(0)
params = init_mlp_params(key, [(784, 256), (256, 128), (128, 10)])

x = jax.random.normal(key, (32, 784))
output = mlp_forward(params, x)
print(f"输出形状: {output.shape}")

带输出激活 #

python
import jax.nn as nn

def mlp_classifier(params, x):
    for layer_params in params[:-1]:
        x = nn.relu(jnp.dot(x, layer_params['w']) + layer_params['b'])
    
    logits = jnp.dot(x, params[-1]['w']) + params[-1]['b']
    return nn.softmax(logits)

key = jax.random.PRNGKey(0)
params = init_mlp_params(key, [(784, 256), (256, 10)])

x = jax.random.normal(key, (32, 784))
probs = mlp_classifier(params, x)
print(f"概率形状: {probs.shape}")
print(f"概率和: {jnp.sum(probs[0]):.4f}")

卷积层 #

2D 卷积 #

python
import jax
import jax.numpy as jnp

def conv2d(x, kernel, stride=1, padding='SAME'):
    return jax.lax.conv_general_dilated(
        x,
        kernel,
        window_strides=(stride, stride),
        padding=padding,
        dimension_numbers=('NHWC', 'HWIO', 'NHWC')
    )

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (1, 28, 28, 3))  
kernel = jax.random.normal(key, (3, 3, 3, 32))  

output = conv2d(x, kernel)
print(f"卷积输出形状: {output.shape}")

初始化卷积参数 #

python
import jax
import jax.numpy as jnp

def init_conv_params(key, in_channels, out_channels, kernel_size):
    key, k_key, b_key = jax.random.split(key, 3)
    k = jax.random.normal(k_key, (kernel_size, kernel_size, in_channels, out_channels))
    k = k * jnp.sqrt(2.0 / (kernel_size * kernel_size * in_channels))
    b = jnp.zeros(out_channels)
    return {'k': k, 'b': b}

key = jax.random.PRNGKey(0)
params = init_conv_params(key, 3, 32, 3)
print(f"卷积核形状: {params['k'].shape}")
print(f"偏置形状: {params['b'].shape}")

池化层 #

python
import jax.numpy as jnp

def max_pool(x, window_shape, strides):
    return jax.lax.reduce_window(
        x,
        init_value=-jnp.inf,
        computation=jax.lax.max,
        window_dimensions=(1, window_shape[0], window_shape[1], 1),
        window_strides=(1, strides[0], strides[1], 1),
        padding='VALID'
    )

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (1, 28, 28, 32))
pooled = max_pool(x, (2, 2), (2, 2))
print(f"池化后形状: {pooled.shape}")

简单 CNN #

python
import jax
import jax.numpy as jnp
import jax.nn as nn

def init_cnn_params(key):
    keys = jax.random.split(key, 4)
    
    params = {
        'conv1': init_conv_params(keys[0], 1, 32, 3),
        'conv2': init_conv_params(keys[1], 32, 64, 3),
        'fc1': {'w': jax.random.normal(keys[2], (3136, 128)) * 0.01,
                'b': jnp.zeros(128)},
        'fc2': {'w': jax.random.normal(keys[3], (128, 10)) * 0.01,
                'b': jnp.zeros(10)}
    }
    return params

def cnn_forward(params, x):
    x = conv2d(x, params['conv1']['k'])
    x = x + params['conv1']['b']
    x = nn.relu(x)
    x = max_pool(x, (2, 2), (2, 2))
    
    x = conv2d(x, params['conv2']['k'])
    x = x + params['conv2']['b']
    x = nn.relu(x)
    x = max_pool(x, (2, 2), (2, 2))
    
    x = x.reshape((x.shape[0], -1))
    
    x = nn.relu(jnp.dot(x, params['fc1']['w']) + params['fc1']['b'])
    x = jnp.dot(x, params['fc2']['w']) + params['fc2']['b']
    
    return x

key = jax.random.PRNGKey(0)
params = init_cnn_params(key)

x = jax.random.normal(key, (32, 28, 28, 1))
output = cnn_forward(params, x)
print(f"CNN 输出形状: {output.shape}")

损失函数 #

均方误差 #

python
import jax.numpy as jnp

def mse_loss(params, x, y, forward_fn):
    pred = forward_fn(params, x)
    return jnp.mean((pred - y) ** 2)

交叉熵损失 #

python
import jax.numpy as jnp
import jax.nn as nn

def cross_entropy_loss(params, x, y, forward_fn):
    logits = forward_fn(params, x)
    log_probs = nn.log_softmax(logits)
    one_hot = nn.one_hot(y, logits.shape[-1])
    return -jnp.mean(jnp.sum(one_hot * log_probs, axis=-1))

二元交叉熵 #

python
import jax.numpy as jnp
import jax.nn as nn

def binary_cross_entropy_loss(params, x, y, forward_fn):
    logits = forward_fn(params, x)
    probs = nn.sigmoid(logits)
    return -jnp.mean(y * jnp.log(probs + 1e-7) + (1 - y) * jnp.log(1 - probs + 1e-7))

正则化 #

L2 正则化 #

python
import jax.numpy as jnp

def l2_regularization(params, weight=0.01):
    l2_loss = 0
    for p in params:
        if isinstance(p, dict):
            l2_loss += jnp.sum(p['w'] ** 2)
        else:
            l2_loss += jnp.sum(p ** 2)
    return weight * l2_loss

def loss_with_reg(params, x, y, forward_fn):
    pred = forward_fn(params, x)
    ce_loss = cross_entropy_loss(params, x, y, forward_fn)
    reg_loss = l2_regularization(params)
    return ce_loss + reg_loss

下一步 #

现在你已经掌握了神经网络基础,接下来学习 状态管理,了解如何管理神经网络的状态!

最后更新:2026-04-04