from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from loguru import logger
from torch.utils.data import DataLoader
from kliff.dataset.dataset import Configuration
from kliff.dataset.dataset_torch import FingerprintsDataset, fingerprints_collate_fn
from kliff.models.model_torch import ModelTorch
from kliff.models.neural_network import NeuralNetwork
from kliff.utils import pickle_load, to_path
[docs]
class CalculatorTorch:
"""
A calculator for torch based models.
Args:
model: torch models, e.g. :class:`~kliff.neuralnetwork.NeuralNetwork`.
gpu: whether to use gpu for training. If `int` (e.g. 0), will trained on this
gpu device. If `True` will always train on gpu `0`.
"""
implemented_property = ["energy", "forces", "stress"]
def __init__(self, model: ModelTorch, gpu: Union[bool, int] = None):
device = _get_device(gpu)
self._model = model.to(device)
self.dtype = self.model.descriptor.dtype
self.fingerprints_path = None
self.use_energy = None
self.use_forces = None
self.use_stress = None
self.results = dict([(i, None) for i in self.implemented_property])
[docs]
def create(
self,
configs: List[Configuration],
use_energy: bool = True,
use_forces: bool = True,
use_stress: bool = False,
fingerprints_filename: Union[Path, str] = "fingerprints.pkl",
fingerprints_mean_stdev_filename: Optional[Union[Path, str]] = None,
reuse: bool = False,
use_welford_method: bool = False,
nprocs: int = 1,
):
"""
Process configs to generate fingerprints.
Args:
configs: atomic configurations
use_energy: Whether to require the calculator to compute energy.
use_forces: Whether to require the calculator to compute forces.
use_stress: Whether to require the calculator to compute stress.
fingerprints_filename: Path to save the generated fingerprints.
If `reuse=True`, Will not generate the fingerprints, but directly use the
one provided via this file.
fingerprints_mean_stdev_filename: Path to save the mean and standard deviation
of the fingerprints. If `reuse=True`, Will not generate new fingerprints
mean and stdev, but directly use the one provided via this file.
If `normalize` is not required by a descriptor, this is ignored.
reuse: Whether to reuse provided fingerprints.
use_welford_method: Whether to compute mean and standard deviation using the
Welford method, which is memory efficient. See
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
nprocs: Number of processes used to generate the fingerprints. If `1`, run
in serial mode, otherwise `nprocs` processes will be forked via
multiprocessing to do the work.
"""
self.configs = configs
self.use_energy = use_energy
self.use_forces = use_forces
self.use_stress = use_stress
if isinstance(configs, Configuration):
configs = [configs]
# reuse existing file
if reuse:
self.fingerprints_path = to_path(fingerprints_filename)
if not self.fingerprints_path.exists():
raise CalculatorTorchError(
f"You specified `reuse=True` to reuse the fingerprints stored in "
f"`{self.fingerprints_path}` This file does not exists."
)
logger.info(f"Reuse fingerprints `{self.fingerprints_path}`")
if self.model.descriptor.normalize:
path = (
None
if fingerprints_mean_stdev_filename is None
else to_path(fingerprints_mean_stdev_filename)
)
if path is None or not path.exists():
raise CalculatorTorchError(
f"You specified `reuse=True` to reuse the fingerprints. The "
f"mean and stdev file of the fingerprints `{path}` does not "
"exists."
)
self.model.descriptor.load_state_dict(pickle_load(path))
logger.info(f"Reuse fingerprints mean and stdev `{path}`")
# generate fingerprints and pickle it
else:
self.fingerprints_path = self.model.descriptor.generate_fingerprints(
configs,
use_forces,
use_stress,
fingerprints_filename,
fingerprints_mean_stdev_filename,
use_welford_method,
nprocs,
)
# Finally, assign fingerprints dataset property as a FingerprintsDataset instance
self.fingerprints_dataset = FingerprintsDataset(self.fingerprints_path)
[docs]
def get_fingerprints(self) -> List[dict]:
"""
Return a list of fingerprints of the configurations.
"""
return self.fingerprints_dataset.fp
[docs]
def get_compute_arguments(self, batch_size: int = 1):
"""
Return the dataloader with batch size set to `batch_size`.
"""
loader = DataLoader(
dataset=self.fingerprints_dataset,
batch_size=batch_size,
collate_fn=fingerprints_collate_fn,
)
return loader
[docs]
def set_fingerprints(self, fingerprints: List[dict]):
"""
Update the fingerprints of the calculator. The fingerprints input argument should
be in the same format as the output of `meth:~kliff.legacy.descriptors.descriptor.load_fingerprints`,
which is a list of dictionaries.
Args:
fingerprints: A list of fingerprints.
"""
self.fingerprints_dataset.fp = fingerprints
[docs]
def fit(self):
path = self.fingerprints_path
self.model.fit(path)
[docs]
def compute(self, batch):
#
# shape N--number of atoms in a config; D--feature dim
# zeta: (N, D)
# dzetadr_force: (N, D, 3N)
# dzetadr_stress: (N, D, 6)
#
# batching dzetadr_force seems difficult, because two axes have different size
# this seems doable, combine N and 3N as one dim, and use einstein sum
device = self.model.device
grad = self.use_forces or self.use_stress
# TODO, the batching should be moved to dataloader
# get information from batch
zeta_config = [sample["zeta"] for sample in batch]
zeta_stacked = torch.cat(zeta_config, dim=0).to(device)
# evaluate model
if grad:
zeta_stacked.requires_grad_(True)
energy_atom = self.model(zeta_stacked)
# forces and stress
if not self.use_forces:
forces_config = None
else:
forces_config = []
if not self.use_stress:
stress_config = None
else:
stress_config = []
natoms_config = [len(zeta) for zeta in zeta_config]
energy_config = [e.sum() for e in torch.split(energy_atom, natoms_config)]
if grad:
dedzeta = torch.autograd.grad(
energy_atom.sum(), zeta_stacked, create_graph=True
)[0]
zeta_stacked.requires_grad_(False) # no need of grad any more
dedzeta_config = torch.split(dedzeta, natoms_config)
for i, sample in enumerate(batch):
dedz = dedzeta_config[i]
if self.use_forces:
dzetadr_forces = sample["dzetadr_forces"].to(device)
f = self._compute_forces(dedz, dzetadr_forces)
forces_config.append(f)
if self.use_stress:
dzetadr_stress = sample["dzetadr_stress"].to(device)
volume = sample["dzetadr_volume"]
s = self._compute_stress(dedz, dzetadr_stress, volume)
stress_config.append(s)
self.results["energy"] = energy_config
self.results["forces"] = forces_config
self.results["stress"] = stress_config
return {
"energy": energy_config,
"forces": forces_config,
"stress": stress_config,
}
@property
def model(self):
"""Get the underlying torch model"""
return self._model
[docs]
def save_model(self, epoch: int, force_save: bool = False):
"""
Save the model to disk.
When to save a model is dependent on `epoch` and a model's metadata for save.
Args:
epoch: current optimization epoch.
force_save: save the model, ignoring `epoch` and save metadata.
"""
# save metadata
save_prefix = self.model.save_prefix
save_start = self.model.save_start
save_frequency = self.model.save_frequency
path = to_path(save_prefix).joinpath(f"model_epoch{epoch}.pkl")
if force_save:
self.model.save(path)
else:
if epoch >= save_start and (epoch - save_start) % save_frequency == 0:
self.model.save(path)
[docs]
def get_energy(self, batch):
return self.results["energy"]
[docs]
def get_forces(self, batch):
return self.results["forces"]
[docs]
def get_stress(self, batch):
return self.results["stress"]
@staticmethod
def _compute_forces(denergy_dzeta, dzetadr):
forces = -torch.tensordot(denergy_dzeta, dzetadr, dims=([0, 1], [0, 1]))
return forces
@staticmethod
def _compute_stress(denergy_dzeta, dzetadr, volume):
forces = torch.tensordot(denergy_dzeta, dzetadr, dims=([0, 1], [0, 1])) / volume
return forces
[docs]
def get_size_opt_params(self) -> Tuple[List[int], List[int], int]:
"""
Return the size of the parameters.
Returns:
sizes: Each element in the list gives the shape of each type of parameter
tensors, containing, e.g., weights and biases, for each layer.
nelements: Number of elements of each parameter tensor.
nparams: Total number of parameters
"""
sizes = [] # Size of each parameter tensor
nelements = [] # The number of elements for each tensor
for param in self.model.parameters():
sizes.append(param.size())
nelements.append(np.prod(param.size()))
nparams = sum(nelements)
return sizes, nelements, nparams
[docs]
def get_num_opt_params(self) -> int:
"""
Return the total number of parameters.
"""
return self.get_size_opt_params()[-1]
[docs]
def get_opt_params(self, flat: bool = True) -> Union[List, np.array]:
"""
Retrieve the parameters, i.e., weights and biases.
Args:
flat: A flag to return a flat, 1D array.
Returns:
Parameters, i.e., weights and biases. If ``flat=True``, a 1D np.ndarray will
be returned. Otherwise, nested lists will be returned, where each list contain
the weights and biases for each layer.
"""
parameters = []
for param in self.model.parameters():
if flat:
# Make sure that the parameters are stored in host memory
param_host = param.data.cpu()
parameters = np.append(parameters, param_host.numpy().flatten())
else:
parameters.append(param)
return parameters
[docs]
def update_model_params(self, parameters: np.array):
"""
Update the model parameters from a 1D array.
Args:
parameters: New parameter values to set. It needs to be a 1D array.
"""
# Convert to the right format
parameters = self._convert_parameters_from_1d_array(parameters)
# Update the weights and biases
for ii, param in enumerate(self.model.parameters()):
param.data = parameters[ii]
def _convert_parameters_from_1d_array(self, flat_params: np.array) -> List:
"""
Convert the parameters from a 1D array format to nested lists format.
Args:
flat_params: A 1D array containing weights and biases of the model.
Returns:
parameters: Parameters (weiths and biases) in nested lists format.
"""
sizes, nelems, _ = self.get_size_opt_params()
# Indices to index the flat array to get the appropriate portion of each parameter
# tensor
idx = np.append(0.0, np.cumsum(nelems)).astype(int)
parameters = []
for ii, size in enumerate(sizes):
params = flat_params[idx[ii] : idx[ii + 1]]
parameters.append(torch.Tensor(params.reshape(size)))
return parameters
[docs]
class CalculatorTorchSeparateSpecies(CalculatorTorch):
"""
A calculator supporting models of difference species.
Args:
models: {species: model} with species specifying the chemical symbol for the
model.
gpu: whether to use gpu for training. If `int` (e.g. 0), will trained on this
gpu device. If `True` will always train on gpu `0`.
"""
def __init__(self, models: Dict[str, NeuralNetwork], gpu: Union[bool, int] = None):
device = _get_device(gpu)
self.models = models
self.dtype = None
for s, m in self.models.items():
m.to(device)
if self.dtype is None:
self.dtype = m.descriptor.dtype
else:
if self.dtype != m.descriptor.dtype:
raise CalculatorTorchError("inconsistent `dtype` from descriptors.")
self._model = _ModelWrapper(models)
self.fingerprints_path = None
self.use_energy = None
self.use_forces = None
self.use_stress = None
self.results = dict([(i, None) for i in self.implemented_property])
[docs]
def compute(self, batch):
device = self.model.device
grad = self.use_forces or self.use_stress
# collate batch by species
supported_species = self.models.keys()
zeta_by_species = {s: [] for s in supported_species}
config_id_by_species = {s: [] for s in supported_species}
zeta_config = []
for i, sample in enumerate(batch):
zeta = sample["zeta"].to(device)
species = sample["configuration"].species
zeta.requires_grad_(True)
zeta_config.append(zeta)
for s, z in zip(species, zeta):
if s not in supported_species:
raise CalculatorTorchError(f"No model for species: {s}")
else:
zeta_by_species[s].append(z)
config_id_by_species[s].append(i)
# evaluate model to compute energy
energy_config = [None for _ in range(len(batch))]
for s, zeta in zeta_by_species.items():
# have no species "s" in this batch of data
if not zeta: # zeta == []
continue
z_tensor = torch.stack(zeta) # convert a list of tensor to tensor
energy = self.models[s](z_tensor)
for e_atom, i in zip(energy, config_id_by_species[s]):
if energy_config[i] is None:
energy_config[i] = e_atom
else:
# note cannot use +=, energy e_atom is a view
energy_config[i] = energy_config[i] + e_atom
# forces and stress
if not self.use_forces:
forces_config = None
else:
forces_config = []
if not self.use_stress:
stress_config = None
else:
stress_config = []
if grad:
for i, sample in enumerate(batch):
# derivative of energy w.r.t. zeta
energy = energy_config[i]
zeta = zeta_config[i]
dedz = torch.autograd.grad(energy, zeta, create_graph=True)[0]
zeta.requires_grad_(False) # no need of grad any more
if self.use_forces:
dzetadr_forces = sample["dzetadr_forces"].to(device)
f = self._compute_forces(dedz, dzetadr_forces)
forces_config.append(f)
if self.use_stress:
dzetadr_stress = sample["dzetadr_stress"]
volume = sample["dzetadr_volume"].to(device)
s = self._compute_stress(dedz, dzetadr_stress, volume)
stress_config.append(s)
self.results["energy"] = energy_config
self.results["forces"] = forces_config
self.results["stress"] = stress_config
return {
"energy": energy_config,
"forces": forces_config,
"stress": stress_config,
}
@property
def model(self):
return self._model
[docs]
def save_model(self, epoch: int, force_save: bool = False):
"""
Save the models to disk.
When to save a model is dependent on `epoch` and a model's metadata for save.
Args:
epoch: current optimization epoch.
force_save: save the model, ignoring `epoch` and save metadata.
"""
# save metadata
for name, model in self.models.items():
save_prefix = model.save_prefix
save_start = model.save_start
save_frequency = model.save_frequency
path = to_path(save_prefix).joinpath(f"model_{name}_epoch{epoch}.pkl")
if force_save:
model.save(path)
else:
if epoch >= save_start and (epoch - save_start) % save_frequency == 0:
model.save(path)
class _ModelWrapper(torch.nn.Module):
"""
A wrapper over multiple torch models.
Only add necessary properties:
- `LossNeuralNetworkModel` uses `calculator.model.parameters()` and
- `calculator.model.device`, and the model wrapper only need to provide them.
- descriptor: needed by model create
"""
def __init__(self, models: Dict[str, torch.nn.Module]):
super().__init__()
self._models = torch.nn.ModuleDict(models)
first_model = list(models.values())[0]
# Assuming all models using the same descriptor as in the example_NN_SiC.py
# example, then it's OK to set it to the descriptor of the first model.
self._descriptor = first_model.descriptor
@property
def device(self):
return next(self.parameters()).device
@property
def descriptor(self):
return self._descriptor
# class CalculatorTorchDDP(CalculatorTorch):
# def __init__(self, model, rank, world_size):
# super(self).__init__(model)
# self.set_up(rank, world_size)
#
# def set_up(self, rank, world_size):
# os.environ["MASTER_ADDR"] = "localhost"
# os.environ["MASTER_PORT"] = "12355"
# dist.init_process_group("gloo", rank=rank, world_size=world_size)
#
# def clean_up(self):
# dist.destroy_process_group()
#
# def compute(self, batch):
# grad = self.use_forces
#
# # collate batch input to NN
# zeta_config = self._collate(batch, "zeta")
# if grad:
# for zeta in zeta_config:
# zeta.requires_grad_(True)
# zeta_stacked = torch.cat(zeta_config, dim=0)
#
# # evaluate model
# model = DistributedDataParallel(self.model)
# energy_atom = model(zeta_stacked)
#
# # energy
# natoms_config = [len(zeta) for zeta in zeta_config]
# energy_config = [e.sum() for e in torch.split(energy_atom, natoms_config)]
#
# # forces
# if grad:
# dzetadr_config = self._collate(batch, "dzetadr")
# forces_config = self.compute_forces_config(
# energy_config, zeta_config, dzetadr_config
# )
# for zeta in zeta_config:
# zeta.requires_grad_(False)
# else:
# forces_config = None
#
# return {"energy": energy_config, "forces": forces_config}
#
# def __del__(self):
# self.clean_up()
[docs]
class CalculatorTorchError(Exception):
def __init__(self, msg):
super(CalculatorTorchError, self).__init__(msg)
self.msg = msg
def _get_device(gpu):
device = None
if isinstance(gpu, bool):
if gpu:
device = torch.device(0)
logger.info(f"Training on gpu")
elif isinstance(gpu, int):
device = torch.device(gpu)
logger.info(f"Training on gpu {gpu}")
if device is None:
logger.info("Training on cpu")
return device