mxlpy 0.22.0__py3-none-any.whl → 0.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mxlpy/plot.py CHANGED
@@ -28,6 +28,7 @@ import numpy as np
28
28
  import pandas as pd
29
29
  import seaborn as sns
30
30
  from cycler import cycler
31
+ from matplotlib import colormaps
31
32
  from matplotlib import pyplot as plt
32
33
  from matplotlib.axes import Axes
33
34
  from matplotlib.colors import (
@@ -37,7 +38,10 @@ from matplotlib.colors import (
37
38
  colorConverter, # type: ignore
38
39
  )
39
40
  from matplotlib.figure import Figure
41
+ from matplotlib.legend import Legend
42
+ from matplotlib.patches import Patch
40
43
  from mpl_toolkits.mplot3d import Axes3D
44
+ from wadler_lindig import pformat
41
45
 
42
46
  from mxlpy.label_map import LabelMapper
43
47
 
@@ -102,6 +106,10 @@ class Axs:
102
106
  """Length of axes."""
103
107
  return len(self.axs.flatten())
104
108
 
109
+ def __repr__(self) -> str:
110
+ """Return default representation."""
111
+ return pformat(self)
112
+
105
113
  @overload
106
114
  def __getitem__(self, row_col: int) -> Axes: ...
107
115
 
@@ -210,6 +218,28 @@ def _partition_by_order_of_magnitude(s: pd.Series) -> list[list[str]]:
210
218
  ]
211
219
 
212
220
 
221
+ def _combine_small_groups(
222
+ groups: list[list[str]], min_group_size: int
223
+ ) -> list[list[str]]:
224
+ """Combine smaller groups."""
225
+ result = []
226
+ current_group = groups[0]
227
+
228
+ for next_group in groups[1:]:
229
+ if len(current_group) < min_group_size:
230
+ current_group.extend(next_group)
231
+ else:
232
+ result.append(current_group)
233
+ current_group = next_group
234
+
235
+ # Last group
236
+ if len(current_group) < min_group_size:
237
+ result[-1].extend(current_group)
238
+ else:
239
+ result.append(current_group)
240
+ return result
241
+
242
+
213
243
  def _split_large_groups[T](groups: list[list[T]], max_size: int) -> list[list[T]]:
214
244
  """Split groups larger than the given size into smaller groups."""
215
245
  return list(
@@ -513,7 +543,7 @@ def grid_layout(
513
543
  n_rows = math.ceil(n_groups / n_cols)
514
544
  figsize = (n_cols * col_width, n_rows * row_height)
515
545
 
516
- return _default_fig_axs(
546
+ fig, axs = _default_fig_axs(
517
547
  ncols=n_cols,
518
548
  nrows=n_rows,
519
549
  figsize=figsize,
@@ -522,6 +552,12 @@ def grid_layout(
522
552
  grid=grid,
523
553
  )
524
554
 
555
+ # Disable unused plots by default
556
+ axsl = list(axs)
557
+ for i in range(n_groups, len(axs)):
558
+ axsl[i].set_visible(False)
559
+ return fig, axs
560
+
525
561
 
526
562
  ##########################################################################
527
563
  # Plots
@@ -541,7 +577,7 @@ def bars(
541
577
  sns.barplot(data=cast(pd.DataFrame, x), ax=ax)
542
578
 
543
579
  if xlabel is None:
544
- xlabel = x.index.name if x.index.name is not None else ""
580
+ xlabel = x.index.name if x.index.name is not None else "" # type: ignore
545
581
  _default_labels(ax, xlabel=xlabel, ylabel=ylabel)
546
582
  if isinstance(x, pd.DataFrame):
547
583
  ax.legend(x.columns)
@@ -583,10 +619,6 @@ def bars_grouped(
583
619
  ylabel=ylabel,
584
620
  )
585
621
 
586
- axsl = list(axs)
587
- for i in range(len(groups), len(axs)):
588
- axsl[i].set_visible(False)
589
-
590
622
  return fig, axs
591
623
 
592
624
 
@@ -596,18 +628,20 @@ def bars_autogrouped(
596
628
  n_cols: int = 2,
597
629
  col_width: float = 4,
598
630
  row_height: float = 3,
631
+ min_group_size: int = 1,
599
632
  max_group_size: int = 6,
600
633
  grid: bool = True,
601
634
  xlabel: str | None = None,
602
635
  ylabel: str | None = None,
603
636
  ) -> FigAxs:
604
637
  """Plot a series or dataframe with lines grouped by order of magnitude."""
605
- group_names = _split_large_groups(
638
+ group_names = (
606
639
  _partition_by_order_of_magnitude(s)
607
640
  if isinstance(s, pd.Series)
608
- else _partition_by_order_of_magnitude(s.max()),
609
- max_size=max_group_size,
641
+ else _partition_by_order_of_magnitude(s.max())
610
642
  )
643
+ group_names = _combine_small_groups(group_names, min_group_size=min_group_size)
644
+ group_names = _split_large_groups(group_names, max_size=max_group_size)
611
645
 
612
646
  groups: list[pd.Series] | list[pd.DataFrame] = (
613
647
  [s.loc[group] for group in group_names]
@@ -651,7 +685,7 @@ def lines(
651
685
  )
652
686
  _default_labels(
653
687
  ax,
654
- xlabel=x.index.name if xlabel is None else xlabel,
688
+ xlabel=x.index.name if xlabel is None else xlabel, # type: ignore
655
689
  ylabel=ylabel,
656
690
  )
657
691
  if legend:
@@ -711,10 +745,6 @@ def lines_grouped(
711
745
  ylabel=ylabel,
712
746
  )
713
747
 
714
- axsl = list(axs)
715
- for i in range(len(groups), len(axs)):
716
- axsl[i].set_visible(False)
717
-
718
748
  return fig, axs
719
749
 
720
750
 
@@ -724,6 +754,7 @@ def line_autogrouped(
724
754
  n_cols: int = 2,
725
755
  col_width: float = 4,
726
756
  row_height: float = 3,
757
+ min_group_size: int = 1,
727
758
  max_group_size: int = 6,
728
759
  grid: bool = True,
729
760
  xlabel: str | None = None,
@@ -733,12 +764,13 @@ def line_autogrouped(
733
764
  linestyle: Linestyle | None = None,
734
765
  ) -> FigAxs:
735
766
  """Plot a series or dataframe with lines grouped by order of magnitude."""
736
- group_names = _split_large_groups(
767
+ group_names = (
737
768
  _partition_by_order_of_magnitude(s)
738
769
  if isinstance(s, pd.Series)
739
- else _partition_by_order_of_magnitude(s.max()),
740
- max_size=max_group_size,
770
+ else _partition_by_order_of_magnitude(s.max())
741
771
  )
772
+ group_names = _combine_small_groups(group_names, min_group_size=min_group_size)
773
+ group_names = _split_large_groups(group_names, max_size=max_group_size)
742
774
 
743
775
  groups: list[pd.Series] | list[pd.DataFrame] = (
744
776
  [s.loc[group] for group in group_names]
@@ -792,7 +824,11 @@ def line_mean_std(
792
824
  color=color,
793
825
  alpha=alpha,
794
826
  )
795
- _default_labels(ax, xlabel=df.index.name, ylabel=None)
827
+ _default_labels(
828
+ ax,
829
+ xlabel=df.index.name, # type: ignore
830
+ ylabel=None,
831
+ )
796
832
  return fig, ax
797
833
 
798
834
 
@@ -865,11 +901,11 @@ def _create_heatmap(
865
901
  if title is not None:
866
902
  ax.set_title(title)
867
903
  ax.set_xticks(
868
- np.arange(0, len(df.columns), 1, dtype=float) + 0.5,
904
+ np.arange(0, len(df.columns), 1, dtype=float) + 0.5, # type: ignore
869
905
  labels=xticklabels,
870
906
  )
871
907
  ax.set_yticks(
872
- np.arange(0, len(df.index), 1, dtype=float) + 0.5,
908
+ np.arange(0, len(df.index), 1, dtype=float) + 0.5, # type: ignore
873
909
  labels=yticklabels,
874
910
  )
875
911
 
@@ -906,8 +942,8 @@ def heatmap(
906
942
  ax=ax,
907
943
  df=df,
908
944
  title=title,
909
- xlabel=df.index.name,
910
- ylabel=df.columns.name,
945
+ xlabel=df.index.name, # type: ignore
946
+ ylabel=df.columns.name, # type: ignore
911
947
  xticklabels=cast(list, df.columns),
912
948
  yticklabels=cast(list, df.index),
913
949
  annotate=annotate,
@@ -943,8 +979,8 @@ def heatmap_from_2d_idx(
943
979
 
944
980
  return _create_heatmap(
945
981
  df=df2d,
946
- xlabel=df2d.index.name,
947
- ylabel=df2d.columns.name,
982
+ xlabel=df2d.index.name, # type: ignore
983
+ ylabel=df2d.columns.name, # type: ignore
948
984
  xticklabels=[f"{i:.2f}" for i in df2d.columns],
949
985
  yticklabels=[f"{i:.2f}" for i in df2d.index],
950
986
  ax=ax,
@@ -1064,11 +1100,6 @@ def shade_protocol(
1064
1100
  add_legend: bool = True,
1065
1101
  ) -> None:
1066
1102
  """Shade the given protocol on the given axis."""
1067
- from matplotlib import colormaps
1068
- from matplotlib.colors import Normalize
1069
- from matplotlib.legend import Legend
1070
- from matplotlib.patches import Patch
1071
-
1072
1103
  cmap = colormaps[cmap_name]
1073
1104
  norm = Normalize(
1074
1105
  vmin=protocol.min() if vmin is None else vmin,
mxlpy/sbml/_data.py CHANGED
@@ -3,6 +3,8 @@ from __future__ import annotations
3
3
  from dataclasses import dataclass
4
4
  from typing import TYPE_CHECKING
5
5
 
6
+ from wadler_lindig import pformat
7
+
6
8
  if TYPE_CHECKING:
7
9
  from collections.abc import Mapping
8
10
 
@@ -26,18 +28,30 @@ class AtomicUnit:
26
28
  scale: int
27
29
  multiplier: float
28
30
 
31
+ def __repr__(self) -> str:
32
+ """Return default representation."""
33
+ return pformat(self)
34
+
29
35
 
30
36
  @dataclass
31
37
  class CompositeUnit:
32
38
  sbml_id: str
33
39
  units: list
34
40
 
41
+ def __repr__(self) -> str:
42
+ """Return default representation."""
43
+ return pformat(self)
44
+
35
45
 
36
46
  @dataclass
37
47
  class Parameter:
38
48
  value: float
39
49
  is_constant: bool
40
50
 
51
+ def __repr__(self) -> str:
52
+ """Return default representation."""
53
+ return pformat(self)
54
+
41
55
 
42
56
  @dataclass
43
57
  class Compartment:
@@ -47,6 +61,10 @@ class Compartment:
47
61
  units: str
48
62
  is_constant: bool
49
63
 
64
+ def __repr__(self) -> str:
65
+ """Return default representation."""
66
+ return pformat(self)
67
+
50
68
 
51
69
  @dataclass
52
70
  class Compound:
@@ -58,21 +76,37 @@ class Compound:
58
76
  is_constant: bool
59
77
  is_concentration: bool
60
78
 
79
+ def __repr__(self) -> str:
80
+ """Return default representation."""
81
+ return pformat(self)
82
+
61
83
 
62
84
  @dataclass
63
85
  class Derived:
64
86
  body: str
65
87
  args: list[str]
66
88
 
89
+ def __repr__(self) -> str:
90
+ """Return default representation."""
91
+ return pformat(self)
92
+
67
93
 
68
94
  @dataclass
69
95
  class Function:
70
96
  body: str
71
97
  args: list[str]
72
98
 
99
+ def __repr__(self) -> str:
100
+ """Return default representation."""
101
+ return pformat(self)
102
+
73
103
 
74
104
  @dataclass
75
105
  class Reaction:
76
106
  body: str
77
107
  stoichiometry: Mapping[str, float | str]
78
108
  args: list[str]
109
+
110
+ def __repr__(self) -> str:
111
+ """Return default representation."""
112
+ return pformat(self)
mxlpy/sbml/_export.py CHANGED
@@ -10,7 +10,7 @@ import numpy as np
10
10
 
11
11
  from mxlpy.meta.source_tools import get_fn_ast
12
12
  from mxlpy.sbml._data import AtomicUnit, Compartment
13
- from mxlpy.types import Derived
13
+ from mxlpy.types import Derived, InitialAssignment
14
14
 
15
15
  if TYPE_CHECKING:
16
16
  from collections.abc import Callable
@@ -447,15 +447,16 @@ def _create_sbml_variables(
447
447
  cpd.setConstant(False)
448
448
  cpd.setBoundaryCondition(False)
449
449
  cpd.setHasOnlySubstanceUnits(False)
450
+ cpd.setCompartment("compartment")
450
451
  # cpd.setUnit() # FIXME: implement
451
- if isinstance((init := variable.initial_value), Derived):
452
+ if isinstance((init := variable.initial_value), InitialAssignment):
452
453
  ar = sbml_model.createInitialAssignment()
453
454
  ar.setId(_convert_id_to_sbml(id_=name, prefix="IA"))
454
455
  ar.setName(_convert_id_to_sbml(id_=name, prefix="IA"))
455
456
  ar.setVariable(_convert_id_to_sbml(id_=name, prefix="IA"))
456
457
  ar.setMath(_sbmlify_fn(init.fn, init.args))
457
458
  else:
458
- cpd.setInitialAmount(float(init))
459
+ cpd.setInitialConcentration(float(init))
459
460
 
460
461
 
461
462
  def _create_sbml_derived_variables(*, model: Model, sbml_model: libsbml.Model) -> None:
@@ -494,11 +495,19 @@ def _create_sbml_parameters(
494
495
  sbml_model : libsbml.Model
495
496
 
496
497
  """
497
- for parameter_id, value in model.get_parameter_values().items():
498
+ for name, value in model.get_raw_parameters().items():
498
499
  k = sbml_model.createParameter()
499
- k.setId(_convert_id_to_sbml(id_=parameter_id, prefix="PAR"))
500
+ k.setId(_convert_id_to_sbml(id_=name, prefix="PAR"))
500
501
  k.setConstant(True)
501
- k.setValue(float(value))
502
+
503
+ if isinstance((init := value.value), InitialAssignment):
504
+ ar = sbml_model.createInitialAssignment()
505
+ ar.setId(_convert_id_to_sbml(id_=name, prefix="IA"))
506
+ ar.setName(_convert_id_to_sbml(id_=name, prefix="IA"))
507
+ ar.setVariable(_convert_id_to_sbml(id_=name, prefix="IA"))
508
+ ar.setMath(_sbmlify_fn(init.fn, init.args))
509
+ else:
510
+ k.setValue(float(init))
502
511
 
503
512
 
504
513
  def _create_sbml_derived_parameters(*, model: Model, sbml_model: libsbml.Model) -> None:
@@ -583,8 +592,8 @@ def _default_compartments(
583
592
  ) -> dict[str, Compartment]:
584
593
  if compartments is None:
585
594
  return {
586
- "c": Compartment(
587
- name="cytosol",
595
+ "compartment": Compartment(
596
+ name="compartment",
588
597
  dimensions=3,
589
598
  size=1,
590
599
  units="litre",