mxlpy 0.16.0__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mxlpy/plot.py CHANGED
@@ -23,9 +23,12 @@ import contextlib
23
23
  from cycler import cycler
24
24
 
25
25
  __all__ = [
26
+ "Color",
26
27
  "FigAx",
27
28
  "FigAxs",
28
29
  "Linestyle",
30
+ "RGB",
31
+ "RGBA",
29
32
  "add_grid",
30
33
  "bars",
31
34
  "context",
@@ -71,7 +74,7 @@ from mpl_toolkits.mplot3d import Axes3D
71
74
  from mxlpy.label_map import LabelMapper
72
75
 
73
76
  if TYPE_CHECKING:
74
- from collections.abc import Generator
77
+ from collections.abc import Generator, Iterable
75
78
 
76
79
  from matplotlib.collections import QuadMesh
77
80
 
@@ -89,6 +92,11 @@ type Linestyle = Literal[
89
92
  "dashdot",
90
93
  ]
91
94
 
95
+
96
+ type RGB = tuple[float, float, float]
97
+ type RGBA = tuple[float, float, float, float]
98
+ type Color = str | RGB | RGBA
99
+
92
100
  ##########################################################################
93
101
  # Helpers
94
102
  ##########################################################################
@@ -133,6 +141,13 @@ def _get_norm(vmin: float, vmax: float) -> Normalize:
133
141
  return norm
134
142
 
135
143
 
144
+ def _norm(df: pd.DataFrame) -> Normalize:
145
+ """Get a normalization object for the given data."""
146
+ vmin = df.min().min()
147
+ vmax = df.max().max()
148
+ return _get_norm(vmin, vmax)
149
+
150
+
136
151
  def _norm_with_zero_center(df: pd.DataFrame) -> Normalize:
137
152
  """Get a normalization object with zero-centered values for the given data."""
138
153
  v = max(abs(df.min().min()), abs(df.max().max()))
@@ -166,7 +181,7 @@ def _split_large_groups[T](groups: list[list[T]], max_size: int) -> list[list[T]
166
181
  ) # type: ignore
167
182
 
168
183
 
169
- def _default_color(ax: Axes, color: str | None) -> str:
184
+ def _default_color(ax: Axes, color: Color | None) -> Color:
170
185
  """Get a default color for the given axis."""
171
186
  return f"C{len(ax.lines)}" if color is None else color
172
187
 
@@ -291,16 +306,16 @@ def reset_prop_cycle(ax: Axes) -> None:
291
306
  @contextlib.contextmanager
292
307
  def context(
293
308
  colors: list[str] | None = None,
294
- line_width: float | None = None,
295
- line_style: Linestyle | None = None,
309
+ linewidth: float | None = None,
310
+ linestyle: Linestyle | None = None,
296
311
  rc: dict[str, Any] | None = None,
297
312
  ) -> Generator[None, None, None]:
298
313
  """Context manager to set the defaults for plots.
299
314
 
300
315
  Args:
301
316
  colors: colors to use for the plot.
302
- line_width: line width to use for the plot.
303
- line_style: line style to use for the plot.
317
+ linewidth: line width to use for the plot.
318
+ linestyle: line style to use for the plot.
304
319
  rc: additional keyword arguments to pass to the rc context.
305
320
 
306
321
  """
@@ -309,11 +324,11 @@ def context(
309
324
  if colors is not None:
310
325
  rc["axes.prop_cycle"] = cycler(color=colors)
311
326
 
312
- if line_width is not None:
313
- rc["lines.linewidth"] = line_width
327
+ if linewidth is not None:
328
+ rc["lines.linewidth"] = linewidth
314
329
 
315
- if line_style is not None:
316
- rc["lines.linestyle"] = line_style
330
+ if linestyle is not None:
331
+ rc["lines.linestyle"] = linestyle
317
332
 
318
333
  with plt.rc_context(rc):
319
334
  yield
@@ -477,9 +492,12 @@ def lines(
477
492
  x: pd.DataFrame | pd.Series,
478
493
  *,
479
494
  ax: Axes | None = None,
480
- grid: bool = True,
481
495
  alpha: float = 1.0,
496
+ color: Color | list[Color] | None = None,
497
+ grid: bool = True,
482
498
  legend: bool = True,
499
+ linewidth: float | None = None,
500
+ linestyle: Linestyle | None = None,
483
501
  ) -> FigAx:
484
502
  """Plot multiple lines on the same axis."""
485
503
  fig, ax = _default_fig_ax(ax=ax, grid=grid)
@@ -487,6 +505,9 @@ def lines(
487
505
  x.index,
488
506
  x,
489
507
  alpha=alpha,
508
+ linewidth=linewidth,
509
+ linestyle=linestyle,
510
+ color=color,
490
511
  )
491
512
  _default_labels(ax, xlabel=x.index.name, ylabel=None)
492
513
  if legend:
@@ -497,6 +518,12 @@ def lines(
497
518
  return fig, ax
498
519
 
499
520
 
521
+ def _repeat_color_if_necessary(
522
+ color: list[list[Color]] | Color | None, n: int
523
+ ) -> Iterable[list[Color] | Color | None]:
524
+ return [color] * n if not isinstance(color, list) else color
525
+
526
+
500
527
  def lines_grouped(
501
528
  groups: list[pd.DataFrame] | list[pd.Series],
502
529
  *,
@@ -506,6 +533,9 @@ def lines_grouped(
506
533
  sharex: bool = True,
507
534
  sharey: bool = False,
508
535
  grid: bool = True,
536
+ color: Color | list[list[Color]] | None = None,
537
+ linewidth: float | None = None,
538
+ linestyle: Linestyle | None = None,
509
539
  ) -> FigAxs:
510
540
  """Plot multiple groups of lines on separate axes."""
511
541
  fig, axs = grid_layout(
@@ -518,8 +548,20 @@ def lines_grouped(
518
548
  grid=grid,
519
549
  )
520
550
 
521
- for group, ax in zip(groups, axs, strict=False):
522
- lines(group, ax=ax, grid=grid)
551
+ for group, ax, color_ in zip(
552
+ groups,
553
+ axs,
554
+ _repeat_color_if_necessary(color, n=len(groups)),
555
+ strict=False,
556
+ ):
557
+ lines(
558
+ group,
559
+ ax=ax,
560
+ grid=grid,
561
+ color=color_,
562
+ linewidth=linewidth,
563
+ linestyle=linestyle,
564
+ )
523
565
 
524
566
  for i in range(len(groups), len(axs)):
525
567
  axs[i].set_visible(False)
@@ -535,6 +577,9 @@ def line_autogrouped(
535
577
  row_height: float = 3,
536
578
  max_group_size: int = 6,
537
579
  grid: bool = True,
580
+ color: Color | list[list[Color]] | None = None,
581
+ linewidth: float | None = None,
582
+ linestyle: Linestyle | None = None,
538
583
  ) -> FigAxs:
539
584
  """Plot a series or dataframe with lines grouped by order of magnitude."""
540
585
  group_names = _split_large_groups(
@@ -556,6 +601,9 @@ def line_autogrouped(
556
601
  col_width=col_width,
557
602
  row_height=row_height,
558
603
  grid=grid,
604
+ color=color,
605
+ linestyle=linestyle,
606
+ linewidth=linewidth,
559
607
  )
560
608
 
561
609
 
@@ -564,7 +612,9 @@ def line_mean_std(
564
612
  *,
565
613
  label: str | None = None,
566
614
  ax: Axes | None = None,
567
- color: str | None = None,
615
+ color: Color | None = None,
616
+ linewidth: float | None = None,
617
+ linestyle: Linestyle | None = None,
568
618
  alpha: float = 0.2,
569
619
  grid: bool = True,
570
620
  ) -> FigAx:
@@ -579,6 +629,8 @@ def line_mean_std(
579
629
  mean,
580
630
  color=color,
581
631
  label=label,
632
+ linewidth=linewidth,
633
+ linestyle=linestyle,
582
634
  )
583
635
  ax.fill_between(
584
636
  df.index,
@@ -598,6 +650,9 @@ def lines_mean_std_from_2d_idx(
598
650
  ax: Axes | None = None,
599
651
  alpha: float = 0.2,
600
652
  grid: bool = True,
653
+ color: Color | None = None,
654
+ linewidth: float | None = None,
655
+ linestyle: Linestyle | None = None,
601
656
  ) -> FigAx:
602
657
  """Plot the mean and standard deviation of a 2D indexed dataframe."""
603
658
  if len(cast(pd.MultiIndex, df.index).levels) != 2: # noqa: PLR2004
@@ -612,25 +667,32 @@ def lines_mean_std_from_2d_idx(
612
667
  label=name,
613
668
  alpha=alpha,
614
669
  ax=ax,
670
+ color=color,
671
+ linestyle=linestyle,
672
+ linewidth=linewidth,
615
673
  )
616
674
  ax.legend()
617
675
  return fig, ax
618
676
 
619
677
 
620
- def heatmap(
621
- df: pd.DataFrame,
678
+ def _create_heatmap(
622
679
  *,
680
+ ax: Axes | None,
681
+ df: pd.DataFrame,
682
+ title: str | None = None,
683
+ xlabel: str,
684
+ ylabel: str,
685
+ xticklabels: list[str],
686
+ yticklabels: list[str],
687
+ norm: Normalize,
623
688
  annotate: bool = False,
624
689
  colorbar: bool = True,
625
690
  invert_yaxis: bool = True,
626
691
  cmap: str = "RdBu_r",
627
- norm: Normalize | None = None,
628
- ax: Axes | None = None,
629
692
  cax: Axes | None = None,
630
693
  sci_annotation_bounds: tuple[float, float] = (0.01, 100),
631
694
  annotation_style: str = "2g",
632
695
  ) -> tuple[Figure, Axes, QuadMesh]:
633
- """Plot a heatmap of the given data."""
634
696
  fig, ax = _default_fig_ax(
635
697
  ax=ax,
636
698
  figsize=(
@@ -639,17 +701,23 @@ def heatmap(
639
701
  ),
640
702
  grid=False,
641
703
  )
642
- if norm is None:
643
- norm = _norm_with_zero_center(df)
644
704
 
705
+ # Note: pcolormesh swaps index/columns
645
706
  hm = ax.pcolormesh(df, norm=norm, cmap=cmap)
707
+
708
+ if xlabel is not None:
709
+ ax.set_xlabel(xlabel)
710
+ if ylabel is not None:
711
+ ax.set_ylabel(ylabel)
712
+ if title is not None:
713
+ ax.set_title(title)
646
714
  ax.set_xticks(
647
715
  np.arange(0, len(df.columns), 1) + 0.5,
648
- labels=df.columns,
716
+ labels=xticklabels,
649
717
  )
650
718
  ax.set_yticks(
651
719
  np.arange(0, len(df.index), 1) + 0.5,
652
- labels=df.index,
720
+ labels=yticklabels,
653
721
  )
654
722
 
655
723
  if annotate:
@@ -666,39 +734,77 @@ def heatmap(
666
734
  return fig, ax, hm
667
735
 
668
736
 
737
+ def heatmap(
738
+ df: pd.DataFrame,
739
+ *,
740
+ ax: Axes | None = None,
741
+ title: str | None = None,
742
+ annotate: bool = False,
743
+ colorbar: bool = True,
744
+ invert_yaxis: bool = True,
745
+ cmap: str = "RdBu_r",
746
+ norm: Normalize | None = None,
747
+ cax: Axes | None = None,
748
+ sci_annotation_bounds: tuple[float, float] = (0.01, 100),
749
+ annotation_style: str = "2g",
750
+ ) -> tuple[Figure, Axes, QuadMesh]:
751
+ """Plot a heatmap of the given data."""
752
+ return _create_heatmap(
753
+ ax=ax,
754
+ df=df,
755
+ title=title,
756
+ xlabel=df.index.name,
757
+ ylabel=df.columns.name,
758
+ xticklabels=cast(list, df.columns),
759
+ yticklabels=cast(list, df.index),
760
+ annotate=annotate,
761
+ colorbar=colorbar,
762
+ invert_yaxis=invert_yaxis,
763
+ cmap=cmap,
764
+ norm=_norm_with_zero_center(df) if norm is None else norm,
765
+ cax=cax,
766
+ sci_annotation_bounds=sci_annotation_bounds,
767
+ annotation_style=annotation_style,
768
+ )
769
+
770
+
669
771
  def heatmap_from_2d_idx(
670
772
  df: pd.DataFrame,
671
773
  variable: str,
774
+ *,
672
775
  ax: Axes | None = None,
673
- ) -> FigAx:
776
+ annotate: bool = False,
777
+ colorbar: bool = True,
778
+ invert_yaxis: bool = False,
779
+ cmap: str = "viridis",
780
+ norm: Normalize | None = None,
781
+ cax: Axes | None = None,
782
+ sci_annotation_bounds: tuple[float, float] = (0.01, 100),
783
+ annotation_style: str = "2g",
784
+ ) -> tuple[Figure, Axes, QuadMesh]:
674
785
  """Plot a heatmap of a 2D indexed dataframe."""
675
786
  if len(cast(pd.MultiIndex, df.index).levels) != 2: # noqa: PLR2004
676
787
  msg = "MultiIndex must have exactly two levels"
677
788
  raise ValueError(msg)
678
-
679
- fig, ax = _default_fig_ax(ax=ax, grid=False)
680
- df2d = df[variable].unstack()
681
-
682
- ax.set_title(variable)
683
- # Note: pcolormesh swaps index/columns
684
- hm = ax.pcolormesh(df2d.T)
685
- ax.set_xlabel(df2d.index.name)
686
- ax.set_ylabel(df2d.columns.name)
687
- ax.set_xticks(
688
- np.arange(0, len(df2d.index), 1) + 0.5,
689
- labels=[f"{i:.2f}" for i in df2d.index],
690
- )
691
- ax.set_yticks(
692
- np.arange(0, len(df2d.columns), 1) + 0.5,
693
- labels=[f"{i:.2f}" for i in df2d.columns],
789
+ df2d = df[variable].unstack().T
790
+
791
+ return _create_heatmap(
792
+ df=df2d,
793
+ xlabel=df2d.index.name,
794
+ ylabel=df2d.columns.name,
795
+ xticklabels=[f"{i:.2f}" for i in df2d.columns],
796
+ yticklabels=[f"{i:.2f}" for i in df2d.index],
797
+ ax=ax,
798
+ cax=cax,
799
+ annotate=annotate,
800
+ colorbar=colorbar,
801
+ invert_yaxis=invert_yaxis,
802
+ cmap=cmap,
803
+ norm=_norm(df2d) if norm is None else norm,
804
+ sci_annotation_bounds=sci_annotation_bounds,
805
+ annotation_style=annotation_style,
694
806
  )
695
807
 
696
- rotate_xlabels(ax, rotation=45, ha="right")
697
-
698
- # Add colorbar
699
- fig.colorbar(hm, ax=ax)
700
- return fig, ax
701
-
702
808
 
703
809
  def heatmaps_from_2d_idx(
704
810
  df: pd.DataFrame,
@@ -708,6 +814,13 @@ def heatmaps_from_2d_idx(
708
814
  row_height_factor: float = 0.6,
709
815
  sharex: bool = True,
710
816
  sharey: bool = False,
817
+ annotate: bool = False,
818
+ colorbar: bool = True,
819
+ invert_yaxis: bool = False,
820
+ cmap: str = "viridis",
821
+ norm: Normalize | None = None,
822
+ sci_annotation_bounds: tuple[float, float] = (0.01, 100),
823
+ annotation_style: str = "2g",
711
824
  ) -> FigAxs:
712
825
  """Plot multiple heatmaps of a 2D indexed dataframe."""
713
826
  idx = cast(pd.MultiIndex, df.index)
@@ -722,7 +835,18 @@ def heatmaps_from_2d_idx(
722
835
  grid=False,
723
836
  )
724
837
  for ax, var in zip(axs, df.columns, strict=False):
725
- heatmap_from_2d_idx(df, var, ax=ax)
838
+ heatmap_from_2d_idx(
839
+ df,
840
+ var,
841
+ ax=ax,
842
+ annotate=annotate,
843
+ colorbar=colorbar,
844
+ invert_yaxis=invert_yaxis,
845
+ cmap=cmap,
846
+ norm=norm,
847
+ sci_annotation_bounds=sci_annotation_bounds,
848
+ annotation_style=annotation_style,
849
+ )
726
850
  return fig, axs
727
851
 
728
852
 
@@ -892,6 +1016,9 @@ def relative_label_distribution(
892
1016
  row_height: float = 3,
893
1017
  sharey: bool = False,
894
1018
  grid: bool = True,
1019
+ color: Color | None = None,
1020
+ linewidth: float | None = None,
1021
+ linestyle: Linestyle | None = None,
895
1022
  ) -> FigAxs:
896
1023
  """Plot the relative distribution of labels in the given data."""
897
1024
  variables = list(mapper.label_variables) if subset is None else subset
@@ -903,20 +1030,37 @@ def relative_label_distribution(
903
1030
  sharey=sharey,
904
1031
  grid=grid,
905
1032
  )
1033
+ # FIXME: rewrite as building a dict of dataframes
1034
+ # and passing it to lines_grouped
906
1035
  if isinstance(mapper, LabelMapper):
907
1036
  for ax, name in zip(axs, variables, strict=False):
908
1037
  for i in range(mapper.label_variables[name]):
909
1038
  isos = mapper.get_isotopomers_of_at_position(name, i)
910
1039
  labels = cast(pd.DataFrame, concs.loc[:, isos])
911
1040
  total = concs.loc[:, f"{name}__total"]
912
- ax.plot(labels.index, (labels.sum(axis=1) / total), label=f"C{i + 1}")
1041
+ ax.plot(
1042
+ labels.index,
1043
+ (labels.sum(axis=1) / total),
1044
+ label=f"C{i + 1}",
1045
+ linewidth=linewidth,
1046
+ linestyle=linestyle,
1047
+ color=color,
1048
+ )
913
1049
  ax.set_title(name)
914
1050
  ax.legend()
915
1051
  else:
916
1052
  for ax, (name, isos) in zip(
917
- axs, mapper.get_isotopomers(variables).items(), strict=False
1053
+ axs,
1054
+ mapper.get_isotopomers(variables).items(),
1055
+ strict=False,
918
1056
  ):
919
- ax.plot(concs.index, concs.loc[:, isos])
1057
+ ax.plot(
1058
+ concs.index,
1059
+ concs.loc[:, isos],
1060
+ linewidth=linewidth,
1061
+ linestyle=linestyle,
1062
+ color=color,
1063
+ )
920
1064
  ax.set_title(name)
921
1065
  ax.legend([f"C{i + 1}" for i in range(len(isos))])
922
1066
 
mxlpy/report.py CHANGED
@@ -48,12 +48,39 @@ def markdown(
48
48
  ) -> str:
49
49
  """Generate a markdown report comparing two models.
50
50
 
51
- Args:
52
- m1: The first model to compare.
53
- m2: The second model to compare.
54
- analyses: A list of functions that take a Path and return a tuple of a string and a Path. Defaults to None.
55
- rel_change: The relative change threshold for numerical differences. Defaults to 1e-2.
56
- img_path: The path to save images. Defaults to Path().
51
+ Parameters
52
+ ----------
53
+ m1
54
+ The first model to compare
55
+ m2
56
+ The second model to compare
57
+ analyses
58
+ A list of functions that analyze both models and return a report section with image
59
+ rel_change
60
+ The relative change threshold for numerical differences
61
+ img_path
62
+ The path to save images
63
+
64
+ Returns
65
+ -------
66
+ str
67
+ Markdown formatted report comparing the two models
68
+
69
+ Examples
70
+ --------
71
+ >>> from mxlpy import Model
72
+ >>> m1 = Model().add_parameter("k1", 0.1).add_variable("S", 1.0)
73
+ >>> m2 = Model().add_parameter("k1", 0.2).add_variable("S", 1.0)
74
+ >>> report = markdown(m1, m2)
75
+ >>> "Parameters" in report and "k1" in report
76
+ True
77
+
78
+ >>> # With custom analysis function
79
+ >>> def custom_analysis(m1, m2, path):
80
+ ... return "## Custom analysis", path / "image.png"
81
+ >>> report = markdown(m1, m2, analyses=[custom_analysis])
82
+ >>> "Custom analysis" in report
83
+ True
57
84
 
58
85
  """
59
86
  content: list[str] = [
mxlpy/sbml/_import.py CHANGED
@@ -12,7 +12,7 @@ import libsbml
12
12
  import numpy as np # noqa: F401 # models might need it
13
13
  import sympy
14
14
 
15
- from mxlpy.model import Model, _sort_dependencies
15
+ from mxlpy.model import Dependency, Model, _sort_dependencies
16
16
  from mxlpy.paths import default_tmp_dir
17
17
  from mxlpy.sbml._data import (
18
18
  AtomicUnit,
@@ -522,7 +522,10 @@ def _codgen(name: str, sbml: Parser) -> Path:
522
522
  ^ set(variables)
523
523
  ^ set(sbml.derived)
524
524
  | {"time"},
525
- elements=[(k, set(v.args)) for k, v in sbml.initial_assignment.items()],
525
+ elements=[
526
+ Dependency(name=k, required=set(v.args), provided={k})
527
+ for k, v in sbml.initial_assignment.items()
528
+ ],
526
529
  )
527
530
 
528
531
  if len(initial_assignment_order) > 0:
@@ -19,13 +19,14 @@ from __future__ import annotations
19
19
  import contextlib
20
20
 
21
21
  with contextlib.suppress(ImportError):
22
- from ._torch import TorchSurrogate, train_torch_surrogate
22
+ from ._torch import Torch, TorchTrainer, train_torch
23
23
 
24
- from ._poly import PolySurrogate, train_polynomial_surrogate
24
+ from ._poly import Polynomial, train_polynomial
25
25
 
26
26
  __all__ = [
27
- "PolySurrogate",
28
- "TorchSurrogate",
29
- "train_polynomial_surrogate",
30
- "train_torch_surrogate",
27
+ "Polynomial",
28
+ "Torch",
29
+ "TorchTrainer",
30
+ "train_polynomial",
31
+ "train_torch",
31
32
  ]
mxlpy/surrogates/_poly.py CHANGED
@@ -9,9 +9,9 @@ from numpy import polynomial
9
9
  from mxlpy.types import AbstractSurrogate, ArrayLike
10
10
 
11
11
  __all__ = [
12
- "PolySurrogate",
12
+ "Polynomial",
13
13
  "PolynomialExpansion",
14
- "train_polynomial_surrogate",
14
+ "train_polynomial",
15
15
  ]
16
16
 
17
17
  # define custom type
@@ -26,23 +26,24 @@ PolynomialExpansion = (
26
26
 
27
27
 
28
28
  @dataclass(kw_only=True)
29
- class PolySurrogate(AbstractSurrogate):
29
+ class Polynomial(AbstractSurrogate):
30
30
  model: PolynomialExpansion
31
31
 
32
32
  def predict_raw(self, y: np.ndarray) -> np.ndarray:
33
33
  return self.model(y)
34
34
 
35
35
 
36
- def train_polynomial_surrogate(
37
- feature: ArrayLike,
38
- target: ArrayLike,
36
+ def train_polynomial(
37
+ feature: ArrayLike | pd.Series,
38
+ target: ArrayLike | pd.Series,
39
39
  series: Literal[
40
40
  "Power", "Chebyshev", "Legendre", "Laguerre", "Hermite", "HermiteE"
41
41
  ] = "Power",
42
42
  degrees: Iterable[int] = (1, 2, 3, 4, 5, 6, 7),
43
43
  surrogate_args: list[str] | None = None,
44
+ surrogate_outputs: list[str] | None = None,
44
45
  surrogate_stoichiometries: dict[str, dict[str, float]] | None = None,
45
- ) -> tuple[PolySurrogate, pd.DataFrame]:
46
+ ) -> tuple[Polynomial, pd.DataFrame]:
46
47
  """Train a surrogate model based on function series expansion.
47
48
 
48
49
  Args:
@@ -51,7 +52,8 @@ def train_polynomial_surrogate(
51
52
  series: Base functions for the surrogate model
52
53
  degrees: Degrees of the polynomial to fit to the data.
53
54
  surrogate_args: Additional arguments for the surrogate model.
54
- surrogate_stoichiometries: Stoichiometries for the surrogate model.
55
+ surrogate_outputs: Names of the surrogate model outputs.
56
+ surrogate_stoichiometries: Mapping of variables to their stoichiometries
55
57
 
56
58
  Returns:
57
59
  PolySurrogate: Polynomial surrogate model.
@@ -83,9 +85,10 @@ def train_polynomial_surrogate(
83
85
  # Choose the model with the lowest AIC
84
86
  model = models[np.argmin(score)]
85
87
  return (
86
- PolySurrogate(
88
+ Polynomial(
87
89
  model=model,
88
90
  args=surrogate_args if surrogate_args is not None else [],
91
+ outputs=surrogate_outputs if surrogate_outputs is not None else [],
89
92
  stoichiometries=surrogate_stoichiometries
90
93
  if surrogate_stoichiometries is not None
91
94
  else {},