2017年,Hinton团队提出胶囊网络,首次将标量型网络扩展到矢量,并运用动态路由方式来进行胶囊之间的传递计算。提出的矢量神经元被认为具有保留物体姿态的能力,为神经网络带来了等变性(equivariance)。本着learning by doing的态度,笔者尝试对这一篇论文进行复现。本文不会对其原论文原理和思想有太多解释。在保证工程性和完整性的同时,尽可能记录自己在实现过程中的总结和反思。Anyway,实现过程也许会有一些bug,欢迎交流和提交issue~
import os import torch import torch.nn as nn import torch.nn.functional as F import torchvision import matplotlib.pyplot as plt import numpy as np from torchvision import transforms from torchvision.utils import save_image
defdynamic_routing(x, iterations=3): """ Args: x: u_hat, (B, 10, 32x6x6, 16, 1) Return: v: next layer output (B, 10, 16) """ N = 32*6*6# previous layer N1 = 10# next layer B = x.shape[0] b = torch.zeros(B,N1,N,1, 1).to(x.device) for _ inrange(iterations): # probability of each vector to be distributed is 1 # (B,10,32*6*6,1, 1) c = F.softmax(b, dim=1) # (B,10,16) s = torch.sum(x.matmul(c), dim=2).squeeze(-1) # (B,10,16) v = squash(s)
# (B,10,32*6*6,1,1) b = b + v[:,:,None,None,:].matmul(x)
return v
x = torch.rand(1,10,32*6*6,16, 1) dynamic_routing(x).shape
defmargin_loss(y, y_hat): """ Args: y: ground truth labels (B) y_hat: class capsules with (B, 10, 16) Return the margin loss """ _lambda = 0.5 m_plus = 0.9 m_minus = 0.1 nclasses = 10 y_norm = y_hat.norm(dim=-1) # (B,10) T = F.one_hot(y, nclasses) # use it as index for right class (B,10) T = T.float()
right = torch.max(torch.zeros_like(y_norm), m_plus-y_norm*T) right = right**2 wrong = torch.max(torch.zeros_like(y_norm), y_norm*(1-T)-m_minus) wrong = _lambda*wrong**2 return torch.sum(right+wrong)
deftest_margin_loss(): y = torch.randint(0,10,(20,)) y_hat = torch.rand(20,10,16) print(margin_loss(y,y_hat).item())
# Test the model defevaluate(model, test_loader): model.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad(): correct = 0 total = 0 for images, labels in test_loader: images = images.to(device) labels = labels.to(device) outputs = model(images) outputs = outputs.norm(dim=-1) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() return correct / total
acc = evaluate(encoder, test_loader) print('Test Accuracy of the model on the 10000 test images: {:.2f}%'.format(100 * acc)) # Test Accuracy of the model on the 10000 test images: 98.62%
acc = evaluate(encoder, test_loader) print('Test accuracy of the model on the 10000 test images of encoder in AE: {:.2f}%'.format(100 * acc)) Test Accuracy of the model on the 10000 test images of encoder in AE: 98.80%
defevaluate_class_capsule(model, x, y, delta=1, dim=0, l=5): """ Simply adding class capsules digit from -7 to 7, to see what happens about reconstruction. Args: model: autoencoder x: input image (B,1,28,28) y dim: which dim you want to research Return [origin image,reconstructed_xs] [(B,1,28,28), ... ,(B,1,28,28)] """ model.eval() B = x.shape[0] encoder, decoder = model.encoder, model.decoder with torch.no_grad(): # Auto encoder, but adding class capsules digit from -7 to 7 class_capsules = encoder(x) # (B, 10, 16) selected_capsules = class_capsules[torch.arange(B), y] # (B, 16) assert selected_capsules.shape == (B, 16) index = F.one_hot(torch.ones(1, dtype=torch.long)*dim, num_classes=16) index = index.float().to(device) shifted_capsules = [selected_capsules+i*delta*index for i inrange(-l,l+1)] reconstructed_xs = [decoder(i) for i in shifted_capsules] reconstructed_xs.insert(0, x) return reconstructed_xs
defresearch_for_class_capsule_for(i=0, delta=0.5): """ Test class capsule dim usage for i th test image """ result = [] for X,y in test_loader: X,y = X[i][None,...].to(device),y[i][None,...].to(device) for dim inrange(16): result.append( evaluate_class_capsule(autoencoder, X, y, delta=delta, dim=dim) ) break fg, axs = plt.subplots(nrows=16, ncols=len(result[0]), gridspec_kw={'hspace': 0, 'wspace': 0.1}, figsize=(13,13)) fg.suptitle(f'research for each dim in capsule, delta={delta}')
for i inrange(16): for j inrange(len(result[0])): axs[i, j].imshow(result[i][j].squeeze().cpu(), cmap='binary') axs[i, j].axis('off') plt.show()
部分效果
1 2
for i inrange(10): research_for_class_capsule_for(i, 0.05)
1 2
for i inrange(10): research_for_class_capsule_for(i, 0.1)