SalesforceのScala製AutoMLライブラリ「TransmogrifAI」を触ってみた

AutoMLはこれまで専門のエンジニアを必要としていたような機械学習の処理を自動化し、誰でも機械学習を利用できるようにするという分野です。PythonだとTPOTなどのライブラリが存在しますが、先日Salefsforce社からScala + SparkベースのAutoMLライブラリが発表されたとのことで軽く触ってみました。

github.com

Hello Worldの最初のサンプルがタイタニック号の乗客の生存予測だったのですが、データの加工処理なども入っていたのでそれらを除いてシンプルにした形のサンプルを紹介したいと思います。

まずはこんな感じのimport文が必要です。

import com.salesforce.op._
import com.salesforce.op.evaluators.Evaluators
import com.salesforce.op.features.FeatureBuilder
import com.salesforce.op.features.types._
import com.salesforce.op.readers.DataReaders
import com.salesforce.op.stages.impl.classification.BinaryClassificationModelSelector
import com.salesforce.op.stages.impl.classification.BinaryClassificationModelsToTry._
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession

CSVデータに対応するケースクラスを作っておきます。

case class Passenger
(
  id: Int,
  survived: Int,
  pClass: Option[Int],
  name: Option[String],
  sex: Option[String],
  age: Option[Double],
  sibSp: Option[Int],
  parCh: Option[Int],
  ticket: Option[String],
  fare: Option[Double],
  cabin: Option[String],
  embarked: Option[String]
)

Sparkを使うのでSparkSessionも必要になります。

val conf = new SparkConf().setAppName("OpTitanicSimple")
implicit val spark = SparkSession.builder.config(conf).getOrCreate()

ここからがTransmogrifAIならではの部分になります。まずはケースクラスのフィールドに対応した特徴を定義します。予測対象は.asResponse、推定に使うフィールドは.asPredictorにしておきます。ここで型を付けておくのでいろいろ自動化が可能らしいです。

val survived = FeatureBuilder.RealNN[Passenger].extract(_.survived.toRealNN).asResponse
val pClass = FeatureBuilder.PickList[Passenger].extract(_.pClass.map(_.toString).toPickList).asPredictor
val name = FeatureBuilder.Text[Passenger].extract(_.name.toText).asPredictor
val sex = FeatureBuilder.PickList[Passenger].extract(_.sex.map(_.toString).toPickList).asPredictor
val age = FeatureBuilder.Real[Passenger].extract(_.age.toReal).asPredictor
val sibSp = FeatureBuilder.Integral[Passenger].extract(_.sibSp.toIntegral).asPredictor
val parCh = FeatureBuilder.Integral[Passenger].extract(_.parCh.toIntegral).asPredictor
val ticket = FeatureBuilder.PickList[Passenger].extract(_.ticket.map(_.toString).toPickList).asPredictor
val fare = FeatureBuilder.Real[Passenger].extract(_.fare.toReal).asPredictor
val cabin = FeatureBuilder.PickList[Passenger].extract(_.cabin.map(_.toString).toPickList).asPredictor
val embarked = FeatureBuilder.PickList[Passenger].extract(_.embarked.map(_.toString).toPickList).asPredictor

特徴をSeqにまとめてtransmogrify()でベクトルに変換し、モデルを定義します。

// 特徴ベクトルを作成
val passengerFeatures = Seq(
  pClass, name, sex, age, sibSp, parCh, ticket, fare, 
  cabin, embarked
).transmogrify()

// モデルを定義
val prediction =
  BinaryClassificationModelSelector.withTrainValidationSplit(
    modelTypesToUse = Seq(OpLogisticRegression)
  ).setInput(survived, passengerFeatures).getOutput()

学習データを使って学習を行います。

// ケースクラスをエンコードするために必要
import spark.implicits._
// CSVファイルから学習データを読み込むReader
val trainDataReader = DataReaders.Simple.csvCase[Passenger](
  path = Option("/tmp/TitanicPassengersTrainData.csv"), 
  key = _.id.toString
)

// ワークフローを定義
val workflow =
  new OpWorkflow()
    .setResultFeatures(survived, prediction)
    .setReader(trainDataReader)

// 学習
val fittedWorkflow = workflow.train()

学習済みのワークフローはファイルに保存したり、新しいデータに対して予測を行ったりできます。ちなみにscore()の戻り値はDataFrameで返ってきます。

// 学習済みのワークフローをファイルに保存可能
fittedWorkflow.save("/tmp/OpTitanicSimple.model", true)

// ファイルから読み込んで使用
val loadedWorkflow = new OpWorkflow()
    .setResultFeatures(survived, prediction)
    .loadModel("/tmp/OpTitanicSimple.model")

// 新しいデータを読み込むReader
val newDataReader = DataReaders.Simple.csvCase[Passenger](
  path = Some("/tmp/NewData.csv"), 
  key = _.id.toString
)

// 新しいデータに対して予測
loadedWorkflow.setReader(newDataReader)
val result = fittedWorkflow.score()

score()の代わりにscoreAndEvaluate()を使うと結果を評価することもできます。この場合、resultに予測結果、metricsに評価結果が入ってきます。TransmogrifAIには学習、予測、評価などをテンプレート化したラッパーも用意されており、それを使用することでこのあたりをもう少しスッキリ記述することもできます。

val evaluator = Evaluators.BinaryClassification()
    .setLabelCol(survived)
    .setPredictionCol(prediction)

val (result, metrics) = fittedWorkflow
    .scoreAndEvaluate(evaluator = evaluator)

少し手数は多いですが、基本的には特徴を渡すだけで予測ができるというコンセプトは伝わるのではないかと思います。データをちゃんと加工するともっと精度が良くなるのですが(実際にTransmogrifAIに付属しているサンプルコードでは多少加工が行われています)、このあたりや特徴の型付けまで自動化されると凄いですね。

TransmogrifAIのモチベーションや技術的背景については以下のブログ記事に書かれています。Sparkを使用しているのは、Salesforceでは大量データを扱う必要があること、バッチとストリーミングの両方の形態でモデルをserveする必要があることが理由として挙げられています。

engineering.salesforce.com

自分は機械学習エンジニアというわけではないので、内部の実装や他のライブラリの事情などよくわかっていないのですが、汎用的な用途でそこそこの精度が出ればよいという使い方であれば機械学習の専門知識がないエンジニアでもできるようになるのではという感覚があり、なかなか未来を感じますね。