闭包与高阶函数 #

一、闭包基础 #

1.1 什么是闭包 #

闭包是一个函数,它可以访问其外部作用域中的变量,即使外部函数已经返回:

lua
local function create_counter()
    local count = 0  -- 外部变量
    
    return function()  -- 闭包
        count = count + 1
        return count
    end
end

local counter = create_counter()
print(counter())  -- 1
print(counter())  -- 2
print(counter())  -- 3

1.2 闭包的组成 #

lua
local function outer(x)      -- 外部函数
    local y = 10             -- 自由变量
    
    local function inner(z)  -- 内部函数(闭包)
        return x + y + z     -- 访问外部变量
    end
    
    return inner
end

local closure = outer(5)
print(closure(3))  -- 5 + 10 + 3 = 18

1.3 词法作用域 #

lua
local x = 10

local function outer()
    local x = 20
    
    local function inner()
        print(x)  -- 访问最近的 x(20)
    end
    
    return inner
end

local f = outer()
f()  -- 20(不是 10)

二、闭包应用 #

2.1 数据封装 #

lua
-- 私有变量
local function create_account(initial_balance)
    local balance = initial_balance or 0  -- 私有变量
    
    return {
        deposit = function(amount)
            balance = balance + amount
            return balance
        end,
        withdraw = function(amount)
            if amount > balance then
                return nil, "余额不足"
            end
            balance = balance - amount
            return balance
        end,
        get_balance = function()
            return balance
        end
    }
end

local account = create_account(100)
print(account.get_balance())    -- 100
account.deposit(50)
print(account.get_balance())    -- 150
account.withdraw(30)
print(account.get_balance())    -- 120
-- balance 是私有的,外部无法直接访问

2.2 工厂函数 #

lua
-- 创建验证器
local function create_validator(rules)
    return function(data)
        for field, rule in pairs(rules) do
            local value = data[field]
            if rule.required and value == nil then
                return false, field .. " 是必填项"
            end
            if rule.min and value and #value < rule.min then
                return false, field .. " 长度不足"
            end
        end
        return true
    end
end

local validate_user = create_validator({
    name = {required = true, min = 2},
    email = {required = true}
})

local ok, err = validate_user({name = "A"})
print(ok, err)  -- false    name 长度不足

2.3 状态保持 #

lua
-- 记忆化
local function memoize(func)
    local cache = {}
    return function(...)
        local key = table.concat({...}, "-")
        if cache[key] == nil then
            cache[key] = {func(...)}
        end
        return table.unpack(cache[key])
    end
end

local slow_fib = function(n)
    if n <= 1 then return n end
    return slow_fib(n - 1) + slow_fib(n - 2)
end

local fib = memoize(function(n)
    if n <= 1 then return n end
    return fib(n - 1) + fib(n - 2)
end)

print(fib(50))  -- 快速计算

2.4 延迟执行 #

lua
-- 惰性求值
local function lazy(func, ...)
    local args = {...}
    local cached_result = nil
    local computed = false
    
    return function()
        if not computed then
            cached_result = func(table.unpack(args))
            computed = true
        end
        return cached_result
    end
end

local lazy_value = lazy(function()
    print("计算中...")
    return 42
end)

print("创建完成")
print(lazy_value())  -- 计算中... 42
print(lazy_value())  -- 42(不再计算)

三、高阶函数 #

3.1 函数作为参数 #

lua
-- 数组遍历
local function foreach(arr, func)
    for i, v in ipairs(arr) do
        func(v, i)
    end
end

foreach({1, 2, 3, 4, 5}, function(v, i)
    print(string.format("arr[%d] = %d", i, v))
end)

-- 条件查找
local function find(arr, predicate)
    for i, v in ipairs(arr) do
        if predicate(v, i) then
            return v, i
        end
    end
    return nil, nil
end

local value, index = find({1, 2, 3, 4, 5}, function(v)
    return v > 3
end)
print(value, index)  -- 4    4

3.2 函数作为返回值 #

lua
-- 谓词生成器
local function greater_than(n)
    return function(x)
        return x > n
    end
end

local is_adult = greater_than(17)
print(is_adult(18))  -- true
print(is_adult(16))  -- false

-- 组合使用
local numbers = {1, 5, 10, 15, 20}
local big_numbers = filter(numbers, greater_than(10))
print(table.concat(big_numbers, ", "))  -- 15, 20

3.3 函数组合 #

lua
-- compose:从右向左组合
local function compose(...)
    local funcs = {...}
    return function(x)
        local result = x
        for i = #funcs, 1, -1 do
            result = funcs[i](result)
        end
        return result
    end
end

-- pipe:从左向右组合
local function pipe(...)
    local funcs = {...}
    return function(x)
        local result = x
        for _, f in ipairs(funcs) do
            result = f(result)
        end
        return result
    end
end

local double = function(x) return x * 2 end
local add_one = function(x) return x + 1 end
local square = function(x) return x * x end

local process = pipe(double, add_one, square)
print(process(3))  -- ((3 * 2) + 1) ^ 2 = 49

3.4 柯里化 #

lua
-- 柯里化函数
local function curry(func, arity)
    arity = arity or 2
    return function(...)
        local args = {...}
        if #args >= arity then
            return func(table.unpack(args))
        else
            return curry(function(...)
                return func(table.unpack(args), ...)
            end, arity - #args)
        end
    end
end

local add = curry(function(a, b, c)
    return a + b + c
end)

print(add(1)(2)(3))  -- 6
print(add(1, 2)(3))  -- 6
print(add(1)(2, 3))  -- 6
print(add(1, 2, 3))  -- 6

四、常用高阶函数 #

4.1 map #

lua
local function map(arr, func)
    local result = {}
    for i, v in ipairs(arr) do
        result[i] = func(v, i)
    end
    return result
end

local doubled = map({1, 2, 3, 4, 5}, function(x)
    return x * 2
end)
print(table.concat(doubled, ", "))  -- 2, 4, 6, 8, 10

-- 对象映射
local users = {
    {name = "Alice", age = 25},
    {name = "Bob", age = 30},
    {name = "Charlie", age = 35}
}

local names = map(users, function(user)
    return user.name
end)
print(table.concat(names, ", "))  -- Alice, Bob, Charlie

4.2 filter #

lua
local function filter(arr, func)
    local result = {}
    for i, v in ipairs(arr) do
        if func(v, i) then
            table.insert(result, v)
        end
    end
    return result
end

local evens = filter({1, 2, 3, 4, 5, 6}, function(x)
    return x % 2 == 0
end)
print(table.concat(evens, ", "))  -- 2, 4, 6

4.3 reduce #

lua
local function reduce(arr, func, initial)
    local acc = initial
    local start = 1
    
    if acc == nil then
        acc = arr[1]
        start = 2
    end
    
    for i = start, #arr do
        acc = func(acc, arr[i], i)
    end
    return acc
end

local sum = reduce({1, 2, 3, 4, 5}, function(acc, v)
    return acc + v
end)
print(sum)  -- 15

local product = reduce({1, 2, 3, 4, 5}, function(acc, v)
    return acc * v
end, 1)
print(product)  -- 120

4.4 some 和 every #

lua
local function some(arr, func)
    for i, v in ipairs(arr) do
        if func(v, i) then
            return true
        end
    end
    return false
end

local function every(arr, func)
    for i, v in ipairs(arr) do
        if not func(v, i) then
            return false
        end
    end
    return true
end

local arr = {1, 2, 3, 4, 5}
print(some(arr, function(x) return x > 3 end))   -- true
print(every(arr, function(x) return x > 0 end))  -- true

五、闭包陷阱 #

5.1 循环中的闭包 #

lua
-- 问题:所有函数都返回相同值
local functions = {}
for i = 1, 3 do
    functions[i] = function()
        return i
    end
end

print(functions[1]())  -- 1
print(functions[2]())  -- 2
print(functions[3]())  -- 3
-- Lua 中这个问题不存在,因为每次循环都创建新的 i

-- 但如果使用外部变量
local j = 0
local funcs = {}
for i = 1, 3 do
    j = i
    funcs[i] = function()
        return j
    end
end

print(funcs[1]())  -- 3(所有函数都返回 3)
print(funcs[2]())  -- 3
print(funcs[3]())  -- 3

-- 解决方案:使用闭包捕获当前值
local funcs2 = {}
for i = 1, 3 do
    funcs2[i] = (function(n)
        return function()
            return n
        end
    end)(i)
end

print(funcs2[1]())  -- 1
print(funcs2[2]())  -- 2
print(funcs2[3]())  -- 3

5.2 内存泄漏 #

lua
-- 闭包会持有外部变量的引用
local function create_leak()
    local big_data = {}  -- 大数据
    for i = 1, 1000000 do
        big_data[i] = i
    end
    
    return function()
        -- 即使不使用 big_data,它也会被保留
        return "leak"
    end
end

-- 正确做法:不需要时释放
local function create_no_leak()
    local big_data = {}
    for i = 1, 1000000 do
        big_data[i] = i
    end
    
    local result = big_data[1]  -- 只保存需要的
    big_data = nil  -- 释放大对象
    
    return function()
        return result
    end
end

六、实用模式 #

6.1 单例模式 #

lua
local singleton = (function()
    local instance = nil
    
    return function()
        if not instance then
            instance = {
                value = 0,
                increment = function(self)
                    self.value = self.value + 1
                end
            }
        end
        return instance
    end
end)()

local s1 = singleton()
local s2 = singleton()
s1:increment()
print(s2.value)  -- 1(同一个实例)

6.2 发布订阅模式 #

lua
local function create_event_emitter()
    local events = {}
    
    return {
        on = function(event, callback)
            events[event] = events[event] or {}
            table.insert(events[event], callback)
        end,
        
        emit = function(event, ...)
            if events[event] then
                for _, callback in ipairs(events[event]) do
                    callback(...)
                end
            end
        end,
        
        off = function(event, callback)
            if events[event] then
                for i, cb in ipairs(events[event]) do
                    if cb == callback then
                        table.remove(events[event], i)
                        break
                    end
                end
            end
        end
    }
end

local emitter = create_event_emitter()
emitter.on("message", function(msg)
    print("收到消息:" .. msg)
end)
emitter.emit("message", "Hello")

6.3 装饰器模式 #

lua
local function log_calls(func)
    local count = 0
    return function(...)
        count = count + 1
        print(string.format("调用 #%d", count))
        return func(...)
    end
end

local function add(a, b)
    return a + b
end

local logged_add = log_calls(add)
print(logged_add(1, 2))  -- 调用 #1    3
print(logged_add(3, 4))  -- 调用 #2    7

七、总结 #

本章介绍了 Lua 的闭包和高阶函数:

  1. 闭包基础:闭包可以访问外部作用域的变量
  2. 闭包应用:数据封装、工厂函数、状态保持
  3. 高阶函数:函数作为参数或返回值
  4. 常用函数:map、filter、reduce
  5. 闭包陷阱:循环中的闭包、内存泄漏
  6. 设计模式:单例、发布订阅、装饰器

闭包是 Lua 函数式编程的核心概念。下一章,我们将学习可变参数函数。

最后更新:2026-03-27