gtr_runner
dreem.models.gtr_runner
¶
Module containing training, validation and inference logic.
Classes:
Name | Description |
---|---|
GTRRunner |
A lightning wrapper around GTR model. |
GTRRunner
¶
Bases: LightningModule
A lightning wrapper around GTR model.
Used for training, validation and inference.
Methods:
Name | Description |
---|---|
__init__ |
Initialize a lightning module for GTR. |
configure_optimizers |
Get optimizers and schedulers for training. |
forward |
Execute forward pass of the lightning module. |
log_metrics |
Log metrics computed during evaluation. |
on_test_end |
Run inference and metrics pipeline to compute metrics for test set. |
on_validation_epoch_end |
Execute hook for validation end. |
predict_step |
Run inference for model. |
test_step |
Execute single test step for model. |
training_step |
Execute single training step for model. |
validation_step |
Execute single val step for model. |
Source code in dreem/models/gtr_runner.py
class GTRRunner(LightningModule):
"""A lightning wrapper around GTR model.
Used for training, validation and inference.
"""
DEFAULT_METRICS = {
"train": [],
"val": [],
"test": ["num_switches", "global_tracking_accuracy"],
}
DEFAULT_TRACKING = {
"train": False,
"val": False,
"test": True,
}
DEFAULT_SAVE = {"train": False, "val": False, "test": False}
def __init__(
self,
model_cfg: dict | None = None,
tracker_cfg: dict | None = None,
loss_cfg: dict | None = None,
optimizer_cfg: dict | None = None,
scheduler_cfg: dict | None = None,
metrics: dict[str, list[str]] | None = None,
persistent_tracking: dict[str, bool] | None = None,
test_save_path: str = "./test_results.h5",
):
"""Initialize a lightning module for GTR.
Args:
model_cfg: hyperparameters for GlobalTrackingTransformer
tracker_cfg: The parameters used for the tracker post-processing
loss_cfg: hyperparameters for AssoLoss
optimizer_cfg: hyper parameters used for optimizer.
Only used to overwrite `configure_optimizer`
scheduler_cfg: hyperparameters for lr_scheduler used to overwrite `configure_optimizer
metrics: a dict containing the metrics to be computed during train, val, and test.
persistent_tracking: a dict containing whether to use persistent tracking during train, val and test inference.
test_save_path: path to a directory to save the eval and tracking results to
"""
super().__init__()
self.save_hyperparameters()
self.model_cfg = model_cfg if model_cfg else {}
self.loss_cfg = loss_cfg if loss_cfg else {}
self.tracker_cfg = tracker_cfg if tracker_cfg else {}
self.model = GlobalTrackingTransformer(**self.model_cfg)
self.loss = AssoLoss(**self.loss_cfg)
if self.tracker_cfg.get("tracker_type", "standard") == "batch":
self.tracker = BatchTracker(**self.tracker_cfg)
else:
self.tracker = Tracker(**self.tracker_cfg)
self.optimizer_cfg = optimizer_cfg
self.scheduler_cfg = scheduler_cfg
self.metrics = metrics if metrics is not None else self.DEFAULT_METRICS
self.persistent_tracking = (
persistent_tracking
if persistent_tracking is not None
else self.DEFAULT_TRACKING
)
self.test_results = {"preds": [], "save_path": test_save_path}
def forward(
self,
ref_instances: list["dreem.io.Instance"],
query_instances: list["dreem.io.Instance"] | None = None,
) -> list["AssociationMatrix"]:
"""Execute forward pass of the lightning module.
Args:
ref_instances: a list of `Instance` objects containing crops and other data needed for transformer model
query_instances: a list of `Instance` objects used as queries in the decoder. Mostly used for inference.
Returns:
An association matrix between objects
"""
asso_preds = self.model(ref_instances, query_instances)
return asso_preds
def training_step(
self, train_batch: list[list["dreem.io.Frame"]], batch_idx: int
) -> dict[str, float]:
"""Execute single training step for model.
Args:
train_batch: A single batch from the dataset which is a list of `Frame` objects
with length `clip_length` containing Instances and other metadata.
batch_idx: the batch number used by lightning
Returns:
A dict containing the train loss plus any other metrics specified
"""
result = self._shared_eval_step(train_batch[0], mode="train")
self.log_metrics(result, len(train_batch[0]), "train")
return result
def validation_step(
self, val_batch: list[list["dreem.io.Frame"]], batch_idx: int
) -> dict[str, float]:
"""Execute single val step for model.
Args:
val_batch: A single batch from the dataset which is a list of `Frame` objects
with length `clip_length` containing Instances and other metadata.
batch_idx: the batch number used by lightning
Returns:
A dict containing the val loss plus any other metrics specified
"""
result = self._shared_eval_step(val_batch[0], mode="val")
self.log_metrics(result, len(val_batch[0]), "val")
return result
def test_step(
self, test_batch: list[list["dreem.io.Frame"]], batch_idx: int
) -> dict[str, float]:
"""Execute single test step for model.
Args:
test_batch: A single batch from the dataset which is a list of `Frame` objects
with length `clip_length` containing Instances and other metadata.
batch_idx: the batch number used by lightning
Returns:
A dict containing the val loss plus any other metrics specified
"""
result = self._shared_eval_step(test_batch[0], mode="test")
self.log_metrics(result, len(test_batch[0]), "test")
return result
def predict_step(
self, batch: list[list["dreem.io.Frame"]], batch_idx: int
) -> list["dreem.io.Frame"]:
"""Run inference for model.
Computes association + assignment.
Args:
batch: A single batch from the dataset which is a list of `Frame` objects
with length `clip_length` containing Instances and other metadata.
batch_idx: the batch number used by lightning
Returns:
A list of dicts where each dict is a frame containing the predicted track ids
"""
frames_pred = self.tracker(self.model, batch[0])
return frames_pred
def _shared_eval_step(
self, frames: list["dreem.io.Frame"], mode: str
) -> dict[str, float]:
"""Run evaluation used by train, test, and val steps.
Args:
frames: A list of dicts where each dict is a frame containing gt data
mode: which metrics to compute and whether to use persistent tracking or not
Returns:
a dict containing the loss and any other metrics specified by `eval_metrics`
"""
try:
instances = [instance for frame in frames for instance in frame.instances]
if len(instances) == 0:
return None
eval_metrics = self.metrics[mode]
logits = self(instances)
logits = [asso.matrix for asso in logits]
loss = self.loss(logits, frames)
return_metrics = {"loss": loss}
if mode == "test":
self.tracker.persistent_tracking = True
frames_pred = self.tracker(self.model, frames)
self.test_results["preds"].extend(
[frame.to("cpu") for frame in frames_pred]
)
return_metrics["batch_size"] = len(frames)
except Exception as e:
logger.exception(
f"Failed on frame {frames[0].frame_id} of video {frames[0].video_id}"
)
logger.exception(e)
raise (e)
return return_metrics
def configure_optimizers(self) -> dict:
"""Get optimizers and schedulers for training.
Is overridden by config but defaults to Adam + ReduceLROnPlateau.
Returns:
an optimizer config dict containing the optimizer, scheduler, and scheduler params
"""
# todo: init from config
if self.optimizer_cfg is None:
optimizer = torch.optim.Adam(self.parameters(), lr=1e-4, betas=(0.9, 0.999))
else:
optimizer = init_optimizer(self.parameters(), self.optimizer_cfg)
if self.scheduler_cfg is None:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, "min", 0.5, 10
)
else:
scheduler = init_scheduler(optimizer, self.scheduler_cfg)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss",
"interval": "epoch",
"frequency": 1,
},
}
def log_metrics(self, result: dict, batch_size: int, mode: str) -> None:
"""Log metrics computed during evaluation.
Args:
result: A dict containing metrics to be logged.
batch_size: the size of the batch used to compute the metrics
mode: One of {'train', 'test' or 'val'}. Used as prefix while logging.
"""
if result:
batch_size = result.pop("batch_size")
for metric, val in result.items():
if isinstance(val, torch.Tensor):
val = val.item()
self.log(f"{mode}_{metric}", val, batch_size=batch_size)
def on_validation_epoch_end(self):
"""Execute hook for validation end.
Currently, we simply clear the gpu cache and do garbage collection.
"""
gc.collect()
torch.cuda.empty_cache()
def on_test_end(self):
"""Run inference and metrics pipeline to compute metrics for test set.
Args:
test_results: dict containing predictions and metrics to be filled out in metrics.evaluate
metrics: list of metrics to compute
"""
# input validation
metrics_to_compute = self.metrics[
"test"
] # list of metrics to compute, or "all"
if metrics_to_compute == "all":
metrics_to_compute = ["motmetrics", "global_tracking_accuracy"]
if isinstance(metrics_to_compute, str):
metrics_to_compute = [metrics_to_compute]
for metric in metrics_to_compute:
if metric not in ["motmetrics", "global_tracking_accuracy"]:
raise ValueError(
f"Metric {metric} not supported. Please select from 'motmetrics' or 'global_tracking_accuracy'"
)
preds = self.test_results["preds"]
# results is a dict with key being the metric name, and value being the metric value computed
results = metrics.evaluate(preds, metrics_to_compute)
# save metrics and frame metadata to hdf5
# Get the video name from the first frame
vid_name = Path(preds[0].vid_name).stem
# save the results to an hdf5 file
fname = os.path.join(
self.test_results["save_path"], f"{vid_name}.dreem_metrics.h5"
)
logger.info(f"Saving metrics to {fname}")
# Check if the h5 file exists and add a suffix to prevent name collision
suffix_counter = 0
original_fname = fname
while os.path.exists(fname):
suffix_counter += 1
fname = original_fname.replace(
".dreem_metrics.h5", f"_{suffix_counter}.dreem_metrics.h5"
)
if suffix_counter > 0:
logger.info(f"File already exists. Saving to {fname} instead")
with h5py.File(fname, "a") as results_file:
# Create a group for this video
vid_group = results_file.require_group(vid_name)
# Save each metric
for metric_name, value in results.items():
if metric_name == "motmetrics":
# For num_switches, save mot_summary and mot_events separately
mot_summary = value[0]
mot_events = value[1]
frame_switch_map = value[2]
mot_summary_group = vid_group.require_group("mot_summary")
# Loop through each row in mot_summary and save as attributes
for _, row in mot_summary.iterrows():
mot_summary_group.attrs[row.name] = row["acc"]
# save extra metadata for frames in which there is a switch
for frame_id, switch in frame_switch_map.items():
frame = preds[frame_id]
frame = frame.to("cpu")
if switch:
_ = frame.to_h5(
vid_group,
frame.get_gt_track_ids().cpu().numpy(),
save={
"crop": True,
"features": True,
"embeddings": True,
},
)
else:
_ = frame.to_h5(
vid_group, frame.get_gt_track_ids().cpu().numpy()
)
# save motevents log to csv
motevents_path = os.path.join(
self.test_results["save_path"], f"{vid_name}.motevents.csv"
)
logger.info(f"Saving motevents log to {motevents_path}")
mot_events.to_csv(motevents_path, index=False)
elif metric_name == "global_tracking_accuracy":
gta_by_gt_track = value
gta_group = vid_group.require_group("global_tracking_accuracy")
# save as a key value pair with gt track id: gta
for gt_track_id, gta in gta_by_gt_track.items():
gta_group.attrs[f"track_{gt_track_id}"] = gta
# save the tracking results to a slp file
pred_slp = []
outpath = os.path.join(
self.test_results["save_path"],
f"{vid_name}.dreem_inference.{datetime.now().strftime('%m-%d-%Y-%H-%M-%S')}.slp",
)
logger.info(f"Saving inference results to {outpath}")
# save the tracking results to a slp file
tracks = {}
for frame in preds:
if frame.frame_id.item() == 0:
video = (
sio.Video(frame.video)
if isinstance(frame.video, str)
else sio.Video
)
lf, tracks = frame.to_slp(tracks, video=video)
pred_slp.append(lf)
pred_slp = sio.Labels(pred_slp)
pred_slp.save(outpath)
# clear the preds
self.test_results["preds"] = []
__init__(model_cfg=None, tracker_cfg=None, loss_cfg=None, optimizer_cfg=None, scheduler_cfg=None, metrics=None, persistent_tracking=None, test_save_path='./test_results.h5')
¶
Initialize a lightning module for GTR.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_cfg
|
dict | None
|
hyperparameters for GlobalTrackingTransformer |
None
|
tracker_cfg
|
dict | None
|
The parameters used for the tracker post-processing |
None
|
loss_cfg
|
dict | None
|
hyperparameters for AssoLoss |
None
|
optimizer_cfg
|
dict | None
|
hyper parameters used for optimizer.
Only used to overwrite |
None
|
scheduler_cfg
|
dict | None
|
hyperparameters for lr_scheduler used to overwrite `configure_optimizer |
None
|
metrics
|
dict[str, list[str]] | None
|
a dict containing the metrics to be computed during train, val, and test. |
None
|
persistent_tracking
|
dict[str, bool] | None
|
a dict containing whether to use persistent tracking during train, val and test inference. |
None
|
test_save_path
|
str
|
path to a directory to save the eval and tracking results to |
'./test_results.h5'
|
Source code in dreem/models/gtr_runner.py
def __init__(
self,
model_cfg: dict | None = None,
tracker_cfg: dict | None = None,
loss_cfg: dict | None = None,
optimizer_cfg: dict | None = None,
scheduler_cfg: dict | None = None,
metrics: dict[str, list[str]] | None = None,
persistent_tracking: dict[str, bool] | None = None,
test_save_path: str = "./test_results.h5",
):
"""Initialize a lightning module for GTR.
Args:
model_cfg: hyperparameters for GlobalTrackingTransformer
tracker_cfg: The parameters used for the tracker post-processing
loss_cfg: hyperparameters for AssoLoss
optimizer_cfg: hyper parameters used for optimizer.
Only used to overwrite `configure_optimizer`
scheduler_cfg: hyperparameters for lr_scheduler used to overwrite `configure_optimizer
metrics: a dict containing the metrics to be computed during train, val, and test.
persistent_tracking: a dict containing whether to use persistent tracking during train, val and test inference.
test_save_path: path to a directory to save the eval and tracking results to
"""
super().__init__()
self.save_hyperparameters()
self.model_cfg = model_cfg if model_cfg else {}
self.loss_cfg = loss_cfg if loss_cfg else {}
self.tracker_cfg = tracker_cfg if tracker_cfg else {}
self.model = GlobalTrackingTransformer(**self.model_cfg)
self.loss = AssoLoss(**self.loss_cfg)
if self.tracker_cfg.get("tracker_type", "standard") == "batch":
self.tracker = BatchTracker(**self.tracker_cfg)
else:
self.tracker = Tracker(**self.tracker_cfg)
self.optimizer_cfg = optimizer_cfg
self.scheduler_cfg = scheduler_cfg
self.metrics = metrics if metrics is not None else self.DEFAULT_METRICS
self.persistent_tracking = (
persistent_tracking
if persistent_tracking is not None
else self.DEFAULT_TRACKING
)
self.test_results = {"preds": [], "save_path": test_save_path}
configure_optimizers()
¶
Get optimizers and schedulers for training.
Is overridden by config but defaults to Adam + ReduceLROnPlateau.
Returns:
Type | Description |
---|---|
dict
|
an optimizer config dict containing the optimizer, scheduler, and scheduler params |
Source code in dreem/models/gtr_runner.py
def configure_optimizers(self) -> dict:
"""Get optimizers and schedulers for training.
Is overridden by config but defaults to Adam + ReduceLROnPlateau.
Returns:
an optimizer config dict containing the optimizer, scheduler, and scheduler params
"""
# todo: init from config
if self.optimizer_cfg is None:
optimizer = torch.optim.Adam(self.parameters(), lr=1e-4, betas=(0.9, 0.999))
else:
optimizer = init_optimizer(self.parameters(), self.optimizer_cfg)
if self.scheduler_cfg is None:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, "min", 0.5, 10
)
else:
scheduler = init_scheduler(optimizer, self.scheduler_cfg)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss",
"interval": "epoch",
"frequency": 1,
},
}
forward(ref_instances, query_instances=None)
¶
Execute forward pass of the lightning module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ref_instances
|
list[Instance]
|
a list of |
required |
query_instances
|
list[Instance] | None
|
a list of |
None
|
Returns:
Type | Description |
---|---|
list[AssociationMatrix]
|
An association matrix between objects |
Source code in dreem/models/gtr_runner.py
def forward(
self,
ref_instances: list["dreem.io.Instance"],
query_instances: list["dreem.io.Instance"] | None = None,
) -> list["AssociationMatrix"]:
"""Execute forward pass of the lightning module.
Args:
ref_instances: a list of `Instance` objects containing crops and other data needed for transformer model
query_instances: a list of `Instance` objects used as queries in the decoder. Mostly used for inference.
Returns:
An association matrix between objects
"""
asso_preds = self.model(ref_instances, query_instances)
return asso_preds
log_metrics(result, batch_size, mode)
¶
Log metrics computed during evaluation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
result
|
dict
|
A dict containing metrics to be logged. |
required |
batch_size
|
int
|
the size of the batch used to compute the metrics |
required |
mode
|
str
|
One of {'train', 'test' or 'val'}. Used as prefix while logging. |
required |
Source code in dreem/models/gtr_runner.py
def log_metrics(self, result: dict, batch_size: int, mode: str) -> None:
"""Log metrics computed during evaluation.
Args:
result: A dict containing metrics to be logged.
batch_size: the size of the batch used to compute the metrics
mode: One of {'train', 'test' or 'val'}. Used as prefix while logging.
"""
if result:
batch_size = result.pop("batch_size")
for metric, val in result.items():
if isinstance(val, torch.Tensor):
val = val.item()
self.log(f"{mode}_{metric}", val, batch_size=batch_size)
on_test_end()
¶
Run inference and metrics pipeline to compute metrics for test set.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
test_results
|
dict containing predictions and metrics to be filled out in metrics.evaluate |
required | |
metrics
|
list of metrics to compute |
required |
Source code in dreem/models/gtr_runner.py
def on_test_end(self):
"""Run inference and metrics pipeline to compute metrics for test set.
Args:
test_results: dict containing predictions and metrics to be filled out in metrics.evaluate
metrics: list of metrics to compute
"""
# input validation
metrics_to_compute = self.metrics[
"test"
] # list of metrics to compute, or "all"
if metrics_to_compute == "all":
metrics_to_compute = ["motmetrics", "global_tracking_accuracy"]
if isinstance(metrics_to_compute, str):
metrics_to_compute = [metrics_to_compute]
for metric in metrics_to_compute:
if metric not in ["motmetrics", "global_tracking_accuracy"]:
raise ValueError(
f"Metric {metric} not supported. Please select from 'motmetrics' or 'global_tracking_accuracy'"
)
preds = self.test_results["preds"]
# results is a dict with key being the metric name, and value being the metric value computed
results = metrics.evaluate(preds, metrics_to_compute)
# save metrics and frame metadata to hdf5
# Get the video name from the first frame
vid_name = Path(preds[0].vid_name).stem
# save the results to an hdf5 file
fname = os.path.join(
self.test_results["save_path"], f"{vid_name}.dreem_metrics.h5"
)
logger.info(f"Saving metrics to {fname}")
# Check if the h5 file exists and add a suffix to prevent name collision
suffix_counter = 0
original_fname = fname
while os.path.exists(fname):
suffix_counter += 1
fname = original_fname.replace(
".dreem_metrics.h5", f"_{suffix_counter}.dreem_metrics.h5"
)
if suffix_counter > 0:
logger.info(f"File already exists. Saving to {fname} instead")
with h5py.File(fname, "a") as results_file:
# Create a group for this video
vid_group = results_file.require_group(vid_name)
# Save each metric
for metric_name, value in results.items():
if metric_name == "motmetrics":
# For num_switches, save mot_summary and mot_events separately
mot_summary = value[0]
mot_events = value[1]
frame_switch_map = value[2]
mot_summary_group = vid_group.require_group("mot_summary")
# Loop through each row in mot_summary and save as attributes
for _, row in mot_summary.iterrows():
mot_summary_group.attrs[row.name] = row["acc"]
# save extra metadata for frames in which there is a switch
for frame_id, switch in frame_switch_map.items():
frame = preds[frame_id]
frame = frame.to("cpu")
if switch:
_ = frame.to_h5(
vid_group,
frame.get_gt_track_ids().cpu().numpy(),
save={
"crop": True,
"features": True,
"embeddings": True,
},
)
else:
_ = frame.to_h5(
vid_group, frame.get_gt_track_ids().cpu().numpy()
)
# save motevents log to csv
motevents_path = os.path.join(
self.test_results["save_path"], f"{vid_name}.motevents.csv"
)
logger.info(f"Saving motevents log to {motevents_path}")
mot_events.to_csv(motevents_path, index=False)
elif metric_name == "global_tracking_accuracy":
gta_by_gt_track = value
gta_group = vid_group.require_group("global_tracking_accuracy")
# save as a key value pair with gt track id: gta
for gt_track_id, gta in gta_by_gt_track.items():
gta_group.attrs[f"track_{gt_track_id}"] = gta
# save the tracking results to a slp file
pred_slp = []
outpath = os.path.join(
self.test_results["save_path"],
f"{vid_name}.dreem_inference.{datetime.now().strftime('%m-%d-%Y-%H-%M-%S')}.slp",
)
logger.info(f"Saving inference results to {outpath}")
# save the tracking results to a slp file
tracks = {}
for frame in preds:
if frame.frame_id.item() == 0:
video = (
sio.Video(frame.video)
if isinstance(frame.video, str)
else sio.Video
)
lf, tracks = frame.to_slp(tracks, video=video)
pred_slp.append(lf)
pred_slp = sio.Labels(pred_slp)
pred_slp.save(outpath)
# clear the preds
self.test_results["preds"] = []
on_validation_epoch_end()
¶
Execute hook for validation end.
Currently, we simply clear the gpu cache and do garbage collection.
predict_step(batch, batch_idx)
¶
Run inference for model.
Computes association + assignment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
list[list[Frame]]
|
A single batch from the dataset which is a list of |
required |
batch_idx
|
int
|
the batch number used by lightning |
required |
Returns:
Type | Description |
---|---|
list[Frame]
|
A list of dicts where each dict is a frame containing the predicted track ids |
Source code in dreem/models/gtr_runner.py
def predict_step(
self, batch: list[list["dreem.io.Frame"]], batch_idx: int
) -> list["dreem.io.Frame"]:
"""Run inference for model.
Computes association + assignment.
Args:
batch: A single batch from the dataset which is a list of `Frame` objects
with length `clip_length` containing Instances and other metadata.
batch_idx: the batch number used by lightning
Returns:
A list of dicts where each dict is a frame containing the predicted track ids
"""
frames_pred = self.tracker(self.model, batch[0])
return frames_pred
test_step(test_batch, batch_idx)
¶
Execute single test step for model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
test_batch
|
list[list[Frame]]
|
A single batch from the dataset which is a list of |
required |
batch_idx
|
int
|
the batch number used by lightning |
required |
Returns:
Type | Description |
---|---|
dict[str, float]
|
A dict containing the val loss plus any other metrics specified |
Source code in dreem/models/gtr_runner.py
def test_step(
self, test_batch: list[list["dreem.io.Frame"]], batch_idx: int
) -> dict[str, float]:
"""Execute single test step for model.
Args:
test_batch: A single batch from the dataset which is a list of `Frame` objects
with length `clip_length` containing Instances and other metadata.
batch_idx: the batch number used by lightning
Returns:
A dict containing the val loss plus any other metrics specified
"""
result = self._shared_eval_step(test_batch[0], mode="test")
self.log_metrics(result, len(test_batch[0]), "test")
return result
training_step(train_batch, batch_idx)
¶
Execute single training step for model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
train_batch
|
list[list[Frame]]
|
A single batch from the dataset which is a list of |
required |
batch_idx
|
int
|
the batch number used by lightning |
required |
Returns:
Type | Description |
---|---|
dict[str, float]
|
A dict containing the train loss plus any other metrics specified |
Source code in dreem/models/gtr_runner.py
def training_step(
self, train_batch: list[list["dreem.io.Frame"]], batch_idx: int
) -> dict[str, float]:
"""Execute single training step for model.
Args:
train_batch: A single batch from the dataset which is a list of `Frame` objects
with length `clip_length` containing Instances and other metadata.
batch_idx: the batch number used by lightning
Returns:
A dict containing the train loss plus any other metrics specified
"""
result = self._shared_eval_step(train_batch[0], mode="train")
self.log_metrics(result, len(train_batch[0]), "train")
return result
validation_step(val_batch, batch_idx)
¶
Execute single val step for model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
val_batch
|
list[list[Frame]]
|
A single batch from the dataset which is a list of |
required |
batch_idx
|
int
|
the batch number used by lightning |
required |
Returns:
Type | Description |
---|---|
dict[str, float]
|
A dict containing the val loss plus any other metrics specified |
Source code in dreem/models/gtr_runner.py
def validation_step(
self, val_batch: list[list["dreem.io.Frame"]], batch_idx: int
) -> dict[str, float]:
"""Execute single val step for model.
Args:
val_batch: A single batch from the dataset which is a list of `Frame` objects
with length `clip_length` containing Instances and other metadata.
batch_idx: the batch number used by lightning
Returns:
A dict containing the val loss plus any other metrics specified
"""
result = self._shared_eval_step(val_batch[0], mode="val")
self.log_metrics(result, len(val_batch[0]), "val")
return result