jinns 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +904 -1203
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +322 -167
  9. jinns/loss/_LossODE.py +324 -322
  10. jinns/loss/_LossPDE.py +652 -1027
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +87 -41
  13. jinns/loss/{_Losses.py → _loss_utils.py} +101 -45
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +521 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +183 -39
  22. jinns/solver/_solve.py +151 -124
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -44
  25. jinns/utils/_hyperpinn.py +224 -119
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +113 -86
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +48 -140
  32. jinns-1.1.0.dist-info/AUTHORS +2 -0
  33. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/METADATA +5 -4
  34. jinns-1.1.0.dist-info/RECORD +39 -0
  35. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/WHEEL +1 -1
  36. jinns/experimental/_sinuspinn.py +0 -135
  37. jinns/experimental/_spectralpinn.py +0 -87
  38. jinns/solver/_seq2seq.py +0 -157
  39. jinns/utils/_optim.py +0 -147
  40. jinns/utils/_utils_uspinn.py +0 -727
  41. jinns-0.9.0.dist-info/RECORD +0 -36
  42. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/LICENSE +0 -0
  43. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/top_level.txt +0 -0
jinns/utils/_spinn.py CHANGED
@@ -3,54 +3,53 @@ Implements utility function to create Separable PINNs
3
3
  https://arxiv.org/abs/2211.08761
4
4
  """
5
5
 
6
+ from dataclasses import InitVar
7
+ from typing import Callable, Literal
6
8
  import jax
7
9
  import jax.numpy as jnp
8
10
  import equinox as eqx
11
+ from jaxtyping import Key, Array, Float, PyTree
9
12
 
10
13
 
11
14
  class _SPINN(eqx.Module):
12
15
  """
13
16
  Construct a Separable PINN as proposed in
14
17
  Cho et al., _Separable Physics-Informed Neural Networks_, NeurIPS, 2023
18
+
19
+ Parameters
20
+ ----------
21
+ key : InitVar[Key]
22
+ A jax random key for the layer initializations.
23
+ d : int
24
+ The number of dimensions to treat separately.
25
+ eqx_list : InitVar[tuple[tuple[Callable, int, int] | Callable, ...]]
26
+ A tuple of tuples of successive equinox modules and activation functions to
27
+ describe the PINN architecture. The inner tuples must have the eqx module or
28
+ activation function as first item, other items represents arguments
29
+ that could be required (eg. the size of the layer).
30
+ The `key` argument need not be given.
31
+ Thus typical example is `eqx_list=
32
+ ((eqx.nn.Linear, 2, 20),
33
+ jax.nn.tanh,
34
+ (eqx.nn.Linear, 20, 20),
35
+ jax.nn.tanh,
36
+ (eqx.nn.Linear, 20, 20),
37
+ jax.nn.tanh,
38
+ (eqx.nn.Linear, 20, 1)
39
+ )`.
15
40
  """
16
41
 
17
- layers: list
18
- separated_mlp: list
19
- d: int
20
- r: int
21
- m: int
42
+ d: int = eqx.field(static=True, kw_only=True)
22
43
 
23
- def __init__(self, key, d, r, eqx_list, m=1):
24
- """
25
- Parameters
26
- ----------
27
- key
28
- A jax random key
29
- d
30
- An integer. The number of dimensions to treat separately
31
- r
32
- An integer. The dimension of the embedding
33
- eqx_list
34
- A list of list of successive equinox modules and activation functions to
35
- describe *each separable PINN architecture*.
36
- The inner lists have the eqx module or
37
- axtivation function as first item, other items represents arguments
38
- that could be required (eg. the size of the layer).
39
- __Note:__ the `key` argument need not be given.
40
- Thus typical example is `eqx_list=
41
- [[eqx.nn.Linear, 1, 20],
42
- [jax.nn.tanh],
43
- [eqx.nn.Linear, 20, 20],
44
- [jax.nn.tanh],
45
- [eqx.nn.Linear, 20, 20],
46
- [jax.nn.tanh],
47
- [eqx.nn.Linear, 20, r]
48
- ]`
49
- """
50
- self.d = d
51
- self.r = r
52
- self.m = m
44
+ key: InitVar[Key] = eqx.field(kw_only=True)
45
+ eqx_list: InitVar[tuple[tuple[Callable, int, int] | Callable, ...]] = eqx.field(
46
+ kw_only=True
47
+ )
48
+
49
+ layers: list = eqx.field(init=False)
50
+ separated_mlp: list = eqx.field(init=False)
53
51
 
52
+ def __post_init__(self, key, eqx_list):
54
53
  self.separated_mlp = []
55
54
  for _ in range(self.d):
56
55
  self.layers = []
@@ -62,7 +61,9 @@ class _SPINN(eqx.Module):
62
61
  self.layers.append(l[0](*l[1:], key=subkey))
63
62
  self.separated_mlp.append(self.layers)
64
63
 
65
- def __call__(self, t, x):
64
+ def __call__(
65
+ self, t: Float[Array, "1"], x: Float[Array, "omega_dim"]
66
+ ) -> Float[Array, "d embed_dim*output_dim"]:
66
67
  if t is not None:
67
68
  dimensions = jnp.concatenate([t, x.flatten()], axis=0)
68
69
  else:
@@ -78,53 +79,67 @@ class _SPINN(eqx.Module):
78
79
 
79
80
  class SPINN(eqx.Module):
80
81
  """
81
- Basically a wrapper around the `__call__` function to be able to give a type to
82
- our former `self.u`
83
- The function create_SPINN has the role to population the `__call__` function
82
+ A SPINN object compatible with the rest of jinns.
83
+ This is typically created with `create_SPINN`.
84
84
 
85
85
  **NOTE**: SPINNs with `t` and `x` as inputs are best used with a
86
86
  DataGenerator with `self.cartesian_product=False` for memory consideration
87
+
88
+ Parameters
89
+ ----------
90
+ d : int
91
+ The number of dimensions to treat separately.
92
+
87
93
  """
88
94
 
89
- d: int
90
- r: int
91
- eq_type: str = eqx.field(static=True)
92
- m: int
93
- params: eqx.Module
94
- static: eqx.Module
95
+ d: int = eqx.field(static=True, kw_only=True)
96
+ r: int = eqx.field(static=True, kw_only=True)
97
+ eq_type: str = eqx.field(static=True, kw_only=True)
98
+ m: int = eqx.field(static=True, kw_only=True)
99
+
100
+ spinn_mlp: InitVar[eqx.Module] = eqx.field(kw_only=True)
101
+
102
+ params: PyTree = eqx.field(init=False)
103
+ static: PyTree = eqx.field(init=False, static=True)
95
104
 
96
- def __init__(self, spinn_mlp, d, r, eq_type, m):
97
- self.d, self.r, self.m = d, r, m
105
+ def __post_init__(self, spinn_mlp):
98
106
  self.params, self.static = eqx.partition(spinn_mlp, eqx.is_inexact_array)
99
- self.eq_type = eq_type
100
107
 
101
- def init_params(self):
108
+ def init_params(self) -> PyTree:
109
+ """
110
+ Returns an initial set of parameters
111
+ """
102
112
  return self.params
103
113
 
104
- def __call__(self, *args):
114
+ def __call__(self, *args) -> Float[Array, "output_dim"]:
115
+ """
116
+ Calls `eval_nn` with rearranged arguments
117
+ """
105
118
  if self.eq_type == "statio_PDE":
106
119
  (x, params) = args
107
120
  try:
108
- spinn = eqx.combine(params["nn_params"], self.static)
109
- except (KeyError, TypeError) as e:
121
+ spinn = eqx.combine(params.nn_params, self.static)
122
+ except (KeyError, AttributeError, TypeError) as e:
110
123
  spinn = eqx.combine(params, self.static)
111
124
  v_model = jax.vmap(spinn, (0))
112
125
  res = v_model(t=None, x=x)
113
- return self._eval_nn(res)
126
+ return self.eval_nn(res)
114
127
  if self.eq_type == "nonstatio_PDE":
115
128
  (t, x, params) = args
116
129
  try:
117
- spinn = eqx.combine(params["nn_params"], self.static)
118
- except (KeyError, TypeError) as e:
130
+ spinn = eqx.combine(params.nn_params, self.static)
131
+ except (KeyError, AttributeError, TypeError) as e:
119
132
  spinn = eqx.combine(params, self.static)
120
133
  v_model = jax.vmap(spinn, ((0, 0)))
121
134
  res = v_model(t, x)
122
- return self._eval_nn(res)
135
+ return self.eval_nn(res)
123
136
  raise RuntimeError("Wrong parameter value for eq_type")
124
137
 
125
- def _eval_nn(self, res):
138
+ def eval_nn(
139
+ self, res: Float[Array, "d embed_dim*output_dim"]
140
+ ) -> Float[Array, "output_dim"]:
126
141
  """
127
- common content of apply_fn put here in order to factorize code
142
+ Evaluate the SPINN on some inputs with some params.
128
143
  """
129
144
  a = ", ".join([f"{chr(97 + d)}z" for d in range(res.shape[1])])
130
145
  b = "".join([f"{chr(97 + d)}" for d in range(res.shape[1])])
@@ -148,59 +163,71 @@ class SPINN(eqx.Module):
148
163
  return res
149
164
 
150
165
 
151
- def create_SPINN(key, d, r, eqx_list, eq_type, m=1):
166
+ def create_SPINN(
167
+ key: Key,
168
+ d: int,
169
+ r: int,
170
+ eqx_list: tuple[tuple[Callable, int, int] | Callable, ...],
171
+ eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
172
+ m: int = 1,
173
+ ) -> SPINN:
152
174
  """
153
175
  Utility function to create a SPINN neural network with the equinox
154
176
  library.
155
177
 
156
- *Note* that a SPINN is not vmapped from the outside and expects batch of the
157
- same size for each input. It outputs directly a solution of shape
158
- (batchsize, batchsize). See the paper for more details.
178
+ *Note* that a SPINN is not vmapped and expects the
179
+ same batch size for each of its input axis. It directly outputs a solution
180
+ of shape `(batchsize, batchsize)`. See the paper for more details.
159
181
 
160
182
  Parameters
161
183
  ----------
162
184
  key
163
- A jax random key that will be used to initialize the network parameters
185
+ A JAX random key that will be used to initialize the network parameters
164
186
  d
165
- An integer. The number of dimensions to treat separately
187
+ The number of dimensions to treat separately.
166
188
  r
167
- An integer. The dimension of the embedding
189
+ An integer. The dimension of the embedding.
168
190
  eqx_list
169
- A list of list of successive equinox modules and activation functions to
170
- describe *each separable PINN architecture*.
171
- The inner lists have the eqx module or
172
- axtivation function as first item, other items represents arguments
191
+ A tuple of tuples of successive equinox modules and activation functions to
192
+ describe the PINN architecture. The inner tuples must have the eqx module or
193
+ activation function as first item, other items represents arguments
173
194
  that could be required (eg. the size of the layer).
174
- __Note:__ the `key` argument need not be given.
175
- Thus typical example is `eqx_list=
176
- [[eqx.nn.Linear, 1, 20],
177
- [jax.nn.tanh],
178
- [eqx.nn.Linear, 20, 20],
179
- [jax.nn.tanh],
180
- [eqx.nn.Linear, 20, 20],
181
- [jax.nn.tanh],
182
- [eqx.nn.Linear, 20, r]
183
- ]`
184
- eq_type
195
+ The `key` argument need not be given.
196
+ Thus typical example is
197
+ `eqx_list=((eqx.nn.Linear, 2, 20),
198
+ jax.nn.tanh,
199
+ (eqx.nn.Linear, 20, 20),
200
+ jax.nn.tanh,
201
+ (eqx.nn.Linear, 20, 20),
202
+ jax.nn.tanh,
203
+ (eqx.nn.Linear, 20, 1)
204
+ )`.
205
+ eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
185
206
  A string with three possibilities.
186
207
  "ODE": the PINN is called with one input `t`.
187
208
  "statio_PDE": the PINN is called with one input `x`, `x`
188
209
  can be high dimensional.
189
210
  "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
190
211
  can be high dimensional.
212
+ **Note**: the input dimension as given in eqx_list has to match the sum
213
+ of the dimension of `t` + the dimension of `x` or the output dimension
214
+ after the `input_transform` function.
191
215
  m
192
- An integer. The output dimension of the neural network. According to
216
+ The output dimension of the neural network. According to
193
217
  the SPINN article, a total embedding dimension of `r*m` is defined. We
194
218
  then sum groups of `r` embedding dimensions to compute each output.
195
219
  Default is 1.
196
220
 
197
- **NOTE**: SPINNs with `t` and `x` as inputs are best used with a
198
- DataGenerator with `self.cartesian_product=False` for memory consideration
221
+ !!! note
222
+ SPINNs with `t` and `x` as inputs are best used with a
223
+ DataGenerator with `self.cartesian_product=False` for memory
224
+ consideration
199
225
 
200
226
 
201
227
  Returns
202
228
  -------
203
- `u`, a :class:`.SPINN` object which inherits from `eqx.Module` (hence callable).
229
+ spinn
230
+ An instanciated SPINN
204
231
 
205
232
  Raises
206
233
  ------
@@ -235,7 +262,7 @@ def create_SPINN(key, d, r, eqx_list, eq_type, m=1):
235
262
  "Too many dimensions, not enough letters available in jnp.einsum"
236
263
  )
237
264
 
238
- spinn_mlp = _SPINN(key, d, r, eqx_list, m)
239
- spinn = SPINN(spinn_mlp, d, r, eq_type, m)
265
+ spinn_mlp = _SPINN(key=key, d=d, eqx_list=eqx_list)
266
+ spinn = SPINN(spinn_mlp=spinn_mlp, d=d, r=r, eq_type=eq_type, m=m)
240
267
 
241
268
  return spinn
jinns/utils/_types.py ADDED
@@ -0,0 +1,64 @@
1
+ from __future__ import (
2
+ annotations,
3
+ ) # https://docs.python.org/3/library/typing.html#constant
4
+
5
+ from typing import TypeAlias, TYPE_CHECKING, NewType
6
+ from jaxtyping import Int
7
+
8
+ if TYPE_CHECKING:
9
+ from jinns.loss._LossPDE import (
10
+ LossPDEStatio,
11
+ LossPDENonStatio,
12
+ SystemLossPDE,
13
+ )
14
+
15
+ from jinns.loss._LossODE import LossODE, SystemLossODE
16
+ from jinns.parameters._params import Params, ParamsDict
17
+ from jinns.data._DataGenerators import (
18
+ DataGeneratorODE,
19
+ CubicMeshPDEStatio,
20
+ CubicMeshPDENonStatio,
21
+ DataGeneratorObservations,
22
+ DataGeneratorParameter,
23
+ DataGeneratorObservationsMultiPINNs,
24
+ )
25
+
26
+ from jinns.loss import DynamicLoss
27
+ from jinns.data._Batchs import *
28
+ from jinns.utils._pinn import PINN
29
+ from jinns.utils._hyperpinn import HYPERPINN
30
+ from jinns.utils._spinn import SPINN
31
+ from jinns.utils._containers import *
32
+ from jinns.validation._validation import AbstractValidationModule
33
+
34
+ AnyLoss: TypeAlias = (
35
+ LossPDEStatio | LossPDENonStatio | SystemLossPDE | LossODE | SystemLossODE
36
+ )
37
+
38
+ AnyParams: TypeAlias = Params | ParamsDict
39
+
40
+ AnyDataGenerator: TypeAlias = (
41
+ DataGeneratorODE | CubicMeshPDEStatio | CubicMeshPDENonStatio
42
+ )
43
+
44
+ AnyPINN: TypeAlias = PINN | HYPERPINN | SPINN
45
+
46
+ AnyBatch: TypeAlias = ODEBatch | PDEStatioBatch | PDENonStatioBatch
47
+ rar_operands = NewType(
48
+ "rar_operands", tuple[AnyLoss, AnyParams, AnyDataGenerator, Int]
49
+ )
50
+
51
+ main_carry = NewType(
52
+ "main_carry",
53
+ tuple[
54
+ Int,
55
+ AnyLoss,
56
+ OptimizationContainer,
57
+ OptimizationExtraContainer,
58
+ DataGeneratorContainer,
59
+ AbstractValidationModule,
60
+ LossContainer,
61
+ StoredObjectContainer,
62
+ Float[Array, "n_iter"],
63
+ ],
64
+ )
jinns/utils/_utils.py CHANGED
@@ -7,9 +7,10 @@ from operator import getitem
7
7
  import numpy as np
8
8
  import jax
9
9
  import jax.numpy as jnp
10
+ from jaxtyping import PyTree, Array
10
11
 
11
12
 
12
- def _check_nan_in_pytree(pytree):
13
+ def _check_nan_in_pytree(pytree: PyTree) -> bool:
13
14
  """
14
15
  Check if there is a NaN value anywhere is the pytree
15
16
 
@@ -25,40 +26,14 @@ def _check_nan_in_pytree(pytree):
25
26
  """
26
27
  return jnp.any(
27
28
  jnp.array(
28
- list(
29
- jax.tree_util.tree_leaves(
30
- jax.tree_util.tree_map(lambda x: jnp.any(jnp.isnan(x)), pytree)
31
- )
29
+ jax.tree_util.tree_leaves(
30
+ jax.tree_util.tree_map(lambda x: jnp.any(jnp.isnan(x)), pytree)
32
31
  )
33
32
  )
34
33
  )
35
34
 
36
35
 
37
- def _tracked_parameters(params, tracked_params_key_list):
38
- """
39
- Returns a pytree with the same structure as params with True is the
40
- parameter is tracked False otherwise
41
- """
42
-
43
- def set_nested_item(dataDict, mapList, val):
44
- """
45
- Set item in nested dictionary
46
- https://stackoverflow.com/questions/54137991/how-to-update-values-in-nested-dictionary-if-keys-are-in-a-list
47
- """
48
- reduce(getitem, mapList[:-1], dataDict)[mapList[-1]] = val
49
- return dataDict
50
-
51
- tracked_params = jax.tree_util.tree_map(
52
- lambda x: False, params
53
- ) # init with all False
54
-
55
- for key_list in tracked_params_key_list:
56
- tracked_params = set_nested_item(tracked_params, key_list, True)
57
-
58
- return tracked_params
59
-
60
-
61
- def _get_grid(in_array):
36
+ def _get_grid(in_array: Array) -> Array:
62
37
  """
63
38
  From an array of shape (B, D), D > 1, get the grid array, i.e., an array of
64
39
  shape (B, B, ...(D times)..., B, D): along the last axis we have the array
@@ -74,31 +49,7 @@ def _get_grid(in_array):
74
49
  return in_array
75
50
 
76
51
 
77
- def _get_vmap_in_axes_params(eq_params_batch_dict, params):
78
- """
79
- Return the input vmap axes when there is batch(es) of parameters to vmap
80
- over. The latter are designated by keys in eq_params_batch_dict
81
- If eq_params_batch_dict (ie no additional parameter batch), we return None
82
- """
83
- if eq_params_batch_dict is None:
84
- return (None,)
85
- # We use pytree indexing of vmapped axes and vmap on axis
86
- # 0 of the eq_parameters for which we have a batch
87
- # this is for a fine-grained vmaping
88
- # scheme over the params
89
- vmap_in_axes_params = (
90
- {
91
- "nn_params": None,
92
- "eq_params": {
93
- k: (0 if k in eq_params_batch_dict.keys() else None)
94
- for k in params["eq_params"].keys()
95
- },
96
- },
97
- )
98
- return vmap_in_axes_params
99
-
100
-
101
- def _check_user_func_return(r, shape):
52
+ def _check_user_func_return(r: Array | int, shape: tuple) -> Array | int:
102
53
  """
103
54
  Correctly handles the result from a user defined function (eg a boundary
104
55
  condition) to get the correct broadcast
@@ -115,108 +66,3 @@ def _check_user_func_return(r, shape):
115
66
  # the reshape below avoids a missing (1,) ending dimension
116
67
  # depending on how the user has coded the inital function
117
68
  return r.reshape(shape)
118
-
119
-
120
- def _set_derivatives(params, loss_term, derivative_keys):
121
- """
122
- Given derivative_keys, the parameters wrt which we want to compute
123
- gradients in the loss, we set stop_gradient operators to not take the
124
- derivatives with respect to the others. Note that we only operator at
125
- top level
126
- """
127
- try:
128
- params = {
129
- k: (
130
- value
131
- if k in derivative_keys[loss_term]
132
- else jax.lax.stop_gradient(value)
133
- )
134
- for k, value in params.items()
135
- }
136
- except KeyError: # if the loss_term key has not been specified we
137
- # only take gradients wrt "nn_params", all the other entries have
138
- # stopped gradient
139
- params = {
140
- k: value if k in ["nn_params"] else jax.lax.stop_gradient(value)
141
- for k, value in params.items()
142
- }
143
-
144
- return params
145
-
146
-
147
- def _extract_nn_params(params_dict, nn_key):
148
- """
149
- Given a params_dict for system loss (ie "nn_params" and "eq_params" as main
150
- keys which contain dicts for each PINN (the nn_keys)) we extract the
151
- corresponding "nn_params" for `nn_key` and reform a dict with "nn_params"
152
- as main key as expected by the PINN/SPINN apply_fn
153
- """
154
- try:
155
- return {
156
- "nn_params": params_dict["nn_params"][nn_key],
157
- "eq_params": params_dict["eq_params"][nn_key],
158
- }
159
- except (KeyError, IndexError) as e:
160
- return {
161
- "nn_params": params_dict["nn_params"][nn_key],
162
- "eq_params": params_dict["eq_params"],
163
- }
164
-
165
-
166
- def euler_maruyama_density(t, x, s, y, params, Tmax=1):
167
- eps = 1e-6
168
- delta = jnp.abs(t - s) * Tmax
169
- mu = params["alpha_sde"] * (params["mu_sde"] - y) * delta
170
- var = params["sigma_sde"] ** 2 * delta
171
- return (
172
- 1 / jnp.sqrt(2 * jnp.pi * var) * jnp.exp(-0.5 * ((x - y) - mu) ** 2 / var) + eps
173
- )
174
-
175
-
176
- def log_euler_maruyama_density(t, x, s, y, params):
177
- eps = 1e-6
178
- delta = jnp.abs(t - s)
179
- mu = params["alpha_sde"] * (params["mu_sde"] - y) * delta
180
- logvar = params["logvar_sde"]
181
- return (
182
- -0.5
183
- * (jnp.log(2 * jnp.pi * delta) + logvar + ((x - y) - mu) ** 2 / jnp.exp(logvar))
184
- + eps
185
- )
186
-
187
-
188
- def euler_maruyama(x0, alpha, mu, sigma, T, N):
189
- """
190
- Simulate 1D diffusion process with simple parametrization using the Euler
191
- Maruyama method in the interval [0, T]
192
- """
193
- path = [np.array([x0])]
194
-
195
- time_steps, step_size = np.linspace(0, T, N, retstep=True)
196
- for _ in time_steps[1:]:
197
- path.append(
198
- path[-1]
199
- + step_size * (alpha * (mu - path[-1]))
200
- + sigma * np.random.normal(loc=0.0, scale=np.sqrt(step_size))
201
- )
202
-
203
- return time_steps, np.stack(path)
204
-
205
-
206
- def _update_eq_params_dict(params, param_batch_dict):
207
- # update params["eq_params"] with a batch of eq_params
208
- # we avoid side_effect by recreating the dict `params`
209
- # TODO transform `params` in a NamedTuple to be able to use _replace
210
- # see Issue #1
211
- param_batch_dict_ = param_batch_dict | {
212
- k: None for k in set(params["eq_params"].keys()) - set(param_batch_dict.keys())
213
- }
214
- params = {"nn_params": params["nn_params"]} | {
215
- "eq_params": jax.tree_util.tree_map(
216
- lambda p, q: q if q is not None else p,
217
- params["eq_params"],
218
- param_batch_dict_,
219
- )
220
- }
221
-
222
- return params