tracker
dreem.inference.tracker
¶
Module containing logic for going from association -> assignment.
Tracker
¶
Tracker class used for assignment based on sliding inference from GTR.
Source code in dreem/inference/tracker.py
class Tracker:
"""Tracker class used for assignment based on sliding inference from GTR."""
def __init__(
self,
window_size: int = 8,
use_vis_feats: bool = True,
overlap_thresh: float = 0.01,
mult_thresh: bool = True,
decay_time: float = None,
iou: str = None,
max_center_dist: float = None,
persistent_tracking: bool = False,
max_gap: int = inf,
max_tracks: int = inf,
verbose: bool = False,
):
"""Initialize a tracker to run inference.
Args:
window_size: the size of the window used during sliding inference.
use_vis_feats: Whether or not to use visual feature extractor.
overlap_thresh: the trajectory overlap threshold to be used for assignment.
mult_thresh: Whether or not to use weight threshold.
decay_time: weight for `decay_time` postprocessing.
iou: Either [None, '', "mult" or "max"]
Whether to use multiplicative or max iou reweighting.
max_center_dist: distance threshold for filtering trajectory score matrix.
persistent_tracking: whether to keep a buffer across chunks or not.
max_gap: the max number of frames a trajectory can be missing before termination.
max_tracks: the maximum number of tracks that can be created while tracking.
We force the tracker to assign instances to a track instead of creating a new track if max_tracks has been reached.
verbose: Whether or not to turn on debug printing after each operation.
"""
self.track_queue = TrackQueue(
window_size=window_size, max_gap=max_gap, verbose=verbose
)
self.use_vis_feats = use_vis_feats
self.overlap_thresh = overlap_thresh
self.mult_thresh = mult_thresh
self.decay_time = decay_time
self.iou = iou
self.max_center_dist = max_center_dist
self.persistent_tracking = persistent_tracking
self.verbose = verbose
self.max_tracks = max_tracks
def __call__(
self, model: GlobalTrackingTransformer, frames: list[Frame]
) -> list[Frame]:
"""Wrap around `track` to enable `tracker()` instead of `tracker.track()`.
Args:
model: the pretrained GlobalTrackingTransformer to be used for inference
frames: list of Frames to run inference on
Returns:
List of frames containing association matrix scores and instances populated with pred track ids.
"""
return self.track(model, frames)
def track(
self, model: GlobalTrackingTransformer, frames: list[dict]
) -> list[Frame]:
"""Run tracker and get predicted trajectories.
Args:
model: the pretrained GlobalTrackingTransformer to be used for inference
frames: data dict to run inference on
Returns:
List of Frames populated with pred track ids and association matrix scores
"""
# Extract feature representations with pre-trained encoder.
_ = model.eval()
for frame in frames:
if frame.has_instances():
if not self.use_vis_feats:
for instance in frame.instances:
instance.features = torch.zeros(1, model.d_model)
# frame["features"] = torch.randn(
# num_frame_instances, self.model.d_model
# )
# comment out to turn encoder off
# Assuming the encoder is already trained or train encoder jointly.
elif not frame.has_features():
with torch.no_grad():
crops = frame.get_crops()
z = model.visual_encoder(crops)
for i, z_i in enumerate(z):
frame.instances[i].features = z_i
# I feel like this chunk is unnecessary:
# reid_features = torch.cat(
# [frame["features"] for frame in instances], dim=0
# ).unsqueeze(0)
# asso_preds, pred_boxes, pred_time, embeddings = self.model(
# instances, reid_features
# )
instances_pred = self.sliding_inference(model, frames)
if not self.persistent_tracking:
if self.verbose:
warnings.warn(f"Clearing Queue after tracking")
self.track_queue.end_tracks()
return instances_pred
def sliding_inference(
self, model: GlobalTrackingTransformer, frames: list[Frame]
) -> list[Frame]:
"""Perform sliding inference on the input video (instances) with a given window size.
Args:
model: the pretrained GlobalTrackingTransformer to be used for inference
frames: A list of Frames (See `dreem.io.Frame` for more info).
Returns:
frames: A list of Frames populated with pred_track_ids and asso_matrices
"""
# B: batch size.
# D: embedding dimension.
# nc: number of channels.
# H: height.
# W: width.
for batch_idx, frame_to_track in enumerate(frames):
tracked_frames = self.track_queue.collate_tracks(
device=frame_to_track.frame_id.device
)
if self.verbose:
warnings.warn(
f"Current number of tracks is {self.track_queue.n_tracks}"
)
if (
self.persistent_tracking and frame_to_track.frame_id == 0
): # check for new video and clear queue
if self.verbose:
warnings.warn("New Video! Resetting Track Queue.")
self.track_queue.end_tracks()
"""
Initialize tracks on first frame where detections appear.
"""
if len(self.track_queue) == 0:
if frame_to_track.has_instances():
if self.verbose:
warnings.warn(
f"Initializing track on clip ind {batch_idx} frame {frame_to_track.frame_id.item()}"
)
curr_track_id = 0
for i, instance in enumerate(frames[batch_idx].instances):
instance.pred_track_id = instance.gt_track_id
curr_track_id = max(curr_track_id, instance.pred_track_id)
for i, instance in enumerate(frames[batch_idx].instances):
if instance.pred_track_id == -1:
curr_track += 1
instance.pred_track_id = curr_track_id
else:
if (
frame_to_track.has_instances()
): # Check if there are detections. If there are skip and increment gap count
frames_to_track = tracked_frames + [
frame_to_track
] # better var name?
query_ind = len(frames_to_track) - 1
frame_to_track = self._run_global_tracker(
model,
frames_to_track,
query_ind=query_ind,
)
if frame_to_track.has_instances():
self.track_queue.add_frame(frame_to_track)
else:
self.track_queue.increment_gaps([])
frames[batch_idx] = frame_to_track
return frames
def _run_global_tracker(
self, model: GlobalTrackingTransformer, frames: list[Frame], query_ind: int
) -> Frame:
"""Run global tracker performs the actual tracking.
Uses Hungarian algorithm to do track assigning.
Args:
model: the pretrained GlobalTrackingTransformer to be used for inference
frames: A list of Frames containing reid features. See `dreem.io.data_structures` for more info.
query_ind: An integer for the query frame within the window of instances.
Returns:
query_frame: The query frame now populated with the pred_track_ids.
"""
# *: each item in frames is a frame in the window. So it follows
# that each frame in the window has * detected instances.
# D: embedding dimension.
# total_instances: number of instances in the window.
# N_i: number of detected instances in i-th frame of window.
# instances_per_frame: a list of number of instances in each frame of the window.
# n_query: number of instances in current/query frame (rightmost frame of the window).
# n_nonquery: number of instances in the window excluding the current/query frame.
# window_size: length of window.
# L: number of decoder blocks.
# n_traj: number of existing tracks within the window so far.
# Number of instances in each frame of the window.
# E.g.: instances_per_frame: [4, 5, 6, 7]; window of length 4 with 4 detected instances in the first frame of the window.
_ = model.eval()
query_frame = frames[query_ind]
query_instances = query_frame.instances
all_instances = [instance for frame in frames for instance in frame.instances]
if self.verbose:
print(f"Frame {query_frame.frame_id.item()}")
instances_per_frame = [frame.num_detected for frame in frames]
total_instances, window_size = sum(instances_per_frame), len(
instances_per_frame
) # Number of instances in window; length of window.
if self.verbose:
print(f"total_instances: {total_instances}")
overlap_thresh = self.overlap_thresh
mult_thresh = self.mult_thresh
n_traj = self.track_queue.n_tracks
curr_track = self.track_queue.curr_track
reid_features = torch.cat([frame.get_features() for frame in frames], dim=0)[
None
] # (1, total_instances, D=512)
# (L=1, n_query, total_instances)
with torch.no_grad():
asso_matrix = model(all_instances, query_instances)
# if model.transformer.return_embedding:
# query_frame.embeddings = embed TODO add embedding to Instance Object
# if query_frame == 1:
# print(asso_output)
asso_output = asso_matrix[-1].matrix.split(
instances_per_frame, dim=1
) # (window_size, n_query, N_i)
asso_output = model_utils.softmax_asso(
asso_output
) # (window_size, n_query, N_i)
asso_output = torch.cat(asso_output, dim=1).cpu() # (n_query, total_instances)
asso_output_df = pd.DataFrame(
asso_output.clone().numpy(),
columns=[f"Instance {i}" for i in range(asso_output.shape[-1])],
)
asso_output_df.index.name = "Instances"
asso_output_df.columns.name = "Instances"
query_frame.add_traj_score("asso_output", asso_output_df)
query_frame.asso_output = asso_matrix
try:
n_query = (
query_frame.num_detected
) # Number of instances in the current/query frame.
except Exception as e:
print(len(frames), query_frame, frames[-1])
raise (e)
n_nonquery = (
total_instances - n_query
) # Number of instances in the window not including the current/query frame.
if self.verbose:
print(f"n_nonquery: {n_nonquery}")
print(f"n_query: {n_query}")
try:
instance_ids = torch.cat(
[
x.get_pred_track_ids()
for batch_idx, x in enumerate(frames)
if batch_idx != query_ind
],
dim=0,
).view(
n_nonquery
) # (n_nonquery,)
except Exception as e:
print(
[
[instance.pred_track_id.device for instance in frame.instances]
for frame in frames
]
)
raise (e)
query_inds = [
x
for x in range(
sum(instances_per_frame[:query_ind]),
sum(instances_per_frame[: query_ind + 1]),
)
]
nonquery_inds = [i for i in range(total_instances) if i not in query_inds]
# instead should we do model(nonquery_instances, query_instances)?
asso_nonquery = asso_output[:, nonquery_inds] # (n_query, n_nonquery)
asso_nonquery_df = pd.DataFrame(
asso_nonquery.clone().numpy(), columns=nonquery_inds
)
asso_nonquery_df.index.name = "Current Frame Instances"
asso_nonquery_df.columns.name = "Nonquery Instances"
query_frame.add_traj_score("asso_nonquery", asso_nonquery_df)
pred_boxes = model_utils.get_boxes(all_instances)
query_boxes = pred_boxes[query_inds] # n_k x 4
nonquery_boxes = pred_boxes[nonquery_inds] # n_nonquery x 4
unique_ids = torch.unique(instance_ids) # (n_nonquery,)
if self.verbose:
print(f"Instance IDs: {instance_ids}")
print(f"unique ids: {unique_ids}")
id_inds = (
unique_ids[None, :] == instance_ids[:, None]
).float() # (n_nonquery, n_traj)
################################################################################
# reweighting hyper-parameters for association -> they use 0.9
traj_score = post_processing.weight_decay_time(
asso_nonquery, self.decay_time, reid_features, window_size, query_ind
)
if self.decay_time is not None and self.decay_time > 0:
decay_time_traj_score = pd.DataFrame(
traj_score.clone().numpy(), columns=nonquery_inds
)
decay_time_traj_score.index.name = "Query Instances"
decay_time_traj_score.columns.name = "Nonquery Instances"
query_frame.add_traj_score("decay_time", decay_time_traj_score)
################################################################################
# (n_query x n_nonquery) x (n_nonquery x n_traj) --> n_query x n_traj
traj_score = torch.mm(traj_score, id_inds.cpu()) # (n_query, n_traj)
traj_score_df = pd.DataFrame(
traj_score.clone().numpy(), columns=unique_ids.cpu().numpy()
)
traj_score_df.index.name = "Current Frame Instances"
traj_score_df.columns.name = "Unique IDs"
query_frame.add_traj_score("traj_score", traj_score_df)
################################################################################
# with iou -> combining with location in tracker, they set to True
# todo -> should also work without pos_embed
if id_inds.numel() > 0:
# this throws error, think we need to slice?
# last_inds = (id_inds * torch.arange(
# n_nonquery, device=id_inds.device)[:, None]).max(dim=0)[1] # n_traj
last_inds = (
id_inds * torch.arange(n_nonquery, device=id_inds.device)[:, None]
).max(dim=0)[
1
] # M
last_boxes = nonquery_boxes[last_inds] # n_traj x 4
last_ious = post_processing._pairwise_iou(
Boxes(query_boxes), Boxes(last_boxes)
) # n_k x M
else:
last_ious = traj_score.new_zeros(traj_score.shape)
traj_score = post_processing.weight_iou(traj_score, self.iou, last_ious.cpu())
if self.iou is not None and self.iou != "":
iou_traj_score = pd.DataFrame(
traj_score.clone().numpy(), columns=unique_ids.cpu().numpy()
)
iou_traj_score.index.name = "Current Frame Instances"
iou_traj_score.columns.name = "Unique IDs"
query_frame.add_traj_score("weight_iou", iou_traj_score)
################################################################################
# threshold for continuing a tracking or starting a new track -> they use 1.0
# todo -> should also work without pos_embed
traj_score = post_processing.filter_max_center_dist(
traj_score, self.max_center_dist, query_boxes, nonquery_boxes, id_inds
)
if self.max_center_dist is not None and self.max_center_dist > 0:
max_center_dist_traj_score = pd.DataFrame(
traj_score.clone().numpy(), columns=unique_ids.cpu().numpy()
)
max_center_dist_traj_score.index.name = "Current Frame Instances"
max_center_dist_traj_score.columns.name = "Unique IDs"
query_frame.add_traj_score("max_center_dist", max_center_dist_traj_score)
################################################################################
scaled_traj_score = torch.softmax(traj_score, dim=1)
scaled_traj_score_df = pd.DataFrame(
scaled_traj_score.numpy(), columns=unique_ids.cpu().numpy()
)
scaled_traj_score_df.index.name = "Current Frame Instances"
scaled_traj_score_df.columns.name = "Unique IDs"
query_frame.add_traj_score("scaled", scaled_traj_score_df)
################################################################################
try:
match_i, match_j = linear_sum_assignment((-traj_score))
except ValueError as e:
print(reid_features.isnan().any())
print(asso_output)
print(traj_score)
raise (e)
track_ids = instance_ids.new_full((n_query,), -1)
for i, j in zip(match_i, match_j):
# The overlap threshold is multiplied by the number of times the unique track j is matched to an
# instance out of all instances in the window excluding the current frame.
#
# So if this is correct, the threshold is higher for matching an instance from the current frame
# to an existing track if that track has already been matched several times.
# So if an existing track in the window has been matched a lot, it gets harder to match to that track.
thresh = (
overlap_thresh * id_inds[:, j].sum() if mult_thresh else overlap_thresh
)
if n_traj >= self.max_tracks or traj_score[i, j] > thresh:
if self.verbose:
print(
f"Assigning instance {i} to track {j} with id {unique_ids[j]}"
)
track_ids[i] = unique_ids[j]
query_frame.instances[i].track_score = scaled_traj_score[i, j].item()
if self.verbose:
print(f"track_ids: {track_ids}")
for i in range(n_query):
if track_ids[i] < 0:
if self.verbose:
print(f"Creating new track {n_traj}")
curr_track += 1
track_ids[i] = curr_track
query_frame.matches = (match_i, match_j)
for instance, track_id in zip(query_frame.instances, track_ids):
instance.pred_track_id = track_id
final_traj_score = pd.DataFrame(
traj_score.clone().numpy(), columns=unique_ids.cpu().numpy()
)
final_traj_score.index.name = "Current Frame Instances"
final_traj_score.columns.name = "Unique IDs"
query_frame.add_traj_score("final", final_traj_score)
return query_frame
__call__(model, frames)
¶
Wrap around track
to enable tracker()
instead of tracker.track()
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
GlobalTrackingTransformer
|
the pretrained GlobalTrackingTransformer to be used for inference |
required |
frames |
list[Frame]
|
list of Frames to run inference on |
required |
Returns:
Type | Description |
---|---|
list[Frame]
|
List of frames containing association matrix scores and instances populated with pred track ids. |
Source code in dreem/inference/tracker.py
def __call__(
self, model: GlobalTrackingTransformer, frames: list[Frame]
) -> list[Frame]:
"""Wrap around `track` to enable `tracker()` instead of `tracker.track()`.
Args:
model: the pretrained GlobalTrackingTransformer to be used for inference
frames: list of Frames to run inference on
Returns:
List of frames containing association matrix scores and instances populated with pred track ids.
"""
return self.track(model, frames)
__init__(window_size=8, use_vis_feats=True, overlap_thresh=0.01, mult_thresh=True, decay_time=None, iou=None, max_center_dist=None, persistent_tracking=False, max_gap=inf, max_tracks=inf, verbose=False)
¶
Initialize a tracker to run inference.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
window_size |
int
|
the size of the window used during sliding inference. |
8
|
use_vis_feats |
bool
|
Whether or not to use visual feature extractor. |
True
|
overlap_thresh |
float
|
the trajectory overlap threshold to be used for assignment. |
0.01
|
mult_thresh |
bool
|
Whether or not to use weight threshold. |
True
|
decay_time |
float
|
weight for |
None
|
iou |
str
|
Either [None, '', "mult" or "max"] Whether to use multiplicative or max iou reweighting. |
None
|
max_center_dist |
float
|
distance threshold for filtering trajectory score matrix. |
None
|
persistent_tracking |
bool
|
whether to keep a buffer across chunks or not. |
False
|
max_gap |
int
|
the max number of frames a trajectory can be missing before termination. |
inf
|
max_tracks |
int
|
the maximum number of tracks that can be created while tracking. We force the tracker to assign instances to a track instead of creating a new track if max_tracks has been reached. |
inf
|
verbose |
bool
|
Whether or not to turn on debug printing after each operation. |
False
|
Source code in dreem/inference/tracker.py
def __init__(
self,
window_size: int = 8,
use_vis_feats: bool = True,
overlap_thresh: float = 0.01,
mult_thresh: bool = True,
decay_time: float = None,
iou: str = None,
max_center_dist: float = None,
persistent_tracking: bool = False,
max_gap: int = inf,
max_tracks: int = inf,
verbose: bool = False,
):
"""Initialize a tracker to run inference.
Args:
window_size: the size of the window used during sliding inference.
use_vis_feats: Whether or not to use visual feature extractor.
overlap_thresh: the trajectory overlap threshold to be used for assignment.
mult_thresh: Whether or not to use weight threshold.
decay_time: weight for `decay_time` postprocessing.
iou: Either [None, '', "mult" or "max"]
Whether to use multiplicative or max iou reweighting.
max_center_dist: distance threshold for filtering trajectory score matrix.
persistent_tracking: whether to keep a buffer across chunks or not.
max_gap: the max number of frames a trajectory can be missing before termination.
max_tracks: the maximum number of tracks that can be created while tracking.
We force the tracker to assign instances to a track instead of creating a new track if max_tracks has been reached.
verbose: Whether or not to turn on debug printing after each operation.
"""
self.track_queue = TrackQueue(
window_size=window_size, max_gap=max_gap, verbose=verbose
)
self.use_vis_feats = use_vis_feats
self.overlap_thresh = overlap_thresh
self.mult_thresh = mult_thresh
self.decay_time = decay_time
self.iou = iou
self.max_center_dist = max_center_dist
self.persistent_tracking = persistent_tracking
self.verbose = verbose
self.max_tracks = max_tracks
sliding_inference(model, frames)
¶
Perform sliding inference on the input video (instances) with a given window size.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
GlobalTrackingTransformer
|
the pretrained GlobalTrackingTransformer to be used for inference |
required |
frames |
list[Frame]
|
A list of Frames (See |
required |
Returns:
Name | Type | Description |
---|---|---|
frames |
list[Frame]
|
A list of Frames populated with pred_track_ids and asso_matrices |
Source code in dreem/inference/tracker.py
def sliding_inference(
self, model: GlobalTrackingTransformer, frames: list[Frame]
) -> list[Frame]:
"""Perform sliding inference on the input video (instances) with a given window size.
Args:
model: the pretrained GlobalTrackingTransformer to be used for inference
frames: A list of Frames (See `dreem.io.Frame` for more info).
Returns:
frames: A list of Frames populated with pred_track_ids and asso_matrices
"""
# B: batch size.
# D: embedding dimension.
# nc: number of channels.
# H: height.
# W: width.
for batch_idx, frame_to_track in enumerate(frames):
tracked_frames = self.track_queue.collate_tracks(
device=frame_to_track.frame_id.device
)
if self.verbose:
warnings.warn(
f"Current number of tracks is {self.track_queue.n_tracks}"
)
if (
self.persistent_tracking and frame_to_track.frame_id == 0
): # check for new video and clear queue
if self.verbose:
warnings.warn("New Video! Resetting Track Queue.")
self.track_queue.end_tracks()
"""
Initialize tracks on first frame where detections appear.
"""
if len(self.track_queue) == 0:
if frame_to_track.has_instances():
if self.verbose:
warnings.warn(
f"Initializing track on clip ind {batch_idx} frame {frame_to_track.frame_id.item()}"
)
curr_track_id = 0
for i, instance in enumerate(frames[batch_idx].instances):
instance.pred_track_id = instance.gt_track_id
curr_track_id = max(curr_track_id, instance.pred_track_id)
for i, instance in enumerate(frames[batch_idx].instances):
if instance.pred_track_id == -1:
curr_track += 1
instance.pred_track_id = curr_track_id
else:
if (
frame_to_track.has_instances()
): # Check if there are detections. If there are skip and increment gap count
frames_to_track = tracked_frames + [
frame_to_track
] # better var name?
query_ind = len(frames_to_track) - 1
frame_to_track = self._run_global_tracker(
model,
frames_to_track,
query_ind=query_ind,
)
if frame_to_track.has_instances():
self.track_queue.add_frame(frame_to_track)
else:
self.track_queue.increment_gaps([])
frames[batch_idx] = frame_to_track
return frames
track(model, frames)
¶
Run tracker and get predicted trajectories.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
GlobalTrackingTransformer
|
the pretrained GlobalTrackingTransformer to be used for inference |
required |
frames |
list[dict]
|
data dict to run inference on |
required |
Returns:
Type | Description |
---|---|
list[Frame]
|
List of Frames populated with pred track ids and association matrix scores |
Source code in dreem/inference/tracker.py
def track(
self, model: GlobalTrackingTransformer, frames: list[dict]
) -> list[Frame]:
"""Run tracker and get predicted trajectories.
Args:
model: the pretrained GlobalTrackingTransformer to be used for inference
frames: data dict to run inference on
Returns:
List of Frames populated with pred track ids and association matrix scores
"""
# Extract feature representations with pre-trained encoder.
_ = model.eval()
for frame in frames:
if frame.has_instances():
if not self.use_vis_feats:
for instance in frame.instances:
instance.features = torch.zeros(1, model.d_model)
# frame["features"] = torch.randn(
# num_frame_instances, self.model.d_model
# )
# comment out to turn encoder off
# Assuming the encoder is already trained or train encoder jointly.
elif not frame.has_features():
with torch.no_grad():
crops = frame.get_crops()
z = model.visual_encoder(crops)
for i, z_i in enumerate(z):
frame.instances[i].features = z_i
# I feel like this chunk is unnecessary:
# reid_features = torch.cat(
# [frame["features"] for frame in instances], dim=0
# ).unsqueeze(0)
# asso_preds, pred_boxes, pred_time, embeddings = self.model(
# instances, reid_features
# )
instances_pred = self.sliding_inference(model, frames)
if not self.persistent_tracking:
if self.verbose:
warnings.warn(f"Clearing Queue after tracking")
self.track_queue.end_tracks()
return instances_pred