track
dreem.inference.track
¶
Script to run inference and get out tracks.
export_trajectories(frames_pred, save_path=None)
¶
Convert trajectories to data frame and save as .csv.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
frames_pred |
list[Frame]
|
A list of Frames with predicted track ids. |
required |
save_path |
str
|
The path to save the predicted trajectories to. |
None
|
Returns:
Type | Description |
---|---|
DataFrame
|
A dictionary containing the predicted track id and centroid coordinates for each instance in the video. |
Source code in dreem/inference/track.py
def export_trajectories(
frames_pred: list["dreem.io.Frame"], save_path: str = None
) -> pd.DataFrame:
"""Convert trajectories to data frame and save as .csv.
Args:
frames_pred: A list of Frames with predicted track ids.
save_path: The path to save the predicted trajectories to.
Returns:
A dictionary containing the predicted track id and centroid coordinates for each instance in the video.
"""
save_dict = {}
frame_ids = []
X, Y = [], []
pred_track_ids = []
track_scores = []
for frame in frames_pred:
for i, instance in enumerate(frame.instances):
frame_ids.append(frame.frame_id.item())
bbox = instance.bbox.squeeze()
y = (bbox[2] + bbox[0]) / 2
x = (bbox[3] + bbox[1]) / 2
X.append(x.item())
Y.append(y.item())
track_scores.append(instance.track_score)
pred_track_ids.append(instance.pred_track_id.item())
save_dict["Frame"] = frame_ids
save_dict["X"] = X
save_dict["Y"] = Y
save_dict["Pred_track_id"] = pred_track_ids
save_dict["Track_score"] = track_scores
save_df = pd.DataFrame(save_dict)
if save_path:
save_df.to_csv(save_path, index=False)
return save_df
run(cfg)
¶
Run inference based on config file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cfg |
DictConfig
|
A dictconfig loaded from hydra containing checkpoint path and data |
required |
Source code in dreem/inference/track.py
@hydra.main(config_path="configs", config_name=None, version_base=None)
def run(cfg: DictConfig) -> dict[int, sio.Labels]:
"""Run inference based on config file.
Args:
cfg: A dictconfig loaded from hydra containing checkpoint path and data
"""
pred_cfg = Config(cfg)
if "checkpoints" in cfg.keys():
try:
index = int(os.environ["POD_INDEX"])
# For testing without deploying a job on runai
except KeyError:
index = input("Pod Index Not found! Please choose a pod index: ")
print(f"Pod Index: {index}")
checkpoints = pd.read_csv(cfg.checkpoints)
checkpoint = checkpoints.iloc[index]
else:
checkpoint = pred_cfg.cfg.ckpt_path
model = GTRRunner.load_from_checkpoint(checkpoint)
tracker_cfg = pred_cfg.get_tracker_cfg()
print("Updating tracker hparams")
model.tracker_cfg = tracker_cfg
print(f"Using the following params for tracker:")
pprint(model.tracker_cfg)
dataset = pred_cfg.get_dataset(mode="test")
dataloader = pred_cfg.get_dataloader(dataset, mode="test")
trainer = pred_cfg.get_trainer()
preds = track(model, trainer, dataloader)
outdir = pred_cfg.cfg.outdir if "outdir" in pred_cfg.cfg else "./results"
os.makedirs(outdir, exist_ok=True)
run_num = 0
for i, pred in preds.items():
outpath = os.path.join(
outdir,
f"{Path(dataloader.dataset.label_files[i]).stem}.dreem_inference.v{run_num}.slp",
)
if os.path.exists(outpath):
run_num += 1
outpath = outpath.replace(f".v{run_num-1}", f".v{run_num}")
print(f"Saving {preds} to {outpath}")
pred.save(outpath)
return preds
track(model, trainer, dataloader)
¶
Run Inference.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
GTRRunner
|
GTRRunner model loaded from checkpoint used for inference |
required |
trainer |
Trainer
|
lighting Trainer object used for handling inference log. |
required |
dataloader |
DataLoader
|
dataloader containing inference data |
required |
Return
List of DataFrames containing prediction results for each video
Source code in dreem/inference/track.py
def track(
model: GTRRunner, trainer: pl.Trainer, dataloader: torch.utils.data.DataLoader
) -> list[pd.DataFrame]:
"""Run Inference.
Args:
model: GTRRunner model loaded from checkpoint used for inference
trainer: lighting Trainer object used for handling inference log.
dataloader: dataloader containing inference data
Return:
List of DataFrames containing prediction results for each video
"""
num_videos = len(dataloader.dataset.vid_files)
preds = trainer.predict(model, dataloader)
vid_trajectories = {i: [] for i in range(num_videos)}
tracks = {}
for batch in preds:
for frame in batch:
lf, tracks = frame.to_slp(tracks)
if frame.frame_id.item() == 0:
print(f"Video: {lf.video}")
vid_trajectories[frame.video_id.item()].append(lf)
for vid_id, video in vid_trajectories.items():
if len(video) > 0:
try:
vid_trajectories[vid_id] = sio.Labels(video)
except AttributeError as e:
print(video[0].video)
raise (e)
return vid_trajectories