modelbase2 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modelbase2/__init__.py +12 -1
- modelbase2/distributions.py +33 -0
- modelbase2/experimental/__init__.py +2 -0
- modelbase2/experimental/_backup.py +1017 -0
- modelbase2/experimental/strikepy.py +562 -0
- modelbase2/experimental/symbolic.py +286 -0
- modelbase2/fit.py +6 -6
- modelbase2/model.py +0 -1
- modelbase2/nnarchitectures.py +128 -0
- modelbase2/npe.py +15 -82
- modelbase2/plot.py +4 -1
- modelbase2/simulator.py +7 -3
- modelbase2/surrogates/__init__.py +1 -2
- modelbase2/surrogates/_poly.py +32 -5
- modelbase2/surrogates/_torch.py +8 -64
- modelbase2/surrogates.py +7 -1
- {modelbase2-0.2.0.dist-info → modelbase2-0.4.0.dist-info}/METADATA +14 -1
- {modelbase2-0.2.0.dist-info → modelbase2-0.4.0.dist-info}/RECORD +20 -16
- {modelbase2-0.2.0.dist-info → modelbase2-0.4.0.dist-info}/WHEEL +0 -0
- {modelbase2-0.2.0.dist-info → modelbase2-0.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,286 @@
|
|
1
|
+
# ruff: noqa: D100, D101, D102, D103, D104, D105, D106, D107, D200, D203, D400, D401
|
2
|
+
|
3
|
+
import ast
|
4
|
+
import inspect
|
5
|
+
import textwrap
|
6
|
+
from collections.abc import Callable
|
7
|
+
from dataclasses import dataclass
|
8
|
+
from typing import Any, cast
|
9
|
+
|
10
|
+
import sympy
|
11
|
+
|
12
|
+
from modelbase2.model import Model
|
13
|
+
|
14
|
+
__all__ = ["Context", "SymbolicModel", "model_fn_to_sympy", "to_symbolic_model"]
|
15
|
+
|
16
|
+
|
17
|
+
@dataclass
|
18
|
+
class Context:
|
19
|
+
symbols: dict[str, sympy.Symbol | sympy.Expr]
|
20
|
+
caller: Callable
|
21
|
+
|
22
|
+
|
23
|
+
@dataclass
|
24
|
+
class SymbolicModel:
|
25
|
+
variables: dict[str, sympy.Symbol]
|
26
|
+
parameters: dict[str, sympy.Symbol]
|
27
|
+
eqs: list[sympy.Expr]
|
28
|
+
|
29
|
+
|
30
|
+
def to_symbolic_model(model: Model) -> SymbolicModel:
|
31
|
+
cache = model._create_cache() # noqa: SLF001
|
32
|
+
|
33
|
+
variables = dict(
|
34
|
+
zip(model.variables, sympy.symbols(list(model.variables)), strict=True)
|
35
|
+
)
|
36
|
+
parameters = dict(
|
37
|
+
zip(model.parameters, sympy.symbols(list(model.parameters)), strict=True)
|
38
|
+
)
|
39
|
+
symbols = variables | parameters
|
40
|
+
|
41
|
+
for k, v in model.derived.items():
|
42
|
+
symbols[k] = model_fn_to_sympy(v.fn, [symbols[i] for i in v.args])
|
43
|
+
|
44
|
+
rxns = {
|
45
|
+
k: model_fn_to_sympy(v.fn, [symbols[i] for i in v.args])
|
46
|
+
for k, v in model.reactions.items()
|
47
|
+
}
|
48
|
+
|
49
|
+
eqs: dict[str, sympy.Expr] = {}
|
50
|
+
for cpd, stoich in cache.stoich_by_cpds.items():
|
51
|
+
for rxn, stoich_value in stoich.items():
|
52
|
+
eqs[cpd] = (
|
53
|
+
eqs.get(cpd, sympy.Float(0.0)) + sympy.Float(stoich_value) * rxns[rxn] # type: ignore
|
54
|
+
)
|
55
|
+
|
56
|
+
for cpd, dstoich in cache.dyn_stoich_by_cpds.items():
|
57
|
+
for rxn, der in dstoich.items():
|
58
|
+
eqs[cpd] = eqs.get(cpd, sympy.Float(0.0)) + model_fn_to_sympy(
|
59
|
+
der.fn,
|
60
|
+
[symbols[i] for i in der.args] * rxns[rxn], # type: ignore
|
61
|
+
) # type: ignore
|
62
|
+
|
63
|
+
return SymbolicModel(
|
64
|
+
variables=variables,
|
65
|
+
parameters=parameters,
|
66
|
+
eqs=[eqs[i] for i in cache.var_names],
|
67
|
+
)
|
68
|
+
|
69
|
+
|
70
|
+
def model_fn_to_sympy(
|
71
|
+
fn: Callable, model_args: list[sympy.Symbol | sympy.Expr] | None = None
|
72
|
+
) -> sympy.Expr:
|
73
|
+
source = textwrap.dedent(inspect.getsource(fn))
|
74
|
+
|
75
|
+
if not isinstance(fn_def := ast.parse(source).body[0], ast.FunctionDef):
|
76
|
+
msg = "Expected a function definition"
|
77
|
+
raise TypeError(msg)
|
78
|
+
|
79
|
+
fn_args = [str(arg.arg) for arg in fn_def.args.args]
|
80
|
+
|
81
|
+
sympy_expr = _handle_fn_body(
|
82
|
+
fn_def.body,
|
83
|
+
ctx=Context(
|
84
|
+
symbols={name: sympy.Symbol(name) for name in fn_args},
|
85
|
+
caller=fn,
|
86
|
+
),
|
87
|
+
)
|
88
|
+
|
89
|
+
if model_args is not None:
|
90
|
+
sympy_expr = sympy_expr.subs(dict(zip(fn_args, model_args, strict=True)))
|
91
|
+
|
92
|
+
return cast(sympy.Expr, sympy_expr)
|
93
|
+
|
94
|
+
|
95
|
+
def _handle_fn_body(body: list[ast.stmt], ctx: Context) -> sympy.Expr:
|
96
|
+
pieces = []
|
97
|
+
remaining_body = list(body)
|
98
|
+
|
99
|
+
while remaining_body:
|
100
|
+
node = remaining_body.pop(0)
|
101
|
+
|
102
|
+
if isinstance(node, ast.If):
|
103
|
+
condition = _handle_expr(node.test, ctx)
|
104
|
+
if_expr = _handle_fn_body(node.body, ctx)
|
105
|
+
pieces.append((if_expr, condition))
|
106
|
+
|
107
|
+
# If there's an else clause
|
108
|
+
if node.orelse:
|
109
|
+
# Check if it's an elif (an If node in orelse)
|
110
|
+
if len(node.orelse) == 1 and isinstance(node.orelse[0], ast.If):
|
111
|
+
# Push the elif back to the beginning of remaining_body to process next
|
112
|
+
remaining_body.insert(0, node.orelse[0])
|
113
|
+
else:
|
114
|
+
# It's a regular else
|
115
|
+
else_expr = _handle_fn_body(node.orelse, ctx) # FIXME: copy here
|
116
|
+
pieces.append((else_expr, True))
|
117
|
+
break # We're done with this chain
|
118
|
+
|
119
|
+
elif not remaining_body and any(
|
120
|
+
isinstance(n, ast.Return) for n in body[body.index(node) + 1 :]
|
121
|
+
):
|
122
|
+
else_expr = _handle_fn_body(
|
123
|
+
body[body.index(node) + 1 :], ctx
|
124
|
+
) # FIXME: copy here
|
125
|
+
pieces.append((else_expr, True))
|
126
|
+
|
127
|
+
elif isinstance(node, ast.Return):
|
128
|
+
if (value := node.value) is None:
|
129
|
+
msg = "Return value cannot be None"
|
130
|
+
raise ValueError(msg)
|
131
|
+
|
132
|
+
expr = _handle_expr(value, ctx)
|
133
|
+
if not pieces:
|
134
|
+
return expr
|
135
|
+
pieces.append((expr, True))
|
136
|
+
break
|
137
|
+
|
138
|
+
elif isinstance(node, ast.Assign):
|
139
|
+
# Handle tuple assignments like c, d = a, b
|
140
|
+
if isinstance(node.targets[0], ast.Tuple):
|
141
|
+
# Handle tuple unpacking
|
142
|
+
target_elements = node.targets[0].elts
|
143
|
+
|
144
|
+
if isinstance(node.value, ast.Tuple):
|
145
|
+
# Direct unpacking like c, d = a, b
|
146
|
+
value_elements = node.value.elts
|
147
|
+
for target, value_expr in zip(
|
148
|
+
target_elements, value_elements, strict=True
|
149
|
+
):
|
150
|
+
if isinstance(target, ast.Name):
|
151
|
+
ctx.symbols[target.id] = _handle_expr(value_expr, ctx)
|
152
|
+
else:
|
153
|
+
# Handle potential iterable unpacking
|
154
|
+
value = _handle_expr(node.value, ctx)
|
155
|
+
else:
|
156
|
+
# Regular single assignment
|
157
|
+
if not isinstance(target := node.targets[0], ast.Name):
|
158
|
+
msg = "Only single variable assignments are supported"
|
159
|
+
raise TypeError(msg)
|
160
|
+
target_name = target.id
|
161
|
+
value = _handle_expr(node.value, ctx)
|
162
|
+
ctx.symbols[target_name] = value
|
163
|
+
|
164
|
+
# If we have pieces to combine into a Piecewise
|
165
|
+
if pieces:
|
166
|
+
return sympy.Piecewise(*pieces)
|
167
|
+
|
168
|
+
# If no return was found but we have assignments, return the last assigned variable
|
169
|
+
for node in reversed(body):
|
170
|
+
if isinstance(node, ast.Assign) and isinstance(node.targets[0], ast.Name):
|
171
|
+
target_name = node.targets[0].id
|
172
|
+
return ctx.symbols[target_name]
|
173
|
+
|
174
|
+
msg = "No return value found in function body"
|
175
|
+
raise ValueError(msg)
|
176
|
+
|
177
|
+
|
178
|
+
def _handle_unaryop(node: ast.UnaryOp, ctx: Context) -> sympy.Expr:
|
179
|
+
left = _handle_expr(node.operand, ctx)
|
180
|
+
left = cast(Any, left) # stupid sympy types don't allow ops on symbols
|
181
|
+
|
182
|
+
if isinstance(node.op, ast.UAdd):
|
183
|
+
return +left
|
184
|
+
if isinstance(node.op, ast.USub):
|
185
|
+
return -left
|
186
|
+
|
187
|
+
msg = f"Operation {type(node.op).__name__} not implemented"
|
188
|
+
raise NotImplementedError(msg)
|
189
|
+
|
190
|
+
|
191
|
+
def _handle_binop(node: ast.BinOp, ctx: Context) -> sympy.Expr:
|
192
|
+
left = _handle_expr(node.left, ctx)
|
193
|
+
left = cast(Any, left) # stupid sympy types don't allow ops on symbols
|
194
|
+
|
195
|
+
right = _handle_expr(node.right, ctx)
|
196
|
+
right = cast(Any, right) # stupid sympy types don't allow ops on symbols
|
197
|
+
|
198
|
+
if isinstance(node.op, ast.Add):
|
199
|
+
return left + right
|
200
|
+
if isinstance(node.op, ast.Sub):
|
201
|
+
return left - right
|
202
|
+
if isinstance(node.op, ast.Mult):
|
203
|
+
return left * right
|
204
|
+
if isinstance(node.op, ast.Div):
|
205
|
+
return left / right
|
206
|
+
if isinstance(node.op, ast.Pow):
|
207
|
+
return left**right
|
208
|
+
if isinstance(node.op, ast.Mod):
|
209
|
+
return left % right
|
210
|
+
if isinstance(node.op, ast.FloorDiv):
|
211
|
+
return left // right
|
212
|
+
|
213
|
+
msg = f"Operation {type(node.op).__name__} not implemented"
|
214
|
+
raise NotImplementedError(msg)
|
215
|
+
|
216
|
+
|
217
|
+
def _handle_call(node: ast.Call, ctx: Context) -> sympy.Expr:
|
218
|
+
if not isinstance(callee := node.func, ast.Name):
|
219
|
+
msg = "Only function calls with names are supported"
|
220
|
+
raise TypeError(msg)
|
221
|
+
|
222
|
+
fn_name = str(callee.id)
|
223
|
+
parent_module = inspect.getmodule(ctx.caller)
|
224
|
+
fns = dict(inspect.getmembers(parent_module, predicate=callable))
|
225
|
+
|
226
|
+
return model_fn_to_sympy(
|
227
|
+
fns[fn_name],
|
228
|
+
model_args=[_handle_expr(i, ctx) for i in node.args],
|
229
|
+
)
|
230
|
+
|
231
|
+
|
232
|
+
def _handle_name(node: ast.Name, ctx: Context) -> sympy.Symbol | sympy.Expr:
|
233
|
+
return ctx.symbols[node.id]
|
234
|
+
|
235
|
+
|
236
|
+
def _handle_expr(node: ast.expr, ctx: Context) -> sympy.Expr:
|
237
|
+
if isinstance(node, ast.UnaryOp):
|
238
|
+
return _handle_unaryop(node, ctx)
|
239
|
+
if isinstance(node, ast.BinOp):
|
240
|
+
return _handle_binop(node, ctx)
|
241
|
+
if isinstance(node, ast.Name):
|
242
|
+
return _handle_name(node, ctx)
|
243
|
+
if isinstance(node, ast.Constant):
|
244
|
+
return node.value
|
245
|
+
if isinstance(node, ast.Compare):
|
246
|
+
# Handle chained comparisons like 1 < a < 2
|
247
|
+
left = cast(Any, _handle_expr(node.left, ctx))
|
248
|
+
comparisons = []
|
249
|
+
|
250
|
+
# Build all individual comparisons from the chain
|
251
|
+
prev_value = left
|
252
|
+
for op, comparator in zip(node.ops, node.comparators, strict=True):
|
253
|
+
right = cast(Any, _handle_expr(comparator, ctx))
|
254
|
+
|
255
|
+
if isinstance(op, ast.Gt):
|
256
|
+
comparisons.append(prev_value > right)
|
257
|
+
elif isinstance(op, ast.GtE):
|
258
|
+
comparisons.append(prev_value >= right)
|
259
|
+
elif isinstance(op, ast.Lt):
|
260
|
+
comparisons.append(prev_value < right)
|
261
|
+
elif isinstance(op, ast.LtE):
|
262
|
+
comparisons.append(prev_value <= right)
|
263
|
+
elif isinstance(op, ast.Eq):
|
264
|
+
comparisons.append(prev_value == right)
|
265
|
+
elif isinstance(op, ast.NotEq):
|
266
|
+
comparisons.append(prev_value != right)
|
267
|
+
|
268
|
+
prev_value = right
|
269
|
+
|
270
|
+
# Combine all comparisons with logical AND
|
271
|
+
result = comparisons[0]
|
272
|
+
for comp in comparisons[1:]:
|
273
|
+
result = sympy.And(result, comp)
|
274
|
+
return cast(sympy.Expr, result)
|
275
|
+
if isinstance(node, ast.Call):
|
276
|
+
return _handle_call(node, ctx)
|
277
|
+
|
278
|
+
# Handle conditional expressions (ternary operators)
|
279
|
+
if isinstance(node, ast.IfExp):
|
280
|
+
condition = _handle_expr(node.test, ctx)
|
281
|
+
if_true = _handle_expr(node.body, ctx)
|
282
|
+
if_false = _handle_expr(node.orelse, ctx)
|
283
|
+
return sympy.Piecewise((if_true, condition), (if_false, True))
|
284
|
+
|
285
|
+
msg = f"Expression type {type(node).__name__} not implemented"
|
286
|
+
raise NotImplementedError(msg)
|
modelbase2/fit.py
CHANGED
@@ -50,7 +50,7 @@ class SteadyStateResidualFn(Protocol):
|
|
50
50
|
data: pd.Series,
|
51
51
|
model: Model,
|
52
52
|
y0: dict[str, float],
|
53
|
-
integrator:
|
53
|
+
integrator: Callable[[Callable, ArrayLike], IntegratorProtocol],
|
54
54
|
) -> float:
|
55
55
|
"""Calculate residual error between model steady state and experimental data."""
|
56
56
|
...
|
@@ -67,7 +67,7 @@ class TimeSeriesResidualFn(Protocol):
|
|
67
67
|
data: pd.DataFrame,
|
68
68
|
model: Model,
|
69
69
|
y0: dict[str, float],
|
70
|
-
integrator:
|
70
|
+
integrator: Callable[[Callable, ArrayLike], IntegratorProtocol],
|
71
71
|
) -> float:
|
72
72
|
"""Calculate residual error between model time course and experimental data."""
|
73
73
|
...
|
@@ -101,7 +101,7 @@ def _steady_state_residual(
|
|
101
101
|
data: pd.Series,
|
102
102
|
model: Model,
|
103
103
|
y0: dict[str, float] | None,
|
104
|
-
integrator:
|
104
|
+
integrator: Callable[[Callable, ArrayLike], IntegratorProtocol],
|
105
105
|
) -> float:
|
106
106
|
"""Calculate residual error between model steady state and experimental data.
|
107
107
|
|
@@ -148,7 +148,7 @@ def _time_course_residual(
|
|
148
148
|
data: pd.DataFrame,
|
149
149
|
model: Model,
|
150
150
|
y0: dict[str, float],
|
151
|
-
integrator:
|
151
|
+
integrator: Callable[[Callable, ArrayLike], IntegratorProtocol],
|
152
152
|
) -> float:
|
153
153
|
"""Calculate residual error between model time course and experimental data.
|
154
154
|
|
@@ -187,7 +187,7 @@ def steady_state(
|
|
187
187
|
y0: dict[str, float] | None = None,
|
188
188
|
minimize_fn: MinimizeFn = _default_minimize_fn,
|
189
189
|
residual_fn: SteadyStateResidualFn = _steady_state_residual,
|
190
|
-
integrator:
|
190
|
+
integrator: Callable[[Callable, ArrayLike], IntegratorProtocol] = DefaultIntegrator,
|
191
191
|
) -> dict[str, float]:
|
192
192
|
"""Fit model parameters to steady-state experimental data.
|
193
193
|
|
@@ -241,7 +241,7 @@ def time_course(
|
|
241
241
|
y0: dict[str, float] | None = None,
|
242
242
|
minimize_fn: MinimizeFn = _default_minimize_fn,
|
243
243
|
residual_fn: TimeSeriesResidualFn = _time_course_residual,
|
244
|
-
integrator:
|
244
|
+
integrator: Callable[[Callable, ArrayLike], IntegratorProtocol] = DefaultIntegrator,
|
245
245
|
) -> dict[str, float]:
|
246
246
|
"""Fit model parameters to time course of experimental data.
|
247
247
|
|
modelbase2/model.py
CHANGED
@@ -0,0 +1,128 @@
|
|
1
|
+
"""Neural network architectures.
|
2
|
+
|
3
|
+
This module provides implementations of neural network architectures used for mechanistic learning.
|
4
|
+
|
5
|
+
"""
|
6
|
+
|
7
|
+
from __future__ import annotations
|
8
|
+
|
9
|
+
from typing import TYPE_CHECKING, cast
|
10
|
+
|
11
|
+
import torch
|
12
|
+
from torch import nn
|
13
|
+
|
14
|
+
if TYPE_CHECKING:
|
15
|
+
from collections.abc import Callable
|
16
|
+
|
17
|
+
__all__ = ["DefaultDevice", "LSTMnn", "MLP"]
|
18
|
+
|
19
|
+
DefaultDevice = torch.device("cpu")
|
20
|
+
|
21
|
+
|
22
|
+
class MLP(nn.Module):
|
23
|
+
"""Multilayer Perceptron (MLP) for surrogate modeling and neural posterior estimation.
|
24
|
+
|
25
|
+
Attributes:
|
26
|
+
net: Sequential neural network model.
|
27
|
+
|
28
|
+
Methods:
|
29
|
+
forward: Forward pass through the neural network.
|
30
|
+
|
31
|
+
"""
|
32
|
+
|
33
|
+
def __init__(
|
34
|
+
self,
|
35
|
+
n_inputs: int,
|
36
|
+
layers: list[int],
|
37
|
+
activation: Callable | None = nn.ReLU(),
|
38
|
+
output_activation: Callable | None = None,
|
39
|
+
) -> None:
|
40
|
+
"""Initializes the MLP with the given number of inputs and list of (hidden) layers.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
n_inputs (int): The number of input features.
|
44
|
+
n_outputs list(int): A list containing the number of neurons in hidden and output layer.
|
45
|
+
activation Callable | None (default nn.ReLU()): The activation function to be applied after each hidden layer
|
46
|
+
activation Callable | None (default None): The activation function to be applied after the final (output) layer
|
47
|
+
|
48
|
+
For instance, MLP(10, layers = [50, 50, 10]) initializes a neural network with the following architecture:
|
49
|
+
- Linear layer with `n_inputs` inputs and 50 outputs
|
50
|
+
- ReLU activation
|
51
|
+
- Linear layer with 50 inputs and 50 outputs
|
52
|
+
- ReLU activation
|
53
|
+
- Linear layer with 50 inputs and 10 outputs
|
54
|
+
|
55
|
+
The weights of the linear layers are initialized with a normal distribution
|
56
|
+
(mean=0, std=0.1) and the biases are initialized to 0.
|
57
|
+
|
58
|
+
"""
|
59
|
+
super().__init__()
|
60
|
+
self.layers = layers
|
61
|
+
self.activation = activation
|
62
|
+
self.output_activation = output_activation
|
63
|
+
|
64
|
+
levels = []
|
65
|
+
previous_neurons = n_inputs
|
66
|
+
|
67
|
+
for idx, neurons in enumerate(self.layers):
|
68
|
+
if idx == (len(self.layers) - 1):
|
69
|
+
levels.append(nn.Linear(previous_neurons, neurons))
|
70
|
+
|
71
|
+
if self.output_activation:
|
72
|
+
levels.append(self.output_activation)
|
73
|
+
|
74
|
+
else:
|
75
|
+
levels.append(nn.Linear(previous_neurons, neurons))
|
76
|
+
|
77
|
+
if self.activation:
|
78
|
+
levels.append(self.activation)
|
79
|
+
|
80
|
+
previous_neurons = neurons
|
81
|
+
|
82
|
+
self.net = nn.Sequential(*levels)
|
83
|
+
|
84
|
+
for m in self.net.modules():
|
85
|
+
if isinstance(m, nn.Linear):
|
86
|
+
nn.init.normal_(m.weight, mean=0, std=0.1)
|
87
|
+
nn.init.constant_(m.bias, val=0)
|
88
|
+
|
89
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
90
|
+
"""Forward pass through the neural network.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
x: Input tensor.
|
94
|
+
|
95
|
+
Returns:
|
96
|
+
torch.Tensor: Output tensor.
|
97
|
+
|
98
|
+
"""
|
99
|
+
return self.net(x)
|
100
|
+
|
101
|
+
|
102
|
+
class LSTMnn(nn.Module):
|
103
|
+
"""Default LSTM neural network model for time-series approximation."""
|
104
|
+
|
105
|
+
def __init__(self, n_inputs: int, n_outputs: int, n_hidden: int) -> None:
|
106
|
+
"""Initializes the neural network model.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
n_inputs (int): Number of input features.
|
110
|
+
n_outputs (int): Number of output features.
|
111
|
+
n_hidden (int): Number of hidden units in the LSTM layer.
|
112
|
+
|
113
|
+
"""
|
114
|
+
super().__init__()
|
115
|
+
|
116
|
+
self.n_hidden = n_hidden
|
117
|
+
|
118
|
+
self.lstm = nn.LSTM(n_inputs, n_hidden)
|
119
|
+
self.to_out = nn.Linear(n_hidden, n_outputs)
|
120
|
+
|
121
|
+
nn.init.normal_(self.to_out.weight, mean=0, std=0.1)
|
122
|
+
nn.init.constant_(self.to_out.bias, val=0)
|
123
|
+
|
124
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
125
|
+
"""Forward pass through the neural network."""
|
126
|
+
# lstm_out, (hidden_state, cell_state)
|
127
|
+
_, (hn, _) = self.lstm(x)
|
128
|
+
return cast(torch.Tensor, self.to_out(hn[-1])) # Use last hidden state
|
modelbase2/npe.py
CHANGED
@@ -4,10 +4,6 @@ This module provides classes and functions for training neural network models to
|
|
4
4
|
parameters in metabolic models. It includes functionality for both steady-state and
|
5
5
|
time-series data.
|
6
6
|
|
7
|
-
Classes:
|
8
|
-
DefaultSSAproximator: Default neural network model for steady-state approximation
|
9
|
-
DefaultTimeSeriesApproximator: Default neural network model for time-series approximation
|
10
|
-
|
11
7
|
Functions:
|
12
8
|
train_torch_surrogate: Train a PyTorch surrogate model
|
13
9
|
train_torch_time_course_estimator: Train a PyTorch time course estimator
|
@@ -18,9 +14,6 @@ from __future__ import annotations
|
|
18
14
|
__all__ = [
|
19
15
|
"AbstractEstimator",
|
20
16
|
"DefaultCache",
|
21
|
-
"DefaultDevice",
|
22
|
-
"DefaultSSAproximator",
|
23
|
-
"DefaultTimeSeriesApproximator",
|
24
17
|
"TorchSSEstimator",
|
25
18
|
"TorchTimeCourseEstimator",
|
26
19
|
"train_torch_ss_estimator",
|
@@ -30,7 +23,7 @@ __all__ = [
|
|
30
23
|
from abc import abstractmethod
|
31
24
|
from dataclasses import dataclass
|
32
25
|
from pathlib import Path
|
33
|
-
from typing import cast
|
26
|
+
from typing import TYPE_CHECKING, cast
|
34
27
|
|
35
28
|
import numpy as np
|
36
29
|
import pandas as pd
|
@@ -39,75 +32,15 @@ import tqdm
|
|
39
32
|
from torch import nn
|
40
33
|
from torch.optim.adam import Adam
|
41
34
|
|
35
|
+
from modelbase2.nnarchitectures import MLP, DefaultDevice, LSTMnn
|
42
36
|
from modelbase2.parallel import Cache
|
43
37
|
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
class DefaultSSAproximator(nn.Module):
|
49
|
-
"""Default neural network model for steady-state approximation."""
|
50
|
-
|
51
|
-
def __init__(self, n_inputs: int, n_outputs: int, n_hidden: int = 50) -> None:
|
52
|
-
"""Initializes the neural network with the specified number of inputs and outputs.
|
53
|
-
|
54
|
-
Args:
|
55
|
-
n_inputs (int): The number of input features.
|
56
|
-
n_outputs (int): The number of output features.
|
57
|
-
n_hidden (int): The number of hidden units in the fully connected layers
|
58
|
-
|
59
|
-
The network consists of three fully connected layers with ReLU activations in between.
|
60
|
-
The weights of the linear layers are initialized with a normal distribution (mean=0, std=0.1),
|
61
|
-
and the biases are initialized to zero.
|
62
|
-
|
63
|
-
"""
|
64
|
-
super().__init__()
|
65
|
-
|
66
|
-
self.net = nn.Sequential(
|
67
|
-
nn.Linear(n_inputs, n_hidden),
|
68
|
-
nn.ReLU(),
|
69
|
-
nn.Linear(n_hidden, n_hidden),
|
70
|
-
nn.ReLU(),
|
71
|
-
nn.Linear(n_hidden, n_outputs),
|
72
|
-
)
|
73
|
-
|
74
|
-
for m in self.net.modules():
|
75
|
-
if isinstance(m, nn.Linear):
|
76
|
-
nn.init.normal_(m.weight, mean=0, std=0.1)
|
77
|
-
nn.init.constant_(m.bias, val=0)
|
38
|
+
if TYPE_CHECKING:
|
39
|
+
from collections.abc import Callable
|
78
40
|
|
79
|
-
|
80
|
-
"""Forward pass through the neural network."""
|
81
|
-
return cast(torch.Tensor, self.net(x))
|
41
|
+
from torch.optim.optimizer import ParamsT
|
82
42
|
|
83
|
-
|
84
|
-
class DefaultTimeSeriesApproximator(nn.Module):
|
85
|
-
"""Default neural network model for time-series approximation."""
|
86
|
-
|
87
|
-
def __init__(self, n_inputs: int, n_outputs: int, n_hidden: int) -> None:
|
88
|
-
"""Initializes the neural network model.
|
89
|
-
|
90
|
-
Args:
|
91
|
-
n_inputs (int): Number of input features.
|
92
|
-
n_outputs (int): Number of output features.
|
93
|
-
n_hidden (int): Number of hidden units in the LSTM layer.
|
94
|
-
|
95
|
-
"""
|
96
|
-
super().__init__()
|
97
|
-
|
98
|
-
self.n_hidden = n_hidden
|
99
|
-
|
100
|
-
self.lstm = nn.LSTM(n_inputs, n_hidden)
|
101
|
-
self.to_out = nn.Linear(n_hidden, n_outputs)
|
102
|
-
|
103
|
-
nn.init.normal_(self.to_out.weight, mean=0, std=0.1)
|
104
|
-
nn.init.constant_(self.to_out.bias, val=0)
|
105
|
-
|
106
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
107
|
-
"""Forward pass through the neural network."""
|
108
|
-
# lstm_out, (hidden_state, cell_state)
|
109
|
-
_, (hn, _) = self.lstm(x)
|
110
|
-
return cast(torch.Tensor, self.to_out(hn[-1])) # Use last hidden state
|
43
|
+
DefaultCache = Cache(Path(".cache"))
|
111
44
|
|
112
45
|
|
113
46
|
@dataclass(kw_only=True)
|
@@ -212,7 +145,7 @@ def train_torch_ss_estimator(
|
|
212
145
|
epochs: int,
|
213
146
|
batch_size: int | None = None,
|
214
147
|
approximator: nn.Module | None = None,
|
215
|
-
optimimzer_cls:
|
148
|
+
optimimzer_cls: Callable[[ParamsT], Adam] = Adam,
|
216
149
|
device: torch.device = DefaultDevice,
|
217
150
|
) -> tuple[TorchSSEstimator, pd.Series]:
|
218
151
|
"""Train a PyTorch steady state estimator.
|
@@ -229,7 +162,7 @@ def train_torch_ss_estimator(
|
|
229
162
|
targets: DataFrame containing the target values for training
|
230
163
|
epochs: Number of training epochs
|
231
164
|
batch_size: Size of mini-batches for training (None for full-batch)
|
232
|
-
approximator: Predefined neural network model (None to use default)
|
165
|
+
approximator: Predefined neural network model (None to use default MLP)
|
233
166
|
optimimzer_cls: Optimizer class to use for training (default: Adam)
|
234
167
|
device: Device to run the training on (default: DefaultDevice)
|
235
168
|
|
@@ -238,10 +171,10 @@ def train_torch_ss_estimator(
|
|
238
171
|
|
239
172
|
"""
|
240
173
|
if approximator is None:
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
174
|
+
n_hidden = max(2 * len(features.columns) * len(targets.columns), 10)
|
175
|
+
n_outputs = len(targets.columns)
|
176
|
+
approximator = MLP(
|
177
|
+
n_inputs=len(features.columns), layers=[n_hidden, n_hidden, n_outputs]
|
245
178
|
).to(device)
|
246
179
|
|
247
180
|
features_ = torch.Tensor(features.to_numpy(), device=device)
|
@@ -278,7 +211,7 @@ def train_torch_time_course_estimator(
|
|
278
211
|
epochs: int,
|
279
212
|
batch_size: int | None = None,
|
280
213
|
approximator: nn.Module | None = None,
|
281
|
-
optimimzer_cls:
|
214
|
+
optimimzer_cls: Callable[[ParamsT], Adam] = Adam,
|
282
215
|
device: torch.device = DefaultDevice,
|
283
216
|
) -> tuple[TorchTimeCourseEstimator, pd.Series]:
|
284
217
|
"""Train a PyTorch time course estimator.
|
@@ -295,7 +228,7 @@ def train_torch_time_course_estimator(
|
|
295
228
|
targets: DataFrame containing the target values for training
|
296
229
|
epochs: Number of training epochs
|
297
230
|
batch_size: Size of mini-batches for training (None for full-batch)
|
298
|
-
approximator: Predefined neural network model (None to use default)
|
231
|
+
approximator: Predefined neural network model (None to use default LSTM)
|
299
232
|
optimimzer_cls: Optimizer class to use for training (default: Adam)
|
300
233
|
device: Device to run the training on (default: DefaultDevice)
|
301
234
|
|
@@ -304,7 +237,7 @@ def train_torch_time_course_estimator(
|
|
304
237
|
|
305
238
|
"""
|
306
239
|
if approximator is None:
|
307
|
-
approximator =
|
240
|
+
approximator = LSTMnn(
|
308
241
|
n_inputs=len(features.columns),
|
309
242
|
n_outputs=len(targets.columns),
|
310
243
|
n_hidden=1,
|
modelbase2/plot.py
CHANGED
@@ -399,7 +399,10 @@ def lines(
|
|
399
399
|
fig, ax = _default_fig_ax(ax=ax, grid=grid)
|
400
400
|
ax.plot(x.index, x)
|
401
401
|
_default_labels(ax, xlabel=x.index.name, ylabel=None)
|
402
|
-
|
402
|
+
if isinstance(x, pd.Series):
|
403
|
+
ax.legend([str(x.name)])
|
404
|
+
else:
|
405
|
+
ax.legend(x.columns)
|
403
406
|
return fig, ax
|
404
407
|
|
405
408
|
|
modelbase2/simulator.py
CHANGED
@@ -21,6 +21,8 @@ from modelbase2.integrators import DefaultIntegrator
|
|
21
21
|
__all__ = ["Simulator"]
|
22
22
|
|
23
23
|
if TYPE_CHECKING:
|
24
|
+
from collections.abc import Callable
|
25
|
+
|
24
26
|
from modelbase2.model import Model
|
25
27
|
from modelbase2.types import ArrayLike, IntegratorProtocol
|
26
28
|
|
@@ -83,7 +85,9 @@ class Simulator:
|
|
83
85
|
self,
|
84
86
|
model: Model,
|
85
87
|
y0: dict[str, float] | None = None,
|
86
|
-
integrator:
|
88
|
+
integrator: Callable[
|
89
|
+
[Callable, ArrayLike], IntegratorProtocol
|
90
|
+
] = DefaultIntegrator,
|
87
91
|
*,
|
88
92
|
test_run: bool = True,
|
89
93
|
) -> None:
|
@@ -93,7 +97,7 @@ class Simulator:
|
|
93
97
|
model (Model): The model to be simulated.
|
94
98
|
y0 (dict[str, float] | None, optional): Initial conditions for the model variables.
|
95
99
|
If None, the initial conditions are obtained from the model. Defaults to None.
|
96
|
-
integrator (
|
100
|
+
integrator (Callable[[Callable, ArrayLike], IntegratorProtocol], optional): The integrator to use for the simulation.
|
97
101
|
Defaults to DefaultIntegrator.
|
98
102
|
test_run (bool, optional): If True, performs a test run to ensure the model's methods
|
99
103
|
(get_full_concs, get_fluxes, get_right_hand_side) work correctly with the initial conditions.
|
@@ -104,7 +108,7 @@ class Simulator:
|
|
104
108
|
y0 = model.get_initial_conditions() if y0 is None else y0
|
105
109
|
self.y0 = [y0[k] for k in model.get_variable_names()]
|
106
110
|
|
107
|
-
self.integrator = integrator(self.model,
|
111
|
+
self.integrator = integrator(self.model, self.y0)
|
108
112
|
self.concs = None
|
109
113
|
self.args = None
|
110
114
|
self.simulation_parameters = None
|