jinns 0.8.10__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +953 -1182
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +321 -168
  9. jinns/loss/_LossODE.py +290 -307
  10. jinns/loss/_LossPDE.py +628 -1040
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +95 -96
  13. jinns/loss/{_Losses.py → _loss_utils.py} +104 -46
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +94 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +193 -45
  22. jinns/solver/_solve.py +199 -144
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -43
  25. jinns/utils/_hyperpinn.py +226 -127
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +117 -84
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +52 -144
  32. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/METADATA +5 -4
  33. jinns-1.0.0.dist-info/RECORD +38 -0
  34. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/WHEEL +1 -1
  35. jinns/experimental/_sinuspinn.py +0 -135
  36. jinns/experimental/_spectralpinn.py +0 -87
  37. jinns/solver/_seq2seq.py +0 -157
  38. jinns/utils/_optim.py +0 -147
  39. jinns/utils/_utils_uspinn.py +0 -727
  40. jinns-0.8.10.dist-info/RECORD +0 -36
  41. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/LICENSE +0 -0
  42. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/top_level.txt +0 -0
jinns/utils/_spinn.py CHANGED
@@ -3,54 +3,53 @@ Implements utility function to create Separable PINNs
3
3
  https://arxiv.org/abs/2211.08761
4
4
  """
5
5
 
6
+ from dataclasses import InitVar
7
+ from typing import Callable, Literal
6
8
  import jax
7
9
  import jax.numpy as jnp
8
10
  import equinox as eqx
11
+ from jaxtyping import Key, Array, Float, PyTree
9
12
 
10
13
 
11
14
  class _SPINN(eqx.Module):
12
15
  """
13
16
  Construct a Separable PINN as proposed in
14
17
  Cho et al., _Separable Physics-Informed Neural Networks_, NeurIPS, 2023
18
+
19
+ Parameters
20
+ ----------
21
+ key : InitVar[Key]
22
+ A jax random key for the layer initializations.
23
+ d : int
24
+ The number of dimensions to treat separately.
25
+ eqx_list : InitVar[tuple[tuple[Callable, int, int] | Callable, ...]]
26
+ A tuple of tuples of successive equinox modules and activation functions to
27
+ describe the PINN architecture. The inner tuples must have the eqx module or
28
+ activation function as first item, other items represents arguments
29
+ that could be required (eg. the size of the layer).
30
+ The `key` argument need not be given.
31
+ Thus typical example is `eqx_list=
32
+ ((eqx.nn.Linear, 2, 20),
33
+ jax.nn.tanh,
34
+ (eqx.nn.Linear, 20, 20),
35
+ jax.nn.tanh,
36
+ (eqx.nn.Linear, 20, 20),
37
+ jax.nn.tanh,
38
+ (eqx.nn.Linear, 20, 1)
39
+ )`.
15
40
  """
16
41
 
17
- layers: list
18
- separated_mlp: list
19
- d: int
20
- r: int
21
- m: int
42
+ d: int = eqx.field(static=True, kw_only=True)
22
43
 
23
- def __init__(self, key, d, r, eqx_list, m=1):
24
- """
25
- Parameters
26
- ----------
27
- key
28
- A jax random key
29
- d
30
- An integer. The number of dimensions to treat separately
31
- r
32
- An integer. The dimension of the embedding
33
- eqx_list
34
- A list of list of successive equinox modules and activation functions to
35
- describe *each separable PINN architecture*.
36
- The inner lists have the eqx module or
37
- axtivation function as first item, other items represents arguments
38
- that could be required (eg. the size of the layer).
39
- __Note:__ the `key` argument need not be given.
40
- Thus typical example is `eqx_list=
41
- [[eqx.nn.Linear, 1, 20],
42
- [jax.nn.tanh],
43
- [eqx.nn.Linear, 20, 20],
44
- [jax.nn.tanh],
45
- [eqx.nn.Linear, 20, 20],
46
- [jax.nn.tanh],
47
- [eqx.nn.Linear, 20, r]
48
- ]`
49
- """
50
- self.d = d
51
- self.r = r
52
- self.m = m
44
+ key: InitVar[Key] = eqx.field(kw_only=True)
45
+ eqx_list: InitVar[tuple[tuple[Callable, int, int] | Callable, ...]] = eqx.field(
46
+ kw_only=True
47
+ )
48
+
49
+ layers: list = eqx.field(init=False)
50
+ separated_mlp: list = eqx.field(init=False)
53
51
 
52
+ def __post_init__(self, key, eqx_list):
54
53
  self.separated_mlp = []
55
54
  for _ in range(self.d):
56
55
  self.layers = []
@@ -62,7 +61,9 @@ class _SPINN(eqx.Module):
62
61
  self.layers.append(l[0](*l[1:], key=subkey))
63
62
  self.separated_mlp.append(self.layers)
64
63
 
65
- def __call__(self, t, x):
64
+ def __call__(
65
+ self, t: Float[Array, "1"], x: Float[Array, "omega_dim"]
66
+ ) -> Float[Array, "d embed_dim*output_dim"]:
66
67
  if t is not None:
67
68
  dimensions = jnp.concatenate([t, x.flatten()], axis=0)
68
69
  else:
@@ -78,50 +79,67 @@ class _SPINN(eqx.Module):
78
79
 
79
80
  class SPINN(eqx.Module):
80
81
  """
81
- Basically a wrapper around the `__call__` function to be able to give a type to
82
- our former `self.u`
83
- The function create_SPINN has the role to population the `__call__` function
82
+ A SPINN object compatible with the rest of jinns.
83
+ This is typically created with `create_SPINN`.
84
+
85
+ **NOTE**: SPINNs with `t` and `x` as inputs are best used with a
86
+ DataGenerator with `self.cartesian_product=False` for memory consideration
87
+
88
+ Parameters
89
+ ----------
90
+ d : int
91
+ The number of dimensions to treat separately.
92
+
84
93
  """
85
94
 
86
- d: int
87
- r: int
88
- eq_type: str = eqx.field(static=True)
89
- m: int
90
- params: eqx.Module
91
- static: eqx.Module
95
+ d: int = eqx.field(static=True, kw_only=True)
96
+ r: int = eqx.field(static=True, kw_only=True)
97
+ eq_type: str = eqx.field(static=True, kw_only=True)
98
+ m: int = eqx.field(static=True, kw_only=True)
99
+
100
+ spinn_mlp: InitVar[eqx.Module] = eqx.field(kw_only=True)
92
101
 
93
- def __init__(self, spinn_mlp, d, r, eq_type, m):
94
- self.d, self.r, self.m = d, r, m
102
+ params: PyTree = eqx.field(init=False)
103
+ static: PyTree = eqx.field(init=False, static=True)
104
+
105
+ def __post_init__(self, spinn_mlp):
95
106
  self.params, self.static = eqx.partition(spinn_mlp, eqx.is_inexact_array)
96
- self.eq_type = eq_type
97
107
 
98
- def init_params(self):
108
+ def init_params(self) -> PyTree:
109
+ """
110
+ Returns an initial set of parameters
111
+ """
99
112
  return self.params
100
113
 
101
- def __call__(self, *args):
114
+ def __call__(self, *args) -> Float[Array, "output_dim"]:
115
+ """
116
+ Calls `eval_nn` with rearranged arguments
117
+ """
102
118
  if self.eq_type == "statio_PDE":
103
119
  (x, params) = args
104
120
  try:
105
- spinn = eqx.combine(params["nn_params"], self.static)
106
- except (KeyError, TypeError) as e:
121
+ spinn = eqx.combine(params.nn_params, self.static)
122
+ except (KeyError, AttributeError, TypeError) as e:
107
123
  spinn = eqx.combine(params, self.static)
108
124
  v_model = jax.vmap(spinn, (0))
109
125
  res = v_model(t=None, x=x)
110
- return self._eval_nn(res)
126
+ return self.eval_nn(res)
111
127
  if self.eq_type == "nonstatio_PDE":
112
128
  (t, x, params) = args
113
129
  try:
114
- spinn = eqx.combine(params["nn_params"], self.static)
115
- except (KeyError, TypeError) as e:
130
+ spinn = eqx.combine(params.nn_params, self.static)
131
+ except (KeyError, AttributeError, TypeError) as e:
116
132
  spinn = eqx.combine(params, self.static)
117
133
  v_model = jax.vmap(spinn, ((0, 0)))
118
134
  res = v_model(t, x)
119
- return self._eval_nn(res)
135
+ return self.eval_nn(res)
120
136
  raise RuntimeError("Wrong parameter value for eq_type")
121
137
 
122
- def _eval_nn(self, res):
138
+ def eval_nn(
139
+ self, res: Float[Array, "d embed_dim*output_dim"]
140
+ ) -> Float[Array, "output_dim"]:
123
141
  """
124
- common content of apply_fn put here in order to factorize code
142
+ Evaluate the SPINN on some inputs with some params.
125
143
  """
126
144
  a = ", ".join([f"{chr(97 + d)}z" for d in range(res.shape[1])])
127
145
  b = "".join([f"{chr(97 + d)}" for d in range(res.shape[1])])
@@ -145,56 +163,71 @@ class SPINN(eqx.Module):
145
163
  return res
146
164
 
147
165
 
148
- def create_SPINN(key, d, r, eqx_list, eq_type, m=1):
166
+ def create_SPINN(
167
+ key: Key,
168
+ d: int,
169
+ r: int,
170
+ eqx_list: tuple[tuple[Callable, int, int] | Callable, ...],
171
+ eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
172
+ m: int = 1,
173
+ ) -> SPINN:
149
174
  """
150
175
  Utility function to create a SPINN neural network with the equinox
151
176
  library.
152
177
 
153
- *Note* that a SPINN is not vmapped from the outside and expects batch of the
154
- same size for each input. It outputs directly a solution of shape
155
- (batchsize, batchsize). See the paper for more details.
178
+ *Note* that a SPINN is not vmapped and expects the
179
+ same batch size for each of its input axis. It directly outputs a solution
180
+ of shape `(batchsize, batchsize)`. See the paper for more details.
156
181
 
157
182
  Parameters
158
183
  ----------
159
184
  key
160
- A jax random key that will be used to initialize the network parameters
185
+ A JAX random key that will be used to initialize the network parameters
161
186
  d
162
- An integer. The number of dimensions to treat separately
187
+ The number of dimensions to treat separately.
163
188
  r
164
- An integer. The dimension of the embedding
189
+ An integer. The dimension of the embedding.
165
190
  eqx_list
166
- A list of list of successive equinox modules and activation functions to
167
- describe *each separable PINN architecture*.
168
- The inner lists have the eqx module or
169
- axtivation function as first item, other items represents arguments
191
+ A tuple of tuples of successive equinox modules and activation functions to
192
+ describe the PINN architecture. The inner tuples must have the eqx module or
193
+ activation function as first item, other items represents arguments
170
194
  that could be required (eg. the size of the layer).
171
- __Note:__ the `key` argument need not be given.
172
- Thus typical example is `eqx_list=
173
- [[eqx.nn.Linear, 1, 20],
174
- [jax.nn.tanh],
175
- [eqx.nn.Linear, 20, 20],
176
- [jax.nn.tanh],
177
- [eqx.nn.Linear, 20, 20],
178
- [jax.nn.tanh],
179
- [eqx.nn.Linear, 20, r]
180
- ]`
181
- eq_type
195
+ The `key` argument need not be given.
196
+ Thus typical example is
197
+ `eqx_list=((eqx.nn.Linear, 2, 20),
198
+ jax.nn.tanh,
199
+ (eqx.nn.Linear, 20, 20),
200
+ jax.nn.tanh,
201
+ (eqx.nn.Linear, 20, 20),
202
+ jax.nn.tanh,
203
+ (eqx.nn.Linear, 20, 1)
204
+ )`.
205
+ eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
182
206
  A string with three possibilities.
183
207
  "ODE": the PINN is called with one input `t`.
184
208
  "statio_PDE": the PINN is called with one input `x`, `x`
185
209
  can be high dimensional.
186
210
  "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
187
211
  can be high dimensional.
212
+ **Note**: the input dimension as given in eqx_list has to match the sum
213
+ of the dimension of `t` + the dimension of `x` or the output dimension
214
+ after the `input_transform` function.
188
215
  m
189
- An integer. The output dimension of the neural network. According to
216
+ The output dimension of the neural network. According to
190
217
  the SPINN article, a total embedding dimension of `r*m` is defined. We
191
218
  then sum groups of `r` embedding dimensions to compute each output.
192
219
  Default is 1.
193
220
 
221
+ !!! note
222
+ SPINNs with `t` and `x` as inputs are best used with a
223
+ DataGenerator with `self.cartesian_product=False` for memory
224
+ consideration
225
+
194
226
 
195
227
  Returns
196
228
  -------
197
- `u`, a :class:`.SPINN` object which inherits from `eqx.Module` (hence callable).
229
+ spinn
230
+ An instanciated SPINN
198
231
 
199
232
  Raises
200
233
  ------
@@ -229,7 +262,7 @@ def create_SPINN(key, d, r, eqx_list, eq_type, m=1):
229
262
  "Too many dimensions, not enough letters available in jnp.einsum"
230
263
  )
231
264
 
232
- spinn_mlp = _SPINN(key, d, r, eqx_list, m)
233
- spinn = SPINN(spinn_mlp, d, r, eq_type, m)
265
+ spinn_mlp = _SPINN(key=key, d=d, eqx_list=eqx_list)
266
+ spinn = SPINN(spinn_mlp=spinn_mlp, d=d, r=r, eq_type=eq_type, m=m)
234
267
 
235
268
  return spinn
jinns/utils/_types.py ADDED
@@ -0,0 +1,64 @@
1
+ from __future__ import (
2
+ annotations,
3
+ ) # https://docs.python.org/3/library/typing.html#constant
4
+
5
+ from typing import TypeAlias, TYPE_CHECKING, NewType
6
+ from jaxtyping import Int
7
+
8
+ if TYPE_CHECKING:
9
+ from jinns.loss._LossPDE import (
10
+ LossPDEStatio,
11
+ LossPDENonStatio,
12
+ SystemLossPDE,
13
+ )
14
+
15
+ from jinns.loss._LossODE import LossODE, SystemLossODE
16
+ from jinns.parameters._params import Params, ParamsDict
17
+ from jinns.data._DataGenerators import (
18
+ DataGeneratorODE,
19
+ CubicMeshPDEStatio,
20
+ CubicMeshPDENonStatio,
21
+ DataGeneratorObservations,
22
+ DataGeneratorParameter,
23
+ DataGeneratorObservationsMultiPINNs,
24
+ )
25
+
26
+ from jinns.loss import DynamicLoss
27
+ from jinns.data._Batchs import *
28
+ from jinns.utils._pinn import PINN
29
+ from jinns.utils._hyperpinn import HYPERPINN
30
+ from jinns.utils._spinn import SPINN
31
+ from jinns.utils._containers import *
32
+ from jinns.validation._validation import AbstractValidationModule
33
+
34
+ AnyLoss: TypeAlias = (
35
+ LossPDEStatio | LossPDENonStatio | SystemLossPDE | LossODE | SystemLossODE
36
+ )
37
+
38
+ AnyParams: TypeAlias = Params | ParamsDict
39
+
40
+ AnyDataGenerator: TypeAlias = (
41
+ DataGeneratorODE | CubicMeshPDEStatio | CubicMeshPDENonStatio
42
+ )
43
+
44
+ AnyPINN: TypeAlias = PINN | HYPERPINN | SPINN
45
+
46
+ AnyBatch: TypeAlias = ODEBatch | PDEStatioBatch | PDENonStatioBatch
47
+ rar_operands = NewType(
48
+ "rar_operands", tuple[AnyLoss, AnyParams, AnyDataGenerator, Int]
49
+ )
50
+
51
+ main_carry = NewType(
52
+ "main_carry",
53
+ tuple[
54
+ Int,
55
+ AnyLoss,
56
+ OptimizationContainer,
57
+ OptimizationExtraContainer,
58
+ DataGeneratorContainer,
59
+ AbstractValidationModule,
60
+ LossContainer,
61
+ StoredObjectContainer,
62
+ Float[Array, "n_iter"],
63
+ ],
64
+ )
jinns/utils/_utils.py CHANGED
@@ -7,9 +7,10 @@ from operator import getitem
7
7
  import numpy as np
8
8
  import jax
9
9
  import jax.numpy as jnp
10
+ from jaxtyping import PyTree, Array
10
11
 
11
12
 
12
- def _check_nan_in_pytree(pytree):
13
+ def _check_nan_in_pytree(pytree: PyTree) -> bool:
13
14
  """
14
15
  Check if there is a NaN value anywhere is the pytree
15
16
 
@@ -25,40 +26,14 @@ def _check_nan_in_pytree(pytree):
25
26
  """
26
27
  return jnp.any(
27
28
  jnp.array(
28
- list(
29
- jax.tree_util.tree_leaves(
30
- jax.tree_util.tree_map(lambda x: jnp.any(jnp.isnan(x)), pytree)
31
- )
29
+ jax.tree_util.tree_leaves(
30
+ jax.tree_util.tree_map(lambda x: jnp.any(jnp.isnan(x)), pytree)
32
31
  )
33
32
  )
34
33
  )
35
34
 
36
35
 
37
- def _tracked_parameters(params, tracked_params_key_list):
38
- """
39
- Returns a pytree with the same structure as params with True is the
40
- parameter is tracked False otherwise
41
- """
42
-
43
- def set_nested_item(dataDict, mapList, val):
44
- """
45
- Set item in nested dictionary
46
- https://stackoverflow.com/questions/54137991/how-to-update-values-in-nested-dictionary-if-keys-are-in-a-list
47
- """
48
- reduce(getitem, mapList[:-1], dataDict)[mapList[-1]] = val
49
- return dataDict
50
-
51
- tracked_params = jax.tree_util.tree_map(
52
- lambda x: False, params
53
- ) # init with all False
54
-
55
- for key_list in tracked_params_key_list:
56
- tracked_params = set_nested_item(tracked_params, key_list, True)
57
-
58
- return tracked_params
59
-
60
-
61
- def _get_grid(in_array):
36
+ def _get_grid(in_array: Array) -> Array:
62
37
  """
63
38
  From an array of shape (B, D), D > 1, get the grid array, i.e., an array of
64
39
  shape (B, B, ...(D times)..., B, D): along the last axis we have the array
@@ -74,31 +49,7 @@ def _get_grid(in_array):
74
49
  return in_array
75
50
 
76
51
 
77
- def _get_vmap_in_axes_params(eq_params_batch_dict, params):
78
- """
79
- Return the input vmap axes when there is batch(es) of parameters to vmap
80
- over. The latter are designated by keys in eq_params_batch_dict
81
- If eq_params_batch_dict (ie no additional parameter batch), we return None
82
- """
83
- if eq_params_batch_dict is None:
84
- return (None,)
85
- # We use pytree indexing of vmapped axes and vmap on axis
86
- # 0 of the eq_parameters for which we have a batch
87
- # this is for a fine-grained vmaping
88
- # scheme over the params
89
- vmap_in_axes_params = (
90
- {
91
- "nn_params": None,
92
- "eq_params": {
93
- k: (0 if k in eq_params_batch_dict.keys() else None)
94
- for k in params["eq_params"].keys()
95
- },
96
- },
97
- )
98
- return vmap_in_axes_params
99
-
100
-
101
- def _check_user_func_return(r, shape):
52
+ def _check_user_func_return(r: Array | int, shape: tuple) -> Array | int:
102
53
  """
103
54
  Correctly handles the result from a user defined function (eg a boundary
104
55
  condition) to get the correct broadcast
@@ -115,108 +66,3 @@ def _check_user_func_return(r, shape):
115
66
  # the reshape below avoids a missing (1,) ending dimension
116
67
  # depending on how the user has coded the inital function
117
68
  return r.reshape(shape)
118
-
119
-
120
- def _set_derivatives(params, loss_term, derivative_keys):
121
- """
122
- Given derivative_keys, the parameters wrt which we want to compute
123
- gradients in the loss, we set stop_gradient operators to not take the
124
- derivatives with respect to the others. Note that we only operator at
125
- top level
126
- """
127
- try:
128
- params = {
129
- k: (
130
- value
131
- if k in derivative_keys[loss_term]
132
- else jax.lax.stop_gradient(value)
133
- )
134
- for k, value in params.items()
135
- }
136
- except KeyError: # if the loss_term key has not been specified we
137
- # only take gradients wrt "nn_params", all the other entries have
138
- # stopped gradient
139
- params = {
140
- k: value if k in ["nn_params"] else jax.lax.stop_gradient(value)
141
- for k, value in params.items()
142
- }
143
-
144
- return params
145
-
146
-
147
- def _extract_nn_params(params_dict, nn_key):
148
- """
149
- Given a params_dict for system loss (ie "nn_params" and "eq_params" as main
150
- keys which contain dicts for each PINN (the nn_keys)) we extract the
151
- corresponding "nn_params" for `nn_key` and reform a dict with "nn_params"
152
- as main key as expected by the PINN/SPINN apply_fn
153
- """
154
- try:
155
- return {
156
- "nn_params": params_dict["nn_params"][nn_key],
157
- "eq_params": params_dict["eq_params"][nn_key],
158
- }
159
- except (KeyError, IndexError) as e:
160
- return {
161
- "nn_params": params_dict["nn_params"][nn_key],
162
- "eq_params": params_dict["eq_params"],
163
- }
164
-
165
-
166
- def euler_maruyama_density(t, x, s, y, params, Tmax=1):
167
- eps = 1e-6
168
- delta = jnp.abs(t - s) * Tmax
169
- mu = params["alpha_sde"] * (params["mu_sde"] - y) * delta
170
- var = params["sigma_sde"] ** 2 * delta
171
- return (
172
- 1 / jnp.sqrt(2 * jnp.pi * var) * jnp.exp(-0.5 * ((x - y) - mu) ** 2 / var) + eps
173
- )
174
-
175
-
176
- def log_euler_maruyama_density(t, x, s, y, params):
177
- eps = 1e-6
178
- delta = jnp.abs(t - s)
179
- mu = params["alpha_sde"] * (params["mu_sde"] - y) * delta
180
- logvar = params["logvar_sde"]
181
- return (
182
- -0.5
183
- * (jnp.log(2 * jnp.pi * delta) + logvar + ((x - y) - mu) ** 2 / jnp.exp(logvar))
184
- + eps
185
- )
186
-
187
-
188
- def euler_maruyama(x0, alpha, mu, sigma, T, N):
189
- """
190
- Simulate 1D diffusion process with simple parametrization using the Euler
191
- Maruyama method in the interval [0, T]
192
- """
193
- path = [np.array([x0])]
194
-
195
- time_steps, step_size = np.linspace(0, T, N, retstep=True)
196
- for _ in time_steps[1:]:
197
- path.append(
198
- path[-1]
199
- + step_size * (alpha * (mu - path[-1]))
200
- + sigma * np.random.normal(loc=0.0, scale=np.sqrt(step_size))
201
- )
202
-
203
- return time_steps, np.stack(path)
204
-
205
-
206
- def _update_eq_params_dict(params, param_batch_dict):
207
- # update params["eq_params"] with a batch of eq_params
208
- # we avoid side_effect by recreating the dict `params`
209
- # TODO transform `params` in a NamedTuple to be able to use _replace
210
- # see Issue #1
211
- param_batch_dict_ = param_batch_dict | {
212
- k: None for k in set(params["eq_params"].keys()) - set(param_batch_dict.keys())
213
- }
214
- params = {"nn_params": params["nn_params"]} | {
215
- "eq_params": jax.tree_util.tree_map(
216
- lambda p, q: q if q is not None else p,
217
- params["eq_params"],
218
- param_batch_dict_,
219
- )
220
- }
221
-
222
- return params