jinns 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +114 -187
  13. jinns/loss/_DynamicLossAbstract.py +74 -69
  14. jinns/loss/_LossODE.py +132 -348
  15. jinns/loss/_LossPDE.py +262 -549
  16. jinns/loss/__init__.py +32 -6
  17. jinns/loss/_abstract_loss.py +128 -0
  18. jinns/loss/_boundary_conditions.py +20 -19
  19. jinns/loss/_loss_components.py +43 -0
  20. jinns/loss/_loss_utils.py +85 -179
  21. jinns/loss/_loss_weight_updates.py +202 -0
  22. jinns/loss/_loss_weights.py +64 -40
  23. jinns/loss/_operators.py +84 -74
  24. jinns/nn/__init__.py +15 -0
  25. jinns/nn/_abstract_pinn.py +22 -0
  26. jinns/nn/_hyperpinn.py +94 -57
  27. jinns/nn/_mlp.py +50 -25
  28. jinns/nn/_pinn.py +33 -19
  29. jinns/nn/_ppinn.py +70 -34
  30. jinns/nn/_save_load.py +21 -51
  31. jinns/nn/_spinn.py +33 -16
  32. jinns/nn/_spinn_mlp.py +28 -22
  33. jinns/nn/_utils.py +38 -0
  34. jinns/parameters/__init__.py +8 -1
  35. jinns/parameters/_derivative_keys.py +116 -177
  36. jinns/parameters/_params.py +18 -46
  37. jinns/plot/__init__.py +2 -0
  38. jinns/plot/_plot.py +35 -34
  39. jinns/solver/_rar.py +80 -63
  40. jinns/solver/_solve.py +207 -92
  41. jinns/solver/_utils.py +4 -6
  42. jinns/utils/__init__.py +2 -0
  43. jinns/utils/_containers.py +16 -10
  44. jinns/utils/_types.py +20 -54
  45. jinns/utils/_utils.py +4 -11
  46. jinns/validation/__init__.py +2 -0
  47. jinns/validation/_validation.py +20 -19
  48. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/METADATA +8 -4
  49. jinns-1.5.0.dist-info/RECORD +55 -0
  50. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/WHEEL +1 -1
  51. jinns/data/_DataGenerators.py +0 -1634
  52. jinns-1.3.0.dist-info/RECORD +0 -44
  53. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/AUTHORS +0 -0
  54. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/LICENSE +0 -0
  55. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/top_level.txt +0 -0
@@ -2,73 +2,45 @@
2
2
  Formalize the data structure for the derivative keys
3
3
  """
4
4
 
5
- from functools import partial
6
- from dataclasses import fields, InitVar
5
+ from dataclasses import InitVar
7
6
  from typing import Literal
7
+ from jaxtyping import Array
8
8
  import jax
9
9
  import equinox as eqx
10
10
 
11
- from jinns.parameters._params import Params, ParamsDict
11
+ from jinns.parameters._params import Params
12
12
 
13
13
 
14
14
  def _get_masked_parameters(
15
- derivative_mask_str: str, params: Params | ParamsDict
16
- ) -> Params | ParamsDict:
15
+ derivative_mask_str: str, params: Params[Array]
16
+ ) -> Params[bool]:
17
17
  """
18
18
  Creates the Params object with True values where we want to differentiate
19
19
  """
20
- if isinstance(params, Params):
21
- # start with a params object with True everywhere. We will update to False
22
- # for parameters wrt which we do want not to differentiate the loss
23
- diff_params = jax.tree.map(
24
- lambda x: True,
25
- params,
26
- is_leaf=lambda x: isinstance(x, eqx.Module)
27
- and not isinstance(x, Params), # do not travers nn_params, more
28
- # granularity could be imagined here, in the future
29
- )
30
- if derivative_mask_str == "both":
31
- return diff_params
32
- if derivative_mask_str == "eq_params":
33
- return eqx.tree_at(lambda p: p.nn_params, diff_params, False)
34
- if derivative_mask_str == "nn_params":
35
- return eqx.tree_at(
36
- lambda p: p.eq_params,
37
- diff_params,
38
- jax.tree.map(lambda x: False, params.eq_params),
39
- )
40
- raise ValueError(
41
- "Bad value for DerivativeKeys. Got "
42
- f'{derivative_mask_str}, expected "both", "nn_params" or '
43
- ' "eq_params"'
44
- )
45
- elif isinstance(params, ParamsDict):
46
- # do not travers nn_params, more
20
+ # start with a params object with True everywhere. We will update to False
21
+ # for parameters wrt which we do want not to differentiate the loss
22
+ diff_params = jax.tree.map(
23
+ lambda x: True,
24
+ params,
25
+ is_leaf=lambda x: isinstance(x, eqx.Module)
26
+ and not isinstance(x, Params), # do not travers nn_params, more
47
27
  # granularity could be imagined here, in the future
48
- diff_params = ParamsDict(
49
- nn_params=True, eq_params=jax.tree.map(lambda x: True, params.eq_params)
50
- )
51
- if derivative_mask_str == "both":
52
- return diff_params
53
- if derivative_mask_str == "eq_params":
54
- return eqx.tree_at(lambda p: p.nn_params, diff_params, False)
55
- if derivative_mask_str == "nn_params":
56
- return eqx.tree_at(
57
- lambda p: p.eq_params,
58
- diff_params,
59
- jax.tree.map(lambda x: False, params.eq_params),
60
- )
61
- raise ValueError(
62
- "Bad value for DerivativeKeys. Got "
63
- f'{derivative_mask_str}, expected "both", "nn_params" or '
64
- ' "eq_params"'
65
- )
66
-
67
- else:
68
- raise ValueError(
69
- f"Bad value for params. Got {type(params)}, expected Params "
70
- " or ParamsDict"
28
+ )
29
+ if derivative_mask_str == "both":
30
+ return diff_params
31
+ if derivative_mask_str == "eq_params":
32
+ return eqx.tree_at(lambda p: p.nn_params, diff_params, False)
33
+ if derivative_mask_str == "nn_params":
34
+ return eqx.tree_at(
35
+ lambda p: p.eq_params,
36
+ diff_params,
37
+ jax.tree.map(lambda x: False, params.eq_params),
71
38
  )
39
+ raise ValueError(
40
+ "Bad value for DerivativeKeys. Got "
41
+ f'{derivative_mask_str}, expected "both", "nn_params" or '
42
+ ' "eq_params"'
43
+ )
72
44
 
73
45
 
74
46
  class DerivativeKeysODE(eqx.Module):
@@ -89,7 +61,7 @@ class DerivativeKeysODE(eqx.Module):
89
61
  1. For unspecified loss term, the default is to differentiate with
90
62
  respect to `"nn_params"` only.
91
63
  2. No granularity inside `Params.nn_params` is currently supported.
92
- 3. Note that the main Params or ParamsDict object of the problem is mandatory if initialization via `from_str()`.
64
+ 3. Note that the main Params object of the problem is mandatory if initialization via `from_str()`.
93
65
 
94
66
  A typical specification is of the form:
95
67
  ```python
@@ -105,67 +77,69 @@ class DerivativeKeysODE(eqx.Module):
105
77
 
106
78
  Parameters
107
79
  ----------
108
- dyn_loss : Params | ParamsDict | None, default=None
80
+ dyn_loss : Params[bool] | None, default=None
109
81
  Tell wrt which node of `Params` we will differentiate the
110
82
  dynamic loss. To do so, the fields of `Params` contain True (if
111
83
  differentiation) or False (if no differentiation).
112
- observations : Params | ParamsDict | None, default=None
84
+ observations : Params[bool] | None, default=None
113
85
  Tell wrt which parameters among Params we will differentiate the
114
86
  observation loss. To do so, the fields of Params contain True (if
115
87
  differentiation) or False (if no differentiation).
116
- initial_condition : Params | ParamsDict | None, default=None
88
+ initial_condition : Params[bool] | None, default=None
117
89
  Tell wrt which parameters among Params we will differentiate the
118
90
  initial condition loss. To do so, the fields of Params contain True (if
119
91
  differentiation) or False (if no differentiation).
120
- params : InitVar[Params | ParamsDict], default=None
92
+ params : InitVar[Params[Array]], default=None
121
93
  The main Params object of the problem. It is required
122
94
  if some terms are unspecified (None). This is because, jinns cannot
123
95
  infer the content of `Params.eq_params`.
124
96
  """
125
97
 
126
- dyn_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
127
- observations: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
128
- initial_condition: Params | ParamsDict | None = eqx.field(
129
- kw_only=True, default=None
130
- )
131
-
132
- params: InitVar[Params | ParamsDict] = eqx.field(kw_only=True, default=None)
133
-
134
- def __post_init__(self, params=None):
98
+ dyn_loss: Params[bool] | None = eqx.field(kw_only=True, default=None)
99
+ observations: Params[bool] | None = eqx.field(kw_only=True, default=None)
100
+ initial_condition: Params[bool] | None = eqx.field(kw_only=True, default=None)
101
+
102
+ params: InitVar[Params[Array] | None] = eqx.field(kw_only=True, default=None)
103
+
104
+ def __post_init__(self, params: Params[Array] | None = None):
105
+ if params is None and (
106
+ self.dyn_loss is None
107
+ or self.observations is None
108
+ or self.initial_condition is None
109
+ ):
110
+ raise ValueError(
111
+ "params cannot be None since at least one loss "
112
+ "term has an undefined derivative key Params PyTree"
113
+ )
135
114
  if self.dyn_loss is None:
136
- try:
137
- self.dyn_loss = _get_masked_parameters("nn_params", params)
138
- except AttributeError:
139
- raise ValueError(
140
- "self.dyn_loss is None, hence params should be " "passed"
141
- )
115
+ if params is None:
116
+ raise ValueError("self.dyn_loss is None, hence params should be passed")
117
+ self.dyn_loss = _get_masked_parameters("nn_params", params)
142
118
  if self.observations is None:
143
- try:
144
- self.observations = _get_masked_parameters("nn_params", params)
145
- except AttributeError:
119
+ if params is None:
146
120
  raise ValueError(
147
- "self.observations is None, hence params should be " "passed"
121
+ "self.observations is None, hence params should be passed"
148
122
  )
123
+ self.observations = _get_masked_parameters("nn_params", params)
149
124
  if self.initial_condition is None:
150
- try:
151
- self.initial_condition = _get_masked_parameters("nn_params", params)
152
- except AttributeError:
125
+ if params is None:
153
126
  raise ValueError(
154
- "self.initial_condition is None, hence params should be " "passed"
127
+ "self.initial_condition is None, hence params should be passed"
155
128
  )
129
+ self.initial_condition = _get_masked_parameters("nn_params", params)
156
130
 
157
131
  @classmethod
158
132
  def from_str(
159
133
  cls,
160
- params: Params | ParamsDict,
134
+ params: Params[Array],
161
135
  dyn_loss: (
162
- Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
136
+ Literal["nn_params", "eq_params", "both"] | Params[bool]
163
137
  ) = "nn_params",
164
138
  observations: (
165
- Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
139
+ Literal["nn_params", "eq_params", "both"] | Params[bool]
166
140
  ) = "nn_params",
167
141
  initial_condition: (
168
- Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
142
+ Literal["nn_params", "eq_params", "both"] | Params[bool]
169
143
  ) = "nn_params",
170
144
  ):
171
145
  """
@@ -181,19 +155,19 @@ class DerivativeKeysODE(eqx.Module):
181
155
  Parameters
182
156
  ----------
183
157
  params
184
- The actual Params or ParamsDict object of the problem.
158
+ The actual Params object of the problem.
185
159
  dyn_loss
186
160
  Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
187
161
  `"both"` we will differentiate the dynamic loss. Default is
188
- `"nn_params"`. Specifying a Params or ParamsDict is also possible.
162
+ `"nn_params"`. Specifying a Params is also possible.
189
163
  observations
190
164
  Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
191
165
  `"both"` we will differentiate the observations. Default is
192
- `"nn_params"`. Specifying a Params or ParamsDict is also possible.
166
+ `"nn_params"`. Specifying a Params is also possible.
193
167
  initial_condition
194
168
  Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
195
169
  `"both"` we will differentiate the initial condition. Default is
196
- `"nn_params"`. Specifying a Params or ParamsDict is also possible.
170
+ `"nn_params"`. Specifying a Params is also possible.
197
171
  """
198
172
  return DerivativeKeysODE(
199
173
  dyn_loss=(
@@ -220,78 +194,74 @@ class DerivativeKeysPDEStatio(eqx.Module):
220
194
 
221
195
  Parameters
222
196
  ----------
223
- dyn_loss : Params | ParamsDict | None, default=None
197
+ dyn_loss : Params[bool] | None, default=None
224
198
  Tell wrt which parameters among Params we will differentiate the
225
199
  dynamic loss. To do so, the fields of Params contain True (if
226
200
  differentiation) or False (if no differentiation).
227
- observations : Params | ParamsDict | None, default=None
201
+ observations : Params[bool] | None, default=None
228
202
  Tell wrt which parameters among Params we will differentiate the
229
203
  observation loss. To do so, the fields of Params contain True (if
230
204
  differentiation) or False (if no differentiation).
231
- boundary_loss : Params | ParamsDict | None, default=None
205
+ boundary_loss : Params[bool] | None, default=None
232
206
  Tell wrt which parameters among Params we will differentiate the
233
207
  boundary loss. To do so, the fields of Params contain True (if
234
208
  differentiation) or False (if no differentiation).
235
- norm_loss : Params | ParamsDict | None, default=None
209
+ norm_loss : Params[bool] | None, default=None
236
210
  Tell wrt which parameters among Params we will differentiate the
237
211
  normalization loss. To do so, the fields of Params contain True (if
238
212
  differentiation) or False (if no differentiation).
239
- params : InitVar[Params | ParamsDict], default=None
213
+ params : InitVar[Params[Array]], default=None
240
214
  The main Params object of the problem. It is required
241
215
  if some terms are unspecified (None). This is because, jinns cannot infer the
242
216
  content of `Params.eq_params`.
243
217
  """
244
218
 
245
- dyn_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
246
- observations: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
247
- boundary_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
248
- norm_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
219
+ dyn_loss: Params[bool] | None = eqx.field(kw_only=True, default=None)
220
+ observations: Params[bool] | None = eqx.field(kw_only=True, default=None)
221
+ boundary_loss: Params[bool] | None = eqx.field(kw_only=True, default=None)
222
+ norm_loss: Params[bool] | None = eqx.field(kw_only=True, default=None)
249
223
 
250
- params: InitVar[Params | ParamsDict] = eqx.field(kw_only=True, default=None)
224
+ params: InitVar[Params[Array] | None] = eqx.field(kw_only=True, default=None)
251
225
 
252
- def __post_init__(self, params=None):
226
+ def __post_init__(self, params: Params[Array] | None = None):
253
227
  if self.dyn_loss is None:
254
- try:
255
- self.dyn_loss = _get_masked_parameters("nn_params", params)
256
- except AttributeError:
228
+ if params is None:
257
229
  raise ValueError("self.dyn_loss is None, hence params should be passed")
230
+ self.dyn_loss = _get_masked_parameters("nn_params", params)
258
231
  if self.observations is None:
259
- try:
260
- self.observations = _get_masked_parameters("nn_params", params)
261
- except AttributeError:
232
+ if params is None:
262
233
  raise ValueError(
263
234
  "self.observations is None, hence params should be passed"
264
235
  )
236
+ self.observations = _get_masked_parameters("nn_params", params)
265
237
  if self.boundary_loss is None:
266
- try:
267
- self.boundary_loss = _get_masked_parameters("nn_params", params)
268
- except AttributeError:
238
+ if params is None:
269
239
  raise ValueError(
270
240
  "self.boundary_loss is None, hence params should be passed"
271
241
  )
242
+ self.boundary_loss = _get_masked_parameters("nn_params", params)
272
243
  if self.norm_loss is None:
273
- try:
274
- self.norm_loss = _get_masked_parameters("nn_params", params)
275
- except AttributeError:
244
+ if params is None:
276
245
  raise ValueError(
277
246
  "self.norm_loss is None, hence params should be passed"
278
247
  )
248
+ self.norm_loss = _get_masked_parameters("nn_params", params)
279
249
 
280
250
  @classmethod
281
251
  def from_str(
282
252
  cls,
283
- params: Params | ParamsDict,
253
+ params: Params[Array],
284
254
  dyn_loss: (
285
- Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
255
+ Literal["nn_params", "eq_params", "both"] | Params[bool]
286
256
  ) = "nn_params",
287
257
  observations: (
288
- Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
258
+ Literal["nn_params", "eq_params", "both"] | Params[bool]
289
259
  ) = "nn_params",
290
260
  boundary_loss: (
291
- Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
261
+ Literal["nn_params", "eq_params", "both"] | Params[bool]
292
262
  ) = "nn_params",
293
263
  norm_loss: (
294
- Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
264
+ Literal["nn_params", "eq_params", "both"] | Params[bool]
295
265
  ) = "nn_params",
296
266
  ):
297
267
  """
@@ -300,23 +270,23 @@ class DerivativeKeysPDEStatio(eqx.Module):
300
270
  Parameters
301
271
  ----------
302
272
  params
303
- The actual Param or ParamsDict object of the problem.
273
+ The actual Param object of the problem.
304
274
  dyn_loss
305
275
  Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
306
276
  `"both"` we will differentiate the dynamic loss. Default is
307
- `"nn_params"`. Specifying a Params or ParamsDict is also possible.
277
+ `"nn_params"`. Specifying a Params is also possible.
308
278
  observations
309
279
  Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
310
280
  `"both"` we will differentiate the observations. Default is
311
- `"nn_params"`. Specifying a Params or ParamsDict is also possible.
281
+ `"nn_params"`. Specifying a Params is also possible.
312
282
  boundary_loss
313
283
  Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
314
284
  `"both"` we will differentiate the boundary loss. Default is
315
- `"nn_params"`. Specifying a Params or ParamsDict is also possible.
285
+ `"nn_params"`. Specifying a Params is also possible.
316
286
  norm_loss
317
287
  Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
318
288
  `"both"` we will differentiate the normalization loss. Default is
319
- `"nn_params"`. Specifying a Params or ParamsDict is also possible.
289
+ `"nn_params"`. Specifying a Params is also possible.
320
290
  """
321
291
  return DerivativeKeysPDEStatio(
322
292
  dyn_loss=(
@@ -348,64 +318,61 @@ class DerivativeKeysPDENonStatio(DerivativeKeysPDEStatio):
348
318
 
349
319
  Parameters
350
320
  ----------
351
- dyn_loss : Params | ParamsDict | None, default=None
321
+ dyn_loss : Params[bool] | None, default=None
352
322
  Tell wrt which parameters among Params we will differentiate the
353
323
  dynamic loss. To do so, the fields of Params contain True (if
354
324
  differentiation) or False (if no differentiation).
355
- observations : Params | ParamsDict | None, default=None
325
+ observations : Params[bool] | None, default=None
356
326
  Tell wrt which parameters among Params we will differentiate the
357
327
  observation loss. To do so, the fields of Params contain True (if
358
328
  differentiation) or False (if no differentiation).
359
- boundary_loss : Params | ParamsDict | None, default=None
329
+ boundary_loss : Params[bool] | None, default=None
360
330
  Tell wrt which parameters among Params we will differentiate the
361
331
  boundary loss. To do so, the fields of Params contain True (if
362
332
  differentiation) or False (if no differentiation).
363
- norm_loss : Params | ParamsDict | None, default=None
333
+ norm_loss : Params[bool] | None, default=None
364
334
  Tell wrt which parameters among Params we will differentiate the
365
335
  normalization loss. To do so, the fields of Params contain True (if
366
336
  differentiation) or False (if no differentiation).
367
- initial_condition : Params | ParamsDict | None, default=None
337
+ initial_condition : Params[bool] | None, default=None
368
338
  Tell wrt which parameters among Params we will differentiate the
369
339
  initial_condition loss. To do so, the fields of Params contain True (if
370
340
  differentiation) or False (if no differentiation).
371
- params : InitVar[Params | ParamsDict], default=None
341
+ params : InitVar[Params[Array]], default=None
372
342
  The main Params object of the problem. It is required
373
343
  if some terms are unspecified (None). This is because, jinns cannot infer the
374
344
  content of `Params.eq_params`.
375
345
  """
376
346
 
377
- initial_condition: Params | ParamsDict | None = eqx.field(
378
- kw_only=True, default=None
379
- )
347
+ initial_condition: Params[bool] | None = eqx.field(kw_only=True, default=None)
380
348
 
381
- def __post_init__(self, params=None):
349
+ def __post_init__(self, params: Params[Array] | None = None):
382
350
  super().__post_init__(params=params)
383
351
  if self.initial_condition is None:
384
- try:
385
- self.initial_condition = _get_masked_parameters("nn_params", params)
386
- except AttributeError:
352
+ if params is None:
387
353
  raise ValueError(
388
354
  "self.initial_condition is None, hence params should be passed"
389
355
  )
356
+ self.initial_condition = _get_masked_parameters("nn_params", params)
390
357
 
391
358
  @classmethod
392
359
  def from_str(
393
360
  cls,
394
- params: Params | ParamsDict,
361
+ params: Params[Array],
395
362
  dyn_loss: (
396
- Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
363
+ Literal["nn_params", "eq_params", "both"] | Params[bool]
397
364
  ) = "nn_params",
398
365
  observations: (
399
- Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
366
+ Literal["nn_params", "eq_params", "both"] | Params[bool]
400
367
  ) = "nn_params",
401
368
  boundary_loss: (
402
- Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
369
+ Literal["nn_params", "eq_params", "both"] | Params[bool]
403
370
  ) = "nn_params",
404
371
  norm_loss: (
405
- Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
372
+ Literal["nn_params", "eq_params", "both"] | Params[bool]
406
373
  ) = "nn_params",
407
374
  initial_condition: (
408
- Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
375
+ Literal["nn_params", "eq_params", "both"] | Params[bool]
409
376
  ) = "nn_params",
410
377
  ):
411
378
  """
@@ -414,27 +381,27 @@ class DerivativeKeysPDENonStatio(DerivativeKeysPDEStatio):
414
381
  Parameters
415
382
  ----------
416
383
  params
417
- The actual Params | ParamsDict object of the problem.
384
+ The actual Params object of the problem.
418
385
  dyn_loss
419
386
  Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
420
387
  `"both"` we will differentiate the dynamic loss. Default is
421
- `"nn_params"`. Specifying a Params or ParamsDict is also possible.
388
+ `"nn_params"`. Specifying a Params is also possible.
422
389
  observations
423
390
  Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
424
391
  `"both"` we will differentiate the observations. Default is
425
- `"nn_params"`. Specifying a Params or ParamsDict is also possible.
392
+ `"nn_params"`. Specifying a Params is also possible.
426
393
  boundary_loss
427
394
  Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
428
395
  `"both"` we will differentiate the boundary loss. Default is
429
- `"nn_params"`. Specifying a Params or ParamsDict is also possible.
396
+ `"nn_params"`. Specifying a Params is also possible.
430
397
  norm_loss
431
398
  Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
432
399
  `"both"` we will differentiate the normalization loss. Default is
433
- `"nn_params"`. Specifying a Params or ParamsDict is also possible.
400
+ `"nn_params"`. Specifying a Params is also possible.
434
401
  initial_condition
435
402
  Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
436
403
  `"both"` we will differentiate the initial_condition loss. Default is
437
- `"nn_params"`. Specifying a Params or ParamsDict is also possible.
404
+ `"nn_params"`. Specifying a Params is also possible.
438
405
  """
439
406
  return DerivativeKeysPDENonStatio(
440
407
  dyn_loss=(
@@ -471,32 +438,6 @@ def _set_derivatives(params, derivative_keys):
471
438
  has a copy of the params with appropriate derivatives set
472
439
  """
473
440
 
474
- def _set_derivatives_ParamsDict(params_, derivative_mask):
475
- """
476
- The next lines put a stop_gradient around the fields that do not
477
- differentiate the loss term
478
- **Note:** **No granularity inside `ParamsDict.nn_params` is currently
479
- supported.**
480
- This means a typical Params specification is of the form:
481
- `ParamsDict(nn_params=True | False, eq_params={"0":{"alpha":True | False,
482
- "beta":True | False}}, "1":{"alpha":True | False, "beta":True | False}})`.
483
- """
484
- # a ParamsDict object is reconstructed by hand since we do not want to
485
- # traverse nn_params, for now...
486
- return ParamsDict(
487
- nn_params=jax.lax.cond(
488
- derivative_mask.nn_params,
489
- lambda p: p,
490
- jax.lax.stop_gradient,
491
- params_.nn_params,
492
- ),
493
- eq_params=jax.tree.map(
494
- lambda p, d: jax.lax.cond(d, lambda p: p, jax.lax.stop_gradient, p),
495
- params_.eq_params,
496
- derivative_mask.eq_params,
497
- ),
498
- )
499
-
500
441
  def _set_derivatives_(params_, derivative_mask):
501
442
  """
502
443
  The next lines put a stop_gradient around the fields that do not
@@ -516,6 +457,4 @@ def _set_derivatives(params, derivative_keys):
516
457
  # granularity could be imagined here, in the future
517
458
  )
518
459
 
519
- if isinstance(params, ParamsDict):
520
- return _set_derivatives_ParamsDict(params, derivative_keys)
521
460
  return _set_derivatives_(params, derivative_keys)
@@ -2,67 +2,36 @@
2
2
  Formalize the data structure for the parameters
3
3
  """
4
4
 
5
+ from typing import Generic, TypeVar
5
6
  import jax
6
7
  import equinox as eqx
7
- from typing import Dict
8
- from jaxtyping import Array, PyTree
8
+ from jaxtyping import Array, PyTree, Float
9
9
 
10
+ T = TypeVar("T") # the generic type for what is in the Params PyTree because we
11
+ # have possibly Params of Arrays, boolean, ...
10
12
 
11
- class Params(eqx.Module):
13
+
14
+ class Params(eqx.Module, Generic[T]):
12
15
  """
13
16
  The equinox module for the parameters
14
17
 
15
18
  Parameters
16
19
  ----------
17
- nn_params : Pytree
20
+ nn_params : PyTree[T]
18
21
  A PyTree of the non-static part of the PINN eqx.Module, i.e., the
19
22
  parameters of the PINN
20
- eq_params : Dict[str, Array]
23
+ eq_params : dict[str, T]
21
24
  A dictionary of the equation parameters. Keys are the parameter name,
22
25
  values are their corresponding value
23
26
  """
24
27
 
25
- nn_params: PyTree = eqx.field(kw_only=True, default=None)
26
- eq_params: Dict[str, Array] = eqx.field(kw_only=True, default=None)
27
-
28
-
29
- class ParamsDict(eqx.Module):
30
- """
31
- The equinox module for a dictionnary of parameters with different keys
32
- corresponding to different equations.
33
-
34
- Parameters
35
- ----------
36
- nn_params : Dict[str, PyTree]
37
- The neural network's parameters. Most of the time, it will be the
38
- Array part of an `eqx.Module` obtained by
39
- `eqx.partition(module, eqx.is_inexact_array)`.
40
- eq_params : Dict[str, Array]
41
- A dictionary of the equation parameters. Dict keys are the parameter name as defined your custom loss.
42
- """
43
-
44
- nn_params: Dict[str, PyTree] = eqx.field(kw_only=True, default=None)
45
- eq_params: Dict[str, Array] = eqx.field(kw_only=True, default=None)
46
-
47
- def extract_params(self, nn_key: str) -> Params:
48
- """
49
- Extract the corresponding `nn_params` and `eq_params` for `nn_key` and
50
- return them in the form of a `Params` object.
51
- """
52
- try:
53
- return Params(
54
- nn_params=self.nn_params[nn_key],
55
- eq_params=self.eq_params[nn_key],
56
- )
57
- except (KeyError, IndexError) as e:
58
- return Params(
59
- nn_params=self.nn_params[nn_key],
60
- eq_params=self.eq_params,
61
- )
28
+ nn_params: PyTree[T] = eqx.field(kw_only=True, default=None)
29
+ eq_params: dict[str, T] = eqx.field(kw_only=True, default=None)
62
30
 
63
31
 
64
32
  def _update_eq_params_dict(
65
- params: Params, param_batch_dict: Dict[str, Array]
33
+ params: Params[Array],
34
+ param_batch_dict: dict[str, Float[Array, " param_batch_size dim"]],
66
35
  ) -> Params:
67
36
  """
68
37
  Update params.eq_params with a batch of eq_params for given key(s)
@@ -89,13 +58,16 @@ def _update_eq_params_dict(
89
58
 
90
59
 
91
60
  def _get_vmap_in_axes_params(
92
- eq_params_batch_dict: Dict[str, Array], params: Params | ParamsDict
93
- ) -> tuple[Params]:
61
+ eq_params_batch_dict: dict[str, Array], params: Params[Array]
62
+ ) -> tuple[Params[int | None] | None]:
94
63
  """
95
64
  Return the input vmap axes when there is batch(es) of parameters to vmap
96
65
  over. The latter are designated by keys in eq_params_batch_dict.
97
66
  If eq_params_batch_dict is None (i.e. no additional parameter batch), we
98
67
  return (None,).
68
+
69
+ Note that we return a Params PyTree with an integer to designate the
70
+ vmapped axis or None if there is not
99
71
  """
100
72
  if eq_params_batch_dict is None:
101
73
  return (None,)
@@ -104,7 +76,7 @@ def _get_vmap_in_axes_params(
104
76
  # this is for a fine-grained vmaping
105
77
  # scheme over the params
106
78
  vmap_in_axes_params = (
107
- type(params)(
79
+ Params(
108
80
  nn_params=None,
109
81
  eq_params={
110
82
  k: (0 if k in eq_params_batch_dict.keys() else None)
jinns/plot/__init__.py CHANGED
@@ -3,3 +3,5 @@ from jinns.plot._plot import (
3
3
  plot1d_slice,
4
4
  plot1d_image,
5
5
  )
6
+
7
+ __all__ = ["plot2d", "plot1d_slice", "plot1d_image"]