构建神经网络 #
概述 #
JAX 本身不提供高级的神经网络 API,但我们可以使用 JAX 的基础功能来构建神经网络。本节介绍如何从零开始构建神经网络组件。
参数初始化 #
Xavier/Glorot 初始化 #
python
import jax
import jax.numpy as jnp
def xavier_init(key, in_features, out_features):
scale = jnp.sqrt(2.0 / (in_features + out_features))
return jax.random.normal(key, (in_features, out_features)) * scale
key = jax.random.PRNGKey(0)
w = xavier_init(key, 784, 256)
print(f"权重形状: {w.shape}")
print(f"权重标准差: {jnp.std(w):.4f}")
He 初始化 #
python
import jax
import jax.numpy as jnp
def he_init(key, in_features, out_features):
scale = jnp.sqrt(2.0 / in_features)
return jax.random.normal(key, (in_features, out_features)) * scale
key = jax.random.PRNGKey(0)
w = he_init(key, 784, 256)
print(f"权重形状: {w.shape}")
print(f"权重标准差: {jnp.std(w):.4f}")
初始化网络参数 #
python
import jax
import jax.numpy as jnp
def init_mlp_params(key, layer_sizes):
params = []
for i, (in_size, out_size) in enumerate(layer_sizes):
key, w_key, b_key = jax.random.split(key, 3)
w = jax.random.normal(w_key, (in_size, out_size)) * jnp.sqrt(2.0 / in_size)
b = jnp.zeros(out_size)
params.append({'w': w, 'b': b})
return params
key = jax.random.PRNGKey(0)
params = init_mlp_params(key, [(784, 256), (256, 128), (128, 10)])
for i, p in enumerate(params):
print(f"层 {i}: w{p['w'].shape}, b{p['b'].shape}")
线性层 #
基本线性层 #
python
import jax.numpy as jnp
def linear(params, x):
w, b = params['w'], params['b']
return jnp.dot(x, w) + b
params = {'w': jnp.ones((10, 5)), 'b': jnp.zeros(5)}
x = jnp.ones((32, 10))
output = linear(params, x)
print(f"输出形状: {output.shape}")
带激活的线性层 #
python
import jax.numpy as jnp
def linear_relu(params, x):
w, b = params['w'], params['b']
return jnp.maximum(jnp.dot(x, w) + b, 0)
params = {'w': jnp.ones((10, 5)), 'b': jnp.zeros(5)}
x = jnp.ones((32, 10))
output = linear_relu(params, x)
print(f"输出形状: {output.shape}")
激活函数 #
常用激活函数 #
python
import jax.numpy as jnp
import jax.nn as nn
x = jnp.array([-2, -1, 0, 1, 2])
relu = nn.relu(x)
print(f"ReLU: {relu}")
sigmoid = nn.sigmoid(x)
print(f"Sigmoid: {sigmoid}")
tanh = nn.tanh(x)
print(f"Tanh: {tanh}")
softmax = nn.softmax(x)
print(f"Softmax: {softmax}")
gelu = nn.gelu(x)
print(f"GELU: {gelu}")
silu = nn.silu(x)
print(f"SiLU/Swish: {silu}")
激活函数可视化 #
text
┌─────────────────────────────────────────────────────────────┐
│ 常用激活函数 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ReLU: max(0, x) │
│ ─── 简单高效,可能导致神经元死亡 │
│ │
│ Sigmoid: 1 / (1 + exp(-x)) │
│ ─── 输出 (0, 1),可能梯度消失 │
│ │
│ Tanh: (exp(x) - exp(-x)) / (exp(x) + exp(-x)) │
│ ─── 输出 (-1, 1),零中心化 │
│ │
│ GELU: x * Φ(x) │
│ ─── 平滑的 ReLU,Transformer 常用 │
│ │
│ SiLU: x * sigmoid(x) │
│ ─── 平滑非单调,现代网络常用 │
│ │
└─────────────────────────────────────────────────────────────┘
多层感知机 (MLP) #
前向传播 #
python
import jax
import jax.numpy as jnp
import jax.nn as nn
def mlp_forward(params, x):
for i, layer_params in enumerate(params[:-1]):
x = jnp.dot(x, layer_params['w']) + layer_params['b']
x = nn.relu(x)
final_params = params[-1]
x = jnp.dot(x, final_params['w']) + final_params['b']
return x
key = jax.random.PRNGKey(0)
params = init_mlp_params(key, [(784, 256), (256, 128), (128, 10)])
x = jax.random.normal(key, (32, 784))
output = mlp_forward(params, x)
print(f"输出形状: {output.shape}")
带输出激活 #
python
import jax.nn as nn
def mlp_classifier(params, x):
for layer_params in params[:-1]:
x = nn.relu(jnp.dot(x, layer_params['w']) + layer_params['b'])
logits = jnp.dot(x, params[-1]['w']) + params[-1]['b']
return nn.softmax(logits)
key = jax.random.PRNGKey(0)
params = init_mlp_params(key, [(784, 256), (256, 10)])
x = jax.random.normal(key, (32, 784))
probs = mlp_classifier(params, x)
print(f"概率形状: {probs.shape}")
print(f"概率和: {jnp.sum(probs[0]):.4f}")
卷积层 #
2D 卷积 #
python
import jax
import jax.numpy as jnp
def conv2d(x, kernel, stride=1, padding='SAME'):
return jax.lax.conv_general_dilated(
x,
kernel,
window_strides=(stride, stride),
padding=padding,
dimension_numbers=('NHWC', 'HWIO', 'NHWC')
)
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (1, 28, 28, 3))
kernel = jax.random.normal(key, (3, 3, 3, 32))
output = conv2d(x, kernel)
print(f"卷积输出形状: {output.shape}")
初始化卷积参数 #
python
import jax
import jax.numpy as jnp
def init_conv_params(key, in_channels, out_channels, kernel_size):
key, k_key, b_key = jax.random.split(key, 3)
k = jax.random.normal(k_key, (kernel_size, kernel_size, in_channels, out_channels))
k = k * jnp.sqrt(2.0 / (kernel_size * kernel_size * in_channels))
b = jnp.zeros(out_channels)
return {'k': k, 'b': b}
key = jax.random.PRNGKey(0)
params = init_conv_params(key, 3, 32, 3)
print(f"卷积核形状: {params['k'].shape}")
print(f"偏置形状: {params['b'].shape}")
池化层 #
python
import jax.numpy as jnp
def max_pool(x, window_shape, strides):
return jax.lax.reduce_window(
x,
init_value=-jnp.inf,
computation=jax.lax.max,
window_dimensions=(1, window_shape[0], window_shape[1], 1),
window_strides=(1, strides[0], strides[1], 1),
padding='VALID'
)
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (1, 28, 28, 32))
pooled = max_pool(x, (2, 2), (2, 2))
print(f"池化后形状: {pooled.shape}")
简单 CNN #
python
import jax
import jax.numpy as jnp
import jax.nn as nn
def init_cnn_params(key):
keys = jax.random.split(key, 4)
params = {
'conv1': init_conv_params(keys[0], 1, 32, 3),
'conv2': init_conv_params(keys[1], 32, 64, 3),
'fc1': {'w': jax.random.normal(keys[2], (3136, 128)) * 0.01,
'b': jnp.zeros(128)},
'fc2': {'w': jax.random.normal(keys[3], (128, 10)) * 0.01,
'b': jnp.zeros(10)}
}
return params
def cnn_forward(params, x):
x = conv2d(x, params['conv1']['k'])
x = x + params['conv1']['b']
x = nn.relu(x)
x = max_pool(x, (2, 2), (2, 2))
x = conv2d(x, params['conv2']['k'])
x = x + params['conv2']['b']
x = nn.relu(x)
x = max_pool(x, (2, 2), (2, 2))
x = x.reshape((x.shape[0], -1))
x = nn.relu(jnp.dot(x, params['fc1']['w']) + params['fc1']['b'])
x = jnp.dot(x, params['fc2']['w']) + params['fc2']['b']
return x
key = jax.random.PRNGKey(0)
params = init_cnn_params(key)
x = jax.random.normal(key, (32, 28, 28, 1))
output = cnn_forward(params, x)
print(f"CNN 输出形状: {output.shape}")
损失函数 #
均方误差 #
python
import jax.numpy as jnp
def mse_loss(params, x, y, forward_fn):
pred = forward_fn(params, x)
return jnp.mean((pred - y) ** 2)
交叉熵损失 #
python
import jax.numpy as jnp
import jax.nn as nn
def cross_entropy_loss(params, x, y, forward_fn):
logits = forward_fn(params, x)
log_probs = nn.log_softmax(logits)
one_hot = nn.one_hot(y, logits.shape[-1])
return -jnp.mean(jnp.sum(one_hot * log_probs, axis=-1))
二元交叉熵 #
python
import jax.numpy as jnp
import jax.nn as nn
def binary_cross_entropy_loss(params, x, y, forward_fn):
logits = forward_fn(params, x)
probs = nn.sigmoid(logits)
return -jnp.mean(y * jnp.log(probs + 1e-7) + (1 - y) * jnp.log(1 - probs + 1e-7))
正则化 #
L2 正则化 #
python
import jax.numpy as jnp
def l2_regularization(params, weight=0.01):
l2_loss = 0
for p in params:
if isinstance(p, dict):
l2_loss += jnp.sum(p['w'] ** 2)
else:
l2_loss += jnp.sum(p ** 2)
return weight * l2_loss
def loss_with_reg(params, x, y, forward_fn):
pred = forward_fn(params, x)
ce_loss = cross_entropy_loss(params, x, y, forward_fn)
reg_loss = l2_regularization(params)
return ce_loss + reg_loss
下一步 #
现在你已经掌握了神经网络基础,接下来学习 状态管理,了解如何管理神经网络的状态!
最后更新:2026-04-04