def compute(self, data: Dict[str, torch.Tensor],
num_objects: List[int]) -> Dict[str, torch.Tensor]:
batch_size, num_frames = data['rgb'].shape[:2]
losses = defaultdict(float)
t_range = range(1, num_frames)
for bi in range(batch_size):
logits = torch.stack(
[data[f'logits_{ti}'][bi, :num_objects[bi] + 1] for ti in t_range], dim=0)
cls_gt = data['cls_gt'][bi, 1:] # remove gt for the first frame
soft_gt = cls_to_one_hot(cls_gt, num_objects[bi])
loss_ce, loss_dice = self.mask_loss(logits, soft_gt)
losses['loss_ce'] += loss_ce / batch_size
losses['loss_dice'] += loss_dice / batch_size
# start
aux = [data[f'aux_{ti}'] for ti in t_range]
if 'sensory_logits' in aux[0]:
sensory_log = torch.stack(
[a['sensory_logits'][bi, :num_objects[bi] + 1] for a in aux], dim=0)
loss_ce, loss_dice = self.mask_loss(sensory_log, soft_gt)
losses['aux_sensory_ce'] += loss_ce / batch_size * self.sensory_weight
losses['aux_sensory_dice'] += loss_dice / batch_size * self.sensory_weight
if 'q_logits' in aux[0]:
num_levels = aux[0]['q_logits'].shape[2]
for l in range(num_levels):
query_log = torch.stack(
[a['q_logits'][bi, :num_objects[bi] + 1, l] for a in aux], dim=0)
loss_ce, loss_dice = self.mask_loss(query_log, soft_gt)
losses[f'aux_query_ce_l{l}'] += loss_ce / batch_size * self.query_weight
losses[f'aux_query_dice_l{l}'] += loss_dice / batch_size * self.query_weight
# end
losses['total_loss'] = sum(losses.values())
return losses
I find that the code start from
aux = [data[f'aux_{ti}'] for ti in t_range]is not in the loop, so is it a error ?