modelbase2 0.1.78__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modelbase2/__init__.py +138 -26
- modelbase2/distributions.py +306 -0
- modelbase2/experimental/__init__.py +17 -0
- modelbase2/experimental/codegen.py +239 -0
- modelbase2/experimental/diff.py +227 -0
- modelbase2/experimental/notes.md +4 -0
- modelbase2/experimental/tex.py +521 -0
- modelbase2/fit.py +284 -0
- modelbase2/fns.py +185 -0
- modelbase2/integrators/__init__.py +19 -0
- modelbase2/integrators/int_assimulo.py +146 -0
- modelbase2/integrators/int_scipy.py +147 -0
- modelbase2/label_map.py +610 -0
- modelbase2/linear_label_map.py +301 -0
- modelbase2/mc.py +548 -0
- modelbase2/mca.py +280 -0
- modelbase2/model.py +1621 -0
- modelbase2/npe.py +343 -0
- modelbase2/parallel.py +171 -0
- modelbase2/parameterise.py +28 -0
- modelbase2/paths.py +36 -0
- modelbase2/plot.py +829 -0
- modelbase2/sbml/__init__.py +14 -0
- modelbase2/sbml/_data.py +77 -0
- modelbase2/sbml/_export.py +656 -0
- modelbase2/sbml/_import.py +585 -0
- modelbase2/sbml/_mathml.py +691 -0
- modelbase2/sbml/_name_conversion.py +52 -0
- modelbase2/sbml/_unit_conversion.py +74 -0
- modelbase2/scan.py +616 -0
- modelbase2/scope.py +96 -0
- modelbase2/simulator.py +635 -0
- modelbase2/surrogates/__init__.py +32 -0
- modelbase2/surrogates/_poly.py +66 -0
- modelbase2/surrogates/_torch.py +249 -0
- modelbase2/surrogates.py +316 -0
- modelbase2/types.py +352 -11
- modelbase2-0.2.0.dist-info/METADATA +81 -0
- modelbase2-0.2.0.dist-info/RECORD +42 -0
- {modelbase2-0.1.78.dist-info → modelbase2-0.2.0.dist-info}/WHEEL +1 -1
- modelbase2/core/__init__.py +0 -29
- modelbase2/core/algebraic_module_container.py +0 -130
- modelbase2/core/constant_container.py +0 -113
- modelbase2/core/data.py +0 -109
- modelbase2/core/name_container.py +0 -29
- modelbase2/core/reaction_container.py +0 -115
- modelbase2/core/utils.py +0 -28
- modelbase2/core/variable_container.py +0 -24
- modelbase2/ode/__init__.py +0 -13
- modelbase2/ode/integrator.py +0 -80
- modelbase2/ode/mca.py +0 -270
- modelbase2/ode/model.py +0 -470
- modelbase2/ode/simulator.py +0 -153
- modelbase2/utils/__init__.py +0 -0
- modelbase2/utils/plotting.py +0 -372
- modelbase2-0.1.78.dist-info/METADATA +0 -44
- modelbase2-0.1.78.dist-info/RECORD +0 -22
- {modelbase2-0.1.78.dist-info → modelbase2-0.2.0.dist-info/licenses}/LICENSE +0 -0
modelbase2/plot.py
ADDED
@@ -0,0 +1,829 @@
|
|
1
|
+
"""Plotting Utilities Module.
|
2
|
+
|
3
|
+
This module provides functions and classes for creating various plots and visualizations
|
4
|
+
for metabolic models. It includes functionality for plotting heatmaps, time courses,
|
5
|
+
and parameter scans.
|
6
|
+
|
7
|
+
Functions:
|
8
|
+
plot_heatmap: Plot a heatmap of the given data.
|
9
|
+
plot_time_course: Plot a time course of the given data.
|
10
|
+
plot_parameter_scan: Plot a parameter scan of the given data.
|
11
|
+
plot_3d_surface: Plot a 3D surface of the given data.
|
12
|
+
plot_3d_scatter: Plot a 3D scatter plot of the given data.
|
13
|
+
plot_label_distribution: Plot the distribution of labels in the given data.
|
14
|
+
plot_linear_label_distribution: Plot the distribution of linear labels in the given
|
15
|
+
data.
|
16
|
+
plot_label_correlation: Plot the correlation between labels in the given data.
|
17
|
+
"""
|
18
|
+
|
19
|
+
from __future__ import annotations
|
20
|
+
|
21
|
+
__all__ = [
|
22
|
+
"FigAx",
|
23
|
+
"FigAxs",
|
24
|
+
"add_grid",
|
25
|
+
"bars",
|
26
|
+
"grid_layout",
|
27
|
+
"heatmap",
|
28
|
+
"heatmap_from_2d_idx",
|
29
|
+
"heatmaps_from_2d_idx",
|
30
|
+
"line_autogrouped",
|
31
|
+
"line_mean_std",
|
32
|
+
"lines",
|
33
|
+
"lines_grouped",
|
34
|
+
"lines_mean_std_from_2d_idx",
|
35
|
+
"relative_label_distribution",
|
36
|
+
"rotate_xlabels",
|
37
|
+
"shade_protocol",
|
38
|
+
"trajectories_2d",
|
39
|
+
"two_axes",
|
40
|
+
"violins",
|
41
|
+
"violins_from_2d_idx",
|
42
|
+
]
|
43
|
+
|
44
|
+
import itertools as it
|
45
|
+
import math
|
46
|
+
from typing import TYPE_CHECKING, Literal, cast
|
47
|
+
|
48
|
+
import numpy as np
|
49
|
+
import pandas as pd
|
50
|
+
import seaborn as sns
|
51
|
+
from matplotlib import pyplot as plt
|
52
|
+
from matplotlib.axes import Axes
|
53
|
+
from matplotlib.colors import (
|
54
|
+
LogNorm,
|
55
|
+
Normalize,
|
56
|
+
SymLogNorm,
|
57
|
+
colorConverter, # type: ignore
|
58
|
+
)
|
59
|
+
from matplotlib.figure import Figure
|
60
|
+
from mpl_toolkits.mplot3d import Axes3D
|
61
|
+
|
62
|
+
from modelbase2.label_map import LabelMapper
|
63
|
+
|
64
|
+
if TYPE_CHECKING:
|
65
|
+
from matplotlib.collections import QuadMesh
|
66
|
+
|
67
|
+
from modelbase2.linear_label_map import LinearLabelMapper
|
68
|
+
from modelbase2.model import Model
|
69
|
+
from modelbase2.types import Array, ArrayLike
|
70
|
+
|
71
|
+
type FigAx = tuple[Figure, Axes]
|
72
|
+
type FigAxs = tuple[Figure, list[Axes]]
|
73
|
+
|
74
|
+
|
75
|
+
##########################################################################
|
76
|
+
# Helpers
|
77
|
+
##########################################################################
|
78
|
+
|
79
|
+
|
80
|
+
def _relative_luminance(color: Array) -> float:
|
81
|
+
"""Calculate the relative luminance of a color."""
|
82
|
+
rgb = colorConverter.to_rgba_array(color)[:, :3]
|
83
|
+
|
84
|
+
# If RsRGB <= 0.03928 then R = RsRGB/12.92 else R = ((RsRGB+0.055)/1.055) ^ 2.4
|
85
|
+
rsrgb = np.where(
|
86
|
+
rgb <= 0.03928, # noqa: PLR2004
|
87
|
+
rgb / 12.92,
|
88
|
+
((rgb + 0.055) / 1.055) ** 2.4,
|
89
|
+
)
|
90
|
+
|
91
|
+
# L = 0.2126 * R + 0.7152 * G + 0.0722 * B
|
92
|
+
return np.matmul(rsrgb, [0.2126, 0.7152, 0.0722])[0]
|
93
|
+
|
94
|
+
|
95
|
+
def _get_norm(vmin: float, vmax: float) -> Normalize:
|
96
|
+
"""Get a suitable normalization object for the given data.
|
97
|
+
|
98
|
+
Uses a logarithmic scale for values greater than 1000 or less than -1000,
|
99
|
+
a symmetrical logarithmic scale for values less than or equal to 0,
|
100
|
+
and a linear scale for all other values.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
vmin: Minimum value of the data.
|
104
|
+
vmax: Maximum value of the data.
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
Normalize: A normalization object for the given data.
|
108
|
+
|
109
|
+
"""
|
110
|
+
if vmax < 1000 and vmin > -1000: # noqa: PLR2004
|
111
|
+
norm = Normalize(vmin=vmin, vmax=vmax)
|
112
|
+
elif vmin <= 0:
|
113
|
+
norm = SymLogNorm(linthresh=1, vmin=vmin, vmax=vmax, base=10)
|
114
|
+
else:
|
115
|
+
norm = LogNorm(vmin=vmin, vmax=vmax)
|
116
|
+
return norm
|
117
|
+
|
118
|
+
|
119
|
+
def _norm_with_zero_center(df: pd.DataFrame) -> Normalize:
|
120
|
+
"""Get a normalization object with zero-centered values for the given data."""
|
121
|
+
v = max(abs(df.min().min()), abs(df.max().max()))
|
122
|
+
return _get_norm(vmin=-v, vmax=v)
|
123
|
+
|
124
|
+
|
125
|
+
def _partition_by_order_of_magnitude(s: pd.Series) -> list[list[str]]:
|
126
|
+
"""Partition a series into groups based on the order of magnitude of the values."""
|
127
|
+
return [
|
128
|
+
i.to_list()
|
129
|
+
for i in np.floor(np.log10(s)).to_frame(name=0).groupby(0)[0].groups.values() # type: ignore
|
130
|
+
]
|
131
|
+
|
132
|
+
|
133
|
+
def _split_large_groups[T](groups: list[list[T]], max_size: int) -> list[list[T]]:
|
134
|
+
"""Split groups larger than the given size into smaller groups."""
|
135
|
+
return list(
|
136
|
+
it.chain(
|
137
|
+
*(
|
138
|
+
(
|
139
|
+
[group]
|
140
|
+
if len(group) < max_size
|
141
|
+
else [ # type: ignore
|
142
|
+
list(i)
|
143
|
+
for i in np.array_split(group, math.ceil(len(group) / max_size)) # type: ignore
|
144
|
+
]
|
145
|
+
)
|
146
|
+
for group in groups
|
147
|
+
)
|
148
|
+
)
|
149
|
+
) # type: ignore
|
150
|
+
|
151
|
+
|
152
|
+
def _default_color(ax: Axes, color: str | None) -> str:
|
153
|
+
"""Get a default color for the given axis."""
|
154
|
+
return f"C{len(ax.lines)}" if color is None else color
|
155
|
+
|
156
|
+
|
157
|
+
def _default_labels(
|
158
|
+
ax: Axes,
|
159
|
+
xlabel: str | None = None,
|
160
|
+
ylabel: str | None = None,
|
161
|
+
zlabel: str | None = None,
|
162
|
+
) -> None:
|
163
|
+
"""Set default labels for the given axis.
|
164
|
+
|
165
|
+
Args:
|
166
|
+
ax: matplotlib Axes
|
167
|
+
xlabel: Label for the x-axis.
|
168
|
+
ylabel: Label for the y-axis.
|
169
|
+
zlabel: Label for the z-axis.
|
170
|
+
|
171
|
+
"""
|
172
|
+
ax.set_xlabel("Add a label / unit" if xlabel is None else xlabel)
|
173
|
+
ax.set_ylabel("Add a label / unit" if ylabel is None else ylabel)
|
174
|
+
if isinstance(ax, Axes3D):
|
175
|
+
ax.set_zlabel("Add a label / unit" if zlabel is None else zlabel)
|
176
|
+
|
177
|
+
|
178
|
+
def _annotate_colormap(
|
179
|
+
df: pd.DataFrame,
|
180
|
+
ax: Axes,
|
181
|
+
sci_annotation_bounds: tuple[float, float],
|
182
|
+
annotation_style: str,
|
183
|
+
hm: QuadMesh,
|
184
|
+
) -> None:
|
185
|
+
"""Annotate a heatmap with the values of the data.
|
186
|
+
|
187
|
+
Args:
|
188
|
+
df: Dataframe to annotate.
|
189
|
+
ax: Axes to annotate.
|
190
|
+
sci_annotation_bounds: Bounds for scientific notation.
|
191
|
+
annotation_style: Style for the annotations.
|
192
|
+
hm: QuadMesh object of the heatmap.
|
193
|
+
|
194
|
+
"""
|
195
|
+
hm.update_scalarmappable() # So that get_facecolor is an array
|
196
|
+
xpos, ypos = np.meshgrid(
|
197
|
+
np.arange(len(df.columns)),
|
198
|
+
np.arange(len(df.index)),
|
199
|
+
)
|
200
|
+
for x, y, val, color in zip(
|
201
|
+
xpos.flat,
|
202
|
+
ypos.flat,
|
203
|
+
hm.get_array().flat, # type: ignore
|
204
|
+
hm.get_facecolor(),
|
205
|
+
strict=True,
|
206
|
+
):
|
207
|
+
val_text = (
|
208
|
+
f"{val:.{annotation_style}}"
|
209
|
+
if sci_annotation_bounds[0] < abs(val) <= sci_annotation_bounds[1]
|
210
|
+
else f"{val:.0e}"
|
211
|
+
)
|
212
|
+
ax.text(
|
213
|
+
x + 0.5,
|
214
|
+
y + 0.5,
|
215
|
+
val_text,
|
216
|
+
ha="center",
|
217
|
+
va="center",
|
218
|
+
color="black" if _relative_luminance(color) > 0.45 else "white", # type: ignore # noqa: PLR2004
|
219
|
+
)
|
220
|
+
|
221
|
+
|
222
|
+
def add_grid(ax: Axes) -> Axes:
|
223
|
+
"""Add a grid to the given axis."""
|
224
|
+
ax.grid(visible=True)
|
225
|
+
ax.set_axisbelow(b=True)
|
226
|
+
return ax
|
227
|
+
|
228
|
+
|
229
|
+
def rotate_xlabels(
|
230
|
+
ax: Axes,
|
231
|
+
rotation: float = 45,
|
232
|
+
ha: Literal["left", "center", "right"] = "right",
|
233
|
+
) -> Axes:
|
234
|
+
"""Rotate the x-axis labels of the given axis.
|
235
|
+
|
236
|
+
Args:
|
237
|
+
ax: Axis to rotate the labels of.
|
238
|
+
rotation: Rotation angle in degrees (default: 45).
|
239
|
+
ha: Horizontal alignment of the labels (default
|
240
|
+
|
241
|
+
Returns:
|
242
|
+
Axes object for object chaining
|
243
|
+
|
244
|
+
"""
|
245
|
+
for label in ax.get_xticklabels():
|
246
|
+
label.set_rotation(rotation)
|
247
|
+
label.set_horizontalalignment(ha)
|
248
|
+
return ax
|
249
|
+
|
250
|
+
|
251
|
+
##########################################################################
|
252
|
+
# General plot layout
|
253
|
+
##########################################################################
|
254
|
+
|
255
|
+
|
256
|
+
def _default_fig_ax(
|
257
|
+
*,
|
258
|
+
ax: Axes | None,
|
259
|
+
grid: bool,
|
260
|
+
figsize: tuple[float, float] | None = None,
|
261
|
+
) -> FigAx:
|
262
|
+
"""Create a figure and axes if none are provided.
|
263
|
+
|
264
|
+
Args:
|
265
|
+
ax: Axis to use for the plot.
|
266
|
+
grid: Whether to add a grid to the plot.
|
267
|
+
figsize: Size of the figure (default: None).
|
268
|
+
|
269
|
+
Returns:
|
270
|
+
Figure and Axes objects for the plot.
|
271
|
+
|
272
|
+
"""
|
273
|
+
if ax is None:
|
274
|
+
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)
|
275
|
+
else:
|
276
|
+
fig = cast(Figure, ax.get_figure())
|
277
|
+
|
278
|
+
if grid:
|
279
|
+
add_grid(ax)
|
280
|
+
return fig, ax
|
281
|
+
|
282
|
+
|
283
|
+
def _default_fig_axs(
|
284
|
+
axs: list[Axes] | None,
|
285
|
+
*,
|
286
|
+
ncols: int,
|
287
|
+
nrows: int,
|
288
|
+
figsize: tuple[float, float] | None,
|
289
|
+
grid: bool,
|
290
|
+
sharex: bool,
|
291
|
+
sharey: bool,
|
292
|
+
) -> FigAxs:
|
293
|
+
"""Create a figure and multiple axes if none are provided.
|
294
|
+
|
295
|
+
Args:
|
296
|
+
axs: Axes to use for the plot.
|
297
|
+
ncols: Number of columns for the plot.
|
298
|
+
nrows: Number of rows for the plot.
|
299
|
+
figsize: Size of the figure (default: None).
|
300
|
+
grid: Whether to add a grid to the plot.
|
301
|
+
sharex: Whether to share the x-axis between the axes.
|
302
|
+
sharey: Whether to share the y-axis between the axes.
|
303
|
+
|
304
|
+
Returns:
|
305
|
+
Figure and Axes objects for the plot.
|
306
|
+
|
307
|
+
"""
|
308
|
+
if axs is None or len(axs) == 0:
|
309
|
+
fig, axs_array = plt.subplots(
|
310
|
+
nrows=nrows,
|
311
|
+
ncols=ncols,
|
312
|
+
sharex=sharex,
|
313
|
+
sharey=sharey,
|
314
|
+
figsize=figsize,
|
315
|
+
squeeze=False,
|
316
|
+
layout="constrained",
|
317
|
+
)
|
318
|
+
axs = list(axs_array.flatten())
|
319
|
+
else:
|
320
|
+
fig = cast(Figure, axs[0].get_figure())
|
321
|
+
|
322
|
+
if grid:
|
323
|
+
for ax in axs:
|
324
|
+
add_grid(ax)
|
325
|
+
return fig, axs
|
326
|
+
|
327
|
+
|
328
|
+
def two_axes(
|
329
|
+
*,
|
330
|
+
figsize: tuple[float, float] | None = None,
|
331
|
+
sharex: bool = True,
|
332
|
+
sharey: bool = False,
|
333
|
+
grid: bool = False,
|
334
|
+
) -> FigAxs:
|
335
|
+
"""Create a figure with two axes."""
|
336
|
+
return _default_fig_axs(
|
337
|
+
None,
|
338
|
+
ncols=2,
|
339
|
+
nrows=1,
|
340
|
+
figsize=figsize,
|
341
|
+
sharex=sharex,
|
342
|
+
sharey=sharey,
|
343
|
+
grid=grid,
|
344
|
+
)
|
345
|
+
|
346
|
+
|
347
|
+
def grid_layout(
|
348
|
+
n_groups: int,
|
349
|
+
*,
|
350
|
+
n_cols: int = 2,
|
351
|
+
col_width: float = 3,
|
352
|
+
row_height: float = 4,
|
353
|
+
sharex: bool = True,
|
354
|
+
sharey: bool = False,
|
355
|
+
grid: bool = True,
|
356
|
+
) -> tuple[Figure, list[Axes]]:
|
357
|
+
"""Create a grid layout for the given number of groups."""
|
358
|
+
n_cols = min(n_groups, n_cols)
|
359
|
+
n_rows = math.ceil(n_groups / n_cols)
|
360
|
+
figsize = (n_cols * col_width, n_rows * row_height)
|
361
|
+
|
362
|
+
return _default_fig_axs(
|
363
|
+
None,
|
364
|
+
ncols=n_cols,
|
365
|
+
nrows=n_rows,
|
366
|
+
figsize=figsize,
|
367
|
+
sharex=sharex,
|
368
|
+
sharey=sharey,
|
369
|
+
grid=grid,
|
370
|
+
)
|
371
|
+
|
372
|
+
|
373
|
+
##########################################################################
|
374
|
+
# Plots
|
375
|
+
##########################################################################
|
376
|
+
|
377
|
+
|
378
|
+
def bars(
|
379
|
+
x: pd.DataFrame,
|
380
|
+
*,
|
381
|
+
ax: Axes | None = None,
|
382
|
+
grid: bool = True,
|
383
|
+
) -> FigAx:
|
384
|
+
"""Plot multiple lines on the same axis."""
|
385
|
+
fig, ax = _default_fig_ax(ax=ax, grid=grid)
|
386
|
+
sns.barplot(data=x, ax=ax)
|
387
|
+
_default_labels(ax, xlabel=x.index.name, ylabel=None)
|
388
|
+
ax.legend(x.columns)
|
389
|
+
return fig, ax
|
390
|
+
|
391
|
+
|
392
|
+
def lines(
|
393
|
+
x: pd.DataFrame | pd.Series,
|
394
|
+
*,
|
395
|
+
ax: Axes | None = None,
|
396
|
+
grid: bool = True,
|
397
|
+
) -> FigAx:
|
398
|
+
"""Plot multiple lines on the same axis."""
|
399
|
+
fig, ax = _default_fig_ax(ax=ax, grid=grid)
|
400
|
+
ax.plot(x.index, x)
|
401
|
+
_default_labels(ax, xlabel=x.index.name, ylabel=None)
|
402
|
+
ax.legend(x.columns)
|
403
|
+
return fig, ax
|
404
|
+
|
405
|
+
|
406
|
+
def lines_grouped(
|
407
|
+
groups: list[pd.DataFrame] | list[pd.Series],
|
408
|
+
*,
|
409
|
+
n_cols: int = 2,
|
410
|
+
col_width: float = 3,
|
411
|
+
row_height: float = 4,
|
412
|
+
sharex: bool = True,
|
413
|
+
sharey: bool = False,
|
414
|
+
grid: bool = True,
|
415
|
+
) -> FigAxs:
|
416
|
+
"""Plot multiple groups of lines on separate axes."""
|
417
|
+
fig, axs = grid_layout(
|
418
|
+
len(groups),
|
419
|
+
n_cols=n_cols,
|
420
|
+
col_width=col_width,
|
421
|
+
row_height=row_height,
|
422
|
+
sharex=sharex,
|
423
|
+
sharey=sharey,
|
424
|
+
grid=grid,
|
425
|
+
)
|
426
|
+
|
427
|
+
for group, ax in zip(groups, axs, strict=False):
|
428
|
+
lines(group, ax=ax, grid=grid)
|
429
|
+
|
430
|
+
for i in range(len(groups), len(axs)):
|
431
|
+
axs[i].set_visible(False)
|
432
|
+
|
433
|
+
return fig, axs
|
434
|
+
|
435
|
+
|
436
|
+
def line_autogrouped(
|
437
|
+
s: pd.Series | pd.DataFrame,
|
438
|
+
*,
|
439
|
+
n_cols: int = 2,
|
440
|
+
col_width: float = 4,
|
441
|
+
row_height: float = 3,
|
442
|
+
max_group_size: int = 6,
|
443
|
+
grid: bool = True,
|
444
|
+
) -> FigAxs:
|
445
|
+
"""Plot a series or dataframe with lines grouped by order of magnitude."""
|
446
|
+
group_names = _split_large_groups(
|
447
|
+
_partition_by_order_of_magnitude(s)
|
448
|
+
if isinstance(s, pd.Series)
|
449
|
+
else _partition_by_order_of_magnitude(s.max()),
|
450
|
+
max_size=max_group_size,
|
451
|
+
)
|
452
|
+
|
453
|
+
groups: list[pd.Series] | list[pd.DataFrame] = (
|
454
|
+
[s.loc[group] for group in group_names]
|
455
|
+
if isinstance(s, pd.Series)
|
456
|
+
else [s.loc[:, group] for group in group_names]
|
457
|
+
)
|
458
|
+
|
459
|
+
return lines_grouped(
|
460
|
+
groups,
|
461
|
+
n_cols=n_cols,
|
462
|
+
col_width=col_width,
|
463
|
+
row_height=row_height,
|
464
|
+
grid=grid,
|
465
|
+
)
|
466
|
+
|
467
|
+
|
468
|
+
def line_mean_std(
|
469
|
+
df: pd.DataFrame,
|
470
|
+
*,
|
471
|
+
label: str | None = None,
|
472
|
+
ax: Axes | None = None,
|
473
|
+
color: str | None = None,
|
474
|
+
alpha: float = 0.2,
|
475
|
+
grid: bool = True,
|
476
|
+
) -> FigAx:
|
477
|
+
"""Plot the mean and standard deviation using a line and fill."""
|
478
|
+
fig, ax = _default_fig_ax(ax=ax, grid=grid)
|
479
|
+
color = _default_color(ax=ax, color=color)
|
480
|
+
|
481
|
+
mean = df.mean(axis=1)
|
482
|
+
std = df.std(axis=1)
|
483
|
+
ax.plot(
|
484
|
+
mean.index,
|
485
|
+
mean,
|
486
|
+
color=color,
|
487
|
+
label=label,
|
488
|
+
)
|
489
|
+
ax.fill_between(
|
490
|
+
df.index,
|
491
|
+
mean - std,
|
492
|
+
mean + std,
|
493
|
+
color=color,
|
494
|
+
alpha=alpha,
|
495
|
+
)
|
496
|
+
_default_labels(ax, xlabel=df.index.name, ylabel=None)
|
497
|
+
return fig, ax
|
498
|
+
|
499
|
+
|
500
|
+
def lines_mean_std_from_2d_idx(
|
501
|
+
df: pd.DataFrame,
|
502
|
+
*,
|
503
|
+
names: list[str] | None = None,
|
504
|
+
ax: Axes | None = None,
|
505
|
+
alpha: float = 0.2,
|
506
|
+
grid: bool = True,
|
507
|
+
) -> FigAx:
|
508
|
+
"""Plot the mean and standard deviation of a 2D indexed dataframe."""
|
509
|
+
if len(cast(pd.MultiIndex, df.index).levels) != 2: # noqa: PLR2004
|
510
|
+
msg = "MultiIndex must have exactly two levels"
|
511
|
+
raise ValueError(msg)
|
512
|
+
|
513
|
+
fig, ax = _default_fig_ax(ax=ax, grid=grid)
|
514
|
+
|
515
|
+
for name in df.columns if names is None else names:
|
516
|
+
line_mean_std(
|
517
|
+
df[name].unstack().T,
|
518
|
+
label=name,
|
519
|
+
alpha=alpha,
|
520
|
+
ax=ax,
|
521
|
+
)
|
522
|
+
ax.legend()
|
523
|
+
return fig, ax
|
524
|
+
|
525
|
+
|
526
|
+
def heatmap(
|
527
|
+
df: pd.DataFrame,
|
528
|
+
*,
|
529
|
+
annotate: bool = False,
|
530
|
+
colorbar: bool = True,
|
531
|
+
invert_yaxis: bool = True,
|
532
|
+
cmap: str = "RdBu_r",
|
533
|
+
norm: Normalize | None = None,
|
534
|
+
ax: Axes | None = None,
|
535
|
+
cax: Axes | None = None,
|
536
|
+
sci_annotation_bounds: tuple[float, float] = (0.01, 100),
|
537
|
+
annotation_style: str = "2g",
|
538
|
+
) -> tuple[Figure, Axes, QuadMesh]:
|
539
|
+
"""Plot a heatmap of the given data."""
|
540
|
+
fig, ax = _default_fig_ax(
|
541
|
+
ax=ax,
|
542
|
+
figsize=(
|
543
|
+
max(4, 0.5 * len(df.columns)),
|
544
|
+
max(4, 0.5 * len(df.index)),
|
545
|
+
),
|
546
|
+
grid=False,
|
547
|
+
)
|
548
|
+
if norm is None:
|
549
|
+
norm = _norm_with_zero_center(df)
|
550
|
+
|
551
|
+
hm = ax.pcolormesh(df, norm=norm, cmap=cmap)
|
552
|
+
ax.set_xticks(
|
553
|
+
np.arange(0, len(df.columns), 1) + 0.5,
|
554
|
+
labels=df.columns,
|
555
|
+
)
|
556
|
+
ax.set_yticks(
|
557
|
+
np.arange(0, len(df.index), 1) + 0.5,
|
558
|
+
labels=df.index,
|
559
|
+
)
|
560
|
+
|
561
|
+
if annotate:
|
562
|
+
_annotate_colormap(df, ax, sci_annotation_bounds, annotation_style, hm)
|
563
|
+
|
564
|
+
if colorbar:
|
565
|
+
# Add a colorbar
|
566
|
+
cb = fig.colorbar(hm, cax, ax)
|
567
|
+
cb.outline.set_linewidth(0) # type: ignore
|
568
|
+
|
569
|
+
if invert_yaxis:
|
570
|
+
ax.invert_yaxis()
|
571
|
+
rotate_xlabels(ax, rotation=45, ha="right")
|
572
|
+
return fig, ax, hm
|
573
|
+
|
574
|
+
|
575
|
+
def heatmap_from_2d_idx(
|
576
|
+
df: pd.DataFrame,
|
577
|
+
variable: str,
|
578
|
+
ax: Axes | None = None,
|
579
|
+
) -> FigAx:
|
580
|
+
"""Plot a heatmap of a 2D indexed dataframe."""
|
581
|
+
if len(cast(pd.MultiIndex, df.index).levels) != 2: # noqa: PLR2004
|
582
|
+
msg = "MultiIndex must have exactly two levels"
|
583
|
+
raise ValueError(msg)
|
584
|
+
|
585
|
+
fig, ax = _default_fig_ax(ax=ax, grid=False)
|
586
|
+
df2d = df[variable].unstack()
|
587
|
+
|
588
|
+
ax.set_title(variable)
|
589
|
+
# Note: pcolormesh swaps index/columns
|
590
|
+
hm = ax.pcolormesh(df2d.T)
|
591
|
+
ax.set_xlabel(df2d.index.name)
|
592
|
+
ax.set_ylabel(df2d.columns.name)
|
593
|
+
ax.set_xticks(
|
594
|
+
np.arange(0, len(df2d.index), 1) + 0.5,
|
595
|
+
labels=[f"{i:.2f}" for i in df2d.index],
|
596
|
+
)
|
597
|
+
ax.set_yticks(
|
598
|
+
np.arange(0, len(df2d.columns), 1) + 0.5,
|
599
|
+
labels=[f"{i:.2f}" for i in df2d.columns],
|
600
|
+
)
|
601
|
+
|
602
|
+
rotate_xlabels(ax, rotation=45, ha="right")
|
603
|
+
|
604
|
+
# Add colorbar
|
605
|
+
fig.colorbar(hm, ax=ax)
|
606
|
+
return fig, ax
|
607
|
+
|
608
|
+
|
609
|
+
def heatmaps_from_2d_idx(
|
610
|
+
df: pd.DataFrame,
|
611
|
+
*,
|
612
|
+
n_cols: int = 3,
|
613
|
+
col_width_factor: float = 1,
|
614
|
+
row_height_factor: float = 0.6,
|
615
|
+
sharex: bool = True,
|
616
|
+
sharey: bool = False,
|
617
|
+
) -> FigAxs:
|
618
|
+
"""Plot multiple heatmaps of a 2D indexed dataframe."""
|
619
|
+
idx = cast(pd.MultiIndex, df.index)
|
620
|
+
|
621
|
+
fig, axs = grid_layout(
|
622
|
+
n_groups=len(df.columns),
|
623
|
+
n_cols=min(n_cols, len(df)),
|
624
|
+
col_width=len(idx.levels[0]) * col_width_factor,
|
625
|
+
row_height=len(idx.levels[1]) * row_height_factor,
|
626
|
+
sharex=sharex,
|
627
|
+
sharey=sharey,
|
628
|
+
grid=False,
|
629
|
+
)
|
630
|
+
for ax, var in zip(axs, df.columns, strict=False):
|
631
|
+
heatmap_from_2d_idx(df, var, ax=ax)
|
632
|
+
return fig, axs
|
633
|
+
|
634
|
+
|
635
|
+
def violins(
|
636
|
+
df: pd.DataFrame,
|
637
|
+
*,
|
638
|
+
ax: Axes | None = None,
|
639
|
+
grid: bool = True,
|
640
|
+
) -> FigAx:
|
641
|
+
"""Plot multiple violins on the same axis."""
|
642
|
+
fig, ax = _default_fig_ax(ax=ax, grid=grid)
|
643
|
+
sns.violinplot(df, ax=ax)
|
644
|
+
_default_labels(ax=ax, xlabel="", ylabel=None)
|
645
|
+
return fig, ax
|
646
|
+
|
647
|
+
|
648
|
+
def violins_from_2d_idx(
|
649
|
+
df: pd.DataFrame,
|
650
|
+
*,
|
651
|
+
n_cols: int = 4,
|
652
|
+
row_height: int = 2,
|
653
|
+
sharex: bool = True,
|
654
|
+
sharey: bool = False,
|
655
|
+
grid: bool = True,
|
656
|
+
) -> FigAxs:
|
657
|
+
"""Plot multiple violins of a 2D indexed dataframe."""
|
658
|
+
if len(cast(pd.MultiIndex, df.index).levels) != 2: # noqa: PLR2004
|
659
|
+
msg = "MultiIndex must have exactly two levels"
|
660
|
+
raise ValueError(msg)
|
661
|
+
|
662
|
+
fig, axs = grid_layout(
|
663
|
+
len(df.columns),
|
664
|
+
n_cols=n_cols,
|
665
|
+
row_height=row_height,
|
666
|
+
sharex=sharex,
|
667
|
+
sharey=sharey,
|
668
|
+
grid=grid,
|
669
|
+
)
|
670
|
+
|
671
|
+
for ax, col in zip(axs[: len(df.columns)], df.columns, strict=True):
|
672
|
+
ax.set_title(col)
|
673
|
+
violins(df[col].unstack(), ax=ax)
|
674
|
+
|
675
|
+
for ax in axs[len(df.columns) :]:
|
676
|
+
for axis in ["top", "bottom", "left", "right"]:
|
677
|
+
ax.spines[axis].set_linewidth(0)
|
678
|
+
ax.yaxis.set_ticks([])
|
679
|
+
|
680
|
+
for ax in axs:
|
681
|
+
rotate_xlabels(ax)
|
682
|
+
return fig, axs
|
683
|
+
|
684
|
+
|
685
|
+
def shade_protocol(
|
686
|
+
protocol: pd.Series,
|
687
|
+
*,
|
688
|
+
ax: Axes,
|
689
|
+
cmap_name: str = "Greys_r",
|
690
|
+
vmin: float | None = None,
|
691
|
+
vmax: float | None = None,
|
692
|
+
alpha: float = 0.5,
|
693
|
+
add_legend: bool = True,
|
694
|
+
) -> None:
|
695
|
+
"""Shade the given protocol on the given axis."""
|
696
|
+
from matplotlib import colormaps
|
697
|
+
from matplotlib.colors import Normalize
|
698
|
+
from matplotlib.legend import Legend
|
699
|
+
from matplotlib.patches import Patch
|
700
|
+
|
701
|
+
cmap = colormaps[cmap_name]
|
702
|
+
norm = Normalize(
|
703
|
+
vmin=protocol.min() if vmin is None else vmin,
|
704
|
+
vmax=protocol.max() if vmax is None else vmax,
|
705
|
+
)
|
706
|
+
|
707
|
+
t0 = pd.Timedelta(seconds=0)
|
708
|
+
for t_end, val in protocol.items():
|
709
|
+
t_end = cast(pd.Timedelta, t_end)
|
710
|
+
ax.axvspan(
|
711
|
+
t0.total_seconds(),
|
712
|
+
t_end.total_seconds(),
|
713
|
+
facecolor=cmap(norm(val)),
|
714
|
+
edgecolor=None,
|
715
|
+
alpha=alpha,
|
716
|
+
)
|
717
|
+
t0 = t_end # type: ignore
|
718
|
+
|
719
|
+
if add_legend:
|
720
|
+
ax.add_artist(
|
721
|
+
Legend(
|
722
|
+
ax,
|
723
|
+
handles=[
|
724
|
+
Patch(
|
725
|
+
facecolor=cmap(norm(val)),
|
726
|
+
alpha=alpha,
|
727
|
+
label=val,
|
728
|
+
) # type: ignore
|
729
|
+
for val in protocol
|
730
|
+
],
|
731
|
+
labels=protocol,
|
732
|
+
loc="lower right",
|
733
|
+
bbox_to_anchor=(1.0, 0.0),
|
734
|
+
title="protocol" if protocol.name is None else cast(str, protocol.name),
|
735
|
+
)
|
736
|
+
)
|
737
|
+
|
738
|
+
|
739
|
+
##########################################################################
|
740
|
+
# Plots that actually require a model :/
|
741
|
+
##########################################################################
|
742
|
+
|
743
|
+
|
744
|
+
def trajectories_2d(
|
745
|
+
model: Model,
|
746
|
+
x1: tuple[str, ArrayLike],
|
747
|
+
x2: tuple[str, ArrayLike],
|
748
|
+
y0: dict[str, float] | None = None,
|
749
|
+
ax: Axes | None = None,
|
750
|
+
) -> FigAx:
|
751
|
+
"""Plot trajectories of two variables in a 2D phase space.
|
752
|
+
|
753
|
+
Examples:
|
754
|
+
>>> trajectories_2d(
|
755
|
+
... model,
|
756
|
+
... ("S", np.linspace(0, 1, 10)),
|
757
|
+
... ("P", np.linspace(0, 1, 10)),
|
758
|
+
... )
|
759
|
+
|
760
|
+
Args:
|
761
|
+
model: Model to use for the plot.
|
762
|
+
x1: Tuple of the first variable name and its values.
|
763
|
+
x2: Tuple of the second variable name and its values.
|
764
|
+
y0: Initial conditions for the model.
|
765
|
+
ax: Axes to use for the plot.
|
766
|
+
|
767
|
+
"""
|
768
|
+
name1, values1 = x1
|
769
|
+
name2, values2 = x2
|
770
|
+
n1 = len(values1)
|
771
|
+
n2 = len(values2)
|
772
|
+
u = np.zeros((n1, n2))
|
773
|
+
v = np.zeros((n1, n2))
|
774
|
+
y0 = model.get_initial_conditions() if y0 is None else y0
|
775
|
+
for i, ii in enumerate(values1):
|
776
|
+
for j, jj in enumerate(values2):
|
777
|
+
rhs = model.get_right_hand_side(y0 | {name1: ii, name2: jj})
|
778
|
+
u[i, j] = rhs[name1]
|
779
|
+
v[i, j] = rhs[name2]
|
780
|
+
|
781
|
+
fig, ax = _default_fig_ax(ax=ax, grid=False)
|
782
|
+
ax.quiver(values1, values2, u.T, v.T)
|
783
|
+
return fig, ax
|
784
|
+
|
785
|
+
|
786
|
+
##########################################################################
|
787
|
+
# Label Plots
|
788
|
+
##########################################################################
|
789
|
+
|
790
|
+
|
791
|
+
def relative_label_distribution(
|
792
|
+
mapper: LabelMapper | LinearLabelMapper,
|
793
|
+
concs: pd.DataFrame,
|
794
|
+
*,
|
795
|
+
subset: list[str] | None = None,
|
796
|
+
n_cols: int = 2,
|
797
|
+
col_width: float = 3,
|
798
|
+
row_height: float = 3,
|
799
|
+
sharey: bool = False,
|
800
|
+
grid: bool = True,
|
801
|
+
) -> FigAxs:
|
802
|
+
"""Plot the relative distribution of labels in the given data."""
|
803
|
+
variables = list(mapper.label_variables) if subset is None else subset
|
804
|
+
fig, axs = grid_layout(
|
805
|
+
n_groups=len(variables),
|
806
|
+
n_cols=n_cols,
|
807
|
+
col_width=col_width,
|
808
|
+
row_height=row_height,
|
809
|
+
sharey=sharey,
|
810
|
+
grid=grid,
|
811
|
+
)
|
812
|
+
if isinstance(mapper, LabelMapper):
|
813
|
+
for ax, name in zip(axs, variables, strict=False):
|
814
|
+
for i in range(mapper.label_variables[name]):
|
815
|
+
isos = mapper.get_isotopomers_of_at_position(name, i)
|
816
|
+
labels = cast(pd.DataFrame, concs.loc[:, isos])
|
817
|
+
total = concs.loc[:, f"{name}__total"]
|
818
|
+
ax.plot(labels.index, (labels.sum(axis=1) / total), label=f"C{i+1}")
|
819
|
+
ax.set_title(name)
|
820
|
+
ax.legend()
|
821
|
+
else:
|
822
|
+
for ax, (name, isos) in zip(
|
823
|
+
axs, mapper.get_isotopomers(variables).items(), strict=False
|
824
|
+
):
|
825
|
+
ax.plot(concs.index, concs.loc[:, isos])
|
826
|
+
ax.set_title(name)
|
827
|
+
ax.legend([f"C{i+1}" for i in range(len(isos))])
|
828
|
+
|
829
|
+
return fig, axs
|