mxlpy 0.21.0__py3-none-any.whl → 0.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,7 +17,7 @@ with contextlib.redirect_stderr(open(os.devnull, "w")): # noqa: PTH123
17
17
  if TYPE_CHECKING:
18
18
  from collections.abc import Callable
19
19
 
20
- from mxlpy.types import Array, ArrayLike
20
+ from mxlpy.types import Array, ArrayLike, Rhs
21
21
 
22
22
 
23
23
  __all__ = [
@@ -43,8 +43,8 @@ class Assimulo:
43
43
 
44
44
  """
45
45
 
46
- rhs: Callable
47
- y0: ArrayLike
46
+ rhs: Rhs
47
+ y0: tuple[float, ...]
48
48
  jacobian: Callable | None = None
49
49
  atol: float = 1e-8
50
50
  rtol: float = 1e-8
@@ -0,0 +1,119 @@
1
+ """Diffrax integrator for solving ODEs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import TYPE_CHECKING
7
+
8
+ import numpy as np
9
+ from diffrax import (
10
+ AbstractSolver,
11
+ AbstractStepSizeController,
12
+ Kvaerno5,
13
+ ODETerm,
14
+ PIDController,
15
+ SaveAt,
16
+ diffeqsolve,
17
+ )
18
+
19
+ __all__ = ["Diffrax"]
20
+
21
+ if TYPE_CHECKING:
22
+ from collections.abc import Callable
23
+
24
+ from mxlpy.types import Array, Rhs
25
+
26
+
27
+ @dataclass
28
+ class Diffrax:
29
+ """Diffrax integrator for solving ODEs."""
30
+
31
+ rhs: Rhs
32
+ y0: tuple[float, ...]
33
+ jac: Callable | None = None
34
+ solver: AbstractSolver = field(default=Kvaerno5())
35
+ stepsize_controller: AbstractStepSizeController = field(
36
+ default=PIDController(rtol=1e-8, atol=1e-8)
37
+ )
38
+ t0: float = 0.0
39
+
40
+ def __post_init__(self) -> None:
41
+ """Create copy of initial state.
42
+
43
+ This method creates a copy of the initial state `y0` and stores it in the `_y0_orig` attribute.
44
+ This is useful for preserving the original initial state for future reference or reset operations.
45
+
46
+ """
47
+ self._y0_orig = self.y0
48
+
49
+ def reset(self) -> None:
50
+ """Reset the integrator."""
51
+ self.t0 = 0
52
+ self.y0 = self._y0_orig
53
+
54
+ def integrate_time_course(
55
+ self, *, time_points: Array
56
+ ) -> tuple[Array | None, Array | None]:
57
+ """Integrate the ODE system over a time course.
58
+
59
+ Args:
60
+ time_points: Time points for the integration.
61
+
62
+ Returns:
63
+ tuple[Array, Array]: Tuple containing the time points and the integrated values.
64
+
65
+ """
66
+ if time_points[0] != self.t0:
67
+ time_points = np.insert(time_points, 0, self.t0)
68
+
69
+ res = diffeqsolve(
70
+ ODETerm(lambda t, y, _: self.rhs(t, y)), # type: ignore
71
+ solver=self.solver,
72
+ t0=time_points[0],
73
+ t1=time_points[-1],
74
+ dt0=None,
75
+ y0=self.y0,
76
+ max_steps=None,
77
+ saveat=SaveAt(ts=time_points), # type: ignore
78
+ stepsize_controller=self.stepsize_controller,
79
+ )
80
+
81
+ t = np.atleast_1d(np.array(res.ts, dtype=float))
82
+ y = np.atleast_2d(np.array(res.ys, dtype=float).T)
83
+
84
+ self.t0 = t[-1]
85
+ self.y0 = y[-1]
86
+ return t, y
87
+
88
+ def integrate(
89
+ self,
90
+ *,
91
+ t_end: float,
92
+ steps: int | None = None,
93
+ ) -> tuple[Array | None, Array | None]:
94
+ """Integrate the ODE system over a time course."""
95
+ steps = 100 if steps is None else steps
96
+
97
+ return self.integrate_time_course(
98
+ time_points=np.linspace(self.t0, t_end, steps, dtype=float)
99
+ )
100
+
101
+ def integrate_to_steady_state(
102
+ self,
103
+ *,
104
+ tolerance: float,
105
+ rel_norm: bool,
106
+ t_max: float = 1_000_000_000,
107
+ ) -> tuple[float | None, Array | None]:
108
+ """Integrate the ODE system to steady state.
109
+
110
+ Args:
111
+ tolerance: Tolerance for determining steady state.
112
+ rel_norm: Whether to use relative normalization.
113
+ t_max: Maximum time point for the integration (default: 1,000,000,000).
114
+
115
+ Returns:
116
+ tuple[float | None, Array | None]: Tuple containing the final time point and the integrated values at steady state.
117
+
118
+ """
119
+ raise NotImplementedError
@@ -14,6 +14,8 @@ from mxlpy.types import Array, ArrayLike
14
14
  if TYPE_CHECKING:
15
15
  from collections.abc import Callable
16
16
 
17
+ from mxlpy.types import Rhs
18
+
17
19
 
18
20
  __all__ = [
19
21
  "Scipy",
@@ -40,13 +42,13 @@ class Scipy:
40
42
 
41
43
  """
42
44
 
43
- rhs: Callable
44
- y0: ArrayLike
45
+ rhs: Rhs
46
+ y0: tuple[float, ...]
45
47
  jacobian: Callable | None = None
46
48
  atol: float = 1e-8
47
49
  rtol: float = 1e-8
48
50
  t0: float = 0.0
49
- _y0_orig: ArrayLike = field(default_factory=list)
51
+ _y0_orig: tuple[float, ...] = field(default_factory=tuple)
50
52
 
51
53
  def __post_init__(self) -> None:
52
54
  """Create copy of initial state.
@@ -55,12 +57,12 @@ class Scipy:
55
57
  This is useful for preserving the original initial state for future reference or reset operations.
56
58
 
57
59
  """
58
- self._y0_orig = self.y0.copy()
60
+ self._y0_orig = self.y0
59
61
 
60
62
  def reset(self) -> None:
61
63
  """Reset the integrator."""
62
64
  self.t0 = 0
63
- self.y0 = self._y0_orig.copy()
65
+ self.y0 = self._y0_orig
64
66
 
65
67
  def integrate(
66
68
  self,
@@ -98,6 +100,9 @@ class Scipy:
98
100
  tuple[ArrayLike, ArrayLike]: Tuple containing the time points and the integrated values.
99
101
 
100
102
  """
103
+ if time_points[0] != self.t0:
104
+ time_points = np.insert(time_points, 0, self.t0)
105
+
101
106
  res = spi.solve_ivp(
102
107
  fun=self.rhs,
103
108
  y0=self.y0,
@@ -140,9 +145,13 @@ class Scipy:
140
145
 
141
146
  """
142
147
  self.reset()
143
- integ = spi.ode(self.rhs, jac=self.jacobian)
148
+
149
+ # If rhs returns a tuple, we get weird errors, so we need
150
+ # to wrap this in a list for some reason
151
+ integ = spi.ode(lambda t, x: list(self.rhs(t, x)), jac=self.jacobian)
144
152
  integ.set_integrator(name="lsoda")
145
153
  integ.set_initial_value(self.y0)
154
+
146
155
  t = self.t0 + step_size
147
156
  y1 = copy.deepcopy(self.y0)
148
157
  for _ in range(max_steps):
mxlpy/label_map.py CHANGED
@@ -27,7 +27,6 @@ from mxlpy.model import Model
27
27
  if TYPE_CHECKING:
28
28
  from collections.abc import Callable, Mapping
29
29
 
30
- from mxlpy.types import Derived
31
30
 
32
31
  __all__ = [
33
32
  "LabelMapper",
@@ -551,13 +550,13 @@ class LabelMapper:
551
550
 
552
551
  m = Model()
553
552
 
554
- m.add_parameters(self.model.parameters)
553
+ m.add_parameters(self.model.get_parameter_values())
555
554
 
556
- for name, dp in self.model.derived_parameters.items():
555
+ for name, dp in self.model.get_derived_parameters().items():
557
556
  m.add_derived(name, fn=dp.fn, args=dp.args)
558
557
 
559
- variables: dict[str, float | Derived] = {}
560
- for k, v in self.model.variables.items():
558
+ variables: dict[str, float] = {}
559
+ for k, v in self.model.get_initial_conditions().items():
561
560
  if (isos := isotopomers.get(k)) is None:
562
561
  variables[k] = v
563
562
  else:
@@ -585,14 +584,14 @@ class LabelMapper:
585
584
  args=label_names,
586
585
  )
587
586
 
588
- for name, dv in self.model.derived_variables.items():
587
+ for name, dv in self.model.get_derived_variables().items():
589
588
  m.add_derived(
590
589
  name,
591
590
  fn=dv.fn,
592
591
  args=[f"{i}__total" if i in isotopomers else i for i in dv.args],
593
592
  )
594
593
 
595
- for rxn_name, rxn in self.model.reactions.items():
594
+ for rxn_name, rxn in self.model.get_raw_reactions().items():
596
595
  if (label_map := self.label_maps.get(rxn_name)) is None:
597
596
  m.add_reaction(
598
597
  rxn_name,
mxlpy/linear_label_map.py CHANGED
@@ -272,8 +272,10 @@ class LinearLabelMapper:
272
272
  m = Model()
273
273
  m.add_variables(variables)
274
274
  m.add_parameters(concs.to_dict() | fluxes.to_dict() | {"EXT": external_label})
275
+
276
+ rxns = self.model.get_raw_reactions()
275
277
  for rxn_name, label_map in self.label_maps.items():
276
- rxn = self.model.reactions[rxn_name]
278
+ rxn = rxns[rxn_name]
277
279
  subs, prods = _unpack_stoichiometries(rxn.stoichiometry)
278
280
 
279
281
  subs = _stoichiometry_to_duplicate_list(subs)
mxlpy/mc.py CHANGED
@@ -35,10 +35,10 @@ from mxlpy.scan import (
35
35
  from mxlpy.types import (
36
36
  IntegratorType,
37
37
  McSteadyStates,
38
- ProtocolByPars,
38
+ ProtocolScan,
39
39
  ResponseCoefficientsByPars,
40
- SteadyStates,
41
- TimeCourseByPars,
40
+ SteadyStateScan,
41
+ TimeCourseScan,
42
42
  )
43
43
 
44
44
  if TYPE_CHECKING:
@@ -66,9 +66,10 @@ class ParameterScanWorker(Protocol):
66
66
  model: Model,
67
67
  *,
68
68
  parameters: pd.DataFrame,
69
+ y0: dict[str, float] | None,
69
70
  rel_norm: bool,
70
71
  integrator: IntegratorType,
71
- ) -> SteadyStates:
72
+ ) -> SteadyStateScan:
72
73
  """Call the worker function."""
73
74
  ...
74
75
 
@@ -77,9 +78,10 @@ def _parameter_scan_worker(
77
78
  model: Model,
78
79
  *,
79
80
  parameters: pd.DataFrame,
81
+ y0: dict[str, float] | None,
80
82
  rel_norm: bool,
81
83
  integrator: IntegratorType,
82
- ) -> SteadyStates:
84
+ ) -> SteadyStateScan:
83
85
  """Worker function for parallel steady state scanning across parameter sets.
84
86
 
85
87
  This function executes a parameter scan for steady state solutions for a
@@ -109,6 +111,7 @@ def _parameter_scan_worker(
109
111
  parallel=False,
110
112
  rel_norm=rel_norm,
111
113
  integrator=integrator,
114
+ y0=y0,
112
115
  )
113
116
 
114
117
 
@@ -122,7 +125,7 @@ def steady_state(
122
125
  rel_norm: bool = False,
123
126
  worker: SteadyStateWorker = _steady_state_worker,
124
127
  integrator: IntegratorType | None = None,
125
- ) -> SteadyStates:
128
+ ) -> SteadyStateScan:
126
129
  """Monte-carlo scan of steady states.
127
130
 
128
131
  Examples:
@@ -160,10 +163,14 @@ def steady_state(
160
163
  max_workers=max_workers,
161
164
  cache=cache,
162
165
  )
163
- return SteadyStates(
164
- variables=pd.concat({k: v.variables for k, v in res}, axis=1).T,
165
- fluxes=pd.concat({k: v.fluxes for k, v in res}, axis=1).T,
166
- parameters=mc_to_scan,
166
+ return SteadyStateScan(
167
+ raw_index=(
168
+ pd.Index(mc_to_scan.iloc[:, 0])
169
+ if mc_to_scan.shape[1] == 1
170
+ else pd.MultiIndex.from_frame(mc_to_scan)
171
+ ),
172
+ raw_results=[i[1] for i in res],
173
+ to_scan=mc_to_scan,
167
174
  )
168
175
 
169
176
 
@@ -177,7 +184,7 @@ def time_course(
177
184
  cache: Cache | None = None,
178
185
  worker: TimeCourseWorker = _time_course_worker,
179
186
  integrator: IntegratorType | None = None,
180
- ) -> TimeCourseByPars:
187
+ ) -> TimeCourseScan:
181
188
  """MC time course.
182
189
 
183
190
  Examples:
@@ -216,10 +223,9 @@ def time_course(
216
223
  cache=cache,
217
224
  )
218
225
 
219
- return TimeCourseByPars(
220
- parameters=mc_to_scan,
221
- variables=pd.concat({k: v.variables.T for k, v in res}, axis=1).T,
222
- fluxes=pd.concat({k: v.fluxes.T for k, v in res}, axis=1).T,
226
+ return TimeCourseScan(
227
+ to_scan=mc_to_scan,
228
+ raw_results=dict(res),
223
229
  )
224
230
 
225
231
 
@@ -234,7 +240,7 @@ def time_course_over_protocol(
234
240
  cache: Cache | None = None,
235
241
  worker: ProtocolWorker = _protocol_worker,
236
242
  integrator: IntegratorType | None = None,
237
- ) -> ProtocolByPars:
243
+ ) -> ProtocolScan:
238
244
  """MC time course.
239
245
 
240
246
  Examples:
@@ -274,13 +280,10 @@ def time_course_over_protocol(
274
280
  max_workers=max_workers,
275
281
  cache=cache,
276
282
  )
277
- concs = {k: v.variables.T for k, v in res}
278
- fluxes = {k: v.fluxes.T for k, v in res}
279
- return ProtocolByPars(
280
- variables=pd.concat(concs, axis=1).T,
281
- fluxes=pd.concat(fluxes, axis=1).T,
282
- parameters=mc_to_scan,
283
+ return ProtocolScan(
284
+ to_scan=mc_to_scan,
283
285
  protocol=protocol,
286
+ raw_results=dict(res),
284
287
  )
285
288
 
286
289
 
mxlpy/mca.py CHANGED
@@ -71,7 +71,7 @@ def _response_coefficient_worker(
71
71
  - Series of flux response coefficients
72
72
 
73
73
  """
74
- old = model.parameters[parameter]
74
+ old = model.get_parameter_values()[parameter]
75
75
  if y0 is not None:
76
76
  model.update_variables(y0)
77
77
 
@@ -91,8 +91,12 @@ def _response_coefficient_worker(
91
91
  y0=None,
92
92
  )
93
93
 
94
- conc_resp = (upper.variables - lower.variables) / (2 * displacement * old)
95
- flux_resp = (upper.fluxes - lower.fluxes) / (2 * displacement * old)
94
+ conc_resp = (upper.variables.iloc[-1] - lower.variables.iloc[-1]) / (
95
+ 2 * displacement * old
96
+ )
97
+ flux_resp = (upper.fluxes.iloc[-1] - lower.fluxes.iloc[-1]) / (
98
+ 2 * displacement * old
99
+ )
96
100
  # Reset
97
101
  model.update_parameters({parameter: old})
98
102
  if normalized:
@@ -102,8 +106,8 @@ def _response_coefficient_worker(
102
106
  integrator=integrator,
103
107
  y0=None,
104
108
  )
105
- conc_resp *= old / norm.variables
106
- flux_resp *= old / norm.fluxes
109
+ conc_resp *= old / norm.variables.iloc[-1]
110
+ flux_resp *= old / norm.fluxes.iloc[-1]
107
111
  return conc_resp, flux_resp
108
112
 
109
113
 
@@ -205,7 +209,7 @@ def parameter_elasticities(
205
209
 
206
210
  variables = model.get_initial_conditions() if variables is None else variables
207
211
  for par in to_scan:
208
- old = model.parameters[par]
212
+ old = model.get_parameter_values()[par]
209
213
 
210
214
  model.update_parameters({par: old * (1 + displacement)})
211
215
  upper = model.get_fluxes(variables=variables, time=time)
mxlpy/meta/__init__.py CHANGED
@@ -2,12 +2,14 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from .codegen_latex import generate_latex_code
6
- from .codegen_modebase import generate_mxlpy_code
7
- from .codegen_py import generate_model_code_py
5
+ from .codegen_latex import generate_latex_code, to_tex_export
6
+ from .codegen_model import generate_model_code_py, generate_model_code_rs
7
+ from .codegen_mxlpy import generate_mxlpy_code
8
8
 
9
9
  __all__ = [
10
10
  "generate_latex_code",
11
11
  "generate_model_code_py",
12
+ "generate_model_code_rs",
12
13
  "generate_mxlpy_code",
14
+ "to_tex_export",
13
15
  ]
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING
7
7
 
8
8
  import sympy
9
9
 
10
- from mxlpy.meta.source_tools import fn_to_sympy
10
+ from mxlpy.meta.sympy_tools import fn_to_sympy, list_of_symbols
11
11
  from mxlpy.types import Derived, RateFn
12
12
 
13
13
  if TYPE_CHECKING:
@@ -21,6 +21,7 @@ __all__ = [
21
21
  "default_init",
22
22
  "generate_latex_code",
23
23
  "get_model_tex_diff",
24
+ "to_tex_export",
24
25
  ]
25
26
 
26
27
  cdot = r"\cdot"
@@ -31,10 +32,6 @@ newline = r"\\" + "\n"
31
32
  floatbarrier = r"\FloatBarrier"
32
33
 
33
34
 
34
- def _list_of_symbols(args: list[str]) -> list[sympy.Symbol | sympy.Expr]:
35
- return [sympy.Symbol(arg) for arg in args]
36
-
37
-
38
35
  def default_init[T1, T2](d: dict[T1, T2] | None) -> dict[T1, T2]:
39
36
  """Return empty dict if d is None.
40
37
 
@@ -63,10 +60,6 @@ def _gls(s: str) -> str:
63
60
  return rf"\gls{{{s}}}"
64
61
 
65
62
 
66
- def _abbrev_and_full(s: str) -> str:
67
- return rf"\acrfull{{{s}}}"
68
-
69
-
70
63
  def _gls_short(s: str) -> str:
71
64
  return rf"\acrshort{{{s}}}"
72
65
 
@@ -75,6 +68,10 @@ def _gls_full(s: str) -> str:
75
68
  return rf"\acrlong{{{s}}}"
76
69
 
77
70
 
71
+ def _gls_short_and_full(s: str) -> str:
72
+ return rf"\acrfull{{{s}}}"
73
+
74
+
78
75
  def _rename_latex(s: str) -> str:
79
76
  if s[0].isdigit():
80
77
  s = s[1:]
@@ -109,6 +106,8 @@ def _sympy_to_latex(expr: sympy.Expr) -> str:
109
106
 
110
107
  def _fn_to_latex(
111
108
  fn: Callable,
109
+ *,
110
+ origin: str,
112
111
  arg_names: list[str],
113
112
  long_name_cutoff: int,
114
113
  ) -> tuple[str, dict[str, str]]:
@@ -121,10 +120,13 @@ def _fn_to_latex(
121
120
  replacements = {k: _name_to_latex(f"_x{i}") for i, k in enumerate(long_names)}
122
121
 
123
122
  expr = fn_to_sympy(
124
- fn, _list_of_symbols([replacements.get(k, k) for k in tex_names])
123
+ fn,
124
+ origin=origin,
125
+ model_args=list_of_symbols([replacements.get(k, k) for k in tex_names]),
125
126
  )
126
- fn_str = _sympy_to_latex(expr)
127
- return fn_str, replacements
127
+ if expr is None:
128
+ return rf"\textcolor{{red}}{{{origin}}}", replacements
129
+ return _sympy_to_latex(expr), replacements
128
130
 
129
131
 
130
132
  def _table(
@@ -323,7 +325,8 @@ def _stoichs_to_latex(
323
325
  )
324
326
  sympy_fn = fn_to_sympy(
325
327
  rxn_stoich.fn,
326
- _list_of_symbols([replacements.get(k, k) for k in arg_names]),
328
+ origin=rxn_name,
329
+ model_args=list_of_symbols([replacements.get(k, k) for k in arg_names]),
327
330
  )
328
331
  expr = expr + sympy_fn * sympy.Symbol(rxn_name) # type: ignore
329
332
  else:
@@ -480,7 +483,7 @@ class TexExport:
480
483
 
481
484
  def _add_gls_if_found(k: str) -> str:
482
485
  if (new := gls.get(k)) is not None:
483
- return _abbrev_and_full(new)
486
+ return _gls_short_and_full(new)
484
487
  return k
485
488
 
486
489
  return TexExport(
@@ -564,7 +567,7 @@ class TexExport:
564
567
 
565
568
  def export_derived(
566
569
  self,
567
- long_name_cutoff: int,
570
+ long_name_cutoff: int = 10,
568
571
  ) -> str:
569
572
  """Export derived quantities as LaTeX equations.
570
573
 
@@ -587,16 +590,20 @@ class TexExport:
587
590
  for k, v in sorted(self.derived.items()):
588
591
  fn_str, repls = _fn_to_latex(
589
592
  v.fn,
593
+ origin=k,
590
594
  arg_names=v.args,
591
595
  long_name_cutoff=long_name_cutoff,
592
596
  )
593
- rows.append(f"{_mathrm(_name_to_latex(k))} &= {fn_str} \\\\")
597
+ rows.append(f" {_mathrm(_name_to_latex(k))} &= {fn_str} \\\\")
594
598
  if repls:
595
599
  rows.append(_replacements_in_align(repls))
596
600
 
597
601
  return _latex_align(rows)
598
602
 
599
- def export_reactions(self, long_name_cutoff: int) -> str:
603
+ def export_reactions(
604
+ self,
605
+ long_name_cutoff: int = 10,
606
+ ) -> str:
600
607
  """Export reactions as LaTeX equations.
601
608
 
602
609
  Returns
@@ -618,17 +625,18 @@ class TexExport:
618
625
  for k, v in sorted(self.reactions.items()):
619
626
  fn_str, repls = _fn_to_latex(
620
627
  v.fn,
628
+ origin=k,
621
629
  arg_names=v.args,
622
630
  long_name_cutoff=long_name_cutoff,
623
631
  )
624
- rows.append(f"{_mathrm(_name_to_latex(k))} &= {fn_str} \\\\")
632
+ rows.append(f" {_mathrm(_name_to_latex(k))} &= {fn_str} \\\\")
625
633
  if repls:
626
634
  rows.append(_replacements_in_align(repls))
627
635
  return _latex_align(rows)
628
636
 
629
637
  def export_diff_eqs(
630
638
  self,
631
- long_name_cutoff: int,
639
+ long_name_cutoff: int = 10,
632
640
  ) -> str:
633
641
  """Export stoichiometries as LaTeX table.
634
642
 
@@ -654,12 +662,15 @@ class TexExport:
654
662
  long_name_cutoff=long_name_cutoff,
655
663
  )
656
664
 
657
- rows.append(f"{dxdt} &= {stoich_str} \\\\")
665
+ rows.append(f" {dxdt} &= {stoich_str} \\\\")
658
666
  if repls:
659
667
  rows.append(_replacements_in_align(repls))
660
668
  return _latex_align(rows)
661
669
 
662
- def export_all(self, long_name_cutoff: int = 10) -> str:
670
+ def export_all(
671
+ self,
672
+ long_name_cutoff: int = 10,
673
+ ) -> str:
663
674
  """Export all model parts as a complete LaTeX document section.
664
675
 
665
676
  Returns
@@ -754,7 +765,7 @@ class TexExport:
754
765
  \usepackage[a4paper,top=2cm,bottom=2cm,left=2cm,right=2cm,marginparwidth=1.75cm]{{geometry}}
755
766
  \usepackage{{amsmath, amssymb, array, booktabs,
756
767
  breqn, caption, longtable, mathtools, placeins,
757
- ragged2e, tabularx, titlesec, titling}}
768
+ ragged2e, tabularx, titlesec, titling, xcolor}}
758
769
  \newcommand{{\sectionbreak}}{{\clearpage}}
759
770
  \setlength{{\parindent}}{{0pt}}
760
771
  \allowdisplaybreaks
@@ -769,17 +780,20 @@ class TexExport:
769
780
  """
770
781
 
771
782
 
772
- def _to_tex_export(self: Model) -> TexExport:
783
+ def to_tex_export(model: Model) -> TexExport:
784
+ """Create TexExport object from a model."""
773
785
  diff_eqs = {}
774
- for rxn_name, rxn in self.reactions.items():
786
+ for rxn_name, rxn in model.get_raw_reactions().items():
775
787
  for var_name, factor in rxn.stoichiometry.items():
776
788
  diff_eqs.setdefault(var_name, {})[rxn_name] = factor
777
789
 
778
790
  return TexExport(
779
- parameters=self.parameters,
780
- variables=self.get_initial_conditions(), # FIXME: think about this later
781
- derived=self.derived,
782
- reactions={k: TexReaction(v.fn, v.args) for k, v in self.reactions.items()},
791
+ parameters=model.get_parameter_values(),
792
+ variables=model.get_initial_conditions(), # FIXME: think about this later
793
+ derived=model.get_raw_derived(),
794
+ reactions={
795
+ k: TexReaction(v.fn, v.args) for k, v in model.get_raw_reactions().items()
796
+ },
783
797
  diff_eqs=diff_eqs,
784
798
  )
785
799
 
@@ -820,7 +834,7 @@ def generate_latex_code(
820
834
  """
821
835
  gls = default_init(gls)
822
836
  return (
823
- _to_tex_export(model)
837
+ to_tex_export(model)
824
838
  .rename_with_glossary(gls)
825
839
  .export_document(long_name_cutoff=long_name_cutoff)
826
840
  )
@@ -866,7 +880,7 @@ def get_model_tex_diff(
866
880
  return f"""{" start autogenerated ":%^60}
867
881
  {_clearpage()}
868
882
  {_subsubsection("Model changes")}{_label(section_label)}
869
- {((_to_tex_export(m1) - _to_tex_export(m2)).rename_with_glossary(gls).export_all(long_name_cutoff=long_name_cutoff))}
883
+ {((to_tex_export(m1) - to_tex_export(m2)).rename_with_glossary(gls).export_all(long_name_cutoff=long_name_cutoff))}
870
884
  {_clearpage()}
871
885
  {" end autogenerated ":%^60}
872
886
  """