Skip to content

eval

dreem.inference.eval

Script to evaluate model.

run(cfg)

Run inference based on config file.

Parameters:

Name Type Description Default
cfg DictConfig

A dictconfig loaded from hydra containing checkpoint path and data

required
Source code in dreem/inference/eval.py
@hydra.main(config_path=None, config_name=None, version_base=None)
def run(cfg: DictConfig) -> dict[int, sio.Labels]:
    """Run inference based on config file.

    Args:
        cfg: A dictconfig loaded from hydra containing checkpoint path and data
    """
    eval_cfg = Config(cfg)

    if "checkpoints" in cfg.keys():
        try:
            index = int(os.environ["POD_INDEX"])
        # For testing without deploying a job on runai
        except KeyError:
            index = input("Pod Index Not found! Please choose a pod index: ")

        logger.info(f"Pod Index: {index}")

        checkpoints = pd.read_csv(cfg.checkpoints)
        checkpoint = checkpoints.iloc[index]
    else:
        checkpoint = eval_cfg.cfg.ckpt_path

    model = GTRRunner.load_from_checkpoint(checkpoint)
    model.tracker_cfg = eval_cfg.cfg.tracker
    model.tracker = Tracker(**model.tracker_cfg)
    logger.info(f"Using the following tracker:")
    print(model.tracker)
    model.metrics["test"] = eval_cfg.cfg.runner.metrics.test
    logger.info(f"Computing the following metrics:")
    logger.info(model.metrics.test)
    model.test_results["save_path"] = eval_cfg.cfg.runner.save_path
    logger.info(f"Saving results to {model.test_results['save_path']}")

    labels_files, vid_files = eval_cfg.get_data_paths(eval_cfg.cfg.dataset.test_dataset)
    trainer = eval_cfg.get_trainer()
    for label_file, vid_file in zip(labels_files, vid_files):
        dataset = eval_cfg.get_dataset(
            label_files=[label_file], vid_files=[vid_file], mode="test"
        )
        dataloader = eval_cfg.get_dataloader(dataset, mode="test")
        metrics = trainer.test(model, dataloader)