jinns 0.4.2__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/utils/_utils.py CHANGED
@@ -57,314 +57,21 @@ def _tracked_parameters(params, tracked_params_key_list):
57
57
  return tracked_params
58
58
 
59
59
 
60
- class _MLP(eqx.Module):
60
+ def _get_grid(in_array):
61
61
  """
62
- Class to construct an equinox module from a key and a eqx_list. To be used
63
- in pair with the function `create_PINN`
62
+ From an array of shape (B, D), D > 1, get the grid array, i.e., an array of
63
+ shape (B, B, ...(D times)..., B, D): along the last axis we have the array
64
+ of values
64
65
  """
65
-
66
- layers: list
67
-
68
- def __init__(self, key, eqx_list):
69
- """
70
- Parameters
71
- ----------
72
- key
73
- A jax random key
74
- eqx_list
75
- A list of list of successive equinox modules and activation functions to
76
- describe the PINN architecture. The inner lists have the eqx module or
77
- axtivation function as first item, other items represents arguments
78
- that could be required (eg. the size of the layer).
79
- __Note:__ the `key` argument need not be given.
80
- Thus typical example is `eqx_list=
81
- [[eqx.nn.Linear, 2, 20],
82
- [jax.nn.tanh],
83
- [eqx.nn.Linear, 20, 20],
84
- [jax.nn.tanh],
85
- [eqx.nn.Linear, 20, 20],
86
- [jax.nn.tanh],
87
- [eqx.nn.Linear, 20, 1]
88
- ]`
89
- """
90
-
91
- self.layers = []
92
- # TODO we are limited currently in the number of layer type we can
93
- # parse and we lack some safety checks
94
- for l in eqx_list:
95
- if len(l) == 1:
96
- self.layers.append(l[0])
97
- else:
98
- # By default we append a random key at the end of the
99
- # arguments fed into a layer module call
100
- key, subkey = jax.random.split(key, 2)
101
- # the argument key is keyword only
102
- self.layers.append(l[0](*l[1:], key=subkey))
103
-
104
- def __call__(self, t):
105
- for layer in self.layers:
106
- t = layer(t)
107
- return t
108
-
109
-
110
- def create_PINN(
111
- key,
112
- eqx_list,
113
- eq_type,
114
- dim_x=0,
115
- with_eq_params=None,
116
- input_transform=None,
117
- output_transform=None,
118
- ):
119
- """
120
- Utility function to create a standard PINN neural network with the equinox
121
- library.
122
-
123
- Parameters
124
- ----------
125
- key
126
- A jax random key that will be used to initialize the network parameters
127
- eqx_list
128
- A list of list of successive equinox modules and activation functions to
129
- describe the PINN architecture. The inner lists have the eqx module or
130
- axtivation function as first item, other items represents arguments
131
- that could be required (eg. the size of the layer).
132
- __Note:__ the `key` argument need not be given.
133
- Thus typical example is `eqx_list=
134
- [[eqx.nn.Linear, 2, 20],
135
- [jax.nn.tanh],
136
- [eqx.nn.Linear, 20, 20],
137
- [jax.nn.tanh],
138
- [eqx.nn.Linear, 20, 20],
139
- [jax.nn.tanh],
140
- [eqx.nn.Linear, 20, 1]
141
- ]`
142
- eq_type
143
- A string with three possibilities.
144
- "ODE": the PINN is called with one input `t`.
145
- "statio_PDE": the PINN is called with one input `x`, `x`
146
- can be high dimensional.
147
- "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
148
- can be high dimensional.
149
- **Note: the input dimension as given in eqx_list has to match the sum
150
- of the dimension of `t` + the dimension of `x` + the number of
151
- parameters in `eq_params` if with_eq_params is `True` (see below)**
152
- dim_x
153
- An integer. The dimension of `x`. Default `0`
154
- with_eq_params
155
- Default is None. Otherwise a list of keys from the dict `eq_params`
156
- that the network also takes as inputs.
157
- the equation parameters (`eq_params`).
158
- **If some keys are provided, the input dimension
159
- as given in eqx_list must take into account the number of such provided
160
- keys (i.e., the input dimension is the addition of the dimension of ``t``
161
- + the dimension of ``x`` + the number of ``eq_params``)**
162
- input_transform
163
- A function that will be called before entering the PINN. Its output(s)
164
- must mathc the PINN inputs.
165
- output_transform
166
- A function with arguments the same input(s) as the PINN AND the PINN
167
- output that will be called after exiting the PINN
168
-
169
-
170
- Returns
171
- -------
172
- init_fn
173
- A function which (re-)initializes the PINN parameters with the provided
174
- jax random key
175
- apply_fn
176
- A function to apply the neural network on given inputs for given
177
- parameters. A typical call will be of the form `u(t, nn_params)` for
178
- ODE or `u(t, x, nn_params)` for nD PDEs (`x` being multidimensional)
179
- or even `u(t, x, nn_params, eq_params)` if with_eq_params is `True`
180
-
181
- Raises
182
- ------
183
- RuntimeError
184
- If the parameter value for eq_type is not in `["ODE", "statio_PDE",
185
- "nonstatio_PDE"]`
186
- RuntimeError
187
- If we have a `dim_x > 0` and `eq_type == "ODE"`
188
- or if we have a `dim_x = 0` and `eq_type != "ODE"`
189
- """
190
- if eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
191
- raise RuntimeError("Wrong parameter value for eq_type")
192
-
193
- if eq_type == "ODE" and dim_x != 0:
194
- raise RuntimeError("Wrong parameter combination eq_type and dim_x")
195
-
196
- if eq_type != "ODE" and dim_x == 0:
197
- raise RuntimeError("Wrong parameter combination eq_type and dim_x")
198
-
199
- dim_t = 0 if eq_type == "statio_PDE" else 1
200
- dim_in_params = len(with_eq_params) if with_eq_params is not None else 0
201
- try:
202
- nb_inputs_declared = eqx_list[0][1] # normally we look for 2nd ele of 1st layer
203
- except IndexError:
204
- nb_inputs_declared = eqx_list[1][
205
- 1
206
- ] # but we can have, eg, a flatten first layer
207
-
208
- # NOTE Currently the check below is disabled because we added
209
- # input_transform
210
- # if dim_t + dim_x + dim_in_params != nb_inputs_declared:
211
- # raise RuntimeError("Error in the declarations of the number of parameters")
212
-
213
- def make_mlp(key, eqx_list):
214
- mlp = _MLP(key, eqx_list)
215
- params, static = eqx.partition(mlp, eqx.is_inexact_array)
216
-
217
- def init_fn():
218
- return params
219
-
220
- if eq_type == "ODE":
221
- if with_eq_params is None:
222
-
223
- def apply_fn(t, u_params, eq_params=None):
224
- model = eqx.combine(u_params, static)
225
- t = t[
226
- None
227
- ] # Note that we added a dimension to t which is lacking for the ODE batches
228
- if output_transform is None:
229
- if input_transform is not None:
230
- return model(input_transform(t)).squeeze()
231
- else:
232
- return model(t).squeeze()
233
- else:
234
- if input_transform is not None:
235
- return output_transform(
236
- t, model(input_transform(t)).squeeze()
237
- )
238
- else:
239
- return output_transform(t, model(t).squeeze())
240
-
241
- else:
242
-
243
- def apply_fn(t, u_params, eq_params):
244
- model = eqx.combine(u_params, static)
245
- t = t[
246
- None
247
- ] # We added a dimension to t which is lacking for the ODE batches
248
- eq_params_flatten = jnp.concatenate(
249
- [e.ravel() for k, e in eq_params.items() if k in with_eq_params]
250
- )
251
- t_eq_params = jnp.concatenate([t, eq_params_flatten], axis=-1)
252
-
253
- if output_transform is None:
254
- if input_transform is not None:
255
- return model(input_transform(t_eq_params)).squeeze()
256
- else:
257
- return model(t_eq_params).squeeze()
258
- else:
259
- if input_transform is not None:
260
- return output_transform(
261
- t_eq_params,
262
- model(input_transform(t_eq_params)).squeeze(),
263
- )
264
- else:
265
- return output_transform(
266
- t_eq_params, model(t_eq_params).squeeze()
267
- )
268
-
269
- elif eq_type == "statio_PDE":
270
- # Here we add an argument `x` which can be high dimensional
271
- if with_eq_params is None:
272
-
273
- def apply_fn(x, u_params, eq_params=None):
274
- model = eqx.combine(u_params, static)
275
-
276
- if output_transform is None:
277
- if input_transform is not None:
278
- return model(input_transform(x)).squeeze()
279
- else:
280
- return model(x).squeeze()
281
- else:
282
- if input_transform is not None:
283
- return output_transform(
284
- x, model(input_transform(x)).squeeze()
285
- )
286
- else:
287
- return output_transform(x, model(x).squeeze())
288
-
289
- else:
290
-
291
- def apply_fn(x, u_params, eq_params):
292
- model = eqx.combine(u_params, static)
293
- eq_params_flatten = jnp.concatenate(
294
- [e.ravel() for k, e in eq_params.items() if k in with_eq_params]
295
- )
296
- x_eq_params = jnp.concatenate([x, eq_params_flatten], axis=-1)
297
-
298
- if output_transform is None:
299
- if input_transform is not None:
300
- return model(input_transform(x_eq_params)).squeeze()
301
- else:
302
- return model(x_eq_params).squeeze()
303
- else:
304
- if input_transform is not None:
305
- return output_transform(
306
- x_eq_params,
307
- model(input_transform(x_eq_params)).squeeze(),
308
- )
309
- else:
310
- return output_transform(
311
- x_eq_params, model(x_eq_params).squeeze()
312
- )
313
-
314
- elif eq_type == "nonstatio_PDE":
315
- # Here we add an argument `x` which can be high dimensional
316
- if with_eq_params is None:
317
-
318
- def apply_fn(t, x, u_params, eq_params=None):
319
- model = eqx.combine(u_params, static)
320
- t_x = jnp.concatenate([t, x], axis=-1)
321
-
322
- if output_transform is None:
323
- if input_transform is not None:
324
- return model(input_transform(t_x)).squeeze()
325
- else:
326
- return model(t_x).squeeze()
327
- else:
328
- if input_transform is not None:
329
- return output_transform(
330
- t_x, model(input_transform(t_x)).squeeze()
331
- )
332
- else:
333
- return output_transform(t_x, model(t_x).squeeze())
334
-
335
- else:
336
-
337
- def apply_fn(t, x, u_params, eq_params):
338
- model = eqx.combine(u_params, static)
339
- t_x = jnp.concatenate([t, x], axis=-1)
340
- eq_params_flatten = jnp.concatenate(
341
- [e.ravel() for k, e in eq_params.items() if k in with_eq_params]
342
- )
343
- t_x_eq_params = jnp.concatenate([t_x, eq_params_flatten], axis=-1)
344
-
345
- if output_transform is None:
346
- if input_transform is not None:
347
- return model(input_transform(t_x_eq_params)).squeeze()
348
- else:
349
- return model(t_x_eq_params).squeeze()
350
- else:
351
- if input_transform is not None:
352
- return output_transform(
353
- t_x_eq_params,
354
- model(input_transform(t_x_eq_params)).squeeze(),
355
- )
356
- else:
357
- return output_transform(
358
- t_x_eq_params,
359
- model(input_transform(t_x_eq_params)).squeeze(),
360
- )
361
-
362
- else:
363
- raise RuntimeError("Wrong parameter value for eq_type")
364
-
365
- return init_fn, apply_fn
366
-
367
- return make_mlp(key, eqx_list)
66
+ if in_array.shape[-1] > 1 or in_array.ndim > 1:
67
+ return jnp.stack(
68
+ jnp.meshgrid(
69
+ *(in_array[..., d] for d in range(in_array.shape[-1])), indexing="ij"
70
+ ),
71
+ axis=-1,
72
+ )
73
+ else:
74
+ return in_array
368
75
 
369
76
 
370
77
  def _get_vmap_in_axes_params(eq_params_batch_dict, params):
@@ -392,6 +99,25 @@ def _get_vmap_in_axes_params(eq_params_batch_dict, params):
392
99
  return vmap_in_axes_params
393
100
 
394
101
 
102
+ def _check_user_func_return(r, shape):
103
+ """
104
+ Correctly handles the result from a user defined function (eg a boundary
105
+ condition) to get the correct broadcast
106
+ """
107
+ if isinstance(r, int) or isinstance(r, float):
108
+ # if we have a scalar cast it to float
109
+ return float(r)
110
+ if r.shape == () or len(r.shape) == 1:
111
+ # if we have a scalar (or a vector, but no batch dim) inside an array
112
+ return r.astype(float)
113
+ else:
114
+ # if we have an array of the shape of the batch dimension(s) check that
115
+ # we have the correct broadcast
116
+ # the reshape below avoids a missing (1,) ending dimension
117
+ # depending on how the user has coded the inital function
118
+ return r.reshape(shape)
119
+
120
+
395
121
  def alternate_optax_solver(
396
122
  steps, parameters_set1, parameters_set2, lr_set1, lr_set2, label_fn=None
397
123
  ):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jinns
3
- Version: 0.4.2
3
+ Version: 0.5.0
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -29,6 +29,18 @@ jinns
29
29
  Physics Informed Neural Networks with JAX. **jinns** has been developed to estimate solutions to your ODE et PDE problems using neural networks.
30
30
  **jinns** is built on JAX.
31
31
 
32
+ **jinns** specific points:
33
+
34
+ - **jinns** is coded with JAX as a backend: forward and backward autodiff, vmapping, jitting and more!
35
+
36
+ - We focus the development towards inverse problems and inference in mecanistic-statistical models
37
+
38
+ - [Separable PINN](https://openreview.net/pdf?id=dEySGIcDnI) are implemented
39
+
40
+ - Check out our various notebooks to get started with `jinns`
41
+
42
+ For more information, open an issue or contact us!
43
+
32
44
  # Installation
33
45
 
34
46
  Install the latest version with pip
@@ -41,6 +53,8 @@ pip install jinns
41
53
 
42
54
  The project's documentation is available at [https://mia_jinns.gitlab.io/jinns/index.html](https://mia_jinns.gitlab.io/jinns/index.html)
43
55
 
56
+ Note that all the tests were performed on a rather small Nvidia T600 GPU, expect a substancial performance gain on bigger devices.
57
+
44
58
  # Contributing
45
59
 
46
60
  * First fork the library on Gitlab.
@@ -56,7 +70,6 @@ pip install -e .
56
70
  ```bash
57
71
  pip install pre-commit
58
72
  pre-commit install
59
- pre-commit install --hook-type pre-push
60
73
  ```
61
74
 
62
75
  * Open a merge request once you are done with your changes.
@@ -0,0 +1,24 @@
1
+ jinns/__init__.py,sha256=Nw5pdlmDhJwco3bXX3YttkeCF8czX_6m0poh8vu0lDQ,113
2
+ jinns/data/_DataGenerators.py,sha256=nIuKtkX4V4ckfT4-g0bjlY7BLkgcok5JbI9OzJn73mA,44461
3
+ jinns/data/__init__.py,sha256=S13J59Fxuph4uNJ542fP_Mj8U72ilhb5t_UQ-c1k3nY,232
4
+ jinns/data/_display.py,sha256=NfINLJAGmQSPz30cWVaeQFpabzCXprp4RNH6Iycx-VU,7722
5
+ jinns/loss/_DynamicLoss.py,sha256=AuW71-Kyt7tj_aatxroym9BcvqzTfsBI7qMdEF3021w,36550
6
+ jinns/loss/_DynamicLossAbstract.py,sha256=V9NJvMmqnSC06yccu9bkFCWLF3cjyr8ze8qf0LRykjo,7718
7
+ jinns/loss/_LossODE.py,sha256=FHTKQPqLSoMh18j_RYUoR7tUGg_ljd0JKL8xf2HiF5M,17541
8
+ jinns/loss/_LossPDE.py,sha256=Pegyr4jZsI9hZMUGjRjohYeVVKLuSU53rc-epD0KZGI,64328
9
+ jinns/loss/__init__.py,sha256=p9pXrttAgsx_V25Yt8lT8Guoc3z9OAfRRvzl9OIc3Gs,402
10
+ jinns/loss/_boundary_conditions.py,sha256=OE7lDwCBKUJaXkOP5uJZyakT1YWb1D7YogbvFT0RaJo,16633
11
+ jinns/loss/_operators.py,sha256=l9nrLGluD4Y-0L-yydYS3--mH-nKXLRymhm3JCapaZ0,12067
12
+ jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
+ jinns/solver/_rar.py,sha256=V9y07F6objmP6rPA305dIJ82h7kwP4AWBAktZ68b-38,13894
14
+ jinns/solver/_seq2seq.py,sha256=XNL9e0fBj85Q86XfGDzq9dzxIkPPMwoJF38C8doNYtM,6032
15
+ jinns/solver/_solve.py,sha256=Yz_asD0ZuYN923E6ysxtPemN-36Zt-BtlSPkTX-BfA8,10047
16
+ jinns/utils/__init__.py,sha256=ClGy9Ppye1z75daH_3Ngb0406K-AJcZHgj4gM5vfX_8,196
17
+ jinns/utils/_pinn.py,sha256=-UmdSRMzkqe_TEsuVHCtu8b_9Na3wz-5YDoKSAeF8ZM,11307
18
+ jinns/utils/_spinn.py,sha256=SYz8eatM_AkIPYhNf2sOKoGGhtAWKZX47W8xDXR0aK0,8027
19
+ jinns/utils/_utils.py,sha256=zsSPoqQi9OQKTRiCRfeINCwCfzo7MHFCc4zPCIU2UpY,8790
20
+ jinns-0.5.0.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
21
+ jinns-0.5.0.dist-info/METADATA,sha256=C8fZOd7PqroIKIwS80YUIvwbaLbNVuYPxGCv0JG4PkM,2338
22
+ jinns-0.5.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
23
+ jinns-0.5.0.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
24
+ jinns-0.5.0.dist-info/RECORD,,
@@ -1,22 +0,0 @@
1
- jinns/__init__.py,sha256=Nw5pdlmDhJwco3bXX3YttkeCF8czX_6m0poh8vu0lDQ,113
2
- jinns/data/_DataGenerators.py,sha256=nIuKtkX4V4ckfT4-g0bjlY7BLkgcok5JbI9OzJn73mA,44461
3
- jinns/data/__init__.py,sha256=S13J59Fxuph4uNJ542fP_Mj8U72ilhb5t_UQ-c1k3nY,232
4
- jinns/data/_display.py,sha256=Xnfo6_PH1g-ZFpWJcbF6CF6Pp12wJtNQb1W1bADuQrA,6134
5
- jinns/loss/_DynamicLoss.py,sha256=VyoyWdkoxRPeP2vs4ZZBK_T9xWgwkcDuaFrjUSid3Zo,52975
6
- jinns/loss/_DynamicLossAbstract.py,sha256=V9NJvMmqnSC06yccu9bkFCWLF3cjyr8ze8qf0LRykjo,7718
7
- jinns/loss/_LossODE.py,sha256=FHTKQPqLSoMh18j_RYUoR7tUGg_ljd0JKL8xf2HiF5M,17541
8
- jinns/loss/_LossPDE.py,sha256=8wEEPnkibNOdqCXtcuU5nieRsPoffSyV8wcUMjcUkvg,57145
9
- jinns/loss/__init__.py,sha256=4JxMHHVMxTMsVZmV8mRSIyMVAEp3QIR8QtZvnwvj96Q,560
10
- jinns/loss/_boundary_conditions.py,sha256=qe1fHnMduVM1h89feHOc-O2--Zgd4ZYM4XjqHYUjuvw,9970
11
- jinns/loss/_operators.py,sha256=HYGDq3K_K7lztDp6al88Q78F6o-hDRo3ODhw8UMRdC8,5690
12
- jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
- jinns/solver/_rar.py,sha256=V9y07F6objmP6rPA305dIJ82h7kwP4AWBAktZ68b-38,13894
14
- jinns/solver/_seq2seq.py,sha256=XNL9e0fBj85Q86XfGDzq9dzxIkPPMwoJF38C8doNYtM,6032
15
- jinns/solver/_solve.py,sha256=Yz_asD0ZuYN923E6ysxtPemN-36Zt-BtlSPkTX-BfA8,10047
16
- jinns/utils/__init__.py,sha256=-jDlwCjyEzWweswKdwLal3OhaUU3FVzK_Ge2S-7KHXs,149
17
- jinns/utils/_utils.py,sha256=bQm6z_xPKJj9BMCr2tXc44IA8JyGNN-PR5LNRhZ1fD8,20085
18
- jinns-0.4.2.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
19
- jinns-0.4.2.dist-info/METADATA,sha256=maaTIojnCdHhTIPd2kCJa8QGfX36mWAt1DN7q8Pd_3o,1821
20
- jinns-0.4.2.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
21
- jinns-0.4.2.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
22
- jinns-0.4.2.dist-info/RECORD,,
File without changes
File without changes