文本生成实战 #

概述 #

本节使用 JAX 实现一个简单的 Transformer 模型进行文本生成。

数据准备 #

文本数据 #

python
import jax
import jax.numpy as jnp

text = "hello world this is a simple text generation example with jax"

chars = sorted(list(set(text)))
vocab_size = len(chars)
char_to_idx = {c: i for i, c in enumerate(chars)}
idx_to_char = {i: c for i, c in enumerate(chars)}

print(f"词汇表大小: {vocab_size}")
print(f"字符: {chars}")

def encode(text):
    return [char_to_idx[c] for c in text]

def decode(indices):
    return ''.join([idx_to_char[i] for i in indices])

encoded = encode(text)
print(f"编码后: {encoded[:10]}...")

def create_sequences(data, seq_length):
    sequences = []
    targets = []
    for i in range(len(data) - seq_length):
        sequences.append(data[i:i+seq_length])
        targets.append(data[i+1:i+seq_length+1])
    return jnp.array(sequences), jnp.array(targets)

seq_length = 10
x_train, y_train = create_sequences(encoded, seq_length)
print(f"训练数据形状: {x_train.shape}, {y_train.shape}")

模型定义 #

Transformer 组件 #

python
import jax
import jax.numpy as jnp
import jax.nn as nn

def init_transformer_params(key, vocab_size, embed_dim, num_heads, ff_dim, num_layers):
    params = {}
    keys = jax.random.split(key, num_layers * 4 + 2)
    
    params['embedding'] = jax.random.normal(keys[0], (vocab_size, embed_dim)) * 0.01
    
    for i in range(num_layers):
        params[f'layer_{i}_attn_w'] = jax.random.normal(keys[i*4+1], (embed_dim, 3 * embed_dim)) * 0.01
        params[f'layer_{i}_attn_out'] = jax.random.normal(keys[i*4+2], (embed_dim, embed_dim)) * 0.01
        params[f'layer_{i}_ff1'] = jax.random.normal(keys[i*4+3], (embed_dim, ff_dim)) * 0.01
        params[f'layer_{i}_ff2'] = jax.random.normal(keys[i*4+4], (ff_dim, embed_dim)) * 0.01
    
    params['output'] = jax.random.normal(keys[-1], (embed_dim, vocab_size)) * 0.01
    
    return params

def attention(q, k, v, mask=None):
    d_k = q.shape[-1]
    scores = jnp.matmul(q, k.transpose(-1, -2)) / jnp.sqrt(d_k)
    
    if mask is not None:
        scores = scores + mask * -1e9
    
    attn_weights = nn.softmax(scores, axis=-1)
    return jnp.matmul(attn_weights, v)

def transformer_forward(params, x, embed_dim, num_heads, num_layers):
    seq_len = x.shape[1]
    
    x = params['embedding'][x]
    
    mask = jnp.triu(jnp.ones((seq_len, seq_len)), k=1)
    
    for i in range(num_layers):
        qkv = jnp.dot(x, params[f'layer_{i}_attn_w'])
        q, k, v = jnp.split(qkv, 3, axis=-1)
        
        attn_out = attention(q, k, v, mask)
        x = x + jnp.dot(attn_out, params[f'layer_{i}_attn_out'])
        x = nn.layer_norm(x, -1)
        
        ff_out = nn.relu(jnp.dot(x, params[f'layer_{i}_ff1']))
        ff_out = jnp.dot(ff_out, params[f'layer_{i}_ff2'])
        x = x + ff_out
        x = nn.layer_norm(x, -1)
    
    logits = jnp.dot(x, params['output'])
    return logits

训练 #

训练步骤 #

python
import jax

def loss_fn(params, x, y, embed_dim, num_heads, num_layers):
    logits = transformer_forward(params, x, embed_dim, num_heads, num_layers)
    log_probs = nn.log_softmax(logits)
    one_hot = nn.one_hot(y, vocab_size)
    return -jnp.mean(jnp.sum(one_hot * log_probs, axis=-1))

@jax.jit
def train_step(params, x, y, embed_dim, num_heads, num_layers, lr=0.001):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y, embed_dim, num_heads, num_layers)
    new_params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
    return new_params, loss

文本生成 #

python
def generate(params, start_text, length, embed_dim, num_heads, num_layers, temperature=1.0):
    current = encode(start_text)
    
    for _ in range(length):
        x = jnp.array([current[-seq_length:]])
        logits = transformer_forward(params, x, embed_dim, num_heads, num_layers)
        
        logits = logits[0, -1, :] / temperature
        probs = nn.softmax(logits)
        
        key = jax.random.PRNGKey(_)
        next_char = jax.random.categorical(key, logits)
        current.append(int(next_char))
    
    return decode(current)

key = jax.random.PRNGKey(0)
params = init_transformer_params(key, vocab_size, embed_dim=64, num_heads=2, ff_dim=128, num_layers=2)

generated = generate(params, "hello", 20, embed_dim=64, num_heads=2, num_layers=2)
print(f"生成文本: {generated}")

下一步 #

现在你已经完成了文本生成实战,接下来学习 强化学习,探索 RL 算法!

最后更新:2026-04-04