Skip to content

Is it a error in loss function ? #2

@YChienHung

Description

@YChienHung
    def compute(self, data: Dict[str, torch.Tensor],
                num_objects: List[int]) -> Dict[str, torch.Tensor]:
        batch_size, num_frames = data['rgb'].shape[:2]
        losses = defaultdict(float)
        t_range = range(1, num_frames)

        for bi in range(batch_size):
            logits = torch.stack(
                [data[f'logits_{ti}'][bi, :num_objects[bi] + 1] for ti in t_range], dim=0)

            cls_gt = data['cls_gt'][bi, 1:]  # remove gt for the first frame
            soft_gt = cls_to_one_hot(cls_gt, num_objects[bi])

            loss_ce, loss_dice = self.mask_loss(logits, soft_gt)
            losses['loss_ce'] += loss_ce / batch_size
            losses['loss_dice'] += loss_dice / batch_size
        
        # start
        aux = [data[f'aux_{ti}'] for ti in t_range]
        if 'sensory_logits' in aux[0]:
            sensory_log = torch.stack(
                [a['sensory_logits'][bi, :num_objects[bi] + 1] for a in aux], dim=0)
            loss_ce, loss_dice = self.mask_loss(sensory_log, soft_gt)
            losses['aux_sensory_ce'] += loss_ce / batch_size * self.sensory_weight
            losses['aux_sensory_dice'] += loss_dice / batch_size * self.sensory_weight

        if 'q_logits' in aux[0]:
            num_levels = aux[0]['q_logits'].shape[2]

            for l in range(num_levels):
                query_log = torch.stack(
                    [a['q_logits'][bi, :num_objects[bi] + 1, l] for a in aux], dim=0)

                loss_ce, loss_dice = self.mask_loss(query_log, soft_gt)
                losses[f'aux_query_ce_l{l}'] += loss_ce / batch_size * self.query_weight
                losses[f'aux_query_dice_l{l}'] += loss_dice / batch_size * self.query_weight
         # end

        losses['total_loss'] = sum(losses.values())

        return losses

I find that the code start from aux = [data[f'aux_{ti}'] for ti in t_range] is not in the loop, so is it a error ?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions