Source code for kliff.transforms.configuration_transforms.descriptors

import os
from pathlib import Path
from typing import Dict, List, Tuple, Union

import numpy as np
from loguru import logger
from monty.dev import requires

from kliff.dataset import Configuration
from kliff.neighbor import NeighborList

from .configuration_transform import ConfigurationTransform

try:
    import libdescriptor as lds
except ImportError:
    lds = None
    logger.warning(
        "descriptors module depends on libdescriptor, "
        "which is not found. please install it first."
    )

from .default_hyperparams import (
    bispectrum_default,
    soap_default,
    symmetry_functions_set30,
    symmetry_functions_set51,
)
from .descriptor_initializers import (
    initialize_bispectrum_functions,
    initialize_symmetry_functions,
)


[docs] @requires(lds, "libdescriptor is needed for Descriptors") class AvailableDescriptors: """ This class lists all the available descriptors in libdescriptor. Libdescriptor provides that information as an Enum structure. This class is a wrapper of that. """ def __init__(self): i = 0 while True: desc_type = lds.AvailableDescriptors(i) i += 1 if desc_type.name == "???": break else: setattr(self, desc_type.name, desc_type)
[docs] @requires(lds, "libdescriptor is needed for Descriptors") def show_available_descriptors(): """ Show all the available descriptors in libdescriptor. """ print("-" * 80) print( "Descriptors below are currently available, select them by `descriptor: str` attribute:" ) print("-" * 80) _instance = AvailableDescriptors() for key in _instance.__dict__.keys(): print(f"{key}")
[docs] @requires(lds, "libdescriptor is needed for Descriptors") class Descriptor(ConfigurationTransform): """ Descriptor class provides interface with the libdescriptor library. It provides a unified interface to all the descriptors in libdescriptor, however all descriptors need to have a corresponding initializer routine, to deal with hyperparameters. The descriptor is initialized with a cutoff radius, a list of species,a descriptor type and an ordered list of hyperparameters. The descriptor type is a string, which can be obtained by `show_available_descriptors()` function. Some sane default values for the hyperparameters are provided in ~:kliff.transforms.configuration_transforms.default_hyperparameters module. This class methods to compute derivatives of the descriptor with respect to atomic positions. The functions generating descriptors and their derivatives are implemented as `forward()` and `backward()`, respectively, to match the PyTorch nomenclature. """ def __init__( self, cutoff: float, species: List[str], descriptor: str, hyperparameters: Union[Dict, str], cutoff_function: str = "cos", copy_to_config: bool = False, ): """ Args: cutoff (float): Cutoff radius. species (list): List of strings, each string is a species (atomic symbols). descriptor (str): String of descriptor type, can be obtained by `show_available_descriptors()`. hyperparameters (Dict): Ordered dictionary of hyperparameters. cutoff_function (str): Cut-off function, currently only "cos" is supported. copy_to_config (bool): If True, the fingerprint will be copied to the Configuration object's fingerprint attribute. """ super().__init__(copy_to_config) self.cutoff = cutoff self.species = species _available_descriptors = AvailableDescriptors() self.descriptor_name = descriptor self.descriptor_kind = getattr(_available_descriptors, descriptor) self.width = -1 self.hyperparameters = self.get_default_hyperparams(hyperparameters) self.cutoff_function = cutoff_function self._cdesc, self.width = self._init_descriptor_from_kind() @staticmethod def get_default_hyperparams(hyperparameters: Union[Dict, str]) -> Dict: """ Set hyperparameters for the descriptor. If a string is provided, it will be used to select a set of default hyperparameters. If a dict is provided, it will be used as is. Args: hyperparameters (str or Dict): Hyperparameters for the descriptor. Returns: Dict: dictionary of hyperparameters. """ if isinstance(hyperparameters, str): if hyperparameters == "set51": return symmetry_functions_set51() elif hyperparameters == "set30": return symmetry_functions_set30() elif hyperparameters == "bs_defaults": return bispectrum_default() elif hyperparameters == "soap_defaults": return soap_default() else: raise DescriptorsError("Hyperparameter set not found") elif isinstance(hyperparameters, dict): return hyperparameters else: raise DescriptorsError("Hyperparameters must be either a string or an Dict") def _init_descriptor_from_kind(self) -> Tuple["lds.DescriptorKind", int]: """ Initialize descriptor from descriptor kind. Currently only Symmetry Functions, Bispectrum, abd SOAP descriptors are supported. Returns: tuple: Tuple of descriptor object and width of the descriptor. """ cutoff_array = np.ones((len(self.species), len(self.species))) * self.cutoff # Symmetry Functions if self.descriptor_kind == lds.AvailableDescriptors(0): input_args, width = initialize_symmetry_functions(self.hyperparameters) return ( lds.DescriptorKind.init_descriptor( self.descriptor_kind, self.species, self.cutoff_function, cutoff_array, *input_args, ), width, ) # Bispectrum elif self.descriptor_kind == lds.AvailableDescriptors(1): if self.hyperparameters["weights"] is None: weights = np.ones(len(self.species)) else: weights = self.hyperparameters["weights"] input_args, width = initialize_bispectrum_functions(self.hyperparameters) return ( lds.DescriptorKind.init_descriptor( self.descriptor_kind, *input_args, cutoff_array, self.species, weights, ), width, ) # SOAP # nothing to initialize here, just return the descriptor elif self.descriptor_kind == lds.AvailableDescriptors(2): # width = (((n_species + 1) * n_species) / 2 * (n_max * (n_max + 1)) * (l_max + 1)) / 2; n_species = len(self.species) width = int( (((n_species + 1) * n_species) / 2) * (self.hyperparameters["n_max"] * (self.hyperparameters["n_max"] + 1)) * (self.hyperparameters["l_max"] + 1) / 2 ) return ( lds.DescriptorKind.init_descriptor( self.descriptor_kind, self.hyperparameters["n_max"], self.hyperparameters["l_max"], self.hyperparameters["cutoff"], self.species, self.hyperparameters["radial_basis"], self.hyperparameters["eta"], ), width, ) else: raise DescriptorsError( f"Descriptor kind: {self.descriptor_kind} not supported yet" ) def _map_species_to_int(self, species: List[str]) -> List[int]: """ Map species to integers, which is required by the C++ implementation of the descriptors. TODO: Unify all instances of species -> Z and Z -> species mapping functions. I think currently there are 3 of them. Also sort the atomic numbers and species codes conversions. Perhaps use the ASE function for this. Args: species (list): List of species. Returns: list: List of integers corresponding to the species. """ return [self.species.index(s) for s in species] def forward(self, configuration: Configuration) -> np.ndarray: """ Compute the descriptors for a given configuration, by calling the C++ implementation, :py:func:`libdescriptor.compute_single_atom`. This function accepts a ~:class:`kliff.dataset.Configuration` object and returns a numpy array of shape (n_atoms, width), where n_atoms is the number of atoms in the configuration and width is the width of the descriptor. TODO: Use the :py:func:`libdescriptor.compute` function for faster evaluation. Which loops in C++. Args: configuration: :py:class:`kliff.dataset.Configuration` object to compute descriptors for. Returns: numpy.ndarray: Array of shape (n_atoms, width). """ nl_ctx = NeighborList(configuration, self.cutoff) n_atoms = configuration.get_num_atoms() descriptors = np.zeros((n_atoms, self.width)) species = np.array(self._map_species_to_int(nl_ctx.species), np.intc) for i in range(n_atoms): neigh_list, _, _ = nl_ctx.get_neigh(i) descriptors[i, :] = lds.compute_single_atom( self._cdesc, i, species, np.array(neigh_list, dtype=np.intc), nl_ctx.coords, ) return descriptors def backward( self, configuration: Configuration, dE_dZeta: np.ndarray ) -> np.ndarray: """ Compute the gradients of the descriptors with respect to the atomic coordinates. It takes in an array of shape (n_atoms, width) and the configuration, and performs the vector-Jacobian product (reverse mode automatic differentiation). The output is an array of shape (n_atoms, 3) yielding the gradients of the descriptors with respect to the atomic coordinates. Args: configuration: :py:class:`kliff.dataset.Configuration` object to compute descriptors for. dE_dZeta: :numpy:`ndarray` of shape (n_atoms, width), usually this is the gradient of the ML model (hence input descriptors) with respect to the energy. Returns: :numpy:`ndarray` of shape (n_atoms, 3) """ nl_ctx = NeighborList(configuration, self.cutoff) n_atoms = configuration.get_num_atoms() derivatives_unrolled = np.zeros(nl_ctx.coords.shape) species = np.array(self._map_species_to_int(nl_ctx.species), dtype=np.intc) descriptor = np.zeros(self.width) for i in range(n_atoms): neigh_list, _, _ = nl_ctx.get_neigh(i) descriptors_derivative = lds.gradient_single_atom( self._cdesc, i, species, np.array(neigh_list, dtype=np.intc), nl_ctx.coords, descriptor, dE_dZeta[i, :], ) derivatives_unrolled += descriptors_derivative.reshape(-1, 3) derivatives = np.zeros(configuration.coords.shape) neigh_images = nl_ctx.get_image() for i, atom in enumerate(neigh_images): derivatives[atom, :] += derivatives_unrolled[i, :] return derivatives def __call__( self, configuration: Configuration, return_extended_state=False ) -> Union[np.ndarray, Dict]: """ Map a configuration to a descriptor, but more importantly store all the information needed to compute the reverse pass easily on the batched configuration. This __call__ method specifically stores the neighbor lists and species information as a dictionary and copy it to fingerprint attribute. This will be used by the descriptor dataset collate function to attach neighbor lists in sequential edge-index like format, where then gradient function can be called on stacked batch of configurations. To get this full state dictionary, set the return_extended_state to True. Otherwise, the descriptor numpy array will be returned. This functionality is useful for the descriptor dataset collate function. Args: configuration: ~:class:`kliff.dataset.Configuration` object. return_extended_state: If True, the full state dictionary will be returned. Otherwise, the descriptor numpy array will be returned. Returns: Union[numpy.ndarray, Dict]: Descriptor numpy array or full state dictionary. """ nl_ctx = NeighborList(configuration, self.cutoff) n_atoms = configuration.get_num_atoms() species = np.array(self._map_species_to_int(nl_ctx.species), np.intc) num_neigh, neigh_list = nl_ctx.get_numneigh_and_neighlist_1D() coords = nl_ctx.get_coords() descriptors = lds.compute( self._cdesc, n_atoms, species, neigh_list, num_neigh, coords ) index = configuration.metadata.get("index", -1) if return_extended_state: output = { "n_atoms": n_atoms, "species": species, "neigh_list": neigh_list, "num_neigh": num_neigh, "image": nl_ctx.get_image(), "coords": coords, "descriptor": descriptors, "index": index, "weight": configuration.weight.to_dict(), } else: output = descriptors if self.copy_to_config: configuration.fingerprint = output return output def save_descriptor_state( self, path: Union[str, Path], fname: str = "descriptor.dat" ): """ Write the descriptor parameters to a file, which can be used by libdescritpor to re-initialize the descriptor. TODO: Refactor this function to abstract ut all respective functions. That should make this function easier to maintain. e.g. See _init_descriptor_from_kind(). Args: path (str): Path to the directory where the file will be saved. fname (str): Name of the descriptor file. """ with open(Path.joinpath(Path(path), fname), "w") as fout: # header fout.write("#" + "=" * 80 + "\n") fout.write("# Descriptor parameters file generated by KLIFF.\n") fout.write("#" + "=" * 80 + "\n\n") # cutoff and species cutname, rcut = self.cutoff_function, self.cutoff fout.write("{} # cutoff type\n\n".format(cutname)) fout.write("{} # number of species\n\n".format(len(self.species))) fout.write("# species 1 species 2 cutoff\n") for i, species1 in enumerate(self.species): for j, species2 in enumerate(self.species): fout.write("{} {} {}\n".format(species1, species2, self.cutoff)) fout.write("\n") if self.descriptor_kind == lds.AvailableDescriptors(0): # header fout.write("#" + "=" * 80 + "\n") fout.write("# symmetry functions\n") fout.write("#" + "=" * 80 + "\n\n") num_sym_func = len(self.hyperparameters.keys()) fout.write( "{} # number of symmetry functions types\n\n".format(num_sym_func) ) # descriptor values fout.write("# sym_function rows cols\n") for name, values in self.hyperparameters.items(): if name == "g1": fout.write("g1\n\n") else: rows = len(values) cols = len(values[0]) fout.write("{} {} {}\n".format(name, rows, cols)) if name == "g2": for val in values: fout.write("{} {}".format(val["eta"], val["Rs"])) fout.write(" # eta Rs\n") fout.write("\n") elif name == "g3": for val in values: fout.write("{}".format(val["kappa"])) fout.write(" # kappa\n") fout.write("\n") elif name == "g4": for val in values: zeta = val["zeta"] lam = val["lambda"] eta = val["eta"] fout.write("{} {} {}".format(zeta, lam, eta)) fout.write(" # zeta lambda eta\n") fout.write("\n") elif name == "g5": for val in values: zeta = val["zeta"] lam = val["lambda"] eta = val["eta"] fout.write("{} {} {}".format(zeta, lam, eta)) fout.write(" # zeta lambda eta\n") fout.write("\n") # header fout.write("#" + "=" * 80 + "\n") fout.write("# Preprocessing data to center and normalize\n") fout.write("#" + "=" * 80 + "\n") # mean and stdev mean = [0.0] stdev = [1.0] if mean is None and stdev is None: fout.write("center_and_normalize False\n") else: fout.write("center_and_normalize True\n\n") fout.write("{} # descriptor size\n".format(self.width)) fout.write("# mean\n") for i in mean: fout.write("{} \n".format(i)) fout.write("\n# standard deviation\n") for i in stdev: fout.write("{} \n".format(i)) fout.write("\n") elif self.descriptor_kind == lds.AvailableDescriptors(1): fout.write(f"# jmax\n{self.hyperparameters['jmax']}\n\n") fout.write(f"# rfac0\n{self.hyperparameters['rfac0']}\n\n") fout.write( f"# diagonalstyle\n{self.hyperparameters['diagonalstyle']}\n\n" ) fout.write(f"# rmin0\n{self.hyperparameters['rmin0']}\n\n") fout.write(f"# switch_flag\n{self.hyperparameters['switch_flag']}\n\n") fout.write(f"# bzero_flag\n{self.hyperparameters['bzero_flag']}\n\n") fout.write("# weights\n") if self.hyperparameters["weights"] is None: for i in range(len(self.species)): fout.write("1.0 ") else: for i in self.hyperparameters["weights"]: fout.write(f"{i} ") fout.write("\n\n") elif self.descriptor_kind == lds.AvailableDescriptors(2): fout.write(f"# n_max\n{self.hyperparameters['n_max']}\n\n") fout.write(f"# l_max\n{self.hyperparameters['l_max']}\n\n") fout.write(f"# cutoff\n{self.hyperparameters['cutoff']}\n\n") fout.write( f"# radial_basis\n{self.hyperparameters['radial_basis']}\n\n" ) fout.write(f"# eta\n{self.hyperparameters['eta']}\n\n") def export_kim_model(self, path: str, model: str): """Saves the descriptor model in the KIM format, for re-use in KIM model driver. Args: path (str): Path to the directory where the file will be saved. model (str): Name of the model. """ with open(f"{path}/kim_model.param", "w") as f: n_elements = len(self.species) f.write(f"# Number of elements\n") f.write(f"{n_elements}\n") f.write(f"{' '.join(self.species)}\n\n") f.write("# Preprocessing kind\n") f.write("Descriptor\n\n") f.write("# Cutoff distance\n") f.write(f"{self.cutoff}\n\n") f.write("# Model\n") f.write(f"{model}\n\n") f.write("# Returns Forces\n") f.write("False\n\n") f.write("# Number of inputs\n") f.write("1\n\n") f.write("# Any descriptors?\n") f.write(f"{self.descriptor_name}\n")
[docs] class DescriptorsError(Exception): def __init__(self, msg): super().__init__(msg) self.msg = msg