jinns 0.8.4__py3-none-any.whl → 0.8.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/solver/_solve.py CHANGED
@@ -218,6 +218,7 @@ def solve(
218
218
  batch, data, param_data, obs_data = get_batch(data, param_data, obs_data)
219
219
 
220
220
  (
221
+ loss,
221
222
  loss_val,
222
223
  loss_terms,
223
224
  params,
@@ -311,7 +312,7 @@ def solve(
311
312
  )
312
313
 
313
314
 
314
- @partial(jit, static_argnames=["loss", "optimizer"])
315
+ @partial(jit, static_argnames=["optimizer"])
315
316
  def gradient_step(loss, optimizer, batch, params, opt_state, last_non_nan_params):
316
317
  """
317
318
  loss and optimizer cannot be jit-ted.
@@ -330,6 +331,7 @@ def gradient_step(loss, optimizer, batch, params, opt_state, last_non_nan_params
330
331
  )
331
332
 
332
333
  return (
334
+ loss,
333
335
  loss_val,
334
336
  loss_terms,
335
337
  params,
jinns/utils/_hyperpinn.py CHANGED
@@ -216,13 +216,7 @@ def create_HYPERPINN(
216
216
 
217
217
  Returns
218
218
  -------
219
- init_fn
220
- A function which (re-)initializes the PINN parameters with the provided
221
- jax random key
222
- apply_fn
223
- A function to apply the neural network on given inputs for given
224
- parameters. A typical call will be of the form `u(t, params)` for
225
- ODE or `u(t, x, params)` for nD PDEs (`x` being multidimensional)
219
+ `u`, a :class:`.HyperPINN` object which inherits from `eqx.Module` (hence callable).
226
220
 
227
221
  Raises
228
222
  ------
@@ -289,7 +283,6 @@ def create_HYPERPINN(
289
283
 
290
284
  if shared_pinn_outputs is not None:
291
285
  hyperpinns = []
292
- static = None
293
286
  for output_slice in shared_pinn_outputs:
294
287
  hyperpinn = HYPERPINN(
295
288
  mlp,
@@ -302,11 +295,6 @@ def create_HYPERPINN(
302
295
  hypernet_input_size,
303
296
  output_slice,
304
297
  )
305
- # all the pinns are in fact the same so we share the same static
306
- if static is None:
307
- static = hyperpinn.static
308
- else:
309
- hyperpinn.static = static
310
298
  hyperpinns.append(hyperpinn)
311
299
  return hyperpinns
312
300
  hyperpinn = HYPERPINN(
jinns/utils/_pinn.py CHANGED
@@ -200,13 +200,7 @@ def create_PINN(
200
200
 
201
201
  Returns
202
202
  -------
203
- init_fn
204
- A function which (re-)initializes the PINN parameters with the provided
205
- jax random key
206
- apply_fn
207
- A function to apply the neural network on given inputs for given
208
- parameters. A typical call will be of the form `u(t, params)` for
209
- ODE or `u(t, x, params)` for nD PDEs (`x` being multidimensional)
203
+ `u`, a :class:`.PINN` object which inherits from `eqx.Module` (hence callable). This comes with a bound method :func:`u.init_params() <PINN.init_params>`. When `shared_pinn_ouput` is not None, a list of :class:`.PINN` with the same structure is returned, only differing by there final slicing of the network output.
210
204
 
211
205
  Raises
212
206
  ------
@@ -253,7 +247,6 @@ def create_PINN(
253
247
 
254
248
  if shared_pinn_outputs is not None:
255
249
  pinns = []
256
- static = None
257
250
  for output_slice in shared_pinn_outputs:
258
251
  pinn = PINN(
259
252
  mlp,
@@ -263,11 +256,6 @@ def create_PINN(
263
256
  output_transform,
264
257
  output_slice,
265
258
  )
266
- # all the pinns are in fact the same so we share the same static
267
- if static is None:
268
- static = pinn.static
269
- else:
270
- pinn.static = static
271
259
  pinns.append(pinn)
272
260
  return pinns
273
261
  pinn = PINN(mlp, slice_solution, eq_type, input_transform, output_transform, None)
jinns/utils/_spinn.py CHANGED
@@ -194,13 +194,7 @@ def create_SPINN(key, d, r, eqx_list, eq_type, m=1):
194
194
 
195
195
  Returns
196
196
  -------
197
- init_fn
198
- A function which (re-)initializes the SPINN parameters with the provided
199
- jax random key
200
- apply_fn
201
- A function to apply the neural network on given inputs for given
202
- parameters. A typical call will be of the form `u(t, params)` for
203
- ODE or `u(t, x, params)` for nD PDEs (`x` being multidimensional)
197
+ `u`, a :class:`.SPINN` object which inherits from `eqx.Module` (hence callable).
204
198
 
205
199
  Raises
206
200
  ------
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jinns
3
- Version: 0.8.4
3
+ Version: 0.8.6
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -15,17 +15,17 @@ jinns/loss/_operators.py,sha256=zDGJqYqeYH7xd-4dtGX9PS-pf0uSOpUUXGo5SVjIJ4o,1106
15
15
  jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  jinns/solver/_rar.py,sha256=K-0y1-ofOAo1n_Ea3QShSGCGKVYTwiaE_Bz9-DZMJm8,14525
17
17
  jinns/solver/_seq2seq.py,sha256=FL-42hTgmVl7O3hHh1ccFVw2bT8bW82hvlDRz971Chk,5620
18
- jinns/solver/_solve.py,sha256=6kWFWpJ33uOUzZKn7gIOM7yQsVZUwSuorOWPojVeMQY,13721
18
+ jinns/solver/_solve.py,sha256=r4jn6hx7_t-Y2rBWA2npUmWWnDg4iRbgYBHZDNn9tmY,13745
19
19
  jinns/utils/__init__.py,sha256=44ms5UR6vMw3Nf6u4RCAzPFs4fom_YbBnH9mfne8m6k,313
20
- jinns/utils/_hyperpinn.py,sha256=Mb5d6auzFfXcA81WgjiuhDBvAypAzVOENj_gUeqz6gI,11370
20
+ jinns/utils/_hyperpinn.py,sha256=93hbiATdp5W4l1cu9Oe6O2c45o-ZF_z2u6FzNLyjnm4,10878
21
21
  jinns/utils/_optim.py,sha256=550kxH75TL30o1iKx1swJyP0KqyUPsJ7-imL1w65Qd0,4444
22
- jinns/utils/_pinn.py,sha256=N8LuB9Ql472O01USghkJkEOmx67DTjc279T8Lj-Lwd4,9722
22
+ jinns/utils/_pinn.py,sha256=mhA4-3PazyQTbWIx9oLaNwL0QDe8ZIBhbiy5J3kwa4I,9471
23
23
  jinns/utils/_save_load.py,sha256=qgZ23nUcB8-B5IZ2guuUWC4M7r5Lxd_Ms3staScdyJo,5668
24
- jinns/utils/_spinn.py,sha256=aeIC3DBY7f_N8HABjvBNv375dMyjll3zt6KjY2bEIkM,8058
24
+ jinns/utils/_spinn.py,sha256=SzOUt1KHtB9QOpghpvitnXN-KEqXUXbvabC5k0TnKEo,7793
25
25
  jinns/utils/_utils.py,sha256=8dgvWXX9NT7_7-zltWp0C9tG45ZFNwXxueyxPBb4hjo,6740
26
26
  jinns/utils/_utils_uspinn.py,sha256=qcKcOw3zrwWSQyGVj6fD8c9GinHt_U6JWN_k0auTtXM,26039
27
- jinns-0.8.4.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
28
- jinns-0.8.4.dist-info/METADATA,sha256=QAq8dRIxqTZaBMb0YVOymae5X2kO5XqeHJXJHfe0380,2482
29
- jinns-0.8.4.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
30
- jinns-0.8.4.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
31
- jinns-0.8.4.dist-info/RECORD,,
27
+ jinns-0.8.6.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
28
+ jinns-0.8.6.dist-info/METADATA,sha256=3Ml6PCA-569v9-1FgyPDySX09RQas0zPOVEV_gqy9lk,2482
29
+ jinns-0.8.6.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
30
+ jinns-0.8.6.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
31
+ jinns-0.8.6.dist-info/RECORD,,
File without changes
File without changes