modelbase2 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,286 @@
1
+ # ruff: noqa: D100, D101, D102, D103, D104, D105, D106, D107, D200, D203, D400, D401
2
+
3
+ import ast
4
+ import inspect
5
+ import textwrap
6
+ from collections.abc import Callable
7
+ from dataclasses import dataclass
8
+ from typing import Any, cast
9
+
10
+ import sympy
11
+
12
+ from modelbase2.model import Model
13
+
14
+ __all__ = ["Context", "SymbolicModel", "model_fn_to_sympy", "to_symbolic_model"]
15
+
16
+
17
+ @dataclass
18
+ class Context:
19
+ symbols: dict[str, sympy.Symbol | sympy.Expr]
20
+ caller: Callable
21
+
22
+
23
+ @dataclass
24
+ class SymbolicModel:
25
+ variables: dict[str, sympy.Symbol]
26
+ parameters: dict[str, sympy.Symbol]
27
+ eqs: list[sympy.Expr]
28
+
29
+
30
+ def to_symbolic_model(model: Model) -> SymbolicModel:
31
+ cache = model._create_cache() # noqa: SLF001
32
+
33
+ variables = dict(
34
+ zip(model.variables, sympy.symbols(list(model.variables)), strict=True)
35
+ )
36
+ parameters = dict(
37
+ zip(model.parameters, sympy.symbols(list(model.parameters)), strict=True)
38
+ )
39
+ symbols = variables | parameters
40
+
41
+ for k, v in model.derived.items():
42
+ symbols[k] = model_fn_to_sympy(v.fn, [symbols[i] for i in v.args])
43
+
44
+ rxns = {
45
+ k: model_fn_to_sympy(v.fn, [symbols[i] for i in v.args])
46
+ for k, v in model.reactions.items()
47
+ }
48
+
49
+ eqs: dict[str, sympy.Expr] = {}
50
+ for cpd, stoich in cache.stoich_by_cpds.items():
51
+ for rxn, stoich_value in stoich.items():
52
+ eqs[cpd] = (
53
+ eqs.get(cpd, sympy.Float(0.0)) + sympy.Float(stoich_value) * rxns[rxn] # type: ignore
54
+ )
55
+
56
+ for cpd, dstoich in cache.dyn_stoich_by_cpds.items():
57
+ for rxn, der in dstoich.items():
58
+ eqs[cpd] = eqs.get(cpd, sympy.Float(0.0)) + model_fn_to_sympy(
59
+ der.fn,
60
+ [symbols[i] for i in der.args] * rxns[rxn], # type: ignore
61
+ ) # type: ignore
62
+
63
+ return SymbolicModel(
64
+ variables=variables,
65
+ parameters=parameters,
66
+ eqs=[eqs[i] for i in cache.var_names],
67
+ )
68
+
69
+
70
+ def model_fn_to_sympy(
71
+ fn: Callable, model_args: list[sympy.Symbol | sympy.Expr] | None = None
72
+ ) -> sympy.Expr:
73
+ source = textwrap.dedent(inspect.getsource(fn))
74
+
75
+ if not isinstance(fn_def := ast.parse(source).body[0], ast.FunctionDef):
76
+ msg = "Expected a function definition"
77
+ raise TypeError(msg)
78
+
79
+ fn_args = [str(arg.arg) for arg in fn_def.args.args]
80
+
81
+ sympy_expr = _handle_fn_body(
82
+ fn_def.body,
83
+ ctx=Context(
84
+ symbols={name: sympy.Symbol(name) for name in fn_args},
85
+ caller=fn,
86
+ ),
87
+ )
88
+
89
+ if model_args is not None:
90
+ sympy_expr = sympy_expr.subs(dict(zip(fn_args, model_args, strict=True)))
91
+
92
+ return cast(sympy.Expr, sympy_expr)
93
+
94
+
95
+ def _handle_fn_body(body: list[ast.stmt], ctx: Context) -> sympy.Expr:
96
+ pieces = []
97
+ remaining_body = list(body)
98
+
99
+ while remaining_body:
100
+ node = remaining_body.pop(0)
101
+
102
+ if isinstance(node, ast.If):
103
+ condition = _handle_expr(node.test, ctx)
104
+ if_expr = _handle_fn_body(node.body, ctx)
105
+ pieces.append((if_expr, condition))
106
+
107
+ # If there's an else clause
108
+ if node.orelse:
109
+ # Check if it's an elif (an If node in orelse)
110
+ if len(node.orelse) == 1 and isinstance(node.orelse[0], ast.If):
111
+ # Push the elif back to the beginning of remaining_body to process next
112
+ remaining_body.insert(0, node.orelse[0])
113
+ else:
114
+ # It's a regular else
115
+ else_expr = _handle_fn_body(node.orelse, ctx) # FIXME: copy here
116
+ pieces.append((else_expr, True))
117
+ break # We're done with this chain
118
+
119
+ elif not remaining_body and any(
120
+ isinstance(n, ast.Return) for n in body[body.index(node) + 1 :]
121
+ ):
122
+ else_expr = _handle_fn_body(
123
+ body[body.index(node) + 1 :], ctx
124
+ ) # FIXME: copy here
125
+ pieces.append((else_expr, True))
126
+
127
+ elif isinstance(node, ast.Return):
128
+ if (value := node.value) is None:
129
+ msg = "Return value cannot be None"
130
+ raise ValueError(msg)
131
+
132
+ expr = _handle_expr(value, ctx)
133
+ if not pieces:
134
+ return expr
135
+ pieces.append((expr, True))
136
+ break
137
+
138
+ elif isinstance(node, ast.Assign):
139
+ # Handle tuple assignments like c, d = a, b
140
+ if isinstance(node.targets[0], ast.Tuple):
141
+ # Handle tuple unpacking
142
+ target_elements = node.targets[0].elts
143
+
144
+ if isinstance(node.value, ast.Tuple):
145
+ # Direct unpacking like c, d = a, b
146
+ value_elements = node.value.elts
147
+ for target, value_expr in zip(
148
+ target_elements, value_elements, strict=True
149
+ ):
150
+ if isinstance(target, ast.Name):
151
+ ctx.symbols[target.id] = _handle_expr(value_expr, ctx)
152
+ else:
153
+ # Handle potential iterable unpacking
154
+ value = _handle_expr(node.value, ctx)
155
+ else:
156
+ # Regular single assignment
157
+ if not isinstance(target := node.targets[0], ast.Name):
158
+ msg = "Only single variable assignments are supported"
159
+ raise TypeError(msg)
160
+ target_name = target.id
161
+ value = _handle_expr(node.value, ctx)
162
+ ctx.symbols[target_name] = value
163
+
164
+ # If we have pieces to combine into a Piecewise
165
+ if pieces:
166
+ return sympy.Piecewise(*pieces)
167
+
168
+ # If no return was found but we have assignments, return the last assigned variable
169
+ for node in reversed(body):
170
+ if isinstance(node, ast.Assign) and isinstance(node.targets[0], ast.Name):
171
+ target_name = node.targets[0].id
172
+ return ctx.symbols[target_name]
173
+
174
+ msg = "No return value found in function body"
175
+ raise ValueError(msg)
176
+
177
+
178
+ def _handle_unaryop(node: ast.UnaryOp, ctx: Context) -> sympy.Expr:
179
+ left = _handle_expr(node.operand, ctx)
180
+ left = cast(Any, left) # stupid sympy types don't allow ops on symbols
181
+
182
+ if isinstance(node.op, ast.UAdd):
183
+ return +left
184
+ if isinstance(node.op, ast.USub):
185
+ return -left
186
+
187
+ msg = f"Operation {type(node.op).__name__} not implemented"
188
+ raise NotImplementedError(msg)
189
+
190
+
191
+ def _handle_binop(node: ast.BinOp, ctx: Context) -> sympy.Expr:
192
+ left = _handle_expr(node.left, ctx)
193
+ left = cast(Any, left) # stupid sympy types don't allow ops on symbols
194
+
195
+ right = _handle_expr(node.right, ctx)
196
+ right = cast(Any, right) # stupid sympy types don't allow ops on symbols
197
+
198
+ if isinstance(node.op, ast.Add):
199
+ return left + right
200
+ if isinstance(node.op, ast.Sub):
201
+ return left - right
202
+ if isinstance(node.op, ast.Mult):
203
+ return left * right
204
+ if isinstance(node.op, ast.Div):
205
+ return left / right
206
+ if isinstance(node.op, ast.Pow):
207
+ return left**right
208
+ if isinstance(node.op, ast.Mod):
209
+ return left % right
210
+ if isinstance(node.op, ast.FloorDiv):
211
+ return left // right
212
+
213
+ msg = f"Operation {type(node.op).__name__} not implemented"
214
+ raise NotImplementedError(msg)
215
+
216
+
217
+ def _handle_call(node: ast.Call, ctx: Context) -> sympy.Expr:
218
+ if not isinstance(callee := node.func, ast.Name):
219
+ msg = "Only function calls with names are supported"
220
+ raise TypeError(msg)
221
+
222
+ fn_name = str(callee.id)
223
+ parent_module = inspect.getmodule(ctx.caller)
224
+ fns = dict(inspect.getmembers(parent_module, predicate=callable))
225
+
226
+ return model_fn_to_sympy(
227
+ fns[fn_name],
228
+ model_args=[_handle_expr(i, ctx) for i in node.args],
229
+ )
230
+
231
+
232
+ def _handle_name(node: ast.Name, ctx: Context) -> sympy.Symbol | sympy.Expr:
233
+ return ctx.symbols[node.id]
234
+
235
+
236
+ def _handle_expr(node: ast.expr, ctx: Context) -> sympy.Expr:
237
+ if isinstance(node, ast.UnaryOp):
238
+ return _handle_unaryop(node, ctx)
239
+ if isinstance(node, ast.BinOp):
240
+ return _handle_binop(node, ctx)
241
+ if isinstance(node, ast.Name):
242
+ return _handle_name(node, ctx)
243
+ if isinstance(node, ast.Constant):
244
+ return node.value
245
+ if isinstance(node, ast.Compare):
246
+ # Handle chained comparisons like 1 < a < 2
247
+ left = cast(Any, _handle_expr(node.left, ctx))
248
+ comparisons = []
249
+
250
+ # Build all individual comparisons from the chain
251
+ prev_value = left
252
+ for op, comparator in zip(node.ops, node.comparators, strict=True):
253
+ right = cast(Any, _handle_expr(comparator, ctx))
254
+
255
+ if isinstance(op, ast.Gt):
256
+ comparisons.append(prev_value > right)
257
+ elif isinstance(op, ast.GtE):
258
+ comparisons.append(prev_value >= right)
259
+ elif isinstance(op, ast.Lt):
260
+ comparisons.append(prev_value < right)
261
+ elif isinstance(op, ast.LtE):
262
+ comparisons.append(prev_value <= right)
263
+ elif isinstance(op, ast.Eq):
264
+ comparisons.append(prev_value == right)
265
+ elif isinstance(op, ast.NotEq):
266
+ comparisons.append(prev_value != right)
267
+
268
+ prev_value = right
269
+
270
+ # Combine all comparisons with logical AND
271
+ result = comparisons[0]
272
+ for comp in comparisons[1:]:
273
+ result = sympy.And(result, comp)
274
+ return cast(sympy.Expr, result)
275
+ if isinstance(node, ast.Call):
276
+ return _handle_call(node, ctx)
277
+
278
+ # Handle conditional expressions (ternary operators)
279
+ if isinstance(node, ast.IfExp):
280
+ condition = _handle_expr(node.test, ctx)
281
+ if_true = _handle_expr(node.body, ctx)
282
+ if_false = _handle_expr(node.orelse, ctx)
283
+ return sympy.Piecewise((if_true, condition), (if_false, True))
284
+
285
+ msg = f"Expression type {type(node).__name__} not implemented"
286
+ raise NotImplementedError(msg)
@@ -508,14 +508,10 @@ def get_model_tex_diff(
508
508
  gls = default_init(gls)
509
509
  section_label = "sec:model-diff"
510
510
 
511
- return f"""{' start autogenerated ':%^60}
511
+ return f"""{" start autogenerated ":%^60}
512
512
  {_clearpage()}
513
- {_subsubsection('Model changes')}{_label(section_label)}
514
- {(
515
- (_to_tex_export(m1) - _to_tex_export(m2))
516
- .rename_with_glossary(gls)
517
- .export_all()
518
- )}
513
+ {_subsubsection("Model changes")}{_label(section_label)}
514
+ {((_to_tex_export(m1) - _to_tex_export(m2)).rename_with_glossary(gls).export_all())}
519
515
  {_clearpage()}
520
- {' end autogenerated ':%^60}
516
+ {" end autogenerated ":%^60}
521
517
  """
modelbase2/fit.py CHANGED
@@ -50,7 +50,7 @@ class SteadyStateResidualFn(Protocol):
50
50
  data: pd.Series,
51
51
  model: Model,
52
52
  y0: dict[str, float],
53
- integrator: type[IntegratorProtocol],
53
+ integrator: Callable[[Callable, ArrayLike], IntegratorProtocol],
54
54
  ) -> float:
55
55
  """Calculate residual error between model steady state and experimental data."""
56
56
  ...
@@ -67,7 +67,7 @@ class TimeSeriesResidualFn(Protocol):
67
67
  data: pd.DataFrame,
68
68
  model: Model,
69
69
  y0: dict[str, float],
70
- integrator: type[IntegratorProtocol],
70
+ integrator: Callable[[Callable, ArrayLike], IntegratorProtocol],
71
71
  ) -> float:
72
72
  """Calculate residual error between model time course and experimental data."""
73
73
  ...
@@ -101,7 +101,7 @@ def _steady_state_residual(
101
101
  data: pd.Series,
102
102
  model: Model,
103
103
  y0: dict[str, float] | None,
104
- integrator: type[IntegratorProtocol],
104
+ integrator: Callable[[Callable, ArrayLike], IntegratorProtocol],
105
105
  ) -> float:
106
106
  """Calculate residual error between model steady state and experimental data.
107
107
 
@@ -148,7 +148,7 @@ def _time_course_residual(
148
148
  data: pd.DataFrame,
149
149
  model: Model,
150
150
  y0: dict[str, float],
151
- integrator: type[IntegratorProtocol],
151
+ integrator: Callable[[Callable, ArrayLike], IntegratorProtocol],
152
152
  ) -> float:
153
153
  """Calculate residual error between model time course and experimental data.
154
154
 
@@ -187,7 +187,7 @@ def steady_state(
187
187
  y0: dict[str, float] | None = None,
188
188
  minimize_fn: MinimizeFn = _default_minimize_fn,
189
189
  residual_fn: SteadyStateResidualFn = _steady_state_residual,
190
- integrator: type[IntegratorProtocol] = DefaultIntegrator,
190
+ integrator: Callable[[Callable, ArrayLike], IntegratorProtocol] = DefaultIntegrator,
191
191
  ) -> dict[str, float]:
192
192
  """Fit model parameters to steady-state experimental data.
193
193
 
@@ -241,7 +241,7 @@ def time_course(
241
241
  y0: dict[str, float] | None = None,
242
242
  minimize_fn: MinimizeFn = _default_minimize_fn,
243
243
  residual_fn: TimeSeriesResidualFn = _time_course_residual,
244
- integrator: type[IntegratorProtocol] = DefaultIntegrator,
244
+ integrator: Callable[[Callable, ArrayLike], IntegratorProtocol] = DefaultIntegrator,
245
245
  ) -> dict[str, float]:
246
246
  """Fit model parameters to time course of experimental data.
247
247
 
modelbase2/model.py CHANGED
@@ -10,7 +10,6 @@ from __future__ import annotations
10
10
  import copy
11
11
  import inspect
12
12
  import itertools as it
13
- import math
14
13
  from dataclasses import dataclass, field
15
14
  from typing import TYPE_CHECKING, Self, cast
16
15
 
@@ -26,7 +25,13 @@ from modelbase2.types import (
26
25
  Readout,
27
26
  )
28
27
 
29
- __all__ = ["ArityMismatchError", "Model", "ModelCache", "SortError"]
28
+ __all__ = [
29
+ "ArityMismatchError",
30
+ "CircularDependencyError",
31
+ "MissingDependenciesError",
32
+ "Model",
33
+ "ModelCache",
34
+ ]
30
35
 
31
36
  if TYPE_CHECKING:
32
37
  from collections.abc import Iterable, Mapping
@@ -35,19 +40,38 @@ if TYPE_CHECKING:
35
40
  from modelbase2.types import AbstractSurrogate, Callable, Param, RateFn, RetType
36
41
 
37
42
 
38
- class SortError(Exception):
43
+ class MissingDependenciesError(Exception):
39
44
  """Raised when dependencies cannot be sorted topologically.
40
45
 
41
46
  This typically indicates circular dependencies in model components.
42
47
  """
43
48
 
44
- def __init__(self, unsorted: list[str], order: list[str]) -> None:
49
+ def __init__(self, not_solvable: dict[str, list[str]]) -> None:
45
50
  """Initialise exception."""
51
+ missing_by_module = "\n".join(f"\t{k}: {v}" for k, v in not_solvable.items())
46
52
  msg = (
47
- f"Exceeded max iterations on sorting derived. "
48
- "Check if there are circular references.\n"
49
- f"Unsorted: {unsorted}\n"
50
- f"Order: {order}"
53
+ f"Dependencies cannot be solved. Missing dependencies:\n{missing_by_module}"
54
+ )
55
+ super().__init__(msg)
56
+
57
+
58
+ class CircularDependencyError(Exception):
59
+ """Raised when dependencies cannot be sorted topologically.
60
+
61
+ This typically indicates circular dependencies in model components.
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ missing: dict[str, set[str]],
67
+ ) -> None:
68
+ """Initialise exception."""
69
+ missing_by_module = "\n".join(f"\t{k}: {v}" for k, v in missing.items())
70
+ msg = (
71
+ f"Exceeded max iterations on sorting dependencies.\n"
72
+ "Check if there are circular references. "
73
+ "Missing dependencies:\n"
74
+ f"{missing_by_module}"
51
75
  )
52
76
  super().__init__(msg)
53
77
 
@@ -120,6 +144,24 @@ def _invalidate_cache(method: Callable[Param, RetType]) -> Callable[Param, RetTy
120
144
  return wrapper # type: ignore
121
145
 
122
146
 
147
+ def _check_if_is_sortable(
148
+ available: set[str],
149
+ elements: list[tuple[str, set[str]]],
150
+ ) -> None:
151
+ all_available = available.copy()
152
+ for name, _ in elements:
153
+ all_available.add(name)
154
+
155
+ # Check if it can be sorted in the first place
156
+ not_solvable = {}
157
+ for name, args in elements:
158
+ if not args.issubset(all_available):
159
+ not_solvable[name] = sorted(args.difference(all_available))
160
+
161
+ if not_solvable:
162
+ raise MissingDependenciesError(not_solvable=not_solvable)
163
+
164
+
123
165
  def _sort_dependencies(
124
166
  available: set[str], elements: list[tuple[str, set[str]]]
125
167
  ) -> list[str]:
@@ -138,6 +180,8 @@ def _sort_dependencies(
138
180
  """
139
181
  from queue import Empty, SimpleQueue
140
182
 
183
+ _check_if_is_sortable(available, elements)
184
+
141
185
  order = []
142
186
  # FIXME: what is the worst case here?
143
187
  max_iterations = len(elements) ** 2
@@ -171,7 +215,10 @@ def _sort_dependencies(
171
215
  unsorted.append(queue.get_nowait()[0])
172
216
  except Empty:
173
217
  break
174
- raise SortError(unsorted=unsorted, order=order)
218
+
219
+ mod_to_args: dict[str, set[str]] = dict(elements)
220
+ missing = {k: mod_to_args[k].difference(available) for k in unsorted}
221
+ raise CircularDependencyError(missing=missing)
175
222
  return order
176
223
 
177
224
 
modelbase2/npe.py CHANGED
@@ -23,7 +23,7 @@ __all__ = [
23
23
  from abc import abstractmethod
24
24
  from dataclasses import dataclass
25
25
  from pathlib import Path
26
- from typing import cast
26
+ from typing import TYPE_CHECKING, cast
27
27
 
28
28
  import numpy as np
29
29
  import pandas as pd
@@ -35,6 +35,11 @@ from torch.optim.adam import Adam
35
35
  from modelbase2.nnarchitectures import MLP, DefaultDevice, LSTMnn
36
36
  from modelbase2.parallel import Cache
37
37
 
38
+ if TYPE_CHECKING:
39
+ from collections.abc import Callable
40
+
41
+ from torch.optim.optimizer import ParamsT
42
+
38
43
  DefaultCache = Cache(Path(".cache"))
39
44
 
40
45
 
@@ -140,7 +145,7 @@ def train_torch_ss_estimator(
140
145
  epochs: int,
141
146
  batch_size: int | None = None,
142
147
  approximator: nn.Module | None = None,
143
- optimimzer_cls: type[Adam] = Adam,
148
+ optimimzer_cls: Callable[[ParamsT], Adam] = Adam,
144
149
  device: torch.device = DefaultDevice,
145
150
  ) -> tuple[TorchSSEstimator, pd.Series]:
146
151
  """Train a PyTorch steady state estimator.
@@ -206,7 +211,7 @@ def train_torch_time_course_estimator(
206
211
  epochs: int,
207
212
  batch_size: int | None = None,
208
213
  approximator: nn.Module | None = None,
209
- optimimzer_cls: type[Adam] = Adam,
214
+ optimimzer_cls: Callable[[ParamsT], Adam] = Adam,
210
215
  device: torch.device = DefaultDevice,
211
216
  ) -> tuple[TorchTimeCourseEstimator, pd.Series]:
212
217
  """Train a PyTorch time course estimator.
modelbase2/plot.py CHANGED
@@ -818,7 +818,7 @@ def relative_label_distribution(
818
818
  isos = mapper.get_isotopomers_of_at_position(name, i)
819
819
  labels = cast(pd.DataFrame, concs.loc[:, isos])
820
820
  total = concs.loc[:, f"{name}__total"]
821
- ax.plot(labels.index, (labels.sum(axis=1) / total), label=f"C{i+1}")
821
+ ax.plot(labels.index, (labels.sum(axis=1) / total), label=f"C{i + 1}")
822
822
  ax.set_title(name)
823
823
  ax.legend()
824
824
  else:
@@ -827,6 +827,6 @@ def relative_label_distribution(
827
827
  ):
828
828
  ax.plot(concs.index, concs.loc[:, isos])
829
829
  ax.set_title(name)
830
- ax.legend([f"C{i+1}" for i in range(len(isos))])
830
+ ax.legend([f"C{i + 1}" for i in range(len(isos))])
831
831
 
832
832
  return fig, axs
@@ -507,7 +507,11 @@ def _codgen(name: str, sbml: Parser) -> Path:
507
507
 
508
508
  # Initial assignments
509
509
  initial_assignment_order = _sort_dependencies(
510
- available=set(sbml.initial_assignment) ^ set(parameters) ^ set(variables),
510
+ available=set(sbml.initial_assignment)
511
+ ^ set(parameters)
512
+ ^ set(variables)
513
+ ^ set(sbml.derived)
514
+ | {"time"},
511
515
  elements=[(k, set(v.args)) for k, v in sbml.initial_assignment.items()],
512
516
  )
513
517
 
modelbase2/simulator.py CHANGED
@@ -21,6 +21,8 @@ from modelbase2.integrators import DefaultIntegrator
21
21
  __all__ = ["Simulator"]
22
22
 
23
23
  if TYPE_CHECKING:
24
+ from collections.abc import Callable
25
+
24
26
  from modelbase2.model import Model
25
27
  from modelbase2.types import ArrayLike, IntegratorProtocol
26
28
 
@@ -83,7 +85,9 @@ class Simulator:
83
85
  self,
84
86
  model: Model,
85
87
  y0: dict[str, float] | None = None,
86
- integrator: type[IntegratorProtocol] = DefaultIntegrator,
88
+ integrator: Callable[
89
+ [Callable, ArrayLike], IntegratorProtocol
90
+ ] = DefaultIntegrator,
87
91
  *,
88
92
  test_run: bool = True,
89
93
  ) -> None:
@@ -93,7 +97,7 @@ class Simulator:
93
97
  model (Model): The model to be simulated.
94
98
  y0 (dict[str, float] | None, optional): Initial conditions for the model variables.
95
99
  If None, the initial conditions are obtained from the model. Defaults to None.
96
- integrator (type[IntegratorProtocol], optional): The integrator to use for the simulation.
100
+ integrator (Callable[[Callable, ArrayLike], IntegratorProtocol], optional): The integrator to use for the simulation.
97
101
  Defaults to DefaultIntegrator.
98
102
  test_run (bool, optional): If True, performs a test run to ensure the model's methods
99
103
  (get_full_concs, get_fluxes, get_right_hand_side) work correctly with the initial conditions.
@@ -104,7 +108,7 @@ class Simulator:
104
108
  y0 = model.get_initial_conditions() if y0 is None else y0
105
109
  self.y0 = [y0[k] for k in model.get_variable_names()]
106
110
 
107
- self.integrator = integrator(self.model, y0=self.y0)
111
+ self.integrator = integrator(self.model, self.y0)
108
112
  self.concs = None
109
113
  self.args = None
110
114
  self.simulation_parameters = None
@@ -32,7 +32,9 @@ class PolySurrogate(AbstractSurrogate):
32
32
  def train_polynomial_surrogate(
33
33
  feature: ArrayLike,
34
34
  target: ArrayLike,
35
- series: Literal["Power", "Chebyshev", "Legendre", "Laguerre", "Hermite", "HermiteE"] = "Power",
35
+ series: Literal[
36
+ "Power", "Chebyshev", "Legendre", "Laguerre", "Hermite", "HermiteE"
37
+ ] = "Power",
36
38
  degrees: Iterable[int] = (1, 2, 3, 4, 5, 6, 7),
37
39
  surrogate_args: list[str] | None = None,
38
40
  surrogate_stoichiometries: dict[str, dict[str, float]] | None = None,
@@ -1,3 +1,4 @@
1
+ from collections.abc import Callable
1
2
  from dataclasses import dataclass
2
3
 
3
4
  import numpy as np
@@ -6,9 +7,10 @@ import torch
6
7
  import tqdm
7
8
  from torch import nn
8
9
  from torch.optim.adam import Adam
10
+ from torch.optim.optimizer import ParamsT
9
11
 
10
- from modelbase2.types import AbstractSurrogate
11
12
  from modelbase2.nnarchitectures import MLP, DefaultDevice
13
+ from modelbase2.types import AbstractSurrogate
12
14
 
13
15
  __all__ = ["TorchSurrogate", "train_torch_surrogate"]
14
16
 
@@ -124,7 +126,7 @@ def train_torch_surrogate(
124
126
  surrogate_stoichiometries: dict[str, dict[str, float]] | None = None,
125
127
  batch_size: int | None = None,
126
128
  approximator: nn.Module | None = None,
127
- optimimzer_cls: type[Adam] = Adam,
129
+ optimimzer_cls: Callable[[ParamsT], Adam] = Adam,
128
130
  device: torch.device = DefaultDevice,
129
131
  ) -> tuple[TorchSurrogate, pd.Series]:
130
132
  """Train a PyTorch surrogate model.
modelbase2/surrogates.py CHANGED
@@ -19,6 +19,7 @@ from __future__ import annotations
19
19
  from abc import abstractmethod
20
20
  from dataclasses import dataclass
21
21
  from pathlib import Path
22
+ from typing import TYPE_CHECKING
22
23
 
23
24
  import numpy as np
24
25
  import pandas as pd
@@ -29,6 +30,11 @@ from torch.optim.adam import Adam
29
30
 
30
31
  from modelbase2.parallel import Cache
31
32
 
33
+ if TYPE_CHECKING:
34
+ from collections.abc import Callable
35
+
36
+ from torch.optim.optimizer import ParamsT
37
+
32
38
  __all__ = [
33
39
  "AbstractSurrogate",
34
40
  "Approximator",
@@ -251,7 +257,7 @@ def train_torch_surrogate(
251
257
  surrogate_stoichiometries: dict[str, dict[str, float]],
252
258
  batch_size: int | None = None,
253
259
  approximator: nn.Module | None = None,
254
- optimimzer_cls: type[Adam] = Adam,
260
+ optimimzer_cls: Callable[[ParamsT], Adam] = Adam,
255
261
  device: torch.device = DefaultDevice,
256
262
  ) -> tuple[TorchSurrogate, pd.Series]:
257
263
  """Train a PyTorch surrogate model.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: modelbase2
3
- Version: 0.3.0
3
+ Version: 0.5.0
4
4
  Summary: A package to build metabolic models
5
5
  Author-email: Marvin van Aalst <marvin.vanaalst@gmail.com>
6
6
  Maintainer-email: Marvin van Aalst <marvin.vanaalst@gmail.com>
@@ -33,6 +33,7 @@ Requires-Dist: pebble>=5.0.7
33
33
  Requires-Dist: python-libsbml>=5.20.4
34
34
  Requires-Dist: scipy>=1.14.1
35
35
  Requires-Dist: seaborn>=0.13.2
36
+ Requires-Dist: symbtools>=0.4.0
36
37
  Requires-Dist: sympy>=1.13.1
37
38
  Requires-Dist: tabulate>=0.9.0
38
39
  Requires-Dist: tqdm>=4.66.6