tracking_dataset
dreem.datasets.tracking_dataset
¶
Module containing Lightning module wrapper around all other datasets.
TrackingDataset
¶
Bases: LightningDataModule
Lightning dataset used to load dataloaders for train, test and validation.
Nice for wrapping around other data formats.
Source code in dreem/datasets/tracking_dataset.py
class TrackingDataset(LightningDataModule):
"""Lightning dataset used to load dataloaders for train, test and validation.
Nice for wrapping around other data formats.
"""
def __init__(
self,
train_ds: Union[
SleapDataset, MicroscopyDataset, CellTrackingDataset, None
] = None,
train_dl: DataLoader = None,
val_ds: Union[
SleapDataset, MicroscopyDataset, CellTrackingDataset, None
] = None,
val_dl: DataLoader = None,
test_ds: Union[
SleapDataset, MicroscopyDataset, CellTrackingDataset, None
] = None,
test_dl: DataLoader = None,
):
"""Initialize tracking dataset.
Args:
train_ds: Sleap or Microscopy training Dataset
train_dl: Training dataloader. Only used for overriding `train_dataloader`.
val_ds: Sleap or Microscopy Validation set
val_dl : Validation dataloader. Only used for overriding `val_dataloader`.
test_ds: Sleap or Microscopy test set
test_dl : Test dataloader. Only used for overriding `test_dataloader`.
"""
super().__init__()
self.train_ds = train_ds
self.train_dl = train_dl
self.val_ds = val_ds
self.val_dl = val_dl
self.test_ds = test_ds
self.test_dl = test_dl
def setup(self, stage=None):
"""Set up lightning dataset.
UNUSED.
"""
pass
def train_dataloader(self) -> DataLoader:
"""Get train_dataloader.
Returns: The Training Dataloader.
"""
if self.train_dl is None and self.train_ds is None:
return None
elif self.train_dl is None:
return DataLoader(
self.train_ds,
batch_size=1,
shuffle=True,
pin_memory=False,
collate_fn=self.train_ds.no_batching_fn,
num_workers=0,
generator=(
torch.Generator(device="cuda")
if torch.cuda.is_available()
else torch.Generator()
),
)
else:
return self.train_dl
def val_dataloader(self) -> DataLoader:
"""Get val dataloader.
Returns: The validation dataloader.
"""
if self.val_dl is None and self.val_ds is None:
return None
elif self.val_dl is None:
return DataLoader(
self.val_ds,
batch_size=1,
shuffle=False,
pin_memory=0,
collate_fn=self.train_ds.no_batching_fn,
num_workers=False,
generator=None,
)
else:
return self.val_dl
def test_dataloader(self) -> DataLoader:
"""Get.
Returns: The test dataloader
"""
if self.test_dl is None and self.test_ds is None:
return None
elif self.test_dl is None:
return DataLoader(
self.test_ds,
batch_size=1,
shuffle=False,
pin_memory=0,
collate_fn=self.train_ds.no_batching_fn,
num_workers=False,
generator=None,
)
else:
return self.test_dl
__init__(train_ds=None, train_dl=None, val_ds=None, val_dl=None, test_ds=None, test_dl=None)
¶
Initialize tracking dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
train_ds |
Union[SleapDataset, MicroscopyDataset, CellTrackingDataset, None]
|
Sleap or Microscopy training Dataset |
None
|
train_dl |
DataLoader
|
Training dataloader. Only used for overriding |
None
|
val_ds |
Union[SleapDataset, MicroscopyDataset, CellTrackingDataset, None]
|
Sleap or Microscopy Validation set |
None
|
val_dl |
Validation dataloader. Only used for overriding |
None
|
|
test_ds |
Union[SleapDataset, MicroscopyDataset, CellTrackingDataset, None]
|
Sleap or Microscopy test set |
None
|
test_dl |
Test dataloader. Only used for overriding |
None
|
Source code in dreem/datasets/tracking_dataset.py
def __init__(
self,
train_ds: Union[
SleapDataset, MicroscopyDataset, CellTrackingDataset, None
] = None,
train_dl: DataLoader = None,
val_ds: Union[
SleapDataset, MicroscopyDataset, CellTrackingDataset, None
] = None,
val_dl: DataLoader = None,
test_ds: Union[
SleapDataset, MicroscopyDataset, CellTrackingDataset, None
] = None,
test_dl: DataLoader = None,
):
"""Initialize tracking dataset.
Args:
train_ds: Sleap or Microscopy training Dataset
train_dl: Training dataloader. Only used for overriding `train_dataloader`.
val_ds: Sleap or Microscopy Validation set
val_dl : Validation dataloader. Only used for overriding `val_dataloader`.
test_ds: Sleap or Microscopy test set
test_dl : Test dataloader. Only used for overriding `test_dataloader`.
"""
super().__init__()
self.train_ds = train_ds
self.train_dl = train_dl
self.val_ds = val_ds
self.val_dl = val_dl
self.test_ds = test_ds
self.test_dl = test_dl
setup(stage=None)
¶
test_dataloader()
¶
Get.
Returns: The test dataloader
Source code in dreem/datasets/tracking_dataset.py
def test_dataloader(self) -> DataLoader:
"""Get.
Returns: The test dataloader
"""
if self.test_dl is None and self.test_ds is None:
return None
elif self.test_dl is None:
return DataLoader(
self.test_ds,
batch_size=1,
shuffle=False,
pin_memory=0,
collate_fn=self.train_ds.no_batching_fn,
num_workers=False,
generator=None,
)
else:
return self.test_dl
train_dataloader()
¶
Get train_dataloader.
Returns: The Training Dataloader.
Source code in dreem/datasets/tracking_dataset.py
def train_dataloader(self) -> DataLoader:
"""Get train_dataloader.
Returns: The Training Dataloader.
"""
if self.train_dl is None and self.train_ds is None:
return None
elif self.train_dl is None:
return DataLoader(
self.train_ds,
batch_size=1,
shuffle=True,
pin_memory=False,
collate_fn=self.train_ds.no_batching_fn,
num_workers=0,
generator=(
torch.Generator(device="cuda")
if torch.cuda.is_available()
else torch.Generator()
),
)
else:
return self.train_dl
val_dataloader()
¶
Get val dataloader.
Returns: The validation dataloader.
Source code in dreem/datasets/tracking_dataset.py
def val_dataloader(self) -> DataLoader:
"""Get val dataloader.
Returns: The validation dataloader.
"""
if self.val_dl is None and self.val_ds is None:
return None
elif self.val_dl is None:
return DataLoader(
self.val_ds,
batch_size=1,
shuffle=False,
pin_memory=0,
collate_fn=self.train_ds.no_batching_fn,
num_workers=False,
generator=None,
)
else:
return self.val_dl