modelbase2 0.1.79__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. modelbase2/__init__.py +148 -25
  2. modelbase2/distributions.py +336 -0
  3. modelbase2/experimental/__init__.py +17 -0
  4. modelbase2/experimental/codegen.py +239 -0
  5. modelbase2/experimental/diff.py +227 -0
  6. modelbase2/experimental/notes.md +4 -0
  7. modelbase2/experimental/tex.py +521 -0
  8. modelbase2/fit.py +284 -0
  9. modelbase2/fns.py +185 -0
  10. modelbase2/integrators/__init__.py +19 -0
  11. modelbase2/integrators/int_assimulo.py +146 -0
  12. modelbase2/integrators/int_scipy.py +147 -0
  13. modelbase2/label_map.py +610 -0
  14. modelbase2/linear_label_map.py +301 -0
  15. modelbase2/mc.py +548 -0
  16. modelbase2/mca.py +280 -0
  17. modelbase2/model.py +1621 -0
  18. modelbase2/nnarchitectures.py +128 -0
  19. modelbase2/npe.py +271 -0
  20. modelbase2/parallel.py +171 -0
  21. modelbase2/parameterise.py +28 -0
  22. modelbase2/paths.py +36 -0
  23. modelbase2/plot.py +832 -0
  24. modelbase2/sbml/__init__.py +14 -0
  25. modelbase2/sbml/_data.py +77 -0
  26. modelbase2/sbml/_export.py +656 -0
  27. modelbase2/sbml/_import.py +585 -0
  28. modelbase2/sbml/_mathml.py +691 -0
  29. modelbase2/sbml/_name_conversion.py +52 -0
  30. modelbase2/sbml/_unit_conversion.py +74 -0
  31. modelbase2/scan.py +616 -0
  32. modelbase2/scope.py +96 -0
  33. modelbase2/simulator.py +635 -0
  34. modelbase2/surrogates/__init__.py +31 -0
  35. modelbase2/surrogates/_poly.py +91 -0
  36. modelbase2/surrogates/_torch.py +191 -0
  37. modelbase2/surrogates.py +316 -0
  38. modelbase2/types.py +352 -11
  39. modelbase2-0.3.0.dist-info/METADATA +93 -0
  40. modelbase2-0.3.0.dist-info/RECORD +43 -0
  41. {modelbase2-0.1.79.dist-info → modelbase2-0.3.0.dist-info}/WHEEL +1 -1
  42. modelbase2/core/__init__.py +0 -29
  43. modelbase2/core/algebraic_module_container.py +0 -130
  44. modelbase2/core/constant_container.py +0 -113
  45. modelbase2/core/data.py +0 -109
  46. modelbase2/core/name_container.py +0 -29
  47. modelbase2/core/reaction_container.py +0 -115
  48. modelbase2/core/utils.py +0 -28
  49. modelbase2/core/variable_container.py +0 -24
  50. modelbase2/ode/__init__.py +0 -13
  51. modelbase2/ode/integrator.py +0 -80
  52. modelbase2/ode/mca.py +0 -270
  53. modelbase2/ode/model.py +0 -470
  54. modelbase2/ode/simulator.py +0 -153
  55. modelbase2/utils/__init__.py +0 -0
  56. modelbase2/utils/plotting.py +0 -372
  57. modelbase2-0.1.79.dist-info/METADATA +0 -44
  58. modelbase2-0.1.79.dist-info/RECORD +0 -22
  59. {modelbase2-0.1.79.dist-info → modelbase2-0.3.0.dist-info/licenses}/LICENSE +0 -0
@@ -0,0 +1,91 @@
1
+ from collections.abc import Iterable
2
+ from dataclasses import dataclass
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ from numpy import polynomial
7
+ from typing import Union, Literal
8
+
9
+ from modelbase2.types import AbstractSurrogate, ArrayLike
10
+
11
+ __all__ = ["PolySurrogate", "PolynomialExpansion", "train_polynomial_surrogate"]
12
+
13
+ # define custom type
14
+ PolynomialExpansion = (
15
+ polynomial.polynomial.Polynomial
16
+ | polynomial.chebyshev.Chebyshev
17
+ | polynomial.legendre.Legendre
18
+ | polynomial.laguerre.Laguerre
19
+ | polynomial.hermite.Hermite
20
+ | polynomial.hermite_e.HermiteE
21
+ )
22
+
23
+
24
+ @dataclass(kw_only=True)
25
+ class PolySurrogate(AbstractSurrogate):
26
+ model: PolynomialExpansion
27
+
28
+ def predict_raw(self, y: np.ndarray) -> np.ndarray:
29
+ return self.model(y)
30
+
31
+
32
+ def train_polynomial_surrogate(
33
+ feature: ArrayLike,
34
+ target: ArrayLike,
35
+ series: Literal["Power", "Chebyshev", "Legendre", "Laguerre", "Hermite", "HermiteE"] = "Power",
36
+ degrees: Iterable[int] = (1, 2, 3, 4, 5, 6, 7),
37
+ surrogate_args: list[str] | None = None,
38
+ surrogate_stoichiometries: dict[str, dict[str, float]] | None = None,
39
+ ) -> tuple[PolySurrogate, pd.DataFrame]:
40
+ """Train a surrogate model based on function series expansion.
41
+
42
+ Args:
43
+ feature: Input data as a numpy array.
44
+ target: Output data as a numpy array.
45
+ series: Base functions for the surrogate model
46
+ degrees: Degrees of the polynomial to fit to the data.
47
+ surrogate_args: Additional arguments for the surrogate model.
48
+ surrogate_stoichiometries: Stoichiometries for the surrogate model.
49
+
50
+ Returns:
51
+ PolySurrogate: Polynomial surrogate model.
52
+
53
+ """
54
+ feature = np.array(feature, dtype=float)
55
+ target = np.array(target, dtype=float)
56
+
57
+ # Choose numpy polynomial convenience classes
58
+ series_dictionary = {
59
+ "Power": polynomial.polynomial.Polynomial,
60
+ "Chebyshev": polynomial.chebyshev.Chebyshev,
61
+ "Legendre": polynomial.legendre.Legendre,
62
+ "Laguerre": polynomial.laguerre.Laguerre,
63
+ "Hermite": polynomial.hermite.Hermite,
64
+ "HermiteE": polynomial.hermite_e.HermiteE,
65
+ }
66
+
67
+ fn_series = series_dictionary[series]
68
+
69
+ models = [fn_series.fit(feature, target, degree) for degree in degrees]
70
+ predictions = np.array([model(feature) for model in models], dtype=float)
71
+ errors = np.sqrt(np.mean(np.square(predictions - target.reshape(1, -1)), axis=1))
72
+ log_likelihood = -0.5 * np.sum(
73
+ np.square(predictions - target.reshape(1, -1)), axis=1
74
+ )
75
+ score = 2 * np.array(degrees) - 2 * log_likelihood
76
+
77
+ # Choose the model with the lowest AIC
78
+ model = models[np.argmin(score)]
79
+ return (
80
+ PolySurrogate(
81
+ model=model,
82
+ args=surrogate_args if surrogate_args is not None else [],
83
+ stoichiometries=surrogate_stoichiometries
84
+ if surrogate_stoichiometries is not None
85
+ else {},
86
+ ),
87
+ pd.DataFrame(
88
+ {"models": models, "error": errors, "score": score},
89
+ index=pd.Index(np.array(degrees), name="degree"),
90
+ ),
91
+ )
@@ -0,0 +1,191 @@
1
+ from dataclasses import dataclass
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import torch
6
+ import tqdm
7
+ from torch import nn
8
+ from torch.optim.adam import Adam
9
+
10
+ from modelbase2.types import AbstractSurrogate
11
+ from modelbase2.nnarchitectures import MLP, DefaultDevice
12
+
13
+ __all__ = ["TorchSurrogate", "train_torch_surrogate"]
14
+
15
+
16
+ @dataclass(kw_only=True)
17
+ class TorchSurrogate(AbstractSurrogate):
18
+ """Surrogate model using PyTorch.
19
+
20
+ Attributes:
21
+ model: PyTorch neural network model.
22
+
23
+ Methods:
24
+ predict: Predict outputs based on input data using the PyTorch model.
25
+
26
+ """
27
+
28
+ model: torch.nn.Module
29
+
30
+ def predict_raw(self, y: np.ndarray) -> np.ndarray:
31
+ """Predict outputs based on input data using the PyTorch model.
32
+
33
+ Args:
34
+ y: Input data as a numpy array.
35
+
36
+ Returns:
37
+ dict[str, float]: Dictionary mapping output variable names to predicted values.
38
+
39
+ """
40
+ with torch.no_grad():
41
+ return self.model(
42
+ torch.tensor(y, dtype=torch.float32),
43
+ ).numpy()
44
+
45
+
46
+ def _train_batched(
47
+ aprox: nn.Module,
48
+ features: pd.DataFrame,
49
+ targets: pd.DataFrame,
50
+ epochs: int,
51
+ optimizer: Adam,
52
+ device: torch.device,
53
+ batch_size: int,
54
+ ) -> pd.Series:
55
+ """Train the neural network using mini-batch gradient descent.
56
+
57
+ Args:
58
+ aprox: Neural network model to train.
59
+ features: Input features as a tensor.
60
+ targets: Target values as a tensor.
61
+ epochs: Number of training epochs.
62
+ optimizer: Optimizer for training.
63
+ device: torch device
64
+ batch_size: Size of mini-batches for training.
65
+
66
+ Returns:
67
+ pd.Series: Series containing the training loss history.
68
+
69
+ """
70
+ rng = np.random.default_rng()
71
+ losses = {}
72
+ for i in tqdm.trange(epochs):
73
+ idxs = rng.choice(features.index, size=batch_size)
74
+ X = torch.Tensor(features.iloc[idxs].to_numpy(), device=device)
75
+ Y = torch.Tensor(targets.iloc[idxs].to_numpy(), device=device)
76
+ optimizer.zero_grad()
77
+ loss = torch.mean(torch.abs(aprox(X) - Y))
78
+ loss.backward()
79
+ optimizer.step()
80
+ losses[i] = loss.detach().numpy()
81
+ return pd.Series(losses, dtype=float)
82
+
83
+
84
+ def _train_full(
85
+ aprox: nn.Module,
86
+ features: pd.DataFrame,
87
+ targets: pd.DataFrame,
88
+ epochs: int,
89
+ optimizer: Adam,
90
+ device: torch.device,
91
+ ) -> pd.Series:
92
+ """Train the neural network using full-batch gradient descent.
93
+
94
+ Args:
95
+ aprox: Neural network model to train.
96
+ features: Input features as a tensor.
97
+ targets: Target values as a tensor.
98
+ epochs: Number of training epochs.
99
+ optimizer: Optimizer for training.
100
+ device: Torch device
101
+
102
+ Returns:
103
+ pd.Series: Series containing the training loss history.
104
+
105
+ """
106
+ X = torch.Tensor(features.to_numpy(), device=device)
107
+ Y = torch.Tensor(targets.to_numpy(), device=device)
108
+
109
+ losses = {}
110
+ for i in tqdm.trange(epochs):
111
+ optimizer.zero_grad()
112
+ loss = torch.mean(torch.abs(aprox(X) - Y))
113
+ loss.backward()
114
+ optimizer.step()
115
+ losses[i] = loss.detach().numpy()
116
+ return pd.Series(losses, dtype=float)
117
+
118
+
119
+ def train_torch_surrogate(
120
+ features: pd.DataFrame,
121
+ targets: pd.DataFrame,
122
+ epochs: int,
123
+ surrogate_args: list[str] | None = None,
124
+ surrogate_stoichiometries: dict[str, dict[str, float]] | None = None,
125
+ batch_size: int | None = None,
126
+ approximator: nn.Module | None = None,
127
+ optimimzer_cls: type[Adam] = Adam,
128
+ device: torch.device = DefaultDevice,
129
+ ) -> tuple[TorchSurrogate, pd.Series]:
130
+ """Train a PyTorch surrogate model.
131
+
132
+ Examples:
133
+ >>> train_torch_surrogate(
134
+ ... features,
135
+ ... targets,
136
+ ... epochs=100,
137
+ ... surrogate_inputs=["x1", "x2"],
138
+ ... surrogate_stoichiometries={
139
+ ... "v1": {"x1": -1, "x2": 1, "ATP": -1},
140
+ ... },
141
+ ...)
142
+
143
+ Args:
144
+ features: DataFrame containing the input features for training.
145
+ targets: DataFrame containing the target values for training.
146
+ epochs: Number of training epochs.
147
+ surrogate_args: List of input variable names for the surrogate model.
148
+ surrogate_stoichiometries: Dictionary mapping reaction names to stoichiometries.
149
+ batch_size: Size of mini-batches for training (None for full-batch).
150
+ approximator: Predefined neural network model (None to use default MLP features-50-50-output).
151
+ optimimzer_cls: Optimizer class to use for training (default: Adam).
152
+ device: Device to run the training on (default: DefaultDevice).
153
+
154
+ Returns:
155
+ tuple[TorchSurrogate, pd.Series]: Trained surrogate model and loss history.
156
+
157
+ """
158
+ if approximator is None:
159
+ approximator = MLP(
160
+ n_inputs=len(features.columns),
161
+ layers=[50, 50, len(targets.columns)],
162
+ ).to(device)
163
+
164
+ optimizer = optimimzer_cls(approximator.parameters())
165
+ if batch_size is None:
166
+ losses = _train_full(
167
+ aprox=approximator,
168
+ features=features,
169
+ targets=targets,
170
+ epochs=epochs,
171
+ optimizer=optimizer,
172
+ device=device,
173
+ )
174
+ else:
175
+ losses = _train_batched(
176
+ aprox=approximator,
177
+ features=features,
178
+ targets=targets,
179
+ epochs=epochs,
180
+ optimizer=optimizer,
181
+ device=device,
182
+ batch_size=batch_size,
183
+ )
184
+ surrogate = TorchSurrogate(
185
+ model=approximator,
186
+ args=surrogate_args if surrogate_args is not None else [],
187
+ stoichiometries=surrogate_stoichiometries
188
+ if surrogate_stoichiometries is not None
189
+ else {},
190
+ )
191
+ return surrogate, losses
@@ -0,0 +1,316 @@
1
+ """Surrogate Models Module.
2
+
3
+ This module provides classes and functions for creating and training surrogate models
4
+ for metabolic simulations. It includes functionality for both steady-state and time-series
5
+ data using neural networks.
6
+
7
+ Classes:
8
+ AbstractSurrogate: Abstract base class for surrogate models.
9
+ TorchSurrogate: Surrogate model using PyTorch.
10
+ Approximator: Neural network approximator for surrogate modeling.
11
+
12
+ Functions:
13
+ train_torch_surrogate: Train a PyTorch surrogate model.
14
+ train_torch_time_course_estimator: Train a PyTorch time course estimator.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ from abc import abstractmethod
20
+ from dataclasses import dataclass
21
+ from pathlib import Path
22
+
23
+ import numpy as np
24
+ import pandas as pd
25
+ import torch
26
+ import tqdm
27
+ from torch import nn
28
+ from torch.optim.adam import Adam
29
+
30
+ from modelbase2.parallel import Cache
31
+
32
+ __all__ = [
33
+ "AbstractSurrogate",
34
+ "Approximator",
35
+ "DefaultCache",
36
+ "DefaultDevice",
37
+ "MockSurrogate",
38
+ "TorchSurrogate",
39
+ "train_torch_surrogate",
40
+ ]
41
+
42
+
43
+ DefaultDevice = torch.device("cpu")
44
+ DefaultCache = Cache(Path(".cache"))
45
+
46
+
47
+ @dataclass(kw_only=True)
48
+ class AbstractSurrogate:
49
+ """Abstract base class for surrogate models.
50
+
51
+ Attributes:
52
+ inputs: List of input variable names.
53
+ stoichiometries: Dictionary mapping reaction names to stoichiometries.
54
+
55
+ Methods:
56
+ predict: Abstract method to predict outputs based on input data.
57
+
58
+ """
59
+
60
+ inputs: list[str]
61
+ stoichiometries: dict[str, dict[str, float]]
62
+
63
+ @abstractmethod
64
+ def predict(self, y: np.ndarray) -> dict[str, float]:
65
+ """Predict outputs based on input data."""
66
+
67
+
68
+ @dataclass(kw_only=True)
69
+ class MockSurrogate(AbstractSurrogate):
70
+ """Mock surrogate model for testing purposes."""
71
+
72
+ def predict(
73
+ self,
74
+ y: np.ndarray,
75
+ ) -> dict[str, float]:
76
+ """Predict outputs based on input data."""
77
+ return dict(zip(self.stoichiometries, y, strict=True))
78
+
79
+
80
+ @dataclass(kw_only=True)
81
+ class TorchSurrogate(AbstractSurrogate):
82
+ """Surrogate model using PyTorch.
83
+
84
+ Attributes:
85
+ model: PyTorch neural network model.
86
+
87
+ Methods:
88
+ predict: Predict outputs based on input data using the PyTorch model.
89
+
90
+ """
91
+
92
+ model: torch.nn.Module
93
+
94
+ def predict(self, y: np.ndarray) -> dict[str, float]:
95
+ """Predict outputs based on input data using the PyTorch model.
96
+
97
+ Args:
98
+ y: Input data as a numpy array.
99
+
100
+ Returns:
101
+ dict[str, float]: Dictionary mapping output variable names to predicted values.
102
+
103
+ """
104
+ with torch.no_grad():
105
+ return dict(
106
+ zip(
107
+ self.stoichiometries,
108
+ self.model(
109
+ torch.tensor(y, dtype=torch.float32),
110
+ ).numpy(),
111
+ strict=True,
112
+ )
113
+ )
114
+
115
+
116
+ class Approximator(nn.Module):
117
+ """Neural network approximator for surrogate modeling.
118
+
119
+ Attributes:
120
+ net: Sequential neural network model.
121
+
122
+ Methods:
123
+ forward: Forward pass through the neural network.
124
+
125
+ """
126
+
127
+ def __init__(self, n_inputs: int, n_outputs: int) -> None:
128
+ """Initializes the surrogate model with the given number of inputs and outputs.
129
+
130
+ Args:
131
+ n_inputs (int): The number of input features.
132
+ n_outputs (int): The number of output features.
133
+
134
+ Initializes a neural network with the following architecture:
135
+ - Linear layer with `n_inputs` inputs and 50 outputs
136
+ - ReLU activation
137
+ - Linear layer with 50 inputs and 50 outputs
138
+ - ReLU activation
139
+ - Linear layer with 50 inputs and `n_outputs` outputs
140
+
141
+ The weights of the linear layers are initialized with a normal distribution
142
+ (mean=0, std=0.1) and the biases are initialized to 0.
143
+
144
+ """
145
+ super().__init__()
146
+
147
+ self.net = nn.Sequential(
148
+ nn.Linear(n_inputs, 50),
149
+ nn.ReLU(),
150
+ nn.Linear(50, 50),
151
+ nn.ReLU(),
152
+ nn.Linear(50, n_outputs),
153
+ )
154
+
155
+ for m in self.net.modules():
156
+ if isinstance(m, nn.Linear):
157
+ nn.init.normal_(m.weight, mean=0, std=0.1)
158
+ nn.init.constant_(m.bias, val=0)
159
+
160
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
161
+ """Forward pass through the neural network.
162
+
163
+ Args:
164
+ x: Input tensor.
165
+
166
+ Returns:
167
+ torch.Tensor: Output tensor.
168
+
169
+ """
170
+ return self.net(x)
171
+
172
+
173
+ def _train_batched(
174
+ aprox: nn.Module,
175
+ features: pd.DataFrame,
176
+ targets: pd.DataFrame,
177
+ epochs: int,
178
+ optimizer: Adam,
179
+ device: torch.device,
180
+ batch_size: int,
181
+ ) -> pd.Series:
182
+ """Train the neural network using mini-batch gradient descent.
183
+
184
+ Args:
185
+ aprox: Neural network model to train.
186
+ features: Input features as a tensor.
187
+ targets: Target values as a tensor.
188
+ epochs: Number of training epochs.
189
+ optimizer: Optimizer for training.
190
+ device: torch device
191
+ batch_size: Size of mini-batches for training.
192
+
193
+ Returns:
194
+ pd.Series: Series containing the training loss history.
195
+
196
+ """
197
+ rng = np.random.default_rng()
198
+ losses = {}
199
+ for i in tqdm.trange(epochs):
200
+ idxs = rng.choice(features.index, size=batch_size)
201
+ X = torch.Tensor(features.iloc[idxs].to_numpy(), device=device)
202
+ Y = torch.Tensor(targets.iloc[idxs].to_numpy(), device=device)
203
+ optimizer.zero_grad()
204
+ loss = torch.mean(torch.abs(aprox(X) - Y))
205
+ loss.backward()
206
+ optimizer.step()
207
+ losses[i] = loss.detach().numpy()
208
+ return pd.Series(losses, dtype=float)
209
+
210
+
211
+ def _train_full(
212
+ aprox: nn.Module,
213
+ features: pd.DataFrame,
214
+ targets: pd.DataFrame,
215
+ epochs: int,
216
+ optimizer: Adam,
217
+ device: torch.device,
218
+ ) -> pd.Series:
219
+ """Train the neural network using full-batch gradient descent.
220
+
221
+ Args:
222
+ aprox: Neural network model to train.
223
+ features: Input features as a tensor.
224
+ targets: Target values as a tensor.
225
+ epochs: Number of training epochs.
226
+ optimizer: Optimizer for training.
227
+ device: Torch device
228
+
229
+ Returns:
230
+ pd.Series: Series containing the training loss history.
231
+
232
+ """
233
+ X = torch.Tensor(features.to_numpy(), device=device)
234
+ Y = torch.Tensor(targets.to_numpy(), device=device)
235
+
236
+ losses = {}
237
+ for i in tqdm.trange(epochs):
238
+ optimizer.zero_grad()
239
+ loss = torch.mean(torch.abs(aprox(X) - Y))
240
+ loss.backward()
241
+ optimizer.step()
242
+ losses[i] = loss.detach().numpy()
243
+ return pd.Series(losses, dtype=float)
244
+
245
+
246
+ def train_torch_surrogate(
247
+ features: pd.DataFrame,
248
+ targets: pd.DataFrame,
249
+ epochs: int,
250
+ surrogate_inputs: list[str],
251
+ surrogate_stoichiometries: dict[str, dict[str, float]],
252
+ batch_size: int | None = None,
253
+ approximator: nn.Module | None = None,
254
+ optimimzer_cls: type[Adam] = Adam,
255
+ device: torch.device = DefaultDevice,
256
+ ) -> tuple[TorchSurrogate, pd.Series]:
257
+ """Train a PyTorch surrogate model.
258
+
259
+ Examples:
260
+ >>> train_torch_surrogate(
261
+ ... features,
262
+ ... targets,
263
+ ... epochs=100,
264
+ ... surrogate_inputs=["x1", "x2"],
265
+ ... surrogate_stoichiometries={
266
+ ... "v1": {"x1": -1, "x2": 1, "ATP": -1},
267
+ ... },
268
+ ...)
269
+
270
+ Args:
271
+ features: DataFrame containing the input features for training.
272
+ targets: DataFrame containing the target values for training.
273
+ epochs: Number of training epochs.
274
+ surrogate_inputs: List of input variable names for the surrogate model.
275
+ surrogate_stoichiometries: Dictionary mapping reaction names to stoichiometries.
276
+ batch_size: Size of mini-batches for training (None for full-batch).
277
+ approximator: Predefined neural network model (None to use default).
278
+ optimimzer_cls: Optimizer class to use for training (default: Adam).
279
+ device: Device to run the training on (default: DefaultDevice).
280
+
281
+ Returns:
282
+ tuple[TorchSurrogate, pd.Series]: Trained surrogate model and loss history.
283
+
284
+ """
285
+ if approximator is None:
286
+ approximator = Approximator(
287
+ n_inputs=len(features.columns),
288
+ n_outputs=len(targets.columns),
289
+ ).to(device)
290
+
291
+ optimizer = optimimzer_cls(approximator.parameters())
292
+ if batch_size is None:
293
+ losses = _train_full(
294
+ aprox=approximator,
295
+ features=features,
296
+ targets=targets,
297
+ epochs=epochs,
298
+ optimizer=optimizer,
299
+ device=device,
300
+ )
301
+ else:
302
+ losses = _train_batched(
303
+ aprox=approximator,
304
+ features=features,
305
+ targets=targets,
306
+ epochs=epochs,
307
+ optimizer=optimizer,
308
+ device=device,
309
+ batch_size=batch_size,
310
+ )
311
+ surrogate = TorchSurrogate(
312
+ model=approximator,
313
+ inputs=surrogate_inputs,
314
+ stoichiometries=surrogate_stoichiometries,
315
+ )
316
+ return surrogate, losses