import os
import pickle
import random
import tarfile
from collections.abc import Sequence
from pathlib import Path
from typing import Union
import numpy as np
import requests
import yaml
[docs]def length_equal(a, b):
if isinstance(a, Sequence) and isinstance(b, Sequence):
if len(a) == len(b):
return True
else:
return False
else:
return True
[docs]def torch_available():
try:
import torch
return True
except ImportError:
return False
[docs]def split_string(string: str, length=80, starter: str = None):
r"""
Insert `\n` into long string such that each line has size no more than `length`.
Args:
string: The string to split.
length: Targeted length of the each line.
starter: String to insert at the beginning of each line.
"""
if starter is not None:
target_end = length - len(starter) - 1
else:
target_end = length
sub_string = []
while string:
end = target_end
if len(string) > end:
while end >= 0 and string[end] != " ":
end -= 1
end += 1
sub = string[:end].strip()
if starter is not None:
sub = starter + " " + sub
sub_string.append(sub)
string = string[end:]
return "\n".join(sub_string) + "\n"
[docs]def seed_all(seed=35, cudnn_benchmark=False, cudnn_deterministic=False):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
if torch_available():
import torch
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if using multi-GPU
torch.backends.cudnn.benchmark = cudnn_benchmark
torch.backends.cudnn.deterministic = cudnn_deterministic
[docs]def to_path(path: Union[str, Path]) -> Path:
"""
Convert str (or filename) to pathlib.Path.
"""
return Path(path).expanduser().resolve()
[docs]def download_dataset(dataset_name: str) -> Path:
"""
Download dataset and untar it.
Args:
dataset_name: name of the dataset
Returns:
Path to the dataset
"""
path = to_path(dataset_name)
if not path.exists():
tarball = path.with_suffix(".tar.gz")
# download
url = (
f"https://raw.githubusercontent.com/openkim/kliff/master/examples/"
f"{dataset_name}.tar.gz"
)
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(tarball, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
# untar
tf = tarfile.open(tarball, "r:gz")
tf.extractall(path.parent)
# # remove tarball
# tarball.unlink()
return path
[docs]def create_directory(path: Union[str, Path], is_directory: bool = False):
p = to_path(path)
if is_directory:
dirname = p
else:
dirname = p.parent
if not dirname.exists():
os.makedirs(dirname)
[docs]def yaml_dump(data, filename: Union[Path, str]):
"""
Dump data to a yaml file.
"""
create_directory(filename)
with open(to_path(filename), "w") as f:
yaml.dump(data, f, default_flow_style=False)
[docs]def yaml_load(filename: Union[Path, str]):
"""
Load data from a yaml file.
"""
with open(to_path(filename), "r") as f:
data = yaml.safe_load(f)
return data
[docs]def pickle_dump(data, filename: Union[Path, str]):
"""
Dump data to a pickle file.
"""
create_directory(filename)
with open(to_path(filename), "wb") as f:
pickle.dump(data, f)
[docs]def pickle_load(filename: Union[Path, str]):
"""
Load data from a pikel file.
"""
with open(to_path(filename), "rb") as f:
data = pickle.load(f)
return data