TensorFlow.js 模型构建 #
模型概述 #
模型是神经网络的核心,它定义了网络的架构和参数。TensorFlow.js 提供两种主要的模型构建方式。
text
┌─────────────────────────────────────────────────────────────┐
│ 模型构建方式 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────┐ ┌─────────────────────┐ │
│ │ Sequential API │ │ Functional API │ │
│ │ │ │ │ │
│ │ - 简单直观 │ │ - 灵活复杂 │ │
│ │ - 层线性堆叠 │ │ - 支持多输入输出 │ │
│ │ - 适合初学者 │ │ - 支持残差连接 │ │
│ │ │ │ │ │
│ └─────────────────────┘ └─────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
Sequential API #
Sequential API 是最简单的模型构建方式,适用于层的线性堆叠。
创建顺序模型 #
javascript
const model = tf.sequential();
添加层 #
javascript
model.add(tf.layers.dense({ units: 64, inputShape: [10] }));
model.add(tf.layers.dense({ units: 32, activation: 'relu' }));
model.add(tf.layers.dense({ units: 1, activation: 'sigmoid' }));
链式创建 #
javascript
const model = tf.sequential({
layers: [
tf.layers.dense({ units: 64, inputShape: [10], activation: 'relu' }),
tf.layers.dense({ units: 32, activation: 'relu' }),
tf.layers.dense({ units: 1, activation: 'sigmoid' })
]
});
完整示例 #
javascript
const model = tf.sequential();
model.add(tf.layers.dense({
units: 128,
inputShape: [784],
activation: 'relu',
kernelInitializer: 'heNormal'
}));
model.add(tf.layers.dropout({ rate: 0.5 }));
model.add(tf.layers.dense({
units: 64,
activation: 'relu'
}));
model.add(tf.layers.dense({
units: 10,
activation: 'softmax'
}));
model.compile({
optimizer: 'adam',
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
model.summary();
Functional API #
Functional API 更加灵活,支持复杂的网络结构。
基本用法 #
javascript
const input = tf.input({ shape: [784] });
const dense1 = tf.layers.dense({ units: 128, activation: 'relu' }).apply(input);
const dropout = tf.layers.dropout({ rate: 0.5 }).apply(dense1);
const dense2 = tf.layers.dense({ units: 64, activation: 'relu' }).apply(dropout);
const output = tf.layers.dense({ units: 10, activation: 'softmax' }).apply(dense2);
const model = tf.model({ inputs: input, outputs: output });
多输入模型 #
javascript
const input1 = tf.input({ shape: [64] });
const input2 = tf.input({ shape: [32] });
const branch1 = tf.layers.dense({ units: 32, activation: 'relu' }).apply(input1);
const branch2 = tf.layers.dense({ units: 16, activation: 'relu' }).apply(input2);
const merged = tf.layers.concatenate().apply([branch1, branch2]);
const output = tf.layers.dense({ units: 1, activation: 'sigmoid' }).apply(merged);
const model = tf.model({ inputs: [input1, input2], outputs: output });
多输出模型 #
javascript
const input = tf.input({ shape: [100] });
const shared = tf.layers.dense({ units: 64, activation: 'relu' }).apply(input);
const output1 = tf.layers.dense({ units: 10, activation: 'softmax', name: 'classification' }).apply(shared);
const output2 = tf.layers.dense({ units: 1, activation: 'sigmoid', name: 'sentiment' }).apply(shared);
const model = tf.model({ inputs: input, outputs: [output1, output2] });
残差连接 #
javascript
const input = tf.input({ shape: [64] });
const dense1 = tf.layers.dense({ units: 64, activation: 'relu' }).apply(input);
const dense2 = tf.layers.dense({ units: 64 }).apply(dense1);
const residual = tf.layers.add().apply([input, dense2]);
const output = tf.layers.activation({ activation: 'relu' }).apply(residual);
const model = tf.model({ inputs: input, outputs: output });
模型配置 #
编译模型 #
javascript
model.compile({
optimizer: 'adam',
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
优化器配置 #
javascript
model.compile({
optimizer: tf.train.adam(0.001),
loss: 'meanSquaredError',
metrics: ['mse']
});
自定义损失函数 #
javascript
function customLoss(yTrue, yPred) {
return yTrue.sub(yPred).square().mean();
}
model.compile({
optimizer: 'adam',
loss: customLoss
});
多输出损失配置 #
javascript
model.compile({
optimizer: 'adam',
loss: {
classification: 'categoricalCrossentropy',
sentiment: 'binaryCrossentropy'
},
lossWeights: {
classification: 0.8,
sentiment: 0.2
},
metrics: ['accuracy']
});
模型信息 #
模型摘要 #
javascript
model.summary();
获取层信息 #
javascript
console.log('层数量:', model.layers.length);
model.layers.forEach((layer, index) => {
console.log(`层 ${index}: ${layer.name}`);
console.log(` 输入形状: ${layer.inputShape}`);
console.log(` 输出形状: ${layer.outputShape}`);
console.log(` 参数数量: ${layer.countParams()}`);
});
获取权重 #
javascript
const weights = model.getWeights();
weights.forEach((w, i) => {
console.log(`权重 ${i}:`, w.shape);
});
设置权重 #
javascript
const weights = model.getWeights();
model.setWeights(weights);
模型保存与加载 #
保存到 LocalStorage #
javascript
await model.save('localstorage://my-model');
从 LocalStorage 加载 #
javascript
const model = await tf.loadLayersModel('localstorage://my-model');
保存到 IndexedDB #
javascript
await model.save('indexeddb://my-model');
从 IndexedDB 加载 #
javascript
const model = await tf.loadLayersModel('indexeddb://my-model');
下载模型文件 #
javascript
await model.save('downloads://my-model');
从文件加载 #
javascript
const model = await tf.loadLayersModel('https://example.com/model.json');
从本地文件加载 #
html
<input type="file" id="model-upload" accept=".json">
<script>
const input = document.getElementById('model-upload');
input.addEventListener('change', async (e) => {
const file = e.target.files[0];
const model = await tf.loadLayersModel(tf.io.browserFiles([file]));
});
</script>
Node.js 中保存 #
javascript
await model.save('file:///path/to/model');
模型预测 #
单样本预测 #
javascript
const input = tf.tensor([[1, 2, 3, 4]]);
const prediction = model.predict(input);
prediction.print();
批量预测 #
javascript
const inputs = tf.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]);
const predictions = model.predict(inputs);
predictions.print();
异步预测 #
javascript
const prediction = await model.predictAsync(input);
prediction.print();
获取预测数据 #
javascript
const prediction = model.predict(input);
const data = prediction.dataSync();
console.log(data);
模型评估 #
评估模型 #
javascript
const result = model.evaluate(xTest, yTest);
console.log('Loss:', result[0].dataSync()[0]);
console.log('Accuracy:', result[1].dataSync()[0]);
带配置评估 #
javascript
const result = model.evaluate(xTest, yTest, {
batchSize: 32,
verbose: 1
});
模型架构可视化 #
使用 tfjs-vis #
javascript
import * as tfvis from '@tensorflow/tfjs-vis';
tfvis.show.modelSummary({ name: 'Model Summary' }, model);
可视化层 #
javascript
tfvis.show.layer({ name: 'Layer Details' }, model.layers[0]);
模型克隆 #
克隆模型 #
javascript
const clonedModel = tf.models.cloneModel(model);
克隆并修改配置 #
javascript
const clonedModel = tf.models.cloneModel(model, (config) => {
if (config.className === 'Dense') {
config.config.activation = 'tanh';
}
return config;
});
常见模型架构 #
分类模型 #
javascript
const classificationModel = tf.sequential({
layers: [
tf.layers.dense({ units: 128, inputShape: [features], activation: 'relu' }),
tf.layers.dropout({ rate: 0.5 }),
tf.layers.dense({ units: 64, activation: 'relu' }),
tf.layers.dense({ units: numClasses, activation: 'softmax' })
]
});
回归模型 #
javascript
const regressionModel = tf.sequential({
layers: [
tf.layers.dense({ units: 64, inputShape: [features], activation: 'relu' }),
tf.layers.dense({ units: 32, activation: 'relu' }),
tf.layers.dense({ units: 1 })
]
});
自编码器 #
javascript
const input = tf.input({ shape: [inputDim] });
const encoded = tf.layers.dense({ units: 128, activation: 'relu' }).apply(input);
const encoded = tf.layers.dense({ units: 64, activation: 'relu' }).apply(encoded);
const encoded = tf.layers.dense({ units: 32, activation: 'relu' }).apply(encoded);
const decoded = tf.layers.dense({ units: 64, activation: 'relu' }).apply(encoded);
const decoded = tf.layers.dense({ units: 128, activation: 'relu' }).apply(decoded);
const decoded = tf.layers.dense({ units: inputDim, activation: 'sigmoid' }).apply(decoded);
const autoencoder = tf.model({ inputs: input, outputs: decoded });
模型调试 #
检查模型输出形状 #
javascript
const input = tf.zeros([1, 784]);
const output = model.predict(input);
console.log('输出形状:', output.shape);
检查梯度 #
javascript
const input = tf.tensor([[1, 2, 3, 4]]);
const target = tf.tensor([[0, 1]]);
const f = () => {
const pred = model.predict(input);
return tf.losses.meanSquaredError(target, pred);
};
const grad = tf.grad(f);
const gradients = grad();
console.log('梯度:', gradients);
下一步 #
现在你已经掌握了模型构建,接下来学习 层与网络,深入了解各种类型的层!
最后更新:2026-03-29