TensorFlow.js 浏览器推理 #

浏览器推理概述 #

TensorFlow.js 可以在浏览器中直接运行机器学习模型,无需服务器支持。

text
┌─────────────────────────────────────────────────────────────┐
│                    浏览器推理优势                             │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐         │
│  │  零延迟     │  │  隐私保护   │  │  离线可用   │         │
│  │  无网络请求 │  │  数据本地   │  │  无需服务器 │         │
│  └─────────────┘  └─────────────┘  └─────────────┘         │
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐         │
│  │  跨平台     │  │  实时交互   │  │  低成本     │         │
│  │  任何设备   │  │  即时反馈   │  │  无服务器费 │         │
│  └─────────────┘  └─────────────┘  └─────────────┘         │
│                                                             │
└─────────────────────────────────────────────────────────────┘

后端选择 #

查看和设置后端 #

javascript
console.log('当前后端:', tf.getBackend());

await tf.setBackend('webgl');
await tf.setBackend('wasm');
await tf.setBackend('cpu');

await tf.ready();
console.log('后端就绪:', tf.getBackend());

后端对比 #

后端 性能 兼容性 适用场景
WebGL 最快 现代浏览器 生产环境
WASM 中等 广泛支持 兼容性优先
CPU 最慢 全部支持 回退方案

WebGL 后端 #

javascript
import '@tensorflow/tfjs-backend-webgl';
import * as tf from '@tensorflow/tfjs';

await tf.setBackend('webgl');
await tf.ready();

console.log('WebGL 版本:', tf.ENV.get('WEBGL_VERSION'));
console.log('最大纹理尺寸:', tf.ENV.get('WEBGL_MAX_TEXTURE_SIZE'));

WebAssembly 后端 #

bash
npm install @tensorflow/tfjs-backend-wasm
javascript
import '@tensorflow/tfjs-backend-wasm';
import * as tf from '@tensorflow/tfjs';

await tf.setBackend('wasm');
await tf.ready();

自动选择最佳后端 #

javascript
async function getBestBackend() {
  const backends = ['webgl', 'wasm', 'cpu'];
  
  for (const backend of backends) {
    try {
      await tf.setBackend(backend);
      await tf.ready();
      console.log(`使用后端: ${backend}`);
      return backend;
    } catch (e) {
      console.log(`后端 ${backend} 不可用`);
    }
  }
  
  return 'cpu';
}

模型加载优化 #

预加载模型 #

javascript
let modelPromise = null;

function getModel() {
  if (!modelPromise) {
    modelPromise = tf.loadLayersModel('model.json');
  }
  return modelPromise;
}

async function predict(input) {
  const model = await getModel();
  return model.predict(input);
}

模型缓存 #

javascript
async function loadModelWithCache(url) {
  const cacheName = 'tfjs-models';
  const cache = await caches.open(cacheName);
  
  const cachedResponse = await cache.match(url);
  if (cachedResponse) {
    console.log('从缓存加载模型');
    return tf.loadLayersModel(url);
  }
  
  console.log('从网络加载模型');
  const model = await tf.loadLayersModel(url);
  
  await cache.add(url);
  
  return model;
}

渐进式加载 #

javascript
async function loadModelProgressively(url, onProgress) {
  const response = await fetch(url);
  const contentLength = response.headers.get('content-length');
  const total = parseInt(contentLength, 10);
  let loaded = 0;
  
  const reader = response.body.getReader();
  const chunks = [];
  
  while (true) {
    const { done, value } = await reader.read();
    if (done) break;
    
    chunks.push(value);
    loaded += value.length;
    
    if (onProgress) {
      onProgress(loaded / total);
    }
  }
  
  const blob = new Blob(chunks);
  const modelUrl = URL.createObjectURL(blob);
  return tf.loadLayersModel(modelUrl);
}

性能优化 #

内存管理 #

javascript
function predictWithCleanup(model, input) {
  return tf.tidy(() => {
    const output = model.predict(input);
    return output.dataSync();
  });
}

批量处理 #

javascript
async function batchPredict(model, inputs, batchSize = 32) {
  const results = [];
  
  for (let i = 0; i < inputs.length; i += batchSize) {
    const batch = inputs.slice(i, i + batchSize);
    const batchTensor = tf.stack(batch);
    
    const predictions = model.predict(batchTensor);
    results.push(...await predictions.array());
    
    batchTensor.dispose();
    predictions.dispose();
  }
  
  return results;
}

预热模型 #

javascript
async function warmupModel(model, inputShape) {
  const dummyInput = tf.zeros(inputShape);
  await model.predict(dummyInput).data();
  dummyInput.dispose();
  console.log('模型预热完成');
}

使用 async 推理 #

javascript
async function predictAsync(model, input) {
  const output = await model.predictAsync(input);
  const result = await output.data();
  output.dispose();
  return result;
}

WebWorker 多线程 #

创建 Worker #

javascript
const workerCode = `
  importScripts('https://cdn.jsdelivr.net/npm/@tensorflow/tfjs');
  
  let model = null;
  
  self.onmessage = async (e) => {
    const { type, data } = e.data;
    
    if (type === 'init') {
      model = await tf.loadLayersModel(data.modelUrl);
      await model.predict(tf.zeros(data.inputShape)).data();
      self.postMessage({ type: 'ready' });
    }
    
    if (type === 'predict') {
      const input = tf.tensor(data.input);
      const output = model.predict(input);
      const result = await output.data();
      
      input.dispose();
      output.dispose();
      
      self.postMessage({ type: 'result', data: result });
    }
  };
`;

const blob = new Blob([workerCode], { type: 'application/javascript' });
const workerUrl = URL.createObjectURL(blob);
const worker = new Worker(workerUrl);

使用 Worker #

javascript
function initWorker(modelUrl, inputShape) {
  return new Promise((resolve) => {
    worker.onmessage = (e) => {
      if (e.data.type === 'ready') {
        resolve();
      }
    };
    
    worker.postMessage({
      type: 'init',
      data: { modelUrl, inputShape }
    });
  });
}

function predictInWorker(input) {
  return new Promise((resolve) => {
    worker.onmessage = (e) => {
      if (e.data.type === 'result') {
        resolve(e.data.data);
      }
    };
    
    worker.postMessage({
      type: 'predict',
      data: { input }
    });
  });
}

实时推理 #

视频流处理 #

javascript
async function processVideoStream(model, video, callback) {
  const canvas = document.createElement('canvas');
  const ctx = canvas.getContext('2d');
  
  canvas.width = 224;
  canvas.height = 224;
  
  async function processFrame() {
    if (video.paused || video.ended) return;
    
    ctx.drawImage(video, 0, 0, 224, 224);
    const imageData = ctx.getImageData(0, 0, 224, 224);
    
    const tensor = tf.tidy(() => {
      return tf.browser.fromPixels(imageData)
        .toFloat()
        .div(255)
        .expandDims(0);
    });
    
    const prediction = model.predict(tensor);
    const result = await prediction.data();
    
    callback(result);
    
    tensor.dispose();
    prediction.dispose();
    
    requestAnimationFrame(processFrame);
  }
  
  processFrame();
}

摄像头实时推理 #

javascript
async function setupCamera() {
  const video = document.getElementById('video');
  
  const stream = await navigator.mediaDevices.getUserMedia({
    video: { width: 640, height: 480 }
  });
  
  video.srcObject = stream;
  await video.play();
  
  return video;
}

async function realtimeInference() {
  const model = await tf.loadLayersModel('model.json');
  const video = await setupCamera();
  
  const fps = 30;
  const interval = 1000 / fps;
  let lastTime = 0;
  
  async function processFrame(timestamp) {
    if (timestamp - lastTime >= interval) {
      lastTime = timestamp;
      
      const tensor = tf.tidy(() => {
        return tf.browser.fromPixels(video)
          .resizeNearestNeighbor([224, 224])
          .toFloat()
          .div(255)
          .expandDims(0);
      });
      
      const prediction = model.predict(tensor);
      const result = await prediction.data();
      
      updateUI(result);
      
      tensor.dispose();
      prediction.dispose();
    }
    
    requestAnimationFrame(processFrame);
  }
  
  requestAnimationFrame(processFrame);
}

模型量化 #

使用量化模型 #

javascript
const model = await tf.loadLayersModel(
  'https://example.com/model_quantized.json'
);

转换时量化 #

bash
tensorflowjs_converter \
  --input_format keras \
  --quantization_bytes 2 \
  model.h5 \
  ./quantized_model

量化效果对比 #

精度 模型大小 推理速度 准确率损失
float32 100% 基准 0%
float16 50% +10% <1%
int8 25% +20% 1-3%

模型分割 #

分片加载 #

javascript
async function loadShardedModel(baseUrl, shards) {
  const modelArtifacts = {
    modelTopology: null,
    weightSpecs: [],
    weightData: []
  };
  
  for (const shard of shards) {
    const response = await fetch(`${baseUrl}/${shard}`);
    const data = await response.json();
    
    if (shard === 'model.json') {
      modelArtifacts.modelTopology = data.modelTopology;
    } else {
      modelArtifacts.weightSpecs.push(...data.weightSpecs);
      modelArtifacts.weightData.push(...data.weightData);
    }
  }
  
  return tf.loadLayersModel(tf.io.fromMemory(modelArtifacts));
}

错误处理 #

后端错误处理 #

javascript
async function safeSetBackend(backend) {
  try {
    await tf.setBackend(backend);
    await tf.ready();
    return true;
  } catch (error) {
    console.error(`后端 ${backend} 初始化失败:`, error);
    return false;
  }
}

模型加载错误处理 #

javascript
async function loadModelWithFallback(primaryUrl, fallbackUrl) {
  try {
    return await tf.loadLayersModel(primaryUrl);
  } catch (error) {
    console.error('主模型加载失败,尝试备用模型');
    return await tf.loadLayersModel(fallbackUrl);
  }
}

内存错误处理 #

javascript
async function predictWithMemoryCheck(model, input) {
  const memBefore = tf.memory();
  
  try {
    const output = model.predict(input);
    const result = await output.data();
    output.dispose();
    return result;
  } catch (error) {
    if (error.message.includes('memory')) {
      console.error('内存不足,清理后重试');
      tf.disposeVariables();
      return predictWithMemoryCheck(model, input);
    }
    throw error;
  }
}

性能监控 #

内存监控 #

javascript
function monitorMemory() {
  const mem = tf.memory();
  console.log({
    张量数量: mem.numTensors,
    字节数: (mem.numBytes / 1024 / 1024).toFixed(2) + ' MB',
    不稳定张量: mem.numDataBuffers
  });
}

setInterval(monitorMemory, 1000);

推理时间测量 #

javascript
async function measureInferenceTime(model, input, runs = 100) {
  const times = [];
  
  for (let i = 0; i < runs; i++) {
    const start = performance.now();
    const output = model.predict(input);
    await output.data();
    const end = performance.now();
    
    times.push(end - start);
    output.dispose();
  }
  
  const avg = times.reduce((a, b) => a + b) / times.length;
  const min = Math.min(...times);
  const max = Math.max(...times);
  
  console.log(`平均: ${avg.toFixed(2)}ms, 最小: ${min.toFixed(2)}ms, 最大: ${max.toFixed(2)}ms`);
}

FPS 监控 #

javascript
class FPSMonitor {
  constructor(windowSize = 60) {
    this.frames = [];
    this.windowSize = windowSize;
    this.lastTime = performance.now();
  }
  
  tick() {
    const now = performance.now();
    const delta = now - this.lastTime;
    this.lastTime = now;
    
    this.frames.push(delta);
    if (this.frames.length > this.windowSize) {
      this.frames.shift();
    }
  }
  
  getFPS() {
    const avgDelta = this.frames.reduce((a, b) => a + b) / this.frames.length;
    return 1000 / avgDelta;
  }
}

完整示例 #

javascript
import * as tf from '@tensorflow/tfjs';
import '@tensorflow/tfjs-backend-webgl';

class BrowserInference {
  constructor() {
    this.model = null;
    this.backend = null;
    this.isReady = false;
  }
  
  async init(modelUrl, preferredBackend = 'webgl') {
    await this.setBestBackend(preferredBackend);
    this.model = await tf.loadLayersModel(modelUrl);
    await this.warmup();
    this.isReady = true;
  }
  
  async setBestBackend(preferred) {
    const backends = [preferred, 'wasm', 'cpu'];
    
    for (const backend of backends) {
      try {
        await tf.setBackend(backend);
        await tf.ready();
        this.backend = backend;
        console.log(`使用后端: ${backend}`);
        return;
      } catch (e) {
        console.log(`后端 ${backend} 不可用`);
      }
    }
  }
  
  async warmup() {
    const inputShape = this.model.inputShape.slice(1);
    const dummyInput = tf.zeros([1, ...inputShape]);
    await this.model.predict(dummyInput).data();
    dummyInput.dispose();
  }
  
  async predict(input, options = {}) {
    if (!this.isReady) {
      throw new Error('模型未初始化');
    }
    
    return tf.tidy(() => {
      const tensor = tf.tensor(input).expandDims(0);
      const output = this.model.predict(tensor);
      return output.dataSync();
    });
  }
  
  getMemoryInfo() {
    return tf.memory();
  }
  
  dispose() {
    if (this.model) {
      this.model.dispose();
    }
  }
}

const inference = new BrowserInference();
await inference.init('model.json');
const result = await inference.predict([1, 2, 3, 4]);
console.log(result);

下一步 #

现在你已经掌握了浏览器推理,接下来学习 Node.js 部署,了解服务端机器学习!

最后更新:2026-03-29