TensorFlow.js 迁移学习 #
迁移学习概述 #
迁移学习是利用预训练模型的知识来解决新问题的技术,可以大幅减少训练时间和数据需求。
text
┌─────────────────────────────────────────────────────────────┐
│ 迁移学习流程 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 预训练模型 迁移学习 │
│ ┌─────────┐ ┌─────────┐ │
│ │ ImageNet │ ---> │ 目标 │ │
│ │ 大数据集 │ │ 任务 │ │
│ └─────────┘ └─────────┘ │
│ │ │ │
│ ↓ ↓ │
│ ┌─────────┐ ┌─────────┐ │
│ │ 特征提取 │ │ 微调 │ │
│ │ 能力 │ │ 训练 │ │
│ └─────────┘ └─────────┘ │
│ │
│ 优势: │
│ - 减少训练数据需求 │
│ - 加速模型收敛 │
│ - 提高模型性能 │
│ │
└─────────────────────────────────────────────────────────────┘
预训练模型加载 #
加载 TensorFlow.js 模型 #
javascript
const model = await tf.loadLayersModel('https://example.com/model.json');
加载 GraphModel #
javascript
const model = await tf.loadGraphModel('https://example.com/model.json');
从本地存储加载 #
javascript
const model = await tf.loadLayersModel('localstorage://my-model');
const model = await tf.loadLayersModel('indexeddb://my-model');
从文件加载 #
html
<input type="file" id="model-upload" multiple>
<script>
const input = document.getElementById('model-upload');
input.addEventListener('change', async (e) => {
const files = Array.from(e.target.files);
const model = await tf.loadLayersModel(tf.io.browserFiles(files));
console.log('模型加载成功');
});
</script>
使用预训练模型 #
MobileNet 图像分类 #
javascript
import * as mobilenet from '@tensorflow-models/mobilenet';
async function classifyImage() {
const model = await mobilenet.load();
const img = document.getElementById('image');
const predictions = await model.classify(img);
predictions.forEach(p => {
console.log(`${p.className}: ${(p.probability * 100).toFixed(2)}%`);
});
}
COCO-SSD 目标检测 #
javascript
import * as cocoSsd from '@tensorflow-models/coco-ssd';
async function detectObjects() {
const model = await cocoSsd.load();
const img = document.getElementById('image');
const predictions = await model.detect(img);
predictions.forEach(p => {
console.log(`${p.class}: ${(p.score * 100).toFixed(2)}%`);
console.log(`位置: [${p.bbox}]`);
});
}
姿态估计 #
javascript
import * as poseDetection from '@tensorflow-models/pose-detection';
async function estimatePose() {
const model = poseDetection.SupportedModels.MoveNet;
const detector = await poseDetection.createDetector(model);
const video = document.getElementById('video');
const poses = await detector.estimatePoses(video);
poses.forEach(pose => {
pose.keypoints.forEach(keypoint => {
console.log(`${keypoint.name}: (${keypoint.x}, ${keypoint.y})`);
});
});
}
文本嵌入 #
javascript
import * as use from '@tensorflow-models/universal-sentence-encoder';
async function embedText() {
const model = await use.load();
const sentences = [
'Hello TensorFlow.js',
'Machine learning in JavaScript'
];
const embeddings = await model.embed(sentences);
console.log('嵌入维度:', embeddings.shape);
}
迁移学习实战 #
特征提取 #
冻结预训练模型,只训练新添加的分类层。
javascript
async function transferLearning() {
const baseModel = await tf.loadLayersModel(
'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json'
);
baseModel.layers.forEach(layer => {
layer.trainable = false;
});
const input = baseModel.input;
const features = baseModel.layers[baseModel.layers.length - 2].output;
const newHead = tf.layers.dense({
units: 10,
activation: 'softmax',
name: 'new_output'
}).apply(features);
const model = tf.model({
inputs: input,
outputs: newHead
});
model.compile({
optimizer: tf.train.adam(0.001),
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
return model;
}
微调 #
解冻部分层进行微调。
javascript
async function fineTuning() {
const model = await transferLearning();
await model.fit(xs, ys, {
epochs: 10,
batchSize: 32
});
for (let i = model.layers.length - 5; i < model.layers.length; i++) {
model.layers[i].trainable = true;
}
model.compile({
optimizer: tf.train.adam(0.0001),
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
await model.fit(xs, ys, {
epochs: 10,
batchSize: 32
});
return model;
}
自定义分类头 #
javascript
async function createCustomClassifier(numClasses) {
const baseModel = await tf.loadLayersModel(
'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json'
);
const layer = baseModel.getLayer('global_average_pooling2d_1');
const features = layer.output;
let x = tf.layers.dense({ units: 256, activation: 'relu' }).apply(features);
x = tf.layers.dropout({ rate: 0.5 }).apply(x);
x = tf.layers.dense({ units: 128, activation: 'relu' }).apply(x);
x = tf.layers.dropout({ rate: 0.3 }).apply(x);
const output = tf.layers.dense({
units: numClasses,
activation: 'softmax'
}).apply(x);
const model = tf.model({
inputs: baseModel.input,
outputs: output
});
return model;
}
模型转换 #
Python 模型转换 #
bash
pip install tensorflowjs
Keras 模型转换 #
bash
tensorflowjs_converter --input_format keras \
model.h5 \
./tfjs_model
SavedModel 转换 #
bash
tensorflowjs_converter --input_format tf_saved_model \
./saved_model \
./tfjs_model
转换选项 #
bash
tensorflowjs_converter \
--input_format keras \
--output_format tfjs_graph_model \
--weight_shard_size_bytes 4000000 \
--quantization_bytes 2 \
model.h5 \
./tfjs_model
量化选项 #
| 选项 | 说明 |
|---|---|
| –quantization_bytes 1 | 1字节量化(最大压缩) |
| –quantization_bytes 2 | 2字节量化(平衡) |
| 无量化 | 4字节浮点(最大精度) |
模型优化 #
模型压缩 #
javascript
async function quantizeModel() {
const model = await tf.loadLayersModel('model.json');
await model.save(tf.io.withSaveHandler(async (artifacts) => {
const weightSpecs = artifacts.weightSpecs;
const weightData = artifacts.weightData;
return {
modelTopology: artifacts.modelTopology,
weightSpecs: weightSpecs,
weightData: weightData
};
}));
}
权重剪枝 #
javascript
function pruneWeights(model, threshold = 0.01) {
const weights = model.getWeights();
const prunedWeights = weights.map(w => {
const mask = w.abs().greater(threshold);
return w.mul(mask.cast('float32'));
});
model.setWeights(prunedWeights);
}
知识蒸馏 #
javascript
async function knowledgeDistillation(teacher, student, xs, temperature = 3) {
const teacherOutput = teacher.predict(xs);
const softTargets = tf.softmax(teacherOutput.div(temperature));
const loss = tf.losses.meanSquaredError(
softTargets,
student.predict(xs).div(temperature)
);
return loss;
}
预训练模型库 #
TensorFlow.js 官方模型 #
| 模型 | 用途 | 安装 |
|---|---|---|
| MobileNet | 图像分类 | @tensorflow-models/mobilenet |
| COCO-SSD | 目标检测 | @tensorflow-models/coco-ssd |
| PoseNet | 姿态估计 | @tensorflow-models/posenet |
| MoveNet | 姿态估计 | @tensorflow-models/pose-detection |
| BodyPix | 身体分割 | @tensorflow-models/body-pix |
| DeepLab | 语义分割 | @tensorflow-models/deeplab |
| USE | 文本嵌入 | @tensorflow-models/universal-sentence-encoder |
| Toxicity | 文本分类 | @tensorflow-models/toxicity |
| Speech Commands | 语音识别 | @tensorflow-models/speech-commands |
| Face Landmarks | 面部关键点 | @tensorflow-models/face-landmarks-detection |
| HandPose | 手部关键点 | @tensorflow-models/handpose |
安装预训练模型 #
bash
npm install @tensorflow-models/mobilenet
npm install @tensorflow-models/coco-ssd
npm install @tensorflow-models/pose-detection
npm install @tensorflow-models/universal-sentence-encoder
完整示例 #
图像分类迁移学习 #
javascript
import * as tf from '@tensorflow/tfjs';
import * as mobilenet from '@tensorflow-models/mobilenet';
async function trainCustomClassifier() {
const mobilenetModel = await mobilenet.load({ version: 2, alpha: 0.5 });
const model = tf.sequential();
model.add(tf.layers.dense({
units: 256,
activation: 'relu',
inputShape: [mobilenetModel.model.outputShape[1]]
}));
model.add(tf.layers.dropout({ rate: 0.5 }));
model.add(tf.layers.dense({
units: 10,
activation: 'softmax'
}));
model.compile({
optimizer: tf.train.adam(0.001),
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
const images = [];
const labels = [];
for (const img of images) {
const features = await mobilenetModel.infer(img);
featuresArray.push(features);
}
const xs = tf.stack(featuresArray);
const ys = tf.oneHot(tf.tensor1d(labels, 'int32'), 10);
await model.fit(xs, ys, {
epochs: 20,
batchSize: 32,
validationSplit: 0.2
});
return { mobilenetModel, classifier: model };
}
async function predict(mobilenetModel, classifier, image) {
const features = await mobilenetModel.infer(image);
const prediction = classifier.predict(features);
return prediction;
}
目标检测微调 #
javascript
async function fineTuneObjectDetection() {
const model = await cocoSsd.load();
return model;
}
最佳实践 #
选择预训练模型 #
text
┌─────────────────────────────────────────────────────────────┐
│ 预训练模型选择指南 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 图像分类: │
│ - MobileNet:轻量级,移动端优先 │
│ - ResNet:高精度,服务器端 │
│ - EfficientNet:平衡精度和效率 │
│ │
│ 目标检测: │
│ - COCO-SSD:实时检测 │
│ - YOLO:快速检测 │
│ │
│ 自然语言处理: │
│ - USE:句子嵌入 │
│ - BERT:深度理解 │
│ │
│ 选择考虑: │
│ 1. 任务类型匹配 │
│ 2. 数据相似度 │
│ 3. 计算资源限制 │
│ 4. 延迟要求 │
│ │
└─────────────────────────────────────────────────────────────┘
微调策略 #
javascript
const strategies = {
featureExtraction: {
freezeBase: true,
learningRate: 0.001,
epochs: 10
},
fineTuning: {
freezeBase: false,
learningRate: 0.0001,
epochs: 20
},
progressive: {
stage1: { freezeBase: true, lr: 0.001, epochs: 5 },
stage2: { freezeBase: false, lr: 0.0001, epochs: 10 }
}
};
下一步 #
现在你已经掌握了迁移学习,接下来学习 浏览器推理,了解如何在浏览器中高效运行模型!
最后更新:2026-03-29