TensorFlow.js 模型训练 #
训练流程概述 #
模型训练是机器学习的核心环节,通过优化算法调整模型参数以最小化损失函数。
text
┌─────────────────────────────────────────────────────────────┐
│ 训练流程 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ 数据准备 │ -> │ 模型构建 │ -> │ 模型编译 │ -> │ 模型训练 │ │
│ └─────────┘ └─────────┘ └─────────┘ └─────────┘ │
│ │ │ │ │ │
│ v v v v │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ 数据增强 │ │ 网络架构 │ │ 损失函数 │ │ 前向传播 │ │
│ │ 数据分割 │ │ 层配置 │ │ 优化器 │ │ 反向传播 │ │
│ │ 数据转换 │ │ 参数初始化│ │ 指标 │ │ 参数更新 │ │
│ └─────────┘ └─────────┘ └─────────┘ └─────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
数据准备 #
创建张量数据 #
javascript
const xs = tf.tensor2d([
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12]
]);
const ys = tf.tensor2d([
[0],
[1],
[0],
[1]
]);
数据归一化 #
javascript
function normalize(data) {
const min = data.min();
const max = data.max();
return data.sub(min).div(max.sub(min));
}
const normalized = normalize(tf.tensor([1, 5, 10, 15, 20]));
normalized.print();
数据标准化 #
javascript
function standardize(data) {
const mean = data.mean();
const std = data.std();
return data.sub(mean).div(std);
}
const standardized = standardize(tf.tensor([1, 2, 3, 4, 5]));
standardized.print();
数据分割 #
javascript
function trainTestSplit(xs, ys, testRatio = 0.2) {
const numSamples = xs.shape[0];
const numTest = Math.floor(numSamples * testRatio);
const numTrain = numSamples - numTest;
const indices = tf.util.createShuffledIndices(numSamples);
const trainIndices = indices.slice(0, numTrain);
const testIndices = indices.slice(numTrain);
const xTrain = xs.gather(trainIndices);
const yTrain = ys.gather(trainIndices);
const xTest = xs.gather(testIndices);
const yTest = ys.gather(testIndices);
return { xTrain, yTrain, xTest, yTest };
}
使用 tf.data API #
javascript
function createDataset(xs, ys, batchSize) {
return tf.data.zip({
xs: tf.data.array(xs.arraySync()),
ys: tf.data.array(ys.arraySync())
}).shuffle(1000).batch(batchSize);
}
const dataset = createDataset(xs, ys, 32);
模型编译 #
基本编译 #
javascript
model.compile({
optimizer: 'adam',
loss: 'meanSquaredError',
metrics: ['accuracy']
});
配置优化器 #
javascript
model.compile({
optimizer: tf.train.adam(0.001),
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
多输出模型编译 #
javascript
model.compile({
optimizer: 'adam',
loss: {
output1: 'categoricalCrossentropy',
output2: 'meanSquaredError'
},
lossWeights: {
output1: 0.8,
output2: 0.2
},
metrics: {
output1: ['accuracy'],
output2: ['mse']
}
});
常用损失函数 #
javascript
model.compile({ loss: 'meanSquaredError' });
model.compile({ loss: 'meanAbsoluteError' });
model.compile({ loss: 'binaryCrossentropy' });
model.compile({ loss: 'categoricalCrossentropy' });
model.compile({ loss: 'sparseCategoricalCrossentropy' });
自定义损失函数 #
javascript
function customLoss(yTrue, yPred) {
return yTrue.sub(yPred).square().mean();
}
model.compile({
optimizer: 'adam',
loss: customLoss
});
模型训练 #
基本训练 #
javascript
const history = await model.fit(xs, ys, {
epochs: 10,
batchSize: 32
});
console.log('训练完成');
console.log('最终损失:', history.history.loss[history.history.loss.length - 1]);
完整训练配置 #
javascript
const history = await model.fit(xTrain, yTrain, {
epochs: 100,
batchSize: 32,
shuffle: true,
validationSplit: 0.2,
verbose: 1
});
使用验证数据 #
javascript
const history = await model.fit(xTrain, yTrain, {
epochs: 50,
batchSize: 32,
validationData: [xVal, yVal]
});
使用 Dataset 训练 #
javascript
const history = await model.fitDataset(dataset, {
epochs: 10,
validationData: valDataset
});
训练回调 #
基本回调 #
javascript
const history = await model.fit(xs, ys, {
epochs: 10,
callbacks: {
onTrainBegin: () => {
console.log('训练开始');
},
onTrainEnd: () => {
console.log('训练结束');
},
onEpochBegin: (epoch) => {
console.log(`Epoch ${epoch + 1} 开始`);
},
onEpochEnd: (epoch, logs) => {
console.log(`Epoch ${epoch + 1} 结束`);
console.log(` loss: ${logs.loss.toFixed(4)}`);
console.log(` accuracy: ${logs.acc.toFixed(4)}`);
},
onBatchBegin: (batch) => {
console.log(`Batch ${batch} 开始`);
},
onBatchEnd: (batch, logs) => {
console.log(`Batch ${batch} 结束, loss: ${logs.loss.toFixed(4)}`);
}
}
});
EarlyStopping(早停) #
javascript
const history = await model.fit(xs, ys, {
epochs: 100,
callbacks: tf.callbacks.earlyStopping({
monitor: 'val_loss',
patience: 10,
minDelta: 0.001,
mode: 'min'
})
});
自定义回调类 #
javascript
class CustomCallback extends tf.Callback {
onEpochEnd(epoch, logs) {
if (logs.loss < 0.1) {
console.log('达到目标损失,停止训练');
this.model.stopTraining = true;
}
}
}
const history = await model.fit(xs, ys, {
epochs: 100,
callbacks: [new CustomCallback()]
});
模型检查点 #
javascript
const history = await model.fit(xs, ys, {
epochs: 10,
callbacks: {
onEpochEnd: async (epoch, logs) => {
if (epoch % 5 === 0) {
await model.save(`localstorage://model-epoch-${epoch}`);
console.log(`模型已保存: epoch ${epoch}`);
}
}
}
});
学习率调度回调 #
javascript
let learningRate = 0.1;
const history = await model.fit(xs, ys, {
epochs: 50,
callbacks: {
onEpochEnd: (epoch, logs) => {
if (epoch > 0 && epoch % 10 === 0) {
learningRate *= 0.5;
model.optimizer.learningRate = learningRate;
console.log(`学习率调整为: ${learningRate}`);
}
}
}
});
tfjs-vis 可视化 #
javascript
import * as tfvis from '@tensorflow/tfjs-vis';
const history = await model.fit(xs, ys, {
epochs: 50,
callbacks: tfvis.show.fitCallbacks(
{ name: 'Training Performance' },
['loss', 'val_loss', 'acc', 'val_acc']
)
});
模型评估 #
评估模型 #
javascript
const result = model.evaluate(xTest, yTest, {
batchSize: 32,
verbose: 1
});
console.log('测试损失:', result[0].dataSync()[0]);
console.log('测试准确率:', result[1].dataSync()[0]);
混淆矩阵 #
javascript
function confusionMatrix(model, xTest, yTest, numClasses) {
const predictions = model.predict(xTest);
const predLabels = predictions.argMax(-1);
const trueLabels = yTest.argMax(-1);
const matrix = tf.zeros([numClasses, numClasses]);
for (let i = 0; i < predLabels.shape[0]; i++) {
const pred = predLabels.arraySync()[i];
const true_ = trueLabels.arraySync()[i];
const indices = tf.tensor2d([[true_, pred]], [1, 2]);
const updates = tf.tensor([1]);
matrix = matrix.bufferSync().set(1, true_, pred).toTensor();
}
return matrix;
}
分类报告 #
javascript
function classificationReport(model, xTest, yTest) {
const predictions = model.predict(xTest);
const predLabels = predictions.argMax(-1).arraySync();
const trueLabels = yTest.argMax(-1).arraySync();
const numClasses = predictions.shape[1];
const report = {};
for (let c = 0; c < numClasses; c++) {
const tp = predLabels.filter((p, i) => p === c && trueLabels[i] === c).length;
const fp = predLabels.filter((p, i) => p === c && trueLabels[i] !== c).length;
const fn = predLabels.filter((p, i) => p !== c && trueLabels[i] === c).length;
const precision = tp / (tp + fp) || 0;
const recall = tp / (tp + fn) || 0;
const f1 = 2 * precision * recall / (precision + recall) || 0;
report[c] = { precision, recall, f1 };
}
return report;
}
训练技巧 #
学习率预热 #
javascript
const warmupEpochs = 5;
const targetLearningRate = 0.001;
for (let epoch = 0; epoch < warmupEpochs; epoch++) {
const lr = targetLearningRate * (epoch + 1) / warmupEpochs;
model.optimizer.learningRate = lr;
await model.fit(xs, ys, { epochs: 1 });
}
梯度裁剪 #
javascript
model.compile({
optimizer: tf.train.adam(0.001, undefined, undefined, undefined, 1.0),
loss: 'meanSquaredError'
});
标签平滑 #
javascript
function labelSmoothing(labels, smoothing = 0.1) {
const numClasses = labels.shape[1];
const smoothLabels = labels.mul(1 - smoothing).add(smoothing / numClasses);
return smoothLabels;
}
类别权重 #
javascript
const classWeights = { 0: 1.0, 1: 2.0, 2: 1.5 };
const weightedLoss = (yTrue, yPred) => {
const weights = tf.tensor(
yTrue.argMax(-1).arraySync().map(c => classWeights[c])
);
return tf.losses.meanSquaredError(yTrue, yPred).mul(weights).mean();
};
训练循环控制 #
手动训练循环 #
javascript
async function trainLoop(model, xs, ys, epochs, batchSize) {
const numBatches = Math.ceil(xs.shape[0] / batchSize);
for (let epoch = 0; epoch < epochs; epoch++) {
let epochLoss = 0;
for (let batch = 0; batch < numBatches; batch++) {
const start = batch * batchSize;
const end = Math.min(start + batchSize, xs.shape[0]);
const xBatch = xs.slice([start, 0], [end - start, -1]);
const yBatch = ys.slice([start, 0], [end - start, -1]);
const loss = model.trainOnBatch(xBatch, yBatch);
epochLoss += loss;
xBatch.dispose();
yBatch.dispose();
}
console.log(`Epoch ${epoch + 1}/${epochs}, Loss: ${epochLoss / numBatches}`);
}
}
trainOnBatch #
javascript
const loss = model.trainOnBatch(xBatch, yBatch);
fitDataset #
javascript
const history = await model.fitDataset(dataset, {
epochs: 10,
validationData: valDataset,
callbacks: {
onEpochEnd: (epoch, logs) => {
console.log(`Epoch ${epoch}: loss = ${logs.loss}`);
}
}
});
训练监控 #
TensorBoard 风格监控 #
javascript
const metrics = {
loss: [],
valLoss: [],
accuracy: [],
valAccuracy: []
};
const history = await model.fit(xs, ys, {
epochs: 50,
validationSplit: 0.2,
callbacks: {
onEpochEnd: (epoch, logs) => {
metrics.loss.push(logs.loss);
metrics.valLoss.push(logs.val_loss);
metrics.accuracy.push(logs.acc);
metrics.valAccuracy.push(logs.val_acc);
updateCharts(metrics);
}
}
});
内存监控 #
javascript
setInterval(() => {
const mem = tf.memory();
console.log(`张量数量: ${mem.numTensors}`);
console.log(`内存使用: ${(mem.numBytes / 1024 / 1024).toFixed(2)} MB`);
}, 1000);
完整训练示例 #
javascript
async function trainModel() {
const xs = tf.randomNormal([1000, 10]);
const ys = tf.randomUniform([1000, 1], 0, 1);
const model = tf.sequential({
layers: [
tf.layers.dense({ units: 64, inputShape: [10], activation: 'relu' }),
tf.layers.dropout({ rate: 0.3 }),
tf.layers.dense({ units: 32, activation: 'relu' }),
tf.layers.dense({ units: 1 })
]
});
model.compile({
optimizer: tf.train.adam(0.001),
loss: 'meanSquaredError',
metrics: ['mse']
});
const history = await model.fit(xs, ys, {
epochs: 50,
batchSize: 32,
validationSplit: 0.2,
callbacks: {
onEpochEnd: (epoch, logs) => {
console.log(`Epoch ${epoch + 1}: loss = ${logs.loss.toFixed(4)}, val_loss = ${logs.val_loss.toFixed(4)}`);
}
}
});
await model.save('localstorage://my-model');
console.log('模型已保存');
return model;
}
trainModel();
下一步 #
现在你已经掌握了模型训练,接下来学习 优化算法,深入了解各种优化器!
最后更新:2026-03-29