TensorFlow.js 迁移学习 #

迁移学习概述 #

迁移学习是利用预训练模型的知识来解决新问题的技术,可以大幅减少训练时间和数据需求。

text
┌─────────────────────────────────────────────────────────────┐
│                    迁移学习流程                               │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  预训练模型                迁移学习                          │
│  ┌─────────┐              ┌─────────┐                       │
│  │ ImageNet │    --->     │  目标   │                       │
│  │ 大数据集  │             │  任务   │                       │
│  └─────────┘              └─────────┘                       │
│       │                        │                            │
│       ↓                        ↓                            │
│  ┌─────────┐              ┌─────────┐                       │
│  │ 特征提取 │              │ 微调   │                       │
│  │ 能力     │              │ 训练   │                       │
│  └─────────┘              └─────────┘                       │
│                                                             │
│  优势:                                                      │
│  - 减少训练数据需求                                          │
│  - 加速模型收敛                                              │
│  - 提高模型性能                                              │
│                                                             │
└─────────────────────────────────────────────────────────────┘

预训练模型加载 #

加载 TensorFlow.js 模型 #

javascript
const model = await tf.loadLayersModel('https://example.com/model.json');

加载 GraphModel #

javascript
const model = await tf.loadGraphModel('https://example.com/model.json');

从本地存储加载 #

javascript
const model = await tf.loadLayersModel('localstorage://my-model');
const model = await tf.loadLayersModel('indexeddb://my-model');

从文件加载 #

html
<input type="file" id="model-upload" multiple>
<script>
const input = document.getElementById('model-upload');
input.addEventListener('change', async (e) => {
  const files = Array.from(e.target.files);
  const model = await tf.loadLayersModel(tf.io.browserFiles(files));
  console.log('模型加载成功');
});
</script>

使用预训练模型 #

MobileNet 图像分类 #

javascript
import * as mobilenet from '@tensorflow-models/mobilenet';

async function classifyImage() {
  const model = await mobilenet.load();
  
  const img = document.getElementById('image');
  const predictions = await model.classify(img);
  
  predictions.forEach(p => {
    console.log(`${p.className}: ${(p.probability * 100).toFixed(2)}%`);
  });
}

COCO-SSD 目标检测 #

javascript
import * as cocoSsd from '@tensorflow-models/coco-ssd';

async function detectObjects() {
  const model = await cocoSsd.load();
  
  const img = document.getElementById('image');
  const predictions = await model.detect(img);
  
  predictions.forEach(p => {
    console.log(`${p.class}: ${(p.score * 100).toFixed(2)}%`);
    console.log(`位置: [${p.bbox}]`);
  });
}

姿态估计 #

javascript
import * as poseDetection from '@tensorflow-models/pose-detection';

async function estimatePose() {
  const model = poseDetection.SupportedModels.MoveNet;
  const detector = await poseDetection.createDetector(model);
  
  const video = document.getElementById('video');
  const poses = await detector.estimatePoses(video);
  
  poses.forEach(pose => {
    pose.keypoints.forEach(keypoint => {
      console.log(`${keypoint.name}: (${keypoint.x}, ${keypoint.y})`);
    });
  });
}

文本嵌入 #

javascript
import * as use from '@tensorflow-models/universal-sentence-encoder';

async function embedText() {
  const model = await use.load();
  
  const sentences = [
    'Hello TensorFlow.js',
    'Machine learning in JavaScript'
  ];
  
  const embeddings = await model.embed(sentences);
  console.log('嵌入维度:', embeddings.shape);
}

迁移学习实战 #

特征提取 #

冻结预训练模型,只训练新添加的分类层。

javascript
async function transferLearning() {
  const baseModel = await tf.loadLayersModel(
    'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json'
  );

  baseModel.layers.forEach(layer => {
    layer.trainable = false;
  });

  const input = baseModel.input;
  const features = baseModel.layers[baseModel.layers.length - 2].output;

  const newHead = tf.layers.dense({
    units: 10,
    activation: 'softmax',
    name: 'new_output'
  }).apply(features);

  const model = tf.model({
    inputs: input,
    outputs: newHead
  });

  model.compile({
    optimizer: tf.train.adam(0.001),
    loss: 'categoricalCrossentropy',
    metrics: ['accuracy']
  });

  return model;
}

微调 #

解冻部分层进行微调。

javascript
async function fineTuning() {
  const model = await transferLearning();

  await model.fit(xs, ys, {
    epochs: 10,
    batchSize: 32
  });

  for (let i = model.layers.length - 5; i < model.layers.length; i++) {
    model.layers[i].trainable = true;
  }

  model.compile({
    optimizer: tf.train.adam(0.0001),
    loss: 'categoricalCrossentropy',
    metrics: ['accuracy']
  });

  await model.fit(xs, ys, {
    epochs: 10,
    batchSize: 32
  });

  return model;
}

自定义分类头 #

javascript
async function createCustomClassifier(numClasses) {
  const baseModel = await tf.loadLayersModel(
    'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json'
  );

  const layer = baseModel.getLayer('global_average_pooling2d_1');
  const features = layer.output;

  let x = tf.layers.dense({ units: 256, activation: 'relu' }).apply(features);
  x = tf.layers.dropout({ rate: 0.5 }).apply(x);
  x = tf.layers.dense({ units: 128, activation: 'relu' }).apply(x);
  x = tf.layers.dropout({ rate: 0.3 }).apply(x);
  const output = tf.layers.dense({
    units: numClasses,
    activation: 'softmax'
  }).apply(x);

  const model = tf.model({
    inputs: baseModel.input,
    outputs: output
  });

  return model;
}

模型转换 #

Python 模型转换 #

bash
pip install tensorflowjs

Keras 模型转换 #

bash
tensorflowjs_converter --input_format keras \
  model.h5 \
  ./tfjs_model

SavedModel 转换 #

bash
tensorflowjs_converter --input_format tf_saved_model \
  ./saved_model \
  ./tfjs_model

转换选项 #

bash
tensorflowjs_converter \
  --input_format keras \
  --output_format tfjs_graph_model \
  --weight_shard_size_bytes 4000000 \
  --quantization_bytes 2 \
  model.h5 \
  ./tfjs_model

量化选项 #

选项 说明
–quantization_bytes 1 1字节量化(最大压缩)
–quantization_bytes 2 2字节量化(平衡)
无量化 4字节浮点(最大精度)

模型优化 #

模型压缩 #

javascript
async function quantizeModel() {
  const model = await tf.loadLayersModel('model.json');
  
  await model.save(tf.io.withSaveHandler(async (artifacts) => {
    const weightSpecs = artifacts.weightSpecs;
    const weightData = artifacts.weightData;
    
    return {
      modelTopology: artifacts.modelTopology,
      weightSpecs: weightSpecs,
      weightData: weightData
    };
  }));
}

权重剪枝 #

javascript
function pruneWeights(model, threshold = 0.01) {
  const weights = model.getWeights();
  
  const prunedWeights = weights.map(w => {
    const mask = w.abs().greater(threshold);
    return w.mul(mask.cast('float32'));
  });
  
  model.setWeights(prunedWeights);
}

知识蒸馏 #

javascript
async function knowledgeDistillation(teacher, student, xs, temperature = 3) {
  const teacherOutput = teacher.predict(xs);
  const softTargets = tf.softmax(teacherOutput.div(temperature));
  
  const loss = tf.losses.meanSquaredError(
    softTargets,
    student.predict(xs).div(temperature)
  );
  
  return loss;
}

预训练模型库 #

TensorFlow.js 官方模型 #

模型 用途 安装
MobileNet 图像分类 @tensorflow-models/mobilenet
COCO-SSD 目标检测 @tensorflow-models/coco-ssd
PoseNet 姿态估计 @tensorflow-models/posenet
MoveNet 姿态估计 @tensorflow-models/pose-detection
BodyPix 身体分割 @tensorflow-models/body-pix
DeepLab 语义分割 @tensorflow-models/deeplab
USE 文本嵌入 @tensorflow-models/universal-sentence-encoder
Toxicity 文本分类 @tensorflow-models/toxicity
Speech Commands 语音识别 @tensorflow-models/speech-commands
Face Landmarks 面部关键点 @tensorflow-models/face-landmarks-detection
HandPose 手部关键点 @tensorflow-models/handpose

安装预训练模型 #

bash
npm install @tensorflow-models/mobilenet
npm install @tensorflow-models/coco-ssd
npm install @tensorflow-models/pose-detection
npm install @tensorflow-models/universal-sentence-encoder

完整示例 #

图像分类迁移学习 #

javascript
import * as tf from '@tensorflow/tfjs';
import * as mobilenet from '@tensorflow-models/mobilenet';

async function trainCustomClassifier() {
  const mobilenetModel = await mobilenet.load({ version: 2, alpha: 0.5 });
  
  const model = tf.sequential();
  
  model.add(tf.layers.dense({
    units: 256,
    activation: 'relu',
    inputShape: [mobilenetModel.model.outputShape[1]]
  }));
  model.add(tf.layers.dropout({ rate: 0.5 }));
  model.add(tf.layers.dense({
    units: 10,
    activation: 'softmax'
  }));

  model.compile({
    optimizer: tf.train.adam(0.001),
    loss: 'categoricalCrossentropy',
    metrics: ['accuracy']
  });

  const images = [];
  const labels = [];
  
  for (const img of images) {
    const features = await mobilenetModel.infer(img);
    featuresArray.push(features);
  }

  const xs = tf.stack(featuresArray);
  const ys = tf.oneHot(tf.tensor1d(labels, 'int32'), 10);

  await model.fit(xs, ys, {
    epochs: 20,
    batchSize: 32,
    validationSplit: 0.2
  });

  return { mobilenetModel, classifier: model };
}

async function predict(mobilenetModel, classifier, image) {
  const features = await mobilenetModel.infer(image);
  const prediction = classifier.predict(features);
  return prediction;
}

目标检测微调 #

javascript
async function fineTuneObjectDetection() {
  const model = await cocoSsd.load();
  
  return model;
}

最佳实践 #

选择预训练模型 #

text
┌─────────────────────────────────────────────────────────────┐
│                   预训练模型选择指南                          │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  图像分类:                                                  │
│  - MobileNet:轻量级,移动端优先                            │
│  - ResNet:高精度,服务器端                                 │
│  - EfficientNet:平衡精度和效率                            │
│                                                             │
│  目标检测:                                                  │
│  - COCO-SSD:实时检测                                      │
│  - YOLO:快速检测                                          │
│                                                             │
│  自然语言处理:                                              │
│  - USE:句子嵌入                                           │
│  - BERT:深度理解                                          │
│                                                             │
│  选择考虑:                                                  │
│  1. 任务类型匹配                                           │
│  2. 数据相似度                                             │
│  3. 计算资源限制                                           │
│  4. 延迟要求                                               │
│                                                             │
└─────────────────────────────────────────────────────────────┘

微调策略 #

javascript
const strategies = {
  featureExtraction: {
    freezeBase: true,
    learningRate: 0.001,
    epochs: 10
  },
  fineTuning: {
    freezeBase: false,
    learningRate: 0.0001,
    epochs: 20
  },
  progressive: {
    stage1: { freezeBase: true, lr: 0.001, epochs: 5 },
    stage2: { freezeBase: false, lr: 0.0001, epochs: 10 }
  }
};

下一步 #

现在你已经掌握了迁移学习,接下来学习 浏览器推理,了解如何在浏览器中高效运行模型!

最后更新:2026-03-29