Config
Parser¶
dreem.io.Config
¶
Class handling loading components based on config params.
Methods:
Name | Description |
---|---|
__init__ |
Initialize the class with config from hydra/omega conf. |
__repr__ |
Object representation of config class. |
__str__ |
Return a string representation of config class. |
from_yaml |
Load config directly from yaml. |
get |
Get config item. |
get_checkpointing |
Getter for lightning checkpointing callback. |
get_ctc_paths |
Get file paths from directory. Only for CTC datasets. |
get_data_paths |
Get file paths from directory. Only for SLEAP datasets. |
get_dataloader |
Getter for dataloader. |
get_dataset |
Getter for datasets. |
get_early_stopping |
Getter for lightning early stopping callback. |
get_gtr_runner |
Get lightning module for training, validation, and inference. |
get_logger |
Getter for logging callback. |
get_loss |
Getter for loss functions. |
get_model |
Getter for gtr model. |
get_optimizer |
Getter for optimizer. |
get_scheduler |
Getter for lr scheduler. |
get_tracker_cfg |
Getter for tracker config params. |
get_trainer |
Getter for the lightning trainer. |
set_hparams |
Setter function for overwriting specific hparams. |
Attributes:
Name | Type | Description |
---|---|---|
data_paths |
Get data paths. |
Source code in dreem/io/config.py
class Config:
"""Class handling loading components based on config params."""
def __init__(self, cfg: DictConfig, params_cfg: DictConfig | None = None):
"""Initialize the class with config from hydra/omega conf.
First uses `base_param` file then overwrites with specific `params_config`.
Args:
cfg: The `DictConfig` containing all the hyperparameters needed for
training/evaluation.
params_cfg: The `DictConfig` containing subset of hyperparameters to override.
training/evaluation
"""
base_cfg = cfg
logger.info(f"Base Config: {cfg}")
if "params_config" in cfg:
params_cfg = OmegaConf.load(cfg.params_config)
if params_cfg:
logger.info(f"Overwriting base config with {params_cfg}")
with open_dict(base_cfg):
self.cfg = OmegaConf.merge(base_cfg, params_cfg) # merge configs
else:
self.cfg = cfg
OmegaConf.set_struct(self.cfg, False)
self._vid_files = {}
def __repr__(self):
"""Object representation of config class."""
return f"Config({self.cfg})"
def __str__(self):
"""Return a string representation of config class."""
return f"Config({self.cfg})"
@classmethod
def from_yaml(cls, base_cfg_path: str, params_cfg_path: str | None = None) -> None:
"""Load config directly from yaml.
Args:
base_cfg_path: path to base config file.
params_cfg_path: path to override params.
"""
base_cfg = OmegaConf.load(base_cfg_path)
params_cfg = OmegaConf.load(params_cfg_path) if params_cfg_path else None
return cls(base_cfg, params_cfg)
def set_hparams(self, hparams: dict) -> bool:
"""Setter function for overwriting specific hparams.
Useful for changing 1 or 2 hyperparameters such as dataset.
Args:
hparams: A dict containing the hyperparameter to be overwritten and
the value to be changed
Returns:
`True` if config is successfully updated, `False` otherwise
"""
if hparams == {} or hparams is None:
logger.warning("Nothing to update!")
return False
for hparam, val in hparams.items():
try:
OmegaConf.update(self.cfg, hparam, val)
except Exception as e:
logger.exception(f"Failed to update {hparam} to {val} due to {e}")
return False
return True
def get(self, key: str, default=None, cfg: dict = None):
"""Get config item.
Args:
key: key of item to return
default: default value to return if key is missing.
cfg: the config dict from which to retrieve an item
"""
if cfg is None:
cfg = self.cfg
param = cfg.get(key, default)
if isinstance(param, DictConfig):
param = OmegaConf.to_container(param, resolve=True)
return param
def get_model(self) -> "GlobalTrackingTransformer":
"""Getter for gtr model.
Returns:
A global tracking transformer with parameters indicated by cfg
"""
from dreem.models import GlobalTrackingTransformer, GTRRunner
model_params = self.get("model", {})
ckpt_path = model_params.pop("ckpt_path", None)
if ckpt_path is not None and len(ckpt_path) > 0:
return GTRRunner.load_from_checkpoint(ckpt_path).model
return GlobalTrackingTransformer(**model_params)
def get_tracker_cfg(self) -> dict:
"""Getter for tracker config params.
Returns:
A dict containing the init params for `Tracker`.
"""
return self.get("tracker", {})
def get_gtr_runner(self, ckpt_path: str | None = None) -> "GTRRunner":
"""Get lightning module for training, validation, and inference.
Args:
ckpt_path: path to checkpoint for override
Returns:
a gtr runner model
"""
from dreem.models import GTRRunner
keys = ["tracker", "optimizer", "scheduler", "loss", "runner", "model"]
args = [key + "_cfg" if key != "runner" else key for key in keys]
params = {}
for key, arg in zip(keys, args):
sub_params = self.get(key, {})
if len(sub_params) == 0:
logger.warning(
f"`{key}` not found in config or is empty. Using defaults for {arg}!"
)
if key == "runner":
runner_params = sub_params
for k, v in runner_params.items():
params[k] = v
else:
params[arg] = sub_params
ckpt_path = params["model_cfg"].pop("ckpt_path", None)
if ckpt_path is not None and ckpt_path != "":
model = GTRRunner.load_from_checkpoint(
ckpt_path, tracker_cfg=params["tracker_cfg"], **runner_params
)
else:
model = GTRRunner(**params)
return model
def get_ctc_paths(
self, list_dir_path: list[str]
) -> tuple[list[str], list[str], list[str]]:
"""Get file paths from directory. Only for CTC datasets.
Args:
list_dir_path: list of directories to search for labels and videos
Returns:
lists of labels file paths and video file paths
"""
gt_list = []
raw_img_list = []
ctc_track_meta = []
# user can specify a list of directories, each of which can contain several subdirectories that come in pairs of (dset_name, dset_name_GT/TRA)
for dir_path in list_dir_path:
for subdir in os.listdir(dir_path):
if subdir.endswith("_GT"):
gt_path = os.path.join(dir_path, subdir, "TRA")
raw_img_path = os.path.join(dir_path, subdir.replace("_GT", ""))
# get filepaths for all tif files in gt_path
gt_list.append(glob.glob(os.path.join(gt_path, "*.tif")))
# get filepaths for all tif files in raw_img_path
raw_img_list.append(glob.glob(os.path.join(raw_img_path, "*.tif")))
man_track_file = glob.glob(os.path.join(gt_path, "man_track.txt"))
if len(man_track_file) > 0:
ctc_track_meta.append(man_track_file[0])
else:
logger.debug(
f"No man_track.txt file found in {gt_path}. Continuing..."
)
else:
continue
return gt_list, raw_img_list, ctc_track_meta
def get_data_paths(self, mode: str, data_cfg: dict) -> tuple[list[str], list[str]]:
"""Get file paths from directory. Only for SLEAP datasets.
Args:
mode: [None, "train", "test", "val"]. Indicates whether to use
train, val, or test params for dataset
data_cfg: Config for the dataset containing "dir" key.
Returns:
lists of labels file paths and video file paths respectively
"""
# hack to get around the fact that for test mode, get_data_paths is called before get_dataset.
# also, for train/val mode, data_cfg has had the dir key popped through self.get() called in get_dataset()
if mode == "test":
list_dir_path = data_cfg.get("dir", {}).get("path", None)
if list_dir_path is None:
raise ValueError(
"`dir` is missing from dataset config. Please provide a path to the directory containing the labels and videos."
)
self.labels_suffix = data_cfg.get("dir", {}).get("labels_suffix")
self.vid_suffix = data_cfg.get("dir", {}).get("vid_suffix")
else:
list_dir_path = self.data_dirs
if not isinstance(list_dir_path, list):
list_dir_path = [list_dir_path]
if self.labels_suffix == ".slp":
label_files = []
vid_files = []
for dir_path in list_dir_path:
logger.debug(f"Searching `{dir_path}` directory")
labels_path = f"{dir_path}/*{self.labels_suffix}"
vid_path = f"{dir_path}/*{self.vid_suffix}"
logger.debug(f"Searching for labels matching {labels_path}")
label_files.extend(glob.glob(labels_path))
logger.debug(f"Searching for videos matching {vid_path}")
vid_files.extend(glob.glob(vid_path))
elif self.labels_suffix == ".tif":
label_files, vid_files, ctc_track_meta = self.get_ctc_paths(list_dir_path)
logger.debug(f"Found {len(label_files)} labels and {len(vid_files)} videos")
# backdoor to set label files directly in the configs (i.e. bypass dir.path)
if data_cfg.get("slp_files", None):
logger.debug("Overriding label files with user provided list")
slp_files = data_cfg.get("slp_files")
if len(slp_files) > 0:
label_files = slp_files
if data_cfg.get("video_files", None):
individual_video_files = data_cfg.get("video_files")
if len(individual_video_files) > 0:
vid_files = individual_video_files
return label_files, vid_files
def get_dataset(
self,
mode: str,
label_files: list[str] | None = None,
vid_files: list[str | list[str]] = None,
) -> "SleapDataset" | "CellTrackingDataset":
"""Getter for datasets.
Args:
mode: [None, "train", "test", "val"]. Indicates whether to use
train, val, or test params for dataset
label_files: path to label_files for override
vid_files: path to vid_files for override
Returns:
Either a `SleapDataset` or `CellTrackingDataset` with params indicated by cfg
"""
from dreem.datasets import SleapDataset, CellTrackingDataset
dataset_params = self.get("dataset")
if dataset_params is None:
raise KeyError("`dataset` key is missing from cfg!")
if mode.lower() == "train":
dataset_params = self.get("train_dataset", {}, dataset_params)
elif mode.lower() == "val":
dataset_params = self.get("val_dataset", {}, dataset_params)
elif mode.lower() == "test":
dataset_params = self.get("test_dataset", {}, dataset_params)
else:
raise ValueError(
"`mode` must be one of ['train', 'val','test'], not '{mode}'"
)
# input validation
self.data_dirs = dataset_params.get("dir", {}).get("path", None)
self.labels_suffix = dataset_params.get("dir", {}).get("labels_suffix")
self.vid_suffix = dataset_params.get("dir", {}).get("vid_suffix")
if self.data_dirs is None:
raise ValueError(
"`dir` is missing from dataset config. Please provide a path to the directory containing the labels and videos."
)
if self.labels_suffix is None or self.vid_suffix is None:
raise KeyError(
f"Must provide a labels suffix and vid suffix to search for but found {self.labels_suffix} and {self.vid_suffix}"
)
# infer dataset type from the user provided suffix
if self.labels_suffix == ".slp":
# during training, multiple files can be used at once, so label_files is not passed in
# during inference, a single label_files string can be passed in as get_data_paths is
# called before get_dataset, hence the check
if label_files is None or vid_files is None:
label_files, vid_files = self.get_data_paths(mode, dataset_params)
dataset_params["slp_files"] = label_files
dataset_params["video_files"] = vid_files
dataset_params["data_dirs"] = self.data_dirs
self.data_paths = (mode, vid_files)
return SleapDataset(**dataset_params)
elif self.labels_suffix == ".tif":
# for CTC datasets, pass in a list of gt and raw image directories, eaech of which contain tifs
ctc_track_meta = None
list_dir_path = self.data_dirs # don't modify self.data_dirs
if not isinstance(list_dir_path, list):
list_dir_path = [list_dir_path]
if label_files is None or vid_files is None:
label_files, vid_files, ctc_track_meta = self.get_ctc_paths(
list_dir_path
)
dataset_params["data_dirs"] = self.data_dirs
# extract filepaths of all raw images and gt images (i.e. labelled masks)
dataset_params["gt_list"] = label_files
dataset_params["raw_img_list"] = vid_files
dataset_params["ctc_track_meta"] = ctc_track_meta
return CellTrackingDataset(**dataset_params)
else:
raise ValueError(
"Could not resolve dataset type from Config! Only .slp (SLEAP) and .tif (Cell Tracking Challenge) data formats are supported."
)
@property
def data_paths(self):
"""Get data paths."""
return self._vid_files
@data_paths.setter
def data_paths(self, paths: tuple[str, list[str]]):
"""Set data paths.
Args:
paths: A tuple containing (mode, vid_files)
"""
mode, vid_files = paths
self._vid_files[mode] = vid_files
def get_dataloader(
self,
dataset: "SleapDataset" | "MicroscopyDataset" | "CellTrackingDataset",
mode: str,
) -> torch.utils.data.DataLoader:
"""Getter for dataloader.
Args:
dataset: the Sleap or Microscopy Dataset used to initialize the dataloader
mode: either ["train", "val", or "test"] indicates which dataset
config to use
Returns:
A torch dataloader for `dataset` with parameters configured as specified
"""
dataloader_params = self.get("dataloader", {})
if mode.lower() == "train":
dataloader_params = self.get("train_dataloader", {}, dataloader_params)
elif mode.lower() == "val":
dataloader_params = self.get("val_dataloader", {}, dataloader_params)
elif mode.lower() == "test":
dataloader_params = self.get("test_dataloader", {}, dataloader_params)
else:
raise ValueError(
"`mode` must be one of ['train', 'val','test'], not '{mode}'"
)
if dataloader_params.get("num_workers", 0) > 0:
# prevent too many open files error
pin_memory = True
torch.multiprocessing.set_sharing_strategy("file_system")
else:
pin_memory = False
return torch.utils.data.DataLoader(
dataset=dataset,
batch_size=1,
pin_memory=pin_memory,
collate_fn=dataset.no_batching_fn,
**dataloader_params,
)
def get_optimizer(self, params: Iterable) -> torch.optim.Optimizer:
"""Getter for optimizer.
Args:
params: iterable of model parameters to optimize or dicts defining
parameter groups
Returns:
A torch Optimizer with specified params
"""
from dreem.models.model_utils import init_optimizer
optimizer_params = self.get("optimizer")
return init_optimizer(params, optimizer_params)
def get_scheduler(
self, optimizer: torch.optim.Optimizer
) -> torch.optim.lr_scheduler.LRScheduler | None:
"""Getter for lr scheduler.
Args:
optimizer: The optimizer to wrap the scheduler around
Returns:
A torch learning rate scheduler with specified params
"""
from dreem.models.model_utils import init_scheduler
lr_scheduler_params = self.get("scheduler")
if lr_scheduler_params is None:
logger.warning(
"`scheduler` key not found in cfg or is empty. No scheduler will be returned!"
)
return None
return init_scheduler(optimizer, lr_scheduler_params)
def get_loss(self) -> "dreem.training.losses.AssoLoss":
"""Getter for loss functions.
Returns:
An AssoLoss with specified params
"""
from dreem.training.losses import AssoLoss
loss_params = self.get("loss", {})
if len(loss_params) == 0:
logger.warning(
"`loss` key not found in cfg. Using default params for `AssoLoss`"
)
return AssoLoss(**loss_params)
def get_logger(self) -> pl.loggers.Logger:
"""Getter for logging callback.
Returns:
A Logger with specified params
"""
from dreem.models.model_utils import init_logger
logger_params = self.get("logging", {})
if len(logger_params) == 0:
logger.warning(
"`logging` key not found in cfg. No logger will be configured!"
)
return init_logger(
logger_params, OmegaConf.to_container(self.cfg, resolve=True)
)
def get_early_stopping(self) -> pl.callbacks.EarlyStopping:
"""Getter for lightning early stopping callback.
Returns:
A lightning early stopping callback with specified params
"""
early_stopping_params = self.get("early_stopping", None)
if early_stopping_params is None:
logger.warning(
"`early_stopping` was not found in cfg or was `null`. Early stopping will not be used!"
)
return None
elif len(early_stopping_params) == 0:
logger.warning("`early_stopping` cfg is empty! Using defaults")
return pl.callbacks.EarlyStopping(**early_stopping_params)
def get_checkpointing(self) -> pl.callbacks.ModelCheckpoint:
"""Getter for lightning checkpointing callback.
Returns:
A lightning checkpointing callback with specified params
"""
# convert to dict to enable extracting/removing params
checkpoint_params = self.get("checkpointing", {})
logging_params = self.get("logging", {})
dirpath = checkpoint_params.pop("dirpath", None)
if dirpath is None:
dirpath = f"./models/{self.get('group', '', logging_params)}/{self.get('name', '', logging_params)}"
dirpath = Path(dirpath).resolve()
if not Path(dirpath).exists():
try:
Path(dirpath).mkdir(parents=True, exist_ok=True)
except OSError as e:
logger.exception(
f"Cannot create a new folder!. Check the permissions to {dirpath}. \n {e}"
)
_ = checkpoint_params.pop("dirpath", None)
monitor = checkpoint_params.pop("monitor", ["val_loss"])
checkpointers = []
logger.info(
f"Saving checkpoints to `{dirpath}` based on the following metrics: {monitor}"
)
if len(checkpoint_params) == 0:
logger.warning(
"""`checkpointing` key was not found in cfg or was empty!
Configuring checkpointing to use default params!"""
)
for metric in monitor:
checkpointer = pl.callbacks.ModelCheckpoint(
monitor=metric,
dirpath=dirpath,
filename=f"{{epoch}}-{{{metric}}}",
**checkpoint_params,
)
checkpointer.CHECKPOINT_NAME_LAST = f"{{epoch}}-final-{{{metric}}}"
checkpointers.append(checkpointer)
return checkpointers
def get_trainer(
self,
callbacks: list[pl.callbacks.Callback] | None = None,
logger: pl.loggers.WandbLogger | None = None,
devices: int = 1,
accelerator: str = "auto",
) -> pl.Trainer:
"""Getter for the lightning trainer.
Args:
callbacks: a list of lightning callbacks preconfigured to be used
for training
logger: the Wandb logger used for logging during training
devices: The number of gpus to be used. 0 means cpu
accelerator: either "gpu" or "cpu" specifies which device to use
Returns:
A lightning Trainer with specified params
"""
trainer_params = self.get("trainer", {})
profiler = trainer_params.pop("profiler", None)
if len(trainer_params) == 0:
print(
"`trainer` key was not found in cfg or was empty. Using defaults for `pl.Trainer`!"
)
if "accelerator" not in trainer_params:
trainer_params["accelerator"] = accelerator
if "devices" not in trainer_params:
trainer_params["devices"] = devices
map_profiler = {
"advanced": pl.profilers.AdvancedProfiler,
"simple": pl.profilers.SimpleProfiler,
"pytorch": pl.profilers.PyTorchProfiler,
"passthrough": pl.profilers.PassThroughProfiler,
"xla": pl.profilers.XLAProfiler,
}
if profiler:
if profiler in map_profiler:
profiler = map_profiler[profiler](filename="profile")
else:
raise ValueError(
f"Profiler {profiler} not supported! Please use one of {list(map_profiler.keys())}"
)
return pl.Trainer(
callbacks=callbacks,
logger=logger,
profiler=profiler,
**trainer_params,
)
data_paths
property
writable
¶
Get data paths.
__init__(cfg, params_cfg=None)
¶
Initialize the class with config from hydra/omega conf.
First uses base_param
file then overwrites with specific params_config
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cfg
|
DictConfig
|
The |
required |
params_cfg
|
DictConfig | None
|
The |
None
|
Source code in dreem/io/config.py
def __init__(self, cfg: DictConfig, params_cfg: DictConfig | None = None):
"""Initialize the class with config from hydra/omega conf.
First uses `base_param` file then overwrites with specific `params_config`.
Args:
cfg: The `DictConfig` containing all the hyperparameters needed for
training/evaluation.
params_cfg: The `DictConfig` containing subset of hyperparameters to override.
training/evaluation
"""
base_cfg = cfg
logger.info(f"Base Config: {cfg}")
if "params_config" in cfg:
params_cfg = OmegaConf.load(cfg.params_config)
if params_cfg:
logger.info(f"Overwriting base config with {params_cfg}")
with open_dict(base_cfg):
self.cfg = OmegaConf.merge(base_cfg, params_cfg) # merge configs
else:
self.cfg = cfg
OmegaConf.set_struct(self.cfg, False)
self._vid_files = {}
__repr__()
¶
__str__()
¶
from_yaml(base_cfg_path, params_cfg_path=None)
classmethod
¶
Load config directly from yaml.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
base_cfg_path
|
str
|
path to base config file. |
required |
params_cfg_path
|
str | None
|
path to override params. |
None
|
Source code in dreem/io/config.py
@classmethod
def from_yaml(cls, base_cfg_path: str, params_cfg_path: str | None = None) -> None:
"""Load config directly from yaml.
Args:
base_cfg_path: path to base config file.
params_cfg_path: path to override params.
"""
base_cfg = OmegaConf.load(base_cfg_path)
params_cfg = OmegaConf.load(params_cfg_path) if params_cfg_path else None
return cls(base_cfg, params_cfg)
get(key, default=None, cfg=None)
¶
Get config item.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key
|
str
|
key of item to return |
required |
default
|
default value to return if key is missing. |
None
|
|
cfg
|
dict
|
the config dict from which to retrieve an item |
None
|
Source code in dreem/io/config.py
def get(self, key: str, default=None, cfg: dict = None):
"""Get config item.
Args:
key: key of item to return
default: default value to return if key is missing.
cfg: the config dict from which to retrieve an item
"""
if cfg is None:
cfg = self.cfg
param = cfg.get(key, default)
if isinstance(param, DictConfig):
param = OmegaConf.to_container(param, resolve=True)
return param
get_checkpointing()
¶
Getter for lightning checkpointing callback.
Returns:
Type | Description |
---|---|
ModelCheckpoint
|
A lightning checkpointing callback with specified params |
Source code in dreem/io/config.py
def get_checkpointing(self) -> pl.callbacks.ModelCheckpoint:
"""Getter for lightning checkpointing callback.
Returns:
A lightning checkpointing callback with specified params
"""
# convert to dict to enable extracting/removing params
checkpoint_params = self.get("checkpointing", {})
logging_params = self.get("logging", {})
dirpath = checkpoint_params.pop("dirpath", None)
if dirpath is None:
dirpath = f"./models/{self.get('group', '', logging_params)}/{self.get('name', '', logging_params)}"
dirpath = Path(dirpath).resolve()
if not Path(dirpath).exists():
try:
Path(dirpath).mkdir(parents=True, exist_ok=True)
except OSError as e:
logger.exception(
f"Cannot create a new folder!. Check the permissions to {dirpath}. \n {e}"
)
_ = checkpoint_params.pop("dirpath", None)
monitor = checkpoint_params.pop("monitor", ["val_loss"])
checkpointers = []
logger.info(
f"Saving checkpoints to `{dirpath}` based on the following metrics: {monitor}"
)
if len(checkpoint_params) == 0:
logger.warning(
"""`checkpointing` key was not found in cfg or was empty!
Configuring checkpointing to use default params!"""
)
for metric in monitor:
checkpointer = pl.callbacks.ModelCheckpoint(
monitor=metric,
dirpath=dirpath,
filename=f"{{epoch}}-{{{metric}}}",
**checkpoint_params,
)
checkpointer.CHECKPOINT_NAME_LAST = f"{{epoch}}-final-{{{metric}}}"
checkpointers.append(checkpointer)
return checkpointers
get_ctc_paths(list_dir_path)
¶
Get file paths from directory. Only for CTC datasets.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
list_dir_path
|
list[str]
|
list of directories to search for labels and videos |
required |
Returns:
Type | Description |
---|---|
tuple[list[str], list[str], list[str]]
|
lists of labels file paths and video file paths |
Source code in dreem/io/config.py
def get_ctc_paths(
self, list_dir_path: list[str]
) -> tuple[list[str], list[str], list[str]]:
"""Get file paths from directory. Only for CTC datasets.
Args:
list_dir_path: list of directories to search for labels and videos
Returns:
lists of labels file paths and video file paths
"""
gt_list = []
raw_img_list = []
ctc_track_meta = []
# user can specify a list of directories, each of which can contain several subdirectories that come in pairs of (dset_name, dset_name_GT/TRA)
for dir_path in list_dir_path:
for subdir in os.listdir(dir_path):
if subdir.endswith("_GT"):
gt_path = os.path.join(dir_path, subdir, "TRA")
raw_img_path = os.path.join(dir_path, subdir.replace("_GT", ""))
# get filepaths for all tif files in gt_path
gt_list.append(glob.glob(os.path.join(gt_path, "*.tif")))
# get filepaths for all tif files in raw_img_path
raw_img_list.append(glob.glob(os.path.join(raw_img_path, "*.tif")))
man_track_file = glob.glob(os.path.join(gt_path, "man_track.txt"))
if len(man_track_file) > 0:
ctc_track_meta.append(man_track_file[0])
else:
logger.debug(
f"No man_track.txt file found in {gt_path}. Continuing..."
)
else:
continue
return gt_list, raw_img_list, ctc_track_meta
get_data_paths(mode, data_cfg)
¶
Get file paths from directory. Only for SLEAP datasets.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mode
|
str
|
[None, "train", "test", "val"]. Indicates whether to use train, val, or test params for dataset |
required |
data_cfg
|
dict
|
Config for the dataset containing "dir" key. |
required |
Returns:
Type | Description |
---|---|
tuple[list[str], list[str]]
|
lists of labels file paths and video file paths respectively |
Source code in dreem/io/config.py
def get_data_paths(self, mode: str, data_cfg: dict) -> tuple[list[str], list[str]]:
"""Get file paths from directory. Only for SLEAP datasets.
Args:
mode: [None, "train", "test", "val"]. Indicates whether to use
train, val, or test params for dataset
data_cfg: Config for the dataset containing "dir" key.
Returns:
lists of labels file paths and video file paths respectively
"""
# hack to get around the fact that for test mode, get_data_paths is called before get_dataset.
# also, for train/val mode, data_cfg has had the dir key popped through self.get() called in get_dataset()
if mode == "test":
list_dir_path = data_cfg.get("dir", {}).get("path", None)
if list_dir_path is None:
raise ValueError(
"`dir` is missing from dataset config. Please provide a path to the directory containing the labels and videos."
)
self.labels_suffix = data_cfg.get("dir", {}).get("labels_suffix")
self.vid_suffix = data_cfg.get("dir", {}).get("vid_suffix")
else:
list_dir_path = self.data_dirs
if not isinstance(list_dir_path, list):
list_dir_path = [list_dir_path]
if self.labels_suffix == ".slp":
label_files = []
vid_files = []
for dir_path in list_dir_path:
logger.debug(f"Searching `{dir_path}` directory")
labels_path = f"{dir_path}/*{self.labels_suffix}"
vid_path = f"{dir_path}/*{self.vid_suffix}"
logger.debug(f"Searching for labels matching {labels_path}")
label_files.extend(glob.glob(labels_path))
logger.debug(f"Searching for videos matching {vid_path}")
vid_files.extend(glob.glob(vid_path))
elif self.labels_suffix == ".tif":
label_files, vid_files, ctc_track_meta = self.get_ctc_paths(list_dir_path)
logger.debug(f"Found {len(label_files)} labels and {len(vid_files)} videos")
# backdoor to set label files directly in the configs (i.e. bypass dir.path)
if data_cfg.get("slp_files", None):
logger.debug("Overriding label files with user provided list")
slp_files = data_cfg.get("slp_files")
if len(slp_files) > 0:
label_files = slp_files
if data_cfg.get("video_files", None):
individual_video_files = data_cfg.get("video_files")
if len(individual_video_files) > 0:
vid_files = individual_video_files
return label_files, vid_files
get_dataloader(dataset, mode)
¶
Getter for dataloader.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset
|
'SleapDataset' | 'MicroscopyDataset' | 'CellTrackingDataset'
|
the Sleap or Microscopy Dataset used to initialize the dataloader |
required |
mode
|
str
|
either ["train", "val", or "test"] indicates which dataset config to use |
required |
Returns:
Type | Description |
---|---|
DataLoader
|
A torch dataloader for |
Source code in dreem/io/config.py
def get_dataloader(
self,
dataset: "SleapDataset" | "MicroscopyDataset" | "CellTrackingDataset",
mode: str,
) -> torch.utils.data.DataLoader:
"""Getter for dataloader.
Args:
dataset: the Sleap or Microscopy Dataset used to initialize the dataloader
mode: either ["train", "val", or "test"] indicates which dataset
config to use
Returns:
A torch dataloader for `dataset` with parameters configured as specified
"""
dataloader_params = self.get("dataloader", {})
if mode.lower() == "train":
dataloader_params = self.get("train_dataloader", {}, dataloader_params)
elif mode.lower() == "val":
dataloader_params = self.get("val_dataloader", {}, dataloader_params)
elif mode.lower() == "test":
dataloader_params = self.get("test_dataloader", {}, dataloader_params)
else:
raise ValueError(
"`mode` must be one of ['train', 'val','test'], not '{mode}'"
)
if dataloader_params.get("num_workers", 0) > 0:
# prevent too many open files error
pin_memory = True
torch.multiprocessing.set_sharing_strategy("file_system")
else:
pin_memory = False
return torch.utils.data.DataLoader(
dataset=dataset,
batch_size=1,
pin_memory=pin_memory,
collate_fn=dataset.no_batching_fn,
**dataloader_params,
)
get_dataset(mode, label_files=None, vid_files=None)
¶
Getter for datasets.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mode
|
str
|
[None, "train", "test", "val"]. Indicates whether to use train, val, or test params for dataset |
required |
label_files
|
list[str] | None
|
path to label_files for override |
None
|
vid_files
|
list[str | list[str]]
|
path to vid_files for override |
None
|
Returns:
Type | Description |
---|---|
'SleapDataset' | 'CellTrackingDataset'
|
Either a |
Source code in dreem/io/config.py
def get_dataset(
self,
mode: str,
label_files: list[str] | None = None,
vid_files: list[str | list[str]] = None,
) -> "SleapDataset" | "CellTrackingDataset":
"""Getter for datasets.
Args:
mode: [None, "train", "test", "val"]. Indicates whether to use
train, val, or test params for dataset
label_files: path to label_files for override
vid_files: path to vid_files for override
Returns:
Either a `SleapDataset` or `CellTrackingDataset` with params indicated by cfg
"""
from dreem.datasets import SleapDataset, CellTrackingDataset
dataset_params = self.get("dataset")
if dataset_params is None:
raise KeyError("`dataset` key is missing from cfg!")
if mode.lower() == "train":
dataset_params = self.get("train_dataset", {}, dataset_params)
elif mode.lower() == "val":
dataset_params = self.get("val_dataset", {}, dataset_params)
elif mode.lower() == "test":
dataset_params = self.get("test_dataset", {}, dataset_params)
else:
raise ValueError(
"`mode` must be one of ['train', 'val','test'], not '{mode}'"
)
# input validation
self.data_dirs = dataset_params.get("dir", {}).get("path", None)
self.labels_suffix = dataset_params.get("dir", {}).get("labels_suffix")
self.vid_suffix = dataset_params.get("dir", {}).get("vid_suffix")
if self.data_dirs is None:
raise ValueError(
"`dir` is missing from dataset config. Please provide a path to the directory containing the labels and videos."
)
if self.labels_suffix is None or self.vid_suffix is None:
raise KeyError(
f"Must provide a labels suffix and vid suffix to search for but found {self.labels_suffix} and {self.vid_suffix}"
)
# infer dataset type from the user provided suffix
if self.labels_suffix == ".slp":
# during training, multiple files can be used at once, so label_files is not passed in
# during inference, a single label_files string can be passed in as get_data_paths is
# called before get_dataset, hence the check
if label_files is None or vid_files is None:
label_files, vid_files = self.get_data_paths(mode, dataset_params)
dataset_params["slp_files"] = label_files
dataset_params["video_files"] = vid_files
dataset_params["data_dirs"] = self.data_dirs
self.data_paths = (mode, vid_files)
return SleapDataset(**dataset_params)
elif self.labels_suffix == ".tif":
# for CTC datasets, pass in a list of gt and raw image directories, eaech of which contain tifs
ctc_track_meta = None
list_dir_path = self.data_dirs # don't modify self.data_dirs
if not isinstance(list_dir_path, list):
list_dir_path = [list_dir_path]
if label_files is None or vid_files is None:
label_files, vid_files, ctc_track_meta = self.get_ctc_paths(
list_dir_path
)
dataset_params["data_dirs"] = self.data_dirs
# extract filepaths of all raw images and gt images (i.e. labelled masks)
dataset_params["gt_list"] = label_files
dataset_params["raw_img_list"] = vid_files
dataset_params["ctc_track_meta"] = ctc_track_meta
return CellTrackingDataset(**dataset_params)
else:
raise ValueError(
"Could not resolve dataset type from Config! Only .slp (SLEAP) and .tif (Cell Tracking Challenge) data formats are supported."
)
get_early_stopping()
¶
Getter for lightning early stopping callback.
Returns:
Type | Description |
---|---|
EarlyStopping
|
A lightning early stopping callback with specified params |
Source code in dreem/io/config.py
def get_early_stopping(self) -> pl.callbacks.EarlyStopping:
"""Getter for lightning early stopping callback.
Returns:
A lightning early stopping callback with specified params
"""
early_stopping_params = self.get("early_stopping", None)
if early_stopping_params is None:
logger.warning(
"`early_stopping` was not found in cfg or was `null`. Early stopping will not be used!"
)
return None
elif len(early_stopping_params) == 0:
logger.warning("`early_stopping` cfg is empty! Using defaults")
return pl.callbacks.EarlyStopping(**early_stopping_params)
get_gtr_runner(ckpt_path=None)
¶
Get lightning module for training, validation, and inference.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ckpt_path
|
str | None
|
path to checkpoint for override |
None
|
Returns:
Type | Description |
---|---|
'GTRRunner'
|
a gtr runner model |
Source code in dreem/io/config.py
def get_gtr_runner(self, ckpt_path: str | None = None) -> "GTRRunner":
"""Get lightning module for training, validation, and inference.
Args:
ckpt_path: path to checkpoint for override
Returns:
a gtr runner model
"""
from dreem.models import GTRRunner
keys = ["tracker", "optimizer", "scheduler", "loss", "runner", "model"]
args = [key + "_cfg" if key != "runner" else key for key in keys]
params = {}
for key, arg in zip(keys, args):
sub_params = self.get(key, {})
if len(sub_params) == 0:
logger.warning(
f"`{key}` not found in config or is empty. Using defaults for {arg}!"
)
if key == "runner":
runner_params = sub_params
for k, v in runner_params.items():
params[k] = v
else:
params[arg] = sub_params
ckpt_path = params["model_cfg"].pop("ckpt_path", None)
if ckpt_path is not None and ckpt_path != "":
model = GTRRunner.load_from_checkpoint(
ckpt_path, tracker_cfg=params["tracker_cfg"], **runner_params
)
else:
model = GTRRunner(**params)
return model
get_logger()
¶
Getter for logging callback.
Returns:
Type | Description |
---|---|
Logger
|
A Logger with specified params |
Source code in dreem/io/config.py
def get_logger(self) -> pl.loggers.Logger:
"""Getter for logging callback.
Returns:
A Logger with specified params
"""
from dreem.models.model_utils import init_logger
logger_params = self.get("logging", {})
if len(logger_params) == 0:
logger.warning(
"`logging` key not found in cfg. No logger will be configured!"
)
return init_logger(
logger_params, OmegaConf.to_container(self.cfg, resolve=True)
)
get_loss()
¶
Getter for loss functions.
Returns:
Type | Description |
---|---|
'dreem.training.losses.AssoLoss'
|
An AssoLoss with specified params |
Source code in dreem/io/config.py
def get_loss(self) -> "dreem.training.losses.AssoLoss":
"""Getter for loss functions.
Returns:
An AssoLoss with specified params
"""
from dreem.training.losses import AssoLoss
loss_params = self.get("loss", {})
if len(loss_params) == 0:
logger.warning(
"`loss` key not found in cfg. Using default params for `AssoLoss`"
)
return AssoLoss(**loss_params)
get_model()
¶
Getter for gtr model.
Returns:
Type | Description |
---|---|
'GlobalTrackingTransformer'
|
A global tracking transformer with parameters indicated by cfg |
Source code in dreem/io/config.py
def get_model(self) -> "GlobalTrackingTransformer":
"""Getter for gtr model.
Returns:
A global tracking transformer with parameters indicated by cfg
"""
from dreem.models import GlobalTrackingTransformer, GTRRunner
model_params = self.get("model", {})
ckpt_path = model_params.pop("ckpt_path", None)
if ckpt_path is not None and len(ckpt_path) > 0:
return GTRRunner.load_from_checkpoint(ckpt_path).model
return GlobalTrackingTransformer(**model_params)
get_optimizer(params)
¶
Getter for optimizer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params
|
Iterable
|
iterable of model parameters to optimize or dicts defining parameter groups |
required |
Returns:
Type | Description |
---|---|
Optimizer
|
A torch Optimizer with specified params |
Source code in dreem/io/config.py
def get_optimizer(self, params: Iterable) -> torch.optim.Optimizer:
"""Getter for optimizer.
Args:
params: iterable of model parameters to optimize or dicts defining
parameter groups
Returns:
A torch Optimizer with specified params
"""
from dreem.models.model_utils import init_optimizer
optimizer_params = self.get("optimizer")
return init_optimizer(params, optimizer_params)
get_scheduler(optimizer)
¶
Getter for lr scheduler.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
optimizer
|
Optimizer
|
The optimizer to wrap the scheduler around |
required |
Returns:
Type | Description |
---|---|
LRScheduler | None
|
A torch learning rate scheduler with specified params |
Source code in dreem/io/config.py
def get_scheduler(
self, optimizer: torch.optim.Optimizer
) -> torch.optim.lr_scheduler.LRScheduler | None:
"""Getter for lr scheduler.
Args:
optimizer: The optimizer to wrap the scheduler around
Returns:
A torch learning rate scheduler with specified params
"""
from dreem.models.model_utils import init_scheduler
lr_scheduler_params = self.get("scheduler")
if lr_scheduler_params is None:
logger.warning(
"`scheduler` key not found in cfg or is empty. No scheduler will be returned!"
)
return None
return init_scheduler(optimizer, lr_scheduler_params)
get_tracker_cfg()
¶
Getter for tracker config params.
Returns:
Type | Description |
---|---|
dict
|
A dict containing the init params for |
get_trainer(callbacks=None, logger=None, devices=1, accelerator='auto')
¶
Getter for the lightning trainer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
callbacks
|
list[Callback] | None
|
a list of lightning callbacks preconfigured to be used for training |
None
|
logger
|
WandbLogger | None
|
the Wandb logger used for logging during training |
None
|
devices
|
int
|
The number of gpus to be used. 0 means cpu |
1
|
accelerator
|
str
|
either "gpu" or "cpu" specifies which device to use |
'auto'
|
Returns:
Type | Description |
---|---|
Trainer
|
A lightning Trainer with specified params |
Source code in dreem/io/config.py
def get_trainer(
self,
callbacks: list[pl.callbacks.Callback] | None = None,
logger: pl.loggers.WandbLogger | None = None,
devices: int = 1,
accelerator: str = "auto",
) -> pl.Trainer:
"""Getter for the lightning trainer.
Args:
callbacks: a list of lightning callbacks preconfigured to be used
for training
logger: the Wandb logger used for logging during training
devices: The number of gpus to be used. 0 means cpu
accelerator: either "gpu" or "cpu" specifies which device to use
Returns:
A lightning Trainer with specified params
"""
trainer_params = self.get("trainer", {})
profiler = trainer_params.pop("profiler", None)
if len(trainer_params) == 0:
print(
"`trainer` key was not found in cfg or was empty. Using defaults for `pl.Trainer`!"
)
if "accelerator" not in trainer_params:
trainer_params["accelerator"] = accelerator
if "devices" not in trainer_params:
trainer_params["devices"] = devices
map_profiler = {
"advanced": pl.profilers.AdvancedProfiler,
"simple": pl.profilers.SimpleProfiler,
"pytorch": pl.profilers.PyTorchProfiler,
"passthrough": pl.profilers.PassThroughProfiler,
"xla": pl.profilers.XLAProfiler,
}
if profiler:
if profiler in map_profiler:
profiler = map_profiler[profiler](filename="profile")
else:
raise ValueError(
f"Profiler {profiler} not supported! Please use one of {list(map_profiler.keys())}"
)
return pl.Trainer(
callbacks=callbacks,
logger=logger,
profiler=profiler,
**trainer_params,
)
set_hparams(hparams)
¶
Setter function for overwriting specific hparams.
Useful for changing 1 or 2 hyperparameters such as dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
hparams
|
dict
|
A dict containing the hyperparameter to be overwritten and the value to be changed |
required |
Returns:
Type | Description |
---|---|
bool
|
|
Source code in dreem/io/config.py
def set_hparams(self, hparams: dict) -> bool:
"""Setter function for overwriting specific hparams.
Useful for changing 1 or 2 hyperparameters such as dataset.
Args:
hparams: A dict containing the hyperparameter to be overwritten and
the value to be changed
Returns:
`True` if config is successfully updated, `False` otherwise
"""
if hparams == {} or hparams is None:
logger.warning("Nothing to update!")
return False
for hparam, val in hparams.items():
try:
OmegaConf.update(self.cfg, hparam, val)
except Exception as e:
logger.exception(f"Failed to update {hparam} to {val} due to {e}")
return False
return True