2017年 03月 21日
train, test = chainer.datasets.get_mnist()
>>> type(train) <class 'chainer.datasets.tuple_dataset.tupledataset'> >>> len(train) 60000 >>> train[0] (array([ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,これから、trainは要素数60000個のリストで、各リストは、画像データと、数値(0から9)のタプル。
.........中略..........
0. , 0.11764707, 0.14117648, 0.36862746, 0.60392159, 0.66666669, 0.99215692, 0.99215692, 0.99215692, 0.99215692, 0.99215692, 0.88235301, 0.67450982, 0.99215692, 0.94901967, 0.76470596, 0.25098041, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.19215688, 0.9333334 , 0.99215692, 0.99215692, 0.99215692, 0.99215692, 0.99215692, .........中略.......... 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ], dtype=float32), 5) >>> train[0][0].shape (784,)
xtrain = train._datasets[0][:48]
ytrain = train._datasets[1][:48]
あとは、subplotsを使って、画面を分割して表示するだけである。
fig,ax = plt.subplots(nrows=6,ncols=8,sharex=True,sharey=True)
ax = ax.flatten()
for i in range(48):
img = xtrain[i].reshape(28,28)
ax[i].imshow(img,cmap='Greys',interpolation='none')
ax.flatten()により、forループ中で、表示枠を指定するのに、単にax[i]で済ますことができる。interpolation='none'
を指定している。
ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.savefig("mnistdisp48.png")
print(ytrain.reshape(6,8))
plt.show()
plt.subplots()で、分割された枠のx軸y軸の情報を共有するために、sharex、shareyをTrueにしている。そして、最後のところで、set_xticks([])、set_yticks([])により、目盛りなど目障りなものを表示しないようにしている。
#!/usr/bin/env python
# from http://nlp.dse.ibaraki.ac.jp/~shinnou/book/chainer.tgz
import numpy as np
import chainer
from chainer import cuda, Function, gradient_check, Variable
from chainer import optimizers, serializers, utils
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
import matplotlib.pyplot as plt
# http://yann.lecun.com/exdb/mnist/
train, test = chainer.datasets.get_mnist()
xtrain = train._datasets[0][:48]
ytrain = train._datasets[1][:48]
#xtest = test._datasets[0]
#ytest = test._datasets[1]
fig,ax = plt.subplots(nrows=6,ncols=8,sharex=True,sharey=True)
ax = ax.flatten()
for i in range(48):
img = xtrain[i].reshape(28,28)
ax[i].imshow(img,cmap='Greys',interpolation='none')
ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.savefig("mnistdisp48.png")
print(ytrain.reshape(6,8))
plt.show()