TensorFlow.js Node.js 部署 #
Node.js 版本概述 #
TensorFlow.js Node.js 版本提供了更强大的计算能力和 GPU 支持。
text
┌─────────────────────────────────────────────────────────────┐
│ Node.js 版本优势 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ GPU 加速 │ │ 更快速度 │ │ 完整功能 │ │
│ │ CUDA 支持 │ │ 原生绑定 │ │ 文件系统 │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ 服务端渲染 │ │ 批量处理 │ │ API 服务 │ │
│ │ SSR 支持 │ │ 大规模 │ │ 微服务 │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
安装配置 #
CPU 版本 #
bash
npm install @tensorflow/tfjs-node
javascript
const tf = require('@tensorflow/tfjs-node');
GPU 版本 #
bash
npm install @tensorflow/tfjs-node-gpu
javascript
const tf = require('@tensorflow/tfjs-node-gpu');
系统要求 #
| 版本 | 要求 |
|---|---|
| CPU | Node.js 12+ |
| GPU | CUDA 11.x, cuDNN 8.x |
验证安装 #
javascript
const tf = require('@tensorflow/tfjs-node');
console.log('TensorFlow.js 版本:', tf.version.tfjs);
console.log('后端:', tf.getBackend());
const tensor = tf.tensor([1, 2, 3, 4]);
tensor.print();
基本使用 #
创建张量 #
javascript
const tf = require('@tensorflow/tfjs-node');
const scalar = tf.scalar(5);
const vector = tf.tensor1d([1, 2, 3]);
const matrix = tf.tensor2d([[1, 2], [3, 4]]);
console.log('标量:');
scalar.print();
console.log('向量:');
vector.print();
console.log('矩阵:');
matrix.print();
加载模型 #
javascript
const tf = require('@tensorflow/tfjs-node');
async function loadModel() {
const model = await tf.loadLayersModel('file://./model/model.json');
console.log('模型加载成功');
return model;
}
模型预测 #
javascript
async function predict(model, inputData) {
const inputTensor = tf.tensor2d(inputData, [inputData.length, inputData[0].length]);
const prediction = model.predict(inputTensor);
const result = prediction.dataSync();
inputTensor.dispose();
prediction.dispose();
return result;
}
文件操作 #
保存模型 #
javascript
const tf = require('@tensorflow/tfjs-node');
async function saveModel(model, path) {
await model.save(`file://${path}`);
console.log(`模型已保存到 ${path}`);
}
加载本地模型 #
javascript
async function loadLocalModel(path) {
const model = await tf.loadLayersModel(`file://${path}/model.json`);
return model;
}
图像处理 #
javascript
const tf = require('@tensorflow/tfjs-node');
const fs = require('fs');
const path = require('path');
async function loadImage(imagePath) {
const buffer = fs.readFileSync(imagePath);
const tensor = tf.node.decodeImage(buffer);
return tensor;
}
async function preprocessImage(imagePath, targetSize = [224, 224]) {
const tensor = await loadImage(imagePath);
const resized = tf.image.resizeBilinear(tensor, targetSize);
const normalized = resized.div(255);
const batched = normalized.expandDims(0);
tensor.dispose();
resized.dispose();
normalized.dispose();
return batched;
}
批量图像处理 #
javascript
async function loadImagesFromDirectory(dir, targetSize = [224, 224]) {
const files = fs.readdirSync(dir).filter(f =>
f.endsWith('.jpg') || f.endsWith('.png')
);
const tensors = [];
for (const file of files) {
const imagePath = path.join(dir, file);
const tensor = await preprocessImage(imagePath, targetSize);
tensors.push(tensor);
}
return tf.concat(tensors, 0);
}
Express API 服务 #
创建 API 服务 #
javascript
const express = require('express');
const multer = require('multer');
const tf = require('@tensorflow/tfjs-node');
const app = express();
const upload = multer({ dest: 'uploads/' });
let model = null;
async function initModel() {
model = await tf.loadLayersModel('file://./model/model.json');
console.log('模型加载完成');
}
app.post('/predict', upload.single('image'), async (req, res) => {
try {
const buffer = fs.readFileSync(req.file.path);
const tensor = tf.node.decodeImage(buffer);
const processed = tf.image.resizeBilinear(tensor, [224, 224])
.div(255)
.expandDims(0);
const prediction = model.predict(processed);
const result = await prediction.data();
tensor.dispose();
processed.dispose();
prediction.dispose();
fs.unlinkSync(req.file.path);
res.json({
success: true,
prediction: Array.from(result)
});
} catch (error) {
res.status(500).json({ error: error.message });
}
});
const PORT = process.env.PORT || 3000;
initModel().then(() => {
app.listen(PORT, () => {
console.log(`服务运行在端口 ${PORT}`);
});
});
批量预测 API #
javascript
app.post('/predict/batch', upload.array('images', 10), async (req, res) => {
try {
const predictions = [];
for (const file of req.files) {
const buffer = fs.readFileSync(file.path);
const tensor = tf.node.decodeImage(buffer);
const processed = tf.image.resizeBilinear(tensor, [224, 224])
.div(255)
.expandDims(0);
const prediction = model.predict(processed);
const result = await prediction.data();
predictions.push(Array.from(result));
tensor.dispose();
processed.dispose();
prediction.dispose();
fs.unlinkSync(file.path);
}
res.json({
success: true,
count: predictions.length,
predictions
});
} catch (error) {
res.status(500).json({ error: error.message });
}
});
GPU 加速 #
检查 GPU #
javascript
const tf = require('@tensorflow/tfjs-node-gpu');
async function checkGPU() {
console.log('后端:', tf.getBackend());
const gpuInfo = await tf.node.queryTensorflowVersion();
console.log('TensorFlow 版本:', gpuInfo);
}
GPU 内存管理 #
javascript
function gpuMemoryInfo() {
const info = tf.memory();
console.log({
张量数量: info.numTensors,
内存使用: (info.numBytes / 1024 / 1024).toFixed(2) + ' MB'
});
}
setInterval(gpuMemoryInfo, 5000);
GPU 配置 #
javascript
const tf = require('@tensorflow/tfjs-node-gpu');
tf.env().set('WEBGL_FORCE_F16_TEXTURES', true);
tf.env().set('TF_ENABLE_ONEDNN_OPTS', 1);
性能优化 #
并行处理 #
javascript
async function parallelPredict(model, inputs, batchSize = 32) {
const results = [];
for (let i = 0; i < inputs.length; i += batchSize) {
const batch = inputs.slice(i, i + batchSize);
const batchTensor = tf.stack(batch);
const predictions = model.predict(batchTensor);
const batchResults = await predictions.array();
results.push(...batchResults);
batchTensor.dispose();
predictions.dispose();
}
return results;
}
Worker 线程 #
javascript
const { Worker, isMainThread, parentPort, workerData } = require('worker_threads');
const tf = require('@tensorflow/tfjs-node');
if (isMainThread) {
async function parallelInference(inputs, numWorkers = 4) {
const chunkSize = Math.ceil(inputs.length / numWorkers);
const workers = [];
for (let i = 0; i < numWorkers; i++) {
const chunk = inputs.slice(i * chunkSize, (i + 1) * chunkSize);
workers.push(new Worker(__filename, { workerData: chunk }));
}
const results = await Promise.all(
workers.map(worker => {
return new Promise(resolve => {
worker.on('message', resolve);
});
})
);
return results.flat();
}
} else {
async function processChunk() {
const model = await tf.loadLayersModel('file://./model/model.json');
const predictions = [];
for (const input of workerData) {
const tensor = tf.tensor(input).expandDims(0);
const prediction = model.predict(tensor);
predictions.push(await prediction.data());
tensor.dispose();
prediction.dispose();
}
parentPort.postMessage(predictions);
}
processChunk();
}
内存优化 #
javascript
class ModelPredictor {
constructor(modelPath) {
this.modelPath = modelPath;
this.model = null;
}
async init() {
this.model = await tf.loadLayersModel(`file://${this.modelPath}`);
await this.warmup();
}
async warmup() {
const inputShape = this.model.inputShape.slice(1);
const dummy = tf.zeros([1, ...inputShape]);
await this.model.predict(dummy).data();
dummy.dispose();
}
async predict(input) {
return tf.tidy(() => {
const tensor = tf.tensor(input).expandDims(0);
const output = this.model.predict(tensor);
return output.dataSync();
});
}
async predictBatch(inputs) {
const results = [];
for (const input of inputs) {
const result = await this.predict(input);
results.push(result);
}
return results;
}
}
数据处理 #
CSV 数据加载 #
javascript
const tf = require('@tensorflow/tfjs-node');
const fs = require('fs');
const csv = require('csv-parser');
async function loadCSV(filePath) {
return new Promise((resolve) => {
const results = [];
fs.createReadStream(filePath)
.pipe(csv())
.on('data', (data) => results.push(data))
.on('end', () => resolve(results));
});
}
async function loadCSVTensors(filePath, labelColumn) {
const data = await loadCSV(filePath);
const features = data.map(row => {
const f = { ...row };
delete f[labelColumn];
return Object.values(f).map(Number);
});
const labels = data.map(row => Number(row[labelColumn]));
return {
xs: tf.tensor2d(features),
ys: tf.tensor1d(labels)
};
}
JSON 数据加载 #
javascript
async function loadJSON(filePath) {
const content = fs.readFileSync(filePath, 'utf8');
return JSON.parse(content);
}
async function loadJSONTensors(filePath, featureKeys, labelKey) {
const data = await loadJSON(filePath);
const features = data.map(item =>
featureKeys.map(key => item[key])
);
const labels = data.map(item => item[labelKey]);
return {
xs: tf.tensor2d(features),
ys: tf.tensor1d(labels)
};
}
生产部署 #
Docker 部署 #
dockerfile
FROM node:18
WORKDIR /app
COPY package*.json ./
RUN npm install
COPY . .
EXPOSE 3000
CMD ["node", "server.js"]
Docker Compose #
yaml
version: '3.8'
services:
tfjs-api:
build: .
ports:
- "3000:3000"
environment:
- NODE_ENV=production
- PORT=3000
volumes:
- ./models:/app/models
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
PM2 部署 #
javascript
module.exports = {
apps: [{
name: 'tfjs-api',
script: 'server.js',
instances: 'max',
exec_mode: 'cluster',
env_production: {
NODE_ENV: 'production',
PORT: 3000
}
}]
};
健康检查 #
javascript
app.get('/health', async (req, res) => {
const mem = tf.memory();
res.json({
status: 'healthy',
backend: tf.getBackend(),
memory: {
tensors: mem.numTensors,
bytes: mem.numBytes
},
modelLoaded: model !== null
});
});
完整示例 #
javascript
const express = require('express');
const multer = require('multer');
const tf = require('@tensorflow/tfjs-node');
const fs = require('fs');
const path = require('path');
class TensorFlowService {
constructor(modelPath) {
this.modelPath = modelPath;
this.model = null;
}
async init() {
console.log('加载模型...');
this.model = await tf.loadLayersModel(`file://${this.modelPath}`);
await this.warmup();
console.log('模型加载完成');
}
async warmup() {
const inputShape = this.model.inputShape.slice(1);
const dummy = tf.zeros([1, ...inputShape]);
await this.model.predict(dummy).data();
dummy.dispose();
}
async predictFromImage(imagePath) {
const buffer = fs.readFileSync(imagePath);
const tensor = tf.node.decodeImage(buffer);
const processed = tf.tidy(() => {
return tf.image.resizeBilinear(tensor, [224, 224])
.toFloat()
.div(255)
.expandDims(0);
});
const prediction = this.model.predict(processed);
const result = await prediction.data();
tensor.dispose();
processed.dispose();
prediction.dispose();
return Array.from(result);
}
getMemoryInfo() {
return tf.memory();
}
}
const app = express();
const upload = multer({ dest: 'uploads/' });
const tfService = new TensorFlowService('./model/model.json');
app.post('/predict', upload.single('image'), async (req, res) => {
try {
const result = await tfService.predictFromImage(req.file.path);
fs.unlinkSync(req.file.path);
res.json({ success: true, prediction: result });
} catch (error) {
res.status(500).json({ error: error.message });
}
});
app.get('/health', (req, res) => {
const mem = tfService.getMemoryInfo();
res.json({
status: 'healthy',
memory: mem
});
});
const PORT = process.env.PORT || 3000;
tfService.init().then(() => {
app.listen(PORT, () => {
console.log(`服务运行在端口 ${PORT}`);
});
});
下一步 #
现在你已经掌握了 Node.js 部署,接下来学习 高级应用,了解更多最佳实践!
最后更新:2026-03-29