lineshape_tools 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lineshape_tools/cli.py ADDED
@@ -0,0 +1,965 @@
1
+ """Contains functionality for the command line interface."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import sys
7
+ import warnings
8
+ from importlib.util import find_spec
9
+ from pathlib import Path
10
+ from typing import TYPE_CHECKING, Annotated
11
+
12
+ import ase.io
13
+ import numpy as np
14
+ from cyclopts import Parameter
15
+ from tqdm import tqdm
16
+
17
+ from lineshape_tools.constants import omega2eV
18
+ from lineshape_tools.lineshape import convert_A_to_L, gaussian, get_phonon_spec_func
19
+ from lineshape_tools.phonon import get_disp_vect, get_ipr, get_phonons
20
+ from lineshape_tools.plot import plot_spec_funcs
21
+
22
+ if TYPE_CHECKING:
23
+ from ase.atoms import Atoms
24
+ from mace.calculators import MACECalculator
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ def collect(
30
+ files: Annotated[list[Path], Parameter(negative="")],
31
+ output_file: Path = Path("./database.extxyz"),
32
+ strategy: str = "none",
33
+ read_index: str = ":",
34
+ max_force: float = 2.0,
35
+ min_force: float = -np.inf,
36
+ dx_tol: float = 0.1,
37
+ rtol: float = 1e-5,
38
+ config_weight: float = 1.0,
39
+ force_weighting: bool = False,
40
+ ) -> None:
41
+ """Collect and process data for fine-tuning.
42
+
43
+ Collect files into an extxyz database that can be used for fine-tuning. An optional filtering
44
+ strategy can be applied. This is potentially useful if relaxation data is being included in the
45
+ dataset, as it can be noisey from having multiple closely spaced geometries close to the
46
+ equilibrium geomtry. By default, configurations with too large forces will be thrown away to
47
+ avoid potential anharmonic contributions to the PES.
48
+
49
+ Args:
50
+ files (list): list of paths to files that are parseable by ase.io. The files should contain
51
+ atomic geometries, total energies, and forces at a minimum (for example, vasprun.xml).
52
+ output_file (Path): optional path where data is written (output to stdout by default)
53
+ strategy (str): optional specification of strategy to be used for filtering. Available
54
+ options are 'none', 'qr', or 'dx'.
55
+ read_index (str): pythonic index passed to ase.io.read to determine which structures are
56
+ read from the input files (the same value is used for each file). The default ":" reads
57
+ all of the structures, while ":3" would read the first three for example.
58
+ max_force (float): remove structures where the maximum force acting on any atom is above
59
+ the specified value (in eV/Å)
60
+ min_force (float): remove structures where the maximum force acting on any atom is below
61
+ the specified value (in eV/Å)
62
+ dx_tol (float): tolerance (in Å) for how far atoms must move to accept configuration in
63
+ 'dx' filtering strategy
64
+ rtol (float): tolerance ratio for determining rank of displacement vectors in 'qr' strategy
65
+ config_weight (float): set the configuration weight for training
66
+ force_weighting (bool): store a config_weight that's inversely proportional to the max
67
+ force that any atom feels in the configuration [min(0.02 / max_fpa, 1)]. Overwrites
68
+ the value specified by config_weight.
69
+ """
70
+ if strategy.lower() not in ("none", "qr", "dx"):
71
+ raise ValueError("invalid strategy choice")
72
+
73
+ if output_file.exists():
74
+ logger.info(f"{output_file} already exists, appending to it")
75
+
76
+ total_read = 0
77
+ all_atoms = []
78
+ for fname in tqdm(files, desc="[*] reading files", disable=len(files) < 2):
79
+ read_atoms = ase.io.read(fname, read_index)
80
+ total_read += len(read_atoms)
81
+ for atoms in tqdm(read_atoms, desc="[*] processing atoms", disable=len(read_atoms) < 10):
82
+ forces = atoms.get_forces()
83
+ if min_force < np.linalg.norm(forces, axis=1).max() < max_force:
84
+ atoms.info["REF_energy"] = atoms.get_potential_energy()
85
+ atoms.new_array("REF_forces", forces)
86
+ atoms.calc = None
87
+ atoms.info["config_weight"] = config_weight
88
+ if force_weighting:
89
+ atoms.info["config_weight"] = min(
90
+ 0.02 / np.linalg.norm(forces, axis=1).max(), 1
91
+ )
92
+ if (cw := atoms.info["config_weight"]) < 0.02:
93
+ logger.warning(f"small config weight found ({cw})")
94
+ all_atoms.append(atoms)
95
+
96
+ logger.info(
97
+ f"read {total_read} configurations and discarded {total_read - len(all_atoms)} based"
98
+ " on forces criteria"
99
+ )
100
+
101
+ filtered_atoms = []
102
+ if strategy.lower() == "qr":
103
+ from scipy.linalg import qr
104
+
105
+ logger.warning("qr strategy assumes the last read structure is the equilibrium geometry")
106
+ if np.linalg.norm(all_atoms[-1].arrays["REF_forces"], axis=1).max() > 0.02:
107
+ warnings.warn("large forces found in last structure", stacklevel=0)
108
+
109
+ sqrt_mass = np.repeat(np.sqrt(all_atoms[-1].get_masses()), 3)
110
+
111
+ dq = np.zeros((sqrt_mass.shape[0], len(all_atoms) - 1), dtype=np.float64)
112
+ for atoms_i, atoms in enumerate(tqdm(all_atoms[:-1], desc="[*] computing dq")):
113
+ dx = get_disp_vect(atoms, all_atoms[-1])
114
+ dq[:, atoms_i] = sqrt_mass * dx
115
+
116
+ logger.info("computing pivoted QR factorization")
117
+ _, R, P = qr(dq, mode="economic", pivoting=True)
118
+
119
+ rdiag = np.diag(R)
120
+ rank = np.sum(np.abs(rdiag) > rtol * rdiag.max())
121
+
122
+ filtered_atoms = [all_atoms[-1]] + [all_atoms[atoms_i] for atoms_i in np.sort(P[:rank])]
123
+ elif strategy.lower() == "dx":
124
+ filtered_atoms.append(all_atoms[0])
125
+ for atoms in tqdm(all_atoms[1:], desc="[*] dx filtering"):
126
+ dx = get_disp_vect(filtered_atoms[-1], atoms)
127
+ if np.linalg.norm(dx) > dx_tol:
128
+ filtered_atoms.append(atoms)
129
+ else:
130
+ filtered_atoms = all_atoms
131
+
132
+ logger.info(f"filtered {len(all_atoms) - len(filtered_atoms)} configurations from dataset")
133
+
134
+ with open(output_file, "a") as f:
135
+ for atoms in tqdm(filtered_atoms, desc="[*] writing"):
136
+ ase.io.write(f, atoms, format="extxyz")
137
+
138
+
139
+ def get_force_opt_modes(
140
+ n: int,
141
+ omega2: np.ndarray,
142
+ U: np.ndarray,
143
+ sqrt_mass: np.ndarray,
144
+ F: float = 0.5,
145
+ tol: float = 1e-6,
146
+ seed: int = 897689932,
147
+ start_with_min_spread: bool = False,
148
+ save_plot: bool = False,
149
+ ) -> tuple[np.ndarray, np.ndarray]:
150
+ """Bin phonons by energy and find a vector in the subspace that minimizes the spread in forces.
151
+
152
+ Args:
153
+ n (int): number of modes to select
154
+ omega2 (np.ndarray): frequencies squared of the modes (directly from np.linalg.eigh)
155
+ U (np.ndarray): matrix where eigenvectors of modes are cols (directly from np.linalg.eigh)
156
+ sqrt_mass (np.ndarray): sqrt of the vector of atomic masses
157
+ F (float): target forces to optimize amplitudes for
158
+ tol (float): convergence tolerance for scipy minimize call
159
+ seed (int): seed value for random number generator
160
+ start_with_min_spread (bool): determines if mode with smallest force spread is used as the
161
+ starting point for optimization. Uses a random vector otherwise.
162
+ save_plot (bool): save a plot for analyzing resulting modes
163
+
164
+ Returns:
165
+ modes (np.ndarray): generated modes as columns of the matrix
166
+ mode_dqs (np.ndarray): optimized displacement amplitudes following above criteria
167
+ """
168
+ inv_sqrt_mass = 1 / sqrt_mass
169
+
170
+ # acoustic phonon filtering
171
+ omega2[omega2 < (0.0005 / omega2eV) ** 2] = 0.0
172
+
173
+ opt_path, opt_info = np.einsum_path("i,ij,j,j->i", sqrt_mass, U, omega2, np.ones(U.shape[0]))
174
+ logger.debug(opt_info)
175
+
176
+ def loss(x, ind=None):
177
+ if ind is not None:
178
+ tx = np.zeros(U.shape[0])
179
+ tx[ind] = x
180
+ else:
181
+ tx = x
182
+
183
+ Fx = np.einsum("i,ij,j,j->i", sqrt_mass, U, omega2, tx, optimize=opt_path)
184
+ Fx = np.linalg.norm(Fx.reshape((-1, 3)), axis=1)
185
+ return np.mean((Fx - Fx.mean()) ** 2)
186
+
187
+ def constr_f(x):
188
+ return np.sum(x**2)
189
+
190
+ def constr_J(x):
191
+ return 2 * x
192
+
193
+ def constr_H(x, v):
194
+ return np.diag(2 * v[0] * np.ones(x.shape[0]))
195
+
196
+ def log_callback(intermediate_result):
197
+ ir = intermediate_result
198
+ logger.debug(f"{ir.nit:04d} | {ir.fun:.06e} | {ir.optimality:.06e}")
199
+
200
+ from scipy.optimize import NonlinearConstraint, minimize
201
+
202
+ norm_constr = NonlinearConstraint(constr_f, 1, 1, jac=constr_J, hess=constr_H)
203
+
204
+ rng = np.random.default_rng(seed)
205
+
206
+ modes = np.zeros((U.shape[0], n), dtype=np.float64)
207
+ mode_dqs = np.zeros(n, dtype=np.float64)
208
+ freqs = np.empty(n)
209
+
210
+ subspace_inds = np.array_split(np.arange(3, U.shape[0]), n)
211
+ for imode, ind in enumerate(tqdm(subspace_inds, desc="[*] optimizing modes")):
212
+ subs_spread = np.empty(ind.shape[0])
213
+ for i, ind_i in enumerate(ind):
214
+ Fx = np.linalg.norm((sqrt_mass * omega2[ind_i] * U[:, ind_i]).reshape((-1, 3)), axis=1)
215
+ subs_spread[i] = np.mean((Fx - Fx.mean()) ** 2)
216
+
217
+ if start_with_min_spread:
218
+ imin = np.argmin(subs_spread)
219
+ logger.info(
220
+ f"starting from mode {ind[imin]} "
221
+ f"with frequency {1000 * omega2eV * np.sqrt(omega2[ind[imin]]):.02f} meV "
222
+ f"and spread {subs_spread[imin]:.06e}"
223
+ )
224
+ x0 = np.zeros(ind.shape[0])
225
+ x0[imin] = 1.0
226
+ else:
227
+ x0 = rng.random(ind.shape[0]) - 0.5
228
+ x0 /= np.linalg.norm(x0)
229
+
230
+ logger.debug(" nit | loss | optimality")
231
+ res = minimize(
232
+ lambda x: loss(x, ind=ind), # noqa: B023
233
+ x0,
234
+ tol=tol,
235
+ method="trust-constr",
236
+ constraints=[norm_constr],
237
+ callback=log_callback,
238
+ )
239
+
240
+ if not res.success:
241
+ logger.warning(f"optimization failed - {res.message}")
242
+
243
+ logger.debug(f"final spread {res.fun}, subspace spread {subs_spread}")
244
+ if not np.all(res.fun < subs_spread):
245
+ logger.warning("optimization failed to find a vector with smaller spread")
246
+
247
+ modes[ind, imode] = res.x
248
+ freqs[imode] = omega2eV * np.sqrt(modes[:, imode] @ np.diag(omega2) @ modes[:, imode])
249
+ logger.debug(f"found mode with frequency {1000 * freqs[imode]:.02f} meV")
250
+
251
+ Fx = np.linalg.norm(
252
+ np.einsum("i,ij,j,j->i", sqrt_mass, U, omega2, modes[:, imode]).reshape((-1, 3)),
253
+ axis=1,
254
+ )
255
+ mode_dqs[imode] = F / Fx.max()
256
+
257
+ dx = np.linalg.norm(
258
+ (inv_sqrt_mass * mode_dqs[imode] * (U @ modes[:, imode])).reshape((-1, 3)), axis=1
259
+ )
260
+ if dx.max() > 0.05:
261
+ new_dq = (0.05 / dx.max()) * mode_dqs[imode]
262
+ logger.warning(
263
+ f"max displacement too large ({dx.max()} > 0.05), "
264
+ f"resetting dq {mode_dqs[imode]} -> {new_dq}"
265
+ )
266
+ mode_dqs[imode] = new_dq
267
+ elif dx.max() < 0.005:
268
+ new_dq = (0.005 / dx.max()) * mode_dqs[imode]
269
+ logger.warning(
270
+ f"max displacement too small ({dx.max()} < 0.005), "
271
+ f"resetting dq {mode_dqs[imode]} -> {new_dq}"
272
+ )
273
+ mode_dqs[imode] = new_dq
274
+
275
+ if save_plot:
276
+ w = np.linspace(0, 0.1, 1000)
277
+ dos = np.zeros_like(w)
278
+ for i in range(U.shape[0]):
279
+ dos += gaussian(w - omega2eV * np.sqrt(omega2[i]), 0.001)
280
+ dos /= U.shape[0]
281
+
282
+ import matplotlib.pyplot as plt
283
+
284
+ fig, ax = plt.subplots(figsize=(4, 3))
285
+ ax.plot(w, dos, color="k")
286
+ ax.fill_between(w, dos, color="k", alpha=0.2, lw=0)
287
+ for imode in range(n):
288
+ pdos = np.zeros_like(w)
289
+ for i in range(U.shape[0]):
290
+ pdos += modes[i, imode] ** 2 * gaussian(w - omega2eV * np.sqrt(omega2[i]), 0.001)
291
+ p = ax.plot(w, pdos / n, lw=1)
292
+ ax.axvline(x=freqs[imode], color=p[0].get_color(), alpha=0.2)
293
+ ax.set_xlabel("Energy [eV]")
294
+ ax.set_ylabel("Density of States")
295
+ plt.savefig("pdos.png", dpi=600, bbox_inches="tight")
296
+
297
+ return U @ modes, mode_dqs
298
+
299
+
300
+ def get_random_phonons(
301
+ n: int,
302
+ omega2: np.ndarray,
303
+ U: np.ndarray,
304
+ sqrt_mass: np.ndarray,
305
+ F: float = 0.5,
306
+ seed: int = 897689932,
307
+ ) -> tuple[np.ndarray, np.ndarray]:
308
+ """Bin phonons by energy and select one randomly from each bin.
309
+
310
+ The amplitude of each phonon is chosen to produce a max force per atom as close to F as
311
+ possible. The max displacement on a given atom is kept within a reasonable range (0.005, 0.05).
312
+
313
+ Args:
314
+ n (int): number of modes to select
315
+ omega2 (np.ndarray): frequencies squared of the modes (directly from np.linalg.eigh)
316
+ U (np.ndarray): matrix where eigenvectors of modes are cols (directly from np.linalg.eigh)
317
+ sqrt_mass (np.ndarray): sqrt of the vector of atomic masses
318
+ F (float): target forces to optimize amplitudes for
319
+ seed (int): seed value for random number generator
320
+
321
+ Returns:
322
+ modes (np.ndarray): generated modes as columns of the matrix
323
+ mode_dqs (np.ndarray): optimized displacement amplitudes following above criteria
324
+ """
325
+ inv_sqrt_mass = 1 / sqrt_mass
326
+
327
+ # acoustic phonon filtering
328
+ omega2[omega2 < (0.0005 / omega2eV) ** 2] = 0.0
329
+
330
+ modes = np.zeros((U.shape[0], n), dtype=np.float64)
331
+ mode_dqs = np.empty(n, dtype=np.float64)
332
+
333
+ rng = np.random.default_rng(seed)
334
+
335
+ # starting from 3 to skip acoustic phonons
336
+ subspace_inds = np.array_split(np.arange(3, U.shape[0]), n)
337
+ for i, inds in enumerate(tqdm(subspace_inds, desc="[*] selecting random modes")):
338
+ imode = inds[rng.integers(inds.shape[0])]
339
+ modes[:, i] = U[:, imode]
340
+
341
+ Fx = np.linalg.norm((sqrt_mass * omega2[imode] * U[:, imode]).reshape((-1, 3)), axis=1)
342
+ mode_dqs[i] = F / Fx.max()
343
+
344
+ dx = np.linalg.norm((mode_dqs[i] * inv_sqrt_mass * U[:, imode]).reshape((-1, 3)), axis=1)
345
+ if dx.max() > 0.05:
346
+ new_dq = (0.05 / dx.max()) * mode_dqs[i]
347
+ logger.warning(
348
+ f"max displacement too large ({dx.max()} > 0.05), "
349
+ f"resetting dq {mode_dqs[i]} -> {new_dq}"
350
+ )
351
+ mode_dqs[i] = new_dq
352
+ elif dx.max() < 0.005:
353
+ new_dq = (0.005 / dx.max()) * mode_dqs[i]
354
+ logger.warning(
355
+ f"max displacement too small ({dx.max()} < 0.005), "
356
+ f"resetting dq {mode_dqs[i]} -> {new_dq}"
357
+ )
358
+ mode_dqs[i] = new_dq
359
+ return modes, mode_dqs
360
+
361
+
362
+ def gen_confs(
363
+ struct_path: Path,
364
+ num_conf: int,
365
+ strategy: str = "rand",
366
+ output_dir: Path = Path("./confs"),
367
+ accepting_mode: Path | None = None,
368
+ dynmat_file: Path | None = None,
369
+ orthogonalize: bool = False,
370
+ default_max_dx: float = 0.015,
371
+ start_with_min_spread: bool = False,
372
+ opt_tol: float = 1e-6,
373
+ seed: int = 897689932,
374
+ ) -> None:
375
+ """Generate additional configurations to enhance fine-tuning dataset.
376
+
377
+ Args:
378
+ struct_path (Path): path to file containing structure that will be displaced
379
+ num_conf (int): total number of additional configurations to generate
380
+ strategy (str): strategy used to generate the additional configurations. Available options
381
+ are 'rand', 'phon_rand', and 'phon_opt'.
382
+ output_dir (Path): output directory where the new configurations will be written to
383
+ accepting_mode (Path): path to file containing the structure that defines the accepting
384
+ mode. For example, if struct_path refers to the ground-state equilibrium geometry, then
385
+ accepting_mode should refer to the excited-state equilibrium geometry and vice versa.
386
+ dynmat_file (Path): path to the .npz file containing the dynamical matrix presumably
387
+ calculated using the "compute-dynmat" function.
388
+ orthogonalize (bool): perform Gram-Schmidt orthogonalization at the last step
389
+ default_max_dx (float): default value for max displacement of a given atom
390
+ start_with_min_spread (bool): determines if mode with smallest force spread is used as the
391
+ starting point for optimization in phon_opt strategy. Uses a random vector otherwise.
392
+ opt_tol (float): convergence tolerance for the call to scipy minimize in the phon_opt strat
393
+ seed (int): seed value for random number generator
394
+ """
395
+ if strategy.lower() not in ("rand", "phon_rand", "phon_opt"):
396
+ raise ValueError("invalid strategy choice")
397
+
398
+ if output_dir.exists():
399
+ raise ValueError("output directory already exists")
400
+
401
+ struct: Atoms = ase.io.read(struct_path) # type: ignore[assignment]
402
+ sqrt_mass = np.repeat(np.sqrt(struct.get_masses()), 3)
403
+
404
+ # makes written poscars ugly and adds unnecessary data
405
+ if struct.has("momenta"):
406
+ del struct.arrays["momenta"]
407
+
408
+ imode = 0
409
+ modes = np.zeros((3 * len(struct), num_conf), dtype=np.float64)
410
+ mode_dqs = np.zeros(num_conf, dtype=np.float64)
411
+
412
+ if accepting_mode is not None:
413
+ am_struct: Atoms = ase.io.read(accepting_mode) # type: ignore[assignment]
414
+ dx = get_disp_vect(struct, am_struct)
415
+ dq = sqrt_mass * dx
416
+ logger.info(f"accepting mode dQ={np.linalg.norm(dq)} amu^0.5 Å")
417
+
418
+ modes[:, 0] = dq / np.linalg.norm(dq)
419
+ mode_dqs[0] = (
420
+ default_max_dx
421
+ / np.linalg.norm((modes[:, 0] / sqrt_mass).reshape((-1, 3)), axis=1).max()
422
+ )
423
+
424
+ imode += 1
425
+
426
+ if strategy.lower()[:4] == "phon":
427
+ logger.info("working in phonon basis")
428
+
429
+ if dynmat_file is None:
430
+ raise ValueError("dynamical matrix file is needed to compute phonons")
431
+
432
+ logger.info(f"reading dynamical matrix from {dynmat_file}")
433
+ data = np.load(dynmat_file)
434
+ H = data["H"]
435
+
436
+ if not np.allclose(sqrt_mass, data["sqrt_mass"], rtol=1e-4):
437
+ logger.debug(sqrt_mass, data["sqrt_mass"])
438
+ raise ValueError("sqrt_mass from struct is not compatible with dynamical matrix file")
439
+
440
+ # diagonalize dynamical matrix
441
+ omega2, U = np.linalg.eigh(H)
442
+
443
+ if strategy.lower() == "phon_opt":
444
+ logger.info("generating force-optimized phonons")
445
+ modes[:, imode:], mode_dqs[imode:] = get_force_opt_modes(
446
+ num_conf - imode,
447
+ omega2,
448
+ U,
449
+ sqrt_mass,
450
+ tol=opt_tol,
451
+ start_with_min_spread=start_with_min_spread,
452
+ seed=seed,
453
+ )
454
+ else:
455
+ logger.info("selecting phonons randomly")
456
+ modes[:, imode:], mode_dqs[imode:] = get_random_phonons(
457
+ num_conf - imode,
458
+ omega2,
459
+ U,
460
+ sqrt_mass,
461
+ seed=seed,
462
+ )
463
+ else:
464
+ logger.info("generating random modes")
465
+
466
+ from scipy.linalg import qr
467
+
468
+ rng = np.random.default_rng(seed)
469
+ modes[:, imode:], _, _ = qr(
470
+ rng.random((modes.shape[0], num_conf - imode)) - 0.5, mode="economic", pivoting=True
471
+ )
472
+ for i in range(imode, num_conf):
473
+ mode_dqs[i] = (
474
+ default_max_dx
475
+ / np.linalg.norm((modes[:, i] / sqrt_mass).reshape((-1, 3)), axis=1).max()
476
+ )
477
+
478
+ if orthogonalize:
479
+ from scipy.linalg import qr
480
+
481
+ logger.info("performing final Gram-Schmidt orthogonalization of all modes")
482
+ modes, _ = qr(modes, mode="economic")
483
+
484
+ logger.info("recomputing displacement amplitude after orthogonalization")
485
+ for i in range(num_conf):
486
+ mode_dqs[i] = (
487
+ default_max_dx
488
+ / np.linalg.norm((modes[:, i] / sqrt_mass).reshape((-1, 3)), axis=1).max()
489
+ )
490
+
491
+ # write to output directory
492
+ for imode in tqdm(range(num_conf), desc="[*] writing structures"):
493
+ dx = mode_dqs[imode] * modes[:, imode] / sqrt_mass
494
+ logger.debug(
495
+ f"mode {imode} - max displacement "
496
+ f"{np.linalg.norm(dx.reshape((-1, 3)), axis=1).max():.06f} Å "
497
+ f"{np.abs(dx).max():.06f} Å"
498
+ )
499
+
500
+ atoms = struct.copy()
501
+ atoms.positions += dx.reshape((-1, 3))
502
+
503
+ fname = output_dir / f"{imode}/POSCAR"
504
+ fname.parent.mkdir(parents=True)
505
+ ase.io.write(fname, atoms, format="vasp", direct=True)
506
+
507
+ # store invoking command
508
+ with open(output_dir / "cmd.txt", "w") as f:
509
+ f.write(" ".join(sys.argv) + "\n")
510
+
511
+
512
+ def reestimate_e0s_linear_system(
513
+ calculator: MACECalculator,
514
+ database_atoms: list[Atoms],
515
+ elements: list | None = None,
516
+ initial_e0s: dict | None = None,
517
+ ) -> dict:
518
+ """Estimate atomic reference energies (E0s) by solving a linear system.
519
+
520
+ Notes:
521
+ Slightly adapted from code by Noam Bernstein based on private communications
522
+ with Ilyes Batatia and Joe Hart.
523
+
524
+ This functionality will eventually be removed once merged into MACE.
525
+
526
+ Args:
527
+ calculator (MACECalculator): Calculator object for the MACE model.
528
+ database_atoms (list): List of ase Atoms objects with energy and atomic_numbers.
529
+ elements (list): List of element atomic numbers to consider, default to set present in
530
+ database_atoms.
531
+ initial_e0s (dict): Dictionary mapping element atomic numbers to E0 values, default to
532
+ values returned by foundation_model for isolated atom configs>
533
+
534
+ Returns:
535
+ Dictionary with re-estimated E0 values for each element
536
+ """
537
+ # filter configs without energy
538
+ database_atoms = [
539
+ atoms for atoms in database_atoms if atoms.info.get("REF_energy") is not None
540
+ ]
541
+
542
+ if len(database_atoms) == 0:
543
+ raise ValueError("database does not contain REF_energy tag")
544
+
545
+ if not elements:
546
+ elements = np.unique([Z for atoms in database_atoms for Z in atoms.numbers]).tolist()
547
+
548
+ if not initial_e0s:
549
+ initial_e0s = {Z: 0.0 for Z in elements}
550
+ try:
551
+ for Z in elements:
552
+ for i in range(len(calculator.models)):
553
+ z_ind = calculator.z_table.z_to_index(Z)
554
+ initial_e0s[Z] += float(
555
+ calculator.models[i].atomic_energies_fn.atomic_energies[z_ind]
556
+ )
557
+ except Exception as e:
558
+ logger.warning(f"unexpected exception in getting initial E0s: {e}")
559
+ logger.warning("falling back to explicit isolated atom calculations")
560
+
561
+ from ase.atoms import Atoms
562
+
563
+ for Z in elements:
564
+ calculator.calculate(
565
+ atoms=Atoms(numbers=[Z], cell=[20] * 3, pbc=[True] * 3), properties=["energy"]
566
+ )
567
+ initial_e0s[Z] = calculator.results.get("energy")
568
+ logger.info(f"using initial E0s: {initial_e0s}")
569
+
570
+ # A matrix: each row contains atom counts for each element
571
+ # b vector: each entry is the prediction error for a configuration
572
+ A = np.zeros((len(database_atoms), len(elements)))
573
+ b = np.zeros(len(database_atoms))
574
+
575
+ logger.info(
576
+ f"solving linear system with {len(database_atoms)} equations and {len(elements)} unknowns"
577
+ )
578
+
579
+ # - A[i,j] is the count of element j in configuration i
580
+ # - b[i] is the error (true - predicted) for configuration i
581
+ # - x[j] will be the energy correction for element j
582
+ for i, atoms in enumerate(tqdm(database_atoms, desc="[*] foundation model predictions")):
583
+ calculator.calculate(atoms=atoms.copy(), properties=["energy"])
584
+ b[i] = atoms.info["REF_energy"] - calculator.results.get("energy")
585
+
586
+ # atom counts for each element
587
+ for j, element in enumerate(elements):
588
+ A[i, j] = np.sum(atoms.get_atomic_numbers() == element)
589
+
590
+ # solve with least squares
591
+ try:
592
+ corrections, _, rank, s = np.linalg.lstsq(A, b, rcond=None)
593
+ except np.linalg.LinAlgError as e:
594
+ logger.warning(f"error using lstsq to solve the linear system: {e}")
595
+ logger.warning("falling back to foundation model E0s")
596
+ return initial_e0s.copy()
597
+
598
+ if np.linalg.norm(corrections) > 1e6:
599
+ logger.critical(
600
+ f"abnormally large corrections found, rank determination may have failed: {s}"
601
+ )
602
+
603
+ new_e0s = {}
604
+ for i, element in enumerate(elements):
605
+ new_e0s[element] = initial_e0s[element] + corrections[i]
606
+ logger.debug(
607
+ f"element {element}: foundation E0 = {initial_e0s[element]:.4f}, "
608
+ f"correction = {corrections[i]:.4f}, new E0 = {new_e0s[element]:.4f}"
609
+ )
610
+
611
+ # statistics about the fit
612
+ b_after = b - A @ corrections
613
+ mse_before, mse_after = np.mean(b**2), np.mean(b_after**2)
614
+ improvement = (1 - mse_after / mse_before) * 100
615
+
616
+ logger.debug(f"mean squared error before correction: {mse_before:.4f} eV²")
617
+ logger.debug(f"mean squared error after correction: {mse_after:.4f} eV²")
618
+ logger.debug(f"improvement: {improvement:.1f}%")
619
+
620
+ if rank < len(elements):
621
+ logger.warning(f"system is rank deficient (rank {rank}/{len(elements)})")
622
+ logger.warning(
623
+ "some elements may be linearly dependent or not well represented in the dataset."
624
+ )
625
+
626
+ return new_e0s
627
+
628
+
629
+ def gen_ft_config(
630
+ out: Path | str = "./config.default",
631
+ estimate_e0s: bool = False,
632
+ device: str = "cuda",
633
+ name: str = "fine-tuned",
634
+ mace_model: str = "medium-omat-0",
635
+ database: Path | str = "./database.extxyz",
636
+ head: str = "default",
637
+ ) -> None:
638
+ """Generate a configuration file for mace_run_train.
639
+
640
+ Args:
641
+ out (Path): path where the mace_run_train configuration file is written to.
642
+ estimate_e0s (bool): estimate the E0s for training of the foundation model.
643
+ device (str): device string passed to MACE to determine where calculation is performed.
644
+ name (str): name of the model.
645
+ mace_model (str): pre-trained MACE model that is used, can be a local path
646
+ database (Path): path to the training dataset file (likely generated with :func:`collect`)
647
+ head (str): which head from the model to use for prediction
648
+ """
649
+ e0s: str | dict
650
+
651
+ if Path(out).exists():
652
+ raise ValueError(f"Output path {out} already exists!")
653
+
654
+ if not estimate_e0s:
655
+ logger.warning(f'Please update the "E0s" tag in {out} prior to running mace_run_train.')
656
+ e0s = "TO_BE_REPLACED"
657
+ else:
658
+ from mace.calculators import mace_mp
659
+
660
+ calc = mace_mp(
661
+ model=mace_model,
662
+ dispersion=False,
663
+ default_dtype="float64",
664
+ device=device,
665
+ enable_cueq=(device == "cuda" and find_spec("cuequivariance") is not None),
666
+ head=head,
667
+ )
668
+
669
+ logger.info("running E0 estimation")
670
+ db_atoms: list[Atoms] = ase.io.read(database, ":") # type: ignore[assignment]
671
+ e0s = reestimate_e0s_linear_system(calc, db_atoms)
672
+
673
+ with open(out, "w") as f:
674
+ f.write(
675
+ "\n".join(
676
+ [
677
+ f"name: {name}",
678
+ f"foundation_model: {mace_model}",
679
+ f"train_file: {database}",
680
+ f"valid_file: {database}",
681
+ f'E0s: "{e0s}"',
682
+ "multiheads_finetuning: false",
683
+ "energy_weight: 1",
684
+ "forces_weight: 10",
685
+ "stress_weight: 0",
686
+ "ema: true",
687
+ "ema_decay: 0.999",
688
+ "lr: 0.001",
689
+ "max_num_epochs: 500",
690
+ "default_dtype: float64",
691
+ "batch_size: 1",
692
+ "valid_batch_size: 1",
693
+ ]
694
+ )
695
+ + "\n"
696
+ )
697
+
698
+
699
+ def parse_force_constants_file(fname: Path | str) -> np.ndarray:
700
+ """Parse a force constants file from phonopy."""
701
+ with open(fname) as f:
702
+ H = np.zeros([int(x) for x in next(f).split()] + [3, 3], dtype=np.float64)
703
+
704
+ for _ in range(H.shape[0] * H.shape[1]):
705
+ i, j = [int(x) - 1 for x in next(f).split()]
706
+
707
+ H[i, j, :, :] = np.fromstring(
708
+ " ".join([next(f).strip() for _ in range(3)]), sep=" "
709
+ ).reshape((3, 3))
710
+
711
+ H = H.swapaxes(1, 2).reshape((3 * H.shape[0], 3 * H.shape[0]))
712
+ return H
713
+
714
+
715
+ def convert_from_phonopy(
716
+ fname: Path | str, atoms_file: Path | str, save_file: Path | str = "dynmat.npz"
717
+ ) -> None:
718
+ """Convert phonopy FORCE_CONSTANTS to dynmat.npz.
719
+
720
+ FORCE_CONSTANTS is written by phonopy when specifying the "--writefc" tag.
721
+
722
+ Args:
723
+ fname (Path): path to FORCE_CONSTANTS file.
724
+ atoms_file (Path): path to file containing the equilibrium structure that was used to
725
+ evaluate the force constants. (Only needed to extract sqrt masses.)
726
+ save_file (Path): path where dynamical matrix is saved to (should end in .npz)
727
+ """
728
+ logger.info(f"reading force constants from {fname}")
729
+ H = parse_force_constants_file(fname)
730
+
731
+ atoms: Atoms = ase.io.read(atoms_file) # type: ignore[assignment]
732
+ sqrt_mass = np.repeat(np.sqrt(atoms.get_masses()), 3)
733
+ inv_sqrt_mass = 1 / sqrt_mass
734
+
735
+ H = np.einsum("i,ij,j->ij", inv_sqrt_mass, H, inv_sqrt_mass)
736
+
737
+ logger.info(f"saving dynamical matrix to {save_file}")
738
+ np.savez_compressed(save_file, H=H, sqrt_mass=sqrt_mass, cmd=" ".join(sys.argv))
739
+
740
+
741
+ def compute_dynmat(
742
+ input_struct: Path,
743
+ save_file: Path = Path("./dynmat.npz"),
744
+ mace_model: str = "medium-omat-0",
745
+ device: str = "cuda",
746
+ head: str = "default",
747
+ relax_struct: bool = True,
748
+ analytical_hessian: bool = True,
749
+ relax_algo: str = "LBFGSLineSearch",
750
+ fmax: float = 0.001,
751
+ ) -> None:
752
+ """Calculate the dynamical matrix using MACE.
753
+
754
+ Args:
755
+ input_struct (Path): structure about which to compute the dynamical matrix
756
+ save_file (Path): path where dynamical matrix is saved to (should end in .npz)
757
+ mace_model (str): pre-trained MACE model that is used, can be a local path
758
+ device (str): device string passed to MACE to determine where calculation is performed
759
+ head (str): which head from the model to use for prediction
760
+ relax_struct (bool): determines if an atomic relaxation is performed prior to computing the
761
+ Hessian matrix. This is recommended if the model does not predict the same equilibrium
762
+ structure as your explicit DFT calculation, which is generally the case unless good
763
+ fine tuning has been performed.
764
+ analytical_hessian (bool): determines if the Hessian is computed analytically or
765
+ numerically using finite differences
766
+ relax_algo (str): name of algorithm from ase.optimize that is used for atomic relaxation.
767
+ fmax (float): force convergence criteria for atomic relaxation in eV/Ä
768
+ """
769
+ from mace.calculators import mace_mp
770
+
771
+ atoms: Atoms = ase.io.read(input_struct) # type: ignore[assignment] # noqa: F823
772
+ atoms.calc = mace_mp(
773
+ model=mace_model,
774
+ dispersion=False,
775
+ default_dtype="float64",
776
+ device=device,
777
+ enable_cueq=(device == "cuda" and find_spec("cuequivariance") is not None),
778
+ head=head,
779
+ )
780
+
781
+ if relax_struct:
782
+ import ase.optimize as ase_optim
783
+
784
+ optim = getattr(ase_optim, relax_algo)
785
+ optim(atoms).run(fmax=fmax)
786
+
787
+ if np.linalg.norm(atoms.get_forces(), axis=1).max() > 0.02:
788
+ warnings.warn("large forces found", stacklevel=0)
789
+
790
+ if analytical_hessian:
791
+ H = atoms.calc.get_hessian(atoms).reshape((3 * len(atoms), 3 * len(atoms)))
792
+ else:
793
+ from ase.atoms import Atoms
794
+ from phonopy import Phonopy
795
+ from phonopy.structure.atoms import PhonopyAtoms
796
+
797
+ phonopy_atoms = PhonopyAtoms(
798
+ symbols=atoms.get_chemical_symbols(),
799
+ cell=atoms.get_cell(),
800
+ scaled_positions=atoms.get_scaled_positions(),
801
+ )
802
+ phonopy = Phonopy(phonopy_atoms, supercell_matrix=np.eye(3), log_level=2)
803
+ phonopy.generate_displacements(distance=0.02)
804
+
805
+ forces = []
806
+ for phonopy_atoms in tqdm(
807
+ phonopy.supercells_with_displacements, desc="[*] computing forces"
808
+ ):
809
+ atoms_dx = Atoms(
810
+ symbols=phonopy_atoms.symbols,
811
+ scaled_positions=phonopy_atoms.scaled_positions,
812
+ cell=phonopy_atoms.cell,
813
+ pbc=True,
814
+ )
815
+ atoms.calc.calculate(atoms=atoms_dx, properties=["forces"])
816
+ forces.append(atoms.calc.results["forces"])
817
+
818
+ phonopy.forces = np.array(forces)
819
+ phonopy.produce_force_constants()
820
+ phonopy.symmetrize_force_constants()
821
+
822
+ if phonopy.force_constants is None:
823
+ raise RuntimeError("phonopy failed to produce force constants")
824
+
825
+ H = phonopy.force_constants.swapaxes(1, 2).reshape((3 * len(atoms), 3 * len(atoms)))
826
+
827
+ if not np.allclose(H, H.T):
828
+ warnings.warn("Hessian matrix is not symmetric", stacklevel=0)
829
+
830
+ sqrt_mass = np.repeat(np.sqrt(atoms.get_masses()), 3)
831
+ inv_sqrt_mass = 1 / sqrt_mass
832
+ H = np.einsum("i,ij,j->ij", inv_sqrt_mass, H, inv_sqrt_mass)
833
+ np.savez_compressed(save_file, H=H, sqrt_mass=sqrt_mass, cmd=" ".join(sys.argv))
834
+
835
+
836
+ def compute_lineshape(
837
+ ground: Path,
838
+ excited: Path,
839
+ dynmat_file: Path,
840
+ emission: Annotated[bool, Parameter(name="--luminescence", negative="--absorption")] = True,
841
+ dE: float | None = None,
842
+ gamma_zpl: float = 0.001,
843
+ sigma_zpl: float = 0.0,
844
+ sigma_psb: tuple[float, float] = (0.005, 0.001),
845
+ gamma_psb: tuple[float, float] | None = None,
846
+ omega_mult: float = 5.0,
847
+ norm: str = "area",
848
+ T: Annotated[float, Parameter(name=["--T", "-T"])] = 0.0,
849
+ plot: str | None = None,
850
+ ) -> None:
851
+ """Calculate the spectral density/function and lineshape for a given dynamical matrix.
852
+
853
+ Args:
854
+ ground (Path): path to structure containing the ground state equilibrium geometry.
855
+ excited (Path): path to structure containing the excited state equilibrium geometry.
856
+ dynmat_file (Path): path to dynamical matrix file produced by :func:`compute_dynmat` or by
857
+ phonopy and converted with :func:`convert_from_phonopy`.
858
+ emission (bool): write luminescence (True) or absorption (False) spectrum.
859
+ dE (float): zero-phonon line energy in eV, inferred from ground/excited if not provided.
860
+ gamma_zpl (float): Lorentzian broadening in the ZPL to capture homogeneous broadening.
861
+ sigma_zpl (float): Gaussian broadening in the ZPL to capture inhomogeneous broadening.
862
+ sigma_psb (float, float): Gaussian broadening used to broaden the partial Huang-Rhys
863
+ factors. The broadening factor is linearly interpolated from sigma_psb[0] at zero
864
+ frequency to sigma_psb[1] at the highest (non-LVM) frequency.
865
+ gamma_psb (float, float): Turns on Lorentzian broadening of local vibrational modes
866
+ identified by their inverse participation ratio. gamma_psb[0] is ipr_cut and
867
+ gamma_psb[1] is gamma_lvm. See :class:`Broadening`.
868
+ omega_mult (float): number of factors of maximum phonon frequency from ZPL to plot.
869
+ norm (str): normalization of luminescence (area or max).
870
+ T (float): Temperature in kelvin.
871
+ plot (str): if provide, specifies the type of plot to be generated and save in the current
872
+ working directory. Can be "subplot", "inset", "dos", "S", or "L". See
873
+ :func:`plot_spec_funcs` for more info.
874
+ """
875
+ initial, final = (excited, ground) if emission else (ground, excited)
876
+
877
+ ini_atoms: Atoms = ase.io.read(initial) # type: ignore[assignment]
878
+ fin_atoms: Atoms = ase.io.read(final) # type:ignore[assignment]
879
+
880
+ if dE is None:
881
+ dE = np.abs(fin_atoms.get_potential_energy() - ini_atoms.get_potential_energy())
882
+ logger.info(f"energy difference not provided, found dE = {dE} from input structures")
883
+
884
+ sqrt_mass = np.repeat(np.sqrt(fin_atoms.get_masses()), 3)
885
+
886
+ dx = get_disp_vect(fin_atoms, ini_atoms)
887
+ dq = sqrt_mass * dx
888
+ logger.info(f"dQ={np.linalg.norm(dq):.06f} amu^{{1/2}} Å")
889
+
890
+ logger.info("reading dynmat matrix")
891
+ data = np.load(dynmat_file)
892
+ H = data["H"]
893
+
894
+ # probably not needed, but doesn't hurt
895
+ if not np.allclose(sqrt_mass, data["sqrt_mass"]):
896
+ warnings.warn(
897
+ "sqrt_mass in dynmat file is not compatible with ground/excited conf",
898
+ stacklevel=0,
899
+ )
900
+
901
+ logger.info("diagonalizing")
902
+ omega, U = get_phonons(H)
903
+ dq_k = U.T @ dq
904
+
905
+ if gamma_psb is not None:
906
+ logger.info("computing inverse participation ratios")
907
+ ipr, ipr_cut, gamma_lvm = get_ipr(U), gamma_psb[0], gamma_psb[1]
908
+ else:
909
+ ipr, ipr_cut, gamma_lvm = None, None, None
910
+
911
+ logger.info("computing spectral functions")
912
+ w, dos, S, A = get_phonon_spec_func(
913
+ dq_k,
914
+ omega,
915
+ sigma_psb=sigma_psb,
916
+ gamma_zpl=gamma_zpl,
917
+ sigma_zpl=sigma_zpl,
918
+ gamma_lvm=gamma_lvm,
919
+ ipr=ipr,
920
+ ipr_cut=ipr_cut,
921
+ T=T,
922
+ )
923
+ tw, L = convert_A_to_L(w, A, dE, emission=emission, norm=norm)
924
+
925
+ logger.info("saving results to .txt files")
926
+ np.savetxt("spec_funcs.txt", np.array((w, S, A)).T)
927
+ np.savetxt("lineshape.txt", np.array((tw, L)).T)
928
+
929
+ if plot is not None:
930
+ import matplotlib.pyplot as plt
931
+
932
+ plot_spec_funcs(
933
+ (w, dos, S, tw, L),
934
+ None,
935
+ dE,
936
+ emission=emission,
937
+ omega_mult=omega_mult,
938
+ omega_max=(omega2eV * omega.max()),
939
+ plot_type=plot,
940
+ )
941
+ plt.savefig("lineshape.png", dpi=600, bbox_inches="tight")
942
+
943
+
944
+ def analyze_dynmat(dynmat: Path, structure: Path) -> None:
945
+ """Produce analysis plots of the dynamical matrix."""
946
+ H = np.load(dynmat)["H"]
947
+
948
+ import matplotlib.pyplot as plt
949
+
950
+ fig, ax = plt.subplots(figsize=(4, 3))
951
+ p = ax.imshow(np.log(np.abs(H)), vmin=-6)
952
+ plt.colorbar(p)
953
+ plt.savefig("H.png", dpi=600, bbox_inches="tight")
954
+
955
+ atoms: Atoms = ase.io.read(structure) # type: ignore[assignment]
956
+ d = atoms.get_all_distances(mic=True)
957
+
958
+ fig, ax = plt.subplots(figsize=(4, 3))
959
+ for i in range(3):
960
+ for j in range(3):
961
+ ax.plot(d.flatten(), np.abs(H[i::3, j::3].flatten()), ".", ms=1)
962
+ ax.set_yscale("log")
963
+ ax.set_xlabel(r"$\vert {\bf R}_I - {\bf R}_J \vert$ [${\rm \AA}$]")
964
+ ax.set_ylabel(r"$\vert \Phi_{I\alpha,J\beta} \vert$ [eV/amu$^{1/2}$ ${\rm \AA}$]")
965
+ plt.savefig("radial_H.png", dpi=600, bbox_inches="tight")