mxlpy 0.20.0__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,174 @@
1
+ """Module to export models as code."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import TYPE_CHECKING
7
+
8
+ from mxlpy.meta.sympy_tools import (
9
+ fn_to_sympy,
10
+ list_of_symbols,
11
+ stoichiometries_to_sympy,
12
+ sympy_to_inline_py,
13
+ sympy_to_inline_rust,
14
+ )
15
+
16
+ if TYPE_CHECKING:
17
+ from collections.abc import Callable
18
+
19
+ import sympy
20
+
21
+ from mxlpy.model import Model
22
+
23
+ __all__ = [
24
+ "generate_model_code_py",
25
+ "generate_model_code_rs",
26
+ ]
27
+
28
+ _LOGGER = logging.getLogger(__name__)
29
+
30
+
31
+ def _generate_model_code(
32
+ model: Model,
33
+ *,
34
+ sized: bool,
35
+ model_fn: str,
36
+ variables_template: str,
37
+ assignment_template: str,
38
+ sympy_inline_fn: Callable[[sympy.Expr], str],
39
+ return_template: str,
40
+ imports: list[str] | None = None,
41
+ end: str | None = None,
42
+ free_parameters: list[str] | None = None,
43
+ ) -> str:
44
+ source: list[str] = []
45
+ # Model components
46
+ variables = model.get_initial_conditions()
47
+ parameters = model.get_parameter_values()
48
+
49
+ if imports is not None:
50
+ source.extend(imports)
51
+
52
+ if not sized:
53
+ source.append(model_fn)
54
+ else:
55
+ source.append(model_fn.format(n=len(variables)))
56
+
57
+ if len(variables) > 0:
58
+ source.append(variables_template.format(", ".join(variables)))
59
+
60
+ # Parameters
61
+ if free_parameters is not None:
62
+ for key in free_parameters:
63
+ parameters.pop(key)
64
+ if len(parameters) > 0:
65
+ source.append(
66
+ "\n".join(
67
+ assignment_template.format(k=k, v=v) for k, v in parameters.items()
68
+ )
69
+ )
70
+
71
+ # Derived
72
+ for name, derived in model.get_raw_derived().items():
73
+ expr = fn_to_sympy(
74
+ derived.fn,
75
+ origin=name,
76
+ model_args=list_of_symbols(derived.args),
77
+ )
78
+ if expr is None:
79
+ msg = f"Unable to parse fn for derived value '{name}'"
80
+ raise ValueError(msg)
81
+ source.append(assignment_template.format(k=name, v=sympy_inline_fn(expr)))
82
+
83
+ # Reactions
84
+ for name, rxn in model.get_raw_reactions().items():
85
+ expr = fn_to_sympy(
86
+ rxn.fn,
87
+ origin=name,
88
+ model_args=list_of_symbols(rxn.args),
89
+ )
90
+ if expr is None:
91
+ msg = f"Unable to parse fn for reaction value '{name}'"
92
+ raise ValueError(msg)
93
+ source.append(assignment_template.format(k=name, v=sympy_inline_fn(expr)))
94
+
95
+ # Diff eqs
96
+ diff_eqs = {}
97
+ for rxn_name, rxn in model.get_raw_reactions().items():
98
+ for var_name, factor in rxn.stoichiometry.items():
99
+ diff_eqs.setdefault(var_name, {})[rxn_name] = factor
100
+
101
+ for variable, stoich in diff_eqs.items():
102
+ expr = stoichiometries_to_sympy(origin=variable, stoichs=stoich)
103
+ source.append(
104
+ assignment_template.format(k=f"d{variable}dt", v=sympy_inline_fn(expr))
105
+ )
106
+
107
+ # Surrogates
108
+ if len(model._surrogates) > 0: # noqa: SLF001
109
+ msg = "Generating code for Surrogates not yet supported."
110
+ _LOGGER.warning(msg)
111
+
112
+ # Return
113
+ ret = ", ".join(f"d{i}dt" for i in diff_eqs) if len(diff_eqs) > 0 else "()"
114
+ source.append(return_template.format(ret))
115
+
116
+ if end is not None:
117
+ source.append(end)
118
+
119
+ # print(source)
120
+ return "\n".join(source)
121
+
122
+
123
+ def generate_model_code_py(
124
+ model: Model,
125
+ free_parameters: list[str] | None = None,
126
+ ) -> str:
127
+ """Transform the model into a python function, inlining the function calls."""
128
+ if free_parameters is None:
129
+ model_fn = (
130
+ "def model(time: float, variables: Iterable[float]) -> Iterable[float]:"
131
+ )
132
+ else:
133
+ args = ", ".join(f"{k}: float" for k in free_parameters)
134
+ model_fn = f"def model(time: float, variables: Iterable[float], {args}) -> Iterable[float]:"
135
+
136
+ return _generate_model_code(
137
+ model,
138
+ imports=[
139
+ "from collections.abc import Iterable\n",
140
+ ],
141
+ sized=False,
142
+ model_fn=model_fn,
143
+ variables_template=" {} = variables",
144
+ assignment_template=" {k} = {v}",
145
+ sympy_inline_fn=sympy_to_inline_py,
146
+ return_template=" return {}",
147
+ end=None,
148
+ free_parameters=free_parameters,
149
+ )
150
+
151
+
152
+ def generate_model_code_rs(
153
+ model: Model,
154
+ free_parameters: list[str] | None = None,
155
+ ) -> str:
156
+ """Transform the model into a rust function, inlining the function calls."""
157
+ if free_parameters is None:
158
+ model_fn = "fn model(time: f64, variables: &[f64; {n}]) -> [f64; {n}] {{"
159
+ else:
160
+ args = ", ".join(f"{k}: f64" for k in free_parameters)
161
+ model_fn = f"fn model(time: f64, variables: &[f64; {{n}}], {args}) -> [f64; {{n}}] {{{{"
162
+
163
+ return _generate_model_code(
164
+ model,
165
+ imports=None,
166
+ sized=True,
167
+ model_fn=model_fn,
168
+ variables_template=" let [{}] = *variables;",
169
+ assignment_template=" let {k}: f64 = {v};",
170
+ sympy_inline_fn=sympy_to_inline_rust,
171
+ return_template=" return [{}]",
172
+ end="}",
173
+ free_parameters=free_parameters,
174
+ )
@@ -2,43 +2,44 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- import warnings
5
+ import logging
6
6
  from typing import TYPE_CHECKING
7
7
 
8
- import sympy
9
-
10
- from mxlpy.meta.source_tools import fn_to_sympy, sympy_to_fn
8
+ from mxlpy.meta.sympy_tools import fn_to_sympy, list_of_symbols, sympy_to_python_fn
11
9
  from mxlpy.types import Derived
12
10
 
13
11
  if TYPE_CHECKING:
12
+ import sympy
13
+
14
14
  from mxlpy.model import Model
15
15
 
16
16
  __all__ = [
17
17
  "generate_mxlpy_code",
18
18
  ]
19
19
 
20
-
21
- def _list_of_symbols(args: list[str]) -> list[sympy.Symbol | sympy.Expr]:
22
- return [sympy.Symbol(arg) for arg in args]
20
+ _LOGGER = logging.getLogger()
23
21
 
24
22
 
25
23
  def generate_mxlpy_code(model: Model) -> str:
26
24
  """Generate a mxlpy model from a model."""
27
- functions = {}
25
+ functions: dict[str, tuple[sympy.Expr, list[str]]] = {}
28
26
 
29
27
  # Variables and parameters
30
- variables = model.variables
31
- parameters = model.parameters
28
+ variables = model.get_raw_variables()
29
+ parameters = model.get_parameter_values()
32
30
 
33
31
  # Derived
34
32
  derived_source = []
35
- for k, der in model.derived.items():
33
+ for k, der in model.get_raw_derived().items():
36
34
  fn = der.fn
37
35
  fn_name = fn.__name__
38
- functions[fn_name] = (
39
- fn_to_sympy(fn, model_args=_list_of_symbols(der.args)),
40
- der.args,
41
- )
36
+ if (
37
+ expr := fn_to_sympy(fn, origin=k, model_args=list_of_symbols(der.args))
38
+ ) is None:
39
+ msg = f"Unable to parse fn for derived value '{k}'"
40
+ raise ValueError(msg)
41
+
42
+ functions[fn_name] = (expr, der.args)
42
43
 
43
44
  derived_source.append(
44
45
  f""" .add_derived(
@@ -50,20 +51,27 @@ def generate_mxlpy_code(model: Model) -> str:
50
51
 
51
52
  # Reactions
52
53
  reactions_source = []
53
- for k, rxn in model.reactions.items():
54
+ for k, rxn in model.get_raw_reactions().items():
54
55
  fn = rxn.fn
55
56
  fn_name = fn.__name__
56
- functions[fn_name] = (
57
- fn_to_sympy(fn, model_args=_list_of_symbols(rxn.args)),
58
- rxn.args,
59
- )
57
+ if (
58
+ expr := fn_to_sympy(fn, origin=k, model_args=list_of_symbols(rxn.args))
59
+ ) is None:
60
+ msg = f"Unable to parse fn for reaction '{k}'"
61
+ raise ValueError(msg)
62
+
63
+ functions[fn_name] = (expr, rxn.args)
60
64
  stoichiometry: list[str] = []
61
65
  for var, stoich in rxn.stoichiometry.items():
62
66
  if isinstance(stoich, Derived):
63
- functions[fn_name] = (
64
- fn_to_sympy(fn, model_args=_list_of_symbols(stoich.args)),
65
- rxn.args,
66
- )
67
+ if (
68
+ expr := fn_to_sympy(
69
+ fn, origin=var, model_args=list_of_symbols(stoich.args)
70
+ )
71
+ ) is None:
72
+ msg = f"Unable to parse fn for stoichiometry '{var}'"
73
+ raise ValueError(msg)
74
+ functions[fn_name] = (expr, rxn.args)
67
75
  args = ", ".join(f'"{k}"' for k in stoich.args)
68
76
  stoich = ( # noqa: PLW2901
69
77
  f"""Derived(fn={fn.__name__}, args=[{args}])"""
@@ -81,14 +89,12 @@ def generate_mxlpy_code(model: Model) -> str:
81
89
 
82
90
  # Surrogates
83
91
  if len(model._surrogates) > 0: # noqa: SLF001
84
- warnings.warn(
85
- "Generating code for Surrogates not yet supported.",
86
- stacklevel=1,
87
- )
92
+ msg = "Generating code for Surrogates not yet supported."
93
+ _LOGGER.warning(msg)
88
94
 
89
95
  # Combine all the sources
90
96
  functions_source = "\n\n".join(
91
- sympy_to_fn(fn_name=name, args=args, expr=expr)
97
+ sympy_to_python_fn(fn_name=name, args=args, expr=expr)
92
98
  for name, (expr, args) in functions.items()
93
99
  )
94
100
  source = [