Code of Pixel-to-Prototype Constrast
Generate CAMs
- Feature map
- Class feature map
- Score of class
- CAMs
Pixel-to-Prototype Contrast
- Pseudo mask
- Pixel-wise projected feature
- Pixel-to-prototype contrast
- Prototype set
- Temperature
- Contrast 像素特征与原型的相似度
- Prototype set
Prototype Estimation in Batch
- Top K pixels of class c
- CAM as confidences
- Estimate prototypes from pixel-wise feature embeddings that are with the top K confidences
- Prototype
Loss
Cross Prototype Contrast
Cross CAM Contrast
Intra-view Contrast
- Strategy to slove the matter of in accurate pseudo label [50]
- Semi-hard prototype mining
- Hard pixel sampling
- Strategy to slove the matter of in accurate pseudo label [50]
Code
归一化
归一化
- 作用
- 保证所有元素之和为 1
- 将向量转换为概率分布
归一化
1 | # 按通道执行L2归一化 |
- 作用
- 方向不变性:向量的方向不变,长度变为 1,使得向量表示不再依赖于其大小
- 数值稳定性:将向量的大小规范在一个相对较小的区间
- 减小特征尺度的差异
- 便于执行相似性度量
Max 归一化
- 归一化后向量的最大值为 1
Max-Min 归一化
- 归一化后向量值范围为[0, 1]
Forward
cam
1
2
3
4# fea是最后一层输出的特征图
self.fc8 = nn.Conv2d(4096, 21, 1, bias=False)
cam = self.fc8(fea)
cam = torch.nn.functional.interpolate(cam, (H, W), mode='bilinear', align_corners=True)cam_rv_down
清洗 CAM
1
2
3
4
5
6
7
8
9
10with torch.no_grad():
cam_d = torch.nn.functional.relu(cam.detach())
# max norm
cam_d_max = torch.max(cam_d.view(n, c, -1), dim=-1)[0].view(n, c, 1, 1)+1e-5
cam_d_norm = torch.nn.functional.relu(cam_d - 1e-5) / cam_d_max
# 计算保留概率值最大分类,反相为背景概率,其余分类置0
cam_d_norm[:, 0, :, :] = 1 - torch.max(cam_d_norm[:, 1:, :, :], dim=1)[0]
cam_max = torch.max(cam_d_norm[:,1:,:,:], dim=1, keepdim=True)[0]
cam_d_norm[:,1:,:,:][cam_d_norm[:,1:,:,:] < cam_max] = 0增强 CAM
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20# 根据像素相似度调整CAM
cam_rv_down = self.PCM(cam_d_norm, f)
# PCM
def PCM(self, cam, f):
n,c,h,w = f.size()
cam = torch.nn.functional.interpolate(cam, (h,w), mode='bilinear', align_corners=True).view(n,-1,h*w)
# 多尺度特征融合
f = self.f9(f)
f = f.view(n, -1, h*w)
# 特征按通道L2归一化
f = f / (torch.norm(f, dim=1, keepdim=True) + 1e-5)
# 计算像素相似度矩阵
aff = torch.nn.functional.relu(torch.matmul(f.transpose(1, 2), f), inplace=True)
# 相似度矩阵L1归一化
aff = aff/(torch.sum(aff, dim=1, keepdim=True) + 1e-5)
# CAM加权
cam_rv = torch.matmul(cam, aff).view(n, -1, h, w)
return cam_rv
cam_rv
1
cam_rv = torch.nn.functional.interpolate(cam_rv_down, (H,W), mode='bilinear', align_corners=True)
f_proj
1
2self.fc_proj = torch.nn.Conv2d(4096, 128, 1, bias=False)
f_proj = torch.nn.functional.relu(self.fc_proj(fea), inplace=True)prototype
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37f_proj1 = torch.nn.functional.interpolate(f_proj1, size=(128 // 8, 128 // 8), mode='bilinear', align_corners=True)
cam_rv1_down = torch.nn.functional.interpolate(cam_rv1_down, size=(128 // 8, 128 // 8), mode='bilinear', align_corners=True)
cam_rv2_down = cam_rv2_down
with torch.no_grad():
fea1 = f_proj1.detach()
c_fea1 = fea1.shape[1]
cam_rv1_down = torch.nn.functional.relu(cam_rv1_down.detach())
# CAM Max-min归一化
n1, c1, h1, w1 = cam_rv1_down.shape
max1 = torch.max(cam_rv1_down.view(n1, c1, -1), dim=-1)[0].view(n1, c1, 1, 1)
min1 = torch.min(cam_rv1_down.view(n1, c1, -1), dim=-1)[0].view(n1, c1, 1, 1)
cam_rv1_down[cam_rv1_down < min1 + 1e-5] = 0.
norm_cam1 = (cam_rv1_down - min1 - 1e-5) / (max1 - min1 + 1e-5)
cam_rv1_down = norm_cam1
# 设置背景阈值
cam_rv1_down[:, 0, :, :] = args.bg_threshold
# 根据图像级标签保留相应的类别
scores1 = torch.nn.functional.softmax(cam_rv1_down * label, dim=1)
# 计算伪标签
pseudo_label1 = scores1.argmax(dim=1, keepdim=True)
n_sc1, c_sc1, h_sc1, w_sc1 = scores1.shape
scores1 = scores1.transpose(0, 1)
fea1 = fea1.permute(0, 2, 3, 1).reshape(-1, c_fea1)
# 获取各个分类CAM值最高的值与索引
top_values, top_indices = torch.topk(cam_rv1_down.transpose(0, 1).reshape(c_sc1, -1), k=h_sc1 * w_sc1 // 8, dim=-1)
prototypes1 = torch.zeros(c_sc1, c_fea1).cuda() # [21, 128]
# 遍历各个分类
for i in range(c_sc1):
# 获取k个像素对应的特征
top_fea = fea1[top_indices[i]]
# CAM值加权平均k个特征得到分类原型
prototypes1[i] = torch.sum(top_values[i].unsqueeze(-1) * top_fea, dim=0) / torch.sum(top_values[i])
# 各个原型L2归一化
prototypes1 = torch.nn.functional.normalize(prototypes1, dim=-1)prototype similarity
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25n_f, c_f, h_f, w_f = f_proj1.shape
# [N, H, W, C] -> [N x H x W, C]
f_proj1 = f_proj1.permute(0, 2, 3, 1).reshape(n_f * h_f * w_f, c_f)
# 特征L2归一化
f_proj1 = torch.nn.functional.normalize(f_proj1, dim=-1)
pseudo_label1 = pseudo_label1.reshape(-1)
positives1 = prototypes2[pseudo_label1]
negitives1 = prototypes2
# for target
n_f, c_f, h_f, w_f = f_proj2.shape
f_proj2 = f_proj2.permute(0, 2, 3, 1).reshape(n_f * h_f * w_f, c_f)
f_proj2 = torch.nn.functional.normalize(f_proj2, dim=-1)
pseudo_label2 = pseudo_label2.reshape(-1)
positives2 = prototypes1[pseudo_label2]
negitives2 = prototypes1
A1 = torch.exp(torch.sum(f_proj1 * positives1, dim=-1) / 0.1)
A2 = torch.sum(torch.exp(torch.matmul(f_proj1, negitives1.transpose(0, 1)) / 0.1), dim=-1)
loss_nce1 = torch.mean(-1 * torch.log(A1 / A2))
A3 = torch.exp(torch.sum(f_proj2 * positives2, dim=-1) / 0.1)
A4 = torch.sum(torch.exp(torch.matmul(f_proj2, negitives2.transpose(0, 1)) / 0.1), dim=-1)
loss_nce2 = torch.mean(-1 * torch.log(A3 / A4))
loss_cross_nce = 0.1 * (loss_nce1 + loss_nce2) / 2