jinns 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +904 -1203
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +322 -167
  9. jinns/loss/_LossODE.py +324 -322
  10. jinns/loss/_LossPDE.py +652 -1027
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +87 -41
  13. jinns/loss/{_Losses.py → _loss_utils.py} +101 -45
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +521 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +183 -39
  22. jinns/solver/_solve.py +151 -124
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -44
  25. jinns/utils/_hyperpinn.py +224 -119
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +113 -86
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +48 -140
  32. jinns-1.1.0.dist-info/AUTHORS +2 -0
  33. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/METADATA +5 -4
  34. jinns-1.1.0.dist-info/RECORD +39 -0
  35. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/WHEEL +1 -1
  36. jinns/experimental/_sinuspinn.py +0 -135
  37. jinns/experimental/_spectralpinn.py +0 -87
  38. jinns/solver/_seq2seq.py +0 -157
  39. jinns/utils/_optim.py +0 -147
  40. jinns/utils/_utils_uspinn.py +0 -727
  41. jinns-0.9.0.dist-info/RECORD +0 -36
  42. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/LICENSE +0 -0
  43. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,521 @@
1
+ """
2
+ Formalize the data structure for the derivative keys
3
+ """
4
+
5
+ from functools import partial
6
+ from dataclasses import fields, InitVar
7
+ from typing import Literal
8
+ import jax
9
+ import equinox as eqx
10
+
11
+ from jinns.parameters._params import Params, ParamsDict
12
+
13
+
14
+ def _get_masked_parameters(
15
+ derivative_mask_str: str, params: Params | ParamsDict
16
+ ) -> Params | ParamsDict:
17
+ """
18
+ Creates the Params object with True values where we want to differentiate
19
+ """
20
+ if isinstance(params, Params):
21
+ # start with a params object with True everywhere. We will update to False
22
+ # for parameters wrt which we do want not to differentiate the loss
23
+ diff_params = jax.tree.map(
24
+ lambda x: True,
25
+ params,
26
+ is_leaf=lambda x: isinstance(x, eqx.Module)
27
+ and not isinstance(x, Params), # do not travers nn_params, more
28
+ # granularity could be imagined here, in the future
29
+ )
30
+ if derivative_mask_str == "both":
31
+ return diff_params
32
+ if derivative_mask_str == "eq_params":
33
+ return eqx.tree_at(lambda p: p.nn_params, diff_params, False)
34
+ if derivative_mask_str == "nn_params":
35
+ return eqx.tree_at(
36
+ lambda p: p.eq_params,
37
+ diff_params,
38
+ jax.tree.map(lambda x: False, params.eq_params),
39
+ )
40
+ raise ValueError(
41
+ "Bad value for DerivativeKeys. Got "
42
+ f'{derivative_mask_str}, expected "both", "nn_params" or '
43
+ ' "eq_params"'
44
+ )
45
+ elif isinstance(params, ParamsDict):
46
+ # do not travers nn_params, more
47
+ # granularity could be imagined here, in the future
48
+ diff_params = ParamsDict(
49
+ nn_params=True, eq_params=jax.tree.map(lambda x: True, params.eq_params)
50
+ )
51
+ if derivative_mask_str == "both":
52
+ return diff_params
53
+ if derivative_mask_str == "eq_params":
54
+ return eqx.tree_at(lambda p: p.nn_params, diff_params, False)
55
+ if derivative_mask_str == "nn_params":
56
+ return eqx.tree_at(
57
+ lambda p: p.eq_params,
58
+ diff_params,
59
+ jax.tree.map(lambda x: False, params.eq_params),
60
+ )
61
+ raise ValueError(
62
+ "Bad value for DerivativeKeys. Got "
63
+ f'{derivative_mask_str}, expected "both", "nn_params" or '
64
+ ' "eq_params"'
65
+ )
66
+
67
+ else:
68
+ raise ValueError(
69
+ f"Bad value for params. Got {type(params)}, expected Params "
70
+ " or ParamsDict"
71
+ )
72
+
73
+
74
+ class DerivativeKeysODE(eqx.Module):
75
+ """
76
+ A class that specifies with repect to which parameter(s) each term of the
77
+ loss is differentiated. For example, you can specify that the
78
+ [`DynamicLoss`][jinns.loss.DynamicLoss] should be differentiated both with
79
+ respect to the neural network parameters *and* the equation parameters, or only some of them.
80
+
81
+ To do so, user can either use strings or a `Params` object
82
+ with PyTree structure matching the parameters of the problem at
83
+ hand, and booleans indicating if gradient is to be taken or not. Internally,
84
+ a `jax.lax.stop_gradient()` is appropriately set to each `True` node when
85
+ computing each loss term.
86
+
87
+ !!! note
88
+
89
+ 1. For unspecified loss term, the default is to differentiate with
90
+ respect to `"nn_params"` only.
91
+ 2. No granularity inside `Params.nn_params` is currently supported.
92
+ 3. Note that the main Params or ParamsDict object of the problem is mandatory if initialization via `from_str()`.
93
+
94
+ A typical specification is of the form:
95
+ ```python
96
+ Params(
97
+ nn_params=True | False,
98
+ eq_params={
99
+ "alpha":True | False,
100
+ "beta":True | False,
101
+ ...
102
+ }
103
+ )
104
+ ```
105
+
106
+ Parameters
107
+ ----------
108
+ dyn_loss : Params | ParamsDict | None, default=None
109
+ Tell wrt which node of `Params` we will differentiate the
110
+ dynamic loss. To do so, the fields of `Params` contain True (if
111
+ differentiation) or False (if no differentiation).
112
+ observations : Params | ParamsDict | None, default=None
113
+ Tell wrt which parameters among Params we will differentiate the
114
+ observation loss. To do so, the fields of Params contain True (if
115
+ differentiation) or False (if no differentiation).
116
+ initial_condition : Params | ParamsDict | None, default=None
117
+ Tell wrt which parameters among Params we will differentiate the
118
+ initial condition loss. To do so, the fields of Params contain True (if
119
+ differentiation) or False (if no differentiation).
120
+ params : InitVar[Params | ParamsDict], default=None
121
+ The main Params object of the problem. It is required
122
+ if some terms are unspecified (None). This is because, jinns cannot
123
+ infer the content of `Params.eq_params`.
124
+ """
125
+
126
+ dyn_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
127
+ observations: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
128
+ initial_condition: Params | ParamsDict | None = eqx.field(
129
+ kw_only=True, default=None
130
+ )
131
+
132
+ params: InitVar[Params | ParamsDict] = eqx.field(kw_only=True, default=None)
133
+
134
+ def __post_init__(self, params=None):
135
+ if self.dyn_loss is None:
136
+ try:
137
+ self.dyn_loss = _get_masked_parameters("nn_params", params)
138
+ except AttributeError:
139
+ raise ValueError(
140
+ "self.dyn_loss is None, hence params should be " "passed"
141
+ )
142
+ if self.observations is None:
143
+ try:
144
+ self.observations = _get_masked_parameters("nn_params", params)
145
+ except AttributeError:
146
+ raise ValueError(
147
+ "self.observations is None, hence params should be " "passed"
148
+ )
149
+ if self.initial_condition is None:
150
+ try:
151
+ self.initial_condition = _get_masked_parameters("nn_params", params)
152
+ except AttributeError:
153
+ raise ValueError(
154
+ "self.initial_condition is None, hence params should be " "passed"
155
+ )
156
+
157
+ @classmethod
158
+ def from_str(
159
+ cls,
160
+ params: Params | ParamsDict,
161
+ dyn_loss: (
162
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
163
+ ) = "nn_params",
164
+ observations: (
165
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
166
+ ) = "nn_params",
167
+ initial_condition: (
168
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
169
+ ) = "nn_params",
170
+ ):
171
+ """
172
+ Construct the DerivativeKeysODE from strings. For each term of the
173
+ loss, specify whether to differentiate wrt the neural network
174
+ parameters, the equation parameters or both. The `Params` object, which
175
+ contains the actual array of parameters must be passed to
176
+ construct the fields with the appropriate PyTree structure.
177
+
178
+ !!! note
179
+ You can mix strings and `Params` if you need granularity.
180
+
181
+ Parameters
182
+ ----------
183
+ params
184
+ The actual Params or ParamsDict object of the problem.
185
+ dyn_loss
186
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
187
+ `"both"` we will differentiate the dynamic loss. Default is
188
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
189
+ observations
190
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
191
+ `"both"` we will differentiate the observations. Default is
192
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
193
+ initial_condition
194
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
195
+ `"both"` we will differentiate the initial condition. Default is
196
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
197
+ """
198
+ return DerivativeKeysODE(
199
+ dyn_loss=(
200
+ _get_masked_parameters(dyn_loss, params)
201
+ if isinstance(dyn_loss, str)
202
+ else dyn_loss
203
+ ),
204
+ observations=(
205
+ _get_masked_parameters(observations, params)
206
+ if isinstance(observations, str)
207
+ else observations
208
+ ),
209
+ initial_condition=(
210
+ _get_masked_parameters(initial_condition, params)
211
+ if isinstance(initial_condition, str)
212
+ else initial_condition
213
+ ),
214
+ )
215
+
216
+
217
+ class DerivativeKeysPDEStatio(eqx.Module):
218
+ """
219
+ See [jinns.parameters.DerivativeKeysODE][].
220
+
221
+ Parameters
222
+ ----------
223
+ dyn_loss : Params | ParamsDict | None, default=None
224
+ Tell wrt which parameters among Params we will differentiate the
225
+ dynamic loss. To do so, the fields of Params contain True (if
226
+ differentiation) or False (if no differentiation).
227
+ observations : Params | ParamsDict | None, default=None
228
+ Tell wrt which parameters among Params we will differentiate the
229
+ observation loss. To do so, the fields of Params contain True (if
230
+ differentiation) or False (if no differentiation).
231
+ boundary_loss : Params | ParamsDict | None, default=None
232
+ Tell wrt which parameters among Params we will differentiate the
233
+ boundary loss. To do so, the fields of Params contain True (if
234
+ differentiation) or False (if no differentiation).
235
+ norm_loss : Params | ParamsDict | None, default=None
236
+ Tell wrt which parameters among Params we will differentiate the
237
+ normalization loss. To do so, the fields of Params contain True (if
238
+ differentiation) or False (if no differentiation).
239
+ params : InitVar[Params | ParamsDict], default=None
240
+ The main Params object of the problem. It is required
241
+ if some terms are unspecified (None). This is because, jinns cannot infer the
242
+ content of `Params.eq_params`.
243
+ """
244
+
245
+ dyn_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
246
+ observations: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
247
+ boundary_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
248
+ norm_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
249
+
250
+ params: InitVar[Params | ParamsDict] = eqx.field(kw_only=True, default=None)
251
+
252
+ def __post_init__(self, params=None):
253
+ if self.dyn_loss is None:
254
+ try:
255
+ self.dyn_loss = _get_masked_parameters("nn_params", params)
256
+ except AttributeError:
257
+ raise ValueError("self.dyn_loss is None, hence params should be passed")
258
+ if self.observations is None:
259
+ try:
260
+ self.observations = _get_masked_parameters("nn_params", params)
261
+ except AttributeError:
262
+ raise ValueError(
263
+ "self.observations is None, hence params should be passed"
264
+ )
265
+ if self.boundary_loss is None:
266
+ try:
267
+ self.boundary_loss = _get_masked_parameters("nn_params", params)
268
+ except AttributeError:
269
+ raise ValueError(
270
+ "self.boundary_loss is None, hence params should be passed"
271
+ )
272
+ if self.norm_loss is None:
273
+ try:
274
+ self.norm_loss = _get_masked_parameters("nn_params", params)
275
+ except AttributeError:
276
+ raise ValueError(
277
+ "self.norm_loss is None, hence params should be passed"
278
+ )
279
+
280
+ @classmethod
281
+ def from_str(
282
+ cls,
283
+ params: Params | ParamsDict,
284
+ dyn_loss: (
285
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
286
+ ) = "nn_params",
287
+ observations: (
288
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
289
+ ) = "nn_params",
290
+ boundary_loss: (
291
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
292
+ ) = "nn_params",
293
+ norm_loss: (
294
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
295
+ ) = "nn_params",
296
+ ):
297
+ """
298
+ See [jinns.parameters.DerivativeKeysODE.from_str][].
299
+
300
+ Parameters
301
+ ----------
302
+ params
303
+ The actual Param or ParamsDict object of the problem.
304
+ dyn_loss
305
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
306
+ `"both"` we will differentiate the dynamic loss. Default is
307
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
308
+ observations
309
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
310
+ `"both"` we will differentiate the observations. Default is
311
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
312
+ boundary_loss
313
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
314
+ `"both"` we will differentiate the boundary loss. Default is
315
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
316
+ norm_loss
317
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
318
+ `"both"` we will differentiate the normalization loss. Default is
319
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
320
+ """
321
+ return DerivativeKeysPDEStatio(
322
+ dyn_loss=(
323
+ _get_masked_parameters(dyn_loss, params)
324
+ if isinstance(dyn_loss, str)
325
+ else dyn_loss
326
+ ),
327
+ observations=(
328
+ _get_masked_parameters(observations, params)
329
+ if isinstance(observations, str)
330
+ else observations
331
+ ),
332
+ boundary_loss=(
333
+ _get_masked_parameters(boundary_loss, params)
334
+ if isinstance(boundary_loss, str)
335
+ else boundary_loss
336
+ ),
337
+ norm_loss=(
338
+ _get_masked_parameters(norm_loss, params)
339
+ if isinstance(norm_loss, str)
340
+ else norm_loss
341
+ ),
342
+ )
343
+
344
+
345
+ class DerivativeKeysPDENonStatio(DerivativeKeysPDEStatio):
346
+ """
347
+ See [jinns.parameters.DerivativeKeysODE][].
348
+
349
+ Parameters
350
+ ----------
351
+ dyn_loss : Params | ParamsDict | None, default=None
352
+ Tell wrt which parameters among Params we will differentiate the
353
+ dynamic loss. To do so, the fields of Params contain True (if
354
+ differentiation) or False (if no differentiation).
355
+ observations : Params | ParamsDict | None, default=None
356
+ Tell wrt which parameters among Params we will differentiate the
357
+ observation loss. To do so, the fields of Params contain True (if
358
+ differentiation) or False (if no differentiation).
359
+ boundary_loss : Params | ParamsDict | None, default=None
360
+ Tell wrt which parameters among Params we will differentiate the
361
+ boundary loss. To do so, the fields of Params contain True (if
362
+ differentiation) or False (if no differentiation).
363
+ norm_loss : Params | ParamsDict | None, default=None
364
+ Tell wrt which parameters among Params we will differentiate the
365
+ normalization loss. To do so, the fields of Params contain True (if
366
+ differentiation) or False (if no differentiation).
367
+ initial_condition : Params | ParamsDict | None, default=None
368
+ Tell wrt which parameters among Params we will differentiate the
369
+ initial_condition loss. To do so, the fields of Params contain True (if
370
+ differentiation) or False (if no differentiation).
371
+ params : InitVar[Params | ParamsDict], default=None
372
+ The main Params object of the problem. It is required
373
+ if some terms are unspecified (None). This is because, jinns cannot infer the
374
+ content of `Params.eq_params`.
375
+ """
376
+
377
+ initial_condition: Params | ParamsDict | None = eqx.field(
378
+ kw_only=True, default=None
379
+ )
380
+
381
+ def __post_init__(self, params=None):
382
+ super().__post_init__(params=params)
383
+ if self.initial_condition is None:
384
+ try:
385
+ self.initial_condition = _get_masked_parameters("nn_params", params)
386
+ except AttributeError:
387
+ raise ValueError(
388
+ "self.initial_condition is None, hence params should be passed"
389
+ )
390
+
391
+ @classmethod
392
+ def from_str(
393
+ cls,
394
+ params: Params | ParamsDict,
395
+ dyn_loss: (
396
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
397
+ ) = "nn_params",
398
+ observations: (
399
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
400
+ ) = "nn_params",
401
+ boundary_loss: (
402
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
403
+ ) = "nn_params",
404
+ norm_loss: (
405
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
406
+ ) = "nn_params",
407
+ initial_condition: (
408
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
409
+ ) = "nn_params",
410
+ ):
411
+ """
412
+ See [jinns.parameters.DerivativeKeysODE.from_str][].
413
+
414
+ Parameters
415
+ ----------
416
+ params
417
+ The actual Params | ParamsDict object of the problem.
418
+ dyn_loss
419
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
420
+ `"both"` we will differentiate the dynamic loss. Default is
421
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
422
+ observations
423
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
424
+ `"both"` we will differentiate the observations. Default is
425
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
426
+ boundary_loss
427
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
428
+ `"both"` we will differentiate the boundary loss. Default is
429
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
430
+ norm_loss
431
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
432
+ `"both"` we will differentiate the normalization loss. Default is
433
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
434
+ initial_condition
435
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
436
+ `"both"` we will differentiate the initial_condition loss. Default is
437
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
438
+ """
439
+ return DerivativeKeysPDENonStatio(
440
+ dyn_loss=(
441
+ _get_masked_parameters(dyn_loss, params)
442
+ if isinstance(dyn_loss, str)
443
+ else dyn_loss
444
+ ),
445
+ observations=(
446
+ _get_masked_parameters(observations, params)
447
+ if isinstance(observations, str)
448
+ else observations
449
+ ),
450
+ boundary_loss=(
451
+ _get_masked_parameters(boundary_loss, params)
452
+ if isinstance(boundary_loss, str)
453
+ else boundary_loss
454
+ ),
455
+ norm_loss=(
456
+ _get_masked_parameters(norm_loss, params)
457
+ if isinstance(norm_loss, str)
458
+ else norm_loss
459
+ ),
460
+ initial_condition=(
461
+ _get_masked_parameters(initial_condition, params)
462
+ if isinstance(initial_condition, str)
463
+ else initial_condition
464
+ ),
465
+ )
466
+
467
+
468
+ def _set_derivatives(params, derivative_keys):
469
+ """
470
+ We construct an eqx.Module with the fields of derivative_keys, each field
471
+ has a copy of the params with appropriate derivatives set
472
+ """
473
+
474
+ def _set_derivatives_ParamsDict(params_, derivative_mask):
475
+ """
476
+ The next lines put a stop_gradient around the fields that do not
477
+ differentiate the loss term
478
+ **Note:** **No granularity inside `ParamsDict.nn_params` is currently
479
+ supported.**
480
+ This means a typical Params specification is of the form:
481
+ `ParamsDict(nn_params=True | False, eq_params={"0":{"alpha":True | False,
482
+ "beta":True | False}}, "1":{"alpha":True | False, "beta":True | False}})`.
483
+ """
484
+ # a ParamsDict object is reconstructed by hand since we do not want to
485
+ # traverse nn_params, for now...
486
+ return ParamsDict(
487
+ nn_params=jax.lax.cond(
488
+ derivative_mask.nn_params,
489
+ lambda p: p,
490
+ jax.lax.stop_gradient,
491
+ params_.nn_params,
492
+ ),
493
+ eq_params=jax.tree.map(
494
+ lambda p, d: jax.lax.cond(d, lambda p: p, jax.lax.stop_gradient, p),
495
+ params_.eq_params,
496
+ derivative_mask.eq_params,
497
+ ),
498
+ )
499
+
500
+ def _set_derivatives_(params_, derivative_mask):
501
+ """
502
+ The next lines put a stop_gradient around the fields that do not
503
+ differentiate the loss term
504
+ **Note:** **No granularity inside `Params.nn_params` is currently
505
+ supported.**
506
+ This means a typical Params specification is of the form:
507
+ `Params(nn_params=True | False, eq_params={"alpha":True | False,
508
+ "beta":True | False})`.
509
+ """
510
+ return jax.tree.map(
511
+ lambda p, d: jax.lax.cond(d, lambda p: p, jax.lax.stop_gradient, p),
512
+ params_,
513
+ derivative_mask,
514
+ is_leaf=lambda x: isinstance(x, eqx.Module)
515
+ and not isinstance(x, Params), # do not travers nn_params, more
516
+ # granularity could be imagined here, in the future
517
+ )
518
+
519
+ if isinstance(params, ParamsDict):
520
+ return _set_derivatives_ParamsDict(params, derivative_keys)
521
+ return _set_derivatives_(params, derivative_keys)
@@ -0,0 +1,115 @@
1
+ """
2
+ Formalize the data structure for the parameters
3
+ """
4
+
5
+ import jax
6
+ import equinox as eqx
7
+ from typing import Dict
8
+ from jaxtyping import Array, PyTree
9
+
10
+
11
+ class Params(eqx.Module):
12
+ """
13
+ The equinox module for the parameters
14
+
15
+ Parameters
16
+ ----------
17
+ nn_params : Pytree
18
+ A PyTree of the non-static part of the PINN eqx.Module, i.e., the
19
+ parameters of the PINN
20
+ eq_params : Dict[str, Array]
21
+ A dictionary of the equation parameters. Keys are the parameter name,
22
+ values are their corresponding value
23
+ """
24
+
25
+ nn_params: PyTree = eqx.field(kw_only=True, default=None)
26
+ eq_params: Dict[str, Array] = eqx.field(kw_only=True, default=None)
27
+
28
+
29
+ class ParamsDict(eqx.Module):
30
+ """
31
+ The equinox module for a dictionnary of parameters with different keys
32
+ corresponding to different equations.
33
+
34
+ Parameters
35
+ ----------
36
+ nn_params : Dict[str, PyTree]
37
+ The neural network's parameters. Most of the time, it will be the
38
+ Array part of an `eqx.Module` obtained by
39
+ `eqx.partition(module, eqx.is_inexact_array)`.
40
+ eq_params : Dict[str, Array]
41
+ A dictionary of the equation parameters. Dict keys are the parameter name as defined your custom loss.
42
+ """
43
+
44
+ nn_params: Dict[str, PyTree] = eqx.field(kw_only=True, default=None)
45
+ eq_params: Dict[str, Array] = eqx.field(kw_only=True, default=None)
46
+
47
+ def extract_params(self, nn_key: str) -> Params:
48
+ """
49
+ Extract the corresponding `nn_params` and `eq_params` for `nn_key` and
50
+ return them in the form of a `Params` object.
51
+ """
52
+ try:
53
+ return Params(
54
+ nn_params=self.nn_params[nn_key],
55
+ eq_params=self.eq_params[nn_key],
56
+ )
57
+ except (KeyError, IndexError) as e:
58
+ return Params(
59
+ nn_params=self.nn_params[nn_key],
60
+ eq_params=self.eq_params,
61
+ )
62
+
63
+
64
+ def _update_eq_params_dict(
65
+ params: Params, param_batch_dict: Dict[str, Array]
66
+ ) -> Params:
67
+ """
68
+ Update params.eq_params with a batch of eq_params for given key(s)
69
+ """
70
+
71
+ # artificially "complete" `param_batch_dict` with None to match `params`
72
+ # PyTree structure
73
+ param_batch_dict_ = param_batch_dict | {
74
+ k: None for k in set(params.eq_params.keys()) - set(param_batch_dict.keys())
75
+ }
76
+
77
+ # Replace at non None leafs
78
+ params = eqx.tree_at(
79
+ lambda p: p.eq_params,
80
+ params,
81
+ jax.tree_util.tree_map(
82
+ lambda p, q: q if q is not None else p,
83
+ params.eq_params,
84
+ param_batch_dict_,
85
+ ),
86
+ )
87
+
88
+ return params
89
+
90
+
91
+ def _get_vmap_in_axes_params(
92
+ eq_params_batch_dict: Dict[str, Array], params: Params | ParamsDict
93
+ ) -> tuple[Params]:
94
+ """
95
+ Return the input vmap axes when there is batch(es) of parameters to vmap
96
+ over. The latter are designated by keys in eq_params_batch_dict.
97
+ If eq_params_batch_dict is None (i.e. no additional parameter batch), we
98
+ return (None,).
99
+ """
100
+ if eq_params_batch_dict is None:
101
+ return (None,)
102
+ # We use pytree indexing of vmapped axes and vmap on axis
103
+ # 0 of the eq_parameters for which we have a batch
104
+ # this is for a fine-grained vmaping
105
+ # scheme over the params
106
+ vmap_in_axes_params = (
107
+ type(params)(
108
+ nn_params=None,
109
+ eq_params={
110
+ k: (0 if k in eq_params_batch_dict.keys() else None)
111
+ for k in params.eq_params.keys()
112
+ },
113
+ ),
114
+ )
115
+ return vmap_in_axes_params
jinns/plot/__init__.py ADDED
@@ -0,0 +1,5 @@
1
+ from jinns.plot._plot import (
2
+ plot2d,
3
+ plot1d_slice,
4
+ plot1d_image,
5
+ )