eval
dreem.inference.eval
¶
Script to evaluate model.
Functions:
Name | Description |
---|---|
run |
Run inference based on config file. |
run(cfg)
¶
Run inference based on config file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cfg
|
DictConfig
|
A dictconfig loaded from hydra containing checkpoint path and data |
required |
Source code in dreem/inference/eval.py
@hydra.main(config_path=None, config_name=None, version_base=None)
def run(cfg: DictConfig) -> dict[int, sio.Labels]:
"""Run inference based on config file.
Args:
cfg: A dictconfig loaded from hydra containing checkpoint path and data
"""
eval_cfg = Config(cfg)
if "checkpoints" in cfg.keys():
try:
index = int(os.environ["POD_INDEX"])
# For testing without deploying a job on runai
except KeyError:
index = input("Pod Index Not found! Please choose a pod index: ")
logger.info(f"Pod Index: {index}")
checkpoints = pd.read_csv(cfg.checkpoints)
checkpoint = checkpoints.iloc[index]
else:
checkpoint = eval_cfg.get("ckpt_path", None)
if checkpoint is None:
raise ValueError("Checkpoint path not found in config")
logging.getLogger().setLevel(level=cfg.get("log_level", "INFO").upper())
model = GTRRunner.load_from_checkpoint(checkpoint, strict=False)
model.tracker_cfg = eval_cfg.cfg.tracker
if model.tracker_cfg.get("tracker_type", "standard") == "batch":
model.tracker = BatchTracker(**model.tracker_cfg)
else:
model.tracker = Tracker(**model.tracker_cfg)
logger.info(f"Using the following tracker:")
logger.info(model.tracker)
model.metrics["test"] = eval_cfg.get("metrics", {}).get("test", "all")
model.persistent_tracking["test"] = True
logger.info(f"Computing the following metrics:")
logger.info(model.metrics["test"])
model.test_results["save_path"] = eval_cfg.get("outdir", ".")
os.makedirs(model.test_results["save_path"], exist_ok=True)
logger.info(
f"Saving tracking results and metrics to {model.test_results['save_path']}"
)
labels_files, vid_files = eval_cfg.get_data_paths(
"test", eval_cfg.cfg.dataset.test_dataset
)
trainer = eval_cfg.get_trainer()
for label_file, vid_file in zip(labels_files, vid_files):
dataset = eval_cfg.get_dataset(
label_files=[label_file], vid_files=[vid_file], mode="test"
)
dataloader = eval_cfg.get_dataloader(dataset, mode="test")
metrics = trainer.test(model, dataloader)