Code of Pixel-to-Prototype Constrast

Generate CAMs

  • Feature map
  • Class feature map
  • Score of class
  • CAMs

Pixel-to-Prototype Contrast

  • Pseudo mask
  • Pixel-wise projected feature
  • Pixel-to-prototype contrast
    • Prototype set
    • Temperature
    • Contrast 像素特征与原型的相似度

Prototype Estimation in Batch

  • Top K pixels of class c
    • CAM as confidences
    • Estimate prototypes from pixel-wise feature embeddings that are with the top K confidences
  • Prototype

Loss

  • Cross Prototype Contrast

  • Cross CAM Contrast

  • Intra-view Contrast

    • Strategy to slove the matter of in accurate pseudo label [50]
      • Semi-hard prototype mining
      • Hard pixel sampling

Code

归一化

归一化

  • 作用
    • 保证所有元素之和为 1
    • 将向量转换为概率分布

归一化

1
2
3
4
# 按通道执行L2归一化
v = v / (torch.norm(v, dim=1, keepdim=True) + 1e-5)
# or
v = torch.nn.functional.normalize(v, dim=1)
  • 作用
    • 方向不变性:向量的方向不变,长度变为 1,使得向量表示不再依赖于其大小
    • 数值稳定性:将向量的大小规范在一个相对较小的区间
    • 减小特征尺度的差异
    • 便于执行相似性度量

Max 归一化

  • 归一化后向量的最大值为 1

Max-Min 归一化

  • 归一化后向量值范围为[0, 1]

Forward

  • cam

    1
    2
    3
    4
    # fea是最后一层输出的特征图
    self.fc8 = nn.Conv2d(4096, 21, 1, bias=False)
    cam = self.fc8(fea)
    cam = torch.nn.functional.interpolate(cam, (H, W), mode='bilinear', align_corners=True)
  • cam_rv_down

    • 清洗 CAM

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      with torch.no_grad():
      cam_d = torch.nn.functional.relu(cam.detach())
      # max norm
      cam_d_max = torch.max(cam_d.view(n, c, -1), dim=-1)[0].view(n, c, 1, 1)+1e-5
      cam_d_norm = torch.nn.functional.relu(cam_d - 1e-5) / cam_d_max
      # 计算保留概率值最大分类,反相为背景概率,其余分类置0
      cam_d_norm[:, 0, :, :] = 1 - torch.max(cam_d_norm[:, 1:, :, :], dim=1)[0]
      cam_max = torch.max(cam_d_norm[:,1:,:,:], dim=1, keepdim=True)[0]
      cam_d_norm[:,1:,:,:][cam_d_norm[:,1:,:,:] < cam_max] = 0

    • 增强 CAM

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      # 根据像素相似度调整CAM
      cam_rv_down = self.PCM(cam_d_norm, f)

      # PCM
      def PCM(self, cam, f):
      n,c,h,w = f.size()
      cam = torch.nn.functional.interpolate(cam, (h,w), mode='bilinear', align_corners=True).view(n,-1,h*w)
      # 多尺度特征融合
      f = self.f9(f)
      f = f.view(n, -1, h*w)
      # 特征按通道L2归一化
      f = f / (torch.norm(f, dim=1, keepdim=True) + 1e-5)
      # 计算像素相似度矩阵
      aff = torch.nn.functional.relu(torch.matmul(f.transpose(1, 2), f), inplace=True)
      # 相似度矩阵L1归一化
      aff = aff/(torch.sum(aff, dim=1, keepdim=True) + 1e-5)
      # CAM加权
      cam_rv = torch.matmul(cam, aff).view(n, -1, h, w)

      return cam_rv
  • cam_rv

    1
    cam_rv = torch.nn.functional.interpolate(cam_rv_down, (H,W), mode='bilinear', align_corners=True)
  • f_proj

    1
    2
    self.fc_proj = torch.nn.Conv2d(4096, 128, 1, bias=False)
    f_proj = torch.nn.functional.relu(self.fc_proj(fea), inplace=True)
  • prototype

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    f_proj1 = torch.nn.functional.interpolate(f_proj1, size=(128 // 8, 128 // 8), mode='bilinear', align_corners=True)
    cam_rv1_down = torch.nn.functional.interpolate(cam_rv1_down, size=(128 // 8, 128 // 8), mode='bilinear', align_corners=True)
    cam_rv2_down = cam_rv2_down

    with torch.no_grad():
    fea1 = f_proj1.detach()
    c_fea1 = fea1.shape[1]
    cam_rv1_down = torch.nn.functional.relu(cam_rv1_down.detach())
    # CAM Max-min归一化
    n1, c1, h1, w1 = cam_rv1_down.shape
    max1 = torch.max(cam_rv1_down.view(n1, c1, -1), dim=-1)[0].view(n1, c1, 1, 1)
    min1 = torch.min(cam_rv1_down.view(n1, c1, -1), dim=-1)[0].view(n1, c1, 1, 1)
    cam_rv1_down[cam_rv1_down < min1 + 1e-5] = 0.
    norm_cam1 = (cam_rv1_down - min1 - 1e-5) / (max1 - min1 + 1e-5)
    cam_rv1_down = norm_cam1
    # 设置背景阈值
    cam_rv1_down[:, 0, :, :] = args.bg_threshold
    # 根据图像级标签保留相应的类别
    scores1 = torch.nn.functional.softmax(cam_rv1_down * label, dim=1)

    # 计算伪标签
    pseudo_label1 = scores1.argmax(dim=1, keepdim=True)
    n_sc1, c_sc1, h_sc1, w_sc1 = scores1.shape
    scores1 = scores1.transpose(0, 1)
    fea1 = fea1.permute(0, 2, 3, 1).reshape(-1, c_fea1)

    # 获取各个分类CAM值最高的值与索引
    top_values, top_indices = torch.topk(cam_rv1_down.transpose(0, 1).reshape(c_sc1, -1), k=h_sc1 * w_sc1 // 8, dim=-1)
    prototypes1 = torch.zeros(c_sc1, c_fea1).cuda() # [21, 128]
    # 遍历各个分类
    for i in range(c_sc1):
    # 获取k个像素对应的特征
    top_fea = fea1[top_indices[i]]
    # CAM值加权平均k个特征得到分类原型
    prototypes1[i] = torch.sum(top_values[i].unsqueeze(-1) * top_fea, dim=0) / torch.sum(top_values[i])
    # 各个原型L2归一化
    prototypes1 = torch.nn.functional.normalize(prototypes1, dim=-1)
  • prototype similarity

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    n_f, c_f, h_f, w_f = f_proj1.shape
    # [N, H, W, C] -> [N x H x W, C]
    f_proj1 = f_proj1.permute(0, 2, 3, 1).reshape(n_f * h_f * w_f, c_f)
    # 特征L2归一化
    f_proj1 = torch.nn.functional.normalize(f_proj1, dim=-1)
    pseudo_label1 = pseudo_label1.reshape(-1)
    positives1 = prototypes2[pseudo_label1]
    negitives1 = prototypes2

    # for target
    n_f, c_f, h_f, w_f = f_proj2.shape
    f_proj2 = f_proj2.permute(0, 2, 3, 1).reshape(n_f * h_f * w_f, c_f)
    f_proj2 = torch.nn.functional.normalize(f_proj2, dim=-1)
    pseudo_label2 = pseudo_label2.reshape(-1)
    positives2 = prototypes1[pseudo_label2]
    negitives2 = prototypes1
    A1 = torch.exp(torch.sum(f_proj1 * positives1, dim=-1) / 0.1)
    A2 = torch.sum(torch.exp(torch.matmul(f_proj1, negitives1.transpose(0, 1)) / 0.1), dim=-1)
    loss_nce1 = torch.mean(-1 * torch.log(A1 / A2))

    A3 = torch.exp(torch.sum(f_proj2 * positives2, dim=-1) / 0.1)
    A4 = torch.sum(torch.exp(torch.matmul(f_proj2, negitives2.transpose(0, 1)) / 0.1), dim=-1)
    loss_nce2 = torch.mean(-1 * torch.log(A3 / A4))

    loss_cross_nce = 0.1 * (loss_nce1 + loss_nce2) / 2
Author

derolol

Posted on

2023-11-14

Updated on

2023-11-14

Licensed under

p