c-bata web

日本語の技術ブログ by @c-bata . 英語の記事はMediumに書いています。

GoptunaのRDB storage backendを使った分散ハイパーパラメーター最適化

先月は会社の技術ブログ執筆当番となっていたため、そのネタづくりに Goptuna を実装・公開しました。記事公開時にはTPEとUniformDistributionしかありませんでしたが、その後もコツコツ開発を続け現在はCategoricalDistributionなどOptunaがサポートしているDistributionを全て実装し、枝刈りなどの機能にも一部対応しています。

adtech.cyberagent.io

今の機能なら業務にも投入できるかなと思っていたのですが、いざ使おうとすると最適化の結果を永続化できないこと・Dashboardが見れないことは結構問題でした。今回はGormを使ってがっと書いてみた ので使い方のメモも兼ねて紹介します。 OptunaのDB定義と互換性があるので、Goptunaの実行結果をOptunaのdashboardで閲覧できます。

SQLite3で動かす

Studyの作成

RDB storage backendを追加するついでにOptunaのようなCLIを定義したので、こちらでまずはstudyを作成します。

$ ./bin/goptuna 
A command line interface for Goptuna

Usage:
  goptuna [command]

Available Commands:
  create-study Create a study in your relational database storage.
  help         Help about any command

Flags:
  -h, --help      help for goptuna
      --version   version for goptuna

Use "goptuna [command] --help" for more information about a command.

$ goptuna create-study --storage sqlite:///db.sqlite3 --study rdb
rdb

optuna create-study と基本的に同じ使い方なので、Optunaに慣れた方であれば同じように使えるかと思います。 例えば --storage オプションとかはSQLAlchemyのEngine Database URL formatがそのまま使えます ( https://docs.sqlalchemy.org/en/13/core/engines.html を裏側でパースしてGoのData Source Name形式に変換しています)。 --study オプションはOptunaと同じく省略するとuuidv4を使って自動で名前を割り当てます。

実行後のdb.sqlite3の中はこんな感じです。

$ sqlite3 db.sqlite3 
SQLite version 3.28.0 2019-04-16 19:49:53
Enter ".help" for usage hints.
sqlite> .header on
sqlite> .mode column
sqlite> .tables
alembic_version          trial_params             trials                 
studies                  trial_system_attributes  version_info           
study_system_attributes  trial_user_attributes  
study_user_attributes    trial_values           
sqlite> select * from studies;
study_id    study_name  direction 
----------  ----------  ----------
1           rdb         MINIMIZE  

ちなみにGORMにはauto migrationの機能がありますが、SQLAlchemy + Alembicのようにマイグレーションファイルの差分情報からうまくSQL生成するような機能はありません。Goptuna CLIでテーブルを作成したときは基本的にあとから新しい定義にマイグレーションできないものだと思ってもらったほうが良いかと思います (ハイパーパラメータ最適化では実用上マイグレーションが問題になることはあまりないとは思いますが)。もし長期的にデータベースを管理していきたい場合は、ここだけOptuna CLIを使ってstudyを作成することも可能なのでそちらを検討してください。

StudyをLoadして最適化

最適化は次のように行います。 StudyはCLIで作成済みのため、goptuna.LoadStudy を使用します。あとはrdb.NewStorage(db)*gorm.DB を渡してそれをLoadStudyに指定します。

package main

import (
    "flag"
    "fmt"
    "math"
    "os"

    "github.com/c-bata/goptuna"
    "github.com/c-bata/goptuna/rdb"
    "github.com/c-bata/goptuna/tpe"
    "github.com/jinzhu/gorm"
    "go.uber.org/zap"

    _ "github.com/jinzhu/gorm/dialects/sqlite"
)

func objective(trial goptuna.Trial) (float64, error) {
    x1, _ := trial.SuggestUniform("x1", -10, 10)
    x2, _ := trial.SuggestUniform("x2", -10, 10)
    return math.Pow(x1-2, 2) + math.Pow(x2+5, 2), nil
}

func main() {
    logger, err := zap.NewDevelopment()
    if err != nil {
        os.Exit(1)
    }
    defer logger.Sync()

    db, err := gorm.Open("sqlite3", "db.sqlite3")
    if err != nil {
        logger.Fatal("failed to open db", zap.Error(err))
    }
    storage := rdb.NewStorage(db)
    defer db.Close()

    study, err := goptuna.LoadStudy(
        "rdb",
        goptuna.StudyOptionStorage(storage),
        goptuna.StudyOptionSampler(tpe.NewSampler()),
        goptuna.StudyOptionSetDirection(goptuna.StudyDirectionMinimize),
        goptuna.StudyOptionSetLogger(logger),
    )
    if err != nil {
        logger.Fatal("failed to create study", zap.Error(err))
    }

    if err = study.Optimize(objective, 50); err != nil {
        logger.Fatal("failed to optimize", zap.Error(err))
    }

    v, err := study.GetBestValue()
    if err != nil {
        logger.Fatal("failed to get best value", zap.Error(err))
    }
    params, err := study.GetBestParams()
    if err != nil {
        logger.Fatal("failed to get best params", zap.Error(err))
    }
    fmt.Println("Result:")
    fmt.Println("- best value", v)
    fmt.Println("- best param", params)
}

実行すると次のようになります。

$ go run main.go
2019-08-13T18:41:16.852+0900    INFO    goptuna/study.go:116    Finished trial  {"trialID": 1, "state": "Complete", "value": 47.67429274143393, "params": "map[x1:8.903922985882328 x2:-5.100698294124405]"}
2019-08-13T18:41:16.860+0900    INFO    goptuna/study.go:116    Finished trial  {"trialID": 2, "state": "Complete", "value": 16.564974686882508, "params": "map[x1:3.1191253039081026 x2:-8.913123208005992]"}
...
2019-08-13T18:41:17.281+0900    INFO    goptuna/study.go:116    Finished trial  {"trialID": 50, "state": "Complete", "value": 27.09517768964176, "params": "map[x1:-3.1410659891485793 x2:-5.81524118202008]"}
Result:
- best value 0.03832650787671685
- best param map[x1:2.1816038398071225 x2:-4.926879871143263]

Optuna Dashboardで確認

DBスキーマはOptunaと互換性があるので、Dashboardで結果を見てみましょう。

$ python3.7 -m venv venv
$ source venv/bin/activate
$ pip install optuna bokeh
$ optuna dashboard --storage sqlite:///db.sqlite3 --study rdb
[W 2019-08-13 18:41:18,756] Optuna dashboard is still highly experimental. Please use with caution!
[I 2019-08-13 18:41:18,764] Starting Bokeh server version 1.3.4 (running on Tornado 6.0.3)
[I 2019-08-13 18:41:18,768] Bokeh app running at: http://localhost:5006/dashboard
[I 2019-08-13 18:41:18,768] Starting Bokeh server with process id: 47887

f:id:nwpct1:20190813185112p:plain
optuna-dashboard

問題なく確認できました。

MySQLで動かす

GILの制約をうけてしまうOptunaとは違いRDB storage backendを使用しなくてもGoroutineですでに複数のCPUコアを効率よく使うことはできたので、SQLite3が使えるだけでは結果がファイルとして保存されていてあとから再開したりdashboardを見たりできる以外にはメリットがありません。複数台のマシンを使った分散ハイパーパラメータ最適化のためには、tcpで通信ができるMySQLなどを使うことになるかと思います。動作確認も兼ねてMySQLで動かしてみます。

DockerでMySQL 8.0を用意

$ cat mysql/my.cnf 
[mysqld]
bind-address = 0.0.0.0
default_authentication_plugin=mysql_native_password

$ docker pull mysql:8.0
$ docker run \
  -d \
  --rm \
  -p 3306:3306 \
  --mount type=volume,src=mysql,dst=/etc/mysql/conf.d \
  -e MYSQL_USER=goptuna \
  -e MYSQL_DATABASE=goptuna \
  -e MYSQL_PASSWORD=password \
  -e MYSQL_ALLOW_EMPTY_PASSWORD=yes \
  --name goptuna-mysql \
  mysql:8.0

Goptuna CLIを使ってstudyを作成

storageにはSQLAlchemyのDatabase Engine URL formatがこちらも使用できます (goptuna側でパースしてGoのData Source Nameに変換しています)。現状はまだSQLite3とMySQLしかdialectをサポートしていないので、PostgreSQLとか使いたい場合はIssueかPRをお願いします。

$ goptuna create-study --storage mysql://goptuna:password@localhost:3306/goptuna
no-name-d704a908-bda1-11e9-8f4c-acde48001122

データを確認

$ mysql --host 127.0.0.1 --port 3306 --user goptuna -ppassword
mysql> show tables from goptuna;
+-------------------------+
| Tables_in_goptuna       |
+-------------------------+
| studies                 |
| study_system_attributes |
| study_user_attributes   |
| trial_params            |
| trial_system_attributes |
| trial_user_attributes   |
| trial_values            |
| trials                  |
+-------------------------+
8 rows in set (0.00 sec)

mysql> select * from studies;
+----------+----------------------------------------------+-----------+
| study_id | study_name                                   | direction |
+----------+----------------------------------------------+-----------+
|        1 | no-name-d704a908-bda1-11e9-8f4c-acde48001122 | MINIMIZE  |
+----------+----------------------------------------------+-----------+
1 row in set (0.00 sec)

実行

コードは先程とほとんど変わらず、dialectが mysql になり、DSNをそれにあわせて変更するだけです。 DSNは、 parseTime=true を忘れずにつけてください。

package main

import (
    "flag"
    "fmt"
    "math"
    "os"

    "github.com/c-bata/goptuna"
    "github.com/c-bata/goptuna/rdb"
    "github.com/c-bata/goptuna/tpe"
    "github.com/jinzhu/gorm"
    "go.uber.org/zap"

    _ "github.com/jinzhu/gorm/dialects/mysql"
    _ "github.com/jinzhu/gorm/dialects/sqlite"
)

func objective(trial goptuna.Trial) (float64, error) {
    x1, err := trial.SuggestUniform("x1", -10, 10)
    if err != nil {
        return 0.0, err
    }
    x2, err := trial.SuggestUniform("x2", -10, 10)
    if err != nil {
        return 0.0, err
    }
    return math.Pow(x1-2, 2) + math.Pow(x2+5, 2), nil
}

func main() {
    logger, err := zap.NewDevelopment()
    if err != nil {
        os.Exit(1)
    }
    defer logger.Sync()

    db, err := gorm.Open("mysql", "goptuna:password@tcp(localhost:3306)/goptuna?parseTime=true")
    if err != nil {
        logger.Fatal("failed to open db", zap.Error(err))
    }

    storage := rdb.NewStorage(db)
    defer db.Close()

    study, err := goptuna.LoadStudy(
        "rdb",
        goptuna.StudyOptionStorage(storage),
        goptuna.StudyOptionSampler(tpe.NewSampler()),
        goptuna.StudyOptionSetDirection(goptuna.StudyDirectionMinimize),
        goptuna.StudyOptionSetLogger(logger),
    )
    if err != nil {
        logger.Fatal("failed to create study", zap.Error(err))
    }

    if err = study.Optimize(objective, 50); err != nil {
        logger.Fatal("failed to optimize", zap.Error(err))
    }

    v, err := study.GetBestValue()
    if err != nil {
        logger.Fatal("failed to get best value", zap.Error(err))
    }
    params, err := study.GetBestParams()
    if err != nil {
        logger.Fatal("failed to get best params", zap.Error(err))
    }
    fmt.Println("Result:")
    fmt.Println("- best value", v)
    fmt.Println("- best param", params)
}

Dashboardで確認

Dashboardで確認するときは次のコマンド

$ optuna dashboard --storage mysql+mysqldb://goptuna:password@127.0.0.1:3306/goptuna --study rdb

まとめ

とりあえず動かし方をメモもかねてMySQLやSQLite3を使う方法を紹介しました。GORMが思っていたより使いやすく結構いいペースで実装できました。Dashboardも使えるようになったのは実用上かなり大きいかなと思っています。

Optunaの開発チームの方でDashboardを1から作っているようなので場合によっては静的ファイルをGoのバイナリにbundleして、Goptuna CLIからdashboard立てられるようにしたい。