aimnet 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aimnet/nbops.py CHANGED
@@ -1,10 +1,8 @@
1
- from typing import Dict, Tuple
2
-
3
1
  import torch
4
2
  from torch import Tensor
5
3
 
6
4
 
7
- def set_nb_mode(data: Dict[str, Tensor]) -> Dict[str, Tensor]:
5
+ def set_nb_mode(data: dict[str, Tensor]) -> dict[str, Tensor]:
8
6
  """Logic to guess and set the neighbor model."""
9
7
  if "nbmat" in data:
10
8
  if data["nbmat"].ndim == 2:
@@ -18,12 +16,12 @@ def set_nb_mode(data: Dict[str, Tensor]) -> Dict[str, Tensor]:
18
16
  return data
19
17
 
20
18
 
21
- def get_nb_mode(data: Dict[str, Tensor]) -> int:
19
+ def get_nb_mode(data: dict[str, Tensor]) -> int:
22
20
  """Get the neighbor model."""
23
21
  return int(data["_nb_mode"].item())
24
22
 
25
23
 
26
- def calc_masks(data: Dict[str, Tensor]) -> Dict[str, Tensor]:
24
+ def calc_masks(data: dict[str, Tensor]) -> dict[str, Tensor]:
27
25
  """Calculate neighbor masks"""
28
26
  nb_mode = get_nb_mode(data)
29
27
  if nb_mode == 0:
@@ -45,9 +43,20 @@ def calc_masks(data: Dict[str, Tensor]) -> Dict[str, Tensor]:
45
43
  # padding must be the last atom
46
44
  data["mask_i"] = torch.zeros(data["numbers"].shape[0], device=data["numbers"].device, dtype=torch.bool)
47
45
  data["mask_i"][-1] = True
48
- for suffix in ("", "_lr"):
49
- if f"nbmat{suffix}" in data:
50
- data[f"mask_ij{suffix}"] = data[f"nbmat{suffix}"] == data["numbers"].shape[0] - 1
46
+ # Track processed arrays by their data pointer to avoid redundant mask calculations
47
+ processed: dict[int, str] = {} # data_ptr -> mask_suffix
48
+ for suffix in ("", "_lr", "_coulomb", "_dftd3"):
49
+ nbmat_key = f"nbmat{suffix}"
50
+ if nbmat_key in data:
51
+ if not torch.jit.is_scripting():
52
+ # data_ptr() not supported in TorchScript
53
+ ptr = data[nbmat_key].data_ptr()
54
+ if ptr in processed:
55
+ # Same array - reuse existing mask
56
+ data[f"mask_ij{suffix}"] = data[f"mask_ij{processed[ptr]}"]
57
+ continue
58
+ processed[ptr] = suffix
59
+ data[f"mask_ij{suffix}"] = data[nbmat_key] == data["numbers"].shape[0] - 1
51
60
  data["_input_padded"] = torch.tensor(True)
52
61
  data["mol_sizes"] = torch.bincount(data["mol_idx"])
53
62
  # last atom is padding
@@ -56,9 +65,20 @@ def calc_masks(data: Dict[str, Tensor]) -> Dict[str, Tensor]:
56
65
  data["mask_i"] = data["numbers"] == 0
57
66
  w = torch.where(data["mask_i"])
58
67
  pad_idx = w[0] * data["numbers"].shape[1] + w[1]
59
- for suffix in ("", "_lr"):
60
- if f"nbmat{suffix}" in data:
61
- data[f"mask_ij{suffix}"] = torch.isin(data[f"nbmat{suffix}"], pad_idx)
68
+ # Track processed arrays by their data pointer to avoid redundant mask calculations
69
+ processed: dict[int, str] = {} # data_ptr -> mask_suffix
70
+ for suffix in ("", "_lr", "_coulomb", "_dftd3"):
71
+ nbmat_key = f"nbmat{suffix}"
72
+ if nbmat_key in data:
73
+ if not torch.jit.is_scripting():
74
+ # data_ptr() not supported in TorchScript
75
+ ptr = data[nbmat_key].data_ptr()
76
+ if ptr in processed:
77
+ # Same array - reuse existing mask
78
+ data[f"mask_ij{suffix}"] = data[f"mask_ij{processed[ptr]}"]
79
+ continue
80
+ processed[ptr] = suffix
81
+ data[f"mask_ij{suffix}"] = torch.isin(data[nbmat_key], pad_idx)
62
82
  data["_input_padded"] = torch.tensor(True)
63
83
  data["mol_sizes"] = (~data["mask_i"]).sum(-1)
64
84
  else:
@@ -69,7 +89,7 @@ def calc_masks(data: Dict[str, Tensor]) -> Dict[str, Tensor]:
69
89
 
70
90
  def mask_ij_(
71
91
  x: Tensor,
72
- data: Dict[str, Tensor],
92
+ data: dict[str, Tensor],
73
93
  mask_value: float = 0.0,
74
94
  inplace: bool = True,
75
95
  suffix: str = "",
@@ -84,7 +104,7 @@ def mask_ij_(
84
104
  return x
85
105
 
86
106
 
87
- def mask_i_(x: Tensor, data: Dict[str, Tensor], mask_value: float = 0.0, inplace: bool = True) -> Tensor:
107
+ def mask_i_(x: Tensor, data: dict[str, Tensor], mask_value: float = 0.0, inplace: bool = True) -> Tensor:
88
108
  nb_mode = get_nb_mode(data)
89
109
  if nb_mode == 0:
90
110
  if data["_input_padded"].item():
@@ -110,7 +130,47 @@ def mask_i_(x: Tensor, data: Dict[str, Tensor], mask_value: float = 0.0, inplace
110
130
  return x
111
131
 
112
132
 
113
- def get_ij(x: Tensor, data: Dict[str, Tensor], suffix: str = "") -> Tuple[Tensor, Tensor]:
133
+ def resolve_suffix(data: dict[str, Tensor], suffixes: list[str]) -> str:
134
+ """Try suffixes in order, return first found, raise if none exist.
135
+
136
+ This function makes fallback behavior explicit by requiring a list
137
+ of acceptable suffixes. Each module controls which neighbor lists
138
+ are acceptable for its operations.
139
+
140
+ For nb_mode=0 (no neighbor matrix), returns empty string since
141
+ neighbor lists are not used in that mode.
142
+
143
+ Parameters
144
+ ----------
145
+ data : dict
146
+ Data dictionary containing neighbor matrices.
147
+ suffixes : list[str]
148
+ List of suffixes to try in priority order (e.g., ["_dftd3", "_lr"]).
149
+ Empty string "" can be included for fallback to base nbmat.
150
+
151
+ Returns
152
+ -------
153
+ str
154
+ The first suffix that has a corresponding nbmat{suffix} in data.
155
+
156
+ Raises
157
+ ------
158
+ KeyError
159
+ If none of the suffixes have corresponding neighbor matrices.
160
+ """
161
+ # In nb_mode=0, there are no neighbor matrices - suffix is unused
162
+ nb_mode = get_nb_mode(data)
163
+ if nb_mode == 0:
164
+ return ""
165
+
166
+ for suffix in suffixes:
167
+ if f"nbmat{suffix}" in data:
168
+ return suffix
169
+
170
+ raise KeyError(f"No neighbor matrix found for any suffix in {suffixes}")
171
+
172
+
173
+ def get_ij(x: Tensor, data: dict[str, Tensor], suffix: str = "") -> tuple[Tensor, Tensor]:
114
174
  nb_mode = get_nb_mode(data)
115
175
  if nb_mode == 0:
116
176
  x_i = x.unsqueeze(2)
@@ -128,7 +188,36 @@ def get_ij(x: Tensor, data: Dict[str, Tensor], suffix: str = "") -> Tuple[Tensor
128
188
  return x_i, x_j
129
189
 
130
190
 
131
- def mol_sum(x: Tensor, data: Dict[str, Tensor]) -> Tensor:
191
+ def get_i(x: Tensor, data: dict[str, Tensor]) -> Tensor:
192
+ """Get the i-component of pairwise expansion without computing j.
193
+
194
+ This is an optimized version of get_ij when only x_i is needed,
195
+ avoiding the expensive index_select operation for x_j.
196
+
197
+ Parameters
198
+ ----------
199
+ x : Tensor
200
+ Input tensor to expand.
201
+ data : dict[str, Tensor]
202
+ Data dictionary containing neighbor mode information.
203
+
204
+ Returns
205
+ -------
206
+ Tensor
207
+ The i-component with appropriate unsqueeze for the neighbor mode.
208
+ """
209
+ nb_mode = get_nb_mode(data)
210
+ if nb_mode == 0:
211
+ return x.unsqueeze(2)
212
+ elif nb_mode == 1:
213
+ return x.unsqueeze(1)
214
+ elif nb_mode == 2:
215
+ return x.unsqueeze(2)
216
+ else:
217
+ raise ValueError(f"Invalid neighbor mode: {nb_mode}")
218
+
219
+
220
+ def mol_sum(x: Tensor, data: dict[str, Tensor]) -> Tensor:
132
221
  nb_mode = get_nb_mode(data)
133
222
  if nb_mode in (0, 2):
134
223
  res = x.sum(dim=1)
@@ -140,6 +229,7 @@ def mol_sum(x: Tensor, data: Dict[str, Tensor]) -> Tensor:
140
229
  idx = data["mol_idx"]
141
230
  # assuming mol_idx is sorted, replace with max if not
142
231
  out_size = int(idx[-1].item()) + 1
232
+
143
233
  if x.ndim == 1:
144
234
  res = torch.zeros(out_size, device=x.device, dtype=x.dtype)
145
235
  else:
aimnet/ops.py CHANGED
@@ -1,5 +1,4 @@
1
1
  import math
2
- from typing import Dict, Optional, Tuple
3
2
 
4
3
  import torch
5
4
  from torch import Tensor
@@ -7,7 +6,7 @@ from torch import Tensor
7
6
  from aimnet import nbops
8
7
 
9
8
 
10
- def lazy_calc_dij_lr(data: Dict[str, Tensor]) -> Dict[str, Tensor]:
9
+ def lazy_calc_dij_lr(data: dict[str, Tensor]) -> dict[str, Tensor]:
11
10
  if "d_ij_lr" not in data:
12
11
  nb_mode = nbops.get_nb_mode(data)
13
12
  if nb_mode == 0:
@@ -17,23 +16,67 @@ def lazy_calc_dij_lr(data: Dict[str, Tensor]) -> Dict[str, Tensor]:
17
16
  return data
18
17
 
19
18
 
20
- def calc_distances(data: Dict[str, Tensor], suffix: str = "", pad_value: float = 1.0) -> Tuple[Tensor, Tensor]:
19
+ def lazy_calc_dij(data: dict[str, Tensor], suffix: str) -> dict[str, Tensor]:
20
+ """Lazily calculate distances for a given suffix.
21
+
22
+ Computes and caches d_ij{suffix} in data dict if not present.
23
+ For nb_mode=0 (no neighbor list), reuses d_ij.
24
+
25
+ Parameters
26
+ ----------
27
+ data : dict
28
+ Data dictionary.
29
+ suffix : str
30
+ Suffix for neighbor matrix (e.g., "_coulomb", "_dftd3", "_lr").
31
+
32
+ Returns
33
+ -------
34
+ dict
35
+ Data dictionary with d_ij{suffix} added.
36
+ """
37
+ key = f"d_ij{suffix}"
38
+ if key not in data:
39
+ nb_mode = nbops.get_nb_mode(data)
40
+ if nb_mode == 0:
41
+ data[key] = data["d_ij"]
42
+ else:
43
+ data[key] = calc_distances(data, suffix=suffix)[0]
44
+ return data
45
+
46
+
47
+ def calc_distances(data: dict[str, Tensor], suffix: str = "", pad_value: float = 1.0) -> tuple[Tensor, Tensor]:
21
48
  coord_i, coord_j = nbops.get_ij(data["coord"], data, suffix)
22
49
  if f"shifts{suffix}" in data:
23
50
  assert "cell" in data, "cell is required if shifts are provided"
24
51
  nb_mode = nbops.get_nb_mode(data)
52
+ cell = data["cell"]
25
53
  if nb_mode == 2:
26
- shifts = torch.einsum("bnmd,bdh->bnmh", data[f"shifts{suffix}"], data["cell"])
54
+ # Batched format: shifts (B, N, M, 3), cell (B, 3, 3) or (3, 3)
55
+ if cell.ndim == 2:
56
+ shifts = torch.einsum("bnmd,dh->bnmh", data[f"shifts{suffix}"], cell)
57
+ else:
58
+ shifts = torch.einsum("bnmd,bdh->bnmh", data[f"shifts{suffix}"], cell)
59
+ elif nb_mode == 1:
60
+ # Flat format: shifts (N_total, M, 3), cell (3, 3) or (B, 3, 3)
61
+ if cell.ndim == 2:
62
+ shifts = data[f"shifts{suffix}"] @ cell
63
+ else:
64
+ # Batched cells - need mol_idx to select correct cell for each atom
65
+ mol_idx = data["mol_idx"]
66
+ atom_cell = cell[mol_idx] # (N_total, 3, 3)
67
+ # shifts: (N_total, M, 3), atom_cell: (N_total, 3, 3)
68
+ shifts = torch.einsum("nmd,ndh->nmh", data[f"shifts{suffix}"], atom_cell)
27
69
  else:
28
- shifts = data[f"shifts{suffix}"] @ data["cell"]
70
+ # nb_mode == 0: no neighbor matrix, shouldn't have shifts
71
+ shifts = data[f"shifts{suffix}"] @ cell
29
72
  coord_j = coord_j + shifts
30
73
  r_ij = coord_j - coord_i
74
+ r_ij = nbops.mask_ij_(r_ij, data, mask_value=pad_value, inplace=False, suffix=suffix)
31
75
  d_ij = torch.norm(r_ij, p=2, dim=-1)
32
- d_ij = nbops.mask_ij_(d_ij, data, mask_value=pad_value, inplace=False, suffix=suffix)
33
76
  return d_ij, r_ij
34
77
 
35
78
 
36
- def center_coordinates(coord: Tensor, data: Dict[str, Tensor], masses: Optional[Tensor] = None) -> Tensor:
79
+ def center_coordinates(coord: Tensor, data: dict[str, Tensor], masses: Tensor | None = None) -> Tensor:
37
80
  if masses is not None:
38
81
  masses = masses.unsqueeze(-1)
39
82
  center = nbops.mol_sum(coord * masses, data) / nbops.mol_sum(masses, data) / data["mol_sizes"].unsqueeze(-1)
@@ -61,16 +104,17 @@ def exp_expand(d_ij: Tensor, shifts: Tensor, eta: float) -> Tensor:
61
104
  return torch.exp(-eta * (d_ij.unsqueeze(-1) - shifts) ** 2)
62
105
 
63
106
 
64
- # pylint: disable=invalid-name
65
107
  def nse(
66
108
  Q: Tensor,
67
109
  q_u: Tensor,
68
110
  f_u: Tensor,
69
- data: Dict[str, Tensor],
111
+ data: dict[str, Tensor],
70
112
  epsilon: float = 1.0e-6,
71
113
  ) -> Tensor:
72
114
  # Q and q_u and f_u must have last dimension size 1 or 2
73
- F_u = nbops.mol_sum(f_u, data) + epsilon
115
+ F_u = nbops.mol_sum(f_u, data)
116
+ if epsilon > 0:
117
+ F_u = F_u + epsilon
74
118
  Q_u = nbops.mol_sum(q_u, data)
75
119
  dQ = Q - Q_u
76
120
  # for loss
@@ -92,30 +136,36 @@ def nse(
92
136
  return q
93
137
 
94
138
 
95
- def coulomb_matrix_dsf(d_ij: Tensor, Rc: float, alpha: float, data: Dict[str, Tensor]) -> Tensor:
139
+ def coulomb_matrix_dsf(d_ij: Tensor, Rc: float, alpha: float, data: dict[str, Tensor]) -> Tensor:
96
140
  _c1 = (alpha * d_ij).erfc() / d_ij
97
141
  _c2 = math.erfc(alpha * Rc) / Rc
98
142
  _c3 = _c2 / Rc
99
143
  _c4 = 2 * alpha * math.exp(-((alpha * Rc) ** 2)) / (Rc * math.pi**0.5)
100
144
  J = _c1 - _c2 + (d_ij - Rc) * (_c3 + _c4)
101
- # mask for d_ij > Rc
102
- mask = data["mask_ij_lr"] & (d_ij > Rc)
145
+ # Zero invalid pairs: padding/diagonal (mask_ij_lr) OR beyond cutoff
146
+ mask = data["mask_ij_lr"] | (d_ij > Rc)
103
147
  J.masked_fill_(mask, 0.0)
104
148
  return J
105
149
 
106
150
 
107
- def coulomb_matrix_sf(q_j: Tensor, d_ij: Tensor, Rc: float, data: Dict[str, Tensor]) -> Tensor:
151
+ def coulomb_matrix_sf(q_j: Tensor, d_ij: Tensor, Rc: float, data: dict[str, Tensor]) -> Tensor:
108
152
  _c1 = 1.0 / d_ij
109
153
  _c2 = 1.0 / Rc
110
154
  _c3 = _c2 / Rc
111
155
  J = _c1 - _c2 + (d_ij - Rc) * _c3
112
- mask = data["mask_ij_lr"] & (d_ij > Rc)
156
+ # Zero invalid pairs: padding/diagonal (mask_ij_lr) OR beyond cutoff
157
+ mask = data["mask_ij_lr"] | (d_ij > Rc)
113
158
  J.masked_fill_(mask, 0.0)
114
159
  return J
115
160
 
116
161
 
117
162
  def get_shifts_within_cutoff(cell: Tensor, cutoff: Tensor) -> Tensor:
118
- assert cell.shape == (3, 3), "Batch cell is not supported"
163
+ """Get all lattice shift vectors within cutoff distance.
164
+
165
+ Note: Batched cells are not supported - this function is only used by Ewald summation
166
+ which is a single-molecule calculation.
167
+ """
168
+ assert cell.ndim == 2 and cell.shape == (3, 3), "Batched cells not supported for Ewald summation"
119
169
  cell_inv = torch.inverse(cell).mT
120
170
  inv_distances = cell_inv.norm(p=2, dim=-1)
121
171
  num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long)
@@ -128,10 +178,32 @@ def get_shifts_within_cutoff(cell: Tensor, cutoff: Tensor) -> Tensor:
128
178
  return shifts
129
179
 
130
180
 
131
- def coulomb_matrix_ewald(coord: Tensor, cell: Tensor) -> Tensor:
181
+ def coulomb_matrix_ewald(coord: Tensor, cell: Tensor, accuracy: float = 1e-8) -> Tensor:
182
+ """Compute Coulomb matrix using Ewald summation.
183
+
184
+ Parameters
185
+ ----------
186
+ coord : Tensor
187
+ Atomic coordinates, shape (N, 3).
188
+ cell : Tensor
189
+ Unit cell vectors, shape (3, 3).
190
+ accuracy : float
191
+ Target accuracy for the Ewald summation. Controls the real-space
192
+ and reciprocal-space cutoffs. Lower values give higher accuracy
193
+ but require more computation. Default is 1e-8.
194
+
195
+ The cutoffs are computed as:
196
+ - eta = (V^2 / N)^(1/6) / sqrt(2*pi)
197
+ - cutoff_real = sqrt(-2 * ln(accuracy)) * eta
198
+ - cutoff_recip = sqrt(-2 * ln(accuracy)) / eta
199
+
200
+ Returns
201
+ -------
202
+ Tensor
203
+ Coulomb matrix J, shape (N, N).
204
+ """
132
205
  # single molecule implementation. nb_mode == 1
133
206
  assert coord.ndim == 2 and cell.ndim == 2, "Only single molecule is supported"
134
- accuracy = 1e-8
135
207
  N = coord.shape[0]
136
208
  volume = torch.det(cell)
137
209
  eta = ((volume**2 / N) ** (1 / 6)) / math.sqrt(2.0 * math.pi)
@@ -0,0 +1,226 @@
1
+ #!/usr/bin/env python3
2
+ """Export trained model to distributable state dict format.
3
+
4
+ This script creates a self-contained .pt file from training artifacts:
5
+ - Raw PyTorch weights (.pt)
6
+ - Self-atomic energies (.sae)
7
+ - Model YAML configuration
8
+
9
+ The output file contains:
10
+ - model_yaml: Core model config (without LRCoulomb/DFTD3, with SRCoulomb if needed)
11
+ - cutoff: Model cutoff
12
+ - needs_coulomb: Whether calculator should add external Coulomb
13
+ - needs_dispersion: Whether calculator should add external DFTD3
14
+ - coulomb_mode: "sr_embedded" | "none" (describes what's in the model)
15
+ - coulomb_sr_rc: Coulomb short-range cutoff (optional, if coulomb_mode="sr_embedded")
16
+ - coulomb_sr_envelope: Envelope function ("exp" or "cosine", optional)
17
+ - d3_params: D3 parameters {s8, a1, a2, s6} (optional, if needs_dispersion=True)
18
+ - has_embedded_lr: Whether model has embedded LR modules (D3TS, SRCoulomb) needing nbmat_lr
19
+ - implemented_species: Parametrized atomic numbers
20
+ - state_dict: Model weights with SAE baked into atomic_shift (float64)
21
+ """
22
+
23
+ import copy
24
+
25
+ import click
26
+ import torch
27
+ import yaml
28
+ from torch import nn
29
+
30
+ from aimnet.config import build_module, load_yaml
31
+ from aimnet.models.utils import strip_lr_modules_from_yaml, validate_state_dict_keys
32
+
33
+
34
+ def load_sae(sae_file: str) -> dict[int, float]:
35
+ """Load SAE file (YAML-like format: atomic_number: energy)."""
36
+ sae = load_yaml(sae_file)
37
+ if not isinstance(sae, dict):
38
+ raise TypeError("SAE file must contain a dictionary.")
39
+ return {int(k): float(v) for k, v in sae.items()}
40
+
41
+
42
+ def bake_sae_into_model(model: nn.Module, sae: dict[int, float]) -> nn.Module:
43
+ """Add SAE values to atomic_shift.shifts.weight (converted to float64)."""
44
+ # Disable gradients before in-place operation
45
+ for p in model.parameters():
46
+ p.requires_grad_(False)
47
+ model.outputs.atomic_shift.double() # type: ignore
48
+ for k, v in sae.items():
49
+ model.outputs.atomic_shift.shifts.weight[k] += v # type: ignore
50
+ return model
51
+
52
+
53
+ def extract_cutoff(model: nn.Module) -> float:
54
+ """Extract cutoff from model's AEV module."""
55
+ return float(model.aev.rc_s.item()) # type: ignore
56
+
57
+
58
+ def get_implemented_species(sae: dict[int, float]) -> list[int]:
59
+ """Get list of implemented species from SAE."""
60
+ return sorted(sae.keys())
61
+
62
+
63
+ def mask_not_implemented_species(model: nn.Module, species: list[int]) -> nn.Module:
64
+ """Set NaN for species not in the SAE."""
65
+ weight = model.afv.weight # type: ignore
66
+ for i in range(1, weight.shape[0]): # type: ignore
67
+ if i not in species:
68
+ weight[i, :] = torch.nan # type: ignore
69
+ return model
70
+
71
+
72
+ @click.command()
73
+ @click.argument("weights", type=click.Path(exists=True))
74
+ @click.argument("output", type=str)
75
+ @click.option("--model", "-m", type=click.Path(exists=True), required=True, help="Path to model definition YAML file")
76
+ @click.option("--sae", "-s", type=click.Path(exists=True), required=True, help="Path to the SAE YAML file")
77
+ @click.option(
78
+ "--needs-coulomb/--no-coulomb", default=None, help="Override Coulomb detection. Default: auto-detect from YAML"
79
+ )
80
+ @click.option(
81
+ "--needs-dispersion/--no-dispersion",
82
+ default=None,
83
+ help="Override dispersion detection. Default: auto-detect from YAML",
84
+ )
85
+ def export_model(
86
+ weights: str,
87
+ output: str,
88
+ model: str,
89
+ sae: str,
90
+ needs_coulomb: bool | None,
91
+ needs_dispersion: bool | None,
92
+ ):
93
+ """Export trained model to distributable state dict format.
94
+
95
+ weights: Path to the raw PyTorch weights file (.pt).
96
+ outoput: Path to the output .pt file.
97
+
98
+ Example:
99
+ aimnet export weights.pt model.pt --model config.yaml --sae model.sae
100
+ """
101
+ # Load model YAML
102
+ print(f"Loading config from {model}")
103
+ with open(model, encoding="utf-8") as f:
104
+ model_config = yaml.safe_load(f)
105
+
106
+ # Load SAE
107
+ print(f"Loading SAE from {sae}")
108
+ sae_dict = load_sae(sae)
109
+ implemented_species = get_implemented_species(sae_dict)
110
+
111
+ # Load source state dict
112
+ print(f"Loading weights from {weights}")
113
+ source_sd = torch.load(weights, map_location="cpu", weights_only=True)
114
+
115
+ # Strip LR modules and detect flags
116
+ core_config, coulomb_mode, needs_dispersion_auto, d3_params, coulomb_sr_rc, coulomb_sr_envelope, disp_ptfile = (
117
+ strip_lr_modules_from_yaml(model_config, source_sd)
118
+ )
119
+
120
+ # Serialize YAML BEFORE building module (build_module mutates the dict)
121
+ core_yaml_str = yaml.dump(core_config, default_flow_style=False, sort_keys=False)
122
+
123
+ # Build model from modified config
124
+ print("Building model...")
125
+ core_model = build_module(copy.deepcopy(core_config))
126
+ if not isinstance(core_model, nn.Module):
127
+ raise TypeError("Built module is not an nn.Module")
128
+
129
+ # Load weights with strict=False (modules may differ)
130
+ load_result = core_model.load_state_dict(source_sd, strict=False)
131
+
132
+ # Check for unexpected missing/extra keys
133
+ real_missing, real_unexpected = validate_state_dict_keys(load_result.missing_keys, load_result.unexpected_keys)
134
+ if real_missing:
135
+ print(f"WARNING: Unexpected missing keys: {real_missing}")
136
+ if real_unexpected:
137
+ print(f"WARNING: Unexpected extra keys in source: {real_unexpected}")
138
+ if not real_missing and not real_unexpected:
139
+ print("Loaded weights successfully")
140
+
141
+ # Load dispersion parameters from ptfile and inject into model
142
+ # (raw training weights don't contain disp_param0 buffer)
143
+ if disp_ptfile is not None:
144
+ disp_params = torch.load(disp_ptfile, map_location="cpu", weights_only=True)
145
+ for _name, module in core_model.named_modules():
146
+ if hasattr(module, "disp_param0"):
147
+ # Resize buffer if needed (ptfile may have different shape than placeholder)
148
+ if module.disp_param0.shape != disp_params.shape:
149
+ module.disp_param0 = torch.zeros_like(disp_params)
150
+ module.disp_param0.copy_(disp_params)
151
+ print(f"Loaded disp_param0 from {disp_ptfile}")
152
+ break
153
+
154
+ # Bake SAE into atomic_shift (float64)
155
+ print("Baking SAE into atomic_shift...")
156
+ core_model = bake_sae_into_model(core_model, sae_dict)
157
+
158
+ # Mask not-implemented species
159
+ core_model = mask_not_implemented_species(core_model, implemented_species)
160
+
161
+ # Extract cutoff
162
+ cutoff = extract_cutoff(core_model)
163
+
164
+ # Set model to eval mode
165
+ core_model.eval()
166
+
167
+ # Determine final flags (CLI overrides auto-detection)
168
+ auto_needs_coulomb = coulomb_mode == "sr_embedded"
169
+ auto_needs_dispersion = needs_dispersion_auto
170
+
171
+ final_needs_coulomb = needs_coulomb if needs_coulomb is not None else auto_needs_coulomb
172
+ final_needs_dispersion = needs_dispersion if needs_dispersion is not None else auto_needs_dispersion
173
+
174
+ # Warn if overriding auto-detection
175
+ if needs_coulomb is not None and needs_coulomb != auto_needs_coulomb:
176
+ print(f" Overriding needs_coulomb: {auto_needs_coulomb} -> {needs_coulomb}")
177
+ if needs_dispersion is not None and needs_dispersion != auto_needs_dispersion:
178
+ print(f" Overriding needs_dispersion: {auto_needs_dispersion} -> {needs_dispersion}")
179
+
180
+ # Detect if model has any embedded LR modules that need nbmat_lr
181
+ outputs = model_config.get("kwargs", {}).get("outputs", {})
182
+ has_embedded_lr = False
183
+
184
+ # Check for embedded D3TS (uses NN-predicted C6/alpha, must stay embedded)
185
+ has_d3ts = any("D3TS" in outputs.get(k, {}).get("class", "") for k in ["dftd3", "d3bj", "d3ts"])
186
+ if has_d3ts:
187
+ has_embedded_lr = True
188
+
189
+ # Check for embedded SRCoulomb (model had LRCoulomb before conversion)
190
+ if coulomb_mode == "sr_embedded":
191
+ has_embedded_lr = True
192
+
193
+ # Create new format dict
194
+ new_format = {
195
+ "format_version": 2, # v2 = new .pt format (v1 = legacy .jpt)
196
+ "model_yaml": core_yaml_str,
197
+ "cutoff": cutoff,
198
+ "needs_coulomb": final_needs_coulomb,
199
+ "needs_dispersion": final_needs_dispersion,
200
+ "coulomb_mode": coulomb_mode,
201
+ "coulomb_sr_rc": coulomb_sr_rc if final_needs_coulomb else None,
202
+ "coulomb_sr_envelope": coulomb_sr_envelope if final_needs_coulomb else None,
203
+ "d3_params": d3_params if final_needs_dispersion else None,
204
+ "has_embedded_lr": has_embedded_lr,
205
+ "implemented_species": implemented_species,
206
+ "state_dict": core_model.state_dict(),
207
+ }
208
+
209
+ # Save
210
+ torch.save(new_format, output)
211
+ print(f"\nSaved model to {output}")
212
+ print(f" cutoff: {cutoff}")
213
+ print(f" needs_coulomb: {final_needs_coulomb}")
214
+ print(f" needs_dispersion: {final_needs_dispersion}")
215
+ print(f" coulomb_mode: {coulomb_mode}")
216
+ if final_needs_coulomb:
217
+ print(f" coulomb_sr_rc: {coulomb_sr_rc}")
218
+ print(f" coulomb_sr_envelope: {coulomb_sr_envelope}")
219
+ if final_needs_dispersion:
220
+ print(f" d3_params: {d3_params}")
221
+ print(f" has_embedded_lr: {has_embedded_lr}")
222
+ print(f" implemented_species: {implemented_species}")
223
+
224
+
225
+ if __name__ == "__main__":
226
+ export_model()
aimnet/train/loss.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from functools import partial
2
- from typing import Any, Dict
2
+ from typing import Any
3
3
 
4
4
  import torch
5
5
  from torch import Tensor
@@ -30,7 +30,7 @@ class MTLoss:
30
30
  Dict[str, Tensor]: total loss under key 'loss' and values for individual components.
31
31
  """
32
32
 
33
- def __init__(self, components: Dict[str, Any]):
33
+ def __init__(self, components: dict[str, Any]):
34
34
  w_sum = sum(c["weight"] for c in components.values())
35
35
  self.components = {}
36
36
  for name, c in components.items():
@@ -38,7 +38,7 @@ class MTLoss:
38
38
  fn = partial(get_module(c["fn"]), **kwargs)
39
39
  self.components[name] = (fn, c["weight"] / w_sum)
40
40
 
41
- def __call__(self, y_pred: Dict[str, Tensor], y_true: Dict[str, Tensor]) -> Dict[str, Tensor]:
41
+ def __call__(self, y_pred: dict[str, Tensor], y_true: dict[str, Tensor]) -> dict[str, Tensor]:
42
42
  loss = {}
43
43
  for name, (fn, w) in self.components.items():
44
44
  _l = fn(y_pred=y_pred, y_true=y_true)
@@ -48,7 +48,7 @@ class MTLoss:
48
48
  return loss
49
49
 
50
50
 
51
- def mse_loss_fn(y_pred: Dict[str, Tensor], y_true: Dict[str, Tensor], key_pred: str, key_true: str) -> Tensor:
51
+ def mse_loss_fn(y_pred: dict[str, Tensor], y_true: dict[str, Tensor], key_pred: str, key_true: str) -> Tensor:
52
52
  """General MSE loss function"""
53
53
  x = y_true[key_true]
54
54
  y = y_pred[key_pred]
@@ -56,7 +56,7 @@ def mse_loss_fn(y_pred: Dict[str, Tensor], y_true: Dict[str, Tensor], key_pred:
56
56
  return loss
57
57
 
58
58
 
59
- def peratom_loss_fn(y_pred: Dict[str, Tensor], y_true: Dict[str, Tensor], key_pred: str, key_true: str) -> Tensor:
59
+ def peratom_loss_fn(y_pred: dict[str, Tensor], y_true: dict[str, Tensor], key_pred: str, key_true: str) -> Tensor:
60
60
  """MSE loss function with per-atom normalization correction.
61
61
  Suitable when some of the values are zero both in y_pred and y_true due to padding of inputs.
62
62
  """
@@ -73,11 +73,11 @@ def peratom_loss_fn(y_pred: Dict[str, Tensor], y_true: Dict[str, Tensor], key_pr
73
73
 
74
74
 
75
75
  def energy_loss_fn(
76
- y_pred: Dict[str, Tensor], y_true: Dict[str, Tensor], key_pred: str = "energy", key_true: str = "energy"
76
+ y_pred: dict[str, Tensor], y_true: dict[str, Tensor], key_pred: str = "energy", key_true: str = "energy"
77
77
  ) -> Tensor:
78
78
  """MSE loss normalized by the number of atoms."""
79
79
  x = y_true[key_true]
80
80
  y = y_pred[key_pred]
81
- s = y_pred["_natom"].sqrt()
81
+ s = y_pred["_natom"] ** 0.5
82
82
  loss = ((x - y).pow(2) / s).mean() if y_pred["_natom"].numel() > 1 else torch.nn.functional.mse_loss(x, y) / s
83
83
  return loss