Skip to content

losses

dreem.training.losses

Module containing different loss functions to be optimized.

AssoLoss

Bases: Module

Default association loss used for training GTR model.

Source code in dreem/training/losses.py
class AssoLoss(nn.Module):
    """Default association loss used for training GTR model."""

    def __init__(
        self,
        neg_unmatched: bool = False,
        epsilon: float = 1e-4,
        asso_weight: float = 1.0,
    ):
        """Initialize Loss function.

        Args:
            neg_unmatched: Whether or not to set unmatched objects to background
            epsilon: small number used for numeric precision to prevent dividing by zero
            asso_weight: How much to weight the association loss by
        """
        super().__init__()

        self.neg_unmatched = neg_unmatched
        self.epsilon = epsilon
        self.asso_weight = asso_weight

    def forward(
        self, asso_preds: List[torch.Tensor], frames: List["Frame"]
    ) -> torch.Tensor:
        """Calculate association loss.

        Args:
            asso_preds: a list containing the association matrix at each frame
            frames: a list of Frames containing gt labels.

        Returns:
            the association loss between predicted association and actual
        """
        # get number of detected objects and ground truth ids
        n_t = [frame.num_detected for frame in frames]
        try:
            target_inst_id = torch.cat(
                [frame.get_gt_track_ids().to(asso_preds[-1].device) for frame in frames]
            )
        except RuntimeError as e:
            print([frame.get_gt_track_ids().device for frame in frames])
            raise (e)
        instances = [instance for frame in frames for instance in frame.instances]

        # for now set equal since detections are fixed
        pred_box = get_boxes(instances)
        pred_time, _ = get_times(instances)
        pred_box = torch.nanmean(pred_box, axis=1)
        target_box, target_time = pred_box, pred_time

        # todo: we should maybe reconsider how we label gt instances. The second
        # criterion will return true on a single instance video, for example.
        # For now we can ignore this since we train on dense labels.

        """
            # Return a 0 loss if any of the 2 criteria are met
            # 1. the video doesn’t have gt bboxes
            # 2. the maximum id is zero

        sum_instance_lengths = sum(len(x) for x in instances)
        max_instance_lengths = max(
            x["gt_track_ids"].max().item() for x in instances if len(x) > 0
        )

        if sum_instance_lengths == 0 or max_instance_lengths == 0:
            print("No bounding boxes detected, returning zero loss")
            print(f"Sum instance lengths: {sum_instance_lengths}")
            print(f"Max instance lengths: {max_instance_lengths}")
            loss = asso_preds[0].new_zeros((1,), dtype=torch.float32)[0]
            return loss
        """

        asso_gt, match_cues = self._get_asso_gt(
            pred_box, pred_time, target_box, target_time, target_inst_id, n_t
        )

        loss = sum(
            [
                self.detr_asso_loss(asso_pred, asso_gt, match_cues, n_t)
                for asso_pred in asso_preds
            ]
        )

        loss *= self.asso_weight

        return loss

    def _get_asso_gt(
        self,
        pred_box: torch.Tensor,
        pred_time: torch.Tensor,
        target_box: torch.Tensor,
        target_time: torch.Tensor,
        target_inst_id: torch.Tensor,
        n_t: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute the association ground truth for a batch.

        Args:
            pred_box: predicted bounding boxes (N x 4)
            pred_time: predicted time intervals (N,)
            target_box: target bounding boxes (N x 4)
            target_time: target time intervals (N,)
            target_inst_id: target instance IDs (N,)
            n_t: number of ground truth instances (N,)

        Returns:
            A tuple containing:
                asso_gt: Ground truth association matrix (K x N) denoting ground
                    truth instances over time
                match_cues: Tensor indicating which instance is assigned to each gt
                    detection (K x 3) or (N,)
        """
        # compute ious over bboxes, ignore pairs with different time stamps
        ious = torchvision.ops.box_iou(pred_box, target_box)
        ious[pred_time[:, None] != target_time[None, :]] = -1.0

        # get unique instance ids
        inst_ids = torch.unique(target_inst_id[target_inst_id > -1])

        # initialize tensors
        K, N = len(inst_ids), len(pred_box)
        match_cues = pred_box.new_full((N,), -1, dtype=torch.long)
        T = len(n_t)
        asso_gt = pred_box.new_zeros((K, T), dtype=torch.long)

        # split ious by frames
        ious_per_frame = ious.split(n_t, dim=0)

        for k, inst_id in enumerate(inst_ids):
            # get ground truth indices, init index
            target_inds = target_inst_id == inst_id
            base_ind = 0

            for t in range(T):
                # get relevant ious
                iou_t = ious_per_frame[t][:, target_inds]

                # if there are no detections, asso_gt = # gt instances at time step
                if iou_t.numel() == 0:
                    asso_gt[k, t] = n_t[t]
                else:
                    # get max iou and index, select positive ious
                    val, inds = iou_t.max(dim=0)
                    ind = inds[val > 0.0]

                    # make sure there is at most one detection
                    assert len(ind) <= 1, f"{target_inst_id} {n_t}"

                    # if there is one detection with pos IOU, select it
                    if len(ind) == 1:
                        obj_ind = ind[0].item()
                        asso_gt[k, t] = obj_ind
                        match_cues[base_ind + obj_ind] = k

                    # otherwise asso_gt = # gt instances at time step
                    else:
                        asso_gt[k, t] = n_t[t]

                base_ind += n_t[t]

        return asso_gt, match_cues

    def detr_asso_loss(
        self,
        asso_pred: torch.Tensor,
        asso_gt: torch.Tensor,
        match_cues: torch.Tensor,
        n_t: torch.Tensor,
    ) -> torch.Tensor:
        """Calculate association loss between predicted and gt boxes.

        Args:
            asso_pred: Association matrix output from the transformer forward
                pass denoting predicted instances over time (M x N)
            asso_gt: Ground truth association matrix (K x N) denoting ground
                truth instances over time
            match_cues: Tensor indicating which instance is assigned to each gt
                detection (K x 3) or (N,)
            n_t: number of ground truth instances (N,)

        Returns:
            loss: association loss normalized by number of objects
        """
        # get matches between preds and gt
        src_inds, target_inds = self._match(asso_pred, asso_gt, match_cues, n_t)

        loss = 0
        num_objs = 0

        zero = asso_pred.new_zeros((asso_pred.shape[0], 1))  # M x 1
        asso_pred_image = asso_pred.split(n_t, dim=1)  # T x [M x n_t]

        for t in range(len(n_t)):
            # add background class
            asso_pred_with_bg = torch.cat(
                [asso_pred_image[t], zero], dim=1
            )  # M x (n_t + 1)

            if self.neg_unmatched:
                # set unmatched preds to background
                asso_gt_t = asso_gt.new_full((asso_pred.shape[0],), float(n_t[t]))  # M
                asso_gt_t[src_inds] = asso_gt[target_inds, t]  # M
            else:
                # keep only unmatched preds
                asso_pred_with_bg = asso_pred_with_bg[src_inds]  # K x (n_t + 1)
                asso_gt_t = asso_gt[target_inds, t]  # K

            num_objs += (asso_gt_t != n_t[t]).float().sum()

            loss += F.cross_entropy(asso_pred_with_bg, asso_gt_t, reduction="none")

        return loss.sum() / (num_objs + self.epsilon)

    @torch.no_grad()
    def _match(
        self,
        asso_pred: torch.Tensor,
        asso_gt: torch.Tensor,
        match_cues: torch.Tensor,
        n_t: torch.Tensor,
    ) -> torch.Tensor:
        """Match predicted scores to gt scores using match cues.

        Args:
            asso_pred: Association matrix output from the transformer forward
                pass denoting predicted instances over time (M x N)
            asso_gt: Ground truth association matrix (K x N) denoting ground
                truth instances over time
            match_cues: Tensor indicating which instance is assigned to each gt
                detection (K x 3) or (N,)
            n_t: number of ground truth instances (N,)

        Returns:
            src_inds: Matched source indices (N,)
            target_inds: Matched target indices (N,)
        """
        src_inds = torch.where(match_cues >= 0)[0]
        target_inds = match_cues[src_inds]

        return (src_inds, target_inds)

__init__(neg_unmatched=False, epsilon=0.0001, asso_weight=1.0)

Initialize Loss function.

Parameters:

Name Type Description Default
neg_unmatched bool

Whether or not to set unmatched objects to background

False
epsilon float

small number used for numeric precision to prevent dividing by zero

0.0001
asso_weight float

How much to weight the association loss by

1.0
Source code in dreem/training/losses.py
def __init__(
    self,
    neg_unmatched: bool = False,
    epsilon: float = 1e-4,
    asso_weight: float = 1.0,
):
    """Initialize Loss function.

    Args:
        neg_unmatched: Whether or not to set unmatched objects to background
        epsilon: small number used for numeric precision to prevent dividing by zero
        asso_weight: How much to weight the association loss by
    """
    super().__init__()

    self.neg_unmatched = neg_unmatched
    self.epsilon = epsilon
    self.asso_weight = asso_weight

detr_asso_loss(asso_pred, asso_gt, match_cues, n_t)

Calculate association loss between predicted and gt boxes.

Parameters:

Name Type Description Default
asso_pred Tensor

Association matrix output from the transformer forward pass denoting predicted instances over time (M x N)

required
asso_gt Tensor

Ground truth association matrix (K x N) denoting ground truth instances over time

required
match_cues Tensor

Tensor indicating which instance is assigned to each gt detection (K x 3) or (N,)

required
n_t Tensor

number of ground truth instances (N,)

required

Returns:

Name Type Description
loss Tensor

association loss normalized by number of objects

Source code in dreem/training/losses.py
def detr_asso_loss(
    self,
    asso_pred: torch.Tensor,
    asso_gt: torch.Tensor,
    match_cues: torch.Tensor,
    n_t: torch.Tensor,
) -> torch.Tensor:
    """Calculate association loss between predicted and gt boxes.

    Args:
        asso_pred: Association matrix output from the transformer forward
            pass denoting predicted instances over time (M x N)
        asso_gt: Ground truth association matrix (K x N) denoting ground
            truth instances over time
        match_cues: Tensor indicating which instance is assigned to each gt
            detection (K x 3) or (N,)
        n_t: number of ground truth instances (N,)

    Returns:
        loss: association loss normalized by number of objects
    """
    # get matches between preds and gt
    src_inds, target_inds = self._match(asso_pred, asso_gt, match_cues, n_t)

    loss = 0
    num_objs = 0

    zero = asso_pred.new_zeros((asso_pred.shape[0], 1))  # M x 1
    asso_pred_image = asso_pred.split(n_t, dim=1)  # T x [M x n_t]

    for t in range(len(n_t)):
        # add background class
        asso_pred_with_bg = torch.cat(
            [asso_pred_image[t], zero], dim=1
        )  # M x (n_t + 1)

        if self.neg_unmatched:
            # set unmatched preds to background
            asso_gt_t = asso_gt.new_full((asso_pred.shape[0],), float(n_t[t]))  # M
            asso_gt_t[src_inds] = asso_gt[target_inds, t]  # M
        else:
            # keep only unmatched preds
            asso_pred_with_bg = asso_pred_with_bg[src_inds]  # K x (n_t + 1)
            asso_gt_t = asso_gt[target_inds, t]  # K

        num_objs += (asso_gt_t != n_t[t]).float().sum()

        loss += F.cross_entropy(asso_pred_with_bg, asso_gt_t, reduction="none")

    return loss.sum() / (num_objs + self.epsilon)

forward(asso_preds, frames)

Calculate association loss.

Parameters:

Name Type Description Default
asso_preds List[Tensor]

a list containing the association matrix at each frame

required
frames List[Frame]

a list of Frames containing gt labels.

required

Returns:

Type Description
Tensor

the association loss between predicted association and actual

Source code in dreem/training/losses.py
def forward(
    self, asso_preds: List[torch.Tensor], frames: List["Frame"]
) -> torch.Tensor:
    """Calculate association loss.

    Args:
        asso_preds: a list containing the association matrix at each frame
        frames: a list of Frames containing gt labels.

    Returns:
        the association loss between predicted association and actual
    """
    # get number of detected objects and ground truth ids
    n_t = [frame.num_detected for frame in frames]
    try:
        target_inst_id = torch.cat(
            [frame.get_gt_track_ids().to(asso_preds[-1].device) for frame in frames]
        )
    except RuntimeError as e:
        print([frame.get_gt_track_ids().device for frame in frames])
        raise (e)
    instances = [instance for frame in frames for instance in frame.instances]

    # for now set equal since detections are fixed
    pred_box = get_boxes(instances)
    pred_time, _ = get_times(instances)
    pred_box = torch.nanmean(pred_box, axis=1)
    target_box, target_time = pred_box, pred_time

    # todo: we should maybe reconsider how we label gt instances. The second
    # criterion will return true on a single instance video, for example.
    # For now we can ignore this since we train on dense labels.

    """
        # Return a 0 loss if any of the 2 criteria are met
        # 1. the video doesn’t have gt bboxes
        # 2. the maximum id is zero

    sum_instance_lengths = sum(len(x) for x in instances)
    max_instance_lengths = max(
        x["gt_track_ids"].max().item() for x in instances if len(x) > 0
    )

    if sum_instance_lengths == 0 or max_instance_lengths == 0:
        print("No bounding boxes detected, returning zero loss")
        print(f"Sum instance lengths: {sum_instance_lengths}")
        print(f"Max instance lengths: {max_instance_lengths}")
        loss = asso_preds[0].new_zeros((1,), dtype=torch.float32)[0]
        return loss
    """

    asso_gt, match_cues = self._get_asso_gt(
        pred_box, pred_time, target_box, target_time, target_inst_id, n_t
    )

    loss = sum(
        [
            self.detr_asso_loss(asso_pred, asso_gt, match_cues, n_t)
            for asso_pred in asso_preds
        ]
    )

    loss *= self.asso_weight

    return loss