Source code for kliff.models.linear_regression

import torch
import torch.nn as nn
from loguru import logger
from torch.utils.data import DataLoader

from kliff.dataset.dataset_torch import FingerprintsDataset, fingerprints_collate_fn
from kliff.models.model_torch import ModelTorch


[docs] class LinearRegression(ModelTorch): """Linear regression model.""" def __init__(self, descriptor, seed=35): super(LinearRegression, self).__init__(descriptor, seed) desc_size = self.descriptor.get_size() self.layer = nn.Linear(desc_size, 1)
[docs] def forward(self, x: torch.Tensor): """ Args: x: Descriptors of shape (N, M), where `N` is the number of descriptors and `M` is the descriptor size. """ return self.layer(x)
[docs] def fit(self, path): """Fit the model using analytic solution.""" fp = FingerprintsDataset(path) loader = DataLoader( dataset=fp, batch_size=1, collate_fn=fingerprints_collate_fn ) X, y = self._prepare_data(loader) A = torch.inverse(torch.mm(X.t(), X)) beta = torch.mv(torch.mm(A, X.t()), y) self._set_params(beta) msg = f'Finished fitting model "{self.__class__.__name__}"' logger.info(msg) print(msg)
def _set_params(self, beta): """ Set linear weight and bias. Parameters ---------- beta: Tensor First component is bias and the remaining components is weight. """ # Note, self.layer.weight is a 2D tensor of shape (1, ndesc) # y = xW^T + b self.layer.weight = torch.nn.Parameter(beta[1:].reshape(1, -1)) self.layer.bias = torch.nn.Parameter(beta[0:1]) def _prepare_data(self, loader, use_energy=True, use_forces=False): X = [] y = [] for batch in loader: sample = batch[0] if use_energy: zeta = sample["zeta"] intercept = torch.ones(zeta.size()[0], 1) zeta = torch.cat((intercept, zeta), dim=1) # sum to get energy of the configuration; we can do this because the # model is linear zeta = torch.sum(zeta, 0, keepdim=True) # 2D tensor e = torch.tensor([sample["energy"]]) # 1D tensor if use_forces: dzeta = sample["dzeta_dr"] # 3D tensor (atom, desc, coords) # torch.zeros because derivative of intercept is 0 intercept = torch.zeros(dzeta.size()[0], dzeta.size()[2]) dzeta = torch.cat((intercept, dzeta), dim=1) dzeta = torch.sum(dzeta, 0) # 2D tensor f = sample["forces"][0] # 1D tensor if use_energy and use_forces: x_ = torch.cat((zeta, torch.transpose(dzeta))) y_ = torch.cat((e, f)) elif use_energy: x_ = zeta y_ = e elif use_forces: x_ = torch.transpose(dzeta) y_ = f else: raise LinearRegressionError( 'Both "use_energy" and "use_forces" are "False".' ) X.append(x_) y.append(y_) X = torch.cat(X) y = torch.cat(y) return X, y
[docs] class LinearRegressionError(Exception): def __init__(self, msg): super(LinearRegressionError, self).__init__(msg) self.msg = msg def __expr__(self): return self.msg