TensorFlow.js 循环神经网络 #

RNN 概述 #

循环神经网络(Recurrent Neural Network,RNN)是处理序列数据的核心技术,能够捕捉序列中的时序依赖关系。

text
┌─────────────────────────────────────────────────────────────┐
│                    RNN 架构概览                               │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  时间步展开:                                                │
│                                                             │
│     x1      x2      x3      x4                             │
│     ↓       ↓       ↓       ↓                              │
│  ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐                          │
│  │ RNN │→│ RNN │→│ RNN │→│ RNN │                          │
│  └──┬──┘ └──┬──┘ └──┬──┘ └──┬──┘                          │
│     ↓       ↓       ↓       ↓                              │
│     y1      y2      y3      y4                             │
│                                                             │
│  特点:                                                      │
│  - 记忆机制:保留历史信息                                    │
│  - 序列建模:处理变长序列                                    │
│  - 参数共享:每个时间步共享参数                              │
│                                                             │
└─────────────────────────────────────────────────────────────┘

RNN 基础 #

SimpleRNN #

最基础的循环神经网络层。

javascript
const rnn = tf.layers.simpleRNN({
  units: 64,
  returnSequences: false,
  returnState: false,
  inputShape: [10, 32]
});

参数说明 #

参数 说明 默认值
units 输出维度 必填
returnSequences 是否返回完整序列 false
returnState 是否返回状态 false
activation 激活函数 tanh
kernelInitializer 权重初始化 glorotNormal

使用示例 #

javascript
const model = tf.sequential();

model.add(tf.layers.simpleRNN({
  units: 64,
  inputShape: [10, 32]
}));

model.add(tf.layers.dense({ units: 10, activation: 'softmax' }));

model.compile({
  optimizer: 'adam',
  loss: 'categoricalCrossentropy',
  metrics: ['accuracy']
});

LSTM #

长短期记忆网络,解决了标准 RNN 的梯度消失问题。

LSTM 结构 #

text
┌─────────────────────────────────────────────────────────────┐
│                    LSTM 单元结构                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│                    ┌─────────────────┐                      │
│                    │      输入门      │                      │
│                    │      (i)        │                      │
│                    └────────┬────────┘                      │
│                             │                               │
│  ┌─────────┐    ┌──────────┴──────────┐    ┌─────────┐    │
│  │  遗忘门  │    │       单元状态       │    │  输出门  │    │
│  │   (f)   │ -> │       (C)           │ -> │   (o)   │    │
│  └─────────┘    └─────────────────────┘    └─────────┘    │
│                                                             │
│  公式:                                                      │
│  f = σ(Wf·[h, x] + bf)     遗忘门                          │
│  i = σ(Wi·[h, x] + bi)     输入门                          │
│  C̃ = tanh(WC·[h, x] + bC)  候选值                          │
│  C = f * C + i * C̃         单元状态                        │
│  o = σ(Wo·[h, x] + bo)     输出门                          │
│  h = o * tanh(C)           隐藏状态                        │
│                                                             │
└─────────────────────────────────────────────────────────────┘

创建 LSTM 层 #

javascript
const lstm = tf.layers.lstm({
  units: 128,
  returnSequences: true,
  returnState: false,
  inputShape: [10, 32]
});

完整 LSTM 模型 #

javascript
const model = tf.sequential();

model.add(tf.layers.lstm({
  units: 128,
  returnSequences: true,
  inputShape: [100, 32]
}));

model.add(tf.layers.lstm({
  units: 64,
  returnSequences: false
}));

model.add(tf.layers.dense({ units: 10, activation: 'softmax' }));

返回状态 #

javascript
const lstm = tf.layers.lstm({
  units: 64,
  returnState: true,
  returnSequences: true,
  inputShape: [10, 32]
});

const input = tf.randomNormal([1, 10, 32]);
const [output, stateH, stateC] = lstm.apply(input);

console.log('输出形状:', output.shape);
console.log('隐藏状态形状:', stateH.shape);
console.log('单元状态形状:', stateC.shape);

GRU #

门控循环单元,比 LSTM 更简单高效。

GRU 结构 #

text
┌─────────────────────────────────────────────────────────────┐
│                    GRU 单元结构                               │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ┌─────────────────────────────────────────────────────┐   │
│  │                                                      │   │
│  │   ┌─────────┐              ┌─────────┐              │   │
│  │   │  重置门  │              │  更新门  │              │   │
│  │   │   (r)   │              │   (z)   │              │   │
│  │   └────┬────┘              └────┬────┘              │   │
│  │        │                        │                    │   │
│  │        └────────┬───────────────┘                    │   │
│  │                 ↓                                    │   │
│  │           ┌─────────┐                                │   │
│  │           │ 候选状态 │                                │   │
│  │           │  (h̃)   │                                │   │
│  │           └────┬────┘                                │   │
│  │                │                                      │   │
│  │                ↓                                      │   │
│  │           ┌─────────┐                                │   │
│  │           │ 隐藏状态 │                                │   │
│  │           │   (h)   │                                │   │
│  │           └─────────┘                                │   │
│  │                                                      │   │
│  └──────────────────────────────────────────────────────┘   │
│                                                             │
│  公式:                                                      │
│  z = σ(Wz·[h, x])    更新门                                │
│  r = σ(Wr·[h, x])    重置门                                │
│  h̃ = tanh(W·[r*h, x]) 候选状态                            │
│  h = (1-z)*h + z*h̃   隐藏状态                             │
│                                                             │
└─────────────────────────────────────────────────────────────┘

创建 GRU 层 #

javascript
const gru = tf.layers.gru({
  units: 128,
  returnSequences: true,
  inputShape: [10, 32]
});

GRU 模型示例 #

javascript
const model = tf.sequential();

model.add(tf.layers.gru({
  units: 64,
  returnSequences: true,
  inputShape: [50, 32]
}));

model.add(tf.layers.gru({
  units: 32
}));

model.add(tf.layers.dense({ units: 1, activation: 'sigmoid' }));

双向 RNN #

Bidirectional 包装器 #

javascript
const bidirectional = tf.layers.bidirectional({
  layer: tf.layers.lstm({ units: 64 }),
  mergeMode: 'concat'
});

mergeMode 选项 #

模式 说明
concat 前向和后向输出拼接
sum 前向和后向输出相加
ave 前向和后向输出平均
mul 前向和后向输出相乘
none 返回原始输出

双向 LSTM 示例 #

javascript
const model = tf.sequential();

model.add(tf.layers.bidirectional({
  layer: tf.layers.lstm({ units: 64, returnSequences: true }),
  inputShape: [100, 32]
}));

model.add(tf.layers.bidirectional({
  layer: tf.layers.lstm({ units: 32 })
}));

model.add(tf.layers.dense({ units: 10, activation: 'softmax' }));

嵌入层 #

Embedding 层 #

将整数索引映射为稠密向量。

javascript
const embedding = tf.layers.embedding({
  inputDim: 10000,
  outputDim: 128,
  inputLength: 100
});

文本分类模型 #

javascript
const model = tf.sequential();

model.add(tf.layers.embedding({
  inputDim: 10000,
  outputDim: 128,
  inputLength: 100
}));

model.add(tf.layers.lstm({ units: 64 }));

model.add(tf.layers.dense({ units: 1, activation: 'sigmoid' }));

model.compile({
  optimizer: 'adam',
  loss: 'binaryCrossentropy',
  metrics: ['accuracy']
});

序列数据处理 #

文本预处理 #

javascript
function tokenize(text, vocab, maxLen) {
  const tokens = text.toLowerCase().split(/\s+/);
  const indices = tokens.map(t => vocab[t] || 0);
  
  while (indices.length < maxLen) {
    indices.push(0);
  }
  
  return indices.slice(0, maxLen);
}

const vocab = { 'hello': 1, 'world': 2, 'tensorflow': 3 };
const sequence = tokenize('Hello TensorFlow World', vocab, 10);
console.log(sequence);

创建词汇表 #

javascript
function buildVocabulary(texts, maxVocabSize) {
  const wordCounts = {};
  
  texts.forEach(text => {
    const words = text.toLowerCase().split(/\s+/);
    words.forEach(word => {
      wordCounts[word] = (wordCounts[word] || 0) + 1;
    });
  });

  const sortedWords = Object.entries(wordCounts)
    .sort((a, b) => b[1] - a[1])
    .slice(0, maxVocabSize);

  const vocab = { '<PAD>': 0, '<UNK>': 1 };
  sortedWords.forEach(([word], index) => {
    vocab[word] = index + 2;
  });

  return vocab;
}

时间序列数据 #

javascript
function createSequences(data, seqLength) {
  const xs = [];
  const ys = [];

  for (let i = 0; i < data.length - seqLength; i++) {
    xs.push(data.slice(i, i + seqLength));
    ys.push(data[i + seqLength]);
  }

  return {
    xs: tf.tensor2d(xs),
    ys: tf.tensor1d(ys)
  };
}

const data = Array.from({ length: 1000 }, (_, i) => Math.sin(i * 0.1));
const { xs, ys } = createSequences(data, 50);

序列到序列模型 #

Encoder-Decoder 架构 #

javascript
function createSeq2SeqModel(inputVocabSize, outputVocabSize, embeddingDim, hiddenUnits) {
  const encoderInput = tf.input({ shape: [null] });
  
  const encoderEmbedding = tf.layers.embedding({
    inputDim: inputVocabSize,
    outputDim: embeddingDim
  }).apply(encoderInput);
  
  const encoderLSTM = tf.layers.lstm({
    units: hiddenUnits,
    returnState: true
  });
  
  const [, encoderStateH, encoderStateC] = encoderLSTM.apply(encoderEmbedding);
  
  const decoderInput = tf.input({ shape: [null] });
  
  const decoderEmbedding = tf.layers.embedding({
    inputDim: outputVocabSize,
    outputDim: embeddingDim
  }).apply(decoderInput);
  
  const decoderLSTM = tf.layers.lstm({
    units: hiddenUnits,
    returnSequences: true,
    returnState: true
  });
  
  const [decoderOutput] = decoderLSTM.apply(
    decoderEmbedding,
    { initialState: [encoderStateH, encoderStateC] }
  );
  
  const decoderDense = tf.layers.dense({
    units: outputVocabSize,
    activation: 'softmax'
  }).apply(decoderOutput);
  
  return tf.model({
    inputs: [encoderInput, decoderInput],
    outputs: decoderDense
  });
}

注意力机制 #

基础注意力 #

javascript
function attention(inputs) {
  const attentionLayer = tf.layers.attention({
    useScale: true
  });
  
  return attentionLayer.apply(inputs);
}

自注意力 #

javascript
function selfAttention(query, key, value) {
  const dK = query.shape[query.shape.length - 1];
  
  const scores = tf.matMul(query, key.transpose())
    .div(tf.scalar(Math.sqrt(dK)));
  
  const attentionWeights = tf.softmax(scores);
  
  return tf.matMul(attentionWeights, value);
}

多头注意力 #

javascript
const multiHeadAttention = tf.layers.multiHeadAttention({
  numHeads: 8,
  keyDim: 64
});

实战示例 #

情感分析 #

javascript
async function trainSentimentAnalysis() {
  const vocabSize = 10000;
  const maxLen = 100;
  const embeddingDim = 128;

  const model = tf.sequential();

  model.add(tf.layers.embedding({
    inputDim: vocabSize,
    outputDim: embeddingDim,
    inputLength: maxLen
  }));

  model.add(tf.layers.bidirectional({
    layer: tf.layers.lstm({ units: 64, returnSequences: true })
  }));

  model.add(tf.layers.globalMaxPooling1d());

  model.add(tf.layers.dense({ units: 64, activation: 'relu' }));
  model.add(tf.layers.dropout({ rate: 0.5 }));
  model.add(tf.layers.dense({ units: 1, activation: 'sigmoid' }));

  model.compile({
    optimizer: tf.train.adam(0.001),
    loss: 'binaryCrossentropy',
    metrics: ['accuracy']
  });

  return model;
}

文本生成 #

javascript
async function trainTextGeneration() {
  const vocabSize = 10000;
  const seqLength = 50;
  const embeddingDim = 256;

  const model = tf.sequential();

  model.add(tf.layers.embedding({
    inputDim: vocabSize,
    outputDim: embeddingDim,
    inputLength: seqLength
  }));

  model.add(tf.layers.lstm({
    units: 256,
    returnSequences: true
  }));

  model.add(tf.layers.dropout({ rate: 0.3 }));

  model.add(tf.layers.lstm({
    units: 256,
    returnSequences: true
  }));

  model.add(tf.layers.timeDistributed(
    tf.layers.dense({ units: vocabSize, activation: 'softmax' })
  ));

  model.compile({
    optimizer: tf.train.adam(0.001),
    loss: 'categoricalCrossentropy'
  });

  return model;
}

时间序列预测 #

javascript
async function trainTimeSeriesForecasting() {
  const model = tf.sequential();

  model.add(tf.layers.lstm({
    units: 50,
    returnSequences: true,
    inputShape: [50, 1]
  }));

  model.add(tf.layers.lstm({
    units: 50
  }));

  model.add(tf.layers.dense({ units: 1 }));

  model.compile({
    optimizer: tf.train.adam(0.001),
    loss: 'meanSquaredError'
  });

  return model;
}

训练技巧 #

梯度裁剪 #

javascript
model.compile({
  optimizer: tf.train.rmsprop(0.001, undefined, undefined, undefined, 1.0),
  loss: 'categoricalCrossentropy'
});

序列填充 #

javascript
function padSequences(sequences, maxLen, padding = 'post') {
  return sequences.map(seq => {
    if (seq.length >= maxLen) {
      return seq.slice(0, maxLen);
    }
    
    const paddingArray = new Array(maxLen - seq.length).fill(0);
    
    if (padding === 'post') {
      return [...seq, ...paddingArray];
    } else {
      return [...paddingArray, ...seq];
    }
  });
}

Masking #

javascript
model.add(tf.layers.embedding({
  inputDim: 10000,
  outputDim: 128,
  maskZero: true
}));

下一步 #

现在你已经掌握了 RNN,接下来学习 迁移学习,了解如何利用预训练模型!

最后更新:2026-03-29