TensorFlow.js 安装与配置 #

安装方式概览 #

TensorFlow.js 提供多种安装方式,适应不同的开发场景:

text
┌─────────────────────────────────────────────────────────────┐
│                   TensorFlow.js 安装方式                      │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐         │
│  │   CDN 引入   │  │   NPM 安装   │  │  Node.js   │         │
│  │   (最简单)   │  │  (推荐)      │  │  (服务端)   │         │
│  └─────────────┘  └─────────────┘  └─────────────┘         │
│                                                             │
└─────────────────────────────────────────────────────────────┘

方式一:CDN 引入(浏览器) #

最简单的方式,适合快速原型开发和学习。

完整版 #

html
<!DOCTYPE html>
<html>
<head>
  <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
</head>
<body>
  <script>
    const tensor = tf.tensor([1, 2, 3, 4]);
    tensor.print();
  </script>
</body>
</html>

指定版本 #

html
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@4.17.0"></script>

使用 UNPKG CDN #

html
<script src="https://unpkg.com/@tensorflow/tfjs@latest"></script>

优缺点 #

优点 缺点
无需构建工具 无法 Tree-shaking
快速开始 加载完整包,体积大
适合学习 不适合生产环境

方式二:NPM 安装(推荐) #

适合正式项目开发,支持模块化和 Tree-shaking。

安装核心包 #

bash
npm install @tensorflow/tfjs

安装 Node.js 版本 #

bash
npm install @tensorflow/tfjs-node

安装带 CUDA 支持的版本 #

bash
npm install @tensorflow/tfjs-node-gpu

在项目中使用 #

javascript
import * as tf from '@tensorflow/tfjs';

const model = tf.sequential();
model.add(tf.layers.dense({ units: 1, inputShape: [1] }));

模块化导入 #

javascript
import { tensor, sequential, layers } from '@tensorflow/tfjs';

const t = tensor([1, 2, 3]);
const model = sequential();

方式三:Node.js 环境 #

在服务端运行 TensorFlow.js,适合需要高性能计算的场景。

安装 #

bash
npm install @tensorflow/tfjs-node

使用 #

javascript
const tf = require('@tensorflow/tfjs-node');

const tensor = tf.tensor2d([[1, 2], [3, 4]]);
tensor.print();

GPU 加速版本 #

bash
npm install @tensorflow/tfjs-node-gpu
javascript
const tf = require('@tensorflow/tfjs-node-gpu');

版本对比 #

版本 加速方式 适用场景
tfjs-node CPU 通用服务端
tfjs-node-gpu CUDA GPU 高性能计算

安装预训练模型 #

TensorFlow.js 提供丰富的预训练模型:

图像分类 #

bash
npm install @tensorflow-models/mobilenet

目标检测 #

bash
npm install @tensorflow-models/coco-ssd

姿态估计 #

bash
npm install @tensorflow-models/pose-detection

文本处理 #

bash
npm install @tensorflow-models/universal-sentence-encoder
npm install @tensorflow-models/toxicity

语音识别 #

bash
npm install @tensorflow-models/speech-commands

开发环境配置 #

使用 Vite #

bash
npm create vite@latest my-tfjs-app
cd my-tfjs-app
npm install @tensorflow/tfjs
npm run dev
javascript
import * as tf from '@tensorflow/tfjs';

const tensor = tf.zeros([2, 2]);
tensor.print();

使用 Webpack #

bash
npm install webpack webpack-cli @tensorflow/tfjs
javascript
const tf = require('@tensorflow/tfjs');

const tensor = tf.tensor([1, 2, 3]);
tensor.print();

使用 TypeScript #

bash
npm install @tensorflow/tfjs typescript
typescript
import * as tf from '@tensorflow/tfjs';

const tensor: tf.Tensor = tf.tensor([1, 2, 3]);
tensor.print();

后端配置 #

TensorFlow.js 支持多种后端:

查看当前后端 #

javascript
console.log(tf.getBackend());

设置后端 #

javascript
await tf.setBackend('webgl');
await tf.setBackend('wasm');
await tf.setBackend('cpu');

后端对比 #

text
┌─────────────────────────────────────────────────────────────┐
│                    后端性能对比                               │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  性能排序:WebGL > WASM > CPU                               │
│                                                             │
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐         │
│  │   WebGL     │  │    WASM     │  │    CPU     │         │
│  │   最快      │  │    中等     │  │   最慢     │         │
│  │   GPU 加速  │  │  SIMD 优化  │  │   回退     │         │
│  └─────────────┘  └─────────────┘  └─────────────┘         │
│                                                             │
└─────────────────────────────────────────────────────────────┘

WebGL 后端配置 #

javascript
await tf.setBackend('webgl');
await tf.ready();

console.log('WebGL backend ready');
console.log('GPU detected:', tf.ENV.get('WEBGL_VERSION'));

WASM 后端配置 #

bash
npm install @tensorflow/tfjs-backend-wasm
javascript
import '@tensorflow/tfjs-backend-wasm';
import * as tf from '@tensorflow/tfjs';

await tf.setBackend('wasm');
await tf.ready();

CPU 后端配置 #

javascript
await tf.setBackend('cpu');
await tf.ready();

内存管理 #

TensorFlow.js 使用 WebGL 纹理存储张量,需要手动管理内存:

手动释放 #

javascript
const tensor = tf.tensor([1, 2, 3]);
tensor.dispose();

使用 tf.tidy 自动清理 #

javascript
const result = tf.tidy(() => {
  const a = tf.tensor([1, 2, 3]);
  const b = tf.tensor([4, 5, 6]);
  return a.add(b);
});

内存监控 #

javascript
console.log('Memory info:', tf.memory());

模型保存与加载 #

保存模型 #

javascript
const model = tf.sequential();
model.add(tf.layers.dense({ units: 1, inputShape: [1] }));

await model.save('localstorage://my-model');
await model.save('indexeddb://my-model');
await model.save('downloads://my-model');

加载模型 #

javascript
const model = await tf.loadLayersModel('localstorage://my-model');
const model = await tf.loadLayersModel('indexeddb://my-model');
const model = await tf.loadLayersModel('https://example.com/model.json');

模型转换 #

将 Python TensorFlow 模型转换为 TensorFlow.js 格式:

安装转换工具 #

bash
pip install tensorflowjs

转换 Keras 模型 #

bash
tensorflowjs_converter --input_format keras model.h5 ./tfjs_model

转换 SavedModel #

bash
tensorflowjs_converter --input_format tf_saved_model ./saved_model ./tfjs_model

转换选项 #

bash
tensorflowjs_converter \
  --input_format keras \
  --output_format tfjs_graph_model \
  --weight_shard_size_bytes 4000000 \
  model.h5 ./tfjs_model

开发工具 #

TensorFlow.js 可视化工具 #

bash
npm install @tensorflow/tfjs-vis
javascript
import * as tfvis from '@tensorflow/tfjs-vis';

tfvis.visor().open();

tfvis.show.modelSummary({ name: 'Model Summary' }, model);
tfvis.show.layer({ name: 'Layer' }, model.layers[0]);

训练监控 #

javascript
const history = await model.fit(xs, ys, {
  epochs: 50,
  callbacks: tfvis.show.fitCallbacks(
    { name: 'Training Performance' },
    ['loss', 'val_loss']
  )
});

常见问题 #

1. WebGL 不可用 #

javascript
if (!tf.ENV.get('WEBGL_VERSION')) {
  console.warn('WebGL not available, falling back to CPU');
  await tf.setBackend('cpu');
}

2. 内存不足 #

javascript
tf.disposeVariables();
tf.tidy(() => {
  
});

3. 模型加载失败 #

javascript
try {
  const model = await tf.loadLayersModel('model.json');
} catch (error) {
  console.error('Failed to load model:', error);
}

项目结构建议 #

text
my-tfjs-project/
├── src/
│   ├── models/
│   │   └── model.js
│   ├── data/
│   │   └── data.js
│   ├── utils/
│   │   └── utils.js
│   └── index.js
├── public/
│   └── models/
│       └── my-model/
│           ├── model.json
│           └── weights.bin
├── package.json
└── vite.config.js

下一步 #

现在你已经完成了环境配置,接下来学习 张量基础,了解 TensorFlow.js 的核心数据结构!

最后更新:2026-03-29