mxlpy 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. mxlpy/__init__.py +165 -0
  2. mxlpy/distributions.py +339 -0
  3. mxlpy/experimental/__init__.py +12 -0
  4. mxlpy/experimental/diff.py +226 -0
  5. mxlpy/fit.py +291 -0
  6. mxlpy/fns.py +191 -0
  7. mxlpy/integrators/__init__.py +19 -0
  8. mxlpy/integrators/int_assimulo.py +146 -0
  9. mxlpy/integrators/int_scipy.py +146 -0
  10. mxlpy/label_map.py +610 -0
  11. mxlpy/linear_label_map.py +303 -0
  12. mxlpy/mc.py +548 -0
  13. mxlpy/mca.py +280 -0
  14. mxlpy/meta/__init__.py +11 -0
  15. mxlpy/meta/codegen_latex.py +516 -0
  16. mxlpy/meta/codegen_modebase.py +110 -0
  17. mxlpy/meta/codegen_py.py +107 -0
  18. mxlpy/meta/source_tools.py +320 -0
  19. mxlpy/model.py +1737 -0
  20. mxlpy/nn/__init__.py +10 -0
  21. mxlpy/nn/_tensorflow.py +0 -0
  22. mxlpy/nn/_torch.py +129 -0
  23. mxlpy/npe.py +277 -0
  24. mxlpy/parallel.py +171 -0
  25. mxlpy/parameterise.py +27 -0
  26. mxlpy/paths.py +36 -0
  27. mxlpy/plot.py +875 -0
  28. mxlpy/py.typed +0 -0
  29. mxlpy/sbml/__init__.py +14 -0
  30. mxlpy/sbml/_data.py +77 -0
  31. mxlpy/sbml/_export.py +644 -0
  32. mxlpy/sbml/_import.py +599 -0
  33. mxlpy/sbml/_mathml.py +691 -0
  34. mxlpy/sbml/_name_conversion.py +52 -0
  35. mxlpy/sbml/_unit_conversion.py +74 -0
  36. mxlpy/scan.py +629 -0
  37. mxlpy/simulator.py +655 -0
  38. mxlpy/surrogates/__init__.py +31 -0
  39. mxlpy/surrogates/_poly.py +97 -0
  40. mxlpy/surrogates/_torch.py +196 -0
  41. mxlpy/symbolic/__init__.py +10 -0
  42. mxlpy/symbolic/strikepy.py +582 -0
  43. mxlpy/symbolic/symbolic_model.py +75 -0
  44. mxlpy/types.py +474 -0
  45. mxlpy-0.8.0.dist-info/METADATA +106 -0
  46. mxlpy-0.8.0.dist-info/RECORD +48 -0
  47. mxlpy-0.8.0.dist-info/WHEEL +4 -0
  48. mxlpy-0.8.0.dist-info/licenses/LICENSE +674 -0
mxlpy/model.py ADDED
@@ -0,0 +1,1737 @@
1
+ """Model for Metabolic System Representation.
2
+
3
+ This module provides the core Model class and supporting functionality for representing
4
+ metabolic models, including reactions, variables, parameters and derived quantities.
5
+
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import copy
11
+ import inspect
12
+ import itertools as it
13
+ from dataclasses import dataclass, field
14
+ from typing import TYPE_CHECKING, Self, cast
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+
19
+ from mxlpy import fns
20
+ from mxlpy.types import (
21
+ Array,
22
+ Derived,
23
+ Reaction,
24
+ Readout,
25
+ )
26
+
27
+ __all__ = [
28
+ "ArityMismatchError",
29
+ "CircularDependencyError",
30
+ "MissingDependenciesError",
31
+ "Model",
32
+ "ModelCache",
33
+ ]
34
+
35
+ if TYPE_CHECKING:
36
+ from collections.abc import Iterable, Mapping
37
+ from inspect import FullArgSpec
38
+
39
+ from mxlpy.types import AbstractSurrogate, Callable, Param, RateFn, RetType
40
+
41
+
42
+ class MissingDependenciesError(Exception):
43
+ """Raised when dependencies cannot be sorted topologically.
44
+
45
+ This typically indicates circular dependencies in model components.
46
+ """
47
+
48
+ def __init__(self, not_solvable: dict[str, list[str]]) -> None:
49
+ """Initialise exception."""
50
+ missing_by_module = "\n".join(f"\t{k}: {v}" for k, v in not_solvable.items())
51
+ msg = (
52
+ f"Dependencies cannot be solved. Missing dependencies:\n{missing_by_module}"
53
+ )
54
+ super().__init__(msg)
55
+
56
+
57
+ class CircularDependencyError(Exception):
58
+ """Raised when dependencies cannot be sorted topologically.
59
+
60
+ This typically indicates circular dependencies in model components.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ missing: dict[str, set[str]],
66
+ ) -> None:
67
+ """Initialise exception."""
68
+ missing_by_module = "\n".join(f"\t{k}: {v}" for k, v in missing.items())
69
+ msg = (
70
+ f"Exceeded max iterations on sorting dependencies.\n"
71
+ "Check if there are circular references. "
72
+ "Missing dependencies:\n"
73
+ f"{missing_by_module}"
74
+ )
75
+ super().__init__(msg)
76
+
77
+
78
+ def _get_all_args(argspec: FullArgSpec) -> list[str]:
79
+ kwonly = [] if argspec.kwonlyargs is None else argspec.kwonlyargs
80
+ return argspec.args + kwonly
81
+
82
+
83
+ def _check_function_arity(function: Callable, arity: int) -> bool:
84
+ """Check if the amount of arguments given fits the argument count of the function."""
85
+ argspec = inspect.getfullargspec(function)
86
+ # Give up on *args functions
87
+ if argspec.varargs is not None:
88
+ return True
89
+
90
+ # The sane case
91
+ if len(argspec.args) == arity:
92
+ return True
93
+
94
+ # It might be that the user has set some args to default values,
95
+ # in which case they are also ok (might be kwonly as well)
96
+ defaults = argspec.defaults
97
+ if defaults is not None and len(argspec.args) + len(defaults) == arity:
98
+ return True
99
+ kwonly = argspec.kwonlyargs
100
+ return bool(defaults is not None and len(argspec.args) + len(kwonly) == arity)
101
+
102
+
103
+ class ArityMismatchError(Exception):
104
+ """Mismatch between python function and model arguments."""
105
+
106
+ def __init__(self, name: str, fn: Callable, args: list[str]) -> None:
107
+ """Format message."""
108
+ argspec = inspect.getfullargspec(fn)
109
+
110
+ message = f"Function arity mismatch for {name}.\n"
111
+ message += "\n".join(
112
+ (
113
+ f"{i:<8.8} | {j:<10.10}"
114
+ for i, j in [
115
+ ("Fn args", "Model args"),
116
+ ("-------", "----------"),
117
+ *it.zip_longest(argspec.args, args, fillvalue="---"),
118
+ ]
119
+ )
120
+ )
121
+ super().__init__(message)
122
+
123
+
124
+ def _invalidate_cache(method: Callable[Param, RetType]) -> Callable[Param, RetType]:
125
+ """Decorator that invalidates model cache when decorated method is called.
126
+
127
+ Args:
128
+ method: Method to wrap with cache invalidation
129
+
130
+ Returns:
131
+ Wrapped method that clears cache before execution
132
+
133
+ """
134
+
135
+ def wrapper(
136
+ *args: Param.args,
137
+ **kwargs: Param.kwargs,
138
+ ) -> RetType:
139
+ self = cast(Model, args[0])
140
+ self._cache = None
141
+ return method(*args, **kwargs)
142
+
143
+ return wrapper # type: ignore
144
+
145
+
146
+ def _check_if_is_sortable(
147
+ available: set[str],
148
+ elements: list[tuple[str, set[str]]],
149
+ ) -> None:
150
+ all_available = available.copy()
151
+ for name, _ in elements:
152
+ all_available.add(name)
153
+
154
+ # Check if it can be sorted in the first place
155
+ not_solvable = {}
156
+ for name, args in elements:
157
+ if not args.issubset(all_available):
158
+ not_solvable[name] = sorted(args.difference(all_available))
159
+
160
+ if not_solvable:
161
+ raise MissingDependenciesError(not_solvable=not_solvable)
162
+
163
+
164
+ def _sort_dependencies(
165
+ available: set[str], elements: list[tuple[str, set[str]]]
166
+ ) -> list[str]:
167
+ """Sort model elements topologically based on their dependencies.
168
+
169
+ Args:
170
+ available: Set of available component names
171
+ elements: List of (name, dependencies) tuples to sort
172
+
173
+ Returns:
174
+ List of element names in dependency order
175
+
176
+ Raises:
177
+ SortError: If circular dependencies are detected
178
+
179
+ """
180
+ from queue import Empty, SimpleQueue
181
+
182
+ _check_if_is_sortable(available, elements)
183
+
184
+ order = []
185
+ # FIXME: what is the worst case here?
186
+ max_iterations = len(elements) ** 2
187
+ queue: SimpleQueue[tuple[str, set[str]]] = SimpleQueue()
188
+ for k, v in elements:
189
+ queue.put((k, v))
190
+
191
+ last_name = None
192
+ i = 0
193
+ while True:
194
+ try:
195
+ new, args = queue.get_nowait()
196
+ except Empty:
197
+ break
198
+ if args.issubset(available):
199
+ available.add(new)
200
+ order.append(new)
201
+ else:
202
+ if last_name == new:
203
+ order.append(new)
204
+ break
205
+ queue.put((new, args))
206
+ last_name = new
207
+ i += 1
208
+
209
+ # Failure case
210
+ if i > max_iterations:
211
+ unsorted = []
212
+ while True:
213
+ try:
214
+ unsorted.append(queue.get_nowait()[0])
215
+ except Empty:
216
+ break
217
+
218
+ mod_to_args: dict[str, set[str]] = dict(elements)
219
+ missing = {k: mod_to_args[k].difference(available) for k in unsorted}
220
+ raise CircularDependencyError(missing=missing)
221
+ return order
222
+
223
+
224
+ @dataclass(slots=True)
225
+ class ModelCache:
226
+ """ModelCache is a class that stores various model-related data structures.
227
+
228
+ Attributes:
229
+ var_names: A list of variable names.
230
+ parameter_values: A dictionary mapping parameter names to their values.
231
+ derived_parameters: A dictionary mapping parameter names to their derived parameter objects.
232
+ derived_variables: A dictionary mapping variable names to their derived variable objects.
233
+ stoich_by_cpds: A dictionary mapping compound names to their stoichiometric coefficients.
234
+ dyn_stoich_by_cpds: A dictionary mapping compound names to their dynamic stoichiometric coefficients.
235
+ dxdt: A pandas Series representing the rate of change of variables.
236
+
237
+ """
238
+
239
+ var_names: list[str]
240
+ order: list[str]
241
+ all_parameter_values: dict[str, float]
242
+ derived_parameter_names: list[str]
243
+ derived_variable_names: list[str]
244
+ stoich_by_cpds: dict[str, dict[str, float]]
245
+ dyn_stoich_by_cpds: dict[str, dict[str, Derived]]
246
+ dxdt: pd.Series
247
+
248
+
249
+ @dataclass(slots=True)
250
+ class Model:
251
+ """Represents a metabolic model.
252
+
253
+ Attributes:
254
+ _ids: Dictionary mapping internal IDs to names.
255
+ _variables: Dictionary of model variables and their initial values.
256
+ _parameters: Dictionary of model parameters and their values.
257
+ _derived: Dictionary of derived quantities.
258
+ _readouts: Dictionary of readout functions.
259
+ _reactions: Dictionary of reactions in the model.
260
+ _surrogates: Dictionary of surrogate models.
261
+ _cache: Cache for storing model-related data structures.
262
+
263
+ """
264
+
265
+ _ids: dict[str, str] = field(default_factory=dict)
266
+ _variables: dict[str, float] = field(default_factory=dict)
267
+ _parameters: dict[str, float] = field(default_factory=dict)
268
+ _derived: dict[str, Derived] = field(default_factory=dict)
269
+ _readouts: dict[str, Readout] = field(default_factory=dict)
270
+ _reactions: dict[str, Reaction] = field(default_factory=dict)
271
+ _surrogates: dict[str, AbstractSurrogate] = field(default_factory=dict)
272
+ _cache: ModelCache | None = None
273
+
274
+ ###########################################################################
275
+ # Cache
276
+ ###########################################################################
277
+
278
+ def _create_cache(self) -> ModelCache:
279
+ """Creates and initializes the model cache.
280
+
281
+ This method constructs a cache that includes parameter values, stoichiometry
282
+ by compounds, dynamic stoichiometry by compounds, derived variables, and
283
+ derived parameters. It processes the model's parameters, variables, derived
284
+ elements, reactions, and surrogates to populate the cache.
285
+
286
+ Returns:
287
+ ModelCache: An instance of ModelCache containing the initialized cache data.
288
+
289
+ """
290
+ all_parameter_values: dict[str, float] = self._parameters.copy()
291
+ all_parameter_names: set[str] = set(all_parameter_values)
292
+
293
+ # Sanity checks
294
+ for name, el in it.chain(
295
+ self._derived.items(),
296
+ self._reactions.items(),
297
+ self._readouts.items(),
298
+ ):
299
+ if not _check_function_arity(el.fn, len(el.args)):
300
+ raise ArityMismatchError(name, el.fn, el.args)
301
+
302
+ # Sort derived & reactions
303
+ to_sort = self._derived | self._reactions | self._surrogates
304
+ order = _sort_dependencies(
305
+ available=set(self._parameters) | set(self._variables) | {"time"},
306
+ elements=[(k, set(v.args)) for k, v in to_sort.items()],
307
+ )
308
+
309
+ # Split derived into parameters and variables
310
+ # for user convenience
311
+ derived_variable_names: list[str] = []
312
+ derived_parameter_names: list[str] = []
313
+ for name in order:
314
+ if name in self._reactions or name in self._surrogates:
315
+ continue
316
+ derived = self._derived[name]
317
+ if all(i in all_parameter_names for i in derived.args):
318
+ all_parameter_names.add(name)
319
+ derived_parameter_names.append(name)
320
+ all_parameter_values[name] = float(
321
+ derived.fn(*(all_parameter_values[i] for i in derived.args))
322
+ )
323
+ else:
324
+ derived_variable_names.append(name)
325
+
326
+ stoich_by_compounds: dict[str, dict[str, float]] = {}
327
+ dyn_stoich_by_compounds: dict[str, dict[str, Derived]] = {}
328
+
329
+ for rxn_name, rxn in self._reactions.items():
330
+ for cpd_name, factor in rxn.stoichiometry.items():
331
+ d_static = stoich_by_compounds.setdefault(cpd_name, {})
332
+
333
+ if isinstance(factor, Derived):
334
+ if all(i in all_parameter_names for i in factor.args):
335
+ d_static[rxn_name] = float(
336
+ factor.fn(*(all_parameter_values[i] for i in factor.args))
337
+ )
338
+ else:
339
+ dyn_stoich_by_compounds.setdefault(cpd_name, {})[rxn_name] = (
340
+ factor
341
+ )
342
+ else:
343
+ d_static[rxn_name] = factor
344
+
345
+ for surrogate in self._surrogates.values():
346
+ for rxn_name, rxn in surrogate.stoichiometries.items():
347
+ for cpd_name, factor in rxn.items():
348
+ stoich_by_compounds.setdefault(cpd_name, {})[rxn_name] = factor
349
+
350
+ var_names = self.get_variable_names()
351
+ dxdt = pd.Series(np.zeros(len(var_names), dtype=float), index=var_names)
352
+
353
+ self._cache = ModelCache(
354
+ var_names=var_names,
355
+ order=order,
356
+ all_parameter_values=all_parameter_values,
357
+ stoich_by_cpds=stoich_by_compounds,
358
+ dyn_stoich_by_cpds=dyn_stoich_by_compounds,
359
+ derived_variable_names=derived_variable_names,
360
+ derived_parameter_names=derived_parameter_names,
361
+ dxdt=dxdt,
362
+ )
363
+ return self._cache
364
+
365
+ ###########################################################################
366
+ # Ids
367
+ ###########################################################################
368
+
369
+ @property
370
+ def ids(self) -> dict[str, str]:
371
+ """Returns a copy of the _ids dictionary.
372
+
373
+ The _ids dictionary contains key-value pairs where both keys and values are strings.
374
+
375
+ Returns:
376
+ dict[str, str]: A copy of the _ids dictionary.
377
+
378
+ """
379
+ return self._ids.copy()
380
+
381
+ def _insert_id(self, *, name: str, ctx: str) -> None:
382
+ """Inserts an identifier into the model's internal ID dictionary.
383
+
384
+ Args:
385
+ name: The name of the identifier to insert.
386
+ ctx: The context associated with the identifier.
387
+
388
+ Raises:
389
+ KeyError: If the name is "time", which is a protected variable.
390
+ NameError: If the name already exists in the model's ID dictionary.
391
+
392
+ """
393
+ if name == "time":
394
+ msg = "time is a protected variable for time"
395
+ raise KeyError(msg)
396
+
397
+ if name in self._ids:
398
+ msg = f"Model already contains {ctx} called '{name}'"
399
+ raise NameError(msg)
400
+ self._ids[name] = ctx
401
+
402
+ def _remove_id(self, *, name: str) -> None:
403
+ """Remove an ID from the internal dictionary.
404
+
405
+ Args:
406
+ name (str): The name of the ID to be removed.
407
+
408
+ Raises:
409
+ KeyError: If the specified name does not exist in the dictionary.
410
+
411
+ """
412
+ del self._ids[name]
413
+
414
+ ##########################################################################
415
+ # Parameters
416
+ ##########################################################################
417
+
418
+ @_invalidate_cache
419
+ def add_parameter(self, name: str, value: float) -> Self:
420
+ """Adds a parameter to the model.
421
+
422
+ Examples:
423
+ >>> model.add_parameter("k1", 0.1)
424
+
425
+ Args:
426
+ name (str): The name of the parameter.
427
+ value (float): The value of the parameter.
428
+
429
+ Returns:
430
+ Self: The instance of the model with the added parameter.
431
+
432
+ """
433
+ self._insert_id(name=name, ctx="parameter")
434
+ self._parameters[name] = value
435
+ return self
436
+
437
+ def add_parameters(self, parameters: dict[str, float]) -> Self:
438
+ """Adds multiple parameters to the model.
439
+
440
+ Examples:
441
+ >>> model.add_parameters({"k1": 0.1, "k2": 0.2})
442
+
443
+ Args:
444
+ parameters (dict[str, float]): A dictionary where the keys are parameter names
445
+ and the values are the corresponding parameter values.
446
+
447
+ Returns:
448
+ Self: The instance of the model with the added parameters.
449
+
450
+ """
451
+ for k, v in parameters.items():
452
+ self.add_parameter(k, v)
453
+ return self
454
+
455
+ @property
456
+ def parameters(self) -> dict[str, float]:
457
+ """Returns the parameters of the model.
458
+
459
+ Examples:
460
+ >>> model.parameters
461
+ {"k1": 0.1, "k2": 0.2}
462
+
463
+ Returns:
464
+ parameters: A dictionary where the keys are parameter names (as strings)
465
+ and the values are parameter values (as floats).
466
+
467
+ """
468
+ return self._parameters.copy()
469
+
470
+ def get_parameter_names(self) -> list[str]:
471
+ """Retrieve the names of the parameters.
472
+
473
+ Examples:
474
+ >>> model.get_parameter_names()
475
+ ['k1', 'k2']
476
+
477
+ Returns:
478
+ parametes: A list containing the names of the parameters.
479
+
480
+ """
481
+ return list(self._parameters)
482
+
483
+ @_invalidate_cache
484
+ def remove_parameter(self, name: str) -> Self:
485
+ """Remove a parameter from the model.
486
+
487
+ Examples:
488
+ >>> model.remove_parameter("k1")
489
+
490
+ Args:
491
+ name: The name of the parameter to remove.
492
+
493
+ Returns:
494
+ Self: The instance of the model with the parameter removed.
495
+
496
+ """
497
+ self._remove_id(name=name)
498
+ self._parameters.pop(name)
499
+ return self
500
+
501
+ def remove_parameters(self, names: list[str]) -> Self:
502
+ """Remove multiple parameters from the model.
503
+
504
+ Examples:
505
+ >>> model.remove_parameters(["k1", "k2"])
506
+
507
+ Args:
508
+ names: A list of parameter names to be removed.
509
+
510
+ Returns:
511
+ Self: The instance of the model with the specified parameters removed.
512
+
513
+ """
514
+ for name in names:
515
+ self.remove_parameter(name)
516
+ return self
517
+
518
+ @_invalidate_cache
519
+ def update_parameter(self, name: str, value: float) -> Self:
520
+ """Update the value of a parameter.
521
+
522
+ Examples:
523
+ >>> model.update_parameter("k1", 0.2)
524
+
525
+ Args:
526
+ name: The name of the parameter to update.
527
+ value: The new value for the parameter.
528
+
529
+ Returns:
530
+ Self: The instance of the class with the updated parameter.
531
+
532
+ Raises:
533
+ NameError: If the parameter name is not found in the parameters.
534
+
535
+ """
536
+ if name not in self._parameters:
537
+ msg = f"'{name}' not found in parameters"
538
+ raise KeyError(msg)
539
+ self._parameters[name] = value
540
+ return self
541
+
542
+ def update_parameters(self, parameters: dict[str, float]) -> Self:
543
+ """Update multiple parameters of the model.
544
+
545
+ Examples:
546
+ >>> model.update_parameters({"k1": 0.2, "k2": 0.3})
547
+
548
+ Args:
549
+ parameters: A dictionary where keys are parameter names and values are the new parameter values.
550
+
551
+ Returns:
552
+ Self: The instance of the model with updated parameters.
553
+
554
+ """
555
+ for k, v in parameters.items():
556
+ self.update_parameter(k, v)
557
+ return self
558
+
559
+ def scale_parameter(self, name: str, factor: float) -> Self:
560
+ """Scales the value of a specified parameter by a given factor.
561
+
562
+ Examples:
563
+ >>> model.scale_parameter("k1", 2.0)
564
+
565
+ Args:
566
+ name: The name of the parameter to be scaled.
567
+ factor: The factor by which to scale the parameter's value.
568
+
569
+ Returns:
570
+ Self: The instance of the class with the updated parameter.
571
+
572
+ """
573
+ return self.update_parameter(name, self._parameters[name] * factor)
574
+
575
+ def scale_parameters(self, parameters: dict[str, float]) -> Self:
576
+ """Scales the parameters of the model.
577
+
578
+ Examples:
579
+ >>> model.scale_parameters({"k1": 2.0, "k2": 0.5})
580
+
581
+ Args:
582
+ parameters: A dictionary where the keys are parameter names
583
+ and the values are the scaling factors.
584
+
585
+ Returns:
586
+ Self: The instance of the model with scaled parameters.
587
+
588
+ """
589
+ for k, v in parameters.items():
590
+ self.scale_parameter(k, v)
591
+ return self
592
+
593
+ @_invalidate_cache
594
+ def make_parameter_dynamic(
595
+ self,
596
+ name: str,
597
+ initial_value: float | None = None,
598
+ stoichiometries: dict[str, float] | None = None,
599
+ ) -> Self:
600
+ """Converts a parameter to a dynamic variable in the model.
601
+
602
+ Examples:
603
+ >>> model.make_parameter_dynamic("k1")
604
+ >>> model.make_parameter_dynamic("k2", initial_value=0.5)
605
+
606
+ This method removes the specified parameter from the model and adds it as a variable with an optional initial value.
607
+
608
+ Args:
609
+ name: The name of the parameter to be converted.
610
+ initial_value: The initial value for the new variable. If None, the current value of the parameter is used. Defaults to None.
611
+ stoichiometries: A dictionary mapping reaction names to stoichiometries for the new variable. Defaults to None.
612
+
613
+ Returns:
614
+ Self: The instance of the model with the parameter converted to a variable.
615
+
616
+ """
617
+ value = self._parameters[name] if initial_value is None else initial_value
618
+ self.remove_parameter(name)
619
+ self.add_variable(name, value)
620
+
621
+ if stoichiometries is not None:
622
+ for rxn_name, value in stoichiometries.items():
623
+ target = False
624
+ if rxn_name in self._reactions:
625
+ target = True
626
+ cast(dict, self._reactions[name].stoichiometry)[name] = value
627
+ else:
628
+ for surrogate in self._surrogates.values():
629
+ if rxn_name in surrogate.stoichiometries:
630
+ target = True
631
+ surrogate.stoichiometries[rxn_name][name] = value
632
+ if not target:
633
+ msg = f"Reaction '{rxn_name}' not found in reactions or surrogates"
634
+ raise KeyError(msg)
635
+
636
+ return self
637
+
638
+ ##########################################################################
639
+ # Variables
640
+ ##########################################################################
641
+
642
+ @property
643
+ def variables(self) -> dict[str, float]:
644
+ """Returns a copy of the variables dictionary.
645
+
646
+ Examples:
647
+ >>> model.variables
648
+ {"x1": 1.0, "x2": 2.0}
649
+
650
+ This method returns a copy of the internal dictionary that maps variable
651
+ names to their corresponding float values.
652
+
653
+ Returns:
654
+ dict[str, float]: A copy of the variables dictionary.
655
+
656
+ """
657
+ return self._variables.copy()
658
+
659
+ @_invalidate_cache
660
+ def add_variable(self, name: str, initial_condition: float) -> Self:
661
+ """Adds a variable to the model with the given name and initial condition.
662
+
663
+ Examples:
664
+ >>> model.add_variable("x1", 1.0)
665
+
666
+ Args:
667
+ name: The name of the variable to add.
668
+ initial_condition: The initial condition value for the variable.
669
+
670
+ Returns:
671
+ Self: The instance of the model with the added variable.
672
+
673
+ """
674
+ self._insert_id(name=name, ctx="variable")
675
+ self._variables[name] = initial_condition
676
+ return self
677
+
678
+ def add_variables(self, variables: dict[str, float]) -> Self:
679
+ """Adds multiple variables to the model with their initial conditions.
680
+
681
+ Examples:
682
+ >>> model.add_variables({"x1": 1.0, "x2": 2.0})
683
+
684
+ Args:
685
+ variables: A dictionary where the keys are variable names (str)
686
+ and the values are their initial conditions (float).
687
+
688
+ Returns:
689
+ Self: The instance of the model with the added variables.
690
+
691
+ """
692
+ for name, y0 in variables.items():
693
+ self.add_variable(name=name, initial_condition=y0)
694
+ return self
695
+
696
+ @_invalidate_cache
697
+ def remove_variable(self, name: str) -> Self:
698
+ """Remove a variable from the model.
699
+
700
+ Examples:
701
+ >>> model.remove_variable("x1")
702
+
703
+ Args:
704
+ name: The name of the variable to remove.
705
+
706
+ Returns:
707
+ Self: The instance of the model with the variable removed.
708
+
709
+ """
710
+ self._remove_id(name=name)
711
+ del self._variables[name]
712
+ return self
713
+
714
+ def remove_variables(self, variables: Iterable[str]) -> Self:
715
+ """Remove multiple variables from the model.
716
+
717
+ Examples:
718
+ >>> model.remove_variables(["x1", "x2"])
719
+
720
+ Args:
721
+ variables: An iterable of variable names to be removed.
722
+
723
+ Returns:
724
+ Self: The instance of the model with the specified variables removed.
725
+
726
+ """
727
+ for variable in variables:
728
+ self.remove_variable(name=variable)
729
+ return self
730
+
731
+ @_invalidate_cache
732
+ def update_variable(self, name: str, initial_condition: float) -> Self:
733
+ """Updates the value of a variable in the model.
734
+
735
+ Examples:
736
+ >>> model.update_variable("x1", 2.0)
737
+
738
+ Args:
739
+ name: The name of the variable to update.
740
+ initial_condition: The initial condition or value to set for the variable.
741
+
742
+ Returns:
743
+ Self: The instance of the model with the updated variable.
744
+
745
+ """
746
+ if name not in self._variables:
747
+ msg = f"'{name}' not found in variables"
748
+ raise KeyError(msg)
749
+ self._variables[name] = initial_condition
750
+ return self
751
+
752
+ def update_variables(self, variables: dict[str, float]) -> Self:
753
+ """Updates multiple variables in the model.
754
+
755
+ Examples:
756
+ >>> model.update_variables({"x1": 2.0, "x2": 3.0})
757
+
758
+ Args:
759
+ variables: A dictionary where the keys are variable names and the values are their new initial conditions.
760
+
761
+ Returns:
762
+ Self: The instance of the model with updated variables.
763
+
764
+ """
765
+ for k, v in variables.items():
766
+ self.update_variable(k, v)
767
+ return self
768
+
769
+ def get_variable_names(self) -> list[str]:
770
+ """Retrieve the names of all variables.
771
+
772
+ Examples:
773
+ >>> model.get_variable_names()
774
+ ["x1", "x2"]
775
+
776
+ Returns:
777
+ variable_names: A list containing the names of all variables.
778
+
779
+ """
780
+ return list(self._variables)
781
+
782
+ def get_initial_conditions(self) -> dict[str, float]:
783
+ """Retrieve the initial conditions of the model.
784
+
785
+ Examples:
786
+ >>> model.get_initial_conditions()
787
+ {"x1": 1.0, "x2": 2.0}
788
+
789
+ Returns:
790
+ initial_conditions: A dictionary where the keys are variable names and the values are their initial conditions.
791
+
792
+ """
793
+ return self._variables
794
+
795
+ def make_variable_static(self, name: str, value: float | None = None) -> Self:
796
+ """Converts a variable to a static parameter.
797
+
798
+ This removes the variable from the stoichiometries of all reactions and surrogates.
799
+ It is not re-inserted if `Model.make_parameter_dynamic` is called.
800
+
801
+ Examples:
802
+ >>> model.make_variable_static("x1")
803
+ >>> model.make_variable_static("x2", value=2.0)
804
+
805
+ Args:
806
+ name: The name of the variable to be made static.
807
+ value: The value to assign to the parameter.
808
+ If None, the current value of the variable is used. Defaults to None.
809
+
810
+ Returns:
811
+ Self: The instance of the class for method chaining.
812
+
813
+ """
814
+ value = self._variables[name] if value is None else value
815
+ self.remove_variable(name)
816
+ self.add_parameter(name, value)
817
+
818
+ # Remove from stoichiometries
819
+ for reaction in self._reactions.values():
820
+ if name in reaction.stoichiometry:
821
+ cast(dict, reaction.stoichiometry).pop(name)
822
+ for surrogate in self._surrogates.values():
823
+ surrogate.stoichiometries = {
824
+ k: {k2: v2 for k2, v2 in v.items() if k2 != name}
825
+ for k, v in surrogate.stoichiometries.items()
826
+ if k != name
827
+ }
828
+ return self
829
+
830
+ ##########################################################################
831
+ # Derived
832
+ ##########################################################################
833
+
834
+ @property
835
+ def derived(self) -> dict[str, Derived]:
836
+ """Returns a copy of the derived quantities.
837
+
838
+ Examples:
839
+ >>> model.derived
840
+ {"d1": Derived(fn1, ["x1", "x2"]),
841
+ "d2": Derived(fn2, ["x1", "d1"])}
842
+
843
+ Returns:
844
+ dict[str, Derived]: A copy of the derived dictionary.
845
+
846
+ """
847
+ return self._derived.copy()
848
+
849
+ @property
850
+ def derived_variables(self) -> dict[str, Derived]:
851
+ """Returns a dictionary of derived variables.
852
+
853
+ Examples:
854
+ >>> model.derived_variables()
855
+ {"d1": Derived(fn1, ["x1", "x2"]),
856
+ "d2": Derived(fn2, ["x1", "d1"])}
857
+
858
+ Returns:
859
+ derived_variables: A dictionary where the keys are strings
860
+ representing the names of the derived variables and the values are
861
+ instances of DerivedVariable.
862
+
863
+ """
864
+ if (cache := self._cache) is None:
865
+ cache = self._create_cache()
866
+ derived = self._derived
867
+ return {k: derived[k] for k in cache.derived_variable_names}
868
+
869
+ @property
870
+ def derived_parameters(self) -> dict[str, Derived]:
871
+ """Returns a dictionary of derived parameters.
872
+
873
+ Examples:
874
+ >>> model.derived_parameters()
875
+ {"kd1": Derived(fn1, ["k1", "k2"]),
876
+ "kd2": Derived(fn2, ["k1", "kd1"])}
877
+
878
+ Returns:
879
+ A dictionary where the keys are
880
+ parameter names and the values are Derived.
881
+
882
+ """
883
+ if (cache := self._cache) is None:
884
+ cache = self._create_cache()
885
+ derived = self._derived
886
+ return {k: derived[k] for k in cache.derived_parameter_names}
887
+
888
+ @_invalidate_cache
889
+ def add_derived(
890
+ self,
891
+ name: str,
892
+ fn: RateFn,
893
+ *,
894
+ args: list[str],
895
+ ) -> Self:
896
+ """Adds a derived attribute to the model.
897
+
898
+ Examples:
899
+ >>> model.add_derived("d1", add, args=["x1", "x2"])
900
+
901
+ Args:
902
+ name: The name of the derived attribute.
903
+ fn: The function used to compute the derived attribute.
904
+ args: The list of arguments to be passed to the function.
905
+
906
+ Returns:
907
+ Self: The instance of the model with the added derived attribute.
908
+
909
+ """
910
+ self._insert_id(name=name, ctx="derived")
911
+ self._derived[name] = Derived(name=name, fn=fn, args=args)
912
+ return self
913
+
914
+ def get_derived_parameter_names(self) -> list[str]:
915
+ """Retrieve the names of derived parameters.
916
+
917
+ Examples:
918
+ >>> model.get_derived_parameter_names()
919
+ ["kd1", "kd2"]
920
+
921
+ Returns:
922
+ A list of names of the derived parameters.
923
+
924
+ """
925
+ return list(self.derived_parameters)
926
+
927
+ def get_derived_variable_names(self) -> list[str]:
928
+ """Retrieve the names of derived variables.
929
+
930
+ Examples:
931
+ >>> model.get_derived_variable_names()
932
+ ["d1", "d2"]
933
+
934
+ Returns:
935
+ A list of names of derived variables.
936
+
937
+ """
938
+ return list(self.derived_variables)
939
+
940
+ @_invalidate_cache
941
+ def update_derived(
942
+ self,
943
+ name: str,
944
+ fn: RateFn | None = None,
945
+ *,
946
+ args: list[str] | None = None,
947
+ ) -> Self:
948
+ """Updates the derived function and its arguments for a given name.
949
+
950
+ Examples:
951
+ >>> model.update_derived("d1", add, ["x1", "x2"])
952
+
953
+ Args:
954
+ name: The name of the derived function to update.
955
+ fn: The new derived function. If None, the existing function is retained. Defaults to None.
956
+ args: The new arguments for the derived function. If None, the existing arguments are retained. Defaults to None.
957
+
958
+ Returns:
959
+ Self: The instance of the class with the updated derived function and arguments.
960
+
961
+ """
962
+ der = self._derived[name]
963
+ der.fn = der.fn if fn is None else fn
964
+ der.args = der.args if args is None else args
965
+ return self
966
+
967
+ @_invalidate_cache
968
+ def remove_derived(self, name: str) -> Self:
969
+ """Remove a derived attribute from the model.
970
+
971
+ Examples:
972
+ >>> model.remove_derived("d1")
973
+
974
+ Args:
975
+ name: The name of the derived attribute to remove.
976
+
977
+ Returns:
978
+ Self: The instance of the model with the derived attribute removed.
979
+
980
+ """
981
+ self._remove_id(name=name)
982
+ self._derived.pop(name)
983
+ return self
984
+
985
+ ###########################################################################
986
+ # Reactions
987
+ ###########################################################################
988
+
989
+ @property
990
+ def reactions(self) -> dict[str, Reaction]:
991
+ """Retrieve the reactions in the model.
992
+
993
+ Examples:
994
+ >>> model.reactions
995
+ {"r1": Reaction(fn1, {"x1": -1, "x2": 1}, ["k1"]),
996
+
997
+ Returns:
998
+ dict[str, Reaction]: A deep copy of the reactions dictionary.
999
+
1000
+ """
1001
+ return copy.deepcopy(self._reactions)
1002
+
1003
+ def get_stoichiometries(
1004
+ self, variables: dict[str, float] | None = None, time: float = 0.0
1005
+ ) -> pd.DataFrame:
1006
+ """Retrieve the stoichiometries of the model.
1007
+
1008
+ Examples:
1009
+ >>> model.stoichiometries()
1010
+ v1 v2
1011
+ x1 -1 1
1012
+ x2 1 -1
1013
+
1014
+ Returns:
1015
+ pd.DataFrame: A DataFrame containing the stoichiometries of the model.
1016
+
1017
+ """
1018
+ if (cache := self._cache) is None:
1019
+ cache = self._create_cache()
1020
+ args = self.get_dependent(variables=variables, time=time)
1021
+
1022
+ stoich_by_cpds = copy.deepcopy(cache.stoich_by_cpds)
1023
+ for cpd, stoich in cache.dyn_stoich_by_cpds.items():
1024
+ for rxn, derived in stoich.items():
1025
+ stoich_by_cpds[cpd][rxn] = float(
1026
+ derived.fn(*(args[i] for i in derived.args))
1027
+ )
1028
+ return pd.DataFrame(stoich_by_cpds).T.fillna(0)
1029
+
1030
+ @_invalidate_cache
1031
+ def add_reaction(
1032
+ self,
1033
+ name: str,
1034
+ fn: RateFn,
1035
+ *,
1036
+ args: list[str],
1037
+ stoichiometry: Mapping[str, float | str | Derived],
1038
+ ) -> Self:
1039
+ """Adds a reaction to the model.
1040
+
1041
+ Examples:
1042
+ >>> model.add_reaction("v1",
1043
+ ... fn=mass_action,
1044
+ ... args=["x1", "kf1"],
1045
+ ... stoichiometry={"x1": -1, "x2": 1},
1046
+ ... )
1047
+
1048
+ Args:
1049
+ name: The name of the reaction.
1050
+ fn: The function representing the reaction.
1051
+ args: A list of arguments for the reaction function.
1052
+ stoichiometry: The stoichiometry of the reaction, mapping species to their coefficients.
1053
+
1054
+ Returns:
1055
+ Self: The instance of the model with the added reaction.
1056
+
1057
+ """
1058
+ self._insert_id(name=name, ctx="reaction")
1059
+
1060
+ stoich: dict[str, Derived | float] = {
1061
+ k: Derived(name=k, fn=fns.constant, args=[v]) if isinstance(v, str) else v
1062
+ for k, v in stoichiometry.items()
1063
+ }
1064
+ self._reactions[name] = Reaction(
1065
+ name=name, fn=fn, stoichiometry=stoich, args=args
1066
+ )
1067
+ return self
1068
+
1069
+ def get_reaction_names(self) -> list[str]:
1070
+ """Retrieve the names of all reactions.
1071
+
1072
+ Examples:
1073
+ >>> model.get_reaction_names()
1074
+ ["v1", "v2"]
1075
+
1076
+ Returns:
1077
+ list[str]: A list containing the names of the reactions.
1078
+
1079
+ """
1080
+ return list(self._reactions)
1081
+
1082
+ @_invalidate_cache
1083
+ def update_reaction(
1084
+ self,
1085
+ name: str,
1086
+ fn: RateFn | None = None,
1087
+ *,
1088
+ args: list[str] | None = None,
1089
+ stoichiometry: Mapping[str, float | Derived | str] | None = None,
1090
+ ) -> Self:
1091
+ """Updates the properties of an existing reaction in the model.
1092
+
1093
+ Examples:
1094
+ >>> model.update_reaction("v1",
1095
+ ... fn=mass_action,
1096
+ ... args=["x1", "kf1"],
1097
+ ... stoichiometry={"x1": -1, "x2": 1},
1098
+ ... )
1099
+
1100
+ Args:
1101
+ name: The name of the reaction to update.
1102
+ fn: The new function for the reaction. If None, the existing function is retained.
1103
+ args: The new arguments for the reaction. If None, the existing arguments are retained.
1104
+ stoichiometry: The new stoichiometry for the reaction. If None, the existing stoichiometry is retained.
1105
+
1106
+ Returns:
1107
+ Self: The instance of the model with the updated reaction.
1108
+
1109
+ """
1110
+ rxn = self._reactions[name]
1111
+ rxn.fn = rxn.fn if fn is None else fn
1112
+
1113
+ if stoichiometry is not None:
1114
+ stoich = {
1115
+ k: Derived(name=k, fn=fns.constant, args=[v])
1116
+ if isinstance(v, str)
1117
+ else v
1118
+ for k, v in stoichiometry.items()
1119
+ }
1120
+ rxn.stoichiometry = stoich
1121
+ rxn.args = rxn.args if args is None else args
1122
+ return self
1123
+
1124
+ @_invalidate_cache
1125
+ def remove_reaction(self, name: str) -> Self:
1126
+ """Remove a reaction from the model by its name.
1127
+
1128
+ Examples:
1129
+ >>> model.remove_reaction("v1")
1130
+
1131
+ Args:
1132
+ name: The name of the reaction to be removed.
1133
+
1134
+ Returns:
1135
+ Self: The instance of the model with the reaction removed.
1136
+
1137
+ """
1138
+ self._remove_id(name=name)
1139
+ self._reactions.pop(name)
1140
+ return self
1141
+
1142
+ # def update_stoichiometry_of_cpd(
1143
+ # self,
1144
+ # rate_name: str,
1145
+ # compound: str,
1146
+ # value: float,
1147
+ # ) -> Model:
1148
+ # self.update_stoichiometry(
1149
+ # rate_name=rate_name,
1150
+ # stoichiometry=self.stoichiometries[rate_name] | {compound: value},
1151
+ # )
1152
+ # return self
1153
+
1154
+ # def scale_stoichiometry_of_cpd(
1155
+ # self,
1156
+ # rate_name: str,
1157
+ # compound: str,
1158
+ # scale: float,
1159
+ # ) -> Model:
1160
+ # return self.update_stoichiometry_of_cpd(
1161
+ # rate_name=rate_name,
1162
+ # compound=compound,
1163
+ # value=self.stoichiometries[rate_name][compound] * scale,
1164
+ # )
1165
+
1166
+ ##########################################################################
1167
+ # Readouts
1168
+ # They are like derived variables, but only calculated on demand
1169
+ # Think of something like NADPH / (NADP + NADPH) as a proxy for energy state
1170
+ ##########################################################################
1171
+
1172
+ def add_readout(self, name: str, fn: RateFn, *, args: list[str]) -> Self:
1173
+ """Adds a readout to the model.
1174
+
1175
+ Examples:
1176
+ >>> model.add_readout("energy_state",
1177
+ ... fn=div,
1178
+ ... args=["NADPH", "NADP*_total"]
1179
+ ... )
1180
+
1181
+ Args:
1182
+ name: The name of the readout.
1183
+ fn: The function to be used for the readout.
1184
+ args: The list of arguments for the function.
1185
+
1186
+ Returns:
1187
+ Self: The instance of the model with the added readout.
1188
+
1189
+ """
1190
+ self._insert_id(name=name, ctx="readout")
1191
+ self._readouts[name] = Readout(name=name, fn=fn, args=args)
1192
+ return self
1193
+
1194
+ def get_readout_names(self) -> list[str]:
1195
+ """Retrieve the names of all readouts.
1196
+
1197
+ Examples:
1198
+ >>> model.get_readout_names()
1199
+ ["energy_state", "redox_state"]
1200
+
1201
+ Returns:
1202
+ list[str]: A list containing the names of the readouts.
1203
+
1204
+ """
1205
+ return list(self._readouts)
1206
+
1207
+ def remove_readout(self, name: str) -> Self:
1208
+ """Remove a readout by its name.
1209
+
1210
+ Examples:
1211
+ >>> model.remove_readout("energy_state")
1212
+
1213
+ Args:
1214
+ name (str): The name of the readout to remove.
1215
+
1216
+ Returns:
1217
+ Self: The instance of the class after the readout has been removed.
1218
+
1219
+ """
1220
+ self._remove_id(name=name)
1221
+ del self._readouts[name]
1222
+ return self
1223
+
1224
+ ##########################################################################
1225
+ # Surrogates
1226
+ ##########################################################################
1227
+
1228
+ @_invalidate_cache
1229
+ def add_surrogate(
1230
+ self,
1231
+ name: str,
1232
+ surrogate: AbstractSurrogate,
1233
+ args: list[str] | None = None,
1234
+ stoichiometries: dict[str, dict[str, float]] | None = None,
1235
+ ) -> Self:
1236
+ """Adds a surrogate model to the current instance.
1237
+
1238
+ Examples:
1239
+ >>> model.add_surrogate("name", surrogate)
1240
+
1241
+ Args:
1242
+ name (str): The name of the surrogate model.
1243
+ surrogate (AbstractSurrogate): The surrogate model instance to be added.
1244
+ args: A list of arguments for the surrogate model.
1245
+ stoichiometries: A dictionary mapping reaction names to stoichiometries.
1246
+
1247
+ Returns:
1248
+ Self: The current instance with the added surrogate model.
1249
+
1250
+ """
1251
+ self._insert_id(name=name, ctx="surrogate")
1252
+
1253
+ if args is not None:
1254
+ surrogate.args = args
1255
+ if stoichiometries is not None:
1256
+ surrogate.stoichiometries = stoichiometries
1257
+
1258
+ self._surrogates[name] = surrogate
1259
+ return self
1260
+
1261
+ def update_surrogate(
1262
+ self,
1263
+ name: str,
1264
+ surrogate: AbstractSurrogate | None = None,
1265
+ args: list[str] | None = None,
1266
+ stoichiometries: dict[str, dict[str, float]] | None = None,
1267
+ ) -> Self:
1268
+ """Update a surrogate model in the model.
1269
+
1270
+ Examples:
1271
+ >>> model.update_surrogate("name", surrogate)
1272
+
1273
+ Args:
1274
+ name (str): The name of the surrogate model to update.
1275
+ surrogate (AbstractSurrogate): The new surrogate model instance.
1276
+ args: A list of arguments for the surrogate model.
1277
+ stoichiometries: A dictionary mapping reaction names to stoichiometries.
1278
+
1279
+ Returns:
1280
+ Self: The instance of the model with the updated surrogate model.
1281
+
1282
+ """
1283
+ if name not in self._surrogates:
1284
+ msg = f"Surrogate '{name}' not found in model"
1285
+ raise KeyError(msg)
1286
+
1287
+ if surrogate is None:
1288
+ surrogate = self._surrogates[name]
1289
+ if args is not None:
1290
+ surrogate.args = args
1291
+ if stoichiometries is not None:
1292
+ surrogate.stoichiometries = stoichiometries
1293
+
1294
+ self._surrogates[name] = surrogate
1295
+ return self
1296
+
1297
+ def remove_surrogate(self, name: str) -> Self:
1298
+ """Remove a surrogate model from the model.
1299
+
1300
+ Examples:
1301
+ >>> model.remove_surrogate("name")
1302
+
1303
+ Returns:
1304
+ Self: The instance of the model with the specified surrogate model removed.
1305
+
1306
+ """
1307
+ self._remove_id(name=name)
1308
+ self._surrogates.pop(name)
1309
+ return self
1310
+
1311
+ def get_surrogate_reaction_names(self) -> list[str]:
1312
+ """Return reaction names by surrogates."""
1313
+ names = []
1314
+ for i in self._surrogates.values():
1315
+ names.extend(i.stoichiometries)
1316
+ return names
1317
+
1318
+ ##########################################################################
1319
+ # Get dependent values. This includes
1320
+ # - derived parameters
1321
+ # - derived variables
1322
+ # - fluxes
1323
+ # - readouts
1324
+ ##########################################################################
1325
+
1326
+ def _get_dependent(
1327
+ self,
1328
+ variables: dict[str, float],
1329
+ time: float = 0.0,
1330
+ *,
1331
+ cache: ModelCache,
1332
+ ) -> dict[str, float]:
1333
+ """Generate a dictionary of model components dependent on other components.
1334
+
1335
+ Examples:
1336
+ >>> model._get_dependent({"x1": 1.0, "x2": 2.0}, time=0.0)
1337
+ {"x1": 1.0, "x2": 2.0, "k1": 0.1, "time": 0.0}
1338
+
1339
+ Args:
1340
+ variables: A dictionary of concentrations with keys as the names of the substances
1341
+ and values as their respective concentrations.
1342
+ time: The time point for the calculation
1343
+ cache: A ModelCache object containing precomputed values and dependencies.
1344
+ include_readouts: A flag indicating whether to include readout values in the returned dictionary.
1345
+
1346
+ Returns:
1347
+ dict[str, float]
1348
+ A dictionary containing parameter values, derived variables, and optionally readouts,
1349
+ with their respective names as keys and their calculated values as values.
1350
+
1351
+ """
1352
+ args: dict[str, float] = cache.all_parameter_values | variables
1353
+ args["time"] = time
1354
+
1355
+ containers = self._derived | self._reactions | self._surrogates
1356
+ for name in cache.order:
1357
+ containers[name].calculate_inpl(args)
1358
+
1359
+ return args
1360
+
1361
+ def get_dependent(
1362
+ self,
1363
+ variables: dict[str, float] | None = None,
1364
+ time: float = 0.0,
1365
+ *,
1366
+ include_readouts: bool = False,
1367
+ ) -> pd.Series:
1368
+ """Generate a pandas Series of arguments for the model.
1369
+
1370
+ Examples:
1371
+ # Using initial conditions
1372
+ >>> model.get_args()
1373
+ {"x1": 1.get_dependent, "x2": 2.0, "k1": 0.1, "time": 0.0}
1374
+
1375
+ # With custom concentrations
1376
+ >>> model.get_dependent({"x1": 1.0, "x2": 2.0})
1377
+ {"x1": 1.0, "x2": 2.0, "k1": 0.1, "time": 0.0}
1378
+
1379
+ # With custom concentrations and time
1380
+ >>> model.get_dependent({"x1": 1.0, "x2": 2.0}, time=1.0)
1381
+ {"x1": 1.0, "x2": 2.0, "k1": 0.1, "time": 1.0}
1382
+
1383
+ Args:
1384
+ variables: A dictionary where keys are the names of the concentrations and values are their respective float values.
1385
+ time: The time point at which the arguments are generated (default is 0.0).
1386
+ include_readouts: Whether to include readouts in the arguments (default is False).
1387
+
1388
+ Returns:
1389
+ A pandas Series containing the generated arguments with float dtype.
1390
+
1391
+ """
1392
+ if (cache := self._cache) is None:
1393
+ cache = self._create_cache()
1394
+
1395
+ args = self._get_dependent(
1396
+ variables=self.get_initial_conditions() if variables is None else variables,
1397
+ time=time,
1398
+ cache=cache,
1399
+ )
1400
+
1401
+ if include_readouts:
1402
+ for ro in self._readouts.values(): # FIXME: order?
1403
+ ro.calculate_inpl(args)
1404
+
1405
+ return pd.Series(args, dtype=float)
1406
+
1407
+ def get_dependent_time_course(
1408
+ self,
1409
+ variables: pd.DataFrame,
1410
+ *,
1411
+ include_readouts: bool = False,
1412
+ ) -> pd.DataFrame:
1413
+ """Generate a DataFrame containing time course arguments for model evaluation.
1414
+
1415
+ Examples:
1416
+ >>> model.get_dependent_time_course(
1417
+ ... pd.DataFrame({"x1": [1.0, 2.0], "x2": [2.0, 3.0]}
1418
+ ... )
1419
+ pd.DataFrame({
1420
+ "x1": [1.0, 2.0],
1421
+ "x2": [2.0, 3.0],
1422
+ "k1": [0.1, 0.1],
1423
+ "time": [0.0, 1.0]},
1424
+ )
1425
+
1426
+ Args:
1427
+ variables: A DataFrame containing concentration data with time as the index.
1428
+ include_readouts: If True, include readout variables in the resulting DataFrame.
1429
+
1430
+ Returns:
1431
+ A DataFrame containing the combined concentration data, parameter values,
1432
+ derived variables, and optionally readout variables, with time as an additional column.
1433
+
1434
+ """
1435
+ if (cache := self._cache) is None:
1436
+ cache = self._create_cache()
1437
+
1438
+ pars_df = pd.DataFrame(
1439
+ np.full(
1440
+ (len(variables), len(cache.all_parameter_values)),
1441
+ np.fromiter(cache.all_parameter_values.values(), dtype=float),
1442
+ ),
1443
+ index=variables.index,
1444
+ columns=list(cache.all_parameter_values),
1445
+ )
1446
+
1447
+ args = pd.concat((variables, pars_df), axis=1)
1448
+ args["time"] = args.index
1449
+
1450
+ containers = self._derived | self._reactions | self._surrogates
1451
+ for name in cache.order:
1452
+ containers[name].calculate_inpl_time_course(args)
1453
+
1454
+ if include_readouts:
1455
+ for name, ro in self._readouts.items():
1456
+ args[name] = ro.fn(*args.loc[:, ro.args].to_numpy().T)
1457
+ return args
1458
+
1459
+ ##########################################################################
1460
+ # Get args
1461
+ ##########################################################################
1462
+
1463
+ def get_args(
1464
+ self,
1465
+ variables: dict[str, float] | None = None,
1466
+ time: float = 0.0,
1467
+ *,
1468
+ include_derived: bool = True,
1469
+ include_readouts: bool = False,
1470
+ ) -> pd.Series:
1471
+ """Generate a pandas Series of arguments for the model.
1472
+
1473
+ Examples:
1474
+ # Using initial conditions
1475
+ >>> model.get_args()
1476
+ {"x1": 1.0, "x2": 2.0, "k1": 0.1, "time": 0.0}
1477
+
1478
+ # With custom concentrations
1479
+ >>> model.get_args({"x1": 1.0, "x2": 2.0})
1480
+ {"x1": 1.0, "x2": 2.0, "k1": 0.1, "time": 0.0}
1481
+
1482
+ # With custom concentrations and time
1483
+ >>> model.get_args({"x1": 1.0, "x2": 2.0}, time=1.0)
1484
+ {"x1": 1.0, "x2": 2.0, "k1": 0.1, "time": 1.0}
1485
+
1486
+ Args:
1487
+ variables: A dictionary where keys are the names of the concentrations and values are their respective float values.
1488
+ time: The time point at which the arguments are generated.
1489
+ include_derived: Whether to include derived variables in the arguments.
1490
+ include_readouts: Whether to include readouts in the arguments.
1491
+
1492
+ Returns:
1493
+ A pandas Series containing the generated arguments with float dtype.
1494
+
1495
+ """
1496
+ names = self.get_variable_names()
1497
+ if include_derived:
1498
+ names.extend(self.get_derived_variable_names())
1499
+ if include_readouts:
1500
+ names.extend(self._readouts)
1501
+
1502
+ args = self.get_dependent(
1503
+ variables=variables, time=time, include_readouts=include_readouts
1504
+ )
1505
+ return args.loc[names]
1506
+
1507
+ def get_args_time_course(
1508
+ self,
1509
+ variables: pd.DataFrame,
1510
+ *,
1511
+ include_derived: bool = True,
1512
+ include_readouts: bool = False,
1513
+ ) -> pd.DataFrame:
1514
+ """Generate a DataFrame containing time course arguments for model evaluation.
1515
+
1516
+ Examples:
1517
+ >>> model.get_args_time_course(
1518
+ ... pd.DataFrame({"x1": [1.0, 2.0], "x2": [2.0, 3.0]}
1519
+ ... )
1520
+ pd.DataFrame({
1521
+ "x1": [1.0, 2.0],
1522
+ "x2": [2.0, 3.0],
1523
+ "k1": [0.1, 0.1],
1524
+ "time": [0.0, 1.0]},
1525
+ )
1526
+
1527
+ Args:
1528
+ variables: A DataFrame containing concentration data with time as the index.
1529
+ include_derived: Whether to include derived variables in the arguments.
1530
+ include_readouts: If True, include readout variables in the resulting DataFrame.
1531
+
1532
+ Returns:
1533
+ A DataFrame containing the combined concentration data, parameter values,
1534
+ derived variables, and optionally readout variables, with time as an additional column.
1535
+
1536
+ """
1537
+ names = self.get_variable_names()
1538
+ if include_derived:
1539
+ names.extend(self.get_derived_variable_names())
1540
+
1541
+ args = self.get_dependent_time_course(
1542
+ variables=variables, include_readouts=include_readouts
1543
+ )
1544
+ return args.loc[:, names]
1545
+
1546
+ ##########################################################################
1547
+ # Get fluxes
1548
+ ##########################################################################
1549
+
1550
+ def _get_fluxes(self, args: dict[str, float]) -> dict[str, float]:
1551
+ """Calculate the fluxes for the given arguments.
1552
+
1553
+ Examples:
1554
+ >>> model._get_fluxes({"x1": 1.0, "x2": 2.0, "k1": 0.1, "time": 0.0})
1555
+ {"r1": 0.1, "r2": 0.2}
1556
+
1557
+ Args:
1558
+ args (dict[str, float]): A dictionary where the keys are argument names and the values are their corresponding float values.
1559
+
1560
+ Returns:
1561
+ dict[str, float]: A dictionary where the keys are reaction names and the values are the calculated fluxes.
1562
+
1563
+ """
1564
+ fluxes: dict[str, float] = {}
1565
+ for name, rxn in self._reactions.items():
1566
+ fluxes[name] = cast(float, rxn.fn(*(args[arg] for arg in rxn.args)))
1567
+
1568
+ for surrogate in self._surrogates.values():
1569
+ fluxes |= surrogate.predict(np.array([args[arg] for arg in surrogate.args]))
1570
+ return fluxes
1571
+
1572
+ def get_fluxes(
1573
+ self,
1574
+ variables: dict[str, float] | None = None,
1575
+ time: float = 0.0,
1576
+ ) -> pd.Series:
1577
+ """Calculate the fluxes for the given concentrations and time.
1578
+
1579
+ Examples:
1580
+ # Using initial conditions as default
1581
+ >>> model.get_fluxes()
1582
+ pd.Series({"r1": 0.1, "r2": 0.2})
1583
+
1584
+ # Using custom concentrations
1585
+ >>> model.get_fluxes({"x1": 1.0, "x2": 2.0})
1586
+ pd.Series({"r1": 0.1, "r2": 0.2})
1587
+
1588
+ # Using custom concentrations and time
1589
+ >>> model.get_fluxes({"x1": 1.0, "x2": 2.0}, time=0.0)
1590
+ pd.Series({"r1": 0.1, "r2": 0.2})
1591
+
1592
+ Args:
1593
+ variables: A dictionary where keys are species names and values are their concentrations.
1594
+ time: The time at which to calculate the fluxes. Defaults to 0.0.
1595
+
1596
+ Returns:
1597
+ Fluxes: A pandas Series containing the fluxes for each reaction.
1598
+
1599
+ """
1600
+ names = self.get_reaction_names()
1601
+ for surrogate in self._surrogates.values():
1602
+ names.extend(surrogate.stoichiometries)
1603
+
1604
+ args = self.get_dependent(
1605
+ variables=variables,
1606
+ time=time,
1607
+ include_readouts=False,
1608
+ )
1609
+ return args.loc[names]
1610
+
1611
+ def get_fluxes_time_course(self, variables: pd.DataFrame) -> pd.DataFrame:
1612
+ """Generate a time course of fluxes for the given reactions and surrogates.
1613
+
1614
+ Examples:
1615
+ >>> model.get_fluxes_time_course(args)
1616
+ pd.DataFrame({"v1": [0.1, 0.2], "v2": [0.2, 0.3]})
1617
+
1618
+ This method calculates the fluxes for each reaction in the model using the provided
1619
+ arguments and combines them with the outputs from the surrogates to create a complete
1620
+ time course of fluxes.
1621
+
1622
+ Args:
1623
+ variables: A DataFrame containing the input arguments for the reactions
1624
+ and surrogates. Each column corresponds to a specific input
1625
+ variable, and each row represents a different time point.
1626
+
1627
+ Returns:
1628
+ pd.DataFrame: A DataFrame containing the calculated fluxes for each reaction and
1629
+ the outputs from the surrogates. The index of the DataFrame matches
1630
+ the index of the input arguments.
1631
+
1632
+ """
1633
+ names = self.get_reaction_names()
1634
+ for surrogate in self._surrogates.values():
1635
+ names.extend(surrogate.stoichiometries)
1636
+
1637
+ variables = self.get_dependent_time_course(
1638
+ variables=variables,
1639
+ include_readouts=False,
1640
+ )
1641
+ return variables.loc[:, names]
1642
+
1643
+ ##########################################################################
1644
+ # Get rhs
1645
+ ##########################################################################
1646
+
1647
+ def __call__(self, /, time: float, variables: Array) -> Array:
1648
+ """Simulation version of get_right_hand_side.
1649
+
1650
+ Examples:
1651
+ >>> model(0.0, np.array([1.0, 2.0]))
1652
+ np.array([0.1, 0.2])
1653
+
1654
+ Warning: Swaps t and y!
1655
+ This can't get kw-only args, as the integrators call it with pos-only
1656
+
1657
+ Args:
1658
+ time: The current time point.
1659
+ variables: Array of concentrations
1660
+
1661
+
1662
+ Returns:
1663
+ The rate of change of each variable in the model.
1664
+
1665
+ """
1666
+ if (cache := self._cache) is None:
1667
+ cache = self._create_cache()
1668
+ vars_d: dict[str, float] = dict(
1669
+ zip(
1670
+ cache.var_names,
1671
+ variables,
1672
+ strict=True,
1673
+ )
1674
+ )
1675
+ dependent: dict[str, float] = self._get_dependent(
1676
+ variables=vars_d,
1677
+ time=time,
1678
+ cache=cache,
1679
+ )
1680
+
1681
+ dxdt = cache.dxdt
1682
+ dxdt[:] = 0
1683
+ for k, stoc in cache.stoich_by_cpds.items():
1684
+ for flux, n in stoc.items():
1685
+ dxdt[k] += n * dependent[flux]
1686
+ for k, sd in cache.dyn_stoich_by_cpds.items():
1687
+ for flux, dv in sd.items():
1688
+ n = dv.calculate(dependent)
1689
+ dxdt[k] += n * dependent[flux]
1690
+ return cast(Array, dxdt.to_numpy())
1691
+
1692
+ def get_right_hand_side(
1693
+ self,
1694
+ variables: dict[str, float] | None = None,
1695
+ time: float = 0.0,
1696
+ ) -> pd.Series:
1697
+ """Calculate the right-hand side of the differential equations for the model.
1698
+
1699
+ Examples:
1700
+ # Using initial conditions as default
1701
+ >>> model.get_right_hand_side()
1702
+ pd.Series({"x1": 0.1, "x2": 0.2})
1703
+
1704
+ # Using custom concentrations
1705
+ >>> model.get_right_hand_side({"x1": 1.0, "x2": 2.0})
1706
+ pd.Series({"x1": 0.1, "x2": 0.2})
1707
+
1708
+ # Using custom concentrations and time
1709
+ >>> model.get_right_hand_side({"x1": 1.0, "x2": 2.0}, time=0.0)
1710
+ pd.Series({"x1": 0.1, "x2": 0.2})
1711
+
1712
+ Args:
1713
+ variables: A dictionary mapping compound names to their concentrations.
1714
+ time: The current time point. Defaults to 0.0.
1715
+
1716
+ Returns:
1717
+ The rate of change of each variable in the model.
1718
+
1719
+ """
1720
+ if (cache := self._cache) is None:
1721
+ cache = self._create_cache()
1722
+ var_names = self.get_variable_names()
1723
+ dependent = self._get_dependent(
1724
+ variables=self.get_initial_conditions() if variables is None else variables,
1725
+ time=time,
1726
+ cache=cache,
1727
+ )
1728
+ dxdt = pd.Series(np.zeros(len(var_names), dtype=float), index=var_names)
1729
+ for k, stoc in cache.stoich_by_cpds.items():
1730
+ for flux, n in stoc.items():
1731
+ dxdt[k] += n * dependent[flux]
1732
+
1733
+ for k, sd in cache.dyn_stoich_by_cpds.items():
1734
+ for flux, dv in sd.items():
1735
+ n = dv.fn(*(dependent[i] for i in dv.args))
1736
+ dxdt[k] += n * dependent[flux]
1737
+ return dxdt