ChatGPTやStable Diffusion等の生成AIが大流行りですが、少し原点に戻って、畳み込みニューラルネットワークな話をします。
といっても、ただの画像認識ではなくて、metric learning、すなわち「距離学習」をやります。
以下、プログラムコードを含めて参考にしたのは、以下のサイト。
【深層距離学習】Center Lossを徹底解説 -Pytorchによる実践あり-|はやぶさの技術ノート
例えば、よく画像分類の初歩で使われる「MNIST」(えむにすと、と読みます)という手書きの数値が大量に集められた画像セットがあります。
0から9までの10種類の手書きの数字データなんですが、これをいかに正確に分類するかという学習をやるために、いわゆる「畳み込みニューラルネットワーク(CNN)」という手法を用います。
が、これを通常のCNN手法で分類すると、こんな感じになります。
全部で10種類のデータがきれいに切り分けられている(つまり、分類が上手くいっている)ことを示している図なのですが、なんだか気持ち悪いところがあります。
それは、原点からしゅーっと伸びたような、妙なデータ分布をしているということ。
こういう広い許容範囲を持っていると、例えば0から9までの数値とは似ても似つかない文字(アルファベットのAなど)を入力しても、強引に0から9に解釈されてしまいます。
0から9までの数値と、それ以外の文字を分けるためには、以下のような感じの分類が望ましいです。
それぞれの数値の分類が、綺麗に分かれています。
こんな具合に、分類の距離を離してやることで分類のブレをなくしてやろうという手法が「距離学習」です。
(かなーり大雑把に説明していますが、そんなようなものだと思ってください)
例えば「1」という手書きの数字が入力されたら、機械学習モデルを通して得られた出力データは、この色のついた点群のどこかに収まるはずです。
逆に数字ではない「A」のような文字が入力されたら、この点群とはまったく別のところにプロットされるはずです。
というのを、実際に体感してみましょう、というプログラムです。
今回のプログラムですが、Jupyter notebookで作成したため、ブツ切りとなってます。
そのままコピーして一つにまとめて動かしていただいても大丈夫なはずです。
まずは、必要なライブラリから。
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler
import matplotlib.pyplot as plt
from torchinfo import summary
from torch.autograd.function import Function
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
%matplotlib inline
pytorchで動かしてます。ちなみにバージョンは1.12.1です。
pytorchのインストール方法は、以下を参照。
Previous PyTorch Versions | PyTorch
他には、numpy、torchinfo、matplotlibが必要です。
epochs = 100 # 繰り返し回数
use_cuda = torch.cuda.is_available() and True
device = torch.device("cuda" if use_cuda else "cpu") # CPUかGPUの自動判別
if not os.path.exists("images"):
# ディレクトリが存在しない場合、ディレクトリを作成する
os.makedirs("images")
初期値を設定します。GPUが使える環境か否かを、自動判別させてます。
class CenterLoss(nn.Module):
def __init__(self, num_classes, feat_dim, size_average=True):
super(CenterLoss, self).__init__()
self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))
self.centerlossfunc = CenterlossFunc.apply
self.feat_dim = feat_dim
self.size_average = size_average
def forward(self, label, feat):
batch_size = feat.size(0)
feat = feat.view(batch_size, -1)
# To check the dim of centers and features
if feat.size(1) != self.feat_dim:
raise ValueError("Center's dim: {0} should be equal to input feature's \
dim: {1}".format(self.feat_dim,feat.size(1)))
batch_size_tensor = feat.new_empty(1).fill_(batch_size if self.size_average else 1)
loss = self.centerlossfunc(feat, label, self.centers, batch_size_tensor)
return loss
class CenterlossFunc(Function):
def forward(ctx, feature, label, centers, batch_size):
ctx.save_for_backward(feature, label, centers, batch_size)
centers_batch = centers.index_select(0, label.long())
return (feature - centers_batch).pow(2).sum() / 2.0 / batch_size
def backward(ctx, grad_output):
feature, label, centers, batch_size = ctx.saved_tensors
centers_batch = centers.index_select(0, label.long())
diff = centers_batch - feature
# init every iteration
counts = centers.new_ones(centers.size(0))
ones = centers.new_ones(label.size(0))
grad_centers = centers.new_zeros(centers.size())
counts = counts.scatter_add_(0, label.long(), ones)
grad_centers.scatter_add_(0, label.unsqueeze(1).expand(feature.size()).long(), diff)
grad_centers = grad_centers/counts.view(-1, 1)
return - grad_output * diff / batch_size, None, grad_centers / batch_size, None
Center Lossというものを定義します。
通常、画像分類ではSoftmax Lossが最小となるように最適化計算をしますが、これに加えてこのCenter Lossというものを最小化することで、距離学習を行います。
すごくざっくり言うと、0から9までの10種類の手書き数字の出力データ群の中心を離し、それぞれがその中心に寄せるように学習させるためのLoss値ということのようです。だからCenter Lossと言うようで。
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1_1 = nn.Conv2d(1, 32, kernel_size=5, padding=2)
self.prelu1_1 = nn.PReLU()
self.conv1_2 = nn.Conv2d(32, 32, kernel_size=5, padding=2)
self.prelu1_2 = nn.PReLU()
self.conv2_1 = nn.Conv2d(32, 64, kernel_size=5, padding=2)
self.prelu2_1 = nn.PReLU()
self.conv2_2 = nn.Conv2d(64, 64, kernel_size=5, padding=2)
self.prelu2_2 = nn.PReLU()
self.conv3_1 = nn.Conv2d(64, 128, kernel_size=5, padding=2)
self.prelu3_1 = nn.PReLU()
self.conv3_2 = nn.Conv2d(128, 128, kernel_size=5, padding=2)
self.prelu3_2 = nn.PReLU()
self.preluip1 = nn.PReLU()
self.ip1 = nn.Linear(128*3*3, 2)
self.ip2 = nn.Linear(2, 10, bias=False)
def forward(self, x):
x = self.prelu1_1(self.conv1_1(x))
x = self.prelu1_2(self.conv1_2(x))
x = F.max_pool2d(x,2)
x = self.prelu2_1(self.conv2_1(x))
x = self.prelu2_2(self.conv2_2(x))
x = F.max_pool2d(x,2)
x = self.prelu3_1(self.conv3_1(x))
x = self.prelu3_2(self.conv3_2(x))
x = F.max_pool2d(x,2)
x = x.view(-1, 128*3*3)
ip1 = self.preluip1(self.ip1(x))
ip2 = self.ip2(ip1)
return ip1, F.log_softmax(ip2, dim=1)
続いて、ネットワークです。MNISTは28 X 28のモノクロ画像なので、非常に小さいネットワークで構成されてますね。
なお、ip1というのは畳み込みニューラルネットワークから結合層となった直後に出力値を使い、それを二次元化した特徴ベクトル、ip2というのは10次元のベクトルで、Softmaxをかける前の出力ベクトルのようです。
def visualize(feat, labels, epoch):
c = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff',
'#ff00ff', '#990000', '#999900', '#009900', '#009999']
plt.clf()
for i in range(10):
plt.plot(feat[labels == i, 0], feat[labels == i, 1], '.', c=c[i])
plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], loc = 'upper right')
plt.xlim(xmin=-8,xmax=8)
plt.ylim(ymin=-8,ymax=8)
plt.text(-7.8,7.3,"epoch=%d" % epoch)
plt.savefig('./images/epoch=%d.jpg' % epoch)
そしてこのvisualizeという関数は、特徴ベクトルの出力点群をグラフ化し、画像として残すためのアルゴリズムです。最初に出したカラフルな点群グラフは、これで作られてます。
def train(epoch):
print("Training... Epoch = %d" % epoch)
ip1_loader = []
idx_loader = []
for i,(data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
ip1, pred = model(data)
loss = nllloss(pred, target) + loss_weight * centerloss(target, ip1)
optimizer4nn.zero_grad()
optimzer4center.zero_grad()
loss.backward()
optimizer4nn.step()
optimzer4center.step()
ip1_loader.append(ip1)
idx_loader.append((target))
feat = torch.cat(ip1_loader, 0)
labels = torch.cat(idx_loader, 0)
visualize(feat.data.cpu().numpy(),labels.data.cpu().numpy(),epoch)
torch.save(model.state_dict(), './mnist.pth')
return feat, labels
で、こちらは学習用の関数。最後に「mnist.pth」という名前でモデルを保存します。
# Dataset
trainset = datasets.MNIST('./MNIST', download=True,train=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))]))
train_loader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)
# データセットの一部を可視化
data_iter = iter(train_loader)
images, labels = data_iter.next()
# matplotlib で1つ目のデータを可視化
for idx, _ in enumerate(labels):
plt.figure(idx+1)
npimg = images[idx].numpy()
npimg = npimg.reshape((28, 28))
plt.imshow(npimg, cmap='Greens')
plt.title('Label: {}'.format(labels[idx]))
ここからは、メイン処理に入ります。
初回のみMNISTのデータセットをダウンロードし、「MNIST」というフォルダに入れます。
一応、中身を確認したいと思ったので、一部を表示させてます。
こんな表示がずらずらっと出てきます。
# Model
model = Net().to(device)
summary(model)
続いてモデル定義です。summaryのところで、ネットワーク図と各層のパラメータ数を表示させてます。
=================================================================
Layer (type:depth-idx) Param #
=================================================================
Net --
├─Conv2d: 1-1 832
├─PReLU: 1-2 1
├─Conv2d: 1-3 25,632
├─PReLU: 1-4 1
├─Conv2d: 1-5 51,264
├─PReLU: 1-6 1
├─Conv2d: 1-7 102,464
├─PReLU: 1-8 1
├─Conv2d: 1-9 204,928
├─PReLU: 1-10 1
├─Conv2d: 1-11 409,728
├─PReLU: 1-12 1
├─PReLU: 1-13 1
├─Linear: 1-14 2,306
├─Linear: 1-15 20
=================================================================
Total params: 797,181
Trainable params: 797,181
Non-trainable params: 0
=================================================================
こんなのが表示されるはずです。
# NLLLoss
nllloss = nn.NLLLoss().to(device) #CrossEntropyLoss = log_softmax + NLLLoss
# CenterLoss
loss_weight = 1
centerloss = CenterLoss(10, 2).to(device)
そして、Loss関数を定義し、
# optimzer4nn
optimizer4nn = optim.SGD(model.parameters(),lr=0.001,momentum=0.9, weight_decay=0.0005)
sheduler = lr_scheduler.StepLR(optimizer4nn,20,gamma=0.8)
# optimzer4center
optimzer4center = optim.SGD(centerloss.parameters(), lr =0.5)
最適化手法を定義したら、
for epoch in range(epochs):
sheduler.step()
feat, labels = train(epoch+1)
学習開始です。
100サイクルほど回しましたが、我が家では大体30分ほどかかりました。
最後にこんな感じの絵が出てくれば、終了です。
一応、各エポックごとの画像もimagesフォルダに出力されます。
3エポック目は、ほとんど未分離だった各分類が、
だんだんと別れ始め、
そして100エポックできれいに分かれました。
これくらい綺麗に分離できていたら、学習は終了です。
学習はここまでですが、ここで自分で作った手書きの画像データを使って、それを「推論」させてみます。
といっても、それぞれの画像が上の点群グラフの中でどのあたりに来るのか?を見るだけです。
def visualize_pred(feat, labels, out_feat):
c = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff',
'#ff00ff', '#990000', '#999900', '#009900', '#009999']
plt.clf()
for i in range(10):
plt.plot(feat[labels == i, 0], feat[labels == i, 1], '.', c=c[i])
plt.plot(out_feat[0],out_feat[1], marker="*", markersize=30, markerfacecolor="r")
plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], loc = 'upper right')
plt.xlim(xmin=-11,xmax=11)
plt.ylim(ymin=-11,ymax=11)
plt.text(-10.8,10.3,"test image")
plt.savefig('./images/test_pred.jpg')
まずは、推論用の可視化関数を作ります。
test_model = Net()
test_model.load_state_dict(torch.load('./mnist.pth'))
test_model.eval()
# 画像の読み込みと前処理
image = Image.open('test_1.png').convert('L')
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
image = transform(image)
image = image.unsqueeze(0)
# 推論の実行
with torch.no_grad():
output = test_model(image)
output_xy = (output[0].data.cpu().numpy())[0]
print(output_xy[0], output_xy[1])
visualize_pred(feat.data.cpu().numpy(),labels.data.cpu().numpy(), output_xy)
ここで「test_1.jpg」という画像を読み込み、その結果をプロットさせます。
ちなみに「test_1.jpg」という画像は以下。
はい、数字の「1」ですね。
これを推論させたら、どうなるか?
こうなりました。赤い星が見えるかと思いますが、それが先の画像の推論結果です。
ちなみに、黄色い点群が「1」で、この赤い星印はその黄色い点群の中にいます。
要するに、「1」だと認識されていることが分かります。
それじゃ、気を取り直してこんどは「あ」という文字をぶち込んでみます。
0から9の数字とは似ても似つかないこの文字、果たして結果はどうなるか?
こうなりました。どの点群にも入らず、「1」と「5」の間よりちょっと原点寄りにいますね。
要するに、「その他」と認識されたってことですね。
今度は、アルファベットの「A」を入れてみました。
うーん、「8」と「3」の間辺り?まあ、「その他」のゾーンですね。
しかし、困ったのはこの数字の「3」。
うーん、原点から見れば、「3」の集団である水色の点群の先にいるんですが、ちょっと離れすぎちゃいませんかね?
実は、いくつかの手書き数字で確かめたんですが、この「3」のような微妙な結果になることが多かったです。
やや、過学習気味でしたかね?
とにもかくにも、距離学習をイメージできるプログラムができました。
ところで、今はこのCenter Lossと言うのはあまり使われていないようで、もっぱらよく使われるのは「SphereFace」、「CosFace」、「ArcFace」あたりでしょうか。このCenter Lossとはまた違う距離の離し方をしますが、分類ごとの距離を離すよう学習させるという点では同じようなものです。
詳しくは、以下をご参照願います。
モダンな深層距離学習 (deep metric learning) 手法: SphereFace, CosFace, ArcFace - Qiita
ちなみに、この深層距離学習というのは異常検知手法や顔認証のようなもので使われてます。私の場合はArcfaceを使って異常検知させるものをよく作ってます。はい。
コメント