aimnet 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aimnet/__init__.py +7 -0
- aimnet/base.py +24 -8
- aimnet/calculators/__init__.py +4 -4
- aimnet/calculators/aimnet2ase.py +19 -6
- aimnet/calculators/calculator.py +868 -108
- aimnet/calculators/model_registry.py +2 -5
- aimnet/calculators/model_registry.yaml +55 -17
- aimnet/cli.py +62 -6
- aimnet/config.py +8 -9
- aimnet/data/sgdataset.py +23 -22
- aimnet/kernels/__init__.py +66 -0
- aimnet/kernels/conv_sv_2d_sp_wp.py +478 -0
- aimnet/models/__init__.py +13 -1
- aimnet/models/aimnet2.py +19 -22
- aimnet/models/base.py +183 -15
- aimnet/models/convert.py +30 -0
- aimnet/models/utils.py +735 -0
- aimnet/modules/__init__.py +1 -1
- aimnet/modules/aev.py +49 -48
- aimnet/modules/core.py +14 -13
- aimnet/modules/lr.py +520 -115
- aimnet/modules/ops.py +537 -0
- aimnet/nbops.py +105 -15
- aimnet/ops.py +90 -18
- aimnet/train/export_model.py +226 -0
- aimnet/train/loss.py +7 -7
- aimnet/train/metrics.py +5 -6
- aimnet/train/train.py +4 -1
- aimnet/train/utils.py +42 -13
- aimnet-0.1.0.dist-info/METADATA +308 -0
- aimnet-0.1.0.dist-info/RECORD +43 -0
- {aimnet-0.0.1.dist-info → aimnet-0.1.0.dist-info}/WHEEL +1 -1
- aimnet-0.1.0.dist-info/entry_points.txt +3 -0
- aimnet/calculators/nb_kernel_cpu.py +0 -222
- aimnet/calculators/nb_kernel_cuda.py +0 -217
- aimnet/calculators/nbmat.py +0 -220
- aimnet/train/pt2jpt.py +0 -81
- aimnet-0.0.1.dist-info/METADATA +0 -78
- aimnet-0.0.1.dist-info/RECORD +0 -41
- aimnet-0.0.1.dist-info/entry_points.txt +0 -5
- {aimnet-0.0.1.dist-info → aimnet-0.1.0.dist-info/licenses}/LICENSE +0 -0
aimnet/models/base.py
CHANGED
|
@@ -1,12 +1,151 @@
|
|
|
1
|
-
from
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import warnings
|
|
5
|
+
from typing import ClassVar, Final, NotRequired, TypedDict
|
|
2
6
|
|
|
3
7
|
import torch
|
|
4
8
|
from torch import Tensor, nn
|
|
5
9
|
|
|
6
10
|
from aimnet import nbops
|
|
11
|
+
from aimnet.config import build_module
|
|
12
|
+
from aimnet.models.utils import (
|
|
13
|
+
extract_d3_params,
|
|
14
|
+
extract_species,
|
|
15
|
+
has_externalizable_dftd3,
|
|
16
|
+
validate_state_dict_keys,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ModelMetadata(TypedDict):
|
|
21
|
+
"""Metadata returned by load_model().
|
|
22
|
+
|
|
23
|
+
This TypedDict documents the structure of the metadata dictionary.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
format_version: int # 1 = legacy .jpt, 2 = new .pt
|
|
27
|
+
cutoff: float # Model cutoff radius
|
|
28
|
+
|
|
29
|
+
# Action flags - what calculator should add externally
|
|
30
|
+
needs_coulomb: bool # Add external Coulomb?
|
|
31
|
+
needs_dispersion: bool # Add external DFTD3?
|
|
32
|
+
|
|
33
|
+
# Coulomb mode descriptor - what's in the model
|
|
34
|
+
# "sr_embedded": Model has SRCoulomb, add FULL externally
|
|
35
|
+
# "full_embedded": Full Coulomb in model (legacy JIT)
|
|
36
|
+
# "none": No Coulomb anywhere
|
|
37
|
+
coulomb_mode: str
|
|
38
|
+
coulomb_sr_rc: NotRequired[float | None] # Only if coulomb_mode="sr_embedded"
|
|
39
|
+
coulomb_sr_envelope: NotRequired[str | None] # "exp" | "cosine", only if sr_embedded
|
|
40
|
+
|
|
41
|
+
# Dispersion parameters (optional)
|
|
42
|
+
d3_params: NotRequired[dict | None] # {s8, a1, a2, s6} if needs_dispersion=True
|
|
43
|
+
|
|
44
|
+
implemented_species: list[int] # Supported atomic numbers
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def load_model(path: str, device: str = "cpu") -> tuple[nn.Module, ModelMetadata]:
|
|
48
|
+
"""Load model from file, supporting both new and legacy formats.
|
|
49
|
+
|
|
50
|
+
Automatically detects format:
|
|
51
|
+
- New format: state dict with embedded YAML config and metadata
|
|
52
|
+
- Legacy format: JIT-compiled TorchScript model
|
|
53
|
+
|
|
54
|
+
Parameters
|
|
55
|
+
----------
|
|
56
|
+
path : str
|
|
57
|
+
Path to the model file (.pt or .jpt).
|
|
58
|
+
device : str
|
|
59
|
+
Device to load the model on. Default is "cpu".
|
|
60
|
+
|
|
61
|
+
Returns
|
|
62
|
+
-------
|
|
63
|
+
model : nn.Module
|
|
64
|
+
The loaded model with weights.
|
|
65
|
+
metadata : ModelMetadata
|
|
66
|
+
Dictionary containing model metadata. See ModelMetadata TypedDict for fields.
|
|
67
|
+
|
|
68
|
+
Notes
|
|
69
|
+
-----
|
|
70
|
+
For legacy JIT models (format_version=1), `needs_coulomb` and `needs_dispersion`
|
|
71
|
+
are False because LR modules are already embedded in the TorchScript model.
|
|
72
|
+
"""
|
|
73
|
+
import yaml
|
|
74
|
+
|
|
75
|
+
# torch.load auto-detects TorchScript and dispatches to torch.jit.load
|
|
76
|
+
with warnings.catch_warnings():
|
|
77
|
+
warnings.filterwarnings("ignore", ".*looks like a TorchScript archive.*")
|
|
78
|
+
data = torch.load(path, map_location=device, weights_only=False)
|
|
79
|
+
|
|
80
|
+
# Check result type to determine format
|
|
81
|
+
if isinstance(data, dict) and "model_yaml" in data:
|
|
82
|
+
# New state dict format
|
|
83
|
+
model_config = yaml.safe_load(data["model_yaml"])
|
|
84
|
+
model = build_module(model_config)
|
|
85
|
+
|
|
86
|
+
# Use strict=False because modules may differ between formats
|
|
87
|
+
load_result = model.load_state_dict(data["state_dict"], strict=False)
|
|
88
|
+
|
|
89
|
+
# Check for unexpected missing/extra keys
|
|
90
|
+
real_missing, real_unexpected = validate_state_dict_keys(load_result.missing_keys, load_result.unexpected_keys)
|
|
91
|
+
if real_missing or real_unexpected:
|
|
92
|
+
msg_parts = []
|
|
93
|
+
if real_missing:
|
|
94
|
+
msg_parts.append(f"Missing keys: {real_missing}")
|
|
95
|
+
if real_unexpected:
|
|
96
|
+
msg_parts.append(f"Unexpected keys: {real_unexpected}")
|
|
97
|
+
warnings.warn(f"State dict mismatch during model loading. {'; '.join(msg_parts)}", stacklevel=2)
|
|
98
|
+
|
|
99
|
+
model = model.to(device)
|
|
100
|
+
|
|
101
|
+
# Preserve float64 precision for atomic shifts (SAE values) after device transfer
|
|
102
|
+
if hasattr(model, "outputs") and hasattr(model.outputs, "atomic_shift"):
|
|
103
|
+
model.outputs.atomic_shift.shifts = model.outputs.atomic_shift.shifts.double()
|
|
104
|
+
|
|
105
|
+
metadata: ModelMetadata = {
|
|
106
|
+
"format_version": data.get("format_version", 2), # Default 2 for early v2 files without version
|
|
107
|
+
"cutoff": data["cutoff"],
|
|
108
|
+
"needs_coulomb": data.get("needs_coulomb", False),
|
|
109
|
+
"needs_dispersion": data.get("needs_dispersion", False),
|
|
110
|
+
"coulomb_mode": data.get("coulomb_mode", "none"),
|
|
111
|
+
"coulomb_sr_rc": data.get("coulomb_sr_rc"),
|
|
112
|
+
"coulomb_sr_envelope": data.get("coulomb_sr_envelope"),
|
|
113
|
+
"d3_params": data.get("d3_params"),
|
|
114
|
+
"has_embedded_lr": data.get("has_embedded_lr", False),
|
|
115
|
+
"implemented_species": data.get("implemented_species", []),
|
|
116
|
+
}
|
|
7
117
|
|
|
118
|
+
# Attach metadata to model for easy access
|
|
119
|
+
model._metadata = metadata
|
|
8
120
|
|
|
9
|
-
|
|
121
|
+
return model, metadata
|
|
122
|
+
|
|
123
|
+
elif isinstance(data, torch.jit.ScriptModule):
|
|
124
|
+
# Legacy JIT format - LR modules are already embedded in the TorchScript model
|
|
125
|
+
model = data
|
|
126
|
+
metadata: ModelMetadata = {
|
|
127
|
+
"format_version": 1, # Legacy .jpt format is v1
|
|
128
|
+
"cutoff": float(model.cutoff),
|
|
129
|
+
# Legacy models have LR modules embedded - don't add external ones
|
|
130
|
+
"needs_coulomb": False,
|
|
131
|
+
"needs_dispersion": False,
|
|
132
|
+
"coulomb_mode": "full_embedded",
|
|
133
|
+
# No coulomb_sr_rc/envelope for legacy (full Coulomb is embedded)
|
|
134
|
+
"d3_params": extract_d3_params(model) if has_externalizable_dftd3(model) else None,
|
|
135
|
+
"implemented_species": extract_species(model),
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
# Attempt metadata assignment; silently fails for JIT models
|
|
139
|
+
with contextlib.suppress(AttributeError, RuntimeError):
|
|
140
|
+
model._metadata = metadata # type: ignore[attr-defined]
|
|
141
|
+
|
|
142
|
+
return model, metadata
|
|
143
|
+
|
|
144
|
+
else:
|
|
145
|
+
raise ValueError(f"Unknown model format: {type(data)}")
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class AIMNet2Base(nn.Module):
|
|
10
149
|
"""Base class for AIMNet2 models. Implements pre-processing data:
|
|
11
150
|
converting to right dtype and device, setting nb mode, calculating masks.
|
|
12
151
|
"""
|
|
@@ -15,32 +154,61 @@ class AIMNet2Base(nn.Module): # pylint: disable=abstract-method
|
|
|
15
154
|
|
|
16
155
|
_required_keys: Final = ["coord", "numbers", "charge"]
|
|
17
156
|
_required_keys_dtype: Final = [__default_dtype, torch.int64, __default_dtype]
|
|
18
|
-
_optional_keys: Final = [
|
|
157
|
+
_optional_keys: Final = [
|
|
158
|
+
"mult",
|
|
159
|
+
"nbmat",
|
|
160
|
+
"nbmat_lr",
|
|
161
|
+
"mol_idx",
|
|
162
|
+
"shifts",
|
|
163
|
+
"shifts_lr",
|
|
164
|
+
"cell",
|
|
165
|
+
"nbmat_dftd3",
|
|
166
|
+
"shifts_dftd3",
|
|
167
|
+
"cutoff_dftd3",
|
|
168
|
+
"nbmat_coulomb",
|
|
169
|
+
"shifts_coulomb",
|
|
170
|
+
"cutoff_coulomb",
|
|
171
|
+
]
|
|
19
172
|
_optional_keys_dtype: Final = [
|
|
20
|
-
__default_dtype,
|
|
21
|
-
torch.int64,
|
|
22
|
-
torch.int64,
|
|
23
|
-
torch.int64,
|
|
24
|
-
__default_dtype,
|
|
25
|
-
__default_dtype,
|
|
26
|
-
__default_dtype,
|
|
173
|
+
__default_dtype, # mult
|
|
174
|
+
torch.int64, # nbmat
|
|
175
|
+
torch.int64, # nbmat_lr
|
|
176
|
+
torch.int64, # mol_idx
|
|
177
|
+
__default_dtype, # shifts
|
|
178
|
+
__default_dtype, # shifts_lr
|
|
179
|
+
__default_dtype, # cell
|
|
180
|
+
torch.int64, # nbmat_dftd3
|
|
181
|
+
__default_dtype, # shifts_dftd3
|
|
182
|
+
__default_dtype, # cutoff_dftd3
|
|
183
|
+
torch.int64, # nbmat_coulomb
|
|
184
|
+
__default_dtype, # shifts_coulomb
|
|
185
|
+
__default_dtype, # cutoff_coulomb
|
|
27
186
|
]
|
|
28
187
|
__constants__: ClassVar = ["_required_keys", "_required_keys_dtype", "_optional_keys", "_optional_keys_dtype"]
|
|
188
|
+
# TypedDict not supported in TorchScript; exclude from serialization
|
|
189
|
+
__jit_unused_properties__: ClassVar = ["metadata"]
|
|
29
190
|
|
|
30
191
|
def __init__(self):
|
|
31
192
|
super().__init__()
|
|
193
|
+
# Use object.__setattr__ to avoid TorchScript tracing this attribute
|
|
194
|
+
object.__setattr__(self, "_metadata", None)
|
|
195
|
+
|
|
196
|
+
@property
|
|
197
|
+
def metadata(self) -> ModelMetadata | None:
|
|
198
|
+
"""Return model metadata if available."""
|
|
199
|
+
return getattr(self, "_metadata", None)
|
|
32
200
|
|
|
33
|
-
def _prepare_dtype(self, data:
|
|
34
|
-
for k, d in zip(self._required_keys, self._required_keys_dtype):
|
|
201
|
+
def _prepare_dtype(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
202
|
+
for k, d in zip(self._required_keys, self._required_keys_dtype, strict=False):
|
|
35
203
|
assert k in data, f"Key {k} is required"
|
|
36
204
|
data[k] = data[k].to(d)
|
|
37
|
-
for k, d in zip(self._optional_keys, self._optional_keys_dtype):
|
|
205
|
+
for k, d in zip(self._optional_keys, self._optional_keys_dtype, strict=False):
|
|
38
206
|
if k in data:
|
|
39
207
|
data[k] = data[k].to(d)
|
|
40
208
|
return data
|
|
41
209
|
|
|
42
|
-
def prepare_input(self, data:
|
|
43
|
-
"""
|
|
210
|
+
def prepare_input(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
211
|
+
"""Common operations for input preparation."""
|
|
44
212
|
data = self._prepare_dtype(data)
|
|
45
213
|
data = nbops.set_nb_mode(data)
|
|
46
214
|
data = nbops.calc_masks(data)
|
aimnet/models/convert.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""Convert legacy JIT-compiled models to new state dict format.
|
|
2
|
+
|
|
3
|
+
This module provides CLI interface for converting legacy .jpt TorchScript
|
|
4
|
+
models to the new .pt format with metadata.
|
|
5
|
+
|
|
6
|
+
For programmatic use, import load_v1_model from aimnet.models.utils:
|
|
7
|
+
from aimnet.models.utils import load_v1_model
|
|
8
|
+
model, metadata = load_v1_model("model.jpt", "config.yaml", "model_new.pt")
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import click
|
|
12
|
+
|
|
13
|
+
from aimnet.models.utils import load_v1_model
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@click.command()
|
|
17
|
+
@click.argument("jpt", type=click.Path(exists=True))
|
|
18
|
+
@click.argument("yaml_config", type=click.Path(exists=True))
|
|
19
|
+
@click.argument("output", type=str)
|
|
20
|
+
def convert_legacy_jpt(jpt: str, yaml_config: str, output: str):
|
|
21
|
+
"""Convert legacy JIT model to new state dict format.
|
|
22
|
+
|
|
23
|
+
JPT: Path to the input JIT-compiled model file.
|
|
24
|
+
YAML_CONFIG: Path to the model YAML configuration file.
|
|
25
|
+
OUTPUT: Path to the output .pt file.
|
|
26
|
+
|
|
27
|
+
Example:
|
|
28
|
+
aimnet convert model.jpt config.yaml model_new.pt
|
|
29
|
+
"""
|
|
30
|
+
load_v1_model(jpt, yaml_config, output)
|