TensorFlow.js 高级应用 #

高级主题概述 #

本章涵盖 TensorFlow.js 的高级应用技巧,帮助你构建生产级机器学习应用。

text
┌─────────────────────────────────────────────────────────────┐
│                    高级应用主题                               │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐         │
│  │  自定义组件   │  │  性能优化   │  │  调试技巧   │         │
│  │  层/损失函数  │  │  内存/速度  │  │  可视化     │         │
│  └─────────────┘  └─────────────┘  └─────────────┘         │
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐         │
│  │  生产部署   │  │  安全考虑   │  │  最佳实践   │         │
│  │  容器化     │  │  数据保护   │  │  代码规范   │         │
│  └─────────────┘  └─────────────┘  └─────────────┘         │
│                                                             │
└─────────────────────────────────────────────────────────────┘

自定义组件 #

自定义层 #

javascript
class CustomDense extends tf.layers.Layer {
  constructor(units, config = {}) {
    super(config);
    this.units = units;
    this.activation = config.activation || 'linear';
  }
  
  build(inputShape) {
    this.kernel = this.addWeight(
      'kernel',
      [inputShape[inputShape.length - 1], this.units],
      'float32',
      tf.initializers.glorotNormal()
    );
    
    this.bias = this.addWeight(
      'bias',
      [this.units],
      'float32',
      tf.initializers.zeros()
    );
  }
  
  call(inputs) {
    const input = inputs instanceof tf.Tensor ? inputs : inputs[0];
    let output = input.matMul(this.kernel.read());
    output = output.add(this.bias.read());
    
    if (this.activation !== 'linear') {
      output = tf.layers.activation({ activation: this.activation }).apply(output);
    }
    
    return output;
  }
  
  computeOutputShape(inputShape) {
    return [...inputShape.slice(0, -1), this.units];
  }
  
  static get className() {
    return 'CustomDense';
  }
}

tf.serialization.registerClass(CustomDense);

自定义损失函数 #

javascript
function focalLoss(gamma = 2.0, alpha = 0.25) {
  return (yTrue, yPred) => {
    const epsilon = tf.scalar(1e-7);
    const yPredClipped = tf.clipByValue(yPred, epsilon, tf.scalar(1).sub(epsilon));
    
    const crossEntropy = yTrue.mul(yPredClipped.log()).neg();
    const weight = yTrue.mul(tf.scalar(alpha))
      .add(tf.scalar(1).sub(yTrue).mul(tf.scalar(1 - alpha)));
    const focal = weight.mul(tf.scalar(1).sub(yPredClipped).pow(gamma));
    
    return focal.mul(crossEntropy).sum(-1).mean();
  };
}

model.compile({
  optimizer: 'adam',
  loss: focalLoss(2.0, 0.25)
});

自定义激活函数 #

javascript
function swish(x) {
  return x.mul(tf.sigmoid(x));
}

class Swish extends tf.layers.Layer {
  call(inputs) {
    const input = inputs instanceof tf.Tensor ? inputs : inputs[0];
    return swish(input);
  }
  
  static get className() {
    return 'Swish';
  }
}

tf.serialization.registerClass(Swish);

自定义正则化器 #

javascript
class L1L2Regularizer extends tf.regularizers.Regularizer {
  constructor(l1 = 0.01, l2 = 0.01) {
    super();
    this.l1 = l1;
    this.l2 = l2;
  }
  
  apply(x) {
    let regularization = tf.scalar(0);
    
    if (this.l1 > 0) {
      regularization = regularization.add(
        tf.scalar(this.l1).mul(x.abs().sum())
      );
    }
    
    if (this.l2 > 0) {
      regularization = regularization.add(
        tf.scalar(this.l2).mul(x.square().sum())
      );
    }
    
    return regularization;
  }
  
  static get className() {
    return 'L1L2Regularizer';
  }
}

tf.serialization.registerClass(L1L2Regularizer);

自定义优化器 #

javascript
class SGDMomentum extends tf.Optimizer {
  constructor(learningRate, momentum = 0.9) {
    super();
    this.learningRate = learningRate;
    this.momentum = momentum;
    this.velocities = {};
  }
  
  applyGradients(variableGradients) {
    for (const varName in variableGradients) {
      const gradient = variableGradients[varName];
      const variable = this.getVariable(varName);
      
      if (!this.velocities[varName]) {
        this.velocities[varName] = tf.zerosLike(gradient);
      }
      
      const velocity = this.velocities[varName];
      const newVelocity = velocity.mul(this.momentum).add(gradient);
      
      this.velocities[varName] = newVelocity;
      variable.assign(variable.sub(newVelocity.mul(this.learningRate)));
    }
  }
  
  static get className() {
    return 'SGDMomentum';
  }
}

tf.serialization.registerClass(SGDMomentum);

性能优化 #

内存优化 #

javascript
function optimizedPredict(model, input) {
  return tf.tidy(() => {
    const tensor = tf.tensor(input);
    const output = model.predict(tensor);
    return output.dataSync();
  });
}

async function batchPredictOptimized(model, inputs, batchSize = 32) {
  const results = [];
  
  for (let i = 0; i < inputs.length; i += batchSize) {
    const batchResults = tf.tidy(() => {
      const batch = inputs.slice(i, i + batchSize);
      const batchTensor = tf.stack(batch.map(x => tf.tensor(x)));
      const predictions = model.predict(batchTensor);
      return predictions.arraySync();
    });
    
    results.push(...batchResults);
  }
  
  return results;
}

运算融合 #

javascript
function fusedOperations(input) {
  return tf.tidy(() => {
    return input
      .square()
      .add(1)
      .div(2)
      .sqrt();
  });
}

预分配内存 #

javascript
class PreallocatedPredictor {
  constructor(model, inputShape, outputShape) {
    this.model = model;
    this.inputBuffer = tf.buffer(inputShape);
    this.outputBuffer = tf.buffer(outputShape);
  }
  
  predict(data) {
    for (let i = 0; i < data.length; i++) {
      this.inputBuffer.set(data[i], i);
    }
    
    const input = this.inputBuffer.toTensor();
    const output = this.model.predict(input);
    
    return output.dataSync();
  }
}

异步处理 #

javascript
class AsyncPredictor {
  constructor(model) {
    this.model = model;
    this.queue = [];
    this.processing = false;
  }
  
  async addToQueue(input) {
    return new Promise((resolve) => {
      this.queue.push({ input, resolve });
      this.processQueue();
    });
  }
  
  async processQueue() {
    if (this.processing || this.queue.length === 0) return;
    
    this.processing = true;
    
    while (this.queue.length > 0) {
      const { input, resolve } = this.queue.shift();
      const result = await this.predict(input);
      resolve(result);
    }
    
    this.processing = false;
  }
  
  async predict(input) {
    return tf.tidy(() => {
      const tensor = tf.tensor(input);
      const output = this.model.predict(tensor);
      return output.dataSync();
    });
  }
}

调试技巧 #

张量调试 #

javascript
function debugTensor(tensor, name = 'tensor') {
  console.log(`=== ${name} ===`);
  console.log('Shape:', tensor.shape);
  console.log('Dtype:', tensor.dtype);
  console.log('Min:', tensor.min().dataSync()[0]);
  console.log('Max:', tensor.max().dataSync()[0]);
  console.log('Mean:', tensor.mean().dataSync()[0]);
  console.log('Std:', tensor.std().dataSync()[0]);
}

梯度检查 #

javascript
function checkGradients(model, x, y) {
  const { grads, value } = tf.variableGrads(() => {
    const pred = model.predict(x);
    return tf.losses.meanSquaredError(y, pred);
  });
  
  console.log('Loss:', value.dataSync()[0]);
  
  for (const name in grads) {
    const grad = grads[name];
    console.log(`${name}:`);
    console.log('  Mean:', grad.mean().dataSync()[0]);
    console.log('  Std:', grad.std().dataSync()[0]);
    console.log('  Max:', grad.max().dataSync()[0]);
    console.log('  Min:', grad.min().dataSync()[0]);
  }
}

层输出检查 #

javascript
function inspectLayerOutputs(model, input) {
  const outputs = [];
  
  for (const layer of model.layers) {
    const layerModel = tf.model({
      inputs: model.input,
      outputs: layer.output
    });
    
    const output = layerModel.predict(input);
    outputs.push({
      name: layer.name,
      shape: output.shape,
      mean: output.mean().dataSync()[0],
      std: output.std().dataSync()[0]
    });
    
    output.dispose();
  }
  
  return outputs;
}

训练监控 #

javascript
class TrainingMonitor {
  constructor() {
    this.history = {
      loss: [],
      accuracy: [],
      valLoss: [],
      valAccuracy: []
    };
  }
  
  onEpochEnd(epoch, logs) {
    this.history.loss.push(logs.loss);
    this.history.accuracy.push(logs.acc);
    this.history.valLoss.push(logs.val_loss);
    this.history.valAccuracy.push(logs.val_acc);
    
    console.log(`Epoch ${epoch + 1}:`);
    console.log(`  Loss: ${logs.loss.toFixed(4)}`);
    console.log(`  Accuracy: ${logs.acc.toFixed(4)}`);
    console.log(`  Val Loss: ${logs.val_loss.toFixed(4)}`);
    console.log(`  Val Accuracy: ${logs.val_acc.toFixed(4)}`);
    
    this.checkForIssues(logs);
  }
  
  checkForIssues(logs) {
    if (isNaN(logs.loss)) {
      console.warn('警告: 损失为 NaN!');
    }
    
    if (logs.loss > 1e6) {
      console.warn('警告: 损失爆炸!');
    }
    
    if (this.history.loss.length > 5) {
      const recent = this.history.loss.slice(-5);
      const increasing = recent.every((v, i) => i === 0 || v > recent[i - 1]);
      if (increasing) {
        console.warn('警告: 损失持续增加!');
      }
    }
  }
}

生产最佳实践 #

错误处理 #

javascript
class SafePredictor {
  constructor(model) {
    this.model = model;
  }
  
  async predict(input, options = {}) {
    const { timeout = 5000, retries = 3 } = options;
    
    for (let attempt = 0; attempt < retries; attempt++) {
      try {
        const result = await Promise.race([
          this._predict(input),
          this._timeout(timeout)
        ]);
        return result;
      } catch (error) {
        console.error(`预测失败 (尝试 ${attempt + 1}/${retries}):`, error);
        
        if (attempt === retries - 1) {
          throw new Error(`预测失败: ${error.message}`);
        }
        
        await this._delay(100 * (attempt + 1));
      }
    }
  }
  
  async _predict(input) {
    return tf.tidy(() => {
      const tensor = tf.tensor(input);
      const output = this.model.predict(tensor);
      return output.dataSync();
    });
  }
  
  _timeout(ms) {
    return new Promise((_, reject) => {
      setTimeout(() => reject(new Error('预测超时')), ms);
    });
  }
  
  _delay(ms) {
    return new Promise(resolve => setTimeout(resolve, ms));
  }
}

资源管理 #

javascript
class ResourceManager {
  constructor(options = {}) {
    this.maxMemory = options.maxMemory || 512 * 1024 * 1024;
    this.maxTensors = options.maxTensors || 1000;
  }
  
  checkResources() {
    const mem = tf.memory();
    
    if (mem.numBytes > this.maxMemory) {
      console.warn('内存使用过高,执行清理');
      tf.disposeVariables();
    }
    
    if (mem.numTensors > this.maxTensors) {
      console.warn('张量数量过多,执行清理');
      tf.disposeVariables();
    }
    
    return mem;
  }
  
  wrapOperation(fn) {
    return (...args) => {
      this.checkResources();
      
      try {
        const result = fn(...args);
        
        if (result instanceof Promise) {
          return result.finally(() => this.checkResources());
        }
        
        return result;
      } finally {
        this.checkResources();
      }
    };
  }
}

模型版本管理 #

javascript
class ModelVersionManager {
  constructor(baseUrl) {
    this.baseUrl = baseUrl;
    this.models = new Map();
    this.currentVersion = null;
  }
  
  async loadVersion(version) {
    if (this.models.has(version)) {
      this.currentVersion = version;
      return this.models.get(version);
    }
    
    const modelUrl = `${this.baseUrl}/v${version}/model.json`;
    const model = await tf.loadLayersModel(modelUrl);
    
    this.models.set(version, model);
    this.currentVersion = version;
    
    return model;
  }
  
  getCurrentModel() {
    if (!this.currentVersion) {
      throw new Error('没有加载任何模型版本');
    }
    return this.models.get(this.currentVersion);
  }
  
  async preloadVersions(versions) {
    await Promise.all(versions.map(v => this.loadVersion(v)));
  }
  
  disposeVersion(version) {
    const model = this.models.get(version);
    if (model) {
      model.dispose();
      this.models.delete(version);
    }
  }
  
  disposeAll() {
    for (const model of this.models.values()) {
      model.dispose();
    }
    this.models.clear();
    this.currentVersion = null;
  }
}

日志记录 #

javascript
class TFLogger {
  constructor(level = 'info') {
    this.level = level;
    this.levels = { debug: 0, info: 1, warn: 2, error: 3 };
  }
  
  log(level, message, data = {}) {
    if (this.levels[level] >= this.levels[this.level]) {
      const timestamp = new Date().toISOString();
      console.log(JSON.stringify({
        timestamp,
        level,
        message,
        ...data
      }));
    }
  }
  
  debug(message, data) { this.log('debug', message, data); }
  info(message, data) { this.log('info', message, data); }
  warn(message, data) { this.log('warn', message, data); }
  error(message, data) { this.log('error', message, data); }
  
  logMemory() {
    const mem = tf.memory();
    this.info('Memory status', {
      tensors: mem.numTensors,
      bytes: mem.numBytes,
      buffers: mem.numDataBuffers
    });
  }
}

安全考虑 #

输入验证 #

javascript
function validateInput(input, expectedShape) {
  if (!Array.isArray(input)) {
    throw new Error('输入必须是数组');
  }
  
  const flatInput = flatten(input);
  
  if (flatInput.length !== expectedShape.reduce((a, b) => a * b)) {
    throw new Error(`输入形状不匹配: 期望 ${expectedShape}, 得到 ${flatInput.length}`);
  }
  
  for (const value of flatInput) {
    if (typeof value !== 'number' || isNaN(value)) {
      throw new Error('输入包含非数值');
    }
    
    if (!isFinite(value)) {
      throw new Error('输入包含无穷值');
    }
  }
  
  return true;
}

function flatten(arr) {
  return arr.reduce((acc, val) => 
    Array.isArray(val) ? acc.concat(flatten(val)) : acc.concat(val), 
    []
  );
}

输出限制 #

javascript
function sanitizeOutput(output, options = {}) {
  const { maxValue = 1e6, minValue = -1e6 } = options;
  
  return output.map(value => {
    if (isNaN(value)) return 0;
    if (!isFinite(value)) return value > 0 ? maxValue : minValue;
    return Math.max(minValue, Math.min(maxValue, value));
  });
}

完整项目模板 #

javascript
const tf = require('@tensorflow/tfjs-node');

class MLService {
  constructor(config) {
    this.config = config;
    this.model = null;
    this.logger = new TFLogger(config.logLevel);
    this.resourceManager = new ResourceManager(config.resources);
  }
  
  async init() {
    this.logger.info('初始化服务...');
    
    try {
      this.model = await tf.loadLayersModel(this.config.modelUrl);
      await this._warmup();
      this.logger.info('模型加载完成');
    } catch (error) {
      this.logger.error('模型加载失败', { error: error.message });
      throw error;
    }
  }
  
  async _warmup() {
    const inputShape = this.model.inputShape.slice(1);
    const dummy = tf.zeros([1, ...inputShape]);
    await this.model.predict(dummy).data();
    dummy.dispose();
  }
  
  async predict(input) {
    return this.resourceManager.wrapOperation(async () => {
      validateInput(input, this.model.inputShape.slice(1));
      
      const result = tf.tidy(() => {
        const tensor = tf.tensor(input).expandDims(0);
        const output = this.model.predict(tensor);
        return output.dataSync();
      });
      
      return sanitizeOutput(result);
    })();
  }
  
  async predictBatch(inputs) {
    return Promise.all(inputs.map(input => this.predict(input)));
  }
  
  getHealth() {
    const mem = tf.memory();
    return {
      status: 'healthy',
      modelLoaded: this.model !== null,
      memory: {
        tensors: mem.numTensors,
        bytes: mem.numBytes
      }
    };
  }
  
  dispose() {
    if (this.model) {
      this.model.dispose();
    }
    this.logger.info('服务已关闭');
  }
}

module.exports = { MLService };

总结 #

TensorFlow.js 是一个强大的机器学习库,通过本系列教程,你已经学习了:

  1. 基础概念:张量、运算、模型构建
  2. 核心进阶:层配置、训练方法、优化算法
  3. 高级应用:CNN、RNN、迁移学习
  4. 实战部署:浏览器推理、Node.js 部署

继续实践和探索,你将成为 TensorFlow.js 专家!

最后更新:2026-03-29