aimnet 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aimnet/modules/core.py ADDED
@@ -0,0 +1,237 @@
1
+ from typing import Any, Callable, Dict, List, Optional
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ from aimnet import constants, nbops, ops
7
+ from aimnet.config import get_init_module, get_module
8
+
9
+
10
+ def MLP(
11
+ n_in: int,
12
+ n_out: int,
13
+ hidden: Optional[List[int]] = None,
14
+ activation_fn: Callable | str = "torch.nn.GELU",
15
+ activation_kwargs: Optional[Dict[str, Any]] = None,
16
+ weight_init_fn: Callable | str = "torch.nn.init.xavier_normal_",
17
+ bias: bool = True,
18
+ last_linear: bool = True,
19
+ ):
20
+ """Convenience function to build MLP from config"""
21
+ if hidden is None:
22
+ hidden = []
23
+ if activation_kwargs is None:
24
+ activation_kwargs = {}
25
+ # hp search hack
26
+ hidden = [x for x in hidden if x > 0]
27
+ if isinstance(activation_fn, str):
28
+ activation_fn = get_init_module(activation_fn, kwargs=activation_kwargs)
29
+ if isinstance(weight_init_fn, str):
30
+ weight_init_fn = get_module(weight_init_fn)
31
+ sizes = [n_in, *hidden, n_out]
32
+ layers = []
33
+ for i in range(1, len(sizes)):
34
+ n_in, n_out = sizes[i - 1], sizes[i]
35
+ layer = nn.Linear(n_in, n_out, bias=bias)
36
+ with torch.no_grad():
37
+ weight_init_fn(layer.weight)
38
+ if bias:
39
+ nn.init.zeros_(layer.bias)
40
+ layers.append(layer)
41
+ if not (last_linear and i == len(sizes) - 1):
42
+ layers.append(activation_fn)
43
+ return nn.Sequential(*layers)
44
+
45
+
46
+ class Embedding(nn.Embedding):
47
+ def __init__(self, init: Optional[Dict[int, Any]] = None, **kwargs):
48
+ super().__init__(**kwargs)
49
+ with torch.no_grad():
50
+ if init is not None:
51
+ for i in range(self.weight.shape[0]):
52
+ if self.padding_idx is not None and i == self.padding_idx:
53
+ continue
54
+ if i in init:
55
+ self.weight[i] = init[i]
56
+ else:
57
+ self.weight[i].fill_(float("nan"))
58
+ for k, v in init.items():
59
+ self.weight[k] = v
60
+
61
+ def reset_parameters(self) -> None:
62
+ nn.init.orthogonal_(self.weight)
63
+ if self.padding_idx is not None:
64
+ with torch.no_grad():
65
+ self.weight[self.padding_idx].fill_(0)
66
+
67
+
68
+ class DSequential(nn.Module):
69
+ def __init__(self, *modules):
70
+ super().__init__()
71
+ self.module = nn.ModuleList(modules)
72
+
73
+ def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
74
+ for m in self.module:
75
+ data = m(data)
76
+ return data
77
+
78
+
79
+ class AtomicShift(nn.Module):
80
+ def __init__(
81
+ self,
82
+ key_in: str,
83
+ key_out: str,
84
+ num_types: int = 64,
85
+ dtype: torch.dtype = torch.float,
86
+ requires_grad: bool = True,
87
+ reduce_sum=False,
88
+ ):
89
+ super().__init__()
90
+ shifts = nn.Embedding(num_types, 1, padding_idx=0, dtype=dtype)
91
+ shifts.weight.requires_grad_(requires_grad)
92
+ self.shifts = shifts
93
+ self.key_in = key_in
94
+ self.key_out = key_out
95
+ self.reduce_sum = reduce_sum
96
+
97
+ def extra_repr(self) -> str:
98
+ return f"key_in: {self.key_in}, key_out: {self.key_out}"
99
+
100
+ def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
101
+ shifts = self.shifts(data["numbers"]).squeeze(-1)
102
+ if self.reduce_sum:
103
+ shifts = nbops.mol_sum(shifts, data)
104
+ data[self.key_out] = data[self.key_in] + shifts
105
+ return data
106
+
107
+
108
+ class AtomicSum(nn.Module):
109
+ def __init__(self, key_in: str, key_out: str):
110
+ super().__init__()
111
+ self.key_in = key_in
112
+ self.key_out = key_out
113
+
114
+ def extra_repr(self) -> str:
115
+ return f"key_in: {self.key_in}, key_out: {self.key_out}"
116
+
117
+ def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
118
+ data[self.key_out] = nbops.mol_sum(data[self.key_in], data)
119
+ return data
120
+
121
+
122
+ class Output(nn.Module):
123
+ def __init__(self, mlp: Dict | nn.Module, n_in: int, n_out: int, key_in: str, key_out: str):
124
+ super().__init__()
125
+ self.key_in = key_in
126
+ self.key_out = key_out
127
+ if not isinstance(mlp, nn.Module):
128
+ mlp = MLP(n_in=n_in, n_out=n_out, **mlp)
129
+ self.mlp = mlp
130
+
131
+ def extra_repr(self) -> str:
132
+ return f"key_in: {self.key_in}, key_out: {self.key_out}"
133
+
134
+ def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
135
+ v = self.mlp(data[self.key_in]).squeeze(-1)
136
+ if data["_input_padded"].item():
137
+ v = nbops.mask_i_(v, data, mask_value=0.0)
138
+ data[self.key_out] = v
139
+ return data
140
+
141
+
142
+ class Forces(nn.Module):
143
+ def __init__(self, module: nn.Module, x: str = "coord", y: str = "energy", key_out: str = "forces"):
144
+ super().__init__()
145
+ self.module = module
146
+ self.x = x
147
+ self.y = y
148
+ self.key_out = key_out
149
+
150
+ def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
151
+ prev = torch.is_grad_enabled()
152
+ torch.set_grad_enabled(True)
153
+ data[self.x].requires_grad_(True)
154
+ data = self.module(data)
155
+ y = data[self.y]
156
+ g = torch.autograd.grad([y.sum()], [data[self.x]], create_graph=self.training)[0]
157
+ assert g is not None
158
+ data[self.key_out] = -g
159
+ torch.set_grad_enabled(prev)
160
+ return data
161
+
162
+
163
+ class Dipole(nn.Module):
164
+ def __init__(self, key_in: str = "charges", key_out: str = "dipole", center_coord: bool = False):
165
+ super().__init__()
166
+ self.center_coord = center_coord
167
+ self.key_out = key_out
168
+ self.key_in = key_in
169
+ self.register_buffer("mass", constants.get_masses())
170
+
171
+ def extra_repr(self) -> str:
172
+ return f"key_in: {self.key_in}, key_out: {self.key_out}, center_coord: {self.center_coord}"
173
+
174
+ def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
175
+ q = data[self.key_in]
176
+ r = data["coord"]
177
+ if self.center_coord:
178
+ r = ops.center_coordinates(r, data, self.mass[data["numbers"]])
179
+ data[self.key_out] = nbops.mol_sum(q.unsqueeze(-1) * r, data)
180
+ return data
181
+
182
+
183
+ class Quadrupole(Dipole):
184
+ def __init__(self, key_in: str = "charges", key_out: str = "quadrupole", center_coord: bool = False):
185
+ super().__init__(key_in=key_in, key_out=key_out, center_coord=center_coord)
186
+
187
+ def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
188
+ q = data[self.key_in]
189
+ r = data["coord"]
190
+ if self.center_coord:
191
+ r = ops.center_coordinates(r, data, self.mass[data["numbers"]])
192
+ _x = torch.cat([r.pow(2), r * r.roll(-1, -1)], dim=-1)
193
+ quad = nbops.mol_sum(q.unsqueeze(-1) * _x, data)
194
+ _x1, _x2 = quad.split(3, dim=-1)
195
+ _x1 = _x1 - _x1.mean(dim=-1, keepdim=True)
196
+ quad = torch.cat([_x1, _x2], dim=-1)
197
+ data[self.key_out] = quad
198
+ return data
199
+
200
+
201
+ class SRRep(nn.Module):
202
+ """GFN1-stype short range repulsion function"""
203
+
204
+ def __init__(self, key_out="e_rep", cutoff_fn="none", rc=5.2, reduce_sum=True):
205
+ super().__init__()
206
+ from aimnet.constants import get_gfn1_rep
207
+
208
+ self.key_out = key_out
209
+ self.cutoff_fn = cutoff_fn
210
+ self.reduce_sum = reduce_sum
211
+
212
+ self.register_buffer("rc", torch.tensor(rc))
213
+ gfn1_repa, gfn1_repb = get_gfn1_rep()
214
+ weight = torch.stack([gfn1_repa, gfn1_repb], dim=-1)
215
+ self.params = nn.Embedding(87, 2, padding_idx=0, _weight=weight)
216
+ self.params.weight.requires_grad_(False)
217
+
218
+ def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
219
+ p = self.params(data["numbers"])
220
+ p_i, p_j = nbops.get_ij(p, data)
221
+ p_ij = p_i * p_j
222
+ alpha_ij, zeff_ij = p_ij.unbind(-1)
223
+ d_ij = data["d_ij"]
224
+ e = torch.exp(-alpha_ij * d_ij.pow(1.5)) * zeff_ij / d_ij
225
+ e = nbops.mask_ij_(e, data, 0.0)
226
+ if self.cutoff_fn == "exp_cutoff":
227
+ e = e * ops.exp_cutoff(d_ij, self.rc)
228
+ elif self.cutoff_fn == "cosine_cutoff":
229
+ e = e * ops.cosine_cutoff(d_ij, self.rc)
230
+ e = e.sum(-1)
231
+ if self.reduce_sum:
232
+ e = nbops.mol_sum(e, data)
233
+ if self.key_out in data:
234
+ data[self.key_out] = data[self.key_out] + e
235
+ else:
236
+ data[self.key_out] = e
237
+ return data
aimnet/modules/lr.py ADDED
@@ -0,0 +1,243 @@
1
+ from typing import Dict, Optional
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ from aimnet import constants, nbops, ops
7
+
8
+
9
+ class LRCoulomb(nn.Module):
10
+ def __init__(
11
+ self,
12
+ key_in: str = "charges",
13
+ key_out: str = "e_h",
14
+ rc: float = 4.6,
15
+ method: str = "simple",
16
+ dsf_alpha: float = 0.2,
17
+ dsf_rc: float = 15.0,
18
+ ):
19
+ super().__init__()
20
+ self.key_in = key_in
21
+ self.key_out = key_out
22
+ self._factor = constants.half_Hartree * constants.Bohr
23
+ self.register_buffer("rc", torch.tensor(rc))
24
+ self.dsf_alpha = dsf_alpha
25
+ self.dsf_rc = dsf_rc
26
+ if method in ("simple", "dsf", "ewald"):
27
+ self.method = method
28
+ else:
29
+ raise ValueError(f"Unknown method {method}")
30
+
31
+ def coul_simple(self, data: Dict[str, Tensor]) -> Tensor:
32
+ data = ops.lazy_calc_dij_lr(data)
33
+ d_ij = data["d_ij_lr"]
34
+ q = data[self.key_in]
35
+ q_i, q_j = nbops.get_ij(q, data, suffix="_lr")
36
+ q_ij = q_i * q_j
37
+ fc = 1.0 - ops.exp_cutoff(d_ij, self.rc)
38
+ e_ij = fc * q_ij / d_ij
39
+ e_ij = nbops.mask_ij_(e_ij, data, 0.0, suffix="_lr")
40
+ e_i = e_ij.sum(-1)
41
+ e = self._factor * nbops.mol_sum(e_i, data)
42
+ return e
43
+
44
+ def coul_simple_sr(self, data: Dict[str, Tensor]) -> Tensor:
45
+ d_ij = data["d_ij"]
46
+ q = data[self.key_in]
47
+ q_i, q_j = nbops.get_ij(q, data)
48
+ q_ij = q_i * q_j
49
+ fc = ops.exp_cutoff(d_ij, self.rc)
50
+ e_ij = fc * q_ij / d_ij
51
+ e_ij = nbops.mask_ij_(e_ij, data, 0.0)
52
+ e_i = e_ij.sum(-1)
53
+ e = self._factor * nbops.mol_sum(e_i, data)
54
+ return e
55
+
56
+ def coul_dsf(self, data: Dict[str, Tensor]) -> Tensor:
57
+ data = ops.lazy_calc_dij_lr(data)
58
+ d_ij = data["d_ij_lr"]
59
+ q = data[self.key_in]
60
+ q_i, q_j = nbops.get_ij(q, data, suffix="_lr")
61
+ J = ops.coulomb_matrix_dsf(d_ij, self.dsf_rc, self.dsf_alpha, data)
62
+ e = (q_i * q_j * J).sum(-1)
63
+ e = self._factor * nbops.mol_sum(e, data)
64
+ e = e - self.coul_simple_sr(data)
65
+ return e
66
+
67
+ def coul_ewald(self, data: Dict[str, Tensor]) -> Tensor:
68
+ J = ops.coulomb_matrix_ewald(data["coord"], data["cell"])
69
+ q_i, q_j = data["charges"].unsqueeze(-1), data["charges"].unsqueeze(-2)
70
+ e = self._factor * (q_i * q_j * J).flatten(-2, -1).sum(-1)
71
+ e = e - self.coul_simple_sr(data)
72
+ return e
73
+
74
+ def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
75
+ if self.method == "simple":
76
+ e = self.coul_simple(data)
77
+ elif self.method == "dsf":
78
+ e = self.coul_dsf(data)
79
+ elif self.method == "ewald":
80
+ e = self.coul_ewald(data)
81
+ else:
82
+ raise ValueError(f"Unknown method {self.method}")
83
+ if self.key_out in data:
84
+ data[self.key_out] = data[self.key_out] + e
85
+ else:
86
+ data[self.key_out] = e
87
+ return data
88
+
89
+
90
+ class DispParam(nn.Module):
91
+ def __init__(
92
+ self,
93
+ ref_c6: Optional[Dict[int, Tensor] | Tensor] = None,
94
+ ref_alpha: Optional[Dict[int, Tensor] | Tensor] = None,
95
+ ptfile: Optional[str] = None,
96
+ key_in: str = "disp_param",
97
+ key_out: str = "disp_param",
98
+ ):
99
+ super().__init__()
100
+ if (
101
+ ptfile is None
102
+ and (ref_c6 is None or ref_alpha is None)
103
+ or ptfile is not None
104
+ and (ref_c6 is not None or ref_alpha is not None)
105
+ ):
106
+ raise ValueError("Either ptfile or ref_c6 and ref_alpha should be supplied.")
107
+ # load data
108
+ ref = torch.load(ptfile) if ptfile is not None else torch.zeros(87, 2)
109
+ for i, p in enumerate([ref_c6, ref_alpha]):
110
+ if p is not None:
111
+ if isinstance(p, Tensor):
112
+ ref[: p.shape[0], i] = p
113
+ else:
114
+ for k, v in p.items():
115
+ ref[k, i] = v
116
+ # c6=0 and alpha=1 for dummy atom
117
+ ref[0, 0] = 0.0
118
+ ref[0, 1] = 1.0
119
+ self.register_buffer("disp_param0", ref)
120
+ self.key_in = key_in
121
+ self.key_out = key_out
122
+
123
+ def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
124
+ disp_param_mult = data[self.key_in].clamp(min=-4, max=4).exp()
125
+ disp_param = self.disp_param0[data["numbers"]]
126
+ vals = disp_param * disp_param_mult
127
+ data[self.key_out] = vals
128
+ return data
129
+
130
+
131
+ class D3TS(nn.Module):
132
+ """DFT-D3-like pairwise dispersion with TS combination rule"""
133
+
134
+ def __init__(self, a1: float, a2: float, s8: float, s6: float = 1.0, key_in="disp_param", key_out="energy"):
135
+ super().__init__()
136
+ self.register_buffer("r4r2", constants.get_r4r2())
137
+ self.a1 = a1
138
+ self.a2 = a2
139
+ self.s6 = s6
140
+ self.s8 = s8
141
+ self.key_in = key_in
142
+ self.key_out = key_out
143
+
144
+ def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
145
+ disp_param = data[self.key_in]
146
+ disp_param_i, disp_param_j = nbops.get_ij(disp_param, data, suffix="_lr")
147
+ c6_i, alpha_i = disp_param_i.unbind(dim=-1)
148
+ c6_j, alpha_j = disp_param_j.unbind(dim=-1)
149
+
150
+ # TS combination rule
151
+ c6ij = 2 * c6_i * c6_j / (c6_i * alpha_j / alpha_i + c6_j * alpha_i / alpha_j).clamp(min=1e-4)
152
+
153
+ rr = self.r4r2[data["numbers"]]
154
+ rr_i, rr_j = nbops.get_ij(rr, data, suffix="_lr")
155
+ rrij = 3 * rr_i * rr_j
156
+ rrij = nbops.mask_ij_(rrij, data, 1.0, suffix="_lr")
157
+ r0ij = self.a1 * rrij.sqrt() + self.a2
158
+
159
+ ops.lazy_calc_dij_lr(data)
160
+ d_ij = data["d_ij_lr"] * constants.Bohr_inv
161
+ e_ij = c6ij * (self.s6 / (d_ij.pow(6) + r0ij.pow(6)) + self.s8 * rrij / (d_ij.pow(8) + r0ij.pow(8)))
162
+ e = -constants.half_Hartree * nbops.mol_sum(e_ij.sum(-1), data)
163
+
164
+ if self.key_out in data:
165
+ data[self.key_out] = data[self.key_out] + e
166
+ else:
167
+ data[self.key_out] = e
168
+
169
+ return data
170
+
171
+
172
+ class DFTD3(nn.Module):
173
+ """DFT-D3 implementation.
174
+ BJ dumping, C6 and C8 terms, without 3-body term.
175
+ """
176
+
177
+ def __init__(self, s8: float, a1: float, a2: float, s6: float = 1.0, key_out="energy"):
178
+ super().__init__()
179
+ self.key_out = key_out
180
+ # BJ damping parameters
181
+ self.s6 = s6
182
+ self.s8 = s8
183
+ self.s9 = 4.0 / 3.0
184
+ self.a1 = a1
185
+ self.a2 = a2
186
+ self.a3 = 16.0
187
+ # CN parameters
188
+ self.k1 = -16.0
189
+ self.k3 = -4.0
190
+ # data
191
+ self.register_buffer("c6ab", torch.zeros(95, 95, 5, 5, 3))
192
+ self.register_buffer("r4r2", torch.zeros(95))
193
+ self.register_buffer("rcov", torch.zeros(95))
194
+ self.register_buffer("cnmax", torch.zeros(95))
195
+ sd = constants.get_dftd3_param()
196
+ self.load_state_dict(sd)
197
+
198
+ def _calc_c6ij(self, data: Dict[str, Tensor]) -> Tensor:
199
+ # CN part
200
+ # short range for CN
201
+ # d_ij = data["d_ij"] * constants.Bohr_inv
202
+ data = ops.lazy_calc_dij_lr(data)
203
+ d_ij = data["d_ij_lr"] * constants.Bohr_inv
204
+
205
+ numbers = data["numbers"]
206
+ numbers_i, numbers_j = nbops.get_ij(numbers, data, suffix="_lr")
207
+ rcov_i, rcov_j = nbops.get_ij(self.rcov[numbers], data, suffix="_lr")
208
+ rcov_ij = rcov_i + rcov_j
209
+ cn_ij = 1.0 / (1.0 + torch.exp(self.k1 * (rcov_ij / d_ij - 1.0)))
210
+ cn_ij = nbops.mask_ij_(cn_ij, data, 0.0, suffix="_lr")
211
+ cn = cn_ij.sum(-1)
212
+ cn = torch.clamp(cn, max=self.cnmax[numbers]).unsqueeze(-1).unsqueeze(-1)
213
+ cn_i, cn_j = nbops.get_ij(cn, data, suffix="_lr")
214
+ c6ab = self.c6ab[numbers_i, numbers_j]
215
+ c6ref, cnref_i, cnref_j = torch.unbind(c6ab, dim=-1)
216
+ c6ref = nbops.mask_ij_(c6ref, data, 0.0, suffix="_lr")
217
+ l_ij = torch.exp(self.k3 * ((cn_i - cnref_i).pow(2) + (cn_j - cnref_j).pow(2)))
218
+ w = l_ij.flatten(-2, -1).sum(-1)
219
+ z = torch.einsum("...ij,...ij->...", c6ref, l_ij)
220
+ _w = w < 1e-5
221
+ z[_w] = 0.0
222
+ c6_ij = z / w.clamp(min=1e-5)
223
+ return c6_ij
224
+
225
+ def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
226
+ c6ij = self._calc_c6ij(data)
227
+
228
+ rr = self.r4r2[data["numbers"]]
229
+ rr_i, rr_j = nbops.get_ij(rr, data, suffix="_lr")
230
+ rrij = 3 * rr_i * rr_j
231
+ rrij = nbops.mask_ij_(rrij, data, 1.0, suffix="_lr")
232
+ r0ij = self.a1 * rrij.sqrt() + self.a2
233
+
234
+ ops.lazy_calc_dij_lr(data)
235
+ d_ij = data["d_ij_lr"] * constants.Bohr_inv
236
+ e_ij = c6ij * (self.s6 / (d_ij.pow(6) + r0ij.pow(6)) + self.s8 * rrij / (d_ij.pow(8) + r0ij.pow(8)))
237
+ e = -constants.half_Hartree * nbops.mol_sum(e_ij.sum(-1), data)
238
+
239
+ if self.key_out in data:
240
+ data[self.key_out] = data[self.key_out] + e
241
+ else:
242
+ data[self.key_out] = e
243
+ return data
aimnet/nbops.py ADDED
@@ -0,0 +1,151 @@
1
+ from typing import Dict, Tuple
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+
7
+ def set_nb_mode(data: Dict[str, Tensor]) -> Dict[str, Tensor]:
8
+ """Logic to guess and set the neighbor model."""
9
+ if "nbmat" in data:
10
+ if data["nbmat"].ndim == 2:
11
+ data["_nb_mode"] = torch.tensor(1)
12
+ elif data["nbmat"].ndim == 3:
13
+ data["_nb_mode"] = torch.tensor(2)
14
+ else:
15
+ raise ValueError(f"Invalid neighbor matrix shape: {data['nbmat'].shape}")
16
+ else:
17
+ data["_nb_mode"] = torch.tensor(0)
18
+ return data
19
+
20
+
21
+ def get_nb_mode(data: Dict[str, Tensor]) -> int:
22
+ """Get the neighbor model."""
23
+ return int(data["_nb_mode"].item())
24
+
25
+
26
+ def calc_masks(data: Dict[str, Tensor]) -> Dict[str, Tensor]:
27
+ """Calculate neighbor masks"""
28
+ nb_mode = get_nb_mode(data)
29
+ if nb_mode == 0:
30
+ data["mask_i"] = data["numbers"] == 0
31
+ data["mask_ij"] = torch.eye(
32
+ data["numbers"].shape[1], device=data["numbers"].device, dtype=torch.bool
33
+ ).unsqueeze(0)
34
+ if data["mask_i"].any():
35
+ data["_input_padded"] = torch.tensor(True)
36
+ data["_natom"] = data["mask_i"].logical_not().sum(-1)
37
+ data["mol_sizes"] = (~data["mask_i"]).sum(-1)
38
+ data["mask_ij"] = data["mask_ij"] | (data["mask_i"].unsqueeze(-2) + data["mask_i"].unsqueeze(-1))
39
+ else:
40
+ data["_input_padded"] = torch.tensor(False)
41
+ data["_natom"] = torch.tensor(data["numbers"].shape[1], device=data["numbers"].device)
42
+ data["mol_sizes"] = torch.tensor(data["numbers"].shape[1], device=data["numbers"].device)
43
+ data["mask_ij_lr"] = data["mask_ij"]
44
+ elif nb_mode == 1:
45
+ # padding must be the last atom
46
+ data["mask_i"] = torch.zeros(data["numbers"].shape[0], device=data["numbers"].device, dtype=torch.bool)
47
+ data["mask_i"][-1] = True
48
+ for suffix in ("", "_lr"):
49
+ if f"nbmat{suffix}" in data:
50
+ data[f"mask_ij{suffix}"] = data[f"nbmat{suffix}"] == data["numbers"].shape[0] - 1
51
+ data["_input_padded"] = torch.tensor(True)
52
+ data["mol_sizes"] = torch.bincount(data["mol_idx"])
53
+ # last atom is padding
54
+ data["mol_sizes"][-1] -= 1
55
+ elif nb_mode == 2:
56
+ data["mask_i"] = data["numbers"] == 0
57
+ w = torch.where(data["mask_i"])
58
+ pad_idx = w[0] * data["numbers"].shape[1] + w[1]
59
+ for suffix in ("", "_lr"):
60
+ if f"nbmat{suffix}" in data:
61
+ data[f"mask_ij{suffix}"] = torch.isin(data[f"nbmat{suffix}"], pad_idx)
62
+ data["_input_padded"] = torch.tensor(True)
63
+ data["mol_sizes"] = (~data["mask_i"]).sum(-1)
64
+ else:
65
+ raise ValueError(f"Invalid neighbor mode: {nb_mode}")
66
+
67
+ return data
68
+
69
+
70
+ def mask_ij_(
71
+ x: Tensor,
72
+ data: Dict[str, Tensor],
73
+ mask_value: float = 0.0,
74
+ inplace: bool = True,
75
+ suffix: str = "",
76
+ ) -> Tensor:
77
+ mask = data[f"mask_ij{suffix}"]
78
+ for _i in range(x.ndim - mask.ndim):
79
+ mask = mask.unsqueeze(-1)
80
+ if inplace:
81
+ x.masked_fill_(mask, mask_value)
82
+ else:
83
+ x = x.masked_fill(mask, mask_value)
84
+ return x
85
+
86
+
87
+ def mask_i_(x: Tensor, data: Dict[str, Tensor], mask_value: float = 0.0, inplace: bool = True) -> Tensor:
88
+ nb_mode = get_nb_mode(data)
89
+ if nb_mode == 0:
90
+ if data["_input_padded"].item():
91
+ mask = data["mask_i"]
92
+ for _i in range(x.ndim - mask.ndim):
93
+ mask = mask.unsqueeze(-1)
94
+ if inplace:
95
+ x.masked_fill_(mask, mask_value)
96
+ else:
97
+ x = x.masked_fill(mask, mask_value)
98
+ elif nb_mode == 1:
99
+ if inplace:
100
+ x[-1] = mask_value
101
+ else:
102
+ x = torch.cat([x[:-1], torch.zeros_like(x[:1])], dim=0)
103
+ elif nb_mode == 2:
104
+ if inplace:
105
+ x[:, -1] = mask_value
106
+ else:
107
+ x = torch.cat([x[:, :-1], torch.zeros_like(x[:, :1])], dim=1)
108
+ else:
109
+ raise ValueError(f"Invalid neighbor mode: {nb_mode}")
110
+ return x
111
+
112
+
113
+ def get_ij(x: Tensor, data: Dict[str, Tensor], suffix: str = "") -> Tuple[Tensor, Tensor]:
114
+ nb_mode = get_nb_mode(data)
115
+ if nb_mode == 0:
116
+ x_i = x.unsqueeze(2)
117
+ x_j = x.unsqueeze(1)
118
+ elif nb_mode == 1:
119
+ x_i = x.unsqueeze(1)
120
+ idx = data[f"nbmat{suffix}"]
121
+ x_j = torch.index_select(x, 0, idx.flatten()).unflatten(0, idx.shape)
122
+ elif nb_mode == 2:
123
+ x_i = x.unsqueeze(2)
124
+ idx = data[f"nbmat{suffix}"]
125
+ x_j = torch.index_select(x.flatten(0, 1), 0, idx.flatten()).unflatten(0, idx.shape)
126
+ else:
127
+ raise ValueError(f"Invalid neighbor mode: {nb_mode}")
128
+ return x_i, x_j
129
+
130
+
131
+ def mol_sum(x: Tensor, data: Dict[str, Tensor]) -> Tensor:
132
+ nb_mode = get_nb_mode(data)
133
+ if nb_mode in (0, 2):
134
+ res = x.sum(dim=1)
135
+ elif nb_mode == 1:
136
+ assert x.ndim in (
137
+ 1,
138
+ 2,
139
+ ), "Invalid tensor shape for mol_sum, ndim should be 1 or 2"
140
+ idx = data["mol_idx"]
141
+ # assuming mol_idx is sorted, replace with max if not
142
+ out_size = int(idx[-1].item()) + 1
143
+ if x.ndim == 1:
144
+ res = torch.zeros(out_size, device=x.device, dtype=x.dtype)
145
+ else:
146
+ idx = idx.unsqueeze(-1).expand(-1, x.shape[1])
147
+ res = torch.zeros(out_size, x.shape[1], device=x.device, dtype=x.dtype)
148
+ res.scatter_add_(0, idx, x)
149
+ else:
150
+ raise ValueError(f"Invalid neighbor mode: {nb_mode}")
151
+ return res