mxlpy 0.18.0__py3-none-any.whl → 0.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. mxlpy/__init__.py +13 -9
  2. mxlpy/compare.py +240 -0
  3. mxlpy/experimental/diff.py +16 -4
  4. mxlpy/fit.py +6 -11
  5. mxlpy/fns.py +37 -42
  6. mxlpy/identify.py +10 -3
  7. mxlpy/integrators/__init__.py +4 -3
  8. mxlpy/integrators/int_assimulo.py +16 -9
  9. mxlpy/integrators/int_scipy.py +13 -9
  10. mxlpy/label_map.py +7 -3
  11. mxlpy/linear_label_map.py +4 -2
  12. mxlpy/mc.py +5 -14
  13. mxlpy/mca.py +4 -4
  14. mxlpy/meta/__init__.py +6 -4
  15. mxlpy/meta/codegen_latex.py +180 -87
  16. mxlpy/meta/codegen_modebase.py +3 -1
  17. mxlpy/meta/codegen_py.py +11 -3
  18. mxlpy/meta/source_tools.py +9 -5
  19. mxlpy/model.py +187 -100
  20. mxlpy/nn/__init__.py +24 -5
  21. mxlpy/nn/_keras.py +92 -0
  22. mxlpy/nn/_torch.py +25 -18
  23. mxlpy/npe/__init__.py +21 -16
  24. mxlpy/npe/_keras.py +326 -0
  25. mxlpy/npe/_torch.py +56 -60
  26. mxlpy/parallel.py +5 -2
  27. mxlpy/parameterise.py +11 -3
  28. mxlpy/plot.py +205 -52
  29. mxlpy/report.py +33 -8
  30. mxlpy/sbml/__init__.py +3 -3
  31. mxlpy/sbml/_data.py +7 -6
  32. mxlpy/sbml/_export.py +8 -1
  33. mxlpy/sbml/_mathml.py +8 -7
  34. mxlpy/sbml/_name_conversion.py +5 -1
  35. mxlpy/scan.py +14 -19
  36. mxlpy/simulator.py +34 -31
  37. mxlpy/surrogates/__init__.py +25 -17
  38. mxlpy/surrogates/_keras.py +139 -0
  39. mxlpy/surrogates/_poly.py +25 -10
  40. mxlpy/surrogates/_qss.py +34 -0
  41. mxlpy/surrogates/_torch.py +50 -32
  42. mxlpy/symbolic/__init__.py +5 -3
  43. mxlpy/symbolic/strikepy.py +5 -2
  44. mxlpy/symbolic/symbolic_model.py +14 -5
  45. mxlpy/types.py +61 -120
  46. {mxlpy-0.18.0.dist-info → mxlpy-0.20.0.dist-info}/METADATA +25 -24
  47. mxlpy-0.20.0.dist-info/RECORD +55 -0
  48. mxlpy/nn/_tensorflow.py +0 -0
  49. mxlpy-0.18.0.dist-info/RECORD +0 -51
  50. {mxlpy-0.18.0.dist-info → mxlpy-0.20.0.dist-info}/WHEEL +0 -0
  51. {mxlpy-0.18.0.dist-info → mxlpy-0.20.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,17 +1,21 @@
1
1
  """Tools for working with python source files."""
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  import ast
4
6
  import inspect
5
7
  import textwrap
6
- from collections.abc import Callable
7
8
  from dataclasses import dataclass
8
- from types import ModuleType
9
- from typing import Any, cast
9
+ from typing import TYPE_CHECKING, Any, cast
10
10
 
11
11
  import dill
12
12
  import sympy
13
13
  from sympy.printing.pycode import pycode
14
14
 
15
+ if TYPE_CHECKING:
16
+ from collections.abc import Callable
17
+ from types import ModuleType
18
+
15
19
  __all__ = [
16
20
  "Context",
17
21
  "fn_to_sympy",
@@ -35,7 +39,7 @@ class Context:
35
39
  symbols: dict[str, sympy.Symbol | sympy.Expr] | None = None,
36
40
  caller: Callable | None = None,
37
41
  parent_module: ModuleType | None = None,
38
- ) -> "Context":
42
+ ) -> Context:
39
43
  """Update the context with new values."""
40
44
  return Context(
41
45
  symbols=self.symbols if symbols is None else symbols,
@@ -434,5 +438,5 @@ def _handle_call(node: ast.Call, ctx: Context) -> sympy.Expr:
434
438
  ctx=ctx.updated(parent_module=imports[module_name.id]),
435
439
  )
436
440
 
437
- msg = f"Onsupported function type {node.func}"
441
+ msg = f"Unsupported function type {node.func}"
438
442
  raise NotImplementedError(msg)
mxlpy/model.py CHANGED
@@ -17,13 +17,13 @@ import numpy as np
17
17
  import pandas as pd
18
18
 
19
19
  from mxlpy import fns
20
- from mxlpy.types import (
21
- AbstractSurrogate,
22
- Array,
23
- Derived,
24
- Reaction,
25
- Readout,
26
- )
20
+ from mxlpy.types import AbstractSurrogate, Array, Derived, Reaction, Readout
21
+
22
+ if TYPE_CHECKING:
23
+ from collections.abc import Iterable, Mapping
24
+ from inspect import FullArgSpec
25
+
26
+ from mxlpy.types import Callable, Param, RateFn, RetType
27
27
 
28
28
  __all__ = [
29
29
  "ArityMismatchError",
@@ -34,12 +34,6 @@ __all__ = [
34
34
  "ModelCache",
35
35
  ]
36
36
 
37
- if TYPE_CHECKING:
38
- from collections.abc import Iterable, Mapping
39
- from inspect import FullArgSpec
40
-
41
- from mxlpy.types import Callable, Param, RateFn, RetType
42
-
43
37
 
44
38
  @dataclass
45
39
  class Dependency:
@@ -250,17 +244,17 @@ class ModelCache:
250
244
  stoich_by_cpds: A dictionary mapping compound names to their stoichiometric coefficients.
251
245
  dyn_stoich_by_cpds: A dictionary mapping compound names to their dynamic stoichiometric coefficients.
252
246
  dxdt: A pandas Series representing the rate of change of variables.
247
+ initial_conditions: calculated initial conditions
253
248
 
254
249
  """
255
250
 
256
251
  var_names: list[str]
257
- order: list[str]
252
+ dyn_order: list[str]
258
253
  all_parameter_values: dict[str, float]
259
- derived_parameter_names: list[str]
260
- derived_variable_names: list[str]
261
254
  stoich_by_cpds: dict[str, dict[str, float]]
262
255
  dyn_stoich_by_cpds: dict[str, dict[str, Derived]]
263
256
  dxdt: pd.Series
257
+ initial_conditions: dict[str, float]
264
258
 
265
259
 
266
260
  @dataclass(slots=True)
@@ -276,17 +270,19 @@ class Model:
276
270
  _reactions: Dictionary of reactions in the model.
277
271
  _surrogates: Dictionary of surrogate models.
278
272
  _cache: Cache for storing model-related data structures.
273
+ _data: Named references to data sets
279
274
 
280
275
  """
281
276
 
282
277
  _ids: dict[str, str] = field(default_factory=dict)
283
- _variables: dict[str, float] = field(default_factory=dict)
278
+ _variables: dict[str, float | Derived] = field(default_factory=dict)
284
279
  _parameters: dict[str, float] = field(default_factory=dict)
285
280
  _derived: dict[str, Derived] = field(default_factory=dict)
286
281
  _readouts: dict[str, Readout] = field(default_factory=dict)
287
282
  _reactions: dict[str, Reaction] = field(default_factory=dict)
288
283
  _surrogates: dict[str, AbstractSurrogate] = field(default_factory=dict)
289
284
  _cache: ModelCache | None = None
285
+ _data: dict[str, pd.Series | pd.DataFrame] = field(default_factory=dict)
290
286
 
291
287
  ###########################################################################
292
288
  # Cache
@@ -317,9 +313,17 @@ class Model:
317
313
  raise ArityMismatchError(name, el.fn, el.args)
318
314
 
319
315
  # Sort derived & reactions
320
- to_sort = self._derived | self._reactions | self._surrogates
316
+ to_sort = (
317
+ self._derived
318
+ | self._reactions
319
+ | self._surrogates
320
+ | {k: v for k, v in self._variables.items() if isinstance(v, Derived)}
321
+ )
321
322
  order = _sort_dependencies(
322
- available=set(self._parameters) | set(self._variables) | {"time"},
323
+ available=all_parameter_names
324
+ | {k for k, v in self._variables.items() if not isinstance(v, Derived)}
325
+ | set(self._data)
326
+ | {"time"},
323
327
  elements=[
324
328
  Dependency(name=k, required=set(v.args), provided={k})
325
329
  if not isinstance(v, AbstractSurrogate)
@@ -328,35 +332,42 @@ class Model:
328
332
  ],
329
333
  )
330
334
 
331
- # Split derived into parameters and variables
332
- # for user convenience
333
- derived_variable_names: list[str] = []
334
- derived_parameter_names: list[str] = []
335
+ # Calculate all values once, including dynamic ones
336
+ # That way, we can make initial conditions dependent on e.g. rates
337
+ dependent = (
338
+ all_parameter_values
339
+ | self._data
340
+ | {k: v for k, v in self._variables.items() if not isinstance(v, Derived)}
341
+ | {"time": 0.0}
342
+ )
343
+ for name in order:
344
+ to_sort[name].calculate_inpl(name, dependent)
345
+
346
+ # Split derived into static and dynamic variables
347
+ static_order = []
348
+ dyn_order = []
335
349
  for name in order:
336
350
  if name in self._reactions or name in self._surrogates:
337
- continue
338
- derived = self._derived[name]
339
- if all(i in all_parameter_names for i in derived.args):
340
- all_parameter_names.add(name)
341
- derived_parameter_names.append(name)
342
- all_parameter_values[name] = float(
343
- derived.fn(*(all_parameter_values[i] for i in derived.args))
344
- )
351
+ dyn_order.append(name)
352
+ elif name in self._variables:
353
+ static_order.append(name)
345
354
  else:
346
- derived_variable_names.append(name)
355
+ derived = self._derived[name]
356
+ if all(i in all_parameter_names for i in derived.args):
357
+ static_order.append(name)
358
+ all_parameter_names.add(name)
359
+ else:
360
+ dyn_order.append(name)
347
361
 
362
+ # Calculate dynamic and static stochiometries
348
363
  stoich_by_compounds: dict[str, dict[str, float]] = {}
349
364
  dyn_stoich_by_compounds: dict[str, dict[str, Derived]] = {}
350
-
351
365
  for rxn_name, rxn in self._reactions.items():
352
366
  for cpd_name, factor in rxn.stoichiometry.items():
353
367
  d_static = stoich_by_compounds.setdefault(cpd_name, {})
354
-
355
368
  if isinstance(factor, Derived):
356
369
  if all(i in all_parameter_names for i in factor.args):
357
- d_static[rxn_name] = float(
358
- factor.fn(*(all_parameter_values[i] for i in factor.args))
359
- )
370
+ d_static[rxn_name] = factor.calculate(dependent)
360
371
  else:
361
372
  dyn_stoich_by_compounds.setdefault(cpd_name, {})[rxn_name] = (
362
373
  factor
@@ -367,20 +378,40 @@ class Model:
367
378
  for surrogate in self._surrogates.values():
368
379
  for rxn_name, rxn in surrogate.stoichiometries.items():
369
380
  for cpd_name, factor in rxn.items():
370
- stoich_by_compounds.setdefault(cpd_name, {})[rxn_name] = factor
381
+ d_static = stoich_by_compounds.setdefault(cpd_name, {})
382
+ if isinstance(factor, Derived):
383
+ if all(i in all_parameter_names for i in factor.args):
384
+ d_static[rxn_name] = factor.calculate(dependent)
385
+ else:
386
+ dyn_stoich_by_compounds.setdefault(cpd_name, {})[
387
+ rxn_name
388
+ ] = factor
389
+ else:
390
+ d_static[rxn_name] = factor
371
391
 
372
392
  var_names = self.get_variable_names()
373
393
  dxdt = pd.Series(np.zeros(len(var_names), dtype=float), index=var_names)
374
394
 
395
+ initial_conditions: dict[str, float] = {
396
+ k: v for k, v in self._variables.items() if not isinstance(v, Derived)
397
+ }
398
+ for name in static_order:
399
+ if name in self._variables:
400
+ initial_conditions[name] = cast(float, dependent[name])
401
+ elif name in self._derived:
402
+ all_parameter_values[name] = cast(float, dependent[name])
403
+ else:
404
+ msg = "Unknown target for static derived variable."
405
+ raise KeyError(msg)
406
+
375
407
  self._cache = ModelCache(
376
408
  var_names=var_names,
377
- order=order,
409
+ dyn_order=dyn_order,
378
410
  all_parameter_values=all_parameter_values,
379
411
  stoich_by_cpds=stoich_by_compounds,
380
412
  dyn_stoich_by_cpds=dyn_stoich_by_compounds,
381
- derived_variable_names=derived_variable_names,
382
- derived_parameter_names=derived_parameter_names,
383
413
  dxdt=dxdt,
414
+ initial_conditions=initial_conditions,
384
415
  )
385
416
  return self._cache
386
417
 
@@ -657,12 +688,27 @@ class Model:
657
688
 
658
689
  return self
659
690
 
691
+ def get_unused_parameters(self) -> set[str]:
692
+ """Get parameters which aren't used in the model."""
693
+ args = set()
694
+ for variable in self._variables.values():
695
+ if isinstance(variable, Derived):
696
+ args.update(variable.args)
697
+ for derived in self._derived.values():
698
+ args.update(derived.args)
699
+ for reaction in self._reactions.values():
700
+ args.update(reaction.args)
701
+ for surrogate in self._surrogates.values():
702
+ args.update(surrogate.args)
703
+
704
+ return set(self._parameters).difference(args)
705
+
660
706
  ##########################################################################
661
707
  # Variables
662
708
  ##########################################################################
663
709
 
664
710
  @property
665
- def variables(self) -> dict[str, float]:
711
+ def variables(self) -> dict[str, float | Derived]:
666
712
  """Returns a copy of the variables dictionary.
667
713
 
668
714
  Examples:
@@ -679,7 +725,7 @@ class Model:
679
725
  return self._variables.copy()
680
726
 
681
727
  @_invalidate_cache
682
- def add_variable(self, name: str, initial_condition: float) -> Self:
728
+ def add_variable(self, name: str, initial_condition: float | Derived) -> Self:
683
729
  """Adds a variable to the model with the given name and initial condition.
684
730
 
685
731
  Examples:
@@ -697,7 +743,7 @@ class Model:
697
743
  self._variables[name] = initial_condition
698
744
  return self
699
745
 
700
- def add_variables(self, variables: dict[str, float]) -> Self:
746
+ def add_variables(self, variables: Mapping[str, float | Derived]) -> Self:
701
747
  """Adds multiple variables to the model with their initial conditions.
702
748
 
703
749
  Examples:
@@ -751,7 +797,7 @@ class Model:
751
797
  return self
752
798
 
753
799
  @_invalidate_cache
754
- def update_variable(self, name: str, initial_condition: float) -> Self:
800
+ def update_variable(self, name: str, initial_condition: float | Derived) -> Self:
755
801
  """Updates the value of a variable in the model.
756
802
 
757
803
  Examples:
@@ -771,7 +817,7 @@ class Model:
771
817
  self._variables[name] = initial_condition
772
818
  return self
773
819
 
774
- def update_variables(self, variables: dict[str, float]) -> Self:
820
+ def update_variables(self, variables: Mapping[str, float | Derived]) -> Self:
775
821
  """Updates multiple variables in the model.
776
822
 
777
823
  Examples:
@@ -812,7 +858,9 @@ class Model:
812
858
  initial_conditions: A dictionary where the keys are variable names and the values are their initial conditions.
813
859
 
814
860
  """
815
- return self._variables
861
+ if (cache := self._cache) is None:
862
+ cache = self._create_cache()
863
+ return cache.initial_conditions
816
864
 
817
865
  def make_variable_static(self, name: str, value: float | None = None) -> Self:
818
866
  """Converts a variable to a static parameter.
@@ -833,9 +881,12 @@ class Model:
833
881
  Self: The instance of the class for method chaining.
834
882
 
835
883
  """
836
- value = self._variables[name] if value is None else value
884
+ value_or_derived = self._variables[name] if value is None else value
837
885
  self.remove_variable(name)
838
- self.add_parameter(name, value)
886
+ if isinstance(value_or_derived, Derived):
887
+ self.add_derived(name, value_or_derived.fn, args=value_or_derived.args)
888
+ else:
889
+ self.add_parameter(name, value_or_derived)
839
890
 
840
891
  # Remove from stoichiometries
841
892
  for reaction in self._reactions.values():
@@ -886,7 +937,8 @@ class Model:
886
937
  if (cache := self._cache) is None:
887
938
  cache = self._create_cache()
888
939
  derived = self._derived
889
- return {k: derived[k] for k in cache.derived_variable_names}
940
+
941
+ return {k: v for k, v in derived.items() if k not in cache.all_parameter_values}
890
942
 
891
943
  @property
892
944
  def derived_parameters(self) -> dict[str, Derived]:
@@ -905,7 +957,7 @@ class Model:
905
957
  if (cache := self._cache) is None:
906
958
  cache = self._create_cache()
907
959
  derived = self._derived
908
- return {k: derived[k] for k in cache.derived_parameter_names}
960
+ return {k: v for k, v in derived.items() if k in cache.all_parameter_values}
909
961
 
910
962
  @_invalidate_cache
911
963
  def add_derived(
@@ -1049,6 +1101,52 @@ class Model:
1049
1101
  )
1050
1102
  return pd.DataFrame(stoich_by_cpds).T.fillna(0)
1051
1103
 
1104
+ def get_stoichiometries_of_variable(
1105
+ self,
1106
+ variable: str,
1107
+ variables: dict[str, float] | None = None,
1108
+ time: float = 0.0,
1109
+ ) -> dict[str, float]:
1110
+ """Retrieve the stoichiometry of a specific variable.
1111
+
1112
+ Examples:
1113
+ >>> model.get_stoichiometries_of_variable("x1")
1114
+ {"v1": -1, "v2": 1}
1115
+
1116
+ Args:
1117
+ variable: The name of the variable for which to retrieve the stoichiometry.
1118
+ variables: A dictionary of variable names and their values.
1119
+ time: The time point at which to evaluate the stoichiometry.
1120
+
1121
+ """
1122
+ if (cache := self._cache) is None:
1123
+ cache = self._create_cache()
1124
+ args = self.get_dependent(variables=variables, time=time)
1125
+
1126
+ stoich = copy.deepcopy(cache.stoich_by_cpds[variable])
1127
+ for rxn, derived in cache.dyn_stoich_by_cpds.get(variable, {}).items():
1128
+ stoich[rxn] = float(derived.fn(*(args[i] for i in derived.args)))
1129
+ return stoich
1130
+
1131
+ def get_raw_stoichiometries_of_variable(
1132
+ self, variable: str
1133
+ ) -> dict[str, float | Derived]:
1134
+ """Retrieve the raw stoichiometry of a specific variable.
1135
+
1136
+ Examples:
1137
+ >>> model.get_stoichiometries_of_variable("x1")
1138
+ {"v1": -1, "v2": Derived(...)}
1139
+
1140
+ Args:
1141
+ variable: The name of the variable for which to retrieve the stoichiometry.
1142
+
1143
+ """
1144
+ stoichs: dict[str, dict[str, float | Derived]] = {}
1145
+ for rxn_name, rxn in self._reactions.items():
1146
+ for cpd_name, factor in rxn.stoichiometry.items():
1147
+ stoichs.setdefault(cpd_name, {})[rxn_name] = factor
1148
+ return stoichs[variable]
1149
+
1052
1150
  @_invalidate_cache
1053
1151
  def add_reaction(
1054
1152
  self,
@@ -1250,7 +1348,7 @@ class Model:
1250
1348
  surrogate: AbstractSurrogate,
1251
1349
  args: list[str] | None = None,
1252
1350
  outputs: list[str] | None = None,
1253
- stoichiometries: dict[str, dict[str, float]] | None = None,
1351
+ stoichiometries: dict[str, dict[str, float | Derived]] | None = None,
1254
1352
  ) -> Self:
1255
1353
  """Adds a surrogate model to the current instance.
1256
1354
 
@@ -1285,7 +1383,7 @@ class Model:
1285
1383
  name: str,
1286
1384
  surrogate: AbstractSurrogate | None = None,
1287
1385
  args: list[str] | None = None,
1288
- stoichiometries: dict[str, dict[str, float]] | None = None,
1386
+ stoichiometries: dict[str, dict[str, float | Derived]] | None = None,
1289
1387
  ) -> Self:
1290
1388
  """Update a surrogate model in the model.
1291
1389
 
@@ -1337,6 +1435,27 @@ class Model:
1337
1435
  names.extend(i.stoichiometries)
1338
1436
  return names
1339
1437
 
1438
+ ##########################################################################
1439
+ # Datasets
1440
+ ##########################################################################
1441
+
1442
+ def add_data(self, name: str, data: pd.Series | pd.DataFrame) -> Self:
1443
+ """Add named data set to model."""
1444
+ self._insert_id(name=name, ctx="data")
1445
+ self._data[name] = data
1446
+ return self
1447
+
1448
+ def update_data(self, name: str, data: pd.Series | pd.DataFrame) -> Self:
1449
+ """Update named data set."""
1450
+ self._data[name] = data
1451
+ return self
1452
+
1453
+ def remove_data(self, name: str) -> Self:
1454
+ """Remove data set from model."""
1455
+ self._remove_id(name=name)
1456
+ self._data.pop(name)
1457
+ return self
1458
+
1340
1459
  ##########################################################################
1341
1460
  # Get dependent values. This includes
1342
1461
  # - derived parameters
@@ -1371,14 +1490,17 @@ class Model:
1371
1490
  with their respective names as keys and their calculated values as values.
1372
1491
 
1373
1492
  """
1374
- args: dict[str, float] = cache.all_parameter_values | variables
1493
+ args = cache.all_parameter_values | variables | self._data
1375
1494
  args["time"] = time
1376
1495
 
1377
1496
  containers = self._derived | self._reactions | self._surrogates
1378
- for name in cache.order:
1497
+ for name in cache.dyn_order:
1379
1498
  containers[name].calculate_inpl(name, args)
1380
1499
 
1381
- return args
1500
+ for k in self._data:
1501
+ args.pop(k)
1502
+
1503
+ return cast(dict[str, float], args)
1382
1504
 
1383
1505
  def get_dependent(
1384
1506
  self,
@@ -1454,29 +1576,16 @@ class Model:
1454
1576
  derived variables, and optionally readout variables, with time as an additional column.
1455
1577
 
1456
1578
  """
1457
- if (cache := self._cache) is None:
1458
- cache = self._create_cache()
1459
-
1460
- pars_df = pd.DataFrame(
1461
- np.full(
1462
- (len(variables), len(cache.all_parameter_values)),
1463
- np.fromiter(cache.all_parameter_values.values(), dtype=float),
1464
- ),
1465
- index=variables.index,
1466
- columns=list(cache.all_parameter_values),
1467
- )
1468
-
1469
- args = pd.concat((variables, pars_df), axis=1)
1470
- args["time"] = args.index
1471
-
1472
- containers = self._derived | self._reactions | self._surrogates
1473
- for name in cache.order:
1474
- containers[name].calculate_inpl_time_course(name, args)
1579
+ args = {
1580
+ time: self.get_dependent(
1581
+ variables=values.to_dict(),
1582
+ time=cast(float, time),
1583
+ include_readouts=include_readouts,
1584
+ )
1585
+ for time, values in variables.iterrows()
1586
+ }
1475
1587
 
1476
- if include_readouts:
1477
- for name, ro in self._readouts.items():
1478
- args[name] = ro.fn(*args.loc[:, ro.args].to_numpy().T)
1479
- return args
1588
+ return pd.DataFrame(args, dtype=float).T
1480
1589
 
1481
1590
  ##########################################################################
1482
1591
  # Get args
@@ -1569,28 +1678,6 @@ class Model:
1569
1678
  # Get fluxes
1570
1679
  ##########################################################################
1571
1680
 
1572
- def _get_fluxes(self, args: dict[str, float]) -> dict[str, float]:
1573
- """Calculate the fluxes for the given arguments.
1574
-
1575
- Examples:
1576
- >>> model._get_fluxes({"x1": 1.0, "x2": 2.0, "k1": 0.1, "time": 0.0})
1577
- {"r1": 0.1, "r2": 0.2}
1578
-
1579
- Args:
1580
- args (dict[str, float]): A dictionary where the keys are argument names and the values are their corresponding float values.
1581
-
1582
- Returns:
1583
- dict[str, float]: A dictionary where the keys are reaction names and the values are the calculated fluxes.
1584
-
1585
- """
1586
- fluxes: dict[str, float] = {}
1587
- for name, rxn in self._reactions.items():
1588
- fluxes[name] = cast(float, rxn.fn(*(args[arg] for arg in rxn.args)))
1589
-
1590
- for surrogate in self._surrogates.values():
1591
- fluxes |= surrogate.predict(np.array([args[arg] for arg in surrogate.args]))
1592
- return fluxes
1593
-
1594
1681
  def get_fluxes(
1595
1682
  self,
1596
1683
  variables: dict[str, float] | None = None,
mxlpy/nn/__init__.py CHANGED
@@ -1,10 +1,29 @@
1
1
  """Collection of neural network architectures."""
2
2
 
3
- import contextlib
3
+ from __future__ import annotations
4
4
 
5
- __all__ = ["tensorflow", "torch"]
5
+ from typing import TYPE_CHECKING
6
6
 
7
- with contextlib.suppress(ImportError):
8
- from . import _torch as torch
7
+ if TYPE_CHECKING:
8
+ import contextlib
9
9
 
10
- from . import _tensorflow as tensorflow
10
+ with contextlib.suppress(ImportError):
11
+ from . import _keras as keras
12
+ from . import _torch as torch
13
+ else:
14
+ from lazy_import import lazy_module
15
+
16
+ keras = lazy_module(
17
+ "mxlpy.nn._keras",
18
+ error_strings={"module": "keras", "install_name": "mxlpy[tf]"},
19
+ )
20
+ torch = lazy_module(
21
+ "mxlpy.nn._torch",
22
+ error_strings={"module": "torch", "install_name": "mxlpy[torch]"},
23
+ )
24
+
25
+
26
+ __all__ = [
27
+ "keras",
28
+ "torch",
29
+ ]
mxlpy/nn/_keras.py ADDED
@@ -0,0 +1,92 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, cast
4
+
5
+ import keras
6
+ import pandas as pd
7
+ from tqdm.keras import TqdmCallback
8
+
9
+ if TYPE_CHECKING:
10
+ from mxlpy.types import Array
11
+
12
+ __all__ = [
13
+ "LSTM",
14
+ "MLP",
15
+ "train",
16
+ ]
17
+
18
+
19
+ def train(
20
+ model: keras.Model,
21
+ features: pd.DataFrame | Array,
22
+ targets: pd.DataFrame | Array,
23
+ epochs: int,
24
+ batch_size: int | None,
25
+ ) -> pd.Series:
26
+ """Train the neural network using mini-batch gradient descent.
27
+
28
+ Args:
29
+ model: Neural network model to train.
30
+ features: Input features as a tensor.
31
+ targets: Target values as a tensor.
32
+ epochs: Number of training epochs.
33
+ optimizer: Optimizer for training.
34
+ device: torch device
35
+ batch_size: Size of mini-batches for training.
36
+ loss_fn: Loss function
37
+
38
+ Returns:
39
+ pd.Series: Series containing the training loss history.
40
+
41
+ """
42
+ history = model.fit(
43
+ features,
44
+ targets,
45
+ batch_size=batch_size,
46
+ epochs=epochs,
47
+ verbose=cast(str, 0),
48
+ callbacks=[TqdmCallback()],
49
+ )
50
+ return pd.Series(history.history["loss"])
51
+
52
+
53
+ def MLP( # noqa: N802
54
+ n_inputs: int,
55
+ neurons_per_layer: list[int],
56
+ activation: None = None,
57
+ output_activation: None = None,
58
+ ) -> keras.Sequential:
59
+ """Multilayer Perceptron (MLP) for surrogate modeling and neural posterior estimation.
60
+
61
+ Methods:
62
+ forward: Forward pass through the neural network.
63
+
64
+ """
65
+ model = keras.Sequential([keras.Input(shape=(n_inputs,))])
66
+ for neurons in neurons_per_layer[:-1]:
67
+ model.add(keras.layers.Dense(neurons, activation=activation))
68
+ model.add(keras.layers.Dense(neurons_per_layer[-1], activation=output_activation))
69
+ return model
70
+
71
+
72
+ def LSTM( # noqa: N802
73
+ n_inputs: int,
74
+ n_outputs: int,
75
+ n_hidden: int,
76
+ ) -> keras.Sequential:
77
+ """Long Short-Term Memory (LSTM) network for time series modeling.
78
+
79
+ Methods:
80
+ forward: Forward pass through the neural network.
81
+
82
+ """
83
+ model = keras.Sequential(
84
+ [
85
+ keras.Input(
86
+ shape=(n_inputs),
87
+ )
88
+ ]
89
+ )
90
+ model.add(keras.layers.LSTM(n_hidden))
91
+ model.add(keras.layers.Dense(n_outputs))
92
+ return model