modelbase2 0.1.79__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. modelbase2/__init__.py +138 -26
  2. modelbase2/distributions.py +306 -0
  3. modelbase2/experimental/__init__.py +17 -0
  4. modelbase2/experimental/codegen.py +239 -0
  5. modelbase2/experimental/diff.py +227 -0
  6. modelbase2/experimental/notes.md +4 -0
  7. modelbase2/experimental/tex.py +521 -0
  8. modelbase2/fit.py +284 -0
  9. modelbase2/fns.py +185 -0
  10. modelbase2/integrators/__init__.py +19 -0
  11. modelbase2/integrators/int_assimulo.py +146 -0
  12. modelbase2/integrators/int_scipy.py +147 -0
  13. modelbase2/label_map.py +610 -0
  14. modelbase2/linear_label_map.py +301 -0
  15. modelbase2/mc.py +548 -0
  16. modelbase2/mca.py +280 -0
  17. modelbase2/model.py +1621 -0
  18. modelbase2/npe.py +343 -0
  19. modelbase2/parallel.py +171 -0
  20. modelbase2/parameterise.py +28 -0
  21. modelbase2/paths.py +36 -0
  22. modelbase2/plot.py +829 -0
  23. modelbase2/sbml/__init__.py +14 -0
  24. modelbase2/sbml/_data.py +77 -0
  25. modelbase2/sbml/_export.py +656 -0
  26. modelbase2/sbml/_import.py +585 -0
  27. modelbase2/sbml/_mathml.py +691 -0
  28. modelbase2/sbml/_name_conversion.py +52 -0
  29. modelbase2/sbml/_unit_conversion.py +74 -0
  30. modelbase2/scan.py +616 -0
  31. modelbase2/scope.py +96 -0
  32. modelbase2/simulator.py +635 -0
  33. modelbase2/surrogates/__init__.py +32 -0
  34. modelbase2/surrogates/_poly.py +66 -0
  35. modelbase2/surrogates/_torch.py +249 -0
  36. modelbase2/surrogates.py +316 -0
  37. modelbase2/types.py +352 -11
  38. modelbase2-0.2.0.dist-info/METADATA +81 -0
  39. modelbase2-0.2.0.dist-info/RECORD +42 -0
  40. {modelbase2-0.1.79.dist-info → modelbase2-0.2.0.dist-info}/WHEEL +1 -1
  41. modelbase2/core/__init__.py +0 -29
  42. modelbase2/core/algebraic_module_container.py +0 -130
  43. modelbase2/core/constant_container.py +0 -113
  44. modelbase2/core/data.py +0 -109
  45. modelbase2/core/name_container.py +0 -29
  46. modelbase2/core/reaction_container.py +0 -115
  47. modelbase2/core/utils.py +0 -28
  48. modelbase2/core/variable_container.py +0 -24
  49. modelbase2/ode/__init__.py +0 -13
  50. modelbase2/ode/integrator.py +0 -80
  51. modelbase2/ode/mca.py +0 -270
  52. modelbase2/ode/model.py +0 -470
  53. modelbase2/ode/simulator.py +0 -153
  54. modelbase2/utils/__init__.py +0 -0
  55. modelbase2/utils/plotting.py +0 -372
  56. modelbase2-0.1.79.dist-info/METADATA +0 -44
  57. modelbase2-0.1.79.dist-info/RECORD +0 -22
  58. {modelbase2-0.1.79.dist-info → modelbase2-0.2.0.dist-info/licenses}/LICENSE +0 -0
modelbase2/plot.py ADDED
@@ -0,0 +1,829 @@
1
+ """Plotting Utilities Module.
2
+
3
+ This module provides functions and classes for creating various plots and visualizations
4
+ for metabolic models. It includes functionality for plotting heatmaps, time courses,
5
+ and parameter scans.
6
+
7
+ Functions:
8
+ plot_heatmap: Plot a heatmap of the given data.
9
+ plot_time_course: Plot a time course of the given data.
10
+ plot_parameter_scan: Plot a parameter scan of the given data.
11
+ plot_3d_surface: Plot a 3D surface of the given data.
12
+ plot_3d_scatter: Plot a 3D scatter plot of the given data.
13
+ plot_label_distribution: Plot the distribution of labels in the given data.
14
+ plot_linear_label_distribution: Plot the distribution of linear labels in the given
15
+ data.
16
+ plot_label_correlation: Plot the correlation between labels in the given data.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ __all__ = [
22
+ "FigAx",
23
+ "FigAxs",
24
+ "add_grid",
25
+ "bars",
26
+ "grid_layout",
27
+ "heatmap",
28
+ "heatmap_from_2d_idx",
29
+ "heatmaps_from_2d_idx",
30
+ "line_autogrouped",
31
+ "line_mean_std",
32
+ "lines",
33
+ "lines_grouped",
34
+ "lines_mean_std_from_2d_idx",
35
+ "relative_label_distribution",
36
+ "rotate_xlabels",
37
+ "shade_protocol",
38
+ "trajectories_2d",
39
+ "two_axes",
40
+ "violins",
41
+ "violins_from_2d_idx",
42
+ ]
43
+
44
+ import itertools as it
45
+ import math
46
+ from typing import TYPE_CHECKING, Literal, cast
47
+
48
+ import numpy as np
49
+ import pandas as pd
50
+ import seaborn as sns
51
+ from matplotlib import pyplot as plt
52
+ from matplotlib.axes import Axes
53
+ from matplotlib.colors import (
54
+ LogNorm,
55
+ Normalize,
56
+ SymLogNorm,
57
+ colorConverter, # type: ignore
58
+ )
59
+ from matplotlib.figure import Figure
60
+ from mpl_toolkits.mplot3d import Axes3D
61
+
62
+ from modelbase2.label_map import LabelMapper
63
+
64
+ if TYPE_CHECKING:
65
+ from matplotlib.collections import QuadMesh
66
+
67
+ from modelbase2.linear_label_map import LinearLabelMapper
68
+ from modelbase2.model import Model
69
+ from modelbase2.types import Array, ArrayLike
70
+
71
+ type FigAx = tuple[Figure, Axes]
72
+ type FigAxs = tuple[Figure, list[Axes]]
73
+
74
+
75
+ ##########################################################################
76
+ # Helpers
77
+ ##########################################################################
78
+
79
+
80
+ def _relative_luminance(color: Array) -> float:
81
+ """Calculate the relative luminance of a color."""
82
+ rgb = colorConverter.to_rgba_array(color)[:, :3]
83
+
84
+ # If RsRGB <= 0.03928 then R = RsRGB/12.92 else R = ((RsRGB+0.055)/1.055) ^ 2.4
85
+ rsrgb = np.where(
86
+ rgb <= 0.03928, # noqa: PLR2004
87
+ rgb / 12.92,
88
+ ((rgb + 0.055) / 1.055) ** 2.4,
89
+ )
90
+
91
+ # L = 0.2126 * R + 0.7152 * G + 0.0722 * B
92
+ return np.matmul(rsrgb, [0.2126, 0.7152, 0.0722])[0]
93
+
94
+
95
+ def _get_norm(vmin: float, vmax: float) -> Normalize:
96
+ """Get a suitable normalization object for the given data.
97
+
98
+ Uses a logarithmic scale for values greater than 1000 or less than -1000,
99
+ a symmetrical logarithmic scale for values less than or equal to 0,
100
+ and a linear scale for all other values.
101
+
102
+ Args:
103
+ vmin: Minimum value of the data.
104
+ vmax: Maximum value of the data.
105
+
106
+ Returns:
107
+ Normalize: A normalization object for the given data.
108
+
109
+ """
110
+ if vmax < 1000 and vmin > -1000: # noqa: PLR2004
111
+ norm = Normalize(vmin=vmin, vmax=vmax)
112
+ elif vmin <= 0:
113
+ norm = SymLogNorm(linthresh=1, vmin=vmin, vmax=vmax, base=10)
114
+ else:
115
+ norm = LogNorm(vmin=vmin, vmax=vmax)
116
+ return norm
117
+
118
+
119
+ def _norm_with_zero_center(df: pd.DataFrame) -> Normalize:
120
+ """Get a normalization object with zero-centered values for the given data."""
121
+ v = max(abs(df.min().min()), abs(df.max().max()))
122
+ return _get_norm(vmin=-v, vmax=v)
123
+
124
+
125
+ def _partition_by_order_of_magnitude(s: pd.Series) -> list[list[str]]:
126
+ """Partition a series into groups based on the order of magnitude of the values."""
127
+ return [
128
+ i.to_list()
129
+ for i in np.floor(np.log10(s)).to_frame(name=0).groupby(0)[0].groups.values() # type: ignore
130
+ ]
131
+
132
+
133
+ def _split_large_groups[T](groups: list[list[T]], max_size: int) -> list[list[T]]:
134
+ """Split groups larger than the given size into smaller groups."""
135
+ return list(
136
+ it.chain(
137
+ *(
138
+ (
139
+ [group]
140
+ if len(group) < max_size
141
+ else [ # type: ignore
142
+ list(i)
143
+ for i in np.array_split(group, math.ceil(len(group) / max_size)) # type: ignore
144
+ ]
145
+ )
146
+ for group in groups
147
+ )
148
+ )
149
+ ) # type: ignore
150
+
151
+
152
+ def _default_color(ax: Axes, color: str | None) -> str:
153
+ """Get a default color for the given axis."""
154
+ return f"C{len(ax.lines)}" if color is None else color
155
+
156
+
157
+ def _default_labels(
158
+ ax: Axes,
159
+ xlabel: str | None = None,
160
+ ylabel: str | None = None,
161
+ zlabel: str | None = None,
162
+ ) -> None:
163
+ """Set default labels for the given axis.
164
+
165
+ Args:
166
+ ax: matplotlib Axes
167
+ xlabel: Label for the x-axis.
168
+ ylabel: Label for the y-axis.
169
+ zlabel: Label for the z-axis.
170
+
171
+ """
172
+ ax.set_xlabel("Add a label / unit" if xlabel is None else xlabel)
173
+ ax.set_ylabel("Add a label / unit" if ylabel is None else ylabel)
174
+ if isinstance(ax, Axes3D):
175
+ ax.set_zlabel("Add a label / unit" if zlabel is None else zlabel)
176
+
177
+
178
+ def _annotate_colormap(
179
+ df: pd.DataFrame,
180
+ ax: Axes,
181
+ sci_annotation_bounds: tuple[float, float],
182
+ annotation_style: str,
183
+ hm: QuadMesh,
184
+ ) -> None:
185
+ """Annotate a heatmap with the values of the data.
186
+
187
+ Args:
188
+ df: Dataframe to annotate.
189
+ ax: Axes to annotate.
190
+ sci_annotation_bounds: Bounds for scientific notation.
191
+ annotation_style: Style for the annotations.
192
+ hm: QuadMesh object of the heatmap.
193
+
194
+ """
195
+ hm.update_scalarmappable() # So that get_facecolor is an array
196
+ xpos, ypos = np.meshgrid(
197
+ np.arange(len(df.columns)),
198
+ np.arange(len(df.index)),
199
+ )
200
+ for x, y, val, color in zip(
201
+ xpos.flat,
202
+ ypos.flat,
203
+ hm.get_array().flat, # type: ignore
204
+ hm.get_facecolor(),
205
+ strict=True,
206
+ ):
207
+ val_text = (
208
+ f"{val:.{annotation_style}}"
209
+ if sci_annotation_bounds[0] < abs(val) <= sci_annotation_bounds[1]
210
+ else f"{val:.0e}"
211
+ )
212
+ ax.text(
213
+ x + 0.5,
214
+ y + 0.5,
215
+ val_text,
216
+ ha="center",
217
+ va="center",
218
+ color="black" if _relative_luminance(color) > 0.45 else "white", # type: ignore # noqa: PLR2004
219
+ )
220
+
221
+
222
+ def add_grid(ax: Axes) -> Axes:
223
+ """Add a grid to the given axis."""
224
+ ax.grid(visible=True)
225
+ ax.set_axisbelow(b=True)
226
+ return ax
227
+
228
+
229
+ def rotate_xlabels(
230
+ ax: Axes,
231
+ rotation: float = 45,
232
+ ha: Literal["left", "center", "right"] = "right",
233
+ ) -> Axes:
234
+ """Rotate the x-axis labels of the given axis.
235
+
236
+ Args:
237
+ ax: Axis to rotate the labels of.
238
+ rotation: Rotation angle in degrees (default: 45).
239
+ ha: Horizontal alignment of the labels (default
240
+
241
+ Returns:
242
+ Axes object for object chaining
243
+
244
+ """
245
+ for label in ax.get_xticklabels():
246
+ label.set_rotation(rotation)
247
+ label.set_horizontalalignment(ha)
248
+ return ax
249
+
250
+
251
+ ##########################################################################
252
+ # General plot layout
253
+ ##########################################################################
254
+
255
+
256
+ def _default_fig_ax(
257
+ *,
258
+ ax: Axes | None,
259
+ grid: bool,
260
+ figsize: tuple[float, float] | None = None,
261
+ ) -> FigAx:
262
+ """Create a figure and axes if none are provided.
263
+
264
+ Args:
265
+ ax: Axis to use for the plot.
266
+ grid: Whether to add a grid to the plot.
267
+ figsize: Size of the figure (default: None).
268
+
269
+ Returns:
270
+ Figure and Axes objects for the plot.
271
+
272
+ """
273
+ if ax is None:
274
+ fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)
275
+ else:
276
+ fig = cast(Figure, ax.get_figure())
277
+
278
+ if grid:
279
+ add_grid(ax)
280
+ return fig, ax
281
+
282
+
283
+ def _default_fig_axs(
284
+ axs: list[Axes] | None,
285
+ *,
286
+ ncols: int,
287
+ nrows: int,
288
+ figsize: tuple[float, float] | None,
289
+ grid: bool,
290
+ sharex: bool,
291
+ sharey: bool,
292
+ ) -> FigAxs:
293
+ """Create a figure and multiple axes if none are provided.
294
+
295
+ Args:
296
+ axs: Axes to use for the plot.
297
+ ncols: Number of columns for the plot.
298
+ nrows: Number of rows for the plot.
299
+ figsize: Size of the figure (default: None).
300
+ grid: Whether to add a grid to the plot.
301
+ sharex: Whether to share the x-axis between the axes.
302
+ sharey: Whether to share the y-axis between the axes.
303
+
304
+ Returns:
305
+ Figure and Axes objects for the plot.
306
+
307
+ """
308
+ if axs is None or len(axs) == 0:
309
+ fig, axs_array = plt.subplots(
310
+ nrows=nrows,
311
+ ncols=ncols,
312
+ sharex=sharex,
313
+ sharey=sharey,
314
+ figsize=figsize,
315
+ squeeze=False,
316
+ layout="constrained",
317
+ )
318
+ axs = list(axs_array.flatten())
319
+ else:
320
+ fig = cast(Figure, axs[0].get_figure())
321
+
322
+ if grid:
323
+ for ax in axs:
324
+ add_grid(ax)
325
+ return fig, axs
326
+
327
+
328
+ def two_axes(
329
+ *,
330
+ figsize: tuple[float, float] | None = None,
331
+ sharex: bool = True,
332
+ sharey: bool = False,
333
+ grid: bool = False,
334
+ ) -> FigAxs:
335
+ """Create a figure with two axes."""
336
+ return _default_fig_axs(
337
+ None,
338
+ ncols=2,
339
+ nrows=1,
340
+ figsize=figsize,
341
+ sharex=sharex,
342
+ sharey=sharey,
343
+ grid=grid,
344
+ )
345
+
346
+
347
+ def grid_layout(
348
+ n_groups: int,
349
+ *,
350
+ n_cols: int = 2,
351
+ col_width: float = 3,
352
+ row_height: float = 4,
353
+ sharex: bool = True,
354
+ sharey: bool = False,
355
+ grid: bool = True,
356
+ ) -> tuple[Figure, list[Axes]]:
357
+ """Create a grid layout for the given number of groups."""
358
+ n_cols = min(n_groups, n_cols)
359
+ n_rows = math.ceil(n_groups / n_cols)
360
+ figsize = (n_cols * col_width, n_rows * row_height)
361
+
362
+ return _default_fig_axs(
363
+ None,
364
+ ncols=n_cols,
365
+ nrows=n_rows,
366
+ figsize=figsize,
367
+ sharex=sharex,
368
+ sharey=sharey,
369
+ grid=grid,
370
+ )
371
+
372
+
373
+ ##########################################################################
374
+ # Plots
375
+ ##########################################################################
376
+
377
+
378
+ def bars(
379
+ x: pd.DataFrame,
380
+ *,
381
+ ax: Axes | None = None,
382
+ grid: bool = True,
383
+ ) -> FigAx:
384
+ """Plot multiple lines on the same axis."""
385
+ fig, ax = _default_fig_ax(ax=ax, grid=grid)
386
+ sns.barplot(data=x, ax=ax)
387
+ _default_labels(ax, xlabel=x.index.name, ylabel=None)
388
+ ax.legend(x.columns)
389
+ return fig, ax
390
+
391
+
392
+ def lines(
393
+ x: pd.DataFrame | pd.Series,
394
+ *,
395
+ ax: Axes | None = None,
396
+ grid: bool = True,
397
+ ) -> FigAx:
398
+ """Plot multiple lines on the same axis."""
399
+ fig, ax = _default_fig_ax(ax=ax, grid=grid)
400
+ ax.plot(x.index, x)
401
+ _default_labels(ax, xlabel=x.index.name, ylabel=None)
402
+ ax.legend(x.columns)
403
+ return fig, ax
404
+
405
+
406
+ def lines_grouped(
407
+ groups: list[pd.DataFrame] | list[pd.Series],
408
+ *,
409
+ n_cols: int = 2,
410
+ col_width: float = 3,
411
+ row_height: float = 4,
412
+ sharex: bool = True,
413
+ sharey: bool = False,
414
+ grid: bool = True,
415
+ ) -> FigAxs:
416
+ """Plot multiple groups of lines on separate axes."""
417
+ fig, axs = grid_layout(
418
+ len(groups),
419
+ n_cols=n_cols,
420
+ col_width=col_width,
421
+ row_height=row_height,
422
+ sharex=sharex,
423
+ sharey=sharey,
424
+ grid=grid,
425
+ )
426
+
427
+ for group, ax in zip(groups, axs, strict=False):
428
+ lines(group, ax=ax, grid=grid)
429
+
430
+ for i in range(len(groups), len(axs)):
431
+ axs[i].set_visible(False)
432
+
433
+ return fig, axs
434
+
435
+
436
+ def line_autogrouped(
437
+ s: pd.Series | pd.DataFrame,
438
+ *,
439
+ n_cols: int = 2,
440
+ col_width: float = 4,
441
+ row_height: float = 3,
442
+ max_group_size: int = 6,
443
+ grid: bool = True,
444
+ ) -> FigAxs:
445
+ """Plot a series or dataframe with lines grouped by order of magnitude."""
446
+ group_names = _split_large_groups(
447
+ _partition_by_order_of_magnitude(s)
448
+ if isinstance(s, pd.Series)
449
+ else _partition_by_order_of_magnitude(s.max()),
450
+ max_size=max_group_size,
451
+ )
452
+
453
+ groups: list[pd.Series] | list[pd.DataFrame] = (
454
+ [s.loc[group] for group in group_names]
455
+ if isinstance(s, pd.Series)
456
+ else [s.loc[:, group] for group in group_names]
457
+ )
458
+
459
+ return lines_grouped(
460
+ groups,
461
+ n_cols=n_cols,
462
+ col_width=col_width,
463
+ row_height=row_height,
464
+ grid=grid,
465
+ )
466
+
467
+
468
+ def line_mean_std(
469
+ df: pd.DataFrame,
470
+ *,
471
+ label: str | None = None,
472
+ ax: Axes | None = None,
473
+ color: str | None = None,
474
+ alpha: float = 0.2,
475
+ grid: bool = True,
476
+ ) -> FigAx:
477
+ """Plot the mean and standard deviation using a line and fill."""
478
+ fig, ax = _default_fig_ax(ax=ax, grid=grid)
479
+ color = _default_color(ax=ax, color=color)
480
+
481
+ mean = df.mean(axis=1)
482
+ std = df.std(axis=1)
483
+ ax.plot(
484
+ mean.index,
485
+ mean,
486
+ color=color,
487
+ label=label,
488
+ )
489
+ ax.fill_between(
490
+ df.index,
491
+ mean - std,
492
+ mean + std,
493
+ color=color,
494
+ alpha=alpha,
495
+ )
496
+ _default_labels(ax, xlabel=df.index.name, ylabel=None)
497
+ return fig, ax
498
+
499
+
500
+ def lines_mean_std_from_2d_idx(
501
+ df: pd.DataFrame,
502
+ *,
503
+ names: list[str] | None = None,
504
+ ax: Axes | None = None,
505
+ alpha: float = 0.2,
506
+ grid: bool = True,
507
+ ) -> FigAx:
508
+ """Plot the mean and standard deviation of a 2D indexed dataframe."""
509
+ if len(cast(pd.MultiIndex, df.index).levels) != 2: # noqa: PLR2004
510
+ msg = "MultiIndex must have exactly two levels"
511
+ raise ValueError(msg)
512
+
513
+ fig, ax = _default_fig_ax(ax=ax, grid=grid)
514
+
515
+ for name in df.columns if names is None else names:
516
+ line_mean_std(
517
+ df[name].unstack().T,
518
+ label=name,
519
+ alpha=alpha,
520
+ ax=ax,
521
+ )
522
+ ax.legend()
523
+ return fig, ax
524
+
525
+
526
+ def heatmap(
527
+ df: pd.DataFrame,
528
+ *,
529
+ annotate: bool = False,
530
+ colorbar: bool = True,
531
+ invert_yaxis: bool = True,
532
+ cmap: str = "RdBu_r",
533
+ norm: Normalize | None = None,
534
+ ax: Axes | None = None,
535
+ cax: Axes | None = None,
536
+ sci_annotation_bounds: tuple[float, float] = (0.01, 100),
537
+ annotation_style: str = "2g",
538
+ ) -> tuple[Figure, Axes, QuadMesh]:
539
+ """Plot a heatmap of the given data."""
540
+ fig, ax = _default_fig_ax(
541
+ ax=ax,
542
+ figsize=(
543
+ max(4, 0.5 * len(df.columns)),
544
+ max(4, 0.5 * len(df.index)),
545
+ ),
546
+ grid=False,
547
+ )
548
+ if norm is None:
549
+ norm = _norm_with_zero_center(df)
550
+
551
+ hm = ax.pcolormesh(df, norm=norm, cmap=cmap)
552
+ ax.set_xticks(
553
+ np.arange(0, len(df.columns), 1) + 0.5,
554
+ labels=df.columns,
555
+ )
556
+ ax.set_yticks(
557
+ np.arange(0, len(df.index), 1) + 0.5,
558
+ labels=df.index,
559
+ )
560
+
561
+ if annotate:
562
+ _annotate_colormap(df, ax, sci_annotation_bounds, annotation_style, hm)
563
+
564
+ if colorbar:
565
+ # Add a colorbar
566
+ cb = fig.colorbar(hm, cax, ax)
567
+ cb.outline.set_linewidth(0) # type: ignore
568
+
569
+ if invert_yaxis:
570
+ ax.invert_yaxis()
571
+ rotate_xlabels(ax, rotation=45, ha="right")
572
+ return fig, ax, hm
573
+
574
+
575
+ def heatmap_from_2d_idx(
576
+ df: pd.DataFrame,
577
+ variable: str,
578
+ ax: Axes | None = None,
579
+ ) -> FigAx:
580
+ """Plot a heatmap of a 2D indexed dataframe."""
581
+ if len(cast(pd.MultiIndex, df.index).levels) != 2: # noqa: PLR2004
582
+ msg = "MultiIndex must have exactly two levels"
583
+ raise ValueError(msg)
584
+
585
+ fig, ax = _default_fig_ax(ax=ax, grid=False)
586
+ df2d = df[variable].unstack()
587
+
588
+ ax.set_title(variable)
589
+ # Note: pcolormesh swaps index/columns
590
+ hm = ax.pcolormesh(df2d.T)
591
+ ax.set_xlabel(df2d.index.name)
592
+ ax.set_ylabel(df2d.columns.name)
593
+ ax.set_xticks(
594
+ np.arange(0, len(df2d.index), 1) + 0.5,
595
+ labels=[f"{i:.2f}" for i in df2d.index],
596
+ )
597
+ ax.set_yticks(
598
+ np.arange(0, len(df2d.columns), 1) + 0.5,
599
+ labels=[f"{i:.2f}" for i in df2d.columns],
600
+ )
601
+
602
+ rotate_xlabels(ax, rotation=45, ha="right")
603
+
604
+ # Add colorbar
605
+ fig.colorbar(hm, ax=ax)
606
+ return fig, ax
607
+
608
+
609
+ def heatmaps_from_2d_idx(
610
+ df: pd.DataFrame,
611
+ *,
612
+ n_cols: int = 3,
613
+ col_width_factor: float = 1,
614
+ row_height_factor: float = 0.6,
615
+ sharex: bool = True,
616
+ sharey: bool = False,
617
+ ) -> FigAxs:
618
+ """Plot multiple heatmaps of a 2D indexed dataframe."""
619
+ idx = cast(pd.MultiIndex, df.index)
620
+
621
+ fig, axs = grid_layout(
622
+ n_groups=len(df.columns),
623
+ n_cols=min(n_cols, len(df)),
624
+ col_width=len(idx.levels[0]) * col_width_factor,
625
+ row_height=len(idx.levels[1]) * row_height_factor,
626
+ sharex=sharex,
627
+ sharey=sharey,
628
+ grid=False,
629
+ )
630
+ for ax, var in zip(axs, df.columns, strict=False):
631
+ heatmap_from_2d_idx(df, var, ax=ax)
632
+ return fig, axs
633
+
634
+
635
+ def violins(
636
+ df: pd.DataFrame,
637
+ *,
638
+ ax: Axes | None = None,
639
+ grid: bool = True,
640
+ ) -> FigAx:
641
+ """Plot multiple violins on the same axis."""
642
+ fig, ax = _default_fig_ax(ax=ax, grid=grid)
643
+ sns.violinplot(df, ax=ax)
644
+ _default_labels(ax=ax, xlabel="", ylabel=None)
645
+ return fig, ax
646
+
647
+
648
+ def violins_from_2d_idx(
649
+ df: pd.DataFrame,
650
+ *,
651
+ n_cols: int = 4,
652
+ row_height: int = 2,
653
+ sharex: bool = True,
654
+ sharey: bool = False,
655
+ grid: bool = True,
656
+ ) -> FigAxs:
657
+ """Plot multiple violins of a 2D indexed dataframe."""
658
+ if len(cast(pd.MultiIndex, df.index).levels) != 2: # noqa: PLR2004
659
+ msg = "MultiIndex must have exactly two levels"
660
+ raise ValueError(msg)
661
+
662
+ fig, axs = grid_layout(
663
+ len(df.columns),
664
+ n_cols=n_cols,
665
+ row_height=row_height,
666
+ sharex=sharex,
667
+ sharey=sharey,
668
+ grid=grid,
669
+ )
670
+
671
+ for ax, col in zip(axs[: len(df.columns)], df.columns, strict=True):
672
+ ax.set_title(col)
673
+ violins(df[col].unstack(), ax=ax)
674
+
675
+ for ax in axs[len(df.columns) :]:
676
+ for axis in ["top", "bottom", "left", "right"]:
677
+ ax.spines[axis].set_linewidth(0)
678
+ ax.yaxis.set_ticks([])
679
+
680
+ for ax in axs:
681
+ rotate_xlabels(ax)
682
+ return fig, axs
683
+
684
+
685
+ def shade_protocol(
686
+ protocol: pd.Series,
687
+ *,
688
+ ax: Axes,
689
+ cmap_name: str = "Greys_r",
690
+ vmin: float | None = None,
691
+ vmax: float | None = None,
692
+ alpha: float = 0.5,
693
+ add_legend: bool = True,
694
+ ) -> None:
695
+ """Shade the given protocol on the given axis."""
696
+ from matplotlib import colormaps
697
+ from matplotlib.colors import Normalize
698
+ from matplotlib.legend import Legend
699
+ from matplotlib.patches import Patch
700
+
701
+ cmap = colormaps[cmap_name]
702
+ norm = Normalize(
703
+ vmin=protocol.min() if vmin is None else vmin,
704
+ vmax=protocol.max() if vmax is None else vmax,
705
+ )
706
+
707
+ t0 = pd.Timedelta(seconds=0)
708
+ for t_end, val in protocol.items():
709
+ t_end = cast(pd.Timedelta, t_end)
710
+ ax.axvspan(
711
+ t0.total_seconds(),
712
+ t_end.total_seconds(),
713
+ facecolor=cmap(norm(val)),
714
+ edgecolor=None,
715
+ alpha=alpha,
716
+ )
717
+ t0 = t_end # type: ignore
718
+
719
+ if add_legend:
720
+ ax.add_artist(
721
+ Legend(
722
+ ax,
723
+ handles=[
724
+ Patch(
725
+ facecolor=cmap(norm(val)),
726
+ alpha=alpha,
727
+ label=val,
728
+ ) # type: ignore
729
+ for val in protocol
730
+ ],
731
+ labels=protocol,
732
+ loc="lower right",
733
+ bbox_to_anchor=(1.0, 0.0),
734
+ title="protocol" if protocol.name is None else cast(str, protocol.name),
735
+ )
736
+ )
737
+
738
+
739
+ ##########################################################################
740
+ # Plots that actually require a model :/
741
+ ##########################################################################
742
+
743
+
744
+ def trajectories_2d(
745
+ model: Model,
746
+ x1: tuple[str, ArrayLike],
747
+ x2: tuple[str, ArrayLike],
748
+ y0: dict[str, float] | None = None,
749
+ ax: Axes | None = None,
750
+ ) -> FigAx:
751
+ """Plot trajectories of two variables in a 2D phase space.
752
+
753
+ Examples:
754
+ >>> trajectories_2d(
755
+ ... model,
756
+ ... ("S", np.linspace(0, 1, 10)),
757
+ ... ("P", np.linspace(0, 1, 10)),
758
+ ... )
759
+
760
+ Args:
761
+ model: Model to use for the plot.
762
+ x1: Tuple of the first variable name and its values.
763
+ x2: Tuple of the second variable name and its values.
764
+ y0: Initial conditions for the model.
765
+ ax: Axes to use for the plot.
766
+
767
+ """
768
+ name1, values1 = x1
769
+ name2, values2 = x2
770
+ n1 = len(values1)
771
+ n2 = len(values2)
772
+ u = np.zeros((n1, n2))
773
+ v = np.zeros((n1, n2))
774
+ y0 = model.get_initial_conditions() if y0 is None else y0
775
+ for i, ii in enumerate(values1):
776
+ for j, jj in enumerate(values2):
777
+ rhs = model.get_right_hand_side(y0 | {name1: ii, name2: jj})
778
+ u[i, j] = rhs[name1]
779
+ v[i, j] = rhs[name2]
780
+
781
+ fig, ax = _default_fig_ax(ax=ax, grid=False)
782
+ ax.quiver(values1, values2, u.T, v.T)
783
+ return fig, ax
784
+
785
+
786
+ ##########################################################################
787
+ # Label Plots
788
+ ##########################################################################
789
+
790
+
791
+ def relative_label_distribution(
792
+ mapper: LabelMapper | LinearLabelMapper,
793
+ concs: pd.DataFrame,
794
+ *,
795
+ subset: list[str] | None = None,
796
+ n_cols: int = 2,
797
+ col_width: float = 3,
798
+ row_height: float = 3,
799
+ sharey: bool = False,
800
+ grid: bool = True,
801
+ ) -> FigAxs:
802
+ """Plot the relative distribution of labels in the given data."""
803
+ variables = list(mapper.label_variables) if subset is None else subset
804
+ fig, axs = grid_layout(
805
+ n_groups=len(variables),
806
+ n_cols=n_cols,
807
+ col_width=col_width,
808
+ row_height=row_height,
809
+ sharey=sharey,
810
+ grid=grid,
811
+ )
812
+ if isinstance(mapper, LabelMapper):
813
+ for ax, name in zip(axs, variables, strict=False):
814
+ for i in range(mapper.label_variables[name]):
815
+ isos = mapper.get_isotopomers_of_at_position(name, i)
816
+ labels = cast(pd.DataFrame, concs.loc[:, isos])
817
+ total = concs.loc[:, f"{name}__total"]
818
+ ax.plot(labels.index, (labels.sum(axis=1) / total), label=f"C{i+1}")
819
+ ax.set_title(name)
820
+ ax.legend()
821
+ else:
822
+ for ax, (name, isos) in zip(
823
+ axs, mapper.get_isotopomers(variables).items(), strict=False
824
+ ):
825
+ ax.plot(concs.index, concs.loc[:, isos])
826
+ ax.set_title(name)
827
+ ax.legend([f"C{i+1}" for i in range(len(isos))])
828
+
829
+ return fig, axs