Source code for kliff.trainer.utils.lightning_utils

import os
import pickle as pkl
from typing import Any, Union

import pytorch_lightning as pl
import torch

from kliff.dataset import Dataset


[docs] class SaveModelCallback(pl.Callback): """ Callback to save the model at the end of each epoch. The model is saved in the ckpt_dir with the name "last_model.pth". The best model is saved with the name "best_model.pth". The model is saved every ckpt_interval epochs with the name "epoch_{epoch}.pth". """ def __init__(self, ckpt_dir, ckpt_interval=100): super().__init__() self.ckpt_dir = ckpt_dir self.best_val_loss = float("inf") self.ckpt_interval = ckpt_interval os.makedirs(self.ckpt_dir, exist_ok=True)
[docs] def on_validation_epoch_end( self, trainer: pl.Trainer, pl_module: pl.LightningModule ): # Save the last model last_save_path = os.path.join(self.ckpt_dir, "last_model.pth") torch.save(pl_module.state_dict(), last_save_path) # Save the best model if trainer.callback_metrics.get("val_loss") < self.best_val_loss: self.best_val_loss = trainer.callback_metrics["val_loss"] best_save_path = os.path.join(self.ckpt_dir, "best_model.pth") torch.save(pl_module.state_dict(), best_save_path) # Save the model every ckpt_interval epochs if pl_module.current_epoch % self.ckpt_interval == 0: epoch_save_path = os.path.join( self.ckpt_dir, f"epoch_{pl_module.current_epoch}.pth" ) torch.save(pl_module.state_dict(), epoch_save_path) # save the trainer checkpoint as well trainer.save_checkpoint(os.path.join(self.ckpt_dir, "trainer_checkpoint.ckpt"))
[docs] class SavePerAtomPredictions(pl.Callback): """ Callback to save the per atom predictions of the model during validation. The per atom predictions are saved in the supplied lmdb file. Usually it is named `per_atom_pred_database.lmdb` in the run dir """ def __init__(self, lmdb_file_handle, ckpt_interval): super().__init__() self.lmdb_file_handle = lmdb_file_handle self.ckpt_interval = ckpt_interval
[docs] def on_validation_batch_end( self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx, dataloader_idx=0, ): if trainer.current_epoch % self.ckpt_interval == 0: epoch = trainer.current_epoch predicted_forces = outputs["per_atom_pred"] self._log_per_atom_outputs(epoch, batch, predicted_forces)
[docs] def on_train_batch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0, ) -> None: if trainer.current_epoch % self.ckpt_interval == 0: epoch = trainer.current_epoch predicted_forces = outputs["per_atom_pred"] self._log_per_atom_outputs(epoch, batch, predicted_forces)
def _log_per_atom_outputs( self, epoch: Union[int, torch.Tensor], batch: Any, predicted_forces: torch.Tensor, ): """ This function is duplicate of ~:class:`kliff.trainer.Trainer.log_per_atom_outputs`. Args: epoch: Current epoch idxs: Index of the configurations predicted_forces: Predicted forces """ with self.lmdb_file_handle.begin(write=True) as txn: idxs = batch["idx"] n_configs = len(idxs) from_ = 0 to_ = -1 for i in range(n_configs): # get the prediction pointer, every even index is contributing n_atoms = ( batch["contributions"][batch["contributions"] == (2 * i)] ).shape[0] to_ = from_ + n_atoms pred = predicted_forces[from_:to_].detach().cpu().numpy() from_ = to_ # save the predictions txn.put( f"epoch_{epoch}|index_{idxs[i]}".encode(), pkl.dumps({"pred_0": pred, "n_atoms": n_atoms}), )
# TODO: LTAU callback?