reaxkit 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- reaxkit/__init__.py +0 -0
- reaxkit/analysis/__init__.py +0 -0
- reaxkit/analysis/composed/RDF_analyzer.py +560 -0
- reaxkit/analysis/composed/__init__.py +0 -0
- reaxkit/analysis/composed/connectivity_analyzer.py +706 -0
- reaxkit/analysis/composed/coordination_analyzer.py +144 -0
- reaxkit/analysis/composed/electrostatics_analyzer.py +687 -0
- reaxkit/analysis/per_file/__init__.py +0 -0
- reaxkit/analysis/per_file/control_analyzer.py +165 -0
- reaxkit/analysis/per_file/eregime_analyzer.py +108 -0
- reaxkit/analysis/per_file/ffield_analyzer.py +305 -0
- reaxkit/analysis/per_file/fort13_analyzer.py +79 -0
- reaxkit/analysis/per_file/fort57_analyzer.py +106 -0
- reaxkit/analysis/per_file/fort73_analyzer.py +61 -0
- reaxkit/analysis/per_file/fort74_analyzer.py +65 -0
- reaxkit/analysis/per_file/fort76_analyzer.py +191 -0
- reaxkit/analysis/per_file/fort78_analyzer.py +154 -0
- reaxkit/analysis/per_file/fort79_analyzer.py +83 -0
- reaxkit/analysis/per_file/fort7_analyzer.py +393 -0
- reaxkit/analysis/per_file/fort99_analyzer.py +411 -0
- reaxkit/analysis/per_file/molfra_analyzer.py +359 -0
- reaxkit/analysis/per_file/params_analyzer.py +258 -0
- reaxkit/analysis/per_file/summary_analyzer.py +84 -0
- reaxkit/analysis/per_file/trainset_analyzer.py +84 -0
- reaxkit/analysis/per_file/vels_analyzer.py +95 -0
- reaxkit/analysis/per_file/xmolout_analyzer.py +528 -0
- reaxkit/cli.py +181 -0
- reaxkit/count_loc.py +276 -0
- reaxkit/data/alias.yaml +89 -0
- reaxkit/data/constants.yaml +27 -0
- reaxkit/data/reaxff_input_files_contents.yaml +186 -0
- reaxkit/data/reaxff_output_files_contents.yaml +301 -0
- reaxkit/data/units.yaml +38 -0
- reaxkit/help/__init__.py +0 -0
- reaxkit/help/help_index_loader.py +531 -0
- reaxkit/help/introspection_utils.py +131 -0
- reaxkit/io/__init__.py +0 -0
- reaxkit/io/base_handler.py +165 -0
- reaxkit/io/generators/__init__.py +0 -0
- reaxkit/io/generators/control_generator.py +123 -0
- reaxkit/io/generators/eregime_generator.py +341 -0
- reaxkit/io/generators/geo_generator.py +967 -0
- reaxkit/io/generators/trainset_generator.py +1758 -0
- reaxkit/io/generators/tregime_generator.py +113 -0
- reaxkit/io/generators/vregime_generator.py +164 -0
- reaxkit/io/generators/xmolout_generator.py +304 -0
- reaxkit/io/handlers/__init__.py +0 -0
- reaxkit/io/handlers/control_handler.py +209 -0
- reaxkit/io/handlers/eregime_handler.py +122 -0
- reaxkit/io/handlers/ffield_handler.py +812 -0
- reaxkit/io/handlers/fort13_handler.py +123 -0
- reaxkit/io/handlers/fort57_handler.py +143 -0
- reaxkit/io/handlers/fort73_handler.py +145 -0
- reaxkit/io/handlers/fort74_handler.py +155 -0
- reaxkit/io/handlers/fort76_handler.py +195 -0
- reaxkit/io/handlers/fort78_handler.py +142 -0
- reaxkit/io/handlers/fort79_handler.py +227 -0
- reaxkit/io/handlers/fort7_handler.py +264 -0
- reaxkit/io/handlers/fort99_handler.py +128 -0
- reaxkit/io/handlers/geo_handler.py +224 -0
- reaxkit/io/handlers/molfra_handler.py +184 -0
- reaxkit/io/handlers/params_handler.py +137 -0
- reaxkit/io/handlers/summary_handler.py +135 -0
- reaxkit/io/handlers/trainset_handler.py +658 -0
- reaxkit/io/handlers/vels_handler.py +293 -0
- reaxkit/io/handlers/xmolout_handler.py +174 -0
- reaxkit/utils/__init__.py +0 -0
- reaxkit/utils/alias.py +219 -0
- reaxkit/utils/cache.py +77 -0
- reaxkit/utils/constants.py +75 -0
- reaxkit/utils/equation_of_states.py +96 -0
- reaxkit/utils/exceptions.py +27 -0
- reaxkit/utils/frame_utils.py +175 -0
- reaxkit/utils/log.py +43 -0
- reaxkit/utils/media/__init__.py +0 -0
- reaxkit/utils/media/convert.py +90 -0
- reaxkit/utils/media/make_video.py +91 -0
- reaxkit/utils/media/plotter.py +812 -0
- reaxkit/utils/numerical/__init__.py +0 -0
- reaxkit/utils/numerical/extrema_finder.py +96 -0
- reaxkit/utils/numerical/moving_average.py +103 -0
- reaxkit/utils/numerical/numerical_calcs.py +75 -0
- reaxkit/utils/numerical/signal_ops.py +135 -0
- reaxkit/utils/path.py +55 -0
- reaxkit/utils/units.py +104 -0
- reaxkit/webui/__init__.py +0 -0
- reaxkit/webui/app.py +0 -0
- reaxkit/webui/components.py +0 -0
- reaxkit/webui/layouts.py +0 -0
- reaxkit/webui/utils.py +0 -0
- reaxkit/workflows/__init__.py +0 -0
- reaxkit/workflows/composed/__init__.py +0 -0
- reaxkit/workflows/composed/coordination_workflow.py +393 -0
- reaxkit/workflows/composed/electrostatics_workflow.py +587 -0
- reaxkit/workflows/composed/xmolout_fort7_workflow.py +343 -0
- reaxkit/workflows/meta/__init__.py +0 -0
- reaxkit/workflows/meta/help_workflow.py +136 -0
- reaxkit/workflows/meta/introspection_workflow.py +235 -0
- reaxkit/workflows/meta/make_video_workflow.py +61 -0
- reaxkit/workflows/meta/plotter_workflow.py +601 -0
- reaxkit/workflows/per_file/__init__.py +0 -0
- reaxkit/workflows/per_file/control_workflow.py +110 -0
- reaxkit/workflows/per_file/eregime_workflow.py +267 -0
- reaxkit/workflows/per_file/ffield_workflow.py +390 -0
- reaxkit/workflows/per_file/fort13_workflow.py +86 -0
- reaxkit/workflows/per_file/fort57_workflow.py +137 -0
- reaxkit/workflows/per_file/fort73_workflow.py +151 -0
- reaxkit/workflows/per_file/fort74_workflow.py +88 -0
- reaxkit/workflows/per_file/fort76_workflow.py +188 -0
- reaxkit/workflows/per_file/fort78_workflow.py +135 -0
- reaxkit/workflows/per_file/fort79_workflow.py +314 -0
- reaxkit/workflows/per_file/fort7_workflow.py +592 -0
- reaxkit/workflows/per_file/fort83_workflow.py +60 -0
- reaxkit/workflows/per_file/fort99_workflow.py +223 -0
- reaxkit/workflows/per_file/geo_workflow.py +554 -0
- reaxkit/workflows/per_file/molfra_workflow.py +577 -0
- reaxkit/workflows/per_file/params_workflow.py +135 -0
- reaxkit/workflows/per_file/summary_workflow.py +161 -0
- reaxkit/workflows/per_file/trainset_workflow.py +356 -0
- reaxkit/workflows/per_file/tregime_workflow.py +79 -0
- reaxkit/workflows/per_file/vels_workflow.py +309 -0
- reaxkit/workflows/per_file/vregime_workflow.py +75 -0
- reaxkit/workflows/per_file/xmolout_workflow.py +678 -0
- reaxkit-1.0.0.dist-info/METADATA +128 -0
- reaxkit-1.0.0.dist-info/RECORD +130 -0
- reaxkit-1.0.0.dist-info/WHEEL +5 -0
- reaxkit-1.0.0.dist-info/entry_points.txt +2 -0
- reaxkit-1.0.0.dist-info/licenses/AUTHORS.md +20 -0
- reaxkit-1.0.0.dist-info/licenses/LICENSE +21 -0
- reaxkit-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,812 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Plotting utilities for ReaxKit.
|
|
3
|
+
|
|
4
|
+
This module provides a collection of lightweight, reusable plotting helpers
|
|
5
|
+
built on Matplotlib for visualizing ReaxFF simulation data. The functions
|
|
6
|
+
support common use cases such as time-series plots, multi-axis comparisons,
|
|
7
|
+
stacked subplots, 3D atom visualizations, sensitivity (tornado) plots, and
|
|
8
|
+
2D projections of 3D data.
|
|
9
|
+
|
|
10
|
+
All helpers share a consistent save/show behavior and are designed to be
|
|
11
|
+
used directly by analyzers and workflows.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
import pandas as pd
|
|
16
|
+
from mpl_toolkits.mplot3d import Axes3D # noqa: F401 (needed for 3D)
|
|
17
|
+
import numpy as np
|
|
18
|
+
import matplotlib.pyplot as plt
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from typing import Optional, Sequence, Tuple, Union, Callable, Mapping, Any, Dict
|
|
21
|
+
|
|
22
|
+
# ---------- Helpers ----------
|
|
23
|
+
def _save_or_show(
|
|
24
|
+
fig: plt.Figure,
|
|
25
|
+
save_dir: Optional[Union[str, Path]],
|
|
26
|
+
filename: str,
|
|
27
|
+
show_message: bool = True,
|
|
28
|
+
) -> None:
|
|
29
|
+
"""
|
|
30
|
+
Save a figure to disk or display it interactively.
|
|
31
|
+
|
|
32
|
+
If ``save_dir`` is provided, the figure is saved either to the given
|
|
33
|
+
file path or into the specified directory using ``filename``. If
|
|
34
|
+
``save_dir`` is ``None``, the figure is shown interactively.
|
|
35
|
+
"""
|
|
36
|
+
if save_dir:
|
|
37
|
+
save_path = Path(save_dir)
|
|
38
|
+
# Known image/document extensions
|
|
39
|
+
exts = {".png", ".jpg", ".jpeg", ".svg", ".pdf", ".tif", ".tiff", ".bmp"}
|
|
40
|
+
if save_path.suffix.lower() in exts:
|
|
41
|
+
# save_dir is a full file path
|
|
42
|
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
43
|
+
out_file = save_path
|
|
44
|
+
else:
|
|
45
|
+
# save_dir is a directory; build filename.png inside it
|
|
46
|
+
save_path.mkdir(parents=True, exist_ok=True)
|
|
47
|
+
safe_name = filename.replace(" ", "_")
|
|
48
|
+
out_file = save_path / f"{safe_name}.png"
|
|
49
|
+
|
|
50
|
+
fig.savefig(out_file, dpi=300, bbox_inches='tight')
|
|
51
|
+
plt.close(fig)
|
|
52
|
+
print(f"[Done] saved plot to {out_file}") if show_message else None
|
|
53
|
+
else:
|
|
54
|
+
plt.show()
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# ---------- Plots ----------
|
|
58
|
+
def single_plot(
|
|
59
|
+
x: Optional[Sequence[float]] = None,
|
|
60
|
+
y: Optional[Sequence[float]] = None,
|
|
61
|
+
*,
|
|
62
|
+
series: Optional[Sequence[Mapping[str, Any]]] = None,
|
|
63
|
+
hlines: Optional[Sequence[Union[float, tuple, Mapping[str, Any]]]] = None,
|
|
64
|
+
title: Optional[str] = None,
|
|
65
|
+
xlabel: Optional[str] = None,
|
|
66
|
+
ylabel: Optional[str] = None,
|
|
67
|
+
save: Optional[Union[str, Path]] = None,
|
|
68
|
+
legend: bool = False,
|
|
69
|
+
figsize: tuple[float, float] = (8.0, 3.2),
|
|
70
|
+
plot_type: str = "line",
|
|
71
|
+
) -> plt.Figure:
|
|
72
|
+
"""
|
|
73
|
+
Create a single plot for one or multiple data series.
|
|
74
|
+
|
|
75
|
+
This function supports both simple ``(x, y)`` inputs and flexible
|
|
76
|
+
multi-series plotting using a list of dictionaries. It can render
|
|
77
|
+
line or scatter plots, add horizontal reference lines, and
|
|
78
|
+
automatically save or display the figure.
|
|
79
|
+
|
|
80
|
+
Parameters
|
|
81
|
+
----------
|
|
82
|
+
x, y : sequence of float, optional
|
|
83
|
+
Data to plot when using the simple API.
|
|
84
|
+
series : sequence of mapping, optional
|
|
85
|
+
Multi-series specification with keys such as ``x``, ``y``,
|
|
86
|
+
``label``, ``marker``, and ``linewidth``.
|
|
87
|
+
hlines : sequence, optional
|
|
88
|
+
Horizontal reference lines specified as floats, tuples, or dicts.
|
|
89
|
+
title, xlabel, ylabel : str, optional
|
|
90
|
+
Plot title and axis labels.
|
|
91
|
+
save : str or Path, optional
|
|
92
|
+
Output directory or file path. If not provided, the plot is shown.
|
|
93
|
+
legend : bool, optional
|
|
94
|
+
Whether to display a legend.
|
|
95
|
+
figsize : tuple of float, optional
|
|
96
|
+
Figure size.
|
|
97
|
+
plot_type : {'line', 'scatter'}, optional
|
|
98
|
+
Type of plot to generate.
|
|
99
|
+
|
|
100
|
+
Returns
|
|
101
|
+
-------
|
|
102
|
+
matplotlib.figure.Figure
|
|
103
|
+
The created figure.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
108
|
+
|
|
109
|
+
def _plot(ax, x, y, label=None, **kwargs):
|
|
110
|
+
if plot_type == "scatter":
|
|
111
|
+
ax.scatter(x, y, label=label, **kwargs)
|
|
112
|
+
else:
|
|
113
|
+
ax.plot(x, y, label=label, **kwargs)
|
|
114
|
+
|
|
115
|
+
# Multiple series
|
|
116
|
+
if series is not None:
|
|
117
|
+
for s in series:
|
|
118
|
+
sx = s.get("x")
|
|
119
|
+
sy = s.get("y")
|
|
120
|
+
if sx is None or sy is None:
|
|
121
|
+
continue
|
|
122
|
+
lbl = s.get("label")
|
|
123
|
+
lw = s.get("linewidth", 1.2)
|
|
124
|
+
mk = s.get("marker", "." if plot_type == "scatter" else None)
|
|
125
|
+
ms = s.get("markersize", 4)
|
|
126
|
+
al = s.get("alpha", 1.0)
|
|
127
|
+
|
|
128
|
+
# Map marker size kw correctly per plot type
|
|
129
|
+
kwargs = dict(linewidth=lw, marker=mk, alpha=al)
|
|
130
|
+
if plot_type == "scatter":
|
|
131
|
+
kwargs["s"] = ms # scatter uses 's'
|
|
132
|
+
else:
|
|
133
|
+
kwargs["markersize"] = ms # line plot uses 'markersize'
|
|
134
|
+
|
|
135
|
+
_plot(ax, sx, sy, label=lbl, **kwargs)
|
|
136
|
+
else:
|
|
137
|
+
if x is None or y is None:
|
|
138
|
+
raise ValueError("Provide (x, y) or 'series=[...]'.")
|
|
139
|
+
if plot_type == "scatter":
|
|
140
|
+
ax.scatter(x, y, label=None)
|
|
141
|
+
else:
|
|
142
|
+
ax.plot(x, y, label=None)
|
|
143
|
+
|
|
144
|
+
# Horizontal lines (unchanged) ...
|
|
145
|
+
if hlines:
|
|
146
|
+
for h in hlines:
|
|
147
|
+
if isinstance(h, Mapping):
|
|
148
|
+
yv = h.get("y")
|
|
149
|
+
if yv is None:
|
|
150
|
+
continue
|
|
151
|
+
lbl = h.get("label")
|
|
152
|
+
ls = h.get("linestyle", "--")
|
|
153
|
+
lw = h.get("linewidth", 1.0)
|
|
154
|
+
al = h.get("alpha", 1.0)
|
|
155
|
+
ax.axhline(yv, linestyle=ls, linewidth=lw, alpha=al, color="gray", label=lbl)
|
|
156
|
+
elif isinstance(h, tuple):
|
|
157
|
+
yv, lbl = h[0], (h[1] if len(h) > 1 else None)
|
|
158
|
+
ax.axhline(yv, linestyle="--", linewidth=1.0, alpha=1.0, color="gray", label=lbl)
|
|
159
|
+
else:
|
|
160
|
+
ax.axhline(float(h), linestyle="--", linewidth=1.0, alpha=1.0, color="gray")
|
|
161
|
+
|
|
162
|
+
if title:
|
|
163
|
+
ax.set_title(title)
|
|
164
|
+
if xlabel:
|
|
165
|
+
ax.set_xlabel(xlabel)
|
|
166
|
+
if ylabel:
|
|
167
|
+
ax.set_ylabel(ylabel)
|
|
168
|
+
if legend:
|
|
169
|
+
ax.legend()
|
|
170
|
+
|
|
171
|
+
fig.tight_layout()
|
|
172
|
+
_save_or_show(fig, save, title or "single_plot")
|
|
173
|
+
return fig
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def directed_plot(
|
|
177
|
+
x: Sequence[float],
|
|
178
|
+
y: Sequence[float],
|
|
179
|
+
*,
|
|
180
|
+
figsize: Tuple[float, float] = (10, 6),
|
|
181
|
+
title: str = '',
|
|
182
|
+
xlabel: str = '',
|
|
183
|
+
ylabel: str = '',
|
|
184
|
+
color: str = 'blue',
|
|
185
|
+
linestyle: str = '-',
|
|
186
|
+
arrow_color: str = 'red',
|
|
187
|
+
arrow_width: float = 0.003,
|
|
188
|
+
grid: bool = False,
|
|
189
|
+
xlim: Optional[Tuple[float, float]] = None,
|
|
190
|
+
ylim: Optional[Tuple[float, float]] = None,
|
|
191
|
+
hline: Optional[float] = None,
|
|
192
|
+
hline_kwargs: Optional[Dict] = None,
|
|
193
|
+
legend: bool = False,
|
|
194
|
+
save: Optional[Union[str, Path]] = None
|
|
195
|
+
) -> None:
|
|
196
|
+
"""
|
|
197
|
+
Plot a 2D path with directional arrows.
|
|
198
|
+
|
|
199
|
+
This helper visualizes progression along a trajectory by drawing a
|
|
200
|
+
continuous line through ``(x, y)`` points and overlaying arrows that
|
|
201
|
+
indicate direction. It is useful for trajectories, energy paths, or
|
|
202
|
+
ordered parameter sweeps.
|
|
203
|
+
|
|
204
|
+
Parameters
|
|
205
|
+
----------
|
|
206
|
+
x, y : sequence of float
|
|
207
|
+
Path coordinates.
|
|
208
|
+
title, xlabel, ylabel : str, optional
|
|
209
|
+
Plot title and axis labels.
|
|
210
|
+
save : str or Path, optional
|
|
211
|
+
Output directory or file path.
|
|
212
|
+
|
|
213
|
+
Returns
|
|
214
|
+
-------
|
|
215
|
+
None
|
|
216
|
+
"""
|
|
217
|
+
dx = np.diff(x)
|
|
218
|
+
dy = np.diff(y)
|
|
219
|
+
|
|
220
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
221
|
+
ax.plot(x, y, linestyle=linestyle, color=color, label='Path')
|
|
222
|
+
ax.quiver(
|
|
223
|
+
x[:-1], y[:-1], dx, dy,
|
|
224
|
+
angles='xy', scale_units='xy', scale=1,
|
|
225
|
+
color=arrow_color, width=arrow_width
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
ax.set(title=title, xlabel=xlabel, ylabel=ylabel)
|
|
229
|
+
|
|
230
|
+
if xlim:
|
|
231
|
+
ax.set_xlim(xlim)
|
|
232
|
+
if ylim:
|
|
233
|
+
ax.set_ylim(ylim)
|
|
234
|
+
if grid:
|
|
235
|
+
ax.grid(True)
|
|
236
|
+
if hline is not None:
|
|
237
|
+
params = {'color': 'black', 'linestyle': '--', 'linewidth': 1}
|
|
238
|
+
if hline_kwargs:
|
|
239
|
+
params.update(hline_kwargs)
|
|
240
|
+
ax.axhline(hline, **params)
|
|
241
|
+
if legend:
|
|
242
|
+
ax.legend()
|
|
243
|
+
|
|
244
|
+
_save_or_show(fig, save, title or 'directed_plot')
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def dual_yaxis_plot(
|
|
248
|
+
x: Sequence[float],
|
|
249
|
+
y1: Sequence[float],
|
|
250
|
+
y2: Sequence[float],
|
|
251
|
+
*,
|
|
252
|
+
figsize: Tuple[float, float] = (10, 6),
|
|
253
|
+
title: str = '',
|
|
254
|
+
xlabel: str = '',
|
|
255
|
+
ylabel1: str = '',
|
|
256
|
+
ylabel2: str = '',
|
|
257
|
+
color1: str = 'blue',
|
|
258
|
+
linestyle1: str = '-',
|
|
259
|
+
marker1: str = '',
|
|
260
|
+
color2: str = 'green',
|
|
261
|
+
linestyle2: str = '--',
|
|
262
|
+
marker2: str = '',
|
|
263
|
+
xlim: Optional[Tuple[float, float]] = None,
|
|
264
|
+
ylim1: Optional[Tuple[float, float]] = None,
|
|
265
|
+
ylim2: Optional[Tuple[float, float]] = None,
|
|
266
|
+
grid: bool = False,
|
|
267
|
+
hline1: Optional[float] = None,
|
|
268
|
+
hline1_kwargs: Optional[Dict] = None,
|
|
269
|
+
hline2: Optional[float] = None,
|
|
270
|
+
hline2_kwargs: Optional[Dict] = None,
|
|
271
|
+
vline: Optional[float] = None,
|
|
272
|
+
vline_kwargs: Optional[Dict] = None,
|
|
273
|
+
save: Optional[Union[str, Path]] = None
|
|
274
|
+
) -> None:
|
|
275
|
+
"""
|
|
276
|
+
Plot two datasets against a shared x-axis with separate y-axes.
|
|
277
|
+
|
|
278
|
+
This function is intended for comparing quantities with different
|
|
279
|
+
units or magnitudes on the same plot (e.g., energy vs temperature).
|
|
280
|
+
|
|
281
|
+
Parameters
|
|
282
|
+
----------
|
|
283
|
+
x : sequence of float
|
|
284
|
+
Shared x-axis values.
|
|
285
|
+
y1, y2 : sequence of float
|
|
286
|
+
Data for the left and right y-axes.
|
|
287
|
+
title, xlabel, ylabel1, ylabel2 : str, optional
|
|
288
|
+
Labels and title.
|
|
289
|
+
save : str or Path, optional
|
|
290
|
+
Output directory or file path.
|
|
291
|
+
|
|
292
|
+
Returns
|
|
293
|
+
-------
|
|
294
|
+
None
|
|
295
|
+
"""
|
|
296
|
+
fig, ax1 = plt.subplots(figsize=figsize)
|
|
297
|
+
|
|
298
|
+
ax1.plot(x, y1, linestyle=linestyle1, marker=marker1, color=color1)
|
|
299
|
+
ax1.set_xlabel(xlabel)
|
|
300
|
+
ax1.set_ylabel(ylabel1, color=color1)
|
|
301
|
+
if xlim:
|
|
302
|
+
ax1.set_xlim(xlim)
|
|
303
|
+
if ylim1:
|
|
304
|
+
ax1.set_ylim(ylim1)
|
|
305
|
+
if grid:
|
|
306
|
+
ax1.grid(True)
|
|
307
|
+
if hline1 is not None:
|
|
308
|
+
params = {'color': color1, 'linestyle': '--', 'linewidth': 1}
|
|
309
|
+
if hline1_kwargs:
|
|
310
|
+
params.update(hline1_kwargs)
|
|
311
|
+
ax1.axhline(hline1, **params)
|
|
312
|
+
|
|
313
|
+
ax2 = ax1.twinx()
|
|
314
|
+
ax2.plot(x, y2, linestyle=linestyle2, marker=marker2, color=color2)
|
|
315
|
+
ax2.set_ylabel(ylabel2, color=color2)
|
|
316
|
+
if ylim2:
|
|
317
|
+
ax2.set_ylim(ylim2)
|
|
318
|
+
if hline2 is not None:
|
|
319
|
+
params = {'color': color2, 'linestyle': '--', 'linewidth': 1}
|
|
320
|
+
if hline2_kwargs:
|
|
321
|
+
params.update(hline2_kwargs)
|
|
322
|
+
ax2.axhline(hline2, **params)
|
|
323
|
+
|
|
324
|
+
if vline is not None:
|
|
325
|
+
params = {'color': 'black', 'linestyle': ':', 'linewidth': 1}
|
|
326
|
+
if vline_kwargs:
|
|
327
|
+
params.update(vline_kwargs)
|
|
328
|
+
ax1.axvline(vline, **params)
|
|
329
|
+
|
|
330
|
+
fig.suptitle(title)
|
|
331
|
+
_save_or_show(fig, save, title or 'dual_yaxis_plot')
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def multi_subplots(
|
|
335
|
+
subplots: Sequence[Sequence[Mapping[str, Any]]],
|
|
336
|
+
*,
|
|
337
|
+
title: Optional[Union[str, Sequence[Optional[str]]]] = None,
|
|
338
|
+
xlabel: Optional[Union[str, Sequence[Optional[str]]]] = None,
|
|
339
|
+
ylabel: Optional[Union[str, Sequence[Optional[str]]]] = None,
|
|
340
|
+
sharex: bool = False,
|
|
341
|
+
sharey: bool = False,
|
|
342
|
+
legend: bool = True,
|
|
343
|
+
grid: bool = False,
|
|
344
|
+
figsize: tuple[float, float] = (8.0, 6.0),
|
|
345
|
+
save: Optional[Union[str, Path]] = None,
|
|
346
|
+
) -> plt.Figure:
|
|
347
|
+
"""
|
|
348
|
+
Create multiple vertically stacked subplots.
|
|
349
|
+
|
|
350
|
+
Each subplot accepts the same series specification used by
|
|
351
|
+
``single_plot``, allowing consistent plotting across panels.
|
|
352
|
+
Titles and axis labels may be shared or specified per subplot.
|
|
353
|
+
|
|
354
|
+
Parameters
|
|
355
|
+
----------
|
|
356
|
+
subplots : sequence of sequence of dict
|
|
357
|
+
Series definitions for each subplot.
|
|
358
|
+
title, xlabel, ylabel : str or sequence of str, optional
|
|
359
|
+
Global or per-subplot titles and labels.
|
|
360
|
+
save : str or Path, optional
|
|
361
|
+
Output directory or file path.
|
|
362
|
+
|
|
363
|
+
Returns
|
|
364
|
+
-------
|
|
365
|
+
matplotlib.figure.Figure
|
|
366
|
+
The created figure.
|
|
367
|
+
"""
|
|
368
|
+
import matplotlib.pyplot as plt
|
|
369
|
+
|
|
370
|
+
nplots = len(subplots)
|
|
371
|
+
if nplots == 0:
|
|
372
|
+
print("multi_subplots: no subplot data provided.")
|
|
373
|
+
return None # type: ignore[return-value]
|
|
374
|
+
|
|
375
|
+
def _normalize_seq(
|
|
376
|
+
val: Optional[Union[str, Sequence[Optional[str]]]],
|
|
377
|
+
n: int,
|
|
378
|
+
) -> list[Optional[str]]:
|
|
379
|
+
"""Turn val into a list of length n.
|
|
380
|
+
|
|
381
|
+
- None → [None] * n
|
|
382
|
+
- str → [str] * n
|
|
383
|
+
- sequence:
|
|
384
|
+
* len == 1 → repeat for all
|
|
385
|
+
* len == n → use as-is
|
|
386
|
+
* otherwise → error
|
|
387
|
+
"""
|
|
388
|
+
if val is None:
|
|
389
|
+
return [None] * n
|
|
390
|
+
if isinstance(val, (list, tuple)):
|
|
391
|
+
if len(val) == 1:
|
|
392
|
+
return [val[0]] * n
|
|
393
|
+
if len(val) != n:
|
|
394
|
+
raise ValueError(
|
|
395
|
+
f"Expected sequence of length 1 or {n}, got length {len(val)}"
|
|
396
|
+
)
|
|
397
|
+
return list(val)
|
|
398
|
+
# single string
|
|
399
|
+
return [val] * n
|
|
400
|
+
|
|
401
|
+
# Handle global vs per-subplot title
|
|
402
|
+
if isinstance(title, (list, tuple)):
|
|
403
|
+
per_titles = _normalize_seq(title, nplots)
|
|
404
|
+
global_title = None
|
|
405
|
+
else:
|
|
406
|
+
per_titles = [None] * nplots
|
|
407
|
+
global_title = title
|
|
408
|
+
|
|
409
|
+
xlabels = _normalize_seq(xlabel, nplots)
|
|
410
|
+
ylabels = _normalize_seq(ylabel, nplots)
|
|
411
|
+
|
|
412
|
+
fig, axes = plt.subplots(
|
|
413
|
+
nplots,
|
|
414
|
+
1,
|
|
415
|
+
figsize=figsize,
|
|
416
|
+
sharex=sharex,
|
|
417
|
+
sharey=sharey,
|
|
418
|
+
squeeze=False,
|
|
419
|
+
)
|
|
420
|
+
axes = axes.flatten()
|
|
421
|
+
|
|
422
|
+
for idx, ax in enumerate(axes):
|
|
423
|
+
if idx >= nplots:
|
|
424
|
+
break
|
|
425
|
+
|
|
426
|
+
# Plot series for this subplot
|
|
427
|
+
for series in subplots[idx]:
|
|
428
|
+
x = series.get("x")
|
|
429
|
+
y = series.get("y")
|
|
430
|
+
if x is None or y is None:
|
|
431
|
+
continue
|
|
432
|
+
label = series.get("label")
|
|
433
|
+
ax.plot(x, y, label=label)
|
|
434
|
+
|
|
435
|
+
# Per-subplot labels/titles
|
|
436
|
+
if ylabels[idx]:
|
|
437
|
+
ax.set_ylabel(ylabels[idx])
|
|
438
|
+
if xlabels[idx]:
|
|
439
|
+
ax.set_xlabel(xlabels[idx])
|
|
440
|
+
if per_titles[idx]:
|
|
441
|
+
ax.set_title(per_titles[idx])
|
|
442
|
+
|
|
443
|
+
if grid:
|
|
444
|
+
ax.grid(True, alpha=0.3)
|
|
445
|
+
if legend:
|
|
446
|
+
ax.legend(fontsize=9)
|
|
447
|
+
|
|
448
|
+
# Global suptitle if provided as a single string
|
|
449
|
+
if global_title:
|
|
450
|
+
fig.suptitle(global_title, fontsize=14, y=0.98)
|
|
451
|
+
|
|
452
|
+
fig.tight_layout()
|
|
453
|
+
_save_or_show(fig, save, global_title or "multi_subplots")
|
|
454
|
+
return fig
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def tornado_plot(
|
|
460
|
+
labels: Sequence[str],
|
|
461
|
+
min_vals: Sequence[float],
|
|
462
|
+
max_vals: Sequence[float],
|
|
463
|
+
*,
|
|
464
|
+
median_vals: Optional[Sequence[float]] = None,
|
|
465
|
+
title: str = "Tornado Plot",
|
|
466
|
+
xlabel: str = "Value",
|
|
467
|
+
ylabel: str = "Value",
|
|
468
|
+
save: Optional[Union[str, Path]] = None,
|
|
469
|
+
top: int = 0,
|
|
470
|
+
vline: Optional[float] = None,
|
|
471
|
+
left_color="#1F77B4",
|
|
472
|
+
right_color=(225 / 255, 113 / 255, 29 / 255) # RGB tuple
|
|
473
|
+
) -> None:
|
|
474
|
+
"""
|
|
475
|
+
Create a tornado plot to visualize sensitivity or uncertainty ranges.
|
|
476
|
+
|
|
477
|
+
Each label is represented by a horizontal bar spanning from a minimum
|
|
478
|
+
to a maximum value. Bars are ordered by total span, highlighting the
|
|
479
|
+
most influential parameters.
|
|
480
|
+
|
|
481
|
+
Parameters
|
|
482
|
+
----------
|
|
483
|
+
labels : sequence of str
|
|
484
|
+
Parameter or variable names.
|
|
485
|
+
min_vals, max_vals : sequence of float
|
|
486
|
+
Lower and upper bounds for each parameter.
|
|
487
|
+
median_vals : sequence of float, optional
|
|
488
|
+
Optional central estimates to mark on each bar.
|
|
489
|
+
save : str or Path, optional
|
|
490
|
+
Output directory or file path.
|
|
491
|
+
|
|
492
|
+
Returns
|
|
493
|
+
-------
|
|
494
|
+
None
|
|
495
|
+
"""
|
|
496
|
+
df = pd.DataFrame({"label": labels, "min": min_vals, "max": max_vals})
|
|
497
|
+
if median_vals is not None:
|
|
498
|
+
if len(median_vals) != len(labels):
|
|
499
|
+
raise ValueError("median_vals must be the same length as labels/min_vals/max_vals")
|
|
500
|
+
df["median"] = list(median_vals)
|
|
501
|
+
|
|
502
|
+
df["span"] = df["max"] - df["min"]
|
|
503
|
+
df = df.sort_values("span", ascending=False).reset_index(drop=True)
|
|
504
|
+
if top and top > 0:
|
|
505
|
+
df = df.head(top)
|
|
506
|
+
|
|
507
|
+
if df.empty:
|
|
508
|
+
fig, ax = plt.subplots(figsize=(6, 3))
|
|
509
|
+
ax.text(0.5, 0.5, "No data to plot", ha="center", va="center")
|
|
510
|
+
ax.axis("off")
|
|
511
|
+
_save_or_show(fig, save, filename=title or "tornado_plot")
|
|
512
|
+
return
|
|
513
|
+
|
|
514
|
+
fig_height = max(3.2, 0.38 * len(df))
|
|
515
|
+
fig, ax = plt.subplots(figsize=(8.6, fig_height))
|
|
516
|
+
|
|
517
|
+
y_positions = np.arange(len(df))
|
|
518
|
+
bar_height = 0.5
|
|
519
|
+
edge_common = dict(edgecolor="black", linewidth=0.5)
|
|
520
|
+
|
|
521
|
+
for y, row in zip(y_positions, df.itertuples(index=False)):
|
|
522
|
+
if row.max < row.min: # guard against swapped inputs
|
|
523
|
+
left, right = row.max, row.min
|
|
524
|
+
else:
|
|
525
|
+
left, right = row.min, row.max
|
|
526
|
+
|
|
527
|
+
if vline is None:
|
|
528
|
+
# no reference: draw a single neutral bar
|
|
529
|
+
ax.barh(y=y, width=right - left, left=left, height=bar_height,
|
|
530
|
+
color="tab:gray", alpha=0.7, **edge_common)
|
|
531
|
+
else:
|
|
532
|
+
if right <= vline:
|
|
533
|
+
# fully left of vline
|
|
534
|
+
ax.barh(y=y, width=right - left, left=left, height=bar_height,
|
|
535
|
+
color=left_color, alpha=0.8, **edge_common)
|
|
536
|
+
elif left >= vline:
|
|
537
|
+
# fully right of vline
|
|
538
|
+
ax.barh(y=y, width=right - left, left=left, height=bar_height,
|
|
539
|
+
color=right_color, alpha=0.8, **edge_common)
|
|
540
|
+
else:
|
|
541
|
+
# straddles vline -> split into two bars
|
|
542
|
+
ax.barh(y=y, width=vline - left, left=left, height=bar_height,
|
|
543
|
+
color=left_color, alpha=0.8, **edge_common)
|
|
544
|
+
ax.barh(y=y, width=right - vline, left=vline, height=bar_height,
|
|
545
|
+
color=right_color, alpha=0.8, **edge_common)
|
|
546
|
+
|
|
547
|
+
# median marker as asterisk
|
|
548
|
+
if "median" in df.columns and not pd.isna(row.median):
|
|
549
|
+
ax.plot(row.median, y, marker="*", markersize=5, color='black', zorder=5)
|
|
550
|
+
|
|
551
|
+
if vline is not None:
|
|
552
|
+
ax.axvline(vline, linestyle="--", linewidth=1, color=(66/255, 196/255, 127/255))
|
|
553
|
+
|
|
554
|
+
# simple legend patches to match the example look
|
|
555
|
+
from matplotlib.patches import Patch
|
|
556
|
+
handles = [
|
|
557
|
+
Patch(facecolor=left_color, edgecolor="black"),
|
|
558
|
+
Patch(facecolor=right_color, edgecolor="black"),
|
|
559
|
+
plt.Line2D([0], [0], marker="*", linestyle="none", color="black",
|
|
560
|
+
markersize=10),
|
|
561
|
+
]
|
|
562
|
+
|
|
563
|
+
|
|
564
|
+
ax.set_yticks(list(y_positions))
|
|
565
|
+
ax.set_yticklabels(df["label"])
|
|
566
|
+
ax.invert_yaxis()
|
|
567
|
+
ax.set_xlabel(xlabel)
|
|
568
|
+
ax.set_ylabel(ylabel)
|
|
569
|
+
ax.set_title(title)
|
|
570
|
+
ax.grid(axis="x", linestyle=":", alpha=0.4)
|
|
571
|
+
fig.tight_layout()
|
|
572
|
+
|
|
573
|
+
_save_or_show(fig, save, filename=title or "tornado_plot")
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
def scatter3d_points(
|
|
579
|
+
coords: np.ndarray, # (N, 3) array of XYZ
|
|
580
|
+
values: np.ndarray, # (N,) array, e.g., partial charges
|
|
581
|
+
*,
|
|
582
|
+
title: str = "atoms (3D)",
|
|
583
|
+
s: float = 8.0,
|
|
584
|
+
alpha: float = 0.9,
|
|
585
|
+
vmin: Optional[float] = None,
|
|
586
|
+
vmax: Optional[float] = None,
|
|
587
|
+
cmap: str = "coolwarm",
|
|
588
|
+
figsize: Tuple[float, float] = (7.5, 6.0),
|
|
589
|
+
elev: float = 22.0,
|
|
590
|
+
azim: float = 38.0,
|
|
591
|
+
save: Optional[Union[str, Path]] = None, # dir or full path w/ extension
|
|
592
|
+
show_colorbar: bool = True,
|
|
593
|
+
show_message: bool = True,
|
|
594
|
+
):
|
|
595
|
+
"""
|
|
596
|
+
Render a 3D scatter plot of atomic coordinates.
|
|
597
|
+
|
|
598
|
+
Points are colored by a scalar per-atom property such as partial
|
|
599
|
+
charge or bond-order sum, enabling spatial visualization of
|
|
600
|
+
atom-resolved quantities.
|
|
601
|
+
|
|
602
|
+
Parameters
|
|
603
|
+
----------
|
|
604
|
+
coords : array-like, shape (N, 3)
|
|
605
|
+
Atomic coordinates.
|
|
606
|
+
values : array-like, shape (N,)
|
|
607
|
+
Scalar values mapped to colors.
|
|
608
|
+
save : str or Path, optional
|
|
609
|
+
Output directory or file path.
|
|
610
|
+
|
|
611
|
+
Returns
|
|
612
|
+
-------
|
|
613
|
+
matplotlib.figure.Figure
|
|
614
|
+
The created figure.
|
|
615
|
+
"""
|
|
616
|
+
coords = np.asarray(coords, float)
|
|
617
|
+
values = np.asarray(values, float)
|
|
618
|
+
assert coords.ndim == 2 and coords.shape[1] == 3, "coords must be (N,3)"
|
|
619
|
+
assert values.ndim == 1 and len(values) == len(coords), "values must be (N,)"
|
|
620
|
+
|
|
621
|
+
fig = plt.figure(figsize=figsize)
|
|
622
|
+
ax = fig.add_subplot(111, projection="3d")
|
|
623
|
+
|
|
624
|
+
sc = ax.scatter(
|
|
625
|
+
coords[:, 0], coords[:, 1], coords[:, 2],
|
|
626
|
+
c=values, cmap=cmap, vmin=vmin, vmax=vmax,
|
|
627
|
+
s=s, alpha=alpha, depthshade=True
|
|
628
|
+
)
|
|
629
|
+
if show_colorbar:
|
|
630
|
+
cb = fig.colorbar(sc, ax=ax, shrink=0.75, pad=0.02)
|
|
631
|
+
cb.set_label("value")
|
|
632
|
+
|
|
633
|
+
ax.set_xlabel("x (Å)")
|
|
634
|
+
ax.set_ylabel("y (Å)")
|
|
635
|
+
ax.set_zlabel("z (Å)")
|
|
636
|
+
ax.set_title(title)
|
|
637
|
+
ax.view_init(elev=elev, azim=azim)
|
|
638
|
+
|
|
639
|
+
_save_or_show(fig, save, filename=title, show_message =show_message)
|
|
640
|
+
|
|
641
|
+
return fig
|
|
642
|
+
|
|
643
|
+
|
|
644
|
+
def heatmap2d_from_3d(
|
|
645
|
+
coords: np.ndarray, # (N,3) array of XYZ
|
|
646
|
+
values: np.ndarray, # (N,) scalar to aggregate (e.g., partial_charge)
|
|
647
|
+
*,
|
|
648
|
+
plane: str = "xy", # "xy", "xz", or "yz"
|
|
649
|
+
bins: Union[int, Tuple[int, int]] = 50, # grid resolution; int or (nx, ny)
|
|
650
|
+
xlim: Optional[Tuple[float, float]] = None,
|
|
651
|
+
ylim: Optional[Tuple[float, float]] = None,
|
|
652
|
+
agg: Union[str, Callable[[np.ndarray], float]] = "mean", # "mean","max","min","sum","count" or a callable
|
|
653
|
+
vmin: Optional[float] = None,
|
|
654
|
+
vmax: Optional[float] = None,
|
|
655
|
+
cmap: str = "viridis",
|
|
656
|
+
title: str = "2D aggregated heatmap",
|
|
657
|
+
figsize: Tuple[float, float] = (6.5, 5.5),
|
|
658
|
+
save: Optional[Union[str, Path]] = None, # dir or full path; if dir, saves PNG
|
|
659
|
+
show_colorbar: bool = True,
|
|
660
|
+
show_message: bool = True,
|
|
661
|
+
):
|
|
662
|
+
"""
|
|
663
|
+
Project 3D point data onto a 2D plane and aggregate values on a grid.
|
|
664
|
+
|
|
665
|
+
This function bins projected coordinates onto a regular grid and
|
|
666
|
+
aggregates per-point values using a specified operation, producing
|
|
667
|
+
a 2D heatmap suitable for planar analysis.
|
|
668
|
+
|
|
669
|
+
Parameters
|
|
670
|
+
----------
|
|
671
|
+
coords : array-like, shape (N, 3)
|
|
672
|
+
3D coordinates.
|
|
673
|
+
values : array-like, shape (N,)
|
|
674
|
+
Values to aggregate.
|
|
675
|
+
plane : {'xy', 'xz', 'yz'}, optional
|
|
676
|
+
Projection plane.
|
|
677
|
+
agg : {'mean', 'max', 'min', 'sum', 'count'} or callable, optional
|
|
678
|
+
Aggregation method.
|
|
679
|
+
save : str or Path, optional
|
|
680
|
+
Output directory or file path.
|
|
681
|
+
|
|
682
|
+
Returns
|
|
683
|
+
-------
|
|
684
|
+
fig : matplotlib.figure.Figure
|
|
685
|
+
grid : numpy.ndarray
|
|
686
|
+
Aggregated 2D grid.
|
|
687
|
+
xedges, yedges : numpy.ndarray
|
|
688
|
+
Bin edges along each axis.
|
|
689
|
+
"""
|
|
690
|
+
coords = np.asarray(coords, float)
|
|
691
|
+
values = np.asarray(values, float)
|
|
692
|
+
assert coords.ndim == 2 and coords.shape[1] == 3, "coords must be (N,3)"
|
|
693
|
+
assert values.ndim == 1 and len(values) == len(coords), "values must be (N,)"
|
|
694
|
+
|
|
695
|
+
# Choose plane projection
|
|
696
|
+
if plane == "xy":
|
|
697
|
+
u, v = coords[:, 0], coords[:, 1]
|
|
698
|
+
xlabel, ylabel = "x (Å)", "y (Å)"
|
|
699
|
+
elif plane == "xz":
|
|
700
|
+
u, v = coords[:, 0], coords[:, 2]
|
|
701
|
+
xlabel, ylabel = "x (Å)", "z (Å)"
|
|
702
|
+
elif plane == "yz":
|
|
703
|
+
u, v = coords[:, 1], coords[:, 2]
|
|
704
|
+
xlabel, ylabel = "y (Å)", "z (Å)"
|
|
705
|
+
else:
|
|
706
|
+
raise ValueError("plane must be one of {'xy','xz','yz'}")
|
|
707
|
+
|
|
708
|
+
# Grid resolution
|
|
709
|
+
if isinstance(bins, int):
|
|
710
|
+
nx = ny = int(bins)
|
|
711
|
+
else:
|
|
712
|
+
nx, ny = int(bins[0]), int(bins[1])
|
|
713
|
+
|
|
714
|
+
# Ranges
|
|
715
|
+
umin = np.min(u) if xlim is None else xlim[0]
|
|
716
|
+
umax = np.max(u) if xlim is None else xlim[1]
|
|
717
|
+
vmin_edge = np.min(v) if ylim is None else ylim[0]
|
|
718
|
+
vmax_edge = np.max(v) if ylim is None else ylim[1]
|
|
719
|
+
|
|
720
|
+
# Bin edges
|
|
721
|
+
xedges = np.linspace(umin, umax, nx + 1)
|
|
722
|
+
yedges = np.linspace(vmin_edge, vmax_edge, ny + 1)
|
|
723
|
+
|
|
724
|
+
# Digitize points to cell indices
|
|
725
|
+
# Points exactly on the right/top edge go to the previous bin
|
|
726
|
+
ui = np.clip(np.digitize(u, xedges) - 1, 0, nx - 1)
|
|
727
|
+
vi = np.clip(np.digitize(v, yedges) - 1, 0, ny - 1)
|
|
728
|
+
|
|
729
|
+
# Flattened cell id for bincount tricks
|
|
730
|
+
flat_idx = vi * nx + ui
|
|
731
|
+
n_cells = nx * ny
|
|
732
|
+
|
|
733
|
+
grid = np.full((ny, nx), np.nan, float)
|
|
734
|
+
|
|
735
|
+
# Built-in fast paths
|
|
736
|
+
if isinstance(agg, str):
|
|
737
|
+
agg_lower = agg.lower()
|
|
738
|
+
if agg_lower == "count":
|
|
739
|
+
cnt = np.bincount(flat_idx, minlength=n_cells).astype(float)
|
|
740
|
+
grid = cnt.reshape(ny, nx)
|
|
741
|
+
elif agg_lower in {"sum", "mean"}:
|
|
742
|
+
sumv = np.bincount(flat_idx, weights=values, minlength=n_cells).astype(float)
|
|
743
|
+
cnt = np.bincount(flat_idx, minlength=n_cells).astype(float)
|
|
744
|
+
with np.errstate(invalid="ignore", divide="ignore"):
|
|
745
|
+
if agg_lower == "sum":
|
|
746
|
+
grid = sumv.reshape(ny, nx)
|
|
747
|
+
else: # mean
|
|
748
|
+
grid = (sumv / cnt).reshape(ny, nx)
|
|
749
|
+
grid[cnt.reshape(ny, nx) == 0] = np.nan
|
|
750
|
+
elif agg_lower in {"max", "min"}:
|
|
751
|
+
# One-pass update
|
|
752
|
+
fill_val = -np.inf if agg_lower == "max" else np.inf
|
|
753
|
+
flat_grid = np.full(n_cells, fill_val, float)
|
|
754
|
+
for idx, val in zip(flat_idx, values):
|
|
755
|
+
if agg_lower == "max":
|
|
756
|
+
if val > flat_grid[idx]:
|
|
757
|
+
flat_grid[idx] = val
|
|
758
|
+
else:
|
|
759
|
+
if val < flat_grid[idx]:
|
|
760
|
+
flat_grid[idx] = val
|
|
761
|
+
# Convert untouched cells to NaN
|
|
762
|
+
if agg_lower == "max":
|
|
763
|
+
flat_grid[flat_grid == -np.inf] = np.nan
|
|
764
|
+
else:
|
|
765
|
+
flat_grid[flat_grid == np.inf] = np.nan
|
|
766
|
+
grid = flat_grid.reshape(ny, nx)
|
|
767
|
+
else:
|
|
768
|
+
raise ValueError("agg must be one of {'mean','max','min','sum','count'} or a callable")
|
|
769
|
+
else:
|
|
770
|
+
# Callable aggregator: collect values per cell (okay for modest grids)
|
|
771
|
+
buckets: Sequence[list] = [list() for _ in range(n_cells)]
|
|
772
|
+
for idx, val in zip(flat_idx, values):
|
|
773
|
+
buckets[idx].append(val)
|
|
774
|
+
flat_grid = np.full(n_cells, np.nan, float)
|
|
775
|
+
for i, bucket in enumerate(buckets):
|
|
776
|
+
if bucket:
|
|
777
|
+
try:
|
|
778
|
+
flat_grid[i] = agg(np.asarray(bucket, float))
|
|
779
|
+
except Exception:
|
|
780
|
+
# Fallback: ignore cell if aggregator fails
|
|
781
|
+
flat_grid[i] = np.nan
|
|
782
|
+
grid = flat_grid.reshape(ny, nx)
|
|
783
|
+
|
|
784
|
+
# Color scale defaults from data if not provided
|
|
785
|
+
cmin = np.nanmin(grid) if vmin is None else vmin
|
|
786
|
+
cmax = np.nanmax(grid) if vmax is None else vmax
|
|
787
|
+
|
|
788
|
+
# Plot
|
|
789
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
790
|
+
# extent maps grid cells to physical coordinates
|
|
791
|
+
extent = (xedges[0], xedges[-1], yedges[0], yedges[-1])
|
|
792
|
+
im = ax.imshow(
|
|
793
|
+
grid,
|
|
794
|
+
origin="lower",
|
|
795
|
+
extent=extent,
|
|
796
|
+
aspect="auto",
|
|
797
|
+
vmin=cmin,
|
|
798
|
+
vmax=cmax,
|
|
799
|
+
cmap=cmap,
|
|
800
|
+
interpolation="nearest",
|
|
801
|
+
)
|
|
802
|
+
ax.set_xlabel(xlabel)
|
|
803
|
+
ax.set_ylabel(ylabel)
|
|
804
|
+
ax.set_title(title)
|
|
805
|
+
|
|
806
|
+
if show_colorbar:
|
|
807
|
+
cb = fig.colorbar(im, ax=ax, pad=0.02)
|
|
808
|
+
cb.set_label(f"{agg if isinstance(agg,str) else 'agg'} of values")
|
|
809
|
+
|
|
810
|
+
_save_or_show(fig, save, title.replace(" ", "_"), show_message =show_message)
|
|
811
|
+
|
|
812
|
+
return fig, grid, xedges, yedges
|