neuromancer.trainer module
- class neuromancer.trainer.CustomEarlyStopping(monitor, patience, warmup=0)[source]
Bases:
EarlyStopping
Custom early stopping callback inherited from PyTorch Lightning Early Stopping. Needed to support proper warmup functionality (early stopping cannot occur within warmup grace period)
- class neuromancer.trainer.LitTrainer(epochs=1000, train_metric='train_loss', dev_metric='dev_loss', test_metric='test_loss', eval_metric='dev_loss', patience=None, warmup=0, clip=100.0, custom_optimizer=None, save_weights=True, weight_path='./', weight_name=None, devices='auto', strategy='auto', accelerator='auto', profiler=None, custom_training_step=None, custom_hooks=None, logger=None, hparam_config=None, automatic_optimization=True)[source]
Bases:
Trainer
- apply_custom_hooks(model)[source]
Apply custom hooks to the model. :param model: The LightningModule to which custom hooks are applied.
- fit(problem, data_setup_function, **kwargs)[source]
Fits (trains) a base neuromancer Problem to a data defined by a data setup function). This function will also instantiate a Lightning version of the provided Problem and LightningDataModule associated with the data setup function
- Parameters:
problem – A Neuromancer Problem() we want to train/fit
data_setup_function – A function that returns train/dev/test Neuromancer DictDatasets as well as batch_size to use
- class neuromancer.trainer.Trainer(problem: ~neuromancer.problem.Problem, train_data: ~torch.utils.data.dataloader.DataLoader, dev_data: ~torch.utils.data.dataloader.DataLoader | None = None, test_data: ~torch.utils.data.dataloader.DataLoader | None = None, optimizer: ~torch.optim.optimizer.Optimizer | None = None, logger: ~neuromancer.loggers.BasicLogger | None = None, callback=<neuromancer.callbacks.Callback object>, lr_scheduler=False, epochs=1000, epoch_verbose=1, patience=5, warmup=0, train_metric='train_loss', dev_metric='dev_loss', test_metric='test_loss', eval_metric='dev_loss', eval_mode='min', clip=100.0, multi_fidelity=False, device='cpu')[source]
Bases:
object
Class encapsulating boilerplate PyTorch training code. Training procedure is somewhat extensible through methods in Callback objects associated with training and evaluation waypoints.