空飛ぶロボットのつくりかた

ロボットをつくるために必要な技術をまとめます。ロボットの未来についても考えたりします。

機械学習のお勉強(chainerのTrainerについて)

公式Docs

How to write a training loop in Chainer — Chainer 3.0.0rc1 documentation

わかりやすいまとめ

Chainer の Trainer 解説と NStepLSTM について

MNIST分類コードをChainer-v1.11.0のTrainerで書き換える - Monthly Hacker's Blog

Chainer: ビギナー向けチュートリアル Vol.1 - Qiita

使い方

  1. Trainerオブジェクト(trainer)をつくる

  2. trainer.run()で実行

Trainerオブジェクトの例

ChainerのTrainerを使ってみた - のんびりしているエンジニアの日記さんわかりやすい記事ありがとうございます👇

# coding:utf-8
from __future__ import absolute_import
from __future__ import unicode_literals
import chainer
import chainer.datasets
from chainer import training
from chainer.training import extensions
import chainer.links as L
import chainer.functions as F


class MLP(chainer.Chain):
    def __init__(self, n_units, n_out):
        super(MLP, self).__init__(
            l1=L.Linear(None, n_units),
            l2=L.Linear(None, n_units),
            l3=L.Linear(None, n_out),
        )

    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)


train, test = chainer.datasets.get_mnist()
train_iter = chainer.iterators.SerialIterator(train, 32)
test_iter = chainer.iterators.SerialIterator(test, 32,
                                             repeat=False, shuffle=False)
model = L.Classifier(MLP(784, 10))
optimizer = chainer.optimizers.SGD()
optimizer.setup(model)
updater = training.StandardUpdater(train_iter, optimizer, device=-1)
trainer = training.Trainer(updater, (10, 'epoch'), out="result")  # epoch数の指定

trainer.extend(extensions.Evaluator(test_iter, model, device=10)) # 評価
trainer.extend(extensions.dump_graph('main/loss'))
trainer.extend(extensions.snapshot(), trigger=(10, 'epoch'))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(
    ['epoch', 'main/loss', 'validation/main/loss',
     'main/accuracy', 'validation/main/accuracy']))
trainer.extend(extensions.ProgressBar())
trainer.run()

参考:

chainer/train_mnist.py at master · chainer/chainer · GitHub

chainer 1.11.0のMNISTサンプルを例にtrainerを読み解く - Monthly Hacker's Blog

Extension

Trainer extensions — Chainer 2.0.2 documentation

Trainerオブジェクトから学習済みモデルを取り出す方法

model = trainer.updater.get_optimizer('main').target.predictor
prediction = model(img)

chainer全体のわかりやすい記事

【機械学習】ディープラーニング フレームワークChainerを試しながら解説してみる。 - Qiita