TensorFlow.js Node.js 部署 #

Node.js 版本概述 #

TensorFlow.js Node.js 版本提供了更强大的计算能力和 GPU 支持。

text
┌─────────────────────────────────────────────────────────────┐
│                  Node.js 版本优势                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐         │
│  │  GPU 加速   │  │  更快速度   │  │  完整功能   │         │
│  │  CUDA 支持  │  │  原生绑定   │  │  文件系统   │         │
│  └─────────────┘  └─────────────┘  └─────────────┘         │
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐         │
│  │  服务端渲染 │  │  批量处理   │  │  API 服务   │         │
│  │  SSR 支持   │  │  大规模     │  │  微服务     │         │
│  └─────────────┘  └─────────────┘  └─────────────┘         │
│                                                             │
└─────────────────────────────────────────────────────────────┘

安装配置 #

CPU 版本 #

bash
npm install @tensorflow/tfjs-node
javascript
const tf = require('@tensorflow/tfjs-node');

GPU 版本 #

bash
npm install @tensorflow/tfjs-node-gpu
javascript
const tf = require('@tensorflow/tfjs-node-gpu');

系统要求 #

版本 要求
CPU Node.js 12+
GPU CUDA 11.x, cuDNN 8.x

验证安装 #

javascript
const tf = require('@tensorflow/tfjs-node');

console.log('TensorFlow.js 版本:', tf.version.tfjs);
console.log('后端:', tf.getBackend());

const tensor = tf.tensor([1, 2, 3, 4]);
tensor.print();

基本使用 #

创建张量 #

javascript
const tf = require('@tensorflow/tfjs-node');

const scalar = tf.scalar(5);
const vector = tf.tensor1d([1, 2, 3]);
const matrix = tf.tensor2d([[1, 2], [3, 4]]);

console.log('标量:');
scalar.print();

console.log('向量:');
vector.print();

console.log('矩阵:');
matrix.print();

加载模型 #

javascript
const tf = require('@tensorflow/tfjs-node');

async function loadModel() {
  const model = await tf.loadLayersModel('file://./model/model.json');
  console.log('模型加载成功');
  return model;
}

模型预测 #

javascript
async function predict(model, inputData) {
  const inputTensor = tf.tensor2d(inputData, [inputData.length, inputData[0].length]);
  const prediction = model.predict(inputTensor);
  const result = prediction.dataSync();
  
  inputTensor.dispose();
  prediction.dispose();
  
  return result;
}

文件操作 #

保存模型 #

javascript
const tf = require('@tensorflow/tfjs-node');

async function saveModel(model, path) {
  await model.save(`file://${path}`);
  console.log(`模型已保存到 ${path}`);
}

加载本地模型 #

javascript
async function loadLocalModel(path) {
  const model = await tf.loadLayersModel(`file://${path}/model.json`);
  return model;
}

图像处理 #

javascript
const tf = require('@tensorflow/tfjs-node');
const fs = require('fs');
const path = require('path');

async function loadImage(imagePath) {
  const buffer = fs.readFileSync(imagePath);
  const tensor = tf.node.decodeImage(buffer);
  return tensor;
}

async function preprocessImage(imagePath, targetSize = [224, 224]) {
  const tensor = await loadImage(imagePath);
  
  const resized = tf.image.resizeBilinear(tensor, targetSize);
  const normalized = resized.div(255);
  const batched = normalized.expandDims(0);
  
  tensor.dispose();
  resized.dispose();
  normalized.dispose();
  
  return batched;
}

批量图像处理 #

javascript
async function loadImagesFromDirectory(dir, targetSize = [224, 224]) {
  const files = fs.readdirSync(dir).filter(f => 
    f.endsWith('.jpg') || f.endsWith('.png')
  );
  
  const tensors = [];
  
  for (const file of files) {
    const imagePath = path.join(dir, file);
    const tensor = await preprocessImage(imagePath, targetSize);
    tensors.push(tensor);
  }
  
  return tf.concat(tensors, 0);
}

Express API 服务 #

创建 API 服务 #

javascript
const express = require('express');
const multer = require('multer');
const tf = require('@tensorflow/tfjs-node');

const app = express();
const upload = multer({ dest: 'uploads/' });

let model = null;

async function initModel() {
  model = await tf.loadLayersModel('file://./model/model.json');
  console.log('模型加载完成');
}

app.post('/predict', upload.single('image'), async (req, res) => {
  try {
    const buffer = fs.readFileSync(req.file.path);
    const tensor = tf.node.decodeImage(buffer);
    
    const processed = tf.image.resizeBilinear(tensor, [224, 224])
      .div(255)
      .expandDims(0);
    
    const prediction = model.predict(processed);
    const result = await prediction.data();
    
    tensor.dispose();
    processed.dispose();
    prediction.dispose();
    
    fs.unlinkSync(req.file.path);
    
    res.json({
      success: true,
      prediction: Array.from(result)
    });
  } catch (error) {
    res.status(500).json({ error: error.message });
  }
});

const PORT = process.env.PORT || 3000;
initModel().then(() => {
  app.listen(PORT, () => {
    console.log(`服务运行在端口 ${PORT}`);
  });
});

批量预测 API #

javascript
app.post('/predict/batch', upload.array('images', 10), async (req, res) => {
  try {
    const predictions = [];
    
    for (const file of req.files) {
      const buffer = fs.readFileSync(file.path);
      const tensor = tf.node.decodeImage(buffer);
      
      const processed = tf.image.resizeBilinear(tensor, [224, 224])
        .div(255)
        .expandDims(0);
      
      const prediction = model.predict(processed);
      const result = await prediction.data();
      
      predictions.push(Array.from(result));
      
      tensor.dispose();
      processed.dispose();
      prediction.dispose();
      fs.unlinkSync(file.path);
    }
    
    res.json({
      success: true,
      count: predictions.length,
      predictions
    });
  } catch (error) {
    res.status(500).json({ error: error.message });
  }
});

GPU 加速 #

检查 GPU #

javascript
const tf = require('@tensorflow/tfjs-node-gpu');

async function checkGPU() {
  console.log('后端:', tf.getBackend());
  
  const gpuInfo = await tf.node.queryTensorflowVersion();
  console.log('TensorFlow 版本:', gpuInfo);
}

GPU 内存管理 #

javascript
function gpuMemoryInfo() {
  const info = tf.memory();
  console.log({
    张量数量: info.numTensors,
    内存使用: (info.numBytes / 1024 / 1024).toFixed(2) + ' MB'
  });
}

setInterval(gpuMemoryInfo, 5000);

GPU 配置 #

javascript
const tf = require('@tensorflow/tfjs-node-gpu');

tf.env().set('WEBGL_FORCE_F16_TEXTURES', true);
tf.env().set('TF_ENABLE_ONEDNN_OPTS', 1);

性能优化 #

并行处理 #

javascript
async function parallelPredict(model, inputs, batchSize = 32) {
  const results = [];
  
  for (let i = 0; i < inputs.length; i += batchSize) {
    const batch = inputs.slice(i, i + batchSize);
    const batchTensor = tf.stack(batch);
    
    const predictions = model.predict(batchTensor);
    const batchResults = await predictions.array();
    
    results.push(...batchResults);
    
    batchTensor.dispose();
    predictions.dispose();
  }
  
  return results;
}

Worker 线程 #

javascript
const { Worker, isMainThread, parentPort, workerData } = require('worker_threads');
const tf = require('@tensorflow/tfjs-node');

if (isMainThread) {
  async function parallelInference(inputs, numWorkers = 4) {
    const chunkSize = Math.ceil(inputs.length / numWorkers);
    const workers = [];
    
    for (let i = 0; i < numWorkers; i++) {
      const chunk = inputs.slice(i * chunkSize, (i + 1) * chunkSize);
      workers.push(new Worker(__filename, { workerData: chunk }));
    }
    
    const results = await Promise.all(
      workers.map(worker => {
        return new Promise(resolve => {
          worker.on('message', resolve);
        });
      })
    );
    
    return results.flat();
  }
} else {
  async function processChunk() {
    const model = await tf.loadLayersModel('file://./model/model.json');
    const predictions = [];
    
    for (const input of workerData) {
      const tensor = tf.tensor(input).expandDims(0);
      const prediction = model.predict(tensor);
      predictions.push(await prediction.data());
      
      tensor.dispose();
      prediction.dispose();
    }
    
    parentPort.postMessage(predictions);
  }
  
  processChunk();
}

内存优化 #

javascript
class ModelPredictor {
  constructor(modelPath) {
    this.modelPath = modelPath;
    this.model = null;
  }
  
  async init() {
    this.model = await tf.loadLayersModel(`file://${this.modelPath}`);
    await this.warmup();
  }
  
  async warmup() {
    const inputShape = this.model.inputShape.slice(1);
    const dummy = tf.zeros([1, ...inputShape]);
    await this.model.predict(dummy).data();
    dummy.dispose();
  }
  
  async predict(input) {
    return tf.tidy(() => {
      const tensor = tf.tensor(input).expandDims(0);
      const output = this.model.predict(tensor);
      return output.dataSync();
    });
  }
  
  async predictBatch(inputs) {
    const results = [];
    
    for (const input of inputs) {
      const result = await this.predict(input);
      results.push(result);
    }
    
    return results;
  }
}

数据处理 #

CSV 数据加载 #

javascript
const tf = require('@tensorflow/tfjs-node');
const fs = require('fs');
const csv = require('csv-parser');

async function loadCSV(filePath) {
  return new Promise((resolve) => {
    const results = [];
    
    fs.createReadStream(filePath)
      .pipe(csv())
      .on('data', (data) => results.push(data))
      .on('end', () => resolve(results));
  });
}

async function loadCSVTensors(filePath, labelColumn) {
  const data = await loadCSV(filePath);
  
  const features = data.map(row => {
    const f = { ...row };
    delete f[labelColumn];
    return Object.values(f).map(Number);
  });
  
  const labels = data.map(row => Number(row[labelColumn]));
  
  return {
    xs: tf.tensor2d(features),
    ys: tf.tensor1d(labels)
  };
}

JSON 数据加载 #

javascript
async function loadJSON(filePath) {
  const content = fs.readFileSync(filePath, 'utf8');
  return JSON.parse(content);
}

async function loadJSONTensors(filePath, featureKeys, labelKey) {
  const data = await loadJSON(filePath);
  
  const features = data.map(item => 
    featureKeys.map(key => item[key])
  );
  
  const labels = data.map(item => item[labelKey]);
  
  return {
    xs: tf.tensor2d(features),
    ys: tf.tensor1d(labels)
  };
}

生产部署 #

Docker 部署 #

dockerfile
FROM node:18

WORKDIR /app

COPY package*.json ./
RUN npm install

COPY . .

EXPOSE 3000

CMD ["node", "server.js"]

Docker Compose #

yaml
version: '3.8'
services:
  tfjs-api:
    build: .
    ports:
      - "3000:3000"
    environment:
      - NODE_ENV=production
      - PORT=3000
    volumes:
      - ./models:/app/models
    deploy:
      resources:
        reservations:
          devices:
            - driver: nvidia
              count: 1
              capabilities: [gpu]

PM2 部署 #

javascript
module.exports = {
  apps: [{
    name: 'tfjs-api',
    script: 'server.js',
    instances: 'max',
    exec_mode: 'cluster',
    env_production: {
      NODE_ENV: 'production',
      PORT: 3000
    }
  }]
};

健康检查 #

javascript
app.get('/health', async (req, res) => {
  const mem = tf.memory();
  
  res.json({
    status: 'healthy',
    backend: tf.getBackend(),
    memory: {
      tensors: mem.numTensors,
      bytes: mem.numBytes
    },
    modelLoaded: model !== null
  });
});

完整示例 #

javascript
const express = require('express');
const multer = require('multer');
const tf = require('@tensorflow/tfjs-node');
const fs = require('fs');
const path = require('path');

class TensorFlowService {
  constructor(modelPath) {
    this.modelPath = modelPath;
    this.model = null;
  }
  
  async init() {
    console.log('加载模型...');
    this.model = await tf.loadLayersModel(`file://${this.modelPath}`);
    await this.warmup();
    console.log('模型加载完成');
  }
  
  async warmup() {
    const inputShape = this.model.inputShape.slice(1);
    const dummy = tf.zeros([1, ...inputShape]);
    await this.model.predict(dummy).data();
    dummy.dispose();
  }
  
  async predictFromImage(imagePath) {
    const buffer = fs.readFileSync(imagePath);
    const tensor = tf.node.decodeImage(buffer);
    
    const processed = tf.tidy(() => {
      return tf.image.resizeBilinear(tensor, [224, 224])
        .toFloat()
        .div(255)
        .expandDims(0);
    });
    
    const prediction = this.model.predict(processed);
    const result = await prediction.data();
    
    tensor.dispose();
    processed.dispose();
    prediction.dispose();
    
    return Array.from(result);
  }
  
  getMemoryInfo() {
    return tf.memory();
  }
}

const app = express();
const upload = multer({ dest: 'uploads/' });
const tfService = new TensorFlowService('./model/model.json');

app.post('/predict', upload.single('image'), async (req, res) => {
  try {
    const result = await tfService.predictFromImage(req.file.path);
    fs.unlinkSync(req.file.path);
    
    res.json({ success: true, prediction: result });
  } catch (error) {
    res.status(500).json({ error: error.message });
  }
});

app.get('/health', (req, res) => {
  const mem = tfService.getMemoryInfo();
  res.json({
    status: 'healthy',
    memory: mem
  });
});

const PORT = process.env.PORT || 3000;

tfService.init().then(() => {
  app.listen(PORT, () => {
    console.log(`服务运行在端口 ${PORT}`);
  });
});

下一步 #

现在你已经掌握了 Node.js 部署,接下来学习 高级应用,了解更多最佳实践!

最后更新:2026-03-29