import os
import tarfile
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Union
import libdescriptor as lds
import numpy as np
import torch
from loguru import logger
from monty.dev import requires
from torch.utils.data import DataLoader as TorchDataLoader
from torch_scatter import scatter_add
from .base_trainer import Trainer, TrainerError
from .utils.dataloaders import DescriptorDataset
from .utils.losses import MAE_loss, MSE_loss
if TYPE_CHECKING:
from kliff.transforms.configuration_transforms.descriptors import Descriptor
[docs]
class DNNTrainer(Trainer):
"""
This module is used to train a descriptor based dense neural network (DNN), and deploy it
using TorchML framework. For using the DUNN model driver please use the legacy Loss optimizer
for now. In the future, the ~kliff.trainer.torch_trainer.DNNTrainer will also support both
loading and exporting the NN models for DUNN model driver.
Args:
configuration (dict): A dictionary containing the configuration for the trainer.
model (torch.nn.Module): A torch model to be trained. If not provided, the model
will be loaded from the model manifest. For model manifest based loading
the model must be a torchscript model, or valid TorchML model (tar or dir).
"""
def __init__(self, configuration: dict, model=None):
self.configuration_transform: "Descriptor" = (
None # for type checking the functions
)
self.torchscript_file = None
self.train_dataloader = None
self.validation_dataloader = None
self.lr_scheduler = None
self.early_stopping = None
self.dtype = torch.float64
super().__init__(configuration, model)
self.loss_function = self._get_loss_function()
self.setup_dataloaders()
self.early_stopping = self._get_early_stopping()
self.lr_scheduler = self._get_scheduler()
# Optim Loss and Checkpoint #######################################################
[docs]
def setup_optimizer(self):
"""
Set up the optimizer for the model. The optimizer is defined in the optimizer manifest.
This function uses ~:func:~`_get_optimizer` to get the optimizer object, so that it can be
customized if needed in the future.
TODO: Add support for custom optimizers, starting with CoRE.
"""
self.optimizer = self._get_optimizer()
[docs]
def loss(
self, x: torch.Tensor, y: torch.Tensor, weight: Union[float, torch.Tensor] = 1.0
) -> torch.Tensor:
"""
Compute the loss between the predicted and target values.
Args:
x (torch.Tensor): Predicted values.
y (torch.Tensor): Target values.
weight (Union[float, torch.Tensor]): Weight to apply to the loss. Default is 1.0.
Returns:
torch.Tensor: Loss value
"""
return self.loss_function(x, y, weight)
[docs]
def checkpoint(self):
"""
Checkpoint the model and optimizer state to disk. Also append training and validation
loss to the log file. It also saves the best and last model to disk, along with the
scheduler and early stopping state if they are present.
"""
if not os.path.exists(f"{self.current['run_dir']}/checkpoints"):
os.makedirs(f"{self.current['run_dir']}/checkpoints")
checkpoint = {
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"current_step": self.current["step"],
"current_loss": self.current["loss"],
}
if self.lr_scheduler:
checkpoint["lr_scheduler"] = self.lr_scheduler.state_dict()
if self.early_stopping:
checkpoint["early_stopping"] = {
"counter": self.early_stopping.counter,
"best_loss": self.early_stopping.best_loss,
}
torch.save(
checkpoint,
f"{self.current['run_dir']}/checkpoints/checkpoint_{self.current['step']}.pkl",
)
# save best and last model
if self.current["loss"]["val"]:
if self.current["loss"]["val"] < self.current["best_loss"]:
self.current["best_loss"] = self.current["loss"]["val"]
torch.save(
self.model.state_dict(),
f"{self.current['run_dir']}/checkpoints/best_model.pth",
)
else:
if not self.current["warned_once"]:
self.current["warned_once"] = True
logger.warning("Validation loss not available for checkpointing.")
torch.save(
self.model.state_dict(),
f"{self.current['run_dir']}/checkpoints/last_model.pth",
)
if os.path.exists(f"{self.current['run_dir']}/loss.log"):
with open(f"{self.current['run_dir']}/loss.log", "a") as f:
f.write(
f"{self.current['step']},{self.current['loss']['train']},{self.current['loss']['val']}\n"
)
else:
with open(f"{self.current['run_dir']}/loss.log", "w") as f:
f.write("step,train_loss,val_loss\n")
f.write(
f"{self.current['step']},{self.current['loss']['train']},{self.current['loss']['val']}\n"
)
[docs]
def load_checkpoint(self, path: str):
"""
Load the model and optimizer state from a checkpoint file.
Args:
path (str): Path to the checkpoint file.
"""
checkpoint = torch.load(path)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.current["step"] = checkpoint["current_step"]
self.current["loss"] = checkpoint["current_loss"]
if self.lr_scheduler:
self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
if self.early_stopping:
self.early_stopping.counter = checkpoint["early_stopping"]["counter"]
self.early_stopping.best_loss = checkpoint["early_stopping"]["best_loss"]
[docs]
def get_last_checkpoint(self) -> str:
"""
Get the last checkpoint file in the run directory. The checkpoint file is assumed to be
named as `checkpoint_{step}.pkl`.
Returns:
str: Path to the last checkpoint file.
"""
checkpoints = [
f
for f in os.listdir(f"{self.current['run_dir']}")
if f.startswith("checkpoint")
]
max_step = max(
list(map(lambda x: int(x.split("_")[1].split(".")[0]), checkpoints))
)
return f"{self.current['run_dir']}/checkpoint_{max_step}.pkl"
def _get_optimizer(self) -> torch.optim:
"""
Get the optimizer for the model. The optimizer is defined in the optimizer manifest.
Returns:
torch.optim: The optimizer object.
"""
optimizer = getattr(torch.optim, self.optimizer_manifest["name"])
if self.optimizer_manifest["kwargs"]:
return optimizer(
self.model.parameters(),
lr=self.optimizer_manifest["learning_rate"],
**self.optimizer_manifest["kwargs"],
)
else:
return optimizer(
self.model.parameters(), lr=self.optimizer_manifest["learning_rate"]
)
def _get_scheduler(self) -> torch.optim.lr_scheduler:
"""
Get the scheduler for the optimizer. The scheduler is defined in the optimizer manifest.
Returns:
torch.optim.lr_scheduler: The scheduler object.
"""
scheduler_manifest = self.optimizer_manifest.get("lr_scheduler", None)
if scheduler_manifest:
scheduler = getattr(torch.optim.lr_scheduler, scheduler_manifest["name"])
logger.info(f"Using scheduler: {scheduler_manifest['name']}")
return scheduler(self.optimizer, **scheduler_manifest["args"])
return None
def _get_early_stopping(self) -> Union["_EarlyStopping", None]:
"""
Get the early stopping callback. The early stopping callback is defined in the optimizer manifest.
Returns:
_EarlyStopping: The early stopping callback object.
"""
if self.training_manifest.get("early_stopping", None):
logger.info(
f"Using early stopping with patience: {self.training_manifest['early_stopping']['patience']} and min_delta: {self.training_manifest['early_stopping']['min_delta']}"
)
return _EarlyStopping(
self.training_manifest["early_stopping"]["patience"],
self.training_manifest["early_stopping"]["min_delta"],
)
return None
def _get_loss_function(self):
"""
Get the loss function for the model. The loss function is defined in the loss manifest.
Returns:
The loss function object.
"""
if self.loss_manifest["function"].lower() == "mse":
return MSE_loss
elif self.loss_manifest["function"].lower() == "mae":
return MAE_loss
else:
raise TrainerError(
f"Loss function {self.loss_manifest['function']} not supported."
)
# TODO: Custom loss functions
# train steps #####################################################################
[docs]
def train_step(self, batch) -> torch.Tensor:
"""
Train the model for one step. This function is called by the `train` function for each
batch in the training data loader
Args:
batch: dictionary containing the batch data
Returns:
torch.Tensor: The loss value for the batch
"""
if self.transform_manifest["configuration"]["name"].lower() == "descriptor":
return self._descriptor_train_step(batch)
else:
raise TrainerError(
f"Configuration transformation type {self.transform_manifest['configuration']['name']} not supported."
)
def _descriptor_train_step(self, batch) -> torch.Tensor:
"""
Train the model for one step. This function is called by the `train` function for each
batch in the training data loader
Args:
batch: dictionary containing the batch data
Returns:
torch.Tensor: The loss value for the batch
"""
self.optimizer.zero_grad()
loss = self._descriptor_eval_batch(batch)
loss.backward()
self.optimizer.step()
return loss
[docs]
def validation_step(self, batch) -> torch.Tensor:
"""
Validate the model for one step. This function is called by the `train` function for each
batch in the validation data loader
Args:
batch: dictionary containing the batch data
Returns:
torch.Tensor: The loss value for the batch
"""
if self.transform_manifest["configuration"]["name"].lower() == "descriptor":
loss = self._descriptor_eval_batch(batch)
try:
if self.lr_scheduler:
self.lr_scheduler.step()
except TypeError:
self.lr_scheduler.step(loss) # for ReduceLROnPlateau
return loss
else:
raise TrainerError(
f"Configuration transformation type {self.transform_manifest['configuration']} not supported."
)
# eval batches #####################################################################
@requires(lds, "libdescriptor is needed for descriptor training.")
def _descriptor_eval_batch(self, batch) -> torch.Tensor:
"""
Evaluate the model for one batch of data. This function is called by the `train_step` and
`validation_step` functions. It uses the libdescriptor library to compute the forces and
stress from the descriptor.
Args:
batch: dictionary containing the batch data
Returns:
torch.Tensor: The loss value for the batch
"""
n_atoms = batch["n_atoms"]
species = batch["species"]
neigh_list = batch["neigh_list"]
num_neigh = batch["num_neigh"]
image = batch["image"]
coords = batch["coords"]
descriptors = batch["descriptors"]
properties = batch["property_dict"]
contribution = batch["contribution"]
ptr = batch["ptr"]
indexes = batch["index"]
weights = batch["weight"]
descriptor_tensor = torch.tensor(
descriptors,
dtype=self.dtype,
device=self.current["device"],
requires_grad=True,
)
predictions = self.model(descriptor_tensor)
predictions = scatter_add(
predictions,
torch.tensor(
contribution, device=self.current["device"], dtype=torch.int64
),
dim=0,
)
loss = self.loss(
predictions,
torch.as_tensor(
properties["energy"], dtype=self.dtype, device=self.current["device"]
),
weights["energy"],
) # energy will always be present for conservative models
# if self.loss_manifest["weights"]["energy"]:
# loss = loss * self.loss_manifest["weights"]["energy"]
# if (
# self.loss_manifest["weights"]["forces"]
# or self.loss_manifest["weights"]["stress"]
# ):
dE_dzeta = torch.autograd.grad(
predictions.sum(),
descriptor_tensor,
retain_graph=True,
)[0]
forces = lds.gradient_batch(
self.configuration_transform._cdesc,
n_atoms,
ptr,
species,
neigh_list,
num_neigh,
coords,
descriptors,
dE_dzeta.double().detach().cpu().numpy(),
)
forces_predicted = torch.zeros(
properties["forces"].shape,
device=self.current["device"],
dtype=self.dtype,
)
forces = torch.tensor(forces, device=self.current["device"], dtype=self.dtype)
force_summed = scatter_add(
forces,
torch.tensor(image, device=self.current["device"], dtype=torch.int64),
dim=0,
)
n_atoms_tensor = torch.tensor(
n_atoms, device=self.current["device"], dtype=torch.int64
)
ptr_tensor = torch.tensor(ptr, device=self.current["device"], dtype=torch.int64)
# TODO: See if we can do without the triple if condition
if self.current["log_per_atom_pred"] and (
self.current["epoch"] % self.current["ckpt_interval"] == 0
):
per_atom_pred = []
for i in range(len(ptr_tensor)):
from_ = torch.sum(n_atoms_tensor[:i])
to_ = from_ + n_atoms_tensor[i]
forces_predicted[from_:to_] = force_summed[
ptr_tensor[i] : ptr_tensor[i] + n_atoms_tensor[i]
]
if self.current["log_per_atom_pred"] and (
self.current["epoch"] % self.current["ckpt_interval"] == 0
):
per_atom_pred.append(forces_predicted[from_:to_].detach().cpu().numpy())
if self.current["log_per_atom_pred"] and (
self.current["epoch"] % self.current["ckpt_interval"] == 0
):
self.log_per_atom_outputs(self.current["epoch"], indexes, per_atom_pred)
if (
weights["forces"] is not None
and np.array(weights["forces"]).size != forces.shape[0]
):
weights_forces = np.atleast_2d(weights["forces"])[contribution]
else:
weights_forces = weights["forces"]
loss_forces = self.loss(
forces_predicted,
torch.tensor(
properties["forces"],
device=self.current["device"],
dtype=self.dtype,
),
weights_forces,
)
# loss = loss + loss_forces * self.loss_manifest["weights"]["forces"]
# TODO: Discuss and check if this is correct
# F = - ∂E/∂r, ℒ = f(E, F)
# F = - ∂E/∂ζ * ∂ζ/∂r <- ∂ζ/∂r is a jacobian, vjp is computed by enzyme
# ∂ℒ/∂θ = ∂ℒ/∂E * ∂E/∂θ + ∂ℒ/∂F * ∂F/∂θ
# tricky part is ∂F/∂θ, as F is computed using Enzyme
# ∂F/∂θ = ∂^2E/∂ζ∂θ * ∂ζ/∂r + ∂E/∂ζ * ∂^2ζ/∂r∂θ
# = ∂^2E/∂ζ∂θ * ∂ζ/∂r + 0 (∂ζ/∂r independent of θ)
# So we do not need second derivative wrt to ζ,
# and ∂ζ/∂r is provided by the descriptor module. So autograd should be able
# to handle this. But need to confirm, else we need to explicitly compute
# ∂^2E/∂ζ∂θ and then call lds.gradient again.
# ask pytorch forum. Or use custom gradient optimization.
# TODO: Add stress loss
# if self.loss_manifest["weights"]["stress"]:
# # stress = \sum_i (f_i \otimes r_i)
# stress = torch.zeros(len(ptr), 6) # voigt notation
# for i in range(len(ptr) - 1):
# from_ = torch.sum(n_atoms[:i])
# to_ = from_ + n_atoms[i]
# full_stress = torch.einsum(
# "ij,ik->ijk", forces_predicted[from_:to_], coords[from_:to_]
# )
# summed_stress = torch.sum(full_stress, dim=0)
# stress[i, 0] = summed_stress[0, 0]
# stress[i, 1] = summed_stress[1, 1]
# stress[i, 2] = summed_stress[2, 2]
# stress[i, 3] = summed_stress[1, 2]
# stress[i, 4] = summed_stress[0, 2]
# stress[i, 5] = summed_stress[0, 1]
#
# loss_stress = self.loss(stress, properties["stress"])
# loss = loss + loss_stress * self.loss_manifest["weights"]["stress"]
return loss
[docs]
def train(self):
"""
Train the model for the specified number of epochs. The training loop is
defined in this function. The model is trained for each epoch, and the
training and validation loss is computed and logged. The model is checkpointed
at the end of each epoch. If early stopping is enabled, the training is stopped
when the validation loss does not improve for the specified number of epochs.
Once the training is complete, a `.finished` file is created in the run directory.
"""
# TODO: granularity of train: train, train_step, train_epoch?
# currently it is train -> train_step, so train is wrapper for train_epoch
if self.current["appending_to_previous_run"]:
self.load_checkpoint(self.get_last_checkpoint())
for epoch in range(self.optimizer_manifest["epochs"]):
self.current["epoch"] = epoch
epoch_train_loss = 0.0
self.model.train()
for i, batch in enumerate(self.train_dataloader):
loss = self.train_step(batch)
epoch_train_loss += loss.detach().cpu().numpy()
if epoch % self.current["ckpt_interval"] == 0:
if self.val_dataset:
epoch_val_loss = 0.0
self.model.eval()
for batch in self.val_dataloader:
loss = self.validation_step(batch)
epoch_val_loss += loss.detach().cpu().numpy()
self.current["loss"] = {
"train": epoch_train_loss,
"val": epoch_val_loss,
}
else:
self.current["loss"] = {"train": epoch_train_loss, "val": None}
epoch_val_loss = None
logger.info(f"Epoch {epoch} completed. val loss: {epoch_val_loss}")
self.checkpoint()
if self.early_stopping and self.early_stopping(epoch_val_loss):
logger.info(f"Early stopping at epoch {epoch}")
break
self.current["step"] += 1
logger.info(f"Epoch {epoch} completed. Train loss: {epoch_train_loss}")
# create .finished file to indicate that training is done
with open(f"{self.current['run_dir']}/.finished", "w") as f:
f.write("")
if self.current["log_per_atom_pred"]:
# close LMDB file
self.current["per_atom_pred_database"].close()
# model io #####################################################################
[docs]
def setup_model(self):
"""
Load the torchscript model from the model manifest. If model is provided, ignore the manifest.
"""
torchscript_path = self.model_manifest["path"]
model = torch.jit.load(torchscript_path)
# change precision of model
# if self.training_manifest["precision"] == "single":
# self.model = self.model.float()
# elif self.training_manifest["precision"] == "double":
# self.model = self.model.double()
# else:
# raise TrainerError(
# f"Precision {self.training_manifest['precision']} not supported."
# )
self.model = model.to(device=self.current["device"], dtype=self.dtype)
[docs]
def save_kim_model(self, path: str = "kim-model"):
"""
Save the KIM model to the given path. The KIM model is saved as a portable
TorchML model.
TODO: Add support for DUNN model driver.
Args:
path: Path to save the model
"""
# create folder if not already present
if self.export_manifest["model_path"]:
path = self.export_manifest["model_path"]
if not self.export_manifest["model_name"]:
qualified_model_name = f"{self.current['run_title']}_MO_000000000000_000"
else:
qualified_model_name = self.export_manifest["model_name"]
torchml_version = self.export_manifest.get("driver_version", "000")
if len(torchml_version) != 3: # not a fully qualified version string
torchml_version = f"{int(torchml_version):03}"
torchml_version = "TorchML__MD_173118614730_" + str(torchml_version)
path = os.path.join(path, qualified_model_name)
os.makedirs(path, exist_ok=True)
best_model = deepcopy(self.model)
best_model.load_state_dict(
torch.load(f"{self.current['run_dir']}/checkpoints/best_model.pth")
)
model = torch.jit.script(best_model)
model = model.cpu()
torch.jit.save(model, f"{path}/model.pt")
# save the configuration transform
self.configuration_transform.export_kim_model(
path, "model.pt"
) # kim_model.param
self.configuration_transform.save_descriptor_state(path) # descriptor.dat
# CMakeLists.txt
cmakefile = self._generate_kim_cmake(
qualified_model_name,
torchml_version,
["model.pt", "descriptor.dat", "kim_model.param"],
)
with open(f"{path}/CMakeLists.txt", "w") as f:
f.write(cmakefile)
# write training environment
self.write_training_env_edn(path)
if self.export_manifest["generate_tarball"]:
tarball_name = f"{path}.tar.gz"
with tarfile.open(tarball_name, "w:gz") as tar:
tar.add(path, arcname=os.path.basename(path))
logger.info(f"Model tarball saved: {tarball_name}")
logger.info(f"KIM model saved at {path}")
# Data loaders #####################################################################
[docs]
def setup_dataloaders(self):
if self.transform_manifest["configuration"]["name"].lower() == "descriptor":
self._setup_descriptor_dataloaders()
else:
raise TrainerError(
f"Configuration transformation type {self.transform_manifest['configuration']['name']} not supported."
)
def _setup_descriptor_dataloaders(self):
"""
Set up the dataloaders for the descriptor based model. The dataloaders are created
using the `DescriptorDataset` class from the `kliff.trainer.torch_trainer_utils.dataloaders`
module. The dataloaders are created from the training and validation datasets, and the
configuration transformation is applied to the data before creating the dataloaders.
If dynamic loading is enabled, then the transformation is delayed until the data is
loaded in the dataloader.
"""
if self.dataset_manifest["dynamic_loading"]:
self.train_dataset = DescriptorDataset(self.train_dataset)
if self.val_dataset:
self.val_dataset = DescriptorDataset(self.val_dataset)
self.train_dataset.add_transform(self.configuration_transform)
if self.val_dataset:
self.val_dataset.add_transform(self.configuration_transform)
else:
for config in self.train_dataset:
config.fingerprint = self.configuration_transform(
config, return_extended_state=True
)
self.train_dataset = DescriptorDataset(self.train_dataset)
if self.val_dataset:
for config in self.val_dataset:
config.fingerprint = self.configuration_transform(
config, return_extended_state=True
)
self.val_dataset = DescriptorDataset(self.val_dataset)
self.train_dataloader = TorchDataLoader(
self.train_dataset,
batch_size=self.optimizer_manifest["batch_size"],
shuffle=True,
collate_fn=self.train_dataset.collate,
)
if self.val_dataset:
self.val_dataloader = TorchDataLoader(
self.val_dataset,
batch_size=self.optimizer_manifest["batch_size"],
shuffle=False,
collate_fn=self.val_dataset.collate,
)
else:
logger.warning("No validation dataset loaded.")
# Auxiliary #####################################################################
class _EarlyStopping:
"""
Early stopping callback for the training loop.
Args:
patience (int): Number of epochs to wait before stopping the training.
delta (float): Minimum change in the loss to be considered as an improvement.
"""
def __init__(self, patience: int, delta: float):
self.patience = patience
self.delta = delta
self.counter = 0
self.best_loss = float("inf")
def __call__(self, loss):
if loss < self.best_loss - self.delta:
self.best_loss = loss
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
return True
return False
# TODO:
# - Add device management
# - Add model export
# - Add restart capabilities
# - Add per component loss extraction
# - Add custom options for torchmetrics
# - Precision management
# - DDP?