文本生成实战 #
概述 #
本节使用 JAX 实现一个简单的 Transformer 模型进行文本生成。
数据准备 #
文本数据 #
python
import jax
import jax.numpy as jnp
text = "hello world this is a simple text generation example with jax"
chars = sorted(list(set(text)))
vocab_size = len(chars)
char_to_idx = {c: i for i, c in enumerate(chars)}
idx_to_char = {i: c for i, c in enumerate(chars)}
print(f"词汇表大小: {vocab_size}")
print(f"字符: {chars}")
def encode(text):
return [char_to_idx[c] for c in text]
def decode(indices):
return ''.join([idx_to_char[i] for i in indices])
encoded = encode(text)
print(f"编码后: {encoded[:10]}...")
def create_sequences(data, seq_length):
sequences = []
targets = []
for i in range(len(data) - seq_length):
sequences.append(data[i:i+seq_length])
targets.append(data[i+1:i+seq_length+1])
return jnp.array(sequences), jnp.array(targets)
seq_length = 10
x_train, y_train = create_sequences(encoded, seq_length)
print(f"训练数据形状: {x_train.shape}, {y_train.shape}")
模型定义 #
Transformer 组件 #
python
import jax
import jax.numpy as jnp
import jax.nn as nn
def init_transformer_params(key, vocab_size, embed_dim, num_heads, ff_dim, num_layers):
params = {}
keys = jax.random.split(key, num_layers * 4 + 2)
params['embedding'] = jax.random.normal(keys[0], (vocab_size, embed_dim)) * 0.01
for i in range(num_layers):
params[f'layer_{i}_attn_w'] = jax.random.normal(keys[i*4+1], (embed_dim, 3 * embed_dim)) * 0.01
params[f'layer_{i}_attn_out'] = jax.random.normal(keys[i*4+2], (embed_dim, embed_dim)) * 0.01
params[f'layer_{i}_ff1'] = jax.random.normal(keys[i*4+3], (embed_dim, ff_dim)) * 0.01
params[f'layer_{i}_ff2'] = jax.random.normal(keys[i*4+4], (ff_dim, embed_dim)) * 0.01
params['output'] = jax.random.normal(keys[-1], (embed_dim, vocab_size)) * 0.01
return params
def attention(q, k, v, mask=None):
d_k = q.shape[-1]
scores = jnp.matmul(q, k.transpose(-1, -2)) / jnp.sqrt(d_k)
if mask is not None:
scores = scores + mask * -1e9
attn_weights = nn.softmax(scores, axis=-1)
return jnp.matmul(attn_weights, v)
def transformer_forward(params, x, embed_dim, num_heads, num_layers):
seq_len = x.shape[1]
x = params['embedding'][x]
mask = jnp.triu(jnp.ones((seq_len, seq_len)), k=1)
for i in range(num_layers):
qkv = jnp.dot(x, params[f'layer_{i}_attn_w'])
q, k, v = jnp.split(qkv, 3, axis=-1)
attn_out = attention(q, k, v, mask)
x = x + jnp.dot(attn_out, params[f'layer_{i}_attn_out'])
x = nn.layer_norm(x, -1)
ff_out = nn.relu(jnp.dot(x, params[f'layer_{i}_ff1']))
ff_out = jnp.dot(ff_out, params[f'layer_{i}_ff2'])
x = x + ff_out
x = nn.layer_norm(x, -1)
logits = jnp.dot(x, params['output'])
return logits
训练 #
训练步骤 #
python
import jax
def loss_fn(params, x, y, embed_dim, num_heads, num_layers):
logits = transformer_forward(params, x, embed_dim, num_heads, num_layers)
log_probs = nn.log_softmax(logits)
one_hot = nn.one_hot(y, vocab_size)
return -jnp.mean(jnp.sum(one_hot * log_probs, axis=-1))
@jax.jit
def train_step(params, x, y, embed_dim, num_heads, num_layers, lr=0.001):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y, embed_dim, num_heads, num_layers)
new_params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
return new_params, loss
文本生成 #
python
def generate(params, start_text, length, embed_dim, num_heads, num_layers, temperature=1.0):
current = encode(start_text)
for _ in range(length):
x = jnp.array([current[-seq_length:]])
logits = transformer_forward(params, x, embed_dim, num_heads, num_layers)
logits = logits[0, -1, :] / temperature
probs = nn.softmax(logits)
key = jax.random.PRNGKey(_)
next_char = jax.random.categorical(key, logits)
current.append(int(next_char))
return decode(current)
key = jax.random.PRNGKey(0)
params = init_transformer_params(key, vocab_size, embed_dim=64, num_heads=2, ff_dim=128, num_layers=2)
generated = generate(params, "hello", 20, embed_dim=64, num_heads=2, num_layers=2)
print(f"生成文本: {generated}")
下一步 #
现在你已经完成了文本生成实战,接下来学习 强化学习,探索 RL 算法!
最后更新:2026-04-04