数组操作 #
创建数组 #
基本创建 #
python
import jax.numpy as jnp
a = jnp.array([1, 2, 3, 4, 5])
b = jnp.array([[1, 2, 3], [4, 5, 6]])
zeros = jnp.zeros((3, 4))
ones = jnp.ones((2, 3))
empty = jnp.empty((2, 2))
full = jnp.full((3, 3), 7)
序列创建 #
python
import jax.numpy as jnp
arange = jnp.arange(0, 10, 2)
linspace = jnp.linspace(0, 1, 5)
logspace = jnp.logspace(0, 2, 5)
特殊矩阵 #
python
import jax.numpy as jnp
eye = jnp.eye(3)
diag = jnp.diag([1, 2, 3])
tri = jnp.tri(3)
随机数组 #
python
import jax
import jax.numpy as jnp
key = jax.random.PRNGKey(42)
rand_normal = jax.random.normal(key, (3, 3))
rand_uniform = jax.random.uniform(key, (3, 3))
rand_int = jax.random.randint(key, (3, 3), 0, 10)
数组属性 #
python
import jax.numpy as jnp
x = jnp.array([[1, 2, 3], [4, 5, 6]])
print(f"形状: {x.shape}")
print(f"维度: {x.ndim}")
print(f"大小: {x.size}")
print(f"数据类型: {x.dtype}")
print(f"设备: {x.devices()}")
索引与切片 #
基本索引 #
python
import jax.numpy as jnp
x = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(x[0, 0])
print(x[1, :])
print(x[:, 2])
print(x[0:2, 1:3])
使用 .at 进行更新 #
python
import jax.numpy as jnp
x = jnp.array([1, 2, 3, 4, 5])
x = x.at[0].set(10)
print(x)
x = x.at[1:3].set(0)
print(x)
x = x.at[::2].add(1)
print(x)
x = x.at[x > 3].set(0)
print(x)
高级索引 #
python
import jax.numpy as jnp
x = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
indices = jnp.array([0, 2])
print(x[indices])
mask = x > 5
print(x[mask])
rows = jnp.array([0, 1, 2])
cols = jnp.array([2, 1, 0])
print(x[rows, cols])
形状操作 #
reshape #
python
import jax.numpy as jnp
x = jnp.arange(12)
y = x.reshape(3, 4)
print(y.shape)
z = x.reshape(2, 2, 3)
print(z.shape)
w = x.reshape(3, -1)
print(w.shape)
transpose #
python
import jax.numpy as jnp
x = jnp.arange(6).reshape(2, 3)
y = x.T
print(y.shape)
z = jnp.transpose(x, (1, 0))
print(z.shape)
squeeze 和 expand_dims #
python
import jax.numpy as jnp
x = jnp.array([[[1, 2], [3, 4]]])
print(f"原始形状: {x.shape}")
y = jnp.squeeze(x)
print(f"squeeze 后: {y.shape}")
z = jnp.expand_dims(y, 0)
print(f"expand_dims 后: {z.shape}")
flatten 和 ravel #
python
import jax.numpy as jnp
x = jnp.arange(6).reshape(2, 3)
y = x.flatten()
print(y.shape)
z = x.ravel()
print(z.shape)
拼接与分割 #
拼接 #
python
import jax.numpy as jnp
a = jnp.array([[1, 2], [3, 4]])
b = jnp.array([[5, 6]])
c = jnp.concatenate([a, b], axis=0)
print(f"concatenate:\n{c}")
d = jnp.stack([a, a], axis=0)
print(f"stack 形状: {d.shape}")
e = jnp.vstack([a, b])
print(f"vstack:\n{e}")
f = jnp.hstack([a, jnp.array([[5], [6]])])
print(f"hstack:\n{f}")
分割 #
python
import jax.numpy as jnp
x = jnp.arange(12).reshape(3, 4)
parts = jnp.split(x, 2, axis=1)
print(f"分割数量: {len(parts)}")
parts = jnp.array_split(x, 3, axis=0)
print(f"不均等分割数量: {len(parts)}")
广播 #
自动广播 #
python
import jax.numpy as jnp
a = jnp.array([[1, 2, 3], [4, 5, 6]])
b = jnp.array([10, 20, 30])
c = a + b
print(c)
d = jnp.array([[1], [2]])
e = a + d
print(e)
broadcast_to #
python
import jax.numpy as jnp
x = jnp.array([1, 2, 3])
y = jnp.broadcast_to(x, (2, 3))
print(y)
broadcast_arrays #
python
import jax.numpy as jnp
a = jnp.array([1, 2, 3])
b = jnp.array([[10], [20]])
a_broadcast, b_broadcast = jnp.broadcast_arrays(a, b)
print(f"a 广播后:\n{a_broadcast}")
print(f"b 广播后:\n{b_broadcast}")
数学运算 #
基本运算 #
python
import jax.numpy as jnp
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])
print(a + b)
print(a - b)
print(a * b)
print(a / b)
print(a ** b)
矩阵运算 #
python
import jax.numpy as jnp
a = jnp.array([[1, 2], [3, 4]])
b = jnp.array([[5, 6], [7, 8]])
print(jnp.dot(a, b))
print(jnp.matmul(a, b))
print(a @ b)
print(jnp.outer(a.flatten(), b.flatten()))
统计运算 #
python
import jax.numpy as jnp
x = jnp.array([[1, 2, 3], [4, 5, 6]])
print(jnp.sum(x))
print(jnp.sum(x, axis=0))
print(jnp.sum(x, axis=1))
print(jnp.mean(x))
print(jnp.std(x))
print(jnp.var(x))
print(jnp.max(x))
print(jnp.min(x))
条件操作 #
where #
python
import jax.numpy as jnp
x = jnp.array([1, -2, 3, -4, 5])
y = jnp.where(x > 0, x, 0)
print(y)
z = jnp.where(x > 0, x, -x)
print(z)
select #
python
import jax.numpy as jnp
x = jnp.array([1, 2, 3, 4, 5])
conditions = [x < 2, x > 4]
choices = [x * 10, x * 100]
y = jnp.select(conditions, choices, default=x)
print(y)
clip #
python
import jax.numpy as jnp
x = jnp.array([1, 5, 10, 15, 20])
y = jnp.clip(x, 5, 15)
print(y)
排序与搜索 #
排序 #
python
import jax.numpy as jnp
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6])
sorted_x = jnp.sort(x)
print(sorted_x)
indices = jnp.argsort(x)
print(indices)
搜索 #
python
import jax.numpy as jnp
x = jnp.array([1, 2, 3, 4, 5])
print(jnp.argmax(x))
print(jnp.argmin(x))
print(jnp.where(x > 3))
类型转换 #
python
import jax.numpy as jnp
x = jnp.array([1.5, 2.7, 3.9])
y = x.astype(jnp.int32)
print(y)
z = x.astype(jnp.float64)
print(z.dtype)
下一步 #
现在你已经掌握了数组操作,接下来学习 线性代数,了解 JAX 的线性代数功能!
最后更新:2026-04-04