TensorFlow.js 卷积神经网络 #
CNN 概述 #
卷积神经网络(Convolutional Neural Network,CNN)是处理图像数据的核心技术,通过卷积操作自动提取图像特征。
text
┌─────────────────────────────────────────────────────────────┐
│ CNN 架构概览 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 输入图像 卷积层 池化层 全连接层 │
│ ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ │
│ │ │ --> │ │ --> │ │ --> │ │ │
│ │ 图片 │ │特征图│ │降采样│ │分类 │ │
│ │ │ │ │ │ │ │ │ │
│ └─────┘ └─────┘ └─────┘ └─────┘ │
│ │
│ 特点: │
│ - 局部感知:提取局部特征 │
│ - 权值共享:减少参数数量 │
│ - 层次特征:从低级到高级特征 │
│ │
└─────────────────────────────────────────────────────────────┘
卷积基础 #
卷积操作原理 #
text
┌─────────────────────────────────────────────────────────────┐
│ 卷积操作 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 输入图像 卷积核 输出 │
│ ┌─────────┐ ┌─────┐ ┌───────┐ │
│ │1 2 3 4 │ │1 0 │ │ │ │
│ │5 6 7 8 │ * │0 1 │ = │ 7 │ │
│ │9 10 11 12│ └─────┘ │ │ │
│ │13 14 15 16│ └───────┘ │
│ └─────────┘ │
│ │
│ 计算: 1*1 + 2*0 + 5*0 + 6*1 = 7 │
│ │
└─────────────────────────────────────────────────────────────┘
卷积参数 #
javascript
const conv2d = tf.layers.conv2d({
filters: 32,
kernelSize: 3,
strides: 1,
padding: 'same',
activation: 'relu',
inputShape: [28, 28, 1]
});
参数说明 #
| 参数 | 说明 | 常用值 |
|---|---|---|
| filters | 卷积核数量 | 32, 64, 128 |
| kernelSize | 卷积核大小 | 3, 5, 7 |
| strides | 步长 | 1, 2 |
| padding | 填充方式 | ‘same’, ‘valid’ |
| activation | 激活函数 | ‘relu’ |
Padding 对比 #
text
┌─────────────────────────────────────────────────────────────┐
│ Padding 对比 │
├─────────────────────────────────────────────────────────────┤
│ │
│ Valid (无填充) Same (填充) │
│ ┌─────────────┐ ┌─────────────┐ │
│ │ ┌─────────┐ │ │ ┌─────────┐ │ │
│ │ │ 输入 │ │ │ │ 填充 │ │ │
│ │ │ 5x5 │ │ │ │ 7x7 │ │ │
│ │ └─────────┘ │ │ │ ┌─────┐ │ │ │
│ │ 输出 3x3 │ │ │ │输入 │ │ │ │
│ └─────────────┘ │ │ │5x5 │ │ │ │
│ │ │ └─────┘ │ │ │
│ 输出尺寸: │ │ 输出 │ │ │
│ (n-f+1) x (n-f+1) │ └─────────┘ │ │
│ └─────────────┘ │
│ 输出尺寸: n x n │
│ │
└─────────────────────────────────────────────────────────────┘
构建 CNN #
基础 CNN 模型 #
javascript
const model = tf.sequential();
model.add(tf.layers.conv2d({
inputShape: [28, 28, 1],
filters: 32,
kernelSize: 3,
activation: 'relu'
}));
model.add(tf.layers.maxPooling2d({ poolSize: 2 }));
model.add(tf.layers.conv2d({
filters: 64,
kernelSize: 3,
activation: 'relu'
}));
model.add(tf.layers.maxPooling2d({ poolSize: 2 }));
model.add(tf.layers.flatten());
model.add(tf.layers.dense({ units: 128, activation: 'relu' }));
model.add(tf.layers.dense({ units: 10, activation: 'softmax' }));
model.compile({
optimizer: 'adam',
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
LeNet-5 架构 #
javascript
function createLeNet5() {
const model = tf.sequential();
model.add(tf.layers.conv2d({
inputShape: [32, 32, 1],
filters: 6,
kernelSize: 5,
activation: 'tanh',
padding: 'same'
}));
model.add(tf.layers.averagePooling2d({ poolSize: 2, strides: 2 }));
model.add(tf.layers.conv2d({
filters: 16,
kernelSize: 5,
activation: 'tanh'
}));
model.add(tf.layers.averagePooling2d({ poolSize: 2, strides: 2 }));
model.add(tf.layers.flatten());
model.add(tf.layers.dense({ units: 120, activation: 'tanh' }));
model.add(tf.layers.dense({ units: 84, activation: 'tanh' }));
model.add(tf.layers.dense({ units: 10, activation: 'softmax' }));
return model;
}
VGG 风格网络 #
javascript
function createVGGStyle() {
const model = tf.sequential();
model.add(tf.layers.conv2d({
inputShape: [224, 224, 3],
filters: 64,
kernelSize: 3,
padding: 'same',
activation: 'relu'
}));
model.add(tf.layers.conv2d({
filters: 64,
kernelSize: 3,
padding: 'same',
activation: 'relu'
}));
model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
model.add(tf.layers.conv2d({
filters: 128,
kernelSize: 3,
padding: 'same',
activation: 'relu'
}));
model.add(tf.layers.conv2d({
filters: 128,
kernelSize: 3,
padding: 'same',
activation: 'relu'
}));
model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
model.add(tf.layers.flatten());
model.add(tf.layers.dense({ units: 4096, activation: 'relu' }));
model.add(tf.layers.dropout({ rate: 0.5 }));
model.add(tf.layers.dense({ units: 1000, activation: 'softmax' }));
return model;
}
ResNet 风格残差块 #
javascript
function residualBlock(input, filters) {
const shortcut = input;
const conv1 = tf.layers.conv2d({
filters: filters,
kernelSize: 3,
padding: 'same',
activation: 'relu'
}).apply(input);
const conv2 = tf.layers.conv2d({
filters: filters,
kernelSize: 3,
padding: 'same'
}).apply(conv1);
const add = tf.layers.add().apply([shortcut, conv2]);
const output = tf.layers.activation({ activation: 'relu' }).apply(add);
return output;
}
function createResNetStyle() {
const input = tf.input({ shape: [224, 224, 3] });
let x = tf.layers.conv2d({
filters: 64,
kernelSize: 7,
strides: 2,
padding: 'same',
activation: 'relu'
}).apply(input);
x = tf.layers.maxPooling2d({ poolSize: 3, strides: 2, padding: 'same' }).apply(x);
x = residualBlock(x, 64);
x = residualBlock(x, 64);
x = tf.layers.globalAveragePooling2d().apply(x);
const output = tf.layers.dense({ units: 1000, activation: 'softmax' }).apply(x);
return tf.model({ inputs: input, outputs: output });
}
池化层 #
最大池化 #
javascript
const maxPool = tf.layers.maxPooling2d({
poolSize: 2,
strides: 2,
padding: 'valid'
});
平均池化 #
javascript
const avgPool = tf.layers.averagePooling2d({
poolSize: 2,
strides: 2
});
全局池化 #
javascript
const globalMaxPool = tf.layers.globalMaxPooling2d();
const globalAvgPool = tf.layers.globalAveragePooling2d();
图像预处理 #
从图像创建张量 #
javascript
const img = document.getElementById('myImage');
const tensor = tf.browser.fromPixels(img);
tensor.print();
图像归一化 #
javascript
function preprocessImage(img) {
return tf.tidy(() => {
let tensor = tf.browser.fromPixels(img)
.resizeNearestNeighbor([224, 224])
.toFloat();
const mean = tf.tensor([123.68, 116.779, 103.939]);
tensor = tensor.sub(mean);
return tensor.reverse(2);
});
}
数据增强 #
javascript
function augmentImage(tensor) {
return tf.tidy(() => {
const augmented = tensor.clone();
if (Math.random() > 0.5) {
augmented = tf.image.flipLeftRight(augmented);
}
const brightness = Math.random() * 0.2 - 0.1;
augmented = tf.image.adjustBrightness(augmented, brightness);
const contrast = Math.random() * 0.2 + 0.9;
augmented = tf.image.adjustContrast(augmented, contrast);
return augmented;
});
}
批量处理 #
javascript
async function loadAndPreprocessImages(imageUrls, targetSize) {
const tensors = await Promise.all(
imageUrls.map(async (url) => {
const img = new Image();
img.src = url;
await img.decode();
return tf.tidy(() => {
return tf.browser.fromPixels(img)
.resizeNearestNeighbor(targetSize)
.toFloat()
.div(255);
});
})
);
return tf.stack(tensors);
}
图像分类实战 #
完整示例 #
javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
async function trainImageClassifier() {
const xs = tf.randomNormal([1000, 28, 28, 1]);
const ys = tf.randomUniform([1000], 0, 10, 'int32');
const ysOneHot = tf.oneHot(ys, 10);
const model = tf.sequential();
model.add(tf.layers.conv2d({
inputShape: [28, 28, 1],
filters: 32,
kernelSize: 3,
activation: 'relu'
}));
model.add(tf.layers.batchNormalization());
model.add(tf.layers.maxPooling2d({ poolSize: 2 }));
model.add(tf.layers.dropout({ rate: 0.25 }));
model.add(tf.layers.conv2d({
filters: 64,
kernelSize: 3,
activation: 'relu'
}));
model.add(tf.layers.batchNormalization());
model.add(tf.layers.maxPooling2d({ poolSize: 2 }));
model.add(tf.layers.dropout({ rate: 0.25 }));
model.add(tf.layers.flatten());
model.add(tf.layers.dense({ units: 128, activation: 'relu' }));
model.add(tf.layers.batchNormalization());
model.add(tf.layers.dropout({ rate: 0.5 }));
model.add(tf.layers.dense({ units: 10, activation: 'softmax' }));
model.compile({
optimizer: tf.train.adam(0.001),
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
tfvis.show.modelSummary({ name: 'Model Summary' }, model);
const history = await model.fit(xs, ysOneHot, {
epochs: 20,
batchSize: 32,
validationSplit: 0.2,
callbacks: tfvis.show.fitCallbacks(
{ name: 'Training Performance' },
['loss', 'val_loss', 'acc', 'val_acc']
)
});
return model;
}
实时预测 #
javascript
async function predictImage(model, imageElement) {
const tensor = tf.tidy(() => {
return tf.browser.fromPixels(imageElement)
.resizeNearestNeighbor([28, 28])
.mean(2)
.expandDims(0)
.expandDims(-1)
.toFloat()
.div(255);
});
const prediction = model.predict(tensor);
const probabilities = await prediction.data();
const predictedClass = probabilities.indexOf(Math.max(...probabilities));
tensor.dispose();
prediction.dispose();
return {
class: predictedClass,
probabilities: probabilities
};
}
特征可视化 #
卷积核可视化 #
javascript
function visualizeFilters(layer, numFilters = 16) {
const weights = layer.getWeights()[0];
const filters = weights.unstack(-1);
for (let i = 0; i < Math.min(numFilters, filters.length); i++) {
const filter = filters[i];
const normalized = filter.sub(filter.min()).div(filter.max().sub(filter.min()));
const canvas = document.createElement('canvas');
canvas.width = filter.shape[0];
canvas.height = filter.shape[1];
tf.browser.toPixels(normalized, canvas);
document.body.appendChild(canvas);
}
}
特征图可视化 #
javascript
async function visualizeFeatureMaps(model, image, layerIndex) {
const layer = model.layers[layerIndex];
const featureMapModel = tf.model({
inputs: model.input,
outputs: layer.output
});
const featureMaps = featureMapModel.predict(image);
const numMaps = featureMaps.shape[-1];
for (let i = 0; i < Math.min(16, numMaps); i++) {
const map = featureMaps.slice([0, 0, 0, i], [-1, -1, -1, 1]).squeeze();
const canvas = document.createElement('canvas');
tf.browser.toPixels(map, canvas);
document.body.appendChild(canvas);
}
}
常见 CNN 架构 #
MobileNet 风格 #
javascript
function depthwiseSeparableConv(x, filters, kernelSize = 3) {
x = tf.layers.depthwiseConv2d({
kernelSize: kernelSize,
padding: 'same',
activation: 'relu'
}).apply(x);
x = tf.layers.conv2d({
filters: filters,
kernelSize: 1,
activation: 'relu'
}).apply(x);
return x;
}
Inception 模块 #
javascript
function inceptionModule(x, filters) {
const [f1, f2, f3, f4] = filters;
const branch1 = tf.layers.conv2d({
filters: f1,
kernelSize: 1,
activation: 'relu'
}).apply(x);
const branch2 = tf.layers.conv2d({
filters: f2[0],
kernelSize: 1,
activation: 'relu'
}).apply(x);
const branch2 = tf.layers.conv2d({
filters: f2[1],
kernelSize: 3,
padding: 'same',
activation: 'relu'
}).apply(branch2);
const branch3 = tf.layers.conv2d({
filters: f3[0],
kernelSize: 1,
activation: 'relu'
}).apply(x);
const branch3 = tf.layers.conv2d({
filters: f3[1],
kernelSize: 5,
padding: 'same',
activation: 'relu'
}).apply(branch3);
const branch4 = tf.layers.maxPooling2d({
poolSize: 3,
strides: 1,
padding: 'same'
}).apply(x);
const branch4 = tf.layers.conv2d({
filters: f4,
kernelSize: 1,
activation: 'relu'
}).apply(branch4);
return tf.layers.concatenate().apply([branch1, branch2, branch3, branch4]);
}
目标检测基础 #
简单边界框预测 #
javascript
function createDetectionModel(inputShape, numClasses) {
const input = tf.input({ shape: inputShape });
let x = tf.layers.conv2d({
filters: 32,
kernelSize: 3,
padding: 'same',
activation: 'relu'
}).apply(input);
x = tf.layers.maxPooling2d({ poolSize: 2 }).apply(x);
x = tf.layers.conv2d({
filters: 64,
kernelSize: 3,
padding: 'same',
activation: 'relu'
}).apply(x);
x = tf.layers.maxPooling2d({ poolSize: 2 }).apply(x);
x = tf.layers.flatten().apply(x);
x = tf.layers.dense({ units: 256, activation: 'relu' }).apply(x);
const bbox = tf.layers.dense({ units: 4, name: 'bbox' }).apply(x);
const classes = tf.layers.dense({
units: numClasses,
activation: 'softmax',
name: 'classes'
}).apply(x);
return tf.model({ inputs: input, outputs: [bbox, classes] });
}
下一步 #
现在你已经掌握了 CNN,接下来学习 循环神经网络,了解序列数据处理技术!
最后更新:2026-03-29