mxlpy 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. mxlpy/__init__.py +165 -0
  2. mxlpy/distributions.py +339 -0
  3. mxlpy/experimental/__init__.py +12 -0
  4. mxlpy/experimental/diff.py +226 -0
  5. mxlpy/fit.py +291 -0
  6. mxlpy/fns.py +191 -0
  7. mxlpy/integrators/__init__.py +19 -0
  8. mxlpy/integrators/int_assimulo.py +146 -0
  9. mxlpy/integrators/int_scipy.py +146 -0
  10. mxlpy/label_map.py +610 -0
  11. mxlpy/linear_label_map.py +303 -0
  12. mxlpy/mc.py +548 -0
  13. mxlpy/mca.py +280 -0
  14. mxlpy/meta/__init__.py +11 -0
  15. mxlpy/meta/codegen_latex.py +516 -0
  16. mxlpy/meta/codegen_modebase.py +110 -0
  17. mxlpy/meta/codegen_py.py +107 -0
  18. mxlpy/meta/source_tools.py +320 -0
  19. mxlpy/model.py +1737 -0
  20. mxlpy/nn/__init__.py +10 -0
  21. mxlpy/nn/_tensorflow.py +0 -0
  22. mxlpy/nn/_torch.py +129 -0
  23. mxlpy/npe.py +277 -0
  24. mxlpy/parallel.py +171 -0
  25. mxlpy/parameterise.py +27 -0
  26. mxlpy/paths.py +36 -0
  27. mxlpy/plot.py +875 -0
  28. mxlpy/py.typed +0 -0
  29. mxlpy/sbml/__init__.py +14 -0
  30. mxlpy/sbml/_data.py +77 -0
  31. mxlpy/sbml/_export.py +644 -0
  32. mxlpy/sbml/_import.py +599 -0
  33. mxlpy/sbml/_mathml.py +691 -0
  34. mxlpy/sbml/_name_conversion.py +52 -0
  35. mxlpy/sbml/_unit_conversion.py +74 -0
  36. mxlpy/scan.py +629 -0
  37. mxlpy/simulator.py +655 -0
  38. mxlpy/surrogates/__init__.py +31 -0
  39. mxlpy/surrogates/_poly.py +97 -0
  40. mxlpy/surrogates/_torch.py +196 -0
  41. mxlpy/symbolic/__init__.py +10 -0
  42. mxlpy/symbolic/strikepy.py +582 -0
  43. mxlpy/symbolic/symbolic_model.py +75 -0
  44. mxlpy/types.py +474 -0
  45. mxlpy-0.8.0.dist-info/METADATA +106 -0
  46. mxlpy-0.8.0.dist-info/RECORD +48 -0
  47. mxlpy-0.8.0.dist-info/WHEEL +4 -0
  48. mxlpy-0.8.0.dist-info/licenses/LICENSE +674 -0
mxlpy/simulator.py ADDED
@@ -0,0 +1,655 @@
1
+ """Simulation Module.
2
+
3
+ This module provides classes and functions for simulating metabolic models.
4
+ It includes functionality for running simulations, normalizing results, and
5
+ retrieving simulation data.
6
+
7
+ Classes:
8
+ Simulator: Class for running simulations on a metabolic model.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from dataclasses import dataclass, field
14
+ from typing import TYPE_CHECKING, Literal, Self, cast, overload
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+
19
+ from mxlpy.integrators import DefaultIntegrator
20
+
21
+ __all__ = ["Result", "Simulator"]
22
+
23
+ if TYPE_CHECKING:
24
+ from collections.abc import Callable, Iterator
25
+
26
+ from mxlpy.model import Model
27
+ from mxlpy.types import Array, ArrayLike, IntegratorProtocol
28
+
29
+
30
+ def _normalise_split_results(
31
+ results: list[pd.DataFrame],
32
+ normalise: float | ArrayLike,
33
+ ) -> list[pd.DataFrame]:
34
+ """Normalize split results by a given factor or array.
35
+
36
+ Args:
37
+ results: List of DataFrames containing the results to normalize.
38
+ normalise: Normalization factor or array.
39
+
40
+ Returns:
41
+ list[pd.DataFrame]: List of normalized DataFrames.
42
+
43
+ """
44
+ if isinstance(normalise, int | float):
45
+ return [i / normalise for i in results]
46
+ if len(normalise) == len(results):
47
+ return [(i.T / j).T for i, j in zip(results, normalise, strict=True)]
48
+
49
+ results = []
50
+ start = 0
51
+ end = 0
52
+ for i in results:
53
+ end += len(i)
54
+ results.append(i / np.reshape(normalise[start:end], (len(i), 1))) # type: ignore
55
+ start += end
56
+ return results
57
+
58
+
59
+ @dataclass(kw_only=True, slots=True)
60
+ class Result:
61
+ """Simulation results."""
62
+
63
+ model: Model
64
+ _raw_variables: list[pd.DataFrame]
65
+ _parameters: list[dict[str, float]]
66
+ _dependent: list[pd.DataFrame] = field(default_factory=list)
67
+
68
+ @property
69
+ def variables(self) -> pd.DataFrame:
70
+ """Simulation variables."""
71
+ return self.get_variables(
72
+ include_derived=True,
73
+ include_readouts=True,
74
+ concatenated=True,
75
+ normalise=None,
76
+ )
77
+
78
+ @property
79
+ def fluxes(self) -> pd.DataFrame:
80
+ """Simulation fluxes."""
81
+ return self.get_fluxes()
82
+
83
+ def __iter__(self) -> Iterator[pd.DataFrame]:
84
+ """Iterate over the concentration and flux response coefficients."""
85
+ return iter((self.variables, self.fluxes))
86
+
87
+ def _get_dependent(
88
+ self,
89
+ *,
90
+ include_readouts: bool = True,
91
+ ) -> list[pd.DataFrame]:
92
+ # Already computed
93
+ if len(self._dependent) > 0:
94
+ return self._dependent
95
+
96
+ # Compute new otherwise
97
+ for res, p in zip(self._raw_variables, self._parameters, strict=True):
98
+ self.model.update_parameters(p)
99
+ self._dependent.append(
100
+ self.model.get_dependent_time_course(
101
+ variables=res,
102
+ include_readouts=include_readouts,
103
+ )
104
+ )
105
+ return self._dependent
106
+
107
+ def _select_variables(
108
+ self,
109
+ dependent: list[pd.DataFrame],
110
+ *,
111
+ include_derived: bool,
112
+ include_readouts: bool,
113
+ ) -> list[pd.DataFrame]:
114
+ names = self.model.get_variable_names()
115
+ if include_derived:
116
+ names.extend(self.model.get_derived_variable_names())
117
+ if include_readouts:
118
+ names.extend(self.model.get_readout_names())
119
+ return [i.loc[:, names] for i in dependent]
120
+
121
+ def _select_fluxes(
122
+ self,
123
+ dependent: list[pd.DataFrame],
124
+ *,
125
+ include_surrogates: bool,
126
+ ) -> list[pd.DataFrame]:
127
+ names = self.model.get_reaction_names()
128
+ if include_surrogates:
129
+ names.extend(self.model.get_surrogate_reaction_names())
130
+ return [i.loc[:, names] for i in dependent]
131
+
132
+ def _adjust_data(
133
+ self,
134
+ data: list[pd.DataFrame],
135
+ normalise: float | ArrayLike | None = None,
136
+ *,
137
+ concatenated: bool = True,
138
+ ) -> pd.DataFrame | list[pd.DataFrame]:
139
+ if normalise is not None:
140
+ data = _normalise_split_results(data, normalise=normalise)
141
+ if concatenated:
142
+ return pd.concat(data, axis=0)
143
+ return data
144
+
145
+ @overload
146
+ def get_variables( # type: ignore
147
+ self,
148
+ *,
149
+ include_derived: bool = True,
150
+ include_readouts: bool = True,
151
+ concatenated: Literal[False],
152
+ normalise: float | ArrayLike | None = None,
153
+ ) -> list[pd.DataFrame]: ...
154
+
155
+ @overload
156
+ def get_variables(
157
+ self,
158
+ *,
159
+ include_derived: bool = True,
160
+ include_readouts: bool = True,
161
+ concatenated: Literal[True],
162
+ normalise: float | ArrayLike | None = None,
163
+ ) -> pd.DataFrame: ...
164
+
165
+ @overload
166
+ def get_variables(
167
+ self,
168
+ *,
169
+ include_derived: bool = True,
170
+ include_readouts: bool = True,
171
+ concatenated: bool = True,
172
+ normalise: float | ArrayLike | None = None,
173
+ ) -> pd.DataFrame: ...
174
+
175
+ def get_variables(
176
+ self,
177
+ *,
178
+ include_derived: bool = True,
179
+ include_readouts: bool = True,
180
+ concatenated: bool = True,
181
+ normalise: float | ArrayLike | None = None,
182
+ ) -> pd.DataFrame | list[pd.DataFrame]:
183
+ """Get the variables over time.
184
+
185
+ Examples:
186
+ >>> Result().get_variables()
187
+ Time ATP NADPH
188
+ 0.000000 1.000000 1.000000
189
+ 0.000100 0.999900 0.999900
190
+ 0.000200 0.999800 0.999800
191
+
192
+ """
193
+ if not include_derived and not include_readouts:
194
+ return self._adjust_data(
195
+ self._raw_variables,
196
+ normalise=normalise,
197
+ concatenated=concatenated,
198
+ )
199
+
200
+ variables = self._select_variables(
201
+ self._get_dependent(),
202
+ include_derived=include_derived,
203
+ include_readouts=include_readouts,
204
+ )
205
+ return self._adjust_data(
206
+ variables, normalise=normalise, concatenated=concatenated
207
+ )
208
+
209
+ @overload
210
+ def get_fluxes( # type: ignore
211
+ self,
212
+ *,
213
+ include_surrogates: bool = True,
214
+ normalise: float | ArrayLike | None = None,
215
+ concatenated: Literal[False],
216
+ ) -> list[pd.DataFrame]: ...
217
+
218
+ @overload
219
+ def get_fluxes(
220
+ self,
221
+ *,
222
+ include_surrogates: bool = True,
223
+ normalise: float | ArrayLike | None = None,
224
+ concatenated: Literal[True],
225
+ ) -> pd.DataFrame: ...
226
+
227
+ @overload
228
+ def get_fluxes(
229
+ self,
230
+ *,
231
+ include_surrogates: bool = True,
232
+ normalise: float | ArrayLike | None = None,
233
+ concatenated: bool = True,
234
+ ) -> pd.DataFrame: ...
235
+
236
+ def get_fluxes(
237
+ self,
238
+ *,
239
+ include_surrogates: bool = True,
240
+ normalise: float | ArrayLike | None = None,
241
+ concatenated: bool = True,
242
+ ) -> pd.DataFrame | list[pd.DataFrame]:
243
+ """Get the flux results.
244
+
245
+ Examples:
246
+ >>> Result.get_fluxes()
247
+ Time v1 v2
248
+ 0.000000 1.000000 10.00000
249
+ 0.000100 0.999900 9.999000
250
+ 0.000200 0.999800 9.998000
251
+
252
+ Returns:
253
+ pd.DataFrame: DataFrame of fluxes.
254
+
255
+ """
256
+ fluxes = self._select_fluxes(
257
+ self._get_dependent(),
258
+ include_surrogates=include_surrogates,
259
+ )
260
+ return self._adjust_data(
261
+ fluxes,
262
+ normalise=normalise,
263
+ concatenated=concatenated,
264
+ )
265
+
266
+ def get_combined(self) -> pd.DataFrame:
267
+ """Get the variables and fluxes as a single pandas.DataFrame.
268
+
269
+ Examples:
270
+ >>> Result.get_combined()
271
+ Time ATP NADPH v1 v2
272
+ 0.000000 1.000000 1.000000 1.000000 10.00000
273
+ 0.000100 0.999900 0.999900 0.999900 9.999000
274
+ 0.000200 0.999800 0.999800 0.999800 9.998000
275
+
276
+ Returns:
277
+ pd.DataFrame: DataFrame of fluxes.
278
+
279
+ """
280
+ return pd.concat((self.variables, self.fluxes), axis=1)
281
+
282
+ def get_new_y0(self) -> dict[str, float]:
283
+ """Get the new initial conditions after the simulation.
284
+
285
+ Examples:
286
+ >>> Simulator(model).simulate_to_steady_state().get_new_y0()
287
+ {"ATP": 1.0, "NADPH": 1.0}
288
+
289
+ """
290
+ return dict(
291
+ self.get_variables(
292
+ include_derived=False,
293
+ include_readouts=False,
294
+ ).iloc[-1]
295
+ )
296
+
297
+
298
+ @dataclass(
299
+ init=False,
300
+ slots=True,
301
+ eq=False,
302
+ )
303
+ class Simulator:
304
+ """Simulator class for running simulations on a metabolic model.
305
+
306
+ Attributes:
307
+ model: Model instance to simulate.
308
+ y0: Initial conditions for the simulation.
309
+ integrator: Integrator protocol to use for the simulation.
310
+ variables: List of DataFrames containing concentration results.
311
+ dependent: List of DataFrames containing argument values.
312
+ simulation_parameters: List of dictionaries containing simulation parameters.
313
+
314
+ """
315
+
316
+ model: Model
317
+ y0: dict[str, float]
318
+ integrator: IntegratorProtocol
319
+ variables: list[pd.DataFrame] | None
320
+ dependent: list[pd.DataFrame] | None
321
+ simulation_parameters: list[dict[str, float]] | None
322
+
323
+ # For resets (e.g. update variable)
324
+ _integrator_type: Callable[[Callable, ArrayLike], IntegratorProtocol]
325
+ _time_shift: float | None
326
+
327
+ def __init__(
328
+ self,
329
+ model: Model,
330
+ y0: dict[str, float] | None = None,
331
+ integrator: Callable[
332
+ [Callable, ArrayLike], IntegratorProtocol
333
+ ] = DefaultIntegrator,
334
+ *,
335
+ test_run: bool = True,
336
+ ) -> None:
337
+ """Initialize the Simulator.
338
+
339
+ Args:
340
+ model: The model to be simulated.
341
+ y0: Initial conditions for the model variables.
342
+ If None, the initial conditions are obtained from the model.
343
+ integrator: The integrator to use for the simulation.
344
+ test_run (bool, optional): If True, performs a test run for better error messages
345
+
346
+ """
347
+ self.model = model
348
+ self.y0 = model.get_initial_conditions() if y0 is None else y0
349
+
350
+ self._integrator_type = integrator
351
+ self._time_shift = None
352
+ self.variables = None
353
+ self.simulation_parameters = None
354
+
355
+ if test_run:
356
+ self.model.get_right_hand_side(self.y0, time=0)
357
+
358
+ self._initialise_integrator()
359
+
360
+ def _initialise_integrator(self) -> None:
361
+ y0 = self.y0
362
+ self.integrator = self._integrator_type(
363
+ self.model,
364
+ [y0[k] for k in self.model.get_variable_names()],
365
+ )
366
+
367
+ def clear_results(self) -> None:
368
+ """Clear simulation results."""
369
+ self.variables = None
370
+ self.dependent = None
371
+ self.simulation_parameters = None
372
+ self._time_shift = None
373
+ self._initialise_integrator()
374
+
375
+ def _handle_simulation_results(
376
+ self,
377
+ time: Array | None,
378
+ results: ArrayLike | None,
379
+ *,
380
+ skipfirst: bool,
381
+ ) -> None:
382
+ """Handle simulation results.
383
+
384
+ Args:
385
+ time: Array of time points for the simulation.
386
+ results: Array of results for the simulation.
387
+ skipfirst: Whether to skip the first row of results.
388
+
389
+ """
390
+ if time is None or results is None:
391
+ # Need to clear results in case continued integration fails
392
+ # to keep expectation that failure = None
393
+ self.clear_results()
394
+ return
395
+
396
+ if self._time_shift is not None:
397
+ time += self._time_shift
398
+
399
+ # NOTE: IMPORTANT!
400
+ # model._get_rhs sorts the return array by model.get_variable_names()
401
+ # Do NOT change this ordering
402
+ results_df = pd.DataFrame(
403
+ results,
404
+ index=time,
405
+ columns=self.model.get_variable_names(),
406
+ )
407
+
408
+ if self.variables is None:
409
+ self.variables = [results_df]
410
+ elif skipfirst:
411
+ self.variables.append(results_df.iloc[1:, :])
412
+ else:
413
+ self.variables.append(results_df)
414
+
415
+ if self.simulation_parameters is None:
416
+ self.simulation_parameters = []
417
+ self.simulation_parameters.append(self.model.parameters)
418
+
419
+ def simulate(
420
+ self,
421
+ t_end: float,
422
+ steps: int | None = None,
423
+ ) -> Self:
424
+ """Simulate the model.
425
+
426
+ Examples:
427
+ >>> s.simulate(t_end=100)
428
+ >>> s.simulate(t_end=100, steps=100)
429
+
430
+ You can either supply only a terminal time point, or additionally also the
431
+ number of steps for which values should be returned.
432
+
433
+ Args:
434
+ t_end: Terminal time point for the simulation.
435
+ steps: Number of steps for the simulation.
436
+
437
+ Returns:
438
+ Self: The Simulator instance with updated results.
439
+
440
+ """
441
+ if self._time_shift is not None:
442
+ t_end -= self._time_shift
443
+
444
+ time, results = self.integrator.integrate(t_end=t_end, steps=steps)
445
+
446
+ self._handle_simulation_results(time, results, skipfirst=True)
447
+ return self
448
+
449
+ def simulate_time_course(self, time_points: ArrayLike) -> Self:
450
+ """Simulate the model over a given set of time points.
451
+
452
+ Examples:
453
+ >>> Simulator(model).simulate_time_course([1, 2, 3])
454
+
455
+ You can either supply only a terminal time point, or additionally also the
456
+ number of steps or exact time points for which values should be returned.
457
+
458
+ Args:
459
+ t_end: Terminal time point for the simulation.
460
+ steps: Number of steps for the simulation.
461
+ time_points: Exact time points for which values should be returned.
462
+
463
+ Returns:
464
+ Self: The Simulator instance with updated results.
465
+
466
+ """
467
+ if self._time_shift is not None:
468
+ time_points = np.array(time_points, dtype=float)
469
+ time_points -= self._time_shift
470
+
471
+ time, results = self.integrator.integrate_time_course(time_points=time_points)
472
+ self._handle_simulation_results(time, results, skipfirst=True)
473
+ return self
474
+
475
+ def simulate_to_steady_state(
476
+ self,
477
+ tolerance: float = 1e-6,
478
+ *,
479
+ rel_norm: bool = False,
480
+ ) -> Self:
481
+ """Simulate the model to steady state.
482
+
483
+ Examples:
484
+ >>> Simulator(model).simulate_to_steady_state()
485
+ >>> Simulator(model).simulate_to_steady_state(tolerance=1e-8)
486
+ >>> Simulator(model).simulate_to_steady_state(rel_norm=True)
487
+
488
+ You can either supply only a terminal time point, or additionally also the
489
+ number of steps or exact time points for which values should be returned.
490
+
491
+ Args:
492
+ tolerance: Tolerance for the steady-state calculation.
493
+ rel_norm: Whether to use relative norm for the steady-state calculation.
494
+
495
+ Returns:
496
+ Self: The Simulator instance with updated results.
497
+
498
+ """
499
+ time, results = self.integrator.integrate_to_steady_state(
500
+ tolerance=tolerance,
501
+ rel_norm=rel_norm,
502
+ )
503
+ self._handle_simulation_results(
504
+ np.array([time], dtype=float) if time is not None else None,
505
+ [results] if results is not None else None, # type: ignore
506
+ skipfirst=False,
507
+ )
508
+ return self
509
+
510
+ def simulate_over_protocol(
511
+ self,
512
+ protocol: pd.DataFrame,
513
+ time_points_per_step: int = 10,
514
+ ) -> Self:
515
+ """Simulate the model over a given protocol.
516
+
517
+ Examples:
518
+ >>> Simulator(model).simulate_over_protocol(
519
+ ... protocol,
520
+ ... time_points_per_step=10
521
+ ... )
522
+
523
+ Args:
524
+ protocol: DataFrame containing the protocol.
525
+ time_points_per_step: Number of time points per step.
526
+
527
+ Returns:
528
+ The Simulator instance with updated results.
529
+
530
+ """
531
+ for t_end, pars in protocol.iterrows():
532
+ t_end = cast(pd.Timedelta, t_end)
533
+ self.model.update_parameters(pars.to_dict())
534
+ self.simulate(t_end.total_seconds(), steps=time_points_per_step)
535
+ if self.variables is None:
536
+ break
537
+ return self
538
+
539
+ def get_result(self) -> Result | None:
540
+ """Get result of the simulation.
541
+
542
+ Examples:
543
+ >>> variables, fluxes = Simulator(model).simulate().get_result()
544
+ >>> variables
545
+ Time ATP NADPH
546
+ 0.000000 1.000000 1.000000
547
+ 0.000100 0.999900 0.999900
548
+ 0.000200 0.999800 0.999800
549
+ >>> fluxes
550
+ Time v1 v2
551
+ 0.000000 1.000000 10.00000
552
+ 0.000100 0.999900 9.999000
553
+ 0.000200 0.999800 9.998000
554
+
555
+ """
556
+ if (variables := self.variables) is None:
557
+ return None
558
+ if (parameters := self.simulation_parameters) is None:
559
+ return None
560
+ return Result(
561
+ model=self.model,
562
+ _raw_variables=variables,
563
+ _parameters=parameters,
564
+ )
565
+
566
+ def update_parameter(self, parameter: str, value: float) -> Self:
567
+ """Updates the value of a specified parameter in the model.
568
+
569
+ Examples:
570
+ >>> Simulator(model).update_parameter("k1", 0.1)
571
+
572
+ Args:
573
+ parameter: The name of the parameter to update.
574
+ value: The new value to set for the parameter.
575
+
576
+ """
577
+ self.model.update_parameter(parameter, value)
578
+ return self
579
+
580
+ def update_parameters(self, parameters: dict[str, float]) -> Self:
581
+ """Updates the model parameters with the provided dictionary of parameters.
582
+
583
+ Examples:
584
+ >>> Simulator(model).update_parameters({"k1": 0.1, "k2": 0.2})
585
+
586
+ Args:
587
+ parameters: A dictionary where the keys are parameter names
588
+ and the values are the new parameter values.
589
+
590
+ """
591
+ self.model.update_parameters(parameters)
592
+ return self
593
+
594
+ def scale_parameter(self, parameter: str, factor: float) -> Self:
595
+ """Scales the value of a specified parameter in the model.
596
+
597
+ Examples:
598
+ >>> Simulator(model).scale_parameter("k1", 0.1)
599
+
600
+ Args:
601
+ parameter: The name of the parameter to scale.
602
+ factor: The factor by which to scale the parameter.
603
+
604
+ """
605
+ self.model.scale_parameter(parameter, factor)
606
+ return self
607
+
608
+ def scale_parameters(self, parameters: dict[str, float]) -> Self:
609
+ """Scales the values of specified parameters in the model.
610
+
611
+ Examples:
612
+ >>> Simulator(model).scale_parameters({"k1": 0.1, "k2": 0.2})
613
+
614
+ Args:
615
+ parameters: A dictionary where the keys are parameter names
616
+ and the values are the scaling factors.
617
+
618
+ """
619
+ self.model.scale_parameters(parameters)
620
+ return self
621
+
622
+ def update_variable(self, variable: str, value: float) -> Self:
623
+ """Updates the value of a specified value in the simulation.
624
+
625
+ Examples:
626
+ >>> Simulator(model).update_variable("k1", 0.1)
627
+
628
+ Args:
629
+ variable: name of the model variable
630
+ value: new value
631
+
632
+ """
633
+ return self.update_variables({variable: value})
634
+
635
+ def update_variables(self, variables: dict[str, float]) -> Self:
636
+ """Updates the value of a specified value in the simulation.
637
+
638
+ Examples:
639
+ >>> Simulator(model).update_variables({"k1": 0.1})
640
+
641
+ Args:
642
+ variables: {variable: value} pairs
643
+
644
+ """
645
+ sim_variables = self.variables
646
+
647
+ # In case someone calls this before the first simulation
648
+ if sim_variables is None:
649
+ self.y0 = self.y0 | variables
650
+ return self
651
+
652
+ self.y0 = sim_variables[-1].iloc[-1, :].to_dict() | variables
653
+ self._time_shift = float(sim_variables[-1].index[-1])
654
+ self._initialise_integrator()
655
+ return self
@@ -0,0 +1,31 @@
1
+ """Surrogate Models Module.
2
+
3
+ This module provides classes and functions for creating and training surrogate models
4
+ for metabolic simulations. It includes functionality for both steady-state and time-series
5
+ data using neural networks.
6
+
7
+ Classes:
8
+ AbstractSurrogate: Abstract base class for surrogate models.
9
+ TorchSurrogate: Surrogate model using PyTorch.
10
+ Approximator: Neural network approximator for surrogate modeling.
11
+
12
+ Functions:
13
+ train_torch_surrogate: Train a PyTorch surrogate model.
14
+ train_torch_time_course_estimator: Train a PyTorch time course estimator.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import contextlib
20
+
21
+ with contextlib.suppress(ImportError):
22
+ from ._torch import TorchSurrogate, train_torch_surrogate
23
+
24
+ from ._poly import PolySurrogate, train_polynomial_surrogate
25
+
26
+ __all__ = [
27
+ "PolySurrogate",
28
+ "TorchSurrogate",
29
+ "train_polynomial_surrogate",
30
+ "train_torch_surrogate",
31
+ ]