数组操作 #

创建数组 #

基本创建 #

python
import jax.numpy as jnp

a = jnp.array([1, 2, 3, 4, 5])
b = jnp.array([[1, 2, 3], [4, 5, 6]])

zeros = jnp.zeros((3, 4))
ones = jnp.ones((2, 3))
empty = jnp.empty((2, 2))
full = jnp.full((3, 3), 7)

序列创建 #

python
import jax.numpy as jnp

arange = jnp.arange(0, 10, 2)  
linspace = jnp.linspace(0, 1, 5)  
logspace = jnp.logspace(0, 2, 5)  

特殊矩阵 #

python
import jax.numpy as jnp

eye = jnp.eye(3)  
diag = jnp.diag([1, 2, 3])  
tri = jnp.tri(3)  

随机数组 #

python
import jax
import jax.numpy as jnp

key = jax.random.PRNGKey(42)

rand_normal = jax.random.normal(key, (3, 3))
rand_uniform = jax.random.uniform(key, (3, 3))
rand_int = jax.random.randint(key, (3, 3), 0, 10)

数组属性 #

python
import jax.numpy as jnp

x = jnp.array([[1, 2, 3], [4, 5, 6]])

print(f"形状: {x.shape}")      
print(f"维度: {x.ndim}")       
print(f"大小: {x.size}")       
print(f"数据类型: {x.dtype}")  
print(f"设备: {x.devices()}")  

索引与切片 #

基本索引 #

python
import jax.numpy as jnp

x = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

print(x[0, 0])      
print(x[1, :])      
print(x[:, 2])      
print(x[0:2, 1:3])  

使用 .at 进行更新 #

python
import jax.numpy as jnp

x = jnp.array([1, 2, 3, 4, 5])

x = x.at[0].set(10)
print(x)  

x = x.at[1:3].set(0)
print(x)  

x = x.at[::2].add(1)
print(x)  

x = x.at[x > 3].set(0)
print(x)  

高级索引 #

python
import jax.numpy as jnp

x = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

indices = jnp.array([0, 2])
print(x[indices])  

mask = x > 5
print(x[mask])  

rows = jnp.array([0, 1, 2])
cols = jnp.array([2, 1, 0])
print(x[rows, cols])  

形状操作 #

reshape #

python
import jax.numpy as jnp

x = jnp.arange(12)

y = x.reshape(3, 4)
print(y.shape)  

z = x.reshape(2, 2, 3)
print(z.shape)  

w = x.reshape(3, -1)
print(w.shape)  

transpose #

python
import jax.numpy as jnp

x = jnp.arange(6).reshape(2, 3)

y = x.T
print(y.shape)  

z = jnp.transpose(x, (1, 0))
print(z.shape)  

squeeze 和 expand_dims #

python
import jax.numpy as jnp

x = jnp.array([[[1, 2], [3, 4]]])  
print(f"原始形状: {x.shape}")

y = jnp.squeeze(x)  
print(f"squeeze 后: {y.shape}")

z = jnp.expand_dims(y, 0)  
print(f"expand_dims 后: {z.shape}")

flatten 和 ravel #

python
import jax.numpy as jnp

x = jnp.arange(6).reshape(2, 3)

y = x.flatten()
print(y.shape)  

z = x.ravel()
print(z.shape)  

拼接与分割 #

拼接 #

python
import jax.numpy as jnp

a = jnp.array([[1, 2], [3, 4]])
b = jnp.array([[5, 6]])

c = jnp.concatenate([a, b], axis=0)
print(f"concatenate:\n{c}")

d = jnp.stack([a, a], axis=0)
print(f"stack 形状: {d.shape}")

e = jnp.vstack([a, b])
print(f"vstack:\n{e}")

f = jnp.hstack([a, jnp.array([[5], [6]])])
print(f"hstack:\n{f}")

分割 #

python
import jax.numpy as jnp

x = jnp.arange(12).reshape(3, 4)

parts = jnp.split(x, 2, axis=1)
print(f"分割数量: {len(parts)}")

parts = jnp.array_split(x, 3, axis=0)
print(f"不均等分割数量: {len(parts)}")

广播 #

自动广播 #

python
import jax.numpy as jnp

a = jnp.array([[1, 2, 3], [4, 5, 6]])  
b = jnp.array([10, 20, 30])            

c = a + b  
print(c)

d = jnp.array([[1], [2]])  
e = a + d  
print(e)

broadcast_to #

python
import jax.numpy as jnp

x = jnp.array([1, 2, 3])

y = jnp.broadcast_to(x, (2, 3))
print(y)

broadcast_arrays #

python
import jax.numpy as jnp

a = jnp.array([1, 2, 3])
b = jnp.array([[10], [20]])

a_broadcast, b_broadcast = jnp.broadcast_arrays(a, b)
print(f"a 广播后:\n{a_broadcast}")
print(f"b 广播后:\n{b_broadcast}")

数学运算 #

基本运算 #

python
import jax.numpy as jnp

a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])

print(a + b)  
print(a - b)  
print(a * b)  
print(a / b)  
print(a ** b) 

矩阵运算 #

python
import jax.numpy as jnp

a = jnp.array([[1, 2], [3, 4]])
b = jnp.array([[5, 6], [7, 8]])

print(jnp.dot(a, b))      
print(jnp.matmul(a, b))   
print(a @ b)              

print(jnp.outer(a.flatten(), b.flatten()))  

统计运算 #

python
import jax.numpy as jnp

x = jnp.array([[1, 2, 3], [4, 5, 6]])

print(jnp.sum(x))            
print(jnp.sum(x, axis=0))    
print(jnp.sum(x, axis=1))    

print(jnp.mean(x))           
print(jnp.std(x))            
print(jnp.var(x))            

print(jnp.max(x))            
print(jnp.min(x))            

条件操作 #

where #

python
import jax.numpy as jnp

x = jnp.array([1, -2, 3, -4, 5])

y = jnp.where(x > 0, x, 0)
print(y)  

z = jnp.where(x > 0, x, -x)
print(z)  

select #

python
import jax.numpy as jnp

x = jnp.array([1, 2, 3, 4, 5])

conditions = [x < 2, x > 4]
choices = [x * 10, x * 100]

y = jnp.select(conditions, choices, default=x)
print(y)  

clip #

python
import jax.numpy as jnp

x = jnp.array([1, 5, 10, 15, 20])

y = jnp.clip(x, 5, 15)
print(y)  

排序与搜索 #

排序 #

python
import jax.numpy as jnp

x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6])

sorted_x = jnp.sort(x)
print(sorted_x)  

indices = jnp.argsort(x)
print(indices)  

搜索 #

python
import jax.numpy as jnp

x = jnp.array([1, 2, 3, 4, 5])

print(jnp.argmax(x))  
print(jnp.argmin(x))  

print(jnp.where(x > 3))  

类型转换 #

python
import jax.numpy as jnp

x = jnp.array([1.5, 2.7, 3.9])

y = x.astype(jnp.int32)
print(y)  

z = x.astype(jnp.float64)
print(z.dtype)  

下一步 #

现在你已经掌握了数组操作,接下来学习 线性代数,了解 JAX 的线性代数功能!

最后更新:2026-04-04