mxlpy 0.15.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,6 +24,8 @@ __all__ = [
24
24
 
25
25
  @dataclass
26
26
  class Context:
27
+ """Context for converting a function to sympy expression."""
28
+
27
29
  symbols: dict[str, sympy.Symbol | sympy.Expr]
28
30
  caller: Callable
29
31
  parent_module: ModuleType | None
@@ -34,6 +36,7 @@ class Context:
34
36
  caller: Callable | None = None,
35
37
  parent_module: ModuleType | None = None,
36
38
  ) -> "Context":
39
+ """Update the context with new values."""
37
40
  return Context(
38
41
  symbols=self.symbols if symbols is None else symbols,
39
42
  caller=self.caller if caller is None else caller,
@@ -44,7 +47,26 @@ class Context:
44
47
 
45
48
 
46
49
  def get_fn_source(fn: Callable) -> str:
47
- """Get the string representation of a function."""
50
+ """Get the string representation of a function.
51
+
52
+ Parameters
53
+ ----------
54
+ fn
55
+ The function to extract source from
56
+
57
+ Returns
58
+ -------
59
+ str
60
+ String representation of the function's source code
61
+
62
+ Examples
63
+ --------
64
+ >>> def example_fn(x): return x * 2
65
+ >>> source = get_fn_source(example_fn)
66
+ >>> print(source)
67
+ def example_fn(x): return x * 2
68
+
69
+ """
48
70
  try:
49
71
  return inspect.getsource(fn)
50
72
  except OSError: # could not get source code
@@ -52,7 +74,31 @@ def get_fn_source(fn: Callable) -> str:
52
74
 
53
75
 
54
76
  def get_fn_ast(fn: Callable) -> ast.FunctionDef:
55
- """Get the source code of a function as an AST."""
77
+ """Get the source code of a function as an AST.
78
+
79
+ Parameters
80
+ ----------
81
+ fn
82
+ The function to convert to AST
83
+
84
+ Returns
85
+ -------
86
+ ast.FunctionDef
87
+ Abstract syntax tree representation of the function
88
+
89
+ Raises
90
+ ------
91
+ TypeError
92
+ If the input is not a function
93
+
94
+ Examples
95
+ --------
96
+ >>> def example_fn(x): return x * 2
97
+ >>> ast_tree = get_fn_ast(example_fn)
98
+ >>> isinstance(ast_tree, ast.FunctionDef)
99
+ True
100
+
101
+ """
56
102
  tree = ast.parse(textwrap.dedent(get_fn_source(fn)))
57
103
  if not isinstance(fn_def := tree.body[0], ast.FunctionDef):
58
104
  msg = "Not a function"
@@ -61,6 +107,27 @@ def get_fn_ast(fn: Callable) -> ast.FunctionDef:
61
107
 
62
108
 
63
109
  def sympy_to_inline(expr: sympy.Expr) -> str:
110
+ """Convert a sympy expression to inline Python code.
111
+
112
+ Parameters
113
+ ----------
114
+ expr
115
+ The sympy expression to convert
116
+
117
+ Returns
118
+ -------
119
+ str
120
+ Python code string for the expression
121
+
122
+ Examples
123
+ --------
124
+ >>> import sympy
125
+ >>> x = sympy.Symbol('x')
126
+ >>> expr = x**2 + 2*x + 1
127
+ >>> sympy_to_inline(expr)
128
+ 'x**2 + 2*x + 1'
129
+
130
+ """
64
131
  return cast(str, pycode(expr, fully_qualified_modules=True))
65
132
 
66
133
 
@@ -70,7 +137,32 @@ def sympy_to_fn(
70
137
  args: list[str],
71
138
  expr: sympy.Expr,
72
139
  ) -> str:
73
- """Convert a sympy expression to a python function."""
140
+ """Convert a sympy expression to a python function.
141
+
142
+ Parameters
143
+ ----------
144
+ fn_name
145
+ Name of the function to generate
146
+ args
147
+ List of argument names for the function
148
+ expr
149
+ Sympy expression to convert to a function body
150
+
151
+ Returns
152
+ -------
153
+ str
154
+ String representation of the generated function
155
+
156
+ Examples
157
+ --------
158
+ >>> import sympy
159
+ >>> x, y = sympy.symbols('x y')
160
+ >>> expr = x**2 + y
161
+ >>> print(sympy_to_fn(fn_name="square_plus_y", args=["x", "y"], expr=expr))
162
+ def square_plus_y(x: float, y: float) -> float:
163
+ return x**2 + y
164
+
165
+ """
74
166
  fn_args = ", ".join(f"{i}: float" for i in args)
75
167
 
76
168
  return f"""def {fn_name}({fn_args}) -> float:
@@ -82,7 +174,33 @@ def fn_to_sympy(
82
174
  fn: Callable,
83
175
  model_args: list[sympy.Symbol | sympy.Expr] | None = None,
84
176
  ) -> sympy.Expr:
85
- """Convert a python function to a sympy expression."""
177
+ """Convert a python function to a sympy expression.
178
+
179
+ Parameters
180
+ ----------
181
+ fn
182
+ The function to convert
183
+ model_args
184
+ Optional list of sympy symbols to substitute for function arguments
185
+
186
+ Returns
187
+ -------
188
+ sympy.Expr
189
+ Sympy expression equivalent to the function
190
+
191
+ Examples
192
+ --------
193
+ >>> def square_fn(x):
194
+ ... return x**2
195
+ >>> import sympy
196
+ >>> fn_to_sympy(square_fn)
197
+ x**2
198
+ >>> # With model_args
199
+ >>> y = sympy.Symbol('y')
200
+ >>> fn_to_sympy(square_fn, [y])
201
+ y**2
202
+
203
+ """
86
204
  fn_def = get_fn_ast(fn)
87
205
  fn_args = [str(arg.arg) for arg in fn_def.args.args]
88
206
  sympy_expr = _handle_fn_body(
mxlpy/model.py CHANGED
@@ -18,6 +18,7 @@ import pandas as pd
18
18
 
19
19
  from mxlpy import fns
20
20
  from mxlpy.types import (
21
+ AbstractSurrogate,
21
22
  Array,
22
23
  Derived,
23
24
  Reaction,
@@ -27,6 +28,7 @@ from mxlpy.types import (
27
28
  __all__ = [
28
29
  "ArityMismatchError",
29
30
  "CircularDependencyError",
31
+ "Dependency",
30
32
  "MissingDependenciesError",
31
33
  "Model",
32
34
  "ModelCache",
@@ -36,7 +38,16 @@ if TYPE_CHECKING:
36
38
  from collections.abc import Iterable, Mapping
37
39
  from inspect import FullArgSpec
38
40
 
39
- from mxlpy.types import AbstractSurrogate, Callable, Param, RateFn, RetType
41
+ from mxlpy.types import Callable, Param, RateFn, RetType
42
+
43
+
44
+ @dataclass
45
+ class Dependency:
46
+ """Container class for building dependency tree."""
47
+
48
+ name: str
49
+ required: set[str]
50
+ provided: set[str]
40
51
 
41
52
 
42
53
  class MissingDependenciesError(Exception):
@@ -145,30 +156,33 @@ def _invalidate_cache(method: Callable[Param, RetType]) -> Callable[Param, RetTy
145
156
 
146
157
  def _check_if_is_sortable(
147
158
  available: set[str],
148
- elements: list[tuple[str, set[str]]],
159
+ elements: list[Dependency],
149
160
  ) -> None:
150
161
  all_available = available.copy()
151
- for name, _ in elements:
152
- all_available.add(name)
162
+ for dependency in elements:
163
+ all_available.update(dependency.provided)
153
164
 
154
165
  # Check if it can be sorted in the first place
155
166
  not_solvable = {}
156
- for name, args in elements:
157
- if not args.issubset(all_available):
158
- not_solvable[name] = sorted(args.difference(all_available))
167
+ for dependency in elements:
168
+ if not dependency.required.issubset(all_available):
169
+ not_solvable[dependency.name] = sorted(
170
+ dependency.required.difference(all_available)
171
+ )
159
172
 
160
173
  if not_solvable:
161
174
  raise MissingDependenciesError(not_solvable=not_solvable)
162
175
 
163
176
 
164
177
  def _sort_dependencies(
165
- available: set[str], elements: list[tuple[str, set[str]]]
178
+ available: set[str],
179
+ elements: list[Dependency],
166
180
  ) -> list[str]:
167
181
  """Sort model elements topologically based on their dependencies.
168
182
 
169
183
  Args:
170
184
  available: Set of available component names
171
- elements: List of (name, dependencies) tuples to sort
185
+ elements: List of (name, dependencies, supplier) tuples to sort
172
186
 
173
187
  Returns:
174
188
  List of element names in dependency order
@@ -184,26 +198,27 @@ def _sort_dependencies(
184
198
  order = []
185
199
  # FIXME: what is the worst case here?
186
200
  max_iterations = len(elements) ** 2
187
- queue: SimpleQueue[tuple[str, set[str]]] = SimpleQueue()
188
- for k, v in elements:
189
- queue.put((k, v))
201
+ queue: SimpleQueue[Dependency] = SimpleQueue()
202
+ for dependency in elements:
203
+ queue.put(dependency)
190
204
 
191
205
  last_name = None
192
206
  i = 0
193
207
  while True:
194
208
  try:
195
- new, args = queue.get_nowait()
209
+ dependency = queue.get_nowait()
196
210
  except Empty:
197
211
  break
198
- if args.issubset(available):
199
- available.add(new)
200
- order.append(new)
212
+ if dependency.required.issubset(available):
213
+ available.update(dependency.provided)
214
+ order.append(dependency.name)
215
+
201
216
  else:
202
- if last_name == new:
203
- order.append(new)
217
+ if last_name == dependency.name:
218
+ order.append(last_name)
204
219
  break
205
- queue.put((new, args))
206
- last_name = new
220
+ queue.put(dependency)
221
+ last_name = dependency.name
207
222
  i += 1
208
223
 
209
224
  # Failure case
@@ -211,11 +226,13 @@ def _sort_dependencies(
211
226
  unsorted = []
212
227
  while True:
213
228
  try:
214
- unsorted.append(queue.get_nowait()[0])
229
+ unsorted.append(queue.get_nowait().name)
215
230
  except Empty:
216
231
  break
217
232
 
218
- mod_to_args: dict[str, set[str]] = dict(elements)
233
+ mod_to_args: dict[str, set[str]] = {
234
+ dependency.name: dependency.required for dependency in elements
235
+ }
219
236
  missing = {k: mod_to_args[k].difference(available) for k in unsorted}
220
237
  raise CircularDependencyError(missing=missing)
221
238
  return order
@@ -303,7 +320,12 @@ class Model:
303
320
  to_sort = self._derived | self._reactions | self._surrogates
304
321
  order = _sort_dependencies(
305
322
  available=set(self._parameters) | set(self._variables) | {"time"},
306
- elements=[(k, set(v.args)) for k, v in to_sort.items()],
323
+ elements=[
324
+ Dependency(name=k, required=set(v.args), provided={k})
325
+ if not isinstance(v, AbstractSurrogate)
326
+ else Dependency(name=k, required=set(v.args), provided=set(v.outputs))
327
+ for k, v in to_sort.items()
328
+ ],
307
329
  )
308
330
 
309
331
  # Split derived into parameters and variables
@@ -1227,6 +1249,7 @@ class Model:
1227
1249
  name: str,
1228
1250
  surrogate: AbstractSurrogate,
1229
1251
  args: list[str] | None = None,
1252
+ outputs: list[str] | None = None,
1230
1253
  stoichiometries: dict[str, dict[str, float]] | None = None,
1231
1254
  ) -> Self:
1232
1255
  """Adds a surrogate model to the current instance.
@@ -1237,7 +1260,8 @@ class Model:
1237
1260
  Args:
1238
1261
  name (str): The name of the surrogate model.
1239
1262
  surrogate (AbstractSurrogate): The surrogate model instance to be added.
1240
- args: A list of arguments for the surrogate model.
1263
+ args: Names of the values passed for the surrogate model.
1264
+ outputs: Names of values produced by the surrogate model.
1241
1265
  stoichiometries: A dictionary mapping reaction names to stoichiometries.
1242
1266
 
1243
1267
  Returns:
@@ -1248,6 +1272,8 @@ class Model:
1248
1272
 
1249
1273
  if args is not None:
1250
1274
  surrogate.args = args
1275
+ if outputs is not None:
1276
+ surrogate.outputs = outputs
1251
1277
  if stoichiometries is not None:
1252
1278
  surrogate.stoichiometries = stoichiometries
1253
1279
 
mxlpy/npe/__init__.py ADDED
@@ -0,0 +1,38 @@
1
+ """Neural Process Estimation (NPE) module.
2
+
3
+ This module provides classes and functions for estimating metabolic processes using
4
+ neural networks. It includes functionality for both steady-state and time-course data.
5
+
6
+ Classes:
7
+ TorchSteadyState: Class for steady-state neural network estimation.
8
+ TorchSteadyStateTrainer: Class for training steady-state neural networks.
9
+ TorchTimeCourse: Class for time-course neural network estimation.
10
+ TorchTimeCourseTrainer: Class for training time-course neural networks.
11
+
12
+ Functions:
13
+ train_torch_steady_state: Train a PyTorch steady-state neural network.
14
+ train_torch_time_course: Train a PyTorch time-course neural network.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import contextlib
20
+
21
+ with contextlib.suppress(ImportError):
22
+ from ._torch import (
23
+ TorchSteadyState,
24
+ TorchSteadyStateTrainer,
25
+ TorchTimeCourse,
26
+ TorchTimeCourseTrainer,
27
+ train_torch_steady_state,
28
+ train_torch_time_course,
29
+ )
30
+
31
+ __all__ = [
32
+ "TorchSteadyState",
33
+ "TorchSteadyStateTrainer",
34
+ "TorchTimeCourse",
35
+ "TorchTimeCourseTrainer",
36
+ "train_torch_steady_state",
37
+ "train_torch_time_course",
38
+ ]