mxlpy 0.16.0__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mxlpy/model.py CHANGED
@@ -18,6 +18,7 @@ import pandas as pd
18
18
 
19
19
  from mxlpy import fns
20
20
  from mxlpy.types import (
21
+ AbstractSurrogate,
21
22
  Array,
22
23
  Derived,
23
24
  Reaction,
@@ -27,6 +28,7 @@ from mxlpy.types import (
27
28
  __all__ = [
28
29
  "ArityMismatchError",
29
30
  "CircularDependencyError",
31
+ "Dependency",
30
32
  "MissingDependenciesError",
31
33
  "Model",
32
34
  "ModelCache",
@@ -36,7 +38,16 @@ if TYPE_CHECKING:
36
38
  from collections.abc import Iterable, Mapping
37
39
  from inspect import FullArgSpec
38
40
 
39
- from mxlpy.types import AbstractSurrogate, Callable, Param, RateFn, RetType
41
+ from mxlpy.types import Callable, Param, RateFn, RetType
42
+
43
+
44
+ @dataclass
45
+ class Dependency:
46
+ """Container class for building dependency tree."""
47
+
48
+ name: str
49
+ required: set[str]
50
+ provided: set[str]
40
51
 
41
52
 
42
53
  class MissingDependenciesError(Exception):
@@ -145,30 +156,33 @@ def _invalidate_cache(method: Callable[Param, RetType]) -> Callable[Param, RetTy
145
156
 
146
157
  def _check_if_is_sortable(
147
158
  available: set[str],
148
- elements: list[tuple[str, set[str]]],
159
+ elements: list[Dependency],
149
160
  ) -> None:
150
161
  all_available = available.copy()
151
- for name, _ in elements:
152
- all_available.add(name)
162
+ for dependency in elements:
163
+ all_available.update(dependency.provided)
153
164
 
154
165
  # Check if it can be sorted in the first place
155
166
  not_solvable = {}
156
- for name, args in elements:
157
- if not args.issubset(all_available):
158
- not_solvable[name] = sorted(args.difference(all_available))
167
+ for dependency in elements:
168
+ if not dependency.required.issubset(all_available):
169
+ not_solvable[dependency.name] = sorted(
170
+ dependency.required.difference(all_available)
171
+ )
159
172
 
160
173
  if not_solvable:
161
174
  raise MissingDependenciesError(not_solvable=not_solvable)
162
175
 
163
176
 
164
177
  def _sort_dependencies(
165
- available: set[str], elements: list[tuple[str, set[str]]]
178
+ available: set[str],
179
+ elements: list[Dependency],
166
180
  ) -> list[str]:
167
181
  """Sort model elements topologically based on their dependencies.
168
182
 
169
183
  Args:
170
184
  available: Set of available component names
171
- elements: List of (name, dependencies) tuples to sort
185
+ elements: List of (name, dependencies, supplier) tuples to sort
172
186
 
173
187
  Returns:
174
188
  List of element names in dependency order
@@ -184,26 +198,27 @@ def _sort_dependencies(
184
198
  order = []
185
199
  # FIXME: what is the worst case here?
186
200
  max_iterations = len(elements) ** 2
187
- queue: SimpleQueue[tuple[str, set[str]]] = SimpleQueue()
188
- for k, v in elements:
189
- queue.put((k, v))
201
+ queue: SimpleQueue[Dependency] = SimpleQueue()
202
+ for dependency in elements:
203
+ queue.put(dependency)
190
204
 
191
205
  last_name = None
192
206
  i = 0
193
207
  while True:
194
208
  try:
195
- new, args = queue.get_nowait()
209
+ dependency = queue.get_nowait()
196
210
  except Empty:
197
211
  break
198
- if args.issubset(available):
199
- available.add(new)
200
- order.append(new)
212
+ if dependency.required.issubset(available):
213
+ available.update(dependency.provided)
214
+ order.append(dependency.name)
215
+
201
216
  else:
202
- if last_name == new:
203
- order.append(new)
217
+ if last_name == dependency.name:
218
+ order.append(last_name)
204
219
  break
205
- queue.put((new, args))
206
- last_name = new
220
+ queue.put(dependency)
221
+ last_name = dependency.name
207
222
  i += 1
208
223
 
209
224
  # Failure case
@@ -211,11 +226,13 @@ def _sort_dependencies(
211
226
  unsorted = []
212
227
  while True:
213
228
  try:
214
- unsorted.append(queue.get_nowait()[0])
229
+ unsorted.append(queue.get_nowait().name)
215
230
  except Empty:
216
231
  break
217
232
 
218
- mod_to_args: dict[str, set[str]] = dict(elements)
233
+ mod_to_args: dict[str, set[str]] = {
234
+ dependency.name: dependency.required for dependency in elements
235
+ }
219
236
  missing = {k: mod_to_args[k].difference(available) for k in unsorted}
220
237
  raise CircularDependencyError(missing=missing)
221
238
  return order
@@ -303,7 +320,12 @@ class Model:
303
320
  to_sort = self._derived | self._reactions | self._surrogates
304
321
  order = _sort_dependencies(
305
322
  available=set(self._parameters) | set(self._variables) | {"time"},
306
- elements=[(k, set(v.args)) for k, v in to_sort.items()],
323
+ elements=[
324
+ Dependency(name=k, required=set(v.args), provided={k})
325
+ if not isinstance(v, AbstractSurrogate)
326
+ else Dependency(name=k, required=set(v.args), provided=set(v.outputs))
327
+ for k, v in to_sort.items()
328
+ ],
307
329
  )
308
330
 
309
331
  # Split derived into parameters and variables
@@ -1227,6 +1249,7 @@ class Model:
1227
1249
  name: str,
1228
1250
  surrogate: AbstractSurrogate,
1229
1251
  args: list[str] | None = None,
1252
+ outputs: list[str] | None = None,
1230
1253
  stoichiometries: dict[str, dict[str, float]] | None = None,
1231
1254
  ) -> Self:
1232
1255
  """Adds a surrogate model to the current instance.
@@ -1237,7 +1260,8 @@ class Model:
1237
1260
  Args:
1238
1261
  name (str): The name of the surrogate model.
1239
1262
  surrogate (AbstractSurrogate): The surrogate model instance to be added.
1240
- args: A list of arguments for the surrogate model.
1263
+ args: Names of the values passed for the surrogate model.
1264
+ outputs: Names of values produced by the surrogate model.
1241
1265
  stoichiometries: A dictionary mapping reaction names to stoichiometries.
1242
1266
 
1243
1267
  Returns:
@@ -1248,6 +1272,8 @@ class Model:
1248
1272
 
1249
1273
  if args is not None:
1250
1274
  surrogate.args = args
1275
+ if outputs is not None:
1276
+ surrogate.outputs = outputs
1251
1277
  if stoichiometries is not None:
1252
1278
  surrogate.stoichiometries = stoichiometries
1253
1279
 
mxlpy/nn/_torch.py CHANGED
@@ -8,17 +8,77 @@ from __future__ import annotations
8
8
 
9
9
  from typing import TYPE_CHECKING, cast
10
10
 
11
+ import numpy as np
12
+ import pandas as pd
11
13
  import torch
14
+ import tqdm
12
15
  from torch import nn
16
+ from torch.utils.data import DataLoader, TensorDataset
17
+
18
+ type LossFn = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
13
19
 
14
20
  if TYPE_CHECKING:
15
21
  from collections.abc import Callable
16
22
 
17
- __all__ = ["DefaultDevice", "LSTM", "MLP"]
23
+ from torch.optim.adam import Adam
24
+
25
+ from mxlpy.types import Array
26
+
27
+ __all__ = ["DefaultDevice", "LSTM", "LossFn", "MLP", "train"]
18
28
 
19
29
  DefaultDevice = torch.device("cpu")
20
30
 
21
31
 
32
+ def train(
33
+ aprox: nn.Module,
34
+ features: Array,
35
+ targets: Array,
36
+ epochs: int,
37
+ optimizer: Adam,
38
+ device: torch.device,
39
+ batch_size: int | None,
40
+ loss_fn: LossFn,
41
+ ) -> pd.Series:
42
+ """Train the neural network using mini-batch gradient descent.
43
+
44
+ Args:
45
+ aprox: Neural network model to train.
46
+ features: Input features as a tensor.
47
+ targets: Target values as a tensor.
48
+ epochs: Number of training epochs.
49
+ optimizer: Optimizer for training.
50
+ device: torch device
51
+ batch_size: Size of mini-batches for training.
52
+ loss_fn: Loss function
53
+
54
+ Returns:
55
+ pd.Series: Series containing the training loss history.
56
+
57
+ """
58
+ losses = {}
59
+
60
+ data = TensorDataset(
61
+ torch.tensor(features.astype(np.float32), dtype=torch.float32, device=device),
62
+ torch.tensor(targets.astype(np.float32), dtype=torch.float32, device=device),
63
+ )
64
+ data_loader = DataLoader(
65
+ data,
66
+ batch_size=len(features) if batch_size is None else batch_size,
67
+ shuffle=True,
68
+ )
69
+
70
+ for i in tqdm.trange(epochs):
71
+ epoch_loss = 0
72
+ for xb, yb in data_loader:
73
+ optimizer.zero_grad()
74
+ loss = loss_fn(aprox(xb), yb)
75
+ loss.backward()
76
+ optimizer.step()
77
+ epoch_loss += loss.item() * xb.size(0)
78
+ losses[i] = epoch_loss / len(data_loader.dataset) # type: ignore
79
+ return pd.Series(losses, dtype=float)
80
+
81
+
22
82
  class MLP(nn.Module):
23
83
  """Multilayer Perceptron (MLP) for surrogate modeling and neural posterior estimation.
24
84
 
mxlpy/npe/__init__.py ADDED
@@ -0,0 +1,38 @@
1
+ """Neural Process Estimation (NPE) module.
2
+
3
+ This module provides classes and functions for estimating metabolic processes using
4
+ neural networks. It includes functionality for both steady-state and time-course data.
5
+
6
+ Classes:
7
+ TorchSteadyState: Class for steady-state neural network estimation.
8
+ TorchSteadyStateTrainer: Class for training steady-state neural networks.
9
+ TorchTimeCourse: Class for time-course neural network estimation.
10
+ TorchTimeCourseTrainer: Class for training time-course neural networks.
11
+
12
+ Functions:
13
+ train_torch_steady_state: Train a PyTorch steady-state neural network.
14
+ train_torch_time_course: Train a PyTorch time-course neural network.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import contextlib
20
+
21
+ with contextlib.suppress(ImportError):
22
+ from ._torch import (
23
+ TorchSteadyState,
24
+ TorchSteadyStateTrainer,
25
+ TorchTimeCourse,
26
+ TorchTimeCourseTrainer,
27
+ train_torch_steady_state,
28
+ train_torch_time_course,
29
+ )
30
+
31
+ __all__ = [
32
+ "TorchSteadyState",
33
+ "TorchSteadyStateTrainer",
34
+ "TorchTimeCourse",
35
+ "TorchTimeCourseTrainer",
36
+ "train_torch_steady_state",
37
+ "train_torch_time_course",
38
+ ]
mxlpy/npe/_torch.py ADDED
@@ -0,0 +1,365 @@
1
+ """Neural Network Parameter Estimation (NPE) Module.
2
+
3
+ This module provides classes and functions for training neural network models to estimate
4
+ parameters in metabolic models. It includes functionality for both steady-state and
5
+ time-series data.
6
+
7
+ Functions:
8
+ train_torch_surrogate: Train a PyTorch surrogate model
9
+ train_torch_time_course_estimator: Train a PyTorch time course estimator
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from dataclasses import dataclass
15
+ from pathlib import Path
16
+ from typing import TYPE_CHECKING, Self, cast
17
+
18
+ import numpy as np
19
+ import pandas as pd
20
+ import torch
21
+ from torch import nn
22
+ from torch.optim.adam import Adam
23
+
24
+ from mxlpy.nn._torch import LSTM, MLP, DefaultDevice, train
25
+ from mxlpy.parallel import Cache
26
+ from mxlpy.types import AbstractEstimator
27
+
28
+ if TYPE_CHECKING:
29
+ from collections.abc import Callable
30
+
31
+ from torch.optim.optimizer import ParamsT
32
+
33
+ DefaultCache = Cache(Path(".cache"))
34
+
35
+ type LossFn = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
36
+
37
+ __all__ = [
38
+ "DefaultCache",
39
+ "LossFn",
40
+ "TorchSteadyState",
41
+ "TorchSteadyStateTrainer",
42
+ "TorchTimeCourse",
43
+ "TorchTimeCourseTrainer",
44
+ "train_torch_steady_state",
45
+ "train_torch_time_course",
46
+ ]
47
+
48
+
49
+ def _mean_abs(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
50
+ """Standard loss for surrogates.
51
+
52
+ Args:
53
+ x: Predictions of a model.
54
+ y: Targets.
55
+
56
+ Returns:
57
+ torch.Tensor: loss.
58
+
59
+ """
60
+ return torch.mean(torch.abs(x - y))
61
+
62
+
63
+ @dataclass(kw_only=True)
64
+ class TorchSteadyState(AbstractEstimator):
65
+ """Estimator for steady state data using PyTorch models."""
66
+
67
+ model: torch.nn.Module
68
+
69
+ def predict(self, features: pd.Series | pd.DataFrame) -> pd.DataFrame:
70
+ """Predict the target values for the given features."""
71
+ with torch.no_grad():
72
+ pred = self.model(torch.tensor(features.to_numpy(), dtype=torch.float32))
73
+ return pd.DataFrame(pred, columns=self.parameter_names)
74
+
75
+
76
+ @dataclass(kw_only=True)
77
+ class TorchTimeCourse(AbstractEstimator):
78
+ """Estimator for time course data using PyTorch models."""
79
+
80
+ model: torch.nn.Module
81
+
82
+ def predict(self, features: pd.Series | pd.DataFrame) -> pd.DataFrame:
83
+ """Predict the target values for the given features."""
84
+ idx = cast(pd.MultiIndex, features.index)
85
+ features_ = torch.Tensor(
86
+ np.swapaxes(
87
+ features.to_numpy().reshape(
88
+ (
89
+ len(idx.levels[0]),
90
+ len(idx.levels[1]),
91
+ len(features.columns),
92
+ )
93
+ ),
94
+ axis1=0,
95
+ axis2=1,
96
+ ),
97
+ )
98
+ with torch.no_grad():
99
+ pred = self.model(features_)
100
+ return pd.DataFrame(pred, columns=self.parameter_names)
101
+
102
+
103
+ @dataclass
104
+ class TorchSteadyStateTrainer:
105
+ """Trainer for steady state data using PyTorch models."""
106
+
107
+ features: pd.DataFrame
108
+ targets: pd.DataFrame
109
+ approximator: nn.Module
110
+ optimimzer: Adam
111
+ device: torch.device
112
+ losses: list[pd.Series]
113
+ loss_fn: LossFn
114
+
115
+ def __init__(
116
+ self,
117
+ features: pd.DataFrame,
118
+ targets: pd.DataFrame,
119
+ approximator: nn.Module | None = None,
120
+ optimimzer_cls: Callable[[ParamsT], Adam] = Adam,
121
+ device: torch.device = DefaultDevice,
122
+ loss_fn: LossFn = _mean_abs,
123
+ ) -> None:
124
+ """Initialize the trainer with features, targets, and model.
125
+
126
+ Args:
127
+ features: DataFrame containing the input features for training
128
+ targets: DataFrame containing the target values for training
129
+ approximator: Predefined neural network model (None to use default MLP)
130
+ optimimzer_cls: Optimizer class to use for training (default: Adam)
131
+ device: Device to run the training on (default: DefaultDevice)
132
+ loss_fn: Loss function
133
+
134
+ """
135
+ self.features = features
136
+ self.targets = targets
137
+
138
+ if approximator is None:
139
+ n_hidden = max(2 * len(features.columns) * len(targets.columns), 10)
140
+ n_outputs = len(targets.columns)
141
+ approximator = MLP(
142
+ n_inputs=len(features.columns),
143
+ neurons_per_layer=[n_hidden, n_hidden, n_outputs],
144
+ )
145
+ self.approximator = approximator.to(device)
146
+ self.optimizer = optimimzer_cls(approximator.parameters())
147
+ self.device = device
148
+ self.loss_fn = loss_fn
149
+ self.losses = []
150
+
151
+ def train(
152
+ self,
153
+ epochs: int,
154
+ batch_size: int | None = None,
155
+ ) -> Self:
156
+ """Train the model using the provided features and targets.
157
+
158
+ Args:
159
+ epochs: Number of training epochs
160
+ batch_size: Size of mini-batches for training (None for full-batch)
161
+
162
+ """
163
+ losses = train(
164
+ aprox=self.approximator,
165
+ features=self.features.to_numpy(),
166
+ targets=self.targets.to_numpy(),
167
+ epochs=epochs,
168
+ optimizer=self.optimizer,
169
+ batch_size=batch_size,
170
+ loss_fn=self.loss_fn,
171
+ device=self.device,
172
+ )
173
+
174
+ if len(self.losses) > 0:
175
+ losses.index += self.losses[-1].index[-1]
176
+ self.losses.append(losses)
177
+ return self
178
+
179
+ def get_loss(self) -> pd.Series:
180
+ """Get the loss history of the training process."""
181
+ return pd.concat(self.losses)
182
+
183
+ def get_estimator(self) -> TorchSteadyState:
184
+ """Get the trained estimator."""
185
+ return TorchSteadyState(
186
+ model=self.approximator,
187
+ parameter_names=list(self.targets.columns),
188
+ )
189
+
190
+
191
+ @dataclass
192
+ class TorchTimeCourseTrainer:
193
+ """Trainer for time course data using PyTorch models."""
194
+
195
+ features: pd.DataFrame
196
+ targets: pd.DataFrame
197
+ approximator: nn.Module
198
+ optimimzer: Adam
199
+ device: torch.device
200
+ losses: list[pd.Series]
201
+ loss_fn: LossFn
202
+
203
+ def __init__(
204
+ self,
205
+ features: pd.DataFrame,
206
+ targets: pd.DataFrame,
207
+ approximator: nn.Module | None = None,
208
+ optimimzer_cls: Callable[[ParamsT], Adam] = Adam,
209
+ device: torch.device = DefaultDevice,
210
+ loss_fn: LossFn = _mean_abs,
211
+ ) -> None:
212
+ """Initialize the trainer with features, targets, and model.
213
+
214
+ Args:
215
+ features: DataFrame containing the input features for training
216
+ targets: DataFrame containing the target values for training
217
+ approximator: Predefined neural network model (None to use default LSTM)
218
+ optimimzer_cls: Optimizer class to use for training (default: Adam)
219
+ device: Device to run the training on (default: DefaultDevice)
220
+ loss_fn: Loss function
221
+
222
+ """
223
+ self.features = features
224
+ self.targets = targets
225
+
226
+ if approximator is None:
227
+ approximator = LSTM(
228
+ n_inputs=len(features.columns),
229
+ n_outputs=len(targets.columns),
230
+ n_hidden=1,
231
+ ).to(device)
232
+ self.approximator = approximator.to(device)
233
+ self.optimizer = optimimzer_cls(approximator.parameters())
234
+ self.device = device
235
+ self.loss_fn = loss_fn
236
+ self.losses = []
237
+
238
+ def train(
239
+ self,
240
+ epochs: int,
241
+ batch_size: int | None = None,
242
+ ) -> Self:
243
+ """Train the model using the provided features and targets.
244
+
245
+ Args:
246
+ epochs: Number of training epochs
247
+ batch_size: Size of mini-batches for training (None for full-batch)
248
+
249
+ """
250
+ losses = train(
251
+ aprox=self.approximator,
252
+ features=np.swapaxes(
253
+ self.features.to_numpy().reshape(
254
+ (len(self.targets), -1, len(self.features.columns))
255
+ ),
256
+ axis1=0,
257
+ axis2=1,
258
+ ),
259
+ targets=self.targets.to_numpy(),
260
+ epochs=epochs,
261
+ optimizer=self.optimizer,
262
+ batch_size=batch_size,
263
+ loss_fn=self.loss_fn,
264
+ device=self.device,
265
+ )
266
+
267
+ if len(self.losses) > 0:
268
+ losses.index += self.losses[-1].index[-1]
269
+ self.losses.append(losses)
270
+ return self
271
+
272
+ def get_loss(self) -> pd.Series:
273
+ """Get the loss history of the training process."""
274
+ return pd.concat(self.losses)
275
+
276
+ def get_estimator(self) -> TorchTimeCourse:
277
+ """Get the trained estimator."""
278
+ return TorchTimeCourse(
279
+ model=self.approximator,
280
+ parameter_names=list(self.targets.columns),
281
+ )
282
+
283
+
284
+ def train_torch_steady_state(
285
+ features: pd.DataFrame,
286
+ targets: pd.DataFrame,
287
+ epochs: int,
288
+ batch_size: int | None = None,
289
+ approximator: nn.Module | None = None,
290
+ optimimzer_cls: Callable[[ParamsT], Adam] = Adam,
291
+ device: torch.device = DefaultDevice,
292
+ ) -> tuple[TorchSteadyState, pd.Series]:
293
+ """Train a PyTorch steady state estimator.
294
+
295
+ This function trains a neural network model to estimate steady state data
296
+ using the provided features and targets. It supports both full-batch and
297
+ mini-batch training.
298
+
299
+ Examples:
300
+ >>> train_torch_ss_estimator(features, targets, epochs=100)
301
+
302
+ Args:
303
+ features: DataFrame containing the input features for training
304
+ targets: DataFrame containing the target values for training
305
+ epochs: Number of training epochs
306
+ batch_size: Size of mini-batches for training (None for full-batch)
307
+ approximator: Predefined neural network model (None to use default MLP)
308
+ optimimzer_cls: Optimizer class to use for training (default: Adam)
309
+ device: Device to run the training on (default: DefaultDevice)
310
+
311
+ Returns:
312
+ tuple[TorchTimeSeriesEstimator, pd.Series]: Trained estimator and loss history
313
+
314
+ """
315
+ trainer = TorchSteadyStateTrainer(
316
+ features=features,
317
+ targets=targets,
318
+ approximator=approximator,
319
+ optimimzer_cls=optimimzer_cls,
320
+ device=device,
321
+ ).train(epochs=epochs, batch_size=batch_size)
322
+
323
+ return trainer.get_estimator(), trainer.get_loss()
324
+
325
+
326
+ def train_torch_time_course(
327
+ features: pd.DataFrame,
328
+ targets: pd.DataFrame,
329
+ epochs: int,
330
+ batch_size: int | None = None,
331
+ approximator: nn.Module | None = None,
332
+ optimimzer_cls: Callable[[ParamsT], Adam] = Adam,
333
+ device: torch.device = DefaultDevice,
334
+ ) -> tuple[TorchTimeCourse, pd.Series]:
335
+ """Train a PyTorch time course estimator.
336
+
337
+ This function trains a neural network model to estimate time course data
338
+ using the provided features and targets. It supports both full-batch and
339
+ mini-batch training.
340
+
341
+ Examples:
342
+ >>> train_torch_time_course_estimator(features, targets, epochs=100)
343
+
344
+ Args:
345
+ features: DataFrame containing the input features for training
346
+ targets: DataFrame containing the target values for training
347
+ epochs: Number of training epochs
348
+ batch_size: Size of mini-batches for training (None for full-batch)
349
+ approximator: Predefined neural network model (None to use default LSTM)
350
+ optimimzer_cls: Optimizer class to use for training (default: Adam)
351
+ device: Device to run the training on (default: DefaultDevice)
352
+
353
+ Returns:
354
+ tuple[TorchTimeSeriesEstimator, pd.Series]: Trained estimator and loss history
355
+
356
+ """
357
+ trainer = TorchTimeCourseTrainer(
358
+ features=features,
359
+ targets=targets,
360
+ approximator=approximator,
361
+ optimimzer_cls=optimimzer_cls,
362
+ device=device,
363
+ ).train(epochs=epochs, batch_size=batch_size)
364
+
365
+ return trainer.get_estimator(), trainer.get_loss()