mxlpy 0.21.0__py3-none-any.whl → 0.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mxlpy/simulator.py CHANGED
@@ -10,9 +10,9 @@ Classes:
10
10
 
11
11
  from __future__ import annotations
12
12
 
13
- import warnings
14
- from dataclasses import dataclass, field
15
- from typing import TYPE_CHECKING, Literal, Self, cast, overload
13
+ import logging
14
+ from dataclasses import dataclass
15
+ from typing import TYPE_CHECKING, Self, cast
16
16
 
17
17
  import numpy as np
18
18
  import pandas as pd
@@ -20,287 +20,19 @@ from sympy import lambdify
20
20
 
21
21
  from mxlpy.integrators import DefaultIntegrator
22
22
  from mxlpy.symbolic import to_symbolic_model
23
+ from mxlpy.types import Result
23
24
 
24
25
  if TYPE_CHECKING:
25
- from collections.abc import Iterator
26
-
27
26
  from mxlpy.model import Model
28
27
  from mxlpy.types import Array, ArrayLike, IntegratorProtocol, IntegratorType
29
28
 
29
+ _LOGGER = logging.getLogger(__name__)
30
+
30
31
  __all__ = [
31
- "Result",
32
32
  "Simulator",
33
33
  ]
34
34
 
35
35
 
36
- def _normalise_split_results(
37
- results: list[pd.DataFrame],
38
- normalise: float | ArrayLike,
39
- ) -> list[pd.DataFrame]:
40
- """Normalize split results by a given factor or array.
41
-
42
- Args:
43
- results: List of DataFrames containing the results to normalize.
44
- normalise: Normalization factor or array.
45
-
46
- Returns:
47
- list[pd.DataFrame]: List of normalized DataFrames.
48
-
49
- """
50
- if isinstance(normalise, int | float):
51
- return [i / normalise for i in results]
52
- if len(normalise) == len(results):
53
- return [(i.T / j).T for i, j in zip(results, normalise, strict=True)]
54
-
55
- results = []
56
- start = 0
57
- end = 0
58
- for i in results:
59
- end += len(i)
60
- results.append(i / np.reshape(normalise[start:end], (len(i), 1))) # type: ignore
61
- start += end
62
- return results
63
-
64
-
65
- @dataclass(kw_only=True, slots=True)
66
- class Result:
67
- """Simulation results."""
68
-
69
- model: Model
70
- _raw_variables: list[pd.DataFrame]
71
- _parameters: list[dict[str, float]]
72
- _dependent: list[pd.DataFrame] = field(default_factory=list)
73
-
74
- @property
75
- def variables(self) -> pd.DataFrame:
76
- """Simulation variables."""
77
- return self.get_variables(
78
- include_derived=True,
79
- include_readouts=True,
80
- concatenated=True,
81
- normalise=None,
82
- )
83
-
84
- @property
85
- def fluxes(self) -> pd.DataFrame:
86
- """Simulation fluxes."""
87
- return self.get_fluxes()
88
-
89
- def __iter__(self) -> Iterator[pd.DataFrame]:
90
- """Iterate over the concentration and flux response coefficients."""
91
- return iter((self.variables, self.fluxes))
92
-
93
- def _get_dependent(
94
- self,
95
- *,
96
- include_readouts: bool = True,
97
- ) -> list[pd.DataFrame]:
98
- # Already computed
99
- if len(self._dependent) > 0:
100
- return self._dependent
101
-
102
- # Compute new otherwise
103
- for res, p in zip(self._raw_variables, self._parameters, strict=True):
104
- self.model.update_parameters(p)
105
- self._dependent.append(
106
- self.model.get_dependent_time_course(
107
- variables=res,
108
- include_readouts=include_readouts,
109
- )
110
- )
111
- return self._dependent
112
-
113
- def _select_variables(
114
- self,
115
- dependent: list[pd.DataFrame],
116
- *,
117
- include_derived: bool,
118
- include_readouts: bool,
119
- ) -> list[pd.DataFrame]:
120
- names = self.model.get_variable_names()
121
- if include_derived:
122
- names.extend(self.model.get_derived_variable_names())
123
- if include_readouts:
124
- names.extend(self.model.get_readout_names())
125
- return [i.loc[:, names] for i in dependent]
126
-
127
- def _select_fluxes(
128
- self,
129
- dependent: list[pd.DataFrame],
130
- *,
131
- include_surrogates: bool,
132
- ) -> list[pd.DataFrame]:
133
- names = self.model.get_reaction_names()
134
- if include_surrogates:
135
- names.extend(self.model.get_surrogate_reaction_names())
136
- return [i.loc[:, names] for i in dependent]
137
-
138
- def _adjust_data(
139
- self,
140
- data: list[pd.DataFrame],
141
- normalise: float | ArrayLike | None = None,
142
- *,
143
- concatenated: bool = True,
144
- ) -> pd.DataFrame | list[pd.DataFrame]:
145
- if normalise is not None:
146
- data = _normalise_split_results(data, normalise=normalise)
147
- if concatenated:
148
- return pd.concat(data, axis=0)
149
- return data
150
-
151
- @overload
152
- def get_variables( # type: ignore
153
- self,
154
- *,
155
- include_derived: bool = True,
156
- include_readouts: bool = True,
157
- concatenated: Literal[False],
158
- normalise: float | ArrayLike | None = None,
159
- ) -> list[pd.DataFrame]: ...
160
-
161
- @overload
162
- def get_variables(
163
- self,
164
- *,
165
- include_derived: bool = True,
166
- include_readouts: bool = True,
167
- concatenated: Literal[True],
168
- normalise: float | ArrayLike | None = None,
169
- ) -> pd.DataFrame: ...
170
-
171
- @overload
172
- def get_variables(
173
- self,
174
- *,
175
- include_derived: bool = True,
176
- include_readouts: bool = True,
177
- concatenated: bool = True,
178
- normalise: float | ArrayLike | None = None,
179
- ) -> pd.DataFrame: ...
180
-
181
- def get_variables(
182
- self,
183
- *,
184
- include_derived: bool = True,
185
- include_readouts: bool = True,
186
- concatenated: bool = True,
187
- normalise: float | ArrayLike | None = None,
188
- ) -> pd.DataFrame | list[pd.DataFrame]:
189
- """Get the variables over time.
190
-
191
- Examples:
192
- >>> Result().get_variables()
193
- Time ATP NADPH
194
- 0.000000 1.000000 1.000000
195
- 0.000100 0.999900 0.999900
196
- 0.000200 0.999800 0.999800
197
-
198
- """
199
- if not include_derived and not include_readouts:
200
- return self._adjust_data(
201
- self._raw_variables,
202
- normalise=normalise,
203
- concatenated=concatenated,
204
- )
205
-
206
- variables = self._select_variables(
207
- self._get_dependent(),
208
- include_derived=include_derived,
209
- include_readouts=include_readouts,
210
- )
211
- return self._adjust_data(
212
- variables, normalise=normalise, concatenated=concatenated
213
- )
214
-
215
- @overload
216
- def get_fluxes( # type: ignore
217
- self,
218
- *,
219
- include_surrogates: bool = True,
220
- normalise: float | ArrayLike | None = None,
221
- concatenated: Literal[False],
222
- ) -> list[pd.DataFrame]: ...
223
-
224
- @overload
225
- def get_fluxes(
226
- self,
227
- *,
228
- include_surrogates: bool = True,
229
- normalise: float | ArrayLike | None = None,
230
- concatenated: Literal[True],
231
- ) -> pd.DataFrame: ...
232
-
233
- @overload
234
- def get_fluxes(
235
- self,
236
- *,
237
- include_surrogates: bool = True,
238
- normalise: float | ArrayLike | None = None,
239
- concatenated: bool = True,
240
- ) -> pd.DataFrame: ...
241
-
242
- def get_fluxes(
243
- self,
244
- *,
245
- include_surrogates: bool = True,
246
- normalise: float | ArrayLike | None = None,
247
- concatenated: bool = True,
248
- ) -> pd.DataFrame | list[pd.DataFrame]:
249
- """Get the flux results.
250
-
251
- Examples:
252
- >>> Result.get_fluxes()
253
- Time v1 v2
254
- 0.000000 1.000000 10.00000
255
- 0.000100 0.999900 9.999000
256
- 0.000200 0.999800 9.998000
257
-
258
- Returns:
259
- pd.DataFrame: DataFrame of fluxes.
260
-
261
- """
262
- fluxes = self._select_fluxes(
263
- self._get_dependent(),
264
- include_surrogates=include_surrogates,
265
- )
266
- return self._adjust_data(
267
- fluxes,
268
- normalise=normalise,
269
- concatenated=concatenated,
270
- )
271
-
272
- def get_combined(self) -> pd.DataFrame:
273
- """Get the variables and fluxes as a single pandas.DataFrame.
274
-
275
- Examples:
276
- >>> Result.get_combined()
277
- Time ATP NADPH v1 v2
278
- 0.000000 1.000000 1.000000 1.000000 10.00000
279
- 0.000100 0.999900 0.999900 0.999900 9.999000
280
- 0.000200 0.999800 0.999800 0.999800 9.998000
281
-
282
- Returns:
283
- pd.DataFrame: DataFrame of fluxes.
284
-
285
- """
286
- return pd.concat((self.variables, self.fluxes), axis=1)
287
-
288
- def get_new_y0(self) -> dict[str, float]:
289
- """Get the new initial conditions after the simulation.
290
-
291
- Examples:
292
- >>> Simulator(model).simulate_to_steady_state().get_new_y0()
293
- {"ATP": 1.0, "NADPH": 1.0}
294
-
295
- """
296
- return dict(
297
- self.get_variables(
298
- include_derived=False,
299
- include_readouts=False,
300
- ).iloc[-1]
301
- )
302
-
303
-
304
36
  @dataclass(
305
37
  init=False,
306
38
  slots=True,
@@ -386,12 +118,12 @@ class Simulator:
386
118
  )
387
119
 
388
120
  except Exception as e: # noqa: BLE001
389
- warnings.warn(str(e), stacklevel=2)
121
+ _LOGGER.warning(str(e), stacklevel=2)
390
122
 
391
123
  y0 = self.y0
392
124
  self.integrator = self._integrator_type(
393
125
  self.model,
394
- [y0[k] for k in self.model.get_variable_names()],
126
+ tuple(y0[k] for k in self.model.get_variable_names()),
395
127
  jac_fn,
396
128
  )
397
129
 
@@ -431,7 +163,7 @@ class Simulator:
431
163
  # model._get_rhs sorts the return array by model.get_variable_names()
432
164
  # Do NOT change this ordering
433
165
  results_df = pd.DataFrame(
434
- results,
166
+ data=results,
435
167
  index=time,
436
168
  columns=self.model.get_variable_names(),
437
169
  )
@@ -445,7 +177,7 @@ class Simulator:
445
177
 
446
178
  if self.simulation_parameters is None:
447
179
  self.simulation_parameters = []
448
- self.simulation_parameters.append(self.model.parameters)
180
+ self.simulation_parameters.append(self.model.get_parameter_values())
449
181
 
450
182
  def simulate(
451
183
  self,
@@ -472,6 +204,13 @@ class Simulator:
472
204
  if self._time_shift is not None:
473
205
  t_end -= self._time_shift
474
206
 
207
+ prior_t_end: float = (
208
+ 0.0 if (variables := self.variables) is None else variables[-1].index[-1]
209
+ )
210
+ if t_end <= prior_t_end:
211
+ msg = "End time point has to be larger than previous end time point"
212
+ raise ValueError(msg)
213
+
475
214
  time, results = self.integrator.integrate(t_end=t_end, steps=steps)
476
215
 
477
216
  self._handle_simulation_results(time, results, skipfirst=True)
@@ -495,17 +234,33 @@ class Simulator:
495
234
  Self: The Simulator instance with updated results.
496
235
 
497
236
  """
237
+ time_points = np.array(time_points, dtype=float)
238
+
498
239
  if self._time_shift is not None:
499
- time_points = np.array(time_points, dtype=float)
500
240
  time_points -= self._time_shift
501
241
 
242
+ # Check if end is actually larger
243
+ prior_t_end: float = (
244
+ 0.0 if (variables := self.variables) is None else variables[-1].index[-1]
245
+ )
246
+ if time_points[-1] <= prior_t_end:
247
+ msg = "End time point has to be larger than previous end time point"
248
+ raise ValueError(msg)
249
+
250
+ # Remove points which are smaller than previous t_end
251
+ if not (larger := time_points >= prior_t_end).all():
252
+ msg = f"Overlapping time points. Removing: {time_points[~larger]}"
253
+ _LOGGER.warning(msg)
254
+ time_points = time_points[larger]
255
+
502
256
  time, results = self.integrator.integrate_time_course(time_points=time_points)
503
257
  self._handle_simulation_results(time, results, skipfirst=True)
504
258
  return self
505
259
 
506
- def simulate_over_protocol(
260
+ def simulate_protocol(
507
261
  self,
508
262
  protocol: pd.DataFrame,
263
+ *,
509
264
  time_points_per_step: int = 10,
510
265
  ) -> Self:
511
266
  """Simulate the model over a given protocol.
@@ -524,10 +279,81 @@ class Simulator:
524
279
  The Simulator instance with updated results.
525
280
 
526
281
  """
282
+ t_start = (
283
+ 0.0 if (variables := self.variables) is None else variables[-1].index[-1]
284
+ )
285
+
527
286
  for t_end, pars in protocol.iterrows():
528
287
  t_end = cast(pd.Timedelta, t_end)
529
288
  self.model.update_parameters(pars.to_dict())
530
- self.simulate(t_end.total_seconds(), steps=time_points_per_step)
289
+ self.simulate(t_start + t_end.total_seconds(), steps=time_points_per_step)
290
+ if self.variables is None:
291
+ break
292
+ return self
293
+
294
+ def simulate_protocol_time_course(
295
+ self,
296
+ protocol: pd.DataFrame,
297
+ time_points: ArrayLike,
298
+ *,
299
+ time_points_as_relative: bool = False,
300
+ ) -> Self:
301
+ """Simulate the model over a given protocol.
302
+
303
+ Examples:
304
+ >>> Simulator(model).simulate_over_protocol(
305
+ ... protocol,
306
+ ... time_points=np.array([1.0, 2.0, 3.0], dtype=float),
307
+ ... )
308
+
309
+ Args:
310
+ protocol: DataFrame containing the protocol.
311
+ time_points: Array of time points for which to return the simulation values.
312
+ time_points_as_relative: Interpret time points as relative time
313
+
314
+ Notes:
315
+ This function will return **both** the control points of the protocol as well
316
+ as the time points supplied in case they don't match.
317
+
318
+ Returns:
319
+ The Simulator instance with updated results.
320
+
321
+ """
322
+ t_start = (
323
+ 0.0 if (variables := self.variables) is None else variables[-1].index[-1]
324
+ )
325
+
326
+ protocol = protocol.copy()
327
+ protocol.index = (
328
+ cast(pd.TimedeltaIndex, protocol.index) + pd.Timedelta(t_start, unit="s")
329
+ ).total_seconds()
330
+
331
+ time_points = np.array(time_points, dtype=float)
332
+ if time_points_as_relative:
333
+ time_points += t_start
334
+
335
+ # Error handling
336
+ if time_points[-1] <= t_start:
337
+ msg = "End time point has to be larger than previous end time point"
338
+ raise ValueError(msg)
339
+
340
+ larger = time_points > protocol.index[-1]
341
+ if any(larger):
342
+ msg = f"Ignoring time points outside of protocol range:\n {time_points[larger]}"
343
+ _LOGGER.warning(msg)
344
+
345
+ # Continue with logic
346
+ full_time_points = protocol.index.join(time_points, how="outer")
347
+
348
+ for t_end, pars in protocol.iterrows():
349
+ self.model.update_parameters(pars.to_dict())
350
+
351
+ self.simulate_time_course(
352
+ time_points=full_time_points[
353
+ (full_time_points > t_start) & (full_time_points <= t_end)
354
+ ]
355
+ )
356
+ t_start = t_end
531
357
  if self.variables is None:
532
358
  break
533
359
  return self
@@ -590,8 +416,8 @@ class Simulator:
590
416
  return None
591
417
  return Result(
592
418
  model=self.model,
593
- _raw_variables=variables,
594
- _parameters=parameters,
419
+ raw_variables=variables,
420
+ raw_parameters=parameters,
595
421
  )
596
422
 
597
423
  def update_parameter(self, parameter: str, value: float) -> Self:
@@ -4,15 +4,13 @@
4
4
  from __future__ import annotations
5
5
 
6
6
  from dataclasses import dataclass
7
- from typing import TYPE_CHECKING
7
+ from typing import TYPE_CHECKING, cast
8
8
 
9
9
  import sympy
10
10
 
11
- from mxlpy.meta.source_tools import fn_to_sympy
11
+ from mxlpy.meta.sympy_tools import fn_to_sympy, list_of_symbols
12
12
 
13
13
  if TYPE_CHECKING:
14
- from collections.abc import Iterable
15
-
16
14
  from mxlpy.model import Model
17
15
 
18
16
  __all__ = [
@@ -36,30 +34,44 @@ class SymbolicModel:
36
34
  )
37
35
 
38
36
 
39
- def _list_of_symbols(args: Iterable[str]) -> list[sympy.Symbol]:
40
- return [sympy.Symbol(arg) for arg in args]
41
-
42
-
43
37
  def to_symbolic_model(model: Model) -> SymbolicModel:
44
38
  cache = model._create_cache() # noqa: SLF001
39
+ initial_conditions = model.get_initial_conditions()
45
40
 
46
41
  variables: dict[str, sympy.Symbol] = dict(
47
- zip(model.variables, _list_of_symbols(model.variables), strict=True)
42
+ zip(
43
+ initial_conditions,
44
+ cast(list[sympy.Symbol], list_of_symbols(initial_conditions)),
45
+ strict=True,
46
+ )
48
47
  )
49
48
  parameters: dict[str, sympy.Symbol] = dict(
50
- zip(model.parameters, _list_of_symbols(model.parameters), strict=True)
49
+ zip(
50
+ model.get_parameter_values(),
51
+ cast(list[sympy.Symbol], list_of_symbols(model.get_parameter_values())),
52
+ strict=True,
53
+ )
51
54
  )
52
55
  symbols: dict[str, sympy.Symbol | sympy.Expr] = variables | parameters # type: ignore
53
56
 
54
57
  # Insert derived into symbols
55
- for k, v in model.derived.items():
56
- symbols[k] = fn_to_sympy(v.fn, [symbols[i] for i in v.args])
58
+ for k, v in model.get_raw_derived().items():
59
+ if (
60
+ expr := fn_to_sympy(v.fn, origin=k, model_args=[symbols[i] for i in v.args])
61
+ ) is None:
62
+ msg = f"Unable to parse derived value '{k}'"
63
+ raise ValueError(msg)
64
+ symbols[k] = expr
57
65
 
58
66
  # Insert derived into reaction via args
59
- rxns = {
60
- k: fn_to_sympy(v.fn, [symbols[i] for i in v.args])
61
- for k, v in model.reactions.items()
62
- }
67
+ rxns: dict[str, sympy.Expr] = {}
68
+ for k, v in model.get_raw_reactions().items():
69
+ if (
70
+ expr := fn_to_sympy(v.fn, origin=k, model_args=[symbols[i] for i in v.args])
71
+ ) is None:
72
+ msg = f"Unable to parse reaction '{k}'"
73
+ raise ValueError(msg)
74
+ rxns[k] = expr
63
75
 
64
76
  eqs: dict[str, sympy.Expr] = {}
65
77
  for cpd, stoich in cache.stoich_by_cpds.items():
@@ -80,5 +92,5 @@ def to_symbolic_model(model: Model) -> SymbolicModel:
80
92
  parameters=parameters,
81
93
  eqs=[eqs[i] for i in cache.var_names],
82
94
  initial_conditions=model.get_initial_conditions(),
83
- parameter_values=model.parameters,
95
+ parameter_values=model.get_parameter_values(),
84
96
  )