mxlpy 0.21.0__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mxlpy/simulator.py CHANGED
@@ -10,7 +10,7 @@ Classes:
10
10
 
11
11
  from __future__ import annotations
12
12
 
13
- import warnings
13
+ import logging
14
14
  from dataclasses import dataclass, field
15
15
  from typing import TYPE_CHECKING, Literal, Self, cast, overload
16
16
 
@@ -27,6 +27,8 @@ if TYPE_CHECKING:
27
27
  from mxlpy.model import Model
28
28
  from mxlpy.types import Array, ArrayLike, IntegratorProtocol, IntegratorType
29
29
 
30
+ _LOGGER = logging.getLogger(__name__)
31
+
30
32
  __all__ = [
31
33
  "Result",
32
34
  "Simulator",
@@ -67,15 +69,15 @@ class Result:
67
69
  """Simulation results."""
68
70
 
69
71
  model: Model
70
- _raw_variables: list[pd.DataFrame]
71
- _parameters: list[dict[str, float]]
72
- _dependent: list[pd.DataFrame] = field(default_factory=list)
72
+ raw_variables: list[pd.DataFrame]
73
+ raw_parameters: list[dict[str, float]]
74
+ raw_args: list[pd.DataFrame] = field(default_factory=list)
73
75
 
74
76
  @property
75
77
  def variables(self) -> pd.DataFrame:
76
78
  """Simulation variables."""
77
79
  return self.get_variables(
78
- include_derived=True,
80
+ include_derived_variables=True,
79
81
  include_readouts=True,
80
82
  concatenated=True,
81
83
  normalise=None,
@@ -90,49 +92,50 @@ class Result:
90
92
  """Iterate over the concentration and flux response coefficients."""
91
93
  return iter((self.variables, self.fluxes))
92
94
 
93
- def _get_dependent(
94
- self,
95
- *,
96
- include_readouts: bool = True,
97
- ) -> list[pd.DataFrame]:
95
+ def _compute_args(self) -> list[pd.DataFrame]:
98
96
  # Already computed
99
- if len(self._dependent) > 0:
100
- return self._dependent
97
+ if len(self.raw_args) > 0:
98
+ return self.raw_args
101
99
 
102
100
  # Compute new otherwise
103
- for res, p in zip(self._raw_variables, self._parameters, strict=True):
101
+ for res, p in zip(self.raw_variables, self.raw_parameters, strict=True):
104
102
  self.model.update_parameters(p)
105
- self._dependent.append(
106
- self.model.get_dependent_time_course(
103
+ self.raw_args.append(
104
+ self.model.get_args_time_course(
107
105
  variables=res,
108
- include_readouts=include_readouts,
106
+ include_variables=True,
107
+ include_parameters=True,
108
+ include_derived_parameters=True,
109
+ include_derived_variables=True,
110
+ include_reactions=True,
111
+ include_surrogate_outputs=True,
112
+ include_readouts=True,
109
113
  )
110
114
  )
111
- return self._dependent
115
+ return self.raw_args
112
116
 
113
- def _select_variables(
117
+ def _select_data(
114
118
  self,
115
119
  dependent: list[pd.DataFrame],
116
120
  *,
117
- include_derived: bool,
118
- include_readouts: bool,
121
+ include_variables: bool = False,
122
+ include_parameters: bool = False,
123
+ include_derived_parameters: bool = False,
124
+ include_derived_variables: bool = False,
125
+ include_reactions: bool = False,
126
+ include_surrogate_outputs: bool = False,
127
+ include_readouts: bool = False,
119
128
  ) -> list[pd.DataFrame]:
120
- names = self.model.get_variable_names()
121
- if include_derived:
122
- names.extend(self.model.get_derived_variable_names())
123
- if include_readouts:
124
- names.extend(self.model.get_readout_names())
125
- return [i.loc[:, names] for i in dependent]
126
-
127
- def _select_fluxes(
128
- self,
129
- dependent: list[pd.DataFrame],
130
- *,
131
- include_surrogates: bool,
132
- ) -> list[pd.DataFrame]:
133
- names = self.model.get_reaction_names()
134
- if include_surrogates:
135
- names.extend(self.model.get_surrogate_reaction_names())
129
+ names = self.model.get_arg_names(
130
+ include_time=False,
131
+ include_variables=include_variables,
132
+ include_parameters=include_parameters,
133
+ include_derived_parameters=include_derived_parameters,
134
+ include_derived_variables=include_derived_variables,
135
+ include_reactions=include_reactions,
136
+ include_surrogate_outputs=include_surrogate_outputs,
137
+ include_readouts=include_readouts,
138
+ )
136
139
  return [i.loc[:, names] for i in dependent]
137
140
 
138
141
  def _adjust_data(
@@ -148,11 +151,93 @@ class Result:
148
151
  return pd.concat(data, axis=0)
149
152
  return data
150
153
 
154
+ @overload
155
+ def get_args( # type: ignore
156
+ self,
157
+ *,
158
+ include_variables: bool = True,
159
+ include_parameters: bool = False,
160
+ include_derived_parameters: bool = False,
161
+ include_derived_variables: bool = True,
162
+ include_reactions: bool = True,
163
+ include_surrogate_outputs: bool = False,
164
+ include_readouts: bool = False,
165
+ concatenated: Literal[False],
166
+ normalise: float | ArrayLike | None = None,
167
+ ) -> list[pd.DataFrame]: ...
168
+
169
+ @overload
170
+ def get_args(
171
+ self,
172
+ *,
173
+ include_variables: bool = True,
174
+ include_parameters: bool = False,
175
+ include_derived_parameters: bool = False,
176
+ include_derived_variables: bool = True,
177
+ include_reactions: bool = True,
178
+ include_surrogate_outputs: bool = False,
179
+ include_readouts: bool = False,
180
+ concatenated: Literal[True],
181
+ normalise: float | ArrayLike | None = None,
182
+ ) -> pd.DataFrame: ...
183
+
184
+ @overload
185
+ def get_args(
186
+ self,
187
+ *,
188
+ include_variables: bool = True,
189
+ include_parameters: bool = False,
190
+ include_derived_parameters: bool = False,
191
+ include_derived_variables: bool = True,
192
+ include_reactions: bool = True,
193
+ include_surrogate_outputs: bool = False,
194
+ include_readouts: bool = False,
195
+ concatenated: bool = True,
196
+ normalise: float | ArrayLike | None = None,
197
+ ) -> pd.DataFrame: ...
198
+
199
+ def get_args(
200
+ self,
201
+ *,
202
+ include_variables: bool = True,
203
+ include_parameters: bool = False,
204
+ include_derived_parameters: bool = False,
205
+ include_derived_variables: bool = True,
206
+ include_reactions: bool = True,
207
+ include_surrogate_outputs: bool = False,
208
+ include_readouts: bool = False,
209
+ concatenated: bool = True,
210
+ normalise: float | ArrayLike | None = None,
211
+ ) -> pd.DataFrame | list[pd.DataFrame]:
212
+ """Get the variables over time.
213
+
214
+ Examples:
215
+ >>> Result().get_variables()
216
+ Time ATP NADPH
217
+ 0.000000 1.000000 1.000000
218
+ 0.000100 0.999900 0.999900
219
+ 0.000200 0.999800 0.999800
220
+
221
+ """
222
+ variables = self._select_data(
223
+ self._compute_args(),
224
+ include_variables=include_variables,
225
+ include_parameters=include_parameters,
226
+ include_derived_parameters=include_derived_parameters,
227
+ include_derived_variables=include_derived_variables,
228
+ include_reactions=include_reactions,
229
+ include_surrogate_outputs=include_surrogate_outputs,
230
+ include_readouts=include_readouts,
231
+ )
232
+ return self._adjust_data(
233
+ variables, normalise=normalise, concatenated=concatenated
234
+ )
235
+
151
236
  @overload
152
237
  def get_variables( # type: ignore
153
238
  self,
154
239
  *,
155
- include_derived: bool = True,
240
+ include_derived_variables: bool = True,
156
241
  include_readouts: bool = True,
157
242
  concatenated: Literal[False],
158
243
  normalise: float | ArrayLike | None = None,
@@ -162,7 +247,7 @@ class Result:
162
247
  def get_variables(
163
248
  self,
164
249
  *,
165
- include_derived: bool = True,
250
+ include_derived_variables: bool = True,
166
251
  include_readouts: bool = True,
167
252
  concatenated: Literal[True],
168
253
  normalise: float | ArrayLike | None = None,
@@ -172,7 +257,7 @@ class Result:
172
257
  def get_variables(
173
258
  self,
174
259
  *,
175
- include_derived: bool = True,
260
+ include_derived_variables: bool = True,
176
261
  include_readouts: bool = True,
177
262
  concatenated: bool = True,
178
263
  normalise: float | ArrayLike | None = None,
@@ -181,7 +266,7 @@ class Result:
181
266
  def get_variables(
182
267
  self,
183
268
  *,
184
- include_derived: bool = True,
269
+ include_derived_variables: bool = True,
185
270
  include_readouts: bool = True,
186
271
  concatenated: bool = True,
187
272
  normalise: float | ArrayLike | None = None,
@@ -196,16 +281,17 @@ class Result:
196
281
  0.000200 0.999800 0.999800
197
282
 
198
283
  """
199
- if not include_derived and not include_readouts:
284
+ if not include_derived_variables and not include_readouts:
200
285
  return self._adjust_data(
201
- self._raw_variables,
286
+ self.raw_variables,
202
287
  normalise=normalise,
203
288
  concatenated=concatenated,
204
289
  )
205
290
 
206
- variables = self._select_variables(
207
- self._get_dependent(),
208
- include_derived=include_derived,
291
+ variables = self._select_data(
292
+ self._compute_args(),
293
+ include_variables=True,
294
+ include_derived_variables=include_derived_variables,
209
295
  include_readouts=include_readouts,
210
296
  )
211
297
  return self._adjust_data(
@@ -259,9 +345,10 @@ class Result:
259
345
  pd.DataFrame: DataFrame of fluxes.
260
346
 
261
347
  """
262
- fluxes = self._select_fluxes(
263
- self._get_dependent(),
264
- include_surrogates=include_surrogates,
348
+ fluxes = self._select_data(
349
+ self._compute_args(),
350
+ include_reactions=True,
351
+ include_surrogate_outputs=include_surrogates,
265
352
  )
266
353
  return self._adjust_data(
267
354
  fluxes,
@@ -295,7 +382,7 @@ class Result:
295
382
  """
296
383
  return dict(
297
384
  self.get_variables(
298
- include_derived=False,
385
+ include_derived_variables=False,
299
386
  include_readouts=False,
300
387
  ).iloc[-1]
301
388
  )
@@ -386,7 +473,7 @@ class Simulator:
386
473
  )
387
474
 
388
475
  except Exception as e: # noqa: BLE001
389
- warnings.warn(str(e), stacklevel=2)
476
+ _LOGGER.warning(str(e), stacklevel=2)
390
477
 
391
478
  y0 = self.y0
392
479
  self.integrator = self._integrator_type(
@@ -431,7 +518,7 @@ class Simulator:
431
518
  # model._get_rhs sorts the return array by model.get_variable_names()
432
519
  # Do NOT change this ordering
433
520
  results_df = pd.DataFrame(
434
- results,
521
+ data=results,
435
522
  index=time,
436
523
  columns=self.model.get_variable_names(),
437
524
  )
@@ -445,7 +532,7 @@ class Simulator:
445
532
 
446
533
  if self.simulation_parameters is None:
447
534
  self.simulation_parameters = []
448
- self.simulation_parameters.append(self.model.parameters)
535
+ self.simulation_parameters.append(self.model.get_parameter_values())
449
536
 
450
537
  def simulate(
451
538
  self,
@@ -472,6 +559,13 @@ class Simulator:
472
559
  if self._time_shift is not None:
473
560
  t_end -= self._time_shift
474
561
 
562
+ prior_t_end: float = (
563
+ 0.0 if (variables := self.variables) is None else variables[-1].index[-1]
564
+ )
565
+ if t_end <= prior_t_end:
566
+ msg = "End time point has to be larger than previous end time point"
567
+ raise ValueError(msg)
568
+
475
569
  time, results = self.integrator.integrate(t_end=t_end, steps=steps)
476
570
 
477
571
  self._handle_simulation_results(time, results, skipfirst=True)
@@ -495,17 +589,33 @@ class Simulator:
495
589
  Self: The Simulator instance with updated results.
496
590
 
497
591
  """
592
+ time_points = np.array(time_points, dtype=float)
593
+
498
594
  if self._time_shift is not None:
499
- time_points = np.array(time_points, dtype=float)
500
595
  time_points -= self._time_shift
501
596
 
597
+ # Check if end is actually larger
598
+ prior_t_end: float = (
599
+ 0.0 if (variables := self.variables) is None else variables[-1].index[-1]
600
+ )
601
+ if time_points[-1] <= prior_t_end:
602
+ msg = "End time point has to be larger than previous end time point"
603
+ raise ValueError(msg)
604
+
605
+ # Remove points which are smaller than previous t_end
606
+ if not (larger := time_points >= prior_t_end).all():
607
+ msg = f"Overlapping time points. Removing: {time_points[~larger]}"
608
+ _LOGGER.warning(msg)
609
+ time_points = time_points[larger]
610
+
502
611
  time, results = self.integrator.integrate_time_course(time_points=time_points)
503
612
  self._handle_simulation_results(time, results, skipfirst=True)
504
613
  return self
505
614
 
506
- def simulate_over_protocol(
615
+ def simulate_protocol(
507
616
  self,
508
617
  protocol: pd.DataFrame,
618
+ *,
509
619
  time_points_per_step: int = 10,
510
620
  ) -> Self:
511
621
  """Simulate the model over a given protocol.
@@ -524,10 +634,81 @@ class Simulator:
524
634
  The Simulator instance with updated results.
525
635
 
526
636
  """
637
+ t_start = (
638
+ 0.0 if (variables := self.variables) is None else variables[-1].index[-1]
639
+ )
640
+
527
641
  for t_end, pars in protocol.iterrows():
528
642
  t_end = cast(pd.Timedelta, t_end)
529
643
  self.model.update_parameters(pars.to_dict())
530
- self.simulate(t_end.total_seconds(), steps=time_points_per_step)
644
+ self.simulate(t_start + t_end.total_seconds(), steps=time_points_per_step)
645
+ if self.variables is None:
646
+ break
647
+ return self
648
+
649
+ def simulate_protocol_time_course(
650
+ self,
651
+ protocol: pd.DataFrame,
652
+ time_points: ArrayLike,
653
+ *,
654
+ time_points_as_relative: bool = False,
655
+ ) -> Self:
656
+ """Simulate the model over a given protocol.
657
+
658
+ Examples:
659
+ >>> Simulator(model).simulate_over_protocol(
660
+ ... protocol,
661
+ ... time_points=np.array([1.0, 2.0, 3.0], dtype=float),
662
+ ... )
663
+
664
+ Args:
665
+ protocol: DataFrame containing the protocol.
666
+ time_points: Array of time points for which to return the simulation values.
667
+ time_points_as_relative: Interpret time points as relative time
668
+
669
+ Notes:
670
+ This function will return **both** the control points of the protocol as well
671
+ as the time points supplied in case they don't match.
672
+
673
+ Returns:
674
+ The Simulator instance with updated results.
675
+
676
+ """
677
+ t_start = (
678
+ 0.0 if (variables := self.variables) is None else variables[-1].index[-1]
679
+ )
680
+
681
+ protocol = protocol.copy()
682
+ protocol.index = (
683
+ cast(pd.TimedeltaIndex, protocol.index) + pd.Timedelta(t_start, unit="s")
684
+ ).total_seconds()
685
+
686
+ time_points = np.array(time_points, dtype=float)
687
+ if time_points_as_relative:
688
+ time_points += t_start
689
+
690
+ # Error handling
691
+ if time_points[-1] <= t_start:
692
+ msg = "End time point has to be larger than previous end time point"
693
+ raise ValueError(msg)
694
+
695
+ larger = time_points > protocol.index[-1]
696
+ if any(larger):
697
+ msg = f"Ignoring time points outside of protocol range:\n {time_points[larger]}"
698
+ _LOGGER.warning(msg)
699
+
700
+ # Continue with logic
701
+ full_time_points = protocol.index.join(time_points, how="outer")
702
+
703
+ for t_end, pars in protocol.iterrows():
704
+ self.model.update_parameters(pars.to_dict())
705
+
706
+ self.simulate_time_course(
707
+ time_points=full_time_points[
708
+ (full_time_points > t_start) & (full_time_points <= t_end)
709
+ ]
710
+ )
711
+ t_start = t_end
531
712
  if self.variables is None:
532
713
  break
533
714
  return self
@@ -590,8 +771,8 @@ class Simulator:
590
771
  return None
591
772
  return Result(
592
773
  model=self.model,
593
- _raw_variables=variables,
594
- _parameters=parameters,
774
+ raw_variables=variables,
775
+ raw_parameters=parameters,
595
776
  )
596
777
 
597
778
  def update_parameter(self, parameter: str, value: float) -> Self:
@@ -4,15 +4,13 @@
4
4
  from __future__ import annotations
5
5
 
6
6
  from dataclasses import dataclass
7
- from typing import TYPE_CHECKING
7
+ from typing import TYPE_CHECKING, cast
8
8
 
9
9
  import sympy
10
10
 
11
- from mxlpy.meta.source_tools import fn_to_sympy
11
+ from mxlpy.meta.sympy_tools import fn_to_sympy, list_of_symbols
12
12
 
13
13
  if TYPE_CHECKING:
14
- from collections.abc import Iterable
15
-
16
14
  from mxlpy.model import Model
17
15
 
18
16
  __all__ = [
@@ -36,30 +34,44 @@ class SymbolicModel:
36
34
  )
37
35
 
38
36
 
39
- def _list_of_symbols(args: Iterable[str]) -> list[sympy.Symbol]:
40
- return [sympy.Symbol(arg) for arg in args]
41
-
42
-
43
37
  def to_symbolic_model(model: Model) -> SymbolicModel:
44
38
  cache = model._create_cache() # noqa: SLF001
39
+ initial_conditions = model.get_initial_conditions()
45
40
 
46
41
  variables: dict[str, sympy.Symbol] = dict(
47
- zip(model.variables, _list_of_symbols(model.variables), strict=True)
42
+ zip(
43
+ initial_conditions,
44
+ cast(list[sympy.Symbol], list_of_symbols(initial_conditions)),
45
+ strict=True,
46
+ )
48
47
  )
49
48
  parameters: dict[str, sympy.Symbol] = dict(
50
- zip(model.parameters, _list_of_symbols(model.parameters), strict=True)
49
+ zip(
50
+ model.get_parameter_values(),
51
+ cast(list[sympy.Symbol], list_of_symbols(model.get_parameter_values())),
52
+ strict=True,
53
+ )
51
54
  )
52
55
  symbols: dict[str, sympy.Symbol | sympy.Expr] = variables | parameters # type: ignore
53
56
 
54
57
  # Insert derived into symbols
55
- for k, v in model.derived.items():
56
- symbols[k] = fn_to_sympy(v.fn, [symbols[i] for i in v.args])
58
+ for k, v in model.get_raw_derived().items():
59
+ if (
60
+ expr := fn_to_sympy(v.fn, origin=k, model_args=[symbols[i] for i in v.args])
61
+ ) is None:
62
+ msg = f"Unable to parse derived value '{k}'"
63
+ raise ValueError(msg)
64
+ symbols[k] = expr
57
65
 
58
66
  # Insert derived into reaction via args
59
- rxns = {
60
- k: fn_to_sympy(v.fn, [symbols[i] for i in v.args])
61
- for k, v in model.reactions.items()
62
- }
67
+ rxns: dict[str, sympy.Expr] = {}
68
+ for k, v in model.get_raw_reactions().items():
69
+ if (
70
+ expr := fn_to_sympy(v.fn, origin=k, model_args=[symbols[i] for i in v.args])
71
+ ) is None:
72
+ msg = f"Unable to parse reaction '{k}'"
73
+ raise ValueError(msg)
74
+ rxns[k] = expr
63
75
 
64
76
  eqs: dict[str, sympy.Expr] = {}
65
77
  for cpd, stoich in cache.stoich_by_cpds.items():
@@ -80,5 +92,5 @@ def to_symbolic_model(model: Model) -> SymbolicModel:
80
92
  parameters=parameters,
81
93
  eqs=[eqs[i] for i in cache.var_names],
82
94
  initial_conditions=model.get_initial_conditions(),
83
- parameter_values=model.parameters,
95
+ parameter_values=model.get_parameter_values(),
84
96
  )