代码生成 #

一、@generated基础 #

1.1 基本语法 #

julia
@generated function gen_add(a, b)
    quote
        a + b
    end
end

gen_add(1, 2)

1.2 基于类型生成 #

julia
@generated function gen_multiply(a::T, b::T) where {T}
    if T <: Integer
        quote
            a * b
        end
    else
        quote
            a * b
        end
    end
end

gen_multiply(1, 2)
gen_multiply(1.0, 2.0)

1.3 查看生成代码 #

julia
@generated function gen_example(x)
    println("Generating for type: ", x)
    :(x)
end

gen_example(1)
gen_example(1.0)

二、生成函数应用 #

2.1 类型特化 #

julia
@generated function type_specific(x::T) where {T}
    if T <: Integer
        :(x^2)
    elseif T <: AbstractFloat
        :(sqrt(x))
    else
        :(x)
    end
end

type_specific(5)
type_specific(16.0)
type_specific("hello")

2.2 循环展开 #

julia
@generated function unrolled_sum(arr::NTuple{N, T}) where {N, T}
    ex = Expr(:block)
    sum_expr = Expr(:call, :+)
    for i in 1:N
        push!(sum_expr.args, :(arr[$i]))
    end
    push!(ex.args, sum_expr)
    return ex
end

unrolled_sum((1, 2, 3, 4, 5))

2.3 编译时计算 #

julia
@generated function factorial_gen(::Val{N}) where {N}
    result = 1
    for i in 1:N
        result *= i
    end
    :($result)
end

factorial_gen(Val{5}())

三、生成函数限制 #

3.1 不能观察值 #

生成函数只能观察参数类型,不能观察参数值:

julia
@generated function bad_gen(x)
    if x > 0
        :(x)
    else
        :(-x)
    end
end

3.2 必须返回表达式 #

julia
@generated function must_return_expr(x)
    return 42
end

3.3 纯函数 #

生成函数应该是纯函数,不应有副作用。

四、实践练习 #

4.1 练习1:静态数组操作 #

julia
@generated function static_dot(a::NTuple{N, T}, b::NTuple{N, T}) where {N, T}
    ex = Expr(:call, :+)
    for i in 1:N
        push!(ex.args, :(a[$i] * b[$i]))
    end
    return ex
end

static_dot((1, 2, 3), (4, 5, 6))

4.2 练习2:类型打印 #

julia
@generated function type_info(::Type{T}) where {T}
    quote
        println("Type: ", $(string(T)))
        println("Supertype: ", $(string(supertype(T))))
    end
end

type_info(Int)
type_info(Float64)

4.3 练习3:字段访问 #

julia
@generated function get_field_names(::Type{T}) where {T}
    fieldnames_T = fieldnames(T)
    quote
        $fieldnames_T
    end
end

struct Point
    x::Float64
    y::Float64
end

get_field_names(Point)

五、总结 #

本章我们学习了:

  1. @generated语法:生成函数定义
  2. 类型特化:基于类型生成代码
  3. 循环展开:编译时优化
  4. 生成函数限制:只能观察类型

接下来让我们学习Julia的测试!

最后更新:2026-03-27