jinns 0.8.4__py3-none-any.whl → 0.8.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/solver/_solve.py +3 -1
- jinns/utils/_hyperpinn.py +1 -13
- jinns/utils/_pinn.py +1 -13
- jinns/utils/_spinn.py +1 -7
- {jinns-0.8.4.dist-info → jinns-0.8.6.dist-info}/METADATA +1 -1
- {jinns-0.8.4.dist-info → jinns-0.8.6.dist-info}/RECORD +9 -9
- {jinns-0.8.4.dist-info → jinns-0.8.6.dist-info}/LICENSE +0 -0
- {jinns-0.8.4.dist-info → jinns-0.8.6.dist-info}/WHEEL +0 -0
- {jinns-0.8.4.dist-info → jinns-0.8.6.dist-info}/top_level.txt +0 -0
jinns/solver/_solve.py
CHANGED
|
@@ -218,6 +218,7 @@ def solve(
|
|
|
218
218
|
batch, data, param_data, obs_data = get_batch(data, param_data, obs_data)
|
|
219
219
|
|
|
220
220
|
(
|
|
221
|
+
loss,
|
|
221
222
|
loss_val,
|
|
222
223
|
loss_terms,
|
|
223
224
|
params,
|
|
@@ -311,7 +312,7 @@ def solve(
|
|
|
311
312
|
)
|
|
312
313
|
|
|
313
314
|
|
|
314
|
-
@partial(jit, static_argnames=["
|
|
315
|
+
@partial(jit, static_argnames=["optimizer"])
|
|
315
316
|
def gradient_step(loss, optimizer, batch, params, opt_state, last_non_nan_params):
|
|
316
317
|
"""
|
|
317
318
|
loss and optimizer cannot be jit-ted.
|
|
@@ -330,6 +331,7 @@ def gradient_step(loss, optimizer, batch, params, opt_state, last_non_nan_params
|
|
|
330
331
|
)
|
|
331
332
|
|
|
332
333
|
return (
|
|
334
|
+
loss,
|
|
333
335
|
loss_val,
|
|
334
336
|
loss_terms,
|
|
335
337
|
params,
|
jinns/utils/_hyperpinn.py
CHANGED
|
@@ -216,13 +216,7 @@ def create_HYPERPINN(
|
|
|
216
216
|
|
|
217
217
|
Returns
|
|
218
218
|
-------
|
|
219
|
-
|
|
220
|
-
A function which (re-)initializes the PINN parameters with the provided
|
|
221
|
-
jax random key
|
|
222
|
-
apply_fn
|
|
223
|
-
A function to apply the neural network on given inputs for given
|
|
224
|
-
parameters. A typical call will be of the form `u(t, params)` for
|
|
225
|
-
ODE or `u(t, x, params)` for nD PDEs (`x` being multidimensional)
|
|
219
|
+
`u`, a :class:`.HyperPINN` object which inherits from `eqx.Module` (hence callable).
|
|
226
220
|
|
|
227
221
|
Raises
|
|
228
222
|
------
|
|
@@ -289,7 +283,6 @@ def create_HYPERPINN(
|
|
|
289
283
|
|
|
290
284
|
if shared_pinn_outputs is not None:
|
|
291
285
|
hyperpinns = []
|
|
292
|
-
static = None
|
|
293
286
|
for output_slice in shared_pinn_outputs:
|
|
294
287
|
hyperpinn = HYPERPINN(
|
|
295
288
|
mlp,
|
|
@@ -302,11 +295,6 @@ def create_HYPERPINN(
|
|
|
302
295
|
hypernet_input_size,
|
|
303
296
|
output_slice,
|
|
304
297
|
)
|
|
305
|
-
# all the pinns are in fact the same so we share the same static
|
|
306
|
-
if static is None:
|
|
307
|
-
static = hyperpinn.static
|
|
308
|
-
else:
|
|
309
|
-
hyperpinn.static = static
|
|
310
298
|
hyperpinns.append(hyperpinn)
|
|
311
299
|
return hyperpinns
|
|
312
300
|
hyperpinn = HYPERPINN(
|
jinns/utils/_pinn.py
CHANGED
|
@@ -200,13 +200,7 @@ def create_PINN(
|
|
|
200
200
|
|
|
201
201
|
Returns
|
|
202
202
|
-------
|
|
203
|
-
|
|
204
|
-
A function which (re-)initializes the PINN parameters with the provided
|
|
205
|
-
jax random key
|
|
206
|
-
apply_fn
|
|
207
|
-
A function to apply the neural network on given inputs for given
|
|
208
|
-
parameters. A typical call will be of the form `u(t, params)` for
|
|
209
|
-
ODE or `u(t, x, params)` for nD PDEs (`x` being multidimensional)
|
|
203
|
+
`u`, a :class:`.PINN` object which inherits from `eqx.Module` (hence callable). This comes with a bound method :func:`u.init_params() <PINN.init_params>`. When `shared_pinn_ouput` is not None, a list of :class:`.PINN` with the same structure is returned, only differing by there final slicing of the network output.
|
|
210
204
|
|
|
211
205
|
Raises
|
|
212
206
|
------
|
|
@@ -253,7 +247,6 @@ def create_PINN(
|
|
|
253
247
|
|
|
254
248
|
if shared_pinn_outputs is not None:
|
|
255
249
|
pinns = []
|
|
256
|
-
static = None
|
|
257
250
|
for output_slice in shared_pinn_outputs:
|
|
258
251
|
pinn = PINN(
|
|
259
252
|
mlp,
|
|
@@ -263,11 +256,6 @@ def create_PINN(
|
|
|
263
256
|
output_transform,
|
|
264
257
|
output_slice,
|
|
265
258
|
)
|
|
266
|
-
# all the pinns are in fact the same so we share the same static
|
|
267
|
-
if static is None:
|
|
268
|
-
static = pinn.static
|
|
269
|
-
else:
|
|
270
|
-
pinn.static = static
|
|
271
259
|
pinns.append(pinn)
|
|
272
260
|
return pinns
|
|
273
261
|
pinn = PINN(mlp, slice_solution, eq_type, input_transform, output_transform, None)
|
jinns/utils/_spinn.py
CHANGED
|
@@ -194,13 +194,7 @@ def create_SPINN(key, d, r, eqx_list, eq_type, m=1):
|
|
|
194
194
|
|
|
195
195
|
Returns
|
|
196
196
|
-------
|
|
197
|
-
|
|
198
|
-
A function which (re-)initializes the SPINN parameters with the provided
|
|
199
|
-
jax random key
|
|
200
|
-
apply_fn
|
|
201
|
-
A function to apply the neural network on given inputs for given
|
|
202
|
-
parameters. A typical call will be of the form `u(t, params)` for
|
|
203
|
-
ODE or `u(t, x, params)` for nD PDEs (`x` being multidimensional)
|
|
197
|
+
`u`, a :class:`.SPINN` object which inherits from `eqx.Module` (hence callable).
|
|
204
198
|
|
|
205
199
|
Raises
|
|
206
200
|
------
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: jinns
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.6
|
|
4
4
|
Summary: Physics Informed Neural Network with JAX
|
|
5
5
|
Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
6
6
|
Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
@@ -15,17 +15,17 @@ jinns/loss/_operators.py,sha256=zDGJqYqeYH7xd-4dtGX9PS-pf0uSOpUUXGo5SVjIJ4o,1106
|
|
|
15
15
|
jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
16
16
|
jinns/solver/_rar.py,sha256=K-0y1-ofOAo1n_Ea3QShSGCGKVYTwiaE_Bz9-DZMJm8,14525
|
|
17
17
|
jinns/solver/_seq2seq.py,sha256=FL-42hTgmVl7O3hHh1ccFVw2bT8bW82hvlDRz971Chk,5620
|
|
18
|
-
jinns/solver/_solve.py,sha256=
|
|
18
|
+
jinns/solver/_solve.py,sha256=r4jn6hx7_t-Y2rBWA2npUmWWnDg4iRbgYBHZDNn9tmY,13745
|
|
19
19
|
jinns/utils/__init__.py,sha256=44ms5UR6vMw3Nf6u4RCAzPFs4fom_YbBnH9mfne8m6k,313
|
|
20
|
-
jinns/utils/_hyperpinn.py,sha256=
|
|
20
|
+
jinns/utils/_hyperpinn.py,sha256=93hbiATdp5W4l1cu9Oe6O2c45o-ZF_z2u6FzNLyjnm4,10878
|
|
21
21
|
jinns/utils/_optim.py,sha256=550kxH75TL30o1iKx1swJyP0KqyUPsJ7-imL1w65Qd0,4444
|
|
22
|
-
jinns/utils/_pinn.py,sha256=
|
|
22
|
+
jinns/utils/_pinn.py,sha256=mhA4-3PazyQTbWIx9oLaNwL0QDe8ZIBhbiy5J3kwa4I,9471
|
|
23
23
|
jinns/utils/_save_load.py,sha256=qgZ23nUcB8-B5IZ2guuUWC4M7r5Lxd_Ms3staScdyJo,5668
|
|
24
|
-
jinns/utils/_spinn.py,sha256=
|
|
24
|
+
jinns/utils/_spinn.py,sha256=SzOUt1KHtB9QOpghpvitnXN-KEqXUXbvabC5k0TnKEo,7793
|
|
25
25
|
jinns/utils/_utils.py,sha256=8dgvWXX9NT7_7-zltWp0C9tG45ZFNwXxueyxPBb4hjo,6740
|
|
26
26
|
jinns/utils/_utils_uspinn.py,sha256=qcKcOw3zrwWSQyGVj6fD8c9GinHt_U6JWN_k0auTtXM,26039
|
|
27
|
-
jinns-0.8.
|
|
28
|
-
jinns-0.8.
|
|
29
|
-
jinns-0.8.
|
|
30
|
-
jinns-0.8.
|
|
31
|
-
jinns-0.8.
|
|
27
|
+
jinns-0.8.6.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
|
|
28
|
+
jinns-0.8.6.dist-info/METADATA,sha256=3Ml6PCA-569v9-1FgyPDySX09RQas0zPOVEV_gqy9lk,2482
|
|
29
|
+
jinns-0.8.6.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
30
|
+
jinns-0.8.6.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
|
|
31
|
+
jinns-0.8.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|