mxlpy 0.18.0__py3-none-any.whl → 0.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mxlpy/__init__.py +13 -9
- mxlpy/compare.py +240 -0
- mxlpy/experimental/diff.py +16 -4
- mxlpy/fit.py +6 -11
- mxlpy/fns.py +37 -42
- mxlpy/identify.py +10 -3
- mxlpy/integrators/__init__.py +4 -3
- mxlpy/integrators/int_assimulo.py +16 -9
- mxlpy/integrators/int_scipy.py +13 -9
- mxlpy/label_map.py +7 -3
- mxlpy/linear_label_map.py +4 -2
- mxlpy/mc.py +5 -14
- mxlpy/mca.py +4 -4
- mxlpy/meta/__init__.py +6 -4
- mxlpy/meta/codegen_latex.py +180 -87
- mxlpy/meta/codegen_modebase.py +3 -1
- mxlpy/meta/codegen_py.py +11 -3
- mxlpy/meta/source_tools.py +9 -5
- mxlpy/model.py +187 -100
- mxlpy/nn/__init__.py +24 -5
- mxlpy/nn/_keras.py +92 -0
- mxlpy/nn/_torch.py +25 -18
- mxlpy/npe/__init__.py +21 -16
- mxlpy/npe/_keras.py +326 -0
- mxlpy/npe/_torch.py +56 -60
- mxlpy/parallel.py +5 -2
- mxlpy/parameterise.py +11 -3
- mxlpy/plot.py +205 -52
- mxlpy/report.py +33 -8
- mxlpy/sbml/__init__.py +3 -3
- mxlpy/sbml/_data.py +7 -6
- mxlpy/sbml/_export.py +8 -1
- mxlpy/sbml/_mathml.py +8 -7
- mxlpy/sbml/_name_conversion.py +5 -1
- mxlpy/scan.py +14 -19
- mxlpy/simulator.py +34 -31
- mxlpy/surrogates/__init__.py +25 -17
- mxlpy/surrogates/_keras.py +139 -0
- mxlpy/surrogates/_poly.py +25 -10
- mxlpy/surrogates/_qss.py +34 -0
- mxlpy/surrogates/_torch.py +50 -32
- mxlpy/symbolic/__init__.py +5 -3
- mxlpy/symbolic/strikepy.py +5 -2
- mxlpy/symbolic/symbolic_model.py +14 -5
- mxlpy/types.py +61 -120
- {mxlpy-0.18.0.dist-info → mxlpy-0.20.0.dist-info}/METADATA +25 -24
- mxlpy-0.20.0.dist-info/RECORD +55 -0
- mxlpy/nn/_tensorflow.py +0 -0
- mxlpy-0.18.0.dist-info/RECORD +0 -51
- {mxlpy-0.18.0.dist-info → mxlpy-0.20.0.dist-info}/WHEEL +0 -0
- {mxlpy-0.18.0.dist-info → mxlpy-0.20.0.dist-info}/licenses/LICENSE +0 -0
mxlpy/meta/source_tools.py
CHANGED
@@ -1,17 +1,21 @@
|
|
1
1
|
"""Tools for working with python source files."""
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
import ast
|
4
6
|
import inspect
|
5
7
|
import textwrap
|
6
|
-
from collections.abc import Callable
|
7
8
|
from dataclasses import dataclass
|
8
|
-
from
|
9
|
-
from typing import Any, cast
|
9
|
+
from typing import TYPE_CHECKING, Any, cast
|
10
10
|
|
11
11
|
import dill
|
12
12
|
import sympy
|
13
13
|
from sympy.printing.pycode import pycode
|
14
14
|
|
15
|
+
if TYPE_CHECKING:
|
16
|
+
from collections.abc import Callable
|
17
|
+
from types import ModuleType
|
18
|
+
|
15
19
|
__all__ = [
|
16
20
|
"Context",
|
17
21
|
"fn_to_sympy",
|
@@ -35,7 +39,7 @@ class Context:
|
|
35
39
|
symbols: dict[str, sympy.Symbol | sympy.Expr] | None = None,
|
36
40
|
caller: Callable | None = None,
|
37
41
|
parent_module: ModuleType | None = None,
|
38
|
-
) ->
|
42
|
+
) -> Context:
|
39
43
|
"""Update the context with new values."""
|
40
44
|
return Context(
|
41
45
|
symbols=self.symbols if symbols is None else symbols,
|
@@ -434,5 +438,5 @@ def _handle_call(node: ast.Call, ctx: Context) -> sympy.Expr:
|
|
434
438
|
ctx=ctx.updated(parent_module=imports[module_name.id]),
|
435
439
|
)
|
436
440
|
|
437
|
-
msg = f"
|
441
|
+
msg = f"Unsupported function type {node.func}"
|
438
442
|
raise NotImplementedError(msg)
|
mxlpy/model.py
CHANGED
@@ -17,13 +17,13 @@ import numpy as np
|
|
17
17
|
import pandas as pd
|
18
18
|
|
19
19
|
from mxlpy import fns
|
20
|
-
from mxlpy.types import
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
20
|
+
from mxlpy.types import AbstractSurrogate, Array, Derived, Reaction, Readout
|
21
|
+
|
22
|
+
if TYPE_CHECKING:
|
23
|
+
from collections.abc import Iterable, Mapping
|
24
|
+
from inspect import FullArgSpec
|
25
|
+
|
26
|
+
from mxlpy.types import Callable, Param, RateFn, RetType
|
27
27
|
|
28
28
|
__all__ = [
|
29
29
|
"ArityMismatchError",
|
@@ -34,12 +34,6 @@ __all__ = [
|
|
34
34
|
"ModelCache",
|
35
35
|
]
|
36
36
|
|
37
|
-
if TYPE_CHECKING:
|
38
|
-
from collections.abc import Iterable, Mapping
|
39
|
-
from inspect import FullArgSpec
|
40
|
-
|
41
|
-
from mxlpy.types import Callable, Param, RateFn, RetType
|
42
|
-
|
43
37
|
|
44
38
|
@dataclass
|
45
39
|
class Dependency:
|
@@ -250,17 +244,17 @@ class ModelCache:
|
|
250
244
|
stoich_by_cpds: A dictionary mapping compound names to their stoichiometric coefficients.
|
251
245
|
dyn_stoich_by_cpds: A dictionary mapping compound names to their dynamic stoichiometric coefficients.
|
252
246
|
dxdt: A pandas Series representing the rate of change of variables.
|
247
|
+
initial_conditions: calculated initial conditions
|
253
248
|
|
254
249
|
"""
|
255
250
|
|
256
251
|
var_names: list[str]
|
257
|
-
|
252
|
+
dyn_order: list[str]
|
258
253
|
all_parameter_values: dict[str, float]
|
259
|
-
derived_parameter_names: list[str]
|
260
|
-
derived_variable_names: list[str]
|
261
254
|
stoich_by_cpds: dict[str, dict[str, float]]
|
262
255
|
dyn_stoich_by_cpds: dict[str, dict[str, Derived]]
|
263
256
|
dxdt: pd.Series
|
257
|
+
initial_conditions: dict[str, float]
|
264
258
|
|
265
259
|
|
266
260
|
@dataclass(slots=True)
|
@@ -276,17 +270,19 @@ class Model:
|
|
276
270
|
_reactions: Dictionary of reactions in the model.
|
277
271
|
_surrogates: Dictionary of surrogate models.
|
278
272
|
_cache: Cache for storing model-related data structures.
|
273
|
+
_data: Named references to data sets
|
279
274
|
|
280
275
|
"""
|
281
276
|
|
282
277
|
_ids: dict[str, str] = field(default_factory=dict)
|
283
|
-
_variables: dict[str, float] = field(default_factory=dict)
|
278
|
+
_variables: dict[str, float | Derived] = field(default_factory=dict)
|
284
279
|
_parameters: dict[str, float] = field(default_factory=dict)
|
285
280
|
_derived: dict[str, Derived] = field(default_factory=dict)
|
286
281
|
_readouts: dict[str, Readout] = field(default_factory=dict)
|
287
282
|
_reactions: dict[str, Reaction] = field(default_factory=dict)
|
288
283
|
_surrogates: dict[str, AbstractSurrogate] = field(default_factory=dict)
|
289
284
|
_cache: ModelCache | None = None
|
285
|
+
_data: dict[str, pd.Series | pd.DataFrame] = field(default_factory=dict)
|
290
286
|
|
291
287
|
###########################################################################
|
292
288
|
# Cache
|
@@ -317,9 +313,17 @@ class Model:
|
|
317
313
|
raise ArityMismatchError(name, el.fn, el.args)
|
318
314
|
|
319
315
|
# Sort derived & reactions
|
320
|
-
to_sort =
|
316
|
+
to_sort = (
|
317
|
+
self._derived
|
318
|
+
| self._reactions
|
319
|
+
| self._surrogates
|
320
|
+
| {k: v for k, v in self._variables.items() if isinstance(v, Derived)}
|
321
|
+
)
|
321
322
|
order = _sort_dependencies(
|
322
|
-
available=
|
323
|
+
available=all_parameter_names
|
324
|
+
| {k for k, v in self._variables.items() if not isinstance(v, Derived)}
|
325
|
+
| set(self._data)
|
326
|
+
| {"time"},
|
323
327
|
elements=[
|
324
328
|
Dependency(name=k, required=set(v.args), provided={k})
|
325
329
|
if not isinstance(v, AbstractSurrogate)
|
@@ -328,35 +332,42 @@ class Model:
|
|
328
332
|
],
|
329
333
|
)
|
330
334
|
|
331
|
-
#
|
332
|
-
#
|
333
|
-
|
334
|
-
|
335
|
+
# Calculate all values once, including dynamic ones
|
336
|
+
# That way, we can make initial conditions dependent on e.g. rates
|
337
|
+
dependent = (
|
338
|
+
all_parameter_values
|
339
|
+
| self._data
|
340
|
+
| {k: v for k, v in self._variables.items() if not isinstance(v, Derived)}
|
341
|
+
| {"time": 0.0}
|
342
|
+
)
|
343
|
+
for name in order:
|
344
|
+
to_sort[name].calculate_inpl(name, dependent)
|
345
|
+
|
346
|
+
# Split derived into static and dynamic variables
|
347
|
+
static_order = []
|
348
|
+
dyn_order = []
|
335
349
|
for name in order:
|
336
350
|
if name in self._reactions or name in self._surrogates:
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
all_parameter_names.add(name)
|
341
|
-
derived_parameter_names.append(name)
|
342
|
-
all_parameter_values[name] = float(
|
343
|
-
derived.fn(*(all_parameter_values[i] for i in derived.args))
|
344
|
-
)
|
351
|
+
dyn_order.append(name)
|
352
|
+
elif name in self._variables:
|
353
|
+
static_order.append(name)
|
345
354
|
else:
|
346
|
-
|
355
|
+
derived = self._derived[name]
|
356
|
+
if all(i in all_parameter_names for i in derived.args):
|
357
|
+
static_order.append(name)
|
358
|
+
all_parameter_names.add(name)
|
359
|
+
else:
|
360
|
+
dyn_order.append(name)
|
347
361
|
|
362
|
+
# Calculate dynamic and static stochiometries
|
348
363
|
stoich_by_compounds: dict[str, dict[str, float]] = {}
|
349
364
|
dyn_stoich_by_compounds: dict[str, dict[str, Derived]] = {}
|
350
|
-
|
351
365
|
for rxn_name, rxn in self._reactions.items():
|
352
366
|
for cpd_name, factor in rxn.stoichiometry.items():
|
353
367
|
d_static = stoich_by_compounds.setdefault(cpd_name, {})
|
354
|
-
|
355
368
|
if isinstance(factor, Derived):
|
356
369
|
if all(i in all_parameter_names for i in factor.args):
|
357
|
-
d_static[rxn_name] =
|
358
|
-
factor.fn(*(all_parameter_values[i] for i in factor.args))
|
359
|
-
)
|
370
|
+
d_static[rxn_name] = factor.calculate(dependent)
|
360
371
|
else:
|
361
372
|
dyn_stoich_by_compounds.setdefault(cpd_name, {})[rxn_name] = (
|
362
373
|
factor
|
@@ -367,20 +378,40 @@ class Model:
|
|
367
378
|
for surrogate in self._surrogates.values():
|
368
379
|
for rxn_name, rxn in surrogate.stoichiometries.items():
|
369
380
|
for cpd_name, factor in rxn.items():
|
370
|
-
stoich_by_compounds.setdefault(cpd_name, {})
|
381
|
+
d_static = stoich_by_compounds.setdefault(cpd_name, {})
|
382
|
+
if isinstance(factor, Derived):
|
383
|
+
if all(i in all_parameter_names for i in factor.args):
|
384
|
+
d_static[rxn_name] = factor.calculate(dependent)
|
385
|
+
else:
|
386
|
+
dyn_stoich_by_compounds.setdefault(cpd_name, {})[
|
387
|
+
rxn_name
|
388
|
+
] = factor
|
389
|
+
else:
|
390
|
+
d_static[rxn_name] = factor
|
371
391
|
|
372
392
|
var_names = self.get_variable_names()
|
373
393
|
dxdt = pd.Series(np.zeros(len(var_names), dtype=float), index=var_names)
|
374
394
|
|
395
|
+
initial_conditions: dict[str, float] = {
|
396
|
+
k: v for k, v in self._variables.items() if not isinstance(v, Derived)
|
397
|
+
}
|
398
|
+
for name in static_order:
|
399
|
+
if name in self._variables:
|
400
|
+
initial_conditions[name] = cast(float, dependent[name])
|
401
|
+
elif name in self._derived:
|
402
|
+
all_parameter_values[name] = cast(float, dependent[name])
|
403
|
+
else:
|
404
|
+
msg = "Unknown target for static derived variable."
|
405
|
+
raise KeyError(msg)
|
406
|
+
|
375
407
|
self._cache = ModelCache(
|
376
408
|
var_names=var_names,
|
377
|
-
|
409
|
+
dyn_order=dyn_order,
|
378
410
|
all_parameter_values=all_parameter_values,
|
379
411
|
stoich_by_cpds=stoich_by_compounds,
|
380
412
|
dyn_stoich_by_cpds=dyn_stoich_by_compounds,
|
381
|
-
derived_variable_names=derived_variable_names,
|
382
|
-
derived_parameter_names=derived_parameter_names,
|
383
413
|
dxdt=dxdt,
|
414
|
+
initial_conditions=initial_conditions,
|
384
415
|
)
|
385
416
|
return self._cache
|
386
417
|
|
@@ -657,12 +688,27 @@ class Model:
|
|
657
688
|
|
658
689
|
return self
|
659
690
|
|
691
|
+
def get_unused_parameters(self) -> set[str]:
|
692
|
+
"""Get parameters which aren't used in the model."""
|
693
|
+
args = set()
|
694
|
+
for variable in self._variables.values():
|
695
|
+
if isinstance(variable, Derived):
|
696
|
+
args.update(variable.args)
|
697
|
+
for derived in self._derived.values():
|
698
|
+
args.update(derived.args)
|
699
|
+
for reaction in self._reactions.values():
|
700
|
+
args.update(reaction.args)
|
701
|
+
for surrogate in self._surrogates.values():
|
702
|
+
args.update(surrogate.args)
|
703
|
+
|
704
|
+
return set(self._parameters).difference(args)
|
705
|
+
|
660
706
|
##########################################################################
|
661
707
|
# Variables
|
662
708
|
##########################################################################
|
663
709
|
|
664
710
|
@property
|
665
|
-
def variables(self) -> dict[str, float]:
|
711
|
+
def variables(self) -> dict[str, float | Derived]:
|
666
712
|
"""Returns a copy of the variables dictionary.
|
667
713
|
|
668
714
|
Examples:
|
@@ -679,7 +725,7 @@ class Model:
|
|
679
725
|
return self._variables.copy()
|
680
726
|
|
681
727
|
@_invalidate_cache
|
682
|
-
def add_variable(self, name: str, initial_condition: float) -> Self:
|
728
|
+
def add_variable(self, name: str, initial_condition: float | Derived) -> Self:
|
683
729
|
"""Adds a variable to the model with the given name and initial condition.
|
684
730
|
|
685
731
|
Examples:
|
@@ -697,7 +743,7 @@ class Model:
|
|
697
743
|
self._variables[name] = initial_condition
|
698
744
|
return self
|
699
745
|
|
700
|
-
def add_variables(self, variables:
|
746
|
+
def add_variables(self, variables: Mapping[str, float | Derived]) -> Self:
|
701
747
|
"""Adds multiple variables to the model with their initial conditions.
|
702
748
|
|
703
749
|
Examples:
|
@@ -751,7 +797,7 @@ class Model:
|
|
751
797
|
return self
|
752
798
|
|
753
799
|
@_invalidate_cache
|
754
|
-
def update_variable(self, name: str, initial_condition: float) -> Self:
|
800
|
+
def update_variable(self, name: str, initial_condition: float | Derived) -> Self:
|
755
801
|
"""Updates the value of a variable in the model.
|
756
802
|
|
757
803
|
Examples:
|
@@ -771,7 +817,7 @@ class Model:
|
|
771
817
|
self._variables[name] = initial_condition
|
772
818
|
return self
|
773
819
|
|
774
|
-
def update_variables(self, variables:
|
820
|
+
def update_variables(self, variables: Mapping[str, float | Derived]) -> Self:
|
775
821
|
"""Updates multiple variables in the model.
|
776
822
|
|
777
823
|
Examples:
|
@@ -812,7 +858,9 @@ class Model:
|
|
812
858
|
initial_conditions: A dictionary where the keys are variable names and the values are their initial conditions.
|
813
859
|
|
814
860
|
"""
|
815
|
-
|
861
|
+
if (cache := self._cache) is None:
|
862
|
+
cache = self._create_cache()
|
863
|
+
return cache.initial_conditions
|
816
864
|
|
817
865
|
def make_variable_static(self, name: str, value: float | None = None) -> Self:
|
818
866
|
"""Converts a variable to a static parameter.
|
@@ -833,9 +881,12 @@ class Model:
|
|
833
881
|
Self: The instance of the class for method chaining.
|
834
882
|
|
835
883
|
"""
|
836
|
-
|
884
|
+
value_or_derived = self._variables[name] if value is None else value
|
837
885
|
self.remove_variable(name)
|
838
|
-
|
886
|
+
if isinstance(value_or_derived, Derived):
|
887
|
+
self.add_derived(name, value_or_derived.fn, args=value_or_derived.args)
|
888
|
+
else:
|
889
|
+
self.add_parameter(name, value_or_derived)
|
839
890
|
|
840
891
|
# Remove from stoichiometries
|
841
892
|
for reaction in self._reactions.values():
|
@@ -886,7 +937,8 @@ class Model:
|
|
886
937
|
if (cache := self._cache) is None:
|
887
938
|
cache = self._create_cache()
|
888
939
|
derived = self._derived
|
889
|
-
|
940
|
+
|
941
|
+
return {k: v for k, v in derived.items() if k not in cache.all_parameter_values}
|
890
942
|
|
891
943
|
@property
|
892
944
|
def derived_parameters(self) -> dict[str, Derived]:
|
@@ -905,7 +957,7 @@ class Model:
|
|
905
957
|
if (cache := self._cache) is None:
|
906
958
|
cache = self._create_cache()
|
907
959
|
derived = self._derived
|
908
|
-
return {k:
|
960
|
+
return {k: v for k, v in derived.items() if k in cache.all_parameter_values}
|
909
961
|
|
910
962
|
@_invalidate_cache
|
911
963
|
def add_derived(
|
@@ -1049,6 +1101,52 @@ class Model:
|
|
1049
1101
|
)
|
1050
1102
|
return pd.DataFrame(stoich_by_cpds).T.fillna(0)
|
1051
1103
|
|
1104
|
+
def get_stoichiometries_of_variable(
|
1105
|
+
self,
|
1106
|
+
variable: str,
|
1107
|
+
variables: dict[str, float] | None = None,
|
1108
|
+
time: float = 0.0,
|
1109
|
+
) -> dict[str, float]:
|
1110
|
+
"""Retrieve the stoichiometry of a specific variable.
|
1111
|
+
|
1112
|
+
Examples:
|
1113
|
+
>>> model.get_stoichiometries_of_variable("x1")
|
1114
|
+
{"v1": -1, "v2": 1}
|
1115
|
+
|
1116
|
+
Args:
|
1117
|
+
variable: The name of the variable for which to retrieve the stoichiometry.
|
1118
|
+
variables: A dictionary of variable names and their values.
|
1119
|
+
time: The time point at which to evaluate the stoichiometry.
|
1120
|
+
|
1121
|
+
"""
|
1122
|
+
if (cache := self._cache) is None:
|
1123
|
+
cache = self._create_cache()
|
1124
|
+
args = self.get_dependent(variables=variables, time=time)
|
1125
|
+
|
1126
|
+
stoich = copy.deepcopy(cache.stoich_by_cpds[variable])
|
1127
|
+
for rxn, derived in cache.dyn_stoich_by_cpds.get(variable, {}).items():
|
1128
|
+
stoich[rxn] = float(derived.fn(*(args[i] for i in derived.args)))
|
1129
|
+
return stoich
|
1130
|
+
|
1131
|
+
def get_raw_stoichiometries_of_variable(
|
1132
|
+
self, variable: str
|
1133
|
+
) -> dict[str, float | Derived]:
|
1134
|
+
"""Retrieve the raw stoichiometry of a specific variable.
|
1135
|
+
|
1136
|
+
Examples:
|
1137
|
+
>>> model.get_stoichiometries_of_variable("x1")
|
1138
|
+
{"v1": -1, "v2": Derived(...)}
|
1139
|
+
|
1140
|
+
Args:
|
1141
|
+
variable: The name of the variable for which to retrieve the stoichiometry.
|
1142
|
+
|
1143
|
+
"""
|
1144
|
+
stoichs: dict[str, dict[str, float | Derived]] = {}
|
1145
|
+
for rxn_name, rxn in self._reactions.items():
|
1146
|
+
for cpd_name, factor in rxn.stoichiometry.items():
|
1147
|
+
stoichs.setdefault(cpd_name, {})[rxn_name] = factor
|
1148
|
+
return stoichs[variable]
|
1149
|
+
|
1052
1150
|
@_invalidate_cache
|
1053
1151
|
def add_reaction(
|
1054
1152
|
self,
|
@@ -1250,7 +1348,7 @@ class Model:
|
|
1250
1348
|
surrogate: AbstractSurrogate,
|
1251
1349
|
args: list[str] | None = None,
|
1252
1350
|
outputs: list[str] | None = None,
|
1253
|
-
stoichiometries: dict[str, dict[str, float]] | None = None,
|
1351
|
+
stoichiometries: dict[str, dict[str, float | Derived]] | None = None,
|
1254
1352
|
) -> Self:
|
1255
1353
|
"""Adds a surrogate model to the current instance.
|
1256
1354
|
|
@@ -1285,7 +1383,7 @@ class Model:
|
|
1285
1383
|
name: str,
|
1286
1384
|
surrogate: AbstractSurrogate | None = None,
|
1287
1385
|
args: list[str] | None = None,
|
1288
|
-
stoichiometries: dict[str, dict[str, float]] | None = None,
|
1386
|
+
stoichiometries: dict[str, dict[str, float | Derived]] | None = None,
|
1289
1387
|
) -> Self:
|
1290
1388
|
"""Update a surrogate model in the model.
|
1291
1389
|
|
@@ -1337,6 +1435,27 @@ class Model:
|
|
1337
1435
|
names.extend(i.stoichiometries)
|
1338
1436
|
return names
|
1339
1437
|
|
1438
|
+
##########################################################################
|
1439
|
+
# Datasets
|
1440
|
+
##########################################################################
|
1441
|
+
|
1442
|
+
def add_data(self, name: str, data: pd.Series | pd.DataFrame) -> Self:
|
1443
|
+
"""Add named data set to model."""
|
1444
|
+
self._insert_id(name=name, ctx="data")
|
1445
|
+
self._data[name] = data
|
1446
|
+
return self
|
1447
|
+
|
1448
|
+
def update_data(self, name: str, data: pd.Series | pd.DataFrame) -> Self:
|
1449
|
+
"""Update named data set."""
|
1450
|
+
self._data[name] = data
|
1451
|
+
return self
|
1452
|
+
|
1453
|
+
def remove_data(self, name: str) -> Self:
|
1454
|
+
"""Remove data set from model."""
|
1455
|
+
self._remove_id(name=name)
|
1456
|
+
self._data.pop(name)
|
1457
|
+
return self
|
1458
|
+
|
1340
1459
|
##########################################################################
|
1341
1460
|
# Get dependent values. This includes
|
1342
1461
|
# - derived parameters
|
@@ -1371,14 +1490,17 @@ class Model:
|
|
1371
1490
|
with their respective names as keys and their calculated values as values.
|
1372
1491
|
|
1373
1492
|
"""
|
1374
|
-
args
|
1493
|
+
args = cache.all_parameter_values | variables | self._data
|
1375
1494
|
args["time"] = time
|
1376
1495
|
|
1377
1496
|
containers = self._derived | self._reactions | self._surrogates
|
1378
|
-
for name in cache.
|
1497
|
+
for name in cache.dyn_order:
|
1379
1498
|
containers[name].calculate_inpl(name, args)
|
1380
1499
|
|
1381
|
-
|
1500
|
+
for k in self._data:
|
1501
|
+
args.pop(k)
|
1502
|
+
|
1503
|
+
return cast(dict[str, float], args)
|
1382
1504
|
|
1383
1505
|
def get_dependent(
|
1384
1506
|
self,
|
@@ -1454,29 +1576,16 @@ class Model:
|
|
1454
1576
|
derived variables, and optionally readout variables, with time as an additional column.
|
1455
1577
|
|
1456
1578
|
"""
|
1457
|
-
|
1458
|
-
|
1459
|
-
|
1460
|
-
|
1461
|
-
|
1462
|
-
|
1463
|
-
|
1464
|
-
|
1465
|
-
index=variables.index,
|
1466
|
-
columns=list(cache.all_parameter_values),
|
1467
|
-
)
|
1468
|
-
|
1469
|
-
args = pd.concat((variables, pars_df), axis=1)
|
1470
|
-
args["time"] = args.index
|
1471
|
-
|
1472
|
-
containers = self._derived | self._reactions | self._surrogates
|
1473
|
-
for name in cache.order:
|
1474
|
-
containers[name].calculate_inpl_time_course(name, args)
|
1579
|
+
args = {
|
1580
|
+
time: self.get_dependent(
|
1581
|
+
variables=values.to_dict(),
|
1582
|
+
time=cast(float, time),
|
1583
|
+
include_readouts=include_readouts,
|
1584
|
+
)
|
1585
|
+
for time, values in variables.iterrows()
|
1586
|
+
}
|
1475
1587
|
|
1476
|
-
|
1477
|
-
for name, ro in self._readouts.items():
|
1478
|
-
args[name] = ro.fn(*args.loc[:, ro.args].to_numpy().T)
|
1479
|
-
return args
|
1588
|
+
return pd.DataFrame(args, dtype=float).T
|
1480
1589
|
|
1481
1590
|
##########################################################################
|
1482
1591
|
# Get args
|
@@ -1569,28 +1678,6 @@ class Model:
|
|
1569
1678
|
# Get fluxes
|
1570
1679
|
##########################################################################
|
1571
1680
|
|
1572
|
-
def _get_fluxes(self, args: dict[str, float]) -> dict[str, float]:
|
1573
|
-
"""Calculate the fluxes for the given arguments.
|
1574
|
-
|
1575
|
-
Examples:
|
1576
|
-
>>> model._get_fluxes({"x1": 1.0, "x2": 2.0, "k1": 0.1, "time": 0.0})
|
1577
|
-
{"r1": 0.1, "r2": 0.2}
|
1578
|
-
|
1579
|
-
Args:
|
1580
|
-
args (dict[str, float]): A dictionary where the keys are argument names and the values are their corresponding float values.
|
1581
|
-
|
1582
|
-
Returns:
|
1583
|
-
dict[str, float]: A dictionary where the keys are reaction names and the values are the calculated fluxes.
|
1584
|
-
|
1585
|
-
"""
|
1586
|
-
fluxes: dict[str, float] = {}
|
1587
|
-
for name, rxn in self._reactions.items():
|
1588
|
-
fluxes[name] = cast(float, rxn.fn(*(args[arg] for arg in rxn.args)))
|
1589
|
-
|
1590
|
-
for surrogate in self._surrogates.values():
|
1591
|
-
fluxes |= surrogate.predict(np.array([args[arg] for arg in surrogate.args]))
|
1592
|
-
return fluxes
|
1593
|
-
|
1594
1681
|
def get_fluxes(
|
1595
1682
|
self,
|
1596
1683
|
variables: dict[str, float] | None = None,
|
mxlpy/nn/__init__.py
CHANGED
@@ -1,10 +1,29 @@
|
|
1
1
|
"""Collection of neural network architectures."""
|
2
2
|
|
3
|
-
import
|
3
|
+
from __future__ import annotations
|
4
4
|
|
5
|
-
|
5
|
+
from typing import TYPE_CHECKING
|
6
6
|
|
7
|
-
|
8
|
-
|
7
|
+
if TYPE_CHECKING:
|
8
|
+
import contextlib
|
9
9
|
|
10
|
-
|
10
|
+
with contextlib.suppress(ImportError):
|
11
|
+
from . import _keras as keras
|
12
|
+
from . import _torch as torch
|
13
|
+
else:
|
14
|
+
from lazy_import import lazy_module
|
15
|
+
|
16
|
+
keras = lazy_module(
|
17
|
+
"mxlpy.nn._keras",
|
18
|
+
error_strings={"module": "keras", "install_name": "mxlpy[tf]"},
|
19
|
+
)
|
20
|
+
torch = lazy_module(
|
21
|
+
"mxlpy.nn._torch",
|
22
|
+
error_strings={"module": "torch", "install_name": "mxlpy[torch]"},
|
23
|
+
)
|
24
|
+
|
25
|
+
|
26
|
+
__all__ = [
|
27
|
+
"keras",
|
28
|
+
"torch",
|
29
|
+
]
|
mxlpy/nn/_keras.py
ADDED
@@ -0,0 +1,92 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING, cast
|
4
|
+
|
5
|
+
import keras
|
6
|
+
import pandas as pd
|
7
|
+
from tqdm.keras import TqdmCallback
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from mxlpy.types import Array
|
11
|
+
|
12
|
+
__all__ = [
|
13
|
+
"LSTM",
|
14
|
+
"MLP",
|
15
|
+
"train",
|
16
|
+
]
|
17
|
+
|
18
|
+
|
19
|
+
def train(
|
20
|
+
model: keras.Model,
|
21
|
+
features: pd.DataFrame | Array,
|
22
|
+
targets: pd.DataFrame | Array,
|
23
|
+
epochs: int,
|
24
|
+
batch_size: int | None,
|
25
|
+
) -> pd.Series:
|
26
|
+
"""Train the neural network using mini-batch gradient descent.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
model: Neural network model to train.
|
30
|
+
features: Input features as a tensor.
|
31
|
+
targets: Target values as a tensor.
|
32
|
+
epochs: Number of training epochs.
|
33
|
+
optimizer: Optimizer for training.
|
34
|
+
device: torch device
|
35
|
+
batch_size: Size of mini-batches for training.
|
36
|
+
loss_fn: Loss function
|
37
|
+
|
38
|
+
Returns:
|
39
|
+
pd.Series: Series containing the training loss history.
|
40
|
+
|
41
|
+
"""
|
42
|
+
history = model.fit(
|
43
|
+
features,
|
44
|
+
targets,
|
45
|
+
batch_size=batch_size,
|
46
|
+
epochs=epochs,
|
47
|
+
verbose=cast(str, 0),
|
48
|
+
callbacks=[TqdmCallback()],
|
49
|
+
)
|
50
|
+
return pd.Series(history.history["loss"])
|
51
|
+
|
52
|
+
|
53
|
+
def MLP( # noqa: N802
|
54
|
+
n_inputs: int,
|
55
|
+
neurons_per_layer: list[int],
|
56
|
+
activation: None = None,
|
57
|
+
output_activation: None = None,
|
58
|
+
) -> keras.Sequential:
|
59
|
+
"""Multilayer Perceptron (MLP) for surrogate modeling and neural posterior estimation.
|
60
|
+
|
61
|
+
Methods:
|
62
|
+
forward: Forward pass through the neural network.
|
63
|
+
|
64
|
+
"""
|
65
|
+
model = keras.Sequential([keras.Input(shape=(n_inputs,))])
|
66
|
+
for neurons in neurons_per_layer[:-1]:
|
67
|
+
model.add(keras.layers.Dense(neurons, activation=activation))
|
68
|
+
model.add(keras.layers.Dense(neurons_per_layer[-1], activation=output_activation))
|
69
|
+
return model
|
70
|
+
|
71
|
+
|
72
|
+
def LSTM( # noqa: N802
|
73
|
+
n_inputs: int,
|
74
|
+
n_outputs: int,
|
75
|
+
n_hidden: int,
|
76
|
+
) -> keras.Sequential:
|
77
|
+
"""Long Short-Term Memory (LSTM) network for time series modeling.
|
78
|
+
|
79
|
+
Methods:
|
80
|
+
forward: Forward pass through the neural network.
|
81
|
+
|
82
|
+
"""
|
83
|
+
model = keras.Sequential(
|
84
|
+
[
|
85
|
+
keras.Input(
|
86
|
+
shape=(n_inputs),
|
87
|
+
)
|
88
|
+
]
|
89
|
+
)
|
90
|
+
model.add(keras.layers.LSTM(n_hidden))
|
91
|
+
model.add(keras.layers.Dense(n_outputs))
|
92
|
+
return model
|