jinns 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +116 -189
  13. jinns/loss/_DynamicLossAbstract.py +45 -68
  14. jinns/loss/_LossODE.py +71 -336
  15. jinns/loss/_LossPDE.py +176 -513
  16. jinns/loss/__init__.py +28 -6
  17. jinns/loss/_abstract_loss.py +15 -0
  18. jinns/loss/_boundary_conditions.py +22 -21
  19. jinns/loss/_loss_utils.py +98 -173
  20. jinns/loss/_loss_weights.py +12 -44
  21. jinns/loss/_operators.py +84 -76
  22. jinns/nn/__init__.py +22 -0
  23. jinns/nn/_abstract_pinn.py +22 -0
  24. jinns/nn/_hyperpinn.py +434 -0
  25. jinns/nn/_mlp.py +217 -0
  26. jinns/nn/_pinn.py +204 -0
  27. jinns/nn/_ppinn.py +239 -0
  28. jinns/{utils → nn}/_save_load.py +39 -53
  29. jinns/nn/_spinn.py +123 -0
  30. jinns/nn/_spinn_mlp.py +202 -0
  31. jinns/nn/_utils.py +38 -0
  32. jinns/parameters/__init__.py +8 -1
  33. jinns/parameters/_derivative_keys.py +116 -177
  34. jinns/parameters/_params.py +18 -46
  35. jinns/plot/__init__.py +2 -0
  36. jinns/plot/_plot.py +38 -37
  37. jinns/solver/_rar.py +82 -65
  38. jinns/solver/_solve.py +111 -71
  39. jinns/solver/_utils.py +4 -6
  40. jinns/utils/__init__.py +2 -5
  41. jinns/utils/_containers.py +12 -9
  42. jinns/utils/_types.py +11 -57
  43. jinns/utils/_utils.py +4 -11
  44. jinns/validation/__init__.py +2 -0
  45. jinns/validation/_validation.py +20 -19
  46. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/METADATA +11 -10
  47. jinns-1.4.0.dist-info/RECORD +53 -0
  48. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
  49. jinns/data/_DataGenerators.py +0 -1634
  50. jinns/utils/_hyperpinn.py +0 -420
  51. jinns/utils/_pinn.py +0 -324
  52. jinns/utils/_ppinn.py +0 -227
  53. jinns/utils/_spinn.py +0 -249
  54. jinns-1.2.0.dist-info/RECORD +0 -41
  55. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
  56. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
  57. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/nn/_hyperpinn.py ADDED
@@ -0,0 +1,434 @@
1
+ """
2
+ Implements utility function to create HyperPINNs
3
+ https://arxiv.org/pdf/2111.01008.pdf
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import warnings
9
+ from dataclasses import InitVar
10
+ from typing import Callable, Literal, Self, Union, Any, cast, overload
11
+ from math import prod
12
+ import jax
13
+ import jax.numpy as jnp
14
+ from jaxtyping import Array, Float, PyTree, Key
15
+ import equinox as eqx
16
+ import numpy as onp
17
+
18
+ from jinns.nn._pinn import PINN
19
+ from jinns.nn._mlp import MLP
20
+ from jinns.parameters._params import Params
21
+ from jinns.nn._utils import _PyTree_to_Params
22
+
23
+
24
+ def _get_param_nb(
25
+ params: PyTree[Array],
26
+ ) -> tuple[int, list[int]]:
27
+ """Returns the number of parameters in a Params object and also
28
+ the cumulative sum when parsing the object.
29
+
30
+
31
+ Parameters
32
+ ----------
33
+ params :
34
+ A Params object.
35
+ """
36
+ dim_prod_all_arrays = [
37
+ prod(a.shape)
38
+ for a in jax.tree.leaves(params, is_leaf=lambda x: isinstance(x, jnp.ndarray))
39
+ ]
40
+ return (
41
+ sum(dim_prod_all_arrays),
42
+ onp.cumsum(dim_prod_all_arrays).tolist(),
43
+ )
44
+
45
+
46
+ class HyperPINN(PINN):
47
+ r"""
48
+ An HyperPINN object compatible with the rest of jinns.
49
+ Composed of a PINN and an HYPER network. The HyperPINN is typically
50
+ instanciated using with `create`.
51
+
52
+ Parameters
53
+ ----------
54
+ hyperparams: list[str] = eqx.field(static=True)
55
+ A list of keys from Params.eq_params that will be considered as
56
+ hyperparameters for metamodeling.
57
+ hypernet_input_size: int
58
+ An integer. The input size of the MLP used for the hypernetwork. Must
59
+ be equal to the flattened concatenations for the array of parameters
60
+ designated by the `hyperparams` argument.
61
+ slice_solution : slice
62
+ A jnp.s\_ object which indicates which axis of the PINN output is
63
+ dedicated to the actual equation solution. Default None
64
+ means that slice_solution = the whole PINN output. This argument is useful
65
+ when the PINN is also used to output equation parameters for example
66
+ Note that it must be a slice and not an integer (a preprocessing of the
67
+ user provided argument takes care of it).
68
+ eq_type : str
69
+ A string with three possibilities.
70
+ "ODE": the HyperPINN is called with one input `t`.
71
+ "statio_PDE": the HyperPINN is called with one input `x`, `x`
72
+ can be high dimensional.
73
+ "nonstatio_PDE": the HyperPINN is called with two inputs `t` and `x`, `x`
74
+ can be high dimensional.
75
+ **Note**: the input dimension as given in eqx_list has to match the sum
76
+ of the dimension of `t` + the dimension of `x` or the output dimension
77
+ after the `input_transform` function
78
+ input_transform : Callable[[Float[Array, " input_dim"], Params[Array]], Float[Array, " output_dim"]]
79
+ A function that will be called before entering the PINN. Its output(s)
80
+ must match the PINN inputs (except for the parameters).
81
+ Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
82
+ and the parameters. Default is no operation.
83
+ output_transform : Callable[[Float[Array, " input_dim"], Float[Array, " output_dim"], Params[Array]], Float[Array, " output_dim"]]
84
+ A function with arguments begin the same input as the PINN, the PINN
85
+ output and the parameter. This function will be called after exiting the PINN.
86
+ Default is no operation.
87
+ mlp : eqx.Module
88
+ The actual neural network instanciated as an eqx.Module.
89
+ hyper_mlp : eqx.Module
90
+ The actual hyper neural network instanciated as an eqx.Module.
91
+ filter_spec : PyTree[Union[bool, Callable[[Any], bool]]]
92
+ Default is `eqx.is_inexact_array`. This tells Jinns what to consider as
93
+ a trainable parameter. Quoting from equinox documentation:
94
+ a PyTree whose structure should be a prefix of the structure of pytree.
95
+ Each of its leaves should either be 1) True, in which case the leaf or
96
+ subtree is kept; 2) False, in which case the leaf or subtree is
97
+ replaced with replace; 3) a callable Leaf -> bool, in which case this is evaluated on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
98
+ """
99
+
100
+ hyperparams: list[str] = eqx.field(static=True, kw_only=True)
101
+ hypernet_input_size: int = eqx.field(kw_only=True)
102
+
103
+ eqx_hyper_network: InitVar[eqx.Module] = eqx.field(kw_only=True)
104
+
105
+ pinn_params_sum: int = eqx.field(init=False, static=True)
106
+ pinn_params_cumsum: list[int] = eqx.field(init=False, static=True)
107
+
108
+ init_params_hyper: HyperPINN = eqx.field(init=False)
109
+ static_hyper: HyperPINN = eqx.field(init=False, static=True)
110
+
111
+ def __post_init__(self, eqx_network, eqx_hyper_network):
112
+ super().__post_init__(
113
+ eqx_network,
114
+ )
115
+ # In addition, we store the PyTree structure of the hypernetwork as well
116
+ self.init_params_hyper, self.static_hyper = eqx.partition(
117
+ eqx_hyper_network, self.filter_spec
118
+ )
119
+ self.pinn_params_sum, self.pinn_params_cumsum = _get_param_nb(self.init_params)
120
+
121
+ def _hyper_to_pinn(self, hyper_output: Float[Array, " output_dim"]) -> PINN:
122
+ """
123
+ From the output of the hypernetwork, transform to a well formed
124
+ parameters for the pinn network (i.e. with the same PyTree structure as
125
+ `self.init_params`)
126
+ """
127
+
128
+ pinn_params_flat = eqx.tree_at(
129
+ jax.tree.leaves, # is_leaf=eqx.is_array argument for jax.tree.leaves
130
+ # is not needed in general when working
131
+ # with eqx.nn.Linear for examples: jax.tree.leaves
132
+ # already returns the array of weights and biases only, since the
133
+ # other stuff (that we do not want to be returned) is marked as
134
+ # static (in eqx.nn.Linear), hence is not part of the leaves.
135
+ # Note, that custom layers should then be properly designed to pass
136
+ # this jax.tree.leaves.
137
+ self.init_params,
138
+ jnp.split(hyper_output, self.pinn_params_cumsum[:-1]),
139
+ )
140
+
141
+ return jax.tree.map(
142
+ lambda a, b: a.reshape(b.shape),
143
+ pinn_params_flat,
144
+ self.init_params,
145
+ is_leaf=lambda x: isinstance(x, jnp.ndarray),
146
+ )
147
+
148
+ @overload
149
+ @_PyTree_to_Params
150
+ def __call__(
151
+ self,
152
+ inputs: Float[Array, " input_dim"],
153
+ params: PyTree,
154
+ *args,
155
+ **kwargs,
156
+ ) -> Float[Array, " output_dim"]: ...
157
+
158
+ @_PyTree_to_Params
159
+ def __call__(
160
+ self,
161
+ inputs: Float[Array, " input_dim"],
162
+ params: Params[Array],
163
+ *args,
164
+ **kwargs,
165
+ ) -> Float[Array, " output_dim"]:
166
+ """
167
+ Evaluate the HyperPINN on some inputs with some params.
168
+
169
+ Note that that thanks to the decorator, params can also directly be the
170
+ PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
171
+ """
172
+ if len(inputs.shape) == 0:
173
+ # This can happen often when the user directly provides some
174
+ # collocation points (eg for plotting, whithout using
175
+ # DataGenerators)
176
+ inputs = inputs[None]
177
+
178
+ # try:
179
+ hyper = eqx.combine(params.nn_params, self.static_hyper)
180
+ # except (KeyError, AttributeError, TypeError) as e: # give more flexibility
181
+ # hyper = eqx.combine(params, self.static_hyper)
182
+
183
+ eq_params_batch = jnp.concatenate(
184
+ [params.eq_params[k].flatten() for k in self.hyperparams],
185
+ axis=0,
186
+ )
187
+
188
+ hyper_output = hyper(eq_params_batch) # type: ignore
189
+
190
+ pinn_params = self._hyper_to_pinn(hyper_output)
191
+
192
+ pinn = eqx.combine(pinn_params, self.static)
193
+ res = self.eval(pinn, self.input_transform(inputs, params), *args, **kwargs)
194
+
195
+ res = self.output_transform(inputs, res.squeeze(), params)
196
+
197
+ # force (1,) output for non vectorial solution (consistency)
198
+ if not res.shape:
199
+ return jnp.expand_dims(res, axis=-1)
200
+ return res
201
+
202
+ @classmethod
203
+ def create(
204
+ cls,
205
+ eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
206
+ hyperparams: list[str],
207
+ hypernet_input_size: int,
208
+ eqx_network: eqx.nn.MLP | MLP | None = None,
209
+ eqx_hyper_network: eqx.nn.MLP | MLP | None = None,
210
+ key: Key = None,
211
+ eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...] | None = None,
212
+ eqx_list_hyper: (
213
+ tuple[tuple[Callable, int, int] | tuple[Callable], ...] | None
214
+ ) = None,
215
+ input_transform: (
216
+ Callable[
217
+ [Float[Array, " input_dim"], Params[Array]],
218
+ Float[Array, " output_dim"],
219
+ ]
220
+ | None
221
+ ) = None,
222
+ output_transform: (
223
+ Callable[
224
+ [
225
+ Float[Array, " input_dim"],
226
+ Float[Array, " output_dim"],
227
+ Params[Array],
228
+ ],
229
+ Float[Array, " output_dim"],
230
+ ]
231
+ | None
232
+ ) = None,
233
+ slice_solution: slice | None = None,
234
+ filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = None,
235
+ ) -> tuple[Self, HyperPINN]:
236
+ r"""
237
+ Utility function to create a standard PINN neural network with the equinox
238
+ library.
239
+
240
+ Parameters
241
+ ----------
242
+ key
243
+ A JAX random key that will be used to initialize the network
244
+ parameters.
245
+ eq_type
246
+ A string with three possibilities.
247
+ "ODE": the HyperPINN is called with one input `t`.
248
+ "statio_PDE": the HyperPINN is called with one input `x`, `x`
249
+ can be high dimensional.
250
+ "nonstatio_PDE": the HyperPINN is called with two inputs `t` and `x`, `x`
251
+ can be high dimensional.
252
+ **Note**: the input dimension as given in eqx_list has to match the sum
253
+ of the dimension of `t` + the dimension of `x` or the output dimension
254
+ after the `input_transform` function
255
+ hyperparams
256
+ A list of keys from Params.eq_params that will be considered as
257
+ hyperparameters for metamodeling.
258
+ hypernet_input_size
259
+ An integer. The input size of the MLP used for the hypernetwork. Must
260
+ be equal to the flattened concatenations for the array of parameters
261
+ designated by the `hyperparams` argument.
262
+ eqx_network
263
+ Default is None. A eqx.nn.MLP for the base network that will be wrapped inside
264
+ our PINN_MLP object in order to make it easily jinns compatible.
265
+ eqx_hyper_network
266
+ Default is None. A eqx.nn.MLP for the hyper network that will be wrapped inside
267
+ our PINN_MLP object in order to make it easily jinns compatible.
268
+ key
269
+ Default is None. Must be provided with `eqx_list` and
270
+ `eqx_list_hyper` if `eqx_network` or `eqx_hyper_network`
271
+ is not provided. A JAX random key that will be used to initialize the network
272
+ parameters.
273
+ eqx_list
274
+ Default is None. Must be provided if `eqx_network` or
275
+ `eqx_hyper_network`
276
+ is not provided.
277
+ A tuple of tuples of successive equinox modules and activation functions to
278
+ describe the base network architecture. The inner tuples must have the eqx module or
279
+ activation function as first item, other items represent arguments
280
+ that could be required (eg. the size of the layer).
281
+ The `key` argument need not be given.
282
+ Thus typical example is `eqx_list=
283
+ ((eqx.nn.Linear, 2, 20),
284
+ (jax.nn.tanh,),
285
+ (eqx.nn.Linear, 20, 20),
286
+ (jax.nn.tanh,),
287
+ (eqx.nn.Linear, 20, 20),
288
+ (jax.nn.tanh,),
289
+ (eqx.nn.Linear, 20, 1)
290
+ )`.
291
+ eqx_list_hyper
292
+ Default is None. Must be provided if `eqx_network` or
293
+ `eqx_hyper_network`
294
+ is not provided.
295
+ A tuple of tuples of successive equinox modules and activation functions to
296
+ describe the hyper network architecture. The inner tuples must have the eqx module or
297
+ activation function as first item, other items represent arguments
298
+ that could be required (eg. the size of the layer).
299
+ The `key` argument need not be given.
300
+ Thus typical example is `eqx_list=
301
+ ((eqx.nn.Linear, 2, 20),
302
+ (jax.nn.tanh,),
303
+ (eqx.nn.Linear, 20, 20),
304
+ (jax.nn.tanh,),
305
+ (eqx.nn.Linear, 20, 20),
306
+ (jax.nn.tanh,),
307
+ (eqx.nn.Linear, 20, 1)
308
+ )`.
309
+ input_transform
310
+ A function that will be called before entering the PINN. Its output(s)
311
+ must match the PINN inputs (except for the parameters).
312
+ Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
313
+ and the parameters. Default is no operation.
314
+ output_transform
315
+ A function with arguments begin the same input as the PINN, the PINN
316
+ output and the parameter. This function will be called after exiting the PINN.
317
+ Default is no operation.
318
+ slice_solution
319
+ A jnp.s\_ object which indicates which axis of the PINN output is
320
+ dedicated to the actual equation solution. Default None
321
+ means that slice_solution = the whole PINN output. This argument is useful
322
+ when the PINN is also used to output equation parameters for example
323
+ Note that it must be a slice and not an integer (a preprocessing of the
324
+ user provided argument takes care of it).
325
+ eqx_list_hyper
326
+ Same as eqx_list but for the hypernetwork. Default is None, i.e., we
327
+ use the same architecture as the PINN, up to the number of inputs and
328
+ ouputs. Note that the number of inputs must be of the hypernetwork must
329
+ be equal to the flattened concatenations for the array of parameters
330
+ designated by the `hyperparams` argument;
331
+ and the number of outputs must be equal to the number
332
+ of parameters in the pinn network
333
+ filter_spec : PyTree[Union[bool, Callable[[Any], bool]]]
334
+ Default is None which leads to `eqx.is_inexact_array` in the class
335
+ instanciation. This tells Jinns what to consider as
336
+ a trainable parameter. Quoting from equinox documentation:
337
+ a PyTree whose structure should be a prefix of the structure of pytree.
338
+ Each of its leaves should either be 1) True, in which case the leaf or
339
+ subtree is kept; 2) False, in which case the leaf or subtree is
340
+ replaced with replace; 3) a callable Leaf -> bool, in which case this is evaluated on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
341
+
342
+ Returns
343
+ -------
344
+ hyperpinn
345
+ A HyperPINN instance or, when `shared_pinn_ouput` is not None,
346
+ a list of HyperPINN instances with the same structure is returned,
347
+ only differing by there final slicing of the network output.
348
+ hyperpinn.init_params
349
+ The initial set of parameters for the HyperPINN or a list of the latter
350
+ when `shared_pinn_ouput` is not None.
351
+
352
+ """
353
+ if eqx_network is None or eqx_hyper_network is None:
354
+ if eqx_list is None or key is None or eqx_list_hyper is None:
355
+ raise ValueError(
356
+ "If eqx_network is None or eqx_hyper_network is None, then"
357
+ " key and eqx_list and eqx_hyper_network must be provided"
358
+ )
359
+
360
+ ### Now we finetune the hypernetwork architecture
361
+
362
+ key, subkey = jax.random.split(key, 2)
363
+ # with warnings.catch_warnings():
364
+ # warnings.filterwarnings("ignore", message="A JAX array is being set as static!")
365
+ eqx_network = MLP(key=subkey, eqx_list=eqx_list)
366
+ # quick partitioning to get the params to get the correct number of neurons
367
+ # for the last layer of hyper network
368
+ params_mlp, _ = eqx.partition(eqx_network, eqx.is_inexact_array)
369
+ pinn_params_sum, _ = _get_param_nb(params_mlp)
370
+ # the number of parameters for the pinn will be the number of ouputs
371
+ # for the hyper network
372
+ if len(eqx_list_hyper[-1]) > 1:
373
+ eqx_list_hyper = eqx_list_hyper[:-1] + (
374
+ (eqx_list_hyper[-1][:2] + (pinn_params_sum,)),
375
+ )
376
+ else:
377
+ eqx_list_hyper = cast(
378
+ tuple[tuple[Callable, int, int] | tuple[Callable], ...],
379
+ (
380
+ eqx_list_hyper[:-2]
381
+ + ((eqx_list_hyper[-2][:2] + (pinn_params_sum,)),)
382
+ + eqx_list_hyper[-1]
383
+ ),
384
+ )
385
+ if len(eqx_list_hyper[0]) > 1:
386
+ eqx_list_hyper = (
387
+ (
388
+ (eqx_list_hyper[0][0],)
389
+ + (hypernet_input_size,)
390
+ + (eqx_list_hyper[0][2],)
391
+ ),
392
+ ) + eqx_list_hyper[1:]
393
+ else:
394
+ eqx_list_hyper = cast(
395
+ tuple[tuple[Callable, int, int] | tuple[Callable], ...],
396
+ (
397
+ eqx_list_hyper[0]
398
+ + (
399
+ (
400
+ (eqx_list_hyper[1][0],)
401
+ + (hypernet_input_size,)
402
+ + (eqx_list_hyper[1][2],) # type: ignore because we suppose that the second element of tuple is nec.of length > 1 since we expect smth like eqx.nn.Linear
403
+ ),
404
+ )
405
+ + eqx_list_hyper[2:]
406
+ ),
407
+ )
408
+ key, subkey = jax.random.split(key, 2)
409
+ # with warnings.catch_warnings():
410
+ # warnings.filterwarnings("ignore", message="A JAX array is being set as static!")
411
+ eqx_hyper_network = cast(MLP, MLP(key=subkey, eqx_list=eqx_list_hyper))
412
+
413
+ ### End of finetuning the hypernetwork architecture
414
+
415
+ with warnings.catch_warnings():
416
+ # Catch the equinox warning because we put the number of
417
+ # parameters as static while being jnp.Array. This this time
418
+ # this is correct to do so, because they are used as indices
419
+ # and will never be modified
420
+ warnings.filterwarnings(
421
+ "ignore", message="A JAX array is being set as static!"
422
+ )
423
+ hyperpinn = cls(
424
+ eqx_network=eqx_network,
425
+ eqx_hyper_network=eqx_hyper_network,
426
+ slice_solution=slice_solution, # type: ignore
427
+ eq_type=eq_type,
428
+ input_transform=input_transform, # type: ignore
429
+ output_transform=output_transform, # type: ignore
430
+ hyperparams=hyperparams,
431
+ hypernet_input_size=hypernet_input_size,
432
+ filter_spec=filter_spec,
433
+ )
434
+ return hyperpinn, hyperpinn.init_params_hyper
jinns/nn/_mlp.py ADDED
@@ -0,0 +1,217 @@
1
+ """
2
+ Implements utility function to create PINNs
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import Callable, Literal, Self, Union, Any, TYPE_CHECKING, cast
8
+ from dataclasses import InitVar
9
+ import jax
10
+ import equinox as eqx
11
+ from typing import Protocol
12
+ from jaxtyping import Array, Key, PyTree, Float
13
+
14
+ from jinns.parameters._params import Params
15
+ from jinns.nn._pinn import PINN
16
+
17
+ if TYPE_CHECKING:
18
+
19
+ class CallableMLPModule(Protocol):
20
+ """
21
+ Basically just a way to add a __call__ to an eqx.Module.
22
+ https://github.com/patrick-kidger/equinox/issues/1002
23
+ We chose the strutural subtyping of protocols instead of subclassing an
24
+ eqx.Module just to add a __call__ here
25
+ """
26
+
27
+ def __call__(self, *_, **__) -> Array: ...
28
+
29
+
30
+ class MLP(eqx.Module):
31
+ """
32
+ Custom MLP equinox module from a key and a eqx_list
33
+
34
+ Parameters
35
+ ----------
36
+ key : InitVar[Key]
37
+ A jax random key for the layer initializations.
38
+ eqx_list : InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]]
39
+ A tuple of tuples of successive equinox modules and activation functions to
40
+ describe the PINN architecture. The inner tuples must have the eqx module or
41
+ activation function as first item, other items represents arguments
42
+ that could be required (eg. the size of the layer).
43
+ The `key` argument need not be given.
44
+ Thus typical example is `eqx_list=
45
+ ((eqx.nn.Linear, 2, 20),
46
+ (jax.nn.tanh,),
47
+ (eqx.nn.Linear, 20, 20),
48
+ (jax.nn.tanh,),
49
+ (eqx.nn.Linear, 20, 20),
50
+ (jax.nn.tanh,),
51
+ (eqx.nn.Linear, 20, 1)
52
+ )`.
53
+ """
54
+
55
+ key: InitVar[Key] = eqx.field(kw_only=True)
56
+ eqx_list: InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]] = (
57
+ eqx.field(kw_only=True)
58
+ )
59
+
60
+ # NOTE that the following should NOT be declared as static otherwise the
61
+ # eqx.partition that we use in the PINN module will misbehave
62
+ layers: list[CallableMLPModule | Callable[[Array], Array]] = eqx.field(init=False)
63
+
64
+ def __post_init__(self, key, eqx_list):
65
+ self.layers = []
66
+ # nb_keys_required = sum(1 if len(l) > 1 else 0 for l in eqx_list)
67
+ # keys = jax.random.split(key, nb_keys_required)
68
+ # we need a global split
69
+ # before the loop to maintain strict equivalency with eqx.nn.MLP
70
+ # for debugging purpose
71
+ k = 0
72
+ for l in eqx_list:
73
+ if len(l) == 1:
74
+ self.layers.append(l[0])
75
+ else:
76
+ key, subkey = jax.random.split(key, 2) # nb_keys_required)
77
+ self.layers.append(l[0](*l[1:], key=subkey))
78
+ k += 1
79
+
80
+ def __call__(self, t: Float[Array, " input_dim"]) -> Float[Array, " output_dim"]:
81
+ for layer in self.layers:
82
+ t = layer(t)
83
+ return t
84
+
85
+
86
+ class PINN_MLP(PINN):
87
+ """
88
+ An implementable PINN based on a MLP architecture
89
+ """
90
+
91
+ # Here we could have a more complex __call__ method that redefined the
92
+ # parent's __call__. But there is no need for the simple PINN_MLP
93
+
94
+ @classmethod
95
+ def create(
96
+ cls,
97
+ eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
98
+ eqx_network: eqx.nn.MLP | MLP | None = None,
99
+ key: Key = None,
100
+ eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...] | None = None,
101
+ input_transform: (
102
+ Callable[
103
+ [Float[Array, " input_dim"], Params[Array]],
104
+ Float[Array, " output_dim"],
105
+ ]
106
+ | None
107
+ ) = None,
108
+ output_transform: (
109
+ Callable[
110
+ [
111
+ Float[Array, " input_dim"],
112
+ Float[Array, " output_dim"],
113
+ Params[Array],
114
+ ],
115
+ Float[Array, " output_dim"],
116
+ ]
117
+ | None
118
+ ) = None,
119
+ slice_solution: slice | None = None,
120
+ filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = None,
121
+ ) -> tuple[Self, PINN]:
122
+ r"""
123
+ Instanciate standard PINN MLP object. The actual NN is either passed as
124
+ a eqx.nn.MLP (`eqx_network` argument) or constructed as a custom
125
+ jinns.nn.MLP when `key` and `eqx_list` is provided.
126
+
127
+ Parameters
128
+ ----------
129
+ eq_type
130
+ A string with three possibilities.
131
+ "ODE": the MLP is called with one input `t`.
132
+ "statio_PDE": the MLP is called with one input `x`, `x`
133
+ can be high dimensional.
134
+ "nonstatio_PDE": the MLP is called with two inputs `t` and `x`, `x`
135
+ can be high dimensional.
136
+ **Note**: the input dimension as given in eqx_list has to match the sum
137
+ of the dimension of `t` + the dimension of `x` or the output dimension
138
+ after the `input_transform` function.
139
+ eqx_network
140
+ Default is None. A eqx.nn.MLP object that will be wrapped inside
141
+ our PINN_MLP object in order to make it easily jinns compatible.
142
+ key
143
+ Default is None. Must be provided with `eqx_list` if `eqx_network`
144
+ is not provided. A JAX random key that will be used to initialize the network
145
+ parameters.
146
+ eqx_list
147
+ Default is None. Must be provided if `eqx_network`
148
+ is not provided. A tuple of tuples of successive equinox modules and activation
149
+ functions to describe the MLP architecture. The inner tuples must have
150
+ the eqx module or activation function as first item, other items
151
+ represent arguments that could be required (eg. the size of the layer).
152
+
153
+ The `key` argument do not need to be given.
154
+
155
+ A typical example is `eqx_list = (
156
+ (eqx.nn.Linear, input_dim, 20),
157
+ (jax.nn.tanh,),
158
+ (eqx.nn.Linear, 20, 20),
159
+ (jax.nn.tanh,),
160
+ (eqx.nn.Linear, 20, 20),
161
+ (jax.nn.tanh,),
162
+ (eqx.nn.Linear, 20, output_dim)
163
+ )`.
164
+ input_transform
165
+ A function that will be called before entering the MLP. Its output(s)
166
+ must match the MLP inputs (except for the parameters).
167
+ Its inputs are the MLP inputs (`t` and/or `x` concatenated together)
168
+ and the parameters. Default is no operation.
169
+ output_transform
170
+ A function with arguments begin the same input as the MLP, the MLP
171
+ output and the parameter. This function will be called after exiting
172
+ the MLP.
173
+ Default is no operation.
174
+ slice_solution
175
+ A jnp.s\_ object which indicates which axis of the MLP output is
176
+ dedicated to the actual equation solution. Default None
177
+ means that slice_solution = the whole MLP output. This argument is
178
+ useful when the MLP is also used to output equation parameters for
179
+ example Note that it must be a slice and not an integer (a
180
+ preprocessing of the user provided argument takes care of it).
181
+
182
+ filter_spec : PyTree[Union[bool, Callable[[Any], bool]]]
183
+ Default is None which leads to `eqx.is_inexact_array` in the class
184
+ instanciation. This tells Jinns what to consider as
185
+ a trainable parameter. Quoting from equinox documentation:
186
+ a PyTree whose structure should be a prefix of the structure of pytree.
187
+ Each of its leaves should either be 1) True, in which case the leaf or
188
+ subtree is kept; 2) False, in which case the leaf or subtree is
189
+ replaced with replace; 3) a callable Leaf -> bool, in which case this is evaluated on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
190
+
191
+ Returns
192
+ -------
193
+ mlp
194
+ A MLP instance or, when `shared_pinn_ouput` is not None,
195
+ a list of MLP instances with the same structure is returned,
196
+ only differing by there final slicing of the network output.
197
+ mlp.init_params
198
+ An initial set of parameters for the MLP or a list of the latter
199
+ when `shared_pinn_ouput` is not None.
200
+
201
+ """
202
+ if eqx_network is None:
203
+ if eqx_list is None or key is None:
204
+ raise ValueError(
205
+ "If eqx_network is None, then key and eqx_list must be provided"
206
+ )
207
+ eqx_network = cast(MLP, MLP(key=key, eqx_list=eqx_list))
208
+
209
+ mlp = cls(
210
+ eqx_network=eqx_network,
211
+ slice_solution=slice_solution, # type: ignore
212
+ eq_type=eq_type,
213
+ input_transform=input_transform, # type: ignore
214
+ output_transform=output_transform, # type: ignore
215
+ filter_spec=filter_spec,
216
+ )
217
+ return mlp, mlp.init_params