mxlpy 0.15.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mxlpy/types.py CHANGED
@@ -22,6 +22,7 @@ from dataclasses import dataclass, field
22
22
  import pandas as pd
23
23
 
24
24
  __all__ = [
25
+ "AbstractEstimator",
25
26
  "AbstractSurrogate",
26
27
  "Array",
27
28
  "ArrayLike",
@@ -354,7 +355,7 @@ class McSteadyStates:
354
355
  variables: pd.DataFrame
355
356
  fluxes: pd.DataFrame
356
357
  parameters: pd.DataFrame
357
- mc_parameters: pd.DataFrame
358
+ mc_to_scan: pd.DataFrame
358
359
 
359
360
  def __iter__(self) -> Iterator[pd.DataFrame]:
360
361
  """Iterate over the concentration and flux steady states."""
@@ -444,7 +445,8 @@ class AbstractSurrogate:
444
445
 
445
446
  """
446
447
 
447
- args: list[str] = field(default_factory=list)
448
+ args: list[str]
449
+ outputs: list[str]
448
450
  stoichiometries: dict[str, dict[str, float]] = field(default_factory=dict)
449
451
 
450
452
  @abstractmethod
@@ -455,7 +457,7 @@ class AbstractSurrogate:
455
457
  """Predict outputs based on input data."""
456
458
  return dict(
457
459
  zip(
458
- self.stoichiometries,
460
+ self.outputs,
459
461
  self.predict_raw(y),
460
462
  strict=True,
461
463
  )
@@ -475,7 +477,7 @@ class AbstractSurrogate:
475
477
  args: pd.DataFrame,
476
478
  ) -> None:
477
479
  """Predict outputs based on input data."""
478
- args[list(self.stoichiometries)] = pd.DataFrame(
480
+ args[self.outputs] = pd.DataFrame(
479
481
  [self.predict(y) for y in args.loc[:, self.args].to_numpy()],
480
482
  index=args.index,
481
483
  dtype=float,
@@ -491,4 +493,15 @@ class MockSurrogate(AbstractSurrogate):
491
493
  y: np.ndarray,
492
494
  ) -> dict[str, float]:
493
495
  """Predict outputs based on input data."""
494
- return dict(zip(self.stoichiometries, y, strict=True))
496
+ return dict(zip(self.outputs, y, strict=True))
497
+
498
+
499
+ @dataclass(kw_only=True)
500
+ class AbstractEstimator:
501
+ """Abstract class for parameter estimation using neural networks."""
502
+
503
+ parameter_names: list[str]
504
+
505
+ @abstractmethod
506
+ def predict(self, features: pd.Series | pd.DataFrame) -> pd.DataFrame:
507
+ """Predict the target values for the given features."""
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mxlpy
3
- Version: 0.15.0
3
+ Version: 0.17.0
4
4
  Summary: A package to build metabolic models
5
5
  Author-email: Marvin van Aalst <marvin.vanaalst@gmail.com>
6
6
  Maintainer-email: Marvin van Aalst <marvin.vanaalst@gmail.com>
@@ -47,7 +47,6 @@ Requires-Dist: jupyter>=1.1.1; extra == 'dev'
47
47
  Requires-Dist: mkdocs-jupyter>=0.25.1; extra == 'dev'
48
48
  Requires-Dist: mkdocs-material>=9.5.42; extra == 'dev'
49
49
  Requires-Dist: mkdocs>=1.6.1; extra == 'dev'
50
- Requires-Dist: mypy>=1.13.0; extra == 'dev'
51
50
  Requires-Dist: pre-commit>=4.0.1; extra == 'dev'
52
51
  Requires-Dist: pyright>=1.1.387; extra == 'dev'
53
52
  Requires-Dist: pytest-cov>=5.0.0; extra == 'dev'
@@ -60,7 +59,9 @@ Provides-Extra: torch
60
59
  Requires-Dist: torch>=2.5.1; extra == 'torch'
61
60
  Description-Content-Type: text/markdown
62
61
 
63
- <img src="docs/assets/logo-diagram.png" style="display: block; max-height: 30rem; margin: auto; padding: 0" alt='mxlpy-logo'>
62
+ <p align="center">
63
+ <img src="docs/assets/logo-diagram.png" width="600px" alt='mxlpy-logo'>
64
+ </p>
64
65
 
65
66
  # mxlpy
66
67
 
@@ -70,7 +71,7 @@ Description-Content-Type: text/markdown
70
71
  ![Coverage](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fgist.github.com%2Fmarvinvanaalst%2F98ab3ce1db511de42f9871e91d85e4cd%2Fraw%2Fcoverage.json&query=%24.message&label=Coverage&color=%24.color&suffix=%20%25)
71
72
  [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
72
73
  [![security: bandit](https://img.shields.io/badge/security-bandit-yellow.svg)](https://github.com/PyCQA/bandit)
73
- [![Downloads](https://pepy.tech/badge/mxlpy)](https://pepy.tech/project/mxlpy)
74
+ [![PyPI Downloads](https://static.pepy.tech/badge/mxlpy)](https://pepy.tech/projects/mxlpy)
74
75
 
75
76
  [docs-badge]: https://img.shields.io/badge/docs-main-green.svg?style=flat-square
76
77
  [docs]: https://computational-biology-aachen.github.io/mxlpy/
@@ -1,50 +1,51 @@
1
- mxlpy/__init__.py,sha256=XZYNFyDC5rWcKi6139mq04cROI7LwJvxB2_3ApKwcvY,4194
1
+ mxlpy/__init__.py,sha256=lGo7XQTVuR1p8rW1J6gZsgdQWRqfYa9AWbvZQwT8oLQ,4236
2
2
  mxlpy/distributions.py,sha256=ce6RTqn19YzMMec-u09fSIUA8A92M6rehCuHuXWcX7A,8734
3
3
  mxlpy/fit.py,sha256=LwSoLfNVrqSlTtuUApwH36LjzGU0HLs4C_2qqTTjXFE,7859
4
- mxlpy/fns.py,sha256=ct_RFj9koW8vXHyr27GnbZUHUS_zfs4rDysybuFiOaU,4599
4
+ mxlpy/fns.py,sha256=VxDDyEdtGD7fEoT5LiiEaRqFk-0fIunRXHr1dCMpCdE,14002
5
5
  mxlpy/identify.py,sha256=af52SCG4nlY9sSw22goaIheuvXR09QYK4ksCT24QHWI,1946
6
6
  mxlpy/label_map.py,sha256=urv-QTb0MUEKjwWvKtJSB8H2kvhLn1EKfRIH7awQQ8Y,17769
7
7
  mxlpy/linear_label_map.py,sha256=DqzN_akacPccZwzYAR3ANIdzAU_GU6Xe6gWV9DHAAWU,10282
8
- mxlpy/mc.py,sha256=HWuJq4fV_wfTDERbLJRSF3fjCCYMxzLdqAyO53Z_uF8,16985
9
- mxlpy/mca.py,sha256=H0dfV45Kz5nMIW8s2V61op7x6LmI21wWgRf94i6iIY4,9328
10
- mxlpy/model.py,sha256=qzol8nDSbM3HdESh50c4UFjn6Pw5JwcvhQ5AyKnbyvc,57576
11
- mxlpy/npe.py,sha256=oiRLA43-qf-AcS2KpQfJIOt7-Ev9Aj5sF6TMq9bJn84,8747
8
+ mxlpy/mc.py,sha256=oYd8a3ycyZLyh-ZxTYUjDRNfsCcwSQaLWssxv0yC5Cc,17399
9
+ mxlpy/mca.py,sha256=1_qBX9lHI6svXSebtwvMldAMwPlLqMylAPmxMbMQdWw,9359
10
+ mxlpy/model.py,sha256=H1rAKaB5pAQcMuBh5GnXuBReADTx5IDa1x0CdUZ6VlI,58411
12
11
  mxlpy/parallel.py,sha256=kX4Td5YoovDwZp6kX_3cfO6QtHSS9ieJ0bMZiKs3Xv8,5002
13
12
  mxlpy/parameterise.py,sha256=2jMhhO-bHTFP_0kXercJekeATAZYBg5FrK1MQ_mWGpk,654
14
13
  mxlpy/paths.py,sha256=TK2wO4N9lG-UV1JGfeB64q48JVDbwqIUj63rl55MKuQ,1022
15
14
  mxlpy/plot.py,sha256=4uu-6d8LH-GWX-sG_TlSpkSsnikv1DLTtnjJzA7nuRA,24670
16
15
  mxlpy/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
- mxlpy/report.py,sha256=h7dhcBzPFydLPxdsEXokzDf7Ce4PirXMsvLqlDZLSWM,7181
18
- mxlpy/scan.py,sha256=-1SLyXJOX3U3CxeP1dEC4ytAoBMCH0Ql89wGvsG3LbI,18858
16
+ mxlpy/report.py,sha256=ZwnjquPAvo4A8UqK-BT19SZFSEUOy1FALqoh7uTmbAI,7793
17
+ mxlpy/scan.py,sha256=FBPpjv66v4IWZ5OwG_EWUdrucLWR9gq_XEsLFC-otaw,18969
19
18
  mxlpy/simulator.py,sha256=9Ne4P5Jrwgx4oAlljPvCqSCCy98_5Lv1B87y1AkbI4c,21041
20
- mxlpy/types.py,sha256=AGWFK59MRRVBOr8I1EmFKGdRpF1DiT8C4lB95-oBk40,14512
19
+ mxlpy/types.py,sha256=fB8-oTJkIpkGP0affoVx1ak2zOuTpT6xH-w62oSJxiU,14814
21
20
  mxlpy/experimental/__init__.py,sha256=kZTE-92OErpHzNRqmgSQYH4CGXrogGJ5EL35XGZQ81M,206
22
21
  mxlpy/experimental/diff.py,sha256=4bztagJzFMsQJM7dlun_kv-WrWssM8CIw7gcL63hFf8,8952
23
22
  mxlpy/integrators/__init__.py,sha256=kqmV6a0TRyLGR_XqbyAI652AfptYnXAUpqbSFg0CpP8,450
24
- mxlpy/integrators/int_assimulo.py,sha256=d-4HHOj4vmGpg8ig2IXMO5CPiIrq89_quEKvCxIKrhw,4747
23
+ mxlpy/integrators/int_assimulo.py,sha256=TCBWQd558ZeRdBba1jCNsFyLBOssKvm8dXK36Aqg4_k,4817
25
24
  mxlpy/integrators/int_scipy.py,sha256=dFHlYTeb2zX97f3VuNdMJdI7WEYshF4JAIgprKKk2z4,4581
26
25
  mxlpy/meta/__init__.py,sha256=Jyy4063fZy6iT4LSwjPyEAVr4N_3xxcLc8wDBoDPyKc,278
27
- mxlpy/meta/codegen_latex.py,sha256=R0OJqzE7PnOCWLk52C3XWuRb-zI2eYTvV2oJZJvPsOE,13414
26
+ mxlpy/meta/codegen_latex.py,sha256=vONj--_wmFM_FJpe15aAYyT06-kolqQwSe2NEbKrQWo,19934
28
27
  mxlpy/meta/codegen_modebase.py,sha256=_ZAW4NvXhKwJQLGz5hkwwZpL2JMAJlfG-GUWkYIiNvw,3124
29
28
  mxlpy/meta/codegen_py.py,sha256=xSdeuEGPGc-QKRMgJO4VSPGMlxCPEV5prkKjNQ2D2hg,3483
30
- mxlpy/meta/source_tools.py,sha256=EN3OoGQaXeIsDTJvA7S15-xDBra3DCIyFZEJ6h0Fy0k,11125
29
+ mxlpy/meta/source_tools.py,sha256=IyiCLZ1KScSqADC9p_QSRgedoHGibs9U1RGJuXm827U,13464
31
30
  mxlpy/nn/__init__.py,sha256=yUc4o-iqfVVzkq9tZCstWwizrCqNlMft0YUwWGFFO-E,219
32
31
  mxlpy/nn/_tensorflow.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
33
32
  mxlpy/nn/_torch.py,sha256=_4Rw87zpIlCnrOsXC7iFp1c64_FcpfVmgBXBU0p8mlg,4063
33
+ mxlpy/npe/__init__.py,sha256=IQmqUPJc5A8QXJLzp6Dq6Sjm8Hh2KAYZLrMxXQVeQP8,1181
34
+ mxlpy/npe/_torch.py,sha256=pMU4PL3eO9Aqdn9waDUpvvDRdmUlmaOFtREwSkZbvNs,13874
34
35
  mxlpy/sbml/__init__.py,sha256=AS7IwrBzBgN8coUZkyBEtiYa9ICWyY1wzp1ujVm5ItA,226
35
36
  mxlpy/sbml/_data.py,sha256=XwT1sSxn6KLTXYMbk4ORbEAEgZhQDBfoyrjMBDAoY_s,1135
36
37
  mxlpy/sbml/_export.py,sha256=Q6B9rxy-yt73DORzAYu4BpfkZXxCS3MDSDUXwpoXV6Q,19970
37
- mxlpy/sbml/_import.py,sha256=5DJklsAe2sMV1CFxAbkSFRT3amPzOZmpo53y9NYv6TY,22015
38
+ mxlpy/sbml/_import.py,sha256=5odQBdpD93mQJp2bVIabmPo6NK60nxqrdSVB8fEsF_A,22099
38
39
  mxlpy/sbml/_mathml.py,sha256=bNk9RQ_NQFDhY1R354p-gwqqHaIiyAwZ1xLPHHhiguQ,24436
39
40
  mxlpy/sbml/_name_conversion.py,sha256=XK9DEyzhrD0GBBwwjK9RA0yORrDX5c-Uvx0VtKMR5rA,1325
40
41
  mxlpy/sbml/_unit_conversion.py,sha256=dW_I6_Ou09ccwnp6LIdrPriIQnQUK5lJcjzM2Fawm6U,1927
41
- mxlpy/surrogates/__init__.py,sha256=N_iXERECKvmrHiihwnyQEKOSBsmlGEuQhEotn-mWKdk,924
42
- mxlpy/surrogates/_poly.py,sha256=E54CFscQBCcYMrty1X2ynl9GlS9uoEeAUgBPnhm3iIA,3113
43
- mxlpy/surrogates/_torch.py,sha256=E_1eDUlPSVFwROkdMDCqYwwHE-61pjNMJWotnhjzge0,5891
42
+ mxlpy/surrogates/__init__.py,sha256=ofHPNwe0LAILP2ZUWivAQpOv9LyHHzLZc6iu1cV2LeQ,894
43
+ mxlpy/surrogates/_poly.py,sha256=n1pe4xuD2A4BK8jJagzZ-17WW3kqvFBO-ZYuznmfosw,3303
44
+ mxlpy/surrogates/_torch.py,sha256=Ep5e5oDyUsUdUpEqCY7WKLKKuwbPu0gcmVTiRabNzQ4,8593
44
45
  mxlpy/symbolic/__init__.py,sha256=3hQjCMw8-6iOxeUdfnCg8449fF_BRF2u6lCM1GPpkRY,222
45
- mxlpy/symbolic/strikepy.py,sha256=r6nRtckV1nxKq3i1bYYWZOkzwZ5XeKQuZM5ck44vUo0,20010
46
+ mxlpy/symbolic/strikepy.py,sha256=UMx2LMRwCkASKjdCYEvh9tKlW9dk3nDoWM9NNJXWL_8,19960
46
47
  mxlpy/symbolic/symbolic_model.py,sha256=YL9noEeP3_0DoKXwMPELtfmPuP6mgNcLIJgDRCkyB7A,2434
47
- mxlpy-0.15.0.dist-info/METADATA,sha256=EXj4l7bToJEQx3jbL26ECI_VuxtSgutyL44zdxLQ4Bg,4564
48
- mxlpy-0.15.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
49
- mxlpy-0.15.0.dist-info/licenses/LICENSE,sha256=bEzjyjy1stQhfRDVaVHa3xV1x-V8emwdlbMvYO8Zo84,35073
50
- mxlpy-0.15.0.dist-info/RECORD,,
48
+ mxlpy-0.17.0.dist-info/METADATA,sha256=8fzqS2MFlBN-JkidtjpM3i5DyfooGggrRv3AylMRIVQ,4507
49
+ mxlpy-0.17.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
50
+ mxlpy-0.17.0.dist-info/licenses/LICENSE,sha256=bEzjyjy1stQhfRDVaVHa3xV1x-V8emwdlbMvYO8Zo84,35073
51
+ mxlpy-0.17.0.dist-info/RECORD,,
mxlpy/npe.py DELETED
@@ -1,277 +0,0 @@
1
- """Neural Network Parameter Estimation (NPE) Module.
2
-
3
- This module provides classes and functions for training neural network models to estimate
4
- parameters in metabolic models. It includes functionality for both steady-state and
5
- time-series data.
6
-
7
- Functions:
8
- train_torch_surrogate: Train a PyTorch surrogate model
9
- train_torch_time_course_estimator: Train a PyTorch time course estimator
10
- """
11
-
12
- from __future__ import annotations
13
-
14
- __all__ = [
15
- "AbstractEstimator",
16
- "DefaultCache",
17
- "TorchSSEstimator",
18
- "TorchTimeCourseEstimator",
19
- "train_torch_ss_estimator",
20
- "train_torch_time_course_estimator",
21
- ]
22
-
23
- from abc import abstractmethod
24
- from dataclasses import dataclass
25
- from pathlib import Path
26
- from typing import TYPE_CHECKING, cast
27
-
28
- import numpy as np
29
- import pandas as pd
30
- import torch
31
- import tqdm
32
- from torch import nn
33
- from torch.optim.adam import Adam
34
-
35
- from mxlpy.nn._torch import LSTM, MLP, DefaultDevice
36
- from mxlpy.parallel import Cache
37
-
38
- if TYPE_CHECKING:
39
- from collections.abc import Callable
40
-
41
- from torch.optim.optimizer import ParamsT
42
-
43
- DefaultCache = Cache(Path(".cache"))
44
-
45
-
46
- @dataclass(kw_only=True)
47
- class AbstractEstimator:
48
- """Abstract class for parameter estimation using neural networks."""
49
-
50
- parameter_names: list[str]
51
-
52
- @abstractmethod
53
- def predict(self, features: pd.Series | pd.DataFrame) -> pd.DataFrame:
54
- """Predict the target values for the given features."""
55
-
56
-
57
- @dataclass(kw_only=True)
58
- class TorchSSEstimator(AbstractEstimator):
59
- """Estimator for steady state data using PyTorch models."""
60
-
61
- model: torch.nn.Module
62
-
63
- def predict(self, features: pd.Series | pd.DataFrame) -> pd.DataFrame:
64
- """Predict the target values for the given features."""
65
- with torch.no_grad():
66
- pred = self.model(torch.tensor(features.to_numpy(), dtype=torch.float32))
67
- return pd.DataFrame(pred, columns=self.parameter_names)
68
-
69
-
70
- @dataclass(kw_only=True)
71
- class TorchTimeCourseEstimator(AbstractEstimator):
72
- """Estimator for time course data using PyTorch models."""
73
-
74
- model: torch.nn.Module
75
-
76
- def predict(self, features: pd.Series | pd.DataFrame) -> pd.DataFrame:
77
- """Predict the target values for the given features."""
78
- idx = cast(pd.MultiIndex, features.index)
79
- features_ = torch.Tensor(
80
- np.swapaxes(
81
- features.to_numpy().reshape(
82
- (
83
- len(idx.levels[0]),
84
- len(idx.levels[1]),
85
- len(features.columns),
86
- )
87
- ),
88
- axis1=0,
89
- axis2=1,
90
- ),
91
- )
92
- with torch.no_grad():
93
- pred = self.model(features_)
94
- return pd.DataFrame(pred, columns=self.parameter_names)
95
-
96
-
97
- def _train_batched(
98
- approximator: nn.Module,
99
- features: torch.Tensor,
100
- targets: torch.Tensor,
101
- epochs: int,
102
- optimizer: Adam,
103
- batch_size: int,
104
- ) -> pd.Series:
105
- losses = {}
106
-
107
- for epoch in tqdm.trange(epochs):
108
- permutation = torch.randperm(features.size()[0])
109
- epoch_loss = 0
110
- for i in range(0, features.size()[0], batch_size):
111
- optimizer.zero_grad()
112
- indices = permutation[i : i + batch_size]
113
-
114
- loss = torch.mean(
115
- torch.abs(approximator(features[indices]) - targets[indices])
116
- )
117
- loss.backward()
118
- optimizer.step()
119
- epoch_loss += loss.detach().numpy()
120
-
121
- losses[epoch] = epoch_loss / (features.size()[0] / batch_size)
122
- return pd.Series(losses, dtype=float)
123
-
124
-
125
- def _train_full(
126
- approximator: nn.Module,
127
- features: torch.Tensor,
128
- targets: torch.Tensor,
129
- epochs: int,
130
- optimizer: Adam,
131
- ) -> pd.Series:
132
- losses = {}
133
- for i in tqdm.trange(epochs):
134
- optimizer.zero_grad()
135
- loss = torch.mean(torch.abs(approximator(features) - targets))
136
- loss.backward()
137
- optimizer.step()
138
- losses[i] = loss.detach().numpy()
139
- return pd.Series(losses, dtype=float)
140
-
141
-
142
- def train_torch_ss_estimator(
143
- features: pd.DataFrame,
144
- targets: pd.DataFrame,
145
- epochs: int,
146
- batch_size: int | None = None,
147
- approximator: nn.Module | None = None,
148
- optimimzer_cls: Callable[[ParamsT], Adam] = Adam,
149
- device: torch.device = DefaultDevice,
150
- ) -> tuple[TorchSSEstimator, pd.Series]:
151
- """Train a PyTorch steady state estimator.
152
-
153
- This function trains a neural network model to estimate steady state data
154
- using the provided features and targets. It supports both full-batch and
155
- mini-batch training.
156
-
157
- Examples:
158
- >>> train_torch_ss_estimator(features, targets, epochs=100)
159
-
160
- Args:
161
- features: DataFrame containing the input features for training
162
- targets: DataFrame containing the target values for training
163
- epochs: Number of training epochs
164
- batch_size: Size of mini-batches for training (None for full-batch)
165
- approximator: Predefined neural network model (None to use default MLP)
166
- optimimzer_cls: Optimizer class to use for training (default: Adam)
167
- device: Device to run the training on (default: DefaultDevice)
168
-
169
- Returns:
170
- tuple[TorchTimeSeriesEstimator, pd.Series]: Trained estimator and loss history
171
-
172
- """
173
- if approximator is None:
174
- n_hidden = max(2 * len(features.columns) * len(targets.columns), 10)
175
- n_outputs = len(targets.columns)
176
- approximator = MLP(
177
- n_inputs=len(features.columns),
178
- neurons_per_layer=[n_hidden, n_hidden, n_outputs],
179
- ).to(device)
180
-
181
- features_ = torch.Tensor(features.to_numpy(), device=device)
182
- targets_ = torch.Tensor(targets.to_numpy(), device=device)
183
-
184
- optimizer = optimimzer_cls(approximator.parameters())
185
- if batch_size is None:
186
- losses = _train_full(
187
- approximator=approximator,
188
- features=features_,
189
- targets=targets_,
190
- epochs=epochs,
191
- optimizer=optimizer,
192
- )
193
- else:
194
- losses = _train_batched(
195
- approximator=approximator,
196
- features=features_,
197
- targets=targets_,
198
- epochs=epochs,
199
- optimizer=optimizer,
200
- batch_size=batch_size,
201
- )
202
-
203
- return TorchSSEstimator(
204
- model=approximator,
205
- parameter_names=list(targets.columns),
206
- ), losses
207
-
208
-
209
- def train_torch_time_course_estimator(
210
- features: pd.DataFrame,
211
- targets: pd.DataFrame,
212
- epochs: int,
213
- batch_size: int | None = None,
214
- approximator: nn.Module | None = None,
215
- optimimzer_cls: Callable[[ParamsT], Adam] = Adam,
216
- device: torch.device = DefaultDevice,
217
- ) -> tuple[TorchTimeCourseEstimator, pd.Series]:
218
- """Train a PyTorch time course estimator.
219
-
220
- This function trains a neural network model to estimate time course data
221
- using the provided features and targets. It supports both full-batch and
222
- mini-batch training.
223
-
224
- Examples:
225
- >>> train_torch_time_course_estimator(features, targets, epochs=100)
226
-
227
- Args:
228
- features: DataFrame containing the input features for training
229
- targets: DataFrame containing the target values for training
230
- epochs: Number of training epochs
231
- batch_size: Size of mini-batches for training (None for full-batch)
232
- approximator: Predefined neural network model (None to use default LSTM)
233
- optimimzer_cls: Optimizer class to use for training (default: Adam)
234
- device: Device to run the training on (default: DefaultDevice)
235
-
236
- Returns:
237
- tuple[TorchTimeSeriesEstimator, pd.Series]: Trained estimator and loss history
238
-
239
- """
240
- if approximator is None:
241
- approximator = LSTM(
242
- n_inputs=len(features.columns),
243
- n_outputs=len(targets.columns),
244
- n_hidden=1,
245
- ).to(device)
246
-
247
- optimizer = optimimzer_cls(approximator.parameters())
248
- features_ = torch.Tensor(
249
- np.swapaxes(
250
- features.to_numpy().reshape((len(targets), -1, len(features.columns))),
251
- axis1=0,
252
- axis2=1,
253
- ),
254
- device=device,
255
- )
256
- targets_ = torch.Tensor(targets.to_numpy(), device=device)
257
- if batch_size is None:
258
- losses = _train_full(
259
- approximator=approximator,
260
- features=features_,
261
- targets=targets_,
262
- epochs=epochs,
263
- optimizer=optimizer,
264
- )
265
- else:
266
- losses = _train_batched(
267
- approximator=approximator,
268
- features=features_,
269
- targets=targets_,
270
- epochs=epochs,
271
- optimizer=optimizer,
272
- batch_size=batch_size,
273
- )
274
- return TorchTimeCourseEstimator(
275
- model=approximator,
276
- parameter_names=list(targets.columns),
277
- ), losses
File without changes