mxlpy 0.15.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mxlpy/scan.py CHANGED
@@ -51,7 +51,7 @@ if TYPE_CHECKING:
51
51
  from mxlpy.types import Array
52
52
 
53
53
 
54
- def _update_parameters_and[T](
54
+ def _update_parameters_and_initial_conditions[T](
55
55
  pars: pd.Series,
56
56
  fn: Callable[[Model], T],
57
57
  model: Model,
@@ -67,7 +67,9 @@ def _update_parameters_and[T](
67
67
  Result of the function execution.
68
68
 
69
69
  """
70
- model.update_parameters(pars.to_dict())
70
+ pd = pars.to_dict()
71
+ model.update_variables({k: v for k, v in pd.items() if k in model._variables}) # noqa: SLF001
72
+ model.update_parameters({k: v for k, v in pd.items() if k in model._parameters}) # noqa: SLF001
71
73
  return fn(model)
72
74
 
73
75
 
@@ -282,7 +284,6 @@ class SteadyStateWorker(Protocol):
282
284
  def __call__(
283
285
  self,
284
286
  model: Model,
285
- y0: dict[str, float] | None,
286
287
  *,
287
288
  rel_norm: bool,
288
289
  integrator: IntegratorType,
@@ -297,7 +298,6 @@ class TimeCourseWorker(Protocol):
297
298
  def __call__(
298
299
  self,
299
300
  model: Model,
300
- y0: dict[str, float] | None,
301
301
  time_points: Array,
302
302
  *,
303
303
  integrator: IntegratorType,
@@ -312,7 +312,6 @@ class ProtocolWorker(Protocol):
312
312
  def __call__(
313
313
  self,
314
314
  model: Model,
315
- y0: dict[str, float] | None,
316
315
  protocol: pd.DataFrame,
317
316
  *,
318
317
  integrator: IntegratorType,
@@ -324,7 +323,6 @@ class ProtocolWorker(Protocol):
324
323
 
325
324
  def _steady_state_worker(
326
325
  model: Model,
327
- y0: dict[str, float] | None,
328
326
  *,
329
327
  rel_norm: bool,
330
328
  integrator: IntegratorType,
@@ -343,7 +341,7 @@ def _steady_state_worker(
343
341
  """
344
342
  try:
345
343
  res = (
346
- Simulator(model, y0=y0, integrator=integrator)
344
+ Simulator(model, integrator=integrator)
347
345
  .simulate_to_steady_state(rel_norm=rel_norm)
348
346
  .get_result()
349
347
  )
@@ -354,7 +352,6 @@ def _steady_state_worker(
354
352
 
355
353
  def _time_course_worker(
356
354
  model: Model,
357
- y0: dict[str, float] | None,
358
355
  time_points: Array,
359
356
  integrator: IntegratorType,
360
357
  ) -> TimeCourse:
@@ -372,7 +369,7 @@ def _time_course_worker(
372
369
  """
373
370
  try:
374
371
  res = (
375
- Simulator(model, y0=y0, integrator=integrator)
372
+ Simulator(model, integrator=integrator)
376
373
  .simulate_time_course(time_points=time_points)
377
374
  .get_result()
378
375
  )
@@ -387,7 +384,6 @@ def _time_course_worker(
387
384
 
388
385
  def _protocol_worker(
389
386
  model: Model,
390
- y0: dict[str, float] | None,
391
387
  protocol: pd.DataFrame,
392
388
  *,
393
389
  integrator: IntegratorType = DefaultIntegrator,
@@ -408,7 +404,7 @@ def _protocol_worker(
408
404
  """
409
405
  try:
410
406
  res = (
411
- Simulator(model, y0=y0, integrator=integrator)
407
+ Simulator(model, integrator=integrator)
412
408
  .simulate_over_protocol(
413
409
  protocol=protocol,
414
410
  time_points_per_step=time_points_per_step,
@@ -432,20 +428,20 @@ def _protocol_worker(
432
428
 
433
429
  def steady_state(
434
430
  model: Model,
435
- parameters: pd.DataFrame,
436
- y0: dict[str, float] | None = None,
437
431
  *,
432
+ to_scan: pd.DataFrame,
433
+ y0: dict[str, float] | None = None,
438
434
  parallel: bool = True,
439
435
  rel_norm: bool = False,
440
436
  cache: Cache | None = None,
441
437
  worker: SteadyStateWorker = _steady_state_worker,
442
438
  integrator: IntegratorType = DefaultIntegrator,
443
439
  ) -> SteadyStates:
444
- """Get steady-state results over supplied parameters.
440
+ """Get steady-state results over supplied values.
445
441
 
446
442
  Args:
447
443
  model: Model instance to simulate.
448
- parameters: DataFrame containing parameter values to scan.
444
+ to_scan: DataFrame containing parameter or initial values to scan.
449
445
  y0: Initial conditions as a dictionary {variable: value}.
450
446
  parallel: Whether to execute in parallel (default: True).
451
447
  rel_norm: Whether to use relative normalization (default: False).
@@ -478,39 +474,41 @@ def steady_state(
478
474
  | (2, 4) | 0.5 | 2 |
479
475
 
480
476
  """
477
+ if y0 is not None:
478
+ model.update_variables(y0)
479
+
481
480
  res = parallelise(
482
481
  partial(
483
- _update_parameters_and,
482
+ _update_parameters_and_initial_conditions,
484
483
  fn=partial(
485
484
  worker,
486
- y0=y0,
487
485
  rel_norm=rel_norm,
488
486
  integrator=integrator,
489
487
  ),
490
488
  model=model,
491
489
  ),
492
- inputs=list(parameters.iterrows()),
490
+ inputs=list(to_scan.iterrows()),
493
491
  cache=cache,
494
492
  parallel=parallel,
495
493
  )
496
494
  concs = pd.DataFrame({k: v.variables.T for k, v in res.items()}).T
497
495
  fluxes = pd.DataFrame({k: v.fluxes.T for k, v in res.items()}).T
498
496
  idx = (
499
- pd.Index(parameters.iloc[:, 0])
500
- if parameters.shape[1] == 1
501
- else pd.MultiIndex.from_frame(parameters)
497
+ pd.Index(to_scan.iloc[:, 0])
498
+ if to_scan.shape[1] == 1
499
+ else pd.MultiIndex.from_frame(to_scan)
502
500
  )
503
501
  concs.index = idx
504
502
  fluxes.index = idx
505
- return SteadyStates(variables=concs, fluxes=fluxes, parameters=parameters)
503
+ return SteadyStates(variables=concs, fluxes=fluxes, parameters=to_scan)
506
504
 
507
505
 
508
506
  def time_course(
509
507
  model: Model,
510
- parameters: pd.DataFrame,
508
+ *,
509
+ to_scan: pd.DataFrame,
511
510
  time_points: Array,
512
511
  y0: dict[str, float] | None = None,
513
- *,
514
512
  parallel: bool = True,
515
513
  cache: Cache | None = None,
516
514
  worker: TimeCourseWorker = _time_course_worker,
@@ -521,7 +519,7 @@ def time_course(
521
519
  Examples:
522
520
  >>> time_course(
523
521
  >>> model,
524
- >>> parameters=pd.DataFrame({"k1": [1, 1.5, 2]}),
522
+ >>> to_scan=pd.DataFrame({"k1": [1, 1.5, 2]}),
525
523
  >>> time_points=np.linspace(0, 1, 3)
526
524
  >>> ).variables
527
525
 
@@ -539,7 +537,7 @@ def time_course(
539
537
 
540
538
  >>> time_course(
541
539
  >>> model,
542
- >>> parameters=cartesian_product({"k1": [1, 2], "k2": [3, 4]}),
540
+ >>> to_scan=cartesian_product({"k1": [1, 2], "k2": [3, 4]}),
543
541
  >>> time_points=[0.0, 0.5, 1.0],
544
542
  >>> ).variables
545
543
 
@@ -553,7 +551,7 @@ def time_course(
553
551
 
554
552
  Args:
555
553
  model: Model instance to simulate.
556
- parameters: DataFrame containing parameter values to scan.
554
+ to_scan: DataFrame containing parameter or initial values to scan.
557
555
  time_points: Array of time points for the simulation.
558
556
  y0: Initial conditions as a dictionary {variable: value}.
559
557
  cache: Optional cache to store and retrieve results.
@@ -566,25 +564,27 @@ def time_course(
566
564
 
567
565
 
568
566
  """
567
+ if y0 is not None:
568
+ model.update_variables(y0)
569
+
569
570
  res = parallelise(
570
571
  partial(
571
- _update_parameters_and,
572
+ _update_parameters_and_initial_conditions,
572
573
  fn=partial(
573
574
  worker,
574
575
  time_points=time_points,
575
- y0=y0,
576
576
  integrator=integrator,
577
577
  ),
578
578
  model=model,
579
579
  ),
580
- inputs=list(parameters.iterrows()),
580
+ inputs=list(to_scan.iterrows()),
581
581
  cache=cache,
582
582
  parallel=parallel,
583
583
  )
584
584
  concs = cast(dict, {k: v.variables for k, v in res.items()})
585
585
  fluxes = cast(dict, {k: v.fluxes for k, v in res.items()})
586
586
  return TimeCourseByPars(
587
- parameters=parameters,
587
+ parameters=to_scan,
588
588
  variables=pd.concat(concs, names=["n", "time"]),
589
589
  fluxes=pd.concat(fluxes, names=["n", "time"]),
590
590
  )
@@ -592,11 +592,11 @@ def time_course(
592
592
 
593
593
  def time_course_over_protocol(
594
594
  model: Model,
595
- parameters: pd.DataFrame,
595
+ *,
596
+ to_scan: pd.DataFrame,
596
597
  protocol: pd.DataFrame,
597
598
  time_points_per_step: int = 10,
598
599
  y0: dict[str, float] | None = None,
599
- *,
600
600
  parallel: bool = True,
601
601
  cache: Cache | None = None,
602
602
  worker: ProtocolWorker = _protocol_worker,
@@ -618,7 +618,7 @@ def time_course_over_protocol(
618
618
 
619
619
  Args:
620
620
  model: Model instance to simulate.
621
- parameters: DataFrame containing parameter values to scan.
621
+ to_scan: DataFrame containing parameter or initial values to scan.
622
622
  protocol: Protocol to follow for the simulation.
623
623
  time_points_per_step: Number of time points per protocol step (default: 10).
624
624
  y0: Initial conditions as a dictionary {variable: value}.
@@ -631,26 +631,28 @@ def time_course_over_protocol(
631
631
  TimeCourseByPars: Protocol series results for each parameter set.
632
632
 
633
633
  """
634
+ if y0 is not None:
635
+ model.update_variables(y0)
636
+
634
637
  res = parallelise(
635
638
  partial(
636
- _update_parameters_and,
639
+ _update_parameters_and_initial_conditions,
637
640
  fn=partial(
638
641
  worker,
639
642
  protocol=protocol,
640
- y0=y0,
641
643
  time_points_per_step=time_points_per_step,
642
644
  integrator=integrator,
643
645
  ),
644
646
  model=model,
645
647
  ),
646
- inputs=list(parameters.iterrows()),
648
+ inputs=list(to_scan.iterrows()),
647
649
  cache=cache,
648
650
  parallel=parallel,
649
651
  )
650
652
  concs = cast(dict, {k: v.variables for k, v in res.items()})
651
653
  fluxes = cast(dict, {k: v.fluxes for k, v in res.items()})
652
654
  return ProtocolByPars(
653
- parameters=parameters,
655
+ parameters=to_scan,
654
656
  protocol=protocol,
655
657
  variables=pd.concat(concs, names=["n", "time"]),
656
658
  fluxes=pd.concat(fluxes, names=["n", "time"]),
@@ -19,13 +19,14 @@ from __future__ import annotations
19
19
  import contextlib
20
20
 
21
21
  with contextlib.suppress(ImportError):
22
- from ._torch import TorchSurrogate, train_torch_surrogate
22
+ from ._torch import Torch, TorchTrainer, train_torch
23
23
 
24
- from ._poly import PolySurrogate, train_polynomial_surrogate
24
+ from ._poly import Polynomial, train_polynomial
25
25
 
26
26
  __all__ = [
27
- "PolySurrogate",
28
- "TorchSurrogate",
29
- "train_polynomial_surrogate",
30
- "train_torch_surrogate",
27
+ "Polynomial",
28
+ "Torch",
29
+ "TorchTrainer",
30
+ "train_polynomial",
31
+ "train_torch",
31
32
  ]
mxlpy/surrogates/_poly.py CHANGED
@@ -9,9 +9,9 @@ from numpy import polynomial
9
9
  from mxlpy.types import AbstractSurrogate, ArrayLike
10
10
 
11
11
  __all__ = [
12
- "PolySurrogate",
12
+ "Polynomial",
13
13
  "PolynomialExpansion",
14
- "train_polynomial_surrogate",
14
+ "train_polynomial",
15
15
  ]
16
16
 
17
17
  # define custom type
@@ -26,23 +26,24 @@ PolynomialExpansion = (
26
26
 
27
27
 
28
28
  @dataclass(kw_only=True)
29
- class PolySurrogate(AbstractSurrogate):
29
+ class Polynomial(AbstractSurrogate):
30
30
  model: PolynomialExpansion
31
31
 
32
32
  def predict_raw(self, y: np.ndarray) -> np.ndarray:
33
33
  return self.model(y)
34
34
 
35
35
 
36
- def train_polynomial_surrogate(
37
- feature: ArrayLike,
38
- target: ArrayLike,
36
+ def train_polynomial(
37
+ feature: ArrayLike | pd.Series,
38
+ target: ArrayLike | pd.Series,
39
39
  series: Literal[
40
40
  "Power", "Chebyshev", "Legendre", "Laguerre", "Hermite", "HermiteE"
41
41
  ] = "Power",
42
42
  degrees: Iterable[int] = (1, 2, 3, 4, 5, 6, 7),
43
43
  surrogate_args: list[str] | None = None,
44
+ surrogate_outputs: list[str] | None = None,
44
45
  surrogate_stoichiometries: dict[str, dict[str, float]] | None = None,
45
- ) -> tuple[PolySurrogate, pd.DataFrame]:
46
+ ) -> tuple[Polynomial, pd.DataFrame]:
46
47
  """Train a surrogate model based on function series expansion.
47
48
 
48
49
  Args:
@@ -51,7 +52,8 @@ def train_polynomial_surrogate(
51
52
  series: Base functions for the surrogate model
52
53
  degrees: Degrees of the polynomial to fit to the data.
53
54
  surrogate_args: Additional arguments for the surrogate model.
54
- surrogate_stoichiometries: Stoichiometries for the surrogate model.
55
+ surrogate_outputs: Names of the surrogate model outputs.
56
+ surrogate_stoichiometries: Mapping of variables to their stoichiometries
55
57
 
56
58
  Returns:
57
59
  PolySurrogate: Polynomial surrogate model.
@@ -83,9 +85,10 @@ def train_polynomial_surrogate(
83
85
  # Choose the model with the lowest AIC
84
86
  model = models[np.argmin(score)]
85
87
  return (
86
- PolySurrogate(
88
+ Polynomial(
87
89
  model=model,
88
90
  args=surrogate_args if surrogate_args is not None else [],
91
+ outputs=surrogate_outputs if surrogate_outputs is not None else [],
89
92
  stoichiometries=surrogate_stoichiometries
90
93
  if surrogate_stoichiometries is not None
91
94
  else {},
@@ -1,5 +1,6 @@
1
1
  from collections.abc import Callable
2
2
  from dataclasses import dataclass
3
+ from typing import Self
3
4
 
4
5
  import numpy as np
5
6
  import pandas as pd
@@ -12,14 +13,32 @@ from torch.optim.optimizer import ParamsT
12
13
  from mxlpy.nn._torch import MLP, DefaultDevice
13
14
  from mxlpy.types import AbstractSurrogate
14
15
 
16
+ type LossFn = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
17
+
15
18
  __all__ = [
16
- "TorchSurrogate",
17
- "train_torch_surrogate",
19
+ "LossFn",
20
+ "Torch",
21
+ "TorchTrainer",
22
+ "train_torch",
18
23
  ]
19
24
 
20
25
 
26
+ def _mean_abs(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
27
+ """Standard loss for surrogates.
28
+
29
+ Args:
30
+ x: Predictions of a model.
31
+ y: Targets.
32
+
33
+ Returns:
34
+ torch.Tensor: loss.
35
+
36
+ """
37
+ return torch.mean(torch.abs(x - y))
38
+
39
+
21
40
  @dataclass(kw_only=True)
22
- class TorchSurrogate(AbstractSurrogate):
41
+ class Torch(AbstractSurrogate):
23
42
  """Surrogate model using PyTorch.
24
43
 
25
44
  Attributes:
@@ -48,6 +67,91 @@ class TorchSurrogate(AbstractSurrogate):
48
67
  ).numpy()
49
68
 
50
69
 
70
+ @dataclass(init=False)
71
+ class TorchTrainer:
72
+ features: pd.DataFrame
73
+ targets: pd.DataFrame
74
+ approximator: nn.Module
75
+ optimizer: Adam
76
+ device: torch.device
77
+ losses: list[pd.Series]
78
+ loss_fn: LossFn
79
+
80
+ def __init__(
81
+ self,
82
+ features: pd.DataFrame,
83
+ targets: pd.DataFrame,
84
+ approximator: nn.Module | None = None,
85
+ optimimzer_cls: Callable[[ParamsT], Adam] = Adam,
86
+ device: torch.device = DefaultDevice,
87
+ loss_fn: LossFn = _mean_abs,
88
+ ) -> None:
89
+ self.features = features
90
+ self.targets = targets
91
+
92
+ if approximator is None:
93
+ approximator = MLP(
94
+ n_inputs=len(features.columns),
95
+ neurons_per_layer=[50, 50, len(targets.columns)],
96
+ )
97
+ self.approximator = approximator.to(device)
98
+
99
+ self.optimizer = optimimzer_cls(approximator.parameters())
100
+ self.device = device
101
+ self.loss_fn = loss_fn
102
+ self.losses = []
103
+
104
+ def train(
105
+ self,
106
+ epochs: int,
107
+ batch_size: int | None = None,
108
+ ) -> Self:
109
+ if batch_size is None:
110
+ losses = _train_full(
111
+ aprox=self.approximator,
112
+ features=self.features,
113
+ targets=self.targets,
114
+ epochs=epochs,
115
+ optimizer=self.optimizer,
116
+ device=self.device,
117
+ loss_fn=self.loss_fn,
118
+ )
119
+ else:
120
+ losses = _train_batched(
121
+ aprox=self.approximator,
122
+ features=self.features,
123
+ targets=self.targets,
124
+ epochs=epochs,
125
+ optimizer=self.optimizer,
126
+ device=self.device,
127
+ batch_size=batch_size,
128
+ loss_fn=self.loss_fn,
129
+ )
130
+
131
+ if len(self.losses) > 0:
132
+ losses.index += self.losses[-1].index[-1]
133
+ self.losses.append(losses)
134
+ return self
135
+
136
+ def get_loss(self) -> pd.Series:
137
+ return pd.concat(self.losses)
138
+
139
+ def get_surrogate(
140
+ self,
141
+ surrogate_args: list[str] | None = None,
142
+ surrogate_outputs: list[str] | None = None,
143
+ surrogate_stoichiometries: dict[str, dict[str, float]] | None = None,
144
+ ) -> Torch:
145
+ return Torch(
146
+ model=self.approximator,
147
+ args=surrogate_args if surrogate_args is not None else [],
148
+ outputs=surrogate_outputs if surrogate_outputs is not None else [],
149
+ stoichiometries=surrogate_stoichiometries
150
+ if surrogate_stoichiometries is not None
151
+ else {},
152
+ )
153
+
154
+
51
155
  def _train_batched(
52
156
  aprox: nn.Module,
53
157
  features: pd.DataFrame,
@@ -56,6 +160,7 @@ def _train_batched(
56
160
  optimizer: Adam,
57
161
  device: torch.device,
58
162
  batch_size: int,
163
+ loss_fn: LossFn,
59
164
  ) -> pd.Series:
60
165
  """Train the neural network using mini-batch gradient descent.
61
166
 
@@ -67,6 +172,7 @@ def _train_batched(
67
172
  optimizer: Optimizer for training.
68
173
  device: torch device
69
174
  batch_size: Size of mini-batches for training.
175
+ loss_fn: Loss function
70
176
 
71
177
  Returns:
72
178
  pd.Series: Series containing the training loss history.
@@ -79,7 +185,7 @@ def _train_batched(
79
185
  X = torch.Tensor(features.iloc[idxs].to_numpy(), device=device)
80
186
  Y = torch.Tensor(targets.iloc[idxs].to_numpy(), device=device)
81
187
  optimizer.zero_grad()
82
- loss = torch.mean(torch.abs(aprox(X) - Y))
188
+ loss = loss_fn(aprox(X), Y)
83
189
  loss.backward()
84
190
  optimizer.step()
85
191
  losses[i] = loss.detach().numpy()
@@ -93,6 +199,7 @@ def _train_full(
93
199
  epochs: int,
94
200
  optimizer: Adam,
95
201
  device: torch.device,
202
+ loss_fn: Callable,
96
203
  ) -> pd.Series:
97
204
  """Train the neural network using full-batch gradient descent.
98
205
 
@@ -103,6 +210,7 @@ def _train_full(
103
210
  epochs: Number of training epochs.
104
211
  optimizer: Optimizer for training.
105
212
  device: Torch device
213
+ loss_fn: Loss function
106
214
 
107
215
  Returns:
108
216
  pd.Series: Series containing the training loss history.
@@ -114,24 +222,26 @@ def _train_full(
114
222
  losses = {}
115
223
  for i in tqdm.trange(epochs):
116
224
  optimizer.zero_grad()
117
- loss = torch.mean(torch.abs(aprox(X) - Y))
225
+ loss = loss_fn(aprox(X), Y)
118
226
  loss.backward()
119
227
  optimizer.step()
120
228
  losses[i] = loss.detach().numpy()
121
229
  return pd.Series(losses, dtype=float)
122
230
 
123
231
 
124
- def train_torch_surrogate(
232
+ def train_torch(
125
233
  features: pd.DataFrame,
126
234
  targets: pd.DataFrame,
127
235
  epochs: int,
128
236
  surrogate_args: list[str] | None = None,
237
+ surrogate_outputs: list[str] | None = None,
129
238
  surrogate_stoichiometries: dict[str, dict[str, float]] | None = None,
130
239
  batch_size: int | None = None,
131
240
  approximator: nn.Module | None = None,
132
241
  optimimzer_cls: Callable[[ParamsT], Adam] = Adam,
133
242
  device: torch.device = DefaultDevice,
134
- ) -> tuple[TorchSurrogate, pd.Series]:
243
+ loss_fn: LossFn = _mean_abs,
244
+ ) -> tuple[Torch, pd.Series]:
135
245
  """Train a PyTorch surrogate model.
136
246
 
137
247
  Examples:
@@ -143,54 +253,38 @@ def train_torch_surrogate(
143
253
  ... surrogate_stoichiometries={
144
254
  ... "v1": {"x1": -1, "x2": 1, "ATP": -1},
145
255
  ... },
146
- ...)
256
+ ...)surrogate_stoichiometries
147
257
 
148
258
  Args:
149
259
  features: DataFrame containing the input features for training.
150
260
  targets: DataFrame containing the target values for training.
151
261
  epochs: Number of training epochs.
152
- surrogate_args: List of input variable names for the surrogate model.
153
- surrogate_stoichiometries: Dictionary mapping reaction names to stoichiometries.
262
+ surrogate_args: Names of inputs arguments for the surrogate model.
263
+ surrogate_outputs: Names of output arguments from the surrogate.
264
+ surrogate_stoichiometries: Mapping of variables to their stoichiometries
154
265
  batch_size: Size of mini-batches for training (None for full-batch).
155
266
  approximator: Predefined neural network model (None to use default MLP features-50-50-output).
156
267
  optimimzer_cls: Optimizer class to use for training (default: Adam).
157
268
  device: Device to run the training on (default: DefaultDevice).
269
+ loss_fn: Custom loss function or instance of torch loss object
158
270
 
159
271
  Returns:
160
272
  tuple[TorchSurrogate, pd.Series]: Trained surrogate model and loss history.
161
273
 
162
274
  """
163
- if approximator is None:
164
- approximator = MLP(
165
- n_inputs=len(features.columns),
166
- neurons_per_layer=[50, 50, len(targets.columns)],
167
- ).to(device)
168
-
169
- optimizer = optimimzer_cls(approximator.parameters())
170
- if batch_size is None:
171
- losses = _train_full(
172
- aprox=approximator,
173
- features=features,
174
- targets=targets,
175
- epochs=epochs,
176
- optimizer=optimizer,
177
- device=device,
178
- )
179
- else:
180
- losses = _train_batched(
181
- aprox=approximator,
182
- features=features,
183
- targets=targets,
184
- epochs=epochs,
185
- optimizer=optimizer,
186
- device=device,
187
- batch_size=batch_size,
188
- )
189
- surrogate = TorchSurrogate(
190
- model=approximator,
191
- args=surrogate_args if surrogate_args is not None else [],
192
- stoichiometries=surrogate_stoichiometries
193
- if surrogate_stoichiometries is not None
194
- else {},
275
+ trainer = TorchTrainer(
276
+ features=features,
277
+ targets=targets,
278
+ approximator=approximator,
279
+ optimimzer_cls=optimimzer_cls,
280
+ device=device,
281
+ loss_fn=loss_fn,
282
+ ).train(
283
+ epochs=epochs,
284
+ batch_size=batch_size,
195
285
  )
196
- return surrogate, losses
286
+ return trainer.get_surrogate(
287
+ surrogate_args=surrogate_args,
288
+ surrogate_outputs=surrogate_outputs,
289
+ surrogate_stoichiometries=surrogate_stoichiometries,
290
+ ), trainer.get_loss()
@@ -27,9 +27,7 @@ import tqdm
27
27
  from sympy import Matrix
28
28
  from sympy.matrices import zeros
29
29
 
30
- from mxlpy.model import Model
31
-
32
- from .symbolic_model import SymbolicModel, to_symbolic_model
30
+ from .symbolic_model import SymbolicModel
33
31
 
34
32
  __all__ = [
35
33
  "Options",