aimnet 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,20 +1,205 @@
1
+ import math
1
2
  import warnings
2
- from typing import Any, ClassVar, Dict, Literal
3
+ from typing import Any, ClassVar, Literal
3
4
 
4
5
  import torch
6
+ from nvalchemiops.neighborlist import neighbor_list
7
+ from nvalchemiops.neighborlist.neighbor_utils import NeighborOverflowError
5
8
  from torch import Tensor, nn
6
9
 
10
+ from aimnet.models.base import load_model
11
+ from aimnet.modules import DFTD3, LRCoulomb
12
+
7
13
  from .model_registry import get_model_path
8
- from .nbmat import TooManyNeighborsError, calc_nbmat
14
+
15
+
16
+ class AdaptiveNeighborList:
17
+ """Adaptive neighbor list with automatic buffer sizing.
18
+
19
+ Wraps nvalchemiops.neighborlist.neighbor_list with automatic max_neighbors adjustment.
20
+ Maintains ~75% utilization to balance memory and recomputation.
21
+
22
+ Parameters
23
+ ----------
24
+ cutoff : float
25
+ Cutoff distance for neighbor detection in Angstroms.
26
+ density : float, optional
27
+ Initial atomic density estimate for allocation sizing.
28
+ Used to compute initial max_neighbors as density * (4/3 * pi * cutoff^3).
29
+ Default is 0.2.
30
+ target_utilization : float, optional
31
+ Target ratio of actual neighbors to allocated max_neighbors.
32
+ Default is 0.75 (75% utilization).
33
+
34
+ Attributes
35
+ ----------
36
+ cutoff : float
37
+ Cutoff distance for neighbor detection.
38
+ target_utilization : float
39
+ Target ratio of actual to allocated neighbors.
40
+ max_neighbors : int
41
+ Current maximum neighbor allocation (rounded to 16).
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ cutoff: float,
47
+ density: float = 0.2,
48
+ target_utilization: float = 0.75,
49
+ ) -> None:
50
+ self.cutoff = cutoff
51
+ self.target_utilization = target_utilization
52
+ sphere_volume = 4 / 3 * math.pi * cutoff**3
53
+ self.max_neighbors = self._round_to_16(int(density * sphere_volume))
54
+
55
+ @staticmethod
56
+ def _round_to_16(n: int) -> int:
57
+ """Round up to the next multiple of 16 for memory alignment."""
58
+ return ((n + 15) // 16) * 16
59
+
60
+ def __call__(
61
+ self,
62
+ positions: Tensor,
63
+ cell: Tensor | None = None,
64
+ pbc: Tensor | None = None,
65
+ batch_idx: Tensor | None = None,
66
+ fill_value: int | None = None,
67
+ ) -> tuple[Tensor, Tensor, Tensor | None]:
68
+ """Compute neighbor list with automatic buffer adjustment.
69
+
70
+ Parameters
71
+ ----------
72
+ positions : Tensor
73
+ Atomic coordinates, shape (N, 3).
74
+ cell : Tensor | None
75
+ Unit cell vectors, shape (num_systems, 3, 3). None for non-periodic.
76
+ pbc : Tensor | None
77
+ Periodic boundary conditions, shape (num_systems, 3). None for non-periodic.
78
+ batch_idx : Tensor | None
79
+ Batch index for each atom, shape (N,). None for single system.
80
+ fill_value : int | None
81
+ Fill value for padding. Default is N (number of atoms).
82
+
83
+ Returns
84
+ -------
85
+ nbmat : Tensor
86
+ Neighbor indices, shape (N, actual_max_neighbors).
87
+ num_neighbors : Tensor
88
+ Number of neighbors per atom, shape (N,).
89
+ shifts : Tensor | None
90
+ Integer unit cell shifts for PBC, shape (N, actual_max_neighbors, 3).
91
+ None for non-periodic systems.
92
+ """
93
+ N = positions.shape[0]
94
+ if fill_value is None:
95
+ fill_value = N
96
+ _pbc = cell is not None
97
+
98
+ while True:
99
+ try:
100
+ if _pbc:
101
+ nbmat, num_neighbors, shifts = neighbor_list(
102
+ positions=positions,
103
+ cutoff=self.cutoff,
104
+ cell=cell,
105
+ pbc=pbc,
106
+ batch_idx=batch_idx,
107
+ max_neighbors=self.max_neighbors,
108
+ half_fill=False,
109
+ fill_value=fill_value,
110
+ )
111
+ else:
112
+ nbmat, num_neighbors = neighbor_list(
113
+ positions=positions,
114
+ cutoff=self.cutoff,
115
+ batch_idx=batch_idx,
116
+ max_neighbors=self.max_neighbors,
117
+ half_fill=False,
118
+ fill_value=fill_value,
119
+ method="batch_naive",
120
+ )
121
+ shifts = None
122
+ except NeighborOverflowError:
123
+ # Increase buffer by 1.5x and retry
124
+ self.max_neighbors = self._round_to_16(int(self.max_neighbors * 1.5))
125
+ continue
126
+
127
+ # Get actual max neighbors from result
128
+ actual_max = int(num_neighbors.max().item())
129
+
130
+ # Adjust buffer if under-utilized (shrink at 2/3 of target for hysteresis)
131
+ # Use 2/3 threshold to prevent thrashing from small fluctuations
132
+ if actual_max < (2 / 3) * self.target_utilization * self.max_neighbors:
133
+ new_max = self._round_to_16(int(actual_max / self.target_utilization))
134
+ self.max_neighbors = max(new_max, 16) # Ensure minimum of 16
135
+
136
+ # Trim to actual max neighbors
137
+ actual_nnb = max(1, actual_max)
138
+ nbmat = nbmat[:, :actual_nnb]
139
+ if shifts is not None:
140
+ shifts = shifts[:, :actual_nnb]
141
+
142
+ return nbmat, num_neighbors, shifts
9
143
 
10
144
 
11
145
  class AIMNet2Calculator:
12
- """Genegic AIMNet2 calculator
146
+ """Generic AIMNet2 calculator.
147
+
13
148
  A helper class to load AIMNet2 models and perform inference.
149
+
150
+ Parameters
151
+ ----------
152
+ model : str | nn.Module
153
+ Model name (from registry), path to model file, or nn.Module instance.
154
+ nb_threshold : int
155
+ Threshold for neighbor list batching. Molecules larger than this use
156
+ flattened processing. Default is 120.
157
+ needs_coulomb : bool | None
158
+ Whether to add external Coulomb module. If None (default), determined
159
+ from model metadata. If True/False, overrides metadata.
160
+ needs_dispersion : bool | None
161
+ Whether to add external DFTD3 module. If None (default), determined
162
+ from model metadata. If True/False, overrides metadata.
163
+ device : str | None
164
+ Device to run the model on ("cuda", "cpu", or specific like "cuda:0").
165
+ If None (default), auto-detects CUDA availability.
166
+ compile_model : bool
167
+ Whether to compile the model with torch.compile(). Default is False.
168
+ compile_kwargs : dict | None
169
+ Additional keyword arguments to pass to torch.compile(). Default is None.
170
+ train : bool
171
+ Whether to enable training mode. Default is False (inference mode).
172
+ When False, all model parameters have requires_grad=False, which
173
+ improves torch.compile compatibility and reduces memory usage.
174
+ Set to True only when training the model.
175
+
176
+ Attributes
177
+ ----------
178
+ model : nn.Module
179
+ The loaded AIMNet2 model.
180
+ device : str
181
+ Device the model is running on ("cuda" or "cpu").
182
+ cutoff : float
183
+ Short-range cutoff distance in Angstroms.
184
+ cutoff_lr : float | None
185
+ Long-range cutoff distance, or None if no LR modules.
186
+ external_coulomb : LRCoulomb | None
187
+ External Coulomb module if attached.
188
+ external_dftd3 : DFTD3 | None
189
+ External DFTD3 module if attached.
190
+
191
+ Notes
192
+ -----
193
+ External LR module behavior:
194
+
195
+ - For file-loaded models (str): metadata is loaded from file
196
+ - For nn.Module: metadata is read from model.metadata attribute if available
197
+ - Explicit flags (needs_coulomb, needs_dispersion) override metadata
198
+ - If no metadata and no explicit flags, no external LR modules are added
14
199
  """
15
200
 
16
- keys_in: ClassVar[Dict[str, torch.dtype]] = {"coord": torch.float, "numbers": torch.int, "charge": torch.float}
17
- keys_in_optional: ClassVar[Dict[str, torch.dtype]] = {
201
+ keys_in: ClassVar[dict[str, torch.dtype]] = {"coord": torch.float, "numbers": torch.int, "charge": torch.float}
202
+ keys_in_optional: ClassVar[dict[str, torch.dtype]] = {
18
203
  "mult": torch.float,
19
204
  "mol_idx": torch.int,
20
205
  "nbmat": torch.int,
@@ -28,112 +213,538 @@ class AIMNet2Calculator:
28
213
  keys_out: ClassVar[list[str]] = ["energy", "charges", "forces", "hessian", "stress"]
29
214
  atom_feature_keys: ClassVar[list[str]] = ["coord", "numbers", "charges", "forces"]
30
215
 
31
- def __init__(self, model: str | nn.Module = "aimnet2", nb_threshold: int = 320):
32
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
216
+ def __init__(
217
+ self,
218
+ model: str | nn.Module = "aimnet2",
219
+ nb_threshold: int = 120,
220
+ needs_coulomb: bool | None = None,
221
+ needs_dispersion: bool | None = None,
222
+ device: str | None = None,
223
+ compile_model: bool = False,
224
+ compile_kwargs: dict | None = None,
225
+ train: bool = False,
226
+ ):
227
+ # Device selection: use provided or auto-detect
228
+ if device is not None:
229
+ self.device = device
230
+ else:
231
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
232
+ self.external_coulomb: LRCoulomb | None = None
233
+ self.external_dftd3: DFTD3 | None = None
234
+ # Default cutoffs for LR modules
235
+ self._default_dsf_cutoff = 15.0
236
+ self._default_dftd3_cutoff = 15.0
237
+ self._default_dftd3_smoothing = 0.2
238
+
239
+ # Load model and get metadata
240
+ metadata: dict | None = None
33
241
  if isinstance(model, str):
34
242
  p = get_model_path(model)
35
- self.model = torch.jit.load(p, map_location=self.device)
243
+ self.model, metadata = load_model(p, device=self.device)
244
+ self.cutoff = metadata["cutoff"]
36
245
  elif isinstance(model, nn.Module):
37
246
  self.model = model.to(self.device)
247
+ self.cutoff = getattr(self.model, "cutoff", 5.0)
248
+ metadata = getattr(self.model, "_metadata", None)
38
249
  else:
39
250
  raise TypeError("Invalid model type/name.")
40
251
 
41
- self.cutoff = self.model.cutoff
42
- self.lr = hasattr(self.model, "cutoff_lr")
43
- self.cutoff_lr = getattr(self.model, "cutoff_lr", float("inf"))
44
- self.max_density = 0.2
252
+ # Compile model if requested
253
+ if compile_model:
254
+ kwargs = compile_kwargs or {}
255
+ self.model = torch.compile(self.model, **kwargs)
256
+
257
+ # Resolve final flags (explicit overrides metadata)
258
+ final_needs_coulomb = (
259
+ needs_coulomb
260
+ if needs_coulomb is not None
261
+ else (metadata.get("needs_coulomb", False) if metadata is not None else False)
262
+ )
263
+ final_needs_dispersion = (
264
+ needs_dispersion
265
+ if needs_dispersion is not None
266
+ else (metadata.get("needs_dispersion", False) if metadata is not None else False)
267
+ )
268
+
269
+ # Set up external Coulomb if needed
270
+ if final_needs_coulomb:
271
+ sr_embedded = metadata.get("coulomb_mode") == "sr_embedded" if metadata is not None else False
272
+ # For PBC, user can switch to DSF/Ewald via set_lrcoulomb_method()
273
+ # When sr_embedded=True: model has SRCoulomb which subtracts SR, so external
274
+ # should compute FULL (subtract_sr=False) to give: (NN - SR) + FULL = NN + LR
275
+ # When sr_embedded=False: model has no SR embedded, so external should compute
276
+ # LR only (subtract_sr=True) to avoid double-counting
277
+ self.external_coulomb = LRCoulomb(
278
+ key_in="charges",
279
+ key_out="energy",
280
+ method="simple",
281
+ rc=metadata.get("coulomb_sr_rc", 4.6) if metadata is not None else 4.6,
282
+ envelope=metadata.get("coulomb_sr_envelope", "exp") if metadata is not None else "exp",
283
+ subtract_sr=not sr_embedded,
284
+ )
285
+ self.external_coulomb = self.external_coulomb.to(self.device)
286
+
287
+ # Set up external DFTD3 if needed
288
+ if final_needs_dispersion:
289
+ d3_params = metadata.get("d3_params") if metadata else None
290
+ if d3_params is None:
291
+ raise ValueError(
292
+ "needs_dispersion=True but d3_params not found in metadata. "
293
+ "Provide d3_params in model metadata or set needs_dispersion=False."
294
+ )
295
+ self.external_dftd3 = DFTD3(
296
+ s8=d3_params["s8"],
297
+ a1=d3_params["a1"],
298
+ a2=d3_params["a2"],
299
+ s6=d3_params.get("s6", 1.0),
300
+ )
301
+ self.external_dftd3 = self.external_dftd3.to(self.device)
302
+
303
+ # Determine if model has long-range modules (embedded or external)
304
+ has_embedded_lr = metadata.get("has_embedded_lr", False) if metadata is not None else False
305
+ self.lr = (
306
+ hasattr(self.model, "cutoff_lr")
307
+ or self.external_coulomb is not None
308
+ or self.external_dftd3 is not None
309
+ or has_embedded_lr
310
+ )
311
+ # Set cutoff_lr based on model attribute or external modules
312
+ if hasattr(self.model, "cutoff_lr"):
313
+ self.cutoff_lr = getattr(self.model, "cutoff_lr", float("inf"))
314
+ elif self.external_coulomb is not None:
315
+ # For "simple" method, use inf (all pairs). For DSF, use dsf_cutoff.
316
+ if self.external_coulomb.method == "simple":
317
+ self.cutoff_lr = float("inf")
318
+ else:
319
+ self.cutoff_lr = self._default_dsf_cutoff
320
+ elif self.external_dftd3 is not None:
321
+ self.cutoff_lr = self._default_dftd3_cutoff
322
+ elif has_embedded_lr:
323
+ # Embedded LR modules (D3TS, SRCoulomb) need nbmat_lr
324
+ self.cutoff_lr = self._default_dftd3_cutoff
325
+ else:
326
+ self.cutoff_lr = None
45
327
  self.nb_threshold = nb_threshold
46
328
 
329
+ # Create adaptive neighbor list instances
330
+ self._nblist = AdaptiveNeighborList(cutoff=self.cutoff)
331
+
332
+ # Track separate cutoffs for LR modules
333
+ self._coulomb_cutoff: float | None = None
334
+ self._dftd3_cutoff: float = self._default_dftd3_cutoff
335
+ if self.external_coulomb is not None:
336
+ if self.external_coulomb.method == "simple":
337
+ self._coulomb_cutoff = float("inf")
338
+ elif self.external_coulomb.method == "ewald":
339
+ self._coulomb_cutoff = None # Ewald manages its own cutoff
340
+ else:
341
+ self._coulomb_cutoff = self.external_coulomb.dsf_rc
342
+ if self.external_dftd3 is not None:
343
+ self._dftd3_cutoff = self.external_dftd3.smoothing_off
344
+
345
+ # Create long-range neighbor list(s) if LR modules present
346
+ self._nblist_lr: AdaptiveNeighborList | None = None
347
+ self._nblist_dftd3: AdaptiveNeighborList | None = None
348
+ self._nblist_coulomb: AdaptiveNeighborList | None = None
349
+ self._update_lr_nblists()
350
+
47
351
  # indicator if input was flattened
48
352
  self._batch = None
49
353
  self._max_mol_size: int = 0
50
354
  # placeholder for tensors that require grad
51
355
  self._saved_for_grad = {}
52
- # set flag of current Coulomb model
53
- coul_methods = {getattr(mod, "method", None) for mod in iter_lrcoulomb_mods(self.model)}
54
- if len(coul_methods) > 1:
55
- raise ValueError("Multiple Coulomb modules found.")
56
- if len(coul_methods):
57
- self._coulomb_method = coul_methods.pop()
58
- else:
59
- self._coulomb_method = None
356
+ # set flag of current Coulomb method
357
+ self._coulomb_method: str | None = None
358
+ if self.external_coulomb is not None:
359
+ self._coulomb_method = self.external_coulomb.method
360
+ elif self._has_embedded_coulomb():
361
+ # Legacy models have embedded Coulomb with "simple" method
362
+ self._coulomb_method = "simple"
363
+
364
+ # Set training mode (default False for inference)
365
+ self._train = train
366
+ self.model.train(train)
367
+ if not train:
368
+ # Disable gradients on all parameters for inference mode
369
+ for param in self.model.parameters():
370
+ param.requires_grad_(False)
371
+ if self.external_coulomb is not None:
372
+ for param in self.external_coulomb.parameters():
373
+ param.requires_grad_(False)
374
+ if self.external_dftd3 is not None:
375
+ for param in self.external_dftd3.parameters():
376
+ param.requires_grad_(False)
60
377
 
61
378
  def __call__(self, *args, **kwargs):
62
379
  return self.eval(*args, **kwargs)
63
380
 
381
+ @property
382
+ def has_external_coulomb(self) -> bool:
383
+ """Check if calculator has external Coulomb module attached.
384
+
385
+ Returns True for new-format models that were trained with Coulomb
386
+ and have it externalized. For legacy models, Coulomb is embedded
387
+ in the model itself, so this returns False.
388
+ """
389
+ return self.external_coulomb is not None
390
+
391
+ @property
392
+ def has_external_dftd3(self) -> bool:
393
+ """Check if calculator has external DFTD3 module attached.
394
+
395
+ Returns True for new-format models that were trained with DFTD3/D3BJ
396
+ dispersion and have it externalized. For legacy models or D3TS models,
397
+ dispersion is embedded in the model itself, so this returns False.
398
+ """
399
+ return self.external_dftd3 is not None
400
+
401
+ @property
402
+ def coulomb_method(self) -> str | None:
403
+ """Get the current Coulomb method.
404
+
405
+ Returns
406
+ -------
407
+ str | None
408
+ One of "simple", "dsf", "ewald", or None if no external Coulomb.
409
+ For legacy models with embedded Coulomb, returns None.
410
+ """
411
+ if self.external_coulomb is not None:
412
+ return self.external_coulomb.method
413
+ return None
414
+
415
+ @property
416
+ def coulomb_cutoff(self) -> float | None:
417
+ """Get the current Coulomb cutoff distance.
418
+
419
+ Returns
420
+ -------
421
+ float | None
422
+ The cutoff distance for Coulomb calculations, or None if not applicable.
423
+ For "simple" method, this is inf. For "ewald", this is None.
424
+ Use set_lrcoulomb_method() to change.
425
+ """
426
+ return self._coulomb_cutoff
427
+
428
+ @property
429
+ def dftd3_cutoff(self) -> float:
430
+ """Get the current DFTD3 cutoff distance.
431
+
432
+ Returns
433
+ -------
434
+ float
435
+ The cutoff distance for DFTD3 calculations in Angstroms.
436
+ """
437
+ return self._dftd3_cutoff
438
+
439
+ def _has_embedded_dispersion(self) -> bool:
440
+ """Check if model has embedded dispersion (not externalized).
441
+
442
+ Uses model metadata when available, otherwise returns False (unknown).
443
+
444
+ Returns
445
+ -------
446
+ bool
447
+ True if model has embedded dispersion module (D3TS or legacy DFTD3).
448
+ """
449
+ meta = getattr(self.model, "_metadata", None)
450
+ if meta is None:
451
+ return False # Unknown, assume no embedded dispersion
452
+
453
+ # New format: Check for embedded D3TS via has_embedded_lr
454
+ # If has_embedded_lr=True and coulomb_mode != "sr_embedded", it's D3TS
455
+ if meta.get("has_embedded_lr", False):
456
+ coulomb_mode = meta.get("coulomb_mode", "none")
457
+ if coulomb_mode != "sr_embedded":
458
+ return True # Must be D3TS (embedded dispersion)
459
+
460
+ # Legacy format: If needs_dispersion=False and d3_params exist, dispersion is embedded
461
+ # (legacy JIT models have dispersion embedded)
462
+ return not meta.get("needs_dispersion", False) and meta.get("d3_params") is not None
463
+
464
+ def _has_embedded_coulomb(self) -> bool:
465
+ """Check if model has embedded Coulomb (not externalized).
466
+
467
+ Uses model metadata when available, otherwise returns False (unknown).
468
+
469
+ Returns
470
+ -------
471
+ bool
472
+ True if model has embedded Coulomb module.
473
+ """
474
+ meta = getattr(self.model, "_metadata", None)
475
+ if meta is None:
476
+ return False # Unknown, assume no embedded Coulomb
477
+ # If needs_coulomb=False and coulomb_mode is not "none", Coulomb is embedded
478
+ # (legacy JIT models have full Coulomb embedded)
479
+ return not meta.get("needs_coulomb", False) and meta.get("coulomb_mode", "none") != "none"
480
+
481
+ def _should_use_separate_nblist(self, cutoff1: float, cutoff2: float) -> bool:
482
+ """Check if two cutoffs differ enough to warrant separate neighbor lists.
483
+
484
+ Parameters
485
+ ----------
486
+ cutoff1 : float
487
+ First cutoff distance.
488
+ cutoff2 : float
489
+ Second cutoff distance.
490
+
491
+ Returns
492
+ -------
493
+ bool
494
+ True if cutoffs differ by more than 20%, False otherwise.
495
+ """
496
+ # Handle edge cases
497
+ if cutoff1 <= 0 or cutoff2 <= 0:
498
+ return False
499
+ if not math.isfinite(cutoff1) or not math.isfinite(cutoff2):
500
+ return False
501
+ ratio = max(cutoff1, cutoff2) / min(cutoff1, cutoff2)
502
+ return ratio > 1.2
503
+
504
+ def _update_lr_nblists(self) -> None:
505
+ """Update long-range neighbor list instances based on current cutoffs.
506
+
507
+ Creates separate neighbor lists for DFTD3 and Coulomb if their cutoffs
508
+ differ by more than 20%. Otherwise, uses a single shared neighbor list.
509
+ Ewald uses its own internal neighbor list and ignores cutoffs.
510
+ """
511
+ if not self.lr:
512
+ self._nblist_lr = None
513
+ self._nblist_dftd3 = None
514
+ self._nblist_coulomb = None
515
+ return
516
+
517
+ has_dftd3 = self.external_dftd3 is not None or self._has_embedded_dispersion()
518
+ has_coulomb = self.external_coulomb is not None or self._has_embedded_coulomb()
519
+
520
+ # Determine effective cutoffs (None means no neighbor list needed for that module)
521
+ dftd3_cutoff = self._dftd3_cutoff if has_dftd3 else 0.0
522
+ coulomb_cutoff = self._coulomb_cutoff if has_coulomb and self._coulomb_cutoff is not None else 0.0
523
+
524
+ # Check if we need separate neighbor lists (both finite and differ by >20%)
525
+ if (
526
+ has_dftd3
527
+ and has_coulomb
528
+ and math.isfinite(dftd3_cutoff)
529
+ and math.isfinite(coulomb_cutoff)
530
+ and coulomb_cutoff > 0
531
+ and self._should_use_separate_nblist(dftd3_cutoff, coulomb_cutoff)
532
+ ):
533
+ # Use separate neighbor lists
534
+ self._nblist_dftd3 = AdaptiveNeighborList(cutoff=dftd3_cutoff)
535
+ self._nblist_coulomb = AdaptiveNeighborList(cutoff=coulomb_cutoff)
536
+ self._nblist_lr = None
537
+ return
538
+
539
+ # Use single shared neighbor list with max cutoff
540
+ max_cutoff = 0.0
541
+ if has_dftd3 and math.isfinite(dftd3_cutoff):
542
+ max_cutoff = max(max_cutoff, dftd3_cutoff)
543
+ if has_coulomb:
544
+ if coulomb_cutoff == float("inf"):
545
+ # Simple Coulomb needs all pairs
546
+ self._nblist_lr = AdaptiveNeighborList(cutoff=1e6)
547
+ self._nblist_dftd3 = None
548
+ self._nblist_coulomb = None
549
+ return
550
+ if math.isfinite(coulomb_cutoff) and coulomb_cutoff > 0:
551
+ max_cutoff = max(max_cutoff, coulomb_cutoff)
552
+
553
+ if max_cutoff > 0:
554
+ self._nblist_lr = AdaptiveNeighborList(cutoff=max_cutoff)
555
+ else:
556
+ self._nblist_lr = None
557
+ self._nblist_dftd3 = None
558
+ self._nblist_coulomb = None
559
+
64
560
  def set_lrcoulomb_method(
65
- self, method: Literal["simple", "dsf", "ewald"], cutoff: float = 15.0, dsf_alpha: float = 0.2
561
+ self,
562
+ method: Literal["simple", "dsf", "ewald"],
563
+ cutoff: float = 15.0,
564
+ dsf_alpha: float = 0.2,
565
+ ewald_accuracy: float = 1e-8,
66
566
  ):
567
+ """Set the long-range Coulomb method.
568
+
569
+ Parameters
570
+ ----------
571
+ method : str
572
+ One of "simple", "dsf", or "ewald".
573
+ cutoff : float
574
+ Cutoff distance for DSF neighbor list. Default is 15.0.
575
+ Not used for Ewald (which computes cutoffs from accuracy).
576
+ dsf_alpha : float
577
+ Alpha parameter for DSF method. Default is 0.2.
578
+ ewald_accuracy : float
579
+ Target accuracy for Ewald summation. Controls the real-space
580
+ and reciprocal-space cutoffs. Lower values give higher accuracy
581
+ but require more computation. Default is 1e-8.
582
+
583
+ The Ewald cutoffs are computed as:
584
+ - eta = (V^2 / N)^(1/6) / sqrt(2*pi)
585
+ - cutoff_real = sqrt(-2 * ln(accuracy)) * eta
586
+ - cutoff_recip = sqrt(-2 * ln(accuracy)) / eta
587
+
588
+ Notes
589
+ -----
590
+ For new-format models with external Coulomb, this updates the external module.
591
+ For legacy models with embedded Coulomb, a warning is issued as those modules
592
+ cannot be modified at runtime.
593
+ """
67
594
  if method not in ("simple", "dsf", "ewald"):
68
595
  raise ValueError(f"Invalid method: {method}")
69
- for mod in iter_lrcoulomb_mods(self.model):
70
- mod.method = method # type: ignore
71
- if method == "simple":
72
- self.cutoff_lr = float("inf")
73
- elif method == "dsf":
74
- self.cutoff_lr = cutoff
75
- mod.dsf_alpha = dsf_alpha # type: ignore
76
- mod.dsf_rc = cutoff # type: ignore
596
+
597
+ # Warn if model has embedded Coulomb (legacy models)
598
+ if self._has_embedded_coulomb() and self.external_coulomb is None:
599
+ warnings.warn(
600
+ "Model has embedded Coulomb module (legacy format). "
601
+ "set_lrcoulomb_method() only affects external Coulomb modules. "
602
+ "For legacy models, the Coulomb method cannot be changed at runtime.",
603
+ stacklevel=2,
604
+ )
605
+
606
+ # Update external LRCoulomb module if present
607
+ if self.external_coulomb is not None:
608
+ self.external_coulomb.method = method
609
+ if method == "dsf":
610
+ self.external_coulomb.dsf_alpha = dsf_alpha
611
+ self.external_coulomb.dsf_rc = cutoff
77
612
  elif method == "ewald":
78
- # current implementaion of Ewald does not use nb mat
79
- self.cutoff_lr = cutoff
613
+ self.external_coulomb.ewald_accuracy = ewald_accuracy
614
+
615
+ # Update _coulomb_cutoff based on method
616
+ if method == "simple":
617
+ self._coulomb_cutoff = float("inf")
618
+ elif method == "dsf":
619
+ self._coulomb_cutoff = cutoff
620
+ elif method == "ewald":
621
+ self._coulomb_cutoff = None # Ewald manages its own real-space cutoff
622
+
623
+ # Update cutoff_lr for backward compatibility
624
+ if self._coulomb_cutoff is not None:
625
+ self.cutoff_lr = self._coulomb_cutoff
626
+ else:
627
+ # Ewald - use DFTD3 cutoff if available, else None
628
+ self.cutoff_lr = self._dftd3_cutoff if self.external_dftd3 is not None else None
629
+
80
630
  self._coulomb_method = method
631
+ self._update_lr_nblists()
632
+
633
+ def set_lr_cutoff(self, cutoff: float) -> None:
634
+ """Set the unified long-range cutoff for all LR modules.
81
635
 
82
- def eval(self, data: Dict[str, Any], forces=False, stress=False, hessian=False) -> Dict[str, Tensor]:
636
+ Parameters
637
+ ----------
638
+ cutoff : float
639
+ Cutoff distance in Angstroms for LR neighbor lists.
640
+
641
+ Notes
642
+ -----
643
+ This updates both _coulomb_cutoff and _dftd3_cutoff.
644
+ Ewald uses its own internal neighbor list and ignores this cutoff.
645
+ """
646
+ # Update both cutoffs (but not for ewald which manages its own)
647
+ if self._coulomb_method != "ewald":
648
+ self._coulomb_cutoff = cutoff
649
+ self._dftd3_cutoff = cutoff
650
+ self.cutoff_lr = cutoff
651
+ self._update_lr_nblists()
652
+
653
+ def set_dftd3_cutoff(self, cutoff: float | None = None, smoothing_fraction: float | None = None) -> None:
654
+ """Set DFTD3 cutoff and smoothing.
655
+
656
+ Parameters
657
+ ----------
658
+ cutoff : float | None
659
+ Cutoff distance in Angstroms for DFTD3 calculation.
660
+ Default is _default_dftd3_cutoff (15.0).
661
+ smoothing_fraction : float | None
662
+ Fraction of cutoff used as smoothing width.
663
+ Default is _default_dftd3_smoothing (0.2).
664
+
665
+ Notes
666
+ -----
667
+ This method only affects external DFTD3 modules attached to
668
+ new-format models. For legacy models with embedded DFTD3,
669
+ the smoothing is fixed.
670
+
671
+ Updates _dftd3_cutoff and rebuilds neighbor lists.
672
+ """
673
+ if cutoff is None:
674
+ cutoff = self._default_dftd3_cutoff
675
+ if smoothing_fraction is None:
676
+ smoothing_fraction = self._default_dftd3_smoothing
677
+
678
+ self._dftd3_cutoff = cutoff
679
+ if self.external_dftd3 is not None:
680
+ self.external_dftd3.set_smoothing(cutoff, smoothing_fraction)
681
+ self._update_lr_nblists()
682
+
683
+ def eval(self, data: dict[str, Any], forces=False, stress=False, hessian=False) -> dict[str, Tensor]:
83
684
  data = self.prepare_input(data)
685
+
84
686
  if hessian and "mol_idx" in data and data["mol_idx"][-1] > 0:
85
687
  raise NotImplementedError("Hessian calculation is not supported for multiple molecules")
86
688
  data = self.set_grad_tensors(data, forces=forces, stress=stress, hessian=hessian)
87
689
  with torch.jit.optimized_execution(False): # type: ignore
88
690
  data = self.model(data)
691
+ # Run external modules if present
692
+ data = self._run_external_modules(data, compute_stress=stress)
89
693
  data = self.get_derivatives(data, forces=forces, stress=stress, hessian=hessian)
90
694
  data = self.process_output(data)
91
695
  return data
92
696
 
93
- def prepare_input(self, data: Dict[str, Any]) -> Dict[str, Tensor]:
697
+ def _run_external_modules(self, data: dict[str, Tensor], compute_stress: bool = False) -> dict[str, Tensor]:
698
+ """Run external Coulomb and DFTD3 modules if attached."""
699
+ if self.external_coulomb is not None:
700
+ data = self.external_coulomb(data)
701
+ if self.external_dftd3 is not None:
702
+ self.external_dftd3.compute_virial = compute_stress
703
+ data = self.external_dftd3(data)
704
+ return data
705
+
706
+ def prepare_input(self, data: dict[str, Any]) -> dict[str, Tensor]:
94
707
  data = self.to_input_tensors(data)
95
708
  data = self.mol_flatten(data)
96
- if data.get("cell") is not None:
97
- if data["mol_idx"][-1] > 0:
98
- raise NotImplementedError("PBC with multiple molecules is not implemented yet.")
99
- if self._coulomb_method == "simple":
100
- warnings.warn("Switching to DSF Coulomb for PBC", stacklevel=1)
101
- self.set_lrcoulomb_method("dsf")
709
+ if data.get("cell") is not None and self._coulomb_method == "simple":
710
+ warnings.warn("Switching to DSF Coulomb for PBC", stacklevel=1)
711
+ self.set_lrcoulomb_method("dsf")
102
712
  if data["coord"].ndim == 2:
103
- data = self.make_nbmat(data)
713
+ # Skip neighbor list calculation if already provided
714
+ if "nbmat" not in data:
715
+ data = self.make_nbmat(data)
104
716
  data = self.pad_input(data)
105
717
  return data
106
718
 
107
- def process_output(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
719
+ def process_output(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
108
720
  if data["coord"].ndim == 2:
109
721
  data = self.unpad_output(data)
110
722
  data = self.mol_unflatten(data)
111
723
  data = self.keep_only(data)
112
724
  return data
113
725
 
114
- def to_input_tensors(self, data: Dict[str, Any]) -> Dict[str, Tensor]:
726
+ def to_input_tensors(self, data: dict[str, Any]) -> dict[str, Tensor]:
115
727
  ret = {}
116
728
  for k in self.keys_in:
117
729
  if k not in data:
118
730
  raise KeyError(f"Missing key {k} in the input data")
119
- # always detach !!
731
+ # Detach from computation graph to prevent gradient accumulation
120
732
  ret[k] = torch.as_tensor(data[k], device=self.device, dtype=self.keys_in[k]).detach()
121
733
  for k in self.keys_in_optional:
122
734
  if k in data and data[k] is not None:
123
735
  ret[k] = torch.as_tensor(data[k], device=self.device, dtype=self.keys_in_optional[k]).detach()
124
- # convert any scalar tensors to shape (1,) tensors
736
+ # Ensure all tensors have at least 1D shape for consistent batch processing
125
737
  for k, v in ret.items():
126
738
  if v.ndim == 0:
127
739
  ret[k] = v.unsqueeze(0)
128
740
  return ret
129
741
 
130
- def mol_flatten(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
742
+ def mol_flatten(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
131
743
  """Flatten the input data for multiple molecules.
132
744
  Will not flatten for batched input and molecule size below threshold.
133
745
  """
134
746
  ndim = data["coord"].ndim
135
747
  if ndim == 2:
136
- # single molecule or already flattened
137
748
  self._batch = None
138
749
  if "mol_idx" not in data:
139
750
  data["mol_idx"] = torch.zeros(data["coord"].shape[0], dtype=torch.long, device=self.device)
@@ -144,9 +755,9 @@ class AIMNet2Calculator:
144
755
  self._max_mol_size = data["mol_idx"].unique(return_counts=True)[1].max().item()
145
756
 
146
757
  elif ndim == 3:
147
- # batched input
148
758
  B, N = data["coord"].shape[:2]
149
- if self.nb_threshold > N or self.device == "cpu":
759
+ # Force flattening for PBC (cell present) to ensure make_nbmat computes proper neighbor lists with shifts
760
+ if self.nb_threshold < N or self.device == "cpu" or data.get("cell") is not None:
150
761
  self._batch = B
151
762
  data["mol_idx"] = torch.repeat_interleave(
152
763
  torch.arange(0, B, device=self.device), torch.full((B,), N, device=self.device)
@@ -159,7 +770,7 @@ class AIMNet2Calculator:
159
770
  self._max_mol_size = N
160
771
  return data
161
772
 
162
- def mol_unflatten(self, data: Dict[str, Tensor], batch=None) -> Dict[str, Tensor]:
773
+ def mol_unflatten(self, data: dict[str, Tensor], batch=None) -> dict[str, Tensor]:
163
774
  batch = batch if batch is not None else self._batch
164
775
  if batch is not None:
165
776
  for k, v in data.items():
@@ -167,47 +778,121 @@ class AIMNet2Calculator:
167
778
  data[k] = v.view(batch, -1, *v.shape[1:])
168
779
  return data
169
780
 
170
- def make_nbmat(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
781
+ def make_nbmat(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
171
782
  assert self._max_mol_size > 0, "Molecule size is not set"
172
783
 
784
+ # Prepare batch_idx from mol_idx
785
+ mol_idx = data.get("mol_idx")
786
+
173
787
  if "cell" in data and data["cell"] is not None:
174
- data["coord"] = move_coord_to_cell(data["coord"], data["cell"])
788
+ data["coord"] = move_coord_to_cell(data["coord"], data["cell"], mol_idx)
175
789
  cell = data["cell"]
176
790
  else:
177
791
  cell = None
178
792
 
179
- while True:
180
- try:
181
- maxnb1 = calc_max_nb(self.cutoff, self.max_density)
182
- maxnb2 = calc_max_nb(self.cutoff_lr, self.max_density) if self.lr else None # type: ignore
183
- if cell is None:
184
- maxnb1 = min(maxnb1, self._max_mol_size)
185
- maxnb2 = min(maxnb2, self._max_mol_size) if self.lr else None # type: ignore
186
- maxnb = (maxnb1, maxnb2)
187
- nbmat1, nbmat2, shifts1, shifts2 = calc_nbmat(
188
- data["coord"],
189
- (self.cutoff, self.cutoff_lr),
190
- maxnb, # type: ignore
191
- cell,
192
- data.get("mol_idx"), # type: ignore
193
- )
194
- break
195
- except TooManyNeighborsError:
196
- self.max_density *= 1.2
197
- assert self.max_density <= 4, "Something went wrong in nbmat calculation"
793
+ N = data["coord"].shape[0]
794
+ _pbc = cell is not None
795
+ batch_idx = mol_idx.to(torch.int32) if mol_idx is not None else None
796
+
797
+ # Prepare cell and pbc tensors for nvalchemiops
798
+ if _pbc:
799
+ if cell.ndim == 2:
800
+ cell_batched = cell.unsqueeze(0) # (1, 3, 3)
801
+ else:
802
+ cell_batched = cell # (num_systems, 3, 3)
803
+ num_systems = cell_batched.shape[0]
804
+ pbc = torch.tensor([[True, True, True]] * num_systems, dtype=torch.bool, device=cell.device)
805
+ else:
806
+ cell_batched = None
807
+ pbc = None
808
+
809
+ # Short-range neighbors (always)
810
+ nbmat1, _, shifts1 = self._nblist(
811
+ positions=data["coord"],
812
+ cell=cell_batched,
813
+ pbc=pbc,
814
+ batch_idx=batch_idx,
815
+ fill_value=N,
816
+ )
817
+
818
+ nbmat1, shifts1 = _add_padding_row(nbmat1, shifts1, N)
198
819
  data["nbmat"] = nbmat1
199
- if self.lr:
200
- assert nbmat2 is not None
201
- data["nbmat_lr"] = nbmat2
202
820
  if cell is not None:
203
821
  assert shifts1 is not None
204
822
  data["shifts"] = shifts1
205
- if self.lr:
206
- assert shifts2 is not None
207
- data["shifts_lr"] = shifts2
823
+
824
+ # Unified neighbor list when LR module cutoffs are similar
825
+ if self._nblist_lr is not None:
826
+ if self._coulomb_cutoff == float("inf"):
827
+ self._nblist_lr.max_neighbors = N
828
+ nbmat_lr, _, shifts_lr = self._nblist_lr(
829
+ positions=data["coord"],
830
+ cell=cell_batched,
831
+ pbc=pbc,
832
+ batch_idx=batch_idx,
833
+ fill_value=N,
834
+ )
835
+ nbmat_lr, shifts_lr = _add_padding_row(nbmat_lr, shifts_lr, N)
836
+
837
+ # All LR modules share the same neighbor list when cutoffs are similar
838
+ data["nbmat_lr"] = nbmat_lr
839
+ data["nbmat_coulomb"] = nbmat_lr
840
+ data["nbmat_dftd3"] = nbmat_lr
841
+ if cell is not None and shifts_lr is not None:
842
+ data["shifts_lr"] = shifts_lr
843
+ data["shifts_coulomb"] = shifts_lr
844
+ data["shifts_dftd3"] = shifts_lr
845
+ else:
846
+ if self._nblist_coulomb is not None:
847
+ if self._coulomb_cutoff == float("inf"):
848
+ self._nblist_coulomb.max_neighbors = N
849
+ nbmat_coulomb, _, shifts_coulomb = self._nblist_coulomb(
850
+ positions=data["coord"],
851
+ cell=cell_batched,
852
+ pbc=pbc,
853
+ batch_idx=batch_idx,
854
+ fill_value=N,
855
+ )
856
+ nbmat_coulomb, shifts_coulomb = _add_padding_row(nbmat_coulomb, shifts_coulomb, N)
857
+ data["nbmat_coulomb"] = nbmat_coulomb
858
+ # Set nbmat_lr for backward compatibility with code expecting unified LR neighbor list
859
+ data["nbmat_lr"] = nbmat_coulomb
860
+ if cell is not None and shifts_coulomb is not None:
861
+ data["shifts_coulomb"] = shifts_coulomb
862
+ data["shifts_lr"] = shifts_coulomb
863
+
864
+ if self._nblist_dftd3 is not None:
865
+ nbmat_dftd3, _, shifts_dftd3 = self._nblist_dftd3(
866
+ positions=data["coord"],
867
+ cell=cell_batched,
868
+ pbc=pbc,
869
+ batch_idx=batch_idx,
870
+ fill_value=N,
871
+ )
872
+ nbmat_dftd3, shifts_dftd3 = _add_padding_row(nbmat_dftd3, shifts_dftd3, N)
873
+ data["nbmat_dftd3"] = nbmat_dftd3
874
+ if cell is not None and shifts_dftd3 is not None:
875
+ data["shifts_dftd3"] = shifts_dftd3
876
+
877
+ elif self._nblist_dftd3 is not None:
878
+ # DFTD3-only configuration: populate nbmat_lr for backward compatibility
879
+ nbmat_dftd3, _, shifts_dftd3 = self._nblist_dftd3(
880
+ positions=data["coord"],
881
+ cell=cell_batched,
882
+ pbc=pbc,
883
+ batch_idx=batch_idx,
884
+ fill_value=N,
885
+ )
886
+ nbmat_dftd3, shifts_dftd3 = _add_padding_row(nbmat_dftd3, shifts_dftd3, N)
887
+ data["nbmat_dftd3"] = nbmat_dftd3
888
+ data["nbmat_lr"] = nbmat_dftd3
889
+ if cell is not None and shifts_dftd3 is not None:
890
+ data["shifts_dftd3"] = shifts_dftd3
891
+ data["shifts_lr"] = shifts_dftd3
892
+
208
893
  return data
209
894
 
210
- def pad_input(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
895
+ def pad_input(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
211
896
  N = data["nbmat"].shape[0]
212
897
  data["mol_idx"] = maybe_pad_dim0(data["mol_idx"], N, value=data["mol_idx"][-1].item())
213
898
  for k in ("coord", "numbers"):
@@ -215,36 +900,49 @@ class AIMNet2Calculator:
215
900
  data[k] = maybe_pad_dim0(data[k], N)
216
901
  return data
217
902
 
218
- def unpad_output(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
903
+ def unpad_output(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
219
904
  N = data["nbmat"].shape[0] - 1
220
905
  for k, v in data.items():
221
906
  if k in self.atom_feature_keys:
222
907
  data[k] = maybe_unpad_dim0(v, N)
223
908
  return data
224
909
 
225
- def set_grad_tensors(self, data: Dict[str, Tensor], forces=False, stress=False, hessian=False) -> Dict[str, Tensor]:
910
+ def set_grad_tensors(self, data: dict[str, Tensor], forces=False, stress=False, hessian=False) -> dict[str, Tensor]:
226
911
  self._saved_for_grad = {}
227
912
  if forces or hessian:
228
913
  data["coord"].requires_grad_(True)
229
914
  self._saved_for_grad["coord"] = data["coord"]
230
915
  if stress:
231
916
  assert "cell" in data and data["cell"] is not None, "Stress calculation requires cell"
232
- scaling = torch.eye(3, requires_grad=True, dtype=data["cell"].dtype, device=data["cell"].device)
233
- data["coord"] = data["coord"] @ scaling
234
- data["cell"] = data["cell"] @ scaling
917
+ cell = data["cell"]
918
+ if cell.ndim == 2:
919
+ # Single system: (3, 3) scaling
920
+ scaling = torch.eye(3, requires_grad=True, dtype=cell.dtype, device=cell.device)
921
+ data["coord"] = data["coord"] @ scaling
922
+ data["cell"] = cell @ scaling
923
+ else:
924
+ # Batched systems: (B, 3, 3) scaling - each system gets independent scaling
925
+ B = cell.shape[0]
926
+ scaling = torch.eye(3, dtype=cell.dtype, device=cell.device).unsqueeze(0).expand(B, -1, -1)
927
+ scaling.requires_grad_(True)
928
+ mol_idx = data["mol_idx"]
929
+ # Apply per-atom scaling: coord[i] @ scaling[mol_idx[i]]
930
+ atom_scaling = torch.index_select(scaling, 0, mol_idx) # (N_total, 3, 3)
931
+ data["coord"] = (data["coord"].unsqueeze(1) @ atom_scaling).squeeze(1)
932
+ data["cell"] = cell @ scaling
235
933
  self._saved_for_grad["scaling"] = scaling
236
934
  return data
237
935
 
238
- def keep_only(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
936
+ def keep_only(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
239
937
  ret = {}
240
938
  for k, v in data.items():
241
939
  if k in self.keys_out or (k.endswith("_std") and k[:-4] in self.keys_out):
242
940
  ret[k] = v
243
941
  return ret
244
942
 
245
- def get_derivatives(self, data: Dict[str, Tensor], forces=False, stress=False, hessian=False) -> Dict[str, Tensor]:
246
- training = getattr(self.model, "training", False)
247
- _create_graph = hessian or training
943
+ def get_derivatives(self, data: dict[str, Tensor], forces=False, stress=False, hessian=False) -> dict[str, Tensor]:
944
+ # Use stored train mode for create_graph decision
945
+ _create_graph = hessian or self._train
248
946
  x = []
249
947
  if hessian:
250
948
  forces = True
@@ -260,21 +958,62 @@ class AIMNet2Calculator:
260
958
  data["forces"] = -deriv[0]
261
959
  if stress:
262
960
  dedc = deriv[0] if not forces else deriv[1]
263
- data["stress"] = dedc / data["cell"].detach().det().abs()
961
+ cell = data["cell"].detach()
962
+ if cell.ndim == 2:
963
+ # Single cell (3, 3)
964
+ volume = cell.det().abs()
965
+ else:
966
+ # Batched cells (B, 3, 3) - compute volume for each cell
967
+ volume = torch.linalg.det(cell).abs().unsqueeze(-1).unsqueeze(-1) # (B, 1, 1)
968
+ data["stress"] = dedc / volume
264
969
  if hessian:
265
970
  data["hessian"] = self.calculate_hessian(data["forces"], self._saved_for_grad["coord"])
266
971
  return data
267
972
 
268
973
  @staticmethod
269
974
  def calculate_hessian(forces: Tensor, coord: Tensor) -> Tensor:
270
- # here forces have shape (N, 3) and coord has shape (N+1, 3)
271
- # return hessian with shape (N, 3, N, 3)
975
+ # Coord includes padding atom (shape N+1), forces only for real atoms (shape N)
976
+ # Hessian computed only for actual atoms: (N, 3, N, 3)
272
977
  hessian = -torch.stack([
273
978
  torch.autograd.grad(_f, coord, retain_graph=True)[0] for _f in forces.flatten().unbind()
274
979
  ]).view(-1, 3, coord.shape[0], 3)[:-1, :, :-1, :]
275
980
  return hessian
276
981
 
277
982
 
983
+ def _add_padding_row(
984
+ nbmat: Tensor,
985
+ shifts: Tensor | None,
986
+ N: int,
987
+ ) -> tuple[Tensor, Tensor | None]:
988
+ """Add padding row to neighbor matrix and shifts.
989
+
990
+ Parameters
991
+ ----------
992
+ nbmat : Tensor
993
+ Neighbor matrix, shape (N, max_neighbors).
994
+ shifts : Tensor | None
995
+ Shift vectors for PBC or None, shape (N, max_neighbors, 3).
996
+ N : int
997
+ Number of atoms (used as fill value for padding row).
998
+
999
+ Returns
1000
+ -------
1001
+ tuple[Tensor, Tensor | None]
1002
+ Tuple of (nbmat, shifts) with padding row added.
1003
+ """
1004
+ device = nbmat.device
1005
+ dtype = nbmat.dtype
1006
+ nnb_max = nbmat.shape[1]
1007
+ padding_row = torch.full((1, nnb_max), N, dtype=dtype, device=device)
1008
+ nbmat = torch.cat([nbmat, padding_row], dim=0)
1009
+
1010
+ if shifts is not None:
1011
+ shifts_padding = torch.zeros((1, nnb_max, 3), dtype=shifts.dtype, device=device)
1012
+ shifts = torch.cat([shifts, shifts_padding], dim=0)
1013
+
1014
+ return nbmat, shifts
1015
+
1016
+
278
1017
  def maybe_pad_dim0(a: Tensor, N: int, value=0.0) -> Tensor:
279
1018
  _shape_diff = N - a.shape[0]
280
1019
  assert _shape_diff == 0 or _shape_diff == 1, "Invalid shape"
@@ -297,24 +1036,45 @@ def maybe_unpad_dim0(a: Tensor, N: int) -> Tensor:
297
1036
  return a
298
1037
 
299
1038
 
300
- def move_coord_to_cell(coord, cell):
301
- coord_f = coord @ cell.inverse()
302
- coord_f = coord_f % 1
303
- return coord_f @ cell
304
-
305
-
306
- def _named_children_rec(module):
307
- if isinstance(module, torch.nn.Module):
308
- for name, child in module.named_children():
309
- yield name, child
310
- yield from _named_children_rec(child)
1039
+ def move_coord_to_cell(coord: Tensor, cell: Tensor, mol_idx: Tensor | None = None) -> Tensor:
1040
+ """Move coordinates into the periodic cell.
311
1041
 
1042
+ Parameters
1043
+ ----------
1044
+ coord : Tensor
1045
+ Coordinates tensor, shape (N, 3) or (B, N, 3).
1046
+ cell : Tensor
1047
+ Cell tensor, shape (3, 3) or (B, 3, 3).
1048
+ mol_idx : Tensor | None
1049
+ Molecule index for each atom, shape (N,).
1050
+ Required for batched cells with flat coordinates.
312
1051
 
313
- def iter_lrcoulomb_mods(model):
314
- for name, module in _named_children_rec(model):
315
- if name == "lrcoulomb":
316
- yield module
317
-
318
-
319
- def calc_max_nb(cutoff: float, density: float = 0.2) -> int | float:
320
- return int(density * 4 / 3 * 3.14159 * cutoff**3) if cutoff < float("inf") else float("inf")
1052
+ Returns
1053
+ -------
1054
+ Tensor
1055
+ Coordinates wrapped into the cell.
1056
+ """
1057
+ if cell.ndim == 2:
1058
+ # Single cell (3, 3)
1059
+ cell_inv = torch.linalg.inv(cell)
1060
+ coord_f = coord @ cell_inv
1061
+ coord_f = coord_f % 1
1062
+ return coord_f @ cell
1063
+ else:
1064
+ # Batched cells (B, 3, 3)
1065
+ if coord.ndim == 3:
1066
+ # Batched coords (B, N, 3) with batched cells (B, 3, 3)
1067
+ cell_inv = torch.linalg.inv(cell) # (B, 3, 3)
1068
+ coord_f = torch.bmm(coord, cell_inv) # (B, N, 3)
1069
+ coord_f = coord_f % 1
1070
+ return torch.bmm(coord_f, cell)
1071
+ else:
1072
+ # Flat coords (N_total, 3) with batched cells (B, 3, 3) - need mol_idx
1073
+ assert mol_idx is not None, "mol_idx required for batched cells with flat coordinates"
1074
+ cell_inv = torch.linalg.inv(cell) # (B, 3, 3)
1075
+ # Get cell and cell_inv for each atom
1076
+ atom_cell = cell[mol_idx] # (N_total, 3, 3)
1077
+ atom_cell_inv = cell_inv[mol_idx] # (N_total, 3, 3)
1078
+ coord_f = torch.bmm(coord.unsqueeze(1), atom_cell_inv).squeeze(1) # (N_total, 3)
1079
+ coord_f = coord_f % 1
1080
+ return torch.bmm(coord_f.unsqueeze(1), atom_cell).squeeze(1)