mxlpy 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. mxlpy/__init__.py +165 -0
  2. mxlpy/distributions.py +339 -0
  3. mxlpy/experimental/__init__.py +12 -0
  4. mxlpy/experimental/diff.py +226 -0
  5. mxlpy/fit.py +291 -0
  6. mxlpy/fns.py +191 -0
  7. mxlpy/integrators/__init__.py +19 -0
  8. mxlpy/integrators/int_assimulo.py +146 -0
  9. mxlpy/integrators/int_scipy.py +146 -0
  10. mxlpy/label_map.py +610 -0
  11. mxlpy/linear_label_map.py +303 -0
  12. mxlpy/mc.py +548 -0
  13. mxlpy/mca.py +280 -0
  14. mxlpy/meta/__init__.py +11 -0
  15. mxlpy/meta/codegen_latex.py +516 -0
  16. mxlpy/meta/codegen_modebase.py +110 -0
  17. mxlpy/meta/codegen_py.py +107 -0
  18. mxlpy/meta/source_tools.py +320 -0
  19. mxlpy/model.py +1737 -0
  20. mxlpy/nn/__init__.py +10 -0
  21. mxlpy/nn/_tensorflow.py +0 -0
  22. mxlpy/nn/_torch.py +129 -0
  23. mxlpy/npe.py +277 -0
  24. mxlpy/parallel.py +171 -0
  25. mxlpy/parameterise.py +27 -0
  26. mxlpy/paths.py +36 -0
  27. mxlpy/plot.py +875 -0
  28. mxlpy/py.typed +0 -0
  29. mxlpy/sbml/__init__.py +14 -0
  30. mxlpy/sbml/_data.py +77 -0
  31. mxlpy/sbml/_export.py +644 -0
  32. mxlpy/sbml/_import.py +599 -0
  33. mxlpy/sbml/_mathml.py +691 -0
  34. mxlpy/sbml/_name_conversion.py +52 -0
  35. mxlpy/sbml/_unit_conversion.py +74 -0
  36. mxlpy/scan.py +629 -0
  37. mxlpy/simulator.py +655 -0
  38. mxlpy/surrogates/__init__.py +31 -0
  39. mxlpy/surrogates/_poly.py +97 -0
  40. mxlpy/surrogates/_torch.py +196 -0
  41. mxlpy/symbolic/__init__.py +10 -0
  42. mxlpy/symbolic/strikepy.py +582 -0
  43. mxlpy/symbolic/symbolic_model.py +75 -0
  44. mxlpy/types.py +474 -0
  45. mxlpy-0.8.0.dist-info/METADATA +106 -0
  46. mxlpy-0.8.0.dist-info/RECORD +48 -0
  47. mxlpy-0.8.0.dist-info/WHEEL +4 -0
  48. mxlpy-0.8.0.dist-info/licenses/LICENSE +674 -0
mxlpy/plot.py ADDED
@@ -0,0 +1,875 @@
1
+ """Plotting Utilities Module.
2
+
3
+ This module provides functions and classes for creating various plots and visualizations
4
+ for metabolic models. It includes functionality for plotting heatmaps, time courses,
5
+ and parameter scans.
6
+
7
+ Functions:
8
+ plot_heatmap: Plot a heatmap of the given data.
9
+ plot_time_course: Plot a time course of the given data.
10
+ plot_parameter_scan: Plot a parameter scan of the given data.
11
+ plot_3d_surface: Plot a 3D surface of the given data.
12
+ plot_3d_scatter: Plot a 3D scatter plot of the given data.
13
+ plot_label_distribution: Plot the distribution of labels in the given data.
14
+ plot_linear_label_distribution: Plot the distribution of linear labels in the given
15
+ data.
16
+ plot_label_correlation: Plot the correlation between labels in the given data.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ __all__ = [
22
+ "FigAx",
23
+ "FigAxs",
24
+ "Linestyle",
25
+ "add_grid",
26
+ "bars",
27
+ "grid_layout",
28
+ "heatmap",
29
+ "heatmap_from_2d_idx",
30
+ "heatmaps_from_2d_idx",
31
+ "line_autogrouped",
32
+ "line_mean_std",
33
+ "lines",
34
+ "lines_grouped",
35
+ "lines_mean_std_from_2d_idx",
36
+ "one_axes",
37
+ "relative_label_distribution",
38
+ "reset_prop_cycle",
39
+ "rotate_xlabels",
40
+ "shade_protocol",
41
+ "show",
42
+ "trajectories_2d",
43
+ "two_axes",
44
+ "violins",
45
+ "violins_from_2d_idx",
46
+ ]
47
+
48
+ import itertools as it
49
+ import math
50
+ from typing import TYPE_CHECKING, Literal, cast
51
+
52
+ import numpy as np
53
+ import pandas as pd
54
+ import seaborn as sns
55
+ from matplotlib import pyplot as plt
56
+ from matplotlib.axes import Axes
57
+ from matplotlib.colors import (
58
+ LogNorm,
59
+ Normalize,
60
+ SymLogNorm,
61
+ colorConverter, # type: ignore
62
+ )
63
+ from matplotlib.figure import Figure
64
+ from mpl_toolkits.mplot3d import Axes3D
65
+
66
+ from mxlpy.label_map import LabelMapper
67
+
68
+ if TYPE_CHECKING:
69
+ from matplotlib.collections import QuadMesh
70
+
71
+ from mxlpy.linear_label_map import LinearLabelMapper
72
+ from mxlpy.model import Model
73
+ from mxlpy.types import Array, ArrayLike
74
+
75
+ type FigAx = tuple[Figure, Axes]
76
+ type FigAxs = tuple[Figure, list[Axes]]
77
+
78
+ type Linestyle = Literal[
79
+ "solid",
80
+ "dotted",
81
+ "dashed",
82
+ "dashdot",
83
+ ]
84
+
85
+ ##########################################################################
86
+ # Helpers
87
+ ##########################################################################
88
+
89
+
90
+ def _relative_luminance(color: Array) -> float:
91
+ """Calculate the relative luminance of a color."""
92
+ rgb = colorConverter.to_rgba_array(color)[:, :3]
93
+
94
+ # If RsRGB <= 0.03928 then R = RsRGB/12.92 else R = ((RsRGB+0.055)/1.055) ^ 2.4
95
+ rsrgb = np.where(
96
+ rgb <= 0.03928, # noqa: PLR2004
97
+ rgb / 12.92,
98
+ ((rgb + 0.055) / 1.055) ** 2.4,
99
+ )
100
+
101
+ # L = 0.2126 * R + 0.7152 * G + 0.0722 * B
102
+ return np.matmul(rsrgb, [0.2126, 0.7152, 0.0722])[0]
103
+
104
+
105
+ def _get_norm(vmin: float, vmax: float) -> Normalize:
106
+ """Get a suitable normalization object for the given data.
107
+
108
+ Uses a logarithmic scale for values greater than 1000 or less than -1000,
109
+ a symmetrical logarithmic scale for values less than or equal to 0,
110
+ and a linear scale for all other values.
111
+
112
+ Args:
113
+ vmin: Minimum value of the data.
114
+ vmax: Maximum value of the data.
115
+
116
+ Returns:
117
+ Normalize: A normalization object for the given data.
118
+
119
+ """
120
+ if vmax < 1000 and vmin > -1000: # noqa: PLR2004
121
+ norm = Normalize(vmin=vmin, vmax=vmax)
122
+ elif vmin <= 0:
123
+ norm = SymLogNorm(linthresh=1, vmin=vmin, vmax=vmax, base=10)
124
+ else:
125
+ norm = LogNorm(vmin=vmin, vmax=vmax)
126
+ return norm
127
+
128
+
129
+ def _norm_with_zero_center(df: pd.DataFrame) -> Normalize:
130
+ """Get a normalization object with zero-centered values for the given data."""
131
+ v = max(abs(df.min().min()), abs(df.max().max()))
132
+ return _get_norm(vmin=-v, vmax=v)
133
+
134
+
135
+ def _partition_by_order_of_magnitude(s: pd.Series) -> list[list[str]]:
136
+ """Partition a series into groups based on the order of magnitude of the values."""
137
+ return [
138
+ i.to_list()
139
+ for i in np.floor(np.log10(s)).to_frame(name=0).groupby(0)[0].groups.values() # type: ignore
140
+ ]
141
+
142
+
143
+ def _split_large_groups[T](groups: list[list[T]], max_size: int) -> list[list[T]]:
144
+ """Split groups larger than the given size into smaller groups."""
145
+ return list(
146
+ it.chain(
147
+ *(
148
+ (
149
+ [group]
150
+ if len(group) < max_size
151
+ else [ # type: ignore
152
+ list(i)
153
+ for i in np.array_split(group, math.ceil(len(group) / max_size)) # type: ignore
154
+ ]
155
+ )
156
+ for group in groups
157
+ )
158
+ )
159
+ ) # type: ignore
160
+
161
+
162
+ def _default_color(ax: Axes, color: str | None) -> str:
163
+ """Get a default color for the given axis."""
164
+ return f"C{len(ax.lines)}" if color is None else color
165
+
166
+
167
+ def _default_labels(
168
+ ax: Axes,
169
+ xlabel: str | None = None,
170
+ ylabel: str | None = None,
171
+ zlabel: str | None = None,
172
+ ) -> None:
173
+ """Set default labels for the given axis.
174
+
175
+ Args:
176
+ ax: matplotlib Axes
177
+ xlabel: Label for the x-axis.
178
+ ylabel: Label for the y-axis.
179
+ zlabel: Label for the z-axis.
180
+
181
+ """
182
+ ax.set_xlabel("Add a label / unit" if xlabel is None else xlabel)
183
+ ax.set_ylabel("Add a label / unit" if ylabel is None else ylabel)
184
+ if isinstance(ax, Axes3D):
185
+ ax.set_zlabel("Add a label / unit" if zlabel is None else zlabel)
186
+
187
+
188
+ def _annotate_colormap(
189
+ df: pd.DataFrame,
190
+ ax: Axes,
191
+ sci_annotation_bounds: tuple[float, float],
192
+ annotation_style: str,
193
+ hm: QuadMesh,
194
+ ) -> None:
195
+ """Annotate a heatmap with the values of the data.
196
+
197
+ Args:
198
+ df: Dataframe to annotate.
199
+ ax: Axes to annotate.
200
+ sci_annotation_bounds: Bounds for scientific notation.
201
+ annotation_style: Style for the annotations.
202
+ hm: QuadMesh object of the heatmap.
203
+
204
+ """
205
+ hm.update_scalarmappable() # So that get_facecolor is an array
206
+ xpos, ypos = np.meshgrid(
207
+ np.arange(len(df.columns)),
208
+ np.arange(len(df.index)),
209
+ )
210
+ for x, y, val, color in zip(
211
+ xpos.flat,
212
+ ypos.flat,
213
+ hm.get_array().flat, # type: ignore
214
+ hm.get_facecolor(),
215
+ strict=True,
216
+ ):
217
+ val_text = (
218
+ f"{val:.{annotation_style}}"
219
+ if sci_annotation_bounds[0] < abs(val) <= sci_annotation_bounds[1]
220
+ else f"{val:.0e}"
221
+ )
222
+ ax.text(
223
+ x + 0.5,
224
+ y + 0.5,
225
+ val_text,
226
+ ha="center",
227
+ va="center",
228
+ color="black" if _relative_luminance(color) > 0.45 else "white", # type: ignore # noqa: PLR2004
229
+ )
230
+
231
+
232
+ def add_grid(ax: Axes) -> Axes:
233
+ """Add a grid to the given axis."""
234
+ ax.grid(visible=True)
235
+ ax.set_axisbelow(b=True)
236
+ return ax
237
+
238
+
239
+ def rotate_xlabels(
240
+ ax: Axes,
241
+ rotation: float = 45,
242
+ ha: Literal["left", "center", "right"] = "right",
243
+ ) -> Axes:
244
+ """Rotate the x-axis labels of the given axis.
245
+
246
+ Args:
247
+ ax: Axis to rotate the labels of.
248
+ rotation: Rotation angle in degrees (default: 45).
249
+ ha: Horizontal alignment of the labels (default
250
+
251
+ Returns:
252
+ Axes object for object chaining
253
+
254
+ """
255
+ for label in ax.get_xticklabels():
256
+ label.set_rotation(rotation)
257
+ label.set_horizontalalignment(ha)
258
+ return ax
259
+
260
+
261
+ def show(fig: Figure | None = None) -> None:
262
+ if fig is None:
263
+ plt.show()
264
+ else:
265
+ fig.show()
266
+
267
+
268
+ def reset_prop_cycle(ax: Axes) -> None:
269
+ ax.set_prop_cycle(plt.rcParams["axes.prop_cycle"])
270
+
271
+
272
+ ##########################################################################
273
+ # General plot layout
274
+ ##########################################################################
275
+
276
+
277
+ def _default_fig_ax(
278
+ *,
279
+ ax: Axes | None,
280
+ grid: bool,
281
+ figsize: tuple[float, float] | None = None,
282
+ ) -> FigAx:
283
+ """Create a figure and axes if none are provided.
284
+
285
+ Args:
286
+ ax: Axis to use for the plot.
287
+ grid: Whether to add a grid to the plot.
288
+ figsize: Size of the figure (default: None).
289
+
290
+ Returns:
291
+ Figure and Axes objects for the plot.
292
+
293
+ """
294
+ if ax is None:
295
+ fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)
296
+ else:
297
+ fig = cast(Figure, ax.get_figure())
298
+
299
+ if grid:
300
+ add_grid(ax)
301
+ return fig, ax
302
+
303
+
304
+ def _default_fig_axs(
305
+ axs: list[Axes] | None,
306
+ *,
307
+ ncols: int,
308
+ nrows: int,
309
+ figsize: tuple[float, float] | None,
310
+ grid: bool,
311
+ sharex: bool,
312
+ sharey: bool,
313
+ ) -> FigAxs:
314
+ """Create a figure and multiple axes if none are provided.
315
+
316
+ Args:
317
+ axs: Axes to use for the plot.
318
+ ncols: Number of columns for the plot.
319
+ nrows: Number of rows for the plot.
320
+ figsize: Size of the figure (default: None).
321
+ grid: Whether to add a grid to the plot.
322
+ sharex: Whether to share the x-axis between the axes.
323
+ sharey: Whether to share the y-axis between the axes.
324
+
325
+ Returns:
326
+ Figure and Axes objects for the plot.
327
+
328
+ """
329
+ if axs is None or len(axs) == 0:
330
+ fig, axs_array = plt.subplots(
331
+ nrows=nrows,
332
+ ncols=ncols,
333
+ sharex=sharex,
334
+ sharey=sharey,
335
+ figsize=figsize,
336
+ squeeze=False,
337
+ layout="constrained",
338
+ )
339
+ axs = list(axs_array.flatten())
340
+ else:
341
+ fig = cast(Figure, axs[0].get_figure())
342
+
343
+ if grid:
344
+ for ax in axs:
345
+ add_grid(ax)
346
+ return fig, axs
347
+
348
+
349
+ def one_axes(
350
+ *,
351
+ figsize: tuple[float, float] | None = None,
352
+ grid: bool = False,
353
+ ) -> FigAx:
354
+ """Create a figure with two axes."""
355
+ return _default_fig_ax(
356
+ ax=None,
357
+ grid=grid,
358
+ figsize=figsize,
359
+ )
360
+
361
+
362
+ def two_axes(
363
+ *,
364
+ figsize: tuple[float, float] | None = None,
365
+ sharex: bool = True,
366
+ sharey: bool = False,
367
+ grid: bool = False,
368
+ ) -> FigAxs:
369
+ """Create a figure with two axes."""
370
+ return _default_fig_axs(
371
+ None,
372
+ ncols=2,
373
+ nrows=1,
374
+ figsize=figsize,
375
+ sharex=sharex,
376
+ sharey=sharey,
377
+ grid=grid,
378
+ )
379
+
380
+
381
+ def grid_layout(
382
+ n_groups: int,
383
+ *,
384
+ n_cols: int = 2,
385
+ col_width: float = 3,
386
+ row_height: float = 4,
387
+ sharex: bool = True,
388
+ sharey: bool = False,
389
+ grid: bool = True,
390
+ ) -> tuple[Figure, list[Axes]]:
391
+ """Create a grid layout for the given number of groups."""
392
+ n_cols = min(n_groups, n_cols)
393
+ n_rows = math.ceil(n_groups / n_cols)
394
+ figsize = (n_cols * col_width, n_rows * row_height)
395
+
396
+ return _default_fig_axs(
397
+ None,
398
+ ncols=n_cols,
399
+ nrows=n_rows,
400
+ figsize=figsize,
401
+ sharex=sharex,
402
+ sharey=sharey,
403
+ grid=grid,
404
+ )
405
+
406
+
407
+ ##########################################################################
408
+ # Plots
409
+ ##########################################################################
410
+
411
+
412
+ def bars(
413
+ x: pd.DataFrame,
414
+ *,
415
+ ax: Axes | None = None,
416
+ grid: bool = True,
417
+ ) -> FigAx:
418
+ """Plot multiple lines on the same axis."""
419
+ fig, ax = _default_fig_ax(ax=ax, grid=grid)
420
+ sns.barplot(data=x, ax=ax)
421
+ _default_labels(ax, xlabel=x.index.name, ylabel=None)
422
+ ax.legend(x.columns)
423
+ return fig, ax
424
+
425
+
426
+ def lines(
427
+ x: pd.DataFrame | pd.Series,
428
+ *,
429
+ ax: Axes | None = None,
430
+ grid: bool = True,
431
+ alpha: float = 1.0,
432
+ legend: bool = True,
433
+ ) -> FigAx:
434
+ """Plot multiple lines on the same axis."""
435
+ fig, ax = _default_fig_ax(ax=ax, grid=grid)
436
+ ax.plot(
437
+ x.index,
438
+ x,
439
+ # linestyle=linestyle,
440
+ # linewidth=linewidth,
441
+ alpha=alpha,
442
+ )
443
+ _default_labels(ax, xlabel=x.index.name, ylabel=None)
444
+ if legend:
445
+ if isinstance(x, pd.Series):
446
+ ax.legend([str(x.name)])
447
+ else:
448
+ ax.legend(x.columns)
449
+ return fig, ax
450
+
451
+
452
+ def lines_grouped(
453
+ groups: list[pd.DataFrame] | list[pd.Series],
454
+ *,
455
+ n_cols: int = 2,
456
+ col_width: float = 3,
457
+ row_height: float = 4,
458
+ sharex: bool = True,
459
+ sharey: bool = False,
460
+ grid: bool = True,
461
+ ) -> FigAxs:
462
+ """Plot multiple groups of lines on separate axes."""
463
+ fig, axs = grid_layout(
464
+ len(groups),
465
+ n_cols=n_cols,
466
+ col_width=col_width,
467
+ row_height=row_height,
468
+ sharex=sharex,
469
+ sharey=sharey,
470
+ grid=grid,
471
+ )
472
+
473
+ for group, ax in zip(groups, axs, strict=False):
474
+ lines(group, ax=ax, grid=grid)
475
+
476
+ for i in range(len(groups), len(axs)):
477
+ axs[i].set_visible(False)
478
+
479
+ return fig, axs
480
+
481
+
482
+ def line_autogrouped(
483
+ s: pd.Series | pd.DataFrame,
484
+ *,
485
+ n_cols: int = 2,
486
+ col_width: float = 4,
487
+ row_height: float = 3,
488
+ max_group_size: int = 6,
489
+ grid: bool = True,
490
+ ) -> FigAxs:
491
+ """Plot a series or dataframe with lines grouped by order of magnitude."""
492
+ group_names = _split_large_groups(
493
+ _partition_by_order_of_magnitude(s)
494
+ if isinstance(s, pd.Series)
495
+ else _partition_by_order_of_magnitude(s.max()),
496
+ max_size=max_group_size,
497
+ )
498
+
499
+ groups: list[pd.Series] | list[pd.DataFrame] = (
500
+ [s.loc[group] for group in group_names]
501
+ if isinstance(s, pd.Series)
502
+ else [s.loc[:, group] for group in group_names]
503
+ )
504
+
505
+ return lines_grouped(
506
+ groups,
507
+ n_cols=n_cols,
508
+ col_width=col_width,
509
+ row_height=row_height,
510
+ grid=grid,
511
+ )
512
+
513
+
514
+ def line_mean_std(
515
+ df: pd.DataFrame,
516
+ *,
517
+ label: str | None = None,
518
+ ax: Axes | None = None,
519
+ color: str | None = None,
520
+ alpha: float = 0.2,
521
+ grid: bool = True,
522
+ ) -> FigAx:
523
+ """Plot the mean and standard deviation using a line and fill."""
524
+ fig, ax = _default_fig_ax(ax=ax, grid=grid)
525
+ color = _default_color(ax=ax, color=color)
526
+
527
+ mean = df.mean(axis=1)
528
+ std = df.std(axis=1)
529
+ ax.plot(
530
+ mean.index,
531
+ mean,
532
+ color=color,
533
+ label=label,
534
+ )
535
+ ax.fill_between(
536
+ df.index,
537
+ mean - std,
538
+ mean + std,
539
+ color=color,
540
+ alpha=alpha,
541
+ )
542
+ _default_labels(ax, xlabel=df.index.name, ylabel=None)
543
+ return fig, ax
544
+
545
+
546
+ def lines_mean_std_from_2d_idx(
547
+ df: pd.DataFrame,
548
+ *,
549
+ names: list[str] | None = None,
550
+ ax: Axes | None = None,
551
+ alpha: float = 0.2,
552
+ grid: bool = True,
553
+ ) -> FigAx:
554
+ """Plot the mean and standard deviation of a 2D indexed dataframe."""
555
+ if len(cast(pd.MultiIndex, df.index).levels) != 2: # noqa: PLR2004
556
+ msg = "MultiIndex must have exactly two levels"
557
+ raise ValueError(msg)
558
+
559
+ fig, ax = _default_fig_ax(ax=ax, grid=grid)
560
+
561
+ for name in df.columns if names is None else names:
562
+ line_mean_std(
563
+ df[name].unstack().T,
564
+ label=name,
565
+ alpha=alpha,
566
+ ax=ax,
567
+ )
568
+ ax.legend()
569
+ return fig, ax
570
+
571
+
572
+ def heatmap(
573
+ df: pd.DataFrame,
574
+ *,
575
+ annotate: bool = False,
576
+ colorbar: bool = True,
577
+ invert_yaxis: bool = True,
578
+ cmap: str = "RdBu_r",
579
+ norm: Normalize | None = None,
580
+ ax: Axes | None = None,
581
+ cax: Axes | None = None,
582
+ sci_annotation_bounds: tuple[float, float] = (0.01, 100),
583
+ annotation_style: str = "2g",
584
+ ) -> tuple[Figure, Axes, QuadMesh]:
585
+ """Plot a heatmap of the given data."""
586
+ fig, ax = _default_fig_ax(
587
+ ax=ax,
588
+ figsize=(
589
+ max(4, 0.5 * len(df.columns)),
590
+ max(4, 0.5 * len(df.index)),
591
+ ),
592
+ grid=False,
593
+ )
594
+ if norm is None:
595
+ norm = _norm_with_zero_center(df)
596
+
597
+ hm = ax.pcolormesh(df, norm=norm, cmap=cmap)
598
+ ax.set_xticks(
599
+ np.arange(0, len(df.columns), 1) + 0.5,
600
+ labels=df.columns,
601
+ )
602
+ ax.set_yticks(
603
+ np.arange(0, len(df.index), 1) + 0.5,
604
+ labels=df.index,
605
+ )
606
+
607
+ if annotate:
608
+ _annotate_colormap(df, ax, sci_annotation_bounds, annotation_style, hm)
609
+
610
+ if colorbar:
611
+ # Add a colorbar
612
+ cb = fig.colorbar(hm, cax, ax)
613
+ cb.outline.set_linewidth(0) # type: ignore
614
+
615
+ if invert_yaxis:
616
+ ax.invert_yaxis()
617
+ rotate_xlabels(ax, rotation=45, ha="right")
618
+ return fig, ax, hm
619
+
620
+
621
+ def heatmap_from_2d_idx(
622
+ df: pd.DataFrame,
623
+ variable: str,
624
+ ax: Axes | None = None,
625
+ ) -> FigAx:
626
+ """Plot a heatmap of a 2D indexed dataframe."""
627
+ if len(cast(pd.MultiIndex, df.index).levels) != 2: # noqa: PLR2004
628
+ msg = "MultiIndex must have exactly two levels"
629
+ raise ValueError(msg)
630
+
631
+ fig, ax = _default_fig_ax(ax=ax, grid=False)
632
+ df2d = df[variable].unstack()
633
+
634
+ ax.set_title(variable)
635
+ # Note: pcolormesh swaps index/columns
636
+ hm = ax.pcolormesh(df2d.T)
637
+ ax.set_xlabel(df2d.index.name)
638
+ ax.set_ylabel(df2d.columns.name)
639
+ ax.set_xticks(
640
+ np.arange(0, len(df2d.index), 1) + 0.5,
641
+ labels=[f"{i:.2f}" for i in df2d.index],
642
+ )
643
+ ax.set_yticks(
644
+ np.arange(0, len(df2d.columns), 1) + 0.5,
645
+ labels=[f"{i:.2f}" for i in df2d.columns],
646
+ )
647
+
648
+ rotate_xlabels(ax, rotation=45, ha="right")
649
+
650
+ # Add colorbar
651
+ fig.colorbar(hm, ax=ax)
652
+ return fig, ax
653
+
654
+
655
+ def heatmaps_from_2d_idx(
656
+ df: pd.DataFrame,
657
+ *,
658
+ n_cols: int = 3,
659
+ col_width_factor: float = 1,
660
+ row_height_factor: float = 0.6,
661
+ sharex: bool = True,
662
+ sharey: bool = False,
663
+ ) -> FigAxs:
664
+ """Plot multiple heatmaps of a 2D indexed dataframe."""
665
+ idx = cast(pd.MultiIndex, df.index)
666
+
667
+ fig, axs = grid_layout(
668
+ n_groups=len(df.columns),
669
+ n_cols=min(n_cols, len(df)),
670
+ col_width=len(idx.levels[0]) * col_width_factor,
671
+ row_height=len(idx.levels[1]) * row_height_factor,
672
+ sharex=sharex,
673
+ sharey=sharey,
674
+ grid=False,
675
+ )
676
+ for ax, var in zip(axs, df.columns, strict=False):
677
+ heatmap_from_2d_idx(df, var, ax=ax)
678
+ return fig, axs
679
+
680
+
681
+ def violins(
682
+ df: pd.DataFrame,
683
+ *,
684
+ ax: Axes | None = None,
685
+ grid: bool = True,
686
+ ) -> FigAx:
687
+ """Plot multiple violins on the same axis."""
688
+ fig, ax = _default_fig_ax(ax=ax, grid=grid)
689
+ sns.violinplot(df, ax=ax)
690
+ _default_labels(ax=ax, xlabel="", ylabel=None)
691
+ return fig, ax
692
+
693
+
694
+ def violins_from_2d_idx(
695
+ df: pd.DataFrame,
696
+ *,
697
+ n_cols: int = 4,
698
+ row_height: int = 2,
699
+ sharex: bool = True,
700
+ sharey: bool = False,
701
+ grid: bool = True,
702
+ ) -> FigAxs:
703
+ """Plot multiple violins of a 2D indexed dataframe."""
704
+ if len(cast(pd.MultiIndex, df.index).levels) != 2: # noqa: PLR2004
705
+ msg = "MultiIndex must have exactly two levels"
706
+ raise ValueError(msg)
707
+
708
+ fig, axs = grid_layout(
709
+ len(df.columns),
710
+ n_cols=n_cols,
711
+ row_height=row_height,
712
+ sharex=sharex,
713
+ sharey=sharey,
714
+ grid=grid,
715
+ )
716
+
717
+ for ax, col in zip(axs[: len(df.columns)], df.columns, strict=True):
718
+ ax.set_title(col)
719
+ violins(df[col].unstack(), ax=ax)
720
+
721
+ for ax in axs[len(df.columns) :]:
722
+ for axis in ["top", "bottom", "left", "right"]:
723
+ ax.spines[axis].set_linewidth(0)
724
+ ax.yaxis.set_ticks([])
725
+
726
+ for ax in axs:
727
+ rotate_xlabels(ax)
728
+ return fig, axs
729
+
730
+
731
+ def shade_protocol(
732
+ protocol: pd.Series,
733
+ *,
734
+ ax: Axes,
735
+ cmap_name: str = "Greys_r",
736
+ vmin: float | None = None,
737
+ vmax: float | None = None,
738
+ alpha: float = 0.5,
739
+ add_legend: bool = True,
740
+ ) -> None:
741
+ """Shade the given protocol on the given axis."""
742
+ from matplotlib import colormaps
743
+ from matplotlib.colors import Normalize
744
+ from matplotlib.legend import Legend
745
+ from matplotlib.patches import Patch
746
+
747
+ cmap = colormaps[cmap_name]
748
+ norm = Normalize(
749
+ vmin=protocol.min() if vmin is None else vmin,
750
+ vmax=protocol.max() if vmax is None else vmax,
751
+ )
752
+
753
+ t0 = pd.Timedelta(seconds=0)
754
+ for t_end, val in protocol.items():
755
+ t_end = cast(pd.Timedelta, t_end)
756
+ ax.axvspan(
757
+ t0.total_seconds(),
758
+ t_end.total_seconds(),
759
+ facecolor=cmap(norm(val)),
760
+ edgecolor=None,
761
+ alpha=alpha,
762
+ )
763
+ t0 = t_end # type: ignore
764
+
765
+ if add_legend:
766
+ ax.add_artist(
767
+ Legend(
768
+ ax,
769
+ handles=[
770
+ Patch(
771
+ facecolor=cmap(norm(val)),
772
+ alpha=alpha,
773
+ label=val,
774
+ ) # type: ignore
775
+ for val in protocol
776
+ ],
777
+ labels=protocol,
778
+ loc="lower right",
779
+ bbox_to_anchor=(1.0, 0.0),
780
+ title="protocol" if protocol.name is None else cast(str, protocol.name),
781
+ )
782
+ )
783
+
784
+
785
+ ##########################################################################
786
+ # Plots that actually require a model :/
787
+ ##########################################################################
788
+
789
+
790
+ def trajectories_2d(
791
+ model: Model,
792
+ x1: tuple[str, ArrayLike],
793
+ x2: tuple[str, ArrayLike],
794
+ y0: dict[str, float] | None = None,
795
+ ax: Axes | None = None,
796
+ ) -> FigAx:
797
+ """Plot trajectories of two variables in a 2D phase space.
798
+
799
+ Examples:
800
+ >>> trajectories_2d(
801
+ ... model,
802
+ ... ("S", np.linspace(0, 1, 10)),
803
+ ... ("P", np.linspace(0, 1, 10)),
804
+ ... )
805
+
806
+ Args:
807
+ model: Model to use for the plot.
808
+ x1: Tuple of the first variable name and its values.
809
+ x2: Tuple of the second variable name and its values.
810
+ y0: Initial conditions for the model.
811
+ ax: Axes to use for the plot.
812
+
813
+ """
814
+ name1, values1 = x1
815
+ name2, values2 = x2
816
+ n1 = len(values1)
817
+ n2 = len(values2)
818
+ u = np.zeros((n1, n2))
819
+ v = np.zeros((n1, n2))
820
+ y0 = model.get_initial_conditions() if y0 is None else y0
821
+ for i, ii in enumerate(values1):
822
+ for j, jj in enumerate(values2):
823
+ rhs = model.get_right_hand_side(y0 | {name1: ii, name2: jj})
824
+ u[i, j] = rhs[name1]
825
+ v[i, j] = rhs[name2]
826
+
827
+ fig, ax = _default_fig_ax(ax=ax, grid=False)
828
+ ax.quiver(values1, values2, u.T, v.T)
829
+ return fig, ax
830
+
831
+
832
+ ##########################################################################
833
+ # Label Plots
834
+ ##########################################################################
835
+
836
+
837
+ def relative_label_distribution(
838
+ mapper: LabelMapper | LinearLabelMapper,
839
+ concs: pd.DataFrame,
840
+ *,
841
+ subset: list[str] | None = None,
842
+ n_cols: int = 2,
843
+ col_width: float = 3,
844
+ row_height: float = 3,
845
+ sharey: bool = False,
846
+ grid: bool = True,
847
+ ) -> FigAxs:
848
+ """Plot the relative distribution of labels in the given data."""
849
+ variables = list(mapper.label_variables) if subset is None else subset
850
+ fig, axs = grid_layout(
851
+ n_groups=len(variables),
852
+ n_cols=n_cols,
853
+ col_width=col_width,
854
+ row_height=row_height,
855
+ sharey=sharey,
856
+ grid=grid,
857
+ )
858
+ if isinstance(mapper, LabelMapper):
859
+ for ax, name in zip(axs, variables, strict=False):
860
+ for i in range(mapper.label_variables[name]):
861
+ isos = mapper.get_isotopomers_of_at_position(name, i)
862
+ labels = cast(pd.DataFrame, concs.loc[:, isos])
863
+ total = concs.loc[:, f"{name}__total"]
864
+ ax.plot(labels.index, (labels.sum(axis=1) / total), label=f"C{i + 1}")
865
+ ax.set_title(name)
866
+ ax.legend()
867
+ else:
868
+ for ax, (name, isos) in zip(
869
+ axs, mapper.get_isotopomers(variables).items(), strict=False
870
+ ):
871
+ ax.plot(concs.index, concs.loc[:, isos])
872
+ ax.set_title(name)
873
+ ax.legend([f"C{i + 1}" for i in range(len(isos))])
874
+
875
+ return fig, axs