aimnet 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aimnet/modules/ops.py ADDED
@@ -0,0 +1,537 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: MIT
3
+ #
4
+ # Permission is hereby granted, free of charge, to any person obtaining a
5
+ # copy of this software and associated documentation files (the "Software"),
6
+ # to deal in the Software without restriction, including without limitation
7
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8
+ # and/or sell copies of the Software, and to permit persons to whom the
9
+ # Software is furnished to do so, subject to the following conditions:
10
+ #
11
+ # The above copyright notice and this permission notice shall be included in
12
+ # all copies or substantial portions of the Software.
13
+ #
14
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20
+ # DEALINGS IN THE SOFTWARE.
21
+ """
22
+ DFT-D3 Custom Op for PyTorch.
23
+
24
+ This module provides a TorchScript-compatible custom op for DFT-D3 dispersion
25
+ energy computation using nvalchemiops GPU-accelerated kernels.
26
+
27
+ TorchScript Compatibility
28
+ -------------------------
29
+ - torch.jit.script(): SUPPORTED - Models using this op can be scripted
30
+ - torch.jit.save(): SUPPORTED - Uses torch.autograd.Function pattern for serialization
31
+
32
+ The implementation wraps nvalchemiops calls in torch.autograd.Function classes,
33
+ which enables proper serialization with TorchScript.
34
+
35
+ Stress Sign Convention
36
+ ----------------------
37
+ This module follows the Cauchy (physical) stress convention:
38
+
39
+ - Positive stress = tensile (material being stretched)
40
+ - Negative stress = compressive (material being compressed)
41
+
42
+ The relationship between virial and Cauchy stress is:
43
+
44
+ stress = -virial / volume
45
+
46
+ where virial = -0.5 * sum_ij(F_ij outer r_ij).
47
+
48
+ The cell gradient for autograd is computed as:
49
+
50
+ dE/dcell = -virial @ inv(cell).T
51
+
52
+ This ensures the calculator returns physical Cauchy stress compatible with
53
+ ASE and standard MD conventions.
54
+ """
55
+
56
+ from typing import Any
57
+
58
+ import torch
59
+ from nvalchemiops.interactions.dispersion.dftd3 import dftd3
60
+ from torch import Tensor
61
+ from torch.autograd import Function
62
+
63
+ from aimnet import constants
64
+
65
+ # =============================================================================
66
+ # Autograd Function Classes
67
+ # =============================================================================
68
+
69
+
70
+ class _DFTD3Function(Function):
71
+ """Autograd Function for DFT-D3 dispersion energy computation.
72
+
73
+ This class wraps the nvalchemiops dftd3 implementation with proper
74
+ autograd support, enabling both gradient computation and TorchScript
75
+ serialization.
76
+
77
+ Notes
78
+ -----
79
+ Input coordinates are in Angstroms, internally converted to Bohr.
80
+ Output energies are in eV, forces in eV/Angstrom.
81
+ """
82
+
83
+ @staticmethod
84
+ def forward(
85
+ ctx: Any,
86
+ coord: Tensor,
87
+ cell: Tensor,
88
+ numbers: Tensor,
89
+ batch_idx: Tensor,
90
+ neighbor_matrix: Tensor,
91
+ shifts: Tensor,
92
+ rcov: Tensor,
93
+ r4r2: Tensor,
94
+ c6ab: Tensor,
95
+ cn_ref: Tensor,
96
+ a1: float,
97
+ a2: float,
98
+ s6: float,
99
+ s8: float,
100
+ num_systems: int,
101
+ fill_value: int,
102
+ smoothing_on: float,
103
+ smoothing_off: float,
104
+ compute_virial: bool,
105
+ has_cell: bool,
106
+ has_shifts: bool,
107
+ ) -> tuple[Tensor, Tensor, Tensor]:
108
+ """Forward pass: compute DFT-D3 dispersion energy."""
109
+ # Convert coordinates to Bohr
110
+ coord_bohr = coord * constants.Bohr_inv
111
+
112
+ # Convert cell to Bohr if present
113
+ cell_bohr = None
114
+ if has_cell:
115
+ cell_bohr = cell * constants.Bohr_inv
116
+ if cell_bohr.ndim == 2:
117
+ cell_bohr = cell_bohr.unsqueeze(0)
118
+
119
+ # Handle shifts
120
+ shifts_arg = None
121
+ if has_shifts:
122
+ shifts_arg = shifts
123
+
124
+ # Build kwargs for nvalchemiops dftd3 call
125
+ dftd3_kwargs: dict[str, Any] = {
126
+ "positions": coord_bohr,
127
+ "numbers": numbers,
128
+ "a1": a1,
129
+ "a2": a2,
130
+ "s8": s8,
131
+ "s6": s6,
132
+ "covalent_radii": rcov,
133
+ "r4r2": r4r2,
134
+ "c6_reference": c6ab,
135
+ "coord_num_ref": cn_ref,
136
+ "batch_idx": batch_idx,
137
+ "cell": cell_bohr,
138
+ "neighbor_matrix": neighbor_matrix,
139
+ "neighbor_matrix_shifts": shifts_arg,
140
+ "fill_value": fill_value,
141
+ "num_systems": num_systems,
142
+ "compute_virial": compute_virial,
143
+ "device": str(coord.device),
144
+ }
145
+
146
+ # Only pass smoothing parameters if smoothing is enabled
147
+ # When smoothing_on >= smoothing_off, omit to use nvalchemiops defaults (1e10)
148
+ if smoothing_on < smoothing_off:
149
+ dftd3_kwargs["s5_smoothing_on"] = smoothing_on * constants.Bohr_inv
150
+ dftd3_kwargs["s5_smoothing_off"] = smoothing_off * constants.Bohr_inv
151
+
152
+ # Call nvalchemiops dftd3
153
+ result = dftd3(**dftd3_kwargs)
154
+
155
+ if compute_virial:
156
+ energy, forces, _coord_num, virial = result
157
+ else:
158
+ energy, forces, _coord_num = result
159
+ virial = torch.empty(0, device=coord.device)
160
+
161
+ # Convert to eV/Angstrom units
162
+ energy_ev = energy * constants.Hartree
163
+ forces_ev_ang = forces * constants.Hartree * constants.Bohr_inv
164
+
165
+ # Save tensors for backward - convert cell_bohr for gradient computation
166
+ cell_bohr_saved = torch.empty(0, device=coord.device)
167
+ if has_cell:
168
+ cell_bohr_saved = cell * constants.Bohr_inv
169
+ if cell_bohr_saved.ndim == 2:
170
+ cell_bohr_saved = cell_bohr_saved.unsqueeze(0)
171
+
172
+ ctx.save_for_backward(forces_ev_ang, virial, batch_idx, cell_bohr_saved)
173
+ ctx.has_cell = has_cell
174
+ ctx.compute_virial = compute_virial
175
+
176
+ return energy_ev, forces_ev_ang, virial
177
+
178
+ @staticmethod
179
+ def backward(
180
+ ctx: Any,
181
+ grad_energy: Tensor,
182
+ grad_forces: Tensor,
183
+ grad_virial: Tensor,
184
+ ) -> tuple[Tensor | None, ...]:
185
+ """Backward pass: compute gradients w.r.t. coord and cell."""
186
+ forces, virial, batch_idx, cell_bohr = ctx.saved_tensors
187
+ has_cell = ctx.has_cell
188
+ compute_virial = ctx.compute_virial
189
+
190
+ # Coord gradient: forces = -dE/dR, so dE/dR = -forces
191
+ grad_coord = -forces * grad_energy[batch_idx].unsqueeze(-1)
192
+
193
+ # Cell gradient for physical Cauchy stress
194
+ # Cauchy stress = -virial / volume, so dE/dcell = -virial @ inv(cell).T
195
+ grad_cell = None
196
+ if has_cell and compute_virial and virial.numel() > 0:
197
+ cell_inv_t = torch.linalg.inv(cell_bohr).transpose(-1, -2)
198
+ dE_dcell_bohr = -virial @ cell_inv_t # Negative sign for Cauchy stress
199
+ dE_dcell_ang = dE_dcell_bohr * constants.Hartree * constants.Bohr_inv
200
+ grad_cell = dE_dcell_ang * grad_energy.view(-1, 1, 1)
201
+
202
+ # Return gradients for all 21 inputs (only coord and cell have gradients)
203
+ return (
204
+ grad_coord,
205
+ grad_cell,
206
+ None, # numbers
207
+ None, # batch_idx
208
+ None, # neighbor_matrix
209
+ None, # shifts
210
+ None, # rcov
211
+ None, # r4r2
212
+ None, # c6ab
213
+ None, # cn_ref
214
+ None, # a1
215
+ None, # a2
216
+ None, # s6
217
+ None, # s8
218
+ None, # num_systems
219
+ None, # fill_value
220
+ None, # smoothing_on
221
+ None, # smoothing_off
222
+ None, # compute_virial
223
+ None, # has_cell
224
+ None, # has_shifts
225
+ )
226
+
227
+
228
+ # =============================================================================
229
+ # PyTorch Custom Op Registration
230
+ # =============================================================================
231
+
232
+
233
+ @torch.library.custom_op("aimnet::dftd3_fwd", mutates_args=())
234
+ def dftd3_fwd(
235
+ coord: Tensor,
236
+ cell: Tensor,
237
+ numbers: Tensor,
238
+ batch_idx: Tensor,
239
+ neighbor_matrix: Tensor,
240
+ shifts: Tensor,
241
+ rcov: Tensor,
242
+ r4r2: Tensor,
243
+ c6ab: Tensor,
244
+ cn_ref: Tensor,
245
+ a1: float,
246
+ a2: float,
247
+ s6: float,
248
+ s8: float,
249
+ num_systems: int,
250
+ fill_value: int,
251
+ smoothing_on: float,
252
+ smoothing_off: float,
253
+ compute_virial: bool,
254
+ has_cell: bool,
255
+ has_shifts: bool,
256
+ ) -> list[Tensor]:
257
+ """
258
+ Forward primitive for DFT-D3 energy computation.
259
+
260
+ Returns [energy, forces, virial] tensors.
261
+ """
262
+ result = _DFTD3Function.apply(
263
+ coord,
264
+ cell,
265
+ numbers,
266
+ batch_idx,
267
+ neighbor_matrix,
268
+ shifts,
269
+ rcov,
270
+ r4r2,
271
+ c6ab,
272
+ cn_ref,
273
+ a1,
274
+ a2,
275
+ s6,
276
+ s8,
277
+ num_systems,
278
+ fill_value,
279
+ smoothing_on,
280
+ smoothing_off,
281
+ compute_virial,
282
+ has_cell,
283
+ has_shifts,
284
+ )
285
+ return list(result)
286
+
287
+
288
+ @torch.library.register_fake("aimnet::dftd3_fwd")
289
+ def _(
290
+ coord: Tensor,
291
+ cell: Tensor,
292
+ numbers: Tensor,
293
+ batch_idx: Tensor,
294
+ neighbor_matrix: Tensor,
295
+ shifts: Tensor,
296
+ rcov: Tensor,
297
+ r4r2: Tensor,
298
+ c6ab: Tensor,
299
+ cn_ref: Tensor,
300
+ a1: float,
301
+ a2: float,
302
+ s6: float,
303
+ s8: float,
304
+ num_systems: int,
305
+ fill_value: int,
306
+ smoothing_on: float,
307
+ smoothing_off: float,
308
+ compute_virial: bool,
309
+ has_cell: bool,
310
+ has_shifts: bool,
311
+ ) -> list[Tensor]:
312
+ """Fake implementation for torch.compile tracing."""
313
+ n_atoms = coord.shape[0]
314
+ energy = coord.new_empty(num_systems)
315
+ forces = coord.new_empty(n_atoms, 3)
316
+ if compute_virial:
317
+ virial = coord.new_empty(num_systems, 3, 3)
318
+ else:
319
+ virial = coord.new_empty(0)
320
+ return [energy, forces, virial]
321
+
322
+
323
+ # =============================================================================
324
+ # Autograd Registration (Method-based)
325
+ # =============================================================================
326
+
327
+
328
+ def _dftd3_setup_context(ctx: Any, inputs: tuple[Any, ...], output: list[Tensor]) -> None:
329
+ """Setup context for backward pass."""
330
+ (
331
+ coord,
332
+ cell,
333
+ _numbers,
334
+ batch_idx,
335
+ _neighbor_matrix,
336
+ _shifts,
337
+ _rcov,
338
+ _r4r2,
339
+ _c6ab,
340
+ _cn_ref,
341
+ _a1,
342
+ _a2,
343
+ _s6,
344
+ _s8,
345
+ _num_systems,
346
+ _fill_value,
347
+ _smoothing_on,
348
+ _smoothing_off,
349
+ compute_virial,
350
+ has_cell,
351
+ _has_shifts,
352
+ ) = inputs
353
+ _energy, forces, virial = output
354
+
355
+ # Convert cell to Bohr for backward
356
+ cell_bohr = torch.empty(0, device=coord.device)
357
+ if has_cell:
358
+ cell_bohr = cell * constants.Bohr_inv
359
+ if cell_bohr.ndim == 2:
360
+ cell_bohr = cell_bohr.unsqueeze(0)
361
+
362
+ ctx.save_for_backward(forces, virial, batch_idx, cell_bohr)
363
+ ctx.has_cell = has_cell
364
+ ctx.compute_virial = compute_virial
365
+
366
+
367
+ def _dftd3_backward(
368
+ ctx: Any,
369
+ grad_outputs: list[Tensor],
370
+ ) -> tuple[Tensor | None, ...]:
371
+ """Backward pass for dftd3 energy."""
372
+ grad_energy = grad_outputs[0]
373
+ # grad_outputs[1] and grad_outputs[2] are grad_forces and grad_virial (unused)
374
+ forces, virial, batch_idx, cell_bohr = ctx.saved_tensors
375
+ has_cell = ctx.has_cell
376
+ compute_virial = ctx.compute_virial
377
+
378
+ # Coord gradient: forces = -dE/dR, so dE/dR = -forces
379
+ grad_coord = -forces * grad_energy[batch_idx].unsqueeze(-1)
380
+
381
+ # Cell gradient for physical Cauchy stress
382
+ # Cauchy stress = -virial / volume, so dE/dcell = -virial @ inv(cell).T
383
+ grad_cell = None
384
+ if has_cell and compute_virial and virial.numel() > 0:
385
+ cell_inv_t = torch.linalg.inv(cell_bohr).transpose(-1, -2)
386
+ dE_dcell_bohr = -virial @ cell_inv_t # Negative sign for Cauchy stress
387
+ dE_dcell_ang = dE_dcell_bohr * constants.Hartree * constants.Bohr_inv
388
+ grad_cell = dE_dcell_ang * grad_energy.view(-1, 1, 1)
389
+
390
+ # Return gradients for all 21 inputs (only coord and cell have gradients)
391
+ return (
392
+ grad_coord,
393
+ grad_cell,
394
+ None, # numbers
395
+ None, # batch_idx
396
+ None, # neighbor_matrix
397
+ None, # shifts
398
+ None, # rcov
399
+ None, # r4r2
400
+ None, # c6ab
401
+ None, # cn_ref
402
+ None, # a1
403
+ None, # a2
404
+ None, # s6
405
+ None, # s8
406
+ None, # num_systems
407
+ None, # fill_value
408
+ None, # smoothing_on
409
+ None, # smoothing_off
410
+ None, # compute_virial
411
+ None, # has_cell
412
+ None, # has_shifts
413
+ )
414
+
415
+
416
+ # Use method-based autograd registration for serializability
417
+ dftd3_fwd.register_autograd(
418
+ _dftd3_backward,
419
+ setup_context=_dftd3_setup_context,
420
+ )
421
+
422
+
423
+ # =============================================================================
424
+ # Public API
425
+ # =============================================================================
426
+
427
+
428
+ def dftd3_energy(
429
+ coord: Tensor,
430
+ cell: Tensor | None,
431
+ numbers: Tensor,
432
+ batch_idx: Tensor,
433
+ neighbor_matrix: Tensor,
434
+ shifts: Tensor | None,
435
+ rcov: Tensor,
436
+ r4r2: Tensor,
437
+ c6ab: Tensor,
438
+ cn_ref: Tensor,
439
+ a1: float,
440
+ a2: float,
441
+ s6: float,
442
+ s8: float,
443
+ num_systems: int,
444
+ fill_value: int,
445
+ smoothing_on: float,
446
+ smoothing_off: float,
447
+ compute_virial: bool = False,
448
+ ) -> Tensor:
449
+ """
450
+ Compute DFT-D3 dispersion energy with automatic differentiation support.
451
+
452
+ This function wraps the nvalchemiops DFT-D3 implementation as a PyTorch
453
+ custom op with proper autograd support for computing gradients.
454
+
455
+ Parameters
456
+ ----------
457
+ coord : Tensor
458
+ Atomic coordinates in Angstrom, shape (N, 3)
459
+ cell : Tensor or None
460
+ Unit cell vectors in Angstrom, shape (3, 3) or (B, 3, 3)
461
+ numbers : Tensor
462
+ Atomic numbers, shape (N,)
463
+ batch_idx : Tensor
464
+ Batch index for each atom, shape (N,)
465
+ neighbor_matrix : Tensor
466
+ Neighbor indices, shape (N, M)
467
+ shifts : Tensor or None
468
+ Periodic shift vectors as integers, shape (N, M, 3)
469
+ rcov : Tensor
470
+ Covalent radii, shape (95,)
471
+ r4r2 : Tensor
472
+ R4/R2 expectation values, shape (95,)
473
+ c6ab : Tensor
474
+ C6 reference values, shape (95, 95, 5, 5)
475
+ cn_ref : Tensor
476
+ Coordination number references, shape (95, 95, 5, 5)
477
+ a1 : float
478
+ BJ damping parameter a1
479
+ a2 : float
480
+ BJ damping parameter a2
481
+ s6 : float
482
+ Scaling factor for C6 term
483
+ s8 : float
484
+ Scaling factor for C8 term
485
+ num_systems : int
486
+ Number of systems in batch
487
+ fill_value : int
488
+ Fill value for invalid neighbor indices
489
+ smoothing_on : float
490
+ Distance at which smoothing starts, in Angstrom.
491
+ smoothing_off : float
492
+ Distance at which smoothing ends (cutoff), in Angstrom.
493
+ compute_virial : bool
494
+ Whether to compute virial tensor for cell gradients
495
+
496
+ Returns
497
+ -------
498
+ Tensor
499
+ Dispersion energy per system in eV, shape (num_systems,)
500
+
501
+ Notes
502
+ -----
503
+ Input coordinates are in Angstroms, internally converted to Bohr.
504
+ Output energies are in eV, forces in eV/Angstrom.
505
+ """
506
+ # Prepare tensors - custom op requires non-None tensors
507
+ cell_tensor = cell if cell is not None else torch.empty(0, device=coord.device)
508
+ shifts_tensor = shifts if shifts is not None else torch.empty(0, device=coord.device, dtype=torch.int32)
509
+
510
+ has_cell = cell is not None
511
+ has_shifts = shifts is not None
512
+
513
+ result = torch.ops.aimnet.dftd3_fwd(
514
+ coord,
515
+ cell_tensor,
516
+ numbers,
517
+ batch_idx,
518
+ neighbor_matrix,
519
+ shifts_tensor,
520
+ rcov,
521
+ r4r2,
522
+ c6ab,
523
+ cn_ref,
524
+ a1,
525
+ a2,
526
+ s6,
527
+ s8,
528
+ num_systems,
529
+ fill_value,
530
+ smoothing_on,
531
+ smoothing_off,
532
+ compute_virial,
533
+ has_cell,
534
+ has_shifts,
535
+ )
536
+
537
+ return result[0] # Return only energy