eval_dataset
dreem.datasets.eval_dataset
¶
Module containing wrapper for merging gt and pred datasets for evaluation.
EvalDataset
¶
Bases: Dataset
Wrapper around gt and predicted dataset.
Source code in dreem/datasets/eval_dataset.py
class EvalDataset(Dataset):
"""Wrapper around gt and predicted dataset."""
def __init__(self, gt_dataset: Dataset, pred_dataset: Dataset) -> None:
"""Initialize EvalDataset.
Args:
gt_dataset: A Dataset object containing ground truth track ids
pred_dataset: A dataset object containing predicted track ids
"""
self.gt_dataset = gt_dataset
self.pred_dataset = pred_dataset
def __len__(self) -> int:
"""Get the size of the dataset.
Returns:
the size or the number of chunks in the dataset
"""
return len(self.gt_dataset)
def __getitem__(self, idx: int) -> List[Frame]:
"""Get an element of the dataset.
Args:
idx: the index of the batch. Note this is not the index of the video
or the frame.
Returns:
A list of Frames where frames contain instances w gt and pred track ids + bboxes.
"""
gt_batch = self.gt_dataset[idx]
pred_batch = self.pred_dataset[idx]
eval_frames = []
for gt_frame, pred_frame in zip(gt_batch, pred_batch):
eval_instances = []
for i, gt_instance in enumerate(gt_frame.instances):
gt_track_id = gt_instance.gt_track_id
try:
pred_track_id = pred_frame.instances[i].gt_track_id
pred_bbox = pred_frame.instances[i].bbox
except IndexError:
pred_track_id = -1
pred_bbox = [-1, -1, -1, -1]
eval_instances.append(
Instance(
gt_track_id=gt_track_id,
pred_track_id=pred_track_id,
bbox=pred_bbox,
)
)
eval_frames.append(
Frame(
video_id=gt_frame.video_id,
frame_id=gt_frame.frame_id,
vid_file=gt_frame.video.filename,
img_shape=gt_frame.img_shape,
instances=eval_instances,
)
)
return eval_frames
__getitem__(idx)
¶
Get an element of the dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idx |
int
|
the index of the batch. Note this is not the index of the video or the frame. |
required |
Returns:
Type | Description |
---|---|
List[Frame]
|
A list of Frames where frames contain instances w gt and pred track ids + bboxes. |
Source code in dreem/datasets/eval_dataset.py
def __getitem__(self, idx: int) -> List[Frame]:
"""Get an element of the dataset.
Args:
idx: the index of the batch. Note this is not the index of the video
or the frame.
Returns:
A list of Frames where frames contain instances w gt and pred track ids + bboxes.
"""
gt_batch = self.gt_dataset[idx]
pred_batch = self.pred_dataset[idx]
eval_frames = []
for gt_frame, pred_frame in zip(gt_batch, pred_batch):
eval_instances = []
for i, gt_instance in enumerate(gt_frame.instances):
gt_track_id = gt_instance.gt_track_id
try:
pred_track_id = pred_frame.instances[i].gt_track_id
pred_bbox = pred_frame.instances[i].bbox
except IndexError:
pred_track_id = -1
pred_bbox = [-1, -1, -1, -1]
eval_instances.append(
Instance(
gt_track_id=gt_track_id,
pred_track_id=pred_track_id,
bbox=pred_bbox,
)
)
eval_frames.append(
Frame(
video_id=gt_frame.video_id,
frame_id=gt_frame.frame_id,
vid_file=gt_frame.video.filename,
img_shape=gt_frame.img_shape,
instances=eval_instances,
)
)
return eval_frames
__init__(gt_dataset, pred_dataset)
¶
Initialize EvalDataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
gt_dataset |
Dataset
|
A Dataset object containing ground truth track ids |
required |
pred_dataset |
Dataset
|
A dataset object containing predicted track ids |
required |