import hashlib
import importlib
import json
import os
import pickle as pkl
import random
import sys
from datetime import datetime
from glob import glob
from pathlib import Path
from typing import TYPE_CHECKING, Callable, List, Union
import numpy as np
import yaml
from loguru import logger
from kliff.dataset import Dataset
from kliff.dataset.weight import Weight
if TYPE_CHECKING:
from kliff.transforms.configuration_transforms import ConfigurationTransform
[docs]
class Trainer:
"""Base class for all trainers.
This class is the base class for all trainers. It provides the basic structure for
training a model. The derived classes should implement the required methods. This
class will provide the basic functionality for training, such as setting up the
work directory, saving the configuration, and setting up the indices for training
and validation datasets. It will save hashes of the configuration fingerprints and
training configuration to the work directory. This would ensure reproducibility of
the training process, and easy restarting.
The core trainer class will provide the following functionality:
- Set up the work directory
- Set up the dataset
- Set up the test train split
Model, parameter transform and optimizer setup are left for the derived classes to
implement.
env variables:
KLIFF_LMDB_MAP_SIZE: lmdb mmap size, defaults to 1e12
"""
def __init__(self, training_manifest: dict, model=None):
"""
Initialize the trainer with the provided manifest and model. Model can be initialized
from the manifest itself for simpler, KIM-API based models. For more complex models,
the model should be provided as an argument.
The model manifest should contain the following sections, or keys:
- workspace: Workspace configuration, folders, name of run etc
- name: Name of the workspace
- seed: Seed for random number generators
- resume: Resume from the previous run if possible
- dataset: Dataset to train on, usually an XYZ file, or ColabFit dataset. The
content of this block will be given {func}`kliff.dataset.Dataset.from_manifest`
function. Consult the documentation of the function for more information.
- model: Model to train, usually a KIM-API model. The content of this block will be given to the
{func}`kliff.models.Model.get_model_from_manifest` function. Consult the documentation of the
function for more information. The above function usually works for
physics based KIM-API models. For more complex ML models, you need to
initialize the model and provide it as an argument. Tarfiles can be used
for TorchML driver based model drivers.
Additional keys might be needed on model specific basis.
- transforms: Transformations to be applied to the dataset. This includes
parameter transforms (for KIM models), configuration transforms
(descriptors, and graphs), and property transforms (energy, forces,
stress).
- training: Training configuration, including loss function, optimizer, batch size,
number of epochs, early stopping, dataset test-train split etc.
TODO: Add more details about the training block.
- export: Export the trained model. This includes the model name and the path to
save the model.
Args:
training_manifest: Dictionary containing the training configuration
model: Model instance or None, usually a pytorch model
"""
# workspace variables
self.workspace: dict = {
"name": "kliff_workspace",
"seed": 12345,
"resume": False,
}
# dataset variables
self.dataset_manifest: dict = {
"type": "path",
"path": "./",
"save": False,
"keys": {"energy": "energy", "forces": "forces"},
"dynamic_loading": False,
"colabfit_dataset": {
"dataset_name": None,
"database_name": None,
"database_url": None,
},
}
self.dataset = None
# model variables
self.model_manifest: dict = {
"name": None,
"path": None,
"collection": "user",
}
self.model: Callable = model
# transform variables
self.transform_manifest: dict = {
"property": [
{
"name": None,
"property_key": None,
}
],
"parameter": [],
"configuration": {
"name": None,
"kwargs": None,
},
}
self.property_transforms = []
self.configuration_transform: "ConfigurationTransform" = None
# training variables
# this is too complicated to put it in singe dict, therefore the training
# block is divided into loss, optimizer, dataset_sample
self.training_manifest: dict = {}
self.loss_manifest: dict = {
"function": "mse",
"weights": {
"energy": 1.0,
"forces": None,
"stress": None,
"config": 1.0,
},
"normalize_per_atom": False,
}
self.optimizer_manifest: dict = {
"name": None,
"learning_rate": None,
"kwargs": {},
"epochs": 10000,
"num_workers": None,
"batch_size": 1,
}
self.optimizer = None
# part of current?
self.dataset_sample_manifest: dict = {
"train_size": None,
"val_size": None,
"indices_files": {"train": None, "val": None},
"val_indices": None,
"train_indices": None,
}
self.train_dataset = None
self.val_dataset = None
# export trained model
self.export_manifest: dict = {
"model_name": None,
"model_path": None,
"generate_tarball": False,
"driver_version": "000",
}
# state variables
self.current: dict = {
"run_title": None,
"run_dir": None,
"run_hash": None,
"start_time": None,
"best_loss": np.inf,
"best_model": None,
"loss": None,
"epoch": 0,
"step": 0,
"device": "cpu",
"warned_once": False,
"dataset_hash": None,
"data_dir": None,
"appending_to_previous_run": False,
"verbose": False,
"ckpt_interval": 100,
"log_per_atom_pred": False,
# "per_atom_pred": {"train": None, "val": None},
"per_atom_pred_database": None,
}
self.parse_manifest(training_manifest)
self.initialize()
[docs]
def parse_manifest(self, manifest: dict):
"""
It accepts the raw manifest dictionary, and processes it to the formatted
manifest. This includes mapping the string fields to enums, and setting sane
defaults for missing fields.
Args:
manifest: raw incoming configuration
Returns:
Processed manifest
"""
_date_time_format = "%Y-%m-%d-%H-%M-%S"
start_time = datetime.now()
date_time_str = start_time.strftime(_date_time_format)
self.current["start_time"] = start_time
# Workspace variables ################################################
workspace_block: Union[None, dict] = manifest.get("workspace", None)
if workspace_block is None:
logger.warning(
"Workspace block not found in the configuration. Using default values."
)
else:
self.workspace |= workspace_block
# Dataset manifest #################################################
dataset_manifest: Union[None, dict] = manifest.get("dataset", None)
if dataset_manifest is None:
raise TrainerError("Dataset block not found in the configuration. Exiting.")
self.dataset_manifest |= dataset_manifest
# model variables ####################################################
model_manifest: Union[None, dict] = manifest.get("model", None)
if model_manifest is None:
raise TrainerError("Model block not found in the configuration. Exiting.")
self.model_manifest |= model_manifest
if self.model_manifest.get("name", None) is None:
self.current["run_title"] = (
f"{self.model_manifest.get('type')}_{date_time_str}"
)
else:
self.current["run_title"] = (
f"{self.model_manifest.get('name')}_{date_time_str}"
)
# transform variables ####################################################
transform_manifest: Union[None, dict] = manifest.get("transforms", None)
if transform_manifest is None:
logger.warning(
"Transform block not found in the configuration. This is bit unusual."
)
else:
self.transform_manifest |= transform_manifest
# training variables ########################################################
training_manifest: Union[None, dict] = manifest.get("training", None)
if training_manifest is None:
logger.warning(
"Training block not found in the configuration."
"Will try and resume the previous run if possible."
)
self.training_manifest |= training_manifest
if self.training_manifest.get("loss", None) is None:
logger.warning(
"Loss block not found in the configuration. Using default values."
)
self.loss_manifest |= self.training_manifest.get("loss")
if self.training_manifest.get("optimizer", None) is None:
logger.warning(
"Optimizer block not found in the configuration."
"Will resume the previous run if possible."
)
self.optimizer_manifest |= self.training_manifest.get("optimizer")
self.optimizer_manifest["epochs"] = self.training_manifest.get("epochs", 10000)
self.optimizer_manifest["num_workers"] = self.training_manifest.get(
"num_workers", None
)
self.optimizer_manifest["batch_size"] = self.training_manifest.get(
"batch_size", 1
)
self.current["ckpt_interval"] = self.training_manifest.get("ckpt_interval", 100)
self.current["verbose"] = self.training_manifest.get("verbose", False)
self.current["device"] = self.training_manifest.get("device", "cpu")
# dataset sample variables will be processed in the setup_dataset method
self.export_manifest |= manifest.get("export", {})
# per save atom prediction?
self.current["log_per_atom_pred"] = training_manifest.get(
"log_per_atom_pred", False
)
[docs]
def config_to_dict(self):
"""
Convert the configuration to a dictionary.
"""
config = {
"workspace": self.workspace,
"dataset": self.dataset_manifest,
"model": self.model_manifest,
"transforms": self.transform_manifest,
"training": self.training_manifest,
}
return config
[docs]
@classmethod
def from_file(cls, filename: Path):
"""
Load the manifest from a YAML file.
Args:
filename: name of the yaml file
Returns:
Trainer instance
"""
manifest = yaml.safe_load(open(filename, "r"))
return cls(manifest)
[docs]
def get_trainer_hash(self):
"""
Get the hash of the current configuration. It will be used to create a unique
directory for the current run. It will be the hash of the configuration dictionary
string.
"""
config = self.config_to_dict()
config_immut_str = json.dumps(config, sort_keys=True)
return hashlib.md5(config_immut_str.encode()).hexdigest()
[docs]
def initialize(self):
"""
Initialize the trainer. Assigns the configuration objects, and
call setup methods.
"""
# Step 1 - Assign the processed configuration objects to the class variables
# This has been done in the __init__ method
# Step 2 - Initialize all seeds
self.seed_all()
logger.info(f"Seed set to {self.workspace['seed']}.")
# Step 3 - Set up the workspace folder
self.setup_workspace()
logger.info(f"Workspace set to {self.current['run_dir']}.")
# Step 4 - Read or load the dataset, initialize the property/configuration transforms
self.setup_dataset()
logger.info(f"Dataset loaded.")
# Step 4.5 - Set up the dataset transforms
self.setup_dataset_transforms()
# Step 5 - Set up the test and train datasets, based on the provided indices
self.setup_dataset_split()
logger.info(f"Train and validation datasets set up.")
# Step 6 - Set up the model, if not provided
if not self.model:
self.setup_model()
logger.info(f"Model loaded.")
# Step 6.5 - Setup parameter transform
self.setup_parameter_transforms()
# Step 7 - Set up the optimizer
self.setup_optimizer()
logger.info(f"Optimizer loaded.")
# Step 8 - Save the configuration for future
self.save_config()
[docs]
def seed_all(self):
"""
Seed all the random number generators.
"""
np.random.seed(self.workspace["seed"])
random.seed(self.workspace["seed"])
# torch.manual_seed(self.seed) # enable torch seed in respective children
# torch.cuda.manual_seed_all(self.seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
[docs]
def setup_workspace(self):
"""
Check all the existing runs in the root directory and see if it finished the run
or not. If it is finished, it will start a new run. If it is not finished, it will
resume the training. If the resume is not requested, it will start a new run.
"""
dir_list = sorted(
glob(f"{self.workspace['name']}/{self.model_manifest['name']}*")
)
dir_list = [p for p in dir_list if os.path.isdir(p)]
if len(dir_list) == 0 or not self.workspace["resume"]:
logger.info(
"Either a fresh run or resume is not requested. Starting a new run."
)
self.current["appending_to_previous_run"] = False
self.current["run_dir"] = (
f"{self.workspace['name']}/{self.current['run_title']}"
)
os.makedirs(self.current["run_dir"], exist_ok=True)
else:
last_dir = dir_list[-1]
was_it_finished = os.path.exists(f"{last_dir}/.finished")
if was_it_finished: # start new run
logger.warning(
"Resuming from last training was requested, but it was completed. Exiting."
)
# current_run_dir = (
# f"{self.workspace['name']}/{self.current['run_title']}"
# )
# os.makedirs(current_run_dir, exist_ok=True)
# self.current["appending_to_previous_run"] = False
sys.exit()
else:
logger.info("Last trainer was not finished. Resuming the training.")
self.current["appending_to_previous_run"] = True
self.current["run_dir"] = dir_list[-1]
# make dataset directory
self.current["data_dir"] = f"{self.workspace['name']}/datasets"
os.makedirs(self.current["data_dir"], exist_ok=True)
[docs]
def setup_dataset(self):
"""
Set up the dataset based on the provided information. If the per atom prediction
logging is requested, it will also assign a sequential index to each configuration
for logging.
TODO: ColabFit integration for extreme scale datasets.
"""
self.dataset = Dataset.get_dataset_from_manifest(self.dataset_manifest)
weights = self.loss_manifest.get("weights", None)
if weights is not None:
if isinstance(weights, str):
weights = Path(weights)
Dataset.add_weights(self.dataset, weights)
elif isinstance(weights, dict):
weights = Weight(
config_weight=weights.get("config", None),
energy_weight=weights.get("energy", None),
forces_weight=weights.get("forces", None),
stress_weight=weights.get("stress", None),
)
if isinstance(weights.forces_weight, list):
weights.forces_weight = np.array(weights.forces_weight)
if isinstance(weights.config_weight, list):
weights.config_weight = np.array(weights.config_weight)
Dataset.add_weights(self.dataset, weights)
else:
raise TrainerError("Weights must be a path or a dictionary.")
# index the dataset
# TODO: Can use identifier from the configuration?
if self.current["log_per_atom_pred"]:
for idx, config in enumerate(self.dataset):
config.metadata |= {"index": idx}
# TODO: add lmdb to the requirements
import lmdb # conditional import, only needed for per-atom predictions
map_size = os.environ.get("KLIFF_LMDB_MAP_SIZE", 1e12)
map_size = int(map_size)
self.current["per_atom_pred_database"] = lmdb.open(
f"{self.current['run_dir']}/per_atom_pred_database.lmdb",
map_size=map_size,
subdir=False,
)
[docs]
def save_config(self):
"""
Hash and save the configuration to the current run directory.
"""
config_hash = self.get_trainer_hash()
config_file = f"{self.current['run_dir']}/{config_hash}.yaml"
with open(config_file, "w") as f:
yaml.dump(self.config_to_dict(), f, default_flow_style=False)
logger.info(f"Configuration saved in {config_file}.")
[docs]
def setup_model(self):
"""
Set up the model based on the provided information. If the model is not provided,
it will be loaded from the model_path. If the model_path is not provided, it will
raise an error. If the model_type is KIM, it will be loaded from the KIM model
repository. If KIM type model is installed in CWD, it will be loaded from there, and
model_path will be set to the KIM CWD. If model is of type TAR, it will be untarred
and model_path will be set to the untarred directory. Left for the derived classes
to implement.
"""
raise TrainerError("setup_model not implemented.")
[docs]
def setup_optimizer(self):
"""
Set up the optimizer based on the provided information. If the optimizer is not
provided, it will be loaded from the optimizer_name. If the optimizer_name is not
provided, it will raise an error. If the optimizer_provider is scipy, it will be
loaded from the scipy.optimize. If the optimizer_provider is torch, it will be
loaded from the torch.optim. Left for the derived classes to implement.
"""
raise TrainerError("setup_optimizer not implemented.")
[docs]
def setup_dataset_split(self):
"""
Simple test train split for now, will have more options like stratification
in the future.
"""
# test train splits
train_size = self.training_manifest.get("training_dataset", {}).get(
"train_size", len(self.dataset)
)
val_size = self.training_manifest.get("validation_dataset", {}).get(
"val_size", 0
)
# sanity checks
if not isinstance(train_size, int) or train_size < 1:
logger.warning(
"Train size is not provided or is less than 1. Using full dataset for training."
)
train_size = len(self.dataset)
else:
logger.info(f"Training dataset size: {train_size}")
if not isinstance(val_size, int) or val_size < 0:
logger.warning(
"Val size is not provided or is less than 0. Using 0 for validation."
)
val_size = 0
else:
logger.info(f"Validation dataset size: {val_size}")
if train_size + val_size > len(self.dataset):
raise TrainerError(
"Sum of train, val, and test sizes is greater than the dataset size."
)
# check if indices are provided
train_indices = self.dataset_sample_manifest.get("train_indices")
val_indices = None
if isinstance(train_indices, str):
train_indices = np.genfromtxt(train_indices, dtype=int)
if val_size > 0:
val_indices = np.genfromtxt(
self.dataset_sample_manifest.get("val_indices"), dtype=int
)
else:
TrainerError(f"Could not load indices from {train_indices}.")
if train_indices is None:
indices = np.random.permutation(len(self.dataset))
train_indices = indices[:train_size]
if val_size > 0:
val_indices = indices[-val_size:]
self.dataset_sample_manifest["train_size"] = train_size
self.dataset_sample_manifest["val_size"] = val_size
self.dataset_sample_manifest["train_indices"] = train_indices
self.dataset_sample_manifest["val_indices"] = val_indices
train_dataset = self.dataset[train_indices]
self.train_dataset = train_dataset
if val_size > 0:
val_dataset = self.dataset[val_indices]
self.val_dataset = val_dataset
else:
self.val_dataset = None
# save the indices if generated
if isinstance(train_indices, str):
self.dataset_sample_manifest["indices_files"]["train"] = train_indices
else:
self.dataset_sample_manifest["indices_files"][
"train"
] = f"{self.current['run_dir']}/train_indices.txt"
np.savetxt(
self.dataset_sample_manifest["indices_files"]["train"],
train_indices,
fmt="%d",
)
if val_size > 0:
if isinstance(val_indices, str):
self.dataset_sample_manifest["indices_files"]["val"] = val_indices
else:
self.dataset_sample_manifest["indices_files"][
"val"
] = f"{self.current['run_dir']}/val_indices.txt"
np.savetxt(
self.dataset_sample_manifest["indices_files"]["val"],
val_indices,
fmt="%d",
)
[docs]
def log_per_atom_outputs(
self,
epoch: int,
idx: Union[List[int], np.ndarray],
predictions: List[np.ndarray],
):
"""
Log the per atom outputs to the database. It saves dictionary of predictions and
n_atoms for each configuration. The key for predictions is pred_{n}, where n is
the index of the prediction. For more than one prediction, it will save pred_0,
pred_1, pred_2, etc. The key for the indices is idx
Args:
epoch: Current epoch
idx: List of indices of the configurations
predictions: List of predictions for the configurations
"""
if self.current["per_atom_pred_database"] is None:
return
with self.current["per_atom_pred_database"].begin(write=True) as txn:
for ids, pred in zip(idx, predictions):
if pred is not None:
txn.put(
f"epoch_{epoch}|index_{ids}".encode(),
pkl.dumps({"pred_0": pred, "n_atoms": pred.shape[0]}),
)
else:
logger.warning(f"Prediction for index {ids} is None. Skipping.")
[docs]
def loss(self, *args, **kwargs):
raise TrainerError("loss not implemented.")
[docs]
def checkpoint(self, *args, **kwargs):
raise TrainerError("checkpoint not implemented.")
[docs]
def train_step(self, *args, **kwargs):
raise TrainerError("train_step not implemented.")
[docs]
def validation_step(self, *args, **kwargs):
raise TrainerError("validation_step not implemented.")
[docs]
def get_optimizer(self, *args, **kwargs):
raise TrainerError("get_optimizer not implemented.")
[docs]
def train(self, *args, **kwargs):
raise TrainerError("train not implemented.")
[docs]
def save_kim_model(self, *args, **kwargs):
raise TrainerError("save_kim_model not implemented.")
@staticmethod
def _generate_kim_cmake(model_name: str, driver_name: str, file_list: List) -> str:
"""
Generate the CMakeLists.txt file for KIM API. This will be used to compile the
driver with the KIM API. The driver name is the name of the driver, and the file
list is the list of files to be included in the CMakeLists.txt file.
Private method.
Args:
driver_name: Name of the driver
file_list: List of files to be included in the CMakeLists.txt file
Returns:
CMakeLists.txt file as a string
"""
model_name = model_name.replace("-", "_")
cmake = f"""
cmake_minimum_required(VERSION 3.10)
list(APPEND CMAKE_PREFIX_PATH $ENV{{KIM_API_CMAKE_PREFIX_DIR}})
find_package(KIM-API-ITEMS 2.2 REQUIRED CONFIG)
kim_api_items_setup_before_project(ITEM_TYPE "portableModel")
project({model_name})
kim_api_items_setup_after_project(ITEM_TYPE "portableModel")
add_kim_api_model_library(
NAME ${{PROJECT_NAME}}
DRIVER_NAME "{driver_name}"
PARAMETER_FILES {" ".join(file_list)}
)
"""
return cmake
[docs]
def write_training_env_edn(self, path: str):
"""
Generate the training_env.edn file for the KIM API. This file will be used to
accurately determine the training environment . The file will be saved in the current run directory.
It saves the hash of the configuration, and list of all python dependencies from
pip freeze.
"""
env_file = f"{path}/training_env.edn"
hash = self.get_trainer_hash()
with open(env_file, "w") as f:
try:
from pip._internal.operations.freeze import freeze
from kliff import __version__
except ImportError:
logger.warning(
"Could not import kliff version or pip freeze. Skipping."
)
return
python_env = []
for module in list(freeze()):
if "@" in module:
module = module.split("@")[0]
python_env.append(module)
f.write("{\n")
f.write(f'"kliff-version" "{__version__}"\n')
f.write(f'"trainer-used" "{type(self).__name__}"\n')
f.write(f'"manifest-hash" "{hash}"\n')
f.write(f'"python-dependencies" [\n')
for module in python_env:
f.write(f' "{module}"\n')
f.write(f"]\n")
f.write("}\n")
[docs]
class TrainerError(Exception):
"""
Exceptions to be raised in Trainer and associated classes.
"""
def __init__(self, message):
super().__init__(message)
# TODO:
# 1. Test dataset
# 2. Stress
# 3. Get top k models