Skip to content

post_processing

dreem.inference.post_processing

Helper functions for post-processing association matrix pre-tracking.

Functions:

Name Description
filter_max_center_dist

Filter trajectory score by distances between objects across frames.

weight_decay_time

Weight association matrix by time.

weight_iou

Weight the association matrix by the IOU between object bboxes across frames.

filter_max_center_dist(asso_output, max_center_dist=0, id_inds=None, query_boxes_px=None, nonquery_boxes_px=None)

Filter trajectory score by distances between objects across frames.

Parameters:

Name Type Description Default
asso_output Tensor

An N_t x N association matrix

required
max_center_dist float

The euclidean distance threshold between bboxes

0
id_inds Tensor | None

track ids

None
query_boxes_px Tensor | None

the raw bbox coords of the current frame instances

None
nonquery_boxes_px Tensor | None

the raw bbox coords of the instances in the nonquery frames (context window)

None

Returns:

Type Description
Tensor

An N_t x N association matrix

Source code in dreem/inference/post_processing.py
def filter_max_center_dist(
    asso_output: torch.Tensor,
    max_center_dist: float = 0,
    id_inds: torch.Tensor | None = None,
    query_boxes_px: torch.Tensor | None = None,
    nonquery_boxes_px: torch.Tensor | None = None,
) -> torch.Tensor:
    """Filter trajectory score by distances between objects across frames.

    Args:
        asso_output: An N_t x N association matrix
        max_center_dist: The euclidean distance threshold between bboxes
        id_inds: track ids
        query_boxes_px: the raw bbox coords of the current frame instances
        nonquery_boxes_px: the raw bbox coords of the instances in the nonquery frames (context window)

    Returns:
        An N_t x N association matrix
    """
    if max_center_dist is not None and max_center_dist > 0:
        assert (
            query_boxes_px is not None and nonquery_boxes_px is not None
        ), "Need `query_boxes_px`, and `nonquery_boxes_px` to filter by `max_center_dist`"

        k_ct = (query_boxes_px[:, :, :2] + query_boxes_px[:, :, 2:]) / 2
        # k_s = ((curr_frame_boxes[:, :, 2:] - curr_frame_boxes[:, :, :2]) ** 2).sum(dim=2)  # n_k
        # nonk boxes are only from previous frame rather than entire window
        nonk_ct = (nonquery_boxes_px[:, :, :2] + nonquery_boxes_px[:, :, 2:]) / 2

        # pairwise euclidean distance in units of pixels
        dist = ((k_ct[:, None, :, :] - nonk_ct[None, :, :, :]) ** 2).sum(dim=-1) ** (
            1 / 2
        )  # n_k x n_nonk
        # norm_dist = dist / (k_s[:, None, :] + 1e-8)

        valid = dist.squeeze() < max_center_dist  # n_k x n_nonk
        # handle case where id_inds and valid is a single value
        # handle this better
        if valid.ndim == 0:
            valid = valid.unsqueeze(0)
        if valid.ndim == 1:
            if id_inds.shape[0] == 1:
                valid_mult = valid.float().unsqueeze(-1)
            else:
                valid_mult = valid.float().unsqueeze(0)
        else:
            valid_mult = valid.float()

        valid_assn = (
            torch.mm(valid_mult, id_inds.to(valid.device)).clamp_(max=1.0).long().bool()
        )  # n_k x M
        asso_output_filtered = asso_output.clone()
        asso_output_filtered[~valid_assn] = 0  # n_k x M
        return asso_output_filtered
    else:
        return asso_output

weight_decay_time(asso_output, decay_time=0, reid_features=None, T=None, k=None)

Weight association matrix by time.

Weighs matrix by number of frames the ith object is from the jth object in the association matrix.

Parameters:

Name Type Description Default
asso_output Tensor

the association matrix to be reweighted

required
decay_time float

the scale to weight the asso_output by

0
reid_features Tensor | None

The n x d matrix of feature vectors for each object

None
T int | None

The length of the window

None
k int | None

an integer for the query frame within the window of instances

None

Returns: The N_t x N association matrix weighted by decay time

Source code in dreem/inference/post_processing.py
def weight_decay_time(
    asso_output: torch.Tensor,
    decay_time: float = 0,
    reid_features: torch.Tensor | None = None,
    T: int | None = None,
    k: int | None = None,
) -> torch.Tensor:
    """Weight association matrix by time.

    Weighs matrix by number of frames the ith object is from the jth object
    in the association matrix.

    Args:
        asso_output: the association matrix to be reweighted
        decay_time: the scale to weight the asso_output by
        reid_features: The n x d matrix of feature vectors for each object
        T: The length of the window
        k: an integer for the query frame within the window of instances
    Returns: The N_t x N association matrix weighted by decay time
    """
    if decay_time is not None and decay_time > 0:
        assert (
            reid_features is not None and T is not None and k is not None
        ), "Need reid_features to weight traj_score by `decay_time`!"
        N_t = asso_output.shape[0]
        dts = torch.cat(
            [
                x.new_full((N_t,), T - t - 2)
                for t, x in enumerate(reid_features)
                if t != k
            ],
            dim=0,
        ).cpu()  # Np
        # asso_output = asso_output.to(self.device) * (self.decay_time ** dts[None, :])
        asso_output = asso_output * (decay_time ** dts[:, None])
    return asso_output

weight_iou(asso_output, method=None, last_ious=None)

Weight the association matrix by the IOU between object bboxes across frames.

Parameters:

Name Type Description Default
asso_output Tensor

An N_t x N association matrix

required
method str | None

string indicating whether to use a max weighting or multiplicative weighting Max weighting: take max(traj_score, iou) multiplicative weighting: iou*weight + traj_score

None
last_ious Tensor

torch Tensor containing the ious between current and previous frames

None

Returns:

Type Description
Tensor

An N_t x N association matrix weighted by the IOU

Source code in dreem/inference/post_processing.py
def weight_iou(
    asso_output: torch.Tensor, method: str | None = None, last_ious: torch.Tensor = None
) -> torch.Tensor:
    """Weight the association matrix by the IOU between object bboxes across frames.

    Args:
        asso_output: An N_t x N association matrix
        method: string indicating whether to use a max weighting or multiplicative weighting
                Max weighting: take `max(traj_score, iou)`
                multiplicative weighting: `iou*weight + traj_score`
        last_ious: torch Tensor containing the ious between current and previous frames

    Returns:
        An N_t x N association matrix weighted by the IOU
    """
    if method is not None and method != "":
        assert last_ious is not None, "Need `last_ious` to weight traj_score by `IOU`"
        if method.lower() == "mult":
            weights = torch.abs(last_ious - asso_output)
            weighted_iou = weights * last_ious
            weighted_iou = torch.nan_to_num(weighted_iou, 0)
            asso_output = asso_output + weighted_iou
        elif method.lower() == "max":
            asso_output = torch.max(asso_output, last_ious)
        else:
            raise ValueError(
                f"`method` must be one of ['mult' or 'max'] got '{method.lower()}'"
            )
    return asso_output