jinns 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/utils/_ppinn.py ADDED
@@ -0,0 +1,227 @@
1
+ """
2
+ Implements utility function to create PINNs
3
+ """
4
+
5
+ from typing import Callable, Literal
6
+ from dataclasses import InitVar
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import equinox as eqx
10
+
11
+ from jaxtyping import Array, Key, PyTree, Float
12
+
13
+ from jinns.parameters._params import Params, ParamsDict
14
+
15
+ from jinns.utils._pinn import PINN, _MLP
16
+
17
+
18
+ class PPINN(PINN):
19
+ r"""
20
+ A PPINN (Parallel PINN) object which mimicks the PFNN architecture from
21
+ DeepXDE. This is in fact a PINN that encompasses several PINNs internally.
22
+
23
+ Parameters
24
+ ----------
25
+ slice_solution : slice
26
+ A jnp.s\_ object which indicates which axis of the PINN output is
27
+ dedicated to the actual equation solution. Default None
28
+ means that slice_solution = the whole PINN output. This argument is useful
29
+ when the PINN is also used to output equation parameters for example
30
+ Note that it must be a slice and not an integer (a preprocessing of the
31
+ user provided argument takes care of it).
32
+ eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
33
+ A string with three possibilities.
34
+ "ODE": the PINN is called with one input `t`.
35
+ "statio_PDE": the PINN is called with one input `x`, `x`
36
+ can be high dimensional.
37
+ "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
38
+ can be high dimensional.
39
+ **Note**: the input dimension as given in eqx_list has to match the sum
40
+ of the dimension of `t` + the dimension of `x` or the output dimension
41
+ after the `input_transform` function.
42
+ input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
43
+ A function that will be called before entering the PINN. Its output(s)
44
+ must match the PINN inputs (except for the parameters).
45
+ Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
46
+ and the parameters. Default is no operation.
47
+ output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
48
+ A function with arguments begin the same input as the PINN, the PINN
49
+ output and the parameter. This function will be called after exiting the PINN.
50
+ Default is no operation.
51
+ mlp_list : list[eqx.Module]
52
+ The actual neural networks instanciated as eqx.Modules
53
+ """
54
+
55
+ slice_solution: slice = eqx.field(static=True, kw_only=True)
56
+ output_slice: slice = eqx.field(static=True, kw_only=True, default=None)
57
+
58
+ mlp_list: InitVar[list[eqx.Module]] = eqx.field(kw_only=True)
59
+
60
+ params: PyTree = eqx.field(init=False)
61
+ static: PyTree = eqx.field(init=False, static=True)
62
+
63
+ def __post_init__(self, mlp, mlp_list):
64
+ super().__post_init__(
65
+ mlp=mlp_list[0],
66
+ )
67
+ self.params, self.static = (), ()
68
+ for mlp in mlp_list:
69
+ params, static = eqx.partition(mlp, eqx.is_inexact_array)
70
+ self.params = self.params + (params,)
71
+ self.static = self.static + (static,)
72
+
73
+ @property
74
+ def init_params(self) -> PyTree:
75
+ """
76
+ Returns an initial set of parameters
77
+ """
78
+ return self.params
79
+
80
+ def __call__(
81
+ self,
82
+ inputs: Float[Array, "1"] | Float[Array, "dim"] | Float[Array, "1+dim"],
83
+ params: PyTree,
84
+ ) -> Float[Array, "output_dim"]:
85
+ """
86
+ Evaluate the PPINN on some inputs with some params.
87
+ """
88
+ if len(inputs.shape) == 0:
89
+ # This can happen often when the user directly provides some
90
+ # collocation points (eg for plotting, whithout using
91
+ # DataGenerators)
92
+ inputs = inputs[None]
93
+ transformed_inputs = self.input_transform(inputs, params)
94
+
95
+ outs = []
96
+ for params_, static in zip(params.nn_params, self.static):
97
+ model = eqx.combine(params_, static)
98
+ outs += [model(transformed_inputs)]
99
+ # Note that below is then a global output transform
100
+ res = self.output_transform(inputs, jnp.concatenate(outs, axis=0), params)
101
+
102
+ ## force (1,) output for non vectorial solution (consistency)
103
+ if not res.shape:
104
+ return jnp.expand_dims(res, axis=-1)
105
+ return res
106
+
107
+
108
+ def create_PPINN(
109
+ key: Key,
110
+ eqx_list_list: list[tuple[tuple[Callable, int, int] | Callable, ...]],
111
+ eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
112
+ dim_x: int = 0,
113
+ input_transform: Callable[
114
+ [Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
115
+ ] = None,
116
+ output_transform: Callable[
117
+ [Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
118
+ Float[Array, "output_dim"],
119
+ ] = None,
120
+ slice_solution: slice = None,
121
+ ) -> tuple[PINN | list[PINN], PyTree | list[PyTree]]:
122
+ r"""
123
+ Utility function to create a standard PINN neural network with the equinox
124
+ library.
125
+
126
+ Parameters
127
+ ----------
128
+ key
129
+ A JAX random key that will be used to initialize the network
130
+ parameters.
131
+ eqx_list_list
132
+ A list of `eqx_list` (see `create_PINN`). The input dimension must be the
133
+ same for each sub-`eqx_list`. Then the parallel subnetworks can be
134
+ different. Their respective outputs are concatenated.
135
+ eq_type
136
+ A string with three possibilities.
137
+ "ODE": the PPINN is called with one input `t`.
138
+ "statio_PDE": the PPINN is called with one input `x`, `x`
139
+ can be high dimensional.
140
+ "nonstatio_PDE": the PPINN is called with two inputs `t` and `x`, `x`
141
+ can be high dimensional.
142
+ **Note**: the input dimension as given in eqx_list has to match the sum
143
+ of the dimension of `t` + the dimension of `x` or the output dimension
144
+ after the `input_transform` function.
145
+ dim_x
146
+ An integer. The dimension of `x`. Default `0`.
147
+ input_transform
148
+ A function that will be called before entering the PPINN. Its output(s)
149
+ must match the PPINN inputs (except for the parameters).
150
+ Its inputs are the PPINN inputs (`t` and/or `x` concatenated together)
151
+ and the parameters. Default is no operation.
152
+ output_transform
153
+ This function will be called after exiting
154
+ the PPINN, i.e., on the concatenated outputs of all parallel networks
155
+ Default is no operation.
156
+ slice_solution
157
+ A jnp.s\_ object which indicates which axis of the PPINN output is
158
+ dedicated to the actual equation solution. Default None
159
+ means that slice_solution = the whole PPINN output. This argument is
160
+ useful when the PPINN is also used to output equation parameters for
161
+ example Note that it must be a slice and not an integer (a
162
+ preprocessing of the user provided argument takes care of it).
163
+
164
+
165
+ Returns
166
+ -------
167
+ ppinn
168
+ A PPINN instance
169
+ ppinn.init_params
170
+ An initial set of parameters for the PPINN
171
+
172
+ Raises
173
+ ------
174
+ RuntimeError
175
+ If the parameter value for eq_type is not in `["ODE", "statio_PDE",
176
+ "nonstatio_PDE"]`
177
+ RuntimeError
178
+ If we have a `dim_x > 0` and `eq_type == "ODE"`
179
+ or if we have a `dim_x = 0` and `eq_type != "ODE"`
180
+ """
181
+ if eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
182
+ raise RuntimeError("Wrong parameter value for eq_type")
183
+
184
+ if eq_type == "ODE" and dim_x != 0:
185
+ raise RuntimeError("Wrong parameter combination eq_type and dim_x")
186
+
187
+ if eq_type != "ODE" and dim_x == 0:
188
+ raise RuntimeError("Wrong parameter combination eq_type and dim_x")
189
+
190
+ nb_outputs_declared = 0
191
+ for eqx_list in eqx_list_list:
192
+ try:
193
+ nb_outputs_declared += eqx_list[-1][2] # normally we look for 3rd ele of
194
+ # last layer
195
+ except IndexError:
196
+ nb_outputs_declared += eqx_list[-2][2]
197
+
198
+ if slice_solution is None:
199
+ slice_solution = jnp.s_[0:nb_outputs_declared]
200
+ if isinstance(slice_solution, int):
201
+ # rewrite it as a slice to ensure that axis does not disappear when
202
+ # indexing
203
+ slice_solution = jnp.s_[slice_solution : slice_solution + 1]
204
+
205
+ if input_transform is None:
206
+
207
+ def input_transform(_in, _params):
208
+ return _in
209
+
210
+ if output_transform is None:
211
+
212
+ def output_transform(_in_pinn, _out_pinn, _params):
213
+ return _out_pinn
214
+
215
+ mlp_list = []
216
+ for eqx_list in eqx_list_list:
217
+ mlp_list.append(_MLP(key=key, eqx_list=eqx_list))
218
+
219
+ ppinn = PPINN(
220
+ mlp=None,
221
+ mlp_list=mlp_list,
222
+ slice_solution=slice_solution,
223
+ eq_type=eq_type,
224
+ input_transform=input_transform,
225
+ output_transform=output_transform,
226
+ )
227
+ return ppinn, ppinn.init_params
jinns/utils/_save_load.py CHANGED
@@ -20,18 +20,18 @@ def function_to_string(
20
20
  We need this transformation for eqx_list to be pickled
21
21
 
22
22
  From `((eqx.nn.Linear, 2, 20),
23
- (jax.nn.tanh),
23
+ (jax.nn.tanh,),
24
24
  (eqx.nn.Linear, 20, 20),
25
- (jax.nn.tanh),
25
+ (jax.nn.tanh,),
26
26
  (eqx.nn.Linear, 20, 20),
27
- (jax.nn.tanh),
27
+ (jax.nn.tanh,),
28
28
  (eqx.nn.Linear, 20, 1))` to
29
29
  `(("Linear", 2, 20),
30
- ("tanh"),
30
+ ("tanh",),
31
31
  ("Linear", 20, 20),
32
- ("tanh"),
32
+ ("tanh",),
33
33
  ("Linear", 20, 20),
34
- ("tanh"),
34
+ ("tanh",),
35
35
  ("Linear", 20, 1))`
36
36
  """
37
37
  return jax.tree_util.tree_map(
@@ -210,14 +210,16 @@ def load_pinn(
210
210
  if type_ == "pinn":
211
211
  # next line creates a shallow model, the jax arrays are just shapes and
212
212
  # not populated, this just recreates the correct pytree structure
213
- u_reloaded_shallow = eqx.filter_eval_shape(create_PINN, **kwargs_reloaded)
213
+ u_reloaded_shallow, _ = eqx.filter_eval_shape(create_PINN, **kwargs_reloaded)
214
214
  elif type_ == "spinn":
215
- u_reloaded_shallow = eqx.filter_eval_shape(create_SPINN, **kwargs_reloaded)
215
+ u_reloaded_shallow, _ = eqx.filter_eval_shape(create_SPINN, **kwargs_reloaded)
216
216
  elif type_ == "hyperpinn":
217
217
  kwargs_reloaded["eqx_list_hyper"] = string_to_function(
218
218
  kwargs_reloaded["eqx_list_hyper"]
219
219
  )
220
- u_reloaded_shallow = eqx.filter_eval_shape(create_HYPERPINN, **kwargs_reloaded)
220
+ u_reloaded_shallow, _ = eqx.filter_eval_shape(
221
+ create_HYPERPINN, **kwargs_reloaded
222
+ )
221
223
  else:
222
224
  raise ValueError(f"{type_} is not valid")
223
225
  if key_list_for_paramsdict is None:
@@ -226,15 +228,13 @@ def load_pinn(
226
228
  u_reloaded = eqx.tree_deserialise_leaves(
227
229
  filename + "-module.eqx", u_reloaded_shallow
228
230
  )
229
- params = Params(
230
- nn_params=u_reloaded.init_params(), eq_params=eq_params_reloaded
231
- )
231
+ params = Params(nn_params=u_reloaded.init_params, eq_params=eq_params_reloaded)
232
232
  else:
233
233
  nn_params_dict = {}
234
234
  for key in key_list_for_paramsdict:
235
235
  u_reloaded = eqx.tree_deserialise_leaves(
236
236
  filename + f"-module_{key}.eqx", u_reloaded_shallow
237
237
  )
238
- nn_params_dict[key] = u_reloaded.init_params()
238
+ nn_params_dict[key] = u_reloaded.init_params
239
239
  params = ParamsDict(nn_params=nn_params_dict, eq_params=eq_params_reloaded)
240
240
  return u_reloaded, params
jinns/utils/_spinn.py CHANGED
@@ -10,6 +10,8 @@ import jax.numpy as jnp
10
10
  import equinox as eqx
11
11
  from jaxtyping import Key, Array, Float, PyTree
12
12
 
13
+ from jinns.parameters._params import Params, ParamsDict
14
+
13
15
 
14
16
  class _SPINN(eqx.Module):
15
17
  """
@@ -21,7 +23,8 @@ class _SPINN(eqx.Module):
21
23
  key : InitVar[Key]
22
24
  A jax random key for the layer initializations.
23
25
  d : int
24
- The number of dimensions to treat separately.
26
+ The number of dimensions to treat separately, including time `t` if
27
+ used for non-stationnary equations.
25
28
  eqx_list : InitVar[tuple[tuple[Callable, int, int] | Callable, ...]]
26
29
  A tuple of tuples of successive equinox modules and activation functions to
27
30
  describe the PINN architecture. The inner tuples must have the eqx module or
@@ -62,15 +65,11 @@ class _SPINN(eqx.Module):
62
65
  self.separated_mlp.append(self.layers)
63
66
 
64
67
  def __call__(
65
- self, t: Float[Array, "1"], x: Float[Array, "omega_dim"]
68
+ self, inputs: Float[Array, "dim"] | Float[Array, "dim+1"]
66
69
  ) -> Float[Array, "d embed_dim*output_dim"]:
67
- if t is not None:
68
- dimensions = jnp.concatenate([t, x.flatten()], axis=0)
69
- else:
70
- dimensions = jnp.concatenate([x.flatten()], axis=0)
71
70
  outputs = []
72
71
  for d in range(self.d):
73
- t_ = dimensions[d][None]
72
+ t_ = inputs[d : d + 1]
74
73
  for layer in self.separated_mlp[d]:
75
74
  t_ = layer(t_)
76
75
  outputs += [t_]
@@ -82,13 +81,11 @@ class SPINN(eqx.Module):
82
81
  A SPINN object compatible with the rest of jinns.
83
82
  This is typically created with `create_SPINN`.
84
83
 
85
- **NOTE**: SPINNs with `t` and `x` as inputs are best used with a
86
- DataGenerator with `self.cartesian_product=False` for memory consideration
87
-
88
84
  Parameters
89
85
  ----------
90
86
  d : int
91
- The number of dimensions to treat separately.
87
+ The number of dimensions to treat separately, including time `t` if
88
+ used for non-stationnary equations.
92
89
 
93
90
  """
94
91
 
@@ -105,42 +102,28 @@ class SPINN(eqx.Module):
105
102
  def __post_init__(self, spinn_mlp):
106
103
  self.params, self.static = eqx.partition(spinn_mlp, eqx.is_inexact_array)
107
104
 
105
+ @property
108
106
  def init_params(self) -> PyTree:
109
107
  """
110
108
  Returns an initial set of parameters
111
109
  """
112
110
  return self.params
113
111
 
114
- def __call__(self, *args) -> Float[Array, "output_dim"]:
115
- """
116
- Calls `eval_nn` with rearranged arguments
117
- """
118
- if self.eq_type == "statio_PDE":
119
- (x, params) = args
120
- try:
121
- spinn = eqx.combine(params.nn_params, self.static)
122
- except (KeyError, AttributeError, TypeError) as e:
123
- spinn = eqx.combine(params, self.static)
124
- v_model = jax.vmap(spinn, (0))
125
- res = v_model(t=None, x=x)
126
- return self.eval_nn(res)
127
- if self.eq_type == "nonstatio_PDE":
128
- (t, x, params) = args
129
- try:
130
- spinn = eqx.combine(params.nn_params, self.static)
131
- except (KeyError, AttributeError, TypeError) as e:
132
- spinn = eqx.combine(params, self.static)
133
- v_model = jax.vmap(spinn, ((0, 0)))
134
- res = v_model(t, x)
135
- return self.eval_nn(res)
136
- raise RuntimeError("Wrong parameter value for eq_type")
137
-
138
- def eval_nn(
139
- self, res: Float[Array, "d embed_dim*output_dim"]
112
+ def __call__(
113
+ self,
114
+ t_x: Float[Array, "batch_size 1+dim"],
115
+ params: Params | ParamsDict | PyTree,
140
116
  ) -> Float[Array, "output_dim"]:
141
117
  """
142
118
  Evaluate the SPINN on some inputs with some params.
143
119
  """
120
+ try:
121
+ spinn = eqx.combine(params.nn_params, self.static)
122
+ except (KeyError, AttributeError, TypeError) as e:
123
+ spinn = eqx.combine(params, self.static)
124
+ v_model = jax.vmap(spinn)
125
+ res = v_model(t_x)
126
+
144
127
  a = ", ".join([f"{chr(97 + d)}z" for d in range(res.shape[1])])
145
128
  b = "".join([f"{chr(97 + d)}" for d in range(res.shape[1])])
146
129
  res = jnp.stack(
@@ -170,7 +153,7 @@ def create_SPINN(
170
153
  eqx_list: tuple[tuple[Callable, int, int] | Callable, ...],
171
154
  eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
172
155
  m: int = 1,
173
- ) -> SPINN:
156
+ ) -> tuple[SPINN, PyTree]:
174
157
  """
175
158
  Utility function to create a SPINN neural network with the equinox
176
159
  library.
@@ -218,16 +201,14 @@ def create_SPINN(
218
201
  then sum groups of `r` embedding dimensions to compute each output.
219
202
  Default is 1.
220
203
 
221
- !!! note
222
- SPINNs with `t` and `x` as inputs are best used with a
223
- DataGenerator with `self.cartesian_product=False` for memory
224
- consideration
225
204
 
226
205
 
227
206
  Returns
228
207
  -------
229
208
  spinn
230
209
  An instanciated SPINN
210
+ spinn.init_params
211
+ The initial set of parameters of the model
231
212
 
232
213
  Raises
233
214
  ------
@@ -265,4 +246,4 @@ def create_SPINN(
265
246
  spinn_mlp = _SPINN(key=key, d=d, eqx_list=eqx_list)
266
247
  spinn = SPINN(spinn_mlp=spinn_mlp, d=d, r=r, eq_type=eq_type, m=m)
267
248
 
268
- return spinn
249
+ return spinn, spinn.init_params
jinns/utils/_types.py CHANGED
@@ -1,3 +1,4 @@
1
+ # pragma: exclude file
1
2
  from __future__ import (
2
3
  annotations,
3
4
  ) # https://docs.python.org/3/library/typing.html#constant
jinns/utils/_utils.py CHANGED
@@ -2,13 +2,18 @@
2
2
  Implements various utility functions
3
3
  """
4
4
 
5
- from functools import reduce
6
- from operator import getitem
7
- import numpy as np
5
+ from math import prod
6
+ import warnings
8
7
  import jax
9
8
  import jax.numpy as jnp
10
9
  from jaxtyping import PyTree, Array
11
10
 
11
+ from jinns.data._DataGenerators import (
12
+ DataGeneratorODE,
13
+ CubicMeshPDEStatio,
14
+ CubicMeshPDENonStatio,
15
+ )
16
+
12
17
 
13
18
  def _check_nan_in_pytree(pytree: PyTree) -> bool:
14
19
  """
@@ -33,7 +38,7 @@ def _check_nan_in_pytree(pytree: PyTree) -> bool:
33
38
  )
34
39
 
35
40
 
36
- def _get_grid(in_array: Array) -> Array:
41
+ def get_grid(in_array: Array) -> Array:
37
42
  """
38
43
  From an array of shape (B, D), D > 1, get the grid array, i.e., an array of
39
44
  shape (B, B, ...(D times)..., B, D): along the last axis we have the array
@@ -49,10 +54,14 @@ def _get_grid(in_array: Array) -> Array:
49
54
  return in_array
50
55
 
51
56
 
52
- def _check_user_func_return(r: Array | int, shape: tuple) -> Array | int:
57
+ def _check_shape_and_type(
58
+ r: Array | int, expected_shape: tuple, cause: str = "", binop: str = ""
59
+ ) -> Array | float:
53
60
  """
54
- Correctly handles the result from a user defined function (eg a boundary
55
- condition) to get the correct broadcast
61
+ Ensures float type and correct shapes for broadcasting when performing a
62
+ binary operation (like -, + or *) between two arrays.
63
+ First array is a custom user (observation data or output of initial/BC
64
+ functions), the expected shape is the same as the PINN's.
56
65
  """
57
66
  if isinstance(r, (int, float)):
58
67
  # if we have a scalar cast it to float
@@ -60,9 +69,28 @@ def _check_user_func_return(r: Array | int, shape: tuple) -> Array | int:
60
69
  if r.shape == ():
61
70
  # if we have a scalar inside a ndarray
62
71
  return r.astype(float)
63
- if r.shape[-1] == shape[-1]:
64
- # the broadcast will be OK
72
+ if r.shape[-1] == expected_shape[-1]:
73
+ # broadcasting will be OK
65
74
  return r.astype(float)
66
- # the reshape below avoids a missing (1,) ending dimension
67
- # depending on how the user has coded the inital function
68
- return r.reshape(shape)
75
+
76
+ if r.shape != expected_shape:
77
+ # Usually, the reshape below adds a missing (1,) final axis to ensure # the PINN output and the other function (initial/boundary condition)
78
+ # have the correct shape, depending on how the user has coded the
79
+ # initial/boundary condition.
80
+ warnings.warn(
81
+ f"[{cause}] Performing operation `{binop}` between arrays"
82
+ f" of different shapes: got {r.shape} for the custom array and"
83
+ f" {expected_shape} for the PINN."
84
+ f" This can cause unexpected and wrong broadcasting."
85
+ f" Reshaping {r.shape} into {expected_shape}. Reshape your"
86
+ f" custom array to math the {expected_shape=} to prevent this"
87
+ f" warning."
88
+ )
89
+ return r.reshape(expected_shape)
90
+
91
+
92
+ def _subtract_with_check(
93
+ a: Array | int, b: Array | int, cause: str = ""
94
+ ) -> Array | float:
95
+ a = _check_shape_and_type(a, b.shape, cause=cause, binop="-")
96
+ return a - b
@@ -0,0 +1,2 @@
1
+ Hugo Gangloff <hugo.gangloff@inrae.fr>
2
+ Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -0,0 +1,127 @@
1
+ Metadata-Version: 2.1
2
+ Name: jinns
3
+ Version: 1.2.0
4
+ Summary: Physics Informed Neural Network with JAX
5
+ Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
+ Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
7
+ License: Apache License 2.0
8
+ Project-URL: Repository, https://gitlab.com/mia_jinns/jinns
9
+ Project-URL: Documentation, https://mia_jinns.gitlab.io/jinns/index.html
10
+ Classifier: License :: OSI Approved :: Apache Software License
11
+ Classifier: Development Status :: 4 - Beta
12
+ Classifier: Programming Language :: Python
13
+ Requires-Python: >=3.10
14
+ Description-Content-Type: text/markdown
15
+ License-File: LICENSE
16
+ License-File: AUTHORS
17
+ Requires-Dist: numpy
18
+ Requires-Dist: jax
19
+ Requires-Dist: jaxopt
20
+ Requires-Dist: optax
21
+ Requires-Dist: equinox>0.11.3
22
+ Requires-Dist: jax-tqdm
23
+ Requires-Dist: diffrax
24
+ Requires-Dist: matplotlib
25
+ Provides-Extra: notebook
26
+ Requires-Dist: jupyter; extra == "notebook"
27
+ Requires-Dist: seaborn; extra == "notebook"
28
+
29
+ jinns
30
+ =====
31
+
32
+ ![status](https://gitlab.com/mia_jinns/jinns/badges/main/pipeline.svg) ![coverage](https://gitlab.com/mia_jinns/jinns/badges/main/coverage.svg)
33
+
34
+ Physics Informed Neural Networks with JAX. **jinns** is developed to estimate solutions of ODE and PDE problems using neural networks, with a strong focus on
35
+
36
+ 1. inverse problems: find equation parameters given noisy/indirect observations
37
+ 2. meta-modeling: solve for a parametric family of differential equations
38
+
39
+ It can also be used for forward problems and hybrid-modeling.
40
+
41
+ **jinns** specific points:
42
+
43
+ - **jinns uses JAX** - It is directed to JAX users: forward and backward autodiff, vmapping, jitting and more! No reinventing the wheel: it relies on the JAX ecosystem whenever possible, such as [equinox](https://github.com/patrick-kidger/equinox/) for neural networks or [optax](https://optax.readthedocs.io/) for optimization.
44
+
45
+ - **jinns is highly modular** - It gives users maximum control for defining their problems, and extending the package. The maths and computations are visible and not hidden behind layers of code!
46
+
47
+ - **jinns is efficient** - It compares favorably to other existing Python package for PINNs on the [PINNacle benchmarks](https://github.com/i207M/PINNacle/), as demonstrated in the table below. For more details on the benchmarks, checkout the [PINN multi-library benchmark](https://gitlab.com/mia_jinns/pinn-multi-library-benchmark)
48
+
49
+ - Implemented PINN architectures
50
+ - Vanilla Multi-Layer Perceptron popular accross the PINNs litterature.
51
+
52
+ - [Separable PINNs](https://openreview.net/pdf?id=dEySGIcDnI): allows to leverage forward-mode autodiff for computational speed.
53
+
54
+ - [Hyper PINNs](https://arxiv.org/pdf/2111.01008.pdf): useful for meta-modeling
55
+
56
+
57
+ - **Get started**: check out our various notebooks on the [documentation](https://mia_jinns.gitlab.io/jinns/index.html).
58
+
59
+ | | jinns | DeepXDE - JAX | DeepXDE - Pytorch | PINA | Nvidia Modulus |
60
+ |---|:---:|:---:|:---:|:---:|:---:|
61
+ | Burgers1D | **445** | 723 | 671 | 1977 | 646 |
62
+ | NS2d-C | **265** | 278 | 441 | 1600 | 275 |
63
+ | PInv | 149 | 218 | *CC* | 1509 | **135** |
64
+ | Diffusion-Reaction-Inv | **284** | *NI* | 3424 | 4061 | 2541 |
65
+ | Navier-Stokes-Inv | **175** | *NI* | 1511 | 1403 | 498 |
66
+
67
+ *Training time in seconds on an Nvidia T600 GPU. NI means problem cannot be implemented in the backend, CC means the code crashed.*
68
+
69
+ ![A diagram of jinns workflow](img/jinns-diagram.png)
70
+
71
+
72
+ # Installation
73
+
74
+ Install the latest version with pip
75
+
76
+ ```bash
77
+ pip install jinns
78
+ ```
79
+
80
+ # Documentation
81
+
82
+ The project's documentation is hosted on Gitlab page and available at [https://mia_jinns.gitlab.io/jinns/index.html](https://mia_jinns.gitlab.io/jinns/index.html).
83
+
84
+
85
+ # Found a bug / want a feature ?
86
+
87
+ Open an issue on the [Gitlab repo](https://gitlab.com/mia_jinns/jinns/-/issues).
88
+
89
+
90
+ # Contributing
91
+
92
+ Here are the contributors guidelines:
93
+
94
+ 1. First fork the library on Gitlab.
95
+
96
+ 2. Then clone and install the library in development mode with
97
+
98
+ ```bash
99
+ pip install -e .
100
+ ```
101
+
102
+ 3. Install pre-commit and run it.
103
+
104
+ ```bash
105
+ pip install pre-commit
106
+ pre-commit install
107
+ ```
108
+
109
+ 4. Open a merge request once you are done with your changes, the review will be done via Gitlab.
110
+
111
+ # Contributors
112
+
113
+ Don't hesitate to contribute and get your name on the list here !
114
+
115
+ **List of contributors:** Hugo Gangloff, Nicolas Jouvin
116
+
117
+ # Cite us
118
+
119
+ Please consider citing our work if you found it useful to yours, using the following lines
120
+ ```
121
+ @software{jinns2024,
122
+ title={\texttt{jinns}: Physics-Informed Neural Networks with JAX},
123
+ author={Gangloff, Hugo and Jouvin, Nicolas},
124
+ url={https://gitlab.com/mia_jinns},
125
+ year={2024}
126
+ }
127
+ ```