CNN(Convolutional Neural Network)を用いた画像識別の簡単な例
はじめに
CNNによる画像識別の簡単な例として、下記の波形図の1つ目のピークが高いか2つ目のピークが高いかの識別をPyTorchを使って試みる。
データの準備
画像をCNNで使える配列の形に変換する方法はいくつかあると思うが、ここではpillowを用いた。波形の画像データがフォルダ「testfig」に置いてあるものとして、下記内容でデータを準備した。 (ここでは識別のラベルを、1つ目のピークが高い→0、2つ目のピークが高い→1とし、ファイル名を例えば0-23.pngのように(ラベル)-(画像番号).pngの形式として、ファイル名からラベルを読み込む形としている)
ここで気をつけないといけないのは配列の形状である。PyTorchのCNNの入力の形状[バッチサイズ, チャンネル数, 高さ(ピクセル), 幅(ピクセル)]に合わせないといけない。チャンネル数は画像の奥行き(または深さ)であり、画像の場合は色に対応する。チャンネル数は例えばカラーならR, G, Bの3つ、モノクロなら1つ、などになる。
import glob from PIL import Image import numpy as np # フォルダ testfig においた画像ファイルのリストを取得 image_path = "./testfig" files = glob.glob(image_path + "/*.png") x = [] y = [] for filename in files: # 画像ファイルをndarrayに変換する im = Image.open(filename) im = im.convert("1") # モノクロ画像に変換 im = im.resize((28, 28)) # 28x28にリサイズ im = np.array(im).astype(int) # intのarrayに変換 x.append([im]) # CNNの入力に合うshapeとなるよう注意 # ファイル名から正解ラベルを取得 label = int(filename[len(image_path)+1]) y.append(label) x = np.array(x) # x.shape (画像数, 1, 28, 28) y = np.array(y) # y.shape (画像数,) # 訓練データとテストデータに分ける x_train, x_test, y_train, y_test = \ train_test_split(x, y, test_size=0.2, shuffle=False) # テンソルに変換 x_train = torch.FloatTensor(x_train) y_train = torch.LongTensor(y_train) x_test = torch.FloatTensor(x_test) y_test = torch.LongTensor(y_test) # Dataloaderを準備 train_dataloader = torch.utils.data.TensorDataset(x_train, y_train) train_dataloader = torch.utils.data.DataLoader(train_dataloader, batch_size=4)
ニューラルネットワークの定義
ここでは、畳み込み層→プーリング層→畳み込み層→プーリング層→全結合層→全結合層→出力層という構造とした。最初の全結合層への入力はその直前の層の出力を一次元に変換する必要がある。直前の層の出力サイズの計算は例えばこちら(定番のConvolutional Neural Networkをゼロから理解する - DeepAge)が参考になるが、ここでは高さ・幅が4x4の16チャンネルとなっており、Numpyのviewメソッドを用いて形状の変換を行っている。フィルタ数やフィルタサイズはPyTorchのチュートリアル(Training a Classifier — PyTorch Tutorials 1.4.0 documentation)を参考に適当に決めた。
from sklearn.model_selection import train_test_split import torch import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.Conv2d(1, 6, 5) # チャンネル1, フィルタ6, フィルタサイズ5x5 self.pool1 = nn.MaxPool2d(2, 2) # 2x2のプーリング層 self.conv2 = nn.Conv2d(6, 16, 5) # チャンネル6, フィルタ16, フィルタサイズ5x5 self.pool2 = nn.MaxPool2d(2, 2) # これまでの畳込みとプーリングで、16チャンネルの4x4が入力サイズ self.fc1 = nn.Linear(16 * 4 * 4, 64) # fc: fully connected self.fc2 = nn.Linear(64, 16) self.fc3 = nn.Linear(16, 2) def forward(self, x): x = F.relu(self.conv1(x)) x = self.pool1(x) x = F.relu(self.conv2(x)) x = self.pool2(x) x = x.view(-1, 16 * 4 * 4) # [(バッチサイズ), (一次元配列データ)]に並び替え x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x
学習
Dataloaderを使ってミニバッチ学習させた。学習に関してはCNNだからといって特別なところはない。
model = Model() criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) num_epochs = 20 for epoch in range(num_epochs): total_loss = 0.0 for inputs, labels in train_dataloader: optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() total_loss += loss.item() print('Epoch: {} Loss: {}'.format(epoch, total_loss))
検証(正答率の確認)
はじめに準備したテストデータを用いて正答率を確認。
outputs = model(x_test) _, predicted = torch.max(outputs.data, 1) correct = (y_test == predicted).sum().item() # Tensorの比較で正答数を確認 print("正答率: {} %".format(correct/len(predicted)*100.0))
結果
問題が単純なためか、正答率は高い。
Epoch: 0 Loss: 6.9794119000434875 Epoch: 1 Loss: 6.852706849575043 Epoch: 2 Loss: 6.788195848464966 Epoch: 3 Loss: 6.761787533760071 Epoch: 4 Loss: 6.738182067871094 Epoch: 5 Loss: 6.654882311820984 Epoch: 6 Loss: 6.485922873020172 Epoch: 7 Loss: 6.05694118142128 Epoch: 8 Loss: 5.022340148687363 Epoch: 9 Loss: 2.5453680232167244 Epoch: 10 Loss: 0.31626373156905174 Epoch: 11 Loss: 0.011260125378612429 Epoch: 12 Loss: 0.0037041376635897905 Epoch: 13 Loss: 0.00182408066757489 Epoch: 14 Loss: 0.0011753853141271975 Epoch: 15 Loss: 0.0009152144466497703 Epoch: 16 Loss: 0.0007907451872597449 Epoch: 17 Loss: 0.0007194795762188733 Epoch: 18 Loss: 0.0006725111852574628 Epoch: 19 Loss: 0.0006386453569575679 正答率: 100.0 %
参考
・シンプルな CNN | PyTorch で簡単なニューラルネットワークを構築し画像を分類する方法
シンプルな例で詳しく説明されています。
・Pytorchのニューラルネットワーク(CNN)のチュートリアル1.3.1の解説 - Qiita
畳み込みの部分が図を用いて詳しく説明されています。
・CNN(Convolutional Neural Network)を理解する - sagantaf
フィルタサイズやパディング、ストライドの影響について解説があります。
・Convolutional Neural Networkとは何なのか - Qiita
CNNが具体的に何を見ているかについてわかりやすい説明があります。