Skip to content

train

dreem.training.train

Training script for training model.

Used for training a single model or deploying a batch train job on RUNAI CLI

run(cfg)

Train model based on config.

Handles all config parsing and initialization then calls trainer.train(). If batch_config is included then run will be assumed to be a batch job.

Parameters:

Name Type Description Default
cfg DictConfig

The config dict parsed by hydra

required
Source code in dreem/training/train.py
@hydra.main(config_path=None, config_name=None, version_base=None)
def run(cfg: DictConfig):
    """Train model based on config.

    Handles all config parsing and initialization then calls `trainer.train()`.
    If `batch_config` is included then run will be assumed to be a batch job.

    Args:
        cfg: The config dict parsed by `hydra`
    """
    torch.set_float32_matmul_precision("medium")
    train_cfg = Config(cfg)

    # update with parameters for batch train job
    if "batch_config" in cfg.keys():
        try:
            index = int(os.environ["POD_INDEX"])
        except KeyError as e:
            index = int(
                input(f"{e}. Assuming single run!\nPlease input task index to run:")
            )

        hparams_df = pd.read_csv(cfg.batch_config)
        hparams = hparams_df.iloc[index].to_dict()

        if train_cfg.set_hparams(hparams):
            print("Updated the following hparams to the following values")
            pprint(hparams)
    else:
        hparams = {}
    pprint(f"Final train config: {train_cfg}")

    model = train_cfg.get_model()
    train_dataset = train_cfg.get_dataset(mode="train")
    train_dataloader = train_cfg.get_dataloader(train_dataset, mode="train")

    val_dataset = train_cfg.get_dataset(mode="val")
    val_dataloader = train_cfg.get_dataloader(val_dataset, mode="val")

    test_dataset = train_cfg.get_dataset(mode="test")
    test_dataloader = train_cfg.get_dataloader(test_dataset, mode="test")

    dataset = TrackingDataset(
        train_dl=train_dataloader, val_dl=val_dataloader, test_dl=test_dataloader
    )

    if cfg.view_batch.enable:
        instances = next(iter(train_dataset))
        view_training_batch(instances, num_frames=cfg.view_batch.num_frames)

        if cfg.view_batch.no_train:
            return

    model = train_cfg.get_gtr_runner()  # TODO see if we can use torch.compile()

    logger = train_cfg.get_logger()

    callbacks = []
    _ = callbacks.extend(train_cfg.get_checkpointing())
    _ = callbacks.append(pl.callbacks.LearningRateMonitor())
    _ = callbacks.append(train_cfg.get_early_stopping())

    accelerator = "gpu" if torch.cuda.is_available() else "cpu"
    devices = torch.cuda.device_count() if torch.cuda.is_available() else cpu_count()

    trainer = train_cfg.get_trainer(
        callbacks,
        logger,
        accelerator=accelerator,
        devices=devices,
    )

    trainer.fit(model, dataset)