train
dreem.training.train
¶
Training script for training model.
Used for training a single model or deploying a batch train job on RUNAI CLI
run(cfg)
¶
Train model based on config.
Handles all config parsing and initialization then calls trainer.train()
.
If batch_config
is included then run will be assumed to be a batch job.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cfg |
DictConfig
|
The config dict parsed by |
required |
Source code in dreem/training/train.py
@hydra.main(config_path=None, config_name=None, version_base=None)
def run(cfg: DictConfig):
"""Train model based on config.
Handles all config parsing and initialization then calls `trainer.train()`.
If `batch_config` is included then run will be assumed to be a batch job.
Args:
cfg: The config dict parsed by `hydra`
"""
torch.set_float32_matmul_precision("medium")
train_cfg = Config(cfg)
# update with parameters for batch train job
if "batch_config" in cfg.keys():
try:
index = int(os.environ["POD_INDEX"])
except KeyError as e:
index = int(
input(f"{e}. Assuming single run!\nPlease input task index to run:")
)
hparams_df = pd.read_csv(cfg.batch_config)
hparams = hparams_df.iloc[index].to_dict()
if train_cfg.set_hparams(hparams):
print("Updated the following hparams to the following values")
pprint(hparams)
else:
hparams = {}
pprint(f"Final train config: {train_cfg}")
model = train_cfg.get_model()
train_dataset = train_cfg.get_dataset(mode="train")
train_dataloader = train_cfg.get_dataloader(train_dataset, mode="train")
val_dataset = train_cfg.get_dataset(mode="val")
val_dataloader = train_cfg.get_dataloader(val_dataset, mode="val")
test_dataset = train_cfg.get_dataset(mode="test")
test_dataloader = train_cfg.get_dataloader(test_dataset, mode="test")
dataset = TrackingDataset(
train_dl=train_dataloader, val_dl=val_dataloader, test_dl=test_dataloader
)
if cfg.view_batch.enable:
instances = next(iter(train_dataset))
view_training_batch(instances, num_frames=cfg.view_batch.num_frames)
if cfg.view_batch.no_train:
return
model = train_cfg.get_gtr_runner() # TODO see if we can use torch.compile()
logger = train_cfg.get_logger()
callbacks = []
_ = callbacks.extend(train_cfg.get_checkpointing())
_ = callbacks.append(pl.callbacks.LearningRateMonitor())
_ = callbacks.append(train_cfg.get_early_stopping())
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
devices = torch.cuda.device_count() if torch.cuda.is_available() else cpu_count()
trainer = train_cfg.get_trainer(
callbacks,
logger,
accelerator=accelerator,
devices=devices,
)
trainer.fit(model, dataset)