mxlpy 0.21.0__py3-none-any.whl → 0.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mxlpy/model.py CHANGED
@@ -10,31 +10,159 @@ from __future__ import annotations
10
10
  import copy
11
11
  import inspect
12
12
  import itertools as it
13
+ import logging
13
14
  from dataclasses import dataclass, field
15
+ from queue import Empty, SimpleQueue
14
16
  from typing import TYPE_CHECKING, Self, cast
15
17
 
16
18
  import numpy as np
17
19
  import pandas as pd
20
+ import sympy
18
21
 
19
22
  from mxlpy import fns
20
- from mxlpy.types import AbstractSurrogate, Array, Derived, Reaction, Readout
23
+ from mxlpy.meta.source_tools import fn_to_sympy
24
+ from mxlpy.meta.sympy_tools import (
25
+ list_of_symbols,
26
+ stoichiometries_to_sympy,
27
+ )
28
+ from mxlpy.types import (
29
+ AbstractSurrogate,
30
+ Derived,
31
+ InitialAssignment,
32
+ Parameter,
33
+ Reaction,
34
+ Readout,
35
+ Variable,
36
+ )
21
37
 
22
38
  if TYPE_CHECKING:
23
39
  from collections.abc import Iterable, Mapping
24
40
  from inspect import FullArgSpec
25
41
 
42
+ from sympy.physics.units.quantities import Quantity
43
+
26
44
  from mxlpy.types import Callable, Param, RateFn, RetType
27
45
 
46
+ LOGGER = logging.getLogger(__name__)
47
+
28
48
  __all__ = [
29
49
  "ArityMismatchError",
30
50
  "CircularDependencyError",
31
51
  "Dependency",
52
+ "Failure",
53
+ "LOGGER",
54
+ "MdText",
32
55
  "MissingDependenciesError",
33
56
  "Model",
34
57
  "ModelCache",
58
+ "TableView",
59
+ "UnitCheck",
60
+ "unit_of",
35
61
  ]
36
62
 
37
63
 
64
+ def _latex_view(expr: sympy.Expr | None) -> str:
65
+ if expr is None:
66
+ return "PARSE-ERROR"
67
+ return f"${sympy.latex(expr)}$"
68
+
69
+
70
+ def unit_of(expr: sympy.Expr) -> sympy.Expr:
71
+ """Get unit of sympy expr."""
72
+ return expr.as_coeff_Mul()[1]
73
+
74
+
75
+ @dataclass
76
+ class Failure:
77
+ """Unit test failure."""
78
+
79
+ expected: sympy.Expr
80
+ obtained: sympy.Expr
81
+
82
+ @property
83
+ def difference(self) -> sympy.Expr:
84
+ """Difference between expected and obtained unit."""
85
+ return self.expected / self.obtained # type: ignore
86
+
87
+
88
+ @dataclass
89
+ class MdText:
90
+ """Generic markdown text."""
91
+
92
+ content: list[str]
93
+
94
+ def _repr_markdown_(self) -> str:
95
+ return "\n".join(self.content)
96
+
97
+
98
+ @dataclass
99
+ class UnitCheck:
100
+ """Container for unit check."""
101
+
102
+ per_variable: dict[str, dict[str, bool | Failure | None]]
103
+
104
+ @staticmethod
105
+ def _fmt_success(s: str) -> str:
106
+ return f"<span style='color: green'>{s}</span>"
107
+
108
+ @staticmethod
109
+ def _fmt_failed(s: str) -> str:
110
+ return f"<span style='color: red'>{s}</span>"
111
+
112
+ def correct_diff_eqs(self) -> dict[str, bool]:
113
+ """Get all correctly annotated reactions by variable."""
114
+ return {
115
+ var: all(isinstance(i, bool) for i in checks.values())
116
+ for var, checks in self.per_variable.items()
117
+ }
118
+
119
+ def report(self) -> MdText:
120
+ """Export check as markdown report."""
121
+ report = ["## Type check"]
122
+ for diff_eq, res in self.correct_diff_eqs().items():
123
+ txt = self._fmt_success("Correct") if res else self._fmt_failed("Failed")
124
+ report.append(f"\n### d{diff_eq}dt: {txt}")
125
+
126
+ if res:
127
+ continue
128
+ for k, v in self.per_variable[diff_eq].items():
129
+ match v:
130
+ case bool():
131
+ continue
132
+ case None:
133
+ report.append(f"\n- {k}")
134
+ report.append(" - Failed to parse")
135
+ case Failure(expected, obtained):
136
+ report.append(f"\n- {k}")
137
+ report.append(f" - expected: {_latex_view(expected)}")
138
+ report.append(f" - obtained: {_latex_view(obtained)}")
139
+ report.append(f" - difference: {_latex_view(v.difference)}")
140
+
141
+ return MdText(report)
142
+
143
+
144
+ @dataclass(kw_only=True, slots=True)
145
+ class TableView:
146
+ """Markdown view of pandas Dataframe.
147
+
148
+ Mostly used to get nice LaTeX rendering of sympy expressions.
149
+ """
150
+
151
+ data: pd.DataFrame
152
+
153
+ def __repr__(self) -> str:
154
+ """Normal Python shell output."""
155
+ return self.data.to_markdown()
156
+
157
+ def _repr_markdown_(self) -> str:
158
+ """Fancy IPython shell output.
159
+
160
+ Looks the same as __repr__, but is handled by IPython to output
161
+ `IPython.display.Markdown`, so looks nice
162
+ """
163
+ return self.data.to_markdown()
164
+
165
+
38
166
  @dataclass
39
167
  class Dependency:
40
168
  """Container class for building dependency tree."""
@@ -185,8 +313,6 @@ def _sort_dependencies(
185
313
  SortError: If circular dependencies are detected
186
314
 
187
315
  """
188
- from queue import Empty, SimpleQueue
189
-
190
316
  _check_if_is_sortable(available, elements)
191
317
 
192
318
  order = []
@@ -248,12 +374,13 @@ class ModelCache:
248
374
 
249
375
  """
250
376
 
377
+ order: list[str] # mostly for debug purposes
251
378
  var_names: list[str]
252
379
  dyn_order: list[str]
380
+ base_parameter_values: dict[str, float]
253
381
  all_parameter_values: dict[str, float]
254
382
  stoich_by_cpds: dict[str, dict[str, float]]
255
383
  dyn_stoich_by_cpds: dict[str, dict[str, Derived]]
256
- dxdt: pd.Series
257
384
  initial_conditions: dict[str, float]
258
385
 
259
386
 
@@ -274,9 +401,9 @@ class Model:
274
401
 
275
402
  """
276
403
 
277
- _ids: dict[str, str] = field(default_factory=dict)
278
- _variables: dict[str, float | Derived] = field(default_factory=dict)
279
- _parameters: dict[str, float] = field(default_factory=dict)
404
+ _ids: dict[str, str] = field(default_factory=dict, repr=False)
405
+ _variables: dict[str, Variable] = field(default_factory=dict)
406
+ _parameters: dict[str, Parameter] = field(default_factory=dict)
280
407
  _derived: dict[str, Derived] = field(default_factory=dict)
281
408
  _readouts: dict[str, Readout] = field(default_factory=dict)
282
409
  _reactions: dict[str, Reaction] = field(default_factory=dict)
@@ -300,11 +427,32 @@ class Model:
300
427
  ModelCache: An instance of ModelCache containing the initialized cache data.
301
428
 
302
429
  """
303
- all_parameter_values: dict[str, float] = self._parameters.copy()
304
- all_parameter_names: set[str] = set(all_parameter_values)
430
+ parameter_names = set(self._parameters)
431
+ all_parameter_names = set(parameter_names) # later include static derived
432
+
433
+ base_parameter_values: dict[str, float] = {
434
+ k: val
435
+ for k, v in self._parameters.items()
436
+ if not isinstance(val := v.value, InitialAssignment)
437
+ }
438
+ base_variable_values: dict[str, float] = {
439
+ k: init
440
+ for k, v in self._variables.items()
441
+ if not isinstance(init := v.initial_value, InitialAssignment)
442
+ }
443
+ initial_assignments: dict[str, InitialAssignment] = {
444
+ k: init
445
+ for k, v in self._variables.items()
446
+ if isinstance(init := v.initial_value, InitialAssignment)
447
+ } | {
448
+ k: init
449
+ for k, v in self._parameters.items()
450
+ if isinstance(init := v.value, InitialAssignment)
451
+ }
305
452
 
306
453
  # Sanity checks
307
454
  for name, el in it.chain(
455
+ initial_assignments.items(),
308
456
  self._derived.items(),
309
457
  self._reactions.items(),
310
458
  self._readouts.items(),
@@ -313,17 +461,17 @@ class Model:
313
461
  raise ArityMismatchError(name, el.fn, el.args)
314
462
 
315
463
  # Sort derived & reactions
464
+ available = (
465
+ set(base_parameter_values)
466
+ | set(base_variable_values)
467
+ | set(self._data)
468
+ | {"time"}
469
+ )
316
470
  to_sort = (
317
- self._derived
318
- | self._reactions
319
- | self._surrogates
320
- | {k: v for k, v in self._variables.items() if isinstance(v, Derived)}
471
+ initial_assignments | self._derived | self._reactions | self._surrogates
321
472
  )
322
473
  order = _sort_dependencies(
323
- available=all_parameter_names
324
- | {k for k, v in self._variables.items() if not isinstance(v, Derived)}
325
- | set(self._data)
326
- | {"time"},
474
+ available=available,
327
475
  elements=[
328
476
  Dependency(name=k, required=set(v.args), provided={k})
329
477
  if not isinstance(v, AbstractSurrogate)
@@ -335,10 +483,7 @@ class Model:
335
483
  # Calculate all values once, including dynamic ones
336
484
  # That way, we can make initial conditions dependent on e.g. rates
337
485
  dependent = (
338
- all_parameter_values
339
- | self._data
340
- | {k: v for k, v in self._variables.items() if not isinstance(v, Derived)}
341
- | {"time": 0.0}
486
+ base_parameter_values | base_variable_values | self._data | {"time": 0.0}
342
487
  )
343
488
  for name in order:
344
489
  to_sort[name].calculate_inpl(name, dependent)
@@ -349,7 +494,7 @@ class Model:
349
494
  for name in order:
350
495
  if name in self._reactions or name in self._surrogates:
351
496
  dyn_order.append(name)
352
- elif name in self._variables:
497
+ elif name in self._variables or name in self._parameters:
353
498
  static_order.append(name)
354
499
  else:
355
500
  derived = self._derived[name]
@@ -390,27 +535,27 @@ class Model:
390
535
  d_static[rxn_name] = factor
391
536
 
392
537
  var_names = self.get_variable_names()
393
- dxdt = pd.Series(np.zeros(len(var_names), dtype=float), index=var_names)
394
-
395
538
  initial_conditions: dict[str, float] = {
396
- k: v for k, v in self._variables.items() if not isinstance(v, Derived)
539
+ k: cast(float, dependent[k]) for k in self._variables
397
540
  }
541
+ all_parameter_values = dict(base_parameter_values)
398
542
  for name in static_order:
399
543
  if name in self._variables:
400
- initial_conditions[name] = cast(float, dependent[name])
401
- elif name in self._derived:
544
+ continue # handled in initial_conditions above
545
+ if name in self._parameters or name in self._derived:
402
546
  all_parameter_values[name] = cast(float, dependent[name])
403
547
  else:
404
548
  msg = "Unknown target for static derived variable."
405
549
  raise KeyError(msg)
406
550
 
407
551
  self._cache = ModelCache(
552
+ order=order,
408
553
  var_names=var_names,
409
554
  dyn_order=dyn_order,
555
+ base_parameter_values=base_parameter_values,
410
556
  all_parameter_values=all_parameter_values,
411
557
  stoich_by_cpds=stoich_by_compounds,
412
558
  dyn_stoich_by_cpds=dyn_stoich_by_compounds,
413
- dxdt=dxdt,
414
559
  initial_conditions=initial_conditions,
415
560
  )
416
561
  return self._cache
@@ -465,29 +610,103 @@ class Model:
465
610
  del self._ids[name]
466
611
 
467
612
  ##########################################################################
468
- # Parameters
613
+ # Parameters - views
469
614
  ##########################################################################
470
615
 
616
+ @property
617
+ def parameters(self) -> TableView:
618
+ """Return view of parameters."""
619
+ index = list(self._parameters.keys())
620
+ data = []
621
+ for name, el in self._parameters.items():
622
+ if isinstance(init := el.value, InitialAssignment):
623
+ value_str = _latex_view(
624
+ fn_to_sympy(
625
+ init.fn,
626
+ origin=name,
627
+ model_args=list_of_symbols(init.args),
628
+ )
629
+ )
630
+ else:
631
+ value_str = str(init)
632
+ data.append(
633
+ {
634
+ "value": value_str,
635
+ "unit": _latex_view(unit) if (unit := el.unit) is not None else "",
636
+ # "source": ...,
637
+ }
638
+ )
639
+ return TableView(data=pd.DataFrame(data, index=index))
640
+
641
+ def get_raw_parameters(self, *, as_copy: bool = True) -> dict[str, Parameter]:
642
+ """Returns the parameters of the model."""
643
+ if as_copy:
644
+ return copy.deepcopy(self._parameters)
645
+ return self._parameters
646
+
647
+ def get_parameter_values(self) -> dict[str, float]:
648
+ """Returns the parameters of the model.
649
+
650
+ Examples:
651
+ >>> model.parameters
652
+ {"k1": 0.1, "k2": 0.2}
653
+
654
+ Returns:
655
+ parameters: A dictionary where the keys are parameter names (as strings)
656
+ and the values are parameter values (as floats).
657
+
658
+ """
659
+ if (cache := self._cache) is None:
660
+ cache = self._create_cache()
661
+ return cache.base_parameter_values
662
+
663
+ def get_parameter_names(self) -> list[str]:
664
+ """Retrieve the names of the parameters.
665
+
666
+ Examples:
667
+ >>> model.get_parameter_names()
668
+ ['k1', 'k2']
669
+
670
+ Returns:
671
+ parametes: A list containing the names of the parameters.
672
+
673
+ """
674
+ return list(self._parameters)
675
+
676
+ #####################################
677
+ # Parameters - create
678
+ #####################################
679
+
471
680
  @_invalidate_cache
472
- def add_parameter(self, name: str, value: float) -> Self:
681
+ def add_parameter(
682
+ self,
683
+ name: str,
684
+ value: float | InitialAssignment,
685
+ unit: sympy.Expr | None = None,
686
+ source: str | None = None,
687
+ ) -> Self:
473
688
  """Adds a parameter to the model.
474
689
 
475
690
  Examples:
476
691
  >>> model.add_parameter("k1", 0.1)
477
692
 
478
693
  Args:
479
- name (str): The name of the parameter.
480
- value (float): The value of the parameter.
694
+ name: The name of the parameter.
695
+ value: The value of the parameter.
696
+ unit: unit of the parameter
697
+ source: source of the information given
481
698
 
482
699
  Returns:
483
700
  Self: The instance of the model with the added parameter.
484
701
 
485
702
  """
486
703
  self._insert_id(name=name, ctx="parameter")
487
- self._parameters[name] = value
704
+ self._parameters[name] = Parameter(value=value, unit=unit, source=source)
488
705
  return self
489
706
 
490
- def add_parameters(self, parameters: dict[str, float]) -> Self:
707
+ def add_parameters(
708
+ self, parameters: Mapping[str, float | Parameter | InitialAssignment]
709
+ ) -> Self:
491
710
  """Adds multiple parameters to the model.
492
711
 
493
712
  Examples:
@@ -502,36 +721,15 @@ class Model:
502
721
 
503
722
  """
504
723
  for k, v in parameters.items():
505
- self.add_parameter(k, v)
724
+ if isinstance(v, Parameter):
725
+ self.add_parameter(k, v.value, unit=v.unit, source=v.source)
726
+ else:
727
+ self.add_parameter(k, v)
506
728
  return self
507
729
 
508
- @property
509
- def parameters(self) -> dict[str, float]:
510
- """Returns the parameters of the model.
511
-
512
- Examples:
513
- >>> model.parameters
514
- {"k1": 0.1, "k2": 0.2}
515
-
516
- Returns:
517
- parameters: A dictionary where the keys are parameter names (as strings)
518
- and the values are parameter values (as floats).
519
-
520
- """
521
- return self._parameters.copy()
522
-
523
- def get_parameter_names(self) -> list[str]:
524
- """Retrieve the names of the parameters.
525
-
526
- Examples:
527
- >>> model.get_parameter_names()
528
- ['k1', 'k2']
529
-
530
- Returns:
531
- parametes: A list containing the names of the parameters.
532
-
533
- """
534
- return list(self._parameters)
730
+ #####################################
731
+ # Parameters - delete
732
+ #####################################
535
733
 
536
734
  @_invalidate_cache
537
735
  def remove_parameter(self, name: str) -> Self:
@@ -568,8 +766,19 @@ class Model:
568
766
  self.remove_parameter(name)
569
767
  return self
570
768
 
769
+ #####################################
770
+ # Parameters - update
771
+ #####################################
772
+
571
773
  @_invalidate_cache
572
- def update_parameter(self, name: str, value: float) -> Self:
774
+ def update_parameter(
775
+ self,
776
+ name: str,
777
+ value: float | InitialAssignment | None = None,
778
+ *,
779
+ unit: sympy.Expr | None = None,
780
+ source: str | None = None,
781
+ ) -> Self:
573
782
  """Update the value of a parameter.
574
783
 
575
784
  Examples:
@@ -578,6 +787,8 @@ class Model:
578
787
  Args:
579
788
  name: The name of the parameter to update.
580
789
  value: The new value for the parameter.
790
+ unit: Unit of the parameter
791
+ source: Source of the information
581
792
 
582
793
  Returns:
583
794
  Self: The instance of the class with the updated parameter.
@@ -587,12 +798,21 @@ class Model:
587
798
 
588
799
  """
589
800
  if name not in self._parameters:
590
- msg = f"'{name}' not found in parameters"
801
+ msg = f"{name!r} not found in parameters"
591
802
  raise KeyError(msg)
592
- self._parameters[name] = value
803
+
804
+ parameter = self._parameters[name]
805
+ if value is not None:
806
+ parameter.value = value
807
+ if unit is not None:
808
+ parameter.unit = unit
809
+ if source is not None:
810
+ parameter.source = source
593
811
  return self
594
812
 
595
- def update_parameters(self, parameters: dict[str, float]) -> Self:
813
+ def update_parameters(
814
+ self, parameters: Mapping[str, float | Parameter | InitialAssignment]
815
+ ) -> Self:
596
816
  """Update multiple parameters of the model.
597
817
 
598
818
  Examples:
@@ -606,7 +826,10 @@ class Model:
606
826
 
607
827
  """
608
828
  for k, v in parameters.items():
609
- self.update_parameter(k, v)
829
+ if isinstance(v, Parameter):
830
+ self.update_parameter(k, value=v.value, unit=v.unit, source=v.source)
831
+ else:
832
+ self.update_parameter(k, v)
610
833
  return self
611
834
 
612
835
  def scale_parameter(self, name: str, factor: float) -> Self:
@@ -623,7 +846,17 @@ class Model:
623
846
  Self: The instance of the class with the updated parameter.
624
847
 
625
848
  """
626
- return self.update_parameter(name, self._parameters[name] * factor)
849
+ old = self._parameters[name].value
850
+ if isinstance(old, InitialAssignment):
851
+ LOGGER.warning("Overwriting initial assignment %s", name)
852
+ if (cache := self._cache) is None:
853
+ cache = self._create_cache()
854
+
855
+ return self.update_parameter(
856
+ name, cache.all_parameter_values[name] * factor
857
+ )
858
+
859
+ return self.update_parameter(name, old * factor)
627
860
 
628
861
  def scale_parameters(self, parameters: dict[str, float]) -> Self:
629
862
  """Scales the parameters of the model.
@@ -667,7 +900,7 @@ class Model:
667
900
  Self: The instance of the model with the parameter converted to a variable.
668
901
 
669
902
  """
670
- value = self._parameters[name] if initial_value is None else initial_value
903
+ value = self._parameters[name].value if initial_value is None else initial_value
671
904
  self.remove_parameter(name)
672
905
  self.add_variable(name, value)
673
906
 
@@ -708,7 +941,7 @@ class Model:
708
941
  ##########################################################################
709
942
 
710
943
  @property
711
- def variables(self) -> dict[str, float | Derived]:
944
+ def variables(self) -> TableView:
712
945
  """Returns a copy of the variables dictionary.
713
946
 
714
947
  Examples:
@@ -722,10 +955,79 @@ class Model:
722
955
  dict[str, float]: A copy of the variables dictionary.
723
956
 
724
957
  """
725
- return self._variables.copy()
958
+ index = list(self._variables.keys())
959
+ data = []
960
+ for name, el in self._variables.items():
961
+ if isinstance(init := el.initial_value, InitialAssignment):
962
+ value_str = _latex_view(
963
+ fn_to_sympy(
964
+ init.fn,
965
+ origin=name,
966
+ model_args=list_of_symbols(init.args),
967
+ )
968
+ )
969
+ else:
970
+ value_str = str(init)
971
+ data.append(
972
+ {
973
+ "value": value_str,
974
+ "unit": _latex_view(unit) if (unit := el.unit) is not None else "",
975
+ # "source"
976
+ }
977
+ )
978
+ return TableView(data=pd.DataFrame(data, index=index))
979
+
980
+ def get_raw_variables(self, *, as_copy: bool = True) -> dict[str, Variable]:
981
+ """Retrieve the initial conditions of the model.
982
+
983
+ Examples:
984
+ >>> model.get_initial_conditions()
985
+ {"x1": 1.0, "x2": 2.0}
986
+
987
+ Returns:
988
+ initial_conditions: A dictionary where the keys are variable names and the values are their initial conditions.
989
+
990
+ """
991
+ if as_copy:
992
+ return copy.deepcopy(self._variables)
993
+ return self._variables
994
+
995
+ def get_initial_conditions(self) -> dict[str, float]:
996
+ """Retrieve the initial conditions of the model.
997
+
998
+ Examples:
999
+ >>> model.get_initial_conditions()
1000
+ {"x1": 1.0, "x2": 2.0}
1001
+
1002
+ Returns:
1003
+ initial_conditions: A dictionary where the keys are variable names and the values are their initial conditions.
1004
+
1005
+ """
1006
+ if (cache := self._cache) is None:
1007
+ cache = self._create_cache()
1008
+ return cache.initial_conditions
1009
+
1010
+ def get_variable_names(self) -> list[str]:
1011
+ """Retrieve the names of all variables.
1012
+
1013
+ Examples:
1014
+ >>> model.get_variable_names()
1015
+ ["x1", "x2"]
1016
+
1017
+ Returns:
1018
+ variable_names: A list containing the names of all variables.
1019
+
1020
+ """
1021
+ return list(self._variables)
726
1022
 
727
1023
  @_invalidate_cache
728
- def add_variable(self, name: str, initial_condition: float | Derived) -> Self:
1024
+ def add_variable(
1025
+ self,
1026
+ name: str,
1027
+ initial_value: float | InitialAssignment,
1028
+ unit: sympy.Expr | None = None,
1029
+ source: str | None = None,
1030
+ ) -> Self:
729
1031
  """Adds a variable to the model with the given name and initial condition.
730
1032
 
731
1033
  Examples:
@@ -733,17 +1035,23 @@ class Model:
733
1035
 
734
1036
  Args:
735
1037
  name: The name of the variable to add.
736
- initial_condition: The initial condition value for the variable.
1038
+ initial_value: The initial condition value for the variable.
1039
+ unit: unit of the variable
1040
+ source: source of the information given
737
1041
 
738
1042
  Returns:
739
1043
  Self: The instance of the model with the added variable.
740
1044
 
741
1045
  """
742
1046
  self._insert_id(name=name, ctx="variable")
743
- self._variables[name] = initial_condition
1047
+ self._variables[name] = Variable(
1048
+ initial_value=initial_value, unit=unit, source=source
1049
+ )
744
1050
  return self
745
1051
 
746
- def add_variables(self, variables: Mapping[str, float | Derived]) -> Self:
1052
+ def add_variables(
1053
+ self, variables: Mapping[str, float | Variable | InitialAssignment]
1054
+ ) -> Self:
747
1055
  """Adds multiple variables to the model with their initial conditions.
748
1056
 
749
1057
  Examples:
@@ -757,8 +1065,16 @@ class Model:
757
1065
  Self: The instance of the model with the added variables.
758
1066
 
759
1067
  """
760
- for name, y0 in variables.items():
761
- self.add_variable(name=name, initial_condition=y0)
1068
+ for name, v in variables.items():
1069
+ if isinstance(v, Variable):
1070
+ self.add_variable(
1071
+ name=name,
1072
+ initial_value=v.initial_value,
1073
+ unit=v.unit,
1074
+ source=v.source,
1075
+ )
1076
+ else:
1077
+ self.add_variable(name=name, initial_value=v)
762
1078
  return self
763
1079
 
764
1080
  @_invalidate_cache
@@ -797,7 +1113,13 @@ class Model:
797
1113
  return self
798
1114
 
799
1115
  @_invalidate_cache
800
- def update_variable(self, name: str, initial_condition: float | Derived) -> Self:
1116
+ def update_variable(
1117
+ self,
1118
+ name: str,
1119
+ initial_value: float | InitialAssignment,
1120
+ unit: sympy.Expr | None = None,
1121
+ source: str | None = None,
1122
+ ) -> Self:
801
1123
  """Updates the value of a variable in the model.
802
1124
 
803
1125
  Examples:
@@ -805,7 +1127,9 @@ class Model:
805
1127
 
806
1128
  Args:
807
1129
  name: The name of the variable to update.
808
- initial_condition: The initial condition or value to set for the variable.
1130
+ initial_value: The initial condition or value to set for the variable.
1131
+ unit: Unit of the variable
1132
+ source: Source of the information
809
1133
 
810
1134
  Returns:
811
1135
  Self: The instance of the model with the updated variable.
@@ -814,10 +1138,20 @@ class Model:
814
1138
  if name not in self._variables:
815
1139
  msg = f"'{name}' not found in variables"
816
1140
  raise KeyError(msg)
817
- self._variables[name] = initial_condition
1141
+
1142
+ variable = self._variables[name]
1143
+
1144
+ if initial_value is not None:
1145
+ variable.initial_value = initial_value
1146
+ if unit is not None:
1147
+ variable.unit = unit
1148
+ if source is not None:
1149
+ variable.source = source
818
1150
  return self
819
1151
 
820
- def update_variables(self, variables: Mapping[str, float | Derived]) -> Self:
1152
+ def update_variables(
1153
+ self, variables: Mapping[str, float | Variable | InitialAssignment]
1154
+ ) -> Self:
821
1155
  """Updates multiple variables in the model.
822
1156
 
823
1157
  Examples:
@@ -831,37 +1165,17 @@ class Model:
831
1165
 
832
1166
  """
833
1167
  for k, v in variables.items():
834
- self.update_variable(k, v)
1168
+ if isinstance(v, Variable):
1169
+ self.update_variable(
1170
+ k,
1171
+ initial_value=v.initial_value,
1172
+ unit=v.unit,
1173
+ source=v.source,
1174
+ )
1175
+ else:
1176
+ self.update_variable(k, v)
835
1177
  return self
836
1178
 
837
- def get_variable_names(self) -> list[str]:
838
- """Retrieve the names of all variables.
839
-
840
- Examples:
841
- >>> model.get_variable_names()
842
- ["x1", "x2"]
843
-
844
- Returns:
845
- variable_names: A list containing the names of all variables.
846
-
847
- """
848
- return list(self._variables)
849
-
850
- def get_initial_conditions(self) -> dict[str, float]:
851
- """Retrieve the initial conditions of the model.
852
-
853
- Examples:
854
- >>> model.get_initial_conditions()
855
- {"x1": 1.0, "x2": 2.0}
856
-
857
- Returns:
858
- initial_conditions: A dictionary where the keys are variable names and the values are their initial conditions.
859
-
860
- """
861
- if (cache := self._cache) is None:
862
- cache = self._create_cache()
863
- return cache.initial_conditions
864
-
865
1179
  def make_variable_static(self, name: str, value: float | None = None) -> Self:
866
1180
  """Converts a variable to a static parameter.
867
1181
 
@@ -881,8 +1195,12 @@ class Model:
881
1195
  Self: The instance of the class for method chaining.
882
1196
 
883
1197
  """
884
- value_or_derived = self._variables[name] if value is None else value
1198
+ value_or_derived = (
1199
+ self._variables[name].initial_value if value is None else value
1200
+ )
885
1201
  self.remove_variable(name)
1202
+
1203
+ # FIXME: better handling of unit
886
1204
  if isinstance(value_or_derived, Derived):
887
1205
  self.add_derived(name, value_or_derived.fn, args=value_or_derived.args)
888
1206
  else:
@@ -905,8 +1223,8 @@ class Model:
905
1223
  ##########################################################################
906
1224
 
907
1225
  @property
908
- def derived(self) -> dict[str, Derived]:
909
- """Returns a copy of the derived quantities.
1226
+ def derived(self) -> TableView:
1227
+ """Returns a view of the derived quantities.
910
1228
 
911
1229
  Examples:
912
1230
  >>> model.derived
@@ -917,10 +1235,30 @@ class Model:
917
1235
  dict[str, Derived]: A copy of the derived dictionary.
918
1236
 
919
1237
  """
920
- return self._derived.copy()
1238
+ index = list(self._derived.keys())
1239
+ data = [
1240
+ {
1241
+ "value": _latex_view(
1242
+ fn_to_sympy(
1243
+ el.fn,
1244
+ origin=name,
1245
+ model_args=list_of_symbols(el.args),
1246
+ )
1247
+ ),
1248
+ "unit": _latex_view(unit) if (unit := el.unit) is not None else "",
1249
+ }
1250
+ for name, el in self._derived.items()
1251
+ ]
921
1252
 
922
- @property
923
- def derived_variables(self) -> dict[str, Derived]:
1253
+ return TableView(data=pd.DataFrame(data, index=index))
1254
+
1255
+ def get_raw_derived(self, *, as_copy: bool = True) -> dict[str, Derived]:
1256
+ """Get copy of derived values."""
1257
+ if as_copy:
1258
+ return copy.deepcopy(self._derived)
1259
+ return self._derived
1260
+
1261
+ def get_derived_variables(self) -> dict[str, Derived]:
924
1262
  """Returns a dictionary of derived variables.
925
1263
 
926
1264
  Examples:
@@ -940,8 +1278,7 @@ class Model:
940
1278
 
941
1279
  return {k: v for k, v in derived.items() if k not in cache.all_parameter_values}
942
1280
 
943
- @property
944
- def derived_parameters(self) -> dict[str, Derived]:
1281
+ def get_derived_parameters(self) -> dict[str, Derived]:
945
1282
  """Returns a dictionary of derived parameters.
946
1283
 
947
1284
  Examples:
@@ -966,6 +1303,7 @@ class Model:
966
1303
  fn: RateFn,
967
1304
  *,
968
1305
  args: list[str],
1306
+ unit: sympy.Expr | None = None,
969
1307
  ) -> Self:
970
1308
  """Adds a derived attribute to the model.
971
1309
 
@@ -976,13 +1314,14 @@ class Model:
976
1314
  name: The name of the derived attribute.
977
1315
  fn: The function used to compute the derived attribute.
978
1316
  args: The list of arguments to be passed to the function.
1317
+ unit: Unit of the derived value
979
1318
 
980
1319
  Returns:
981
1320
  Self: The instance of the model with the added derived attribute.
982
1321
 
983
1322
  """
984
1323
  self._insert_id(name=name, ctx="derived")
985
- self._derived[name] = Derived(fn=fn, args=args)
1324
+ self._derived[name] = Derived(fn=fn, args=args, unit=unit)
986
1325
  return self
987
1326
 
988
1327
  def get_derived_parameter_names(self) -> list[str]:
@@ -996,7 +1335,7 @@ class Model:
996
1335
  A list of names of the derived parameters.
997
1336
 
998
1337
  """
999
- return list(self.derived_parameters)
1338
+ return list(self.get_derived_parameters())
1000
1339
 
1001
1340
  def get_derived_variable_names(self) -> list[str]:
1002
1341
  """Retrieve the names of derived variables.
@@ -1009,7 +1348,7 @@ class Model:
1009
1348
  A list of names of derived variables.
1010
1349
 
1011
1350
  """
1012
- return list(self.derived_variables)
1351
+ return list(self.get_derived_variables())
1013
1352
 
1014
1353
  @_invalidate_cache
1015
1354
  def update_derived(
@@ -1018,6 +1357,7 @@ class Model:
1018
1357
  fn: RateFn | None = None,
1019
1358
  *,
1020
1359
  args: list[str] | None = None,
1360
+ unit: sympy.Expr | None = None,
1021
1361
  ) -> Self:
1022
1362
  """Updates the derived function and its arguments for a given name.
1023
1363
 
@@ -1026,16 +1366,21 @@ class Model:
1026
1366
 
1027
1367
  Args:
1028
1368
  name: The name of the derived function to update.
1029
- fn: The new derived function. If None, the existing function is retained. Defaults to None.
1030
- args: The new arguments for the derived function. If None, the existing arguments are retained. Defaults to None.
1369
+ fn: The new derived function. If None, the existing function is retained.
1370
+ args: The new arguments for the derived function. If None, the existing arguments are retained.
1371
+ unit: Unit of the derived value
1031
1372
 
1032
1373
  Returns:
1033
1374
  Self: The instance of the class with the updated derived function and arguments.
1034
1375
 
1035
1376
  """
1036
1377
  der = self._derived[name]
1037
- der.fn = der.fn if fn is None else fn
1038
- der.args = der.args if args is None else args
1378
+ if fn is not None:
1379
+ der.fn = fn
1380
+ if args is not None:
1381
+ der.args = args
1382
+ if unit is not None:
1383
+ der.unit = unit
1039
1384
  return self
1040
1385
 
1041
1386
  @_invalidate_cache
@@ -1061,7 +1406,27 @@ class Model:
1061
1406
  ###########################################################################
1062
1407
 
1063
1408
  @property
1064
- def reactions(self) -> dict[str, Reaction]:
1409
+ def reactions(self) -> TableView:
1410
+ """Get view of reactions."""
1411
+ index = list(self._reactions.keys())
1412
+ data = [
1413
+ {
1414
+ "value": _latex_view(
1415
+ fn_to_sympy(
1416
+ rxn.fn,
1417
+ origin=name,
1418
+ model_args=list_of_symbols(rxn.args),
1419
+ )
1420
+ ),
1421
+ "stoichiometry": stoichiometries_to_sympy(name, rxn.stoichiometry),
1422
+ "unit": _latex_view(unit) if (unit := rxn.unit) is not None else "",
1423
+ # "source"
1424
+ }
1425
+ for name, rxn in self._reactions.items()
1426
+ ]
1427
+ return TableView(data=pd.DataFrame(data, index=index))
1428
+
1429
+ def get_raw_reactions(self, *, as_copy: bool = True) -> dict[str, Reaction]:
1065
1430
  """Retrieve the reactions in the model.
1066
1431
 
1067
1432
  Examples:
@@ -1072,7 +1437,9 @@ class Model:
1072
1437
  dict[str, Reaction]: A deep copy of the reactions dictionary.
1073
1438
 
1074
1439
  """
1075
- return copy.deepcopy(self._reactions)
1440
+ if as_copy:
1441
+ return copy.deepcopy(self._reactions)
1442
+ return self._reactions
1076
1443
 
1077
1444
  def get_stoichiometries(
1078
1445
  self, variables: dict[str, float] | None = None, time: float = 0.0
@@ -1091,7 +1458,7 @@ class Model:
1091
1458
  """
1092
1459
  if (cache := self._cache) is None:
1093
1460
  cache = self._create_cache()
1094
- args = self.get_dependent(variables=variables, time=time)
1461
+ args = self.get_args(variables=variables, time=time)
1095
1462
 
1096
1463
  stoich_by_cpds = copy.deepcopy(cache.stoich_by_cpds)
1097
1464
  for cpd, stoich in cache.dyn_stoich_by_cpds.items():
@@ -1121,7 +1488,7 @@ class Model:
1121
1488
  """
1122
1489
  if (cache := self._cache) is None:
1123
1490
  cache = self._create_cache()
1124
- args = self.get_dependent(variables=variables, time=time)
1491
+ args = self.get_args(variables=variables, time=time)
1125
1492
 
1126
1493
  stoich = copy.deepcopy(cache.stoich_by_cpds[variable])
1127
1494
  for rxn, derived in cache.dyn_stoich_by_cpds.get(variable, {}).items():
@@ -1155,6 +1522,8 @@ class Model:
1155
1522
  *,
1156
1523
  args: list[str],
1157
1524
  stoichiometry: Mapping[str, float | str | Derived],
1525
+ unit: sympy.Expr | None = None,
1526
+ # source: str | None = None,
1158
1527
  ) -> Self:
1159
1528
  """Adds a reaction to the model.
1160
1529
 
@@ -1170,6 +1539,7 @@ class Model:
1170
1539
  fn: The function representing the reaction.
1171
1540
  args: A list of arguments for the reaction function.
1172
1541
  stoichiometry: The stoichiometry of the reaction, mapping species to their coefficients.
1542
+ unit: Unit of the rate
1173
1543
 
1174
1544
  Returns:
1175
1545
  Self: The instance of the model with the added reaction.
@@ -1181,7 +1551,12 @@ class Model:
1181
1551
  k: Derived(fn=fns.constant, args=[v]) if isinstance(v, str) else v
1182
1552
  for k, v in stoichiometry.items()
1183
1553
  }
1184
- self._reactions[name] = Reaction(fn=fn, stoichiometry=stoich, args=args)
1554
+ self._reactions[name] = Reaction(
1555
+ fn=fn,
1556
+ stoichiometry=stoich,
1557
+ args=args,
1558
+ unit=unit,
1559
+ )
1185
1560
  return self
1186
1561
 
1187
1562
  def get_reaction_names(self) -> list[str]:
@@ -1205,6 +1580,7 @@ class Model:
1205
1580
  *,
1206
1581
  args: list[str] | None = None,
1207
1582
  stoichiometry: Mapping[str, float | Derived | str] | None = None,
1583
+ unit: sympy.Expr | None = None,
1208
1584
  ) -> Self:
1209
1585
  """Updates the properties of an existing reaction in the model.
1210
1586
 
@@ -1220,6 +1596,7 @@ class Model:
1220
1596
  fn: The new function for the reaction. If None, the existing function is retained.
1221
1597
  args: The new arguments for the reaction. If None, the existing arguments are retained.
1222
1598
  stoichiometry: The new stoichiometry for the reaction. If None, the existing stoichiometry is retained.
1599
+ unit: Unit of the reaction
1223
1600
 
1224
1601
  Returns:
1225
1602
  Self: The instance of the model with the updated reaction.
@@ -1235,6 +1612,7 @@ class Model:
1235
1612
  }
1236
1613
  rxn.stoichiometry = stoich
1237
1614
  rxn.args = rxn.args if args is None else args
1615
+ rxn.unit = rxn.unit if unit is None else unit
1238
1616
  return self
1239
1617
 
1240
1618
  @_invalidate_cache
@@ -1285,7 +1663,14 @@ class Model:
1285
1663
  # Think of something like NADPH / (NADP + NADPH) as a proxy for energy state
1286
1664
  ##########################################################################
1287
1665
 
1288
- def add_readout(self, name: str, fn: RateFn, *, args: list[str]) -> Self:
1666
+ def add_readout(
1667
+ self,
1668
+ name: str,
1669
+ fn: RateFn,
1670
+ *,
1671
+ args: list[str],
1672
+ unit: sympy.Expr | None = None,
1673
+ ) -> Self:
1289
1674
  """Adds a readout to the model.
1290
1675
 
1291
1676
  Examples:
@@ -1298,13 +1683,14 @@ class Model:
1298
1683
  name: The name of the readout.
1299
1684
  fn: The function to be used for the readout.
1300
1685
  args: The list of arguments for the function.
1686
+ unit: Unit of the readout
1301
1687
 
1302
1688
  Returns:
1303
1689
  Self: The instance of the model with the added readout.
1304
1690
 
1305
1691
  """
1306
1692
  self._insert_id(name=name, ctx="readout")
1307
- self._readouts[name] = Readout(fn=fn, args=args)
1693
+ self._readouts[name] = Readout(fn=fn, args=args, unit=unit)
1308
1694
  return self
1309
1695
 
1310
1696
  def get_readout_names(self) -> list[str]:
@@ -1320,6 +1706,12 @@ class Model:
1320
1706
  """
1321
1707
  return list(self._readouts)
1322
1708
 
1709
+ def get_raw_readouts(self, *, as_copy: bool = True) -> dict[str, Readout]:
1710
+ """Get copy of readouts in the model."""
1711
+ if as_copy:
1712
+ return copy.deepcopy(self._readouts)
1713
+ return self._readouts
1714
+
1323
1715
  def remove_readout(self, name: str) -> Self:
1324
1716
  """Remove a readout by its name.
1325
1717
 
@@ -1428,6 +1820,38 @@ class Model:
1428
1820
  self._surrogates.pop(name)
1429
1821
  return self
1430
1822
 
1823
+ def get_raw_surrogates(
1824
+ self, *, as_copy: bool = True
1825
+ ) -> dict[str, AbstractSurrogate]:
1826
+ """Get direct copies of model surrogates."""
1827
+ if as_copy:
1828
+ return copy.deepcopy(self._surrogates)
1829
+ return self._surrogates
1830
+
1831
+ def get_surrogate_output_names(
1832
+ self,
1833
+ *,
1834
+ include_fluxes: bool = True,
1835
+ ) -> list[str]:
1836
+ """Return output names by surrogates.
1837
+
1838
+ Optionally filter out the names of which surrogate outfluxes are actually
1839
+ fluxes / reactions rather than variables.
1840
+
1841
+ Args:
1842
+ include_fluxes: whether to also include outputs which are reaction
1843
+ names
1844
+
1845
+ """
1846
+ names = []
1847
+ if include_fluxes:
1848
+ for i in self._surrogates.values():
1849
+ names.extend(i.outputs)
1850
+ else:
1851
+ for i in self._surrogates.values():
1852
+ names.extend(x for x in i.outputs if x not in i.stoichiometries)
1853
+ return names
1854
+
1431
1855
  def get_surrogate_reaction_names(self) -> list[str]:
1432
1856
  """Return reaction names by surrogates."""
1433
1857
  names = []
@@ -1464,7 +1888,42 @@ class Model:
1464
1888
  # - readouts
1465
1889
  ##########################################################################
1466
1890
 
1467
- def _get_dependent(
1891
+ def get_arg_names(
1892
+ self,
1893
+ *,
1894
+ include_time: bool,
1895
+ include_variables: bool,
1896
+ include_parameters: bool,
1897
+ include_derived_parameters: bool,
1898
+ include_derived_variables: bool,
1899
+ include_reactions: bool,
1900
+ include_surrogate_variables: bool,
1901
+ include_surrogate_fluxes: bool,
1902
+ include_readouts: bool,
1903
+ ) -> list[str]:
1904
+ """Get names of all kinds of model components."""
1905
+ names = []
1906
+ if include_time:
1907
+ names.append("time")
1908
+ if include_variables:
1909
+ names.extend(self.get_variable_names())
1910
+ if include_parameters:
1911
+ names.extend(self.get_parameter_names())
1912
+ if include_derived_variables:
1913
+ names.extend(self.get_derived_variable_names())
1914
+ if include_derived_parameters:
1915
+ names.extend(self.get_derived_parameter_names())
1916
+ if include_reactions:
1917
+ names.extend(self.get_reaction_names())
1918
+ if include_surrogate_variables:
1919
+ names.extend(self.get_surrogate_output_names(include_fluxes=False))
1920
+ if include_surrogate_fluxes:
1921
+ names.extend(self.get_surrogate_reaction_names())
1922
+ if include_readouts:
1923
+ names.extend(self.get_readout_names())
1924
+ return names
1925
+
1926
+ def _get_args(
1468
1927
  self,
1469
1928
  variables: dict[str, float],
1470
1929
  time: float = 0.0,
@@ -1474,7 +1933,7 @@ class Model:
1474
1933
  """Generate a dictionary of model components dependent on other components.
1475
1934
 
1476
1935
  Examples:
1477
- >>> model._get_dependent({"x1": 1.0, "x2": 2.0}, time=0.0)
1936
+ >>> model._get_args({"x1": 1.0, "x2": 2.0}, time=0.0)
1478
1937
  {"x1": 1.0, "x2": 2.0, "k1": 0.1, "time": 0.0}
1479
1938
 
1480
1939
  Args:
@@ -1502,11 +1961,19 @@ class Model:
1502
1961
 
1503
1962
  return cast(dict[str, float], args)
1504
1963
 
1505
- def get_dependent(
1964
+ def get_args(
1506
1965
  self,
1507
1966
  variables: dict[str, float] | None = None,
1508
1967
  time: float = 0.0,
1509
1968
  *,
1969
+ include_time: bool = True,
1970
+ include_variables: bool = True,
1971
+ include_parameters: bool = True,
1972
+ include_derived_parameters: bool = True,
1973
+ include_derived_variables: bool = True,
1974
+ include_reactions: bool = True,
1975
+ include_surrogate_variables: bool = True,
1976
+ include_surrogate_fluxes: bool = True,
1510
1977
  include_readouts: bool = False,
1511
1978
  ) -> pd.Series:
1512
1979
  """Generate a pandas Series of arguments for the model.
@@ -1514,20 +1981,28 @@ class Model:
1514
1981
  Examples:
1515
1982
  # Using initial conditions
1516
1983
  >>> model.get_args()
1517
- {"x1": 1.get_dependent, "x2": 2.0, "k1": 0.1, "time": 0.0}
1984
+ {"x1": 1.get_args, "x2": 2.0, "k1": 0.1, "time": 0.0}
1518
1985
 
1519
1986
  # With custom concentrations
1520
- >>> model.get_dependent({"x1": 1.0, "x2": 2.0})
1987
+ >>> model.get_args({"x1": 1.0, "x2": 2.0})
1521
1988
  {"x1": 1.0, "x2": 2.0, "k1": 0.1, "time": 0.0}
1522
1989
 
1523
1990
  # With custom concentrations and time
1524
- >>> model.get_dependent({"x1": 1.0, "x2": 2.0}, time=1.0)
1991
+ >>> model.get_args({"x1": 1.0, "x2": 2.0}, time=1.0)
1525
1992
  {"x1": 1.0, "x2": 2.0, "k1": 0.1, "time": 1.0}
1526
1993
 
1527
1994
  Args:
1528
1995
  variables: A dictionary where keys are the names of the concentrations and values are their respective float values.
1529
- time: The time point at which the arguments are generated (default is 0.0).
1530
- include_readouts: Whether to include readouts in the arguments (default is False).
1996
+ time: The time point at which the arguments are generated.
1997
+ include_time: Whether to include the time as an argument
1998
+ include_variables: Whether to include variables
1999
+ include_parameters: Whether to include parameters
2000
+ include_derived_parameters: Whether to include derived parameters
2001
+ include_derived_variables: Whether to include derived variables
2002
+ include_reactions: Whether to include reactions
2003
+ include_surrogate_variables: Whether to include derive variables obtained from surrogate
2004
+ include_surrogate_fluxes: Whether to include surrogate fluxes
2005
+ include_readouts: Whether to include readouts
1531
2006
 
1532
2007
  Returns:
1533
2008
  A pandas Series containing the generated arguments with float dtype.
@@ -1535,111 +2010,62 @@ class Model:
1535
2010
  """
1536
2011
  if (cache := self._cache) is None:
1537
2012
  cache = self._create_cache()
1538
-
1539
- args = self._get_dependent(
2013
+ raw = self._get_args(
1540
2014
  variables=self.get_initial_conditions() if variables is None else variables,
1541
2015
  time=time,
1542
2016
  cache=cache,
1543
2017
  )
1544
-
1545
2018
  if include_readouts:
1546
2019
  for name, ro in self._readouts.items(): # FIXME: order?
1547
- ro.calculate_inpl(name, args)
1548
-
1549
- return pd.Series(args, dtype=float)
2020
+ ro.calculate_inpl(name, raw)
2021
+ args = pd.Series(raw, dtype=float)
2022
+ return args.loc[
2023
+ self.get_arg_names(
2024
+ include_time=include_time,
2025
+ include_variables=include_variables,
2026
+ include_parameters=include_parameters,
2027
+ include_derived_parameters=include_derived_parameters,
2028
+ include_derived_variables=include_derived_variables,
2029
+ include_reactions=include_reactions,
2030
+ include_surrogate_variables=include_surrogate_variables,
2031
+ include_surrogate_fluxes=include_surrogate_fluxes,
2032
+ include_readouts=include_readouts,
2033
+ )
2034
+ ]
1550
2035
 
1551
- def get_dependent_time_course(
2036
+ def _get_args_time_course(
1552
2037
  self,
1553
- variables: pd.DataFrame,
1554
2038
  *,
1555
- include_readouts: bool = False,
1556
- ) -> pd.DataFrame:
1557
- """Generate a DataFrame containing time course arguments for model evaluation.
1558
-
1559
- Examples:
1560
- >>> model.get_dependent_time_course(
1561
- ... pd.DataFrame({"x1": [1.0, 2.0], "x2": [2.0, 3.0]}
1562
- ... )
1563
- pd.DataFrame({
1564
- "x1": [1.0, 2.0],
1565
- "x2": [2.0, 3.0],
1566
- "k1": [0.1, 0.1],
1567
- "time": [0.0, 1.0]},
1568
- )
1569
-
1570
- Args:
1571
- variables: A DataFrame containing concentration data with time as the index.
1572
- include_readouts: If True, include readout variables in the resulting DataFrame.
1573
-
1574
- Returns:
1575
- A DataFrame containing the combined concentration data, parameter values,
1576
- derived variables, and optionally readout variables, with time as an additional column.
2039
+ variables: pd.DataFrame,
2040
+ include_readouts: bool,
2041
+ ) -> dict[float, dict[str, float]]:
2042
+ if (cache := self._cache) is None:
2043
+ cache = self._create_cache()
1577
2044
 
1578
- """
1579
- args = {
1580
- time: self.get_dependent(
2045
+ args_by_time = {}
2046
+ for time, values in variables.iterrows():
2047
+ args = self._get_args(
1581
2048
  variables=values.to_dict(),
1582
2049
  time=cast(float, time),
1583
- include_readouts=include_readouts,
2050
+ cache=cache,
1584
2051
  )
1585
- for time, values in variables.iterrows()
1586
- }
1587
-
1588
- return pd.DataFrame(args, dtype=float).T
1589
-
1590
- ##########################################################################
1591
- # Get args
1592
- ##########################################################################
1593
-
1594
- def get_args(
1595
- self,
1596
- variables: dict[str, float] | None = None,
1597
- time: float = 0.0,
1598
- *,
1599
- include_derived: bool = True,
1600
- include_readouts: bool = False,
1601
- ) -> pd.Series:
1602
- """Generate a pandas Series of arguments for the model.
1603
-
1604
- Examples:
1605
- # Using initial conditions
1606
- >>> model.get_args()
1607
- {"x1": 1.0, "x2": 2.0, "k1": 0.1, "time": 0.0}
1608
-
1609
- # With custom concentrations
1610
- >>> model.get_args({"x1": 1.0, "x2": 2.0})
1611
- {"x1": 1.0, "x2": 2.0, "k1": 0.1, "time": 0.0}
1612
-
1613
- # With custom concentrations and time
1614
- >>> model.get_args({"x1": 1.0, "x2": 2.0}, time=1.0)
1615
- {"x1": 1.0, "x2": 2.0, "k1": 0.1, "time": 1.0}
1616
-
1617
- Args:
1618
- variables: A dictionary where keys are the names of the concentrations and values are their respective float values.
1619
- time: The time point at which the arguments are generated.
1620
- include_derived: Whether to include derived variables in the arguments.
1621
- include_readouts: Whether to include readouts in the arguments.
1622
-
1623
- Returns:
1624
- A pandas Series containing the generated arguments with float dtype.
1625
-
1626
- """
1627
- names = self.get_variable_names()
1628
- if include_derived:
1629
- names.extend(self.get_derived_variable_names())
1630
- if include_readouts:
1631
- names.extend(self._readouts)
1632
-
1633
- args = self.get_dependent(
1634
- variables=variables, time=time, include_readouts=include_readouts
1635
- )
1636
- return args.loc[names]
2052
+ if include_readouts:
2053
+ for name, ro in self._readouts.items(): # FIXME: order?
2054
+ ro.calculate_inpl(name, args)
2055
+ args_by_time[time] = args
2056
+ return args_by_time
1637
2057
 
1638
2058
  def get_args_time_course(
1639
2059
  self,
1640
2060
  variables: pd.DataFrame,
1641
2061
  *,
1642
- include_derived: bool = True,
2062
+ include_variables: bool = True,
2063
+ include_parameters: bool = True,
2064
+ include_derived_parameters: bool = True,
2065
+ include_derived_variables: bool = True,
2066
+ include_reactions: bool = True,
2067
+ include_surrogate_variables: bool = True,
2068
+ include_surrogate_fluxes: bool = True,
1643
2069
  include_readouts: bool = False,
1644
2070
  ) -> pd.DataFrame:
1645
2071
  """Generate a DataFrame containing time course arguments for model evaluation.
@@ -1657,22 +2083,42 @@ class Model:
1657
2083
 
1658
2084
  Args:
1659
2085
  variables: A DataFrame containing concentration data with time as the index.
1660
- include_derived: Whether to include derived variables in the arguments.
1661
- include_readouts: If True, include readout variables in the resulting DataFrame.
2086
+ include_variables: Whether to include variables
2087
+ include_parameters: Whether to include parameters
2088
+ include_derived_parameters: Whether to include derived parameters
2089
+ include_derived_variables: Whether to include derived variables
2090
+ include_reactions: Whether to include reactions
2091
+ include_surrogate_variables: Whether to include variables derived from surrogates
2092
+ include_surrogate_fluxes: Whether to include surrogate fluxes
2093
+ include_readouts: Whether to include readouts
1662
2094
 
1663
2095
  Returns:
1664
2096
  A DataFrame containing the combined concentration data, parameter values,
1665
2097
  derived variables, and optionally readout variables, with time as an additional column.
1666
2098
 
1667
2099
  """
1668
- names = self.get_variable_names()
1669
- if include_derived:
1670
- names.extend(self.get_derived_variable_names())
1671
-
1672
- args = self.get_dependent_time_course(
1673
- variables=variables, include_readouts=include_readouts
1674
- )
1675
- return args.loc[:, names]
2100
+ args = pd.DataFrame(
2101
+ self._get_args_time_course(
2102
+ variables=variables,
2103
+ include_readouts=include_readouts,
2104
+ ),
2105
+ dtype=float,
2106
+ ).T
2107
+
2108
+ return args.loc[
2109
+ :,
2110
+ self.get_arg_names(
2111
+ include_time=False,
2112
+ include_variables=include_variables,
2113
+ include_parameters=include_parameters,
2114
+ include_derived_parameters=include_derived_parameters,
2115
+ include_derived_variables=include_derived_variables,
2116
+ include_reactions=include_reactions,
2117
+ include_surrogate_variables=include_surrogate_variables,
2118
+ include_surrogate_fluxes=include_surrogate_fluxes,
2119
+ include_readouts=include_readouts,
2120
+ ),
2121
+ ]
1676
2122
 
1677
2123
  ##########################################################################
1678
2124
  # Get fluxes
@@ -1706,16 +2152,19 @@ class Model:
1706
2152
  Fluxes: A pandas Series containing the fluxes for each reaction.
1707
2153
 
1708
2154
  """
1709
- names = self.get_reaction_names()
1710
- for surrogate in self._surrogates.values():
1711
- names.extend(surrogate.stoichiometries)
1712
-
1713
- args = self.get_dependent(
2155
+ return self.get_args(
1714
2156
  variables=variables,
1715
2157
  time=time,
2158
+ include_time=False,
2159
+ include_variables=False,
2160
+ include_parameters=False,
2161
+ include_derived_parameters=False,
2162
+ include_derived_variables=False,
2163
+ include_reactions=True,
2164
+ include_surrogate_variables=False,
2165
+ include_surrogate_fluxes=True,
1716
2166
  include_readouts=False,
1717
2167
  )
1718
- return args.loc[names]
1719
2168
 
1720
2169
  def get_fluxes_time_course(self, variables: pd.DataFrame) -> pd.DataFrame:
1721
2170
  """Generate a time course of fluxes for the given reactions and surrogates.
@@ -1725,7 +2174,9 @@ class Model:
1725
2174
  pd.DataFrame({"v1": [0.1, 0.2], "v2": [0.2, 0.3]})
1726
2175
 
1727
2176
  This method calculates the fluxes for each reaction in the model using the provided
1728
- arguments and combines them with the outputs from the surrogates to create a complete
2177
+ arguments and combines them wit names: list[str] = self.get_reaction_names()
2178
+ for surrogate in self._surrogates.values():
2179
+ names.extend(surrogate.stoichiometries)h the outputs from the surrogates to create a complete
1729
2180
  time course of fluxes.
1730
2181
 
1731
2182
  Args:
@@ -1739,21 +2190,23 @@ class Model:
1739
2190
  the index of the input arguments.
1740
2191
 
1741
2192
  """
1742
- names = self.get_reaction_names()
1743
- for surrogate in self._surrogates.values():
1744
- names.extend(surrogate.stoichiometries)
1745
-
1746
- variables = self.get_dependent_time_course(
2193
+ return self.get_args_time_course(
1747
2194
  variables=variables,
2195
+ include_variables=False,
2196
+ include_parameters=False,
2197
+ include_derived_parameters=False,
2198
+ include_derived_variables=False,
2199
+ include_reactions=True,
2200
+ include_surrogate_variables=False,
2201
+ include_surrogate_fluxes=True,
1748
2202
  include_readouts=False,
1749
2203
  )
1750
- return variables.loc[:, names]
1751
2204
 
1752
2205
  ##########################################################################
1753
2206
  # Get rhs
1754
2207
  ##########################################################################
1755
2208
 
1756
- def __call__(self, /, time: float, variables: Array) -> Array:
2209
+ def __call__(self, /, time: float, variables: Iterable[float]) -> tuple[float, ...]:
1757
2210
  """Simulation version of get_right_hand_side.
1758
2211
 
1759
2212
  Examples:
@@ -1761,7 +2214,7 @@ class Model:
1761
2214
  np.array([0.1, 0.2])
1762
2215
 
1763
2216
  Warning: Swaps t and y!
1764
- This can't get kw-only args, as the integrators call it with pos-only
2217
+ This can't get kw args, as the integrators call it with pos-only
1765
2218
 
1766
2219
  Args:
1767
2220
  time: The current time point.
@@ -1781,14 +2234,13 @@ class Model:
1781
2234
  strict=True,
1782
2235
  )
1783
2236
  )
1784
- dependent: dict[str, float] = self._get_dependent(
2237
+ dependent: dict[str, float] = self._get_args(
1785
2238
  variables=vars_d,
1786
2239
  time=time,
1787
2240
  cache=cache,
1788
2241
  )
1789
2242
 
1790
- dxdt = cache.dxdt
1791
- dxdt[:] = 0
2243
+ dxdt = dict.fromkeys(cache.var_names, 0.0)
1792
2244
  for k, stoc in cache.stoich_by_cpds.items():
1793
2245
  for flux, n in stoc.items():
1794
2246
  dxdt[k] += n * dependent[flux]
@@ -1796,7 +2248,7 @@ class Model:
1796
2248
  for flux, dv in sd.items():
1797
2249
  n = dv.calculate(dependent)
1798
2250
  dxdt[k] += n * dependent[flux]
1799
- return cast(Array, dxdt.to_numpy())
2251
+ return tuple(dxdt[i] for i in cache.var_names)
1800
2252
 
1801
2253
  def get_right_hand_side(
1802
2254
  self,
@@ -1829,7 +2281,7 @@ class Model:
1829
2281
  if (cache := self._cache) is None:
1830
2282
  cache = self._create_cache()
1831
2283
  var_names = self.get_variable_names()
1832
- dependent = self._get_dependent(
2284
+ dependent = self._get_args(
1833
2285
  variables=self.get_initial_conditions() if variables is None else variables,
1834
2286
  time=time,
1835
2287
  cache=cache,
@@ -1844,3 +2296,52 @@ class Model:
1844
2296
  n = dv.fn(*(dependent[i] for i in dv.args))
1845
2297
  dxdt[k] += n * dependent[flux]
1846
2298
  return dxdt
2299
+
2300
+ ##########################################################################
2301
+ # Check units
2302
+ ##########################################################################
2303
+
2304
+ def check_units(self, time_unit: Quantity) -> UnitCheck:
2305
+ """Check unit consistency per differential equation and reaction."""
2306
+ units_per_fn = {}
2307
+ for name, rxn in self._reactions.items():
2308
+ unit_per_arg = {}
2309
+ for arg in rxn.args:
2310
+ if (par := self._parameters.get(arg)) is not None:
2311
+ unit_per_arg[sympy.Symbol(arg)] = par.unit
2312
+ elif (var := self._variables.get(arg)) is not None:
2313
+ unit_per_arg[sympy.Symbol(arg)] = var.unit
2314
+ else:
2315
+ raise NotImplementedError
2316
+
2317
+ symbolic_fn = fn_to_sympy(
2318
+ rxn.fn,
2319
+ origin="unit-checking",
2320
+ model_args=list_of_symbols(rxn.args),
2321
+ )
2322
+ units_per_fn[name] = None
2323
+ if symbolic_fn is None:
2324
+ continue
2325
+ if any(i is None for i in unit_per_arg.values()):
2326
+ continue
2327
+ units_per_fn[name] = symbolic_fn.subs(unit_per_arg)
2328
+
2329
+ check_per_variable = {}
2330
+ for name, var in self._variables.items():
2331
+ check_per_rxn = {}
2332
+
2333
+ if (var_unit := var.unit) is None:
2334
+ break
2335
+
2336
+ for rxn in self.get_stoichiometries_of_variable(name):
2337
+ if (rxn_unit := units_per_fn.get(rxn)) is None:
2338
+ check_per_rxn[rxn] = None
2339
+ elif unit_of(rxn_unit) == var_unit / time_unit: # type: ignore
2340
+ check_per_rxn[rxn] = True
2341
+ else:
2342
+ check_per_rxn[rxn] = Failure(
2343
+ expected=unit_of(rxn_unit),
2344
+ obtained=var_unit / time_unit, # type: ignore
2345
+ )
2346
+ check_per_variable[name] = check_per_rxn
2347
+ return UnitCheck(check_per_variable)