kliff.trainer.torch_trainer¶
- class kliff.trainer.torch_trainer.DNNTrainer(configuration, model=None)[source]¶
This module is used to train a descriptor based dense neural network (DNN), and deploy it using TorchML framework. For using the DUNN model driver please use the legacy Loss optimizer for now. In the future, the ~kliff.trainer.torch_trainer.DNNTrainer will also support both loading and exporting the NN models for DUNN model driver.
- Parameters:
configuration (dict) – A dictionary containing the configuration for the trainer.
model (torch.nn.Module) – A torch model to be trained. If not provided, the model will be loaded from the model manifest. For model manifest based loading the model must be a torchscript model, or valid TorchML model (tar or dir).
- setup_optimizer()[source]¶
Set up the optimizer for the model. The optimizer is defined in the optimizer manifest. This function uses ~:func:~`_get_optimizer` to get the optimizer object, so that it can be customized if needed in the future.
TODO: Add support for custom optimizers, starting with CoRE.
- loss(x, y, weight=1.0)[source]¶
Compute the loss between the predicted and target values.
- Parameters:
x (torch.Tensor) – Predicted values.
y (torch.Tensor) – Target values.
weight (Union[float, torch.Tensor]) – Weight to apply to the loss. Default is 1.0.
- Returns:
Loss value
- Return type:
torch.Tensor
- checkpoint()[source]¶
Checkpoint the model and optimizer state to disk. Also append training and validation loss to the log file. It also saves the best and last model to disk, along with the scheduler and early stopping state if they are present.
- load_checkpoint(path)[source]¶
Load the model and optimizer state from a checkpoint file.
- Parameters:
path (str) – Path to the checkpoint file.
- get_last_checkpoint()[source]¶
Get the last checkpoint file in the run directory. The checkpoint file is assumed to be named as checkpoint_{step}.pkl.
- Returns:
Path to the last checkpoint file.
- Return type:
str
- train_step(batch)[source]¶
Train the model for one step. This function is called by the train function for each batch in the training data loader
- Parameters:
batch – dictionary containing the batch data
- Returns:
The loss value for the batch
- Return type:
torch.Tensor
- validation_step(batch)[source]¶
Validate the model for one step. This function is called by the train function for each batch in the validation data loader
- Parameters:
batch – dictionary containing the batch data
- Returns:
The loss value for the batch
- Return type:
torch.Tensor
- train()[source]¶
Train the model for the specified number of epochs. The training loop is defined in this function. The model is trained for each epoch, and the training and validation loss is computed and logged. The model is checkpointed at the end of each epoch. If early stopping is enabled, the training is stopped when the validation loss does not improve for the specified number of epochs. Once the training is complete, a .finished file is created in the run directory.
- setup_model()[source]¶
Load the torchscript model from the model manifest. If model is provided, ignore the manifest.
- save_kim_model(path='kim-model')[source]¶
Save the KIM model to the given path. The KIM model is saved as a portable TorchML model.
TODO: Add support for DUNN model driver.
- Parameters:
path (
str) – Path to save the model
- setup_parameter_transforms()[source]¶
This method set up the transformed parameter space for models. It can be used for any model type in general, but as there exists a significant difference between how models handles their parameters, it is left for the subclass to implement. Although to ensure that initialize function remains consistent this method will not raise NotImplemented error, rather it will quietly pass. So be aware.