mxlpy 0.24.0__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mxlpy/__init__.py +2 -0
- mxlpy/fit.py +960 -359
- mxlpy/fuzzy.py +139 -0
- mxlpy/identify.py +1 -0
- mxlpy/integrators/int_scipy.py +4 -3
- mxlpy/meta/codegen_latex.py +1 -0
- mxlpy/meta/source_tools.py +1 -1
- mxlpy/model.py +74 -33
- mxlpy/nn/__init__.py +5 -0
- mxlpy/nn/_equinox.py +293 -0
- mxlpy/nn/_torch.py +59 -2
- mxlpy/npe/__init__.py +5 -0
- mxlpy/npe/_equinox.py +344 -0
- mxlpy/npe/_torch.py +6 -22
- mxlpy/parallel.py +73 -4
- mxlpy/surrogates/__init__.py +5 -0
- mxlpy/surrogates/_equinox.py +195 -0
- mxlpy/surrogates/_torch.py +5 -20
- mxlpy/symbolic/symbolic_model.py +30 -3
- mxlpy/types.py +172 -0
- {mxlpy-0.24.0.dist-info → mxlpy-0.26.0.dist-info}/METADATA +11 -1
- {mxlpy-0.24.0.dist-info → mxlpy-0.26.0.dist-info}/RECORD +24 -20
- {mxlpy-0.24.0.dist-info → mxlpy-0.26.0.dist-info}/WHEEL +0 -0
- {mxlpy-0.24.0.dist-info → mxlpy-0.26.0.dist-info}/licenses/LICENSE +0 -0
mxlpy/nn/_torch.py
CHANGED
@@ -15,8 +15,6 @@ import tqdm
|
|
15
15
|
from torch import nn
|
16
16
|
from torch.utils.data import DataLoader, TensorDataset
|
17
17
|
|
18
|
-
type LossFn = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
|
19
|
-
|
20
18
|
if TYPE_CHECKING:
|
21
19
|
from collections.abc import Callable
|
22
20
|
|
@@ -29,11 +27,65 @@ __all__ = [
|
|
29
27
|
"LSTM",
|
30
28
|
"LossFn",
|
31
29
|
"MLP",
|
30
|
+
"cosine_similarity",
|
31
|
+
"mean_abs_error",
|
32
|
+
"mean_absolute_percentage",
|
33
|
+
"mean_error",
|
34
|
+
"mean_squared_error",
|
35
|
+
"mean_squared_logarithmic",
|
36
|
+
"rms_error",
|
32
37
|
"train",
|
33
38
|
]
|
34
39
|
|
35
40
|
DefaultDevice = torch.device("cpu")
|
36
41
|
|
42
|
+
###############################################################################
|
43
|
+
# Loss functions
|
44
|
+
###############################################################################
|
45
|
+
|
46
|
+
|
47
|
+
type LossFn = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
|
48
|
+
|
49
|
+
|
50
|
+
def mean_error(pred: torch.Tensor, true: torch.Tensor) -> torch.Tensor:
|
51
|
+
"""Calculate mean error."""
|
52
|
+
return torch.mean(pred - true)
|
53
|
+
|
54
|
+
|
55
|
+
def mean_squared_error(pred: torch.Tensor, true: torch.Tensor) -> torch.Tensor:
|
56
|
+
"""Calculate mean squared error."""
|
57
|
+
return torch.mean(torch.square(pred - true))
|
58
|
+
|
59
|
+
|
60
|
+
def rms_error(pred: torch.Tensor, true: torch.Tensor) -> torch.Tensor:
|
61
|
+
"""Calculate root mean square error."""
|
62
|
+
return torch.sqrt(torch.mean(torch.square(pred - true)))
|
63
|
+
|
64
|
+
|
65
|
+
def mean_abs_error(pred: torch.Tensor, true: torch.Tensor) -> torch.Tensor:
|
66
|
+
"""Calculate mean absolute error."""
|
67
|
+
return torch.mean(torch.abs(pred - true))
|
68
|
+
|
69
|
+
|
70
|
+
def mean_absolute_percentage(pred: torch.Tensor, true: torch.Tensor) -> torch.Tensor:
|
71
|
+
"""Calculate mean absolute percentag error."""
|
72
|
+
return 100 * torch.mean(torch.abs((true - pred) / pred))
|
73
|
+
|
74
|
+
|
75
|
+
def mean_squared_logarithmic(pred: torch.Tensor, true: torch.Tensor) -> torch.Tensor:
|
76
|
+
"""Calculate root mean square error between model and data."""
|
77
|
+
return torch.mean(torch.square(torch.log(pred + 1) - torch.log(true + 1)))
|
78
|
+
|
79
|
+
|
80
|
+
def cosine_similarity(pred: torch.Tensor, true: torch.Tensor) -> torch.Tensor:
|
81
|
+
"""Calculate root mean square error between model and data."""
|
82
|
+
return -torch.sum(torch.norm(pred, 2) * torch.norm(true, 2))
|
83
|
+
|
84
|
+
|
85
|
+
###############################################################################
|
86
|
+
# Training routines
|
87
|
+
###############################################################################
|
88
|
+
|
37
89
|
|
38
90
|
def train(
|
39
91
|
model: nn.Module,
|
@@ -85,6 +137,11 @@ def train(
|
|
85
137
|
return pd.Series(losses, dtype=float)
|
86
138
|
|
87
139
|
|
140
|
+
###############################################################################
|
141
|
+
# Actual models
|
142
|
+
###############################################################################
|
143
|
+
|
144
|
+
|
88
145
|
class MLP(nn.Module):
|
89
146
|
"""Multilayer Perceptron (MLP) for surrogate modeling and neural posterior estimation.
|
90
147
|
|
mxlpy/npe/__init__.py
CHANGED
@@ -22,11 +22,16 @@ if TYPE_CHECKING:
|
|
22
22
|
import contextlib
|
23
23
|
|
24
24
|
with contextlib.suppress(ImportError):
|
25
|
+
from . import _equinox as equinox
|
25
26
|
from . import _keras as keras
|
26
27
|
from . import _torch as torch
|
27
28
|
else:
|
28
29
|
from lazy_import import lazy_module
|
29
30
|
|
31
|
+
equinox = lazy_module(
|
32
|
+
"mxlpy.npe._equinox",
|
33
|
+
error_strings={"module": "equinox", "install_name": "mxlpy[equinox]"},
|
34
|
+
)
|
30
35
|
keras = lazy_module(
|
31
36
|
"mxlpy.npe._keras",
|
32
37
|
error_strings={"module": "keras", "install_name": "mxlpy[tf]"},
|
mxlpy/npe/_equinox.py
ADDED
@@ -0,0 +1,344 @@
|
|
1
|
+
"""Neural Network Parameter Estimation (NPE) Module.
|
2
|
+
|
3
|
+
This module provides classes and functions for training neural network models to estimate
|
4
|
+
parameters in metabolic models. It includes functionality for both steady-state and
|
5
|
+
time-series data.
|
6
|
+
|
7
|
+
Functions:
|
8
|
+
train_torch_surrogate: Train a PyTorch surrogate model
|
9
|
+
train_torch_time_course_estimator: Train a PyTorch time course estimator
|
10
|
+
"""
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
from dataclasses import dataclass
|
15
|
+
from typing import TYPE_CHECKING, Self, cast
|
16
|
+
|
17
|
+
import jax
|
18
|
+
import jax.numpy as jnp
|
19
|
+
import numpy as np
|
20
|
+
import optax
|
21
|
+
import pandas as pd
|
22
|
+
|
23
|
+
from mxlpy.nn._equinox import LSTM, MLP, LossFn, mean_abs_error
|
24
|
+
from mxlpy.nn._equinox import train as _train
|
25
|
+
from mxlpy.types import AbstractEstimator
|
26
|
+
|
27
|
+
if TYPE_CHECKING:
|
28
|
+
import equinox as eqx
|
29
|
+
|
30
|
+
__all__ = [
|
31
|
+
"SteadyState",
|
32
|
+
"SteadyStateTrainer",
|
33
|
+
"TimeCourse",
|
34
|
+
"TimeCourseTrainer",
|
35
|
+
"train_steady_state",
|
36
|
+
"train_time_course",
|
37
|
+
]
|
38
|
+
|
39
|
+
|
40
|
+
@dataclass(kw_only=True)
|
41
|
+
class SteadyState(AbstractEstimator):
|
42
|
+
"""Estimator for steady state data using PyTorch models."""
|
43
|
+
|
44
|
+
model: eqx.Module
|
45
|
+
|
46
|
+
def predict(self, features: pd.Series | pd.DataFrame) -> pd.DataFrame:
|
47
|
+
"""Predict the target values for the given features."""
|
48
|
+
# One has to implement __call__ on eqx.Module, so this should
|
49
|
+
# always exist. Should really be abstract on eqx.Module
|
50
|
+
pred = jax.vmap(self.model)(jnp.array(features)) # type: ignore
|
51
|
+
return pd.DataFrame(pred, columns=self.parameter_names)
|
52
|
+
|
53
|
+
|
54
|
+
@dataclass(kw_only=True)
|
55
|
+
class TimeCourse(AbstractEstimator):
|
56
|
+
"""Estimator for time course data using PyTorch models."""
|
57
|
+
|
58
|
+
model: eqx.Module
|
59
|
+
|
60
|
+
def predict(self, features: pd.Series | pd.DataFrame) -> pd.DataFrame:
|
61
|
+
"""Predict the target values for the given features."""
|
62
|
+
idx = cast(pd.MultiIndex, features.index)
|
63
|
+
features_ = jnp.array(
|
64
|
+
np.swapaxes(
|
65
|
+
features.to_numpy().reshape(
|
66
|
+
(
|
67
|
+
len(idx.levels[0]),
|
68
|
+
len(idx.levels[1]),
|
69
|
+
len(features.columns),
|
70
|
+
)
|
71
|
+
),
|
72
|
+
axis1=0,
|
73
|
+
axis2=1,
|
74
|
+
),
|
75
|
+
)
|
76
|
+
# One has to implement __call__ on eqx.Module, so this should
|
77
|
+
# always exist. Should really be abstract on eqx.Module
|
78
|
+
pred = jax.vmap(self.model)(features_) # type: ignore
|
79
|
+
return pd.DataFrame(pred, columns=self.parameter_names)
|
80
|
+
|
81
|
+
|
82
|
+
@dataclass
|
83
|
+
class SteadyStateTrainer:
|
84
|
+
"""Trainer for steady state data using PyTorch models."""
|
85
|
+
|
86
|
+
features: pd.DataFrame
|
87
|
+
targets: pd.DataFrame
|
88
|
+
model: eqx.Module
|
89
|
+
optimizer: optax.GradientTransformation
|
90
|
+
losses: list[pd.Series]
|
91
|
+
loss_fn: LossFn
|
92
|
+
seed: int
|
93
|
+
|
94
|
+
def __init__(
|
95
|
+
self,
|
96
|
+
features: pd.DataFrame,
|
97
|
+
targets: pd.DataFrame,
|
98
|
+
model: eqx.Module | None = None,
|
99
|
+
optimizer: optax.GradientTransformation | None = None,
|
100
|
+
loss_fn: LossFn = mean_abs_error,
|
101
|
+
seed: int = 0,
|
102
|
+
) -> None:
|
103
|
+
"""Initialize the trainer with features, targets, and model.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
features: DataFrame containing the input features for training
|
107
|
+
targets: DataFrame containing the target values for training
|
108
|
+
model: Predefined neural network model (None to use default MLP)
|
109
|
+
optimizer: Optimizer class to use for training (default: Adam)
|
110
|
+
device: Device to run the training on (default: DefaultDevice)
|
111
|
+
loss_fn: Loss function
|
112
|
+
seed: seed of random initialisation
|
113
|
+
|
114
|
+
"""
|
115
|
+
self.features = features
|
116
|
+
self.targets = targets
|
117
|
+
|
118
|
+
if model is None:
|
119
|
+
n_hidden = max(2 * len(features.columns) * len(targets.columns), 10)
|
120
|
+
n_outputs = len(targets.columns)
|
121
|
+
model = MLP(
|
122
|
+
n_inputs=len(features.columns),
|
123
|
+
neurons_per_layer=[n_hidden, n_hidden, n_outputs],
|
124
|
+
key=jax.random.PRNGKey(seed),
|
125
|
+
)
|
126
|
+
self.model = model
|
127
|
+
self.optimizer = (
|
128
|
+
optax.adamw(learning_rate=0.001) if optimizer is None else optimizer
|
129
|
+
)
|
130
|
+
self.loss_fn = loss_fn
|
131
|
+
self.losses = []
|
132
|
+
self.seed = seed
|
133
|
+
|
134
|
+
def train(
|
135
|
+
self,
|
136
|
+
epochs: int,
|
137
|
+
batch_size: int | None = None,
|
138
|
+
) -> Self:
|
139
|
+
"""Train the model using the provided features and targets.
|
140
|
+
|
141
|
+
Args:
|
142
|
+
epochs: Number of training epochs
|
143
|
+
batch_size: Size of mini-batches for training (None for full-batch)
|
144
|
+
|
145
|
+
"""
|
146
|
+
losses = _train(
|
147
|
+
model=self.model,
|
148
|
+
features=jnp.array(self.features),
|
149
|
+
targets=jnp.array(self.targets),
|
150
|
+
epochs=epochs,
|
151
|
+
optimizer=self.optimizer,
|
152
|
+
batch_size=batch_size,
|
153
|
+
loss_fn=self.loss_fn,
|
154
|
+
)
|
155
|
+
|
156
|
+
if len(self.losses) > 0:
|
157
|
+
losses.index += self.losses[-1].index[-1]
|
158
|
+
self.losses.append(losses)
|
159
|
+
return self
|
160
|
+
|
161
|
+
def get_loss(self) -> pd.Series:
|
162
|
+
"""Get the loss history of the training process."""
|
163
|
+
return pd.concat(self.losses)
|
164
|
+
|
165
|
+
def get_estimator(self) -> SteadyState:
|
166
|
+
"""Get the trained estimator."""
|
167
|
+
return SteadyState(
|
168
|
+
model=self.model,
|
169
|
+
parameter_names=list(self.targets.columns),
|
170
|
+
)
|
171
|
+
|
172
|
+
|
173
|
+
@dataclass
|
174
|
+
class TimeCourseTrainer:
|
175
|
+
"""Trainer for time course data using PyTorch models."""
|
176
|
+
|
177
|
+
features: pd.DataFrame
|
178
|
+
targets: pd.DataFrame
|
179
|
+
model: eqx.Module
|
180
|
+
optimizer: optax.GradientTransformation
|
181
|
+
losses: list[pd.Series]
|
182
|
+
loss_fn: LossFn
|
183
|
+
|
184
|
+
def __init__(
|
185
|
+
self,
|
186
|
+
features: pd.DataFrame,
|
187
|
+
targets: pd.DataFrame,
|
188
|
+
model: eqx.Module | None = None,
|
189
|
+
optimizer: optax.GradientTransformation | None = None,
|
190
|
+
loss_fn: LossFn = mean_abs_error,
|
191
|
+
) -> None:
|
192
|
+
"""Initialize the trainer with features, targets, and model.
|
193
|
+
|
194
|
+
Args:
|
195
|
+
features: DataFrame containing the input features for training
|
196
|
+
targets: DataFrame containing the target values for training
|
197
|
+
model: Predefined neural network model (None to use default LSTM)
|
198
|
+
optimizer: Optimizer class to use for training (default: Adam)
|
199
|
+
device: Device to run the training on (default: DefaultDevice)
|
200
|
+
loss_fn: Loss function
|
201
|
+
|
202
|
+
"""
|
203
|
+
self.features = features
|
204
|
+
self.targets = targets
|
205
|
+
|
206
|
+
if model is None:
|
207
|
+
model = LSTM(
|
208
|
+
n_inputs=len(features.columns),
|
209
|
+
n_outputs=len(targets.columns),
|
210
|
+
n_hidden=1,
|
211
|
+
key=jnp.array([]),
|
212
|
+
)
|
213
|
+
self.model = model
|
214
|
+
self.optimizer = (
|
215
|
+
optax.adamw(learning_rate=0.001) if optimizer is None else optimizer
|
216
|
+
)
|
217
|
+
self.loss_fn = loss_fn
|
218
|
+
self.losses = []
|
219
|
+
|
220
|
+
def train(
|
221
|
+
self,
|
222
|
+
epochs: int,
|
223
|
+
batch_size: int | None = None,
|
224
|
+
) -> Self:
|
225
|
+
"""Train the model using the provided features and targets.
|
226
|
+
|
227
|
+
Args:
|
228
|
+
epochs: Number of training epochs
|
229
|
+
batch_size: Size of mini-batches for training (None for full-batch)
|
230
|
+
|
231
|
+
"""
|
232
|
+
losses = _train(
|
233
|
+
model=self.model,
|
234
|
+
features=jnp.array(
|
235
|
+
np.swapaxes(
|
236
|
+
self.features.to_numpy().reshape(
|
237
|
+
(len(self.targets), -1, len(self.features.columns))
|
238
|
+
),
|
239
|
+
axis1=0,
|
240
|
+
axis2=1,
|
241
|
+
)
|
242
|
+
),
|
243
|
+
targets=jnp.array(self.targets.to_numpy()),
|
244
|
+
epochs=epochs,
|
245
|
+
optimizer=self.optimizer,
|
246
|
+
batch_size=batch_size,
|
247
|
+
loss_fn=self.loss_fn,
|
248
|
+
)
|
249
|
+
|
250
|
+
if len(self.losses) > 0:
|
251
|
+
losses.index += self.losses[-1].index[-1]
|
252
|
+
self.losses.append(losses)
|
253
|
+
return self
|
254
|
+
|
255
|
+
def get_loss(self) -> pd.Series:
|
256
|
+
"""Get the loss history of the training process."""
|
257
|
+
return pd.concat(self.losses)
|
258
|
+
|
259
|
+
def get_estimator(self) -> TimeCourse:
|
260
|
+
"""Get the trained estimator."""
|
261
|
+
return TimeCourse(
|
262
|
+
model=self.model,
|
263
|
+
parameter_names=list(self.targets.columns),
|
264
|
+
)
|
265
|
+
|
266
|
+
|
267
|
+
def train_steady_state(
|
268
|
+
features: pd.DataFrame,
|
269
|
+
targets: pd.DataFrame,
|
270
|
+
epochs: int,
|
271
|
+
batch_size: int | None = None,
|
272
|
+
model: eqx.Module | None = None,
|
273
|
+
optimizer: optax.GradientTransformation | None = None,
|
274
|
+
) -> tuple[SteadyState, pd.Series]:
|
275
|
+
"""Train a PyTorch steady state estimator.
|
276
|
+
|
277
|
+
This function trains a neural network model to estimate steady state data
|
278
|
+
using the provided features and targets. It supports both full-batch and
|
279
|
+
mini-batch training.
|
280
|
+
|
281
|
+
Examples:
|
282
|
+
>>> train_torch_ss_estimator(features, targets, epochs=100)
|
283
|
+
|
284
|
+
Args:
|
285
|
+
features: DataFrame containing the input features for training
|
286
|
+
targets: DataFrame containing the target values for training
|
287
|
+
epochs: Number of training epochs
|
288
|
+
batch_size: Size of mini-batches for training (None for full-batch)
|
289
|
+
model: Predefined neural network model (None to use default MLP)
|
290
|
+
optimizer: Optimizer class to use for training (default: Adam)
|
291
|
+
device: Device to run the training on (default: DefaultDevice)
|
292
|
+
|
293
|
+
Returns:
|
294
|
+
tuple[TorchTimeSeriesEstimator, pd.Series]: Trained estimator and loss history
|
295
|
+
|
296
|
+
"""
|
297
|
+
trainer = SteadyStateTrainer(
|
298
|
+
features=features,
|
299
|
+
targets=targets,
|
300
|
+
model=model,
|
301
|
+
optimizer=optimizer,
|
302
|
+
).train(epochs=epochs, batch_size=batch_size)
|
303
|
+
|
304
|
+
return trainer.get_estimator(), trainer.get_loss()
|
305
|
+
|
306
|
+
|
307
|
+
def train_time_course(
|
308
|
+
features: pd.DataFrame,
|
309
|
+
targets: pd.DataFrame,
|
310
|
+
epochs: int,
|
311
|
+
batch_size: int | None = None,
|
312
|
+
model: eqx.Module | None = None,
|
313
|
+
optimizer: optax.GradientTransformation | None = None,
|
314
|
+
) -> tuple[TimeCourse, pd.Series]:
|
315
|
+
"""Train a PyTorch time course estimator.
|
316
|
+
|
317
|
+
This function trains a neural network model to estimate time course data
|
318
|
+
using the provided features and targets. It supports both full-batch and
|
319
|
+
mini-batch training.
|
320
|
+
|
321
|
+
Examples:
|
322
|
+
>>> train_torch_time_course_estimator(features, targets, epochs=100)
|
323
|
+
|
324
|
+
Args:
|
325
|
+
features: DataFrame containing the input features for training
|
326
|
+
targets: DataFrame containing the target values for training
|
327
|
+
epochs: Number of training epochs
|
328
|
+
batch_size: Size of mini-batches for training (None for full-batch)
|
329
|
+
model: Predefined neural network model (None to use default LSTM)
|
330
|
+
optimizer: Optimizer class to use for training (default: Adam)
|
331
|
+
device: Device to run the training on (default: DefaultDevice)
|
332
|
+
|
333
|
+
Returns:
|
334
|
+
tuple[TorchTimeSeriesEstimator, pd.Series]: Trained estimator and loss history
|
335
|
+
|
336
|
+
"""
|
337
|
+
trainer = TimeCourseTrainer(
|
338
|
+
features=features,
|
339
|
+
targets=targets,
|
340
|
+
model=model,
|
341
|
+
optimizer=optimizer,
|
342
|
+
).train(epochs=epochs, batch_size=batch_size)
|
343
|
+
|
344
|
+
return trainer.get_estimator(), trainer.get_loss()
|
mxlpy/npe/_torch.py
CHANGED
@@ -20,7 +20,8 @@ import torch
|
|
20
20
|
from torch import nn
|
21
21
|
from torch.optim.adam import Adam
|
22
22
|
|
23
|
-
from mxlpy.nn._torch import LSTM, MLP, DefaultDevice,
|
23
|
+
from mxlpy.nn._torch import LSTM, MLP, DefaultDevice, LossFn, mean_abs_error
|
24
|
+
from mxlpy.nn._torch import train as _train
|
24
25
|
from mxlpy.types import AbstractEstimator
|
25
26
|
|
26
27
|
if TYPE_CHECKING:
|
@@ -29,10 +30,7 @@ if TYPE_CHECKING:
|
|
29
30
|
from torch.optim.optimizer import ParamsT
|
30
31
|
|
31
32
|
|
32
|
-
type LossFn = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
|
33
|
-
|
34
33
|
__all__ = [
|
35
|
-
"LossFn",
|
36
34
|
"SteadyState",
|
37
35
|
"SteadyStateTrainer",
|
38
36
|
"TimeCourse",
|
@@ -42,20 +40,6 @@ __all__ = [
|
|
42
40
|
]
|
43
41
|
|
44
42
|
|
45
|
-
def _mean_abs(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
46
|
-
"""Standard loss for surrogates.
|
47
|
-
|
48
|
-
Args:
|
49
|
-
x: Predictions of a model.
|
50
|
-
y: Targets.
|
51
|
-
|
52
|
-
Returns:
|
53
|
-
torch.Tensor: loss.
|
54
|
-
|
55
|
-
"""
|
56
|
-
return torch.mean(torch.abs(x - y))
|
57
|
-
|
58
|
-
|
59
43
|
@dataclass(kw_only=True)
|
60
44
|
class SteadyState(AbstractEstimator):
|
61
45
|
"""Estimator for steady state data using PyTorch models."""
|
@@ -115,7 +99,7 @@ class SteadyStateTrainer:
|
|
115
99
|
model: nn.Module | None = None,
|
116
100
|
optimizer_cls: Callable[[ParamsT], Adam] = Adam,
|
117
101
|
device: torch.device = DefaultDevice,
|
118
|
-
loss_fn: LossFn =
|
102
|
+
loss_fn: LossFn = mean_abs_error,
|
119
103
|
) -> None:
|
120
104
|
"""Initialize the trainer with features, targets, and model.
|
121
105
|
|
@@ -156,7 +140,7 @@ class SteadyStateTrainer:
|
|
156
140
|
batch_size: Size of mini-batches for training (None for full-batch)
|
157
141
|
|
158
142
|
"""
|
159
|
-
losses =
|
143
|
+
losses = _train(
|
160
144
|
model=self.model,
|
161
145
|
features=self.features.to_numpy(),
|
162
146
|
targets=self.targets.to_numpy(),
|
@@ -203,7 +187,7 @@ class TimeCourseTrainer:
|
|
203
187
|
model: nn.Module | None = None,
|
204
188
|
optimizer_cls: Callable[[ParamsT], Adam] = Adam,
|
205
189
|
device: torch.device = DefaultDevice,
|
206
|
-
loss_fn: LossFn =
|
190
|
+
loss_fn: LossFn = mean_abs_error,
|
207
191
|
) -> None:
|
208
192
|
"""Initialize the trainer with features, targets, and model.
|
209
193
|
|
@@ -243,7 +227,7 @@ class TimeCourseTrainer:
|
|
243
227
|
batch_size: Size of mini-batches for training (None for full-batch)
|
244
228
|
|
245
229
|
"""
|
246
|
-
losses =
|
230
|
+
losses = _train(
|
247
231
|
model=self.model,
|
248
232
|
features=np.swapaxes(
|
249
233
|
self.features.to_numpy().reshape(
|
mxlpy/parallel.py
CHANGED
@@ -27,10 +27,7 @@ from tqdm import tqdm
|
|
27
27
|
if TYPE_CHECKING:
|
28
28
|
from collections.abc import Callable, Collection, Hashable
|
29
29
|
|
30
|
-
__all__ = [
|
31
|
-
"Cache",
|
32
|
-
"parallelise",
|
33
|
-
]
|
30
|
+
__all__ = ["Cache", "parallelise", "parallelise_keyless"]
|
34
31
|
|
35
32
|
|
36
33
|
def _pickle_name(k: Hashable) -> str:
|
@@ -173,3 +170,75 @@ def parallelise[K: Hashable, Tin, Tout](
|
|
173
170
|
) # type: ignore
|
174
171
|
|
175
172
|
return results
|
173
|
+
|
174
|
+
|
175
|
+
def parallelise_keyless[Tin, Tout](
|
176
|
+
fn: Callable[[Tin], Tout],
|
177
|
+
inputs: Collection[Tin],
|
178
|
+
*,
|
179
|
+
parallel: bool = True,
|
180
|
+
max_workers: int | None = None,
|
181
|
+
timeout: float | None = None,
|
182
|
+
disable_tqdm: bool = False,
|
183
|
+
tqdm_desc: str | None = None,
|
184
|
+
) -> list[Tout]:
|
185
|
+
"""Execute a function in parallel over a collection of inputs.
|
186
|
+
|
187
|
+
Examples:
|
188
|
+
>>> parallelise(square, [("a", 2), ("b", 3), ("c", 4)])
|
189
|
+
{"a": 4, "b": 9, "c": 16}
|
190
|
+
|
191
|
+
Args:
|
192
|
+
fn: Function to execute in parallel. Takes a single input and returns a result.
|
193
|
+
inputs: Collection of (key, input) tuples to process.
|
194
|
+
cache: Optional cache to store and retrieve results.
|
195
|
+
parallel: Whether to execute in parallel (default: True).
|
196
|
+
max_workers: Maximum number of worker processes (default: None, uses all available CPUs).
|
197
|
+
timeout: Maximum time (in seconds) to wait for each worker to complete (default: None).
|
198
|
+
disable_tqdm: Whether to disable the tqdm progress bar (default: False).
|
199
|
+
tqdm_desc: Description for the tqdm progress bar (default: None).
|
200
|
+
|
201
|
+
Returns:
|
202
|
+
dict[Tin, Tout]: Dictionary mapping inputs to their corresponding outputs.
|
203
|
+
|
204
|
+
"""
|
205
|
+
if sys.platform in ["win32", "cygwin"]:
|
206
|
+
parallel = False
|
207
|
+
|
208
|
+
results: list[Tout]
|
209
|
+
if parallel:
|
210
|
+
results = []
|
211
|
+
max_workers = (
|
212
|
+
multiprocessing.cpu_count() if max_workers is None else max_workers
|
213
|
+
)
|
214
|
+
|
215
|
+
with (
|
216
|
+
tqdm(
|
217
|
+
total=len(inputs),
|
218
|
+
disable=disable_tqdm,
|
219
|
+
desc=tqdm_desc,
|
220
|
+
) as pbar,
|
221
|
+
pebble.ProcessPool(max_workers=max_workers) as pool,
|
222
|
+
):
|
223
|
+
future = pool.map(fn, inputs, timeout=timeout)
|
224
|
+
it = future.result()
|
225
|
+
while True:
|
226
|
+
try:
|
227
|
+
value = next(it)
|
228
|
+
pbar.update(1)
|
229
|
+
results.append(value)
|
230
|
+
except StopIteration:
|
231
|
+
break
|
232
|
+
except TimeoutError:
|
233
|
+
pbar.update(1)
|
234
|
+
else:
|
235
|
+
results = list(
|
236
|
+
tqdm(
|
237
|
+
map(fn, inputs), # type: ignore
|
238
|
+
total=len(inputs),
|
239
|
+
disable=disable_tqdm,
|
240
|
+
desc=tqdm_desc,
|
241
|
+
)
|
242
|
+
) # type: ignore
|
243
|
+
|
244
|
+
return results
|
mxlpy/surrogates/__init__.py
CHANGED
@@ -14,11 +14,16 @@ if TYPE_CHECKING:
|
|
14
14
|
import contextlib
|
15
15
|
|
16
16
|
with contextlib.suppress(ImportError):
|
17
|
+
from . import _equinox as equinox
|
17
18
|
from . import _keras as keras
|
18
19
|
from . import _torch as torch
|
19
20
|
else:
|
20
21
|
from lazy_import import lazy_module
|
21
22
|
|
23
|
+
equinox = lazy_module(
|
24
|
+
"mxlpy.surrogates._equinox",
|
25
|
+
error_strings={"module": "equinox", "install_name": "mxlpy[equinox]"},
|
26
|
+
)
|
22
27
|
keras = lazy_module(
|
23
28
|
"mxlpy.surrogates._keras",
|
24
29
|
error_strings={"module": "keras", "install_name": "mxlpy[tf]"},
|