modelbase2 0.1.78__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. modelbase2/__init__.py +138 -26
  2. modelbase2/distributions.py +306 -0
  3. modelbase2/experimental/__init__.py +17 -0
  4. modelbase2/experimental/codegen.py +239 -0
  5. modelbase2/experimental/diff.py +227 -0
  6. modelbase2/experimental/notes.md +4 -0
  7. modelbase2/experimental/tex.py +521 -0
  8. modelbase2/fit.py +284 -0
  9. modelbase2/fns.py +185 -0
  10. modelbase2/integrators/__init__.py +19 -0
  11. modelbase2/integrators/int_assimulo.py +146 -0
  12. modelbase2/integrators/int_scipy.py +147 -0
  13. modelbase2/label_map.py +610 -0
  14. modelbase2/linear_label_map.py +301 -0
  15. modelbase2/mc.py +548 -0
  16. modelbase2/mca.py +280 -0
  17. modelbase2/model.py +1621 -0
  18. modelbase2/npe.py +343 -0
  19. modelbase2/parallel.py +171 -0
  20. modelbase2/parameterise.py +28 -0
  21. modelbase2/paths.py +36 -0
  22. modelbase2/plot.py +829 -0
  23. modelbase2/sbml/__init__.py +14 -0
  24. modelbase2/sbml/_data.py +77 -0
  25. modelbase2/sbml/_export.py +656 -0
  26. modelbase2/sbml/_import.py +585 -0
  27. modelbase2/sbml/_mathml.py +691 -0
  28. modelbase2/sbml/_name_conversion.py +52 -0
  29. modelbase2/sbml/_unit_conversion.py +74 -0
  30. modelbase2/scan.py +616 -0
  31. modelbase2/scope.py +96 -0
  32. modelbase2/simulator.py +635 -0
  33. modelbase2/surrogates/__init__.py +32 -0
  34. modelbase2/surrogates/_poly.py +66 -0
  35. modelbase2/surrogates/_torch.py +249 -0
  36. modelbase2/surrogates.py +316 -0
  37. modelbase2/types.py +352 -11
  38. modelbase2-0.2.0.dist-info/METADATA +81 -0
  39. modelbase2-0.2.0.dist-info/RECORD +42 -0
  40. {modelbase2-0.1.78.dist-info → modelbase2-0.2.0.dist-info}/WHEEL +1 -1
  41. modelbase2/core/__init__.py +0 -29
  42. modelbase2/core/algebraic_module_container.py +0 -130
  43. modelbase2/core/constant_container.py +0 -113
  44. modelbase2/core/data.py +0 -109
  45. modelbase2/core/name_container.py +0 -29
  46. modelbase2/core/reaction_container.py +0 -115
  47. modelbase2/core/utils.py +0 -28
  48. modelbase2/core/variable_container.py +0 -24
  49. modelbase2/ode/__init__.py +0 -13
  50. modelbase2/ode/integrator.py +0 -80
  51. modelbase2/ode/mca.py +0 -270
  52. modelbase2/ode/model.py +0 -470
  53. modelbase2/ode/simulator.py +0 -153
  54. modelbase2/utils/__init__.py +0 -0
  55. modelbase2/utils/plotting.py +0 -372
  56. modelbase2-0.1.78.dist-info/METADATA +0 -44
  57. modelbase2-0.1.78.dist-info/RECORD +0 -22
  58. {modelbase2-0.1.78.dist-info → modelbase2-0.2.0.dist-info/licenses}/LICENSE +0 -0
@@ -0,0 +1,66 @@
1
+ from collections.abc import Iterable
2
+ from dataclasses import dataclass
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ from numpy.polynomial.polynomial import Polynomial
7
+
8
+ from modelbase2.types import AbstractSurrogate, ArrayLike
9
+
10
+ __all__ = ["PolySurrogate", "train_polynomial_surrogate"]
11
+
12
+
13
+ @dataclass(kw_only=True)
14
+ class PolySurrogate(AbstractSurrogate):
15
+ model: Polynomial
16
+
17
+ def predict_raw(self, y: np.ndarray) -> np.ndarray:
18
+ return self.model(y)
19
+
20
+
21
+ def train_polynomial_surrogate(
22
+ feature: ArrayLike,
23
+ target: ArrayLike,
24
+ degrees: Iterable[int] = (1, 2, 3, 4, 5, 6, 7),
25
+ surrogate_args: list[str] | None = None,
26
+ surrogate_stoichiometries: dict[str, dict[str, float]] | None = None,
27
+ ) -> tuple[PolySurrogate, pd.DataFrame]:
28
+ """Train a polynomial surrogate model.
29
+
30
+ Args:
31
+ feature: Input data as a numpy array.
32
+ target: Output data as a numpy array.
33
+ degrees: Degrees of the polynomial to fit to the data.
34
+ surrogate_args: Additional arguments for the surrogate model.
35
+ surrogate_stoichiometries: Stoichiometries for the surrogate model.
36
+
37
+ Returns:
38
+ PolySurrogate: Polynomial surrogate model.
39
+
40
+ """
41
+ feature = np.array(feature, dtype=float)
42
+ target = np.array(target, dtype=float)
43
+
44
+ models = [Polynomial.fit(feature, target, degree) for degree in degrees]
45
+ predictions = np.array([model(feature) for model in models], dtype=float)
46
+ errors = np.sqrt(np.mean(np.square(predictions - target.reshape(1, -1)), axis=1))
47
+ log_likelihood = -0.5 * np.sum(
48
+ np.square(predictions - target.reshape(1, -1)), axis=1
49
+ )
50
+ score = 2 * np.array(degrees) - 2 * log_likelihood
51
+
52
+ # Choose the model with the lowest AIC
53
+ model = models[np.argmin(score)]
54
+ return (
55
+ PolySurrogate(
56
+ model=model,
57
+ args=surrogate_args if surrogate_args is not None else [],
58
+ stoichiometries=surrogate_stoichiometries
59
+ if surrogate_stoichiometries is not None
60
+ else {},
61
+ ),
62
+ pd.DataFrame(
63
+ {"models": models, "error": errors, "score": score},
64
+ index=pd.Index(np.array(degrees), name="degree"),
65
+ ),
66
+ )
@@ -0,0 +1,249 @@
1
+ from dataclasses import dataclass
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import torch
6
+ import tqdm
7
+ from torch import nn
8
+ from torch.optim.adam import Adam
9
+
10
+ from modelbase2.types import AbstractSurrogate
11
+
12
+ __all__ = ["DefaultDevice", "Dense", "TorchSurrogate", "train_torch_surrogate"]
13
+
14
+ DefaultDevice = torch.device("cpu")
15
+
16
+
17
+ @dataclass(kw_only=True)
18
+ class TorchSurrogate(AbstractSurrogate):
19
+ """Surrogate model using PyTorch.
20
+
21
+ Attributes:
22
+ model: PyTorch neural network model.
23
+
24
+ Methods:
25
+ predict: Predict outputs based on input data using the PyTorch model.
26
+
27
+ """
28
+
29
+ model: torch.nn.Module
30
+
31
+ def predict_raw(self, y: np.ndarray) -> np.ndarray:
32
+ """Predict outputs based on input data using the PyTorch model.
33
+
34
+ Args:
35
+ y: Input data as a numpy array.
36
+
37
+ Returns:
38
+ dict[str, float]: Dictionary mapping output variable names to predicted values.
39
+
40
+ """
41
+ with torch.no_grad():
42
+ return self.model(
43
+ torch.tensor(y, dtype=torch.float32),
44
+ ).numpy()
45
+
46
+
47
+ class Dense(nn.Module):
48
+ """Neural network approximator for surrogate modeling.
49
+
50
+ Attributes:
51
+ net: Sequential neural network model.
52
+
53
+ Methods:
54
+ forward: Forward pass through the neural network.
55
+
56
+ """
57
+
58
+ def __init__(self, n_inputs: int, n_outputs: int) -> None:
59
+ """Initializes the surrogate model with the given number of inputs and outputs.
60
+
61
+ Args:
62
+ n_inputs (int): The number of input features.
63
+ n_outputs (int): The number of output features.
64
+
65
+ Initializes a neural network with the following architecture:
66
+ - Linear layer with `n_inputs` inputs and 50 outputs
67
+ - ReLU activation
68
+ - Linear layer with 50 inputs and 50 outputs
69
+ - ReLU activation
70
+ - Linear layer with 50 inputs and `n_outputs` outputs
71
+
72
+ The weights of the linear layers are initialized with a normal distribution
73
+ (mean=0, std=0.1) and the biases are initialized to 0.
74
+
75
+ """
76
+ super().__init__()
77
+
78
+ self.net = nn.Sequential(
79
+ nn.Linear(n_inputs, 50),
80
+ nn.ReLU(),
81
+ nn.Linear(50, 50),
82
+ nn.ReLU(),
83
+ nn.Linear(50, n_outputs),
84
+ )
85
+
86
+ for m in self.net.modules():
87
+ if isinstance(m, nn.Linear):
88
+ nn.init.normal_(m.weight, mean=0, std=0.1)
89
+ nn.init.constant_(m.bias, val=0)
90
+
91
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
92
+ """Forward pass through the neural network.
93
+
94
+ Args:
95
+ x: Input tensor.
96
+
97
+ Returns:
98
+ torch.Tensor: Output tensor.
99
+
100
+ """
101
+ return self.net(x)
102
+
103
+
104
+ def _train_batched(
105
+ aprox: nn.Module,
106
+ features: pd.DataFrame,
107
+ targets: pd.DataFrame,
108
+ epochs: int,
109
+ optimizer: Adam,
110
+ device: torch.device,
111
+ batch_size: int,
112
+ ) -> pd.Series:
113
+ """Train the neural network using mini-batch gradient descent.
114
+
115
+ Args:
116
+ aprox: Neural network model to train.
117
+ features: Input features as a tensor.
118
+ targets: Target values as a tensor.
119
+ epochs: Number of training epochs.
120
+ optimizer: Optimizer for training.
121
+ device: torch device
122
+ batch_size: Size of mini-batches for training.
123
+
124
+ Returns:
125
+ pd.Series: Series containing the training loss history.
126
+
127
+ """
128
+ rng = np.random.default_rng()
129
+ losses = {}
130
+ for i in tqdm.trange(epochs):
131
+ idxs = rng.choice(features.index, size=batch_size)
132
+ X = torch.Tensor(features.iloc[idxs].to_numpy(), device=device)
133
+ Y = torch.Tensor(targets.iloc[idxs].to_numpy(), device=device)
134
+ optimizer.zero_grad()
135
+ loss = torch.mean(torch.abs(aprox(X) - Y))
136
+ loss.backward()
137
+ optimizer.step()
138
+ losses[i] = loss.detach().numpy()
139
+ return pd.Series(losses, dtype=float)
140
+
141
+
142
+ def _train_full(
143
+ aprox: nn.Module,
144
+ features: pd.DataFrame,
145
+ targets: pd.DataFrame,
146
+ epochs: int,
147
+ optimizer: Adam,
148
+ device: torch.device,
149
+ ) -> pd.Series:
150
+ """Train the neural network using full-batch gradient descent.
151
+
152
+ Args:
153
+ aprox: Neural network model to train.
154
+ features: Input features as a tensor.
155
+ targets: Target values as a tensor.
156
+ epochs: Number of training epochs.
157
+ optimizer: Optimizer for training.
158
+ device: Torch device
159
+
160
+ Returns:
161
+ pd.Series: Series containing the training loss history.
162
+
163
+ """
164
+ X = torch.Tensor(features.to_numpy(), device=device)
165
+ Y = torch.Tensor(targets.to_numpy(), device=device)
166
+
167
+ losses = {}
168
+ for i in tqdm.trange(epochs):
169
+ optimizer.zero_grad()
170
+ loss = torch.mean(torch.abs(aprox(X) - Y))
171
+ loss.backward()
172
+ optimizer.step()
173
+ losses[i] = loss.detach().numpy()
174
+ return pd.Series(losses, dtype=float)
175
+
176
+
177
+ def train_torch_surrogate(
178
+ features: pd.DataFrame,
179
+ targets: pd.DataFrame,
180
+ epochs: int,
181
+ surrogate_args: list[str] | None = None,
182
+ surrogate_stoichiometries: dict[str, dict[str, float]] | None = None,
183
+ batch_size: int | None = None,
184
+ approximator: nn.Module | None = None,
185
+ optimimzer_cls: type[Adam] = Adam,
186
+ device: torch.device = DefaultDevice,
187
+ ) -> tuple[TorchSurrogate, pd.Series]:
188
+ """Train a PyTorch surrogate model.
189
+
190
+ Examples:
191
+ >>> train_torch_surrogate(
192
+ ... features,
193
+ ... targets,
194
+ ... epochs=100,
195
+ ... surrogate_inputs=["x1", "x2"],
196
+ ... surrogate_stoichiometries={
197
+ ... "v1": {"x1": -1, "x2": 1, "ATP": -1},
198
+ ... },
199
+ ...)
200
+
201
+ Args:
202
+ features: DataFrame containing the input features for training.
203
+ targets: DataFrame containing the target values for training.
204
+ epochs: Number of training epochs.
205
+ surrogate_args: List of input variable names for the surrogate model.
206
+ surrogate_stoichiometries: Dictionary mapping reaction names to stoichiometries.
207
+ batch_size: Size of mini-batches for training (None for full-batch).
208
+ approximator: Predefined neural network model (None to use default).
209
+ optimimzer_cls: Optimizer class to use for training (default: Adam).
210
+ device: Device to run the training on (default: DefaultDevice).
211
+
212
+ Returns:
213
+ tuple[TorchSurrogate, pd.Series]: Trained surrogate model and loss history.
214
+
215
+ """
216
+ if approximator is None:
217
+ approximator = Dense(
218
+ n_inputs=len(features.columns),
219
+ n_outputs=len(targets.columns),
220
+ ).to(device)
221
+
222
+ optimizer = optimimzer_cls(approximator.parameters())
223
+ if batch_size is None:
224
+ losses = _train_full(
225
+ aprox=approximator,
226
+ features=features,
227
+ targets=targets,
228
+ epochs=epochs,
229
+ optimizer=optimizer,
230
+ device=device,
231
+ )
232
+ else:
233
+ losses = _train_batched(
234
+ aprox=approximator,
235
+ features=features,
236
+ targets=targets,
237
+ epochs=epochs,
238
+ optimizer=optimizer,
239
+ device=device,
240
+ batch_size=batch_size,
241
+ )
242
+ surrogate = TorchSurrogate(
243
+ model=approximator,
244
+ args=surrogate_args if surrogate_args is not None else [],
245
+ stoichiometries=surrogate_stoichiometries
246
+ if surrogate_stoichiometries is not None
247
+ else {},
248
+ )
249
+ return surrogate, losses
@@ -0,0 +1,316 @@
1
+ """Surrogate Models Module.
2
+
3
+ This module provides classes and functions for creating and training surrogate models
4
+ for metabolic simulations. It includes functionality for both steady-state and time-series
5
+ data using neural networks.
6
+
7
+ Classes:
8
+ AbstractSurrogate: Abstract base class for surrogate models.
9
+ TorchSurrogate: Surrogate model using PyTorch.
10
+ Approximator: Neural network approximator for surrogate modeling.
11
+
12
+ Functions:
13
+ train_torch_surrogate: Train a PyTorch surrogate model.
14
+ train_torch_time_course_estimator: Train a PyTorch time course estimator.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ from abc import abstractmethod
20
+ from dataclasses import dataclass
21
+ from pathlib import Path
22
+
23
+ import numpy as np
24
+ import pandas as pd
25
+ import torch
26
+ import tqdm
27
+ from torch import nn
28
+ from torch.optim.adam import Adam
29
+
30
+ from modelbase2.parallel import Cache
31
+
32
+ __all__ = [
33
+ "AbstractSurrogate",
34
+ "Approximator",
35
+ "DefaultCache",
36
+ "DefaultDevice",
37
+ "MockSurrogate",
38
+ "TorchSurrogate",
39
+ "train_torch_surrogate",
40
+ ]
41
+
42
+
43
+ DefaultDevice = torch.device("cpu")
44
+ DefaultCache = Cache(Path(".cache"))
45
+
46
+
47
+ @dataclass(kw_only=True)
48
+ class AbstractSurrogate:
49
+ """Abstract base class for surrogate models.
50
+
51
+ Attributes:
52
+ inputs: List of input variable names.
53
+ stoichiometries: Dictionary mapping reaction names to stoichiometries.
54
+
55
+ Methods:
56
+ predict: Abstract method to predict outputs based on input data.
57
+
58
+ """
59
+
60
+ inputs: list[str]
61
+ stoichiometries: dict[str, dict[str, float]]
62
+
63
+ @abstractmethod
64
+ def predict(self, y: np.ndarray) -> dict[str, float]:
65
+ """Predict outputs based on input data."""
66
+
67
+
68
+ @dataclass(kw_only=True)
69
+ class MockSurrogate(AbstractSurrogate):
70
+ """Mock surrogate model for testing purposes."""
71
+
72
+ def predict(
73
+ self,
74
+ y: np.ndarray,
75
+ ) -> dict[str, float]:
76
+ """Predict outputs based on input data."""
77
+ return dict(zip(self.stoichiometries, y, strict=True))
78
+
79
+
80
+ @dataclass(kw_only=True)
81
+ class TorchSurrogate(AbstractSurrogate):
82
+ """Surrogate model using PyTorch.
83
+
84
+ Attributes:
85
+ model: PyTorch neural network model.
86
+
87
+ Methods:
88
+ predict: Predict outputs based on input data using the PyTorch model.
89
+
90
+ """
91
+
92
+ model: torch.nn.Module
93
+
94
+ def predict(self, y: np.ndarray) -> dict[str, float]:
95
+ """Predict outputs based on input data using the PyTorch model.
96
+
97
+ Args:
98
+ y: Input data as a numpy array.
99
+
100
+ Returns:
101
+ dict[str, float]: Dictionary mapping output variable names to predicted values.
102
+
103
+ """
104
+ with torch.no_grad():
105
+ return dict(
106
+ zip(
107
+ self.stoichiometries,
108
+ self.model(
109
+ torch.tensor(y, dtype=torch.float32),
110
+ ).numpy(),
111
+ strict=True,
112
+ )
113
+ )
114
+
115
+
116
+ class Approximator(nn.Module):
117
+ """Neural network approximator for surrogate modeling.
118
+
119
+ Attributes:
120
+ net: Sequential neural network model.
121
+
122
+ Methods:
123
+ forward: Forward pass through the neural network.
124
+
125
+ """
126
+
127
+ def __init__(self, n_inputs: int, n_outputs: int) -> None:
128
+ """Initializes the surrogate model with the given number of inputs and outputs.
129
+
130
+ Args:
131
+ n_inputs (int): The number of input features.
132
+ n_outputs (int): The number of output features.
133
+
134
+ Initializes a neural network with the following architecture:
135
+ - Linear layer with `n_inputs` inputs and 50 outputs
136
+ - ReLU activation
137
+ - Linear layer with 50 inputs and 50 outputs
138
+ - ReLU activation
139
+ - Linear layer with 50 inputs and `n_outputs` outputs
140
+
141
+ The weights of the linear layers are initialized with a normal distribution
142
+ (mean=0, std=0.1) and the biases are initialized to 0.
143
+
144
+ """
145
+ super().__init__()
146
+
147
+ self.net = nn.Sequential(
148
+ nn.Linear(n_inputs, 50),
149
+ nn.ReLU(),
150
+ nn.Linear(50, 50),
151
+ nn.ReLU(),
152
+ nn.Linear(50, n_outputs),
153
+ )
154
+
155
+ for m in self.net.modules():
156
+ if isinstance(m, nn.Linear):
157
+ nn.init.normal_(m.weight, mean=0, std=0.1)
158
+ nn.init.constant_(m.bias, val=0)
159
+
160
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
161
+ """Forward pass through the neural network.
162
+
163
+ Args:
164
+ x: Input tensor.
165
+
166
+ Returns:
167
+ torch.Tensor: Output tensor.
168
+
169
+ """
170
+ return self.net(x)
171
+
172
+
173
+ def _train_batched(
174
+ aprox: nn.Module,
175
+ features: pd.DataFrame,
176
+ targets: pd.DataFrame,
177
+ epochs: int,
178
+ optimizer: Adam,
179
+ device: torch.device,
180
+ batch_size: int,
181
+ ) -> pd.Series:
182
+ """Train the neural network using mini-batch gradient descent.
183
+
184
+ Args:
185
+ aprox: Neural network model to train.
186
+ features: Input features as a tensor.
187
+ targets: Target values as a tensor.
188
+ epochs: Number of training epochs.
189
+ optimizer: Optimizer for training.
190
+ device: torch device
191
+ batch_size: Size of mini-batches for training.
192
+
193
+ Returns:
194
+ pd.Series: Series containing the training loss history.
195
+
196
+ """
197
+ rng = np.random.default_rng()
198
+ losses = {}
199
+ for i in tqdm.trange(epochs):
200
+ idxs = rng.choice(features.index, size=batch_size)
201
+ X = torch.Tensor(features.iloc[idxs].to_numpy(), device=device)
202
+ Y = torch.Tensor(targets.iloc[idxs].to_numpy(), device=device)
203
+ optimizer.zero_grad()
204
+ loss = torch.mean(torch.abs(aprox(X) - Y))
205
+ loss.backward()
206
+ optimizer.step()
207
+ losses[i] = loss.detach().numpy()
208
+ return pd.Series(losses, dtype=float)
209
+
210
+
211
+ def _train_full(
212
+ aprox: nn.Module,
213
+ features: pd.DataFrame,
214
+ targets: pd.DataFrame,
215
+ epochs: int,
216
+ optimizer: Adam,
217
+ device: torch.device,
218
+ ) -> pd.Series:
219
+ """Train the neural network using full-batch gradient descent.
220
+
221
+ Args:
222
+ aprox: Neural network model to train.
223
+ features: Input features as a tensor.
224
+ targets: Target values as a tensor.
225
+ epochs: Number of training epochs.
226
+ optimizer: Optimizer for training.
227
+ device: Torch device
228
+
229
+ Returns:
230
+ pd.Series: Series containing the training loss history.
231
+
232
+ """
233
+ X = torch.Tensor(features.to_numpy(), device=device)
234
+ Y = torch.Tensor(targets.to_numpy(), device=device)
235
+
236
+ losses = {}
237
+ for i in tqdm.trange(epochs):
238
+ optimizer.zero_grad()
239
+ loss = torch.mean(torch.abs(aprox(X) - Y))
240
+ loss.backward()
241
+ optimizer.step()
242
+ losses[i] = loss.detach().numpy()
243
+ return pd.Series(losses, dtype=float)
244
+
245
+
246
+ def train_torch_surrogate(
247
+ features: pd.DataFrame,
248
+ targets: pd.DataFrame,
249
+ epochs: int,
250
+ surrogate_inputs: list[str],
251
+ surrogate_stoichiometries: dict[str, dict[str, float]],
252
+ batch_size: int | None = None,
253
+ approximator: nn.Module | None = None,
254
+ optimimzer_cls: type[Adam] = Adam,
255
+ device: torch.device = DefaultDevice,
256
+ ) -> tuple[TorchSurrogate, pd.Series]:
257
+ """Train a PyTorch surrogate model.
258
+
259
+ Examples:
260
+ >>> train_torch_surrogate(
261
+ ... features,
262
+ ... targets,
263
+ ... epochs=100,
264
+ ... surrogate_inputs=["x1", "x2"],
265
+ ... surrogate_stoichiometries={
266
+ ... "v1": {"x1": -1, "x2": 1, "ATP": -1},
267
+ ... },
268
+ ...)
269
+
270
+ Args:
271
+ features: DataFrame containing the input features for training.
272
+ targets: DataFrame containing the target values for training.
273
+ epochs: Number of training epochs.
274
+ surrogate_inputs: List of input variable names for the surrogate model.
275
+ surrogate_stoichiometries: Dictionary mapping reaction names to stoichiometries.
276
+ batch_size: Size of mini-batches for training (None for full-batch).
277
+ approximator: Predefined neural network model (None to use default).
278
+ optimimzer_cls: Optimizer class to use for training (default: Adam).
279
+ device: Device to run the training on (default: DefaultDevice).
280
+
281
+ Returns:
282
+ tuple[TorchSurrogate, pd.Series]: Trained surrogate model and loss history.
283
+
284
+ """
285
+ if approximator is None:
286
+ approximator = Approximator(
287
+ n_inputs=len(features.columns),
288
+ n_outputs=len(targets.columns),
289
+ ).to(device)
290
+
291
+ optimizer = optimimzer_cls(approximator.parameters())
292
+ if batch_size is None:
293
+ losses = _train_full(
294
+ aprox=approximator,
295
+ features=features,
296
+ targets=targets,
297
+ epochs=epochs,
298
+ optimizer=optimizer,
299
+ device=device,
300
+ )
301
+ else:
302
+ losses = _train_batched(
303
+ aprox=approximator,
304
+ features=features,
305
+ targets=targets,
306
+ epochs=epochs,
307
+ optimizer=optimizer,
308
+ device=device,
309
+ batch_size=batch_size,
310
+ )
311
+ surrogate = TorchSurrogate(
312
+ model=approximator,
313
+ inputs=surrogate_inputs,
314
+ stoichiometries=surrogate_stoichiometries,
315
+ )
316
+ return surrogate, losses