modelbase2 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,286 @@
1
+ # ruff: noqa: D100, D101, D102, D103, D104, D105, D106, D107, D200, D203, D400, D401
2
+
3
+ import ast
4
+ import inspect
5
+ import textwrap
6
+ from collections.abc import Callable
7
+ from dataclasses import dataclass
8
+ from typing import Any, cast
9
+
10
+ import sympy
11
+
12
+ from modelbase2.model import Model
13
+
14
+ __all__ = ["Context", "SymbolicModel", "model_fn_to_sympy", "to_symbolic_model"]
15
+
16
+
17
+ @dataclass
18
+ class Context:
19
+ symbols: dict[str, sympy.Symbol | sympy.Expr]
20
+ caller: Callable
21
+
22
+
23
+ @dataclass
24
+ class SymbolicModel:
25
+ variables: dict[str, sympy.Symbol]
26
+ parameters: dict[str, sympy.Symbol]
27
+ eqs: list[sympy.Expr]
28
+
29
+
30
+ def to_symbolic_model(model: Model) -> SymbolicModel:
31
+ cache = model._create_cache() # noqa: SLF001
32
+
33
+ variables = dict(
34
+ zip(model.variables, sympy.symbols(list(model.variables)), strict=True)
35
+ )
36
+ parameters = dict(
37
+ zip(model.parameters, sympy.symbols(list(model.parameters)), strict=True)
38
+ )
39
+ symbols = variables | parameters
40
+
41
+ for k, v in model.derived.items():
42
+ symbols[k] = model_fn_to_sympy(v.fn, [symbols[i] for i in v.args])
43
+
44
+ rxns = {
45
+ k: model_fn_to_sympy(v.fn, [symbols[i] for i in v.args])
46
+ for k, v in model.reactions.items()
47
+ }
48
+
49
+ eqs: dict[str, sympy.Expr] = {}
50
+ for cpd, stoich in cache.stoich_by_cpds.items():
51
+ for rxn, stoich_value in stoich.items():
52
+ eqs[cpd] = (
53
+ eqs.get(cpd, sympy.Float(0.0)) + sympy.Float(stoich_value) * rxns[rxn] # type: ignore
54
+ )
55
+
56
+ for cpd, dstoich in cache.dyn_stoich_by_cpds.items():
57
+ for rxn, der in dstoich.items():
58
+ eqs[cpd] = eqs.get(cpd, sympy.Float(0.0)) + model_fn_to_sympy(
59
+ der.fn,
60
+ [symbols[i] for i in der.args] * rxns[rxn], # type: ignore
61
+ ) # type: ignore
62
+
63
+ return SymbolicModel(
64
+ variables=variables,
65
+ parameters=parameters,
66
+ eqs=[eqs[i] for i in cache.var_names],
67
+ )
68
+
69
+
70
+ def model_fn_to_sympy(
71
+ fn: Callable, model_args: list[sympy.Symbol | sympy.Expr] | None = None
72
+ ) -> sympy.Expr:
73
+ source = textwrap.dedent(inspect.getsource(fn))
74
+
75
+ if not isinstance(fn_def := ast.parse(source).body[0], ast.FunctionDef):
76
+ msg = "Expected a function definition"
77
+ raise TypeError(msg)
78
+
79
+ fn_args = [str(arg.arg) for arg in fn_def.args.args]
80
+
81
+ sympy_expr = _handle_fn_body(
82
+ fn_def.body,
83
+ ctx=Context(
84
+ symbols={name: sympy.Symbol(name) for name in fn_args},
85
+ caller=fn,
86
+ ),
87
+ )
88
+
89
+ if model_args is not None:
90
+ sympy_expr = sympy_expr.subs(dict(zip(fn_args, model_args, strict=True)))
91
+
92
+ return cast(sympy.Expr, sympy_expr)
93
+
94
+
95
+ def _handle_fn_body(body: list[ast.stmt], ctx: Context) -> sympy.Expr:
96
+ pieces = []
97
+ remaining_body = list(body)
98
+
99
+ while remaining_body:
100
+ node = remaining_body.pop(0)
101
+
102
+ if isinstance(node, ast.If):
103
+ condition = _handle_expr(node.test, ctx)
104
+ if_expr = _handle_fn_body(node.body, ctx)
105
+ pieces.append((if_expr, condition))
106
+
107
+ # If there's an else clause
108
+ if node.orelse:
109
+ # Check if it's an elif (an If node in orelse)
110
+ if len(node.orelse) == 1 and isinstance(node.orelse[0], ast.If):
111
+ # Push the elif back to the beginning of remaining_body to process next
112
+ remaining_body.insert(0, node.orelse[0])
113
+ else:
114
+ # It's a regular else
115
+ else_expr = _handle_fn_body(node.orelse, ctx) # FIXME: copy here
116
+ pieces.append((else_expr, True))
117
+ break # We're done with this chain
118
+
119
+ elif not remaining_body and any(
120
+ isinstance(n, ast.Return) for n in body[body.index(node) + 1 :]
121
+ ):
122
+ else_expr = _handle_fn_body(
123
+ body[body.index(node) + 1 :], ctx
124
+ ) # FIXME: copy here
125
+ pieces.append((else_expr, True))
126
+
127
+ elif isinstance(node, ast.Return):
128
+ if (value := node.value) is None:
129
+ msg = "Return value cannot be None"
130
+ raise ValueError(msg)
131
+
132
+ expr = _handle_expr(value, ctx)
133
+ if not pieces:
134
+ return expr
135
+ pieces.append((expr, True))
136
+ break
137
+
138
+ elif isinstance(node, ast.Assign):
139
+ # Handle tuple assignments like c, d = a, b
140
+ if isinstance(node.targets[0], ast.Tuple):
141
+ # Handle tuple unpacking
142
+ target_elements = node.targets[0].elts
143
+
144
+ if isinstance(node.value, ast.Tuple):
145
+ # Direct unpacking like c, d = a, b
146
+ value_elements = node.value.elts
147
+ for target, value_expr in zip(
148
+ target_elements, value_elements, strict=True
149
+ ):
150
+ if isinstance(target, ast.Name):
151
+ ctx.symbols[target.id] = _handle_expr(value_expr, ctx)
152
+ else:
153
+ # Handle potential iterable unpacking
154
+ value = _handle_expr(node.value, ctx)
155
+ else:
156
+ # Regular single assignment
157
+ if not isinstance(target := node.targets[0], ast.Name):
158
+ msg = "Only single variable assignments are supported"
159
+ raise TypeError(msg)
160
+ target_name = target.id
161
+ value = _handle_expr(node.value, ctx)
162
+ ctx.symbols[target_name] = value
163
+
164
+ # If we have pieces to combine into a Piecewise
165
+ if pieces:
166
+ return sympy.Piecewise(*pieces)
167
+
168
+ # If no return was found but we have assignments, return the last assigned variable
169
+ for node in reversed(body):
170
+ if isinstance(node, ast.Assign) and isinstance(node.targets[0], ast.Name):
171
+ target_name = node.targets[0].id
172
+ return ctx.symbols[target_name]
173
+
174
+ msg = "No return value found in function body"
175
+ raise ValueError(msg)
176
+
177
+
178
+ def _handle_unaryop(node: ast.UnaryOp, ctx: Context) -> sympy.Expr:
179
+ left = _handle_expr(node.operand, ctx)
180
+ left = cast(Any, left) # stupid sympy types don't allow ops on symbols
181
+
182
+ if isinstance(node.op, ast.UAdd):
183
+ return +left
184
+ if isinstance(node.op, ast.USub):
185
+ return -left
186
+
187
+ msg = f"Operation {type(node.op).__name__} not implemented"
188
+ raise NotImplementedError(msg)
189
+
190
+
191
+ def _handle_binop(node: ast.BinOp, ctx: Context) -> sympy.Expr:
192
+ left = _handle_expr(node.left, ctx)
193
+ left = cast(Any, left) # stupid sympy types don't allow ops on symbols
194
+
195
+ right = _handle_expr(node.right, ctx)
196
+ right = cast(Any, right) # stupid sympy types don't allow ops on symbols
197
+
198
+ if isinstance(node.op, ast.Add):
199
+ return left + right
200
+ if isinstance(node.op, ast.Sub):
201
+ return left - right
202
+ if isinstance(node.op, ast.Mult):
203
+ return left * right
204
+ if isinstance(node.op, ast.Div):
205
+ return left / right
206
+ if isinstance(node.op, ast.Pow):
207
+ return left**right
208
+ if isinstance(node.op, ast.Mod):
209
+ return left % right
210
+ if isinstance(node.op, ast.FloorDiv):
211
+ return left // right
212
+
213
+ msg = f"Operation {type(node.op).__name__} not implemented"
214
+ raise NotImplementedError(msg)
215
+
216
+
217
+ def _handle_call(node: ast.Call, ctx: Context) -> sympy.Expr:
218
+ if not isinstance(callee := node.func, ast.Name):
219
+ msg = "Only function calls with names are supported"
220
+ raise TypeError(msg)
221
+
222
+ fn_name = str(callee.id)
223
+ parent_module = inspect.getmodule(ctx.caller)
224
+ fns = dict(inspect.getmembers(parent_module, predicate=callable))
225
+
226
+ return model_fn_to_sympy(
227
+ fns[fn_name],
228
+ model_args=[_handle_expr(i, ctx) for i in node.args],
229
+ )
230
+
231
+
232
+ def _handle_name(node: ast.Name, ctx: Context) -> sympy.Symbol | sympy.Expr:
233
+ return ctx.symbols[node.id]
234
+
235
+
236
+ def _handle_expr(node: ast.expr, ctx: Context) -> sympy.Expr:
237
+ if isinstance(node, ast.UnaryOp):
238
+ return _handle_unaryop(node, ctx)
239
+ if isinstance(node, ast.BinOp):
240
+ return _handle_binop(node, ctx)
241
+ if isinstance(node, ast.Name):
242
+ return _handle_name(node, ctx)
243
+ if isinstance(node, ast.Constant):
244
+ return node.value
245
+ if isinstance(node, ast.Compare):
246
+ # Handle chained comparisons like 1 < a < 2
247
+ left = cast(Any, _handle_expr(node.left, ctx))
248
+ comparisons = []
249
+
250
+ # Build all individual comparisons from the chain
251
+ prev_value = left
252
+ for op, comparator in zip(node.ops, node.comparators, strict=True):
253
+ right = cast(Any, _handle_expr(comparator, ctx))
254
+
255
+ if isinstance(op, ast.Gt):
256
+ comparisons.append(prev_value > right)
257
+ elif isinstance(op, ast.GtE):
258
+ comparisons.append(prev_value >= right)
259
+ elif isinstance(op, ast.Lt):
260
+ comparisons.append(prev_value < right)
261
+ elif isinstance(op, ast.LtE):
262
+ comparisons.append(prev_value <= right)
263
+ elif isinstance(op, ast.Eq):
264
+ comparisons.append(prev_value == right)
265
+ elif isinstance(op, ast.NotEq):
266
+ comparisons.append(prev_value != right)
267
+
268
+ prev_value = right
269
+
270
+ # Combine all comparisons with logical AND
271
+ result = comparisons[0]
272
+ for comp in comparisons[1:]:
273
+ result = sympy.And(result, comp)
274
+ return cast(sympy.Expr, result)
275
+ if isinstance(node, ast.Call):
276
+ return _handle_call(node, ctx)
277
+
278
+ # Handle conditional expressions (ternary operators)
279
+ if isinstance(node, ast.IfExp):
280
+ condition = _handle_expr(node.test, ctx)
281
+ if_true = _handle_expr(node.body, ctx)
282
+ if_false = _handle_expr(node.orelse, ctx)
283
+ return sympy.Piecewise((if_true, condition), (if_false, True))
284
+
285
+ msg = f"Expression type {type(node).__name__} not implemented"
286
+ raise NotImplementedError(msg)
modelbase2/fit.py CHANGED
@@ -50,7 +50,7 @@ class SteadyStateResidualFn(Protocol):
50
50
  data: pd.Series,
51
51
  model: Model,
52
52
  y0: dict[str, float],
53
- integrator: type[IntegratorProtocol],
53
+ integrator: Callable[[Callable, ArrayLike], IntegratorProtocol],
54
54
  ) -> float:
55
55
  """Calculate residual error between model steady state and experimental data."""
56
56
  ...
@@ -67,7 +67,7 @@ class TimeSeriesResidualFn(Protocol):
67
67
  data: pd.DataFrame,
68
68
  model: Model,
69
69
  y0: dict[str, float],
70
- integrator: type[IntegratorProtocol],
70
+ integrator: Callable[[Callable, ArrayLike], IntegratorProtocol],
71
71
  ) -> float:
72
72
  """Calculate residual error between model time course and experimental data."""
73
73
  ...
@@ -101,7 +101,7 @@ def _steady_state_residual(
101
101
  data: pd.Series,
102
102
  model: Model,
103
103
  y0: dict[str, float] | None,
104
- integrator: type[IntegratorProtocol],
104
+ integrator: Callable[[Callable, ArrayLike], IntegratorProtocol],
105
105
  ) -> float:
106
106
  """Calculate residual error between model steady state and experimental data.
107
107
 
@@ -148,7 +148,7 @@ def _time_course_residual(
148
148
  data: pd.DataFrame,
149
149
  model: Model,
150
150
  y0: dict[str, float],
151
- integrator: type[IntegratorProtocol],
151
+ integrator: Callable[[Callable, ArrayLike], IntegratorProtocol],
152
152
  ) -> float:
153
153
  """Calculate residual error between model time course and experimental data.
154
154
 
@@ -187,7 +187,7 @@ def steady_state(
187
187
  y0: dict[str, float] | None = None,
188
188
  minimize_fn: MinimizeFn = _default_minimize_fn,
189
189
  residual_fn: SteadyStateResidualFn = _steady_state_residual,
190
- integrator: type[IntegratorProtocol] = DefaultIntegrator,
190
+ integrator: Callable[[Callable, ArrayLike], IntegratorProtocol] = DefaultIntegrator,
191
191
  ) -> dict[str, float]:
192
192
  """Fit model parameters to steady-state experimental data.
193
193
 
@@ -241,7 +241,7 @@ def time_course(
241
241
  y0: dict[str, float] | None = None,
242
242
  minimize_fn: MinimizeFn = _default_minimize_fn,
243
243
  residual_fn: TimeSeriesResidualFn = _time_course_residual,
244
- integrator: type[IntegratorProtocol] = DefaultIntegrator,
244
+ integrator: Callable[[Callable, ArrayLike], IntegratorProtocol] = DefaultIntegrator,
245
245
  ) -> dict[str, float]:
246
246
  """Fit model parameters to time course of experimental data.
247
247
 
modelbase2/model.py CHANGED
@@ -10,7 +10,6 @@ from __future__ import annotations
10
10
  import copy
11
11
  import inspect
12
12
  import itertools as it
13
- import math
14
13
  from dataclasses import dataclass, field
15
14
  from typing import TYPE_CHECKING, Self, cast
16
15
 
@@ -0,0 +1,128 @@
1
+ """Neural network architectures.
2
+
3
+ This module provides implementations of neural network architectures used for mechanistic learning.
4
+
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import TYPE_CHECKING, cast
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+ if TYPE_CHECKING:
15
+ from collections.abc import Callable
16
+
17
+ __all__ = ["DefaultDevice", "LSTMnn", "MLP"]
18
+
19
+ DefaultDevice = torch.device("cpu")
20
+
21
+
22
+ class MLP(nn.Module):
23
+ """Multilayer Perceptron (MLP) for surrogate modeling and neural posterior estimation.
24
+
25
+ Attributes:
26
+ net: Sequential neural network model.
27
+
28
+ Methods:
29
+ forward: Forward pass through the neural network.
30
+
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ n_inputs: int,
36
+ layers: list[int],
37
+ activation: Callable | None = nn.ReLU(),
38
+ output_activation: Callable | None = None,
39
+ ) -> None:
40
+ """Initializes the MLP with the given number of inputs and list of (hidden) layers.
41
+
42
+ Args:
43
+ n_inputs (int): The number of input features.
44
+ n_outputs list(int): A list containing the number of neurons in hidden and output layer.
45
+ activation Callable | None (default nn.ReLU()): The activation function to be applied after each hidden layer
46
+ activation Callable | None (default None): The activation function to be applied after the final (output) layer
47
+
48
+ For instance, MLP(10, layers = [50, 50, 10]) initializes a neural network with the following architecture:
49
+ - Linear layer with `n_inputs` inputs and 50 outputs
50
+ - ReLU activation
51
+ - Linear layer with 50 inputs and 50 outputs
52
+ - ReLU activation
53
+ - Linear layer with 50 inputs and 10 outputs
54
+
55
+ The weights of the linear layers are initialized with a normal distribution
56
+ (mean=0, std=0.1) and the biases are initialized to 0.
57
+
58
+ """
59
+ super().__init__()
60
+ self.layers = layers
61
+ self.activation = activation
62
+ self.output_activation = output_activation
63
+
64
+ levels = []
65
+ previous_neurons = n_inputs
66
+
67
+ for idx, neurons in enumerate(self.layers):
68
+ if idx == (len(self.layers) - 1):
69
+ levels.append(nn.Linear(previous_neurons, neurons))
70
+
71
+ if self.output_activation:
72
+ levels.append(self.output_activation)
73
+
74
+ else:
75
+ levels.append(nn.Linear(previous_neurons, neurons))
76
+
77
+ if self.activation:
78
+ levels.append(self.activation)
79
+
80
+ previous_neurons = neurons
81
+
82
+ self.net = nn.Sequential(*levels)
83
+
84
+ for m in self.net.modules():
85
+ if isinstance(m, nn.Linear):
86
+ nn.init.normal_(m.weight, mean=0, std=0.1)
87
+ nn.init.constant_(m.bias, val=0)
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ """Forward pass through the neural network.
91
+
92
+ Args:
93
+ x: Input tensor.
94
+
95
+ Returns:
96
+ torch.Tensor: Output tensor.
97
+
98
+ """
99
+ return self.net(x)
100
+
101
+
102
+ class LSTMnn(nn.Module):
103
+ """Default LSTM neural network model for time-series approximation."""
104
+
105
+ def __init__(self, n_inputs: int, n_outputs: int, n_hidden: int) -> None:
106
+ """Initializes the neural network model.
107
+
108
+ Args:
109
+ n_inputs (int): Number of input features.
110
+ n_outputs (int): Number of output features.
111
+ n_hidden (int): Number of hidden units in the LSTM layer.
112
+
113
+ """
114
+ super().__init__()
115
+
116
+ self.n_hidden = n_hidden
117
+
118
+ self.lstm = nn.LSTM(n_inputs, n_hidden)
119
+ self.to_out = nn.Linear(n_hidden, n_outputs)
120
+
121
+ nn.init.normal_(self.to_out.weight, mean=0, std=0.1)
122
+ nn.init.constant_(self.to_out.bias, val=0)
123
+
124
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
125
+ """Forward pass through the neural network."""
126
+ # lstm_out, (hidden_state, cell_state)
127
+ _, (hn, _) = self.lstm(x)
128
+ return cast(torch.Tensor, self.to_out(hn[-1])) # Use last hidden state
modelbase2/npe.py CHANGED
@@ -4,10 +4,6 @@ This module provides classes and functions for training neural network models to
4
4
  parameters in metabolic models. It includes functionality for both steady-state and
5
5
  time-series data.
6
6
 
7
- Classes:
8
- DefaultSSAproximator: Default neural network model for steady-state approximation
9
- DefaultTimeSeriesApproximator: Default neural network model for time-series approximation
10
-
11
7
  Functions:
12
8
  train_torch_surrogate: Train a PyTorch surrogate model
13
9
  train_torch_time_course_estimator: Train a PyTorch time course estimator
@@ -18,9 +14,6 @@ from __future__ import annotations
18
14
  __all__ = [
19
15
  "AbstractEstimator",
20
16
  "DefaultCache",
21
- "DefaultDevice",
22
- "DefaultSSAproximator",
23
- "DefaultTimeSeriesApproximator",
24
17
  "TorchSSEstimator",
25
18
  "TorchTimeCourseEstimator",
26
19
  "train_torch_ss_estimator",
@@ -30,7 +23,7 @@ __all__ = [
30
23
  from abc import abstractmethod
31
24
  from dataclasses import dataclass
32
25
  from pathlib import Path
33
- from typing import cast
26
+ from typing import TYPE_CHECKING, cast
34
27
 
35
28
  import numpy as np
36
29
  import pandas as pd
@@ -39,75 +32,15 @@ import tqdm
39
32
  from torch import nn
40
33
  from torch.optim.adam import Adam
41
34
 
35
+ from modelbase2.nnarchitectures import MLP, DefaultDevice, LSTMnn
42
36
  from modelbase2.parallel import Cache
43
37
 
44
- DefaultDevice = torch.device("cpu")
45
- DefaultCache = Cache(Path(".cache"))
46
-
47
-
48
- class DefaultSSAproximator(nn.Module):
49
- """Default neural network model for steady-state approximation."""
50
-
51
- def __init__(self, n_inputs: int, n_outputs: int, n_hidden: int = 50) -> None:
52
- """Initializes the neural network with the specified number of inputs and outputs.
53
-
54
- Args:
55
- n_inputs (int): The number of input features.
56
- n_outputs (int): The number of output features.
57
- n_hidden (int): The number of hidden units in the fully connected layers
58
-
59
- The network consists of three fully connected layers with ReLU activations in between.
60
- The weights of the linear layers are initialized with a normal distribution (mean=0, std=0.1),
61
- and the biases are initialized to zero.
62
-
63
- """
64
- super().__init__()
65
-
66
- self.net = nn.Sequential(
67
- nn.Linear(n_inputs, n_hidden),
68
- nn.ReLU(),
69
- nn.Linear(n_hidden, n_hidden),
70
- nn.ReLU(),
71
- nn.Linear(n_hidden, n_outputs),
72
- )
73
-
74
- for m in self.net.modules():
75
- if isinstance(m, nn.Linear):
76
- nn.init.normal_(m.weight, mean=0, std=0.1)
77
- nn.init.constant_(m.bias, val=0)
38
+ if TYPE_CHECKING:
39
+ from collections.abc import Callable
78
40
 
79
- def forward(self, x: torch.Tensor) -> torch.Tensor:
80
- """Forward pass through the neural network."""
81
- return cast(torch.Tensor, self.net(x))
41
+ from torch.optim.optimizer import ParamsT
82
42
 
83
-
84
- class DefaultTimeSeriesApproximator(nn.Module):
85
- """Default neural network model for time-series approximation."""
86
-
87
- def __init__(self, n_inputs: int, n_outputs: int, n_hidden: int) -> None:
88
- """Initializes the neural network model.
89
-
90
- Args:
91
- n_inputs (int): Number of input features.
92
- n_outputs (int): Number of output features.
93
- n_hidden (int): Number of hidden units in the LSTM layer.
94
-
95
- """
96
- super().__init__()
97
-
98
- self.n_hidden = n_hidden
99
-
100
- self.lstm = nn.LSTM(n_inputs, n_hidden)
101
- self.to_out = nn.Linear(n_hidden, n_outputs)
102
-
103
- nn.init.normal_(self.to_out.weight, mean=0, std=0.1)
104
- nn.init.constant_(self.to_out.bias, val=0)
105
-
106
- def forward(self, x: torch.Tensor) -> torch.Tensor:
107
- """Forward pass through the neural network."""
108
- # lstm_out, (hidden_state, cell_state)
109
- _, (hn, _) = self.lstm(x)
110
- return cast(torch.Tensor, self.to_out(hn[-1])) # Use last hidden state
43
+ DefaultCache = Cache(Path(".cache"))
111
44
 
112
45
 
113
46
  @dataclass(kw_only=True)
@@ -212,7 +145,7 @@ def train_torch_ss_estimator(
212
145
  epochs: int,
213
146
  batch_size: int | None = None,
214
147
  approximator: nn.Module | None = None,
215
- optimimzer_cls: type[Adam] = Adam,
148
+ optimimzer_cls: Callable[[ParamsT], Adam] = Adam,
216
149
  device: torch.device = DefaultDevice,
217
150
  ) -> tuple[TorchSSEstimator, pd.Series]:
218
151
  """Train a PyTorch steady state estimator.
@@ -229,7 +162,7 @@ def train_torch_ss_estimator(
229
162
  targets: DataFrame containing the target values for training
230
163
  epochs: Number of training epochs
231
164
  batch_size: Size of mini-batches for training (None for full-batch)
232
- approximator: Predefined neural network model (None to use default)
165
+ approximator: Predefined neural network model (None to use default MLP)
233
166
  optimimzer_cls: Optimizer class to use for training (default: Adam)
234
167
  device: Device to run the training on (default: DefaultDevice)
235
168
 
@@ -238,10 +171,10 @@ def train_torch_ss_estimator(
238
171
 
239
172
  """
240
173
  if approximator is None:
241
- approximator = DefaultSSAproximator(
242
- n_inputs=len(features.columns),
243
- n_outputs=len(targets.columns),
244
- n_hidden=max(2 * len(features.columns) * len(targets.columns), 10),
174
+ n_hidden = max(2 * len(features.columns) * len(targets.columns), 10)
175
+ n_outputs = len(targets.columns)
176
+ approximator = MLP(
177
+ n_inputs=len(features.columns), layers=[n_hidden, n_hidden, n_outputs]
245
178
  ).to(device)
246
179
 
247
180
  features_ = torch.Tensor(features.to_numpy(), device=device)
@@ -278,7 +211,7 @@ def train_torch_time_course_estimator(
278
211
  epochs: int,
279
212
  batch_size: int | None = None,
280
213
  approximator: nn.Module | None = None,
281
- optimimzer_cls: type[Adam] = Adam,
214
+ optimimzer_cls: Callable[[ParamsT], Adam] = Adam,
282
215
  device: torch.device = DefaultDevice,
283
216
  ) -> tuple[TorchTimeCourseEstimator, pd.Series]:
284
217
  """Train a PyTorch time course estimator.
@@ -295,7 +228,7 @@ def train_torch_time_course_estimator(
295
228
  targets: DataFrame containing the target values for training
296
229
  epochs: Number of training epochs
297
230
  batch_size: Size of mini-batches for training (None for full-batch)
298
- approximator: Predefined neural network model (None to use default)
231
+ approximator: Predefined neural network model (None to use default LSTM)
299
232
  optimimzer_cls: Optimizer class to use for training (default: Adam)
300
233
  device: Device to run the training on (default: DefaultDevice)
301
234
 
@@ -304,7 +237,7 @@ def train_torch_time_course_estimator(
304
237
 
305
238
  """
306
239
  if approximator is None:
307
- approximator = DefaultTimeSeriesApproximator(
240
+ approximator = LSTMnn(
308
241
  n_inputs=len(features.columns),
309
242
  n_outputs=len(targets.columns),
310
243
  n_hidden=1,
modelbase2/plot.py CHANGED
@@ -399,7 +399,10 @@ def lines(
399
399
  fig, ax = _default_fig_ax(ax=ax, grid=grid)
400
400
  ax.plot(x.index, x)
401
401
  _default_labels(ax, xlabel=x.index.name, ylabel=None)
402
- ax.legend(x.columns)
402
+ if isinstance(x, pd.Series):
403
+ ax.legend([str(x.name)])
404
+ else:
405
+ ax.legend(x.columns)
403
406
  return fig, ax
404
407
 
405
408
 
modelbase2/simulator.py CHANGED
@@ -21,6 +21,8 @@ from modelbase2.integrators import DefaultIntegrator
21
21
  __all__ = ["Simulator"]
22
22
 
23
23
  if TYPE_CHECKING:
24
+ from collections.abc import Callable
25
+
24
26
  from modelbase2.model import Model
25
27
  from modelbase2.types import ArrayLike, IntegratorProtocol
26
28
 
@@ -83,7 +85,9 @@ class Simulator:
83
85
  self,
84
86
  model: Model,
85
87
  y0: dict[str, float] | None = None,
86
- integrator: type[IntegratorProtocol] = DefaultIntegrator,
88
+ integrator: Callable[
89
+ [Callable, ArrayLike], IntegratorProtocol
90
+ ] = DefaultIntegrator,
87
91
  *,
88
92
  test_run: bool = True,
89
93
  ) -> None:
@@ -93,7 +97,7 @@ class Simulator:
93
97
  model (Model): The model to be simulated.
94
98
  y0 (dict[str, float] | None, optional): Initial conditions for the model variables.
95
99
  If None, the initial conditions are obtained from the model. Defaults to None.
96
- integrator (type[IntegratorProtocol], optional): The integrator to use for the simulation.
100
+ integrator (Callable[[Callable, ArrayLike], IntegratorProtocol], optional): The integrator to use for the simulation.
97
101
  Defaults to DefaultIntegrator.
98
102
  test_run (bool, optional): If True, performs a test run to ensure the model's methods
99
103
  (get_full_concs, get_fluxes, get_right_hand_side) work correctly with the initial conditions.
@@ -104,7 +108,7 @@ class Simulator:
104
108
  y0 = model.get_initial_conditions() if y0 is None else y0
105
109
  self.y0 = [y0[k] for k in model.get_variable_names()]
106
110
 
107
- self.integrator = integrator(self.model, y0=self.y0)
111
+ self.integrator = integrator(self.model, self.y0)
108
112
  self.concs = None
109
113
  self.args = None
110
114
  self.simulation_parameters = None