jinns 1.1.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/solver/_solve.py CHANGED
@@ -7,6 +7,7 @@ from __future__ import (
7
7
  annotations,
8
8
  ) # https://docs.python.org/3/library/typing.html#constant
9
9
 
10
+ import time
10
11
  from typing import TYPE_CHECKING, NamedTuple, Dict, Union
11
12
  from functools import partial
12
13
  import optax
@@ -16,6 +17,7 @@ import jax.numpy as jnp
16
17
  from jaxtyping import Int, Bool, Float, Array
17
18
  from jinns.solver._rar import init_rar, trigger_rar
18
19
  from jinns.utils._utils import _check_nan_in_pytree
20
+ from jinns.solver._utils import _check_batch_size
19
21
  from jinns.utils._containers import *
20
22
  from jinns.data._DataGenerators import (
21
23
  DataGeneratorODE,
@@ -29,31 +31,6 @@ if TYPE_CHECKING:
29
31
  from jinns.utils._types import *
30
32
 
31
33
 
32
- def _check_batch_size(other_data, main_data, attr_name):
33
- if (
34
- (
35
- isinstance(main_data, DataGeneratorODE)
36
- and getattr(other_data, attr_name) != main_data.temporal_batch_size
37
- )
38
- or (
39
- isinstance(main_data, CubicMeshPDEStatio)
40
- and not isinstance(main_data, CubicMeshPDENonStatio)
41
- and getattr(other_data, attr_name) != main_data.omega_batch_size
42
- )
43
- or (
44
- isinstance(main_data, CubicMeshPDENonStatio)
45
- and getattr(other_data, attr_name)
46
- != main_data.omega_batch_size * main_data.temporal_batch_size
47
- )
48
- ):
49
- raise ValueError(
50
- "Optional other_data.param_batch_size must be"
51
- " equal to main_data.temporal_batch_size or main_data.omega_batch_size or"
52
- " the product of both dependeing on the type of the main"
53
- " datagenerator"
54
- )
55
-
56
-
57
34
  def solve(
58
35
  n_iter: Int,
59
36
  init_params: AnyParams,
@@ -70,6 +47,7 @@ def solve(
70
47
  validation: AbstractValidationModule | None = None,
71
48
  obs_batch_sharding: jax.sharding.Sharding | None = None,
72
49
  verbose: Bool = True,
50
+ ahead_of_time: Bool = True,
73
51
  ) -> tuple[
74
52
  Params | ParamsDict,
75
53
  Float[Array, "n_iter"],
@@ -141,6 +119,14 @@ def solve(
141
119
  verbose
142
120
  Default True. If False, no std output (loss or cause of
143
121
  exiting the optimization loop) will be produced.
122
+ ahead_of_time
123
+ Default True. Separate the compilation of the main training loop from
124
+ the execution to get both timings. You might need to avoid this
125
+ behaviour if you need to perform JAX transforms over chunks of code
126
+ containing `jinns.solve()` since AOT-compiled functions cannot be JAX
127
+ transformed (see https://jax.readthedocs.io/en/latest/aot.html#aot-compiled-functions-cannot-be-transformed).
128
+ When False, jinns does not provide any timing information (which would
129
+ be nonsense in a JIT transformed `solve()` function).
144
130
 
145
131
  Returns
146
132
  -------
@@ -167,10 +153,22 @@ def solve(
167
153
  The best parameters according to the validation criterion
168
154
  """
169
155
  if param_data is not None:
170
- _check_batch_size(param_data, data, "param_batch_size")
171
-
172
- if obs_data is not None:
173
- _check_batch_size(obs_data, data, "obs_batch_size")
156
+ if param_data.param_batch_size is not None:
157
+ # We need to check that batch sizes will all be compliant for
158
+ # correct vectorization
159
+ _check_batch_size(param_data, data, "param_batch_size")
160
+ else:
161
+ # If DataGeneratorParameter does not have a batch size we will
162
+ # vectorization using `n`, and the same checks must be done
163
+ _check_batch_size(param_data, data, "n")
164
+
165
+ if obs_data is not None and param_data is not None:
166
+ # obs_data batch dimensions need only to be aligned with param_data
167
+ # batch dimensions if the latter exist
168
+ if obs_data.obs_batch_size is not None:
169
+ _check_batch_size(obs_data, param_data, "obs_batch_size")
170
+ else:
171
+ _check_batch_size(obs_data, param_data, "n")
174
172
 
175
173
  if opt_state is None:
176
174
  opt_state = optimizer.init(init_params)
@@ -224,6 +222,8 @@ def solve(
224
222
  )
225
223
  optimization_extra = OptimizationExtraContainer(
226
224
  curr_seq=curr_seq,
225
+ best_iter_id=0,
226
+ best_val_criterion=jnp.nan,
227
227
  best_val_params=init_params,
228
228
  )
229
229
  loss_container = LossContainer(
@@ -323,16 +323,26 @@ def solve(
323
323
  validation_criterion
324
324
  )
325
325
 
326
- # update best_val_params w.r.t val_loss if needed
327
- best_val_params = jax.lax.cond(
326
+ # update best_val_params and best_val_criterion w.r.t val_loss if needed
327
+ (best_val_params, best_val_criterion, best_iter_id) = jax.lax.cond(
328
328
  update_best_params,
329
- lambda _: params, # update with current value
330
- lambda operands: operands[0].best_val_params, # unchanged
329
+ lambda operands: (
330
+ params,
331
+ validation_criterion,
332
+ i,
333
+ ), # update with current value
334
+ lambda operands: (
335
+ operands[0].best_val_params,
336
+ operands[0].best_val_criterion,
337
+ operands[0].best_iter_id,
338
+ ), # unchanged
331
339
  (optimization_extra,),
332
340
  )
333
341
  else:
334
342
  early_stopping = False
343
+ best_iter_id = 0
335
344
  best_val_params = params
345
+ best_val_criterion = jnp.nan
336
346
 
337
347
  # Trigger RAR
338
348
  loss, params, data = trigger_rar(
@@ -358,7 +368,13 @@ def solve(
358
368
  i,
359
369
  loss,
360
370
  OptimizationContainer(params, last_non_nan_params, opt_state),
361
- OptimizationExtraContainer(curr_seq, best_val_params, early_stopping),
371
+ OptimizationExtraContainer(
372
+ curr_seq,
373
+ best_iter_id,
374
+ best_val_criterion,
375
+ best_val_params,
376
+ early_stopping,
377
+ ),
362
378
  DataGeneratorContainer(data, param_data, obs_data),
363
379
  validation,
364
380
  LossContainer(stored_loss_terms, train_loss_values),
@@ -373,7 +389,25 @@ def solve(
373
389
  while break_fun(carry):
374
390
  carry = _one_iteration(carry)
375
391
  else:
376
- carry = jax.lax.while_loop(break_fun, _one_iteration, carry)
392
+
393
+ def train_fun(carry):
394
+ return jax.lax.while_loop(break_fun, _one_iteration, carry)
395
+
396
+ if ahead_of_time:
397
+ start = time.time()
398
+ compiled_train_fun = jax.jit(train_fun).lower(carry).compile()
399
+ end = time.time()
400
+ if verbose:
401
+ print("\nCompilation took\n", end - start, "\n")
402
+
403
+ start = time.time()
404
+ carry = compiled_train_fun(carry)
405
+ jax.block_until_ready(carry)
406
+ end = time.time()
407
+ if verbose:
408
+ print("\nTraining took\n", end - start, "\n")
409
+ else:
410
+ carry = train_fun(carry)
377
411
 
378
412
  (
379
413
  i,
@@ -389,15 +423,30 @@ def solve(
389
423
 
390
424
  if verbose:
391
425
  jax.debug.print(
392
- "Final iteration {i}: train loss value = {train_loss_val}",
426
+ "\nFinal iteration {i}: train loss value = {train_loss_val}",
393
427
  i=i,
394
428
  train_loss_val=loss_container.train_loss_values[i - 1],
395
429
  )
430
+
431
+ # get ready to return the parameters at last iteration...
432
+ # (by default arbitrary choice, this could be None)
433
+ validation_parameters = optimization.last_non_nan_params
396
434
  if validation is not None:
397
435
  jax.debug.print(
398
436
  "validation loss value = {validation_loss_val}",
399
437
  validation_loss_val=validation_crit_values[i - 1],
400
438
  )
439
+ if optimization_extra.early_stopping:
440
+ jax.debug.print(
441
+ "\n Returning a set of best parameters from early stopping"
442
+ " as last argument!\n"
443
+ " Best parameters from iteration {best_iter_id}"
444
+ " with validation loss criterion = {best_val_criterion}",
445
+ best_iter_id=optimization_extra.best_iter_id,
446
+ best_val_criterion=optimization_extra.best_val_criterion,
447
+ )
448
+ # ...but if early stopping, return the parameters at the best_iter_id
449
+ validation_parameters = optimization_extra.best_val_params
401
450
 
402
451
  return (
403
452
  optimization.last_non_nan_params,
@@ -408,7 +457,7 @@ def solve(
408
457
  optimization.opt_state,
409
458
  stored_objects.stored_params,
410
459
  validation_crit_values if validation is not None else None,
411
- optimization_extra.best_val_params if validation is not None else None,
460
+ validation_parameters,
412
461
  )
413
462
 
414
463
 
@@ -531,7 +580,7 @@ def _get_break_fun(n_iter: Int, verbose: Bool) -> Callable[[main_carry], Bool]:
531
580
  string is not a valid JAX type that can be fed into the operands
532
581
  """
533
582
  if verbose:
534
- jax.debug.print(f"Stopping main optimization loop, cause: {msg}")
583
+ jax.debug.print(f"\nStopping main optimization loop, cause: {msg}")
535
584
  return False
536
585
 
537
586
  def continue_while_loop(_):
jinns/solver/_utils.py ADDED
@@ -0,0 +1,122 @@
1
+ from jinns.data._DataGenerators import (
2
+ DataGeneratorODE,
3
+ CubicMeshPDEStatio,
4
+ CubicMeshPDENonStatio,
5
+ DataGeneratorParameter,
6
+ )
7
+
8
+
9
+ def _check_batch_size(other_data, main_data, attr_name):
10
+ if isinstance(main_data, DataGeneratorODE):
11
+ if main_data.temporal_batch_size is not None:
12
+ if getattr(other_data, attr_name) != main_data.temporal_batch_size:
13
+ raise ValueError(
14
+ f"{other_data.__class__}.{attr_name} must be equal"
15
+ f" to {main_data.__class__}.temporal_batch_size for correct"
16
+ " vectorization"
17
+ )
18
+ else:
19
+ if main_data.nt is not None:
20
+ if getattr(other_data, attr_name) != main_data.nt:
21
+ raise ValueError(
22
+ f"{other_data.__class__}.{attr_name} must be equal"
23
+ f" to {main_data.__class__}.nt for correct"
24
+ " vectorization"
25
+ )
26
+ if isinstance(main_data, CubicMeshPDEStatio) and not isinstance(
27
+ main_data, CubicMeshPDENonStatio
28
+ ):
29
+ if main_data.omega_batch_size is not None:
30
+ if getattr(other_data, attr_name) != main_data.omega_batch_size:
31
+ raise ValueError(
32
+ f"{other_data.__class__}.{attr_name} must be equal"
33
+ f" to {main_data.__class__}.omega_batch_size for correct"
34
+ " vectorization"
35
+ )
36
+ else:
37
+ if main_data.n is not None:
38
+ if getattr(other_data, attr_name) != main_data.n:
39
+ raise ValueError(
40
+ f"{other_data.__class__}.{attr_name} must be equal"
41
+ f" to {main_data.__class__}.n for correct"
42
+ " vectorization"
43
+ )
44
+ if main_data.omega_border_batch_size is not None:
45
+ if getattr(other_data, attr_name) != main_data.omega_border_batch_size:
46
+ raise ValueError(
47
+ f"{other_data.__class__}.{attr_name} must be equal"
48
+ f" to {main_data.__class__}.omega_border_batch_size for correct"
49
+ " vectorization"
50
+ )
51
+ else:
52
+ if main_data.nb is not None:
53
+ if getattr(other_data, attr_name) != main_data.nb:
54
+ raise ValueError(
55
+ f"{other_data.__class__}.{attr_name} must be equal"
56
+ f" to {main_data.__class__}.nb for correct"
57
+ " vectorization"
58
+ )
59
+ if isinstance(main_data, CubicMeshPDENonStatio):
60
+ if main_data.domain_batch_size is not None:
61
+ if getattr(other_data, attr_name) != main_data.domain_batch_size:
62
+ raise ValueError(
63
+ f"{other_data.__class__}.{attr_name} must be equal"
64
+ f" to {main_data.__class__}.domain_batch_size for correct"
65
+ " vectorization"
66
+ )
67
+ else:
68
+ if main_data.n is not None:
69
+ if getattr(other_data, attr_name) != main_data.n:
70
+ raise ValueError(
71
+ f"{other_data.__class__}.{attr_name} must be equal"
72
+ f" to {main_data.__class__}.n for correct"
73
+ " vectorization"
74
+ )
75
+ if main_data.border_batch_size is not None:
76
+ if getattr(other_data, attr_name) != main_data.border_batch_size:
77
+ raise ValueError(
78
+ f"{other_data.__class__}.{attr_name} must be equal"
79
+ f" to {main_data.__class__}.border_batch_size for correct"
80
+ " vectorization"
81
+ )
82
+ else:
83
+ if main_data.nb is not None:
84
+ if main_data.dim > 1 and getattr(other_data, attr_name) != (
85
+ main_data.nb // 2**main_data.dim
86
+ ):
87
+ raise ValueError(
88
+ f"{other_data.__class__}.{attr_name} must be equal"
89
+ f" to ({main_data.__class__}.nb // 2**{main_data.__class__}.dim)"
90
+ " for correct vectorization"
91
+ )
92
+ if main_data.initial_batch_size is not None:
93
+ if getattr(other_data, attr_name) != main_data.initial_batch_size:
94
+ raise ValueError(
95
+ f"{other_data.__class__}.{attr_name} must be equal"
96
+ f" to {main_data.__class__}.initial_batch_size for correct"
97
+ " vectorization"
98
+ )
99
+ else:
100
+ if main_data.ni is not None:
101
+ if getattr(other_data, attr_name) != main_data.ni:
102
+ raise ValueError(
103
+ f"{other_data.__class__}.{attr_name} must be equal"
104
+ f" to {main_data.__class__}.ni for correct"
105
+ " vectorization"
106
+ )
107
+ if isinstance(main_data, DataGeneratorParameter):
108
+ if main_data.param_batch_size is not None:
109
+ if getattr(other_data, attr_name) != main_data.param_batch_size:
110
+ raise ValueError(
111
+ f"{other_data.__class__}.{attr_name} must be equal"
112
+ f" to {main_data.__class__}.param_batch_size for correct"
113
+ " vectorization"
114
+ )
115
+ else:
116
+ if main_data.n is not None:
117
+ if getattr(other_data, attr_name) != main_data.n:
118
+ raise ValueError(
119
+ f"{other_data.__class__}.{attr_name} must be equal"
120
+ f" to {main_data.__class__}.n for correct"
121
+ " vectorization"
122
+ )
jinns/utils/__init__.py CHANGED
@@ -1,4 +1 @@
1
- from ._pinn import create_PINN, PINN
2
- from ._spinn import create_SPINN, SPINN
3
- from ._hyperpinn import create_HYPERPINN, HYPERPINN
4
- from ._save_load import save_pinn, load_pinn
1
+ from ._utils import get_grid
@@ -38,7 +38,9 @@ class OptimizationContainer(eqx.Module):
38
38
 
39
39
  class OptimizationExtraContainer(eqx.Module):
40
40
  curr_seq: int
41
- best_val_params: Params
41
+ best_iter_id: int # the best iteration number (that which achieves best_val_params and best_val_params)
42
+ best_val_criterion: float # the best validation criterion at early stopping
43
+ best_val_params: Params # the best parameter values at early stopping
42
44
  early_stopping: Bool = False
43
45
 
44
46
 
jinns/utils/_types.py CHANGED
@@ -1,3 +1,4 @@
1
+ # pragma: exclude file
1
2
  from __future__ import (
2
3
  annotations,
3
4
  ) # https://docs.python.org/3/library/typing.html#constant
@@ -25,9 +26,9 @@ if TYPE_CHECKING:
25
26
 
26
27
  from jinns.loss import DynamicLoss
27
28
  from jinns.data._Batchs import *
28
- from jinns.utils._pinn import PINN
29
- from jinns.utils._hyperpinn import HYPERPINN
30
- from jinns.utils._spinn import SPINN
29
+ from jinns.nn._pinn import PINN
30
+ from jinns.nn._hyperpinn import HyperPINN
31
+ from jinns.nn._spinn_mlp import SPINN
31
32
  from jinns.utils._containers import *
32
33
  from jinns.validation._validation import AbstractValidationModule
33
34
 
@@ -41,7 +42,7 @@ if TYPE_CHECKING:
41
42
  DataGeneratorODE | CubicMeshPDEStatio | CubicMeshPDENonStatio
42
43
  )
43
44
 
44
- AnyPINN: TypeAlias = PINN | HYPERPINN | SPINN
45
+ AnyPINN: TypeAlias = PINN | HyperPINN | SPINN
45
46
 
46
47
  AnyBatch: TypeAlias = ODEBatch | PDEStatioBatch | PDENonStatioBatch
47
48
  rar_operands = NewType(
jinns/utils/_utils.py CHANGED
@@ -2,13 +2,18 @@
2
2
  Implements various utility functions
3
3
  """
4
4
 
5
- from functools import reduce
6
- from operator import getitem
7
- import numpy as np
5
+ from math import prod
6
+ import warnings
8
7
  import jax
9
8
  import jax.numpy as jnp
10
9
  from jaxtyping import PyTree, Array
11
10
 
11
+ from jinns.data._DataGenerators import (
12
+ DataGeneratorODE,
13
+ CubicMeshPDEStatio,
14
+ CubicMeshPDENonStatio,
15
+ )
16
+
12
17
 
13
18
  def _check_nan_in_pytree(pytree: PyTree) -> bool:
14
19
  """
@@ -33,7 +38,7 @@ def _check_nan_in_pytree(pytree: PyTree) -> bool:
33
38
  )
34
39
 
35
40
 
36
- def _get_grid(in_array: Array) -> Array:
41
+ def get_grid(in_array: Array) -> Array:
37
42
  """
38
43
  From an array of shape (B, D), D > 1, get the grid array, i.e., an array of
39
44
  shape (B, B, ...(D times)..., B, D): along the last axis we have the array
@@ -49,10 +54,14 @@ def _get_grid(in_array: Array) -> Array:
49
54
  return in_array
50
55
 
51
56
 
52
- def _check_user_func_return(r: Array | int, shape: tuple) -> Array | int:
57
+ def _check_shape_and_type(
58
+ r: Array | int, expected_shape: tuple, cause: str = "", binop: str = ""
59
+ ) -> Array | float:
53
60
  """
54
- Correctly handles the result from a user defined function (eg a boundary
55
- condition) to get the correct broadcast
61
+ Ensures float type and correct shapes for broadcasting when performing a
62
+ binary operation (like -, + or *) between two arrays.
63
+ First array is a custom user (observation data or output of initial/BC
64
+ functions), the expected shape is the same as the PINN's.
56
65
  """
57
66
  if isinstance(r, (int, float)):
58
67
  # if we have a scalar cast it to float
@@ -60,9 +69,28 @@ def _check_user_func_return(r: Array | int, shape: tuple) -> Array | int:
60
69
  if r.shape == ():
61
70
  # if we have a scalar inside a ndarray
62
71
  return r.astype(float)
63
- if r.shape[-1] == shape[-1]:
64
- # the broadcast will be OK
72
+ if r.shape[-1] == expected_shape[-1]:
73
+ # broadcasting will be OK
65
74
  return r.astype(float)
66
- # the reshape below avoids a missing (1,) ending dimension
67
- # depending on how the user has coded the inital function
68
- return r.reshape(shape)
75
+
76
+ if r.shape != expected_shape:
77
+ # Usually, the reshape below adds a missing (1,) final axis to ensure # the PINN output and the other function (initial/boundary condition)
78
+ # have the correct shape, depending on how the user has coded the
79
+ # initial/boundary condition.
80
+ warnings.warn(
81
+ f"[{cause}] Performing operation `{binop}` between arrays"
82
+ f" of different shapes: got {r.shape} for the custom array and"
83
+ f" {expected_shape} for the PINN."
84
+ f" This can cause unexpected and wrong broadcasting."
85
+ f" Reshaping {r.shape} into {expected_shape}. Reshape your"
86
+ f" custom array to math the {expected_shape=} to prevent this"
87
+ f" warning."
88
+ )
89
+ return r.reshape(expected_shape)
90
+
91
+
92
+ def _subtract_with_check(
93
+ a: Array | int, b: Array | int, cause: str = ""
94
+ ) -> Array | float:
95
+ a = _check_shape_and_type(a, b.shape, cause=cause, binop="-")
96
+ return a - b
@@ -0,0 +1,127 @@
1
+ Metadata-Version: 2.2
2
+ Name: jinns
3
+ Version: 1.3.0
4
+ Summary: Physics Informed Neural Network with JAX
5
+ Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
+ Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
7
+ License: Apache License 2.0
8
+ Project-URL: Repository, https://gitlab.com/mia_jinns/jinns
9
+ Project-URL: Documentation, https://mia_jinns.gitlab.io/jinns/index.html
10
+ Classifier: License :: OSI Approved :: Apache Software License
11
+ Classifier: Development Status :: 4 - Beta
12
+ Classifier: Programming Language :: Python
13
+ Requires-Python: >=3.10
14
+ Description-Content-Type: text/markdown
15
+ License-File: LICENSE
16
+ License-File: AUTHORS
17
+ Requires-Dist: numpy
18
+ Requires-Dist: jax
19
+ Requires-Dist: jaxopt
20
+ Requires-Dist: optax
21
+ Requires-Dist: equinox>0.11.3
22
+ Requires-Dist: jax-tqdm
23
+ Requires-Dist: diffrax
24
+ Requires-Dist: matplotlib
25
+ Provides-Extra: notebook
26
+ Requires-Dist: jupyter; extra == "notebook"
27
+ Requires-Dist: seaborn; extra == "notebook"
28
+
29
+ jinns
30
+ =====
31
+
32
+ ![status](https://gitlab.com/mia_jinns/jinns/badges/main/pipeline.svg) ![coverage](https://gitlab.com/mia_jinns/jinns/badges/main/coverage.svg)
33
+
34
+ Physics Informed Neural Networks with JAX. **jinns** is developed to estimate solutions of ODE and PDE problems using neural networks, with a strong focus on
35
+
36
+ 1. inverse problems: find equation parameters given noisy/indirect observations
37
+ 2. meta-modeling: solve for a parametric family of differential equations
38
+
39
+ It can also be used for forward problems and hybrid-modeling.
40
+
41
+ **jinns** specific points:
42
+
43
+ - **jinns uses JAX** - It is directed to JAX users: forward and backward autodiff, vmapping, jitting and more! No reinventing the wheel: it relies on the JAX ecosystem whenever possible, such as [equinox](https://github.com/patrick-kidger/equinox/) for neural networks or [optax](https://optax.readthedocs.io/) for optimization.
44
+
45
+ - **jinns is highly modular** - It gives users maximum control for defining their problems, and extending the package. The maths and computations are visible and not hidden behind layers of code!
46
+
47
+ - **jinns is efficient** - It compares favorably to other existing Python package for PINNs on the [PINNacle benchmarks](https://github.com/i207M/PINNacle/), as demonstrated in the table below. For more details on the benchmarks, checkout the [PINN multi-library benchmark](https://gitlab.com/mia_jinns/pinn-multi-library-benchmark)
48
+
49
+ - Implemented PINN architectures
50
+ - Vanilla Multi-Layer Perceptron popular accross the PINNs litterature.
51
+
52
+ - [Separable PINNs](https://openreview.net/pdf?id=dEySGIcDnI): allows to leverage forward-mode autodiff for computational speed.
53
+
54
+ - [Hyper PINNs](https://arxiv.org/pdf/2111.01008.pdf): useful for meta-modeling
55
+
56
+
57
+ - **Get started**: check out our various notebooks on the [documentation](https://mia_jinns.gitlab.io/jinns/index.html).
58
+
59
+ | | jinns | DeepXDE - JAX | DeepXDE - Pytorch | PINA | Nvidia Modulus |
60
+ |---|:---:|:---:|:---:|:---:|:---:|
61
+ | Burgers1D | **445** | 723 | 671 | 1977 | 646 |
62
+ | NS2d-C | **265** | 278 | 441 | 1600 | 275 |
63
+ | PInv | 149 | 218 | *CC* | 1509 | **135** |
64
+ | Diffusion-Reaction-Inv | **284** | *NI* | 3424 | 4061 | 2541 |
65
+ | Navier-Stokes-Inv | **175** | *NI* | 1511 | 1403 | 498 |
66
+
67
+ *Training time in seconds on an Nvidia T600 GPU. NI means problem cannot be implemented in the backend, CC means the code crashed.*
68
+
69
+ ![A diagram of jinns workflow](img/jinns-diagram.png)
70
+
71
+
72
+ # Installation
73
+
74
+ Install the latest version with pip
75
+
76
+ ```bash
77
+ pip install jinns
78
+ ```
79
+
80
+ # Documentation
81
+
82
+ The project's documentation is hosted on Gitlab page and available at [https://mia_jinns.gitlab.io/jinns/index.html](https://mia_jinns.gitlab.io/jinns/index.html).
83
+
84
+
85
+ # Found a bug / want a feature ?
86
+
87
+ Open an issue on the [Gitlab repo](https://gitlab.com/mia_jinns/jinns/-/issues).
88
+
89
+
90
+ # Contributing
91
+
92
+ Here are the contributors guidelines:
93
+
94
+ 1. First fork the library on Gitlab.
95
+
96
+ 2. Then clone and install the library in development mode with
97
+
98
+ ```bash
99
+ pip install -e .
100
+ ```
101
+
102
+ 3. Install pre-commit and run it.
103
+
104
+ ```bash
105
+ pip install pre-commit
106
+ pre-commit install
107
+ ```
108
+
109
+ 4. Open a merge request once you are done with your changes, the review will be done via Gitlab.
110
+
111
+ # Contributors
112
+
113
+ Don't hesitate to contribute and get your name on the list here !
114
+
115
+ **List of contributors:** Hugo Gangloff, Nicolas Jouvin, Lucia Clarotto, Inass Soukarieh
116
+
117
+ # Cite us
118
+
119
+ Please consider citing our work if you found it useful to yours, using this [ArXiV preprint](https://arxiv.org/abs/2412.14132)
120
+ ```
121
+ @article{gangloff_jouvin2024jinns,
122
+ title={jinns: a JAX Library for Physics-Informed Neural Networks},
123
+ author={Gangloff, Hugo and Jouvin, Nicolas},
124
+ journal={arXiv preprint arXiv:2412.14132},
125
+ year={2024}
126
+ }
127
+ ```
@@ -0,0 +1,44 @@
1
+ jinns/__init__.py,sha256=5p7V5VJd7PXEINqhqS4mUsnQtXlyPwfctRhL4p0loFg,181
2
+ jinns/data/_Batchs.py,sha256=oc7-N1wEbsEvbe9fjVFKG2OPoZJVEjzPm8uj_icACf4,817
3
+ jinns/data/_DataGenerators.py,sha256=3pyUqzQ12AUBqOV-yqpt4X6K_7CqTFtUKMjg-gJE6KA,65101
4
+ jinns/data/__init__.py,sha256=TRCH0Z4-SQZ50MbSf46CUYWBkWVDmXCyez9T-EGiv_8,338
5
+ jinns/experimental/__init__.py,sha256=3jCIy2R2i_0Erwxg-HwISdH79Nt1XCXhS9yY1F5awiY,208
6
+ jinns/experimental/_diffrax_solver.py,sha256=upMr3kTTNrxEiSUO_oLvCXcjS9lPxSjvbB81h3qlhaU,6813
7
+ jinns/loss/_DynamicLoss.py,sha256=lUpFl37_TfwxSREpoVKqUOpQEVqD3hrFXqwP2GZReWw,25817
8
+ jinns/loss/_DynamicLossAbstract.py,sha256=bqmPxyrcvZh_dL74DTpj-TGiFxchvG8qC6KhuGeyOoA,12006
9
+ jinns/loss/_LossODE.py,sha256=QhhSyJpDbcyW4TdShX0HkxbvJQWXvnYg8lik8_wyOg4,23415
10
+ jinns/loss/_LossPDE.py,sha256=DZPinl7KYV2vp_CdjnhaR9M_gE-WOvyi4s8VSDEgti0,51046
11
+ jinns/loss/__init__.py,sha256=PRiJV9fd2GSwaCBVCPyh6pFc6pdA40jfb_T1YvO8ERc,712
12
+ jinns/loss/_boundary_conditions.py,sha256=kxHwNFSMsNzFso6nvAewcAdzW50yTi7IX-5Pthe65XY,12271
13
+ jinns/loss/_loss_utils.py,sha256=IkZAWmBumNWwk3hzeO0dh5RjHKZpt_hL4XnG5-Gpfr8,14690
14
+ jinns/loss/_loss_weights.py,sha256=F0Fgji2XpVk3pr9oIryGuXcG1FGQo4Dv6WFgze2BtA0,2201
15
+ jinns/loss/_operators.py,sha256=qaRxwqgnZzlE_zTyUvafZGnUH5EZY1lpgjT9Vb7QJAQ,21718
16
+ jinns/nn/__init__.py,sha256=k9guJSKmKlHEadAjU-0HlYXJe55Tt783QrkZz6EYyO8,231
17
+ jinns/nn/_hyperpinn.py,sha256=nH8c9DeiiAujprEd7CVKU1chWn-kcSAY-fYLzd8_ikY,18049
18
+ jinns/nn/_mlp.py,sha256=AbbFLF85ayJcQ6kVwfSNdAvjP69UWBP6Z3V-1De-pI4,8028
19
+ jinns/nn/_pinn.py,sha256=45lXgrZQHv-7PQ3EDWWIoo8FlXRnjL1nl7mALTSJ45o,8391
20
+ jinns/nn/_ppinn.py,sha256=vqIH_v1DF3LoHyl3pJ1qhfnGMRMfvbfNK6m9s5LC21k,9212
21
+ jinns/nn/_save_load.py,sha256=VaO9LtR6dajEfo8iP7FgOvyLdQxT2IawazC2sxs97lc,9139
22
+ jinns/nn/_spinn.py,sha256=QmKhDZ0-ToJk3_glQ9BQWgoC0d-EEAWxMrDeHfB2slw,4191
23
+ jinns/nn/_spinn_mlp.py,sha256=9iU_-TIUFMVBcYv0nQmsa07ZwApIKqnXm7v4CY87PTo,7224
24
+ jinns/parameters/__init__.py,sha256=1gxNLoAXUjhUzBWuh86YjU5pYy8SOboCs8TrKcU1wZc,158
25
+ jinns/parameters/_derivative_keys.py,sha256=UyEcgfNF1vwPcGWD2ShAZkZiq4thzRDm_OUJzOfjjiY,21909
26
+ jinns/parameters/_params.py,sha256=wK9ZSqoL9KnjOWqc_ZhJ09ffbsgeUEcttc1Rhme0lLk,3550
27
+ jinns/plot/__init__.py,sha256=Q279h5veYWNLQyttsC8_tDOToqUHh8WaRON90CiWXqk,81
28
+ jinns/plot/_plot.py,sha256=6OqCNvOeqbat3dViOtehILbRfGIS3pnTmNRfbZYaVTA,11433
29
+ jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
+ jinns/solver/_rar.py,sha256=JU4FgWt5w3tzgn2mNyftGi8Erxn5N0Za60-lRaL2poI,9724
31
+ jinns/solver/_solve.py,sha256=Bh7uplfcInJEQj1wmMquisN_vvUghARgX_uaYf7NUpw,23423
32
+ jinns/solver/_utils.py,sha256=b2zYvwZY_fU0NMNWvUEMvHez9s7hwcxfpGzQlz5F6HA,5762
33
+ jinns/utils/__init__.py,sha256=uw3I-lWT3wLabo6-H8FbKpSXI2xobzSs2W-Xno280g0,29
34
+ jinns/utils/_containers.py,sha256=a7A-iUApnjc1YVc7bdt9tKUvHHPDOKMB9OfdrDZGWN8,1450
35
+ jinns/utils/_types.py,sha256=4Qgsg6r9UPGpRwmERv4Cx2nU5ZIweehDlZQPo-FuR4Y,1896
36
+ jinns/utils/_utils.py,sha256=hoRcJqcTuQi_Ip40oI4EbxW46E1rp2C01_HfuCpwKRM,2932
37
+ jinns/validation/__init__.py,sha256=Jv58mzgC3F7cRfXA6caicL1t_U0UAhbwLrmMNVg6E7s,66
38
+ jinns/validation/_validation.py,sha256=bvqL2poTFJfn9lspWqMqXvQGcQIodKwKrC786QtEZ7A,4700
39
+ jinns-1.3.0.dist-info/AUTHORS,sha256=7NwCj9nU-HNG1asvy4qhQ2w7oZHrn-Lk5_wK_Ve7a3M,80
40
+ jinns-1.3.0.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
41
+ jinns-1.3.0.dist-info/METADATA,sha256=PM3iLQFd-vHDU697ECGjD2vQpgxo1vo1GTFl5AdIWoo,4744
42
+ jinns-1.3.0.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
43
+ jinns-1.3.0.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
44
+ jinns-1.3.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.1.0)
2
+ Generator: setuptools (75.8.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5