0%

动手实现《stacked capsule autoencoders》

前言

本次动手实现论文《stacked capsule autoencoders》的pytorch版本。这篇论文的原作者开源了TensorFlow版本[1],其细节和工程性都挺不错,是个参考的好范本(做研究建议直接参考原项目)。关于pytorch的实现,github也开源了相关例子[2,3,4],但这些都只实现了原文第二个实验。本文是对其原文第一个实验的复现笔记,后续也计划复现第二个实验。

全部复现代码会开源在https://github.com/QiangZiBro/stacked_capsule_autoencoders.pytorch,欢迎提issue。

复现目标

第一个实验

  • Set Transformer (直接使用原论文代码)
  • CCAE模型
  • 高斯混合模型的编程实现
  • Concellation数据集生成
  • CCAE训练
  • 可视化CCAE

前期准备

环境

  • 系统:ubuntu 18.04.04
  • 显卡:GP100
  • 环境管理:miniconda3
  • 相关第三方库:pytorch1.7

为了保证工程性以及少点重复工作,我们基于一个深度学习模板项目来进行本次实现。当然,为了可解释性,也会使用notebook进行相关可视化。同时会写一些必备的test case,来帮助我更加了解一些细节。

1
2
3
4
5
git clone https://github.com/QiangZiBro/pytorch-template
cd pytorch-template
python new_project.py ../stacked_capsule_autoencoders.pytorch
cd ../stacked_capsule_autoencoders.pytorch
wget https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore -O .gitignore

模型细节

概览

CCAE编码器为Set Transformer,解码器为mlp的自编码器,其输入是2维平面上的点集。

我们接下来总结CCAE模型的细节

Set Transformer


Set Transformer的编码器可以是连续的SAB或者连续的ISAB。使用ISAB的优点是其使用了诱导点$I \in \mathbb{R}^{m \times d}$,使得计算参数比SAB更少。
$$
Z=\operatorname{Encoder}(X)=\operatorname{SAB}(\operatorname{SAB}(X)) \in \mathbb{R}^{n \times d}
$$

$$
Z=\operatorname{Encoder}(X)=\operatorname{ISAB}{m}\left(\operatorname{ISAB}{m}(X)\right) \in \mathbb{R}^{n \times d}
$$

解码器
$$
O=\operatorname{Decoder}(Z ; \lambda)=\operatorname{rFF}\left(\operatorname{SAB}\left(\operatorname{PMA}_{k}(Z)\right)\right) \in \mathbb{R}^{k \times d}
$$


细节部分

  • $\operatorname{rFF}(x)$ 全连接 ,具体的讲,输入$n \times d$维,输出也是$n \times d$维。
  • 注意力机制:$\operatorname{Att}(Q, K, V ; \omega)=\omega\left(Q K^{\top}\right) V \in \mathbb{R}^{n \times d_v}$,其中,$\omega$是激活函数。
  • 多头注意力机制

$$
\operatorname{Multihead} (Q, K, V ; \lambda, \omega) = \operatorname{concat(Z_1,…,Z_h)}W^O \in \mathbb{R}^{n \times d},\ Z_{j}=\operatorname{Att}\left(Q W_{j}^{Q}, K W_{j}^{K}, V W_{j}^{V} ; \omega_{j}\right)​
$$

  • 多头注意力模块(Multihead Attention Block) 输出维度和X维度相同
    $$
    \operatorname{MAB}(X, Y)=\operatorname{LayerNorm}(H+\operatorname{rFF}(H))
    $$
    其中$H=\text { LayerNorm }(X+\operatorname{Multihead}(X, Y, Y ; \omega))$

  • 集合注意力模块(Set Attention Block ),计算复杂度$\mathcal{O}\left(n^{2}\right)$

$$
\operatorname{SAB}(X) = \operatorname{MAB}(X,X)
$$

  • 诱导集合注意力模块(Induced Set Attention Block )

$$
\operatorname{ISAB}_m(X)=\operatorname{MAB}(X, H) \in \mathbb{R}^{n \times d}​
$$

​ 其中$H=\operatorname{MAB}(I,X) \in \mathbb{R}^{m \times d}$,$I \in \mathbb{R}^{m \times d}$为可学习参数。

  • 多头注意力机制的池化(Pooling by Multihead Attention)。池化是一种常见的聚合(aggregation)操作。上面提到,池化可以是最大或是平均。这里提出的池化是应用一个MAB在一个可学习的矩阵$S \in \mathbb{R}^{k \times d}$上。在一些聚类任务上,$k$设为我们需要的类别数。使用基于注意力的池化的直觉是,每个实例对target的重要性都不一样

$$
\operatorname{PMA}_{k}(Z)=\operatorname{MAB}(S, \operatorname{rFF}(Z))
$$

$$
H=\operatorname{SAB}\left(\operatorname{PMA}_{k}(Z)\right)
$$

其中池化操作$\operatorname{PMA}_{k}(Z)=\operatorname{MAB}(S, \operatorname{rFF}(Z)) \in \mathbb{R}^{k \times d}$,$k$表示输出集合中实例的个数,$k < n$。

CCAE

对M个2维输入点组成的集合$\mathbf{x_{1:M}}$,首先使用Set Transformer将这个集合编码为$K$个$(2\times 2+n_c+1)$的object向量,这三个数分别表示OV矩阵大小、特殊向量(即特征)、存在概率。特殊向量的尺度是个超参,原文$n_c=16$。
$$
\mathrm{OV}{1: K}, \mathbf{c}{1: K}, a_{1: K}=\mathrm{h}^{\mathrm{caps}}\left(\mathbf{x}{1: M}\right) = \operatorname{SetTransformer}\left(\mathbf{x}{1: M}\right)
$$
对每个object向量,取其特殊向量,通过mlp解码出$N$个part。其中,每个part长度为$(2+1+1)$,分别为OP矩阵、存在概率、和标准差;每个object应用一个单独的mlp,mlp结构为$n_c,128,(2+1+1)\times N$。
$$
\mathrm{OP}{k, 1: N}, a{k, 1: N}, \lambda_{k, 1: N}=\mathrm{h}{\mathrm{k}}^{\mathrm{part}}\left(\mathbf{c}{k}\right) = \operatorname{mlp_k}\left(\mathbf{c}_{k}\right)
$$
在原文例子中,$M=3, N=4$。

每个解码出的part都可以表示一个高斯分量。CCAE处理的数据是2维平面点,因此表示的高斯分量的均值是2维,协方差矩阵大小是$2 \times 2$的矩阵。具体的讲,由第$i$个object产生的第$j$个part表示的高斯分量均值为
$$
\mu_{k,n} = OV_k OP_{k,n}
$$
其中$OV_k$是$2 \times 2$的矩阵,$OP_{k,n}$是长度为2的向量。而part只有一个标量的标准差$\lambda_{k,n}$,即,原文将一个高斯分量假设为各向同性,通过标准差$\lambda_{k,n}$计算到高斯模型的协方差矩阵:
$$
\Sigma_{k,n} = \begin{bmatrix}
\frac{1}{\lambda_{k,n}} & 0\
0 & \frac{1}{\lambda_{k,n}}
\end{bmatrix}
$$
对于这个高斯分量的存在概率,表示为
$$
\pi_{k,n} = \frac{a_{k} a_{k, n}}{\sum_{i} a_{i} \sum_{j} a_{i, j}}
$$
因此,给定每个高斯模型的三个参数:均值,协方差,概率。可以得到给定数据分布在整个高斯混合模型上的估计为:
$$
p\left(\mathbf{x}{1: M}\right)=\prod{m=1}^{M} \sum_{k=1}^{K} \sum_{n=1}^{N} \frac{a_{k} a_{k, n}}{\sum_{i} a_{i} \sum_{j} a_{i, j}} p\left(\mathbf{x}{m} \mid k, n\right)
$$
其中,点$\mathbf{x}
{m}$在第$i$个object产生的第$j$个part表示的高斯计算得到的似然值为
$$
p(\mathbf{x}{m}|k,n) = p(\mathbf{x}{m}|\mu_{k,n} ,\Sigma_{k,n}) = \frac{1}{(2 \pi)^{D/2} |\Sigma_{k,n}|^{1/2}}\operatorname{exp}\left(\frac{1}{2} \left(\mathbf{x}{m}-\mu{k,n} \right)^T \Sigma_{k,n}^{-1} \left(\mathbf{x}{m}-\mu{k,n} \right) \right)
$$
最大化$p\left(\mathbf{x}{1: M}\right)$,求得$\mu{k,n},\lambda_{k,n},\pi_{k,n}$,在理论上即可得到表示这个数据分布的模型。原文使用反向传播优化参数,目标是最大化$\operatorname{log }p\left(\mathbf{x}{1: M}\right)$,等价于最小化$-\operatorname{log }p\left(\mathbf{x}{1: M}\right)$。

数据集

数据(3个集群,两个正方形,一个三角形)是在线创建的,每一次创建后被随机平移、放缩、旋转到180度,最后所有点被标准化到-1到1之间。

决策

依据object $ a_{k}$和其概率最高的part $a_{k,n}$,对每个点$x_m$,其类别决策为$k^{\star}=\arg \max {k} a{k} a_{k, n} p\left(\mathbf{x}_{m} \mid k, n\right)$。

一些实现细节

  • 多维矩阵 首先要明白将所有的part、object放在一个矩阵里,每个维度的含义。笔者设定:part为(B, n_objects, n_votes, (dim_input+1+1)),object为(B, n_objects, dim_input**2+dim_speical_features+1)。搞定这些,之后可以进行矩阵拆分,对应到原论文对应的变量里。

  • BatchMLP 在计算object到part的解码时用到。每个object capsule需要一个单独的MLP来解码到对应的part capsule,也就是说,输入的object维度为[B, n_objects, n_special_features],被多个MLP计算得到结果应该是(B, n_objects, n_votes*(dim_input+1+1))。pytorch里面只有单个的MLP,我们类似原作者也实现了个BatchMLP来完成这个功能。

  • 对概率的处理 对预测的$a_k$和$a_{k,n}$使用softmax等函数进行处理,对预测的标准差加上一个$\epsilon=10^{-6}$防止分母为0.

代码部分

Set Transformer

关于Set Transformer的实现如下,笔者做了相关注释,具体每个模块实现这里不贴。简而言之,这个编码器将(B, N, dim_input)的输入转化为(B, num_outputs, dim_output)的输出。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch.nn as nn
from base import BaseModel
from model.modules.setmodules import ISAB,SAB,PMA


class SetTransformer(BaseModel):
"""
"""

def __init__(self, dim_input, num_outputs, dim_output,
num_inds=32, dim_hidden=128, num_heads=4, ln=True):
"""Set Transformer, An autoencoder model dealing with set data

Input set X with N elements, each `dim_input` dimensions, output
`num_outputs` elements, each `dim_output` dimensions.

In short, choose:
N --> num_outputs
dim_input --> dim_output

Hyper-parameters:
num_inds
dim_hidden
num_heads
ln
Args:
dim_input: Number of dimensions of one elem in input set X
num_outputs: Number of output elements
dim_output: Number of dimensions of one elem in output set
num_inds: inducing points number
dim_hidden: output dimension of one elem of middle layer
num_heads: heads number of multi-heads attention in MAB
ln: whether to use layer norm in MAB
"""
super(SetTransformer, self).__init__()
self.enc = nn.Sequential(
ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln),
ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln))
self.dec = nn.Sequential(
PMA(dim_hidden, num_heads, num_outputs, ln=ln),
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
nn.Linear(dim_hidden, dim_output))

def forward(self, X):
"""
Args:
X: (B, N, dim_input)

Returns:
output set with shape (B, num_outputs, dim_output)
"""
return self.dec(self.enc(X))

CCAE

编码器核心部分如下,可以看到可以类似17版那样用矢量表示胶囊,不过这里每个胶囊用三种不同意义的变量表示,因此后续处理也不同。

1
2
3
4
5
6
objects = self.set_transformer(x)  # (B, n_objects, dim_input**2+dim_speical_features+1)
splits = [self.dim_input**2,self.dim_input**2+self.dim_speical_features]
ov_matrix,special_features,presence=objects[:,:,:splits[0]],objects[:,:,splits[0]:splits[1]],objects[:,:,splits[1]:]

ov_matrix = ov_matrix.reshape(B, self.n_objects, self.dim_input, self.dim_input)
presence = F.softmax(presence, dim=1)

解码器,注意到这里使用了一个BatchMLP,即使用多个MLP对每个object的特殊向量进行解码,每个object都可以解码出若干个part。

1
2
3
4
5
6
7
8
9
10
x = self.bmlp(x) # (B, n_objects, n_votes*(dim_input+1+1))
x_chunk = x.chunk(self.n_votes, dim=-1)
x_object_part = torch.stack(x_chunk, dim=2) # (B, n_objects, n_votes, (dim_input+1+1))

splits = [self.dim_input, self.dim_input+1]
op_matrix = x_object_part[:,:,:,:splits[0]]
standard_deviation = x_object_part[:,:,:,splits[0]:splits[1]]
presence = x_object_part[:,:,:,splits[1]:]
presence = F.softmax(presence, dim=2)

使用无监督的决策方式,参考上文原理部分

1
2
3
4
5
6
7
8
9
# (B, 1, n_objects, 1)
object_presence = res_dict.object_presence[:, None, ...]
# (B, 1, n_objects, n_votes)
part_presence = res_dict.part_presence[:, None, ...].squeeze(-1)
# (B, M, n_objects, n_votes)
likelihood = res_dict.likelihood
a_k_n_times_p = (part_presence * likelihood).max(dim=-1, keepdim=True)[0]
expr = object_presence * a_k_n_times_p
winners = expr.max(dim=-2)[1].squeeze(-1)

数据集

这里直接复用了原本数据生成代码,搭建了一个Dataloader

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class CCAE_Dataloader(BaseDataLoader):
def __init__(self,
# for dataloader
batch_size,
shuffle=True,
validation_split=0.0,
num_workers=1,

# for dataset
shuffle_corners=True,
gaussian_noise=0.,
max_translation=1.,
rotation_percent=0.0,
which_patterns='basic',
drop_prob=0.0,
max_scale=3.,
min_scale=.1
):
self.dataset = CCAE_Dataset(
shuffle_corners,
gaussian_noise,
max_translation,
rotation_percent,
which_patterns,
drop_prob,
max_scale,
min_scale
)
super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers)

class CCAE_Dataset(data.Dataset):
def __init__(self,
shuffle_corners=True,
gaussian_noise=0.,
max_translation=1.,
rotation_percent=0.0,
which_patterns='basic',
drop_prob=0.0,
max_scale=3.,
min_scale=.1
):
self.shuffle_corners = shuffle_corners
self.scale = max_scale
self.gaussian_noise = gaussian_noise
self.max_translation = max_translation
self.rotation_percent = rotation_percent
self.which_patterns = which_patterns
self.drop_prob = drop_prob

def __len__(self):
return 10000
def __getitem__(self, item):
data = create_numpy(
1,
self.shuffle_corners,
self.gaussian_noise,
self.max_translation,
self.rotation_percent,
self.scale,
self.which_patterns,
self.drop_prob)
return data

损失

损失函数的计算方式如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def ccae_loss(res_dict, target, epsilon = 1e-6):
"""

Args:
res_dict:
target: input set with (B, k, dim_input)
epsilon: avoiding nan for reciprocal of standard deviation
Returns:
log likelihood for input dataset(here "target") , (B,)
"""
# retrieve the variable (Sorry for possible complication)
op_matrix = res_dict.op_matrix # (B, n_objects, n_votes, dim_input)
ov_matrix = res_dict.ov_matrix # (B, n_objects, dim_input, dim_input)
standard_deviation = res_dict.standard_deviation # (B, n_objects, n_votes, 1)
object_presence = res_dict.object_presence # (B, n_objects, 1)
part_presence = res_dict.part_presence # (B, n_objects, n_votes, 1)
dim_input = res_dict.dim_input
B, n_objects, n_votes, _ = standard_deviation.shape
op_matrix = op_matrix[:,:,:,:,None] # (B, n_objects, n_votes, dim_input,1)
ov_matrix = ov_matrix[:,:,None,:,:] # (B, n_objects, 1, dim_input,dim_input)

# 防止分母为0
standard_deviation = epsilon + standard_deviation[Ellipsis, None]
# 计算mu
mu = ov_matrix @ op_matrix # (B, n_objects, n_votes, dim_input,1)
# 计算协方差
identity = torch.eye(dim_input).repeat(B, n_objects, n_votes, 1, 1).to(standard_deviation.device)
sigma = identity * (1/standard_deviation) # (B, n_objects, n_votes, dim_input,dim_input)

# 计算数据集(即target)在混合模型上的似然估计
# (B, k, n_objects, n_votes)
gaussian_likelihood = gmm(mu, sigma).likelihood(target, object_presence=object_presence, part_presence=part_presence)

# 计算似然估计的对数,作为损失目标
log_likelihood = torch.log(gaussian_likelihood.sum((1,2,3))).mean()
gaussian_likelihood = gaussian_likelihood.mean()
res_dict.likelihood = -gaussian_likelihood
res_dict.log_likelihood = -log_likelihood



return res_dict

笔者又实现了一个高斯混合模型类来计算似然值,下面是计算损失的核心代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
mu = ov_matrix @ op_matrix  # (B, n_objects, n_votes, dim_input,1)
identity = (
torch.eye(dim_input)
.repeat(B, n_objects, n_votes, 1, 1)
.to(standard_deviation.device)
)
sigma = identity * (
1 / standard_deviation
) # (B, n_objects, n_votes, dim_input,dim_input)

# (B, k, n_objects, n_votes)
likelihood = gmm(mu, sigma).likelihood(
target, object_presence=object_presence, part_presence=part_presence
)
log_likelihood = torch.log(likelihood.sum((1, 2, 3))).mean()

后续思考,这个损失函数有点写复杂了,直接在model里算好就不需要这么多代码了。

高斯混合模型的核心实现

1
class GuassianMixture(object):    """    GMM for part capsules    """    def __init__(self, mu, sigma):        """        Args:            mu: (B, n_objects, n_votes, dim_input, 1)            sigma: (B, n_objects, n_votes, dim_input, dim_input)        After initialized:            mu:   (B, 1, n_objects, n_votes, dim_input, 1)            sigma:(B, 1, n_objects, n_votes, dim_input,dim_input)            multiplier:(B, 1, n_objects, n_votes, 1, 1)        """        #  Converse shape to        #  (Batch_size, num_of_points, num_of_objects, number_of_votes, ...)        mu = mu[:, None, ...]  # (B, 1, n_objects, n_votes, dim_input, 1)        sigma = sigma[:, None, ...]  # (B, 1, n_objects, n_votes, dim_input,dim_input)        self.sigma = sigma        self.mu = mu        self.sigma_inv = sigma.inverse()        D = self.sigma.shape[-1]        sigma_det = torch.det(sigma)  # (B, 1, n_objects, n_votes)        self.multiplier = (            1 / ((2 * math.pi) ** (D / 2) * sigma_det.sqrt())[..., None, None]        )    def likelihood(self, x, object_presence=None, part_presence=None):        diff = x - self.mu        exp_result = torch.exp(-0.5 * diff.transpose(-1, -2) @ self.sigma_inv @ diff)        denominator = object_presence.sum(dim=2, keepdim=True) * part_presence.sum(          dim=3, keepdim=True        )        exp_result = (object_presence * part_presence / denominator) * exp_result        gaussian_likelihood = self.multiplier * exp_result        return gaussian_likelihood.squeeze(-1).squeeze(-1)    def plot(self, choose):        raise NotImplemented

目前的效果

  • 正确分类

原数据

image-20201204231106857

无监督分类结果

image-20201204231147582

  • 错误分类

image-20201204231236948

总结

本文使用pytorch实现了原论文第一个toy experiment,做了一个简单的展示,损失使用的是
$$
p\left(\mathbf{x}{1: M}\right)=\prod{m=1}^{M} \sum_{k=1}^{K} \sum_{n=1}^{N} \frac{a_{k} a_{k, n}}{\sum_{i} a_{i} \sum_{j} a_{i, j}} p\left(\mathbf{x}_{m} \mid k, n\right)
$$
未使用原文提出的sparsity loss。

工程方面

  • 参数传递的过程中,形状为1的维度应该压缩掉

实验方面

  • 写了个重大BUG:BatchMLP忘记使用激活,梯度变为nan,还是对激活函数的理解程度不深,写MLP竟然忘记带了
  • 还需要对无监督效果进行评估,

TODO

  • 实现无监督评估方法

  • 尝试用这个模型做指导性学习

参考资料

[1] https://github.com/google-research/google-research/tree/master/stacked_capsule_autoencoders

[2] https://github.com/phanideepgampa/stacked-capsule-networks

[3] https://github.com/MuhammadMomin93/Stacked-Capsule-Autoencoders-PyTorch

[4] https://github.com/Axquaris/StackedCapsuleAutoencoders

[5] Fitting a generative model using standard divergences between measures http://www.math.ens.fr/~feydy/Teaching/DataScience/fitting_a_generative_model.html

前言

本次动手实现论文《stacked capsule autoencoders》的pytorch版本。这篇论文的原作者开源了TensorFlow版本[1],其细节和工程性都挺不错,是个参考的好范本(做研究建议直接参考原项目)。关于pytorch的实现,github也开源了相关例子[2,3,4],但这些都只实现了原文第二个实验。本文是对其原文第一个实验的复现笔记,后续也计划复现第二个实验。

全部复现代码会开源在https://github.com/QiangZiBro/stacked_capsule_autoencoders.pytorch,欢迎提issue。

复现目标

第一个实验

  • Set Transformer (直接使用原论文代码)
  • CCAE模型
  • 高斯混合模型的编程实现
  • Concellation数据集生成
  • CCAE训练
  • 可视化CCAE

前期准备

环境

  • 系统:ubuntu 18.04.04
  • 显卡:GP100
  • 环境管理:miniconda3
  • 相关第三方库:pytorch1.7

为了保证工程性以及少点重复工作,我们基于一个深度学习模板项目来进行本次实现。当然,为了可解释性,也会使用notebook进行相关可视化。同时会写一些必备的test case,来帮助我更加了解一些细节。

1
2
3
4
5
git clone https://github.com/QiangZiBro/pytorch-template
cd pytorch-template
python new_project.py ../stacked_capsule_autoencoders.pytorch
cd ../stacked_capsule_autoencoders.pytorch
wget https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore -O .gitignore

模型细节

概览

CCAE编码器为Set Transformer,解码器为mlp的自编码器,其输入是2维平面上的点集。

我们接下来总结CCAE模型的细节

Set Transformer


Set Transformer的编码器可以是连续的SAB或者连续的ISAB。使用ISAB的优点是其使用了诱导点$I \in \mathbb{R}^{m \times d}$,使得计算参数比SAB更少。
$$
Z=\operatorname{Encoder}(X)=\operatorname{SAB}(\operatorname{SAB}(X)) \in \mathbb{R}^{n \times d}
$$

$$
Z=\operatorname{Encoder}(X)=\operatorname{ISAB}{m}\left(\operatorname{ISAB}{m}(X)\right) \in \mathbb{R}^{n \times d}
$$

解码器
$$
O=\operatorname{Decoder}(Z ; \lambda)=\operatorname{rFF}\left(\operatorname{SAB}\left(\operatorname{PMA}_{k}(Z)\right)\right) \in \mathbb{R}^{k \times d}
$$


细节部分

  • $\operatorname{rFF}(x)$ 全连接 ,具体的讲,输入$n \times d$维,输出也是$n \times d$维。
  • 注意力机制:$\operatorname{Att}(Q, K, V ; \omega)=\omega\left(Q K^{\top}\right) V \in \mathbb{R}^{n \times d_v}$,其中,$\omega$是激活函数。
  • 多头注意力机制

$$
\operatorname{Multihead} (Q, K, V ; \lambda, \omega) = \operatorname{concat(Z_1,…,Z_h)}W^O \in \mathbb{R}^{n \times d},\ Z_{j}=\operatorname{Att}\left(Q W_{j}^{Q}, K W_{j}^{K}, V W_{j}^{V} ; \omega_{j}\right)
$$

  • 多头注意力模块(Multihead Attention Block) 输出维度和X维度相同
    $$
    \operatorname{MAB}(X, Y)=\operatorname{LayerNorm}(H+\operatorname{rFF}(H))
    $$
    其中$H=\text { LayerNorm }(X+\operatorname{Multihead}(X, Y, Y ; \omega))$

  • 集合注意力模块(Set Attention Block ),计算复杂度$\mathcal{O}\left(n^{2}\right)$

$$
\operatorname{SAB}(X) = \operatorname{MAB}(X,X)
$$

  • 诱导集合注意力模块(Induced Set Attention Block )

$$
\operatorname{ISAB}_m(X)=\operatorname{MAB}(X, H) \in \mathbb{R}^{n \times d}​
$$

​ 其中$H=\operatorname{MAB}(I,X) \in \mathbb{R}^{m \times d}$,$I \in \mathbb{R}^{m \times d}$为可学习参数。

  • 多头注意力机制的池化(Pooling by Multihead Attention)。池化是一种常见的聚合(aggregation)操作。上面提到,池化可以是最大或是平均。这里提出的池化是应用一个MAB在一个可学习的矩阵$S \in \mathbb{R}^{k \times d}$上。在一些聚类任务上,$k$设为我们需要的类别数。使用基于注意力的池化的直觉是,每个实例对target的重要性都不一样

$$
\operatorname{PMA}_{k}(Z)=\operatorname{MAB}(S, \operatorname{rFF}(Z))
$$

$$
H=\operatorname{SAB}\left(\operatorname{PMA}_{k}(Z)\right)
$$

其中池化操作$\operatorname{PMA}_{k}(Z)=\operatorname{MAB}(S, \operatorname{rFF}(Z)) \in \mathbb{R}^{k \times d}$,$k$表示输出集合中实例的个数,$k < n$。

CCAE

对M个2维输入点组成的集合$\mathbf{x_{1:M}}$,首先使用Set Transformer将这个集合编码为$K$个$(2\times 2+n_c+1)$的object向量,这三个数分别表示OV矩阵大小、特殊向量(即特征)、存在概率。特殊向量的尺度是个超参,原文$n_c=16$。
$$
\mathrm{OV}{1: K}, \mathbf{c}{1: K}, a_{1: K}=\mathrm{h}^{\mathrm{caps}}\left(\mathbf{x}{1: M}\right) = \operatorname{SetTransformer}\left(\mathbf{x}{1: M}\right)
$$
对每个object向量,取其特殊向量,通过mlp解码出$N$个part。其中,每个part长度为$(2+1+1)$,分别为OP矩阵、存在概率、和标准差;每个object应用一个单独的mlp,mlp结构为$n_c,128,(2+1+1)\times N$。
$$
\mathrm{OP}{k, 1: N}, a{k, 1: N}, \lambda_{k, 1: N}=\mathrm{h}{\mathrm{k}}^{\mathrm{part}}\left(\mathbf{c}{k}\right) = \operatorname{mlp_k}\left(\mathbf{c}_{k}\right)
$$
在原文例子中,$M=3, N=4$。

每个解码出的part都可以表示一个高斯分量。CCAE处理的数据是2维平面点,因此表示的高斯分量的均值是2维,协方差矩阵大小是$2 \times 2$的矩阵。具体的讲,由第$i$个object产生的第$j$个part表示的高斯分量均值为
$$
\mu_{k,n} = OV_k OP_{k,n}
$$
其中$OV_k$是$2 \times 2$的矩阵,$OP_{k,n}$是长度为2的向量。而part只有一个标量的标准差$\lambda_{k,n}$,即,原文将一个高斯分量假设为各向同性,通过标准差$\lambda_{k,n}$计算到高斯模型的协方差矩阵:
$$
\Sigma_{k,n} = \begin{bmatrix}
\frac{1}{\lambda_{k,n}} & 0\
0 & \frac{1}{\lambda_{k,n}}
\end{bmatrix}
$$
对于这个高斯分量的存在概率,表示为
$$
\pi_{k,n} = \frac{a_{k} a_{k, n}}{\sum_{i} a_{i} \sum_{j} a_{i, j}}
$$
因此,给定每个高斯模型的三个参数:均值,协方差,概率。可以得到给定数据分布在整个高斯混合模型上的估计为:
$$
p\left(\mathbf{x}{1: M}\right)=\prod{m=1}^{M} \sum_{k=1}^{K} \sum_{n=1}^{N} \frac{a_{k} a_{k, n}}{\sum_{i} a_{i} \sum_{j} a_{i, j}} p\left(\mathbf{x}{m} \mid k, n\right)
$$
其中,点$\mathbf{x}
{m}$在第$i$个object产生的第$j$个part表示的高斯计算得到的似然值为
$$
p(\mathbf{x}{m}|k,n) = p(\mathbf{x}{m}|\mu_{k,n} ,\Sigma_{k,n}) = \frac{1}{(2 \pi)^{D/2} |\Sigma_{k,n}|^{1/2}}\operatorname{exp}\left(\frac{1}{2} \left(\mathbf{x}{m}-\mu{k,n} \right)^T \Sigma_{k,n}^{-1} \left(\mathbf{x}{m}-\mu{k,n} \right) \right)
$$
最大化$p\left(\mathbf{x}{1: M}\right)$,求得$\mu{k,n},\lambda_{k,n},\pi_{k,n}$,在理论上即可得到表示这个数据分布的模型。原文使用反向传播优化参数,目标是最大化$\operatorname{log }p\left(\mathbf{x}{1: M}\right)$,等价于最小化$-\operatorname{log }p\left(\mathbf{x}{1: M}\right)$。

数据集

数据(3个集群,两个正方形,一个三角形)是在线创建的,每一次创建后被随机平移、放缩、旋转到180度,最后所有点被标准化到-1到1之间。

决策

依据object $ a_{k}$和其概率最高的part $a_{k,n}$,对每个点$x_m$,其类别决策为$k^{\star}=\arg \max {k} a{k} a_{k, n} p\left(\mathbf{x}_{m} \mid k, n\right)$。

一些实现细节

  • 多维矩阵 首先要明白将所有的part、object放在一个矩阵里,每个维度的含义。笔者设定:part为(B, n_objects, n_votes, (dim_input+1+1)),object为(B, n_objects, dim_input**2+dim_speical_features+1)。搞定这些,之后可以进行矩阵拆分,对应到原论文对应的变量里。

  • BatchMLP 在计算object到part的解码时用到。每个object capsule需要一个单独的MLP来解码到对应的part capsule,也就是说,输入的object维度为[B, n_objects, n_special_features],被多个MLP计算得到结果应该是(B, n_objects, n_votes*(dim_input+1+1))。pytorch里面只有单个的MLP,我们类似原作者也实现了个BatchMLP来完成这个功能。

  • 对概率的处理 对预测的$a_k$和$a_{k,n}$使用softmax等函数进行处理,对预测的标准差加上一个$\epsilon=10^{-6}$防止分母为0.

代码部分

Set Transformer

关于Set Transformer的实现如下,笔者做了相关注释,具体每个模块实现这里不贴。简而言之,这个编码器将(B, N, dim_input)的输入转化为(B, num_outputs, dim_output)的输出。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch.nn as nn
from base import BaseModel
from model.modules.setmodules import ISAB,SAB,PMA


class SetTransformer(BaseModel):
"""
"""

def __init__(self, dim_input, num_outputs, dim_output,
num_inds=32, dim_hidden=128, num_heads=4, ln=True):
"""Set Transformer, An autoencoder model dealing with set data

Input set X with N elements, each `dim_input` dimensions, output
`num_outputs` elements, each `dim_output` dimensions.

In short, choose:
N --> num_outputs
dim_input --> dim_output

Hyper-parameters:
num_inds
dim_hidden
num_heads
ln
Args:
dim_input: Number of dimensions of one elem in input set X
num_outputs: Number of output elements
dim_output: Number of dimensions of one elem in output set
num_inds: inducing points number
dim_hidden: output dimension of one elem of middle layer
num_heads: heads number of multi-heads attention in MAB
ln: whether to use layer norm in MAB
"""
super(SetTransformer, self).__init__()
self.enc = nn.Sequential(
ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln),
ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln))
self.dec = nn.Sequential(
PMA(dim_hidden, num_heads, num_outputs, ln=ln),
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
nn.Linear(dim_hidden, dim_output))

def forward(self, X):
"""
Args:
X: (B, N, dim_input)

Returns:
output set with shape (B, num_outputs, dim_output)
"""
return self.dec(self.enc(X))

CCAE

编码器核心部分如下,可以看到可以类似17版那样用矢量表示胶囊,不过这里每个胶囊用三种不同意义的变量表示,因此后续处理也不同。

1
2
3
4
5
6
objects = self.set_transformer(x)  # (B, n_objects, dim_input**2+dim_speical_features+1)
splits = [self.dim_input**2,self.dim_input**2+self.dim_speical_features]
ov_matrix,special_features,presence=objects[:,:,:splits[0]],objects[:,:,splits[0]:splits[1]],objects[:,:,splits[1]:]

ov_matrix = ov_matrix.reshape(B, self.n_objects, self.dim_input, self.dim_input)
presence = F.softmax(presence, dim=1)

解码器,注意到这里使用了一个BatchMLP,即使用多个MLP对每个object的特殊向量进行解码,每个object都可以解码出若干个part。

1
2
3
4
5
6
7
8
9
10
11

x = self.bmlp(x) # (B, n_objects, n_votes*(dim_input+1+1))
x_chunk = x.chunk(self.n_votes, dim=-1)
x_object_part = torch.stack(x_chunk, dim=2) # (B, n_objects, n_votes, (dim_input+1+1))

splits = [self.dim_input, self.dim_input+1]
op_matrix = x_object_part[:,:,:,:splits[0]]
standard_deviation = x_object_part[:,:,:,splits[0]:splits[1]]
presence = x_object_part[:,:,:,splits[1]:]
presence = F.softmax(presence, dim=2)

使用无监督的决策方式,参考上文原理部分

1
2
3
4
5
6
7
8
9
# (B, 1, n_objects, 1)
object_presence = res_dict.object_presence[:, None, ...]
# (B, 1, n_objects, n_votes)
part_presence = res_dict.part_presence[:, None, ...].squeeze(-1)
# (B, M, n_objects, n_votes)
likelihood = res_dict.likelihood
a_k_n_times_p = (part_presence * likelihood).max(dim=-1, keepdim=True)[0]
expr = object_presence * a_k_n_times_p
winners = expr.max(dim=-2)[1].squeeze(-1)

数据集

这里直接复用了原本数据生成代码,搭建了一个Dataloader

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class CCAE_Dataloader(BaseDataLoader):
def __init__(self,
# for dataloader
batch_size,
shuffle=True,
validation_split=0.0,
num_workers=1,

# for dataset
shuffle_corners=True,
gaussian_noise=0.,
max_translation=1.,
rotation_percent=0.0,
which_patterns='basic',
drop_prob=0.0,
max_scale=3.,
min_scale=.1
):
self.dataset = CCAE_Dataset(
shuffle_corners,
gaussian_noise,
max_translation,
rotation_percent,
which_patterns,
drop_prob,
max_scale,
min_scale
)
super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers)

class CCAE_Dataset(data.Dataset):
def __init__(self,
shuffle_corners=True,
gaussian_noise=0.,
max_translation=1.,
rotation_percent=0.0,
which_patterns='basic',
drop_prob=0.0,
max_scale=3.,
min_scale=.1
):
self.shuffle_corners = shuffle_corners
self.scale = max_scale
self.gaussian_noise = gaussian_noise
self.max_translation = max_translation
self.rotation_percent = rotation_percent
self.which_patterns = which_patterns
self.drop_prob = drop_prob

def __len__(self):
return 10000
def __getitem__(self, item):
data = create_numpy(
1,
self.shuffle_corners,
self.gaussian_noise,
self.max_translation,
self.rotation_percent,
self.scale,
self.which_patterns,
self.drop_prob)
return data

损失

损失函数的计算方式如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def ccae_loss(res_dict, target, epsilon = 1e-6):
"""

Args:
res_dict:
target: input set with (B, k, dim_input)
epsilon: avoiding nan for reciprocal of standard deviation
Returns:
log likelihood for input dataset(here "target") , (B,)
"""
# retrieve the variable (Sorry for possible complication)
op_matrix = res_dict.op_matrix # (B, n_objects, n_votes, dim_input)
ov_matrix = res_dict.ov_matrix # (B, n_objects, dim_input, dim_input)
standard_deviation = res_dict.standard_deviation # (B, n_objects, n_votes, 1)
object_presence = res_dict.object_presence # (B, n_objects, 1)
part_presence = res_dict.part_presence # (B, n_objects, n_votes, 1)
dim_input = res_dict.dim_input
B, n_objects, n_votes, _ = standard_deviation.shape
op_matrix = op_matrix[:,:,:,:,None] # (B, n_objects, n_votes, dim_input,1)
ov_matrix = ov_matrix[:,:,None,:,:] # (B, n_objects, 1, dim_input,dim_input)

# 防止分母为0
standard_deviation = epsilon + standard_deviation[Ellipsis, None]
# 计算mu
mu = ov_matrix @ op_matrix # (B, n_objects, n_votes, dim_input,1)
# 计算协方差
identity = torch.eye(dim_input).repeat(B, n_objects, n_votes, 1, 1).to(standard_deviation.device)
sigma = identity * (1/standard_deviation) # (B, n_objects, n_votes, dim_input,dim_input)

# 计算数据集(即target)在混合模型上的似然估计
# (B, k, n_objects, n_votes)
gaussian_likelihood = gmm(mu, sigma).likelihood(target, object_presence=object_presence, part_presence=part_presence)

# 计算似然估计的对数,作为损失目标
log_likelihood = torch.log(gaussian_likelihood.sum((1,2,3))).mean()
gaussian_likelihood = gaussian_likelihood.mean()
res_dict.likelihood = -gaussian_likelihood
res_dict.log_likelihood = -log_likelihood



return res_dict

笔者又实现了一个高斯混合模型类来计算似然值,下面是计算损失的核心代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
mu = ov_matrix @ op_matrix  # (B, n_objects, n_votes, dim_input,1)
identity = (
torch.eye(dim_input)
.repeat(B, n_objects, n_votes, 1, 1)
.to(standard_deviation.device)
)
sigma = identity * (
1 / standard_deviation
) # (B, n_objects, n_votes, dim_input,dim_input)

# (B, k, n_objects, n_votes)
likelihood = gmm(mu, sigma).likelihood(
target, object_presence=object_presence, part_presence=part_presence
)
log_likelihood = torch.log(likelihood.sum((1, 2, 3))).mean()

后续思考,这个损失函数有点写复杂了,直接在model里算好就不需要这么多代码了。

高斯混合模型的核心实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class GuassianMixture(object):
"""
GMM for part capsules
"""

def __init__(self, mu, sigma):
"""
Args:
mu: (B, n_objects, n_votes, dim_input, 1)
sigma: (B, n_objects, n_votes, dim_input, dim_input)

After initialized:
mu: (B, 1, n_objects, n_votes, dim_input, 1)
sigma:(B, 1, n_objects, n_votes, dim_input,dim_input)
multiplier:(B, 1, n_objects, n_votes, 1, 1)
"""
# Converse shape to
# (Batch_size, num_of_points, num_of_objects, number_of_votes, ...)

mu = mu[:, None, ...] # (B, 1, n_objects, n_votes, dim_input, 1)
sigma = sigma[:, None, ...] # (B, 1, n_objects, n_votes, dim_input,dim_input)

self.sigma = sigma
self.mu = mu
self.sigma_inv = sigma.inverse()
D = self.sigma.shape[-1]
sigma_det = torch.det(sigma) # (B, 1, n_objects, n_votes)
self.multiplier = (
1 / ((2 * math.pi) ** (D / 2) * sigma_det.sqrt())[..., None, None]
)

def likelihood(self, x, object_presence=None, part_presence=None):
diff = x - self.mu
exp_result = torch.exp(-0.5 * diff.transpose(-1, -2) @ self.sigma_inv @ diff)

denominator = object_presence.sum(dim=2, keepdim=True) * part_presence.sum(
dim=3, keepdim=True
)
exp_result = (object_presence * part_presence / denominator) * exp_result
gaussian_likelihood = self.multiplier * exp_result
return gaussian_likelihood.squeeze(-1).squeeze(-1)

def plot(self, choose):
raise NotImplemented

目前的效果

  • 正确分类

原数据

image-20201204231106857

无监督分类结果

image-20201204231147582

  • 错误分类

image-20201204231236948

总结

本文使用pytorch实现了原论文第一个toy experiment,做了一个简单的展示,损失使用的是
$$
p\left(\mathbf{x}{1: M}\right)=\prod{m=1}^{M} \sum_{k=1}^{K} \sum_{n=1}^{N} \frac{a_{k} a_{k, n}}{\sum_{i} a_{i} \sum_{j} a_{i, j}} p\left(\mathbf{x}_{m} \mid k, n\right)
$$
未使用原文提出的sparsity loss。

工程方面

  • 参数传递的过程中,形状为1的维度应该压缩掉

实验方面

  • 写了个重大BUG:BatchMLP忘记使用激活,梯度变为nan,还是对激活函数的理解程度不深,写MLP竟然忘记带了
  • 还需要对无监督效果进行评估,

TODO

  • 实现无监督评估方法

  • 尝试用这个模型做指导性学习

参考资料

[1] https://github.com/google-research/google-research/tree/master/stacked_capsule_autoencoders

[2] https://github.com/phanideepgampa/stacked-capsule-networks

[3] https://github.com/MuhammadMomin93/Stacked-Capsule-Autoencoders-PyTorch

[4] https://github.com/Axquaris/StackedCapsuleAutoencoders

[5] Fitting a generative model using standard divergences between measures http://www.math.ens.fr/~feydy/Teaching/DataScience/fitting_a_generative_model.html

欢迎关注我的其它发布渠道