mxlpy 0.17.0__py3-none-any.whl → 0.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mxlpy/__init__.py +14 -4
- mxlpy/experimental/diff.py +1 -1
- mxlpy/fit.py +173 -7
- mxlpy/identify.py +7 -1
- mxlpy/integrators/int_assimulo.py +10 -3
- mxlpy/integrators/int_scipy.py +7 -3
- mxlpy/label_map.py +3 -1
- mxlpy/meta/codegen_latex.py +1 -1
- mxlpy/meta/source_tools.py +1 -1
- mxlpy/model.py +146 -87
- mxlpy/nn/__init__.py +24 -5
- mxlpy/nn/_keras.py +85 -0
- mxlpy/nn/_torch.py +76 -15
- mxlpy/npe/__init__.py +21 -16
- mxlpy/npe/_keras.py +326 -0
- mxlpy/npe/_torch.py +73 -148
- mxlpy/plot.py +196 -52
- mxlpy/sbml/_export.py +8 -1
- mxlpy/surrogates/__init__.py +25 -17
- mxlpy/surrogates/_keras.py +137 -0
- mxlpy/surrogates/_poly.py +19 -8
- mxlpy/surrogates/_qss.py +31 -0
- mxlpy/surrogates/_torch.py +51 -127
- mxlpy/symbolic/symbolic_model.py +2 -2
- mxlpy/types.py +57 -114
- {mxlpy-0.17.0.dist-info → mxlpy-0.19.0.dist-info}/METADATA +27 -28
- mxlpy-0.19.0.dist-info/RECORD +54 -0
- mxlpy-0.19.0.dist-info/licenses/LICENSE +21 -0
- mxlpy/nn/_tensorflow.py +0 -0
- mxlpy-0.17.0.dist-info/RECORD +0 -51
- mxlpy-0.17.0.dist-info/licenses/LICENSE +0 -674
- {mxlpy-0.17.0.dist-info → mxlpy-0.19.0.dist-info}/WHEEL +0 -0
mxlpy/plot.py
CHANGED
@@ -23,9 +23,12 @@ import contextlib
|
|
23
23
|
from cycler import cycler
|
24
24
|
|
25
25
|
__all__ = [
|
26
|
+
"Color",
|
26
27
|
"FigAx",
|
27
28
|
"FigAxs",
|
28
29
|
"Linestyle",
|
30
|
+
"RGB",
|
31
|
+
"RGBA",
|
29
32
|
"add_grid",
|
30
33
|
"bars",
|
31
34
|
"context",
|
@@ -71,7 +74,7 @@ from mpl_toolkits.mplot3d import Axes3D
|
|
71
74
|
from mxlpy.label_map import LabelMapper
|
72
75
|
|
73
76
|
if TYPE_CHECKING:
|
74
|
-
from collections.abc import Generator
|
77
|
+
from collections.abc import Generator, Iterable
|
75
78
|
|
76
79
|
from matplotlib.collections import QuadMesh
|
77
80
|
|
@@ -89,6 +92,11 @@ type Linestyle = Literal[
|
|
89
92
|
"dashdot",
|
90
93
|
]
|
91
94
|
|
95
|
+
|
96
|
+
type RGB = tuple[float, float, float]
|
97
|
+
type RGBA = tuple[float, float, float, float]
|
98
|
+
type Color = str | RGB | RGBA
|
99
|
+
|
92
100
|
##########################################################################
|
93
101
|
# Helpers
|
94
102
|
##########################################################################
|
@@ -133,6 +141,13 @@ def _get_norm(vmin: float, vmax: float) -> Normalize:
|
|
133
141
|
return norm
|
134
142
|
|
135
143
|
|
144
|
+
def _norm(df: pd.DataFrame) -> Normalize:
|
145
|
+
"""Get a normalization object for the given data."""
|
146
|
+
vmin = df.min().min()
|
147
|
+
vmax = df.max().max()
|
148
|
+
return _get_norm(vmin, vmax)
|
149
|
+
|
150
|
+
|
136
151
|
def _norm_with_zero_center(df: pd.DataFrame) -> Normalize:
|
137
152
|
"""Get a normalization object with zero-centered values for the given data."""
|
138
153
|
v = max(abs(df.min().min()), abs(df.max().max()))
|
@@ -166,7 +181,7 @@ def _split_large_groups[T](groups: list[list[T]], max_size: int) -> list[list[T]
|
|
166
181
|
) # type: ignore
|
167
182
|
|
168
183
|
|
169
|
-
def _default_color(ax: Axes, color:
|
184
|
+
def _default_color(ax: Axes, color: Color | None) -> Color:
|
170
185
|
"""Get a default color for the given axis."""
|
171
186
|
return f"C{len(ax.lines)}" if color is None else color
|
172
187
|
|
@@ -291,16 +306,16 @@ def reset_prop_cycle(ax: Axes) -> None:
|
|
291
306
|
@contextlib.contextmanager
|
292
307
|
def context(
|
293
308
|
colors: list[str] | None = None,
|
294
|
-
|
295
|
-
|
309
|
+
linewidth: float | None = None,
|
310
|
+
linestyle: Linestyle | None = None,
|
296
311
|
rc: dict[str, Any] | None = None,
|
297
312
|
) -> Generator[None, None, None]:
|
298
313
|
"""Context manager to set the defaults for plots.
|
299
314
|
|
300
315
|
Args:
|
301
316
|
colors: colors to use for the plot.
|
302
|
-
|
303
|
-
|
317
|
+
linewidth: line width to use for the plot.
|
318
|
+
linestyle: line style to use for the plot.
|
304
319
|
rc: additional keyword arguments to pass to the rc context.
|
305
320
|
|
306
321
|
"""
|
@@ -309,11 +324,11 @@ def context(
|
|
309
324
|
if colors is not None:
|
310
325
|
rc["axes.prop_cycle"] = cycler(color=colors)
|
311
326
|
|
312
|
-
if
|
313
|
-
rc["lines.linewidth"] =
|
327
|
+
if linewidth is not None:
|
328
|
+
rc["lines.linewidth"] = linewidth
|
314
329
|
|
315
|
-
if
|
316
|
-
rc["lines.linestyle"] =
|
330
|
+
if linestyle is not None:
|
331
|
+
rc["lines.linestyle"] = linestyle
|
317
332
|
|
318
333
|
with plt.rc_context(rc):
|
319
334
|
yield
|
@@ -477,9 +492,12 @@ def lines(
|
|
477
492
|
x: pd.DataFrame | pd.Series,
|
478
493
|
*,
|
479
494
|
ax: Axes | None = None,
|
480
|
-
grid: bool = True,
|
481
495
|
alpha: float = 1.0,
|
496
|
+
color: Color | list[Color] | None = None,
|
497
|
+
grid: bool = True,
|
482
498
|
legend: bool = True,
|
499
|
+
linewidth: float | None = None,
|
500
|
+
linestyle: Linestyle | None = None,
|
483
501
|
) -> FigAx:
|
484
502
|
"""Plot multiple lines on the same axis."""
|
485
503
|
fig, ax = _default_fig_ax(ax=ax, grid=grid)
|
@@ -487,6 +505,9 @@ def lines(
|
|
487
505
|
x.index,
|
488
506
|
x,
|
489
507
|
alpha=alpha,
|
508
|
+
linewidth=linewidth,
|
509
|
+
linestyle=linestyle,
|
510
|
+
color=color,
|
490
511
|
)
|
491
512
|
_default_labels(ax, xlabel=x.index.name, ylabel=None)
|
492
513
|
if legend:
|
@@ -497,6 +518,12 @@ def lines(
|
|
497
518
|
return fig, ax
|
498
519
|
|
499
520
|
|
521
|
+
def _repeat_color_if_necessary(
|
522
|
+
color: list[list[Color]] | Color | None, n: int
|
523
|
+
) -> Iterable[list[Color] | Color | None]:
|
524
|
+
return [color] * n if not isinstance(color, list) else color
|
525
|
+
|
526
|
+
|
500
527
|
def lines_grouped(
|
501
528
|
groups: list[pd.DataFrame] | list[pd.Series],
|
502
529
|
*,
|
@@ -506,6 +533,9 @@ def lines_grouped(
|
|
506
533
|
sharex: bool = True,
|
507
534
|
sharey: bool = False,
|
508
535
|
grid: bool = True,
|
536
|
+
color: Color | list[list[Color]] | None = None,
|
537
|
+
linewidth: float | None = None,
|
538
|
+
linestyle: Linestyle | None = None,
|
509
539
|
) -> FigAxs:
|
510
540
|
"""Plot multiple groups of lines on separate axes."""
|
511
541
|
fig, axs = grid_layout(
|
@@ -518,8 +548,20 @@ def lines_grouped(
|
|
518
548
|
grid=grid,
|
519
549
|
)
|
520
550
|
|
521
|
-
for group, ax in zip(
|
522
|
-
|
551
|
+
for group, ax, color_ in zip(
|
552
|
+
groups,
|
553
|
+
axs,
|
554
|
+
_repeat_color_if_necessary(color, n=len(groups)),
|
555
|
+
strict=False,
|
556
|
+
):
|
557
|
+
lines(
|
558
|
+
group,
|
559
|
+
ax=ax,
|
560
|
+
grid=grid,
|
561
|
+
color=color_,
|
562
|
+
linewidth=linewidth,
|
563
|
+
linestyle=linestyle,
|
564
|
+
)
|
523
565
|
|
524
566
|
for i in range(len(groups), len(axs)):
|
525
567
|
axs[i].set_visible(False)
|
@@ -535,6 +577,9 @@ def line_autogrouped(
|
|
535
577
|
row_height: float = 3,
|
536
578
|
max_group_size: int = 6,
|
537
579
|
grid: bool = True,
|
580
|
+
color: Color | list[list[Color]] | None = None,
|
581
|
+
linewidth: float | None = None,
|
582
|
+
linestyle: Linestyle | None = None,
|
538
583
|
) -> FigAxs:
|
539
584
|
"""Plot a series or dataframe with lines grouped by order of magnitude."""
|
540
585
|
group_names = _split_large_groups(
|
@@ -556,6 +601,9 @@ def line_autogrouped(
|
|
556
601
|
col_width=col_width,
|
557
602
|
row_height=row_height,
|
558
603
|
grid=grid,
|
604
|
+
color=color,
|
605
|
+
linestyle=linestyle,
|
606
|
+
linewidth=linewidth,
|
559
607
|
)
|
560
608
|
|
561
609
|
|
@@ -564,7 +612,9 @@ def line_mean_std(
|
|
564
612
|
*,
|
565
613
|
label: str | None = None,
|
566
614
|
ax: Axes | None = None,
|
567
|
-
color:
|
615
|
+
color: Color | None = None,
|
616
|
+
linewidth: float | None = None,
|
617
|
+
linestyle: Linestyle | None = None,
|
568
618
|
alpha: float = 0.2,
|
569
619
|
grid: bool = True,
|
570
620
|
) -> FigAx:
|
@@ -579,6 +629,8 @@ def line_mean_std(
|
|
579
629
|
mean,
|
580
630
|
color=color,
|
581
631
|
label=label,
|
632
|
+
linewidth=linewidth,
|
633
|
+
linestyle=linestyle,
|
582
634
|
)
|
583
635
|
ax.fill_between(
|
584
636
|
df.index,
|
@@ -598,6 +650,9 @@ def lines_mean_std_from_2d_idx(
|
|
598
650
|
ax: Axes | None = None,
|
599
651
|
alpha: float = 0.2,
|
600
652
|
grid: bool = True,
|
653
|
+
color: Color | None = None,
|
654
|
+
linewidth: float | None = None,
|
655
|
+
linestyle: Linestyle | None = None,
|
601
656
|
) -> FigAx:
|
602
657
|
"""Plot the mean and standard deviation of a 2D indexed dataframe."""
|
603
658
|
if len(cast(pd.MultiIndex, df.index).levels) != 2: # noqa: PLR2004
|
@@ -612,25 +667,32 @@ def lines_mean_std_from_2d_idx(
|
|
612
667
|
label=name,
|
613
668
|
alpha=alpha,
|
614
669
|
ax=ax,
|
670
|
+
color=color,
|
671
|
+
linestyle=linestyle,
|
672
|
+
linewidth=linewidth,
|
615
673
|
)
|
616
674
|
ax.legend()
|
617
675
|
return fig, ax
|
618
676
|
|
619
677
|
|
620
|
-
def
|
621
|
-
df: pd.DataFrame,
|
678
|
+
def _create_heatmap(
|
622
679
|
*,
|
680
|
+
ax: Axes | None,
|
681
|
+
df: pd.DataFrame,
|
682
|
+
title: str | None = None,
|
683
|
+
xlabel: str,
|
684
|
+
ylabel: str,
|
685
|
+
xticklabels: list[str],
|
686
|
+
yticklabels: list[str],
|
687
|
+
norm: Normalize,
|
623
688
|
annotate: bool = False,
|
624
689
|
colorbar: bool = True,
|
625
690
|
invert_yaxis: bool = True,
|
626
691
|
cmap: str = "RdBu_r",
|
627
|
-
norm: Normalize | None = None,
|
628
|
-
ax: Axes | None = None,
|
629
692
|
cax: Axes | None = None,
|
630
693
|
sci_annotation_bounds: tuple[float, float] = (0.01, 100),
|
631
694
|
annotation_style: str = "2g",
|
632
695
|
) -> tuple[Figure, Axes, QuadMesh]:
|
633
|
-
"""Plot a heatmap of the given data."""
|
634
696
|
fig, ax = _default_fig_ax(
|
635
697
|
ax=ax,
|
636
698
|
figsize=(
|
@@ -639,17 +701,23 @@ def heatmap(
|
|
639
701
|
),
|
640
702
|
grid=False,
|
641
703
|
)
|
642
|
-
if norm is None:
|
643
|
-
norm = _norm_with_zero_center(df)
|
644
704
|
|
705
|
+
# Note: pcolormesh swaps index/columns
|
645
706
|
hm = ax.pcolormesh(df, norm=norm, cmap=cmap)
|
707
|
+
|
708
|
+
if xlabel is not None:
|
709
|
+
ax.set_xlabel(xlabel)
|
710
|
+
if ylabel is not None:
|
711
|
+
ax.set_ylabel(ylabel)
|
712
|
+
if title is not None:
|
713
|
+
ax.set_title(title)
|
646
714
|
ax.set_xticks(
|
647
|
-
np.arange(0, len(df.columns), 1) + 0.5,
|
648
|
-
labels=
|
715
|
+
np.arange(0, len(df.columns), 1, dtype=float) + 0.5,
|
716
|
+
labels=xticklabels,
|
649
717
|
)
|
650
718
|
ax.set_yticks(
|
651
|
-
np.arange(0, len(df.index), 1) + 0.5,
|
652
|
-
labels=
|
719
|
+
np.arange(0, len(df.index), 1, dtype=float) + 0.5,
|
720
|
+
labels=yticklabels,
|
653
721
|
)
|
654
722
|
|
655
723
|
if annotate:
|
@@ -666,39 +734,77 @@ def heatmap(
|
|
666
734
|
return fig, ax, hm
|
667
735
|
|
668
736
|
|
737
|
+
def heatmap(
|
738
|
+
df: pd.DataFrame,
|
739
|
+
*,
|
740
|
+
ax: Axes | None = None,
|
741
|
+
title: str | None = None,
|
742
|
+
annotate: bool = False,
|
743
|
+
colorbar: bool = True,
|
744
|
+
invert_yaxis: bool = True,
|
745
|
+
cmap: str = "RdBu_r",
|
746
|
+
norm: Normalize | None = None,
|
747
|
+
cax: Axes | None = None,
|
748
|
+
sci_annotation_bounds: tuple[float, float] = (0.01, 100),
|
749
|
+
annotation_style: str = "2g",
|
750
|
+
) -> tuple[Figure, Axes, QuadMesh]:
|
751
|
+
"""Plot a heatmap of the given data."""
|
752
|
+
return _create_heatmap(
|
753
|
+
ax=ax,
|
754
|
+
df=df,
|
755
|
+
title=title,
|
756
|
+
xlabel=df.index.name,
|
757
|
+
ylabel=df.columns.name,
|
758
|
+
xticklabels=cast(list, df.columns),
|
759
|
+
yticklabels=cast(list, df.index),
|
760
|
+
annotate=annotate,
|
761
|
+
colorbar=colorbar,
|
762
|
+
invert_yaxis=invert_yaxis,
|
763
|
+
cmap=cmap,
|
764
|
+
norm=_norm_with_zero_center(df) if norm is None else norm,
|
765
|
+
cax=cax,
|
766
|
+
sci_annotation_bounds=sci_annotation_bounds,
|
767
|
+
annotation_style=annotation_style,
|
768
|
+
)
|
769
|
+
|
770
|
+
|
669
771
|
def heatmap_from_2d_idx(
|
670
772
|
df: pd.DataFrame,
|
671
773
|
variable: str,
|
774
|
+
*,
|
672
775
|
ax: Axes | None = None,
|
673
|
-
|
776
|
+
annotate: bool = False,
|
777
|
+
colorbar: bool = True,
|
778
|
+
invert_yaxis: bool = False,
|
779
|
+
cmap: str = "viridis",
|
780
|
+
norm: Normalize | None = None,
|
781
|
+
cax: Axes | None = None,
|
782
|
+
sci_annotation_bounds: tuple[float, float] = (0.01, 100),
|
783
|
+
annotation_style: str = "2g",
|
784
|
+
) -> tuple[Figure, Axes, QuadMesh]:
|
674
785
|
"""Plot a heatmap of a 2D indexed dataframe."""
|
675
786
|
if len(cast(pd.MultiIndex, df.index).levels) != 2: # noqa: PLR2004
|
676
787
|
msg = "MultiIndex must have exactly two levels"
|
677
788
|
raise ValueError(msg)
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
789
|
+
df2d = df[variable].unstack().T
|
790
|
+
|
791
|
+
return _create_heatmap(
|
792
|
+
df=df2d,
|
793
|
+
xlabel=df2d.index.name,
|
794
|
+
ylabel=df2d.columns.name,
|
795
|
+
xticklabels=[f"{i:.2f}" for i in df2d.columns],
|
796
|
+
yticklabels=[f"{i:.2f}" for i in df2d.index],
|
797
|
+
ax=ax,
|
798
|
+
cax=cax,
|
799
|
+
annotate=annotate,
|
800
|
+
colorbar=colorbar,
|
801
|
+
invert_yaxis=invert_yaxis,
|
802
|
+
cmap=cmap,
|
803
|
+
norm=_norm(df2d) if norm is None else norm,
|
804
|
+
sci_annotation_bounds=sci_annotation_bounds,
|
805
|
+
annotation_style=annotation_style,
|
694
806
|
)
|
695
807
|
|
696
|
-
rotate_xlabels(ax, rotation=45, ha="right")
|
697
|
-
|
698
|
-
# Add colorbar
|
699
|
-
fig.colorbar(hm, ax=ax)
|
700
|
-
return fig, ax
|
701
|
-
|
702
808
|
|
703
809
|
def heatmaps_from_2d_idx(
|
704
810
|
df: pd.DataFrame,
|
@@ -708,6 +814,13 @@ def heatmaps_from_2d_idx(
|
|
708
814
|
row_height_factor: float = 0.6,
|
709
815
|
sharex: bool = True,
|
710
816
|
sharey: bool = False,
|
817
|
+
annotate: bool = False,
|
818
|
+
colorbar: bool = True,
|
819
|
+
invert_yaxis: bool = False,
|
820
|
+
cmap: str = "viridis",
|
821
|
+
norm: Normalize | None = None,
|
822
|
+
sci_annotation_bounds: tuple[float, float] = (0.01, 100),
|
823
|
+
annotation_style: str = "2g",
|
711
824
|
) -> FigAxs:
|
712
825
|
"""Plot multiple heatmaps of a 2D indexed dataframe."""
|
713
826
|
idx = cast(pd.MultiIndex, df.index)
|
@@ -722,7 +835,18 @@ def heatmaps_from_2d_idx(
|
|
722
835
|
grid=False,
|
723
836
|
)
|
724
837
|
for ax, var in zip(axs, df.columns, strict=False):
|
725
|
-
heatmap_from_2d_idx(
|
838
|
+
heatmap_from_2d_idx(
|
839
|
+
df,
|
840
|
+
var,
|
841
|
+
ax=ax,
|
842
|
+
annotate=annotate,
|
843
|
+
colorbar=colorbar,
|
844
|
+
invert_yaxis=invert_yaxis,
|
845
|
+
cmap=cmap,
|
846
|
+
norm=norm,
|
847
|
+
sci_annotation_bounds=sci_annotation_bounds,
|
848
|
+
annotation_style=annotation_style,
|
849
|
+
)
|
726
850
|
return fig, axs
|
727
851
|
|
728
852
|
|
@@ -892,6 +1016,9 @@ def relative_label_distribution(
|
|
892
1016
|
row_height: float = 3,
|
893
1017
|
sharey: bool = False,
|
894
1018
|
grid: bool = True,
|
1019
|
+
color: Color | None = None,
|
1020
|
+
linewidth: float | None = None,
|
1021
|
+
linestyle: Linestyle | None = None,
|
895
1022
|
) -> FigAxs:
|
896
1023
|
"""Plot the relative distribution of labels in the given data."""
|
897
1024
|
variables = list(mapper.label_variables) if subset is None else subset
|
@@ -903,20 +1030,37 @@ def relative_label_distribution(
|
|
903
1030
|
sharey=sharey,
|
904
1031
|
grid=grid,
|
905
1032
|
)
|
1033
|
+
# FIXME: rewrite as building a dict of dataframes
|
1034
|
+
# and passing it to lines_grouped
|
906
1035
|
if isinstance(mapper, LabelMapper):
|
907
1036
|
for ax, name in zip(axs, variables, strict=False):
|
908
1037
|
for i in range(mapper.label_variables[name]):
|
909
1038
|
isos = mapper.get_isotopomers_of_at_position(name, i)
|
910
1039
|
labels = cast(pd.DataFrame, concs.loc[:, isos])
|
911
1040
|
total = concs.loc[:, f"{name}__total"]
|
912
|
-
ax.plot(
|
1041
|
+
ax.plot(
|
1042
|
+
labels.index,
|
1043
|
+
(labels.sum(axis=1) / total),
|
1044
|
+
label=f"C{i + 1}",
|
1045
|
+
linewidth=linewidth,
|
1046
|
+
linestyle=linestyle,
|
1047
|
+
color=color,
|
1048
|
+
)
|
913
1049
|
ax.set_title(name)
|
914
1050
|
ax.legend()
|
915
1051
|
else:
|
916
1052
|
for ax, (name, isos) in zip(
|
917
|
-
axs,
|
1053
|
+
axs,
|
1054
|
+
mapper.get_isotopomers(variables).items(),
|
1055
|
+
strict=False,
|
918
1056
|
):
|
919
|
-
ax.plot(
|
1057
|
+
ax.plot(
|
1058
|
+
concs.index,
|
1059
|
+
concs.loc[:, isos],
|
1060
|
+
linewidth=linewidth,
|
1061
|
+
linestyle=linestyle,
|
1062
|
+
color=color,
|
1063
|
+
)
|
920
1064
|
ax.set_title(name)
|
921
1065
|
ax.legend([f"C{i + 1}" for i in range(len(isos))])
|
922
1066
|
|
mxlpy/sbml/_export.py
CHANGED
@@ -447,7 +447,14 @@ def _create_sbml_variables(
|
|
447
447
|
cpd.setConstant(False)
|
448
448
|
cpd.setBoundaryCondition(False)
|
449
449
|
cpd.setHasOnlySubstanceUnits(False)
|
450
|
-
|
450
|
+
if isinstance(value, Derived):
|
451
|
+
ar = sbml_model.createInitialAssignment()
|
452
|
+
ar.setId(_convert_id_to_sbml(id_=name, prefix="IA"))
|
453
|
+
ar.setName(_convert_id_to_sbml(id_=name, prefix="IA"))
|
454
|
+
ar.setVariable(_convert_id_to_sbml(id_=name, prefix="IA"))
|
455
|
+
ar.setMath(_sbmlify_fn(value.fn, value.args))
|
456
|
+
else:
|
457
|
+
cpd.setInitialAmount(float(value))
|
451
458
|
|
452
459
|
|
453
460
|
def _create_sbml_derived_variables(*, model: Model, sbml_model: libsbml.Model) -> None:
|
mxlpy/surrogates/__init__.py
CHANGED
@@ -4,29 +4,37 @@ This module provides classes and functions for creating and training surrogate m
|
|
4
4
|
for metabolic simulations. It includes functionality for both steady-state and time-series
|
5
5
|
data using neural networks.
|
6
6
|
|
7
|
-
Classes:
|
8
|
-
AbstractSurrogate: Abstract base class for surrogate models.
|
9
|
-
TorchSurrogate: Surrogate model using PyTorch.
|
10
|
-
Approximator: Neural network approximator for surrogate modeling.
|
11
|
-
|
12
|
-
Functions:
|
13
|
-
train_torch_surrogate: Train a PyTorch surrogate model.
|
14
|
-
train_torch_time_course_estimator: Train a PyTorch time course estimator.
|
15
7
|
"""
|
16
8
|
|
17
9
|
from __future__ import annotations
|
18
10
|
|
19
|
-
import
|
11
|
+
from typing import TYPE_CHECKING
|
20
12
|
|
21
|
-
|
22
|
-
|
13
|
+
if TYPE_CHECKING:
|
14
|
+
import contextlib
|
23
15
|
|
24
|
-
|
16
|
+
with contextlib.suppress(ImportError):
|
17
|
+
from . import _keras as keras
|
18
|
+
from . import _torch as torch
|
19
|
+
else:
|
20
|
+
from lazy_import import lazy_module
|
21
|
+
|
22
|
+
keras = lazy_module(
|
23
|
+
"mxlpy.surrogates._keras",
|
24
|
+
error_strings={"module": "keras", "install_name": "mxlpy[tf]"},
|
25
|
+
)
|
26
|
+
torch = lazy_module(
|
27
|
+
"mxlpy.surrogates._torch",
|
28
|
+
error_strings={"module": "torch", "install_name": "mxlpy[torch]"},
|
29
|
+
)
|
30
|
+
|
31
|
+
|
32
|
+
from . import _poly as poly
|
33
|
+
from . import _qss as qss
|
25
34
|
|
26
35
|
__all__ = [
|
27
|
-
"
|
28
|
-
"
|
29
|
-
"
|
30
|
-
"
|
31
|
-
"train_torch",
|
36
|
+
"keras",
|
37
|
+
"poly",
|
38
|
+
"qss",
|
39
|
+
"torch",
|
32
40
|
]
|
@@ -0,0 +1,137 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from typing import Self, cast
|
3
|
+
|
4
|
+
import keras
|
5
|
+
import numpy as np
|
6
|
+
import pandas as pd
|
7
|
+
|
8
|
+
from mxlpy.nn._keras import MLP
|
9
|
+
from mxlpy.nn._keras import train as _train
|
10
|
+
from mxlpy.types import AbstractSurrogate, Array, Derived
|
11
|
+
|
12
|
+
__all__ = [
|
13
|
+
"DefaultLoss",
|
14
|
+
"DefaultOptimizer",
|
15
|
+
"LossFn",
|
16
|
+
"Optimizer",
|
17
|
+
"Surrogate",
|
18
|
+
"Trainer",
|
19
|
+
"train",
|
20
|
+
]
|
21
|
+
|
22
|
+
type Optimizer = keras.optimizers.Optimizer | str
|
23
|
+
type LossFn = keras.losses.Loss | str
|
24
|
+
|
25
|
+
DefaultOptimizer = keras.optimizers.Adam()
|
26
|
+
DefaultLoss = keras.losses.MeanAbsoluteError()
|
27
|
+
|
28
|
+
|
29
|
+
@dataclass(kw_only=True)
|
30
|
+
class Surrogate(AbstractSurrogate):
|
31
|
+
model: keras.Model
|
32
|
+
|
33
|
+
def predict_raw(self, y: Array) -> Array:
|
34
|
+
return np.atleast_1d(np.squeeze(self.model.predict(y)))
|
35
|
+
|
36
|
+
def predict(
|
37
|
+
self, args: dict[str, float | pd.Series | pd.DataFrame]
|
38
|
+
) -> dict[str, float]:
|
39
|
+
return dict(
|
40
|
+
zip(
|
41
|
+
self.outputs,
|
42
|
+
self.predict_raw(np.array([args[arg] for arg in self.args])),
|
43
|
+
strict=True,
|
44
|
+
)
|
45
|
+
)
|
46
|
+
|
47
|
+
|
48
|
+
@dataclass(init=False)
|
49
|
+
class Trainer:
|
50
|
+
features: pd.DataFrame
|
51
|
+
targets: pd.DataFrame
|
52
|
+
model: keras.Model
|
53
|
+
optimizer: Optimizer | str
|
54
|
+
losses: list[pd.Series]
|
55
|
+
loss_fn: LossFn
|
56
|
+
|
57
|
+
def __init__(
|
58
|
+
self,
|
59
|
+
features: pd.DataFrame,
|
60
|
+
targets: pd.DataFrame,
|
61
|
+
model: keras.Model | None = None,
|
62
|
+
optimizer: Optimizer = DefaultOptimizer,
|
63
|
+
loss: LossFn = DefaultLoss,
|
64
|
+
) -> None:
|
65
|
+
self.features = features
|
66
|
+
self.targets = targets
|
67
|
+
if model is None:
|
68
|
+
model = MLP(
|
69
|
+
n_inputs=len(features.columns),
|
70
|
+
neurons_per_layer=[50, 50, len(targets.columns)],
|
71
|
+
)
|
72
|
+
self.model = model
|
73
|
+
model.compile(optimizer=cast(str, optimizer), loss=loss)
|
74
|
+
|
75
|
+
self.losses = []
|
76
|
+
|
77
|
+
def train(self, epochs: int, batch_size: int | None = None) -> Self:
|
78
|
+
losses = _train(
|
79
|
+
model=self.model,
|
80
|
+
features=self.features,
|
81
|
+
targets=self.targets,
|
82
|
+
epochs=epochs,
|
83
|
+
batch_size=batch_size,
|
84
|
+
)
|
85
|
+
|
86
|
+
if len(self.losses) > 0:
|
87
|
+
losses.index += self.losses[-1].index[-1]
|
88
|
+
self.losses.append(losses)
|
89
|
+
|
90
|
+
return self
|
91
|
+
|
92
|
+
def get_loss(self) -> pd.Series:
|
93
|
+
return pd.concat(self.losses)
|
94
|
+
|
95
|
+
def get_surrogate(
|
96
|
+
self,
|
97
|
+
surrogate_args: list[str] | None = None,
|
98
|
+
surrogate_outputs: list[str] | None = None,
|
99
|
+
surrogate_stoichiometries: dict[str, dict[str, float | Derived]] | None = None,
|
100
|
+
) -> Surrogate:
|
101
|
+
return Surrogate(
|
102
|
+
model=self.model,
|
103
|
+
args=surrogate_args if surrogate_args is not None else [],
|
104
|
+
outputs=surrogate_outputs if surrogate_outputs is not None else [],
|
105
|
+
stoichiometries=surrogate_stoichiometries
|
106
|
+
if surrogate_stoichiometries is not None
|
107
|
+
else {},
|
108
|
+
)
|
109
|
+
|
110
|
+
|
111
|
+
def train(
|
112
|
+
features: pd.DataFrame,
|
113
|
+
targets: pd.DataFrame,
|
114
|
+
epochs: int,
|
115
|
+
surrogate_args: list[str] | None = None,
|
116
|
+
surrogate_outputs: list[str] | None = None,
|
117
|
+
surrogate_stoichiometries: dict[str, dict[str, float | Derived]] | None = None,
|
118
|
+
batch_size: int | None = None,
|
119
|
+
model: keras.Model | None = None,
|
120
|
+
optimizer: Optimizer = DefaultOptimizer,
|
121
|
+
loss: LossFn = DefaultLoss,
|
122
|
+
) -> tuple[Surrogate, pd.Series]:
|
123
|
+
trainer = Trainer(
|
124
|
+
features=features,
|
125
|
+
targets=targets,
|
126
|
+
model=model,
|
127
|
+
optimizer=optimizer,
|
128
|
+
loss=loss,
|
129
|
+
).train(
|
130
|
+
epochs=epochs,
|
131
|
+
batch_size=batch_size,
|
132
|
+
)
|
133
|
+
return trainer.get_surrogate(
|
134
|
+
surrogate_args=surrogate_args,
|
135
|
+
surrogate_outputs=surrogate_outputs,
|
136
|
+
surrogate_stoichiometries=surrogate_stoichiometries,
|
137
|
+
), trainer.get_loss()
|