reaxkit 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. reaxkit/__init__.py +0 -0
  2. reaxkit/analysis/__init__.py +0 -0
  3. reaxkit/analysis/composed/RDF_analyzer.py +560 -0
  4. reaxkit/analysis/composed/__init__.py +0 -0
  5. reaxkit/analysis/composed/connectivity_analyzer.py +706 -0
  6. reaxkit/analysis/composed/coordination_analyzer.py +144 -0
  7. reaxkit/analysis/composed/electrostatics_analyzer.py +687 -0
  8. reaxkit/analysis/per_file/__init__.py +0 -0
  9. reaxkit/analysis/per_file/control_analyzer.py +165 -0
  10. reaxkit/analysis/per_file/eregime_analyzer.py +108 -0
  11. reaxkit/analysis/per_file/ffield_analyzer.py +305 -0
  12. reaxkit/analysis/per_file/fort13_analyzer.py +79 -0
  13. reaxkit/analysis/per_file/fort57_analyzer.py +106 -0
  14. reaxkit/analysis/per_file/fort73_analyzer.py +61 -0
  15. reaxkit/analysis/per_file/fort74_analyzer.py +65 -0
  16. reaxkit/analysis/per_file/fort76_analyzer.py +191 -0
  17. reaxkit/analysis/per_file/fort78_analyzer.py +154 -0
  18. reaxkit/analysis/per_file/fort79_analyzer.py +83 -0
  19. reaxkit/analysis/per_file/fort7_analyzer.py +393 -0
  20. reaxkit/analysis/per_file/fort99_analyzer.py +411 -0
  21. reaxkit/analysis/per_file/molfra_analyzer.py +359 -0
  22. reaxkit/analysis/per_file/params_analyzer.py +258 -0
  23. reaxkit/analysis/per_file/summary_analyzer.py +84 -0
  24. reaxkit/analysis/per_file/trainset_analyzer.py +84 -0
  25. reaxkit/analysis/per_file/vels_analyzer.py +95 -0
  26. reaxkit/analysis/per_file/xmolout_analyzer.py +528 -0
  27. reaxkit/cli.py +181 -0
  28. reaxkit/count_loc.py +276 -0
  29. reaxkit/data/alias.yaml +89 -0
  30. reaxkit/data/constants.yaml +27 -0
  31. reaxkit/data/reaxff_input_files_contents.yaml +186 -0
  32. reaxkit/data/reaxff_output_files_contents.yaml +301 -0
  33. reaxkit/data/units.yaml +38 -0
  34. reaxkit/help/__init__.py +0 -0
  35. reaxkit/help/help_index_loader.py +531 -0
  36. reaxkit/help/introspection_utils.py +131 -0
  37. reaxkit/io/__init__.py +0 -0
  38. reaxkit/io/base_handler.py +165 -0
  39. reaxkit/io/generators/__init__.py +0 -0
  40. reaxkit/io/generators/control_generator.py +123 -0
  41. reaxkit/io/generators/eregime_generator.py +341 -0
  42. reaxkit/io/generators/geo_generator.py +967 -0
  43. reaxkit/io/generators/trainset_generator.py +1758 -0
  44. reaxkit/io/generators/tregime_generator.py +113 -0
  45. reaxkit/io/generators/vregime_generator.py +164 -0
  46. reaxkit/io/generators/xmolout_generator.py +304 -0
  47. reaxkit/io/handlers/__init__.py +0 -0
  48. reaxkit/io/handlers/control_handler.py +209 -0
  49. reaxkit/io/handlers/eregime_handler.py +122 -0
  50. reaxkit/io/handlers/ffield_handler.py +812 -0
  51. reaxkit/io/handlers/fort13_handler.py +123 -0
  52. reaxkit/io/handlers/fort57_handler.py +143 -0
  53. reaxkit/io/handlers/fort73_handler.py +145 -0
  54. reaxkit/io/handlers/fort74_handler.py +155 -0
  55. reaxkit/io/handlers/fort76_handler.py +195 -0
  56. reaxkit/io/handlers/fort78_handler.py +142 -0
  57. reaxkit/io/handlers/fort79_handler.py +227 -0
  58. reaxkit/io/handlers/fort7_handler.py +264 -0
  59. reaxkit/io/handlers/fort99_handler.py +128 -0
  60. reaxkit/io/handlers/geo_handler.py +224 -0
  61. reaxkit/io/handlers/molfra_handler.py +184 -0
  62. reaxkit/io/handlers/params_handler.py +137 -0
  63. reaxkit/io/handlers/summary_handler.py +135 -0
  64. reaxkit/io/handlers/trainset_handler.py +658 -0
  65. reaxkit/io/handlers/vels_handler.py +293 -0
  66. reaxkit/io/handlers/xmolout_handler.py +174 -0
  67. reaxkit/utils/__init__.py +0 -0
  68. reaxkit/utils/alias.py +219 -0
  69. reaxkit/utils/cache.py +77 -0
  70. reaxkit/utils/constants.py +75 -0
  71. reaxkit/utils/equation_of_states.py +96 -0
  72. reaxkit/utils/exceptions.py +27 -0
  73. reaxkit/utils/frame_utils.py +175 -0
  74. reaxkit/utils/log.py +43 -0
  75. reaxkit/utils/media/__init__.py +0 -0
  76. reaxkit/utils/media/convert.py +90 -0
  77. reaxkit/utils/media/make_video.py +91 -0
  78. reaxkit/utils/media/plotter.py +812 -0
  79. reaxkit/utils/numerical/__init__.py +0 -0
  80. reaxkit/utils/numerical/extrema_finder.py +96 -0
  81. reaxkit/utils/numerical/moving_average.py +103 -0
  82. reaxkit/utils/numerical/numerical_calcs.py +75 -0
  83. reaxkit/utils/numerical/signal_ops.py +135 -0
  84. reaxkit/utils/path.py +55 -0
  85. reaxkit/utils/units.py +104 -0
  86. reaxkit/webui/__init__.py +0 -0
  87. reaxkit/webui/app.py +0 -0
  88. reaxkit/webui/components.py +0 -0
  89. reaxkit/webui/layouts.py +0 -0
  90. reaxkit/webui/utils.py +0 -0
  91. reaxkit/workflows/__init__.py +0 -0
  92. reaxkit/workflows/composed/__init__.py +0 -0
  93. reaxkit/workflows/composed/coordination_workflow.py +393 -0
  94. reaxkit/workflows/composed/electrostatics_workflow.py +587 -0
  95. reaxkit/workflows/composed/xmolout_fort7_workflow.py +343 -0
  96. reaxkit/workflows/meta/__init__.py +0 -0
  97. reaxkit/workflows/meta/help_workflow.py +136 -0
  98. reaxkit/workflows/meta/introspection_workflow.py +235 -0
  99. reaxkit/workflows/meta/make_video_workflow.py +61 -0
  100. reaxkit/workflows/meta/plotter_workflow.py +601 -0
  101. reaxkit/workflows/per_file/__init__.py +0 -0
  102. reaxkit/workflows/per_file/control_workflow.py +110 -0
  103. reaxkit/workflows/per_file/eregime_workflow.py +267 -0
  104. reaxkit/workflows/per_file/ffield_workflow.py +390 -0
  105. reaxkit/workflows/per_file/fort13_workflow.py +86 -0
  106. reaxkit/workflows/per_file/fort57_workflow.py +137 -0
  107. reaxkit/workflows/per_file/fort73_workflow.py +151 -0
  108. reaxkit/workflows/per_file/fort74_workflow.py +88 -0
  109. reaxkit/workflows/per_file/fort76_workflow.py +188 -0
  110. reaxkit/workflows/per_file/fort78_workflow.py +135 -0
  111. reaxkit/workflows/per_file/fort79_workflow.py +314 -0
  112. reaxkit/workflows/per_file/fort7_workflow.py +592 -0
  113. reaxkit/workflows/per_file/fort83_workflow.py +60 -0
  114. reaxkit/workflows/per_file/fort99_workflow.py +223 -0
  115. reaxkit/workflows/per_file/geo_workflow.py +554 -0
  116. reaxkit/workflows/per_file/molfra_workflow.py +577 -0
  117. reaxkit/workflows/per_file/params_workflow.py +135 -0
  118. reaxkit/workflows/per_file/summary_workflow.py +161 -0
  119. reaxkit/workflows/per_file/trainset_workflow.py +356 -0
  120. reaxkit/workflows/per_file/tregime_workflow.py +79 -0
  121. reaxkit/workflows/per_file/vels_workflow.py +309 -0
  122. reaxkit/workflows/per_file/vregime_workflow.py +75 -0
  123. reaxkit/workflows/per_file/xmolout_workflow.py +678 -0
  124. reaxkit-1.0.0.dist-info/METADATA +128 -0
  125. reaxkit-1.0.0.dist-info/RECORD +130 -0
  126. reaxkit-1.0.0.dist-info/WHEEL +5 -0
  127. reaxkit-1.0.0.dist-info/entry_points.txt +2 -0
  128. reaxkit-1.0.0.dist-info/licenses/AUTHORS.md +20 -0
  129. reaxkit-1.0.0.dist-info/licenses/LICENSE +21 -0
  130. reaxkit-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,812 @@
1
+ """
2
+ Plotting utilities for ReaxKit.
3
+
4
+ This module provides a collection of lightweight, reusable plotting helpers
5
+ built on Matplotlib for visualizing ReaxFF simulation data. The functions
6
+ support common use cases such as time-series plots, multi-axis comparisons,
7
+ stacked subplots, 3D atom visualizations, sensitivity (tornado) plots, and
8
+ 2D projections of 3D data.
9
+
10
+ All helpers share a consistent save/show behavior and are designed to be
11
+ used directly by analyzers and workflows.
12
+ """
13
+
14
+ from __future__ import annotations
15
+ import pandas as pd
16
+ from mpl_toolkits.mplot3d import Axes3D # noqa: F401 (needed for 3D)
17
+ import numpy as np
18
+ import matplotlib.pyplot as plt
19
+ from pathlib import Path
20
+ from typing import Optional, Sequence, Tuple, Union, Callable, Mapping, Any, Dict
21
+
22
+ # ---------- Helpers ----------
23
+ def _save_or_show(
24
+ fig: plt.Figure,
25
+ save_dir: Optional[Union[str, Path]],
26
+ filename: str,
27
+ show_message: bool = True,
28
+ ) -> None:
29
+ """
30
+ Save a figure to disk or display it interactively.
31
+
32
+ If ``save_dir`` is provided, the figure is saved either to the given
33
+ file path or into the specified directory using ``filename``. If
34
+ ``save_dir`` is ``None``, the figure is shown interactively.
35
+ """
36
+ if save_dir:
37
+ save_path = Path(save_dir)
38
+ # Known image/document extensions
39
+ exts = {".png", ".jpg", ".jpeg", ".svg", ".pdf", ".tif", ".tiff", ".bmp"}
40
+ if save_path.suffix.lower() in exts:
41
+ # save_dir is a full file path
42
+ save_path.parent.mkdir(parents=True, exist_ok=True)
43
+ out_file = save_path
44
+ else:
45
+ # save_dir is a directory; build filename.png inside it
46
+ save_path.mkdir(parents=True, exist_ok=True)
47
+ safe_name = filename.replace(" ", "_")
48
+ out_file = save_path / f"{safe_name}.png"
49
+
50
+ fig.savefig(out_file, dpi=300, bbox_inches='tight')
51
+ plt.close(fig)
52
+ print(f"[Done] saved plot to {out_file}") if show_message else None
53
+ else:
54
+ plt.show()
55
+
56
+
57
+ # ---------- Plots ----------
58
+ def single_plot(
59
+ x: Optional[Sequence[float]] = None,
60
+ y: Optional[Sequence[float]] = None,
61
+ *,
62
+ series: Optional[Sequence[Mapping[str, Any]]] = None,
63
+ hlines: Optional[Sequence[Union[float, tuple, Mapping[str, Any]]]] = None,
64
+ title: Optional[str] = None,
65
+ xlabel: Optional[str] = None,
66
+ ylabel: Optional[str] = None,
67
+ save: Optional[Union[str, Path]] = None,
68
+ legend: bool = False,
69
+ figsize: tuple[float, float] = (8.0, 3.2),
70
+ plot_type: str = "line",
71
+ ) -> plt.Figure:
72
+ """
73
+ Create a single plot for one or multiple data series.
74
+
75
+ This function supports both simple ``(x, y)`` inputs and flexible
76
+ multi-series plotting using a list of dictionaries. It can render
77
+ line or scatter plots, add horizontal reference lines, and
78
+ automatically save or display the figure.
79
+
80
+ Parameters
81
+ ----------
82
+ x, y : sequence of float, optional
83
+ Data to plot when using the simple API.
84
+ series : sequence of mapping, optional
85
+ Multi-series specification with keys such as ``x``, ``y``,
86
+ ``label``, ``marker``, and ``linewidth``.
87
+ hlines : sequence, optional
88
+ Horizontal reference lines specified as floats, tuples, or dicts.
89
+ title, xlabel, ylabel : str, optional
90
+ Plot title and axis labels.
91
+ save : str or Path, optional
92
+ Output directory or file path. If not provided, the plot is shown.
93
+ legend : bool, optional
94
+ Whether to display a legend.
95
+ figsize : tuple of float, optional
96
+ Figure size.
97
+ plot_type : {'line', 'scatter'}, optional
98
+ Type of plot to generate.
99
+
100
+ Returns
101
+ -------
102
+ matplotlib.figure.Figure
103
+ The created figure.
104
+ """
105
+
106
+
107
+ fig, ax = plt.subplots(figsize=figsize)
108
+
109
+ def _plot(ax, x, y, label=None, **kwargs):
110
+ if plot_type == "scatter":
111
+ ax.scatter(x, y, label=label, **kwargs)
112
+ else:
113
+ ax.plot(x, y, label=label, **kwargs)
114
+
115
+ # Multiple series
116
+ if series is not None:
117
+ for s in series:
118
+ sx = s.get("x")
119
+ sy = s.get("y")
120
+ if sx is None or sy is None:
121
+ continue
122
+ lbl = s.get("label")
123
+ lw = s.get("linewidth", 1.2)
124
+ mk = s.get("marker", "." if plot_type == "scatter" else None)
125
+ ms = s.get("markersize", 4)
126
+ al = s.get("alpha", 1.0)
127
+
128
+ # Map marker size kw correctly per plot type
129
+ kwargs = dict(linewidth=lw, marker=mk, alpha=al)
130
+ if plot_type == "scatter":
131
+ kwargs["s"] = ms # scatter uses 's'
132
+ else:
133
+ kwargs["markersize"] = ms # line plot uses 'markersize'
134
+
135
+ _plot(ax, sx, sy, label=lbl, **kwargs)
136
+ else:
137
+ if x is None or y is None:
138
+ raise ValueError("Provide (x, y) or 'series=[...]'.")
139
+ if plot_type == "scatter":
140
+ ax.scatter(x, y, label=None)
141
+ else:
142
+ ax.plot(x, y, label=None)
143
+
144
+ # Horizontal lines (unchanged) ...
145
+ if hlines:
146
+ for h in hlines:
147
+ if isinstance(h, Mapping):
148
+ yv = h.get("y")
149
+ if yv is None:
150
+ continue
151
+ lbl = h.get("label")
152
+ ls = h.get("linestyle", "--")
153
+ lw = h.get("linewidth", 1.0)
154
+ al = h.get("alpha", 1.0)
155
+ ax.axhline(yv, linestyle=ls, linewidth=lw, alpha=al, color="gray", label=lbl)
156
+ elif isinstance(h, tuple):
157
+ yv, lbl = h[0], (h[1] if len(h) > 1 else None)
158
+ ax.axhline(yv, linestyle="--", linewidth=1.0, alpha=1.0, color="gray", label=lbl)
159
+ else:
160
+ ax.axhline(float(h), linestyle="--", linewidth=1.0, alpha=1.0, color="gray")
161
+
162
+ if title:
163
+ ax.set_title(title)
164
+ if xlabel:
165
+ ax.set_xlabel(xlabel)
166
+ if ylabel:
167
+ ax.set_ylabel(ylabel)
168
+ if legend:
169
+ ax.legend()
170
+
171
+ fig.tight_layout()
172
+ _save_or_show(fig, save, title or "single_plot")
173
+ return fig
174
+
175
+
176
+ def directed_plot(
177
+ x: Sequence[float],
178
+ y: Sequence[float],
179
+ *,
180
+ figsize: Tuple[float, float] = (10, 6),
181
+ title: str = '',
182
+ xlabel: str = '',
183
+ ylabel: str = '',
184
+ color: str = 'blue',
185
+ linestyle: str = '-',
186
+ arrow_color: str = 'red',
187
+ arrow_width: float = 0.003,
188
+ grid: bool = False,
189
+ xlim: Optional[Tuple[float, float]] = None,
190
+ ylim: Optional[Tuple[float, float]] = None,
191
+ hline: Optional[float] = None,
192
+ hline_kwargs: Optional[Dict] = None,
193
+ legend: bool = False,
194
+ save: Optional[Union[str, Path]] = None
195
+ ) -> None:
196
+ """
197
+ Plot a 2D path with directional arrows.
198
+
199
+ This helper visualizes progression along a trajectory by drawing a
200
+ continuous line through ``(x, y)`` points and overlaying arrows that
201
+ indicate direction. It is useful for trajectories, energy paths, or
202
+ ordered parameter sweeps.
203
+
204
+ Parameters
205
+ ----------
206
+ x, y : sequence of float
207
+ Path coordinates.
208
+ title, xlabel, ylabel : str, optional
209
+ Plot title and axis labels.
210
+ save : str or Path, optional
211
+ Output directory or file path.
212
+
213
+ Returns
214
+ -------
215
+ None
216
+ """
217
+ dx = np.diff(x)
218
+ dy = np.diff(y)
219
+
220
+ fig, ax = plt.subplots(figsize=figsize)
221
+ ax.plot(x, y, linestyle=linestyle, color=color, label='Path')
222
+ ax.quiver(
223
+ x[:-1], y[:-1], dx, dy,
224
+ angles='xy', scale_units='xy', scale=1,
225
+ color=arrow_color, width=arrow_width
226
+ )
227
+
228
+ ax.set(title=title, xlabel=xlabel, ylabel=ylabel)
229
+
230
+ if xlim:
231
+ ax.set_xlim(xlim)
232
+ if ylim:
233
+ ax.set_ylim(ylim)
234
+ if grid:
235
+ ax.grid(True)
236
+ if hline is not None:
237
+ params = {'color': 'black', 'linestyle': '--', 'linewidth': 1}
238
+ if hline_kwargs:
239
+ params.update(hline_kwargs)
240
+ ax.axhline(hline, **params)
241
+ if legend:
242
+ ax.legend()
243
+
244
+ _save_or_show(fig, save, title or 'directed_plot')
245
+
246
+
247
+ def dual_yaxis_plot(
248
+ x: Sequence[float],
249
+ y1: Sequence[float],
250
+ y2: Sequence[float],
251
+ *,
252
+ figsize: Tuple[float, float] = (10, 6),
253
+ title: str = '',
254
+ xlabel: str = '',
255
+ ylabel1: str = '',
256
+ ylabel2: str = '',
257
+ color1: str = 'blue',
258
+ linestyle1: str = '-',
259
+ marker1: str = '',
260
+ color2: str = 'green',
261
+ linestyle2: str = '--',
262
+ marker2: str = '',
263
+ xlim: Optional[Tuple[float, float]] = None,
264
+ ylim1: Optional[Tuple[float, float]] = None,
265
+ ylim2: Optional[Tuple[float, float]] = None,
266
+ grid: bool = False,
267
+ hline1: Optional[float] = None,
268
+ hline1_kwargs: Optional[Dict] = None,
269
+ hline2: Optional[float] = None,
270
+ hline2_kwargs: Optional[Dict] = None,
271
+ vline: Optional[float] = None,
272
+ vline_kwargs: Optional[Dict] = None,
273
+ save: Optional[Union[str, Path]] = None
274
+ ) -> None:
275
+ """
276
+ Plot two datasets against a shared x-axis with separate y-axes.
277
+
278
+ This function is intended for comparing quantities with different
279
+ units or magnitudes on the same plot (e.g., energy vs temperature).
280
+
281
+ Parameters
282
+ ----------
283
+ x : sequence of float
284
+ Shared x-axis values.
285
+ y1, y2 : sequence of float
286
+ Data for the left and right y-axes.
287
+ title, xlabel, ylabel1, ylabel2 : str, optional
288
+ Labels and title.
289
+ save : str or Path, optional
290
+ Output directory or file path.
291
+
292
+ Returns
293
+ -------
294
+ None
295
+ """
296
+ fig, ax1 = plt.subplots(figsize=figsize)
297
+
298
+ ax1.plot(x, y1, linestyle=linestyle1, marker=marker1, color=color1)
299
+ ax1.set_xlabel(xlabel)
300
+ ax1.set_ylabel(ylabel1, color=color1)
301
+ if xlim:
302
+ ax1.set_xlim(xlim)
303
+ if ylim1:
304
+ ax1.set_ylim(ylim1)
305
+ if grid:
306
+ ax1.grid(True)
307
+ if hline1 is not None:
308
+ params = {'color': color1, 'linestyle': '--', 'linewidth': 1}
309
+ if hline1_kwargs:
310
+ params.update(hline1_kwargs)
311
+ ax1.axhline(hline1, **params)
312
+
313
+ ax2 = ax1.twinx()
314
+ ax2.plot(x, y2, linestyle=linestyle2, marker=marker2, color=color2)
315
+ ax2.set_ylabel(ylabel2, color=color2)
316
+ if ylim2:
317
+ ax2.set_ylim(ylim2)
318
+ if hline2 is not None:
319
+ params = {'color': color2, 'linestyle': '--', 'linewidth': 1}
320
+ if hline2_kwargs:
321
+ params.update(hline2_kwargs)
322
+ ax2.axhline(hline2, **params)
323
+
324
+ if vline is not None:
325
+ params = {'color': 'black', 'linestyle': ':', 'linewidth': 1}
326
+ if vline_kwargs:
327
+ params.update(vline_kwargs)
328
+ ax1.axvline(vline, **params)
329
+
330
+ fig.suptitle(title)
331
+ _save_or_show(fig, save, title or 'dual_yaxis_plot')
332
+
333
+
334
+ def multi_subplots(
335
+ subplots: Sequence[Sequence[Mapping[str, Any]]],
336
+ *,
337
+ title: Optional[Union[str, Sequence[Optional[str]]]] = None,
338
+ xlabel: Optional[Union[str, Sequence[Optional[str]]]] = None,
339
+ ylabel: Optional[Union[str, Sequence[Optional[str]]]] = None,
340
+ sharex: bool = False,
341
+ sharey: bool = False,
342
+ legend: bool = True,
343
+ grid: bool = False,
344
+ figsize: tuple[float, float] = (8.0, 6.0),
345
+ save: Optional[Union[str, Path]] = None,
346
+ ) -> plt.Figure:
347
+ """
348
+ Create multiple vertically stacked subplots.
349
+
350
+ Each subplot accepts the same series specification used by
351
+ ``single_plot``, allowing consistent plotting across panels.
352
+ Titles and axis labels may be shared or specified per subplot.
353
+
354
+ Parameters
355
+ ----------
356
+ subplots : sequence of sequence of dict
357
+ Series definitions for each subplot.
358
+ title, xlabel, ylabel : str or sequence of str, optional
359
+ Global or per-subplot titles and labels.
360
+ save : str or Path, optional
361
+ Output directory or file path.
362
+
363
+ Returns
364
+ -------
365
+ matplotlib.figure.Figure
366
+ The created figure.
367
+ """
368
+ import matplotlib.pyplot as plt
369
+
370
+ nplots = len(subplots)
371
+ if nplots == 0:
372
+ print("multi_subplots: no subplot data provided.")
373
+ return None # type: ignore[return-value]
374
+
375
+ def _normalize_seq(
376
+ val: Optional[Union[str, Sequence[Optional[str]]]],
377
+ n: int,
378
+ ) -> list[Optional[str]]:
379
+ """Turn val into a list of length n.
380
+
381
+ - None → [None] * n
382
+ - str → [str] * n
383
+ - sequence:
384
+ * len == 1 → repeat for all
385
+ * len == n → use as-is
386
+ * otherwise → error
387
+ """
388
+ if val is None:
389
+ return [None] * n
390
+ if isinstance(val, (list, tuple)):
391
+ if len(val) == 1:
392
+ return [val[0]] * n
393
+ if len(val) != n:
394
+ raise ValueError(
395
+ f"Expected sequence of length 1 or {n}, got length {len(val)}"
396
+ )
397
+ return list(val)
398
+ # single string
399
+ return [val] * n
400
+
401
+ # Handle global vs per-subplot title
402
+ if isinstance(title, (list, tuple)):
403
+ per_titles = _normalize_seq(title, nplots)
404
+ global_title = None
405
+ else:
406
+ per_titles = [None] * nplots
407
+ global_title = title
408
+
409
+ xlabels = _normalize_seq(xlabel, nplots)
410
+ ylabels = _normalize_seq(ylabel, nplots)
411
+
412
+ fig, axes = plt.subplots(
413
+ nplots,
414
+ 1,
415
+ figsize=figsize,
416
+ sharex=sharex,
417
+ sharey=sharey,
418
+ squeeze=False,
419
+ )
420
+ axes = axes.flatten()
421
+
422
+ for idx, ax in enumerate(axes):
423
+ if idx >= nplots:
424
+ break
425
+
426
+ # Plot series for this subplot
427
+ for series in subplots[idx]:
428
+ x = series.get("x")
429
+ y = series.get("y")
430
+ if x is None or y is None:
431
+ continue
432
+ label = series.get("label")
433
+ ax.plot(x, y, label=label)
434
+
435
+ # Per-subplot labels/titles
436
+ if ylabels[idx]:
437
+ ax.set_ylabel(ylabels[idx])
438
+ if xlabels[idx]:
439
+ ax.set_xlabel(xlabels[idx])
440
+ if per_titles[idx]:
441
+ ax.set_title(per_titles[idx])
442
+
443
+ if grid:
444
+ ax.grid(True, alpha=0.3)
445
+ if legend:
446
+ ax.legend(fontsize=9)
447
+
448
+ # Global suptitle if provided as a single string
449
+ if global_title:
450
+ fig.suptitle(global_title, fontsize=14, y=0.98)
451
+
452
+ fig.tight_layout()
453
+ _save_or_show(fig, save, global_title or "multi_subplots")
454
+ return fig
455
+
456
+
457
+
458
+
459
+ def tornado_plot(
460
+ labels: Sequence[str],
461
+ min_vals: Sequence[float],
462
+ max_vals: Sequence[float],
463
+ *,
464
+ median_vals: Optional[Sequence[float]] = None,
465
+ title: str = "Tornado Plot",
466
+ xlabel: str = "Value",
467
+ ylabel: str = "Value",
468
+ save: Optional[Union[str, Path]] = None,
469
+ top: int = 0,
470
+ vline: Optional[float] = None,
471
+ left_color="#1F77B4",
472
+ right_color=(225 / 255, 113 / 255, 29 / 255) # RGB tuple
473
+ ) -> None:
474
+ """
475
+ Create a tornado plot to visualize sensitivity or uncertainty ranges.
476
+
477
+ Each label is represented by a horizontal bar spanning from a minimum
478
+ to a maximum value. Bars are ordered by total span, highlighting the
479
+ most influential parameters.
480
+
481
+ Parameters
482
+ ----------
483
+ labels : sequence of str
484
+ Parameter or variable names.
485
+ min_vals, max_vals : sequence of float
486
+ Lower and upper bounds for each parameter.
487
+ median_vals : sequence of float, optional
488
+ Optional central estimates to mark on each bar.
489
+ save : str or Path, optional
490
+ Output directory or file path.
491
+
492
+ Returns
493
+ -------
494
+ None
495
+ """
496
+ df = pd.DataFrame({"label": labels, "min": min_vals, "max": max_vals})
497
+ if median_vals is not None:
498
+ if len(median_vals) != len(labels):
499
+ raise ValueError("median_vals must be the same length as labels/min_vals/max_vals")
500
+ df["median"] = list(median_vals)
501
+
502
+ df["span"] = df["max"] - df["min"]
503
+ df = df.sort_values("span", ascending=False).reset_index(drop=True)
504
+ if top and top > 0:
505
+ df = df.head(top)
506
+
507
+ if df.empty:
508
+ fig, ax = plt.subplots(figsize=(6, 3))
509
+ ax.text(0.5, 0.5, "No data to plot", ha="center", va="center")
510
+ ax.axis("off")
511
+ _save_or_show(fig, save, filename=title or "tornado_plot")
512
+ return
513
+
514
+ fig_height = max(3.2, 0.38 * len(df))
515
+ fig, ax = plt.subplots(figsize=(8.6, fig_height))
516
+
517
+ y_positions = np.arange(len(df))
518
+ bar_height = 0.5
519
+ edge_common = dict(edgecolor="black", linewidth=0.5)
520
+
521
+ for y, row in zip(y_positions, df.itertuples(index=False)):
522
+ if row.max < row.min: # guard against swapped inputs
523
+ left, right = row.max, row.min
524
+ else:
525
+ left, right = row.min, row.max
526
+
527
+ if vline is None:
528
+ # no reference: draw a single neutral bar
529
+ ax.barh(y=y, width=right - left, left=left, height=bar_height,
530
+ color="tab:gray", alpha=0.7, **edge_common)
531
+ else:
532
+ if right <= vline:
533
+ # fully left of vline
534
+ ax.barh(y=y, width=right - left, left=left, height=bar_height,
535
+ color=left_color, alpha=0.8, **edge_common)
536
+ elif left >= vline:
537
+ # fully right of vline
538
+ ax.barh(y=y, width=right - left, left=left, height=bar_height,
539
+ color=right_color, alpha=0.8, **edge_common)
540
+ else:
541
+ # straddles vline -> split into two bars
542
+ ax.barh(y=y, width=vline - left, left=left, height=bar_height,
543
+ color=left_color, alpha=0.8, **edge_common)
544
+ ax.barh(y=y, width=right - vline, left=vline, height=bar_height,
545
+ color=right_color, alpha=0.8, **edge_common)
546
+
547
+ # median marker as asterisk
548
+ if "median" in df.columns and not pd.isna(row.median):
549
+ ax.plot(row.median, y, marker="*", markersize=5, color='black', zorder=5)
550
+
551
+ if vline is not None:
552
+ ax.axvline(vline, linestyle="--", linewidth=1, color=(66/255, 196/255, 127/255))
553
+
554
+ # simple legend patches to match the example look
555
+ from matplotlib.patches import Patch
556
+ handles = [
557
+ Patch(facecolor=left_color, edgecolor="black"),
558
+ Patch(facecolor=right_color, edgecolor="black"),
559
+ plt.Line2D([0], [0], marker="*", linestyle="none", color="black",
560
+ markersize=10),
561
+ ]
562
+
563
+
564
+ ax.set_yticks(list(y_positions))
565
+ ax.set_yticklabels(df["label"])
566
+ ax.invert_yaxis()
567
+ ax.set_xlabel(xlabel)
568
+ ax.set_ylabel(ylabel)
569
+ ax.set_title(title)
570
+ ax.grid(axis="x", linestyle=":", alpha=0.4)
571
+ fig.tight_layout()
572
+
573
+ _save_or_show(fig, save, filename=title or "tornado_plot")
574
+
575
+
576
+
577
+
578
+ def scatter3d_points(
579
+ coords: np.ndarray, # (N, 3) array of XYZ
580
+ values: np.ndarray, # (N,) array, e.g., partial charges
581
+ *,
582
+ title: str = "atoms (3D)",
583
+ s: float = 8.0,
584
+ alpha: float = 0.9,
585
+ vmin: Optional[float] = None,
586
+ vmax: Optional[float] = None,
587
+ cmap: str = "coolwarm",
588
+ figsize: Tuple[float, float] = (7.5, 6.0),
589
+ elev: float = 22.0,
590
+ azim: float = 38.0,
591
+ save: Optional[Union[str, Path]] = None, # dir or full path w/ extension
592
+ show_colorbar: bool = True,
593
+ show_message: bool = True,
594
+ ):
595
+ """
596
+ Render a 3D scatter plot of atomic coordinates.
597
+
598
+ Points are colored by a scalar per-atom property such as partial
599
+ charge or bond-order sum, enabling spatial visualization of
600
+ atom-resolved quantities.
601
+
602
+ Parameters
603
+ ----------
604
+ coords : array-like, shape (N, 3)
605
+ Atomic coordinates.
606
+ values : array-like, shape (N,)
607
+ Scalar values mapped to colors.
608
+ save : str or Path, optional
609
+ Output directory or file path.
610
+
611
+ Returns
612
+ -------
613
+ matplotlib.figure.Figure
614
+ The created figure.
615
+ """
616
+ coords = np.asarray(coords, float)
617
+ values = np.asarray(values, float)
618
+ assert coords.ndim == 2 and coords.shape[1] == 3, "coords must be (N,3)"
619
+ assert values.ndim == 1 and len(values) == len(coords), "values must be (N,)"
620
+
621
+ fig = plt.figure(figsize=figsize)
622
+ ax = fig.add_subplot(111, projection="3d")
623
+
624
+ sc = ax.scatter(
625
+ coords[:, 0], coords[:, 1], coords[:, 2],
626
+ c=values, cmap=cmap, vmin=vmin, vmax=vmax,
627
+ s=s, alpha=alpha, depthshade=True
628
+ )
629
+ if show_colorbar:
630
+ cb = fig.colorbar(sc, ax=ax, shrink=0.75, pad=0.02)
631
+ cb.set_label("value")
632
+
633
+ ax.set_xlabel("x (Å)")
634
+ ax.set_ylabel("y (Å)")
635
+ ax.set_zlabel("z (Å)")
636
+ ax.set_title(title)
637
+ ax.view_init(elev=elev, azim=azim)
638
+
639
+ _save_or_show(fig, save, filename=title, show_message =show_message)
640
+
641
+ return fig
642
+
643
+
644
+ def heatmap2d_from_3d(
645
+ coords: np.ndarray, # (N,3) array of XYZ
646
+ values: np.ndarray, # (N,) scalar to aggregate (e.g., partial_charge)
647
+ *,
648
+ plane: str = "xy", # "xy", "xz", or "yz"
649
+ bins: Union[int, Tuple[int, int]] = 50, # grid resolution; int or (nx, ny)
650
+ xlim: Optional[Tuple[float, float]] = None,
651
+ ylim: Optional[Tuple[float, float]] = None,
652
+ agg: Union[str, Callable[[np.ndarray], float]] = "mean", # "mean","max","min","sum","count" or a callable
653
+ vmin: Optional[float] = None,
654
+ vmax: Optional[float] = None,
655
+ cmap: str = "viridis",
656
+ title: str = "2D aggregated heatmap",
657
+ figsize: Tuple[float, float] = (6.5, 5.5),
658
+ save: Optional[Union[str, Path]] = None, # dir or full path; if dir, saves PNG
659
+ show_colorbar: bool = True,
660
+ show_message: bool = True,
661
+ ):
662
+ """
663
+ Project 3D point data onto a 2D plane and aggregate values on a grid.
664
+
665
+ This function bins projected coordinates onto a regular grid and
666
+ aggregates per-point values using a specified operation, producing
667
+ a 2D heatmap suitable for planar analysis.
668
+
669
+ Parameters
670
+ ----------
671
+ coords : array-like, shape (N, 3)
672
+ 3D coordinates.
673
+ values : array-like, shape (N,)
674
+ Values to aggregate.
675
+ plane : {'xy', 'xz', 'yz'}, optional
676
+ Projection plane.
677
+ agg : {'mean', 'max', 'min', 'sum', 'count'} or callable, optional
678
+ Aggregation method.
679
+ save : str or Path, optional
680
+ Output directory or file path.
681
+
682
+ Returns
683
+ -------
684
+ fig : matplotlib.figure.Figure
685
+ grid : numpy.ndarray
686
+ Aggregated 2D grid.
687
+ xedges, yedges : numpy.ndarray
688
+ Bin edges along each axis.
689
+ """
690
+ coords = np.asarray(coords, float)
691
+ values = np.asarray(values, float)
692
+ assert coords.ndim == 2 and coords.shape[1] == 3, "coords must be (N,3)"
693
+ assert values.ndim == 1 and len(values) == len(coords), "values must be (N,)"
694
+
695
+ # Choose plane projection
696
+ if plane == "xy":
697
+ u, v = coords[:, 0], coords[:, 1]
698
+ xlabel, ylabel = "x (Å)", "y (Å)"
699
+ elif plane == "xz":
700
+ u, v = coords[:, 0], coords[:, 2]
701
+ xlabel, ylabel = "x (Å)", "z (Å)"
702
+ elif plane == "yz":
703
+ u, v = coords[:, 1], coords[:, 2]
704
+ xlabel, ylabel = "y (Å)", "z (Å)"
705
+ else:
706
+ raise ValueError("plane must be one of {'xy','xz','yz'}")
707
+
708
+ # Grid resolution
709
+ if isinstance(bins, int):
710
+ nx = ny = int(bins)
711
+ else:
712
+ nx, ny = int(bins[0]), int(bins[1])
713
+
714
+ # Ranges
715
+ umin = np.min(u) if xlim is None else xlim[0]
716
+ umax = np.max(u) if xlim is None else xlim[1]
717
+ vmin_edge = np.min(v) if ylim is None else ylim[0]
718
+ vmax_edge = np.max(v) if ylim is None else ylim[1]
719
+
720
+ # Bin edges
721
+ xedges = np.linspace(umin, umax, nx + 1)
722
+ yedges = np.linspace(vmin_edge, vmax_edge, ny + 1)
723
+
724
+ # Digitize points to cell indices
725
+ # Points exactly on the right/top edge go to the previous bin
726
+ ui = np.clip(np.digitize(u, xedges) - 1, 0, nx - 1)
727
+ vi = np.clip(np.digitize(v, yedges) - 1, 0, ny - 1)
728
+
729
+ # Flattened cell id for bincount tricks
730
+ flat_idx = vi * nx + ui
731
+ n_cells = nx * ny
732
+
733
+ grid = np.full((ny, nx), np.nan, float)
734
+
735
+ # Built-in fast paths
736
+ if isinstance(agg, str):
737
+ agg_lower = agg.lower()
738
+ if agg_lower == "count":
739
+ cnt = np.bincount(flat_idx, minlength=n_cells).astype(float)
740
+ grid = cnt.reshape(ny, nx)
741
+ elif agg_lower in {"sum", "mean"}:
742
+ sumv = np.bincount(flat_idx, weights=values, minlength=n_cells).astype(float)
743
+ cnt = np.bincount(flat_idx, minlength=n_cells).astype(float)
744
+ with np.errstate(invalid="ignore", divide="ignore"):
745
+ if agg_lower == "sum":
746
+ grid = sumv.reshape(ny, nx)
747
+ else: # mean
748
+ grid = (sumv / cnt).reshape(ny, nx)
749
+ grid[cnt.reshape(ny, nx) == 0] = np.nan
750
+ elif agg_lower in {"max", "min"}:
751
+ # One-pass update
752
+ fill_val = -np.inf if agg_lower == "max" else np.inf
753
+ flat_grid = np.full(n_cells, fill_val, float)
754
+ for idx, val in zip(flat_idx, values):
755
+ if agg_lower == "max":
756
+ if val > flat_grid[idx]:
757
+ flat_grid[idx] = val
758
+ else:
759
+ if val < flat_grid[idx]:
760
+ flat_grid[idx] = val
761
+ # Convert untouched cells to NaN
762
+ if agg_lower == "max":
763
+ flat_grid[flat_grid == -np.inf] = np.nan
764
+ else:
765
+ flat_grid[flat_grid == np.inf] = np.nan
766
+ grid = flat_grid.reshape(ny, nx)
767
+ else:
768
+ raise ValueError("agg must be one of {'mean','max','min','sum','count'} or a callable")
769
+ else:
770
+ # Callable aggregator: collect values per cell (okay for modest grids)
771
+ buckets: Sequence[list] = [list() for _ in range(n_cells)]
772
+ for idx, val in zip(flat_idx, values):
773
+ buckets[idx].append(val)
774
+ flat_grid = np.full(n_cells, np.nan, float)
775
+ for i, bucket in enumerate(buckets):
776
+ if bucket:
777
+ try:
778
+ flat_grid[i] = agg(np.asarray(bucket, float))
779
+ except Exception:
780
+ # Fallback: ignore cell if aggregator fails
781
+ flat_grid[i] = np.nan
782
+ grid = flat_grid.reshape(ny, nx)
783
+
784
+ # Color scale defaults from data if not provided
785
+ cmin = np.nanmin(grid) if vmin is None else vmin
786
+ cmax = np.nanmax(grid) if vmax is None else vmax
787
+
788
+ # Plot
789
+ fig, ax = plt.subplots(figsize=figsize)
790
+ # extent maps grid cells to physical coordinates
791
+ extent = (xedges[0], xedges[-1], yedges[0], yedges[-1])
792
+ im = ax.imshow(
793
+ grid,
794
+ origin="lower",
795
+ extent=extent,
796
+ aspect="auto",
797
+ vmin=cmin,
798
+ vmax=cmax,
799
+ cmap=cmap,
800
+ interpolation="nearest",
801
+ )
802
+ ax.set_xlabel(xlabel)
803
+ ax.set_ylabel(ylabel)
804
+ ax.set_title(title)
805
+
806
+ if show_colorbar:
807
+ cb = fig.colorbar(im, ax=ax, pad=0.02)
808
+ cb.set_label(f"{agg if isinstance(agg,str) else 'agg'} of values")
809
+
810
+ _save_or_show(fig, save, title.replace(" ", "_"), show_message =show_message)
811
+
812
+ return fig, grid, xedges, yedges