强化学习实战 #

概述 #

本节使用 JAX 实现强化学习算法,包括 DQN(Deep Q-Network)和 PPO(Proximal Policy Optimization)。

环境设置 #

简单环境 #

python
import jax
import jax.numpy as jnp

class SimpleEnv:
    def __init__(self, state_dim=4, action_dim=2):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.state = None
    
    def reset(self, key):
        self.state = jax.random.uniform(key, (self.state_dim,), minval=-1, maxval=1)
        return self.state
    
    def step(self, action):
        reward = jnp.sum(self.state * (action + 1))
        self.state = self.state + (action - 0.5) * 0.1
        done = jnp.abs(jnp.sum(self.state)) > 2
        return self.state, reward, done

env = SimpleEnv()
key = jax.random.PRNGKey(0)
state = env.reset(key)
print(f"初始状态: {state}")

DQN 实现 #

Q 网络 #

python
import jax
import jax.numpy as jnp
import jax.nn as nn

def init_q_network(key, state_dim, action_dim, hidden_dim=64):
    keys = jax.random.split(key, 3)
    return {
        'w1': jax.random.normal(keys[0], (state_dim, hidden_dim)) * 0.01,
        'b1': jnp.zeros(hidden_dim),
        'w2': jax.random.normal(keys[1], (hidden_dim, hidden_dim)) * 0.01,
        'b2': jnp.zeros(hidden_dim),
        'w3': jax.random.normal(keys[2], (hidden_dim, action_dim)) * 0.01,
        'b3': jnp.zeros(action_dim)
    }

def q_forward(params, state):
    x = nn.relu(jnp.dot(state, params['w1']) + params['b1'])
    x = nn.relu(jnp.dot(x, params['w2']) + params['b2'])
    return jnp.dot(x, params['w3']) + params['b3']

经验回放 #

python
from collections import deque
import random

class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (
            jnp.array(states),
            jnp.array(actions),
            jnp.array(rewards),
            jnp.array(next_states),
            jnp.array(dones)
        )
    
    def __len__(self):
        return len(self.buffer)

DQN 训练 #

python
import jax

def dqn_loss(params, target_params, states, actions, rewards, next_states, dones, gamma=0.99):
    q_values = q_forward(params, states)
    q_value = q_values[jnp.arange(len(actions)), actions]
    
    next_q_values = q_forward(target_params, next_states)
    max_next_q = jnp.max(next_q_values, axis=-1)
    target_q = rewards + gamma * max_next_q * (1 - dones)
    
    return jnp.mean((q_value - target_q) ** 2)

@jax.jit
def dqn_train_step(params, target_params, states, actions, rewards, next_states, dones, lr=0.001):
    loss, grads = jax.value_and_grad(dqn_loss)(params, target_params, states, actions, rewards, next_states, dones)
    new_params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
    return new_params, loss

def select_action(params, state, key, epsilon=0.1):
    if jax.random.uniform(key) < epsilon:
        return jax.random.randint(key, (), 0, 2)
    q_values = q_forward(params, state)
    return jnp.argmax(q_values)

DQN 训练循环 #

python
def train_dqn(env, params, episodes=100, batch_size=32, gamma=0.99, epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.995):
    buffer = ReplayBuffer()
    target_params = params
    epsilon = epsilon_start
    key = jax.random.PRNGKey(0)
    
    for episode in range(episodes):
        key, subkey = jax.random.split(key)
        state = env.reset(subkey)
        total_reward = 0
        
        while True:
            key, subkey = jax.random.split(key)
            action = select_action(params, state, subkey, epsilon)
            next_state, reward, done = env.step(int(action))
            
            buffer.push(state, action, reward, next_state, done)
            total_reward += reward
            
            if len(buffer) >= batch_size:
                states, actions, rewards, next_states, dones = buffer.sample(batch_size)
                params, loss = dqn_train_step(params, target_params, states, actions, rewards, next_states, dones)
            
            state = next_state
            
            if done:
                break
        
        epsilon = max(epsilon_end, epsilon * epsilon_decay)
        
        if episode % 10 == 0:
            target_params = params
        
        if episode % 20 == 0:
            print(f"Episode {episode}: total_reward={total_reward:.2f}, epsilon={epsilon:.3f}")
    
    return params

key = jax.random.PRNGKey(0)
params = init_q_network(key, state_dim=4, action_dim=2)
params = train_dqn(env, params, episodes=100)

PPO 实现 #

策略网络 #

python
import jax
import jax.numpy as jnp
import jax.nn as nn

def init_policy_network(key, state_dim, action_dim, hidden_dim=64):
    keys = jax.random.split(key, 4)
    return {
        'actor': {
            'w1': jax.random.normal(keys[0], (state_dim, hidden_dim)) * 0.01,
            'b1': jnp.zeros(hidden_dim),
            'w2': jax.random.normal(keys[1], (hidden_dim, action_dim)) * 0.01,
            'b2': jnp.zeros(action_dim)
        },
        'critic': {
            'w1': jax.random.normal(keys[2], (state_dim, hidden_dim)) * 0.01,
            'b1': jnp.zeros(hidden_dim),
            'w2': jax.random.normal(keys[3], (hidden_dim, 1)) * 0.01,
            'b2': jnp.zeros(1)
        }
    }

def actor_forward(params, state):
    x = nn.relu(jnp.dot(state, params['actor']['w1']) + params['actor']['b1'])
    logits = jnp.dot(x, params['actor']['w2']) + params['actor']['b2']
    return nn.softmax(logits)

def critic_forward(params, state):
    x = nn.relu(jnp.dot(state, params['critic']['w1']) + params['critic']['b1'])
    return jnp.dot(x, params['critic']['w2']) + params['critic']['b2']

PPO 训练 #

python
import jax

def ppo_loss(params, states, actions, advantages, old_probs, clip_ratio=0.2):
    probs = actor_forward(params, states)
    
    action_probs = probs[jnp.arange(len(actions)), actions]
    old_action_probs = old_probs[jnp.arange(len(actions)), actions]
    
    ratio = action_probs / (old_action_probs + 1e-8)
    
    clipped_ratio = jnp.clip(ratio, 1 - clip_ratio, 1 + clip_ratio)
    
    policy_loss = -jnp.mean(jnp.minimum(ratio * advantages, clipped_ratio * advantages))
    
    values = critic_forward(params, states).squeeze()
    value_loss = jnp.mean((values - advantages) ** 2)
    
    return policy_loss + 0.5 * value_loss

@jax.jit
def ppo_train_step(params, states, actions, advantages, old_probs, lr=0.001):
    loss, grads = jax.value_and_grad(ppo_loss)(params, states, actions, advantages, old_probs)
    new_params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
    return new_params, loss

下一步 #

恭喜你完成了 JAX 强化学习实战!现在你已经掌握了 JAX 的核心功能,可以开始自己的项目了!

最后更新:2026-04-04