mxlpy 0.25.0__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mxlpy/nn/_torch.py CHANGED
@@ -15,8 +15,6 @@ import tqdm
15
15
  from torch import nn
16
16
  from torch.utils.data import DataLoader, TensorDataset
17
17
 
18
- type LossFn = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
19
-
20
18
  if TYPE_CHECKING:
21
19
  from collections.abc import Callable
22
20
 
@@ -29,11 +27,65 @@ __all__ = [
29
27
  "LSTM",
30
28
  "LossFn",
31
29
  "MLP",
30
+ "cosine_similarity",
31
+ "mean_abs_error",
32
+ "mean_absolute_percentage",
33
+ "mean_error",
34
+ "mean_squared_error",
35
+ "mean_squared_logarithmic",
36
+ "rms_error",
32
37
  "train",
33
38
  ]
34
39
 
35
40
  DefaultDevice = torch.device("cpu")
36
41
 
42
+ ###############################################################################
43
+ # Loss functions
44
+ ###############################################################################
45
+
46
+
47
+ type LossFn = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
48
+
49
+
50
+ def mean_error(pred: torch.Tensor, true: torch.Tensor) -> torch.Tensor:
51
+ """Calculate mean error."""
52
+ return torch.mean(pred - true)
53
+
54
+
55
+ def mean_squared_error(pred: torch.Tensor, true: torch.Tensor) -> torch.Tensor:
56
+ """Calculate mean squared error."""
57
+ return torch.mean(torch.square(pred - true))
58
+
59
+
60
+ def rms_error(pred: torch.Tensor, true: torch.Tensor) -> torch.Tensor:
61
+ """Calculate root mean square error."""
62
+ return torch.sqrt(torch.mean(torch.square(pred - true)))
63
+
64
+
65
+ def mean_abs_error(pred: torch.Tensor, true: torch.Tensor) -> torch.Tensor:
66
+ """Calculate mean absolute error."""
67
+ return torch.mean(torch.abs(pred - true))
68
+
69
+
70
+ def mean_absolute_percentage(pred: torch.Tensor, true: torch.Tensor) -> torch.Tensor:
71
+ """Calculate mean absolute percentag error."""
72
+ return 100 * torch.mean(torch.abs((true - pred) / pred))
73
+
74
+
75
+ def mean_squared_logarithmic(pred: torch.Tensor, true: torch.Tensor) -> torch.Tensor:
76
+ """Calculate root mean square error between model and data."""
77
+ return torch.mean(torch.square(torch.log(pred + 1) - torch.log(true + 1)))
78
+
79
+
80
+ def cosine_similarity(pred: torch.Tensor, true: torch.Tensor) -> torch.Tensor:
81
+ """Calculate root mean square error between model and data."""
82
+ return -torch.sum(torch.norm(pred, 2) * torch.norm(true, 2))
83
+
84
+
85
+ ###############################################################################
86
+ # Training routines
87
+ ###############################################################################
88
+
37
89
 
38
90
  def train(
39
91
  model: nn.Module,
@@ -85,6 +137,11 @@ def train(
85
137
  return pd.Series(losses, dtype=float)
86
138
 
87
139
 
140
+ ###############################################################################
141
+ # Actual models
142
+ ###############################################################################
143
+
144
+
88
145
  class MLP(nn.Module):
89
146
  """Multilayer Perceptron (MLP) for surrogate modeling and neural posterior estimation.
90
147
 
mxlpy/npe/__init__.py CHANGED
@@ -22,11 +22,16 @@ if TYPE_CHECKING:
22
22
  import contextlib
23
23
 
24
24
  with contextlib.suppress(ImportError):
25
+ from . import _equinox as equinox
25
26
  from . import _keras as keras
26
27
  from . import _torch as torch
27
28
  else:
28
29
  from lazy_import import lazy_module
29
30
 
31
+ equinox = lazy_module(
32
+ "mxlpy.npe._equinox",
33
+ error_strings={"module": "equinox", "install_name": "mxlpy[equinox]"},
34
+ )
30
35
  keras = lazy_module(
31
36
  "mxlpy.npe._keras",
32
37
  error_strings={"module": "keras", "install_name": "mxlpy[tf]"},
mxlpy/npe/_equinox.py ADDED
@@ -0,0 +1,344 @@
1
+ """Neural Network Parameter Estimation (NPE) Module.
2
+
3
+ This module provides classes and functions for training neural network models to estimate
4
+ parameters in metabolic models. It includes functionality for both steady-state and
5
+ time-series data.
6
+
7
+ Functions:
8
+ train_torch_surrogate: Train a PyTorch surrogate model
9
+ train_torch_time_course_estimator: Train a PyTorch time course estimator
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from dataclasses import dataclass
15
+ from typing import TYPE_CHECKING, Self, cast
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import numpy as np
20
+ import optax
21
+ import pandas as pd
22
+
23
+ from mxlpy.nn._equinox import LSTM, MLP, LossFn, mean_abs_error
24
+ from mxlpy.nn._equinox import train as _train
25
+ from mxlpy.types import AbstractEstimator
26
+
27
+ if TYPE_CHECKING:
28
+ import equinox as eqx
29
+
30
+ __all__ = [
31
+ "SteadyState",
32
+ "SteadyStateTrainer",
33
+ "TimeCourse",
34
+ "TimeCourseTrainer",
35
+ "train_steady_state",
36
+ "train_time_course",
37
+ ]
38
+
39
+
40
+ @dataclass(kw_only=True)
41
+ class SteadyState(AbstractEstimator):
42
+ """Estimator for steady state data using PyTorch models."""
43
+
44
+ model: eqx.Module
45
+
46
+ def predict(self, features: pd.Series | pd.DataFrame) -> pd.DataFrame:
47
+ """Predict the target values for the given features."""
48
+ # One has to implement __call__ on eqx.Module, so this should
49
+ # always exist. Should really be abstract on eqx.Module
50
+ pred = jax.vmap(self.model)(jnp.array(features)) # type: ignore
51
+ return pd.DataFrame(pred, columns=self.parameter_names)
52
+
53
+
54
+ @dataclass(kw_only=True)
55
+ class TimeCourse(AbstractEstimator):
56
+ """Estimator for time course data using PyTorch models."""
57
+
58
+ model: eqx.Module
59
+
60
+ def predict(self, features: pd.Series | pd.DataFrame) -> pd.DataFrame:
61
+ """Predict the target values for the given features."""
62
+ idx = cast(pd.MultiIndex, features.index)
63
+ features_ = jnp.array(
64
+ np.swapaxes(
65
+ features.to_numpy().reshape(
66
+ (
67
+ len(idx.levels[0]),
68
+ len(idx.levels[1]),
69
+ len(features.columns),
70
+ )
71
+ ),
72
+ axis1=0,
73
+ axis2=1,
74
+ ),
75
+ )
76
+ # One has to implement __call__ on eqx.Module, so this should
77
+ # always exist. Should really be abstract on eqx.Module
78
+ pred = jax.vmap(self.model)(features_) # type: ignore
79
+ return pd.DataFrame(pred, columns=self.parameter_names)
80
+
81
+
82
+ @dataclass
83
+ class SteadyStateTrainer:
84
+ """Trainer for steady state data using PyTorch models."""
85
+
86
+ features: pd.DataFrame
87
+ targets: pd.DataFrame
88
+ model: eqx.Module
89
+ optimizer: optax.GradientTransformation
90
+ losses: list[pd.Series]
91
+ loss_fn: LossFn
92
+ seed: int
93
+
94
+ def __init__(
95
+ self,
96
+ features: pd.DataFrame,
97
+ targets: pd.DataFrame,
98
+ model: eqx.Module | None = None,
99
+ optimizer: optax.GradientTransformation | None = None,
100
+ loss_fn: LossFn = mean_abs_error,
101
+ seed: int = 0,
102
+ ) -> None:
103
+ """Initialize the trainer with features, targets, and model.
104
+
105
+ Args:
106
+ features: DataFrame containing the input features for training
107
+ targets: DataFrame containing the target values for training
108
+ model: Predefined neural network model (None to use default MLP)
109
+ optimizer: Optimizer class to use for training (default: Adam)
110
+ device: Device to run the training on (default: DefaultDevice)
111
+ loss_fn: Loss function
112
+ seed: seed of random initialisation
113
+
114
+ """
115
+ self.features = features
116
+ self.targets = targets
117
+
118
+ if model is None:
119
+ n_hidden = max(2 * len(features.columns) * len(targets.columns), 10)
120
+ n_outputs = len(targets.columns)
121
+ model = MLP(
122
+ n_inputs=len(features.columns),
123
+ neurons_per_layer=[n_hidden, n_hidden, n_outputs],
124
+ key=jax.random.PRNGKey(seed),
125
+ )
126
+ self.model = model
127
+ self.optimizer = (
128
+ optax.adamw(learning_rate=0.001) if optimizer is None else optimizer
129
+ )
130
+ self.loss_fn = loss_fn
131
+ self.losses = []
132
+ self.seed = seed
133
+
134
+ def train(
135
+ self,
136
+ epochs: int,
137
+ batch_size: int | None = None,
138
+ ) -> Self:
139
+ """Train the model using the provided features and targets.
140
+
141
+ Args:
142
+ epochs: Number of training epochs
143
+ batch_size: Size of mini-batches for training (None for full-batch)
144
+
145
+ """
146
+ losses = _train(
147
+ model=self.model,
148
+ features=jnp.array(self.features),
149
+ targets=jnp.array(self.targets),
150
+ epochs=epochs,
151
+ optimizer=self.optimizer,
152
+ batch_size=batch_size,
153
+ loss_fn=self.loss_fn,
154
+ )
155
+
156
+ if len(self.losses) > 0:
157
+ losses.index += self.losses[-1].index[-1]
158
+ self.losses.append(losses)
159
+ return self
160
+
161
+ def get_loss(self) -> pd.Series:
162
+ """Get the loss history of the training process."""
163
+ return pd.concat(self.losses)
164
+
165
+ def get_estimator(self) -> SteadyState:
166
+ """Get the trained estimator."""
167
+ return SteadyState(
168
+ model=self.model,
169
+ parameter_names=list(self.targets.columns),
170
+ )
171
+
172
+
173
+ @dataclass
174
+ class TimeCourseTrainer:
175
+ """Trainer for time course data using PyTorch models."""
176
+
177
+ features: pd.DataFrame
178
+ targets: pd.DataFrame
179
+ model: eqx.Module
180
+ optimizer: optax.GradientTransformation
181
+ losses: list[pd.Series]
182
+ loss_fn: LossFn
183
+
184
+ def __init__(
185
+ self,
186
+ features: pd.DataFrame,
187
+ targets: pd.DataFrame,
188
+ model: eqx.Module | None = None,
189
+ optimizer: optax.GradientTransformation | None = None,
190
+ loss_fn: LossFn = mean_abs_error,
191
+ ) -> None:
192
+ """Initialize the trainer with features, targets, and model.
193
+
194
+ Args:
195
+ features: DataFrame containing the input features for training
196
+ targets: DataFrame containing the target values for training
197
+ model: Predefined neural network model (None to use default LSTM)
198
+ optimizer: Optimizer class to use for training (default: Adam)
199
+ device: Device to run the training on (default: DefaultDevice)
200
+ loss_fn: Loss function
201
+
202
+ """
203
+ self.features = features
204
+ self.targets = targets
205
+
206
+ if model is None:
207
+ model = LSTM(
208
+ n_inputs=len(features.columns),
209
+ n_outputs=len(targets.columns),
210
+ n_hidden=1,
211
+ key=jnp.array([]),
212
+ )
213
+ self.model = model
214
+ self.optimizer = (
215
+ optax.adamw(learning_rate=0.001) if optimizer is None else optimizer
216
+ )
217
+ self.loss_fn = loss_fn
218
+ self.losses = []
219
+
220
+ def train(
221
+ self,
222
+ epochs: int,
223
+ batch_size: int | None = None,
224
+ ) -> Self:
225
+ """Train the model using the provided features and targets.
226
+
227
+ Args:
228
+ epochs: Number of training epochs
229
+ batch_size: Size of mini-batches for training (None for full-batch)
230
+
231
+ """
232
+ losses = _train(
233
+ model=self.model,
234
+ features=jnp.array(
235
+ np.swapaxes(
236
+ self.features.to_numpy().reshape(
237
+ (len(self.targets), -1, len(self.features.columns))
238
+ ),
239
+ axis1=0,
240
+ axis2=1,
241
+ )
242
+ ),
243
+ targets=jnp.array(self.targets.to_numpy()),
244
+ epochs=epochs,
245
+ optimizer=self.optimizer,
246
+ batch_size=batch_size,
247
+ loss_fn=self.loss_fn,
248
+ )
249
+
250
+ if len(self.losses) > 0:
251
+ losses.index += self.losses[-1].index[-1]
252
+ self.losses.append(losses)
253
+ return self
254
+
255
+ def get_loss(self) -> pd.Series:
256
+ """Get the loss history of the training process."""
257
+ return pd.concat(self.losses)
258
+
259
+ def get_estimator(self) -> TimeCourse:
260
+ """Get the trained estimator."""
261
+ return TimeCourse(
262
+ model=self.model,
263
+ parameter_names=list(self.targets.columns),
264
+ )
265
+
266
+
267
+ def train_steady_state(
268
+ features: pd.DataFrame,
269
+ targets: pd.DataFrame,
270
+ epochs: int,
271
+ batch_size: int | None = None,
272
+ model: eqx.Module | None = None,
273
+ optimizer: optax.GradientTransformation | None = None,
274
+ ) -> tuple[SteadyState, pd.Series]:
275
+ """Train a PyTorch steady state estimator.
276
+
277
+ This function trains a neural network model to estimate steady state data
278
+ using the provided features and targets. It supports both full-batch and
279
+ mini-batch training.
280
+
281
+ Examples:
282
+ >>> train_torch_ss_estimator(features, targets, epochs=100)
283
+
284
+ Args:
285
+ features: DataFrame containing the input features for training
286
+ targets: DataFrame containing the target values for training
287
+ epochs: Number of training epochs
288
+ batch_size: Size of mini-batches for training (None for full-batch)
289
+ model: Predefined neural network model (None to use default MLP)
290
+ optimizer: Optimizer class to use for training (default: Adam)
291
+ device: Device to run the training on (default: DefaultDevice)
292
+
293
+ Returns:
294
+ tuple[TorchTimeSeriesEstimator, pd.Series]: Trained estimator and loss history
295
+
296
+ """
297
+ trainer = SteadyStateTrainer(
298
+ features=features,
299
+ targets=targets,
300
+ model=model,
301
+ optimizer=optimizer,
302
+ ).train(epochs=epochs, batch_size=batch_size)
303
+
304
+ return trainer.get_estimator(), trainer.get_loss()
305
+
306
+
307
+ def train_time_course(
308
+ features: pd.DataFrame,
309
+ targets: pd.DataFrame,
310
+ epochs: int,
311
+ batch_size: int | None = None,
312
+ model: eqx.Module | None = None,
313
+ optimizer: optax.GradientTransformation | None = None,
314
+ ) -> tuple[TimeCourse, pd.Series]:
315
+ """Train a PyTorch time course estimator.
316
+
317
+ This function trains a neural network model to estimate time course data
318
+ using the provided features and targets. It supports both full-batch and
319
+ mini-batch training.
320
+
321
+ Examples:
322
+ >>> train_torch_time_course_estimator(features, targets, epochs=100)
323
+
324
+ Args:
325
+ features: DataFrame containing the input features for training
326
+ targets: DataFrame containing the target values for training
327
+ epochs: Number of training epochs
328
+ batch_size: Size of mini-batches for training (None for full-batch)
329
+ model: Predefined neural network model (None to use default LSTM)
330
+ optimizer: Optimizer class to use for training (default: Adam)
331
+ device: Device to run the training on (default: DefaultDevice)
332
+
333
+ Returns:
334
+ tuple[TorchTimeSeriesEstimator, pd.Series]: Trained estimator and loss history
335
+
336
+ """
337
+ trainer = TimeCourseTrainer(
338
+ features=features,
339
+ targets=targets,
340
+ model=model,
341
+ optimizer=optimizer,
342
+ ).train(epochs=epochs, batch_size=batch_size)
343
+
344
+ return trainer.get_estimator(), trainer.get_loss()
mxlpy/npe/_torch.py CHANGED
@@ -20,7 +20,8 @@ import torch
20
20
  from torch import nn
21
21
  from torch.optim.adam import Adam
22
22
 
23
- from mxlpy.nn._torch import LSTM, MLP, DefaultDevice, train
23
+ from mxlpy.nn._torch import LSTM, MLP, DefaultDevice, LossFn, mean_abs_error
24
+ from mxlpy.nn._torch import train as _train
24
25
  from mxlpy.types import AbstractEstimator
25
26
 
26
27
  if TYPE_CHECKING:
@@ -29,10 +30,7 @@ if TYPE_CHECKING:
29
30
  from torch.optim.optimizer import ParamsT
30
31
 
31
32
 
32
- type LossFn = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
33
-
34
33
  __all__ = [
35
- "LossFn",
36
34
  "SteadyState",
37
35
  "SteadyStateTrainer",
38
36
  "TimeCourse",
@@ -42,20 +40,6 @@ __all__ = [
42
40
  ]
43
41
 
44
42
 
45
- def _mean_abs(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
46
- """Standard loss for surrogates.
47
-
48
- Args:
49
- x: Predictions of a model.
50
- y: Targets.
51
-
52
- Returns:
53
- torch.Tensor: loss.
54
-
55
- """
56
- return torch.mean(torch.abs(x - y))
57
-
58
-
59
43
  @dataclass(kw_only=True)
60
44
  class SteadyState(AbstractEstimator):
61
45
  """Estimator for steady state data using PyTorch models."""
@@ -115,7 +99,7 @@ class SteadyStateTrainer:
115
99
  model: nn.Module | None = None,
116
100
  optimizer_cls: Callable[[ParamsT], Adam] = Adam,
117
101
  device: torch.device = DefaultDevice,
118
- loss_fn: LossFn = _mean_abs,
102
+ loss_fn: LossFn = mean_abs_error,
119
103
  ) -> None:
120
104
  """Initialize the trainer with features, targets, and model.
121
105
 
@@ -156,7 +140,7 @@ class SteadyStateTrainer:
156
140
  batch_size: Size of mini-batches for training (None for full-batch)
157
141
 
158
142
  """
159
- losses = train(
143
+ losses = _train(
160
144
  model=self.model,
161
145
  features=self.features.to_numpy(),
162
146
  targets=self.targets.to_numpy(),
@@ -203,7 +187,7 @@ class TimeCourseTrainer:
203
187
  model: nn.Module | None = None,
204
188
  optimizer_cls: Callable[[ParamsT], Adam] = Adam,
205
189
  device: torch.device = DefaultDevice,
206
- loss_fn: LossFn = _mean_abs,
190
+ loss_fn: LossFn = mean_abs_error,
207
191
  ) -> None:
208
192
  """Initialize the trainer with features, targets, and model.
209
193
 
@@ -243,7 +227,7 @@ class TimeCourseTrainer:
243
227
  batch_size: Size of mini-batches for training (None for full-batch)
244
228
 
245
229
  """
246
- losses = train(
230
+ losses = _train(
247
231
  model=self.model,
248
232
  features=np.swapaxes(
249
233
  self.features.to_numpy().reshape(
mxlpy/parallel.py CHANGED
@@ -27,10 +27,7 @@ from tqdm import tqdm
27
27
  if TYPE_CHECKING:
28
28
  from collections.abc import Callable, Collection, Hashable
29
29
 
30
- __all__ = [
31
- "Cache",
32
- "parallelise",
33
- ]
30
+ __all__ = ["Cache", "parallelise", "parallelise_keyless"]
34
31
 
35
32
 
36
33
  def _pickle_name(k: Hashable) -> str:
@@ -173,3 +170,75 @@ def parallelise[K: Hashable, Tin, Tout](
173
170
  ) # type: ignore
174
171
 
175
172
  return results
173
+
174
+
175
+ def parallelise_keyless[Tin, Tout](
176
+ fn: Callable[[Tin], Tout],
177
+ inputs: Collection[Tin],
178
+ *,
179
+ parallel: bool = True,
180
+ max_workers: int | None = None,
181
+ timeout: float | None = None,
182
+ disable_tqdm: bool = False,
183
+ tqdm_desc: str | None = None,
184
+ ) -> list[Tout]:
185
+ """Execute a function in parallel over a collection of inputs.
186
+
187
+ Examples:
188
+ >>> parallelise(square, [("a", 2), ("b", 3), ("c", 4)])
189
+ {"a": 4, "b": 9, "c": 16}
190
+
191
+ Args:
192
+ fn: Function to execute in parallel. Takes a single input and returns a result.
193
+ inputs: Collection of (key, input) tuples to process.
194
+ cache: Optional cache to store and retrieve results.
195
+ parallel: Whether to execute in parallel (default: True).
196
+ max_workers: Maximum number of worker processes (default: None, uses all available CPUs).
197
+ timeout: Maximum time (in seconds) to wait for each worker to complete (default: None).
198
+ disable_tqdm: Whether to disable the tqdm progress bar (default: False).
199
+ tqdm_desc: Description for the tqdm progress bar (default: None).
200
+
201
+ Returns:
202
+ dict[Tin, Tout]: Dictionary mapping inputs to their corresponding outputs.
203
+
204
+ """
205
+ if sys.platform in ["win32", "cygwin"]:
206
+ parallel = False
207
+
208
+ results: list[Tout]
209
+ if parallel:
210
+ results = []
211
+ max_workers = (
212
+ multiprocessing.cpu_count() if max_workers is None else max_workers
213
+ )
214
+
215
+ with (
216
+ tqdm(
217
+ total=len(inputs),
218
+ disable=disable_tqdm,
219
+ desc=tqdm_desc,
220
+ ) as pbar,
221
+ pebble.ProcessPool(max_workers=max_workers) as pool,
222
+ ):
223
+ future = pool.map(fn, inputs, timeout=timeout)
224
+ it = future.result()
225
+ while True:
226
+ try:
227
+ value = next(it)
228
+ pbar.update(1)
229
+ results.append(value)
230
+ except StopIteration:
231
+ break
232
+ except TimeoutError:
233
+ pbar.update(1)
234
+ else:
235
+ results = list(
236
+ tqdm(
237
+ map(fn, inputs), # type: ignore
238
+ total=len(inputs),
239
+ disable=disable_tqdm,
240
+ desc=tqdm_desc,
241
+ )
242
+ ) # type: ignore
243
+
244
+ return results
@@ -14,11 +14,16 @@ if TYPE_CHECKING:
14
14
  import contextlib
15
15
 
16
16
  with contextlib.suppress(ImportError):
17
+ from . import _equinox as equinox
17
18
  from . import _keras as keras
18
19
  from . import _torch as torch
19
20
  else:
20
21
  from lazy_import import lazy_module
21
22
 
23
+ equinox = lazy_module(
24
+ "mxlpy.surrogates._equinox",
25
+ error_strings={"module": "equinox", "install_name": "mxlpy[equinox]"},
26
+ )
22
27
  keras = lazy_module(
23
28
  "mxlpy.surrogates._keras",
24
29
  error_strings={"module": "keras", "install_name": "mxlpy[tf]"},