base_dataset
dreem.datasets.base_dataset
¶
Module containing logic for loading datasets.
Classes:
Name | Description |
---|---|
BaseDataset |
Base Dataset for microscopy and sleap datasets to override. |
BaseDataset
¶
Bases: Dataset
Base Dataset for microscopy and sleap datasets to override.
Methods:
Name | Description |
---|---|
__getitem__ |
Get an element of the dataset. |
__init__ |
Initialize Dataset. |
__len__ |
Get the size of the dataset. |
create_chunks_other |
Legacy chunking logic. Does not support unannotated segments. |
create_chunks_slp |
Get indexing for data. |
get_indices |
Retrieve label and frame indices given batch index. |
get_instances |
Build chunk of frames. |
no_batching_fn |
Collate function used to overwrite dataloader batching function. |
process_segments |
Process segments to stitch. Modifies state variables chunked_frame_idx and label_idx. |
Source code in dreem/datasets/base_dataset.py
class BaseDataset(Dataset):
"""Base Dataset for microscopy and sleap datasets to override."""
def __init__(
self,
label_files: list[str],
vid_files: list[str],
padding: int,
crop_size: Union[int, list[int]],
chunk: bool,
clip_length: int,
mode: str,
augmentations: dict | None = None,
n_chunks: int | float = 1.0,
seed: int | None = None,
gt_list: str | None = None,
):
"""Initialize Dataset.
Args:
label_files: a list of paths to label files.
should at least contain detections for inference, detections + tracks for training.
vid_files: list of paths to video files.
padding: amount of padding around object crops
crop_size: the size of the object crops
chunk: whether or not to chunk the dataset into batches
clip_length: the number of frames in each chunk
mode: `train` or `val`. Determines whether this dataset is used for
training or validation. Currently doesn't affect dataset logic
augmentations: An optional dict mapping augmentations to parameters.
See subclasses for details.
n_chunks: Number of chunks to subsample from.
Can either a fraction of the dataset (ie (0,1.0]) or number of chunks
seed: set a seed for reproducibility
gt_list: An optional path to .txt file containing ground truth for
cell tracking challenge datasets.
"""
self.vid_files = vid_files
self.label_files = label_files
self.padding = padding
self.crop_size = crop_size
self.chunk = chunk
self.clip_length = clip_length
self.mode = mode
self.n_chunks = n_chunks
self.seed = seed
if self.seed is not None:
np.random.seed(self.seed)
if augmentations and self.mode == "train":
self.instance_dropout = augmentations.pop(
"InstanceDropout", {"p": 0.0, "n": 0}
)
self.node_dropout = data_utils.NodeDropout(
**augmentations.pop("NodeDropout", {"p": 0.0, "n": 0})
)
self.augmentations = data_utils.build_augmentations(augmentations)
else:
self.instance_dropout = {"p": 0.0, "n": 0}
self.node_dropout = data_utils.NodeDropout(p=0.0, n=0)
self.augmentations = None
# Initialize in subclasses
self.frame_idx = None
self.labels = None
self.gt_list = None
def process_segments(
self, i: int, segments_to_stitch: list[torch.Tensor], clip_length: int
) -> None:
"""Process segments to stitch. Modifies state variables chunked_frame_idx and label_idx.
Args:
segments_to_stitch: list of segments to stitch
i: index of the video
clip_length: the number of frames in each chunk
Returns: None
"""
stitched_segment = torch.cat(segments_to_stitch)
frame_idx_split = torch.split(stitched_segment, clip_length)
self.chunked_frame_idx.extend(frame_idx_split)
self.label_idx.extend(len(frame_idx_split) * [i])
def create_chunks_slp(self) -> None:
"""Get indexing for data.
Creates both indexes for selecting dataset (label_idx) and frame in
dataset (chunked_frame_idx). If chunking is false, we index directly
using the frame ids. Setting chunking to true creates a list of lists
containing chunk frames for indexing. This is useful for computational
efficiency and data shuffling. To be called by subclass __init__()
"""
self.chunked_frame_idx, self.label_idx = [], []
# go through each slp file and create chunks that respect max_batching_gap
for i, slp_file in enumerate(self.label_files):
annotated_segments = self.annotated_segments[slp_file]
segments_to_stitch = []
prev_end = annotated_segments[0][1] # end of first segment
for start, end in annotated_segments:
# check if the start of current segment is within batching_max_gap of end of previous
if (
int(start) - int(prev_end) < self.max_batching_gap
) or not self.chunk: # also takes care of first segment as start < prev_end
segments_to_stitch.append(torch.arange(start, end + 1))
prev_end = end
else:
# stitch previous set of segments before creating a new chunk
self.process_segments(i, segments_to_stitch, self.clip_length)
# reset segments_to_stitch as we are starting a new chunk
segments_to_stitch = [torch.arange(start, end + 1)]
prev_end = end
if not self.chunk:
self.process_segments(
i, segments_to_stitch, self.labels[i].video.shape[0]
)
else:
# add last chunk after the loop
if segments_to_stitch:
self.process_segments(i, segments_to_stitch, self.clip_length)
if self.n_chunks > 0 and self.n_chunks <= 1.0:
n_chunks = int(self.n_chunks * len(self.chunked_frame_idx))
elif self.n_chunks <= len(self.chunked_frame_idx):
n_chunks = int(self.n_chunks)
else:
n_chunks = len(self.chunked_frame_idx)
if n_chunks > 0 and n_chunks < len(self.chunked_frame_idx):
sample_idx = np.random.choice(
np.arange(len(self.chunked_frame_idx)), n_chunks, replace=False
)
self.chunked_frame_idx = [self.chunked_frame_idx[i] for i in sample_idx]
self.label_idx = [self.label_idx[i] for i in sample_idx]
# workaround for empty batch bug (needs to be changed). Check for batch with with only 1/10 size of clip length. Arbitrary thresholds
remove_idx = []
for i, frame_chunk in enumerate(self.chunked_frame_idx):
if (
len(frame_chunk)
<= min(int(self.clip_length / 10), 5)
# and frame_chunk[-1] % self.clip_length == 0
):
logger.warning(
f"Warning: Batch containing frames {frame_chunk} from video {self.vid_files[self.label_idx[i]]} has {len(frame_chunk)} frames. Removing to avoid empty batch possibility with failed frame loading"
)
remove_idx.append(i)
if len(remove_idx) > 0:
for i in sorted(remove_idx, reverse=True):
self.chunked_frame_idx.pop(i)
self.label_idx.pop(i)
def create_chunks_other(self) -> None:
"""Legacy chunking logic. Does not support unannotated segments.
Creates both indexes for selecting dataset (label_idx) and frame in
dataset (chunked_frame_idx). If chunking is false, we index directly
using the frame ids. Setting chunking to true creates a list of lists
containing chunk frames for indexing. This is useful for computational
efficiency and data shuffling. To be called by subclass __init__()
"""
if self.chunk:
self.chunked_frame_idx, self.label_idx = [], []
for i, frame_idx in enumerate(self.frame_idx):
frame_idx_split = torch.split(frame_idx, self.clip_length)
self.chunked_frame_idx.extend(frame_idx_split)
self.label_idx.extend(len(frame_idx_split) * [i])
if self.n_chunks > 0 and self.n_chunks <= 1.0:
n_chunks = int(self.n_chunks * len(self.chunked_frame_idx))
elif self.n_chunks <= len(self.chunked_frame_idx):
n_chunks = int(self.n_chunks)
else:
n_chunks = len(self.chunked_frame_idx)
if n_chunks > 0 and n_chunks < len(self.chunked_frame_idx):
sample_idx = np.random.choice(
np.arange(len(self.chunked_frame_idx)), n_chunks, replace=False
)
self.chunked_frame_idx = [self.chunked_frame_idx[i] for i in sample_idx]
self.label_idx = [self.label_idx[i] for i in sample_idx]
# workaround for empty batch bug (needs to be changed). Check for batch with with only 1/10 size of clip length. Arbitrary thresholds
remove_idx = []
for i, frame_chunk in enumerate(self.chunked_frame_idx):
if (
len(frame_chunk)
<= min(int(self.clip_length / 10), 5)
# and frame_chunk[-1] % self.clip_length == 0
):
logger.warning(
f"Warning: Batch containing frames {frame_chunk} from video {self.vid_files[self.label_idx[i]]} has {len(frame_chunk)} frames. Removing to avoid empty batch possibility with failed frame loading"
)
remove_idx.append(i)
if len(remove_idx) > 0:
for i in sorted(remove_idx, reverse=True):
self.chunked_frame_idx.pop(i)
self.label_idx.pop(i)
else:
self.chunked_frame_idx = self.frame_idx
self.label_idx = [i for i in range(len(self.labels))]
def __len__(self) -> int:
"""Get the size of the dataset.
Returns:
the size or the number of chunks in the dataset
"""
return len(self.chunked_frame_idx)
def no_batching_fn(self, batch: list[Frame]) -> list[Frame]:
"""Collate function used to overwrite dataloader batching function.
Args:
batch: the chunk of frames to be returned
Returns:
The batch
"""
return batch
def __getitem__(self, idx: int) -> list[Frame]:
"""Get an element of the dataset.
Args:
idx: the index of the batch. Note this is not the index of the video
or the frame.
Returns:
A list of `Frame`s in the chunk containing the metadata + instance features.
"""
label_idx, frame_idx = self.get_indices(idx)
return self.get_instances(label_idx, frame_idx)
def get_indices(self, idx: int):
"""Retrieve label and frame indices given batch index.
This method should be implemented in any subclass of the BaseDataset.
Args:
idx: the index of the batch.
Raises:
NotImplementedError: If this method is not overridden in a subclass.
"""
raise NotImplementedError("Must be implemented in subclass")
def get_instances(self, label_idx: list[int], frame_idx: list[int]):
"""Build chunk of frames.
This method should be implemented in any subclass of the BaseDataset.
Args:
label_idx: The index of the labels.
frame_idx: The index of the frames.
Raises:
NotImplementedError: If this method is not overridden in a subclass.
"""
raise NotImplementedError("Must be implemented in subclass")
__getitem__(idx)
¶
Get an element of the dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idx
|
int
|
the index of the batch. Note this is not the index of the video or the frame. |
required |
Returns:
Type | Description |
---|---|
list[Frame]
|
A list of |
Source code in dreem/datasets/base_dataset.py
def __getitem__(self, idx: int) -> list[Frame]:
"""Get an element of the dataset.
Args:
idx: the index of the batch. Note this is not the index of the video
or the frame.
Returns:
A list of `Frame`s in the chunk containing the metadata + instance features.
"""
label_idx, frame_idx = self.get_indices(idx)
return self.get_instances(label_idx, frame_idx)
__init__(label_files, vid_files, padding, crop_size, chunk, clip_length, mode, augmentations=None, n_chunks=1.0, seed=None, gt_list=None)
¶
Initialize Dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
label_files
|
list[str]
|
a list of paths to label files. should at least contain detections for inference, detections + tracks for training. |
required |
vid_files
|
list[str]
|
list of paths to video files. |
required |
padding
|
int
|
amount of padding around object crops |
required |
crop_size
|
Union[int, list[int]]
|
the size of the object crops |
required |
chunk
|
bool
|
whether or not to chunk the dataset into batches |
required |
clip_length
|
int
|
the number of frames in each chunk |
required |
mode
|
str
|
|
required |
augmentations
|
dict | None
|
An optional dict mapping augmentations to parameters. See subclasses for details. |
None
|
n_chunks
|
int | float
|
Number of chunks to subsample from. Can either a fraction of the dataset (ie (0,1.0]) or number of chunks |
1.0
|
seed
|
int | None
|
set a seed for reproducibility |
None
|
gt_list
|
str | None
|
An optional path to .txt file containing ground truth for cell tracking challenge datasets. |
None
|
Source code in dreem/datasets/base_dataset.py
def __init__(
self,
label_files: list[str],
vid_files: list[str],
padding: int,
crop_size: Union[int, list[int]],
chunk: bool,
clip_length: int,
mode: str,
augmentations: dict | None = None,
n_chunks: int | float = 1.0,
seed: int | None = None,
gt_list: str | None = None,
):
"""Initialize Dataset.
Args:
label_files: a list of paths to label files.
should at least contain detections for inference, detections + tracks for training.
vid_files: list of paths to video files.
padding: amount of padding around object crops
crop_size: the size of the object crops
chunk: whether or not to chunk the dataset into batches
clip_length: the number of frames in each chunk
mode: `train` or `val`. Determines whether this dataset is used for
training or validation. Currently doesn't affect dataset logic
augmentations: An optional dict mapping augmentations to parameters.
See subclasses for details.
n_chunks: Number of chunks to subsample from.
Can either a fraction of the dataset (ie (0,1.0]) or number of chunks
seed: set a seed for reproducibility
gt_list: An optional path to .txt file containing ground truth for
cell tracking challenge datasets.
"""
self.vid_files = vid_files
self.label_files = label_files
self.padding = padding
self.crop_size = crop_size
self.chunk = chunk
self.clip_length = clip_length
self.mode = mode
self.n_chunks = n_chunks
self.seed = seed
if self.seed is not None:
np.random.seed(self.seed)
if augmentations and self.mode == "train":
self.instance_dropout = augmentations.pop(
"InstanceDropout", {"p": 0.0, "n": 0}
)
self.node_dropout = data_utils.NodeDropout(
**augmentations.pop("NodeDropout", {"p": 0.0, "n": 0})
)
self.augmentations = data_utils.build_augmentations(augmentations)
else:
self.instance_dropout = {"p": 0.0, "n": 0}
self.node_dropout = data_utils.NodeDropout(p=0.0, n=0)
self.augmentations = None
# Initialize in subclasses
self.frame_idx = None
self.labels = None
self.gt_list = None
__len__()
¶
Get the size of the dataset.
Returns:
Type | Description |
---|---|
int
|
the size or the number of chunks in the dataset |
create_chunks_other()
¶
Legacy chunking logic. Does not support unannotated segments.
Creates both indexes for selecting dataset (label_idx) and frame in dataset (chunked_frame_idx). If chunking is false, we index directly using the frame ids. Setting chunking to true creates a list of lists containing chunk frames for indexing. This is useful for computational efficiency and data shuffling. To be called by subclass init()
Source code in dreem/datasets/base_dataset.py
def create_chunks_other(self) -> None:
"""Legacy chunking logic. Does not support unannotated segments.
Creates both indexes for selecting dataset (label_idx) and frame in
dataset (chunked_frame_idx). If chunking is false, we index directly
using the frame ids. Setting chunking to true creates a list of lists
containing chunk frames for indexing. This is useful for computational
efficiency and data shuffling. To be called by subclass __init__()
"""
if self.chunk:
self.chunked_frame_idx, self.label_idx = [], []
for i, frame_idx in enumerate(self.frame_idx):
frame_idx_split = torch.split(frame_idx, self.clip_length)
self.chunked_frame_idx.extend(frame_idx_split)
self.label_idx.extend(len(frame_idx_split) * [i])
if self.n_chunks > 0 and self.n_chunks <= 1.0:
n_chunks = int(self.n_chunks * len(self.chunked_frame_idx))
elif self.n_chunks <= len(self.chunked_frame_idx):
n_chunks = int(self.n_chunks)
else:
n_chunks = len(self.chunked_frame_idx)
if n_chunks > 0 and n_chunks < len(self.chunked_frame_idx):
sample_idx = np.random.choice(
np.arange(len(self.chunked_frame_idx)), n_chunks, replace=False
)
self.chunked_frame_idx = [self.chunked_frame_idx[i] for i in sample_idx]
self.label_idx = [self.label_idx[i] for i in sample_idx]
# workaround for empty batch bug (needs to be changed). Check for batch with with only 1/10 size of clip length. Arbitrary thresholds
remove_idx = []
for i, frame_chunk in enumerate(self.chunked_frame_idx):
if (
len(frame_chunk)
<= min(int(self.clip_length / 10), 5)
# and frame_chunk[-1] % self.clip_length == 0
):
logger.warning(
f"Warning: Batch containing frames {frame_chunk} from video {self.vid_files[self.label_idx[i]]} has {len(frame_chunk)} frames. Removing to avoid empty batch possibility with failed frame loading"
)
remove_idx.append(i)
if len(remove_idx) > 0:
for i in sorted(remove_idx, reverse=True):
self.chunked_frame_idx.pop(i)
self.label_idx.pop(i)
else:
self.chunked_frame_idx = self.frame_idx
self.label_idx = [i for i in range(len(self.labels))]
create_chunks_slp()
¶
Get indexing for data.
Creates both indexes for selecting dataset (label_idx) and frame in dataset (chunked_frame_idx). If chunking is false, we index directly using the frame ids. Setting chunking to true creates a list of lists containing chunk frames for indexing. This is useful for computational efficiency and data shuffling. To be called by subclass init()
Source code in dreem/datasets/base_dataset.py
def create_chunks_slp(self) -> None:
"""Get indexing for data.
Creates both indexes for selecting dataset (label_idx) and frame in
dataset (chunked_frame_idx). If chunking is false, we index directly
using the frame ids. Setting chunking to true creates a list of lists
containing chunk frames for indexing. This is useful for computational
efficiency and data shuffling. To be called by subclass __init__()
"""
self.chunked_frame_idx, self.label_idx = [], []
# go through each slp file and create chunks that respect max_batching_gap
for i, slp_file in enumerate(self.label_files):
annotated_segments = self.annotated_segments[slp_file]
segments_to_stitch = []
prev_end = annotated_segments[0][1] # end of first segment
for start, end in annotated_segments:
# check if the start of current segment is within batching_max_gap of end of previous
if (
int(start) - int(prev_end) < self.max_batching_gap
) or not self.chunk: # also takes care of first segment as start < prev_end
segments_to_stitch.append(torch.arange(start, end + 1))
prev_end = end
else:
# stitch previous set of segments before creating a new chunk
self.process_segments(i, segments_to_stitch, self.clip_length)
# reset segments_to_stitch as we are starting a new chunk
segments_to_stitch = [torch.arange(start, end + 1)]
prev_end = end
if not self.chunk:
self.process_segments(
i, segments_to_stitch, self.labels[i].video.shape[0]
)
else:
# add last chunk after the loop
if segments_to_stitch:
self.process_segments(i, segments_to_stitch, self.clip_length)
if self.n_chunks > 0 and self.n_chunks <= 1.0:
n_chunks = int(self.n_chunks * len(self.chunked_frame_idx))
elif self.n_chunks <= len(self.chunked_frame_idx):
n_chunks = int(self.n_chunks)
else:
n_chunks = len(self.chunked_frame_idx)
if n_chunks > 0 and n_chunks < len(self.chunked_frame_idx):
sample_idx = np.random.choice(
np.arange(len(self.chunked_frame_idx)), n_chunks, replace=False
)
self.chunked_frame_idx = [self.chunked_frame_idx[i] for i in sample_idx]
self.label_idx = [self.label_idx[i] for i in sample_idx]
# workaround for empty batch bug (needs to be changed). Check for batch with with only 1/10 size of clip length. Arbitrary thresholds
remove_idx = []
for i, frame_chunk in enumerate(self.chunked_frame_idx):
if (
len(frame_chunk)
<= min(int(self.clip_length / 10), 5)
# and frame_chunk[-1] % self.clip_length == 0
):
logger.warning(
f"Warning: Batch containing frames {frame_chunk} from video {self.vid_files[self.label_idx[i]]} has {len(frame_chunk)} frames. Removing to avoid empty batch possibility with failed frame loading"
)
remove_idx.append(i)
if len(remove_idx) > 0:
for i in sorted(remove_idx, reverse=True):
self.chunked_frame_idx.pop(i)
self.label_idx.pop(i)
get_indices(idx)
¶
Retrieve label and frame indices given batch index.
This method should be implemented in any subclass of the BaseDataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idx
|
int
|
the index of the batch. |
required |
Raises:
Type | Description |
---|---|
NotImplementedError
|
If this method is not overridden in a subclass. |
Source code in dreem/datasets/base_dataset.py
def get_indices(self, idx: int):
"""Retrieve label and frame indices given batch index.
This method should be implemented in any subclass of the BaseDataset.
Args:
idx: the index of the batch.
Raises:
NotImplementedError: If this method is not overridden in a subclass.
"""
raise NotImplementedError("Must be implemented in subclass")
get_instances(label_idx, frame_idx)
¶
Build chunk of frames.
This method should be implemented in any subclass of the BaseDataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
label_idx
|
list[int]
|
The index of the labels. |
required |
frame_idx
|
list[int]
|
The index of the frames. |
required |
Raises:
Type | Description |
---|---|
NotImplementedError
|
If this method is not overridden in a subclass. |
Source code in dreem/datasets/base_dataset.py
def get_instances(self, label_idx: list[int], frame_idx: list[int]):
"""Build chunk of frames.
This method should be implemented in any subclass of the BaseDataset.
Args:
label_idx: The index of the labels.
frame_idx: The index of the frames.
Raises:
NotImplementedError: If this method is not overridden in a subclass.
"""
raise NotImplementedError("Must be implemented in subclass")
no_batching_fn(batch)
¶
process_segments(i, segments_to_stitch, clip_length)
¶
Process segments to stitch. Modifies state variables chunked_frame_idx and label_idx.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
segments_to_stitch
|
list[Tensor]
|
list of segments to stitch |
required |
i
|
int
|
index of the video |
required |
clip_length
|
int
|
the number of frames in each chunk |
required |
Returns: None
Source code in dreem/datasets/base_dataset.py
def process_segments(
self, i: int, segments_to_stitch: list[torch.Tensor], clip_length: int
) -> None:
"""Process segments to stitch. Modifies state variables chunked_frame_idx and label_idx.
Args:
segments_to_stitch: list of segments to stitch
i: index of the video
clip_length: the number of frames in each chunk
Returns: None
"""
stitched_segment = torch.cat(segments_to_stitch)
frame_idx_split = torch.split(stitched_segment, clip_length)
self.chunked_frame_idx.extend(frame_idx_split)
self.label_idx.extend(len(frame_idx_split) * [i])