PyTorch 神经网络模块 #

nn.Module 简介 #

nn.Module 是 PyTorch 中所有神经网络模块的基类。它提供了构建神经网络所需的核心功能。

text
┌─────────────────────────────────────────────────────────────┐
│                    nn.Module 核心功能                        │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. 参数管理                                                 │
│     - 自动注册可学习参数                                     │
│     - parameters() 获取所有参数                              │
│     - state_dict() 获取状态字典                              │
│                                                             │
│  2. 子模块管理                                               │
│     - 自动注册子模块                                         │
│     - modules() 获取所有模块                                 │
│     - children() 获取直接子模块                              │
│                                                             │
│  3. 前向传播                                                 │
│     - forward() 方法定义计算                                 │
│     - __call__() 自动调用 forward                           │
│                                                             │
│  4. 设备转移                                                 │
│     - to() 方法转移设备                                     │
│     - cuda() / cpu() 快捷方法                               │
│                                                             │
│  5. 训练/评估模式                                            │
│     - train() 训练模式                                      │
│     - eval() 评估模式                                       │
│                                                             │
└─────────────────────────────────────────────────────────────┘

自定义神经网络 #

基本结构 #

python
import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

model = SimpleNet(784, 128, 10)
print(model)

使用 nn.Sequential #

python
import torch.nn as nn

model = nn.Sequential(
    nn.Linear(784, 128),
    nn.ReLU(),
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, 10)
)

print(model)

from collections import OrderedDict

model = nn.Sequential(OrderedDict([
    ('fc1', nn.Linear(784, 128)),
    ('relu1', nn.ReLU()),
    ('fc2', nn.Linear(128, 10))
]))

print(model)

使用 nn.ModuleList #

python
import torch.nn as nn

class MyNet(nn.Module):
    def __init__(self, sizes):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(sizes[i], sizes[i+1]) 
            for i in range(len(sizes)-1)
        ])
        self.relu = nn.ReLU()
    
    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i < len(self.layers) - 1:
                x = self.relu(x)
        return x

model = MyNet([784, 256, 128, 10])
print(model)

使用 nn.ModuleDict #

python
import torch.nn as nn

class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleDict({
            'fc1': nn.Linear(784, 128),
            'fc2': nn.Linear(128, 10)
        })
        self.activations = nn.ModuleDict({
            'relu': nn.ReLU(),
            'sigmoid': nn.Sigmoid()
        })
    
    def forward(self, x, activation='relu'):
        x = self.layers['fc1'](x)
        x = self.activations[activation](x)
        x = self.layers['fc2'](x)
        return x

model = MyNet()
print(model)

常用层 #

全连接层 nn.Linear #

python
import torch
import torch.nn as nn

layer = nn.Linear(10, 5)

print(layer.weight.shape)
print(layer.bias.shape)

x = torch.randn(32, 10)
y = layer(x)
print(y.shape)

layer = nn.Linear(10, 5, bias=False)
print(layer.bias)

卷积层 nn.Conv2d #

python
import torch
import torch.nn as nn

conv = nn.Conv2d(
    in_channels=3,
    out_channels=64,
    kernel_size=3,
    stride=1,
    padding=1
)

x = torch.randn(32, 3, 224, 224)
y = conv(x)
print(y.shape)

conv = nn.Conv2d(3, 64, kernel_size=3, padding='same')
y = conv(x)
print(y.shape)

conv = nn.Conv2d(3, 64, kernel_size=3, padding='valid')
y = conv(x)
print(y.shape)

池化层 #

python
import torch
import torch.nn as nn

x = torch.randn(32, 64, 112, 112)

maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
y = maxpool(x)
print(f"MaxPool: {y.shape}")

avgpool = nn.AvgPool2d(kernel_size=2, stride=2)
y = avgpool(x)
print(f"AvgPool: {y.shape}")

adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
y = adaptive_pool(x)
print(f"AdaptiveAvgPool: {y.shape}")

adaptive_pool = nn.AdaptiveMaxPool2d((7, 7))
y = adaptive_pool(x)
print(f"AdaptiveMaxPool: {y.shape}")

归一化层 #

python
import torch
import torch.nn as nn

x = torch.randn(32, 64, 112, 112)

bn = nn.BatchNorm2d(64)
y = bn(x)
print(f"BatchNorm: {y.shape}")

ln = nn.LayerNorm([64, 112, 112])
y = ln(x)
print(f"LayerNorm: {y.shape}")

gn = nn.GroupNorm(num_groups=8, num_channels=64)
y = gn(x)
print(f"GroupNorm: {y.shape}")

dropout = nn.Dropout(p=0.5)
y = dropout(x)
print(f"Dropout: {y.shape}")

激活函数 #

常用激活函数 #

text
┌─────────────────────────────────────────────────────────────┐
│                    激活函数对比                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ReLU                                                       │
│  f(x) = max(0, x)                                          │
│  优点:计算简单,缓解梯度消失                                │
│  缺点:负区间梯度为0(死亡ReLU)                             │
│                                                             │
│  LeakyReLU                                                  │
│  f(x) = max(αx, x)  (α通常为0.01)                          │
│  优点:解决死亡ReLU问题                                     │
│                                                             │
│  GELU                                                       │
│  f(x) = x * Φ(x)                                           │
│  优点:平滑,Transformer常用                                │
│                                                             │
│  Sigmoid                                                    │
│  f(x) = 1 / (1 + e^(-x))                                   │
│  优点:输出(0,1),适合概率                                  │
│  缺点:梯度消失                                             │
│                                                             │
│  Tanh                                                       │
│  f(x) = (e^x - e^(-x)) / (e^x + e^(-x))                    │
│  优点:输出(-1,1),零中心                                   │
│  缺点:梯度消失                                             │
│                                                             │
│  Softmax                                                    │
│  f(x_i) = e^x_i / Σe^x_j                                   │
│  用途:多分类输出层                                         │
│                                                             │
└─────────────────────────────────────────────────────────────┘

使用示例 #

python
import torch
import torch.nn as nn

x = torch.randn(2, 3)

relu = nn.ReLU()
print(f"ReLU: {relu(x)}")

leaky_relu = nn.LeakyReLU(negative_slope=0.01)
print(f"LeakyReLU: {leaky_relu(x)}")

gelu = nn.GELU()
print(f"GELU: {gelu(x)}")

sigmoid = nn.Sigmoid()
print(f"Sigmoid: {sigmoid(x)}")

tanh = nn.Tanh()
print(f"Tanh: {tanh(x)}")

softmax = nn.Softmax(dim=1)
print(f"Softmax: {softmax(x)}")

函数式 API #

python
import torch
import torch.nn.functional as F

x = torch.randn(2, 3)

y = F.relu(x)
y = F.leaky_relu(x, negative_slope=0.01)
y = F.gelu(x)
y = F.sigmoid(x)
y = F.tanh(x)
y = F.softmax(x, dim=1)

损失函数 #

回归损失 #

python
import torch
import torch.nn as nn

predictions = torch.randn(32, 1)
targets = torch.randn(32, 1)

mse_loss = nn.MSELoss()
loss = mse_loss(predictions, targets)
print(f"MSE Loss: {loss}")

l1_loss = nn.L1Loss()
loss = l1_loss(predictions, targets)
print(f"L1 Loss: {loss}")

smooth_l1_loss = nn.SmoothL1Loss()
loss = smooth_l1_loss(predictions, targets)
print(f"Smooth L1 Loss: {loss}")

分类损失 #

python
import torch
import torch.nn as nn

predictions = torch.randn(32, 10)
targets = torch.randint(0, 10, (32,))

ce_loss = nn.CrossEntropyLoss()
loss = ce_loss(predictions, targets)
print(f"CrossEntropy Loss: {loss}")

targets_onehot = torch.zeros(32, 10)
targets_onehot.scatter_(1, targets.unsqueeze(1), 1)

bce_loss = nn.BCELoss()
loss = bce_loss(torch.sigmoid(predictions), targets_onehot)
print(f"BCE Loss: {loss}")

bce_with_logits = nn.BCEWithLogitsLoss()
loss = bce_with_logits(predictions, targets_onehot)
print(f"BCE with Logits Loss: {loss}")

其他损失 #

python
import torch
import torch.nn as nn

predictions = torch.randn(32, 10)
targets = torch.randint(0, 10, (32,))

nll_loss = nn.NLLLoss()
log_probs = torch.log_softmax(predictions, dim=1)
loss = nll_loss(log_probs, targets)
print(f"NLL Loss: {loss}")

embedding = nn.Embedding(10, 3)
inputs = torch.randint(0, 10, (5,))
targets = torch.randint(0, 10, (5,))
triplet_loss = nn.TripletMarginLoss(margin=1.0)
anchor = embedding(inputs)
positive = embedding(targets)
negative = embedding(torch.randint(0, 10, (5,)))
loss = triplet_loss(anchor, positive, negative)
print(f"Triplet Loss: {loss}")

参数管理 #

获取参数 #

python
import torch.nn as nn

model = nn.Sequential(
    nn.Linear(784, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

print("所有参数:")
for name, param in model.named_parameters():
    print(f"  {name}: {param.shape}")

print("\n参数数量:")
total_params = sum(p.numel() for p in model.parameters())
print(f"  总参数: {total_params}")

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"  可训练参数: {trainable_params}")

参数初始化 #

python
import torch
import torch.nn as nn

model = nn.Sequential(
    nn.Linear(784, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)

model.apply(init_weights)

print(model[0].weight[:2, :5])

nn.init.kaiming_normal_(model[0].weight, mode='fan_out', nonlinearity='relu')
nn.init.constant_(model[0].bias, 0.1)

print(model[0].bias[:5])

冻结参数 #

python
import torch.nn as nn

model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

for param in model[0].parameters():
    param.requires_grad = False
for param in model[2].parameters():
    param.requires_grad = False

for name, param in model.named_parameters():
    print(f"{name}: requires_grad={param.requires_grad}")

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"\n可训练参数: {trainable}/{total}")

模型状态 #

训练与评估模式 #

python
import torch.nn as nn

model = nn.Sequential(
    nn.Linear(784, 128),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(128, 10)
)

model.train()
print(f"训练模式: {model.training}")

x = torch.randn(32, 784)
y1 = model(x)
y2 = model(x)
print(f"训练模式下两次输出相同: {torch.allclose(y1, y2)}")

model.eval()
print(f"评估模式: {model.training}")

y1 = model(x)
y2 = model(x)
print(f"评估模式下两次输出相同: {torch.allclose(y1, y2)}")

state_dict #

python
import torch.nn as nn

model = nn.Sequential(
    nn.Linear(784, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

state_dict = model.state_dict()
print("state_dict keys:")
for key in state_dict.keys():
    print(f"  {key}: {state_dict[key].shape}")

model2 = nn.Sequential(
    nn.Linear(784, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)
model2.load_state_dict(state_dict)

完整示例:MNIST 分类器 #

python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

class MNISTNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
    
    def forward(self, x):
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MNISTNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')

for epoch in range(1, 6):
    train(epoch)

下一步 #

现在你已经掌握了 PyTorch 神经网络模块的核心概念,接下来学习 优化器,了解如何高效训练神经网络!

最后更新:2026-03-29