TensorFlow.js 张量基础 #
什么是张量? #
张量是 TensorFlow.js 中的核心数据结构,它是多维数组的泛化形式。张量可以表示标量、向量、矩阵以及更高维度的数据。
text
┌─────────────────────────────────────────────────────────────┐
│ 张量的维度 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 标量 (0D) 向量 (1D) 矩阵 (2D) 3D 张量 │
│ │
│ [5] [1,2,3] [[1,2,3], [[[1,2], │
│ [4,5,6]] [3,4]], │
│ [[5,6], │
│ [7,8]]] │
│ │
│ shape: [] shape: [3] shape: [2,3] shape: [2,2,2] │
│ │
└─────────────────────────────────────────────────────────────┘
张量的属性 #
每个张量都有以下核心属性:
数据类型(dtype) #
javascript
const intTensor = tf.tensor([1, 2, 3]);
console.log(intTensor.dtype);
const floatTensor = tf.tensor([1.5, 2.5, 3.5]);
console.log(floatTensor.dtype);
const boolTensor = tf.tensor([true, false, true], [3], 'bool');
console.log(boolTensor.dtype);
形状(shape) #
javascript
const tensor = tf.tensor([[1, 2, 3], [4, 5, 6]]);
console.log(tensor.shape);
const tensor3d = tf.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
console.log(tensor3d.shape);
维度(rank) #
javascript
const scalar = tf.scalar(5);
console.log(scalar.rank);
const vector = tf.tensor1d([1, 2, 3]);
console.log(vector.rank);
const matrix = tf.tensor2d([[1, 2], [3, 4]]);
console.log(matrix.rank);
大小(size) #
javascript
const tensor = tf.tensor([[1, 2, 3], [4, 5, 6]]);
console.log(tensor.size);
创建张量 #
基本创建方法 #
tf.tensor() #
javascript
const t1 = tf.tensor([1, 2, 3, 4]);
t1.print();
const t2 = tf.tensor([1, 2, 3, 4], [2, 2]);
t2.print();
const t3 = tf.tensor([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]);
t3.print();
指定维度的创建方法 #
javascript
const scalar = tf.scalar(5);
scalar.print();
const vector = tf.tensor1d([1, 2, 3]);
vector.print();
const matrix = tf.tensor2d([[1, 2], [3, 4]]);
matrix.print();
const tensor3d = tf.tensor3d([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
tensor3d.print();
const tensor4d = tf.tensor4d([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]]);
tensor4d.print();
特殊值张量 #
零张量 #
javascript
const zeros = tf.zeros([3, 3]);
zeros.print();
const zerosLike = tf.zerosLike(tf.tensor([1, 2, 3]));
zerosLike.print();
一张量 #
javascript
const ones = tf.ones([2, 3]);
ones.print();
const onesLike = tf.onesLike(tf.tensor([[1, 2], [3, 4]]));
onesLike.print();
填充张量 #
javascript
const filled = tf.fill([2, 3], 7);
filled.print();
单位矩阵 #
javascript
const eye = tf.eye(3);
eye.print();
随机张量 #
均匀分布随机 #
javascript
const random = tf.randomUniform([3, 3], 0, 1);
random.print();
正态分布随机 #
javascript
const normal = tf.randomNormal([3, 3], 0, 1);
normal.print();
截断正态分布 #
javascript
const truncated = tf.truncatedNormal([3, 3], 0, 1);
truncated.print();
序列张量 #
等差序列 #
javascript
const range = tf.range(0, 10, 2);
range.print();
const linspace = tf.linspace(0, 1, 5);
linspace.print();
从其他数据创建 #
从 DOM 元素 #
javascript
const img = document.getElementById('myImage');
const tensor = tf.browser.fromPixels(img);
tensor.print();
从像素数据 #
javascript
const canvas = document.getElementById('myCanvas');
const ctx = canvas.getContext('2d');
const imageData = ctx.getImageData(0, 0, 100, 100);
const tensor = tf.browser.fromPixels(imageData);
张量操作 #
形状变换 #
reshape #
javascript
const t = tf.tensor([1, 2, 3, 4, 5, 6]);
const reshaped = t.reshape([2, 3]);
reshaped.print();
flatten #
javascript
const t = tf.tensor([[1, 2, 3], [4, 5, 6]]);
const flattened = t.flatten();
flattened.print();
expandDims #
javascript
const t = tf.tensor([1, 2, 3]);
const expanded = t.expandDims(0);
console.log(expanded.shape);
squeeze #
javascript
const t = tf.tensor([[[1, 2, 3]]]);
const squeezed = t.squeeze();
console.log(squeezed.shape);
transpose #
javascript
const t = tf.tensor([[1, 2, 3], [4, 5, 6]]);
const transposed = t.transpose();
transposed.print();
切片与索引 #
slice #
javascript
const t = tf.tensor([1, 2, 3, 4, 5, 6]);
const sliced = t.slice([2], [3]);
sliced.print();
const t2d = tf.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]);
const sliced2d = t2d.slice([1, 1], [2, 2]);
sliced2d.print();
stridedSlice #
javascript
const t = tf.tensor([1, 2, 3, 4, 5, 6, 7, 8]);
const strided = t.stridedSlice([1], [7], [2]);
strided.print();
gather #
javascript
const t = tf.tensor([[1, 2], [3, 4], [5, 6]]);
const gathered = t.gather([0, 2]);
gathered.print();
拼接与分割 #
concat #
javascript
const a = tf.tensor([[1, 2]]);
const b = tf.tensor([[3, 4]]);
const concatenated = a.concat(b, 0);
concatenated.print();
split #
javascript
const t = tf.tensor([1, 2, 3, 4, 5, 6]);
const [a, b] = tf.split(t, 2);
a.print();
b.print();
stack #
javascript
const a = tf.tensor([1, 2, 3]);
const b = tf.tensor([4, 5, 6]);
const stacked = tf.stack([a, b]);
stacked.print();
unstack #
javascript
const t = tf.tensor([[1, 2, 3], [4, 5, 6]]);
const [a, b] = tf.unstack(t);
a.print();
b.print();
填充与复制 #
pad #
javascript
const t = tf.tensor([[1, 2], [3, 4]]);
const padded = t.pad([[1, 1], [1, 1]]);
padded.print();
tile #
javascript
const t = tf.tensor([1, 2]);
const tiled = t.tile([3]);
tiled.print();
repeat #
javascript
const t = tf.tensor([1, 2, 3]);
const repeated = t.repeat(2);
repeated.print();
张量数据获取 #
同步获取 #
javascript
const t = tf.tensor([1, 2, 3, 4]);
const data = t.dataSync();
console.log(data);
const array = t.arraySync();
console.log(array);
异步获取 #
javascript
const t = tf.tensor([1, 2, 3, 4]);
t.data().then(data => {
console.log(data);
});
t.array().then(array => {
console.log(array);
});
获取单个值 #
javascript
const t = tf.tensor([1, 2, 3, 4]);
const value = t.dataSync()[0];
console.log(value);
const scalar = tf.scalar(5);
const scalarValue = scalar.dataSync()[0];
console.log(scalarValue);
变量(Variable) #
变量是可变的张量,用于存储模型参数。
创建变量 #
javascript
const initial = tf.zeros([3, 3]);
const variable = tf.variable(initial);
variable.print();
更新变量 #
javascript
const variable = tf.variable(tf.zeros([3]));
variable.assign(tf.tensor([1, 2, 3]));
variable.print();
变量属性 #
javascript
const variable = tf.variable(tf.zeros([3, 3]));
console.log(variable.shape);
console.log(variable.dtype);
console.log(variable.name);
内存管理 #
TensorFlow.js 使用 WebGL 纹理存储张量,需要手动管理内存。
手动释放 #
javascript
const t = tf.tensor([1, 2, 3]);
t.dispose();
tf.tidy 自动清理 #
javascript
const result = tf.tidy(() => {
const a = tf.tensor([1, 2, 3]);
const b = tf.tensor([4, 5, 6]);
const c = a.add(b);
return c;
});
result.print();
tf.dispose 批量释放 #
javascript
const a = tf.tensor([1, 2, 3]);
const b = tf.tensor([4, 5, 6]);
tf.dispose([a, b]);
内存监控 #
javascript
console.log(tf.memory());
内存管理最佳实践 #
text
┌─────────────────────────────────────────────────────────────┐
│ 内存管理最佳实践 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. 使用 tf.tidy 包装中间计算 │
│ │
│ const result = tf.tidy(() => { │
│ const a = tf.tensor([1, 2, 3]); │
│ const b = a.square(); │
│ return b; │
│ }); │
│ │
│ 2. 及时释放不需要的张量 │
│ │
│ const t = expensiveOperation(); │
│ useTensor(t); │
│ t.dispose(); │
│ │
│ 3. 避免在循环中创建张量 │
│ │
│ for (let i = 0; i < 1000; i++) { │
│ tf.tidy(() => { │
│ const t = tf.tensor(i); │
│ }); │
│ } │
│ │
│ 4. 监控内存使用 │
│ │
│ setInterval(() => { │
│ console.log(tf.memory()); │
│ }, 1000); │
│ │
└─────────────────────────────────────────────────────────────┘
张量打印与可视化 #
print() #
javascript
const t = tf.tensor([[1, 2, 3], [4, 5, 6]]);
t.print();
t.print(true);
toString() #
javascript
const t = tf.tensor([1, 2, 3]);
console.log(t.toString());
浏览器渲染 #
javascript
const t = tf.randomNormal([100, 100]);
await tf.browser.toPixels(t, document.getElementById('canvas'));
张量类型转换 #
cast #
javascript
const t = tf.tensor([1.5, 2.7, 3.9]);
const intT = t.cast('int32');
intT.print();
asType #
javascript
const t = tf.tensor([1, 2, 3]);
const floatT = t.asType('float32');
实用工具函数 #
获取张量信息 #
javascript
const t = tf.tensor([[1, 2, 3], [4, 5, 6]]);
console.log('Shape:', t.shape);
console.log('Rank:', t.rank);
console.log('Size:', t.size);
console.log('Dtype:', t.dtype);
判断是否为张量 #
javascript
const t = tf.tensor([1, 2, 3]);
console.log(t instanceof tf.Tensor);
console.log(tf.util.isTensor(t));
克隆张量 #
javascript
const original = tf.tensor([1, 2, 3]);
const cloned = original.clone();
cloned.print();
下一步 #
现在你已经掌握了张量的基础知识,接下来学习 张量运算,了解如何对张量进行数学运算!
最后更新:2026-03-29