aimnet 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aimnet/models/base.py CHANGED
@@ -1,12 +1,151 @@
1
- from typing import ClassVar, Dict, Final
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import warnings
5
+ from typing import ClassVar, Final, NotRequired, TypedDict
2
6
 
3
7
  import torch
4
8
  from torch import Tensor, nn
5
9
 
6
10
  from aimnet import nbops
11
+ from aimnet.config import build_module
12
+ from aimnet.models.utils import (
13
+ extract_d3_params,
14
+ extract_species,
15
+ has_externalizable_dftd3,
16
+ validate_state_dict_keys,
17
+ )
18
+
19
+
20
+ class ModelMetadata(TypedDict):
21
+ """Metadata returned by load_model().
22
+
23
+ This TypedDict documents the structure of the metadata dictionary.
24
+ """
25
+
26
+ format_version: int # 1 = legacy .jpt, 2 = new .pt
27
+ cutoff: float # Model cutoff radius
28
+
29
+ # Action flags - what calculator should add externally
30
+ needs_coulomb: bool # Add external Coulomb?
31
+ needs_dispersion: bool # Add external DFTD3?
32
+
33
+ # Coulomb mode descriptor - what's in the model
34
+ # "sr_embedded": Model has SRCoulomb, add FULL externally
35
+ # "full_embedded": Full Coulomb in model (legacy JIT)
36
+ # "none": No Coulomb anywhere
37
+ coulomb_mode: str
38
+ coulomb_sr_rc: NotRequired[float | None] # Only if coulomb_mode="sr_embedded"
39
+ coulomb_sr_envelope: NotRequired[str | None] # "exp" | "cosine", only if sr_embedded
40
+
41
+ # Dispersion parameters (optional)
42
+ d3_params: NotRequired[dict | None] # {s8, a1, a2, s6} if needs_dispersion=True
43
+
44
+ implemented_species: list[int] # Supported atomic numbers
45
+
46
+
47
+ def load_model(path: str, device: str = "cpu") -> tuple[nn.Module, ModelMetadata]:
48
+ """Load model from file, supporting both new and legacy formats.
49
+
50
+ Automatically detects format:
51
+ - New format: state dict with embedded YAML config and metadata
52
+ - Legacy format: JIT-compiled TorchScript model
53
+
54
+ Parameters
55
+ ----------
56
+ path : str
57
+ Path to the model file (.pt or .jpt).
58
+ device : str
59
+ Device to load the model on. Default is "cpu".
60
+
61
+ Returns
62
+ -------
63
+ model : nn.Module
64
+ The loaded model with weights.
65
+ metadata : ModelMetadata
66
+ Dictionary containing model metadata. See ModelMetadata TypedDict for fields.
67
+
68
+ Notes
69
+ -----
70
+ For legacy JIT models (format_version=1), `needs_coulomb` and `needs_dispersion`
71
+ are False because LR modules are already embedded in the TorchScript model.
72
+ """
73
+ import yaml
74
+
75
+ # torch.load auto-detects TorchScript and dispatches to torch.jit.load
76
+ with warnings.catch_warnings():
77
+ warnings.filterwarnings("ignore", ".*looks like a TorchScript archive.*")
78
+ data = torch.load(path, map_location=device, weights_only=False)
79
+
80
+ # Check result type to determine format
81
+ if isinstance(data, dict) and "model_yaml" in data:
82
+ # New state dict format
83
+ model_config = yaml.safe_load(data["model_yaml"])
84
+ model = build_module(model_config)
85
+
86
+ # Use strict=False because modules may differ between formats
87
+ load_result = model.load_state_dict(data["state_dict"], strict=False)
88
+
89
+ # Check for unexpected missing/extra keys
90
+ real_missing, real_unexpected = validate_state_dict_keys(load_result.missing_keys, load_result.unexpected_keys)
91
+ if real_missing or real_unexpected:
92
+ msg_parts = []
93
+ if real_missing:
94
+ msg_parts.append(f"Missing keys: {real_missing}")
95
+ if real_unexpected:
96
+ msg_parts.append(f"Unexpected keys: {real_unexpected}")
97
+ warnings.warn(f"State dict mismatch during model loading. {'; '.join(msg_parts)}", stacklevel=2)
98
+
99
+ model = model.to(device)
100
+
101
+ # Preserve float64 precision for atomic shifts (SAE values) after device transfer
102
+ if hasattr(model, "outputs") and hasattr(model.outputs, "atomic_shift"):
103
+ model.outputs.atomic_shift.shifts = model.outputs.atomic_shift.shifts.double()
104
+
105
+ metadata: ModelMetadata = {
106
+ "format_version": data.get("format_version", 2), # Default 2 for early v2 files without version
107
+ "cutoff": data["cutoff"],
108
+ "needs_coulomb": data.get("needs_coulomb", False),
109
+ "needs_dispersion": data.get("needs_dispersion", False),
110
+ "coulomb_mode": data.get("coulomb_mode", "none"),
111
+ "coulomb_sr_rc": data.get("coulomb_sr_rc"),
112
+ "coulomb_sr_envelope": data.get("coulomb_sr_envelope"),
113
+ "d3_params": data.get("d3_params"),
114
+ "has_embedded_lr": data.get("has_embedded_lr", False),
115
+ "implemented_species": data.get("implemented_species", []),
116
+ }
7
117
 
118
+ # Attach metadata to model for easy access
119
+ model._metadata = metadata
8
120
 
9
- class AIMNet2Base(nn.Module): # pylint: disable=abstract-method
121
+ return model, metadata
122
+
123
+ elif isinstance(data, torch.jit.ScriptModule):
124
+ # Legacy JIT format - LR modules are already embedded in the TorchScript model
125
+ model = data
126
+ metadata: ModelMetadata = {
127
+ "format_version": 1, # Legacy .jpt format is v1
128
+ "cutoff": float(model.cutoff),
129
+ # Legacy models have LR modules embedded - don't add external ones
130
+ "needs_coulomb": False,
131
+ "needs_dispersion": False,
132
+ "coulomb_mode": "full_embedded",
133
+ # No coulomb_sr_rc/envelope for legacy (full Coulomb is embedded)
134
+ "d3_params": extract_d3_params(model) if has_externalizable_dftd3(model) else None,
135
+ "implemented_species": extract_species(model),
136
+ }
137
+
138
+ # Attempt metadata assignment; silently fails for JIT models
139
+ with contextlib.suppress(AttributeError, RuntimeError):
140
+ model._metadata = metadata # type: ignore[attr-defined]
141
+
142
+ return model, metadata
143
+
144
+ else:
145
+ raise ValueError(f"Unknown model format: {type(data)}")
146
+
147
+
148
+ class AIMNet2Base(nn.Module):
10
149
  """Base class for AIMNet2 models. Implements pre-processing data:
11
150
  converting to right dtype and device, setting nb mode, calculating masks.
12
151
  """
@@ -15,32 +154,61 @@ class AIMNet2Base(nn.Module): # pylint: disable=abstract-method
15
154
 
16
155
  _required_keys: Final = ["coord", "numbers", "charge"]
17
156
  _required_keys_dtype: Final = [__default_dtype, torch.int64, __default_dtype]
18
- _optional_keys: Final = ["mult", "nbmat", "nbmat_lr", "mol_idx", "shifts", "shifts_lr", "cell"]
157
+ _optional_keys: Final = [
158
+ "mult",
159
+ "nbmat",
160
+ "nbmat_lr",
161
+ "mol_idx",
162
+ "shifts",
163
+ "shifts_lr",
164
+ "cell",
165
+ "nbmat_dftd3",
166
+ "shifts_dftd3",
167
+ "cutoff_dftd3",
168
+ "nbmat_coulomb",
169
+ "shifts_coulomb",
170
+ "cutoff_coulomb",
171
+ ]
19
172
  _optional_keys_dtype: Final = [
20
- __default_dtype,
21
- torch.int64,
22
- torch.int64,
23
- torch.int64,
24
- __default_dtype,
25
- __default_dtype,
26
- __default_dtype,
173
+ __default_dtype, # mult
174
+ torch.int64, # nbmat
175
+ torch.int64, # nbmat_lr
176
+ torch.int64, # mol_idx
177
+ __default_dtype, # shifts
178
+ __default_dtype, # shifts_lr
179
+ __default_dtype, # cell
180
+ torch.int64, # nbmat_dftd3
181
+ __default_dtype, # shifts_dftd3
182
+ __default_dtype, # cutoff_dftd3
183
+ torch.int64, # nbmat_coulomb
184
+ __default_dtype, # shifts_coulomb
185
+ __default_dtype, # cutoff_coulomb
27
186
  ]
28
187
  __constants__: ClassVar = ["_required_keys", "_required_keys_dtype", "_optional_keys", "_optional_keys_dtype"]
188
+ # TypedDict not supported in TorchScript; exclude from serialization
189
+ __jit_unused_properties__: ClassVar = ["metadata"]
29
190
 
30
191
  def __init__(self):
31
192
  super().__init__()
193
+ # Use object.__setattr__ to avoid TorchScript tracing this attribute
194
+ object.__setattr__(self, "_metadata", None)
195
+
196
+ @property
197
+ def metadata(self) -> ModelMetadata | None:
198
+ """Return model metadata if available."""
199
+ return getattr(self, "_metadata", None)
32
200
 
33
- def _prepare_dtype(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
34
- for k, d in zip(self._required_keys, self._required_keys_dtype):
201
+ def _prepare_dtype(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
202
+ for k, d in zip(self._required_keys, self._required_keys_dtype, strict=False):
35
203
  assert k in data, f"Key {k} is required"
36
204
  data[k] = data[k].to(d)
37
- for k, d in zip(self._optional_keys, self._optional_keys_dtype):
205
+ for k, d in zip(self._optional_keys, self._optional_keys_dtype, strict=False):
38
206
  if k in data:
39
207
  data[k] = data[k].to(d)
40
208
  return data
41
209
 
42
- def prepare_input(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
43
- """Some sommon operations"""
210
+ def prepare_input(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
211
+ """Common operations for input preparation."""
44
212
  data = self._prepare_dtype(data)
45
213
  data = nbops.set_nb_mode(data)
46
214
  data = nbops.calc_masks(data)
@@ -0,0 +1,30 @@
1
+ """Convert legacy JIT-compiled models to new state dict format.
2
+
3
+ This module provides CLI interface for converting legacy .jpt TorchScript
4
+ models to the new .pt format with metadata.
5
+
6
+ For programmatic use, import load_v1_model from aimnet.models.utils:
7
+ from aimnet.models.utils import load_v1_model
8
+ model, metadata = load_v1_model("model.jpt", "config.yaml", "model_new.pt")
9
+ """
10
+
11
+ import click
12
+
13
+ from aimnet.models.utils import load_v1_model
14
+
15
+
16
+ @click.command()
17
+ @click.argument("jpt", type=click.Path(exists=True))
18
+ @click.argument("yaml_config", type=click.Path(exists=True))
19
+ @click.argument("output", type=str)
20
+ def convert_legacy_jpt(jpt: str, yaml_config: str, output: str):
21
+ """Convert legacy JIT model to new state dict format.
22
+
23
+ JPT: Path to the input JIT-compiled model file.
24
+ YAML_CONFIG: Path to the model YAML configuration file.
25
+ OUTPUT: Path to the output .pt file.
26
+
27
+ Example:
28
+ aimnet convert model.jpt config.yaml model_new.pt
29
+ """
30
+ load_v1_model(jpt, yaml_config, output)