pychnosz 1.1.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. pychnosz/__init__.py +129 -0
  2. pychnosz/biomolecules/__init__.py +29 -0
  3. pychnosz/biomolecules/ionize_aa.py +197 -0
  4. pychnosz/biomolecules/proteins.py +595 -0
  5. pychnosz/core/__init__.py +46 -0
  6. pychnosz/core/affinity.py +1256 -0
  7. pychnosz/core/animation.py +593 -0
  8. pychnosz/core/balance.py +334 -0
  9. pychnosz/core/basis.py +716 -0
  10. pychnosz/core/diagram.py +3336 -0
  11. pychnosz/core/equilibrate.py +813 -0
  12. pychnosz/core/equilibrium.py +554 -0
  13. pychnosz/core/info.py +821 -0
  14. pychnosz/core/retrieve.py +364 -0
  15. pychnosz/core/speciation.py +580 -0
  16. pychnosz/core/species.py +599 -0
  17. pychnosz/core/subcrt.py +1700 -0
  18. pychnosz/core/thermo.py +593 -0
  19. pychnosz/core/unicurve.py +1226 -0
  20. pychnosz/data/__init__.py +11 -0
  21. pychnosz/data/add_obigt.py +327 -0
  22. pychnosz/data/extdata/Berman/BDat17_2017.csv +2 -0
  23. pychnosz/data/extdata/Berman/Ber88_1988.csv +68 -0
  24. pychnosz/data/extdata/Berman/Ber90_1990.csv +5 -0
  25. pychnosz/data/extdata/Berman/DS10_2010.csv +6 -0
  26. pychnosz/data/extdata/Berman/FDM+14_2014.csv +2 -0
  27. pychnosz/data/extdata/Berman/Got04_2004.csv +5 -0
  28. pychnosz/data/extdata/Berman/JUN92_1992.csv +3 -0
  29. pychnosz/data/extdata/Berman/SHD91_1991.csv +12 -0
  30. pychnosz/data/extdata/Berman/VGT92_1992.csv +2 -0
  31. pychnosz/data/extdata/Berman/VPT01_2001.csv +3 -0
  32. pychnosz/data/extdata/Berman/VPV05_2005.csv +2 -0
  33. pychnosz/data/extdata/Berman/ZS92_1992.csv +11 -0
  34. pychnosz/data/extdata/Berman/sympy.R +99 -0
  35. pychnosz/data/extdata/Berman/testing/BA96.bib +12 -0
  36. pychnosz/data/extdata/Berman/testing/BA96_Berman.csv +21 -0
  37. pychnosz/data/extdata/Berman/testing/BA96_OBIGT.csv +21 -0
  38. pychnosz/data/extdata/Berman/testing/BA96_refs.csv +6 -0
  39. pychnosz/data/extdata/OBIGT/AD.csv +25 -0
  40. pychnosz/data/extdata/OBIGT/Berman_cr.csv +93 -0
  41. pychnosz/data/extdata/OBIGT/DEW.csv +211 -0
  42. pychnosz/data/extdata/OBIGT/H2O_aq.csv +4 -0
  43. pychnosz/data/extdata/OBIGT/SLOP98.csv +411 -0
  44. pychnosz/data/extdata/OBIGT/SUPCRT92.csv +178 -0
  45. pychnosz/data/extdata/OBIGT/inorganic_aq.csv +729 -0
  46. pychnosz/data/extdata/OBIGT/inorganic_cr.csv +273 -0
  47. pychnosz/data/extdata/OBIGT/inorganic_gas.csv +20 -0
  48. pychnosz/data/extdata/OBIGT/organic_aq.csv +1104 -0
  49. pychnosz/data/extdata/OBIGT/organic_cr.csv +481 -0
  50. pychnosz/data/extdata/OBIGT/organic_gas.csv +268 -0
  51. pychnosz/data/extdata/OBIGT/organic_liq.csv +533 -0
  52. pychnosz/data/extdata/OBIGT/testing/GEMSFIT.csv +43 -0
  53. pychnosz/data/extdata/OBIGT/testing/IGEM.csv +17 -0
  54. pychnosz/data/extdata/OBIGT/testing/Sandia.csv +8 -0
  55. pychnosz/data/extdata/OBIGT/testing/SiO2.csv +4 -0
  56. pychnosz/data/extdata/misc/AD03_Fig1a.csv +69 -0
  57. pychnosz/data/extdata/misc/AD03_Fig1b.csv +43 -0
  58. pychnosz/data/extdata/misc/AD03_Fig1c.csv +89 -0
  59. pychnosz/data/extdata/misc/AD03_Fig1d.csv +30 -0
  60. pychnosz/data/extdata/misc/BZA10.csv +5 -0
  61. pychnosz/data/extdata/misc/HW97_Cp.csv +90 -0
  62. pychnosz/data/extdata/misc/HWM96_V.csv +229 -0
  63. pychnosz/data/extdata/misc/LA19_test.csv +7 -0
  64. pychnosz/data/extdata/misc/Mer75_Table4.csv +42 -0
  65. pychnosz/data/extdata/misc/OBIGT_check.csv +423 -0
  66. pychnosz/data/extdata/misc/PM90.csv +7 -0
  67. pychnosz/data/extdata/misc/RH95.csv +23 -0
  68. pychnosz/data/extdata/misc/RH98_Table15.csv +17 -0
  69. pychnosz/data/extdata/misc/SC10_Rainbow.csv +19 -0
  70. pychnosz/data/extdata/misc/SK95.csv +55 -0
  71. pychnosz/data/extdata/misc/SOJSH.csv +61 -0
  72. pychnosz/data/extdata/misc/SS98_Fig5a.csv +81 -0
  73. pychnosz/data/extdata/misc/SS98_Fig5b.csv +84 -0
  74. pychnosz/data/extdata/misc/TKSS14_Fig2.csv +25 -0
  75. pychnosz/data/extdata/misc/bluered.txt +1000 -0
  76. pychnosz/data/extdata/protein/Cas/Cas_aa.csv +177 -0
  77. pychnosz/data/extdata/protein/Cas/Cas_uniprot.csv +186 -0
  78. pychnosz/data/extdata/protein/Cas/download.R +34 -0
  79. pychnosz/data/extdata/protein/Cas/mkaa.R +34 -0
  80. pychnosz/data/extdata/protein/POLG.csv +12 -0
  81. pychnosz/data/extdata/protein/TBD+05.csv +393 -0
  82. pychnosz/data/extdata/protein/TBD+05_aa.csv +393 -0
  83. pychnosz/data/extdata/protein/rubisco.csv +28 -0
  84. pychnosz/data/extdata/protein/rubisco.fasta +239 -0
  85. pychnosz/data/extdata/protein/rubisco_aa.csv +28 -0
  86. pychnosz/data/extdata/src/H2O92D.f.orig +3457 -0
  87. pychnosz/data/extdata/src/README.txt +5 -0
  88. pychnosz/data/extdata/taxonomy/names.dmp +215 -0
  89. pychnosz/data/extdata/taxonomy/nodes.dmp +63 -0
  90. pychnosz/data/extdata/thermo/Bdot_acirc.csv +60 -0
  91. pychnosz/data/extdata/thermo/buffer.csv +40 -0
  92. pychnosz/data/extdata/thermo/element.csv +135 -0
  93. pychnosz/data/extdata/thermo/groups.csv +6 -0
  94. pychnosz/data/extdata/thermo/opt.csv +2 -0
  95. pychnosz/data/extdata/thermo/protein.csv +506 -0
  96. pychnosz/data/extdata/thermo/refs.csv +343 -0
  97. pychnosz/data/extdata/thermo/stoich.csv.xz +0 -0
  98. pychnosz/data/loader.py +431 -0
  99. pychnosz/data/mod_obigt.py +322 -0
  100. pychnosz/data/obigt.py +471 -0
  101. pychnosz/data/worm.py +228 -0
  102. pychnosz/fortran/__init__.py +16 -0
  103. pychnosz/fortran/h2o92.dll +0 -0
  104. pychnosz/fortran/h2o92_interface.py +527 -0
  105. pychnosz/geochemistry/__init__.py +21 -0
  106. pychnosz/geochemistry/minerals.py +514 -0
  107. pychnosz/geochemistry/redox.py +500 -0
  108. pychnosz/models/__init__.py +47 -0
  109. pychnosz/models/archer_wang.py +165 -0
  110. pychnosz/models/berman.py +309 -0
  111. pychnosz/models/cgl.py +381 -0
  112. pychnosz/models/dew.py +997 -0
  113. pychnosz/models/hkf.py +523 -0
  114. pychnosz/models/hkf_helpers.py +222 -0
  115. pychnosz/models/iapws95.py +1113 -0
  116. pychnosz/models/supcrt92_fortran.py +238 -0
  117. pychnosz/models/water.py +480 -0
  118. pychnosz/utils/__init__.py +27 -0
  119. pychnosz/utils/expression.py +1074 -0
  120. pychnosz/utils/formula.py +830 -0
  121. pychnosz/utils/formula_ox.py +227 -0
  122. pychnosz/utils/reset.py +33 -0
  123. pychnosz/utils/units.py +259 -0
  124. pychnosz-1.1.1.dist-info/METADATA +197 -0
  125. pychnosz-1.1.1.dist-info/RECORD +128 -0
  126. pychnosz-1.1.1.dist-info/WHEEL +5 -0
  127. pychnosz-1.1.1.dist-info/licenses/LICENSE.txt +19 -0
  128. pychnosz-1.1.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,3336 @@
1
+ """
2
+ Diagram module for plotting chemical activity and predominance diagrams.
3
+
4
+ This module provides Python equivalents of the R functions in diagram.R:
5
+ - diagram(): Plot equilibrium chemical activity and predominance diagrams
6
+ - Supporting utilities for 1D line plots and 2D predominance diagrams
7
+
8
+ Author: CHNOSZ Python port
9
+ """
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import matplotlib.pyplot as plt
14
+ from typing import Union, List, Optional, Dict, Any, Tuple
15
+ import warnings
16
+ import copy
17
+ from ..utils.expression import _format_species_latex
18
+
19
+
20
+ def copy_plot(diagram_result: Dict[str, Any]) -> Dict[str, Any]:
21
+ """
22
+ Create a deep copy of a diagram result, allowing independent modification.
23
+
24
+ This function addresses a fundamental limitation in Python plotting libraries:
25
+ matplotlib figure and axes objects are mutable, so passing them between
26
+ functions causes modifications to affect all references. This function
27
+ creates a true deep copy that can be modified independently.
28
+
29
+ Parameters
30
+ ----------
31
+ diagram_result : dict
32
+ Result dictionary from diagram(), which may contain 'fig' and 'ax' keys
33
+
34
+ Returns
35
+ -------
36
+ dict
37
+ A deep copy of the diagram result with independent figure and axes objects
38
+
39
+ Examples
40
+ --------
41
+ Manual copying workflow (advanced usage - normally use add_to parameter instead):
42
+
43
+ >>> import pychnosz
44
+ >>> # Create base plot (Plot A)
45
+ >>> basis(['SiO2', 'Ca+2', 'Mg+2', 'CO2', 'H2O', 'O2', 'H+'])
46
+ >>> species(['quartz', 'talc', 'chrysotile', 'forsterite'])
47
+ >>> a = affinity(**{'Mg+2': [4, 10, 500], 'Ca+2': [5, 15, 500]})
48
+ >>> plot_a = diagram(a, fill='terrain')
49
+ >>>
50
+ >>> # Manual approach: create copies first, then modify the axes directly
51
+ >>> plot_a1 = copy_plot(plot_a) # For modification 1
52
+ >>> plot_a2 = copy_plot(plot_a) # For modification 2
53
+ >>> # ... then modify plot_a1['ax'] and plot_a2['ax'] directly
54
+ >>>
55
+ >>> # Recommended approach: use add_to parameter instead
56
+ >>> # This automatically handles copying internally
57
+ >>> basis('CO2', -1)
58
+ >>> species(['calcite', 'dolomide'])
59
+ >>> a2 = affinity(**{'Mg+2': [4, 10, 500], 'Ca+2': [5, 15, 500]})
60
+ >>> plot_a1 = diagram(a2, type='saturation', add_to=plot_a, col='blue')
61
+ >>> plot_a2 = diagram(a2, type='saturation', add_to=plot_a, col='red')
62
+ >>> # Now you have three independent plots: plot_a, plot_a1, plot_a2
63
+
64
+ Notes
65
+ -----
66
+ - This function uses copy.deepcopy() which works well for matplotlib figures
67
+ - For very large plots, copying may be memory-intensive
68
+ - Interactive plots (plotly) may not copy perfectly - test before relying on this
69
+ - The copied plot is fully independent and can be saved, displayed, or modified
70
+ without affecting the original
71
+
72
+ Limitations
73
+ -----------
74
+ Python's matplotlib (unlike R's base graphics) uses mutable objects for plots.
75
+ Without explicit copying, all references point to the same plot. This is a
76
+ known limitation of matplotlib that this function works around.
77
+
78
+ See Also
79
+ --------
80
+ diagram : Create plots that can be copied with this function
81
+ """
82
+ return copy.deepcopy(diagram_result)
83
+
84
+
85
+ def diagram(eout: Dict[str, Any],
86
+ type: str = "auto",
87
+ alpha: bool = False,
88
+ balance: Optional[Union[str, float, List[float]]] = None,
89
+ names: Optional[List[str]] = None,
90
+ format_names: bool = True,
91
+ xlab: Optional[str] = None,
92
+ ylab: Optional[str] = None,
93
+ xlim: Optional[List[float]] = None,
94
+ ylim: Optional[List[float]] = None,
95
+ col: Optional[Union[str, List[str]]] = None,
96
+ col_names: Optional[Union[str, List[str]]] = None,
97
+ lty: Optional[Union[str, int, List]] = None,
98
+ lwd: Union[float, List[float]] = 1,
99
+ cex: Union[float, List[float]] = 1.0,
100
+ main: Optional[str] = None,
101
+ fill: Optional[str] = None,
102
+ fill_NA: str = "0.8",
103
+ limit_water: Optional[bool] = None,
104
+ plot_it: bool = True,
105
+ add_to: Optional[Dict[str, Any]] = None,
106
+ contour_method: Optional[Union[str, List[str]]] = "edge",
107
+ messages: bool = True,
108
+ interactive: bool = False,
109
+ annotation: Optional[str] = None,
110
+ annotation_coords: List[float] = [0, 0],
111
+ width: int = 600,
112
+ height: int = 520,
113
+ save_as: Optional[str] = None,
114
+ save_format: Optional[str] = None,
115
+ save_scale: float = 1,
116
+ normalize: Union[bool, List[bool]] = False,
117
+ as_residue: bool = False,
118
+ **kwargs) -> Dict[str, Any]:
119
+ """
120
+ Plot equilibrium chemical activity and predominance diagrams.
121
+
122
+ This function creates plots from the output of affinity() or equilibrate().
123
+ For 1D diagrams, it produces line plots showing how affinity or activity
124
+ varies with a single variable. For 2D diagrams, it creates predominance
125
+ field diagrams.
126
+
127
+ Parameters
128
+ ----------
129
+ eout : dict
130
+ Output from affinity() or equilibrate()
131
+ type : str, default "auto"
132
+ Type of diagram:
133
+ - "auto" (default): Plot affinity values (A/2.303RT)
134
+ - "loga.equil": Plot equilibrium activities from equilibrate()
135
+ - "saturation": Draw affinity=0 contour lines (mineral saturation)
136
+ - Basis species name (e.g., "O2", "H2O", "CO2"): Plot equilibrium
137
+ log activity/fugacity of the specified basis species where affinity=0
138
+ for each formed species. Useful for Eh-pH diagrams and showing
139
+ oxygen/water fugacities at equilibrium.
140
+ alpha : bool or str, default False
141
+ Plot degree of formation instead of activities?
142
+ If "balance", scale by balancing coefficients
143
+ balance : str, float, or list of float, optional
144
+ Balancing coefficients or method for balancing reactions
145
+ names : list of str, optional
146
+ Custom names for species (for labels)
147
+ format_names : bool, default True
148
+ Apply formatting to chemical formulas?
149
+ xlab : str, optional
150
+ Custom x-axis label
151
+ ylab : str, optional
152
+ Custom y-axis label
153
+ xlim : list of float, optional
154
+ X-axis limits [min, max]
155
+ ylim : list of float, optional
156
+ Y-axis limits [min, max]
157
+ col : str or list of str, optional
158
+ Line colors for 1-D plots and boundary lines in 2-D plots (matplotlib color specs)
159
+ col_names : str or list of str, optional
160
+ Text colors for field labels in 2-D plots (matplotlib color specs)
161
+ lty : str, int, or list, optional
162
+ Line styles (matplotlib linestyle specs)
163
+ lwd : float or list of float, default 1
164
+ Line widths for 1-D plots and boundary lines in 2-D predominance
165
+ diagrams. Set to 0 to disable borders in 2-D diagrams. If fill is
166
+ None and lwd > 0, uses white fill with black borders (R CHNOSZ default).
167
+ cex : float or list of float, default 1.0
168
+ Character expansion factor for text labels. Values > 1 make text larger,
169
+ values < 1 make text smaller. Can be a single value or a list (one per species).
170
+ Used for contour labels in type="saturation" plots.
171
+ main : str, optional
172
+ Plot title
173
+ fill : str, optional
174
+ Color palette for 2-D predominance diagrams. Can be any matplotlib
175
+ colormap name (e.g., 'viridis', 'plasma', 'terrain', 'rainbow',
176
+ 'Set1', 'tab10', 'Pastel1'). If None, uses discrete colors from
177
+ the default color cycle. Ignored for 1-D diagrams.
178
+ fill_NA : str, default "0.8"
179
+ Color for regions outside water stability limits (water instability regions).
180
+ Matplotlib color specification (e.g., "0.8" for gray, "#CCCCCC").
181
+ Set to "transparent" to disable shading. Default "0.8" matches R's "gray80".
182
+ limit_water : bool, optional
183
+ Whether to show water stability limits as shaded regions (default True for
184
+ 2-D diagrams). If True, also clips the diagram to the water stability region.
185
+ Set to False to disable water stability shading.
186
+ plot_it : bool, default True
187
+ Display the plot?
188
+ add_to : dict, optional
189
+ A diagram result dictionary from a previous diagram() call. When provided,
190
+ this plot will be AUTOMATICALLY COPIED and the new diagram will be added to
191
+ the copy. This preserves the original plot while creating a modified version.
192
+ The axes object is extracted from add_to['ax'].
193
+
194
+ This parameter eliminates the need for a separate 'add' boolean - when
195
+ add_to is provided, the function automatically operates in "add" mode.
196
+
197
+ Example workflow:
198
+ >>> plot_a = diagram(affinity1, fill='terrain') # Create base plot
199
+ >>> plot_a1 = diagram(affinity2, add_to=plot_a, col='blue') # Copy and add
200
+ >>> plot_a2 = diagram(affinity3, add_to=plot_a, col='red') # Copy and add again
201
+ >>> # plot_a remains unchanged, plot_a1 and plot_a2 are independent modifications
202
+ contour_method : str or list of str, optional
203
+ Method for labeling contour lines. Default "edge" labels at plot edges.
204
+ Can be a single value (applied to all species) or a list (one per species).
205
+ Set to None, NA, or "" to disable labels (only for type="saturation").
206
+ In R CHNOSZ, different methods like "edge", "flattest", "simple" control
207
+ label placement; in Python, this mainly controls whether labels are shown.
208
+ interactive : bool, default False
209
+ Create an interactive plot using Plotly instead of matplotlib?
210
+ If True, calls diagram_interactive() with the appropriate parameters.
211
+ annotation : str, optional
212
+ For interactive plots only. Annotation text to add to the plot.
213
+ annotation_coords : list of float, default [0, 0]
214
+ For interactive plots only. Coordinates of annotation, where [0, 0] is
215
+ bottom left and [1, 1] is top right.
216
+ width : int, default 600
217
+ For interactive plots only. Width of the plot in pixels.
218
+ height : int, default 520
219
+ For interactive plots only. Height of the plot in pixels.
220
+ save_as : str, optional
221
+ For interactive plots only. Provide a filename to save this figure.
222
+ Filetype is determined by `save_format`.
223
+ save_format : str, optional
224
+ For interactive plots only. Desired format of saved or downloaded figure.
225
+ Can be 'png', 'jpg', 'jpeg', 'webp', 'svg', 'pdf', 'eps', 'json', or 'html'.
226
+ If 'html', an interactive plot will be saved.
227
+ save_scale : float, default 1
228
+ For interactive plots only. Multiply title/legend/axis/canvas sizes by
229
+ this factor when saving the figure.
230
+ **kwargs
231
+ Additional arguments passed to matplotlib plotting functions
232
+
233
+ Returns
234
+ -------
235
+ dict
236
+ Dictionary containing:
237
+ - plotvar : str, Variable that was plotted
238
+ - plotvals : dict, Values that were plotted
239
+ - names : list, Names used for labels
240
+ - predominant : array or NA, Predominance matrix (for 2D)
241
+ - balance : str or list, Balancing coefficients used
242
+ - n.balance : list, Numerical balancing coefficients
243
+ - ax : matplotlib.axes.Axes, The axes object used for plotting (if plot_it=True)
244
+ - fig : matplotlib.figure.Figure, The figure object used for plotting (if plot_it=True)
245
+ - All original eout contents
246
+
247
+ Examples
248
+ --------
249
+ >>> import pychnosz
250
+ >>> pychnosz.basis(["Fe2O3", "CO2", "H2O", "NH3", "H2S", "oxygen", "H+"],
251
+ ... [0, -3, 0, -4, -7, -80, -7])
252
+ >>> pychnosz.species(["pyrite", "goethite"])
253
+ >>> a = pychnosz.affinity(H2S=[-60, 20, 5], T=25, P=1)
254
+ >>> d = diagram(a)
255
+
256
+ Notes
257
+ -----
258
+ This implementation is based on R CHNOSZ diagram() function but adapted
259
+ for Python's matplotlib plotting instead of R's base graphics. The key
260
+ differences from diagram_from_WORM.py are:
261
+ - Works directly with Python dict output from affinity() (no rpy2)
262
+ - Uses matplotlib for 1D plots by default
263
+ - Can optionally use plotly if requested
264
+ """
265
+
266
+ # Handle add_to parameter: automatically copy the provided plot
267
+ # This extracts the axes object and creates an independent copy
268
+ # When add_to is provided, we're in "add" mode
269
+ ax = None
270
+ add = add_to is not None
271
+ plot_was_provided = add
272
+
273
+ if add_to is not None:
274
+ # Make a deep copy of the provided plot to preserve the original
275
+ plot_copy = copy_plot(add_to)
276
+ # Extract the axes from the copied plot
277
+ if 'ax' in plot_copy:
278
+ ax = plot_copy['ax']
279
+ else:
280
+ raise ValueError("The 'add_to' parameter must contain an 'ax' key (a diagram result dictionary)")
281
+
282
+ # If interactive mode is requested, delegate to diagram_interactive
283
+ if interactive:
284
+ df, fig = diagram_interactive(
285
+ eout=eout,
286
+ type=type,
287
+ main=main,
288
+ borders=lwd,
289
+ names=names,
290
+ format_names=format_names,
291
+ annotation=annotation,
292
+ annotation_coords=annotation_coords,
293
+ balance=balance,
294
+ xlab=xlab,
295
+ ylab=ylab,
296
+ fill=fill,
297
+ width=width,
298
+ height=height,
299
+ alpha=alpha,
300
+ plot_it=plot_it,
301
+ add=add,
302
+ ax=ax,
303
+ col=col,
304
+ lty=lty,
305
+ lwd=lwd,
306
+ cex=cex,
307
+ contour_method=contour_method,
308
+ save_as=save_as,
309
+ save_format=save_format,
310
+ save_scale=save_scale,
311
+ messages=messages
312
+ )
313
+ # Return in a format compatible with diagram's normal output
314
+ # diagram_interactive returns (df, fig), wrap in a dict for consistency
315
+ # Include eout data so water_lines() can access vars, vals, basis, etc.
316
+ result = {
317
+ **eout, # Include all original eout data
318
+ 'df': df,
319
+ 'fig': fig,
320
+ 'ax': fig # For compatibility, store fig in ax key for add=True workflow
321
+ }
322
+ return result
323
+
324
+ # Check that eout is valid
325
+ efun = eout.get('fun', '')
326
+ if efun not in ['affinity', 'equilibrate', 'solubility']:
327
+ raise ValueError("'eout' is not the output from affinity(), equilibrate(), or solubility()")
328
+
329
+ # Determine if eout is from affinity() (as opposed to equilibrate())
330
+ # Check for both Python naming (loga_equil) and R naming (loga.equil)
331
+ eout_is_aout = 'loga_equil' not in eout and 'loga.equil' not in eout
332
+
333
+ # Check if type is a basis species name
334
+ plot_loga_basis = False
335
+ if type not in ["auto", "saturation", "loga.equil", "loga_equil", "loga.balance", "loga_balance"]:
336
+ # Check if type matches a basis species name
337
+ if 'basis' in eout:
338
+ basis_species = list(eout['basis'].index) if hasattr(eout['basis'], 'index') else []
339
+ if type in basis_species:
340
+ plot_loga_basis = True
341
+ if alpha:
342
+ raise ValueError("equilibrium activities of basis species not available with alpha = TRUE")
343
+
344
+ # Handle type="saturation" - requires affinity output
345
+ if type == "saturation":
346
+ if not eout_is_aout:
347
+ raise ValueError("type='saturation' requires output from affinity(), not equilibrate()")
348
+ # Set eout_is_aout flag
349
+ eout_is_aout = True
350
+
351
+ # Get number of dimensions
352
+ # Handle both dict (affinity) and list (equilibrate) values structures
353
+ if isinstance(eout['values'], dict):
354
+ first_values = list(eout['values'].values())[0]
355
+ elif isinstance(eout['values'], list):
356
+ first_values = eout['values'][0]
357
+ else:
358
+ first_values = eout['values']
359
+
360
+ if hasattr(first_values, 'shape'):
361
+ nd = len(first_values.shape)
362
+ elif hasattr(first_values, '__len__'):
363
+ nd = 1
364
+ else:
365
+ nd = 0 # Single value
366
+
367
+ # For affinity output, get balancing coefficients
368
+ if eout_is_aout and type == "auto":
369
+ n_balance, balance = _get_balance(eout, balance, messages)
370
+ elif eout_is_aout and type == "saturation":
371
+ # For saturation diagrams, use n_balance = 1 for all species (don't normalize by stoichiometry)
372
+ if isinstance(eout['values'], dict):
373
+ n_balance = [1] * len(eout['values'])
374
+ elif isinstance(eout['values'], list):
375
+ n_balance = [1] * len(eout['values'])
376
+ else:
377
+ n_balance = [1]
378
+ if balance is None:
379
+ balance = 1
380
+ else:
381
+ # For equilibrate output, use n_balance from equilibrate if available
382
+ if 'n_balance' in eout:
383
+ n_balance = eout['n_balance']
384
+ balance = eout.get('balance', 1)
385
+ else:
386
+ if isinstance(eout['values'], dict):
387
+ n_balance = [1] * len(eout['values'])
388
+ elif isinstance(eout['values'], list):
389
+ n_balance = [1] * len(eout['values'])
390
+ else:
391
+ n_balance = [1]
392
+ if balance is None:
393
+ balance = 1
394
+
395
+ # Determine what to plot
396
+ plotvals = {}
397
+ plotvar = eout.get('property', 'A')
398
+
399
+ # Calculate equilibrium log activity/fugacity of basis species
400
+ if plot_loga_basis:
401
+ # Find the index of the basis species
402
+ basis_df = eout['basis']
403
+ ibasis = list(basis_df.index).index(type)
404
+
405
+ # Get the logarithm of activity used in the affinity calculation
406
+ logact = basis_df.iloc[ibasis]['logact']
407
+
408
+ # Check if logact is numeric
409
+ try:
410
+ loga_basis = float(logact)
411
+ except (ValueError, TypeError):
412
+ raise ValueError(f"the logarithm of activity for basis species {type} is not numeric - was a buffer selected?")
413
+
414
+ # Get the reaction coefficients for this basis species
415
+ # eout['species'] is a DataFrame with basis species as columns
416
+ nu_basis = eout['species'].iloc[:, ibasis].values
417
+
418
+ # Calculate the logarithm of activity where affinity = 0
419
+ # loga_equilibrium = loga_basis - affinity / nu_basis
420
+ plotvals = {}
421
+ for i, (sp_idx, affinity_vals) in enumerate(eout['values'].items()):
422
+ plotvals[sp_idx] = loga_basis - affinity_vals / nu_basis[i]
423
+
424
+ plotvar = type
425
+
426
+ # Set n_balance (not used for basis species plots, but needed for compatibility)
427
+ n_balance = [1] * len(plotvals)
428
+ if balance is None:
429
+ balance = 1
430
+ elif eout_is_aout:
431
+ # Plot affinity values divided by balancing coefficients
432
+ # DEBUG: Check balance application
433
+ if False: # Set to True for debugging
434
+ print(f"\nDEBUG: Applying balance to affinity values")
435
+ print(f" n_balance: {n_balance}")
436
+
437
+ # Handle dict-based values (from affinity)
438
+ if isinstance(eout['values'], dict):
439
+ for i, (species_idx, values) in enumerate(eout['values'].items()):
440
+ if False: # Set to True for debugging
441
+ print(f" Species {i} (ispecies {species_idx}): values/n_balance[{i}]={n_balance[i]}")
442
+ plotvals[species_idx] = values / n_balance[i]
443
+ # Handle list-based values
444
+ elif isinstance(eout['values'], list):
445
+ for i, values in enumerate(eout['values']):
446
+ species_idx = eout['species']['ispecies'].iloc[i]
447
+ plotvals[species_idx] = values / n_balance[i]
448
+
449
+ if plotvar == 'A':
450
+ plotvar = 'A/(2.303RT)'
451
+ if nd == 1:
452
+ if messages:
453
+ print(f"diagram: plotting {plotvar} / n.balance")
454
+ else:
455
+ # Plot equilibrated activities
456
+ # Check for both Python naming (loga_equil) and R naming (loga.equil)
457
+ loga_equil_key = 'loga_equil' if 'loga_equil' in eout else 'loga.equil'
458
+ loga_equil_list = eout[loga_equil_key]
459
+
460
+ # For equilibrate output, keep plotvals as a dict with INTEGER indices as keys
461
+ # This preserves the 1:1 correspondence with the species list, including duplicates
462
+ # Do NOT use ispecies as keys because duplicates would overwrite each other
463
+ if isinstance(loga_equil_list, list):
464
+ for i, loga_val in enumerate(loga_equil_list):
465
+ plotvals[i] = loga_val # Use integer index, not ispecies
466
+ else:
467
+ # Already a dict
468
+ plotvals = loga_equil_list
469
+
470
+ plotvar = 'loga.equil'
471
+
472
+ # Handle alpha (degree of formation)
473
+ if alpha:
474
+ # Convert to activities (remove logarithms)
475
+ # Use numpy arrays for proper element-wise operations
476
+ act_vals = {}
477
+ for k, v in plotvals.items():
478
+ if isinstance(v, np.ndarray):
479
+ act_vals[k] = 10**v
480
+ else:
481
+ act_vals[k] = np.power(10, v)
482
+
483
+ # Scale by balance if requested
484
+ if alpha == "balance":
485
+ species_keys = list(act_vals.keys())
486
+ for i, k in enumerate(species_keys):
487
+ act_vals[k] = act_vals[k] * n_balance[i]
488
+
489
+ # Calculate sum of activities (element-wise for arrays)
490
+ # Get the first value to determine shape
491
+ first_val = list(act_vals.values())[0]
492
+ if isinstance(first_val, np.ndarray):
493
+ # Multi-dimensional case
494
+ sum_act = np.zeros_like(first_val)
495
+ for v in act_vals.values():
496
+ sum_act = sum_act + v
497
+ else:
498
+ # Single value case
499
+ sum_act = sum(act_vals.values())
500
+
501
+ # Calculate alpha (fraction) - element-wise division
502
+ plotvals = {k: v / sum_act for k, v in act_vals.items()}
503
+ plotvar = "alpha"
504
+
505
+ # Get species information for labels
506
+ species_df = eout['species']
507
+ if names is None:
508
+ names = species_df['name'].tolist()
509
+
510
+ # Format chemical names if requested
511
+ if format_names and not alpha:
512
+ names = [_format_chemname(name) for name in names]
513
+
514
+ # Prepare for plotting
515
+ if nd == 0:
516
+ # 0-D: Bar plot (not implemented yet)
517
+ raise NotImplementedError("0-D bar plots not yet implemented")
518
+
519
+ elif nd == 1:
520
+ # 1-D: Line plot
521
+ result = _plot_1d(eout, plotvals, plotvar, names, n_balance, balance,
522
+ xlab, ylab, xlim, ylim, col, lty, lwd, main, add, plot_it, ax, width, height, plot_was_provided, **kwargs)
523
+
524
+ elif nd == 2:
525
+ # 2-D: Predominance diagram or saturation lines
526
+ # Pass lty and cex through kwargs for saturation plots
527
+ result = _plot_2d(eout, plotvals, plotvar, names, n_balance, balance,
528
+ xlab, ylab, xlim, ylim, col, col_names, fill, fill_NA, limit_water, lwd, main, add, plot_it, ax,
529
+ type, contour_method, messages, width, height, plot_was_provided, lty=lty, cex=cex, **kwargs)
530
+
531
+ else:
532
+ raise ValueError(f"Cannot create diagram with {nd} dimensions")
533
+
534
+ # Handle Jupyter display behavior
535
+ # When plot_it=True, we want the figure to display
536
+ # When plot_it=False, we want to suppress display and close the figure
537
+ if not plot_it and result is not None and 'fig' in result:
538
+ # Close the figure to prevent auto-display in Jupyter
539
+ # The figure is still in the result dict, so users can access it via result['fig']
540
+ # but it won't be displayed automatically
541
+ plt.close(result['fig'])
542
+ elif plot_it and result is not None and 'fig' in result:
543
+ # Try to use IPython display if available (for Jupyter notebooks)
544
+ try:
545
+ from IPython.display import display
546
+ display(result['fig'])
547
+ except (ImportError, NameError):
548
+ # Not in IPython/Jupyter, regular matplotlib display
549
+ pass
550
+
551
+ return result
552
+
553
+
554
+ def _get_balance(eout: Dict[str, Any], balance: Optional[Union[str, float, List[float]]], messages: bool = True) -> Tuple[List[float], Union[str, int, List[float]]]:
555
+ """
556
+ Get balancing coefficients for formation reactions.
557
+
558
+ This implements the R CHNOSZ balance() function logic for determining
559
+ how to balance formation reactions when calculating diagrams.
560
+
561
+ Parameters
562
+ ----------
563
+ eout : dict
564
+ Output from affinity()
565
+ balance : str, float, list of float, or None
566
+ Balancing specification
567
+
568
+ Returns
569
+ -------
570
+ tuple of (list of float, balance_name)
571
+ - Balancing coefficients for each species
572
+ - The balance identifier used
573
+ """
574
+ species_df = eout['species']
575
+ basis_df = eout['basis']
576
+ n_species = len(species_df)
577
+
578
+ # Get basis species column names (exclude metadata columns)
579
+ basis_cols = [col for col in species_df.columns
580
+ if col not in ['ispecies', 'name', 'state', 'logact']]
581
+
582
+ if balance is None:
583
+ # Auto-select using which_balance logic
584
+ ibalance = _which_balance(species_df, basis_cols)
585
+ if len(ibalance) == 0:
586
+ raise ValueError("no basis species is present in all formation reactions")
587
+ balance_col = basis_cols[ibalance[0]]
588
+ n_balance = species_df[balance_col].tolist()
589
+ if messages:
590
+ print(f"balance: on moles of {balance_col} in formation reactions")
591
+ balance = balance_col
592
+ elif balance == 1 or balance == "1":
593
+ # Balance on one mole of species (formula units)
594
+ n_balance = [1] * n_species
595
+ if messages:
596
+ print("balance: on supplied numeric argument (1) [1 means balance on formula units]")
597
+ balance = 1
598
+ elif isinstance(balance, (int, float)):
599
+ # Use a specific basis species by index
600
+ if 0 < balance <= len(basis_cols):
601
+ balance_col = basis_cols[int(balance) - 1]
602
+ n_balance = species_df[balance_col].tolist()
603
+ if messages:
604
+ print(f"balance: on moles of {balance_col} in formation reactions")
605
+ balance = balance_col
606
+ else:
607
+ warnings.warn(f"Balance index {balance} out of range, using 1")
608
+ n_balance = [1] * n_species
609
+ balance = 1
610
+ elif isinstance(balance, str):
611
+ # Use named basis species
612
+ if balance in species_df.columns:
613
+ n_balance = species_df[balance].tolist()
614
+ if messages:
615
+ print(f"balance: on moles of {balance} in formation reactions")
616
+ else:
617
+ warnings.warn(f"Balance species '{balance}' not found, using 1")
618
+ n_balance = [1] * n_species
619
+ balance = 1
620
+ elif isinstance(balance, list):
621
+ # Use provided list
622
+ if len(balance) == n_species:
623
+ n_balance = balance
624
+ if messages:
625
+ print(f"balance: on supplied numeric argument ({','.join(map(str, balance))})")
626
+ else:
627
+ warnings.warn(f"Balance list length ({len(balance)}) doesn't match species count ({n_species}), using 1")
628
+ n_balance = [1] * n_species
629
+ balance = 1
630
+ else:
631
+ n_balance = [1] * n_species
632
+ balance = 1
633
+
634
+ # Handle negative coefficients (make all positive if all negative)
635
+ if all(x < 0 for x in n_balance):
636
+ n_balance = [-x for x in n_balance]
637
+
638
+ return n_balance, balance
639
+
640
+
641
+ def _which_balance(species_df: pd.DataFrame, basis_cols: List[str]) -> List[int]:
642
+ """
643
+ Find basis species present in all formation reactions.
644
+
645
+ This implements R CHNOSZ which.balance() function.
646
+
647
+ Parameters
648
+ ----------
649
+ species_df : pd.DataFrame
650
+ Species dataframe with stoichiometric coefficients
651
+ basis_cols : list of str
652
+ Names of basis species columns
653
+
654
+ Returns
655
+ -------
656
+ list of int
657
+ Indices of basis species present in all reactions (0-indexed)
658
+ """
659
+ ib = []
660
+ for i, col in enumerate(basis_cols):
661
+ coeffs = species_df[col].values
662
+ # Check if all coefficients are non-zero
663
+ if np.all(coeffs != 0):
664
+ ib.append(i)
665
+ return ib
666
+
667
+
668
+ def _plot_1d(eout: Dict[str, Any],
669
+ plotvals: Dict,
670
+ plotvar: str,
671
+ names: List[str],
672
+ n_balance: List[float],
673
+ balance: Optional[Union[str, float, List[float]]],
674
+ xlab: Optional[str],
675
+ ylab: Optional[str],
676
+ xlim: Optional[List[float]],
677
+ ylim: Optional[List[float]],
678
+ col: Optional[Union[str, List[str]]],
679
+ lty: Optional[Union[str, int, List]],
680
+ lwd: Union[float, List[float]],
681
+ main: Optional[str],
682
+ add: bool,
683
+ plot_it: bool,
684
+ ax: Optional[Any],
685
+ width: int = 600,
686
+ height: int = 520,
687
+ plot_was_provided: bool = False,
688
+ **kwargs) -> Dict[str, Any]:
689
+ """
690
+ Create a 1-D line plot.
691
+
692
+ Parameters
693
+ ----------
694
+ (See diagram() for parameter descriptions)
695
+
696
+ Returns
697
+ -------
698
+ dict
699
+ Output dictionary with plot data and metadata
700
+ """
701
+
702
+ # Get x-axis values
703
+ xvar = eout['vars'][0]
704
+ xvals = eout['vals'][xvar]
705
+
706
+ # Convert to numpy array if needed
707
+ if not isinstance(xvals, np.ndarray):
708
+ xvals = np.array(xvals)
709
+
710
+ # Set up axis labels
711
+ if xlab is None:
712
+ xlab = _axis_label(xvar, eout)
713
+
714
+ if ylab is None:
715
+ ylab = _axis_label(plotvar, eout)
716
+
717
+ # Set up x limits
718
+ if xlim is None:
719
+ xlim = [xvals[0], xvals[-1]]
720
+
721
+ # Set up colors and line styles
722
+ n_species = len(plotvals)
723
+
724
+ if col is None:
725
+ # Use matplotlib default color cycle
726
+ prop_cycle = plt.rcParams['axes.prop_cycle']
727
+ colors = prop_cycle.by_key()['color']
728
+ col = [colors[i % len(colors)] for i in range(n_species)]
729
+ elif isinstance(col, str):
730
+ col = [col] * n_species
731
+ else:
732
+ col = list(col) * (n_species // len(col) + 1)
733
+ col = col[:n_species]
734
+
735
+ if lty is None:
736
+ lty = ['-'] * n_species
737
+ elif isinstance(lty, (str, int)):
738
+ lty = [lty] * n_species
739
+ else:
740
+ lty = list(lty) * (n_species // len(lty) + 1)
741
+ lty = lty[:n_species]
742
+
743
+ if isinstance(lwd, (int, float)):
744
+ lwd = [lwd] * n_species
745
+ else:
746
+ lwd = list(lwd) * (n_species // len(lwd) + 1)
747
+ lwd = lwd[:n_species]
748
+
749
+ # Convert numeric line styles to matplotlib styles
750
+ lty_map = {1: '-', 2: '--', 3: '-.', 4: ':', 5: '-', 6: '--'}
751
+ lty = [lty_map.get(lt, lt) if isinstance(lt, int) else lt for lt in lty]
752
+
753
+ # Temporarily disable interactive mode if plot_it=False
754
+ # This prevents Jupyter from auto-displaying the figure
755
+ was_interactive = plt.isinteractive()
756
+ if not plot_it:
757
+ plt.ioff()
758
+
759
+ # Convert width and height from pixels to inches for matplotlib
760
+ # Use standard 96 DPI for consistency with web/screen displays
761
+ dpi = 96
762
+ figsize_inches = (width / dpi, height / dpi)
763
+
764
+ # Create figure and axes (always, even if plot_it=False)
765
+ # This allows the plot to be used with add_to parameter later
766
+ fig = None
767
+ ax_was_provided = ax is not None # Track if ax was passed as parameter
768
+
769
+ if ax is not None:
770
+ # Use provided axes
771
+ fig = ax.get_figure()
772
+ elif not add:
773
+ # Create new figure and axes with specified size
774
+ fig, ax = plt.subplots(figsize=figsize_inches, dpi=dpi)
775
+ else:
776
+ # Try to get current axes, create new if none exists
777
+ try:
778
+ ax = plt.gca()
779
+ fig = ax.get_figure()
780
+ except:
781
+ fig, ax = plt.subplots(figsize=figsize_inches, dpi=dpi)
782
+
783
+ # Plot each species (always draw, regardless of plot_it)
784
+ # plot_it only controls display, not drawing
785
+ for i, (species_idx, yvals) in enumerate(plotvals.items()):
786
+ # Convert to numpy array if needed
787
+ if not isinstance(yvals, np.ndarray):
788
+ yvals = np.array([yvals] * len(xvals))
789
+
790
+ ax.plot(xvals, yvals,
791
+ color=col[i],
792
+ linestyle=lty[i],
793
+ linewidth=lwd[i],
794
+ label=names[i],
795
+ **kwargs)
796
+
797
+ # Set labels and limits
798
+ ax.set_xlabel(xlab)
799
+ ax.set_ylabel(ylab)
800
+
801
+ if xlim is not None:
802
+ ax.set_xlim(xlim)
803
+
804
+ if ylim is not None:
805
+ ax.set_ylim(ylim)
806
+
807
+ # Add legend
808
+ ax.legend()
809
+
810
+ # Add title
811
+ if main is not None:
812
+ ax.set_title(main)
813
+
814
+ # Add grid
815
+ ax.grid(True, alpha=0.3)
816
+
817
+ if not add:
818
+ plt.tight_layout()
819
+
820
+ # Build output dictionary
821
+ result = {
822
+ **eout,
823
+ 'plotvar': plotvar,
824
+ 'plotvals': plotvals,
825
+ 'names': names,
826
+ 'predominant': np.nan,
827
+ 'balance': balance,
828
+ 'n.balance': n_balance
829
+ }
830
+
831
+ # Add figure and axes to output if they were created
832
+ if fig is not None:
833
+ if not ax_was_provided or plot_was_provided:
834
+ result['ax'] = ax
835
+ result['fig'] = fig
836
+
837
+ # Always restore interactive mode to its original state
838
+ if was_interactive and not plt.isinteractive():
839
+ plt.ion()
840
+ elif not was_interactive and plt.isinteractive():
841
+ plt.ioff()
842
+
843
+ return result
844
+
845
+
846
+ def _plot_2d(eout: Dict[str, Any],
847
+ plotvals: Dict[int, np.ndarray],
848
+ plotvar: str,
849
+ names: List[str],
850
+ n_balance: List[float],
851
+ balance: Union[str, int, List[float]],
852
+ xlab: Optional[str],
853
+ ylab: Optional[str],
854
+ xlim: Optional[List[float]],
855
+ ylim: Optional[List[float]],
856
+ col: Optional[Union[str, List[str]]],
857
+ col_names: Optional[Union[str, List[str]]],
858
+ fill: Optional[str],
859
+ fill_NA: str,
860
+ limit_water: Optional[bool],
861
+ lwd: Union[float, List[float]],
862
+ main: Optional[str],
863
+ add: bool,
864
+ plot_it: bool,
865
+ ax: Optional[Any],
866
+ type: str = "auto",
867
+ contour_method: Optional[str] = "edge",
868
+ messages: bool = True,
869
+ width: int = 600,
870
+ height: int = 520,
871
+ plot_was_provided: bool = False,
872
+ **kwargs) -> Dict[str, Any]:
873
+ """
874
+ Create a 2-D predominance diagram (internal function).
875
+
876
+ This function determines which species has the maximum affinity at each
877
+ point on a 2D grid and creates a predominance field diagram.
878
+
879
+ Parameters
880
+ ----------
881
+ eout : dict
882
+ Output from affinity()
883
+ plotvals : dict
884
+ Values to plot (affinity or equilibrium values)
885
+ plotvar : str
886
+ Variable being plotted
887
+ names : list of str
888
+ Species names for labels
889
+ n_balance : list of float
890
+ Balancing coefficients
891
+ balance : str, int, or list
892
+ Balance identifier
893
+ xlab, ylab : str or None
894
+ Axis labels
895
+ xlim, ylim : list or None
896
+ Axis limits
897
+ col : str, list, or None
898
+ Colors for boundary lines in 2D plots
899
+ col_names : str, list, or None
900
+ Colors for field labels (text) in 2D plots
901
+ fill : str or None
902
+ Matplotlib colormap name for coloring predominance fields
903
+ (e.g., 'viridis', 'plasma', 'terrain', 'rainbow', 'Set1', 'tab10')
904
+ lwd : float or list of float
905
+ Line width for drawing boundaries between predominance fields.
906
+ Set to 0 to disable borders.
907
+ main : str or None
908
+ Plot title
909
+ add : bool
910
+ Add to existing plot?
911
+ plot_it : bool
912
+ Display the plot?
913
+ **kwargs
914
+ Additional matplotlib arguments
915
+
916
+ Returns
917
+ -------
918
+ dict
919
+ Result dictionary with predominance information
920
+ """
921
+
922
+ # Get the two variables
923
+ vars_list = eout['vars']
924
+ if len(vars_list) != 2:
925
+ raise ValueError(f"Expected 2 variables for 2-D plot, got {len(vars_list)}")
926
+
927
+ # R CHNOSZ convention: first variable in affinity() → x-axis, second → y-axis
928
+ # In the array: rows correspond to first var, columns to second var
929
+ xvar = vars_list[0] # First variable (rows in array) → x-axis
930
+ yvar = vars_list[1] # Second variable (cols in array) → y-axis
931
+
932
+ # Get the values for each variable
933
+ xvals = eout['vals'][xvar]
934
+ yvals = eout['vals'][yvar]
935
+
936
+ # Convert to numpy arrays if needed
937
+ xvals = np.asarray(xvals)
938
+ yvals = np.asarray(yvals)
939
+
940
+ # Get axis labels
941
+ if xlab is None:
942
+ xlab = _axis_label(xvar, eout)
943
+ if ylab is None:
944
+ ylab = _axis_label(yvar, eout)
945
+
946
+ # Handle saturation lines separately from predominance diagrams
947
+ if type == "saturation":
948
+ # Extract lty and cex from kwargs (they're in kwargs from diagram() call)
949
+ lty_param = kwargs.pop('lty', None)
950
+ cex_param = kwargs.pop('cex', 1.0)
951
+
952
+ return _plot_saturation_2d(eout, plotvals, plotvar, names, n_balance, balance,
953
+ xlab, ylab, xlim, ylim, col, lwd, lty_param, cex_param,
954
+ main, add, plot_it, ax, contour_method, messages, width, height, plot_was_provided, **kwargs)
955
+
956
+ # For non-saturation plots, remove lty and cex from kwargs to avoid passing to matplotlib
957
+ kwargs.pop('lty', None)
958
+ kwargs.pop('cex', None)
959
+
960
+ # Print message about diagram method
961
+ if messages:
962
+ print(f"diagram: using maximum affinity method for 2-D diagram")
963
+
964
+ # Stack all species values into a 3D array (species, rows, cols)
965
+ # The plotvals dict can have two types of keys:
966
+ # 1. Integer indices (0, 1, 2, ...) for equilibrate output - preserves duplicates
967
+ # 2. ispecies values (1844, 1876, ...) for affinity output - unique species only
968
+ species_keys = list(plotvals.keys())
969
+ n_species = len(species_keys)
970
+
971
+ # Check if species_keys are integer indices (equilibrate) or ispecies values (affinity)
972
+ # Integer indices will be 0, 1, 2, ..., n-1 and directly map to names list
973
+ # ispecies values are typically > 100 and need mapping
974
+ species_df = eout['species']
975
+ if all(isinstance(k, int) and k < len(names) for k in species_keys):
976
+ # Integer indices: direct mapping to names
977
+ predominant_to_names_idx = {i: species_keys[i] for i in range(n_species)}
978
+ else:
979
+ # ispecies values: need to find matching rows
980
+ # Use first matching row (for consistency with affinity output)
981
+ predominant_to_names_idx = {}
982
+ for i, sp_idx in enumerate(species_keys):
983
+ matching_rows = species_df[species_df['ispecies'] == sp_idx].index.tolist()
984
+ if len(matching_rows) > 0:
985
+ predominant_to_names_idx[i] = matching_rows[0]
986
+ else:
987
+ predominant_to_names_idx[i] = i # Fallback
988
+
989
+ # DEBUG: Print species_keys order
990
+ # print(f"DEBUG: species_keys = {species_keys}")
991
+
992
+ # Get dimensions from first species
993
+ first_vals = plotvals[species_keys[0]]
994
+ if len(first_vals.shape) != 2:
995
+ raise ValueError(f"Expected 2-D array for each species, got shape {first_vals.shape}")
996
+
997
+ # Array shape: (n_xvar, n_yvar) since first var → x-axis, second → y-axis
998
+ # For our example: (n_T, n_P) = (3, 10)
999
+ n_xvar, n_yvar = first_vals.shape
1000
+
1001
+ # Stack all species affinities into a 3D array
1002
+ affinity_stack = np.zeros((n_species, n_xvar, n_yvar))
1003
+ for i, sp_idx in enumerate(species_keys):
1004
+ affinity_stack[i, :, :] = plotvals[sp_idx]
1005
+
1006
+ # DEBUG: Check species order
1007
+ if False: # Set to True for debugging
1008
+ print(f"\nDEBUG affinity_stack:")
1009
+ print(f" n_species: {n_species}")
1010
+ print(f" species_keys: {species_keys}")
1011
+ for i, sp_idx in enumerate(species_keys):
1012
+ print(f" Stack position {i}: species {sp_idx}")
1013
+
1014
+ # Find the species with maximum affinity at each point
1015
+ # predominant will have indices 0, 1, 2, ... for species
1016
+ predominant_indices = np.argmax(affinity_stack, axis=0)
1017
+
1018
+ # Get the affinity values at predominant points
1019
+ predominant_values = np.max(affinity_stack, axis=0)
1020
+
1021
+ # Convert indices to species indices (1-based for R compatibility)
1022
+ # In R, predominant contains 1, 2, 3, etc.
1023
+ predominant = predominant_indices + 1
1024
+
1025
+ # Calculate water stability limits if requested (default True for 2-D diagrams)
1026
+ H2O_predominant = None
1027
+ if limit_water is None:
1028
+ limit_water = not add # True unless adding to existing plot
1029
+
1030
+ if limit_water:
1031
+ # Call water_lines with plot_it=False to get boundaries
1032
+ wl = water_lines(eout, plot_it=False, messages=messages)
1033
+ # Check if water_lines produced valid results
1034
+ if not (wl['y_oxidation'] is None or wl['y_reduction'] is None):
1035
+ # Create a copy of predominant matrix for water stability masking
1036
+ # Convert to float to allow NaN values
1037
+ H2O_predominant = predominant.astype(float).copy()
1038
+
1039
+ # For each x-point, find y-values outside water stability limits
1040
+ for i in range(len(wl['xpoints'])):
1041
+ ymin = min(wl['y_oxidation'][i], wl['y_reduction'][i])
1042
+ ymax = max(wl['y_oxidation'][i], wl['y_reduction'][i])
1043
+
1044
+ if not wl['swapped']:
1045
+ # Normal orientation: x is first var, y is second var
1046
+ # eout['vals'][yvar] contains the y-axis values
1047
+ yvals = np.asarray(eout['vals'][yvar])
1048
+ # Find indices where y is outside stability range
1049
+ iNA = (yvals < ymin) | (yvals > ymax)
1050
+ # Mark those regions as NA (using nan)
1051
+ H2O_predominant[i, iNA] = np.nan
1052
+ else:
1053
+ # Swapped: first var is y-axis
1054
+ xvals = np.asarray(eout['vals'][xvar])
1055
+ iNA = (xvals < ymin) | (xvals > ymax)
1056
+ H2O_predominant[iNA, i] = np.nan
1057
+
1058
+ # For plotting: transpose arrays so that x-axis is horizontal and y-axis is vertical
1059
+ # imshow expects (n_rows, n_cols) = (n_yaxis, n_xaxis)
1060
+ # Current shape is (n_xvar, n_yvar), so transpose to (n_yvar, n_xvar)
1061
+ predominant_indices_T = predominant_indices.T
1062
+ if H2O_predominant is not None:
1063
+ H2O_predominant_T = H2O_predominant.T
1064
+ else:
1065
+ H2O_predominant_T = None
1066
+
1067
+ # Temporarily disable interactive mode if plot_it=False
1068
+ # This prevents Jupyter from auto-displaying the figure
1069
+ was_interactive = plt.isinteractive()
1070
+ if not plot_it:
1071
+ plt.ioff()
1072
+
1073
+ # Convert width and height from pixels to inches for matplotlib
1074
+ # Use standard 96 DPI for consistency with web/screen displays
1075
+ dpi = 96
1076
+ figsize_inches = (width / dpi, height / dpi)
1077
+
1078
+ # Create figure and axes (always, even if plot_it=False)
1079
+ # This allows the plot to be used with add_to parameter later
1080
+ fig = None
1081
+ ax_was_provided = ax is not None # Track if ax was passed as parameter
1082
+
1083
+ if ax is not None:
1084
+ # Use provided axes
1085
+ fig = ax.get_figure()
1086
+ elif not add:
1087
+ # Create new figure and axes with specified size
1088
+ fig, ax = plt.subplots(figsize=figsize_inches, dpi=dpi)
1089
+ else:
1090
+ # Try to get current axes, create new if none exists
1091
+ try:
1092
+ ax = plt.gca()
1093
+ fig = ax.get_figure()
1094
+ except:
1095
+ fig, ax = plt.subplots(figsize=figsize_inches, dpi=dpi)
1096
+
1097
+ # When add=True, don't draw fill by default (to overlay on existing plot)
1098
+ # User can explicitly provide fill parameter to override this
1099
+ draw_fill = True
1100
+ if add and fill is None:
1101
+ draw_fill = False
1102
+
1103
+ # Draw the plot content (always, regardless of plot_it)
1104
+ # plot_it only controls display, not drawing
1105
+ # Determine fill colors for predominance fields
1106
+ # Priority: fill parameter > default
1107
+ # R CHNOSZ default: white fill with black borders
1108
+ if fill is not None:
1109
+ # Use a matplotlib colormap
1110
+ try:
1111
+ import matplotlib.cm as cm
1112
+ cmap = cm.get_cmap(fill)
1113
+ # Sample the colormap evenly across species
1114
+ # Use range 0.1 to 0.9 to avoid very light/dark ends
1115
+ color_indices = np.linspace(0.1, 0.9, n_species)
1116
+ fill_colors = [cmap(idx) for idx in color_indices]
1117
+ except (ValueError, KeyError):
1118
+ warnings.warn(f"Colormap '{fill}' not found, using default colors")
1119
+ prop_cycle = plt.rcParams['axes.prop_cycle']
1120
+ colors = prop_cycle.by_key()['color']
1121
+ fill_colors = [colors[i % len(colors)] for i in range(n_species)]
1122
+ elif fill is None and lwd > 0:
1123
+ # R CHNOSZ default behavior: white fill with black borders
1124
+ fill_colors = ['white'] * n_species
1125
+ else:
1126
+ # Use default matplotlib colors
1127
+ prop_cycle = plt.rcParams['axes.prop_cycle']
1128
+ colors = prop_cycle.by_key()['color']
1129
+ fill_colors = [colors[i % len(colors)] for i in range(n_species)]
1130
+
1131
+ # Determine boundary line colors (col parameter)
1132
+ if col is None:
1133
+ # Default to black for boundary lines
1134
+ boundary_colors = ['black'] * n_species
1135
+ elif isinstance(col, str):
1136
+ boundary_colors = [col] * n_species
1137
+ else:
1138
+ boundary_colors = list(col)
1139
+ if len(boundary_colors) < n_species:
1140
+ # Repeat colors if not enough provided
1141
+ boundary_colors = boundary_colors * (n_species // len(boundary_colors) + 1)
1142
+ boundary_colors = boundary_colors[:n_species]
1143
+
1144
+ # Determine text label colors (col_names parameter)
1145
+ if col_names is None:
1146
+ # Default to black for text labels
1147
+ text_colors = ['black'] * n_species
1148
+ elif isinstance(col_names, str):
1149
+ text_colors = [col_names] * n_species
1150
+ else:
1151
+ text_colors = list(col_names)
1152
+ if len(text_colors) < n_species:
1153
+ # Repeat colors if not enough provided
1154
+ text_colors = text_colors * (n_species // len(text_colors) + 1)
1155
+ text_colors = text_colors[:n_species]
1156
+
1157
+ # NOTE: Water instability shading is NOT drawn here in diagram()
1158
+ # It is only drawn when water_lines() is explicitly called
1159
+ # We keep H2O_predominant for the limit_water feature, but don't show gray shading
1160
+
1161
+ # Draw filled predominance fields only if not adding to existing plot
1162
+ # (or if user explicitly provided fill parameter)
1163
+ if draw_fill:
1164
+ # Create a colored map showing predominance fields
1165
+ # Map each predominance index to its color
1166
+ # Shape is (n_yvar, n_xvar) after transpose
1167
+ colored_predominant = np.zeros((n_yvar, n_xvar, 3)) # RGB image
1168
+ from matplotlib.colors import to_rgb
1169
+
1170
+ for i in range(n_species):
1171
+ mask = predominant_indices_T == i
1172
+ rgb = to_rgb(fill_colors[i])
1173
+ colored_predominant[mask] = rgb
1174
+
1175
+ # Display the predominance map
1176
+ # imshow plots: rows → y-axis (vertical), cols → x-axis (horizontal)
1177
+ # extent sets the data coordinates: [x_start, x_end, y_start, y_end]
1178
+ # IMPORTANT: Use original order (xvals[0] to xvals[-1]), not min/max
1179
+ # This preserves axes with decreasing values (e.g., F- from -2 to -9)
1180
+ extent = [xvals[0], xvals[-1], yvals[0], yvals[-1]]
1181
+
1182
+ # Expand extent if any dimension has identical limits to avoid matplotlib warning
1183
+ if extent[0] == extent[1]:
1184
+ x_range = abs(extent[0]) * 0.1 if extent[0] != 0 else 0.1
1185
+ extent[0] -= x_range
1186
+ extent[1] += x_range
1187
+ if extent[2] == extent[3]:
1188
+ y_range = abs(extent[2]) * 0.1 if extent[2] != 0 else 0.1
1189
+ extent[2] -= y_range
1190
+ extent[3] += y_range
1191
+
1192
+ im = ax.imshow(colored_predominant, aspect='auto', origin='lower',
1193
+ extent=extent, interpolation='nearest', **kwargs)
1194
+
1195
+ # Add species labels at the center of their predominance regions
1196
+ for i in range(n_species):
1197
+ # Find all points where this species predominates (in transposed array)
1198
+ mask = predominant_indices_T == i
1199
+ if np.any(mask):
1200
+ # Get row and column indices from transposed array
1201
+ # rows correspond to y-axis values, cols to x-axis values
1202
+ rows, cols = np.where(mask)
1203
+ # Calculate mean position
1204
+ mean_row = rows.mean()
1205
+ mean_col = cols.mean()
1206
+ # Convert to data coordinates
1207
+ # cols index into xvals, rows index into yvals
1208
+ x_pos = xvals[int(mean_col)]
1209
+ y_pos = yvals[int(mean_row)]
1210
+
1211
+ # Map predominant index to correct names index
1212
+ # This handles cases where duplicate species exist in the species list
1213
+ names_idx = predominant_to_names_idx[i]
1214
+
1215
+ # Add text label with color from col_names
1216
+ ax.text(x_pos, y_pos, names[names_idx],
1217
+ ha='center', va='center',
1218
+ color=text_colors[i],
1219
+ bbox=dict(boxstyle='round', facecolor='white', edgecolor='none', alpha=0.7),
1220
+ fontsize=10, fontweight='bold')
1221
+
1222
+ # Set labels and limits
1223
+ ax.set_xlabel(xlab)
1224
+ ax.set_ylabel(ylab)
1225
+
1226
+ if xlim is not None:
1227
+ ax.set_xlim(xlim)
1228
+ else:
1229
+ # Preserve original axis direction (important for decreasing axes)
1230
+ x_min, x_max = xvals[0], xvals[-1]
1231
+ # Expand range if limits are identical to avoid matplotlib warning
1232
+ if x_min == x_max:
1233
+ x_range = abs(x_max) * 0.1 if x_max != 0 else 0.1
1234
+ x_min -= x_range
1235
+ x_max += x_range
1236
+ ax.set_xlim([x_min, x_max])
1237
+
1238
+ if ylim is not None:
1239
+ ax.set_ylim(ylim)
1240
+ else:
1241
+ # Preserve original axis direction (important for decreasing axes)
1242
+ y_min, y_max = yvals[0], yvals[-1]
1243
+ # Expand range if limits are identical to avoid matplotlib warning
1244
+ if y_min == y_max:
1245
+ y_range = abs(y_max) * 0.1 if y_max != 0 else 0.1
1246
+ y_min -= y_range
1247
+ y_max += y_range
1248
+ ax.set_ylim([y_min, y_max])
1249
+
1250
+ # Add title
1251
+ if main is not None:
1252
+ ax.set_title(main)
1253
+
1254
+ # Draw borders between predominance fields
1255
+ if lwd > 0:
1256
+ # Use matplotlib's contour function to draw boundaries between species
1257
+ # This matches R CHNOSZ's use of contourLines()
1258
+ # Following R CHNOSZ approach: loop over species and draw contours at level 0.5
1259
+
1260
+ # Get unique species values (excluding any that don't appear)
1261
+ unique_species = np.unique(predominant_indices_T)
1262
+ unique_species = unique_species[~np.isnan(unique_species)]
1263
+
1264
+ # Create meshgrid for contour (matches actual data coordinates)
1265
+ X, Y = np.meshgrid(xvals, yvals)
1266
+
1267
+ # Loop over species (except the last one to avoid double-plotting)
1268
+ for i in range(len(unique_species) - 1):
1269
+ species_idx = int(unique_species[i])
1270
+
1271
+ # Create a binary mask: 1 where this species predominates, 0 elsewhere
1272
+ z = (predominant_indices_T == species_idx).astype(float)
1273
+
1274
+ # Draw contour at level 0.5 (boundary between 0 and 1)
1275
+ # This creates smooth boundaries following the actual grid
1276
+ try:
1277
+ line_color = boundary_colors[species_idx]
1278
+ cs = ax.contour(X, Y, z, levels=[0.5], colors=[line_color],
1279
+ linewidths=lwd, zorder=10)
1280
+ except:
1281
+ pass # Skip if contour can't be drawn (e.g., species doesn't appear)
1282
+
1283
+ if not add:
1284
+ plt.tight_layout()
1285
+
1286
+ # Don't close the figure when plot_it=False
1287
+ # The plt.ioff() above already prevents auto-display in Jupyter
1288
+ # This keeps the figure available for adding titles, legends, and later display
1289
+ # Users can display with: d['fig'].show() or display(d['fig']) in Jupyter
1290
+
1291
+ # Build output dictionary (matching R CHNOSZ structure)
1292
+ result = {
1293
+ **eout,
1294
+ 'plotvar': plotvar,
1295
+ 'plotvals': plotvals,
1296
+ 'names': names,
1297
+ 'predominant': predominant,
1298
+ 'predominant.values': predominant_values,
1299
+ 'balance': balance,
1300
+ 'n.balance': n_balance
1301
+ }
1302
+
1303
+ # Add figure and axes to output if they were created
1304
+ if fig is not None:
1305
+ if not ax_was_provided or plot_was_provided:
1306
+ result['ax'] = ax
1307
+ result['fig'] = fig
1308
+
1309
+ # Always restore interactive mode to its original state
1310
+ if was_interactive and not plt.isinteractive():
1311
+ plt.ion()
1312
+ elif not was_interactive and plt.isinteractive():
1313
+ plt.ioff()
1314
+
1315
+ return result
1316
+
1317
+
1318
+ def _plot_saturation_2d(eout: Dict[str, Any],
1319
+ plotvals: Dict[int, np.ndarray],
1320
+ plotvar: str,
1321
+ names: List[str],
1322
+ n_balance: List[float],
1323
+ balance: Union[str, int, List[float]],
1324
+ xlab: str,
1325
+ ylab: str,
1326
+ xlim: Optional[List[float]],
1327
+ ylim: Optional[List[float]],
1328
+ col: Optional[Union[str, List[str]]],
1329
+ lwd: Union[float, List[float]],
1330
+ lty: Optional[Union[str, int, List]],
1331
+ cex: Union[float, List[float]],
1332
+ main: Optional[str],
1333
+ add: bool,
1334
+ plot_it: bool,
1335
+ ax: Optional[Any],
1336
+ contour_method: Optional[Union[str, List[str]]],
1337
+ messages: bool = True,
1338
+ width: int = 600,
1339
+ height: int = 520,
1340
+ plot_was_provided: bool = False,
1341
+ **kwargs) -> Dict[str, Any]:
1342
+ """
1343
+ Plot saturation lines (affinity=0 contours) for 2-D diagrams.
1344
+
1345
+ This function draws contour lines where affinity = 0 for each species,
1346
+ indicating saturation boundaries (e.g., mineral precipitation thresholds).
1347
+
1348
+ Parameters
1349
+ ----------
1350
+ (Most parameters are the same as _plot_2d)
1351
+ contour_method : str, list of str, or None
1352
+ Method for labeling contour lines. Can be a single value (applied to all species)
1353
+ or a list (one per species). If None, NA, or "", disable labels.
1354
+ Matplotlib doesn't support the same contour methods as R, so this mainly
1355
+ controls whether labels are drawn (any non-None/non-empty value enables labels).
1356
+
1357
+ Returns
1358
+ -------
1359
+ dict
1360
+ Result dictionary
1361
+ """
1362
+
1363
+ # Get the two variables
1364
+ vars_list = eout['vars']
1365
+ xvar = vars_list[0]
1366
+ yvar = vars_list[1]
1367
+
1368
+ # Get the values for each variable
1369
+ xvals = np.asarray(eout['vals'][xvar])
1370
+ yvals = np.asarray(eout['vals'][yvar])
1371
+
1372
+ species_keys = list(plotvals.keys())
1373
+ n_species = len(species_keys)
1374
+
1375
+ if messages:
1376
+ print(f"diagram: plotting saturation lines for 2-D diagram")
1377
+
1378
+ # Set up colors and line styles
1379
+ if col is None:
1380
+ # Use matplotlib default color cycle
1381
+ prop_cycle = plt.rcParams['axes.prop_cycle']
1382
+ colors = prop_cycle.by_key()['color']
1383
+ col = [colors[i % len(colors)] for i in range(n_species)]
1384
+ elif isinstance(col, str):
1385
+ col = [col] * n_species
1386
+ else:
1387
+ col = list(col)
1388
+ if len(col) < n_species:
1389
+ col = col * (n_species // len(col) + 1)
1390
+ col = col[:n_species]
1391
+
1392
+ if isinstance(lwd, (int, float)):
1393
+ lwd = [lwd] * n_species
1394
+ else:
1395
+ lwd = list(lwd)
1396
+ if len(lwd) < n_species:
1397
+ lwd = lwd * (n_species // len(lwd) + 1)
1398
+ lwd = lwd[:n_species]
1399
+
1400
+ # Handle line styles (lty)
1401
+ if lty is None:
1402
+ lty = ['-'] * n_species
1403
+ elif isinstance(lty, (str, int)):
1404
+ lty = [lty] * n_species
1405
+ else:
1406
+ lty = list(lty)
1407
+ if len(lty) < n_species:
1408
+ lty = lty * (n_species // len(lty) + 1)
1409
+ lty = lty[:n_species]
1410
+
1411
+ # Convert numeric line styles to matplotlib styles
1412
+ lty_map = {1: '-', 2: '--', 3: '-.', 4: ':', 5: '-', 6: '--'}
1413
+ lty = [lty_map.get(lt, lt) if isinstance(lt, int) else lt for lt in lty]
1414
+
1415
+ # Handle text size (cex) for contour labels
1416
+ if isinstance(cex, (int, float)):
1417
+ cex_list = [cex] * n_species
1418
+ else:
1419
+ cex_list = list(cex)
1420
+ if len(cex_list) < n_species:
1421
+ cex_list = cex_list * (n_species // len(cex_list) + 1)
1422
+ cex_list = cex_list[:n_species]
1423
+
1424
+ # Determine if labels should be drawn (per species)
1425
+ # Convert contour_method to a list (one per species)
1426
+ if contour_method is None or contour_method == "" or (isinstance(contour_method, str) and contour_method.upper() == "NA"):
1427
+ # No labels for any species
1428
+ drawlabels = [False] * n_species
1429
+ elif isinstance(contour_method, str):
1430
+ # Same method for all species
1431
+ drawlabels = [True] * n_species
1432
+ elif isinstance(contour_method, list):
1433
+ # Per-species methods
1434
+ if len(contour_method) != n_species:
1435
+ # Repeat/extend to match number of species
1436
+ contour_method_extended = list(contour_method) * (n_species // len(contour_method) + 1)
1437
+ contour_method_extended = contour_method_extended[:n_species]
1438
+ else:
1439
+ contour_method_extended = contour_method
1440
+
1441
+ # Check each method to determine if labels should be drawn
1442
+ drawlabels = []
1443
+ for method in contour_method_extended:
1444
+ if method is None or method == "" or (isinstance(method, str) and method.upper() == "NA"):
1445
+ drawlabels.append(False)
1446
+ else:
1447
+ drawlabels.append(True)
1448
+ else:
1449
+ drawlabels = [True] * n_species
1450
+
1451
+ # Temporarily disable interactive mode if plot_it=False
1452
+ # This prevents Jupyter from auto-displaying the figure
1453
+ was_interactive = plt.isinteractive()
1454
+ if not plot_it:
1455
+ plt.ioff()
1456
+
1457
+ # Convert width and height from pixels to inches for matplotlib
1458
+ # Use standard 96 DPI for consistency with web/screen displays
1459
+ dpi = 96
1460
+ figsize_inches = (width / dpi, height / dpi)
1461
+
1462
+ # Create figure and axes (always, even if plot_it=False)
1463
+ # This allows the plot to be used with add_to parameter later
1464
+ fig = None
1465
+ ax_was_provided = ax is not None # Track if ax was passed as parameter
1466
+
1467
+ if ax is not None:
1468
+ fig = ax.get_figure()
1469
+ elif not add:
1470
+ fig, ax = plt.subplots(figsize=figsize_inches, dpi=dpi)
1471
+ else:
1472
+ try:
1473
+ ax = plt.gca()
1474
+ fig = ax.get_figure()
1475
+ except:
1476
+ fig, ax = plt.subplots(figsize=figsize_inches, dpi=dpi)
1477
+
1478
+ # Only do the actual plotting if plot_it=True
1479
+ # Draw the plot content (always, regardless of plot_it)
1480
+ # plot_it only controls display, not drawing
1481
+ if not add:
1482
+ ax.set_xlabel(xlab)
1483
+ ax.set_ylabel(ylab)
1484
+
1485
+ if xlim is not None:
1486
+ ax.set_xlim(xlim)
1487
+ else:
1488
+ ax.set_xlim([xvals.min(), xvals.max()])
1489
+
1490
+ if ylim is not None:
1491
+ ax.set_ylim(ylim)
1492
+ else:
1493
+ ax.set_ylim([yvals.min(), yvals.max()])
1494
+
1495
+ if main is not None:
1496
+ ax.set_title(main)
1497
+
1498
+ # Plot saturation lines (affinity = 0 contours) for each species
1499
+ for i, sp_idx in enumerate(species_keys):
1500
+ zs = plotvals[sp_idx]
1501
+
1502
+ # Skip plotting if this species has no possible saturation line
1503
+ if len(np.unique(zs)) == 1:
1504
+ if messages:
1505
+ print(f"diagram: no saturation line possible for {names[i]}")
1506
+ continue
1507
+
1508
+ # Skip if line is outside the plot range
1509
+ if np.all(zs < 0) or np.all(zs > 0):
1510
+ if messages:
1511
+ print(f"diagram: beyond range for saturation line of {names[i]}")
1512
+ continue
1513
+
1514
+ # Draw the contour line at affinity = 0
1515
+ # matplotlib's contour needs (X, Y, Z) where X and Y are meshgrids
1516
+ X, Y = np.meshgrid(xvals, yvals)
1517
+ # Transpose zs to match meshgrid orientation
1518
+ # plotvals has shape (n_xvar, n_yvar), but contour expects (n_yvar, n_xvar)
1519
+ zs_T = zs.T
1520
+
1521
+ try:
1522
+ # Calculate font size from cex (matplotlib default is ~10pt)
1523
+ fontsize = 9 * cex_list[i]
1524
+
1525
+ if drawlabels[i]:
1526
+ CS = ax.contour(X, Y, zs_T, levels=[0], colors=col[i],
1527
+ linewidths=lwd[i], linestyles=lty[i], **kwargs)
1528
+ ax.clabel(CS, inline=True, fontsize=fontsize, fmt=names[i])
1529
+ else:
1530
+ ax.contour(X, Y, zs_T, levels=[0], colors=col[i],
1531
+ linewidths=lwd[i], linestyles=lty[i], **kwargs)
1532
+ except Exception as e:
1533
+ warnings.warn(f"Could not plot contour for {names[i]}: {e}")
1534
+
1535
+ if not add:
1536
+ plt.tight_layout()
1537
+
1538
+ # Build output dictionary
1539
+ result = {
1540
+ **eout,
1541
+ 'plotvar': plotvar,
1542
+ 'plotvals': plotvals,
1543
+ 'names': names,
1544
+ 'predominant': np.nan,
1545
+ 'balance': balance,
1546
+ 'n.balance': n_balance
1547
+ }
1548
+
1549
+ # Add figure and axes to output if they were created
1550
+ if fig is not None:
1551
+ if not ax_was_provided or plot_was_provided:
1552
+ result['ax'] = ax
1553
+ result['fig'] = fig
1554
+
1555
+ # Always restore interactive mode to its original state
1556
+ if was_interactive and not plt.isinteractive():
1557
+ plt.ion()
1558
+ elif not was_interactive and plt.isinteractive():
1559
+ plt.ioff()
1560
+
1561
+ return result
1562
+
1563
+
1564
+ def _axis_label(var: str, eout: Dict[str, Any]) -> str:
1565
+ """
1566
+ Generate axis label for a variable.
1567
+
1568
+ Parameters
1569
+ ----------
1570
+ var : str
1571
+ Variable name
1572
+ eout : dict
1573
+ Output from affinity()
1574
+
1575
+ Returns
1576
+ -------
1577
+ str
1578
+ Formatted axis label
1579
+ """
1580
+
1581
+ # Special cases
1582
+ if var == 'A/(2.303RT)':
1583
+ return r'A/(2.303RT)'
1584
+ elif var == 'alpha':
1585
+ return r'$\alpha$'
1586
+ elif var == 'loga.equil':
1587
+ return r'log activity'
1588
+ elif var == 'pH':
1589
+ return 'pH'
1590
+ elif var == 'pe':
1591
+ return 'pe'
1592
+ elif var == 'Eh':
1593
+ return 'Eh (V)'
1594
+ elif var == 'T':
1595
+ return 'Temperature (°C)'
1596
+ elif var == 'P':
1597
+ return 'Pressure (bar)'
1598
+ elif var == 'IS':
1599
+ return 'Ionic strength'
1600
+ else:
1601
+ # Check if it's a basis species
1602
+ basis_df = eout.get('basis')
1603
+ if basis_df is not None and var in basis_df.index:
1604
+ state = basis_df.loc[var, 'state']
1605
+ # Format the chemical formula with proper LaTeX subscripts and superscripts
1606
+ var_formatted = _format_species_latex(var)
1607
+ if state in ['aq', 'liq', 'cr']:
1608
+ return f'$\\log\\ a_{{{var_formatted}}}$'
1609
+ else:
1610
+ return f'$\\log\\ f_{{{var_formatted}}}$'
1611
+ return var
1612
+
1613
+
1614
+ def _format_chemname(name: str) -> str:
1615
+ """
1616
+ Format a chemical formula for display in matplotlib.
1617
+
1618
+ Uses LaTeX formatting for proper subscripts and superscripts.
1619
+ Delegates to _format_species_latex from expression.py for consistency.
1620
+
1621
+ Parameters
1622
+ ----------
1623
+ name : str
1624
+ Chemical formula
1625
+
1626
+ Returns
1627
+ -------
1628
+ str
1629
+ Formatted formula (using LaTeX for matplotlib)
1630
+ """
1631
+ # Use the centralized formatting function and wrap in math mode for matplotlib
1632
+ latex_formula = _format_species_latex(name)
1633
+ return f'${latex_formula}$'
1634
+
1635
+
1636
+ def water_lines(eout: Dict[str, Any],
1637
+ which: Union[str, List[str]] = ['oxidation', 'reduction'],
1638
+ lty: Union[int, str] = 2,
1639
+ lwd: float = 1,
1640
+ col: Optional[str] = None,
1641
+ plot_it: bool = True,
1642
+ messages: bool = True) -> Dict[str, Any]:
1643
+ """
1644
+ Draw water stability limits for Eh-pH, logfO2-pH, logfO2-T or Eh-T diagrams.
1645
+
1646
+ This function adds lines showing the oxidation and reduction stability limits
1647
+ of water to diagrams. Above the oxidation line, water breaks down to O2.
1648
+ Below the reduction line, water breaks down to H2.
1649
+
1650
+ Parameters
1651
+ ----------
1652
+ eout : dict
1653
+ Output from affinity(), equilibrate(), or diagram()
1654
+ which : str or list of str, default ['oxidation', 'reduction']
1655
+ Which line(s) to draw: 'oxidation', 'reduction', or both
1656
+ lty : int or str, default 2
1657
+ Line style (matplotlib linestyle or numeric code)
1658
+ lwd : float, default 1
1659
+ Line width
1660
+ col : str, optional
1661
+ Line color (matplotlib color spec). If None, uses current foreground color
1662
+ plot_it : bool, default True
1663
+ Whether to plot the lines and display the figure. When True, the lines
1664
+ are added to the diagram and the figure is displayed (useful when the
1665
+ original diagram was created with plot_it=False). When False, only
1666
+ calculates and returns the water line coordinates without plotting.
1667
+
1668
+ Returns
1669
+ -------
1670
+ dict
1671
+ Dictionary containing all keys from the input diagram (including 'fig', 'ax',
1672
+ 'plotvar', 'plotvals', 'names', 'predominant', etc. if present) plus the
1673
+ following water line specific keys:
1674
+ - xpoints: x-axis values
1675
+ - y_oxidation: y values for oxidation line (or None)
1676
+ - y_reduction: y values for reduction line (or None)
1677
+ - swapped: whether axes were swapped
1678
+
1679
+ Examples
1680
+ --------
1681
+ >>> # Add water lines to an existing displayed diagram
1682
+ >>> basis(["Fe+2", "SO4-2", "H2O", "H+", "e-"], [0, math.log10(3), math.log10(0.75), 999, 999])
1683
+ >>> species(["rhomboclase", "ferricopiapite", "hydronium jarosite", "goethite", "melanterite", "pyrite"])
1684
+ >>> a = affinity(pH=[-1, 4, 256], pe=[-5, 23, 256])
1685
+ >>> d = diagram(a, main="Fe-S-O-H, after Majzlan et al., 2006")
1686
+ >>> water_lines(d, lwd=2)
1687
+
1688
+ >>> # Add water lines and display when diagram was created with plot_it=False
1689
+ >>> d = diagram(a, main="Fe-S-O-H", plot_it=False)
1690
+ >>> water_lines(d, lwd=2) # This will display the figure with water lines
1691
+
1692
+ Notes
1693
+ -----
1694
+ This function only works on diagrams with a redox variable (Eh, pe, O2, or H2)
1695
+ on one axis and pH, T, P, or another non-redox variable on the other axis.
1696
+ For 1-D diagrams, vertical lines are drawn.
1697
+ """
1698
+
1699
+ # Import here to avoid circular imports
1700
+ from ..utils.units import convert, envert
1701
+ from ..core.subcrt import subcrt
1702
+
1703
+ # Create a deep copy of the input to preserve all diagram information
1704
+ # This allows us to return all the original keys plus water line data
1705
+ result = copy_plot(eout)
1706
+
1707
+ # Detect if this is a Plotly figure (interactive diagram)
1708
+ is_plotly = False
1709
+ if 'fig' in result and result['fig'] is not None:
1710
+ is_plotly = hasattr(result['fig'], 'add_trace') and hasattr(result['fig'], 'update_layout')
1711
+
1712
+ # Ensure which is a list
1713
+ if isinstance(which, str):
1714
+ which = [which]
1715
+
1716
+ # Get number of variables used in affinity()
1717
+ nvar1 = len(result['vars'])
1718
+
1719
+ # Determine actual number of variables from array dimensions
1720
+ # Check both loga.equil (equilibrate output) and values (affinity output)
1721
+ if 'loga_equil' in result or 'loga.equil' in result:
1722
+ loga_key = 'loga_equil' if 'loga_equil' in result else 'loga.equil'
1723
+ first_val = result[loga_key][0] if isinstance(result[loga_key], list) else list(result[loga_key].values())[0]
1724
+ else:
1725
+ first_val = list(result['values'].values())[0] if isinstance(result['values'], dict) else result['values'][0]
1726
+
1727
+ if hasattr(first_val, 'shape'):
1728
+ dim = first_val.shape
1729
+ elif hasattr(first_val, '__len__'):
1730
+ dim = (len(first_val),)
1731
+ else:
1732
+ dim = ()
1733
+
1734
+ nvar2 = len(dim)
1735
+
1736
+ # We only work on diagrams with 1 or 2 variables
1737
+ if nvar1 not in [1, 2] or nvar2 not in [1, 2]:
1738
+ result.update({'xpoints': None, 'y_oxidation': None, 'y_reduction': None, 'swapped': False})
1739
+ return result
1740
+
1741
+ # Get variables from result
1742
+ vars_list = result['vars'].copy()
1743
+
1744
+ # If needed, swap axes so redox variable is on y-axis
1745
+ # Also do this for 1-D diagrams
1746
+ if len(vars_list) == 1:
1747
+ vars_list.append('nothing')
1748
+
1749
+ swapped = False
1750
+ if vars_list[1] in ['T', 'P', 'nothing']:
1751
+ vars_list = list(reversed(vars_list))
1752
+ vals_dict = {vars_list[0]: result['vals'][vars_list[0]]} if vars_list[0] != 'nothing' else {}
1753
+ if len(result['vars']) > 1:
1754
+ vals_dict[vars_list[1]] = result['vals'][vars_list[1]]
1755
+ swapped = True
1756
+ else:
1757
+ vals_dict = result['vals']
1758
+
1759
+ xaxis = vars_list[0]
1760
+ yaxis = vars_list[1]
1761
+ xpoints = np.asarray(vals_dict[xaxis]) if xaxis in vals_dict else np.array([0])
1762
+
1763
+ # Make xaxis "nothing" if it is not pH, T, or P
1764
+ # (so that horizontal water lines can be drawn for any non-redox variable on the x-axis)
1765
+ if xaxis not in ['pH', 'T', 'P']:
1766
+ xaxis = 'nothing'
1767
+
1768
+ # T and P are constants unless they are plotted on one of the axes
1769
+ T = result['T']
1770
+ if vars_list[0] == 'T':
1771
+ T = envert(xpoints, 'K')
1772
+ P = result['P']
1773
+ if vars_list[0] == 'P':
1774
+ P = envert(xpoints, 'bar')
1775
+
1776
+ # Handle the case where P is "Psat" - keep it as is for subcrt
1777
+ # (subcrt knows how to handle "Psat")
1778
+
1779
+ # logaH2O is 0 unless given in result['basis']
1780
+ basis_df = result['basis']
1781
+ if 'H2O' in basis_df.index:
1782
+ logaH2O = float(basis_df.loc['H2O', 'logact'])
1783
+ else:
1784
+ logaH2O = 0
1785
+
1786
+ # pH is 7 unless given in eout['basis'] or plotted on one of the axes
1787
+ if vars_list[0] == 'pH':
1788
+ pH = xpoints
1789
+ elif 'H+' in basis_df.index:
1790
+ minuspH = basis_df.loc['H+', 'logact']
1791
+ # Special treatment for non-numeric value (happens when a buffer is used)
1792
+ try:
1793
+ pH = -float(minuspH)
1794
+ except (ValueError, TypeError):
1795
+ pH = np.nan
1796
+ else:
1797
+ pH = 7
1798
+
1799
+ # O2 state is gas unless given in eout['basis']
1800
+ O2state = 'gas'
1801
+ if 'O2' in basis_df.index:
1802
+ O2state = basis_df.loc['O2', 'state']
1803
+
1804
+ # H2 state is gas unless given in eout['basis']
1805
+ H2state = 'gas'
1806
+ if 'H2' in basis_df.index:
1807
+ H2state = basis_df.loc['H2', 'state']
1808
+
1809
+ # Where the calculated values will go
1810
+ y_oxidation = None
1811
+ y_reduction = None
1812
+
1813
+ if xaxis in ['pH', 'T', 'P', 'nothing'] and yaxis in ['Eh', 'pe', 'O2', 'H2']:
1814
+ # Eh/pe/logfO2/logaO2/logfH2/logaH2 vs pH/T/P
1815
+
1816
+ # Reduction line (H2O + e- = 1/2 H2 + OH-)
1817
+ if 'reduction' in which:
1818
+ logfH2 = logaH2O # usually 0
1819
+
1820
+ if yaxis == 'H2':
1821
+ # Calculate equilibrium constant for gas-aqueous conversion if needed
1822
+ logK = subcrt(['H2', 'H2'], [-1, 1], ['gas', H2state], T=T, P=P, convert=False, messages=messages, show=False).out['logK']
1823
+ # This is logfH2 if H2state == "gas", or logaH2 if H2state == "aq"
1824
+ logfH2 = logfH2 + logK
1825
+ # Broadcast to match xpoints length
1826
+ if isinstance(logfH2, (int, float)):
1827
+ y_reduction = np.full_like(xpoints, logfH2)
1828
+ else:
1829
+ logfH2_val = float(logfH2.iloc[0]) if hasattr(logfH2, 'iloc') else float(logfH2[0])
1830
+ y_reduction = np.full_like(xpoints, logfH2_val)
1831
+ else:
1832
+ # Calculate logfO2 from H2O = 1/2 O2 + H2
1833
+ logK = subcrt(['H2O', 'O2', 'H2'], [-1, 0.5, 1], ['liq', O2state, 'gas'], T=T, P=P, convert=False, messages=messages, show=False).out['logK']
1834
+ # This is logfO2 if O2state == "gas", or logaO2 if O2state == "aq"
1835
+ logfO2 = 2 * (logK - logfH2 + logaH2O)
1836
+
1837
+ if yaxis == 'O2':
1838
+ # Broadcast to match xpoints length
1839
+ if isinstance(logfO2, (int, float)):
1840
+ y_reduction = np.full_like(xpoints, logfO2)
1841
+ else:
1842
+ logfO2_val = float(logfO2.iloc[0]) if hasattr(logfO2, 'iloc') else float(logfO2[0])
1843
+ y_reduction = np.full_like(xpoints, logfO2_val)
1844
+ elif yaxis == 'Eh':
1845
+ y_reduction = convert(logfO2, 'E0', T=T, P=P, pH=pH, logaH2O=logaH2O, messages=messages)
1846
+ elif yaxis == 'pe':
1847
+ Eh_val = convert(logfO2, 'E0', T=T, P=P, pH=pH, logaH2O=logaH2O, messages=messages)
1848
+ y_reduction = convert(Eh_val, 'pe', T=T, messages=messages)
1849
+
1850
+ # Oxidation line (H2O = 1/2 O2 + 2H+ + 2e-)
1851
+ if 'oxidation' in which:
1852
+ logfO2 = logaH2O # usually 0
1853
+
1854
+ if yaxis == 'H2':
1855
+ # Calculate logfH2 from H2O = 1/2 O2 + H2
1856
+ logK = subcrt(['H2O', 'O2', 'H2'], [-1, 0.5, 1], ['liq', 'gas', H2state], T=T, P=P, convert=False, messages=messages, show=False).out['logK']
1857
+ # This is logfH2 if H2state == "gas", or logaH2 if H2state == "aq"
1858
+ logfH2 = logK - 0.5*logfO2 + logaH2O
1859
+ # Broadcast to match xpoints length
1860
+ if isinstance(logfH2, (int, float)):
1861
+ y_oxidation = np.full_like(xpoints, logfH2)
1862
+ else:
1863
+ logfH2_val = float(logfH2.iloc[0]) if hasattr(logfH2, 'iloc') else float(logfH2[0])
1864
+ y_oxidation = np.full_like(xpoints, logfH2_val)
1865
+ else:
1866
+ # Calculate equilibrium constant for gas-aqueous conversion if needed
1867
+ logK = subcrt(['O2', 'O2'], [-1, 1], ['gas', O2state], T=T, P=P, convert=False, messages=messages, show=False).out['logK']
1868
+ # This is logfO2 if O2state == "gas", or logaO2 if O2state == "aq"
1869
+ logfO2 = logfO2 + logK
1870
+
1871
+ if yaxis == 'O2':
1872
+ # Broadcast to match xpoints length
1873
+ if isinstance(logfO2, (int, float)):
1874
+ y_oxidation = np.full_like(xpoints, logfO2)
1875
+ else:
1876
+ logfO2_val = float(logfO2.iloc[0]) if hasattr(logfO2, 'iloc') else float(logfO2[0])
1877
+ y_oxidation = np.full_like(xpoints, logfO2_val)
1878
+ elif yaxis == 'Eh':
1879
+ y_oxidation = convert(logfO2, 'E0', T=T, P=P, pH=pH, logaH2O=logaH2O, messages=messages)
1880
+ elif yaxis == 'pe':
1881
+ Eh_val = convert(logfO2, 'E0', T=T, P=P, pH=pH, logaH2O=logaH2O, messages=messages)
1882
+ y_oxidation = convert(Eh_val, 'pe', T=T, messages=messages)
1883
+
1884
+ else:
1885
+ # Invalid axis combination
1886
+ result.update({'xpoints': xpoints, 'y_oxidation': None, 'y_reduction': None, 'swapped': swapped})
1887
+ return result
1888
+
1889
+ # Route to Plotly or matplotlib implementation
1890
+ if is_plotly:
1891
+ return _water_lines_plotly(result, xpoints, y_oxidation, y_reduction, swapped,
1892
+ lty, lwd, col, plot_it)
1893
+
1894
+ # Matplotlib implementation
1895
+ # Only draw water lines if eout already has an axes (meaning it's from a diagram)
1896
+ # If no axes, this is being called just for calculation (e.g., from within diagram())
1897
+ if 'ax' not in eout or eout['ax'] is None:
1898
+ # No axes to plot on - just return the calculated values
1899
+ result.update({'xpoints': xpoints, 'y_oxidation': y_oxidation, 'y_reduction': y_reduction, 'swapped': swapped})
1900
+ return result
1901
+
1902
+ # Use the axes from result
1903
+ ax = result['ax']
1904
+
1905
+ # First, shade the water-unstable regions with gray
1906
+ # This creates the same effect as R's fill.NA for H2O.predominant
1907
+ if y_oxidation is not None and y_reduction is not None:
1908
+ from matplotlib.colors import ListedColormap
1909
+
1910
+ # Get current axis limits to create shading
1911
+ xlim = ax.get_xlim()
1912
+ ylim = ax.get_ylim()
1913
+
1914
+ # Create a high-resolution mesh for smooth shading
1915
+ n_points = 500
1916
+ if swapped:
1917
+ # When swapped, xpoints is on the y-axis
1918
+ y_mesh = np.linspace(ylim[0], ylim[1], n_points)
1919
+ x_mesh = np.linspace(xlim[0], xlim[1], n_points)
1920
+ X, Y = np.meshgrid(x_mesh, y_mesh)
1921
+
1922
+ # For each y-value, determine if it's in water-unstable region
1923
+ # Interpolate oxidation and reduction values to the mesh
1924
+ y_ox_interp = np.interp(y_mesh, xpoints, y_oxidation)
1925
+ y_red_interp = np.interp(y_mesh, xpoints, y_reduction)
1926
+
1927
+ # Create mask: unstable where x < min or x > max
1928
+ unstable = np.zeros_like(X, dtype=bool)
1929
+ for i in range(n_points):
1930
+ ymin = min(y_ox_interp[i], y_red_interp[i])
1931
+ ymax = max(y_ox_interp[i], y_red_interp[i])
1932
+ unstable[i, :] = (X[i, :] < ymin) | (X[i, :] > ymax)
1933
+ else:
1934
+ # Normal: xpoints on x-axis, y values on y-axis
1935
+ x_mesh = np.linspace(xlim[0], xlim[1], n_points)
1936
+ y_mesh = np.linspace(ylim[0], ylim[1], n_points)
1937
+ X, Y = np.meshgrid(x_mesh, y_mesh)
1938
+
1939
+ # Interpolate oxidation and reduction values to the mesh
1940
+ y_ox_interp = np.interp(x_mesh, xpoints, y_oxidation)
1941
+ y_red_interp = np.interp(x_mesh, xpoints, y_reduction)
1942
+
1943
+ # Create mask: unstable where y < min or y > max
1944
+ unstable = np.zeros_like(Y, dtype=bool)
1945
+ for i in range(n_points):
1946
+ ymin = min(y_ox_interp[i], y_red_interp[i])
1947
+ ymax = max(y_ox_interp[i], y_red_interp[i])
1948
+ unstable[:, i] = (Y[:, i] < ymin) | (Y[:, i] > ymax)
1949
+
1950
+ # Create masked array for unstable regions
1951
+ import numpy.ma as ma
1952
+ unstable_mask = ma.masked_where(~unstable, np.ones_like(X))
1953
+
1954
+ # Draw the shading with gray (matching R's gray80 = 0.8)
1955
+ fill_na_cmap = ListedColormap(['0.8'])
1956
+ extent = [xlim[0], xlim[1], ylim[0], ylim[1]]
1957
+ ax.imshow(unstable_mask, aspect='auto', origin='lower',
1958
+ extent=extent, interpolation='nearest',
1959
+ cmap=fill_na_cmap, vmin=0, vmax=1, zorder=1)
1960
+
1961
+ # Set line color
1962
+ if col is None:
1963
+ col = 'black'
1964
+
1965
+ # Convert numeric line style to matplotlib style
1966
+ lty_map = {1: '-', 2: '--', 3: '-.', 4: ':', 5: '-', 6: '--'}
1967
+ if isinstance(lty, int):
1968
+ lty = lty_map.get(lty, '--')
1969
+
1970
+ if swapped:
1971
+ if nvar1 == 1 or nvar2 == 2:
1972
+ # Add vertical lines on 1-D diagram
1973
+ if y_oxidation is not None and len(y_oxidation) > 0:
1974
+ ax.axvline(x=y_oxidation[0], linestyle=lty, linewidth=lwd, color=col)
1975
+ if y_reduction is not None and len(y_reduction) > 0:
1976
+ ax.axvline(x=y_reduction[0], linestyle=lty, linewidth=lwd, color=col)
1977
+ else:
1978
+ # xpoints above is really the ypoints
1979
+ if y_oxidation is not None:
1980
+ ax.plot(y_oxidation, xpoints, linestyle=lty, linewidth=lwd, color=col)
1981
+ if y_reduction is not None:
1982
+ ax.plot(y_reduction, xpoints, linestyle=lty, linewidth=lwd, color=col)
1983
+ else:
1984
+ if y_oxidation is not None:
1985
+ ax.plot(xpoints, y_oxidation, linestyle=lty, linewidth=lwd, color=col)
1986
+ if y_reduction is not None:
1987
+ ax.plot(xpoints, y_reduction, linestyle=lty, linewidth=lwd, color=col)
1988
+
1989
+ # Update the figure and axes references in result to reflect the water lines
1990
+ fig = ax.get_figure()
1991
+ result['fig'] = fig
1992
+ result['ax'] = ax
1993
+
1994
+ # Display the figure if plot_it=True
1995
+ # This allows water_lines() to display a figure that was created with plot_it=False
1996
+ if plot_it and fig is not None:
1997
+ try:
1998
+ from IPython.display import display
1999
+ display(fig)
2000
+ except (ImportError, NameError):
2001
+ # Not in IPython/Jupyter, matplotlib will handle display
2002
+ pass
2003
+
2004
+ # Update result with water line data and return
2005
+ result.update({'xpoints': xpoints, 'y_oxidation': y_oxidation, 'y_reduction': y_reduction, 'swapped': swapped})
2006
+ return result
2007
+
2008
+
2009
+ def _water_lines_plotly(eout: Dict[str, Any],
2010
+ xpoints: np.ndarray,
2011
+ y_oxidation: Optional[np.ndarray],
2012
+ y_reduction: Optional[np.ndarray],
2013
+ swapped: bool,
2014
+ lty: Union[int, str],
2015
+ lwd: float,
2016
+ col: Optional[str],
2017
+ plot_it: bool) -> Dict[str, Any]:
2018
+ """
2019
+ Add water stability lines to a Plotly interactive diagram.
2020
+
2021
+ This helper function adds dashed lines showing water oxidation and reduction
2022
+ stability limits, plus gray shading for water-unstable regions, to an
2023
+ interactive Plotly diagram.
2024
+ """
2025
+ import plotly.graph_objects as go
2026
+ import numpy as np
2027
+
2028
+ # Get the Plotly figure from eout
2029
+ fig = eout['fig']
2030
+
2031
+ # Set line color (default to black)
2032
+ if col is None:
2033
+ col = 'black'
2034
+
2035
+ # Convert numeric/matplotlib line styles to Plotly dash types
2036
+ lty_map = {
2037
+ 1: 'solid', '-': 'solid',
2038
+ 2: 'dash', '--': 'dash',
2039
+ 3: 'dashdot', '-.': 'dashdot',
2040
+ 4: 'dot', ':': 'dot',
2041
+ 5: 'solid', 6: 'dash'
2042
+ }
2043
+ dash_style = lty_map.get(lty, 'dash') if (isinstance(lty, int) or lty in lty_map) else 'dash'
2044
+
2045
+ # Get axis limits to determine shading extent
2046
+ # We need to extract axis ranges from the existing figure
2047
+ if fig.layout.xaxis.range:
2048
+ xlim = fig.layout.xaxis.range
2049
+ else:
2050
+ # If not set, estimate from data
2051
+ xlim = [xpoints.min(), xpoints.max()]
2052
+
2053
+ if fig.layout.yaxis.range:
2054
+ ylim = fig.layout.yaxis.range
2055
+ else:
2056
+ # If not set, estimate from y_oxidation and y_reduction
2057
+ if y_oxidation is not None and y_reduction is not None:
2058
+ ylim = [min(y_oxidation.min(), y_reduction.min()),
2059
+ max(y_oxidation.max(), y_reduction.max())]
2060
+ else:
2061
+ ylim = [0, 1] # Fallback
2062
+
2063
+ # Add gray shading for water-unstable regions
2064
+ if y_oxidation is not None and y_reduction is not None:
2065
+ # Create high-resolution mesh for smooth shading
2066
+ n_points = 200
2067
+
2068
+ if swapped:
2069
+ # When swapped, xpoints is on y-axis, y values on x-axis
2070
+ # Create shading shapes for regions outside water stability
2071
+
2072
+ # Upper unstable region (above oxidation line)
2073
+ y_mesh = np.linspace(ylim[0], ylim[1], n_points)
2074
+ x_ox_interp = np.interp(y_mesh, xpoints, y_oxidation)
2075
+
2076
+ # Fill from oxidation line to right edge
2077
+ fig.add_trace(go.Scatter(
2078
+ x=np.concatenate([x_ox_interp, [xlim[1]] * len(y_mesh), x_ox_interp[::-1]]),
2079
+ y=np.concatenate([y_mesh, y_mesh[::-1], y_mesh[::-1]]),
2080
+ fill='toself',
2081
+ fillcolor='rgba(128, 128, 128, 0.5)', # Gray with transparency
2082
+ line=dict(width=0),
2083
+ showlegend=False,
2084
+ hoverinfo='skip'
2085
+ ))
2086
+
2087
+ # Lower unstable region (below reduction line)
2088
+ x_red_interp = np.interp(y_mesh, xpoints, y_reduction)
2089
+
2090
+ # Fill from left edge to reduction line
2091
+ fig.add_trace(go.Scatter(
2092
+ x=np.concatenate([[xlim[0]] * len(y_mesh), x_red_interp[::-1], [xlim[0]] * len(y_mesh)]),
2093
+ y=np.concatenate([y_mesh, y_mesh[::-1], y_mesh[::-1]]),
2094
+ fill='toself',
2095
+ fillcolor='rgba(128, 128, 128, 0.5)',
2096
+ line=dict(width=0),
2097
+ showlegend=False,
2098
+ hoverinfo='skip'
2099
+ ))
2100
+
2101
+ else:
2102
+ # Normal: xpoints on x-axis, y values on y-axis
2103
+ # Interpolate to create smooth shading boundaries
2104
+ x_mesh = np.linspace(xlim[0], xlim[1], n_points)
2105
+ y_ox_interp = np.interp(x_mesh, xpoints, y_oxidation)
2106
+ y_red_interp = np.interp(x_mesh, xpoints, y_reduction)
2107
+
2108
+ # Upper unstable region (above oxidation line)
2109
+ fig.add_trace(go.Scatter(
2110
+ x=np.concatenate([x_mesh, x_mesh[::-1]]),
2111
+ y=np.concatenate([y_ox_interp, [ylim[1]] * len(x_mesh)]),
2112
+ fill='toself',
2113
+ fillcolor='rgba(128, 128, 128, 0.5)',
2114
+ line=dict(width=0),
2115
+ showlegend=False,
2116
+ hoverinfo='skip'
2117
+ ))
2118
+
2119
+ # Lower unstable region (below reduction line)
2120
+ fig.add_trace(go.Scatter(
2121
+ x=np.concatenate([x_mesh, x_mesh[::-1]]),
2122
+ y=np.concatenate([[ylim[0]] * len(x_mesh), y_red_interp[::-1]]),
2123
+ fill='toself',
2124
+ fillcolor='rgba(128, 128, 128, 0.5)',
2125
+ line=dict(width=0),
2126
+ showlegend=False,
2127
+ hoverinfo='skip'
2128
+ ))
2129
+
2130
+ # Add water stability lines
2131
+ if swapped:
2132
+ # When swapped: xpoints is on y-axis, y values on x-axis
2133
+ if y_oxidation is not None:
2134
+ fig.add_trace(go.Scatter(
2135
+ x=y_oxidation,
2136
+ y=xpoints,
2137
+ mode='lines',
2138
+ line=dict(color=col, width=lwd, dash=dash_style),
2139
+ name='H₂O oxidation limit',
2140
+ showlegend=False,
2141
+ hoverinfo='skip'
2142
+ ))
2143
+ if y_reduction is not None:
2144
+ fig.add_trace(go.Scatter(
2145
+ x=y_reduction,
2146
+ y=xpoints,
2147
+ mode='lines',
2148
+ line=dict(color=col, width=lwd, dash=dash_style),
2149
+ name='H₂O reduction limit',
2150
+ showlegend=False,
2151
+ hoverinfo='skip'
2152
+ ))
2153
+ else:
2154
+ # Normal orientation
2155
+ if y_oxidation is not None:
2156
+ fig.add_trace(go.Scatter(
2157
+ x=xpoints,
2158
+ y=y_oxidation,
2159
+ mode='lines',
2160
+ line=dict(color=col, width=lwd, dash=dash_style),
2161
+ name='H₂O oxidation limit',
2162
+ showlegend=False,
2163
+ hoverinfo='skip'
2164
+ ))
2165
+ if y_reduction is not None:
2166
+ fig.add_trace(go.Scatter(
2167
+ x=xpoints,
2168
+ y=y_reduction,
2169
+ mode='lines',
2170
+ line=dict(color=col, width=lwd, dash=dash_style),
2171
+ name='H₂O reduction limit',
2172
+ showlegend=False,
2173
+ hoverinfo='skip'
2174
+ ))
2175
+
2176
+ # Update the figure reference in eout to reflect the water lines
2177
+ eout['fig'] = fig
2178
+
2179
+ # Display the figure if plot_it=True
2180
+ if plot_it:
2181
+ try:
2182
+ from IPython.display import display
2183
+ display(fig)
2184
+ except (ImportError, NameError):
2185
+ # Not in IPython/Jupyter, use fig.show()
2186
+ fig.show()
2187
+
2188
+ # Update eout with water line data and return
2189
+ eout.update({'xpoints': xpoints, 'y_oxidation': y_oxidation, 'y_reduction': y_reduction, 'swapped': swapped})
2190
+ return eout
2191
+
2192
+
2193
+ def find_tp(predominant: np.ndarray) -> np.ndarray:
2194
+ """
2195
+ Find triple points in a predominance diagram.
2196
+
2197
+ This function identifies the approximate positions of triple points
2198
+ (where three phases meet) in a 2-D predominance diagram by locating
2199
+ cells with the greatest number of different neighboring values.
2200
+
2201
+ Parameters
2202
+ ----------
2203
+ predominant : np.ndarray
2204
+ Matrix of integers from diagram() output indicating which species
2205
+ predominates at each point. Should be a 2-D array where each value
2206
+ represents a different species/phase.
2207
+
2208
+ Returns
2209
+ -------
2210
+ np.ndarray
2211
+ Array of shape (n, 2) where n is the number of triple points found.
2212
+ Each row contains [row_index, col_index] of a triple point location.
2213
+ Indices are 1-based to match R behavior.
2214
+
2215
+ Examples
2216
+ --------
2217
+ >>> from pychnosz import *
2218
+ >>> reset()
2219
+ >>> basis(["corundum", "quartz", "oxygen"])
2220
+ >>> species(["kyanite", "sillimanite", "andalusite"])
2221
+ >>> a = affinity(T=[200, 900, 99], P=[0, 9000, 101], exceed_Ttr=True)
2222
+ >>> d = diagram(a)
2223
+ >>> tp = find_tp(d['predominant'])
2224
+ >>> # Get T and P at the triple point
2225
+ >>> Ttp = a['vals'][0][tp[0, 1] - 1] # -1 for 0-based indexing
2226
+ >>> Ptp = a['vals'][1][::-1][tp[0, 0] - 1] # reversed and -1
2227
+
2228
+ Notes
2229
+ -----
2230
+ This is a Python translation of the R function find.tp() from CHNOSZ.
2231
+ The R version returns 1-based indices, and this Python version does too
2232
+ for consistency. When using these indices to access Python arrays,
2233
+ remember to subtract 1.
2234
+
2235
+ The function works by:
2236
+ 1. Rearranging the matrix as done by diagram() for plotting
2237
+ 2. For each position, examining a 3x3 neighborhood
2238
+ 3. Counting the number of unique values in that neighborhood
2239
+ 4. Returning positions with the maximum count (typically 3 or more)
2240
+ """
2241
+ # Rearrange the matrix in the same way that diagram() does for 2-D predominance diagrams
2242
+ # R code: x <- t(x[, ncol(x):1])
2243
+ # This means: first reverse columns, then transpose
2244
+ x = np.transpose(predominant[:, ::-1])
2245
+
2246
+ # Get all positions with valid values (> 0)
2247
+ valid_positions = np.argwhere(x > 0)
2248
+
2249
+ if len(valid_positions) == 0:
2250
+ return np.array([])
2251
+
2252
+ # For each position, count unique values in 3x3 neighborhood
2253
+ counts = []
2254
+ for pos in valid_positions:
2255
+ row, col = pos
2256
+
2257
+ # Define the range to look at (3x3 except at edges)
2258
+ r1 = max(row - 1, 0)
2259
+ r2 = min(row + 1, x.shape[0] - 1)
2260
+ c1 = max(col - 1, 0)
2261
+ c2 = min(col + 1, x.shape[1] - 1)
2262
+
2263
+ # Extract the neighborhood
2264
+ neighborhood = x[r1:r2+1, c1:c2+1]
2265
+
2266
+ # Count unique values
2267
+ n_unique = len(np.unique(neighborhood))
2268
+ counts.append(n_unique)
2269
+
2270
+ counts = np.array(counts)
2271
+
2272
+ # Find positions with the maximum count
2273
+ max_count = np.max(counts)
2274
+ max_positions = valid_positions[counts == max_count]
2275
+
2276
+ # Convert to 1-based indexing (to match R)
2277
+ # Return as [row, col] with 1-based indices
2278
+ result = max_positions + 1
2279
+
2280
+ return result
2281
+
2282
+
2283
+ def diagram_interactive(eout: Dict[str, Any],
2284
+ type: str = "auto",
2285
+ main: Optional[str] = None,
2286
+ borders: Union[float, str] = 0,
2287
+ names: Optional[List[str]] = None,
2288
+ format_names: bool = True,
2289
+ annotation: Optional[str] = None,
2290
+ annotation_coords: List[float] = [0, 0],
2291
+ balance: Optional[Union[str, float, List[float]]] = None,
2292
+ xlab: Optional[str] = None,
2293
+ ylab: Optional[str] = None,
2294
+ fill: Optional[Union[str, List[str]]] = "viridis",
2295
+ width: int = 600,
2296
+ height: int = 520,
2297
+ alpha: Union[bool, str] = False,
2298
+ add: bool = False,
2299
+ ax: Optional[Any] = None,
2300
+ col: Optional[Union[str, List[str]]] = None,
2301
+ lty: Optional[Union[str, int, List]] = None,
2302
+ lwd: Union[float, List[float]] = 1,
2303
+ cex: Union[float, List[float]] = 1.0,
2304
+ contour_method: Optional[Union[str, List[str]]] = "edge",
2305
+ messages: bool = True,
2306
+ plot_it: bool = True,
2307
+ save_as: Optional[str] = None,
2308
+ save_format: Optional[str] = None,
2309
+ save_scale: float = 1) -> Tuple[pd.DataFrame, Any]:
2310
+ """
2311
+ Create an interactive diagram using Plotly.
2312
+
2313
+ This function produces interactive versions of the diagrams created by diagram(),
2314
+ using Plotly for interactivity. It accepts output from affinity() or equilibrate()
2315
+ and creates either 1D line plots or 2D predominance diagrams.
2316
+
2317
+ Parameters
2318
+ ----------
2319
+ eout : dict
2320
+ Output from affinity() or equilibrate().
2321
+ main : str, optional
2322
+ Title of the plot.
2323
+ borders : float or str, default 0
2324
+ Controls boundary lines between regions in 2D predominance diagrams.
2325
+ - If numeric > 0: draws grid-aligned borders with specified thickness (pixels)
2326
+ - If "contour": draws smooth contour-based boundaries (like diagram())
2327
+ - If 0 or None: no borders drawn
2328
+ names : list of str, optional
2329
+ Names of species for activity lines or predominance fields.
2330
+ format_names : bool, default True
2331
+ Apply formatting to chemical formulas?
2332
+ annotation : str, optional
2333
+ Annotation to add to the plot.
2334
+ annotation_coords : list of float, default [0, 0]
2335
+ Coordinates of annotation, where 0,0 is bottom left and 1,1 is top right.
2336
+ balance : str or numeric, optional
2337
+ How to balance the transformations.
2338
+ xlab : str, optional
2339
+ Custom x-axis label.
2340
+ ylab : str, optional
2341
+ Custom y-axis label.
2342
+ fill : str or list of str, default "viridis"
2343
+ For 2D diagrams: colormap name (e.g., "viridis", "hot") or list of colors.
2344
+ For 1D diagrams: list of line colors.
2345
+ width : int, default 600
2346
+ Width of the plot in pixels.
2347
+ height : int, default 520
2348
+ Height of the plot in pixels.
2349
+ alpha : bool or str, default False
2350
+ For speciation diagrams, plot degree of formation instead of activities?
2351
+ If True, plots mole fractions. If "balance", scales by stoichiometry.
2352
+ messages : bool, default True
2353
+ Display messages?
2354
+ plot_it : bool, default True
2355
+ Show the plot?
2356
+ save_as : str, optional
2357
+ Provide a filename to save this figure. Filetype of saved figure is
2358
+ determined by save_format.
2359
+ save_format : str, default "png"
2360
+ Desired format of saved or downloaded figure. Can be 'png', 'jpg', 'jpeg',
2361
+ 'webp', 'svg', 'pdf', 'eps', 'json', or 'html'. If 'html', an interactive
2362
+ plot will be saved.
2363
+ save_scale : float, default 1
2364
+ Multiply title/legend/axis/canvas sizes by this factor when saving.
2365
+
2366
+ Returns
2367
+ -------
2368
+ tuple
2369
+ (df, fig) where df is a pandas DataFrame with the data and fig is the
2370
+ Plotly figure object.
2371
+
2372
+ Examples
2373
+ --------
2374
+ 1D diagram:
2375
+ >>> basis("CHNOS+")
2376
+ >>> species(info(["glycinium", "glycine", "glycinate"]))
2377
+ >>> a = affinity(pH=[0, 14])
2378
+ >>> e = equilibrate(a)
2379
+ >>> diagram_interactive(e, alpha=True)
2380
+
2381
+ 2D diagram:
2382
+ >>> basis(["Fe", "oxygen", "S2"])
2383
+ >>> species(["iron", "ferrous-oxide", "magnetite", "hematite", "pyrite", "pyrrhotite"])
2384
+ >>> a = affinity(S2=[-50, 0], O2=[-90, -10], T=200)
2385
+ >>> diagram_interactive(a, fill="hot")
2386
+
2387
+ Notes
2388
+ -----
2389
+ This function requires plotly to be installed. Install with:
2390
+ pip install plotly
2391
+
2392
+ The function adapts the pyCHNOSZ diagram_interactive() implementation
2393
+ to work with Python CHNOSZ's native data structures.
2394
+ """
2395
+
2396
+ # Import plotly (lazy import to avoid dependency issues)
2397
+ try:
2398
+ import plotly.express as px
2399
+ import plotly.graph_objects as go
2400
+ import plotly.io as pio
2401
+ except ImportError:
2402
+ raise ImportError("diagram_interactive() requires plotly. Install with: pip install plotly")
2403
+
2404
+ # Check that eout is valid
2405
+ efun = eout.get('fun', '')
2406
+ if efun not in ['affinity', 'equilibrate', 'solubility']:
2407
+ raise ValueError("'eout' is not the output from affinity(), equilibrate(), or solubility()")
2408
+
2409
+ # Determine if this is affinity or equilibrate output
2410
+ calc_type = "a" if ('loga_equil' not in eout and 'loga.equil' not in eout) else "e"
2411
+
2412
+ # Get basis species and their states
2413
+ basis_df = eout['basis']
2414
+ basis_sp = list(basis_df.index)
2415
+ basis_state = list(basis_df['state'])
2416
+
2417
+ # Get variable names and values
2418
+ xyvars = eout['vars']
2419
+ xyvals_dict = eout['vals']
2420
+ # Convert vals dict to list format for easier access
2421
+ xyvals = [xyvals_dict[var] for var in xyvars]
2422
+
2423
+ # Determine balance if not provided
2424
+ if balance is None or balance == "":
2425
+ # For saturation diagrams, use balance=1 (formula units) to match R behavior
2426
+ # This avoids issues when minerals don't have a common basis element
2427
+ if type == "saturation":
2428
+ balance = 1
2429
+ n_balance = [1] * len(eout['values'])
2430
+ else:
2431
+ # Call diagram with plot_it=False to get balance
2432
+ # Need to import matplotlib to close the figure afterward
2433
+ import matplotlib.pyplot as plt_temp
2434
+ temp_result = diagram(eout, messages=False, plot_it=False)
2435
+ balance = temp_result.get('balance', 1)
2436
+ n_balance = temp_result.get('n_balance', [1])
2437
+ # Close the matplotlib figure created by diagram() since we don't need it
2438
+ if 'fig' in temp_result and temp_result['fig'] is not None:
2439
+ plt_temp.close(temp_result['fig'])
2440
+ else:
2441
+ # Calculate n_balance from balance
2442
+ try:
2443
+ balance_float = float(balance)
2444
+ n_balance = [balance_float] * len(eout['values'])
2445
+ except (ValueError, TypeError):
2446
+ # balance is a string (element name)
2447
+ # Get species from eout instead of global state
2448
+ if 'species' in eout and eout['species'] is not None:
2449
+ sp_df = eout['species']
2450
+ else:
2451
+ # Fallback to global species if not in eout
2452
+ from .species import species as species_func
2453
+ sp_df = species_func()
2454
+
2455
+ # Check if balance is a list (user-provided values) or a string (column name)
2456
+ if isinstance(balance, list):
2457
+ n_balance = balance
2458
+ elif balance in sp_df.columns:
2459
+ n_balance = list(sp_df[balance])
2460
+ else:
2461
+ n_balance = [1] * len(eout['values'])
2462
+
2463
+ # Get output values
2464
+ if calc_type == "a":
2465
+ # handling output of affinity()
2466
+ out_vals = eout['values']
2467
+ out_units = "A/(2.303RT)"
2468
+ else:
2469
+ # handling output of equilibrate()
2470
+ loga_equil_key = 'loga_equil' if 'loga_equil' in eout else 'loga.equil'
2471
+ out_vals = eout[loga_equil_key]
2472
+ out_units = "log a"
2473
+
2474
+ # Convert values to a list format
2475
+ if isinstance(out_vals, dict):
2476
+ nsp = len(out_vals)
2477
+ values_list = list(out_vals.values())
2478
+ species_indices = list(out_vals.keys())
2479
+ else:
2480
+ nsp = len(out_vals)
2481
+ values_list = out_vals
2482
+ species_indices = eout['species']['ispecies'].tolist()
2483
+
2484
+ # Get species names
2485
+ from .info import info as info_func
2486
+ # Convert numpy types to Python types
2487
+ species_indices_py = [int(idx) for idx in species_indices]
2488
+ sp_info = info_func(species_indices_py, messages=False)
2489
+ sp_names = sp_info['name'].tolist()
2490
+
2491
+ # Use custom names if provided
2492
+ if isinstance(names, list) and len(names) == len(sp_names):
2493
+ sp_names = names
2494
+
2495
+ # Determine dimensions
2496
+ first_val = values_list[0]
2497
+ if hasattr(first_val, 'shape'):
2498
+ nd = len(first_val.shape)
2499
+ else:
2500
+ nd = 1 if hasattr(first_val, '__len__') else 0
2501
+
2502
+ # Handle type="saturation" - plot contour lines where affinity=0
2503
+ if type == "saturation":
2504
+ if nd != 2:
2505
+ raise ValueError("type='saturation' requires 2-D diagram")
2506
+ if calc_type != "a":
2507
+ raise ValueError("type='saturation' requires output from affinity(), not equilibrate()")
2508
+
2509
+ # Delegate to saturation plotting function
2510
+ return _plot_saturation_interactive(
2511
+ eout, values_list, sp_names, xyvars, xyvals,
2512
+ xlab, ylab, col, lwd, lty, cex, contour_method,
2513
+ main, add, ax, width, height, plot_it,
2514
+ save_as, save_format, save_scale, messages
2515
+ )
2516
+
2517
+ # Build DataFrame
2518
+ if nd == 2:
2519
+ # 2D case - flatten the data
2520
+ xvals = xyvals[0]
2521
+ yvals = xyvals[1]
2522
+ xvar = xyvars[0]
2523
+ yvar = xyvars[1]
2524
+
2525
+ # Flatten the data - transpose first so coordinates match
2526
+ # Original shape is (nx, ny) where nx=len(xvals), ny=len(yvals)
2527
+ # After transpose, shape is (ny, nx)
2528
+ # Flattening with C-order then gives: [row0, row1, ...] = [x-values at y[0], x-values at y[1], ...]
2529
+ flat_out_vals = []
2530
+ for v in values_list:
2531
+ # Transpose then flatten so coordinates align correctly
2532
+ flat_out_vals.append(v.T.flatten())
2533
+ df = pd.DataFrame(flat_out_vals, index=sp_names).T
2534
+
2535
+ # Apply balance if needed
2536
+ if calc_type == "a":
2537
+ if isinstance(balance, str):
2538
+ # Get balance from species dataframe
2539
+ # Get species from eout instead of global state
2540
+ if 'species' in eout and eout['species'] is not None:
2541
+ sp_df = eout['species']
2542
+ else:
2543
+ # Fallback to global species if not in eout
2544
+ from .species import species as species_func
2545
+ sp_df = species_func()
2546
+
2547
+ # Check if balance is a list (user-provided values) or a string (column name)
2548
+ if isinstance(balance, list):
2549
+ n_balance = balance
2550
+ elif balance in sp_df.columns:
2551
+ n_balance = list(sp_df[balance])
2552
+ # Divide by balance
2553
+ for i, sp in enumerate(sp_names):
2554
+ df[sp] = df[sp] / n_balance[i]
2555
+
2556
+ # Find predominant species
2557
+ df["pred"] = df.idxmax(axis=1, skipna=True)
2558
+ df["prednames"] = df["pred"]
2559
+
2560
+ # Add x and y coordinates
2561
+ # After transpose and flatten, data is ordered as:
2562
+ # [x0,y0], [x1,y0], ..., [xn,y0], [x0,y1], [x1,y1], ...
2563
+ xvals_full = list(xvals) * len(yvals)
2564
+ yvals_full = []
2565
+ for y in yvals:
2566
+ yvals_full.extend([y] * len(xvals))
2567
+ df[xvar] = xvals_full
2568
+ df[yvar] = yvals_full
2569
+
2570
+ else:
2571
+ # 1D case
2572
+ xvar = xyvars[0]
2573
+ xvals = xyvals[0]
2574
+
2575
+ flat_out_vals = []
2576
+ for v in values_list:
2577
+ flat_out_vals.append(v)
2578
+ df = pd.DataFrame(flat_out_vals, index=sp_names).T
2579
+
2580
+ # Apply balance if needed
2581
+ if calc_type == "a":
2582
+ if isinstance(balance, str):
2583
+ # Get species from eout instead of global state
2584
+ if 'species' in eout and eout['species'] is not None:
2585
+ sp_df = eout['species']
2586
+ else:
2587
+ # Fallback to global species if not in eout
2588
+ from .species import species as species_func
2589
+ sp_df = species_func()
2590
+
2591
+ # Check if balance is a list (user-provided values) or a string (column name)
2592
+ if isinstance(balance, list):
2593
+ n_balance = balance
2594
+ elif balance in sp_df.columns:
2595
+ n_balance = list(sp_df[balance])
2596
+ # Divide by balance
2597
+ for i, sp in enumerate(sp_names):
2598
+ df[sp] = df[sp] / n_balance[i]
2599
+
2600
+ # Handle alpha (degree of formation)
2601
+ if alpha:
2602
+ df = df.apply(lambda x: 10**x)
2603
+ df = df[sp_names].div(df[sp_names].sum(axis=1), axis=0)
2604
+
2605
+ df[xvar] = xvals
2606
+
2607
+ # Create axis labels
2608
+ unit_dict = {"P": "bar", "T": "°C", "pH": "", "Eh": "volts", "IS": "mol/kg"}
2609
+
2610
+ for i, s in enumerate(basis_sp):
2611
+ if basis_state[i] in ["aq", "liq", "cr"]:
2612
+ if format_names:
2613
+ unit_dict[s] = f"log <i>a</i><sub>{_format_html_species(s)}</sub>"
2614
+ else:
2615
+ unit_dict[s] = f"log <i>a</i><sub>{s}</sub>"
2616
+ else:
2617
+ if format_names:
2618
+ unit_dict[s] = f"log <i>f</i><sub>{_format_html_species(s)}</sub>"
2619
+ else:
2620
+ unit_dict[s] = f"log <i>f</i><sub>{s}</sub>"
2621
+
2622
+ # Set x-axis label
2623
+ if not isinstance(xlab, str):
2624
+ xlab = xvar + ", " + unit_dict.get(xvar, "")
2625
+ if xvar == "pH":
2626
+ xlab = "pH"
2627
+ if xvar in basis_sp:
2628
+ xlab = unit_dict[xvar]
2629
+
2630
+ # Create the plot
2631
+ if nd == 1:
2632
+ # 1D plot
2633
+ # Melt the dataframe for plotting
2634
+ df_melted = pd.melt(df, id_vars=[xvar], value_vars=sp_names, var_name='variable', value_name='value')
2635
+
2636
+ # Format species names if requested
2637
+ if format_names:
2638
+ df_melted['variable'] = df_melted['variable'].apply(_format_html_species)
2639
+
2640
+ # Set y-axis label
2641
+ if not isinstance(ylab, str):
2642
+ if alpha:
2643
+ ylab = "alpha"
2644
+ else:
2645
+ ylab = out_units
2646
+
2647
+ fig = px.line(df_melted, x=xvar, y="value", color='variable',
2648
+ template="simple_white", width=width, height=height,
2649
+ labels={'value': ylab, xvar: xlab},
2650
+ render_mode='svg')
2651
+
2652
+ # Apply custom colors if provided
2653
+ if isinstance(fill, list):
2654
+ for i, color in enumerate(fill):
2655
+ if i < len(fig.data):
2656
+ fig.data[i].line.color = color
2657
+
2658
+ # Check for LaTeX format in axis labels
2659
+ if xlab and _detect_latex_format(xlab):
2660
+ warnings.warn(
2661
+ "LaTeX formatting detected in 'xlab' parameter. "
2662
+ "Plotly requires HTML format (<sub>, <sup>) instead of LaTeX ($, _, ^). "
2663
+ "For activity ratios, use ratlab_html() instead of ratlab().",
2664
+ UserWarning
2665
+ )
2666
+ if ylab and _detect_latex_format(ylab):
2667
+ warnings.warn(
2668
+ "LaTeX formatting detected in 'ylab' parameter. "
2669
+ "Plotly requires HTML format (<sub>, <sup>) instead of LaTeX ($, _, ^). "
2670
+ "For activity ratios, use ratlab_html() instead of ratlab().",
2671
+ UserWarning
2672
+ )
2673
+
2674
+ fig.update_layout(xaxis_title=xlab,
2675
+ yaxis_title=ylab,
2676
+ legend_title=None)
2677
+
2678
+ if isinstance(main, str):
2679
+ fig.update_layout(title={'text': main, 'x': 0.5, 'xanchor': 'center'})
2680
+
2681
+ if isinstance(annotation, str):
2682
+ # Check for LaTeX format and warn user
2683
+ if _detect_latex_format(annotation):
2684
+ warnings.warn(
2685
+ "LaTeX formatting detected in 'annotation' parameter. "
2686
+ "Plotly requires HTML format (<sub>, <sup>) instead of LaTeX ($, _, ^). "
2687
+ "For activity ratios, use ratlab_html() instead of ratlab().",
2688
+ UserWarning
2689
+ )
2690
+
2691
+ fig.add_annotation(
2692
+ x=annotation_coords[0],
2693
+ y=annotation_coords[1],
2694
+ text=annotation,
2695
+ showarrow=False,
2696
+ xref="paper",
2697
+ yref="paper",
2698
+ align='left',
2699
+ bgcolor="rgba(255, 255, 255, 0.5)")
2700
+
2701
+ # Configure download button
2702
+ save_as_name, save_format_final = _save_figure(fig, save_as, save_format, save_scale,
2703
+ plot_width=width, plot_height=height, ppi=1)
2704
+
2705
+ config = {'displaylogo': False,
2706
+ 'modeBarButtonsToRemove': ['resetScale2d', 'toggleSpikelines'],
2707
+ 'toImageButtonOptions': {
2708
+ 'format': save_format_final,
2709
+ 'filename': save_as_name,
2710
+ 'height': height,
2711
+ 'width': width,
2712
+ 'scale': save_scale,
2713
+ }}
2714
+
2715
+ # Store config on figure so it persists when fig.show() is called later
2716
+ fig._config = fig._config | config
2717
+
2718
+ else:
2719
+ # 2D plot
2720
+ # Map species names to numeric values
2721
+ mappings = {s: lab for s, lab in zip(sp_names, range(len(sp_names)))}
2722
+ df['pred'] = df['pred'].map(mappings).astype(int)
2723
+
2724
+ # Reshape data
2725
+ # Data is flattened as [x0,y0], [x1,y0], ..., [xn,y0], [x0,y1], ...
2726
+ # Reshape to (ny, nx) for proper orientation in Plotly
2727
+ # Plotly expects data[i,j] to correspond to x[j], y[i]
2728
+ data = np.array(df['pred'])
2729
+ shape = (len(yvals), len(xvals))
2730
+ dmap = data.reshape(shape)
2731
+
2732
+ data_names = np.array(df['prednames'])
2733
+ dmap_names = data_names.reshape(shape)
2734
+
2735
+ # Set y-axis label
2736
+ if not isinstance(ylab, str):
2737
+ ylab = yvar + ", " + unit_dict.get(yvar, "")
2738
+ if yvar in basis_sp:
2739
+ ylab = unit_dict[yvar]
2740
+ if yvar == "pH":
2741
+ ylab = "pH"
2742
+
2743
+ # Check for LaTeX format in axis labels (2D plot)
2744
+ if xlab and _detect_latex_format(xlab):
2745
+ warnings.warn(
2746
+ "LaTeX formatting detected in 'xlab' parameter. "
2747
+ "Plotly requires HTML format (<sub>, <sup>) instead of LaTeX ($, _, ^). "
2748
+ "For activity ratios, use ratlab_html() instead of ratlab().",
2749
+ UserWarning
2750
+ )
2751
+ if ylab and _detect_latex_format(ylab):
2752
+ warnings.warn(
2753
+ "LaTeX formatting detected in 'ylab' parameter. "
2754
+ "Plotly requires HTML format (<sub>, <sup>) instead of LaTeX ($, _, ^). "
2755
+ "For activity ratios, use ratlab_html() instead of ratlab().",
2756
+ UserWarning
2757
+ )
2758
+
2759
+ # Create heatmap
2760
+ fig = px.imshow(dmap, width=width, height=height, aspect="auto",
2761
+ labels={'x': xlab, 'y': ylab, 'color': "region"},
2762
+ x=xvals, y=yvals, template="simple_white")
2763
+
2764
+ fig.update(data=[{'customdata': dmap_names,
2765
+ 'hovertemplate': xlab + ': %{x}<br>' + ylab + ': %{y}<br>Region: %{customdata}<extra></extra>'}])
2766
+
2767
+ # Set colormap
2768
+ if fill == 'none':
2769
+ colormap = [[0, 'white'], [1, 'white']]
2770
+ elif isinstance(fill, list):
2771
+ colmap_temp = []
2772
+ for i, v in enumerate(fill):
2773
+ colmap_temp.append([i / (len(fill) - 1) if len(fill) > 1 else 0, v])
2774
+ colormap = colmap_temp
2775
+ else:
2776
+ colormap = fill
2777
+
2778
+ fig.update_traces(dict(showscale=False,
2779
+ coloraxis=None,
2780
+ colorscale=colormap),
2781
+ selector={'type': 'heatmap'})
2782
+
2783
+ fig.update_yaxes(autorange=True)
2784
+
2785
+ if isinstance(main, str):
2786
+ fig.update_layout(title={'text': main, 'x': 0.5, 'xanchor': 'center'})
2787
+
2788
+ # Add species labels
2789
+ for s in sp_names:
2790
+ if s in set(df["prednames"]):
2791
+ df_s = df.loc[df["prednames"] == s]
2792
+ namex = df_s[xvar].mean()
2793
+ namey = df_s[yvar].mean()
2794
+
2795
+ if format_names:
2796
+ annot_text = _format_html_species(s)
2797
+ else:
2798
+ annot_text = str(s)
2799
+
2800
+ fig.add_annotation(x=namex, y=namey,
2801
+ text=annot_text,
2802
+ bgcolor="rgba(255, 255, 255, 0.5)",
2803
+ showarrow=False)
2804
+
2805
+ if isinstance(annotation, str):
2806
+ # Check for LaTeX format and warn user
2807
+ if _detect_latex_format(annotation):
2808
+ warnings.warn(
2809
+ "LaTeX formatting detected in 'annotation' parameter. "
2810
+ "Plotly requires HTML format (<sub>, <sup>) instead of LaTeX ($, _, ^). "
2811
+ "For activity ratios, use ratlab_html() instead of ratlab().",
2812
+ UserWarning
2813
+ )
2814
+
2815
+ fig.add_annotation(
2816
+ x=annotation_coords[0],
2817
+ y=annotation_coords[1],
2818
+ text=annotation,
2819
+ showarrow=False,
2820
+ xref="paper",
2821
+ yref="paper",
2822
+ align='left',
2823
+ bgcolor="rgba(255, 255, 255, 0.5)")
2824
+
2825
+ # Add borders if requested
2826
+ if borders == "contour":
2827
+ # Use contour-based boundaries (smooth, like diagram())
2828
+ # Draw boundaries using matplotlib contour extraction without filling
2829
+
2830
+ # Get unique species (excluding any that don't appear)
2831
+ unique_species_names = sorted(df["prednames"].unique())
2832
+
2833
+ # Create a temporary matplotlib figure to extract contour paths
2834
+ # We won't display it, just use it to calculate contours
2835
+ temp_fig, temp_ax = plt.subplots()
2836
+
2837
+ # For each species, create a binary mask and extract contours
2838
+ for i, sp_name in enumerate(unique_species_names):
2839
+ # Create binary mask: 1 where this species predominates, 0 elsewhere
2840
+ z = (dmap_names == sp_name).astype(float)
2841
+
2842
+ # Create meshgrid for contour
2843
+ X, Y = np.meshgrid(xvals, yvals)
2844
+
2845
+ # Find contours at level 0.5 using matplotlib
2846
+ try:
2847
+ cs = temp_ax.contour(X, Y, z, levels=[0.5])
2848
+
2849
+ # Extract the contour segments
2850
+ # cs.allsegs is a list of lists: [level][segment]
2851
+ for level_segs in cs.allsegs:
2852
+ for segment in level_segs:
2853
+ # segment is an (N, 2) array of (x, y) coordinates
2854
+ # Add as a scatter trace with lines
2855
+ fig.add_trace(
2856
+ go.Scatter(
2857
+ x=segment[:, 0],
2858
+ y=segment[:, 1],
2859
+ mode='lines',
2860
+ line=dict(color='black', width=2),
2861
+ hoverinfo='skip',
2862
+ showlegend=False
2863
+ )
2864
+ )
2865
+
2866
+ # Clear the temp axes for next species
2867
+ temp_ax.clear()
2868
+ except Exception as e:
2869
+ if messages:
2870
+ warnings.warn(f"Could not draw contour for {sp_name}: {e}")
2871
+ pass # Skip if contour can't be drawn
2872
+
2873
+ # Close the temporary figure
2874
+ plt.close(temp_fig)
2875
+
2876
+ elif isinstance(borders, (int, float)) and borders > 0:
2877
+ unique_x_vals = sorted(list(set(df[xvar])))
2878
+ unique_y_vals = sorted(list(set(df[yvar])))
2879
+
2880
+ # Skip border drawing if there are fewer than 2 unique values
2881
+ # (single point or single line - no borders to draw between regions)
2882
+ if len(unique_x_vals) < 2 or len(unique_y_vals) < 2:
2883
+ if messages:
2884
+ warnings.warn("Skipping border drawing: need at least 2 unique values in each dimension")
2885
+ else:
2886
+ def mov_mean(numbers, window_size=2):
2887
+ moving_averages = []
2888
+ for i in range(len(numbers) - window_size + 1):
2889
+ window_average = sum(numbers[i:i + window_size]) / window_size
2890
+ moving_averages.append(window_average)
2891
+ return moving_averages
2892
+
2893
+ x_mov_mean = mov_mean(unique_x_vals)
2894
+ y_mov_mean = mov_mean(unique_y_vals)
2895
+
2896
+ x_plot_min = x_mov_mean[0] - (x_mov_mean[1] - x_mov_mean[0])
2897
+ y_plot_min = y_mov_mean[0] - (y_mov_mean[1] - y_mov_mean[0])
2898
+
2899
+ x_plot_max = x_mov_mean[-1] + (x_mov_mean[1] - x_mov_mean[0])
2900
+ y_plot_max = y_mov_mean[-1] + (y_mov_mean[1] - y_mov_mean[0])
2901
+
2902
+ x_vals_border = [x_plot_min] + x_mov_mean + [x_plot_max]
2903
+ y_vals_border = [y_plot_min] + y_mov_mean + [y_plot_max]
2904
+
2905
+ # Find border lines
2906
+ def find_line(dmap, row_index):
2907
+ return [i for i in range(len(dmap[row_index]) - 1) if dmap[row_index][i] != dmap[row_index][i + 1]]
2908
+
2909
+ nrows, ncols = dmap.shape
2910
+ vlines = [find_line(dmap, row_i) for row_i in range(nrows)]
2911
+
2912
+ dmap_transposed = dmap.transpose()
2913
+ nrows_t, ncols_t = dmap_transposed.shape
2914
+ hlines = [find_line(dmap_transposed, row_i) for row_i in range(nrows_t)]
2915
+
2916
+ y_coord_list_vertical = []
2917
+ x_coord_list_vertical = []
2918
+ for i, row in enumerate(vlines):
2919
+ for line in row:
2920
+ x_coord_list_vertical += [x_vals_border[line + 1], x_vals_border[line + 1], np.nan]
2921
+ y_coord_list_vertical += [y_vals_border[i], y_vals_border[i + 1], np.nan]
2922
+
2923
+ y_coord_list_horizontal = []
2924
+ x_coord_list_horizontal = []
2925
+ for i, col in enumerate(hlines):
2926
+ for line in col:
2927
+ y_coord_list_horizontal += [y_vals_border[line + 1], y_vals_border[line + 1], np.nan]
2928
+ x_coord_list_horizontal += [x_vals_border[i], x_vals_border[i + 1], np.nan]
2929
+
2930
+ fig.add_trace(
2931
+ go.Scatter(
2932
+ mode='lines',
2933
+ x=x_coord_list_horizontal,
2934
+ y=y_coord_list_horizontal,
2935
+ line={'width': borders, 'color': 'black'},
2936
+ hoverinfo='skip',
2937
+ showlegend=False))
2938
+
2939
+ fig.add_trace(
2940
+ go.Scatter(
2941
+ mode='lines',
2942
+ x=x_coord_list_vertical,
2943
+ y=y_coord_list_vertical,
2944
+ line={'width': borders, 'color': 'black'},
2945
+ hoverinfo='skip',
2946
+ showlegend=False))
2947
+
2948
+ fig.update_yaxes(range=[min(yvals), max(yvals)], autorange=False, mirror=True)
2949
+ fig.update_xaxes(range=[min(xvals), max(xvals)], autorange=False, mirror=True)
2950
+
2951
+ # Configure download button
2952
+ save_as_name, save_format_final = _save_figure(fig, save_as, save_format, save_scale,
2953
+ plot_width=width, plot_height=height, ppi=1)
2954
+
2955
+ config = {'displaylogo': False,
2956
+ 'modeBarButtonsToRemove': ['zoom2d', 'pan2d', 'zoomIn2d', 'zoomOut2d',
2957
+ 'autoScale2d', 'resetScale2d', 'toggleSpikelines',
2958
+ 'hoverClosestCartesian', 'hoverCompareCartesian'],
2959
+ 'toImageButtonOptions': {
2960
+ 'format': save_format_final,
2961
+ 'filename': save_as_name,
2962
+ 'height': height,
2963
+ 'width': width,
2964
+ 'scale': save_scale,
2965
+ }}
2966
+
2967
+ # Store config on figure so it persists when fig.show() is called later
2968
+ fig._config = fig._config | config
2969
+
2970
+ if plot_it:
2971
+ fig.show(config=config)
2972
+
2973
+ return df, fig
2974
+
2975
+
2976
+ def _detect_latex_format(text: str) -> bool:
2977
+ """
2978
+ Detect if a string contains LaTeX formatting (incompatible with Plotly).
2979
+
2980
+ Parameters
2981
+ ----------
2982
+ text : str
2983
+ Text to check for LaTeX formatting
2984
+
2985
+ Returns
2986
+ -------
2987
+ bool
2988
+ True if LaTeX formatting is detected (e.g., $...$, _{...}, ^{...})
2989
+ """
2990
+ import re
2991
+ # Check for common LaTeX patterns:
2992
+ # - Text wrapped in $ $
2993
+ # - LaTeX subscripts _{...}
2994
+ # - LaTeX superscripts ^{...}
2995
+ latex_patterns = [
2996
+ r'\$[^$]+\$', # $...$
2997
+ r'_\{[^}]+\}', # _{...}
2998
+ r'\^\{[^}]+\}' # ^{...}
2999
+ ]
3000
+
3001
+ for pattern in latex_patterns:
3002
+ if re.search(pattern, text):
3003
+ return True
3004
+ return False
3005
+
3006
+
3007
+ def _format_html_species(formula: str) -> str:
3008
+ """
3009
+ Format a chemical formula for HTML rendering in Plotly.
3010
+
3011
+ Converts chemical formulas like "H2O" to "H<sub>2</sub>O" and
3012
+ "Ca+2" to "Ca<sup>2+</sup>".
3013
+
3014
+ Parameters
3015
+ ----------
3016
+ formula : str
3017
+ Chemical formula to format
3018
+
3019
+ Returns
3020
+ -------
3021
+ str
3022
+ HTML-formatted formula
3023
+ """
3024
+ import re
3025
+
3026
+ # Handle charge notation (e.g., +2, -1)
3027
+ # Match patterns like +2, -2, +, -
3028
+ charge_pattern = r'([+-])(\d*)'
3029
+
3030
+ def format_charge(match):
3031
+ sign = match.group(1)
3032
+ num = match.group(2)
3033
+ if num == '' or num == '1':
3034
+ return f"<sup>{sign}</sup>"
3035
+ else:
3036
+ return f"<sup>{num}{sign}</sup>"
3037
+
3038
+ # First handle charges at the end
3039
+ formula = re.sub(charge_pattern + r'$', format_charge, formula)
3040
+
3041
+ # Handle subscript numbers (digits that aren't part of the charge)
3042
+ # Match digits that come after letters and aren't preceded by < or >
3043
+ def format_subscript(match):
3044
+ return f"<sub>{match.group(0)}</sub>"
3045
+
3046
+ # Find all digits that should be subscripts
3047
+ # This matches digits that come after letters
3048
+ result = []
3049
+ i = 0
3050
+ while i < len(formula):
3051
+ if formula[i].isdigit() and i > 0 and formula[i-1].isalpha():
3052
+ # Start of a number sequence
3053
+ num_start = i
3054
+ while i < len(formula) and formula[i].isdigit():
3055
+ i += 1
3056
+ result.append(f"<sub>{formula[num_start:i]}</sub>")
3057
+ else:
3058
+ result.append(formula[i])
3059
+ i += 1
3060
+
3061
+ return ''.join(result)
3062
+
3063
+
3064
+ def _save_figure(fig, save_as, save_format, save_scale, plot_width, plot_height, ppi):
3065
+ """
3066
+ Save a Plotly figure to a file.
3067
+
3068
+ Parameters
3069
+ ----------
3070
+ fig : plotly figure
3071
+ The figure to save
3072
+ save_as : str or None
3073
+ Filename (without extension) to save as
3074
+ save_format : str or None
3075
+ Format to save ('png', 'jpg', 'jpeg', 'webp', 'svg', 'pdf', 'eps', 'json', 'html')
3076
+ save_scale : float
3077
+ Scale factor for saving
3078
+ plot_width : int
3079
+ Width of the plot
3080
+ plot_height : int
3081
+ Height of the plot
3082
+ ppi : int
3083
+ Pixels per inch
3084
+
3085
+ Returns
3086
+ -------
3087
+ tuple
3088
+ (save_as, save_format) - processed values for use in config
3089
+ """
3090
+ import plotly.io as pio
3091
+
3092
+ valid_formats = ['png', 'jpg', 'jpeg', 'webp', 'svg', 'pdf', 'eps', 'json', 'html']
3093
+
3094
+ if isinstance(save_format, str) and save_format not in valid_formats:
3095
+ raise ValueError(f"{save_format} is an unrecognized save format. "
3096
+ f"Supported formats include: {', '.join(valid_formats)}")
3097
+
3098
+ if isinstance(save_format, str) and save_as is not None:
3099
+ if not isinstance(save_as, str):
3100
+ save_as = "newplot"
3101
+
3102
+ if save_format == "html":
3103
+ fig.write_html(save_as + ".html")
3104
+ print(f"Saved figure as {save_as}.html")
3105
+ save_format = 'png'
3106
+ elif save_format in ['pdf', 'eps', 'json']:
3107
+ pio.write_image(fig, save_as + "." + save_format, format=save_format,
3108
+ scale=save_scale, width=plot_width * ppi, height=plot_height * ppi)
3109
+ print(f"Saved figure as {save_as}.{save_format}")
3110
+ save_format = "png"
3111
+ else:
3112
+ pio.write_image(fig, save_as + "." + save_format, format=save_format,
3113
+ scale=save_scale, width=plot_width * ppi, height=plot_height * ppi)
3114
+ print(f"Saved figure as {save_as}.{save_format}")
3115
+ else:
3116
+ save_format = "png"
3117
+
3118
+ return save_as, save_format
3119
+
3120
+
3121
+ def _plot_saturation_interactive(eout, values_list, sp_names, xyvars, xyvals,
3122
+ xlab, ylab, col, lwd, lty, cex, contour_method,
3123
+ main, add, ax, width, height, plot_it,
3124
+ save_as, save_format, save_scale, messages):
3125
+ """
3126
+ Plot saturation lines (affinity=0 contours) for interactive 2-D diagrams using Plotly.
3127
+
3128
+ This function draws contour lines where affinity = 0 for each species,
3129
+ indicating saturation boundaries (e.g., mineral precipitation thresholds).
3130
+ """
3131
+ import plotly.graph_objects as go
3132
+ import plotly.io as pio
3133
+
3134
+ # Get x and y values
3135
+ xvals = xyvals[0]
3136
+ yvals = xyvals[1]
3137
+ xvar = xyvars[0]
3138
+ yvar = xyvars[1]
3139
+
3140
+ n_species = len(sp_names)
3141
+
3142
+ if messages:
3143
+ print(f"diagram: plotting saturation lines for interactive 2-D diagram")
3144
+
3145
+ # Set up colors
3146
+ if col is None:
3147
+ # Use default Plotly colors
3148
+ default_colors = ['#636EFA', '#EF553B', '#00CC96', '#AB63FA', '#FFA15A',
3149
+ '#19D3F3', '#FF6692', '#B6E880', '#FF97FF', '#FECB52']
3150
+ col = [default_colors[i % len(default_colors)] for i in range(n_species)]
3151
+ elif isinstance(col, str):
3152
+ col = [col] * n_species
3153
+ else:
3154
+ col = list(col)
3155
+ if len(col) < n_species:
3156
+ col = col * (n_species // len(col) + 1)
3157
+ col = col[:n_species]
3158
+
3159
+ # Set up line widths
3160
+ if isinstance(lwd, (int, float)):
3161
+ lwd = [lwd] * n_species
3162
+ else:
3163
+ lwd = list(lwd)
3164
+ if len(lwd) < n_species:
3165
+ lwd = lwd * (n_species // len(lwd) + 1)
3166
+ lwd = lwd[:n_species]
3167
+
3168
+ # Handle line styles (lty)
3169
+ if lty is None:
3170
+ lty = ['solid'] * n_species
3171
+ elif isinstance(lty, (str, int)):
3172
+ lty = [lty] * n_species
3173
+ else:
3174
+ lty = list(lty)
3175
+ if len(lty) < n_species:
3176
+ lty = lty * (n_species // len(lty) + 1)
3177
+ lty = lty[:n_species]
3178
+
3179
+ # Convert numeric/matplotlib line styles to Plotly dash types
3180
+ lty_map = {
3181
+ 1: 'solid', '-': 'solid',
3182
+ 2: 'dash', '--': 'dash',
3183
+ 3: 'dashdot', '-.': 'dashdot',
3184
+ 4: 'dot', ':': 'dot',
3185
+ 5: 'solid', 6: 'dash'
3186
+ }
3187
+ lty = [lty_map.get(lt, 'solid') if (isinstance(lt, int) or lt in lty_map) else 'solid' for lt in lty]
3188
+
3189
+ # Handle contour label control
3190
+ if contour_method is None or contour_method == "" or (isinstance(contour_method, str) and contour_method.upper() == "NA"):
3191
+ show_labels = [False] * n_species
3192
+ elif isinstance(contour_method, str):
3193
+ show_labels = [True] * n_species
3194
+ elif isinstance(contour_method, list):
3195
+ if len(contour_method) != n_species:
3196
+ contour_method_extended = list(contour_method) * (n_species // len(contour_method) + 1)
3197
+ contour_method_extended = contour_method_extended[:n_species]
3198
+ else:
3199
+ contour_method_extended = contour_method
3200
+
3201
+ show_labels = []
3202
+ for method in contour_method_extended:
3203
+ if method is None or method == "" or (isinstance(method, str) and method.upper() == "NA"):
3204
+ show_labels.append(False)
3205
+ else:
3206
+ show_labels.append(True)
3207
+ else:
3208
+ show_labels = [True] * n_species
3209
+
3210
+ # Handle text size (cex) for contour labels
3211
+ if isinstance(cex, (int, float)):
3212
+ cex_list = [cex] * n_species
3213
+ else:
3214
+ cex_list = list(cex)
3215
+ if len(cex_list) < n_species:
3216
+ cex_list = cex_list * (n_species // len(cex_list) + 1)
3217
+ cex_list = cex_list[:n_species]
3218
+
3219
+ # Base font size for labels (Plotly default ~12)
3220
+ base_font_size = 12
3221
+ font_sizes = [base_font_size * c for c in cex_list]
3222
+
3223
+ # Create or get figure
3224
+ if add and ax is not None:
3225
+ # ax is actually the Plotly figure from previous call
3226
+ fig = ax
3227
+ else:
3228
+ # Create new figure
3229
+ fig = go.Figure()
3230
+
3231
+ # Add contour lines for each species
3232
+ for i, sp_name in enumerate(sp_names):
3233
+ # Get affinity values for this species
3234
+ # values_list[i] has shape (nx, ny)
3235
+ affinity_2d = values_list[i]
3236
+
3237
+ # Create contour trace for affinity=0 only
3238
+ # Use ncontours=1 with start=0 and end=0 to force a single contour at zero
3239
+ # Note: Plotly doesn't support custom text on contour lines, so we rely on the legend
3240
+ contour = go.Contour(
3241
+ x=xvals,
3242
+ y=yvals,
3243
+ z=affinity_2d.T, # Transpose to match Plotly's expected orientation
3244
+ ncontours=1, # Only generate one contour level
3245
+ contours=dict(
3246
+ coloring='lines', # Draw only contour lines, not filled regions
3247
+ start=0, # Start at 0
3248
+ end=0, # End at 0
3249
+ showlabels=False # Don't show "0" labels on contour lines
3250
+ ),
3251
+ line=dict(
3252
+ color=col[i],
3253
+ width=lwd[i],
3254
+ dash=lty[i]
3255
+ ),
3256
+ colorscale=[[0, col[i]], [1, col[i]]], # Force uniform color
3257
+ showscale=False,
3258
+ hoverinfo='skip',
3259
+ name=sp_name,
3260
+ legendgroup=sp_name,
3261
+ showlegend=True
3262
+ )
3263
+
3264
+ fig.add_trace(contour)
3265
+
3266
+ # Set axis labels if not adding to existing plot
3267
+ if not add:
3268
+ # Create axis labels with proper units
3269
+ from .thermo import thermo as thermo_func
3270
+ basis_df = eout['basis']
3271
+ basis_sp = list(basis_df.index)
3272
+ basis_state = list(basis_df['state'])
3273
+
3274
+ unit_dict = {"P": "bar", "T": "°C", "pH": "", "Eh": "volts", "IS": "mol/kg"}
3275
+
3276
+ for i, s in enumerate(basis_sp):
3277
+ if basis_state[i] in ["aq", "liq", "cr"]:
3278
+ unit_dict[s] = f"log <i>a</i><sub>{s}</sub>"
3279
+ else:
3280
+ unit_dict[s] = f"log <i>f</i><sub>{s}</sub>"
3281
+
3282
+ if not isinstance(xlab, str):
3283
+ xlab = xvar + ", " + unit_dict.get(xvar, "")
3284
+ if xvar == "pH":
3285
+ xlab = "pH"
3286
+ if xvar in basis_sp:
3287
+ xlab = unit_dict[xvar]
3288
+
3289
+ if not isinstance(ylab, str):
3290
+ ylab = yvar + ", " + unit_dict.get(yvar, "")
3291
+ if yvar in basis_sp:
3292
+ ylab = unit_dict[yvar]
3293
+ if yvar == "pH":
3294
+ ylab = "pH"
3295
+
3296
+ fig.update_xaxes(title_text=xlab)
3297
+ fig.update_yaxes(title_text=ylab)
3298
+
3299
+ fig.update_layout(
3300
+ template="simple_white",
3301
+ width=width,
3302
+ height=height,
3303
+ showlegend=True
3304
+ )
3305
+
3306
+ if isinstance(main, str):
3307
+ fig.update_layout(title={'text': main, 'x': 0.5, 'xanchor': 'center'})
3308
+
3309
+ # Configure download button
3310
+ save_as_name, save_format_final = _save_figure(fig, save_as, save_format, save_scale,
3311
+ plot_width=width, plot_height=height, ppi=1)
3312
+
3313
+ config = {'displaylogo': False,
3314
+ 'modeBarButtonsToRemove': ['resetScale2d', 'toggleSpikelines'],
3315
+ 'toImageButtonOptions': {
3316
+ 'format': save_format_final,
3317
+ 'filename': save_as_name,
3318
+ 'height': height,
3319
+ 'width': width,
3320
+ 'scale': save_scale,
3321
+ }}
3322
+
3323
+ fig._config = fig._config | config if hasattr(fig, '_config') else config
3324
+
3325
+ # Show plot if requested
3326
+ if plot_it:
3327
+ fig.show(config=config)
3328
+
3329
+ # Return empty DataFrame (saturation doesn't produce tabular data like predominance diagrams)
3330
+ df = pd.DataFrame()
3331
+
3332
+ return df, fig
3333
+
3334
+
3335
+ # Export main functions
3336
+ __all__ = ['diagram', 'diagram_interactive', 'water_lines', 'find_tp']