modelbase2 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modelbase2/distributions.py +5 -2
- modelbase2/experimental/__init__.py +2 -0
- modelbase2/experimental/_backup.py +1017 -0
- modelbase2/experimental/codegen.py +1 -1
- modelbase2/experimental/strikepy.py +562 -0
- modelbase2/experimental/symbolic.py +286 -0
- modelbase2/experimental/tex.py +4 -8
- modelbase2/fit.py +6 -6
- modelbase2/model.py +56 -9
- modelbase2/npe.py +8 -3
- modelbase2/plot.py +2 -2
- modelbase2/sbml/_import.py +5 -1
- modelbase2/simulator.py +7 -3
- modelbase2/surrogates/_poly.py +3 -1
- modelbase2/surrogates/_torch.py +4 -2
- modelbase2/surrogates.py +7 -1
- {modelbase2-0.3.0.dist-info → modelbase2-0.5.0.dist-info}/METADATA +2 -1
- {modelbase2-0.3.0.dist-info → modelbase2-0.5.0.dist-info}/RECORD +20 -17
- {modelbase2-0.3.0.dist-info → modelbase2-0.5.0.dist-info}/WHEEL +0 -0
- {modelbase2-0.3.0.dist-info → modelbase2-0.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,286 @@
|
|
1
|
+
# ruff: noqa: D100, D101, D102, D103, D104, D105, D106, D107, D200, D203, D400, D401
|
2
|
+
|
3
|
+
import ast
|
4
|
+
import inspect
|
5
|
+
import textwrap
|
6
|
+
from collections.abc import Callable
|
7
|
+
from dataclasses import dataclass
|
8
|
+
from typing import Any, cast
|
9
|
+
|
10
|
+
import sympy
|
11
|
+
|
12
|
+
from modelbase2.model import Model
|
13
|
+
|
14
|
+
__all__ = ["Context", "SymbolicModel", "model_fn_to_sympy", "to_symbolic_model"]
|
15
|
+
|
16
|
+
|
17
|
+
@dataclass
|
18
|
+
class Context:
|
19
|
+
symbols: dict[str, sympy.Symbol | sympy.Expr]
|
20
|
+
caller: Callable
|
21
|
+
|
22
|
+
|
23
|
+
@dataclass
|
24
|
+
class SymbolicModel:
|
25
|
+
variables: dict[str, sympy.Symbol]
|
26
|
+
parameters: dict[str, sympy.Symbol]
|
27
|
+
eqs: list[sympy.Expr]
|
28
|
+
|
29
|
+
|
30
|
+
def to_symbolic_model(model: Model) -> SymbolicModel:
|
31
|
+
cache = model._create_cache() # noqa: SLF001
|
32
|
+
|
33
|
+
variables = dict(
|
34
|
+
zip(model.variables, sympy.symbols(list(model.variables)), strict=True)
|
35
|
+
)
|
36
|
+
parameters = dict(
|
37
|
+
zip(model.parameters, sympy.symbols(list(model.parameters)), strict=True)
|
38
|
+
)
|
39
|
+
symbols = variables | parameters
|
40
|
+
|
41
|
+
for k, v in model.derived.items():
|
42
|
+
symbols[k] = model_fn_to_sympy(v.fn, [symbols[i] for i in v.args])
|
43
|
+
|
44
|
+
rxns = {
|
45
|
+
k: model_fn_to_sympy(v.fn, [symbols[i] for i in v.args])
|
46
|
+
for k, v in model.reactions.items()
|
47
|
+
}
|
48
|
+
|
49
|
+
eqs: dict[str, sympy.Expr] = {}
|
50
|
+
for cpd, stoich in cache.stoich_by_cpds.items():
|
51
|
+
for rxn, stoich_value in stoich.items():
|
52
|
+
eqs[cpd] = (
|
53
|
+
eqs.get(cpd, sympy.Float(0.0)) + sympy.Float(stoich_value) * rxns[rxn] # type: ignore
|
54
|
+
)
|
55
|
+
|
56
|
+
for cpd, dstoich in cache.dyn_stoich_by_cpds.items():
|
57
|
+
for rxn, der in dstoich.items():
|
58
|
+
eqs[cpd] = eqs.get(cpd, sympy.Float(0.0)) + model_fn_to_sympy(
|
59
|
+
der.fn,
|
60
|
+
[symbols[i] for i in der.args] * rxns[rxn], # type: ignore
|
61
|
+
) # type: ignore
|
62
|
+
|
63
|
+
return SymbolicModel(
|
64
|
+
variables=variables,
|
65
|
+
parameters=parameters,
|
66
|
+
eqs=[eqs[i] for i in cache.var_names],
|
67
|
+
)
|
68
|
+
|
69
|
+
|
70
|
+
def model_fn_to_sympy(
|
71
|
+
fn: Callable, model_args: list[sympy.Symbol | sympy.Expr] | None = None
|
72
|
+
) -> sympy.Expr:
|
73
|
+
source = textwrap.dedent(inspect.getsource(fn))
|
74
|
+
|
75
|
+
if not isinstance(fn_def := ast.parse(source).body[0], ast.FunctionDef):
|
76
|
+
msg = "Expected a function definition"
|
77
|
+
raise TypeError(msg)
|
78
|
+
|
79
|
+
fn_args = [str(arg.arg) for arg in fn_def.args.args]
|
80
|
+
|
81
|
+
sympy_expr = _handle_fn_body(
|
82
|
+
fn_def.body,
|
83
|
+
ctx=Context(
|
84
|
+
symbols={name: sympy.Symbol(name) for name in fn_args},
|
85
|
+
caller=fn,
|
86
|
+
),
|
87
|
+
)
|
88
|
+
|
89
|
+
if model_args is not None:
|
90
|
+
sympy_expr = sympy_expr.subs(dict(zip(fn_args, model_args, strict=True)))
|
91
|
+
|
92
|
+
return cast(sympy.Expr, sympy_expr)
|
93
|
+
|
94
|
+
|
95
|
+
def _handle_fn_body(body: list[ast.stmt], ctx: Context) -> sympy.Expr:
|
96
|
+
pieces = []
|
97
|
+
remaining_body = list(body)
|
98
|
+
|
99
|
+
while remaining_body:
|
100
|
+
node = remaining_body.pop(0)
|
101
|
+
|
102
|
+
if isinstance(node, ast.If):
|
103
|
+
condition = _handle_expr(node.test, ctx)
|
104
|
+
if_expr = _handle_fn_body(node.body, ctx)
|
105
|
+
pieces.append((if_expr, condition))
|
106
|
+
|
107
|
+
# If there's an else clause
|
108
|
+
if node.orelse:
|
109
|
+
# Check if it's an elif (an If node in orelse)
|
110
|
+
if len(node.orelse) == 1 and isinstance(node.orelse[0], ast.If):
|
111
|
+
# Push the elif back to the beginning of remaining_body to process next
|
112
|
+
remaining_body.insert(0, node.orelse[0])
|
113
|
+
else:
|
114
|
+
# It's a regular else
|
115
|
+
else_expr = _handle_fn_body(node.orelse, ctx) # FIXME: copy here
|
116
|
+
pieces.append((else_expr, True))
|
117
|
+
break # We're done with this chain
|
118
|
+
|
119
|
+
elif not remaining_body and any(
|
120
|
+
isinstance(n, ast.Return) for n in body[body.index(node) + 1 :]
|
121
|
+
):
|
122
|
+
else_expr = _handle_fn_body(
|
123
|
+
body[body.index(node) + 1 :], ctx
|
124
|
+
) # FIXME: copy here
|
125
|
+
pieces.append((else_expr, True))
|
126
|
+
|
127
|
+
elif isinstance(node, ast.Return):
|
128
|
+
if (value := node.value) is None:
|
129
|
+
msg = "Return value cannot be None"
|
130
|
+
raise ValueError(msg)
|
131
|
+
|
132
|
+
expr = _handle_expr(value, ctx)
|
133
|
+
if not pieces:
|
134
|
+
return expr
|
135
|
+
pieces.append((expr, True))
|
136
|
+
break
|
137
|
+
|
138
|
+
elif isinstance(node, ast.Assign):
|
139
|
+
# Handle tuple assignments like c, d = a, b
|
140
|
+
if isinstance(node.targets[0], ast.Tuple):
|
141
|
+
# Handle tuple unpacking
|
142
|
+
target_elements = node.targets[0].elts
|
143
|
+
|
144
|
+
if isinstance(node.value, ast.Tuple):
|
145
|
+
# Direct unpacking like c, d = a, b
|
146
|
+
value_elements = node.value.elts
|
147
|
+
for target, value_expr in zip(
|
148
|
+
target_elements, value_elements, strict=True
|
149
|
+
):
|
150
|
+
if isinstance(target, ast.Name):
|
151
|
+
ctx.symbols[target.id] = _handle_expr(value_expr, ctx)
|
152
|
+
else:
|
153
|
+
# Handle potential iterable unpacking
|
154
|
+
value = _handle_expr(node.value, ctx)
|
155
|
+
else:
|
156
|
+
# Regular single assignment
|
157
|
+
if not isinstance(target := node.targets[0], ast.Name):
|
158
|
+
msg = "Only single variable assignments are supported"
|
159
|
+
raise TypeError(msg)
|
160
|
+
target_name = target.id
|
161
|
+
value = _handle_expr(node.value, ctx)
|
162
|
+
ctx.symbols[target_name] = value
|
163
|
+
|
164
|
+
# If we have pieces to combine into a Piecewise
|
165
|
+
if pieces:
|
166
|
+
return sympy.Piecewise(*pieces)
|
167
|
+
|
168
|
+
# If no return was found but we have assignments, return the last assigned variable
|
169
|
+
for node in reversed(body):
|
170
|
+
if isinstance(node, ast.Assign) and isinstance(node.targets[0], ast.Name):
|
171
|
+
target_name = node.targets[0].id
|
172
|
+
return ctx.symbols[target_name]
|
173
|
+
|
174
|
+
msg = "No return value found in function body"
|
175
|
+
raise ValueError(msg)
|
176
|
+
|
177
|
+
|
178
|
+
def _handle_unaryop(node: ast.UnaryOp, ctx: Context) -> sympy.Expr:
|
179
|
+
left = _handle_expr(node.operand, ctx)
|
180
|
+
left = cast(Any, left) # stupid sympy types don't allow ops on symbols
|
181
|
+
|
182
|
+
if isinstance(node.op, ast.UAdd):
|
183
|
+
return +left
|
184
|
+
if isinstance(node.op, ast.USub):
|
185
|
+
return -left
|
186
|
+
|
187
|
+
msg = f"Operation {type(node.op).__name__} not implemented"
|
188
|
+
raise NotImplementedError(msg)
|
189
|
+
|
190
|
+
|
191
|
+
def _handle_binop(node: ast.BinOp, ctx: Context) -> sympy.Expr:
|
192
|
+
left = _handle_expr(node.left, ctx)
|
193
|
+
left = cast(Any, left) # stupid sympy types don't allow ops on symbols
|
194
|
+
|
195
|
+
right = _handle_expr(node.right, ctx)
|
196
|
+
right = cast(Any, right) # stupid sympy types don't allow ops on symbols
|
197
|
+
|
198
|
+
if isinstance(node.op, ast.Add):
|
199
|
+
return left + right
|
200
|
+
if isinstance(node.op, ast.Sub):
|
201
|
+
return left - right
|
202
|
+
if isinstance(node.op, ast.Mult):
|
203
|
+
return left * right
|
204
|
+
if isinstance(node.op, ast.Div):
|
205
|
+
return left / right
|
206
|
+
if isinstance(node.op, ast.Pow):
|
207
|
+
return left**right
|
208
|
+
if isinstance(node.op, ast.Mod):
|
209
|
+
return left % right
|
210
|
+
if isinstance(node.op, ast.FloorDiv):
|
211
|
+
return left // right
|
212
|
+
|
213
|
+
msg = f"Operation {type(node.op).__name__} not implemented"
|
214
|
+
raise NotImplementedError(msg)
|
215
|
+
|
216
|
+
|
217
|
+
def _handle_call(node: ast.Call, ctx: Context) -> sympy.Expr:
|
218
|
+
if not isinstance(callee := node.func, ast.Name):
|
219
|
+
msg = "Only function calls with names are supported"
|
220
|
+
raise TypeError(msg)
|
221
|
+
|
222
|
+
fn_name = str(callee.id)
|
223
|
+
parent_module = inspect.getmodule(ctx.caller)
|
224
|
+
fns = dict(inspect.getmembers(parent_module, predicate=callable))
|
225
|
+
|
226
|
+
return model_fn_to_sympy(
|
227
|
+
fns[fn_name],
|
228
|
+
model_args=[_handle_expr(i, ctx) for i in node.args],
|
229
|
+
)
|
230
|
+
|
231
|
+
|
232
|
+
def _handle_name(node: ast.Name, ctx: Context) -> sympy.Symbol | sympy.Expr:
|
233
|
+
return ctx.symbols[node.id]
|
234
|
+
|
235
|
+
|
236
|
+
def _handle_expr(node: ast.expr, ctx: Context) -> sympy.Expr:
|
237
|
+
if isinstance(node, ast.UnaryOp):
|
238
|
+
return _handle_unaryop(node, ctx)
|
239
|
+
if isinstance(node, ast.BinOp):
|
240
|
+
return _handle_binop(node, ctx)
|
241
|
+
if isinstance(node, ast.Name):
|
242
|
+
return _handle_name(node, ctx)
|
243
|
+
if isinstance(node, ast.Constant):
|
244
|
+
return node.value
|
245
|
+
if isinstance(node, ast.Compare):
|
246
|
+
# Handle chained comparisons like 1 < a < 2
|
247
|
+
left = cast(Any, _handle_expr(node.left, ctx))
|
248
|
+
comparisons = []
|
249
|
+
|
250
|
+
# Build all individual comparisons from the chain
|
251
|
+
prev_value = left
|
252
|
+
for op, comparator in zip(node.ops, node.comparators, strict=True):
|
253
|
+
right = cast(Any, _handle_expr(comparator, ctx))
|
254
|
+
|
255
|
+
if isinstance(op, ast.Gt):
|
256
|
+
comparisons.append(prev_value > right)
|
257
|
+
elif isinstance(op, ast.GtE):
|
258
|
+
comparisons.append(prev_value >= right)
|
259
|
+
elif isinstance(op, ast.Lt):
|
260
|
+
comparisons.append(prev_value < right)
|
261
|
+
elif isinstance(op, ast.LtE):
|
262
|
+
comparisons.append(prev_value <= right)
|
263
|
+
elif isinstance(op, ast.Eq):
|
264
|
+
comparisons.append(prev_value == right)
|
265
|
+
elif isinstance(op, ast.NotEq):
|
266
|
+
comparisons.append(prev_value != right)
|
267
|
+
|
268
|
+
prev_value = right
|
269
|
+
|
270
|
+
# Combine all comparisons with logical AND
|
271
|
+
result = comparisons[0]
|
272
|
+
for comp in comparisons[1:]:
|
273
|
+
result = sympy.And(result, comp)
|
274
|
+
return cast(sympy.Expr, result)
|
275
|
+
if isinstance(node, ast.Call):
|
276
|
+
return _handle_call(node, ctx)
|
277
|
+
|
278
|
+
# Handle conditional expressions (ternary operators)
|
279
|
+
if isinstance(node, ast.IfExp):
|
280
|
+
condition = _handle_expr(node.test, ctx)
|
281
|
+
if_true = _handle_expr(node.body, ctx)
|
282
|
+
if_false = _handle_expr(node.orelse, ctx)
|
283
|
+
return sympy.Piecewise((if_true, condition), (if_false, True))
|
284
|
+
|
285
|
+
msg = f"Expression type {type(node).__name__} not implemented"
|
286
|
+
raise NotImplementedError(msg)
|
modelbase2/experimental/tex.py
CHANGED
@@ -508,14 +508,10 @@ def get_model_tex_diff(
|
|
508
508
|
gls = default_init(gls)
|
509
509
|
section_label = "sec:model-diff"
|
510
510
|
|
511
|
-
return f"""{
|
511
|
+
return f"""{" start autogenerated ":%^60}
|
512
512
|
{_clearpage()}
|
513
|
-
{_subsubsection(
|
514
|
-
{(
|
515
|
-
(_to_tex_export(m1) - _to_tex_export(m2))
|
516
|
-
.rename_with_glossary(gls)
|
517
|
-
.export_all()
|
518
|
-
)}
|
513
|
+
{_subsubsection("Model changes")}{_label(section_label)}
|
514
|
+
{((_to_tex_export(m1) - _to_tex_export(m2)).rename_with_glossary(gls).export_all())}
|
519
515
|
{_clearpage()}
|
520
|
-
{
|
516
|
+
{" end autogenerated ":%^60}
|
521
517
|
"""
|
modelbase2/fit.py
CHANGED
@@ -50,7 +50,7 @@ class SteadyStateResidualFn(Protocol):
|
|
50
50
|
data: pd.Series,
|
51
51
|
model: Model,
|
52
52
|
y0: dict[str, float],
|
53
|
-
integrator:
|
53
|
+
integrator: Callable[[Callable, ArrayLike], IntegratorProtocol],
|
54
54
|
) -> float:
|
55
55
|
"""Calculate residual error between model steady state and experimental data."""
|
56
56
|
...
|
@@ -67,7 +67,7 @@ class TimeSeriesResidualFn(Protocol):
|
|
67
67
|
data: pd.DataFrame,
|
68
68
|
model: Model,
|
69
69
|
y0: dict[str, float],
|
70
|
-
integrator:
|
70
|
+
integrator: Callable[[Callable, ArrayLike], IntegratorProtocol],
|
71
71
|
) -> float:
|
72
72
|
"""Calculate residual error between model time course and experimental data."""
|
73
73
|
...
|
@@ -101,7 +101,7 @@ def _steady_state_residual(
|
|
101
101
|
data: pd.Series,
|
102
102
|
model: Model,
|
103
103
|
y0: dict[str, float] | None,
|
104
|
-
integrator:
|
104
|
+
integrator: Callable[[Callable, ArrayLike], IntegratorProtocol],
|
105
105
|
) -> float:
|
106
106
|
"""Calculate residual error between model steady state and experimental data.
|
107
107
|
|
@@ -148,7 +148,7 @@ def _time_course_residual(
|
|
148
148
|
data: pd.DataFrame,
|
149
149
|
model: Model,
|
150
150
|
y0: dict[str, float],
|
151
|
-
integrator:
|
151
|
+
integrator: Callable[[Callable, ArrayLike], IntegratorProtocol],
|
152
152
|
) -> float:
|
153
153
|
"""Calculate residual error between model time course and experimental data.
|
154
154
|
|
@@ -187,7 +187,7 @@ def steady_state(
|
|
187
187
|
y0: dict[str, float] | None = None,
|
188
188
|
minimize_fn: MinimizeFn = _default_minimize_fn,
|
189
189
|
residual_fn: SteadyStateResidualFn = _steady_state_residual,
|
190
|
-
integrator:
|
190
|
+
integrator: Callable[[Callable, ArrayLike], IntegratorProtocol] = DefaultIntegrator,
|
191
191
|
) -> dict[str, float]:
|
192
192
|
"""Fit model parameters to steady-state experimental data.
|
193
193
|
|
@@ -241,7 +241,7 @@ def time_course(
|
|
241
241
|
y0: dict[str, float] | None = None,
|
242
242
|
minimize_fn: MinimizeFn = _default_minimize_fn,
|
243
243
|
residual_fn: TimeSeriesResidualFn = _time_course_residual,
|
244
|
-
integrator:
|
244
|
+
integrator: Callable[[Callable, ArrayLike], IntegratorProtocol] = DefaultIntegrator,
|
245
245
|
) -> dict[str, float]:
|
246
246
|
"""Fit model parameters to time course of experimental data.
|
247
247
|
|
modelbase2/model.py
CHANGED
@@ -10,7 +10,6 @@ from __future__ import annotations
|
|
10
10
|
import copy
|
11
11
|
import inspect
|
12
12
|
import itertools as it
|
13
|
-
import math
|
14
13
|
from dataclasses import dataclass, field
|
15
14
|
from typing import TYPE_CHECKING, Self, cast
|
16
15
|
|
@@ -26,7 +25,13 @@ from modelbase2.types import (
|
|
26
25
|
Readout,
|
27
26
|
)
|
28
27
|
|
29
|
-
__all__ = [
|
28
|
+
__all__ = [
|
29
|
+
"ArityMismatchError",
|
30
|
+
"CircularDependencyError",
|
31
|
+
"MissingDependenciesError",
|
32
|
+
"Model",
|
33
|
+
"ModelCache",
|
34
|
+
]
|
30
35
|
|
31
36
|
if TYPE_CHECKING:
|
32
37
|
from collections.abc import Iterable, Mapping
|
@@ -35,19 +40,38 @@ if TYPE_CHECKING:
|
|
35
40
|
from modelbase2.types import AbstractSurrogate, Callable, Param, RateFn, RetType
|
36
41
|
|
37
42
|
|
38
|
-
class
|
43
|
+
class MissingDependenciesError(Exception):
|
39
44
|
"""Raised when dependencies cannot be sorted topologically.
|
40
45
|
|
41
46
|
This typically indicates circular dependencies in model components.
|
42
47
|
"""
|
43
48
|
|
44
|
-
def __init__(self,
|
49
|
+
def __init__(self, not_solvable: dict[str, list[str]]) -> None:
|
45
50
|
"""Initialise exception."""
|
51
|
+
missing_by_module = "\n".join(f"\t{k}: {v}" for k, v in not_solvable.items())
|
46
52
|
msg = (
|
47
|
-
f"
|
48
|
-
|
49
|
-
|
50
|
-
|
53
|
+
f"Dependencies cannot be solved. Missing dependencies:\n{missing_by_module}"
|
54
|
+
)
|
55
|
+
super().__init__(msg)
|
56
|
+
|
57
|
+
|
58
|
+
class CircularDependencyError(Exception):
|
59
|
+
"""Raised when dependencies cannot be sorted topologically.
|
60
|
+
|
61
|
+
This typically indicates circular dependencies in model components.
|
62
|
+
"""
|
63
|
+
|
64
|
+
def __init__(
|
65
|
+
self,
|
66
|
+
missing: dict[str, set[str]],
|
67
|
+
) -> None:
|
68
|
+
"""Initialise exception."""
|
69
|
+
missing_by_module = "\n".join(f"\t{k}: {v}" for k, v in missing.items())
|
70
|
+
msg = (
|
71
|
+
f"Exceeded max iterations on sorting dependencies.\n"
|
72
|
+
"Check if there are circular references. "
|
73
|
+
"Missing dependencies:\n"
|
74
|
+
f"{missing_by_module}"
|
51
75
|
)
|
52
76
|
super().__init__(msg)
|
53
77
|
|
@@ -120,6 +144,24 @@ def _invalidate_cache(method: Callable[Param, RetType]) -> Callable[Param, RetTy
|
|
120
144
|
return wrapper # type: ignore
|
121
145
|
|
122
146
|
|
147
|
+
def _check_if_is_sortable(
|
148
|
+
available: set[str],
|
149
|
+
elements: list[tuple[str, set[str]]],
|
150
|
+
) -> None:
|
151
|
+
all_available = available.copy()
|
152
|
+
for name, _ in elements:
|
153
|
+
all_available.add(name)
|
154
|
+
|
155
|
+
# Check if it can be sorted in the first place
|
156
|
+
not_solvable = {}
|
157
|
+
for name, args in elements:
|
158
|
+
if not args.issubset(all_available):
|
159
|
+
not_solvable[name] = sorted(args.difference(all_available))
|
160
|
+
|
161
|
+
if not_solvable:
|
162
|
+
raise MissingDependenciesError(not_solvable=not_solvable)
|
163
|
+
|
164
|
+
|
123
165
|
def _sort_dependencies(
|
124
166
|
available: set[str], elements: list[tuple[str, set[str]]]
|
125
167
|
) -> list[str]:
|
@@ -138,6 +180,8 @@ def _sort_dependencies(
|
|
138
180
|
"""
|
139
181
|
from queue import Empty, SimpleQueue
|
140
182
|
|
183
|
+
_check_if_is_sortable(available, elements)
|
184
|
+
|
141
185
|
order = []
|
142
186
|
# FIXME: what is the worst case here?
|
143
187
|
max_iterations = len(elements) ** 2
|
@@ -171,7 +215,10 @@ def _sort_dependencies(
|
|
171
215
|
unsorted.append(queue.get_nowait()[0])
|
172
216
|
except Empty:
|
173
217
|
break
|
174
|
-
|
218
|
+
|
219
|
+
mod_to_args: dict[str, set[str]] = dict(elements)
|
220
|
+
missing = {k: mod_to_args[k].difference(available) for k in unsorted}
|
221
|
+
raise CircularDependencyError(missing=missing)
|
175
222
|
return order
|
176
223
|
|
177
224
|
|
modelbase2/npe.py
CHANGED
@@ -23,7 +23,7 @@ __all__ = [
|
|
23
23
|
from abc import abstractmethod
|
24
24
|
from dataclasses import dataclass
|
25
25
|
from pathlib import Path
|
26
|
-
from typing import cast
|
26
|
+
from typing import TYPE_CHECKING, cast
|
27
27
|
|
28
28
|
import numpy as np
|
29
29
|
import pandas as pd
|
@@ -35,6 +35,11 @@ from torch.optim.adam import Adam
|
|
35
35
|
from modelbase2.nnarchitectures import MLP, DefaultDevice, LSTMnn
|
36
36
|
from modelbase2.parallel import Cache
|
37
37
|
|
38
|
+
if TYPE_CHECKING:
|
39
|
+
from collections.abc import Callable
|
40
|
+
|
41
|
+
from torch.optim.optimizer import ParamsT
|
42
|
+
|
38
43
|
DefaultCache = Cache(Path(".cache"))
|
39
44
|
|
40
45
|
|
@@ -140,7 +145,7 @@ def train_torch_ss_estimator(
|
|
140
145
|
epochs: int,
|
141
146
|
batch_size: int | None = None,
|
142
147
|
approximator: nn.Module | None = None,
|
143
|
-
optimimzer_cls:
|
148
|
+
optimimzer_cls: Callable[[ParamsT], Adam] = Adam,
|
144
149
|
device: torch.device = DefaultDevice,
|
145
150
|
) -> tuple[TorchSSEstimator, pd.Series]:
|
146
151
|
"""Train a PyTorch steady state estimator.
|
@@ -206,7 +211,7 @@ def train_torch_time_course_estimator(
|
|
206
211
|
epochs: int,
|
207
212
|
batch_size: int | None = None,
|
208
213
|
approximator: nn.Module | None = None,
|
209
|
-
optimimzer_cls:
|
214
|
+
optimimzer_cls: Callable[[ParamsT], Adam] = Adam,
|
210
215
|
device: torch.device = DefaultDevice,
|
211
216
|
) -> tuple[TorchTimeCourseEstimator, pd.Series]:
|
212
217
|
"""Train a PyTorch time course estimator.
|
modelbase2/plot.py
CHANGED
@@ -818,7 +818,7 @@ def relative_label_distribution(
|
|
818
818
|
isos = mapper.get_isotopomers_of_at_position(name, i)
|
819
819
|
labels = cast(pd.DataFrame, concs.loc[:, isos])
|
820
820
|
total = concs.loc[:, f"{name}__total"]
|
821
|
-
ax.plot(labels.index, (labels.sum(axis=1) / total), label=f"C{i+1}")
|
821
|
+
ax.plot(labels.index, (labels.sum(axis=1) / total), label=f"C{i + 1}")
|
822
822
|
ax.set_title(name)
|
823
823
|
ax.legend()
|
824
824
|
else:
|
@@ -827,6 +827,6 @@ def relative_label_distribution(
|
|
827
827
|
):
|
828
828
|
ax.plot(concs.index, concs.loc[:, isos])
|
829
829
|
ax.set_title(name)
|
830
|
-
ax.legend([f"C{i+1}" for i in range(len(isos))])
|
830
|
+
ax.legend([f"C{i + 1}" for i in range(len(isos))])
|
831
831
|
|
832
832
|
return fig, axs
|
modelbase2/sbml/_import.py
CHANGED
@@ -507,7 +507,11 @@ def _codgen(name: str, sbml: Parser) -> Path:
|
|
507
507
|
|
508
508
|
# Initial assignments
|
509
509
|
initial_assignment_order = _sort_dependencies(
|
510
|
-
available=set(sbml.initial_assignment)
|
510
|
+
available=set(sbml.initial_assignment)
|
511
|
+
^ set(parameters)
|
512
|
+
^ set(variables)
|
513
|
+
^ set(sbml.derived)
|
514
|
+
| {"time"},
|
511
515
|
elements=[(k, set(v.args)) for k, v in sbml.initial_assignment.items()],
|
512
516
|
)
|
513
517
|
|
modelbase2/simulator.py
CHANGED
@@ -21,6 +21,8 @@ from modelbase2.integrators import DefaultIntegrator
|
|
21
21
|
__all__ = ["Simulator"]
|
22
22
|
|
23
23
|
if TYPE_CHECKING:
|
24
|
+
from collections.abc import Callable
|
25
|
+
|
24
26
|
from modelbase2.model import Model
|
25
27
|
from modelbase2.types import ArrayLike, IntegratorProtocol
|
26
28
|
|
@@ -83,7 +85,9 @@ class Simulator:
|
|
83
85
|
self,
|
84
86
|
model: Model,
|
85
87
|
y0: dict[str, float] | None = None,
|
86
|
-
integrator:
|
88
|
+
integrator: Callable[
|
89
|
+
[Callable, ArrayLike], IntegratorProtocol
|
90
|
+
] = DefaultIntegrator,
|
87
91
|
*,
|
88
92
|
test_run: bool = True,
|
89
93
|
) -> None:
|
@@ -93,7 +97,7 @@ class Simulator:
|
|
93
97
|
model (Model): The model to be simulated.
|
94
98
|
y0 (dict[str, float] | None, optional): Initial conditions for the model variables.
|
95
99
|
If None, the initial conditions are obtained from the model. Defaults to None.
|
96
|
-
integrator (
|
100
|
+
integrator (Callable[[Callable, ArrayLike], IntegratorProtocol], optional): The integrator to use for the simulation.
|
97
101
|
Defaults to DefaultIntegrator.
|
98
102
|
test_run (bool, optional): If True, performs a test run to ensure the model's methods
|
99
103
|
(get_full_concs, get_fluxes, get_right_hand_side) work correctly with the initial conditions.
|
@@ -104,7 +108,7 @@ class Simulator:
|
|
104
108
|
y0 = model.get_initial_conditions() if y0 is None else y0
|
105
109
|
self.y0 = [y0[k] for k in model.get_variable_names()]
|
106
110
|
|
107
|
-
self.integrator = integrator(self.model,
|
111
|
+
self.integrator = integrator(self.model, self.y0)
|
108
112
|
self.concs = None
|
109
113
|
self.args = None
|
110
114
|
self.simulation_parameters = None
|
modelbase2/surrogates/_poly.py
CHANGED
@@ -32,7 +32,9 @@ class PolySurrogate(AbstractSurrogate):
|
|
32
32
|
def train_polynomial_surrogate(
|
33
33
|
feature: ArrayLike,
|
34
34
|
target: ArrayLike,
|
35
|
-
series: Literal[
|
35
|
+
series: Literal[
|
36
|
+
"Power", "Chebyshev", "Legendre", "Laguerre", "Hermite", "HermiteE"
|
37
|
+
] = "Power",
|
36
38
|
degrees: Iterable[int] = (1, 2, 3, 4, 5, 6, 7),
|
37
39
|
surrogate_args: list[str] | None = None,
|
38
40
|
surrogate_stoichiometries: dict[str, dict[str, float]] | None = None,
|
modelbase2/surrogates/_torch.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
from collections.abc import Callable
|
1
2
|
from dataclasses import dataclass
|
2
3
|
|
3
4
|
import numpy as np
|
@@ -6,9 +7,10 @@ import torch
|
|
6
7
|
import tqdm
|
7
8
|
from torch import nn
|
8
9
|
from torch.optim.adam import Adam
|
10
|
+
from torch.optim.optimizer import ParamsT
|
9
11
|
|
10
|
-
from modelbase2.types import AbstractSurrogate
|
11
12
|
from modelbase2.nnarchitectures import MLP, DefaultDevice
|
13
|
+
from modelbase2.types import AbstractSurrogate
|
12
14
|
|
13
15
|
__all__ = ["TorchSurrogate", "train_torch_surrogate"]
|
14
16
|
|
@@ -124,7 +126,7 @@ def train_torch_surrogate(
|
|
124
126
|
surrogate_stoichiometries: dict[str, dict[str, float]] | None = None,
|
125
127
|
batch_size: int | None = None,
|
126
128
|
approximator: nn.Module | None = None,
|
127
|
-
optimimzer_cls:
|
129
|
+
optimimzer_cls: Callable[[ParamsT], Adam] = Adam,
|
128
130
|
device: torch.device = DefaultDevice,
|
129
131
|
) -> tuple[TorchSurrogate, pd.Series]:
|
130
132
|
"""Train a PyTorch surrogate model.
|
modelbase2/surrogates.py
CHANGED
@@ -19,6 +19,7 @@ from __future__ import annotations
|
|
19
19
|
from abc import abstractmethod
|
20
20
|
from dataclasses import dataclass
|
21
21
|
from pathlib import Path
|
22
|
+
from typing import TYPE_CHECKING
|
22
23
|
|
23
24
|
import numpy as np
|
24
25
|
import pandas as pd
|
@@ -29,6 +30,11 @@ from torch.optim.adam import Adam
|
|
29
30
|
|
30
31
|
from modelbase2.parallel import Cache
|
31
32
|
|
33
|
+
if TYPE_CHECKING:
|
34
|
+
from collections.abc import Callable
|
35
|
+
|
36
|
+
from torch.optim.optimizer import ParamsT
|
37
|
+
|
32
38
|
__all__ = [
|
33
39
|
"AbstractSurrogate",
|
34
40
|
"Approximator",
|
@@ -251,7 +257,7 @@ def train_torch_surrogate(
|
|
251
257
|
surrogate_stoichiometries: dict[str, dict[str, float]],
|
252
258
|
batch_size: int | None = None,
|
253
259
|
approximator: nn.Module | None = None,
|
254
|
-
optimimzer_cls:
|
260
|
+
optimimzer_cls: Callable[[ParamsT], Adam] = Adam,
|
255
261
|
device: torch.device = DefaultDevice,
|
256
262
|
) -> tuple[TorchSurrogate, pd.Series]:
|
257
263
|
"""Train a PyTorch surrogate model.
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: modelbase2
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.5.0
|
4
4
|
Summary: A package to build metabolic models
|
5
5
|
Author-email: Marvin van Aalst <marvin.vanaalst@gmail.com>
|
6
6
|
Maintainer-email: Marvin van Aalst <marvin.vanaalst@gmail.com>
|
@@ -33,6 +33,7 @@ Requires-Dist: pebble>=5.0.7
|
|
33
33
|
Requires-Dist: python-libsbml>=5.20.4
|
34
34
|
Requires-Dist: scipy>=1.14.1
|
35
35
|
Requires-Dist: seaborn>=0.13.2
|
36
|
+
Requires-Dist: symbtools>=0.4.0
|
36
37
|
Requires-Dist: sympy>=1.13.1
|
37
38
|
Requires-Dist: tabulate>=0.9.0
|
38
39
|
Requires-Dist: tqdm>=4.66.6
|