mxlpy 0.23.0__py3-none-any.whl → 0.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mxlpy/model.py CHANGED
@@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, Self, cast
18
18
  import numpy as np
19
19
  import pandas as pd
20
20
  import sympy
21
+ from wadler_lindig import pformat
21
22
 
22
23
  from mxlpy import fns
23
24
  from mxlpy.meta.source_tools import fn_to_sympy
@@ -91,6 +92,10 @@ class MdText:
91
92
 
92
93
  content: list[str]
93
94
 
95
+ def __repr__(self) -> str:
96
+ """Return default representation."""
97
+ return pformat(self)
98
+
94
99
  def _repr_markdown_(self) -> str:
95
100
  return "\n".join(self.content)
96
101
 
@@ -101,6 +106,10 @@ class UnitCheck:
101
106
 
102
107
  per_variable: dict[str, dict[str, bool | Failure | None]]
103
108
 
109
+ def __repr__(self) -> str:
110
+ """Return default representation."""
111
+ return pformat(self)
112
+
104
113
  @staticmethod
105
114
  def _fmt_success(s: str) -> str:
106
115
  return f"<span style='color: green'>{s}</span>"
@@ -171,6 +180,10 @@ class Dependency:
171
180
  required: set[str]
172
181
  provided: set[str]
173
182
 
183
+ def __repr__(self) -> str:
184
+ """Return default representation."""
185
+ return pformat(self)
186
+
174
187
 
175
188
  class MissingDependenciesError(Exception):
176
189
  """Raised when dependencies cannot be sorted topologically.
@@ -374,6 +387,10 @@ class ModelCache:
374
387
 
375
388
  """
376
389
 
390
+ def __repr__(self) -> str:
391
+ """Return default representation."""
392
+ return pformat(self)
393
+
377
394
  order: list[str] # mostly for debug purposes
378
395
  var_names: list[str]
379
396
  dyn_order: list[str]
@@ -402,15 +419,19 @@ class Model:
402
419
  """
403
420
 
404
421
  _ids: dict[str, str] = field(default_factory=dict, repr=False)
422
+ _cache: ModelCache | None = field(default=None, repr=False)
405
423
  _variables: dict[str, Variable] = field(default_factory=dict)
406
424
  _parameters: dict[str, Parameter] = field(default_factory=dict)
407
425
  _derived: dict[str, Derived] = field(default_factory=dict)
408
426
  _readouts: dict[str, Readout] = field(default_factory=dict)
409
427
  _reactions: dict[str, Reaction] = field(default_factory=dict)
410
428
  _surrogates: dict[str, AbstractSurrogate] = field(default_factory=dict)
411
- _cache: ModelCache | None = None
412
429
  _data: dict[str, pd.Series | pd.DataFrame] = field(default_factory=dict)
413
430
 
431
+ def __repr__(self) -> str:
432
+ """Return default representation."""
433
+ return pformat(self)
434
+
414
435
  ###########################################################################
415
436
  # Cache
416
437
  ###########################################################################
@@ -2281,7 +2302,7 @@ class Model:
2281
2302
  if (cache := self._cache) is None:
2282
2303
  cache = self._create_cache()
2283
2304
  var_names = self.get_variable_names()
2284
- dependent = self._get_args(
2305
+ args = self._get_args(
2285
2306
  variables=self.get_initial_conditions() if variables is None else variables,
2286
2307
  time=time,
2287
2308
  cache=cache,
@@ -2289,12 +2310,12 @@ class Model:
2289
2310
  dxdt = pd.Series(np.zeros(len(var_names), dtype=float), index=var_names)
2290
2311
  for k, stoc in cache.stoich_by_cpds.items():
2291
2312
  for flux, n in stoc.items():
2292
- dxdt[k] += n * dependent[flux]
2313
+ dxdt[k] += n * args[flux]
2293
2314
 
2294
2315
  for k, sd in cache.dyn_stoich_by_cpds.items():
2295
2316
  for flux, dv in sd.items():
2296
- n = dv.fn(*(dependent[i] for i in dv.args))
2297
- dxdt[k] += n * dependent[flux]
2317
+ n = dv.fn(*(args[i] for i in dv.args))
2318
+ dxdt[k] += n * args[flux]
2298
2319
  return dxdt
2299
2320
 
2300
2321
  ##########################################################################
mxlpy/plot.py CHANGED
@@ -41,6 +41,7 @@ from matplotlib.figure import Figure
41
41
  from matplotlib.legend import Legend
42
42
  from matplotlib.patches import Patch
43
43
  from mpl_toolkits.mplot3d import Axes3D
44
+ from wadler_lindig import pformat
44
45
 
45
46
  from mxlpy.label_map import LabelMapper
46
47
 
@@ -105,6 +106,10 @@ class Axs:
105
106
  """Length of axes."""
106
107
  return len(self.axs.flatten())
107
108
 
109
+ def __repr__(self) -> str:
110
+ """Return default representation."""
111
+ return pformat(self)
112
+
108
113
  @overload
109
114
  def __getitem__(self, row_col: int) -> Axes: ...
110
115
 
@@ -213,6 +218,28 @@ def _partition_by_order_of_magnitude(s: pd.Series) -> list[list[str]]:
213
218
  ]
214
219
 
215
220
 
221
+ def _combine_small_groups(
222
+ groups: list[list[str]], min_group_size: int
223
+ ) -> list[list[str]]:
224
+ """Combine smaller groups."""
225
+ result = []
226
+ current_group = groups[0]
227
+
228
+ for next_group in groups[1:]:
229
+ if len(current_group) < min_group_size:
230
+ current_group.extend(next_group)
231
+ else:
232
+ result.append(current_group)
233
+ current_group = next_group
234
+
235
+ # Last group
236
+ if len(current_group) < min_group_size:
237
+ result[-1].extend(current_group)
238
+ else:
239
+ result.append(current_group)
240
+ return result
241
+
242
+
216
243
  def _split_large_groups[T](groups: list[list[T]], max_size: int) -> list[list[T]]:
217
244
  """Split groups larger than the given size into smaller groups."""
218
245
  return list(
@@ -516,7 +543,7 @@ def grid_layout(
516
543
  n_rows = math.ceil(n_groups / n_cols)
517
544
  figsize = (n_cols * col_width, n_rows * row_height)
518
545
 
519
- return _default_fig_axs(
546
+ fig, axs = _default_fig_axs(
520
547
  ncols=n_cols,
521
548
  nrows=n_rows,
522
549
  figsize=figsize,
@@ -525,6 +552,12 @@ def grid_layout(
525
552
  grid=grid,
526
553
  )
527
554
 
555
+ # Disable unused plots by default
556
+ axsl = list(axs)
557
+ for i in range(n_groups, len(axs)):
558
+ axsl[i].set_visible(False)
559
+ return fig, axs
560
+
528
561
 
529
562
  ##########################################################################
530
563
  # Plots
@@ -586,10 +619,6 @@ def bars_grouped(
586
619
  ylabel=ylabel,
587
620
  )
588
621
 
589
- axsl = list(axs)
590
- for i in range(len(groups), len(axs)):
591
- axsl[i].set_visible(False)
592
-
593
622
  return fig, axs
594
623
 
595
624
 
@@ -599,18 +628,20 @@ def bars_autogrouped(
599
628
  n_cols: int = 2,
600
629
  col_width: float = 4,
601
630
  row_height: float = 3,
631
+ min_group_size: int = 1,
602
632
  max_group_size: int = 6,
603
633
  grid: bool = True,
604
634
  xlabel: str | None = None,
605
635
  ylabel: str | None = None,
606
636
  ) -> FigAxs:
607
637
  """Plot a series or dataframe with lines grouped by order of magnitude."""
608
- group_names = _split_large_groups(
638
+ group_names = (
609
639
  _partition_by_order_of_magnitude(s)
610
640
  if isinstance(s, pd.Series)
611
- else _partition_by_order_of_magnitude(s.max()),
612
- max_size=max_group_size,
641
+ else _partition_by_order_of_magnitude(s.max())
613
642
  )
643
+ group_names = _combine_small_groups(group_names, min_group_size=min_group_size)
644
+ group_names = _split_large_groups(group_names, max_size=max_group_size)
614
645
 
615
646
  groups: list[pd.Series] | list[pd.DataFrame] = (
616
647
  [s.loc[group] for group in group_names]
@@ -714,10 +745,6 @@ def lines_grouped(
714
745
  ylabel=ylabel,
715
746
  )
716
747
 
717
- axsl = list(axs)
718
- for i in range(len(groups), len(axs)):
719
- axsl[i].set_visible(False)
720
-
721
748
  return fig, axs
722
749
 
723
750
 
@@ -727,6 +754,7 @@ def line_autogrouped(
727
754
  n_cols: int = 2,
728
755
  col_width: float = 4,
729
756
  row_height: float = 3,
757
+ min_group_size: int = 1,
730
758
  max_group_size: int = 6,
731
759
  grid: bool = True,
732
760
  xlabel: str | None = None,
@@ -736,12 +764,13 @@ def line_autogrouped(
736
764
  linestyle: Linestyle | None = None,
737
765
  ) -> FigAxs:
738
766
  """Plot a series or dataframe with lines grouped by order of magnitude."""
739
- group_names = _split_large_groups(
767
+ group_names = (
740
768
  _partition_by_order_of_magnitude(s)
741
769
  if isinstance(s, pd.Series)
742
- else _partition_by_order_of_magnitude(s.max()),
743
- max_size=max_group_size,
770
+ else _partition_by_order_of_magnitude(s.max())
744
771
  )
772
+ group_names = _combine_small_groups(group_names, min_group_size=min_group_size)
773
+ group_names = _split_large_groups(group_names, max_size=max_group_size)
745
774
 
746
775
  groups: list[pd.Series] | list[pd.DataFrame] = (
747
776
  [s.loc[group] for group in group_names]
mxlpy/sbml/_data.py CHANGED
@@ -3,6 +3,8 @@ from __future__ import annotations
3
3
  from dataclasses import dataclass
4
4
  from typing import TYPE_CHECKING
5
5
 
6
+ from wadler_lindig import pformat
7
+
6
8
  if TYPE_CHECKING:
7
9
  from collections.abc import Mapping
8
10
 
@@ -26,18 +28,30 @@ class AtomicUnit:
26
28
  scale: int
27
29
  multiplier: float
28
30
 
31
+ def __repr__(self) -> str:
32
+ """Return default representation."""
33
+ return pformat(self)
34
+
29
35
 
30
36
  @dataclass
31
37
  class CompositeUnit:
32
38
  sbml_id: str
33
39
  units: list
34
40
 
41
+ def __repr__(self) -> str:
42
+ """Return default representation."""
43
+ return pformat(self)
44
+
35
45
 
36
46
  @dataclass
37
47
  class Parameter:
38
48
  value: float
39
49
  is_constant: bool
40
50
 
51
+ def __repr__(self) -> str:
52
+ """Return default representation."""
53
+ return pformat(self)
54
+
41
55
 
42
56
  @dataclass
43
57
  class Compartment:
@@ -47,6 +61,10 @@ class Compartment:
47
61
  units: str
48
62
  is_constant: bool
49
63
 
64
+ def __repr__(self) -> str:
65
+ """Return default representation."""
66
+ return pformat(self)
67
+
50
68
 
51
69
  @dataclass
52
70
  class Compound:
@@ -58,21 +76,37 @@ class Compound:
58
76
  is_constant: bool
59
77
  is_concentration: bool
60
78
 
79
+ def __repr__(self) -> str:
80
+ """Return default representation."""
81
+ return pformat(self)
82
+
61
83
 
62
84
  @dataclass
63
85
  class Derived:
64
86
  body: str
65
87
  args: list[str]
66
88
 
89
+ def __repr__(self) -> str:
90
+ """Return default representation."""
91
+ return pformat(self)
92
+
67
93
 
68
94
  @dataclass
69
95
  class Function:
70
96
  body: str
71
97
  args: list[str]
72
98
 
99
+ def __repr__(self) -> str:
100
+ """Return default representation."""
101
+ return pformat(self)
102
+
73
103
 
74
104
  @dataclass
75
105
  class Reaction:
76
106
  body: str
77
107
  stoichiometry: Mapping[str, float | str]
78
108
  args: list[str]
109
+
110
+ def __repr__(self) -> str:
111
+ """Return default representation."""
112
+ return pformat(self)
mxlpy/sbml/_export.py CHANGED
@@ -447,6 +447,7 @@ def _create_sbml_variables(
447
447
  cpd.setConstant(False)
448
448
  cpd.setBoundaryCondition(False)
449
449
  cpd.setHasOnlySubstanceUnits(False)
450
+ cpd.setCompartment("compartment")
450
451
  # cpd.setUnit() # FIXME: implement
451
452
  if isinstance((init := variable.initial_value), InitialAssignment):
452
453
  ar = sbml_model.createInitialAssignment()
@@ -455,7 +456,7 @@ def _create_sbml_variables(
455
456
  ar.setVariable(_convert_id_to_sbml(id_=name, prefix="IA"))
456
457
  ar.setMath(_sbmlify_fn(init.fn, init.args))
457
458
  else:
458
- cpd.setInitialAmount(float(init))
459
+ cpd.setInitialConcentration(float(init))
459
460
 
460
461
 
461
462
  def _create_sbml_derived_variables(*, model: Model, sbml_model: libsbml.Model) -> None:
@@ -591,8 +592,8 @@ def _default_compartments(
591
592
  ) -> dict[str, Compartment]:
592
593
  if compartments is None:
593
594
  return {
594
- "c": Compartment(
595
- name="cytosol",
595
+ "compartment": Compartment(
596
+ name="compartment",
596
597
  dimensions=3,
597
598
  size=1,
598
599
  units="litre",
mxlpy/scan.py CHANGED
@@ -39,12 +39,14 @@ if TYPE_CHECKING:
39
39
 
40
40
 
41
41
  __all__ = [
42
+ "ProtocolTimeCourseWorker",
42
43
  "ProtocolWorker",
43
44
  "SteadyStateWorker",
44
45
  "TimeCourseWorker",
46
+ "protocol",
47
+ "protocol_time_course",
45
48
  "steady_state",
46
49
  "time_course",
47
- "time_course_over_protocol",
48
50
  ]
49
51
 
50
52
 
@@ -70,11 +72,6 @@ def _update_parameters_and_initial_conditions[T](
70
72
  return fn(model)
71
73
 
72
74
 
73
- ###############################################################################
74
- # Single returns
75
- ###############################################################################
76
-
77
-
78
75
  ###############################################################################
79
76
  # Workers
80
77
  ###############################################################################
@@ -126,6 +123,22 @@ class ProtocolWorker(Protocol):
126
123
  ...
127
124
 
128
125
 
126
+ class ProtocolTimeCourseWorker(Protocol):
127
+ """Worker function for protocol-based simulations."""
128
+
129
+ def __call__(
130
+ self,
131
+ model: Model,
132
+ protocol: pd.DataFrame,
133
+ time_points: Array,
134
+ *,
135
+ integrator: IntegratorType | None,
136
+ y0: dict[str, float] | None,
137
+ ) -> Result:
138
+ """Call the worker function."""
139
+ ...
140
+
141
+
129
142
  def _steady_state_worker(
130
143
  model: Model,
131
144
  *,
@@ -228,6 +241,42 @@ def _protocol_worker(
228
241
  return Result.default(model=model, time_points=time_points) if res is None else res
229
242
 
230
243
 
244
+ def _protocol_time_course_worker(
245
+ model: Model,
246
+ protocol: pd.DataFrame,
247
+ time_points: Array,
248
+ *,
249
+ integrator: IntegratorType | None,
250
+ y0: dict[str, float] | None,
251
+ ) -> Result:
252
+ """Simulate the model over a protocol and return concentrations and fluxes.
253
+
254
+ Args:
255
+ model: Model instance to simulate.
256
+ y0: Initial conditions as a dictionary {species: value}.
257
+ protocol: DataFrame containing the protocol steps.
258
+ time_points: Time points where to return the simulation
259
+ integrator: Integrator function to use for steady state calculation
260
+
261
+ Returns:
262
+ TimeCourse: Object containing protocol series concentrations and fluxes.
263
+
264
+ """
265
+ try:
266
+ res = (
267
+ Simulator(model, integrator=integrator, y0=y0)
268
+ .simulate_protocol_time_course(
269
+ protocol=protocol,
270
+ time_points=time_points,
271
+ )
272
+ .get_result()
273
+ )
274
+ except ZeroDivisionError:
275
+ res = None
276
+
277
+ return Result.default(model=model, time_points=time_points) if res is None else res
278
+
279
+
231
280
  def steady_state(
232
281
  model: Model,
233
282
  *,
@@ -393,7 +442,7 @@ def time_course(
393
442
  )
394
443
 
395
444
 
396
- def time_course_over_protocol(
445
+ def protocol(
397
446
  model: Model,
398
447
  *,
399
448
  to_scan: pd.DataFrame,
@@ -460,3 +509,72 @@ def time_course_over_protocol(
460
509
  protocol=protocol,
461
510
  raw_results=dict(res),
462
511
  )
512
+
513
+
514
+ def protocol_time_course(
515
+ model: Model,
516
+ *,
517
+ to_scan: pd.DataFrame,
518
+ protocol: pd.DataFrame,
519
+ time_points: Array,
520
+ y0: dict[str, float] | None = None,
521
+ parallel: bool = True,
522
+ cache: Cache | None = None,
523
+ worker: ProtocolTimeCourseWorker = _protocol_time_course_worker,
524
+ integrator: IntegratorType | None = None,
525
+ ) -> ProtocolScan:
526
+ """Get protocol series for each supplied parameter.
527
+
528
+ Examples:
529
+ >>> scan.time_course_over_protocol(
530
+ ... model,
531
+ ... parameters=pd.DataFrame({"k2": np.linspace(1, 2, 11)}),
532
+ ... protocol=make_protocol(
533
+ ... {
534
+ ... 1: {"k1": 1},
535
+ ... 2: {"k1": 2},
536
+ ... }
537
+ ... ),
538
+ ... )
539
+
540
+ Args:
541
+ model: Model instance to simulate.
542
+ to_scan: DataFrame containing parameter or initial values to scan.
543
+ protocol: Protocol to follow for the simulation.
544
+ time_points: Time points where to return simulation results
545
+ y0: Initial conditions as a dictionary {variable: value}.
546
+ parallel: Whether to execute in parallel (default: True).
547
+ cache: Optional cache to store and retrieve results.
548
+ worker: Worker function to use for the simulation.
549
+ integrator: Integrator function to use for steady state calculation
550
+
551
+ Returns:
552
+ TimeCourseByPars: Protocol series results for each parameter set.
553
+
554
+ """
555
+ # We update the initial conditions separately here, because `to_scan` might also
556
+ # contain initial conditions.
557
+ if y0 is not None:
558
+ model.update_variables(y0)
559
+
560
+ res = parallelise(
561
+ partial(
562
+ _update_parameters_and_initial_conditions,
563
+ fn=partial(
564
+ worker,
565
+ protocol=protocol,
566
+ time_points=time_points,
567
+ integrator=integrator,
568
+ y0=None,
569
+ ),
570
+ model=model,
571
+ ),
572
+ inputs=list(to_scan.iterrows()),
573
+ cache=cache,
574
+ parallel=parallel,
575
+ )
576
+ return ProtocolScan(
577
+ to_scan=to_scan,
578
+ protocol=protocol,
579
+ raw_results=dict(res),
580
+ )
mxlpy/simulator.py CHANGED
@@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Self, cast
17
17
  import numpy as np
18
18
  import pandas as pd
19
19
  from sympy import lambdify
20
+ from wadler_lindig import pformat
20
21
 
21
22
  from mxlpy.integrators import DefaultIntegrator
22
23
  from mxlpy.symbolic import to_symbolic_model
@@ -63,6 +64,10 @@ class Simulator:
63
64
  _integrator_type: IntegratorType
64
65
  _time_shift: float | None
65
66
 
67
+ def __repr__(self) -> str:
68
+ """Return default representation."""
69
+ return pformat(self)
70
+
66
71
  def __init__(
67
72
  self,
68
73
  model: Model,