Java/Scala用機械学習ライブラリ「Smile」を使ってみる

手軽に使える機械学習ライブラリというとPythonのscikit-learnが有名ですが、Java/ScalaでもSmileというライブラリがあったので軽く試してみました。

github.com

まずはリリースページからzipファイルをダウンロードして適当な場所に展開します。bin/smileで対話シェルが起動します。

f:id:takezoe:20190104102332p:plain

サンプルデータがついているのでこれを使って学習してみます。

smile> val toy = read.table("data/classification/toy/toy-train.txt", response = Some(new NominalAttribute("class"), 0))
smile> val (x, y) = toy.unzipInt
smile> val model = knn(x, y, 3)

作ったモデルを使って予測してみます。

smile> model.predict(x(0))
res3: Int = 0

学習データとモデルをプロットしてみます。別ウィンドウが立ち上がって以下のようなグラフが表示されるはずです。SmileはJupyterでも使うことができるようですが、その場合はちゃんとノートブックにグラフが表示されるのだろうか…。

smile> plot(x, y, model)

f:id:takezoe:20190104103334p:plain

テストデータを使って精度を確かめてみます。

smile> val test = read.table("data/classification/toy/toy-test.txt", response = Some(new NominalAttribute("class"), 0))
smile> val (testx, testy) = test.unzipInt
smile> val pred = testx.map(model.predict(_))
smile> accuracy(testy, pred)
res9: Double = 0.81205

モデルの保存、読み込みはJava標準のバイナリシリアライゼーションで行われるようです。Xstreamを使ってXML形式で保存する機能も用意されているようです。読み込むときはちゃんとキャストしてあげないとAnyRef型になってしまうので注意が必要です。

smile> write(model, "toy.model")
smile> val model = read("toy.model").asInstanceOf[KNN[Array[Double]]]

データ加工にScalaのコレクションAPIを使えるのは便利ですが、もうちょっと高度なライブラリが欲しいかなという気もします。その辺は別のライブラリを組み合わせればよいのでしょう。以下のような感じで外部ライブラリが使えるみたいです。

smile> import $ivy.`org.scalaz::scalaz-core:7.2.7`, scalaz._, Scalaz._

思っていたよりサクッと使えましたし、対話シェルやグラフの描画などもよくできていました。Java/Scalaアプリケーションにちょっとした機械学習を利用した機能を組み込みたい場合には便利に使えるのではないかと思います。