かとじゅんの技術日誌

技術の話をするところ

パラレルコレクションの性能測定

Scala Advent Calendar jp 2011 6日目 いきます。

STMの話にしようと思ったのですが、いろいろまだ調査中なんでまた後日ということで、今回はパラレルコレクションでいきます。すでにあちこちのブログで扱っているネタなので目新しさはないですが...

パラレルコレクションは2.9から使える新機能です。
早速 使い方。通常のコレクションの要素を2倍する処理は次のように記述します。

List(1,2,3).map(_ * 2)

一方、パラレルコレクションではparメソッドを使います。

List(1,2,3).par.map(_ * 2)

scala.collection.immutable.List#parはParSeq[A]型の戻り値を返します。ParSeq#mapを呼び出すだけでmapを並行に処理できるわけです*1。本来並行処理を実装する場合は、スレッドの起動や待機、スレッドプールとタスクの管理など複雑な制御が伴いますが、パラレルコレクションの場合はparメソッドを呼ぶだけ並行処理を記述できます。
実際のテストプログラム。Benchのmainメソッドがエントリポイントです。引数に応じて通常のコレクションでの処理と、パラレルコレクションでの処理を呼び分けます。

package parallel
import scala.collection.immutable.NumericRange

// プログラム本体
object Bench {

  // nまでの階乗を計算するメソッド
  def fac(n: BigInt) =
    NumericRange(BigInt(1), n, BigInt(1)).
      foldLeft(BigInt(1)) { (cur, next) =>
        cur * next
      }

  def main(args: Array[String]): Unit = {
    import parallel.BenchUtil._
    args match {
      case Array("N") =>
        bench(50, "normal") {
          (1 to 2000).map { x => fac(x) }
        }
      case Array("P") =>
        bench(50, "parallel") {
          (1 to 2000).par.map { x => fac(x) }
        }
    }
  }

}

BenchUtil#benchメソッドは計測した結果をソートして、前後20%ずつ削除した値だけを利用して平均値や標準偏差、最大最小値を計算。

package parallel
import scala.compat.Platform

// ベンチマーク用ユーティリティクラス
object BenchUtil {

  private def avg(xs: List[BigDecimal]): BigDecimal =
    xs.sum / xs.size

  private def std(xs: List[BigDecimal]): BigDecimal = {
    val a = avg(xs)
    Math.sqrt((xs.foldLeft(BigDecimal(0))((s, c) => s + (c - a) * (c + a)) / xs.size).toDouble)
  }

  private def median(xs: List[BigDecimal]) = xs.toSet.toList.sortWith(_ < _) match {
    case n :: Nil => n
    case xs if xs.size % 2 != 0 => xs(xs.size / 2)
    case xs if xs.size % 2 == 0 => {
      val a = xs(xs.size / 2 - 1)
      val b = xs(xs.size / 2)
      (a + b) / 2
    }
    case _ => throw new RuntimeException
  }

  private def mode(xs: List[BigDecimal]): BigDecimal =
    xs.foldLeft(Map[BigDecimal, Int]().withDefaultValue(0)) { (map, key) => map + (key -> (map(key) + 1)) } maxBy (_._2) _1

  def bench(n: Int, msg:String)(f: => Unit) {
    val times = for (i <- List.range(1, n + 1, 1)) yield {
      Platform.collectGarbage
      val start = System.nanoTime
      f
      val stop = System.nanoTime
      BigDecimal(stop - start) / 1000 / 1000
    }
    val truncate = n / 5
    val result = times.sortWith(_ < _).view(truncate, n - truncate).toList
    if (result.size > 0) {
      println("%s, threadId = %d, n = %d, avg = %11.2f, std = %11.2f, median = %11.2f, mode = %11.2f, min = %11.2f, max = %11.2f".
        format(msg, Thread.currentThread.getId, result.size, avg(result), std(result), median(result), mode(result), result.min, result.max))
    }
  }
}

結果は次のとおり。私の環境ではパラレルコレクションを使った処理の方が3倍ぐらい高速になりました。もっとコアが多いマシンでテストしたいところだけど、個人ではこれが限界。

動作環境: MacBook Pro 15インチ Corei7 2GHz(物理コアは4つ。仮想コアとして8つ) 
単位はmsec
normal  , threadId = 1, n = 30, avg = 3177.07, std = 10.39, median = 3179.97, mode = 3184.39, min = 3155.66, max = 3191.19
parallel, threadId = 1, n = 30, avg = 1110.68, std = 27.62, median = 1113.52, mode = 1081.60, min = 1065.08, max = 1157.97

CPU負荷は具体的な数値はとってませんが、パラレルの方はちゃんと全部使っている感じ。
通常のコレクション

パラレルコレクション

使い込んでみないと具体的にどういうところで使えるかわかりませんが、動画のエンコードとか面白そうかなと思ったりしています。

あわせて読みたい
Scala 並列コレクション メモ(Hishidama's Scala parallel collections Memo)
scala2.9のparallel collection の benchmark をしてみた - scalaとか・・・
Scalaの並列コレクションで実際に並列化されているメソッドを調べてみた - chimerastのエレガント指向プログラミング日記

*1:当然mapに渡す関数に副作用がないことが前提