reaxkit 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. reaxkit/__init__.py +0 -0
  2. reaxkit/analysis/__init__.py +0 -0
  3. reaxkit/analysis/composed/RDF_analyzer.py +560 -0
  4. reaxkit/analysis/composed/__init__.py +0 -0
  5. reaxkit/analysis/composed/connectivity_analyzer.py +706 -0
  6. reaxkit/analysis/composed/coordination_analyzer.py +144 -0
  7. reaxkit/analysis/composed/electrostatics_analyzer.py +687 -0
  8. reaxkit/analysis/per_file/__init__.py +0 -0
  9. reaxkit/analysis/per_file/control_analyzer.py +165 -0
  10. reaxkit/analysis/per_file/eregime_analyzer.py +108 -0
  11. reaxkit/analysis/per_file/ffield_analyzer.py +305 -0
  12. reaxkit/analysis/per_file/fort13_analyzer.py +79 -0
  13. reaxkit/analysis/per_file/fort57_analyzer.py +106 -0
  14. reaxkit/analysis/per_file/fort73_analyzer.py +61 -0
  15. reaxkit/analysis/per_file/fort74_analyzer.py +65 -0
  16. reaxkit/analysis/per_file/fort76_analyzer.py +191 -0
  17. reaxkit/analysis/per_file/fort78_analyzer.py +154 -0
  18. reaxkit/analysis/per_file/fort79_analyzer.py +83 -0
  19. reaxkit/analysis/per_file/fort7_analyzer.py +393 -0
  20. reaxkit/analysis/per_file/fort99_analyzer.py +411 -0
  21. reaxkit/analysis/per_file/molfra_analyzer.py +359 -0
  22. reaxkit/analysis/per_file/params_analyzer.py +258 -0
  23. reaxkit/analysis/per_file/summary_analyzer.py +84 -0
  24. reaxkit/analysis/per_file/trainset_analyzer.py +84 -0
  25. reaxkit/analysis/per_file/vels_analyzer.py +95 -0
  26. reaxkit/analysis/per_file/xmolout_analyzer.py +528 -0
  27. reaxkit/cli.py +181 -0
  28. reaxkit/count_loc.py +276 -0
  29. reaxkit/data/alias.yaml +89 -0
  30. reaxkit/data/constants.yaml +27 -0
  31. reaxkit/data/reaxff_input_files_contents.yaml +186 -0
  32. reaxkit/data/reaxff_output_files_contents.yaml +301 -0
  33. reaxkit/data/units.yaml +38 -0
  34. reaxkit/help/__init__.py +0 -0
  35. reaxkit/help/help_index_loader.py +531 -0
  36. reaxkit/help/introspection_utils.py +131 -0
  37. reaxkit/io/__init__.py +0 -0
  38. reaxkit/io/base_handler.py +165 -0
  39. reaxkit/io/generators/__init__.py +0 -0
  40. reaxkit/io/generators/control_generator.py +123 -0
  41. reaxkit/io/generators/eregime_generator.py +341 -0
  42. reaxkit/io/generators/geo_generator.py +967 -0
  43. reaxkit/io/generators/trainset_generator.py +1758 -0
  44. reaxkit/io/generators/tregime_generator.py +113 -0
  45. reaxkit/io/generators/vregime_generator.py +164 -0
  46. reaxkit/io/generators/xmolout_generator.py +304 -0
  47. reaxkit/io/handlers/__init__.py +0 -0
  48. reaxkit/io/handlers/control_handler.py +209 -0
  49. reaxkit/io/handlers/eregime_handler.py +122 -0
  50. reaxkit/io/handlers/ffield_handler.py +812 -0
  51. reaxkit/io/handlers/fort13_handler.py +123 -0
  52. reaxkit/io/handlers/fort57_handler.py +143 -0
  53. reaxkit/io/handlers/fort73_handler.py +145 -0
  54. reaxkit/io/handlers/fort74_handler.py +155 -0
  55. reaxkit/io/handlers/fort76_handler.py +195 -0
  56. reaxkit/io/handlers/fort78_handler.py +142 -0
  57. reaxkit/io/handlers/fort79_handler.py +227 -0
  58. reaxkit/io/handlers/fort7_handler.py +264 -0
  59. reaxkit/io/handlers/fort99_handler.py +128 -0
  60. reaxkit/io/handlers/geo_handler.py +224 -0
  61. reaxkit/io/handlers/molfra_handler.py +184 -0
  62. reaxkit/io/handlers/params_handler.py +137 -0
  63. reaxkit/io/handlers/summary_handler.py +135 -0
  64. reaxkit/io/handlers/trainset_handler.py +658 -0
  65. reaxkit/io/handlers/vels_handler.py +293 -0
  66. reaxkit/io/handlers/xmolout_handler.py +174 -0
  67. reaxkit/utils/__init__.py +0 -0
  68. reaxkit/utils/alias.py +219 -0
  69. reaxkit/utils/cache.py +77 -0
  70. reaxkit/utils/constants.py +75 -0
  71. reaxkit/utils/equation_of_states.py +96 -0
  72. reaxkit/utils/exceptions.py +27 -0
  73. reaxkit/utils/frame_utils.py +175 -0
  74. reaxkit/utils/log.py +43 -0
  75. reaxkit/utils/media/__init__.py +0 -0
  76. reaxkit/utils/media/convert.py +90 -0
  77. reaxkit/utils/media/make_video.py +91 -0
  78. reaxkit/utils/media/plotter.py +812 -0
  79. reaxkit/utils/numerical/__init__.py +0 -0
  80. reaxkit/utils/numerical/extrema_finder.py +96 -0
  81. reaxkit/utils/numerical/moving_average.py +103 -0
  82. reaxkit/utils/numerical/numerical_calcs.py +75 -0
  83. reaxkit/utils/numerical/signal_ops.py +135 -0
  84. reaxkit/utils/path.py +55 -0
  85. reaxkit/utils/units.py +104 -0
  86. reaxkit/webui/__init__.py +0 -0
  87. reaxkit/webui/app.py +0 -0
  88. reaxkit/webui/components.py +0 -0
  89. reaxkit/webui/layouts.py +0 -0
  90. reaxkit/webui/utils.py +0 -0
  91. reaxkit/workflows/__init__.py +0 -0
  92. reaxkit/workflows/composed/__init__.py +0 -0
  93. reaxkit/workflows/composed/coordination_workflow.py +393 -0
  94. reaxkit/workflows/composed/electrostatics_workflow.py +587 -0
  95. reaxkit/workflows/composed/xmolout_fort7_workflow.py +343 -0
  96. reaxkit/workflows/meta/__init__.py +0 -0
  97. reaxkit/workflows/meta/help_workflow.py +136 -0
  98. reaxkit/workflows/meta/introspection_workflow.py +235 -0
  99. reaxkit/workflows/meta/make_video_workflow.py +61 -0
  100. reaxkit/workflows/meta/plotter_workflow.py +601 -0
  101. reaxkit/workflows/per_file/__init__.py +0 -0
  102. reaxkit/workflows/per_file/control_workflow.py +110 -0
  103. reaxkit/workflows/per_file/eregime_workflow.py +267 -0
  104. reaxkit/workflows/per_file/ffield_workflow.py +390 -0
  105. reaxkit/workflows/per_file/fort13_workflow.py +86 -0
  106. reaxkit/workflows/per_file/fort57_workflow.py +137 -0
  107. reaxkit/workflows/per_file/fort73_workflow.py +151 -0
  108. reaxkit/workflows/per_file/fort74_workflow.py +88 -0
  109. reaxkit/workflows/per_file/fort76_workflow.py +188 -0
  110. reaxkit/workflows/per_file/fort78_workflow.py +135 -0
  111. reaxkit/workflows/per_file/fort79_workflow.py +314 -0
  112. reaxkit/workflows/per_file/fort7_workflow.py +592 -0
  113. reaxkit/workflows/per_file/fort83_workflow.py +60 -0
  114. reaxkit/workflows/per_file/fort99_workflow.py +223 -0
  115. reaxkit/workflows/per_file/geo_workflow.py +554 -0
  116. reaxkit/workflows/per_file/molfra_workflow.py +577 -0
  117. reaxkit/workflows/per_file/params_workflow.py +135 -0
  118. reaxkit/workflows/per_file/summary_workflow.py +161 -0
  119. reaxkit/workflows/per_file/trainset_workflow.py +356 -0
  120. reaxkit/workflows/per_file/tregime_workflow.py +79 -0
  121. reaxkit/workflows/per_file/vels_workflow.py +309 -0
  122. reaxkit/workflows/per_file/vregime_workflow.py +75 -0
  123. reaxkit/workflows/per_file/xmolout_workflow.py +678 -0
  124. reaxkit-1.0.0.dist-info/METADATA +128 -0
  125. reaxkit-1.0.0.dist-info/RECORD +130 -0
  126. reaxkit-1.0.0.dist-info/WHEEL +5 -0
  127. reaxkit-1.0.0.dist-info/entry_points.txt +2 -0
  128. reaxkit-1.0.0.dist-info/licenses/AUTHORS.md +20 -0
  129. reaxkit-1.0.0.dist-info/licenses/LICENSE +21 -0
  130. reaxkit-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1758 @@
1
+ """
2
+ Trainset generation utilities for ReaxFF parameter training.
3
+
4
+ This module provides end-to-end helpers for generating elastic-energy
5
+ training targets (bulk EOS and elastic constants), optional strained
6
+ geometries, YAML-based configuration, and Materials Project-based
7
+ bootstrapping for trainset creation.
8
+
9
+ Typical use cases include:
10
+
11
+ - generating ``trainset_elastic.in`` plus energy tables (E vs strain/volume)
12
+ - generating strained XYZ/GEO structures for ReaxFF runs
13
+ - writing/reading a ``trainset_elastic.yaml`` settings file
14
+ - creating a ready-to-run trainset from a Materials Project material ID
15
+ """
16
+
17
+
18
+ from __future__ import annotations
19
+
20
+ import math
21
+ import os
22
+ from typing import Dict, List, Tuple, Optional, Any, Literal
23
+ from mp_api.client import MPRester
24
+ import numpy as np
25
+ from ase import Atoms
26
+ from ase.geometry import cellpar_to_cell, cell_to_cellpar
27
+ import shutil
28
+ from pathlib import Path
29
+
30
+ from reaxkit.io.generators.geo_generator import read_structure, write_structure, xtob
31
+ from reaxkit.utils.equation_of_states import vinet_energy_trainset
32
+ from reaxkit.utils.constants import const
33
+
34
+ # =============================================================================
35
+ # OVERVIEW OF THE THE CODE
36
+ # =============================================================================
37
+
38
+ """
39
+ This is a refactor/translation of the Fortran program (elastic_energy_v2) developed by Y. Shin.
40
+
41
+ This code is used to generate trainset data, and is structured into 4 parts:
42
+
43
+ 1. ELASTIC_ENERGY SECTION:
44
+ which is used to obtain energy vs volume for orthogonal systems. That's because it treats
45
+ strain as independent scalar components instead of a full strain tensor,
46
+ which breaks down for non-orthogonal cells.
47
+ This part yields to sorts of data:
48
+ - Bulk modulus: Energy vs Volume using an EOS (Vinet)
49
+ - Elastic constants (c11..c66): Energy vs strain using quadratic strain-energy forms
50
+
51
+ there is a top-level function here, generate_all_energy_vs_volume_data, which calls
52
+ two other functions generate_bulk_data and generate_elastic_data to compute energy
53
+ vs volume change according to the explanation above.
54
+
55
+ 2. ELASTIC_GEO SECTION:
56
+ which is used to obtain expanded or compressed geometries in xyz and bgf format for any
57
+ crystal structure (i.e., not limited to orthognal systems), using the main function
58
+ generate_strained_geometries_with_xtob, which calls other related functions.
59
+
60
+ 3. YAML file management for settings of trainset
61
+ this part is used to write (using write_trainset_settings_yaml fucntion) or read
62
+ (using read_trainset_settings_yaml) a settings file with .yaml format. This file determines
63
+ cell dimensions, bulk modulus, and any other required settings for getting
64
+ (using generate_trainset_from_yaml function) elastic energy or expanded/compressed geo files.
65
+
66
+ 4. MP API Handler:
67
+ which is used to get crystal structure, cell dimension and angles, and mechanical properties
68
+ directly from material's project website, and generate the corresponding trainset.
69
+ The main function here is generate_trainset_settings_yaml_from_mp_simple which:
70
+ 1. makes the connection to MP website and gets the data
71
+ 2. writes an informative yaml file using write_trainset_settings_yaml fucntion
72
+ 3. makes two geometry files in .xyz and .cif format of the material
73
+ 4. generates training set using Yaml file and xyz file using generate_trainset_from_yaml
74
+
75
+ """
76
+
77
+ # =============================================================================
78
+ # 1. ELASTIC_ENERGY SECTION
79
+ # =============================================================================
80
+
81
+ # -----------------------------------------------------------------------------
82
+ # Constants (match Fortran code)
83
+ # -----------------------------------------------------------------------------
84
+
85
+ # AVOGADRO_CONSTANT is NA in the original Fortran code
86
+ # ENERGY_CONVERSION_FACTOR is factor in Fortran
87
+ AVOGADRO_CONSTANT = const("AVOGADRO_CONSTANT")
88
+ ENERGY_CONVERSION_FACTOR = 10.0 * 4.184 / AVOGADRO_CONSTANT
89
+
90
+ # -----------------------------------------------------------------------------
91
+ # Small utilities for elastic_energy
92
+ # -----------------------------------------------------------------------------
93
+
94
+ def _fortran_nint(nonnegative_value: float) -> int:
95
+ """
96
+ Round a value to the nearest integer using Fortran-style NINT behavior.
97
+
98
+ The Fortran code uses nint(tstrain/dstrain) with tstrain>=0.
99
+ We implement "round half up" for x>=0 (ties are rare for these inputs).
100
+
101
+ Works on
102
+ --------
103
+ Numeric values (utility)
104
+
105
+ Parameters
106
+ ----------
107
+ nonnegative_value : float
108
+ Value to round.
109
+
110
+ Returns
111
+ -------
112
+ int
113
+ Nearest integer (ties rounded away from zero for typical non-negative inputs).
114
+
115
+ Examples
116
+ --------
117
+ >>> _fortran_nint(1.2)
118
+ 1
119
+ >>> _fortran_nint(1.5)
120
+ 2
121
+ """
122
+ # We support negative too, though it won't be used here.
123
+ if nonnegative_value < 0:
124
+ return -_fortran_nint(-nonnegative_value)
125
+ return int(math.floor(nonnegative_value + 0.5))
126
+
127
+
128
+ def _compute_orthogonal_lattice_cell_volume(
129
+ *,
130
+ a_length: float,
131
+ b_length: float,
132
+ c_length: float,
133
+ alpha_deg: float,
134
+ beta_deg: float,
135
+ gamma_deg: float,
136
+ ) -> float:
137
+ """
138
+ Compute the unit-cell volume from lattice lengths and angles.
139
+
140
+ This follows the same trig construction used in the Fortran snippet
141
+ (with cosphi/sinphi). For orthogonal cells, reduces to a*b*c.
142
+
143
+ Fortran mapping
144
+ --------------
145
+ - a_length, b_length, c_length correspond to L2(1), L2(2), L2(3) when used for the bulk volume
146
+ (and also for elastic volume in your snippet).
147
+ - alpha_deg, beta_deg, gamma_deg correspond to ang2(1..3).
148
+
149
+ Works on
150
+ --------
151
+ Lattice parameters for (typically) orthogonal cells
152
+
153
+ Parameters
154
+ ----------
155
+ a_length, b_length, c_length : float
156
+ Cell edge lengths (Å).
157
+ alpha_deg, beta_deg, gamma_deg : float
158
+ Cell angles (degrees).
159
+
160
+ Returns
161
+ -------
162
+ float
163
+ Unit-cell volume (Å^3).
164
+
165
+ Examples
166
+ --------
167
+ >>> _compute_orthogonal_lattice_cell_volume(
168
+ ... a_length=2.0, b_length=3.0, c_length=4.0,
169
+ ... alpha_deg=90.0, beta_deg=90.0, gamma_deg=90.0
170
+ ... )
171
+ 24.0
172
+ """
173
+ degrees_to_radians = math.pi / 180.0 # degrees_to_radians is dgrrdn in Fortran (via rdndgr)
174
+ alpha_rad = alpha_deg * degrees_to_radians # alpha_rad is halfa in Fortran
175
+ beta_rad = beta_deg * degrees_to_radians # beta_rad is hbeta in Fortran
176
+ gamma_rad = gamma_deg * degrees_to_radians # gamma_rad is hgamma in Fortran
177
+
178
+ sin_alpha = math.sin(alpha_rad) # sin_alpha is sinalf in Fortran
179
+ cos_alpha = math.cos(alpha_rad) # cos_alpha is cosalf in Fortran
180
+ sin_beta = math.sin(beta_rad) # sin_beta is sinbet in Fortran
181
+ cos_beta = math.cos(beta_rad) # cos_beta is cosbet in Fortran
182
+
183
+ denominator = sin_alpha * sin_beta # denominator is sinalf*sinbet in Fortran
184
+ if abs(denominator) < 1e-16:
185
+ raise ValueError("Invalid cell angles: sin(alpha)*sin(beta) ~ 0, cannot compute volume.")
186
+
187
+ cos_intermediate_angle = (math.cos(gamma_rad) - cos_alpha * cos_beta) / denominator
188
+ # cos_intermediate_angle is cosphi in the original Fortran code
189
+ # Fortran clamps only > 1.0; we clamp to [-1,1] for numerical safety.
190
+ cos_intermediate_angle = max(-1.0, min(1.0, cos_intermediate_angle))
191
+
192
+ sin_intermediate_angle = math.sqrt(max(0.0, 1.0 - cos_intermediate_angle * cos_intermediate_angle))
193
+ # sin_intermediate_angle is sinphi in the original Fortran code
194
+
195
+ # Build the same determinant components as the Fortran snippet.
196
+ basis_x_component = a_length * sin_beta * sin_intermediate_angle
197
+ # basis_x_component is tm11 in the original Fortran code
198
+
199
+ basis_y_component = b_length * sin_alpha
200
+ # basis_y_component is tm22 in the original Fortran code
201
+
202
+ basis_z_component = c_length
203
+ # basis_z_component is tm33 in the original Fortran code
204
+
205
+ reference_volume = basis_x_component * basis_y_component * basis_z_component
206
+ # reference_volume is vol in the original Fortran code
207
+
208
+ return reference_volume
209
+
210
+
211
+ def _build_symmetric_grid(
212
+ *,
213
+ max_abs_value: float,
214
+ step: float,
215
+ grid_mode: str,
216
+ ) -> List[float]:
217
+ """
218
+ Build a symmetric grid spanning [-max_abs_value, +max_abs_value].
219
+
220
+ Fortran mapping
221
+ --------------
222
+ - max_abs_value corresponds to:
223
+ * tstrain2 in bulk block
224
+ * tstrain in elastic blocks
225
+ - step corresponds to:
226
+ * dstrain=0.004 in bulk block
227
+ * dstrain=0.005 in elastic blocks
228
+ - num_strain_steps corresponds to nstrain in Fortran
229
+
230
+ Works on
231
+ --------
232
+ Strain/parameter grids for bulk and elastic target generation
233
+
234
+ Parameters
235
+ ----------
236
+ max_abs_value : float
237
+ Maximum absolute value to include.
238
+ step : float
239
+ Grid spacing (> 0).
240
+ grid_mode : {"bulk","elastic"}
241
+ Grid logic mode matching the generator block.
242
+
243
+ Returns
244
+ -------
245
+ list[float]
246
+ Symmetric grid values including 0.
247
+
248
+ Examples
249
+ --------
250
+ >>> _build_symmetric_grid(max_abs_value=0.01, step=0.005, grid_mode="elastic")
251
+ [-0.01, -0.005, 0.0, 0.005, 0.01]
252
+ """
253
+ if step <= 0:
254
+ raise ValueError("step must be positive.")
255
+
256
+ rounded_ratio = _fortran_nint(max_abs_value / step) # rounded_ratio is t in Fortran
257
+
258
+ # For both bulk and elastic blocks in your snippet, logic effectively becomes:
259
+ # if rounded_ratio*step < max_abs_value -> nstrain = rounded_ratio + 1
260
+ # else nstrain = rounded_ratio
261
+ # (bulk block used .ge, elastic block used .eq but falls through)
262
+ if grid_mode not in ("bulk", "elastic"):
263
+ raise ValueError("grid_mode must be 'bulk' or 'elastic'.")
264
+
265
+ num_strain_steps = rounded_ratio + 1 if (rounded_ratio * step < max_abs_value) else rounded_ratio
266
+ # num_strain_steps is nstrain in the original Fortran code
267
+
268
+ return [step * n for n in range(-num_strain_steps, num_strain_steps + 1)]
269
+
270
+
271
+ def _make_label(prefix: str, signed_index: int) -> str:
272
+ """
273
+ Construct a trainset label for a signed strain index.
274
+
275
+ Construct labels like Fortran:
276
+ bulk_c0001, bulk_e0001, bulk_0
277
+ c11_c0001, c11_e0001, c11_0
278
+
279
+ Works on
280
+ --------
281
+ Trainset naming for bulk/elastic targets
282
+
283
+ Parameters
284
+ ----------
285
+ prefix : str
286
+ Label prefix (e.g., "bulk", "c11").
287
+ signed_index : int
288
+ Signed index (negative = compression, positive = expansion, 0 = reference).
289
+
290
+ Returns
291
+ -------
292
+ str
293
+ Label string (e.g., ``"bulk_c0001"``, ``"c11_e0002"``, ``"c11_0"``).
294
+
295
+ Examples
296
+ --------
297
+ >>> _make_label("bulk", -3)
298
+ 'bulk_c0003'
299
+ >>> _make_label("c11", 0)
300
+ 'c11_0'
301
+ """
302
+ if signed_index == 0:
303
+ return f"{prefix}_0"
304
+ compression_or_expansion = "c" if signed_index < 0 else "e"
305
+ return f"{prefix}_{compression_or_expansion}{abs(signed_index):04d}"
306
+
307
+
308
+ def _index_from_grid_value(grid_value: float, step: float) -> int:
309
+ """
310
+ Convert a grid value to its signed integer index for labeling.
311
+
312
+ Convert a grid value back to its signed integer index (used for label naming).
313
+ Fortran uses eps = step*n exactly, so this is safe; we still round for robustness.
314
+
315
+ Works on
316
+ --------
317
+ Grid-to-index conversion for trainset label naming
318
+
319
+ Parameters
320
+ ----------
321
+ grid_value : float
322
+ Grid value (typically a strain).
323
+ step : float
324
+ Grid spacing.
325
+
326
+ Returns
327
+ -------
328
+ int
329
+ Signed integer index corresponding to ``grid_value / step``.
330
+
331
+ Examples
332
+ --------
333
+ >>> _index_from_grid_value(-0.01, 0.005)
334
+ -2
335
+ """
336
+ if step == 0:
337
+ return 0
338
+ return int(round(grid_value / step))
339
+
340
+
341
+ # -----------------------------------------------------------------------------
342
+ # Public API: generators for elastic_energy
343
+ # -----------------------------------------------------------------------------
344
+
345
+ def _generate_bulk_data(
346
+ *,
347
+ bulk_modulus_gpa: float,
348
+ bulk_modulus_pressure_derivative: float,
349
+ max_volumetric_strain_percent: float,
350
+ cell: Dict[str, float],
351
+ linear_strain_step: float = 0.004,
352
+ reference_energy: float = 0.0,
353
+ ) -> Tuple[List[Tuple[float, float]], List[str]]:
354
+ """
355
+ Generate bulk EOS energy-vs-volume targets for trainset fitting.
356
+
357
+ Works on
358
+ --------
359
+ Elastic-energy training targets (bulk EOS), driven by lattice parameters
360
+
361
+ Parameters
362
+ ----------
363
+ bulk_modulus_gpa : float
364
+ Bulk modulus B0 (GPa).
365
+ bulk_modulus_pressure_derivative : float
366
+ Pressure derivative B0' (dimensionless).
367
+ max_volumetric_strain_percent : float
368
+ Maximum volumetric strain magnitude (%).
369
+ cell : dict
370
+ Reference cell parameters with keys: ``a``, ``b``, ``c``, ``alpha``, ``beta``, ``gamma``.
371
+ linear_strain_step : float, optional
372
+ Sampling step for the equivalent linear strain grid.
373
+ reference_energy : float, optional
374
+ Reference energy offset E0.
375
+
376
+ Returns
377
+ -------
378
+ tuple[list[tuple[float, float]], list[str]]
379
+ (1) Table rows of ``(volume, energy)`` for writing ``EvsStrain_bulk.dat``.
380
+ (2) Trainset ENERGY-block lines for inserting into ``trainset_elastic.in``.
381
+
382
+ Examples
383
+ --------
384
+ >>> cell = {"a": 2.9, "b": 2.9, "c": 3.5, "alpha": 90, "beta": 90, "gamma": 90}
385
+ >>> table, lines = _generate_bulk_data(
386
+ ... bulk_modulus_gpa=180.0,
387
+ ... bulk_modulus_pressure_derivative=4.0,
388
+ ... max_volumetric_strain_percent=6.0,
389
+ ... cell=cell,
390
+ ... )
391
+ """
392
+ reference_volume = _compute_orthogonal_lattice_cell_volume(
393
+ a_length=cell["a"], b_length=cell["b"], c_length=cell["c"],
394
+ alpha_deg=cell["alpha"], beta_deg=cell["beta"], gamma_deg=cell["gamma"]
395
+ )
396
+ # reference_volume is vol in the original Fortran code (computed from L2/ang2)
397
+
398
+ equivalent_linear_strain_max = (1.0 + max_volumetric_strain_percent / 100.0) ** (1.0 / 3.0) - 1.0
399
+ # equivalent_linear_strain_max is tstrain2 in the original Fortran code
400
+
401
+ linear_strain_grid = _build_symmetric_grid(
402
+ max_abs_value=equivalent_linear_strain_max,
403
+ step=linear_strain_step,
404
+ grid_mode="bulk",
405
+ )
406
+ # linear_strain_grid corresponds to the n loop from -nstrain..+nstrain in bulk block
407
+
408
+ bulk_table: List[Tuple[float, float]] = []
409
+ trainset_lines: List[str] = []
410
+
411
+ for linear_strain in linear_strain_grid:
412
+ signed_step_index = _index_from_grid_value(linear_strain, linear_strain_step)
413
+ # signed_step_index is n in the original Fortran code (bulk block)
414
+
415
+ strained_volume = reference_volume * (1.0 + linear_strain) ** 3
416
+ # strained_volume is strain0 in the original Fortran code (bulk block)
417
+
418
+ eos_energy = vinet_energy_trainset(
419
+ volume=strained_volume,
420
+ reference_volume=reference_volume,
421
+ bulk_modulus_gpa=bulk_modulus_gpa,
422
+ bulk_modulus_pressure_derivative=bulk_modulus_pressure_derivative,
423
+ reference_energy=reference_energy,
424
+ energy_conversion_factor=ENERGY_CONVERSION_FACTOR,
425
+ )
426
+ # eos_energy is db in the original Fortran code
427
+
428
+ if eos_energy == 0.0:
429
+ eos_energy = 1e-4 # eos_energy is db safeguard in Fortran
430
+
431
+ label = _make_label("bulk", signed_step_index) # label is title0 in Fortran
432
+
433
+ # Match the semantic structure of the Fortran trainset line:
434
+ # ' 1.0 + ' title0 ' /1 - ' 'bulk_0 /1' ' ' db
435
+ label_field_width = 11 # enough for 'bulk_c0005' and also pads 'bulk_0'
436
+ trainset_line = f" 1.0 + {label:<{label_field_width}} /1 - bulk_0 /1 {eos_energy:12.4f}"
437
+
438
+ bulk_table.append((strained_volume, eos_energy))
439
+ trainset_lines.append(trainset_line)
440
+
441
+ return bulk_table, trainset_lines
442
+
443
+
444
+ def _generate_elastic_data(
445
+ *,
446
+ elastic_constants_gpa: Dict[str, float],
447
+ max_strain_percent: float,
448
+ volume_reference_cell: Dict[str, float],
449
+ strain_step: float = 0.005,
450
+ ) -> Dict[str, Tuple[List[Tuple[float, float]], List[str]]]:
451
+ """
452
+ Generate elastic-constant energy-vs-strain targets for trainset fitting.
453
+
454
+ Works on
455
+ --------
456
+ Elastic-energy training targets (c11..c66), driven by lattice parameters
457
+
458
+ Parameters
459
+ ----------
460
+ elastic_constants_gpa : dict
461
+ Elastic constants in GPa with keys:
462
+ ``c11,c22,c33,c12,c13,c23,c44,c55,c66``.
463
+ max_strain_percent : float
464
+ Maximum linear strain magnitude (%).
465
+ volume_reference_cell : dict
466
+ Reference cell parameters with keys: ``a``, ``b``, ``c``, ``alpha``, ``beta``, ``gamma``.
467
+ strain_step : float, optional
468
+ Linear strain step size (unitless).
469
+
470
+ Returns
471
+ -------
472
+ dict[str, tuple[list[tuple[float, float]], list[str]]]
473
+ Mapping ``mode -> (table_rows, trainset_lines)``, where:
474
+ - table_rows: ``[(strain, energy), ...]``
475
+ - trainset_lines: ENERGY-block lines for ``trainset_elastic.in``
476
+
477
+ Examples
478
+ --------
479
+ >>> cij = {"c11": 300, "c22": 300, "c33": 250, "c12": 120, "c13": 140,
480
+ ... "c23": 140, "c44": 80, "c55": 80, "c66": 60}
481
+ >>> cell = {"a": 2.9, "b": 2.9, "c": 3.5, "alpha": 90, "beta": 90, "gamma": 90}
482
+ >>> out = _generate_elastic_data(
483
+ ... elastic_constants_gpa=cij,
484
+ ... max_strain_percent=3.0,
485
+ ... volume_reference_cell=cell,
486
+ ... )
487
+ >>> sorted(out.keys())[:3]
488
+ ['c11', 'c12', 'c13']
489
+ """
490
+ reference_volume = _compute_orthogonal_lattice_cell_volume(
491
+ a_length=volume_reference_cell["a"], b_length=volume_reference_cell["b"], c_length=volume_reference_cell["c"],
492
+ alpha_deg=volume_reference_cell["alpha"], beta_deg=volume_reference_cell["beta"], gamma_deg=volume_reference_cell["gamma"]
493
+ )
494
+ # reference_volume is vol in the original Fortran code (used in elastic energy prefactors)
495
+
496
+ max_linear_strain = max_strain_percent / 100.0 # max_linear_strain is tstrain in Fortran
497
+ linear_strain_grid = _build_symmetric_grid(
498
+ max_abs_value=max_linear_strain,
499
+ step=strain_step,
500
+ grid_mode="elastic",
501
+ )
502
+ # linear_strain_grid corresponds to n loop -nstrain..+nstrain in each elastic block
503
+
504
+ c = elastic_constants_gpa # c is cii(...) conceptually in Fortran, but named dict here
505
+
506
+ # Coefficient definitions replicate Fortran formulas (a*eps^2 with b=c=0):
507
+ def normal_strain_prefactor(cij: float) -> float:
508
+ return cij * reference_volume / (2.0 * ENERGY_CONVERSION_FACTOR)
509
+ # normal_strain_prefactor corresponds to a11/a22/a33 in Fortran (via a=cii*vol/(2*factor))
510
+
511
+ def shear_strain_prefactor(cij: float) -> float:
512
+ return 2.0 * cij * reference_volume / ENERGY_CONVERSION_FACTOR
513
+ # shear_strain_prefactor corresponds to a44/a55/a66 in Fortran (via a=2*cii*vol/factor)
514
+
515
+ def coupling_strain_prefactor(cij: float, cii: float, cjj: float) -> float:
516
+ return (-cij + (cii + cjj) / 2.0) * reference_volume / ENERGY_CONVERSION_FACTOR
517
+ # coupling_strain_prefactor corresponds to a12/a13/a23 in Fortran
518
+
519
+ energy_quadratic_coefficients = {
520
+ "c11": normal_strain_prefactor(c["c11"]),
521
+ "c22": normal_strain_prefactor(c["c22"]),
522
+ "c33": normal_strain_prefactor(c["c33"]),
523
+ "c12": coupling_strain_prefactor(c["c12"], c["c11"], c["c22"]),
524
+ "c13": coupling_strain_prefactor(c["c13"], c["c11"], c["c33"]),
525
+ "c23": coupling_strain_prefactor(c["c23"], c["c22"], c["c33"]),
526
+ "c44": shear_strain_prefactor(c["c44"]),
527
+ "c55": shear_strain_prefactor(c["c55"]),
528
+ "c66": shear_strain_prefactor(c["c66"]),
529
+ }
530
+ # energy_quadratic_coefficients corresponds to a11,a22,... etc in Fortran
531
+
532
+ result: Dict[str, Tuple[List[Tuple[float, float]], List[str]]] = {}
533
+
534
+ for mode_name, quadratic_prefactor in energy_quadratic_coefficients.items():
535
+ # quadratic_prefactor is 'a' (a11/a22/...) in the original Fortran code
536
+ table_rows: List[Tuple[float, float]] = []
537
+ trainset_lines: List[str] = []
538
+
539
+ for linear_strain in linear_strain_grid:
540
+ signed_step_index = _index_from_grid_value(linear_strain, strain_step)
541
+ # signed_step_index is n in the original Fortran elastic blocks
542
+
543
+ energy = quadratic_prefactor * (linear_strain ** 2)
544
+ # energy corresponds to d11/d22/... in Fortran (d = a*eps^2 + b*eps + c; b=c=0)
545
+
546
+ if energy == 0.0:
547
+ energy = 1e-4 # energy safeguard matches Fortran's dxx=0 -> 0.0001
548
+
549
+ label = _make_label(mode_name, signed_step_index) # label is title0 in Fortran
550
+
551
+ label_field_width = 12 # or 14 if you want extra room
552
+ trainset_line = f" 1.0 + {label:<{label_field_width}} /1 - {mode_name}_0 /1 {energy:12.4f}"
553
+ # trainset_line matches the Fortran ENERGY line semantic structure
554
+
555
+ table_rows.append((linear_strain, energy))
556
+ trainset_lines.append(trainset_line)
557
+
558
+ result[mode_name] = (table_rows, trainset_lines)
559
+
560
+ return result
561
+
562
+
563
+ def generate_all_energy_vs_volume_data(
564
+ *,
565
+ out_dir: str,
566
+ bulk_inputs: Dict[str, float],
567
+ elastic_inputs: Dict[str, float],
568
+ bulk_cell: Dict[str, float],
569
+ elastic_volume_cell: Optional[Dict[str, float]] = None,
570
+ bulk_options: Optional[Dict[str, float]] = None,
571
+ elastic_options: Optional[Dict[str, float]] = None,
572
+ trainset_filename: str = "trainset_elastic.in",
573
+ ) -> None:
574
+ """
575
+ Write bulk and elastic energy targets to trainset and table files.
576
+
577
+ High-level generator that:
578
+ 1) generates bulk + elastic energy targets
579
+ 2) writes:
580
+ - trainset_elastic.in
581
+ - EvsStrain_bulk.dat
582
+ - EvsStrain_c11.dat ... EvsStrain_c66.dat
583
+
584
+ This function performs both generation and writing and returns None.
585
+
586
+ Works on
587
+ --------
588
+ Elastic-energy training targets written to disk (trainset + tables)
589
+
590
+ Parameters
591
+ ----------
592
+ out_dir : str
593
+ Output directory to write files into.
594
+ bulk_inputs : dict
595
+ Bulk target inputs with keys such as ``B0_gpa``, ``B0_prime``,
596
+ and ``max_volumetric_strain_percent``.
597
+ elastic_inputs : dict
598
+ Elastic target inputs including ``max_strain_percent`` and ``cij`` values.
599
+ bulk_cell : dict
600
+ Bulk reference cell with keys: ``a,b,c,alpha,beta,gamma``.
601
+ elastic_volume_cell : dict or None, optional
602
+ Cell used to compute volume prefactors for elastic targets. If None, uses ``bulk_cell``.
603
+ bulk_options : dict or None, optional
604
+ Optional overrides (e.g., ``linear_strain_step``, ``reference_energy``).
605
+ elastic_options : dict or None, optional
606
+ Optional overrides (e.g., ``strain_step``).
607
+ trainset_filename : str, optional
608
+ Output trainset file name (default: ``"trainset_elastic.in"``).
609
+
610
+ Returns
611
+ -------
612
+ None
613
+ Writes ``trainset_elastic.in`` and E-vs-strain/volume tables to ``out_dir``.
614
+
615
+ Examples
616
+ --------
617
+ >>> generate_all_energy_vs_volume_data(
618
+ ... out_dir="out",
619
+ ... bulk_inputs={"B0_gpa": 180, "B0_prime": 4.0, "max_volumetric_strain_percent": 6.0},
620
+ ... elastic_inputs={"max_strain_percent": 3.0, "c11": 300, "c22": 300, "c33": 250,
621
+ ... "c12": 120, "c13": 140, "c23": 140, "c44": 80, "c55": 80, "c66": 60},
622
+ ... bulk_cell={"a": 2.9, "b": 2.9, "c": 3.5, "alpha": 90, "beta": 90, "gamma": 90},
623
+ ... )
624
+ """
625
+ import os
626
+
627
+ bulk_options = bulk_options or {}
628
+ elastic_options = elastic_options or {}
629
+ elastic_volume_cell = elastic_volume_cell or bulk_cell
630
+
631
+ os.makedirs(out_dir, exist_ok=True)
632
+
633
+ # -------------------------
634
+ # Orthogonality warnings
635
+ # -------------------------
636
+ def _warn_if_nonorthogonal(cell: dict, label: str) -> None:
637
+ angles = [cell.get("alpha", 90.0),
638
+ cell.get("beta", 90.0),
639
+ cell.get("gamma", 90.0)]
640
+ tol = 1e-6
641
+ if any(abs(a - 90.0) > tol for a in angles):
642
+ print(
643
+ f"⚠️ WARNING: {label} cell is non-orthogonal "
644
+ f"(angles = {angles}).\n"
645
+ " Elastic energy targets assume an orthogonal lattice.\n"
646
+ " Geometry generation is correct, but elastic energies may be inconsistent.\n"
647
+ )
648
+
649
+ _warn_if_nonorthogonal(elastic_volume_cell, label="Elastic")
650
+ _warn_if_nonorthogonal(bulk_cell, label="Bulk")
651
+
652
+ # -------------------------
653
+ # Bulk targets
654
+ # -------------------------
655
+ bulk_table, bulk_trainset_lines = _generate_bulk_data(
656
+ bulk_modulus_gpa=bulk_inputs["B0_gpa"],
657
+ bulk_modulus_pressure_derivative=bulk_inputs["B0_prime"],
658
+ max_volumetric_strain_percent=bulk_inputs["max_volumetric_strain_percent"],
659
+ cell=bulk_cell,
660
+ linear_strain_step=float(bulk_options.get("linear_strain_step", 0.004)),
661
+ reference_energy=float(bulk_options.get("reference_energy", 0.0)),
662
+ )
663
+
664
+ # -------------------------
665
+ # Elastic targets
666
+ # -------------------------
667
+ elastic_constants = {
668
+ k: elastic_inputs[k]
669
+ for k in ("c11", "c22", "c33", "c12", "c13", "c23", "c44", "c55", "c66")
670
+ }
671
+
672
+ elastic_targets = _generate_elastic_data(
673
+ elastic_constants_gpa=elastic_constants,
674
+ max_strain_percent=elastic_inputs["max_strain_percent"],
675
+ volume_reference_cell=elastic_volume_cell,
676
+ strain_step=float(elastic_options.get("strain_step", 0.005)),
677
+ )
678
+
679
+ # -------------------------
680
+ # Write trainset file
681
+ # -------------------------
682
+ mode_order = ["c11", "c22", "c33", "c12", "c13", "c23", "c44", "c55", "c66"]
683
+
684
+ trainset_lines: List[str] = [
685
+ "ENERGY",
686
+ "# Volume Bulk_EOS",
687
+ *bulk_trainset_lines,
688
+ ]
689
+
690
+ for mode in mode_order:
691
+ trainset_lines.append(f"# Volume {mode.upper()}_EOS")
692
+ trainset_lines.extend(elastic_targets[mode][1])
693
+
694
+ trainset_lines.append("ENDENERGY")
695
+
696
+ with open(os.path.join(out_dir, trainset_filename), "w", encoding="utf-8") as f:
697
+ f.write("\n".join(trainset_lines) + "\n")
698
+
699
+ # -------------------------
700
+ # Write tables
701
+ # -------------------------
702
+ def _write_two_column_table(path: str, header: str, rows):
703
+ with open(path, "w", encoding="utf-8") as f:
704
+ f.write(header.rstrip() + "\n")
705
+ for x, y in rows:
706
+ f.write(f"{x:8.3f} {y:12.4f}\n")
707
+
708
+ _write_two_column_table(
709
+ os.path.join(out_dir, "EvsStrain_bulk.dat"),
710
+ "# Volume Energy",
711
+ bulk_table,
712
+ )
713
+
714
+ for mode in mode_order:
715
+ _write_two_column_table(
716
+ os.path.join(out_dir, f"EvsStrain_{mode}.dat"),
717
+ "# Strain Energy",
718
+ elastic_targets[mode][0],
719
+ )
720
+
721
+
722
+ # =============================================================================
723
+ # 2. ELASTIC_GEO SECTION
724
+ # =============================================================================
725
+
726
+
727
+ # -----------------------------
728
+ # Strain matrices
729
+ # -----------------------------
730
+
731
+ def _deformation_matrix(mode: str, eps: float) -> np.ndarray:
732
+ """
733
+ Build a 3x3 deformation matrix for a named strain mode.
734
+
735
+ Return 3x3 deformation matrix d(mode, eps) similar to elastic_geo Fortran.
736
+ mode: bulk, c11,c22,c33,c12,c13,c23,c44,c55,c66
737
+
738
+ Works on
739
+ --------
740
+ Strain-mode deformation matrices for strained-geometry generation
741
+
742
+ Parameters
743
+ ----------
744
+ mode : str
745
+ Strain mode (e.g., ``"bulk"``, ``"c11"``, ``"c44"``).
746
+ eps : float
747
+ Strain magnitude (unitless).
748
+
749
+ Returns
750
+ -------
751
+ numpy.ndarray
752
+ 3x3 deformation matrix.
753
+
754
+ Examples
755
+ --------
756
+ >>> D = _deformation_matrix("c11", 0.01)
757
+ >>> D.shape
758
+ (3, 3)
759
+ """
760
+ I = np.eye(3, dtype=float)
761
+
762
+ if mode == "bulk":
763
+ return np.diag([1.0 + eps, 1.0 + eps, 1.0 + eps])
764
+
765
+ if mode == "c11":
766
+ d = I.copy()
767
+ d[0, 0] = 1.0 + eps
768
+ return d
769
+ if mode == "c22":
770
+ d = I.copy()
771
+ d[1, 1] = 1.0 + eps
772
+ return d
773
+ if mode == "c33":
774
+ d = I.copy()
775
+ d[2, 2] = 1.0 + eps
776
+ return d
777
+
778
+ # Coupled modes (Fortran uses u = 1/sqrt(1-eps^2))
779
+ if mode in {"c12", "c13", "c23"}:
780
+ u = 1.0 / np.sqrt(max(1e-30, 1.0 - eps * eps))
781
+ d = I.copy()
782
+ if mode == "c12":
783
+ d[0, 0] = u * (1.0 + eps)
784
+ d[1, 1] = u * (1.0 - eps)
785
+ d[2, 2] = 1.0
786
+ elif mode == "c13":
787
+ d[0, 0] = u * (1.0 + eps)
788
+ d[2, 2] = u * (1.0 - eps)
789
+ d[1, 1] = 1.0
790
+ else: # c23
791
+ d[1, 1] = u * (1.0 + eps)
792
+ d[2, 2] = u * (1.0 - eps)
793
+ d[0, 0] = 1.0
794
+ return d
795
+
796
+ # Shear modes (Fortran uses u = 1/(1-eps^2)^(1/3))
797
+ if mode in {"c44", "c55", "c66"}:
798
+ u = 1.0 / (max(1e-30, 1.0 - eps * eps) ** (1.0 / 3.0))
799
+ d = I.copy()
800
+ if mode == "c44":
801
+ d[1, 2] = eps
802
+ d[2, 1] = eps
803
+ elif mode == "c55":
804
+ d[0, 2] = eps
805
+ d[2, 0] = eps
806
+ else: # c66
807
+ d[0, 1] = eps
808
+ d[1, 0] = eps
809
+ return u * d
810
+
811
+ raise ValueError(f"Unknown mode: {mode!r}")
812
+
813
+
814
+ def _symmetric_strain_grid(max_abs: float, step: float) -> List[float]:
815
+ """
816
+ Build a symmetric strain grid spanning [-max_abs, +max_abs].
817
+
818
+ Works on
819
+ --------
820
+ Strain grids for strained-geometry generation
821
+
822
+ Parameters
823
+ ----------
824
+ max_abs : float
825
+ Maximum absolute strain magnitude.
826
+ step : float
827
+ Strain step size.
828
+
829
+ Returns
830
+ -------
831
+ list[float]
832
+ Symmetric strain grid including 0.
833
+
834
+ Examples
835
+ --------
836
+ >>> _symmetric_strain_grid(0.01, 0.005)
837
+ [-0.01, -0.005, 0.0, 0.005, 0.01]
838
+ """
839
+ n = int(np.ceil(max_abs / step))
840
+ grid = [k * step for k in range(-n, n + 1)]
841
+ grid = [x for x in grid if abs(x) <= max_abs + 1e-12]
842
+ if 0.0 not in grid:
843
+ grid.append(0.0)
844
+ grid.sort()
845
+ return grid
846
+
847
+
848
+ def _strain_title(prefix: str, eps: float, idx_abs: int) -> str:
849
+ """
850
+ Format a canonical title string for a strain state.
851
+
852
+ Works on
853
+ --------
854
+ Strain-state naming for strained-geometry outputs
855
+
856
+ Parameters
857
+ ----------
858
+ prefix : str
859
+ Mode prefix (e.g., ``"bulk"``, ``"c11"``).
860
+ eps : float
861
+ Strain value (unitless).
862
+ idx_abs : int
863
+ Absolute strain-step index.
864
+
865
+ Returns
866
+ -------
867
+ str
868
+ Title string used for output file naming.
869
+
870
+ Examples
871
+ --------
872
+ >>> _strain_title("bulk", -0.01, idx_abs=2)
873
+ 'bulk_c0002'
874
+ """
875
+
876
+ if abs(eps) < 1e-15:
877
+ return f"{prefix}_0"
878
+ return f"{prefix}_{'c' if eps < 0 else 'e'}{idx_abs:04d}"
879
+
880
+
881
+ def _make_base_atoms_from_xyz_and_cell(
882
+ xyz_path: str | Path,
883
+ cell: np.ndarray,
884
+ ) -> Atoms:
885
+ """
886
+ Read XYZ via read_structure(), attach the provided cell, and enable PBC.
887
+ """
888
+ atoms = read_structure(xyz_path, format="xyz")
889
+ atoms.set_cell(cell, scale_atoms=False)
890
+ atoms.set_pbc(True)
891
+ return atoms
892
+
893
+
894
+ def generate_strained_geometries_with_xtob(
895
+ *,
896
+ elastic_xyz: str | Path,
897
+ bulk_xyz: Optional[str | Path],
898
+ elastic_cell: Dict[str, float], # keys: a,b,c,alpha,beta,gamma
899
+ bulk_cell: Dict[str, float],
900
+ max_strain_elastic: float, # e.g. 0.02 for ±2%
901
+ dstrain_elastic: float, # e.g. 0.005
902
+ max_strain_bulk_linear: float, # linear strain, not volumetric
903
+ dstrain_bulk_linear: float, # e.g. 0.004
904
+ out_dir: str | Path,
905
+ sort_by: Optional[str] = None,
906
+ ) -> Dict[str, List[Path]]:
907
+ """
908
+ Generate strained XYZ structures and convert them to GEO via xtob.
909
+
910
+ Creates strained XYZ files (with comment=title on line 2) and converts each
911
+ to GEO using xtob().
912
+
913
+ Output folders:
914
+ out_dir/xyz_strained/*.xyz
915
+ out_dir/geo_strained/*.bgf
916
+
917
+ Works on
918
+ --------
919
+ XYZ input structures + GEO/XTLGRF outputs via ``xtob``
920
+
921
+ Parameters
922
+ ----------
923
+ elastic_xyz : str or pathlib.Path
924
+ Base XYZ used for elastic strain modes.
925
+ bulk_xyz : str or pathlib.Path or None
926
+ Optional base XYZ used for bulk mode. If None, reuse ``elastic_xyz``.
927
+ elastic_cell : dict
928
+ Elastic reference cell with keys: ``a,b,c,alpha,beta,gamma``.
929
+ bulk_cell : dict
930
+ Bulk reference cell with keys: ``a,b,c,alpha,beta,gamma``.
931
+ max_strain_elastic : float
932
+ Maximum absolute linear strain for elastic modes (unitless).
933
+ dstrain_elastic : float
934
+ Linear strain step for elastic modes (unitless).
935
+ max_strain_bulk_linear : float
936
+ Maximum absolute linear bulk strain (unitless).
937
+ dstrain_bulk_linear : float
938
+ Linear bulk strain step (unitless).
939
+ out_dir : str or pathlib.Path
940
+ Output directory where ``xyz_strained`` and ``geo_strained`` are created.
941
+ sort_by : str or None, optional
942
+ Sorting key passed to ``xtob`` (e.g., ``"z"``).
943
+
944
+ Returns
945
+ -------
946
+ dict[str, list[pathlib.Path]]
947
+ Mapping mode name to written GEO paths (e.g., ``"bulk"``, ``"c11"``).
948
+
949
+ Examples
950
+ --------
951
+ >>> cell = {"a": 2.9, "b": 2.9, "c": 3.5, "alpha": 90, "beta": 90, "gamma": 90}
952
+ >>> out = generate_strained_geometries_with_xtob(
953
+ ... elastic_xyz="ground_elastic.xyz",
954
+ ... bulk_xyz=None,
955
+ ... elastic_cell=cell,
956
+ ... bulk_cell=cell,
957
+ ... max_strain_elastic=0.02,
958
+ ... dstrain_elastic=0.005,
959
+ ... max_strain_bulk_linear=0.01,
960
+ ... dstrain_bulk_linear=0.004,
961
+ ... out_dir="out",
962
+ ... )
963
+ """
964
+ out_dir = Path(out_dir)
965
+ xyz_dir = out_dir / "xyz_strained"
966
+ geo_dir = out_dir / "geo_strained"
967
+ xyz_dir.mkdir(parents=True, exist_ok=True)
968
+ geo_dir.mkdir(parents=True, exist_ok=True)
969
+
970
+ def idx_abs_from_eps(eps: float, step: float) -> int:
971
+ """Fortran-like abs(n) index where eps = n * step."""
972
+ if abs(eps) < 1e-15:
973
+ return 0
974
+ return abs(int(round(eps / step)))
975
+
976
+ cell_e = cellpar_to_cell([
977
+ elastic_cell["a"], elastic_cell["b"], elastic_cell["c"],
978
+ elastic_cell["alpha"], elastic_cell["beta"], elastic_cell["gamma"],
979
+ ])
980
+ cell_b = cellpar_to_cell([
981
+ bulk_cell["a"], bulk_cell["b"], bulk_cell["c"],
982
+ bulk_cell["alpha"], bulk_cell["beta"], bulk_cell["gamma"],
983
+ ])
984
+
985
+ # Base atoms for elastic
986
+ base_e = _make_base_atoms_from_xyz_and_cell(elastic_xyz, cell_e)
987
+ frac_e = base_e.get_scaled_positions(wrap=False)
988
+
989
+ # Base atoms for bulk (reuse elastic if not provided)
990
+ if bulk_xyz is None:
991
+ base_b = base_e.copy()
992
+ base_b.set_cell(cell_b, scale_atoms=False)
993
+ base_b.set_pbc(True)
994
+ else:
995
+ base_b = _make_base_atoms_from_xyz_and_cell(bulk_xyz, cell_b)
996
+ frac_b = base_b.get_scaled_positions(wrap=False)
997
+
998
+ out: Dict[str, List[Path]] = {
999
+ m: [] for m in ["bulk", "c11", "c22", "c33", "c12", "c13", "c23", "c44", "c55", "c66"]
1000
+ }
1001
+
1002
+ # ---- Bulk ----
1003
+ bulk_grid = _symmetric_strain_grid(max_strain_bulk_linear, dstrain_bulk_linear)
1004
+ for j, eps in enumerate(bulk_grid):
1005
+ d = _deformation_matrix("bulk", eps)
1006
+ new_cell = d @ cell_b
1007
+ a, b, c, alpha, beta, gamma = cell_to_cellpar(new_cell)
1008
+
1009
+ idx_abs = idx_abs_from_eps(eps, dstrain_bulk_linear)
1010
+ title = _strain_title("bulk", eps, idx_abs=idx_abs)
1011
+ xyz_path = xyz_dir / f"{title}.xyz"
1012
+ geo_path = geo_dir / f"{title}.bgf"
1013
+
1014
+ atoms = base_b.copy()
1015
+ atoms.set_cell(new_cell, scale_atoms=False)
1016
+ atoms.set_scaled_positions(frac_b)
1017
+
1018
+ # IMPORTANT: comment goes to line 2 in XYZ
1019
+ write_structure(atoms, xyz_path, format="xyz", comment=title)
1020
+
1021
+ xtob(
1022
+ xyz_file=xyz_path,
1023
+ geo_file=geo_path,
1024
+ box_lengths=(float(a), float(b), float(c)),
1025
+ box_angles=(float(alpha), float(beta), float(gamma)),
1026
+ sort_by=sort_by,
1027
+ ascending=True,
1028
+ )
1029
+ out["bulk"].append(geo_path)
1030
+
1031
+ # ---- Elastic modes ----
1032
+ elastic_grid = _symmetric_strain_grid(max_strain_elastic, dstrain_elastic)
1033
+ modes = ["c11", "c22", "c33", "c12", "c13", "c23", "c44", "c55", "c66"]
1034
+ for mode in modes:
1035
+ for j, eps in enumerate(elastic_grid):
1036
+ d = _deformation_matrix(mode, eps)
1037
+ new_cell = d @ cell_e
1038
+ a, b, c, alpha, beta, gamma = cell_to_cellpar(new_cell)
1039
+
1040
+ idx_abs = idx_abs_from_eps(eps, dstrain_elastic)
1041
+ title = _strain_title(mode, eps, idx_abs=idx_abs)
1042
+ xyz_path = xyz_dir / f"{title}.xyz"
1043
+ geo_path = geo_dir / f"{title}.geo"
1044
+
1045
+ atoms = base_e.copy()
1046
+ atoms.set_cell(new_cell, scale_atoms=False)
1047
+ atoms.set_scaled_positions(frac_e)
1048
+
1049
+ write_structure(atoms, xyz_path, format="xyz", comment=title)
1050
+
1051
+ xtob(
1052
+ xyz_file=xyz_path,
1053
+ geo_file=geo_path,
1054
+ box_lengths=(float(a), float(b), float(c)),
1055
+ box_angles=(float(alpha), float(beta), float(gamma)),
1056
+ sort_by=sort_by,
1057
+ ascending=True,
1058
+ )
1059
+ out[mode].append(geo_path)
1060
+
1061
+ return out
1062
+
1063
+ # =============================================================================
1064
+ # 3. YAML file management for settings of trainset
1065
+ # =============================================================================
1066
+
1067
+ # -----------------------------------------------------------------------------
1068
+ # Yaml producer to input the cell dimensions and angles along with other
1069
+ # settings for generating energy vs volume for expanded or compressed cells
1070
+ # -----------------------------------------------------------------------------
1071
+
1072
+ def write_trainset_settings_yaml(
1073
+ *,
1074
+ out_path: str,
1075
+ name: str = "AlN example",
1076
+ source: str = "manual",
1077
+ mp_id: Optional[str] = None,
1078
+ # Elastic inputs
1079
+ elastic_max_strain_percent: float = 3.0,
1080
+ elastic_dstrain: float = 0.005,
1081
+ cij_gpa: Optional[Dict[str, float]] = None,
1082
+ elastic_cell: Optional[Dict[str, float]] = None,
1083
+ # Bulk inputs
1084
+ B0_gpa: float = 174.0,
1085
+ B0_prime: float = 1.5,
1086
+ bulk_max_volumetric_strain_percent: float = 6.0,
1087
+ bulk_dstrain_linear: float = 0.004,
1088
+ bulk_cell: Optional[Dict[str, float]] = None,
1089
+ # Output names
1090
+ trainset_file: str = "trainset_elastic.in",
1091
+ tables: Optional[Dict[str, str]] = None,
1092
+ elastic_xyz: Optional[str | Path] = "ground_elastic.xyz",
1093
+ bulk_xyz: Optional[str | Path] = "null",
1094
+ geo_enable: bool = True
1095
+ ) -> None:
1096
+ """
1097
+ Write a trainset settings YAML file for elastic-energy trainset generation.
1098
+
1099
+ Works on
1100
+ --------
1101
+ YAML configuration files for trainset generation (trainset_elastic.yaml)
1102
+
1103
+ Parameters
1104
+ ----------
1105
+ out_path : str
1106
+ Output YAML file path.
1107
+ name : str, optional
1108
+ Descriptive material name stored in metadata.
1109
+ source : str, optional
1110
+ Settings source label (e.g., ``"manual"`` or ``"materials_project"``).
1111
+ mp_id : str or None, optional
1112
+ Materials Project ID to store in metadata.
1113
+ elastic_max_strain_percent : float, optional
1114
+ Maximum elastic strain magnitude (%).
1115
+ elastic_dstrain : float, optional
1116
+ Elastic strain step size (unitless).
1117
+ cij_gpa : dict or None, optional
1118
+ Elastic constants in GPa with keys ``c11..c66``.
1119
+ elastic_cell : dict or None, optional
1120
+ Elastic reference cell with keys ``a,b,c,alpha,beta,gamma``.
1121
+ B0_gpa : float, optional
1122
+ Bulk modulus B0 (GPa).
1123
+ B0_prime : float, optional
1124
+ Bulk modulus pressure derivative B0' (dimensionless).
1125
+ bulk_max_volumetric_strain_percent : float, optional
1126
+ Maximum volumetric strain magnitude (%).
1127
+ bulk_dstrain_linear : float, optional
1128
+ Bulk linear strain step (unitless).
1129
+ bulk_cell : dict or None, optional
1130
+ Bulk reference cell with keys ``a,b,c,alpha,beta,gamma``.
1131
+ trainset_file : str, optional
1132
+ Trainset file name to store under output settings.
1133
+ tables : dict or None, optional
1134
+ Output table filenames keyed by mode (e.g., ``"bulk"``, ``"c11"``).
1135
+ elastic_xyz : str or pathlib.Path or None, optional
1136
+ Base XYZ used for elastic geometry generation when enabled.
1137
+ bulk_xyz : str or pathlib.Path or None, optional
1138
+ Optional base XYZ for bulk geometry generation when enabled.
1139
+ geo_enable : bool, optional
1140
+ Whether the YAML enables geometry generation.
1141
+
1142
+ Returns
1143
+ -------
1144
+ None
1145
+ Writes a YAML settings file to disk.
1146
+
1147
+ Examples
1148
+ --------
1149
+ >>> write_trainset_settings_yaml(
1150
+ ... out_path="trainset_elastic.yaml",
1151
+ ... name="AlN example",
1152
+ ... source="manual",
1153
+ ... )
1154
+ """
1155
+ import os
1156
+ from typing import List
1157
+
1158
+ # -------------------------
1159
+ # Defaults (match your example)
1160
+ # -------------------------
1161
+ if cij_gpa is None:
1162
+ cij_gpa = {
1163
+ "c11": 287,
1164
+ "c22": 287,
1165
+ "c33": 219,
1166
+ "c12": 100,
1167
+ "c13": 144,
1168
+ "c23": 144,
1169
+ "c44": 76,
1170
+ "c55": 76,
1171
+ "c66": 50,
1172
+ }
1173
+
1174
+ if elastic_cell is None:
1175
+ elastic_cell = {
1176
+ "a": 2.85086,
1177
+ "b": 2.85086,
1178
+ "c": 3.49456,
1179
+ "alpha": 90.0,
1180
+ "beta": 90.0,
1181
+ "gamma": 90.0,
1182
+ }
1183
+
1184
+ if bulk_cell is None:
1185
+ bulk_cell = {
1186
+ "a": 2.85086,
1187
+ "b": 2.85086,
1188
+ "c": 3.49456,
1189
+ "alpha": 90.0,
1190
+ "beta": 90.0,
1191
+ "gamma": 90.0,
1192
+ }
1193
+
1194
+ if tables is None:
1195
+ tables = {
1196
+ "bulk": "EvsStrain_bulk.dat",
1197
+ "c11": "EvsStrain_c11.dat",
1198
+ "c22": "EvsStrain_c22.dat",
1199
+ "c33": "EvsStrain_c33.dat",
1200
+ "c12": "EvsStrain_c12.dat",
1201
+ "c13": "EvsStrain_c13.dat",
1202
+ "c23": "EvsStrain_c23.dat",
1203
+ "c44": "EvsStrain_c44.dat",
1204
+ "c55": "EvsStrain_c55.dat",
1205
+ "c66": "EvsStrain_c66.dat",
1206
+ }
1207
+
1208
+ # -------------------------
1209
+ # Minimal validation
1210
+ # -------------------------
1211
+ required_cij = ("c11", "c22", "c33", "c12", "c13", "c23", "c44", "c55", "c66")
1212
+ missing = [k for k in required_cij if k not in cij_gpa]
1213
+ if missing:
1214
+ raise ValueError(f"cij_gpa is missing required keys: {missing}")
1215
+
1216
+ for cell_name, cell in (("elastic_cell", elastic_cell), ("bulk_cell", bulk_cell)):
1217
+ for k in ("a", "b", "c", "alpha", "beta", "gamma"):
1218
+ if k not in cell:
1219
+ raise ValueError(f"{cell_name} missing key '{k}'")
1220
+
1221
+ # -------------------------
1222
+ # YAML writing (manual; stable schema; no external deps)
1223
+ # -------------------------
1224
+ def _q(s: str) -> str:
1225
+ """Quote a string for YAML safely enough for our simple schema."""
1226
+ s2 = s.replace("\\", "\\\\").replace('"', '\\"')
1227
+ return f'"{s2}"'
1228
+
1229
+ mp_id_yaml = "null" if mp_id is None else _q(mp_id)
1230
+
1231
+ # Requirement (1): header comment must match the file name exactly
1232
+ yaml_filename = os.path.basename(out_path)
1233
+
1234
+ lines: List[str] = []
1235
+ lines.append(f"# {yaml_filename}")
1236
+ lines.append("")
1237
+ # Requirement (2): short docstring-like comment at top
1238
+ lines.append("# This is the settings file used by ReaxKit's trainset generator.")
1239
+ lines.append("# Edit the values below (especially strains/moduli/cell) to match your material/system.")
1240
+ lines.append("")
1241
+
1242
+ lines.append("metadata:")
1243
+ lines.append(f" name: {_q(name)}")
1244
+ lines.append(f" source: {_q(source)} # 'manual' or 'materials_project'")
1245
+ lines.append(f" mp_id: {mp_id_yaml} # Optional: e.g. \"mp-661\"")
1246
+ lines.append("")
1247
+
1248
+ lines.append("units:")
1249
+ lines.append(' elastic_constants: "GPa"')
1250
+ lines.append(' bulk_modulus: "GPa"')
1251
+ lines.append(' angles: "deg"')
1252
+ lines.append(' lengths: "angstrom"')
1253
+ lines.append(' strain: "percent"')
1254
+ lines.append("")
1255
+
1256
+ # Requirement (4): explain elastic vs bulk with 1–2 comment lines before each section
1257
+ lines.append("# Elastic section: generates energy-vs-strain targets for c11..c66 (small linear strains).")
1258
+ lines.append("# Use this for harmonic elastic response around the reference cell.")
1259
+ lines.append("elastic:")
1260
+ lines.append(f" max_strain_percent: {elastic_max_strain_percent} # Max linear strain magnitude (%) for elastic targets")
1261
+ # Requirement (3) + (5): inline comment for dstrain explaining what it is and user input
1262
+ lines.append(f" dstrain: {elastic_dstrain} # Strain step size (unitless). Default = 0.5% = 0.005")
1263
+ lines.append(" cij_gpa: # Elastic constants in GPa (c11,c22,c33,c12,c13,c23,c44,c55,c66)")
1264
+ for k in required_cij:
1265
+ lines.append(f" {k}: {cij_gpa[k]}")
1266
+ lines.append("")
1267
+ lines.append(" cell: # Elastic reference cell (a,b,c in Å; angles in deg)")
1268
+ lines.append(f" a: {elastic_cell['a']}")
1269
+ lines.append(f" b: {elastic_cell['b']}")
1270
+ lines.append(f" c: {elastic_cell['c']}")
1271
+ lines.append(f" alpha: {elastic_cell['alpha']}")
1272
+ lines.append(f" beta: {elastic_cell['beta']}")
1273
+ lines.append(f" gamma: {elastic_cell['gamma']}")
1274
+ lines.append("")
1275
+
1276
+ # PATCH 1: keep separate structure sections (and make intent explicit)
1277
+ lines.append("# Input structures (XYZ). Used when geo.enable=true.")
1278
+ lines.append("structure 1:")
1279
+ lines.append(f' elastic_xyz: {elastic_xyz} # required if geo.enable=true')
1280
+ lines.append("")
1281
+
1282
+ lines.append("# Bulk section: generates energy-vs-volume targets using an EOS (Vinet) over a wider strain range.")
1283
+ lines.append("# Use this to constrain compressibility (B0, B0') around the reference volume.")
1284
+ lines.append("bulk:")
1285
+ lines.append(f" B0_gpa: {B0_gpa} # Bulk modulus B0 at P=0 (GPa)")
1286
+ lines.append(f" B0_prime: {B0_prime} # Pressure derivative B0' = dB/dP at P=0 (dimensionless)")
1287
+ lines.append(f" max_volumetric_strain_percent: {bulk_max_volumetric_strain_percent} # Max volumetric strain magnitude (%)")
1288
+ # Requirement (5): inline comment for dstrain_linear
1289
+ lines.append(
1290
+ f" dstrain_linear: {bulk_dstrain_linear} # Linear isotropic strain step ε (unitless). "
1291
+ f"Volume uses V=V0*(1+ε)^3. Default = 0.4% = 0.004"
1292
+ )
1293
+ lines.append("")
1294
+ lines.append(" cell: # Bulk/EOS reference cell (used to compute V0; a,b,c in Å; angles in deg)")
1295
+ lines.append(f" a: {bulk_cell['a']}")
1296
+ lines.append(f" b: {bulk_cell['b']}")
1297
+ lines.append(f" c: {bulk_cell['c']}")
1298
+ lines.append(f" alpha: {bulk_cell['alpha']}")
1299
+ lines.append(f" beta: {bulk_cell['beta']}")
1300
+ lines.append(f" gamma: {bulk_cell['gamma']}")
1301
+ lines.append("")
1302
+
1303
+ # PATCH 1: second structure section
1304
+ lines.append("structure 2:")
1305
+ lines.append(f" bulk_xyz: {bulk_xyz} # optional; if null, reuse elastic_xyz")
1306
+ lines.append("")
1307
+
1308
+ # PATCH 1: geo generation options (opt-in)
1309
+ lines.append("# Geometry generation options.")
1310
+ lines.append("geo:")
1311
+ lines.append(f" enable: {geo_enable} # set true to generate strained xyz + geo")
1312
+ lines.append(" sort_by: null # e.g. 'z' or null")
1313
+ lines.append("")
1314
+
1315
+ # Output section: you can usually keep this as-is unless you want different filenames.
1316
+ lines.append("# Output section: you can usually keep this as-is unless you want different filenames.")
1317
+ lines.append("# Added: output folders for strained XYZ and GEO files.")
1318
+ lines.append("output:")
1319
+ lines.append(f" trainset_file: {_q(trainset_file)}")
1320
+ lines.append(f" xyz_strained_dir: {_q('xyz_strained')}")
1321
+ lines.append(f" geo_strained_dir: {_q('geo_strained')}")
1322
+ lines.append(" tables:")
1323
+ for key in ("bulk", "c11", "c22", "c33", "c12", "c13", "c23", "c44", "c55", "c66"):
1324
+ lines.append(f" {key}: {_q(tables[key])}")
1325
+ lines.append("")
1326
+
1327
+ with open(out_path, "w", encoding="utf-8") as f:
1328
+ f.write("\n".join(lines))
1329
+
1330
+
1331
+ # -----------------------------------------------------------------------------
1332
+ # reads a Yaml settings file and generates a trainset
1333
+ # -----------------------------------------------------------------------------
1334
+
1335
+ def read_trainset_settings_yaml(yaml_path: str) -> dict:
1336
+ """
1337
+ Read a trainset settings YAML file into a configuration dictionary.
1338
+
1339
+ Works on
1340
+ --------
1341
+ YAML configuration files for trainset generation (trainset_elastic.yaml)
1342
+
1343
+ Parameters
1344
+ ----------
1345
+ yaml_path : str
1346
+ Path to a YAML settings file.
1347
+
1348
+ Returns
1349
+ -------
1350
+ dict
1351
+ Parsed configuration mapping containing ``elastic``, ``bulk``, and ``output`` sections.
1352
+
1353
+ Examples
1354
+ --------
1355
+ >>> cfg = read_trainset_settings_yaml("trainset_elastic.yaml")
1356
+ >>> sorted(cfg.keys())[:3]
1357
+ ['bulk', 'elastic', 'metadata']
1358
+ """
1359
+ try:
1360
+ import yaml
1361
+ except ImportError as exc:
1362
+ raise ImportError(
1363
+ "PyYAML is required to read trainset YAML files. "
1364
+ "Install with: pip install pyyaml"
1365
+ ) from exc
1366
+
1367
+ from pathlib import Path
1368
+
1369
+ yaml_path = Path(yaml_path)
1370
+ if not yaml_path.exists():
1371
+ raise FileNotFoundError(f"YAML file does not exist: {yaml_path}")
1372
+
1373
+ with yaml_path.open("r", encoding="utf-8") as f:
1374
+ data = yaml.safe_load(f)
1375
+
1376
+ if not isinstance(data, dict):
1377
+ raise ValueError("YAML root must be a mapping/dictionary.")
1378
+
1379
+ # Minimal structural validation (structures are required only if geo.enable=true)
1380
+ required_sections = ("elastic", "bulk", "output")
1381
+ missing = [k for k in required_sections if k not in data]
1382
+ if missing:
1383
+ raise ValueError(f"YAML is missing required sections: {missing}")
1384
+
1385
+ geo_cfg = data.get("geo", {}) or {}
1386
+ enable_geo = bool(geo_cfg.get("enable", False))
1387
+
1388
+ if enable_geo:
1389
+ # Require "structure 1" and validate elastic_xyz exists
1390
+ s1 = data.get("structure 1")
1391
+ if not isinstance(s1, dict):
1392
+ raise ValueError("Missing required section: 'structure 1' (required when geo.enable=true)")
1393
+
1394
+ elastic_xyz = s1.get("elastic_xyz")
1395
+ if not elastic_xyz:
1396
+ raise ValueError("Missing required key: structure 1.elastic_xyz (required when geo.enable=true)")
1397
+
1398
+ elastic_xyz_path = Path(elastic_xyz)
1399
+ if not elastic_xyz_path.is_absolute():
1400
+ elastic_xyz_path = (yaml_path.parent / elastic_xyz_path).resolve()
1401
+ if not elastic_xyz_path.exists():
1402
+ raise FileNotFoundError(f"structure 1.elastic_xyz does not exist: {elastic_xyz_path}")
1403
+
1404
+ # Store resolved path back
1405
+ data["structure 1"]["elastic_xyz"] = str(elastic_xyz_path)
1406
+
1407
+ # "structure 2" is optional; validate bulk_xyz if provided
1408
+ s2 = data.get("structure 2", {}) or {}
1409
+ if not isinstance(s2, dict):
1410
+ raise ValueError("'structure 2' must be a mapping if provided")
1411
+
1412
+ bulk_xyz = s2.get("bulk_xyz")
1413
+ if bulk_xyz:
1414
+ bulk_xyz_path = Path(bulk_xyz)
1415
+ if not bulk_xyz_path.is_absolute():
1416
+ bulk_xyz_path = (yaml_path.parent / bulk_xyz_path).resolve()
1417
+ if not bulk_xyz_path.exists():
1418
+ raise FileNotFoundError(f"structure 2.bulk_xyz does not exist: {bulk_xyz_path}")
1419
+ data.setdefault("structure 2", {})["bulk_xyz"] = str(bulk_xyz_path)
1420
+
1421
+ return data
1422
+
1423
+
1424
+ def generate_trainset_from_yaml(
1425
+ yaml_path: str,
1426
+ out_dir: str,
1427
+ *,
1428
+ place_all_outputs_in_out_dir: bool = True,
1429
+ copy_input_xyz_into_out_dir: bool = True,
1430
+ ):
1431
+ """
1432
+ Generate a trainset and optional strained geometries from a YAML settings file.
1433
+
1434
+ Works on
1435
+ --------
1436
+ YAML settings + XYZ inputs (optional) → trainset files and strained structures
1437
+
1438
+ Parameters
1439
+ ----------
1440
+ yaml_path : str
1441
+ Path to the trainset settings YAML file.
1442
+ out_dir : str
1443
+ Output directory for generated files.
1444
+ place_all_outputs_in_out_dir : bool, optional
1445
+ If True, place all generated outputs (including geometry outputs) in ``out_dir``.
1446
+ copy_input_xyz_into_out_dir : bool, optional
1447
+ If True, copy input XYZ files into the output directory when geometry generation is enabled.
1448
+
1449
+ Returns
1450
+ -------
1451
+ None
1452
+ Writes trainset files and (optionally) strained XYZ/GEO files to disk.
1453
+
1454
+ Examples
1455
+ --------
1456
+ >>> generate_trainset_from_yaml("trainset_elastic.yaml", out_dir="out")
1457
+ """
1458
+ cfg = read_trainset_settings_yaml(yaml_path)
1459
+ yaml_path_p = Path(yaml_path).resolve()
1460
+
1461
+ out_dir_p = Path(out_dir).resolve()
1462
+ out_dir_p.mkdir(parents=True, exist_ok=True)
1463
+
1464
+ # -------------------------
1465
+ # Bulk inputs
1466
+ # -------------------------
1467
+ bulk_cfg = cfg["bulk"]
1468
+ bulk_inputs = {
1469
+ "B0_gpa": bulk_cfg["B0_gpa"],
1470
+ "B0_prime": bulk_cfg["B0_prime"],
1471
+ "max_volumetric_strain_percent": bulk_cfg["max_volumetric_strain_percent"],
1472
+ }
1473
+
1474
+ # -------------------------
1475
+ # Elastic inputs
1476
+ # -------------------------
1477
+ elastic_cfg = cfg["elastic"]
1478
+ elastic_inputs = {
1479
+ "max_strain_percent": elastic_cfg["max_strain_percent"],
1480
+ **elastic_cfg["cij_gpa"],
1481
+ }
1482
+
1483
+ # -------------------------
1484
+ # Cells
1485
+ # -------------------------
1486
+ bulk_cell = bulk_cfg["cell"]
1487
+ elastic_cell = elastic_cfg.get("cell", bulk_cell)
1488
+
1489
+ # -------------------------
1490
+ # Optional overrides
1491
+ # -------------------------
1492
+ bulk_options = {
1493
+ "linear_strain_step": bulk_cfg.get("dstrain_linear", 0.004)
1494
+ }
1495
+ elastic_options = {
1496
+ "strain_step": elastic_cfg.get("dstrain", 0.005)
1497
+ }
1498
+
1499
+ # -------------------------
1500
+ # Generate elastic energy data (writes into out_dir)
1501
+ # -------------------------
1502
+ generate_all_energy_vs_volume_data(
1503
+ bulk_inputs=bulk_inputs,
1504
+ elastic_inputs=elastic_inputs,
1505
+ bulk_cell=bulk_cell,
1506
+ elastic_volume_cell=elastic_cell,
1507
+ bulk_options=bulk_options,
1508
+ elastic_options=elastic_options,
1509
+ out_dir=str(out_dir_p),
1510
+ )
1511
+
1512
+ # -------------------------
1513
+ # Optional: generate strained geometries (writes xyz + geo)
1514
+ # -------------------------
1515
+ geo_cfg = cfg.get("geo", {}) or {}
1516
+ enable_geo = bool(geo_cfg.get("enable", False))
1517
+ if not enable_geo:
1518
+ return
1519
+
1520
+ s1 = cfg["structure 1"]
1521
+ elastic_xyz = Path(s1["elastic_xyz"])
1522
+ if not elastic_xyz.is_absolute():
1523
+ elastic_xyz = (yaml_path_p.parent / elastic_xyz).resolve()
1524
+ if not elastic_xyz.exists():
1525
+ raise FileNotFoundError(f"structure 1.elastic_xyz does not exist: {elastic_xyz}")
1526
+
1527
+ s2 = cfg.get("structure 2", {}) or {}
1528
+ bulk_xyz_val = s2.get("bulk_xyz")
1529
+ bulk_xyz = None
1530
+ if bulk_xyz_val:
1531
+ bulk_xyz = Path(bulk_xyz_val)
1532
+ if not bulk_xyz.is_absolute():
1533
+ bulk_xyz = (yaml_path_p.parent / bulk_xyz).resolve()
1534
+ if not bulk_xyz.exists():
1535
+ raise FileNotFoundError(f"structure 2.bulk_xyz does not exist: {bulk_xyz}")
1536
+
1537
+ # Decide where geo outputs go
1538
+ geo_out_dir = out_dir_p if place_all_outputs_in_out_dir else yaml_path_p.parent
1539
+
1540
+ # Optionally copy the input xyz files into geo_out_dir so folder is self-contained
1541
+ if copy_input_xyz_into_out_dir:
1542
+ elastic_xyz_dst = geo_out_dir / elastic_xyz.name
1543
+ if elastic_xyz.resolve() != elastic_xyz_dst.resolve():
1544
+ shutil.copy2(elastic_xyz, elastic_xyz_dst)
1545
+ elastic_xyz = elastic_xyz_dst
1546
+
1547
+ if bulk_xyz is not None:
1548
+ bulk_xyz_dst = geo_out_dir / bulk_xyz.name
1549
+ if bulk_xyz.resolve() != bulk_xyz_dst.resolve():
1550
+ shutil.copy2(bulk_xyz, bulk_xyz_dst)
1551
+ bulk_xyz = bulk_xyz_dst
1552
+
1553
+ # Convert YAML percent strains -> linear strain limits
1554
+ max_strain_elastic = elastic_cfg["max_strain_percent"] / 100.0
1555
+ dstrain_elastic = elastic_cfg.get("dstrain", 0.005)
1556
+
1557
+ # bulk linear strain from volumetric percent (Fortran-style)
1558
+ max_vol = bulk_cfg["max_volumetric_strain_percent"] / 100.0
1559
+ max_strain_bulk_linear = (1.0 + max_vol) ** (1.0 / 3.0) - 1.0
1560
+ dstrain_bulk_linear = bulk_cfg.get("dstrain_linear", 0.004)
1561
+
1562
+ sort_by = geo_cfg.get("sort_by") # e.g. "z" or None
1563
+
1564
+ generate_strained_geometries_with_xtob(
1565
+ elastic_xyz=str(elastic_xyz),
1566
+ bulk_xyz=None if bulk_xyz is None else str(bulk_xyz),
1567
+ elastic_cell=elastic_cell,
1568
+ bulk_cell=bulk_cell,
1569
+ max_strain_elastic=max_strain_elastic,
1570
+ dstrain_elastic=dstrain_elastic,
1571
+ max_strain_bulk_linear=max_strain_bulk_linear,
1572
+ dstrain_bulk_linear=dstrain_bulk_linear,
1573
+ out_dir=str(geo_out_dir),
1574
+ sort_by=sort_by,
1575
+ )
1576
+
1577
+
1578
+ # =============================================================================
1579
+ # 4. MP API Handler:
1580
+ # Handles Material's project API to get mechanical properties, lattice
1581
+ # dimensions and angles, and structure of the system
1582
+ # =============================================================================
1583
+
1584
+ BulkModulusMode = Literal["voigt", "reuss", "vrh"]
1585
+
1586
+ def _tensor6x6_to_cij_dict(t6: List[List[float]]) -> Dict[str, float]:
1587
+ if t6 is None or len(t6) != 6 or any(len(row) != 6 for row in t6):
1588
+ raise ValueError("Elastic tensor must be a 6x6 matrix.")
1589
+ f = lambda i, j: float(t6[i][j])
1590
+ return {
1591
+ "c11": f(0, 0), "c22": f(1, 1), "c33": f(2, 2),
1592
+ "c12": f(0, 1), "c13": f(0, 2), "c23": f(1, 2),
1593
+ "c44": f(3, 3), "c55": f(4, 4), "c66": f(5, 5),
1594
+ }
1595
+
1596
+
1597
+ def _extract_tensor6(elastic_tensor_obj: Any) -> Optional[List[List[float]]]:
1598
+ """Keep this tiny: support the 2–3 common mp-api shapes."""
1599
+ if elastic_tensor_obj is None:
1600
+ return None
1601
+ et = elastic_tensor_obj
1602
+ if hasattr(et, "ieee_format") and et.ieee_format is not None:
1603
+ return et.ieee_format
1604
+ if hasattr(et, "raw") and et.raw is not None:
1605
+ return et.raw
1606
+ if isinstance(et, (list, tuple)):
1607
+ return list(et) # type: ignore[return-value]
1608
+ return None
1609
+
1610
+
1611
+ def _pick_bulk_modulus(bm: Any, mode: BulkModulusMode) -> Optional[float]:
1612
+ if bm is None:
1613
+ return None
1614
+ val = getattr(bm, mode, None) # bm.voigt / bm.reuss / bm.vrh
1615
+ return None if val is None else float(val)
1616
+
1617
+
1618
+ def generate_trainset_settings_yaml_from_mp_simple(
1619
+ *,
1620
+ mp_id: str,
1621
+ out_yaml: str | Path,
1622
+ structure_dir: Optional[str | Path] = None,
1623
+ bulk_mode: BulkModulusMode = "vrh",
1624
+ api_key: Optional[str] = None,
1625
+ verbose: bool = True,
1626
+ ) -> Dict[str, str]:
1627
+ """
1628
+ Generate a trainset settings YAML and structures from a Materials Project ID.
1629
+
1630
+ Minimal MP -> (structure + mechanics) -> CIF -> XYZ -> trainset_settings.yaml.
1631
+
1632
+ - Fetches: structure, lattice (a,b,c,alpha,beta,gamma), elastic tensor (6x6), bulk modulus.
1633
+ - Writes: <mp_id>.cif and <mp_id>.xyz
1634
+ - Writes YAML where:
1635
+ - elastic_cell == bulk_cell == MP lattice
1636
+ - structure 1.elastic_xyz == structure 2.bulk_xyz == generated XYZ
1637
+ - geo.enable is set true (since geo comes from the XYZ)
1638
+
1639
+ Works on
1640
+ --------
1641
+ Materials Project API + structure files (CIF/XYZ) + trainset settings YAML
1642
+
1643
+ Parameters
1644
+ ----------
1645
+ mp_id : str
1646
+ Materials Project material ID (e.g., ``"mp-661"``).
1647
+ out_yaml : str or pathlib.Path
1648
+ Output YAML path to write.
1649
+ structure_dir : str or pathlib.Path or None, optional
1650
+ Directory to write structure files (CIF/XYZ). If None, uses the YAML folder.
1651
+ bulk_mode : {"voigt","reuss","vrh"}, optional
1652
+ Which bulk modulus value to store in YAML.
1653
+ api_key : str or None, optional
1654
+ Materials Project API key. If None, uses ``MP_API_KEY`` environment variable.
1655
+ verbose : bool, optional
1656
+ If True, print written paths to stdout.
1657
+
1658
+ Returns
1659
+ -------
1660
+ dict[str, str]
1661
+ Mapping with keys: ``"cif"``, ``"xyz"``, ``"yaml"`` pointing to written file paths.
1662
+
1663
+ Examples
1664
+ --------
1665
+ >>> out = generate_trainset_settings_yaml_from_mp_simple(
1666
+ ... mp_id="mp-661",
1667
+ ... out_yaml="trainset_elastic.yaml",
1668
+ ... )
1669
+ >>> sorted(out.keys())
1670
+ ['cif', 'xyz', 'yaml']
1671
+ """
1672
+ api_key = api_key or os.getenv("MP_API_KEY")
1673
+ if not api_key:
1674
+ raise RuntimeError("Set MP_API_KEY env var (or pass api_key=...).")
1675
+
1676
+ out_yaml = Path(out_yaml)
1677
+ out_yaml.parent.mkdir(parents=True, exist_ok=True)
1678
+
1679
+ sdir = Path(structure_dir) if structure_dir is not None else out_yaml.parent
1680
+ sdir.mkdir(parents=True, exist_ok=True)
1681
+
1682
+ base = mp_id.replace(":", "_")
1683
+ cif_path = sdir / f"{base}.cif"
1684
+ xyz_path = sdir / f"{base}.xyz"
1685
+
1686
+ # out_yaml: path to YAML you're writing
1687
+ # xyz_path: full path where you saved mp-661.xyz
1688
+ out_yaml_p = Path(out_yaml).resolve()
1689
+ xyz_path_p = Path(xyz_path).resolve() # wherever you saved it (likely structure_dir/mp-661.xyz)
1690
+
1691
+ # Write a RELATIVE path into YAML (relative to the YAML folder)
1692
+ elastic_xyz_for_yaml = xyz_path_p.relative_to(out_yaml_p.parent).as_posix()
1693
+
1694
+ with MPRester(api_key) as mpr:
1695
+ # 1) summary: structure + lattice
1696
+ sdoc = mpr.materials.summary.search(
1697
+ material_ids=[mp_id],
1698
+ fields=["material_id", "formula_pretty", "structure"],
1699
+ )[0]
1700
+ structure = sdoc.structure
1701
+ lat = structure.lattice
1702
+ name = getattr(sdoc, "formula_pretty", None) or mp_id
1703
+
1704
+ cell = {
1705
+ "a": float(lat.a), "b": float(lat.b), "c": float(lat.c),
1706
+ "alpha": float(lat.alpha), "beta": float(lat.beta), "gamma": float(lat.gamma),
1707
+ }
1708
+
1709
+ # 2) elasticity: elastic tensor + bulk modulus
1710
+ edocs = mpr.materials.elasticity.search(
1711
+ material_ids=[mp_id],
1712
+ fields=["material_id", "elastic_tensor", "bulk_modulus"],
1713
+ )
1714
+ if not edocs:
1715
+ raise ValueError(f"No elasticity data for {mp_id} (cannot populate elastic/bulk).")
1716
+ edoc = edocs[0]
1717
+
1718
+ tensor6 = _extract_tensor6(getattr(edoc, "elastic_tensor", None))
1719
+ if tensor6 is None:
1720
+ raise ValueError(f"{mp_id}: elastic_tensor missing/unreadable.")
1721
+ cij = _tensor6x6_to_cij_dict(tensor6)
1722
+
1723
+ B0 = _pick_bulk_modulus(getattr(edoc, "bulk_modulus", None), bulk_mode)
1724
+ if B0 is None:
1725
+ raise ValueError(f"{mp_id}: bulk_modulus.{bulk_mode} missing/unreadable.")
1726
+
1727
+ # 3) CIF -> XYZ (use geo_generator’s writer for XYZ)
1728
+ # Write CIF via pymatgen structure.to(...) (simple; if it fails you can add your CifWriter fallback)
1729
+ structure.to(filename=str(cif_path), fmt="cif")
1730
+
1731
+ atoms = read_structure(cif_path, format="cif")
1732
+ write_structure(atoms, xyz_path, format="xyz", comment=mp_id)
1733
+
1734
+ # 4) YAML (bulk == elastic, geo == same XYZ)
1735
+ # NOTE: this assumes your writer supports these fields now (geo + elastic_xyz + bulk_xyz).
1736
+ write_trainset_settings_yaml(
1737
+ out_path=str(out_yaml),
1738
+ name=f"{name} ({mp_id})",
1739
+ source="materials_project",
1740
+ mp_id=mp_id,
1741
+ cij_gpa=cij,
1742
+ B0_gpa=B0,
1743
+ elastic_cell=cell,
1744
+ bulk_cell=cell,
1745
+ elastic_xyz=str(elastic_xyz_for_yaml),
1746
+ bulk_xyz=str(elastic_xyz_for_yaml),
1747
+ geo_enable=True,
1748
+ )
1749
+
1750
+ if verbose:
1751
+ print(f"[MP] CIF: {cif_path}")
1752
+ print(f"[MP] XYZ: {xyz_path}")
1753
+ print(f"[MP] YAML: {out_yaml}")
1754
+
1755
+ return {"cif": str(cif_path), "xyz": str(xyz_path), "yaml": str(out_yaml)}
1756
+
1757
+
1758
+