泛型编程 #

Go 1.18引入了泛型(Generics),这是Go语言最重要的特性之一。泛型允许编写能够处理多种类型的代码,提高代码复用性。

泛型基础 #

泛型函数 #

使用类型参数定义泛型函数:

go
package main

import "fmt"

func Print[T any](s []T) {
    for _, v := range s {
        fmt.Println(v)
    }
}

func main() {
    ints := []int{1, 2, 3}
    strings := []string{"a", "b", "c"}
    
    Print(ints)
    Print(strings)
}

类型参数语法 #

go
func FunctionName[T constraint](param T) T {
    return param
}
  • T - 类型参数名称
  • constraint - 类型约束
  • any - 任意类型(等同于 interface{}

类型约束 #

内置约束 #

go
package main

import "fmt"

func Min[T constraints.Ordered](a, b T) T {
    if a < b {
        return a
    }
    return b
}

func main() {
    fmt.Println(Min(1, 2))
    fmt.Println(Min(3.14, 2.71))
    fmt.Println(Min("apple", "banana"))
}

注意: constraints.Orderedgolang.org/x/exp/constraints 包中。

自定义约束 #

go
package main

import "fmt"

type Number interface {
    int | int64 | float64
}

func Sum[T Number](nums []T) T {
    var total T
    for _, n := range nums {
        total += n
    }
    return total
}

func main() {
    ints := []int{1, 2, 3, 4, 5}
    floats := []float64{1.1, 2.2, 3.3}
    
    fmt.Println(Sum(ints))
    fmt.Println(Sum(floats))
}

comparable 约束 #

comparable 是内置约束,表示可以使用 ==!= 比较的类型:

go
package main

import "fmt"

func Index[T comparable](s []T, x T) int {
    for i, v := range s {
        if v == x {
            return i
        }
    }
    return -1
}

func main() {
    ints := []int{10, 20, 30}
    fmt.Println(Index(ints, 20))
    
    strings := []string{"a", "b", "c"}
    fmt.Println(Index(strings, "d"))
}

泛型类型 #

泛型结构体 #

go
package main

import "fmt"

type Stack[T any] struct {
    elements []T
}

func (s *Stack[T]) Push(v T) {
    s.elements = append(s.elements, v)
}

func (s *Stack[T]) Pop() (T, bool) {
    var zero T
    if len(s.elements) == 0 {
        return zero, false
    }
    
    index := len(s.elements) - 1
    element := s.elements[index]
    s.elements = s.elements[:index]
    return element, true
}

func (s *Stack[T]) Size() int {
    return len(s.elements)
}

func main() {
    intStack := Stack[int]{}
    intStack.Push(1)
    intStack.Push(2)
    
    v, ok := intStack.Pop()
    fmt.Println(v, ok)
    
    stringStack := Stack[string]{}
    stringStack.Push("hello")
    stringStack.Push("world")
    
    s, ok := stringStack.Pop()
    fmt.Println(s, ok)
}

泛型映射 #

go
package main

import "fmt"

type Pair[K comparable, V any] struct {
    Key   K
    Value V
}

func (p Pair[K, V]) String() string {
    return fmt.Sprintf("%v: %v", p.Key, p.Value)
}

func main() {
    p1 := Pair[string, int]{"age", 30}
    p2 := Pair[int, string]{1, "first"}
    
    fmt.Println(p1)
    fmt.Println(p2)
}

类型推断 #

Go编译器可以自动推断类型参数:

go
package main

import "fmt"

func Map[T, U any](s []T, f func(T) U) []U {
    result := make([]U, len(s))
    for i, v := range s {
        result[i] = f(v)
    }
    return result
}

func main() {
    nums := []int{1, 2, 3, 4, 5}
    
    doubled := Map(nums, func(n int) int {
        return n * 2
    })
    fmt.Println(doubled)
    
    strs := Map(nums, func(n int) string {
        return fmt.Sprintf("num:%d", n)
    })
    fmt.Println(strs)
}

约束类型集 #

联合类型 #

go
package main

import "fmt"

type Signed interface {
    ~int | ~int8 | ~int16 | ~int32 | ~int64
}

type Unsigned interface {
    ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr
}

type Float interface {
    ~float32 | ~float64
}

type Integer interface {
    Signed | Unsigned
}

type Number interface {
    Integer | Float
}

func Abs[T Number](n T) T {
    if n < 0 {
        return -n
    }
    return n
}

func main() {
    fmt.Println(Abs(-5))
    fmt.Println(Abs(-3.14))
}

底层类型约束 #

使用 ~ 约束底层类型:

go
package main

import "fmt"

type MyInt int

type IntLike interface {
    ~int | ~int8 | ~int16 | ~int32 | ~int64
}

func Double[T IntLike](n T) T {
    return n * 2
}

func main() {
    var x int = 5
    var y MyInt = 10
    
    fmt.Println(Double(x))
    fmt.Println(Double(y))
}

实用泛型函数 #

切片操作 #

go
package main

import "fmt"

func Filter[T any](s []T, predicate func(T) bool) []T {
    var result []T
    for _, v := range s {
        if predicate(v) {
            result = append(result, v)
        }
    }
    return result
}

func Reduce[T, U any](s []T, initial U, f func(U, T) U) U {
    result := initial
    for _, v := range s {
        result = f(result, v)
    }
    return result
}

func Contains[T comparable](s []T, x T) bool {
    for _, v := range s {
        if v == x {
            return true
        }
    }
    return false
}

func main() {
    nums := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
    
    evens := Filter(nums, func(n int) bool {
        return n%2 == 0
    })
    fmt.Println("偶数:", evens)
    
    sum := Reduce(nums, 0, func(acc, n int) int {
        return acc + n
    })
    fmt.Println("总和:", sum)
    
    fmt.Println("包含5:", Contains(nums, 5))
    fmt.Println("包含11:", Contains(nums, 11))
}

泛型队列 #

go
package main

import "fmt"

type Queue[T any] struct {
    items []T
}

func (q *Queue[T]) Enqueue(item T) {
    q.items = append(q.items, item)
}

func (q *Queue[T]) Dequeue() (T, bool) {
    var zero T
    if len(q.items) == 0 {
        return zero, false
    }
    
    item := q.items[0]
    q.items = q.items[1:]
    return item, true
}

func (q *Queue[T]) IsEmpty() bool {
    return len(q.items) == 0
}

func main() {
    queue := Queue[string]{}
    
    queue.Enqueue("first")
    queue.Enqueue("second")
    queue.Enqueue("third")
    
    for !queue.IsEmpty() {
        item, _ := queue.Dequeue()
        fmt.Println(item)
    }
}

泛型链表 #

go
package main

import "fmt"

type Node[T any] struct {
    Value T
    Next  *Node[T]
}

type LinkedList[T any] struct {
    Head *Node[T]
    Size int
}

func (l *LinkedList[T]) Add(value T) {
    newNode := &Node[T]{Value: value}
    
    if l.Head == nil {
        l.Head = newNode
    } else {
        current := l.Head
        for current.Next != nil {
            current = current.Next
        }
        current.Next = newNode
    }
    l.Size++
}

func (l *LinkedList[T]) ForEach(f func(T)) {
    current := l.Head
    for current != nil {
        f(current.Value)
        current = current.Next
    }
}

func main() {
    list := LinkedList[int]{}
    
    list.Add(1)
    list.Add(2)
    list.Add(3)
    
    list.ForEach(func(v int) {
        fmt.Println(v)
    })
}

泛型与接口 #

go
package main

import "fmt"

type Stringer interface {
    String() string
}

func PrintAll[T Stringer](items []T) {
    for _, item := range items {
        fmt.Println(item.String())
    }
}

type Person struct {
    Name string
    Age  int
}

func (p Person) String() string {
    return fmt.Sprintf("%s (%d岁)", p.Name, p.Age)
}

func main() {
    people := []Person{
        {"张三", 30},
        {"李四", 25},
    }
    
    PrintAll(people)
}

小结 #

概念 说明
any 任意类型约束
comparable 可比较类型约束
类型参数 [T constraint]
泛型函数 func Name[T any](...)
泛型类型 type Name[T any] struct{}
类型推断 编译器自动推断类型

泛型是Go语言的强大特性,合理使用可以提高代码复用性和类型安全性。但也要注意,不是所有场景都需要泛型,有时候接口和代码生成可能是更好的选择。