modelbase2 0.1.79__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modelbase2/__init__.py +148 -25
 - modelbase2/distributions.py +336 -0
 - modelbase2/experimental/__init__.py +17 -0
 - modelbase2/experimental/codegen.py +239 -0
 - modelbase2/experimental/diff.py +227 -0
 - modelbase2/experimental/notes.md +4 -0
 - modelbase2/experimental/tex.py +521 -0
 - modelbase2/fit.py +284 -0
 - modelbase2/fns.py +185 -0
 - modelbase2/integrators/__init__.py +19 -0
 - modelbase2/integrators/int_assimulo.py +146 -0
 - modelbase2/integrators/int_scipy.py +147 -0
 - modelbase2/label_map.py +610 -0
 - modelbase2/linear_label_map.py +301 -0
 - modelbase2/mc.py +548 -0
 - modelbase2/mca.py +280 -0
 - modelbase2/model.py +1621 -0
 - modelbase2/nnarchitectures.py +128 -0
 - modelbase2/npe.py +271 -0
 - modelbase2/parallel.py +171 -0
 - modelbase2/parameterise.py +28 -0
 - modelbase2/paths.py +36 -0
 - modelbase2/plot.py +832 -0
 - modelbase2/sbml/__init__.py +14 -0
 - modelbase2/sbml/_data.py +77 -0
 - modelbase2/sbml/_export.py +656 -0
 - modelbase2/sbml/_import.py +585 -0
 - modelbase2/sbml/_mathml.py +691 -0
 - modelbase2/sbml/_name_conversion.py +52 -0
 - modelbase2/sbml/_unit_conversion.py +74 -0
 - modelbase2/scan.py +616 -0
 - modelbase2/scope.py +96 -0
 - modelbase2/simulator.py +635 -0
 - modelbase2/surrogates/__init__.py +31 -0
 - modelbase2/surrogates/_poly.py +91 -0
 - modelbase2/surrogates/_torch.py +191 -0
 - modelbase2/surrogates.py +316 -0
 - modelbase2/types.py +352 -11
 - modelbase2-0.3.0.dist-info/METADATA +93 -0
 - modelbase2-0.3.0.dist-info/RECORD +43 -0
 - {modelbase2-0.1.79.dist-info → modelbase2-0.3.0.dist-info}/WHEEL +1 -1
 - modelbase2/core/__init__.py +0 -29
 - modelbase2/core/algebraic_module_container.py +0 -130
 - modelbase2/core/constant_container.py +0 -113
 - modelbase2/core/data.py +0 -109
 - modelbase2/core/name_container.py +0 -29
 - modelbase2/core/reaction_container.py +0 -115
 - modelbase2/core/utils.py +0 -28
 - modelbase2/core/variable_container.py +0 -24
 - modelbase2/ode/__init__.py +0 -13
 - modelbase2/ode/integrator.py +0 -80
 - modelbase2/ode/mca.py +0 -270
 - modelbase2/ode/model.py +0 -470
 - modelbase2/ode/simulator.py +0 -153
 - modelbase2/utils/__init__.py +0 -0
 - modelbase2/utils/plotting.py +0 -372
 - modelbase2-0.1.79.dist-info/METADATA +0 -44
 - modelbase2-0.1.79.dist-info/RECORD +0 -22
 - {modelbase2-0.1.79.dist-info → modelbase2-0.3.0.dist-info/licenses}/LICENSE +0 -0
 
    
        modelbase2/ode/mca.py
    DELETED
    
    | 
         @@ -1,270 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            from __future__ import annotations
         
     | 
| 
       2 
     | 
    
         
            -
             
     | 
| 
       3 
     | 
    
         
            -
            import copy
         
     | 
| 
       4 
     | 
    
         
            -
            import math
         
     | 
| 
       5 
     | 
    
         
            -
            import matplotlib.pyplot as plt
         
     | 
| 
       6 
     | 
    
         
            -
            import numpy as np
         
     | 
| 
       7 
     | 
    
         
            -
            import pandas as pd
         
     | 
| 
       8 
     | 
    
         
            -
            from ..types import Array, DataFrame, Series
         
     | 
| 
       9 
     | 
    
         
            -
            from ..utils.plotting import get_norm as _get_norm
         
     | 
| 
       10 
     | 
    
         
            -
            from ..utils.plotting import heatmap_from_dataframe as _heatmap_from_dataframe
         
     | 
| 
       11 
     | 
    
         
            -
            from .integrator import Assimulo
         
     | 
| 
       12 
     | 
    
         
            -
            from .model import Model
         
     | 
| 
       13 
     | 
    
         
            -
            from .simulator import Simulator
         
     | 
| 
       14 
     | 
    
         
            -
            from matplotlib import cm
         
     | 
| 
       15 
     | 
    
         
            -
            from matplotlib.axes import Axes
         
     | 
| 
       16 
     | 
    
         
            -
            from matplotlib.collections import QuadMesh
         
     | 
| 
       17 
     | 
    
         
            -
            from matplotlib.figure import Figure
         
     | 
| 
       18 
     | 
    
         
            -
            from typing import Optional, cast
         
     | 
| 
       19 
     | 
    
         
            -
             
     | 
| 
       20 
     | 
    
         
            -
            _DISPLACEMENT = 1e-4
         
     | 
| 
       21 
     | 
    
         
            -
            _DEFAULT_TOLERANCE = 1e-8
         
     | 
| 
       22 
     | 
    
         
            -
             
     | 
| 
       23 
     | 
    
         
            -
             
     | 
| 
       24 
     | 
    
         
            -
            def get_variable_elasticity(
         
     | 
| 
       25 
     | 
    
         
            -
                m: Model,
         
     | 
| 
       26 
     | 
    
         
            -
                variable: str,
         
     | 
| 
       27 
     | 
    
         
            -
                y: dict[str, float],
         
     | 
| 
       28 
     | 
    
         
            -
                displacement: float = _DISPLACEMENT,
         
     | 
| 
       29 
     | 
    
         
            -
            ) -> Array:
         
     | 
| 
       30 
     | 
    
         
            -
                """Get sensitivity of all rates to a change of the concentration of a variable.
         
     | 
| 
       31 
     | 
    
         
            -
             
     | 
| 
       32 
     | 
    
         
            -
                Also called epsilon-elasticities. Not in steady state!
         
     | 
| 
       33 
     | 
    
         
            -
                """
         
     | 
| 
       34 
     | 
    
         
            -
                y_full = y | m.get_derived_variables(y)
         
     | 
| 
       35 
     | 
    
         
            -
                old_concentration = y_full[variable]
         
     | 
| 
       36 
     | 
    
         
            -
                fluxes: list[Array] = []
         
     | 
| 
       37 
     | 
    
         
            -
                for new_concentration in (
         
     | 
| 
       38 
     | 
    
         
            -
                    old_concentration * (1 + displacement),
         
     | 
| 
       39 
     | 
    
         
            -
                    old_concentration * (1 - displacement),
         
     | 
| 
       40 
     | 
    
         
            -
                ):
         
     | 
| 
       41 
     | 
    
         
            -
                    y_full[variable] = new_concentration
         
     | 
| 
       42 
     | 
    
         
            -
                    fluxes.append(m.get_fluxes(y_full, "array"))
         
     | 
| 
       43 
     | 
    
         
            -
                elasticity_coef = (fluxes[0] - fluxes[1]) / (2 * displacement * old_concentration)
         
     | 
| 
       44 
     | 
    
         
            -
                # normalise
         
     | 
| 
       45 
     | 
    
         
            -
                y_full[variable] = old_concentration
         
     | 
| 
       46 
     | 
    
         
            -
                flux_array = m.get_fluxes(y_full, "array")
         
     | 
| 
       47 
     | 
    
         
            -
                elasticity_coef *= old_concentration / flux_array
         
     | 
| 
       48 
     | 
    
         
            -
                return np.atleast_1d(np.squeeze(elasticity_coef))
         
     | 
| 
       49 
     | 
    
         
            -
             
     | 
| 
       50 
     | 
    
         
            -
             
     | 
| 
       51 
     | 
    
         
            -
            def get_variable_elasticities(
         
     | 
| 
       52 
     | 
    
         
            -
                m: Model,
         
     | 
| 
       53 
     | 
    
         
            -
                variables: list[str],
         
     | 
| 
       54 
     | 
    
         
            -
                y: dict[str, float],
         
     | 
| 
       55 
     | 
    
         
            -
                displacement: float = _DISPLACEMENT,
         
     | 
| 
       56 
     | 
    
         
            -
            ) -> DataFrame:
         
     | 
| 
       57 
     | 
    
         
            -
                """Get sensitivity of all rates to a change of the concentration of multiple variables.
         
     | 
| 
       58 
     | 
    
         
            -
             
     | 
| 
       59 
     | 
    
         
            -
                Also called epsilon-elasticities. Not in steady state!
         
     | 
| 
       60 
     | 
    
         
            -
                """
         
     | 
| 
       61 
     | 
    
         
            -
                stoichiometries = m.get_stoichiometries().columns
         
     | 
| 
       62 
     | 
    
         
            -
                elasticities = np.full(
         
     | 
| 
       63 
     | 
    
         
            -
                    shape=(len(variables), len(stoichiometries)), fill_value=np.nan
         
     | 
| 
       64 
     | 
    
         
            -
                )
         
     | 
| 
       65 
     | 
    
         
            -
                for i, variable in enumerate(variables):
         
     | 
| 
       66 
     | 
    
         
            -
                    elasticities[i] = get_variable_elasticity(
         
     | 
| 
       67 
     | 
    
         
            -
                        m=m,
         
     | 
| 
       68 
     | 
    
         
            -
                        variable=variable,
         
     | 
| 
       69 
     | 
    
         
            -
                        y=y,
         
     | 
| 
       70 
     | 
    
         
            -
                        displacement=displacement,
         
     | 
| 
       71 
     | 
    
         
            -
                    )
         
     | 
| 
       72 
     | 
    
         
            -
                return pd.DataFrame(elasticities, index=variables, columns=stoichiometries)
         
     | 
| 
       73 
     | 
    
         
            -
             
     | 
| 
       74 
     | 
    
         
            -
             
     | 
| 
       75 
     | 
    
         
            -
            def get_constant_elasticity(
         
     | 
| 
       76 
     | 
    
         
            -
                m: Model,
         
     | 
| 
       77 
     | 
    
         
            -
                constant: str,
         
     | 
| 
       78 
     | 
    
         
            -
                y: dict[str, float],
         
     | 
| 
       79 
     | 
    
         
            -
                displacement: float = _DISPLACEMENT,
         
     | 
| 
       80 
     | 
    
         
            -
            ) -> Array:
         
     | 
| 
       81 
     | 
    
         
            -
                """Get sensitivity of all rates to a change of a constant value.
         
     | 
| 
       82 
     | 
    
         
            -
             
     | 
| 
       83 
     | 
    
         
            -
                Also called pi-elasticities. Not in steady state!
         
     | 
| 
       84 
     | 
    
         
            -
                """
         
     | 
| 
       85 
     | 
    
         
            -
                m = copy.deepcopy(m)
         
     | 
| 
       86 
     | 
    
         
            -
                old_value = m.constant_values[constant]
         
     | 
| 
       87 
     | 
    
         
            -
                fluxes = []
         
     | 
| 
       88 
     | 
    
         
            -
                for new_value in [old_value * (1 + displacement), old_value * (1 - displacement)]:
         
     | 
| 
       89 
     | 
    
         
            -
                    m.update_constant(constant, new_value)
         
     | 
| 
       90 
     | 
    
         
            -
                    fluxes.append(m.get_fluxes(y, "array"))
         
     | 
| 
       91 
     | 
    
         
            -
                elasticity_coef = (fluxes[0] - fluxes[1]) / (2 * displacement * old_value)
         
     | 
| 
       92 
     | 
    
         
            -
                # normalise
         
     | 
| 
       93 
     | 
    
         
            -
                m.update_constant(constant, old_value)
         
     | 
| 
       94 
     | 
    
         
            -
                fluxes_array = m.get_fluxes(y, "array")
         
     | 
| 
       95 
     | 
    
         
            -
                elasticity_coef *= old_value / fluxes_array
         
     | 
| 
       96 
     | 
    
         
            -
                return np.atleast_1d(np.squeeze(elasticity_coef))
         
     | 
| 
       97 
     | 
    
         
            -
             
     | 
| 
       98 
     | 
    
         
            -
             
     | 
| 
       99 
     | 
    
         
            -
            def get_constant_elasticities(
         
     | 
| 
       100 
     | 
    
         
            -
                m: Model,
         
     | 
| 
       101 
     | 
    
         
            -
                constants: list[str],
         
     | 
| 
       102 
     | 
    
         
            -
                y: dict[str, float],
         
     | 
| 
       103 
     | 
    
         
            -
                displacement: float = _DISPLACEMENT,
         
     | 
| 
       104 
     | 
    
         
            -
            ) -> DataFrame:
         
     | 
| 
       105 
     | 
    
         
            -
                """Get sensitivity of all rates to a change of multiple constant values.
         
     | 
| 
       106 
     | 
    
         
            -
             
     | 
| 
       107 
     | 
    
         
            -
                Also called pi-elasticities. Not in steady state!
         
     | 
| 
       108 
     | 
    
         
            -
                """
         
     | 
| 
       109 
     | 
    
         
            -
                stoichiometries = m.get_stoichiometries().columns
         
     | 
| 
       110 
     | 
    
         
            -
                elasticities = np.full(
         
     | 
| 
       111 
     | 
    
         
            -
                    shape=(len(constants), len(stoichiometries)), fill_value=np.nan
         
     | 
| 
       112 
     | 
    
         
            -
                )
         
     | 
| 
       113 
     | 
    
         
            -
                for i, constant in enumerate(constants):
         
     | 
| 
       114 
     | 
    
         
            -
                    elasticities[i] = get_constant_elasticity(
         
     | 
| 
       115 
     | 
    
         
            -
                        m=m,
         
     | 
| 
       116 
     | 
    
         
            -
                        constant=constant,
         
     | 
| 
       117 
     | 
    
         
            -
                        y=y,
         
     | 
| 
       118 
     | 
    
         
            -
                        displacement=displacement,
         
     | 
| 
       119 
     | 
    
         
            -
                    )
         
     | 
| 
       120 
     | 
    
         
            -
                return pd.DataFrame(elasticities, index=constants, columns=stoichiometries)
         
     | 
| 
       121 
     | 
    
         
            -
             
     | 
| 
       122 
     | 
    
         
            -
             
     | 
| 
       123 
     | 
    
         
            -
            def _get_response_coefficients_single_constant(
         
     | 
| 
       124 
     | 
    
         
            -
                m: Model,
         
     | 
| 
       125 
     | 
    
         
            -
                constant: str,
         
     | 
| 
       126 
     | 
    
         
            -
                y0: dict[str, float],
         
     | 
| 
       127 
     | 
    
         
            -
                displacement: float = _DISPLACEMENT,
         
     | 
| 
       128 
     | 
    
         
            -
                tolerance: float = _DEFAULT_TOLERANCE,
         
     | 
| 
       129 
     | 
    
         
            -
            ) -> tuple[Optional[Series], Optional[Series]]:
         
     | 
| 
       130 
     | 
    
         
            -
                """Get response of the steady state concentrations and fluxes to a change of the given constant."""
         
     | 
| 
       131 
     | 
    
         
            -
                old_value = m.constant_values[constant]
         
     | 
| 
       132 
     | 
    
         
            -
                m = copy.deepcopy(m)
         
     | 
| 
       133 
     | 
    
         
            -
             
     | 
| 
       134 
     | 
    
         
            -
                # normalise
         
     | 
| 
       135 
     | 
    
         
            -
                y_ss = Simulator(m, Assimulo, y0).simulate_to_steady_state(tolerance=tolerance)
         
     | 
| 
       136 
     | 
    
         
            -
                if y_ss is None:
         
     | 
| 
       137 
     | 
    
         
            -
                    return None, None
         
     | 
| 
       138 
     | 
    
         
            -
                y_full = y0 | m.get_derived_variables(y_ss)
         
     | 
| 
       139 
     | 
    
         
            -
                y_ss_norm = old_value / np.fromiter(y_full.values(), dtype="float")
         
     | 
| 
       140 
     | 
    
         
            -
                fluxes_norm = old_value / m.get_fluxes(y_full, return_type="array")
         
     | 
| 
       141 
     | 
    
         
            -
             
     | 
| 
       142 
     | 
    
         
            -
                # scan
         
     | 
| 
       143 
     | 
    
         
            -
                ss: list[Array] = []
         
     | 
| 
       144 
     | 
    
         
            -
                fluxes: list[Array] = []
         
     | 
| 
       145 
     | 
    
         
            -
                for new_value in [
         
     | 
| 
       146 
     | 
    
         
            -
                    old_value * (1 + displacement),
         
     | 
| 
       147 
     | 
    
         
            -
                    old_value * (1 - displacement),
         
     | 
| 
       148 
     | 
    
         
            -
                ]:
         
     | 
| 
       149 
     | 
    
         
            -
                    m.update_constant(constant, new_value)
         
     | 
| 
       150 
     | 
    
         
            -
                    y_ss = Simulator(m, Assimulo, y_ss).simulate_to_steady_state(tolerance=tolerance)
         
     | 
| 
       151 
     | 
    
         
            -
                    if y_ss is None:
         
     | 
| 
       152 
     | 
    
         
            -
                        return None, None
         
     | 
| 
       153 
     | 
    
         
            -
                    y_full = y_ss | m.get_derived_variables(y_ss)
         
     | 
| 
       154 
     | 
    
         
            -
                    ss.append(np.fromiter(y_full.values(), dtype="float"))
         
     | 
| 
       155 
     | 
    
         
            -
                    fluxes.append(m.get_fluxes(y_full, return_type="array"))
         
     | 
| 
       156 
     | 
    
         
            -
             
     | 
| 
       157 
     | 
    
         
            -
                conc_resp_coef = (ss[0] - ss[1]) / (2 * displacement * old_value)
         
     | 
| 
       158 
     | 
    
         
            -
                flux_resp_coef = (fluxes[0] - fluxes[1]) / (2 * displacement * old_value)
         
     | 
| 
       159 
     | 
    
         
            -
             
     | 
| 
       160 
     | 
    
         
            -
                return (
         
     | 
| 
       161 
     | 
    
         
            -
                    pd.Series(conc_resp_coef * y_ss_norm, index=list(y_full.keys())).replace(
         
     | 
| 
       162 
     | 
    
         
            -
                        [np.inf, -np.inf], np.nan
         
     | 
| 
       163 
     | 
    
         
            -
                    ),
         
     | 
| 
       164 
     | 
    
         
            -
                    pd.Series(flux_resp_coef * fluxes_norm, index=list(m.reactions)).replace(
         
     | 
| 
       165 
     | 
    
         
            -
                        [np.inf, -np.inf], np.nan
         
     | 
| 
       166 
     | 
    
         
            -
                    ),
         
     | 
| 
       167 
     | 
    
         
            -
                )
         
     | 
| 
       168 
     | 
    
         
            -
             
     | 
| 
       169 
     | 
    
         
            -
             
     | 
| 
       170 
     | 
    
         
            -
            def get_response_coefficients(
         
     | 
| 
       171 
     | 
    
         
            -
                m: Model,
         
     | 
| 
       172 
     | 
    
         
            -
                constants: list[str],
         
     | 
| 
       173 
     | 
    
         
            -
                y0: dict[str, float],
         
     | 
| 
       174 
     | 
    
         
            -
                displacement: float = _DISPLACEMENT,
         
     | 
| 
       175 
     | 
    
         
            -
                tolerance: float = _DEFAULT_TOLERANCE,
         
     | 
| 
       176 
     | 
    
         
            -
            ) -> tuple[DataFrame, DataFrame]:
         
     | 
| 
       177 
     | 
    
         
            -
                crcs: dict[str, pd.Series] = {}
         
     | 
| 
       178 
     | 
    
         
            -
                frcs: dict[str, pd.Series] = {}
         
     | 
| 
       179 
     | 
    
         
            -
                for constant in constants:
         
     | 
| 
       180 
     | 
    
         
            -
                    crc, frc = _get_response_coefficients_single_constant(
         
     | 
| 
       181 
     | 
    
         
            -
                        m=m,
         
     | 
| 
       182 
     | 
    
         
            -
                        constant=constant,
         
     | 
| 
       183 
     | 
    
         
            -
                        y0=y0,
         
     | 
| 
       184 
     | 
    
         
            -
                        displacement=displacement,
         
     | 
| 
       185 
     | 
    
         
            -
                        tolerance=tolerance,
         
     | 
| 
       186 
     | 
    
         
            -
                    )
         
     | 
| 
       187 
     | 
    
         
            -
                    if crc is not None and frc is not None:
         
     | 
| 
       188 
     | 
    
         
            -
                        crcs[constant] = crc
         
     | 
| 
       189 
     | 
    
         
            -
                        frcs[constant] = frc
         
     | 
| 
       190 
     | 
    
         
            -
                return pd.DataFrame(crcs).T, pd.DataFrame(frcs).T
         
     | 
| 
       191 
     | 
    
         
            -
             
     | 
| 
       192 
     | 
    
         
            -
             
     | 
| 
       193 
     | 
    
         
            -
            def plot_coefficient_heatmap(
         
     | 
| 
       194 
     | 
    
         
            -
                df: pd.DataFrame,
         
     | 
| 
       195 
     | 
    
         
            -
                title: str,
         
     | 
| 
       196 
     | 
    
         
            -
                cmap: str = "RdBu_r",
         
     | 
| 
       197 
     | 
    
         
            -
                norm: plt.Normalize | None = None,
         
     | 
| 
       198 
     | 
    
         
            -
                annotate: bool = True,
         
     | 
| 
       199 
     | 
    
         
            -
                colorbar: bool = True,
         
     | 
| 
       200 
     | 
    
         
            -
                xlabel: str | None = None,
         
     | 
| 
       201 
     | 
    
         
            -
                ylabel: str | None = None,
         
     | 
| 
       202 
     | 
    
         
            -
                ax: Optional[Axes] = None,
         
     | 
| 
       203 
     | 
    
         
            -
                cax: Optional[Axes] = None,
         
     | 
| 
       204 
     | 
    
         
            -
                figsize: tuple[float, float] = (8, 6),
         
     | 
| 
       205 
     | 
    
         
            -
            ) -> tuple[Figure, Axes, QuadMesh]:
         
     | 
| 
       206 
     | 
    
         
            -
                df = df.T.round(2)
         
     | 
| 
       207 
     | 
    
         
            -
                if norm is None:
         
     | 
| 
       208 
     | 
    
         
            -
                    end = abs(df.abs().max().max())
         
     | 
| 
       209 
     | 
    
         
            -
                    norm = _get_norm(vmin=-end, vmax=end)
         
     | 
| 
       210 
     | 
    
         
            -
             
     | 
| 
       211 
     | 
    
         
            -
                fig, ax, hm = _heatmap_from_dataframe(
         
     | 
| 
       212 
     | 
    
         
            -
                    df=df,
         
     | 
| 
       213 
     | 
    
         
            -
                    title=title,
         
     | 
| 
       214 
     | 
    
         
            -
                    xlabel=xlabel,
         
     | 
| 
       215 
     | 
    
         
            -
                    ylabel=ylabel,
         
     | 
| 
       216 
     | 
    
         
            -
                    annotate=annotate,
         
     | 
| 
       217 
     | 
    
         
            -
                    colorbar=colorbar,
         
     | 
| 
       218 
     | 
    
         
            -
                    cmap=cmap,
         
     | 
| 
       219 
     | 
    
         
            -
                    norm=norm,
         
     | 
| 
       220 
     | 
    
         
            -
                    ax=ax,
         
     | 
| 
       221 
     | 
    
         
            -
                    cax=cax,
         
     | 
| 
       222 
     | 
    
         
            -
                    figsize=figsize,
         
     | 
| 
       223 
     | 
    
         
            -
                )
         
     | 
| 
       224 
     | 
    
         
            -
                ax.set_xticklabels(ax.get_xticklabels(), **{"rotation": 45, "ha": "right"})
         
     | 
| 
       225 
     | 
    
         
            -
                return fig, ax, hm
         
     | 
| 
       226 
     | 
    
         
            -
             
     | 
| 
       227 
     | 
    
         
            -
             
     | 
| 
       228 
     | 
    
         
            -
            def plot_multiple(
         
     | 
| 
       229 
     | 
    
         
            -
                dfs: list[pd.DataFrame],
         
     | 
| 
       230 
     | 
    
         
            -
                titles: list[str],
         
     | 
| 
       231 
     | 
    
         
            -
                cmap: str = "RdBu_r",
         
     | 
| 
       232 
     | 
    
         
            -
                annotate: bool = True,
         
     | 
| 
       233 
     | 
    
         
            -
                colorbar: bool = True,
         
     | 
| 
       234 
     | 
    
         
            -
                figsize: tuple[float, float] = (20, 10),
         
     | 
| 
       235 
     | 
    
         
            -
                norm: plt.Normalize | None = None,
         
     | 
| 
       236 
     | 
    
         
            -
            ) -> tuple[Figure, Axes]:
         
     | 
| 
       237 
     | 
    
         
            -
                if norm is None:
         
     | 
| 
       238 
     | 
    
         
            -
                    vmin = min(i.values.min() for i in dfs)
         
     | 
| 
       239 
     | 
    
         
            -
                    vmax = max(i.values.max() for i in dfs)
         
     | 
| 
       240 
     | 
    
         
            -
                    end = max(abs(vmin), abs(vmax))
         
     | 
| 
       241 
     | 
    
         
            -
                    norm = _get_norm(vmin=-end, vmax=end)
         
     | 
| 
       242 
     | 
    
         
            -
             
     | 
| 
       243 
     | 
    
         
            -
                n_cols = 2
         
     | 
| 
       244 
     | 
    
         
            -
                n_rows = math.ceil(len(dfs) / n_cols)
         
     | 
| 
       245 
     | 
    
         
            -
             
     | 
| 
       246 
     | 
    
         
            -
                if figsize is None:
         
     | 
| 
       247 
     | 
    
         
            -
                    figsize = (n_cols * 4, n_rows * 4)
         
     | 
| 
       248 
     | 
    
         
            -
             
     | 
| 
       249 
     | 
    
         
            -
                fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=figsize, squeeze=False)
         
     | 
| 
       250 
     | 
    
         
            -
                axs = cast(Axes, axs)
         
     | 
| 
       251 
     | 
    
         
            -
                for ax, df, title in zip(axs.ravel(), dfs, titles):
         
     | 
| 
       252 
     | 
    
         
            -
                    plot_coefficient_heatmap(
         
     | 
| 
       253 
     | 
    
         
            -
                        df=df,
         
     | 
| 
       254 
     | 
    
         
            -
                        title=title,
         
     | 
| 
       255 
     | 
    
         
            -
                        cmap=cmap,
         
     | 
| 
       256 
     | 
    
         
            -
                        annotate=annotate,
         
     | 
| 
       257 
     | 
    
         
            -
                        colorbar=False,
         
     | 
| 
       258 
     | 
    
         
            -
                        norm=norm,
         
     | 
| 
       259 
     | 
    
         
            -
                        ax=ax,
         
     | 
| 
       260 
     | 
    
         
            -
                    )
         
     | 
| 
       261 
     | 
    
         
            -
             
     | 
| 
       262 
     | 
    
         
            -
                # Add a colorbar
         
     | 
| 
       263 
     | 
    
         
            -
                if colorbar:
         
     | 
| 
       264 
     | 
    
         
            -
                    cb = fig.colorbar(
         
     | 
| 
       265 
     | 
    
         
            -
                        cm.ScalarMappable(norm=norm, cmap=cmap),
         
     | 
| 
       266 
     | 
    
         
            -
                        ax=axs.ravel()[-1],
         
     | 
| 
       267 
     | 
    
         
            -
                    )
         
     | 
| 
       268 
     | 
    
         
            -
                    cb.outline.set_linewidth(0)
         
     | 
| 
       269 
     | 
    
         
            -
                fig.tight_layout()
         
     | 
| 
       270 
     | 
    
         
            -
                return fig, axs
         
     |