线性代数 #

概述 #

JAX 提供了完整的线性代数功能,通过 jax.numpy.linalg 模块可以执行各种矩阵运算、分解和求解操作。

基本运算 #

矩阵乘法 #

python
import jax.numpy as jnp

A = jnp.array([[1, 2], [3, 4]])
B = jnp.array([[5, 6], [7, 8]])

C = jnp.dot(A, B)
print(f"dot:\n{C}")

D = jnp.matmul(A, B)
print(f"matmul:\n{D}")

E = A @ B
print(f"@ 运算符:\n{E}")

向量运算 #

python
import jax.numpy as jnp

a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])

dot_product = jnp.dot(a, b)
print(f"点积: {dot_product}")

outer_product = jnp.outer(a, b)
print(f"外积:\n{outer_product}")

cross_product = jnp.cross(a, b)
print(f"叉积: {cross_product}")

矩阵转置 #

python
import jax.numpy as jnp

A = jnp.array([[1, 2, 3], [4, 5, 6]])

AT = A.T
print(f"转置:\n{AT}")

AT2 = jnp.transpose(A)
print(f"transpose:\n{AT2}")

矩阵分解 #

特征值分解 #

python
import jax.numpy as jnp

A = jnp.array([[4, -2], [1, 1]])

eigenvalues, eigenvectors = jnp.linalg.eig(A)
print(f"特征值: {eigenvalues}")
print(f"特征向量:\n{eigenvectors}")

eigenvalues_real, eigenvectors_real = jnp.linalg.eigh(A)
print(f"实特征值: {eigenvalues_real}")

奇异值分解 (SVD) #

python
import jax.numpy as jnp

A = jnp.array([[1, 2, 3], [4, 5, 6]])

U, S, Vt = jnp.linalg.svd(A)
print(f"U 形状: {U.shape}")
print(f"S (奇异值): {S}")
print(f"Vt 形状: {Vt.shape}")

A_reconstructed = U @ jnp.diag(S) @ Vt
print(f"重构矩阵:\n{A_reconstructed}")

QR 分解 #

python
import jax.numpy as jnp

A = jnp.array([[1, 2], [3, 4], [5, 6]])

Q, R = jnp.linalg.qr(A)
print(f"Q (正交矩阵):\n{Q}")
print(f"R (上三角矩阵):\n{R}")

A_reconstructed = Q @ R
print(f"重构矩阵:\n{A_reconstructed}")

Cholesky 分解 #

python
import jax.numpy as jnp

A = jnp.array([[4, 2], [2, 3]], dtype=jnp.float32)

L = jnp.linalg.cholesky(A)
print(f"Cholesky 分解 L:\n{L}")

A_reconstructed = L @ L.T
print(f"重构矩阵:\n{A_reconstructed}")

矩阵求逆 #

逆矩阵 #

python
import jax.numpy as jnp

A = jnp.array([[1, 2], [3, 4]], dtype=jnp.float32)

A_inv = jnp.linalg.inv(A)
print(f"逆矩阵:\n{A_inv}")

I = A @ A_inv
print(f"A @ A_inv:\n{I}")

伪逆 #

python
import jax.numpy as jnp

A = jnp.array([[1, 2], [3, 4], [5, 6]], dtype=jnp.float32)

A_pinv = jnp.linalg.pinv(A)
print(f"伪逆形状: {A_pinv.shape}")

线性方程组 #

求解 Ax = b #

python
import jax.numpy as jnp

A = jnp.array([[3, 1], [1, 2]], dtype=jnp.float32)
b = jnp.array([9, 8], dtype=jnp.float32)

x = jnp.linalg.solve(A, b)
print(f"解: {x}")

b_check = A @ x
print(f"验证: {b_check}")

最小二乘解 #

python
import jax.numpy as jnp

A = jnp.array([[1, 1], [1, 2], [1, 3]], dtype=jnp.float32)
b = jnp.array([1, 2, 2], dtype=jnp.float32)

x, residuals, rank, s = jnp.linalg.lstsq(A, b)
print(f"最小二乘解: {x}")
print(f"残差: {residuals}")

行列式与迹 #

行列式 #

python
import jax.numpy as jnp

A = jnp.array([[1, 2], [3, 4]], dtype=jnp.float32)

det = jnp.linalg.det(A)
print(f"行列式: {det}")

矩阵的迹 #

python
import jax.numpy as jnp

A = jnp.array([[1, 2], [3, 4]])

trace = jnp.trace(A)
print(f"迹: {trace}")

范数 #

向量范数 #

python
import jax.numpy as jnp

v = jnp.array([3, 4])

norm_2 = jnp.linalg.norm(v)  
norm_1 = jnp.linalg.norm(v, ord=1)  
norm_inf = jnp.linalg.norm(v, ord=jnp.inf)  

print(f"L2 范数: {norm_2}")
print(f"L1 范数: {norm_1}")
print(f"无穷范数: {norm_inf}")

矩阵范数 #

python
import jax.numpy as jnp

A = jnp.array([[1, 2], [3, 4]], dtype=jnp.float32)

frobenius = jnp.linalg.norm(A)  
spectral = jnp.linalg.norm(A, ord=2)  

print(f"Frobenius 范数: {frobenius}")
print(f"谱范数: {spectral}")

条件数 #

python
import jax.numpy as jnp

A = jnp.array([[1, 2], [3, 4]], dtype=jnp.float32)

cond = jnp.linalg.cond(A)
print(f"条件数: {cond}")

矩阵幂 #

python
import jax.numpy as jnp

A = jnp.array([[1, 2], [3, 4]], dtype=jnp.float32)

A_squared = jnp.linalg.matrix_power(A, 2)
print(f"A^2:\n{A_squared}")

A_cubed = jnp.linalg.matrix_power(A, 3)
print(f"A^3:\n{A_cubed}")

批处理操作 #

批量矩阵乘法 #

python
import jax.numpy as jnp

A = jnp.ones((10, 3, 4))  
B = jnp.ones((10, 4, 5))  

C = jnp.matmul(A, B)
print(f"批量矩阵乘法结果形状: {C.shape}")  

批量求逆 #

python
import jax.numpy as jnp
import jax

def batch_inv(A):
    return jax.vmap(jnp.linalg.inv)(A)

A_batch = jnp.array([
    [[1, 2], [3, 4]],
    [[2, 1], [1, 3]]
], dtype=jnp.float32)

A_inv_batch = batch_inv(A_batch)
print(f"批量逆矩阵形状: {A_inv_batch.shape}")

实际应用 #

主成分分析 (PCA) #

python
import jax.numpy as jnp
import jax

def pca(X, n_components):
    X_centered = X - jnp.mean(X, axis=0)
    
    cov = jnp.cov(X_centered.T)
    
    eigenvalues, eigenvectors = jnp.linalg.eigh(cov)
    
    idx = jnp.argsort(eigenvalues)[::-1]
    eigenvectors = eigenvectors[:, idx]
    
    return eigenvectors[:, :n_components]

key = jax.random.PRNGKey(0)
X = jax.random.normal(key, (100, 10))

components = pca(X, 3)
print(f"主成分形状: {components.shape}")

线性回归 #

python
import jax.numpy as jnp
import jax

def linear_regression(X, y):
    return jnp.linalg.lstsq(X, y)[0]

key = jax.random.PRNGKey(0)
X = jax.random.normal(key, (100, 5))
true_w = jnp.array([1, 2, 3, 4, 5], dtype=jnp.float32)
y = X @ true_w + jax.random.normal(key, (100,)) * 0.1

w = linear_regression(X, y)
print(f"估计权重: {w}")
print(f"真实权重: {true_w}")

性能优化 #

JIT 编译 #

python
import jax
import jax.numpy as jnp

@jax.jit
def fast_matrix_operations(A, B):
    C = A @ B
    eigenvalues = jnp.linalg.eigvalsh(C)
    return eigenvalues

A = jax.random.normal(jax.random.PRNGKey(0), (100, 100))
B = jax.random.normal(jax.random.PRNGKey(1), (100, 100))

eigenvalues = fast_matrix_operations(A, B)

下一步 #

现在你已经掌握了线性代数,接下来学习 随机数生成,了解 JAX 的随机数系统!

最后更新:2026-04-04