TensorFlow.js 卷积神经网络 #

CNN 概述 #

卷积神经网络(Convolutional Neural Network,CNN)是处理图像数据的核心技术,通过卷积操作自动提取图像特征。

text
┌─────────────────────────────────────────────────────────────┐
│                    CNN 架构概览                               │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  输入图像        卷积层         池化层        全连接层       │
│  ┌─────┐       ┌─────┐       ┌─────┐       ┌─────┐        │
│  │     │  -->  │     │  -->  │     │  -->  │     │        │
│  │ 图片 │       │特征图│       │降采样│       │分类 │        │
│  │     │       │     │       │     │       │     │        │
│  └─────┘       └─────┘       └─────┘       └─────┘        │
│                                                             │
│  特点:                                                      │
│  - 局部感知:提取局部特征                                    │
│  - 权值共享:减少参数数量                                    │
│  - 层次特征:从低级到高级特征                                │
│                                                             │
└─────────────────────────────────────────────────────────────┘

卷积基础 #

卷积操作原理 #

text
┌─────────────────────────────────────────────────────────────┐
│                     卷积操作                                 │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  输入图像           卷积核            输出                   │
│  ┌─────────┐       ┌─────┐          ┌───────┐              │
│  │1 2 3 4  │       │1 0 │           │       │              │
│  │5 6 7 8  │   *   │0 1 │    =      │  7    │              │
│  │9 10 11 12│      └─────┘          │       │              │
│  │13 14 15 16│                      └───────┘              │
│  └─────────┘                                              │
│                                                             │
│  计算: 1*1 + 2*0 + 5*0 + 6*1 = 7                          │
│                                                             │
└─────────────────────────────────────────────────────────────┘

卷积参数 #

javascript
const conv2d = tf.layers.conv2d({
  filters: 32,
  kernelSize: 3,
  strides: 1,
  padding: 'same',
  activation: 'relu',
  inputShape: [28, 28, 1]
});

参数说明 #

参数 说明 常用值
filters 卷积核数量 32, 64, 128
kernelSize 卷积核大小 3, 5, 7
strides 步长 1, 2
padding 填充方式 ‘same’, ‘valid’
activation 激活函数 ‘relu’

Padding 对比 #

text
┌─────────────────────────────────────────────────────────────┐
│                    Padding 对比                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  Valid (无填充)              Same (填充)                    │
│  ┌─────────────┐            ┌─────────────┐               │
│  │ ┌─────────┐ │            │ ┌─────────┐ │               │
│  │ │  输入   │ │            │ │  填充   │ │               │
│  │ │ 5x5    │ │            │ │ 7x7    │ │               │
│  │ └─────────┘ │            │ │ ┌─────┐ │ │               │
│  │  输出 3x3   │            │ │ │输入 │ │ │               │
│  └─────────────┘            │ │ │5x5 │ │ │               │
│                             │ │ └─────┘ │ │               │
│  输出尺寸:                   │ │  输出   │ │               │
│  (n-f+1) x (n-f+1)          │ └─────────┘ │               │
│                             └─────────────┘               │
│                             输出尺寸: n x n               │
│                                                             │
└─────────────────────────────────────────────────────────────┘

构建 CNN #

基础 CNN 模型 #

javascript
const model = tf.sequential();

model.add(tf.layers.conv2d({
  inputShape: [28, 28, 1],
  filters: 32,
  kernelSize: 3,
  activation: 'relu'
}));
model.add(tf.layers.maxPooling2d({ poolSize: 2 }));

model.add(tf.layers.conv2d({
  filters: 64,
  kernelSize: 3,
  activation: 'relu'
}));
model.add(tf.layers.maxPooling2d({ poolSize: 2 }));

model.add(tf.layers.flatten());
model.add(tf.layers.dense({ units: 128, activation: 'relu' }));
model.add(tf.layers.dense({ units: 10, activation: 'softmax' }));

model.compile({
  optimizer: 'adam',
  loss: 'categoricalCrossentropy',
  metrics: ['accuracy']
});

LeNet-5 架构 #

javascript
function createLeNet5() {
  const model = tf.sequential();

  model.add(tf.layers.conv2d({
    inputShape: [32, 32, 1],
    filters: 6,
    kernelSize: 5,
    activation: 'tanh',
    padding: 'same'
  }));
  model.add(tf.layers.averagePooling2d({ poolSize: 2, strides: 2 }));

  model.add(tf.layers.conv2d({
    filters: 16,
    kernelSize: 5,
    activation: 'tanh'
  }));
  model.add(tf.layers.averagePooling2d({ poolSize: 2, strides: 2 }));

  model.add(tf.layers.flatten());
  model.add(tf.layers.dense({ units: 120, activation: 'tanh' }));
  model.add(tf.layers.dense({ units: 84, activation: 'tanh' }));
  model.add(tf.layers.dense({ units: 10, activation: 'softmax' }));

  return model;
}

VGG 风格网络 #

javascript
function createVGGStyle() {
  const model = tf.sequential();

  model.add(tf.layers.conv2d({
    inputShape: [224, 224, 3],
    filters: 64,
    kernelSize: 3,
    padding: 'same',
    activation: 'relu'
  }));
  model.add(tf.layers.conv2d({
    filters: 64,
    kernelSize: 3,
    padding: 'same',
    activation: 'relu'
  }));
  model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));

  model.add(tf.layers.conv2d({
    filters: 128,
    kernelSize: 3,
    padding: 'same',
    activation: 'relu'
  }));
  model.add(tf.layers.conv2d({
    filters: 128,
    kernelSize: 3,
    padding: 'same',
    activation: 'relu'
  }));
  model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));

  model.add(tf.layers.flatten());
  model.add(tf.layers.dense({ units: 4096, activation: 'relu' }));
  model.add(tf.layers.dropout({ rate: 0.5 }));
  model.add(tf.layers.dense({ units: 1000, activation: 'softmax' }));

  return model;
}

ResNet 风格残差块 #

javascript
function residualBlock(input, filters) {
  const shortcut = input;

  const conv1 = tf.layers.conv2d({
    filters: filters,
    kernelSize: 3,
    padding: 'same',
    activation: 'relu'
  }).apply(input);

  const conv2 = tf.layers.conv2d({
    filters: filters,
    kernelSize: 3,
    padding: 'same'
  }).apply(conv1);

  const add = tf.layers.add().apply([shortcut, conv2]);
  const output = tf.layers.activation({ activation: 'relu' }).apply(add);

  return output;
}

function createResNetStyle() {
  const input = tf.input({ shape: [224, 224, 3] });

  let x = tf.layers.conv2d({
    filters: 64,
    kernelSize: 7,
    strides: 2,
    padding: 'same',
    activation: 'relu'
  }).apply(input);

  x = tf.layers.maxPooling2d({ poolSize: 3, strides: 2, padding: 'same' }).apply(x);

  x = residualBlock(x, 64);
  x = residualBlock(x, 64);

  x = tf.layers.globalAveragePooling2d().apply(x);
  const output = tf.layers.dense({ units: 1000, activation: 'softmax' }).apply(x);

  return tf.model({ inputs: input, outputs: output });
}

池化层 #

最大池化 #

javascript
const maxPool = tf.layers.maxPooling2d({
  poolSize: 2,
  strides: 2,
  padding: 'valid'
});

平均池化 #

javascript
const avgPool = tf.layers.averagePooling2d({
  poolSize: 2,
  strides: 2
});

全局池化 #

javascript
const globalMaxPool = tf.layers.globalMaxPooling2d();
const globalAvgPool = tf.layers.globalAveragePooling2d();

图像预处理 #

从图像创建张量 #

javascript
const img = document.getElementById('myImage');
const tensor = tf.browser.fromPixels(img);
tensor.print();

图像归一化 #

javascript
function preprocessImage(img) {
  return tf.tidy(() => {
    let tensor = tf.browser.fromPixels(img)
      .resizeNearestNeighbor([224, 224])
      .toFloat();
    
    const mean = tf.tensor([123.68, 116.779, 103.939]);
    tensor = tensor.sub(mean);
    
    return tensor.reverse(2);
  });
}

数据增强 #

javascript
function augmentImage(tensor) {
  return tf.tidy(() => {
    const augmented = tensor.clone();
    
    if (Math.random() > 0.5) {
      augmented = tf.image.flipLeftRight(augmented);
    }
    
    const brightness = Math.random() * 0.2 - 0.1;
    augmented = tf.image.adjustBrightness(augmented, brightness);
    
    const contrast = Math.random() * 0.2 + 0.9;
    augmented = tf.image.adjustContrast(augmented, contrast);
    
    return augmented;
  });
}

批量处理 #

javascript
async function loadAndPreprocessImages(imageUrls, targetSize) {
  const tensors = await Promise.all(
    imageUrls.map(async (url) => {
      const img = new Image();
      img.src = url;
      await img.decode();
      
      return tf.tidy(() => {
        return tf.browser.fromPixels(img)
          .resizeNearestNeighbor(targetSize)
          .toFloat()
          .div(255);
      });
    })
  );
  
  return tf.stack(tensors);
}

图像分类实战 #

完整示例 #

javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';

async function trainImageClassifier() {
  const xs = tf.randomNormal([1000, 28, 28, 1]);
  const ys = tf.randomUniform([1000], 0, 10, 'int32');
  const ysOneHot = tf.oneHot(ys, 10);

  const model = tf.sequential();

  model.add(tf.layers.conv2d({
    inputShape: [28, 28, 1],
    filters: 32,
    kernelSize: 3,
    activation: 'relu'
  }));
  model.add(tf.layers.batchNormalization());
  model.add(tf.layers.maxPooling2d({ poolSize: 2 }));
  model.add(tf.layers.dropout({ rate: 0.25 }));

  model.add(tf.layers.conv2d({
    filters: 64,
    kernelSize: 3,
    activation: 'relu'
  }));
  model.add(tf.layers.batchNormalization());
  model.add(tf.layers.maxPooling2d({ poolSize: 2 }));
  model.add(tf.layers.dropout({ rate: 0.25 }));

  model.add(tf.layers.flatten());
  model.add(tf.layers.dense({ units: 128, activation: 'relu' }));
  model.add(tf.layers.batchNormalization());
  model.add(tf.layers.dropout({ rate: 0.5 }));
  model.add(tf.layers.dense({ units: 10, activation: 'softmax' }));

  model.compile({
    optimizer: tf.train.adam(0.001),
    loss: 'categoricalCrossentropy',
    metrics: ['accuracy']
  });

  tfvis.show.modelSummary({ name: 'Model Summary' }, model);

  const history = await model.fit(xs, ysOneHot, {
    epochs: 20,
    batchSize: 32,
    validationSplit: 0.2,
    callbacks: tfvis.show.fitCallbacks(
      { name: 'Training Performance' },
      ['loss', 'val_loss', 'acc', 'val_acc']
    )
  });

  return model;
}

实时预测 #

javascript
async function predictImage(model, imageElement) {
  const tensor = tf.tidy(() => {
    return tf.browser.fromPixels(imageElement)
      .resizeNearestNeighbor([28, 28])
      .mean(2)
      .expandDims(0)
      .expandDims(-1)
      .toFloat()
      .div(255);
  });

  const prediction = model.predict(tensor);
  const probabilities = await prediction.data();
  const predictedClass = probabilities.indexOf(Math.max(...probabilities));

  tensor.dispose();
  prediction.dispose();

  return {
    class: predictedClass,
    probabilities: probabilities
  };
}

特征可视化 #

卷积核可视化 #

javascript
function visualizeFilters(layer, numFilters = 16) {
  const weights = layer.getWeights()[0];
  const filters = weights.unstack(-1);

  for (let i = 0; i < Math.min(numFilters, filters.length); i++) {
    const filter = filters[i];
    const normalized = filter.sub(filter.min()).div(filter.max().sub(filter.min()));
    
    const canvas = document.createElement('canvas');
    canvas.width = filter.shape[0];
    canvas.height = filter.shape[1];
    tf.browser.toPixels(normalized, canvas);
    document.body.appendChild(canvas);
  }
}

特征图可视化 #

javascript
async function visualizeFeatureMaps(model, image, layerIndex) {
  const layer = model.layers[layerIndex];
  const featureMapModel = tf.model({
    inputs: model.input,
    outputs: layer.output
  });

  const featureMaps = featureMapModel.predict(image);
  const numMaps = featureMaps.shape[-1];

  for (let i = 0; i < Math.min(16, numMaps); i++) {
    const map = featureMaps.slice([0, 0, 0, i], [-1, -1, -1, 1]).squeeze();
    
    const canvas = document.createElement('canvas');
    tf.browser.toPixels(map, canvas);
    document.body.appendChild(canvas);
  }
}

常见 CNN 架构 #

MobileNet 风格 #

javascript
function depthwiseSeparableConv(x, filters, kernelSize = 3) {
  x = tf.layers.depthwiseConv2d({
    kernelSize: kernelSize,
    padding: 'same',
    activation: 'relu'
  }).apply(x);
  
  x = tf.layers.conv2d({
    filters: filters,
    kernelSize: 1,
    activation: 'relu'
  }).apply(x);
  
  return x;
}

Inception 模块 #

javascript
function inceptionModule(x, filters) {
  const [f1, f2, f3, f4] = filters;

  const branch1 = tf.layers.conv2d({
    filters: f1,
    kernelSize: 1,
    activation: 'relu'
  }).apply(x);

  const branch2 = tf.layers.conv2d({
    filters: f2[0],
    kernelSize: 1,
    activation: 'relu'
  }).apply(x);
  const branch2 = tf.layers.conv2d({
    filters: f2[1],
    kernelSize: 3,
    padding: 'same',
    activation: 'relu'
  }).apply(branch2);

  const branch3 = tf.layers.conv2d({
    filters: f3[0],
    kernelSize: 1,
    activation: 'relu'
  }).apply(x);
  const branch3 = tf.layers.conv2d({
    filters: f3[1],
    kernelSize: 5,
    padding: 'same',
    activation: 'relu'
  }).apply(branch3);

  const branch4 = tf.layers.maxPooling2d({
    poolSize: 3,
    strides: 1,
    padding: 'same'
  }).apply(x);
  const branch4 = tf.layers.conv2d({
    filters: f4,
    kernelSize: 1,
    activation: 'relu'
  }).apply(branch4);

  return tf.layers.concatenate().apply([branch1, branch2, branch3, branch4]);
}

目标检测基础 #

简单边界框预测 #

javascript
function createDetectionModel(inputShape, numClasses) {
  const input = tf.input({ shape: inputShape });

  let x = tf.layers.conv2d({
    filters: 32,
    kernelSize: 3,
    padding: 'same',
    activation: 'relu'
  }).apply(input);
  x = tf.layers.maxPooling2d({ poolSize: 2 }).apply(x);

  x = tf.layers.conv2d({
    filters: 64,
    kernelSize: 3,
    padding: 'same',
    activation: 'relu'
  }).apply(x);
  x = tf.layers.maxPooling2d({ poolSize: 2 }).apply(x);

  x = tf.layers.flatten().apply(x);
  x = tf.layers.dense({ units: 256, activation: 'relu' }).apply(x);

  const bbox = tf.layers.dense({ units: 4, name: 'bbox' }).apply(x);
  const classes = tf.layers.dense({
    units: numClasses,
    activation: 'softmax',
    name: 'classes'
  }).apply(x);

  return tf.model({ inputs: input, outputs: [bbox, classes] });
}

下一步 #

现在你已经掌握了 CNN,接下来学习 循环神经网络,了解序列数据处理技术!

最后更新:2026-03-29