表格驱动测试 #

表格驱动测试(Table Driven Tests)是Go语言推荐的测试模式,它将测试数据和测试逻辑分离,使代码更清晰、更易维护。

基本概念 #

传统测试 vs 表格驱动测试 #

传统方式:

go
func TestAddTraditional(t *testing.T) {
    if Add(1, 2) != 3 {
        t.Error("1 + 2 should be 3")
    }
    if Add(-1, 1) != 0 {
        t.Error("-1 + 1 should be 0")
    }
    if Add(0, 0) != 0 {
        t.Error("0 + 0 should be 0")
    }
}

表格驱动方式:

go
func TestAddTableDriven(t *testing.T) {
    tests := []struct {
        name     string
        a, b     int
        expected int
    }{
        {"正数相加", 1, 2, 3},
        {"负数相加", -1, -2, -3},
        {"正负混合", -1, 1, 0},
        {"零值测试", 0, 0, 0},
    }
    
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            if got := Add(tt.a, tt.b); got != tt.expected {
                t.Errorf("Add(%d, %d) = %d; want %d", 
                    tt.a, tt.b, got, tt.expected)
            }
        })
    }
}

基本结构 #

测试用例结构体 #

go
package myapp

import "testing"

func TestMultiply(t *testing.T) {
    tests := []struct {
        name     string
        a        int
        b        int
        expected int
    }{
        {
            name:     "正数相乘",
            a:        3,
            b:        4,
            expected: 12,
        },
        {
            name:     "负数相乘",
            a:        -3,
            b:        4,
            expected: -12,
        },
        {
            name:     "零乘任何数",
            a:        0,
            b:        100,
            expected: 0,
        },
    }
    
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            got := Multiply(tt.a, tt.b)
            if got != tt.expected {
                t.Errorf("Multiply(%d, %d) = %d; want %d",
                    tt.a, tt.b, got, tt.expected)
            }
        })
    }
}

处理错误返回 #

测试返回错误的函数 #

go
package myapp

import (
    "errors"
    "testing"
)

func Divide(a, b float64) (float64, error) {
    if b == 0 {
        return 0, errors.New("division by zero")
    }
    return a / b, nil
}

func TestDivide(t *testing.T) {
    tests := []struct {
        name      string
        a, b      float64
        expected  float64
        expectErr bool
    }{
        {"正常除法", 10, 2, 5, false},
        {"除数为负", 10, -2, -5, false},
        {"除零错误", 10, 0, 0, true},
        {"零被除", 0, 5, 0, false},
    }
    
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            got, err := Divide(tt.a, tt.b)
            
            if tt.expectErr {
                if err == nil {
                    t.Error("expected error but got nil")
                }
                return
            }
            
            if err != nil {
                t.Fatalf("unexpected error: %v", err)
            }
            
            if got != tt.expected {
                t.Errorf("Divide(%f, %f) = %f; want %f",
                    tt.a, tt.b, got, tt.expected)
            }
        })
    }
}

测试特定错误类型 #

go
package myapp

import (
    "errors"
    "testing"
)

var ErrNotFound = errors.New("not found")
var ErrPermission = errors.New("permission denied")

func GetData(id int) (string, error) {
    switch id {
    case 0:
        return "", ErrNotFound
    case -1:
        return "", ErrPermission
    default:
        return "data", nil
    }
}

func TestGetData(t *testing.T) {
    tests := []struct {
        name      string
        id        int
        expected  string
        expectErr error
    }{
        {"正常获取", 1, "data", nil},
        {"未找到", 0, "", ErrNotFound},
        {"权限拒绝", -1, "", ErrPermission},
    }
    
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            got, err := GetData(tt.id)
            
            if !errors.Is(err, tt.expectErr) {
                t.Errorf("expected error %v, got %v", tt.expectErr, err)
            }
            
            if got != tt.expected {
                t.Errorf("GetData(%d) = %s; want %s", tt.id, got, tt.expected)
            }
        })
    }
}

测试复杂数据类型 #

测试结构体 #

go
package myapp

import (
    "reflect"
    "testing"
)

type Person struct {
    Name string
    Age  int
    City string
}

func NewPerson(name string, age int, city string) *Person {
    return &Person{
        Name: name,
        Age:  age,
        City: city,
    }
}

func TestNewPerson(t *testing.T) {
    tests := []struct {
        name     string
        input    []interface{}
        expected *Person
    }{
        {
            name:  "完整信息",
            input: []interface{}{"张三", 30, "北京"},
            expected: &Person{
                Name: "张三",
                Age:  30,
                City: "北京",
            },
        },
        {
            name:  "不同信息",
            input: []interface{}{"李四", 25, "上海"},
            expected: &Person{
                Name: "李四",
                Age:  25,
                City: "上海",
            },
        },
    }
    
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            got := NewPerson(
                tt.input[0].(string),
                tt.input[1].(int),
                tt.input[2].(string),
            )
            
            if !reflect.DeepEqual(got, tt.expected) {
                t.Errorf("got %+v, want %+v", got, tt.expected)
            }
        })
    }
}

使用 cmp 比较 #

go
package myapp

import (
    "testing"
    
    "github.com/google/go-cmp/cmp"
)

func TestWithCmp(t *testing.T) {
    tests := []struct {
        name     string
        input    string
        expected Person
    }{
        {
            name:  "测试1",
            input: "张三,30,北京",
            expected: Person{
                Name: "张三",
                Age:  30,
                City: "北京",
            },
        },
    }
    
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            got := ParsePerson(tt.input)
            
            if diff := cmp.Diff(got, tt.expected); diff != "" {
                t.Errorf("mismatch (-got +want):\n%s", diff)
            }
        })
    }
}

测试切片和映射 #

测试切片操作 #

go
package myapp

import "testing"

func FilterEven(nums []int) []int {
    var result []int
    for _, n := range nums {
        if n%2 == 0 {
            result = append(result, n)
        }
    }
    return result
}

func TestFilterEven(t *testing.T) {
    tests := []struct {
        name     string
        input    []int
        expected []int
    }{
        {"混合数字", []int{1, 2, 3, 4, 5, 6}, []int{2, 4, 6}},
        {"全是偶数", []int{2, 4, 6, 8}, []int{2, 4, 6, 8}},
        {"全是奇数", []int{1, 3, 5, 7}, []int{}},
        {"空切片", []int{}, []int{}},
    }
    
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            got := FilterEven(tt.input)
            
            if len(got) != len(tt.expected) {
                t.Errorf("length mismatch: got %d, want %d",
                    len(got), len(tt.expected))
                return
            }
            
            for i := range got {
                if got[i] != tt.expected[i] {
                    t.Errorf("index %d: got %d, want %d",
                        i, got[i], tt.expected[i])
                }
            }
        })
    }
}

测试映射操作 #

go
package myapp

import "testing"

func MergeMaps(m1, m2 map[string]int) map[string]int {
    result := make(map[string]int)
    for k, v := range m1 {
        result[k] = v
    }
    for k, v := range m2 {
        result[k] = v
    }
    return result
}

func TestMergeMaps(t *testing.T) {
    tests := []struct {
        name     string
        m1       map[string]int
        m2       map[string]int
        expected map[string]int
    }{
        {
            name:     "无重叠",
            m1:       map[string]int{"a": 1, "b": 2},
            m2:       map[string]int{"c": 3, "d": 4},
            expected: map[string]int{"a": 1, "b": 2, "c": 3, "d": 4},
        },
        {
            name:     "有重叠",
            m1:       map[string]int{"a": 1, "b": 2},
            m2:       map[string]int{"b": 3, "c": 4},
            expected: map[string]int{"a": 1, "b": 3, "c": 4},
        },
        {
            name:     "空映射",
            m1:       map[string]int{},
            m2:       map[string]int{"a": 1},
            expected: map[string]int{"a": 1},
        },
    }
    
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            got := MergeMaps(tt.m1, tt.m2)
            
            if len(got) != len(tt.expected) {
                t.Errorf("length mismatch: got %d, want %d",
                    len(got), len(tt.expected))
                return
            }
            
            for k, v := range tt.expected {
                if got[k] != v {
                    t.Errorf("key %s: got %d, want %d", k, got[k], v)
                }
            }
        })
    }
}

边界条件测试 #

go
package myapp

import "testing"

func GetItem(slice []int, index int) (int, error) {
    if index < 0 || index >= len(slice) {
        return 0, ErrIndexOutOfRange
    }
    return slice[index], nil
}

var ErrIndexOutOfRange = errors.New("index out of range")

func TestGetItem(t *testing.T) {
    slice := []int{10, 20, 30}
    
    tests := []struct {
        name      string
        index     int
        expected  int
        expectErr bool
    }{
        {"第一个元素", 0, 10, false},
        {"最后一个元素", 2, 30, false},
        {"中间元素", 1, 20, false},
        {"负索引", -1, 0, true},
        {"超出范围", 3, 0, true},
        {"远超范围", 100, 0, true},
    }
    
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            got, err := GetItem(slice, tt.index)
            
            if tt.expectErr {
                if err == nil {
                    t.Error("expected error but got nil")
                }
                return
            }
            
            if err != nil {
                t.Fatalf("unexpected error: %v", err)
            }
            
            if got != tt.expected {
                t.Errorf("got %d, want %d", got, tt.expected)
            }
        })
    }
}

运行特定子测试 #

bash
go test -v -run TestDivide/正常除法
go test -v -run TestDivide/除零

小结 #

表格驱动测试的优势:

优势 说明
可读性 测试数据和逻辑分离,易于理解
可维护性 添加新测试用例只需添加数据
完整性 容易覆盖各种边界条件
可调试性 每个用例独立运行,便于定位问题

表格驱动测试是Go测试的最佳实践,掌握这种模式可以编写出高质量的测试代码。