并发模式 #

Go语言的并发原语(goroutine和channel)可以组合出各种强大的并发模式。掌握这些模式可以编写高效、清晰的并发程序。

生产者-消费者模式 #

基本实现 #

go
package main

import (
    "fmt"
    "math/rand"
    "time"
)

func producer(id int, ch chan<- int) {
    for i := 0; ; i++ {
        value := rand.Intn(100)
        ch <- value
        fmt.Printf("生产者%d: 生产 %d\n", id, value)
        time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond)
    }
}

func consumer(id int, ch <-chan int) {
    for value := range ch {
        fmt.Printf("  消费者%d: 消费 %d\n", id, value)
        time.Sleep(time.Duration(rand.Intn(300)) * time.Millisecond)
    }
}

func main() {
    ch := make(chan int, 10)
    
    for i := 1; i <= 3; i++ {
        go producer(i, ch)
    }
    
    for i := 1; i <= 2; i++ {
        go consumer(i, ch)
    }
    
    select {}
}

优雅关闭 #

go
package main

import (
    "fmt"
    "sync"
)

func producer(id int, ch chan<- int, done <-chan bool, wg *sync.WaitGroup) {
    defer wg.Done()
    
    for i := 0; i < 5; i++ {
        select {
        case ch <- i:
            fmt.Printf("生产者%d: 生产 %d\n", id, i)
        case <-done:
            fmt.Printf("生产者%d: 收到停止信号\n", id)
            return
        }
    }
}

func consumer(id int, ch <-chan int, wg *sync.WaitGroup) {
    defer wg.Done()
    
    for value := range ch {
        fmt.Printf("  消费者%d: 消费 %d\n", id, value)
    }
    fmt.Printf("消费者%d: 通道已关闭\n", id)
}

func main() {
    ch := make(chan int, 10)
    done := make(chan bool)
    var wg sync.WaitGroup
    
    wg.Add(2)
    go producer(1, ch, done, &wg)
    go producer(2, ch, done, &wg)
    
    wg.Add(1)
    go consumer(1, ch, &wg)
    
    wg.Wait()
    close(ch)
    
    wg.Wait()
    fmt.Println("程序结束")
}

工作池模式 #

基本工作池 #

go
package main

import (
    "fmt"
    "sync"
    "time"
)

type Job struct {
    ID   int
    Data int
}

type Result struct {
    Job    Job
    Result int
}

func worker(id int, jobs <-chan Job, results chan<- Result, wg *sync.WaitGroup) {
    defer wg.Done()
    
    for job := range jobs {
        fmt.Printf("工作者%d: 处理任务 %d\n", id, job.ID)
        time.Sleep(500 * time.Millisecond)
        
        results <- Result{
            Job:    job,
            Result: job.Data * 2,
        }
    }
}

func main() {
    const numJobs = 10
    const numWorkers = 3
    
    jobs := make(chan Job, numJobs)
    results := make(chan Result, numJobs)
    
    var wg sync.WaitGroup
    
    for w := 1; w <= numWorkers; w++ {
        wg.Add(1)
        go worker(w, jobs, results, &wg)
    }
    
    for j := 1; j <= numJobs; j++ {
        jobs <- Job{ID: j, Data: j}
    }
    close(jobs)
    
    go func() {
        wg.Wait()
        close(results)
    }()
    
    for result := range results {
        fmt.Printf("结果: 任务%d -> %d\n", result.Job.ID, result.Result)
    }
}

带超时的工作池 #

go
package main

import (
    "context"
    "fmt"
    "sync"
    "time"
)

func workerWithTimeout(ctx context.Context, id int, jobs <-chan int, results chan<- int) {
    for {
        select {
        case job, ok := <-jobs:
            if !ok {
                return
            }
            
            select {
            case <-time.After(300 * time.Millisecond):
                results <- job * 2
                fmt.Printf("工作者%d: 完成任务 %d\n", id, job)
            case <-ctx.Done():
                fmt.Printf("工作者%d: 任务 %d 超时取消\n", id, job)
                return
            }
            
        case <-ctx.Done():
            fmt.Printf("工作者%d: 收到停止信号\n", id)
            return
        }
    }
}

func main() {
    ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
    defer cancel()
    
    jobs := make(chan int, 10)
    results := make(chan int, 10)
    
    var wg sync.WaitGroup
    
    for w := 1; w <= 3; w++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()
            workerWithTimeout(ctx, id, jobs, results)
        }(w)
    }
    
    go func() {
        for i := 1; i <= 10; i++ {
            jobs <- i
        }
        close(jobs)
    }()
    
    go func() {
        wg.Wait()
        close(results)
    }()
    
    for result := range results {
        fmt.Printf("结果: %d\n", result)
    }
}

Pipeline模式 #

基本Pipeline #

go
package main

import "fmt"

func generator(nums ...int) <-chan int {
    out := make(chan int)
    go func() {
        for _, n := range nums {
            out <- n
        }
        close(out)
    }()
    return out
}

func square(in <-chan int) <-chan int {
    out := make(chan int)
    go func() {
        for n := range in {
            out <- n * n
        }
        close(out)
    }()
    return out
}

func double(in <-chan int) <-chan int {
    out := make(chan int)
    go func() {
        for n := range in {
            out <- n * 2
        }
        close(out)
    }()
    return out
}

func main() {
    nums := generator(1, 2, 3, 4, 5)
    squared := square(nums)
    doubled := double(squared)
    
    for result := range doubled {
        fmt.Println(result)
    }
}

可取消的Pipeline #

go
package main

import (
    "context"
    "fmt"
)

func generatorWithContext(ctx context.Context, nums ...int) <-chan int {
    out := make(chan int)
    go func() {
        defer close(out)
        for _, n := range nums {
            select {
            case out <- n:
            case <-ctx.Done():
                return
            }
        }
    }()
    return out
}

func processWithContext(ctx context.Context, in <-chan int, f func(int) int) <-chan int {
    out := make(chan int)
    go func() {
        defer close(out)
        for n := range in {
            select {
            case out <- f(n):
            case <-ctx.Done():
                return
            }
        }
    }()
    return out
}

func main() {
    ctx, cancel := context.WithCancel(context.Background())
    defer cancel()
    
    nums := generatorWithContext(ctx, 1, 2, 3, 4, 5)
    
    squared := processWithContext(ctx, nums, func(n int) int {
        return n * n
    })
    
    for result := range squared {
        fmt.Println(result)
        if result == 9 {
            cancel()
            break
        }
    }
}

扇出扇入模式 #

扇出(Fan-out) #

多个goroutine从同一个通道读取数据:

go
package main

import (
    "fmt"
    "sync"
    "time"
)

func producer(nums ...int) <-chan int {
    out := make(chan int)
    go func() {
        for _, n := range nums {
            out <- n
        }
        close(out)
    }()
    return out
}

func worker(id int, in <-chan int) <-chan int {
    out := make(chan int)
    go func() {
        defer close(out)
        for n := range in {
            time.Sleep(100 * time.Millisecond)
            fmt.Printf("工作者%d: 处理 %d\n", id, n)
            out <- n * n
        }
    }()
    return out
}

func main() {
    in := producer(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
    
    numWorkers := 3
    var channels []<-chan int
    
    for i := 1; i <= numWorkers; i++ {
        channels = append(channels, worker(i, in))
    }
    
    var wg sync.WaitGroup
    results := make(chan int)
    
    for _, ch := range channels {
        wg.Add(1)
        go func(c <-chan int) {
            defer wg.Done()
            for n := range c {
                results <- n
            }
        }(ch)
    }
    
    go func() {
        wg.Wait()
        close(results)
    }()
    
    for result := range results {
        fmt.Printf("结果: %d\n", result)
    }
}

扇入(Fan-in) #

合并多个通道到一个通道:

go
package main

import (
    "fmt"
    "sync"
)

func merge(channels ...<-chan int) <-chan int {
    out := make(chan int)
    var wg sync.WaitGroup
    
    output := func(c <-chan int) {
        defer wg.Done()
        for n := range c {
            out <- n
        }
    }
    
    wg.Add(len(channels))
    for _, c := range channels {
        go output(c)
    }
    
    go func() {
        wg.Wait()
        close(out)
    }()
    
    return out
}

func source(name string, nums ...int) <-chan int {
    out := make(chan int)
    go func() {
        defer close(out)
        for _, n := range nums {
            out <- n
            fmt.Printf("%s: 发送 %d\n", name, n)
        }
    }()
    return out
}

func main() {
    ch1 := source("A", 1, 2, 3)
    ch2 := source("B", 4, 5, 6)
    ch3 := source("C", 7, 8, 9)
    
    merged := merge(ch1, ch2, ch3)
    
    for n := range merged {
        fmt.Printf("收到: %d\n", n)
    }
}

超时模式 #

select超时 #

go
package main

import (
    "fmt"
    "time"
)

func slowOperation() <-chan int {
    out := make(chan int)
    go func() {
        time.Sleep(2 * time.Second)
        out <- 42
    }()
    return out
}

func main() {
    select {
    case result := <-slowOperation():
        fmt.Println("结果:", result)
    case <-time.After(1 * time.Second):
        fmt.Println("操作超时")
    }
}

心跳模式 #

go
package main

import (
    "fmt"
    "time"
)

func heartbeat(interval time.Duration) (<-chan time.Time, func()) {
    done := make(chan struct{})
    ticker := time.NewTicker(interval)
    
    go func() {
        for {
            select {
            case <-done:
                ticker.Stop()
                return
            case t := <-ticker.C:
                select {
                case <-done:
                    return
                default:
                }
                fmt.Println("心跳:", t.Format("15:04:05"))
            }
        }
    }()
    
    return ticker.C, func() { close(done) }
}

func main() {
    _, stop := heartbeat(500 * time.Millisecond)
    
    time.Sleep(3 * time.Second)
    stop()
    
    fmt.Println("心跳已停止")
}

优雅关闭模式 #

go
package main

import (
    "context"
    "fmt"
    "os/signal"
    "syscall"
    "time"
)

func worker(ctx context.Context, id int) {
    for {
        select {
        case <-ctx.Done():
            fmt.Printf("工作者%d: 正在关闭...\n", id)
            time.Sleep(500 * time.Millisecond)
            fmt.Printf("工作者%d: 已关闭\n", id)
            return
        default:
            fmt.Printf("工作者%d: 工作中...\n", id)
            time.Sleep(1 * time.Second)
        }
    }
}

func main() {
    ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
    defer stop()
    
    for i := 1; i <= 3; i++ {
        go worker(ctx, i)
    }
    
    <-ctx.Done()
    fmt.Println("收到停止信号,等待工作者完成...")
    
    stop()
    time.Sleep(2 * time.Second)
    fmt.Println("程序退出")
}

小结 #

模式 说明
生产者-消费者 数据生产和消费分离
工作池 并行处理任务
Pipeline 数据流经多个处理阶段
扇出 多个goroutine处理同一数据源
扇入 合并多个数据源
超时 控制操作时间
心跳 监控goroutine存活
优雅关闭 安全终止程序

并发模式是构建高并发应用的基础,选择合适的模式可以大大简化并发编程的复杂度。