mxlpy 0.16.0__py3-none-any.whl → 0.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mxlpy/__init__.py +4 -1
- mxlpy/fit.py +173 -7
- mxlpy/fns.py +513 -21
- mxlpy/identify.py +7 -1
- mxlpy/meta/codegen_latex.py +279 -14
- mxlpy/meta/source_tools.py +122 -4
- mxlpy/model.py +50 -24
- mxlpy/nn/_torch.py +61 -1
- mxlpy/npe/__init__.py +38 -0
- mxlpy/npe/_torch.py +365 -0
- mxlpy/plot.py +194 -50
- mxlpy/report.py +33 -6
- mxlpy/sbml/_import.py +5 -2
- mxlpy/surrogates/__init__.py +7 -6
- mxlpy/surrogates/_poly.py +12 -9
- mxlpy/surrogates/_torch.py +118 -114
- mxlpy/symbolic/strikepy.py +1 -3
- mxlpy/types.py +17 -7
- {mxlpy-0.16.0.dist-info → mxlpy-0.18.0.dist-info}/METADATA +7 -8
- {mxlpy-0.16.0.dist-info → mxlpy-0.18.0.dist-info}/RECORD +22 -21
- mxlpy-0.18.0.dist-info/licenses/LICENSE +21 -0
- mxlpy/npe.py +0 -277
- mxlpy-0.16.0.dist-info/licenses/LICENSE +0 -674
- {mxlpy-0.16.0.dist-info → mxlpy-0.18.0.dist-info}/WHEEL +0 -0
mxlpy/plot.py
CHANGED
@@ -23,9 +23,12 @@ import contextlib
|
|
23
23
|
from cycler import cycler
|
24
24
|
|
25
25
|
__all__ = [
|
26
|
+
"Color",
|
26
27
|
"FigAx",
|
27
28
|
"FigAxs",
|
28
29
|
"Linestyle",
|
30
|
+
"RGB",
|
31
|
+
"RGBA",
|
29
32
|
"add_grid",
|
30
33
|
"bars",
|
31
34
|
"context",
|
@@ -71,7 +74,7 @@ from mpl_toolkits.mplot3d import Axes3D
|
|
71
74
|
from mxlpy.label_map import LabelMapper
|
72
75
|
|
73
76
|
if TYPE_CHECKING:
|
74
|
-
from collections.abc import Generator
|
77
|
+
from collections.abc import Generator, Iterable
|
75
78
|
|
76
79
|
from matplotlib.collections import QuadMesh
|
77
80
|
|
@@ -89,6 +92,11 @@ type Linestyle = Literal[
|
|
89
92
|
"dashdot",
|
90
93
|
]
|
91
94
|
|
95
|
+
|
96
|
+
type RGB = tuple[float, float, float]
|
97
|
+
type RGBA = tuple[float, float, float, float]
|
98
|
+
type Color = str | RGB | RGBA
|
99
|
+
|
92
100
|
##########################################################################
|
93
101
|
# Helpers
|
94
102
|
##########################################################################
|
@@ -133,6 +141,13 @@ def _get_norm(vmin: float, vmax: float) -> Normalize:
|
|
133
141
|
return norm
|
134
142
|
|
135
143
|
|
144
|
+
def _norm(df: pd.DataFrame) -> Normalize:
|
145
|
+
"""Get a normalization object for the given data."""
|
146
|
+
vmin = df.min().min()
|
147
|
+
vmax = df.max().max()
|
148
|
+
return _get_norm(vmin, vmax)
|
149
|
+
|
150
|
+
|
136
151
|
def _norm_with_zero_center(df: pd.DataFrame) -> Normalize:
|
137
152
|
"""Get a normalization object with zero-centered values for the given data."""
|
138
153
|
v = max(abs(df.min().min()), abs(df.max().max()))
|
@@ -166,7 +181,7 @@ def _split_large_groups[T](groups: list[list[T]], max_size: int) -> list[list[T]
|
|
166
181
|
) # type: ignore
|
167
182
|
|
168
183
|
|
169
|
-
def _default_color(ax: Axes, color:
|
184
|
+
def _default_color(ax: Axes, color: Color | None) -> Color:
|
170
185
|
"""Get a default color for the given axis."""
|
171
186
|
return f"C{len(ax.lines)}" if color is None else color
|
172
187
|
|
@@ -291,16 +306,16 @@ def reset_prop_cycle(ax: Axes) -> None:
|
|
291
306
|
@contextlib.contextmanager
|
292
307
|
def context(
|
293
308
|
colors: list[str] | None = None,
|
294
|
-
|
295
|
-
|
309
|
+
linewidth: float | None = None,
|
310
|
+
linestyle: Linestyle | None = None,
|
296
311
|
rc: dict[str, Any] | None = None,
|
297
312
|
) -> Generator[None, None, None]:
|
298
313
|
"""Context manager to set the defaults for plots.
|
299
314
|
|
300
315
|
Args:
|
301
316
|
colors: colors to use for the plot.
|
302
|
-
|
303
|
-
|
317
|
+
linewidth: line width to use for the plot.
|
318
|
+
linestyle: line style to use for the plot.
|
304
319
|
rc: additional keyword arguments to pass to the rc context.
|
305
320
|
|
306
321
|
"""
|
@@ -309,11 +324,11 @@ def context(
|
|
309
324
|
if colors is not None:
|
310
325
|
rc["axes.prop_cycle"] = cycler(color=colors)
|
311
326
|
|
312
|
-
if
|
313
|
-
rc["lines.linewidth"] =
|
327
|
+
if linewidth is not None:
|
328
|
+
rc["lines.linewidth"] = linewidth
|
314
329
|
|
315
|
-
if
|
316
|
-
rc["lines.linestyle"] =
|
330
|
+
if linestyle is not None:
|
331
|
+
rc["lines.linestyle"] = linestyle
|
317
332
|
|
318
333
|
with plt.rc_context(rc):
|
319
334
|
yield
|
@@ -477,9 +492,12 @@ def lines(
|
|
477
492
|
x: pd.DataFrame | pd.Series,
|
478
493
|
*,
|
479
494
|
ax: Axes | None = None,
|
480
|
-
grid: bool = True,
|
481
495
|
alpha: float = 1.0,
|
496
|
+
color: Color | list[Color] | None = None,
|
497
|
+
grid: bool = True,
|
482
498
|
legend: bool = True,
|
499
|
+
linewidth: float | None = None,
|
500
|
+
linestyle: Linestyle | None = None,
|
483
501
|
) -> FigAx:
|
484
502
|
"""Plot multiple lines on the same axis."""
|
485
503
|
fig, ax = _default_fig_ax(ax=ax, grid=grid)
|
@@ -487,6 +505,9 @@ def lines(
|
|
487
505
|
x.index,
|
488
506
|
x,
|
489
507
|
alpha=alpha,
|
508
|
+
linewidth=linewidth,
|
509
|
+
linestyle=linestyle,
|
510
|
+
color=color,
|
490
511
|
)
|
491
512
|
_default_labels(ax, xlabel=x.index.name, ylabel=None)
|
492
513
|
if legend:
|
@@ -497,6 +518,12 @@ def lines(
|
|
497
518
|
return fig, ax
|
498
519
|
|
499
520
|
|
521
|
+
def _repeat_color_if_necessary(
|
522
|
+
color: list[list[Color]] | Color | None, n: int
|
523
|
+
) -> Iterable[list[Color] | Color | None]:
|
524
|
+
return [color] * n if not isinstance(color, list) else color
|
525
|
+
|
526
|
+
|
500
527
|
def lines_grouped(
|
501
528
|
groups: list[pd.DataFrame] | list[pd.Series],
|
502
529
|
*,
|
@@ -506,6 +533,9 @@ def lines_grouped(
|
|
506
533
|
sharex: bool = True,
|
507
534
|
sharey: bool = False,
|
508
535
|
grid: bool = True,
|
536
|
+
color: Color | list[list[Color]] | None = None,
|
537
|
+
linewidth: float | None = None,
|
538
|
+
linestyle: Linestyle | None = None,
|
509
539
|
) -> FigAxs:
|
510
540
|
"""Plot multiple groups of lines on separate axes."""
|
511
541
|
fig, axs = grid_layout(
|
@@ -518,8 +548,20 @@ def lines_grouped(
|
|
518
548
|
grid=grid,
|
519
549
|
)
|
520
550
|
|
521
|
-
for group, ax in zip(
|
522
|
-
|
551
|
+
for group, ax, color_ in zip(
|
552
|
+
groups,
|
553
|
+
axs,
|
554
|
+
_repeat_color_if_necessary(color, n=len(groups)),
|
555
|
+
strict=False,
|
556
|
+
):
|
557
|
+
lines(
|
558
|
+
group,
|
559
|
+
ax=ax,
|
560
|
+
grid=grid,
|
561
|
+
color=color_,
|
562
|
+
linewidth=linewidth,
|
563
|
+
linestyle=linestyle,
|
564
|
+
)
|
523
565
|
|
524
566
|
for i in range(len(groups), len(axs)):
|
525
567
|
axs[i].set_visible(False)
|
@@ -535,6 +577,9 @@ def line_autogrouped(
|
|
535
577
|
row_height: float = 3,
|
536
578
|
max_group_size: int = 6,
|
537
579
|
grid: bool = True,
|
580
|
+
color: Color | list[list[Color]] | None = None,
|
581
|
+
linewidth: float | None = None,
|
582
|
+
linestyle: Linestyle | None = None,
|
538
583
|
) -> FigAxs:
|
539
584
|
"""Plot a series or dataframe with lines grouped by order of magnitude."""
|
540
585
|
group_names = _split_large_groups(
|
@@ -556,6 +601,9 @@ def line_autogrouped(
|
|
556
601
|
col_width=col_width,
|
557
602
|
row_height=row_height,
|
558
603
|
grid=grid,
|
604
|
+
color=color,
|
605
|
+
linestyle=linestyle,
|
606
|
+
linewidth=linewidth,
|
559
607
|
)
|
560
608
|
|
561
609
|
|
@@ -564,7 +612,9 @@ def line_mean_std(
|
|
564
612
|
*,
|
565
613
|
label: str | None = None,
|
566
614
|
ax: Axes | None = None,
|
567
|
-
color:
|
615
|
+
color: Color | None = None,
|
616
|
+
linewidth: float | None = None,
|
617
|
+
linestyle: Linestyle | None = None,
|
568
618
|
alpha: float = 0.2,
|
569
619
|
grid: bool = True,
|
570
620
|
) -> FigAx:
|
@@ -579,6 +629,8 @@ def line_mean_std(
|
|
579
629
|
mean,
|
580
630
|
color=color,
|
581
631
|
label=label,
|
632
|
+
linewidth=linewidth,
|
633
|
+
linestyle=linestyle,
|
582
634
|
)
|
583
635
|
ax.fill_between(
|
584
636
|
df.index,
|
@@ -598,6 +650,9 @@ def lines_mean_std_from_2d_idx(
|
|
598
650
|
ax: Axes | None = None,
|
599
651
|
alpha: float = 0.2,
|
600
652
|
grid: bool = True,
|
653
|
+
color: Color | None = None,
|
654
|
+
linewidth: float | None = None,
|
655
|
+
linestyle: Linestyle | None = None,
|
601
656
|
) -> FigAx:
|
602
657
|
"""Plot the mean and standard deviation of a 2D indexed dataframe."""
|
603
658
|
if len(cast(pd.MultiIndex, df.index).levels) != 2: # noqa: PLR2004
|
@@ -612,25 +667,32 @@ def lines_mean_std_from_2d_idx(
|
|
612
667
|
label=name,
|
613
668
|
alpha=alpha,
|
614
669
|
ax=ax,
|
670
|
+
color=color,
|
671
|
+
linestyle=linestyle,
|
672
|
+
linewidth=linewidth,
|
615
673
|
)
|
616
674
|
ax.legend()
|
617
675
|
return fig, ax
|
618
676
|
|
619
677
|
|
620
|
-
def
|
621
|
-
df: pd.DataFrame,
|
678
|
+
def _create_heatmap(
|
622
679
|
*,
|
680
|
+
ax: Axes | None,
|
681
|
+
df: pd.DataFrame,
|
682
|
+
title: str | None = None,
|
683
|
+
xlabel: str,
|
684
|
+
ylabel: str,
|
685
|
+
xticklabels: list[str],
|
686
|
+
yticklabels: list[str],
|
687
|
+
norm: Normalize,
|
623
688
|
annotate: bool = False,
|
624
689
|
colorbar: bool = True,
|
625
690
|
invert_yaxis: bool = True,
|
626
691
|
cmap: str = "RdBu_r",
|
627
|
-
norm: Normalize | None = None,
|
628
|
-
ax: Axes | None = None,
|
629
692
|
cax: Axes | None = None,
|
630
693
|
sci_annotation_bounds: tuple[float, float] = (0.01, 100),
|
631
694
|
annotation_style: str = "2g",
|
632
695
|
) -> tuple[Figure, Axes, QuadMesh]:
|
633
|
-
"""Plot a heatmap of the given data."""
|
634
696
|
fig, ax = _default_fig_ax(
|
635
697
|
ax=ax,
|
636
698
|
figsize=(
|
@@ -639,17 +701,23 @@ def heatmap(
|
|
639
701
|
),
|
640
702
|
grid=False,
|
641
703
|
)
|
642
|
-
if norm is None:
|
643
|
-
norm = _norm_with_zero_center(df)
|
644
704
|
|
705
|
+
# Note: pcolormesh swaps index/columns
|
645
706
|
hm = ax.pcolormesh(df, norm=norm, cmap=cmap)
|
707
|
+
|
708
|
+
if xlabel is not None:
|
709
|
+
ax.set_xlabel(xlabel)
|
710
|
+
if ylabel is not None:
|
711
|
+
ax.set_ylabel(ylabel)
|
712
|
+
if title is not None:
|
713
|
+
ax.set_title(title)
|
646
714
|
ax.set_xticks(
|
647
715
|
np.arange(0, len(df.columns), 1) + 0.5,
|
648
|
-
labels=
|
716
|
+
labels=xticklabels,
|
649
717
|
)
|
650
718
|
ax.set_yticks(
|
651
719
|
np.arange(0, len(df.index), 1) + 0.5,
|
652
|
-
labels=
|
720
|
+
labels=yticklabels,
|
653
721
|
)
|
654
722
|
|
655
723
|
if annotate:
|
@@ -666,39 +734,77 @@ def heatmap(
|
|
666
734
|
return fig, ax, hm
|
667
735
|
|
668
736
|
|
737
|
+
def heatmap(
|
738
|
+
df: pd.DataFrame,
|
739
|
+
*,
|
740
|
+
ax: Axes | None = None,
|
741
|
+
title: str | None = None,
|
742
|
+
annotate: bool = False,
|
743
|
+
colorbar: bool = True,
|
744
|
+
invert_yaxis: bool = True,
|
745
|
+
cmap: str = "RdBu_r",
|
746
|
+
norm: Normalize | None = None,
|
747
|
+
cax: Axes | None = None,
|
748
|
+
sci_annotation_bounds: tuple[float, float] = (0.01, 100),
|
749
|
+
annotation_style: str = "2g",
|
750
|
+
) -> tuple[Figure, Axes, QuadMesh]:
|
751
|
+
"""Plot a heatmap of the given data."""
|
752
|
+
return _create_heatmap(
|
753
|
+
ax=ax,
|
754
|
+
df=df,
|
755
|
+
title=title,
|
756
|
+
xlabel=df.index.name,
|
757
|
+
ylabel=df.columns.name,
|
758
|
+
xticklabels=cast(list, df.columns),
|
759
|
+
yticklabels=cast(list, df.index),
|
760
|
+
annotate=annotate,
|
761
|
+
colorbar=colorbar,
|
762
|
+
invert_yaxis=invert_yaxis,
|
763
|
+
cmap=cmap,
|
764
|
+
norm=_norm_with_zero_center(df) if norm is None else norm,
|
765
|
+
cax=cax,
|
766
|
+
sci_annotation_bounds=sci_annotation_bounds,
|
767
|
+
annotation_style=annotation_style,
|
768
|
+
)
|
769
|
+
|
770
|
+
|
669
771
|
def heatmap_from_2d_idx(
|
670
772
|
df: pd.DataFrame,
|
671
773
|
variable: str,
|
774
|
+
*,
|
672
775
|
ax: Axes | None = None,
|
673
|
-
|
776
|
+
annotate: bool = False,
|
777
|
+
colorbar: bool = True,
|
778
|
+
invert_yaxis: bool = False,
|
779
|
+
cmap: str = "viridis",
|
780
|
+
norm: Normalize | None = None,
|
781
|
+
cax: Axes | None = None,
|
782
|
+
sci_annotation_bounds: tuple[float, float] = (0.01, 100),
|
783
|
+
annotation_style: str = "2g",
|
784
|
+
) -> tuple[Figure, Axes, QuadMesh]:
|
674
785
|
"""Plot a heatmap of a 2D indexed dataframe."""
|
675
786
|
if len(cast(pd.MultiIndex, df.index).levels) != 2: # noqa: PLR2004
|
676
787
|
msg = "MultiIndex must have exactly two levels"
|
677
788
|
raise ValueError(msg)
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
789
|
+
df2d = df[variable].unstack().T
|
790
|
+
|
791
|
+
return _create_heatmap(
|
792
|
+
df=df2d,
|
793
|
+
xlabel=df2d.index.name,
|
794
|
+
ylabel=df2d.columns.name,
|
795
|
+
xticklabels=[f"{i:.2f}" for i in df2d.columns],
|
796
|
+
yticklabels=[f"{i:.2f}" for i in df2d.index],
|
797
|
+
ax=ax,
|
798
|
+
cax=cax,
|
799
|
+
annotate=annotate,
|
800
|
+
colorbar=colorbar,
|
801
|
+
invert_yaxis=invert_yaxis,
|
802
|
+
cmap=cmap,
|
803
|
+
norm=_norm(df2d) if norm is None else norm,
|
804
|
+
sci_annotation_bounds=sci_annotation_bounds,
|
805
|
+
annotation_style=annotation_style,
|
694
806
|
)
|
695
807
|
|
696
|
-
rotate_xlabels(ax, rotation=45, ha="right")
|
697
|
-
|
698
|
-
# Add colorbar
|
699
|
-
fig.colorbar(hm, ax=ax)
|
700
|
-
return fig, ax
|
701
|
-
|
702
808
|
|
703
809
|
def heatmaps_from_2d_idx(
|
704
810
|
df: pd.DataFrame,
|
@@ -708,6 +814,13 @@ def heatmaps_from_2d_idx(
|
|
708
814
|
row_height_factor: float = 0.6,
|
709
815
|
sharex: bool = True,
|
710
816
|
sharey: bool = False,
|
817
|
+
annotate: bool = False,
|
818
|
+
colorbar: bool = True,
|
819
|
+
invert_yaxis: bool = False,
|
820
|
+
cmap: str = "viridis",
|
821
|
+
norm: Normalize | None = None,
|
822
|
+
sci_annotation_bounds: tuple[float, float] = (0.01, 100),
|
823
|
+
annotation_style: str = "2g",
|
711
824
|
) -> FigAxs:
|
712
825
|
"""Plot multiple heatmaps of a 2D indexed dataframe."""
|
713
826
|
idx = cast(pd.MultiIndex, df.index)
|
@@ -722,7 +835,18 @@ def heatmaps_from_2d_idx(
|
|
722
835
|
grid=False,
|
723
836
|
)
|
724
837
|
for ax, var in zip(axs, df.columns, strict=False):
|
725
|
-
heatmap_from_2d_idx(
|
838
|
+
heatmap_from_2d_idx(
|
839
|
+
df,
|
840
|
+
var,
|
841
|
+
ax=ax,
|
842
|
+
annotate=annotate,
|
843
|
+
colorbar=colorbar,
|
844
|
+
invert_yaxis=invert_yaxis,
|
845
|
+
cmap=cmap,
|
846
|
+
norm=norm,
|
847
|
+
sci_annotation_bounds=sci_annotation_bounds,
|
848
|
+
annotation_style=annotation_style,
|
849
|
+
)
|
726
850
|
return fig, axs
|
727
851
|
|
728
852
|
|
@@ -892,6 +1016,9 @@ def relative_label_distribution(
|
|
892
1016
|
row_height: float = 3,
|
893
1017
|
sharey: bool = False,
|
894
1018
|
grid: bool = True,
|
1019
|
+
color: Color | None = None,
|
1020
|
+
linewidth: float | None = None,
|
1021
|
+
linestyle: Linestyle | None = None,
|
895
1022
|
) -> FigAxs:
|
896
1023
|
"""Plot the relative distribution of labels in the given data."""
|
897
1024
|
variables = list(mapper.label_variables) if subset is None else subset
|
@@ -903,20 +1030,37 @@ def relative_label_distribution(
|
|
903
1030
|
sharey=sharey,
|
904
1031
|
grid=grid,
|
905
1032
|
)
|
1033
|
+
# FIXME: rewrite as building a dict of dataframes
|
1034
|
+
# and passing it to lines_grouped
|
906
1035
|
if isinstance(mapper, LabelMapper):
|
907
1036
|
for ax, name in zip(axs, variables, strict=False):
|
908
1037
|
for i in range(mapper.label_variables[name]):
|
909
1038
|
isos = mapper.get_isotopomers_of_at_position(name, i)
|
910
1039
|
labels = cast(pd.DataFrame, concs.loc[:, isos])
|
911
1040
|
total = concs.loc[:, f"{name}__total"]
|
912
|
-
ax.plot(
|
1041
|
+
ax.plot(
|
1042
|
+
labels.index,
|
1043
|
+
(labels.sum(axis=1) / total),
|
1044
|
+
label=f"C{i + 1}",
|
1045
|
+
linewidth=linewidth,
|
1046
|
+
linestyle=linestyle,
|
1047
|
+
color=color,
|
1048
|
+
)
|
913
1049
|
ax.set_title(name)
|
914
1050
|
ax.legend()
|
915
1051
|
else:
|
916
1052
|
for ax, (name, isos) in zip(
|
917
|
-
axs,
|
1053
|
+
axs,
|
1054
|
+
mapper.get_isotopomers(variables).items(),
|
1055
|
+
strict=False,
|
918
1056
|
):
|
919
|
-
ax.plot(
|
1057
|
+
ax.plot(
|
1058
|
+
concs.index,
|
1059
|
+
concs.loc[:, isos],
|
1060
|
+
linewidth=linewidth,
|
1061
|
+
linestyle=linestyle,
|
1062
|
+
color=color,
|
1063
|
+
)
|
920
1064
|
ax.set_title(name)
|
921
1065
|
ax.legend([f"C{i + 1}" for i in range(len(isos))])
|
922
1066
|
|
mxlpy/report.py
CHANGED
@@ -48,12 +48,39 @@ def markdown(
|
|
48
48
|
) -> str:
|
49
49
|
"""Generate a markdown report comparing two models.
|
50
50
|
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
51
|
+
Parameters
|
52
|
+
----------
|
53
|
+
m1
|
54
|
+
The first model to compare
|
55
|
+
m2
|
56
|
+
The second model to compare
|
57
|
+
analyses
|
58
|
+
A list of functions that analyze both models and return a report section with image
|
59
|
+
rel_change
|
60
|
+
The relative change threshold for numerical differences
|
61
|
+
img_path
|
62
|
+
The path to save images
|
63
|
+
|
64
|
+
Returns
|
65
|
+
-------
|
66
|
+
str
|
67
|
+
Markdown formatted report comparing the two models
|
68
|
+
|
69
|
+
Examples
|
70
|
+
--------
|
71
|
+
>>> from mxlpy import Model
|
72
|
+
>>> m1 = Model().add_parameter("k1", 0.1).add_variable("S", 1.0)
|
73
|
+
>>> m2 = Model().add_parameter("k1", 0.2).add_variable("S", 1.0)
|
74
|
+
>>> report = markdown(m1, m2)
|
75
|
+
>>> "Parameters" in report and "k1" in report
|
76
|
+
True
|
77
|
+
|
78
|
+
>>> # With custom analysis function
|
79
|
+
>>> def custom_analysis(m1, m2, path):
|
80
|
+
... return "## Custom analysis", path / "image.png"
|
81
|
+
>>> report = markdown(m1, m2, analyses=[custom_analysis])
|
82
|
+
>>> "Custom analysis" in report
|
83
|
+
True
|
57
84
|
|
58
85
|
"""
|
59
86
|
content: list[str] = [
|
mxlpy/sbml/_import.py
CHANGED
@@ -12,7 +12,7 @@ import libsbml
|
|
12
12
|
import numpy as np # noqa: F401 # models might need it
|
13
13
|
import sympy
|
14
14
|
|
15
|
-
from mxlpy.model import Model, _sort_dependencies
|
15
|
+
from mxlpy.model import Dependency, Model, _sort_dependencies
|
16
16
|
from mxlpy.paths import default_tmp_dir
|
17
17
|
from mxlpy.sbml._data import (
|
18
18
|
AtomicUnit,
|
@@ -522,7 +522,10 @@ def _codgen(name: str, sbml: Parser) -> Path:
|
|
522
522
|
^ set(variables)
|
523
523
|
^ set(sbml.derived)
|
524
524
|
| {"time"},
|
525
|
-
elements=[
|
525
|
+
elements=[
|
526
|
+
Dependency(name=k, required=set(v.args), provided={k})
|
527
|
+
for k, v in sbml.initial_assignment.items()
|
528
|
+
],
|
526
529
|
)
|
527
530
|
|
528
531
|
if len(initial_assignment_order) > 0:
|
mxlpy/surrogates/__init__.py
CHANGED
@@ -19,13 +19,14 @@ from __future__ import annotations
|
|
19
19
|
import contextlib
|
20
20
|
|
21
21
|
with contextlib.suppress(ImportError):
|
22
|
-
from ._torch import
|
22
|
+
from ._torch import Torch, TorchTrainer, train_torch
|
23
23
|
|
24
|
-
from ._poly import
|
24
|
+
from ._poly import Polynomial, train_polynomial
|
25
25
|
|
26
26
|
__all__ = [
|
27
|
-
"
|
28
|
-
"
|
29
|
-
"
|
30
|
-
"
|
27
|
+
"Polynomial",
|
28
|
+
"Torch",
|
29
|
+
"TorchTrainer",
|
30
|
+
"train_polynomial",
|
31
|
+
"train_torch",
|
31
32
|
]
|
mxlpy/surrogates/_poly.py
CHANGED
@@ -9,9 +9,9 @@ from numpy import polynomial
|
|
9
9
|
from mxlpy.types import AbstractSurrogate, ArrayLike
|
10
10
|
|
11
11
|
__all__ = [
|
12
|
-
"
|
12
|
+
"Polynomial",
|
13
13
|
"PolynomialExpansion",
|
14
|
-
"
|
14
|
+
"train_polynomial",
|
15
15
|
]
|
16
16
|
|
17
17
|
# define custom type
|
@@ -26,23 +26,24 @@ PolynomialExpansion = (
|
|
26
26
|
|
27
27
|
|
28
28
|
@dataclass(kw_only=True)
|
29
|
-
class
|
29
|
+
class Polynomial(AbstractSurrogate):
|
30
30
|
model: PolynomialExpansion
|
31
31
|
|
32
32
|
def predict_raw(self, y: np.ndarray) -> np.ndarray:
|
33
33
|
return self.model(y)
|
34
34
|
|
35
35
|
|
36
|
-
def
|
37
|
-
feature: ArrayLike,
|
38
|
-
target: ArrayLike,
|
36
|
+
def train_polynomial(
|
37
|
+
feature: ArrayLike | pd.Series,
|
38
|
+
target: ArrayLike | pd.Series,
|
39
39
|
series: Literal[
|
40
40
|
"Power", "Chebyshev", "Legendre", "Laguerre", "Hermite", "HermiteE"
|
41
41
|
] = "Power",
|
42
42
|
degrees: Iterable[int] = (1, 2, 3, 4, 5, 6, 7),
|
43
43
|
surrogate_args: list[str] | None = None,
|
44
|
+
surrogate_outputs: list[str] | None = None,
|
44
45
|
surrogate_stoichiometries: dict[str, dict[str, float]] | None = None,
|
45
|
-
) -> tuple[
|
46
|
+
) -> tuple[Polynomial, pd.DataFrame]:
|
46
47
|
"""Train a surrogate model based on function series expansion.
|
47
48
|
|
48
49
|
Args:
|
@@ -51,7 +52,8 @@ def train_polynomial_surrogate(
|
|
51
52
|
series: Base functions for the surrogate model
|
52
53
|
degrees: Degrees of the polynomial to fit to the data.
|
53
54
|
surrogate_args: Additional arguments for the surrogate model.
|
54
|
-
|
55
|
+
surrogate_outputs: Names of the surrogate model outputs.
|
56
|
+
surrogate_stoichiometries: Mapping of variables to their stoichiometries
|
55
57
|
|
56
58
|
Returns:
|
57
59
|
PolySurrogate: Polynomial surrogate model.
|
@@ -83,9 +85,10 @@ def train_polynomial_surrogate(
|
|
83
85
|
# Choose the model with the lowest AIC
|
84
86
|
model = models[np.argmin(score)]
|
85
87
|
return (
|
86
|
-
|
88
|
+
Polynomial(
|
87
89
|
model=model,
|
88
90
|
args=surrogate_args if surrogate_args is not None else [],
|
91
|
+
outputs=surrogate_outputs if surrogate_outputs is not None else [],
|
89
92
|
stoichiometries=surrogate_stoichiometries
|
90
93
|
if surrogate_stoichiometries is not None
|
91
94
|
else {},
|