高阶函数 #

一、高阶函数概述 #

1.1 什么是高阶函数 #

高阶函数是满足以下条件之一的函数:

  1. 接受一个或多个函数作为参数
  2. 返回一个函数作为结果

1.2 函数类型 #

scala
val double: Int => Int = x => x * 2
val add: (Int, Int) => Int = (a, b) => a + b
val process: String => Int => String = s => n => s * n

二、函数作为参数 #

2.1 基本示例 #

scala
def applyFunction(x: Int, f: Int => Int): Int = f(x)

def double(x: Int): Int = x * 2
def square(x: Int): Int = x * x

applyFunction(5, double)
applyFunction(5, square)

2.2 使用匿名函数 #

scala
applyFunction(5, x => x * 2)
applyFunction(5, _ * 2)

2.3 多个函数参数 #

scala
def compose[A, B, C](f: B => C, g: A => B)(x: A): C =
  f(g(x))

val addOne = (x: Int) => x + 1
val double = (x: Int) => x * 2

compose(double, addOne)(5)

三、函数作为返回值 #

3.1 返回函数 #

scala
def multiplier(factor: Int): Int => Int =
  x => x * factor

val double = multiplier(2)
val triple = multiplier(3)

double(5)
triple(5)

3.2 闭包 #

返回的函数可以捕获外部变量:

scala
def makeCounter(): () => Int =
  var count = 0
  () =>
    count += 1
    count

val counter = makeCounter()
counter()
counter()
counter()

3.3 配置函数 #

scala
def createValidator(min: Int, max: Int): Int => Boolean =
  x => x >= min && x <= max

val isValidAge = createValidator(0, 150)
val isValidPercent = createValidator(0, 100)

isValidAge(25)
isValidPercent(105)

四、常用高阶函数 #

4.1 map #

对集合中每个元素应用函数:

scala
val numbers = List(1, 2, 3, 4, 5)

val doubled = numbers.map(_ * 2)
val squared = numbers.map(x => x * x)
val stringified = numbers.map(_.toString)

4.2 filter #

过滤满足条件的元素:

scala
val numbers = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)

val evens = numbers.filter(_ % 2 == 0)
val positives = numbers.filter(_ > 0)
val large = numbers.filter(_ > 5)

4.3 filterNot #

过滤不满足条件的元素:

scala
val numbers = List(1, 2, 3, 4, 5)

val odds = numbers.filterNot(_ % 2 == 0)

4.4 flatMap #

映射后展平:

scala
val lists = List(List(1, 2), List(3, 4), List(5, 6))

val flattened = lists.flatMap(identity)
val flattened2 = lists.flatten

val words = List("Hello", "World")
val chars = words.flatMap(_.toList)

4.5 fold / foldLeft / foldRight #

scala
val numbers = List(1, 2, 3, 4, 5)

val sum = numbers.foldLeft(0)(_ + _)
val product = numbers.foldLeft(1)(_ * _)
val reversed = numbers.foldLeft(List.empty[Int])((acc, x) => x :: acc)

val sumRight = numbers.foldRight(0)(_ + _)
val reversedRight = numbers.foldRight(List.empty[Int])((x, acc) => acc :+ x)

4.6 reduce / reduceLeft / reduceRight #

scala
val numbers = List(1, 2, 3, 4, 5)

val sum = numbers.reduce(_ + _)
val max = numbers.reduce(_ max _)
val min = numbers.reduce(_ min _)

4.7 groupBy #

按条件分组:

scala
val words = List("apple", "banana", "cherry", "date", "elderberry")

val byLength = words.groupBy(_.length)
val byFirstChar = words.groupBy(_.head)

4.8 partition #

分成两组:

scala
val numbers = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)

val (evens, odds) = numbers.partition(_ % 2 == 0)
val (small, large) = numbers.partition(_ < 5)

4.9 takeWhile / dropWhile #

scala
val numbers = List(1, 2, 3, 4, 5, 1, 2, 3)

val firstPart = numbers.takeWhile(_ < 4)
val rest = numbers.dropWhile(_ < 4)

4.10 span #

scala
val numbers = List(1, 2, 3, 4, 5, 1, 2, 3)

val (first, second) = numbers.span(_ < 4)

五、函数组合 #

5.1 compose #

f compose g 先执行 g,再执行 f:

scala
val addOne = (x: Int) => x + 1
val double = (x: Int) => x * 2

val addOneThenDouble = double compose addOne
val doubleThenAddOne = addOne compose double

addOneThenDouble(5)
doubleThenAddOne(5)

5.2 andThen #

f andThen g 先执行 f,再执行 g:

scala
val addOne = (x: Int) => x + 1
val double = (x: Int) => x * 2

val addOneThenDouble = addOne andThen double
val doubleThenAddOne = double andThen addOne

addOneThenDouble(5)
doubleThenAddOne(5)

5.3 链式组合 #

scala
val process = ((_: Int) * 2)
  .andThen(_ + 1)
  .andThen(_ * 3)

process(5)

六、偏应用函数 #

6.1 固定部分参数 #

scala
def add(a: Int, b: Int, c: Int): Int = a + b + c

val addFive = add(5, _: Int, _: Int)
addFive(10, 20)

val addFiveAndTen = add(5, 10, _: Int)
addFiveAndTen(20)

6.2 使用场景 #

scala
def log(level: String, message: String): Unit =
  println(s"[$level] $message")

val infoLog = log("INFO", _: String)
val errorLog = log("ERROR", _: String)

infoLog("Application started")
errorLog("Something went wrong")

七、柯里化 #

7.1 定义柯里化函数 #

scala
def add(a: Int)(b: Int): Int = a + b

val addFive = add(5) _
addFive(10)

7.2 柯里化现有函数 #

scala
def add(a: Int, b: Int): Int = a + b

val curriedAdd = (add _).curried
val addFive = curriedAdd(5)
addFive(10)

7.3 实用示例 #

scala
def filterBy(predicate: Int => Boolean)(list: List[Int]): List[Int] =
  list.filter(predicate)

val filterEvens = filterBy(_ % 2 == 0) _
val filterPositives = filterBy(_ > 0) _

filterEvens(List(1, 2, 3, 4, 5))
filterPositives(List(-1, 0, 1, 2))

八、高阶函数实战 #

8.1 数据处理管道 #

scala
case class Person(name: String, age: Int, salary: Double)

val people = List(
  Person("Alice", 25, 50000),
  Person("Bob", 35, 75000),
  Person("Charlie", 28, 60000),
  Person("David", 45, 90000)
)

val result = people
  .filter(_.age >= 30)
  .map(_.salary)
  .reduce(_ + _)

8.2 验证器组合 #

scala
type Validator[T] = T => Boolean

def validateAll[T](validators: Validator[T]*): Validator[T] =
  x => validators.forall(_(x))

def validateAny[T](validators: Validator[T]*): Validator[T] =
  x => validators.exists(_(x))

val isAdult: Validator[Int] = _ >= 18
val isNotTooOld: Validator[Int] = _ <= 100
val isValidAge = validateAll(isAdult, isNotTooOld)

isValidAge(25)
isValidAge(150)

8.3 构建器模式 #

scala
case class Query(
  table: String,
  columns: List[String] = List("*"),
  where: Option[String] = None,
  orderBy: Option[String] = None,
  limit: Option[Int] = None
)

def select(columns: String*): Query => Query =
  q => q.copy(columns = columns.toList)

def from(table: String): Query = Query(table)

def where(condition: String): Query => Query =
  q => q.copy(where = Some(condition))

def orderBy(column: String): Query => Query =
  q => q.copy(orderBy = Some(column))

def limit(n: Int): Query => Query =
  q => q.copy(limit = Some(n))

val query = from("users")
  .pipe(select("id", "name", "email"))
  .pipe(where("age > 18"))
  .pipe(orderBy("created_at"))
  .pipe(limit(10))

九、性能考虑 #

9.1 避免多次遍历 #

scala
val result = list
  .map(_.trim)
  .filter(_.nonEmpty)
  .map(_.toUpperCase)

val result2 = list.collect {
  case s if s.trim.nonEmpty => s.trim.toUpperCase
}

9.2 使用 view 延迟计算 #

scala
val result = list.view
  .map(_ * 2)
  .filter(_ > 10)
  .take(5)
  .toList

9.3 并行处理 #

scala
val result = list.par
  .map(expensiveOperation)
  .toList

十、总结 #

常用高阶函数 #

函数 用途 返回类型
map 转换元素 新集合
filter 过滤元素 新集合
flatMap 映射并展平 新集合
fold 累积计算 单个值
reduce 累积计算 单个值
groupBy 分组 Map
partition 分割 元组

函数组合 #

方法 执行顺序
compose g 然后 f
andThen f 然后 g

下一步,让我们学习 匿名函数与Lambda

最后更新:2026-03-27