mxlpy 0.15.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mxlpy/__init__.py +4 -1
- mxlpy/fns.py +513 -21
- mxlpy/integrators/int_assimulo.py +2 -1
- mxlpy/mc.py +84 -70
- mxlpy/mca.py +97 -98
- mxlpy/meta/codegen_latex.py +279 -14
- mxlpy/meta/source_tools.py +122 -4
- mxlpy/model.py +50 -24
- mxlpy/npe/__init__.py +38 -0
- mxlpy/npe/_torch.py +436 -0
- mxlpy/report.py +33 -6
- mxlpy/sbml/_import.py +5 -2
- mxlpy/scan.py +40 -38
- mxlpy/surrogates/__init__.py +7 -6
- mxlpy/surrogates/_poly.py +12 -9
- mxlpy/surrogates/_torch.py +137 -43
- mxlpy/symbolic/strikepy.py +1 -3
- mxlpy/types.py +18 -5
- {mxlpy-0.15.0.dist-info → mxlpy-0.17.0.dist-info}/METADATA +5 -4
- {mxlpy-0.15.0.dist-info → mxlpy-0.17.0.dist-info}/RECORD +22 -21
- mxlpy/npe.py +0 -277
- {mxlpy-0.15.0.dist-info → mxlpy-0.17.0.dist-info}/WHEEL +0 -0
- {mxlpy-0.15.0.dist-info → mxlpy-0.17.0.dist-info}/licenses/LICENSE +0 -0
mxlpy/scan.py
CHANGED
@@ -51,7 +51,7 @@ if TYPE_CHECKING:
|
|
51
51
|
from mxlpy.types import Array
|
52
52
|
|
53
53
|
|
54
|
-
def
|
54
|
+
def _update_parameters_and_initial_conditions[T](
|
55
55
|
pars: pd.Series,
|
56
56
|
fn: Callable[[Model], T],
|
57
57
|
model: Model,
|
@@ -67,7 +67,9 @@ def _update_parameters_and[T](
|
|
67
67
|
Result of the function execution.
|
68
68
|
|
69
69
|
"""
|
70
|
-
|
70
|
+
pd = pars.to_dict()
|
71
|
+
model.update_variables({k: v for k, v in pd.items() if k in model._variables}) # noqa: SLF001
|
72
|
+
model.update_parameters({k: v for k, v in pd.items() if k in model._parameters}) # noqa: SLF001
|
71
73
|
return fn(model)
|
72
74
|
|
73
75
|
|
@@ -282,7 +284,6 @@ class SteadyStateWorker(Protocol):
|
|
282
284
|
def __call__(
|
283
285
|
self,
|
284
286
|
model: Model,
|
285
|
-
y0: dict[str, float] | None,
|
286
287
|
*,
|
287
288
|
rel_norm: bool,
|
288
289
|
integrator: IntegratorType,
|
@@ -297,7 +298,6 @@ class TimeCourseWorker(Protocol):
|
|
297
298
|
def __call__(
|
298
299
|
self,
|
299
300
|
model: Model,
|
300
|
-
y0: dict[str, float] | None,
|
301
301
|
time_points: Array,
|
302
302
|
*,
|
303
303
|
integrator: IntegratorType,
|
@@ -312,7 +312,6 @@ class ProtocolWorker(Protocol):
|
|
312
312
|
def __call__(
|
313
313
|
self,
|
314
314
|
model: Model,
|
315
|
-
y0: dict[str, float] | None,
|
316
315
|
protocol: pd.DataFrame,
|
317
316
|
*,
|
318
317
|
integrator: IntegratorType,
|
@@ -324,7 +323,6 @@ class ProtocolWorker(Protocol):
|
|
324
323
|
|
325
324
|
def _steady_state_worker(
|
326
325
|
model: Model,
|
327
|
-
y0: dict[str, float] | None,
|
328
326
|
*,
|
329
327
|
rel_norm: bool,
|
330
328
|
integrator: IntegratorType,
|
@@ -343,7 +341,7 @@ def _steady_state_worker(
|
|
343
341
|
"""
|
344
342
|
try:
|
345
343
|
res = (
|
346
|
-
Simulator(model,
|
344
|
+
Simulator(model, integrator=integrator)
|
347
345
|
.simulate_to_steady_state(rel_norm=rel_norm)
|
348
346
|
.get_result()
|
349
347
|
)
|
@@ -354,7 +352,6 @@ def _steady_state_worker(
|
|
354
352
|
|
355
353
|
def _time_course_worker(
|
356
354
|
model: Model,
|
357
|
-
y0: dict[str, float] | None,
|
358
355
|
time_points: Array,
|
359
356
|
integrator: IntegratorType,
|
360
357
|
) -> TimeCourse:
|
@@ -372,7 +369,7 @@ def _time_course_worker(
|
|
372
369
|
"""
|
373
370
|
try:
|
374
371
|
res = (
|
375
|
-
Simulator(model,
|
372
|
+
Simulator(model, integrator=integrator)
|
376
373
|
.simulate_time_course(time_points=time_points)
|
377
374
|
.get_result()
|
378
375
|
)
|
@@ -387,7 +384,6 @@ def _time_course_worker(
|
|
387
384
|
|
388
385
|
def _protocol_worker(
|
389
386
|
model: Model,
|
390
|
-
y0: dict[str, float] | None,
|
391
387
|
protocol: pd.DataFrame,
|
392
388
|
*,
|
393
389
|
integrator: IntegratorType = DefaultIntegrator,
|
@@ -408,7 +404,7 @@ def _protocol_worker(
|
|
408
404
|
"""
|
409
405
|
try:
|
410
406
|
res = (
|
411
|
-
Simulator(model,
|
407
|
+
Simulator(model, integrator=integrator)
|
412
408
|
.simulate_over_protocol(
|
413
409
|
protocol=protocol,
|
414
410
|
time_points_per_step=time_points_per_step,
|
@@ -432,20 +428,20 @@ def _protocol_worker(
|
|
432
428
|
|
433
429
|
def steady_state(
|
434
430
|
model: Model,
|
435
|
-
parameters: pd.DataFrame,
|
436
|
-
y0: dict[str, float] | None = None,
|
437
431
|
*,
|
432
|
+
to_scan: pd.DataFrame,
|
433
|
+
y0: dict[str, float] | None = None,
|
438
434
|
parallel: bool = True,
|
439
435
|
rel_norm: bool = False,
|
440
436
|
cache: Cache | None = None,
|
441
437
|
worker: SteadyStateWorker = _steady_state_worker,
|
442
438
|
integrator: IntegratorType = DefaultIntegrator,
|
443
439
|
) -> SteadyStates:
|
444
|
-
"""Get steady-state results over supplied
|
440
|
+
"""Get steady-state results over supplied values.
|
445
441
|
|
446
442
|
Args:
|
447
443
|
model: Model instance to simulate.
|
448
|
-
|
444
|
+
to_scan: DataFrame containing parameter or initial values to scan.
|
449
445
|
y0: Initial conditions as a dictionary {variable: value}.
|
450
446
|
parallel: Whether to execute in parallel (default: True).
|
451
447
|
rel_norm: Whether to use relative normalization (default: False).
|
@@ -478,39 +474,41 @@ def steady_state(
|
|
478
474
|
| (2, 4) | 0.5 | 2 |
|
479
475
|
|
480
476
|
"""
|
477
|
+
if y0 is not None:
|
478
|
+
model.update_variables(y0)
|
479
|
+
|
481
480
|
res = parallelise(
|
482
481
|
partial(
|
483
|
-
|
482
|
+
_update_parameters_and_initial_conditions,
|
484
483
|
fn=partial(
|
485
484
|
worker,
|
486
|
-
y0=y0,
|
487
485
|
rel_norm=rel_norm,
|
488
486
|
integrator=integrator,
|
489
487
|
),
|
490
488
|
model=model,
|
491
489
|
),
|
492
|
-
inputs=list(
|
490
|
+
inputs=list(to_scan.iterrows()),
|
493
491
|
cache=cache,
|
494
492
|
parallel=parallel,
|
495
493
|
)
|
496
494
|
concs = pd.DataFrame({k: v.variables.T for k, v in res.items()}).T
|
497
495
|
fluxes = pd.DataFrame({k: v.fluxes.T for k, v in res.items()}).T
|
498
496
|
idx = (
|
499
|
-
pd.Index(
|
500
|
-
if
|
501
|
-
else pd.MultiIndex.from_frame(
|
497
|
+
pd.Index(to_scan.iloc[:, 0])
|
498
|
+
if to_scan.shape[1] == 1
|
499
|
+
else pd.MultiIndex.from_frame(to_scan)
|
502
500
|
)
|
503
501
|
concs.index = idx
|
504
502
|
fluxes.index = idx
|
505
|
-
return SteadyStates(variables=concs, fluxes=fluxes, parameters=
|
503
|
+
return SteadyStates(variables=concs, fluxes=fluxes, parameters=to_scan)
|
506
504
|
|
507
505
|
|
508
506
|
def time_course(
|
509
507
|
model: Model,
|
510
|
-
|
508
|
+
*,
|
509
|
+
to_scan: pd.DataFrame,
|
511
510
|
time_points: Array,
|
512
511
|
y0: dict[str, float] | None = None,
|
513
|
-
*,
|
514
512
|
parallel: bool = True,
|
515
513
|
cache: Cache | None = None,
|
516
514
|
worker: TimeCourseWorker = _time_course_worker,
|
@@ -521,7 +519,7 @@ def time_course(
|
|
521
519
|
Examples:
|
522
520
|
>>> time_course(
|
523
521
|
>>> model,
|
524
|
-
>>>
|
522
|
+
>>> to_scan=pd.DataFrame({"k1": [1, 1.5, 2]}),
|
525
523
|
>>> time_points=np.linspace(0, 1, 3)
|
526
524
|
>>> ).variables
|
527
525
|
|
@@ -539,7 +537,7 @@ def time_course(
|
|
539
537
|
|
540
538
|
>>> time_course(
|
541
539
|
>>> model,
|
542
|
-
>>>
|
540
|
+
>>> to_scan=cartesian_product({"k1": [1, 2], "k2": [3, 4]}),
|
543
541
|
>>> time_points=[0.0, 0.5, 1.0],
|
544
542
|
>>> ).variables
|
545
543
|
|
@@ -553,7 +551,7 @@ def time_course(
|
|
553
551
|
|
554
552
|
Args:
|
555
553
|
model: Model instance to simulate.
|
556
|
-
|
554
|
+
to_scan: DataFrame containing parameter or initial values to scan.
|
557
555
|
time_points: Array of time points for the simulation.
|
558
556
|
y0: Initial conditions as a dictionary {variable: value}.
|
559
557
|
cache: Optional cache to store and retrieve results.
|
@@ -566,25 +564,27 @@ def time_course(
|
|
566
564
|
|
567
565
|
|
568
566
|
"""
|
567
|
+
if y0 is not None:
|
568
|
+
model.update_variables(y0)
|
569
|
+
|
569
570
|
res = parallelise(
|
570
571
|
partial(
|
571
|
-
|
572
|
+
_update_parameters_and_initial_conditions,
|
572
573
|
fn=partial(
|
573
574
|
worker,
|
574
575
|
time_points=time_points,
|
575
|
-
y0=y0,
|
576
576
|
integrator=integrator,
|
577
577
|
),
|
578
578
|
model=model,
|
579
579
|
),
|
580
|
-
inputs=list(
|
580
|
+
inputs=list(to_scan.iterrows()),
|
581
581
|
cache=cache,
|
582
582
|
parallel=parallel,
|
583
583
|
)
|
584
584
|
concs = cast(dict, {k: v.variables for k, v in res.items()})
|
585
585
|
fluxes = cast(dict, {k: v.fluxes for k, v in res.items()})
|
586
586
|
return TimeCourseByPars(
|
587
|
-
parameters=
|
587
|
+
parameters=to_scan,
|
588
588
|
variables=pd.concat(concs, names=["n", "time"]),
|
589
589
|
fluxes=pd.concat(fluxes, names=["n", "time"]),
|
590
590
|
)
|
@@ -592,11 +592,11 @@ def time_course(
|
|
592
592
|
|
593
593
|
def time_course_over_protocol(
|
594
594
|
model: Model,
|
595
|
-
|
595
|
+
*,
|
596
|
+
to_scan: pd.DataFrame,
|
596
597
|
protocol: pd.DataFrame,
|
597
598
|
time_points_per_step: int = 10,
|
598
599
|
y0: dict[str, float] | None = None,
|
599
|
-
*,
|
600
600
|
parallel: bool = True,
|
601
601
|
cache: Cache | None = None,
|
602
602
|
worker: ProtocolWorker = _protocol_worker,
|
@@ -618,7 +618,7 @@ def time_course_over_protocol(
|
|
618
618
|
|
619
619
|
Args:
|
620
620
|
model: Model instance to simulate.
|
621
|
-
|
621
|
+
to_scan: DataFrame containing parameter or initial values to scan.
|
622
622
|
protocol: Protocol to follow for the simulation.
|
623
623
|
time_points_per_step: Number of time points per protocol step (default: 10).
|
624
624
|
y0: Initial conditions as a dictionary {variable: value}.
|
@@ -631,26 +631,28 @@ def time_course_over_protocol(
|
|
631
631
|
TimeCourseByPars: Protocol series results for each parameter set.
|
632
632
|
|
633
633
|
"""
|
634
|
+
if y0 is not None:
|
635
|
+
model.update_variables(y0)
|
636
|
+
|
634
637
|
res = parallelise(
|
635
638
|
partial(
|
636
|
-
|
639
|
+
_update_parameters_and_initial_conditions,
|
637
640
|
fn=partial(
|
638
641
|
worker,
|
639
642
|
protocol=protocol,
|
640
|
-
y0=y0,
|
641
643
|
time_points_per_step=time_points_per_step,
|
642
644
|
integrator=integrator,
|
643
645
|
),
|
644
646
|
model=model,
|
645
647
|
),
|
646
|
-
inputs=list(
|
648
|
+
inputs=list(to_scan.iterrows()),
|
647
649
|
cache=cache,
|
648
650
|
parallel=parallel,
|
649
651
|
)
|
650
652
|
concs = cast(dict, {k: v.variables for k, v in res.items()})
|
651
653
|
fluxes = cast(dict, {k: v.fluxes for k, v in res.items()})
|
652
654
|
return ProtocolByPars(
|
653
|
-
parameters=
|
655
|
+
parameters=to_scan,
|
654
656
|
protocol=protocol,
|
655
657
|
variables=pd.concat(concs, names=["n", "time"]),
|
656
658
|
fluxes=pd.concat(fluxes, names=["n", "time"]),
|
mxlpy/surrogates/__init__.py
CHANGED
@@ -19,13 +19,14 @@ from __future__ import annotations
|
|
19
19
|
import contextlib
|
20
20
|
|
21
21
|
with contextlib.suppress(ImportError):
|
22
|
-
from ._torch import
|
22
|
+
from ._torch import Torch, TorchTrainer, train_torch
|
23
23
|
|
24
|
-
from ._poly import
|
24
|
+
from ._poly import Polynomial, train_polynomial
|
25
25
|
|
26
26
|
__all__ = [
|
27
|
-
"
|
28
|
-
"
|
29
|
-
"
|
30
|
-
"
|
27
|
+
"Polynomial",
|
28
|
+
"Torch",
|
29
|
+
"TorchTrainer",
|
30
|
+
"train_polynomial",
|
31
|
+
"train_torch",
|
31
32
|
]
|
mxlpy/surrogates/_poly.py
CHANGED
@@ -9,9 +9,9 @@ from numpy import polynomial
|
|
9
9
|
from mxlpy.types import AbstractSurrogate, ArrayLike
|
10
10
|
|
11
11
|
__all__ = [
|
12
|
-
"
|
12
|
+
"Polynomial",
|
13
13
|
"PolynomialExpansion",
|
14
|
-
"
|
14
|
+
"train_polynomial",
|
15
15
|
]
|
16
16
|
|
17
17
|
# define custom type
|
@@ -26,23 +26,24 @@ PolynomialExpansion = (
|
|
26
26
|
|
27
27
|
|
28
28
|
@dataclass(kw_only=True)
|
29
|
-
class
|
29
|
+
class Polynomial(AbstractSurrogate):
|
30
30
|
model: PolynomialExpansion
|
31
31
|
|
32
32
|
def predict_raw(self, y: np.ndarray) -> np.ndarray:
|
33
33
|
return self.model(y)
|
34
34
|
|
35
35
|
|
36
|
-
def
|
37
|
-
feature: ArrayLike,
|
38
|
-
target: ArrayLike,
|
36
|
+
def train_polynomial(
|
37
|
+
feature: ArrayLike | pd.Series,
|
38
|
+
target: ArrayLike | pd.Series,
|
39
39
|
series: Literal[
|
40
40
|
"Power", "Chebyshev", "Legendre", "Laguerre", "Hermite", "HermiteE"
|
41
41
|
] = "Power",
|
42
42
|
degrees: Iterable[int] = (1, 2, 3, 4, 5, 6, 7),
|
43
43
|
surrogate_args: list[str] | None = None,
|
44
|
+
surrogate_outputs: list[str] | None = None,
|
44
45
|
surrogate_stoichiometries: dict[str, dict[str, float]] | None = None,
|
45
|
-
) -> tuple[
|
46
|
+
) -> tuple[Polynomial, pd.DataFrame]:
|
46
47
|
"""Train a surrogate model based on function series expansion.
|
47
48
|
|
48
49
|
Args:
|
@@ -51,7 +52,8 @@ def train_polynomial_surrogate(
|
|
51
52
|
series: Base functions for the surrogate model
|
52
53
|
degrees: Degrees of the polynomial to fit to the data.
|
53
54
|
surrogate_args: Additional arguments for the surrogate model.
|
54
|
-
|
55
|
+
surrogate_outputs: Names of the surrogate model outputs.
|
56
|
+
surrogate_stoichiometries: Mapping of variables to their stoichiometries
|
55
57
|
|
56
58
|
Returns:
|
57
59
|
PolySurrogate: Polynomial surrogate model.
|
@@ -83,9 +85,10 @@ def train_polynomial_surrogate(
|
|
83
85
|
# Choose the model with the lowest AIC
|
84
86
|
model = models[np.argmin(score)]
|
85
87
|
return (
|
86
|
-
|
88
|
+
Polynomial(
|
87
89
|
model=model,
|
88
90
|
args=surrogate_args if surrogate_args is not None else [],
|
91
|
+
outputs=surrogate_outputs if surrogate_outputs is not None else [],
|
89
92
|
stoichiometries=surrogate_stoichiometries
|
90
93
|
if surrogate_stoichiometries is not None
|
91
94
|
else {},
|
mxlpy/surrogates/_torch.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
from collections.abc import Callable
|
2
2
|
from dataclasses import dataclass
|
3
|
+
from typing import Self
|
3
4
|
|
4
5
|
import numpy as np
|
5
6
|
import pandas as pd
|
@@ -12,14 +13,32 @@ from torch.optim.optimizer import ParamsT
|
|
12
13
|
from mxlpy.nn._torch import MLP, DefaultDevice
|
13
14
|
from mxlpy.types import AbstractSurrogate
|
14
15
|
|
16
|
+
type LossFn = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
|
17
|
+
|
15
18
|
__all__ = [
|
16
|
-
"
|
17
|
-
"
|
19
|
+
"LossFn",
|
20
|
+
"Torch",
|
21
|
+
"TorchTrainer",
|
22
|
+
"train_torch",
|
18
23
|
]
|
19
24
|
|
20
25
|
|
26
|
+
def _mean_abs(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
27
|
+
"""Standard loss for surrogates.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
x: Predictions of a model.
|
31
|
+
y: Targets.
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
torch.Tensor: loss.
|
35
|
+
|
36
|
+
"""
|
37
|
+
return torch.mean(torch.abs(x - y))
|
38
|
+
|
39
|
+
|
21
40
|
@dataclass(kw_only=True)
|
22
|
-
class
|
41
|
+
class Torch(AbstractSurrogate):
|
23
42
|
"""Surrogate model using PyTorch.
|
24
43
|
|
25
44
|
Attributes:
|
@@ -48,6 +67,91 @@ class TorchSurrogate(AbstractSurrogate):
|
|
48
67
|
).numpy()
|
49
68
|
|
50
69
|
|
70
|
+
@dataclass(init=False)
|
71
|
+
class TorchTrainer:
|
72
|
+
features: pd.DataFrame
|
73
|
+
targets: pd.DataFrame
|
74
|
+
approximator: nn.Module
|
75
|
+
optimizer: Adam
|
76
|
+
device: torch.device
|
77
|
+
losses: list[pd.Series]
|
78
|
+
loss_fn: LossFn
|
79
|
+
|
80
|
+
def __init__(
|
81
|
+
self,
|
82
|
+
features: pd.DataFrame,
|
83
|
+
targets: pd.DataFrame,
|
84
|
+
approximator: nn.Module | None = None,
|
85
|
+
optimimzer_cls: Callable[[ParamsT], Adam] = Adam,
|
86
|
+
device: torch.device = DefaultDevice,
|
87
|
+
loss_fn: LossFn = _mean_abs,
|
88
|
+
) -> None:
|
89
|
+
self.features = features
|
90
|
+
self.targets = targets
|
91
|
+
|
92
|
+
if approximator is None:
|
93
|
+
approximator = MLP(
|
94
|
+
n_inputs=len(features.columns),
|
95
|
+
neurons_per_layer=[50, 50, len(targets.columns)],
|
96
|
+
)
|
97
|
+
self.approximator = approximator.to(device)
|
98
|
+
|
99
|
+
self.optimizer = optimimzer_cls(approximator.parameters())
|
100
|
+
self.device = device
|
101
|
+
self.loss_fn = loss_fn
|
102
|
+
self.losses = []
|
103
|
+
|
104
|
+
def train(
|
105
|
+
self,
|
106
|
+
epochs: int,
|
107
|
+
batch_size: int | None = None,
|
108
|
+
) -> Self:
|
109
|
+
if batch_size is None:
|
110
|
+
losses = _train_full(
|
111
|
+
aprox=self.approximator,
|
112
|
+
features=self.features,
|
113
|
+
targets=self.targets,
|
114
|
+
epochs=epochs,
|
115
|
+
optimizer=self.optimizer,
|
116
|
+
device=self.device,
|
117
|
+
loss_fn=self.loss_fn,
|
118
|
+
)
|
119
|
+
else:
|
120
|
+
losses = _train_batched(
|
121
|
+
aprox=self.approximator,
|
122
|
+
features=self.features,
|
123
|
+
targets=self.targets,
|
124
|
+
epochs=epochs,
|
125
|
+
optimizer=self.optimizer,
|
126
|
+
device=self.device,
|
127
|
+
batch_size=batch_size,
|
128
|
+
loss_fn=self.loss_fn,
|
129
|
+
)
|
130
|
+
|
131
|
+
if len(self.losses) > 0:
|
132
|
+
losses.index += self.losses[-1].index[-1]
|
133
|
+
self.losses.append(losses)
|
134
|
+
return self
|
135
|
+
|
136
|
+
def get_loss(self) -> pd.Series:
|
137
|
+
return pd.concat(self.losses)
|
138
|
+
|
139
|
+
def get_surrogate(
|
140
|
+
self,
|
141
|
+
surrogate_args: list[str] | None = None,
|
142
|
+
surrogate_outputs: list[str] | None = None,
|
143
|
+
surrogate_stoichiometries: dict[str, dict[str, float]] | None = None,
|
144
|
+
) -> Torch:
|
145
|
+
return Torch(
|
146
|
+
model=self.approximator,
|
147
|
+
args=surrogate_args if surrogate_args is not None else [],
|
148
|
+
outputs=surrogate_outputs if surrogate_outputs is not None else [],
|
149
|
+
stoichiometries=surrogate_stoichiometries
|
150
|
+
if surrogate_stoichiometries is not None
|
151
|
+
else {},
|
152
|
+
)
|
153
|
+
|
154
|
+
|
51
155
|
def _train_batched(
|
52
156
|
aprox: nn.Module,
|
53
157
|
features: pd.DataFrame,
|
@@ -56,6 +160,7 @@ def _train_batched(
|
|
56
160
|
optimizer: Adam,
|
57
161
|
device: torch.device,
|
58
162
|
batch_size: int,
|
163
|
+
loss_fn: LossFn,
|
59
164
|
) -> pd.Series:
|
60
165
|
"""Train the neural network using mini-batch gradient descent.
|
61
166
|
|
@@ -67,6 +172,7 @@ def _train_batched(
|
|
67
172
|
optimizer: Optimizer for training.
|
68
173
|
device: torch device
|
69
174
|
batch_size: Size of mini-batches for training.
|
175
|
+
loss_fn: Loss function
|
70
176
|
|
71
177
|
Returns:
|
72
178
|
pd.Series: Series containing the training loss history.
|
@@ -79,7 +185,7 @@ def _train_batched(
|
|
79
185
|
X = torch.Tensor(features.iloc[idxs].to_numpy(), device=device)
|
80
186
|
Y = torch.Tensor(targets.iloc[idxs].to_numpy(), device=device)
|
81
187
|
optimizer.zero_grad()
|
82
|
-
loss =
|
188
|
+
loss = loss_fn(aprox(X), Y)
|
83
189
|
loss.backward()
|
84
190
|
optimizer.step()
|
85
191
|
losses[i] = loss.detach().numpy()
|
@@ -93,6 +199,7 @@ def _train_full(
|
|
93
199
|
epochs: int,
|
94
200
|
optimizer: Adam,
|
95
201
|
device: torch.device,
|
202
|
+
loss_fn: Callable,
|
96
203
|
) -> pd.Series:
|
97
204
|
"""Train the neural network using full-batch gradient descent.
|
98
205
|
|
@@ -103,6 +210,7 @@ def _train_full(
|
|
103
210
|
epochs: Number of training epochs.
|
104
211
|
optimizer: Optimizer for training.
|
105
212
|
device: Torch device
|
213
|
+
loss_fn: Loss function
|
106
214
|
|
107
215
|
Returns:
|
108
216
|
pd.Series: Series containing the training loss history.
|
@@ -114,24 +222,26 @@ def _train_full(
|
|
114
222
|
losses = {}
|
115
223
|
for i in tqdm.trange(epochs):
|
116
224
|
optimizer.zero_grad()
|
117
|
-
loss =
|
225
|
+
loss = loss_fn(aprox(X), Y)
|
118
226
|
loss.backward()
|
119
227
|
optimizer.step()
|
120
228
|
losses[i] = loss.detach().numpy()
|
121
229
|
return pd.Series(losses, dtype=float)
|
122
230
|
|
123
231
|
|
124
|
-
def
|
232
|
+
def train_torch(
|
125
233
|
features: pd.DataFrame,
|
126
234
|
targets: pd.DataFrame,
|
127
235
|
epochs: int,
|
128
236
|
surrogate_args: list[str] | None = None,
|
237
|
+
surrogate_outputs: list[str] | None = None,
|
129
238
|
surrogate_stoichiometries: dict[str, dict[str, float]] | None = None,
|
130
239
|
batch_size: int | None = None,
|
131
240
|
approximator: nn.Module | None = None,
|
132
241
|
optimimzer_cls: Callable[[ParamsT], Adam] = Adam,
|
133
242
|
device: torch.device = DefaultDevice,
|
134
|
-
|
243
|
+
loss_fn: LossFn = _mean_abs,
|
244
|
+
) -> tuple[Torch, pd.Series]:
|
135
245
|
"""Train a PyTorch surrogate model.
|
136
246
|
|
137
247
|
Examples:
|
@@ -143,54 +253,38 @@ def train_torch_surrogate(
|
|
143
253
|
... surrogate_stoichiometries={
|
144
254
|
... "v1": {"x1": -1, "x2": 1, "ATP": -1},
|
145
255
|
... },
|
146
|
-
...)
|
256
|
+
...)surrogate_stoichiometries
|
147
257
|
|
148
258
|
Args:
|
149
259
|
features: DataFrame containing the input features for training.
|
150
260
|
targets: DataFrame containing the target values for training.
|
151
261
|
epochs: Number of training epochs.
|
152
|
-
surrogate_args:
|
153
|
-
|
262
|
+
surrogate_args: Names of inputs arguments for the surrogate model.
|
263
|
+
surrogate_outputs: Names of output arguments from the surrogate.
|
264
|
+
surrogate_stoichiometries: Mapping of variables to their stoichiometries
|
154
265
|
batch_size: Size of mini-batches for training (None for full-batch).
|
155
266
|
approximator: Predefined neural network model (None to use default MLP features-50-50-output).
|
156
267
|
optimimzer_cls: Optimizer class to use for training (default: Adam).
|
157
268
|
device: Device to run the training on (default: DefaultDevice).
|
269
|
+
loss_fn: Custom loss function or instance of torch loss object
|
158
270
|
|
159
271
|
Returns:
|
160
272
|
tuple[TorchSurrogate, pd.Series]: Trained surrogate model and loss history.
|
161
273
|
|
162
274
|
"""
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
features=features,
|
174
|
-
targets=targets,
|
175
|
-
epochs=epochs,
|
176
|
-
optimizer=optimizer,
|
177
|
-
device=device,
|
178
|
-
)
|
179
|
-
else:
|
180
|
-
losses = _train_batched(
|
181
|
-
aprox=approximator,
|
182
|
-
features=features,
|
183
|
-
targets=targets,
|
184
|
-
epochs=epochs,
|
185
|
-
optimizer=optimizer,
|
186
|
-
device=device,
|
187
|
-
batch_size=batch_size,
|
188
|
-
)
|
189
|
-
surrogate = TorchSurrogate(
|
190
|
-
model=approximator,
|
191
|
-
args=surrogate_args if surrogate_args is not None else [],
|
192
|
-
stoichiometries=surrogate_stoichiometries
|
193
|
-
if surrogate_stoichiometries is not None
|
194
|
-
else {},
|
275
|
+
trainer = TorchTrainer(
|
276
|
+
features=features,
|
277
|
+
targets=targets,
|
278
|
+
approximator=approximator,
|
279
|
+
optimimzer_cls=optimimzer_cls,
|
280
|
+
device=device,
|
281
|
+
loss_fn=loss_fn,
|
282
|
+
).train(
|
283
|
+
epochs=epochs,
|
284
|
+
batch_size=batch_size,
|
195
285
|
)
|
196
|
-
return
|
286
|
+
return trainer.get_surrogate(
|
287
|
+
surrogate_args=surrogate_args,
|
288
|
+
surrogate_outputs=surrogate_outputs,
|
289
|
+
surrogate_stoichiometries=surrogate_stoichiometries,
|
290
|
+
), trainer.get_loss()
|
mxlpy/symbolic/strikepy.py
CHANGED