aimnet 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aimnet/__init__.py +7 -0
- aimnet/base.py +24 -8
- aimnet/calculators/__init__.py +4 -4
- aimnet/calculators/aimnet2ase.py +19 -6
- aimnet/calculators/calculator.py +868 -108
- aimnet/calculators/model_registry.py +2 -5
- aimnet/calculators/model_registry.yaml +55 -17
- aimnet/cli.py +62 -6
- aimnet/config.py +8 -9
- aimnet/data/sgdataset.py +23 -22
- aimnet/kernels/__init__.py +66 -0
- aimnet/kernels/conv_sv_2d_sp_wp.py +478 -0
- aimnet/models/__init__.py +13 -1
- aimnet/models/aimnet2.py +19 -22
- aimnet/models/base.py +183 -15
- aimnet/models/convert.py +30 -0
- aimnet/models/utils.py +735 -0
- aimnet/modules/__init__.py +1 -1
- aimnet/modules/aev.py +49 -48
- aimnet/modules/core.py +14 -13
- aimnet/modules/lr.py +520 -115
- aimnet/modules/ops.py +537 -0
- aimnet/nbops.py +105 -15
- aimnet/ops.py +90 -18
- aimnet/train/export_model.py +226 -0
- aimnet/train/loss.py +7 -7
- aimnet/train/metrics.py +5 -6
- aimnet/train/train.py +4 -1
- aimnet/train/utils.py +42 -13
- aimnet-0.1.0.dist-info/METADATA +308 -0
- aimnet-0.1.0.dist-info/RECORD +43 -0
- {aimnet-0.0.1.dist-info → aimnet-0.1.0.dist-info}/WHEEL +1 -1
- aimnet-0.1.0.dist-info/entry_points.txt +3 -0
- aimnet/calculators/nb_kernel_cpu.py +0 -222
- aimnet/calculators/nb_kernel_cuda.py +0 -217
- aimnet/calculators/nbmat.py +0 -220
- aimnet/train/pt2jpt.py +0 -81
- aimnet-0.0.1.dist-info/METADATA +0 -78
- aimnet-0.0.1.dist-info/RECORD +0 -41
- aimnet-0.0.1.dist-info/entry_points.txt +0 -5
- {aimnet-0.0.1.dist-info → aimnet-0.1.0.dist-info/licenses}/LICENSE +0 -0
aimnet/calculators/calculator.py
CHANGED
|
@@ -1,20 +1,205 @@
|
|
|
1
|
+
import math
|
|
1
2
|
import warnings
|
|
2
|
-
from typing import Any, ClassVar,
|
|
3
|
+
from typing import Any, ClassVar, Literal
|
|
3
4
|
|
|
4
5
|
import torch
|
|
6
|
+
from nvalchemiops.neighborlist import neighbor_list
|
|
7
|
+
from nvalchemiops.neighborlist.neighbor_utils import NeighborOverflowError
|
|
5
8
|
from torch import Tensor, nn
|
|
6
9
|
|
|
10
|
+
from aimnet.models.base import load_model
|
|
11
|
+
from aimnet.modules import DFTD3, LRCoulomb
|
|
12
|
+
|
|
7
13
|
from .model_registry import get_model_path
|
|
8
|
-
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AdaptiveNeighborList:
|
|
17
|
+
"""Adaptive neighbor list with automatic buffer sizing.
|
|
18
|
+
|
|
19
|
+
Wraps nvalchemiops.neighborlist.neighbor_list with automatic max_neighbors adjustment.
|
|
20
|
+
Maintains ~75% utilization to balance memory and recomputation.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
cutoff : float
|
|
25
|
+
Cutoff distance for neighbor detection in Angstroms.
|
|
26
|
+
density : float, optional
|
|
27
|
+
Initial atomic density estimate for allocation sizing.
|
|
28
|
+
Used to compute initial max_neighbors as density * (4/3 * pi * cutoff^3).
|
|
29
|
+
Default is 0.2.
|
|
30
|
+
target_utilization : float, optional
|
|
31
|
+
Target ratio of actual neighbors to allocated max_neighbors.
|
|
32
|
+
Default is 0.75 (75% utilization).
|
|
33
|
+
|
|
34
|
+
Attributes
|
|
35
|
+
----------
|
|
36
|
+
cutoff : float
|
|
37
|
+
Cutoff distance for neighbor detection.
|
|
38
|
+
target_utilization : float
|
|
39
|
+
Target ratio of actual to allocated neighbors.
|
|
40
|
+
max_neighbors : int
|
|
41
|
+
Current maximum neighbor allocation (rounded to 16).
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
cutoff: float,
|
|
47
|
+
density: float = 0.2,
|
|
48
|
+
target_utilization: float = 0.75,
|
|
49
|
+
) -> None:
|
|
50
|
+
self.cutoff = cutoff
|
|
51
|
+
self.target_utilization = target_utilization
|
|
52
|
+
sphere_volume = 4 / 3 * math.pi * cutoff**3
|
|
53
|
+
self.max_neighbors = self._round_to_16(int(density * sphere_volume))
|
|
54
|
+
|
|
55
|
+
@staticmethod
|
|
56
|
+
def _round_to_16(n: int) -> int:
|
|
57
|
+
"""Round up to the next multiple of 16 for memory alignment."""
|
|
58
|
+
return ((n + 15) // 16) * 16
|
|
59
|
+
|
|
60
|
+
def __call__(
|
|
61
|
+
self,
|
|
62
|
+
positions: Tensor,
|
|
63
|
+
cell: Tensor | None = None,
|
|
64
|
+
pbc: Tensor | None = None,
|
|
65
|
+
batch_idx: Tensor | None = None,
|
|
66
|
+
fill_value: int | None = None,
|
|
67
|
+
) -> tuple[Tensor, Tensor, Tensor | None]:
|
|
68
|
+
"""Compute neighbor list with automatic buffer adjustment.
|
|
69
|
+
|
|
70
|
+
Parameters
|
|
71
|
+
----------
|
|
72
|
+
positions : Tensor
|
|
73
|
+
Atomic coordinates, shape (N, 3).
|
|
74
|
+
cell : Tensor | None
|
|
75
|
+
Unit cell vectors, shape (num_systems, 3, 3). None for non-periodic.
|
|
76
|
+
pbc : Tensor | None
|
|
77
|
+
Periodic boundary conditions, shape (num_systems, 3). None for non-periodic.
|
|
78
|
+
batch_idx : Tensor | None
|
|
79
|
+
Batch index for each atom, shape (N,). None for single system.
|
|
80
|
+
fill_value : int | None
|
|
81
|
+
Fill value for padding. Default is N (number of atoms).
|
|
82
|
+
|
|
83
|
+
Returns
|
|
84
|
+
-------
|
|
85
|
+
nbmat : Tensor
|
|
86
|
+
Neighbor indices, shape (N, actual_max_neighbors).
|
|
87
|
+
num_neighbors : Tensor
|
|
88
|
+
Number of neighbors per atom, shape (N,).
|
|
89
|
+
shifts : Tensor | None
|
|
90
|
+
Integer unit cell shifts for PBC, shape (N, actual_max_neighbors, 3).
|
|
91
|
+
None for non-periodic systems.
|
|
92
|
+
"""
|
|
93
|
+
N = positions.shape[0]
|
|
94
|
+
if fill_value is None:
|
|
95
|
+
fill_value = N
|
|
96
|
+
_pbc = cell is not None
|
|
97
|
+
|
|
98
|
+
while True:
|
|
99
|
+
try:
|
|
100
|
+
if _pbc:
|
|
101
|
+
nbmat, num_neighbors, shifts = neighbor_list(
|
|
102
|
+
positions=positions,
|
|
103
|
+
cutoff=self.cutoff,
|
|
104
|
+
cell=cell,
|
|
105
|
+
pbc=pbc,
|
|
106
|
+
batch_idx=batch_idx,
|
|
107
|
+
max_neighbors=self.max_neighbors,
|
|
108
|
+
half_fill=False,
|
|
109
|
+
fill_value=fill_value,
|
|
110
|
+
)
|
|
111
|
+
else:
|
|
112
|
+
nbmat, num_neighbors = neighbor_list(
|
|
113
|
+
positions=positions,
|
|
114
|
+
cutoff=self.cutoff,
|
|
115
|
+
batch_idx=batch_idx,
|
|
116
|
+
max_neighbors=self.max_neighbors,
|
|
117
|
+
half_fill=False,
|
|
118
|
+
fill_value=fill_value,
|
|
119
|
+
method="batch_naive",
|
|
120
|
+
)
|
|
121
|
+
shifts = None
|
|
122
|
+
except NeighborOverflowError:
|
|
123
|
+
# Increase buffer by 1.5x and retry
|
|
124
|
+
self.max_neighbors = self._round_to_16(int(self.max_neighbors * 1.5))
|
|
125
|
+
continue
|
|
126
|
+
|
|
127
|
+
# Get actual max neighbors from result
|
|
128
|
+
actual_max = int(num_neighbors.max().item())
|
|
129
|
+
|
|
130
|
+
# Adjust buffer if under-utilized (shrink at 2/3 of target for hysteresis)
|
|
131
|
+
# Use 2/3 threshold to prevent thrashing from small fluctuations
|
|
132
|
+
if actual_max < (2 / 3) * self.target_utilization * self.max_neighbors:
|
|
133
|
+
new_max = self._round_to_16(int(actual_max / self.target_utilization))
|
|
134
|
+
self.max_neighbors = max(new_max, 16) # Ensure minimum of 16
|
|
135
|
+
|
|
136
|
+
# Trim to actual max neighbors
|
|
137
|
+
actual_nnb = max(1, actual_max)
|
|
138
|
+
nbmat = nbmat[:, :actual_nnb]
|
|
139
|
+
if shifts is not None:
|
|
140
|
+
shifts = shifts[:, :actual_nnb]
|
|
141
|
+
|
|
142
|
+
return nbmat, num_neighbors, shifts
|
|
9
143
|
|
|
10
144
|
|
|
11
145
|
class AIMNet2Calculator:
|
|
12
|
-
"""
|
|
146
|
+
"""Generic AIMNet2 calculator.
|
|
147
|
+
|
|
13
148
|
A helper class to load AIMNet2 models and perform inference.
|
|
149
|
+
|
|
150
|
+
Parameters
|
|
151
|
+
----------
|
|
152
|
+
model : str | nn.Module
|
|
153
|
+
Model name (from registry), path to model file, or nn.Module instance.
|
|
154
|
+
nb_threshold : int
|
|
155
|
+
Threshold for neighbor list batching. Molecules larger than this use
|
|
156
|
+
flattened processing. Default is 120.
|
|
157
|
+
needs_coulomb : bool | None
|
|
158
|
+
Whether to add external Coulomb module. If None (default), determined
|
|
159
|
+
from model metadata. If True/False, overrides metadata.
|
|
160
|
+
needs_dispersion : bool | None
|
|
161
|
+
Whether to add external DFTD3 module. If None (default), determined
|
|
162
|
+
from model metadata. If True/False, overrides metadata.
|
|
163
|
+
device : str | None
|
|
164
|
+
Device to run the model on ("cuda", "cpu", or specific like "cuda:0").
|
|
165
|
+
If None (default), auto-detects CUDA availability.
|
|
166
|
+
compile_model : bool
|
|
167
|
+
Whether to compile the model with torch.compile(). Default is False.
|
|
168
|
+
compile_kwargs : dict | None
|
|
169
|
+
Additional keyword arguments to pass to torch.compile(). Default is None.
|
|
170
|
+
train : bool
|
|
171
|
+
Whether to enable training mode. Default is False (inference mode).
|
|
172
|
+
When False, all model parameters have requires_grad=False, which
|
|
173
|
+
improves torch.compile compatibility and reduces memory usage.
|
|
174
|
+
Set to True only when training the model.
|
|
175
|
+
|
|
176
|
+
Attributes
|
|
177
|
+
----------
|
|
178
|
+
model : nn.Module
|
|
179
|
+
The loaded AIMNet2 model.
|
|
180
|
+
device : str
|
|
181
|
+
Device the model is running on ("cuda" or "cpu").
|
|
182
|
+
cutoff : float
|
|
183
|
+
Short-range cutoff distance in Angstroms.
|
|
184
|
+
cutoff_lr : float | None
|
|
185
|
+
Long-range cutoff distance, or None if no LR modules.
|
|
186
|
+
external_coulomb : LRCoulomb | None
|
|
187
|
+
External Coulomb module if attached.
|
|
188
|
+
external_dftd3 : DFTD3 | None
|
|
189
|
+
External DFTD3 module if attached.
|
|
190
|
+
|
|
191
|
+
Notes
|
|
192
|
+
-----
|
|
193
|
+
External LR module behavior:
|
|
194
|
+
|
|
195
|
+
- For file-loaded models (str): metadata is loaded from file
|
|
196
|
+
- For nn.Module: metadata is read from model.metadata attribute if available
|
|
197
|
+
- Explicit flags (needs_coulomb, needs_dispersion) override metadata
|
|
198
|
+
- If no metadata and no explicit flags, no external LR modules are added
|
|
14
199
|
"""
|
|
15
200
|
|
|
16
|
-
keys_in: ClassVar[
|
|
17
|
-
keys_in_optional: ClassVar[
|
|
201
|
+
keys_in: ClassVar[dict[str, torch.dtype]] = {"coord": torch.float, "numbers": torch.int, "charge": torch.float}
|
|
202
|
+
keys_in_optional: ClassVar[dict[str, torch.dtype]] = {
|
|
18
203
|
"mult": torch.float,
|
|
19
204
|
"mol_idx": torch.int,
|
|
20
205
|
"nbmat": torch.int,
|
|
@@ -28,112 +213,538 @@ class AIMNet2Calculator:
|
|
|
28
213
|
keys_out: ClassVar[list[str]] = ["energy", "charges", "forces", "hessian", "stress"]
|
|
29
214
|
atom_feature_keys: ClassVar[list[str]] = ["coord", "numbers", "charges", "forces"]
|
|
30
215
|
|
|
31
|
-
def __init__(
|
|
32
|
-
self
|
|
216
|
+
def __init__(
|
|
217
|
+
self,
|
|
218
|
+
model: str | nn.Module = "aimnet2",
|
|
219
|
+
nb_threshold: int = 120,
|
|
220
|
+
needs_coulomb: bool | None = None,
|
|
221
|
+
needs_dispersion: bool | None = None,
|
|
222
|
+
device: str | None = None,
|
|
223
|
+
compile_model: bool = False,
|
|
224
|
+
compile_kwargs: dict | None = None,
|
|
225
|
+
train: bool = False,
|
|
226
|
+
):
|
|
227
|
+
# Device selection: use provided or auto-detect
|
|
228
|
+
if device is not None:
|
|
229
|
+
self.device = device
|
|
230
|
+
else:
|
|
231
|
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
232
|
+
self.external_coulomb: LRCoulomb | None = None
|
|
233
|
+
self.external_dftd3: DFTD3 | None = None
|
|
234
|
+
# Default cutoffs for LR modules
|
|
235
|
+
self._default_dsf_cutoff = 15.0
|
|
236
|
+
self._default_dftd3_cutoff = 15.0
|
|
237
|
+
self._default_dftd3_smoothing = 0.2
|
|
238
|
+
|
|
239
|
+
# Load model and get metadata
|
|
240
|
+
metadata: dict | None = None
|
|
33
241
|
if isinstance(model, str):
|
|
34
242
|
p = get_model_path(model)
|
|
35
|
-
self.model =
|
|
243
|
+
self.model, metadata = load_model(p, device=self.device)
|
|
244
|
+
self.cutoff = metadata["cutoff"]
|
|
36
245
|
elif isinstance(model, nn.Module):
|
|
37
246
|
self.model = model.to(self.device)
|
|
247
|
+
self.cutoff = getattr(self.model, "cutoff", 5.0)
|
|
248
|
+
metadata = getattr(self.model, "_metadata", None)
|
|
38
249
|
else:
|
|
39
250
|
raise TypeError("Invalid model type/name.")
|
|
40
251
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
252
|
+
# Compile model if requested
|
|
253
|
+
if compile_model:
|
|
254
|
+
kwargs = compile_kwargs or {}
|
|
255
|
+
self.model = torch.compile(self.model, **kwargs)
|
|
256
|
+
|
|
257
|
+
# Resolve final flags (explicit overrides metadata)
|
|
258
|
+
final_needs_coulomb = (
|
|
259
|
+
needs_coulomb
|
|
260
|
+
if needs_coulomb is not None
|
|
261
|
+
else (metadata.get("needs_coulomb", False) if metadata is not None else False)
|
|
262
|
+
)
|
|
263
|
+
final_needs_dispersion = (
|
|
264
|
+
needs_dispersion
|
|
265
|
+
if needs_dispersion is not None
|
|
266
|
+
else (metadata.get("needs_dispersion", False) if metadata is not None else False)
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# Set up external Coulomb if needed
|
|
270
|
+
if final_needs_coulomb:
|
|
271
|
+
sr_embedded = metadata.get("coulomb_mode") == "sr_embedded" if metadata is not None else False
|
|
272
|
+
# For PBC, user can switch to DSF/Ewald via set_lrcoulomb_method()
|
|
273
|
+
# When sr_embedded=True: model has SRCoulomb which subtracts SR, so external
|
|
274
|
+
# should compute FULL (subtract_sr=False) to give: (NN - SR) + FULL = NN + LR
|
|
275
|
+
# When sr_embedded=False: model has no SR embedded, so external should compute
|
|
276
|
+
# LR only (subtract_sr=True) to avoid double-counting
|
|
277
|
+
self.external_coulomb = LRCoulomb(
|
|
278
|
+
key_in="charges",
|
|
279
|
+
key_out="energy",
|
|
280
|
+
method="simple",
|
|
281
|
+
rc=metadata.get("coulomb_sr_rc", 4.6) if metadata is not None else 4.6,
|
|
282
|
+
envelope=metadata.get("coulomb_sr_envelope", "exp") if metadata is not None else "exp",
|
|
283
|
+
subtract_sr=not sr_embedded,
|
|
284
|
+
)
|
|
285
|
+
self.external_coulomb = self.external_coulomb.to(self.device)
|
|
286
|
+
|
|
287
|
+
# Set up external DFTD3 if needed
|
|
288
|
+
if final_needs_dispersion:
|
|
289
|
+
d3_params = metadata.get("d3_params") if metadata else None
|
|
290
|
+
if d3_params is None:
|
|
291
|
+
raise ValueError(
|
|
292
|
+
"needs_dispersion=True but d3_params not found in metadata. "
|
|
293
|
+
"Provide d3_params in model metadata or set needs_dispersion=False."
|
|
294
|
+
)
|
|
295
|
+
self.external_dftd3 = DFTD3(
|
|
296
|
+
s8=d3_params["s8"],
|
|
297
|
+
a1=d3_params["a1"],
|
|
298
|
+
a2=d3_params["a2"],
|
|
299
|
+
s6=d3_params.get("s6", 1.0),
|
|
300
|
+
)
|
|
301
|
+
self.external_dftd3 = self.external_dftd3.to(self.device)
|
|
302
|
+
|
|
303
|
+
# Determine if model has long-range modules (embedded or external)
|
|
304
|
+
has_embedded_lr = metadata.get("has_embedded_lr", False) if metadata is not None else False
|
|
305
|
+
self.lr = (
|
|
306
|
+
hasattr(self.model, "cutoff_lr")
|
|
307
|
+
or self.external_coulomb is not None
|
|
308
|
+
or self.external_dftd3 is not None
|
|
309
|
+
or has_embedded_lr
|
|
310
|
+
)
|
|
311
|
+
# Set cutoff_lr based on model attribute or external modules
|
|
312
|
+
if hasattr(self.model, "cutoff_lr"):
|
|
313
|
+
self.cutoff_lr = getattr(self.model, "cutoff_lr", float("inf"))
|
|
314
|
+
elif self.external_coulomb is not None:
|
|
315
|
+
# For "simple" method, use inf (all pairs). For DSF, use dsf_cutoff.
|
|
316
|
+
if self.external_coulomb.method == "simple":
|
|
317
|
+
self.cutoff_lr = float("inf")
|
|
318
|
+
else:
|
|
319
|
+
self.cutoff_lr = self._default_dsf_cutoff
|
|
320
|
+
elif self.external_dftd3 is not None:
|
|
321
|
+
self.cutoff_lr = self._default_dftd3_cutoff
|
|
322
|
+
elif has_embedded_lr:
|
|
323
|
+
# Embedded LR modules (D3TS, SRCoulomb) need nbmat_lr
|
|
324
|
+
self.cutoff_lr = self._default_dftd3_cutoff
|
|
325
|
+
else:
|
|
326
|
+
self.cutoff_lr = None
|
|
45
327
|
self.nb_threshold = nb_threshold
|
|
46
328
|
|
|
329
|
+
# Create adaptive neighbor list instances
|
|
330
|
+
self._nblist = AdaptiveNeighborList(cutoff=self.cutoff)
|
|
331
|
+
|
|
332
|
+
# Track separate cutoffs for LR modules
|
|
333
|
+
self._coulomb_cutoff: float | None = None
|
|
334
|
+
self._dftd3_cutoff: float = self._default_dftd3_cutoff
|
|
335
|
+
if self.external_coulomb is not None:
|
|
336
|
+
if self.external_coulomb.method == "simple":
|
|
337
|
+
self._coulomb_cutoff = float("inf")
|
|
338
|
+
elif self.external_coulomb.method == "ewald":
|
|
339
|
+
self._coulomb_cutoff = None # Ewald manages its own cutoff
|
|
340
|
+
else:
|
|
341
|
+
self._coulomb_cutoff = self.external_coulomb.dsf_rc
|
|
342
|
+
if self.external_dftd3 is not None:
|
|
343
|
+
self._dftd3_cutoff = self.external_dftd3.smoothing_off
|
|
344
|
+
|
|
345
|
+
# Create long-range neighbor list(s) if LR modules present
|
|
346
|
+
self._nblist_lr: AdaptiveNeighborList | None = None
|
|
347
|
+
self._nblist_dftd3: AdaptiveNeighborList | None = None
|
|
348
|
+
self._nblist_coulomb: AdaptiveNeighborList | None = None
|
|
349
|
+
self._update_lr_nblists()
|
|
350
|
+
|
|
47
351
|
# indicator if input was flattened
|
|
48
352
|
self._batch = None
|
|
49
353
|
self._max_mol_size: int = 0
|
|
50
354
|
# placeholder for tensors that require grad
|
|
51
355
|
self._saved_for_grad = {}
|
|
52
|
-
# set flag of current Coulomb
|
|
53
|
-
|
|
54
|
-
if
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
356
|
+
# set flag of current Coulomb method
|
|
357
|
+
self._coulomb_method: str | None = None
|
|
358
|
+
if self.external_coulomb is not None:
|
|
359
|
+
self._coulomb_method = self.external_coulomb.method
|
|
360
|
+
elif self._has_embedded_coulomb():
|
|
361
|
+
# Legacy models have embedded Coulomb with "simple" method
|
|
362
|
+
self._coulomb_method = "simple"
|
|
363
|
+
|
|
364
|
+
# Set training mode (default False for inference)
|
|
365
|
+
self._train = train
|
|
366
|
+
self.model.train(train)
|
|
367
|
+
if not train:
|
|
368
|
+
# Disable gradients on all parameters for inference mode
|
|
369
|
+
for param in self.model.parameters():
|
|
370
|
+
param.requires_grad_(False)
|
|
371
|
+
if self.external_coulomb is not None:
|
|
372
|
+
for param in self.external_coulomb.parameters():
|
|
373
|
+
param.requires_grad_(False)
|
|
374
|
+
if self.external_dftd3 is not None:
|
|
375
|
+
for param in self.external_dftd3.parameters():
|
|
376
|
+
param.requires_grad_(False)
|
|
60
377
|
|
|
61
378
|
def __call__(self, *args, **kwargs):
|
|
62
379
|
return self.eval(*args, **kwargs)
|
|
63
380
|
|
|
381
|
+
@property
|
|
382
|
+
def has_external_coulomb(self) -> bool:
|
|
383
|
+
"""Check if calculator has external Coulomb module attached.
|
|
384
|
+
|
|
385
|
+
Returns True for new-format models that were trained with Coulomb
|
|
386
|
+
and have it externalized. For legacy models, Coulomb is embedded
|
|
387
|
+
in the model itself, so this returns False.
|
|
388
|
+
"""
|
|
389
|
+
return self.external_coulomb is not None
|
|
390
|
+
|
|
391
|
+
@property
|
|
392
|
+
def has_external_dftd3(self) -> bool:
|
|
393
|
+
"""Check if calculator has external DFTD3 module attached.
|
|
394
|
+
|
|
395
|
+
Returns True for new-format models that were trained with DFTD3/D3BJ
|
|
396
|
+
dispersion and have it externalized. For legacy models or D3TS models,
|
|
397
|
+
dispersion is embedded in the model itself, so this returns False.
|
|
398
|
+
"""
|
|
399
|
+
return self.external_dftd3 is not None
|
|
400
|
+
|
|
401
|
+
@property
|
|
402
|
+
def coulomb_method(self) -> str | None:
|
|
403
|
+
"""Get the current Coulomb method.
|
|
404
|
+
|
|
405
|
+
Returns
|
|
406
|
+
-------
|
|
407
|
+
str | None
|
|
408
|
+
One of "simple", "dsf", "ewald", or None if no external Coulomb.
|
|
409
|
+
For legacy models with embedded Coulomb, returns None.
|
|
410
|
+
"""
|
|
411
|
+
if self.external_coulomb is not None:
|
|
412
|
+
return self.external_coulomb.method
|
|
413
|
+
return None
|
|
414
|
+
|
|
415
|
+
@property
|
|
416
|
+
def coulomb_cutoff(self) -> float | None:
|
|
417
|
+
"""Get the current Coulomb cutoff distance.
|
|
418
|
+
|
|
419
|
+
Returns
|
|
420
|
+
-------
|
|
421
|
+
float | None
|
|
422
|
+
The cutoff distance for Coulomb calculations, or None if not applicable.
|
|
423
|
+
For "simple" method, this is inf. For "ewald", this is None.
|
|
424
|
+
Use set_lrcoulomb_method() to change.
|
|
425
|
+
"""
|
|
426
|
+
return self._coulomb_cutoff
|
|
427
|
+
|
|
428
|
+
@property
|
|
429
|
+
def dftd3_cutoff(self) -> float:
|
|
430
|
+
"""Get the current DFTD3 cutoff distance.
|
|
431
|
+
|
|
432
|
+
Returns
|
|
433
|
+
-------
|
|
434
|
+
float
|
|
435
|
+
The cutoff distance for DFTD3 calculations in Angstroms.
|
|
436
|
+
"""
|
|
437
|
+
return self._dftd3_cutoff
|
|
438
|
+
|
|
439
|
+
def _has_embedded_dispersion(self) -> bool:
|
|
440
|
+
"""Check if model has embedded dispersion (not externalized).
|
|
441
|
+
|
|
442
|
+
Uses model metadata when available, otherwise returns False (unknown).
|
|
443
|
+
|
|
444
|
+
Returns
|
|
445
|
+
-------
|
|
446
|
+
bool
|
|
447
|
+
True if model has embedded dispersion module (D3TS or legacy DFTD3).
|
|
448
|
+
"""
|
|
449
|
+
meta = getattr(self.model, "_metadata", None)
|
|
450
|
+
if meta is None:
|
|
451
|
+
return False # Unknown, assume no embedded dispersion
|
|
452
|
+
|
|
453
|
+
# New format: Check for embedded D3TS via has_embedded_lr
|
|
454
|
+
# If has_embedded_lr=True and coulomb_mode != "sr_embedded", it's D3TS
|
|
455
|
+
if meta.get("has_embedded_lr", False):
|
|
456
|
+
coulomb_mode = meta.get("coulomb_mode", "none")
|
|
457
|
+
if coulomb_mode != "sr_embedded":
|
|
458
|
+
return True # Must be D3TS (embedded dispersion)
|
|
459
|
+
|
|
460
|
+
# Legacy format: If needs_dispersion=False and d3_params exist, dispersion is embedded
|
|
461
|
+
# (legacy JIT models have dispersion embedded)
|
|
462
|
+
return not meta.get("needs_dispersion", False) and meta.get("d3_params") is not None
|
|
463
|
+
|
|
464
|
+
def _has_embedded_coulomb(self) -> bool:
|
|
465
|
+
"""Check if model has embedded Coulomb (not externalized).
|
|
466
|
+
|
|
467
|
+
Uses model metadata when available, otherwise returns False (unknown).
|
|
468
|
+
|
|
469
|
+
Returns
|
|
470
|
+
-------
|
|
471
|
+
bool
|
|
472
|
+
True if model has embedded Coulomb module.
|
|
473
|
+
"""
|
|
474
|
+
meta = getattr(self.model, "_metadata", None)
|
|
475
|
+
if meta is None:
|
|
476
|
+
return False # Unknown, assume no embedded Coulomb
|
|
477
|
+
# If needs_coulomb=False and coulomb_mode is not "none", Coulomb is embedded
|
|
478
|
+
# (legacy JIT models have full Coulomb embedded)
|
|
479
|
+
return not meta.get("needs_coulomb", False) and meta.get("coulomb_mode", "none") != "none"
|
|
480
|
+
|
|
481
|
+
def _should_use_separate_nblist(self, cutoff1: float, cutoff2: float) -> bool:
|
|
482
|
+
"""Check if two cutoffs differ enough to warrant separate neighbor lists.
|
|
483
|
+
|
|
484
|
+
Parameters
|
|
485
|
+
----------
|
|
486
|
+
cutoff1 : float
|
|
487
|
+
First cutoff distance.
|
|
488
|
+
cutoff2 : float
|
|
489
|
+
Second cutoff distance.
|
|
490
|
+
|
|
491
|
+
Returns
|
|
492
|
+
-------
|
|
493
|
+
bool
|
|
494
|
+
True if cutoffs differ by more than 20%, False otherwise.
|
|
495
|
+
"""
|
|
496
|
+
# Handle edge cases
|
|
497
|
+
if cutoff1 <= 0 or cutoff2 <= 0:
|
|
498
|
+
return False
|
|
499
|
+
if not math.isfinite(cutoff1) or not math.isfinite(cutoff2):
|
|
500
|
+
return False
|
|
501
|
+
ratio = max(cutoff1, cutoff2) / min(cutoff1, cutoff2)
|
|
502
|
+
return ratio > 1.2
|
|
503
|
+
|
|
504
|
+
def _update_lr_nblists(self) -> None:
|
|
505
|
+
"""Update long-range neighbor list instances based on current cutoffs.
|
|
506
|
+
|
|
507
|
+
Creates separate neighbor lists for DFTD3 and Coulomb if their cutoffs
|
|
508
|
+
differ by more than 20%. Otherwise, uses a single shared neighbor list.
|
|
509
|
+
Ewald uses its own internal neighbor list and ignores cutoffs.
|
|
510
|
+
"""
|
|
511
|
+
if not self.lr:
|
|
512
|
+
self._nblist_lr = None
|
|
513
|
+
self._nblist_dftd3 = None
|
|
514
|
+
self._nblist_coulomb = None
|
|
515
|
+
return
|
|
516
|
+
|
|
517
|
+
has_dftd3 = self.external_dftd3 is not None or self._has_embedded_dispersion()
|
|
518
|
+
has_coulomb = self.external_coulomb is not None or self._has_embedded_coulomb()
|
|
519
|
+
|
|
520
|
+
# Determine effective cutoffs (None means no neighbor list needed for that module)
|
|
521
|
+
dftd3_cutoff = self._dftd3_cutoff if has_dftd3 else 0.0
|
|
522
|
+
coulomb_cutoff = self._coulomb_cutoff if has_coulomb and self._coulomb_cutoff is not None else 0.0
|
|
523
|
+
|
|
524
|
+
# Check if we need separate neighbor lists (both finite and differ by >20%)
|
|
525
|
+
if (
|
|
526
|
+
has_dftd3
|
|
527
|
+
and has_coulomb
|
|
528
|
+
and math.isfinite(dftd3_cutoff)
|
|
529
|
+
and math.isfinite(coulomb_cutoff)
|
|
530
|
+
and coulomb_cutoff > 0
|
|
531
|
+
and self._should_use_separate_nblist(dftd3_cutoff, coulomb_cutoff)
|
|
532
|
+
):
|
|
533
|
+
# Use separate neighbor lists
|
|
534
|
+
self._nblist_dftd3 = AdaptiveNeighborList(cutoff=dftd3_cutoff)
|
|
535
|
+
self._nblist_coulomb = AdaptiveNeighborList(cutoff=coulomb_cutoff)
|
|
536
|
+
self._nblist_lr = None
|
|
537
|
+
return
|
|
538
|
+
|
|
539
|
+
# Use single shared neighbor list with max cutoff
|
|
540
|
+
max_cutoff = 0.0
|
|
541
|
+
if has_dftd3 and math.isfinite(dftd3_cutoff):
|
|
542
|
+
max_cutoff = max(max_cutoff, dftd3_cutoff)
|
|
543
|
+
if has_coulomb:
|
|
544
|
+
if coulomb_cutoff == float("inf"):
|
|
545
|
+
# Simple Coulomb needs all pairs
|
|
546
|
+
self._nblist_lr = AdaptiveNeighborList(cutoff=1e6)
|
|
547
|
+
self._nblist_dftd3 = None
|
|
548
|
+
self._nblist_coulomb = None
|
|
549
|
+
return
|
|
550
|
+
if math.isfinite(coulomb_cutoff) and coulomb_cutoff > 0:
|
|
551
|
+
max_cutoff = max(max_cutoff, coulomb_cutoff)
|
|
552
|
+
|
|
553
|
+
if max_cutoff > 0:
|
|
554
|
+
self._nblist_lr = AdaptiveNeighborList(cutoff=max_cutoff)
|
|
555
|
+
else:
|
|
556
|
+
self._nblist_lr = None
|
|
557
|
+
self._nblist_dftd3 = None
|
|
558
|
+
self._nblist_coulomb = None
|
|
559
|
+
|
|
64
560
|
def set_lrcoulomb_method(
|
|
65
|
-
self,
|
|
561
|
+
self,
|
|
562
|
+
method: Literal["simple", "dsf", "ewald"],
|
|
563
|
+
cutoff: float = 15.0,
|
|
564
|
+
dsf_alpha: float = 0.2,
|
|
565
|
+
ewald_accuracy: float = 1e-8,
|
|
66
566
|
):
|
|
567
|
+
"""Set the long-range Coulomb method.
|
|
568
|
+
|
|
569
|
+
Parameters
|
|
570
|
+
----------
|
|
571
|
+
method : str
|
|
572
|
+
One of "simple", "dsf", or "ewald".
|
|
573
|
+
cutoff : float
|
|
574
|
+
Cutoff distance for DSF neighbor list. Default is 15.0.
|
|
575
|
+
Not used for Ewald (which computes cutoffs from accuracy).
|
|
576
|
+
dsf_alpha : float
|
|
577
|
+
Alpha parameter for DSF method. Default is 0.2.
|
|
578
|
+
ewald_accuracy : float
|
|
579
|
+
Target accuracy for Ewald summation. Controls the real-space
|
|
580
|
+
and reciprocal-space cutoffs. Lower values give higher accuracy
|
|
581
|
+
but require more computation. Default is 1e-8.
|
|
582
|
+
|
|
583
|
+
The Ewald cutoffs are computed as:
|
|
584
|
+
- eta = (V^2 / N)^(1/6) / sqrt(2*pi)
|
|
585
|
+
- cutoff_real = sqrt(-2 * ln(accuracy)) * eta
|
|
586
|
+
- cutoff_recip = sqrt(-2 * ln(accuracy)) / eta
|
|
587
|
+
|
|
588
|
+
Notes
|
|
589
|
+
-----
|
|
590
|
+
For new-format models with external Coulomb, this updates the external module.
|
|
591
|
+
For legacy models with embedded Coulomb, a warning is issued as those modules
|
|
592
|
+
cannot be modified at runtime.
|
|
593
|
+
"""
|
|
67
594
|
if method not in ("simple", "dsf", "ewald"):
|
|
68
595
|
raise ValueError(f"Invalid method: {method}")
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
596
|
+
|
|
597
|
+
# Warn if model has embedded Coulomb (legacy models)
|
|
598
|
+
if self._has_embedded_coulomb() and self.external_coulomb is None:
|
|
599
|
+
warnings.warn(
|
|
600
|
+
"Model has embedded Coulomb module (legacy format). "
|
|
601
|
+
"set_lrcoulomb_method() only affects external Coulomb modules. "
|
|
602
|
+
"For legacy models, the Coulomb method cannot be changed at runtime.",
|
|
603
|
+
stacklevel=2,
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
# Update external LRCoulomb module if present
|
|
607
|
+
if self.external_coulomb is not None:
|
|
608
|
+
self.external_coulomb.method = method
|
|
609
|
+
if method == "dsf":
|
|
610
|
+
self.external_coulomb.dsf_alpha = dsf_alpha
|
|
611
|
+
self.external_coulomb.dsf_rc = cutoff
|
|
77
612
|
elif method == "ewald":
|
|
78
|
-
|
|
79
|
-
|
|
613
|
+
self.external_coulomb.ewald_accuracy = ewald_accuracy
|
|
614
|
+
|
|
615
|
+
# Update _coulomb_cutoff based on method
|
|
616
|
+
if method == "simple":
|
|
617
|
+
self._coulomb_cutoff = float("inf")
|
|
618
|
+
elif method == "dsf":
|
|
619
|
+
self._coulomb_cutoff = cutoff
|
|
620
|
+
elif method == "ewald":
|
|
621
|
+
self._coulomb_cutoff = None # Ewald manages its own real-space cutoff
|
|
622
|
+
|
|
623
|
+
# Update cutoff_lr for backward compatibility
|
|
624
|
+
if self._coulomb_cutoff is not None:
|
|
625
|
+
self.cutoff_lr = self._coulomb_cutoff
|
|
626
|
+
else:
|
|
627
|
+
# Ewald - use DFTD3 cutoff if available, else None
|
|
628
|
+
self.cutoff_lr = self._dftd3_cutoff if self.external_dftd3 is not None else None
|
|
629
|
+
|
|
80
630
|
self._coulomb_method = method
|
|
631
|
+
self._update_lr_nblists()
|
|
632
|
+
|
|
633
|
+
def set_lr_cutoff(self, cutoff: float) -> None:
|
|
634
|
+
"""Set the unified long-range cutoff for all LR modules.
|
|
81
635
|
|
|
82
|
-
|
|
636
|
+
Parameters
|
|
637
|
+
----------
|
|
638
|
+
cutoff : float
|
|
639
|
+
Cutoff distance in Angstroms for LR neighbor lists.
|
|
640
|
+
|
|
641
|
+
Notes
|
|
642
|
+
-----
|
|
643
|
+
This updates both _coulomb_cutoff and _dftd3_cutoff.
|
|
644
|
+
Ewald uses its own internal neighbor list and ignores this cutoff.
|
|
645
|
+
"""
|
|
646
|
+
# Update both cutoffs (but not for ewald which manages its own)
|
|
647
|
+
if self._coulomb_method != "ewald":
|
|
648
|
+
self._coulomb_cutoff = cutoff
|
|
649
|
+
self._dftd3_cutoff = cutoff
|
|
650
|
+
self.cutoff_lr = cutoff
|
|
651
|
+
self._update_lr_nblists()
|
|
652
|
+
|
|
653
|
+
def set_dftd3_cutoff(self, cutoff: float | None = None, smoothing_fraction: float | None = None) -> None:
|
|
654
|
+
"""Set DFTD3 cutoff and smoothing.
|
|
655
|
+
|
|
656
|
+
Parameters
|
|
657
|
+
----------
|
|
658
|
+
cutoff : float | None
|
|
659
|
+
Cutoff distance in Angstroms for DFTD3 calculation.
|
|
660
|
+
Default is _default_dftd3_cutoff (15.0).
|
|
661
|
+
smoothing_fraction : float | None
|
|
662
|
+
Fraction of cutoff used as smoothing width.
|
|
663
|
+
Default is _default_dftd3_smoothing (0.2).
|
|
664
|
+
|
|
665
|
+
Notes
|
|
666
|
+
-----
|
|
667
|
+
This method only affects external DFTD3 modules attached to
|
|
668
|
+
new-format models. For legacy models with embedded DFTD3,
|
|
669
|
+
the smoothing is fixed.
|
|
670
|
+
|
|
671
|
+
Updates _dftd3_cutoff and rebuilds neighbor lists.
|
|
672
|
+
"""
|
|
673
|
+
if cutoff is None:
|
|
674
|
+
cutoff = self._default_dftd3_cutoff
|
|
675
|
+
if smoothing_fraction is None:
|
|
676
|
+
smoothing_fraction = self._default_dftd3_smoothing
|
|
677
|
+
|
|
678
|
+
self._dftd3_cutoff = cutoff
|
|
679
|
+
if self.external_dftd3 is not None:
|
|
680
|
+
self.external_dftd3.set_smoothing(cutoff, smoothing_fraction)
|
|
681
|
+
self._update_lr_nblists()
|
|
682
|
+
|
|
683
|
+
def eval(self, data: dict[str, Any], forces=False, stress=False, hessian=False) -> dict[str, Tensor]:
|
|
83
684
|
data = self.prepare_input(data)
|
|
685
|
+
|
|
84
686
|
if hessian and "mol_idx" in data and data["mol_idx"][-1] > 0:
|
|
85
687
|
raise NotImplementedError("Hessian calculation is not supported for multiple molecules")
|
|
86
688
|
data = self.set_grad_tensors(data, forces=forces, stress=stress, hessian=hessian)
|
|
87
689
|
with torch.jit.optimized_execution(False): # type: ignore
|
|
88
690
|
data = self.model(data)
|
|
691
|
+
# Run external modules if present
|
|
692
|
+
data = self._run_external_modules(data, compute_stress=stress)
|
|
89
693
|
data = self.get_derivatives(data, forces=forces, stress=stress, hessian=hessian)
|
|
90
694
|
data = self.process_output(data)
|
|
91
695
|
return data
|
|
92
696
|
|
|
93
|
-
def
|
|
697
|
+
def _run_external_modules(self, data: dict[str, Tensor], compute_stress: bool = False) -> dict[str, Tensor]:
|
|
698
|
+
"""Run external Coulomb and DFTD3 modules if attached."""
|
|
699
|
+
if self.external_coulomb is not None:
|
|
700
|
+
data = self.external_coulomb(data)
|
|
701
|
+
if self.external_dftd3 is not None:
|
|
702
|
+
self.external_dftd3.compute_virial = compute_stress
|
|
703
|
+
data = self.external_dftd3(data)
|
|
704
|
+
return data
|
|
705
|
+
|
|
706
|
+
def prepare_input(self, data: dict[str, Any]) -> dict[str, Tensor]:
|
|
94
707
|
data = self.to_input_tensors(data)
|
|
95
708
|
data = self.mol_flatten(data)
|
|
96
|
-
if data.get("cell") is not None:
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
if self._coulomb_method == "simple":
|
|
100
|
-
warnings.warn("Switching to DSF Coulomb for PBC", stacklevel=1)
|
|
101
|
-
self.set_lrcoulomb_method("dsf")
|
|
709
|
+
if data.get("cell") is not None and self._coulomb_method == "simple":
|
|
710
|
+
warnings.warn("Switching to DSF Coulomb for PBC", stacklevel=1)
|
|
711
|
+
self.set_lrcoulomb_method("dsf")
|
|
102
712
|
if data["coord"].ndim == 2:
|
|
103
|
-
|
|
713
|
+
# Skip neighbor list calculation if already provided
|
|
714
|
+
if "nbmat" not in data:
|
|
715
|
+
data = self.make_nbmat(data)
|
|
104
716
|
data = self.pad_input(data)
|
|
105
717
|
return data
|
|
106
718
|
|
|
107
|
-
def process_output(self, data:
|
|
719
|
+
def process_output(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
108
720
|
if data["coord"].ndim == 2:
|
|
109
721
|
data = self.unpad_output(data)
|
|
110
722
|
data = self.mol_unflatten(data)
|
|
111
723
|
data = self.keep_only(data)
|
|
112
724
|
return data
|
|
113
725
|
|
|
114
|
-
def to_input_tensors(self, data:
|
|
726
|
+
def to_input_tensors(self, data: dict[str, Any]) -> dict[str, Tensor]:
|
|
115
727
|
ret = {}
|
|
116
728
|
for k in self.keys_in:
|
|
117
729
|
if k not in data:
|
|
118
730
|
raise KeyError(f"Missing key {k} in the input data")
|
|
119
|
-
#
|
|
731
|
+
# Detach from computation graph to prevent gradient accumulation
|
|
120
732
|
ret[k] = torch.as_tensor(data[k], device=self.device, dtype=self.keys_in[k]).detach()
|
|
121
733
|
for k in self.keys_in_optional:
|
|
122
734
|
if k in data and data[k] is not None:
|
|
123
735
|
ret[k] = torch.as_tensor(data[k], device=self.device, dtype=self.keys_in_optional[k]).detach()
|
|
124
|
-
#
|
|
736
|
+
# Ensure all tensors have at least 1D shape for consistent batch processing
|
|
125
737
|
for k, v in ret.items():
|
|
126
738
|
if v.ndim == 0:
|
|
127
739
|
ret[k] = v.unsqueeze(0)
|
|
128
740
|
return ret
|
|
129
741
|
|
|
130
|
-
def mol_flatten(self, data:
|
|
742
|
+
def mol_flatten(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
131
743
|
"""Flatten the input data for multiple molecules.
|
|
132
744
|
Will not flatten for batched input and molecule size below threshold.
|
|
133
745
|
"""
|
|
134
746
|
ndim = data["coord"].ndim
|
|
135
747
|
if ndim == 2:
|
|
136
|
-
# single molecule or already flattened
|
|
137
748
|
self._batch = None
|
|
138
749
|
if "mol_idx" not in data:
|
|
139
750
|
data["mol_idx"] = torch.zeros(data["coord"].shape[0], dtype=torch.long, device=self.device)
|
|
@@ -144,9 +755,9 @@ class AIMNet2Calculator:
|
|
|
144
755
|
self._max_mol_size = data["mol_idx"].unique(return_counts=True)[1].max().item()
|
|
145
756
|
|
|
146
757
|
elif ndim == 3:
|
|
147
|
-
# batched input
|
|
148
758
|
B, N = data["coord"].shape[:2]
|
|
149
|
-
|
|
759
|
+
# Force flattening for PBC (cell present) to ensure make_nbmat computes proper neighbor lists with shifts
|
|
760
|
+
if self.nb_threshold < N or self.device == "cpu" or data.get("cell") is not None:
|
|
150
761
|
self._batch = B
|
|
151
762
|
data["mol_idx"] = torch.repeat_interleave(
|
|
152
763
|
torch.arange(0, B, device=self.device), torch.full((B,), N, device=self.device)
|
|
@@ -159,7 +770,7 @@ class AIMNet2Calculator:
|
|
|
159
770
|
self._max_mol_size = N
|
|
160
771
|
return data
|
|
161
772
|
|
|
162
|
-
def mol_unflatten(self, data:
|
|
773
|
+
def mol_unflatten(self, data: dict[str, Tensor], batch=None) -> dict[str, Tensor]:
|
|
163
774
|
batch = batch if batch is not None else self._batch
|
|
164
775
|
if batch is not None:
|
|
165
776
|
for k, v in data.items():
|
|
@@ -167,47 +778,121 @@ class AIMNet2Calculator:
|
|
|
167
778
|
data[k] = v.view(batch, -1, *v.shape[1:])
|
|
168
779
|
return data
|
|
169
780
|
|
|
170
|
-
def make_nbmat(self, data:
|
|
781
|
+
def make_nbmat(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
171
782
|
assert self._max_mol_size > 0, "Molecule size is not set"
|
|
172
783
|
|
|
784
|
+
# Prepare batch_idx from mol_idx
|
|
785
|
+
mol_idx = data.get("mol_idx")
|
|
786
|
+
|
|
173
787
|
if "cell" in data and data["cell"] is not None:
|
|
174
|
-
data["coord"] = move_coord_to_cell(data["coord"], data["cell"])
|
|
788
|
+
data["coord"] = move_coord_to_cell(data["coord"], data["cell"], mol_idx)
|
|
175
789
|
cell = data["cell"]
|
|
176
790
|
else:
|
|
177
791
|
cell = None
|
|
178
792
|
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
793
|
+
N = data["coord"].shape[0]
|
|
794
|
+
_pbc = cell is not None
|
|
795
|
+
batch_idx = mol_idx.to(torch.int32) if mol_idx is not None else None
|
|
796
|
+
|
|
797
|
+
# Prepare cell and pbc tensors for nvalchemiops
|
|
798
|
+
if _pbc:
|
|
799
|
+
if cell.ndim == 2:
|
|
800
|
+
cell_batched = cell.unsqueeze(0) # (1, 3, 3)
|
|
801
|
+
else:
|
|
802
|
+
cell_batched = cell # (num_systems, 3, 3)
|
|
803
|
+
num_systems = cell_batched.shape[0]
|
|
804
|
+
pbc = torch.tensor([[True, True, True]] * num_systems, dtype=torch.bool, device=cell.device)
|
|
805
|
+
else:
|
|
806
|
+
cell_batched = None
|
|
807
|
+
pbc = None
|
|
808
|
+
|
|
809
|
+
# Short-range neighbors (always)
|
|
810
|
+
nbmat1, _, shifts1 = self._nblist(
|
|
811
|
+
positions=data["coord"],
|
|
812
|
+
cell=cell_batched,
|
|
813
|
+
pbc=pbc,
|
|
814
|
+
batch_idx=batch_idx,
|
|
815
|
+
fill_value=N,
|
|
816
|
+
)
|
|
817
|
+
|
|
818
|
+
nbmat1, shifts1 = _add_padding_row(nbmat1, shifts1, N)
|
|
198
819
|
data["nbmat"] = nbmat1
|
|
199
|
-
if self.lr:
|
|
200
|
-
assert nbmat2 is not None
|
|
201
|
-
data["nbmat_lr"] = nbmat2
|
|
202
820
|
if cell is not None:
|
|
203
821
|
assert shifts1 is not None
|
|
204
822
|
data["shifts"] = shifts1
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
823
|
+
|
|
824
|
+
# Unified neighbor list when LR module cutoffs are similar
|
|
825
|
+
if self._nblist_lr is not None:
|
|
826
|
+
if self._coulomb_cutoff == float("inf"):
|
|
827
|
+
self._nblist_lr.max_neighbors = N
|
|
828
|
+
nbmat_lr, _, shifts_lr = self._nblist_lr(
|
|
829
|
+
positions=data["coord"],
|
|
830
|
+
cell=cell_batched,
|
|
831
|
+
pbc=pbc,
|
|
832
|
+
batch_idx=batch_idx,
|
|
833
|
+
fill_value=N,
|
|
834
|
+
)
|
|
835
|
+
nbmat_lr, shifts_lr = _add_padding_row(nbmat_lr, shifts_lr, N)
|
|
836
|
+
|
|
837
|
+
# All LR modules share the same neighbor list when cutoffs are similar
|
|
838
|
+
data["nbmat_lr"] = nbmat_lr
|
|
839
|
+
data["nbmat_coulomb"] = nbmat_lr
|
|
840
|
+
data["nbmat_dftd3"] = nbmat_lr
|
|
841
|
+
if cell is not None and shifts_lr is not None:
|
|
842
|
+
data["shifts_lr"] = shifts_lr
|
|
843
|
+
data["shifts_coulomb"] = shifts_lr
|
|
844
|
+
data["shifts_dftd3"] = shifts_lr
|
|
845
|
+
else:
|
|
846
|
+
if self._nblist_coulomb is not None:
|
|
847
|
+
if self._coulomb_cutoff == float("inf"):
|
|
848
|
+
self._nblist_coulomb.max_neighbors = N
|
|
849
|
+
nbmat_coulomb, _, shifts_coulomb = self._nblist_coulomb(
|
|
850
|
+
positions=data["coord"],
|
|
851
|
+
cell=cell_batched,
|
|
852
|
+
pbc=pbc,
|
|
853
|
+
batch_idx=batch_idx,
|
|
854
|
+
fill_value=N,
|
|
855
|
+
)
|
|
856
|
+
nbmat_coulomb, shifts_coulomb = _add_padding_row(nbmat_coulomb, shifts_coulomb, N)
|
|
857
|
+
data["nbmat_coulomb"] = nbmat_coulomb
|
|
858
|
+
# Set nbmat_lr for backward compatibility with code expecting unified LR neighbor list
|
|
859
|
+
data["nbmat_lr"] = nbmat_coulomb
|
|
860
|
+
if cell is not None and shifts_coulomb is not None:
|
|
861
|
+
data["shifts_coulomb"] = shifts_coulomb
|
|
862
|
+
data["shifts_lr"] = shifts_coulomb
|
|
863
|
+
|
|
864
|
+
if self._nblist_dftd3 is not None:
|
|
865
|
+
nbmat_dftd3, _, shifts_dftd3 = self._nblist_dftd3(
|
|
866
|
+
positions=data["coord"],
|
|
867
|
+
cell=cell_batched,
|
|
868
|
+
pbc=pbc,
|
|
869
|
+
batch_idx=batch_idx,
|
|
870
|
+
fill_value=N,
|
|
871
|
+
)
|
|
872
|
+
nbmat_dftd3, shifts_dftd3 = _add_padding_row(nbmat_dftd3, shifts_dftd3, N)
|
|
873
|
+
data["nbmat_dftd3"] = nbmat_dftd3
|
|
874
|
+
if cell is not None and shifts_dftd3 is not None:
|
|
875
|
+
data["shifts_dftd3"] = shifts_dftd3
|
|
876
|
+
|
|
877
|
+
elif self._nblist_dftd3 is not None:
|
|
878
|
+
# DFTD3-only configuration: populate nbmat_lr for backward compatibility
|
|
879
|
+
nbmat_dftd3, _, shifts_dftd3 = self._nblist_dftd3(
|
|
880
|
+
positions=data["coord"],
|
|
881
|
+
cell=cell_batched,
|
|
882
|
+
pbc=pbc,
|
|
883
|
+
batch_idx=batch_idx,
|
|
884
|
+
fill_value=N,
|
|
885
|
+
)
|
|
886
|
+
nbmat_dftd3, shifts_dftd3 = _add_padding_row(nbmat_dftd3, shifts_dftd3, N)
|
|
887
|
+
data["nbmat_dftd3"] = nbmat_dftd3
|
|
888
|
+
data["nbmat_lr"] = nbmat_dftd3
|
|
889
|
+
if cell is not None and shifts_dftd3 is not None:
|
|
890
|
+
data["shifts_dftd3"] = shifts_dftd3
|
|
891
|
+
data["shifts_lr"] = shifts_dftd3
|
|
892
|
+
|
|
208
893
|
return data
|
|
209
894
|
|
|
210
|
-
def pad_input(self, data:
|
|
895
|
+
def pad_input(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
211
896
|
N = data["nbmat"].shape[0]
|
|
212
897
|
data["mol_idx"] = maybe_pad_dim0(data["mol_idx"], N, value=data["mol_idx"][-1].item())
|
|
213
898
|
for k in ("coord", "numbers"):
|
|
@@ -215,36 +900,49 @@ class AIMNet2Calculator:
|
|
|
215
900
|
data[k] = maybe_pad_dim0(data[k], N)
|
|
216
901
|
return data
|
|
217
902
|
|
|
218
|
-
def unpad_output(self, data:
|
|
903
|
+
def unpad_output(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
219
904
|
N = data["nbmat"].shape[0] - 1
|
|
220
905
|
for k, v in data.items():
|
|
221
906
|
if k in self.atom_feature_keys:
|
|
222
907
|
data[k] = maybe_unpad_dim0(v, N)
|
|
223
908
|
return data
|
|
224
909
|
|
|
225
|
-
def set_grad_tensors(self, data:
|
|
910
|
+
def set_grad_tensors(self, data: dict[str, Tensor], forces=False, stress=False, hessian=False) -> dict[str, Tensor]:
|
|
226
911
|
self._saved_for_grad = {}
|
|
227
912
|
if forces or hessian:
|
|
228
913
|
data["coord"].requires_grad_(True)
|
|
229
914
|
self._saved_for_grad["coord"] = data["coord"]
|
|
230
915
|
if stress:
|
|
231
916
|
assert "cell" in data and data["cell"] is not None, "Stress calculation requires cell"
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
917
|
+
cell = data["cell"]
|
|
918
|
+
if cell.ndim == 2:
|
|
919
|
+
# Single system: (3, 3) scaling
|
|
920
|
+
scaling = torch.eye(3, requires_grad=True, dtype=cell.dtype, device=cell.device)
|
|
921
|
+
data["coord"] = data["coord"] @ scaling
|
|
922
|
+
data["cell"] = cell @ scaling
|
|
923
|
+
else:
|
|
924
|
+
# Batched systems: (B, 3, 3) scaling - each system gets independent scaling
|
|
925
|
+
B = cell.shape[0]
|
|
926
|
+
scaling = torch.eye(3, dtype=cell.dtype, device=cell.device).unsqueeze(0).expand(B, -1, -1)
|
|
927
|
+
scaling.requires_grad_(True)
|
|
928
|
+
mol_idx = data["mol_idx"]
|
|
929
|
+
# Apply per-atom scaling: coord[i] @ scaling[mol_idx[i]]
|
|
930
|
+
atom_scaling = torch.index_select(scaling, 0, mol_idx) # (N_total, 3, 3)
|
|
931
|
+
data["coord"] = (data["coord"].unsqueeze(1) @ atom_scaling).squeeze(1)
|
|
932
|
+
data["cell"] = cell @ scaling
|
|
235
933
|
self._saved_for_grad["scaling"] = scaling
|
|
236
934
|
return data
|
|
237
935
|
|
|
238
|
-
def keep_only(self, data:
|
|
936
|
+
def keep_only(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
239
937
|
ret = {}
|
|
240
938
|
for k, v in data.items():
|
|
241
939
|
if k in self.keys_out or (k.endswith("_std") and k[:-4] in self.keys_out):
|
|
242
940
|
ret[k] = v
|
|
243
941
|
return ret
|
|
244
942
|
|
|
245
|
-
def get_derivatives(self, data:
|
|
246
|
-
|
|
247
|
-
_create_graph = hessian or
|
|
943
|
+
def get_derivatives(self, data: dict[str, Tensor], forces=False, stress=False, hessian=False) -> dict[str, Tensor]:
|
|
944
|
+
# Use stored train mode for create_graph decision
|
|
945
|
+
_create_graph = hessian or self._train
|
|
248
946
|
x = []
|
|
249
947
|
if hessian:
|
|
250
948
|
forces = True
|
|
@@ -260,21 +958,62 @@ class AIMNet2Calculator:
|
|
|
260
958
|
data["forces"] = -deriv[0]
|
|
261
959
|
if stress:
|
|
262
960
|
dedc = deriv[0] if not forces else deriv[1]
|
|
263
|
-
|
|
961
|
+
cell = data["cell"].detach()
|
|
962
|
+
if cell.ndim == 2:
|
|
963
|
+
# Single cell (3, 3)
|
|
964
|
+
volume = cell.det().abs()
|
|
965
|
+
else:
|
|
966
|
+
# Batched cells (B, 3, 3) - compute volume for each cell
|
|
967
|
+
volume = torch.linalg.det(cell).abs().unsqueeze(-1).unsqueeze(-1) # (B, 1, 1)
|
|
968
|
+
data["stress"] = dedc / volume
|
|
264
969
|
if hessian:
|
|
265
970
|
data["hessian"] = self.calculate_hessian(data["forces"], self._saved_for_grad["coord"])
|
|
266
971
|
return data
|
|
267
972
|
|
|
268
973
|
@staticmethod
|
|
269
974
|
def calculate_hessian(forces: Tensor, coord: Tensor) -> Tensor:
|
|
270
|
-
#
|
|
271
|
-
#
|
|
975
|
+
# Coord includes padding atom (shape N+1), forces only for real atoms (shape N)
|
|
976
|
+
# Hessian computed only for actual atoms: (N, 3, N, 3)
|
|
272
977
|
hessian = -torch.stack([
|
|
273
978
|
torch.autograd.grad(_f, coord, retain_graph=True)[0] for _f in forces.flatten().unbind()
|
|
274
979
|
]).view(-1, 3, coord.shape[0], 3)[:-1, :, :-1, :]
|
|
275
980
|
return hessian
|
|
276
981
|
|
|
277
982
|
|
|
983
|
+
def _add_padding_row(
|
|
984
|
+
nbmat: Tensor,
|
|
985
|
+
shifts: Tensor | None,
|
|
986
|
+
N: int,
|
|
987
|
+
) -> tuple[Tensor, Tensor | None]:
|
|
988
|
+
"""Add padding row to neighbor matrix and shifts.
|
|
989
|
+
|
|
990
|
+
Parameters
|
|
991
|
+
----------
|
|
992
|
+
nbmat : Tensor
|
|
993
|
+
Neighbor matrix, shape (N, max_neighbors).
|
|
994
|
+
shifts : Tensor | None
|
|
995
|
+
Shift vectors for PBC or None, shape (N, max_neighbors, 3).
|
|
996
|
+
N : int
|
|
997
|
+
Number of atoms (used as fill value for padding row).
|
|
998
|
+
|
|
999
|
+
Returns
|
|
1000
|
+
-------
|
|
1001
|
+
tuple[Tensor, Tensor | None]
|
|
1002
|
+
Tuple of (nbmat, shifts) with padding row added.
|
|
1003
|
+
"""
|
|
1004
|
+
device = nbmat.device
|
|
1005
|
+
dtype = nbmat.dtype
|
|
1006
|
+
nnb_max = nbmat.shape[1]
|
|
1007
|
+
padding_row = torch.full((1, nnb_max), N, dtype=dtype, device=device)
|
|
1008
|
+
nbmat = torch.cat([nbmat, padding_row], dim=0)
|
|
1009
|
+
|
|
1010
|
+
if shifts is not None:
|
|
1011
|
+
shifts_padding = torch.zeros((1, nnb_max, 3), dtype=shifts.dtype, device=device)
|
|
1012
|
+
shifts = torch.cat([shifts, shifts_padding], dim=0)
|
|
1013
|
+
|
|
1014
|
+
return nbmat, shifts
|
|
1015
|
+
|
|
1016
|
+
|
|
278
1017
|
def maybe_pad_dim0(a: Tensor, N: int, value=0.0) -> Tensor:
|
|
279
1018
|
_shape_diff = N - a.shape[0]
|
|
280
1019
|
assert _shape_diff == 0 or _shape_diff == 1, "Invalid shape"
|
|
@@ -297,24 +1036,45 @@ def maybe_unpad_dim0(a: Tensor, N: int) -> Tensor:
|
|
|
297
1036
|
return a
|
|
298
1037
|
|
|
299
1038
|
|
|
300
|
-
def move_coord_to_cell(coord, cell):
|
|
301
|
-
|
|
302
|
-
coord_f = coord_f % 1
|
|
303
|
-
return coord_f @ cell
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
def _named_children_rec(module):
|
|
307
|
-
if isinstance(module, torch.nn.Module):
|
|
308
|
-
for name, child in module.named_children():
|
|
309
|
-
yield name, child
|
|
310
|
-
yield from _named_children_rec(child)
|
|
1039
|
+
def move_coord_to_cell(coord: Tensor, cell: Tensor, mol_idx: Tensor | None = None) -> Tensor:
|
|
1040
|
+
"""Move coordinates into the periodic cell.
|
|
311
1041
|
|
|
1042
|
+
Parameters
|
|
1043
|
+
----------
|
|
1044
|
+
coord : Tensor
|
|
1045
|
+
Coordinates tensor, shape (N, 3) or (B, N, 3).
|
|
1046
|
+
cell : Tensor
|
|
1047
|
+
Cell tensor, shape (3, 3) or (B, 3, 3).
|
|
1048
|
+
mol_idx : Tensor | None
|
|
1049
|
+
Molecule index for each atom, shape (N,).
|
|
1050
|
+
Required for batched cells with flat coordinates.
|
|
312
1051
|
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
1052
|
+
Returns
|
|
1053
|
+
-------
|
|
1054
|
+
Tensor
|
|
1055
|
+
Coordinates wrapped into the cell.
|
|
1056
|
+
"""
|
|
1057
|
+
if cell.ndim == 2:
|
|
1058
|
+
# Single cell (3, 3)
|
|
1059
|
+
cell_inv = torch.linalg.inv(cell)
|
|
1060
|
+
coord_f = coord @ cell_inv
|
|
1061
|
+
coord_f = coord_f % 1
|
|
1062
|
+
return coord_f @ cell
|
|
1063
|
+
else:
|
|
1064
|
+
# Batched cells (B, 3, 3)
|
|
1065
|
+
if coord.ndim == 3:
|
|
1066
|
+
# Batched coords (B, N, 3) with batched cells (B, 3, 3)
|
|
1067
|
+
cell_inv = torch.linalg.inv(cell) # (B, 3, 3)
|
|
1068
|
+
coord_f = torch.bmm(coord, cell_inv) # (B, N, 3)
|
|
1069
|
+
coord_f = coord_f % 1
|
|
1070
|
+
return torch.bmm(coord_f, cell)
|
|
1071
|
+
else:
|
|
1072
|
+
# Flat coords (N_total, 3) with batched cells (B, 3, 3) - need mol_idx
|
|
1073
|
+
assert mol_idx is not None, "mol_idx required for batched cells with flat coordinates"
|
|
1074
|
+
cell_inv = torch.linalg.inv(cell) # (B, 3, 3)
|
|
1075
|
+
# Get cell and cell_inv for each atom
|
|
1076
|
+
atom_cell = cell[mol_idx] # (N_total, 3, 3)
|
|
1077
|
+
atom_cell_inv = cell_inv[mol_idx] # (N_total, 3, 3)
|
|
1078
|
+
coord_f = torch.bmm(coord.unsqueeze(1), atom_cell_inv).squeeze(1) # (N_total, 3)
|
|
1079
|
+
coord_f = coord_f % 1
|
|
1080
|
+
return torch.bmm(coord_f.unsqueeze(1), atom_cell).squeeze(1)
|