mxlpy 0.22.0__py3-none-any.whl → 0.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mxlpy/scan.py CHANGED
@@ -15,16 +15,21 @@ Functions:
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- from dataclasses import dataclass
19
18
  from functools import partial
20
- from typing import TYPE_CHECKING, Protocol, Self, cast
19
+ from typing import TYPE_CHECKING, Protocol
21
20
 
22
21
  import numpy as np
23
22
  import pandas as pd
24
23
 
25
24
  from mxlpy.parallel import Cache, parallelise
26
- from mxlpy.simulator import Result, Simulator
27
- from mxlpy.types import IntegratorType, ProtocolByPars, SteadyStates, TimeCourseByPars
25
+ from mxlpy.simulator import Simulator
26
+ from mxlpy.types import (
27
+ IntegratorType,
28
+ ProtocolScan,
29
+ Result,
30
+ SteadyStateScan,
31
+ TimeCourseScan,
32
+ )
28
33
 
29
34
  if TYPE_CHECKING:
30
35
  from collections.abc import Callable
@@ -34,14 +39,14 @@ if TYPE_CHECKING:
34
39
 
35
40
 
36
41
  __all__ = [
42
+ "ProtocolTimeCourseWorker",
37
43
  "ProtocolWorker",
38
44
  "SteadyStateWorker",
39
- "TimeCourse",
40
45
  "TimeCourseWorker",
41
- "TimePoint",
46
+ "protocol",
47
+ "protocol_time_course",
42
48
  "steady_state",
43
49
  "time_course",
44
- "time_course_over_protocol",
45
50
  ]
46
51
 
47
52
 
@@ -67,206 +72,6 @@ def _update_parameters_and_initial_conditions[T](
67
72
  return fn(model)
68
73
 
69
74
 
70
- def _empty_conc_series(model: Model) -> pd.Series:
71
- """Create an empty concentration series for the model.
72
-
73
- Args:
74
- model: Model instance to generate the series for.
75
-
76
- Returns:
77
- pd.Series: Series with NaN values for each model variable.
78
-
79
- """
80
- return pd.Series(
81
- data=np.full(shape=len(model.get_variable_names()), fill_value=np.nan),
82
- index=model.get_variable_names(),
83
- )
84
-
85
-
86
- def _empty_flux_series(model: Model) -> pd.Series:
87
- """Create an empty flux series for the model.
88
-
89
- Args:
90
- model: Model instance to generate the series for.
91
-
92
- Returns:
93
- pd.Series: Series with NaN values for each model reaction.
94
-
95
- """
96
- return pd.Series(
97
- data=np.full(shape=len(model.get_reaction_names()), fill_value=np.nan),
98
- index=model.get_reaction_names(),
99
- )
100
-
101
-
102
- def _empty_conc_df(model: Model, time_points: Array) -> pd.DataFrame:
103
- """Create an empty concentration DataFrame for the model over given time points.
104
-
105
- Args:
106
- model: Model instance to generate the DataFrame for.
107
- time_points: Array of time points.
108
-
109
- Returns:
110
- pd.DataFrame: DataFrame with NaN values for each model variable at each time point.
111
-
112
- """
113
- return pd.DataFrame(
114
- data=np.full(
115
- shape=(len(time_points), len(model.get_variable_names())),
116
- fill_value=np.nan,
117
- ),
118
- index=time_points,
119
- columns=model.get_variable_names(),
120
- )
121
-
122
-
123
- def _empty_flux_df(model: Model, time_points: Array) -> pd.DataFrame:
124
- """Create an empty concentration DataFrame for the model over given time points.
125
-
126
- Args:
127
- model: Model instance to generate the DataFrame for.
128
- time_points: Array of time points.
129
-
130
- Returns:
131
- pd.DataFrame: DataFrame with NaN values for each model variable at each time point.
132
-
133
- """
134
- return pd.DataFrame(
135
- data=np.full(
136
- shape=(len(time_points), len(model.get_reaction_names())),
137
- fill_value=np.nan,
138
- ),
139
- index=time_points,
140
- columns=model.get_reaction_names(),
141
- )
142
-
143
-
144
- ###############################################################################
145
- # Single returns
146
- ###############################################################################
147
-
148
-
149
- @dataclass(slots=True)
150
- class TimePoint:
151
- """Represents a single time point in a simulation.
152
-
153
- Attributes:
154
- concs: Series of concentrations at the time point.
155
- fluxes: Series of fluxes at the time point.
156
-
157
- Args:
158
- model: Model instance to generate the time point for.
159
- concs: DataFrame of concentrations (default: None).
160
- fluxes: DataFrame of fluxes (default: None).
161
- idx: Index of the time point in the DataFrame (default: -1).
162
-
163
- """
164
-
165
- variables: pd.Series
166
- fluxes: pd.Series
167
-
168
- @classmethod
169
- def from_result(
170
- cls,
171
- *,
172
- model: Model,
173
- result: Result | None,
174
- idx: int = -1,
175
- ) -> Self:
176
- """Initialize the Scan object.
177
-
178
- Args:
179
- model: The model object.
180
- result: Result of the simulation
181
- idx: Index to select specific row from concs and fluxes DataFrames.
182
-
183
- """
184
- if result is None:
185
- return cls(
186
- variables=_empty_conc_series(model),
187
- fluxes=_empty_flux_series(model),
188
- )
189
-
190
- return cls(
191
- variables=result.variables.iloc[idx],
192
- fluxes=result.fluxes.iloc[idx],
193
- )
194
-
195
- @property
196
- def results(self) -> pd.Series:
197
- """Get the combined results of concentrations and fluxes.
198
-
199
- Example:
200
- >>> time_point.results
201
- x1 1.0
202
- x2 0.5
203
- v1 0.1
204
- v2 0.2
205
-
206
- Returns:
207
- pd.Series: Combined series of concentrations and fluxes.
208
-
209
- """
210
- return pd.concat((self.variables, self.fluxes), axis=0)
211
-
212
-
213
- @dataclass(slots=True)
214
- class TimeCourse:
215
- """Represents a time course in a simulation.
216
-
217
- Attributes:
218
- variables: DataFrame of concentrations over time.
219
- fluxes: DataFrame of fluxes over time.
220
-
221
- """
222
-
223
- variables: pd.DataFrame
224
- fluxes: pd.DataFrame
225
-
226
- @classmethod
227
- def from_scan(
228
- cls,
229
- *,
230
- model: Model,
231
- time_points: Array,
232
- result: Result | None,
233
- ) -> Self:
234
- """Initialize the Scan object.
235
-
236
- Args:
237
- model (Model): The model object.
238
- time_points (Array): Array of time points.
239
- result: Result of the simulation
240
-
241
- """
242
- if result is None:
243
- return cls(
244
- _empty_conc_df(model, time_points),
245
- _empty_flux_df(model, time_points),
246
- )
247
- return cls(
248
- result.variables,
249
- result.fluxes,
250
- )
251
-
252
- @property
253
- def results(self) -> pd.DataFrame:
254
- """Get the combined results of concentrations and fluxes over time.
255
-
256
- Examples:
257
- >>> time_course.results
258
- Time x1 x2 v1 v2
259
- 0.0 1.0 1.00 1.00 1.00
260
- 0.1 0.9 0.99 0.99 0.99
261
- 0.2 0.8 0.99 0.99 0.99
262
-
263
- Returns:
264
- pd.DataFrame: Combined DataFrame of concentrations and fluxes.
265
-
266
- """
267
- return pd.concat((self.variables, self.fluxes), axis=1)
268
-
269
-
270
75
  ###############################################################################
271
76
  # Workers
272
77
  ###############################################################################
@@ -282,7 +87,7 @@ class SteadyStateWorker(Protocol):
282
87
  rel_norm: bool,
283
88
  integrator: IntegratorType | None,
284
89
  y0: dict[str, float] | None,
285
- ) -> TimePoint:
90
+ ) -> Result:
286
91
  """Call the worker function."""
287
92
  ...
288
93
 
@@ -297,7 +102,7 @@ class TimeCourseWorker(Protocol):
297
102
  *,
298
103
  integrator: IntegratorType | None,
299
104
  y0: dict[str, float] | None,
300
- ) -> TimeCourse:
105
+ ) -> Result:
301
106
  """Call the worker function."""
302
107
  ...
303
108
 
@@ -313,7 +118,23 @@ class ProtocolWorker(Protocol):
313
118
  integrator: IntegratorType | None,
314
119
  y0: dict[str, float] | None,
315
120
  time_points_per_step: int = 10,
316
- ) -> TimeCourse:
121
+ ) -> Result:
122
+ """Call the worker function."""
123
+ ...
124
+
125
+
126
+ class ProtocolTimeCourseWorker(Protocol):
127
+ """Worker function for protocol-based simulations."""
128
+
129
+ def __call__(
130
+ self,
131
+ model: Model,
132
+ protocol: pd.DataFrame,
133
+ time_points: Array,
134
+ *,
135
+ integrator: IntegratorType | None,
136
+ y0: dict[str, float] | None,
137
+ ) -> Result:
317
138
  """Call the worker function."""
318
139
  ...
319
140
 
@@ -324,7 +145,7 @@ def _steady_state_worker(
324
145
  rel_norm: bool,
325
146
  integrator: IntegratorType | None,
326
147
  y0: dict[str, float] | None,
327
- ) -> TimePoint:
148
+ ) -> Result:
328
149
  """Simulate the model to steady state and return concentrations and fluxes.
329
150
 
330
151
  Args:
@@ -345,7 +166,9 @@ def _steady_state_worker(
345
166
  )
346
167
  except ZeroDivisionError:
347
168
  res = None
348
- return TimePoint.from_result(model=model, result=res)
169
+ return (
170
+ Result.default(model=model, time_points=np.array([0.0])) if res is None else res
171
+ )
349
172
 
350
173
 
351
174
  def _time_course_worker(
@@ -353,7 +176,7 @@ def _time_course_worker(
353
176
  time_points: Array,
354
177
  y0: dict[str, float] | None,
355
178
  integrator: IntegratorType | None,
356
- ) -> TimeCourse:
179
+ ) -> Result:
357
180
  """Simulate the model to steady state and return concentrations and fluxes.
358
181
 
359
182
  Args:
@@ -374,11 +197,7 @@ def _time_course_worker(
374
197
  )
375
198
  except ZeroDivisionError:
376
199
  res = None
377
- return TimeCourse.from_scan(
378
- model=model,
379
- time_points=time_points,
380
- result=res,
381
- )
200
+ return Result.default(model=model, time_points=time_points) if res is None else res
382
201
 
383
202
 
384
203
  def _protocol_worker(
@@ -388,7 +207,7 @@ def _protocol_worker(
388
207
  integrator: IntegratorType | None,
389
208
  y0: dict[str, float] | None,
390
209
  time_points_per_step: int = 10,
391
- ) -> TimeCourse:
210
+ ) -> Result:
392
211
  """Simulate the model over a protocol and return concentrations and fluxes.
393
212
 
394
213
  Args:
@@ -419,11 +238,43 @@ def _protocol_worker(
419
238
  protocol.index[-1].total_seconds(),
420
239
  len(protocol) * time_points_per_step,
421
240
  )
422
- return TimeCourse.from_scan(
423
- model=model,
424
- time_points=time_points,
425
- result=res,
426
- )
241
+ return Result.default(model=model, time_points=time_points) if res is None else res
242
+
243
+
244
+ def _protocol_time_course_worker(
245
+ model: Model,
246
+ protocol: pd.DataFrame,
247
+ time_points: Array,
248
+ *,
249
+ integrator: IntegratorType | None,
250
+ y0: dict[str, float] | None,
251
+ ) -> Result:
252
+ """Simulate the model over a protocol and return concentrations and fluxes.
253
+
254
+ Args:
255
+ model: Model instance to simulate.
256
+ y0: Initial conditions as a dictionary {species: value}.
257
+ protocol: DataFrame containing the protocol steps.
258
+ time_points: Time points where to return the simulation
259
+ integrator: Integrator function to use for steady state calculation
260
+
261
+ Returns:
262
+ TimeCourse: Object containing protocol series concentrations and fluxes.
263
+
264
+ """
265
+ try:
266
+ res = (
267
+ Simulator(model, integrator=integrator, y0=y0)
268
+ .simulate_protocol_time_course(
269
+ protocol=protocol,
270
+ time_points=time_points,
271
+ )
272
+ .get_result()
273
+ )
274
+ except ZeroDivisionError:
275
+ res = None
276
+
277
+ return Result.default(model=model, time_points=time_points) if res is None else res
427
278
 
428
279
 
429
280
  def steady_state(
@@ -436,7 +287,7 @@ def steady_state(
436
287
  cache: Cache | None = None,
437
288
  worker: SteadyStateWorker = _steady_state_worker,
438
289
  integrator: IntegratorType | None = None,
439
- ) -> SteadyStates:
290
+ ) -> SteadyStateScan:
440
291
  """Get steady-state results over supplied values.
441
292
 
442
293
  Args:
@@ -492,16 +343,16 @@ def steady_state(
492
343
  cache=cache,
493
344
  parallel=parallel,
494
345
  )
495
- concs = pd.DataFrame({k: v.variables.T for k, v in res}).T
496
- fluxes = pd.DataFrame({k: v.fluxes.T for k, v in res}).T
497
- idx = (
498
- pd.Index(to_scan.iloc[:, 0])
499
- if to_scan.shape[1] == 1
500
- else pd.MultiIndex.from_frame(to_scan)
346
+
347
+ return SteadyStateScan(
348
+ raw_index=(
349
+ pd.Index(to_scan.iloc[:, 0])
350
+ if to_scan.shape[1] == 1
351
+ else pd.MultiIndex.from_frame(to_scan)
352
+ ),
353
+ raw_results=[i[1] for i in res],
354
+ to_scan=to_scan,
501
355
  )
502
- concs.index = idx
503
- fluxes.index = idx
504
- return SteadyStates(variables=concs, fluxes=fluxes, parameters=to_scan)
505
356
 
506
357
 
507
358
  def time_course(
@@ -514,7 +365,7 @@ def time_course(
514
365
  cache: Cache | None = None,
515
366
  integrator: IntegratorType | None = None,
516
367
  worker: TimeCourseWorker = _time_course_worker,
517
- ) -> TimeCourseByPars:
368
+ ) -> TimeCourseScan:
518
369
  """Get time course for each supplied parameter.
519
370
 
520
371
  Examples:
@@ -585,16 +436,13 @@ def time_course(
585
436
  cache=cache,
586
437
  parallel=parallel,
587
438
  )
588
- concs = cast(dict, {k: v.variables for k, v in res})
589
- fluxes = cast(dict, {k: v.fluxes for k, v in res})
590
- return TimeCourseByPars(
591
- parameters=to_scan,
592
- variables=pd.concat(concs, names=["n", "time"]),
593
- fluxes=pd.concat(fluxes, names=["n", "time"]),
439
+ return TimeCourseScan(
440
+ to_scan=to_scan,
441
+ raw_results=dict(res),
594
442
  )
595
443
 
596
444
 
597
- def time_course_over_protocol(
445
+ def protocol(
598
446
  model: Model,
599
447
  *,
600
448
  to_scan: pd.DataFrame,
@@ -605,7 +453,7 @@ def time_course_over_protocol(
605
453
  cache: Cache | None = None,
606
454
  worker: ProtocolWorker = _protocol_worker,
607
455
  integrator: IntegratorType | None = None,
608
- ) -> ProtocolByPars:
456
+ ) -> ProtocolScan:
609
457
  """Get protocol series for each supplied parameter.
610
458
 
611
459
  Examples:
@@ -656,11 +504,77 @@ def time_course_over_protocol(
656
504
  cache=cache,
657
505
  parallel=parallel,
658
506
  )
659
- concs = cast(dict, {k: v.variables for k, v in res})
660
- fluxes = cast(dict, {k: v.fluxes for k, v in res})
661
- return ProtocolByPars(
662
- parameters=to_scan,
507
+ return ProtocolScan(
508
+ to_scan=to_scan,
509
+ protocol=protocol,
510
+ raw_results=dict(res),
511
+ )
512
+
513
+
514
+ def protocol_time_course(
515
+ model: Model,
516
+ *,
517
+ to_scan: pd.DataFrame,
518
+ protocol: pd.DataFrame,
519
+ time_points: Array,
520
+ y0: dict[str, float] | None = None,
521
+ parallel: bool = True,
522
+ cache: Cache | None = None,
523
+ worker: ProtocolTimeCourseWorker = _protocol_time_course_worker,
524
+ integrator: IntegratorType | None = None,
525
+ ) -> ProtocolScan:
526
+ """Get protocol series for each supplied parameter.
527
+
528
+ Examples:
529
+ >>> scan.time_course_over_protocol(
530
+ ... model,
531
+ ... parameters=pd.DataFrame({"k2": np.linspace(1, 2, 11)}),
532
+ ... protocol=make_protocol(
533
+ ... {
534
+ ... 1: {"k1": 1},
535
+ ... 2: {"k1": 2},
536
+ ... }
537
+ ... ),
538
+ ... )
539
+
540
+ Args:
541
+ model: Model instance to simulate.
542
+ to_scan: DataFrame containing parameter or initial values to scan.
543
+ protocol: Protocol to follow for the simulation.
544
+ time_points: Time points where to return simulation results
545
+ y0: Initial conditions as a dictionary {variable: value}.
546
+ parallel: Whether to execute in parallel (default: True).
547
+ cache: Optional cache to store and retrieve results.
548
+ worker: Worker function to use for the simulation.
549
+ integrator: Integrator function to use for steady state calculation
550
+
551
+ Returns:
552
+ TimeCourseByPars: Protocol series results for each parameter set.
553
+
554
+ """
555
+ # We update the initial conditions separately here, because `to_scan` might also
556
+ # contain initial conditions.
557
+ if y0 is not None:
558
+ model.update_variables(y0)
559
+
560
+ res = parallelise(
561
+ partial(
562
+ _update_parameters_and_initial_conditions,
563
+ fn=partial(
564
+ worker,
565
+ protocol=protocol,
566
+ time_points=time_points,
567
+ integrator=integrator,
568
+ y0=None,
569
+ ),
570
+ model=model,
571
+ ),
572
+ inputs=list(to_scan.iterrows()),
573
+ cache=cache,
574
+ parallel=parallel,
575
+ )
576
+ return ProtocolScan(
577
+ to_scan=to_scan,
663
578
  protocol=protocol,
664
- variables=pd.concat(concs, names=["n", "time"]),
665
- fluxes=pd.concat(fluxes, names=["n", "time"]),
579
+ raw_results=dict(res),
666
580
  )