aimnet 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,5 @@
1
1
  import logging
2
2
  import os
3
- from typing import Dict, Optional
4
3
 
5
4
  import click
6
5
  import requests
@@ -9,7 +8,7 @@ import yaml
9
8
  logging.basicConfig(level=logging.INFO)
10
9
 
11
10
 
12
- def load_model_registry(registry_file: Optional[str] = None) -> Dict[str, str]:
11
+ def load_model_registry(registry_file: str | None = None) -> dict[str, str]:
13
12
  registry_file = registry_file or os.path.join(os.path.dirname(__file__), "model_registry.yaml")
14
13
  with open(os.path.join(os.path.dirname(__file__), "model_registry.yaml")) as f:
15
14
  return yaml.load(f, Loader=yaml.SafeLoader)
@@ -43,9 +42,7 @@ def _maybe_download_asset(file: str, url: str) -> str:
43
42
 
44
43
  def get_model_path(s: str) -> str:
45
44
  # direct file path
46
- if os.path.isfile(s):
47
- print("Found model file:", s)
48
- else:
45
+ if not os.path.isfile(s):
49
46
  s = get_registry_model_path(s)
50
47
  return s
51
48
 
@@ -1,33 +1,71 @@
1
1
  # map file name to url
2
2
  models:
3
3
  aimnet2_wb97m_d3_0:
4
- file: aimnet2_wb97m_d3_0.jpt
5
- url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_d3_0.jpt
4
+ file: aimnet2_wb97m_d3_0.pt
5
+ url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2/aimnet2_wb97m_d3_0.pt
6
6
  aimnet2_wb97m_d3_1:
7
- file: aimnet2_wb97m_d3_1.jpt
8
- url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_d3_1.jpt
7
+ file: aimnet2_wb97m_d3_1.pt
8
+ url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2/aimnet2_wb97m_d3_1.pt
9
9
  aimnet2_wb97m_d3_2:
10
- file: aimnet2_wb97m_d3_2.jpt
11
- url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_d3_2.jpt
10
+ file: aimnet2_wb97m_d3_2.pt
11
+ url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2/aimnet2_wb97m_d3_2.pt
12
12
  aimnet2_wb97m_d3_3:
13
- file: aimnet2_wb97m_d3_3.jpt
14
- url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_d3_3.jpt
13
+ file: aimnet2_wb97m_d3_3.pt
14
+ url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2/aimnet2_wb97m_d3_3.pt
15
15
  aimnet2_b973c_d3_0:
16
- file: aimnet2_b973c_d3_0.jpt
17
- url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_d3_0.jpt
16
+ file: aimnet2_b973c_d3_0.pt
17
+ url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2/aimnet2_b973c_d3_0.pt
18
18
  aimnet2_b973c_d3_1:
19
- file: aimnet2_b973c_d3_1.jpt
20
- url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_d3_1.jpt
19
+ file: aimnet2_b973c_d3_1.pt
20
+ url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2/aimnet2_b973c_d3_1.pt
21
21
  aimnet2_b973c_d3_2:
22
- file: aimnet2_b973c_d3_2.jpt
23
- url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_d3_2.jpt
22
+ file: aimnet2_b973c_d3_2.pt
23
+ url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2/aimnet2_b973c_d3_2.pt
24
24
  aimnet2_b973c_d3_3:
25
- file: aimnet2_b973c_d3_3.jpt
26
- url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_d3_3.jpt
25
+ file: aimnet2_b973c_d3_3.pt
26
+ url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2/aimnet2_b973c_d3_3.pt
27
+ aimnet2_b973c_2025_d3_0:
28
+ file: aimnet2_2025_b973c_d3_0.pt
29
+ url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2/aimnet2_2025_b973c_d3_0.pt
30
+ aimnet2_b973c_2025_d3_1:
31
+ file: aimnet2_2025_b973c_d3_1.pt
32
+ url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2/aimnet2_2025_b973c_d3_1.pt
33
+ aimnet2_b973c_2025_d3_2:
34
+ file: aimnet2_2025_b973c_d3_2.pt
35
+ url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2/aimnet2_2025_b973c_d3_2.pt
36
+ aimnet2_b973c_2025_d3_3:
37
+ file: aimnet2_2025_b973c_d3_3.pt
38
+ url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2/aimnet2_2025_b973c_d3_3.pt
39
+ aimnet2nse_0:
40
+ file: aimnet2nse_wb97m_0.pt
41
+ url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2NSE/aimnet2nse_wb97m_0.pt
42
+ aimnet2nse_1:
43
+ file: aimnet2nse_wb97m_1.pt
44
+ url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2NSE/aimnet2nse_wb97m_1.pt
45
+ aimnet2nse_2:
46
+ file: aimnet2nse_wb97m_2.pt
47
+ url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2NSE/aimnet2nse_wb97m_2.pt
48
+ aimnet2nse_3:
49
+ file: aimnet2nse_wb97m_3.pt
50
+ url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2NSE/aimnet2nse_wb97m_3.pt
51
+ aimnet2-pd_0:
52
+ file: aimnet2-pd_0.pt
53
+ url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2Pd/aimnet2-pd_0.pt
54
+ aimnet2-pd_1:
55
+ file: aimnet2-pd_1.pt
56
+ url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2Pd/aimnet2-pd_1.pt
57
+ aimnet2-pd_2:
58
+ file: aimnet2-pd_2.pt
59
+ url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2Pd/aimnet2-pd_2.pt
60
+ aimnet2-pd_3:
61
+ file: aimnet2-pd_3.pt
62
+ url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2Pd/aimnet2-pd_3.pt
27
63
 
28
64
  # map model alias to file name
29
65
  aliases:
30
66
  aimnet2: aimnet2_wb97m_d3_0
31
67
  aimnet2_wb97m: aimnet2_wb97m_d3_0
32
68
  aimnet2_b973c: aimnet2_b973c_d3_0
33
- aimnet2_qr: aimnet2_qr_v0
69
+ aimnet2nse: aimnet2nse_0
70
+ aimnet2pd: aimnet2-pd_0
71
+ aimnet2_2025: aimnet2_b973c_2025_d3_0
aimnet/cli.py CHANGED
@@ -1,8 +1,8 @@
1
+ import sys
2
+
1
3
  import click
2
4
 
3
- from .train.calc_sae import calc_sae
4
- from .train.pt2jpt import jitcompile
5
- from .train.train import train
5
+ from .calculators.model_registry import clear_assets
6
6
 
7
7
 
8
8
  @click.group()
@@ -10,9 +10,65 @@ def cli():
10
10
  """AIMNet2 command line tool"""
11
11
 
12
12
 
13
- cli.add_command(train, name="train")
14
- cli.add_command(jitcompile, name="jitcompile")
15
- cli.add_command(calc_sae, name="calc_sae")
13
+ # Always available commands
14
+ cli.add_command(clear_assets, name="clear_model_cache")
15
+
16
+
17
+ # Register convert command (doesn't need heavy training dependencies)
18
+ try:
19
+ from .models.convert import convert_legacy_jpt
20
+
21
+ cli.add_command(convert_legacy_jpt, name="convert")
22
+ except ImportError:
23
+
24
+ @cli.command(name="convert")
25
+ def convert_stub():
26
+ """Convert legacy JIT model to new format (requires aimnet[train])"""
27
+ click.echo(
28
+ "Training dependencies not installed.\nInstall with: pip install aimnet[train]",
29
+ err=True,
30
+ )
31
+ sys.exit(1)
32
+
33
+
34
+ # Try to lazily register training commands (requires ignite, etc.)
35
+ try:
36
+ from .train.calc_sae import calc_sae
37
+ from .train.export_model import export_model
38
+ from .train.train import train
39
+
40
+ cli.add_command(train, name="train")
41
+ cli.add_command(export_model, name="export")
42
+ cli.add_command(calc_sae, name="calc_sae")
43
+ except ImportError:
44
+ # If training dependencies are not available, register stub commands with helpful error messages
45
+
46
+ @cli.command(name="train")
47
+ def train_stub():
48
+ """Train AIMNet2 models (requires aimnet[train])"""
49
+ click.echo(
50
+ "Training dependencies not installed.\nInstall with: pip install aimnet[train]",
51
+ err=True,
52
+ )
53
+ sys.exit(1)
54
+
55
+ @cli.command(name="export")
56
+ def export_stub():
57
+ """Export trained model to distributable format (requires aimnet[train])"""
58
+ click.echo(
59
+ "Training dependencies not installed.\nInstall with: pip install aimnet[train]",
60
+ err=True,
61
+ )
62
+ sys.exit(1)
63
+
64
+ @cli.command(name="calc_sae")
65
+ def calc_sae_stub():
66
+ """Calculate SAE (requires aimnet[train])"""
67
+ click.echo(
68
+ "Training dependencies not installed.\nInstall with: pip install aimnet[train]",
69
+ err=True,
70
+ )
71
+ sys.exit(1)
16
72
 
17
73
 
18
74
  if __name__ == "__main__":
aimnet/config.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import os
2
+ from collections.abc import Callable, Iterator
2
3
  from importlib import import_module
3
- from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
4
+ from typing import Any
4
5
 
5
6
  import yaml
6
7
  from jinja2 import Template
@@ -27,7 +28,7 @@ def get_module(name: str) -> Callable:
27
28
  return func # type: ignore[no-any-return]
28
29
 
29
30
 
30
- def get_init_module(name: str, args: Optional[List] = None, kwargs: Optional[Dict] = None) -> Callable:
31
+ def get_init_module(name: str, args: list | None = None, kwargs: dict | None = None) -> Callable:
31
32
  """
32
33
  Get the initialized module based on the given name, arguments, and keyword arguments.
33
34
 
@@ -46,8 +47,8 @@ def get_init_module(name: str, args: Optional[List] = None, kwargs: Optional[Dic
46
47
 
47
48
 
48
49
  def load_yaml(
49
- config: Dict[str, Any] | List | str, hyperpar: Optional[Dict[str, Any] | str] = None
50
- ) -> Dict[str, Any] | List:
50
+ config: dict[str, Any] | list | str, hyperpar: dict[str, Any] | str | None = None
51
+ ) -> dict[str, Any] | list:
51
52
  """
52
53
  Load a YAML configuration file and apply optional hyperparameters.
53
54
 
@@ -88,8 +89,8 @@ def load_yaml(
88
89
 
89
90
 
90
91
  def _iter_rec_bottomup(
91
- d: Dict[str, Any] | List,
92
- ) -> Iterator[Tuple[Dict[str, Any] | List, str | int, Any]]:
92
+ d: dict[str, Any] | list,
93
+ ) -> Iterator[tuple[dict[str, Any] | list, str | int, Any]]:
93
94
  if isinstance(d, list):
94
95
  it = enumerate(d)
95
96
  elif isinstance(d, dict):
@@ -102,9 +103,7 @@ def _iter_rec_bottomup(
102
103
  yield d, k, v
103
104
 
104
105
 
105
- def build_module(
106
- config: Union[str, Dict, List], hyperpar: Union[str, Dict, None] = None
107
- ) -> Union[List, Dict, Callable]:
106
+ def build_module(config: str | dict | list, hyperpar: str | dict | None = None) -> list | dict | Callable:
108
107
  """
109
108
  Build a module based on the provided configuration.
110
109
  Every (possibly nested) dictionary with a 'class' key will be replaced by an instance initialized with
aimnet/data/sgdataset.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import os
2
+ from collections.abc import Callable, Sequence
2
3
  from glob import glob
3
- from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
4
+ from typing import Any
4
5
 
5
6
  import h5py
6
7
  import numpy as np
@@ -23,12 +24,12 @@ class DataGroup:
23
24
 
24
25
  def __init__(
25
26
  self,
26
- data: Optional[str | Dict[str, np.ndarray] | h5py.Group] = None,
27
- keys: Optional[List[str]] = None,
28
- shard: Optional[Tuple[int, int]] = None,
27
+ data: str | dict[str, np.ndarray] | h5py.Group | None = None,
28
+ keys: list[str] | None = None,
29
+ shard: tuple[int, int] | None = None,
29
30
  ):
30
31
  # main container for data
31
- self._data: Dict[str, np.ndarray] = {}
32
+ self._data: dict[str, np.ndarray] = {}
32
33
 
33
34
  if data is None:
34
35
  data = {}
@@ -184,13 +185,13 @@ class DataGroup:
184
185
  class SizeGroupedDataset:
185
186
  def __init__(
186
187
  self,
187
- data: Optional[str | List[str] | Dict[int, Dict[str, np.ndarray]] | Dict[int, DataGroup]] = None,
188
- keys: Optional[List[str]] = None,
189
- shard: Optional[Tuple[int, int]] = None,
188
+ data: str | list[str] | dict[int, dict[str, np.ndarray]] | dict[int, DataGroup] | None = None,
189
+ keys: list[str] | None = None,
190
+ shard: tuple[int, int] | None = None,
190
191
  ):
191
192
  # main containers
192
- self._data: Dict[int, DataGroup] = {}
193
- self._meta: Dict[str, str] = {}
193
+ self._data: dict[int, DataGroup] = {}
194
+ self._meta: dict[str, str] = {}
194
195
 
195
196
  # load data
196
197
  if isinstance(data, str):
@@ -203,17 +204,17 @@ class SizeGroupedDataset:
203
204
  elif isinstance(data, dict):
204
205
  self.load_dict(data)
205
206
  self.loader_mode = False
206
- self.x: List[str] = []
207
- self.y: List[str] = []
207
+ self.x: list[str] = []
208
+ self.y: list[str] = []
208
209
 
209
- def load_datadir(self, path, keys=None, shard: Optional[Tuple[int, int]] = None):
210
+ def load_datadir(self, path, keys=None, shard: tuple[int, int] | None = None):
210
211
  if not os.path.isdir(path):
211
212
  raise FileNotFoundError(f"{path} does not exist or not a directory.")
212
213
  for f in glob(os.path.join(path, "???.npz")):
213
214
  k = int(os.path.basename(f)[:3])
214
215
  self[k] = DataGroup(f, keys=keys, shard=shard)
215
216
 
216
- def load_files(self, files, keys=None, shard: Optional[Tuple[int, int]] = None):
217
+ def load_files(self, files, keys=None, shard: tuple[int, int] | None = None):
217
218
  for fil in files:
218
219
  if not os.path.isfile(fil):
219
220
  raise FileNotFoundError(f"{fil} does not exist or not a file.")
@@ -224,27 +225,27 @@ class SizeGroupedDataset:
224
225
  for k, v in data.items():
225
226
  self[k] = DataGroup(v, keys=keys)
226
227
 
227
- def load_h5(self, data, keys=None, shard: Optional[Tuple[int, int]] = None):
228
+ def load_h5(self, data, keys=None, shard: tuple[int, int] | None = None):
228
229
  with h5py.File(data, "r") as f:
229
230
  for k, g in f.items():
230
231
  k = int(k)
231
232
  self[k] = DataGroup(g, keys=keys, shard=shard)
232
233
  self._meta = dict(f.attrs) # type: ignore[attr-defined]
233
234
 
234
- def keys(self) -> List[int]:
235
+ def keys(self) -> list[int]:
235
236
  return sorted(self._data.keys())
236
237
 
237
- def values(self) -> List:
238
+ def values(self) -> list:
238
239
  return [self[k] for k in self.keys()]
239
240
 
240
- def items(self) -> List[Tuple[int, Any]]:
241
+ def items(self) -> list[tuple[int, Any]]:
241
242
  return [(k, self[k]) for k in self.keys()]
242
243
 
243
244
  def datakeys(self):
244
245
  return next(iter(self._data.values())).keys() if self._data else set()
245
246
 
246
247
  @property
247
- def groups(self) -> List[DataGroup]:
248
+ def groups(self) -> list[DataGroup]:
248
249
  return self.values()
249
250
 
250
251
  def __len__(self):
@@ -259,7 +260,7 @@ class SizeGroupedDataset:
259
260
  raise ValueError("Wrong set of data keys.")
260
261
  self._data[key] = value
261
262
 
262
- def __getitem__(self, item: int | Tuple[int, Sequence]) -> Dict | Tuple[Dict, Dict]:
263
+ def __getitem__(self, item: int | tuple[int, Sequence]) -> dict | tuple[dict, dict]:
263
264
  if isinstance(item, int):
264
265
  ret = self._data[item]
265
266
  else:
@@ -332,7 +333,7 @@ class SizeGroupedDataset:
332
333
  for v in self.values():
333
334
  v.shuffle(seed)
334
335
 
335
- def save(self, dirname, namemap_fn: Optional[Callable] = None, compress: bool = False):
336
+ def save(self, dirname, namemap_fn: Callable | None = None, compress: bool = False):
336
337
  os.makedirs(dirname, exist_ok=True)
337
338
  if namemap_fn is None:
338
339
  namemap_fn = lambda x: f"{x:03d}.npz"
@@ -441,7 +442,7 @@ class SizeGroupedDataset:
441
442
  for g in self.values():
442
443
  yield from g.iter_batched(batch_size, keys)
443
444
 
444
- def get_loader(self, sampler, x: List[str], y: Optional[List[str]] = None, **loader_kwargs):
445
+ def get_loader(self, sampler, x: list[str], y: list[str] | None = None, **loader_kwargs):
445
446
  self.loader_mode = True
446
447
  self.x = x
447
448
  self.y = y or []
@@ -0,0 +1,66 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: MIT
3
+ #
4
+ # Permission is hereby granted, free of charge, to any person obtaining a
5
+ # copy of this software and associated documentation files (the "Software"),
6
+ # to deal in the Software without restriction, including without limitation
7
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8
+ # and/or sell copies of the Software, and to permit persons to whom the
9
+ # Software is furnished to do so, subject to the following conditions:
10
+ #
11
+ # The above copyright notice and this permission notice shall be included in
12
+ # all copies or substantial portions of the Software.
13
+ #
14
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20
+ # DEALINGS IN THE SOFTWARE.
21
+ """AIMNet Kernels Package - GPU-accelerated custom operations."""
22
+
23
+ import torch
24
+
25
+ from .conv_sv_2d_sp_wp import conv_sv_2d_sp
26
+
27
+
28
+ def load_ops():
29
+ """
30
+ Load and register all custom ops.
31
+
32
+ This function ensures that all custom operations are properly
33
+ registered with PyTorch's operator registry.
34
+
35
+ Should be called before using any of the custom kernels to ensure
36
+ proper registration with the PyTorch dispatcher.
37
+
38
+ Returns:
39
+ list: Available ops that were registered.
40
+ """
41
+ available_ops = []
42
+
43
+ # Import warp kernels to trigger registration
44
+ # Import custom ops module to trigger registration
45
+ from aimnet.modules import ops as _ops # noqa: F401
46
+
47
+ from . import conv_sv_2d_sp_wp # noqa: F401
48
+
49
+ # Verify ops are available
50
+ if hasattr(torch.ops, "aimnet"):
51
+ if hasattr(torch.ops.aimnet, "conv_sv_2d_sp_fwd"):
52
+ available_ops.append("aimnet::conv_sv_2d_sp_fwd")
53
+ if hasattr(torch.ops.aimnet, "conv_sv_2d_sp_bwd"):
54
+ available_ops.append("aimnet::conv_sv_2d_sp_bwd")
55
+ if hasattr(torch.ops.aimnet, "conv_sv_2d_sp_bwd_bwd"):
56
+ available_ops.append("aimnet::conv_sv_2d_sp_bwd_bwd")
57
+ if hasattr(torch.ops.aimnet, "dftd3_fwd"):
58
+ available_ops.append("aimnet::dftd3_fwd")
59
+
60
+ return available_ops
61
+
62
+
63
+ __all__ = [
64
+ "conv_sv_2d_sp",
65
+ "load_ops",
66
+ ]