Skip to content

train

dreem.training.train

Training script for training model.

Used for training a single model or deploying a batch train job on RUNAI CLI

Functions:

Name Description
run

Train model based on config.

run(cfg)

Train model based on config.

Handles all config parsing and initialization then calls trainer.train(). If batch_config is included then run will be assumed to be a batch job.

Parameters:

Name Type Description Default
cfg DictConfig

The config dict parsed by hydra

required
Source code in dreem/training/train.py
@hydra.main(config_path=None, config_name=None, version_base=None)
def run(cfg: DictConfig):
    """Train model based on config.

    Handles all config parsing and initialization then calls `trainer.train()`.
    If `batch_config` is included then run will be assumed to be a batch job.

    Args:
        cfg: The config dict parsed by `hydra`
    """
    torch.set_float32_matmul_precision("medium")
    train_cfg = Config(cfg)

    # update with parameters for batch train job
    if "batch_config" in cfg.keys():
        try:
            index = int(os.environ["POD_INDEX"])
        except KeyError as e:
            index = int(
                input(f"{e}. Assuming single run!\nPlease input task index to run:")
            )

        hparams_df = pd.read_csv(cfg.batch_config)
        hparams = hparams_df.iloc[index].to_dict()

        if train_cfg.set_hparams(hparams):
            logger.debug("Updated the following hparams to the following values")
            logger.debug(hparams)
    else:
        hparams = {}
    logging.getLogger().setLevel(level=cfg.get("log_level", "INFO").upper())
    logger.info(f"Final train config: {train_cfg}")

    model = train_cfg.get_model()

    train_dataset = train_cfg.get_dataset(mode="train")
    train_dataloader = train_cfg.get_dataloader(train_dataset, mode="train")

    val_dataset = train_cfg.get_dataset(mode="val")
    val_dataloader = train_cfg.get_dataloader(val_dataset, mode="val")

    dataset = TrackingDataset(train_dl=train_dataloader, val_dl=val_dataloader)

    if cfg.view_batch.enable:
        instances = next(iter(train_dataset))
        view_training_batch(instances, num_frames=cfg.view_batch.num_frames)

        if cfg.view_batch.no_train:
            return

    model = train_cfg.get_gtr_runner()  # TODO see if we can use torch.compile()

    run_logger = train_cfg.get_logger()

    if run_logger is not None and isinstance(run_logger, pl.loggers.wandb.WandbLogger):
        data_paths = train_cfg.data_paths
        flattened_paths = [
            [item] for sublist in data_paths.values() for item in sublist
        ]
        run_logger.log_text(
            "training_files", columns=["data_paths"], data=flattened_paths
        )

    callbacks = []
    _ = callbacks.extend(train_cfg.get_checkpointing())
    _ = callbacks.append(pl.callbacks.LearningRateMonitor())

    early_stopping = train_cfg.get_early_stopping()
    if early_stopping is not None:
        callbacks.append(early_stopping)

    accelerator = "gpu" if torch.cuda.is_available() else "cpu"
    devices = torch.cuda.device_count() if torch.cuda.is_available() else cpu_count()

    trainer = train_cfg.get_trainer(
        callbacks,
        run_logger,
        accelerator=accelerator,
        devices=devices,
    )

    trainer.fit(model, dataset)