pychnosz 1.1.12__cp310-cp310-macosx_15_0_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pychnosz/.dylibs/libgcc_s.1.1.dylib +0 -0
- pychnosz/.dylibs/libgfortran.5.dylib +0 -0
- pychnosz/.dylibs/libquadmath.0.dylib +0 -0
- pychnosz/__init__.py +129 -0
- pychnosz/_version.py +34 -0
- pychnosz/biomolecules/__init__.py +29 -0
- pychnosz/biomolecules/ionize_aa.py +197 -0
- pychnosz/biomolecules/proteins.py +595 -0
- pychnosz/core/__init__.py +46 -0
- pychnosz/core/affinity.py +1256 -0
- pychnosz/core/animation.py +593 -0
- pychnosz/core/balance.py +334 -0
- pychnosz/core/basis.py +716 -0
- pychnosz/core/diagram.py +3336 -0
- pychnosz/core/equilibrate.py +813 -0
- pychnosz/core/equilibrium.py +554 -0
- pychnosz/core/info.py +821 -0
- pychnosz/core/retrieve.py +364 -0
- pychnosz/core/speciation.py +580 -0
- pychnosz/core/species.py +599 -0
- pychnosz/core/subcrt.py +1696 -0
- pychnosz/core/thermo.py +593 -0
- pychnosz/core/unicurve.py +1226 -0
- pychnosz/data/__init__.py +11 -0
- pychnosz/data/add_obigt.py +327 -0
- pychnosz/data/extdata/Berman/BDat17_2017.csv +2 -0
- pychnosz/data/extdata/Berman/Ber88_1988.csv +68 -0
- pychnosz/data/extdata/Berman/Ber90_1990.csv +5 -0
- pychnosz/data/extdata/Berman/DS10_2010.csv +6 -0
- pychnosz/data/extdata/Berman/FDM+14_2014.csv +2 -0
- pychnosz/data/extdata/Berman/Got04_2004.csv +5 -0
- pychnosz/data/extdata/Berman/JUN92_1992.csv +3 -0
- pychnosz/data/extdata/Berman/SHD91_1991.csv +12 -0
- pychnosz/data/extdata/Berman/VGT92_1992.csv +2 -0
- pychnosz/data/extdata/Berman/VPT01_2001.csv +3 -0
- pychnosz/data/extdata/Berman/VPV05_2005.csv +2 -0
- pychnosz/data/extdata/Berman/ZS92_1992.csv +11 -0
- pychnosz/data/extdata/Berman/sympy.R +99 -0
- pychnosz/data/extdata/Berman/testing/BA96.bib +12 -0
- pychnosz/data/extdata/Berman/testing/BA96_Berman.csv +21 -0
- pychnosz/data/extdata/Berman/testing/BA96_OBIGT.csv +21 -0
- pychnosz/data/extdata/Berman/testing/BA96_refs.csv +6 -0
- pychnosz/data/extdata/OBIGT/AD.csv +25 -0
- pychnosz/data/extdata/OBIGT/Berman_cr.csv +93 -0
- pychnosz/data/extdata/OBIGT/DEW.csv +211 -0
- pychnosz/data/extdata/OBIGT/H2O_aq.csv +4 -0
- pychnosz/data/extdata/OBIGT/SLOP98.csv +411 -0
- pychnosz/data/extdata/OBIGT/SUPCRT92.csv +178 -0
- pychnosz/data/extdata/OBIGT/inorganic_aq.csv +729 -0
- pychnosz/data/extdata/OBIGT/inorganic_cr.csv +273 -0
- pychnosz/data/extdata/OBIGT/inorganic_gas.csv +20 -0
- pychnosz/data/extdata/OBIGT/organic_aq.csv +1104 -0
- pychnosz/data/extdata/OBIGT/organic_cr.csv +481 -0
- pychnosz/data/extdata/OBIGT/organic_gas.csv +268 -0
- pychnosz/data/extdata/OBIGT/organic_liq.csv +533 -0
- pychnosz/data/extdata/OBIGT/testing/GEMSFIT.csv +43 -0
- pychnosz/data/extdata/OBIGT/testing/IGEM.csv +17 -0
- pychnosz/data/extdata/OBIGT/testing/Sandia.csv +8 -0
- pychnosz/data/extdata/OBIGT/testing/SiO2.csv +4 -0
- pychnosz/data/extdata/misc/AD03_Fig1a.csv +69 -0
- pychnosz/data/extdata/misc/AD03_Fig1b.csv +43 -0
- pychnosz/data/extdata/misc/AD03_Fig1c.csv +89 -0
- pychnosz/data/extdata/misc/AD03_Fig1d.csv +30 -0
- pychnosz/data/extdata/misc/BZA10.csv +5 -0
- pychnosz/data/extdata/misc/HW97_Cp.csv +90 -0
- pychnosz/data/extdata/misc/HWM96_V.csv +229 -0
- pychnosz/data/extdata/misc/LA19_test.csv +7 -0
- pychnosz/data/extdata/misc/Mer75_Table4.csv +42 -0
- pychnosz/data/extdata/misc/OBIGT_check.csv +423 -0
- pychnosz/data/extdata/misc/PM90.csv +7 -0
- pychnosz/data/extdata/misc/RH95.csv +23 -0
- pychnosz/data/extdata/misc/RH98_Table15.csv +17 -0
- pychnosz/data/extdata/misc/SC10_Rainbow.csv +19 -0
- pychnosz/data/extdata/misc/SK95.csv +55 -0
- pychnosz/data/extdata/misc/SOJSH.csv +61 -0
- pychnosz/data/extdata/misc/SS98_Fig5a.csv +81 -0
- pychnosz/data/extdata/misc/SS98_Fig5b.csv +84 -0
- pychnosz/data/extdata/misc/TKSS14_Fig2.csv +25 -0
- pychnosz/data/extdata/misc/bluered.txt +1000 -0
- pychnosz/data/extdata/protein/Cas/Cas_aa.csv +177 -0
- pychnosz/data/extdata/protein/Cas/Cas_uniprot.csv +186 -0
- pychnosz/data/extdata/protein/Cas/download.R +34 -0
- pychnosz/data/extdata/protein/Cas/mkaa.R +34 -0
- pychnosz/data/extdata/protein/POLG.csv +12 -0
- pychnosz/data/extdata/protein/TBD+05.csv +393 -0
- pychnosz/data/extdata/protein/TBD+05_aa.csv +393 -0
- pychnosz/data/extdata/protein/rubisco.csv +28 -0
- pychnosz/data/extdata/protein/rubisco.fasta +239 -0
- pychnosz/data/extdata/protein/rubisco_aa.csv +28 -0
- pychnosz/data/extdata/src/H2O92D.f.orig +3457 -0
- pychnosz/data/extdata/src/README.txt +5 -0
- pychnosz/data/extdata/taxonomy/names.dmp +215 -0
- pychnosz/data/extdata/taxonomy/nodes.dmp +63 -0
- pychnosz/data/extdata/thermo/Bdot_acirc.csv +60 -0
- pychnosz/data/extdata/thermo/buffer.csv +40 -0
- pychnosz/data/extdata/thermo/element.csv +135 -0
- pychnosz/data/extdata/thermo/groups.csv +6 -0
- pychnosz/data/extdata/thermo/opt.csv +2 -0
- pychnosz/data/extdata/thermo/protein.csv +506 -0
- pychnosz/data/extdata/thermo/refs.csv +343 -0
- pychnosz/data/extdata/thermo/stoich.csv.xz +0 -0
- pychnosz/data/loader.py +431 -0
- pychnosz/data/mod_obigt.py +322 -0
- pychnosz/data/obigt.py +471 -0
- pychnosz/data/worm.py +228 -0
- pychnosz/fortran/.gitignore +6 -0
- pychnosz/fortran/__init__.py +16 -0
- pychnosz/fortran/h2o92.dylib +0 -0
- pychnosz/fortran/h2o92_interface.py +527 -0
- pychnosz/geochemistry/__init__.py +21 -0
- pychnosz/geochemistry/minerals.py +514 -0
- pychnosz/geochemistry/redox.py +500 -0
- pychnosz/models/__init__.py +47 -0
- pychnosz/models/archer_wang.py +165 -0
- pychnosz/models/berman.py +309 -0
- pychnosz/models/cgl.py +381 -0
- pychnosz/models/dew.py +997 -0
- pychnosz/models/hkf.py +523 -0
- pychnosz/models/hkf_helpers.py +231 -0
- pychnosz/models/iapws95.py +1113 -0
- pychnosz/models/supcrt92_fortran.py +238 -0
- pychnosz/models/water.py +480 -0
- pychnosz/utils/__init__.py +27 -0
- pychnosz/utils/expression.py +1074 -0
- pychnosz/utils/formula.py +830 -0
- pychnosz/utils/formula_ox.py +227 -0
- pychnosz/utils/reset.py +33 -0
- pychnosz/utils/units.py +259 -0
- pychnosz-1.1.12.dist-info/METADATA +197 -0
- pychnosz-1.1.12.dist-info/RECORD +133 -0
- pychnosz-1.1.12.dist-info/WHEEL +5 -0
- pychnosz-1.1.12.dist-info/licenses/LICENSE.txt +19 -0
- pychnosz-1.1.12.dist-info/top_level.txt +1 -0
pychnosz/core/diagram.py
ADDED
|
@@ -0,0 +1,3336 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Diagram module for plotting chemical activity and predominance diagrams.
|
|
3
|
+
|
|
4
|
+
This module provides Python equivalents of the R functions in diagram.R:
|
|
5
|
+
- diagram(): Plot equilibrium chemical activity and predominance diagrams
|
|
6
|
+
- Supporting utilities for 1D line plots and 2D predominance diagrams
|
|
7
|
+
|
|
8
|
+
Author: CHNOSZ Python port
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
import matplotlib.pyplot as plt
|
|
14
|
+
from typing import Union, List, Optional, Dict, Any, Tuple
|
|
15
|
+
import warnings
|
|
16
|
+
import copy
|
|
17
|
+
from ..utils.expression import _format_species_latex
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def copy_plot(diagram_result: Dict[str, Any]) -> Dict[str, Any]:
|
|
21
|
+
"""
|
|
22
|
+
Create a deep copy of a diagram result, allowing independent modification.
|
|
23
|
+
|
|
24
|
+
This function addresses a fundamental limitation in Python plotting libraries:
|
|
25
|
+
matplotlib figure and axes objects are mutable, so passing them between
|
|
26
|
+
functions causes modifications to affect all references. This function
|
|
27
|
+
creates a true deep copy that can be modified independently.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
diagram_result : dict
|
|
32
|
+
Result dictionary from diagram(), which may contain 'fig' and 'ax' keys
|
|
33
|
+
|
|
34
|
+
Returns
|
|
35
|
+
-------
|
|
36
|
+
dict
|
|
37
|
+
A deep copy of the diagram result with independent figure and axes objects
|
|
38
|
+
|
|
39
|
+
Examples
|
|
40
|
+
--------
|
|
41
|
+
Manual copying workflow (advanced usage - normally use add_to parameter instead):
|
|
42
|
+
|
|
43
|
+
>>> import pychnosz
|
|
44
|
+
>>> # Create base plot (Plot A)
|
|
45
|
+
>>> basis(['SiO2', 'Ca+2', 'Mg+2', 'CO2', 'H2O', 'O2', 'H+'])
|
|
46
|
+
>>> species(['quartz', 'talc', 'chrysotile', 'forsterite'])
|
|
47
|
+
>>> a = affinity(**{'Mg+2': [4, 10, 500], 'Ca+2': [5, 15, 500]})
|
|
48
|
+
>>> plot_a = diagram(a, fill='terrain')
|
|
49
|
+
>>>
|
|
50
|
+
>>> # Manual approach: create copies first, then modify the axes directly
|
|
51
|
+
>>> plot_a1 = copy_plot(plot_a) # For modification 1
|
|
52
|
+
>>> plot_a2 = copy_plot(plot_a) # For modification 2
|
|
53
|
+
>>> # ... then modify plot_a1['ax'] and plot_a2['ax'] directly
|
|
54
|
+
>>>
|
|
55
|
+
>>> # Recommended approach: use add_to parameter instead
|
|
56
|
+
>>> # This automatically handles copying internally
|
|
57
|
+
>>> basis('CO2', -1)
|
|
58
|
+
>>> species(['calcite', 'dolomide'])
|
|
59
|
+
>>> a2 = affinity(**{'Mg+2': [4, 10, 500], 'Ca+2': [5, 15, 500]})
|
|
60
|
+
>>> plot_a1 = diagram(a2, type='saturation', add_to=plot_a, col='blue')
|
|
61
|
+
>>> plot_a2 = diagram(a2, type='saturation', add_to=plot_a, col='red')
|
|
62
|
+
>>> # Now you have three independent plots: plot_a, plot_a1, plot_a2
|
|
63
|
+
|
|
64
|
+
Notes
|
|
65
|
+
-----
|
|
66
|
+
- This function uses copy.deepcopy() which works well for matplotlib figures
|
|
67
|
+
- For very large plots, copying may be memory-intensive
|
|
68
|
+
- Interactive plots (plotly) may not copy perfectly - test before relying on this
|
|
69
|
+
- The copied plot is fully independent and can be saved, displayed, or modified
|
|
70
|
+
without affecting the original
|
|
71
|
+
|
|
72
|
+
Limitations
|
|
73
|
+
-----------
|
|
74
|
+
Python's matplotlib (unlike R's base graphics) uses mutable objects for plots.
|
|
75
|
+
Without explicit copying, all references point to the same plot. This is a
|
|
76
|
+
known limitation of matplotlib that this function works around.
|
|
77
|
+
|
|
78
|
+
See Also
|
|
79
|
+
--------
|
|
80
|
+
diagram : Create plots that can be copied with this function
|
|
81
|
+
"""
|
|
82
|
+
return copy.deepcopy(diagram_result)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def diagram(eout: Dict[str, Any],
|
|
86
|
+
type: str = "auto",
|
|
87
|
+
alpha: bool = False,
|
|
88
|
+
balance: Optional[Union[str, float, List[float]]] = None,
|
|
89
|
+
names: Optional[List[str]] = None,
|
|
90
|
+
format_names: bool = True,
|
|
91
|
+
xlab: Optional[str] = None,
|
|
92
|
+
ylab: Optional[str] = None,
|
|
93
|
+
xlim: Optional[List[float]] = None,
|
|
94
|
+
ylim: Optional[List[float]] = None,
|
|
95
|
+
col: Optional[Union[str, List[str]]] = None,
|
|
96
|
+
col_names: Optional[Union[str, List[str]]] = None,
|
|
97
|
+
lty: Optional[Union[str, int, List]] = None,
|
|
98
|
+
lwd: Union[float, List[float]] = 1,
|
|
99
|
+
cex: Union[float, List[float]] = 1.0,
|
|
100
|
+
main: Optional[str] = None,
|
|
101
|
+
fill: Optional[str] = None,
|
|
102
|
+
fill_NA: str = "0.8",
|
|
103
|
+
limit_water: Optional[bool] = None,
|
|
104
|
+
plot_it: bool = True,
|
|
105
|
+
add_to: Optional[Dict[str, Any]] = None,
|
|
106
|
+
contour_method: Optional[Union[str, List[str]]] = "edge",
|
|
107
|
+
messages: bool = True,
|
|
108
|
+
interactive: bool = False,
|
|
109
|
+
annotation: Optional[str] = None,
|
|
110
|
+
annotation_coords: List[float] = [0, 0],
|
|
111
|
+
width: int = 600,
|
|
112
|
+
height: int = 520,
|
|
113
|
+
save_as: Optional[str] = None,
|
|
114
|
+
save_format: Optional[str] = None,
|
|
115
|
+
save_scale: float = 1,
|
|
116
|
+
normalize: Union[bool, List[bool]] = False,
|
|
117
|
+
as_residue: bool = False,
|
|
118
|
+
**kwargs) -> Dict[str, Any]:
|
|
119
|
+
"""
|
|
120
|
+
Plot equilibrium chemical activity and predominance diagrams.
|
|
121
|
+
|
|
122
|
+
This function creates plots from the output of affinity() or equilibrate().
|
|
123
|
+
For 1D diagrams, it produces line plots showing how affinity or activity
|
|
124
|
+
varies with a single variable. For 2D diagrams, it creates predominance
|
|
125
|
+
field diagrams.
|
|
126
|
+
|
|
127
|
+
Parameters
|
|
128
|
+
----------
|
|
129
|
+
eout : dict
|
|
130
|
+
Output from affinity() or equilibrate()
|
|
131
|
+
type : str, default "auto"
|
|
132
|
+
Type of diagram:
|
|
133
|
+
- "auto" (default): Plot affinity values (A/2.303RT)
|
|
134
|
+
- "loga.equil": Plot equilibrium activities from equilibrate()
|
|
135
|
+
- "saturation": Draw affinity=0 contour lines (mineral saturation)
|
|
136
|
+
- Basis species name (e.g., "O2", "H2O", "CO2"): Plot equilibrium
|
|
137
|
+
log activity/fugacity of the specified basis species where affinity=0
|
|
138
|
+
for each formed species. Useful for Eh-pH diagrams and showing
|
|
139
|
+
oxygen/water fugacities at equilibrium.
|
|
140
|
+
alpha : bool or str, default False
|
|
141
|
+
Plot degree of formation instead of activities?
|
|
142
|
+
If "balance", scale by balancing coefficients
|
|
143
|
+
balance : str, float, or list of float, optional
|
|
144
|
+
Balancing coefficients or method for balancing reactions
|
|
145
|
+
names : list of str, optional
|
|
146
|
+
Custom names for species (for labels)
|
|
147
|
+
format_names : bool, default True
|
|
148
|
+
Apply formatting to chemical formulas?
|
|
149
|
+
xlab : str, optional
|
|
150
|
+
Custom x-axis label
|
|
151
|
+
ylab : str, optional
|
|
152
|
+
Custom y-axis label
|
|
153
|
+
xlim : list of float, optional
|
|
154
|
+
X-axis limits [min, max]
|
|
155
|
+
ylim : list of float, optional
|
|
156
|
+
Y-axis limits [min, max]
|
|
157
|
+
col : str or list of str, optional
|
|
158
|
+
Line colors for 1-D plots and boundary lines in 2-D plots (matplotlib color specs)
|
|
159
|
+
col_names : str or list of str, optional
|
|
160
|
+
Text colors for field labels in 2-D plots (matplotlib color specs)
|
|
161
|
+
lty : str, int, or list, optional
|
|
162
|
+
Line styles (matplotlib linestyle specs)
|
|
163
|
+
lwd : float or list of float, default 1
|
|
164
|
+
Line widths for 1-D plots and boundary lines in 2-D predominance
|
|
165
|
+
diagrams. Set to 0 to disable borders in 2-D diagrams. If fill is
|
|
166
|
+
None and lwd > 0, uses white fill with black borders (R CHNOSZ default).
|
|
167
|
+
cex : float or list of float, default 1.0
|
|
168
|
+
Character expansion factor for text labels. Values > 1 make text larger,
|
|
169
|
+
values < 1 make text smaller. Can be a single value or a list (one per species).
|
|
170
|
+
Used for contour labels in type="saturation" plots.
|
|
171
|
+
main : str, optional
|
|
172
|
+
Plot title
|
|
173
|
+
fill : str, optional
|
|
174
|
+
Color palette for 2-D predominance diagrams. Can be any matplotlib
|
|
175
|
+
colormap name (e.g., 'viridis', 'plasma', 'terrain', 'rainbow',
|
|
176
|
+
'Set1', 'tab10', 'Pastel1'). If None, uses discrete colors from
|
|
177
|
+
the default color cycle. Ignored for 1-D diagrams.
|
|
178
|
+
fill_NA : str, default "0.8"
|
|
179
|
+
Color for regions outside water stability limits (water instability regions).
|
|
180
|
+
Matplotlib color specification (e.g., "0.8" for gray, "#CCCCCC").
|
|
181
|
+
Set to "transparent" to disable shading. Default "0.8" matches R's "gray80".
|
|
182
|
+
limit_water : bool, optional
|
|
183
|
+
Whether to show water stability limits as shaded regions (default True for
|
|
184
|
+
2-D diagrams). If True, also clips the diagram to the water stability region.
|
|
185
|
+
Set to False to disable water stability shading.
|
|
186
|
+
plot_it : bool, default True
|
|
187
|
+
Display the plot?
|
|
188
|
+
add_to : dict, optional
|
|
189
|
+
A diagram result dictionary from a previous diagram() call. When provided,
|
|
190
|
+
this plot will be AUTOMATICALLY COPIED and the new diagram will be added to
|
|
191
|
+
the copy. This preserves the original plot while creating a modified version.
|
|
192
|
+
The axes object is extracted from add_to['ax'].
|
|
193
|
+
|
|
194
|
+
This parameter eliminates the need for a separate 'add' boolean - when
|
|
195
|
+
add_to is provided, the function automatically operates in "add" mode.
|
|
196
|
+
|
|
197
|
+
Example workflow:
|
|
198
|
+
>>> plot_a = diagram(affinity1, fill='terrain') # Create base plot
|
|
199
|
+
>>> plot_a1 = diagram(affinity2, add_to=plot_a, col='blue') # Copy and add
|
|
200
|
+
>>> plot_a2 = diagram(affinity3, add_to=plot_a, col='red') # Copy and add again
|
|
201
|
+
>>> # plot_a remains unchanged, plot_a1 and plot_a2 are independent modifications
|
|
202
|
+
contour_method : str or list of str, optional
|
|
203
|
+
Method for labeling contour lines. Default "edge" labels at plot edges.
|
|
204
|
+
Can be a single value (applied to all species) or a list (one per species).
|
|
205
|
+
Set to None, NA, or "" to disable labels (only for type="saturation").
|
|
206
|
+
In R CHNOSZ, different methods like "edge", "flattest", "simple" control
|
|
207
|
+
label placement; in Python, this mainly controls whether labels are shown.
|
|
208
|
+
interactive : bool, default False
|
|
209
|
+
Create an interactive plot using Plotly instead of matplotlib?
|
|
210
|
+
If True, calls diagram_interactive() with the appropriate parameters.
|
|
211
|
+
annotation : str, optional
|
|
212
|
+
For interactive plots only. Annotation text to add to the plot.
|
|
213
|
+
annotation_coords : list of float, default [0, 0]
|
|
214
|
+
For interactive plots only. Coordinates of annotation, where [0, 0] is
|
|
215
|
+
bottom left and [1, 1] is top right.
|
|
216
|
+
width : int, default 600
|
|
217
|
+
For interactive plots only. Width of the plot in pixels.
|
|
218
|
+
height : int, default 520
|
|
219
|
+
For interactive plots only. Height of the plot in pixels.
|
|
220
|
+
save_as : str, optional
|
|
221
|
+
For interactive plots only. Provide a filename to save this figure.
|
|
222
|
+
Filetype is determined by `save_format`.
|
|
223
|
+
save_format : str, optional
|
|
224
|
+
For interactive plots only. Desired format of saved or downloaded figure.
|
|
225
|
+
Can be 'png', 'jpg', 'jpeg', 'webp', 'svg', 'pdf', 'eps', 'json', or 'html'.
|
|
226
|
+
If 'html', an interactive plot will be saved.
|
|
227
|
+
save_scale : float, default 1
|
|
228
|
+
For interactive plots only. Multiply title/legend/axis/canvas sizes by
|
|
229
|
+
this factor when saving the figure.
|
|
230
|
+
**kwargs
|
|
231
|
+
Additional arguments passed to matplotlib plotting functions
|
|
232
|
+
|
|
233
|
+
Returns
|
|
234
|
+
-------
|
|
235
|
+
dict
|
|
236
|
+
Dictionary containing:
|
|
237
|
+
- plotvar : str, Variable that was plotted
|
|
238
|
+
- plotvals : dict, Values that were plotted
|
|
239
|
+
- names : list, Names used for labels
|
|
240
|
+
- predominant : array or NA, Predominance matrix (for 2D)
|
|
241
|
+
- balance : str or list, Balancing coefficients used
|
|
242
|
+
- n.balance : list, Numerical balancing coefficients
|
|
243
|
+
- ax : matplotlib.axes.Axes, The axes object used for plotting (if plot_it=True)
|
|
244
|
+
- fig : matplotlib.figure.Figure, The figure object used for plotting (if plot_it=True)
|
|
245
|
+
- All original eout contents
|
|
246
|
+
|
|
247
|
+
Examples
|
|
248
|
+
--------
|
|
249
|
+
>>> import pychnosz
|
|
250
|
+
>>> pychnosz.basis(["Fe2O3", "CO2", "H2O", "NH3", "H2S", "oxygen", "H+"],
|
|
251
|
+
... [0, -3, 0, -4, -7, -80, -7])
|
|
252
|
+
>>> pychnosz.species(["pyrite", "goethite"])
|
|
253
|
+
>>> a = pychnosz.affinity(H2S=[-60, 20, 5], T=25, P=1)
|
|
254
|
+
>>> d = diagram(a)
|
|
255
|
+
|
|
256
|
+
Notes
|
|
257
|
+
-----
|
|
258
|
+
This implementation is based on R CHNOSZ diagram() function but adapted
|
|
259
|
+
for Python's matplotlib plotting instead of R's base graphics. The key
|
|
260
|
+
differences from diagram_from_WORM.py are:
|
|
261
|
+
- Works directly with Python dict output from affinity() (no rpy2)
|
|
262
|
+
- Uses matplotlib for 1D plots by default
|
|
263
|
+
- Can optionally use plotly if requested
|
|
264
|
+
"""
|
|
265
|
+
|
|
266
|
+
# Handle add_to parameter: automatically copy the provided plot
|
|
267
|
+
# This extracts the axes object and creates an independent copy
|
|
268
|
+
# When add_to is provided, we're in "add" mode
|
|
269
|
+
ax = None
|
|
270
|
+
add = add_to is not None
|
|
271
|
+
plot_was_provided = add
|
|
272
|
+
|
|
273
|
+
if add_to is not None:
|
|
274
|
+
# Make a deep copy of the provided plot to preserve the original
|
|
275
|
+
plot_copy = copy_plot(add_to)
|
|
276
|
+
# Extract the axes from the copied plot
|
|
277
|
+
if 'ax' in plot_copy:
|
|
278
|
+
ax = plot_copy['ax']
|
|
279
|
+
else:
|
|
280
|
+
raise ValueError("The 'add_to' parameter must contain an 'ax' key (a diagram result dictionary)")
|
|
281
|
+
|
|
282
|
+
# If interactive mode is requested, delegate to diagram_interactive
|
|
283
|
+
if interactive:
|
|
284
|
+
df, fig = diagram_interactive(
|
|
285
|
+
eout=eout,
|
|
286
|
+
type=type,
|
|
287
|
+
main=main,
|
|
288
|
+
borders=lwd,
|
|
289
|
+
names=names,
|
|
290
|
+
format_names=format_names,
|
|
291
|
+
annotation=annotation,
|
|
292
|
+
annotation_coords=annotation_coords,
|
|
293
|
+
balance=balance,
|
|
294
|
+
xlab=xlab,
|
|
295
|
+
ylab=ylab,
|
|
296
|
+
fill=fill,
|
|
297
|
+
width=width,
|
|
298
|
+
height=height,
|
|
299
|
+
alpha=alpha,
|
|
300
|
+
plot_it=plot_it,
|
|
301
|
+
add=add,
|
|
302
|
+
ax=ax,
|
|
303
|
+
col=col,
|
|
304
|
+
lty=lty,
|
|
305
|
+
lwd=lwd,
|
|
306
|
+
cex=cex,
|
|
307
|
+
contour_method=contour_method,
|
|
308
|
+
save_as=save_as,
|
|
309
|
+
save_format=save_format,
|
|
310
|
+
save_scale=save_scale,
|
|
311
|
+
messages=messages
|
|
312
|
+
)
|
|
313
|
+
# Return in a format compatible with diagram's normal output
|
|
314
|
+
# diagram_interactive returns (df, fig), wrap in a dict for consistency
|
|
315
|
+
# Include eout data so water_lines() can access vars, vals, basis, etc.
|
|
316
|
+
result = {
|
|
317
|
+
**eout, # Include all original eout data
|
|
318
|
+
'df': df,
|
|
319
|
+
'fig': fig,
|
|
320
|
+
'ax': fig # For compatibility, store fig in ax key for add=True workflow
|
|
321
|
+
}
|
|
322
|
+
return result
|
|
323
|
+
|
|
324
|
+
# Check that eout is valid
|
|
325
|
+
efun = eout.get('fun', '')
|
|
326
|
+
if efun not in ['affinity', 'equilibrate', 'solubility']:
|
|
327
|
+
raise ValueError("'eout' is not the output from affinity(), equilibrate(), or solubility()")
|
|
328
|
+
|
|
329
|
+
# Determine if eout is from affinity() (as opposed to equilibrate())
|
|
330
|
+
# Check for both Python naming (loga_equil) and R naming (loga.equil)
|
|
331
|
+
eout_is_aout = 'loga_equil' not in eout and 'loga.equil' not in eout
|
|
332
|
+
|
|
333
|
+
# Check if type is a basis species name
|
|
334
|
+
plot_loga_basis = False
|
|
335
|
+
if type not in ["auto", "saturation", "loga.equil", "loga_equil", "loga.balance", "loga_balance"]:
|
|
336
|
+
# Check if type matches a basis species name
|
|
337
|
+
if 'basis' in eout:
|
|
338
|
+
basis_species = list(eout['basis'].index) if hasattr(eout['basis'], 'index') else []
|
|
339
|
+
if type in basis_species:
|
|
340
|
+
plot_loga_basis = True
|
|
341
|
+
if alpha:
|
|
342
|
+
raise ValueError("equilibrium activities of basis species not available with alpha = TRUE")
|
|
343
|
+
|
|
344
|
+
# Handle type="saturation" - requires affinity output
|
|
345
|
+
if type == "saturation":
|
|
346
|
+
if not eout_is_aout:
|
|
347
|
+
raise ValueError("type='saturation' requires output from affinity(), not equilibrate()")
|
|
348
|
+
# Set eout_is_aout flag
|
|
349
|
+
eout_is_aout = True
|
|
350
|
+
|
|
351
|
+
# Get number of dimensions
|
|
352
|
+
# Handle both dict (affinity) and list (equilibrate) values structures
|
|
353
|
+
if isinstance(eout['values'], dict):
|
|
354
|
+
first_values = list(eout['values'].values())[0]
|
|
355
|
+
elif isinstance(eout['values'], list):
|
|
356
|
+
first_values = eout['values'][0]
|
|
357
|
+
else:
|
|
358
|
+
first_values = eout['values']
|
|
359
|
+
|
|
360
|
+
if hasattr(first_values, 'shape'):
|
|
361
|
+
nd = len(first_values.shape)
|
|
362
|
+
elif hasattr(first_values, '__len__'):
|
|
363
|
+
nd = 1
|
|
364
|
+
else:
|
|
365
|
+
nd = 0 # Single value
|
|
366
|
+
|
|
367
|
+
# For affinity output, get balancing coefficients
|
|
368
|
+
if eout_is_aout and type == "auto":
|
|
369
|
+
n_balance, balance = _get_balance(eout, balance, messages)
|
|
370
|
+
elif eout_is_aout and type == "saturation":
|
|
371
|
+
# For saturation diagrams, use n_balance = 1 for all species (don't normalize by stoichiometry)
|
|
372
|
+
if isinstance(eout['values'], dict):
|
|
373
|
+
n_balance = [1] * len(eout['values'])
|
|
374
|
+
elif isinstance(eout['values'], list):
|
|
375
|
+
n_balance = [1] * len(eout['values'])
|
|
376
|
+
else:
|
|
377
|
+
n_balance = [1]
|
|
378
|
+
if balance is None:
|
|
379
|
+
balance = 1
|
|
380
|
+
else:
|
|
381
|
+
# For equilibrate output, use n_balance from equilibrate if available
|
|
382
|
+
if 'n_balance' in eout:
|
|
383
|
+
n_balance = eout['n_balance']
|
|
384
|
+
balance = eout.get('balance', 1)
|
|
385
|
+
else:
|
|
386
|
+
if isinstance(eout['values'], dict):
|
|
387
|
+
n_balance = [1] * len(eout['values'])
|
|
388
|
+
elif isinstance(eout['values'], list):
|
|
389
|
+
n_balance = [1] * len(eout['values'])
|
|
390
|
+
else:
|
|
391
|
+
n_balance = [1]
|
|
392
|
+
if balance is None:
|
|
393
|
+
balance = 1
|
|
394
|
+
|
|
395
|
+
# Determine what to plot
|
|
396
|
+
plotvals = {}
|
|
397
|
+
plotvar = eout.get('property', 'A')
|
|
398
|
+
|
|
399
|
+
# Calculate equilibrium log activity/fugacity of basis species
|
|
400
|
+
if plot_loga_basis:
|
|
401
|
+
# Find the index of the basis species
|
|
402
|
+
basis_df = eout['basis']
|
|
403
|
+
ibasis = list(basis_df.index).index(type)
|
|
404
|
+
|
|
405
|
+
# Get the logarithm of activity used in the affinity calculation
|
|
406
|
+
logact = basis_df.iloc[ibasis]['logact']
|
|
407
|
+
|
|
408
|
+
# Check if logact is numeric
|
|
409
|
+
try:
|
|
410
|
+
loga_basis = float(logact)
|
|
411
|
+
except (ValueError, TypeError):
|
|
412
|
+
raise ValueError(f"the logarithm of activity for basis species {type} is not numeric - was a buffer selected?")
|
|
413
|
+
|
|
414
|
+
# Get the reaction coefficients for this basis species
|
|
415
|
+
# eout['species'] is a DataFrame with basis species as columns
|
|
416
|
+
nu_basis = eout['species'].iloc[:, ibasis].values
|
|
417
|
+
|
|
418
|
+
# Calculate the logarithm of activity where affinity = 0
|
|
419
|
+
# loga_equilibrium = loga_basis - affinity / nu_basis
|
|
420
|
+
plotvals = {}
|
|
421
|
+
for i, (sp_idx, affinity_vals) in enumerate(eout['values'].items()):
|
|
422
|
+
plotvals[sp_idx] = loga_basis - affinity_vals / nu_basis[i]
|
|
423
|
+
|
|
424
|
+
plotvar = type
|
|
425
|
+
|
|
426
|
+
# Set n_balance (not used for basis species plots, but needed for compatibility)
|
|
427
|
+
n_balance = [1] * len(plotvals)
|
|
428
|
+
if balance is None:
|
|
429
|
+
balance = 1
|
|
430
|
+
elif eout_is_aout:
|
|
431
|
+
# Plot affinity values divided by balancing coefficients
|
|
432
|
+
# DEBUG: Check balance application
|
|
433
|
+
if False: # Set to True for debugging
|
|
434
|
+
print(f"\nDEBUG: Applying balance to affinity values")
|
|
435
|
+
print(f" n_balance: {n_balance}")
|
|
436
|
+
|
|
437
|
+
# Handle dict-based values (from affinity)
|
|
438
|
+
if isinstance(eout['values'], dict):
|
|
439
|
+
for i, (species_idx, values) in enumerate(eout['values'].items()):
|
|
440
|
+
if False: # Set to True for debugging
|
|
441
|
+
print(f" Species {i} (ispecies {species_idx}): values/n_balance[{i}]={n_balance[i]}")
|
|
442
|
+
plotvals[species_idx] = values / n_balance[i]
|
|
443
|
+
# Handle list-based values
|
|
444
|
+
elif isinstance(eout['values'], list):
|
|
445
|
+
for i, values in enumerate(eout['values']):
|
|
446
|
+
species_idx = eout['species']['ispecies'].iloc[i]
|
|
447
|
+
plotvals[species_idx] = values / n_balance[i]
|
|
448
|
+
|
|
449
|
+
if plotvar == 'A':
|
|
450
|
+
plotvar = 'A/(2.303RT)'
|
|
451
|
+
if nd == 1:
|
|
452
|
+
if messages:
|
|
453
|
+
print(f"diagram: plotting {plotvar} / n.balance")
|
|
454
|
+
else:
|
|
455
|
+
# Plot equilibrated activities
|
|
456
|
+
# Check for both Python naming (loga_equil) and R naming (loga.equil)
|
|
457
|
+
loga_equil_key = 'loga_equil' if 'loga_equil' in eout else 'loga.equil'
|
|
458
|
+
loga_equil_list = eout[loga_equil_key]
|
|
459
|
+
|
|
460
|
+
# For equilibrate output, keep plotvals as a dict with INTEGER indices as keys
|
|
461
|
+
# This preserves the 1:1 correspondence with the species list, including duplicates
|
|
462
|
+
# Do NOT use ispecies as keys because duplicates would overwrite each other
|
|
463
|
+
if isinstance(loga_equil_list, list):
|
|
464
|
+
for i, loga_val in enumerate(loga_equil_list):
|
|
465
|
+
plotvals[i] = loga_val # Use integer index, not ispecies
|
|
466
|
+
else:
|
|
467
|
+
# Already a dict
|
|
468
|
+
plotvals = loga_equil_list
|
|
469
|
+
|
|
470
|
+
plotvar = 'loga.equil'
|
|
471
|
+
|
|
472
|
+
# Handle alpha (degree of formation)
|
|
473
|
+
if alpha:
|
|
474
|
+
# Convert to activities (remove logarithms)
|
|
475
|
+
# Use numpy arrays for proper element-wise operations
|
|
476
|
+
act_vals = {}
|
|
477
|
+
for k, v in plotvals.items():
|
|
478
|
+
if isinstance(v, np.ndarray):
|
|
479
|
+
act_vals[k] = 10**v
|
|
480
|
+
else:
|
|
481
|
+
act_vals[k] = np.power(10, v)
|
|
482
|
+
|
|
483
|
+
# Scale by balance if requested
|
|
484
|
+
if alpha == "balance":
|
|
485
|
+
species_keys = list(act_vals.keys())
|
|
486
|
+
for i, k in enumerate(species_keys):
|
|
487
|
+
act_vals[k] = act_vals[k] * n_balance[i]
|
|
488
|
+
|
|
489
|
+
# Calculate sum of activities (element-wise for arrays)
|
|
490
|
+
# Get the first value to determine shape
|
|
491
|
+
first_val = list(act_vals.values())[0]
|
|
492
|
+
if isinstance(first_val, np.ndarray):
|
|
493
|
+
# Multi-dimensional case
|
|
494
|
+
sum_act = np.zeros_like(first_val)
|
|
495
|
+
for v in act_vals.values():
|
|
496
|
+
sum_act = sum_act + v
|
|
497
|
+
else:
|
|
498
|
+
# Single value case
|
|
499
|
+
sum_act = sum(act_vals.values())
|
|
500
|
+
|
|
501
|
+
# Calculate alpha (fraction) - element-wise division
|
|
502
|
+
plotvals = {k: v / sum_act for k, v in act_vals.items()}
|
|
503
|
+
plotvar = "alpha"
|
|
504
|
+
|
|
505
|
+
# Get species information for labels
|
|
506
|
+
species_df = eout['species']
|
|
507
|
+
if names is None:
|
|
508
|
+
names = species_df['name'].tolist()
|
|
509
|
+
|
|
510
|
+
# Format chemical names if requested
|
|
511
|
+
if format_names and not alpha:
|
|
512
|
+
names = [_format_chemname(name) for name in names]
|
|
513
|
+
|
|
514
|
+
# Prepare for plotting
|
|
515
|
+
if nd == 0:
|
|
516
|
+
# 0-D: Bar plot (not implemented yet)
|
|
517
|
+
raise NotImplementedError("0-D bar plots not yet implemented")
|
|
518
|
+
|
|
519
|
+
elif nd == 1:
|
|
520
|
+
# 1-D: Line plot
|
|
521
|
+
result = _plot_1d(eout, plotvals, plotvar, names, n_balance, balance,
|
|
522
|
+
xlab, ylab, xlim, ylim, col, lty, lwd, main, add, plot_it, ax, width, height, plot_was_provided, **kwargs)
|
|
523
|
+
|
|
524
|
+
elif nd == 2:
|
|
525
|
+
# 2-D: Predominance diagram or saturation lines
|
|
526
|
+
# Pass lty and cex through kwargs for saturation plots
|
|
527
|
+
result = _plot_2d(eout, plotvals, plotvar, names, n_balance, balance,
|
|
528
|
+
xlab, ylab, xlim, ylim, col, col_names, fill, fill_NA, limit_water, lwd, main, add, plot_it, ax,
|
|
529
|
+
type, contour_method, messages, width, height, plot_was_provided, lty=lty, cex=cex, **kwargs)
|
|
530
|
+
|
|
531
|
+
else:
|
|
532
|
+
raise ValueError(f"Cannot create diagram with {nd} dimensions")
|
|
533
|
+
|
|
534
|
+
# Handle Jupyter display behavior
|
|
535
|
+
# When plot_it=True, we want the figure to display
|
|
536
|
+
# When plot_it=False, we want to suppress display and close the figure
|
|
537
|
+
if not plot_it and result is not None and 'fig' in result:
|
|
538
|
+
# Close the figure to prevent auto-display in Jupyter
|
|
539
|
+
# The figure is still in the result dict, so users can access it via result['fig']
|
|
540
|
+
# but it won't be displayed automatically
|
|
541
|
+
plt.close(result['fig'])
|
|
542
|
+
elif plot_it and result is not None and 'fig' in result:
|
|
543
|
+
# Try to use IPython display if available (for Jupyter notebooks)
|
|
544
|
+
try:
|
|
545
|
+
from IPython.display import display
|
|
546
|
+
display(result['fig'])
|
|
547
|
+
except (ImportError, NameError):
|
|
548
|
+
# Not in IPython/Jupyter, regular matplotlib display
|
|
549
|
+
pass
|
|
550
|
+
|
|
551
|
+
return result
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
def _get_balance(eout: Dict[str, Any], balance: Optional[Union[str, float, List[float]]], messages: bool = True) -> Tuple[List[float], Union[str, int, List[float]]]:
|
|
555
|
+
"""
|
|
556
|
+
Get balancing coefficients for formation reactions.
|
|
557
|
+
|
|
558
|
+
This implements the R CHNOSZ balance() function logic for determining
|
|
559
|
+
how to balance formation reactions when calculating diagrams.
|
|
560
|
+
|
|
561
|
+
Parameters
|
|
562
|
+
----------
|
|
563
|
+
eout : dict
|
|
564
|
+
Output from affinity()
|
|
565
|
+
balance : str, float, list of float, or None
|
|
566
|
+
Balancing specification
|
|
567
|
+
|
|
568
|
+
Returns
|
|
569
|
+
-------
|
|
570
|
+
tuple of (list of float, balance_name)
|
|
571
|
+
- Balancing coefficients for each species
|
|
572
|
+
- The balance identifier used
|
|
573
|
+
"""
|
|
574
|
+
species_df = eout['species']
|
|
575
|
+
basis_df = eout['basis']
|
|
576
|
+
n_species = len(species_df)
|
|
577
|
+
|
|
578
|
+
# Get basis species column names (exclude metadata columns)
|
|
579
|
+
basis_cols = [col for col in species_df.columns
|
|
580
|
+
if col not in ['ispecies', 'name', 'state', 'logact']]
|
|
581
|
+
|
|
582
|
+
if balance is None:
|
|
583
|
+
# Auto-select using which_balance logic
|
|
584
|
+
ibalance = _which_balance(species_df, basis_cols)
|
|
585
|
+
if len(ibalance) == 0:
|
|
586
|
+
raise ValueError("no basis species is present in all formation reactions")
|
|
587
|
+
balance_col = basis_cols[ibalance[0]]
|
|
588
|
+
n_balance = species_df[balance_col].tolist()
|
|
589
|
+
if messages:
|
|
590
|
+
print(f"balance: on moles of {balance_col} in formation reactions")
|
|
591
|
+
balance = balance_col
|
|
592
|
+
elif balance == 1 or balance == "1":
|
|
593
|
+
# Balance on one mole of species (formula units)
|
|
594
|
+
n_balance = [1] * n_species
|
|
595
|
+
if messages:
|
|
596
|
+
print("balance: on supplied numeric argument (1) [1 means balance on formula units]")
|
|
597
|
+
balance = 1
|
|
598
|
+
elif isinstance(balance, (int, float)):
|
|
599
|
+
# Use a specific basis species by index
|
|
600
|
+
if 0 < balance <= len(basis_cols):
|
|
601
|
+
balance_col = basis_cols[int(balance) - 1]
|
|
602
|
+
n_balance = species_df[balance_col].tolist()
|
|
603
|
+
if messages:
|
|
604
|
+
print(f"balance: on moles of {balance_col} in formation reactions")
|
|
605
|
+
balance = balance_col
|
|
606
|
+
else:
|
|
607
|
+
warnings.warn(f"Balance index {balance} out of range, using 1")
|
|
608
|
+
n_balance = [1] * n_species
|
|
609
|
+
balance = 1
|
|
610
|
+
elif isinstance(balance, str):
|
|
611
|
+
# Use named basis species
|
|
612
|
+
if balance in species_df.columns:
|
|
613
|
+
n_balance = species_df[balance].tolist()
|
|
614
|
+
if messages:
|
|
615
|
+
print(f"balance: on moles of {balance} in formation reactions")
|
|
616
|
+
else:
|
|
617
|
+
warnings.warn(f"Balance species '{balance}' not found, using 1")
|
|
618
|
+
n_balance = [1] * n_species
|
|
619
|
+
balance = 1
|
|
620
|
+
elif isinstance(balance, list):
|
|
621
|
+
# Use provided list
|
|
622
|
+
if len(balance) == n_species:
|
|
623
|
+
n_balance = balance
|
|
624
|
+
if messages:
|
|
625
|
+
print(f"balance: on supplied numeric argument ({','.join(map(str, balance))})")
|
|
626
|
+
else:
|
|
627
|
+
warnings.warn(f"Balance list length ({len(balance)}) doesn't match species count ({n_species}), using 1")
|
|
628
|
+
n_balance = [1] * n_species
|
|
629
|
+
balance = 1
|
|
630
|
+
else:
|
|
631
|
+
n_balance = [1] * n_species
|
|
632
|
+
balance = 1
|
|
633
|
+
|
|
634
|
+
# Handle negative coefficients (make all positive if all negative)
|
|
635
|
+
if all(x < 0 for x in n_balance):
|
|
636
|
+
n_balance = [-x for x in n_balance]
|
|
637
|
+
|
|
638
|
+
return n_balance, balance
|
|
639
|
+
|
|
640
|
+
|
|
641
|
+
def _which_balance(species_df: pd.DataFrame, basis_cols: List[str]) -> List[int]:
|
|
642
|
+
"""
|
|
643
|
+
Find basis species present in all formation reactions.
|
|
644
|
+
|
|
645
|
+
This implements R CHNOSZ which.balance() function.
|
|
646
|
+
|
|
647
|
+
Parameters
|
|
648
|
+
----------
|
|
649
|
+
species_df : pd.DataFrame
|
|
650
|
+
Species dataframe with stoichiometric coefficients
|
|
651
|
+
basis_cols : list of str
|
|
652
|
+
Names of basis species columns
|
|
653
|
+
|
|
654
|
+
Returns
|
|
655
|
+
-------
|
|
656
|
+
list of int
|
|
657
|
+
Indices of basis species present in all reactions (0-indexed)
|
|
658
|
+
"""
|
|
659
|
+
ib = []
|
|
660
|
+
for i, col in enumerate(basis_cols):
|
|
661
|
+
coeffs = species_df[col].values
|
|
662
|
+
# Check if all coefficients are non-zero
|
|
663
|
+
if np.all(coeffs != 0):
|
|
664
|
+
ib.append(i)
|
|
665
|
+
return ib
|
|
666
|
+
|
|
667
|
+
|
|
668
|
+
def _plot_1d(eout: Dict[str, Any],
|
|
669
|
+
plotvals: Dict,
|
|
670
|
+
plotvar: str,
|
|
671
|
+
names: List[str],
|
|
672
|
+
n_balance: List[float],
|
|
673
|
+
balance: Optional[Union[str, float, List[float]]],
|
|
674
|
+
xlab: Optional[str],
|
|
675
|
+
ylab: Optional[str],
|
|
676
|
+
xlim: Optional[List[float]],
|
|
677
|
+
ylim: Optional[List[float]],
|
|
678
|
+
col: Optional[Union[str, List[str]]],
|
|
679
|
+
lty: Optional[Union[str, int, List]],
|
|
680
|
+
lwd: Union[float, List[float]],
|
|
681
|
+
main: Optional[str],
|
|
682
|
+
add: bool,
|
|
683
|
+
plot_it: bool,
|
|
684
|
+
ax: Optional[Any],
|
|
685
|
+
width: int = 600,
|
|
686
|
+
height: int = 520,
|
|
687
|
+
plot_was_provided: bool = False,
|
|
688
|
+
**kwargs) -> Dict[str, Any]:
|
|
689
|
+
"""
|
|
690
|
+
Create a 1-D line plot.
|
|
691
|
+
|
|
692
|
+
Parameters
|
|
693
|
+
----------
|
|
694
|
+
(See diagram() for parameter descriptions)
|
|
695
|
+
|
|
696
|
+
Returns
|
|
697
|
+
-------
|
|
698
|
+
dict
|
|
699
|
+
Output dictionary with plot data and metadata
|
|
700
|
+
"""
|
|
701
|
+
|
|
702
|
+
# Get x-axis values
|
|
703
|
+
xvar = eout['vars'][0]
|
|
704
|
+
xvals = eout['vals'][xvar]
|
|
705
|
+
|
|
706
|
+
# Convert to numpy array if needed
|
|
707
|
+
if not isinstance(xvals, np.ndarray):
|
|
708
|
+
xvals = np.array(xvals)
|
|
709
|
+
|
|
710
|
+
# Set up axis labels
|
|
711
|
+
if xlab is None:
|
|
712
|
+
xlab = _axis_label(xvar, eout)
|
|
713
|
+
|
|
714
|
+
if ylab is None:
|
|
715
|
+
ylab = _axis_label(plotvar, eout)
|
|
716
|
+
|
|
717
|
+
# Set up x limits
|
|
718
|
+
if xlim is None:
|
|
719
|
+
xlim = [xvals[0], xvals[-1]]
|
|
720
|
+
|
|
721
|
+
# Set up colors and line styles
|
|
722
|
+
n_species = len(plotvals)
|
|
723
|
+
|
|
724
|
+
if col is None:
|
|
725
|
+
# Use matplotlib default color cycle
|
|
726
|
+
prop_cycle = plt.rcParams['axes.prop_cycle']
|
|
727
|
+
colors = prop_cycle.by_key()['color']
|
|
728
|
+
col = [colors[i % len(colors)] for i in range(n_species)]
|
|
729
|
+
elif isinstance(col, str):
|
|
730
|
+
col = [col] * n_species
|
|
731
|
+
else:
|
|
732
|
+
col = list(col) * (n_species // len(col) + 1)
|
|
733
|
+
col = col[:n_species]
|
|
734
|
+
|
|
735
|
+
if lty is None:
|
|
736
|
+
lty = ['-'] * n_species
|
|
737
|
+
elif isinstance(lty, (str, int)):
|
|
738
|
+
lty = [lty] * n_species
|
|
739
|
+
else:
|
|
740
|
+
lty = list(lty) * (n_species // len(lty) + 1)
|
|
741
|
+
lty = lty[:n_species]
|
|
742
|
+
|
|
743
|
+
if isinstance(lwd, (int, float)):
|
|
744
|
+
lwd = [lwd] * n_species
|
|
745
|
+
else:
|
|
746
|
+
lwd = list(lwd) * (n_species // len(lwd) + 1)
|
|
747
|
+
lwd = lwd[:n_species]
|
|
748
|
+
|
|
749
|
+
# Convert numeric line styles to matplotlib styles
|
|
750
|
+
lty_map = {1: '-', 2: '--', 3: '-.', 4: ':', 5: '-', 6: '--'}
|
|
751
|
+
lty = [lty_map.get(lt, lt) if isinstance(lt, int) else lt for lt in lty]
|
|
752
|
+
|
|
753
|
+
# Temporarily disable interactive mode if plot_it=False
|
|
754
|
+
# This prevents Jupyter from auto-displaying the figure
|
|
755
|
+
was_interactive = plt.isinteractive()
|
|
756
|
+
if not plot_it:
|
|
757
|
+
plt.ioff()
|
|
758
|
+
|
|
759
|
+
# Convert width and height from pixels to inches for matplotlib
|
|
760
|
+
# Use standard 96 DPI for consistency with web/screen displays
|
|
761
|
+
dpi = 96
|
|
762
|
+
figsize_inches = (width / dpi, height / dpi)
|
|
763
|
+
|
|
764
|
+
# Create figure and axes (always, even if plot_it=False)
|
|
765
|
+
# This allows the plot to be used with add_to parameter later
|
|
766
|
+
fig = None
|
|
767
|
+
ax_was_provided = ax is not None # Track if ax was passed as parameter
|
|
768
|
+
|
|
769
|
+
if ax is not None:
|
|
770
|
+
# Use provided axes
|
|
771
|
+
fig = ax.get_figure()
|
|
772
|
+
elif not add:
|
|
773
|
+
# Create new figure and axes with specified size
|
|
774
|
+
fig, ax = plt.subplots(figsize=figsize_inches, dpi=dpi)
|
|
775
|
+
else:
|
|
776
|
+
# Try to get current axes, create new if none exists
|
|
777
|
+
try:
|
|
778
|
+
ax = plt.gca()
|
|
779
|
+
fig = ax.get_figure()
|
|
780
|
+
except:
|
|
781
|
+
fig, ax = plt.subplots(figsize=figsize_inches, dpi=dpi)
|
|
782
|
+
|
|
783
|
+
# Plot each species (always draw, regardless of plot_it)
|
|
784
|
+
# plot_it only controls display, not drawing
|
|
785
|
+
for i, (species_idx, yvals) in enumerate(plotvals.items()):
|
|
786
|
+
# Convert to numpy array if needed
|
|
787
|
+
if not isinstance(yvals, np.ndarray):
|
|
788
|
+
yvals = np.array([yvals] * len(xvals))
|
|
789
|
+
|
|
790
|
+
ax.plot(xvals, yvals,
|
|
791
|
+
color=col[i],
|
|
792
|
+
linestyle=lty[i],
|
|
793
|
+
linewidth=lwd[i],
|
|
794
|
+
label=names[i],
|
|
795
|
+
**kwargs)
|
|
796
|
+
|
|
797
|
+
# Set labels and limits
|
|
798
|
+
ax.set_xlabel(xlab)
|
|
799
|
+
ax.set_ylabel(ylab)
|
|
800
|
+
|
|
801
|
+
if xlim is not None:
|
|
802
|
+
ax.set_xlim(xlim)
|
|
803
|
+
|
|
804
|
+
if ylim is not None:
|
|
805
|
+
ax.set_ylim(ylim)
|
|
806
|
+
|
|
807
|
+
# Add legend
|
|
808
|
+
ax.legend()
|
|
809
|
+
|
|
810
|
+
# Add title
|
|
811
|
+
if main is not None:
|
|
812
|
+
ax.set_title(main)
|
|
813
|
+
|
|
814
|
+
# Add grid
|
|
815
|
+
ax.grid(True, alpha=0.3)
|
|
816
|
+
|
|
817
|
+
if not add:
|
|
818
|
+
plt.tight_layout()
|
|
819
|
+
|
|
820
|
+
# Build output dictionary
|
|
821
|
+
result = {
|
|
822
|
+
**eout,
|
|
823
|
+
'plotvar': plotvar,
|
|
824
|
+
'plotvals': plotvals,
|
|
825
|
+
'names': names,
|
|
826
|
+
'predominant': np.nan,
|
|
827
|
+
'balance': balance,
|
|
828
|
+
'n.balance': n_balance
|
|
829
|
+
}
|
|
830
|
+
|
|
831
|
+
# Add figure and axes to output if they were created
|
|
832
|
+
if fig is not None:
|
|
833
|
+
if not ax_was_provided or plot_was_provided:
|
|
834
|
+
result['ax'] = ax
|
|
835
|
+
result['fig'] = fig
|
|
836
|
+
|
|
837
|
+
# Always restore interactive mode to its original state
|
|
838
|
+
if was_interactive and not plt.isinteractive():
|
|
839
|
+
plt.ion()
|
|
840
|
+
elif not was_interactive and plt.isinteractive():
|
|
841
|
+
plt.ioff()
|
|
842
|
+
|
|
843
|
+
return result
|
|
844
|
+
|
|
845
|
+
|
|
846
|
+
def _plot_2d(eout: Dict[str, Any],
|
|
847
|
+
plotvals: Dict[int, np.ndarray],
|
|
848
|
+
plotvar: str,
|
|
849
|
+
names: List[str],
|
|
850
|
+
n_balance: List[float],
|
|
851
|
+
balance: Union[str, int, List[float]],
|
|
852
|
+
xlab: Optional[str],
|
|
853
|
+
ylab: Optional[str],
|
|
854
|
+
xlim: Optional[List[float]],
|
|
855
|
+
ylim: Optional[List[float]],
|
|
856
|
+
col: Optional[Union[str, List[str]]],
|
|
857
|
+
col_names: Optional[Union[str, List[str]]],
|
|
858
|
+
fill: Optional[str],
|
|
859
|
+
fill_NA: str,
|
|
860
|
+
limit_water: Optional[bool],
|
|
861
|
+
lwd: Union[float, List[float]],
|
|
862
|
+
main: Optional[str],
|
|
863
|
+
add: bool,
|
|
864
|
+
plot_it: bool,
|
|
865
|
+
ax: Optional[Any],
|
|
866
|
+
type: str = "auto",
|
|
867
|
+
contour_method: Optional[str] = "edge",
|
|
868
|
+
messages: bool = True,
|
|
869
|
+
width: int = 600,
|
|
870
|
+
height: int = 520,
|
|
871
|
+
plot_was_provided: bool = False,
|
|
872
|
+
**kwargs) -> Dict[str, Any]:
|
|
873
|
+
"""
|
|
874
|
+
Create a 2-D predominance diagram (internal function).
|
|
875
|
+
|
|
876
|
+
This function determines which species has the maximum affinity at each
|
|
877
|
+
point on a 2D grid and creates a predominance field diagram.
|
|
878
|
+
|
|
879
|
+
Parameters
|
|
880
|
+
----------
|
|
881
|
+
eout : dict
|
|
882
|
+
Output from affinity()
|
|
883
|
+
plotvals : dict
|
|
884
|
+
Values to plot (affinity or equilibrium values)
|
|
885
|
+
plotvar : str
|
|
886
|
+
Variable being plotted
|
|
887
|
+
names : list of str
|
|
888
|
+
Species names for labels
|
|
889
|
+
n_balance : list of float
|
|
890
|
+
Balancing coefficients
|
|
891
|
+
balance : str, int, or list
|
|
892
|
+
Balance identifier
|
|
893
|
+
xlab, ylab : str or None
|
|
894
|
+
Axis labels
|
|
895
|
+
xlim, ylim : list or None
|
|
896
|
+
Axis limits
|
|
897
|
+
col : str, list, or None
|
|
898
|
+
Colors for boundary lines in 2D plots
|
|
899
|
+
col_names : str, list, or None
|
|
900
|
+
Colors for field labels (text) in 2D plots
|
|
901
|
+
fill : str or None
|
|
902
|
+
Matplotlib colormap name for coloring predominance fields
|
|
903
|
+
(e.g., 'viridis', 'plasma', 'terrain', 'rainbow', 'Set1', 'tab10')
|
|
904
|
+
lwd : float or list of float
|
|
905
|
+
Line width for drawing boundaries between predominance fields.
|
|
906
|
+
Set to 0 to disable borders.
|
|
907
|
+
main : str or None
|
|
908
|
+
Plot title
|
|
909
|
+
add : bool
|
|
910
|
+
Add to existing plot?
|
|
911
|
+
plot_it : bool
|
|
912
|
+
Display the plot?
|
|
913
|
+
**kwargs
|
|
914
|
+
Additional matplotlib arguments
|
|
915
|
+
|
|
916
|
+
Returns
|
|
917
|
+
-------
|
|
918
|
+
dict
|
|
919
|
+
Result dictionary with predominance information
|
|
920
|
+
"""
|
|
921
|
+
|
|
922
|
+
# Get the two variables
|
|
923
|
+
vars_list = eout['vars']
|
|
924
|
+
if len(vars_list) != 2:
|
|
925
|
+
raise ValueError(f"Expected 2 variables for 2-D plot, got {len(vars_list)}")
|
|
926
|
+
|
|
927
|
+
# R CHNOSZ convention: first variable in affinity() → x-axis, second → y-axis
|
|
928
|
+
# In the array: rows correspond to first var, columns to second var
|
|
929
|
+
xvar = vars_list[0] # First variable (rows in array) → x-axis
|
|
930
|
+
yvar = vars_list[1] # Second variable (cols in array) → y-axis
|
|
931
|
+
|
|
932
|
+
# Get the values for each variable
|
|
933
|
+
xvals = eout['vals'][xvar]
|
|
934
|
+
yvals = eout['vals'][yvar]
|
|
935
|
+
|
|
936
|
+
# Convert to numpy arrays if needed
|
|
937
|
+
xvals = np.asarray(xvals)
|
|
938
|
+
yvals = np.asarray(yvals)
|
|
939
|
+
|
|
940
|
+
# Get axis labels
|
|
941
|
+
if xlab is None:
|
|
942
|
+
xlab = _axis_label(xvar, eout)
|
|
943
|
+
if ylab is None:
|
|
944
|
+
ylab = _axis_label(yvar, eout)
|
|
945
|
+
|
|
946
|
+
# Handle saturation lines separately from predominance diagrams
|
|
947
|
+
if type == "saturation":
|
|
948
|
+
# Extract lty and cex from kwargs (they're in kwargs from diagram() call)
|
|
949
|
+
lty_param = kwargs.pop('lty', None)
|
|
950
|
+
cex_param = kwargs.pop('cex', 1.0)
|
|
951
|
+
|
|
952
|
+
return _plot_saturation_2d(eout, plotvals, plotvar, names, n_balance, balance,
|
|
953
|
+
xlab, ylab, xlim, ylim, col, lwd, lty_param, cex_param,
|
|
954
|
+
main, add, plot_it, ax, contour_method, messages, width, height, plot_was_provided, **kwargs)
|
|
955
|
+
|
|
956
|
+
# For non-saturation plots, remove lty and cex from kwargs to avoid passing to matplotlib
|
|
957
|
+
kwargs.pop('lty', None)
|
|
958
|
+
kwargs.pop('cex', None)
|
|
959
|
+
|
|
960
|
+
# Print message about diagram method
|
|
961
|
+
if messages:
|
|
962
|
+
print(f"diagram: using maximum affinity method for 2-D diagram")
|
|
963
|
+
|
|
964
|
+
# Stack all species values into a 3D array (species, rows, cols)
|
|
965
|
+
# The plotvals dict can have two types of keys:
|
|
966
|
+
# 1. Integer indices (0, 1, 2, ...) for equilibrate output - preserves duplicates
|
|
967
|
+
# 2. ispecies values (1844, 1876, ...) for affinity output - unique species only
|
|
968
|
+
species_keys = list(plotvals.keys())
|
|
969
|
+
n_species = len(species_keys)
|
|
970
|
+
|
|
971
|
+
# Check if species_keys are integer indices (equilibrate) or ispecies values (affinity)
|
|
972
|
+
# Integer indices will be 0, 1, 2, ..., n-1 and directly map to names list
|
|
973
|
+
# ispecies values are typically > 100 and need mapping
|
|
974
|
+
species_df = eout['species']
|
|
975
|
+
if all(isinstance(k, int) and k < len(names) for k in species_keys):
|
|
976
|
+
# Integer indices: direct mapping to names
|
|
977
|
+
predominant_to_names_idx = {i: species_keys[i] for i in range(n_species)}
|
|
978
|
+
else:
|
|
979
|
+
# ispecies values: need to find matching rows
|
|
980
|
+
# Use first matching row (for consistency with affinity output)
|
|
981
|
+
predominant_to_names_idx = {}
|
|
982
|
+
for i, sp_idx in enumerate(species_keys):
|
|
983
|
+
matching_rows = species_df[species_df['ispecies'] == sp_idx].index.tolist()
|
|
984
|
+
if len(matching_rows) > 0:
|
|
985
|
+
predominant_to_names_idx[i] = matching_rows[0]
|
|
986
|
+
else:
|
|
987
|
+
predominant_to_names_idx[i] = i # Fallback
|
|
988
|
+
|
|
989
|
+
# DEBUG: Print species_keys order
|
|
990
|
+
# print(f"DEBUG: species_keys = {species_keys}")
|
|
991
|
+
|
|
992
|
+
# Get dimensions from first species
|
|
993
|
+
first_vals = plotvals[species_keys[0]]
|
|
994
|
+
if len(first_vals.shape) != 2:
|
|
995
|
+
raise ValueError(f"Expected 2-D array for each species, got shape {first_vals.shape}")
|
|
996
|
+
|
|
997
|
+
# Array shape: (n_xvar, n_yvar) since first var → x-axis, second → y-axis
|
|
998
|
+
# For our example: (n_T, n_P) = (3, 10)
|
|
999
|
+
n_xvar, n_yvar = first_vals.shape
|
|
1000
|
+
|
|
1001
|
+
# Stack all species affinities into a 3D array
|
|
1002
|
+
affinity_stack = np.zeros((n_species, n_xvar, n_yvar))
|
|
1003
|
+
for i, sp_idx in enumerate(species_keys):
|
|
1004
|
+
affinity_stack[i, :, :] = plotvals[sp_idx]
|
|
1005
|
+
|
|
1006
|
+
# DEBUG: Check species order
|
|
1007
|
+
if False: # Set to True for debugging
|
|
1008
|
+
print(f"\nDEBUG affinity_stack:")
|
|
1009
|
+
print(f" n_species: {n_species}")
|
|
1010
|
+
print(f" species_keys: {species_keys}")
|
|
1011
|
+
for i, sp_idx in enumerate(species_keys):
|
|
1012
|
+
print(f" Stack position {i}: species {sp_idx}")
|
|
1013
|
+
|
|
1014
|
+
# Find the species with maximum affinity at each point
|
|
1015
|
+
# predominant will have indices 0, 1, 2, ... for species
|
|
1016
|
+
predominant_indices = np.argmax(affinity_stack, axis=0)
|
|
1017
|
+
|
|
1018
|
+
# Get the affinity values at predominant points
|
|
1019
|
+
predominant_values = np.max(affinity_stack, axis=0)
|
|
1020
|
+
|
|
1021
|
+
# Convert indices to species indices (1-based for R compatibility)
|
|
1022
|
+
# In R, predominant contains 1, 2, 3, etc.
|
|
1023
|
+
predominant = predominant_indices + 1
|
|
1024
|
+
|
|
1025
|
+
# Calculate water stability limits if requested (default True for 2-D diagrams)
|
|
1026
|
+
H2O_predominant = None
|
|
1027
|
+
if limit_water is None:
|
|
1028
|
+
limit_water = not add # True unless adding to existing plot
|
|
1029
|
+
|
|
1030
|
+
if limit_water:
|
|
1031
|
+
# Call water_lines with plot_it=False to get boundaries
|
|
1032
|
+
wl = water_lines(eout, plot_it=False, messages=messages)
|
|
1033
|
+
# Check if water_lines produced valid results
|
|
1034
|
+
if not (wl['y_oxidation'] is None or wl['y_reduction'] is None):
|
|
1035
|
+
# Create a copy of predominant matrix for water stability masking
|
|
1036
|
+
# Convert to float to allow NaN values
|
|
1037
|
+
H2O_predominant = predominant.astype(float).copy()
|
|
1038
|
+
|
|
1039
|
+
# For each x-point, find y-values outside water stability limits
|
|
1040
|
+
for i in range(len(wl['xpoints'])):
|
|
1041
|
+
ymin = min(wl['y_oxidation'][i], wl['y_reduction'][i])
|
|
1042
|
+
ymax = max(wl['y_oxidation'][i], wl['y_reduction'][i])
|
|
1043
|
+
|
|
1044
|
+
if not wl['swapped']:
|
|
1045
|
+
# Normal orientation: x is first var, y is second var
|
|
1046
|
+
# eout['vals'][yvar] contains the y-axis values
|
|
1047
|
+
yvals = np.asarray(eout['vals'][yvar])
|
|
1048
|
+
# Find indices where y is outside stability range
|
|
1049
|
+
iNA = (yvals < ymin) | (yvals > ymax)
|
|
1050
|
+
# Mark those regions as NA (using nan)
|
|
1051
|
+
H2O_predominant[i, iNA] = np.nan
|
|
1052
|
+
else:
|
|
1053
|
+
# Swapped: first var is y-axis
|
|
1054
|
+
xvals = np.asarray(eout['vals'][xvar])
|
|
1055
|
+
iNA = (xvals < ymin) | (xvals > ymax)
|
|
1056
|
+
H2O_predominant[iNA, i] = np.nan
|
|
1057
|
+
|
|
1058
|
+
# For plotting: transpose arrays so that x-axis is horizontal and y-axis is vertical
|
|
1059
|
+
# imshow expects (n_rows, n_cols) = (n_yaxis, n_xaxis)
|
|
1060
|
+
# Current shape is (n_xvar, n_yvar), so transpose to (n_yvar, n_xvar)
|
|
1061
|
+
predominant_indices_T = predominant_indices.T
|
|
1062
|
+
if H2O_predominant is not None:
|
|
1063
|
+
H2O_predominant_T = H2O_predominant.T
|
|
1064
|
+
else:
|
|
1065
|
+
H2O_predominant_T = None
|
|
1066
|
+
|
|
1067
|
+
# Temporarily disable interactive mode if plot_it=False
|
|
1068
|
+
# This prevents Jupyter from auto-displaying the figure
|
|
1069
|
+
was_interactive = plt.isinteractive()
|
|
1070
|
+
if not plot_it:
|
|
1071
|
+
plt.ioff()
|
|
1072
|
+
|
|
1073
|
+
# Convert width and height from pixels to inches for matplotlib
|
|
1074
|
+
# Use standard 96 DPI for consistency with web/screen displays
|
|
1075
|
+
dpi = 96
|
|
1076
|
+
figsize_inches = (width / dpi, height / dpi)
|
|
1077
|
+
|
|
1078
|
+
# Create figure and axes (always, even if plot_it=False)
|
|
1079
|
+
# This allows the plot to be used with add_to parameter later
|
|
1080
|
+
fig = None
|
|
1081
|
+
ax_was_provided = ax is not None # Track if ax was passed as parameter
|
|
1082
|
+
|
|
1083
|
+
if ax is not None:
|
|
1084
|
+
# Use provided axes
|
|
1085
|
+
fig = ax.get_figure()
|
|
1086
|
+
elif not add:
|
|
1087
|
+
# Create new figure and axes with specified size
|
|
1088
|
+
fig, ax = plt.subplots(figsize=figsize_inches, dpi=dpi)
|
|
1089
|
+
else:
|
|
1090
|
+
# Try to get current axes, create new if none exists
|
|
1091
|
+
try:
|
|
1092
|
+
ax = plt.gca()
|
|
1093
|
+
fig = ax.get_figure()
|
|
1094
|
+
except:
|
|
1095
|
+
fig, ax = plt.subplots(figsize=figsize_inches, dpi=dpi)
|
|
1096
|
+
|
|
1097
|
+
# When add=True, don't draw fill by default (to overlay on existing plot)
|
|
1098
|
+
# User can explicitly provide fill parameter to override this
|
|
1099
|
+
draw_fill = True
|
|
1100
|
+
if add and fill is None:
|
|
1101
|
+
draw_fill = False
|
|
1102
|
+
|
|
1103
|
+
# Draw the plot content (always, regardless of plot_it)
|
|
1104
|
+
# plot_it only controls display, not drawing
|
|
1105
|
+
# Determine fill colors for predominance fields
|
|
1106
|
+
# Priority: fill parameter > default
|
|
1107
|
+
# R CHNOSZ default: white fill with black borders
|
|
1108
|
+
if fill is not None:
|
|
1109
|
+
# Use a matplotlib colormap
|
|
1110
|
+
try:
|
|
1111
|
+
import matplotlib.cm as cm
|
|
1112
|
+
cmap = cm.get_cmap(fill)
|
|
1113
|
+
# Sample the colormap evenly across species
|
|
1114
|
+
# Use range 0.1 to 0.9 to avoid very light/dark ends
|
|
1115
|
+
color_indices = np.linspace(0.1, 0.9, n_species)
|
|
1116
|
+
fill_colors = [cmap(idx) for idx in color_indices]
|
|
1117
|
+
except (ValueError, KeyError):
|
|
1118
|
+
warnings.warn(f"Colormap '{fill}' not found, using default colors")
|
|
1119
|
+
prop_cycle = plt.rcParams['axes.prop_cycle']
|
|
1120
|
+
colors = prop_cycle.by_key()['color']
|
|
1121
|
+
fill_colors = [colors[i % len(colors)] for i in range(n_species)]
|
|
1122
|
+
elif fill is None and lwd > 0:
|
|
1123
|
+
# R CHNOSZ default behavior: white fill with black borders
|
|
1124
|
+
fill_colors = ['white'] * n_species
|
|
1125
|
+
else:
|
|
1126
|
+
# Use default matplotlib colors
|
|
1127
|
+
prop_cycle = plt.rcParams['axes.prop_cycle']
|
|
1128
|
+
colors = prop_cycle.by_key()['color']
|
|
1129
|
+
fill_colors = [colors[i % len(colors)] for i in range(n_species)]
|
|
1130
|
+
|
|
1131
|
+
# Determine boundary line colors (col parameter)
|
|
1132
|
+
if col is None:
|
|
1133
|
+
# Default to black for boundary lines
|
|
1134
|
+
boundary_colors = ['black'] * n_species
|
|
1135
|
+
elif isinstance(col, str):
|
|
1136
|
+
boundary_colors = [col] * n_species
|
|
1137
|
+
else:
|
|
1138
|
+
boundary_colors = list(col)
|
|
1139
|
+
if len(boundary_colors) < n_species:
|
|
1140
|
+
# Repeat colors if not enough provided
|
|
1141
|
+
boundary_colors = boundary_colors * (n_species // len(boundary_colors) + 1)
|
|
1142
|
+
boundary_colors = boundary_colors[:n_species]
|
|
1143
|
+
|
|
1144
|
+
# Determine text label colors (col_names parameter)
|
|
1145
|
+
if col_names is None:
|
|
1146
|
+
# Default to black for text labels
|
|
1147
|
+
text_colors = ['black'] * n_species
|
|
1148
|
+
elif isinstance(col_names, str):
|
|
1149
|
+
text_colors = [col_names] * n_species
|
|
1150
|
+
else:
|
|
1151
|
+
text_colors = list(col_names)
|
|
1152
|
+
if len(text_colors) < n_species:
|
|
1153
|
+
# Repeat colors if not enough provided
|
|
1154
|
+
text_colors = text_colors * (n_species // len(text_colors) + 1)
|
|
1155
|
+
text_colors = text_colors[:n_species]
|
|
1156
|
+
|
|
1157
|
+
# NOTE: Water instability shading is NOT drawn here in diagram()
|
|
1158
|
+
# It is only drawn when water_lines() is explicitly called
|
|
1159
|
+
# We keep H2O_predominant for the limit_water feature, but don't show gray shading
|
|
1160
|
+
|
|
1161
|
+
# Draw filled predominance fields only if not adding to existing plot
|
|
1162
|
+
# (or if user explicitly provided fill parameter)
|
|
1163
|
+
if draw_fill:
|
|
1164
|
+
# Create a colored map showing predominance fields
|
|
1165
|
+
# Map each predominance index to its color
|
|
1166
|
+
# Shape is (n_yvar, n_xvar) after transpose
|
|
1167
|
+
colored_predominant = np.zeros((n_yvar, n_xvar, 3)) # RGB image
|
|
1168
|
+
from matplotlib.colors import to_rgb
|
|
1169
|
+
|
|
1170
|
+
for i in range(n_species):
|
|
1171
|
+
mask = predominant_indices_T == i
|
|
1172
|
+
rgb = to_rgb(fill_colors[i])
|
|
1173
|
+
colored_predominant[mask] = rgb
|
|
1174
|
+
|
|
1175
|
+
# Display the predominance map
|
|
1176
|
+
# imshow plots: rows → y-axis (vertical), cols → x-axis (horizontal)
|
|
1177
|
+
# extent sets the data coordinates: [x_start, x_end, y_start, y_end]
|
|
1178
|
+
# IMPORTANT: Use original order (xvals[0] to xvals[-1]), not min/max
|
|
1179
|
+
# This preserves axes with decreasing values (e.g., F- from -2 to -9)
|
|
1180
|
+
extent = [xvals[0], xvals[-1], yvals[0], yvals[-1]]
|
|
1181
|
+
|
|
1182
|
+
# Expand extent if any dimension has identical limits to avoid matplotlib warning
|
|
1183
|
+
if extent[0] == extent[1]:
|
|
1184
|
+
x_range = abs(extent[0]) * 0.1 if extent[0] != 0 else 0.1
|
|
1185
|
+
extent[0] -= x_range
|
|
1186
|
+
extent[1] += x_range
|
|
1187
|
+
if extent[2] == extent[3]:
|
|
1188
|
+
y_range = abs(extent[2]) * 0.1 if extent[2] != 0 else 0.1
|
|
1189
|
+
extent[2] -= y_range
|
|
1190
|
+
extent[3] += y_range
|
|
1191
|
+
|
|
1192
|
+
im = ax.imshow(colored_predominant, aspect='auto', origin='lower',
|
|
1193
|
+
extent=extent, interpolation='nearest', **kwargs)
|
|
1194
|
+
|
|
1195
|
+
# Add species labels at the center of their predominance regions
|
|
1196
|
+
for i in range(n_species):
|
|
1197
|
+
# Find all points where this species predominates (in transposed array)
|
|
1198
|
+
mask = predominant_indices_T == i
|
|
1199
|
+
if np.any(mask):
|
|
1200
|
+
# Get row and column indices from transposed array
|
|
1201
|
+
# rows correspond to y-axis values, cols to x-axis values
|
|
1202
|
+
rows, cols = np.where(mask)
|
|
1203
|
+
# Calculate mean position
|
|
1204
|
+
mean_row = rows.mean()
|
|
1205
|
+
mean_col = cols.mean()
|
|
1206
|
+
# Convert to data coordinates
|
|
1207
|
+
# cols index into xvals, rows index into yvals
|
|
1208
|
+
x_pos = xvals[int(mean_col)]
|
|
1209
|
+
y_pos = yvals[int(mean_row)]
|
|
1210
|
+
|
|
1211
|
+
# Map predominant index to correct names index
|
|
1212
|
+
# This handles cases where duplicate species exist in the species list
|
|
1213
|
+
names_idx = predominant_to_names_idx[i]
|
|
1214
|
+
|
|
1215
|
+
# Add text label with color from col_names
|
|
1216
|
+
ax.text(x_pos, y_pos, names[names_idx],
|
|
1217
|
+
ha='center', va='center',
|
|
1218
|
+
color=text_colors[i],
|
|
1219
|
+
bbox=dict(boxstyle='round', facecolor='white', edgecolor='none', alpha=0.7),
|
|
1220
|
+
fontsize=10, fontweight='bold')
|
|
1221
|
+
|
|
1222
|
+
# Set labels and limits
|
|
1223
|
+
ax.set_xlabel(xlab)
|
|
1224
|
+
ax.set_ylabel(ylab)
|
|
1225
|
+
|
|
1226
|
+
if xlim is not None:
|
|
1227
|
+
ax.set_xlim(xlim)
|
|
1228
|
+
else:
|
|
1229
|
+
# Preserve original axis direction (important for decreasing axes)
|
|
1230
|
+
x_min, x_max = xvals[0], xvals[-1]
|
|
1231
|
+
# Expand range if limits are identical to avoid matplotlib warning
|
|
1232
|
+
if x_min == x_max:
|
|
1233
|
+
x_range = abs(x_max) * 0.1 if x_max != 0 else 0.1
|
|
1234
|
+
x_min -= x_range
|
|
1235
|
+
x_max += x_range
|
|
1236
|
+
ax.set_xlim([x_min, x_max])
|
|
1237
|
+
|
|
1238
|
+
if ylim is not None:
|
|
1239
|
+
ax.set_ylim(ylim)
|
|
1240
|
+
else:
|
|
1241
|
+
# Preserve original axis direction (important for decreasing axes)
|
|
1242
|
+
y_min, y_max = yvals[0], yvals[-1]
|
|
1243
|
+
# Expand range if limits are identical to avoid matplotlib warning
|
|
1244
|
+
if y_min == y_max:
|
|
1245
|
+
y_range = abs(y_max) * 0.1 if y_max != 0 else 0.1
|
|
1246
|
+
y_min -= y_range
|
|
1247
|
+
y_max += y_range
|
|
1248
|
+
ax.set_ylim([y_min, y_max])
|
|
1249
|
+
|
|
1250
|
+
# Add title
|
|
1251
|
+
if main is not None:
|
|
1252
|
+
ax.set_title(main)
|
|
1253
|
+
|
|
1254
|
+
# Draw borders between predominance fields
|
|
1255
|
+
if lwd > 0:
|
|
1256
|
+
# Use matplotlib's contour function to draw boundaries between species
|
|
1257
|
+
# This matches R CHNOSZ's use of contourLines()
|
|
1258
|
+
# Following R CHNOSZ approach: loop over species and draw contours at level 0.5
|
|
1259
|
+
|
|
1260
|
+
# Get unique species values (excluding any that don't appear)
|
|
1261
|
+
unique_species = np.unique(predominant_indices_T)
|
|
1262
|
+
unique_species = unique_species[~np.isnan(unique_species)]
|
|
1263
|
+
|
|
1264
|
+
# Create meshgrid for contour (matches actual data coordinates)
|
|
1265
|
+
X, Y = np.meshgrid(xvals, yvals)
|
|
1266
|
+
|
|
1267
|
+
# Loop over species (except the last one to avoid double-plotting)
|
|
1268
|
+
for i in range(len(unique_species) - 1):
|
|
1269
|
+
species_idx = int(unique_species[i])
|
|
1270
|
+
|
|
1271
|
+
# Create a binary mask: 1 where this species predominates, 0 elsewhere
|
|
1272
|
+
z = (predominant_indices_T == species_idx).astype(float)
|
|
1273
|
+
|
|
1274
|
+
# Draw contour at level 0.5 (boundary between 0 and 1)
|
|
1275
|
+
# This creates smooth boundaries following the actual grid
|
|
1276
|
+
try:
|
|
1277
|
+
line_color = boundary_colors[species_idx]
|
|
1278
|
+
cs = ax.contour(X, Y, z, levels=[0.5], colors=[line_color],
|
|
1279
|
+
linewidths=lwd, zorder=10)
|
|
1280
|
+
except:
|
|
1281
|
+
pass # Skip if contour can't be drawn (e.g., species doesn't appear)
|
|
1282
|
+
|
|
1283
|
+
if not add:
|
|
1284
|
+
plt.tight_layout()
|
|
1285
|
+
|
|
1286
|
+
# Don't close the figure when plot_it=False
|
|
1287
|
+
# The plt.ioff() above already prevents auto-display in Jupyter
|
|
1288
|
+
# This keeps the figure available for adding titles, legends, and later display
|
|
1289
|
+
# Users can display with: d['fig'].show() or display(d['fig']) in Jupyter
|
|
1290
|
+
|
|
1291
|
+
# Build output dictionary (matching R CHNOSZ structure)
|
|
1292
|
+
result = {
|
|
1293
|
+
**eout,
|
|
1294
|
+
'plotvar': plotvar,
|
|
1295
|
+
'plotvals': plotvals,
|
|
1296
|
+
'names': names,
|
|
1297
|
+
'predominant': predominant,
|
|
1298
|
+
'predominant.values': predominant_values,
|
|
1299
|
+
'balance': balance,
|
|
1300
|
+
'n.balance': n_balance
|
|
1301
|
+
}
|
|
1302
|
+
|
|
1303
|
+
# Add figure and axes to output if they were created
|
|
1304
|
+
if fig is not None:
|
|
1305
|
+
if not ax_was_provided or plot_was_provided:
|
|
1306
|
+
result['ax'] = ax
|
|
1307
|
+
result['fig'] = fig
|
|
1308
|
+
|
|
1309
|
+
# Always restore interactive mode to its original state
|
|
1310
|
+
if was_interactive and not plt.isinteractive():
|
|
1311
|
+
plt.ion()
|
|
1312
|
+
elif not was_interactive and plt.isinteractive():
|
|
1313
|
+
plt.ioff()
|
|
1314
|
+
|
|
1315
|
+
return result
|
|
1316
|
+
|
|
1317
|
+
|
|
1318
|
+
def _plot_saturation_2d(eout: Dict[str, Any],
|
|
1319
|
+
plotvals: Dict[int, np.ndarray],
|
|
1320
|
+
plotvar: str,
|
|
1321
|
+
names: List[str],
|
|
1322
|
+
n_balance: List[float],
|
|
1323
|
+
balance: Union[str, int, List[float]],
|
|
1324
|
+
xlab: str,
|
|
1325
|
+
ylab: str,
|
|
1326
|
+
xlim: Optional[List[float]],
|
|
1327
|
+
ylim: Optional[List[float]],
|
|
1328
|
+
col: Optional[Union[str, List[str]]],
|
|
1329
|
+
lwd: Union[float, List[float]],
|
|
1330
|
+
lty: Optional[Union[str, int, List]],
|
|
1331
|
+
cex: Union[float, List[float]],
|
|
1332
|
+
main: Optional[str],
|
|
1333
|
+
add: bool,
|
|
1334
|
+
plot_it: bool,
|
|
1335
|
+
ax: Optional[Any],
|
|
1336
|
+
contour_method: Optional[Union[str, List[str]]],
|
|
1337
|
+
messages: bool = True,
|
|
1338
|
+
width: int = 600,
|
|
1339
|
+
height: int = 520,
|
|
1340
|
+
plot_was_provided: bool = False,
|
|
1341
|
+
**kwargs) -> Dict[str, Any]:
|
|
1342
|
+
"""
|
|
1343
|
+
Plot saturation lines (affinity=0 contours) for 2-D diagrams.
|
|
1344
|
+
|
|
1345
|
+
This function draws contour lines where affinity = 0 for each species,
|
|
1346
|
+
indicating saturation boundaries (e.g., mineral precipitation thresholds).
|
|
1347
|
+
|
|
1348
|
+
Parameters
|
|
1349
|
+
----------
|
|
1350
|
+
(Most parameters are the same as _plot_2d)
|
|
1351
|
+
contour_method : str, list of str, or None
|
|
1352
|
+
Method for labeling contour lines. Can be a single value (applied to all species)
|
|
1353
|
+
or a list (one per species). If None, NA, or "", disable labels.
|
|
1354
|
+
Matplotlib doesn't support the same contour methods as R, so this mainly
|
|
1355
|
+
controls whether labels are drawn (any non-None/non-empty value enables labels).
|
|
1356
|
+
|
|
1357
|
+
Returns
|
|
1358
|
+
-------
|
|
1359
|
+
dict
|
|
1360
|
+
Result dictionary
|
|
1361
|
+
"""
|
|
1362
|
+
|
|
1363
|
+
# Get the two variables
|
|
1364
|
+
vars_list = eout['vars']
|
|
1365
|
+
xvar = vars_list[0]
|
|
1366
|
+
yvar = vars_list[1]
|
|
1367
|
+
|
|
1368
|
+
# Get the values for each variable
|
|
1369
|
+
xvals = np.asarray(eout['vals'][xvar])
|
|
1370
|
+
yvals = np.asarray(eout['vals'][yvar])
|
|
1371
|
+
|
|
1372
|
+
species_keys = list(plotvals.keys())
|
|
1373
|
+
n_species = len(species_keys)
|
|
1374
|
+
|
|
1375
|
+
if messages:
|
|
1376
|
+
print(f"diagram: plotting saturation lines for 2-D diagram")
|
|
1377
|
+
|
|
1378
|
+
# Set up colors and line styles
|
|
1379
|
+
if col is None:
|
|
1380
|
+
# Use matplotlib default color cycle
|
|
1381
|
+
prop_cycle = plt.rcParams['axes.prop_cycle']
|
|
1382
|
+
colors = prop_cycle.by_key()['color']
|
|
1383
|
+
col = [colors[i % len(colors)] for i in range(n_species)]
|
|
1384
|
+
elif isinstance(col, str):
|
|
1385
|
+
col = [col] * n_species
|
|
1386
|
+
else:
|
|
1387
|
+
col = list(col)
|
|
1388
|
+
if len(col) < n_species:
|
|
1389
|
+
col = col * (n_species // len(col) + 1)
|
|
1390
|
+
col = col[:n_species]
|
|
1391
|
+
|
|
1392
|
+
if isinstance(lwd, (int, float)):
|
|
1393
|
+
lwd = [lwd] * n_species
|
|
1394
|
+
else:
|
|
1395
|
+
lwd = list(lwd)
|
|
1396
|
+
if len(lwd) < n_species:
|
|
1397
|
+
lwd = lwd * (n_species // len(lwd) + 1)
|
|
1398
|
+
lwd = lwd[:n_species]
|
|
1399
|
+
|
|
1400
|
+
# Handle line styles (lty)
|
|
1401
|
+
if lty is None:
|
|
1402
|
+
lty = ['-'] * n_species
|
|
1403
|
+
elif isinstance(lty, (str, int)):
|
|
1404
|
+
lty = [lty] * n_species
|
|
1405
|
+
else:
|
|
1406
|
+
lty = list(lty)
|
|
1407
|
+
if len(lty) < n_species:
|
|
1408
|
+
lty = lty * (n_species // len(lty) + 1)
|
|
1409
|
+
lty = lty[:n_species]
|
|
1410
|
+
|
|
1411
|
+
# Convert numeric line styles to matplotlib styles
|
|
1412
|
+
lty_map = {1: '-', 2: '--', 3: '-.', 4: ':', 5: '-', 6: '--'}
|
|
1413
|
+
lty = [lty_map.get(lt, lt) if isinstance(lt, int) else lt for lt in lty]
|
|
1414
|
+
|
|
1415
|
+
# Handle text size (cex) for contour labels
|
|
1416
|
+
if isinstance(cex, (int, float)):
|
|
1417
|
+
cex_list = [cex] * n_species
|
|
1418
|
+
else:
|
|
1419
|
+
cex_list = list(cex)
|
|
1420
|
+
if len(cex_list) < n_species:
|
|
1421
|
+
cex_list = cex_list * (n_species // len(cex_list) + 1)
|
|
1422
|
+
cex_list = cex_list[:n_species]
|
|
1423
|
+
|
|
1424
|
+
# Determine if labels should be drawn (per species)
|
|
1425
|
+
# Convert contour_method to a list (one per species)
|
|
1426
|
+
if contour_method is None or contour_method == "" or (isinstance(contour_method, str) and contour_method.upper() == "NA"):
|
|
1427
|
+
# No labels for any species
|
|
1428
|
+
drawlabels = [False] * n_species
|
|
1429
|
+
elif isinstance(contour_method, str):
|
|
1430
|
+
# Same method for all species
|
|
1431
|
+
drawlabels = [True] * n_species
|
|
1432
|
+
elif isinstance(contour_method, list):
|
|
1433
|
+
# Per-species methods
|
|
1434
|
+
if len(contour_method) != n_species:
|
|
1435
|
+
# Repeat/extend to match number of species
|
|
1436
|
+
contour_method_extended = list(contour_method) * (n_species // len(contour_method) + 1)
|
|
1437
|
+
contour_method_extended = contour_method_extended[:n_species]
|
|
1438
|
+
else:
|
|
1439
|
+
contour_method_extended = contour_method
|
|
1440
|
+
|
|
1441
|
+
# Check each method to determine if labels should be drawn
|
|
1442
|
+
drawlabels = []
|
|
1443
|
+
for method in contour_method_extended:
|
|
1444
|
+
if method is None or method == "" or (isinstance(method, str) and method.upper() == "NA"):
|
|
1445
|
+
drawlabels.append(False)
|
|
1446
|
+
else:
|
|
1447
|
+
drawlabels.append(True)
|
|
1448
|
+
else:
|
|
1449
|
+
drawlabels = [True] * n_species
|
|
1450
|
+
|
|
1451
|
+
# Temporarily disable interactive mode if plot_it=False
|
|
1452
|
+
# This prevents Jupyter from auto-displaying the figure
|
|
1453
|
+
was_interactive = plt.isinteractive()
|
|
1454
|
+
if not plot_it:
|
|
1455
|
+
plt.ioff()
|
|
1456
|
+
|
|
1457
|
+
# Convert width and height from pixels to inches for matplotlib
|
|
1458
|
+
# Use standard 96 DPI for consistency with web/screen displays
|
|
1459
|
+
dpi = 96
|
|
1460
|
+
figsize_inches = (width / dpi, height / dpi)
|
|
1461
|
+
|
|
1462
|
+
# Create figure and axes (always, even if plot_it=False)
|
|
1463
|
+
# This allows the plot to be used with add_to parameter later
|
|
1464
|
+
fig = None
|
|
1465
|
+
ax_was_provided = ax is not None # Track if ax was passed as parameter
|
|
1466
|
+
|
|
1467
|
+
if ax is not None:
|
|
1468
|
+
fig = ax.get_figure()
|
|
1469
|
+
elif not add:
|
|
1470
|
+
fig, ax = plt.subplots(figsize=figsize_inches, dpi=dpi)
|
|
1471
|
+
else:
|
|
1472
|
+
try:
|
|
1473
|
+
ax = plt.gca()
|
|
1474
|
+
fig = ax.get_figure()
|
|
1475
|
+
except:
|
|
1476
|
+
fig, ax = plt.subplots(figsize=figsize_inches, dpi=dpi)
|
|
1477
|
+
|
|
1478
|
+
# Only do the actual plotting if plot_it=True
|
|
1479
|
+
# Draw the plot content (always, regardless of plot_it)
|
|
1480
|
+
# plot_it only controls display, not drawing
|
|
1481
|
+
if not add:
|
|
1482
|
+
ax.set_xlabel(xlab)
|
|
1483
|
+
ax.set_ylabel(ylab)
|
|
1484
|
+
|
|
1485
|
+
if xlim is not None:
|
|
1486
|
+
ax.set_xlim(xlim)
|
|
1487
|
+
else:
|
|
1488
|
+
ax.set_xlim([xvals.min(), xvals.max()])
|
|
1489
|
+
|
|
1490
|
+
if ylim is not None:
|
|
1491
|
+
ax.set_ylim(ylim)
|
|
1492
|
+
else:
|
|
1493
|
+
ax.set_ylim([yvals.min(), yvals.max()])
|
|
1494
|
+
|
|
1495
|
+
if main is not None:
|
|
1496
|
+
ax.set_title(main)
|
|
1497
|
+
|
|
1498
|
+
# Plot saturation lines (affinity = 0 contours) for each species
|
|
1499
|
+
for i, sp_idx in enumerate(species_keys):
|
|
1500
|
+
zs = plotvals[sp_idx]
|
|
1501
|
+
|
|
1502
|
+
# Skip plotting if this species has no possible saturation line
|
|
1503
|
+
if len(np.unique(zs)) == 1:
|
|
1504
|
+
if messages:
|
|
1505
|
+
print(f"diagram: no saturation line possible for {names[i]}")
|
|
1506
|
+
continue
|
|
1507
|
+
|
|
1508
|
+
# Skip if line is outside the plot range
|
|
1509
|
+
if np.all(zs < 0) or np.all(zs > 0):
|
|
1510
|
+
if messages:
|
|
1511
|
+
print(f"diagram: beyond range for saturation line of {names[i]}")
|
|
1512
|
+
continue
|
|
1513
|
+
|
|
1514
|
+
# Draw the contour line at affinity = 0
|
|
1515
|
+
# matplotlib's contour needs (X, Y, Z) where X and Y are meshgrids
|
|
1516
|
+
X, Y = np.meshgrid(xvals, yvals)
|
|
1517
|
+
# Transpose zs to match meshgrid orientation
|
|
1518
|
+
# plotvals has shape (n_xvar, n_yvar), but contour expects (n_yvar, n_xvar)
|
|
1519
|
+
zs_T = zs.T
|
|
1520
|
+
|
|
1521
|
+
try:
|
|
1522
|
+
# Calculate font size from cex (matplotlib default is ~10pt)
|
|
1523
|
+
fontsize = 9 * cex_list[i]
|
|
1524
|
+
|
|
1525
|
+
if drawlabels[i]:
|
|
1526
|
+
CS = ax.contour(X, Y, zs_T, levels=[0], colors=col[i],
|
|
1527
|
+
linewidths=lwd[i], linestyles=lty[i], **kwargs)
|
|
1528
|
+
ax.clabel(CS, inline=True, fontsize=fontsize, fmt=names[i])
|
|
1529
|
+
else:
|
|
1530
|
+
ax.contour(X, Y, zs_T, levels=[0], colors=col[i],
|
|
1531
|
+
linewidths=lwd[i], linestyles=lty[i], **kwargs)
|
|
1532
|
+
except Exception as e:
|
|
1533
|
+
warnings.warn(f"Could not plot contour for {names[i]}: {e}")
|
|
1534
|
+
|
|
1535
|
+
if not add:
|
|
1536
|
+
plt.tight_layout()
|
|
1537
|
+
|
|
1538
|
+
# Build output dictionary
|
|
1539
|
+
result = {
|
|
1540
|
+
**eout,
|
|
1541
|
+
'plotvar': plotvar,
|
|
1542
|
+
'plotvals': plotvals,
|
|
1543
|
+
'names': names,
|
|
1544
|
+
'predominant': np.nan,
|
|
1545
|
+
'balance': balance,
|
|
1546
|
+
'n.balance': n_balance
|
|
1547
|
+
}
|
|
1548
|
+
|
|
1549
|
+
# Add figure and axes to output if they were created
|
|
1550
|
+
if fig is not None:
|
|
1551
|
+
if not ax_was_provided or plot_was_provided:
|
|
1552
|
+
result['ax'] = ax
|
|
1553
|
+
result['fig'] = fig
|
|
1554
|
+
|
|
1555
|
+
# Always restore interactive mode to its original state
|
|
1556
|
+
if was_interactive and not plt.isinteractive():
|
|
1557
|
+
plt.ion()
|
|
1558
|
+
elif not was_interactive and plt.isinteractive():
|
|
1559
|
+
plt.ioff()
|
|
1560
|
+
|
|
1561
|
+
return result
|
|
1562
|
+
|
|
1563
|
+
|
|
1564
|
+
def _axis_label(var: str, eout: Dict[str, Any]) -> str:
|
|
1565
|
+
"""
|
|
1566
|
+
Generate axis label for a variable.
|
|
1567
|
+
|
|
1568
|
+
Parameters
|
|
1569
|
+
----------
|
|
1570
|
+
var : str
|
|
1571
|
+
Variable name
|
|
1572
|
+
eout : dict
|
|
1573
|
+
Output from affinity()
|
|
1574
|
+
|
|
1575
|
+
Returns
|
|
1576
|
+
-------
|
|
1577
|
+
str
|
|
1578
|
+
Formatted axis label
|
|
1579
|
+
"""
|
|
1580
|
+
|
|
1581
|
+
# Special cases
|
|
1582
|
+
if var == 'A/(2.303RT)':
|
|
1583
|
+
return r'A/(2.303RT)'
|
|
1584
|
+
elif var == 'alpha':
|
|
1585
|
+
return r'$\alpha$'
|
|
1586
|
+
elif var == 'loga.equil':
|
|
1587
|
+
return r'log activity'
|
|
1588
|
+
elif var == 'pH':
|
|
1589
|
+
return 'pH'
|
|
1590
|
+
elif var == 'pe':
|
|
1591
|
+
return 'pe'
|
|
1592
|
+
elif var == 'Eh':
|
|
1593
|
+
return 'Eh (V)'
|
|
1594
|
+
elif var == 'T':
|
|
1595
|
+
return 'Temperature (°C)'
|
|
1596
|
+
elif var == 'P':
|
|
1597
|
+
return 'Pressure (bar)'
|
|
1598
|
+
elif var == 'IS':
|
|
1599
|
+
return 'Ionic strength'
|
|
1600
|
+
else:
|
|
1601
|
+
# Check if it's a basis species
|
|
1602
|
+
basis_df = eout.get('basis')
|
|
1603
|
+
if basis_df is not None and var in basis_df.index:
|
|
1604
|
+
state = basis_df.loc[var, 'state']
|
|
1605
|
+
# Format the chemical formula with proper LaTeX subscripts and superscripts
|
|
1606
|
+
var_formatted = _format_species_latex(var)
|
|
1607
|
+
if state in ['aq', 'liq', 'cr']:
|
|
1608
|
+
return f'$\\log\\ a_{{{var_formatted}}}$'
|
|
1609
|
+
else:
|
|
1610
|
+
return f'$\\log\\ f_{{{var_formatted}}}$'
|
|
1611
|
+
return var
|
|
1612
|
+
|
|
1613
|
+
|
|
1614
|
+
def _format_chemname(name: str) -> str:
|
|
1615
|
+
"""
|
|
1616
|
+
Format a chemical formula for display in matplotlib.
|
|
1617
|
+
|
|
1618
|
+
Uses LaTeX formatting for proper subscripts and superscripts.
|
|
1619
|
+
Delegates to _format_species_latex from expression.py for consistency.
|
|
1620
|
+
|
|
1621
|
+
Parameters
|
|
1622
|
+
----------
|
|
1623
|
+
name : str
|
|
1624
|
+
Chemical formula
|
|
1625
|
+
|
|
1626
|
+
Returns
|
|
1627
|
+
-------
|
|
1628
|
+
str
|
|
1629
|
+
Formatted formula (using LaTeX for matplotlib)
|
|
1630
|
+
"""
|
|
1631
|
+
# Use the centralized formatting function and wrap in math mode for matplotlib
|
|
1632
|
+
latex_formula = _format_species_latex(name)
|
|
1633
|
+
return f'${latex_formula}$'
|
|
1634
|
+
|
|
1635
|
+
|
|
1636
|
+
def water_lines(eout: Dict[str, Any],
|
|
1637
|
+
which: Union[str, List[str]] = ['oxidation', 'reduction'],
|
|
1638
|
+
lty: Union[int, str] = 2,
|
|
1639
|
+
lwd: float = 1,
|
|
1640
|
+
col: Optional[str] = None,
|
|
1641
|
+
plot_it: bool = True,
|
|
1642
|
+
messages: bool = True) -> Dict[str, Any]:
|
|
1643
|
+
"""
|
|
1644
|
+
Draw water stability limits for Eh-pH, logfO2-pH, logfO2-T or Eh-T diagrams.
|
|
1645
|
+
|
|
1646
|
+
This function adds lines showing the oxidation and reduction stability limits
|
|
1647
|
+
of water to diagrams. Above the oxidation line, water breaks down to O2.
|
|
1648
|
+
Below the reduction line, water breaks down to H2.
|
|
1649
|
+
|
|
1650
|
+
Parameters
|
|
1651
|
+
----------
|
|
1652
|
+
eout : dict
|
|
1653
|
+
Output from affinity(), equilibrate(), or diagram()
|
|
1654
|
+
which : str or list of str, default ['oxidation', 'reduction']
|
|
1655
|
+
Which line(s) to draw: 'oxidation', 'reduction', or both
|
|
1656
|
+
lty : int or str, default 2
|
|
1657
|
+
Line style (matplotlib linestyle or numeric code)
|
|
1658
|
+
lwd : float, default 1
|
|
1659
|
+
Line width
|
|
1660
|
+
col : str, optional
|
|
1661
|
+
Line color (matplotlib color spec). If None, uses current foreground color
|
|
1662
|
+
plot_it : bool, default True
|
|
1663
|
+
Whether to plot the lines and display the figure. When True, the lines
|
|
1664
|
+
are added to the diagram and the figure is displayed (useful when the
|
|
1665
|
+
original diagram was created with plot_it=False). When False, only
|
|
1666
|
+
calculates and returns the water line coordinates without plotting.
|
|
1667
|
+
|
|
1668
|
+
Returns
|
|
1669
|
+
-------
|
|
1670
|
+
dict
|
|
1671
|
+
Dictionary containing all keys from the input diagram (including 'fig', 'ax',
|
|
1672
|
+
'plotvar', 'plotvals', 'names', 'predominant', etc. if present) plus the
|
|
1673
|
+
following water line specific keys:
|
|
1674
|
+
- xpoints: x-axis values
|
|
1675
|
+
- y_oxidation: y values for oxidation line (or None)
|
|
1676
|
+
- y_reduction: y values for reduction line (or None)
|
|
1677
|
+
- swapped: whether axes were swapped
|
|
1678
|
+
|
|
1679
|
+
Examples
|
|
1680
|
+
--------
|
|
1681
|
+
>>> # Add water lines to an existing displayed diagram
|
|
1682
|
+
>>> basis(["Fe+2", "SO4-2", "H2O", "H+", "e-"], [0, math.log10(3), math.log10(0.75), 999, 999])
|
|
1683
|
+
>>> species(["rhomboclase", "ferricopiapite", "hydronium jarosite", "goethite", "melanterite", "pyrite"])
|
|
1684
|
+
>>> a = affinity(pH=[-1, 4, 256], pe=[-5, 23, 256])
|
|
1685
|
+
>>> d = diagram(a, main="Fe-S-O-H, after Majzlan et al., 2006")
|
|
1686
|
+
>>> water_lines(d, lwd=2)
|
|
1687
|
+
|
|
1688
|
+
>>> # Add water lines and display when diagram was created with plot_it=False
|
|
1689
|
+
>>> d = diagram(a, main="Fe-S-O-H", plot_it=False)
|
|
1690
|
+
>>> water_lines(d, lwd=2) # This will display the figure with water lines
|
|
1691
|
+
|
|
1692
|
+
Notes
|
|
1693
|
+
-----
|
|
1694
|
+
This function only works on diagrams with a redox variable (Eh, pe, O2, or H2)
|
|
1695
|
+
on one axis and pH, T, P, or another non-redox variable on the other axis.
|
|
1696
|
+
For 1-D diagrams, vertical lines are drawn.
|
|
1697
|
+
"""
|
|
1698
|
+
|
|
1699
|
+
# Import here to avoid circular imports
|
|
1700
|
+
from ..utils.units import convert, envert
|
|
1701
|
+
from ..core.subcrt import subcrt
|
|
1702
|
+
|
|
1703
|
+
# Create a deep copy of the input to preserve all diagram information
|
|
1704
|
+
# This allows us to return all the original keys plus water line data
|
|
1705
|
+
result = copy_plot(eout)
|
|
1706
|
+
|
|
1707
|
+
# Detect if this is a Plotly figure (interactive diagram)
|
|
1708
|
+
is_plotly = False
|
|
1709
|
+
if 'fig' in result and result['fig'] is not None:
|
|
1710
|
+
is_plotly = hasattr(result['fig'], 'add_trace') and hasattr(result['fig'], 'update_layout')
|
|
1711
|
+
|
|
1712
|
+
# Ensure which is a list
|
|
1713
|
+
if isinstance(which, str):
|
|
1714
|
+
which = [which]
|
|
1715
|
+
|
|
1716
|
+
# Get number of variables used in affinity()
|
|
1717
|
+
nvar1 = len(result['vars'])
|
|
1718
|
+
|
|
1719
|
+
# Determine actual number of variables from array dimensions
|
|
1720
|
+
# Check both loga.equil (equilibrate output) and values (affinity output)
|
|
1721
|
+
if 'loga_equil' in result or 'loga.equil' in result:
|
|
1722
|
+
loga_key = 'loga_equil' if 'loga_equil' in result else 'loga.equil'
|
|
1723
|
+
first_val = result[loga_key][0] if isinstance(result[loga_key], list) else list(result[loga_key].values())[0]
|
|
1724
|
+
else:
|
|
1725
|
+
first_val = list(result['values'].values())[0] if isinstance(result['values'], dict) else result['values'][0]
|
|
1726
|
+
|
|
1727
|
+
if hasattr(first_val, 'shape'):
|
|
1728
|
+
dim = first_val.shape
|
|
1729
|
+
elif hasattr(first_val, '__len__'):
|
|
1730
|
+
dim = (len(first_val),)
|
|
1731
|
+
else:
|
|
1732
|
+
dim = ()
|
|
1733
|
+
|
|
1734
|
+
nvar2 = len(dim)
|
|
1735
|
+
|
|
1736
|
+
# We only work on diagrams with 1 or 2 variables
|
|
1737
|
+
if nvar1 not in [1, 2] or nvar2 not in [1, 2]:
|
|
1738
|
+
result.update({'xpoints': None, 'y_oxidation': None, 'y_reduction': None, 'swapped': False})
|
|
1739
|
+
return result
|
|
1740
|
+
|
|
1741
|
+
# Get variables from result
|
|
1742
|
+
vars_list = result['vars'].copy()
|
|
1743
|
+
|
|
1744
|
+
# If needed, swap axes so redox variable is on y-axis
|
|
1745
|
+
# Also do this for 1-D diagrams
|
|
1746
|
+
if len(vars_list) == 1:
|
|
1747
|
+
vars_list.append('nothing')
|
|
1748
|
+
|
|
1749
|
+
swapped = False
|
|
1750
|
+
if vars_list[1] in ['T', 'P', 'nothing']:
|
|
1751
|
+
vars_list = list(reversed(vars_list))
|
|
1752
|
+
vals_dict = {vars_list[0]: result['vals'][vars_list[0]]} if vars_list[0] != 'nothing' else {}
|
|
1753
|
+
if len(result['vars']) > 1:
|
|
1754
|
+
vals_dict[vars_list[1]] = result['vals'][vars_list[1]]
|
|
1755
|
+
swapped = True
|
|
1756
|
+
else:
|
|
1757
|
+
vals_dict = result['vals']
|
|
1758
|
+
|
|
1759
|
+
xaxis = vars_list[0]
|
|
1760
|
+
yaxis = vars_list[1]
|
|
1761
|
+
xpoints = np.asarray(vals_dict[xaxis]) if xaxis in vals_dict else np.array([0])
|
|
1762
|
+
|
|
1763
|
+
# Make xaxis "nothing" if it is not pH, T, or P
|
|
1764
|
+
# (so that horizontal water lines can be drawn for any non-redox variable on the x-axis)
|
|
1765
|
+
if xaxis not in ['pH', 'T', 'P']:
|
|
1766
|
+
xaxis = 'nothing'
|
|
1767
|
+
|
|
1768
|
+
# T and P are constants unless they are plotted on one of the axes
|
|
1769
|
+
T = result['T']
|
|
1770
|
+
if vars_list[0] == 'T':
|
|
1771
|
+
T = envert(xpoints, 'K')
|
|
1772
|
+
P = result['P']
|
|
1773
|
+
if vars_list[0] == 'P':
|
|
1774
|
+
P = envert(xpoints, 'bar')
|
|
1775
|
+
|
|
1776
|
+
# Handle the case where P is "Psat" - keep it as is for subcrt
|
|
1777
|
+
# (subcrt knows how to handle "Psat")
|
|
1778
|
+
|
|
1779
|
+
# logaH2O is 0 unless given in result['basis']
|
|
1780
|
+
basis_df = result['basis']
|
|
1781
|
+
if 'H2O' in basis_df.index:
|
|
1782
|
+
logaH2O = float(basis_df.loc['H2O', 'logact'])
|
|
1783
|
+
else:
|
|
1784
|
+
logaH2O = 0
|
|
1785
|
+
|
|
1786
|
+
# pH is 7 unless given in eout['basis'] or plotted on one of the axes
|
|
1787
|
+
if vars_list[0] == 'pH':
|
|
1788
|
+
pH = xpoints
|
|
1789
|
+
elif 'H+' in basis_df.index:
|
|
1790
|
+
minuspH = basis_df.loc['H+', 'logact']
|
|
1791
|
+
# Special treatment for non-numeric value (happens when a buffer is used)
|
|
1792
|
+
try:
|
|
1793
|
+
pH = -float(minuspH)
|
|
1794
|
+
except (ValueError, TypeError):
|
|
1795
|
+
pH = np.nan
|
|
1796
|
+
else:
|
|
1797
|
+
pH = 7
|
|
1798
|
+
|
|
1799
|
+
# O2 state is gas unless given in eout['basis']
|
|
1800
|
+
O2state = 'gas'
|
|
1801
|
+
if 'O2' in basis_df.index:
|
|
1802
|
+
O2state = basis_df.loc['O2', 'state']
|
|
1803
|
+
|
|
1804
|
+
# H2 state is gas unless given in eout['basis']
|
|
1805
|
+
H2state = 'gas'
|
|
1806
|
+
if 'H2' in basis_df.index:
|
|
1807
|
+
H2state = basis_df.loc['H2', 'state']
|
|
1808
|
+
|
|
1809
|
+
# Where the calculated values will go
|
|
1810
|
+
y_oxidation = None
|
|
1811
|
+
y_reduction = None
|
|
1812
|
+
|
|
1813
|
+
if xaxis in ['pH', 'T', 'P', 'nothing'] and yaxis in ['Eh', 'pe', 'O2', 'H2']:
|
|
1814
|
+
# Eh/pe/logfO2/logaO2/logfH2/logaH2 vs pH/T/P
|
|
1815
|
+
|
|
1816
|
+
# Reduction line (H2O + e- = 1/2 H2 + OH-)
|
|
1817
|
+
if 'reduction' in which:
|
|
1818
|
+
logfH2 = logaH2O # usually 0
|
|
1819
|
+
|
|
1820
|
+
if yaxis == 'H2':
|
|
1821
|
+
# Calculate equilibrium constant for gas-aqueous conversion if needed
|
|
1822
|
+
logK = subcrt(['H2', 'H2'], [-1, 1], ['gas', H2state], T=T, P=P, convert=False, messages=messages, show=False).out['logK']
|
|
1823
|
+
# This is logfH2 if H2state == "gas", or logaH2 if H2state == "aq"
|
|
1824
|
+
logfH2 = logfH2 + logK
|
|
1825
|
+
# Broadcast to match xpoints length
|
|
1826
|
+
if isinstance(logfH2, (int, float)):
|
|
1827
|
+
y_reduction = np.full_like(xpoints, logfH2)
|
|
1828
|
+
else:
|
|
1829
|
+
logfH2_val = float(logfH2.iloc[0]) if hasattr(logfH2, 'iloc') else float(logfH2[0])
|
|
1830
|
+
y_reduction = np.full_like(xpoints, logfH2_val)
|
|
1831
|
+
else:
|
|
1832
|
+
# Calculate logfO2 from H2O = 1/2 O2 + H2
|
|
1833
|
+
logK = subcrt(['H2O', 'O2', 'H2'], [-1, 0.5, 1], ['liq', O2state, 'gas'], T=T, P=P, convert=False, messages=messages, show=False).out['logK']
|
|
1834
|
+
# This is logfO2 if O2state == "gas", or logaO2 if O2state == "aq"
|
|
1835
|
+
logfO2 = 2 * (logK - logfH2 + logaH2O)
|
|
1836
|
+
|
|
1837
|
+
if yaxis == 'O2':
|
|
1838
|
+
# Broadcast to match xpoints length
|
|
1839
|
+
if isinstance(logfO2, (int, float)):
|
|
1840
|
+
y_reduction = np.full_like(xpoints, logfO2)
|
|
1841
|
+
else:
|
|
1842
|
+
logfO2_val = float(logfO2.iloc[0]) if hasattr(logfO2, 'iloc') else float(logfO2[0])
|
|
1843
|
+
y_reduction = np.full_like(xpoints, logfO2_val)
|
|
1844
|
+
elif yaxis == 'Eh':
|
|
1845
|
+
y_reduction = convert(logfO2, 'E0', T=T, P=P, pH=pH, logaH2O=logaH2O, messages=messages)
|
|
1846
|
+
elif yaxis == 'pe':
|
|
1847
|
+
Eh_val = convert(logfO2, 'E0', T=T, P=P, pH=pH, logaH2O=logaH2O, messages=messages)
|
|
1848
|
+
y_reduction = convert(Eh_val, 'pe', T=T, messages=messages)
|
|
1849
|
+
|
|
1850
|
+
# Oxidation line (H2O = 1/2 O2 + 2H+ + 2e-)
|
|
1851
|
+
if 'oxidation' in which:
|
|
1852
|
+
logfO2 = logaH2O # usually 0
|
|
1853
|
+
|
|
1854
|
+
if yaxis == 'H2':
|
|
1855
|
+
# Calculate logfH2 from H2O = 1/2 O2 + H2
|
|
1856
|
+
logK = subcrt(['H2O', 'O2', 'H2'], [-1, 0.5, 1], ['liq', 'gas', H2state], T=T, P=P, convert=False, messages=messages, show=False).out['logK']
|
|
1857
|
+
# This is logfH2 if H2state == "gas", or logaH2 if H2state == "aq"
|
|
1858
|
+
logfH2 = logK - 0.5*logfO2 + logaH2O
|
|
1859
|
+
# Broadcast to match xpoints length
|
|
1860
|
+
if isinstance(logfH2, (int, float)):
|
|
1861
|
+
y_oxidation = np.full_like(xpoints, logfH2)
|
|
1862
|
+
else:
|
|
1863
|
+
logfH2_val = float(logfH2.iloc[0]) if hasattr(logfH2, 'iloc') else float(logfH2[0])
|
|
1864
|
+
y_oxidation = np.full_like(xpoints, logfH2_val)
|
|
1865
|
+
else:
|
|
1866
|
+
# Calculate equilibrium constant for gas-aqueous conversion if needed
|
|
1867
|
+
logK = subcrt(['O2', 'O2'], [-1, 1], ['gas', O2state], T=T, P=P, convert=False, messages=messages, show=False).out['logK']
|
|
1868
|
+
# This is logfO2 if O2state == "gas", or logaO2 if O2state == "aq"
|
|
1869
|
+
logfO2 = logfO2 + logK
|
|
1870
|
+
|
|
1871
|
+
if yaxis == 'O2':
|
|
1872
|
+
# Broadcast to match xpoints length
|
|
1873
|
+
if isinstance(logfO2, (int, float)):
|
|
1874
|
+
y_oxidation = np.full_like(xpoints, logfO2)
|
|
1875
|
+
else:
|
|
1876
|
+
logfO2_val = float(logfO2.iloc[0]) if hasattr(logfO2, 'iloc') else float(logfO2[0])
|
|
1877
|
+
y_oxidation = np.full_like(xpoints, logfO2_val)
|
|
1878
|
+
elif yaxis == 'Eh':
|
|
1879
|
+
y_oxidation = convert(logfO2, 'E0', T=T, P=P, pH=pH, logaH2O=logaH2O, messages=messages)
|
|
1880
|
+
elif yaxis == 'pe':
|
|
1881
|
+
Eh_val = convert(logfO2, 'E0', T=T, P=P, pH=pH, logaH2O=logaH2O, messages=messages)
|
|
1882
|
+
y_oxidation = convert(Eh_val, 'pe', T=T, messages=messages)
|
|
1883
|
+
|
|
1884
|
+
else:
|
|
1885
|
+
# Invalid axis combination
|
|
1886
|
+
result.update({'xpoints': xpoints, 'y_oxidation': None, 'y_reduction': None, 'swapped': swapped})
|
|
1887
|
+
return result
|
|
1888
|
+
|
|
1889
|
+
# Route to Plotly or matplotlib implementation
|
|
1890
|
+
if is_plotly:
|
|
1891
|
+
return _water_lines_plotly(result, xpoints, y_oxidation, y_reduction, swapped,
|
|
1892
|
+
lty, lwd, col, plot_it)
|
|
1893
|
+
|
|
1894
|
+
# Matplotlib implementation
|
|
1895
|
+
# Only draw water lines if eout already has an axes (meaning it's from a diagram)
|
|
1896
|
+
# If no axes, this is being called just for calculation (e.g., from within diagram())
|
|
1897
|
+
if 'ax' not in eout or eout['ax'] is None:
|
|
1898
|
+
# No axes to plot on - just return the calculated values
|
|
1899
|
+
result.update({'xpoints': xpoints, 'y_oxidation': y_oxidation, 'y_reduction': y_reduction, 'swapped': swapped})
|
|
1900
|
+
return result
|
|
1901
|
+
|
|
1902
|
+
# Use the axes from result
|
|
1903
|
+
ax = result['ax']
|
|
1904
|
+
|
|
1905
|
+
# First, shade the water-unstable regions with gray
|
|
1906
|
+
# This creates the same effect as R's fill.NA for H2O.predominant
|
|
1907
|
+
if y_oxidation is not None and y_reduction is not None:
|
|
1908
|
+
from matplotlib.colors import ListedColormap
|
|
1909
|
+
|
|
1910
|
+
# Get current axis limits to create shading
|
|
1911
|
+
xlim = ax.get_xlim()
|
|
1912
|
+
ylim = ax.get_ylim()
|
|
1913
|
+
|
|
1914
|
+
# Create a high-resolution mesh for smooth shading
|
|
1915
|
+
n_points = 500
|
|
1916
|
+
if swapped:
|
|
1917
|
+
# When swapped, xpoints is on the y-axis
|
|
1918
|
+
y_mesh = np.linspace(ylim[0], ylim[1], n_points)
|
|
1919
|
+
x_mesh = np.linspace(xlim[0], xlim[1], n_points)
|
|
1920
|
+
X, Y = np.meshgrid(x_mesh, y_mesh)
|
|
1921
|
+
|
|
1922
|
+
# For each y-value, determine if it's in water-unstable region
|
|
1923
|
+
# Interpolate oxidation and reduction values to the mesh
|
|
1924
|
+
y_ox_interp = np.interp(y_mesh, xpoints, y_oxidation)
|
|
1925
|
+
y_red_interp = np.interp(y_mesh, xpoints, y_reduction)
|
|
1926
|
+
|
|
1927
|
+
# Create mask: unstable where x < min or x > max
|
|
1928
|
+
unstable = np.zeros_like(X, dtype=bool)
|
|
1929
|
+
for i in range(n_points):
|
|
1930
|
+
ymin = min(y_ox_interp[i], y_red_interp[i])
|
|
1931
|
+
ymax = max(y_ox_interp[i], y_red_interp[i])
|
|
1932
|
+
unstable[i, :] = (X[i, :] < ymin) | (X[i, :] > ymax)
|
|
1933
|
+
else:
|
|
1934
|
+
# Normal: xpoints on x-axis, y values on y-axis
|
|
1935
|
+
x_mesh = np.linspace(xlim[0], xlim[1], n_points)
|
|
1936
|
+
y_mesh = np.linspace(ylim[0], ylim[1], n_points)
|
|
1937
|
+
X, Y = np.meshgrid(x_mesh, y_mesh)
|
|
1938
|
+
|
|
1939
|
+
# Interpolate oxidation and reduction values to the mesh
|
|
1940
|
+
y_ox_interp = np.interp(x_mesh, xpoints, y_oxidation)
|
|
1941
|
+
y_red_interp = np.interp(x_mesh, xpoints, y_reduction)
|
|
1942
|
+
|
|
1943
|
+
# Create mask: unstable where y < min or y > max
|
|
1944
|
+
unstable = np.zeros_like(Y, dtype=bool)
|
|
1945
|
+
for i in range(n_points):
|
|
1946
|
+
ymin = min(y_ox_interp[i], y_red_interp[i])
|
|
1947
|
+
ymax = max(y_ox_interp[i], y_red_interp[i])
|
|
1948
|
+
unstable[:, i] = (Y[:, i] < ymin) | (Y[:, i] > ymax)
|
|
1949
|
+
|
|
1950
|
+
# Create masked array for unstable regions
|
|
1951
|
+
import numpy.ma as ma
|
|
1952
|
+
unstable_mask = ma.masked_where(~unstable, np.ones_like(X))
|
|
1953
|
+
|
|
1954
|
+
# Draw the shading with gray (matching R's gray80 = 0.8)
|
|
1955
|
+
fill_na_cmap = ListedColormap(['0.8'])
|
|
1956
|
+
extent = [xlim[0], xlim[1], ylim[0], ylim[1]]
|
|
1957
|
+
ax.imshow(unstable_mask, aspect='auto', origin='lower',
|
|
1958
|
+
extent=extent, interpolation='nearest',
|
|
1959
|
+
cmap=fill_na_cmap, vmin=0, vmax=1, zorder=1)
|
|
1960
|
+
|
|
1961
|
+
# Set line color
|
|
1962
|
+
if col is None:
|
|
1963
|
+
col = 'black'
|
|
1964
|
+
|
|
1965
|
+
# Convert numeric line style to matplotlib style
|
|
1966
|
+
lty_map = {1: '-', 2: '--', 3: '-.', 4: ':', 5: '-', 6: '--'}
|
|
1967
|
+
if isinstance(lty, int):
|
|
1968
|
+
lty = lty_map.get(lty, '--')
|
|
1969
|
+
|
|
1970
|
+
if swapped:
|
|
1971
|
+
if nvar1 == 1 or nvar2 == 2:
|
|
1972
|
+
# Add vertical lines on 1-D diagram
|
|
1973
|
+
if y_oxidation is not None and len(y_oxidation) > 0:
|
|
1974
|
+
ax.axvline(x=y_oxidation[0], linestyle=lty, linewidth=lwd, color=col)
|
|
1975
|
+
if y_reduction is not None and len(y_reduction) > 0:
|
|
1976
|
+
ax.axvline(x=y_reduction[0], linestyle=lty, linewidth=lwd, color=col)
|
|
1977
|
+
else:
|
|
1978
|
+
# xpoints above is really the ypoints
|
|
1979
|
+
if y_oxidation is not None:
|
|
1980
|
+
ax.plot(y_oxidation, xpoints, linestyle=lty, linewidth=lwd, color=col)
|
|
1981
|
+
if y_reduction is not None:
|
|
1982
|
+
ax.plot(y_reduction, xpoints, linestyle=lty, linewidth=lwd, color=col)
|
|
1983
|
+
else:
|
|
1984
|
+
if y_oxidation is not None:
|
|
1985
|
+
ax.plot(xpoints, y_oxidation, linestyle=lty, linewidth=lwd, color=col)
|
|
1986
|
+
if y_reduction is not None:
|
|
1987
|
+
ax.plot(xpoints, y_reduction, linestyle=lty, linewidth=lwd, color=col)
|
|
1988
|
+
|
|
1989
|
+
# Update the figure and axes references in result to reflect the water lines
|
|
1990
|
+
fig = ax.get_figure()
|
|
1991
|
+
result['fig'] = fig
|
|
1992
|
+
result['ax'] = ax
|
|
1993
|
+
|
|
1994
|
+
# Display the figure if plot_it=True
|
|
1995
|
+
# This allows water_lines() to display a figure that was created with plot_it=False
|
|
1996
|
+
if plot_it and fig is not None:
|
|
1997
|
+
try:
|
|
1998
|
+
from IPython.display import display
|
|
1999
|
+
display(fig)
|
|
2000
|
+
except (ImportError, NameError):
|
|
2001
|
+
# Not in IPython/Jupyter, matplotlib will handle display
|
|
2002
|
+
pass
|
|
2003
|
+
|
|
2004
|
+
# Update result with water line data and return
|
|
2005
|
+
result.update({'xpoints': xpoints, 'y_oxidation': y_oxidation, 'y_reduction': y_reduction, 'swapped': swapped})
|
|
2006
|
+
return result
|
|
2007
|
+
|
|
2008
|
+
|
|
2009
|
+
def _water_lines_plotly(eout: Dict[str, Any],
|
|
2010
|
+
xpoints: np.ndarray,
|
|
2011
|
+
y_oxidation: Optional[np.ndarray],
|
|
2012
|
+
y_reduction: Optional[np.ndarray],
|
|
2013
|
+
swapped: bool,
|
|
2014
|
+
lty: Union[int, str],
|
|
2015
|
+
lwd: float,
|
|
2016
|
+
col: Optional[str],
|
|
2017
|
+
plot_it: bool) -> Dict[str, Any]:
|
|
2018
|
+
"""
|
|
2019
|
+
Add water stability lines to a Plotly interactive diagram.
|
|
2020
|
+
|
|
2021
|
+
This helper function adds dashed lines showing water oxidation and reduction
|
|
2022
|
+
stability limits, plus gray shading for water-unstable regions, to an
|
|
2023
|
+
interactive Plotly diagram.
|
|
2024
|
+
"""
|
|
2025
|
+
import plotly.graph_objects as go
|
|
2026
|
+
import numpy as np
|
|
2027
|
+
|
|
2028
|
+
# Get the Plotly figure from eout
|
|
2029
|
+
fig = eout['fig']
|
|
2030
|
+
|
|
2031
|
+
# Set line color (default to black)
|
|
2032
|
+
if col is None:
|
|
2033
|
+
col = 'black'
|
|
2034
|
+
|
|
2035
|
+
# Convert numeric/matplotlib line styles to Plotly dash types
|
|
2036
|
+
lty_map = {
|
|
2037
|
+
1: 'solid', '-': 'solid',
|
|
2038
|
+
2: 'dash', '--': 'dash',
|
|
2039
|
+
3: 'dashdot', '-.': 'dashdot',
|
|
2040
|
+
4: 'dot', ':': 'dot',
|
|
2041
|
+
5: 'solid', 6: 'dash'
|
|
2042
|
+
}
|
|
2043
|
+
dash_style = lty_map.get(lty, 'dash') if (isinstance(lty, int) or lty in lty_map) else 'dash'
|
|
2044
|
+
|
|
2045
|
+
# Get axis limits to determine shading extent
|
|
2046
|
+
# We need to extract axis ranges from the existing figure
|
|
2047
|
+
if fig.layout.xaxis.range:
|
|
2048
|
+
xlim = fig.layout.xaxis.range
|
|
2049
|
+
else:
|
|
2050
|
+
# If not set, estimate from data
|
|
2051
|
+
xlim = [xpoints.min(), xpoints.max()]
|
|
2052
|
+
|
|
2053
|
+
if fig.layout.yaxis.range:
|
|
2054
|
+
ylim = fig.layout.yaxis.range
|
|
2055
|
+
else:
|
|
2056
|
+
# If not set, estimate from y_oxidation and y_reduction
|
|
2057
|
+
if y_oxidation is not None and y_reduction is not None:
|
|
2058
|
+
ylim = [min(y_oxidation.min(), y_reduction.min()),
|
|
2059
|
+
max(y_oxidation.max(), y_reduction.max())]
|
|
2060
|
+
else:
|
|
2061
|
+
ylim = [0, 1] # Fallback
|
|
2062
|
+
|
|
2063
|
+
# Add gray shading for water-unstable regions
|
|
2064
|
+
if y_oxidation is not None and y_reduction is not None:
|
|
2065
|
+
# Create high-resolution mesh for smooth shading
|
|
2066
|
+
n_points = 200
|
|
2067
|
+
|
|
2068
|
+
if swapped:
|
|
2069
|
+
# When swapped, xpoints is on y-axis, y values on x-axis
|
|
2070
|
+
# Create shading shapes for regions outside water stability
|
|
2071
|
+
|
|
2072
|
+
# Upper unstable region (above oxidation line)
|
|
2073
|
+
y_mesh = np.linspace(ylim[0], ylim[1], n_points)
|
|
2074
|
+
x_ox_interp = np.interp(y_mesh, xpoints, y_oxidation)
|
|
2075
|
+
|
|
2076
|
+
# Fill from oxidation line to right edge
|
|
2077
|
+
fig.add_trace(go.Scatter(
|
|
2078
|
+
x=np.concatenate([x_ox_interp, [xlim[1]] * len(y_mesh), x_ox_interp[::-1]]),
|
|
2079
|
+
y=np.concatenate([y_mesh, y_mesh[::-1], y_mesh[::-1]]),
|
|
2080
|
+
fill='toself',
|
|
2081
|
+
fillcolor='rgba(128, 128, 128, 0.5)', # Gray with transparency
|
|
2082
|
+
line=dict(width=0),
|
|
2083
|
+
showlegend=False,
|
|
2084
|
+
hoverinfo='skip'
|
|
2085
|
+
))
|
|
2086
|
+
|
|
2087
|
+
# Lower unstable region (below reduction line)
|
|
2088
|
+
x_red_interp = np.interp(y_mesh, xpoints, y_reduction)
|
|
2089
|
+
|
|
2090
|
+
# Fill from left edge to reduction line
|
|
2091
|
+
fig.add_trace(go.Scatter(
|
|
2092
|
+
x=np.concatenate([[xlim[0]] * len(y_mesh), x_red_interp[::-1], [xlim[0]] * len(y_mesh)]),
|
|
2093
|
+
y=np.concatenate([y_mesh, y_mesh[::-1], y_mesh[::-1]]),
|
|
2094
|
+
fill='toself',
|
|
2095
|
+
fillcolor='rgba(128, 128, 128, 0.5)',
|
|
2096
|
+
line=dict(width=0),
|
|
2097
|
+
showlegend=False,
|
|
2098
|
+
hoverinfo='skip'
|
|
2099
|
+
))
|
|
2100
|
+
|
|
2101
|
+
else:
|
|
2102
|
+
# Normal: xpoints on x-axis, y values on y-axis
|
|
2103
|
+
# Interpolate to create smooth shading boundaries
|
|
2104
|
+
x_mesh = np.linspace(xlim[0], xlim[1], n_points)
|
|
2105
|
+
y_ox_interp = np.interp(x_mesh, xpoints, y_oxidation)
|
|
2106
|
+
y_red_interp = np.interp(x_mesh, xpoints, y_reduction)
|
|
2107
|
+
|
|
2108
|
+
# Upper unstable region (above oxidation line)
|
|
2109
|
+
fig.add_trace(go.Scatter(
|
|
2110
|
+
x=np.concatenate([x_mesh, x_mesh[::-1]]),
|
|
2111
|
+
y=np.concatenate([y_ox_interp, [ylim[1]] * len(x_mesh)]),
|
|
2112
|
+
fill='toself',
|
|
2113
|
+
fillcolor='rgba(128, 128, 128, 0.5)',
|
|
2114
|
+
line=dict(width=0),
|
|
2115
|
+
showlegend=False,
|
|
2116
|
+
hoverinfo='skip'
|
|
2117
|
+
))
|
|
2118
|
+
|
|
2119
|
+
# Lower unstable region (below reduction line)
|
|
2120
|
+
fig.add_trace(go.Scatter(
|
|
2121
|
+
x=np.concatenate([x_mesh, x_mesh[::-1]]),
|
|
2122
|
+
y=np.concatenate([[ylim[0]] * len(x_mesh), y_red_interp[::-1]]),
|
|
2123
|
+
fill='toself',
|
|
2124
|
+
fillcolor='rgba(128, 128, 128, 0.5)',
|
|
2125
|
+
line=dict(width=0),
|
|
2126
|
+
showlegend=False,
|
|
2127
|
+
hoverinfo='skip'
|
|
2128
|
+
))
|
|
2129
|
+
|
|
2130
|
+
# Add water stability lines
|
|
2131
|
+
if swapped:
|
|
2132
|
+
# When swapped: xpoints is on y-axis, y values on x-axis
|
|
2133
|
+
if y_oxidation is not None:
|
|
2134
|
+
fig.add_trace(go.Scatter(
|
|
2135
|
+
x=y_oxidation,
|
|
2136
|
+
y=xpoints,
|
|
2137
|
+
mode='lines',
|
|
2138
|
+
line=dict(color=col, width=lwd, dash=dash_style),
|
|
2139
|
+
name='H₂O oxidation limit',
|
|
2140
|
+
showlegend=False,
|
|
2141
|
+
hoverinfo='skip'
|
|
2142
|
+
))
|
|
2143
|
+
if y_reduction is not None:
|
|
2144
|
+
fig.add_trace(go.Scatter(
|
|
2145
|
+
x=y_reduction,
|
|
2146
|
+
y=xpoints,
|
|
2147
|
+
mode='lines',
|
|
2148
|
+
line=dict(color=col, width=lwd, dash=dash_style),
|
|
2149
|
+
name='H₂O reduction limit',
|
|
2150
|
+
showlegend=False,
|
|
2151
|
+
hoverinfo='skip'
|
|
2152
|
+
))
|
|
2153
|
+
else:
|
|
2154
|
+
# Normal orientation
|
|
2155
|
+
if y_oxidation is not None:
|
|
2156
|
+
fig.add_trace(go.Scatter(
|
|
2157
|
+
x=xpoints,
|
|
2158
|
+
y=y_oxidation,
|
|
2159
|
+
mode='lines',
|
|
2160
|
+
line=dict(color=col, width=lwd, dash=dash_style),
|
|
2161
|
+
name='H₂O oxidation limit',
|
|
2162
|
+
showlegend=False,
|
|
2163
|
+
hoverinfo='skip'
|
|
2164
|
+
))
|
|
2165
|
+
if y_reduction is not None:
|
|
2166
|
+
fig.add_trace(go.Scatter(
|
|
2167
|
+
x=xpoints,
|
|
2168
|
+
y=y_reduction,
|
|
2169
|
+
mode='lines',
|
|
2170
|
+
line=dict(color=col, width=lwd, dash=dash_style),
|
|
2171
|
+
name='H₂O reduction limit',
|
|
2172
|
+
showlegend=False,
|
|
2173
|
+
hoverinfo='skip'
|
|
2174
|
+
))
|
|
2175
|
+
|
|
2176
|
+
# Update the figure reference in eout to reflect the water lines
|
|
2177
|
+
eout['fig'] = fig
|
|
2178
|
+
|
|
2179
|
+
# Display the figure if plot_it=True
|
|
2180
|
+
if plot_it:
|
|
2181
|
+
try:
|
|
2182
|
+
from IPython.display import display
|
|
2183
|
+
display(fig)
|
|
2184
|
+
except (ImportError, NameError):
|
|
2185
|
+
# Not in IPython/Jupyter, use fig.show()
|
|
2186
|
+
fig.show()
|
|
2187
|
+
|
|
2188
|
+
# Update eout with water line data and return
|
|
2189
|
+
eout.update({'xpoints': xpoints, 'y_oxidation': y_oxidation, 'y_reduction': y_reduction, 'swapped': swapped})
|
|
2190
|
+
return eout
|
|
2191
|
+
|
|
2192
|
+
|
|
2193
|
+
def find_tp(predominant: np.ndarray) -> np.ndarray:
|
|
2194
|
+
"""
|
|
2195
|
+
Find triple points in a predominance diagram.
|
|
2196
|
+
|
|
2197
|
+
This function identifies the approximate positions of triple points
|
|
2198
|
+
(where three phases meet) in a 2-D predominance diagram by locating
|
|
2199
|
+
cells with the greatest number of different neighboring values.
|
|
2200
|
+
|
|
2201
|
+
Parameters
|
|
2202
|
+
----------
|
|
2203
|
+
predominant : np.ndarray
|
|
2204
|
+
Matrix of integers from diagram() output indicating which species
|
|
2205
|
+
predominates at each point. Should be a 2-D array where each value
|
|
2206
|
+
represents a different species/phase.
|
|
2207
|
+
|
|
2208
|
+
Returns
|
|
2209
|
+
-------
|
|
2210
|
+
np.ndarray
|
|
2211
|
+
Array of shape (n, 2) where n is the number of triple points found.
|
|
2212
|
+
Each row contains [row_index, col_index] of a triple point location.
|
|
2213
|
+
Indices are 1-based to match R behavior.
|
|
2214
|
+
|
|
2215
|
+
Examples
|
|
2216
|
+
--------
|
|
2217
|
+
>>> from pychnosz import *
|
|
2218
|
+
>>> reset()
|
|
2219
|
+
>>> basis(["corundum", "quartz", "oxygen"])
|
|
2220
|
+
>>> species(["kyanite", "sillimanite", "andalusite"])
|
|
2221
|
+
>>> a = affinity(T=[200, 900, 99], P=[0, 9000, 101], exceed_Ttr=True)
|
|
2222
|
+
>>> d = diagram(a)
|
|
2223
|
+
>>> tp = find_tp(d['predominant'])
|
|
2224
|
+
>>> # Get T and P at the triple point
|
|
2225
|
+
>>> Ttp = a['vals'][0][tp[0, 1] - 1] # -1 for 0-based indexing
|
|
2226
|
+
>>> Ptp = a['vals'][1][::-1][tp[0, 0] - 1] # reversed and -1
|
|
2227
|
+
|
|
2228
|
+
Notes
|
|
2229
|
+
-----
|
|
2230
|
+
This is a Python translation of the R function find.tp() from CHNOSZ.
|
|
2231
|
+
The R version returns 1-based indices, and this Python version does too
|
|
2232
|
+
for consistency. When using these indices to access Python arrays,
|
|
2233
|
+
remember to subtract 1.
|
|
2234
|
+
|
|
2235
|
+
The function works by:
|
|
2236
|
+
1. Rearranging the matrix as done by diagram() for plotting
|
|
2237
|
+
2. For each position, examining a 3x3 neighborhood
|
|
2238
|
+
3. Counting the number of unique values in that neighborhood
|
|
2239
|
+
4. Returning positions with the maximum count (typically 3 or more)
|
|
2240
|
+
"""
|
|
2241
|
+
# Rearrange the matrix in the same way that diagram() does for 2-D predominance diagrams
|
|
2242
|
+
# R code: x <- t(x[, ncol(x):1])
|
|
2243
|
+
# This means: first reverse columns, then transpose
|
|
2244
|
+
x = np.transpose(predominant[:, ::-1])
|
|
2245
|
+
|
|
2246
|
+
# Get all positions with valid values (> 0)
|
|
2247
|
+
valid_positions = np.argwhere(x > 0)
|
|
2248
|
+
|
|
2249
|
+
if len(valid_positions) == 0:
|
|
2250
|
+
return np.array([])
|
|
2251
|
+
|
|
2252
|
+
# For each position, count unique values in 3x3 neighborhood
|
|
2253
|
+
counts = []
|
|
2254
|
+
for pos in valid_positions:
|
|
2255
|
+
row, col = pos
|
|
2256
|
+
|
|
2257
|
+
# Define the range to look at (3x3 except at edges)
|
|
2258
|
+
r1 = max(row - 1, 0)
|
|
2259
|
+
r2 = min(row + 1, x.shape[0] - 1)
|
|
2260
|
+
c1 = max(col - 1, 0)
|
|
2261
|
+
c2 = min(col + 1, x.shape[1] - 1)
|
|
2262
|
+
|
|
2263
|
+
# Extract the neighborhood
|
|
2264
|
+
neighborhood = x[r1:r2+1, c1:c2+1]
|
|
2265
|
+
|
|
2266
|
+
# Count unique values
|
|
2267
|
+
n_unique = len(np.unique(neighborhood))
|
|
2268
|
+
counts.append(n_unique)
|
|
2269
|
+
|
|
2270
|
+
counts = np.array(counts)
|
|
2271
|
+
|
|
2272
|
+
# Find positions with the maximum count
|
|
2273
|
+
max_count = np.max(counts)
|
|
2274
|
+
max_positions = valid_positions[counts == max_count]
|
|
2275
|
+
|
|
2276
|
+
# Convert to 1-based indexing (to match R)
|
|
2277
|
+
# Return as [row, col] with 1-based indices
|
|
2278
|
+
result = max_positions + 1
|
|
2279
|
+
|
|
2280
|
+
return result
|
|
2281
|
+
|
|
2282
|
+
|
|
2283
|
+
def diagram_interactive(eout: Dict[str, Any],
|
|
2284
|
+
type: str = "auto",
|
|
2285
|
+
main: Optional[str] = None,
|
|
2286
|
+
borders: Union[float, str] = 0,
|
|
2287
|
+
names: Optional[List[str]] = None,
|
|
2288
|
+
format_names: bool = True,
|
|
2289
|
+
annotation: Optional[str] = None,
|
|
2290
|
+
annotation_coords: List[float] = [0, 0],
|
|
2291
|
+
balance: Optional[Union[str, float, List[float]]] = None,
|
|
2292
|
+
xlab: Optional[str] = None,
|
|
2293
|
+
ylab: Optional[str] = None,
|
|
2294
|
+
fill: Optional[Union[str, List[str]]] = "viridis",
|
|
2295
|
+
width: int = 600,
|
|
2296
|
+
height: int = 520,
|
|
2297
|
+
alpha: Union[bool, str] = False,
|
|
2298
|
+
add: bool = False,
|
|
2299
|
+
ax: Optional[Any] = None,
|
|
2300
|
+
col: Optional[Union[str, List[str]]] = None,
|
|
2301
|
+
lty: Optional[Union[str, int, List]] = None,
|
|
2302
|
+
lwd: Union[float, List[float]] = 1,
|
|
2303
|
+
cex: Union[float, List[float]] = 1.0,
|
|
2304
|
+
contour_method: Optional[Union[str, List[str]]] = "edge",
|
|
2305
|
+
messages: bool = True,
|
|
2306
|
+
plot_it: bool = True,
|
|
2307
|
+
save_as: Optional[str] = None,
|
|
2308
|
+
save_format: Optional[str] = None,
|
|
2309
|
+
save_scale: float = 1) -> Tuple[pd.DataFrame, Any]:
|
|
2310
|
+
"""
|
|
2311
|
+
Create an interactive diagram using Plotly.
|
|
2312
|
+
|
|
2313
|
+
This function produces interactive versions of the diagrams created by diagram(),
|
|
2314
|
+
using Plotly for interactivity. It accepts output from affinity() or equilibrate()
|
|
2315
|
+
and creates either 1D line plots or 2D predominance diagrams.
|
|
2316
|
+
|
|
2317
|
+
Parameters
|
|
2318
|
+
----------
|
|
2319
|
+
eout : dict
|
|
2320
|
+
Output from affinity() or equilibrate().
|
|
2321
|
+
main : str, optional
|
|
2322
|
+
Title of the plot.
|
|
2323
|
+
borders : float or str, default 0
|
|
2324
|
+
Controls boundary lines between regions in 2D predominance diagrams.
|
|
2325
|
+
- If numeric > 0: draws grid-aligned borders with specified thickness (pixels)
|
|
2326
|
+
- If "contour": draws smooth contour-based boundaries (like diagram())
|
|
2327
|
+
- If 0 or None: no borders drawn
|
|
2328
|
+
names : list of str, optional
|
|
2329
|
+
Names of species for activity lines or predominance fields.
|
|
2330
|
+
format_names : bool, default True
|
|
2331
|
+
Apply formatting to chemical formulas?
|
|
2332
|
+
annotation : str, optional
|
|
2333
|
+
Annotation to add to the plot.
|
|
2334
|
+
annotation_coords : list of float, default [0, 0]
|
|
2335
|
+
Coordinates of annotation, where 0,0 is bottom left and 1,1 is top right.
|
|
2336
|
+
balance : str or numeric, optional
|
|
2337
|
+
How to balance the transformations.
|
|
2338
|
+
xlab : str, optional
|
|
2339
|
+
Custom x-axis label.
|
|
2340
|
+
ylab : str, optional
|
|
2341
|
+
Custom y-axis label.
|
|
2342
|
+
fill : str or list of str, default "viridis"
|
|
2343
|
+
For 2D diagrams: colormap name (e.g., "viridis", "hot") or list of colors.
|
|
2344
|
+
For 1D diagrams: list of line colors.
|
|
2345
|
+
width : int, default 600
|
|
2346
|
+
Width of the plot in pixels.
|
|
2347
|
+
height : int, default 520
|
|
2348
|
+
Height of the plot in pixels.
|
|
2349
|
+
alpha : bool or str, default False
|
|
2350
|
+
For speciation diagrams, plot degree of formation instead of activities?
|
|
2351
|
+
If True, plots mole fractions. If "balance", scales by stoichiometry.
|
|
2352
|
+
messages : bool, default True
|
|
2353
|
+
Display messages?
|
|
2354
|
+
plot_it : bool, default True
|
|
2355
|
+
Show the plot?
|
|
2356
|
+
save_as : str, optional
|
|
2357
|
+
Provide a filename to save this figure. Filetype of saved figure is
|
|
2358
|
+
determined by save_format.
|
|
2359
|
+
save_format : str, default "png"
|
|
2360
|
+
Desired format of saved or downloaded figure. Can be 'png', 'jpg', 'jpeg',
|
|
2361
|
+
'webp', 'svg', 'pdf', 'eps', 'json', or 'html'. If 'html', an interactive
|
|
2362
|
+
plot will be saved.
|
|
2363
|
+
save_scale : float, default 1
|
|
2364
|
+
Multiply title/legend/axis/canvas sizes by this factor when saving.
|
|
2365
|
+
|
|
2366
|
+
Returns
|
|
2367
|
+
-------
|
|
2368
|
+
tuple
|
|
2369
|
+
(df, fig) where df is a pandas DataFrame with the data and fig is the
|
|
2370
|
+
Plotly figure object.
|
|
2371
|
+
|
|
2372
|
+
Examples
|
|
2373
|
+
--------
|
|
2374
|
+
1D diagram:
|
|
2375
|
+
>>> basis("CHNOS+")
|
|
2376
|
+
>>> species(info(["glycinium", "glycine", "glycinate"]))
|
|
2377
|
+
>>> a = affinity(pH=[0, 14])
|
|
2378
|
+
>>> e = equilibrate(a)
|
|
2379
|
+
>>> diagram_interactive(e, alpha=True)
|
|
2380
|
+
|
|
2381
|
+
2D diagram:
|
|
2382
|
+
>>> basis(["Fe", "oxygen", "S2"])
|
|
2383
|
+
>>> species(["iron", "ferrous-oxide", "magnetite", "hematite", "pyrite", "pyrrhotite"])
|
|
2384
|
+
>>> a = affinity(S2=[-50, 0], O2=[-90, -10], T=200)
|
|
2385
|
+
>>> diagram_interactive(a, fill="hot")
|
|
2386
|
+
|
|
2387
|
+
Notes
|
|
2388
|
+
-----
|
|
2389
|
+
This function requires plotly to be installed. Install with:
|
|
2390
|
+
pip install plotly
|
|
2391
|
+
|
|
2392
|
+
The function adapts the pyCHNOSZ diagram_interactive() implementation
|
|
2393
|
+
to work with Python CHNOSZ's native data structures.
|
|
2394
|
+
"""
|
|
2395
|
+
|
|
2396
|
+
# Import plotly (lazy import to avoid dependency issues)
|
|
2397
|
+
try:
|
|
2398
|
+
import plotly.express as px
|
|
2399
|
+
import plotly.graph_objects as go
|
|
2400
|
+
import plotly.io as pio
|
|
2401
|
+
except ImportError:
|
|
2402
|
+
raise ImportError("diagram_interactive() requires plotly. Install with: pip install plotly")
|
|
2403
|
+
|
|
2404
|
+
# Check that eout is valid
|
|
2405
|
+
efun = eout.get('fun', '')
|
|
2406
|
+
if efun not in ['affinity', 'equilibrate', 'solubility']:
|
|
2407
|
+
raise ValueError("'eout' is not the output from affinity(), equilibrate(), or solubility()")
|
|
2408
|
+
|
|
2409
|
+
# Determine if this is affinity or equilibrate output
|
|
2410
|
+
calc_type = "a" if ('loga_equil' not in eout and 'loga.equil' not in eout) else "e"
|
|
2411
|
+
|
|
2412
|
+
# Get basis species and their states
|
|
2413
|
+
basis_df = eout['basis']
|
|
2414
|
+
basis_sp = list(basis_df.index)
|
|
2415
|
+
basis_state = list(basis_df['state'])
|
|
2416
|
+
|
|
2417
|
+
# Get variable names and values
|
|
2418
|
+
xyvars = eout['vars']
|
|
2419
|
+
xyvals_dict = eout['vals']
|
|
2420
|
+
# Convert vals dict to list format for easier access
|
|
2421
|
+
xyvals = [xyvals_dict[var] for var in xyvars]
|
|
2422
|
+
|
|
2423
|
+
# Determine balance if not provided
|
|
2424
|
+
if balance is None or balance == "":
|
|
2425
|
+
# For saturation diagrams, use balance=1 (formula units) to match R behavior
|
|
2426
|
+
# This avoids issues when minerals don't have a common basis element
|
|
2427
|
+
if type == "saturation":
|
|
2428
|
+
balance = 1
|
|
2429
|
+
n_balance = [1] * len(eout['values'])
|
|
2430
|
+
else:
|
|
2431
|
+
# Call diagram with plot_it=False to get balance
|
|
2432
|
+
# Need to import matplotlib to close the figure afterward
|
|
2433
|
+
import matplotlib.pyplot as plt_temp
|
|
2434
|
+
temp_result = diagram(eout, messages=False, plot_it=False)
|
|
2435
|
+
balance = temp_result.get('balance', 1)
|
|
2436
|
+
n_balance = temp_result.get('n_balance', [1])
|
|
2437
|
+
# Close the matplotlib figure created by diagram() since we don't need it
|
|
2438
|
+
if 'fig' in temp_result and temp_result['fig'] is not None:
|
|
2439
|
+
plt_temp.close(temp_result['fig'])
|
|
2440
|
+
else:
|
|
2441
|
+
# Calculate n_balance from balance
|
|
2442
|
+
try:
|
|
2443
|
+
balance_float = float(balance)
|
|
2444
|
+
n_balance = [balance_float] * len(eout['values'])
|
|
2445
|
+
except (ValueError, TypeError):
|
|
2446
|
+
# balance is a string (element name)
|
|
2447
|
+
# Get species from eout instead of global state
|
|
2448
|
+
if 'species' in eout and eout['species'] is not None:
|
|
2449
|
+
sp_df = eout['species']
|
|
2450
|
+
else:
|
|
2451
|
+
# Fallback to global species if not in eout
|
|
2452
|
+
from .species import species as species_func
|
|
2453
|
+
sp_df = species_func()
|
|
2454
|
+
|
|
2455
|
+
# Check if balance is a list (user-provided values) or a string (column name)
|
|
2456
|
+
if isinstance(balance, list):
|
|
2457
|
+
n_balance = balance
|
|
2458
|
+
elif balance in sp_df.columns:
|
|
2459
|
+
n_balance = list(sp_df[balance])
|
|
2460
|
+
else:
|
|
2461
|
+
n_balance = [1] * len(eout['values'])
|
|
2462
|
+
|
|
2463
|
+
# Get output values
|
|
2464
|
+
if calc_type == "a":
|
|
2465
|
+
# handling output of affinity()
|
|
2466
|
+
out_vals = eout['values']
|
|
2467
|
+
out_units = "A/(2.303RT)"
|
|
2468
|
+
else:
|
|
2469
|
+
# handling output of equilibrate()
|
|
2470
|
+
loga_equil_key = 'loga_equil' if 'loga_equil' in eout else 'loga.equil'
|
|
2471
|
+
out_vals = eout[loga_equil_key]
|
|
2472
|
+
out_units = "log a"
|
|
2473
|
+
|
|
2474
|
+
# Convert values to a list format
|
|
2475
|
+
if isinstance(out_vals, dict):
|
|
2476
|
+
nsp = len(out_vals)
|
|
2477
|
+
values_list = list(out_vals.values())
|
|
2478
|
+
species_indices = list(out_vals.keys())
|
|
2479
|
+
else:
|
|
2480
|
+
nsp = len(out_vals)
|
|
2481
|
+
values_list = out_vals
|
|
2482
|
+
species_indices = eout['species']['ispecies'].tolist()
|
|
2483
|
+
|
|
2484
|
+
# Get species names
|
|
2485
|
+
from .info import info as info_func
|
|
2486
|
+
# Convert numpy types to Python types
|
|
2487
|
+
species_indices_py = [int(idx) for idx in species_indices]
|
|
2488
|
+
sp_info = info_func(species_indices_py, messages=False)
|
|
2489
|
+
sp_names = sp_info['name'].tolist()
|
|
2490
|
+
|
|
2491
|
+
# Use custom names if provided
|
|
2492
|
+
if isinstance(names, list) and len(names) == len(sp_names):
|
|
2493
|
+
sp_names = names
|
|
2494
|
+
|
|
2495
|
+
# Determine dimensions
|
|
2496
|
+
first_val = values_list[0]
|
|
2497
|
+
if hasattr(first_val, 'shape'):
|
|
2498
|
+
nd = len(first_val.shape)
|
|
2499
|
+
else:
|
|
2500
|
+
nd = 1 if hasattr(first_val, '__len__') else 0
|
|
2501
|
+
|
|
2502
|
+
# Handle type="saturation" - plot contour lines where affinity=0
|
|
2503
|
+
if type == "saturation":
|
|
2504
|
+
if nd != 2:
|
|
2505
|
+
raise ValueError("type='saturation' requires 2-D diagram")
|
|
2506
|
+
if calc_type != "a":
|
|
2507
|
+
raise ValueError("type='saturation' requires output from affinity(), not equilibrate()")
|
|
2508
|
+
|
|
2509
|
+
# Delegate to saturation plotting function
|
|
2510
|
+
return _plot_saturation_interactive(
|
|
2511
|
+
eout, values_list, sp_names, xyvars, xyvals,
|
|
2512
|
+
xlab, ylab, col, lwd, lty, cex, contour_method,
|
|
2513
|
+
main, add, ax, width, height, plot_it,
|
|
2514
|
+
save_as, save_format, save_scale, messages
|
|
2515
|
+
)
|
|
2516
|
+
|
|
2517
|
+
# Build DataFrame
|
|
2518
|
+
if nd == 2:
|
|
2519
|
+
# 2D case - flatten the data
|
|
2520
|
+
xvals = xyvals[0]
|
|
2521
|
+
yvals = xyvals[1]
|
|
2522
|
+
xvar = xyvars[0]
|
|
2523
|
+
yvar = xyvars[1]
|
|
2524
|
+
|
|
2525
|
+
# Flatten the data - transpose first so coordinates match
|
|
2526
|
+
# Original shape is (nx, ny) where nx=len(xvals), ny=len(yvals)
|
|
2527
|
+
# After transpose, shape is (ny, nx)
|
|
2528
|
+
# Flattening with C-order then gives: [row0, row1, ...] = [x-values at y[0], x-values at y[1], ...]
|
|
2529
|
+
flat_out_vals = []
|
|
2530
|
+
for v in values_list:
|
|
2531
|
+
# Transpose then flatten so coordinates align correctly
|
|
2532
|
+
flat_out_vals.append(v.T.flatten())
|
|
2533
|
+
df = pd.DataFrame(flat_out_vals, index=sp_names).T
|
|
2534
|
+
|
|
2535
|
+
# Apply balance if needed
|
|
2536
|
+
if calc_type == "a":
|
|
2537
|
+
if isinstance(balance, str):
|
|
2538
|
+
# Get balance from species dataframe
|
|
2539
|
+
# Get species from eout instead of global state
|
|
2540
|
+
if 'species' in eout and eout['species'] is not None:
|
|
2541
|
+
sp_df = eout['species']
|
|
2542
|
+
else:
|
|
2543
|
+
# Fallback to global species if not in eout
|
|
2544
|
+
from .species import species as species_func
|
|
2545
|
+
sp_df = species_func()
|
|
2546
|
+
|
|
2547
|
+
# Check if balance is a list (user-provided values) or a string (column name)
|
|
2548
|
+
if isinstance(balance, list):
|
|
2549
|
+
n_balance = balance
|
|
2550
|
+
elif balance in sp_df.columns:
|
|
2551
|
+
n_balance = list(sp_df[balance])
|
|
2552
|
+
# Divide by balance
|
|
2553
|
+
for i, sp in enumerate(sp_names):
|
|
2554
|
+
df[sp] = df[sp] / n_balance[i]
|
|
2555
|
+
|
|
2556
|
+
# Find predominant species
|
|
2557
|
+
df["pred"] = df.idxmax(axis=1, skipna=True)
|
|
2558
|
+
df["prednames"] = df["pred"]
|
|
2559
|
+
|
|
2560
|
+
# Add x and y coordinates
|
|
2561
|
+
# After transpose and flatten, data is ordered as:
|
|
2562
|
+
# [x0,y0], [x1,y0], ..., [xn,y0], [x0,y1], [x1,y1], ...
|
|
2563
|
+
xvals_full = list(xvals) * len(yvals)
|
|
2564
|
+
yvals_full = []
|
|
2565
|
+
for y in yvals:
|
|
2566
|
+
yvals_full.extend([y] * len(xvals))
|
|
2567
|
+
df[xvar] = xvals_full
|
|
2568
|
+
df[yvar] = yvals_full
|
|
2569
|
+
|
|
2570
|
+
else:
|
|
2571
|
+
# 1D case
|
|
2572
|
+
xvar = xyvars[0]
|
|
2573
|
+
xvals = xyvals[0]
|
|
2574
|
+
|
|
2575
|
+
flat_out_vals = []
|
|
2576
|
+
for v in values_list:
|
|
2577
|
+
flat_out_vals.append(v)
|
|
2578
|
+
df = pd.DataFrame(flat_out_vals, index=sp_names).T
|
|
2579
|
+
|
|
2580
|
+
# Apply balance if needed
|
|
2581
|
+
if calc_type == "a":
|
|
2582
|
+
if isinstance(balance, str):
|
|
2583
|
+
# Get species from eout instead of global state
|
|
2584
|
+
if 'species' in eout and eout['species'] is not None:
|
|
2585
|
+
sp_df = eout['species']
|
|
2586
|
+
else:
|
|
2587
|
+
# Fallback to global species if not in eout
|
|
2588
|
+
from .species import species as species_func
|
|
2589
|
+
sp_df = species_func()
|
|
2590
|
+
|
|
2591
|
+
# Check if balance is a list (user-provided values) or a string (column name)
|
|
2592
|
+
if isinstance(balance, list):
|
|
2593
|
+
n_balance = balance
|
|
2594
|
+
elif balance in sp_df.columns:
|
|
2595
|
+
n_balance = list(sp_df[balance])
|
|
2596
|
+
# Divide by balance
|
|
2597
|
+
for i, sp in enumerate(sp_names):
|
|
2598
|
+
df[sp] = df[sp] / n_balance[i]
|
|
2599
|
+
|
|
2600
|
+
# Handle alpha (degree of formation)
|
|
2601
|
+
if alpha:
|
|
2602
|
+
df = df.apply(lambda x: 10**x)
|
|
2603
|
+
df = df[sp_names].div(df[sp_names].sum(axis=1), axis=0)
|
|
2604
|
+
|
|
2605
|
+
df[xvar] = xvals
|
|
2606
|
+
|
|
2607
|
+
# Create axis labels
|
|
2608
|
+
unit_dict = {"P": "bar", "T": "°C", "pH": "", "Eh": "volts", "IS": "mol/kg"}
|
|
2609
|
+
|
|
2610
|
+
for i, s in enumerate(basis_sp):
|
|
2611
|
+
if basis_state[i] in ["aq", "liq", "cr"]:
|
|
2612
|
+
if format_names:
|
|
2613
|
+
unit_dict[s] = f"log <i>a</i><sub>{_format_html_species(s)}</sub>"
|
|
2614
|
+
else:
|
|
2615
|
+
unit_dict[s] = f"log <i>a</i><sub>{s}</sub>"
|
|
2616
|
+
else:
|
|
2617
|
+
if format_names:
|
|
2618
|
+
unit_dict[s] = f"log <i>f</i><sub>{_format_html_species(s)}</sub>"
|
|
2619
|
+
else:
|
|
2620
|
+
unit_dict[s] = f"log <i>f</i><sub>{s}</sub>"
|
|
2621
|
+
|
|
2622
|
+
# Set x-axis label
|
|
2623
|
+
if not isinstance(xlab, str):
|
|
2624
|
+
xlab = xvar + ", " + unit_dict.get(xvar, "")
|
|
2625
|
+
if xvar == "pH":
|
|
2626
|
+
xlab = "pH"
|
|
2627
|
+
if xvar in basis_sp:
|
|
2628
|
+
xlab = unit_dict[xvar]
|
|
2629
|
+
|
|
2630
|
+
# Create the plot
|
|
2631
|
+
if nd == 1:
|
|
2632
|
+
# 1D plot
|
|
2633
|
+
# Melt the dataframe for plotting
|
|
2634
|
+
df_melted = pd.melt(df, id_vars=[xvar], value_vars=sp_names, var_name='variable', value_name='value')
|
|
2635
|
+
|
|
2636
|
+
# Format species names if requested
|
|
2637
|
+
if format_names:
|
|
2638
|
+
df_melted['variable'] = df_melted['variable'].apply(_format_html_species)
|
|
2639
|
+
|
|
2640
|
+
# Set y-axis label
|
|
2641
|
+
if not isinstance(ylab, str):
|
|
2642
|
+
if alpha:
|
|
2643
|
+
ylab = "alpha"
|
|
2644
|
+
else:
|
|
2645
|
+
ylab = out_units
|
|
2646
|
+
|
|
2647
|
+
fig = px.line(df_melted, x=xvar, y="value", color='variable',
|
|
2648
|
+
template="simple_white", width=width, height=height,
|
|
2649
|
+
labels={'value': ylab, xvar: xlab},
|
|
2650
|
+
render_mode='svg')
|
|
2651
|
+
|
|
2652
|
+
# Apply custom colors if provided
|
|
2653
|
+
if isinstance(fill, list):
|
|
2654
|
+
for i, color in enumerate(fill):
|
|
2655
|
+
if i < len(fig.data):
|
|
2656
|
+
fig.data[i].line.color = color
|
|
2657
|
+
|
|
2658
|
+
# Check for LaTeX format in axis labels
|
|
2659
|
+
if xlab and _detect_latex_format(xlab):
|
|
2660
|
+
warnings.warn(
|
|
2661
|
+
"LaTeX formatting detected in 'xlab' parameter. "
|
|
2662
|
+
"Plotly requires HTML format (<sub>, <sup>) instead of LaTeX ($, _, ^). "
|
|
2663
|
+
"For activity ratios, use ratlab_html() instead of ratlab().",
|
|
2664
|
+
UserWarning
|
|
2665
|
+
)
|
|
2666
|
+
if ylab and _detect_latex_format(ylab):
|
|
2667
|
+
warnings.warn(
|
|
2668
|
+
"LaTeX formatting detected in 'ylab' parameter. "
|
|
2669
|
+
"Plotly requires HTML format (<sub>, <sup>) instead of LaTeX ($, _, ^). "
|
|
2670
|
+
"For activity ratios, use ratlab_html() instead of ratlab().",
|
|
2671
|
+
UserWarning
|
|
2672
|
+
)
|
|
2673
|
+
|
|
2674
|
+
fig.update_layout(xaxis_title=xlab,
|
|
2675
|
+
yaxis_title=ylab,
|
|
2676
|
+
legend_title=None)
|
|
2677
|
+
|
|
2678
|
+
if isinstance(main, str):
|
|
2679
|
+
fig.update_layout(title={'text': main, 'x': 0.5, 'xanchor': 'center'})
|
|
2680
|
+
|
|
2681
|
+
if isinstance(annotation, str):
|
|
2682
|
+
# Check for LaTeX format and warn user
|
|
2683
|
+
if _detect_latex_format(annotation):
|
|
2684
|
+
warnings.warn(
|
|
2685
|
+
"LaTeX formatting detected in 'annotation' parameter. "
|
|
2686
|
+
"Plotly requires HTML format (<sub>, <sup>) instead of LaTeX ($, _, ^). "
|
|
2687
|
+
"For activity ratios, use ratlab_html() instead of ratlab().",
|
|
2688
|
+
UserWarning
|
|
2689
|
+
)
|
|
2690
|
+
|
|
2691
|
+
fig.add_annotation(
|
|
2692
|
+
x=annotation_coords[0],
|
|
2693
|
+
y=annotation_coords[1],
|
|
2694
|
+
text=annotation,
|
|
2695
|
+
showarrow=False,
|
|
2696
|
+
xref="paper",
|
|
2697
|
+
yref="paper",
|
|
2698
|
+
align='left',
|
|
2699
|
+
bgcolor="rgba(255, 255, 255, 0.5)")
|
|
2700
|
+
|
|
2701
|
+
# Configure download button
|
|
2702
|
+
save_as_name, save_format_final = _save_figure(fig, save_as, save_format, save_scale,
|
|
2703
|
+
plot_width=width, plot_height=height, ppi=1)
|
|
2704
|
+
|
|
2705
|
+
config = {'displaylogo': False,
|
|
2706
|
+
'modeBarButtonsToRemove': ['resetScale2d', 'toggleSpikelines'],
|
|
2707
|
+
'toImageButtonOptions': {
|
|
2708
|
+
'format': save_format_final,
|
|
2709
|
+
'filename': save_as_name,
|
|
2710
|
+
'height': height,
|
|
2711
|
+
'width': width,
|
|
2712
|
+
'scale': save_scale,
|
|
2713
|
+
}}
|
|
2714
|
+
|
|
2715
|
+
# Store config on figure so it persists when fig.show() is called later
|
|
2716
|
+
fig._config = fig._config | config
|
|
2717
|
+
|
|
2718
|
+
else:
|
|
2719
|
+
# 2D plot
|
|
2720
|
+
# Map species names to numeric values
|
|
2721
|
+
mappings = {s: lab for s, lab in zip(sp_names, range(len(sp_names)))}
|
|
2722
|
+
df['pred'] = df['pred'].map(mappings).astype(int)
|
|
2723
|
+
|
|
2724
|
+
# Reshape data
|
|
2725
|
+
# Data is flattened as [x0,y0], [x1,y0], ..., [xn,y0], [x0,y1], ...
|
|
2726
|
+
# Reshape to (ny, nx) for proper orientation in Plotly
|
|
2727
|
+
# Plotly expects data[i,j] to correspond to x[j], y[i]
|
|
2728
|
+
data = np.array(df['pred'])
|
|
2729
|
+
shape = (len(yvals), len(xvals))
|
|
2730
|
+
dmap = data.reshape(shape)
|
|
2731
|
+
|
|
2732
|
+
data_names = np.array(df['prednames'])
|
|
2733
|
+
dmap_names = data_names.reshape(shape)
|
|
2734
|
+
|
|
2735
|
+
# Set y-axis label
|
|
2736
|
+
if not isinstance(ylab, str):
|
|
2737
|
+
ylab = yvar + ", " + unit_dict.get(yvar, "")
|
|
2738
|
+
if yvar in basis_sp:
|
|
2739
|
+
ylab = unit_dict[yvar]
|
|
2740
|
+
if yvar == "pH":
|
|
2741
|
+
ylab = "pH"
|
|
2742
|
+
|
|
2743
|
+
# Check for LaTeX format in axis labels (2D plot)
|
|
2744
|
+
if xlab and _detect_latex_format(xlab):
|
|
2745
|
+
warnings.warn(
|
|
2746
|
+
"LaTeX formatting detected in 'xlab' parameter. "
|
|
2747
|
+
"Plotly requires HTML format (<sub>, <sup>) instead of LaTeX ($, _, ^). "
|
|
2748
|
+
"For activity ratios, use ratlab_html() instead of ratlab().",
|
|
2749
|
+
UserWarning
|
|
2750
|
+
)
|
|
2751
|
+
if ylab and _detect_latex_format(ylab):
|
|
2752
|
+
warnings.warn(
|
|
2753
|
+
"LaTeX formatting detected in 'ylab' parameter. "
|
|
2754
|
+
"Plotly requires HTML format (<sub>, <sup>) instead of LaTeX ($, _, ^). "
|
|
2755
|
+
"For activity ratios, use ratlab_html() instead of ratlab().",
|
|
2756
|
+
UserWarning
|
|
2757
|
+
)
|
|
2758
|
+
|
|
2759
|
+
# Create heatmap
|
|
2760
|
+
fig = px.imshow(dmap, width=width, height=height, aspect="auto",
|
|
2761
|
+
labels={'x': xlab, 'y': ylab, 'color': "region"},
|
|
2762
|
+
x=xvals, y=yvals, template="simple_white")
|
|
2763
|
+
|
|
2764
|
+
fig.update(data=[{'customdata': dmap_names,
|
|
2765
|
+
'hovertemplate': xlab + ': %{x}<br>' + ylab + ': %{y}<br>Region: %{customdata}<extra></extra>'}])
|
|
2766
|
+
|
|
2767
|
+
# Set colormap
|
|
2768
|
+
if fill == 'none':
|
|
2769
|
+
colormap = [[0, 'white'], [1, 'white']]
|
|
2770
|
+
elif isinstance(fill, list):
|
|
2771
|
+
colmap_temp = []
|
|
2772
|
+
for i, v in enumerate(fill):
|
|
2773
|
+
colmap_temp.append([i / (len(fill) - 1) if len(fill) > 1 else 0, v])
|
|
2774
|
+
colormap = colmap_temp
|
|
2775
|
+
else:
|
|
2776
|
+
colormap = fill
|
|
2777
|
+
|
|
2778
|
+
fig.update_traces(dict(showscale=False,
|
|
2779
|
+
coloraxis=None,
|
|
2780
|
+
colorscale=colormap),
|
|
2781
|
+
selector={'type': 'heatmap'})
|
|
2782
|
+
|
|
2783
|
+
fig.update_yaxes(autorange=True)
|
|
2784
|
+
|
|
2785
|
+
if isinstance(main, str):
|
|
2786
|
+
fig.update_layout(title={'text': main, 'x': 0.5, 'xanchor': 'center'})
|
|
2787
|
+
|
|
2788
|
+
# Add species labels
|
|
2789
|
+
for s in sp_names:
|
|
2790
|
+
if s in set(df["prednames"]):
|
|
2791
|
+
df_s = df.loc[df["prednames"] == s]
|
|
2792
|
+
namex = df_s[xvar].mean()
|
|
2793
|
+
namey = df_s[yvar].mean()
|
|
2794
|
+
|
|
2795
|
+
if format_names:
|
|
2796
|
+
annot_text = _format_html_species(s)
|
|
2797
|
+
else:
|
|
2798
|
+
annot_text = str(s)
|
|
2799
|
+
|
|
2800
|
+
fig.add_annotation(x=namex, y=namey,
|
|
2801
|
+
text=annot_text,
|
|
2802
|
+
bgcolor="rgba(255, 255, 255, 0.5)",
|
|
2803
|
+
showarrow=False)
|
|
2804
|
+
|
|
2805
|
+
if isinstance(annotation, str):
|
|
2806
|
+
# Check for LaTeX format and warn user
|
|
2807
|
+
if _detect_latex_format(annotation):
|
|
2808
|
+
warnings.warn(
|
|
2809
|
+
"LaTeX formatting detected in 'annotation' parameter. "
|
|
2810
|
+
"Plotly requires HTML format (<sub>, <sup>) instead of LaTeX ($, _, ^). "
|
|
2811
|
+
"For activity ratios, use ratlab_html() instead of ratlab().",
|
|
2812
|
+
UserWarning
|
|
2813
|
+
)
|
|
2814
|
+
|
|
2815
|
+
fig.add_annotation(
|
|
2816
|
+
x=annotation_coords[0],
|
|
2817
|
+
y=annotation_coords[1],
|
|
2818
|
+
text=annotation,
|
|
2819
|
+
showarrow=False,
|
|
2820
|
+
xref="paper",
|
|
2821
|
+
yref="paper",
|
|
2822
|
+
align='left',
|
|
2823
|
+
bgcolor="rgba(255, 255, 255, 0.5)")
|
|
2824
|
+
|
|
2825
|
+
# Add borders if requested
|
|
2826
|
+
if borders == "contour":
|
|
2827
|
+
# Use contour-based boundaries (smooth, like diagram())
|
|
2828
|
+
# Draw boundaries using matplotlib contour extraction without filling
|
|
2829
|
+
|
|
2830
|
+
# Get unique species (excluding any that don't appear)
|
|
2831
|
+
unique_species_names = sorted(df["prednames"].unique())
|
|
2832
|
+
|
|
2833
|
+
# Create a temporary matplotlib figure to extract contour paths
|
|
2834
|
+
# We won't display it, just use it to calculate contours
|
|
2835
|
+
temp_fig, temp_ax = plt.subplots()
|
|
2836
|
+
|
|
2837
|
+
# For each species, create a binary mask and extract contours
|
|
2838
|
+
for i, sp_name in enumerate(unique_species_names):
|
|
2839
|
+
# Create binary mask: 1 where this species predominates, 0 elsewhere
|
|
2840
|
+
z = (dmap_names == sp_name).astype(float)
|
|
2841
|
+
|
|
2842
|
+
# Create meshgrid for contour
|
|
2843
|
+
X, Y = np.meshgrid(xvals, yvals)
|
|
2844
|
+
|
|
2845
|
+
# Find contours at level 0.5 using matplotlib
|
|
2846
|
+
try:
|
|
2847
|
+
cs = temp_ax.contour(X, Y, z, levels=[0.5])
|
|
2848
|
+
|
|
2849
|
+
# Extract the contour segments
|
|
2850
|
+
# cs.allsegs is a list of lists: [level][segment]
|
|
2851
|
+
for level_segs in cs.allsegs:
|
|
2852
|
+
for segment in level_segs:
|
|
2853
|
+
# segment is an (N, 2) array of (x, y) coordinates
|
|
2854
|
+
# Add as a scatter trace with lines
|
|
2855
|
+
fig.add_trace(
|
|
2856
|
+
go.Scatter(
|
|
2857
|
+
x=segment[:, 0],
|
|
2858
|
+
y=segment[:, 1],
|
|
2859
|
+
mode='lines',
|
|
2860
|
+
line=dict(color='black', width=2),
|
|
2861
|
+
hoverinfo='skip',
|
|
2862
|
+
showlegend=False
|
|
2863
|
+
)
|
|
2864
|
+
)
|
|
2865
|
+
|
|
2866
|
+
# Clear the temp axes for next species
|
|
2867
|
+
temp_ax.clear()
|
|
2868
|
+
except Exception as e:
|
|
2869
|
+
if messages:
|
|
2870
|
+
warnings.warn(f"Could not draw contour for {sp_name}: {e}")
|
|
2871
|
+
pass # Skip if contour can't be drawn
|
|
2872
|
+
|
|
2873
|
+
# Close the temporary figure
|
|
2874
|
+
plt.close(temp_fig)
|
|
2875
|
+
|
|
2876
|
+
elif isinstance(borders, (int, float)) and borders > 0:
|
|
2877
|
+
unique_x_vals = sorted(list(set(df[xvar])))
|
|
2878
|
+
unique_y_vals = sorted(list(set(df[yvar])))
|
|
2879
|
+
|
|
2880
|
+
# Skip border drawing if there are fewer than 2 unique values
|
|
2881
|
+
# (single point or single line - no borders to draw between regions)
|
|
2882
|
+
if len(unique_x_vals) < 2 or len(unique_y_vals) < 2:
|
|
2883
|
+
if messages:
|
|
2884
|
+
warnings.warn("Skipping border drawing: need at least 2 unique values in each dimension")
|
|
2885
|
+
else:
|
|
2886
|
+
def mov_mean(numbers, window_size=2):
|
|
2887
|
+
moving_averages = []
|
|
2888
|
+
for i in range(len(numbers) - window_size + 1):
|
|
2889
|
+
window_average = sum(numbers[i:i + window_size]) / window_size
|
|
2890
|
+
moving_averages.append(window_average)
|
|
2891
|
+
return moving_averages
|
|
2892
|
+
|
|
2893
|
+
x_mov_mean = mov_mean(unique_x_vals)
|
|
2894
|
+
y_mov_mean = mov_mean(unique_y_vals)
|
|
2895
|
+
|
|
2896
|
+
x_plot_min = x_mov_mean[0] - (x_mov_mean[1] - x_mov_mean[0])
|
|
2897
|
+
y_plot_min = y_mov_mean[0] - (y_mov_mean[1] - y_mov_mean[0])
|
|
2898
|
+
|
|
2899
|
+
x_plot_max = x_mov_mean[-1] + (x_mov_mean[1] - x_mov_mean[0])
|
|
2900
|
+
y_plot_max = y_mov_mean[-1] + (y_mov_mean[1] - y_mov_mean[0])
|
|
2901
|
+
|
|
2902
|
+
x_vals_border = [x_plot_min] + x_mov_mean + [x_plot_max]
|
|
2903
|
+
y_vals_border = [y_plot_min] + y_mov_mean + [y_plot_max]
|
|
2904
|
+
|
|
2905
|
+
# Find border lines
|
|
2906
|
+
def find_line(dmap, row_index):
|
|
2907
|
+
return [i for i in range(len(dmap[row_index]) - 1) if dmap[row_index][i] != dmap[row_index][i + 1]]
|
|
2908
|
+
|
|
2909
|
+
nrows, ncols = dmap.shape
|
|
2910
|
+
vlines = [find_line(dmap, row_i) for row_i in range(nrows)]
|
|
2911
|
+
|
|
2912
|
+
dmap_transposed = dmap.transpose()
|
|
2913
|
+
nrows_t, ncols_t = dmap_transposed.shape
|
|
2914
|
+
hlines = [find_line(dmap_transposed, row_i) for row_i in range(nrows_t)]
|
|
2915
|
+
|
|
2916
|
+
y_coord_list_vertical = []
|
|
2917
|
+
x_coord_list_vertical = []
|
|
2918
|
+
for i, row in enumerate(vlines):
|
|
2919
|
+
for line in row:
|
|
2920
|
+
x_coord_list_vertical += [x_vals_border[line + 1], x_vals_border[line + 1], np.nan]
|
|
2921
|
+
y_coord_list_vertical += [y_vals_border[i], y_vals_border[i + 1], np.nan]
|
|
2922
|
+
|
|
2923
|
+
y_coord_list_horizontal = []
|
|
2924
|
+
x_coord_list_horizontal = []
|
|
2925
|
+
for i, col in enumerate(hlines):
|
|
2926
|
+
for line in col:
|
|
2927
|
+
y_coord_list_horizontal += [y_vals_border[line + 1], y_vals_border[line + 1], np.nan]
|
|
2928
|
+
x_coord_list_horizontal += [x_vals_border[i], x_vals_border[i + 1], np.nan]
|
|
2929
|
+
|
|
2930
|
+
fig.add_trace(
|
|
2931
|
+
go.Scatter(
|
|
2932
|
+
mode='lines',
|
|
2933
|
+
x=x_coord_list_horizontal,
|
|
2934
|
+
y=y_coord_list_horizontal,
|
|
2935
|
+
line={'width': borders, 'color': 'black'},
|
|
2936
|
+
hoverinfo='skip',
|
|
2937
|
+
showlegend=False))
|
|
2938
|
+
|
|
2939
|
+
fig.add_trace(
|
|
2940
|
+
go.Scatter(
|
|
2941
|
+
mode='lines',
|
|
2942
|
+
x=x_coord_list_vertical,
|
|
2943
|
+
y=y_coord_list_vertical,
|
|
2944
|
+
line={'width': borders, 'color': 'black'},
|
|
2945
|
+
hoverinfo='skip',
|
|
2946
|
+
showlegend=False))
|
|
2947
|
+
|
|
2948
|
+
fig.update_yaxes(range=[min(yvals), max(yvals)], autorange=False, mirror=True)
|
|
2949
|
+
fig.update_xaxes(range=[min(xvals), max(xvals)], autorange=False, mirror=True)
|
|
2950
|
+
|
|
2951
|
+
# Configure download button
|
|
2952
|
+
save_as_name, save_format_final = _save_figure(fig, save_as, save_format, save_scale,
|
|
2953
|
+
plot_width=width, plot_height=height, ppi=1)
|
|
2954
|
+
|
|
2955
|
+
config = {'displaylogo': False,
|
|
2956
|
+
'modeBarButtonsToRemove': ['zoom2d', 'pan2d', 'zoomIn2d', 'zoomOut2d',
|
|
2957
|
+
'autoScale2d', 'resetScale2d', 'toggleSpikelines',
|
|
2958
|
+
'hoverClosestCartesian', 'hoverCompareCartesian'],
|
|
2959
|
+
'toImageButtonOptions': {
|
|
2960
|
+
'format': save_format_final,
|
|
2961
|
+
'filename': save_as_name,
|
|
2962
|
+
'height': height,
|
|
2963
|
+
'width': width,
|
|
2964
|
+
'scale': save_scale,
|
|
2965
|
+
}}
|
|
2966
|
+
|
|
2967
|
+
# Store config on figure so it persists when fig.show() is called later
|
|
2968
|
+
fig._config = fig._config | config
|
|
2969
|
+
|
|
2970
|
+
if plot_it:
|
|
2971
|
+
fig.show(config=config)
|
|
2972
|
+
|
|
2973
|
+
return df, fig
|
|
2974
|
+
|
|
2975
|
+
|
|
2976
|
+
def _detect_latex_format(text: str) -> bool:
|
|
2977
|
+
"""
|
|
2978
|
+
Detect if a string contains LaTeX formatting (incompatible with Plotly).
|
|
2979
|
+
|
|
2980
|
+
Parameters
|
|
2981
|
+
----------
|
|
2982
|
+
text : str
|
|
2983
|
+
Text to check for LaTeX formatting
|
|
2984
|
+
|
|
2985
|
+
Returns
|
|
2986
|
+
-------
|
|
2987
|
+
bool
|
|
2988
|
+
True if LaTeX formatting is detected (e.g., $...$, _{...}, ^{...})
|
|
2989
|
+
"""
|
|
2990
|
+
import re
|
|
2991
|
+
# Check for common LaTeX patterns:
|
|
2992
|
+
# - Text wrapped in $ $
|
|
2993
|
+
# - LaTeX subscripts _{...}
|
|
2994
|
+
# - LaTeX superscripts ^{...}
|
|
2995
|
+
latex_patterns = [
|
|
2996
|
+
r'\$[^$]+\$', # $...$
|
|
2997
|
+
r'_\{[^}]+\}', # _{...}
|
|
2998
|
+
r'\^\{[^}]+\}' # ^{...}
|
|
2999
|
+
]
|
|
3000
|
+
|
|
3001
|
+
for pattern in latex_patterns:
|
|
3002
|
+
if re.search(pattern, text):
|
|
3003
|
+
return True
|
|
3004
|
+
return False
|
|
3005
|
+
|
|
3006
|
+
|
|
3007
|
+
def _format_html_species(formula: str) -> str:
|
|
3008
|
+
"""
|
|
3009
|
+
Format a chemical formula for HTML rendering in Plotly.
|
|
3010
|
+
|
|
3011
|
+
Converts chemical formulas like "H2O" to "H<sub>2</sub>O" and
|
|
3012
|
+
"Ca+2" to "Ca<sup>2+</sup>".
|
|
3013
|
+
|
|
3014
|
+
Parameters
|
|
3015
|
+
----------
|
|
3016
|
+
formula : str
|
|
3017
|
+
Chemical formula to format
|
|
3018
|
+
|
|
3019
|
+
Returns
|
|
3020
|
+
-------
|
|
3021
|
+
str
|
|
3022
|
+
HTML-formatted formula
|
|
3023
|
+
"""
|
|
3024
|
+
import re
|
|
3025
|
+
|
|
3026
|
+
# Handle charge notation (e.g., +2, -1)
|
|
3027
|
+
# Match patterns like +2, -2, +, -
|
|
3028
|
+
charge_pattern = r'([+-])(\d*)'
|
|
3029
|
+
|
|
3030
|
+
def format_charge(match):
|
|
3031
|
+
sign = match.group(1)
|
|
3032
|
+
num = match.group(2)
|
|
3033
|
+
if num == '' or num == '1':
|
|
3034
|
+
return f"<sup>{sign}</sup>"
|
|
3035
|
+
else:
|
|
3036
|
+
return f"<sup>{num}{sign}</sup>"
|
|
3037
|
+
|
|
3038
|
+
# First handle charges at the end
|
|
3039
|
+
formula = re.sub(charge_pattern + r'$', format_charge, formula)
|
|
3040
|
+
|
|
3041
|
+
# Handle subscript numbers (digits that aren't part of the charge)
|
|
3042
|
+
# Match digits that come after letters and aren't preceded by < or >
|
|
3043
|
+
def format_subscript(match):
|
|
3044
|
+
return f"<sub>{match.group(0)}</sub>"
|
|
3045
|
+
|
|
3046
|
+
# Find all digits that should be subscripts
|
|
3047
|
+
# This matches digits that come after letters
|
|
3048
|
+
result = []
|
|
3049
|
+
i = 0
|
|
3050
|
+
while i < len(formula):
|
|
3051
|
+
if formula[i].isdigit() and i > 0 and formula[i-1].isalpha():
|
|
3052
|
+
# Start of a number sequence
|
|
3053
|
+
num_start = i
|
|
3054
|
+
while i < len(formula) and formula[i].isdigit():
|
|
3055
|
+
i += 1
|
|
3056
|
+
result.append(f"<sub>{formula[num_start:i]}</sub>")
|
|
3057
|
+
else:
|
|
3058
|
+
result.append(formula[i])
|
|
3059
|
+
i += 1
|
|
3060
|
+
|
|
3061
|
+
return ''.join(result)
|
|
3062
|
+
|
|
3063
|
+
|
|
3064
|
+
def _save_figure(fig, save_as, save_format, save_scale, plot_width, plot_height, ppi):
|
|
3065
|
+
"""
|
|
3066
|
+
Save a Plotly figure to a file.
|
|
3067
|
+
|
|
3068
|
+
Parameters
|
|
3069
|
+
----------
|
|
3070
|
+
fig : plotly figure
|
|
3071
|
+
The figure to save
|
|
3072
|
+
save_as : str or None
|
|
3073
|
+
Filename (without extension) to save as
|
|
3074
|
+
save_format : str or None
|
|
3075
|
+
Format to save ('png', 'jpg', 'jpeg', 'webp', 'svg', 'pdf', 'eps', 'json', 'html')
|
|
3076
|
+
save_scale : float
|
|
3077
|
+
Scale factor for saving
|
|
3078
|
+
plot_width : int
|
|
3079
|
+
Width of the plot
|
|
3080
|
+
plot_height : int
|
|
3081
|
+
Height of the plot
|
|
3082
|
+
ppi : int
|
|
3083
|
+
Pixels per inch
|
|
3084
|
+
|
|
3085
|
+
Returns
|
|
3086
|
+
-------
|
|
3087
|
+
tuple
|
|
3088
|
+
(save_as, save_format) - processed values for use in config
|
|
3089
|
+
"""
|
|
3090
|
+
import plotly.io as pio
|
|
3091
|
+
|
|
3092
|
+
valid_formats = ['png', 'jpg', 'jpeg', 'webp', 'svg', 'pdf', 'eps', 'json', 'html']
|
|
3093
|
+
|
|
3094
|
+
if isinstance(save_format, str) and save_format not in valid_formats:
|
|
3095
|
+
raise ValueError(f"{save_format} is an unrecognized save format. "
|
|
3096
|
+
f"Supported formats include: {', '.join(valid_formats)}")
|
|
3097
|
+
|
|
3098
|
+
if isinstance(save_format, str) and save_as is not None:
|
|
3099
|
+
if not isinstance(save_as, str):
|
|
3100
|
+
save_as = "newplot"
|
|
3101
|
+
|
|
3102
|
+
if save_format == "html":
|
|
3103
|
+
fig.write_html(save_as + ".html")
|
|
3104
|
+
print(f"Saved figure as {save_as}.html")
|
|
3105
|
+
save_format = 'png'
|
|
3106
|
+
elif save_format in ['pdf', 'eps', 'json']:
|
|
3107
|
+
pio.write_image(fig, save_as + "." + save_format, format=save_format,
|
|
3108
|
+
scale=save_scale, width=plot_width * ppi, height=plot_height * ppi)
|
|
3109
|
+
print(f"Saved figure as {save_as}.{save_format}")
|
|
3110
|
+
save_format = "png"
|
|
3111
|
+
else:
|
|
3112
|
+
pio.write_image(fig, save_as + "." + save_format, format=save_format,
|
|
3113
|
+
scale=save_scale, width=plot_width * ppi, height=plot_height * ppi)
|
|
3114
|
+
print(f"Saved figure as {save_as}.{save_format}")
|
|
3115
|
+
else:
|
|
3116
|
+
save_format = "png"
|
|
3117
|
+
|
|
3118
|
+
return save_as, save_format
|
|
3119
|
+
|
|
3120
|
+
|
|
3121
|
+
def _plot_saturation_interactive(eout, values_list, sp_names, xyvars, xyvals,
|
|
3122
|
+
xlab, ylab, col, lwd, lty, cex, contour_method,
|
|
3123
|
+
main, add, ax, width, height, plot_it,
|
|
3124
|
+
save_as, save_format, save_scale, messages):
|
|
3125
|
+
"""
|
|
3126
|
+
Plot saturation lines (affinity=0 contours) for interactive 2-D diagrams using Plotly.
|
|
3127
|
+
|
|
3128
|
+
This function draws contour lines where affinity = 0 for each species,
|
|
3129
|
+
indicating saturation boundaries (e.g., mineral precipitation thresholds).
|
|
3130
|
+
"""
|
|
3131
|
+
import plotly.graph_objects as go
|
|
3132
|
+
import plotly.io as pio
|
|
3133
|
+
|
|
3134
|
+
# Get x and y values
|
|
3135
|
+
xvals = xyvals[0]
|
|
3136
|
+
yvals = xyvals[1]
|
|
3137
|
+
xvar = xyvars[0]
|
|
3138
|
+
yvar = xyvars[1]
|
|
3139
|
+
|
|
3140
|
+
n_species = len(sp_names)
|
|
3141
|
+
|
|
3142
|
+
if messages:
|
|
3143
|
+
print(f"diagram: plotting saturation lines for interactive 2-D diagram")
|
|
3144
|
+
|
|
3145
|
+
# Set up colors
|
|
3146
|
+
if col is None:
|
|
3147
|
+
# Use default Plotly colors
|
|
3148
|
+
default_colors = ['#636EFA', '#EF553B', '#00CC96', '#AB63FA', '#FFA15A',
|
|
3149
|
+
'#19D3F3', '#FF6692', '#B6E880', '#FF97FF', '#FECB52']
|
|
3150
|
+
col = [default_colors[i % len(default_colors)] for i in range(n_species)]
|
|
3151
|
+
elif isinstance(col, str):
|
|
3152
|
+
col = [col] * n_species
|
|
3153
|
+
else:
|
|
3154
|
+
col = list(col)
|
|
3155
|
+
if len(col) < n_species:
|
|
3156
|
+
col = col * (n_species // len(col) + 1)
|
|
3157
|
+
col = col[:n_species]
|
|
3158
|
+
|
|
3159
|
+
# Set up line widths
|
|
3160
|
+
if isinstance(lwd, (int, float)):
|
|
3161
|
+
lwd = [lwd] * n_species
|
|
3162
|
+
else:
|
|
3163
|
+
lwd = list(lwd)
|
|
3164
|
+
if len(lwd) < n_species:
|
|
3165
|
+
lwd = lwd * (n_species // len(lwd) + 1)
|
|
3166
|
+
lwd = lwd[:n_species]
|
|
3167
|
+
|
|
3168
|
+
# Handle line styles (lty)
|
|
3169
|
+
if lty is None:
|
|
3170
|
+
lty = ['solid'] * n_species
|
|
3171
|
+
elif isinstance(lty, (str, int)):
|
|
3172
|
+
lty = [lty] * n_species
|
|
3173
|
+
else:
|
|
3174
|
+
lty = list(lty)
|
|
3175
|
+
if len(lty) < n_species:
|
|
3176
|
+
lty = lty * (n_species // len(lty) + 1)
|
|
3177
|
+
lty = lty[:n_species]
|
|
3178
|
+
|
|
3179
|
+
# Convert numeric/matplotlib line styles to Plotly dash types
|
|
3180
|
+
lty_map = {
|
|
3181
|
+
1: 'solid', '-': 'solid',
|
|
3182
|
+
2: 'dash', '--': 'dash',
|
|
3183
|
+
3: 'dashdot', '-.': 'dashdot',
|
|
3184
|
+
4: 'dot', ':': 'dot',
|
|
3185
|
+
5: 'solid', 6: 'dash'
|
|
3186
|
+
}
|
|
3187
|
+
lty = [lty_map.get(lt, 'solid') if (isinstance(lt, int) or lt in lty_map) else 'solid' for lt in lty]
|
|
3188
|
+
|
|
3189
|
+
# Handle contour label control
|
|
3190
|
+
if contour_method is None or contour_method == "" or (isinstance(contour_method, str) and contour_method.upper() == "NA"):
|
|
3191
|
+
show_labels = [False] * n_species
|
|
3192
|
+
elif isinstance(contour_method, str):
|
|
3193
|
+
show_labels = [True] * n_species
|
|
3194
|
+
elif isinstance(contour_method, list):
|
|
3195
|
+
if len(contour_method) != n_species:
|
|
3196
|
+
contour_method_extended = list(contour_method) * (n_species // len(contour_method) + 1)
|
|
3197
|
+
contour_method_extended = contour_method_extended[:n_species]
|
|
3198
|
+
else:
|
|
3199
|
+
contour_method_extended = contour_method
|
|
3200
|
+
|
|
3201
|
+
show_labels = []
|
|
3202
|
+
for method in contour_method_extended:
|
|
3203
|
+
if method is None or method == "" or (isinstance(method, str) and method.upper() == "NA"):
|
|
3204
|
+
show_labels.append(False)
|
|
3205
|
+
else:
|
|
3206
|
+
show_labels.append(True)
|
|
3207
|
+
else:
|
|
3208
|
+
show_labels = [True] * n_species
|
|
3209
|
+
|
|
3210
|
+
# Handle text size (cex) for contour labels
|
|
3211
|
+
if isinstance(cex, (int, float)):
|
|
3212
|
+
cex_list = [cex] * n_species
|
|
3213
|
+
else:
|
|
3214
|
+
cex_list = list(cex)
|
|
3215
|
+
if len(cex_list) < n_species:
|
|
3216
|
+
cex_list = cex_list * (n_species // len(cex_list) + 1)
|
|
3217
|
+
cex_list = cex_list[:n_species]
|
|
3218
|
+
|
|
3219
|
+
# Base font size for labels (Plotly default ~12)
|
|
3220
|
+
base_font_size = 12
|
|
3221
|
+
font_sizes = [base_font_size * c for c in cex_list]
|
|
3222
|
+
|
|
3223
|
+
# Create or get figure
|
|
3224
|
+
if add and ax is not None:
|
|
3225
|
+
# ax is actually the Plotly figure from previous call
|
|
3226
|
+
fig = ax
|
|
3227
|
+
else:
|
|
3228
|
+
# Create new figure
|
|
3229
|
+
fig = go.Figure()
|
|
3230
|
+
|
|
3231
|
+
# Add contour lines for each species
|
|
3232
|
+
for i, sp_name in enumerate(sp_names):
|
|
3233
|
+
# Get affinity values for this species
|
|
3234
|
+
# values_list[i] has shape (nx, ny)
|
|
3235
|
+
affinity_2d = values_list[i]
|
|
3236
|
+
|
|
3237
|
+
# Create contour trace for affinity=0 only
|
|
3238
|
+
# Use ncontours=1 with start=0 and end=0 to force a single contour at zero
|
|
3239
|
+
# Note: Plotly doesn't support custom text on contour lines, so we rely on the legend
|
|
3240
|
+
contour = go.Contour(
|
|
3241
|
+
x=xvals,
|
|
3242
|
+
y=yvals,
|
|
3243
|
+
z=affinity_2d.T, # Transpose to match Plotly's expected orientation
|
|
3244
|
+
ncontours=1, # Only generate one contour level
|
|
3245
|
+
contours=dict(
|
|
3246
|
+
coloring='lines', # Draw only contour lines, not filled regions
|
|
3247
|
+
start=0, # Start at 0
|
|
3248
|
+
end=0, # End at 0
|
|
3249
|
+
showlabels=False # Don't show "0" labels on contour lines
|
|
3250
|
+
),
|
|
3251
|
+
line=dict(
|
|
3252
|
+
color=col[i],
|
|
3253
|
+
width=lwd[i],
|
|
3254
|
+
dash=lty[i]
|
|
3255
|
+
),
|
|
3256
|
+
colorscale=[[0, col[i]], [1, col[i]]], # Force uniform color
|
|
3257
|
+
showscale=False,
|
|
3258
|
+
hoverinfo='skip',
|
|
3259
|
+
name=sp_name,
|
|
3260
|
+
legendgroup=sp_name,
|
|
3261
|
+
showlegend=True
|
|
3262
|
+
)
|
|
3263
|
+
|
|
3264
|
+
fig.add_trace(contour)
|
|
3265
|
+
|
|
3266
|
+
# Set axis labels if not adding to existing plot
|
|
3267
|
+
if not add:
|
|
3268
|
+
# Create axis labels with proper units
|
|
3269
|
+
from .thermo import thermo as thermo_func
|
|
3270
|
+
basis_df = eout['basis']
|
|
3271
|
+
basis_sp = list(basis_df.index)
|
|
3272
|
+
basis_state = list(basis_df['state'])
|
|
3273
|
+
|
|
3274
|
+
unit_dict = {"P": "bar", "T": "°C", "pH": "", "Eh": "volts", "IS": "mol/kg"}
|
|
3275
|
+
|
|
3276
|
+
for i, s in enumerate(basis_sp):
|
|
3277
|
+
if basis_state[i] in ["aq", "liq", "cr"]:
|
|
3278
|
+
unit_dict[s] = f"log <i>a</i><sub>{s}</sub>"
|
|
3279
|
+
else:
|
|
3280
|
+
unit_dict[s] = f"log <i>f</i><sub>{s}</sub>"
|
|
3281
|
+
|
|
3282
|
+
if not isinstance(xlab, str):
|
|
3283
|
+
xlab = xvar + ", " + unit_dict.get(xvar, "")
|
|
3284
|
+
if xvar == "pH":
|
|
3285
|
+
xlab = "pH"
|
|
3286
|
+
if xvar in basis_sp:
|
|
3287
|
+
xlab = unit_dict[xvar]
|
|
3288
|
+
|
|
3289
|
+
if not isinstance(ylab, str):
|
|
3290
|
+
ylab = yvar + ", " + unit_dict.get(yvar, "")
|
|
3291
|
+
if yvar in basis_sp:
|
|
3292
|
+
ylab = unit_dict[yvar]
|
|
3293
|
+
if yvar == "pH":
|
|
3294
|
+
ylab = "pH"
|
|
3295
|
+
|
|
3296
|
+
fig.update_xaxes(title_text=xlab)
|
|
3297
|
+
fig.update_yaxes(title_text=ylab)
|
|
3298
|
+
|
|
3299
|
+
fig.update_layout(
|
|
3300
|
+
template="simple_white",
|
|
3301
|
+
width=width,
|
|
3302
|
+
height=height,
|
|
3303
|
+
showlegend=True
|
|
3304
|
+
)
|
|
3305
|
+
|
|
3306
|
+
if isinstance(main, str):
|
|
3307
|
+
fig.update_layout(title={'text': main, 'x': 0.5, 'xanchor': 'center'})
|
|
3308
|
+
|
|
3309
|
+
# Configure download button
|
|
3310
|
+
save_as_name, save_format_final = _save_figure(fig, save_as, save_format, save_scale,
|
|
3311
|
+
plot_width=width, plot_height=height, ppi=1)
|
|
3312
|
+
|
|
3313
|
+
config = {'displaylogo': False,
|
|
3314
|
+
'modeBarButtonsToRemove': ['resetScale2d', 'toggleSpikelines'],
|
|
3315
|
+
'toImageButtonOptions': {
|
|
3316
|
+
'format': save_format_final,
|
|
3317
|
+
'filename': save_as_name,
|
|
3318
|
+
'height': height,
|
|
3319
|
+
'width': width,
|
|
3320
|
+
'scale': save_scale,
|
|
3321
|
+
}}
|
|
3322
|
+
|
|
3323
|
+
fig._config = fig._config | config if hasattr(fig, '_config') else config
|
|
3324
|
+
|
|
3325
|
+
# Show plot if requested
|
|
3326
|
+
if plot_it:
|
|
3327
|
+
fig.show(config=config)
|
|
3328
|
+
|
|
3329
|
+
# Return empty DataFrame (saturation doesn't produce tabular data like predominance diagrams)
|
|
3330
|
+
df = pd.DataFrame()
|
|
3331
|
+
|
|
3332
|
+
return df, fig
|
|
3333
|
+
|
|
3334
|
+
|
|
3335
|
+
# Export main functions
|
|
3336
|
+
__all__ = ['diagram', 'diagram_interactive', 'water_lines', 'find_tp']
|