Source code for kliff.utils

import os
import pickle
import random
import subprocess
import tarfile
from ast import literal_eval
from collections.abc import Sequence
from pathlib import Path
from typing import Any, Union

import numpy as np
import requests
import yaml


[docs] def length_equal(a, b): if isinstance(a, Sequence) and isinstance(b, Sequence): if len(a) == len(b): return True else: return False else: return True
[docs] def torch_available(): try: import torch return True except ModuleNotFoundError: return False
[docs] def torch_geometric_available(): try: import torch_geometric return True except ModuleNotFoundError: return False
[docs] def split_string(string: str, length=80, starter: str = None): r""" Insert `\n` into long string such that each line has size no more than `length`. Args: string: The string to split. length: Targeted length of the each line. starter: String to insert at the beginning of each line. """ if starter is not None: target_end = length - len(starter) - 1 else: target_end = length sub_string = [] while string: end = target_end if len(string) > end: while end >= 0 and string[end] != " ": end -= 1 end += 1 sub = string[:end].strip() if starter is not None: sub = starter + " " + sub sub_string.append(sub) string = string[end:] return "\n".join(sub_string) + "\n"
[docs] def seed_all(seed=35, cudnn_benchmark=False, cudnn_deterministic=False): random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) if torch_available(): import torch torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # if using multi-GPU torch.backends.cudnn.benchmark = cudnn_benchmark torch.backends.cudnn.deterministic = cudnn_deterministic
[docs] def to_path(path: Union[str, Path]) -> Path: """ Convert str (or filename) to pathlib.Path. """ return Path(path).expanduser().resolve()
[docs] def download_dataset(dataset_name: str) -> Path: """ Download dataset and untar it. Args: dataset_name: name of the dataset Returns: Path to the dataset """ path = to_path(dataset_name) if not path.exists(): tarball = path.with_suffix(".tar.gz") # download url = ( f"https://raw.githubusercontent.com/openkim/kliff/master/examples/" f"{dataset_name}.tar.gz" ) with requests.get(url, stream=True) as r: r.raise_for_status() with open(tarball, "wb") as f: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) # untar tf = tarfile.open(tarball, "r:gz") tf.extractall(path.parent) # # remove tarball # tarball.unlink() return path
[docs] def create_directory(path: Union[str, Path], is_directory: bool = False): p = to_path(path) if is_directory: dirname = p else: dirname = p.parent if not dirname.exists(): os.makedirs(dirname)
[docs] def yaml_dump(data, filename: Union[Path, str]): """ Dump data to a yaml file. """ create_directory(filename) with open(to_path(filename), "w") as f: yaml.dump(data, f, default_flow_style=False)
[docs] def yaml_load(filename: Union[Path, str]): """ Load data from a yaml file. """ with open(to_path(filename), "r") as f: data = yaml.safe_load(f) return data
[docs] def pickle_dump(data, filename: Union[Path, str]): """ Dump data to a pickle file. """ create_directory(filename) with open(to_path(filename), "wb") as f: pickle.dump(data, f)
[docs] def pickle_load(filename: Union[Path, str]): """ Load data from a pikel file. """ with open(to_path(filename), "rb") as f: data = pickle.load(f) return data
[docs] def stress_to_voigt(input_stress: np.ndarray) -> list: """ Convert stress from 3x3 tensor notation to 6x1 Voigt notation. :math:`\sigma_{ij} = [\sigma_{11}, \sigma_{22}, \sigma_{33}, \sigma_{23}, \sigma_{13}, \sigma_{12}]` Args: input_stress: Stress tensor in Voigt notation or tensor notation. Returns: stress: Stress tensor Voigt notation. """ stress = [0.0] * 6 if input_stress.ndim == 2: # tensor -> Voigt stress[0] = input_stress[0, 0] stress[1] = input_stress[1, 1] stress[2] = input_stress[2, 2] stress[3] = input_stress[1, 2] stress[4] = input_stress[0, 2] stress[5] = input_stress[0, 1] else: raise ValueError("input_stress must be a 2D array") return stress
[docs] def stress_to_tensor(input_stress: list) -> np.ndarray: """ Convert stress from 6x1 Voigt notation to 3x3 tensor notation. Args: input_stress: Stress tensor in Voigt notation. Returns: stress: Stress tensor notation. """ stress = np.zeros((3, 3)) stress[0, 0] = input_stress[0] stress[1, 1] = input_stress[1] stress[2, 2] = input_stress[2] stress[1, 2] = stress[2, 1] = input_stress[3] stress[0, 2] = stress[2, 0] = input_stress[4] stress[0, 1] = stress[1, 0] = input_stress[5] return stress
[docs] def is_kim_model_installed(model_name: str) -> bool: """ Check if the KIM model is installed in any collection. Args: model_name: name of the model. """ model_list = subprocess.run( ["kim-api-collections-management", "list"], capture_output=True, text=True ) if model_name in model_list.stdout: return True else: return False
[docs] def install_kim_model(model_name: str, collection: str = "user") -> bool: """ Install the KIM model Args: model_name: name of the model. collection: name of the collection. Returns: True if the model is now installed, False otherwise. """ if not is_kim_model_installed(model_name): output = subprocess.run( ["kim-api-collections-management", "install", collection, model_name], check=True, ) return output.returncode == 0 else: return True
[docs] def get_n_configs_in_xyz(file_path: str) -> int: """ Get the number of configurations in a xyz file. It uses the grep command to count the number of lines that contain only numbers. Args: file_path: Path to the xyz file. Returns: """ pattern = "^[0-9]+$" # Run the grep command and capture the output result = subprocess.run( ["grep", "-Ec", pattern, file_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, ) # Check if there is any error if result.returncode != 0: raise Exception(result.stderr) else: num_atoms = int(result.stdout.strip()) return num_atoms
[docs] def str_to_numpy(expression: str, dtype: Any) -> Union[np.ndarray, Any]: """ Convert a string to numpy array. For reading from SQL/HF databases. Args: expression: Numpy array expression. dtype: dtype to convert to. Returns: Numpy array of str, else returns the same object. """ return ( np.array(literal_eval(expression), dtype) if isinstance(expression, str) else expression )