TensorFlow.js 模型训练 #

训练流程概述 #

模型训练是机器学习的核心环节,通过优化算法调整模型参数以最小化损失函数。

text
┌─────────────────────────────────────────────────────────────┐
│                      训练流程                                │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ┌─────────┐    ┌─────────┐    ┌─────────┐    ┌─────────┐  │
│  │ 数据准备 │ -> │ 模型构建 │ -> │ 模型编译 │ -> │ 模型训练 │  │
│  └─────────┘    └─────────┘    └─────────┘    └─────────┘  │
│       │              │              │              │        │
│       v              v              v              v        │
│  ┌─────────┐    ┌─────────┐    ┌─────────┐    ┌─────────┐  │
│  │ 数据增强 │    │ 网络架构 │    │ 损失函数 │    │ 前向传播 │  │
│  │ 数据分割 │    │ 层配置   │    │ 优化器   │    │ 反向传播 │  │
│  │ 数据转换 │    │ 参数初始化│   │ 指标     │    │ 参数更新 │  │
│  └─────────┘    └─────────┘    └─────────┘    └─────────┘  │
│                                                             │
└─────────────────────────────────────────────────────────────┘

数据准备 #

创建张量数据 #

javascript
const xs = tf.tensor2d([
  [1, 2, 3],
  [4, 5, 6],
  [7, 8, 9],
  [10, 11, 12]
]);

const ys = tf.tensor2d([
  [0],
  [1],
  [0],
  [1]
]);

数据归一化 #

javascript
function normalize(data) {
  const min = data.min();
  const max = data.max();
  return data.sub(min).div(max.sub(min));
}

const normalized = normalize(tf.tensor([1, 5, 10, 15, 20]));
normalized.print();

数据标准化 #

javascript
function standardize(data) {
  const mean = data.mean();
  const std = data.std();
  return data.sub(mean).div(std);
}

const standardized = standardize(tf.tensor([1, 2, 3, 4, 5]));
standardized.print();

数据分割 #

javascript
function trainTestSplit(xs, ys, testRatio = 0.2) {
  const numSamples = xs.shape[0];
  const numTest = Math.floor(numSamples * testRatio);
  const numTrain = numSamples - numTest;

  const indices = tf.util.createShuffledIndices(numSamples);
  
  const trainIndices = indices.slice(0, numTrain);
  const testIndices = indices.slice(numTrain);

  const xTrain = xs.gather(trainIndices);
  const yTrain = ys.gather(trainIndices);
  const xTest = xs.gather(testIndices);
  const yTest = ys.gather(testIndices);

  return { xTrain, yTrain, xTest, yTest };
}

使用 tf.data API #

javascript
function createDataset(xs, ys, batchSize) {
  return tf.data.zip({
    xs: tf.data.array(xs.arraySync()),
    ys: tf.data.array(ys.arraySync())
  }).shuffle(1000).batch(batchSize);
}

const dataset = createDataset(xs, ys, 32);

模型编译 #

基本编译 #

javascript
model.compile({
  optimizer: 'adam',
  loss: 'meanSquaredError',
  metrics: ['accuracy']
});

配置优化器 #

javascript
model.compile({
  optimizer: tf.train.adam(0.001),
  loss: 'categoricalCrossentropy',
  metrics: ['accuracy']
});

多输出模型编译 #

javascript
model.compile({
  optimizer: 'adam',
  loss: {
    output1: 'categoricalCrossentropy',
    output2: 'meanSquaredError'
  },
  lossWeights: {
    output1: 0.8,
    output2: 0.2
  },
  metrics: {
    output1: ['accuracy'],
    output2: ['mse']
  }
});

常用损失函数 #

javascript
model.compile({ loss: 'meanSquaredError' });
model.compile({ loss: 'meanAbsoluteError' });
model.compile({ loss: 'binaryCrossentropy' });
model.compile({ loss: 'categoricalCrossentropy' });
model.compile({ loss: 'sparseCategoricalCrossentropy' });

自定义损失函数 #

javascript
function customLoss(yTrue, yPred) {
  return yTrue.sub(yPred).square().mean();
}

model.compile({
  optimizer: 'adam',
  loss: customLoss
});

模型训练 #

基本训练 #

javascript
const history = await model.fit(xs, ys, {
  epochs: 10,
  batchSize: 32
});

console.log('训练完成');
console.log('最终损失:', history.history.loss[history.history.loss.length - 1]);

完整训练配置 #

javascript
const history = await model.fit(xTrain, yTrain, {
  epochs: 100,
  batchSize: 32,
  shuffle: true,
  validationSplit: 0.2,
  verbose: 1
});

使用验证数据 #

javascript
const history = await model.fit(xTrain, yTrain, {
  epochs: 50,
  batchSize: 32,
  validationData: [xVal, yVal]
});

使用 Dataset 训练 #

javascript
const history = await model.fitDataset(dataset, {
  epochs: 10,
  validationData: valDataset
});

训练回调 #

基本回调 #

javascript
const history = await model.fit(xs, ys, {
  epochs: 10,
  callbacks: {
    onTrainBegin: () => {
      console.log('训练开始');
    },
    onTrainEnd: () => {
      console.log('训练结束');
    },
    onEpochBegin: (epoch) => {
      console.log(`Epoch ${epoch + 1} 开始`);
    },
    onEpochEnd: (epoch, logs) => {
      console.log(`Epoch ${epoch + 1} 结束`);
      console.log(`  loss: ${logs.loss.toFixed(4)}`);
      console.log(`  accuracy: ${logs.acc.toFixed(4)}`);
    },
    onBatchBegin: (batch) => {
      console.log(`Batch ${batch} 开始`);
    },
    onBatchEnd: (batch, logs) => {
      console.log(`Batch ${batch} 结束, loss: ${logs.loss.toFixed(4)}`);
    }
  }
});

EarlyStopping(早停) #

javascript
const history = await model.fit(xs, ys, {
  epochs: 100,
  callbacks: tf.callbacks.earlyStopping({
    monitor: 'val_loss',
    patience: 10,
    minDelta: 0.001,
    mode: 'min'
  })
});

自定义回调类 #

javascript
class CustomCallback extends tf.Callback {
  onEpochEnd(epoch, logs) {
    if (logs.loss < 0.1) {
      console.log('达到目标损失,停止训练');
      this.model.stopTraining = true;
    }
  }
}

const history = await model.fit(xs, ys, {
  epochs: 100,
  callbacks: [new CustomCallback()]
});

模型检查点 #

javascript
const history = await model.fit(xs, ys, {
  epochs: 10,
  callbacks: {
    onEpochEnd: async (epoch, logs) => {
      if (epoch % 5 === 0) {
        await model.save(`localstorage://model-epoch-${epoch}`);
        console.log(`模型已保存: epoch ${epoch}`);
      }
    }
  }
});

学习率调度回调 #

javascript
let learningRate = 0.1;

const history = await model.fit(xs, ys, {
  epochs: 50,
  callbacks: {
    onEpochEnd: (epoch, logs) => {
      if (epoch > 0 && epoch % 10 === 0) {
        learningRate *= 0.5;
        model.optimizer.learningRate = learningRate;
        console.log(`学习率调整为: ${learningRate}`);
      }
    }
  }
});

tfjs-vis 可视化 #

javascript
import * as tfvis from '@tensorflow/tfjs-vis';

const history = await model.fit(xs, ys, {
  epochs: 50,
  callbacks: tfvis.show.fitCallbacks(
    { name: 'Training Performance' },
    ['loss', 'val_loss', 'acc', 'val_acc']
  )
});

模型评估 #

评估模型 #

javascript
const result = model.evaluate(xTest, yTest, {
  batchSize: 32,
  verbose: 1
});

console.log('测试损失:', result[0].dataSync()[0]);
console.log('测试准确率:', result[1].dataSync()[0]);

混淆矩阵 #

javascript
function confusionMatrix(model, xTest, yTest, numClasses) {
  const predictions = model.predict(xTest);
  const predLabels = predictions.argMax(-1);
  const trueLabels = yTest.argMax(-1);

  const matrix = tf.zeros([numClasses, numClasses]);

  for (let i = 0; i < predLabels.shape[0]; i++) {
    const pred = predLabels.arraySync()[i];
    const true_ = trueLabels.arraySync()[i];
    const indices = tf.tensor2d([[true_, pred]], [1, 2]);
    const updates = tf.tensor([1]);
    matrix = matrix.bufferSync().set(1, true_, pred).toTensor();
  }

  return matrix;
}

分类报告 #

javascript
function classificationReport(model, xTest, yTest) {
  const predictions = model.predict(xTest);
  const predLabels = predictions.argMax(-1).arraySync();
  const trueLabels = yTest.argMax(-1).arraySync();

  const numClasses = predictions.shape[1];
  const report = {};

  for (let c = 0; c < numClasses; c++) {
    const tp = predLabels.filter((p, i) => p === c && trueLabels[i] === c).length;
    const fp = predLabels.filter((p, i) => p === c && trueLabels[i] !== c).length;
    const fn = predLabels.filter((p, i) => p !== c && trueLabels[i] === c).length;

    const precision = tp / (tp + fp) || 0;
    const recall = tp / (tp + fn) || 0;
    const f1 = 2 * precision * recall / (precision + recall) || 0;

    report[c] = { precision, recall, f1 };
  }

  return report;
}

训练技巧 #

学习率预热 #

javascript
const warmupEpochs = 5;
const targetLearningRate = 0.001;

for (let epoch = 0; epoch < warmupEpochs; epoch++) {
  const lr = targetLearningRate * (epoch + 1) / warmupEpochs;
  model.optimizer.learningRate = lr;
  await model.fit(xs, ys, { epochs: 1 });
}

梯度裁剪 #

javascript
model.compile({
  optimizer: tf.train.adam(0.001, undefined, undefined, undefined, 1.0),
  loss: 'meanSquaredError'
});

标签平滑 #

javascript
function labelSmoothing(labels, smoothing = 0.1) {
  const numClasses = labels.shape[1];
  const smoothLabels = labels.mul(1 - smoothing).add(smoothing / numClasses);
  return smoothLabels;
}

类别权重 #

javascript
const classWeights = { 0: 1.0, 1: 2.0, 2: 1.5 };

const weightedLoss = (yTrue, yPred) => {
  const weights = tf.tensor(
    yTrue.argMax(-1).arraySync().map(c => classWeights[c])
  );
  return tf.losses.meanSquaredError(yTrue, yPred).mul(weights).mean();
};

训练循环控制 #

手动训练循环 #

javascript
async function trainLoop(model, xs, ys, epochs, batchSize) {
  const numBatches = Math.ceil(xs.shape[0] / batchSize);

  for (let epoch = 0; epoch < epochs; epoch++) {
    let epochLoss = 0;

    for (let batch = 0; batch < numBatches; batch++) {
      const start = batch * batchSize;
      const end = Math.min(start + batchSize, xs.shape[0]);

      const xBatch = xs.slice([start, 0], [end - start, -1]);
      const yBatch = ys.slice([start, 0], [end - start, -1]);

      const loss = model.trainOnBatch(xBatch, yBatch);
      epochLoss += loss;

      xBatch.dispose();
      yBatch.dispose();
    }

    console.log(`Epoch ${epoch + 1}/${epochs}, Loss: ${epochLoss / numBatches}`);
  }
}

trainOnBatch #

javascript
const loss = model.trainOnBatch(xBatch, yBatch);

fitDataset #

javascript
const history = await model.fitDataset(dataset, {
  epochs: 10,
  validationData: valDataset,
  callbacks: {
    onEpochEnd: (epoch, logs) => {
      console.log(`Epoch ${epoch}: loss = ${logs.loss}`);
    }
  }
});

训练监控 #

TensorBoard 风格监控 #

javascript
const metrics = {
  loss: [],
  valLoss: [],
  accuracy: [],
  valAccuracy: []
};

const history = await model.fit(xs, ys, {
  epochs: 50,
  validationSplit: 0.2,
  callbacks: {
    onEpochEnd: (epoch, logs) => {
      metrics.loss.push(logs.loss);
      metrics.valLoss.push(logs.val_loss);
      metrics.accuracy.push(logs.acc);
      metrics.valAccuracy.push(logs.val_acc);
      
      updateCharts(metrics);
    }
  }
});

内存监控 #

javascript
setInterval(() => {
  const mem = tf.memory();
  console.log(`张量数量: ${mem.numTensors}`);
  console.log(`内存使用: ${(mem.numBytes / 1024 / 1024).toFixed(2)} MB`);
}, 1000);

完整训练示例 #

javascript
async function trainModel() {
  const xs = tf.randomNormal([1000, 10]);
  const ys = tf.randomUniform([1000, 1], 0, 1);

  const model = tf.sequential({
    layers: [
      tf.layers.dense({ units: 64, inputShape: [10], activation: 'relu' }),
      tf.layers.dropout({ rate: 0.3 }),
      tf.layers.dense({ units: 32, activation: 'relu' }),
      tf.layers.dense({ units: 1 })
    ]
  });

  model.compile({
    optimizer: tf.train.adam(0.001),
    loss: 'meanSquaredError',
    metrics: ['mse']
  });

  const history = await model.fit(xs, ys, {
    epochs: 50,
    batchSize: 32,
    validationSplit: 0.2,
    callbacks: {
      onEpochEnd: (epoch, logs) => {
        console.log(`Epoch ${epoch + 1}: loss = ${logs.loss.toFixed(4)}, val_loss = ${logs.val_loss.toFixed(4)}`);
      }
    }
  });

  await model.save('localstorage://my-model');
  console.log('模型已保存');

  return model;
}

trainModel();

下一步 #

现在你已经掌握了模型训练,接下来学习 优化算法,深入了解各种优化器!

最后更新:2026-03-29