jinns 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +904 -1203
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +322 -167
  9. jinns/loss/_LossODE.py +324 -322
  10. jinns/loss/_LossPDE.py +652 -1027
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +87 -41
  13. jinns/loss/{_Losses.py → _loss_utils.py} +101 -45
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +521 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +183 -39
  22. jinns/solver/_solve.py +151 -124
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -44
  25. jinns/utils/_hyperpinn.py +224 -119
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +113 -86
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +48 -140
  32. jinns-1.1.0.dist-info/AUTHORS +2 -0
  33. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/METADATA +5 -4
  34. jinns-1.1.0.dist-info/RECORD +39 -0
  35. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/WHEEL +1 -1
  36. jinns/experimental/_sinuspinn.py +0 -135
  37. jinns/experimental/_spectralpinn.py +0 -87
  38. jinns/solver/_seq2seq.py +0 -157
  39. jinns/utils/_optim.py +0 -147
  40. jinns/utils/_utils_uspinn.py +0 -727
  41. jinns-0.9.0.dist-info/RECORD +0 -36
  42. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/LICENSE +0 -0
  43. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,727 +0,0 @@
1
- import numpy as np
2
- import jax
3
- import jax.numpy as jnp
4
- import optax
5
- import equinox as eqx
6
- from functools import reduce
7
- from operator import getitem
8
-
9
-
10
- def _check_nan_in_pytree(pytree):
11
- """
12
- Check if there is a NaN value anywhere is the pytree
13
-
14
- Parameters
15
- ----------
16
- pytree
17
- A pytree
18
-
19
- Returns
20
- -------
21
- res
22
- A boolean. True if any of the pytree content is NaN
23
- """
24
- return jnp.any(
25
- jnp.array(
26
- [
27
- value
28
- for value in jax.tree_util.tree_leaves(
29
- jax.tree_util.tree_map(lambda x: jnp.any(jnp.isnan(x)), pytree)
30
- )
31
- ]
32
- )
33
- )
34
-
35
-
36
- def _tracked_parameters(params, tracked_params_key_list):
37
- """
38
- Returns a pytree with the same structure as params with True is the
39
- parameter is tracked False otherwise
40
- """
41
-
42
- def set_nested_item(dataDict, mapList, val):
43
- """
44
- Set item in nested dictionary
45
- https://stackoverflow.com/questions/54137991/how-to-update-values-in-nested-dictionary-if-keys-are-in-a-list
46
- """
47
- reduce(getitem, mapList[:-1], dataDict)[mapList[-1]] = val
48
- return dataDict
49
-
50
- tracked_params = jax.tree_util.tree_map(
51
- lambda x: False, params
52
- ) # init with all False
53
-
54
- for key_list in tracked_params_key_list:
55
- tracked_params = set_nested_item(tracked_params, key_list, True)
56
-
57
- return tracked_params
58
-
59
-
60
- class _MLP(eqx.Module):
61
- """
62
- Class to construct an equinox module from a key and a eqx_list. To be used
63
- in pair with the function `create_PINN`
64
- """
65
-
66
- layers: list
67
-
68
- def __init__(self, key, eqx_list):
69
- """
70
- Parameters
71
- ----------
72
- key
73
- A jax random key
74
- eqx_list
75
- A list of list of successive equinox modules and activation functions to
76
- describe the PINN architecture. The inner lists have the eqx module or
77
- axtivation function as first item, other items represents arguments
78
- that could be required (eg. the size of the layer).
79
- __Note:__ the `key` argument need not be given.
80
- Thus typical example is `eqx_list=
81
- [[eqx.nn.Linear, 2, 20],
82
- [jax.nn.tanh],
83
- [eqx.nn.Linear, 20, 20],
84
- [jax.nn.tanh],
85
- [eqx.nn.Linear, 20, 20],
86
- [jax.nn.tanh],
87
- [eqx.nn.Linear, 20, 1]
88
- ]`
89
- """
90
-
91
- self.layers = []
92
- # TODO we are limited currently in the number of layer type we can
93
- # parse and we lack some safety checks
94
- for l in eqx_list:
95
- if len(l) == 1:
96
- self.layers.append(l[0])
97
- else:
98
- # By default we append a random key at the end of the
99
- # arguments fed into a layer module call
100
- key, subkey = jax.random.split(key, 2)
101
- # the argument key is keyword only
102
- self.layers.append(l[0](*l[1:], key=subkey))
103
-
104
- def __call__(self, t):
105
- for layer in self.layers:
106
- t = layer(t)
107
- return t
108
-
109
-
110
- class PINN:
111
- """
112
- Basically a wrapper around the `__call__` function to be able to give a type to
113
- our former `self.u`
114
- The function create_PINN has the role to population the `__call__` function
115
- """
116
-
117
- def __init__(self, key, eqx_list, output_slice=None):
118
- _pinn = _MLP(key, eqx_list)
119
- self.params, self.static = eqx.partition(_pinn, eqx.is_inexact_array)
120
- self.output_slice = output_slice
121
-
122
- def init_params(self):
123
- return self.params
124
-
125
- def __call__(self, *args, **kwargs):
126
- return self.apply_fn(self, *args, **kwargs)
127
-
128
-
129
- def create_PINN(
130
- key,
131
- eqx_list,
132
- eq_type,
133
- dim_x=0,
134
- with_eq_params=None,
135
- input_transform=None,
136
- output_transform=None,
137
- shared_pinn_outputs=None,
138
- ):
139
- """
140
- Utility function to create a standard PINN neural network with the equinox
141
- library.
142
-
143
- Parameters
144
- ----------
145
- key
146
- A jax random key that will be used to initialize the network parameters
147
- eqx_list
148
- A list of list of successive equinox modules and activation functions to
149
- describe the PINN architecture. The inner lists have the eqx module or
150
- axtivation function as first item, other items represents arguments
151
- that could be required (eg. the size of the layer).
152
- __Note:__ the `key` argument need not be given.
153
- Thus typical example is `eqx_list=
154
- [[eqx.nn.Linear, 2, 20],
155
- [jax.nn.tanh],
156
- [eqx.nn.Linear, 20, 20],
157
- [jax.nn.tanh],
158
- [eqx.nn.Linear, 20, 20],
159
- [jax.nn.tanh],
160
- [eqx.nn.Linear, 20, 1]
161
- ]`
162
- eq_type
163
- A string with three possibilities.
164
- "ODE": the PINN is called with one input `t`.
165
- "statio_PDE": the PINN is called with one input `x`, `x`
166
- can be high dimensional.
167
- "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
168
- can be high dimensional.
169
- **Note: the input dimension as given in eqx_list has to match the sum
170
- of the dimension of `t` + the dimension of `x` + the number of
171
- parameters in `eq_params` if with_eq_params is `True` (see below)**
172
- dim_x
173
- An integer. The dimension of `x`. Default `0`
174
- with_eq_params
175
- Default is None. Otherwise a list of keys from the dict `eq_params`
176
- that the network also takes as inputs.
177
- the equation parameters (`eq_params`).
178
- **If some keys are provided, the input dimension
179
- as given in eqx_list must take into account the number of such provided
180
- keys (i.e., the input dimension is the addition of the dimension of ``t``
181
- + the dimension of ``x`` + the number of ``eq_params``)**
182
- input_transform
183
- A function that will be called before entering the PINN. Its output(s)
184
- must mathc the PINN inputs.
185
- output_transform
186
- A function with arguments the same input(s) as the PINN AND the PINN
187
- output that will be called after exiting the PINN
188
- shared_pinn_outputs
189
- A tuple of jnp.s_[] (slices) to determine the different output for each
190
- network. In this case we return a list of PINNs, one for each output in
191
- shared_pinn_outputs. This is useful to create PINNs that share the
192
- same network and same parameters. Default is None, we only return one PINN.
193
-
194
-
195
- Returns
196
- -------
197
- init_fn
198
- A function which (re-)initializes the PINN parameters with the provided
199
- jax random key
200
- apply_fn
201
- A function to apply the neural network on given inputs for given
202
- parameters. A typical call will be of the form `u(t, nn_params)` for
203
- ODE or `u(t, x, nn_params)` for nD PDEs (`x` being multidimensional)
204
- or even `u(t, x, nn_params, eq_params)` if with_eq_params is `True`
205
-
206
- Raises
207
- ------
208
- RuntimeError
209
- If the parameter value for eq_type is not in `["ODE", "statio_PDE",
210
- "nonstatio_PDE"]`
211
- RuntimeError
212
- If we have a `dim_x > 0` and `eq_type == "ODE"`
213
- or if we have a `dim_x = 0` and `eq_type != "ODE"`
214
- """
215
- if eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
216
- raise RuntimeError("Wrong parameter value for eq_type")
217
-
218
- if eq_type == "ODE" and dim_x != 0:
219
- raise RuntimeError("Wrong parameter combination eq_type and dim_x")
220
-
221
- if eq_type != "ODE" and dim_x == 0:
222
- raise RuntimeError("Wrong parameter combination eq_type and dim_x")
223
-
224
- dim_t = 0 if eq_type == "statio_PDE" else 1
225
- dim_in_params = len(with_eq_params) if with_eq_params is not None else 0
226
- try:
227
- nb_inputs_declared = eqx_list[0][1] # normally we look for 2nd ele of 1st layer
228
- except IndexError:
229
- nb_inputs_declared = eqx_list[1][1]
230
- # but we can have, eg, a flatten first layer
231
-
232
- try:
233
- nb_outputs_declared = eqx_list[-1][2] # normally we look for 3rd ele of
234
- # last layer
235
- except IndexError:
236
- nb_outputs_declared = eqx_list[-2][2]
237
- # but we can have, eg, a `jnp.exp` last layer
238
-
239
- # NOTE Currently the check below is disabled because we added
240
- # input_transform
241
- # if dim_t + dim_x + dim_in_params != nb_inputs_declared:
242
- # raise RuntimeError("Error in the declarations of the number of parameters")
243
-
244
- if eq_type == "ODE":
245
- if with_eq_params is None:
246
-
247
- def apply_fn(self, t, u_params, eq_params=None):
248
- model = eqx.combine(u_params, self.static)
249
- t = t[
250
- None
251
- ] # Note that we added a dimension to t which is lacking for the ODE batches
252
- if output_transform is None:
253
- if input_transform is not None:
254
- res = model(input_transform(t)).squeeze()
255
- else:
256
- res = model(t).squeeze()
257
- else:
258
- if input_transform is not None:
259
- res = output_transform(t, model(input_transform(t)).squeeze())
260
- else:
261
- res = output_transform(t, model(t).squeeze())
262
- if self.output_slice is not None:
263
- return res[self.output_slice]
264
- else:
265
- return res
266
-
267
- else:
268
-
269
- def apply_fn(self, t, u_params, eq_params):
270
- model = eqx.combine(u_params, self.static)
271
- t = t[
272
- None
273
- ] # We added a dimension to t which is lacking for the ODE batches
274
- eq_params_flatten = jnp.concatenate(
275
- [e.ravel() for k, e in eq_params.items() if k in with_eq_params]
276
- )
277
- t_eq_params = jnp.concatenate([t, eq_params_flatten], axis=-1)
278
-
279
- if output_transform is None:
280
- if input_transform is not None:
281
- res = model(input_transform(t_eq_params)).squeeze()
282
- else:
283
- res = model(t_eq_params).squeeze()
284
- else:
285
- if input_transform is not None:
286
- res = output_transform(
287
- t_eq_params,
288
- model(input_transform(t_eq_params)).squeeze(),
289
- )
290
- else:
291
- res = output_transform(
292
- t_eq_params, model(t_eq_params).squeeze()
293
- )
294
-
295
- if self.output_slice is not None:
296
- return res[self.output_slice]
297
- else:
298
- return res
299
-
300
- elif eq_type == "statio_PDE":
301
- # Here we add an argument `x` which can be high dimensional
302
- if with_eq_params is None:
303
-
304
- def apply_fn(self, x, u_params, eq_params=None):
305
- model = eqx.combine(u_params, self.static)
306
-
307
- if output_transform is None:
308
- if input_transform is not None:
309
- res = model(input_transform(x)).squeeze()
310
- else:
311
- res = model(x).squeeze()
312
- else:
313
- if input_transform is not None:
314
- res = output_transform(x, model(input_transform(x)).squeeze())
315
- else:
316
- res = output_transform(x, model(x).squeeze()).squeeze()
317
-
318
- if self.output_slice is not None:
319
- res = res[self.output_slice]
320
-
321
- # force (1,) output for non vectorial solution (consistency)
322
- if not res.shape:
323
- return jnp.expand_dims(res, axis=-1)
324
- else:
325
- return res
326
-
327
- else:
328
-
329
- def apply_fn(self, x, u_params, eq_params):
330
- model = eqx.combine(u_params, self.static)
331
- eq_params_flatten = jnp.concatenate(
332
- [e.ravel() for k, e in eq_params.items() if k in with_eq_params]
333
- )
334
- x_eq_params = jnp.concatenate([x, eq_params_flatten], axis=-1)
335
-
336
- if output_transform is None:
337
- if input_transform is not None:
338
- res = model(input_transform(x_eq_params)).squeeze()
339
- else:
340
- res = model(x_eq_params).squeeze()
341
- else:
342
- if input_transform is not None:
343
- res = output_transform(
344
- x_eq_params,
345
- model(input_transform(x_eq_params)).squeeze(),
346
- )
347
- else:
348
- res = output_transform(
349
- x_eq_params, model(x_eq_params).squeeze()
350
- )
351
-
352
- if self.output_slice is not None:
353
- res = res[self.output_slice]
354
-
355
- # force (1,) output for non vectorial solution (consistency)
356
- if not res.shape:
357
- return jnp.expand_dims(res, axis=-1)
358
- else:
359
- return res
360
-
361
- elif eq_type == "nonstatio_PDE":
362
- # Here we add an argument `x` which can be high dimensional
363
- if with_eq_params is None:
364
-
365
- def apply_fn(self, t, x, u_params, eq_params=None):
366
- model = eqx.combine(u_params, self.static)
367
- t_x = jnp.concatenate([t, x], axis=-1)
368
-
369
- if output_transform is None:
370
- if input_transform is not None:
371
- res = model(input_transform(t_x)).squeeze()
372
- else:
373
- res = model(t_x).squeeze()
374
- else:
375
- if input_transform is not None:
376
- res = output_transform(
377
- t_x, model(input_transform(t_x)).squeeze()
378
- )
379
- else:
380
- res = output_transform(t_x, model(t_x).squeeze())
381
-
382
- if self.output_slice is not None:
383
- res = res[self.output_slice]
384
-
385
- ## force (1,) output for non vectorial solution (consistency)
386
- if not res.shape:
387
- return jnp.expand_dims(res, axis=-1)
388
- else:
389
- return res
390
-
391
- else:
392
-
393
- def apply_fn(self, t, x, u_params, eq_params):
394
- model = eqx.combine(u_params, self.static)
395
- t_x = jnp.concatenate([t, x], axis=-1)
396
- eq_params_flatten = jnp.concatenate(
397
- [e.ravel() for k, e in eq_params.items() if k in with_eq_params]
398
- )
399
- t_x_eq_params = jnp.concatenate([t_x, eq_params_flatten], axis=-1)
400
-
401
- if output_transform is None:
402
- if input_transform is not None:
403
- res = model(input_transform(t_x_eq_params)).squeeze()
404
- else:
405
- res = model(t_x_eq_params).squeeze()
406
- else:
407
- if input_transform is not None:
408
- res = output_transform(
409
- t_x_eq_params,
410
- model(input_transform(t_x_eq_params)).squeeze(),
411
- )
412
- else:
413
- res = output_transform(
414
- t_x_eq_params,
415
- model(input_transform(t_x_eq_params)).squeeze(),
416
- )
417
-
418
- if self.output_slice is not None:
419
- res = res[self.output_slice]
420
-
421
- # force (1,) output for non vectorial solution (consistency)
422
- if not res.shape:
423
- return jnp.expand_dims(res, axis=-1)
424
- else:
425
- return res
426
-
427
- else:
428
- raise RuntimeError("Wrong parameter value for eq_type")
429
-
430
- if shared_pinn_outputs is not None:
431
- pinns = []
432
- static = None
433
- for output_slice in shared_pinn_outputs:
434
- pinn = PINN(key, eqx_list, output_slice)
435
- pinn.apply_fn = apply_fn
436
- # all the pinns are in fact the same so we share the same static
437
- if static is None:
438
- static = pinn.static
439
- else:
440
- pinn.static = static
441
- pinns.append(pinn)
442
- return pinns
443
- else:
444
- pinn = PINN(key, eqx_list)
445
- pinn.apply_fn = apply_fn
446
- return pinn
447
-
448
-
449
- class _SPINN(eqx.Module):
450
- """
451
- Construct a Separable PINN as proposed in
452
- Cho et al., _Separable Physics-Informed Neural Networks_, NeurIPS, 2023
453
- """
454
-
455
- layers: list
456
- separated_mlp: list
457
- d: int
458
- r: int
459
- m: int
460
-
461
- def __init__(self, key, d, r, eqx_list, m=1):
462
- """
463
- Parameters
464
- ----------
465
- key
466
- A jax random key
467
- d
468
- An integer. The number of dimensions to treat separately
469
- r
470
- An integer. The dimension of the embedding
471
- eqx_list
472
- A list of list of successive equinox modules and activation functions to
473
- describe *each separable PINN architecture*.
474
- The inner lists have the eqx module or
475
- axtivation function as first item, other items represents arguments
476
- that could be required (eg. the size of the layer).
477
- __Note:__ the `key` argument need not be given.
478
- Thus typical example is `eqx_list=
479
- [[eqx.nn.Linear, d, 20],
480
- [jax.nn.tanh],
481
- [eqx.nn.Linear, 20, 20],
482
- [jax.nn.tanh],
483
- [eqx.nn.Linear, 20, 20],
484
- [jax.nn.tanh],
485
- [eqx.nn.Linear, 20, r]
486
- ]`
487
- """
488
- keys = jax.random.split(key, 8)
489
-
490
- self.d = d
491
- self.r = r
492
- self.m = m
493
-
494
- self.separated_mlp = []
495
- for d in range(self.d):
496
- self.layers = []
497
- for l in eqx_list:
498
- if len(l) == 1:
499
- self.layers.append(l[0])
500
- else:
501
- key, subkey = jax.random.split(key, 2)
502
- self.layers.append(l[0](*l[1:], key=subkey))
503
- self.separated_mlp.append(self.layers)
504
-
505
- def __call__(self, t, x):
506
- if t is not None:
507
- dimensions = jnp.concatenate([t, x.flatten()], axis=0)
508
- else:
509
- dimensions = jnp.concatenate([x.flatten()], axis=0)
510
- outputs = []
511
- for d in range(self.d):
512
- t_ = dimensions[d][None]
513
- for layer in self.separated_mlp[d]:
514
- t_ = layer(t_)
515
- outputs += [t_]
516
- return jnp.asarray(outputs)
517
-
518
-
519
- def _get_grid(in_array):
520
- """
521
- From an array of shape (B, D), D > 1, get the grid array, i.e., an array of
522
- shape (B, B, ...(D times)..., B, D): along the last axis we have the array
523
- of values
524
- """
525
- if in_array.shape[-1] > 1 or in_array.ndim > 1:
526
- return jnp.stack(
527
- jnp.meshgrid(
528
- *(in_array[..., d] for d in range(in_array.shape[-1])), indexing="ij"
529
- ),
530
- axis=-1,
531
- )
532
- else:
533
- return in_array
534
-
535
-
536
- def _get_vmap_in_axes_params(eq_params_batch_dict, params):
537
- """
538
- Return the input vmap axes when there is batch(es) of parameters to vmap
539
- over. The latter are designated by keys in eq_params_batch_dict
540
- If eq_params_batch_dict (ie no additional parameter batch), we return None
541
- """
542
- if eq_params_batch_dict is None:
543
- return (None,)
544
- else:
545
- # We use pytree indexing of vmapped axes and vmap on axis
546
- # 0 of the eq_parameters for which we have a batch
547
- # this is for a fine-grained vmaping
548
- # scheme over the params
549
- vmap_in_axes_params = (
550
- {
551
- "eq_params": {
552
- k: (0 if k in eq_params_batch_dict.keys() else None)
553
- for k in params["eq_params"].keys()
554
- },
555
- "nn_params": None,
556
- },
557
- )
558
- return vmap_in_axes_params
559
-
560
-
561
- def _check_user_func_return(r, shape):
562
- """
563
- Correctly handles the result from a user defined function (eg a boundary
564
- condition) to get the correct broadcast
565
- """
566
- if isinstance(r, int) or isinstance(r, float):
567
- # if we have a scalar cast it to float
568
- return float(r)
569
- if r.shape == () or len(r.shape) == 1:
570
- # if we have a scalar (or a vector, but no batch dim) inside an array
571
- return r.astype(float)
572
- else:
573
- # if we have an array of the shape of the batch dimension(s) check that
574
- # we have the correct broadcast
575
- # the reshape below avoids a missing (1,) ending dimension
576
- # depending on how the user has coded the inital function
577
- return r.reshape(shape)
578
-
579
-
580
- def alternate_optax_solver(
581
- steps, parameters_set1, parameters_set2, lr_set1, lr_set2, label_fn=None
582
- ):
583
- """
584
- This function creates an optax optimizer that alternates the optimization
585
- between two set of parameters (ie. when some parameters are update to a
586
- given learning rates, others are not updated (learning rate = 0)
587
- The optimizers are scaled by adam parameters.
588
-
589
- __Note:__ The alternating pattern relies on
590
- `optax.piecewise_constant_schedule` which __multiplies__ learning rates of
591
- previous steps (current included) to set the new learning rate. Hence, our
592
- strategy used here is to relying on potentially cancelling power of tens to
593
- create the alternating scheme.
594
-
595
- Parameters
596
- ----------
597
- steps
598
- An array which describes the epochis number at which we alternate the
599
- optimization: the parameter_set that is being updated now stops
600
- updating, the other parameter_set starts updating.
601
- __Note:__ The step 0 should not be included
602
- parameters_set1
603
- A list of leaf level keys which must be found in the general `params` dict. The
604
- parameters in this `set1` will be the parameters which are updated
605
- first in the alternating scheme.
606
- parameters_set2
607
- A list of leaf level keys which must be found in the general `params` dict. The
608
- parameters in this `set2` will be the parameters which are not updated
609
- first in the alternating scheme.
610
- lr_set1
611
- A float. The learning rate of updates for set1.
612
- lr_set2
613
- A float. The learning rate of updates for set2.
614
- label_fn
615
- The same function as the label_fn function passed in an optax
616
- `multi_transform`
617
- [https://optax.readthedocs.io/en/latest/api.html#optax.multi_transform](see
618
- here)
619
- Default None, ie, we already internally provide the default one (as
620
- proposed in the optax documentation) which may suit many use cases
621
-
622
- Returns
623
- -------
624
- tx
625
- The optax optimizer object
626
- """
627
-
628
- def map_nested_fn(fn):
629
- """
630
- Recursively apply `fn` to the key-value pairs of a nested dict
631
- We follow the example from
632
- https://optax.readthedocs.io/en/latest/api.html#optax.multi_transform
633
- for different learning rates
634
- """
635
-
636
- def map_fn(nested_dict):
637
- return {
638
- k: (map_fn(v) if isinstance(v, dict) else fn(k, v))
639
- for k, v in nested_dict.items()
640
- }
641
-
642
- return map_fn
643
-
644
- label_fn = map_nested_fn(lambda k, _: k)
645
-
646
- power_to_0 = 1e-25 # power of ten used to force a learning rate to 0
647
- power_to_lr = 1 / power_to_0 # power of ten used to force a learning rate to lr
648
- nn_params_scheduler = optax.piecewise_constant_schedule(
649
- init_value=lr_set1,
650
- boundaries_and_scales={
651
- k: (
652
- power_to_0
653
- if even_odd % 2 == 0 # set lr to 0 eg if even_odd is even ie at
654
- # first step
655
- else power_to_lr
656
- )
657
- for even_odd, k in enumerate(steps)
658
- },
659
- )
660
- eq_params_scheduler = optax.piecewise_constant_schedule(
661
- init_value=power_to_0 * lr_set2, # so normal learning rate is 1e-3
662
- boundaries_and_scales={
663
- k: (power_to_lr if even_odd % 2 == 0 else power_to_0)
664
- for even_odd, k in enumerate(steps)
665
- },
666
- )
667
-
668
- # the scheduler for set1 is called nn_chain because we usually start by
669
- # updating the NN parameters
670
- nn_chain = optax.chain(
671
- optax.scale_by_adam(),
672
- optax.scale_by_schedule(nn_params_scheduler),
673
- optax.scale(-1.0),
674
- )
675
- eq_chain = optax.chain(
676
- optax.scale_by_adam(),
677
- optax.scale_by_schedule(eq_params_scheduler),
678
- optax.scale(-1.0),
679
- )
680
- dict_params_set1 = {p: nn_chain for p in parameters_set1}
681
- dict_params_set2 = {p: eq_chain for p in parameters_set2}
682
- tx = optax.multi_transform(
683
- {**dict_params_set1, **dict_params_set2},
684
- label_fn,
685
- )
686
-
687
- return tx
688
-
689
-
690
- def euler_maruyama_density(t, x, s, y, params, Tmax=1):
691
- eps = 1e-6
692
- delta = jnp.abs(t - s) * Tmax
693
- mu = params["alpha_sde"] * (params["mu_sde"] - y) * delta
694
- var = params["sigma_sde"] ** 2 * delta
695
- return (
696
- 1 / jnp.sqrt(2 * jnp.pi * var) * jnp.exp(-0.5 * ((x - y) - mu) ** 2 / var) + eps
697
- )
698
-
699
-
700
- def log_euler_maruyama_density(t, x, s, y, params):
701
- eps = 1e-6
702
- delta = jnp.abs(t - s)
703
- mu = params["alpha_sde"] * (params["mu_sde"] - y) * delta
704
- logvar = params["logvar_sde"]
705
- return (
706
- -0.5
707
- * (jnp.log(2 * jnp.pi * delta) + logvar + ((x - y) - mu) ** 2 / jnp.exp(logvar))
708
- + eps
709
- )
710
-
711
-
712
- def euler_maruyama(x0, alpha, mu, sigma, T, N):
713
- """
714
- Simulate 1D diffusion process with simple parametrization using the Euler
715
- Maruyama method in the interval [0, T]
716
- """
717
- path = [np.array([x0])]
718
-
719
- time_steps, step_size = np.linspace(0, T, N, retstep=True)
720
- for i in time_steps[1:]:
721
- path.append(
722
- path[-1]
723
- + step_size * (alpha * (mu - path[-1]))
724
- + sigma * np.random.normal(loc=0.0, scale=np.sqrt(step_size))
725
- )
726
-
727
- return time_steps, np.stack(path)