ニューラルネットワークによる手書き数字の認識
はじめに
scikit-learnライブラリに含まれる手書き数字データを用いて、ニューラルネットワークによる 手書き数字認識をやってみる。
データの準備
手書き数字データはサイズ 8 x 8で、グレースケールの階調が17段階となっている。具体的な手書き数字画像は scikit-learnのexampleで確認できる。 入力の個数は 8 x 8 = 64個で、出力は0〜9の10個となっている。データをPyTorchで扱うために、入力はFloatTensor、出力はLongTensorに 変換している。
## 必要なモジュールのインポート import numpy import torch import torch.nn as nn import torch.nn.functional as F from matplotlib import pyplot as plt from sklearn.datasets import load_digits # 手書き数字データ読み込み from sklearn.model_selection import train_test_split # 手書き数字データを学習用とテスト用に分類する ## データの準備 digits = load_digits() x_train, x_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size=0.2) # 2割がテストデータ x_train = torch.FloatTensor(x_train) y_train = torch.LongTensor(y_train) x_test = torch.FloatTensor(x_test) y_test = torch.LongTensor(y_test)
ニューラルネットワークモデルの定義
入力64個、出力10個。中間層は2層とし、ノード数は100個とした。 中間層の活性化関数はReLUとし、出力層はl3の計算値をそのまま出力とした。
class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.l1 = nn.Linear(64, 100) self.l2 = nn.Linear(100, 100) self.l3 = nn.Linear(100, 10) def forward(self, x): x = F.relu(self.l1(x)) x = F.relu(self.l2(x)) x = self.l3(x) return x
損失関数、最適化関数
損失関数にnn.CrossEntropyLoss
を使用した。learning_rateは試行錯誤して
良さそうなところに決めた。
model = Model()
criterion = nn.CrossEntropyLoss()
learning_rate = 1e-2
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
学習
epoch数は試行錯誤して良さそうなところに設定した。
epoch = 200 loss_log = [] # 学習状況のプロット用 for t in range(epoch): optimizer.zero_grad() y_model = model(x_train) loss = criterion(y_model, y_train) loss.backward() optimizer.step() loss_log.append(loss.item())
学習結果の確認
## 学習状況のプロット plt.plot(loss_log) plt.xlabel('epoch') plt.ylabel('loss') plt.yscale('log') plt.show() ## 正答率の確認 outputs = model(x_test) _, predicted = torch.max(outputs.data, 1) correct = (y_test == predicted).sum().item() # Tensorの比較で正答数を確認 print("正答率: {} %".format(correct/len(predicted)*100.0))
結果
epoch数5000までやってみたが、学習状況はなかなか落ち着いてこなかった。いっぽうで、正答率はepoch数が100でも5000でも ほとんど変わらなかった。
学習状況
正答率
>>> 正答率: 98.61111111111111 % # epoch数 100 >>> 正答率: 98.05555555555556 % # epoch数 5000
所感
人でも実際に見間違うことはあるだろうから、そんなに悪くない正答率なのではないかと思う。 畳み込みニューラルネットワークを利用したら精度が上がる?画像が 8 x 8 と小さいのであまり効果はなさそうに思われる。