aimnet 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aimnet/train/metrics.py CHANGED
@@ -1,6 +1,5 @@
1
1
  import logging
2
2
  from collections import defaultdict
3
- from typing import Dict, List, Tuple
4
3
 
5
4
  import ignite.distributed as idist
6
5
  import numpy as np
@@ -11,7 +10,7 @@ from ignite.metrics.metric import reinit__is_reduced
11
10
  from torch import Tensor
12
11
 
13
12
 
14
- def regression_stats(pred: Tensor, true: Tensor) -> Dict[str, Tensor]:
13
+ def regression_stats(pred: Tensor, true: Tensor) -> dict[str, Tensor]:
15
14
  diff = true - pred
16
15
  diff2 = diff.pow(2)
17
16
  mae = diff.abs().mean(-1)
@@ -23,7 +22,7 @@ def regression_stats(pred: Tensor, true: Tensor) -> Dict[str, Tensor]:
23
22
  return {"mae": mae, "rmse": rmse, "r2": r2}
24
23
 
25
24
 
26
- def cat_flatten(y_pred: Tensor, y_true: Tensor) -> Tuple[Tensor, Tensor]:
25
+ def cat_flatten(y_pred: Tensor, y_true: Tensor) -> tuple[Tensor, Tensor]:
27
26
  if isinstance(y_true, (list, tuple)):
28
27
  y_true = torch.cat([x.view(-1) for x in y_true])
29
28
  if isinstance(y_pred, (list, tuple)):
@@ -69,7 +68,7 @@ def calculate_metrics(result, histogram=False, corrplot=False):
69
68
  y_pred, y_true = cat_flatten(y_pred, y_true)
70
69
  stats = regression_stats(y_pred, y_true)
71
70
  npass = stats["mae"].numel()
72
- if k.split(".")[-1] in ("energy", "forces"): # noqa: SIM108
71
+ if k.split(".")[-1] in ("energy", "forces"):
73
72
  f = 23.06 # eV to kcal/mol
74
73
  else:
75
74
  f = 1.0
@@ -86,7 +85,7 @@ def calculate_metrics(result, histogram=False, corrplot=False):
86
85
 
87
86
 
88
87
  class RegMultiMetric(Metric):
89
- def __init__(self, cfg: List[Dict], loss_fn=None):
88
+ def __init__(self, cfg: list[dict], loss_fn=None):
90
89
  super().__init__()
91
90
  self.cfg = cfg
92
91
  self.loss_fn = loss_fn
@@ -139,7 +138,7 @@ class RegMultiMetric(Metric):
139
138
  loss = loss.item()
140
139
  self.loss[k] += loss * b
141
140
 
142
- def compute(self) -> Dict[str, float]:
141
+ def compute(self) -> dict[str, float]:
143
142
  if self.samples == 0:
144
143
  raise NotComputableError
145
144
  # Use custom reduction
aimnet/train/train.py CHANGED
@@ -70,7 +70,7 @@ def train(config, model, load=None, save=None, args=None, no_default_config=Fals
70
70
  logging.info(train_cfg)
71
71
  logging.info("--- END train.yaml ---")
72
72
 
73
- # try load model and pring its configuration
73
+ # try load model and print its configuration
74
74
  logging.info("Building model")
75
75
  model = utils.build_model(model_cfg)
76
76
  logging.info(model)
@@ -113,6 +113,8 @@ def run(local_rank, world_size, model_cfg, train_cfg, load, save):
113
113
  from ignite import distributed as idist
114
114
 
115
115
  model = idist.auto_model(model) # type: ignore[attr-defined]
116
+ elif torch.cuda.is_available():
117
+ model = model.cuda() # type: ignore
116
118
 
117
119
  # load weights
118
120
  if load is not None:
@@ -135,6 +137,7 @@ def run(local_rank, world_size, model_cfg, train_cfg, load, save):
135
137
  optimizer = idist.auto_optim(optimizer) # type: ignore[attr-defined]
136
138
  scheduler = utils.get_scheduler(optimizer, train_cfg.scheduler) if train_cfg.scheduler is not None else None # type: ignore[attr-defined]
137
139
  loss = utils.get_loss(train_cfg.loss)
140
+
138
141
  metrics = utils.get_metrics(train_cfg.metrics)
139
142
  metrics.attach_loss(loss) # type: ignore[attr-defined]
140
143
 
aimnet/train/utils.py CHANGED
@@ -1,7 +1,8 @@
1
1
  import logging
2
2
  import os
3
3
  import re
4
- from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
4
+ from collections.abc import Callable, Sequence
5
+ from typing import Any
5
6
 
6
7
  import numpy as np
7
8
  import omegaconf
@@ -58,7 +59,7 @@ def apply_sae(ds: SizeGroupedDataset, cfg: omegaconf.DictConfig):
58
59
  if c is not None and k in cfg.y:
59
60
  sae = load_yaml(c.file)
60
61
  unique_numbers = set(np.unique(ds.concatenate("numbers").tolist()))
61
- if not set(sae.keys()).issubset(unique_numbers): # type: ignore[attr-defined]
62
+ if not unique_numbers.issubset(sae.keys()): # type: ignore[attr-defined]
62
63
  raise ValueError(f"Keys in SAE file {c.file} do not cover all the dataset atoms")
63
64
  if c.mode == "linreg":
64
65
  ds.apply_peratom_shift(k, k, sap_dict=sae)
@@ -177,7 +178,7 @@ def get_loss(cfg: omegaconf.DictConfig):
177
178
  return loss
178
179
 
179
180
 
180
- def set_trainable_parameters(model: nn.Module, force_train: List[str], force_no_train: List[str]) -> nn.Module:
181
+ def set_trainable_parameters(model: nn.Module, force_train: list[str], force_no_train: list[str]) -> nn.Module:
181
182
  for n, p in model.named_parameters():
182
183
  if any(re.search(x, n) for x in force_no_train):
183
184
  p.requires_grad_(False)
@@ -214,7 +215,7 @@ def get_metrics(cfg: omegaconf.DictConfig):
214
215
  return metrics
215
216
 
216
217
 
217
- def train_step(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
218
+ def train_step(engine: Engine, batch: Sequence[torch.Tensor]) -> Any | tuple[torch.Tensor]:
218
219
  global model
219
220
  global optimizer
220
221
  global prepare_batch
@@ -232,7 +233,7 @@ def train_step(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tupl
232
233
  return loss.item()
233
234
 
234
235
 
235
- def val_step(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
236
+ def val_step(engine: Engine, batch: Sequence[torch.Tensor]) -> Any | tuple[torch.Tensor]:
236
237
  global model
237
238
  global optimizer
238
239
  global prepare_batch
@@ -248,7 +249,7 @@ def val_step(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[
248
249
  return y_pred, y
249
250
 
250
251
 
251
- def prepare_batch(batch: Dict[str, Tensor], device="cuda", non_blocking=True) -> Dict[str, Tensor]: # noqa: F811
252
+ def prepare_batch(batch: dict[str, Tensor], device="cuda", non_blocking=True) -> dict[str, Tensor]: # noqa: F811
252
253
  for k, v in batch.items():
253
254
  batch[k] = v.to(device, non_blocking=non_blocking)
254
255
  return batch
@@ -257,11 +258,11 @@ def prepare_batch(batch: Dict[str, Tensor], device="cuda", non_blocking=True) ->
257
258
  def default_trainer(
258
259
  model: torch.nn.Module,
259
260
  optimizer: torch.optim.Optimizer,
260
- loss_fn: Union[Callable, torch.nn.Module],
261
- device: Optional[Union[str, torch.device]] = None,
261
+ loss_fn: Callable | torch.nn.Module,
262
+ device: str | torch.device | None = None,
262
263
  non_blocking: bool = True,
263
264
  ) -> Engine:
264
- def _update(engine: Engine, batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]]) -> float:
265
+ def _update(engine: Engine, batch: tuple[dict[str, Tensor], dict[str, Tensor]]) -> float:
265
266
  model.train()
266
267
  optimizer.zero_grad()
267
268
  x = prepare_batch(batch[0], device=device, non_blocking=non_blocking) # type: ignore
@@ -271,17 +272,18 @@ def default_trainer(
271
272
  loss.backward()
272
273
  torch.nn.utils.clip_grad_value_(model.parameters(), 0.4)
273
274
  optimizer.step()
275
+
274
276
  return loss.item()
275
277
 
276
278
  return Engine(_update)
277
279
 
278
280
 
279
281
  def default_evaluator(
280
- model: torch.nn.Module, device: Optional[Union[str, torch.device]] = None, non_blocking: bool = True
282
+ model: torch.nn.Module, device: str | torch.device | None = None, non_blocking: bool = True
281
283
  ) -> Engine:
282
284
  def _inference(
283
- engine: Engine, batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]]
284
- ) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]:
285
+ engine: Engine, batch: tuple[dict[str, Tensor], dict[str, Tensor]]
286
+ ) -> tuple[dict[str, Tensor], dict[str, Tensor]]:
285
287
  model.eval()
286
288
  x = prepare_batch(batch[0], device=device, non_blocking=non_blocking) # type: ignore
287
289
  y = prepare_batch(batch[1], device=device, non_blocking=non_blocking) # type: ignore
@@ -316,6 +318,23 @@ def build_engine(model, optimizer, scheduler, loss_fn, metrics, cfg, loader_val)
316
318
  logging.info(f"LR: {lr}")
317
319
 
318
320
  trainer.add_event_handler(Events.EPOCH_STARTED, log_lr)
321
+
322
+ # log loss weights
323
+ def log_loss_weights(engine):
324
+ s = []
325
+ for k, v in loss_fn.components.items():
326
+ s.append(f"{k}: {v[1]:.4f}")
327
+ s = " ".join(s)
328
+ logging.info(s)
329
+ if loss_fn.weights is not None:
330
+ s = []
331
+ for k, v in loss_fn.weights.items():
332
+ s.append(f"{k}: {v:.4f}")
333
+ s = " ".join(s)
334
+ logging.info(s)
335
+
336
+ trainer.add_event_handler(Events.EPOCH_STARTED, log_loss_weights)
337
+
319
338
  # write TQDM progress
320
339
  if idist.get_local_rank() == 0:
321
340
  pbar = ProgressBar()
@@ -327,15 +346,22 @@ def build_engine(model, optimizer, scheduler, loss_fn, metrics, cfg, loader_val)
327
346
  metrics.attach(validator, "multi")
328
347
  trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), validator.run, data=loader_val)
329
348
 
349
+ # attach optimizer and loss to engines
350
+ trainer.state.optimizer = optimizer
351
+ trainer.state.loss_fn = loss_fn
352
+ validator.state.optimizer = optimizer
353
+ validator.state.loss_fn = loss_fn
354
+
330
355
  # scheduler
331
356
  if scheduler is not None:
357
+ validator.state.scheduler = scheduler
332
358
  validator.add_event_handler(Events.COMPLETED, scheduler)
333
359
  terminator = TerminateOnLowLR(optimizer, cfg.scheduler.terminate_on_low_lr)
334
360
  trainer.add_event_handler(Events.EPOCH_STARTED, terminator)
335
361
 
336
362
  # checkpoint after each epoch
337
363
  if cfg.checkpoint and idist.get_local_rank() == 0:
338
- kwargs = OmegaConf.to_container(cfg.checkpoint.kwargs) if "kwargs" not in cfg.checkpoint else {}
364
+ kwargs = OmegaConf.to_container(cfg.checkpoint.kwargs) if "kwargs" in cfg.checkpoint else {}
339
365
  if not isinstance(kwargs, dict):
340
366
  raise TypeError("Checkpoint kwargs must be a dictionary.")
341
367
  kwargs["global_step_transform"] = global_step_from_engine(trainer)
@@ -379,6 +405,9 @@ def setup_wandb(cfg, model_cfg, model, trainer, validator, optimizer):
379
405
  params = {
380
406
  f"{self.param_name}_{i}": float(g[self.param_name]) for i, g in enumerate(self.optimizer.param_groups)
381
407
  }
408
+ if hasattr(engine.state, "loss_fn") and hasattr(engine.state.loss_fn, "components"): # type: ignore
409
+ for name, (_, w) in engine.state.loss_fn.components.items(): # type: ignore
410
+ params[f"weight/{name}"] = w
382
411
  logger.log(params, step=global_step, sync=self.sync)
383
412
 
384
413
  wandb_logger.attach(trainer, log_handler=EpochLRLogger(optimizer), event_name=Events.EPOCH_STARTED)
@@ -0,0 +1,308 @@
1
+ Metadata-Version: 2.4
2
+ Name: aimnet
3
+ Version: 0.1.0
4
+ Summary: AIMNet Machine Learned Interatomic Potential
5
+ Project-URL: Homepage, https://github.com/isayevlab/aimnetcentral
6
+ Project-URL: Documentation, https://isayevlab.github.io/aimnetcentral/
7
+ Project-URL: Repository, https://github.com/isayevlab/aimnetcentral
8
+ Author-email: Roman Zubatyuk <zubatyuk@gmail.com>
9
+ License: MIT License
10
+
11
+ Copyright (c) 2024, Roman Zubatyuk
12
+
13
+ Permission is hereby granted, free of charge, to any person obtaining a copy
14
+ of this software and associated documentation files (the "Software"), to deal
15
+ in the Software without restriction, including without limitation the rights
16
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
17
+ copies of the Software, and to permit persons to whom the Software is
18
+ furnished to do so, subject to the following conditions:
19
+
20
+ The above copyright notice and this permission notice shall be included in all
21
+ copies or substantial portions of the Software.
22
+
23
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
24
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
25
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
26
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
27
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
28
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
29
+ SOFTWARE.
30
+ License-File: LICENSE
31
+ Keywords: computational chemistry,deep learning,interatomic potential,machine learning,molecular dynamics
32
+ Classifier: Intended Audience :: Science/Research
33
+ Classifier: License :: OSI Approved :: MIT License
34
+ Classifier: Operating System :: OS Independent
35
+ Classifier: Programming Language :: Python :: 3
36
+ Classifier: Programming Language :: Python :: 3.11
37
+ Classifier: Programming Language :: Python :: 3.12
38
+ Classifier: Topic :: Scientific/Engineering :: Chemistry
39
+ Classifier: Topic :: Scientific/Engineering :: Physics
40
+ Requires-Python: >=3.11
41
+ Requires-Dist: click>=8.1.7
42
+ Requires-Dist: h5py>=3.12.1
43
+ Requires-Dist: jinja2>=3.1.6
44
+ Requires-Dist: numpy
45
+ Requires-Dist: nvalchemi-toolkit-ops>=0.2
46
+ Requires-Dist: pyyaml>=6.0.2
47
+ Requires-Dist: requests>=2.32.3
48
+ Requires-Dist: torch>=2.4
49
+ Requires-Dist: warp-lang>=1.11
50
+ Provides-Extra: ase
51
+ Requires-Dist: ase==3.27.0; extra == 'ase'
52
+ Provides-Extra: pysis
53
+ Requires-Dist: pysisyphus; extra == 'pysis'
54
+ Provides-Extra: train
55
+ Requires-Dist: omegaconf>=2.3.0; extra == 'train'
56
+ Requires-Dist: pytorch-ignite>=0.5.1; extra == 'train'
57
+ Requires-Dist: wandb>=0.18.5; extra == 'train'
58
+ Description-Content-Type: text/markdown
59
+
60
+ [![Release](https://img.shields.io/github/v/release/isayevlab/aimnetcentral)](https://github.com/isayevlab/aimnetcentral/releases)
61
+ [![Python](https://img.shields.io/badge/python-3.11%20%7C%203.12-blue)](https://www.python.org/)
62
+ [![Build status](https://img.shields.io/github/actions/workflow/status/isayevlab/aimnetcentral/main.yml?branch=main)](https://github.com/isayevlab/aimnetcentral/actions/workflows/main.yml?query=branch%3Amain)
63
+ [![codecov](https://codecov.io/gh/isayevlab/aimnetcentral/branch/main/graph/badge.svg)](https://codecov.io/gh/isayevlab/aimnetcentral)
64
+ [![License](https://img.shields.io/github/license/isayevlab/aimnetcentral)](https://github.com/isayevlab/aimnetcentral/blob/main/LICENSE)
65
+
66
+ - **Github repository**: <https://github.com/isayevlab/aimnetcentral/>
67
+ - **Documentation** <https://isayevlab.github.io/aimnetcentral/>
68
+
69
+ # AIMNet2 : ML interatomic potential for fast and accurate atomistic simulations
70
+
71
+ ## Key Features
72
+
73
+ - Accurate and Versatile: AIMNet2 excels at modeling neutral, charged, organic, and elemental-organic systems.
74
+ - Flexible Interfaces: Use AIMNet2 through convenient calculators for popular simulation packages like ASE and PySisyphus.
75
+ - Flexible Long-Range Interactions: Optionally employ the Damped-Shifted Force (DSF) or Ewald summation Coulomb models for accurate calculations in large or periodic systems.
76
+
77
+ ## Requirements
78
+
79
+ ### Python Version
80
+
81
+ AIMNet2 requires **Python 3.11 or 3.12**.
82
+
83
+ ### GPU Support (Optional)
84
+
85
+ AIMNet2 works on CPU out of the box. For GPU acceleration:
86
+
87
+ - **CUDA GPU**: Install PyTorch with CUDA support from [pytorch.org](https://pytorch.org/get-started/locally/)
88
+ - **compile_mode**: Requires CUDA for ~5x MD speedup (see Performance Optimization)
89
+
90
+ Example PyTorch installation with CUDA 12.4:
91
+
92
+ ```bash
93
+ pip install torch --index-url https://download.pytorch.org/whl/cu124
94
+ ```
95
+
96
+ ## Available Models
97
+
98
+ | Model | Elements | Description |
99
+ | --------------- | --------------------------------------------- | --------------------------------------------- |
100
+ | `aimnet2` | H, B, C, N, O, F, Si, P, S, Cl, As, Se, Br, I | wB97M-D3 (default) |
101
+ | `aimnet2_b973c` | H, B, C, N, O, F, Si, P, S, Cl, As, Se, Br, I | B97-3c functional |
102
+ | `aimnet2_2025` | H, B, C, N, O, F, Si, P, S, Cl, As, Se, Br, I | B97-3c + improved intermolecular interactions |
103
+ | `aimnet2nse` | H, B, C, N, O, F, Si, P, S, Cl, As, Se, Br, I | Open-shell chemistry |
104
+ | `aimnet2pd` | H, B, C, N, O, F, Si, P, S, Cl, Se, Br, Pd, I | Palladium-containing systems |
105
+
106
+ _Each model has ensemble members (append \_0 to \_3). Ensemble averaging recommended for production use._
107
+
108
+ ## Installation
109
+
110
+ ### Basic Installation
111
+
112
+ Install from PyPI:
113
+
114
+ ```bash
115
+ pip install aimnet
116
+ ```
117
+
118
+ ### Optional Features
119
+
120
+ AIMNet2 provides optional extras for different use cases:
121
+
122
+ **ASE Calculator** (for atomistic simulations with ASE):
123
+
124
+ ```bash
125
+ pip install "aimnet[ase]"
126
+ ```
127
+
128
+ **PySisyphus Calculator** (for reaction path calculations):
129
+
130
+ ```bash
131
+ pip install "aimnet[pysis]"
132
+ ```
133
+
134
+ **Training** (for model training and development):
135
+
136
+ ```bash
137
+ pip install "aimnet[train]"
138
+ ```
139
+
140
+ **All Features**:
141
+
142
+ ```bash
143
+ pip install "aimnet[ase,pysis,train]"
144
+ ```
145
+
146
+ ### Development Installation
147
+
148
+ For contributors, use [uv](https://docs.astral.sh/uv/) for fast dependency management:
149
+
150
+ ```bash
151
+ git clone https://github.com/isayevlab/aimnetcentral.git
152
+ cd aimnetcentral
153
+ make install
154
+ source .venv/bin/activate
155
+ ```
156
+
157
+ ## Quick Start
158
+
159
+ ### Basic Usage (Core)
160
+
161
+ ```python
162
+ from aimnet.calculators import AIMNet2Calculator
163
+
164
+ # Load a pre-trained model
165
+ calc = AIMNet2Calculator("aimnet2")
166
+
167
+ # Prepare input
168
+ data = {
169
+ "coord": coordinates, # Nx3 array
170
+ "numbers": atomic_numbers, # N array
171
+ "charge": 0.0,
172
+ }
173
+
174
+ # Run inference
175
+ results = calc(data, forces=True)
176
+ print(results["energy"], results["forces"])
177
+ ```
178
+
179
+ ### Output Data
180
+
181
+ The calculator returns a dictionary with the following keys:
182
+
183
+ | Key | Shape | Description |
184
+ | --------- | ----------------------- | ------------------------------------ |
185
+ | `energy` | `(,)` or `(B,)` | Total energy in eV |
186
+ | `charges` | `(N,)` or `(B, N)` | Atomic partial charges in e |
187
+ | `forces` | `(N, 3)` or `(B, N, 3)` | Atomic forces in eV/A (if requested) |
188
+ | `hessian` | `(N, 3, N, 3)` | Second derivatives (if requested) |
189
+ | `stress` | `(3, 3)` | Stress tensor for PBC (if requested) |
190
+
191
+ _B = batch size, N = number of atoms_
192
+
193
+ ### ASE Integration
194
+
195
+ With `aimnet[ase]` installed:
196
+
197
+ ```python
198
+ from ase.io import read
199
+ from aimnet.calculators import AIMNet2ASE
200
+
201
+ atoms = read("molecule.xyz")
202
+ atoms.calc = AIMNet2ASE("aimnet2")
203
+
204
+ energy = atoms.get_potential_energy()
205
+ forces = atoms.get_forces()
206
+ ```
207
+
208
+ ### Periodic Boundary Conditions
209
+
210
+ For periodic systems, provide a unit cell:
211
+
212
+ ```python
213
+ data = {
214
+ "coord": coordinates,
215
+ "numbers": atomic_numbers,
216
+ "charge": 0.0,
217
+ "cell": cell_vectors, # 3x3 array in Angstrom
218
+ }
219
+ results = calc(data, forces=True, stress=True)
220
+ ```
221
+
222
+ ### Long-Range Coulomb Methods
223
+
224
+ Configure electrostatic interactions for large or periodic systems:
225
+
226
+ ```python
227
+ # Damped-Shifted Force (DSF) - recommended for periodic systems
228
+ calc.set_lrcoulomb_method("dsf", cutoff=15.0, dsf_alpha=0.2)
229
+
230
+ # Ewald summation - for accurate periodic electrostatics
231
+ calc.set_lrcoulomb_method("ewald", ewald_accuracy=1e-8)
232
+ ```
233
+
234
+ ### Performance Optimization
235
+
236
+ For molecular dynamics simulations, use `compile_mode` for ~5x speedup:
237
+
238
+ ```python
239
+ calc = AIMNet2Calculator("aimnet2", compile_mode=True)
240
+ ```
241
+
242
+ Requirements:
243
+
244
+ - CUDA GPU required
245
+ - Not compatible with periodic boundary conditions
246
+ - Best for repeated inference on similar-sized systems
247
+
248
+ ### Training
249
+
250
+ With `aimnet[train]` installed:
251
+
252
+ ```bash
253
+ aimnet train --config my_config.yaml --model aimnet2.yaml
254
+ ```
255
+
256
+ ## Technical Details
257
+
258
+ ### Batching and Neighbor Lists
259
+
260
+ The `AIMNet2Calculator` automatically selects the optimal strategy based on system size (`nb_threshold`, default 120 atoms) and hardware:
261
+
262
+ 1. **Dense Mode (O(N²))**: Used for small molecules on GPU. Input is kept in 3D batched format `(B, N, 3)`. No neighbor list is computed; the model uses a fully connected graph for maximum parallelism.
263
+ 2. **Sparse Mode (O(N))**: Used for large systems or CPU execution. Input is flattened to 2D `(N_total, 3)` with an adaptive neighbor list. This ensures linear memory scaling.
264
+
265
+ ### Adaptive Neighbor List
266
+
267
+ In sparse mode, AIMNet2 uses an `AdaptiveNeighborList` that automatically resizes its buffer to maintain efficient utilization (~75%) while preventing overflows.
268
+
269
+ - **Format**: The neighbor list is stored as a 2D integer matrix `nbmat` of shape `(N_total, max_neighbors)`. Each row `i` contains the indices of atoms neighboring atom `i`.
270
+ - **Padding**: Rows with fewer neighbors than `max_neighbors` are padded with the index `N_total` (a dummy atom index).
271
+ - **Buffer Management**: The buffer size `max_neighbors` is always a multiple of 16 for memory alignment. It dynamically expands (by 1.5x) on overflow and shrinks if utilization drops significantly below the target, ensuring robust performance during MD simulations where density fluctuates.
272
+
273
+ ## Development
274
+
275
+ Common development tasks using `make`:
276
+
277
+ ```bash
278
+ make check # Run linters and code quality checks
279
+ make test # Run tests with coverage
280
+ make docs # Build and serve documentation
281
+ make build # Build distribution packages
282
+ ```
283
+
284
+ ## Citation
285
+
286
+ If you use AIMNet2 in your research, please cite the appropriate paper:
287
+
288
+ **AIMNet2 (main model):**
289
+
290
+ ```bibtex
291
+ @article{aimnet2,
292
+ title={AIMNet2: A Neural Network Potential to Meet Your Neutral, Charged, Organic, and Elemental-Organic Needs},
293
+ author={Anstine, Dylan M and Zubatyuk, Roman and Isayev, Olexandr},
294
+ journal={Chemical Science},
295
+ volume={16},
296
+ pages={10228--10244},
297
+ year={2025},
298
+ doi={10.1039/D4SC08572H}
299
+ }
300
+ ```
301
+
302
+ **AIMNet2-NSE:** [ChemRxiv preprint](https://chemrxiv.org/engage/chemrxiv/article-details/692d304c65a54c2d4a7ab3c7)
303
+
304
+ **AIMNet2-Pd:** [ChemRxiv preprint](https://chemrxiv.org/engage/chemrxiv/article-details/67d7b7f7fa469535b97c021a)
305
+
306
+ ## License
307
+
308
+ See [LICENSE](LICENSE) file for details.
@@ -0,0 +1,43 @@
1
+ aimnet/__init__.py,sha256=uxi5cEnLX3g_C7nUnxGwULhF-wyNIUwcYUImDYhnoKA,185
2
+ aimnet/base.py,sha256=jkMYzkF8ZB3RvN2G4p3felkXZm3UXVYK6KPcj5nC0Ac,1875
3
+ aimnet/cli.py,sha256=d6sICpuRFoMifjE4DxPGB8YPUYOlxNnxZeQWfhqAypE,2158
4
+ aimnet/config.py,sha256=N9ePMjlmUKWwumi-mAcBN_SrvqAAm41hYRxYv-MDcvU,5881
5
+ aimnet/constants.py,sha256=T6eb_CG5dkxuUg22Xf8WiF0iGB-JLoLNzHq9R8SRzrQ,8760
6
+ aimnet/nbops.py,sha256=tatWTo3RyQNfTfD6dO1Cxp70AKLmup89m69rNqITJGA,8731
7
+ aimnet/ops.py,sha256=DJ9WqlTfP7ATYKOk0DvTSEWeYjCMnhv0-8qoj5zn2KY,10523
8
+ aimnet/calculators/__init__.py,sha256=xn4guycZdf_FRJenbzTTQYG7BppFDZ4K671Ser4oRRc,383
9
+ aimnet/calculators/aimnet2ase.py,sha256=s6M-quNxISpX9TjEthH6HNIvifcIA0lNcVOECpNGxJU,3890
10
+ aimnet/calculators/aimnet2pysis.py,sha256=TDoDl-TExCQkg8EXfsTKjOu-z6Kd25xXqSH98gF8iGM,2876
11
+ aimnet/calculators/calculator.py,sha256=lZXo7YqSLJuMPruFRqqxACzAlqL14J-rVI62zrduwT4,43938
12
+ aimnet/calculators/model_registry.py,sha256=Vt0zW6LBOVZOSWSMFVUHhP4twsJXEYyrOPeaK1Yc07k,1840
13
+ aimnet/data/__init__.py,sha256=rD84W2qazjXUbdcYBuHkMmYp44fGRFMS3qnv1OKPsQs,87
14
+ aimnet/data/sgdataset.py,sha256=SKdSLXhtE00dKqF96FNX2sdI8wvenHMdIeyibHeGZqM,18891
15
+ aimnet/kernels/__init__.py,sha256=seUBZszDGFGhRfr_WI7j8cz-rtstEDVLcFEHMaqEjbw,2548
16
+ aimnet/kernels/conv_sv_2d_sp_wp.py,sha256=JUXej8LYa6She_TTUda-scZ_8pi7Q6VZ97Sfu5dLvVU,15329
17
+ aimnet/models/__init__.py,sha256=_cbeJgPf35xQW7nib1zn0VeEvgTxK4mXRkwYF472tds,361
18
+ aimnet/models/aimnet2.py,sha256=nmmNJTLuNxdMxzZ7Jn0fk98UesEl-MHFL0NEvnwi2eg,6971
19
+ aimnet/models/base.py,sha256=1Cux585dEkzkmJ9Ne8EVzv7fO-CW3ga5SSFcJ643uC8,8290
20
+ aimnet/models/convert.py,sha256=vxS2sIfM-wM2BOglXtiG1k0S5ayGT8CY7nR6d2y5xgU,1018
21
+ aimnet/models/utils.py,sha256=Itn3VEFtWc64rpEzlYys8J4EfZwzGJSUb2-7LGiVX9U,24445
22
+ aimnet/modules/__init__.py,sha256=I-9heqRdnXeypMUjzeQ6Y4EATM_9n4NFyCC6Om6H2Vw,216
23
+ aimnet/modules/aev.py,sha256=aYCootmc3mdp9rgpvwc6dtrX96bORu1xTlgPPGobOQU,8099
24
+ aimnet/modules/core.py,sha256=vIPyMcrS6tfQUMdlLYx1d-I8lJ56gRkofk5Wf6sEKPU,8232
25
+ aimnet/modules/lr.py,sha256=1jA319NUvT-Cact7Xf4vY8tZ58IpOAg-8DDciRvLkOg,23741
26
+ aimnet/modules/ops.py,sha256=FQkN-2yVK4pMGdUv7xYDvDB2ZnCiRNLksYmJjX-q9Hk,16175
27
+ aimnet/train/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
+ aimnet/train/calc_sae.py,sha256=h-y8jQVU6O7Rmw8xJWX3DHOpFv1gtjFwtZzZyUOhz9g,1459
29
+ aimnet/train/export_model.py,sha256=mku74Bqu5OkxD-iDTnb78A202Bzu7vIvfBKjyrMgKsI,9028
30
+ aimnet/train/loss.py,sha256=HSSagZvO_jRvouIwX0ZHLiOMf9uLgX3rSmtrZMCcKOA,3169
31
+ aimnet/train/metrics.py,sha256=gCYIFAgsOfmGv2Mc2GtfOcIq2TLU6q5cvQPWLuiUdSM,6814
32
+ aimnet/train/train.py,sha256=aOrYbujKdy_3nLUVM0wkv1E_CLRDGUJ7uIFzgmYegW4,5793
33
+ aimnet/train/utils.py,sha256=8-6jvNRGaoEODjLOnmxhmUSQJYiM1GfFeGFCeytA5RQ,15666
34
+ aimnet/dftd3_data.pt,sha256=g9HQ_2nekMs97UnIDIFReGEG8q6n6zQlq-Yl5LGQnOI,2710647
35
+ aimnet/calculators/model_registry.yaml,sha256=HsreBHvQHkf0ErCuzYMqq1MsPYv6g6xRhS9yzOMrKOk,3423
36
+ aimnet/models/aimnet2.yaml,sha256=VmI0ub7UbiJm833zwCtCKsrBIaXZKjJsCXxXxutwdLI,1062
37
+ aimnet/models/aimnet2_dftd3_wb97m.yaml,sha256=Gk869_2RQIXzlHwRNbfsai4u9wPX3oAyINK1I71eBFI,1219
38
+ aimnet/train/default_train.yaml,sha256=FAG3f3pFC2dkFUuVO_XE_UB6rFr-R3GkACddn7FH_58,4783
39
+ aimnet-0.1.0.dist-info/METADATA,sha256=iPzwFcL4vxHeB-zEzUINhXEsHhSGjFQZWsD9xyMGVnQ,11083
40
+ aimnet-0.1.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
41
+ aimnet-0.1.0.dist-info/entry_points.txt,sha256=LEicBveXz9HdsF0UiHXr3tMeHPxofm2VP5v9ISwhe9s,99
42
+ aimnet-0.1.0.dist-info/licenses/LICENSE,sha256=73sk-zg2yVRrOZQDeVbPlVB7ScZc1iK0kyCBMwNwQgA,1072
43
+ aimnet-0.1.0.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 1.9.1
2
+ Generator: hatchling 1.28.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -0,0 +1,3 @@
1
+ [console_scripts]
2
+ aimnet = aimnet.cli:cli
3
+ aimnet2pysis = aimnet.calculators.aimnet2pysis:run_pysis