rapmat 0.2.2__tar.gz → 0.2.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. {rapmat-0.2.2 → rapmat-0.2.4}/PKG-INFO +3 -2
  2. {rapmat-0.2.2 → rapmat-0.2.4}/pyproject.toml +1 -1
  3. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/calculators/__init__.py +9 -0
  4. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/calculators/vasp.py +2 -1
  5. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/core/csp.py +26 -56
  6. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/core/dedup.py +0 -2
  7. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/core/dedup_analysis.py +11 -4
  8. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/core/evaluation.py +9 -18
  9. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/core/generation_worker.py +6 -4
  10. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/core/phonon.py +32 -0
  11. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/core/phonon_stability.py +11 -22
  12. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/core/relaxation.py +7 -6
  13. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/storage/base.py +5 -37
  14. rapmat-0.2.4/rapmat/storage/descriptors.py +35 -0
  15. rapmat-0.2.4/rapmat/storage/status.py +18 -0
  16. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/storage/surrealdb_store.py +117 -133
  17. rapmat-0.2.4/rapmat/tui/screens/base.py +86 -0
  18. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/screens/base_results.py +80 -27
  19. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/screens/csp_resume.py +10 -27
  20. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/screens/csp_search.py +10 -27
  21. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/screens/db_settings.py +10 -20
  22. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/screens/dedup.py +19 -40
  23. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/screens/eval.py +17 -45
  24. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/screens/home.py +10 -17
  25. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/screens/hull.py +16 -117
  26. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/screens/phonon.py +10 -27
  27. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/screens/results.py +4 -2
  28. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/screens/status.py +8 -18
  29. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/screens/study_create.py +9 -22
  30. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/screens/study_detail.py +16 -23
  31. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/screens/study_list.py +43 -51
  32. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/state.py +6 -2
  33. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/widgets/dialog.py +23 -36
  34. rapmat-0.2.4/rapmat/tui/widgets/search.py +45 -0
  35. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/utils/common.py +10 -0
  36. rapmat-0.2.4/rapmat/utils/progress.py +11 -0
  37. {rapmat-0.2.2 → rapmat-0.2.4}/tests/conftest.py +21 -54
  38. {rapmat-0.2.2 → rapmat-0.2.4}/tests/test_db_config_and_locking.py +23 -23
  39. {rapmat-0.2.2 → rapmat-0.2.4}/tests/test_dedup_analysis.py +9 -23
  40. {rapmat-0.2.2 → rapmat-0.2.4}/tests/test_hull.py +35 -32
  41. {rapmat-0.2.2 → rapmat-0.2.4}/tests/test_processing_loop.py +6 -21
  42. {rapmat-0.2.2 → rapmat-0.2.4}/tests/test_render.py +3 -3
  43. rapmat-0.2.4/tests/test_screens_render.py +271 -0
  44. rapmat-0.2.4/tests/test_storage.py +426 -0
  45. {rapmat-0.2.2 → rapmat-0.2.4}/tests/test_tui.py +154 -1
  46. {rapmat-0.2.2 → rapmat-0.2.4}/tests/test_tui_layout.py +3 -3
  47. {rapmat-0.2.2 → rapmat-0.2.4}/tests/test_ui_toggle.py +2 -7
  48. rapmat-0.2.2/rapmat/storage/descriptors.py +0 -65
  49. rapmat-0.2.2/tests/test_storage.py +0 -265
  50. {rapmat-0.2.2 → rapmat-0.2.4}/.github/workflows/python-publish.yml +0 -0
  51. {rapmat-0.2.2 → rapmat-0.2.4}/.gitignore +0 -0
  52. {rapmat-0.2.2 → rapmat-0.2.4}/LICENSE +0 -0
  53. {rapmat-0.2.2 → rapmat-0.2.4}/README.md +0 -0
  54. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/__init__.py +0 -0
  55. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/__main__.py +0 -0
  56. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/calculators/factory.py +0 -0
  57. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/calculators/mattersim.py +0 -0
  58. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/calculators/nequip.py +0 -0
  59. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/calculators/upet.py +0 -0
  60. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/config.py +0 -0
  61. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/core/__init__.py +0 -0
  62. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/core/hull.py +0 -0
  63. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/core/sanity.py +0 -0
  64. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/db_config.py +0 -0
  65. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/main.py +0 -0
  66. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/storage/__init__.py +0 -0
  67. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/__init__.py +0 -0
  68. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/app.py +0 -0
  69. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/router.py +0 -0
  70. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/screens/__init__.py +0 -0
  71. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/tasks.py +0 -0
  72. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/widgets/__init__.py +0 -0
  73. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/widgets/calc_fields.py +0 -0
  74. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/widgets/config_grid.py +0 -0
  75. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/widgets/dropdown.py +0 -0
  76. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/widgets/form.py +0 -0
  77. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/widgets/progress.py +0 -0
  78. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/widgets/status_bar.py +0 -0
  79. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/tui/widgets/table.py +0 -0
  80. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/utils/__init__.py +0 -0
  81. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/utils/console.py +0 -0
  82. {rapmat-0.2.2 → rapmat-0.2.4}/rapmat/utils/structure.py +0 -0
  83. {rapmat-0.2.2 → rapmat-0.2.4}/tests/test_checkbox.py +0 -0
  84. {rapmat-0.2.2 → rapmat-0.2.4}/tests/test_evaluation.py +0 -0
  85. {rapmat-0.2.2 → rapmat-0.2.4}/tests/test_relaxation.py +0 -0
  86. {rapmat-0.2.2 → rapmat-0.2.4}/tests/test_sanity.py +0 -0
  87. {rapmat-0.2.2 → rapmat-0.2.4}/tests/test_utils.py +0 -0
  88. {rapmat-0.2.2 → rapmat-0.2.4}/tests/test_vasp.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rapmat
3
- Version: 0.2.2
3
+ Version: 0.2.4
4
4
  Summary: Rapmat - rapid materials discovery using MLIPs and random search
5
5
  Project-URL: Homepage, https://github.com/milevevvvv/rapmat
6
6
  Author-email: Michael Levenets <milevev256@gmail.com>
@@ -18,7 +18,6 @@ Requires-Dist: phonopy>=2.47.1
18
18
  Requires-Dist: platformdirs>=4.9.4
19
19
  Requires-Dist: pydantic>=2.12.5
20
20
  Requires-Dist: pymatgen>=2025.10.7
21
- Requires-Dist: pytest>=8.0.0
22
21
  Requires-Dist: pyxtal>=1.1.2
23
22
  Requires-Dist: seekpath>=2.2.1
24
23
  Requires-Dist: spglib>=2.7.0
@@ -32,6 +31,8 @@ Provides-Extra: all-calculators
32
31
  Requires-Dist: mattersim>=1.2.0; extra == 'all-calculators'
33
32
  Requires-Dist: nequip>=0.16.2; extra == 'all-calculators'
34
33
  Requires-Dist: upet>=0.2.1; extra == 'all-calculators'
34
+ Provides-Extra: dev
35
+ Requires-Dist: pytest>=8.0.0; extra == 'dev'
35
36
  Provides-Extra: mattersim
36
37
  Requires-Dist: mattersim>=1.2.0; extra == 'mattersim'
37
38
  Provides-Extra: nequip
@@ -21,7 +21,6 @@ dependencies = [
21
21
  "dscribe==2.1.2",
22
22
  "surrealdb>=2.0.0",
23
23
  "matplotlib>=3.8.0",
24
- "pytest>=8.0.0",
25
24
  "platformdirs>=4.9.4",
26
25
  "filelock>=3.13.0",
27
26
  "tomli-w>=1.2.0",
@@ -39,6 +38,7 @@ mattersim = ["mattersim>=1.2.0"]
39
38
  nequip = ["nequip>=0.16.2"]
40
39
  upet = ["upet>=0.2.1"]
41
40
  all-calculators = ["mattersim>=1.2.0", "nequip>=0.16.2", "upet>=0.2.1"]
41
+ dev = ["pytest>=8.0.0"]
42
42
 
43
43
  [build-system]
44
44
  requires = ["hatchling", "hatch-vcs"]
@@ -63,6 +63,15 @@ class CalculatorCallback(Protocol):
63
63
  def on_status(self, message: str) -> None: ...
64
64
 
65
65
 
66
+ class ProgressCalcCallback:
67
+ def __init__(self, progress_callback) -> None:
68
+ self._progress_callback = progress_callback
69
+
70
+ def on_status(self, message: str) -> None:
71
+ if self._progress_callback:
72
+ self._progress_callback(0, 0, message)
73
+
74
+
66
75
  def _notify(callback: CalculatorCallback | None, message: str) -> None:
67
76
  if callback is not None:
68
77
  callback.on_status(message)
@@ -5,7 +5,8 @@ from ase.calculators.vasp import Vasp
5
5
 
6
6
  def build_calculator_vasp(config: dict, directory: Path | None = None) -> Vasp:
7
7
  kwargs = dict(config)
8
- if directory is not None:
8
+
9
+ if directory is not None and "directory" not in kwargs:
9
10
  kwargs["directory"] = str(directory)
10
11
 
11
12
  if "txt" not in kwargs:
@@ -1,7 +1,8 @@
1
1
  import traceback
2
- from enum import Enum
3
2
  from pathlib import Path
4
3
 
4
+ from rapmat.storage.status import StructureStatus
5
+ from rapmat.utils.progress import ProgressCallback
5
6
 
6
7
  # ------------------------------------------------------------------ #
7
8
  # Orchestration loops (used by TUI and tests)
@@ -18,34 +19,26 @@ def run_processing_loop(
18
19
  config: dict,
19
20
  workdir_path: Path,
20
21
  worker_id: str | None = None,
21
- progress_callback=None,
22
+ progress_callback: ProgressCallback | None = None,
22
23
  cancel_flag: list[bool] | None = None,
23
24
  ):
24
25
  import numpy as np
25
- import torch
26
26
  from ase.units import GPa as _GPa
27
27
 
28
- from rapmat.calculators import CalculatorCallback, Calculators
28
+ from rapmat.calculators import Calculators, ProgressCalcCallback
29
29
  from rapmat.calculators.factory import load_calculator
30
30
  from rapmat.core.relaxation import structure_relax
31
31
  from rapmat.core.sanity import check_sanity
32
+ from rapmat.utils.common import free_cuda_memory
32
33
  from rapmat.utils.console import get_logger
33
34
  logger = get_logger("rapmat.csp")
34
- from rapmat.utils.structure import (calculate_thickness, format_spg,
35
- standardize_atoms)
35
+ from rapmat.utils.structure import format_spg
36
36
 
37
- class _ProgressCalcCallback:
38
-
39
- def on_status(self, message: str) -> None:
40
- if progress_callback:
41
- progress_callback(0, 0, message)
42
-
43
- _calc_cb = _ProgressCalcCallback()
37
+ _calc_cb = ProgressCalcCallback(progress_callback)
44
38
 
45
39
  calculator_name = config.get("calculator", "MATTERSIM").upper()
46
40
  calculator_config = config.get("calculator_config", {})
47
41
  domain_val = config.get("domain", "bulk")
48
- search_dim = 3 if domain_val == "bulk" else 2
49
42
  symprec = config.get("symprec", 1e-5)
50
43
  pressure_gpa = config.get("pressure_gpa", 0.0)
51
44
  pressure_evA3 = pressure_gpa * _GPa
@@ -64,11 +57,7 @@ def run_processing_loop(
64
57
 
65
58
  relaxed_structures = []
66
59
 
67
- try:
68
- if torch.cuda.is_available():
69
- torch.cuda.empty_cache()
70
- except Exception:
71
- pass
60
+ free_cuda_memory()
72
61
 
73
62
  _calc_cb.on_status(f"Loading calculator {calculator_name}...")
74
63
  calculator = load_calculator(
@@ -80,9 +69,6 @@ def run_processing_loop(
80
69
 
81
70
  counter: int = 0
82
71
  discarded_sanity = 0
83
- discarded_unstable = 0
84
- discarded_dup = 0
85
- discarded_candidate_dup = 0
86
72
  n_candidates = len(candidates)
87
73
 
88
74
  def _report(msg: str) -> None:
@@ -91,7 +77,6 @@ def run_processing_loop(
91
77
 
92
78
  def _run_loop():
93
79
  nonlocal counter, discarded_sanity
94
- nonlocal discarded_unstable, discarded_dup, discarded_candidate_dup
95
80
  nonlocal calculator
96
81
 
97
82
  for candidate in candidates:
@@ -121,7 +106,9 @@ def run_processing_loop(
121
106
  def _optim_cb(step: int, max_steps: int, msg: str) -> None:
122
107
  if progress_callback:
123
108
  msg_fmt = f"Relaxing {struct_id}: {msg}"
124
- progress_callback(counter, n_candidates, msg_fmt, False)
109
+ progress_callback(
110
+ counter, n_candidates, msg_fmt, is_log=False
111
+ )
125
112
 
126
113
  converged, relaxed_structure = structure_relax(
127
114
  structure,
@@ -161,17 +148,11 @@ def run_processing_loop(
161
148
 
162
149
  _report("Metadata preparation...")
163
150
  energy = relaxed_structure.info["energy"]
164
- volume = relaxed_structure.get_volume()
165
- enthalpy = energy + pressure_evA3 * volume
166
151
 
167
152
  meta = {
168
153
  "energy_per_atom": energy / len(relaxed_structure),
169
- "energy_total": energy,
170
- "enthalpy_per_atom": enthalpy / len(relaxed_structure),
171
- "volume": volume,
172
154
  "fmax": relaxed_structure.info["fmax"],
173
155
  "converged": relaxed_structure.info["converged"],
174
- "thickness": 0.0,
175
156
  }
176
157
 
177
158
  _report("Checking sanity...")
@@ -185,22 +166,16 @@ def run_processing_loop(
185
166
  _report(f"Discarded {struct_id}: failed sanity check")
186
167
  store.update_structure(
187
168
  struct_id,
188
- status="discarded",
169
+ status=StructureStatus.DISCARDED,
189
170
  atoms=relaxed_structure,
190
171
  metadata=meta,
191
172
  )
192
173
  break
193
174
 
194
- if search_dim == 2:
195
- _report("Calculating thickness...")
196
- current_thickness = calculate_thickness(relaxed_structure)
197
- relaxed_structure.info["thickness"] = current_thickness
198
- meta["thickness"] = current_thickness
199
-
200
175
  _report("Saving to database...")
201
176
  store.update_structure(
202
177
  struct_id,
203
- status="relaxed",
178
+ status=StructureStatus.RELAXED,
204
179
  atoms=relaxed_structure,
205
180
  metadata=meta,
206
181
  )
@@ -216,15 +191,16 @@ def run_processing_loop(
216
191
  "Failed to relax structure %s: %s",
217
192
  struct_id, ex, exc_info=True,
218
193
  )
219
- store.update_structure(struct_id, status="error")
194
+ store.update_structure(
195
+ struct_id, status=StructureStatus.ERROR
196
+ )
220
197
  break
221
198
 
222
199
  try:
223
200
  del calculator
224
- if torch.cuda.is_available():
225
- torch.cuda.empty_cache()
226
201
  except Exception:
227
202
  pass
203
+ free_cuda_memory()
228
204
 
229
205
  _report(
230
206
  f"Reloading calculator {calculator_name} after error (attempt {attempt + 1})..."
@@ -243,7 +219,9 @@ def run_processing_loop(
243
219
  )
244
220
  _report(f"CRITICAL ERROR: Reload failed: {reload_ex}")
245
221
  _report(f"Reload traceback:\n{traceback.format_exc()}")
246
- store.update_structure(struct_id, status="error")
222
+ store.update_structure(
223
+ struct_id, status=StructureStatus.ERROR
224
+ )
247
225
  break
248
226
 
249
227
  if progress_callback:
@@ -254,18 +232,10 @@ def run_processing_loop(
254
232
  _run_loop()
255
233
 
256
234
  n_relaxed = len(relaxed_structures)
257
- discarded_parts = [
258
- f"{discarded_candidate_dup} cand-dup",
259
- f"{discarded_sanity} sanity",
260
- f"{discarded_unstable} unstable",
261
- f"{discarded_dup} dup",
262
- ]
263
- discarded_str = ", ".join(discarded_parts)
264
-
265
235
  pressure_msg = f" | Pressure: {pressure_gpa} GPa" if pressure_gpa > 0 else ""
266
236
  logger.info(
267
- "Done. Run: %s | Storage: %s%s | Relaxed: %d | Discarded: %s",
268
- run_name, store._db_url, pressure_msg, n_relaxed, discarded_str,
237
+ "Done. Run: %s | Storage: %s%s | Relaxed: %d | Discarded (sanity): %d",
238
+ run_name, store.get_url(), pressure_msg, n_relaxed, discarded_sanity,
269
239
  )
270
240
 
271
241
  return None
@@ -277,7 +247,7 @@ def run_generation_loop(
277
247
  config: dict,
278
248
  worker_id: str | None = None,
279
249
  workers: int = 1,
280
- progress_callback=None,
250
+ progress_callback: ProgressCallback | None = None,
281
251
  cancel_flag: list[bool] | None = None,
282
252
  log_callback=None,
283
253
  ) -> int:
@@ -327,13 +297,13 @@ def run_generation_loop(
327
297
  def _handle_result(status, struct_id, atoms, spg, fu):
328
298
  nonlocal generated, discarded, errors
329
299
  match status:
330
- case "generated":
300
+ case StructureStatus.GENERATED:
331
301
  store.update_generated_structure(struct_id, atoms)
332
302
  generated += 1
333
- case "discarded":
303
+ case StructureStatus.DISCARDED:
334
304
  store.discard_generation_placeholder(struct_id)
335
305
  discarded += 1
336
- case "error":
306
+ case StructureStatus.ERROR:
337
307
  logger.error(
338
308
  "Structure for group %s / fu %s failed", spg, fu,
339
309
  )
@@ -1,9 +1,7 @@
1
1
  import warnings
2
- from typing import Optional
3
2
 
4
3
  import numpy as np
5
4
  from ase import Atoms
6
- from pymatgen.analysis.structure_matcher import StructureMatcher
7
5
  from pymatgen.io.ase import AseAtomsAdaptor
8
6
 
9
7
 
@@ -11,6 +11,8 @@ except ImportError:
11
11
  StructureMatcher = None # type: ignore[assignment,misc]
12
12
 
13
13
  from rapmat.core.dedup import _to_pymatgen, forces_cosine_similarity
14
+ from rapmat.utils.console import get_logger
15
+ from rapmat.utils.progress import ProgressCallback
14
16
 
15
17
 
16
18
  @dataclass
@@ -47,7 +49,7 @@ def simulate_deduplication(
47
49
  angle_tol: float = 5.0,
48
50
  use_forces: bool = False,
49
51
  force_cosine_threshold: float = 0.95,
50
- progress_callback=None,
52
+ progress_callback: ProgressCallback | None = None,
51
53
  ) -> DedupSimulationResult:
52
54
  result = DedupSimulationResult(total=len(structures))
53
55
 
@@ -72,7 +74,7 @@ def simulate_deduplication(
72
74
 
73
75
  for i in range(N):
74
76
  if progress_callback:
75
- progress_callback(i, N, False)
77
+ progress_callback(i, N, f"Dedup: {i}/{N}", is_log=False)
76
78
 
77
79
  if i in dropped:
78
80
  continue
@@ -98,7 +100,12 @@ def simulate_deduplication(
98
100
  result.rescued_by_pymatgen += 1
99
101
  result.dropped_by_vector -= 1
100
102
  confirmed = False
101
- except Exception:
103
+ except Exception as exc:
104
+ get_logger("rapmat.dedup").warning(
105
+ "pymatgen comparison failed for pair (%s, %s), "
106
+ "treating as non-duplicate: %s",
107
+ with_vec[i]["id"], with_vec[j]["id"], exc,
108
+ )
102
109
  result.pymatgen_mismatches += 1
103
110
  result.rescued_by_pymatgen += 1
104
111
  result.dropped_by_vector -= 1
@@ -124,7 +131,7 @@ def simulate_deduplication(
124
131
  dropped.add(j)
125
132
 
126
133
  if progress_callback:
127
- progress_callback(N, N, False)
134
+ progress_callback(N, N, f"Dedup: {N}/{N}", is_log=False)
128
135
 
129
136
  result.final_dropped = len(dropped)
130
137
  result.kept = N - len(dropped)
@@ -1,6 +1,6 @@
1
1
  from typing import Sequence
2
2
 
3
- from rapmat.utils.structure import standardize_atoms
3
+ from rapmat.utils.progress import ProgressCallback
4
4
 
5
5
  # ------------------------------------------------------------------ #
6
6
  # Evaluation loop (used by TUI and tests)
@@ -19,14 +19,13 @@ def run_eval_loop(
19
19
  phonon_displacement: float = 1e-2,
20
20
  phonon_supercell: tuple = (3, 3, 3),
21
21
  phonon_mesh: tuple = (20, 20, 20),
22
- progress_callback=None,
22
+ progress_callback: ProgressCallback | None = None,
23
23
  log_callback=None,
24
24
  reduce_to_primitive: bool = True,
25
25
  symprec: float = 1e-3,
26
26
  ) -> None:
27
27
  from rapmat.calculators import cleanup_calculator_files
28
- from rapmat.core.phonon import (get_mesh_min_frequency,
29
- structure_calculate_phonons)
28
+ from rapmat.core.phonon import calculate_min_phonon_freq
30
29
  from rapmat.utils.console import get_logger
31
30
  logger = get_logger("rapmat.evaluation")
32
31
 
@@ -45,26 +44,18 @@ def run_eval_loop(
45
44
 
46
45
  ref_phonon_freq = None
47
46
  if run_phonons:
48
- if reduce_to_primitive:
49
- atoms_len_before = len(atoms)
50
- atoms = standardize_atoms(atoms, to_primitive=True, symprec=symprec)
51
- atoms_len_after = len(atoms)
52
-
53
- atoms.calc = calculator
54
-
55
- if log_callback:
56
- log_callback(
57
- f"Reducing {rec['id']}: {atoms_len_before} -> {atoms_len_after} atoms"
58
- )
59
-
60
- phonons = structure_calculate_phonons(
47
+ ref_phonon_freq = calculate_min_phonon_freq(
61
48
  atoms,
49
+ calculator=calculator,
62
50
  displacement=phonon_displacement,
63
51
  supercell=phonon_supercell,
64
52
  qpoint_mesh=phonon_mesh,
53
+ reduce_primitive=reduce_to_primitive,
54
+ symprec=symprec,
65
55
  progress_callback=progress_callback,
56
+ log_label=rec["id"],
57
+ log_callback=log_callback,
66
58
  )
67
- ref_phonon_freq = get_mesh_min_frequency(phonons)
68
59
 
69
60
  store.add_evaluation(
70
61
  structure_id=rec["id"],
@@ -1,3 +1,5 @@
1
+ from rapmat.storage.status import StructureStatus
2
+
1
3
 
2
4
  def generate_one_structure(
3
5
  struct_id: str,
@@ -27,9 +29,9 @@ def generate_one_structure(
27
29
  if crystal.valid:
28
30
  atoms = crystal.to_ase()
29
31
 
30
- return ("generated", struct_id, atoms)
31
- return ("discarded", struct_id, None)
32
+ return (StructureStatus.GENERATED, struct_id, atoms)
33
+ return (StructureStatus.DISCARDED, struct_id, None)
32
34
  except pyxtal.msg.Comp_CompatibilityError:
33
- return ("discarded", struct_id, None)
35
+ return (StructureStatus.DISCARDED, struct_id, None)
34
36
  except RuntimeError:
35
- return ("error", struct_id, None)
37
+ return (StructureStatus.ERROR, struct_id, None)
@@ -70,6 +70,38 @@ def structure_calculate_phonons(
70
70
  return phonons
71
71
 
72
72
 
73
+ def calculate_min_phonon_freq(
74
+ atoms: Atoms,
75
+ *,
76
+ calculator,
77
+ displacement: float,
78
+ supercell: Tuple[int, int, int],
79
+ qpoint_mesh: Tuple[int, int, int],
80
+ reduce_primitive: bool = True,
81
+ symprec: float = 1e-3,
82
+ progress_callback=None,
83
+ log_label: str | None = None,
84
+ log_callback=None,
85
+ ) -> float:
86
+ if reduce_primitive:
87
+ from rapmat.utils.structure import standardize_atoms
88
+
89
+ n_before = len(atoms)
90
+ atoms = standardize_atoms(atoms, to_primitive=True, symprec=symprec)
91
+ if log_callback and log_label:
92
+ log_callback(f"Reducing {log_label}: {n_before} -> {len(atoms)} atoms")
93
+
94
+ atoms.calc = calculator
95
+ phonons = structure_calculate_phonons(
96
+ atoms,
97
+ displacement=displacement,
98
+ supercell=supercell,
99
+ qpoint_mesh=qpoint_mesh,
100
+ progress_callback=progress_callback,
101
+ )
102
+ return get_mesh_min_frequency(phonons)
103
+
104
+
73
105
  def get_mesh_min_frequency(phonons: Phonopy) -> float:
74
106
  return float(np.min(phonons.get_mesh_dict()["frequencies"]))
75
107
 
@@ -1,15 +1,14 @@
1
- from typing import Callable, List, Optional, Tuple
1
+ from typing import List, Optional, Tuple
2
2
 
3
3
  from ase import Atoms
4
4
 
5
- from rapmat.calculators import CalculatorCallback, Calculators
5
+ from rapmat.calculators import Calculators, ProgressCalcCallback
6
6
  from rapmat.calculators.factory import load_calculator
7
- from rapmat.core.phonon import (get_mesh_min_frequency,
8
- structure_calculate_phonons,
9
- structure_has_imag_phonon_freq)
7
+ from rapmat.core.phonon import calculate_min_phonon_freq
10
8
  from rapmat.storage.base import StructureStore
11
9
  from rapmat.utils.common import workdir_context
12
10
  from rapmat.utils.console import get_logger
11
+ from rapmat.utils.progress import ProgressCallback
13
12
 
14
13
  _logger = get_logger("rapmat.phonon_stability")
15
14
 
@@ -25,7 +24,7 @@ def compute_dynamical_stability_for_results(
25
24
  phonon_calculator: Calculators,
26
25
  store: Optional["StructureStore"] = None,
27
26
  calculator_config: dict | None = None,
28
- progress_callback: Callable[[int, int, str], None] | None = None,
27
+ progress_callback: ProgressCallback | None = None,
29
28
  symprec: float = 1e-3,
30
29
  reduce_primitive: bool = True,
31
30
  ) -> bool:
@@ -44,17 +43,12 @@ def compute_dynamical_stability_for_results(
44
43
  if not target_results:
45
44
  return False
46
45
 
47
- class _ProgressCalcCallback:
48
- def on_status(self, message: str) -> None:
49
- if progress_callback:
50
- progress_callback(0, 0, message)
51
-
52
46
  with workdir_context(None) as wdir:
53
47
  calculator = load_calculator(
54
48
  phonon_calculator,
55
49
  wdir,
56
50
  config=calculator_config,
57
- callback=_ProgressCalcCallback(),
51
+ callback=ProgressCalcCallback(progress_callback),
58
52
  )
59
53
  updated = False
60
54
  total = len(target_results)
@@ -73,25 +67,20 @@ def compute_dynamical_stability_for_results(
73
67
  return
74
68
 
75
69
  atoms = structures[structure_index]
76
- if reduce_primitive:
77
- from rapmat.utils.structure import standardize_atoms
78
-
79
- atoms = standardize_atoms(atoms, to_primitive=True, symprec=symprec)
80
- atoms.calc = calculator
81
70
 
82
71
  try:
83
- phonons = structure_calculate_phonons(
72
+ min_freq = calculate_min_phonon_freq(
84
73
  atoms,
74
+ calculator=calculator,
85
75
  displacement=phonon_displacement,
86
76
  supercell=phonon_supercell,
87
77
  qpoint_mesh=phonon_mesh,
78
+ reduce_primitive=reduce_primitive,
79
+ symprec=symprec,
88
80
  progress_callback=progress_callback,
89
81
  )
90
- min_freq = get_mesh_min_frequency(phonons)
91
82
  result["min_phonon_freq"] = min_freq
92
- result["dynamical_stability"] = not structure_has_imag_phonon_freq(
93
- phonons, threshold=phonon_cutoff
94
- )
83
+ result["dynamical_stability"] = not (min_freq < phonon_cutoff)
95
84
  if store is not None and result.get("structure_id"):
96
85
  store.update_structure_phonon(result["structure_id"], min_freq)
97
86
  updated = True
@@ -3,9 +3,12 @@ from typing import Optional, Tuple, Type
3
3
 
4
4
  import numpy as np
5
5
  from ase import Atoms
6
- from ase.filters import Filter, FrechetCellFilter
6
+ from ase.filters import FrechetCellFilter
7
7
  from ase.optimize import BFGS
8
8
 
9
+ from rapmat.utils.common import free_cuda_memory
10
+ from rapmat.utils.progress import ProgressCallback
11
+
9
12
 
10
13
  def _max_force(atoms) -> float:
11
14
  return float(np.max(np.linalg.norm(atoms.get_forces(), axis=1)))
@@ -23,10 +26,8 @@ def structure_relax(
23
26
  suppress_warnings: bool = True,
24
27
  scalar_pressure: float = 0.0,
25
28
  cancel_flag: list[bool] | None = None,
26
- progress_callback=None,
29
+ progress_callback: ProgressCallback | None = None,
27
30
  ) -> Tuple[bool, Atoms]:
28
- import torch
29
-
30
31
  if atoms.calc is None:
31
32
  raise RuntimeError("No calculator set for the structure.")
32
33
 
@@ -67,7 +68,7 @@ def structure_relax(
67
68
 
68
69
  converged = not force_broken and _max_force(atoms_cf) <= force_conv_crit
69
70
 
70
- if cleanup_gpu and torch.cuda.is_available():
71
- torch.cuda.empty_cache()
71
+ if cleanup_gpu:
72
+ free_cuda_memory()
72
73
 
73
74
  return converged, atoms_cf.atoms
@@ -1,29 +1,9 @@
1
- from pymatgen.util import string
2
1
  from abc import ABC, abstractmethod
3
- from pathlib import Path
4
- from typing import List, Optional, Tuple
2
+ from typing import Callable, List, Optional, Tuple
5
3
 
6
- import numpy as np
7
4
  from ase import Atoms
8
5
 
9
- # ------------------------------------------------------------------ #
10
- # Descriptor ABC
11
- # ------------------------------------------------------------------ #
12
-
13
-
14
- class StructureDescriptor(ABC):
15
- @abstractmethod
16
- def dimension(self) -> int: ...
17
-
18
- @abstractmethod
19
- def compute(self, atoms: Atoms) -> np.ndarray: ...
20
-
21
- @abstractmethod
22
- def code_version(self) -> str: ...
23
-
24
- @abstractmethod
25
- def descriptor_id(self) -> str: ...
26
-
6
+ from rapmat.storage.status import StructureStatus
27
7
 
28
8
  # ------------------------------------------------------------------ #
29
9
  # Store ABC
@@ -31,13 +11,6 @@ class StructureDescriptor(ABC):
31
11
 
32
12
 
33
13
  class StructureStore(ABC):
34
- @abstractmethod
35
- def register_descriptor(
36
- self,
37
- desc: StructureDescriptor,
38
- ) -> str:
39
- ...
40
-
41
14
  @abstractmethod
42
15
  def create_run(
43
16
  self,
@@ -66,7 +39,7 @@ class StructureStore(ABC):
66
39
  def claim_run(self, run_name: str, worker_id: str) -> bool: ...
67
40
 
68
41
  @abstractmethod
69
- def release_run(self, run_name: str, final_status: str = "completed") -> None: ...
42
+ def release_run(self, run_name: str, final_status: str) -> None: ...
70
43
 
71
44
  @abstractmethod
72
45
  def update_heartbeat(self, run_name: str, worker_id: str) -> None: ...
@@ -83,7 +56,6 @@ class StructureStore(ABC):
83
56
  struct_id: str,
84
57
  status: str,
85
58
  atoms: Optional[Atoms] = None,
86
- vector: Optional[np.ndarray] = None,
87
59
  metadata: Optional[dict] = None,
88
60
  ) -> None: ...
89
61
 
@@ -108,13 +80,14 @@ class StructureStore(ABC):
108
80
  status: Optional[str] = None,
109
81
  statuses: Optional[tuple[str, ...]] = None,
110
82
  symprec: float = 1e-3,
83
+ progress_callback: Optional[Callable[..., None]] = None,
111
84
  ) -> List[dict]: ...
112
85
 
113
86
  @abstractmethod
114
87
  def get_structures_for_analysis(
115
88
  self,
116
89
  run_id: str,
117
- statuses: tuple = ("relaxed",),
90
+ statuses: tuple = (StructureStatus.RELAXED,),
118
91
  ) -> List[dict]: ...
119
92
 
120
93
  @abstractmethod
@@ -193,7 +166,6 @@ class StructureStore(ABC):
193
166
  self,
194
167
  struct_id: str,
195
168
  atoms: Atoms,
196
- vector: Optional[np.ndarray] = None,
197
169
  ) -> None: ...
198
170
 
199
171
  @abstractmethod
@@ -201,7 +173,3 @@ class StructureStore(ABC):
201
173
 
202
174
  @abstractmethod
203
175
  def get_url(self) -> Optional[str]: ...
204
-
205
- @classmethod
206
- def from_path(cls, db_path: Path, **kwargs) -> "StructureStore":
207
- raise NotImplementedError