aimnet 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aimnet/models/utils.py ADDED
@@ -0,0 +1,735 @@
1
+ """Utility functions for model inspection and metadata extraction.
2
+
3
+ This module provides helper functions for:
4
+ - Recursive module traversal
5
+ - Extracting attributes from JIT-compiled models
6
+ - Detecting embedded Coulomb and dispersion modules
7
+ - Extracting D3 parameters and implemented species
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import contextlib
13
+ from collections.abc import Iterator
14
+
15
+ import torch
16
+ from torch import nn
17
+
18
+
19
+ def named_children_rec(module: nn.Module) -> Iterator[tuple[str, nn.Module]]:
20
+ """Recursively yield (name, child) for all descendants.
21
+
22
+ Parameters
23
+ ----------
24
+ module : nn.Module
25
+ The module to traverse.
26
+
27
+ Yields
28
+ ------
29
+ tuple[str, nn.Module]
30
+ Pairs of (name, child_module) for all descendants.
31
+ """
32
+ if isinstance(module, nn.Module):
33
+ for name, child in module.named_children():
34
+ yield name, child
35
+ yield from named_children_rec(child)
36
+
37
+
38
+ def get_jit_attr(module: nn.Module, attr: str, default: float) -> float:
39
+ """Extract attribute from JIT module, handling TorchScript constants.
40
+
41
+ JIT models store scalar attributes as TorchScript constants which may
42
+ need special handling to extract as Python floats.
43
+
44
+ Parameters
45
+ ----------
46
+ module : nn.Module
47
+ The module to extract the attribute from.
48
+ attr : str
49
+ The attribute name.
50
+ default : float
51
+ Default value if attribute is not found.
52
+
53
+ Returns
54
+ -------
55
+ float
56
+ The attribute value as a float.
57
+ """
58
+ val = None
59
+
60
+ # Try direct attribute access first
61
+ with contextlib.suppress(Exception):
62
+ val = getattr(module, attr, None)
63
+
64
+ # If that failed, try __getattr__ for TorchScript modules
65
+ if val is None:
66
+ with contextlib.suppress(AttributeError, RuntimeError):
67
+ val = module.__getattr__(attr)
68
+
69
+ # If still None, return default
70
+ if val is None:
71
+ return default
72
+
73
+ # Convert tensor/number to float
74
+ if hasattr(val, "item"):
75
+ return float(val.item())
76
+ elif hasattr(val, "__float__") or isinstance(val, (int, float)):
77
+ return float(val)
78
+
79
+ return default
80
+
81
+
82
+ def has_dispersion(model: nn.Module) -> bool:
83
+ """Check if model has any dispersion module embedded (DFTD3, D3BJ, or D3TS).
84
+
85
+ .. deprecated::
86
+ Use ``model.metadata`` instead. This function iterates through model
87
+ children which is slow and unreliable for JIT models.
88
+
89
+ All dispersion modules need nbmat_lr for neighbor calculations,
90
+ regardless of whether they use tabulated (DFTD3/D3BJ) or learned (D3TS) parameters.
91
+
92
+ Parameters
93
+ ----------
94
+ model : nn.Module
95
+ The model to check.
96
+
97
+ Returns
98
+ -------
99
+ bool
100
+ True if any dispersion module is found.
101
+ """
102
+ import warnings
103
+
104
+ warnings.warn(
105
+ "has_dispersion() is deprecated. Use model._metadata instead.",
106
+ DeprecationWarning,
107
+ stacklevel=2,
108
+ )
109
+ return any(name in ("dftd3", "d3bj", "d3ts") for name, _ in named_children_rec(model))
110
+
111
+
112
+ def has_externalizable_dftd3(model: nn.Module) -> bool:
113
+ """Check if model has DFTD3/D3BJ that can be externalized.
114
+
115
+ D3TS uses learned parameters from the NN and must stay embedded.
116
+ Only DFTD3/D3BJ with tabulated parameters can be externalized.
117
+
118
+ Parameters
119
+ ----------
120
+ model : nn.Module
121
+ The model to check.
122
+
123
+ Returns
124
+ -------
125
+ bool
126
+ True if DFTD3 or D3BJ module is found.
127
+ """
128
+ return any(name in ("dftd3", "d3bj") for name, _ in named_children_rec(model))
129
+
130
+
131
+ def has_d3ts(model: nn.Module) -> bool:
132
+ """Check if model has D3TS module (learned dispersion parameters).
133
+
134
+ D3TS uses learned parameters from the NN and must stay embedded.
135
+
136
+ Parameters
137
+ ----------
138
+ model : nn.Module
139
+ The model to check.
140
+
141
+ Returns
142
+ -------
143
+ bool
144
+ True if D3TS module is found.
145
+ """
146
+ return any(name == "d3ts" for name, _ in named_children_rec(model))
147
+
148
+
149
+ def has_lrcoulomb(model: nn.Module) -> bool:
150
+ """Check if model has LRCoulomb module embedded.
151
+
152
+ Parameters
153
+ ----------
154
+ model : nn.Module
155
+ The model to check.
156
+
157
+ Returns
158
+ -------
159
+ bool
160
+ True if LRCoulomb module is found.
161
+ """
162
+ return any(name == "lrcoulomb" for name, _ in named_children_rec(model))
163
+
164
+
165
+ def iter_lrcoulomb_mods(model: nn.Module) -> Iterator[nn.Module]:
166
+ """Iterate over all LRCoulomb modules in the model.
167
+
168
+ .. deprecated::
169
+ Use ``model.metadata`` instead. This function iterates through model
170
+ children which is slow and unreliable for JIT models.
171
+
172
+ Parameters
173
+ ----------
174
+ model : nn.Module
175
+ The model to search.
176
+
177
+ Yields
178
+ ------
179
+ nn.Module
180
+ Each LRCoulomb module found.
181
+ """
182
+ import warnings
183
+
184
+ warnings.warn(
185
+ "iter_lrcoulomb_mods() is deprecated. Use model._metadata instead.",
186
+ DeprecationWarning,
187
+ stacklevel=2,
188
+ )
189
+ for name, module in named_children_rec(model):
190
+ if name == "lrcoulomb":
191
+ yield module
192
+
193
+
194
+ def extract_d3_params(model: nn.Module) -> dict[str, float] | None:
195
+ """Extract D3 parameters from model's DFTD3/D3BJ module.
196
+
197
+ Only extracts from DFTD3/D3BJ (tabulated params), not D3TS (learned params).
198
+ Handles TorchScript constants which may be stored differently than
199
+ regular Python attributes.
200
+
201
+ Parameters
202
+ ----------
203
+ model : nn.Module
204
+ The model to extract D3 parameters from.
205
+
206
+ Returns
207
+ -------
208
+ dict or None
209
+ Dictionary with s6, s8, a1, a2 parameters, or None if not found.
210
+ """
211
+ for name, module in named_children_rec(model):
212
+ if name in ("dftd3", "d3bj"): # NOT d3ts - it uses learned params
213
+ return {
214
+ "s8": get_jit_attr(module, "s8", 0.0),
215
+ "a1": get_jit_attr(module, "a1", 0.0),
216
+ "a2": get_jit_attr(module, "a2", 0.0),
217
+ "s6": get_jit_attr(module, "s6", 1.0),
218
+ }
219
+ return None
220
+
221
+
222
+ def extract_coulomb_rc(model: nn.Module) -> float:
223
+ """Extract Coulomb cutoff (rc) from model's LRCoulomb module.
224
+
225
+ Parameters
226
+ ----------
227
+ model : nn.Module
228
+ The model to extract the cutoff from.
229
+
230
+ Returns
231
+ -------
232
+ float
233
+ The Coulomb short-range cutoff value.
234
+
235
+ Raises
236
+ ------
237
+ KeyError
238
+ If LRCoulomb module is not found or rc attribute is missing.
239
+ """
240
+ for name, module in named_children_rec(model):
241
+ if name == "lrcoulomb":
242
+ rc = getattr(module, "rc", None)
243
+ if rc is not None:
244
+ return float(rc.item()) if hasattr(rc, "item") else float(rc)
245
+ raise KeyError("LRCoulomb module found but 'rc' attribute is missing")
246
+ raise KeyError("No LRCoulomb module found in model")
247
+
248
+
249
+ def extract_species(model: nn.Module) -> list[int]:
250
+ """Extract implemented species from model's afv.weight (non-NaN entries).
251
+
252
+ Checks afv.weight for non-NaN entries to determine which elements are implemented.
253
+
254
+ Parameters
255
+ ----------
256
+ model : nn.Module
257
+ The model to extract species from.
258
+
259
+ Returns
260
+ -------
261
+ list[int]
262
+ Sorted list of atomic numbers that are implemented in the model.
263
+ """
264
+ sd = model.state_dict()
265
+ afv_weight = sd.get("afv.weight")
266
+ if afv_weight is not None:
267
+ species = []
268
+ for i in range(1, afv_weight.shape[0]):
269
+ # Element is implemented if its row is not all NaN
270
+ if not torch.isnan(afv_weight[i]).all():
271
+ species.append(i)
272
+ return species
273
+ return []
274
+
275
+
276
+ def has_d3ts_in_config(config: dict) -> bool:
277
+ """Check if YAML config contains D3TS module.
278
+
279
+ Parameters
280
+ ----------
281
+ config : dict
282
+ Model YAML configuration dictionary.
283
+
284
+ Returns
285
+ -------
286
+ bool
287
+ True if D3TS is in the outputs section.
288
+ """
289
+ outputs = config.get("kwargs", {}).get("outputs", {})
290
+ return "d3ts" in outputs
291
+
292
+
293
+ def has_dftd3_in_config(config: dict) -> bool:
294
+ """Check if YAML config contains DFTD3 or D3BJ module.
295
+
296
+ Parameters
297
+ ----------
298
+ config : dict
299
+ Model YAML configuration dictionary.
300
+
301
+ Returns
302
+ -------
303
+ bool
304
+ True if DFTD3 or D3BJ is in the outputs section.
305
+ """
306
+ outputs = config.get("kwargs", {}).get("outputs", {})
307
+ return "dftd3" in outputs or "d3bj" in outputs
308
+
309
+
310
+ # --- State dict key validation ---
311
+
312
+
313
+ def validate_state_dict_keys(
314
+ missing_keys: list[str],
315
+ unexpected_keys: list[str],
316
+ ) -> tuple[list[str], list[str]]:
317
+ """Filter out expected missing/unexpected keys during format migration.
318
+
319
+ During v1→v2 model conversion, certain keys are expected to be missing
320
+ (SRCoulomb added) or unexpected (LRCoulomb/DFTD3 removed). This function
321
+ filters those out and returns only keys that indicate actual problems.
322
+
323
+ Parameters
324
+ ----------
325
+ missing_keys : list[str]
326
+ Keys missing from the state dict.
327
+ unexpected_keys : list[str]
328
+ Keys in the state dict that weren't expected.
329
+
330
+ Returns
331
+ -------
332
+ tuple[list[str], list[str]]
333
+ (real_missing, real_unexpected) - keys that indicate actual problems.
334
+ """
335
+ # Prefixes for keys that are expected to be missing/unexpected
336
+ EXPECTED_MISSING_PREFIXES = ("outputs.srcoulomb.",)
337
+ EXPECTED_UNEXPECTED_PREFIXES = (
338
+ "outputs.lrcoulomb.",
339
+ "outputs.dftd3.",
340
+ "outputs.d3bj.",
341
+ )
342
+
343
+ def is_expected_missing(k: str) -> bool:
344
+ return k.startswith(EXPECTED_MISSING_PREFIXES)
345
+
346
+ def is_expected_unexpected(k: str) -> bool:
347
+ return k.startswith(EXPECTED_UNEXPECTED_PREFIXES)
348
+
349
+ real_missing = [k for k in missing_keys if not is_expected_missing(k)]
350
+ real_unexpected = [k for k in unexpected_keys if not is_expected_unexpected(k)]
351
+ return real_missing, real_unexpected
352
+
353
+
354
+ # --- YAML config manipulation ---
355
+
356
+
357
+ def strip_lr_modules_from_yaml(
358
+ config: dict,
359
+ source: dict | nn.Module,
360
+ ) -> tuple[dict, str, bool, dict[str, float] | None, float | None, str, str | None]:
361
+ """Remove LRCoulomb and DFTD3 from YAML config, add SRCoulomb.
362
+
363
+ This is the unified function for both export (from state dict) and
364
+ convert (from JIT model) paths.
365
+
366
+ Parameters
367
+ ----------
368
+ config : dict
369
+ Model YAML configuration dictionary.
370
+ source : dict | nn.Module
371
+ Either a state dict (for export path) or a JIT model (for convert path).
372
+ Used to extract metadata like Coulomb rc and D3 params.
373
+
374
+ Returns
375
+ -------
376
+ tuple
377
+ (config, coulomb_mode, needs_dispersion, d3_params, coulomb_sr_rc, coulomb_sr_envelope, disp_ptfile):
378
+ - config: Modified config with LR modules removed and SRCoulomb added
379
+ - coulomb_mode: "sr_embedded" if LRCoulomb was present, else "none"
380
+ - needs_dispersion: True if DFTD3/D3BJ was present
381
+ - d3_params: D3 parameters dict or None
382
+ - coulomb_sr_rc: Short-range Coulomb cutoff or None
383
+ - coulomb_sr_envelope: Envelope function ("exp" or "cosine")
384
+ - disp_ptfile: Path to DispParam ptfile (if any) for loading buffer
385
+
386
+ Raises
387
+ ------
388
+ ValueError
389
+ If model has both D3TS and DFTD3/D3BJ (double dispersion).
390
+ If LRCoulomb is present but rc cannot be determined.
391
+
392
+ Notes
393
+ -----
394
+ SRCoulomb is added to outputs only when LRCoulomb was present in the
395
+ original config. This ensures proper energy accounting when the
396
+ calculator adds external LRCoulomb.
397
+ """
398
+ import copy
399
+
400
+ config = copy.deepcopy(config)
401
+ outputs = config.get("kwargs", {}).get("outputs", {})
402
+
403
+ # Determine source type
404
+ is_jit_model = isinstance(source, nn.Module)
405
+
406
+ # --- Detect Coulomb ---
407
+ if is_jit_model:
408
+ has_coulomb = has_lrcoulomb(source)
409
+ coulomb_sr_rc = extract_coulomb_rc(source) if has_coulomb else None
410
+ # Legacy models always used exp envelope
411
+ coulomb_sr_envelope = "exp"
412
+ else:
413
+ # State dict path - check YAML config first, then state dict
414
+ has_coulomb_in_sd = any(k.startswith("outputs.lrcoulomb") for k in source)
415
+ if "lrcoulomb" in outputs:
416
+ has_coulomb = True
417
+ lrc_config = outputs["lrcoulomb"]
418
+ lrc_kwargs = lrc_config.get("kwargs", {})
419
+ rc_value = lrc_kwargs.get("rc")
420
+ coulomb_sr_rc = float(rc_value) if rc_value is not None else None
421
+ coulomb_sr_envelope = lrc_kwargs.get("envelope", "exp")
422
+ elif has_coulomb_in_sd:
423
+ has_coulomb = True
424
+ rc_key = "outputs.lrcoulomb.rc"
425
+ coulomb_sr_rc = float(source[rc_key].item()) if rc_key in source else None
426
+ coulomb_sr_envelope = "exp" # Cannot extract from state dict
427
+ else:
428
+ has_coulomb = False
429
+ coulomb_sr_rc = None
430
+ coulomb_sr_envelope = "exp"
431
+
432
+ # Validate: if Coulomb is needed, rc must be determinable
433
+ if has_coulomb and coulomb_sr_rc is None:
434
+ raise ValueError(
435
+ "Model requires Coulomb but 'rc' could not be determined from YAML config or source. "
436
+ "Please specify 'rc' explicitly in the LRCoulomb config kwargs."
437
+ )
438
+
439
+ # --- Detect Dispersion ---
440
+ if is_jit_model:
441
+ # Check if model has dftd3/d3bj modules
442
+ has_d3_module = any(name in ("dftd3", "d3bj") for name, _ in named_children_rec(source))
443
+
444
+ # Check YAML to determine if it's D3TS (not externalizable)
445
+ # D3TS uses NN-predicted C6/alpha and must stay embedded
446
+ is_d3ts = False
447
+ if has_d3_module:
448
+ for key in ["dftd3", "d3bj"]:
449
+ if key in outputs:
450
+ d3_class = outputs[key].get("class", "")
451
+ if "D3TS" in d3_class:
452
+ is_d3ts = True
453
+ break
454
+
455
+ # Only externalize if NOT D3TS (DFTD3/D3BJ with tabulated params can be externalized)
456
+ needs_dispersion = has_d3_module and not is_d3ts
457
+
458
+ if needs_dispersion:
459
+ # Try to extract from JIT model first
460
+ d3_params = extract_d3_params(source)
461
+ # If extraction failed or returned zeros, try YAML config
462
+ if d3_params is None or (
463
+ d3_params.get("s8") == 0.0 and d3_params.get("a1") == 0.0 and d3_params.get("a2") == 0.0
464
+ ):
465
+ for key in ["dftd3", "d3bj"]:
466
+ if key in outputs:
467
+ d3_config = outputs[key]
468
+ d3_kwargs = d3_config.get("kwargs", {})
469
+ d3_params = {
470
+ "s8": d3_kwargs.get("s8", 0.0),
471
+ "a1": d3_kwargs.get("a1", 0.0),
472
+ "a2": d3_kwargs.get("a2", 0.0),
473
+ "s6": d3_kwargs.get("s6", 1.0),
474
+ }
475
+ break
476
+ else:
477
+ d3_params = None
478
+ else:
479
+ # State dict path - check YAML config
480
+ needs_dispersion = False
481
+ d3_params = None
482
+ for key in ["dftd3", "d3bj"]:
483
+ if key in outputs:
484
+ d3_config = outputs[key]
485
+ # Check if it's D3TS (must stay embedded, not externalizable)
486
+ module_class = d3_config.get("class", "")
487
+ if "D3TS" in module_class:
488
+ # D3TS uses NN-predicted C6/alpha, must stay embedded
489
+ needs_dispersion = False
490
+ d3_params = None
491
+ break
492
+ # DFTD3/D3BJ with tabulated params can be externalized
493
+ needs_dispersion = True
494
+ d3_kwargs = d3_config.get("kwargs", {})
495
+ d3_params = {
496
+ "s8": d3_kwargs.get("s8", 0.0),
497
+ "a1": d3_kwargs.get("a1", 0.0),
498
+ "a2": d3_kwargs.get("a2", 0.0),
499
+ "s6": d3_kwargs.get("s6", 1.0),
500
+ }
501
+ break
502
+
503
+ # Validate: D3TS + DFTD3/D3BJ is invalid (would cause double dispersion)
504
+ has_d3ts_model = has_d3ts(source) if is_jit_model else False
505
+ if needs_dispersion and (has_d3ts_model or has_d3ts_in_config(config)):
506
+ raise ValueError(
507
+ "Model has both D3TS (learned) and DFTD3/D3BJ (tabulated) dispersion. "
508
+ "D3TS uses learned parameters and must stay embedded, while DFTD3/D3BJ "
509
+ "would be externalized. This configuration leads to double dispersion "
510
+ "correction. Remove either D3TS or DFTD3/D3BJ from the model."
511
+ )
512
+
513
+ # --- Rebuild outputs dict ---
514
+ new_outputs = {}
515
+ for key, value in outputs.items():
516
+ if key == "lrcoulomb":
517
+ pass # Will be added externally by calculator
518
+ elif key in ["dftd3", "d3bj"]:
519
+ # Check if it's D3TS (must stay embedded)
520
+ module_class = value.get("class", "")
521
+ if "D3TS" in module_class:
522
+ new_outputs[key] = value # Keep D3TS embedded
523
+ else:
524
+ pass # Remove DFTD3/D3BJ for externalization
525
+ else:
526
+ new_outputs[key] = value
527
+
528
+ # Strip ptfile from DispParam configs but save the path
529
+ # (raw training weights don't contain disp_param0 buffer, need to load from ptfile)
530
+ disp_ptfile: str | None = None
531
+ for _key, value in new_outputs.items():
532
+ if isinstance(value, dict):
533
+ module_class = value.get("class", "")
534
+ if "DispParam" in module_class:
535
+ kwargs = value.get("kwargs", {})
536
+ if "ptfile" in kwargs:
537
+ disp_ptfile = kwargs.pop("ptfile") # Save before removing
538
+
539
+ # Add SRCoulomb if LRCoulomb was present
540
+ if has_coulomb:
541
+ new_outputs["srcoulomb"] = {
542
+ "class": "aimnet.modules.SRCoulomb",
543
+ "kwargs": {
544
+ "rc": coulomb_sr_rc,
545
+ "key_in": "charges",
546
+ "key_out": "energy",
547
+ "envelope": coulomb_sr_envelope,
548
+ },
549
+ }
550
+
551
+ config["kwargs"]["outputs"] = new_outputs
552
+ coulomb_mode = "sr_embedded" if has_coulomb else "none"
553
+
554
+ return (
555
+ config,
556
+ coulomb_mode,
557
+ needs_dispersion,
558
+ d3_params,
559
+ coulomb_sr_rc,
560
+ coulomb_sr_envelope if coulomb_sr_envelope else "exp",
561
+ disp_ptfile,
562
+ )
563
+
564
+
565
+ # --- Model loading ---
566
+
567
+
568
+ def load_v1_model(
569
+ jpt_path: str,
570
+ yaml_config_path: str,
571
+ output_path: str | None = None,
572
+ verbose: bool = True,
573
+ ) -> tuple[nn.Module, dict]:
574
+ """Load legacy JIT model (v1) and convert to v2 format.
575
+
576
+ This is the primary entry point for loading legacy models.
577
+
578
+ Parameters
579
+ ----------
580
+ jpt_path : str
581
+ Path to the input JIT-compiled model file (.jpt).
582
+ yaml_config_path : str
583
+ Path to the model YAML configuration file.
584
+ output_path : str, optional
585
+ If provided, save the converted model to this path.
586
+ verbose : bool
587
+ Whether to print progress messages.
588
+
589
+ Returns
590
+ -------
591
+ model : nn.Module
592
+ The loaded model in v2 format.
593
+ metadata : dict
594
+ Model metadata dictionary with keys:
595
+ - format_version: 2
596
+ - cutoff: float
597
+ - needs_coulomb: bool
598
+ - needs_dispersion: bool
599
+ - coulomb_mode: str
600
+ - coulomb_sr_rc: float | None
601
+ - coulomb_sr_envelope: str | None
602
+ - d3_params: dict | None
603
+ - implemented_species: list[int]
604
+
605
+ Example
606
+ -------
607
+ >>> from aimnet.models.utils import load_v1_model
608
+ >>> model, metadata = load_v1_model("model.jpt", "config.yaml")
609
+ >>> print(metadata["format_version"]) # 2
610
+
611
+ Warnings
612
+ --------
613
+ UserWarning
614
+ If D3 parameter extraction produces zero values.
615
+ """
616
+ import copy
617
+
618
+ import torch
619
+ import yaml
620
+
621
+ from aimnet.config import build_module
622
+
623
+ # Load YAML config
624
+ with open(yaml_config_path, encoding="utf-8") as f:
625
+ model_config = yaml.safe_load(f)
626
+
627
+ # Load JIT model
628
+ if verbose:
629
+ print(f"Loading JIT model from {jpt_path}")
630
+ jit_model = torch.jit.load(jpt_path, map_location="cpu")
631
+
632
+ # Extract metadata from JIT
633
+ cutoff = float(jit_model.cutoff)
634
+ implemented_species = extract_species(jit_model)
635
+
636
+ # Strip LR modules from YAML and add SRCoulomb
637
+ # Note: disp_ptfile is unused here because JIT model already has disp_param0 in its state dict
638
+ core_config, coulomb_mode, needs_dispersion, d3_params, coulomb_sr_rc, coulomb_sr_envelope, _disp_ptfile = (
639
+ strip_lr_modules_from_yaml(model_config, jit_model)
640
+ )
641
+
642
+ # Inform user about dispersion handling
643
+ if verbose:
644
+ if needs_dispersion:
645
+ # External dispersion (DFTD3/D3BJ with tabulated params)
646
+ if d3_params is None:
647
+ print("WARNING: Model has DFTD3 module but D3 params extraction failed!")
648
+ elif d3_params.get("s8") == 0.0 and d3_params.get("a1") == 0.0 and d3_params.get("a2") == 0.0:
649
+ print("WARNING: D3 params appear to be all zeros - extraction may have failed!")
650
+ print(f" Extracted: {d3_params}")
651
+ else:
652
+ print(
653
+ f" D3 parameters: s6={d3_params['s6']}, s8={d3_params['s8']}, "
654
+ f"a1={d3_params['a1']}, a2={d3_params['a2']}"
655
+ )
656
+ else:
657
+ # Check if D3TS is embedded
658
+ outputs = model_config.get("kwargs", {}).get("outputs", {})
659
+ has_d3ts = any("D3TS" in outputs.get(k, {}).get("class", "") for k in ["dftd3", "d3bj", "d3ts"])
660
+ if has_d3ts:
661
+ print(" D3TS dispersion kept embedded (uses NN-predicted C6/alpha)")
662
+
663
+ # Detect if model has any embedded LR modules that need nbmat_lr
664
+ outputs = model_config.get("kwargs", {}).get("outputs", {})
665
+ has_embedded_lr = False
666
+
667
+ # Check for embedded D3TS
668
+ has_d3ts = any("D3TS" in outputs.get(k, {}).get("class", "") for k in ["dftd3", "d3bj", "d3ts"])
669
+ if has_d3ts:
670
+ has_embedded_lr = True
671
+
672
+ # Check for embedded SRCoulomb (model had LRCoulomb before conversion)
673
+ if coulomb_mode == "sr_embedded":
674
+ has_embedded_lr = True
675
+
676
+ # Convert config to YAML string
677
+ core_yaml_str = yaml.dump(core_config, default_flow_style=False, sort_keys=False)
678
+
679
+ # Build model from modified config
680
+ if verbose:
681
+ print("Building model from YAML config...")
682
+ core_model = build_module(copy.deepcopy(core_config))
683
+
684
+ # Load weights from JIT model
685
+ jit_sd = jit_model.state_dict()
686
+ load_result = core_model.load_state_dict(jit_sd, strict=False)
687
+
688
+ # Validate keys
689
+ real_missing, real_unexpected = validate_state_dict_keys(load_result.missing_keys, load_result.unexpected_keys)
690
+ if real_missing:
691
+ print(f"WARNING: Unexpected missing keys: {real_missing}")
692
+ if real_unexpected:
693
+ print(f"WARNING: Unexpected extra keys: {real_unexpected}")
694
+ if not real_missing and not real_unexpected and verbose:
695
+ print("Loaded weights successfully")
696
+
697
+ # Convert atomic_shift to float64 to preserve SAE precision
698
+ if hasattr(core_model, "outputs") and hasattr(core_model.outputs, "atomic_shift"):
699
+ core_model.outputs.atomic_shift.double()
700
+ atomic_shift_key = "outputs.atomic_shift.shifts.weight"
701
+ if atomic_shift_key in jit_sd:
702
+ core_model.outputs.atomic_shift.shifts.weight.data.copy_(jit_sd[atomic_shift_key])
703
+ if verbose:
704
+ print(" Atomic shift converted to float64")
705
+
706
+ core_model.eval()
707
+
708
+ # Create metadata
709
+ needs_coulomb = coulomb_mode == "sr_embedded"
710
+ metadata = {
711
+ "format_version": 2,
712
+ "model_yaml": core_yaml_str,
713
+ "cutoff": cutoff,
714
+ "needs_coulomb": needs_coulomb,
715
+ "needs_dispersion": needs_dispersion,
716
+ "coulomb_mode": coulomb_mode,
717
+ "coulomb_sr_rc": coulomb_sr_rc if needs_coulomb else None,
718
+ "coulomb_sr_envelope": coulomb_sr_envelope if needs_coulomb else None,
719
+ "d3_params": d3_params if needs_dispersion else None,
720
+ "has_embedded_lr": has_embedded_lr,
721
+ "implemented_species": implemented_species,
722
+ }
723
+
724
+ # Save if output path provided
725
+ if output_path is not None:
726
+ save_data = {**metadata, "state_dict": core_model.state_dict()}
727
+ torch.save(save_data, output_path)
728
+ if verbose:
729
+ print(f"\nSaved model to {output_path}")
730
+ print(f" cutoff: {cutoff:.3f}")
731
+ print(f" needs_coulomb: {needs_coulomb}")
732
+ print(f" needs_dispersion: {needs_dispersion}")
733
+ print(f" has_embedded_lr: {has_embedded_lr}")
734
+
735
+ return core_model, metadata