jinns 0.8.7__py3-none-any.whl → 0.8.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/__init__.py CHANGED
@@ -2,4 +2,5 @@ import jinns.data
2
2
  import jinns.loss
3
3
  import jinns.solver
4
4
  import jinns.utils
5
+ import jinns.experimental
5
6
  from jinns.solver._solve import solve
jinns/data/_display.py CHANGED
@@ -1,8 +1,13 @@
1
+ """
2
+ Utility functions for plotting
3
+ """
4
+
5
+ from functools import partial
6
+ import warnings
1
7
  import matplotlib.pyplot as plt
2
8
  import jax.numpy as jnp
3
9
  from jax import vmap
4
10
  from mpl_toolkits.axes_grid1 import ImageGrid
5
- from functools import partial
6
11
 
7
12
 
8
13
  def plot2d(
@@ -14,8 +19,10 @@ def plot2d(
14
19
  figsize=(7, 7),
15
20
  cmap="inferno",
16
21
  spinn=False,
22
+ vmin_vmax=None,
23
+ ax_for_plot=None,
17
24
  ):
18
- """Generic function for plotting functions over rectangular 2-D domains
25
+ r"""Generic function for plotting functions over rectangular 2-D domains
19
26
  :math:`\Omega`. It treats both the stationary case :math:`u(x)` or the
20
27
  non-stationnary case :math:`u(t, x)`.
21
28
 
@@ -40,6 +47,12 @@ def plot2d(
40
47
  _description_, by default (7, 7)
41
48
  cmap : str, optional
42
49
  _description_, by default "inferno"
50
+ vmin_vmax : tuple, optional
51
+ The colorbar minimum and maximum value. Defaults None.
52
+ ax_for_plot : Matplotlib axis, optional
53
+ If None, jinns triggers the plotting. Otherwise this argument
54
+ corresponds to the axis which will host the plot. Default is None.
55
+ NOTE: that this argument will have an effect only if times is None.
43
56
 
44
57
  Raises
45
58
  ------
@@ -59,25 +72,49 @@ def plot2d(
59
72
  # Statio case : expect a function of one argument fun(x)
60
73
  if not spinn:
61
74
  v_fun = vmap(fun, 0, 0)
62
- _plot_2D_statio(
63
- v_fun, mesh, plot=True, colorbar=True, cmap=cmap, figsize=figsize
75
+ ret = _plot_2D_statio(
76
+ v_fun,
77
+ mesh,
78
+ plot=not ax_for_plot,
79
+ colorbar=True,
80
+ cmap=cmap,
81
+ figsize=figsize,
82
+ vmin_vmax=vmin_vmax,
64
83
  )
65
84
  elif spinn:
66
85
  values_grid = jnp.squeeze(
67
86
  fun(jnp.stack([xy_data[0][..., None], xy_data[1][..., None]], axis=1))
68
87
  )
69
- _plot_2D_statio(
88
+ ret = _plot_2D_statio(
70
89
  values_grid,
71
90
  mesh,
72
- plot=True,
91
+ plot=not ax_for_plot,
73
92
  colorbar=True,
74
93
  cmap=cmap,
75
94
  spinn=True,
76
95
  figsize=figsize,
96
+ vmin_vmax=vmin_vmax,
77
97
  )
78
- plt.title(title)
98
+ if not ax_for_plot:
99
+ plt.title(title)
100
+ else:
101
+ if vmin_vmax is not None:
102
+ im = ax_for_plot.pcolormesh(
103
+ mesh[0],
104
+ mesh[1],
105
+ ret[0],
106
+ cmap=cmap,
107
+ vmin=vmin_vmax[0],
108
+ vmax=vmin_vmax[1],
109
+ )
110
+ else:
111
+ im = ax_for_plot.pcolormesh(mesh[0], mesh[1], ret[0], cmap=cmap)
112
+ ax_for_plot.set_title(title)
113
+ ax_for_plot.cax.colorbar(im, format="%0.2f")
79
114
 
80
115
  else:
116
+ if ax_for_plot is not None:
117
+ warnings.warn("ax_for_plot is ignored. jinns will plot the figure")
81
118
  if not isinstance(times, list):
82
119
  try:
83
120
  times = times.tolist()
@@ -101,7 +138,12 @@ def plot2d(
101
138
  if not spinn:
102
139
  v_fun_at_t = vmap(lambda x: fun(t=jnp.array([t]), x=x), 0, 0)
103
140
  t_slice, _ = _plot_2D_statio(
104
- v_fun_at_t, mesh, plot=False, colorbar=False, cmap=None
141
+ v_fun_at_t,
142
+ mesh,
143
+ plot=False,
144
+ colorbar=False,
145
+ cmap=None,
146
+ vmin_vmax=vmin_vmax,
105
147
  )
106
148
  elif spinn:
107
149
  values_grid = jnp.squeeze(
@@ -113,15 +155,37 @@ def plot2d(
113
155
  )[0]
114
156
  )
115
157
  t_slice, _ = _plot_2D_statio(
116
- values_grid, mesh, plot=False, colorbar=True, spinn=True
158
+ values_grid,
159
+ mesh,
160
+ plot=False,
161
+ colorbar=True,
162
+ spinn=True,
163
+ vmin_vmax=vmin_vmax,
164
+ )
165
+ if vmin_vmax is not None:
166
+ im = ax.pcolormesh(
167
+ mesh[0],
168
+ mesh[1],
169
+ t_slice,
170
+ cmap=cmap,
171
+ vmin=vmin_vmax[0],
172
+ vmax=vmin_vmax[1],
117
173
  )
118
- im = ax.pcolormesh(mesh[0], mesh[1], t_slice, cmap=cmap)
174
+ else:
175
+ im = ax.pcolormesh(mesh[0], mesh[1], t_slice, cmap=cmap)
119
176
  ax.set_title(f"t = {times[idx] * Tmax:.2f}")
120
177
  ax.cax.colorbar(im, format="%0.2f")
121
178
 
122
179
 
123
180
  def _plot_2D_statio(
124
- v_fun, mesh, plot=True, colorbar=True, cmap="inferno", figsize=(7, 7), spinn=False
181
+ v_fun,
182
+ mesh,
183
+ plot=True,
184
+ colorbar=True,
185
+ cmap="inferno",
186
+ figsize=(7, 7),
187
+ spinn=False,
188
+ vmin_vmax=None,
125
189
  ):
126
190
  """Function that plot the function u(x) with 2-D input x using pcolormesh()
127
191
 
@@ -136,6 +200,8 @@ def _plot_2D_statio(
136
200
  either show or return the plot, by default True
137
201
  colorbar : bool, optional
138
202
  add a colorbar, by default True
203
+ vmin_vmax: tuple, optional
204
+ The colorbar minimum and maximum value. Defaults None.
139
205
 
140
206
  Returns
141
207
  -------
@@ -153,7 +219,17 @@ def _plot_2D_statio(
153
219
 
154
220
  if plot:
155
221
  fig = plt.figure(figsize=figsize)
156
- im = plt.pcolormesh(x_grid, y_grid, values_grid, cmap=cmap)
222
+ if vmin_vmax is not None:
223
+ im = plt.pcolormesh(
224
+ x_grid,
225
+ y_grid,
226
+ values_grid,
227
+ cmap=cmap,
228
+ vmin=vmin_vmax[0],
229
+ vmax=vmin_vmax[1],
230
+ )
231
+ else:
232
+ im = plt.pcolormesh(x_grid, y_grid, values_grid, cmap=cmap)
157
233
  if colorbar:
158
234
  fig.colorbar(im, format="%0.2f")
159
235
  # don't plt.show() because it is done in plot2d()
@@ -217,6 +293,7 @@ def plot1d_image(
217
293
  colorbar=True,
218
294
  cmap="inferno",
219
295
  spinn=False,
296
+ vmin_vmax=None,
220
297
  ):
221
298
  """Function for plotting the 2-D image of a function :math:`f(t, x)` where
222
299
  `t` is time (1-D) and x is space (1-D).
@@ -237,6 +314,8 @@ def plot1d_image(
237
314
  , by default ""
238
315
  figsize : tuple, optional
239
316
  , by default (10, 10)
317
+ vmin_vmax: tuple
318
+ The colorbar minimum and maximum value. Defaults None.
240
319
  """
241
320
 
242
321
  mesh = jnp.meshgrid(times, xdata) # cartesian product
@@ -250,7 +329,17 @@ def plot1d_image(
250
329
  elif spinn:
251
330
  values_grid = jnp.squeeze(fun((times[..., None]), xdata[..., None]).T)
252
331
  fig, ax = plt.subplots(1, 1, figsize=figsize)
253
- im = ax.pcolormesh(mesh[0] * Tmax, mesh[1], values_grid, cmap=cmap)
332
+ if vmin_vmax is not None:
333
+ im = ax.pcolormesh(
334
+ mesh[0] * Tmax,
335
+ mesh[1],
336
+ values_grid,
337
+ cmap=cmap,
338
+ vmin=vmin_vmax[0],
339
+ vmax=vmin_vmax[1],
340
+ )
341
+ else:
342
+ im = ax.pcolormesh(mesh[0] * Tmax, mesh[1], values_grid, cmap=cmap)
254
343
  if colorbar:
255
344
  fig.colorbar(im, format="%0.2f")
256
345
  ax.set_title(title)
@@ -6,3 +6,5 @@ from ._diffrax_solver import (
6
6
  neumann_boundary_condition,
7
7
  plot_diffrax_solution,
8
8
  )
9
+ from ._sinuspinn import create_sinusPINN
10
+ from ._spectralpinn import create_spectralPINN
@@ -0,0 +1,135 @@
1
+ import jax
2
+ import equinox as eqx
3
+ import jax.numpy as jnp
4
+ from jinns.utils._pinn import PINN
5
+
6
+
7
+ def almost_zero_init(weight: jax.Array, key: jax.random.PRNGKey) -> jax.Array:
8
+ out, in_ = weight.shape
9
+ stddev = 1e-2
10
+ return stddev * jax.random.normal(key, shape=(out, in_))
11
+
12
+
13
+ class _SinusPINN(eqx.Module):
14
+ """
15
+ A specific PINN whose layers are x_sin2x functions whose frequencies are
16
+ determined by an other network
17
+ """
18
+
19
+ layers_pinn: list
20
+ layers_aux_nn: list
21
+
22
+ def __init__(self, key, list_layers_pinn, list_layers_aux_nn):
23
+ """
24
+ Parameters
25
+ ----------
26
+ key
27
+ A jax random key
28
+ list_layers_pinn
29
+ A list as eqx_list in jinns' PINN utility for the main PINN
30
+ list_layers_aux_nn
31
+ A list as eqx_list in jinns' PINN utility for the network which outputs
32
+ the PINN's activation frequencies
33
+ """
34
+ self.layers_pinn = []
35
+ for l in list_layers_pinn:
36
+ if len(l) == 1:
37
+ self.layers_pinn.append(l[0])
38
+ else:
39
+ key, subkey = jax.random.split(key, 2)
40
+ self.layers_pinn.append(l[0](*l[1:], key=subkey))
41
+ self.layers_aux_nn = []
42
+ for idx, l in enumerate(list_layers_aux_nn):
43
+ if len(l) == 1:
44
+ self.layers_aux_nn.append(l[0])
45
+ else:
46
+ key, subkey = jax.random.split(key, 2)
47
+ linear_layer = l[0](*l[1:], key=subkey)
48
+ key, subkey = jax.random.split(key, 2)
49
+ linear_layer = eqx.tree_at(
50
+ lambda l: l.weight,
51
+ linear_layer,
52
+ almost_zero_init(linear_layer.weight, subkey),
53
+ )
54
+ if (idx == len(list_layers_aux_nn) - 1) or (
55
+ idx == len(list_layers_aux_nn) - 2
56
+ ):
57
+ # for the last layer: almost 0 weights and 0.5 bias
58
+ linear_layer = eqx.tree_at(
59
+ lambda l: l.bias,
60
+ linear_layer,
61
+ 0.5 * jnp.ones(linear_layer.bias.shape),
62
+ )
63
+ else:
64
+ # for the other previous layers:
65
+ # almost 0 weight and 0 bias
66
+ linear_layer = eqx.tree_at(
67
+ lambda l: l.bias,
68
+ linear_layer,
69
+ jnp.zeros(linear_layer.bias.shape),
70
+ )
71
+ self.layers_aux_nn.append(linear_layer)
72
+
73
+ ## init to zero the frequency network except last biases
74
+ # key, subkey = jax.random.split(key, 2)
75
+ # _pinn = init_linear_weight(_pinn, almost_zero_init, subkey)
76
+ # key, subkey = jax.random.split(key, 2)
77
+ # _pinn = init_linear_bias(_pinn, zero_init, subkey)
78
+ # print(_pinn)
79
+ # print(jax.tree_util.tree_leaves(_pinn, is_leaf=lambda
80
+ # p:not isinstance(p,eqx.nn.Linear))[0].layers_aux_nn[-1].bias)
81
+ # _pinn = eqx.tree_at(lambda p:_pinn.layers_aux_nn[-1].bias, 0.5 *
82
+ # jnp.ones(_pinn.layers_aux_nn[-1].bias.shape))
83
+ # #, is_leaf=lambda
84
+ # #p:not isinstance(p, eqx.nn.Linear))
85
+
86
+ def __call__(self, x):
87
+ x_ = x.copy()
88
+ # forward pass in the network which determines the freq
89
+ for layer in self.layers_aux_nn:
90
+ x_ = layer(x_)
91
+ freq_list = jnp.clip(jnp.square(x_), a_min=1e-4, a_max=5)
92
+ x_ = x.copy()
93
+ # forward pass through the actual PINN
94
+ for idx, layer in enumerate(self.layers_pinn):
95
+ if idx % 2 == 0:
96
+ # Currently: every two layer we have an activation
97
+ # requiring a frequency
98
+ x_ = layer(x_)
99
+ else:
100
+ x_ = layer(x_, freq_list[(idx - 1) // 2])
101
+ return x_
102
+
103
+
104
+ class sinusPINN(PINN):
105
+ """
106
+ MUST inherit from PINN to pass all the checks
107
+
108
+ HOWEVER we dot not bother with reimplementing anything
109
+ """
110
+
111
+ def __init__(self, key, list_layers_pinn, list_layers_aux_nn):
112
+ super().__init__({}, jnp.s_[...], "statio_PDE", None, None, None)
113
+ key, subkey = jax.random.split(key, 2)
114
+ _pinn = _SinusPINN(subkey, list_layers_pinn, list_layers_aux_nn)
115
+
116
+ self.params, self.static = eqx.partition(_pinn, eqx.is_inexact_array)
117
+
118
+ def init_params(self):
119
+ return self.params
120
+
121
+ def __call__(self, x, params):
122
+ try:
123
+ model = eqx.combine(params["nn_params"], self.static)
124
+ except (KeyError, TypeError) as e: # give more flexibility
125
+ model = eqx.combine(params, self.static)
126
+ res = model(x)
127
+ if not res.shape:
128
+ return jnp.expand_dims(res, axis=-1)
129
+ return res
130
+
131
+
132
+ def create_sinusPINN(key, list_layers_pinn, list_layers_aux_nn):
133
+ """ """
134
+ u = sinusPINN(key, list_layers_pinn, list_layers_aux_nn)
135
+ return u
@@ -0,0 +1,87 @@
1
+ import jax
2
+ import equinox as eqx
3
+ import jax.numpy as jnp
4
+ from jinns.utils._pinn import PINN
5
+
6
+
7
+ def almost_zero_init(weight: jax.Array, key: jax.random.PRNGKey) -> jax.Array:
8
+ out, in_ = weight.shape
9
+ stddev = 1e-2
10
+ return stddev * jax.random.normal(key, shape=(out, in_))
11
+
12
+
13
+ class _SpectralPINN(eqx.Module):
14
+ """
15
+ A specific PINN whose acrhitecture is similar to spectral method for simulation of a spatial field
16
+ (Chilès and Delfiner, 2012) - a single layer with cos() activation function and sum for last layer
17
+ """
18
+
19
+ layers_pinn: list
20
+ nbands: int
21
+
22
+ def __init__(self, key, list_layers_pinn, nbands):
23
+ """
24
+ Parameters
25
+ ----------
26
+ key
27
+ A jax random key
28
+ list_layers_pinn
29
+ A list as eqx_list in jinns' PINN utility for the main PINN
30
+ nbands
31
+ Number of spectral bands (i.e., neurones in the single layer of the PINN)
32
+ """
33
+ self.nbands = nbands
34
+ self.layers_pinn = []
35
+ for l in list_layers_pinn:
36
+ if len(l) == 1:
37
+ self.layers_pinn.append(l[0])
38
+ else:
39
+ key, subkey = jax.random.split(key, 2)
40
+ self.layers_pinn.append(l[0](*l[1:], key=subkey))
41
+
42
+ def __call__(self, x):
43
+ # forward pass through the actual PINN
44
+ for layer in self.layers_pinn:
45
+ x = layer(x)
46
+
47
+ return jnp.sqrt(2 / self.nbands) * jnp.sum(x)
48
+
49
+
50
+ class spectralPINN(PINN):
51
+ """
52
+ MUST inherit from PINN to pass all the checks
53
+
54
+ HOWEVER we dot not bother with reimplementing anything
55
+ """
56
+
57
+ def __init__(self, key, list_layers_pinn, nbands):
58
+ super().__init__({}, jnp.s_[...], "statio_PDE", None, None, None)
59
+ key, subkey = jax.random.split(key, 2)
60
+ _pinn = _SpectralPINN(subkey, list_layers_pinn, nbands)
61
+
62
+ self.params, self.static = eqx.partition(_pinn, eqx.is_inexact_array)
63
+
64
+ def init_params(self):
65
+ return self.params
66
+
67
+ def __call__(self, x, params):
68
+ try:
69
+ model = eqx.combine(params["nn_params"], self.static)
70
+ except (KeyError, TypeError) as e: # give more flexibility
71
+ model = eqx.combine(params, self.static)
72
+ # model = eqx.tree_at(lambda m:
73
+ # m.layers_pinn[0].bias,
74
+ # model,
75
+ # model.layers_pinn[0].bias % (2 *
76
+ # jnp.pi)
77
+ # )
78
+ res = model(x)
79
+ if not res.shape:
80
+ return jnp.expand_dims(res, axis=-1)
81
+ return res
82
+
83
+
84
+ def create_spectralPINN(key, list_layers_pinn, nbands):
85
+ """ """
86
+ u = spectralPINN(key, list_layers_pinn, nbands)
87
+ return u
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jinns
3
- Version: 0.8.7
3
+ Version: 0.8.8
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -1,9 +1,11 @@
1
- jinns/__init__.py,sha256=Nw5pdlmDhJwco3bXX3YttkeCF8czX_6m0poh8vu0lDQ,113
1
+ jinns/__init__.py,sha256=T2XlmLbYqcXTumPJL00cJ80W98We5LH8Yg_Lss_exl4,139
2
2
  jinns/data/_DataGenerators.py,sha256=N4-U4z3MG46UIzHCbKScv9Z7AN40w1wlLY_VsVNj2sI,62293
3
3
  jinns/data/__init__.py,sha256=yBOmoavSD-cABp4XcjQY1zsEVO0mDyIhi2MJ5WNp0l8,326
4
- jinns/data/_display.py,sha256=6renz4H7kHktutmLY7HM6PmxYH7cBfGHpC7GQa1Fnlk,7778
5
- jinns/experimental/__init__.py,sha256=3jCIy2R2i_0Erwxg-HwISdH79Nt1XCXhS9yY1F5awiY,208
4
+ jinns/data/_display.py,sha256=vlqggDCgVMEwdGBtjVmZaTQORU6imSfDkssn2XCtITI,10392
5
+ jinns/experimental/__init__.py,sha256=qWbhC7Z8UgLWy0t-zU7RYze6v13-FngiCYXu-2bRVFQ,296
6
6
  jinns/experimental/_diffrax_solver.py,sha256=sLT22byqh-6015_fhe1xtMWlFOYcCjzYKET4sLhA9R4,6818
7
+ jinns/experimental/_sinuspinn.py,sha256=hxSzscwMV2LayWOqenIlT1zqEVVrE5Y8CKf7bHX5XFQ,5016
8
+ jinns/experimental/_spectralpinn.py,sha256=-4795pa7AYtRNSE-ugan3gHh64mtu2VdrRG5AS_J9Eg,2654
7
9
  jinns/loss/_DynamicLoss.py,sha256=L4CVmmF0rTPbHntgqsLLHlnrlQgLHsetUocpJm7ZYag,27461
8
10
  jinns/loss/_DynamicLossAbstract.py,sha256=kTQlhLx7SBuH5dIDmYaE79sVHUZt1nUFa8LxPU5IHhM,8504
9
11
  jinns/loss/_LossODE.py,sha256=b9doBHoQwYvlgpqzrNO4dOaTN87LRvjHtHbz9bMoH7E,22119
@@ -27,8 +29,8 @@ jinns/utils/_utils.py,sha256=8dgvWXX9NT7_7-zltWp0C9tG45ZFNwXxueyxPBb4hjo,6740
27
29
  jinns/utils/_utils_uspinn.py,sha256=qcKcOw3zrwWSQyGVj6fD8c9GinHt_U6JWN_k0auTtXM,26039
28
30
  jinns/validation/__init__.py,sha256=Jv58mzgC3F7cRfXA6caicL1t_U0UAhbwLrmMNVg6E7s,66
29
31
  jinns/validation/_validation.py,sha256=KfetbzB0xTNdBcYLwFWjEtP63Tf9wJirlhgqLTJDyy4,6761
30
- jinns-0.8.7.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
31
- jinns-0.8.7.dist-info/METADATA,sha256=L0P7JvMGKrJHx9OjrtFsmNKEwdKA_RlufAbOBf5l10I,2482
32
- jinns-0.8.7.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
33
- jinns-0.8.7.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
34
- jinns-0.8.7.dist-info/RECORD,,
32
+ jinns-0.8.8.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
33
+ jinns-0.8.8.dist-info/METADATA,sha256=oTs2EJMu4Bwo2n9DLsAPSU5edpbgPtwhNXBuW8YjpOc,2482
34
+ jinns-0.8.8.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
35
+ jinns-0.8.8.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
36
+ jinns-0.8.8.dist-info/RECORD,,
File without changes
File without changes