TensorFlow.js 浏览器推理 #
浏览器推理概述 #
TensorFlow.js 可以在浏览器中直接运行机器学习模型,无需服务器支持。
text
┌─────────────────────────────────────────────────────────────┐
│ 浏览器推理优势 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ 零延迟 │ │ 隐私保护 │ │ 离线可用 │ │
│ │ 无网络请求 │ │ 数据本地 │ │ 无需服务器 │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ 跨平台 │ │ 实时交互 │ │ 低成本 │ │
│ │ 任何设备 │ │ 即时反馈 │ │ 无服务器费 │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
后端选择 #
查看和设置后端 #
javascript
console.log('当前后端:', tf.getBackend());
await tf.setBackend('webgl');
await tf.setBackend('wasm');
await tf.setBackend('cpu');
await tf.ready();
console.log('后端就绪:', tf.getBackend());
后端对比 #
| 后端 | 性能 | 兼容性 | 适用场景 |
|---|---|---|---|
| WebGL | 最快 | 现代浏览器 | 生产环境 |
| WASM | 中等 | 广泛支持 | 兼容性优先 |
| CPU | 最慢 | 全部支持 | 回退方案 |
WebGL 后端 #
javascript
import '@tensorflow/tfjs-backend-webgl';
import * as tf from '@tensorflow/tfjs';
await tf.setBackend('webgl');
await tf.ready();
console.log('WebGL 版本:', tf.ENV.get('WEBGL_VERSION'));
console.log('最大纹理尺寸:', tf.ENV.get('WEBGL_MAX_TEXTURE_SIZE'));
WebAssembly 后端 #
bash
npm install @tensorflow/tfjs-backend-wasm
javascript
import '@tensorflow/tfjs-backend-wasm';
import * as tf from '@tensorflow/tfjs';
await tf.setBackend('wasm');
await tf.ready();
自动选择最佳后端 #
javascript
async function getBestBackend() {
const backends = ['webgl', 'wasm', 'cpu'];
for (const backend of backends) {
try {
await tf.setBackend(backend);
await tf.ready();
console.log(`使用后端: ${backend}`);
return backend;
} catch (e) {
console.log(`后端 ${backend} 不可用`);
}
}
return 'cpu';
}
模型加载优化 #
预加载模型 #
javascript
let modelPromise = null;
function getModel() {
if (!modelPromise) {
modelPromise = tf.loadLayersModel('model.json');
}
return modelPromise;
}
async function predict(input) {
const model = await getModel();
return model.predict(input);
}
模型缓存 #
javascript
async function loadModelWithCache(url) {
const cacheName = 'tfjs-models';
const cache = await caches.open(cacheName);
const cachedResponse = await cache.match(url);
if (cachedResponse) {
console.log('从缓存加载模型');
return tf.loadLayersModel(url);
}
console.log('从网络加载模型');
const model = await tf.loadLayersModel(url);
await cache.add(url);
return model;
}
渐进式加载 #
javascript
async function loadModelProgressively(url, onProgress) {
const response = await fetch(url);
const contentLength = response.headers.get('content-length');
const total = parseInt(contentLength, 10);
let loaded = 0;
const reader = response.body.getReader();
const chunks = [];
while (true) {
const { done, value } = await reader.read();
if (done) break;
chunks.push(value);
loaded += value.length;
if (onProgress) {
onProgress(loaded / total);
}
}
const blob = new Blob(chunks);
const modelUrl = URL.createObjectURL(blob);
return tf.loadLayersModel(modelUrl);
}
性能优化 #
内存管理 #
javascript
function predictWithCleanup(model, input) {
return tf.tidy(() => {
const output = model.predict(input);
return output.dataSync();
});
}
批量处理 #
javascript
async function batchPredict(model, inputs, batchSize = 32) {
const results = [];
for (let i = 0; i < inputs.length; i += batchSize) {
const batch = inputs.slice(i, i + batchSize);
const batchTensor = tf.stack(batch);
const predictions = model.predict(batchTensor);
results.push(...await predictions.array());
batchTensor.dispose();
predictions.dispose();
}
return results;
}
预热模型 #
javascript
async function warmupModel(model, inputShape) {
const dummyInput = tf.zeros(inputShape);
await model.predict(dummyInput).data();
dummyInput.dispose();
console.log('模型预热完成');
}
使用 async 推理 #
javascript
async function predictAsync(model, input) {
const output = await model.predictAsync(input);
const result = await output.data();
output.dispose();
return result;
}
WebWorker 多线程 #
创建 Worker #
javascript
const workerCode = `
importScripts('https://cdn.jsdelivr.net/npm/@tensorflow/tfjs');
let model = null;
self.onmessage = async (e) => {
const { type, data } = e.data;
if (type === 'init') {
model = await tf.loadLayersModel(data.modelUrl);
await model.predict(tf.zeros(data.inputShape)).data();
self.postMessage({ type: 'ready' });
}
if (type === 'predict') {
const input = tf.tensor(data.input);
const output = model.predict(input);
const result = await output.data();
input.dispose();
output.dispose();
self.postMessage({ type: 'result', data: result });
}
};
`;
const blob = new Blob([workerCode], { type: 'application/javascript' });
const workerUrl = URL.createObjectURL(blob);
const worker = new Worker(workerUrl);
使用 Worker #
javascript
function initWorker(modelUrl, inputShape) {
return new Promise((resolve) => {
worker.onmessage = (e) => {
if (e.data.type === 'ready') {
resolve();
}
};
worker.postMessage({
type: 'init',
data: { modelUrl, inputShape }
});
});
}
function predictInWorker(input) {
return new Promise((resolve) => {
worker.onmessage = (e) => {
if (e.data.type === 'result') {
resolve(e.data.data);
}
};
worker.postMessage({
type: 'predict',
data: { input }
});
});
}
实时推理 #
视频流处理 #
javascript
async function processVideoStream(model, video, callback) {
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');
canvas.width = 224;
canvas.height = 224;
async function processFrame() {
if (video.paused || video.ended) return;
ctx.drawImage(video, 0, 0, 224, 224);
const imageData = ctx.getImageData(0, 0, 224, 224);
const tensor = tf.tidy(() => {
return tf.browser.fromPixels(imageData)
.toFloat()
.div(255)
.expandDims(0);
});
const prediction = model.predict(tensor);
const result = await prediction.data();
callback(result);
tensor.dispose();
prediction.dispose();
requestAnimationFrame(processFrame);
}
processFrame();
}
摄像头实时推理 #
javascript
async function setupCamera() {
const video = document.getElementById('video');
const stream = await navigator.mediaDevices.getUserMedia({
video: { width: 640, height: 480 }
});
video.srcObject = stream;
await video.play();
return video;
}
async function realtimeInference() {
const model = await tf.loadLayersModel('model.json');
const video = await setupCamera();
const fps = 30;
const interval = 1000 / fps;
let lastTime = 0;
async function processFrame(timestamp) {
if (timestamp - lastTime >= interval) {
lastTime = timestamp;
const tensor = tf.tidy(() => {
return tf.browser.fromPixels(video)
.resizeNearestNeighbor([224, 224])
.toFloat()
.div(255)
.expandDims(0);
});
const prediction = model.predict(tensor);
const result = await prediction.data();
updateUI(result);
tensor.dispose();
prediction.dispose();
}
requestAnimationFrame(processFrame);
}
requestAnimationFrame(processFrame);
}
模型量化 #
使用量化模型 #
javascript
const model = await tf.loadLayersModel(
'https://example.com/model_quantized.json'
);
转换时量化 #
bash
tensorflowjs_converter \
--input_format keras \
--quantization_bytes 2 \
model.h5 \
./quantized_model
量化效果对比 #
| 精度 | 模型大小 | 推理速度 | 准确率损失 |
|---|---|---|---|
| float32 | 100% | 基准 | 0% |
| float16 | 50% | +10% | <1% |
| int8 | 25% | +20% | 1-3% |
模型分割 #
分片加载 #
javascript
async function loadShardedModel(baseUrl, shards) {
const modelArtifacts = {
modelTopology: null,
weightSpecs: [],
weightData: []
};
for (const shard of shards) {
const response = await fetch(`${baseUrl}/${shard}`);
const data = await response.json();
if (shard === 'model.json') {
modelArtifacts.modelTopology = data.modelTopology;
} else {
modelArtifacts.weightSpecs.push(...data.weightSpecs);
modelArtifacts.weightData.push(...data.weightData);
}
}
return tf.loadLayersModel(tf.io.fromMemory(modelArtifacts));
}
错误处理 #
后端错误处理 #
javascript
async function safeSetBackend(backend) {
try {
await tf.setBackend(backend);
await tf.ready();
return true;
} catch (error) {
console.error(`后端 ${backend} 初始化失败:`, error);
return false;
}
}
模型加载错误处理 #
javascript
async function loadModelWithFallback(primaryUrl, fallbackUrl) {
try {
return await tf.loadLayersModel(primaryUrl);
} catch (error) {
console.error('主模型加载失败,尝试备用模型');
return await tf.loadLayersModel(fallbackUrl);
}
}
内存错误处理 #
javascript
async function predictWithMemoryCheck(model, input) {
const memBefore = tf.memory();
try {
const output = model.predict(input);
const result = await output.data();
output.dispose();
return result;
} catch (error) {
if (error.message.includes('memory')) {
console.error('内存不足,清理后重试');
tf.disposeVariables();
return predictWithMemoryCheck(model, input);
}
throw error;
}
}
性能监控 #
内存监控 #
javascript
function monitorMemory() {
const mem = tf.memory();
console.log({
张量数量: mem.numTensors,
字节数: (mem.numBytes / 1024 / 1024).toFixed(2) + ' MB',
不稳定张量: mem.numDataBuffers
});
}
setInterval(monitorMemory, 1000);
推理时间测量 #
javascript
async function measureInferenceTime(model, input, runs = 100) {
const times = [];
for (let i = 0; i < runs; i++) {
const start = performance.now();
const output = model.predict(input);
await output.data();
const end = performance.now();
times.push(end - start);
output.dispose();
}
const avg = times.reduce((a, b) => a + b) / times.length;
const min = Math.min(...times);
const max = Math.max(...times);
console.log(`平均: ${avg.toFixed(2)}ms, 最小: ${min.toFixed(2)}ms, 最大: ${max.toFixed(2)}ms`);
}
FPS 监控 #
javascript
class FPSMonitor {
constructor(windowSize = 60) {
this.frames = [];
this.windowSize = windowSize;
this.lastTime = performance.now();
}
tick() {
const now = performance.now();
const delta = now - this.lastTime;
this.lastTime = now;
this.frames.push(delta);
if (this.frames.length > this.windowSize) {
this.frames.shift();
}
}
getFPS() {
const avgDelta = this.frames.reduce((a, b) => a + b) / this.frames.length;
return 1000 / avgDelta;
}
}
完整示例 #
javascript
import * as tf from '@tensorflow/tfjs';
import '@tensorflow/tfjs-backend-webgl';
class BrowserInference {
constructor() {
this.model = null;
this.backend = null;
this.isReady = false;
}
async init(modelUrl, preferredBackend = 'webgl') {
await this.setBestBackend(preferredBackend);
this.model = await tf.loadLayersModel(modelUrl);
await this.warmup();
this.isReady = true;
}
async setBestBackend(preferred) {
const backends = [preferred, 'wasm', 'cpu'];
for (const backend of backends) {
try {
await tf.setBackend(backend);
await tf.ready();
this.backend = backend;
console.log(`使用后端: ${backend}`);
return;
} catch (e) {
console.log(`后端 ${backend} 不可用`);
}
}
}
async warmup() {
const inputShape = this.model.inputShape.slice(1);
const dummyInput = tf.zeros([1, ...inputShape]);
await this.model.predict(dummyInput).data();
dummyInput.dispose();
}
async predict(input, options = {}) {
if (!this.isReady) {
throw new Error('模型未初始化');
}
return tf.tidy(() => {
const tensor = tf.tensor(input).expandDims(0);
const output = this.model.predict(tensor);
return output.dataSync();
});
}
getMemoryInfo() {
return tf.memory();
}
dispose() {
if (this.model) {
this.model.dispose();
}
}
}
const inference = new BrowserInference();
await inference.init('model.json');
const result = await inference.predict([1, 2, 3, 4]);
console.log(result);
下一步 #
现在你已经掌握了浏览器推理,接下来学习 Node.js 部署,了解服务端机器学习!
最后更新:2026-03-29