aimnet 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aimnet/modules/lr.py CHANGED
@@ -1,12 +1,96 @@
1
- from typing import Dict, Optional
1
+ import os
2
2
 
3
3
  import torch
4
4
  from torch import Tensor, nn
5
5
 
6
6
  from aimnet import constants, nbops, ops
7
+ from aimnet.modules.ops import dftd3_energy
8
+
9
+
10
+ def _calc_coulomb_sr(
11
+ data: dict[str, Tensor],
12
+ rc: Tensor,
13
+ envelope: str,
14
+ key_in: str,
15
+ factor: float,
16
+ ) -> Tensor:
17
+ """Shared short-range Coulomb energy calculation.
18
+
19
+ Computes pairwise Coulomb energy with envelope-weighted cutoff.
20
+
21
+ Parameters
22
+ ----------
23
+ data : dict
24
+ Data dictionary containing d_ij distances and charges.
25
+ rc : Tensor
26
+ Cutoff radius tensor.
27
+ envelope : str
28
+ Envelope function: "exp" or "cosine".
29
+ key_in : str
30
+ Key for charges in data dict.
31
+ factor : float
32
+ Unit conversion factor (half_Hartree * Bohr).
33
+
34
+ Returns
35
+ -------
36
+ Tensor
37
+ Short-range Coulomb energy per molecule.
38
+ """
39
+ d_ij = data["d_ij"]
40
+ q = data[key_in]
41
+ q_i, q_j = nbops.get_ij(q, data)
42
+ q_ij = q_i * q_j
43
+ if envelope == "exp":
44
+ fc = ops.exp_cutoff(d_ij, rc)
45
+ else: # cosine
46
+ fc = ops.cosine_cutoff(d_ij, rc.item())
47
+ e_ij = fc * q_ij / d_ij
48
+ e_ij = nbops.mask_ij_(e_ij, data, 0.0)
49
+ # Accumulate in float64 for precision
50
+ e_i = e_ij.sum(-1, dtype=torch.float64)
51
+ return factor * nbops.mol_sum(e_i, data)
7
52
 
8
53
 
9
54
  class LRCoulomb(nn.Module):
55
+ """Long-range Coulomb energy module.
56
+
57
+ Computes electrostatic energy using one of several methods:
58
+ simple (all pairs), DSF (damped shifted force), or Ewald summation.
59
+
60
+ Parameters
61
+ ----------
62
+ key_in : str
63
+ Key for input charges in data dict. Default is "charges".
64
+ key_out : str
65
+ Key for output energy in data dict. Default is "e_h".
66
+ rc : float
67
+ Short-range cutoff radius. Default is 4.6 Angstrom.
68
+ method : str
69
+ Coulomb method: "simple", "dsf", or "ewald". Default is "simple".
70
+ dsf_alpha : float
71
+ Alpha parameter for DSF method. Default is 0.2.
72
+ dsf_rc : float
73
+ Cutoff for DSF method. Default is 15.0.
74
+ ewald_accuracy : float
75
+ Target accuracy for Ewald summation. Controls real-space and
76
+ reciprocal-space cutoffs. Lower values give higher accuracy.
77
+ Default is 1e-8.
78
+ subtract_sr : bool
79
+ Whether to subtract short-range contribution. Default is True.
80
+ envelope : str
81
+ Envelope function for SR cutoff: "exp" or "cosine". Default is "exp".
82
+
83
+ Notes
84
+ -----
85
+ Energy accumulation uses float64 for numerical precision, particularly
86
+ important for large systems where many small contributions can suffer
87
+ from floating-point error accumulation.
88
+
89
+ Neighbor list keys follow a suffix resolution pattern: methods first look
90
+ for module-specific keys (e.g., nbmat_coulomb, shifts_coulomb), falling
91
+ back to shared _lr suffix (nbmat_lr, shifts_lr) if not found.
92
+ """
93
+
10
94
  def __init__(
11
95
  self,
12
96
  key_in: str = "charges",
@@ -15,6 +99,9 @@ class LRCoulomb(nn.Module):
15
99
  method: str = "simple",
16
100
  dsf_alpha: float = 0.2,
17
101
  dsf_rc: float = 15.0,
102
+ ewald_accuracy: float = 1e-8,
103
+ subtract_sr: bool = True,
104
+ envelope: str = "exp",
18
105
  ):
19
106
  super().__init__()
20
107
  self.key_in = key_in
@@ -23,55 +110,63 @@ class LRCoulomb(nn.Module):
23
110
  self.register_buffer("rc", torch.tensor(rc))
24
111
  self.dsf_alpha = dsf_alpha
25
112
  self.dsf_rc = dsf_rc
113
+ self.ewald_accuracy = ewald_accuracy
114
+ self.subtract_sr = subtract_sr
115
+ if envelope not in ("exp", "cosine"):
116
+ raise ValueError(f"Unknown envelope {envelope}, must be 'exp' or 'cosine'")
117
+ self.envelope = envelope
26
118
  if method in ("simple", "dsf", "ewald"):
27
119
  self.method = method
28
120
  else:
29
121
  raise ValueError(f"Unknown method {method}")
30
122
 
31
- def coul_simple(self, data: Dict[str, Tensor]) -> Tensor:
32
- data = ops.lazy_calc_dij_lr(data)
33
- d_ij = data["d_ij_lr"]
34
- q = data[self.key_in]
35
- q_i, q_j = nbops.get_ij(q, data, suffix="_lr")
36
- q_ij = q_i * q_j
37
- fc = 1.0 - ops.exp_cutoff(d_ij, self.rc)
38
- e_ij = fc * q_ij / d_ij
39
- e_ij = nbops.mask_ij_(e_ij, data, 0.0, suffix="_lr")
40
- e_i = e_ij.sum(-1)
41
- e = self._factor * nbops.mol_sum(e_i, data)
42
- return e
123
+ def coul_simple(self, data: dict[str, Tensor]) -> Tensor:
124
+ """Compute pairwise Coulomb energy.
43
125
 
44
- def coul_simple_sr(self, data: Dict[str, Tensor]) -> Tensor:
45
- d_ij = data["d_ij"]
126
+ With subtract_sr=True (default): Returns LR only (FULL - SR)
127
+ With subtract_sr=False: Returns FULL pairwise Coulomb
128
+ """
129
+ suffix = nbops.resolve_suffix(data, ["_coulomb", "_lr"])
130
+ data = ops.lazy_calc_dij(data, suffix)
131
+ d_ij = data[f"d_ij{suffix}"]
46
132
  q = data[self.key_in]
47
- q_i, q_j = nbops.get_ij(q, data)
133
+ q_i, q_j = nbops.get_ij(q, data, suffix=suffix)
48
134
  q_ij = q_i * q_j
49
- fc = ops.exp_cutoff(d_ij, self.rc)
50
- e_ij = fc * q_ij / d_ij
51
- e_ij = nbops.mask_ij_(e_ij, data, 0.0)
52
- e_i = e_ij.sum(-1)
135
+ # Compute FULL pairwise Coulomb (no exp_cutoff weighting)
136
+ e_ij = q_ij / d_ij
137
+ e_ij = nbops.mask_ij_(e_ij, data, 0.0, suffix=suffix)
138
+ e_i = e_ij.sum(-1, dtype=torch.float64)
53
139
  e = self._factor * nbops.mol_sum(e_i, data)
140
+ # Same pattern as dsf/ewald - subtract SR to get LR
141
+ if self.subtract_sr:
142
+ e = e - self.coul_simple_sr(data)
54
143
  return e
55
144
 
56
- def coul_dsf(self, data: Dict[str, Tensor]) -> Tensor:
57
- data = ops.lazy_calc_dij_lr(data)
58
- d_ij = data["d_ij_lr"]
145
+ def coul_simple_sr(self, data: dict[str, Tensor]) -> Tensor:
146
+ return _calc_coulomb_sr(data, self.rc, self.envelope, self.key_in, self._factor)
147
+
148
+ def coul_dsf(self, data: dict[str, Tensor]) -> Tensor:
149
+ suffix = nbops.resolve_suffix(data, ["_coulomb", "_lr"])
150
+ data = ops.lazy_calc_dij(data, suffix)
151
+ d_ij = data[f"d_ij{suffix}"]
59
152
  q = data[self.key_in]
60
- q_i, q_j = nbops.get_ij(q, data, suffix="_lr")
153
+ q_i, q_j = nbops.get_ij(q, data, suffix=suffix)
61
154
  J = ops.coulomb_matrix_dsf(d_ij, self.dsf_rc, self.dsf_alpha, data)
62
- e = (q_i * q_j * J).sum(-1)
155
+ e = (q_i * q_j * J).sum(-1, dtype=torch.float64)
63
156
  e = self._factor * nbops.mol_sum(e, data)
64
- e = e - self.coul_simple_sr(data)
157
+ if self.subtract_sr:
158
+ e = e - self.coul_simple_sr(data)
65
159
  return e
66
160
 
67
- def coul_ewald(self, data: Dict[str, Tensor]) -> Tensor:
68
- J = ops.coulomb_matrix_ewald(data["coord"], data["cell"])
161
+ def coul_ewald(self, data: dict[str, Tensor]) -> Tensor:
162
+ J = ops.coulomb_matrix_ewald(data["coord"], data["cell"], accuracy=self.ewald_accuracy)
69
163
  q_i, q_j = data["charges"].unsqueeze(-1), data["charges"].unsqueeze(-2)
70
- e = self._factor * (q_i * q_j * J).flatten(-2, -1).sum(-1)
71
- e = e - self.coul_simple_sr(data)
164
+ e = self._factor * (q_i * q_j * J).flatten(-2, -1).sum(-1, dtype=torch.float64)
165
+ if self.subtract_sr:
166
+ e = e - self.coul_simple_sr(data)
72
167
  return e
73
168
 
74
- def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
169
+ def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
75
170
  if self.method == "simple":
76
171
  e = self.coul_simple(data)
77
172
  elif self.method == "dsf":
@@ -81,46 +176,131 @@ class LRCoulomb(nn.Module):
81
176
  else:
82
177
  raise ValueError(f"Unknown method {self.method}")
83
178
  if self.key_out in data:
84
- data[self.key_out] = data[self.key_out] + e
179
+ data[self.key_out] = data[self.key_out].double() + e
85
180
  else:
86
181
  data[self.key_out] = e
87
182
  return data
88
183
 
89
184
 
185
+ class SRCoulomb(nn.Module):
186
+ """Subtract short-range Coulomb contribution from energy.
187
+
188
+ For models trained with "simple" Coulomb mode, the NN has implicitly learned
189
+ the short-range Coulomb interaction. When using DSF or Ewald summation for
190
+ the full Coulomb energy, we need to subtract this short-range contribution
191
+ to avoid double-counting.
192
+
193
+ Parameters
194
+ ----------
195
+ rc : float
196
+ Cutoff radius for short-range Coulomb. Default is 4.6 Angstrom.
197
+ key_in : str
198
+ Key for input charges in data dict. Default is "charges".
199
+ key_out : str
200
+ Key for output energy in data dict. Default is "energy".
201
+ envelope : str
202
+ Envelope function for cutoff: "exp" (mollifier) or "cosine". Default is "exp".
203
+ """
204
+
205
+ def __init__(
206
+ self,
207
+ rc: float = 4.6,
208
+ key_in: str = "charges",
209
+ key_out: str = "energy",
210
+ envelope: str = "exp",
211
+ ):
212
+ super().__init__()
213
+ self.key_in = key_in
214
+ self.key_out = key_out
215
+ self._factor = constants.half_Hartree * constants.Bohr
216
+ self.register_buffer("rc", torch.tensor(rc))
217
+ if envelope not in ("exp", "cosine"):
218
+ raise ValueError(f"Unknown envelope {envelope}, must be 'exp' or 'cosine'")
219
+ self.envelope = envelope
220
+
221
+ def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
222
+ """Subtract short-range Coulomb from energy."""
223
+ e_sr = _calc_coulomb_sr(data, self.rc, self.envelope, self.key_in, self._factor)
224
+
225
+ # Subtract short-range Coulomb from energy (in float64)
226
+ if self.key_out in data:
227
+ data[self.key_out] = data[self.key_out].double() - e_sr
228
+ else:
229
+ data[self.key_out] = -e_sr
230
+ return data
231
+
232
+
90
233
  class DispParam(nn.Module):
91
234
  def __init__(
92
235
  self,
93
- ref_c6: Optional[Dict[int, Tensor] | Tensor] = None,
94
- ref_alpha: Optional[Dict[int, Tensor] | Tensor] = None,
95
- ptfile: Optional[str] = None,
236
+ ref_c6: dict[int, Tensor] | Tensor | None = None,
237
+ ref_alpha: dict[int, Tensor] | Tensor | None = None,
238
+ ptfile: str | None = None,
96
239
  key_in: str = "disp_param",
97
240
  key_out: str = "disp_param",
98
241
  ):
99
242
  super().__init__()
100
- if (
101
- ptfile is None
102
- and (ref_c6 is None or ref_alpha is None)
103
- or ptfile is not None
104
- and (ref_c6 is not None or ref_alpha is not None)
105
- ):
106
- raise ValueError("Either ptfile or ref_c6 and ref_alpha should be supplied.")
107
- # load data
108
- ref = torch.load(ptfile) if ptfile is not None else torch.zeros(87, 2)
109
- for i, p in enumerate([ref_c6, ref_alpha]):
110
- if p is not None:
111
- if isinstance(p, Tensor):
112
- ref[: p.shape[0], i] = p
113
- else:
114
- for k, v in p.items():
115
- ref[k, i] = v
116
- # c6=0 and alpha=1 for dummy atom
243
+ # Validate: cannot mix ptfile with ref_c6/ref_alpha
244
+ if ptfile is not None and (ref_c6 is not None or ref_alpha is not None):
245
+ raise ValueError("Cannot specify both ptfile and ref_c6/ref_alpha.")
246
+
247
+ # Load reference data
248
+ if ptfile is not None:
249
+ ref = torch.load(ptfile, weights_only=True)
250
+ elif ref_c6 is not None or ref_alpha is not None:
251
+ ref = torch.zeros(87, 2)
252
+ for i, p in enumerate([ref_c6, ref_alpha]):
253
+ if p is not None:
254
+ if isinstance(p, Tensor):
255
+ ref[: p.shape[0], i] = p
256
+ else:
257
+ for k, v in p.items():
258
+ ref[k, i] = v
259
+ else:
260
+ # Placeholder - will be populated by load_state_dict
261
+ ref = torch.zeros(87, 2)
262
+
263
+ # Element 0 represents dummy atoms with c6=0 and alpha=1
117
264
  ref[0, 0] = 0.0
118
265
  ref[0, 1] = 1.0
119
266
  self.register_buffer("disp_param0", ref)
120
267
  self.key_in = key_in
121
268
  self.key_out = key_out
122
269
 
123
- def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
270
+ def _load_from_state_dict(
271
+ self,
272
+ state_dict: dict,
273
+ prefix: str,
274
+ local_metadata: dict,
275
+ strict: bool,
276
+ missing_keys: list,
277
+ unexpected_keys: list,
278
+ error_msgs: list,
279
+ ) -> None:
280
+ # Resize placeholder buffer to match checkpoint size before loading
281
+ key = prefix + "disp_param0"
282
+ if key in state_dict:
283
+ buf = state_dict[key]
284
+ if buf.shape != self.disp_param0.shape:
285
+ # Resize placeholder to match checkpoint
286
+ self.disp_param0 = torch.zeros_like(buf)
287
+
288
+ # Validate buffer has non-zero values (safety check)
289
+ nonzero = (buf != 0).sum() / buf.numel()
290
+ if nonzero < 0.1:
291
+ import warnings
292
+
293
+ warnings.warn(
294
+ f"DispParam buffer appears to have mostly zero values (nonzero: {nonzero:.1%}). "
295
+ "This may indicate a loading issue.",
296
+ stacklevel=2,
297
+ )
298
+
299
+ super()._load_from_state_dict(
300
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
301
+ )
302
+
303
+ def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
124
304
  disp_param_mult = data[self.key_in].clamp(min=-4, max=4).exp()
125
305
  disp_param = self.disp_param0[data["numbers"]]
126
306
  vals = disp_param * disp_param_mult
@@ -141,24 +321,28 @@ class D3TS(nn.Module):
141
321
  self.key_in = key_in
142
322
  self.key_out = key_out
143
323
 
144
- def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
324
+ def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
325
+ suffix = nbops.resolve_suffix(data, ["_dftd3", "_lr"])
326
+
145
327
  disp_param = data[self.key_in]
146
- disp_param_i, disp_param_j = nbops.get_ij(disp_param, data, suffix="_lr")
328
+ disp_param_i, disp_param_j = nbops.get_ij(disp_param, data, suffix=suffix)
147
329
  c6_i, alpha_i = disp_param_i.unbind(dim=-1)
148
330
  c6_j, alpha_j = disp_param_j.unbind(dim=-1)
149
331
 
150
332
  # TS combination rule
151
333
  c6ij = 2 * c6_i * c6_j / (c6_i * alpha_j / alpha_i + c6_j * alpha_i / alpha_j).clamp(min=1e-4)
334
+ c6ij = nbops.mask_ij_(c6ij, data, 0.0, suffix=suffix)
152
335
 
153
336
  rr = self.r4r2[data["numbers"]]
154
- rr_i, rr_j = nbops.get_ij(rr, data, suffix="_lr")
337
+ rr_i, rr_j = nbops.get_ij(rr, data, suffix=suffix)
155
338
  rrij = 3 * rr_i * rr_j
156
- rrij = nbops.mask_ij_(rrij, data, 1.0, suffix="_lr")
339
+ rrij = nbops.mask_ij_(rrij, data, 1.0, suffix=suffix)
157
340
  r0ij = self.a1 * rrij.sqrt() + self.a2
158
341
 
159
- ops.lazy_calc_dij_lr(data)
160
- d_ij = data["d_ij_lr"] * constants.Bohr_inv
342
+ ops.lazy_calc_dij(data, suffix)
343
+ d_ij = data[f"d_ij{suffix}"] * constants.Bohr_inv
161
344
  e_ij = c6ij * (self.s6 / (d_ij.pow(6) + r0ij.pow(6)) + self.s8 * rrij / (d_ij.pow(8) + r0ij.pow(8)))
345
+
162
346
  e = -constants.half_Hartree * nbops.mol_sum(e_ij.sum(-1), data)
163
347
 
164
348
  if self.key_out in data:
@@ -170,74 +354,295 @@ class D3TS(nn.Module):
170
354
 
171
355
 
172
356
  class DFTD3(nn.Module):
173
- """DFT-D3 implementation.
174
- BJ dumping, C6 and C8 terms, without 3-body term.
357
+ """DFT-D3 implementation using nvalchemiops GPU-accelerated kernels.
358
+
359
+ BJ damping, C6 and C8 terms, without 3-body term.
360
+
361
+ This implementation uses nvalchemiops.interactions.dispersion.dftd3 for
362
+ GPU-accelerated computation of dispersion energies and forces. It is
363
+ differentiable through a custom autograd function.
364
+
365
+ Parameters
366
+ ----------
367
+ s8 : float
368
+ Scaling factor for C8 term.
369
+ a1 : float
370
+ BJ damping parameter 1.
371
+ a2 : float
372
+ BJ damping parameter 2.
373
+ s6 : float, optional
374
+ Scaling factor for C6 term. Default is 1.0.
375
+ cutoff : float, optional
376
+ Cutoff distance in Angstroms for smoothing. Default is 15.0.
377
+ smoothing_fraction : float, optional
378
+ Fraction of cutoff distance used for smoothing window width.
379
+ Smoothing starts at cutoff * (1 - smoothing_fraction) and ends at cutoff.
380
+ Example: With cutoff=15.0 and smoothing_fraction=0.2:
381
+ - Smoothing starts at 12.0 Å (15.0 * 0.8)
382
+ - Smoothing ends at 15.0 Å
383
+ Default is 0.2 (20% of cutoff as smoothing window).
384
+ key_out : str, optional
385
+ Key for output energy in data dict. Default is "energy".
386
+ compute_forces : bool, optional
387
+ Whether to add forces to data dict. Default is False.
388
+ compute_virial : bool, optional
389
+ Whether to compute virial for cell gradients. Default is False.
390
+
391
+ Attributes
392
+ ----------
393
+ smoothing_on : float
394
+ Distance where smoothing starts (Angstroms).
395
+ smoothing_off : float
396
+ Distance where smoothing ends / cutoff (Angstroms).
397
+ s6, s8, a1, a2 : float
398
+ BJ damping parameters.
399
+
400
+ Notes
401
+ -----
402
+ Neighbor list keys follow a suffix resolution pattern: methods first look
403
+ for module-specific keys (e.g., nbmat_dftd3, shifts_dftd3), falling back
404
+ to shared _lr suffix (nbmat_lr, shifts_lr) if not found.
175
405
  """
176
406
 
177
- def __init__(self, s8: float, a1: float, a2: float, s6: float = 1.0, key_out="energy"):
407
+ def __init__(
408
+ self,
409
+ s8: float,
410
+ a1: float,
411
+ a2: float,
412
+ s6: float = 1.0,
413
+ cutoff: float = 15.0,
414
+ smoothing_fraction: float = 0.2,
415
+ key_out: str = "energy",
416
+ compute_forces: bool = False,
417
+ compute_virial: bool = False,
418
+ ):
178
419
  super().__init__()
179
420
  self.key_out = key_out
421
+ self.compute_forces = compute_forces
422
+ self.compute_virial = compute_virial
180
423
  # BJ damping parameters
181
424
  self.s6 = s6
182
425
  self.s8 = s8
183
- self.s9 = 4.0 / 3.0
184
426
  self.a1 = a1
185
427
  self.a2 = a2
186
- self.a3 = 16.0
187
- # CN parameters
188
- self.k1 = -16.0
189
- self.k3 = -4.0
190
- # data
191
- self.register_buffer("c6ab", torch.zeros(95, 95, 5, 5, 3))
192
- self.register_buffer("r4r2", torch.zeros(95))
193
- self.register_buffer("rcov", torch.zeros(95))
194
- self.register_buffer("cnmax", torch.zeros(95))
195
- sd = constants.get_dftd3_param()
196
- self.load_state_dict(sd)
197
-
198
- def _calc_c6ij(self, data: Dict[str, Tensor]) -> Tensor:
199
- # CN part
200
- # short range for CN
201
- # d_ij = data["d_ij"] * constants.Bohr_inv
202
- data = ops.lazy_calc_dij_lr(data)
203
- d_ij = data["d_ij_lr"] * constants.Bohr_inv
204
-
205
- numbers = data["numbers"]
206
- numbers_i, numbers_j = nbops.get_ij(numbers, data, suffix="_lr")
207
- rcov_i, rcov_j = nbops.get_ij(self.rcov[numbers], data, suffix="_lr")
208
- rcov_ij = rcov_i + rcov_j
209
- cn_ij = 1.0 / (1.0 + torch.exp(self.k1 * (rcov_ij / d_ij - 1.0)))
210
- cn_ij = nbops.mask_ij_(cn_ij, data, 0.0, suffix="_lr")
211
- cn = cn_ij.sum(-1)
212
- cn = torch.clamp(cn, max=self.cnmax[numbers]).unsqueeze(-1).unsqueeze(-1)
213
- cn_i, cn_j = nbops.get_ij(cn, data, suffix="_lr")
214
- c6ab = self.c6ab[numbers_i, numbers_j]
215
- c6ref, cnref_i, cnref_j = torch.unbind(c6ab, dim=-1)
216
- c6ref = nbops.mask_ij_(c6ref, data, 0.0, suffix="_lr")
217
- l_ij = torch.exp(self.k3 * ((cn_i - cnref_i).pow(2) + (cn_j - cnref_j).pow(2)))
218
- w = l_ij.flatten(-2, -1).sum(-1)
219
- z = torch.einsum("...ij,...ij->...", c6ref, l_ij)
220
- _w = w < 1e-5
221
- z[_w] = 0.0
222
- c6_ij = z / w.clamp(min=1e-5)
223
- return c6_ij
224
-
225
- def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
226
- c6ij = self._calc_c6ij(data)
227
428
 
228
- rr = self.r4r2[data["numbers"]]
229
- rr_i, rr_j = nbops.get_ij(rr, data, suffix="_lr")
230
- rrij = 3 * rr_i * rr_j
231
- rrij = nbops.mask_ij_(rrij, data, 1.0, suffix="_lr")
232
- r0ij = self.a1 * rrij.sqrt() + self.a2
429
+ # Smoothing parameters as module attributes
430
+ self.smoothing_on: float = cutoff * (1 - smoothing_fraction)
431
+ self.smoothing_off: float = cutoff
233
432
 
234
- ops.lazy_calc_dij_lr(data)
235
- d_ij = data["d_ij_lr"] * constants.Bohr_inv
236
- e_ij = c6ij * (self.s6 / (d_ij.pow(6) + r0ij.pow(6)) + self.s8 * rrij / (d_ij.pow(8) + r0ij.pow(8)))
237
- e = -constants.half_Hartree * nbops.mol_sum(e_ij.sum(-1), data)
433
+ # Load D3 reference parameters and convert to nvalchemiops format
434
+ dirname = os.path.dirname(os.path.dirname(__file__))
435
+ filename = os.path.join(dirname, "dftd3_data.pt")
436
+ param = torch.load(filename, map_location="cpu", weights_only=True)
437
+
438
+ c6ab_packed = param["c6ab"]
439
+ c6ab = c6ab_packed[..., 0].contiguous()
440
+ cn_ref = c6ab_packed[..., 1].contiguous()
441
+
442
+ # Register buffers for D3 parameters
443
+ self.register_buffer("rcov", param["rcov"].float())
444
+ self.register_buffer("r4r2", param["r4r2"].float())
445
+ self.register_buffer("c6ab", c6ab.float())
446
+ self.register_buffer("cn_ref", cn_ref.float())
447
+
448
+ def set_smoothing(self, cutoff: float, smoothing_fraction: float = 0.2) -> None:
449
+ """Update smoothing parameters based on new cutoff and fraction.
450
+
451
+ Parameters
452
+ ----------
453
+ cutoff : float
454
+ Cutoff distance in Angstroms.
455
+ smoothing_fraction : float
456
+ Fraction of cutoff used as smoothing window width.
457
+ Smoothing occurs from cutoff * (1 - smoothing_fraction) to cutoff.
458
+ Example: smoothing_fraction=0.2 means smoothing over last 20%
459
+ of cutoff distance (from 0.8*cutoff to cutoff). Default is 0.2.
460
+ """
461
+ self.smoothing_on = cutoff * (1 - smoothing_fraction)
462
+ self.smoothing_off = cutoff
463
+
464
+ def _load_from_state_dict(
465
+ self,
466
+ state_dict: dict,
467
+ prefix: str,
468
+ local_metadata: dict,
469
+ strict: bool,
470
+ missing_keys: list,
471
+ unexpected_keys: list,
472
+ error_msgs: list,
473
+ ) -> None:
474
+ """Handle loading from old state dict format with packed c6ab.
475
+
476
+ Migrates from legacy format where c6ab had shape [95, 95, 5, 5, 3]
477
+ with last dimension containing (c6ref, cnref_i, cnref_j) to new format
478
+ where c6ab is [95, 95, 5, 5] and cn_ref is separate [95, 95, 5, 5].
479
+ Also removes deprecated cnmax parameter if present.
480
+ """
481
+ c6ab_key = prefix + "c6ab"
482
+ cn_ref_key = prefix + "cn_ref"
483
+ cnmax_key = prefix + "cnmax"
484
+
485
+ if c6ab_key in state_dict and state_dict[c6ab_key].ndim == 5:
486
+ c6ab_packed = state_dict[c6ab_key]
487
+ state_dict[c6ab_key] = c6ab_packed[..., 0].contiguous()
488
+ state_dict[cn_ref_key] = c6ab_packed[..., 1].contiguous()
489
+
490
+ state_dict.pop(cnmax_key, None)
491
+
492
+ super()._load_from_state_dict(
493
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
494
+ )
238
495
 
496
+ def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
497
+ nb_mode = nbops.get_nb_mode(data)
498
+ coord = data["coord"]
499
+ numbers = data["numbers"].to(torch.int32)
500
+ cell = data.get("cell")
501
+
502
+ # Prepare inputs based on nb_mode
503
+ if nb_mode == 0:
504
+ # Batched mode without neighbor matrix - construct full neighbor matrix
505
+ # This only applies when no nbmat is provided in data
506
+ B, N = coord.shape[:2]
507
+ coord_flat = coord.flatten(0, 1) # (B*N, 3)
508
+ numbers_flat = numbers.flatten() # (B*N,)
509
+ batch_idx = torch.arange(B, device=coord.device, dtype=torch.int32).repeat_interleave(N)
510
+ num_systems = B
511
+ total_atoms = B * N
512
+ max_neighbors = N - 1
513
+
514
+ # Build dense all-to-all neighbor matrix for mode 0 (no pre-computed neighbor list)
515
+ # Creates (N, N-1) template where each atom connects to all others except itself
516
+ arange_n = torch.arange(N, device=coord.device, dtype=torch.int32)
517
+ all_indices = arange_n.unsqueeze(0).expand(N, -1)
518
+ mask = all_indices != arange_n.unsqueeze(1)
519
+ template = all_indices[mask].view(N, N - 1)
520
+ batch_offsets = torch.arange(B, device=coord.device, dtype=torch.int32).unsqueeze(1).unsqueeze(2) * N
521
+ neighbor_matrix = (template.unsqueeze(0) + batch_offsets).view(total_atoms, max_neighbors)
522
+
523
+ fill_value = total_atoms
524
+ neighbor_matrix_shifts: Tensor | None = None
525
+ cell_for_autograd: Tensor | None = None
526
+
527
+ elif nb_mode == 1:
528
+ # Flat mode with neighbor matrix
529
+ suffix = nbops.resolve_suffix(data, ["_dftd3", "_lr"])
530
+ N = coord.shape[0]
531
+ coord_flat = coord
532
+ numbers_flat = numbers
533
+ neighbor_matrix = data[f"nbmat{suffix}"].to(torch.int32)
534
+
535
+ mol_idx = data.get("mol_idx")
536
+ if mol_idx is not None:
537
+ batch_idx = mol_idx.to(torch.int32)
538
+ num_systems = int(mol_idx.max().item()) + 1
539
+ else:
540
+ batch_idx = torch.zeros(N, dtype=torch.int32, device=coord.device)
541
+ num_systems = 1
542
+
543
+ shifts = data.get(f"shifts{suffix}")
544
+ neighbor_matrix_shifts = shifts.to(torch.int32) if shifts is not None else None
545
+
546
+ fill_value = N
547
+ cell_for_autograd = cell
548
+
549
+ elif nb_mode == 2:
550
+ # Batched mode with neighbor matrix
551
+ suffix = nbops.resolve_suffix(data, ["_dftd3", "_lr"])
552
+ B, N = coord.shape[:2]
553
+ coord_flat = coord.flatten(0, 1)
554
+ numbers_flat = numbers.flatten()
555
+ batch_idx = torch.arange(B, device=coord.device, dtype=torch.int32).repeat_interleave(N)
556
+ num_systems = B
557
+
558
+ nbmat = data[f"nbmat{suffix}"]
559
+ offsets = torch.arange(B, device=coord.device).unsqueeze(1) * N
560
+ neighbor_matrix = (nbmat + offsets.unsqueeze(-1)).flatten(0, 1).to(torch.int32)
561
+
562
+ shifts = data.get(f"shifts{suffix}")
563
+ if shifts is not None:
564
+ neighbor_matrix_shifts = shifts.flatten(0, 1).to(torch.int32)
565
+ else:
566
+ neighbor_matrix_shifts = None
567
+
568
+ fill_value = B * N
569
+ cell_for_autograd = cell
570
+
571
+ else:
572
+ raise ValueError(f"Unsupported neighbor mode: {nb_mode}")
573
+
574
+ # Compute energy using autograd function
575
+ energy_ev = self._compute_energy_autograd(
576
+ coord_flat,
577
+ cell_for_autograd,
578
+ numbers_flat,
579
+ batch_idx,
580
+ neighbor_matrix,
581
+ neighbor_matrix_shifts,
582
+ num_systems,
583
+ fill_value,
584
+ )
585
+
586
+ # Add dispersion energy to output
239
587
  if self.key_out in data:
240
- data[self.key_out] = data[self.key_out] + e
588
+ data[self.key_out] = data[self.key_out] + energy_ev.to(data[self.key_out].dtype)
241
589
  else:
242
- data[self.key_out] = e
590
+ data[self.key_out] = energy_ev
591
+
592
+ # Optionally compute and add forces to data dict
593
+ # Compute forces via autograd (will use saved forces from DFTD3Function)
594
+ if self.compute_forces and not torch.jit.is_scripting() and coord_flat.requires_grad:
595
+ # Forces are -grad of energy
596
+ forces_flat = torch.autograd.grad(
597
+ energy_ev.sum(),
598
+ coord_flat,
599
+ create_graph=self.training,
600
+ retain_graph=True,
601
+ )[0]
602
+ forces = -forces_flat
603
+
604
+ # Reshape if needed
605
+ if nb_mode == 0 or nb_mode == 2:
606
+ B, N = coord.shape[:2]
607
+ forces = forces.view(B, N, 3)
608
+
609
+ if "forces" in data:
610
+ data["forces"] = data["forces"] + forces
611
+ else:
612
+ data["forces"] = forces
613
+
243
614
  return data
615
+
616
+ def _compute_energy_autograd(
617
+ self,
618
+ coord: Tensor,
619
+ cell: Tensor | None,
620
+ numbers: Tensor,
621
+ batch_idx: Tensor,
622
+ neighbor_matrix: Tensor,
623
+ neighbor_matrix_shifts: Tensor | None,
624
+ num_systems: int,
625
+ fill_value: int,
626
+ ) -> Tensor:
627
+ """Compute DFT-D3 energy using custom op for differentiability and TorchScript."""
628
+ return dftd3_energy(
629
+ coord=coord,
630
+ cell=cell,
631
+ numbers=numbers,
632
+ batch_idx=batch_idx,
633
+ neighbor_matrix=neighbor_matrix,
634
+ shifts=neighbor_matrix_shifts,
635
+ rcov=self.rcov,
636
+ r4r2=self.r4r2,
637
+ c6ab=self.c6ab,
638
+ cn_ref=self.cn_ref,
639
+ a1=self.a1,
640
+ a2=self.a2,
641
+ s6=self.s6,
642
+ s8=self.s8,
643
+ num_systems=num_systems,
644
+ fill_value=fill_value,
645
+ smoothing_on=self.smoothing_on,
646
+ smoothing_off=self.smoothing_off,
647
+ compute_virial=self.compute_virial,
648
+ )