MNISTの学習データは6万画像もあって、非常に多い。
もっと少ない枚数でも大丈夫かも。
ということで、学習データ画像数とテストの正解率の関係を知りたいと思ったのだが、コマンド引数にはデータ数を指定する項目がない。
それで、サンプルソース(train_mnist.py)プログラムを見たら、コマンド引数の処理をargparseで行っているようだ。
main()の最初の部分を以下に示す。
import argparse
def main():
parser = argparse.ArgumentParser(description='Chainer example: MNIST')
parser.add_argument('--batchsize', '-b', type=int, default=100,
help='Number of images in each mini-batch')
parser.add_argument('--epoch', '-e', type=int, default=20,
help='Number of sweeps over the dataset to train')
parser.add_argument('--frequency', '-f', type=int, default=-1,
help='Frequency of taking a snapshot')
parser.add_argument('--gpu', '-g', type=int, default=-1,
help='GPU ID (negative value indicates CPU)')
parser.add_argument('--out', '-o', default='result',
help='Directory to output the result')
parser.add_argument('--resume', '-r', default='',
help='Resume the training from snapshot')
parser.add_argument('--unit', '-u', type=int, default=1000,
help='Number of units')
args = parser.parse_args()
print('GPU: {}'.format(args.gpu))
print('# unit: {}'.format(args.unit))
print('# Minibatch-size: {}'.format(args.batchsize))
print('# epoch: {}'.format(args.epoch))
print('')
ということで、Pythonのマニュアルを調べたら、説明があった。
16.4. argparse — コマンドラインオプション、引数、サブコマンドのパーサー
このドキュメントを実はほとんど参考にせず、元のプログラムにちょっと手を入れてみた。
parser.add_argument('--number', '-n', type=int, default=60000,
help='Number of training data')
args = parser.parse_args()
print('GPU: {}'.format(args.gpu))
print('# number: {}'.format(args.number))
print('# unit: {}'.format(args.unit))
print('# Minibatch-size: {}'.format(args.batchsize))
print('# epoch: {}'.format(args.epoch))
print('')
これで、引数 -n が読み取られるはずだ。
引数が省略されているときは、全データ数の60000をdefault値として設定しておいた。
さて、実際に学習のデータ量を、引数で指定した値、 args.number まで減らすには、データを読み込んだところで、先頭から指定した個数だけにしてしまうことにした。
# Load the MNIST dataset
train, test = chainer.datasets.get_mnist()
train = train[:args.number]
これで実行すると、こんな感じになった。
Chainer$ python train_mnist0.py -n 1000
GPU: -1
# number: 1000
# unit: 1000
# Minibatch-size: 100
# epoch: 20
epoch main/loss validation/main/loss main/accuracy validation/main/accuracy elapsed_time
1 1.31764 0.606311 0.631 0.8063 0.632579
2 0.444206 0.490377 0.867 0.8456 1.48039
3 0.259161 0.44472 0.911 0.8633 2.32615
4 0.149083 0.41302 0.958 0.8786 3.21353
5 0.0815992 0.402668 0.983 0.8868 4.07794
6 0.0438824 0.434466 0.994 0.8755 4.95173
7 0.0247996 0.412777 0.999 0.8861 5.80761
8 0.0134032 0.398092 1 0.8901 6.66825
9 0.00692543 0.387881 1 0.8963 7.55733
10 0.00444668 0.397806 1 0.8985 8.4276
11 0.00350568 0.408884 1 0.8968 9.27244
12 0.00275337 0.409523 1 0.8982 10.1134
13 0.00231785 0.413871 1 0.8971 10.9584
14 0.00198377 0.419695 1 0.8968 11.8287
15 0.00173918 0.422949 1 0.8973 12.6778
16 0.00155182 0.426604 1 0.897 13.5328
17 0.00139003 0.42915 1 0.897 14.395
18 0.0012566 0.431308 1 0.8975 15.2749
19 0.00113922 0.434765 1 0.897 16.1276
20 0.00104006 0.436891 1 0.8973 17.0124
Chainer$
学習データ数が60000から1000に減ると、かなり速くなった。
ということで、とりあえず学習データの個数をコマンドラインから制御できるようになった。
データ数による学習成果への影響については、次回調べることにする。