mxlpy 0.22.0__py3-none-any.whl → 0.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mxlpy/model.py CHANGED
@@ -10,7 +10,9 @@ from __future__ import annotations
10
10
  import copy
11
11
  import inspect
12
12
  import itertools as it
13
+ import logging
13
14
  from dataclasses import dataclass, field
15
+ from queue import Empty, SimpleQueue
14
16
  from typing import TYPE_CHECKING, Self, cast
15
17
 
16
18
  import numpy as np
@@ -25,8 +27,8 @@ from mxlpy.meta.sympy_tools import (
25
27
  )
26
28
  from mxlpy.types import (
27
29
  AbstractSurrogate,
28
- Array,
29
30
  Derived,
31
+ InitialAssignment,
30
32
  Parameter,
31
33
  Reaction,
32
34
  Readout,
@@ -37,16 +39,25 @@ if TYPE_CHECKING:
37
39
  from collections.abc import Iterable, Mapping
38
40
  from inspect import FullArgSpec
39
41
 
42
+ from sympy.physics.units.quantities import Quantity
43
+
40
44
  from mxlpy.types import Callable, Param, RateFn, RetType
41
45
 
46
+ LOGGER = logging.getLogger(__name__)
47
+
42
48
  __all__ = [
43
49
  "ArityMismatchError",
44
50
  "CircularDependencyError",
45
51
  "Dependency",
52
+ "Failure",
53
+ "LOGGER",
54
+ "MdText",
46
55
  "MissingDependenciesError",
47
56
  "Model",
48
57
  "ModelCache",
49
58
  "TableView",
59
+ "UnitCheck",
60
+ "unit_of",
50
61
  ]
51
62
 
52
63
 
@@ -56,6 +67,80 @@ def _latex_view(expr: sympy.Expr | None) -> str:
56
67
  return f"${sympy.latex(expr)}$"
57
68
 
58
69
 
70
+ def unit_of(expr: sympy.Expr) -> sympy.Expr:
71
+ """Get unit of sympy expr."""
72
+ return expr.as_coeff_Mul()[1]
73
+
74
+
75
+ @dataclass
76
+ class Failure:
77
+ """Unit test failure."""
78
+
79
+ expected: sympy.Expr
80
+ obtained: sympy.Expr
81
+
82
+ @property
83
+ def difference(self) -> sympy.Expr:
84
+ """Difference between expected and obtained unit."""
85
+ return self.expected / self.obtained # type: ignore
86
+
87
+
88
+ @dataclass
89
+ class MdText:
90
+ """Generic markdown text."""
91
+
92
+ content: list[str]
93
+
94
+ def _repr_markdown_(self) -> str:
95
+ return "\n".join(self.content)
96
+
97
+
98
+ @dataclass
99
+ class UnitCheck:
100
+ """Container for unit check."""
101
+
102
+ per_variable: dict[str, dict[str, bool | Failure | None]]
103
+
104
+ @staticmethod
105
+ def _fmt_success(s: str) -> str:
106
+ return f"<span style='color: green'>{s}</span>"
107
+
108
+ @staticmethod
109
+ def _fmt_failed(s: str) -> str:
110
+ return f"<span style='color: red'>{s}</span>"
111
+
112
+ def correct_diff_eqs(self) -> dict[str, bool]:
113
+ """Get all correctly annotated reactions by variable."""
114
+ return {
115
+ var: all(isinstance(i, bool) for i in checks.values())
116
+ for var, checks in self.per_variable.items()
117
+ }
118
+
119
+ def report(self) -> MdText:
120
+ """Export check as markdown report."""
121
+ report = ["## Type check"]
122
+ for diff_eq, res in self.correct_diff_eqs().items():
123
+ txt = self._fmt_success("Correct") if res else self._fmt_failed("Failed")
124
+ report.append(f"\n### d{diff_eq}dt: {txt}")
125
+
126
+ if res:
127
+ continue
128
+ for k, v in self.per_variable[diff_eq].items():
129
+ match v:
130
+ case bool():
131
+ continue
132
+ case None:
133
+ report.append(f"\n- {k}")
134
+ report.append(" - Failed to parse")
135
+ case Failure(expected, obtained):
136
+ report.append(f"\n- {k}")
137
+ report.append(f" - expected: {_latex_view(expected)}")
138
+ report.append(f" - obtained: {_latex_view(obtained)}")
139
+ report.append(f" - difference: {_latex_view(v.difference)}")
140
+
141
+ return MdText(report)
142
+
143
+
59
144
  @dataclass(kw_only=True, slots=True)
60
145
  class TableView:
61
146
  """Markdown view of pandas Dataframe.
@@ -228,8 +313,6 @@ def _sort_dependencies(
228
313
  SortError: If circular dependencies are detected
229
314
 
230
315
  """
231
- from queue import Empty, SimpleQueue
232
-
233
316
  _check_if_is_sortable(available, elements)
234
317
 
235
318
  order = []
@@ -291,12 +374,13 @@ class ModelCache:
291
374
 
292
375
  """
293
376
 
377
+ order: list[str] # mostly for debug purposes
294
378
  var_names: list[str]
295
379
  dyn_order: list[str]
380
+ base_parameter_values: dict[str, float]
296
381
  all_parameter_values: dict[str, float]
297
382
  stoich_by_cpds: dict[str, dict[str, float]]
298
383
  dyn_stoich_by_cpds: dict[str, dict[str, Derived]]
299
- dxdt: pd.Series
300
384
  initial_conditions: dict[str, float]
301
385
 
302
386
 
@@ -317,7 +401,7 @@ class Model:
317
401
 
318
402
  """
319
403
 
320
- _ids: dict[str, str] = field(default_factory=dict)
404
+ _ids: dict[str, str] = field(default_factory=dict, repr=False)
321
405
  _variables: dict[str, Variable] = field(default_factory=dict)
322
406
  _parameters: dict[str, Parameter] = field(default_factory=dict)
323
407
  _derived: dict[str, Derived] = field(default_factory=dict)
@@ -343,11 +427,32 @@ class Model:
343
427
  ModelCache: An instance of ModelCache containing the initialized cache data.
344
428
 
345
429
  """
346
- all_parameter_values: dict[str, float] = self.get_parameter_values()
347
- all_parameter_names: set[str] = set(all_parameter_values)
430
+ parameter_names = set(self._parameters)
431
+ all_parameter_names = set(parameter_names) # later include static derived
432
+
433
+ base_parameter_values: dict[str, float] = {
434
+ k: val
435
+ for k, v in self._parameters.items()
436
+ if not isinstance(val := v.value, InitialAssignment)
437
+ }
438
+ base_variable_values: dict[str, float] = {
439
+ k: init
440
+ for k, v in self._variables.items()
441
+ if not isinstance(init := v.initial_value, InitialAssignment)
442
+ }
443
+ initial_assignments: dict[str, InitialAssignment] = {
444
+ k: init
445
+ for k, v in self._variables.items()
446
+ if isinstance(init := v.initial_value, InitialAssignment)
447
+ } | {
448
+ k: init
449
+ for k, v in self._parameters.items()
450
+ if isinstance(init := v.value, InitialAssignment)
451
+ }
348
452
 
349
453
  # Sanity checks
350
454
  for name, el in it.chain(
455
+ initial_assignments.items(),
351
456
  self._derived.items(),
352
457
  self._reactions.items(),
353
458
  self._readouts.items(),
@@ -356,25 +461,17 @@ class Model:
356
461
  raise ArityMismatchError(name, el.fn, el.args)
357
462
 
358
463
  # Sort derived & reactions
464
+ available = (
465
+ set(base_parameter_values)
466
+ | set(base_variable_values)
467
+ | set(self._data)
468
+ | {"time"}
469
+ )
359
470
  to_sort = (
360
- self._derived
361
- | self._reactions
362
- | self._surrogates
363
- | {
364
- k: init
365
- for k, v in self._variables.items()
366
- if isinstance(init := v.initial_value, Derived)
367
- }
471
+ initial_assignments | self._derived | self._reactions | self._surrogates
368
472
  )
369
473
  order = _sort_dependencies(
370
- available=all_parameter_names
371
- | {
372
- k
373
- for k, v in self._variables.items()
374
- if not isinstance(v.initial_value, Derived)
375
- }
376
- | set(self._data)
377
- | {"time"},
474
+ available=available,
378
475
  elements=[
379
476
  Dependency(name=k, required=set(v.args), provided={k})
380
477
  if not isinstance(v, AbstractSurrogate)
@@ -386,14 +483,7 @@ class Model:
386
483
  # Calculate all values once, including dynamic ones
387
484
  # That way, we can make initial conditions dependent on e.g. rates
388
485
  dependent = (
389
- all_parameter_values
390
- | self._data
391
- | {
392
- k: init
393
- for k, v in self._variables.items()
394
- if not isinstance(init := v.initial_value, Derived)
395
- }
396
- | {"time": 0.0}
486
+ base_parameter_values | base_variable_values | self._data | {"time": 0.0}
397
487
  )
398
488
  for name in order:
399
489
  to_sort[name].calculate_inpl(name, dependent)
@@ -404,7 +494,7 @@ class Model:
404
494
  for name in order:
405
495
  if name in self._reactions or name in self._surrogates:
406
496
  dyn_order.append(name)
407
- elif name in self._variables:
497
+ elif name in self._variables or name in self._parameters:
408
498
  static_order.append(name)
409
499
  else:
410
500
  derived = self._derived[name]
@@ -445,29 +535,27 @@ class Model:
445
535
  d_static[rxn_name] = factor
446
536
 
447
537
  var_names = self.get_variable_names()
448
- dxdt = pd.Series(np.zeros(len(var_names), dtype=float), index=var_names)
449
-
450
538
  initial_conditions: dict[str, float] = {
451
- k: init
452
- for k, v in self._variables.items()
453
- if not isinstance(init := v.initial_value, Derived)
539
+ k: cast(float, dependent[k]) for k in self._variables
454
540
  }
541
+ all_parameter_values = dict(base_parameter_values)
455
542
  for name in static_order:
456
543
  if name in self._variables:
457
- initial_conditions[name] = cast(float, dependent[name])
458
- elif name in self._derived:
544
+ continue # handled in initial_conditions above
545
+ if name in self._parameters or name in self._derived:
459
546
  all_parameter_values[name] = cast(float, dependent[name])
460
547
  else:
461
548
  msg = "Unknown target for static derived variable."
462
549
  raise KeyError(msg)
463
550
 
464
551
  self._cache = ModelCache(
552
+ order=order,
465
553
  var_names=var_names,
466
554
  dyn_order=dyn_order,
555
+ base_parameter_values=base_parameter_values,
467
556
  all_parameter_values=all_parameter_values,
468
557
  stoich_by_cpds=stoich_by_compounds,
469
558
  dyn_stoich_by_cpds=dyn_stoich_by_compounds,
470
- dxdt=dxdt,
471
559
  initial_conditions=initial_conditions,
472
560
  )
473
561
  return self._cache
@@ -529,14 +617,25 @@ class Model:
529
617
  def parameters(self) -> TableView:
530
618
  """Return view of parameters."""
531
619
  index = list(self._parameters.keys())
532
- data = [
533
- {
534
- "value": el.value,
535
- "unit": _latex_view(unit) if (unit := el.unit) is not None else "",
536
- # "source": ...,
537
- }
538
- for el in self._parameters.values()
539
- ]
620
+ data = []
621
+ for name, el in self._parameters.items():
622
+ if isinstance(init := el.value, InitialAssignment):
623
+ value_str = _latex_view(
624
+ fn_to_sympy(
625
+ init.fn,
626
+ origin=name,
627
+ model_args=list_of_symbols(init.args),
628
+ )
629
+ )
630
+ else:
631
+ value_str = str(init)
632
+ data.append(
633
+ {
634
+ "value": value_str,
635
+ "unit": _latex_view(unit) if (unit := el.unit) is not None else "",
636
+ # "source": ...,
637
+ }
638
+ )
540
639
  return TableView(data=pd.DataFrame(data, index=index))
541
640
 
542
641
  def get_raw_parameters(self, *, as_copy: bool = True) -> dict[str, Parameter]:
@@ -557,7 +656,9 @@ class Model:
557
656
  and the values are parameter values (as floats).
558
657
 
559
658
  """
560
- return {k: v.value for k, v in self._parameters.items()}
659
+ if (cache := self._cache) is None:
660
+ cache = self._create_cache()
661
+ return cache.base_parameter_values
561
662
 
562
663
  def get_parameter_names(self) -> list[str]:
563
664
  """Retrieve the names of the parameters.
@@ -580,7 +681,7 @@ class Model:
580
681
  def add_parameter(
581
682
  self,
582
683
  name: str,
583
- value: float,
684
+ value: float | InitialAssignment,
584
685
  unit: sympy.Expr | None = None,
585
686
  source: str | None = None,
586
687
  ) -> Self:
@@ -603,7 +704,9 @@ class Model:
603
704
  self._parameters[name] = Parameter(value=value, unit=unit, source=source)
604
705
  return self
605
706
 
606
- def add_parameters(self, parameters: Mapping[str, float | Parameter]) -> Self:
707
+ def add_parameters(
708
+ self, parameters: Mapping[str, float | Parameter | InitialAssignment]
709
+ ) -> Self:
607
710
  """Adds multiple parameters to the model.
608
711
 
609
712
  Examples:
@@ -671,7 +774,7 @@ class Model:
671
774
  def update_parameter(
672
775
  self,
673
776
  name: str,
674
- value: float | None = None,
777
+ value: float | InitialAssignment | None = None,
675
778
  *,
676
779
  unit: sympy.Expr | None = None,
677
780
  source: str | None = None,
@@ -695,7 +798,7 @@ class Model:
695
798
 
696
799
  """
697
800
  if name not in self._parameters:
698
- msg = f"'{name}' not found in parameters"
801
+ msg = f"{name!r} not found in parameters"
699
802
  raise KeyError(msg)
700
803
 
701
804
  parameter = self._parameters[name]
@@ -707,7 +810,9 @@ class Model:
707
810
  parameter.source = source
708
811
  return self
709
812
 
710
- def update_parameters(self, parameters: Mapping[str, float | Parameter]) -> Self:
813
+ def update_parameters(
814
+ self, parameters: Mapping[str, float | Parameter | InitialAssignment]
815
+ ) -> Self:
711
816
  """Update multiple parameters of the model.
712
817
 
713
818
  Examples:
@@ -741,7 +846,17 @@ class Model:
741
846
  Self: The instance of the class with the updated parameter.
742
847
 
743
848
  """
744
- return self.update_parameter(name, self._parameters[name].value * factor)
849
+ old = self._parameters[name].value
850
+ if isinstance(old, InitialAssignment):
851
+ LOGGER.warning("Overwriting initial assignment %s", name)
852
+ if (cache := self._cache) is None:
853
+ cache = self._create_cache()
854
+
855
+ return self.update_parameter(
856
+ name, cache.all_parameter_values[name] * factor
857
+ )
858
+
859
+ return self.update_parameter(name, old * factor)
745
860
 
746
861
  def scale_parameters(self, parameters: dict[str, float]) -> Self:
747
862
  """Scales the parameters of the model.
@@ -843,7 +958,7 @@ class Model:
843
958
  index = list(self._variables.keys())
844
959
  data = []
845
960
  for name, el in self._variables.items():
846
- if isinstance(init := el.initial_value, Derived):
961
+ if isinstance(init := el.initial_value, InitialAssignment):
847
962
  value_str = _latex_view(
848
963
  fn_to_sympy(
849
964
  init.fn,
@@ -909,7 +1024,7 @@ class Model:
909
1024
  def add_variable(
910
1025
  self,
911
1026
  name: str,
912
- initial_value: float | Derived,
1027
+ initial_value: float | InitialAssignment,
913
1028
  unit: sympy.Expr | None = None,
914
1029
  source: str | None = None,
915
1030
  ) -> Self:
@@ -935,7 +1050,7 @@ class Model:
935
1050
  return self
936
1051
 
937
1052
  def add_variables(
938
- self, variables: Mapping[str, float | Variable | Derived]
1053
+ self, variables: Mapping[str, float | Variable | InitialAssignment]
939
1054
  ) -> Self:
940
1055
  """Adds multiple variables to the model with their initial conditions.
941
1056
 
@@ -1001,7 +1116,7 @@ class Model:
1001
1116
  def update_variable(
1002
1117
  self,
1003
1118
  name: str,
1004
- initial_value: float | Derived,
1119
+ initial_value: float | InitialAssignment,
1005
1120
  unit: sympy.Expr | None = None,
1006
1121
  source: str | None = None,
1007
1122
  ) -> Self:
@@ -1035,7 +1150,7 @@ class Model:
1035
1150
  return self
1036
1151
 
1037
1152
  def update_variables(
1038
- self, variables: Mapping[str, float | Derived | Variable]
1153
+ self, variables: Mapping[str, float | Variable | InitialAssignment]
1039
1154
  ) -> Self:
1040
1155
  """Updates multiple variables in the model.
1041
1156
 
@@ -1104,7 +1219,7 @@ class Model:
1104
1219
  return self
1105
1220
 
1106
1221
  ##########################################################################
1107
- # Derived - views
1222
+ # Derived
1108
1223
  ##########################################################################
1109
1224
 
1110
1225
  @property
@@ -1713,11 +1828,28 @@ class Model:
1713
1828
  return copy.deepcopy(self._surrogates)
1714
1829
  return self._surrogates
1715
1830
 
1716
- def get_surrogate_output_names(self) -> list[str]:
1717
- """Return output names by surrogates."""
1831
+ def get_surrogate_output_names(
1832
+ self,
1833
+ *,
1834
+ include_fluxes: bool = True,
1835
+ ) -> list[str]:
1836
+ """Return output names by surrogates.
1837
+
1838
+ Optionally filter out the names of which surrogate outfluxes are actually
1839
+ fluxes / reactions rather than variables.
1840
+
1841
+ Args:
1842
+ include_fluxes: whether to also include outputs which are reaction
1843
+ names
1844
+
1845
+ """
1718
1846
  names = []
1719
- for i in self._surrogates.values():
1720
- names.extend(i.outputs)
1847
+ if include_fluxes:
1848
+ for i in self._surrogates.values():
1849
+ names.extend(i.outputs)
1850
+ else:
1851
+ for i in self._surrogates.values():
1852
+ names.extend(x for x in i.outputs if x not in i.stoichiometries)
1721
1853
  return names
1722
1854
 
1723
1855
  def get_surrogate_reaction_names(self) -> list[str]:
@@ -1765,7 +1897,8 @@ class Model:
1765
1897
  include_derived_parameters: bool,
1766
1898
  include_derived_variables: bool,
1767
1899
  include_reactions: bool,
1768
- include_surrogate_outputs: bool,
1900
+ include_surrogate_variables: bool,
1901
+ include_surrogate_fluxes: bool,
1769
1902
  include_readouts: bool,
1770
1903
  ) -> list[str]:
1771
1904
  """Get names of all kinds of model components."""
@@ -1782,8 +1915,10 @@ class Model:
1782
1915
  names.extend(self.get_derived_parameter_names())
1783
1916
  if include_reactions:
1784
1917
  names.extend(self.get_reaction_names())
1785
- if include_surrogate_outputs:
1786
- names.extend(self.get_surrogate_output_names())
1918
+ if include_surrogate_variables:
1919
+ names.extend(self.get_surrogate_output_names(include_fluxes=False))
1920
+ if include_surrogate_fluxes:
1921
+ names.extend(self.get_surrogate_reaction_names())
1787
1922
  if include_readouts:
1788
1923
  names.extend(self.get_readout_names())
1789
1924
  return names
@@ -1837,7 +1972,8 @@ class Model:
1837
1972
  include_derived_parameters: bool = True,
1838
1973
  include_derived_variables: bool = True,
1839
1974
  include_reactions: bool = True,
1840
- include_surrogate_outputs: bool = True,
1975
+ include_surrogate_variables: bool = True,
1976
+ include_surrogate_fluxes: bool = True,
1841
1977
  include_readouts: bool = False,
1842
1978
  ) -> pd.Series:
1843
1979
  """Generate a pandas Series of arguments for the model.
@@ -1864,7 +2000,8 @@ class Model:
1864
2000
  include_derived_parameters: Whether to include derived parameters
1865
2001
  include_derived_variables: Whether to include derived variables
1866
2002
  include_reactions: Whether to include reactions
1867
- include_surrogate_outputs: Whether to include surrogate outputs
2003
+ include_surrogate_variables: Whether to include derive variables obtained from surrogate
2004
+ include_surrogate_fluxes: Whether to include surrogate fluxes
1868
2005
  include_readouts: Whether to include readouts
1869
2006
 
1870
2007
  Returns:
@@ -1890,7 +2027,8 @@ class Model:
1890
2027
  include_derived_parameters=include_derived_parameters,
1891
2028
  include_derived_variables=include_derived_variables,
1892
2029
  include_reactions=include_reactions,
1893
- include_surrogate_outputs=include_surrogate_outputs,
2030
+ include_surrogate_variables=include_surrogate_variables,
2031
+ include_surrogate_fluxes=include_surrogate_fluxes,
1894
2032
  include_readouts=include_readouts,
1895
2033
  )
1896
2034
  ]
@@ -1926,7 +2064,8 @@ class Model:
1926
2064
  include_derived_parameters: bool = True,
1927
2065
  include_derived_variables: bool = True,
1928
2066
  include_reactions: bool = True,
1929
- include_surrogate_outputs: bool = True,
2067
+ include_surrogate_variables: bool = True,
2068
+ include_surrogate_fluxes: bool = True,
1930
2069
  include_readouts: bool = False,
1931
2070
  ) -> pd.DataFrame:
1932
2071
  """Generate a DataFrame containing time course arguments for model evaluation.
@@ -1949,7 +2088,8 @@ class Model:
1949
2088
  include_derived_parameters: Whether to include derived parameters
1950
2089
  include_derived_variables: Whether to include derived variables
1951
2090
  include_reactions: Whether to include reactions
1952
- include_surrogate_outputs: Whether to include surrogate outputs
2091
+ include_surrogate_variables: Whether to include variables derived from surrogates
2092
+ include_surrogate_fluxes: Whether to include surrogate fluxes
1953
2093
  include_readouts: Whether to include readouts
1954
2094
 
1955
2095
  Returns:
@@ -1974,7 +2114,8 @@ class Model:
1974
2114
  include_derived_parameters=include_derived_parameters,
1975
2115
  include_derived_variables=include_derived_variables,
1976
2116
  include_reactions=include_reactions,
1977
- include_surrogate_outputs=include_surrogate_outputs,
2117
+ include_surrogate_variables=include_surrogate_variables,
2118
+ include_surrogate_fluxes=include_surrogate_fluxes,
1978
2119
  include_readouts=include_readouts,
1979
2120
  ),
1980
2121
  ]
@@ -2011,15 +2152,19 @@ class Model:
2011
2152
  Fluxes: A pandas Series containing the fluxes for each reaction.
2012
2153
 
2013
2154
  """
2014
- names = self.get_reaction_names()
2015
- names.extend(self.get_surrogate_reaction_names())
2016
-
2017
- args = self.get_args(
2155
+ return self.get_args(
2018
2156
  variables=variables,
2019
2157
  time=time,
2158
+ include_time=False,
2159
+ include_variables=False,
2160
+ include_parameters=False,
2161
+ include_derived_parameters=False,
2162
+ include_derived_variables=False,
2163
+ include_reactions=True,
2164
+ include_surrogate_variables=False,
2165
+ include_surrogate_fluxes=True,
2020
2166
  include_readouts=False,
2021
2167
  )
2022
- return args.loc[names]
2023
2168
 
2024
2169
  def get_fluxes_time_course(self, variables: pd.DataFrame) -> pd.DataFrame:
2025
2170
  """Generate a time course of fluxes for the given reactions and surrogates.
@@ -2029,7 +2174,9 @@ class Model:
2029
2174
  pd.DataFrame({"v1": [0.1, 0.2], "v2": [0.2, 0.3]})
2030
2175
 
2031
2176
  This method calculates the fluxes for each reaction in the model using the provided
2032
- arguments and combines them with the outputs from the surrogates to create a complete
2177
+ arguments and combines them wit names: list[str] = self.get_reaction_names()
2178
+ for surrogate in self._surrogates.values():
2179
+ names.extend(surrogate.stoichiometries)h the outputs from the surrogates to create a complete
2033
2180
  time course of fluxes.
2034
2181
 
2035
2182
  Args:
@@ -2043,21 +2190,23 @@ class Model:
2043
2190
  the index of the input arguments.
2044
2191
 
2045
2192
  """
2046
- names: list[str] = self.get_reaction_names()
2047
- for surrogate in self._surrogates.values():
2048
- names.extend(surrogate.stoichiometries)
2049
-
2050
- variables = self.get_args_time_course(
2193
+ return self.get_args_time_course(
2051
2194
  variables=variables,
2195
+ include_variables=False,
2196
+ include_parameters=False,
2197
+ include_derived_parameters=False,
2198
+ include_derived_variables=False,
2199
+ include_reactions=True,
2200
+ include_surrogate_variables=False,
2201
+ include_surrogate_fluxes=True,
2052
2202
  include_readouts=False,
2053
2203
  )
2054
- return variables.loc[:, names]
2055
2204
 
2056
2205
  ##########################################################################
2057
2206
  # Get rhs
2058
2207
  ##########################################################################
2059
2208
 
2060
- def __call__(self, /, time: float, variables: Array) -> Array:
2209
+ def __call__(self, /, time: float, variables: Iterable[float]) -> tuple[float, ...]:
2061
2210
  """Simulation version of get_right_hand_side.
2062
2211
 
2063
2212
  Examples:
@@ -2065,7 +2214,7 @@ class Model:
2065
2214
  np.array([0.1, 0.2])
2066
2215
 
2067
2216
  Warning: Swaps t and y!
2068
- This can't get kw-only args, as the integrators call it with pos-only
2217
+ This can't get kw args, as the integrators call it with pos-only
2069
2218
 
2070
2219
  Args:
2071
2220
  time: The current time point.
@@ -2091,8 +2240,7 @@ class Model:
2091
2240
  cache=cache,
2092
2241
  )
2093
2242
 
2094
- dxdt = cache.dxdt
2095
- dxdt[:] = 0
2243
+ dxdt = dict.fromkeys(cache.var_names, 0.0)
2096
2244
  for k, stoc in cache.stoich_by_cpds.items():
2097
2245
  for flux, n in stoc.items():
2098
2246
  dxdt[k] += n * dependent[flux]
@@ -2100,7 +2248,7 @@ class Model:
2100
2248
  for flux, dv in sd.items():
2101
2249
  n = dv.calculate(dependent)
2102
2250
  dxdt[k] += n * dependent[flux]
2103
- return cast(Array, dxdt.to_numpy())
2251
+ return tuple(dxdt[i] for i in cache.var_names)
2104
2252
 
2105
2253
  def get_right_hand_side(
2106
2254
  self,
@@ -2148,3 +2296,52 @@ class Model:
2148
2296
  n = dv.fn(*(dependent[i] for i in dv.args))
2149
2297
  dxdt[k] += n * dependent[flux]
2150
2298
  return dxdt
2299
+
2300
+ ##########################################################################
2301
+ # Check units
2302
+ ##########################################################################
2303
+
2304
+ def check_units(self, time_unit: Quantity) -> UnitCheck:
2305
+ """Check unit consistency per differential equation and reaction."""
2306
+ units_per_fn = {}
2307
+ for name, rxn in self._reactions.items():
2308
+ unit_per_arg = {}
2309
+ for arg in rxn.args:
2310
+ if (par := self._parameters.get(arg)) is not None:
2311
+ unit_per_arg[sympy.Symbol(arg)] = par.unit
2312
+ elif (var := self._variables.get(arg)) is not None:
2313
+ unit_per_arg[sympy.Symbol(arg)] = var.unit
2314
+ else:
2315
+ raise NotImplementedError
2316
+
2317
+ symbolic_fn = fn_to_sympy(
2318
+ rxn.fn,
2319
+ origin="unit-checking",
2320
+ model_args=list_of_symbols(rxn.args),
2321
+ )
2322
+ units_per_fn[name] = None
2323
+ if symbolic_fn is None:
2324
+ continue
2325
+ if any(i is None for i in unit_per_arg.values()):
2326
+ continue
2327
+ units_per_fn[name] = symbolic_fn.subs(unit_per_arg)
2328
+
2329
+ check_per_variable = {}
2330
+ for name, var in self._variables.items():
2331
+ check_per_rxn = {}
2332
+
2333
+ if (var_unit := var.unit) is None:
2334
+ break
2335
+
2336
+ for rxn in self.get_stoichiometries_of_variable(name):
2337
+ if (rxn_unit := units_per_fn.get(rxn)) is None:
2338
+ check_per_rxn[rxn] = None
2339
+ elif unit_of(rxn_unit) == var_unit / time_unit: # type: ignore
2340
+ check_per_rxn[rxn] = True
2341
+ else:
2342
+ check_per_rxn[rxn] = Failure(
2343
+ expected=unit_of(rxn_unit),
2344
+ obtained=var_unit / time_unit, # type: ignore
2345
+ )
2346
+ check_per_variable[name] = check_per_rxn
2347
+ return UnitCheck(check_per_variable)