mxlpy 0.25.0__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mxlpy/fuzzy.py ADDED
@@ -0,0 +1,139 @@
1
+ """Fuzzy / bayesian fitting methods."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import multiprocessing
6
+ import sys
7
+ from dataclasses import dataclass, field
8
+ from functools import partial
9
+ from math import ceil
10
+ from typing import TYPE_CHECKING, Self
11
+
12
+ import numpy as np
13
+ import pandas as pd
14
+ import pebble
15
+ from tqdm import tqdm, trange
16
+
17
+ from mxlpy.simulator import Simulator
18
+
19
+ if TYPE_CHECKING:
20
+ from collections.abc import Iterable
21
+
22
+ from mxlpy import Model
23
+
24
+ __all__ = ["ThompsonState", "thompson_sampling"]
25
+
26
+
27
+ @dataclass
28
+ class ThompsonState:
29
+ """State of thompson sampling."""
30
+
31
+ rng: np.random.Generator = field(default_factory=np.random.default_rng)
32
+ state: dict[str, pd.DataFrame] = field(default_factory=dict)
33
+
34
+ @classmethod
35
+ def from_parameter_values(cls, parameters: dict[str, Iterable[float]]) -> Self:
36
+ """Create state from parameter values."""
37
+ return cls(
38
+ state={
39
+ k: pd.DataFrame(
40
+ {
41
+ "x": v,
42
+ "success": np.ones_like(v, dtype=int),
43
+ "fail": np.ones_like(v, dtype=int),
44
+ }
45
+ )
46
+ for k, v in parameters.items()
47
+ },
48
+ )
49
+
50
+ def sample(self) -> tuple[dict[str, int], dict[str, float]]:
51
+ """Sample idxs and parameters."""
52
+ idxs = {
53
+ k: int(np.argmax(self.rng.beta(v["success"], v["fail"])))
54
+ for k, v in self.state.items()
55
+ }
56
+ parameters = {k: v["x"][idxs[k]] for k, v in self.state.items()}
57
+ return idxs, parameters
58
+
59
+ def update(
60
+ self,
61
+ idxs: dict[str, int],
62
+ pred: pd.DataFrame | None,
63
+ data: pd.DataFrame,
64
+ rtol: float,
65
+ ) -> None:
66
+ """Sample state."""
67
+ accept: bool = (
68
+ False if pred is None else np.sqrt(np.mean(np.square(pred - data))) < rtol
69
+ )
70
+ for k, v in self.state.items():
71
+ v.loc[idxs[k], "success" if accept else "fail"] += 1 # type: ignore
72
+
73
+
74
+ def _thompson_worker(
75
+ inp: tuple[dict[str, int], dict[str, float]],
76
+ model: Model,
77
+ data: pd.DataFrame,
78
+ ) -> tuple[dict[str, int], pd.DataFrame | None]:
79
+ idxs, parameters = inp
80
+ if (
81
+ res := (
82
+ Simulator(model)
83
+ .update_parameters(parameters)
84
+ .simulate_time_course(data.index)
85
+ .get_result()
86
+ )
87
+ ) is None:
88
+ return idxs, None
89
+ return idxs, res.get_variables()
90
+
91
+
92
+ def thompson_sampling(
93
+ model: Model,
94
+ data: pd.DataFrame,
95
+ state: ThompsonState,
96
+ rtol: float,
97
+ n: int,
98
+ *,
99
+ max_workers: int | None = None,
100
+ disable_tqdm: bool = False,
101
+ timeout: float | None = None,
102
+ parallel: bool = True,
103
+ ) -> ThompsonState:
104
+ """Perform thompson sampling."""
105
+ if sys.platform in ["win32", "cygwin"]:
106
+ parallel = False
107
+
108
+ max_workers = multiprocessing.cpu_count() if max_workers is None else max_workers
109
+ worker = partial(_thompson_worker, model=model, data=data)
110
+
111
+ if not parallel:
112
+ for _ in trange(n):
113
+ idxs, pred = worker(state.sample())
114
+ state.update(idxs, pred, data=data, rtol=rtol)
115
+ else:
116
+ # FIXME: think about whether this is ok to do. Thompson sampling is state-
117
+ # dependent. We are breaking up that state a bit by chunking the approach
118
+ # Is that fine to do?
119
+ with (
120
+ tqdm(total=n, disable=disable_tqdm) as pbar,
121
+ pebble.ProcessPool(max_workers=max_workers) as pool,
122
+ ):
123
+ for _ in range(ceil(n / max_workers)):
124
+ future = pool.map(
125
+ worker,
126
+ [state.sample() for _ in range(max_workers)],
127
+ timeout=timeout,
128
+ )
129
+ it = future.result()
130
+ while True:
131
+ try:
132
+ idxs, pred = next(it)
133
+ state.update(idxs, pred, data=data, rtol=rtol)
134
+ pbar.update(1)
135
+ except StopIteration:
136
+ break
137
+ except TimeoutError:
138
+ pbar.update(1)
139
+ return state
mxlpy/identify.py CHANGED
@@ -9,9 +9,8 @@ import numpy as np
9
9
  import pandas as pd
10
10
  from tqdm import tqdm
11
11
 
12
- from mxlpy import fit_local
12
+ from mxlpy import fit
13
13
  from mxlpy.distributions import LogNormal, sample
14
- from mxlpy.fit.common import LossFn, rmse
15
14
  from mxlpy.parallel import parallelise
16
15
 
17
16
  if TYPE_CHECKING:
@@ -27,13 +26,14 @@ def _mc_fit_time_course_worker(
27
26
  p0: pd.Series,
28
27
  model: Model,
29
28
  data: pd.DataFrame,
30
- loss_fn: fit_local.LossFn,
29
+ loss_fn: fit.LossFn,
31
30
  ) -> float:
32
- fit_result = fit_local.time_course(
31
+ fit_result = fit.time_course(
33
32
  model=model,
34
33
  p0=p0.to_dict(),
35
34
  data=data,
36
35
  loss_fn=loss_fn,
36
+ minimizer=fit.LocalScipyMinimizer(),
37
37
  )
38
38
  if fit_result is None:
39
39
  return np.inf
@@ -46,7 +46,7 @@ def profile_likelihood(
46
46
  parameter_name: str,
47
47
  parameter_values: Array,
48
48
  n_random: int = 10,
49
- loss_fn: LossFn = rmse,
49
+ loss_fn: fit.LossFn = fit.rmse,
50
50
  ) -> pd.Series:
51
51
  """Estimate the profile likelihood of model parameters given data.
52
52
 
@@ -4,7 +4,7 @@ from __future__ import annotations
4
4
 
5
5
  import copy
6
6
  from dataclasses import dataclass, field
7
- from typing import TYPE_CHECKING, cast
7
+ from typing import TYPE_CHECKING, Literal, cast
8
8
 
9
9
  import numpy as np
10
10
  import scipy.integrate as spi
@@ -48,6 +48,7 @@ class Scipy:
48
48
  atol: float = 1e-8
49
49
  rtol: float = 1e-8
50
50
  t0: float = 0.0
51
+ method: Literal["RK45", "RK23", "DOP853", "Radau", "BDF", "LSODA"] = "LSODA"
51
52
  _y0_orig: tuple[float, ...] = field(default_factory=tuple)
52
53
 
53
54
  def __post_init__(self) -> None:
@@ -111,7 +112,7 @@ class Scipy:
111
112
  jac=self.jacobian,
112
113
  atol=self.atol,
113
114
  rtol=self.rtol,
114
- method="LSODA",
115
+ method=self.method,
115
116
  )
116
117
 
117
118
  if res.success:
@@ -149,7 +150,7 @@ class Scipy:
149
150
  # If rhs returns a tuple, we get weird errors, so we need
150
151
  # to wrap this in a list for some reason
151
152
  integ = spi.ode(lambda t, x: list(self.rhs(t, x)), jac=self.jacobian)
152
- integ.set_integrator(name="lsoda")
153
+ integ.set_integrator(name=self.method)
153
154
  integ.set_initial_value(self.y0)
154
155
 
155
156
  t = self.t0 + step_size
@@ -795,6 +795,7 @@ def to_tex_export(model: Model) -> TexExport:
795
795
  for rxn_name, rxn in model.get_raw_reactions().items():
796
796
  for var_name, factor in rxn.stoichiometry.items():
797
797
  diff_eqs.setdefault(var_name, {})[rxn_name] = factor
798
+ # FIXME: think about surrogates here
798
799
 
799
800
  return TexExport(
800
801
  parameters=model.get_parameter_values(),
@@ -635,7 +635,7 @@ def _handle_attribute(node: ast.Attribute, ctx: Context) -> sympy.Expr | None:
635
635
  )
636
636
  elif (var := variables.get(level)) is not None:
637
637
  _LOGGER.debug("var %s", var)
638
- return _get_inner_object(var, levels[(idx + 1) :] + [node.attr])
638
+ return _get_inner_object(var, [*levels[idx + 1 :], node.attr])
639
639
 
640
640
  else:
641
641
  _LOGGER.debug("No target found")
mxlpy/model.py CHANGED
@@ -928,14 +928,14 @@ class Model:
928
928
  if stoichiometries is not None:
929
929
  for rxn_name, value in stoichiometries.items():
930
930
  target = False
931
- if rxn_name in self._reactions:
931
+ if (rxn := self._reactions.get(rxn_name)) is not None:
932
932
  target = True
933
- cast(dict, self._reactions[name].stoichiometry)[name] = value
933
+ cast(dict, rxn.stoichiometry)[name] = value
934
934
  else:
935
935
  for surrogate in self._surrogates.values():
936
- if rxn_name in surrogate.stoichiometries:
936
+ if stoich := surrogate.stoichiometries.get(rxn_name):
937
937
  target = True
938
- surrogate.stoichiometries[rxn_name][name] = value
938
+ stoich[name] = value
939
939
  if not target:
940
940
  msg = f"Reaction '{rxn_name}' not found in reactions or surrogates"
941
941
  raise KeyError(msg)
@@ -1099,7 +1099,12 @@ class Model:
1099
1099
  return self
1100
1100
 
1101
1101
  @_invalidate_cache
1102
- def remove_variable(self, name: str) -> Self:
1102
+ def remove_variable(
1103
+ self,
1104
+ name: str,
1105
+ *,
1106
+ remove_stoichiometries: bool = True,
1107
+ ) -> Self:
1103
1108
  """Remove a variable from the model.
1104
1109
 
1105
1110
  Examples:
@@ -1107,16 +1112,31 @@ class Model:
1107
1112
 
1108
1113
  Args:
1109
1114
  name: The name of the variable to remove.
1115
+ remove_stoichiometries: whether to remove the variable from all reactions
1110
1116
 
1111
1117
  Returns:
1112
1118
  Self: The instance of the model with the variable removed.
1113
1119
 
1114
1120
  """
1121
+ if remove_stoichiometries:
1122
+ for rxn in self._reactions.values():
1123
+ if name in rxn.stoichiometry:
1124
+ cast(dict, rxn.stoichiometry).pop(name)
1125
+ for surrogate in self._surrogates.values():
1126
+ for stoich in surrogate.stoichiometries.values():
1127
+ if name in stoich:
1128
+ cast(dict, stoich).pop(name)
1129
+
1115
1130
  self._remove_id(name=name)
1116
1131
  del self._variables[name]
1117
1132
  return self
1118
1133
 
1119
- def remove_variables(self, variables: Iterable[str]) -> Self:
1134
+ def remove_variables(
1135
+ self,
1136
+ variables: Iterable[str],
1137
+ *,
1138
+ remove_stoichiometries: bool = True,
1139
+ ) -> Self:
1120
1140
  """Remove multiple variables from the model.
1121
1141
 
1122
1142
  Examples:
@@ -1124,13 +1144,16 @@ class Model:
1124
1144
 
1125
1145
  Args:
1126
1146
  variables: An iterable of variable names to be removed.
1147
+ remove_stoichiometries: whether to remove the variables from all reactions
1127
1148
 
1128
1149
  Returns:
1129
1150
  Self: The instance of the model with the specified variables removed.
1130
1151
 
1131
1152
  """
1132
1153
  for variable in variables:
1133
- self.remove_variable(name=variable)
1154
+ self.remove_variable(
1155
+ name=variable, remove_stoichiometries=remove_stoichiometries
1156
+ )
1134
1157
  return self
1135
1158
 
1136
1159
  @_invalidate_cache
@@ -1219,24 +1242,17 @@ class Model:
1219
1242
  value_or_derived = (
1220
1243
  self._variables[name].initial_value if value is None else value
1221
1244
  )
1222
- self.remove_variable(name)
1223
-
1224
- # FIXME: better handling of unit
1225
- if isinstance(value_or_derived, Derived):
1226
- self.add_derived(name, value_or_derived.fn, args=value_or_derived.args)
1245
+ self.remove_variable(name, remove_stoichiometries=True)
1246
+
1247
+ if isinstance(der := value_or_derived, Derived):
1248
+ self.add_derived(
1249
+ name,
1250
+ der.fn,
1251
+ args=der.args,
1252
+ unit=der.unit,
1253
+ )
1227
1254
  else:
1228
1255
  self.add_parameter(name, value_or_derived)
1229
-
1230
- # Remove from stoichiometries
1231
- for reaction in self._reactions.values():
1232
- if name in reaction.stoichiometry:
1233
- cast(dict, reaction.stoichiometry).pop(name)
1234
- for surrogate in self._surrogates.values():
1235
- surrogate.stoichiometries = {
1236
- k: {k2: v2 for k2, v2 in v.items() if k2 != name}
1237
- for k, v in surrogate.stoichiometries.items()
1238
- if k != name
1239
- }
1240
1256
  return self
1241
1257
 
1242
1258
  ##########################################################################
@@ -1680,7 +1696,8 @@ class Model:
1680
1696
 
1681
1697
  ##########################################################################
1682
1698
  # Readouts
1683
- # They are like derived variables, but only calculated on demand
1699
+ # They are like derived variables, but only calculated on demand, e.g. after
1700
+ # a simulation
1684
1701
  # Think of something like NADPH / (NADP + NADPH) as a proxy for energy state
1685
1702
  ##########################################################################
1686
1703
 
mxlpy/nn/__init__.py CHANGED
@@ -8,11 +8,16 @@ if TYPE_CHECKING:
8
8
  import contextlib
9
9
 
10
10
  with contextlib.suppress(ImportError):
11
+ from . import _equinox as equinox
11
12
  from . import _keras as keras
12
13
  from . import _torch as torch
13
14
  else:
14
15
  from lazy_import import lazy_module
15
16
 
17
+ equinox = lazy_module(
18
+ "mxlpy.nn._equinox",
19
+ error_strings={"module": "equinox", "install_name": "mxlpy[equinox]"},
20
+ )
16
21
  keras = lazy_module(
17
22
  "mxlpy.nn._keras",
18
23
  error_strings={"module": "keras", "install_name": "mxlpy[tf]"},
mxlpy/nn/_equinox.py ADDED
@@ -0,0 +1,293 @@
1
+ """Neural network architectures.
2
+
3
+ This module provides implementations of neural network architectures used for mechanistic learning.
4
+
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import TYPE_CHECKING
10
+
11
+ import equinox as eqx
12
+ import jax
13
+ import jax.numpy as jnp
14
+ import numpy as np
15
+ import pandas as pd
16
+ import torch
17
+ import tqdm
18
+ from jaxtyping import Array, PyTree
19
+ from torch.utils.data import DataLoader, TensorDataset
20
+
21
+ if TYPE_CHECKING:
22
+ from collections.abc import Callable
23
+
24
+ import optax
25
+
26
+
27
+ __all__ = [
28
+ "LSTM",
29
+ "LossFn",
30
+ "MLP",
31
+ "cosine_similarity",
32
+ "mean_abs_error",
33
+ "mean_absolute_percentage",
34
+ "mean_error",
35
+ "mean_squared_error",
36
+ "mean_squared_logarithmic",
37
+ "rms_error",
38
+ "train",
39
+ ]
40
+
41
+
42
+ ###############################################################################
43
+ # Loss functions
44
+ ###############################################################################
45
+
46
+ type LossFn = Callable[[eqx.Module, Array, Array], Array]
47
+
48
+
49
+ @eqx.filter_jit
50
+ def mean_error(model: eqx.Module, inp: Array, true: Array) -> Array:
51
+ """Calculate mean error."""
52
+ pred = jax.vmap(model)(inp) # type: ignore
53
+ return jnp.mean(pred - true)
54
+
55
+
56
+ @eqx.filter_jit
57
+ def mean_squared_error(model: eqx.Module, inp: Array, true: Array) -> Array:
58
+ """Calculate mean squared error."""
59
+ pred = jax.vmap(model)(inp) # type: ignore
60
+ return jnp.mean(jnp.square(pred - true))
61
+
62
+
63
+ @eqx.filter_jit
64
+ def rms_error(model: eqx.Module, inp: Array, true: Array) -> Array:
65
+ """Calculate root mean square error."""
66
+ pred = jax.vmap(model)(inp) # type: ignore
67
+ return jnp.sqrt(jnp.mean(jnp.square(pred - true)))
68
+
69
+
70
+ @eqx.filter_jit
71
+ def mean_abs_error(model: eqx.Module, inp: Array, true: Array) -> Array:
72
+ """Calculate mean absolute error."""
73
+ pred = jax.vmap(model)(inp) # type: ignore
74
+ return jnp.mean(jnp.abs(pred - true))
75
+
76
+
77
+ @eqx.filter_jit
78
+ def mean_absolute_percentage(model: eqx.Module, inp: Array, true: Array) -> Array:
79
+ """Calculate mean absolute percentag error."""
80
+ pred = jax.vmap(model)(inp) # type: ignore
81
+ return 100 * jnp.mean(jnp.abs((true - pred) / pred))
82
+
83
+
84
+ @eqx.filter_jit
85
+ def mean_squared_logarithmic(model: eqx.Module, inp: Array, true: Array) -> Array:
86
+ """Calculate root mean square error between model and data."""
87
+ pred = jax.vmap(model)(inp) # type: ignore
88
+ return jnp.mean(jnp.square(jnp.log(pred + 1) - jnp.log(true + 1)))
89
+
90
+
91
+ @eqx.filter_jit
92
+ def cosine_similarity(model: eqx.Module, inp: Array, true: Array) -> Array:
93
+ """Calculate root mean square error between model and data."""
94
+ pred = jax.vmap(model)(inp) # type: ignore
95
+ return -jnp.sum(jnp.linalg.norm(pred, 2) * jnp.linalg.norm(true, 2))
96
+
97
+
98
+ ###############################################################################
99
+ # Training routines
100
+ ###############################################################################
101
+
102
+
103
+ def train(
104
+ model: eqx.Module,
105
+ features: Array,
106
+ targets: Array,
107
+ epochs: int,
108
+ optimizer: optax.GradientTransformation,
109
+ batch_size: int | None,
110
+ loss_fn: LossFn,
111
+ ) -> pd.Series:
112
+ """Train the neural network using mini-batch gradient descent.
113
+
114
+ Args:
115
+ model: Neural network model to train.
116
+ features: Input features as a tensor.
117
+ targets: Target values as a tensor.
118
+ epochs: Number of training epochs.
119
+ optimizer: Optimizer for training.
120
+ device: torch device
121
+ batch_size: Size of mini-batches for training.
122
+ loss_fn: Loss function
123
+
124
+ Returns:
125
+ pd.Series: Series containing the training loss history.
126
+
127
+ """
128
+ losses = {}
129
+
130
+ data = TensorDataset(
131
+ torch.tensor(features.astype(np.float32), dtype=torch.float32),
132
+ torch.tensor(targets.astype(np.float32), dtype=torch.float32),
133
+ )
134
+ data_loader = DataLoader(
135
+ data,
136
+ batch_size=len(features) if batch_size is None else batch_size,
137
+ shuffle=True,
138
+ )
139
+
140
+ opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
141
+
142
+ @eqx.filter_jit
143
+ def make_step(
144
+ model: eqx.Module,
145
+ opt_state: PyTree,
146
+ x: Array,
147
+ y: Array,
148
+ ) -> tuple[eqx.Module, Array, Array]:
149
+ loss_value, grads = eqx.filter_value_and_grad(loss_fn)(model, x, y)
150
+ updates, opt_state = optimizer.update(
151
+ grads, opt_state, eqx.filter(model, eqx.is_array)
152
+ )
153
+ model = eqx.apply_updates(model, updates)
154
+ return model, opt_state, loss_value
155
+
156
+ for i in tqdm.trange(epochs):
157
+ epoch_loss = 0
158
+ for xb, yb in data_loader:
159
+ model, opt_state, train_loss = make_step(
160
+ model,
161
+ opt_state,
162
+ xb.numpy(),
163
+ yb.numpy(),
164
+ )
165
+ epoch_loss += train_loss * xb.size(0)
166
+ losses[i] = epoch_loss / len(data_loader.dataset) # type: ignore
167
+ return pd.Series(losses, dtype=float)
168
+
169
+
170
+ ###############################################################################
171
+ # Actual models
172
+ ###############################################################################
173
+
174
+
175
+ class MLP(eqx.Module):
176
+ """Multilayer Perceptron (MLP) for surrogate modeling and neural posterior estimation.
177
+
178
+ Attributes:
179
+ net: Sequential neural network model.
180
+
181
+ Methods:
182
+ forward: Forward pass through the neural network.
183
+
184
+ """
185
+
186
+ layers: list
187
+
188
+ def __init__(
189
+ self,
190
+ n_inputs: int,
191
+ neurons_per_layer: list[int],
192
+ key: Array,
193
+ ) -> None:
194
+ """Initializes the MLP with the given number of inputs and list of (hidden) layers.
195
+
196
+ Args:
197
+ n_inputs: The number of input features.
198
+ neurons_per_layer: Number of neurons per layer
199
+ n_outputs: A list containing the number of neurons in hidden and output layer.
200
+ key: jax.random.PRNGKey(SEED) for initial parameters
201
+
202
+ For instance, MLP(10, layers = [50, 50, 10]) initializes a neural network with the following architecture:
203
+ - Linear layer with `n_inputs` inputs and 50 outputs
204
+ - ReLU activation
205
+ - Linear layer with 50 inputs and 50 outputs
206
+ - ReLU activation
207
+ - Linear layer with 50 inputs and 10 outputs
208
+
209
+ The weights of the linear layers are initialized with a normal distribution
210
+ (mean=0, std=0.1) and the biases are initialized to 0.
211
+
212
+ """
213
+ keys = iter(jax.random.split(key, len(neurons_per_layer)))
214
+ previous_neurons = n_inputs
215
+ layers = []
216
+ for neurons in neurons_per_layer:
217
+ layers.append(eqx.nn.Linear(previous_neurons, neurons, key=next(keys)))
218
+ previous_neurons = neurons
219
+ self.layers = layers
220
+
221
+ def __call__(self, x: Array) -> Array:
222
+ """Forward pass through the neural network.
223
+
224
+ Args:
225
+ x: Input tensor.
226
+
227
+ Returns:
228
+ Output tensor.
229
+
230
+ """
231
+ for layer in self.layers[:-1]:
232
+ x = jax.nn.relu(layer(x))
233
+ return self.layers[-1](x)
234
+
235
+
236
+ class LSTM(eqx.Module):
237
+ """Default LSTM neural network model for time-series approximation."""
238
+
239
+ lstm_cell: eqx.nn.LSTMCell
240
+ n_hidden: int
241
+ linear: eqx.nn.Linear
242
+
243
+ def __init__(
244
+ self,
245
+ n_inputs: int,
246
+ n_outputs: int,
247
+ n_hidden: int,
248
+ key: Array,
249
+ ) -> None:
250
+ """Initializes the LSTM neural network model.
251
+
252
+ Args:
253
+ n_inputs (int): Number of input features.
254
+ n_outputs (int): Number of output features.
255
+ n_hidden (int): Number of hidden units in the LSTM layer.
256
+ key (Array): JAX random key for initialization.
257
+
258
+ """
259
+ k1, k2 = jax.random.split(key, 2)
260
+ self.lstm_cell = eqx.nn.LSTMCell(n_inputs, n_hidden, key=k1)
261
+ self.n_hidden = n_hidden
262
+ self.linear = eqx.nn.Linear(n_hidden, n_outputs, key=k2)
263
+
264
+ def __call__(
265
+ self,
266
+ x: Array,
267
+ *,
268
+ h: Array | None = None,
269
+ c: Array | None = None,
270
+ ) -> Array:
271
+ """Forward pass through the LSTM network.
272
+
273
+ Args:
274
+ x: Input tensor of shape (seq_len, batch_size, n_inputs).
275
+ h: Optional initial hidden state (batch_size, n_hidden).
276
+ c: Optional initial cell state (batch_size, n_hidden).
277
+
278
+ Returns:
279
+ Output tensor of shape (seq_len, batch_size, n_outputs).
280
+
281
+ """
282
+ seq_len, batch_size, _ = x.shape
283
+ if h is None:
284
+ h = jnp.zeros((batch_size, self.n_hidden))
285
+ if c is None:
286
+ c = jnp.zeros((batch_size, self.n_hidden))
287
+
288
+ outputs = []
289
+ for t in range(seq_len):
290
+ h, c = self.lstm_cell(x[t], (h, c))
291
+ outputs.append(h)
292
+ outputs = jnp.stack(outputs, axis=0)
293
+ return jax.vmap(self.linear)(outputs)