Skip to content

post_processing

dreem.inference.post_processing

Helper functions for post-processing association matrix pre-tracking.

filter_max_center_dist(asso_output, max_center_dist=0, k_boxes=None, nonk_boxes=None, id_inds=None)

Filter trajectory score by distances between objects across frames.

Parameters:

Name Type Description Default
asso_output Tensor

An N_t x N association matrix

required
max_center_dist float

The euclidean distance threshold between bboxes

0
k_boxes Tensor

The bounding boxes in the current frame

None
nonk_boxes Tensor

the boxes not in the current frame

None
id_inds Tensor

track ids

None

Returns:

Type Description
Tensor

An N_t x N association matrix

Source code in dreem/inference/post_processing.py
def filter_max_center_dist(
    asso_output: torch.Tensor,
    max_center_dist: float = 0,
    k_boxes: torch.Tensor = None,
    nonk_boxes: torch.Tensor = None,
    id_inds: torch.Tensor = None,
) -> torch.Tensor:
    """Filter trajectory score by distances between objects across frames.

    Args:
        asso_output: An N_t x N association matrix
        max_center_dist: The euclidean distance threshold between bboxes
        k_boxes: The bounding boxes in the current frame
        nonk_boxes: the boxes not in the current frame
        id_inds: track ids

    Returns:
        An N_t x N association matrix
    """
    if max_center_dist is not None and max_center_dist > 0:
        assert (
            k_boxes is not None and nonk_boxes is not None and id_inds is not None
        ), "Need `k_boxes`, `nonk_boxes`, and `id_ind` to filter by `max_center_dist`"
        k_ct = (k_boxes[:, :2] + k_boxes[:, 2:]) / 2
        k_s = ((k_boxes[:, 2:] - k_boxes[:, :2]) ** 2).sum(dim=1)  # n_k

        nonk_ct = (nonk_boxes[:, :2] + nonk_boxes[:, 2:]) / 2
        dist = ((k_ct[:, None] - nonk_ct[None, :]) ** 2).sum(dim=2)  # n_k x Np

        norm_dist = dist / (k_s[:, None] + 1e-8)  # n_k x Np
        # id_inds # Np x M
        valid = norm_dist < max_center_dist  # n_k x Np

        valid_assn = (
            torch.mm(valid.float(), id_inds.to(valid.device))
            .clamp_(max=1.0)
            .long()
            .bool()
        )  # n_k x M
        asso_output_filtered = asso_output.clone()
        asso_output_filtered[~valid_assn] = 0  # n_k x M
        return asso_output_filtered
    else:
        return asso_output

weight_decay_time(asso_output, decay_time=0, reid_features=None, T=None, k=None)

Weight association matrix by time.

Weighs matrix by number of frames the ith object is from the jth object in the association matrix.

Parameters:

Name Type Description Default
asso_output Tensor

the association matrix to be reweighted

required
decay_time float

the scale to weight the asso_output by

0
reid_features Tensor

The n x d matrix of feature vectors for each object

None
T int

The length of the window

None
k int

an integer for the query frame within the window of instances

None

Returns: The N_t x N association matrix weighted by decay time

Source code in dreem/inference/post_processing.py
def weight_decay_time(
    asso_output: torch.Tensor,
    decay_time: float = 0,
    reid_features: torch.Tensor = None,
    T: int = None,
    k: int = None,
) -> torch.Tensor:
    """Weight association matrix by time.

    Weighs matrix by number of frames the ith object is from the jth object
    in the association matrix.

    Args:
        asso_output: the association matrix to be reweighted
        decay_time: the scale to weight the asso_output by
        reid_features: The n x d matrix of feature vectors for each object
        T: The length of the window
        k: an integer for the query frame within the window of instances
    Returns: The N_t x N association matrix weighted by decay time
    """
    if decay_time is not None and decay_time > 0:
        assert (
            reid_features is not None and T is not None and k is not None
        ), "Need reid_features to weight traj_score by `decay_time`!"
        N_t = asso_output.shape[0]
        dts = torch.cat(
            [
                x.new_full((N_t,), T - t - 2)
                for t, x in enumerate(reid_features)
                if t != k
            ],
            dim=0,
        ).cpu()  # Np
        # asso_output = asso_output.to(self.device) * (self.decay_time ** dts[None, :])
        asso_output = asso_output * (decay_time ** dts[:, None])
    return asso_output

weight_iou(asso_output, method=None, last_ious=None)

Weight the association matrix by the IOU between object bboxes across frames.

Parameters:

Name Type Description Default
asso_output Tensor

An N_t x N association matrix

required
method str

string indicating whether to use a max weighting or multiplicative weighting Max weighting: take max(traj_score, iou) multiplicative weighting: iou*weight + traj_score

None
last_ious Tensor

torch Tensor containing the ious between current and previous frames

None

Returns:

Type Description
Tensor

An N_t x N association matrix weighted by the IOU

Source code in dreem/inference/post_processing.py
def weight_iou(
    asso_output: torch.Tensor, method: str = None, last_ious: torch.Tensor = None
) -> torch.Tensor:
    """Weight the association matrix by the IOU between object bboxes across frames.

    Args:
        asso_output: An N_t x N association matrix
        method: string indicating whether to use a max weighting or multiplicative weighting
                Max weighting: take `max(traj_score, iou)`
                multiplicative weighting: `iou*weight + traj_score`
        last_ious: torch Tensor containing the ious between current and previous frames

    Returns:
        An N_t x N association matrix weighted by the IOU
    """
    if method is not None and method != "":
        assert last_ious is not None, "Need `last_ious` to weight traj_score by `IOU`"
        if method.lower() == "mult":
            weights = torch.abs(last_ious - asso_output)
            weighted_iou = weights * last_ious
            weighted_iou = torch.nan_to_num(weighted_iou, 0)
            asso_output = asso_output + weighted_iou
        elif method.lower() == "max":
            asso_output = torch.max(asso_output, last_ious)
        else:
            raise ValueError(
                f"`method` must be one of ['mult' or 'max'] got '{method.lower()}'"
            )
    return asso_output