Source code for kliff.trainer.kim_trainer

import importlib
import tarfile
from pathlib import Path

import numpy as np
from loguru import logger

from kliff.models import KIMModel

from .base_trainer import Trainer, TrainerError
from .utils.losses import MAE_loss, MSE_loss

SCIPY_MINIMIZE_METHODS = [
    "Nelder-Mead",
    "Powell",
    "CG",
    "BFGS",
    "Newton-CG",
    "L-BFGS-B",
    "TNC",
    "COBYLA",
    "SLSQP",
    "trust-constr",
    "dogleg",
    "trust-ncg",
    "trust-exact",
    "trust-krylov",
]


[docs] class KIMTrainer(Trainer): """ This class extends the base Trainer class for training OpenKIM physics based models. It will use the scipy optimizers. It will perform a check to exclude TorchML model driver based models, as they would be handled by Torch based trainers. It can read model tarballs as well as export the models as tarballs for ease of use. It will use the KIMModel class to load the model and set the parameters. It also provides explicit interface for parameters transformation. Args: training_manifest (dict): The training_manifest dictionary. """ def __init__(self, training_manifest: dict, model=None): self.model_driver_name = None self.parameters = None self.mutable_parameters_list = [] self.use_energy = True self.use_forces = False self.use_stress = False self.is_model_tarfile = False super().__init__(training_manifest, model) if model: self.setup_model() # call manually if model is provided self.loss_function = self._get_loss_fn() self.result = None
[docs] def setup_model(self): """ Load either the installed KIM model, or install it from the source. If the model driver required is TorchML* family, then it will raise an error, as it should be handled by the DNNTrainer, or GNNLightningTrainer. Path can be a folder containing the model, or a tar file. The model name is the KIM model name. """ if not isinstance(self.model, KIMModel): if self.model_manifest["path"]: try: self.is_model_tarfile = tarfile.is_tarfile( self.model_manifest["path"] ) except (IsADirectoryError, TypeError) as e: self.is_model_tarfile = False logger.debug(f"Model path is not a tarfile: {e}") # check for unsupported model drivers self.model = KIMModel.get_model_from_manifest( self.model_manifest, self.transform_manifest, self.is_model_tarfile ) self.parameters = self.model.get_model_params()
[docs] def setup_optimizer(self): """ Set up the optimizer based on the provided information. If the optimizer name is not provided, it will raise an error. It will use the ~:class:~scipy.optimize class for optimizers. It will raise an error if the optimizer is not supported. """ if self.optimizer_manifest["name"] not in SCIPY_MINIMIZE_METHODS: raise TrainerError( f"Optimizer not supported: {self.optimizer_manifest['name']}." ) optimizer_lib = importlib.import_module(f"scipy.optimize") self.optimizer = getattr(optimizer_lib, "minimize")
# TODO: LM-Geodesic optimizer
[docs] def loss(self, x: np.ndarray) -> float: """ Compute the loss function for the given parameters. It sets the KIM model parameters, compute the desired loss function doe all trainable properties and return the total loss after scaling losses with ~:class:~kliff.configuration.Weight. TODO: Include MPI support. Args: x (np.ndarray): The model parameters. Returns: float: The total loss. """ # set the parameters self.model.update_model_params(x) # compute the loss loss = 0.0 for configuration in self.train_dataset: compute_energy = ( True if configuration.weight.energy_weight is not None else False ) compute_forces = ( True if configuration.weight.forces_weight is not None else False ) compute_stress = ( True if configuration.weight.stress_weight is not None else False ) prediction = self.model( configuration, compute_energy=compute_energy, compute_forces=compute_forces, compute_stress=compute_stress, ) if self.current["log_per_atom_pred"]: self.log_per_atom_outputs( self.current["epoch"], [configuration.metadata.get("index")], [prediction["forces"]], ) if configuration.weight.energy_weight is not None: loss += self.loss_function( prediction["energy"], configuration.energy, configuration.weight.energy_weight, ) if configuration.weight.forces_weight is not None: loss += self.loss_function( prediction["forces"], configuration.forces, configuration.weight.forces_weight, ) if configuration.weight.stress_weight is not None: loss += self.loss_function( prediction["stress"], configuration.stress, configuration.weight.stress_weight, ) if configuration.weight.config_weight is not None: loss *= configuration.weight.config_weight self.current["epoch"] += 1 return loss
[docs] def checkpoint(self, *args, **kwargs): TrainerError("checkpoint not implemented.")
[docs] def train_step(self, *args, **kwargs): TrainerError("train_step not implemented.")
[docs] def validation_step(self, *args, **kwargs): TrainerError("validation_step not implemented.")
[docs] def get_optimizer(self, *args, **kwargs): TrainerError("get_optimizer not implemented.")
[docs] def train(self): """ Train the model using the provided optimizer. It will set the model parameters to the optimal values found by the optimizer. It will log the optimization status and the message. It will raise an error if the optimization fails. TODO: Include MPI support. Log loss trajectory for KIM models. """ def _wrapper_func(x): return self.loss(x) x = self.model.get_opt_params() options = self.optimizer_manifest.get("kwargs", {}) options["options"] = { "maxiter": self.optimizer_manifest["epochs"], "disp": self.current["verbose"], } self.result = self.optimizer( _wrapper_func, x, method=self.optimizer_manifest["name"], **options ) if self.result.success: logger.info(f"Optimization successful: {self.result.message}") self.model.update_model_params(self.result.x) else: logger.error(f"Optimization failed: {self.result.message}")
def _get_loss_fn(self) -> callable: """ Get the loss function based on the provided loss manifest. It will raise an error if the loss function is not supported. Returns: function: The loss function. """ if self.loss_manifest["function"].lower() == "mse": return MSE_loss if self.loss_manifest["function"].lower() == "mae": return MAE_loss else: raise TrainerError( f"Loss function {self.loss_manifest['function']} not supported." )
[docs] def save_kim_model(self): """ Save the KIM model to the provided path. It will also generate a tarball if specified in the export manifest. """ path = ( Path(self.export_manifest["model_path"]) / self.export_manifest["model_name"] ) self.model.write_kim_model(path) self.write_training_env_edn(path) if self.export_manifest["generate_tarball"]: tarfile_path = path.with_suffix(".tar.gz") with tarfile.open(tarfile_path, "w:gz") as tar: tar.add(path, arcname=path.name) logger.info(f"Model tarball saved: {tarfile_path}") logger.info(f"KIM model saved at {path}")
# TODO: Support for lst_sq in optimizer