機械学習のお勉強(chainerのTrainerについて)
公式Docs
How to write a training loop in Chainer — Chainer 3.0.0rc1 documentation
わかりやすいまとめ
Chainer の Trainer 解説と NStepLSTM について
MNIST分類コードをChainer-v1.11.0のTrainerで書き換える - Monthly Hacker's Blog
Chainer: ビギナー向けチュートリアル Vol.1 - Qiita
使い方
Trainerオブジェクト(trainer)をつくる
trainer.run()で実行
Trainerオブジェクトの例
ChainerのTrainerを使ってみた - のんびりしているエンジニアの日記さんわかりやすい記事ありがとうございます👇
# coding:utf-8 from __future__ import absolute_import from __future__ import unicode_literals import chainer import chainer.datasets from chainer import training from chainer.training import extensions import chainer.links as L import chainer.functions as F class MLP(chainer.Chain): def __init__(self, n_units, n_out): super(MLP, self).__init__( l1=L.Linear(None, n_units), l2=L.Linear(None, n_units), l3=L.Linear(None, n_out), ) def __call__(self, x): h1 = F.relu(self.l1(x)) h2 = F.relu(self.l2(h1)) return self.l3(h2) train, test = chainer.datasets.get_mnist() train_iter = chainer.iterators.SerialIterator(train, 32) test_iter = chainer.iterators.SerialIterator(test, 32, repeat=False, shuffle=False) model = L.Classifier(MLP(784, 10)) optimizer = chainer.optimizers.SGD() optimizer.setup(model) updater = training.StandardUpdater(train_iter, optimizer, device=-1) trainer = training.Trainer(updater, (10, 'epoch'), out="result") # epoch数の指定 trainer.extend(extensions.Evaluator(test_iter, model, device=10)) # 評価 trainer.extend(extensions.dump_graph('main/loss')) trainer.extend(extensions.snapshot(), trigger=(10, 'epoch')) trainer.extend(extensions.LogReport()) trainer.extend(extensions.PrintReport( ['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy'])) trainer.extend(extensions.ProgressBar()) trainer.run()
参考:
chainer/train_mnist.py at master · chainer/chainer · GitHub
chainer 1.11.0のMNISTサンプルを例にtrainerを読み解く - Monthly Hacker's Blog
Extension
Trainer extensions — Chainer 2.0.2 documentation
Trainerオブジェクトから学習済みモデルを取り出す方法
model = trainer.updater.get_optimizer('main').target.predictor
prediction = model(img)