强化学习实战 #
概述 #
本节使用 JAX 实现强化学习算法,包括 DQN(Deep Q-Network)和 PPO(Proximal Policy Optimization)。
环境设置 #
简单环境 #
python
import jax
import jax.numpy as jnp
class SimpleEnv:
def __init__(self, state_dim=4, action_dim=2):
self.state_dim = state_dim
self.action_dim = action_dim
self.state = None
def reset(self, key):
self.state = jax.random.uniform(key, (self.state_dim,), minval=-1, maxval=1)
return self.state
def step(self, action):
reward = jnp.sum(self.state * (action + 1))
self.state = self.state + (action - 0.5) * 0.1
done = jnp.abs(jnp.sum(self.state)) > 2
return self.state, reward, done
env = SimpleEnv()
key = jax.random.PRNGKey(0)
state = env.reset(key)
print(f"初始状态: {state}")
DQN 实现 #
Q 网络 #
python
import jax
import jax.numpy as jnp
import jax.nn as nn
def init_q_network(key, state_dim, action_dim, hidden_dim=64):
keys = jax.random.split(key, 3)
return {
'w1': jax.random.normal(keys[0], (state_dim, hidden_dim)) * 0.01,
'b1': jnp.zeros(hidden_dim),
'w2': jax.random.normal(keys[1], (hidden_dim, hidden_dim)) * 0.01,
'b2': jnp.zeros(hidden_dim),
'w3': jax.random.normal(keys[2], (hidden_dim, action_dim)) * 0.01,
'b3': jnp.zeros(action_dim)
}
def q_forward(params, state):
x = nn.relu(jnp.dot(state, params['w1']) + params['b1'])
x = nn.relu(jnp.dot(x, params['w2']) + params['b2'])
return jnp.dot(x, params['w3']) + params['b3']
经验回放 #
python
from collections import deque
import random
class ReplayBuffer:
def __init__(self, capacity=10000):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return (
jnp.array(states),
jnp.array(actions),
jnp.array(rewards),
jnp.array(next_states),
jnp.array(dones)
)
def __len__(self):
return len(self.buffer)
DQN 训练 #
python
import jax
def dqn_loss(params, target_params, states, actions, rewards, next_states, dones, gamma=0.99):
q_values = q_forward(params, states)
q_value = q_values[jnp.arange(len(actions)), actions]
next_q_values = q_forward(target_params, next_states)
max_next_q = jnp.max(next_q_values, axis=-1)
target_q = rewards + gamma * max_next_q * (1 - dones)
return jnp.mean((q_value - target_q) ** 2)
@jax.jit
def dqn_train_step(params, target_params, states, actions, rewards, next_states, dones, lr=0.001):
loss, grads = jax.value_and_grad(dqn_loss)(params, target_params, states, actions, rewards, next_states, dones)
new_params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
return new_params, loss
def select_action(params, state, key, epsilon=0.1):
if jax.random.uniform(key) < epsilon:
return jax.random.randint(key, (), 0, 2)
q_values = q_forward(params, state)
return jnp.argmax(q_values)
DQN 训练循环 #
python
def train_dqn(env, params, episodes=100, batch_size=32, gamma=0.99, epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.995):
buffer = ReplayBuffer()
target_params = params
epsilon = epsilon_start
key = jax.random.PRNGKey(0)
for episode in range(episodes):
key, subkey = jax.random.split(key)
state = env.reset(subkey)
total_reward = 0
while True:
key, subkey = jax.random.split(key)
action = select_action(params, state, subkey, epsilon)
next_state, reward, done = env.step(int(action))
buffer.push(state, action, reward, next_state, done)
total_reward += reward
if len(buffer) >= batch_size:
states, actions, rewards, next_states, dones = buffer.sample(batch_size)
params, loss = dqn_train_step(params, target_params, states, actions, rewards, next_states, dones)
state = next_state
if done:
break
epsilon = max(epsilon_end, epsilon * epsilon_decay)
if episode % 10 == 0:
target_params = params
if episode % 20 == 0:
print(f"Episode {episode}: total_reward={total_reward:.2f}, epsilon={epsilon:.3f}")
return params
key = jax.random.PRNGKey(0)
params = init_q_network(key, state_dim=4, action_dim=2)
params = train_dqn(env, params, episodes=100)
PPO 实现 #
策略网络 #
python
import jax
import jax.numpy as jnp
import jax.nn as nn
def init_policy_network(key, state_dim, action_dim, hidden_dim=64):
keys = jax.random.split(key, 4)
return {
'actor': {
'w1': jax.random.normal(keys[0], (state_dim, hidden_dim)) * 0.01,
'b1': jnp.zeros(hidden_dim),
'w2': jax.random.normal(keys[1], (hidden_dim, action_dim)) * 0.01,
'b2': jnp.zeros(action_dim)
},
'critic': {
'w1': jax.random.normal(keys[2], (state_dim, hidden_dim)) * 0.01,
'b1': jnp.zeros(hidden_dim),
'w2': jax.random.normal(keys[3], (hidden_dim, 1)) * 0.01,
'b2': jnp.zeros(1)
}
}
def actor_forward(params, state):
x = nn.relu(jnp.dot(state, params['actor']['w1']) + params['actor']['b1'])
logits = jnp.dot(x, params['actor']['w2']) + params['actor']['b2']
return nn.softmax(logits)
def critic_forward(params, state):
x = nn.relu(jnp.dot(state, params['critic']['w1']) + params['critic']['b1'])
return jnp.dot(x, params['critic']['w2']) + params['critic']['b2']
PPO 训练 #
python
import jax
def ppo_loss(params, states, actions, advantages, old_probs, clip_ratio=0.2):
probs = actor_forward(params, states)
action_probs = probs[jnp.arange(len(actions)), actions]
old_action_probs = old_probs[jnp.arange(len(actions)), actions]
ratio = action_probs / (old_action_probs + 1e-8)
clipped_ratio = jnp.clip(ratio, 1 - clip_ratio, 1 + clip_ratio)
policy_loss = -jnp.mean(jnp.minimum(ratio * advantages, clipped_ratio * advantages))
values = critic_forward(params, states).squeeze()
value_loss = jnp.mean((values - advantages) ** 2)
return policy_loss + 0.5 * value_loss
@jax.jit
def ppo_train_step(params, states, actions, advantages, old_probs, lr=0.001):
loss, grads = jax.value_and_grad(ppo_loss)(params, states, actions, advantages, old_probs)
new_params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
return new_params, loss
下一步 #
恭喜你完成了 JAX 强化学习实战!现在你已经掌握了 JAX 的核心功能,可以开始自己的项目了!
最后更新:2026-04-04