化学系エンジニアがAIを学ぶ

PyTorchでディープラーニング、強化学習を学び、主に化学工学の問題に取り組みます

ニューラルネットワークによる手書き数字の認識

はじめに

scikit-learnライブラリに含まれる手書き数字データを用いて、ニューラルネットワークによる 手書き数字認識をやってみる。

データの準備

手書き数字データはサイズ 8 x 8で、グレースケールの階調が17段階となっている。具体的な手書き数字画像は scikit-learnのexampleで確認できる。 入力の個数は 8 x 8 = 64個で、出力は0〜9の10個となっている。データをPyTorchで扱うために、入力はFloatTensor、出力はLongTensorに 変換している。

## 必要なモジュールのインポート
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
from matplotlib import pyplot as plt
from sklearn.datasets import load_digits # 手書き数字データ読み込み
from sklearn.model_selection import train_test_split # 手書き数字データを学習用とテスト用に分類する

## データの準備
digits = load_digits()
x_train, x_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size=0.2) # 2割がテストデータ
x_train = torch.FloatTensor(x_train)
y_train = torch.LongTensor(y_train)
x_test  = torch.FloatTensor(x_test)
y_test  = torch.LongTensor(y_test)

ニューラルネットワークモデルの定義

入力64個、出力10個。中間層は2層とし、ノード数は100個とした。 中間層の活性化関数はReLUとし、出力層はl3の計算値をそのまま出力とした。

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.l1 = nn.Linear(64, 100)
        self.l2 = nn.Linear(100, 100)
        self.l3 = nn.Linear(100, 10)

    def forward(self, x):
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = self.l3(x)
        return x

損失関数、最適化関数

損失関数にnn.CrossEntropyLossを使用した。learning_rateは試行錯誤して 良さそうなところに決めた。

model = Model()
criterion = nn.CrossEntropyLoss()
learning_rate = 1e-2
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

学習

epoch数は試行錯誤して良さそうなところに設定した。

epoch = 200
loss_log = [] # 学習状況のプロット用
for t in range(epoch):
    optimizer.zero_grad()
    y_model = model(x_train)
    loss = criterion(y_model, y_train)
    loss.backward()
    optimizer.step()
    loss_log.append(loss.item())

学習結果の確認

## 学習状況のプロット
plt.plot(loss_log)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.yscale('log')
plt.show()

## 正答率の確認
outputs = model(x_test)
_, predicted = torch.max(outputs.data, 1)
correct = (y_test == predicted).sum().item() # Tensorの比較で正答数を確認
print("正答率: {} %".format(correct/len(predicted)*100.0))

結果

epoch数5000までやってみたが、学習状況はなかなか落ち着いてこなかった。いっぽうで、正答率はepoch数が100でも5000でも ほとんど変わらなかった。

学習状況

f:id:schemer1341:20190113220841p:plain:w500
図 学習状況 epoch数に対するlossの変化

正答率

>>> 正答率: 98.61111111111111 % # epoch数 100
>>> 正答率: 98.05555555555556 % # epoch数 5000

所感

人でも実際に見間違うことはあるだろうから、そんなに悪くない正答率なのではないかと思う。 畳み込みニューラルネットワークを利用したら精度が上がる?画像が 8 x 8 と小さいのであまり効果はなさそうに思われる。

参考