modelbase2 0.1.78__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. modelbase2/__init__.py +138 -26
  2. modelbase2/distributions.py +306 -0
  3. modelbase2/experimental/__init__.py +17 -0
  4. modelbase2/experimental/codegen.py +239 -0
  5. modelbase2/experimental/diff.py +227 -0
  6. modelbase2/experimental/notes.md +4 -0
  7. modelbase2/experimental/tex.py +521 -0
  8. modelbase2/fit.py +284 -0
  9. modelbase2/fns.py +185 -0
  10. modelbase2/integrators/__init__.py +19 -0
  11. modelbase2/integrators/int_assimulo.py +146 -0
  12. modelbase2/integrators/int_scipy.py +147 -0
  13. modelbase2/label_map.py +610 -0
  14. modelbase2/linear_label_map.py +301 -0
  15. modelbase2/mc.py +548 -0
  16. modelbase2/mca.py +280 -0
  17. modelbase2/model.py +1621 -0
  18. modelbase2/npe.py +343 -0
  19. modelbase2/parallel.py +171 -0
  20. modelbase2/parameterise.py +28 -0
  21. modelbase2/paths.py +36 -0
  22. modelbase2/plot.py +829 -0
  23. modelbase2/sbml/__init__.py +14 -0
  24. modelbase2/sbml/_data.py +77 -0
  25. modelbase2/sbml/_export.py +656 -0
  26. modelbase2/sbml/_import.py +585 -0
  27. modelbase2/sbml/_mathml.py +691 -0
  28. modelbase2/sbml/_name_conversion.py +52 -0
  29. modelbase2/sbml/_unit_conversion.py +74 -0
  30. modelbase2/scan.py +616 -0
  31. modelbase2/scope.py +96 -0
  32. modelbase2/simulator.py +635 -0
  33. modelbase2/surrogates/__init__.py +32 -0
  34. modelbase2/surrogates/_poly.py +66 -0
  35. modelbase2/surrogates/_torch.py +249 -0
  36. modelbase2/surrogates.py +316 -0
  37. modelbase2/types.py +352 -11
  38. modelbase2-0.2.0.dist-info/METADATA +81 -0
  39. modelbase2-0.2.0.dist-info/RECORD +42 -0
  40. {modelbase2-0.1.78.dist-info → modelbase2-0.2.0.dist-info}/WHEEL +1 -1
  41. modelbase2/core/__init__.py +0 -29
  42. modelbase2/core/algebraic_module_container.py +0 -130
  43. modelbase2/core/constant_container.py +0 -113
  44. modelbase2/core/data.py +0 -109
  45. modelbase2/core/name_container.py +0 -29
  46. modelbase2/core/reaction_container.py +0 -115
  47. modelbase2/core/utils.py +0 -28
  48. modelbase2/core/variable_container.py +0 -24
  49. modelbase2/ode/__init__.py +0 -13
  50. modelbase2/ode/integrator.py +0 -80
  51. modelbase2/ode/mca.py +0 -270
  52. modelbase2/ode/model.py +0 -470
  53. modelbase2/ode/simulator.py +0 -153
  54. modelbase2/utils/__init__.py +0 -0
  55. modelbase2/utils/plotting.py +0 -372
  56. modelbase2-0.1.78.dist-info/METADATA +0 -44
  57. modelbase2-0.1.78.dist-info/RECORD +0 -22
  58. {modelbase2-0.1.78.dist-info → modelbase2-0.2.0.dist-info/licenses}/LICENSE +0 -0
modelbase2/model.py ADDED
@@ -0,0 +1,1621 @@
1
+ """Model for Metabolic System Representation.
2
+
3
+ This module provides the core Model class and supporting functionality for representing
4
+ metabolic models, including reactions, variables, parameters and derived quantities.
5
+
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import copy
11
+ import inspect
12
+ import itertools as it
13
+ import math
14
+ from dataclasses import dataclass, field
15
+ from typing import TYPE_CHECKING, Self, cast
16
+
17
+ import numpy as np
18
+ import pandas as pd
19
+
20
+ from modelbase2 import fns
21
+ from modelbase2.types import (
22
+ Array,
23
+ Derived,
24
+ Float,
25
+ Reaction,
26
+ Readout,
27
+ )
28
+
29
+ __all__ = ["ArityMismatchError", "Model", "ModelCache", "SortError"]
30
+
31
+ if TYPE_CHECKING:
32
+ from collections.abc import Iterable, Mapping
33
+ from inspect import FullArgSpec
34
+
35
+ from modelbase2.types import AbstractSurrogate, Callable, Param, RateFn, RetType
36
+
37
+
38
+ class SortError(Exception):
39
+ """Raised when dependencies cannot be sorted topologically.
40
+
41
+ This typically indicates circular dependencies in model components.
42
+ """
43
+
44
+ def __init__(self, unsorted: list[str], order: list[str]) -> None:
45
+ """Initialise exception."""
46
+ msg = (
47
+ f"Exceeded max iterations on sorting derived. "
48
+ "Check if there are circular references.\n"
49
+ f"Unsorted: {unsorted}\n"
50
+ f"Order: {order}"
51
+ )
52
+ super().__init__(msg)
53
+
54
+
55
+ def _get_all_args(argspec: FullArgSpec) -> list[str]:
56
+ kwonly = [] if argspec.kwonlyargs is None else argspec.kwonlyargs
57
+ return argspec.args + kwonly
58
+
59
+
60
+ def _check_function_arity(function: Callable, arity: int) -> bool:
61
+ """Check if the amount of arguments given fits the argument count of the function."""
62
+ argspec = inspect.getfullargspec(function)
63
+ # Give up on *args functions
64
+ if argspec.varargs is not None:
65
+ return True
66
+
67
+ # The sane case
68
+ if len(argspec.args) == arity:
69
+ return True
70
+
71
+ # It might be that the user has set some args to default values,
72
+ # in which case they are also ok (might be kwonly as well)
73
+ defaults = argspec.defaults
74
+ if defaults is not None and len(argspec.args) + len(defaults) == arity:
75
+ return True
76
+ kwonly = argspec.kwonlyargs
77
+ return bool(defaults is not None and len(argspec.args) + len(kwonly) == arity)
78
+
79
+
80
+ class ArityMismatchError(Exception):
81
+ """Mismatch between python function and model arguments."""
82
+
83
+ def __init__(self, name: str, fn: Callable, args: list[str]) -> None:
84
+ """Format message."""
85
+ argspec = inspect.getfullargspec(fn)
86
+
87
+ message = f"Function arity mismatch for {name}.\n"
88
+ message += "\n".join(
89
+ (
90
+ f"{i:<8.8} | {j:<10.10}"
91
+ for i, j in [
92
+ ("Fn args", "Model args"),
93
+ ("-------", "----------"),
94
+ *it.zip_longest(argspec.args, args, fillvalue="---"),
95
+ ]
96
+ )
97
+ )
98
+ super().__init__(message)
99
+
100
+
101
+ def _invalidate_cache(method: Callable[Param, RetType]) -> Callable[Param, RetType]:
102
+ """Decorator that invalidates model cache when decorated method is called.
103
+
104
+ Args:
105
+ method: Method to wrap with cache invalidation
106
+
107
+ Returns:
108
+ Wrapped method that clears cache before execution
109
+
110
+ """
111
+
112
+ def wrapper(
113
+ *args: Param.args,
114
+ **kwargs: Param.kwargs,
115
+ ) -> RetType:
116
+ self = cast(Model, args[0])
117
+ self._cache = None
118
+ return method(*args, **kwargs)
119
+
120
+ return wrapper # type: ignore
121
+
122
+
123
+ def _sort_dependencies(
124
+ available: set[str], elements: list[tuple[str, set[str]]]
125
+ ) -> list[str]:
126
+ """Sort model elements topologically based on their dependencies.
127
+
128
+ Args:
129
+ available: Set of available component names
130
+ elements: List of (name, dependencies) tuples to sort
131
+
132
+ Returns:
133
+ List of element names in dependency order
134
+
135
+ Raises:
136
+ SortError: If circular dependencies are detected
137
+
138
+ """
139
+ from queue import Empty, SimpleQueue
140
+
141
+ order = []
142
+ # FIXME: what is the worst case here?
143
+ max_iterations = len(elements) ** 2
144
+ queue: SimpleQueue[tuple[str, set[str]]] = SimpleQueue()
145
+ for k, v in elements:
146
+ queue.put((k, v))
147
+
148
+ last_name = None
149
+ i = 0
150
+ while True:
151
+ try:
152
+ new, args = queue.get_nowait()
153
+ except Empty:
154
+ break
155
+ if args.issubset(available):
156
+ available.add(new)
157
+ order.append(new)
158
+ else:
159
+ if last_name == new:
160
+ order.append(new)
161
+ break
162
+ queue.put((new, args))
163
+ last_name = new
164
+ i += 1
165
+
166
+ # Failure case
167
+ if i > max_iterations:
168
+ unsorted = []
169
+ while True:
170
+ try:
171
+ unsorted.append(queue.get_nowait()[0])
172
+ except Empty:
173
+ break
174
+ raise SortError(unsorted=unsorted, order=order)
175
+ return order
176
+
177
+
178
+ @dataclass(slots=True)
179
+ class ModelCache:
180
+ """ModelCache is a class that stores various model-related data structures.
181
+
182
+ Attributes:
183
+ var_names: A list of variable names.
184
+ parameter_values: A dictionary mapping parameter names to their values.
185
+ derived_parameters: A dictionary mapping parameter names to their derived parameter objects.
186
+ derived_variables: A dictionary mapping variable names to their derived variable objects.
187
+ stoich_by_cpds: A dictionary mapping compound names to their stoichiometric coefficients.
188
+ dyn_stoich_by_cpds: A dictionary mapping compound names to their dynamic stoichiometric coefficients.
189
+ dxdt: A pandas Series representing the rate of change of variables.
190
+
191
+ """
192
+
193
+ var_names: list[str]
194
+ all_parameter_values: dict[str, float]
195
+ derived_parameter_names: list[str]
196
+ derived_variable_names: list[str]
197
+ stoich_by_cpds: dict[str, dict[str, float]]
198
+ dyn_stoich_by_cpds: dict[str, dict[str, Derived]]
199
+ dxdt: pd.Series
200
+
201
+
202
+ @dataclass(slots=True)
203
+ class Model:
204
+ """Represents a metabolic model.
205
+
206
+ Attributes:
207
+ _ids: Dictionary mapping internal IDs to names.
208
+ _variables: Dictionary of model variables and their initial values.
209
+ _parameters: Dictionary of model parameters and their values.
210
+ _derived: Dictionary of derived quantities.
211
+ _readouts: Dictionary of readout functions.
212
+ _reactions: Dictionary of reactions in the model.
213
+ _surrogates: Dictionary of surrogate models.
214
+ _cache: Cache for storing model-related data structures.
215
+
216
+ """
217
+
218
+ _ids: dict[str, str] = field(default_factory=dict)
219
+ _variables: dict[str, float] = field(default_factory=dict)
220
+ _parameters: dict[str, float] = field(default_factory=dict)
221
+ _derived: dict[str, Derived] = field(default_factory=dict)
222
+ _readouts: dict[str, Readout] = field(default_factory=dict)
223
+ _reactions: dict[str, Reaction] = field(default_factory=dict)
224
+ _surrogates: dict[str, AbstractSurrogate] = field(default_factory=dict)
225
+ _cache: ModelCache | None = None
226
+
227
+ ###########################################################################
228
+ # Cache
229
+ ###########################################################################
230
+
231
+ def _create_cache(self) -> ModelCache:
232
+ """Creates and initializes the model cache.
233
+
234
+ This method constructs a cache that includes parameter values, stoichiometry
235
+ by compounds, dynamic stoichiometry by compounds, derived variables, and
236
+ derived parameters. It processes the model's parameters, variables, derived
237
+ elements, reactions, and surrogates to populate the cache.
238
+
239
+ Returns:
240
+ ModelCache: An instance of ModelCache containing the initialized cache data.
241
+
242
+ """
243
+ all_parameter_values: dict[str, float] = self._parameters.copy()
244
+ all_parameter_names: set[str] = set(all_parameter_values)
245
+
246
+ # Sanity checks
247
+ for name, el in it.chain(
248
+ self._derived.items(),
249
+ self._readouts.items(),
250
+ self._reactions.items(),
251
+ ):
252
+ if not _check_function_arity(el.fn, len(el.args)):
253
+ raise ArityMismatchError(name, el.fn, el.args)
254
+
255
+ # Sort derived
256
+ derived_order = _sort_dependencies(
257
+ available=set(self._parameters) | set(self._variables) | {"time"},
258
+ elements=[(k, set(v.args)) for k, v in self._derived.items()],
259
+ )
260
+
261
+ # Split derived into parameters and variables
262
+ derived_variable_names: list[str] = []
263
+ derived_parameter_names: list[str] = []
264
+ for name in derived_order:
265
+ derived = self._derived[name]
266
+ if all(i in all_parameter_names for i in derived.args):
267
+ all_parameter_names.add(name)
268
+ derived_parameter_names.append(name)
269
+ all_parameter_values[name] = float(
270
+ derived.fn(*(all_parameter_values[i] for i in derived.args))
271
+ )
272
+ else:
273
+ derived_variable_names.append(name)
274
+
275
+ stoich_by_compounds: dict[str, dict[str, float]] = {}
276
+ dyn_stoich_by_compounds: dict[str, dict[str, Derived]] = {}
277
+
278
+ for rxn_name, rxn in self._reactions.items():
279
+ for cpd_name, factor in rxn.stoichiometry.items():
280
+ d_static = stoich_by_compounds.setdefault(cpd_name, {})
281
+
282
+ if isinstance(factor, Derived):
283
+ if all(i in all_parameter_names for i in factor.args):
284
+ d_static[rxn_name] = float(
285
+ factor.fn(*(all_parameter_values[i] for i in factor.args))
286
+ )
287
+ else:
288
+ dyn_stoich_by_compounds.setdefault(cpd_name, {})[rxn_name] = (
289
+ factor
290
+ )
291
+ else:
292
+ d_static[rxn_name] = factor
293
+
294
+ for surrogate in self._surrogates.values():
295
+ for rxn_name, rxn in surrogate.stoichiometries.items():
296
+ for cpd_name, factor in rxn.items():
297
+ stoich_by_compounds.setdefault(cpd_name, {})[rxn_name] = factor
298
+
299
+ var_names = self.get_variable_names()
300
+ dxdt = pd.Series(np.zeros(len(var_names), dtype=float), index=var_names)
301
+
302
+ self._cache = ModelCache(
303
+ var_names=var_names,
304
+ all_parameter_values=all_parameter_values,
305
+ stoich_by_cpds=stoich_by_compounds,
306
+ dyn_stoich_by_cpds=dyn_stoich_by_compounds,
307
+ derived_variable_names=derived_variable_names,
308
+ derived_parameter_names=derived_parameter_names,
309
+ dxdt=dxdt,
310
+ )
311
+ return self._cache
312
+
313
+ ###########################################################################
314
+ # Ids
315
+ ###########################################################################
316
+
317
+ @property
318
+ def ids(self) -> dict[str, str]:
319
+ """Returns a copy of the _ids dictionary.
320
+
321
+ The _ids dictionary contains key-value pairs where both keys and values are strings.
322
+
323
+ Returns:
324
+ dict[str, str]: A copy of the _ids dictionary.
325
+
326
+ """
327
+ return self._ids.copy()
328
+
329
+ def _insert_id(self, *, name: str, ctx: str) -> None:
330
+ """Inserts an identifier into the model's internal ID dictionary.
331
+
332
+ Args:
333
+ name: The name of the identifier to insert.
334
+ ctx: The context associated with the identifier.
335
+
336
+ Raises:
337
+ KeyError: If the name is "time", which is a protected variable.
338
+ NameError: If the name already exists in the model's ID dictionary.
339
+
340
+ """
341
+ if name == "time":
342
+ msg = "time is a protected variable for time"
343
+ raise KeyError(msg)
344
+
345
+ if name in self._ids:
346
+ msg = f"Model already contains {ctx} called '{name}'"
347
+ raise NameError(msg)
348
+ self._ids[name] = ctx
349
+
350
+ def _remove_id(self, *, name: str) -> None:
351
+ """Remove an ID from the internal dictionary.
352
+
353
+ Args:
354
+ name (str): The name of the ID to be removed.
355
+
356
+ Raises:
357
+ KeyError: If the specified name does not exist in the dictionary.
358
+
359
+ """
360
+ del self._ids[name]
361
+
362
+ ##########################################################################
363
+ # Parameters
364
+ ##########################################################################
365
+
366
+ @_invalidate_cache
367
+ def add_parameter(self, name: str, value: float) -> Self:
368
+ """Adds a parameter to the model.
369
+
370
+ Examples:
371
+ >>> model.add_parameter("k1", 0.1)
372
+
373
+ Args:
374
+ name (str): The name of the parameter.
375
+ value (float): The value of the parameter.
376
+
377
+ Returns:
378
+ Self: The instance of the model with the added parameter.
379
+
380
+ """
381
+ self._insert_id(name=name, ctx="parameter")
382
+ self._parameters[name] = value
383
+ return self
384
+
385
+ def add_parameters(self, parameters: dict[str, float]) -> Self:
386
+ """Adds multiple parameters to the model.
387
+
388
+ Examples:
389
+ >>> model.add_parameters({"k1": 0.1, "k2": 0.2})
390
+
391
+ Args:
392
+ parameters (dict[str, float]): A dictionary where the keys are parameter names
393
+ and the values are the corresponding parameter values.
394
+
395
+ Returns:
396
+ Self: The instance of the model with the added parameters.
397
+
398
+ """
399
+ for k, v in parameters.items():
400
+ self.add_parameter(k, v)
401
+ return self
402
+
403
+ @property
404
+ def parameters(self) -> dict[str, float]:
405
+ """Returns the parameters of the model.
406
+
407
+ Examples:
408
+ >>> model.parameters
409
+ {"k1": 0.1, "k2": 0.2}
410
+
411
+ Returns:
412
+ parameters: A dictionary where the keys are parameter names (as strings)
413
+ and the values are parameter values (as floats).
414
+
415
+ """
416
+ return self._parameters.copy()
417
+
418
+ def get_parameter_names(self) -> list[str]:
419
+ """Retrieve the names of the parameters.
420
+
421
+ Examples:
422
+ >>> model.get_parameter_names()
423
+ ['k1', 'k2']
424
+
425
+ Returns:
426
+ parametes: A list containing the names of the parameters.
427
+
428
+ """
429
+ return list(self._parameters)
430
+
431
+ @_invalidate_cache
432
+ def remove_parameter(self, name: str) -> Self:
433
+ """Remove a parameter from the model.
434
+
435
+ Examples:
436
+ >>> model.remove_parameter("k1")
437
+
438
+ Args:
439
+ name: The name of the parameter to remove.
440
+
441
+ Returns:
442
+ Self: The instance of the model with the parameter removed.
443
+
444
+ """
445
+ self._remove_id(name=name)
446
+ self._parameters.pop(name)
447
+ return self
448
+
449
+ def remove_parameters(self, names: list[str]) -> Self:
450
+ """Remove multiple parameters from the model.
451
+
452
+ Examples:
453
+ >>> model.remove_parameters(["k1", "k2"])
454
+
455
+ Args:
456
+ names: A list of parameter names to be removed.
457
+
458
+ Returns:
459
+ Self: The instance of the model with the specified parameters removed.
460
+
461
+ """
462
+ for name in names:
463
+ self.remove_parameter(name)
464
+ return self
465
+
466
+ @_invalidate_cache
467
+ def update_parameter(self, name: str, value: float) -> Self:
468
+ """Update the value of a parameter.
469
+
470
+ Examples:
471
+ >>> model.update_parameter("k1", 0.2)
472
+
473
+ Args:
474
+ name: The name of the parameter to update.
475
+ value: The new value for the parameter.
476
+
477
+ Returns:
478
+ Self: The instance of the class with the updated parameter.
479
+
480
+ Raises:
481
+ NameError: If the parameter name is not found in the parameters.
482
+
483
+ """
484
+ if name not in self._parameters:
485
+ msg = f"'{name}' not found in parameters"
486
+ raise KeyError(msg)
487
+ self._parameters[name] = value
488
+ return self
489
+
490
+ def update_parameters(self, parameters: dict[str, float]) -> Self:
491
+ """Update multiple parameters of the model.
492
+
493
+ Examples:
494
+ >>> model.update_parameters({"k1": 0.2, "k2": 0.3})
495
+
496
+ Args:
497
+ parameters: A dictionary where keys are parameter names and values are the new parameter values.
498
+
499
+ Returns:
500
+ Self: The instance of the model with updated parameters.
501
+
502
+ """
503
+ for k, v in parameters.items():
504
+ self.update_parameter(k, v)
505
+ return self
506
+
507
+ def scale_parameter(self, name: str, factor: float) -> Self:
508
+ """Scales the value of a specified parameter by a given factor.
509
+
510
+ Examples:
511
+ >>> model.scale_parameter("k1", 2.0)
512
+
513
+ Args:
514
+ name: The name of the parameter to be scaled.
515
+ factor: The factor by which to scale the parameter's value.
516
+
517
+ Returns:
518
+ Self: The instance of the class with the updated parameter.
519
+
520
+ """
521
+ return self.update_parameter(name, self._parameters[name] * factor)
522
+
523
+ def scale_parameters(self, parameters: dict[str, float]) -> Self:
524
+ """Scales the parameters of the model.
525
+
526
+ Examples:
527
+ >>> model.scale_parameters({"k1": 2.0, "k2": 0.5})
528
+
529
+ Args:
530
+ parameters: A dictionary where the keys are parameter names
531
+ and the values are the scaling factors.
532
+
533
+ Returns:
534
+ Self: The instance of the model with scaled parameters.
535
+
536
+ """
537
+ for k, v in parameters.items():
538
+ self.scale_parameter(k, v)
539
+ return self
540
+
541
+ @_invalidate_cache
542
+ def make_parameter_dynamic(
543
+ self,
544
+ name: str,
545
+ initial_value: float | None = None,
546
+ stoichiometries: dict[str, float] | None = None,
547
+ ) -> Self:
548
+ """Converts a parameter to a dynamic variable in the model.
549
+
550
+ Examples:
551
+ >>> model.make_parameter_dynamic("k1")
552
+ >>> model.make_parameter_dynamic("k2", initial_value=0.5)
553
+
554
+ This method removes the specified parameter from the model and adds it as a variable with an optional initial value.
555
+
556
+ Args:
557
+ name: The name of the parameter to be converted.
558
+ initial_value: The initial value for the new variable. If None, the current value of the parameter is used. Defaults to None.
559
+ stoichiometries: A dictionary mapping reaction names to stoichiometries for the new variable. Defaults to None.
560
+
561
+ Returns:
562
+ Self: The instance of the model with the parameter converted to a variable.
563
+
564
+ """
565
+ value = self._parameters[name] if initial_value is None else initial_value
566
+ self.remove_parameter(name)
567
+ self.add_variable(name, value)
568
+
569
+ if stoichiometries is not None:
570
+ for rxn_name, value in stoichiometries.items():
571
+ target = False
572
+ if rxn_name in self._reactions:
573
+ target = True
574
+ cast(dict, self._reactions[name].stoichiometry)[name] = value
575
+ else:
576
+ for surrogate in self._surrogates.values():
577
+ if rxn_name in surrogate.stoichiometries:
578
+ target = True
579
+ surrogate.stoichiometries[rxn_name][name] = value
580
+ if not target:
581
+ msg = f"Reaction '{rxn_name}' not found in reactions or surrogates"
582
+ raise KeyError(msg)
583
+
584
+ return self
585
+
586
+ ##########################################################################
587
+ # Variables
588
+ ##########################################################################
589
+
590
+ @property
591
+ def variables(self) -> dict[str, float]:
592
+ """Returns a copy of the variables dictionary.
593
+
594
+ Examples:
595
+ >>> model.variables
596
+ {"x1": 1.0, "x2": 2.0}
597
+
598
+ This method returns a copy of the internal dictionary that maps variable
599
+ names to their corresponding float values.
600
+
601
+ Returns:
602
+ dict[str, float]: A copy of the variables dictionary.
603
+
604
+ """
605
+ return self._variables.copy()
606
+
607
+ @_invalidate_cache
608
+ def add_variable(self, name: str, initial_condition: float) -> Self:
609
+ """Adds a variable to the model with the given name and initial condition.
610
+
611
+ Examples:
612
+ >>> model.add_variable("x1", 1.0)
613
+
614
+ Args:
615
+ name: The name of the variable to add.
616
+ initial_condition: The initial condition value for the variable.
617
+
618
+ Returns:
619
+ Self: The instance of the model with the added variable.
620
+
621
+ """
622
+ self._insert_id(name=name, ctx="variable")
623
+ self._variables[name] = initial_condition
624
+ return self
625
+
626
+ def add_variables(self, variables: dict[str, float]) -> Self:
627
+ """Adds multiple variables to the model with their initial conditions.
628
+
629
+ Examples:
630
+ >>> model.add_variables({"x1": 1.0, "x2": 2.0})
631
+
632
+ Args:
633
+ variables: A dictionary where the keys are variable names (str)
634
+ and the values are their initial conditions (float).
635
+
636
+ Returns:
637
+ Self: The instance of the model with the added variables.
638
+
639
+ """
640
+ for name, y0 in variables.items():
641
+ self.add_variable(name=name, initial_condition=y0)
642
+ return self
643
+
644
+ @_invalidate_cache
645
+ def remove_variable(self, name: str) -> Self:
646
+ """Remove a variable from the model.
647
+
648
+ Examples:
649
+ >>> model.remove_variable("x1")
650
+
651
+ Args:
652
+ name: The name of the variable to remove.
653
+
654
+ Returns:
655
+ Self: The instance of the model with the variable removed.
656
+
657
+ """
658
+ self._remove_id(name=name)
659
+ del self._variables[name]
660
+ return self
661
+
662
+ def remove_variables(self, variables: Iterable[str]) -> Self:
663
+ """Remove multiple variables from the model.
664
+
665
+ Examples:
666
+ >>> model.remove_variables(["x1", "x2"])
667
+
668
+ Args:
669
+ variables: An iterable of variable names to be removed.
670
+
671
+ Returns:
672
+ Self: The instance of the model with the specified variables removed.
673
+
674
+ """
675
+ for variable in variables:
676
+ self.remove_variable(name=variable)
677
+ return self
678
+
679
+ @_invalidate_cache
680
+ def update_variable(self, name: str, initial_condition: float) -> Self:
681
+ """Updates the value of a variable in the model.
682
+
683
+ Examples:
684
+ >>> model.update_variable("x1", 2.0)
685
+
686
+ Args:
687
+ name: The name of the variable to update.
688
+ initial_condition: The initial condition or value to set for the variable.
689
+
690
+ Returns:
691
+ Self: The instance of the model with the updated variable.
692
+
693
+ """
694
+ if name not in self._variables:
695
+ msg = f"'{name}' not found in variables"
696
+ raise KeyError(msg)
697
+ self._variables[name] = initial_condition
698
+ return self
699
+
700
+ def get_variable_names(self) -> list[str]:
701
+ """Retrieve the names of all variables.
702
+
703
+ Examples:
704
+ >>> model.get_variable_names()
705
+ ["x1", "x2"]
706
+
707
+ Returns:
708
+ variable_names: A list containing the names of all variables.
709
+
710
+ """
711
+ return list(self._variables)
712
+
713
+ def get_initial_conditions(self) -> dict[str, float]:
714
+ """Retrieve the initial conditions of the model.
715
+
716
+ Examples:
717
+ >>> model.get_initial_conditions()
718
+ {"x1": 1.0, "x2": 2.0}
719
+
720
+ Returns:
721
+ initial_conditions: A dictionary where the keys are variable names and the values are their initial conditions.
722
+
723
+ """
724
+ return self._variables
725
+
726
+ def make_variable_static(self, name: str, value: float | None = None) -> Self:
727
+ """Converts a variable to a static parameter.
728
+
729
+ This removes the variable from the stoichiometries of all reactions and surrogates.
730
+ It is not re-inserted if `Model.make_parameter_dynamic` is called.
731
+
732
+ Examples:
733
+ >>> model.make_variable_static("x1")
734
+ >>> model.make_variable_static("x2", value=2.0)
735
+
736
+ Args:
737
+ name: The name of the variable to be made static.
738
+ value: The value to assign to the parameter.
739
+ If None, the current value of the variable is used. Defaults to None.
740
+
741
+ Returns:
742
+ Self: The instance of the class for method chaining.
743
+
744
+ """
745
+ value = self._variables[name] if value is None else value
746
+ self.remove_variable(name)
747
+ self.add_parameter(name, value)
748
+
749
+ # Remove from stoichiometries
750
+ for reaction in self._reactions.values():
751
+ if name in reaction.stoichiometry:
752
+ cast(dict, reaction.stoichiometry).pop(name)
753
+ for surrogate in self._surrogates.values():
754
+ surrogate.stoichiometries = {
755
+ k: {k2: v2 for k2, v2 in v.items() if k2 != name}
756
+ for k, v in surrogate.stoichiometries.items()
757
+ if k != name
758
+ }
759
+ return self
760
+
761
+ ##########################################################################
762
+ # Derived
763
+ ##########################################################################
764
+
765
+ @property
766
+ def derived(self) -> dict[str, Derived]:
767
+ """Returns a copy of the derived quantities.
768
+
769
+ Examples:
770
+ >>> model.derived
771
+ {"d1": Derived(fn1, ["x1", "x2"]),
772
+ "d2": Derived(fn2, ["x1", "d1"])}
773
+
774
+ Returns:
775
+ dict[str, Derived]: A copy of the derived dictionary.
776
+
777
+ """
778
+ return self._derived.copy()
779
+
780
+ @property
781
+ def derived_variables(self) -> dict[str, Derived]:
782
+ """Returns a dictionary of derived variables.
783
+
784
+ Examples:
785
+ >>> model.derived_variables()
786
+ {"d1": Derived(fn1, ["x1", "x2"]),
787
+ "d2": Derived(fn2, ["x1", "d1"])}
788
+
789
+ Returns:
790
+ derived_variables: A dictionary where the keys are strings
791
+ representing the names of the derived variables and the values are
792
+ instances of DerivedVariable.
793
+
794
+ """
795
+ if (cache := self._cache) is None:
796
+ cache = self._create_cache()
797
+ derived = self._derived
798
+ return {k: derived[k] for k in cache.derived_variable_names}
799
+
800
+ @property
801
+ def derived_parameters(self) -> dict[str, Derived]:
802
+ """Returns a dictionary of derived parameters.
803
+
804
+ Examples:
805
+ >>> model.derived_parameters()
806
+ {"kd1": Derived(fn1, ["k1", "k2"]),
807
+ "kd2": Derived(fn2, ["k1", "kd1"])}
808
+
809
+ Returns:
810
+ A dictionary where the keys are
811
+ parameter names and the values are Derived.
812
+
813
+ """
814
+ if (cache := self._cache) is None:
815
+ cache = self._create_cache()
816
+ derived = self._derived
817
+ return {k: derived[k] for k in cache.derived_parameter_names}
818
+
819
+ @_invalidate_cache
820
+ def add_derived(
821
+ self,
822
+ name: str,
823
+ fn: RateFn,
824
+ *,
825
+ args: list[str],
826
+ ) -> Self:
827
+ """Adds a derived attribute to the model.
828
+
829
+ Examples:
830
+ >>> model.add_derived("d1", add, args=["x1", "x2"])
831
+
832
+ Args:
833
+ name: The name of the derived attribute.
834
+ fn: The function used to compute the derived attribute.
835
+ args: The list of arguments to be passed to the function.
836
+
837
+ Returns:
838
+ Self: The instance of the model with the added derived attribute.
839
+
840
+ """
841
+ self._insert_id(name=name, ctx="derived")
842
+ self._derived[name] = Derived(fn, args)
843
+ return self
844
+
845
+ def get_derived_parameter_names(self) -> list[str]:
846
+ """Retrieve the names of derived parameters.
847
+
848
+ Examples:
849
+ >>> model.get_derived_parameter_names()
850
+ ["kd1", "kd2"]
851
+
852
+ Returns:
853
+ A list of names of the derived parameters.
854
+
855
+ """
856
+ return list(self.derived_parameters)
857
+
858
+ def get_derived_variable_names(self) -> list[str]:
859
+ """Retrieve the names of derived variables.
860
+
861
+ Examples:
862
+ >>> model.get_derived_variable_names()
863
+ ["d1", "d2"]
864
+
865
+ Returns:
866
+ A list of names of derived variables.
867
+
868
+ """
869
+ return list(self.derived_variables)
870
+
871
+ @_invalidate_cache
872
+ def update_derived(
873
+ self,
874
+ name: str,
875
+ fn: RateFn | None = None,
876
+ *,
877
+ args: list[str] | None = None,
878
+ ) -> Self:
879
+ """Updates the derived function and its arguments for a given name.
880
+
881
+ Examples:
882
+ >>> model.update_derived("d1", add, ["x1", "x2"])
883
+
884
+ Args:
885
+ name: The name of the derived function to update.
886
+ fn: The new derived function. If None, the existing function is retained. Defaults to None.
887
+ args: The new arguments for the derived function. If None, the existing arguments are retained. Defaults to None.
888
+
889
+ Returns:
890
+ Self: The instance of the class with the updated derived function and arguments.
891
+
892
+ """
893
+ der = self._derived[name]
894
+ der.fn = der.fn if fn is None else fn
895
+ der.args = der.args if args is None else args
896
+ return self
897
+
898
+ @_invalidate_cache
899
+ def remove_derived(self, name: str) -> Self:
900
+ """Remove a derived attribute from the model.
901
+
902
+ Examples:
903
+ >>> model.remove_derived("d1")
904
+
905
+ Args:
906
+ name: The name of the derived attribute to remove.
907
+
908
+ Returns:
909
+ Self: The instance of the model with the derived attribute removed.
910
+
911
+ """
912
+ self._remove_id(name=name)
913
+ self._derived.pop(name)
914
+ return self
915
+
916
+ ###########################################################################
917
+ # Reactions
918
+ ###########################################################################
919
+
920
+ @property
921
+ def reactions(self) -> dict[str, Reaction]:
922
+ """Retrieve the reactions in the model.
923
+
924
+ Examples:
925
+ >>> model.reactions
926
+ {"r1": Reaction(fn1, {"x1": -1, "x2": 1}, ["k1"]),
927
+
928
+ Returns:
929
+ dict[str, Reaction]: A deep copy of the reactions dictionary.
930
+
931
+ """
932
+ return copy.deepcopy(self._reactions)
933
+
934
+ def get_stoichiometries(
935
+ self, concs: dict[str, float] | None = None, time: float = 0.0
936
+ ) -> pd.DataFrame:
937
+ """Retrieve the stoichiometries of the model.
938
+
939
+ Examples:
940
+ >>> model.stoichiometries()
941
+ v1 v2
942
+ x1 -1 1
943
+ x2 1 -1
944
+
945
+ Returns:
946
+ pd.DataFrame: A DataFrame containing the stoichiometries of the model.
947
+
948
+ """
949
+ if (cache := self._cache) is None:
950
+ cache = self._create_cache()
951
+ args = self.get_args(concs=concs, time=time)
952
+
953
+ stoich_by_cpds = copy.deepcopy(cache.stoich_by_cpds)
954
+ for cpd, stoich in cache.dyn_stoich_by_cpds.items():
955
+ for rxn, derived in stoich.items():
956
+ stoich_by_cpds[cpd][rxn] = float(
957
+ derived.fn(*(args[i] for i in derived.args))
958
+ )
959
+ return pd.DataFrame(stoich_by_cpds).T.fillna(0)
960
+
961
+ @_invalidate_cache
962
+ def add_reaction(
963
+ self,
964
+ name: str,
965
+ fn: RateFn,
966
+ *,
967
+ args: list[str],
968
+ stoichiometry: Mapping[str, float | str | Derived],
969
+ ) -> Self:
970
+ """Adds a reaction to the model.
971
+
972
+ Examples:
973
+ >>> model.add_reaction("v1",
974
+ ... fn=mass_action,
975
+ ... args=["x1", "kf1"],
976
+ ... stoichiometry={"x1": -1, "x2": 1},
977
+ ... )
978
+
979
+ Args:
980
+ name: The name of the reaction.
981
+ fn: The function representing the reaction.
982
+ args: A list of arguments for the reaction function.
983
+ stoichiometry: The stoichiometry of the reaction, mapping species to their coefficients.
984
+
985
+ Returns:
986
+ Self: The instance of the model with the added reaction.
987
+
988
+ """
989
+ self._insert_id(name=name, ctx="reaction")
990
+
991
+ stoich: dict[str, Derived | float] = {
992
+ k: Derived(fns.constant, [v]) if isinstance(v, str) else v
993
+ for k, v in stoichiometry.items()
994
+ }
995
+ self._reactions[name] = Reaction(fn=fn, stoichiometry=stoich, args=args)
996
+ return self
997
+
998
+ def get_reaction_names(self) -> list[str]:
999
+ """Retrieve the names of all reactions.
1000
+
1001
+ Examples:
1002
+ >>> model.get_reaction_names()
1003
+ ["v1", "v2"]
1004
+
1005
+ Returns:
1006
+ list[str]: A list containing the names of the reactions.
1007
+
1008
+ """
1009
+ return list(self._reactions)
1010
+
1011
+ @_invalidate_cache
1012
+ def update_reaction(
1013
+ self,
1014
+ name: str,
1015
+ fn: RateFn | None = None,
1016
+ *,
1017
+ args: list[str] | None = None,
1018
+ stoichiometry: Mapping[str, float | Derived | str] | None = None,
1019
+ ) -> Self:
1020
+ """Updates the properties of an existing reaction in the model.
1021
+
1022
+ Examples:
1023
+ >>> model.update_reaction("v1",
1024
+ ... fn=mass_action,
1025
+ ... args=["x1", "kf1"],
1026
+ ... stoichiometry={"x1": -1, "x2": 1},
1027
+ ... )
1028
+
1029
+ Args:
1030
+ name: The name of the reaction to update.
1031
+ fn: The new function for the reaction. If None, the existing function is retained.
1032
+ args: The new arguments for the reaction. If None, the existing arguments are retained.
1033
+ stoichiometry: The new stoichiometry for the reaction. If None, the existing stoichiometry is retained.
1034
+
1035
+ Returns:
1036
+ Self: The instance of the model with the updated reaction.
1037
+
1038
+ """
1039
+ rxn = self._reactions[name]
1040
+ rxn.fn = rxn.fn if fn is None else fn
1041
+
1042
+ if stoichiometry is not None:
1043
+ stoich = {
1044
+ k: Derived(fns.constant, [v]) if isinstance(v, str) else v
1045
+ for k, v in stoichiometry.items()
1046
+ }
1047
+ rxn.stoichiometry = stoich
1048
+ rxn.args = rxn.args if args is None else args
1049
+ return self
1050
+
1051
+ @_invalidate_cache
1052
+ def remove_reaction(self, name: str) -> Self:
1053
+ """Remove a reaction from the model by its name.
1054
+
1055
+ Examples:
1056
+ >>> model.remove_reaction("v1")
1057
+
1058
+ Args:
1059
+ name: The name of the reaction to be removed.
1060
+
1061
+ Returns:
1062
+ Self: The instance of the model with the reaction removed.
1063
+
1064
+ """
1065
+ self._remove_id(name=name)
1066
+ self._reactions.pop(name)
1067
+ return self
1068
+
1069
+ # def update_stoichiometry_of_cpd(
1070
+ # self,
1071
+ # rate_name: str,
1072
+ # compound: str,
1073
+ # value: float,
1074
+ # ) -> Model:
1075
+ # self.update_stoichiometry(
1076
+ # rate_name=rate_name,
1077
+ # stoichiometry=self.stoichiometries[rate_name] | {compound: value},
1078
+ # )
1079
+ # return self
1080
+
1081
+ # def scale_stoichiometry_of_cpd(
1082
+ # self,
1083
+ # rate_name: str,
1084
+ # compound: str,
1085
+ # scale: float,
1086
+ # ) -> Model:
1087
+ # return self.update_stoichiometry_of_cpd(
1088
+ # rate_name=rate_name,
1089
+ # compound=compound,
1090
+ # value=self.stoichiometries[rate_name][compound] * scale,
1091
+ # )
1092
+
1093
+ ##########################################################################
1094
+ # Readouts
1095
+ # They are like derived variables, but only calculated on demand
1096
+ # Think of something like NADPH / (NADP + NADPH) as a proxy for energy state
1097
+ ##########################################################################
1098
+
1099
+ def add_readout(self, name: str, fn: RateFn, *, args: list[str]) -> Self:
1100
+ """Adds a readout to the model.
1101
+
1102
+ Examples:
1103
+ >>> model.add_readout("energy_state",
1104
+ ... fn=div,
1105
+ ... args=["NADPH", "NADP*_total"]
1106
+ ... )
1107
+
1108
+ Args:
1109
+ name: The name of the readout.
1110
+ fn: The function to be used for the readout.
1111
+ args: The list of arguments for the function.
1112
+
1113
+ Returns:
1114
+ Self: The instance of the model with the added readout.
1115
+
1116
+ """
1117
+ self._insert_id(name=name, ctx="readout")
1118
+ self._readouts[name] = Readout(fn, args)
1119
+ return self
1120
+
1121
+ def get_readout_names(self) -> list[str]:
1122
+ """Retrieve the names of all readouts.
1123
+
1124
+ Examples:
1125
+ >>> model.get_readout_names()
1126
+ ["energy_state", "redox_state"]
1127
+
1128
+ Returns:
1129
+ list[str]: A list containing the names of the readouts.
1130
+
1131
+ """
1132
+ return list(self._readouts)
1133
+
1134
+ def remove_readout(self, name: str) -> Self:
1135
+ """Remove a readout by its name.
1136
+
1137
+ Examples:
1138
+ >>> model.remove_readout("energy_state")
1139
+
1140
+ Args:
1141
+ name (str): The name of the readout to remove.
1142
+
1143
+ Returns:
1144
+ Self: The instance of the class after the readout has been removed.
1145
+
1146
+ """
1147
+ self._remove_id(name=name)
1148
+ del self._readouts[name]
1149
+ return self
1150
+
1151
+ ##########################################################################
1152
+ # Surrogates
1153
+ ##########################################################################
1154
+
1155
+ @_invalidate_cache
1156
+ def add_surrogate(
1157
+ self,
1158
+ name: str,
1159
+ surrogate: AbstractSurrogate,
1160
+ args: list[str] | None = None,
1161
+ stoichiometries: dict[str, dict[str, float]] | None = None,
1162
+ ) -> Self:
1163
+ """Adds a surrogate model to the current instance.
1164
+
1165
+ Examples:
1166
+ >>> model.add_surrogate("name", surrogate)
1167
+
1168
+ Args:
1169
+ name (str): The name of the surrogate model.
1170
+ surrogate (AbstractSurrogate): The surrogate model instance to be added.
1171
+ args: A list of arguments for the surrogate model.
1172
+ stoichiometries: A dictionary mapping reaction names to stoichiometries.
1173
+
1174
+ Returns:
1175
+ Self: The current instance with the added surrogate model.
1176
+
1177
+ """
1178
+ self._insert_id(name=name, ctx="surrogate")
1179
+
1180
+ if args is not None:
1181
+ surrogate.args = args
1182
+ if stoichiometries is not None:
1183
+ surrogate.stoichiometries = stoichiometries
1184
+
1185
+ self._surrogates[name] = surrogate
1186
+ return self
1187
+
1188
+ def update_surrogate(
1189
+ self,
1190
+ name: str,
1191
+ surrogate: AbstractSurrogate | None = None,
1192
+ args: list[str] | None = None,
1193
+ stoichiometries: dict[str, dict[str, float]] | None = None,
1194
+ ) -> Self:
1195
+ """Update a surrogate model in the model.
1196
+
1197
+ Examples:
1198
+ >>> model.update_surrogate("name", surrogate)
1199
+
1200
+ Args:
1201
+ name (str): The name of the surrogate model to update.
1202
+ surrogate (AbstractSurrogate): The new surrogate model instance.
1203
+ args: A list of arguments for the surrogate model.
1204
+ stoichiometries: A dictionary mapping reaction names to stoichiometries.
1205
+
1206
+ Returns:
1207
+ Self: The instance of the model with the updated surrogate model.
1208
+
1209
+ """
1210
+ if name not in self._surrogates:
1211
+ msg = f"Surrogate '{name}' not found in model"
1212
+ raise KeyError(msg)
1213
+
1214
+ if surrogate is None:
1215
+ surrogate = self._surrogates[name]
1216
+ if args is not None:
1217
+ surrogate.args = args
1218
+ if stoichiometries is not None:
1219
+ surrogate.stoichiometries = stoichiometries
1220
+
1221
+ self._surrogates[name] = surrogate
1222
+ return self
1223
+
1224
+ def remove_surrogate(self, name: str) -> Self:
1225
+ """Remove a surrogate model from the model.
1226
+
1227
+ Examples:
1228
+ >>> model.remove_surrogate("name")
1229
+
1230
+ Returns:
1231
+ Self: The instance of the model with the specified surrogate model removed.
1232
+
1233
+ """
1234
+ self._remove_id(name=name)
1235
+ self._surrogates.pop(name)
1236
+ return self
1237
+
1238
+ ##########################################################################
1239
+ # Get args
1240
+ ##########################################################################
1241
+
1242
+ def _get_args(
1243
+ self,
1244
+ concs: dict[str, float],
1245
+ time: float = 0.0,
1246
+ *,
1247
+ include_readouts: bool,
1248
+ ) -> dict[str, float]:
1249
+ """Generate a dictionary of arguments for model calculations.
1250
+
1251
+ Examples:
1252
+ >>> model._get_args({"x1": 1.0, "x2": 2.0}, time=0.0)
1253
+ {"x1": 1.0, "x2": 2.0, "k1": 0.1, "time": 0.0}
1254
+
1255
+ Args:
1256
+ concs: A dictionary of concentrations with keys as the names of the substances
1257
+ and values as their respective concentrations.
1258
+ time: The time point for the calculation
1259
+ include_readouts: A flag indicating whether to include readout values in the returned dictionary.
1260
+
1261
+ Returns:
1262
+ dict[str, float]
1263
+ A dictionary containing parameter values, derived variables, and optionally readouts,
1264
+ with their respective names as keys and their calculated values as values.
1265
+
1266
+ """
1267
+ if (cache := self._cache) is None:
1268
+ cache = self._create_cache()
1269
+
1270
+ args: dict[str, float] = cache.all_parameter_values | concs
1271
+ args["time"] = time
1272
+
1273
+ derived = self._derived
1274
+ for name in cache.derived_variable_names:
1275
+ dv = derived[name]
1276
+ args[name] = cast(float, dv.fn(*(args[arg] for arg in dv.args)))
1277
+
1278
+ if include_readouts:
1279
+ for name, ro in self._readouts.items():
1280
+ args[name] = cast(float, ro.fn(*(args[arg] for arg in ro.args)))
1281
+ return args
1282
+
1283
+ def get_args(
1284
+ self,
1285
+ concs: dict[str, float] | None = None,
1286
+ time: float = 0.0,
1287
+ *,
1288
+ include_readouts: bool = False,
1289
+ ) -> pd.Series:
1290
+ """Generate a pandas Series of arguments for the model.
1291
+
1292
+ Examples:
1293
+ # Using initial conditions
1294
+ >>> model.get_args()
1295
+ {"x1": 1.0, "x2": 2.0, "k1": 0.1, "time": 0.0}
1296
+
1297
+ # With custom concentrations
1298
+ >>> model.get_args({"x1": 1.0, "x2": 2.0})
1299
+ {"x1": 1.0, "x2": 2.0, "k1": 0.1, "time": 0.0}
1300
+
1301
+ # With custom concentrations and time
1302
+ >>> model.get_args({"x1": 1.0, "x2": 2.0}, time=1.0)
1303
+ {"x1": 1.0, "x2": 2.0, "k1": 0.1, "time": 1.0}
1304
+
1305
+ Args:
1306
+ concs: A dictionary where keys are the names of the concentrations and values are their respective float values.
1307
+ time: The time point at which the arguments are generated (default is 0.0).
1308
+ include_readouts: Whether to include readouts in the arguments (default is False).
1309
+
1310
+ Returns:
1311
+ A pandas Series containing the generated arguments with float dtype.
1312
+
1313
+ """
1314
+ return pd.Series(
1315
+ self._get_args(
1316
+ concs=self.get_initial_conditions() if concs is None else concs,
1317
+ time=time,
1318
+ include_readouts=include_readouts,
1319
+ ),
1320
+ dtype=float,
1321
+ )
1322
+
1323
+ def get_args_time_course(
1324
+ self,
1325
+ concs: pd.DataFrame,
1326
+ *,
1327
+ include_readouts: bool = False,
1328
+ ) -> pd.DataFrame:
1329
+ """Generate a DataFrame containing time course arguments for model evaluation.
1330
+
1331
+ Examples:
1332
+ >>> model.get_args_time_course(
1333
+ ... pd.DataFrame({"x1": [1.0, 2.0], "x2": [2.0, 3.0]}
1334
+ ... )
1335
+ pd.DataFrame({
1336
+ "x1": [1.0, 2.0],
1337
+ "x2": [2.0, 3.0],
1338
+ "k1": [0.1, 0.1],
1339
+ "time": [0.0, 1.0]},
1340
+ )
1341
+
1342
+ Args:
1343
+ concs: A DataFrame containing concentration data with time as the index.
1344
+ include_readouts: If True, include readout variables in the resulting DataFrame.
1345
+
1346
+ Returns:
1347
+ A DataFrame containing the combined concentration data, parameter values,
1348
+ derived variables, and optionally readout variables, with time as an additional column.
1349
+
1350
+ """
1351
+ if (cache := self._cache) is None:
1352
+ cache = self._create_cache()
1353
+
1354
+ pars_df = pd.DataFrame(
1355
+ np.full(
1356
+ (len(concs), len(cache.all_parameter_values)),
1357
+ np.fromiter(cache.all_parameter_values.values(), dtype=float),
1358
+ ),
1359
+ index=concs.index,
1360
+ columns=list(cache.all_parameter_values),
1361
+ )
1362
+
1363
+ args = pd.concat((concs, pars_df), axis=1)
1364
+ args["time"] = args.index
1365
+
1366
+ derived = self._derived
1367
+ for name in cache.derived_variable_names:
1368
+ dv = derived[name]
1369
+ args[name] = dv.fn(*args.loc[:, dv.args].to_numpy().T)
1370
+
1371
+ if include_readouts:
1372
+ for name, ro in self._readouts.items():
1373
+ args[name] = ro.fn(*args.loc[:, ro.args].to_numpy().T)
1374
+ return args
1375
+
1376
+ ##########################################################################
1377
+ # Get full concs
1378
+ ##########################################################################
1379
+
1380
+ def get_full_concs(
1381
+ self,
1382
+ concs: dict[str, float] | None = None,
1383
+ time: float = 0.0,
1384
+ *,
1385
+ include_readouts: bool = True,
1386
+ ) -> pd.Series:
1387
+ """Get the full concentrations as a pandas Series.
1388
+
1389
+ Examples:
1390
+ >>> model.get_full_concs({"x1": 1.0, "x2": 2.0}, time=0.0)
1391
+ pd.Series({
1392
+ "x1": 1.0,
1393
+ "x2": 2.0,
1394
+ "d1": 3.0,
1395
+ "d2": 4.0,
1396
+ "r1": 0.1,
1397
+ "r2": 0.2,
1398
+ "energy_state": 0.5,
1399
+ })
1400
+
1401
+ Args:
1402
+ concs (dict[str, float]): A dictionary of concentrations with variable names as keys and their corresponding values as floats.
1403
+ time (float, optional): The time point at which to get the concentrations. Default is 0.0.
1404
+ include_readouts (bool, optional): Whether to include readout variables in the result. Default is True.
1405
+
1406
+ Returns:
1407
+ pd.Series: A pandas Series containing the full concentrations for the specified variables.
1408
+
1409
+ """
1410
+ names = self.get_variable_names() + self.get_derived_variable_names()
1411
+ if include_readouts:
1412
+ names.extend(self.get_readout_names())
1413
+
1414
+ return self.get_args(
1415
+ concs=concs,
1416
+ time=time,
1417
+ include_readouts=include_readouts,
1418
+ ).loc[names]
1419
+
1420
+ ##########################################################################
1421
+ # Get fluxes
1422
+ ##########################################################################
1423
+
1424
+ def _get_fluxes(self, args: dict[str, float]) -> dict[str, float]:
1425
+ """Calculate the fluxes for the given arguments.
1426
+
1427
+ Examples:
1428
+ >>> model._get_fluxes({"x1": 1.0, "x2": 2.0, "k1": 0.1, "time": 0.0})
1429
+ {"r1": 0.1, "r2": 0.2}
1430
+
1431
+ Args:
1432
+ args (dict[str, float]): A dictionary where the keys are argument names and the values are their corresponding float values.
1433
+
1434
+ Returns:
1435
+ dict[str, float]: A dictionary where the keys are reaction names and the values are the calculated fluxes.
1436
+
1437
+ """
1438
+ fluxes: dict[str, float] = {}
1439
+ for name, rxn in self._reactions.items():
1440
+ fluxes[name] = cast(float, rxn.fn(*(args[arg] for arg in rxn.args)))
1441
+
1442
+ for surrogate in self._surrogates.values():
1443
+ fluxes |= surrogate.predict(np.array([args[arg] for arg in surrogate.args]))
1444
+ return fluxes
1445
+
1446
+ def get_fluxes(
1447
+ self,
1448
+ concs: dict[str, float] | None = None,
1449
+ time: float = 0.0,
1450
+ ) -> pd.Series:
1451
+ """Calculate the fluxes for the given concentrations and time.
1452
+
1453
+ Examples:
1454
+ # Using initial conditions as default
1455
+ >>> model.get_fluxes()
1456
+ pd.Series({"r1": 0.1, "r2": 0.2})
1457
+
1458
+ # Using custom concentrations
1459
+ >>> model.get_fluxes({"x1": 1.0, "x2": 2.0})
1460
+ pd.Series({"r1": 0.1, "r2": 0.2})
1461
+
1462
+ # Using custom concentrations and time
1463
+ >>> model.get_fluxes({"x1": 1.0, "x2": 2.0}, time=0.0)
1464
+ pd.Series({"r1": 0.1, "r2": 0.2})
1465
+
1466
+ Args:
1467
+ concs: A dictionary where keys are species names and values are their concentrations.
1468
+ time: The time at which to calculate the fluxes. Defaults to 0.0.
1469
+
1470
+ Returns:
1471
+ Fluxes: A pandas Series containing the fluxes for each reaction.
1472
+
1473
+ """
1474
+ args = self.get_args(
1475
+ concs=concs,
1476
+ time=time,
1477
+ include_readouts=False,
1478
+ )
1479
+
1480
+ fluxes: dict[str, float] = {}
1481
+ for name, rxn in self._reactions.items():
1482
+ fluxes[name] = cast(float, rxn.fn(*args.loc[rxn.args]))
1483
+
1484
+ for surrogate in self._surrogates.values():
1485
+ fluxes |= surrogate.predict(args.loc[surrogate.args].to_numpy())
1486
+ return pd.Series(fluxes, dtype=float)
1487
+
1488
+ def get_fluxes_time_course(self, args: pd.DataFrame) -> pd.DataFrame:
1489
+ """Generate a time course of fluxes for the given reactions and surrogates.
1490
+
1491
+ Examples:
1492
+ >>> model.get_fluxes_time_course(args)
1493
+ pd.DataFrame({"v1": [0.1, 0.2], "v2": [0.2, 0.3]})
1494
+
1495
+ This method calculates the fluxes for each reaction in the model using the provided
1496
+ arguments and combines them with the outputs from the surrogates to create a complete
1497
+ time course of fluxes.
1498
+
1499
+ Args:
1500
+ args (pd.DataFrame): A DataFrame containing the input arguments for the reactions
1501
+ and surrogates. Each column corresponds to a specific input
1502
+ variable, and each row represents a different time point.
1503
+
1504
+ Returns:
1505
+ pd.DataFrame: A DataFrame containing the calculated fluxes for each reaction and
1506
+ the outputs from the surrogates. The index of the DataFrame matches
1507
+ the index of the input arguments.
1508
+
1509
+ """
1510
+ fluxes: dict[str, Float] = {}
1511
+ for name, rate in self._reactions.items():
1512
+ fluxes[name] = rate.fn(*args.loc[:, rate.args].to_numpy().T)
1513
+
1514
+ # Create df here already to avoid having to play around with
1515
+ # shape of surrogate outputs
1516
+ flux_df = pd.DataFrame(fluxes, index=args.index)
1517
+ for surrogate in self._surrogates.values():
1518
+ outputs = pd.DataFrame(
1519
+ [surrogate.predict(y) for y in args.loc[:, surrogate.args].to_numpy()],
1520
+ index=args.index,
1521
+ )
1522
+ flux_df = pd.concat((flux_df, outputs), axis=1)
1523
+ return flux_df
1524
+
1525
+ ##########################################################################
1526
+ # Get rhs
1527
+ ##########################################################################
1528
+
1529
+ def __call__(self, /, time: float, concs: Array) -> Array:
1530
+ """Simulation version of get_right_hand_side.
1531
+
1532
+ Examples:
1533
+ >>> model(0.0, np.array([1.0, 2.0]))
1534
+ np.array([0.1, 0.2])
1535
+
1536
+ Warning: Swaps t and y!
1537
+ This can't get kw-only args, as the integrators call it with pos-only
1538
+
1539
+ Args:
1540
+ time: The current time point.
1541
+ concs: Array of concentrations
1542
+
1543
+
1544
+ Returns:
1545
+ The rate of change of each variable in the model.
1546
+
1547
+ """
1548
+ if (cache := self._cache) is None:
1549
+ cache = self._create_cache()
1550
+ concsd: dict[str, float] = dict(
1551
+ zip(
1552
+ cache.var_names,
1553
+ concs,
1554
+ strict=True,
1555
+ )
1556
+ )
1557
+ args: dict[str, float] = self._get_args(
1558
+ concs=concsd,
1559
+ time=time,
1560
+ include_readouts=False,
1561
+ )
1562
+ fluxes: dict[str, float] = self._get_fluxes(args)
1563
+
1564
+ dxdt = cache.dxdt
1565
+ dxdt[:] = 0
1566
+ for k, stoc in cache.stoich_by_cpds.items():
1567
+ for flux, n in stoc.items():
1568
+ dxdt[k] += n * fluxes[flux]
1569
+ for k, sd in cache.dyn_stoich_by_cpds.items():
1570
+ for flux, dv in sd.items():
1571
+ n = dv.fn(*(args[i] for i in dv.args))
1572
+ dxdt[k] += n * fluxes[flux]
1573
+ return cast(Array, dxdt.to_numpy())
1574
+
1575
+ def get_right_hand_side(
1576
+ self,
1577
+ concs: dict[str, float] | None = None,
1578
+ time: float = 0.0,
1579
+ ) -> pd.Series:
1580
+ """Calculate the right-hand side of the differential equations for the model.
1581
+
1582
+ Examples:
1583
+ # Using initial conditions as default
1584
+ >>> model.get_right_hand_side()
1585
+ pd.Series({"x1": 0.1, "x2": 0.2})
1586
+
1587
+ # Using custom concentrations
1588
+ >>> model.get_right_hand_side({"x1": 1.0, "x2": 2.0})
1589
+ pd.Series({"x1": 0.1, "x2": 0.2})
1590
+
1591
+ # Using custom concentrations and time
1592
+ >>> model.get_right_hand_side({"x1": 1.0, "x2": 2.0}, time=0.0)
1593
+ pd.Series({"x1": 0.1, "x2": 0.2})
1594
+
1595
+ Args:
1596
+ concs: A dictionary mapping compound names to their concentrations.
1597
+ time: The current time point. Defaults to 0.0.
1598
+
1599
+ Returns:
1600
+ The rate of change of each variable in the model.
1601
+
1602
+ """
1603
+ if (cache := self._cache) is None:
1604
+ cache = self._create_cache()
1605
+ var_names = self.get_variable_names()
1606
+ args = self._get_args(
1607
+ concs=self.get_initial_conditions() if concs is None else concs,
1608
+ time=time,
1609
+ include_readouts=False,
1610
+ )
1611
+ fluxes = self._get_fluxes(args)
1612
+ dxdt = pd.Series(np.zeros(len(var_names), dtype=float), index=var_names)
1613
+ for k, stoc in cache.stoich_by_cpds.items():
1614
+ for flux, n in stoc.items():
1615
+ dxdt[k] += n * fluxes[flux]
1616
+
1617
+ for k, sd in cache.dyn_stoich_by_cpds.items():
1618
+ for flux, dv in sd.items():
1619
+ n = dv.fn(*(args[i] for i in dv.args))
1620
+ dxdt[k] += n * fluxes[flux]
1621
+ return dxdt