modelbase2 0.1.79__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modelbase2/__init__.py +148 -25
- modelbase2/distributions.py +336 -0
- modelbase2/experimental/__init__.py +17 -0
- modelbase2/experimental/codegen.py +239 -0
- modelbase2/experimental/diff.py +227 -0
- modelbase2/experimental/notes.md +4 -0
- modelbase2/experimental/tex.py +521 -0
- modelbase2/fit.py +284 -0
- modelbase2/fns.py +185 -0
- modelbase2/integrators/__init__.py +19 -0
- modelbase2/integrators/int_assimulo.py +146 -0
- modelbase2/integrators/int_scipy.py +147 -0
- modelbase2/label_map.py +610 -0
- modelbase2/linear_label_map.py +301 -0
- modelbase2/mc.py +548 -0
- modelbase2/mca.py +280 -0
- modelbase2/model.py +1621 -0
- modelbase2/nnarchitectures.py +128 -0
- modelbase2/npe.py +271 -0
- modelbase2/parallel.py +171 -0
- modelbase2/parameterise.py +28 -0
- modelbase2/paths.py +36 -0
- modelbase2/plot.py +832 -0
- modelbase2/sbml/__init__.py +14 -0
- modelbase2/sbml/_data.py +77 -0
- modelbase2/sbml/_export.py +656 -0
- modelbase2/sbml/_import.py +585 -0
- modelbase2/sbml/_mathml.py +691 -0
- modelbase2/sbml/_name_conversion.py +52 -0
- modelbase2/sbml/_unit_conversion.py +74 -0
- modelbase2/scan.py +616 -0
- modelbase2/scope.py +96 -0
- modelbase2/simulator.py +635 -0
- modelbase2/surrogates/__init__.py +31 -0
- modelbase2/surrogates/_poly.py +91 -0
- modelbase2/surrogates/_torch.py +191 -0
- modelbase2/surrogates.py +316 -0
- modelbase2/types.py +352 -11
- modelbase2-0.3.0.dist-info/METADATA +93 -0
- modelbase2-0.3.0.dist-info/RECORD +43 -0
- {modelbase2-0.1.79.dist-info → modelbase2-0.3.0.dist-info}/WHEEL +1 -1
- modelbase2/core/__init__.py +0 -29
- modelbase2/core/algebraic_module_container.py +0 -130
- modelbase2/core/constant_container.py +0 -113
- modelbase2/core/data.py +0 -109
- modelbase2/core/name_container.py +0 -29
- modelbase2/core/reaction_container.py +0 -115
- modelbase2/core/utils.py +0 -28
- modelbase2/core/variable_container.py +0 -24
- modelbase2/ode/__init__.py +0 -13
- modelbase2/ode/integrator.py +0 -80
- modelbase2/ode/mca.py +0 -270
- modelbase2/ode/model.py +0 -470
- modelbase2/ode/simulator.py +0 -153
- modelbase2/utils/__init__.py +0 -0
- modelbase2/utils/plotting.py +0 -372
- modelbase2-0.1.79.dist-info/METADATA +0 -44
- modelbase2-0.1.79.dist-info/RECORD +0 -22
- {modelbase2-0.1.79.dist-info → modelbase2-0.3.0.dist-info/licenses}/LICENSE +0 -0
@@ -1,113 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
from .data import Constant, DerivedConstant
|
4
|
-
from dataclasses import dataclass, field
|
5
|
-
from queue import Empty, SimpleQueue
|
6
|
-
|
7
|
-
# import logging
|
8
|
-
# logger = logging.getLogger(__name__)
|
9
|
-
|
10
|
-
|
11
|
-
@dataclass
|
12
|
-
class ConstantContainer:
|
13
|
-
constants: dict[str, Constant] = field(default_factory=dict)
|
14
|
-
derived_constants: dict[str, DerivedConstant] = field(default_factory=dict)
|
15
|
-
_derived_from_constants: set[str] = field(default_factory=set)
|
16
|
-
_derived_constant_order: list[str] = field(default_factory=list)
|
17
|
-
values: dict[str, float] = field(default_factory=dict)
|
18
|
-
|
19
|
-
###############################################################################
|
20
|
-
# Basic
|
21
|
-
###############################################################################
|
22
|
-
|
23
|
-
def add_basic(self, constant: Constant, update_derived: bool = True) -> None:
|
24
|
-
if (name := constant.name) in self.constants:
|
25
|
-
raise KeyError(f"Constant {name} already exists in the model.")
|
26
|
-
self.constants[name] = constant
|
27
|
-
self.values[name] = constant.value
|
28
|
-
# if name in self._derived_from_constants and update_derived:
|
29
|
-
# self._update_derived_constant_values()
|
30
|
-
|
31
|
-
def remove_basic(self, name: str) -> None:
|
32
|
-
del self.constants[name]
|
33
|
-
del self.values[name]
|
34
|
-
|
35
|
-
def update_basic(self, name: str, value: float, update_derived: bool = True) -> None:
|
36
|
-
if name not in self.constants:
|
37
|
-
raise KeyError(
|
38
|
-
f"Constant {name} is not in the model. You have to add it first"
|
39
|
-
)
|
40
|
-
self.constants[name].value = value
|
41
|
-
self.values[name] = value
|
42
|
-
if name in self._derived_from_constants and update_derived:
|
43
|
-
self._update_derived_constant_values()
|
44
|
-
|
45
|
-
###############################################################################
|
46
|
-
# Derived
|
47
|
-
###############################################################################
|
48
|
-
|
49
|
-
def _sort_derived_constants(self, max_iterations: int = 10_000) -> None:
|
50
|
-
available_args = set(self.constants)
|
51
|
-
order = []
|
52
|
-
to_sort: SimpleQueue[tuple[str, DerivedConstant]] = SimpleQueue()
|
53
|
-
for k, v in self.derived_constants.items():
|
54
|
-
to_sort.put((k, v))
|
55
|
-
|
56
|
-
i = 0
|
57
|
-
last_name = ""
|
58
|
-
while True:
|
59
|
-
try:
|
60
|
-
name, constant = to_sort.get_nowait()
|
61
|
-
# logger.warning(f"Trying {name}, which requires {constant.args}")
|
62
|
-
except Empty:
|
63
|
-
break
|
64
|
-
if set(constant.args).issubset(available_args):
|
65
|
-
# logger.warning(f"Sorting in {name}")
|
66
|
-
available_args.add(name)
|
67
|
-
order.append(name)
|
68
|
-
elif name == last_name:
|
69
|
-
raise ValueError(f"Missing args for {name}")
|
70
|
-
else:
|
71
|
-
# logger.warning(f"{name} doesn't fit yet, {set(constant.args).difference(available_args)} missing")
|
72
|
-
to_sort.put((name, constant))
|
73
|
-
last_name = name
|
74
|
-
i += 1
|
75
|
-
if i > max_iterations:
|
76
|
-
raise ValueError(
|
77
|
-
f"Exceeded max iterations on derived constants sorting {name}. "
|
78
|
-
"Check if there are circular references."
|
79
|
-
)
|
80
|
-
self._derived_constant_order = order
|
81
|
-
|
82
|
-
def _update_derived_constant_values(self) -> None:
|
83
|
-
for name in self._derived_constant_order:
|
84
|
-
derived_constant = self.derived_constants[name]
|
85
|
-
value = derived_constant.function(
|
86
|
-
*(self.values[i] for i in derived_constant.args)
|
87
|
-
)
|
88
|
-
self.values[name] = value
|
89
|
-
|
90
|
-
def add_derived(self, constant: DerivedConstant, update_derived: bool = True) -> None:
|
91
|
-
name = constant.name
|
92
|
-
self.derived_constants[name] = constant
|
93
|
-
for arg in constant.args:
|
94
|
-
self._derived_from_constants.add(arg)
|
95
|
-
|
96
|
-
# Calculate initial value
|
97
|
-
value = constant.function(*(self.values[i] for i in constant.args))
|
98
|
-
self.values[name] = value
|
99
|
-
self._sort_derived_constants()
|
100
|
-
|
101
|
-
if name in self._derived_from_constants and update_derived:
|
102
|
-
self._update_derived_constant_values()
|
103
|
-
|
104
|
-
def remove_derived(self, name: str) -> DerivedConstant:
|
105
|
-
"""Remove a derived constant from the model."""
|
106
|
-
old_constant = self.derived_constants.pop(name)
|
107
|
-
derived_from = old_constant.args
|
108
|
-
for i in derived_from:
|
109
|
-
if all(i not in j.args for j in self.derived_constants.values()):
|
110
|
-
self._derived_from_constants.remove(i)
|
111
|
-
del self.values[name]
|
112
|
-
self._derived_constant_order.remove(name)
|
113
|
-
return old_constant
|
modelbase2/core/data.py
DELETED
@@ -1,109 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
from ..types import Array
|
4
|
-
from .utils import check_function_arity
|
5
|
-
from dataclasses import dataclass, field
|
6
|
-
from typing import Any, Callable, Generator, Iterable, Protocol, TypeVar, Union
|
7
|
-
|
8
|
-
T = TypeVar("T")
|
9
|
-
|
10
|
-
|
11
|
-
class Arithmetic(Protocol[T]):
|
12
|
-
def __add__(self, other: T) -> T:
|
13
|
-
...
|
14
|
-
|
15
|
-
def __mul__(self, other: T) -> T:
|
16
|
-
...
|
17
|
-
|
18
|
-
|
19
|
-
RateFunction = Callable[..., float]
|
20
|
-
ModuleFunction = Callable[..., Iterable[float]]
|
21
|
-
StoichiometryByReaction = dict[str, float]
|
22
|
-
StoichiometryByVariable = dict[str, float]
|
23
|
-
ValueData = Union[
|
24
|
-
dict[str, float],
|
25
|
-
dict[str, list[float]],
|
26
|
-
dict[str, Array],
|
27
|
-
Array,
|
28
|
-
list[float],
|
29
|
-
]
|
30
|
-
TimeData = Union[float, list[float], Array]
|
31
|
-
|
32
|
-
|
33
|
-
@dataclass
|
34
|
-
class Variable:
|
35
|
-
name: str
|
36
|
-
unit: str
|
37
|
-
|
38
|
-
|
39
|
-
@dataclass
|
40
|
-
class Constant:
|
41
|
-
name: str
|
42
|
-
value: float
|
43
|
-
unit: str
|
44
|
-
sources: list[str] = field(default_factory=list)
|
45
|
-
|
46
|
-
|
47
|
-
@dataclass
|
48
|
-
class DerivedConstant:
|
49
|
-
name: str
|
50
|
-
function: RateFunction
|
51
|
-
args: list[str]
|
52
|
-
unit: str
|
53
|
-
|
54
|
-
|
55
|
-
@dataclass
|
56
|
-
class DerivedStoichiometry:
|
57
|
-
function: RateFunction
|
58
|
-
args: list[str]
|
59
|
-
|
60
|
-
|
61
|
-
@dataclass
|
62
|
-
class AlgebraicModule:
|
63
|
-
name: str
|
64
|
-
function: ModuleFunction
|
65
|
-
derived_variables: list[str]
|
66
|
-
args: list[str]
|
67
|
-
|
68
|
-
def __post_init__(self) -> None:
|
69
|
-
if not check_function_arity(function=self.function, arity=len(self.args)):
|
70
|
-
raise ValueError(f"Function arity does not match args of {self.name}")
|
71
|
-
|
72
|
-
def __getitem__(self, key: str) -> Any:
|
73
|
-
return self.__dict__[key]
|
74
|
-
|
75
|
-
def __iter__(self) -> Generator:
|
76
|
-
yield from self.__dict__
|
77
|
-
|
78
|
-
def keys(self) -> tuple[str, ...]:
|
79
|
-
"""Get all valid keys of the algebraic module"""
|
80
|
-
return tuple(self.__dict__)
|
81
|
-
|
82
|
-
|
83
|
-
@dataclass
|
84
|
-
class Rate:
|
85
|
-
name: str
|
86
|
-
function: RateFunction
|
87
|
-
args: list[str]
|
88
|
-
|
89
|
-
def __post_init__(self) -> None:
|
90
|
-
if not check_function_arity(function=self.function, arity=len(self.args)):
|
91
|
-
raise ValueError(f"Function arity does not match args of {self.name}")
|
92
|
-
|
93
|
-
def __getitem__(self, key: str) -> Any:
|
94
|
-
return self.__dict__[key]
|
95
|
-
|
96
|
-
def __iter__(self) -> Generator:
|
97
|
-
yield from self.__dict__
|
98
|
-
|
99
|
-
def keys(self) -> tuple[str, ...]:
|
100
|
-
return tuple(self.__dict__)
|
101
|
-
|
102
|
-
|
103
|
-
@dataclass
|
104
|
-
class Reaction:
|
105
|
-
name: str
|
106
|
-
function: RateFunction
|
107
|
-
stoichiometry: StoichiometryByReaction
|
108
|
-
args: list[str]
|
109
|
-
derived_stoichiometry: dict[str, DerivedStoichiometry] = field(default_factory=dict)
|
@@ -1,29 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import logging
|
4
|
-
from dataclasses import dataclass, field
|
5
|
-
from typing import Iterable
|
6
|
-
|
7
|
-
logger = logging.getLogger(__name__)
|
8
|
-
|
9
|
-
|
10
|
-
@dataclass
|
11
|
-
class NameContainer:
|
12
|
-
names: dict[str, str] = field(default_factory=dict)
|
13
|
-
|
14
|
-
def add(self, name: str, element_type: str) -> None:
|
15
|
-
if (old_type := self.names.get(name)) is not None:
|
16
|
-
raise KeyError(
|
17
|
-
f"Cannot add {element_type} {name}, as there already exists a {old_type} with that name."
|
18
|
-
)
|
19
|
-
|
20
|
-
logger.info(f"Adding name {name}")
|
21
|
-
self.names[name] = element_type
|
22
|
-
|
23
|
-
def remove(self, name: str) -> None:
|
24
|
-
logger.info(f"Removing name {name}")
|
25
|
-
del self.names[name]
|
26
|
-
|
27
|
-
def require_multiple(self, names: Iterable[str]) -> None:
|
28
|
-
if bool(difference := set(names).difference(self.names)):
|
29
|
-
raise KeyError(f"Names '{', '.join(difference)}' are missing.")
|
@@ -1,115 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import logging
|
4
|
-
import numpy as np
|
5
|
-
import pandas as pd
|
6
|
-
from ..types import Array
|
7
|
-
from .data import DerivedStoichiometry, RateFunction, Reaction, StoichiometryByVariable
|
8
|
-
from dataclasses import dataclass, field
|
9
|
-
from typing import Optional
|
10
|
-
|
11
|
-
logger = logging.getLogger()
|
12
|
-
|
13
|
-
|
14
|
-
@dataclass
|
15
|
-
class ReactionContainer:
|
16
|
-
reactions: dict[str, Reaction] = field(default_factory=dict)
|
17
|
-
stoichiometries_by_variables: dict[str, StoichiometryByVariable] = field(
|
18
|
-
default_factory=dict
|
19
|
-
)
|
20
|
-
|
21
|
-
def set_stoichiometry(self, variable: str, reaction: str, factor: float) -> None:
|
22
|
-
self.stoichiometries_by_variables.setdefault(variable, {})[reaction] = factor
|
23
|
-
|
24
|
-
def update_stoichiometry(self, name: str, reaction: Reaction) -> None:
|
25
|
-
for variable, factor in reaction.stoichiometry.items():
|
26
|
-
self.set_stoichiometry(variable, name, factor)
|
27
|
-
|
28
|
-
def update_derived_stoichiometry(
|
29
|
-
self, name: str, reaction: Reaction, constants: dict[str, float]
|
30
|
-
) -> None:
|
31
|
-
for variable, derived_stoich in reaction.derived_stoichiometry.items():
|
32
|
-
factor = derived_stoich.function(*(constants[i] for i in derived_stoich.args))
|
33
|
-
self.set_stoichiometry(variable, name, factor)
|
34
|
-
|
35
|
-
def update_derived_stoichiometries(self, constants: dict[str, float]) -> None:
|
36
|
-
for name, reaction in self.reactions.items():
|
37
|
-
self.update_derived_stoichiometry(name, reaction, constants)
|
38
|
-
|
39
|
-
def add(self, reaction: Reaction, constants: dict[str, float]) -> None:
|
40
|
-
if (name := reaction.name) in self.reactions:
|
41
|
-
raise KeyError(f"Reaction {name} already exists in the model.")
|
42
|
-
self.reactions[name] = reaction
|
43
|
-
self.update_stoichiometry(name, reaction)
|
44
|
-
self.update_derived_stoichiometry(name, reaction, constants)
|
45
|
-
|
46
|
-
def remove(self, name: str) -> Reaction:
|
47
|
-
reaction = self.reactions.pop(name)
|
48
|
-
for variable in reaction.stoichiometry:
|
49
|
-
del self.stoichiometries_by_variables[variable][name]
|
50
|
-
if not bool(self.stoichiometries_by_variables[variable]):
|
51
|
-
del self.stoichiometries_by_variables[variable]
|
52
|
-
|
53
|
-
return reaction
|
54
|
-
|
55
|
-
def update(
|
56
|
-
self,
|
57
|
-
name: str,
|
58
|
-
function: Optional[RateFunction],
|
59
|
-
stoichiometry: Optional[StoichiometryByVariable],
|
60
|
-
# derived_stoichiometry: Optional[DerivedStoichiometry],
|
61
|
-
derived_stoichiometry: Optional[dict[str, DerivedStoichiometry]],
|
62
|
-
args: Optional[list[str]],
|
63
|
-
constants: dict[str, float],
|
64
|
-
) -> None:
|
65
|
-
reaction = self.remove(name)
|
66
|
-
if function is not None:
|
67
|
-
reaction.function = function
|
68
|
-
if stoichiometry is not None:
|
69
|
-
reaction.stoichiometry = stoichiometry
|
70
|
-
if derived_stoichiometry is not None:
|
71
|
-
reaction.derived_stoichiometry = derived_stoichiometry
|
72
|
-
if args is not None:
|
73
|
-
reaction.args = args
|
74
|
-
self.add(reaction, constants)
|
75
|
-
|
76
|
-
def get_names(self) -> list[str]:
|
77
|
-
return list(self.reactions)
|
78
|
-
|
79
|
-
def get_stoichiometries(self) -> pd.DataFrame:
|
80
|
-
reactions = self.reactions
|
81
|
-
variables = self.stoichiometries_by_variables
|
82
|
-
variable_indexes = {v: k for k, v in enumerate(variables)}
|
83
|
-
reaction_indexes = {v: k for k, v in enumerate(reactions)}
|
84
|
-
|
85
|
-
data = np.zeros(shape=[len(variables), len(reactions)])
|
86
|
-
for cpd, stoich in variables.items():
|
87
|
-
for reaction, factor in stoich.items():
|
88
|
-
data[variable_indexes[cpd], reaction_indexes[reaction]] = factor
|
89
|
-
# for stoich_idx, reaction in enumerate(reactions.values()):
|
90
|
-
# for cpd, stoich in reaction.stoichiometry.items():
|
91
|
-
# data[variable_indexes[cpd], stoich_idx] = stoich
|
92
|
-
return pd.DataFrame(
|
93
|
-
data=data,
|
94
|
-
index=variables,
|
95
|
-
columns=reactions,
|
96
|
-
)
|
97
|
-
|
98
|
-
def get_fluxes_float(self, args: dict[str, float]) -> dict[str, float]:
|
99
|
-
fluxes = {}
|
100
|
-
for name, reaction in self.reactions.items():
|
101
|
-
fluxes[name] = reaction.function(*(args[arg] for arg in reaction.args))
|
102
|
-
return fluxes
|
103
|
-
|
104
|
-
def get_fluxes_array(self, args: dict[str, Array]) -> dict[str, Array]:
|
105
|
-
fluxes = np.full((len(self.reactions), len(args["time"])), np.nan, dtype=float)
|
106
|
-
for i, reaction in enumerate(self.reactions.values()):
|
107
|
-
fluxes[i, :] = reaction.function(*(args[arg] for arg in reaction.args))
|
108
|
-
return dict(zip(self.reactions.keys(), fluxes))
|
109
|
-
|
110
|
-
def get_right_hand_side_float(self, fluxes: dict[str, float]) -> dict[str, float]:
|
111
|
-
rhs: dict[str, float] = {}
|
112
|
-
for cpd, stoichiometry in self.stoichiometries_by_variables.items():
|
113
|
-
for rate, factor in stoichiometry.items():
|
114
|
-
rhs[cpd] = rhs.get(cpd, 0) + factor * fluxes[rate]
|
115
|
-
return rhs
|
modelbase2/core/utils.py
DELETED
@@ -1,28 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import inspect
|
4
|
-
from typing import Callable
|
5
|
-
|
6
|
-
|
7
|
-
def check_function_arity(function: Callable, arity: int) -> bool:
|
8
|
-
"""Check if the amount of arguments given match argument count"""
|
9
|
-
argspec = inspect.getfullargspec(function)
|
10
|
-
# Give up on *args functions
|
11
|
-
if argspec.varargs is not None:
|
12
|
-
return True
|
13
|
-
|
14
|
-
# The sane case
|
15
|
-
if len(argspec.args) == arity:
|
16
|
-
return True
|
17
|
-
|
18
|
-
# It might be that the user has set some args to default values,
|
19
|
-
# in which case they are also ok (might be kwonly as well)
|
20
|
-
defaults = argspec.defaults
|
21
|
-
if defaults is not None:
|
22
|
-
if len(argspec.args) + len(defaults) == arity:
|
23
|
-
return True
|
24
|
-
kwonly = argspec.kwonlyargs
|
25
|
-
if defaults is not None:
|
26
|
-
if len(argspec.args) + len(kwonly) == arity:
|
27
|
-
return True
|
28
|
-
return False
|
@@ -1,24 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
from .data import Variable
|
4
|
-
from dataclasses import dataclass, field
|
5
|
-
from typing import Iterator
|
6
|
-
|
7
|
-
|
8
|
-
@dataclass
|
9
|
-
class VariableContainer:
|
10
|
-
variables: dict[str, Variable] = field(default_factory=dict)
|
11
|
-
|
12
|
-
def __iter__(self) -> Iterator[str]:
|
13
|
-
return iter(self.variables)
|
14
|
-
|
15
|
-
def add(self, variable: Variable) -> None:
|
16
|
-
name = variable.name
|
17
|
-
if name == "time":
|
18
|
-
raise KeyError("'time' is a protected variable for the simulation time")
|
19
|
-
if name in self.variables:
|
20
|
-
raise KeyError(f"Variable {variable} already exists.")
|
21
|
-
self.variables[variable.name] = variable
|
22
|
-
|
23
|
-
def remove(self, variable: str) -> None:
|
24
|
-
del self.variables[variable]
|
modelbase2/ode/__init__.py
DELETED
modelbase2/ode/integrator.py
DELETED
@@ -1,80 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import numpy as np
|
4
|
-
from ..types import Array
|
5
|
-
from assimulo.problem import Explicit_Problem # type: ignore
|
6
|
-
from assimulo.solvers import CVode # type: ignore
|
7
|
-
from assimulo.solvers.sundials import CVodeError # type: ignore
|
8
|
-
from dataclasses import dataclass
|
9
|
-
from typing import Callable, Optional, cast
|
10
|
-
|
11
|
-
|
12
|
-
@dataclass
|
13
|
-
class AssmimuloSettings:
|
14
|
-
atol: float = 1e-8
|
15
|
-
rtol: float = 1e-8
|
16
|
-
maxnef: int = 4
|
17
|
-
maxncf: int = 1
|
18
|
-
verbosity: int = 50
|
19
|
-
|
20
|
-
|
21
|
-
@dataclass
|
22
|
-
class IntegrationResult:
|
23
|
-
time: Array
|
24
|
-
values: Array
|
25
|
-
|
26
|
-
|
27
|
-
@dataclass(init=False)
|
28
|
-
class Assimulo:
|
29
|
-
problem: Explicit_Problem
|
30
|
-
integrator: CVode
|
31
|
-
settings: AssmimuloSettings
|
32
|
-
|
33
|
-
def __init__(
|
34
|
-
self,
|
35
|
-
rhs: Callable,
|
36
|
-
y0: list[float],
|
37
|
-
settings: Optional[AssmimuloSettings] = None,
|
38
|
-
) -> None:
|
39
|
-
self.problem = Explicit_Problem(rhs, y0)
|
40
|
-
self.integrator = CVode(self.problem)
|
41
|
-
self.settings = AssmimuloSettings() if settings is None else settings
|
42
|
-
|
43
|
-
def _set_settings(self) -> None:
|
44
|
-
for k, v in self.settings.__dict__.items():
|
45
|
-
setattr(self.integrator, k, v)
|
46
|
-
|
47
|
-
def update_settings(self, settings: AssmimuloSettings) -> None:
|
48
|
-
self.settings = settings
|
49
|
-
|
50
|
-
def integrate(
|
51
|
-
self,
|
52
|
-
t_end: float,
|
53
|
-
steps: Optional[int],
|
54
|
-
time_points: Optional[list[float]],
|
55
|
-
) -> Optional[IntegrationResult]:
|
56
|
-
self._set_settings()
|
57
|
-
if steps is None:
|
58
|
-
steps = 0
|
59
|
-
try:
|
60
|
-
t, y = self.integrator.simulate(t_end, steps, time_points)
|
61
|
-
return IntegrationResult(np.array(t, dtype=float), np.array(y, dtype=float))
|
62
|
-
except CVodeError:
|
63
|
-
return None
|
64
|
-
|
65
|
-
def integrate_to_steady_state(self, tolerance: float) -> Optional[Array]:
|
66
|
-
self.reset()
|
67
|
-
self._set_settings()
|
68
|
-
t_end = 1000
|
69
|
-
for _ in range(1, 4):
|
70
|
-
res = self.integrate(t_end, None, None)
|
71
|
-
if res is None:
|
72
|
-
return None
|
73
|
-
y = res.values
|
74
|
-
if np.linalg.norm(y[-1] - y[-2], ord=2) < tolerance:
|
75
|
-
return cast(Array, y[-1])
|
76
|
-
t_end *= 1000
|
77
|
-
return None
|
78
|
-
|
79
|
-
def reset(self) -> None:
|
80
|
-
self.integrator.reset()
|