Pythonのコマンド引数パーサー


2017年 04月 07日

MNISTの学習データは6万画像もあって、非常に多い。
もっと少ない枚数でも大丈夫かも。
ということで、学習データ画像数とテストの正解率の関係を知りたいと思ったのだが、コマンド引数にはデータ数を指定する項目がない。

それで、サンプルソース(train_mnist.py)プログラムを見たら、コマンド引数の処理をargparseで行っているようだ。
main()の最初の部分を以下に示す。

import argparse

def main():
parser = argparse.ArgumentParser(description='Chainer example: MNIST')
parser.add_argument('--batchsize', '-b', type=int, default=100,
help='Number of images in each mini-batch')
parser.add_argument('--epoch', '-e', type=int, default=20,
help='Number of sweeps over the dataset to train')
parser.add_argument('--frequency', '-f', type=int, default=-1,
help='Frequency of taking a snapshot')
parser.add_argument('--gpu', '-g', type=int, default=-1,
help='GPU ID (negative value indicates CPU)')
parser.add_argument('--out', '-o', default='result',
help='Directory to output the result')
parser.add_argument('--resume', '-r', default='',
help='Resume the training from snapshot')
parser.add_argument('--unit', '-u', type=int, default=1000,
help='Number of units')
args = parser.parse_args()

print('GPU: {}'.format(args.gpu))
print('# unit: {}'.format(args.unit))
print('# Minibatch-size: {}'.format(args.batchsize))
print('# epoch: {}'.format(args.epoch))
print('')
ということで、Pythonのマニュアルを調べたら、説明があった。
16.4. argparse — コマンドラインオプション、引数、サブコマンドのパーサー
このドキュメントを実はほとんど参考にせず、元のプログラムにちょっと手を入れてみた。
    parser.add_argument('--number', '-n', type=int, default=60000,
help='Number of training data')
args = parser.parse_args()

print('GPU: {}'.format(args.gpu))
print('# number: {}'.format(args.number))
print('# unit: {}'.format(args.unit))
print('# Minibatch-size: {}'.format(args.batchsize))
print('# epoch: {}'.format(args.epoch))
print('')
これで、引数 -n が読み取られるはずだ。 引数が省略されているときは、全データ数の60000をdefault値として設定しておいた。 さて、実際に学習のデータ量を、引数で指定した値、 args.number まで減らすには、データを読み込んだところで、先頭から指定した個数だけにしてしまうことにした。
    # Load the MNIST dataset
train, test = chainer.datasets.get_mnist()
train = train[:args.number]

これで実行すると、こんな感じになった。
Chainer$ python train_mnist0.py -n 1000
GPU: -1
# number: 1000
# unit: 1000
# Minibatch-size: 100
# epoch: 20

epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  elapsed_time
1           1.31764     0.606311              0.631          0.8063                    0.632579
2           0.444206    0.490377              0.867          0.8456                    1.48039
3           0.259161    0.44472               0.911          0.8633                    2.32615
4           0.149083    0.41302               0.958          0.8786                    3.21353
5           0.0815992   0.402668              0.983          0.8868                    4.07794
6           0.0438824   0.434466              0.994          0.8755                    4.95173
7           0.0247996   0.412777              0.999          0.8861                    5.80761
8           0.0134032   0.398092              1              0.8901                    6.66825
9           0.00692543  0.387881              1              0.8963                    7.55733
10          0.00444668  0.397806              1              0.8985                    8.4276
11          0.00350568  0.408884              1              0.8968                    9.27244
12          0.00275337  0.409523              1              0.8982                    10.1134
13          0.00231785  0.413871              1              0.8971                    10.9584
14          0.00198377  0.419695              1              0.8968                    11.8287
15          0.00173918  0.422949              1              0.8973                    12.6778
16          0.00155182  0.426604              1              0.897                     13.5328
17          0.00139003  0.42915               1              0.897                     14.395
18          0.0012566   0.431308              1              0.8975                    15.2749
19          0.00113922  0.434765              1              0.897                     16.1276
20          0.00104006  0.436891              1              0.8973                    17.0124
Chainer$
学習データ数が60000から1000に減ると、かなり速くなった。 ということで、とりあえず学習データの個数をコマンドラインから制御できるようになった。 データ数による学習成果への影響については、次回調べることにする。