mxlpy 0.22.0__py3-none-any.whl → 0.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mxlpy/model.py CHANGED
@@ -10,12 +10,15 @@ from __future__ import annotations
10
10
  import copy
11
11
  import inspect
12
12
  import itertools as it
13
+ import logging
13
14
  from dataclasses import dataclass, field
15
+ from queue import Empty, SimpleQueue
14
16
  from typing import TYPE_CHECKING, Self, cast
15
17
 
16
18
  import numpy as np
17
19
  import pandas as pd
18
20
  import sympy
21
+ from wadler_lindig import pformat
19
22
 
20
23
  from mxlpy import fns
21
24
  from mxlpy.meta.source_tools import fn_to_sympy
@@ -25,8 +28,8 @@ from mxlpy.meta.sympy_tools import (
25
28
  )
26
29
  from mxlpy.types import (
27
30
  AbstractSurrogate,
28
- Array,
29
31
  Derived,
32
+ InitialAssignment,
30
33
  Parameter,
31
34
  Reaction,
32
35
  Readout,
@@ -37,16 +40,25 @@ if TYPE_CHECKING:
37
40
  from collections.abc import Iterable, Mapping
38
41
  from inspect import FullArgSpec
39
42
 
43
+ from sympy.physics.units.quantities import Quantity
44
+
40
45
  from mxlpy.types import Callable, Param, RateFn, RetType
41
46
 
47
+ LOGGER = logging.getLogger(__name__)
48
+
42
49
  __all__ = [
43
50
  "ArityMismatchError",
44
51
  "CircularDependencyError",
45
52
  "Dependency",
53
+ "Failure",
54
+ "LOGGER",
55
+ "MdText",
46
56
  "MissingDependenciesError",
47
57
  "Model",
48
58
  "ModelCache",
49
59
  "TableView",
60
+ "UnitCheck",
61
+ "unit_of",
50
62
  ]
51
63
 
52
64
 
@@ -56,6 +68,88 @@ def _latex_view(expr: sympy.Expr | None) -> str:
56
68
  return f"${sympy.latex(expr)}$"
57
69
 
58
70
 
71
+ def unit_of(expr: sympy.Expr) -> sympy.Expr:
72
+ """Get unit of sympy expr."""
73
+ return expr.as_coeff_Mul()[1]
74
+
75
+
76
+ @dataclass
77
+ class Failure:
78
+ """Unit test failure."""
79
+
80
+ expected: sympy.Expr
81
+ obtained: sympy.Expr
82
+
83
+ @property
84
+ def difference(self) -> sympy.Expr:
85
+ """Difference between expected and obtained unit."""
86
+ return self.expected / self.obtained # type: ignore
87
+
88
+
89
+ @dataclass
90
+ class MdText:
91
+ """Generic markdown text."""
92
+
93
+ content: list[str]
94
+
95
+ def __repr__(self) -> str:
96
+ """Return default representation."""
97
+ return pformat(self)
98
+
99
+ def _repr_markdown_(self) -> str:
100
+ return "\n".join(self.content)
101
+
102
+
103
+ @dataclass
104
+ class UnitCheck:
105
+ """Container for unit check."""
106
+
107
+ per_variable: dict[str, dict[str, bool | Failure | None]]
108
+
109
+ def __repr__(self) -> str:
110
+ """Return default representation."""
111
+ return pformat(self)
112
+
113
+ @staticmethod
114
+ def _fmt_success(s: str) -> str:
115
+ return f"<span style='color: green'>{s}</span>"
116
+
117
+ @staticmethod
118
+ def _fmt_failed(s: str) -> str:
119
+ return f"<span style='color: red'>{s}</span>"
120
+
121
+ def correct_diff_eqs(self) -> dict[str, bool]:
122
+ """Get all correctly annotated reactions by variable."""
123
+ return {
124
+ var: all(isinstance(i, bool) for i in checks.values())
125
+ for var, checks in self.per_variable.items()
126
+ }
127
+
128
+ def report(self) -> MdText:
129
+ """Export check as markdown report."""
130
+ report = ["## Type check"]
131
+ for diff_eq, res in self.correct_diff_eqs().items():
132
+ txt = self._fmt_success("Correct") if res else self._fmt_failed("Failed")
133
+ report.append(f"\n### d{diff_eq}dt: {txt}")
134
+
135
+ if res:
136
+ continue
137
+ for k, v in self.per_variable[diff_eq].items():
138
+ match v:
139
+ case bool():
140
+ continue
141
+ case None:
142
+ report.append(f"\n- {k}")
143
+ report.append(" - Failed to parse")
144
+ case Failure(expected, obtained):
145
+ report.append(f"\n- {k}")
146
+ report.append(f" - expected: {_latex_view(expected)}")
147
+ report.append(f" - obtained: {_latex_view(obtained)}")
148
+ report.append(f" - difference: {_latex_view(v.difference)}")
149
+
150
+ return MdText(report)
151
+
152
+
59
153
  @dataclass(kw_only=True, slots=True)
60
154
  class TableView:
61
155
  """Markdown view of pandas Dataframe.
@@ -86,6 +180,10 @@ class Dependency:
86
180
  required: set[str]
87
181
  provided: set[str]
88
182
 
183
+ def __repr__(self) -> str:
184
+ """Return default representation."""
185
+ return pformat(self)
186
+
89
187
 
90
188
  class MissingDependenciesError(Exception):
91
189
  """Raised when dependencies cannot be sorted topologically.
@@ -228,8 +326,6 @@ def _sort_dependencies(
228
326
  SortError: If circular dependencies are detected
229
327
 
230
328
  """
231
- from queue import Empty, SimpleQueue
232
-
233
329
  _check_if_is_sortable(available, elements)
234
330
 
235
331
  order = []
@@ -291,12 +387,17 @@ class ModelCache:
291
387
 
292
388
  """
293
389
 
390
+ def __repr__(self) -> str:
391
+ """Return default representation."""
392
+ return pformat(self)
393
+
394
+ order: list[str] # mostly for debug purposes
294
395
  var_names: list[str]
295
396
  dyn_order: list[str]
397
+ base_parameter_values: dict[str, float]
296
398
  all_parameter_values: dict[str, float]
297
399
  stoich_by_cpds: dict[str, dict[str, float]]
298
400
  dyn_stoich_by_cpds: dict[str, dict[str, Derived]]
299
- dxdt: pd.Series
300
401
  initial_conditions: dict[str, float]
301
402
 
302
403
 
@@ -317,16 +418,20 @@ class Model:
317
418
 
318
419
  """
319
420
 
320
- _ids: dict[str, str] = field(default_factory=dict)
421
+ _ids: dict[str, str] = field(default_factory=dict, repr=False)
422
+ _cache: ModelCache | None = field(default=None, repr=False)
321
423
  _variables: dict[str, Variable] = field(default_factory=dict)
322
424
  _parameters: dict[str, Parameter] = field(default_factory=dict)
323
425
  _derived: dict[str, Derived] = field(default_factory=dict)
324
426
  _readouts: dict[str, Readout] = field(default_factory=dict)
325
427
  _reactions: dict[str, Reaction] = field(default_factory=dict)
326
428
  _surrogates: dict[str, AbstractSurrogate] = field(default_factory=dict)
327
- _cache: ModelCache | None = None
328
429
  _data: dict[str, pd.Series | pd.DataFrame] = field(default_factory=dict)
329
430
 
431
+ def __repr__(self) -> str:
432
+ """Return default representation."""
433
+ return pformat(self)
434
+
330
435
  ###########################################################################
331
436
  # Cache
332
437
  ###########################################################################
@@ -343,11 +448,32 @@ class Model:
343
448
  ModelCache: An instance of ModelCache containing the initialized cache data.
344
449
 
345
450
  """
346
- all_parameter_values: dict[str, float] = self.get_parameter_values()
347
- all_parameter_names: set[str] = set(all_parameter_values)
451
+ parameter_names = set(self._parameters)
452
+ all_parameter_names = set(parameter_names) # later include static derived
453
+
454
+ base_parameter_values: dict[str, float] = {
455
+ k: val
456
+ for k, v in self._parameters.items()
457
+ if not isinstance(val := v.value, InitialAssignment)
458
+ }
459
+ base_variable_values: dict[str, float] = {
460
+ k: init
461
+ for k, v in self._variables.items()
462
+ if not isinstance(init := v.initial_value, InitialAssignment)
463
+ }
464
+ initial_assignments: dict[str, InitialAssignment] = {
465
+ k: init
466
+ for k, v in self._variables.items()
467
+ if isinstance(init := v.initial_value, InitialAssignment)
468
+ } | {
469
+ k: init
470
+ for k, v in self._parameters.items()
471
+ if isinstance(init := v.value, InitialAssignment)
472
+ }
348
473
 
349
474
  # Sanity checks
350
475
  for name, el in it.chain(
476
+ initial_assignments.items(),
351
477
  self._derived.items(),
352
478
  self._reactions.items(),
353
479
  self._readouts.items(),
@@ -356,25 +482,17 @@ class Model:
356
482
  raise ArityMismatchError(name, el.fn, el.args)
357
483
 
358
484
  # Sort derived & reactions
485
+ available = (
486
+ set(base_parameter_values)
487
+ | set(base_variable_values)
488
+ | set(self._data)
489
+ | {"time"}
490
+ )
359
491
  to_sort = (
360
- self._derived
361
- | self._reactions
362
- | self._surrogates
363
- | {
364
- k: init
365
- for k, v in self._variables.items()
366
- if isinstance(init := v.initial_value, Derived)
367
- }
492
+ initial_assignments | self._derived | self._reactions | self._surrogates
368
493
  )
369
494
  order = _sort_dependencies(
370
- available=all_parameter_names
371
- | {
372
- k
373
- for k, v in self._variables.items()
374
- if not isinstance(v.initial_value, Derived)
375
- }
376
- | set(self._data)
377
- | {"time"},
495
+ available=available,
378
496
  elements=[
379
497
  Dependency(name=k, required=set(v.args), provided={k})
380
498
  if not isinstance(v, AbstractSurrogate)
@@ -386,14 +504,7 @@ class Model:
386
504
  # Calculate all values once, including dynamic ones
387
505
  # That way, we can make initial conditions dependent on e.g. rates
388
506
  dependent = (
389
- all_parameter_values
390
- | self._data
391
- | {
392
- k: init
393
- for k, v in self._variables.items()
394
- if not isinstance(init := v.initial_value, Derived)
395
- }
396
- | {"time": 0.0}
507
+ base_parameter_values | base_variable_values | self._data | {"time": 0.0}
397
508
  )
398
509
  for name in order:
399
510
  to_sort[name].calculate_inpl(name, dependent)
@@ -404,7 +515,7 @@ class Model:
404
515
  for name in order:
405
516
  if name in self._reactions or name in self._surrogates:
406
517
  dyn_order.append(name)
407
- elif name in self._variables:
518
+ elif name in self._variables or name in self._parameters:
408
519
  static_order.append(name)
409
520
  else:
410
521
  derived = self._derived[name]
@@ -445,29 +556,27 @@ class Model:
445
556
  d_static[rxn_name] = factor
446
557
 
447
558
  var_names = self.get_variable_names()
448
- dxdt = pd.Series(np.zeros(len(var_names), dtype=float), index=var_names)
449
-
450
559
  initial_conditions: dict[str, float] = {
451
- k: init
452
- for k, v in self._variables.items()
453
- if not isinstance(init := v.initial_value, Derived)
560
+ k: cast(float, dependent[k]) for k in self._variables
454
561
  }
562
+ all_parameter_values = dict(base_parameter_values)
455
563
  for name in static_order:
456
564
  if name in self._variables:
457
- initial_conditions[name] = cast(float, dependent[name])
458
- elif name in self._derived:
565
+ continue # handled in initial_conditions above
566
+ if name in self._parameters or name in self._derived:
459
567
  all_parameter_values[name] = cast(float, dependent[name])
460
568
  else:
461
569
  msg = "Unknown target for static derived variable."
462
570
  raise KeyError(msg)
463
571
 
464
572
  self._cache = ModelCache(
573
+ order=order,
465
574
  var_names=var_names,
466
575
  dyn_order=dyn_order,
576
+ base_parameter_values=base_parameter_values,
467
577
  all_parameter_values=all_parameter_values,
468
578
  stoich_by_cpds=stoich_by_compounds,
469
579
  dyn_stoich_by_cpds=dyn_stoich_by_compounds,
470
- dxdt=dxdt,
471
580
  initial_conditions=initial_conditions,
472
581
  )
473
582
  return self._cache
@@ -529,14 +638,25 @@ class Model:
529
638
  def parameters(self) -> TableView:
530
639
  """Return view of parameters."""
531
640
  index = list(self._parameters.keys())
532
- data = [
533
- {
534
- "value": el.value,
535
- "unit": _latex_view(unit) if (unit := el.unit) is not None else "",
536
- # "source": ...,
537
- }
538
- for el in self._parameters.values()
539
- ]
641
+ data = []
642
+ for name, el in self._parameters.items():
643
+ if isinstance(init := el.value, InitialAssignment):
644
+ value_str = _latex_view(
645
+ fn_to_sympy(
646
+ init.fn,
647
+ origin=name,
648
+ model_args=list_of_symbols(init.args),
649
+ )
650
+ )
651
+ else:
652
+ value_str = str(init)
653
+ data.append(
654
+ {
655
+ "value": value_str,
656
+ "unit": _latex_view(unit) if (unit := el.unit) is not None else "",
657
+ # "source": ...,
658
+ }
659
+ )
540
660
  return TableView(data=pd.DataFrame(data, index=index))
541
661
 
542
662
  def get_raw_parameters(self, *, as_copy: bool = True) -> dict[str, Parameter]:
@@ -557,7 +677,9 @@ class Model:
557
677
  and the values are parameter values (as floats).
558
678
 
559
679
  """
560
- return {k: v.value for k, v in self._parameters.items()}
680
+ if (cache := self._cache) is None:
681
+ cache = self._create_cache()
682
+ return cache.base_parameter_values
561
683
 
562
684
  def get_parameter_names(self) -> list[str]:
563
685
  """Retrieve the names of the parameters.
@@ -580,7 +702,7 @@ class Model:
580
702
  def add_parameter(
581
703
  self,
582
704
  name: str,
583
- value: float,
705
+ value: float | InitialAssignment,
584
706
  unit: sympy.Expr | None = None,
585
707
  source: str | None = None,
586
708
  ) -> Self:
@@ -603,7 +725,9 @@ class Model:
603
725
  self._parameters[name] = Parameter(value=value, unit=unit, source=source)
604
726
  return self
605
727
 
606
- def add_parameters(self, parameters: Mapping[str, float | Parameter]) -> Self:
728
+ def add_parameters(
729
+ self, parameters: Mapping[str, float | Parameter | InitialAssignment]
730
+ ) -> Self:
607
731
  """Adds multiple parameters to the model.
608
732
 
609
733
  Examples:
@@ -671,7 +795,7 @@ class Model:
671
795
  def update_parameter(
672
796
  self,
673
797
  name: str,
674
- value: float | None = None,
798
+ value: float | InitialAssignment | None = None,
675
799
  *,
676
800
  unit: sympy.Expr | None = None,
677
801
  source: str | None = None,
@@ -695,7 +819,7 @@ class Model:
695
819
 
696
820
  """
697
821
  if name not in self._parameters:
698
- msg = f"'{name}' not found in parameters"
822
+ msg = f"{name!r} not found in parameters"
699
823
  raise KeyError(msg)
700
824
 
701
825
  parameter = self._parameters[name]
@@ -707,7 +831,9 @@ class Model:
707
831
  parameter.source = source
708
832
  return self
709
833
 
710
- def update_parameters(self, parameters: Mapping[str, float | Parameter]) -> Self:
834
+ def update_parameters(
835
+ self, parameters: Mapping[str, float | Parameter | InitialAssignment]
836
+ ) -> Self:
711
837
  """Update multiple parameters of the model.
712
838
 
713
839
  Examples:
@@ -741,7 +867,17 @@ class Model:
741
867
  Self: The instance of the class with the updated parameter.
742
868
 
743
869
  """
744
- return self.update_parameter(name, self._parameters[name].value * factor)
870
+ old = self._parameters[name].value
871
+ if isinstance(old, InitialAssignment):
872
+ LOGGER.warning("Overwriting initial assignment %s", name)
873
+ if (cache := self._cache) is None:
874
+ cache = self._create_cache()
875
+
876
+ return self.update_parameter(
877
+ name, cache.all_parameter_values[name] * factor
878
+ )
879
+
880
+ return self.update_parameter(name, old * factor)
745
881
 
746
882
  def scale_parameters(self, parameters: dict[str, float]) -> Self:
747
883
  """Scales the parameters of the model.
@@ -843,7 +979,7 @@ class Model:
843
979
  index = list(self._variables.keys())
844
980
  data = []
845
981
  for name, el in self._variables.items():
846
- if isinstance(init := el.initial_value, Derived):
982
+ if isinstance(init := el.initial_value, InitialAssignment):
847
983
  value_str = _latex_view(
848
984
  fn_to_sympy(
849
985
  init.fn,
@@ -909,7 +1045,7 @@ class Model:
909
1045
  def add_variable(
910
1046
  self,
911
1047
  name: str,
912
- initial_value: float | Derived,
1048
+ initial_value: float | InitialAssignment,
913
1049
  unit: sympy.Expr | None = None,
914
1050
  source: str | None = None,
915
1051
  ) -> Self:
@@ -935,7 +1071,7 @@ class Model:
935
1071
  return self
936
1072
 
937
1073
  def add_variables(
938
- self, variables: Mapping[str, float | Variable | Derived]
1074
+ self, variables: Mapping[str, float | Variable | InitialAssignment]
939
1075
  ) -> Self:
940
1076
  """Adds multiple variables to the model with their initial conditions.
941
1077
 
@@ -1001,7 +1137,7 @@ class Model:
1001
1137
  def update_variable(
1002
1138
  self,
1003
1139
  name: str,
1004
- initial_value: float | Derived,
1140
+ initial_value: float | InitialAssignment,
1005
1141
  unit: sympy.Expr | None = None,
1006
1142
  source: str | None = None,
1007
1143
  ) -> Self:
@@ -1035,7 +1171,7 @@ class Model:
1035
1171
  return self
1036
1172
 
1037
1173
  def update_variables(
1038
- self, variables: Mapping[str, float | Derived | Variable]
1174
+ self, variables: Mapping[str, float | Variable | InitialAssignment]
1039
1175
  ) -> Self:
1040
1176
  """Updates multiple variables in the model.
1041
1177
 
@@ -1104,7 +1240,7 @@ class Model:
1104
1240
  return self
1105
1241
 
1106
1242
  ##########################################################################
1107
- # Derived - views
1243
+ # Derived
1108
1244
  ##########################################################################
1109
1245
 
1110
1246
  @property
@@ -1713,11 +1849,28 @@ class Model:
1713
1849
  return copy.deepcopy(self._surrogates)
1714
1850
  return self._surrogates
1715
1851
 
1716
- def get_surrogate_output_names(self) -> list[str]:
1717
- """Return output names by surrogates."""
1852
+ def get_surrogate_output_names(
1853
+ self,
1854
+ *,
1855
+ include_fluxes: bool = True,
1856
+ ) -> list[str]:
1857
+ """Return output names by surrogates.
1858
+
1859
+ Optionally filter out the names of which surrogate outfluxes are actually
1860
+ fluxes / reactions rather than variables.
1861
+
1862
+ Args:
1863
+ include_fluxes: whether to also include outputs which are reaction
1864
+ names
1865
+
1866
+ """
1718
1867
  names = []
1719
- for i in self._surrogates.values():
1720
- names.extend(i.outputs)
1868
+ if include_fluxes:
1869
+ for i in self._surrogates.values():
1870
+ names.extend(i.outputs)
1871
+ else:
1872
+ for i in self._surrogates.values():
1873
+ names.extend(x for x in i.outputs if x not in i.stoichiometries)
1721
1874
  return names
1722
1875
 
1723
1876
  def get_surrogate_reaction_names(self) -> list[str]:
@@ -1765,7 +1918,8 @@ class Model:
1765
1918
  include_derived_parameters: bool,
1766
1919
  include_derived_variables: bool,
1767
1920
  include_reactions: bool,
1768
- include_surrogate_outputs: bool,
1921
+ include_surrogate_variables: bool,
1922
+ include_surrogate_fluxes: bool,
1769
1923
  include_readouts: bool,
1770
1924
  ) -> list[str]:
1771
1925
  """Get names of all kinds of model components."""
@@ -1782,8 +1936,10 @@ class Model:
1782
1936
  names.extend(self.get_derived_parameter_names())
1783
1937
  if include_reactions:
1784
1938
  names.extend(self.get_reaction_names())
1785
- if include_surrogate_outputs:
1786
- names.extend(self.get_surrogate_output_names())
1939
+ if include_surrogate_variables:
1940
+ names.extend(self.get_surrogate_output_names(include_fluxes=False))
1941
+ if include_surrogate_fluxes:
1942
+ names.extend(self.get_surrogate_reaction_names())
1787
1943
  if include_readouts:
1788
1944
  names.extend(self.get_readout_names())
1789
1945
  return names
@@ -1837,7 +1993,8 @@ class Model:
1837
1993
  include_derived_parameters: bool = True,
1838
1994
  include_derived_variables: bool = True,
1839
1995
  include_reactions: bool = True,
1840
- include_surrogate_outputs: bool = True,
1996
+ include_surrogate_variables: bool = True,
1997
+ include_surrogate_fluxes: bool = True,
1841
1998
  include_readouts: bool = False,
1842
1999
  ) -> pd.Series:
1843
2000
  """Generate a pandas Series of arguments for the model.
@@ -1864,7 +2021,8 @@ class Model:
1864
2021
  include_derived_parameters: Whether to include derived parameters
1865
2022
  include_derived_variables: Whether to include derived variables
1866
2023
  include_reactions: Whether to include reactions
1867
- include_surrogate_outputs: Whether to include surrogate outputs
2024
+ include_surrogate_variables: Whether to include derive variables obtained from surrogate
2025
+ include_surrogate_fluxes: Whether to include surrogate fluxes
1868
2026
  include_readouts: Whether to include readouts
1869
2027
 
1870
2028
  Returns:
@@ -1890,7 +2048,8 @@ class Model:
1890
2048
  include_derived_parameters=include_derived_parameters,
1891
2049
  include_derived_variables=include_derived_variables,
1892
2050
  include_reactions=include_reactions,
1893
- include_surrogate_outputs=include_surrogate_outputs,
2051
+ include_surrogate_variables=include_surrogate_variables,
2052
+ include_surrogate_fluxes=include_surrogate_fluxes,
1894
2053
  include_readouts=include_readouts,
1895
2054
  )
1896
2055
  ]
@@ -1926,7 +2085,8 @@ class Model:
1926
2085
  include_derived_parameters: bool = True,
1927
2086
  include_derived_variables: bool = True,
1928
2087
  include_reactions: bool = True,
1929
- include_surrogate_outputs: bool = True,
2088
+ include_surrogate_variables: bool = True,
2089
+ include_surrogate_fluxes: bool = True,
1930
2090
  include_readouts: bool = False,
1931
2091
  ) -> pd.DataFrame:
1932
2092
  """Generate a DataFrame containing time course arguments for model evaluation.
@@ -1949,7 +2109,8 @@ class Model:
1949
2109
  include_derived_parameters: Whether to include derived parameters
1950
2110
  include_derived_variables: Whether to include derived variables
1951
2111
  include_reactions: Whether to include reactions
1952
- include_surrogate_outputs: Whether to include surrogate outputs
2112
+ include_surrogate_variables: Whether to include variables derived from surrogates
2113
+ include_surrogate_fluxes: Whether to include surrogate fluxes
1953
2114
  include_readouts: Whether to include readouts
1954
2115
 
1955
2116
  Returns:
@@ -1974,7 +2135,8 @@ class Model:
1974
2135
  include_derived_parameters=include_derived_parameters,
1975
2136
  include_derived_variables=include_derived_variables,
1976
2137
  include_reactions=include_reactions,
1977
- include_surrogate_outputs=include_surrogate_outputs,
2138
+ include_surrogate_variables=include_surrogate_variables,
2139
+ include_surrogate_fluxes=include_surrogate_fluxes,
1978
2140
  include_readouts=include_readouts,
1979
2141
  ),
1980
2142
  ]
@@ -2011,15 +2173,19 @@ class Model:
2011
2173
  Fluxes: A pandas Series containing the fluxes for each reaction.
2012
2174
 
2013
2175
  """
2014
- names = self.get_reaction_names()
2015
- names.extend(self.get_surrogate_reaction_names())
2016
-
2017
- args = self.get_args(
2176
+ return self.get_args(
2018
2177
  variables=variables,
2019
2178
  time=time,
2179
+ include_time=False,
2180
+ include_variables=False,
2181
+ include_parameters=False,
2182
+ include_derived_parameters=False,
2183
+ include_derived_variables=False,
2184
+ include_reactions=True,
2185
+ include_surrogate_variables=False,
2186
+ include_surrogate_fluxes=True,
2020
2187
  include_readouts=False,
2021
2188
  )
2022
- return args.loc[names]
2023
2189
 
2024
2190
  def get_fluxes_time_course(self, variables: pd.DataFrame) -> pd.DataFrame:
2025
2191
  """Generate a time course of fluxes for the given reactions and surrogates.
@@ -2029,7 +2195,9 @@ class Model:
2029
2195
  pd.DataFrame({"v1": [0.1, 0.2], "v2": [0.2, 0.3]})
2030
2196
 
2031
2197
  This method calculates the fluxes for each reaction in the model using the provided
2032
- arguments and combines them with the outputs from the surrogates to create a complete
2198
+ arguments and combines them wit names: list[str] = self.get_reaction_names()
2199
+ for surrogate in self._surrogates.values():
2200
+ names.extend(surrogate.stoichiometries)h the outputs from the surrogates to create a complete
2033
2201
  time course of fluxes.
2034
2202
 
2035
2203
  Args:
@@ -2043,21 +2211,23 @@ class Model:
2043
2211
  the index of the input arguments.
2044
2212
 
2045
2213
  """
2046
- names: list[str] = self.get_reaction_names()
2047
- for surrogate in self._surrogates.values():
2048
- names.extend(surrogate.stoichiometries)
2049
-
2050
- variables = self.get_args_time_course(
2214
+ return self.get_args_time_course(
2051
2215
  variables=variables,
2216
+ include_variables=False,
2217
+ include_parameters=False,
2218
+ include_derived_parameters=False,
2219
+ include_derived_variables=False,
2220
+ include_reactions=True,
2221
+ include_surrogate_variables=False,
2222
+ include_surrogate_fluxes=True,
2052
2223
  include_readouts=False,
2053
2224
  )
2054
- return variables.loc[:, names]
2055
2225
 
2056
2226
  ##########################################################################
2057
2227
  # Get rhs
2058
2228
  ##########################################################################
2059
2229
 
2060
- def __call__(self, /, time: float, variables: Array) -> Array:
2230
+ def __call__(self, /, time: float, variables: Iterable[float]) -> tuple[float, ...]:
2061
2231
  """Simulation version of get_right_hand_side.
2062
2232
 
2063
2233
  Examples:
@@ -2065,7 +2235,7 @@ class Model:
2065
2235
  np.array([0.1, 0.2])
2066
2236
 
2067
2237
  Warning: Swaps t and y!
2068
- This can't get kw-only args, as the integrators call it with pos-only
2238
+ This can't get kw args, as the integrators call it with pos-only
2069
2239
 
2070
2240
  Args:
2071
2241
  time: The current time point.
@@ -2091,8 +2261,7 @@ class Model:
2091
2261
  cache=cache,
2092
2262
  )
2093
2263
 
2094
- dxdt = cache.dxdt
2095
- dxdt[:] = 0
2264
+ dxdt = dict.fromkeys(cache.var_names, 0.0)
2096
2265
  for k, stoc in cache.stoich_by_cpds.items():
2097
2266
  for flux, n in stoc.items():
2098
2267
  dxdt[k] += n * dependent[flux]
@@ -2100,7 +2269,7 @@ class Model:
2100
2269
  for flux, dv in sd.items():
2101
2270
  n = dv.calculate(dependent)
2102
2271
  dxdt[k] += n * dependent[flux]
2103
- return cast(Array, dxdt.to_numpy())
2272
+ return tuple(dxdt[i] for i in cache.var_names)
2104
2273
 
2105
2274
  def get_right_hand_side(
2106
2275
  self,
@@ -2133,7 +2302,7 @@ class Model:
2133
2302
  if (cache := self._cache) is None:
2134
2303
  cache = self._create_cache()
2135
2304
  var_names = self.get_variable_names()
2136
- dependent = self._get_args(
2305
+ args = self._get_args(
2137
2306
  variables=self.get_initial_conditions() if variables is None else variables,
2138
2307
  time=time,
2139
2308
  cache=cache,
@@ -2141,10 +2310,59 @@ class Model:
2141
2310
  dxdt = pd.Series(np.zeros(len(var_names), dtype=float), index=var_names)
2142
2311
  for k, stoc in cache.stoich_by_cpds.items():
2143
2312
  for flux, n in stoc.items():
2144
- dxdt[k] += n * dependent[flux]
2313
+ dxdt[k] += n * args[flux]
2145
2314
 
2146
2315
  for k, sd in cache.dyn_stoich_by_cpds.items():
2147
2316
  for flux, dv in sd.items():
2148
- n = dv.fn(*(dependent[i] for i in dv.args))
2149
- dxdt[k] += n * dependent[flux]
2317
+ n = dv.fn(*(args[i] for i in dv.args))
2318
+ dxdt[k] += n * args[flux]
2150
2319
  return dxdt
2320
+
2321
+ ##########################################################################
2322
+ # Check units
2323
+ ##########################################################################
2324
+
2325
+ def check_units(self, time_unit: Quantity) -> UnitCheck:
2326
+ """Check unit consistency per differential equation and reaction."""
2327
+ units_per_fn = {}
2328
+ for name, rxn in self._reactions.items():
2329
+ unit_per_arg = {}
2330
+ for arg in rxn.args:
2331
+ if (par := self._parameters.get(arg)) is not None:
2332
+ unit_per_arg[sympy.Symbol(arg)] = par.unit
2333
+ elif (var := self._variables.get(arg)) is not None:
2334
+ unit_per_arg[sympy.Symbol(arg)] = var.unit
2335
+ else:
2336
+ raise NotImplementedError
2337
+
2338
+ symbolic_fn = fn_to_sympy(
2339
+ rxn.fn,
2340
+ origin="unit-checking",
2341
+ model_args=list_of_symbols(rxn.args),
2342
+ )
2343
+ units_per_fn[name] = None
2344
+ if symbolic_fn is None:
2345
+ continue
2346
+ if any(i is None for i in unit_per_arg.values()):
2347
+ continue
2348
+ units_per_fn[name] = symbolic_fn.subs(unit_per_arg)
2349
+
2350
+ check_per_variable = {}
2351
+ for name, var in self._variables.items():
2352
+ check_per_rxn = {}
2353
+
2354
+ if (var_unit := var.unit) is None:
2355
+ break
2356
+
2357
+ for rxn in self.get_stoichiometries_of_variable(name):
2358
+ if (rxn_unit := units_per_fn.get(rxn)) is None:
2359
+ check_per_rxn[rxn] = None
2360
+ elif unit_of(rxn_unit) == var_unit / time_unit: # type: ignore
2361
+ check_per_rxn[rxn] = True
2362
+ else:
2363
+ check_per_rxn[rxn] = Failure(
2364
+ expected=unit_of(rxn_unit),
2365
+ obtained=var_unit / time_unit, # type: ignore
2366
+ )
2367
+ check_per_variable[name] = check_per_rxn
2368
+ return UnitCheck(check_per_variable)