jinns 0.8.6__py3-none-any.whl → 0.8.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/__init__.py CHANGED
@@ -2,4 +2,5 @@ import jinns.data
2
2
  import jinns.loss
3
3
  import jinns.solver
4
4
  import jinns.utils
5
+ import jinns.experimental
5
6
  from jinns.solver._solve import solve
jinns/data/_display.py CHANGED
@@ -1,8 +1,13 @@
1
+ """
2
+ Utility functions for plotting
3
+ """
4
+
5
+ from functools import partial
6
+ import warnings
1
7
  import matplotlib.pyplot as plt
2
8
  import jax.numpy as jnp
3
9
  from jax import vmap
4
10
  from mpl_toolkits.axes_grid1 import ImageGrid
5
- from functools import partial
6
11
 
7
12
 
8
13
  def plot2d(
@@ -14,8 +19,10 @@ def plot2d(
14
19
  figsize=(7, 7),
15
20
  cmap="inferno",
16
21
  spinn=False,
22
+ vmin_vmax=None,
23
+ ax_for_plot=None,
17
24
  ):
18
- """Generic function for plotting functions over rectangular 2-D domains
25
+ r"""Generic function for plotting functions over rectangular 2-D domains
19
26
  :math:`\Omega`. It treats both the stationary case :math:`u(x)` or the
20
27
  non-stationnary case :math:`u(t, x)`.
21
28
 
@@ -40,6 +47,12 @@ def plot2d(
40
47
  _description_, by default (7, 7)
41
48
  cmap : str, optional
42
49
  _description_, by default "inferno"
50
+ vmin_vmax : tuple, optional
51
+ The colorbar minimum and maximum value. Defaults None.
52
+ ax_for_plot : Matplotlib axis, optional
53
+ If None, jinns triggers the plotting. Otherwise this argument
54
+ corresponds to the axis which will host the plot. Default is None.
55
+ NOTE: that this argument will have an effect only if times is None.
43
56
 
44
57
  Raises
45
58
  ------
@@ -59,25 +72,49 @@ def plot2d(
59
72
  # Statio case : expect a function of one argument fun(x)
60
73
  if not spinn:
61
74
  v_fun = vmap(fun, 0, 0)
62
- _plot_2D_statio(
63
- v_fun, mesh, plot=True, colorbar=True, cmap=cmap, figsize=figsize
75
+ ret = _plot_2D_statio(
76
+ v_fun,
77
+ mesh,
78
+ plot=not ax_for_plot,
79
+ colorbar=True,
80
+ cmap=cmap,
81
+ figsize=figsize,
82
+ vmin_vmax=vmin_vmax,
64
83
  )
65
84
  elif spinn:
66
85
  values_grid = jnp.squeeze(
67
86
  fun(jnp.stack([xy_data[0][..., None], xy_data[1][..., None]], axis=1))
68
87
  )
69
- _plot_2D_statio(
88
+ ret = _plot_2D_statio(
70
89
  values_grid,
71
90
  mesh,
72
- plot=True,
91
+ plot=not ax_for_plot,
73
92
  colorbar=True,
74
93
  cmap=cmap,
75
94
  spinn=True,
76
95
  figsize=figsize,
96
+ vmin_vmax=vmin_vmax,
77
97
  )
78
- plt.title(title)
98
+ if not ax_for_plot:
99
+ plt.title(title)
100
+ else:
101
+ if vmin_vmax is not None:
102
+ im = ax_for_plot.pcolormesh(
103
+ mesh[0],
104
+ mesh[1],
105
+ ret[0],
106
+ cmap=cmap,
107
+ vmin=vmin_vmax[0],
108
+ vmax=vmin_vmax[1],
109
+ )
110
+ else:
111
+ im = ax_for_plot.pcolormesh(mesh[0], mesh[1], ret[0], cmap=cmap)
112
+ ax_for_plot.set_title(title)
113
+ ax_for_plot.cax.colorbar(im, format="%0.2f")
79
114
 
80
115
  else:
116
+ if ax_for_plot is not None:
117
+ warnings.warn("ax_for_plot is ignored. jinns will plot the figure")
81
118
  if not isinstance(times, list):
82
119
  try:
83
120
  times = times.tolist()
@@ -101,7 +138,12 @@ def plot2d(
101
138
  if not spinn:
102
139
  v_fun_at_t = vmap(lambda x: fun(t=jnp.array([t]), x=x), 0, 0)
103
140
  t_slice, _ = _plot_2D_statio(
104
- v_fun_at_t, mesh, plot=False, colorbar=False, cmap=None
141
+ v_fun_at_t,
142
+ mesh,
143
+ plot=False,
144
+ colorbar=False,
145
+ cmap=None,
146
+ vmin_vmax=vmin_vmax,
105
147
  )
106
148
  elif spinn:
107
149
  values_grid = jnp.squeeze(
@@ -113,15 +155,37 @@ def plot2d(
113
155
  )[0]
114
156
  )
115
157
  t_slice, _ = _plot_2D_statio(
116
- values_grid, mesh, plot=False, colorbar=True, spinn=True
158
+ values_grid,
159
+ mesh,
160
+ plot=False,
161
+ colorbar=True,
162
+ spinn=True,
163
+ vmin_vmax=vmin_vmax,
164
+ )
165
+ if vmin_vmax is not None:
166
+ im = ax.pcolormesh(
167
+ mesh[0],
168
+ mesh[1],
169
+ t_slice,
170
+ cmap=cmap,
171
+ vmin=vmin_vmax[0],
172
+ vmax=vmin_vmax[1],
117
173
  )
118
- im = ax.pcolormesh(mesh[0], mesh[1], t_slice, cmap=cmap)
174
+ else:
175
+ im = ax.pcolormesh(mesh[0], mesh[1], t_slice, cmap=cmap)
119
176
  ax.set_title(f"t = {times[idx] * Tmax:.2f}")
120
177
  ax.cax.colorbar(im, format="%0.2f")
121
178
 
122
179
 
123
180
  def _plot_2D_statio(
124
- v_fun, mesh, plot=True, colorbar=True, cmap="inferno", figsize=(7, 7), spinn=False
181
+ v_fun,
182
+ mesh,
183
+ plot=True,
184
+ colorbar=True,
185
+ cmap="inferno",
186
+ figsize=(7, 7),
187
+ spinn=False,
188
+ vmin_vmax=None,
125
189
  ):
126
190
  """Function that plot the function u(x) with 2-D input x using pcolormesh()
127
191
 
@@ -136,6 +200,8 @@ def _plot_2D_statio(
136
200
  either show or return the plot, by default True
137
201
  colorbar : bool, optional
138
202
  add a colorbar, by default True
203
+ vmin_vmax: tuple, optional
204
+ The colorbar minimum and maximum value. Defaults None.
139
205
 
140
206
  Returns
141
207
  -------
@@ -153,7 +219,17 @@ def _plot_2D_statio(
153
219
 
154
220
  if plot:
155
221
  fig = plt.figure(figsize=figsize)
156
- im = plt.pcolormesh(x_grid, y_grid, values_grid, cmap=cmap)
222
+ if vmin_vmax is not None:
223
+ im = plt.pcolormesh(
224
+ x_grid,
225
+ y_grid,
226
+ values_grid,
227
+ cmap=cmap,
228
+ vmin=vmin_vmax[0],
229
+ vmax=vmin_vmax[1],
230
+ )
231
+ else:
232
+ im = plt.pcolormesh(x_grid, y_grid, values_grid, cmap=cmap)
157
233
  if colorbar:
158
234
  fig.colorbar(im, format="%0.2f")
159
235
  # don't plt.show() because it is done in plot2d()
@@ -217,6 +293,7 @@ def plot1d_image(
217
293
  colorbar=True,
218
294
  cmap="inferno",
219
295
  spinn=False,
296
+ vmin_vmax=None,
220
297
  ):
221
298
  """Function for plotting the 2-D image of a function :math:`f(t, x)` where
222
299
  `t` is time (1-D) and x is space (1-D).
@@ -237,6 +314,8 @@ def plot1d_image(
237
314
  , by default ""
238
315
  figsize : tuple, optional
239
316
  , by default (10, 10)
317
+ vmin_vmax: tuple
318
+ The colorbar minimum and maximum value. Defaults None.
240
319
  """
241
320
 
242
321
  mesh = jnp.meshgrid(times, xdata) # cartesian product
@@ -250,7 +329,17 @@ def plot1d_image(
250
329
  elif spinn:
251
330
  values_grid = jnp.squeeze(fun((times[..., None]), xdata[..., None]).T)
252
331
  fig, ax = plt.subplots(1, 1, figsize=figsize)
253
- im = ax.pcolormesh(mesh[0] * Tmax, mesh[1], values_grid, cmap=cmap)
332
+ if vmin_vmax is not None:
333
+ im = ax.pcolormesh(
334
+ mesh[0] * Tmax,
335
+ mesh[1],
336
+ values_grid,
337
+ cmap=cmap,
338
+ vmin=vmin_vmax[0],
339
+ vmax=vmin_vmax[1],
340
+ )
341
+ else:
342
+ im = ax.pcolormesh(mesh[0] * Tmax, mesh[1], values_grid, cmap=cmap)
254
343
  if colorbar:
255
344
  fig.colorbar(im, format="%0.2f")
256
345
  ax.set_title(title)
@@ -6,3 +6,5 @@ from ._diffrax_solver import (
6
6
  neumann_boundary_condition,
7
7
  plot_diffrax_solution,
8
8
  )
9
+ from ._sinuspinn import create_sinusPINN
10
+ from ._spectralpinn import create_spectralPINN
@@ -0,0 +1,135 @@
1
+ import jax
2
+ import equinox as eqx
3
+ import jax.numpy as jnp
4
+ from jinns.utils._pinn import PINN
5
+
6
+
7
+ def almost_zero_init(weight: jax.Array, key: jax.random.PRNGKey) -> jax.Array:
8
+ out, in_ = weight.shape
9
+ stddev = 1e-2
10
+ return stddev * jax.random.normal(key, shape=(out, in_))
11
+
12
+
13
+ class _SinusPINN(eqx.Module):
14
+ """
15
+ A specific PINN whose layers are x_sin2x functions whose frequencies are
16
+ determined by an other network
17
+ """
18
+
19
+ layers_pinn: list
20
+ layers_aux_nn: list
21
+
22
+ def __init__(self, key, list_layers_pinn, list_layers_aux_nn):
23
+ """
24
+ Parameters
25
+ ----------
26
+ key
27
+ A jax random key
28
+ list_layers_pinn
29
+ A list as eqx_list in jinns' PINN utility for the main PINN
30
+ list_layers_aux_nn
31
+ A list as eqx_list in jinns' PINN utility for the network which outputs
32
+ the PINN's activation frequencies
33
+ """
34
+ self.layers_pinn = []
35
+ for l in list_layers_pinn:
36
+ if len(l) == 1:
37
+ self.layers_pinn.append(l[0])
38
+ else:
39
+ key, subkey = jax.random.split(key, 2)
40
+ self.layers_pinn.append(l[0](*l[1:], key=subkey))
41
+ self.layers_aux_nn = []
42
+ for idx, l in enumerate(list_layers_aux_nn):
43
+ if len(l) == 1:
44
+ self.layers_aux_nn.append(l[0])
45
+ else:
46
+ key, subkey = jax.random.split(key, 2)
47
+ linear_layer = l[0](*l[1:], key=subkey)
48
+ key, subkey = jax.random.split(key, 2)
49
+ linear_layer = eqx.tree_at(
50
+ lambda l: l.weight,
51
+ linear_layer,
52
+ almost_zero_init(linear_layer.weight, subkey),
53
+ )
54
+ if (idx == len(list_layers_aux_nn) - 1) or (
55
+ idx == len(list_layers_aux_nn) - 2
56
+ ):
57
+ # for the last layer: almost 0 weights and 0.5 bias
58
+ linear_layer = eqx.tree_at(
59
+ lambda l: l.bias,
60
+ linear_layer,
61
+ 0.5 * jnp.ones(linear_layer.bias.shape),
62
+ )
63
+ else:
64
+ # for the other previous layers:
65
+ # almost 0 weight and 0 bias
66
+ linear_layer = eqx.tree_at(
67
+ lambda l: l.bias,
68
+ linear_layer,
69
+ jnp.zeros(linear_layer.bias.shape),
70
+ )
71
+ self.layers_aux_nn.append(linear_layer)
72
+
73
+ ## init to zero the frequency network except last biases
74
+ # key, subkey = jax.random.split(key, 2)
75
+ # _pinn = init_linear_weight(_pinn, almost_zero_init, subkey)
76
+ # key, subkey = jax.random.split(key, 2)
77
+ # _pinn = init_linear_bias(_pinn, zero_init, subkey)
78
+ # print(_pinn)
79
+ # print(jax.tree_util.tree_leaves(_pinn, is_leaf=lambda
80
+ # p:not isinstance(p,eqx.nn.Linear))[0].layers_aux_nn[-1].bias)
81
+ # _pinn = eqx.tree_at(lambda p:_pinn.layers_aux_nn[-1].bias, 0.5 *
82
+ # jnp.ones(_pinn.layers_aux_nn[-1].bias.shape))
83
+ # #, is_leaf=lambda
84
+ # #p:not isinstance(p, eqx.nn.Linear))
85
+
86
+ def __call__(self, x):
87
+ x_ = x.copy()
88
+ # forward pass in the network which determines the freq
89
+ for layer in self.layers_aux_nn:
90
+ x_ = layer(x_)
91
+ freq_list = jnp.clip(jnp.square(x_), a_min=1e-4, a_max=5)
92
+ x_ = x.copy()
93
+ # forward pass through the actual PINN
94
+ for idx, layer in enumerate(self.layers_pinn):
95
+ if idx % 2 == 0:
96
+ # Currently: every two layer we have an activation
97
+ # requiring a frequency
98
+ x_ = layer(x_)
99
+ else:
100
+ x_ = layer(x_, freq_list[(idx - 1) // 2])
101
+ return x_
102
+
103
+
104
+ class sinusPINN(PINN):
105
+ """
106
+ MUST inherit from PINN to pass all the checks
107
+
108
+ HOWEVER we dot not bother with reimplementing anything
109
+ """
110
+
111
+ def __init__(self, key, list_layers_pinn, list_layers_aux_nn):
112
+ super().__init__({}, jnp.s_[...], "statio_PDE", None, None, None)
113
+ key, subkey = jax.random.split(key, 2)
114
+ _pinn = _SinusPINN(subkey, list_layers_pinn, list_layers_aux_nn)
115
+
116
+ self.params, self.static = eqx.partition(_pinn, eqx.is_inexact_array)
117
+
118
+ def init_params(self):
119
+ return self.params
120
+
121
+ def __call__(self, x, params):
122
+ try:
123
+ model = eqx.combine(params["nn_params"], self.static)
124
+ except (KeyError, TypeError) as e: # give more flexibility
125
+ model = eqx.combine(params, self.static)
126
+ res = model(x)
127
+ if not res.shape:
128
+ return jnp.expand_dims(res, axis=-1)
129
+ return res
130
+
131
+
132
+ def create_sinusPINN(key, list_layers_pinn, list_layers_aux_nn):
133
+ """ """
134
+ u = sinusPINN(key, list_layers_pinn, list_layers_aux_nn)
135
+ return u
@@ -0,0 +1,87 @@
1
+ import jax
2
+ import equinox as eqx
3
+ import jax.numpy as jnp
4
+ from jinns.utils._pinn import PINN
5
+
6
+
7
+ def almost_zero_init(weight: jax.Array, key: jax.random.PRNGKey) -> jax.Array:
8
+ out, in_ = weight.shape
9
+ stddev = 1e-2
10
+ return stddev * jax.random.normal(key, shape=(out, in_))
11
+
12
+
13
+ class _SpectralPINN(eqx.Module):
14
+ """
15
+ A specific PINN whose acrhitecture is similar to spectral method for simulation of a spatial field
16
+ (Chilès and Delfiner, 2012) - a single layer with cos() activation function and sum for last layer
17
+ """
18
+
19
+ layers_pinn: list
20
+ nbands: int
21
+
22
+ def __init__(self, key, list_layers_pinn, nbands):
23
+ """
24
+ Parameters
25
+ ----------
26
+ key
27
+ A jax random key
28
+ list_layers_pinn
29
+ A list as eqx_list in jinns' PINN utility for the main PINN
30
+ nbands
31
+ Number of spectral bands (i.e., neurones in the single layer of the PINN)
32
+ """
33
+ self.nbands = nbands
34
+ self.layers_pinn = []
35
+ for l in list_layers_pinn:
36
+ if len(l) == 1:
37
+ self.layers_pinn.append(l[0])
38
+ else:
39
+ key, subkey = jax.random.split(key, 2)
40
+ self.layers_pinn.append(l[0](*l[1:], key=subkey))
41
+
42
+ def __call__(self, x):
43
+ # forward pass through the actual PINN
44
+ for layer in self.layers_pinn:
45
+ x = layer(x)
46
+
47
+ return jnp.sqrt(2 / self.nbands) * jnp.sum(x)
48
+
49
+
50
+ class spectralPINN(PINN):
51
+ """
52
+ MUST inherit from PINN to pass all the checks
53
+
54
+ HOWEVER we dot not bother with reimplementing anything
55
+ """
56
+
57
+ def __init__(self, key, list_layers_pinn, nbands):
58
+ super().__init__({}, jnp.s_[...], "statio_PDE", None, None, None)
59
+ key, subkey = jax.random.split(key, 2)
60
+ _pinn = _SpectralPINN(subkey, list_layers_pinn, nbands)
61
+
62
+ self.params, self.static = eqx.partition(_pinn, eqx.is_inexact_array)
63
+
64
+ def init_params(self):
65
+ return self.params
66
+
67
+ def __call__(self, x, params):
68
+ try:
69
+ model = eqx.combine(params["nn_params"], self.static)
70
+ except (KeyError, TypeError) as e: # give more flexibility
71
+ model = eqx.combine(params, self.static)
72
+ # model = eqx.tree_at(lambda m:
73
+ # m.layers_pinn[0].bias,
74
+ # model,
75
+ # model.layers_pinn[0].bias % (2 *
76
+ # jnp.pi)
77
+ # )
78
+ res = model(x)
79
+ if not res.shape:
80
+ return jnp.expand_dims(res, axis=-1)
81
+ return res
82
+
83
+
84
+ def create_spectralPINN(key, list_layers_pinn, nbands):
85
+ """ """
86
+ u = spectralPINN(key, list_layers_pinn, nbands)
87
+ return u
jinns/loss/_LossODE.py CHANGED
@@ -19,6 +19,8 @@ from jinns.loss._Losses import (
19
19
  )
20
20
  from jinns.utils._pinn import PINN
21
21
 
22
+ _LOSS_WEIGHT_KEYS_ODE = ["observations", "dyn_loss", "initial_condition"]
23
+
22
24
 
23
25
  @register_pytree_node_class
24
26
  class LossODE:
@@ -128,6 +130,10 @@ class LossODE:
128
130
  if self.obs_slice is None:
129
131
  self.obs_slice = jnp.s_[...]
130
132
 
133
+ for k in _LOSS_WEIGHT_KEYS_ODE:
134
+ if k not in self.loss_weights.keys():
135
+ self.loss_weights[k] = 0
136
+
131
137
  def __call__(self, *args, **kwargs):
132
138
  return self.evaluate(*args, **kwargs)
133
139
 
jinns/loss/_LossPDE.py CHANGED
@@ -31,6 +31,16 @@ _IMPLEMENTED_BOUNDARY_CONDITIONS = [
31
31
  "vonneumann",
32
32
  ]
33
33
 
34
+ _LOSS_WEIGHT_KEYS_PDESTATIO = [
35
+ "sobolev",
36
+ "observations",
37
+ "norm_loss",
38
+ "boundary_loss",
39
+ "dyn_loss",
40
+ ]
41
+
42
+ _LOSS_WEIGHT_KEYS_PDENONSTATIO = _LOSS_WEIGHT_KEYS_PDESTATIO + ["initial_condition"]
43
+
34
44
 
35
45
  @register_pytree_node_class
36
46
  class LossPDEAbstract:
@@ -269,8 +279,8 @@ class LossPDEStatio(LossPDEAbstract):
269
279
  the PINN object
270
280
  loss_weights
271
281
  a dictionary with values used to ponderate each term in the loss
272
- function. Valid keys are `dyn_loss`, `norm_loss`, `boundary_loss`
273
- and `observations`.
282
+ function. Valid keys are `dyn_loss`, `norm_loss`, `boundary_loss`,
283
+ `observations` and `sobolev`.
274
284
  Note that we can have jnp.arrays with the same dimension of
275
285
  `u` which then ponderates each output of `u`
276
286
  dynamic_loss
@@ -441,18 +451,12 @@ class LossPDEStatio(LossPDEAbstract):
441
451
  ) # we return a function, that way
442
452
  # the order of sobolev_m is static and the conditional in the recursive
443
453
  # function is properly set
444
- self.sobolev_m = self.sobolev_m
445
454
  else:
446
455
  self.sobolev_reg = None
447
456
 
448
- if self.normalization_loss is None:
449
- self.loss_weights["norm_loss"] = 0
450
-
451
- if self.omega_boundary_fun is None:
452
- self.loss_weights["boundary_loss"] = 0
453
-
454
- if self.sobolev_reg is None:
455
- self.loss_weights["sobolev"] = 0
457
+ for k in _LOSS_WEIGHT_KEYS_PDESTATIO:
458
+ if k not in self.loss_weights.keys():
459
+ self.loss_weights[k] = 0
456
460
 
457
461
  if (
458
462
  isinstance(self.omega_boundary_fun, dict)
@@ -533,7 +537,6 @@ class LossPDEStatio(LossPDEAbstract):
533
537
  )
534
538
  else:
535
539
  mse_norm_loss = jnp.array(0.0)
536
- self.loss_weights["norm_loss"] = 0
537
540
 
538
541
  # boundary part
539
542
  params_ = _set_derivatives(params, "boundary_loss", self.derivative_keys)
@@ -567,7 +570,6 @@ class LossPDEStatio(LossPDEAbstract):
567
570
  )
568
571
  else:
569
572
  mse_observation_loss = jnp.array(0.0)
570
- self.loss_weights["observations"] = 0
571
573
 
572
574
  # Sobolev regularization
573
575
  params_ = _set_derivatives(params, "sobolev", self.derivative_keys)
@@ -582,7 +584,6 @@ class LossPDEStatio(LossPDEAbstract):
582
584
  )
583
585
  else:
584
586
  mse_sobolev_loss = jnp.array(0.0)
585
- self.loss_weights["sobolev"] = 0
586
587
 
587
588
  # total loss
588
589
  total_loss = (
@@ -785,8 +786,9 @@ class LossPDENonStatio(LossPDEStatio):
785
786
  else:
786
787
  self.sobolev_reg = None
787
788
 
788
- if self.sobolev_reg is None:
789
- self.loss_weights["sobolev"] = 0
789
+ for k in _LOSS_WEIGHT_KEYS_PDENONSTATIO:
790
+ if k not in self.loss_weights.keys():
791
+ self.loss_weights[k] = 0
790
792
 
791
793
  def __call__(self, *args, **kwargs):
792
794
  return self.evaluate(*args, **kwargs)
@@ -924,7 +926,6 @@ class LossPDENonStatio(LossPDEStatio):
924
926
  )
925
927
  else:
926
928
  mse_observation_loss = jnp.array(0.0)
927
- self.loss_weights["observations"] = 0
928
929
 
929
930
  # Sobolev regularization
930
931
  params_ = _set_derivatives(params, "sobolev", self.derivative_keys)
@@ -939,7 +940,6 @@ class LossPDENonStatio(LossPDEStatio):
939
940
  )
940
941
  else:
941
942
  mse_sobolev_loss = jnp.array(0.0)
942
- self.loss_weights["sobolev"] = 0.0
943
943
 
944
944
  # total loss
945
945
  total_loss = (