aimnet 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aimnet/__init__.py +7 -0
- aimnet/base.py +24 -8
- aimnet/calculators/__init__.py +4 -4
- aimnet/calculators/aimnet2ase.py +19 -6
- aimnet/calculators/calculator.py +868 -108
- aimnet/calculators/model_registry.py +2 -5
- aimnet/calculators/model_registry.yaml +55 -17
- aimnet/cli.py +62 -6
- aimnet/config.py +8 -9
- aimnet/data/sgdataset.py +23 -22
- aimnet/kernels/__init__.py +66 -0
- aimnet/kernels/conv_sv_2d_sp_wp.py +478 -0
- aimnet/models/__init__.py +13 -1
- aimnet/models/aimnet2.py +19 -22
- aimnet/models/base.py +183 -15
- aimnet/models/convert.py +30 -0
- aimnet/models/utils.py +735 -0
- aimnet/modules/__init__.py +1 -1
- aimnet/modules/aev.py +49 -48
- aimnet/modules/core.py +14 -13
- aimnet/modules/lr.py +520 -115
- aimnet/modules/ops.py +537 -0
- aimnet/nbops.py +105 -15
- aimnet/ops.py +90 -18
- aimnet/train/export_model.py +226 -0
- aimnet/train/loss.py +7 -7
- aimnet/train/metrics.py +5 -6
- aimnet/train/train.py +4 -1
- aimnet/train/utils.py +42 -13
- aimnet-0.1.0.dist-info/METADATA +308 -0
- aimnet-0.1.0.dist-info/RECORD +43 -0
- {aimnet-0.0.1.dist-info → aimnet-0.1.0.dist-info}/WHEEL +1 -1
- aimnet-0.1.0.dist-info/entry_points.txt +3 -0
- aimnet/calculators/nb_kernel_cpu.py +0 -222
- aimnet/calculators/nb_kernel_cuda.py +0 -217
- aimnet/calculators/nbmat.py +0 -220
- aimnet/train/pt2jpt.py +0 -81
- aimnet-0.0.1.dist-info/METADATA +0 -78
- aimnet-0.0.1.dist-info/RECORD +0 -41
- aimnet-0.0.1.dist-info/entry_points.txt +0 -5
- {aimnet-0.0.1.dist-info → aimnet-0.1.0.dist-info/licenses}/LICENSE +0 -0
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
|
-
from typing import Dict, Optional
|
|
4
3
|
|
|
5
4
|
import click
|
|
6
5
|
import requests
|
|
@@ -9,7 +8,7 @@ import yaml
|
|
|
9
8
|
logging.basicConfig(level=logging.INFO)
|
|
10
9
|
|
|
11
10
|
|
|
12
|
-
def load_model_registry(registry_file:
|
|
11
|
+
def load_model_registry(registry_file: str | None = None) -> dict[str, str]:
|
|
13
12
|
registry_file = registry_file or os.path.join(os.path.dirname(__file__), "model_registry.yaml")
|
|
14
13
|
with open(os.path.join(os.path.dirname(__file__), "model_registry.yaml")) as f:
|
|
15
14
|
return yaml.load(f, Loader=yaml.SafeLoader)
|
|
@@ -43,9 +42,7 @@ def _maybe_download_asset(file: str, url: str) -> str:
|
|
|
43
42
|
|
|
44
43
|
def get_model_path(s: str) -> str:
|
|
45
44
|
# direct file path
|
|
46
|
-
if os.path.isfile(s):
|
|
47
|
-
print("Found model file:", s)
|
|
48
|
-
else:
|
|
45
|
+
if not os.path.isfile(s):
|
|
49
46
|
s = get_registry_model_path(s)
|
|
50
47
|
return s
|
|
51
48
|
|
|
@@ -1,33 +1,71 @@
|
|
|
1
1
|
# map file name to url
|
|
2
2
|
models:
|
|
3
3
|
aimnet2_wb97m_d3_0:
|
|
4
|
-
file: aimnet2_wb97m_d3_0.
|
|
5
|
-
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_d3_0.
|
|
4
|
+
file: aimnet2_wb97m_d3_0.pt
|
|
5
|
+
url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2/aimnet2_wb97m_d3_0.pt
|
|
6
6
|
aimnet2_wb97m_d3_1:
|
|
7
|
-
file: aimnet2_wb97m_d3_1.
|
|
8
|
-
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_d3_1.
|
|
7
|
+
file: aimnet2_wb97m_d3_1.pt
|
|
8
|
+
url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2/aimnet2_wb97m_d3_1.pt
|
|
9
9
|
aimnet2_wb97m_d3_2:
|
|
10
|
-
file: aimnet2_wb97m_d3_2.
|
|
11
|
-
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_d3_2.
|
|
10
|
+
file: aimnet2_wb97m_d3_2.pt
|
|
11
|
+
url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2/aimnet2_wb97m_d3_2.pt
|
|
12
12
|
aimnet2_wb97m_d3_3:
|
|
13
|
-
file: aimnet2_wb97m_d3_3.
|
|
14
|
-
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_d3_3.
|
|
13
|
+
file: aimnet2_wb97m_d3_3.pt
|
|
14
|
+
url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2/aimnet2_wb97m_d3_3.pt
|
|
15
15
|
aimnet2_b973c_d3_0:
|
|
16
|
-
file: aimnet2_b973c_d3_0.
|
|
17
|
-
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_d3_0.
|
|
16
|
+
file: aimnet2_b973c_d3_0.pt
|
|
17
|
+
url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2/aimnet2_b973c_d3_0.pt
|
|
18
18
|
aimnet2_b973c_d3_1:
|
|
19
|
-
file: aimnet2_b973c_d3_1.
|
|
20
|
-
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_d3_1.
|
|
19
|
+
file: aimnet2_b973c_d3_1.pt
|
|
20
|
+
url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2/aimnet2_b973c_d3_1.pt
|
|
21
21
|
aimnet2_b973c_d3_2:
|
|
22
|
-
file: aimnet2_b973c_d3_2.
|
|
23
|
-
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_d3_2.
|
|
22
|
+
file: aimnet2_b973c_d3_2.pt
|
|
23
|
+
url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2/aimnet2_b973c_d3_2.pt
|
|
24
24
|
aimnet2_b973c_d3_3:
|
|
25
|
-
file: aimnet2_b973c_d3_3.
|
|
26
|
-
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_d3_3.
|
|
25
|
+
file: aimnet2_b973c_d3_3.pt
|
|
26
|
+
url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2/aimnet2_b973c_d3_3.pt
|
|
27
|
+
aimnet2_b973c_2025_d3_0:
|
|
28
|
+
file: aimnet2_2025_b973c_d3_0.pt
|
|
29
|
+
url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2/aimnet2_2025_b973c_d3_0.pt
|
|
30
|
+
aimnet2_b973c_2025_d3_1:
|
|
31
|
+
file: aimnet2_2025_b973c_d3_1.pt
|
|
32
|
+
url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2/aimnet2_2025_b973c_d3_1.pt
|
|
33
|
+
aimnet2_b973c_2025_d3_2:
|
|
34
|
+
file: aimnet2_2025_b973c_d3_2.pt
|
|
35
|
+
url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2/aimnet2_2025_b973c_d3_2.pt
|
|
36
|
+
aimnet2_b973c_2025_d3_3:
|
|
37
|
+
file: aimnet2_2025_b973c_d3_3.pt
|
|
38
|
+
url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2/aimnet2_2025_b973c_d3_3.pt
|
|
39
|
+
aimnet2nse_0:
|
|
40
|
+
file: aimnet2nse_wb97m_0.pt
|
|
41
|
+
url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2NSE/aimnet2nse_wb97m_0.pt
|
|
42
|
+
aimnet2nse_1:
|
|
43
|
+
file: aimnet2nse_wb97m_1.pt
|
|
44
|
+
url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2NSE/aimnet2nse_wb97m_1.pt
|
|
45
|
+
aimnet2nse_2:
|
|
46
|
+
file: aimnet2nse_wb97m_2.pt
|
|
47
|
+
url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2NSE/aimnet2nse_wb97m_2.pt
|
|
48
|
+
aimnet2nse_3:
|
|
49
|
+
file: aimnet2nse_wb97m_3.pt
|
|
50
|
+
url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2NSE/aimnet2nse_wb97m_3.pt
|
|
51
|
+
aimnet2-pd_0:
|
|
52
|
+
file: aimnet2-pd_0.pt
|
|
53
|
+
url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2Pd/aimnet2-pd_0.pt
|
|
54
|
+
aimnet2-pd_1:
|
|
55
|
+
file: aimnet2-pd_1.pt
|
|
56
|
+
url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2Pd/aimnet2-pd_1.pt
|
|
57
|
+
aimnet2-pd_2:
|
|
58
|
+
file: aimnet2-pd_2.pt
|
|
59
|
+
url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2Pd/aimnet2-pd_2.pt
|
|
60
|
+
aimnet2-pd_3:
|
|
61
|
+
file: aimnet2-pd_3.pt
|
|
62
|
+
url: https://storage.googleapis.com/aimnetcentral/aimnet2v2/AIMNet2Pd/aimnet2-pd_3.pt
|
|
27
63
|
|
|
28
64
|
# map model alias to file name
|
|
29
65
|
aliases:
|
|
30
66
|
aimnet2: aimnet2_wb97m_d3_0
|
|
31
67
|
aimnet2_wb97m: aimnet2_wb97m_d3_0
|
|
32
68
|
aimnet2_b973c: aimnet2_b973c_d3_0
|
|
33
|
-
|
|
69
|
+
aimnet2nse: aimnet2nse_0
|
|
70
|
+
aimnet2pd: aimnet2-pd_0
|
|
71
|
+
aimnet2_2025: aimnet2_b973c_2025_d3_0
|
aimnet/cli.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
|
|
1
3
|
import click
|
|
2
4
|
|
|
3
|
-
from .
|
|
4
|
-
from .train.pt2jpt import jitcompile
|
|
5
|
-
from .train.train import train
|
|
5
|
+
from .calculators.model_registry import clear_assets
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
@click.group()
|
|
@@ -10,9 +10,65 @@ def cli():
|
|
|
10
10
|
"""AIMNet2 command line tool"""
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
|
|
14
|
-
cli.add_command(
|
|
15
|
-
|
|
13
|
+
# Always available commands
|
|
14
|
+
cli.add_command(clear_assets, name="clear_model_cache")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# Register convert command (doesn't need heavy training dependencies)
|
|
18
|
+
try:
|
|
19
|
+
from .models.convert import convert_legacy_jpt
|
|
20
|
+
|
|
21
|
+
cli.add_command(convert_legacy_jpt, name="convert")
|
|
22
|
+
except ImportError:
|
|
23
|
+
|
|
24
|
+
@cli.command(name="convert")
|
|
25
|
+
def convert_stub():
|
|
26
|
+
"""Convert legacy JIT model to new format (requires aimnet[train])"""
|
|
27
|
+
click.echo(
|
|
28
|
+
"Training dependencies not installed.\nInstall with: pip install aimnet[train]",
|
|
29
|
+
err=True,
|
|
30
|
+
)
|
|
31
|
+
sys.exit(1)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# Try to lazily register training commands (requires ignite, etc.)
|
|
35
|
+
try:
|
|
36
|
+
from .train.calc_sae import calc_sae
|
|
37
|
+
from .train.export_model import export_model
|
|
38
|
+
from .train.train import train
|
|
39
|
+
|
|
40
|
+
cli.add_command(train, name="train")
|
|
41
|
+
cli.add_command(export_model, name="export")
|
|
42
|
+
cli.add_command(calc_sae, name="calc_sae")
|
|
43
|
+
except ImportError:
|
|
44
|
+
# If training dependencies are not available, register stub commands with helpful error messages
|
|
45
|
+
|
|
46
|
+
@cli.command(name="train")
|
|
47
|
+
def train_stub():
|
|
48
|
+
"""Train AIMNet2 models (requires aimnet[train])"""
|
|
49
|
+
click.echo(
|
|
50
|
+
"Training dependencies not installed.\nInstall with: pip install aimnet[train]",
|
|
51
|
+
err=True,
|
|
52
|
+
)
|
|
53
|
+
sys.exit(1)
|
|
54
|
+
|
|
55
|
+
@cli.command(name="export")
|
|
56
|
+
def export_stub():
|
|
57
|
+
"""Export trained model to distributable format (requires aimnet[train])"""
|
|
58
|
+
click.echo(
|
|
59
|
+
"Training dependencies not installed.\nInstall with: pip install aimnet[train]",
|
|
60
|
+
err=True,
|
|
61
|
+
)
|
|
62
|
+
sys.exit(1)
|
|
63
|
+
|
|
64
|
+
@cli.command(name="calc_sae")
|
|
65
|
+
def calc_sae_stub():
|
|
66
|
+
"""Calculate SAE (requires aimnet[train])"""
|
|
67
|
+
click.echo(
|
|
68
|
+
"Training dependencies not installed.\nInstall with: pip install aimnet[train]",
|
|
69
|
+
err=True,
|
|
70
|
+
)
|
|
71
|
+
sys.exit(1)
|
|
16
72
|
|
|
17
73
|
|
|
18
74
|
if __name__ == "__main__":
|
aimnet/config.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import os
|
|
2
|
+
from collections.abc import Callable, Iterator
|
|
2
3
|
from importlib import import_module
|
|
3
|
-
from typing import Any
|
|
4
|
+
from typing import Any
|
|
4
5
|
|
|
5
6
|
import yaml
|
|
6
7
|
from jinja2 import Template
|
|
@@ -27,7 +28,7 @@ def get_module(name: str) -> Callable:
|
|
|
27
28
|
return func # type: ignore[no-any-return]
|
|
28
29
|
|
|
29
30
|
|
|
30
|
-
def get_init_module(name: str, args:
|
|
31
|
+
def get_init_module(name: str, args: list | None = None, kwargs: dict | None = None) -> Callable:
|
|
31
32
|
"""
|
|
32
33
|
Get the initialized module based on the given name, arguments, and keyword arguments.
|
|
33
34
|
|
|
@@ -46,8 +47,8 @@ def get_init_module(name: str, args: Optional[List] = None, kwargs: Optional[Dic
|
|
|
46
47
|
|
|
47
48
|
|
|
48
49
|
def load_yaml(
|
|
49
|
-
config:
|
|
50
|
-
) ->
|
|
50
|
+
config: dict[str, Any] | list | str, hyperpar: dict[str, Any] | str | None = None
|
|
51
|
+
) -> dict[str, Any] | list:
|
|
51
52
|
"""
|
|
52
53
|
Load a YAML configuration file and apply optional hyperparameters.
|
|
53
54
|
|
|
@@ -88,8 +89,8 @@ def load_yaml(
|
|
|
88
89
|
|
|
89
90
|
|
|
90
91
|
def _iter_rec_bottomup(
|
|
91
|
-
d:
|
|
92
|
-
) -> Iterator[
|
|
92
|
+
d: dict[str, Any] | list,
|
|
93
|
+
) -> Iterator[tuple[dict[str, Any] | list, str | int, Any]]:
|
|
93
94
|
if isinstance(d, list):
|
|
94
95
|
it = enumerate(d)
|
|
95
96
|
elif isinstance(d, dict):
|
|
@@ -102,9 +103,7 @@ def _iter_rec_bottomup(
|
|
|
102
103
|
yield d, k, v
|
|
103
104
|
|
|
104
105
|
|
|
105
|
-
def build_module(
|
|
106
|
-
config: Union[str, Dict, List], hyperpar: Union[str, Dict, None] = None
|
|
107
|
-
) -> Union[List, Dict, Callable]:
|
|
106
|
+
def build_module(config: str | dict | list, hyperpar: str | dict | None = None) -> list | dict | Callable:
|
|
108
107
|
"""
|
|
109
108
|
Build a module based on the provided configuration.
|
|
110
109
|
Every (possibly nested) dictionary with a 'class' key will be replaced by an instance initialized with
|
aimnet/data/sgdataset.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import os
|
|
2
|
+
from collections.abc import Callable, Sequence
|
|
2
3
|
from glob import glob
|
|
3
|
-
from typing import Any
|
|
4
|
+
from typing import Any
|
|
4
5
|
|
|
5
6
|
import h5py
|
|
6
7
|
import numpy as np
|
|
@@ -23,12 +24,12 @@ class DataGroup:
|
|
|
23
24
|
|
|
24
25
|
def __init__(
|
|
25
26
|
self,
|
|
26
|
-
data:
|
|
27
|
-
keys:
|
|
28
|
-
shard:
|
|
27
|
+
data: str | dict[str, np.ndarray] | h5py.Group | None = None,
|
|
28
|
+
keys: list[str] | None = None,
|
|
29
|
+
shard: tuple[int, int] | None = None,
|
|
29
30
|
):
|
|
30
31
|
# main container for data
|
|
31
|
-
self._data:
|
|
32
|
+
self._data: dict[str, np.ndarray] = {}
|
|
32
33
|
|
|
33
34
|
if data is None:
|
|
34
35
|
data = {}
|
|
@@ -184,13 +185,13 @@ class DataGroup:
|
|
|
184
185
|
class SizeGroupedDataset:
|
|
185
186
|
def __init__(
|
|
186
187
|
self,
|
|
187
|
-
data:
|
|
188
|
-
keys:
|
|
189
|
-
shard:
|
|
188
|
+
data: str | list[str] | dict[int, dict[str, np.ndarray]] | dict[int, DataGroup] | None = None,
|
|
189
|
+
keys: list[str] | None = None,
|
|
190
|
+
shard: tuple[int, int] | None = None,
|
|
190
191
|
):
|
|
191
192
|
# main containers
|
|
192
|
-
self._data:
|
|
193
|
-
self._meta:
|
|
193
|
+
self._data: dict[int, DataGroup] = {}
|
|
194
|
+
self._meta: dict[str, str] = {}
|
|
194
195
|
|
|
195
196
|
# load data
|
|
196
197
|
if isinstance(data, str):
|
|
@@ -203,17 +204,17 @@ class SizeGroupedDataset:
|
|
|
203
204
|
elif isinstance(data, dict):
|
|
204
205
|
self.load_dict(data)
|
|
205
206
|
self.loader_mode = False
|
|
206
|
-
self.x:
|
|
207
|
-
self.y:
|
|
207
|
+
self.x: list[str] = []
|
|
208
|
+
self.y: list[str] = []
|
|
208
209
|
|
|
209
|
-
def load_datadir(self, path, keys=None, shard:
|
|
210
|
+
def load_datadir(self, path, keys=None, shard: tuple[int, int] | None = None):
|
|
210
211
|
if not os.path.isdir(path):
|
|
211
212
|
raise FileNotFoundError(f"{path} does not exist or not a directory.")
|
|
212
213
|
for f in glob(os.path.join(path, "???.npz")):
|
|
213
214
|
k = int(os.path.basename(f)[:3])
|
|
214
215
|
self[k] = DataGroup(f, keys=keys, shard=shard)
|
|
215
216
|
|
|
216
|
-
def load_files(self, files, keys=None, shard:
|
|
217
|
+
def load_files(self, files, keys=None, shard: tuple[int, int] | None = None):
|
|
217
218
|
for fil in files:
|
|
218
219
|
if not os.path.isfile(fil):
|
|
219
220
|
raise FileNotFoundError(f"{fil} does not exist or not a file.")
|
|
@@ -224,27 +225,27 @@ class SizeGroupedDataset:
|
|
|
224
225
|
for k, v in data.items():
|
|
225
226
|
self[k] = DataGroup(v, keys=keys)
|
|
226
227
|
|
|
227
|
-
def load_h5(self, data, keys=None, shard:
|
|
228
|
+
def load_h5(self, data, keys=None, shard: tuple[int, int] | None = None):
|
|
228
229
|
with h5py.File(data, "r") as f:
|
|
229
230
|
for k, g in f.items():
|
|
230
231
|
k = int(k)
|
|
231
232
|
self[k] = DataGroup(g, keys=keys, shard=shard)
|
|
232
233
|
self._meta = dict(f.attrs) # type: ignore[attr-defined]
|
|
233
234
|
|
|
234
|
-
def keys(self) ->
|
|
235
|
+
def keys(self) -> list[int]:
|
|
235
236
|
return sorted(self._data.keys())
|
|
236
237
|
|
|
237
|
-
def values(self) ->
|
|
238
|
+
def values(self) -> list:
|
|
238
239
|
return [self[k] for k in self.keys()]
|
|
239
240
|
|
|
240
|
-
def items(self) ->
|
|
241
|
+
def items(self) -> list[tuple[int, Any]]:
|
|
241
242
|
return [(k, self[k]) for k in self.keys()]
|
|
242
243
|
|
|
243
244
|
def datakeys(self):
|
|
244
245
|
return next(iter(self._data.values())).keys() if self._data else set()
|
|
245
246
|
|
|
246
247
|
@property
|
|
247
|
-
def groups(self) ->
|
|
248
|
+
def groups(self) -> list[DataGroup]:
|
|
248
249
|
return self.values()
|
|
249
250
|
|
|
250
251
|
def __len__(self):
|
|
@@ -259,7 +260,7 @@ class SizeGroupedDataset:
|
|
|
259
260
|
raise ValueError("Wrong set of data keys.")
|
|
260
261
|
self._data[key] = value
|
|
261
262
|
|
|
262
|
-
def __getitem__(self, item: int |
|
|
263
|
+
def __getitem__(self, item: int | tuple[int, Sequence]) -> dict | tuple[dict, dict]:
|
|
263
264
|
if isinstance(item, int):
|
|
264
265
|
ret = self._data[item]
|
|
265
266
|
else:
|
|
@@ -332,7 +333,7 @@ class SizeGroupedDataset:
|
|
|
332
333
|
for v in self.values():
|
|
333
334
|
v.shuffle(seed)
|
|
334
335
|
|
|
335
|
-
def save(self, dirname, namemap_fn:
|
|
336
|
+
def save(self, dirname, namemap_fn: Callable | None = None, compress: bool = False):
|
|
336
337
|
os.makedirs(dirname, exist_ok=True)
|
|
337
338
|
if namemap_fn is None:
|
|
338
339
|
namemap_fn = lambda x: f"{x:03d}.npz"
|
|
@@ -441,7 +442,7 @@ class SizeGroupedDataset:
|
|
|
441
442
|
for g in self.values():
|
|
442
443
|
yield from g.iter_batched(batch_size, keys)
|
|
443
444
|
|
|
444
|
-
def get_loader(self, sampler, x:
|
|
445
|
+
def get_loader(self, sampler, x: list[str], y: list[str] | None = None, **loader_kwargs):
|
|
445
446
|
self.loader_mode = True
|
|
446
447
|
self.x = x
|
|
447
448
|
self.y = y or []
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: MIT
|
|
3
|
+
#
|
|
4
|
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
|
5
|
+
# copy of this software and associated documentation files (the "Software"),
|
|
6
|
+
# to deal in the Software without restriction, including without limitation
|
|
7
|
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
|
8
|
+
# and/or sell copies of the Software, and to permit persons to whom the
|
|
9
|
+
# Software is furnished to do so, subject to the following conditions:
|
|
10
|
+
#
|
|
11
|
+
# The above copyright notice and this permission notice shall be included in
|
|
12
|
+
# all copies or substantial portions of the Software.
|
|
13
|
+
#
|
|
14
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
15
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
16
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
|
17
|
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
18
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
|
19
|
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
|
20
|
+
# DEALINGS IN THE SOFTWARE.
|
|
21
|
+
"""AIMNet Kernels Package - GPU-accelerated custom operations."""
|
|
22
|
+
|
|
23
|
+
import torch
|
|
24
|
+
|
|
25
|
+
from .conv_sv_2d_sp_wp import conv_sv_2d_sp
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def load_ops():
|
|
29
|
+
"""
|
|
30
|
+
Load and register all custom ops.
|
|
31
|
+
|
|
32
|
+
This function ensures that all custom operations are properly
|
|
33
|
+
registered with PyTorch's operator registry.
|
|
34
|
+
|
|
35
|
+
Should be called before using any of the custom kernels to ensure
|
|
36
|
+
proper registration with the PyTorch dispatcher.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
list: Available ops that were registered.
|
|
40
|
+
"""
|
|
41
|
+
available_ops = []
|
|
42
|
+
|
|
43
|
+
# Import warp kernels to trigger registration
|
|
44
|
+
# Import custom ops module to trigger registration
|
|
45
|
+
from aimnet.modules import ops as _ops # noqa: F401
|
|
46
|
+
|
|
47
|
+
from . import conv_sv_2d_sp_wp # noqa: F401
|
|
48
|
+
|
|
49
|
+
# Verify ops are available
|
|
50
|
+
if hasattr(torch.ops, "aimnet"):
|
|
51
|
+
if hasattr(torch.ops.aimnet, "conv_sv_2d_sp_fwd"):
|
|
52
|
+
available_ops.append("aimnet::conv_sv_2d_sp_fwd")
|
|
53
|
+
if hasattr(torch.ops.aimnet, "conv_sv_2d_sp_bwd"):
|
|
54
|
+
available_ops.append("aimnet::conv_sv_2d_sp_bwd")
|
|
55
|
+
if hasattr(torch.ops.aimnet, "conv_sv_2d_sp_bwd_bwd"):
|
|
56
|
+
available_ops.append("aimnet::conv_sv_2d_sp_bwd_bwd")
|
|
57
|
+
if hasattr(torch.ops.aimnet, "dftd3_fwd"):
|
|
58
|
+
available_ops.append("aimnet::dftd3_fwd")
|
|
59
|
+
|
|
60
|
+
return available_ops
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
__all__ = [
|
|
64
|
+
"conv_sv_2d_sp",
|
|
65
|
+
"load_ops",
|
|
66
|
+
]
|