jinns 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/nn/_spinn_mlp.py ADDED
@@ -0,0 +1,196 @@
1
+ """
2
+ Implements utility function to create Separable PINNs
3
+ https://arxiv.org/abs/2211.08761
4
+ """
5
+
6
+ from dataclasses import InitVar
7
+ from typing import Callable, Literal, Self, Union, Any
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import equinox as eqx
11
+ from jaxtyping import Key, Array, Float, PyTree
12
+
13
+ from jinns.parameters._params import Params, ParamsDict
14
+ from jinns.nn._mlp import MLP
15
+ from jinns.nn._spinn import SPINN
16
+
17
+
18
+ class SMLP(eqx.Module):
19
+ """
20
+ Construct a Separable MLP
21
+
22
+ Parameters
23
+ ----------
24
+ key : InitVar[Key]
25
+ A jax random key for the layer initializations.
26
+ d : int
27
+ The number of dimensions to treat separately, including time `t` if
28
+ used for non-stationnary equations.
29
+ eqx_list : InitVar[tuple[tuple[Callable, int, int] | Callable, ...]]
30
+ A tuple of tuples of successive equinox modules and activation functions to
31
+ describe the PINN architecture. The inner tuples must have the eqx module or
32
+ activation function as first item, other items represents arguments
33
+ that could be required (eg. the size of the layer).
34
+ The `key` argument need not be given.
35
+ Thus typical example is `eqx_list=
36
+ ((eqx.nn.Linear, 1, 20),
37
+ jax.nn.tanh,
38
+ (eqx.nn.Linear, 20, 20),
39
+ jax.nn.tanh,
40
+ (eqx.nn.Linear, 20, 20),
41
+ jax.nn.tanh,
42
+ (eqx.nn.Linear, 20, r * m)
43
+ )`.
44
+ """
45
+
46
+ key: InitVar[Key] = eqx.field(kw_only=True)
47
+ eqx_list: InitVar[tuple[tuple[Callable, int, int] | Callable, ...]] = eqx.field(
48
+ kw_only=True
49
+ )
50
+ d: int = eqx.field(static=True, kw_only=True)
51
+
52
+ separated_mlp: list[MLP] = eqx.field(init=False)
53
+
54
+ def __post_init__(self, key, eqx_list):
55
+ keys = jax.random.split(key, self.d)
56
+ self.separated_mlp = [
57
+ MLP(key=keys[d_], eqx_list=eqx_list) for d_ in range(self.d)
58
+ ]
59
+
60
+ def __call__(
61
+ self, inputs: Float[Array, "dim"] | Float[Array, "dim+1"]
62
+ ) -> Float[Array, "d embed_dim*output_dim"]:
63
+ outputs = []
64
+ for d in range(self.d):
65
+ x_i = inputs[d : d + 1]
66
+ outputs += [self.separated_mlp[d](x_i)]
67
+ return jnp.asarray(outputs)
68
+
69
+
70
+ class SPINN_MLP(SPINN):
71
+ """
72
+ An implementable SPINN based on a MLP architecture
73
+ """
74
+
75
+ @classmethod
76
+ def create(
77
+ cls,
78
+ key: Key,
79
+ d: int,
80
+ r: int,
81
+ eqx_list: tuple[tuple[Callable, int, int] | Callable, ...],
82
+ eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
83
+ m: int = 1,
84
+ filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = None,
85
+ ) -> tuple[Self, PyTree]:
86
+ """
87
+ Utility function to create a SPINN neural network with the equinox
88
+ library.
89
+
90
+ *Note* that a SPINN is not vmapped and expects the
91
+ same batch size for each of its input axis. It directly outputs a
92
+ solution of shape `(batchsize,) * d`. See the paper for more
93
+ details.
94
+
95
+ Parameters
96
+ ----------
97
+ key : Key
98
+ A JAX random key that will be used to initialize the network parameters
99
+ d : int
100
+ The number of dimensions to treat separately.
101
+ r : int
102
+ An integer. The dimension of the embedding.
103
+ eqx_list : tuple[tuple[Callable, int, int] | Callable, ...],
104
+ A tuple of tuples of successive equinox modules and activation functions to
105
+ describe the PINN architecture. The inner tuples must have the eqx module or
106
+ activation function as first item, other items represents arguments
107
+ that could be required (eg. the size of the layer).
108
+ The `key` argument need not be given.
109
+ Thus typical example is
110
+ `eqx_list=((eqx.nn.Linear, 1, 20),
111
+ jax.nn.tanh,
112
+ (eqx.nn.Linear, 20, 20),
113
+ jax.nn.tanh,
114
+ (eqx.nn.Linear, 20, 20),
115
+ jax.nn.tanh,
116
+ (eqx.nn.Linear, 20, r * m)
117
+ )`.
118
+ eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
119
+ A string with three possibilities.
120
+ "ODE": the PINN is called with one input `t`.
121
+ "statio_PDE": the PINN is called with one input `x`, `x`
122
+ can be high dimensional.
123
+ "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
124
+ can be high dimensional.
125
+ **Note**: the input dimension as given in eqx_list has to match the sum
126
+ of the dimension of `t` + the dimension of `x`.
127
+ m : int
128
+ The output dimension of the neural network. According to
129
+ the SPINN article, a total embedding dimension of `r*m` is defined. We
130
+ then sum groups of `r` embedding dimensions to compute each output.
131
+ Default is 1.
132
+ filter_spec : PyTree[Union[bool, Callable[[Any], bool]]]
133
+ Default is None which leads to `eqx.is_inexact_array` in the class
134
+ instanciation. This tells Jinns what to consider as
135
+ a trainable parameter. Quoting from equinox documentation:
136
+ a PyTree whose structure should be a prefix of the structure of pytree.
137
+ Each of its leaves should either be 1) True, in which case the leaf or
138
+ subtree is kept; 2) False, in which case the leaf or subtree is
139
+ replaced with replace; 3) a callable Leaf -> bool, in which case this is evaluated on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
140
+
141
+
142
+
143
+
144
+ Returns
145
+ -------
146
+ spinn
147
+ An instanciated SPINN
148
+ spinn.init_params
149
+ The initial set of parameters of the model
150
+
151
+ Raises
152
+ ------
153
+ RuntimeError
154
+ If the parameter value for eq_type is not in `["ODE", "statio_PDE",
155
+ "nonstatio_PDE"]` and for various failing checks
156
+ """
157
+
158
+ if eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
159
+ raise RuntimeError("Wrong parameter value for eq_type")
160
+
161
+ try:
162
+ nb_inputs_declared = eqx_list[0][
163
+ 1
164
+ ] # normally we look for 2nd ele of 1st layer
165
+ except IndexError:
166
+ nb_inputs_declared = eqx_list[1][
167
+ 1
168
+ ] # but we can have, eg, a flatten first layer
169
+ if nb_inputs_declared != 1:
170
+ raise ValueError("Input dim must be set to 1 in SPINN!")
171
+
172
+ try:
173
+ nb_outputs_declared = eqx_list[-1][2] # normally we look for 3rd ele of
174
+ # last layer
175
+ except IndexError:
176
+ nb_outputs_declared = eqx_list[-2][2]
177
+ # but we can have, eg, a `jnp.exp` last layer
178
+ if nb_outputs_declared != r * m:
179
+ raise ValueError("Output dim must be set to r * m in SPINN!")
180
+
181
+ if d > 24:
182
+ raise ValueError(
183
+ "Too many dimensions, not enough letters available in jnp.einsum"
184
+ )
185
+
186
+ smlp = SMLP(key=key, d=d, eqx_list=eqx_list)
187
+ spinn = cls(
188
+ eqx_spinn_network=smlp,
189
+ d=d,
190
+ r=r,
191
+ eq_type=eq_type,
192
+ m=m,
193
+ filter_spec=filter_spec,
194
+ )
195
+
196
+ return spinn, spinn.init_params
jinns/plot/_plot.py CHANGED
@@ -208,7 +208,7 @@ def _plot_2D_statio(
208
208
  figsize :
209
209
  By default (7, 7)
210
210
  spinn :
211
- True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
211
+ True if a SPINN is to be plotted. False for PINNs and HyperPINNs
212
212
  vmin_vmax: list, optional
213
213
  The colorbar minimum and maximum value. Defaults None.
214
214
 
@@ -272,7 +272,7 @@ def plot1d_slice(
272
272
  figsize
273
273
  size of the figure, by default (10, 10)
274
274
  spinn
275
- True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
275
+ True if a SPINN is to be plotted. False for PINNs and HyperPINNs
276
276
  ax
277
277
  A pre-defined `matplotlib.Axes` where you want to plot.
278
278
 
@@ -341,7 +341,7 @@ def plot1d_image(
341
341
  cmap :
342
342
  the matplotlib color map used in the ImageGrid.
343
343
  spinn :
344
- True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
344
+ True if a SPINN is to be plotted. False for PINNs and HyperPINNs
345
345
  vmin_vmax:
346
346
  The colorbar minimum and maximum value. Defaults None.
347
347
 
jinns/solver/_rar.py CHANGED
@@ -19,8 +19,8 @@ from jinns.data._DataGenerators import (
19
19
  CubicMeshPDEStatio,
20
20
  CubicMeshPDENonStatio,
21
21
  )
22
- from jinns.utils._hyperpinn import HYPERPINN
23
- from jinns.utils._spinn import SPINN
22
+ from jinns.nn._hyperpinn import HyperPINN
23
+ from jinns.nn._spinn import SPINN
24
24
 
25
25
 
26
26
  if TYPE_CHECKING:
@@ -115,7 +115,7 @@ def _rar_step_init(sample_size: Int, selected_sample_size: Int) -> tuple[
115
115
 
116
116
  def rar_step_true(operands: rar_operands) -> AnyDataGenerator:
117
117
  loss, params, data, i = operands
118
- if isinstance(loss.u, HYPERPINN) or isinstance(loss.u, SPINN):
118
+ if isinstance(loss.u, HyperPINN) or isinstance(loss.u, SPINN):
119
119
  raise NotImplementedError("RAR not implemented for hyperPINN and SPINN")
120
120
 
121
121
  if isinstance(data, DataGeneratorODE):
jinns/solver/_solve.py CHANGED
@@ -47,6 +47,7 @@ def solve(
47
47
  validation: AbstractValidationModule | None = None,
48
48
  obs_batch_sharding: jax.sharding.Sharding | None = None,
49
49
  verbose: Bool = True,
50
+ ahead_of_time: Bool = True,
50
51
  ) -> tuple[
51
52
  Params | ParamsDict,
52
53
  Float[Array, "n_iter"],
@@ -118,6 +119,14 @@ def solve(
118
119
  verbose
119
120
  Default True. If False, no std output (loss or cause of
120
121
  exiting the optimization loop) will be produced.
122
+ ahead_of_time
123
+ Default True. Separate the compilation of the main training loop from
124
+ the execution to get both timings. You might need to avoid this
125
+ behaviour if you need to perform JAX transforms over chunks of code
126
+ containing `jinns.solve()` since AOT-compiled functions cannot be JAX
127
+ transformed (see https://jax.readthedocs.io/en/latest/aot.html#aot-compiled-functions-cannot-be-transformed).
128
+ When False, jinns does not provide any timing information (which would
129
+ be nonsense in a JIT transformed `solve()` function).
121
130
 
122
131
  Returns
123
132
  -------
@@ -384,16 +393,21 @@ def solve(
384
393
  def train_fun(carry):
385
394
  return jax.lax.while_loop(break_fun, _one_iteration, carry)
386
395
 
387
- start = time.time()
388
- compiled_train_fun = jax.jit(train_fun).lower(carry).compile()
389
- end = time.time()
390
- print("\nCompilation took\n", end - start, "\n")
396
+ if ahead_of_time:
397
+ start = time.time()
398
+ compiled_train_fun = jax.jit(train_fun).lower(carry).compile()
399
+ end = time.time()
400
+ if verbose:
401
+ print("\nCompilation took\n", end - start, "\n")
391
402
 
392
- start = time.time()
393
- carry = compiled_train_fun(carry)
394
- jax.block_until_ready(carry)
395
- end = time.time()
396
- print("\nTraining took\n", end - start, "\n")
403
+ start = time.time()
404
+ carry = compiled_train_fun(carry)
405
+ jax.block_until_ready(carry)
406
+ end = time.time()
407
+ if verbose:
408
+ print("\nTraining took\n", end - start, "\n")
409
+ else:
410
+ carry = train_fun(carry)
397
411
 
398
412
  (
399
413
  i,
jinns/utils/__init__.py CHANGED
@@ -1,6 +1 @@
1
- from ._pinn import create_PINN, PINN
2
- from ._ppinn import create_PPINN, PPINN
3
- from ._spinn import create_SPINN, SPINN
4
- from ._hyperpinn import create_HYPERPINN, HYPERPINN
5
- from ._save_load import save_pinn, load_pinn
6
1
  from ._utils import get_grid
jinns/utils/_types.py CHANGED
@@ -26,9 +26,9 @@ if TYPE_CHECKING:
26
26
 
27
27
  from jinns.loss import DynamicLoss
28
28
  from jinns.data._Batchs import *
29
- from jinns.utils._pinn import PINN
30
- from jinns.utils._hyperpinn import HYPERPINN
31
- from jinns.utils._spinn import SPINN
29
+ from jinns.nn._pinn import PINN
30
+ from jinns.nn._hyperpinn import HyperPINN
31
+ from jinns.nn._spinn_mlp import SPINN
32
32
  from jinns.utils._containers import *
33
33
  from jinns.validation._validation import AbstractValidationModule
34
34
 
@@ -42,7 +42,7 @@ if TYPE_CHECKING:
42
42
  DataGeneratorODE | CubicMeshPDEStatio | CubicMeshPDENonStatio
43
43
  )
44
44
 
45
- AnyPINN: TypeAlias = PINN | HYPERPINN | SPINN
45
+ AnyPINN: TypeAlias = PINN | HyperPINN | SPINN
46
46
 
47
47
  AnyBatch: TypeAlias = ODEBatch | PDEStatioBatch | PDENonStatioBatch
48
48
  rar_operands = NewType(
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: jinns
3
- Version: 1.2.0
3
+ Version: 1.3.0
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -112,16 +112,16 @@ pre-commit install
112
112
 
113
113
  Don't hesitate to contribute and get your name on the list here !
114
114
 
115
- **List of contributors:** Hugo Gangloff, Nicolas Jouvin
115
+ **List of contributors:** Hugo Gangloff, Nicolas Jouvin, Lucia Clarotto, Inass Soukarieh
116
116
 
117
117
  # Cite us
118
118
 
119
- Please consider citing our work if you found it useful to yours, using the following lines
119
+ Please consider citing our work if you found it useful to yours, using this [ArXiV preprint](https://arxiv.org/abs/2412.14132)
120
120
  ```
121
- @software{jinns2024,
122
- title={\texttt{jinns}: Physics-Informed Neural Networks with JAX},
123
- author={Gangloff, Hugo and Jouvin, Nicolas},
124
- url={https://gitlab.com/mia_jinns},
125
- year={2024}
121
+ @article{gangloff_jouvin2024jinns,
122
+ title={jinns: a JAX Library for Physics-Informed Neural Networks},
123
+ author={Gangloff, Hugo and Jouvin, Nicolas},
124
+ journal={arXiv preprint arXiv:2412.14132},
125
+ year={2024}
126
126
  }
127
127
  ```
@@ -0,0 +1,44 @@
1
+ jinns/__init__.py,sha256=5p7V5VJd7PXEINqhqS4mUsnQtXlyPwfctRhL4p0loFg,181
2
+ jinns/data/_Batchs.py,sha256=oc7-N1wEbsEvbe9fjVFKG2OPoZJVEjzPm8uj_icACf4,817
3
+ jinns/data/_DataGenerators.py,sha256=3pyUqzQ12AUBqOV-yqpt4X6K_7CqTFtUKMjg-gJE6KA,65101
4
+ jinns/data/__init__.py,sha256=TRCH0Z4-SQZ50MbSf46CUYWBkWVDmXCyez9T-EGiv_8,338
5
+ jinns/experimental/__init__.py,sha256=3jCIy2R2i_0Erwxg-HwISdH79Nt1XCXhS9yY1F5awiY,208
6
+ jinns/experimental/_diffrax_solver.py,sha256=upMr3kTTNrxEiSUO_oLvCXcjS9lPxSjvbB81h3qlhaU,6813
7
+ jinns/loss/_DynamicLoss.py,sha256=lUpFl37_TfwxSREpoVKqUOpQEVqD3hrFXqwP2GZReWw,25817
8
+ jinns/loss/_DynamicLossAbstract.py,sha256=bqmPxyrcvZh_dL74DTpj-TGiFxchvG8qC6KhuGeyOoA,12006
9
+ jinns/loss/_LossODE.py,sha256=QhhSyJpDbcyW4TdShX0HkxbvJQWXvnYg8lik8_wyOg4,23415
10
+ jinns/loss/_LossPDE.py,sha256=DZPinl7KYV2vp_CdjnhaR9M_gE-WOvyi4s8VSDEgti0,51046
11
+ jinns/loss/__init__.py,sha256=PRiJV9fd2GSwaCBVCPyh6pFc6pdA40jfb_T1YvO8ERc,712
12
+ jinns/loss/_boundary_conditions.py,sha256=kxHwNFSMsNzFso6nvAewcAdzW50yTi7IX-5Pthe65XY,12271
13
+ jinns/loss/_loss_utils.py,sha256=IkZAWmBumNWwk3hzeO0dh5RjHKZpt_hL4XnG5-Gpfr8,14690
14
+ jinns/loss/_loss_weights.py,sha256=F0Fgji2XpVk3pr9oIryGuXcG1FGQo4Dv6WFgze2BtA0,2201
15
+ jinns/loss/_operators.py,sha256=qaRxwqgnZzlE_zTyUvafZGnUH5EZY1lpgjT9Vb7QJAQ,21718
16
+ jinns/nn/__init__.py,sha256=k9guJSKmKlHEadAjU-0HlYXJe55Tt783QrkZz6EYyO8,231
17
+ jinns/nn/_hyperpinn.py,sha256=nH8c9DeiiAujprEd7CVKU1chWn-kcSAY-fYLzd8_ikY,18049
18
+ jinns/nn/_mlp.py,sha256=AbbFLF85ayJcQ6kVwfSNdAvjP69UWBP6Z3V-1De-pI4,8028
19
+ jinns/nn/_pinn.py,sha256=45lXgrZQHv-7PQ3EDWWIoo8FlXRnjL1nl7mALTSJ45o,8391
20
+ jinns/nn/_ppinn.py,sha256=vqIH_v1DF3LoHyl3pJ1qhfnGMRMfvbfNK6m9s5LC21k,9212
21
+ jinns/nn/_save_load.py,sha256=VaO9LtR6dajEfo8iP7FgOvyLdQxT2IawazC2sxs97lc,9139
22
+ jinns/nn/_spinn.py,sha256=QmKhDZ0-ToJk3_glQ9BQWgoC0d-EEAWxMrDeHfB2slw,4191
23
+ jinns/nn/_spinn_mlp.py,sha256=9iU_-TIUFMVBcYv0nQmsa07ZwApIKqnXm7v4CY87PTo,7224
24
+ jinns/parameters/__init__.py,sha256=1gxNLoAXUjhUzBWuh86YjU5pYy8SOboCs8TrKcU1wZc,158
25
+ jinns/parameters/_derivative_keys.py,sha256=UyEcgfNF1vwPcGWD2ShAZkZiq4thzRDm_OUJzOfjjiY,21909
26
+ jinns/parameters/_params.py,sha256=wK9ZSqoL9KnjOWqc_ZhJ09ffbsgeUEcttc1Rhme0lLk,3550
27
+ jinns/plot/__init__.py,sha256=Q279h5veYWNLQyttsC8_tDOToqUHh8WaRON90CiWXqk,81
28
+ jinns/plot/_plot.py,sha256=6OqCNvOeqbat3dViOtehILbRfGIS3pnTmNRfbZYaVTA,11433
29
+ jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
+ jinns/solver/_rar.py,sha256=JU4FgWt5w3tzgn2mNyftGi8Erxn5N0Za60-lRaL2poI,9724
31
+ jinns/solver/_solve.py,sha256=Bh7uplfcInJEQj1wmMquisN_vvUghARgX_uaYf7NUpw,23423
32
+ jinns/solver/_utils.py,sha256=b2zYvwZY_fU0NMNWvUEMvHez9s7hwcxfpGzQlz5F6HA,5762
33
+ jinns/utils/__init__.py,sha256=uw3I-lWT3wLabo6-H8FbKpSXI2xobzSs2W-Xno280g0,29
34
+ jinns/utils/_containers.py,sha256=a7A-iUApnjc1YVc7bdt9tKUvHHPDOKMB9OfdrDZGWN8,1450
35
+ jinns/utils/_types.py,sha256=4Qgsg6r9UPGpRwmERv4Cx2nU5ZIweehDlZQPo-FuR4Y,1896
36
+ jinns/utils/_utils.py,sha256=hoRcJqcTuQi_Ip40oI4EbxW46E1rp2C01_HfuCpwKRM,2932
37
+ jinns/validation/__init__.py,sha256=Jv58mzgC3F7cRfXA6caicL1t_U0UAhbwLrmMNVg6E7s,66
38
+ jinns/validation/_validation.py,sha256=bvqL2poTFJfn9lspWqMqXvQGcQIodKwKrC786QtEZ7A,4700
39
+ jinns-1.3.0.dist-info/AUTHORS,sha256=7NwCj9nU-HNG1asvy4qhQ2w7oZHrn-Lk5_wK_Ve7a3M,80
40
+ jinns-1.3.0.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
41
+ jinns-1.3.0.dist-info/METADATA,sha256=PM3iLQFd-vHDU697ECGjD2vQpgxo1vo1GTFl5AdIWoo,4744
42
+ jinns-1.3.0.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
43
+ jinns-1.3.0.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
44
+ jinns-1.3.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.6.0)
2
+ Generator: setuptools (75.8.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5