Source code for kliff.calculators.calculator_torch

from pathlib import Path
from typing import Dict, List, Optional, Union

import torch
from loguru import logger
from torch.utils.data import DataLoader

from kliff.dataset.dataset import Configuration
from kliff.dataset.dataset_torch import FingerprintsDataset, fingerprints_collate_fn
from kliff.models.model_torch import ModelTorch
from kliff.models.neural_network import NeuralNetwork
from kliff.utils import pickle_load, to_path


[docs]class CalculatorTorch: """ A calculator for torch based models. Args: model: torch models, e.g. :class:`~kliff.neuralnetwork.NeuralNetwork`. gpu: whether to use gpu for training. If `int` (e.g. 0), will trained on this gpu device. If `True` will always train on gpu `0`. """ implemented_property = ["energy", "forces", "stress"] def __init__(self, model: ModelTorch, gpu: Union[bool, int] = None): device = _get_device(gpu) self._model = model.to(device) self.dtype = self.model.descriptor.dtype self.fingerprints_path = None self.use_energy = None self.use_forces = None self.use_stress = None self.results = dict([(i, None) for i in self.implemented_property])
[docs] def create( self, configs: List[Configuration], use_energy: bool = True, use_forces: bool = True, use_stress: bool = False, fingerprints_filename: Union[Path, str] = "fingerprints.pkl", fingerprints_mean_stdev_filename: Optional[Union[Path, str]] = None, reuse: bool = False, use_welford_method: bool = False, nprocs: int = 1, ): """ Process configs to generate fingerprints. Args: configs: atomic configurations use_energy: Whether to require the calculator to compute energy. use_forces: Whether to require the calculator to compute forces. use_stress: Whether to require the calculator to compute stress. fingerprints_filename: Path to save the generated fingerprints. If `reuse=True`, Will not generate the fingerprints, but directly use the one provided via this file. fingerprints_mean_stdev_filename: Path to save the mean and standard deviation of the fingerprints. If `reuse=True`, Will not generate new fingerprints mean and stdev, but directly use the one provided via this file. If `normalize` is not required by a descriptor, this is ignored. reuse: Whether to reuse provided fingerprints. use_welford_method: Whether to compute mean and standard deviation using the Welford method, which is memory efficient. See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance nprocs: Number of processes used to generate the fingerprints. If `1`, run in serial mode, otherwise `nprocs` processes will be forked via multiprocessing to do the work. """ self.configs = configs self.use_energy = use_energy self.use_forces = use_forces self.use_stress = use_stress if isinstance(configs, Configuration): configs = [configs] # reuse existing file if reuse: self.fingerprints_path = to_path(fingerprints_filename) if not self.fingerprints_path.exists(): raise CalculatorTorchError( f"You specified `reuse=True` to reuse the fingerprints stored in " f"`{self.fingerprints_path}` This file does not exists." ) logger.info(f"Reuse fingerprints `{self.fingerprints_path}`") if self.model.descriptor.normalize: path = ( None if fingerprints_mean_stdev_filename is None else to_path(fingerprints_mean_stdev_filename) ) if path is None or not path.exists(): raise CalculatorTorchError( f"You specified `reuse=True` to reuse the fingerprints. The " f"mean and stdev file of the fingerprints `{path}` does not " "exists." ) self.model.descriptor.load_state_dict(pickle_load(path)) logger.info(f"Reuse fingerprints mean and stdev `{path}`") # generate fingerprints and pickle it else: self.fingerprints_path = self.model.descriptor.generate_fingerprints( configs, use_forces, use_stress, fingerprints_filename, fingerprints_mean_stdev_filename, use_welford_method, nprocs, )
[docs] def get_compute_arguments(self, batch_size: int = 1): """ Return the dataloader with batch size set to `batch_size`. """ fname = self.fingerprints_path fp = FingerprintsDataset(fname) loader = DataLoader( dataset=fp, batch_size=batch_size, collate_fn=fingerprints_collate_fn ) return loader
[docs] def fit(self): path = self.fingerprints_path self.model.fit(path)
[docs] def compute(self, batch): # # shape N--number of atoms in a config; D--feature dim # zeta: (N, D) # dzetadr_force: (N, D, 3N) # dzetadr_stress: (N, D, 6) # # batching dzetadr_force seems difficult, because two axes have different size # this seems doable, combine N and 3N as one dim, and use einstein sum device = self.model.device grad = self.use_forces or self.use_stress # TODO, the batching should be moved to dataloader # get information from batch zeta_config = [sample["zeta"] for sample in batch] zeta_stacked = torch.cat(zeta_config, dim=0).to(device) # evaluate model if grad: zeta_stacked.requires_grad_(True) energy_atom = self.model(zeta_stacked) # forces and stress if not self.use_forces: forces_config = None else: forces_config = [] if not self.use_stress: stress_config = None else: stress_config = [] natoms_config = [len(zeta) for zeta in zeta_config] energy_config = [e.sum() for e in torch.split(energy_atom, natoms_config)] if grad: dedzeta = torch.autograd.grad( energy_atom.sum(), zeta_stacked, create_graph=True )[0] zeta_stacked.requires_grad_(False) # no need of grad any more dedzeta_config = torch.split(dedzeta, natoms_config) for i, sample in enumerate(batch): dedz = dedzeta_config[i] if self.use_forces: dzetadr_forces = sample["dzetadr_forces"].to(device) f = self._compute_forces(dedz, dzetadr_forces) forces_config.append(f) if self.use_stress: dzetadr_stress = sample["dzetadr_stress"].to(device) volume = sample["dzetadr_volume"] s = self._compute_stress(dedz, dzetadr_stress, volume) stress_config.append(s) self.results["energy"] = energy_config self.results["forces"] = forces_config self.results["stress"] = stress_config return { "energy": energy_config, "forces": forces_config, "stress": stress_config, }
@property def model(self): """Get the underlying torch model""" return self._model
[docs] def save_model(self, epoch: int, force_save: bool = False): """ Save the model to disk. When to save a model is dependent on `epoch` and a model's metadata for save. Args: epoch: current optimization epoch. force_save: save the model, ignoring `epoch` and save metadata. """ # save metadata save_prefix = self.model.save_prefix save_start = self.model.save_start save_frequency = self.model.save_frequency path = to_path(save_prefix).joinpath(f"model_epoch{epoch}.pkl") if force_save: self.model.save(path) else: if epoch >= save_start and (epoch - save_start) % save_frequency == 0: self.model.save(path)
[docs] def get_energy(self, batch): return self.results["energy"]
[docs] def get_forces(self, batch): return self.results["forces"]
[docs] def get_stress(self, batch): return self.results["stress"]
@staticmethod def _compute_forces(denergy_dzeta, dzetadr): forces = -torch.tensordot(denergy_dzeta, dzetadr, dims=([0, 1], [0, 1])) return forces @staticmethod def _compute_stress(denergy_dzeta, dzetadr, volume): forces = torch.tensordot(denergy_dzeta, dzetadr, dims=([0, 1], [0, 1])) / volume return forces
class CalculatorTorchSeparateSpecies(CalculatorTorch): """ A calculator supporting models of difference species. Args: models: {species: model} with species specifying the chemical symbol for the model. gpu: whether to use gpu for training. If `int` (e.g. 0), will trained on this gpu device. If `True` will always train on gpu `0`. """ def __init__(self, models: Dict[str, NeuralNetwork], gpu: Union[bool, int] = None): device = _get_device(gpu) self.models = models self.dtype = None for s, m in self.models.items(): m.to(device) if self.dtype is None: self.dtype = m.descriptor.dtype else: if self.dtype != m.descriptor.dtype: raise CalculatorTorchError("inconsistent `dtype` from descriptors.") self._model = _ModelWrapper(models) self.fingerprints_path = None self.use_energy = None self.use_forces = None self.use_stress = None self.results = dict([(i, None) for i in self.implemented_property]) def compute(self, batch): device = self.model.device grad = self.use_forces or self.use_stress # collate batch by species supported_species = self.models.keys() zeta_by_species = {s: [] for s in supported_species} config_id_by_species = {s: [] for s in supported_species} zeta_config = [] for i, sample in enumerate(batch): zeta = sample["zeta"].to(device) species = sample["configuration"].species zeta.requires_grad_(True) zeta_config.append(zeta) for s, z in zip(species, zeta): if s not in supported_species: raise CalculatorTorchError(f"No model for species: {s}") else: zeta_by_species[s].append(z) config_id_by_species[s].append(i) # evaluate model to compute energy energy_config = [None for _ in range(len(batch))] for s, zeta in zeta_by_species.items(): # have no species "s" in this batch of data if not zeta: # zeta == [] continue z_tensor = torch.stack(zeta) # convert a list of tensor to tensor energy = self.models[s](z_tensor) for e_atom, i in zip(energy, config_id_by_species[s]): if energy_config[i] is None: energy_config[i] = e_atom else: # note cannot use +=, energy e_atom is a view energy_config[i] = energy_config[i] + e_atom # forces and stress if not self.use_forces: forces_config = None else: forces_config = [] if not self.use_stress: stress_config = None else: stress_config = [] if grad: for i, sample in enumerate(batch): # derivative of energy w.r.t. zeta energy = energy_config[i] zeta = zeta_config[i] dedz = torch.autograd.grad(energy, zeta, create_graph=True)[0] zeta.requires_grad_(False) # no need of grad any more if self.use_forces: dzetadr_forces = sample["dzetadr_forces"].to(device) f = self._compute_forces(dedz, dzetadr_forces) forces_config.append(f) if self.use_stress: dzetadr_stress = sample["dzetadr_stress"] volume = sample["dzetadr_volume"].to(device) s = self._compute_stress(dedz, dzetadr_stress, volume) stress_config.append(s) self.results["energy"] = energy_config self.results["forces"] = forces_config self.results["stress"] = stress_config return { "energy": energy_config, "forces": forces_config, "stress": stress_config, } @property def model(self): return self._model def save_model(self, epoch: int, force_save: bool = False): """ Save the models to disk. When to save a model is dependent on `epoch` and a model's metadata for save. Args: epoch: current optimization epoch. force_save: save the model, ignoring `epoch` and save metadata. """ # save metadata for name, model in self.models.items(): save_prefix = model.save_prefix save_start = model.save_start save_frequency = model.save_frequency path = to_path(save_prefix).joinpath(f"model_{name}_epoch{epoch}.pkl") if force_save: model.save(path) else: if epoch >= save_start and (epoch - save_start) % save_frequency == 0: model.save(path) class _ModelWrapper(torch.nn.Module): """ A wrapper over multiple torch models. Only add necessary properties: - `LossNeuralNetworkModel` uses `calculator.model.parameters()` and - `calculator.model.device`, and the model wrapper only need to provide them. - descriptor: needed by model create """ def __init__(self, models: Dict[str, torch.nn.Module]): super().__init__() self._models = torch.nn.ModuleDict(models) first_model = list(models.values())[0] # Assuming all models using the same descriptor as in the example_NN_SiC.py # example, then it's OK to set it to the descriptor of the first model. self._descriptor = first_model.descriptor @property def device(self): return next(self.parameters()).device @property def descriptor(self): return self._descriptor # class CalculatorTorchDDP(CalculatorTorch): # def __init__(self, model, rank, world_size): # super(self).__init__(model) # self.set_up(rank, world_size) # # def set_up(self, rank, world_size): # os.environ["MASTER_ADDR"] = "localhost" # os.environ["MASTER_PORT"] = "12355" # dist.init_process_group("gloo", rank=rank, world_size=world_size) # # def clean_up(self): # dist.destroy_process_group() # # def compute(self, batch): # grad = self.use_forces # # # collate batch input to NN # zeta_config = self._collate(batch, "zeta") # if grad: # for zeta in zeta_config: # zeta.requires_grad_(True) # zeta_stacked = torch.cat(zeta_config, dim=0) # # # evaluate model # model = DistributedDataParallel(self.model) # energy_atom = model(zeta_stacked) # # # energy # natoms_config = [len(zeta) for zeta in zeta_config] # energy_config = [e.sum() for e in torch.split(energy_atom, natoms_config)] # # # forces # if grad: # dzetadr_config = self._collate(batch, "dzetadr") # forces_config = self.compute_forces_config( # energy_config, zeta_config, dzetadr_config # ) # for zeta in zeta_config: # zeta.requires_grad_(False) # else: # forces_config = None # # return {"energy": energy_config, "forces": forces_config} # # def __del__(self): # self.clean_up() class CalculatorTorchError(Exception): def __init__(self, msg): super(CalculatorTorchError, self).__init__(msg) self.msg = msg def _get_device(gpu): device = None if isinstance(gpu, bool): if gpu: device = torch.device(0) logger.info(f"Training on gpu") elif isinstance(gpu, int): device = torch.device(gpu) logger.info(f"Training on gpu {gpu}") if device is None: logger.info("Training on cpu") return device