Source code for kliff.dataset.dataset_torch

from pathlib import Path
from typing import Callable, Optional

import torch
from torch.utils.data import Dataset

from kliff.legacy.descriptors.descriptor import load_fingerprints


[docs] class FingerprintsDataset(Dataset): """ Atomic environment fingerprints dataset used by torch models. Args: filename: to the fingerprints file. transform: transform to be applied on a sample. """ def __init__(self, filename: Path, transform: Optional[Callable] = None): self.fp = load_fingerprints(filename) for i, f in enumerate(self.fp): f["index"] = i self.transform = transform def __len__(self): return len(self.fp) def __getitem__(self, index): sample = self.fp[index] if self.transform: sample = self.transform(sample) return sample
[docs] def fingerprints_collate_fn(batch): """ Convert a batch of samples into tensor. Unlike the default collate_fn(), which stack samples in the batch (requiring each sample having the same dimension), this function does not do the stack. Args: batch: A batch of samples. Returns: A list of tensor. """ tensor_batch = [] for i, sample in enumerate(batch): tensor_sample = {} for key, value in sample.items(): if type(value).__module__ == "numpy": value = torch.from_numpy(value) tensor_sample[key] = value tensor_batch.append(tensor_sample) return tensor_batch