jinns 0.8.0__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/utils/__init__.py CHANGED
@@ -7,3 +7,4 @@ from ._pinn import create_PINN
7
7
  from ._spinn import create_SPINN
8
8
  from ._hyperpinn import create_HYPERPINN
9
9
  from ._optim import alternate_optimizer, delayed_optimizer
10
+ from ._save_load import save_pinn, load_pinn
jinns/utils/_hyperpinn.py CHANGED
@@ -3,13 +3,13 @@ Implements utility function to create HYPERPINNs
3
3
  https://arxiv.org/pdf/2111.01008.pdf
4
4
  """
5
5
 
6
- from functools import partial
7
6
  import copy
8
7
  from math import prod
9
8
  import numpy as onp
10
9
  import jax
11
10
  import jax.numpy as jnp
12
11
  from jax.tree_util import tree_leaves, tree_map
12
+ from jax.typing import ArrayLike
13
13
  import equinox as eqx
14
14
 
15
15
  from jinns.utils._pinn import PINN, _MLP
@@ -35,47 +35,39 @@ class HYPERPINN(PINN):
35
35
  Composed of a PINN and an hypernetwork
36
36
  """
37
37
 
38
+ params_hyper: eqx.Module
39
+ static_hyper: eqx.Module
40
+ hyperparams: list = eqx.field(static=True)
41
+ hypernet_input_size: int
42
+ pinn_params_sum: ArrayLike
43
+ pinn_params_cumsum: ArrayLike
44
+
38
45
  def __init__(
39
46
  self,
40
- key,
41
- eqx_list,
42
- eqx_list_hyper,
47
+ mlp,
48
+ hyper_mlp,
43
49
  slice_solution,
44
50
  eq_type,
45
51
  input_transform,
46
52
  output_transform,
47
53
  hyperparams,
48
54
  hypernet_input_size,
49
- output_slice=None,
55
+ output_slice,
50
56
  ):
51
- key, subkey = jax.random.split(key, 2)
52
57
  super().__init__(
53
- subkey,
54
- eqx_list,
58
+ mlp,
55
59
  slice_solution,
56
60
  eq_type,
57
61
  input_transform,
58
62
  output_transform,
59
63
  output_slice,
60
64
  )
61
- self.pinn_params_sum, self.pinn_params_cumsum = _get_param_nb(self.params)
62
- # the number of parameters for the pinn will be the number of ouputs
63
- # for the hypetnetwork
64
- self.hyperparams = hyperparams
65
- self.hypernet_input_size = hypernet_input_size
66
- key, subkey = jax.random.split(key, 2)
67
- try:
68
- eqx_list_hyper[-1][2] = self.pinn_params_sum
69
- except IndexError:
70
- eqx_list_hyper[-2][2] = self.pinn_params_sum
71
- try:
72
- eqx_list_hyper[0][1] = self.hypernet_input_size
73
- except IndexError:
74
- eqx_list_hyper[0][1] = self.hypernet_input_size
75
- _hyper = _MLP(subkey, eqx_list_hyper)
76
65
  self.params_hyper, self.static_hyper = eqx.partition(
77
- _hyper, eqx.is_inexact_array
66
+ hyper_mlp, eqx.is_inexact_array
78
67
  )
68
+ self.hyperparams = hyperparams
69
+ self.hypernet_input_size = hypernet_input_size
70
+ self.pinn_params_sum, self.pinn_params_cumsum = _get_param_nb(self.params)
79
71
 
80
72
  def init_params(self):
81
73
  return self.params_hyper
@@ -276,14 +268,32 @@ def create_HYPERPINN(
276
268
  def output_transform(_in_pinn, _out_pinn):
277
269
  return _out_pinn
278
270
 
271
+ key, subkey = jax.random.split(key, 2)
272
+ mlp = _MLP(subkey, eqx_list)
273
+ # quick partitioning to get the params to get the correct number of neurons
274
+ # for the last layer of hyper network
275
+ params_mlp, _ = eqx.partition(mlp, eqx.is_inexact_array)
276
+ pinn_params_sum, _ = _get_param_nb(params_mlp)
277
+ # the number of parameters for the pinn will be the number of ouputs
278
+ # for the hyper network
279
+ try:
280
+ eqx_list_hyper[-1][2] = pinn_params_sum
281
+ except IndexError:
282
+ eqx_list_hyper[-2][2] = pinn_params_sum
283
+ try:
284
+ eqx_list_hyper[0][1] = hypernet_input_size
285
+ except IndexError:
286
+ eqx_list_hyper[1][1] = hypernet_input_size
287
+ key, subkey = jax.random.split(key, 2)
288
+ hyper_mlp = _MLP(subkey, eqx_list_hyper)
289
+
279
290
  if shared_pinn_outputs is not None:
280
291
  hyperpinns = []
281
292
  static = None
282
293
  for output_slice in shared_pinn_outputs:
283
294
  hyperpinn = HYPERPINN(
284
- key,
285
- eqx_list,
286
- eqx_list_hyper,
295
+ mlp,
296
+ hyper_mlp,
287
297
  slice_solution,
288
298
  eq_type,
289
299
  input_transform,
@@ -300,14 +310,14 @@ def create_HYPERPINN(
300
310
  hyperpinns.append(hyperpinn)
301
311
  return hyperpinns
302
312
  hyperpinn = HYPERPINN(
303
- key,
304
- eqx_list,
305
- eqx_list_hyper,
313
+ mlp,
314
+ hyper_mlp,
306
315
  slice_solution,
307
316
  eq_type,
308
317
  input_transform,
309
318
  output_transform,
310
319
  hyperparams,
311
320
  hypernet_input_size,
321
+ None,
312
322
  )
313
323
  return hyperpinn
jinns/utils/_pinn.py CHANGED
@@ -2,8 +2,10 @@
2
2
  Implements utility function to create PINNs
3
3
  """
4
4
 
5
+ from typing import Callable
5
6
  import jax
6
7
  import jax.numpy as jnp
8
+ from jax.typing import ArrayLike
7
9
  import equinox as eqx
8
10
 
9
11
 
@@ -39,16 +41,11 @@ class _MLP(eqx.Module):
39
41
  """
40
42
 
41
43
  self.layers = []
42
- # TODO we are limited currently in the number of layer type we can
43
- # parse and we lack some safety checks
44
44
  for l in eqx_list:
45
45
  if len(l) == 1:
46
46
  self.layers.append(l[0])
47
47
  else:
48
- # By default we append a random key at the end of the
49
- # arguments fed into a layer module call
50
48
  key, subkey = jax.random.split(key, 2)
51
- # the argument key is keyword only
52
49
  self.layers.append(l[0](*l[1:], key=subkey))
53
50
 
54
51
  def __call__(self, t):
@@ -57,25 +54,31 @@ class _MLP(eqx.Module):
57
54
  return t
58
55
 
59
56
 
60
- class PINN:
57
+ class PINN(eqx.Module):
61
58
  """
62
59
  Basically a wrapper around the `__call__` function to be able to give a type to
63
60
  our former `self.u`
64
61
  The function create_PINN has the role to population the `__call__` function
65
62
  """
66
63
 
64
+ slice_solution: ArrayLike
65
+ eq_type: str = eqx.field(static=True)
66
+ input_transform: Callable = eqx.field(static=True)
67
+ output_transform: Callable = eqx.field(static=True)
68
+ output_slice: ArrayLike
69
+ params: eqx.Module
70
+ static: eqx.Module
71
+
67
72
  def __init__(
68
73
  self,
69
- key,
70
- eqx_list,
74
+ mlp,
71
75
  slice_solution,
72
76
  eq_type,
73
77
  input_transform,
74
78
  output_transform,
75
- output_slice=None,
79
+ output_slice,
76
80
  ):
77
- _pinn = _MLP(key, eqx_list)
78
- self.params, self.static = eqx.partition(_pinn, eqx.is_inexact_array)
81
+ self.params, self.static = eqx.partition(mlp, eqx.is_inexact_array)
79
82
  self.slice_solution = slice_solution
80
83
  self.eq_type = eq_type
81
84
  self.input_transform = input_transform
@@ -134,7 +137,7 @@ def create_PINN(
134
137
  shared_pinn_outputs=None,
135
138
  slice_solution=None,
136
139
  ):
137
- """
140
+ r"""
138
141
  Utility function to create a standard PINN neural network with the equinox
139
142
  library.
140
143
 
@@ -246,13 +249,14 @@ def create_PINN(
246
249
  def output_transform(_in_pinn, _out_pinn):
247
250
  return _out_pinn
248
251
 
252
+ mlp = _MLP(key, eqx_list)
253
+
249
254
  if shared_pinn_outputs is not None:
250
255
  pinns = []
251
256
  static = None
252
257
  for output_slice in shared_pinn_outputs:
253
258
  pinn = PINN(
254
- key,
255
- eqx_list,
259
+ mlp,
256
260
  slice_solution,
257
261
  eq_type,
258
262
  input_transform,
@@ -266,7 +270,5 @@ def create_PINN(
266
270
  pinn.static = static
267
271
  pinns.append(pinn)
268
272
  return pinns
269
- pinn = PINN(
270
- key, eqx_list, slice_solution, eq_type, input_transform, output_transform
271
- )
273
+ pinn = PINN(mlp, slice_solution, eq_type, input_transform, output_transform, None)
272
274
  return pinn
@@ -0,0 +1,175 @@
1
+ """
2
+ Implements save and load functions
3
+ """
4
+
5
+ import pickle
6
+ import jax
7
+ import equinox as eqx
8
+
9
+ from jinns.utils._pinn import create_PINN
10
+ from jinns.utils._spinn import create_SPINN
11
+ from jinns.utils._hyperpinn import create_HYPERPINN
12
+
13
+
14
+ def function_to_string(eqx_list):
15
+ """
16
+ We need this transformation for eqx_list to be pickled
17
+
18
+ From `[[eqx.nn.Linear, 2, 20],
19
+ [jax.nn.tanh],
20
+ [eqx.nn.Linear, 20, 20],
21
+ [jax.nn.tanh],
22
+ [eqx.nn.Linear, 20, 20],
23
+ [jax.nn.tanh],
24
+ [eqx.nn.Linear, 20, 1]` to
25
+ `[["Linear", 2, 20],
26
+ ["tanh"],
27
+ ["Linear", 20, 20],
28
+ ["tanh"],
29
+ ["Linear", 20, 20],
30
+ ["tanh"],
31
+ ["Linear", 20, 1]`
32
+ """
33
+ return jax.tree_util.tree_map(
34
+ lambda x: x.__name__ if hasattr(x, "__call__") else x, eqx_list
35
+ )
36
+
37
+
38
+ def string_to_function(eqx_list_with_string):
39
+ """
40
+ We need this transformation for eqx_list at the loading ("unpickling")
41
+ operation.
42
+
43
+ From `[["Linear", 2, 20],
44
+ ["tanh"],
45
+ ["Linear", 20, 20],
46
+ ["tanh"],
47
+ ["Linear", 20, 20],
48
+ ["tanh"],
49
+ ["Linear", 20, 1]`
50
+ to `[[eqx.nn.Linear, 2, 20],
51
+ [jax.nn.tanh],
52
+ [eqx.nn.Linear, 20, 20],
53
+ [jax.nn.tanh],
54
+ [eqx.nn.Linear, 20, 20],
55
+ [jax.nn.tanh],
56
+ [eqx.nn.Linear, 20, 1]` to
57
+ """
58
+
59
+ def _str_to_fun(l):
60
+ try:
61
+ try:
62
+ try:
63
+ return getattr(jax.nn, l)
64
+ except AttributeError:
65
+ return getattr(jax.numpy, l)
66
+ except AttributeError:
67
+ return getattr(eqx.nn, l)
68
+ except AttributeError as exc:
69
+ raise ValueError(
70
+ "Activation functions must be from jax.nn or jax.numpy,"
71
+ + "or layers must be eqx.nn layers"
72
+ ) from exc
73
+
74
+ return jax.tree_util.tree_map(
75
+ lambda x: _str_to_fun(x) if isinstance(x, str) else x, eqx_list_with_string
76
+ )
77
+
78
+
79
+ def save_pinn(filename, u, params, kwargs_creation):
80
+ """
81
+ Save a PINN / HyperPINN / SPINN model
82
+ This function creates 3 files, beggining by `filename`
83
+
84
+ 1. an eqx file to save the eqx.Module (the PINN, HyperPINN, ...)
85
+ 2. a pickle file for the parameters
86
+ 3. a pickle file for the arguments that have been used at PINN
87
+
88
+ creation and that we need to reconstruct the eqx.module later on.
89
+
90
+ Parameters
91
+ ----------
92
+ filename
93
+ Filename (prefix) without extension
94
+ u
95
+ The PINN
96
+ params
97
+ The dictionary of parameters of the model.
98
+ Typically, it is a dictionary of
99
+ dictionaries: `eq_params` and `nn_params`, respectively the
100
+ differential equation parameters and the neural network parameter
101
+ kwargs_creation
102
+ The dictionary of arguments that were used to create the PINN, e.g.
103
+ the layers list, O/PDE type, etc.
104
+ """
105
+ eqx.tree_serialise_leaves(filename + "-module.eqx", u)
106
+ with open(filename + "-parameters.pkl", "wb") as f:
107
+ pickle.dump(params, f)
108
+ kwargs_creation = kwargs_creation.copy() # avoid side-effect that would be
109
+ # very probably harmless anyway
110
+
111
+ # we now need to transform the functions in eqx_list into strings otherwise
112
+ # it could not be pickled
113
+ kwargs_creation["eqx_list"] = function_to_string(kwargs_creation["eqx_list"])
114
+
115
+ # same thing if there is an hypernetwork:
116
+ try:
117
+ kwargs_creation["eqx_list_hyper"] = function_to_string(
118
+ kwargs_creation["eqx_list_hyper"]
119
+ )
120
+ except KeyError:
121
+ pass
122
+
123
+ with open(filename + "-arguments.pkl", "wb") as f:
124
+ pickle.dump(kwargs_creation, f)
125
+
126
+
127
+ def load_pinn(filename, type_):
128
+ """
129
+ Load a PINN model. This function needs to access 3 files :
130
+ `{filename}-module.eqx`, `{filename}-parameters.pkl` and
131
+ `{filename}-arguments.pkl`.
132
+
133
+ These files are created by `jinns.utils.save_pinn`.
134
+
135
+ Note that this requires equinox v0.11.3 (currently latest version) for the
136
+ `eqx.filter_eval_shape` to work.
137
+
138
+ Parameters
139
+ ----------
140
+ filename
141
+ Filename (prefix) without extension.
142
+ type_
143
+ Type of model to load. Must be in ["pinn", "hyperpinn", "spinn"].
144
+
145
+ Returns
146
+ -------
147
+ u_reloaded
148
+ The reloaded PINN
149
+ params_reloaded
150
+ The reloaded parameters
151
+ """
152
+ with open(filename + "-arguments.pkl", "rb") as f:
153
+ kwargs_reloaded = pickle.load(f)
154
+ with open(filename + "-parameters.pkl", "rb") as f:
155
+ params_reloaded = pickle.load(f)
156
+ kwargs_reloaded["eqx_list"] = string_to_function(kwargs_reloaded["eqx_list"])
157
+ if type_ == "pinn":
158
+ # next line creates a shallow model, the jax arrays are just shapes and
159
+ # not populated, this just recreates the correct pytree structure
160
+ u_reloaded_shallow = eqx.filter_eval_shape(create_PINN, **kwargs_reloaded)
161
+ elif type_ == "spinn":
162
+ u_reloaded_shallow = eqx.filter_eval_shape(create_SPINN, **kwargs_reloaded)
163
+ elif type_ == "hyperpinn":
164
+ kwargs_reloaded["eqx_list_hyper"] = string_to_function(
165
+ kwargs_reloaded["eqx_list_hyper"]
166
+ )
167
+ u_reloaded_shallow = eqx.filter_eval_shape(create_HYPERPINN, **kwargs_reloaded)
168
+ else:
169
+ raise ValueError(f"{type_} is not valid")
170
+ # now the empty structure is populated with the actual saved array values
171
+ # stored in the eqx file
172
+ u_reloaded = eqx.tree_deserialise_leaves(
173
+ filename + "-module.eqx", u_reloaded_shallow
174
+ )
175
+ return u_reloaded, params_reloaded
jinns/utils/_spinn.py CHANGED
@@ -76,17 +76,23 @@ class _SPINN(eqx.Module):
76
76
  return jnp.asarray(outputs)
77
77
 
78
78
 
79
- class SPINN:
79
+ class SPINN(eqx.Module):
80
80
  """
81
81
  Basically a wrapper around the `__call__` function to be able to give a type to
82
82
  our former `self.u`
83
83
  The function create_SPINN has the role to population the `__call__` function
84
84
  """
85
85
 
86
- def __init__(self, key, d, r, eqx_list, eq_type, m=1):
86
+ d: int
87
+ r: int
88
+ eq_type: str = eqx.field(static=True)
89
+ m: int
90
+ params: eqx.Module
91
+ static: eqx.Module
92
+
93
+ def __init__(self, spinn_mlp, d, r, eq_type, m):
87
94
  self.d, self.r, self.m = d, r, m
88
- _spinn = _SPINN(key, d, r, eqx_list, m)
89
- self.params, self.static = eqx.partition(_spinn, eqx.is_inexact_array)
95
+ self.params, self.static = eqx.partition(spinn_mlp, eqx.is_inexact_array)
90
96
  self.eq_type = eq_type
91
97
 
92
98
  def init_params(self):
@@ -229,6 +235,7 @@ def create_SPINN(key, d, r, eqx_list, eq_type, m=1):
229
235
  "Too many dimensions, not enough letters available in jnp.einsum"
230
236
  )
231
237
 
232
- spinn = SPINN(key, d, r, eqx_list, eq_type, m)
238
+ spinn_mlp = _SPINN(key, d, r, eqx_list, m)
239
+ spinn = SPINN(spinn_mlp, d, r, eq_type, m)
233
240
 
234
241
  return spinn
jinns/utils/_utils.py CHANGED
@@ -7,7 +7,6 @@ from operator import getitem
7
7
  import numpy as np
8
8
  import jax
9
9
  import jax.numpy as jnp
10
- import optax
11
10
 
12
11
 
13
12
  def _check_nan_in_pytree(pytree):
@@ -0,0 +1,727 @@
1
+ import numpy as np
2
+ import jax
3
+ import jax.numpy as jnp
4
+ import optax
5
+ import equinox as eqx
6
+ from functools import reduce
7
+ from operator import getitem
8
+
9
+
10
+ def _check_nan_in_pytree(pytree):
11
+ """
12
+ Check if there is a NaN value anywhere is the pytree
13
+
14
+ Parameters
15
+ ----------
16
+ pytree
17
+ A pytree
18
+
19
+ Returns
20
+ -------
21
+ res
22
+ A boolean. True if any of the pytree content is NaN
23
+ """
24
+ return jnp.any(
25
+ jnp.array(
26
+ [
27
+ value
28
+ for value in jax.tree_util.tree_leaves(
29
+ jax.tree_util.tree_map(lambda x: jnp.any(jnp.isnan(x)), pytree)
30
+ )
31
+ ]
32
+ )
33
+ )
34
+
35
+
36
+ def _tracked_parameters(params, tracked_params_key_list):
37
+ """
38
+ Returns a pytree with the same structure as params with True is the
39
+ parameter is tracked False otherwise
40
+ """
41
+
42
+ def set_nested_item(dataDict, mapList, val):
43
+ """
44
+ Set item in nested dictionary
45
+ https://stackoverflow.com/questions/54137991/how-to-update-values-in-nested-dictionary-if-keys-are-in-a-list
46
+ """
47
+ reduce(getitem, mapList[:-1], dataDict)[mapList[-1]] = val
48
+ return dataDict
49
+
50
+ tracked_params = jax.tree_util.tree_map(
51
+ lambda x: False, params
52
+ ) # init with all False
53
+
54
+ for key_list in tracked_params_key_list:
55
+ tracked_params = set_nested_item(tracked_params, key_list, True)
56
+
57
+ return tracked_params
58
+
59
+
60
+ class _MLP(eqx.Module):
61
+ """
62
+ Class to construct an equinox module from a key and a eqx_list. To be used
63
+ in pair with the function `create_PINN`
64
+ """
65
+
66
+ layers: list
67
+
68
+ def __init__(self, key, eqx_list):
69
+ """
70
+ Parameters
71
+ ----------
72
+ key
73
+ A jax random key
74
+ eqx_list
75
+ A list of list of successive equinox modules and activation functions to
76
+ describe the PINN architecture. The inner lists have the eqx module or
77
+ axtivation function as first item, other items represents arguments
78
+ that could be required (eg. the size of the layer).
79
+ __Note:__ the `key` argument need not be given.
80
+ Thus typical example is `eqx_list=
81
+ [[eqx.nn.Linear, 2, 20],
82
+ [jax.nn.tanh],
83
+ [eqx.nn.Linear, 20, 20],
84
+ [jax.nn.tanh],
85
+ [eqx.nn.Linear, 20, 20],
86
+ [jax.nn.tanh],
87
+ [eqx.nn.Linear, 20, 1]
88
+ ]`
89
+ """
90
+
91
+ self.layers = []
92
+ # TODO we are limited currently in the number of layer type we can
93
+ # parse and we lack some safety checks
94
+ for l in eqx_list:
95
+ if len(l) == 1:
96
+ self.layers.append(l[0])
97
+ else:
98
+ # By default we append a random key at the end of the
99
+ # arguments fed into a layer module call
100
+ key, subkey = jax.random.split(key, 2)
101
+ # the argument key is keyword only
102
+ self.layers.append(l[0](*l[1:], key=subkey))
103
+
104
+ def __call__(self, t):
105
+ for layer in self.layers:
106
+ t = layer(t)
107
+ return t
108
+
109
+
110
+ class PINN:
111
+ """
112
+ Basically a wrapper around the `__call__` function to be able to give a type to
113
+ our former `self.u`
114
+ The function create_PINN has the role to population the `__call__` function
115
+ """
116
+
117
+ def __init__(self, key, eqx_list, output_slice=None):
118
+ _pinn = _MLP(key, eqx_list)
119
+ self.params, self.static = eqx.partition(_pinn, eqx.is_inexact_array)
120
+ self.output_slice = output_slice
121
+
122
+ def init_params(self):
123
+ return self.params
124
+
125
+ def __call__(self, *args, **kwargs):
126
+ return self.apply_fn(self, *args, **kwargs)
127
+
128
+
129
+ def create_PINN(
130
+ key,
131
+ eqx_list,
132
+ eq_type,
133
+ dim_x=0,
134
+ with_eq_params=None,
135
+ input_transform=None,
136
+ output_transform=None,
137
+ shared_pinn_outputs=None,
138
+ ):
139
+ """
140
+ Utility function to create a standard PINN neural network with the equinox
141
+ library.
142
+
143
+ Parameters
144
+ ----------
145
+ key
146
+ A jax random key that will be used to initialize the network parameters
147
+ eqx_list
148
+ A list of list of successive equinox modules and activation functions to
149
+ describe the PINN architecture. The inner lists have the eqx module or
150
+ axtivation function as first item, other items represents arguments
151
+ that could be required (eg. the size of the layer).
152
+ __Note:__ the `key` argument need not be given.
153
+ Thus typical example is `eqx_list=
154
+ [[eqx.nn.Linear, 2, 20],
155
+ [jax.nn.tanh],
156
+ [eqx.nn.Linear, 20, 20],
157
+ [jax.nn.tanh],
158
+ [eqx.nn.Linear, 20, 20],
159
+ [jax.nn.tanh],
160
+ [eqx.nn.Linear, 20, 1]
161
+ ]`
162
+ eq_type
163
+ A string with three possibilities.
164
+ "ODE": the PINN is called with one input `t`.
165
+ "statio_PDE": the PINN is called with one input `x`, `x`
166
+ can be high dimensional.
167
+ "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
168
+ can be high dimensional.
169
+ **Note: the input dimension as given in eqx_list has to match the sum
170
+ of the dimension of `t` + the dimension of `x` + the number of
171
+ parameters in `eq_params` if with_eq_params is `True` (see below)**
172
+ dim_x
173
+ An integer. The dimension of `x`. Default `0`
174
+ with_eq_params
175
+ Default is None. Otherwise a list of keys from the dict `eq_params`
176
+ that the network also takes as inputs.
177
+ the equation parameters (`eq_params`).
178
+ **If some keys are provided, the input dimension
179
+ as given in eqx_list must take into account the number of such provided
180
+ keys (i.e., the input dimension is the addition of the dimension of ``t``
181
+ + the dimension of ``x`` + the number of ``eq_params``)**
182
+ input_transform
183
+ A function that will be called before entering the PINN. Its output(s)
184
+ must mathc the PINN inputs.
185
+ output_transform
186
+ A function with arguments the same input(s) as the PINN AND the PINN
187
+ output that will be called after exiting the PINN
188
+ shared_pinn_outputs
189
+ A tuple of jnp.s_[] (slices) to determine the different output for each
190
+ network. In this case we return a list of PINNs, one for each output in
191
+ shared_pinn_outputs. This is useful to create PINNs that share the
192
+ same network and same parameters. Default is None, we only return one PINN.
193
+
194
+
195
+ Returns
196
+ -------
197
+ init_fn
198
+ A function which (re-)initializes the PINN parameters with the provided
199
+ jax random key
200
+ apply_fn
201
+ A function to apply the neural network on given inputs for given
202
+ parameters. A typical call will be of the form `u(t, nn_params)` for
203
+ ODE or `u(t, x, nn_params)` for nD PDEs (`x` being multidimensional)
204
+ or even `u(t, x, nn_params, eq_params)` if with_eq_params is `True`
205
+
206
+ Raises
207
+ ------
208
+ RuntimeError
209
+ If the parameter value for eq_type is not in `["ODE", "statio_PDE",
210
+ "nonstatio_PDE"]`
211
+ RuntimeError
212
+ If we have a `dim_x > 0` and `eq_type == "ODE"`
213
+ or if we have a `dim_x = 0` and `eq_type != "ODE"`
214
+ """
215
+ if eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
216
+ raise RuntimeError("Wrong parameter value for eq_type")
217
+
218
+ if eq_type == "ODE" and dim_x != 0:
219
+ raise RuntimeError("Wrong parameter combination eq_type and dim_x")
220
+
221
+ if eq_type != "ODE" and dim_x == 0:
222
+ raise RuntimeError("Wrong parameter combination eq_type and dim_x")
223
+
224
+ dim_t = 0 if eq_type == "statio_PDE" else 1
225
+ dim_in_params = len(with_eq_params) if with_eq_params is not None else 0
226
+ try:
227
+ nb_inputs_declared = eqx_list[0][1] # normally we look for 2nd ele of 1st layer
228
+ except IndexError:
229
+ nb_inputs_declared = eqx_list[1][1]
230
+ # but we can have, eg, a flatten first layer
231
+
232
+ try:
233
+ nb_outputs_declared = eqx_list[-1][2] # normally we look for 3rd ele of
234
+ # last layer
235
+ except IndexError:
236
+ nb_outputs_declared = eqx_list[-2][2]
237
+ # but we can have, eg, a `jnp.exp` last layer
238
+
239
+ # NOTE Currently the check below is disabled because we added
240
+ # input_transform
241
+ # if dim_t + dim_x + dim_in_params != nb_inputs_declared:
242
+ # raise RuntimeError("Error in the declarations of the number of parameters")
243
+
244
+ if eq_type == "ODE":
245
+ if with_eq_params is None:
246
+
247
+ def apply_fn(self, t, u_params, eq_params=None):
248
+ model = eqx.combine(u_params, self.static)
249
+ t = t[
250
+ None
251
+ ] # Note that we added a dimension to t which is lacking for the ODE batches
252
+ if output_transform is None:
253
+ if input_transform is not None:
254
+ res = model(input_transform(t)).squeeze()
255
+ else:
256
+ res = model(t).squeeze()
257
+ else:
258
+ if input_transform is not None:
259
+ res = output_transform(t, model(input_transform(t)).squeeze())
260
+ else:
261
+ res = output_transform(t, model(t).squeeze())
262
+ if self.output_slice is not None:
263
+ return res[self.output_slice]
264
+ else:
265
+ return res
266
+
267
+ else:
268
+
269
+ def apply_fn(self, t, u_params, eq_params):
270
+ model = eqx.combine(u_params, self.static)
271
+ t = t[
272
+ None
273
+ ] # We added a dimension to t which is lacking for the ODE batches
274
+ eq_params_flatten = jnp.concatenate(
275
+ [e.ravel() for k, e in eq_params.items() if k in with_eq_params]
276
+ )
277
+ t_eq_params = jnp.concatenate([t, eq_params_flatten], axis=-1)
278
+
279
+ if output_transform is None:
280
+ if input_transform is not None:
281
+ res = model(input_transform(t_eq_params)).squeeze()
282
+ else:
283
+ res = model(t_eq_params).squeeze()
284
+ else:
285
+ if input_transform is not None:
286
+ res = output_transform(
287
+ t_eq_params,
288
+ model(input_transform(t_eq_params)).squeeze(),
289
+ )
290
+ else:
291
+ res = output_transform(
292
+ t_eq_params, model(t_eq_params).squeeze()
293
+ )
294
+
295
+ if self.output_slice is not None:
296
+ return res[self.output_slice]
297
+ else:
298
+ return res
299
+
300
+ elif eq_type == "statio_PDE":
301
+ # Here we add an argument `x` which can be high dimensional
302
+ if with_eq_params is None:
303
+
304
+ def apply_fn(self, x, u_params, eq_params=None):
305
+ model = eqx.combine(u_params, self.static)
306
+
307
+ if output_transform is None:
308
+ if input_transform is not None:
309
+ res = model(input_transform(x)).squeeze()
310
+ else:
311
+ res = model(x).squeeze()
312
+ else:
313
+ if input_transform is not None:
314
+ res = output_transform(x, model(input_transform(x)).squeeze())
315
+ else:
316
+ res = output_transform(x, model(x).squeeze()).squeeze()
317
+
318
+ if self.output_slice is not None:
319
+ res = res[self.output_slice]
320
+
321
+ # force (1,) output for non vectorial solution (consistency)
322
+ if not res.shape:
323
+ return jnp.expand_dims(res, axis=-1)
324
+ else:
325
+ return res
326
+
327
+ else:
328
+
329
+ def apply_fn(self, x, u_params, eq_params):
330
+ model = eqx.combine(u_params, self.static)
331
+ eq_params_flatten = jnp.concatenate(
332
+ [e.ravel() for k, e in eq_params.items() if k in with_eq_params]
333
+ )
334
+ x_eq_params = jnp.concatenate([x, eq_params_flatten], axis=-1)
335
+
336
+ if output_transform is None:
337
+ if input_transform is not None:
338
+ res = model(input_transform(x_eq_params)).squeeze()
339
+ else:
340
+ res = model(x_eq_params).squeeze()
341
+ else:
342
+ if input_transform is not None:
343
+ res = output_transform(
344
+ x_eq_params,
345
+ model(input_transform(x_eq_params)).squeeze(),
346
+ )
347
+ else:
348
+ res = output_transform(
349
+ x_eq_params, model(x_eq_params).squeeze()
350
+ )
351
+
352
+ if self.output_slice is not None:
353
+ res = res[self.output_slice]
354
+
355
+ # force (1,) output for non vectorial solution (consistency)
356
+ if not res.shape:
357
+ return jnp.expand_dims(res, axis=-1)
358
+ else:
359
+ return res
360
+
361
+ elif eq_type == "nonstatio_PDE":
362
+ # Here we add an argument `x` which can be high dimensional
363
+ if with_eq_params is None:
364
+
365
+ def apply_fn(self, t, x, u_params, eq_params=None):
366
+ model = eqx.combine(u_params, self.static)
367
+ t_x = jnp.concatenate([t, x], axis=-1)
368
+
369
+ if output_transform is None:
370
+ if input_transform is not None:
371
+ res = model(input_transform(t_x)).squeeze()
372
+ else:
373
+ res = model(t_x).squeeze()
374
+ else:
375
+ if input_transform is not None:
376
+ res = output_transform(
377
+ t_x, model(input_transform(t_x)).squeeze()
378
+ )
379
+ else:
380
+ res = output_transform(t_x, model(t_x).squeeze())
381
+
382
+ if self.output_slice is not None:
383
+ res = res[self.output_slice]
384
+
385
+ ## force (1,) output for non vectorial solution (consistency)
386
+ if not res.shape:
387
+ return jnp.expand_dims(res, axis=-1)
388
+ else:
389
+ return res
390
+
391
+ else:
392
+
393
+ def apply_fn(self, t, x, u_params, eq_params):
394
+ model = eqx.combine(u_params, self.static)
395
+ t_x = jnp.concatenate([t, x], axis=-1)
396
+ eq_params_flatten = jnp.concatenate(
397
+ [e.ravel() for k, e in eq_params.items() if k in with_eq_params]
398
+ )
399
+ t_x_eq_params = jnp.concatenate([t_x, eq_params_flatten], axis=-1)
400
+
401
+ if output_transform is None:
402
+ if input_transform is not None:
403
+ res = model(input_transform(t_x_eq_params)).squeeze()
404
+ else:
405
+ res = model(t_x_eq_params).squeeze()
406
+ else:
407
+ if input_transform is not None:
408
+ res = output_transform(
409
+ t_x_eq_params,
410
+ model(input_transform(t_x_eq_params)).squeeze(),
411
+ )
412
+ else:
413
+ res = output_transform(
414
+ t_x_eq_params,
415
+ model(input_transform(t_x_eq_params)).squeeze(),
416
+ )
417
+
418
+ if self.output_slice is not None:
419
+ res = res[self.output_slice]
420
+
421
+ # force (1,) output for non vectorial solution (consistency)
422
+ if not res.shape:
423
+ return jnp.expand_dims(res, axis=-1)
424
+ else:
425
+ return res
426
+
427
+ else:
428
+ raise RuntimeError("Wrong parameter value for eq_type")
429
+
430
+ if shared_pinn_outputs is not None:
431
+ pinns = []
432
+ static = None
433
+ for output_slice in shared_pinn_outputs:
434
+ pinn = PINN(key, eqx_list, output_slice)
435
+ pinn.apply_fn = apply_fn
436
+ # all the pinns are in fact the same so we share the same static
437
+ if static is None:
438
+ static = pinn.static
439
+ else:
440
+ pinn.static = static
441
+ pinns.append(pinn)
442
+ return pinns
443
+ else:
444
+ pinn = PINN(key, eqx_list)
445
+ pinn.apply_fn = apply_fn
446
+ return pinn
447
+
448
+
449
+ class _SPINN(eqx.Module):
450
+ """
451
+ Construct a Separable PINN as proposed in
452
+ Cho et al., _Separable Physics-Informed Neural Networks_, NeurIPS, 2023
453
+ """
454
+
455
+ layers: list
456
+ separated_mlp: list
457
+ d: int
458
+ r: int
459
+ m: int
460
+
461
+ def __init__(self, key, d, r, eqx_list, m=1):
462
+ """
463
+ Parameters
464
+ ----------
465
+ key
466
+ A jax random key
467
+ d
468
+ An integer. The number of dimensions to treat separately
469
+ r
470
+ An integer. The dimension of the embedding
471
+ eqx_list
472
+ A list of list of successive equinox modules and activation functions to
473
+ describe *each separable PINN architecture*.
474
+ The inner lists have the eqx module or
475
+ axtivation function as first item, other items represents arguments
476
+ that could be required (eg. the size of the layer).
477
+ __Note:__ the `key` argument need not be given.
478
+ Thus typical example is `eqx_list=
479
+ [[eqx.nn.Linear, d, 20],
480
+ [jax.nn.tanh],
481
+ [eqx.nn.Linear, 20, 20],
482
+ [jax.nn.tanh],
483
+ [eqx.nn.Linear, 20, 20],
484
+ [jax.nn.tanh],
485
+ [eqx.nn.Linear, 20, r]
486
+ ]`
487
+ """
488
+ keys = jax.random.split(key, 8)
489
+
490
+ self.d = d
491
+ self.r = r
492
+ self.m = m
493
+
494
+ self.separated_mlp = []
495
+ for d in range(self.d):
496
+ self.layers = []
497
+ for l in eqx_list:
498
+ if len(l) == 1:
499
+ self.layers.append(l[0])
500
+ else:
501
+ key, subkey = jax.random.split(key, 2)
502
+ self.layers.append(l[0](*l[1:], key=subkey))
503
+ self.separated_mlp.append(self.layers)
504
+
505
+ def __call__(self, t, x):
506
+ if t is not None:
507
+ dimensions = jnp.concatenate([t, x.flatten()], axis=0)
508
+ else:
509
+ dimensions = jnp.concatenate([x.flatten()], axis=0)
510
+ outputs = []
511
+ for d in range(self.d):
512
+ t_ = dimensions[d][None]
513
+ for layer in self.separated_mlp[d]:
514
+ t_ = layer(t_)
515
+ outputs += [t_]
516
+ return jnp.asarray(outputs)
517
+
518
+
519
+ def _get_grid(in_array):
520
+ """
521
+ From an array of shape (B, D), D > 1, get the grid array, i.e., an array of
522
+ shape (B, B, ...(D times)..., B, D): along the last axis we have the array
523
+ of values
524
+ """
525
+ if in_array.shape[-1] > 1 or in_array.ndim > 1:
526
+ return jnp.stack(
527
+ jnp.meshgrid(
528
+ *(in_array[..., d] for d in range(in_array.shape[-1])), indexing="ij"
529
+ ),
530
+ axis=-1,
531
+ )
532
+ else:
533
+ return in_array
534
+
535
+
536
+ def _get_vmap_in_axes_params(eq_params_batch_dict, params):
537
+ """
538
+ Return the input vmap axes when there is batch(es) of parameters to vmap
539
+ over. The latter are designated by keys in eq_params_batch_dict
540
+ If eq_params_batch_dict (ie no additional parameter batch), we return None
541
+ """
542
+ if eq_params_batch_dict is None:
543
+ return (None,)
544
+ else:
545
+ # We use pytree indexing of vmapped axes and vmap on axis
546
+ # 0 of the eq_parameters for which we have a batch
547
+ # this is for a fine-grained vmaping
548
+ # scheme over the params
549
+ vmap_in_axes_params = (
550
+ {
551
+ "eq_params": {
552
+ k: (0 if k in eq_params_batch_dict.keys() else None)
553
+ for k in params["eq_params"].keys()
554
+ },
555
+ "nn_params": None,
556
+ },
557
+ )
558
+ return vmap_in_axes_params
559
+
560
+
561
+ def _check_user_func_return(r, shape):
562
+ """
563
+ Correctly handles the result from a user defined function (eg a boundary
564
+ condition) to get the correct broadcast
565
+ """
566
+ if isinstance(r, int) or isinstance(r, float):
567
+ # if we have a scalar cast it to float
568
+ return float(r)
569
+ if r.shape == () or len(r.shape) == 1:
570
+ # if we have a scalar (or a vector, but no batch dim) inside an array
571
+ return r.astype(float)
572
+ else:
573
+ # if we have an array of the shape of the batch dimension(s) check that
574
+ # we have the correct broadcast
575
+ # the reshape below avoids a missing (1,) ending dimension
576
+ # depending on how the user has coded the inital function
577
+ return r.reshape(shape)
578
+
579
+
580
+ def alternate_optax_solver(
581
+ steps, parameters_set1, parameters_set2, lr_set1, lr_set2, label_fn=None
582
+ ):
583
+ """
584
+ This function creates an optax optimizer that alternates the optimization
585
+ between two set of parameters (ie. when some parameters are update to a
586
+ given learning rates, others are not updated (learning rate = 0)
587
+ The optimizers are scaled by adam parameters.
588
+
589
+ __Note:__ The alternating pattern relies on
590
+ `optax.piecewise_constant_schedule` which __multiplies__ learning rates of
591
+ previous steps (current included) to set the new learning rate. Hence, our
592
+ strategy used here is to relying on potentially cancelling power of tens to
593
+ create the alternating scheme.
594
+
595
+ Parameters
596
+ ----------
597
+ steps
598
+ An array which describes the epochis number at which we alternate the
599
+ optimization: the parameter_set that is being updated now stops
600
+ updating, the other parameter_set starts updating.
601
+ __Note:__ The step 0 should not be included
602
+ parameters_set1
603
+ A list of leaf level keys which must be found in the general `params` dict. The
604
+ parameters in this `set1` will be the parameters which are updated
605
+ first in the alternating scheme.
606
+ parameters_set2
607
+ A list of leaf level keys which must be found in the general `params` dict. The
608
+ parameters in this `set2` will be the parameters which are not updated
609
+ first in the alternating scheme.
610
+ lr_set1
611
+ A float. The learning rate of updates for set1.
612
+ lr_set2
613
+ A float. The learning rate of updates for set2.
614
+ label_fn
615
+ The same function as the label_fn function passed in an optax
616
+ `multi_transform`
617
+ [https://optax.readthedocs.io/en/latest/api.html#optax.multi_transform](see
618
+ here)
619
+ Default None, ie, we already internally provide the default one (as
620
+ proposed in the optax documentation) which may suit many use cases
621
+
622
+ Returns
623
+ -------
624
+ tx
625
+ The optax optimizer object
626
+ """
627
+
628
+ def map_nested_fn(fn):
629
+ """
630
+ Recursively apply `fn` to the key-value pairs of a nested dict
631
+ We follow the example from
632
+ https://optax.readthedocs.io/en/latest/api.html#optax.multi_transform
633
+ for different learning rates
634
+ """
635
+
636
+ def map_fn(nested_dict):
637
+ return {
638
+ k: (map_fn(v) if isinstance(v, dict) else fn(k, v))
639
+ for k, v in nested_dict.items()
640
+ }
641
+
642
+ return map_fn
643
+
644
+ label_fn = map_nested_fn(lambda k, _: k)
645
+
646
+ power_to_0 = 1e-25 # power of ten used to force a learning rate to 0
647
+ power_to_lr = 1 / power_to_0 # power of ten used to force a learning rate to lr
648
+ nn_params_scheduler = optax.piecewise_constant_schedule(
649
+ init_value=lr_set1,
650
+ boundaries_and_scales={
651
+ k: (
652
+ power_to_0
653
+ if even_odd % 2 == 0 # set lr to 0 eg if even_odd is even ie at
654
+ # first step
655
+ else power_to_lr
656
+ )
657
+ for even_odd, k in enumerate(steps)
658
+ },
659
+ )
660
+ eq_params_scheduler = optax.piecewise_constant_schedule(
661
+ init_value=power_to_0 * lr_set2, # so normal learning rate is 1e-3
662
+ boundaries_and_scales={
663
+ k: (power_to_lr if even_odd % 2 == 0 else power_to_0)
664
+ for even_odd, k in enumerate(steps)
665
+ },
666
+ )
667
+
668
+ # the scheduler for set1 is called nn_chain because we usually start by
669
+ # updating the NN parameters
670
+ nn_chain = optax.chain(
671
+ optax.scale_by_adam(),
672
+ optax.scale_by_schedule(nn_params_scheduler),
673
+ optax.scale(-1.0),
674
+ )
675
+ eq_chain = optax.chain(
676
+ optax.scale_by_adam(),
677
+ optax.scale_by_schedule(eq_params_scheduler),
678
+ optax.scale(-1.0),
679
+ )
680
+ dict_params_set1 = {p: nn_chain for p in parameters_set1}
681
+ dict_params_set2 = {p: eq_chain for p in parameters_set2}
682
+ tx = optax.multi_transform(
683
+ {**dict_params_set1, **dict_params_set2},
684
+ label_fn,
685
+ )
686
+
687
+ return tx
688
+
689
+
690
+ def euler_maruyama_density(t, x, s, y, params, Tmax=1):
691
+ eps = 1e-6
692
+ delta = jnp.abs(t - s) * Tmax
693
+ mu = params["alpha_sde"] * (params["mu_sde"] - y) * delta
694
+ var = params["sigma_sde"] ** 2 * delta
695
+ return (
696
+ 1 / jnp.sqrt(2 * jnp.pi * var) * jnp.exp(-0.5 * ((x - y) - mu) ** 2 / var) + eps
697
+ )
698
+
699
+
700
+ def log_euler_maruyama_density(t, x, s, y, params):
701
+ eps = 1e-6
702
+ delta = jnp.abs(t - s)
703
+ mu = params["alpha_sde"] * (params["mu_sde"] - y) * delta
704
+ logvar = params["logvar_sde"]
705
+ return (
706
+ -0.5
707
+ * (jnp.log(2 * jnp.pi * delta) + logvar + ((x - y) - mu) ** 2 / jnp.exp(logvar))
708
+ + eps
709
+ )
710
+
711
+
712
+ def euler_maruyama(x0, alpha, mu, sigma, T, N):
713
+ """
714
+ Simulate 1D diffusion process with simple parametrization using the Euler
715
+ Maruyama method in the interval [0, T]
716
+ """
717
+ path = [np.array([x0])]
718
+
719
+ time_steps, step_size = np.linspace(0, T, N, retstep=True)
720
+ for i in time_steps[1:]:
721
+ path.append(
722
+ path[-1]
723
+ + step_size * (alpha * (mu - path[-1]))
724
+ + sigma * np.random.normal(loc=0.0, scale=np.sqrt(step_size))
725
+ )
726
+
727
+ return time_steps, np.stack(path)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jinns
3
- Version: 0.8.0
3
+ Version: 0.8.1
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -16,14 +16,16 @@ jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  jinns/solver/_rar.py,sha256=K-0y1-ofOAo1n_Ea3QShSGCGKVYTwiaE_Bz9-DZMJm8,14525
17
17
  jinns/solver/_seq2seq.py,sha256=FL-42hTgmVl7O3hHh1ccFVw2bT8bW82hvlDRz971Chk,5620
18
18
  jinns/solver/_solve.py,sha256=6kWFWpJ33uOUzZKn7gIOM7yQsVZUwSuorOWPojVeMQY,13721
19
- jinns/utils/__init__.py,sha256=RxWVc4GUHkCRVtATwF-tJZvEF9VqU0t6I6vvFI8pvzY,268
20
- jinns/utils/_hyperpinn.py,sha256=oPQrzqDhZK1plukRFPgNPOeBWTqrezrmgKFTwI1I0eU,11007
19
+ jinns/utils/__init__.py,sha256=44ms5UR6vMw3Nf6u4RCAzPFs4fom_YbBnH9mfne8m6k,313
20
+ jinns/utils/_hyperpinn.py,sha256=nuy_V6qIXNxvLKvQAY6aZ_PLVroiiXQZ7RkHlF3GG60,11320
21
21
  jinns/utils/_optim.py,sha256=550kxH75TL30o1iKx1swJyP0KqyUPsJ7-imL1w65Qd0,4444
22
- jinns/utils/_pinn.py,sha256=U2cgxFG0Eaa94xe_Niifc_l6JXg0Ft5jTTs5nvS_258,9764
23
- jinns/utils/_spinn.py,sha256=C0z3d19cpj3sHPfUX8bv2GwtAgGfRR7m89LAI4j8bec,7932
24
- jinns/utils/_utils.py,sha256=8U0AcczbMbSEUS1J__06zI1B3TWfZhx7E0BDnHWPcPQ,6753
25
- jinns-0.8.0.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
26
- jinns-0.8.0.dist-info/METADATA,sha256=QmD157p0PpAFV7OHm7tznSCUFVReSis-VamGhWbi8qQ,2482
27
- jinns-0.8.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
28
- jinns-0.8.0.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
29
- jinns-0.8.0.dist-info/RECORD,,
22
+ jinns/utils/_pinn.py,sha256=N8LuB9Ql472O01USghkJkEOmx67DTjc279T8Lj-Lwd4,9722
23
+ jinns/utils/_save_load.py,sha256=qgZ23nUcB8-B5IZ2guuUWC4M7r5Lxd_Ms3staScdyJo,5668
24
+ jinns/utils/_spinn.py,sha256=aeIC3DBY7f_N8HABjvBNv375dMyjll3zt6KjY2bEIkM,8058
25
+ jinns/utils/_utils.py,sha256=8dgvWXX9NT7_7-zltWp0C9tG45ZFNwXxueyxPBb4hjo,6740
26
+ jinns/utils/_utils_uspinn.py,sha256=qcKcOw3zrwWSQyGVj6fD8c9GinHt_U6JWN_k0auTtXM,26039
27
+ jinns-0.8.1.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
28
+ jinns-0.8.1.dist-info/METADATA,sha256=AKJ921rioUOvCh8PRFt-_KbvuUMPt73i-i9QV2Orrrg,2482
29
+ jinns-0.8.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
30
+ jinns-0.8.1.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
31
+ jinns-0.8.1.dist-info/RECORD,,
File without changes
File without changes