aimnet 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aimnet/__init__.py +7 -0
- aimnet/base.py +24 -8
- aimnet/calculators/__init__.py +4 -4
- aimnet/calculators/aimnet2ase.py +19 -6
- aimnet/calculators/calculator.py +868 -108
- aimnet/calculators/model_registry.py +2 -5
- aimnet/calculators/model_registry.yaml +55 -17
- aimnet/cli.py +62 -6
- aimnet/config.py +8 -9
- aimnet/data/sgdataset.py +23 -22
- aimnet/kernels/__init__.py +66 -0
- aimnet/kernels/conv_sv_2d_sp_wp.py +478 -0
- aimnet/models/__init__.py +13 -1
- aimnet/models/aimnet2.py +19 -22
- aimnet/models/base.py +183 -15
- aimnet/models/convert.py +30 -0
- aimnet/models/utils.py +735 -0
- aimnet/modules/__init__.py +1 -1
- aimnet/modules/aev.py +49 -48
- aimnet/modules/core.py +14 -13
- aimnet/modules/lr.py +520 -115
- aimnet/modules/ops.py +537 -0
- aimnet/nbops.py +105 -15
- aimnet/ops.py +90 -18
- aimnet/train/export_model.py +226 -0
- aimnet/train/loss.py +7 -7
- aimnet/train/metrics.py +5 -6
- aimnet/train/train.py +4 -1
- aimnet/train/utils.py +42 -13
- aimnet-0.1.0.dist-info/METADATA +308 -0
- aimnet-0.1.0.dist-info/RECORD +43 -0
- {aimnet-0.0.1.dist-info → aimnet-0.1.0.dist-info}/WHEEL +1 -1
- aimnet-0.1.0.dist-info/entry_points.txt +3 -0
- aimnet/calculators/nb_kernel_cpu.py +0 -222
- aimnet/calculators/nb_kernel_cuda.py +0 -217
- aimnet/calculators/nbmat.py +0 -220
- aimnet/train/pt2jpt.py +0 -81
- aimnet-0.0.1.dist-info/METADATA +0 -78
- aimnet-0.0.1.dist-info/RECORD +0 -41
- aimnet-0.0.1.dist-info/entry_points.txt +0 -5
- {aimnet-0.0.1.dist-info → aimnet-0.1.0.dist-info/licenses}/LICENSE +0 -0
aimnet/modules/ops.py
ADDED
|
@@ -0,0 +1,537 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: MIT
|
|
3
|
+
#
|
|
4
|
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
|
5
|
+
# copy of this software and associated documentation files (the "Software"),
|
|
6
|
+
# to deal in the Software without restriction, including without limitation
|
|
7
|
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
|
8
|
+
# and/or sell copies of the Software, and to permit persons to whom the
|
|
9
|
+
# Software is furnished to do so, subject to the following conditions:
|
|
10
|
+
#
|
|
11
|
+
# The above copyright notice and this permission notice shall be included in
|
|
12
|
+
# all copies or substantial portions of the Software.
|
|
13
|
+
#
|
|
14
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
15
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
16
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
|
17
|
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
18
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
|
19
|
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
|
20
|
+
# DEALINGS IN THE SOFTWARE.
|
|
21
|
+
"""
|
|
22
|
+
DFT-D3 Custom Op for PyTorch.
|
|
23
|
+
|
|
24
|
+
This module provides a TorchScript-compatible custom op for DFT-D3 dispersion
|
|
25
|
+
energy computation using nvalchemiops GPU-accelerated kernels.
|
|
26
|
+
|
|
27
|
+
TorchScript Compatibility
|
|
28
|
+
-------------------------
|
|
29
|
+
- torch.jit.script(): SUPPORTED - Models using this op can be scripted
|
|
30
|
+
- torch.jit.save(): SUPPORTED - Uses torch.autograd.Function pattern for serialization
|
|
31
|
+
|
|
32
|
+
The implementation wraps nvalchemiops calls in torch.autograd.Function classes,
|
|
33
|
+
which enables proper serialization with TorchScript.
|
|
34
|
+
|
|
35
|
+
Stress Sign Convention
|
|
36
|
+
----------------------
|
|
37
|
+
This module follows the Cauchy (physical) stress convention:
|
|
38
|
+
|
|
39
|
+
- Positive stress = tensile (material being stretched)
|
|
40
|
+
- Negative stress = compressive (material being compressed)
|
|
41
|
+
|
|
42
|
+
The relationship between virial and Cauchy stress is:
|
|
43
|
+
|
|
44
|
+
stress = -virial / volume
|
|
45
|
+
|
|
46
|
+
where virial = -0.5 * sum_ij(F_ij outer r_ij).
|
|
47
|
+
|
|
48
|
+
The cell gradient for autograd is computed as:
|
|
49
|
+
|
|
50
|
+
dE/dcell = -virial @ inv(cell).T
|
|
51
|
+
|
|
52
|
+
This ensures the calculator returns physical Cauchy stress compatible with
|
|
53
|
+
ASE and standard MD conventions.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
from typing import Any
|
|
57
|
+
|
|
58
|
+
import torch
|
|
59
|
+
from nvalchemiops.interactions.dispersion.dftd3 import dftd3
|
|
60
|
+
from torch import Tensor
|
|
61
|
+
from torch.autograd import Function
|
|
62
|
+
|
|
63
|
+
from aimnet import constants
|
|
64
|
+
|
|
65
|
+
# =============================================================================
|
|
66
|
+
# Autograd Function Classes
|
|
67
|
+
# =============================================================================
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class _DFTD3Function(Function):
|
|
71
|
+
"""Autograd Function for DFT-D3 dispersion energy computation.
|
|
72
|
+
|
|
73
|
+
This class wraps the nvalchemiops dftd3 implementation with proper
|
|
74
|
+
autograd support, enabling both gradient computation and TorchScript
|
|
75
|
+
serialization.
|
|
76
|
+
|
|
77
|
+
Notes
|
|
78
|
+
-----
|
|
79
|
+
Input coordinates are in Angstroms, internally converted to Bohr.
|
|
80
|
+
Output energies are in eV, forces in eV/Angstrom.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
@staticmethod
|
|
84
|
+
def forward(
|
|
85
|
+
ctx: Any,
|
|
86
|
+
coord: Tensor,
|
|
87
|
+
cell: Tensor,
|
|
88
|
+
numbers: Tensor,
|
|
89
|
+
batch_idx: Tensor,
|
|
90
|
+
neighbor_matrix: Tensor,
|
|
91
|
+
shifts: Tensor,
|
|
92
|
+
rcov: Tensor,
|
|
93
|
+
r4r2: Tensor,
|
|
94
|
+
c6ab: Tensor,
|
|
95
|
+
cn_ref: Tensor,
|
|
96
|
+
a1: float,
|
|
97
|
+
a2: float,
|
|
98
|
+
s6: float,
|
|
99
|
+
s8: float,
|
|
100
|
+
num_systems: int,
|
|
101
|
+
fill_value: int,
|
|
102
|
+
smoothing_on: float,
|
|
103
|
+
smoothing_off: float,
|
|
104
|
+
compute_virial: bool,
|
|
105
|
+
has_cell: bool,
|
|
106
|
+
has_shifts: bool,
|
|
107
|
+
) -> tuple[Tensor, Tensor, Tensor]:
|
|
108
|
+
"""Forward pass: compute DFT-D3 dispersion energy."""
|
|
109
|
+
# Convert coordinates to Bohr
|
|
110
|
+
coord_bohr = coord * constants.Bohr_inv
|
|
111
|
+
|
|
112
|
+
# Convert cell to Bohr if present
|
|
113
|
+
cell_bohr = None
|
|
114
|
+
if has_cell:
|
|
115
|
+
cell_bohr = cell * constants.Bohr_inv
|
|
116
|
+
if cell_bohr.ndim == 2:
|
|
117
|
+
cell_bohr = cell_bohr.unsqueeze(0)
|
|
118
|
+
|
|
119
|
+
# Handle shifts
|
|
120
|
+
shifts_arg = None
|
|
121
|
+
if has_shifts:
|
|
122
|
+
shifts_arg = shifts
|
|
123
|
+
|
|
124
|
+
# Build kwargs for nvalchemiops dftd3 call
|
|
125
|
+
dftd3_kwargs: dict[str, Any] = {
|
|
126
|
+
"positions": coord_bohr,
|
|
127
|
+
"numbers": numbers,
|
|
128
|
+
"a1": a1,
|
|
129
|
+
"a2": a2,
|
|
130
|
+
"s8": s8,
|
|
131
|
+
"s6": s6,
|
|
132
|
+
"covalent_radii": rcov,
|
|
133
|
+
"r4r2": r4r2,
|
|
134
|
+
"c6_reference": c6ab,
|
|
135
|
+
"coord_num_ref": cn_ref,
|
|
136
|
+
"batch_idx": batch_idx,
|
|
137
|
+
"cell": cell_bohr,
|
|
138
|
+
"neighbor_matrix": neighbor_matrix,
|
|
139
|
+
"neighbor_matrix_shifts": shifts_arg,
|
|
140
|
+
"fill_value": fill_value,
|
|
141
|
+
"num_systems": num_systems,
|
|
142
|
+
"compute_virial": compute_virial,
|
|
143
|
+
"device": str(coord.device),
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
# Only pass smoothing parameters if smoothing is enabled
|
|
147
|
+
# When smoothing_on >= smoothing_off, omit to use nvalchemiops defaults (1e10)
|
|
148
|
+
if smoothing_on < smoothing_off:
|
|
149
|
+
dftd3_kwargs["s5_smoothing_on"] = smoothing_on * constants.Bohr_inv
|
|
150
|
+
dftd3_kwargs["s5_smoothing_off"] = smoothing_off * constants.Bohr_inv
|
|
151
|
+
|
|
152
|
+
# Call nvalchemiops dftd3
|
|
153
|
+
result = dftd3(**dftd3_kwargs)
|
|
154
|
+
|
|
155
|
+
if compute_virial:
|
|
156
|
+
energy, forces, _coord_num, virial = result
|
|
157
|
+
else:
|
|
158
|
+
energy, forces, _coord_num = result
|
|
159
|
+
virial = torch.empty(0, device=coord.device)
|
|
160
|
+
|
|
161
|
+
# Convert to eV/Angstrom units
|
|
162
|
+
energy_ev = energy * constants.Hartree
|
|
163
|
+
forces_ev_ang = forces * constants.Hartree * constants.Bohr_inv
|
|
164
|
+
|
|
165
|
+
# Save tensors for backward - convert cell_bohr for gradient computation
|
|
166
|
+
cell_bohr_saved = torch.empty(0, device=coord.device)
|
|
167
|
+
if has_cell:
|
|
168
|
+
cell_bohr_saved = cell * constants.Bohr_inv
|
|
169
|
+
if cell_bohr_saved.ndim == 2:
|
|
170
|
+
cell_bohr_saved = cell_bohr_saved.unsqueeze(0)
|
|
171
|
+
|
|
172
|
+
ctx.save_for_backward(forces_ev_ang, virial, batch_idx, cell_bohr_saved)
|
|
173
|
+
ctx.has_cell = has_cell
|
|
174
|
+
ctx.compute_virial = compute_virial
|
|
175
|
+
|
|
176
|
+
return energy_ev, forces_ev_ang, virial
|
|
177
|
+
|
|
178
|
+
@staticmethod
|
|
179
|
+
def backward(
|
|
180
|
+
ctx: Any,
|
|
181
|
+
grad_energy: Tensor,
|
|
182
|
+
grad_forces: Tensor,
|
|
183
|
+
grad_virial: Tensor,
|
|
184
|
+
) -> tuple[Tensor | None, ...]:
|
|
185
|
+
"""Backward pass: compute gradients w.r.t. coord and cell."""
|
|
186
|
+
forces, virial, batch_idx, cell_bohr = ctx.saved_tensors
|
|
187
|
+
has_cell = ctx.has_cell
|
|
188
|
+
compute_virial = ctx.compute_virial
|
|
189
|
+
|
|
190
|
+
# Coord gradient: forces = -dE/dR, so dE/dR = -forces
|
|
191
|
+
grad_coord = -forces * grad_energy[batch_idx].unsqueeze(-1)
|
|
192
|
+
|
|
193
|
+
# Cell gradient for physical Cauchy stress
|
|
194
|
+
# Cauchy stress = -virial / volume, so dE/dcell = -virial @ inv(cell).T
|
|
195
|
+
grad_cell = None
|
|
196
|
+
if has_cell and compute_virial and virial.numel() > 0:
|
|
197
|
+
cell_inv_t = torch.linalg.inv(cell_bohr).transpose(-1, -2)
|
|
198
|
+
dE_dcell_bohr = -virial @ cell_inv_t # Negative sign for Cauchy stress
|
|
199
|
+
dE_dcell_ang = dE_dcell_bohr * constants.Hartree * constants.Bohr_inv
|
|
200
|
+
grad_cell = dE_dcell_ang * grad_energy.view(-1, 1, 1)
|
|
201
|
+
|
|
202
|
+
# Return gradients for all 21 inputs (only coord and cell have gradients)
|
|
203
|
+
return (
|
|
204
|
+
grad_coord,
|
|
205
|
+
grad_cell,
|
|
206
|
+
None, # numbers
|
|
207
|
+
None, # batch_idx
|
|
208
|
+
None, # neighbor_matrix
|
|
209
|
+
None, # shifts
|
|
210
|
+
None, # rcov
|
|
211
|
+
None, # r4r2
|
|
212
|
+
None, # c6ab
|
|
213
|
+
None, # cn_ref
|
|
214
|
+
None, # a1
|
|
215
|
+
None, # a2
|
|
216
|
+
None, # s6
|
|
217
|
+
None, # s8
|
|
218
|
+
None, # num_systems
|
|
219
|
+
None, # fill_value
|
|
220
|
+
None, # smoothing_on
|
|
221
|
+
None, # smoothing_off
|
|
222
|
+
None, # compute_virial
|
|
223
|
+
None, # has_cell
|
|
224
|
+
None, # has_shifts
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
# =============================================================================
|
|
229
|
+
# PyTorch Custom Op Registration
|
|
230
|
+
# =============================================================================
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
@torch.library.custom_op("aimnet::dftd3_fwd", mutates_args=())
|
|
234
|
+
def dftd3_fwd(
|
|
235
|
+
coord: Tensor,
|
|
236
|
+
cell: Tensor,
|
|
237
|
+
numbers: Tensor,
|
|
238
|
+
batch_idx: Tensor,
|
|
239
|
+
neighbor_matrix: Tensor,
|
|
240
|
+
shifts: Tensor,
|
|
241
|
+
rcov: Tensor,
|
|
242
|
+
r4r2: Tensor,
|
|
243
|
+
c6ab: Tensor,
|
|
244
|
+
cn_ref: Tensor,
|
|
245
|
+
a1: float,
|
|
246
|
+
a2: float,
|
|
247
|
+
s6: float,
|
|
248
|
+
s8: float,
|
|
249
|
+
num_systems: int,
|
|
250
|
+
fill_value: int,
|
|
251
|
+
smoothing_on: float,
|
|
252
|
+
smoothing_off: float,
|
|
253
|
+
compute_virial: bool,
|
|
254
|
+
has_cell: bool,
|
|
255
|
+
has_shifts: bool,
|
|
256
|
+
) -> list[Tensor]:
|
|
257
|
+
"""
|
|
258
|
+
Forward primitive for DFT-D3 energy computation.
|
|
259
|
+
|
|
260
|
+
Returns [energy, forces, virial] tensors.
|
|
261
|
+
"""
|
|
262
|
+
result = _DFTD3Function.apply(
|
|
263
|
+
coord,
|
|
264
|
+
cell,
|
|
265
|
+
numbers,
|
|
266
|
+
batch_idx,
|
|
267
|
+
neighbor_matrix,
|
|
268
|
+
shifts,
|
|
269
|
+
rcov,
|
|
270
|
+
r4r2,
|
|
271
|
+
c6ab,
|
|
272
|
+
cn_ref,
|
|
273
|
+
a1,
|
|
274
|
+
a2,
|
|
275
|
+
s6,
|
|
276
|
+
s8,
|
|
277
|
+
num_systems,
|
|
278
|
+
fill_value,
|
|
279
|
+
smoothing_on,
|
|
280
|
+
smoothing_off,
|
|
281
|
+
compute_virial,
|
|
282
|
+
has_cell,
|
|
283
|
+
has_shifts,
|
|
284
|
+
)
|
|
285
|
+
return list(result)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
@torch.library.register_fake("aimnet::dftd3_fwd")
|
|
289
|
+
def _(
|
|
290
|
+
coord: Tensor,
|
|
291
|
+
cell: Tensor,
|
|
292
|
+
numbers: Tensor,
|
|
293
|
+
batch_idx: Tensor,
|
|
294
|
+
neighbor_matrix: Tensor,
|
|
295
|
+
shifts: Tensor,
|
|
296
|
+
rcov: Tensor,
|
|
297
|
+
r4r2: Tensor,
|
|
298
|
+
c6ab: Tensor,
|
|
299
|
+
cn_ref: Tensor,
|
|
300
|
+
a1: float,
|
|
301
|
+
a2: float,
|
|
302
|
+
s6: float,
|
|
303
|
+
s8: float,
|
|
304
|
+
num_systems: int,
|
|
305
|
+
fill_value: int,
|
|
306
|
+
smoothing_on: float,
|
|
307
|
+
smoothing_off: float,
|
|
308
|
+
compute_virial: bool,
|
|
309
|
+
has_cell: bool,
|
|
310
|
+
has_shifts: bool,
|
|
311
|
+
) -> list[Tensor]:
|
|
312
|
+
"""Fake implementation for torch.compile tracing."""
|
|
313
|
+
n_atoms = coord.shape[0]
|
|
314
|
+
energy = coord.new_empty(num_systems)
|
|
315
|
+
forces = coord.new_empty(n_atoms, 3)
|
|
316
|
+
if compute_virial:
|
|
317
|
+
virial = coord.new_empty(num_systems, 3, 3)
|
|
318
|
+
else:
|
|
319
|
+
virial = coord.new_empty(0)
|
|
320
|
+
return [energy, forces, virial]
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
# =============================================================================
|
|
324
|
+
# Autograd Registration (Method-based)
|
|
325
|
+
# =============================================================================
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def _dftd3_setup_context(ctx: Any, inputs: tuple[Any, ...], output: list[Tensor]) -> None:
|
|
329
|
+
"""Setup context for backward pass."""
|
|
330
|
+
(
|
|
331
|
+
coord,
|
|
332
|
+
cell,
|
|
333
|
+
_numbers,
|
|
334
|
+
batch_idx,
|
|
335
|
+
_neighbor_matrix,
|
|
336
|
+
_shifts,
|
|
337
|
+
_rcov,
|
|
338
|
+
_r4r2,
|
|
339
|
+
_c6ab,
|
|
340
|
+
_cn_ref,
|
|
341
|
+
_a1,
|
|
342
|
+
_a2,
|
|
343
|
+
_s6,
|
|
344
|
+
_s8,
|
|
345
|
+
_num_systems,
|
|
346
|
+
_fill_value,
|
|
347
|
+
_smoothing_on,
|
|
348
|
+
_smoothing_off,
|
|
349
|
+
compute_virial,
|
|
350
|
+
has_cell,
|
|
351
|
+
_has_shifts,
|
|
352
|
+
) = inputs
|
|
353
|
+
_energy, forces, virial = output
|
|
354
|
+
|
|
355
|
+
# Convert cell to Bohr for backward
|
|
356
|
+
cell_bohr = torch.empty(0, device=coord.device)
|
|
357
|
+
if has_cell:
|
|
358
|
+
cell_bohr = cell * constants.Bohr_inv
|
|
359
|
+
if cell_bohr.ndim == 2:
|
|
360
|
+
cell_bohr = cell_bohr.unsqueeze(0)
|
|
361
|
+
|
|
362
|
+
ctx.save_for_backward(forces, virial, batch_idx, cell_bohr)
|
|
363
|
+
ctx.has_cell = has_cell
|
|
364
|
+
ctx.compute_virial = compute_virial
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def _dftd3_backward(
|
|
368
|
+
ctx: Any,
|
|
369
|
+
grad_outputs: list[Tensor],
|
|
370
|
+
) -> tuple[Tensor | None, ...]:
|
|
371
|
+
"""Backward pass for dftd3 energy."""
|
|
372
|
+
grad_energy = grad_outputs[0]
|
|
373
|
+
# grad_outputs[1] and grad_outputs[2] are grad_forces and grad_virial (unused)
|
|
374
|
+
forces, virial, batch_idx, cell_bohr = ctx.saved_tensors
|
|
375
|
+
has_cell = ctx.has_cell
|
|
376
|
+
compute_virial = ctx.compute_virial
|
|
377
|
+
|
|
378
|
+
# Coord gradient: forces = -dE/dR, so dE/dR = -forces
|
|
379
|
+
grad_coord = -forces * grad_energy[batch_idx].unsqueeze(-1)
|
|
380
|
+
|
|
381
|
+
# Cell gradient for physical Cauchy stress
|
|
382
|
+
# Cauchy stress = -virial / volume, so dE/dcell = -virial @ inv(cell).T
|
|
383
|
+
grad_cell = None
|
|
384
|
+
if has_cell and compute_virial and virial.numel() > 0:
|
|
385
|
+
cell_inv_t = torch.linalg.inv(cell_bohr).transpose(-1, -2)
|
|
386
|
+
dE_dcell_bohr = -virial @ cell_inv_t # Negative sign for Cauchy stress
|
|
387
|
+
dE_dcell_ang = dE_dcell_bohr * constants.Hartree * constants.Bohr_inv
|
|
388
|
+
grad_cell = dE_dcell_ang * grad_energy.view(-1, 1, 1)
|
|
389
|
+
|
|
390
|
+
# Return gradients for all 21 inputs (only coord and cell have gradients)
|
|
391
|
+
return (
|
|
392
|
+
grad_coord,
|
|
393
|
+
grad_cell,
|
|
394
|
+
None, # numbers
|
|
395
|
+
None, # batch_idx
|
|
396
|
+
None, # neighbor_matrix
|
|
397
|
+
None, # shifts
|
|
398
|
+
None, # rcov
|
|
399
|
+
None, # r4r2
|
|
400
|
+
None, # c6ab
|
|
401
|
+
None, # cn_ref
|
|
402
|
+
None, # a1
|
|
403
|
+
None, # a2
|
|
404
|
+
None, # s6
|
|
405
|
+
None, # s8
|
|
406
|
+
None, # num_systems
|
|
407
|
+
None, # fill_value
|
|
408
|
+
None, # smoothing_on
|
|
409
|
+
None, # smoothing_off
|
|
410
|
+
None, # compute_virial
|
|
411
|
+
None, # has_cell
|
|
412
|
+
None, # has_shifts
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
# Use method-based autograd registration for serializability
|
|
417
|
+
dftd3_fwd.register_autograd(
|
|
418
|
+
_dftd3_backward,
|
|
419
|
+
setup_context=_dftd3_setup_context,
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
# =============================================================================
|
|
424
|
+
# Public API
|
|
425
|
+
# =============================================================================
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def dftd3_energy(
|
|
429
|
+
coord: Tensor,
|
|
430
|
+
cell: Tensor | None,
|
|
431
|
+
numbers: Tensor,
|
|
432
|
+
batch_idx: Tensor,
|
|
433
|
+
neighbor_matrix: Tensor,
|
|
434
|
+
shifts: Tensor | None,
|
|
435
|
+
rcov: Tensor,
|
|
436
|
+
r4r2: Tensor,
|
|
437
|
+
c6ab: Tensor,
|
|
438
|
+
cn_ref: Tensor,
|
|
439
|
+
a1: float,
|
|
440
|
+
a2: float,
|
|
441
|
+
s6: float,
|
|
442
|
+
s8: float,
|
|
443
|
+
num_systems: int,
|
|
444
|
+
fill_value: int,
|
|
445
|
+
smoothing_on: float,
|
|
446
|
+
smoothing_off: float,
|
|
447
|
+
compute_virial: bool = False,
|
|
448
|
+
) -> Tensor:
|
|
449
|
+
"""
|
|
450
|
+
Compute DFT-D3 dispersion energy with automatic differentiation support.
|
|
451
|
+
|
|
452
|
+
This function wraps the nvalchemiops DFT-D3 implementation as a PyTorch
|
|
453
|
+
custom op with proper autograd support for computing gradients.
|
|
454
|
+
|
|
455
|
+
Parameters
|
|
456
|
+
----------
|
|
457
|
+
coord : Tensor
|
|
458
|
+
Atomic coordinates in Angstrom, shape (N, 3)
|
|
459
|
+
cell : Tensor or None
|
|
460
|
+
Unit cell vectors in Angstrom, shape (3, 3) or (B, 3, 3)
|
|
461
|
+
numbers : Tensor
|
|
462
|
+
Atomic numbers, shape (N,)
|
|
463
|
+
batch_idx : Tensor
|
|
464
|
+
Batch index for each atom, shape (N,)
|
|
465
|
+
neighbor_matrix : Tensor
|
|
466
|
+
Neighbor indices, shape (N, M)
|
|
467
|
+
shifts : Tensor or None
|
|
468
|
+
Periodic shift vectors as integers, shape (N, M, 3)
|
|
469
|
+
rcov : Tensor
|
|
470
|
+
Covalent radii, shape (95,)
|
|
471
|
+
r4r2 : Tensor
|
|
472
|
+
R4/R2 expectation values, shape (95,)
|
|
473
|
+
c6ab : Tensor
|
|
474
|
+
C6 reference values, shape (95, 95, 5, 5)
|
|
475
|
+
cn_ref : Tensor
|
|
476
|
+
Coordination number references, shape (95, 95, 5, 5)
|
|
477
|
+
a1 : float
|
|
478
|
+
BJ damping parameter a1
|
|
479
|
+
a2 : float
|
|
480
|
+
BJ damping parameter a2
|
|
481
|
+
s6 : float
|
|
482
|
+
Scaling factor for C6 term
|
|
483
|
+
s8 : float
|
|
484
|
+
Scaling factor for C8 term
|
|
485
|
+
num_systems : int
|
|
486
|
+
Number of systems in batch
|
|
487
|
+
fill_value : int
|
|
488
|
+
Fill value for invalid neighbor indices
|
|
489
|
+
smoothing_on : float
|
|
490
|
+
Distance at which smoothing starts, in Angstrom.
|
|
491
|
+
smoothing_off : float
|
|
492
|
+
Distance at which smoothing ends (cutoff), in Angstrom.
|
|
493
|
+
compute_virial : bool
|
|
494
|
+
Whether to compute virial tensor for cell gradients
|
|
495
|
+
|
|
496
|
+
Returns
|
|
497
|
+
-------
|
|
498
|
+
Tensor
|
|
499
|
+
Dispersion energy per system in eV, shape (num_systems,)
|
|
500
|
+
|
|
501
|
+
Notes
|
|
502
|
+
-----
|
|
503
|
+
Input coordinates are in Angstroms, internally converted to Bohr.
|
|
504
|
+
Output energies are in eV, forces in eV/Angstrom.
|
|
505
|
+
"""
|
|
506
|
+
# Prepare tensors - custom op requires non-None tensors
|
|
507
|
+
cell_tensor = cell if cell is not None else torch.empty(0, device=coord.device)
|
|
508
|
+
shifts_tensor = shifts if shifts is not None else torch.empty(0, device=coord.device, dtype=torch.int32)
|
|
509
|
+
|
|
510
|
+
has_cell = cell is not None
|
|
511
|
+
has_shifts = shifts is not None
|
|
512
|
+
|
|
513
|
+
result = torch.ops.aimnet.dftd3_fwd(
|
|
514
|
+
coord,
|
|
515
|
+
cell_tensor,
|
|
516
|
+
numbers,
|
|
517
|
+
batch_idx,
|
|
518
|
+
neighbor_matrix,
|
|
519
|
+
shifts_tensor,
|
|
520
|
+
rcov,
|
|
521
|
+
r4r2,
|
|
522
|
+
c6ab,
|
|
523
|
+
cn_ref,
|
|
524
|
+
a1,
|
|
525
|
+
a2,
|
|
526
|
+
s6,
|
|
527
|
+
s8,
|
|
528
|
+
num_systems,
|
|
529
|
+
fill_value,
|
|
530
|
+
smoothing_on,
|
|
531
|
+
smoothing_off,
|
|
532
|
+
compute_virial,
|
|
533
|
+
has_cell,
|
|
534
|
+
has_shifts,
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
return result[0] # Return only energy
|