modelbase2 0.1.79__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modelbase2/__init__.py +138 -26
 - modelbase2/distributions.py +306 -0
 - modelbase2/experimental/__init__.py +17 -0
 - modelbase2/experimental/codegen.py +239 -0
 - modelbase2/experimental/diff.py +227 -0
 - modelbase2/experimental/notes.md +4 -0
 - modelbase2/experimental/tex.py +521 -0
 - modelbase2/fit.py +284 -0
 - modelbase2/fns.py +185 -0
 - modelbase2/integrators/__init__.py +19 -0
 - modelbase2/integrators/int_assimulo.py +146 -0
 - modelbase2/integrators/int_scipy.py +147 -0
 - modelbase2/label_map.py +610 -0
 - modelbase2/linear_label_map.py +301 -0
 - modelbase2/mc.py +548 -0
 - modelbase2/mca.py +280 -0
 - modelbase2/model.py +1621 -0
 - modelbase2/npe.py +343 -0
 - modelbase2/parallel.py +171 -0
 - modelbase2/parameterise.py +28 -0
 - modelbase2/paths.py +36 -0
 - modelbase2/plot.py +829 -0
 - modelbase2/sbml/__init__.py +14 -0
 - modelbase2/sbml/_data.py +77 -0
 - modelbase2/sbml/_export.py +656 -0
 - modelbase2/sbml/_import.py +585 -0
 - modelbase2/sbml/_mathml.py +691 -0
 - modelbase2/sbml/_name_conversion.py +52 -0
 - modelbase2/sbml/_unit_conversion.py +74 -0
 - modelbase2/scan.py +616 -0
 - modelbase2/scope.py +96 -0
 - modelbase2/simulator.py +635 -0
 - modelbase2/surrogates/__init__.py +32 -0
 - modelbase2/surrogates/_poly.py +66 -0
 - modelbase2/surrogates/_torch.py +249 -0
 - modelbase2/surrogates.py +316 -0
 - modelbase2/types.py +352 -11
 - modelbase2-0.2.0.dist-info/METADATA +81 -0
 - modelbase2-0.2.0.dist-info/RECORD +42 -0
 - {modelbase2-0.1.79.dist-info → modelbase2-0.2.0.dist-info}/WHEEL +1 -1
 - modelbase2/core/__init__.py +0 -29
 - modelbase2/core/algebraic_module_container.py +0 -130
 - modelbase2/core/constant_container.py +0 -113
 - modelbase2/core/data.py +0 -109
 - modelbase2/core/name_container.py +0 -29
 - modelbase2/core/reaction_container.py +0 -115
 - modelbase2/core/utils.py +0 -28
 - modelbase2/core/variable_container.py +0 -24
 - modelbase2/ode/__init__.py +0 -13
 - modelbase2/ode/integrator.py +0 -80
 - modelbase2/ode/mca.py +0 -270
 - modelbase2/ode/model.py +0 -470
 - modelbase2/ode/simulator.py +0 -153
 - modelbase2/utils/__init__.py +0 -0
 - modelbase2/utils/plotting.py +0 -372
 - modelbase2-0.1.79.dist-info/METADATA +0 -44
 - modelbase2-0.1.79.dist-info/RECORD +0 -22
 - {modelbase2-0.1.79.dist-info → modelbase2-0.2.0.dist-info/licenses}/LICENSE +0 -0
 
    
        modelbase2/plot.py
    ADDED
    
    | 
         @@ -0,0 +1,829 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            """Plotting Utilities Module.
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            This module provides functions and classes for creating various plots and visualizations
         
     | 
| 
      
 4 
     | 
    
         
            +
            for metabolic models. It includes functionality for plotting heatmaps, time courses,
         
     | 
| 
      
 5 
     | 
    
         
            +
            and parameter scans.
         
     | 
| 
      
 6 
     | 
    
         
            +
             
     | 
| 
      
 7 
     | 
    
         
            +
            Functions:
         
     | 
| 
      
 8 
     | 
    
         
            +
                plot_heatmap: Plot a heatmap of the given data.
         
     | 
| 
      
 9 
     | 
    
         
            +
                plot_time_course: Plot a time course of the given data.
         
     | 
| 
      
 10 
     | 
    
         
            +
                plot_parameter_scan: Plot a parameter scan of the given data.
         
     | 
| 
      
 11 
     | 
    
         
            +
                plot_3d_surface: Plot a 3D surface of the given data.
         
     | 
| 
      
 12 
     | 
    
         
            +
                plot_3d_scatter: Plot a 3D scatter plot of the given data.
         
     | 
| 
      
 13 
     | 
    
         
            +
                plot_label_distribution: Plot the distribution of labels in the given data.
         
     | 
| 
      
 14 
     | 
    
         
            +
                plot_linear_label_distribution: Plot the distribution of linear labels in the given
         
     | 
| 
      
 15 
     | 
    
         
            +
                    data.
         
     | 
| 
      
 16 
     | 
    
         
            +
                plot_label_correlation: Plot the correlation between labels in the given data.
         
     | 
| 
      
 17 
     | 
    
         
            +
            """
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
            from __future__ import annotations
         
     | 
| 
      
 20 
     | 
    
         
            +
             
     | 
| 
      
 21 
     | 
    
         
            +
            __all__ = [
         
     | 
| 
      
 22 
     | 
    
         
            +
                "FigAx",
         
     | 
| 
      
 23 
     | 
    
         
            +
                "FigAxs",
         
     | 
| 
      
 24 
     | 
    
         
            +
                "add_grid",
         
     | 
| 
      
 25 
     | 
    
         
            +
                "bars",
         
     | 
| 
      
 26 
     | 
    
         
            +
                "grid_layout",
         
     | 
| 
      
 27 
     | 
    
         
            +
                "heatmap",
         
     | 
| 
      
 28 
     | 
    
         
            +
                "heatmap_from_2d_idx",
         
     | 
| 
      
 29 
     | 
    
         
            +
                "heatmaps_from_2d_idx",
         
     | 
| 
      
 30 
     | 
    
         
            +
                "line_autogrouped",
         
     | 
| 
      
 31 
     | 
    
         
            +
                "line_mean_std",
         
     | 
| 
      
 32 
     | 
    
         
            +
                "lines",
         
     | 
| 
      
 33 
     | 
    
         
            +
                "lines_grouped",
         
     | 
| 
      
 34 
     | 
    
         
            +
                "lines_mean_std_from_2d_idx",
         
     | 
| 
      
 35 
     | 
    
         
            +
                "relative_label_distribution",
         
     | 
| 
      
 36 
     | 
    
         
            +
                "rotate_xlabels",
         
     | 
| 
      
 37 
     | 
    
         
            +
                "shade_protocol",
         
     | 
| 
      
 38 
     | 
    
         
            +
                "trajectories_2d",
         
     | 
| 
      
 39 
     | 
    
         
            +
                "two_axes",
         
     | 
| 
      
 40 
     | 
    
         
            +
                "violins",
         
     | 
| 
      
 41 
     | 
    
         
            +
                "violins_from_2d_idx",
         
     | 
| 
      
 42 
     | 
    
         
            +
            ]
         
     | 
| 
      
 43 
     | 
    
         
            +
             
     | 
| 
      
 44 
     | 
    
         
            +
            import itertools as it
         
     | 
| 
      
 45 
     | 
    
         
            +
            import math
         
     | 
| 
      
 46 
     | 
    
         
            +
            from typing import TYPE_CHECKING, Literal, cast
         
     | 
| 
      
 47 
     | 
    
         
            +
             
     | 
| 
      
 48 
     | 
    
         
            +
            import numpy as np
         
     | 
| 
      
 49 
     | 
    
         
            +
            import pandas as pd
         
     | 
| 
      
 50 
     | 
    
         
            +
            import seaborn as sns
         
     | 
| 
      
 51 
     | 
    
         
            +
            from matplotlib import pyplot as plt
         
     | 
| 
      
 52 
     | 
    
         
            +
            from matplotlib.axes import Axes
         
     | 
| 
      
 53 
     | 
    
         
            +
            from matplotlib.colors import (
         
     | 
| 
      
 54 
     | 
    
         
            +
                LogNorm,
         
     | 
| 
      
 55 
     | 
    
         
            +
                Normalize,
         
     | 
| 
      
 56 
     | 
    
         
            +
                SymLogNorm,
         
     | 
| 
      
 57 
     | 
    
         
            +
                colorConverter,  # type: ignore
         
     | 
| 
      
 58 
     | 
    
         
            +
            )
         
     | 
| 
      
 59 
     | 
    
         
            +
            from matplotlib.figure import Figure
         
     | 
| 
      
 60 
     | 
    
         
            +
            from mpl_toolkits.mplot3d import Axes3D
         
     | 
| 
      
 61 
     | 
    
         
            +
             
     | 
| 
      
 62 
     | 
    
         
            +
            from modelbase2.label_map import LabelMapper
         
     | 
| 
      
 63 
     | 
    
         
            +
             
     | 
| 
      
 64 
     | 
    
         
            +
            if TYPE_CHECKING:
         
     | 
| 
      
 65 
     | 
    
         
            +
                from matplotlib.collections import QuadMesh
         
     | 
| 
      
 66 
     | 
    
         
            +
             
     | 
| 
      
 67 
     | 
    
         
            +
                from modelbase2.linear_label_map import LinearLabelMapper
         
     | 
| 
      
 68 
     | 
    
         
            +
                from modelbase2.model import Model
         
     | 
| 
      
 69 
     | 
    
         
            +
                from modelbase2.types import Array, ArrayLike
         
     | 
| 
      
 70 
     | 
    
         
            +
             
     | 
| 
      
 71 
     | 
    
         
            +
            type FigAx = tuple[Figure, Axes]
         
     | 
| 
      
 72 
     | 
    
         
            +
            type FigAxs = tuple[Figure, list[Axes]]
         
     | 
| 
      
 73 
     | 
    
         
            +
             
     | 
| 
      
 74 
     | 
    
         
            +
             
     | 
| 
      
 75 
     | 
    
         
            +
            ##########################################################################
         
     | 
| 
      
 76 
     | 
    
         
            +
            # Helpers
         
     | 
| 
      
 77 
     | 
    
         
            +
            ##########################################################################
         
     | 
| 
      
 78 
     | 
    
         
            +
             
     | 
| 
      
 79 
     | 
    
         
            +
             
     | 
| 
      
 80 
     | 
    
         
            +
            def _relative_luminance(color: Array) -> float:
         
     | 
| 
      
 81 
     | 
    
         
            +
                """Calculate the relative luminance of a color."""
         
     | 
| 
      
 82 
     | 
    
         
            +
                rgb = colorConverter.to_rgba_array(color)[:, :3]
         
     | 
| 
      
 83 
     | 
    
         
            +
             
     | 
| 
      
 84 
     | 
    
         
            +
                # If RsRGB <= 0.03928 then R = RsRGB/12.92 else R = ((RsRGB+0.055)/1.055) ^ 2.4
         
     | 
| 
      
 85 
     | 
    
         
            +
                rsrgb = np.where(
         
     | 
| 
      
 86 
     | 
    
         
            +
                    rgb <= 0.03928,  # noqa: PLR2004
         
     | 
| 
      
 87 
     | 
    
         
            +
                    rgb / 12.92,
         
     | 
| 
      
 88 
     | 
    
         
            +
                    ((rgb + 0.055) / 1.055) ** 2.4,
         
     | 
| 
      
 89 
     | 
    
         
            +
                )
         
     | 
| 
      
 90 
     | 
    
         
            +
             
     | 
| 
      
 91 
     | 
    
         
            +
                # L = 0.2126 * R + 0.7152 * G + 0.0722 * B
         
     | 
| 
      
 92 
     | 
    
         
            +
                return np.matmul(rsrgb, [0.2126, 0.7152, 0.0722])[0]
         
     | 
| 
      
 93 
     | 
    
         
            +
             
     | 
| 
      
 94 
     | 
    
         
            +
             
     | 
| 
      
 95 
     | 
    
         
            +
            def _get_norm(vmin: float, vmax: float) -> Normalize:
         
     | 
| 
      
 96 
     | 
    
         
            +
                """Get a suitable normalization object for the given data.
         
     | 
| 
      
 97 
     | 
    
         
            +
             
     | 
| 
      
 98 
     | 
    
         
            +
                Uses a logarithmic scale for values greater than 1000 or less than -1000,
         
     | 
| 
      
 99 
     | 
    
         
            +
                a symmetrical logarithmic scale for values less than or equal to 0,
         
     | 
| 
      
 100 
     | 
    
         
            +
                and a linear scale for all other values.
         
     | 
| 
      
 101 
     | 
    
         
            +
             
     | 
| 
      
 102 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 103 
     | 
    
         
            +
                    vmin: Minimum value of the data.
         
     | 
| 
      
 104 
     | 
    
         
            +
                    vmax: Maximum value of the data.
         
     | 
| 
      
 105 
     | 
    
         
            +
             
     | 
| 
      
 106 
     | 
    
         
            +
                Returns:
         
     | 
| 
      
 107 
     | 
    
         
            +
                    Normalize: A normalization object for the given data.
         
     | 
| 
      
 108 
     | 
    
         
            +
             
     | 
| 
      
 109 
     | 
    
         
            +
                """
         
     | 
| 
      
 110 
     | 
    
         
            +
                if vmax < 1000 and vmin > -1000:  # noqa: PLR2004
         
     | 
| 
      
 111 
     | 
    
         
            +
                    norm = Normalize(vmin=vmin, vmax=vmax)
         
     | 
| 
      
 112 
     | 
    
         
            +
                elif vmin <= 0:
         
     | 
| 
      
 113 
     | 
    
         
            +
                    norm = SymLogNorm(linthresh=1, vmin=vmin, vmax=vmax, base=10)
         
     | 
| 
      
 114 
     | 
    
         
            +
                else:
         
     | 
| 
      
 115 
     | 
    
         
            +
                    norm = LogNorm(vmin=vmin, vmax=vmax)
         
     | 
| 
      
 116 
     | 
    
         
            +
                return norm
         
     | 
| 
      
 117 
     | 
    
         
            +
             
     | 
| 
      
 118 
     | 
    
         
            +
             
     | 
| 
      
 119 
     | 
    
         
            +
            def _norm_with_zero_center(df: pd.DataFrame) -> Normalize:
         
     | 
| 
      
 120 
     | 
    
         
            +
                """Get a normalization object with zero-centered values for the given data."""
         
     | 
| 
      
 121 
     | 
    
         
            +
                v = max(abs(df.min().min()), abs(df.max().max()))
         
     | 
| 
      
 122 
     | 
    
         
            +
                return _get_norm(vmin=-v, vmax=v)
         
     | 
| 
      
 123 
     | 
    
         
            +
             
     | 
| 
      
 124 
     | 
    
         
            +
             
     | 
| 
      
 125 
     | 
    
         
            +
            def _partition_by_order_of_magnitude(s: pd.Series) -> list[list[str]]:
         
     | 
| 
      
 126 
     | 
    
         
            +
                """Partition a series into groups based on the order of magnitude of the values."""
         
     | 
| 
      
 127 
     | 
    
         
            +
                return [
         
     | 
| 
      
 128 
     | 
    
         
            +
                    i.to_list()
         
     | 
| 
      
 129 
     | 
    
         
            +
                    for i in np.floor(np.log10(s)).to_frame(name=0).groupby(0)[0].groups.values()  # type: ignore
         
     | 
| 
      
 130 
     | 
    
         
            +
                ]
         
     | 
| 
      
 131 
     | 
    
         
            +
             
     | 
| 
      
 132 
     | 
    
         
            +
             
     | 
| 
      
 133 
     | 
    
         
            +
            def _split_large_groups[T](groups: list[list[T]], max_size: int) -> list[list[T]]:
         
     | 
| 
      
 134 
     | 
    
         
            +
                """Split groups larger than the given size into smaller groups."""
         
     | 
| 
      
 135 
     | 
    
         
            +
                return list(
         
     | 
| 
      
 136 
     | 
    
         
            +
                    it.chain(
         
     | 
| 
      
 137 
     | 
    
         
            +
                        *(
         
     | 
| 
      
 138 
     | 
    
         
            +
                            (
         
     | 
| 
      
 139 
     | 
    
         
            +
                                [group]
         
     | 
| 
      
 140 
     | 
    
         
            +
                                if len(group) < max_size
         
     | 
| 
      
 141 
     | 
    
         
            +
                                else [  # type: ignore
         
     | 
| 
      
 142 
     | 
    
         
            +
                                    list(i)
         
     | 
| 
      
 143 
     | 
    
         
            +
                                    for i in np.array_split(group, math.ceil(len(group) / max_size))  # type: ignore
         
     | 
| 
      
 144 
     | 
    
         
            +
                                ]
         
     | 
| 
      
 145 
     | 
    
         
            +
                            )
         
     | 
| 
      
 146 
     | 
    
         
            +
                            for group in groups
         
     | 
| 
      
 147 
     | 
    
         
            +
                        )
         
     | 
| 
      
 148 
     | 
    
         
            +
                    )
         
     | 
| 
      
 149 
     | 
    
         
            +
                )  # type: ignore
         
     | 
| 
      
 150 
     | 
    
         
            +
             
     | 
| 
      
 151 
     | 
    
         
            +
             
     | 
| 
      
 152 
     | 
    
         
            +
            def _default_color(ax: Axes, color: str | None) -> str:
         
     | 
| 
      
 153 
     | 
    
         
            +
                """Get a default color for the given axis."""
         
     | 
| 
      
 154 
     | 
    
         
            +
                return f"C{len(ax.lines)}" if color is None else color
         
     | 
| 
      
 155 
     | 
    
         
            +
             
     | 
| 
      
 156 
     | 
    
         
            +
             
     | 
| 
      
 157 
     | 
    
         
            +
            def _default_labels(
         
     | 
| 
      
 158 
     | 
    
         
            +
                ax: Axes,
         
     | 
| 
      
 159 
     | 
    
         
            +
                xlabel: str | None = None,
         
     | 
| 
      
 160 
     | 
    
         
            +
                ylabel: str | None = None,
         
     | 
| 
      
 161 
     | 
    
         
            +
                zlabel: str | None = None,
         
     | 
| 
      
 162 
     | 
    
         
            +
            ) -> None:
         
     | 
| 
      
 163 
     | 
    
         
            +
                """Set default labels for the given axis.
         
     | 
| 
      
 164 
     | 
    
         
            +
             
     | 
| 
      
 165 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 166 
     | 
    
         
            +
                    ax: matplotlib Axes
         
     | 
| 
      
 167 
     | 
    
         
            +
                    xlabel: Label for the x-axis.
         
     | 
| 
      
 168 
     | 
    
         
            +
                    ylabel: Label for the y-axis.
         
     | 
| 
      
 169 
     | 
    
         
            +
                    zlabel: Label for the z-axis.
         
     | 
| 
      
 170 
     | 
    
         
            +
             
     | 
| 
      
 171 
     | 
    
         
            +
                """
         
     | 
| 
      
 172 
     | 
    
         
            +
                ax.set_xlabel("Add a label / unit" if xlabel is None else xlabel)
         
     | 
| 
      
 173 
     | 
    
         
            +
                ax.set_ylabel("Add a label / unit" if ylabel is None else ylabel)
         
     | 
| 
      
 174 
     | 
    
         
            +
                if isinstance(ax, Axes3D):
         
     | 
| 
      
 175 
     | 
    
         
            +
                    ax.set_zlabel("Add a label / unit" if zlabel is None else zlabel)
         
     | 
| 
      
 176 
     | 
    
         
            +
             
     | 
| 
      
 177 
     | 
    
         
            +
             
     | 
| 
      
 178 
     | 
    
         
            +
            def _annotate_colormap(
         
     | 
| 
      
 179 
     | 
    
         
            +
                df: pd.DataFrame,
         
     | 
| 
      
 180 
     | 
    
         
            +
                ax: Axes,
         
     | 
| 
      
 181 
     | 
    
         
            +
                sci_annotation_bounds: tuple[float, float],
         
     | 
| 
      
 182 
     | 
    
         
            +
                annotation_style: str,
         
     | 
| 
      
 183 
     | 
    
         
            +
                hm: QuadMesh,
         
     | 
| 
      
 184 
     | 
    
         
            +
            ) -> None:
         
     | 
| 
      
 185 
     | 
    
         
            +
                """Annotate a heatmap with the values of the data.
         
     | 
| 
      
 186 
     | 
    
         
            +
             
     | 
| 
      
 187 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 188 
     | 
    
         
            +
                    df: Dataframe to annotate.
         
     | 
| 
      
 189 
     | 
    
         
            +
                    ax: Axes to annotate.
         
     | 
| 
      
 190 
     | 
    
         
            +
                    sci_annotation_bounds: Bounds for scientific notation.
         
     | 
| 
      
 191 
     | 
    
         
            +
                    annotation_style: Style for the annotations.
         
     | 
| 
      
 192 
     | 
    
         
            +
                    hm: QuadMesh object of the heatmap.
         
     | 
| 
      
 193 
     | 
    
         
            +
             
     | 
| 
      
 194 
     | 
    
         
            +
                """
         
     | 
| 
      
 195 
     | 
    
         
            +
                hm.update_scalarmappable()  # So that get_facecolor is an array
         
     | 
| 
      
 196 
     | 
    
         
            +
                xpos, ypos = np.meshgrid(
         
     | 
| 
      
 197 
     | 
    
         
            +
                    np.arange(len(df.columns)),
         
     | 
| 
      
 198 
     | 
    
         
            +
                    np.arange(len(df.index)),
         
     | 
| 
      
 199 
     | 
    
         
            +
                )
         
     | 
| 
      
 200 
     | 
    
         
            +
                for x, y, val, color in zip(
         
     | 
| 
      
 201 
     | 
    
         
            +
                    xpos.flat,
         
     | 
| 
      
 202 
     | 
    
         
            +
                    ypos.flat,
         
     | 
| 
      
 203 
     | 
    
         
            +
                    hm.get_array().flat,  # type: ignore
         
     | 
| 
      
 204 
     | 
    
         
            +
                    hm.get_facecolor(),
         
     | 
| 
      
 205 
     | 
    
         
            +
                    strict=True,
         
     | 
| 
      
 206 
     | 
    
         
            +
                ):
         
     | 
| 
      
 207 
     | 
    
         
            +
                    val_text = (
         
     | 
| 
      
 208 
     | 
    
         
            +
                        f"{val:.{annotation_style}}"
         
     | 
| 
      
 209 
     | 
    
         
            +
                        if sci_annotation_bounds[0] < abs(val) <= sci_annotation_bounds[1]
         
     | 
| 
      
 210 
     | 
    
         
            +
                        else f"{val:.0e}"
         
     | 
| 
      
 211 
     | 
    
         
            +
                    )
         
     | 
| 
      
 212 
     | 
    
         
            +
                    ax.text(
         
     | 
| 
      
 213 
     | 
    
         
            +
                        x + 0.5,
         
     | 
| 
      
 214 
     | 
    
         
            +
                        y + 0.5,
         
     | 
| 
      
 215 
     | 
    
         
            +
                        val_text,
         
     | 
| 
      
 216 
     | 
    
         
            +
                        ha="center",
         
     | 
| 
      
 217 
     | 
    
         
            +
                        va="center",
         
     | 
| 
      
 218 
     | 
    
         
            +
                        color="black" if _relative_luminance(color) > 0.45 else "white",  # type: ignore  # noqa: PLR2004
         
     | 
| 
      
 219 
     | 
    
         
            +
                    )
         
     | 
| 
      
 220 
     | 
    
         
            +
             
     | 
| 
      
 221 
     | 
    
         
            +
             
     | 
| 
      
 222 
     | 
    
         
            +
            def add_grid(ax: Axes) -> Axes:
         
     | 
| 
      
 223 
     | 
    
         
            +
                """Add a grid to the given axis."""
         
     | 
| 
      
 224 
     | 
    
         
            +
                ax.grid(visible=True)
         
     | 
| 
      
 225 
     | 
    
         
            +
                ax.set_axisbelow(b=True)
         
     | 
| 
      
 226 
     | 
    
         
            +
                return ax
         
     | 
| 
      
 227 
     | 
    
         
            +
             
     | 
| 
      
 228 
     | 
    
         
            +
             
     | 
| 
      
 229 
     | 
    
         
            +
            def rotate_xlabels(
         
     | 
| 
      
 230 
     | 
    
         
            +
                ax: Axes,
         
     | 
| 
      
 231 
     | 
    
         
            +
                rotation: float = 45,
         
     | 
| 
      
 232 
     | 
    
         
            +
                ha: Literal["left", "center", "right"] = "right",
         
     | 
| 
      
 233 
     | 
    
         
            +
            ) -> Axes:
         
     | 
| 
      
 234 
     | 
    
         
            +
                """Rotate the x-axis labels of the given axis.
         
     | 
| 
      
 235 
     | 
    
         
            +
             
     | 
| 
      
 236 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 237 
     | 
    
         
            +
                    ax: Axis to rotate the labels of.
         
     | 
| 
      
 238 
     | 
    
         
            +
                    rotation: Rotation angle in degrees (default: 45).
         
     | 
| 
      
 239 
     | 
    
         
            +
                    ha: Horizontal alignment of the labels (default
         
     | 
| 
      
 240 
     | 
    
         
            +
             
     | 
| 
      
 241 
     | 
    
         
            +
                Returns:
         
     | 
| 
      
 242 
     | 
    
         
            +
                    Axes object for object chaining
         
     | 
| 
      
 243 
     | 
    
         
            +
             
     | 
| 
      
 244 
     | 
    
         
            +
                """
         
     | 
| 
      
 245 
     | 
    
         
            +
                for label in ax.get_xticklabels():
         
     | 
| 
      
 246 
     | 
    
         
            +
                    label.set_rotation(rotation)
         
     | 
| 
      
 247 
     | 
    
         
            +
                    label.set_horizontalalignment(ha)
         
     | 
| 
      
 248 
     | 
    
         
            +
                return ax
         
     | 
| 
      
 249 
     | 
    
         
            +
             
     | 
| 
      
 250 
     | 
    
         
            +
             
     | 
| 
      
 251 
     | 
    
         
            +
            ##########################################################################
         
     | 
| 
      
 252 
     | 
    
         
            +
            # General plot layout
         
     | 
| 
      
 253 
     | 
    
         
            +
            ##########################################################################
         
     | 
| 
      
 254 
     | 
    
         
            +
             
     | 
| 
      
 255 
     | 
    
         
            +
             
     | 
| 
      
 256 
     | 
    
         
            +
            def _default_fig_ax(
         
     | 
| 
      
 257 
     | 
    
         
            +
                *,
         
     | 
| 
      
 258 
     | 
    
         
            +
                ax: Axes | None,
         
     | 
| 
      
 259 
     | 
    
         
            +
                grid: bool,
         
     | 
| 
      
 260 
     | 
    
         
            +
                figsize: tuple[float, float] | None = None,
         
     | 
| 
      
 261 
     | 
    
         
            +
            ) -> FigAx:
         
     | 
| 
      
 262 
     | 
    
         
            +
                """Create a figure and axes if none are provided.
         
     | 
| 
      
 263 
     | 
    
         
            +
             
     | 
| 
      
 264 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 265 
     | 
    
         
            +
                    ax: Axis to use for the plot.
         
     | 
| 
      
 266 
     | 
    
         
            +
                    grid: Whether to add a grid to the plot.
         
     | 
| 
      
 267 
     | 
    
         
            +
                    figsize: Size of the figure (default: None).
         
     | 
| 
      
 268 
     | 
    
         
            +
             
     | 
| 
      
 269 
     | 
    
         
            +
                Returns:
         
     | 
| 
      
 270 
     | 
    
         
            +
                    Figure and Axes objects for the plot.
         
     | 
| 
      
 271 
     | 
    
         
            +
             
     | 
| 
      
 272 
     | 
    
         
            +
                """
         
     | 
| 
      
 273 
     | 
    
         
            +
                if ax is None:
         
     | 
| 
      
 274 
     | 
    
         
            +
                    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)
         
     | 
| 
      
 275 
     | 
    
         
            +
                else:
         
     | 
| 
      
 276 
     | 
    
         
            +
                    fig = cast(Figure, ax.get_figure())
         
     | 
| 
      
 277 
     | 
    
         
            +
             
     | 
| 
      
 278 
     | 
    
         
            +
                if grid:
         
     | 
| 
      
 279 
     | 
    
         
            +
                    add_grid(ax)
         
     | 
| 
      
 280 
     | 
    
         
            +
                return fig, ax
         
     | 
| 
      
 281 
     | 
    
         
            +
             
     | 
| 
      
 282 
     | 
    
         
            +
             
     | 
| 
      
 283 
     | 
    
         
            +
            def _default_fig_axs(
         
     | 
| 
      
 284 
     | 
    
         
            +
                axs: list[Axes] | None,
         
     | 
| 
      
 285 
     | 
    
         
            +
                *,
         
     | 
| 
      
 286 
     | 
    
         
            +
                ncols: int,
         
     | 
| 
      
 287 
     | 
    
         
            +
                nrows: int,
         
     | 
| 
      
 288 
     | 
    
         
            +
                figsize: tuple[float, float] | None,
         
     | 
| 
      
 289 
     | 
    
         
            +
                grid: bool,
         
     | 
| 
      
 290 
     | 
    
         
            +
                sharex: bool,
         
     | 
| 
      
 291 
     | 
    
         
            +
                sharey: bool,
         
     | 
| 
      
 292 
     | 
    
         
            +
            ) -> FigAxs:
         
     | 
| 
      
 293 
     | 
    
         
            +
                """Create a figure and multiple axes if none are provided.
         
     | 
| 
      
 294 
     | 
    
         
            +
             
     | 
| 
      
 295 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 296 
     | 
    
         
            +
                    axs: Axes to use for the plot.
         
     | 
| 
      
 297 
     | 
    
         
            +
                    ncols: Number of columns for the plot.
         
     | 
| 
      
 298 
     | 
    
         
            +
                    nrows: Number of rows for the plot.
         
     | 
| 
      
 299 
     | 
    
         
            +
                    figsize: Size of the figure (default: None).
         
     | 
| 
      
 300 
     | 
    
         
            +
                    grid: Whether to add a grid to the plot.
         
     | 
| 
      
 301 
     | 
    
         
            +
                    sharex: Whether to share the x-axis between the axes.
         
     | 
| 
      
 302 
     | 
    
         
            +
                    sharey: Whether to share the y-axis between the axes.
         
     | 
| 
      
 303 
     | 
    
         
            +
             
     | 
| 
      
 304 
     | 
    
         
            +
                Returns:
         
     | 
| 
      
 305 
     | 
    
         
            +
                    Figure and Axes objects for the plot.
         
     | 
| 
      
 306 
     | 
    
         
            +
             
     | 
| 
      
 307 
     | 
    
         
            +
                """
         
     | 
| 
      
 308 
     | 
    
         
            +
                if axs is None or len(axs) == 0:
         
     | 
| 
      
 309 
     | 
    
         
            +
                    fig, axs_array = plt.subplots(
         
     | 
| 
      
 310 
     | 
    
         
            +
                        nrows=nrows,
         
     | 
| 
      
 311 
     | 
    
         
            +
                        ncols=ncols,
         
     | 
| 
      
 312 
     | 
    
         
            +
                        sharex=sharex,
         
     | 
| 
      
 313 
     | 
    
         
            +
                        sharey=sharey,
         
     | 
| 
      
 314 
     | 
    
         
            +
                        figsize=figsize,
         
     | 
| 
      
 315 
     | 
    
         
            +
                        squeeze=False,
         
     | 
| 
      
 316 
     | 
    
         
            +
                        layout="constrained",
         
     | 
| 
      
 317 
     | 
    
         
            +
                    )
         
     | 
| 
      
 318 
     | 
    
         
            +
                    axs = list(axs_array.flatten())
         
     | 
| 
      
 319 
     | 
    
         
            +
                else:
         
     | 
| 
      
 320 
     | 
    
         
            +
                    fig = cast(Figure, axs[0].get_figure())
         
     | 
| 
      
 321 
     | 
    
         
            +
             
     | 
| 
      
 322 
     | 
    
         
            +
                if grid:
         
     | 
| 
      
 323 
     | 
    
         
            +
                    for ax in axs:
         
     | 
| 
      
 324 
     | 
    
         
            +
                        add_grid(ax)
         
     | 
| 
      
 325 
     | 
    
         
            +
                return fig, axs
         
     | 
| 
      
 326 
     | 
    
         
            +
             
     | 
| 
      
 327 
     | 
    
         
            +
             
     | 
| 
      
 328 
     | 
    
         
            +
            def two_axes(
         
     | 
| 
      
 329 
     | 
    
         
            +
                *,
         
     | 
| 
      
 330 
     | 
    
         
            +
                figsize: tuple[float, float] | None = None,
         
     | 
| 
      
 331 
     | 
    
         
            +
                sharex: bool = True,
         
     | 
| 
      
 332 
     | 
    
         
            +
                sharey: bool = False,
         
     | 
| 
      
 333 
     | 
    
         
            +
                grid: bool = False,
         
     | 
| 
      
 334 
     | 
    
         
            +
            ) -> FigAxs:
         
     | 
| 
      
 335 
     | 
    
         
            +
                """Create a figure with two axes."""
         
     | 
| 
      
 336 
     | 
    
         
            +
                return _default_fig_axs(
         
     | 
| 
      
 337 
     | 
    
         
            +
                    None,
         
     | 
| 
      
 338 
     | 
    
         
            +
                    ncols=2,
         
     | 
| 
      
 339 
     | 
    
         
            +
                    nrows=1,
         
     | 
| 
      
 340 
     | 
    
         
            +
                    figsize=figsize,
         
     | 
| 
      
 341 
     | 
    
         
            +
                    sharex=sharex,
         
     | 
| 
      
 342 
     | 
    
         
            +
                    sharey=sharey,
         
     | 
| 
      
 343 
     | 
    
         
            +
                    grid=grid,
         
     | 
| 
      
 344 
     | 
    
         
            +
                )
         
     | 
| 
      
 345 
     | 
    
         
            +
             
     | 
| 
      
 346 
     | 
    
         
            +
             
     | 
| 
      
 347 
     | 
    
         
            +
            def grid_layout(
         
     | 
| 
      
 348 
     | 
    
         
            +
                n_groups: int,
         
     | 
| 
      
 349 
     | 
    
         
            +
                *,
         
     | 
| 
      
 350 
     | 
    
         
            +
                n_cols: int = 2,
         
     | 
| 
      
 351 
     | 
    
         
            +
                col_width: float = 3,
         
     | 
| 
      
 352 
     | 
    
         
            +
                row_height: float = 4,
         
     | 
| 
      
 353 
     | 
    
         
            +
                sharex: bool = True,
         
     | 
| 
      
 354 
     | 
    
         
            +
                sharey: bool = False,
         
     | 
| 
      
 355 
     | 
    
         
            +
                grid: bool = True,
         
     | 
| 
      
 356 
     | 
    
         
            +
            ) -> tuple[Figure, list[Axes]]:
         
     | 
| 
      
 357 
     | 
    
         
            +
                """Create a grid layout for the given number of groups."""
         
     | 
| 
      
 358 
     | 
    
         
            +
                n_cols = min(n_groups, n_cols)
         
     | 
| 
      
 359 
     | 
    
         
            +
                n_rows = math.ceil(n_groups / n_cols)
         
     | 
| 
      
 360 
     | 
    
         
            +
                figsize = (n_cols * col_width, n_rows * row_height)
         
     | 
| 
      
 361 
     | 
    
         
            +
             
     | 
| 
      
 362 
     | 
    
         
            +
                return _default_fig_axs(
         
     | 
| 
      
 363 
     | 
    
         
            +
                    None,
         
     | 
| 
      
 364 
     | 
    
         
            +
                    ncols=n_cols,
         
     | 
| 
      
 365 
     | 
    
         
            +
                    nrows=n_rows,
         
     | 
| 
      
 366 
     | 
    
         
            +
                    figsize=figsize,
         
     | 
| 
      
 367 
     | 
    
         
            +
                    sharex=sharex,
         
     | 
| 
      
 368 
     | 
    
         
            +
                    sharey=sharey,
         
     | 
| 
      
 369 
     | 
    
         
            +
                    grid=grid,
         
     | 
| 
      
 370 
     | 
    
         
            +
                )
         
     | 
| 
      
 371 
     | 
    
         
            +
             
     | 
| 
      
 372 
     | 
    
         
            +
             
     | 
| 
      
 373 
     | 
    
         
            +
            ##########################################################################
         
     | 
| 
      
 374 
     | 
    
         
            +
            # Plots
         
     | 
| 
      
 375 
     | 
    
         
            +
            ##########################################################################
         
     | 
| 
      
 376 
     | 
    
         
            +
             
     | 
| 
      
 377 
     | 
    
         
            +
             
     | 
| 
      
 378 
     | 
    
         
            +
            def bars(
         
     | 
| 
      
 379 
     | 
    
         
            +
                x: pd.DataFrame,
         
     | 
| 
      
 380 
     | 
    
         
            +
                *,
         
     | 
| 
      
 381 
     | 
    
         
            +
                ax: Axes | None = None,
         
     | 
| 
      
 382 
     | 
    
         
            +
                grid: bool = True,
         
     | 
| 
      
 383 
     | 
    
         
            +
            ) -> FigAx:
         
     | 
| 
      
 384 
     | 
    
         
            +
                """Plot multiple lines on the same axis."""
         
     | 
| 
      
 385 
     | 
    
         
            +
                fig, ax = _default_fig_ax(ax=ax, grid=grid)
         
     | 
| 
      
 386 
     | 
    
         
            +
                sns.barplot(data=x, ax=ax)
         
     | 
| 
      
 387 
     | 
    
         
            +
                _default_labels(ax, xlabel=x.index.name, ylabel=None)
         
     | 
| 
      
 388 
     | 
    
         
            +
                ax.legend(x.columns)
         
     | 
| 
      
 389 
     | 
    
         
            +
                return fig, ax
         
     | 
| 
      
 390 
     | 
    
         
            +
             
     | 
| 
      
 391 
     | 
    
         
            +
             
     | 
| 
      
 392 
     | 
    
         
            +
            def lines(
         
     | 
| 
      
 393 
     | 
    
         
            +
                x: pd.DataFrame | pd.Series,
         
     | 
| 
      
 394 
     | 
    
         
            +
                *,
         
     | 
| 
      
 395 
     | 
    
         
            +
                ax: Axes | None = None,
         
     | 
| 
      
 396 
     | 
    
         
            +
                grid: bool = True,
         
     | 
| 
      
 397 
     | 
    
         
            +
            ) -> FigAx:
         
     | 
| 
      
 398 
     | 
    
         
            +
                """Plot multiple lines on the same axis."""
         
     | 
| 
      
 399 
     | 
    
         
            +
                fig, ax = _default_fig_ax(ax=ax, grid=grid)
         
     | 
| 
      
 400 
     | 
    
         
            +
                ax.plot(x.index, x)
         
     | 
| 
      
 401 
     | 
    
         
            +
                _default_labels(ax, xlabel=x.index.name, ylabel=None)
         
     | 
| 
      
 402 
     | 
    
         
            +
                ax.legend(x.columns)
         
     | 
| 
      
 403 
     | 
    
         
            +
                return fig, ax
         
     | 
| 
      
 404 
     | 
    
         
            +
             
     | 
| 
      
 405 
     | 
    
         
            +
             
     | 
| 
      
 406 
     | 
    
         
            +
            def lines_grouped(
         
     | 
| 
      
 407 
     | 
    
         
            +
                groups: list[pd.DataFrame] | list[pd.Series],
         
     | 
| 
      
 408 
     | 
    
         
            +
                *,
         
     | 
| 
      
 409 
     | 
    
         
            +
                n_cols: int = 2,
         
     | 
| 
      
 410 
     | 
    
         
            +
                col_width: float = 3,
         
     | 
| 
      
 411 
     | 
    
         
            +
                row_height: float = 4,
         
     | 
| 
      
 412 
     | 
    
         
            +
                sharex: bool = True,
         
     | 
| 
      
 413 
     | 
    
         
            +
                sharey: bool = False,
         
     | 
| 
      
 414 
     | 
    
         
            +
                grid: bool = True,
         
     | 
| 
      
 415 
     | 
    
         
            +
            ) -> FigAxs:
         
     | 
| 
      
 416 
     | 
    
         
            +
                """Plot multiple groups of lines on separate axes."""
         
     | 
| 
      
 417 
     | 
    
         
            +
                fig, axs = grid_layout(
         
     | 
| 
      
 418 
     | 
    
         
            +
                    len(groups),
         
     | 
| 
      
 419 
     | 
    
         
            +
                    n_cols=n_cols,
         
     | 
| 
      
 420 
     | 
    
         
            +
                    col_width=col_width,
         
     | 
| 
      
 421 
     | 
    
         
            +
                    row_height=row_height,
         
     | 
| 
      
 422 
     | 
    
         
            +
                    sharex=sharex,
         
     | 
| 
      
 423 
     | 
    
         
            +
                    sharey=sharey,
         
     | 
| 
      
 424 
     | 
    
         
            +
                    grid=grid,
         
     | 
| 
      
 425 
     | 
    
         
            +
                )
         
     | 
| 
      
 426 
     | 
    
         
            +
             
     | 
| 
      
 427 
     | 
    
         
            +
                for group, ax in zip(groups, axs, strict=False):
         
     | 
| 
      
 428 
     | 
    
         
            +
                    lines(group, ax=ax, grid=grid)
         
     | 
| 
      
 429 
     | 
    
         
            +
             
     | 
| 
      
 430 
     | 
    
         
            +
                for i in range(len(groups), len(axs)):
         
     | 
| 
      
 431 
     | 
    
         
            +
                    axs[i].set_visible(False)
         
     | 
| 
      
 432 
     | 
    
         
            +
             
     | 
| 
      
 433 
     | 
    
         
            +
                return fig, axs
         
     | 
| 
      
 434 
     | 
    
         
            +
             
     | 
| 
      
 435 
     | 
    
         
            +
             
     | 
| 
      
 436 
     | 
    
         
            +
            def line_autogrouped(
         
     | 
| 
      
 437 
     | 
    
         
            +
                s: pd.Series | pd.DataFrame,
         
     | 
| 
      
 438 
     | 
    
         
            +
                *,
         
     | 
| 
      
 439 
     | 
    
         
            +
                n_cols: int = 2,
         
     | 
| 
      
 440 
     | 
    
         
            +
                col_width: float = 4,
         
     | 
| 
      
 441 
     | 
    
         
            +
                row_height: float = 3,
         
     | 
| 
      
 442 
     | 
    
         
            +
                max_group_size: int = 6,
         
     | 
| 
      
 443 
     | 
    
         
            +
                grid: bool = True,
         
     | 
| 
      
 444 
     | 
    
         
            +
            ) -> FigAxs:
         
     | 
| 
      
 445 
     | 
    
         
            +
                """Plot a series or dataframe with lines grouped by order of magnitude."""
         
     | 
| 
      
 446 
     | 
    
         
            +
                group_names = _split_large_groups(
         
     | 
| 
      
 447 
     | 
    
         
            +
                    _partition_by_order_of_magnitude(s)
         
     | 
| 
      
 448 
     | 
    
         
            +
                    if isinstance(s, pd.Series)
         
     | 
| 
      
 449 
     | 
    
         
            +
                    else _partition_by_order_of_magnitude(s.max()),
         
     | 
| 
      
 450 
     | 
    
         
            +
                    max_size=max_group_size,
         
     | 
| 
      
 451 
     | 
    
         
            +
                )
         
     | 
| 
      
 452 
     | 
    
         
            +
             
     | 
| 
      
 453 
     | 
    
         
            +
                groups: list[pd.Series] | list[pd.DataFrame] = (
         
     | 
| 
      
 454 
     | 
    
         
            +
                    [s.loc[group] for group in group_names]
         
     | 
| 
      
 455 
     | 
    
         
            +
                    if isinstance(s, pd.Series)
         
     | 
| 
      
 456 
     | 
    
         
            +
                    else [s.loc[:, group] for group in group_names]
         
     | 
| 
      
 457 
     | 
    
         
            +
                )
         
     | 
| 
      
 458 
     | 
    
         
            +
             
     | 
| 
      
 459 
     | 
    
         
            +
                return lines_grouped(
         
     | 
| 
      
 460 
     | 
    
         
            +
                    groups,
         
     | 
| 
      
 461 
     | 
    
         
            +
                    n_cols=n_cols,
         
     | 
| 
      
 462 
     | 
    
         
            +
                    col_width=col_width,
         
     | 
| 
      
 463 
     | 
    
         
            +
                    row_height=row_height,
         
     | 
| 
      
 464 
     | 
    
         
            +
                    grid=grid,
         
     | 
| 
      
 465 
     | 
    
         
            +
                )
         
     | 
| 
      
 466 
     | 
    
         
            +
             
     | 
| 
      
 467 
     | 
    
         
            +
             
     | 
| 
      
 468 
     | 
    
         
            +
            def line_mean_std(
         
     | 
| 
      
 469 
     | 
    
         
            +
                df: pd.DataFrame,
         
     | 
| 
      
 470 
     | 
    
         
            +
                *,
         
     | 
| 
      
 471 
     | 
    
         
            +
                label: str | None = None,
         
     | 
| 
      
 472 
     | 
    
         
            +
                ax: Axes | None = None,
         
     | 
| 
      
 473 
     | 
    
         
            +
                color: str | None = None,
         
     | 
| 
      
 474 
     | 
    
         
            +
                alpha: float = 0.2,
         
     | 
| 
      
 475 
     | 
    
         
            +
                grid: bool = True,
         
     | 
| 
      
 476 
     | 
    
         
            +
            ) -> FigAx:
         
     | 
| 
      
 477 
     | 
    
         
            +
                """Plot the mean and standard deviation using a line and fill."""
         
     | 
| 
      
 478 
     | 
    
         
            +
                fig, ax = _default_fig_ax(ax=ax, grid=grid)
         
     | 
| 
      
 479 
     | 
    
         
            +
                color = _default_color(ax=ax, color=color)
         
     | 
| 
      
 480 
     | 
    
         
            +
             
     | 
| 
      
 481 
     | 
    
         
            +
                mean = df.mean(axis=1)
         
     | 
| 
      
 482 
     | 
    
         
            +
                std = df.std(axis=1)
         
     | 
| 
      
 483 
     | 
    
         
            +
                ax.plot(
         
     | 
| 
      
 484 
     | 
    
         
            +
                    mean.index,
         
     | 
| 
      
 485 
     | 
    
         
            +
                    mean,
         
     | 
| 
      
 486 
     | 
    
         
            +
                    color=color,
         
     | 
| 
      
 487 
     | 
    
         
            +
                    label=label,
         
     | 
| 
      
 488 
     | 
    
         
            +
                )
         
     | 
| 
      
 489 
     | 
    
         
            +
                ax.fill_between(
         
     | 
| 
      
 490 
     | 
    
         
            +
                    df.index,
         
     | 
| 
      
 491 
     | 
    
         
            +
                    mean - std,
         
     | 
| 
      
 492 
     | 
    
         
            +
                    mean + std,
         
     | 
| 
      
 493 
     | 
    
         
            +
                    color=color,
         
     | 
| 
      
 494 
     | 
    
         
            +
                    alpha=alpha,
         
     | 
| 
      
 495 
     | 
    
         
            +
                )
         
     | 
| 
      
 496 
     | 
    
         
            +
                _default_labels(ax, xlabel=df.index.name, ylabel=None)
         
     | 
| 
      
 497 
     | 
    
         
            +
                return fig, ax
         
     | 
| 
      
 498 
     | 
    
         
            +
             
     | 
| 
      
 499 
     | 
    
         
            +
             
     | 
| 
      
 500 
     | 
    
         
            +
            def lines_mean_std_from_2d_idx(
         
     | 
| 
      
 501 
     | 
    
         
            +
                df: pd.DataFrame,
         
     | 
| 
      
 502 
     | 
    
         
            +
                *,
         
     | 
| 
      
 503 
     | 
    
         
            +
                names: list[str] | None = None,
         
     | 
| 
      
 504 
     | 
    
         
            +
                ax: Axes | None = None,
         
     | 
| 
      
 505 
     | 
    
         
            +
                alpha: float = 0.2,
         
     | 
| 
      
 506 
     | 
    
         
            +
                grid: bool = True,
         
     | 
| 
      
 507 
     | 
    
         
            +
            ) -> FigAx:
         
     | 
| 
      
 508 
     | 
    
         
            +
                """Plot the mean and standard deviation of a 2D indexed dataframe."""
         
     | 
| 
      
 509 
     | 
    
         
            +
                if len(cast(pd.MultiIndex, df.index).levels) != 2:  # noqa: PLR2004
         
     | 
| 
      
 510 
     | 
    
         
            +
                    msg = "MultiIndex must have exactly two levels"
         
     | 
| 
      
 511 
     | 
    
         
            +
                    raise ValueError(msg)
         
     | 
| 
      
 512 
     | 
    
         
            +
             
     | 
| 
      
 513 
     | 
    
         
            +
                fig, ax = _default_fig_ax(ax=ax, grid=grid)
         
     | 
| 
      
 514 
     | 
    
         
            +
             
     | 
| 
      
 515 
     | 
    
         
            +
                for name in df.columns if names is None else names:
         
     | 
| 
      
 516 
     | 
    
         
            +
                    line_mean_std(
         
     | 
| 
      
 517 
     | 
    
         
            +
                        df[name].unstack().T,
         
     | 
| 
      
 518 
     | 
    
         
            +
                        label=name,
         
     | 
| 
      
 519 
     | 
    
         
            +
                        alpha=alpha,
         
     | 
| 
      
 520 
     | 
    
         
            +
                        ax=ax,
         
     | 
| 
      
 521 
     | 
    
         
            +
                    )
         
     | 
| 
      
 522 
     | 
    
         
            +
                ax.legend()
         
     | 
| 
      
 523 
     | 
    
         
            +
                return fig, ax
         
     | 
| 
      
 524 
     | 
    
         
            +
             
     | 
| 
      
 525 
     | 
    
         
            +
             
     | 
| 
      
 526 
     | 
    
         
            +
            def heatmap(
         
     | 
| 
      
 527 
     | 
    
         
            +
                df: pd.DataFrame,
         
     | 
| 
      
 528 
     | 
    
         
            +
                *,
         
     | 
| 
      
 529 
     | 
    
         
            +
                annotate: bool = False,
         
     | 
| 
      
 530 
     | 
    
         
            +
                colorbar: bool = True,
         
     | 
| 
      
 531 
     | 
    
         
            +
                invert_yaxis: bool = True,
         
     | 
| 
      
 532 
     | 
    
         
            +
                cmap: str = "RdBu_r",
         
     | 
| 
      
 533 
     | 
    
         
            +
                norm: Normalize | None = None,
         
     | 
| 
      
 534 
     | 
    
         
            +
                ax: Axes | None = None,
         
     | 
| 
      
 535 
     | 
    
         
            +
                cax: Axes | None = None,
         
     | 
| 
      
 536 
     | 
    
         
            +
                sci_annotation_bounds: tuple[float, float] = (0.01, 100),
         
     | 
| 
      
 537 
     | 
    
         
            +
                annotation_style: str = "2g",
         
     | 
| 
      
 538 
     | 
    
         
            +
            ) -> tuple[Figure, Axes, QuadMesh]:
         
     | 
| 
      
 539 
     | 
    
         
            +
                """Plot a heatmap of the given data."""
         
     | 
| 
      
 540 
     | 
    
         
            +
                fig, ax = _default_fig_ax(
         
     | 
| 
      
 541 
     | 
    
         
            +
                    ax=ax,
         
     | 
| 
      
 542 
     | 
    
         
            +
                    figsize=(
         
     | 
| 
      
 543 
     | 
    
         
            +
                        max(4, 0.5 * len(df.columns)),
         
     | 
| 
      
 544 
     | 
    
         
            +
                        max(4, 0.5 * len(df.index)),
         
     | 
| 
      
 545 
     | 
    
         
            +
                    ),
         
     | 
| 
      
 546 
     | 
    
         
            +
                    grid=False,
         
     | 
| 
      
 547 
     | 
    
         
            +
                )
         
     | 
| 
      
 548 
     | 
    
         
            +
                if norm is None:
         
     | 
| 
      
 549 
     | 
    
         
            +
                    norm = _norm_with_zero_center(df)
         
     | 
| 
      
 550 
     | 
    
         
            +
             
     | 
| 
      
 551 
     | 
    
         
            +
                hm = ax.pcolormesh(df, norm=norm, cmap=cmap)
         
     | 
| 
      
 552 
     | 
    
         
            +
                ax.set_xticks(
         
     | 
| 
      
 553 
     | 
    
         
            +
                    np.arange(0, len(df.columns), 1) + 0.5,
         
     | 
| 
      
 554 
     | 
    
         
            +
                    labels=df.columns,
         
     | 
| 
      
 555 
     | 
    
         
            +
                )
         
     | 
| 
      
 556 
     | 
    
         
            +
                ax.set_yticks(
         
     | 
| 
      
 557 
     | 
    
         
            +
                    np.arange(0, len(df.index), 1) + 0.5,
         
     | 
| 
      
 558 
     | 
    
         
            +
                    labels=df.index,
         
     | 
| 
      
 559 
     | 
    
         
            +
                )
         
     | 
| 
      
 560 
     | 
    
         
            +
             
     | 
| 
      
 561 
     | 
    
         
            +
                if annotate:
         
     | 
| 
      
 562 
     | 
    
         
            +
                    _annotate_colormap(df, ax, sci_annotation_bounds, annotation_style, hm)
         
     | 
| 
      
 563 
     | 
    
         
            +
             
     | 
| 
      
 564 
     | 
    
         
            +
                if colorbar:
         
     | 
| 
      
 565 
     | 
    
         
            +
                    # Add a colorbar
         
     | 
| 
      
 566 
     | 
    
         
            +
                    cb = fig.colorbar(hm, cax, ax)
         
     | 
| 
      
 567 
     | 
    
         
            +
                    cb.outline.set_linewidth(0)  # type: ignore
         
     | 
| 
      
 568 
     | 
    
         
            +
             
     | 
| 
      
 569 
     | 
    
         
            +
                if invert_yaxis:
         
     | 
| 
      
 570 
     | 
    
         
            +
                    ax.invert_yaxis()
         
     | 
| 
      
 571 
     | 
    
         
            +
                rotate_xlabels(ax, rotation=45, ha="right")
         
     | 
| 
      
 572 
     | 
    
         
            +
                return fig, ax, hm
         
     | 
| 
      
 573 
     | 
    
         
            +
             
     | 
| 
      
 574 
     | 
    
         
            +
             
     | 
| 
      
 575 
     | 
    
         
            +
            def heatmap_from_2d_idx(
         
     | 
| 
      
 576 
     | 
    
         
            +
                df: pd.DataFrame,
         
     | 
| 
      
 577 
     | 
    
         
            +
                variable: str,
         
     | 
| 
      
 578 
     | 
    
         
            +
                ax: Axes | None = None,
         
     | 
| 
      
 579 
     | 
    
         
            +
            ) -> FigAx:
         
     | 
| 
      
 580 
     | 
    
         
            +
                """Plot a heatmap of a 2D indexed dataframe."""
         
     | 
| 
      
 581 
     | 
    
         
            +
                if len(cast(pd.MultiIndex, df.index).levels) != 2:  # noqa: PLR2004
         
     | 
| 
      
 582 
     | 
    
         
            +
                    msg = "MultiIndex must have exactly two levels"
         
     | 
| 
      
 583 
     | 
    
         
            +
                    raise ValueError(msg)
         
     | 
| 
      
 584 
     | 
    
         
            +
             
     | 
| 
      
 585 
     | 
    
         
            +
                fig, ax = _default_fig_ax(ax=ax, grid=False)
         
     | 
| 
      
 586 
     | 
    
         
            +
                df2d = df[variable].unstack()
         
     | 
| 
      
 587 
     | 
    
         
            +
             
     | 
| 
      
 588 
     | 
    
         
            +
                ax.set_title(variable)
         
     | 
| 
      
 589 
     | 
    
         
            +
                # Note: pcolormesh swaps index/columns
         
     | 
| 
      
 590 
     | 
    
         
            +
                hm = ax.pcolormesh(df2d.T)
         
     | 
| 
      
 591 
     | 
    
         
            +
                ax.set_xlabel(df2d.index.name)
         
     | 
| 
      
 592 
     | 
    
         
            +
                ax.set_ylabel(df2d.columns.name)
         
     | 
| 
      
 593 
     | 
    
         
            +
                ax.set_xticks(
         
     | 
| 
      
 594 
     | 
    
         
            +
                    np.arange(0, len(df2d.index), 1) + 0.5,
         
     | 
| 
      
 595 
     | 
    
         
            +
                    labels=[f"{i:.2f}" for i in df2d.index],
         
     | 
| 
      
 596 
     | 
    
         
            +
                )
         
     | 
| 
      
 597 
     | 
    
         
            +
                ax.set_yticks(
         
     | 
| 
      
 598 
     | 
    
         
            +
                    np.arange(0, len(df2d.columns), 1) + 0.5,
         
     | 
| 
      
 599 
     | 
    
         
            +
                    labels=[f"{i:.2f}" for i in df2d.columns],
         
     | 
| 
      
 600 
     | 
    
         
            +
                )
         
     | 
| 
      
 601 
     | 
    
         
            +
             
     | 
| 
      
 602 
     | 
    
         
            +
                rotate_xlabels(ax, rotation=45, ha="right")
         
     | 
| 
      
 603 
     | 
    
         
            +
             
     | 
| 
      
 604 
     | 
    
         
            +
                # Add colorbar
         
     | 
| 
      
 605 
     | 
    
         
            +
                fig.colorbar(hm, ax=ax)
         
     | 
| 
      
 606 
     | 
    
         
            +
                return fig, ax
         
     | 
| 
      
 607 
     | 
    
         
            +
             
     | 
| 
      
 608 
     | 
    
         
            +
             
     | 
| 
      
 609 
     | 
    
         
            +
            def heatmaps_from_2d_idx(
         
     | 
| 
      
 610 
     | 
    
         
            +
                df: pd.DataFrame,
         
     | 
| 
      
 611 
     | 
    
         
            +
                *,
         
     | 
| 
      
 612 
     | 
    
         
            +
                n_cols: int = 3,
         
     | 
| 
      
 613 
     | 
    
         
            +
                col_width_factor: float = 1,
         
     | 
| 
      
 614 
     | 
    
         
            +
                row_height_factor: float = 0.6,
         
     | 
| 
      
 615 
     | 
    
         
            +
                sharex: bool = True,
         
     | 
| 
      
 616 
     | 
    
         
            +
                sharey: bool = False,
         
     | 
| 
      
 617 
     | 
    
         
            +
            ) -> FigAxs:
         
     | 
| 
      
 618 
     | 
    
         
            +
                """Plot multiple heatmaps of a 2D indexed dataframe."""
         
     | 
| 
      
 619 
     | 
    
         
            +
                idx = cast(pd.MultiIndex, df.index)
         
     | 
| 
      
 620 
     | 
    
         
            +
             
     | 
| 
      
 621 
     | 
    
         
            +
                fig, axs = grid_layout(
         
     | 
| 
      
 622 
     | 
    
         
            +
                    n_groups=len(df.columns),
         
     | 
| 
      
 623 
     | 
    
         
            +
                    n_cols=min(n_cols, len(df)),
         
     | 
| 
      
 624 
     | 
    
         
            +
                    col_width=len(idx.levels[0]) * col_width_factor,
         
     | 
| 
      
 625 
     | 
    
         
            +
                    row_height=len(idx.levels[1]) * row_height_factor,
         
     | 
| 
      
 626 
     | 
    
         
            +
                    sharex=sharex,
         
     | 
| 
      
 627 
     | 
    
         
            +
                    sharey=sharey,
         
     | 
| 
      
 628 
     | 
    
         
            +
                    grid=False,
         
     | 
| 
      
 629 
     | 
    
         
            +
                )
         
     | 
| 
      
 630 
     | 
    
         
            +
                for ax, var in zip(axs, df.columns, strict=False):
         
     | 
| 
      
 631 
     | 
    
         
            +
                    heatmap_from_2d_idx(df, var, ax=ax)
         
     | 
| 
      
 632 
     | 
    
         
            +
                return fig, axs
         
     | 
| 
      
 633 
     | 
    
         
            +
             
     | 
| 
      
 634 
     | 
    
         
            +
             
     | 
| 
      
 635 
     | 
    
         
            +
            def violins(
         
     | 
| 
      
 636 
     | 
    
         
            +
                df: pd.DataFrame,
         
     | 
| 
      
 637 
     | 
    
         
            +
                *,
         
     | 
| 
      
 638 
     | 
    
         
            +
                ax: Axes | None = None,
         
     | 
| 
      
 639 
     | 
    
         
            +
                grid: bool = True,
         
     | 
| 
      
 640 
     | 
    
         
            +
            ) -> FigAx:
         
     | 
| 
      
 641 
     | 
    
         
            +
                """Plot multiple violins on the same axis."""
         
     | 
| 
      
 642 
     | 
    
         
            +
                fig, ax = _default_fig_ax(ax=ax, grid=grid)
         
     | 
| 
      
 643 
     | 
    
         
            +
                sns.violinplot(df, ax=ax)
         
     | 
| 
      
 644 
     | 
    
         
            +
                _default_labels(ax=ax, xlabel="", ylabel=None)
         
     | 
| 
      
 645 
     | 
    
         
            +
                return fig, ax
         
     | 
| 
      
 646 
     | 
    
         
            +
             
     | 
| 
      
 647 
     | 
    
         
            +
             
     | 
| 
      
 648 
     | 
    
         
            +
            def violins_from_2d_idx(
         
     | 
| 
      
 649 
     | 
    
         
            +
                df: pd.DataFrame,
         
     | 
| 
      
 650 
     | 
    
         
            +
                *,
         
     | 
| 
      
 651 
     | 
    
         
            +
                n_cols: int = 4,
         
     | 
| 
      
 652 
     | 
    
         
            +
                row_height: int = 2,
         
     | 
| 
      
 653 
     | 
    
         
            +
                sharex: bool = True,
         
     | 
| 
      
 654 
     | 
    
         
            +
                sharey: bool = False,
         
     | 
| 
      
 655 
     | 
    
         
            +
                grid: bool = True,
         
     | 
| 
      
 656 
     | 
    
         
            +
            ) -> FigAxs:
         
     | 
| 
      
 657 
     | 
    
         
            +
                """Plot multiple violins of a 2D indexed dataframe."""
         
     | 
| 
      
 658 
     | 
    
         
            +
                if len(cast(pd.MultiIndex, df.index).levels) != 2:  # noqa: PLR2004
         
     | 
| 
      
 659 
     | 
    
         
            +
                    msg = "MultiIndex must have exactly two levels"
         
     | 
| 
      
 660 
     | 
    
         
            +
                    raise ValueError(msg)
         
     | 
| 
      
 661 
     | 
    
         
            +
             
     | 
| 
      
 662 
     | 
    
         
            +
                fig, axs = grid_layout(
         
     | 
| 
      
 663 
     | 
    
         
            +
                    len(df.columns),
         
     | 
| 
      
 664 
     | 
    
         
            +
                    n_cols=n_cols,
         
     | 
| 
      
 665 
     | 
    
         
            +
                    row_height=row_height,
         
     | 
| 
      
 666 
     | 
    
         
            +
                    sharex=sharex,
         
     | 
| 
      
 667 
     | 
    
         
            +
                    sharey=sharey,
         
     | 
| 
      
 668 
     | 
    
         
            +
                    grid=grid,
         
     | 
| 
      
 669 
     | 
    
         
            +
                )
         
     | 
| 
      
 670 
     | 
    
         
            +
             
     | 
| 
      
 671 
     | 
    
         
            +
                for ax, col in zip(axs[: len(df.columns)], df.columns, strict=True):
         
     | 
| 
      
 672 
     | 
    
         
            +
                    ax.set_title(col)
         
     | 
| 
      
 673 
     | 
    
         
            +
                    violins(df[col].unstack(), ax=ax)
         
     | 
| 
      
 674 
     | 
    
         
            +
             
     | 
| 
      
 675 
     | 
    
         
            +
                for ax in axs[len(df.columns) :]:
         
     | 
| 
      
 676 
     | 
    
         
            +
                    for axis in ["top", "bottom", "left", "right"]:
         
     | 
| 
      
 677 
     | 
    
         
            +
                        ax.spines[axis].set_linewidth(0)
         
     | 
| 
      
 678 
     | 
    
         
            +
                    ax.yaxis.set_ticks([])
         
     | 
| 
      
 679 
     | 
    
         
            +
             
     | 
| 
      
 680 
     | 
    
         
            +
                for ax in axs:
         
     | 
| 
      
 681 
     | 
    
         
            +
                    rotate_xlabels(ax)
         
     | 
| 
      
 682 
     | 
    
         
            +
                return fig, axs
         
     | 
| 
      
 683 
     | 
    
         
            +
             
     | 
| 
      
 684 
     | 
    
         
            +
             
     | 
| 
      
 685 
     | 
    
         
            +
            def shade_protocol(
         
     | 
| 
      
 686 
     | 
    
         
            +
                protocol: pd.Series,
         
     | 
| 
      
 687 
     | 
    
         
            +
                *,
         
     | 
| 
      
 688 
     | 
    
         
            +
                ax: Axes,
         
     | 
| 
      
 689 
     | 
    
         
            +
                cmap_name: str = "Greys_r",
         
     | 
| 
      
 690 
     | 
    
         
            +
                vmin: float | None = None,
         
     | 
| 
      
 691 
     | 
    
         
            +
                vmax: float | None = None,
         
     | 
| 
      
 692 
     | 
    
         
            +
                alpha: float = 0.5,
         
     | 
| 
      
 693 
     | 
    
         
            +
                add_legend: bool = True,
         
     | 
| 
      
 694 
     | 
    
         
            +
            ) -> None:
         
     | 
| 
      
 695 
     | 
    
         
            +
                """Shade the given protocol on the given axis."""
         
     | 
| 
      
 696 
     | 
    
         
            +
                from matplotlib import colormaps
         
     | 
| 
      
 697 
     | 
    
         
            +
                from matplotlib.colors import Normalize
         
     | 
| 
      
 698 
     | 
    
         
            +
                from matplotlib.legend import Legend
         
     | 
| 
      
 699 
     | 
    
         
            +
                from matplotlib.patches import Patch
         
     | 
| 
      
 700 
     | 
    
         
            +
             
     | 
| 
      
 701 
     | 
    
         
            +
                cmap = colormaps[cmap_name]
         
     | 
| 
      
 702 
     | 
    
         
            +
                norm = Normalize(
         
     | 
| 
      
 703 
     | 
    
         
            +
                    vmin=protocol.min() if vmin is None else vmin,
         
     | 
| 
      
 704 
     | 
    
         
            +
                    vmax=protocol.max() if vmax is None else vmax,
         
     | 
| 
      
 705 
     | 
    
         
            +
                )
         
     | 
| 
      
 706 
     | 
    
         
            +
             
     | 
| 
      
 707 
     | 
    
         
            +
                t0 = pd.Timedelta(seconds=0)
         
     | 
| 
      
 708 
     | 
    
         
            +
                for t_end, val in protocol.items():
         
     | 
| 
      
 709 
     | 
    
         
            +
                    t_end = cast(pd.Timedelta, t_end)
         
     | 
| 
      
 710 
     | 
    
         
            +
                    ax.axvspan(
         
     | 
| 
      
 711 
     | 
    
         
            +
                        t0.total_seconds(),
         
     | 
| 
      
 712 
     | 
    
         
            +
                        t_end.total_seconds(),
         
     | 
| 
      
 713 
     | 
    
         
            +
                        facecolor=cmap(norm(val)),
         
     | 
| 
      
 714 
     | 
    
         
            +
                        edgecolor=None,
         
     | 
| 
      
 715 
     | 
    
         
            +
                        alpha=alpha,
         
     | 
| 
      
 716 
     | 
    
         
            +
                    )
         
     | 
| 
      
 717 
     | 
    
         
            +
                    t0 = t_end  # type: ignore
         
     | 
| 
      
 718 
     | 
    
         
            +
             
     | 
| 
      
 719 
     | 
    
         
            +
                if add_legend:
         
     | 
| 
      
 720 
     | 
    
         
            +
                    ax.add_artist(
         
     | 
| 
      
 721 
     | 
    
         
            +
                        Legend(
         
     | 
| 
      
 722 
     | 
    
         
            +
                            ax,
         
     | 
| 
      
 723 
     | 
    
         
            +
                            handles=[
         
     | 
| 
      
 724 
     | 
    
         
            +
                                Patch(
         
     | 
| 
      
 725 
     | 
    
         
            +
                                    facecolor=cmap(norm(val)),
         
     | 
| 
      
 726 
     | 
    
         
            +
                                    alpha=alpha,
         
     | 
| 
      
 727 
     | 
    
         
            +
                                    label=val,
         
     | 
| 
      
 728 
     | 
    
         
            +
                                )  # type: ignore
         
     | 
| 
      
 729 
     | 
    
         
            +
                                for val in protocol
         
     | 
| 
      
 730 
     | 
    
         
            +
                            ],
         
     | 
| 
      
 731 
     | 
    
         
            +
                            labels=protocol,
         
     | 
| 
      
 732 
     | 
    
         
            +
                            loc="lower right",
         
     | 
| 
      
 733 
     | 
    
         
            +
                            bbox_to_anchor=(1.0, 0.0),
         
     | 
| 
      
 734 
     | 
    
         
            +
                            title="protocol" if protocol.name is None else cast(str, protocol.name),
         
     | 
| 
      
 735 
     | 
    
         
            +
                        )
         
     | 
| 
      
 736 
     | 
    
         
            +
                    )
         
     | 
| 
      
 737 
     | 
    
         
            +
             
     | 
| 
      
 738 
     | 
    
         
            +
             
     | 
| 
      
 739 
     | 
    
         
            +
            ##########################################################################
         
     | 
| 
      
 740 
     | 
    
         
            +
            # Plots that actually require a model :/
         
     | 
| 
      
 741 
     | 
    
         
            +
            ##########################################################################
         
     | 
| 
      
 742 
     | 
    
         
            +
             
     | 
| 
      
 743 
     | 
    
         
            +
             
     | 
| 
      
 744 
     | 
    
         
            +
            def trajectories_2d(
         
     | 
| 
      
 745 
     | 
    
         
            +
                model: Model,
         
     | 
| 
      
 746 
     | 
    
         
            +
                x1: tuple[str, ArrayLike],
         
     | 
| 
      
 747 
     | 
    
         
            +
                x2: tuple[str, ArrayLike],
         
     | 
| 
      
 748 
     | 
    
         
            +
                y0: dict[str, float] | None = None,
         
     | 
| 
      
 749 
     | 
    
         
            +
                ax: Axes | None = None,
         
     | 
| 
      
 750 
     | 
    
         
            +
            ) -> FigAx:
         
     | 
| 
      
 751 
     | 
    
         
            +
                """Plot trajectories of two variables in a 2D phase space.
         
     | 
| 
      
 752 
     | 
    
         
            +
             
     | 
| 
      
 753 
     | 
    
         
            +
                Examples:
         
     | 
| 
      
 754 
     | 
    
         
            +
                    >>> trajectories_2d(
         
     | 
| 
      
 755 
     | 
    
         
            +
                    ...     model,
         
     | 
| 
      
 756 
     | 
    
         
            +
                    ...     ("S", np.linspace(0, 1, 10)),
         
     | 
| 
      
 757 
     | 
    
         
            +
                    ...     ("P", np.linspace(0, 1, 10)),
         
     | 
| 
      
 758 
     | 
    
         
            +
                    ... )
         
     | 
| 
      
 759 
     | 
    
         
            +
             
     | 
| 
      
 760 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 761 
     | 
    
         
            +
                    model: Model to use for the plot.
         
     | 
| 
      
 762 
     | 
    
         
            +
                    x1: Tuple of the first variable name and its values.
         
     | 
| 
      
 763 
     | 
    
         
            +
                    x2: Tuple of the second variable name and its values.
         
     | 
| 
      
 764 
     | 
    
         
            +
                    y0: Initial conditions for the model.
         
     | 
| 
      
 765 
     | 
    
         
            +
                    ax: Axes to use for the plot.
         
     | 
| 
      
 766 
     | 
    
         
            +
             
     | 
| 
      
 767 
     | 
    
         
            +
                """
         
     | 
| 
      
 768 
     | 
    
         
            +
                name1, values1 = x1
         
     | 
| 
      
 769 
     | 
    
         
            +
                name2, values2 = x2
         
     | 
| 
      
 770 
     | 
    
         
            +
                n1 = len(values1)
         
     | 
| 
      
 771 
     | 
    
         
            +
                n2 = len(values2)
         
     | 
| 
      
 772 
     | 
    
         
            +
                u = np.zeros((n1, n2))
         
     | 
| 
      
 773 
     | 
    
         
            +
                v = np.zeros((n1, n2))
         
     | 
| 
      
 774 
     | 
    
         
            +
                y0 = model.get_initial_conditions() if y0 is None else y0
         
     | 
| 
      
 775 
     | 
    
         
            +
                for i, ii in enumerate(values1):
         
     | 
| 
      
 776 
     | 
    
         
            +
                    for j, jj in enumerate(values2):
         
     | 
| 
      
 777 
     | 
    
         
            +
                        rhs = model.get_right_hand_side(y0 | {name1: ii, name2: jj})
         
     | 
| 
      
 778 
     | 
    
         
            +
                        u[i, j] = rhs[name1]
         
     | 
| 
      
 779 
     | 
    
         
            +
                        v[i, j] = rhs[name2]
         
     | 
| 
      
 780 
     | 
    
         
            +
             
     | 
| 
      
 781 
     | 
    
         
            +
                fig, ax = _default_fig_ax(ax=ax, grid=False)
         
     | 
| 
      
 782 
     | 
    
         
            +
                ax.quiver(values1, values2, u.T, v.T)
         
     | 
| 
      
 783 
     | 
    
         
            +
                return fig, ax
         
     | 
| 
      
 784 
     | 
    
         
            +
             
     | 
| 
      
 785 
     | 
    
         
            +
             
     | 
| 
      
 786 
     | 
    
         
            +
            ##########################################################################
         
     | 
| 
      
 787 
     | 
    
         
            +
            # Label Plots
         
     | 
| 
      
 788 
     | 
    
         
            +
            ##########################################################################
         
     | 
| 
      
 789 
     | 
    
         
            +
             
     | 
| 
      
 790 
     | 
    
         
            +
             
     | 
| 
      
 791 
     | 
    
         
            +
            def relative_label_distribution(
         
     | 
| 
      
 792 
     | 
    
         
            +
                mapper: LabelMapper | LinearLabelMapper,
         
     | 
| 
      
 793 
     | 
    
         
            +
                concs: pd.DataFrame,
         
     | 
| 
      
 794 
     | 
    
         
            +
                *,
         
     | 
| 
      
 795 
     | 
    
         
            +
                subset: list[str] | None = None,
         
     | 
| 
      
 796 
     | 
    
         
            +
                n_cols: int = 2,
         
     | 
| 
      
 797 
     | 
    
         
            +
                col_width: float = 3,
         
     | 
| 
      
 798 
     | 
    
         
            +
                row_height: float = 3,
         
     | 
| 
      
 799 
     | 
    
         
            +
                sharey: bool = False,
         
     | 
| 
      
 800 
     | 
    
         
            +
                grid: bool = True,
         
     | 
| 
      
 801 
     | 
    
         
            +
            ) -> FigAxs:
         
     | 
| 
      
 802 
     | 
    
         
            +
                """Plot the relative distribution of labels in the given data."""
         
     | 
| 
      
 803 
     | 
    
         
            +
                variables = list(mapper.label_variables) if subset is None else subset
         
     | 
| 
      
 804 
     | 
    
         
            +
                fig, axs = grid_layout(
         
     | 
| 
      
 805 
     | 
    
         
            +
                    n_groups=len(variables),
         
     | 
| 
      
 806 
     | 
    
         
            +
                    n_cols=n_cols,
         
     | 
| 
      
 807 
     | 
    
         
            +
                    col_width=col_width,
         
     | 
| 
      
 808 
     | 
    
         
            +
                    row_height=row_height,
         
     | 
| 
      
 809 
     | 
    
         
            +
                    sharey=sharey,
         
     | 
| 
      
 810 
     | 
    
         
            +
                    grid=grid,
         
     | 
| 
      
 811 
     | 
    
         
            +
                )
         
     | 
| 
      
 812 
     | 
    
         
            +
                if isinstance(mapper, LabelMapper):
         
     | 
| 
      
 813 
     | 
    
         
            +
                    for ax, name in zip(axs, variables, strict=False):
         
     | 
| 
      
 814 
     | 
    
         
            +
                        for i in range(mapper.label_variables[name]):
         
     | 
| 
      
 815 
     | 
    
         
            +
                            isos = mapper.get_isotopomers_of_at_position(name, i)
         
     | 
| 
      
 816 
     | 
    
         
            +
                            labels = cast(pd.DataFrame, concs.loc[:, isos])
         
     | 
| 
      
 817 
     | 
    
         
            +
                            total = concs.loc[:, f"{name}__total"]
         
     | 
| 
      
 818 
     | 
    
         
            +
                            ax.plot(labels.index, (labels.sum(axis=1) / total), label=f"C{i+1}")
         
     | 
| 
      
 819 
     | 
    
         
            +
                        ax.set_title(name)
         
     | 
| 
      
 820 
     | 
    
         
            +
                        ax.legend()
         
     | 
| 
      
 821 
     | 
    
         
            +
                else:
         
     | 
| 
      
 822 
     | 
    
         
            +
                    for ax, (name, isos) in zip(
         
     | 
| 
      
 823 
     | 
    
         
            +
                        axs, mapper.get_isotopomers(variables).items(), strict=False
         
     | 
| 
      
 824 
     | 
    
         
            +
                    ):
         
     | 
| 
      
 825 
     | 
    
         
            +
                        ax.plot(concs.index, concs.loc[:, isos])
         
     | 
| 
      
 826 
     | 
    
         
            +
                        ax.set_title(name)
         
     | 
| 
      
 827 
     | 
    
         
            +
                        ax.legend([f"C{i+1}" for i in range(len(isos))])
         
     | 
| 
      
 828 
     | 
    
         
            +
             
     | 
| 
      
 829 
     | 
    
         
            +
                return fig, axs
         
     |