aimnet 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aimnet/__init__.py +7 -0
- aimnet/base.py +24 -8
- aimnet/calculators/__init__.py +4 -4
- aimnet/calculators/aimnet2ase.py +19 -6
- aimnet/calculators/calculator.py +868 -108
- aimnet/calculators/model_registry.py +2 -5
- aimnet/calculators/model_registry.yaml +55 -17
- aimnet/cli.py +62 -6
- aimnet/config.py +8 -9
- aimnet/data/sgdataset.py +23 -22
- aimnet/kernels/__init__.py +66 -0
- aimnet/kernels/conv_sv_2d_sp_wp.py +478 -0
- aimnet/models/__init__.py +13 -1
- aimnet/models/aimnet2.py +19 -22
- aimnet/models/base.py +183 -15
- aimnet/models/convert.py +30 -0
- aimnet/models/utils.py +735 -0
- aimnet/modules/__init__.py +1 -1
- aimnet/modules/aev.py +49 -48
- aimnet/modules/core.py +14 -13
- aimnet/modules/lr.py +520 -115
- aimnet/modules/ops.py +537 -0
- aimnet/nbops.py +105 -15
- aimnet/ops.py +90 -18
- aimnet/train/export_model.py +226 -0
- aimnet/train/loss.py +7 -7
- aimnet/train/metrics.py +5 -6
- aimnet/train/train.py +4 -1
- aimnet/train/utils.py +42 -13
- aimnet-0.1.0.dist-info/METADATA +308 -0
- aimnet-0.1.0.dist-info/RECORD +43 -0
- {aimnet-0.0.1.dist-info → aimnet-0.1.0.dist-info}/WHEEL +1 -1
- aimnet-0.1.0.dist-info/entry_points.txt +3 -0
- aimnet/calculators/nb_kernel_cpu.py +0 -222
- aimnet/calculators/nb_kernel_cuda.py +0 -217
- aimnet/calculators/nbmat.py +0 -220
- aimnet/train/pt2jpt.py +0 -81
- aimnet-0.0.1.dist-info/METADATA +0 -78
- aimnet-0.0.1.dist-info/RECORD +0 -41
- aimnet-0.0.1.dist-info/entry_points.txt +0 -5
- {aimnet-0.0.1.dist-info → aimnet-0.1.0.dist-info/licenses}/LICENSE +0 -0
aimnet/models/utils.py
ADDED
|
@@ -0,0 +1,735 @@
|
|
|
1
|
+
"""Utility functions for model inspection and metadata extraction.
|
|
2
|
+
|
|
3
|
+
This module provides helper functions for:
|
|
4
|
+
- Recursive module traversal
|
|
5
|
+
- Extracting attributes from JIT-compiled models
|
|
6
|
+
- Detecting embedded Coulomb and dispersion modules
|
|
7
|
+
- Extracting D3 parameters and implemented species
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import contextlib
|
|
13
|
+
from collections.abc import Iterator
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
from torch import nn
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def named_children_rec(module: nn.Module) -> Iterator[tuple[str, nn.Module]]:
|
|
20
|
+
"""Recursively yield (name, child) for all descendants.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
module : nn.Module
|
|
25
|
+
The module to traverse.
|
|
26
|
+
|
|
27
|
+
Yields
|
|
28
|
+
------
|
|
29
|
+
tuple[str, nn.Module]
|
|
30
|
+
Pairs of (name, child_module) for all descendants.
|
|
31
|
+
"""
|
|
32
|
+
if isinstance(module, nn.Module):
|
|
33
|
+
for name, child in module.named_children():
|
|
34
|
+
yield name, child
|
|
35
|
+
yield from named_children_rec(child)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_jit_attr(module: nn.Module, attr: str, default: float) -> float:
|
|
39
|
+
"""Extract attribute from JIT module, handling TorchScript constants.
|
|
40
|
+
|
|
41
|
+
JIT models store scalar attributes as TorchScript constants which may
|
|
42
|
+
need special handling to extract as Python floats.
|
|
43
|
+
|
|
44
|
+
Parameters
|
|
45
|
+
----------
|
|
46
|
+
module : nn.Module
|
|
47
|
+
The module to extract the attribute from.
|
|
48
|
+
attr : str
|
|
49
|
+
The attribute name.
|
|
50
|
+
default : float
|
|
51
|
+
Default value if attribute is not found.
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
float
|
|
56
|
+
The attribute value as a float.
|
|
57
|
+
"""
|
|
58
|
+
val = None
|
|
59
|
+
|
|
60
|
+
# Try direct attribute access first
|
|
61
|
+
with contextlib.suppress(Exception):
|
|
62
|
+
val = getattr(module, attr, None)
|
|
63
|
+
|
|
64
|
+
# If that failed, try __getattr__ for TorchScript modules
|
|
65
|
+
if val is None:
|
|
66
|
+
with contextlib.suppress(AttributeError, RuntimeError):
|
|
67
|
+
val = module.__getattr__(attr)
|
|
68
|
+
|
|
69
|
+
# If still None, return default
|
|
70
|
+
if val is None:
|
|
71
|
+
return default
|
|
72
|
+
|
|
73
|
+
# Convert tensor/number to float
|
|
74
|
+
if hasattr(val, "item"):
|
|
75
|
+
return float(val.item())
|
|
76
|
+
elif hasattr(val, "__float__") or isinstance(val, (int, float)):
|
|
77
|
+
return float(val)
|
|
78
|
+
|
|
79
|
+
return default
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def has_dispersion(model: nn.Module) -> bool:
|
|
83
|
+
"""Check if model has any dispersion module embedded (DFTD3, D3BJ, or D3TS).
|
|
84
|
+
|
|
85
|
+
.. deprecated::
|
|
86
|
+
Use ``model.metadata`` instead. This function iterates through model
|
|
87
|
+
children which is slow and unreliable for JIT models.
|
|
88
|
+
|
|
89
|
+
All dispersion modules need nbmat_lr for neighbor calculations,
|
|
90
|
+
regardless of whether they use tabulated (DFTD3/D3BJ) or learned (D3TS) parameters.
|
|
91
|
+
|
|
92
|
+
Parameters
|
|
93
|
+
----------
|
|
94
|
+
model : nn.Module
|
|
95
|
+
The model to check.
|
|
96
|
+
|
|
97
|
+
Returns
|
|
98
|
+
-------
|
|
99
|
+
bool
|
|
100
|
+
True if any dispersion module is found.
|
|
101
|
+
"""
|
|
102
|
+
import warnings
|
|
103
|
+
|
|
104
|
+
warnings.warn(
|
|
105
|
+
"has_dispersion() is deprecated. Use model._metadata instead.",
|
|
106
|
+
DeprecationWarning,
|
|
107
|
+
stacklevel=2,
|
|
108
|
+
)
|
|
109
|
+
return any(name in ("dftd3", "d3bj", "d3ts") for name, _ in named_children_rec(model))
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def has_externalizable_dftd3(model: nn.Module) -> bool:
|
|
113
|
+
"""Check if model has DFTD3/D3BJ that can be externalized.
|
|
114
|
+
|
|
115
|
+
D3TS uses learned parameters from the NN and must stay embedded.
|
|
116
|
+
Only DFTD3/D3BJ with tabulated parameters can be externalized.
|
|
117
|
+
|
|
118
|
+
Parameters
|
|
119
|
+
----------
|
|
120
|
+
model : nn.Module
|
|
121
|
+
The model to check.
|
|
122
|
+
|
|
123
|
+
Returns
|
|
124
|
+
-------
|
|
125
|
+
bool
|
|
126
|
+
True if DFTD3 or D3BJ module is found.
|
|
127
|
+
"""
|
|
128
|
+
return any(name in ("dftd3", "d3bj") for name, _ in named_children_rec(model))
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def has_d3ts(model: nn.Module) -> bool:
|
|
132
|
+
"""Check if model has D3TS module (learned dispersion parameters).
|
|
133
|
+
|
|
134
|
+
D3TS uses learned parameters from the NN and must stay embedded.
|
|
135
|
+
|
|
136
|
+
Parameters
|
|
137
|
+
----------
|
|
138
|
+
model : nn.Module
|
|
139
|
+
The model to check.
|
|
140
|
+
|
|
141
|
+
Returns
|
|
142
|
+
-------
|
|
143
|
+
bool
|
|
144
|
+
True if D3TS module is found.
|
|
145
|
+
"""
|
|
146
|
+
return any(name == "d3ts" for name, _ in named_children_rec(model))
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def has_lrcoulomb(model: nn.Module) -> bool:
|
|
150
|
+
"""Check if model has LRCoulomb module embedded.
|
|
151
|
+
|
|
152
|
+
Parameters
|
|
153
|
+
----------
|
|
154
|
+
model : nn.Module
|
|
155
|
+
The model to check.
|
|
156
|
+
|
|
157
|
+
Returns
|
|
158
|
+
-------
|
|
159
|
+
bool
|
|
160
|
+
True if LRCoulomb module is found.
|
|
161
|
+
"""
|
|
162
|
+
return any(name == "lrcoulomb" for name, _ in named_children_rec(model))
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def iter_lrcoulomb_mods(model: nn.Module) -> Iterator[nn.Module]:
|
|
166
|
+
"""Iterate over all LRCoulomb modules in the model.
|
|
167
|
+
|
|
168
|
+
.. deprecated::
|
|
169
|
+
Use ``model.metadata`` instead. This function iterates through model
|
|
170
|
+
children which is slow and unreliable for JIT models.
|
|
171
|
+
|
|
172
|
+
Parameters
|
|
173
|
+
----------
|
|
174
|
+
model : nn.Module
|
|
175
|
+
The model to search.
|
|
176
|
+
|
|
177
|
+
Yields
|
|
178
|
+
------
|
|
179
|
+
nn.Module
|
|
180
|
+
Each LRCoulomb module found.
|
|
181
|
+
"""
|
|
182
|
+
import warnings
|
|
183
|
+
|
|
184
|
+
warnings.warn(
|
|
185
|
+
"iter_lrcoulomb_mods() is deprecated. Use model._metadata instead.",
|
|
186
|
+
DeprecationWarning,
|
|
187
|
+
stacklevel=2,
|
|
188
|
+
)
|
|
189
|
+
for name, module in named_children_rec(model):
|
|
190
|
+
if name == "lrcoulomb":
|
|
191
|
+
yield module
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def extract_d3_params(model: nn.Module) -> dict[str, float] | None:
|
|
195
|
+
"""Extract D3 parameters from model's DFTD3/D3BJ module.
|
|
196
|
+
|
|
197
|
+
Only extracts from DFTD3/D3BJ (tabulated params), not D3TS (learned params).
|
|
198
|
+
Handles TorchScript constants which may be stored differently than
|
|
199
|
+
regular Python attributes.
|
|
200
|
+
|
|
201
|
+
Parameters
|
|
202
|
+
----------
|
|
203
|
+
model : nn.Module
|
|
204
|
+
The model to extract D3 parameters from.
|
|
205
|
+
|
|
206
|
+
Returns
|
|
207
|
+
-------
|
|
208
|
+
dict or None
|
|
209
|
+
Dictionary with s6, s8, a1, a2 parameters, or None if not found.
|
|
210
|
+
"""
|
|
211
|
+
for name, module in named_children_rec(model):
|
|
212
|
+
if name in ("dftd3", "d3bj"): # NOT d3ts - it uses learned params
|
|
213
|
+
return {
|
|
214
|
+
"s8": get_jit_attr(module, "s8", 0.0),
|
|
215
|
+
"a1": get_jit_attr(module, "a1", 0.0),
|
|
216
|
+
"a2": get_jit_attr(module, "a2", 0.0),
|
|
217
|
+
"s6": get_jit_attr(module, "s6", 1.0),
|
|
218
|
+
}
|
|
219
|
+
return None
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def extract_coulomb_rc(model: nn.Module) -> float:
|
|
223
|
+
"""Extract Coulomb cutoff (rc) from model's LRCoulomb module.
|
|
224
|
+
|
|
225
|
+
Parameters
|
|
226
|
+
----------
|
|
227
|
+
model : nn.Module
|
|
228
|
+
The model to extract the cutoff from.
|
|
229
|
+
|
|
230
|
+
Returns
|
|
231
|
+
-------
|
|
232
|
+
float
|
|
233
|
+
The Coulomb short-range cutoff value.
|
|
234
|
+
|
|
235
|
+
Raises
|
|
236
|
+
------
|
|
237
|
+
KeyError
|
|
238
|
+
If LRCoulomb module is not found or rc attribute is missing.
|
|
239
|
+
"""
|
|
240
|
+
for name, module in named_children_rec(model):
|
|
241
|
+
if name == "lrcoulomb":
|
|
242
|
+
rc = getattr(module, "rc", None)
|
|
243
|
+
if rc is not None:
|
|
244
|
+
return float(rc.item()) if hasattr(rc, "item") else float(rc)
|
|
245
|
+
raise KeyError("LRCoulomb module found but 'rc' attribute is missing")
|
|
246
|
+
raise KeyError("No LRCoulomb module found in model")
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def extract_species(model: nn.Module) -> list[int]:
|
|
250
|
+
"""Extract implemented species from model's afv.weight (non-NaN entries).
|
|
251
|
+
|
|
252
|
+
Checks afv.weight for non-NaN entries to determine which elements are implemented.
|
|
253
|
+
|
|
254
|
+
Parameters
|
|
255
|
+
----------
|
|
256
|
+
model : nn.Module
|
|
257
|
+
The model to extract species from.
|
|
258
|
+
|
|
259
|
+
Returns
|
|
260
|
+
-------
|
|
261
|
+
list[int]
|
|
262
|
+
Sorted list of atomic numbers that are implemented in the model.
|
|
263
|
+
"""
|
|
264
|
+
sd = model.state_dict()
|
|
265
|
+
afv_weight = sd.get("afv.weight")
|
|
266
|
+
if afv_weight is not None:
|
|
267
|
+
species = []
|
|
268
|
+
for i in range(1, afv_weight.shape[0]):
|
|
269
|
+
# Element is implemented if its row is not all NaN
|
|
270
|
+
if not torch.isnan(afv_weight[i]).all():
|
|
271
|
+
species.append(i)
|
|
272
|
+
return species
|
|
273
|
+
return []
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def has_d3ts_in_config(config: dict) -> bool:
|
|
277
|
+
"""Check if YAML config contains D3TS module.
|
|
278
|
+
|
|
279
|
+
Parameters
|
|
280
|
+
----------
|
|
281
|
+
config : dict
|
|
282
|
+
Model YAML configuration dictionary.
|
|
283
|
+
|
|
284
|
+
Returns
|
|
285
|
+
-------
|
|
286
|
+
bool
|
|
287
|
+
True if D3TS is in the outputs section.
|
|
288
|
+
"""
|
|
289
|
+
outputs = config.get("kwargs", {}).get("outputs", {})
|
|
290
|
+
return "d3ts" in outputs
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def has_dftd3_in_config(config: dict) -> bool:
|
|
294
|
+
"""Check if YAML config contains DFTD3 or D3BJ module.
|
|
295
|
+
|
|
296
|
+
Parameters
|
|
297
|
+
----------
|
|
298
|
+
config : dict
|
|
299
|
+
Model YAML configuration dictionary.
|
|
300
|
+
|
|
301
|
+
Returns
|
|
302
|
+
-------
|
|
303
|
+
bool
|
|
304
|
+
True if DFTD3 or D3BJ is in the outputs section.
|
|
305
|
+
"""
|
|
306
|
+
outputs = config.get("kwargs", {}).get("outputs", {})
|
|
307
|
+
return "dftd3" in outputs or "d3bj" in outputs
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
# --- State dict key validation ---
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def validate_state_dict_keys(
|
|
314
|
+
missing_keys: list[str],
|
|
315
|
+
unexpected_keys: list[str],
|
|
316
|
+
) -> tuple[list[str], list[str]]:
|
|
317
|
+
"""Filter out expected missing/unexpected keys during format migration.
|
|
318
|
+
|
|
319
|
+
During v1→v2 model conversion, certain keys are expected to be missing
|
|
320
|
+
(SRCoulomb added) or unexpected (LRCoulomb/DFTD3 removed). This function
|
|
321
|
+
filters those out and returns only keys that indicate actual problems.
|
|
322
|
+
|
|
323
|
+
Parameters
|
|
324
|
+
----------
|
|
325
|
+
missing_keys : list[str]
|
|
326
|
+
Keys missing from the state dict.
|
|
327
|
+
unexpected_keys : list[str]
|
|
328
|
+
Keys in the state dict that weren't expected.
|
|
329
|
+
|
|
330
|
+
Returns
|
|
331
|
+
-------
|
|
332
|
+
tuple[list[str], list[str]]
|
|
333
|
+
(real_missing, real_unexpected) - keys that indicate actual problems.
|
|
334
|
+
"""
|
|
335
|
+
# Prefixes for keys that are expected to be missing/unexpected
|
|
336
|
+
EXPECTED_MISSING_PREFIXES = ("outputs.srcoulomb.",)
|
|
337
|
+
EXPECTED_UNEXPECTED_PREFIXES = (
|
|
338
|
+
"outputs.lrcoulomb.",
|
|
339
|
+
"outputs.dftd3.",
|
|
340
|
+
"outputs.d3bj.",
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
def is_expected_missing(k: str) -> bool:
|
|
344
|
+
return k.startswith(EXPECTED_MISSING_PREFIXES)
|
|
345
|
+
|
|
346
|
+
def is_expected_unexpected(k: str) -> bool:
|
|
347
|
+
return k.startswith(EXPECTED_UNEXPECTED_PREFIXES)
|
|
348
|
+
|
|
349
|
+
real_missing = [k for k in missing_keys if not is_expected_missing(k)]
|
|
350
|
+
real_unexpected = [k for k in unexpected_keys if not is_expected_unexpected(k)]
|
|
351
|
+
return real_missing, real_unexpected
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
# --- YAML config manipulation ---
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def strip_lr_modules_from_yaml(
|
|
358
|
+
config: dict,
|
|
359
|
+
source: dict | nn.Module,
|
|
360
|
+
) -> tuple[dict, str, bool, dict[str, float] | None, float | None, str, str | None]:
|
|
361
|
+
"""Remove LRCoulomb and DFTD3 from YAML config, add SRCoulomb.
|
|
362
|
+
|
|
363
|
+
This is the unified function for both export (from state dict) and
|
|
364
|
+
convert (from JIT model) paths.
|
|
365
|
+
|
|
366
|
+
Parameters
|
|
367
|
+
----------
|
|
368
|
+
config : dict
|
|
369
|
+
Model YAML configuration dictionary.
|
|
370
|
+
source : dict | nn.Module
|
|
371
|
+
Either a state dict (for export path) or a JIT model (for convert path).
|
|
372
|
+
Used to extract metadata like Coulomb rc and D3 params.
|
|
373
|
+
|
|
374
|
+
Returns
|
|
375
|
+
-------
|
|
376
|
+
tuple
|
|
377
|
+
(config, coulomb_mode, needs_dispersion, d3_params, coulomb_sr_rc, coulomb_sr_envelope, disp_ptfile):
|
|
378
|
+
- config: Modified config with LR modules removed and SRCoulomb added
|
|
379
|
+
- coulomb_mode: "sr_embedded" if LRCoulomb was present, else "none"
|
|
380
|
+
- needs_dispersion: True if DFTD3/D3BJ was present
|
|
381
|
+
- d3_params: D3 parameters dict or None
|
|
382
|
+
- coulomb_sr_rc: Short-range Coulomb cutoff or None
|
|
383
|
+
- coulomb_sr_envelope: Envelope function ("exp" or "cosine")
|
|
384
|
+
- disp_ptfile: Path to DispParam ptfile (if any) for loading buffer
|
|
385
|
+
|
|
386
|
+
Raises
|
|
387
|
+
------
|
|
388
|
+
ValueError
|
|
389
|
+
If model has both D3TS and DFTD3/D3BJ (double dispersion).
|
|
390
|
+
If LRCoulomb is present but rc cannot be determined.
|
|
391
|
+
|
|
392
|
+
Notes
|
|
393
|
+
-----
|
|
394
|
+
SRCoulomb is added to outputs only when LRCoulomb was present in the
|
|
395
|
+
original config. This ensures proper energy accounting when the
|
|
396
|
+
calculator adds external LRCoulomb.
|
|
397
|
+
"""
|
|
398
|
+
import copy
|
|
399
|
+
|
|
400
|
+
config = copy.deepcopy(config)
|
|
401
|
+
outputs = config.get("kwargs", {}).get("outputs", {})
|
|
402
|
+
|
|
403
|
+
# Determine source type
|
|
404
|
+
is_jit_model = isinstance(source, nn.Module)
|
|
405
|
+
|
|
406
|
+
# --- Detect Coulomb ---
|
|
407
|
+
if is_jit_model:
|
|
408
|
+
has_coulomb = has_lrcoulomb(source)
|
|
409
|
+
coulomb_sr_rc = extract_coulomb_rc(source) if has_coulomb else None
|
|
410
|
+
# Legacy models always used exp envelope
|
|
411
|
+
coulomb_sr_envelope = "exp"
|
|
412
|
+
else:
|
|
413
|
+
# State dict path - check YAML config first, then state dict
|
|
414
|
+
has_coulomb_in_sd = any(k.startswith("outputs.lrcoulomb") for k in source)
|
|
415
|
+
if "lrcoulomb" in outputs:
|
|
416
|
+
has_coulomb = True
|
|
417
|
+
lrc_config = outputs["lrcoulomb"]
|
|
418
|
+
lrc_kwargs = lrc_config.get("kwargs", {})
|
|
419
|
+
rc_value = lrc_kwargs.get("rc")
|
|
420
|
+
coulomb_sr_rc = float(rc_value) if rc_value is not None else None
|
|
421
|
+
coulomb_sr_envelope = lrc_kwargs.get("envelope", "exp")
|
|
422
|
+
elif has_coulomb_in_sd:
|
|
423
|
+
has_coulomb = True
|
|
424
|
+
rc_key = "outputs.lrcoulomb.rc"
|
|
425
|
+
coulomb_sr_rc = float(source[rc_key].item()) if rc_key in source else None
|
|
426
|
+
coulomb_sr_envelope = "exp" # Cannot extract from state dict
|
|
427
|
+
else:
|
|
428
|
+
has_coulomb = False
|
|
429
|
+
coulomb_sr_rc = None
|
|
430
|
+
coulomb_sr_envelope = "exp"
|
|
431
|
+
|
|
432
|
+
# Validate: if Coulomb is needed, rc must be determinable
|
|
433
|
+
if has_coulomb and coulomb_sr_rc is None:
|
|
434
|
+
raise ValueError(
|
|
435
|
+
"Model requires Coulomb but 'rc' could not be determined from YAML config or source. "
|
|
436
|
+
"Please specify 'rc' explicitly in the LRCoulomb config kwargs."
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
# --- Detect Dispersion ---
|
|
440
|
+
if is_jit_model:
|
|
441
|
+
# Check if model has dftd3/d3bj modules
|
|
442
|
+
has_d3_module = any(name in ("dftd3", "d3bj") for name, _ in named_children_rec(source))
|
|
443
|
+
|
|
444
|
+
# Check YAML to determine if it's D3TS (not externalizable)
|
|
445
|
+
# D3TS uses NN-predicted C6/alpha and must stay embedded
|
|
446
|
+
is_d3ts = False
|
|
447
|
+
if has_d3_module:
|
|
448
|
+
for key in ["dftd3", "d3bj"]:
|
|
449
|
+
if key in outputs:
|
|
450
|
+
d3_class = outputs[key].get("class", "")
|
|
451
|
+
if "D3TS" in d3_class:
|
|
452
|
+
is_d3ts = True
|
|
453
|
+
break
|
|
454
|
+
|
|
455
|
+
# Only externalize if NOT D3TS (DFTD3/D3BJ with tabulated params can be externalized)
|
|
456
|
+
needs_dispersion = has_d3_module and not is_d3ts
|
|
457
|
+
|
|
458
|
+
if needs_dispersion:
|
|
459
|
+
# Try to extract from JIT model first
|
|
460
|
+
d3_params = extract_d3_params(source)
|
|
461
|
+
# If extraction failed or returned zeros, try YAML config
|
|
462
|
+
if d3_params is None or (
|
|
463
|
+
d3_params.get("s8") == 0.0 and d3_params.get("a1") == 0.0 and d3_params.get("a2") == 0.0
|
|
464
|
+
):
|
|
465
|
+
for key in ["dftd3", "d3bj"]:
|
|
466
|
+
if key in outputs:
|
|
467
|
+
d3_config = outputs[key]
|
|
468
|
+
d3_kwargs = d3_config.get("kwargs", {})
|
|
469
|
+
d3_params = {
|
|
470
|
+
"s8": d3_kwargs.get("s8", 0.0),
|
|
471
|
+
"a1": d3_kwargs.get("a1", 0.0),
|
|
472
|
+
"a2": d3_kwargs.get("a2", 0.0),
|
|
473
|
+
"s6": d3_kwargs.get("s6", 1.0),
|
|
474
|
+
}
|
|
475
|
+
break
|
|
476
|
+
else:
|
|
477
|
+
d3_params = None
|
|
478
|
+
else:
|
|
479
|
+
# State dict path - check YAML config
|
|
480
|
+
needs_dispersion = False
|
|
481
|
+
d3_params = None
|
|
482
|
+
for key in ["dftd3", "d3bj"]:
|
|
483
|
+
if key in outputs:
|
|
484
|
+
d3_config = outputs[key]
|
|
485
|
+
# Check if it's D3TS (must stay embedded, not externalizable)
|
|
486
|
+
module_class = d3_config.get("class", "")
|
|
487
|
+
if "D3TS" in module_class:
|
|
488
|
+
# D3TS uses NN-predicted C6/alpha, must stay embedded
|
|
489
|
+
needs_dispersion = False
|
|
490
|
+
d3_params = None
|
|
491
|
+
break
|
|
492
|
+
# DFTD3/D3BJ with tabulated params can be externalized
|
|
493
|
+
needs_dispersion = True
|
|
494
|
+
d3_kwargs = d3_config.get("kwargs", {})
|
|
495
|
+
d3_params = {
|
|
496
|
+
"s8": d3_kwargs.get("s8", 0.0),
|
|
497
|
+
"a1": d3_kwargs.get("a1", 0.0),
|
|
498
|
+
"a2": d3_kwargs.get("a2", 0.0),
|
|
499
|
+
"s6": d3_kwargs.get("s6", 1.0),
|
|
500
|
+
}
|
|
501
|
+
break
|
|
502
|
+
|
|
503
|
+
# Validate: D3TS + DFTD3/D3BJ is invalid (would cause double dispersion)
|
|
504
|
+
has_d3ts_model = has_d3ts(source) if is_jit_model else False
|
|
505
|
+
if needs_dispersion and (has_d3ts_model or has_d3ts_in_config(config)):
|
|
506
|
+
raise ValueError(
|
|
507
|
+
"Model has both D3TS (learned) and DFTD3/D3BJ (tabulated) dispersion. "
|
|
508
|
+
"D3TS uses learned parameters and must stay embedded, while DFTD3/D3BJ "
|
|
509
|
+
"would be externalized. This configuration leads to double dispersion "
|
|
510
|
+
"correction. Remove either D3TS or DFTD3/D3BJ from the model."
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
# --- Rebuild outputs dict ---
|
|
514
|
+
new_outputs = {}
|
|
515
|
+
for key, value in outputs.items():
|
|
516
|
+
if key == "lrcoulomb":
|
|
517
|
+
pass # Will be added externally by calculator
|
|
518
|
+
elif key in ["dftd3", "d3bj"]:
|
|
519
|
+
# Check if it's D3TS (must stay embedded)
|
|
520
|
+
module_class = value.get("class", "")
|
|
521
|
+
if "D3TS" in module_class:
|
|
522
|
+
new_outputs[key] = value # Keep D3TS embedded
|
|
523
|
+
else:
|
|
524
|
+
pass # Remove DFTD3/D3BJ for externalization
|
|
525
|
+
else:
|
|
526
|
+
new_outputs[key] = value
|
|
527
|
+
|
|
528
|
+
# Strip ptfile from DispParam configs but save the path
|
|
529
|
+
# (raw training weights don't contain disp_param0 buffer, need to load from ptfile)
|
|
530
|
+
disp_ptfile: str | None = None
|
|
531
|
+
for _key, value in new_outputs.items():
|
|
532
|
+
if isinstance(value, dict):
|
|
533
|
+
module_class = value.get("class", "")
|
|
534
|
+
if "DispParam" in module_class:
|
|
535
|
+
kwargs = value.get("kwargs", {})
|
|
536
|
+
if "ptfile" in kwargs:
|
|
537
|
+
disp_ptfile = kwargs.pop("ptfile") # Save before removing
|
|
538
|
+
|
|
539
|
+
# Add SRCoulomb if LRCoulomb was present
|
|
540
|
+
if has_coulomb:
|
|
541
|
+
new_outputs["srcoulomb"] = {
|
|
542
|
+
"class": "aimnet.modules.SRCoulomb",
|
|
543
|
+
"kwargs": {
|
|
544
|
+
"rc": coulomb_sr_rc,
|
|
545
|
+
"key_in": "charges",
|
|
546
|
+
"key_out": "energy",
|
|
547
|
+
"envelope": coulomb_sr_envelope,
|
|
548
|
+
},
|
|
549
|
+
}
|
|
550
|
+
|
|
551
|
+
config["kwargs"]["outputs"] = new_outputs
|
|
552
|
+
coulomb_mode = "sr_embedded" if has_coulomb else "none"
|
|
553
|
+
|
|
554
|
+
return (
|
|
555
|
+
config,
|
|
556
|
+
coulomb_mode,
|
|
557
|
+
needs_dispersion,
|
|
558
|
+
d3_params,
|
|
559
|
+
coulomb_sr_rc,
|
|
560
|
+
coulomb_sr_envelope if coulomb_sr_envelope else "exp",
|
|
561
|
+
disp_ptfile,
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
# --- Model loading ---
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
def load_v1_model(
|
|
569
|
+
jpt_path: str,
|
|
570
|
+
yaml_config_path: str,
|
|
571
|
+
output_path: str | None = None,
|
|
572
|
+
verbose: bool = True,
|
|
573
|
+
) -> tuple[nn.Module, dict]:
|
|
574
|
+
"""Load legacy JIT model (v1) and convert to v2 format.
|
|
575
|
+
|
|
576
|
+
This is the primary entry point for loading legacy models.
|
|
577
|
+
|
|
578
|
+
Parameters
|
|
579
|
+
----------
|
|
580
|
+
jpt_path : str
|
|
581
|
+
Path to the input JIT-compiled model file (.jpt).
|
|
582
|
+
yaml_config_path : str
|
|
583
|
+
Path to the model YAML configuration file.
|
|
584
|
+
output_path : str, optional
|
|
585
|
+
If provided, save the converted model to this path.
|
|
586
|
+
verbose : bool
|
|
587
|
+
Whether to print progress messages.
|
|
588
|
+
|
|
589
|
+
Returns
|
|
590
|
+
-------
|
|
591
|
+
model : nn.Module
|
|
592
|
+
The loaded model in v2 format.
|
|
593
|
+
metadata : dict
|
|
594
|
+
Model metadata dictionary with keys:
|
|
595
|
+
- format_version: 2
|
|
596
|
+
- cutoff: float
|
|
597
|
+
- needs_coulomb: bool
|
|
598
|
+
- needs_dispersion: bool
|
|
599
|
+
- coulomb_mode: str
|
|
600
|
+
- coulomb_sr_rc: float | None
|
|
601
|
+
- coulomb_sr_envelope: str | None
|
|
602
|
+
- d3_params: dict | None
|
|
603
|
+
- implemented_species: list[int]
|
|
604
|
+
|
|
605
|
+
Example
|
|
606
|
+
-------
|
|
607
|
+
>>> from aimnet.models.utils import load_v1_model
|
|
608
|
+
>>> model, metadata = load_v1_model("model.jpt", "config.yaml")
|
|
609
|
+
>>> print(metadata["format_version"]) # 2
|
|
610
|
+
|
|
611
|
+
Warnings
|
|
612
|
+
--------
|
|
613
|
+
UserWarning
|
|
614
|
+
If D3 parameter extraction produces zero values.
|
|
615
|
+
"""
|
|
616
|
+
import copy
|
|
617
|
+
|
|
618
|
+
import torch
|
|
619
|
+
import yaml
|
|
620
|
+
|
|
621
|
+
from aimnet.config import build_module
|
|
622
|
+
|
|
623
|
+
# Load YAML config
|
|
624
|
+
with open(yaml_config_path, encoding="utf-8") as f:
|
|
625
|
+
model_config = yaml.safe_load(f)
|
|
626
|
+
|
|
627
|
+
# Load JIT model
|
|
628
|
+
if verbose:
|
|
629
|
+
print(f"Loading JIT model from {jpt_path}")
|
|
630
|
+
jit_model = torch.jit.load(jpt_path, map_location="cpu")
|
|
631
|
+
|
|
632
|
+
# Extract metadata from JIT
|
|
633
|
+
cutoff = float(jit_model.cutoff)
|
|
634
|
+
implemented_species = extract_species(jit_model)
|
|
635
|
+
|
|
636
|
+
# Strip LR modules from YAML and add SRCoulomb
|
|
637
|
+
# Note: disp_ptfile is unused here because JIT model already has disp_param0 in its state dict
|
|
638
|
+
core_config, coulomb_mode, needs_dispersion, d3_params, coulomb_sr_rc, coulomb_sr_envelope, _disp_ptfile = (
|
|
639
|
+
strip_lr_modules_from_yaml(model_config, jit_model)
|
|
640
|
+
)
|
|
641
|
+
|
|
642
|
+
# Inform user about dispersion handling
|
|
643
|
+
if verbose:
|
|
644
|
+
if needs_dispersion:
|
|
645
|
+
# External dispersion (DFTD3/D3BJ with tabulated params)
|
|
646
|
+
if d3_params is None:
|
|
647
|
+
print("WARNING: Model has DFTD3 module but D3 params extraction failed!")
|
|
648
|
+
elif d3_params.get("s8") == 0.0 and d3_params.get("a1") == 0.0 and d3_params.get("a2") == 0.0:
|
|
649
|
+
print("WARNING: D3 params appear to be all zeros - extraction may have failed!")
|
|
650
|
+
print(f" Extracted: {d3_params}")
|
|
651
|
+
else:
|
|
652
|
+
print(
|
|
653
|
+
f" D3 parameters: s6={d3_params['s6']}, s8={d3_params['s8']}, "
|
|
654
|
+
f"a1={d3_params['a1']}, a2={d3_params['a2']}"
|
|
655
|
+
)
|
|
656
|
+
else:
|
|
657
|
+
# Check if D3TS is embedded
|
|
658
|
+
outputs = model_config.get("kwargs", {}).get("outputs", {})
|
|
659
|
+
has_d3ts = any("D3TS" in outputs.get(k, {}).get("class", "") for k in ["dftd3", "d3bj", "d3ts"])
|
|
660
|
+
if has_d3ts:
|
|
661
|
+
print(" D3TS dispersion kept embedded (uses NN-predicted C6/alpha)")
|
|
662
|
+
|
|
663
|
+
# Detect if model has any embedded LR modules that need nbmat_lr
|
|
664
|
+
outputs = model_config.get("kwargs", {}).get("outputs", {})
|
|
665
|
+
has_embedded_lr = False
|
|
666
|
+
|
|
667
|
+
# Check for embedded D3TS
|
|
668
|
+
has_d3ts = any("D3TS" in outputs.get(k, {}).get("class", "") for k in ["dftd3", "d3bj", "d3ts"])
|
|
669
|
+
if has_d3ts:
|
|
670
|
+
has_embedded_lr = True
|
|
671
|
+
|
|
672
|
+
# Check for embedded SRCoulomb (model had LRCoulomb before conversion)
|
|
673
|
+
if coulomb_mode == "sr_embedded":
|
|
674
|
+
has_embedded_lr = True
|
|
675
|
+
|
|
676
|
+
# Convert config to YAML string
|
|
677
|
+
core_yaml_str = yaml.dump(core_config, default_flow_style=False, sort_keys=False)
|
|
678
|
+
|
|
679
|
+
# Build model from modified config
|
|
680
|
+
if verbose:
|
|
681
|
+
print("Building model from YAML config...")
|
|
682
|
+
core_model = build_module(copy.deepcopy(core_config))
|
|
683
|
+
|
|
684
|
+
# Load weights from JIT model
|
|
685
|
+
jit_sd = jit_model.state_dict()
|
|
686
|
+
load_result = core_model.load_state_dict(jit_sd, strict=False)
|
|
687
|
+
|
|
688
|
+
# Validate keys
|
|
689
|
+
real_missing, real_unexpected = validate_state_dict_keys(load_result.missing_keys, load_result.unexpected_keys)
|
|
690
|
+
if real_missing:
|
|
691
|
+
print(f"WARNING: Unexpected missing keys: {real_missing}")
|
|
692
|
+
if real_unexpected:
|
|
693
|
+
print(f"WARNING: Unexpected extra keys: {real_unexpected}")
|
|
694
|
+
if not real_missing and not real_unexpected and verbose:
|
|
695
|
+
print("Loaded weights successfully")
|
|
696
|
+
|
|
697
|
+
# Convert atomic_shift to float64 to preserve SAE precision
|
|
698
|
+
if hasattr(core_model, "outputs") and hasattr(core_model.outputs, "atomic_shift"):
|
|
699
|
+
core_model.outputs.atomic_shift.double()
|
|
700
|
+
atomic_shift_key = "outputs.atomic_shift.shifts.weight"
|
|
701
|
+
if atomic_shift_key in jit_sd:
|
|
702
|
+
core_model.outputs.atomic_shift.shifts.weight.data.copy_(jit_sd[atomic_shift_key])
|
|
703
|
+
if verbose:
|
|
704
|
+
print(" Atomic shift converted to float64")
|
|
705
|
+
|
|
706
|
+
core_model.eval()
|
|
707
|
+
|
|
708
|
+
# Create metadata
|
|
709
|
+
needs_coulomb = coulomb_mode == "sr_embedded"
|
|
710
|
+
metadata = {
|
|
711
|
+
"format_version": 2,
|
|
712
|
+
"model_yaml": core_yaml_str,
|
|
713
|
+
"cutoff": cutoff,
|
|
714
|
+
"needs_coulomb": needs_coulomb,
|
|
715
|
+
"needs_dispersion": needs_dispersion,
|
|
716
|
+
"coulomb_mode": coulomb_mode,
|
|
717
|
+
"coulomb_sr_rc": coulomb_sr_rc if needs_coulomb else None,
|
|
718
|
+
"coulomb_sr_envelope": coulomb_sr_envelope if needs_coulomb else None,
|
|
719
|
+
"d3_params": d3_params if needs_dispersion else None,
|
|
720
|
+
"has_embedded_lr": has_embedded_lr,
|
|
721
|
+
"implemented_species": implemented_species,
|
|
722
|
+
}
|
|
723
|
+
|
|
724
|
+
# Save if output path provided
|
|
725
|
+
if output_path is not None:
|
|
726
|
+
save_data = {**metadata, "state_dict": core_model.state_dict()}
|
|
727
|
+
torch.save(save_data, output_path)
|
|
728
|
+
if verbose:
|
|
729
|
+
print(f"\nSaved model to {output_path}")
|
|
730
|
+
print(f" cutoff: {cutoff:.3f}")
|
|
731
|
+
print(f" needs_coulomb: {needs_coulomb}")
|
|
732
|
+
print(f" needs_dispersion: {needs_dispersion}")
|
|
733
|
+
print(f" has_embedded_lr: {has_embedded_lr}")
|
|
734
|
+
|
|
735
|
+
return core_model, metadata
|