mxlpy 0.18.0__py3-none-any.whl → 0.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. mxlpy/__init__.py +13 -9
  2. mxlpy/compare.py +240 -0
  3. mxlpy/experimental/diff.py +16 -4
  4. mxlpy/fit.py +6 -11
  5. mxlpy/fns.py +37 -42
  6. mxlpy/identify.py +10 -3
  7. mxlpy/integrators/__init__.py +4 -3
  8. mxlpy/integrators/int_assimulo.py +16 -9
  9. mxlpy/integrators/int_scipy.py +13 -9
  10. mxlpy/label_map.py +7 -3
  11. mxlpy/linear_label_map.py +4 -2
  12. mxlpy/mc.py +5 -14
  13. mxlpy/mca.py +4 -4
  14. mxlpy/meta/__init__.py +6 -4
  15. mxlpy/meta/codegen_latex.py +180 -87
  16. mxlpy/meta/codegen_modebase.py +3 -1
  17. mxlpy/meta/codegen_py.py +11 -3
  18. mxlpy/meta/source_tools.py +9 -5
  19. mxlpy/model.py +187 -100
  20. mxlpy/nn/__init__.py +24 -5
  21. mxlpy/nn/_keras.py +92 -0
  22. mxlpy/nn/_torch.py +25 -18
  23. mxlpy/npe/__init__.py +21 -16
  24. mxlpy/npe/_keras.py +326 -0
  25. mxlpy/npe/_torch.py +56 -60
  26. mxlpy/parallel.py +5 -2
  27. mxlpy/parameterise.py +11 -3
  28. mxlpy/plot.py +205 -52
  29. mxlpy/report.py +33 -8
  30. mxlpy/sbml/__init__.py +3 -3
  31. mxlpy/sbml/_data.py +7 -6
  32. mxlpy/sbml/_export.py +8 -1
  33. mxlpy/sbml/_mathml.py +8 -7
  34. mxlpy/sbml/_name_conversion.py +5 -1
  35. mxlpy/scan.py +14 -19
  36. mxlpy/simulator.py +34 -31
  37. mxlpy/surrogates/__init__.py +25 -17
  38. mxlpy/surrogates/_keras.py +139 -0
  39. mxlpy/surrogates/_poly.py +25 -10
  40. mxlpy/surrogates/_qss.py +34 -0
  41. mxlpy/surrogates/_torch.py +50 -32
  42. mxlpy/symbolic/__init__.py +5 -3
  43. mxlpy/symbolic/strikepy.py +5 -2
  44. mxlpy/symbolic/symbolic_model.py +14 -5
  45. mxlpy/types.py +61 -120
  46. {mxlpy-0.18.0.dist-info → mxlpy-0.20.0.dist-info}/METADATA +25 -24
  47. mxlpy-0.20.0.dist-info/RECORD +55 -0
  48. mxlpy/nn/_tensorflow.py +0 -0
  49. mxlpy-0.18.0.dist-info/RECORD +0 -51
  50. {mxlpy-0.18.0.dist-info → mxlpy-0.20.0.dist-info}/WHEEL +0 -0
  51. {mxlpy-0.18.0.dist-info → mxlpy-0.20.0.dist-info}/licenses/LICENSE +0 -0
@@ -2,13 +2,8 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from dataclasses import dataclass, field
6
-
7
- __all__ = [
8
- "Scipy",
9
- ]
10
-
11
5
  import copy
6
+ from dataclasses import dataclass, field
12
7
  from typing import TYPE_CHECKING, cast
13
8
 
14
9
  import numpy as np
@@ -20,6 +15,11 @@ if TYPE_CHECKING:
20
15
  from collections.abc import Callable
21
16
 
22
17
 
18
+ __all__ = [
19
+ "Scipy",
20
+ ]
21
+
22
+
23
23
  @dataclass
24
24
  class Scipy:
25
25
  """Scipy integrator for solving ODEs.
@@ -108,10 +108,14 @@ class Scipy:
108
108
  rtol=self.rtol,
109
109
  method="LSODA",
110
110
  )
111
+
111
112
  if res.success:
112
- self.t0 = time_points[-1]
113
- self.y0 = res.y[:, -1]
114
- return np.array(time_points, dtype=float), res.y.T
113
+ t = np.atleast_1d(np.array(res.t, dtype=float))
114
+ y = np.atleast_2d(np.array(res.y, dtype=float).T)
115
+
116
+ self.t0 = t[-1]
117
+ self.y0 = y[-1]
118
+ return t, y
115
119
  return None, None
116
120
 
117
121
  def integrate_to_steady_state(
mxlpy/label_map.py CHANGED
@@ -24,11 +24,15 @@ import numpy as np
24
24
 
25
25
  from mxlpy.model import Model
26
26
 
27
- __all__ = ["LabelMapper"]
28
-
29
27
  if TYPE_CHECKING:
30
28
  from collections.abc import Callable, Mapping
31
29
 
30
+ from mxlpy.types import Derived
31
+
32
+ __all__ = [
33
+ "LabelMapper",
34
+ ]
35
+
32
36
 
33
37
  def _total_concentration(*args: float) -> float:
34
38
  """Calculate sum of isotopomer concentrations.
@@ -552,7 +556,7 @@ class LabelMapper:
552
556
  for name, dp in self.model.derived_parameters.items():
553
557
  m.add_derived(name, fn=dp.fn, args=dp.args)
554
558
 
555
- variables: dict[str, float] = {}
559
+ variables: dict[str, float | Derived] = {}
556
560
  for k, v in self.model.variables.items():
557
561
  if (isos := isotopomers.get(k)) is None:
558
562
  variables[k] = v
mxlpy/linear_label_map.py CHANGED
@@ -16,13 +16,15 @@ from typing import TYPE_CHECKING
16
16
 
17
17
  from mxlpy.model import Derived, Model
18
18
 
19
- __all__ = ["LinearLabelMapper"]
20
-
21
19
  if TYPE_CHECKING:
22
20
  from collections.abc import Mapping
23
21
 
24
22
  import pandas as pd
25
23
 
24
+ __all__ = [
25
+ "LinearLabelMapper",
26
+ ]
27
+
26
28
 
27
29
  def _generate_isotope_labels(base_name: str, num_labels: int) -> list[str]:
28
30
  """Generate list of isotopomer names for a compound.
mxlpy/mc.py CHANGED
@@ -42,6 +42,11 @@ from mxlpy.types import (
42
42
  TimeCourseByPars,
43
43
  )
44
44
 
45
+ if TYPE_CHECKING:
46
+ from mxlpy.model import Model
47
+ from mxlpy.types import Array
48
+
49
+
45
50
  __all__ = [
46
51
  "ParameterScanWorker",
47
52
  "parameter_elasticities",
@@ -53,20 +58,6 @@ __all__ = [
53
58
  "variable_elasticities",
54
59
  ]
55
60
 
56
- if TYPE_CHECKING:
57
- from mxlpy.model import Model
58
- from mxlpy.types import Array
59
-
60
- __ALL__ = [
61
- "steady_state",
62
- "time_course",
63
- "time_course_over_protocol",
64
- "parameter_scan_ss",
65
- "compound_elasticities",
66
- "parameter_elasticities",
67
- "response_coefficients",
68
- ]
69
-
70
61
 
71
62
  class ParameterScanWorker(Protocol):
72
63
  """Protocol for the parameter scan worker function."""
mxlpy/mca.py CHANGED
@@ -27,16 +27,16 @@ from mxlpy.parallel import parallelise
27
27
  from mxlpy.scan import _steady_state_worker
28
28
  from mxlpy.types import ResponseCoefficients
29
29
 
30
+ if TYPE_CHECKING:
31
+ from mxlpy.model import Model
32
+ from mxlpy.types import IntegratorType
33
+
30
34
  __all__ = [
31
35
  "parameter_elasticities",
32
36
  "response_coefficients",
33
37
  "variable_elasticities",
34
38
  ]
35
39
 
36
- if TYPE_CHECKING:
37
- from mxlpy.model import Model
38
- from mxlpy.types import IntegratorType
39
-
40
40
 
41
41
  def _response_coefficient_worker(
42
42
  parameter: str,
mxlpy/meta/__init__.py CHANGED
@@ -1,11 +1,13 @@
1
1
  """Metaprogramming facilities."""
2
2
 
3
+ from __future__ import annotations
4
+
5
+ from .codegen_latex import generate_latex_code
6
+ from .codegen_modebase import generate_mxlpy_code
7
+ from .codegen_py import generate_model_code_py
8
+
3
9
  __all__ = [
4
10
  "generate_latex_code",
5
11
  "generate_model_code_py",
6
12
  "generate_mxlpy_code",
7
13
  ]
8
-
9
- from .codegen_latex import generate_latex_code
10
- from .codegen_modebase import generate_mxlpy_code
11
- from .codegen_py import generate_model_code_py
@@ -10,6 +10,11 @@ import sympy
10
10
  from mxlpy.meta.source_tools import fn_to_sympy
11
11
  from mxlpy.types import Derived, RateFn
12
12
 
13
+ if TYPE_CHECKING:
14
+ from collections.abc import Callable, Mapping
15
+
16
+ from mxlpy import Model
17
+
13
18
  __all__ = [
14
19
  "TexExport",
15
20
  "TexReaction",
@@ -18,12 +23,6 @@ __all__ = [
18
23
  "get_model_tex_diff",
19
24
  ]
20
25
 
21
- if TYPE_CHECKING:
22
- from collections.abc import Callable, Mapping
23
-
24
- from mxlpy import Model
25
-
26
-
27
26
  cdot = r"\cdot"
28
27
  empty_set = r"\emptyset"
29
28
  left_right_arrows = r"\xleftrightharpoons{}"
@@ -94,8 +93,38 @@ def _escape_non_math(s: str) -> str:
94
93
  return s.replace("_", r"\_")
95
94
 
96
95
 
97
- def _fn_to_latex(fn: Callable, arg_names: list[str]) -> str:
98
- return sympy.latex(fn_to_sympy(fn, _list_of_symbols(arg_names)))
96
+ def _name_to_latex(s: str) -> str:
97
+ return _escape_non_math(_rename_latex(s))
98
+
99
+
100
+ def _sympy_to_latex(expr: sympy.Expr) -> str:
101
+ return sympy.latex(
102
+ expr,
103
+ fold_frac_powers=True,
104
+ fold_func_brackets=True,
105
+ fold_short_frac=True,
106
+ mul_symbol="dot",
107
+ )
108
+
109
+
110
+ def _fn_to_latex(
111
+ fn: Callable,
112
+ arg_names: list[str],
113
+ long_name_cutoff: int,
114
+ ) -> tuple[str, dict[str, str]]:
115
+ tex_names = [_mathrm(_name_to_latex(i)) for i in arg_names]
116
+ long_names = (
117
+ k
118
+ for k, k_orig in zip(tex_names, arg_names, strict=True)
119
+ if len(k_orig) >= long_name_cutoff
120
+ )
121
+ replacements = {k: _name_to_latex(f"_x{i}") for i, k in enumerate(long_names)}
122
+
123
+ expr = fn_to_sympy(
124
+ fn, _list_of_symbols([replacements.get(k, k) for k in tex_names])
125
+ )
126
+ fn_str = _sympy_to_latex(expr)
127
+ return fn_str, replacements
99
128
 
100
129
 
101
130
  def _table(
@@ -213,6 +242,12 @@ def _latex_list(rows: list[str]) -> str:
213
242
  return "\n\n".join(rows)
214
243
 
215
244
 
245
+ def _latex_align(items: list[str]) -> str:
246
+ return rf"""\begin{{align*}}
247
+ {"\n".join(items)}
248
+ \end{{align*}}"""
249
+
250
+
216
251
  def _latex_list_as_sections(
217
252
  rows: list[tuple[str, str]], sec_fn: Callable[[str], str]
218
253
  ) -> str:
@@ -220,7 +255,7 @@ def _latex_list_as_sections(
220
255
  [
221
256
  "\n".join(
222
257
  (
223
- sec_fn(_escape_non_math(name)),
258
+ sec_fn(_name_to_latex(name)),
224
259
  content,
225
260
  )
226
261
  )
@@ -234,7 +269,7 @@ def _latex_list_as_bold(rows: list[tuple[str, str]]) -> str:
234
269
  [
235
270
  "\n".join(
236
271
  (
237
- _bold(_escape_non_math(name)) + r"\\",
272
+ _bold(_name_to_latex(name)) + r"\\",
238
273
  content,
239
274
  r"\vspace{20pt}",
240
275
  )
@@ -244,33 +279,57 @@ def _latex_list_as_bold(rows: list[tuple[str, str]]) -> str:
244
279
  )
245
280
 
246
281
 
247
- def _stoichiometries_to_latex(stoich: Mapping[str, float | Derived]) -> str:
248
- def optional_factor(k: str, v: float) -> str:
249
- if v == 1:
250
- return _mathrm(k)
251
- if v == -1:
252
- return f"-{_mathrm(k)}"
253
- return f"{v} {cdot} {_mathrm(k)}"
282
+ def _replacements_in_align(replacements: dict[str, str]) -> str:
283
+ reps = "\n".join(rf"&\qquad {v} :: {k} \\" for k, v in replacements.items())
254
284
 
255
- def latex_for_empty(s: str) -> str:
256
- if len(s) == 0:
257
- return empty_set
258
- return s
285
+ return rf"""&\quad \mathrm{{with}}\\
286
+ {reps}\\"""
259
287
 
260
- line = []
261
- for k, v in stoich.items():
262
- if isinstance(v, int | float):
263
- line.append(optional_factor(k, v))
264
- else:
265
- line.append(_fn_to_latex(v.fn, [_rename_latex(i) for i in v.args]))
266
288
 
267
- line_str = f"{line[0]}"
268
- for element in line[1:]:
269
- if element.startswith("-"):
270
- line_str += f" {element}"
289
+ def _diff_eq(name: str) -> str:
290
+ return rf"\frac{{d\left({name}\right)}}{{dt}}"
291
+
292
+
293
+ def _optional_factor(k: str, v: float) -> str:
294
+ if v == 1:
295
+ return k
296
+ if v == -1:
297
+ return f"-{k}"
298
+ return f"{v} {cdot} {k}"
299
+
300
+
301
+ def _stoichs_to_latex(
302
+ stoichs: Mapping[str, float | Derived],
303
+ long_name_cutoff: int,
304
+ ) -> tuple[str, dict[str, str]]:
305
+ replacements = {}
306
+ expr = sympy.Integer(0)
307
+
308
+ for rxn_name, rxn_stoich in stoichs.items():
309
+ rxn_name = _name_to_latex(rxn_name) # noqa: PLW2901
310
+
311
+ if isinstance(rxn_stoich, Derived):
312
+ arg_names = [_mathrm(_name_to_latex(i)) for i in rxn_stoich.args]
313
+ long_names = (
314
+ k
315
+ for k, k_orig in zip(arg_names, rxn_stoich.args, strict=True)
316
+ if len(k_orig) >= long_name_cutoff
317
+ )
318
+ replacements.update(
319
+ {
320
+ k: _name_to_latex(f"_x{i}")
321
+ for i, k in enumerate(long_names, len(replacements))
322
+ }
323
+ )
324
+ sympy_fn = fn_to_sympy(
325
+ rxn_stoich.fn,
326
+ _list_of_symbols([replacements.get(k, k) for k in arg_names]),
327
+ )
328
+ expr = expr + sympy_fn * sympy.Symbol(rxn_name) # type: ignore
271
329
  else:
272
- line_str += f" + {element}"
273
- return _math_il(line_str)
330
+ expr = expr + rxn_stoich * sympy.Symbol(rxn_name) # type: ignore
331
+
332
+ return _sympy_to_latex(expr.subs(1.0, 1).simplify()), replacements
274
333
 
275
334
 
276
335
  @dataclass
@@ -333,7 +392,7 @@ class TexExport:
333
392
  variables: dict[str, float]
334
393
  derived: dict[str, Derived]
335
394
  reactions: dict[str, TexReaction]
336
- stoichiometries: dict[str, Mapping[str, float | Derived]]
395
+ diff_eqs: dict[str, Mapping[str, float | Derived]]
337
396
 
338
397
  @staticmethod
339
398
  def _diff_parameters(
@@ -390,10 +449,8 @@ class TexExport:
390
449
  variables=self._diff_variables(self.variables, other.variables),
391
450
  derived=self._diff_derived(self.derived, other.derived),
392
451
  reactions=self._diff_reactions(self.reactions, other.reactions),
393
- stoichiometries={
394
- k: v
395
- for k, v in other.stoichiometries.items()
396
- if self.stoichiometries.get(k, {}) != v
452
+ diff_eqs={
453
+ k: v for k, v in other.diff_eqs.items() if self.diff_eqs.get(k, {}) != v
397
454
  },
398
455
  )
399
456
 
@@ -440,9 +497,9 @@ class TexExport:
440
497
  )
441
498
  for k, v in self.reactions.items()
442
499
  },
443
- stoichiometries={
500
+ diff_eqs={
444
501
  _add_gls_if_found(k): {gls.get(k2, k2): v2 for k2, v2 in v.items()}
445
- for k, v in self.stoichiometries.items()
502
+ for k, v in self.diff_eqs.items()
446
503
  },
447
504
  )
448
505
 
@@ -466,7 +523,7 @@ class TexExport:
466
523
  headers=["Model name", "Initial concentration"],
467
524
  rows=[
468
525
  [
469
- k,
526
+ _name_to_latex(k),
470
527
  f"{v:.2e}",
471
528
  ]
472
529
  for k, v in self.variables.items()
@@ -496,7 +553,7 @@ class TexExport:
496
553
  return _table(
497
554
  headers=["Parameter name", "Parameter value"],
498
555
  rows=[
499
- [_math_il(_mathrm(_escape_non_math(_rename_latex(k)))), f"{v:.2e}"]
556
+ [_name_to_latex(k), f"{v:.2e}"]
500
557
  for k, v in sorted(self.parameters.items())
501
558
  ],
502
559
  n_columns=2,
@@ -505,7 +562,10 @@ class TexExport:
505
562
  long_desc="Model parameters",
506
563
  )
507
564
 
508
- def export_derived(self) -> str:
565
+ def export_derived(
566
+ self,
567
+ long_name_cutoff: int,
568
+ ) -> str:
509
569
  """Export derived quantities as LaTeX equations.
510
570
 
511
571
  Returns
@@ -523,16 +583,20 @@ class TexExport:
523
583
  True
524
584
 
525
585
  """
526
- return _latex_list(
527
- rows=[
528
- _dmath(
529
- f"{_rename_latex(k)} = {_fn_to_latex(v.fn, [_rename_latex(i) for i in v.args])}"
530
- )
531
- for k, v in sorted(self.derived.items())
532
- ]
533
- )
586
+ rows = []
587
+ for k, v in sorted(self.derived.items()):
588
+ fn_str, repls = _fn_to_latex(
589
+ v.fn,
590
+ arg_names=v.args,
591
+ long_name_cutoff=long_name_cutoff,
592
+ )
593
+ rows.append(f"{_mathrm(_name_to_latex(k))} &= {fn_str} \\\\")
594
+ if repls:
595
+ rows.append(_replacements_in_align(repls))
534
596
 
535
- def export_reactions(self) -> str:
597
+ return _latex_align(rows)
598
+
599
+ def export_reactions(self, long_name_cutoff: int) -> str:
536
600
  """Export reactions as LaTeX equations.
537
601
 
538
602
  Returns
@@ -550,16 +614,22 @@ class TexExport:
550
614
  True
551
615
 
552
616
  """
553
- return _latex_list(
554
- rows=[
555
- _dmath(
556
- f"{_rename_latex(k)} = {_fn_to_latex(v.fn, [_rename_latex(i) for i in v.args])}"
557
- )
558
- for k, v in sorted(self.reactions.items())
559
- ]
560
- )
617
+ rows = []
618
+ for k, v in sorted(self.reactions.items()):
619
+ fn_str, repls = _fn_to_latex(
620
+ v.fn,
621
+ arg_names=v.args,
622
+ long_name_cutoff=long_name_cutoff,
623
+ )
624
+ rows.append(f"{_mathrm(_name_to_latex(k))} &= {fn_str} \\\\")
625
+ if repls:
626
+ rows.append(_replacements_in_align(repls))
627
+ return _latex_align(rows)
561
628
 
562
- def export_stoichiometries(self) -> str:
629
+ def export_diff_eqs(
630
+ self,
631
+ long_name_cutoff: int,
632
+ ) -> str:
563
633
  """Export stoichiometries as LaTeX table.
564
634
 
565
635
  Returns
@@ -576,22 +646,20 @@ class TexExport:
576
646
  True
577
647
 
578
648
  """
579
- return _table(
580
- headers=["Rate name", "Stoichiometry"],
581
- rows=[
582
- [
583
- _escape_non_math(_rename_latex(k)),
584
- _stoichiometries_to_latex(v),
585
- ]
586
- for k, v in sorted(self.stoichiometries.items())
587
- ],
588
- n_columns=2,
589
- label="model-stoichs",
590
- short_desc="Model stoichiometries.",
591
- long_desc="Model stoichiometries.",
592
- )
649
+ rows = []
650
+ for var_name, stoich in sorted(self.diff_eqs.items()):
651
+ dxdt = _diff_eq(_mathrm(_name_to_latex(var_name)))
652
+ stoich_str, repls = _stoichs_to_latex(
653
+ stoich,
654
+ long_name_cutoff=long_name_cutoff,
655
+ )
656
+
657
+ rows.append(f"{dxdt} &= {stoich_str} \\\\")
658
+ if repls:
659
+ rows.append(_replacements_in_align(repls))
660
+ return _latex_align(rows)
593
661
 
594
- def export_all(self) -> str:
662
+ def export_all(self, long_name_cutoff: int = 10) -> str:
595
663
  """Export all model parts as a complete LaTeX document section.
596
664
 
597
665
  Returns
@@ -626,28 +694,35 @@ class TexExport:
626
694
  sections.append(
627
695
  (
628
696
  "Derived",
629
- self.export_derived(),
697
+ self.export_derived(
698
+ long_name_cutoff=long_name_cutoff,
699
+ ),
630
700
  )
631
701
  )
632
702
  if len(self.reactions) > 0:
633
703
  sections.append(
634
704
  (
635
705
  "Reactions",
636
- self.export_reactions(),
706
+ self.export_reactions(
707
+ long_name_cutoff=long_name_cutoff,
708
+ ),
637
709
  )
638
710
  )
639
711
  sections.append(
640
712
  (
641
- "Stoichiometries",
642
- self.export_stoichiometries(),
713
+ "Differential equations",
714
+ self.export_diff_eqs(
715
+ long_name_cutoff=long_name_cutoff,
716
+ ),
643
717
  )
644
718
  )
645
- return _latex_list_as_sections(sections, _subsubsection_)
719
+ return _latex_list_as_sections(sections, _subsection_)
646
720
 
647
721
  def export_document(
648
722
  self,
649
723
  author: str = "mxlpy",
650
724
  title: str = "Model construction",
725
+ long_name_cutoff: int = 10,
651
726
  ) -> str:
652
727
  r"""Export complete LaTeX document with all model components.
653
728
 
@@ -657,6 +732,8 @@ class TexExport:
657
732
  Name of the author for the document
658
733
  title
659
734
  Title for the document
735
+ long_name_cutoff
736
+ length of function argument names before they are shortened
660
737
 
661
738
  Returns
662
739
  -------
@@ -671,8 +748,8 @@ class TexExport:
671
748
  True
672
749
 
673
750
  """
674
- content = self.export_all()
675
- return rf"""\documentclass{{article}}
751
+ content = self.export_all(long_name_cutoff=long_name_cutoff)
752
+ return rf"""\documentclass[fleqn]{{article}}
676
753
  \usepackage[english]{{babel}}
677
754
  \usepackage[a4paper,top=2cm,bottom=2cm,left=2cm,right=2cm,marginparwidth=1.75cm]{{geometry}}
678
755
  \usepackage{{amsmath, amssymb, array, booktabs,
@@ -680,6 +757,7 @@ class TexExport:
680
757
  ragged2e, tabularx, titlesec, titling}}
681
758
  \newcommand{{\sectionbreak}}{{\clearpage}}
682
759
  \setlength{{\parindent}}{{0pt}}
760
+ \allowdisplaybreaks
683
761
 
684
762
  \title{{{title}}}
685
763
  \date{{}} % clear date
@@ -692,18 +770,24 @@ class TexExport:
692
770
 
693
771
 
694
772
  def _to_tex_export(self: Model) -> TexExport:
773
+ diff_eqs = {}
774
+ for rxn_name, rxn in self.reactions.items():
775
+ for var_name, factor in rxn.stoichiometry.items():
776
+ diff_eqs.setdefault(var_name, {})[rxn_name] = factor
777
+
695
778
  return TexExport(
696
779
  parameters=self.parameters,
697
- variables=self.variables,
780
+ variables=self.get_initial_conditions(), # FIXME: think about this later
698
781
  derived=self.derived,
699
782
  reactions={k: TexReaction(v.fn, v.args) for k, v in self.reactions.items()},
700
- stoichiometries={k: v.stoichiometry for k, v in self.reactions.items()},
783
+ diff_eqs=diff_eqs,
701
784
  )
702
785
 
703
786
 
704
787
  def generate_latex_code(
705
788
  model: Model,
706
789
  gls: dict[str, str] | None = None,
790
+ long_name_cutoff: int = 10,
707
791
  ) -> str:
708
792
  """Export model as LaTeX document.
709
793
 
@@ -713,6 +797,8 @@ def generate_latex_code(
713
797
  The model to export
714
798
  gls
715
799
  Optional glossary mapping for renaming model components
800
+ long_name_cutoff
801
+ length of function argument names before they are shortened
716
802
 
717
803
  Returns
718
804
  -------
@@ -733,13 +819,18 @@ def generate_latex_code(
733
819
 
734
820
  """
735
821
  gls = default_init(gls)
736
- return _to_tex_export(model).rename_with_glossary(gls).export_document()
822
+ return (
823
+ _to_tex_export(model)
824
+ .rename_with_glossary(gls)
825
+ .export_document(long_name_cutoff=long_name_cutoff)
826
+ )
737
827
 
738
828
 
739
829
  def get_model_tex_diff(
740
830
  m1: Model,
741
831
  m2: Model,
742
832
  gls: dict[str, str] | None = None,
833
+ long_name_cutoff: int = 10,
743
834
  ) -> str:
744
835
  """Create LaTeX diff showing changes between two models.
745
836
 
@@ -751,6 +842,8 @@ def get_model_tex_diff(
751
842
  Second model (compared against the base)
752
843
  gls
753
844
  Optional glossary mapping for renaming model components
845
+ long_name_cutoff
846
+ length of function argument names before they are shortened
754
847
 
755
848
  Returns
756
849
  -------
@@ -773,7 +866,7 @@ def get_model_tex_diff(
773
866
  return f"""{" start autogenerated ":%^60}
774
867
  {_clearpage()}
775
868
  {_subsubsection("Model changes")}{_label(section_label)}
776
- {((_to_tex_export(m1) - _to_tex_export(m2)).rename_with_glossary(gls).export_all())}
869
+ {((_to_tex_export(m1) - _to_tex_export(m2)).rename_with_glossary(gls).export_all(long_name_cutoff=long_name_cutoff))}
777
870
  {_clearpage()}
778
871
  {" end autogenerated ":%^60}
779
872
  """
@@ -13,7 +13,9 @@ from mxlpy.types import Derived
13
13
  if TYPE_CHECKING:
14
14
  from mxlpy.model import Model
15
15
 
16
- __all__ = ["generate_mxlpy_code"]
16
+ __all__ = [
17
+ "generate_mxlpy_code",
18
+ ]
17
19
 
18
20
 
19
21
  def _list_of_symbols(args: list[str]) -> list[sympy.Symbol | sympy.Expr]:
mxlpy/meta/codegen_py.py CHANGED
@@ -1,15 +1,23 @@
1
1
  """Module to export models as code."""
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  import warnings
4
- from collections.abc import Callable, Generator, Iterable, Iterator
6
+ from typing import TYPE_CHECKING
5
7
 
6
8
  import sympy
7
9
 
8
10
  from mxlpy.meta.source_tools import fn_to_sympy, sympy_to_inline
9
- from mxlpy.model import Model
10
11
  from mxlpy.types import Derived
11
12
 
12
- __all__ = ["generate_model_code_py"]
13
+ if TYPE_CHECKING:
14
+ from collections.abc import Callable, Generator, Iterable, Iterator
15
+
16
+ from mxlpy.model import Model
17
+
18
+ __all__ = [
19
+ "generate_model_code_py",
20
+ ]
13
21
 
14
22
 
15
23
  def _conditional_join[T](