线性代数 #
概述 #
JAX 提供了完整的线性代数功能,通过 jax.numpy.linalg 模块可以执行各种矩阵运算、分解和求解操作。
基本运算 #
矩阵乘法 #
python
import jax.numpy as jnp
A = jnp.array([[1, 2], [3, 4]])
B = jnp.array([[5, 6], [7, 8]])
C = jnp.dot(A, B)
print(f"dot:\n{C}")
D = jnp.matmul(A, B)
print(f"matmul:\n{D}")
E = A @ B
print(f"@ 运算符:\n{E}")
向量运算 #
python
import jax.numpy as jnp
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])
dot_product = jnp.dot(a, b)
print(f"点积: {dot_product}")
outer_product = jnp.outer(a, b)
print(f"外积:\n{outer_product}")
cross_product = jnp.cross(a, b)
print(f"叉积: {cross_product}")
矩阵转置 #
python
import jax.numpy as jnp
A = jnp.array([[1, 2, 3], [4, 5, 6]])
AT = A.T
print(f"转置:\n{AT}")
AT2 = jnp.transpose(A)
print(f"transpose:\n{AT2}")
矩阵分解 #
特征值分解 #
python
import jax.numpy as jnp
A = jnp.array([[4, -2], [1, 1]])
eigenvalues, eigenvectors = jnp.linalg.eig(A)
print(f"特征值: {eigenvalues}")
print(f"特征向量:\n{eigenvectors}")
eigenvalues_real, eigenvectors_real = jnp.linalg.eigh(A)
print(f"实特征值: {eigenvalues_real}")
奇异值分解 (SVD) #
python
import jax.numpy as jnp
A = jnp.array([[1, 2, 3], [4, 5, 6]])
U, S, Vt = jnp.linalg.svd(A)
print(f"U 形状: {U.shape}")
print(f"S (奇异值): {S}")
print(f"Vt 形状: {Vt.shape}")
A_reconstructed = U @ jnp.diag(S) @ Vt
print(f"重构矩阵:\n{A_reconstructed}")
QR 分解 #
python
import jax.numpy as jnp
A = jnp.array([[1, 2], [3, 4], [5, 6]])
Q, R = jnp.linalg.qr(A)
print(f"Q (正交矩阵):\n{Q}")
print(f"R (上三角矩阵):\n{R}")
A_reconstructed = Q @ R
print(f"重构矩阵:\n{A_reconstructed}")
Cholesky 分解 #
python
import jax.numpy as jnp
A = jnp.array([[4, 2], [2, 3]], dtype=jnp.float32)
L = jnp.linalg.cholesky(A)
print(f"Cholesky 分解 L:\n{L}")
A_reconstructed = L @ L.T
print(f"重构矩阵:\n{A_reconstructed}")
矩阵求逆 #
逆矩阵 #
python
import jax.numpy as jnp
A = jnp.array([[1, 2], [3, 4]], dtype=jnp.float32)
A_inv = jnp.linalg.inv(A)
print(f"逆矩阵:\n{A_inv}")
I = A @ A_inv
print(f"A @ A_inv:\n{I}")
伪逆 #
python
import jax.numpy as jnp
A = jnp.array([[1, 2], [3, 4], [5, 6]], dtype=jnp.float32)
A_pinv = jnp.linalg.pinv(A)
print(f"伪逆形状: {A_pinv.shape}")
线性方程组 #
求解 Ax = b #
python
import jax.numpy as jnp
A = jnp.array([[3, 1], [1, 2]], dtype=jnp.float32)
b = jnp.array([9, 8], dtype=jnp.float32)
x = jnp.linalg.solve(A, b)
print(f"解: {x}")
b_check = A @ x
print(f"验证: {b_check}")
最小二乘解 #
python
import jax.numpy as jnp
A = jnp.array([[1, 1], [1, 2], [1, 3]], dtype=jnp.float32)
b = jnp.array([1, 2, 2], dtype=jnp.float32)
x, residuals, rank, s = jnp.linalg.lstsq(A, b)
print(f"最小二乘解: {x}")
print(f"残差: {residuals}")
行列式与迹 #
行列式 #
python
import jax.numpy as jnp
A = jnp.array([[1, 2], [3, 4]], dtype=jnp.float32)
det = jnp.linalg.det(A)
print(f"行列式: {det}")
矩阵的迹 #
python
import jax.numpy as jnp
A = jnp.array([[1, 2], [3, 4]])
trace = jnp.trace(A)
print(f"迹: {trace}")
范数 #
向量范数 #
python
import jax.numpy as jnp
v = jnp.array([3, 4])
norm_2 = jnp.linalg.norm(v)
norm_1 = jnp.linalg.norm(v, ord=1)
norm_inf = jnp.linalg.norm(v, ord=jnp.inf)
print(f"L2 范数: {norm_2}")
print(f"L1 范数: {norm_1}")
print(f"无穷范数: {norm_inf}")
矩阵范数 #
python
import jax.numpy as jnp
A = jnp.array([[1, 2], [3, 4]], dtype=jnp.float32)
frobenius = jnp.linalg.norm(A)
spectral = jnp.linalg.norm(A, ord=2)
print(f"Frobenius 范数: {frobenius}")
print(f"谱范数: {spectral}")
条件数 #
python
import jax.numpy as jnp
A = jnp.array([[1, 2], [3, 4]], dtype=jnp.float32)
cond = jnp.linalg.cond(A)
print(f"条件数: {cond}")
矩阵幂 #
python
import jax.numpy as jnp
A = jnp.array([[1, 2], [3, 4]], dtype=jnp.float32)
A_squared = jnp.linalg.matrix_power(A, 2)
print(f"A^2:\n{A_squared}")
A_cubed = jnp.linalg.matrix_power(A, 3)
print(f"A^3:\n{A_cubed}")
批处理操作 #
批量矩阵乘法 #
python
import jax.numpy as jnp
A = jnp.ones((10, 3, 4))
B = jnp.ones((10, 4, 5))
C = jnp.matmul(A, B)
print(f"批量矩阵乘法结果形状: {C.shape}")
批量求逆 #
python
import jax.numpy as jnp
import jax
def batch_inv(A):
return jax.vmap(jnp.linalg.inv)(A)
A_batch = jnp.array([
[[1, 2], [3, 4]],
[[2, 1], [1, 3]]
], dtype=jnp.float32)
A_inv_batch = batch_inv(A_batch)
print(f"批量逆矩阵形状: {A_inv_batch.shape}")
实际应用 #
主成分分析 (PCA) #
python
import jax.numpy as jnp
import jax
def pca(X, n_components):
X_centered = X - jnp.mean(X, axis=0)
cov = jnp.cov(X_centered.T)
eigenvalues, eigenvectors = jnp.linalg.eigh(cov)
idx = jnp.argsort(eigenvalues)[::-1]
eigenvectors = eigenvectors[:, idx]
return eigenvectors[:, :n_components]
key = jax.random.PRNGKey(0)
X = jax.random.normal(key, (100, 10))
components = pca(X, 3)
print(f"主成分形状: {components.shape}")
线性回归 #
python
import jax.numpy as jnp
import jax
def linear_regression(X, y):
return jnp.linalg.lstsq(X, y)[0]
key = jax.random.PRNGKey(0)
X = jax.random.normal(key, (100, 5))
true_w = jnp.array([1, 2, 3, 4, 5], dtype=jnp.float32)
y = X @ true_w + jax.random.normal(key, (100,)) * 0.1
w = linear_regression(X, y)
print(f"估计权重: {w}")
print(f"真实权重: {true_w}")
性能优化 #
JIT 编译 #
python
import jax
import jax.numpy as jnp
@jax.jit
def fast_matrix_operations(A, B):
C = A @ B
eigenvalues = jnp.linalg.eigvalsh(C)
return eigenvalues
A = jax.random.normal(jax.random.PRNGKey(0), (100, 100))
B = jax.random.normal(jax.random.PRNGKey(1), (100, 100))
eigenvalues = fast_matrix_operations(A, B)
下一步 #
现在你已经掌握了线性代数,接下来学习 随机数生成,了解 JAX 的随机数系统!
最后更新:2026-04-04