化学系エンジニアがAIを学ぶ

PyTorchでディープラーニング、強化学習を学び、主に化学工学の問題に取り組みます

CNN(Convolutional Neural Network)を用いた画像識別の簡単な例

はじめに

CNNによる画像識別の簡単な例として、下記の波形図の1つ目のピークが高いか2つ目のピークが高いかの識別をPyTorchを使って試みる。

f:id:schemer1341:20200223215448p:plain:w500

データの準備

画像をCNNで使える配列の形に変換する方法はいくつかあると思うが、ここではpillowを用いた。波形の画像データがフォルダ「testfig」に置いてあるものとして、下記内容でデータを準備した。 (ここでは識別のラベルを、1つ目のピークが高い→0、2つ目のピークが高い→1とし、ファイル名を例えば0-23.pngのように(ラベル)-(画像番号).pngの形式として、ファイル名からラベルを読み込む形としている)

画像データ

ここで気をつけないといけないのは配列の形状である。PyTorchのCNNの入力の形状[バッチサイズ, チャンネル数, 高さ(ピクセル), 幅(ピクセル)]に合わせないといけない。チャンネル数は画像の奥行き(または深さ)であり、画像の場合は色に対応する。チャンネル数は例えばカラーならR, G, Bの3つ、モノクロなら1つ、などになる。

import glob
from PIL import Image
import numpy as np

# フォルダ testfig においた画像ファイルのリストを取得
image_path = "./testfig"
files = glob.glob(image_path + "/*.png")

x = []
y = []

for filename in files:
    # 画像ファイルをndarrayに変換する
    im = Image.open(filename)
    im = im.convert("1")  # モノクロ画像に変換
    im = im.resize((28, 28))  # 28x28にリサイズ
    im = np.array(im).astype(int)  # intのarrayに変換
    x.append([im])  # CNNの入力に合うshapeとなるよう注意

    # ファイル名から正解ラベルを取得
    label = int(filename[len(image_path)+1])
    y.append(label)

x = np.array(x)  # x.shape (画像数, 1, 28, 28)
y = np.array(y)  # y.shape (画像数,)


# 訓練データとテストデータに分ける
x_train, x_test, y_train, y_test = \
    train_test_split(x, y, test_size=0.2, shuffle=False)

# テンソルに変換
x_train = torch.FloatTensor(x_train)
y_train = torch.LongTensor(y_train)
x_test = torch.FloatTensor(x_test)
y_test = torch.LongTensor(y_test)

# Dataloaderを準備
train_dataloader = torch.utils.data.TensorDataset(x_train, y_train)
train_dataloader = torch.utils.data.DataLoader(train_dataloader, batch_size=4)

ニューラルネットワークの定義

ここでは、畳み込み層→プーリング層→畳み込み層→プーリング層→全結合層→全結合層→出力層という構造とした。最初の全結合層への入力はその直前の層の出力を一次元に変換する必要がある。直前の層の出力サイズの計算は例えばこちら(定番のConvolutional Neural Networkをゼロから理解する - DeepAge)が参考になるが、ここでは高さ・幅が4x4の16チャンネルとなっており、Numpyのviewメソッドを用いて形状の変換を行っている。フィルタ数やフィルタサイズはPyTorchのチュートリアル(Training a Classifier — PyTorch Tutorials 1.4.0 documentation)を参考に適当に決めた。

from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)  # チャンネル1, フィルタ6, フィルタサイズ5x5
        self.pool1 = nn.MaxPool2d(2, 2)  # 2x2のプーリング層
        self.conv2 = nn.Conv2d(6, 16, 5)  # チャンネル6, フィルタ16, フィルタサイズ5x5
        self.pool2 = nn.MaxPool2d(2, 2)

        # これまでの畳込みとプーリングで、16チャンネルの4x4が入力サイズ
        self.fc1 = nn.Linear(16 * 4 * 4, 64)  # fc: fully connected
        self.fc2 = nn.Linear(64, 16)
        self.fc3 = nn.Linear(16, 2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(-1, 16 * 4 * 4)  # [(バッチサイズ), (一次元配列データ)]に並び替え
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

学習

Dataloaderを使ってミニバッチ学習させた。学習に関してはCNNだからといって特別なところはない。

model = Model()

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

num_epochs = 20

for epoch in range(num_epochs):
    total_loss = 0.0
    for inputs, labels in train_dataloader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print('Epoch: {}  Loss: {}'.format(epoch, total_loss))

検証(正答率の確認)

はじめに準備したテストデータを用いて正答率を確認。

outputs = model(x_test)
_, predicted = torch.max(outputs.data, 1)
correct = (y_test == predicted).sum().item()  # Tensorの比較で正答数を確認
print("正答率: {} %".format(correct/len(predicted)*100.0))

結果

問題が単純なためか、正答率は高い。

Epoch: 0  Loss: 6.9794119000434875
Epoch: 1  Loss: 6.852706849575043
Epoch: 2  Loss: 6.788195848464966
Epoch: 3  Loss: 6.761787533760071
Epoch: 4  Loss: 6.738182067871094
Epoch: 5  Loss: 6.654882311820984
Epoch: 6  Loss: 6.485922873020172
Epoch: 7  Loss: 6.05694118142128
Epoch: 8  Loss: 5.022340148687363
Epoch: 9  Loss: 2.5453680232167244
Epoch: 10  Loss: 0.31626373156905174
Epoch: 11  Loss: 0.011260125378612429
Epoch: 12  Loss: 0.0037041376635897905
Epoch: 13  Loss: 0.00182408066757489
Epoch: 14  Loss: 0.0011753853141271975
Epoch: 15  Loss: 0.0009152144466497703
Epoch: 16  Loss: 0.0007907451872597449
Epoch: 17  Loss: 0.0007194795762188733
Epoch: 18  Loss: 0.0006725111852574628
Epoch: 19  Loss: 0.0006386453569575679
正答率: 100.0 %

参考

シンプルな CNN | PyTorch で簡単なニューラルネットワークを構築し画像を分類する方法
シンプルな例で詳しく説明されています。

Pytorchのニューラルネットワーク(CNN)のチュートリアル1.3.1の解説 - Qiita
畳み込みの部分が図を用いて詳しく説明されています。

CNN(Convolutional Neural Network)を理解する - sagantaf
フィルタサイズやパディング、ストライドの影響について解説があります。

Convolutional Neural Networkとは何なのか - Qiita
CNNが具体的に何を見ているかについてわかりやすい説明があります。