mxlpy 0.25.0__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mxlpy/__init__.py +4 -4
- mxlpy/fit.py +1414 -0
- mxlpy/fuzzy.py +139 -0
- mxlpy/identify.py +5 -5
- mxlpy/integrators/int_scipy.py +4 -3
- mxlpy/meta/codegen_latex.py +1 -0
- mxlpy/meta/source_tools.py +1 -1
- mxlpy/model.py +41 -24
- mxlpy/nn/__init__.py +5 -0
- mxlpy/nn/_equinox.py +293 -0
- mxlpy/nn/_torch.py +59 -2
- mxlpy/npe/__init__.py +5 -0
- mxlpy/npe/_equinox.py +344 -0
- mxlpy/npe/_torch.py +6 -22
- mxlpy/parallel.py +73 -4
- mxlpy/surrogates/__init__.py +5 -0
- mxlpy/surrogates/_equinox.py +195 -0
- mxlpy/surrogates/_torch.py +5 -20
- mxlpy/symbolic/symbolic_model.py +30 -3
- mxlpy/types.py +1 -0
- {mxlpy-0.25.0.dist-info → mxlpy-0.26.0.dist-info}/METADATA +4 -1
- {mxlpy-0.25.0.dist-info → mxlpy-0.26.0.dist-info}/RECORD +24 -23
- mxlpy/fit/__init__.py +0 -9
- mxlpy/fit/common.py +0 -298
- mxlpy/fit/global_.py +0 -534
- mxlpy/fit/local_.py +0 -591
- {mxlpy-0.25.0.dist-info → mxlpy-0.26.0.dist-info}/WHEEL +0 -0
- {mxlpy-0.25.0.dist-info → mxlpy-0.26.0.dist-info}/licenses/LICENSE +0 -0
mxlpy/fuzzy.py
ADDED
@@ -0,0 +1,139 @@
|
|
1
|
+
"""Fuzzy / bayesian fitting methods."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import multiprocessing
|
6
|
+
import sys
|
7
|
+
from dataclasses import dataclass, field
|
8
|
+
from functools import partial
|
9
|
+
from math import ceil
|
10
|
+
from typing import TYPE_CHECKING, Self
|
11
|
+
|
12
|
+
import numpy as np
|
13
|
+
import pandas as pd
|
14
|
+
import pebble
|
15
|
+
from tqdm import tqdm, trange
|
16
|
+
|
17
|
+
from mxlpy.simulator import Simulator
|
18
|
+
|
19
|
+
if TYPE_CHECKING:
|
20
|
+
from collections.abc import Iterable
|
21
|
+
|
22
|
+
from mxlpy import Model
|
23
|
+
|
24
|
+
__all__ = ["ThompsonState", "thompson_sampling"]
|
25
|
+
|
26
|
+
|
27
|
+
@dataclass
|
28
|
+
class ThompsonState:
|
29
|
+
"""State of thompson sampling."""
|
30
|
+
|
31
|
+
rng: np.random.Generator = field(default_factory=np.random.default_rng)
|
32
|
+
state: dict[str, pd.DataFrame] = field(default_factory=dict)
|
33
|
+
|
34
|
+
@classmethod
|
35
|
+
def from_parameter_values(cls, parameters: dict[str, Iterable[float]]) -> Self:
|
36
|
+
"""Create state from parameter values."""
|
37
|
+
return cls(
|
38
|
+
state={
|
39
|
+
k: pd.DataFrame(
|
40
|
+
{
|
41
|
+
"x": v,
|
42
|
+
"success": np.ones_like(v, dtype=int),
|
43
|
+
"fail": np.ones_like(v, dtype=int),
|
44
|
+
}
|
45
|
+
)
|
46
|
+
for k, v in parameters.items()
|
47
|
+
},
|
48
|
+
)
|
49
|
+
|
50
|
+
def sample(self) -> tuple[dict[str, int], dict[str, float]]:
|
51
|
+
"""Sample idxs and parameters."""
|
52
|
+
idxs = {
|
53
|
+
k: int(np.argmax(self.rng.beta(v["success"], v["fail"])))
|
54
|
+
for k, v in self.state.items()
|
55
|
+
}
|
56
|
+
parameters = {k: v["x"][idxs[k]] for k, v in self.state.items()}
|
57
|
+
return idxs, parameters
|
58
|
+
|
59
|
+
def update(
|
60
|
+
self,
|
61
|
+
idxs: dict[str, int],
|
62
|
+
pred: pd.DataFrame | None,
|
63
|
+
data: pd.DataFrame,
|
64
|
+
rtol: float,
|
65
|
+
) -> None:
|
66
|
+
"""Sample state."""
|
67
|
+
accept: bool = (
|
68
|
+
False if pred is None else np.sqrt(np.mean(np.square(pred - data))) < rtol
|
69
|
+
)
|
70
|
+
for k, v in self.state.items():
|
71
|
+
v.loc[idxs[k], "success" if accept else "fail"] += 1 # type: ignore
|
72
|
+
|
73
|
+
|
74
|
+
def _thompson_worker(
|
75
|
+
inp: tuple[dict[str, int], dict[str, float]],
|
76
|
+
model: Model,
|
77
|
+
data: pd.DataFrame,
|
78
|
+
) -> tuple[dict[str, int], pd.DataFrame | None]:
|
79
|
+
idxs, parameters = inp
|
80
|
+
if (
|
81
|
+
res := (
|
82
|
+
Simulator(model)
|
83
|
+
.update_parameters(parameters)
|
84
|
+
.simulate_time_course(data.index)
|
85
|
+
.get_result()
|
86
|
+
)
|
87
|
+
) is None:
|
88
|
+
return idxs, None
|
89
|
+
return idxs, res.get_variables()
|
90
|
+
|
91
|
+
|
92
|
+
def thompson_sampling(
|
93
|
+
model: Model,
|
94
|
+
data: pd.DataFrame,
|
95
|
+
state: ThompsonState,
|
96
|
+
rtol: float,
|
97
|
+
n: int,
|
98
|
+
*,
|
99
|
+
max_workers: int | None = None,
|
100
|
+
disable_tqdm: bool = False,
|
101
|
+
timeout: float | None = None,
|
102
|
+
parallel: bool = True,
|
103
|
+
) -> ThompsonState:
|
104
|
+
"""Perform thompson sampling."""
|
105
|
+
if sys.platform in ["win32", "cygwin"]:
|
106
|
+
parallel = False
|
107
|
+
|
108
|
+
max_workers = multiprocessing.cpu_count() if max_workers is None else max_workers
|
109
|
+
worker = partial(_thompson_worker, model=model, data=data)
|
110
|
+
|
111
|
+
if not parallel:
|
112
|
+
for _ in trange(n):
|
113
|
+
idxs, pred = worker(state.sample())
|
114
|
+
state.update(idxs, pred, data=data, rtol=rtol)
|
115
|
+
else:
|
116
|
+
# FIXME: think about whether this is ok to do. Thompson sampling is state-
|
117
|
+
# dependent. We are breaking up that state a bit by chunking the approach
|
118
|
+
# Is that fine to do?
|
119
|
+
with (
|
120
|
+
tqdm(total=n, disable=disable_tqdm) as pbar,
|
121
|
+
pebble.ProcessPool(max_workers=max_workers) as pool,
|
122
|
+
):
|
123
|
+
for _ in range(ceil(n / max_workers)):
|
124
|
+
future = pool.map(
|
125
|
+
worker,
|
126
|
+
[state.sample() for _ in range(max_workers)],
|
127
|
+
timeout=timeout,
|
128
|
+
)
|
129
|
+
it = future.result()
|
130
|
+
while True:
|
131
|
+
try:
|
132
|
+
idxs, pred = next(it)
|
133
|
+
state.update(idxs, pred, data=data, rtol=rtol)
|
134
|
+
pbar.update(1)
|
135
|
+
except StopIteration:
|
136
|
+
break
|
137
|
+
except TimeoutError:
|
138
|
+
pbar.update(1)
|
139
|
+
return state
|
mxlpy/identify.py
CHANGED
@@ -9,9 +9,8 @@ import numpy as np
|
|
9
9
|
import pandas as pd
|
10
10
|
from tqdm import tqdm
|
11
11
|
|
12
|
-
from mxlpy import
|
12
|
+
from mxlpy import fit
|
13
13
|
from mxlpy.distributions import LogNormal, sample
|
14
|
-
from mxlpy.fit.common import LossFn, rmse
|
15
14
|
from mxlpy.parallel import parallelise
|
16
15
|
|
17
16
|
if TYPE_CHECKING:
|
@@ -27,13 +26,14 @@ def _mc_fit_time_course_worker(
|
|
27
26
|
p0: pd.Series,
|
28
27
|
model: Model,
|
29
28
|
data: pd.DataFrame,
|
30
|
-
loss_fn:
|
29
|
+
loss_fn: fit.LossFn,
|
31
30
|
) -> float:
|
32
|
-
fit_result =
|
31
|
+
fit_result = fit.time_course(
|
33
32
|
model=model,
|
34
33
|
p0=p0.to_dict(),
|
35
34
|
data=data,
|
36
35
|
loss_fn=loss_fn,
|
36
|
+
minimizer=fit.LocalScipyMinimizer(),
|
37
37
|
)
|
38
38
|
if fit_result is None:
|
39
39
|
return np.inf
|
@@ -46,7 +46,7 @@ def profile_likelihood(
|
|
46
46
|
parameter_name: str,
|
47
47
|
parameter_values: Array,
|
48
48
|
n_random: int = 10,
|
49
|
-
loss_fn: LossFn = rmse,
|
49
|
+
loss_fn: fit.LossFn = fit.rmse,
|
50
50
|
) -> pd.Series:
|
51
51
|
"""Estimate the profile likelihood of model parameters given data.
|
52
52
|
|
mxlpy/integrators/int_scipy.py
CHANGED
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|
4
4
|
|
5
5
|
import copy
|
6
6
|
from dataclasses import dataclass, field
|
7
|
-
from typing import TYPE_CHECKING, cast
|
7
|
+
from typing import TYPE_CHECKING, Literal, cast
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
import scipy.integrate as spi
|
@@ -48,6 +48,7 @@ class Scipy:
|
|
48
48
|
atol: float = 1e-8
|
49
49
|
rtol: float = 1e-8
|
50
50
|
t0: float = 0.0
|
51
|
+
method: Literal["RK45", "RK23", "DOP853", "Radau", "BDF", "LSODA"] = "LSODA"
|
51
52
|
_y0_orig: tuple[float, ...] = field(default_factory=tuple)
|
52
53
|
|
53
54
|
def __post_init__(self) -> None:
|
@@ -111,7 +112,7 @@ class Scipy:
|
|
111
112
|
jac=self.jacobian,
|
112
113
|
atol=self.atol,
|
113
114
|
rtol=self.rtol,
|
114
|
-
method=
|
115
|
+
method=self.method,
|
115
116
|
)
|
116
117
|
|
117
118
|
if res.success:
|
@@ -149,7 +150,7 @@ class Scipy:
|
|
149
150
|
# If rhs returns a tuple, we get weird errors, so we need
|
150
151
|
# to wrap this in a list for some reason
|
151
152
|
integ = spi.ode(lambda t, x: list(self.rhs(t, x)), jac=self.jacobian)
|
152
|
-
integ.set_integrator(name=
|
153
|
+
integ.set_integrator(name=self.method)
|
153
154
|
integ.set_initial_value(self.y0)
|
154
155
|
|
155
156
|
t = self.t0 + step_size
|
mxlpy/meta/codegen_latex.py
CHANGED
@@ -795,6 +795,7 @@ def to_tex_export(model: Model) -> TexExport:
|
|
795
795
|
for rxn_name, rxn in model.get_raw_reactions().items():
|
796
796
|
for var_name, factor in rxn.stoichiometry.items():
|
797
797
|
diff_eqs.setdefault(var_name, {})[rxn_name] = factor
|
798
|
+
# FIXME: think about surrogates here
|
798
799
|
|
799
800
|
return TexExport(
|
800
801
|
parameters=model.get_parameter_values(),
|
mxlpy/meta/source_tools.py
CHANGED
@@ -635,7 +635,7 @@ def _handle_attribute(node: ast.Attribute, ctx: Context) -> sympy.Expr | None:
|
|
635
635
|
)
|
636
636
|
elif (var := variables.get(level)) is not None:
|
637
637
|
_LOGGER.debug("var %s", var)
|
638
|
-
return _get_inner_object(var, levels[
|
638
|
+
return _get_inner_object(var, [*levels[idx + 1 :], node.attr])
|
639
639
|
|
640
640
|
else:
|
641
641
|
_LOGGER.debug("No target found")
|
mxlpy/model.py
CHANGED
@@ -928,14 +928,14 @@ class Model:
|
|
928
928
|
if stoichiometries is not None:
|
929
929
|
for rxn_name, value in stoichiometries.items():
|
930
930
|
target = False
|
931
|
-
if
|
931
|
+
if (rxn := self._reactions.get(rxn_name)) is not None:
|
932
932
|
target = True
|
933
|
-
cast(dict,
|
933
|
+
cast(dict, rxn.stoichiometry)[name] = value
|
934
934
|
else:
|
935
935
|
for surrogate in self._surrogates.values():
|
936
|
-
if
|
936
|
+
if stoich := surrogate.stoichiometries.get(rxn_name):
|
937
937
|
target = True
|
938
|
-
|
938
|
+
stoich[name] = value
|
939
939
|
if not target:
|
940
940
|
msg = f"Reaction '{rxn_name}' not found in reactions or surrogates"
|
941
941
|
raise KeyError(msg)
|
@@ -1099,7 +1099,12 @@ class Model:
|
|
1099
1099
|
return self
|
1100
1100
|
|
1101
1101
|
@_invalidate_cache
|
1102
|
-
def remove_variable(
|
1102
|
+
def remove_variable(
|
1103
|
+
self,
|
1104
|
+
name: str,
|
1105
|
+
*,
|
1106
|
+
remove_stoichiometries: bool = True,
|
1107
|
+
) -> Self:
|
1103
1108
|
"""Remove a variable from the model.
|
1104
1109
|
|
1105
1110
|
Examples:
|
@@ -1107,16 +1112,31 @@ class Model:
|
|
1107
1112
|
|
1108
1113
|
Args:
|
1109
1114
|
name: The name of the variable to remove.
|
1115
|
+
remove_stoichiometries: whether to remove the variable from all reactions
|
1110
1116
|
|
1111
1117
|
Returns:
|
1112
1118
|
Self: The instance of the model with the variable removed.
|
1113
1119
|
|
1114
1120
|
"""
|
1121
|
+
if remove_stoichiometries:
|
1122
|
+
for rxn in self._reactions.values():
|
1123
|
+
if name in rxn.stoichiometry:
|
1124
|
+
cast(dict, rxn.stoichiometry).pop(name)
|
1125
|
+
for surrogate in self._surrogates.values():
|
1126
|
+
for stoich in surrogate.stoichiometries.values():
|
1127
|
+
if name in stoich:
|
1128
|
+
cast(dict, stoich).pop(name)
|
1129
|
+
|
1115
1130
|
self._remove_id(name=name)
|
1116
1131
|
del self._variables[name]
|
1117
1132
|
return self
|
1118
1133
|
|
1119
|
-
def remove_variables(
|
1134
|
+
def remove_variables(
|
1135
|
+
self,
|
1136
|
+
variables: Iterable[str],
|
1137
|
+
*,
|
1138
|
+
remove_stoichiometries: bool = True,
|
1139
|
+
) -> Self:
|
1120
1140
|
"""Remove multiple variables from the model.
|
1121
1141
|
|
1122
1142
|
Examples:
|
@@ -1124,13 +1144,16 @@ class Model:
|
|
1124
1144
|
|
1125
1145
|
Args:
|
1126
1146
|
variables: An iterable of variable names to be removed.
|
1147
|
+
remove_stoichiometries: whether to remove the variables from all reactions
|
1127
1148
|
|
1128
1149
|
Returns:
|
1129
1150
|
Self: The instance of the model with the specified variables removed.
|
1130
1151
|
|
1131
1152
|
"""
|
1132
1153
|
for variable in variables:
|
1133
|
-
self.remove_variable(
|
1154
|
+
self.remove_variable(
|
1155
|
+
name=variable, remove_stoichiometries=remove_stoichiometries
|
1156
|
+
)
|
1134
1157
|
return self
|
1135
1158
|
|
1136
1159
|
@_invalidate_cache
|
@@ -1219,24 +1242,17 @@ class Model:
|
|
1219
1242
|
value_or_derived = (
|
1220
1243
|
self._variables[name].initial_value if value is None else value
|
1221
1244
|
)
|
1222
|
-
self.remove_variable(name)
|
1223
|
-
|
1224
|
-
|
1225
|
-
|
1226
|
-
|
1245
|
+
self.remove_variable(name, remove_stoichiometries=True)
|
1246
|
+
|
1247
|
+
if isinstance(der := value_or_derived, Derived):
|
1248
|
+
self.add_derived(
|
1249
|
+
name,
|
1250
|
+
der.fn,
|
1251
|
+
args=der.args,
|
1252
|
+
unit=der.unit,
|
1253
|
+
)
|
1227
1254
|
else:
|
1228
1255
|
self.add_parameter(name, value_or_derived)
|
1229
|
-
|
1230
|
-
# Remove from stoichiometries
|
1231
|
-
for reaction in self._reactions.values():
|
1232
|
-
if name in reaction.stoichiometry:
|
1233
|
-
cast(dict, reaction.stoichiometry).pop(name)
|
1234
|
-
for surrogate in self._surrogates.values():
|
1235
|
-
surrogate.stoichiometries = {
|
1236
|
-
k: {k2: v2 for k2, v2 in v.items() if k2 != name}
|
1237
|
-
for k, v in surrogate.stoichiometries.items()
|
1238
|
-
if k != name
|
1239
|
-
}
|
1240
1256
|
return self
|
1241
1257
|
|
1242
1258
|
##########################################################################
|
@@ -1680,7 +1696,8 @@ class Model:
|
|
1680
1696
|
|
1681
1697
|
##########################################################################
|
1682
1698
|
# Readouts
|
1683
|
-
# They are like derived variables, but only calculated on demand
|
1699
|
+
# They are like derived variables, but only calculated on demand, e.g. after
|
1700
|
+
# a simulation
|
1684
1701
|
# Think of something like NADPH / (NADP + NADPH) as a proxy for energy state
|
1685
1702
|
##########################################################################
|
1686
1703
|
|
mxlpy/nn/__init__.py
CHANGED
@@ -8,11 +8,16 @@ if TYPE_CHECKING:
|
|
8
8
|
import contextlib
|
9
9
|
|
10
10
|
with contextlib.suppress(ImportError):
|
11
|
+
from . import _equinox as equinox
|
11
12
|
from . import _keras as keras
|
12
13
|
from . import _torch as torch
|
13
14
|
else:
|
14
15
|
from lazy_import import lazy_module
|
15
16
|
|
17
|
+
equinox = lazy_module(
|
18
|
+
"mxlpy.nn._equinox",
|
19
|
+
error_strings={"module": "equinox", "install_name": "mxlpy[equinox]"},
|
20
|
+
)
|
16
21
|
keras = lazy_module(
|
17
22
|
"mxlpy.nn._keras",
|
18
23
|
error_strings={"module": "keras", "install_name": "mxlpy[tf]"},
|
mxlpy/nn/_equinox.py
ADDED
@@ -0,0 +1,293 @@
|
|
1
|
+
"""Neural network architectures.
|
2
|
+
|
3
|
+
This module provides implementations of neural network architectures used for mechanistic learning.
|
4
|
+
|
5
|
+
"""
|
6
|
+
|
7
|
+
from __future__ import annotations
|
8
|
+
|
9
|
+
from typing import TYPE_CHECKING
|
10
|
+
|
11
|
+
import equinox as eqx
|
12
|
+
import jax
|
13
|
+
import jax.numpy as jnp
|
14
|
+
import numpy as np
|
15
|
+
import pandas as pd
|
16
|
+
import torch
|
17
|
+
import tqdm
|
18
|
+
from jaxtyping import Array, PyTree
|
19
|
+
from torch.utils.data import DataLoader, TensorDataset
|
20
|
+
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
from collections.abc import Callable
|
23
|
+
|
24
|
+
import optax
|
25
|
+
|
26
|
+
|
27
|
+
__all__ = [
|
28
|
+
"LSTM",
|
29
|
+
"LossFn",
|
30
|
+
"MLP",
|
31
|
+
"cosine_similarity",
|
32
|
+
"mean_abs_error",
|
33
|
+
"mean_absolute_percentage",
|
34
|
+
"mean_error",
|
35
|
+
"mean_squared_error",
|
36
|
+
"mean_squared_logarithmic",
|
37
|
+
"rms_error",
|
38
|
+
"train",
|
39
|
+
]
|
40
|
+
|
41
|
+
|
42
|
+
###############################################################################
|
43
|
+
# Loss functions
|
44
|
+
###############################################################################
|
45
|
+
|
46
|
+
type LossFn = Callable[[eqx.Module, Array, Array], Array]
|
47
|
+
|
48
|
+
|
49
|
+
@eqx.filter_jit
|
50
|
+
def mean_error(model: eqx.Module, inp: Array, true: Array) -> Array:
|
51
|
+
"""Calculate mean error."""
|
52
|
+
pred = jax.vmap(model)(inp) # type: ignore
|
53
|
+
return jnp.mean(pred - true)
|
54
|
+
|
55
|
+
|
56
|
+
@eqx.filter_jit
|
57
|
+
def mean_squared_error(model: eqx.Module, inp: Array, true: Array) -> Array:
|
58
|
+
"""Calculate mean squared error."""
|
59
|
+
pred = jax.vmap(model)(inp) # type: ignore
|
60
|
+
return jnp.mean(jnp.square(pred - true))
|
61
|
+
|
62
|
+
|
63
|
+
@eqx.filter_jit
|
64
|
+
def rms_error(model: eqx.Module, inp: Array, true: Array) -> Array:
|
65
|
+
"""Calculate root mean square error."""
|
66
|
+
pred = jax.vmap(model)(inp) # type: ignore
|
67
|
+
return jnp.sqrt(jnp.mean(jnp.square(pred - true)))
|
68
|
+
|
69
|
+
|
70
|
+
@eqx.filter_jit
|
71
|
+
def mean_abs_error(model: eqx.Module, inp: Array, true: Array) -> Array:
|
72
|
+
"""Calculate mean absolute error."""
|
73
|
+
pred = jax.vmap(model)(inp) # type: ignore
|
74
|
+
return jnp.mean(jnp.abs(pred - true))
|
75
|
+
|
76
|
+
|
77
|
+
@eqx.filter_jit
|
78
|
+
def mean_absolute_percentage(model: eqx.Module, inp: Array, true: Array) -> Array:
|
79
|
+
"""Calculate mean absolute percentag error."""
|
80
|
+
pred = jax.vmap(model)(inp) # type: ignore
|
81
|
+
return 100 * jnp.mean(jnp.abs((true - pred) / pred))
|
82
|
+
|
83
|
+
|
84
|
+
@eqx.filter_jit
|
85
|
+
def mean_squared_logarithmic(model: eqx.Module, inp: Array, true: Array) -> Array:
|
86
|
+
"""Calculate root mean square error between model and data."""
|
87
|
+
pred = jax.vmap(model)(inp) # type: ignore
|
88
|
+
return jnp.mean(jnp.square(jnp.log(pred + 1) - jnp.log(true + 1)))
|
89
|
+
|
90
|
+
|
91
|
+
@eqx.filter_jit
|
92
|
+
def cosine_similarity(model: eqx.Module, inp: Array, true: Array) -> Array:
|
93
|
+
"""Calculate root mean square error between model and data."""
|
94
|
+
pred = jax.vmap(model)(inp) # type: ignore
|
95
|
+
return -jnp.sum(jnp.linalg.norm(pred, 2) * jnp.linalg.norm(true, 2))
|
96
|
+
|
97
|
+
|
98
|
+
###############################################################################
|
99
|
+
# Training routines
|
100
|
+
###############################################################################
|
101
|
+
|
102
|
+
|
103
|
+
def train(
|
104
|
+
model: eqx.Module,
|
105
|
+
features: Array,
|
106
|
+
targets: Array,
|
107
|
+
epochs: int,
|
108
|
+
optimizer: optax.GradientTransformation,
|
109
|
+
batch_size: int | None,
|
110
|
+
loss_fn: LossFn,
|
111
|
+
) -> pd.Series:
|
112
|
+
"""Train the neural network using mini-batch gradient descent.
|
113
|
+
|
114
|
+
Args:
|
115
|
+
model: Neural network model to train.
|
116
|
+
features: Input features as a tensor.
|
117
|
+
targets: Target values as a tensor.
|
118
|
+
epochs: Number of training epochs.
|
119
|
+
optimizer: Optimizer for training.
|
120
|
+
device: torch device
|
121
|
+
batch_size: Size of mini-batches for training.
|
122
|
+
loss_fn: Loss function
|
123
|
+
|
124
|
+
Returns:
|
125
|
+
pd.Series: Series containing the training loss history.
|
126
|
+
|
127
|
+
"""
|
128
|
+
losses = {}
|
129
|
+
|
130
|
+
data = TensorDataset(
|
131
|
+
torch.tensor(features.astype(np.float32), dtype=torch.float32),
|
132
|
+
torch.tensor(targets.astype(np.float32), dtype=torch.float32),
|
133
|
+
)
|
134
|
+
data_loader = DataLoader(
|
135
|
+
data,
|
136
|
+
batch_size=len(features) if batch_size is None else batch_size,
|
137
|
+
shuffle=True,
|
138
|
+
)
|
139
|
+
|
140
|
+
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
|
141
|
+
|
142
|
+
@eqx.filter_jit
|
143
|
+
def make_step(
|
144
|
+
model: eqx.Module,
|
145
|
+
opt_state: PyTree,
|
146
|
+
x: Array,
|
147
|
+
y: Array,
|
148
|
+
) -> tuple[eqx.Module, Array, Array]:
|
149
|
+
loss_value, grads = eqx.filter_value_and_grad(loss_fn)(model, x, y)
|
150
|
+
updates, opt_state = optimizer.update(
|
151
|
+
grads, opt_state, eqx.filter(model, eqx.is_array)
|
152
|
+
)
|
153
|
+
model = eqx.apply_updates(model, updates)
|
154
|
+
return model, opt_state, loss_value
|
155
|
+
|
156
|
+
for i in tqdm.trange(epochs):
|
157
|
+
epoch_loss = 0
|
158
|
+
for xb, yb in data_loader:
|
159
|
+
model, opt_state, train_loss = make_step(
|
160
|
+
model,
|
161
|
+
opt_state,
|
162
|
+
xb.numpy(),
|
163
|
+
yb.numpy(),
|
164
|
+
)
|
165
|
+
epoch_loss += train_loss * xb.size(0)
|
166
|
+
losses[i] = epoch_loss / len(data_loader.dataset) # type: ignore
|
167
|
+
return pd.Series(losses, dtype=float)
|
168
|
+
|
169
|
+
|
170
|
+
###############################################################################
|
171
|
+
# Actual models
|
172
|
+
###############################################################################
|
173
|
+
|
174
|
+
|
175
|
+
class MLP(eqx.Module):
|
176
|
+
"""Multilayer Perceptron (MLP) for surrogate modeling and neural posterior estimation.
|
177
|
+
|
178
|
+
Attributes:
|
179
|
+
net: Sequential neural network model.
|
180
|
+
|
181
|
+
Methods:
|
182
|
+
forward: Forward pass through the neural network.
|
183
|
+
|
184
|
+
"""
|
185
|
+
|
186
|
+
layers: list
|
187
|
+
|
188
|
+
def __init__(
|
189
|
+
self,
|
190
|
+
n_inputs: int,
|
191
|
+
neurons_per_layer: list[int],
|
192
|
+
key: Array,
|
193
|
+
) -> None:
|
194
|
+
"""Initializes the MLP with the given number of inputs and list of (hidden) layers.
|
195
|
+
|
196
|
+
Args:
|
197
|
+
n_inputs: The number of input features.
|
198
|
+
neurons_per_layer: Number of neurons per layer
|
199
|
+
n_outputs: A list containing the number of neurons in hidden and output layer.
|
200
|
+
key: jax.random.PRNGKey(SEED) for initial parameters
|
201
|
+
|
202
|
+
For instance, MLP(10, layers = [50, 50, 10]) initializes a neural network with the following architecture:
|
203
|
+
- Linear layer with `n_inputs` inputs and 50 outputs
|
204
|
+
- ReLU activation
|
205
|
+
- Linear layer with 50 inputs and 50 outputs
|
206
|
+
- ReLU activation
|
207
|
+
- Linear layer with 50 inputs and 10 outputs
|
208
|
+
|
209
|
+
The weights of the linear layers are initialized with a normal distribution
|
210
|
+
(mean=0, std=0.1) and the biases are initialized to 0.
|
211
|
+
|
212
|
+
"""
|
213
|
+
keys = iter(jax.random.split(key, len(neurons_per_layer)))
|
214
|
+
previous_neurons = n_inputs
|
215
|
+
layers = []
|
216
|
+
for neurons in neurons_per_layer:
|
217
|
+
layers.append(eqx.nn.Linear(previous_neurons, neurons, key=next(keys)))
|
218
|
+
previous_neurons = neurons
|
219
|
+
self.layers = layers
|
220
|
+
|
221
|
+
def __call__(self, x: Array) -> Array:
|
222
|
+
"""Forward pass through the neural network.
|
223
|
+
|
224
|
+
Args:
|
225
|
+
x: Input tensor.
|
226
|
+
|
227
|
+
Returns:
|
228
|
+
Output tensor.
|
229
|
+
|
230
|
+
"""
|
231
|
+
for layer in self.layers[:-1]:
|
232
|
+
x = jax.nn.relu(layer(x))
|
233
|
+
return self.layers[-1](x)
|
234
|
+
|
235
|
+
|
236
|
+
class LSTM(eqx.Module):
|
237
|
+
"""Default LSTM neural network model for time-series approximation."""
|
238
|
+
|
239
|
+
lstm_cell: eqx.nn.LSTMCell
|
240
|
+
n_hidden: int
|
241
|
+
linear: eqx.nn.Linear
|
242
|
+
|
243
|
+
def __init__(
|
244
|
+
self,
|
245
|
+
n_inputs: int,
|
246
|
+
n_outputs: int,
|
247
|
+
n_hidden: int,
|
248
|
+
key: Array,
|
249
|
+
) -> None:
|
250
|
+
"""Initializes the LSTM neural network model.
|
251
|
+
|
252
|
+
Args:
|
253
|
+
n_inputs (int): Number of input features.
|
254
|
+
n_outputs (int): Number of output features.
|
255
|
+
n_hidden (int): Number of hidden units in the LSTM layer.
|
256
|
+
key (Array): JAX random key for initialization.
|
257
|
+
|
258
|
+
"""
|
259
|
+
k1, k2 = jax.random.split(key, 2)
|
260
|
+
self.lstm_cell = eqx.nn.LSTMCell(n_inputs, n_hidden, key=k1)
|
261
|
+
self.n_hidden = n_hidden
|
262
|
+
self.linear = eqx.nn.Linear(n_hidden, n_outputs, key=k2)
|
263
|
+
|
264
|
+
def __call__(
|
265
|
+
self,
|
266
|
+
x: Array,
|
267
|
+
*,
|
268
|
+
h: Array | None = None,
|
269
|
+
c: Array | None = None,
|
270
|
+
) -> Array:
|
271
|
+
"""Forward pass through the LSTM network.
|
272
|
+
|
273
|
+
Args:
|
274
|
+
x: Input tensor of shape (seq_len, batch_size, n_inputs).
|
275
|
+
h: Optional initial hidden state (batch_size, n_hidden).
|
276
|
+
c: Optional initial cell state (batch_size, n_hidden).
|
277
|
+
|
278
|
+
Returns:
|
279
|
+
Output tensor of shape (seq_len, batch_size, n_outputs).
|
280
|
+
|
281
|
+
"""
|
282
|
+
seq_len, batch_size, _ = x.shape
|
283
|
+
if h is None:
|
284
|
+
h = jnp.zeros((batch_size, self.n_hidden))
|
285
|
+
if c is None:
|
286
|
+
c = jnp.zeros((batch_size, self.n_hidden))
|
287
|
+
|
288
|
+
outputs = []
|
289
|
+
for t in range(seq_len):
|
290
|
+
h, c = self.lstm_cell(x[t], (h, c))
|
291
|
+
outputs.append(h)
|
292
|
+
outputs = jnp.stack(outputs, axis=0)
|
293
|
+
return jax.vmap(self.linear)(outputs)
|