mxlpy 0.17.0__py3-none-any.whl → 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mxlpy/model.py CHANGED
@@ -250,17 +250,17 @@ class ModelCache:
250
250
  stoich_by_cpds: A dictionary mapping compound names to their stoichiometric coefficients.
251
251
  dyn_stoich_by_cpds: A dictionary mapping compound names to their dynamic stoichiometric coefficients.
252
252
  dxdt: A pandas Series representing the rate of change of variables.
253
+ initial_conditions: calculated initial conditions
253
254
 
254
255
  """
255
256
 
256
257
  var_names: list[str]
257
- order: list[str]
258
+ dyn_order: list[str]
258
259
  all_parameter_values: dict[str, float]
259
- derived_parameter_names: list[str]
260
- derived_variable_names: list[str]
261
260
  stoich_by_cpds: dict[str, dict[str, float]]
262
261
  dyn_stoich_by_cpds: dict[str, dict[str, Derived]]
263
262
  dxdt: pd.Series
263
+ initial_conditions: dict[str, float]
264
264
 
265
265
 
266
266
  @dataclass(slots=True)
@@ -276,17 +276,19 @@ class Model:
276
276
  _reactions: Dictionary of reactions in the model.
277
277
  _surrogates: Dictionary of surrogate models.
278
278
  _cache: Cache for storing model-related data structures.
279
+ _data: Named references to data sets
279
280
 
280
281
  """
281
282
 
282
283
  _ids: dict[str, str] = field(default_factory=dict)
283
- _variables: dict[str, float] = field(default_factory=dict)
284
+ _variables: dict[str, float | Derived] = field(default_factory=dict)
284
285
  _parameters: dict[str, float] = field(default_factory=dict)
285
286
  _derived: dict[str, Derived] = field(default_factory=dict)
286
287
  _readouts: dict[str, Readout] = field(default_factory=dict)
287
288
  _reactions: dict[str, Reaction] = field(default_factory=dict)
288
289
  _surrogates: dict[str, AbstractSurrogate] = field(default_factory=dict)
289
290
  _cache: ModelCache | None = None
291
+ _data: dict[str, pd.Series | pd.DataFrame] = field(default_factory=dict)
290
292
 
291
293
  ###########################################################################
292
294
  # Cache
@@ -317,9 +319,17 @@ class Model:
317
319
  raise ArityMismatchError(name, el.fn, el.args)
318
320
 
319
321
  # Sort derived & reactions
320
- to_sort = self._derived | self._reactions | self._surrogates
322
+ to_sort = (
323
+ self._derived
324
+ | self._reactions
325
+ | self._surrogates
326
+ | {k: v for k, v in self._variables.items() if isinstance(v, Derived)}
327
+ )
321
328
  order = _sort_dependencies(
322
- available=set(self._parameters) | set(self._variables) | {"time"},
329
+ available=all_parameter_names
330
+ | {k for k, v in self._variables.items() if not isinstance(v, Derived)}
331
+ | set(self._data)
332
+ | {"time"},
323
333
  elements=[
324
334
  Dependency(name=k, required=set(v.args), provided={k})
325
335
  if not isinstance(v, AbstractSurrogate)
@@ -328,35 +338,42 @@ class Model:
328
338
  ],
329
339
  )
330
340
 
331
- # Split derived into parameters and variables
332
- # for user convenience
333
- derived_variable_names: list[str] = []
334
- derived_parameter_names: list[str] = []
341
+ # Calculate all values once, including dynamic ones
342
+ # That way, we can make initial conditions dependent on e.g. rates
343
+ dependent = (
344
+ all_parameter_values
345
+ | self._data
346
+ | {k: v for k, v in self._variables.items() if not isinstance(v, Derived)}
347
+ | {"time": 0.0}
348
+ )
349
+ for name in order:
350
+ to_sort[name].calculate_inpl(name, dependent)
351
+
352
+ # Split derived into static and dynamic variables
353
+ static_order = []
354
+ dyn_order = []
335
355
  for name in order:
336
356
  if name in self._reactions or name in self._surrogates:
337
- continue
338
- derived = self._derived[name]
339
- if all(i in all_parameter_names for i in derived.args):
340
- all_parameter_names.add(name)
341
- derived_parameter_names.append(name)
342
- all_parameter_values[name] = float(
343
- derived.fn(*(all_parameter_values[i] for i in derived.args))
344
- )
357
+ dyn_order.append(name)
358
+ elif name in self._variables:
359
+ static_order.append(name)
345
360
  else:
346
- derived_variable_names.append(name)
361
+ derived = self._derived[name]
362
+ if all(i in all_parameter_names for i in derived.args):
363
+ static_order.append(name)
364
+ all_parameter_names.add(name)
365
+ else:
366
+ dyn_order.append(name)
347
367
 
368
+ # Calculate dynamic and static stochiometries
348
369
  stoich_by_compounds: dict[str, dict[str, float]] = {}
349
370
  dyn_stoich_by_compounds: dict[str, dict[str, Derived]] = {}
350
-
351
371
  for rxn_name, rxn in self._reactions.items():
352
372
  for cpd_name, factor in rxn.stoichiometry.items():
353
373
  d_static = stoich_by_compounds.setdefault(cpd_name, {})
354
-
355
374
  if isinstance(factor, Derived):
356
375
  if all(i in all_parameter_names for i in factor.args):
357
- d_static[rxn_name] = float(
358
- factor.fn(*(all_parameter_values[i] for i in factor.args))
359
- )
376
+ d_static[rxn_name] = factor.calculate(dependent)
360
377
  else:
361
378
  dyn_stoich_by_compounds.setdefault(cpd_name, {})[rxn_name] = (
362
379
  factor
@@ -367,20 +384,40 @@ class Model:
367
384
  for surrogate in self._surrogates.values():
368
385
  for rxn_name, rxn in surrogate.stoichiometries.items():
369
386
  for cpd_name, factor in rxn.items():
370
- stoich_by_compounds.setdefault(cpd_name, {})[rxn_name] = factor
387
+ d_static = stoich_by_compounds.setdefault(cpd_name, {})
388
+ if isinstance(factor, Derived):
389
+ if all(i in all_parameter_names for i in factor.args):
390
+ d_static[rxn_name] = factor.calculate(dependent)
391
+ else:
392
+ dyn_stoich_by_compounds.setdefault(cpd_name, {})[
393
+ rxn_name
394
+ ] = factor
395
+ else:
396
+ d_static[rxn_name] = factor
371
397
 
372
398
  var_names = self.get_variable_names()
373
399
  dxdt = pd.Series(np.zeros(len(var_names), dtype=float), index=var_names)
374
400
 
401
+ initial_conditions: dict[str, float] = {
402
+ k: v for k, v in self._variables.items() if not isinstance(v, Derived)
403
+ }
404
+ for name in static_order:
405
+ if name in self._variables:
406
+ initial_conditions[name] = cast(float, dependent[name])
407
+ elif name in self._derived:
408
+ all_parameter_values[name] = cast(float, dependent[name])
409
+ else:
410
+ msg = "Unknown target for static derived variable."
411
+ raise KeyError(msg)
412
+
375
413
  self._cache = ModelCache(
376
414
  var_names=var_names,
377
- order=order,
415
+ dyn_order=dyn_order,
378
416
  all_parameter_values=all_parameter_values,
379
417
  stoich_by_cpds=stoich_by_compounds,
380
418
  dyn_stoich_by_cpds=dyn_stoich_by_compounds,
381
- derived_variable_names=derived_variable_names,
382
- derived_parameter_names=derived_parameter_names,
383
419
  dxdt=dxdt,
420
+ initial_conditions=initial_conditions,
384
421
  )
385
422
  return self._cache
386
423
 
@@ -662,7 +699,7 @@ class Model:
662
699
  ##########################################################################
663
700
 
664
701
  @property
665
- def variables(self) -> dict[str, float]:
702
+ def variables(self) -> dict[str, float | Derived]:
666
703
  """Returns a copy of the variables dictionary.
667
704
 
668
705
  Examples:
@@ -679,7 +716,7 @@ class Model:
679
716
  return self._variables.copy()
680
717
 
681
718
  @_invalidate_cache
682
- def add_variable(self, name: str, initial_condition: float) -> Self:
719
+ def add_variable(self, name: str, initial_condition: float | Derived) -> Self:
683
720
  """Adds a variable to the model with the given name and initial condition.
684
721
 
685
722
  Examples:
@@ -697,7 +734,7 @@ class Model:
697
734
  self._variables[name] = initial_condition
698
735
  return self
699
736
 
700
- def add_variables(self, variables: dict[str, float]) -> Self:
737
+ def add_variables(self, variables: Mapping[str, float | Derived]) -> Self:
701
738
  """Adds multiple variables to the model with their initial conditions.
702
739
 
703
740
  Examples:
@@ -751,7 +788,7 @@ class Model:
751
788
  return self
752
789
 
753
790
  @_invalidate_cache
754
- def update_variable(self, name: str, initial_condition: float) -> Self:
791
+ def update_variable(self, name: str, initial_condition: float | Derived) -> Self:
755
792
  """Updates the value of a variable in the model.
756
793
 
757
794
  Examples:
@@ -771,7 +808,7 @@ class Model:
771
808
  self._variables[name] = initial_condition
772
809
  return self
773
810
 
774
- def update_variables(self, variables: dict[str, float]) -> Self:
811
+ def update_variables(self, variables: Mapping[str, float | Derived]) -> Self:
775
812
  """Updates multiple variables in the model.
776
813
 
777
814
  Examples:
@@ -812,7 +849,9 @@ class Model:
812
849
  initial_conditions: A dictionary where the keys are variable names and the values are their initial conditions.
813
850
 
814
851
  """
815
- return self._variables
852
+ if (cache := self._cache) is None:
853
+ cache = self._create_cache()
854
+ return cache.initial_conditions
816
855
 
817
856
  def make_variable_static(self, name: str, value: float | None = None) -> Self:
818
857
  """Converts a variable to a static parameter.
@@ -833,9 +872,12 @@ class Model:
833
872
  Self: The instance of the class for method chaining.
834
873
 
835
874
  """
836
- value = self._variables[name] if value is None else value
875
+ value_or_derived = self._variables[name] if value is None else value
837
876
  self.remove_variable(name)
838
- self.add_parameter(name, value)
877
+ if isinstance(value_or_derived, Derived):
878
+ self.add_derived(name, value_or_derived.fn, args=value_or_derived.args)
879
+ else:
880
+ self.add_parameter(name, value_or_derived)
839
881
 
840
882
  # Remove from stoichiometries
841
883
  for reaction in self._reactions.values():
@@ -886,7 +928,8 @@ class Model:
886
928
  if (cache := self._cache) is None:
887
929
  cache = self._create_cache()
888
930
  derived = self._derived
889
- return {k: derived[k] for k in cache.derived_variable_names}
931
+
932
+ return {k: v for k, v in derived.items() if k not in cache.all_parameter_values}
890
933
 
891
934
  @property
892
935
  def derived_parameters(self) -> dict[str, Derived]:
@@ -905,7 +948,7 @@ class Model:
905
948
  if (cache := self._cache) is None:
906
949
  cache = self._create_cache()
907
950
  derived = self._derived
908
- return {k: derived[k] for k in cache.derived_parameter_names}
951
+ return {k: v for k, v in derived.items() if k in cache.all_parameter_values}
909
952
 
910
953
  @_invalidate_cache
911
954
  def add_derived(
@@ -1049,6 +1092,33 @@ class Model:
1049
1092
  )
1050
1093
  return pd.DataFrame(stoich_by_cpds).T.fillna(0)
1051
1094
 
1095
+ def get_stoichiometries_of_variable(
1096
+ self,
1097
+ variable: str,
1098
+ variables: dict[str, float] | None = None,
1099
+ time: float = 0.0,
1100
+ ) -> dict[str, float]:
1101
+ """Retrieve the stoichiometry of a specific variable.
1102
+
1103
+ Examples:
1104
+ >>> model.get_stoichiometries_of_variable("x1")
1105
+ {"v1": -1, "v2": 1}
1106
+
1107
+ Args:
1108
+ variable: The name of the variable for which to retrieve the stoichiometry.
1109
+ variables: A dictionary of variable names and their values.
1110
+ time: The time point at which to evaluate the stoichiometry.
1111
+
1112
+ """
1113
+ if (cache := self._cache) is None:
1114
+ cache = self._create_cache()
1115
+ args = self.get_dependent(variables=variables, time=time)
1116
+
1117
+ stoich = copy.deepcopy(cache.stoich_by_cpds[variable])
1118
+ for rxn, derived in cache.dyn_stoich_by_cpds[variable].items():
1119
+ stoich[rxn] = float(derived.fn(*(args[i] for i in derived.args)))
1120
+ return stoich
1121
+
1052
1122
  @_invalidate_cache
1053
1123
  def add_reaction(
1054
1124
  self,
@@ -1250,7 +1320,7 @@ class Model:
1250
1320
  surrogate: AbstractSurrogate,
1251
1321
  args: list[str] | None = None,
1252
1322
  outputs: list[str] | None = None,
1253
- stoichiometries: dict[str, dict[str, float]] | None = None,
1323
+ stoichiometries: dict[str, dict[str, float | Derived]] | None = None,
1254
1324
  ) -> Self:
1255
1325
  """Adds a surrogate model to the current instance.
1256
1326
 
@@ -1285,7 +1355,7 @@ class Model:
1285
1355
  name: str,
1286
1356
  surrogate: AbstractSurrogate | None = None,
1287
1357
  args: list[str] | None = None,
1288
- stoichiometries: dict[str, dict[str, float]] | None = None,
1358
+ stoichiometries: dict[str, dict[str, float | Derived]] | None = None,
1289
1359
  ) -> Self:
1290
1360
  """Update a surrogate model in the model.
1291
1361
 
@@ -1337,6 +1407,27 @@ class Model:
1337
1407
  names.extend(i.stoichiometries)
1338
1408
  return names
1339
1409
 
1410
+ ##########################################################################
1411
+ # Datasets
1412
+ ##########################################################################
1413
+
1414
+ def add_data(self, name: str, data: pd.Series | pd.DataFrame) -> Self:
1415
+ """Add named data set to model."""
1416
+ self._insert_id(name=name, ctx="data")
1417
+ self._data[name] = data
1418
+ return self
1419
+
1420
+ def update_data(self, name: str, data: pd.Series | pd.DataFrame) -> Self:
1421
+ """Update named data set."""
1422
+ self._data[name] = data
1423
+ return self
1424
+
1425
+ def remove_data(self, name: str) -> Self:
1426
+ """Remove data set from model."""
1427
+ self._remove_id(name=name)
1428
+ self._data.pop(name)
1429
+ return self
1430
+
1340
1431
  ##########################################################################
1341
1432
  # Get dependent values. This includes
1342
1433
  # - derived parameters
@@ -1371,14 +1462,17 @@ class Model:
1371
1462
  with their respective names as keys and their calculated values as values.
1372
1463
 
1373
1464
  """
1374
- args: dict[str, float] = cache.all_parameter_values | variables
1465
+ args = cache.all_parameter_values | variables | self._data
1375
1466
  args["time"] = time
1376
1467
 
1377
1468
  containers = self._derived | self._reactions | self._surrogates
1378
- for name in cache.order:
1469
+ for name in cache.dyn_order:
1379
1470
  containers[name].calculate_inpl(name, args)
1380
1471
 
1381
- return args
1472
+ for k in self._data:
1473
+ args.pop(k)
1474
+
1475
+ return cast(dict[str, float], args)
1382
1476
 
1383
1477
  def get_dependent(
1384
1478
  self,
@@ -1454,29 +1548,16 @@ class Model:
1454
1548
  derived variables, and optionally readout variables, with time as an additional column.
1455
1549
 
1456
1550
  """
1457
- if (cache := self._cache) is None:
1458
- cache = self._create_cache()
1459
-
1460
- pars_df = pd.DataFrame(
1461
- np.full(
1462
- (len(variables), len(cache.all_parameter_values)),
1463
- np.fromiter(cache.all_parameter_values.values(), dtype=float),
1464
- ),
1465
- index=variables.index,
1466
- columns=list(cache.all_parameter_values),
1467
- )
1468
-
1469
- args = pd.concat((variables, pars_df), axis=1)
1470
- args["time"] = args.index
1471
-
1472
- containers = self._derived | self._reactions | self._surrogates
1473
- for name in cache.order:
1474
- containers[name].calculate_inpl_time_course(name, args)
1551
+ args = {
1552
+ time: self.get_dependent(
1553
+ variables=values.to_dict(),
1554
+ time=cast(float, time),
1555
+ include_readouts=include_readouts,
1556
+ )
1557
+ for time, values in variables.iterrows()
1558
+ }
1475
1559
 
1476
- if include_readouts:
1477
- for name, ro in self._readouts.items():
1478
- args[name] = ro.fn(*args.loc[:, ro.args].to_numpy().T)
1479
- return args
1560
+ return pd.DataFrame(args, dtype=float).T
1480
1561
 
1481
1562
  ##########################################################################
1482
1563
  # Get args
@@ -1569,28 +1650,6 @@ class Model:
1569
1650
  # Get fluxes
1570
1651
  ##########################################################################
1571
1652
 
1572
- def _get_fluxes(self, args: dict[str, float]) -> dict[str, float]:
1573
- """Calculate the fluxes for the given arguments.
1574
-
1575
- Examples:
1576
- >>> model._get_fluxes({"x1": 1.0, "x2": 2.0, "k1": 0.1, "time": 0.0})
1577
- {"r1": 0.1, "r2": 0.2}
1578
-
1579
- Args:
1580
- args (dict[str, float]): A dictionary where the keys are argument names and the values are their corresponding float values.
1581
-
1582
- Returns:
1583
- dict[str, float]: A dictionary where the keys are reaction names and the values are the calculated fluxes.
1584
-
1585
- """
1586
- fluxes: dict[str, float] = {}
1587
- for name, rxn in self._reactions.items():
1588
- fluxes[name] = cast(float, rxn.fn(*(args[arg] for arg in rxn.args)))
1589
-
1590
- for surrogate in self._surrogates.values():
1591
- fluxes |= surrogate.predict(np.array([args[arg] for arg in surrogate.args]))
1592
- return fluxes
1593
-
1594
1653
  def get_fluxes(
1595
1654
  self,
1596
1655
  variables: dict[str, float] | None = None,
mxlpy/nn/__init__.py CHANGED
@@ -1,10 +1,29 @@
1
1
  """Collection of neural network architectures."""
2
2
 
3
- import contextlib
3
+ from __future__ import annotations
4
4
 
5
- __all__ = ["tensorflow", "torch"]
5
+ from typing import TYPE_CHECKING
6
6
 
7
- with contextlib.suppress(ImportError):
8
- from . import _torch as torch
7
+ if TYPE_CHECKING:
8
+ import contextlib
9
9
 
10
- from . import _tensorflow as tensorflow
10
+ with contextlib.suppress(ImportError):
11
+ from . import _keras as keras
12
+ from . import _torch as torch
13
+ else:
14
+ from lazy_import import lazy_module
15
+
16
+ keras = lazy_module(
17
+ "mxlpy.nn._keras",
18
+ error_strings={"module": "keras", "install_name": "mxlpy[tf]"},
19
+ )
20
+ torch = lazy_module(
21
+ "mxlpy.nn._torch",
22
+ error_strings={"module": "torch", "install_name": "mxlpy[torch]"},
23
+ )
24
+
25
+
26
+ __all__ = [
27
+ "keras",
28
+ "torch",
29
+ ]
mxlpy/nn/_keras.py ADDED
@@ -0,0 +1,85 @@
1
+ from typing import cast
2
+
3
+ import keras
4
+ import pandas as pd
5
+ from tqdm.keras import TqdmCallback
6
+
7
+ from mxlpy.types import Array
8
+
9
+ __all__ = ["LSTM", "MLP", "train"]
10
+
11
+
12
+ def train(
13
+ model: keras.Model,
14
+ features: pd.DataFrame | Array,
15
+ targets: pd.DataFrame | Array,
16
+ epochs: int,
17
+ batch_size: int | None,
18
+ ) -> pd.Series:
19
+ """Train the neural network using mini-batch gradient descent.
20
+
21
+ Args:
22
+ model: Neural network model to train.
23
+ features: Input features as a tensor.
24
+ targets: Target values as a tensor.
25
+ epochs: Number of training epochs.
26
+ optimizer: Optimizer for training.
27
+ device: torch device
28
+ batch_size: Size of mini-batches for training.
29
+ loss_fn: Loss function
30
+
31
+ Returns:
32
+ pd.Series: Series containing the training loss history.
33
+
34
+ """
35
+ history = model.fit(
36
+ features,
37
+ targets,
38
+ batch_size=batch_size,
39
+ epochs=epochs,
40
+ verbose=cast(str, 0),
41
+ callbacks=[TqdmCallback()],
42
+ )
43
+ return pd.Series(history.history["loss"])
44
+
45
+
46
+ def MLP( # noqa: N802
47
+ n_inputs: int,
48
+ neurons_per_layer: list[int],
49
+ activation: None = None,
50
+ output_activation: None = None,
51
+ ) -> keras.Sequential:
52
+ """Multilayer Perceptron (MLP) for surrogate modeling and neural posterior estimation.
53
+
54
+ Methods:
55
+ forward: Forward pass through the neural network.
56
+
57
+ """
58
+ model = keras.Sequential([keras.Input(shape=(n_inputs,))])
59
+ for neurons in neurons_per_layer[:-1]:
60
+ model.add(keras.layers.Dense(neurons, activation=activation))
61
+ model.add(keras.layers.Dense(neurons_per_layer[-1], activation=output_activation))
62
+ return model
63
+
64
+
65
+ def LSTM( # noqa: N802
66
+ n_inputs: int,
67
+ n_outputs: int,
68
+ n_hidden: int,
69
+ ) -> keras.Sequential:
70
+ """Long Short-Term Memory (LSTM) network for time series modeling.
71
+
72
+ Methods:
73
+ forward: Forward pass through the neural network.
74
+
75
+ """
76
+ model = keras.Sequential(
77
+ [
78
+ keras.Input(
79
+ shape=(n_inputs),
80
+ )
81
+ ]
82
+ )
83
+ model.add(keras.layers.LSTM(n_hidden))
84
+ model.add(keras.layers.Dense(n_outputs))
85
+ return model
mxlpy/nn/_torch.py CHANGED
@@ -8,17 +8,77 @@ from __future__ import annotations
8
8
 
9
9
  from typing import TYPE_CHECKING, cast
10
10
 
11
+ import numpy as np
12
+ import pandas as pd
11
13
  import torch
14
+ import tqdm
12
15
  from torch import nn
16
+ from torch.utils.data import DataLoader, TensorDataset
17
+
18
+ type LossFn = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
13
19
 
14
20
  if TYPE_CHECKING:
15
21
  from collections.abc import Callable
16
22
 
17
- __all__ = ["DefaultDevice", "LSTM", "MLP"]
23
+ from torch.optim.adam import Adam
24
+
25
+ from mxlpy.types import Array
26
+
27
+ __all__ = ["DefaultDevice", "LSTM", "LossFn", "MLP", "train"]
18
28
 
19
29
  DefaultDevice = torch.device("cpu")
20
30
 
21
31
 
32
+ def train(
33
+ model: nn.Module,
34
+ features: Array,
35
+ targets: Array,
36
+ epochs: int,
37
+ optimizer: Adam,
38
+ device: torch.device,
39
+ batch_size: int | None,
40
+ loss_fn: LossFn,
41
+ ) -> pd.Series:
42
+ """Train the neural network using mini-batch gradient descent.
43
+
44
+ Args:
45
+ model: Neural network model to train.
46
+ features: Input features as a tensor.
47
+ targets: Target values as a tensor.
48
+ epochs: Number of training epochs.
49
+ optimizer: Optimizer for training.
50
+ device: torch device
51
+ batch_size: Size of mini-batches for training.
52
+ loss_fn: Loss function
53
+
54
+ Returns:
55
+ pd.Series: Series containing the training loss history.
56
+
57
+ """
58
+ losses = {}
59
+
60
+ data = TensorDataset(
61
+ torch.tensor(features.astype(np.float32), dtype=torch.float32, device=device),
62
+ torch.tensor(targets.astype(np.float32), dtype=torch.float32, device=device),
63
+ )
64
+ data_loader = DataLoader(
65
+ data,
66
+ batch_size=len(features) if batch_size is None else batch_size,
67
+ shuffle=True,
68
+ )
69
+
70
+ for i in tqdm.trange(epochs):
71
+ epoch_loss = 0
72
+ for xb, yb in data_loader:
73
+ optimizer.zero_grad()
74
+ loss = loss_fn(model(xb), yb)
75
+ loss.backward()
76
+ optimizer.step()
77
+ epoch_loss += loss.item() * xb.size(0)
78
+ losses[i] = epoch_loss / len(data_loader.dataset) # type: ignore
79
+ return pd.Series(losses, dtype=float)
80
+
81
+
22
82
  class MLP(nn.Module):
23
83
  """Multilayer Perceptron (MLP) for surrogate modeling and neural posterior estimation.
24
84
 
@@ -65,21 +125,17 @@ class MLP(nn.Module):
65
125
  levels = []
66
126
  previous_neurons = n_inputs
67
127
 
68
- for idx, neurons in enumerate(self.layers):
69
- if idx == (len(self.layers) - 1):
70
- levels.append(nn.Linear(previous_neurons, neurons))
71
-
72
- if self.output_activation:
73
- levels.append(self.output_activation)
74
-
75
- else:
76
- levels.append(nn.Linear(previous_neurons, neurons))
77
-
78
- if self.activation:
79
- levels.append(self.activation)
80
-
128
+ for neurons in self.layers[:-1]:
129
+ levels.append(nn.Linear(previous_neurons, neurons))
130
+ if self.activation:
131
+ levels.append(self.activation)
81
132
  previous_neurons = neurons
82
133
 
134
+ # Output layer
135
+ levels.append(nn.Linear(previous_neurons, self.layers[-1]))
136
+ if self.output_activation:
137
+ levels.append(self.output_activation)
138
+
83
139
  self.net = nn.Sequential(*levels)
84
140
 
85
141
  for m in self.net.modules():
@@ -103,7 +159,12 @@ class MLP(nn.Module):
103
159
  class LSTM(nn.Module):
104
160
  """Default LSTM neural network model for time-series approximation."""
105
161
 
106
- def __init__(self, n_inputs: int, n_outputs: int, n_hidden: int) -> None:
162
+ def __init__(
163
+ self,
164
+ n_inputs: int,
165
+ n_outputs: int,
166
+ n_hidden: int,
167
+ ) -> None:
107
168
  """Initializes the neural network model.
108
169
 
109
170
  Args: