mxlpy 0.15.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mxlpy/__init__.py +4 -1
- mxlpy/fns.py +513 -21
- mxlpy/integrators/int_assimulo.py +2 -1
- mxlpy/mc.py +84 -70
- mxlpy/mca.py +97 -98
- mxlpy/meta/codegen_latex.py +279 -14
- mxlpy/meta/source_tools.py +122 -4
- mxlpy/model.py +50 -24
- mxlpy/npe/__init__.py +38 -0
- mxlpy/npe/_torch.py +436 -0
- mxlpy/report.py +33 -6
- mxlpy/sbml/_import.py +5 -2
- mxlpy/scan.py +40 -38
- mxlpy/surrogates/__init__.py +7 -6
- mxlpy/surrogates/_poly.py +12 -9
- mxlpy/surrogates/_torch.py +137 -43
- mxlpy/symbolic/strikepy.py +1 -3
- mxlpy/types.py +18 -5
- {mxlpy-0.15.0.dist-info → mxlpy-0.17.0.dist-info}/METADATA +5 -4
- {mxlpy-0.15.0.dist-info → mxlpy-0.17.0.dist-info}/RECORD +22 -21
- mxlpy/npe.py +0 -277
- {mxlpy-0.15.0.dist-info → mxlpy-0.17.0.dist-info}/WHEEL +0 -0
- {mxlpy-0.15.0.dist-info → mxlpy-0.17.0.dist-info}/licenses/LICENSE +0 -0
mxlpy/types.py
CHANGED
@@ -22,6 +22,7 @@ from dataclasses import dataclass, field
|
|
22
22
|
import pandas as pd
|
23
23
|
|
24
24
|
__all__ = [
|
25
|
+
"AbstractEstimator",
|
25
26
|
"AbstractSurrogate",
|
26
27
|
"Array",
|
27
28
|
"ArrayLike",
|
@@ -354,7 +355,7 @@ class McSteadyStates:
|
|
354
355
|
variables: pd.DataFrame
|
355
356
|
fluxes: pd.DataFrame
|
356
357
|
parameters: pd.DataFrame
|
357
|
-
|
358
|
+
mc_to_scan: pd.DataFrame
|
358
359
|
|
359
360
|
def __iter__(self) -> Iterator[pd.DataFrame]:
|
360
361
|
"""Iterate over the concentration and flux steady states."""
|
@@ -444,7 +445,8 @@ class AbstractSurrogate:
|
|
444
445
|
|
445
446
|
"""
|
446
447
|
|
447
|
-
args: list[str]
|
448
|
+
args: list[str]
|
449
|
+
outputs: list[str]
|
448
450
|
stoichiometries: dict[str, dict[str, float]] = field(default_factory=dict)
|
449
451
|
|
450
452
|
@abstractmethod
|
@@ -455,7 +457,7 @@ class AbstractSurrogate:
|
|
455
457
|
"""Predict outputs based on input data."""
|
456
458
|
return dict(
|
457
459
|
zip(
|
458
|
-
self.
|
460
|
+
self.outputs,
|
459
461
|
self.predict_raw(y),
|
460
462
|
strict=True,
|
461
463
|
)
|
@@ -475,7 +477,7 @@ class AbstractSurrogate:
|
|
475
477
|
args: pd.DataFrame,
|
476
478
|
) -> None:
|
477
479
|
"""Predict outputs based on input data."""
|
478
|
-
args[
|
480
|
+
args[self.outputs] = pd.DataFrame(
|
479
481
|
[self.predict(y) for y in args.loc[:, self.args].to_numpy()],
|
480
482
|
index=args.index,
|
481
483
|
dtype=float,
|
@@ -491,4 +493,15 @@ class MockSurrogate(AbstractSurrogate):
|
|
491
493
|
y: np.ndarray,
|
492
494
|
) -> dict[str, float]:
|
493
495
|
"""Predict outputs based on input data."""
|
494
|
-
return dict(zip(self.
|
496
|
+
return dict(zip(self.outputs, y, strict=True))
|
497
|
+
|
498
|
+
|
499
|
+
@dataclass(kw_only=True)
|
500
|
+
class AbstractEstimator:
|
501
|
+
"""Abstract class for parameter estimation using neural networks."""
|
502
|
+
|
503
|
+
parameter_names: list[str]
|
504
|
+
|
505
|
+
@abstractmethod
|
506
|
+
def predict(self, features: pd.Series | pd.DataFrame) -> pd.DataFrame:
|
507
|
+
"""Predict the target values for the given features."""
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: mxlpy
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.17.0
|
4
4
|
Summary: A package to build metabolic models
|
5
5
|
Author-email: Marvin van Aalst <marvin.vanaalst@gmail.com>
|
6
6
|
Maintainer-email: Marvin van Aalst <marvin.vanaalst@gmail.com>
|
@@ -47,7 +47,6 @@ Requires-Dist: jupyter>=1.1.1; extra == 'dev'
|
|
47
47
|
Requires-Dist: mkdocs-jupyter>=0.25.1; extra == 'dev'
|
48
48
|
Requires-Dist: mkdocs-material>=9.5.42; extra == 'dev'
|
49
49
|
Requires-Dist: mkdocs>=1.6.1; extra == 'dev'
|
50
|
-
Requires-Dist: mypy>=1.13.0; extra == 'dev'
|
51
50
|
Requires-Dist: pre-commit>=4.0.1; extra == 'dev'
|
52
51
|
Requires-Dist: pyright>=1.1.387; extra == 'dev'
|
53
52
|
Requires-Dist: pytest-cov>=5.0.0; extra == 'dev'
|
@@ -60,7 +59,9 @@ Provides-Extra: torch
|
|
60
59
|
Requires-Dist: torch>=2.5.1; extra == 'torch'
|
61
60
|
Description-Content-Type: text/markdown
|
62
61
|
|
63
|
-
<
|
62
|
+
<p align="center">
|
63
|
+
<img src="docs/assets/logo-diagram.png" width="600px" alt='mxlpy-logo'>
|
64
|
+
</p>
|
64
65
|
|
65
66
|
# mxlpy
|
66
67
|
|
@@ -70,7 +71,7 @@ Description-Content-Type: text/markdown
|
|
70
71
|

|
71
72
|
[](https://github.com/astral-sh/ruff)
|
72
73
|
[](https://github.com/PyCQA/bandit)
|
73
|
-
[](https://pepy.tech/
|
74
|
+
[](https://pepy.tech/projects/mxlpy)
|
74
75
|
|
75
76
|
[docs-badge]: https://img.shields.io/badge/docs-main-green.svg?style=flat-square
|
76
77
|
[docs]: https://computational-biology-aachen.github.io/mxlpy/
|
@@ -1,50 +1,51 @@
|
|
1
|
-
mxlpy/__init__.py,sha256=
|
1
|
+
mxlpy/__init__.py,sha256=lGo7XQTVuR1p8rW1J6gZsgdQWRqfYa9AWbvZQwT8oLQ,4236
|
2
2
|
mxlpy/distributions.py,sha256=ce6RTqn19YzMMec-u09fSIUA8A92M6rehCuHuXWcX7A,8734
|
3
3
|
mxlpy/fit.py,sha256=LwSoLfNVrqSlTtuUApwH36LjzGU0HLs4C_2qqTTjXFE,7859
|
4
|
-
mxlpy/fns.py,sha256=
|
4
|
+
mxlpy/fns.py,sha256=VxDDyEdtGD7fEoT5LiiEaRqFk-0fIunRXHr1dCMpCdE,14002
|
5
5
|
mxlpy/identify.py,sha256=af52SCG4nlY9sSw22goaIheuvXR09QYK4ksCT24QHWI,1946
|
6
6
|
mxlpy/label_map.py,sha256=urv-QTb0MUEKjwWvKtJSB8H2kvhLn1EKfRIH7awQQ8Y,17769
|
7
7
|
mxlpy/linear_label_map.py,sha256=DqzN_akacPccZwzYAR3ANIdzAU_GU6Xe6gWV9DHAAWU,10282
|
8
|
-
mxlpy/mc.py,sha256=
|
9
|
-
mxlpy/mca.py,sha256=
|
10
|
-
mxlpy/model.py,sha256=
|
11
|
-
mxlpy/npe.py,sha256=oiRLA43-qf-AcS2KpQfJIOt7-Ev9Aj5sF6TMq9bJn84,8747
|
8
|
+
mxlpy/mc.py,sha256=oYd8a3ycyZLyh-ZxTYUjDRNfsCcwSQaLWssxv0yC5Cc,17399
|
9
|
+
mxlpy/mca.py,sha256=1_qBX9lHI6svXSebtwvMldAMwPlLqMylAPmxMbMQdWw,9359
|
10
|
+
mxlpy/model.py,sha256=H1rAKaB5pAQcMuBh5GnXuBReADTx5IDa1x0CdUZ6VlI,58411
|
12
11
|
mxlpy/parallel.py,sha256=kX4Td5YoovDwZp6kX_3cfO6QtHSS9ieJ0bMZiKs3Xv8,5002
|
13
12
|
mxlpy/parameterise.py,sha256=2jMhhO-bHTFP_0kXercJekeATAZYBg5FrK1MQ_mWGpk,654
|
14
13
|
mxlpy/paths.py,sha256=TK2wO4N9lG-UV1JGfeB64q48JVDbwqIUj63rl55MKuQ,1022
|
15
14
|
mxlpy/plot.py,sha256=4uu-6d8LH-GWX-sG_TlSpkSsnikv1DLTtnjJzA7nuRA,24670
|
16
15
|
mxlpy/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
|
-
mxlpy/report.py,sha256=
|
18
|
-
mxlpy/scan.py,sha256
|
16
|
+
mxlpy/report.py,sha256=ZwnjquPAvo4A8UqK-BT19SZFSEUOy1FALqoh7uTmbAI,7793
|
17
|
+
mxlpy/scan.py,sha256=FBPpjv66v4IWZ5OwG_EWUdrucLWR9gq_XEsLFC-otaw,18969
|
19
18
|
mxlpy/simulator.py,sha256=9Ne4P5Jrwgx4oAlljPvCqSCCy98_5Lv1B87y1AkbI4c,21041
|
20
|
-
mxlpy/types.py,sha256=
|
19
|
+
mxlpy/types.py,sha256=fB8-oTJkIpkGP0affoVx1ak2zOuTpT6xH-w62oSJxiU,14814
|
21
20
|
mxlpy/experimental/__init__.py,sha256=kZTE-92OErpHzNRqmgSQYH4CGXrogGJ5EL35XGZQ81M,206
|
22
21
|
mxlpy/experimental/diff.py,sha256=4bztagJzFMsQJM7dlun_kv-WrWssM8CIw7gcL63hFf8,8952
|
23
22
|
mxlpy/integrators/__init__.py,sha256=kqmV6a0TRyLGR_XqbyAI652AfptYnXAUpqbSFg0CpP8,450
|
24
|
-
mxlpy/integrators/int_assimulo.py,sha256=
|
23
|
+
mxlpy/integrators/int_assimulo.py,sha256=TCBWQd558ZeRdBba1jCNsFyLBOssKvm8dXK36Aqg4_k,4817
|
25
24
|
mxlpy/integrators/int_scipy.py,sha256=dFHlYTeb2zX97f3VuNdMJdI7WEYshF4JAIgprKKk2z4,4581
|
26
25
|
mxlpy/meta/__init__.py,sha256=Jyy4063fZy6iT4LSwjPyEAVr4N_3xxcLc8wDBoDPyKc,278
|
27
|
-
mxlpy/meta/codegen_latex.py,sha256=
|
26
|
+
mxlpy/meta/codegen_latex.py,sha256=vONj--_wmFM_FJpe15aAYyT06-kolqQwSe2NEbKrQWo,19934
|
28
27
|
mxlpy/meta/codegen_modebase.py,sha256=_ZAW4NvXhKwJQLGz5hkwwZpL2JMAJlfG-GUWkYIiNvw,3124
|
29
28
|
mxlpy/meta/codegen_py.py,sha256=xSdeuEGPGc-QKRMgJO4VSPGMlxCPEV5prkKjNQ2D2hg,3483
|
30
|
-
mxlpy/meta/source_tools.py,sha256=
|
29
|
+
mxlpy/meta/source_tools.py,sha256=IyiCLZ1KScSqADC9p_QSRgedoHGibs9U1RGJuXm827U,13464
|
31
30
|
mxlpy/nn/__init__.py,sha256=yUc4o-iqfVVzkq9tZCstWwizrCqNlMft0YUwWGFFO-E,219
|
32
31
|
mxlpy/nn/_tensorflow.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
33
32
|
mxlpy/nn/_torch.py,sha256=_4Rw87zpIlCnrOsXC7iFp1c64_FcpfVmgBXBU0p8mlg,4063
|
33
|
+
mxlpy/npe/__init__.py,sha256=IQmqUPJc5A8QXJLzp6Dq6Sjm8Hh2KAYZLrMxXQVeQP8,1181
|
34
|
+
mxlpy/npe/_torch.py,sha256=pMU4PL3eO9Aqdn9waDUpvvDRdmUlmaOFtREwSkZbvNs,13874
|
34
35
|
mxlpy/sbml/__init__.py,sha256=AS7IwrBzBgN8coUZkyBEtiYa9ICWyY1wzp1ujVm5ItA,226
|
35
36
|
mxlpy/sbml/_data.py,sha256=XwT1sSxn6KLTXYMbk4ORbEAEgZhQDBfoyrjMBDAoY_s,1135
|
36
37
|
mxlpy/sbml/_export.py,sha256=Q6B9rxy-yt73DORzAYu4BpfkZXxCS3MDSDUXwpoXV6Q,19970
|
37
|
-
mxlpy/sbml/_import.py,sha256=
|
38
|
+
mxlpy/sbml/_import.py,sha256=5odQBdpD93mQJp2bVIabmPo6NK60nxqrdSVB8fEsF_A,22099
|
38
39
|
mxlpy/sbml/_mathml.py,sha256=bNk9RQ_NQFDhY1R354p-gwqqHaIiyAwZ1xLPHHhiguQ,24436
|
39
40
|
mxlpy/sbml/_name_conversion.py,sha256=XK9DEyzhrD0GBBwwjK9RA0yORrDX5c-Uvx0VtKMR5rA,1325
|
40
41
|
mxlpy/sbml/_unit_conversion.py,sha256=dW_I6_Ou09ccwnp6LIdrPriIQnQUK5lJcjzM2Fawm6U,1927
|
41
|
-
mxlpy/surrogates/__init__.py,sha256=
|
42
|
-
mxlpy/surrogates/_poly.py,sha256=
|
43
|
-
mxlpy/surrogates/_torch.py,sha256=
|
42
|
+
mxlpy/surrogates/__init__.py,sha256=ofHPNwe0LAILP2ZUWivAQpOv9LyHHzLZc6iu1cV2LeQ,894
|
43
|
+
mxlpy/surrogates/_poly.py,sha256=n1pe4xuD2A4BK8jJagzZ-17WW3kqvFBO-ZYuznmfosw,3303
|
44
|
+
mxlpy/surrogates/_torch.py,sha256=Ep5e5oDyUsUdUpEqCY7WKLKKuwbPu0gcmVTiRabNzQ4,8593
|
44
45
|
mxlpy/symbolic/__init__.py,sha256=3hQjCMw8-6iOxeUdfnCg8449fF_BRF2u6lCM1GPpkRY,222
|
45
|
-
mxlpy/symbolic/strikepy.py,sha256=
|
46
|
+
mxlpy/symbolic/strikepy.py,sha256=UMx2LMRwCkASKjdCYEvh9tKlW9dk3nDoWM9NNJXWL_8,19960
|
46
47
|
mxlpy/symbolic/symbolic_model.py,sha256=YL9noEeP3_0DoKXwMPELtfmPuP6mgNcLIJgDRCkyB7A,2434
|
47
|
-
mxlpy-0.
|
48
|
-
mxlpy-0.
|
49
|
-
mxlpy-0.
|
50
|
-
mxlpy-0.
|
48
|
+
mxlpy-0.17.0.dist-info/METADATA,sha256=8fzqS2MFlBN-JkidtjpM3i5DyfooGggrRv3AylMRIVQ,4507
|
49
|
+
mxlpy-0.17.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
50
|
+
mxlpy-0.17.0.dist-info/licenses/LICENSE,sha256=bEzjyjy1stQhfRDVaVHa3xV1x-V8emwdlbMvYO8Zo84,35073
|
51
|
+
mxlpy-0.17.0.dist-info/RECORD,,
|
mxlpy/npe.py
DELETED
@@ -1,277 +0,0 @@
|
|
1
|
-
"""Neural Network Parameter Estimation (NPE) Module.
|
2
|
-
|
3
|
-
This module provides classes and functions for training neural network models to estimate
|
4
|
-
parameters in metabolic models. It includes functionality for both steady-state and
|
5
|
-
time-series data.
|
6
|
-
|
7
|
-
Functions:
|
8
|
-
train_torch_surrogate: Train a PyTorch surrogate model
|
9
|
-
train_torch_time_course_estimator: Train a PyTorch time course estimator
|
10
|
-
"""
|
11
|
-
|
12
|
-
from __future__ import annotations
|
13
|
-
|
14
|
-
__all__ = [
|
15
|
-
"AbstractEstimator",
|
16
|
-
"DefaultCache",
|
17
|
-
"TorchSSEstimator",
|
18
|
-
"TorchTimeCourseEstimator",
|
19
|
-
"train_torch_ss_estimator",
|
20
|
-
"train_torch_time_course_estimator",
|
21
|
-
]
|
22
|
-
|
23
|
-
from abc import abstractmethod
|
24
|
-
from dataclasses import dataclass
|
25
|
-
from pathlib import Path
|
26
|
-
from typing import TYPE_CHECKING, cast
|
27
|
-
|
28
|
-
import numpy as np
|
29
|
-
import pandas as pd
|
30
|
-
import torch
|
31
|
-
import tqdm
|
32
|
-
from torch import nn
|
33
|
-
from torch.optim.adam import Adam
|
34
|
-
|
35
|
-
from mxlpy.nn._torch import LSTM, MLP, DefaultDevice
|
36
|
-
from mxlpy.parallel import Cache
|
37
|
-
|
38
|
-
if TYPE_CHECKING:
|
39
|
-
from collections.abc import Callable
|
40
|
-
|
41
|
-
from torch.optim.optimizer import ParamsT
|
42
|
-
|
43
|
-
DefaultCache = Cache(Path(".cache"))
|
44
|
-
|
45
|
-
|
46
|
-
@dataclass(kw_only=True)
|
47
|
-
class AbstractEstimator:
|
48
|
-
"""Abstract class for parameter estimation using neural networks."""
|
49
|
-
|
50
|
-
parameter_names: list[str]
|
51
|
-
|
52
|
-
@abstractmethod
|
53
|
-
def predict(self, features: pd.Series | pd.DataFrame) -> pd.DataFrame:
|
54
|
-
"""Predict the target values for the given features."""
|
55
|
-
|
56
|
-
|
57
|
-
@dataclass(kw_only=True)
|
58
|
-
class TorchSSEstimator(AbstractEstimator):
|
59
|
-
"""Estimator for steady state data using PyTorch models."""
|
60
|
-
|
61
|
-
model: torch.nn.Module
|
62
|
-
|
63
|
-
def predict(self, features: pd.Series | pd.DataFrame) -> pd.DataFrame:
|
64
|
-
"""Predict the target values for the given features."""
|
65
|
-
with torch.no_grad():
|
66
|
-
pred = self.model(torch.tensor(features.to_numpy(), dtype=torch.float32))
|
67
|
-
return pd.DataFrame(pred, columns=self.parameter_names)
|
68
|
-
|
69
|
-
|
70
|
-
@dataclass(kw_only=True)
|
71
|
-
class TorchTimeCourseEstimator(AbstractEstimator):
|
72
|
-
"""Estimator for time course data using PyTorch models."""
|
73
|
-
|
74
|
-
model: torch.nn.Module
|
75
|
-
|
76
|
-
def predict(self, features: pd.Series | pd.DataFrame) -> pd.DataFrame:
|
77
|
-
"""Predict the target values for the given features."""
|
78
|
-
idx = cast(pd.MultiIndex, features.index)
|
79
|
-
features_ = torch.Tensor(
|
80
|
-
np.swapaxes(
|
81
|
-
features.to_numpy().reshape(
|
82
|
-
(
|
83
|
-
len(idx.levels[0]),
|
84
|
-
len(idx.levels[1]),
|
85
|
-
len(features.columns),
|
86
|
-
)
|
87
|
-
),
|
88
|
-
axis1=0,
|
89
|
-
axis2=1,
|
90
|
-
),
|
91
|
-
)
|
92
|
-
with torch.no_grad():
|
93
|
-
pred = self.model(features_)
|
94
|
-
return pd.DataFrame(pred, columns=self.parameter_names)
|
95
|
-
|
96
|
-
|
97
|
-
def _train_batched(
|
98
|
-
approximator: nn.Module,
|
99
|
-
features: torch.Tensor,
|
100
|
-
targets: torch.Tensor,
|
101
|
-
epochs: int,
|
102
|
-
optimizer: Adam,
|
103
|
-
batch_size: int,
|
104
|
-
) -> pd.Series:
|
105
|
-
losses = {}
|
106
|
-
|
107
|
-
for epoch in tqdm.trange(epochs):
|
108
|
-
permutation = torch.randperm(features.size()[0])
|
109
|
-
epoch_loss = 0
|
110
|
-
for i in range(0, features.size()[0], batch_size):
|
111
|
-
optimizer.zero_grad()
|
112
|
-
indices = permutation[i : i + batch_size]
|
113
|
-
|
114
|
-
loss = torch.mean(
|
115
|
-
torch.abs(approximator(features[indices]) - targets[indices])
|
116
|
-
)
|
117
|
-
loss.backward()
|
118
|
-
optimizer.step()
|
119
|
-
epoch_loss += loss.detach().numpy()
|
120
|
-
|
121
|
-
losses[epoch] = epoch_loss / (features.size()[0] / batch_size)
|
122
|
-
return pd.Series(losses, dtype=float)
|
123
|
-
|
124
|
-
|
125
|
-
def _train_full(
|
126
|
-
approximator: nn.Module,
|
127
|
-
features: torch.Tensor,
|
128
|
-
targets: torch.Tensor,
|
129
|
-
epochs: int,
|
130
|
-
optimizer: Adam,
|
131
|
-
) -> pd.Series:
|
132
|
-
losses = {}
|
133
|
-
for i in tqdm.trange(epochs):
|
134
|
-
optimizer.zero_grad()
|
135
|
-
loss = torch.mean(torch.abs(approximator(features) - targets))
|
136
|
-
loss.backward()
|
137
|
-
optimizer.step()
|
138
|
-
losses[i] = loss.detach().numpy()
|
139
|
-
return pd.Series(losses, dtype=float)
|
140
|
-
|
141
|
-
|
142
|
-
def train_torch_ss_estimator(
|
143
|
-
features: pd.DataFrame,
|
144
|
-
targets: pd.DataFrame,
|
145
|
-
epochs: int,
|
146
|
-
batch_size: int | None = None,
|
147
|
-
approximator: nn.Module | None = None,
|
148
|
-
optimimzer_cls: Callable[[ParamsT], Adam] = Adam,
|
149
|
-
device: torch.device = DefaultDevice,
|
150
|
-
) -> tuple[TorchSSEstimator, pd.Series]:
|
151
|
-
"""Train a PyTorch steady state estimator.
|
152
|
-
|
153
|
-
This function trains a neural network model to estimate steady state data
|
154
|
-
using the provided features and targets. It supports both full-batch and
|
155
|
-
mini-batch training.
|
156
|
-
|
157
|
-
Examples:
|
158
|
-
>>> train_torch_ss_estimator(features, targets, epochs=100)
|
159
|
-
|
160
|
-
Args:
|
161
|
-
features: DataFrame containing the input features for training
|
162
|
-
targets: DataFrame containing the target values for training
|
163
|
-
epochs: Number of training epochs
|
164
|
-
batch_size: Size of mini-batches for training (None for full-batch)
|
165
|
-
approximator: Predefined neural network model (None to use default MLP)
|
166
|
-
optimimzer_cls: Optimizer class to use for training (default: Adam)
|
167
|
-
device: Device to run the training on (default: DefaultDevice)
|
168
|
-
|
169
|
-
Returns:
|
170
|
-
tuple[TorchTimeSeriesEstimator, pd.Series]: Trained estimator and loss history
|
171
|
-
|
172
|
-
"""
|
173
|
-
if approximator is None:
|
174
|
-
n_hidden = max(2 * len(features.columns) * len(targets.columns), 10)
|
175
|
-
n_outputs = len(targets.columns)
|
176
|
-
approximator = MLP(
|
177
|
-
n_inputs=len(features.columns),
|
178
|
-
neurons_per_layer=[n_hidden, n_hidden, n_outputs],
|
179
|
-
).to(device)
|
180
|
-
|
181
|
-
features_ = torch.Tensor(features.to_numpy(), device=device)
|
182
|
-
targets_ = torch.Tensor(targets.to_numpy(), device=device)
|
183
|
-
|
184
|
-
optimizer = optimimzer_cls(approximator.parameters())
|
185
|
-
if batch_size is None:
|
186
|
-
losses = _train_full(
|
187
|
-
approximator=approximator,
|
188
|
-
features=features_,
|
189
|
-
targets=targets_,
|
190
|
-
epochs=epochs,
|
191
|
-
optimizer=optimizer,
|
192
|
-
)
|
193
|
-
else:
|
194
|
-
losses = _train_batched(
|
195
|
-
approximator=approximator,
|
196
|
-
features=features_,
|
197
|
-
targets=targets_,
|
198
|
-
epochs=epochs,
|
199
|
-
optimizer=optimizer,
|
200
|
-
batch_size=batch_size,
|
201
|
-
)
|
202
|
-
|
203
|
-
return TorchSSEstimator(
|
204
|
-
model=approximator,
|
205
|
-
parameter_names=list(targets.columns),
|
206
|
-
), losses
|
207
|
-
|
208
|
-
|
209
|
-
def train_torch_time_course_estimator(
|
210
|
-
features: pd.DataFrame,
|
211
|
-
targets: pd.DataFrame,
|
212
|
-
epochs: int,
|
213
|
-
batch_size: int | None = None,
|
214
|
-
approximator: nn.Module | None = None,
|
215
|
-
optimimzer_cls: Callable[[ParamsT], Adam] = Adam,
|
216
|
-
device: torch.device = DefaultDevice,
|
217
|
-
) -> tuple[TorchTimeCourseEstimator, pd.Series]:
|
218
|
-
"""Train a PyTorch time course estimator.
|
219
|
-
|
220
|
-
This function trains a neural network model to estimate time course data
|
221
|
-
using the provided features and targets. It supports both full-batch and
|
222
|
-
mini-batch training.
|
223
|
-
|
224
|
-
Examples:
|
225
|
-
>>> train_torch_time_course_estimator(features, targets, epochs=100)
|
226
|
-
|
227
|
-
Args:
|
228
|
-
features: DataFrame containing the input features for training
|
229
|
-
targets: DataFrame containing the target values for training
|
230
|
-
epochs: Number of training epochs
|
231
|
-
batch_size: Size of mini-batches for training (None for full-batch)
|
232
|
-
approximator: Predefined neural network model (None to use default LSTM)
|
233
|
-
optimimzer_cls: Optimizer class to use for training (default: Adam)
|
234
|
-
device: Device to run the training on (default: DefaultDevice)
|
235
|
-
|
236
|
-
Returns:
|
237
|
-
tuple[TorchTimeSeriesEstimator, pd.Series]: Trained estimator and loss history
|
238
|
-
|
239
|
-
"""
|
240
|
-
if approximator is None:
|
241
|
-
approximator = LSTM(
|
242
|
-
n_inputs=len(features.columns),
|
243
|
-
n_outputs=len(targets.columns),
|
244
|
-
n_hidden=1,
|
245
|
-
).to(device)
|
246
|
-
|
247
|
-
optimizer = optimimzer_cls(approximator.parameters())
|
248
|
-
features_ = torch.Tensor(
|
249
|
-
np.swapaxes(
|
250
|
-
features.to_numpy().reshape((len(targets), -1, len(features.columns))),
|
251
|
-
axis1=0,
|
252
|
-
axis2=1,
|
253
|
-
),
|
254
|
-
device=device,
|
255
|
-
)
|
256
|
-
targets_ = torch.Tensor(targets.to_numpy(), device=device)
|
257
|
-
if batch_size is None:
|
258
|
-
losses = _train_full(
|
259
|
-
approximator=approximator,
|
260
|
-
features=features_,
|
261
|
-
targets=targets_,
|
262
|
-
epochs=epochs,
|
263
|
-
optimizer=optimizer,
|
264
|
-
)
|
265
|
-
else:
|
266
|
-
losses = _train_batched(
|
267
|
-
approximator=approximator,
|
268
|
-
features=features_,
|
269
|
-
targets=targets_,
|
270
|
-
epochs=epochs,
|
271
|
-
optimizer=optimizer,
|
272
|
-
batch_size=batch_size,
|
273
|
-
)
|
274
|
-
return TorchTimeCourseEstimator(
|
275
|
-
model=approximator,
|
276
|
-
parameter_names=list(targets.columns),
|
277
|
-
), losses
|
File without changes
|
File without changes
|