mxlpy 0.17.0__py3-none-any.whl → 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mxlpy/plot.py CHANGED
@@ -23,9 +23,12 @@ import contextlib
23
23
  from cycler import cycler
24
24
 
25
25
  __all__ = [
26
+ "Color",
26
27
  "FigAx",
27
28
  "FigAxs",
28
29
  "Linestyle",
30
+ "RGB",
31
+ "RGBA",
29
32
  "add_grid",
30
33
  "bars",
31
34
  "context",
@@ -71,7 +74,7 @@ from mpl_toolkits.mplot3d import Axes3D
71
74
  from mxlpy.label_map import LabelMapper
72
75
 
73
76
  if TYPE_CHECKING:
74
- from collections.abc import Generator
77
+ from collections.abc import Generator, Iterable
75
78
 
76
79
  from matplotlib.collections import QuadMesh
77
80
 
@@ -89,6 +92,11 @@ type Linestyle = Literal[
89
92
  "dashdot",
90
93
  ]
91
94
 
95
+
96
+ type RGB = tuple[float, float, float]
97
+ type RGBA = tuple[float, float, float, float]
98
+ type Color = str | RGB | RGBA
99
+
92
100
  ##########################################################################
93
101
  # Helpers
94
102
  ##########################################################################
@@ -133,6 +141,13 @@ def _get_norm(vmin: float, vmax: float) -> Normalize:
133
141
  return norm
134
142
 
135
143
 
144
+ def _norm(df: pd.DataFrame) -> Normalize:
145
+ """Get a normalization object for the given data."""
146
+ vmin = df.min().min()
147
+ vmax = df.max().max()
148
+ return _get_norm(vmin, vmax)
149
+
150
+
136
151
  def _norm_with_zero_center(df: pd.DataFrame) -> Normalize:
137
152
  """Get a normalization object with zero-centered values for the given data."""
138
153
  v = max(abs(df.min().min()), abs(df.max().max()))
@@ -166,7 +181,7 @@ def _split_large_groups[T](groups: list[list[T]], max_size: int) -> list[list[T]
166
181
  ) # type: ignore
167
182
 
168
183
 
169
- def _default_color(ax: Axes, color: str | None) -> str:
184
+ def _default_color(ax: Axes, color: Color | None) -> Color:
170
185
  """Get a default color for the given axis."""
171
186
  return f"C{len(ax.lines)}" if color is None else color
172
187
 
@@ -291,16 +306,16 @@ def reset_prop_cycle(ax: Axes) -> None:
291
306
  @contextlib.contextmanager
292
307
  def context(
293
308
  colors: list[str] | None = None,
294
- line_width: float | None = None,
295
- line_style: Linestyle | None = None,
309
+ linewidth: float | None = None,
310
+ linestyle: Linestyle | None = None,
296
311
  rc: dict[str, Any] | None = None,
297
312
  ) -> Generator[None, None, None]:
298
313
  """Context manager to set the defaults for plots.
299
314
 
300
315
  Args:
301
316
  colors: colors to use for the plot.
302
- line_width: line width to use for the plot.
303
- line_style: line style to use for the plot.
317
+ linewidth: line width to use for the plot.
318
+ linestyle: line style to use for the plot.
304
319
  rc: additional keyword arguments to pass to the rc context.
305
320
 
306
321
  """
@@ -309,11 +324,11 @@ def context(
309
324
  if colors is not None:
310
325
  rc["axes.prop_cycle"] = cycler(color=colors)
311
326
 
312
- if line_width is not None:
313
- rc["lines.linewidth"] = line_width
327
+ if linewidth is not None:
328
+ rc["lines.linewidth"] = linewidth
314
329
 
315
- if line_style is not None:
316
- rc["lines.linestyle"] = line_style
330
+ if linestyle is not None:
331
+ rc["lines.linestyle"] = linestyle
317
332
 
318
333
  with plt.rc_context(rc):
319
334
  yield
@@ -477,9 +492,12 @@ def lines(
477
492
  x: pd.DataFrame | pd.Series,
478
493
  *,
479
494
  ax: Axes | None = None,
480
- grid: bool = True,
481
495
  alpha: float = 1.0,
496
+ color: Color | list[Color] | None = None,
497
+ grid: bool = True,
482
498
  legend: bool = True,
499
+ linewidth: float | None = None,
500
+ linestyle: Linestyle | None = None,
483
501
  ) -> FigAx:
484
502
  """Plot multiple lines on the same axis."""
485
503
  fig, ax = _default_fig_ax(ax=ax, grid=grid)
@@ -487,6 +505,9 @@ def lines(
487
505
  x.index,
488
506
  x,
489
507
  alpha=alpha,
508
+ linewidth=linewidth,
509
+ linestyle=linestyle,
510
+ color=color,
490
511
  )
491
512
  _default_labels(ax, xlabel=x.index.name, ylabel=None)
492
513
  if legend:
@@ -497,6 +518,12 @@ def lines(
497
518
  return fig, ax
498
519
 
499
520
 
521
+ def _repeat_color_if_necessary(
522
+ color: list[list[Color]] | Color | None, n: int
523
+ ) -> Iterable[list[Color] | Color | None]:
524
+ return [color] * n if not isinstance(color, list) else color
525
+
526
+
500
527
  def lines_grouped(
501
528
  groups: list[pd.DataFrame] | list[pd.Series],
502
529
  *,
@@ -506,6 +533,9 @@ def lines_grouped(
506
533
  sharex: bool = True,
507
534
  sharey: bool = False,
508
535
  grid: bool = True,
536
+ color: Color | list[list[Color]] | None = None,
537
+ linewidth: float | None = None,
538
+ linestyle: Linestyle | None = None,
509
539
  ) -> FigAxs:
510
540
  """Plot multiple groups of lines on separate axes."""
511
541
  fig, axs = grid_layout(
@@ -518,8 +548,20 @@ def lines_grouped(
518
548
  grid=grid,
519
549
  )
520
550
 
521
- for group, ax in zip(groups, axs, strict=False):
522
- lines(group, ax=ax, grid=grid)
551
+ for group, ax, color_ in zip(
552
+ groups,
553
+ axs,
554
+ _repeat_color_if_necessary(color, n=len(groups)),
555
+ strict=False,
556
+ ):
557
+ lines(
558
+ group,
559
+ ax=ax,
560
+ grid=grid,
561
+ color=color_,
562
+ linewidth=linewidth,
563
+ linestyle=linestyle,
564
+ )
523
565
 
524
566
  for i in range(len(groups), len(axs)):
525
567
  axs[i].set_visible(False)
@@ -535,6 +577,9 @@ def line_autogrouped(
535
577
  row_height: float = 3,
536
578
  max_group_size: int = 6,
537
579
  grid: bool = True,
580
+ color: Color | list[list[Color]] | None = None,
581
+ linewidth: float | None = None,
582
+ linestyle: Linestyle | None = None,
538
583
  ) -> FigAxs:
539
584
  """Plot a series or dataframe with lines grouped by order of magnitude."""
540
585
  group_names = _split_large_groups(
@@ -556,6 +601,9 @@ def line_autogrouped(
556
601
  col_width=col_width,
557
602
  row_height=row_height,
558
603
  grid=grid,
604
+ color=color,
605
+ linestyle=linestyle,
606
+ linewidth=linewidth,
559
607
  )
560
608
 
561
609
 
@@ -564,7 +612,9 @@ def line_mean_std(
564
612
  *,
565
613
  label: str | None = None,
566
614
  ax: Axes | None = None,
567
- color: str | None = None,
615
+ color: Color | None = None,
616
+ linewidth: float | None = None,
617
+ linestyle: Linestyle | None = None,
568
618
  alpha: float = 0.2,
569
619
  grid: bool = True,
570
620
  ) -> FigAx:
@@ -579,6 +629,8 @@ def line_mean_std(
579
629
  mean,
580
630
  color=color,
581
631
  label=label,
632
+ linewidth=linewidth,
633
+ linestyle=linestyle,
582
634
  )
583
635
  ax.fill_between(
584
636
  df.index,
@@ -598,6 +650,9 @@ def lines_mean_std_from_2d_idx(
598
650
  ax: Axes | None = None,
599
651
  alpha: float = 0.2,
600
652
  grid: bool = True,
653
+ color: Color | None = None,
654
+ linewidth: float | None = None,
655
+ linestyle: Linestyle | None = None,
601
656
  ) -> FigAx:
602
657
  """Plot the mean and standard deviation of a 2D indexed dataframe."""
603
658
  if len(cast(pd.MultiIndex, df.index).levels) != 2: # noqa: PLR2004
@@ -612,25 +667,32 @@ def lines_mean_std_from_2d_idx(
612
667
  label=name,
613
668
  alpha=alpha,
614
669
  ax=ax,
670
+ color=color,
671
+ linestyle=linestyle,
672
+ linewidth=linewidth,
615
673
  )
616
674
  ax.legend()
617
675
  return fig, ax
618
676
 
619
677
 
620
- def heatmap(
621
- df: pd.DataFrame,
678
+ def _create_heatmap(
622
679
  *,
680
+ ax: Axes | None,
681
+ df: pd.DataFrame,
682
+ title: str | None = None,
683
+ xlabel: str,
684
+ ylabel: str,
685
+ xticklabels: list[str],
686
+ yticklabels: list[str],
687
+ norm: Normalize,
623
688
  annotate: bool = False,
624
689
  colorbar: bool = True,
625
690
  invert_yaxis: bool = True,
626
691
  cmap: str = "RdBu_r",
627
- norm: Normalize | None = None,
628
- ax: Axes | None = None,
629
692
  cax: Axes | None = None,
630
693
  sci_annotation_bounds: tuple[float, float] = (0.01, 100),
631
694
  annotation_style: str = "2g",
632
695
  ) -> tuple[Figure, Axes, QuadMesh]:
633
- """Plot a heatmap of the given data."""
634
696
  fig, ax = _default_fig_ax(
635
697
  ax=ax,
636
698
  figsize=(
@@ -639,17 +701,23 @@ def heatmap(
639
701
  ),
640
702
  grid=False,
641
703
  )
642
- if norm is None:
643
- norm = _norm_with_zero_center(df)
644
704
 
705
+ # Note: pcolormesh swaps index/columns
645
706
  hm = ax.pcolormesh(df, norm=norm, cmap=cmap)
707
+
708
+ if xlabel is not None:
709
+ ax.set_xlabel(xlabel)
710
+ if ylabel is not None:
711
+ ax.set_ylabel(ylabel)
712
+ if title is not None:
713
+ ax.set_title(title)
646
714
  ax.set_xticks(
647
- np.arange(0, len(df.columns), 1) + 0.5,
648
- labels=df.columns,
715
+ np.arange(0, len(df.columns), 1, dtype=float) + 0.5,
716
+ labels=xticklabels,
649
717
  )
650
718
  ax.set_yticks(
651
- np.arange(0, len(df.index), 1) + 0.5,
652
- labels=df.index,
719
+ np.arange(0, len(df.index), 1, dtype=float) + 0.5,
720
+ labels=yticklabels,
653
721
  )
654
722
 
655
723
  if annotate:
@@ -666,39 +734,77 @@ def heatmap(
666
734
  return fig, ax, hm
667
735
 
668
736
 
737
+ def heatmap(
738
+ df: pd.DataFrame,
739
+ *,
740
+ ax: Axes | None = None,
741
+ title: str | None = None,
742
+ annotate: bool = False,
743
+ colorbar: bool = True,
744
+ invert_yaxis: bool = True,
745
+ cmap: str = "RdBu_r",
746
+ norm: Normalize | None = None,
747
+ cax: Axes | None = None,
748
+ sci_annotation_bounds: tuple[float, float] = (0.01, 100),
749
+ annotation_style: str = "2g",
750
+ ) -> tuple[Figure, Axes, QuadMesh]:
751
+ """Plot a heatmap of the given data."""
752
+ return _create_heatmap(
753
+ ax=ax,
754
+ df=df,
755
+ title=title,
756
+ xlabel=df.index.name,
757
+ ylabel=df.columns.name,
758
+ xticklabels=cast(list, df.columns),
759
+ yticklabels=cast(list, df.index),
760
+ annotate=annotate,
761
+ colorbar=colorbar,
762
+ invert_yaxis=invert_yaxis,
763
+ cmap=cmap,
764
+ norm=_norm_with_zero_center(df) if norm is None else norm,
765
+ cax=cax,
766
+ sci_annotation_bounds=sci_annotation_bounds,
767
+ annotation_style=annotation_style,
768
+ )
769
+
770
+
669
771
  def heatmap_from_2d_idx(
670
772
  df: pd.DataFrame,
671
773
  variable: str,
774
+ *,
672
775
  ax: Axes | None = None,
673
- ) -> FigAx:
776
+ annotate: bool = False,
777
+ colorbar: bool = True,
778
+ invert_yaxis: bool = False,
779
+ cmap: str = "viridis",
780
+ norm: Normalize | None = None,
781
+ cax: Axes | None = None,
782
+ sci_annotation_bounds: tuple[float, float] = (0.01, 100),
783
+ annotation_style: str = "2g",
784
+ ) -> tuple[Figure, Axes, QuadMesh]:
674
785
  """Plot a heatmap of a 2D indexed dataframe."""
675
786
  if len(cast(pd.MultiIndex, df.index).levels) != 2: # noqa: PLR2004
676
787
  msg = "MultiIndex must have exactly two levels"
677
788
  raise ValueError(msg)
678
-
679
- fig, ax = _default_fig_ax(ax=ax, grid=False)
680
- df2d = df[variable].unstack()
681
-
682
- ax.set_title(variable)
683
- # Note: pcolormesh swaps index/columns
684
- hm = ax.pcolormesh(df2d.T)
685
- ax.set_xlabel(df2d.index.name)
686
- ax.set_ylabel(df2d.columns.name)
687
- ax.set_xticks(
688
- np.arange(0, len(df2d.index), 1) + 0.5,
689
- labels=[f"{i:.2f}" for i in df2d.index],
690
- )
691
- ax.set_yticks(
692
- np.arange(0, len(df2d.columns), 1) + 0.5,
693
- labels=[f"{i:.2f}" for i in df2d.columns],
789
+ df2d = df[variable].unstack().T
790
+
791
+ return _create_heatmap(
792
+ df=df2d,
793
+ xlabel=df2d.index.name,
794
+ ylabel=df2d.columns.name,
795
+ xticklabels=[f"{i:.2f}" for i in df2d.columns],
796
+ yticklabels=[f"{i:.2f}" for i in df2d.index],
797
+ ax=ax,
798
+ cax=cax,
799
+ annotate=annotate,
800
+ colorbar=colorbar,
801
+ invert_yaxis=invert_yaxis,
802
+ cmap=cmap,
803
+ norm=_norm(df2d) if norm is None else norm,
804
+ sci_annotation_bounds=sci_annotation_bounds,
805
+ annotation_style=annotation_style,
694
806
  )
695
807
 
696
- rotate_xlabels(ax, rotation=45, ha="right")
697
-
698
- # Add colorbar
699
- fig.colorbar(hm, ax=ax)
700
- return fig, ax
701
-
702
808
 
703
809
  def heatmaps_from_2d_idx(
704
810
  df: pd.DataFrame,
@@ -708,6 +814,13 @@ def heatmaps_from_2d_idx(
708
814
  row_height_factor: float = 0.6,
709
815
  sharex: bool = True,
710
816
  sharey: bool = False,
817
+ annotate: bool = False,
818
+ colorbar: bool = True,
819
+ invert_yaxis: bool = False,
820
+ cmap: str = "viridis",
821
+ norm: Normalize | None = None,
822
+ sci_annotation_bounds: tuple[float, float] = (0.01, 100),
823
+ annotation_style: str = "2g",
711
824
  ) -> FigAxs:
712
825
  """Plot multiple heatmaps of a 2D indexed dataframe."""
713
826
  idx = cast(pd.MultiIndex, df.index)
@@ -722,7 +835,18 @@ def heatmaps_from_2d_idx(
722
835
  grid=False,
723
836
  )
724
837
  for ax, var in zip(axs, df.columns, strict=False):
725
- heatmap_from_2d_idx(df, var, ax=ax)
838
+ heatmap_from_2d_idx(
839
+ df,
840
+ var,
841
+ ax=ax,
842
+ annotate=annotate,
843
+ colorbar=colorbar,
844
+ invert_yaxis=invert_yaxis,
845
+ cmap=cmap,
846
+ norm=norm,
847
+ sci_annotation_bounds=sci_annotation_bounds,
848
+ annotation_style=annotation_style,
849
+ )
726
850
  return fig, axs
727
851
 
728
852
 
@@ -892,6 +1016,9 @@ def relative_label_distribution(
892
1016
  row_height: float = 3,
893
1017
  sharey: bool = False,
894
1018
  grid: bool = True,
1019
+ color: Color | None = None,
1020
+ linewidth: float | None = None,
1021
+ linestyle: Linestyle | None = None,
895
1022
  ) -> FigAxs:
896
1023
  """Plot the relative distribution of labels in the given data."""
897
1024
  variables = list(mapper.label_variables) if subset is None else subset
@@ -903,20 +1030,37 @@ def relative_label_distribution(
903
1030
  sharey=sharey,
904
1031
  grid=grid,
905
1032
  )
1033
+ # FIXME: rewrite as building a dict of dataframes
1034
+ # and passing it to lines_grouped
906
1035
  if isinstance(mapper, LabelMapper):
907
1036
  for ax, name in zip(axs, variables, strict=False):
908
1037
  for i in range(mapper.label_variables[name]):
909
1038
  isos = mapper.get_isotopomers_of_at_position(name, i)
910
1039
  labels = cast(pd.DataFrame, concs.loc[:, isos])
911
1040
  total = concs.loc[:, f"{name}__total"]
912
- ax.plot(labels.index, (labels.sum(axis=1) / total), label=f"C{i + 1}")
1041
+ ax.plot(
1042
+ labels.index,
1043
+ (labels.sum(axis=1) / total),
1044
+ label=f"C{i + 1}",
1045
+ linewidth=linewidth,
1046
+ linestyle=linestyle,
1047
+ color=color,
1048
+ )
913
1049
  ax.set_title(name)
914
1050
  ax.legend()
915
1051
  else:
916
1052
  for ax, (name, isos) in zip(
917
- axs, mapper.get_isotopomers(variables).items(), strict=False
1053
+ axs,
1054
+ mapper.get_isotopomers(variables).items(),
1055
+ strict=False,
918
1056
  ):
919
- ax.plot(concs.index, concs.loc[:, isos])
1057
+ ax.plot(
1058
+ concs.index,
1059
+ concs.loc[:, isos],
1060
+ linewidth=linewidth,
1061
+ linestyle=linestyle,
1062
+ color=color,
1063
+ )
920
1064
  ax.set_title(name)
921
1065
  ax.legend([f"C{i + 1}" for i in range(len(isos))])
922
1066
 
mxlpy/sbml/_export.py CHANGED
@@ -447,7 +447,14 @@ def _create_sbml_variables(
447
447
  cpd.setConstant(False)
448
448
  cpd.setBoundaryCondition(False)
449
449
  cpd.setHasOnlySubstanceUnits(False)
450
- cpd.setInitialAmount(float(value))
450
+ if isinstance(value, Derived):
451
+ ar = sbml_model.createInitialAssignment()
452
+ ar.setId(_convert_id_to_sbml(id_=name, prefix="IA"))
453
+ ar.setName(_convert_id_to_sbml(id_=name, prefix="IA"))
454
+ ar.setVariable(_convert_id_to_sbml(id_=name, prefix="IA"))
455
+ ar.setMath(_sbmlify_fn(value.fn, value.args))
456
+ else:
457
+ cpd.setInitialAmount(float(value))
451
458
 
452
459
 
453
460
  def _create_sbml_derived_variables(*, model: Model, sbml_model: libsbml.Model) -> None:
@@ -4,29 +4,37 @@ This module provides classes and functions for creating and training surrogate m
4
4
  for metabolic simulations. It includes functionality for both steady-state and time-series
5
5
  data using neural networks.
6
6
 
7
- Classes:
8
- AbstractSurrogate: Abstract base class for surrogate models.
9
- TorchSurrogate: Surrogate model using PyTorch.
10
- Approximator: Neural network approximator for surrogate modeling.
11
-
12
- Functions:
13
- train_torch_surrogate: Train a PyTorch surrogate model.
14
- train_torch_time_course_estimator: Train a PyTorch time course estimator.
15
7
  """
16
8
 
17
9
  from __future__ import annotations
18
10
 
19
- import contextlib
11
+ from typing import TYPE_CHECKING
20
12
 
21
- with contextlib.suppress(ImportError):
22
- from ._torch import Torch, TorchTrainer, train_torch
13
+ if TYPE_CHECKING:
14
+ import contextlib
23
15
 
24
- from ._poly import Polynomial, train_polynomial
16
+ with contextlib.suppress(ImportError):
17
+ from . import _keras as keras
18
+ from . import _torch as torch
19
+ else:
20
+ from lazy_import import lazy_module
21
+
22
+ keras = lazy_module(
23
+ "mxlpy.surrogates._keras",
24
+ error_strings={"module": "keras", "install_name": "mxlpy[tf]"},
25
+ )
26
+ torch = lazy_module(
27
+ "mxlpy.surrogates._torch",
28
+ error_strings={"module": "torch", "install_name": "mxlpy[torch]"},
29
+ )
30
+
31
+
32
+ from . import _poly as poly
33
+ from . import _qss as qss
25
34
 
26
35
  __all__ = [
27
- "Polynomial",
28
- "Torch",
29
- "TorchTrainer",
30
- "train_polynomial",
31
- "train_torch",
36
+ "keras",
37
+ "poly",
38
+ "qss",
39
+ "torch",
32
40
  ]
@@ -0,0 +1,137 @@
1
+ from dataclasses import dataclass
2
+ from typing import Self, cast
3
+
4
+ import keras
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ from mxlpy.nn._keras import MLP
9
+ from mxlpy.nn._keras import train as _train
10
+ from mxlpy.types import AbstractSurrogate, Array, Derived
11
+
12
+ __all__ = [
13
+ "DefaultLoss",
14
+ "DefaultOptimizer",
15
+ "LossFn",
16
+ "Optimizer",
17
+ "Surrogate",
18
+ "Trainer",
19
+ "train",
20
+ ]
21
+
22
+ type Optimizer = keras.optimizers.Optimizer | str
23
+ type LossFn = keras.losses.Loss | str
24
+
25
+ DefaultOptimizer = keras.optimizers.Adam()
26
+ DefaultLoss = keras.losses.MeanAbsoluteError()
27
+
28
+
29
+ @dataclass(kw_only=True)
30
+ class Surrogate(AbstractSurrogate):
31
+ model: keras.Model
32
+
33
+ def predict_raw(self, y: Array) -> Array:
34
+ return np.atleast_1d(np.squeeze(self.model.predict(y)))
35
+
36
+ def predict(
37
+ self, args: dict[str, float | pd.Series | pd.DataFrame]
38
+ ) -> dict[str, float]:
39
+ return dict(
40
+ zip(
41
+ self.outputs,
42
+ self.predict_raw(np.array([args[arg] for arg in self.args])),
43
+ strict=True,
44
+ )
45
+ )
46
+
47
+
48
+ @dataclass(init=False)
49
+ class Trainer:
50
+ features: pd.DataFrame
51
+ targets: pd.DataFrame
52
+ model: keras.Model
53
+ optimizer: Optimizer | str
54
+ losses: list[pd.Series]
55
+ loss_fn: LossFn
56
+
57
+ def __init__(
58
+ self,
59
+ features: pd.DataFrame,
60
+ targets: pd.DataFrame,
61
+ model: keras.Model | None = None,
62
+ optimizer: Optimizer = DefaultOptimizer,
63
+ loss: LossFn = DefaultLoss,
64
+ ) -> None:
65
+ self.features = features
66
+ self.targets = targets
67
+ if model is None:
68
+ model = MLP(
69
+ n_inputs=len(features.columns),
70
+ neurons_per_layer=[50, 50, len(targets.columns)],
71
+ )
72
+ self.model = model
73
+ model.compile(optimizer=cast(str, optimizer), loss=loss)
74
+
75
+ self.losses = []
76
+
77
+ def train(self, epochs: int, batch_size: int | None = None) -> Self:
78
+ losses = _train(
79
+ model=self.model,
80
+ features=self.features,
81
+ targets=self.targets,
82
+ epochs=epochs,
83
+ batch_size=batch_size,
84
+ )
85
+
86
+ if len(self.losses) > 0:
87
+ losses.index += self.losses[-1].index[-1]
88
+ self.losses.append(losses)
89
+
90
+ return self
91
+
92
+ def get_loss(self) -> pd.Series:
93
+ return pd.concat(self.losses)
94
+
95
+ def get_surrogate(
96
+ self,
97
+ surrogate_args: list[str] | None = None,
98
+ surrogate_outputs: list[str] | None = None,
99
+ surrogate_stoichiometries: dict[str, dict[str, float | Derived]] | None = None,
100
+ ) -> Surrogate:
101
+ return Surrogate(
102
+ model=self.model,
103
+ args=surrogate_args if surrogate_args is not None else [],
104
+ outputs=surrogate_outputs if surrogate_outputs is not None else [],
105
+ stoichiometries=surrogate_stoichiometries
106
+ if surrogate_stoichiometries is not None
107
+ else {},
108
+ )
109
+
110
+
111
+ def train(
112
+ features: pd.DataFrame,
113
+ targets: pd.DataFrame,
114
+ epochs: int,
115
+ surrogate_args: list[str] | None = None,
116
+ surrogate_outputs: list[str] | None = None,
117
+ surrogate_stoichiometries: dict[str, dict[str, float | Derived]] | None = None,
118
+ batch_size: int | None = None,
119
+ model: keras.Model | None = None,
120
+ optimizer: Optimizer = DefaultOptimizer,
121
+ loss: LossFn = DefaultLoss,
122
+ ) -> tuple[Surrogate, pd.Series]:
123
+ trainer = Trainer(
124
+ features=features,
125
+ targets=targets,
126
+ model=model,
127
+ optimizer=optimizer,
128
+ loss=loss,
129
+ ).train(
130
+ epochs=epochs,
131
+ batch_size=batch_size,
132
+ )
133
+ return trainer.get_surrogate(
134
+ surrogate_args=surrogate_args,
135
+ surrogate_outputs=surrogate_outputs,
136
+ surrogate_stoichiometries=surrogate_stoichiometries,
137
+ ), trainer.get_loss()