mxlpy 0.18.0__py3-none-any.whl → 0.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. mxlpy/__init__.py +13 -9
  2. mxlpy/compare.py +240 -0
  3. mxlpy/experimental/diff.py +16 -4
  4. mxlpy/fit.py +6 -11
  5. mxlpy/fns.py +37 -42
  6. mxlpy/identify.py +10 -3
  7. mxlpy/integrators/__init__.py +4 -3
  8. mxlpy/integrators/int_assimulo.py +16 -9
  9. mxlpy/integrators/int_scipy.py +13 -9
  10. mxlpy/label_map.py +7 -3
  11. mxlpy/linear_label_map.py +4 -2
  12. mxlpy/mc.py +5 -14
  13. mxlpy/mca.py +4 -4
  14. mxlpy/meta/__init__.py +6 -4
  15. mxlpy/meta/codegen_latex.py +180 -87
  16. mxlpy/meta/codegen_modebase.py +3 -1
  17. mxlpy/meta/codegen_py.py +11 -3
  18. mxlpy/meta/source_tools.py +9 -5
  19. mxlpy/model.py +187 -100
  20. mxlpy/nn/__init__.py +24 -5
  21. mxlpy/nn/_keras.py +92 -0
  22. mxlpy/nn/_torch.py +25 -18
  23. mxlpy/npe/__init__.py +21 -16
  24. mxlpy/npe/_keras.py +326 -0
  25. mxlpy/npe/_torch.py +56 -60
  26. mxlpy/parallel.py +5 -2
  27. mxlpy/parameterise.py +11 -3
  28. mxlpy/plot.py +205 -52
  29. mxlpy/report.py +33 -8
  30. mxlpy/sbml/__init__.py +3 -3
  31. mxlpy/sbml/_data.py +7 -6
  32. mxlpy/sbml/_export.py +8 -1
  33. mxlpy/sbml/_mathml.py +8 -7
  34. mxlpy/sbml/_name_conversion.py +5 -1
  35. mxlpy/scan.py +14 -19
  36. mxlpy/simulator.py +34 -31
  37. mxlpy/surrogates/__init__.py +25 -17
  38. mxlpy/surrogates/_keras.py +139 -0
  39. mxlpy/surrogates/_poly.py +25 -10
  40. mxlpy/surrogates/_qss.py +34 -0
  41. mxlpy/surrogates/_torch.py +50 -32
  42. mxlpy/symbolic/__init__.py +5 -3
  43. mxlpy/symbolic/strikepy.py +5 -2
  44. mxlpy/symbolic/symbolic_model.py +14 -5
  45. mxlpy/types.py +61 -120
  46. {mxlpy-0.18.0.dist-info → mxlpy-0.20.0.dist-info}/METADATA +25 -24
  47. mxlpy-0.20.0.dist-info/RECORD +55 -0
  48. mxlpy/nn/_tensorflow.py +0 -0
  49. mxlpy-0.18.0.dist-info/RECORD +0 -51
  50. {mxlpy-0.18.0.dist-info → mxlpy-0.20.0.dist-info}/WHEEL +0 -0
  51. {mxlpy-0.18.0.dist-info → mxlpy-0.20.0.dist-info}/licenses/LICENSE +0 -0
mxlpy/plot.py CHANGED
@@ -19,10 +19,41 @@ Functions:
19
19
  from __future__ import annotations
20
20
 
21
21
  import contextlib
22
+ import itertools as it
23
+ import math
24
+ from dataclasses import dataclass
25
+ from typing import TYPE_CHECKING, Any, Literal, cast, overload
22
26
 
27
+ import numpy as np
28
+ import pandas as pd
29
+ import seaborn as sns
23
30
  from cycler import cycler
31
+ from matplotlib import pyplot as plt
32
+ from matplotlib.axes import Axes
33
+ from matplotlib.colors import (
34
+ LogNorm,
35
+ Normalize,
36
+ SymLogNorm,
37
+ colorConverter, # type: ignore
38
+ )
39
+ from matplotlib.figure import Figure
40
+ from mpl_toolkits.mplot3d import Axes3D
41
+
42
+ from mxlpy.label_map import LabelMapper
43
+
44
+ if TYPE_CHECKING:
45
+ from collections.abc import Generator, Iterable, Iterator
46
+
47
+ from matplotlib.collections import QuadMesh
48
+ from numpy.typing import NDArray
49
+
50
+ from mxlpy.linear_label_map import LinearLabelMapper
51
+ from mxlpy.model import Model
52
+ from mxlpy.types import Array, ArrayLike
53
+
24
54
 
25
55
  __all__ = [
56
+ "Axs",
26
57
  "Color",
27
58
  "FigAx",
28
59
  "FigAxs",
@@ -31,7 +62,10 @@ __all__ = [
31
62
  "RGBA",
32
63
  "add_grid",
33
64
  "bars",
65
+ "bars_autogrouped",
66
+ "bars_grouped",
34
67
  "context",
68
+ "grid_labels",
35
69
  "grid_layout",
36
70
  "heatmap",
37
71
  "heatmap_from_2d_idx",
@@ -53,37 +87,45 @@ __all__ = [
53
87
  "violins_from_2d_idx",
54
88
  ]
55
89
 
56
- import itertools as it
57
- import math
58
- from typing import TYPE_CHECKING, Any, Literal, cast
59
90
 
60
- import numpy as np
61
- import pandas as pd
62
- import seaborn as sns
63
- from matplotlib import pyplot as plt
64
- from matplotlib.axes import Axes
65
- from matplotlib.colors import (
66
- LogNorm,
67
- Normalize,
68
- SymLogNorm,
69
- colorConverter, # type: ignore
70
- )
71
- from matplotlib.figure import Figure
72
- from mpl_toolkits.mplot3d import Axes3D
91
+ @dataclass
92
+ class Axs:
93
+ """Convenience container axes."""
73
94
 
74
- from mxlpy.label_map import LabelMapper
95
+ axs: NDArray[np.object_]
75
96
 
76
- if TYPE_CHECKING:
77
- from collections.abc import Generator, Iterable
97
+ def __iter__(self) -> Iterator[Axes]:
98
+ """Get flat axes."""
99
+ yield from cast(list[Axes], self.axs.flatten())
78
100
 
79
- from matplotlib.collections import QuadMesh
101
+ def __len__(self) -> int:
102
+ """Length of axes."""
103
+ return len(self.axs.flatten())
104
+
105
+ @overload
106
+ def __getitem__(self, row_col: int) -> Axes: ...
107
+
108
+ @overload
109
+ def __getitem__(self, row_col: slice) -> NDArray[np.object_]: ...
110
+
111
+ @overload
112
+ def __getitem__(self, row_col: tuple[int, int]) -> Axes: ...
113
+
114
+ @overload
115
+ def __getitem__(self, row_col: tuple[slice, int]) -> NDArray[np.object_]: ...
116
+
117
+ @overload
118
+ def __getitem__(self, row_col: tuple[int, slice]) -> NDArray[np.object_]: ...
119
+
120
+ def __getitem__(
121
+ self, row_col: int | slice | tuple[int | slice, int | slice]
122
+ ) -> Axes | NDArray[np.object_]:
123
+ """Get Axes or Array of Axes."""
124
+ return cast(Axes, self.axs[row_col])
80
125
 
81
- from mxlpy.linear_label_map import LinearLabelMapper
82
- from mxlpy.model import Model
83
- from mxlpy.types import Array, ArrayLike
84
126
 
85
127
  type FigAx = tuple[Figure, Axes]
86
- type FigAxs = tuple[Figure, list[Axes]]
128
+ type FigAxs = tuple[Figure, Axs]
87
129
 
88
130
  type Linestyle = Literal[
89
131
  "solid",
@@ -97,6 +139,7 @@ type RGB = tuple[float, float, float]
97
139
  type RGBA = tuple[float, float, float, float]
98
140
  type Color = str | RGB | RGBA
99
141
 
142
+
100
143
  ##########################################################################
101
144
  # Helpers
102
145
  ##########################################################################
@@ -158,7 +201,12 @@ def _partition_by_order_of_magnitude(s: pd.Series) -> list[list[str]]:
158
201
  """Partition a series into groups based on the order of magnitude of the values."""
159
202
  return [
160
203
  i.to_list()
161
- for i in np.floor(np.log10(s)).to_frame(name=0).groupby(0)[0].groups.values() # type: ignore
204
+ for i in s.abs()
205
+ .apply(np.log10)
206
+ .apply(np.floor)
207
+ .to_frame(name=0)
208
+ .groupby(0)[0]
209
+ .groups.values() # type: ignore
162
210
  ]
163
211
 
164
212
 
@@ -258,6 +306,18 @@ def add_grid(ax: Axes) -> Axes:
258
306
  return ax
259
307
 
260
308
 
309
+ def grid_labels(
310
+ axs: Axs,
311
+ xlabel: str | None = None,
312
+ ylabel: str | None = None,
313
+ ) -> None:
314
+ """Apply labels to left and bottom axes."""
315
+ for ax in axs[-1, :]:
316
+ ax.set_xlabel(xlabel)
317
+ for ax in axs[:, 0]:
318
+ ax.set_ylabel(ylabel)
319
+
320
+
261
321
  def rotate_xlabels(
262
322
  ax: Axes,
263
323
  rotation: float = 45,
@@ -367,7 +427,6 @@ def _default_fig_ax(
367
427
 
368
428
 
369
429
  def _default_fig_axs(
370
- axs: list[Axes] | None,
371
430
  *,
372
431
  ncols: int,
373
432
  nrows: int,
@@ -391,19 +450,16 @@ def _default_fig_axs(
391
450
  Figure and Axes objects for the plot.
392
451
 
393
452
  """
394
- if axs is None or len(axs) == 0:
395
- fig, axs_array = plt.subplots(
396
- nrows=nrows,
397
- ncols=ncols,
398
- sharex=sharex,
399
- sharey=sharey,
400
- figsize=figsize,
401
- squeeze=False,
402
- layout="constrained",
403
- )
404
- axs = list(axs_array.flatten())
405
- else:
406
- fig = cast(Figure, axs[0].get_figure())
453
+ fig, axs_array = plt.subplots(
454
+ nrows=nrows,
455
+ ncols=ncols,
456
+ sharex=sharex,
457
+ sharey=sharey,
458
+ figsize=figsize,
459
+ squeeze=False,
460
+ layout="constrained",
461
+ )
462
+ axs = Axs(axs_array)
407
463
 
408
464
  if grid:
409
465
  for ax in axs:
@@ -433,7 +489,6 @@ def two_axes(
433
489
  ) -> FigAxs:
434
490
  """Create a figure with two axes."""
435
491
  return _default_fig_axs(
436
- None,
437
492
  ncols=2,
438
493
  nrows=1,
439
494
  figsize=figsize,
@@ -448,18 +503,17 @@ def grid_layout(
448
503
  *,
449
504
  n_cols: int = 2,
450
505
  col_width: float = 3,
451
- row_height: float = 4,
506
+ row_height: float = 2.5,
452
507
  sharex: bool = True,
453
508
  sharey: bool = False,
454
509
  grid: bool = True,
455
- ) -> tuple[Figure, list[Axes]]:
510
+ ) -> FigAxs:
456
511
  """Create a grid layout for the given number of groups."""
457
512
  n_cols = min(n_groups, n_cols)
458
513
  n_rows = math.ceil(n_groups / n_cols)
459
514
  figsize = (n_cols * col_width, n_rows * row_height)
460
515
 
461
516
  return _default_fig_axs(
462
- None,
463
517
  ncols=n_cols,
464
518
  nrows=n_rows,
465
519
  figsize=figsize,
@@ -475,19 +529,103 @@ def grid_layout(
475
529
 
476
530
 
477
531
  def bars(
478
- x: pd.DataFrame,
532
+ x: pd.Series | pd.DataFrame,
479
533
  *,
480
534
  ax: Axes | None = None,
481
535
  grid: bool = True,
536
+ xlabel: str | None = None,
537
+ ylabel: str | None = None,
482
538
  ) -> FigAx:
483
539
  """Plot multiple lines on the same axis."""
484
540
  fig, ax = _default_fig_ax(ax=ax, grid=grid)
485
- sns.barplot(data=x, ax=ax)
486
- _default_labels(ax, xlabel=x.index.name, ylabel=None)
487
- ax.legend(x.columns)
541
+ sns.barplot(data=cast(pd.DataFrame, x), ax=ax)
542
+
543
+ if xlabel is None:
544
+ xlabel = x.index.name if x.index.name is not None else ""
545
+ _default_labels(ax, xlabel=xlabel, ylabel=ylabel)
546
+ if isinstance(x, pd.DataFrame):
547
+ ax.legend(x.columns)
488
548
  return fig, ax
489
549
 
490
550
 
551
+ def bars_grouped(
552
+ groups: list[pd.DataFrame] | list[pd.Series],
553
+ *,
554
+ n_cols: int = 2,
555
+ col_width: float = 3,
556
+ row_height: float = 4,
557
+ sharey: bool = False,
558
+ grid: bool = True,
559
+ xlabel: str | None = None,
560
+ ylabel: str | None = None,
561
+ ) -> FigAxs:
562
+ """Plot multiple groups of lines on separate axes."""
563
+ fig, axs = grid_layout(
564
+ len(groups),
565
+ n_cols=n_cols,
566
+ col_width=col_width,
567
+ row_height=row_height,
568
+ sharex=False,
569
+ sharey=sharey,
570
+ grid=grid,
571
+ )
572
+
573
+ for group, ax in zip(
574
+ groups,
575
+ axs,
576
+ strict=False,
577
+ ):
578
+ bars(
579
+ group,
580
+ ax=ax,
581
+ grid=grid,
582
+ xlabel=xlabel,
583
+ ylabel=ylabel,
584
+ )
585
+
586
+ axsl = list(axs)
587
+ for i in range(len(groups), len(axs)):
588
+ axsl[i].set_visible(False)
589
+
590
+ return fig, axs
591
+
592
+
593
+ def bars_autogrouped(
594
+ s: pd.Series | pd.DataFrame,
595
+ *,
596
+ n_cols: int = 2,
597
+ col_width: float = 4,
598
+ row_height: float = 3,
599
+ max_group_size: int = 6,
600
+ grid: bool = True,
601
+ xlabel: str | None = None,
602
+ ylabel: str | None = None,
603
+ ) -> FigAxs:
604
+ """Plot a series or dataframe with lines grouped by order of magnitude."""
605
+ group_names = _split_large_groups(
606
+ _partition_by_order_of_magnitude(s)
607
+ if isinstance(s, pd.Series)
608
+ else _partition_by_order_of_magnitude(s.max()),
609
+ max_size=max_group_size,
610
+ )
611
+
612
+ groups: list[pd.Series] | list[pd.DataFrame] = (
613
+ [s.loc[group] for group in group_names]
614
+ if isinstance(s, pd.Series)
615
+ else [s.loc[:, group] for group in group_names]
616
+ )
617
+
618
+ return bars_grouped(
619
+ groups,
620
+ n_cols=n_cols,
621
+ col_width=col_width,
622
+ row_height=row_height,
623
+ grid=grid,
624
+ xlabel=xlabel,
625
+ ylabel=ylabel,
626
+ )
627
+
628
+
491
629
  def lines(
492
630
  x: pd.DataFrame | pd.Series,
493
631
  *,
@@ -498,6 +636,8 @@ def lines(
498
636
  legend: bool = True,
499
637
  linewidth: float | None = None,
500
638
  linestyle: Linestyle | None = None,
639
+ xlabel: str | None = None,
640
+ ylabel: str | None = None,
501
641
  ) -> FigAx:
502
642
  """Plot multiple lines on the same axis."""
503
643
  fig, ax = _default_fig_ax(ax=ax, grid=grid)
@@ -509,7 +649,11 @@ def lines(
509
649
  linestyle=linestyle,
510
650
  color=color,
511
651
  )
512
- _default_labels(ax, xlabel=x.index.name, ylabel=None)
652
+ _default_labels(
653
+ ax,
654
+ xlabel=x.index.name if xlabel is None else xlabel,
655
+ ylabel=ylabel,
656
+ )
513
657
  if legend:
514
658
  names = x.columns if isinstance(x, pd.DataFrame) else [str(x.name)]
515
659
  for line, name in zip(_lines, names, strict=True):
@@ -533,6 +677,8 @@ def lines_grouped(
533
677
  sharex: bool = True,
534
678
  sharey: bool = False,
535
679
  grid: bool = True,
680
+ xlabel: str | None = None,
681
+ ylabel: str | None = None,
536
682
  color: Color | list[list[Color]] | None = None,
537
683
  linewidth: float | None = None,
538
684
  linestyle: Linestyle | None = None,
@@ -561,10 +707,13 @@ def lines_grouped(
561
707
  color=color_,
562
708
  linewidth=linewidth,
563
709
  linestyle=linestyle,
710
+ xlabel=xlabel,
711
+ ylabel=ylabel,
564
712
  )
565
713
 
714
+ axsl = list(axs)
566
715
  for i in range(len(groups), len(axs)):
567
- axs[i].set_visible(False)
716
+ axsl[i].set_visible(False)
568
717
 
569
718
  return fig, axs
570
719
 
@@ -577,6 +726,8 @@ def line_autogrouped(
577
726
  row_height: float = 3,
578
727
  max_group_size: int = 6,
579
728
  grid: bool = True,
729
+ xlabel: str | None = None,
730
+ ylabel: str | None = None,
580
731
  color: Color | list[list[Color]] | None = None,
581
732
  linewidth: float | None = None,
582
733
  linestyle: Linestyle | None = None,
@@ -604,6 +755,8 @@ def line_autogrouped(
604
755
  color=color,
605
756
  linestyle=linestyle,
606
757
  linewidth=linewidth,
758
+ xlabel=xlabel,
759
+ ylabel=ylabel,
607
760
  )
608
761
 
609
762
 
@@ -712,11 +865,11 @@ def _create_heatmap(
712
865
  if title is not None:
713
866
  ax.set_title(title)
714
867
  ax.set_xticks(
715
- np.arange(0, len(df.columns), 1) + 0.5,
868
+ np.arange(0, len(df.columns), 1, dtype=float) + 0.5,
716
869
  labels=xticklabels,
717
870
  )
718
871
  ax.set_yticks(
719
- np.arange(0, len(df.index), 1) + 0.5,
872
+ np.arange(0, len(df.index), 1, dtype=float) + 0.5,
720
873
  labels=yticklabels,
721
874
  )
722
875
 
@@ -886,7 +1039,7 @@ def violins_from_2d_idx(
886
1039
  grid=grid,
887
1040
  )
888
1041
 
889
- for ax, col in zip(axs[: len(df.columns)], df.columns, strict=True):
1042
+ for ax, col in zip(axs[: len(df.columns)].flatten(), df.columns, strict=True):
890
1043
  ax.set_title(col)
891
1044
  violins(df[col].unstack(), ax=ax)
892
1045
 
mxlpy/report.py CHANGED
@@ -1,5 +1,7 @@
1
1
  """Generate a report comparing two models."""
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  from collections.abc import Callable
4
6
  from datetime import UTC, datetime
5
7
  from pathlib import Path
@@ -10,7 +12,10 @@ import sympy
10
12
  from mxlpy.meta.source_tools import fn_to_sympy
11
13
  from mxlpy.model import Model
12
14
 
13
- __all__ = ["AnalysisFn", "markdown"]
15
+ __all__ = [
16
+ "AnalysisFn",
17
+ "markdown",
18
+ ]
14
19
 
15
20
  type AnalysisFn = Callable[[Model, Model, Path], tuple[str, Path]]
16
21
 
@@ -84,9 +89,29 @@ def markdown(
84
89
 
85
90
  """
86
91
  content: list[str] = [
87
- f"# Report: {datetime.now(UTC).strftime('%Y-%m-%d')}",
92
+ f"# Report: {datetime.now(UTC).strftime('%Y-%m-%d')}\n",
88
93
  ]
89
94
 
95
+ # Unused
96
+ if unused := m2.get_unused_parameters():
97
+ content.append("## <span style='color: red'>Unused parameters</span>\n")
98
+ names = "\n".join(f"<li>{i}</li>\n" for i in sorted(unused))
99
+ content.append(f"<ul>\n{names}\n</ul>\n")
100
+
101
+ # Model stats
102
+ content.extend(
103
+ [
104
+ "| Model component | Old | New |",
105
+ "| --- | --- | --- |",
106
+ f"| variables | {len(m1.variables)} | {len(m2.variables)}|",
107
+ f"| parameters | {len(m1.parameters)} | {len(m2.parameters)}|",
108
+ f"| derived parameters | {len(m1.derived_parameters)} | {len(m2.derived_parameters)}|",
109
+ f"| derived variables | {len(m1.derived_variables)} | {len(m2.derived_variables)}|",
110
+ f"| reactions | {len(m1.reactions)} | {len(m2.reactions)}|",
111
+ f"| surrogates | {len(m1._surrogates)} | {len(m2._surrogates)}|", # noqa: SLF001
112
+ ]
113
+ )
114
+
90
115
  # Variables
91
116
  new_variables, removed_variables, changed_variables = _new_removed_changed(
92
117
  m1.variables, m2.variables
@@ -106,7 +131,7 @@ def markdown(
106
131
  if len(variables) >= 1:
107
132
  content.extend(
108
133
  (
109
- "## Variables\n",
134
+ "## Variables\n\n",
110
135
  "| Name | Old Value | New Value |",
111
136
  "| ---- | --------- | --------- |",
112
137
  )
@@ -132,7 +157,7 @@ def markdown(
132
157
  if len(pars) >= 1:
133
158
  content.extend(
134
159
  (
135
- "## Parameters\n",
160
+ "## Parameters\n\n",
136
161
  "| Name | Old Value | New Value |",
137
162
  "| ---- | --------- | --------- |",
138
163
  )
@@ -159,7 +184,7 @@ def markdown(
159
184
  if len(derived) >= 1:
160
185
  content.extend(
161
186
  (
162
- "## Derived\n",
187
+ "## Derived\n\n",
163
188
  "| Name | Old Value | New Value |",
164
189
  "| ---- | --------- | --------- |",
165
190
  )
@@ -187,7 +212,7 @@ def markdown(
187
212
  if len(reactions) >= 1:
188
213
  content.extend(
189
214
  (
190
- "## Reactions\n",
215
+ "## Reactions\n\n",
191
216
  "| Name | Old Value | New Value |",
192
217
  "| ---- | --------- | --------- |",
193
218
  )
@@ -207,7 +232,7 @@ def markdown(
207
232
  if len(dependent) >= 1:
208
233
  content.extend(
209
234
  (
210
- "## Numerical differences of dependent values\n",
235
+ "## Numerical differences of dependent values\n\n",
211
236
  "| Name | Old Value | New Value | Relative Change | ",
212
237
  "| ---- | --------- | --------- | --------------- | ",
213
238
  )
@@ -226,7 +251,7 @@ def markdown(
226
251
  if len(rhs) >= 1:
227
252
  content.extend(
228
253
  (
229
- "## Numerical differences of right hand side values\n",
254
+ "## Numerical differences of right hand side values\n\n",
230
255
  "| Name | Old Value | New Value | Relative Change | ",
231
256
  "| ---- | --------- | --------- | --------------- | ",
232
257
  )
mxlpy/sbml/__init__.py CHANGED
@@ -5,10 +5,10 @@ Allows importing and exporting metabolic models in SBML format.
5
5
 
6
6
  from __future__ import annotations
7
7
 
8
+ from ._export import write
9
+ from ._import import read
10
+
8
11
  __all__ = [
9
12
  "read",
10
13
  "write",
11
14
  ]
12
-
13
- from ._export import write
14
- from ._import import read
mxlpy/sbml/_data.py CHANGED
@@ -1,5 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING
5
+
6
+ if TYPE_CHECKING:
7
+ from collections.abc import Mapping
8
+
9
+
3
10
  __all__ = [
4
11
  "AtomicUnit",
5
12
  "Compartment",
@@ -11,12 +18,6 @@ __all__ = [
11
18
  "Reaction",
12
19
  ]
13
20
 
14
- from dataclasses import dataclass
15
- from typing import TYPE_CHECKING
16
-
17
- if TYPE_CHECKING:
18
- from collections.abc import Mapping
19
-
20
21
 
21
22
  @dataclass
22
23
  class AtomicUnit:
mxlpy/sbml/_export.py CHANGED
@@ -447,7 +447,14 @@ def _create_sbml_variables(
447
447
  cpd.setConstant(False)
448
448
  cpd.setBoundaryCondition(False)
449
449
  cpd.setHasOnlySubstanceUnits(False)
450
- cpd.setInitialAmount(float(value))
450
+ if isinstance(value, Derived):
451
+ ar = sbml_model.createInitialAssignment()
452
+ ar.setId(_convert_id_to_sbml(id_=name, prefix="IA"))
453
+ ar.setName(_convert_id_to_sbml(id_=name, prefix="IA"))
454
+ ar.setVariable(_convert_id_to_sbml(id_=name, prefix="IA"))
455
+ ar.setMath(_sbmlify_fn(value.fn, value.args))
456
+ else:
457
+ cpd.setInitialAmount(float(value))
451
458
 
452
459
 
453
460
  def _create_sbml_derived_variables(*, model: Model, sbml_model: libsbml.Model) -> None:
mxlpy/sbml/_mathml.py CHANGED
@@ -1,5 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from typing import TYPE_CHECKING, Any
4
+
5
+ from ._name_conversion import _name_to_py
6
+ from ._unit_conversion import get_ast_types
7
+
8
+ if TYPE_CHECKING:
9
+ from libsbml import ASTNode
10
+
3
11
  __all__ = [
4
12
  "AST_TYPES",
5
13
  "handle_ast_constant_e",
@@ -73,13 +81,6 @@ __all__ = [
73
81
  "parse_sbml_math",
74
82
  ]
75
83
 
76
- from typing import TYPE_CHECKING, Any
77
-
78
- from ._name_conversion import _name_to_py
79
- from ._unit_conversion import get_ast_types
80
-
81
- if TYPE_CHECKING:
82
- from libsbml import ASTNode
83
84
 
84
85
  AST_TYPES = get_ast_types()
85
86
 
@@ -3,7 +3,11 @@ from __future__ import annotations
3
3
  import keyword
4
4
  import re
5
5
 
6
- __all__ = ["RE_FROM_SBML", "RE_KWDS", "SBML_DOT"]
6
+ __all__ = [
7
+ "RE_FROM_SBML",
8
+ "RE_KWDS",
9
+ "SBML_DOT",
10
+ ]
7
11
 
8
12
  RE_KWDS = re.compile("|".join(f"^{i}$" for i in keyword.kwlist))
9
13
  SBML_DOT = "__SBML_DOT__"
mxlpy/scan.py CHANGED
@@ -15,19 +15,6 @@ Functions:
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- from mxlpy.integrators import DefaultIntegrator
19
-
20
- __all__ = [
21
- "ProtocolWorker",
22
- "SteadyStateWorker",
23
- "TimeCourse",
24
- "TimeCourseWorker",
25
- "TimePoint",
26
- "steady_state",
27
- "time_course",
28
- "time_course_over_protocol",
29
- ]
30
-
31
18
  from dataclasses import dataclass
32
19
  from functools import partial
33
20
  from typing import TYPE_CHECKING, Protocol, Self, cast
@@ -35,14 +22,10 @@ from typing import TYPE_CHECKING, Protocol, Self, cast
35
22
  import numpy as np
36
23
  import pandas as pd
37
24
 
25
+ from mxlpy.integrators import DefaultIntegrator
38
26
  from mxlpy.parallel import Cache, parallelise
39
27
  from mxlpy.simulator import Result, Simulator
40
- from mxlpy.types import (
41
- IntegratorType,
42
- ProtocolByPars,
43
- SteadyStates,
44
- TimeCourseByPars,
45
- )
28
+ from mxlpy.types import IntegratorType, ProtocolByPars, SteadyStates, TimeCourseByPars
46
29
 
47
30
  if TYPE_CHECKING:
48
31
  from collections.abc import Callable
@@ -51,6 +34,18 @@ if TYPE_CHECKING:
51
34
  from mxlpy.types import Array
52
35
 
53
36
 
37
+ __all__ = [
38
+ "ProtocolWorker",
39
+ "SteadyStateWorker",
40
+ "TimeCourse",
41
+ "TimeCourseWorker",
42
+ "TimePoint",
43
+ "steady_state",
44
+ "time_course",
45
+ "time_course_over_protocol",
46
+ ]
47
+
48
+
54
49
  def _update_parameters_and_initial_conditions[T](
55
50
  pars: pd.Series,
56
51
  fn: Callable[[Model], T],