jinns 1.1.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/nn/_spinn_mlp.py ADDED
@@ -0,0 +1,196 @@
1
+ """
2
+ Implements utility function to create Separable PINNs
3
+ https://arxiv.org/abs/2211.08761
4
+ """
5
+
6
+ from dataclasses import InitVar
7
+ from typing import Callable, Literal, Self, Union, Any
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import equinox as eqx
11
+ from jaxtyping import Key, Array, Float, PyTree
12
+
13
+ from jinns.parameters._params import Params, ParamsDict
14
+ from jinns.nn._mlp import MLP
15
+ from jinns.nn._spinn import SPINN
16
+
17
+
18
+ class SMLP(eqx.Module):
19
+ """
20
+ Construct a Separable MLP
21
+
22
+ Parameters
23
+ ----------
24
+ key : InitVar[Key]
25
+ A jax random key for the layer initializations.
26
+ d : int
27
+ The number of dimensions to treat separately, including time `t` if
28
+ used for non-stationnary equations.
29
+ eqx_list : InitVar[tuple[tuple[Callable, int, int] | Callable, ...]]
30
+ A tuple of tuples of successive equinox modules and activation functions to
31
+ describe the PINN architecture. The inner tuples must have the eqx module or
32
+ activation function as first item, other items represents arguments
33
+ that could be required (eg. the size of the layer).
34
+ The `key` argument need not be given.
35
+ Thus typical example is `eqx_list=
36
+ ((eqx.nn.Linear, 1, 20),
37
+ jax.nn.tanh,
38
+ (eqx.nn.Linear, 20, 20),
39
+ jax.nn.tanh,
40
+ (eqx.nn.Linear, 20, 20),
41
+ jax.nn.tanh,
42
+ (eqx.nn.Linear, 20, r * m)
43
+ )`.
44
+ """
45
+
46
+ key: InitVar[Key] = eqx.field(kw_only=True)
47
+ eqx_list: InitVar[tuple[tuple[Callable, int, int] | Callable, ...]] = eqx.field(
48
+ kw_only=True
49
+ )
50
+ d: int = eqx.field(static=True, kw_only=True)
51
+
52
+ separated_mlp: list[MLP] = eqx.field(init=False)
53
+
54
+ def __post_init__(self, key, eqx_list):
55
+ keys = jax.random.split(key, self.d)
56
+ self.separated_mlp = [
57
+ MLP(key=keys[d_], eqx_list=eqx_list) for d_ in range(self.d)
58
+ ]
59
+
60
+ def __call__(
61
+ self, inputs: Float[Array, "dim"] | Float[Array, "dim+1"]
62
+ ) -> Float[Array, "d embed_dim*output_dim"]:
63
+ outputs = []
64
+ for d in range(self.d):
65
+ x_i = inputs[d : d + 1]
66
+ outputs += [self.separated_mlp[d](x_i)]
67
+ return jnp.asarray(outputs)
68
+
69
+
70
+ class SPINN_MLP(SPINN):
71
+ """
72
+ An implementable SPINN based on a MLP architecture
73
+ """
74
+
75
+ @classmethod
76
+ def create(
77
+ cls,
78
+ key: Key,
79
+ d: int,
80
+ r: int,
81
+ eqx_list: tuple[tuple[Callable, int, int] | Callable, ...],
82
+ eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
83
+ m: int = 1,
84
+ filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = None,
85
+ ) -> tuple[Self, PyTree]:
86
+ """
87
+ Utility function to create a SPINN neural network with the equinox
88
+ library.
89
+
90
+ *Note* that a SPINN is not vmapped and expects the
91
+ same batch size for each of its input axis. It directly outputs a
92
+ solution of shape `(batchsize,) * d`. See the paper for more
93
+ details.
94
+
95
+ Parameters
96
+ ----------
97
+ key : Key
98
+ A JAX random key that will be used to initialize the network parameters
99
+ d : int
100
+ The number of dimensions to treat separately.
101
+ r : int
102
+ An integer. The dimension of the embedding.
103
+ eqx_list : tuple[tuple[Callable, int, int] | Callable, ...],
104
+ A tuple of tuples of successive equinox modules and activation functions to
105
+ describe the PINN architecture. The inner tuples must have the eqx module or
106
+ activation function as first item, other items represents arguments
107
+ that could be required (eg. the size of the layer).
108
+ The `key` argument need not be given.
109
+ Thus typical example is
110
+ `eqx_list=((eqx.nn.Linear, 1, 20),
111
+ jax.nn.tanh,
112
+ (eqx.nn.Linear, 20, 20),
113
+ jax.nn.tanh,
114
+ (eqx.nn.Linear, 20, 20),
115
+ jax.nn.tanh,
116
+ (eqx.nn.Linear, 20, r * m)
117
+ )`.
118
+ eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
119
+ A string with three possibilities.
120
+ "ODE": the PINN is called with one input `t`.
121
+ "statio_PDE": the PINN is called with one input `x`, `x`
122
+ can be high dimensional.
123
+ "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
124
+ can be high dimensional.
125
+ **Note**: the input dimension as given in eqx_list has to match the sum
126
+ of the dimension of `t` + the dimension of `x`.
127
+ m : int
128
+ The output dimension of the neural network. According to
129
+ the SPINN article, a total embedding dimension of `r*m` is defined. We
130
+ then sum groups of `r` embedding dimensions to compute each output.
131
+ Default is 1.
132
+ filter_spec : PyTree[Union[bool, Callable[[Any], bool]]]
133
+ Default is None which leads to `eqx.is_inexact_array` in the class
134
+ instanciation. This tells Jinns what to consider as
135
+ a trainable parameter. Quoting from equinox documentation:
136
+ a PyTree whose structure should be a prefix of the structure of pytree.
137
+ Each of its leaves should either be 1) True, in which case the leaf or
138
+ subtree is kept; 2) False, in which case the leaf or subtree is
139
+ replaced with replace; 3) a callable Leaf -> bool, in which case this is evaluated on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
140
+
141
+
142
+
143
+
144
+ Returns
145
+ -------
146
+ spinn
147
+ An instanciated SPINN
148
+ spinn.init_params
149
+ The initial set of parameters of the model
150
+
151
+ Raises
152
+ ------
153
+ RuntimeError
154
+ If the parameter value for eq_type is not in `["ODE", "statio_PDE",
155
+ "nonstatio_PDE"]` and for various failing checks
156
+ """
157
+
158
+ if eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
159
+ raise RuntimeError("Wrong parameter value for eq_type")
160
+
161
+ try:
162
+ nb_inputs_declared = eqx_list[0][
163
+ 1
164
+ ] # normally we look for 2nd ele of 1st layer
165
+ except IndexError:
166
+ nb_inputs_declared = eqx_list[1][
167
+ 1
168
+ ] # but we can have, eg, a flatten first layer
169
+ if nb_inputs_declared != 1:
170
+ raise ValueError("Input dim must be set to 1 in SPINN!")
171
+
172
+ try:
173
+ nb_outputs_declared = eqx_list[-1][2] # normally we look for 3rd ele of
174
+ # last layer
175
+ except IndexError:
176
+ nb_outputs_declared = eqx_list[-2][2]
177
+ # but we can have, eg, a `jnp.exp` last layer
178
+ if nb_outputs_declared != r * m:
179
+ raise ValueError("Output dim must be set to r * m in SPINN!")
180
+
181
+ if d > 24:
182
+ raise ValueError(
183
+ "Too many dimensions, not enough letters available in jnp.einsum"
184
+ )
185
+
186
+ smlp = SMLP(key=key, d=d, eqx_list=eqx_list)
187
+ spinn = cls(
188
+ eqx_spinn_network=smlp,
189
+ d=d,
190
+ r=r,
191
+ eq_type=eq_type,
192
+ m=m,
193
+ filter_spec=filter_spec,
194
+ )
195
+
196
+ return spinn, spinn.init_params
jinns/plot/_plot.py CHANGED
@@ -21,8 +21,7 @@ def plot2d(
21
21
  figsize: tuple = (7, 7),
22
22
  cmap: str = "inferno",
23
23
  spinn: bool = False,
24
- vmin_vmax: tuple[float, float] | None = None,
25
- ax_for_plot: plt.Axes | None = None,
24
+ vmin_vmax: tuple[float, float] = [None, None],
26
25
  ):
27
26
  r"""Generic function for plotting functions over rectangular 2-D domains
28
27
  $\Omega$. It handles both the
@@ -55,11 +54,7 @@ def plot2d(
55
54
  vmin_vmax :
56
55
  The colorbar minimum and maximum value. Defaults None.
57
56
  spinn :
58
- True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
59
- ax_for_plot :
60
- If None, jinns triggers the plotting. Otherwise this argument
61
- corresponds to the axis which will host the plot. Default is None.
62
- NOTE: that this argument will have an effect only if times is None.
57
+ True if the function is a `SPINN` object.
63
58
 
64
59
  Raises
65
60
  ------
@@ -73,55 +68,48 @@ def plot2d(
73
68
  "xy_data must be a list of length 2 containing"
74
69
  "jnp.array of shape (nx,) and (ny,)."
75
70
  )
71
+
76
72
  mesh = jnp.meshgrid(xy_data[0], xy_data[1]) # cartesian product
77
73
 
78
74
  if times is None:
79
75
  # Statio case : expect a function of one argument fun(x)
76
+
80
77
  if not spinn:
81
78
  v_fun = vmap(fun, 0, 0)
82
79
  ret = _plot_2D_statio(
83
80
  v_fun,
84
81
  mesh,
85
- plot=not ax_for_plot,
86
82
  colorbar=True,
87
83
  cmap=cmap,
88
84
  figsize=figsize,
89
85
  vmin_vmax=vmin_vmax,
90
86
  )
91
87
  elif spinn:
92
- values_grid = jnp.squeeze(
93
- fun(jnp.stack([xy_data[0][..., None], xy_data[1][..., None]], axis=1))
94
- )
88
+ values_grid = jnp.squeeze(fun(jnp.stack([xy_data[0], xy_data[1]], axis=1)))
95
89
  ret = _plot_2D_statio(
96
90
  values_grid,
97
91
  mesh,
98
- plot=not ax_for_plot,
99
92
  colorbar=True,
100
93
  cmap=cmap,
101
- spinn=True,
102
94
  figsize=figsize,
103
95
  vmin_vmax=vmin_vmax,
104
96
  )
105
- if not ax_for_plot:
106
- plt.title(title)
107
97
  else:
108
- if vmin_vmax is not None:
109
- im = ax_for_plot.pcolormesh(
110
- mesh[0],
111
- mesh[1],
112
- ret[0],
113
- cmap=cmap,
114
- vmin=vmin_vmax[0],
115
- vmax=vmin_vmax[1],
116
- )
117
- else:
118
- im = ax_for_plot.pcolormesh(mesh[0], mesh[1], ret[0], cmap=cmap)
119
- ax_for_plot.set_title(title)
120
- ax_for_plot.cax.colorbar(im, format="%0.2f")
98
+ fig, ax = plt.subplots(1, 1, figsize=figsize)
99
+
100
+ im = ax.pcolormesh(
101
+ mesh[0],
102
+ mesh[1],
103
+ ret[0],
104
+ cmap=cmap,
105
+ vmin=vmin_vmax[0],
106
+ vmax=vmin_vmax[1],
107
+ )
108
+
109
+ ax.set_title(title)
110
+ fig.cax.colorbar(im, format="%0.2f")
121
111
 
122
112
  else:
123
- if ax_for_plot is not None:
124
- warnings.warn("ax_for_plot is ignored. jinns will plot the figure")
125
113
  if not isinstance(times, list):
126
114
  try:
127
115
  times = times.tolist()
@@ -143,56 +131,64 @@ def plot2d(
143
131
 
144
132
  for idx, (t, ax) in enumerate(zip(times, grid)):
145
133
  if not spinn:
146
- v_fun_at_t = vmap(lambda x: fun(t=jnp.array([t]), x=x), 0, 0)
134
+ x_grid, y_grid = mesh
135
+ v_fun_at_t = vmap(fun)(
136
+ jnp.concatenate(
137
+ [
138
+ t
139
+ * jnp.ones((xy_data[0].shape[0] * xy_data[1].shape[0], 1)),
140
+ jnp.vstack([x_grid.flatten(), y_grid.flatten()]).T,
141
+ ],
142
+ axis=-1,
143
+ )
144
+ )
147
145
  t_slice, _ = _plot_2D_statio(
148
146
  v_fun_at_t,
149
147
  mesh,
150
- plot=False,
148
+ plot=False, # only use to compute t_slice
151
149
  colorbar=False,
152
150
  cmap=None,
153
151
  vmin_vmax=vmin_vmax,
154
152
  )
155
153
  elif spinn:
156
- values_grid = jnp.squeeze(
157
- fun(
154
+ t_x = jnp.concatenate(
155
+ [
158
156
  t * jnp.ones((xy_data[0].shape[0], 1)),
159
- jnp.stack(
160
- [xy_data[0][..., None], xy_data[1][..., None]], axis=1
157
+ jnp.concatenate(
158
+ [xy_data[0][..., None], xy_data[1][..., None]], axis=-1
161
159
  ),
162
- )[0]
160
+ ],
161
+ axis=-1,
163
162
  )
163
+ values_grid = jnp.squeeze(fun(t_x)[0]).T
164
164
  t_slice, _ = _plot_2D_statio(
165
165
  values_grid,
166
166
  mesh,
167
- plot=False,
167
+ plot=False, # only use to compute t_slice
168
168
  colorbar=True,
169
- spinn=True,
170
169
  vmin_vmax=vmin_vmax,
171
170
  )
172
- if vmin_vmax is not None:
173
- im = ax.pcolormesh(
174
- mesh[0],
175
- mesh[1],
176
- t_slice,
177
- cmap=cmap,
178
- vmin=vmin_vmax[0],
179
- vmax=vmin_vmax[1],
180
- )
181
- else:
182
- im = ax.pcolormesh(mesh[0], mesh[1], t_slice, cmap=cmap)
171
+
172
+ im = ax.pcolormesh(
173
+ mesh[0],
174
+ mesh[1],
175
+ t_slice,
176
+ cmap=cmap,
177
+ vmin=vmin_vmax[0],
178
+ vmax=vmin_vmax[1],
179
+ )
183
180
  ax.set_title(f"t = {times[idx] * Tmax:.2f}")
184
181
  ax.cax.colorbar(im, format="%0.2f")
185
182
 
186
183
 
187
184
  def _plot_2D_statio(
188
- v_fun,
185
+ v_fun: Callable | Float[Array, "(nx*ny)^2 1"],
189
186
  mesh: Float[Array, "nx*ny nx*ny"],
190
187
  plot: Bool = True,
191
188
  colorbar: Bool = True,
192
189
  cmap: str = "inferno",
193
190
  figsize: tuple[int, int] = (7, 7),
194
- spinn: Bool = False,
195
- vmin_vmax: tuple[float, float] = None,
191
+ vmin_vmax: tuple[float, float] = [None, None],
196
192
  ):
197
193
  """Function that plot the function u(x) with 2-D input x using pcolormesh()
198
194
 
@@ -200,11 +196,11 @@ def _plot_2D_statio(
200
196
  Parameters
201
197
  ----------
202
198
  v_fun :
203
- a vmapped function over jnp.array of shape (*, 2)
199
+ a vmapped function over jnp.array of shape (*, 2) OR a precomputed array of function values with shape compatible with `mesh`.
204
200
  mesh :
205
201
  a tuple of size 2, containing the x and y meshgrid.
206
202
  plot : bool, optional
207
- either show or return the plot, by default True
203
+ either displays the plot, or silently returns the grid of values `v_fun(mesh)`.
208
204
  colorbar : bool, optional
209
205
  add a colorbar, by default True
210
206
  cmap :
@@ -212,8 +208,8 @@ def _plot_2D_statio(
212
208
  figsize :
213
209
  By default (7, 7)
214
210
  spinn :
215
- True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
216
- vmin_vmax: tuple, optional
211
+ True if a SPINN is to be plotted. False for PINNs and HyperPINNs
212
+ vmin_vmax: list, optional
217
213
  The colorbar minimum and maximum value. Defaults None.
218
214
 
219
215
  Returns
@@ -223,26 +219,23 @@ def _plot_2D_statio(
223
219
  """
224
220
 
225
221
  x_grid, y_grid = mesh
226
- if not spinn:
222
+ if callable(v_fun):
227
223
  values = v_fun(jnp.vstack([x_grid.flatten(), y_grid.flatten()]).T)
228
224
  values_grid = values.reshape(x_grid.shape)
229
- elif spinn:
230
- # in this case v_fun is directly the values :)
231
- values_grid = v_fun.T
225
+ else:
226
+ values_grid = v_fun.reshape(x_grid.shape)
232
227
 
233
228
  if plot:
234
229
  fig = plt.figure(figsize=figsize)
235
- if vmin_vmax is not None:
236
- im = plt.pcolormesh(
237
- x_grid,
238
- y_grid,
239
- values_grid,
240
- cmap=cmap,
241
- vmin=vmin_vmax[0],
242
- vmax=vmin_vmax[1],
243
- )
244
- else:
245
- im = plt.pcolormesh(x_grid, y_grid, values_grid, cmap=cmap)
230
+ im = plt.pcolormesh(
231
+ x_grid,
232
+ y_grid,
233
+ values_grid,
234
+ cmap=cmap,
235
+ vmin=vmin_vmax[0],
236
+ vmax=vmin_vmax[1],
237
+ )
238
+
246
239
  if colorbar:
247
240
  fig.colorbar(im, format="%0.2f")
248
241
  # don't plt.show() because it is done in plot2d()
@@ -258,6 +251,7 @@ def plot1d_slice(
258
251
  title: str = "",
259
252
  figsize: tuple[int, int] = (10, 10),
260
253
  spinn: Bool = False,
254
+ ax=None,
261
255
  ):
262
256
  """Function for plotting time slices of a function :math:`f(t_i, x)` where
263
257
  `t_i` is time (1-D) and x is 1-D
@@ -275,29 +269,40 @@ def plot1d_slice(
275
269
  default 1
276
270
  title
277
271
  title of the plot, by default ""
278
- spinn :
279
- True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
280
272
  figsize
281
273
  size of the figure, by default (10, 10)
274
+ spinn
275
+ True if a SPINN is to be plotted. False for PINNs and HyperPINNs
276
+ ax
277
+ A pre-defined `matplotlib.Axes` where you want to plot.
278
+
279
+ Returns
280
+ -------
281
+ ax
282
+ A `matplotlib.Axes` object
282
283
  """
283
284
  if time_slices is None:
284
285
  time_slices = jnp.array([0])
285
- plt.figure(figsize=figsize)
286
+ if ax is None:
287
+ fig, ax = plt.subplots(figsize=figsize)
288
+
286
289
  for t in time_slices:
290
+ t_xdata = jnp.concatenate(
291
+ [t * jnp.ones((xdata.shape[0], 1)), xdata[:, None]], axis=1
292
+ )
287
293
  if not spinn:
288
294
  # fix t with partial : shape is (1,)
289
- v_u_tfixed = vmap(partial(fun, t=t * jnp.ones((1,))), 0, 0)
295
+ v_u_tfixed = vmap(fun)
290
296
  # add an axis to xdata for the concatenate function in the neural net
291
- values = v_u_tfixed(x=xdata[:, None])
297
+ values = v_u_tfixed(t_xdata)
292
298
  elif spinn:
293
- values = jnp.squeeze(
294
- fun(t * jnp.ones((xdata.shape[0], 1)), xdata[..., None])[0]
295
- )
296
- plt.plot(xdata, values, label=f"$t_i={t * Tmax:.2f}$")
297
- plt.xlabel("x")
298
- plt.ylabel(r"$u(t_i, x)$")
299
- plt.legend()
300
- plt.title(title)
299
+ values = jnp.squeeze(fun(t_xdata)[0])
300
+ ax.plot(xdata, values, label=f"$t_i={t * Tmax:.2f}$")
301
+ ax.set_xlabel("x")
302
+ ax.set_ylabel(r"$u(t_i, x)$")
303
+ ax.legend()
304
+ ax.set_title(title)
305
+ return ax
301
306
 
302
307
 
303
308
  def plot1d_image(
@@ -310,7 +315,7 @@ def plot1d_image(
310
315
  colorbar: Bool = True,
311
316
  cmap: str = "inferno",
312
317
  spinn: Bool = False,
313
- vmin_vmax: tuple[float, float] = None,
318
+ vmin_vmax: tuple[float, float] = [None, None],
314
319
  ):
315
320
  """Function for plotting the 2-D image of a function :math:`f(t, x)` where
316
321
  `t` is time (1-D) and x is space (1-D).
@@ -336,33 +341,41 @@ def plot1d_image(
336
341
  cmap :
337
342
  the matplotlib color map used in the ImageGrid.
338
343
  spinn :
339
- True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
344
+ True if a SPINN is to be plotted. False for PINNs and HyperPINNs
340
345
  vmin_vmax:
341
346
  The colorbar minimum and maximum value. Defaults None.
347
+
348
+ Returns
349
+ -------
350
+ fig, ax
351
+ A `matplotlib` `Figure` and `Axes` objects with the figure.
342
352
  """
343
353
 
344
354
  mesh = jnp.meshgrid(times, xdata) # cartesian product
345
355
  if not spinn:
346
356
  # the trick is to use _plot2Dstatio
347
- v_fun = vmap(lambda tx: fun(t=tx[0, None], x=tx[1, None]), 0, 0)
357
+ v_fun = vmap(fun) # lambda tx: fun(t=tx[0, None], x=tx[1, None]), 0, 0)
348
358
  t_grid, x_grid = mesh
349
359
  values_grid = v_fun(jnp.vstack([t_grid.flatten(), x_grid.flatten()]).T).reshape(
350
360
  t_grid.shape
351
361
  )
352
362
  elif spinn:
353
- values_grid = jnp.squeeze(fun((times[..., None]), xdata[..., None]).T)
363
+ values_grid = jnp.squeeze(
364
+ fun(jnp.concatenate([times[..., None], xdata[..., None]], axis=-1))
365
+ ).T
366
+
354
367
  fig, ax = plt.subplots(1, 1, figsize=figsize)
355
- if vmin_vmax is not None:
356
- im = ax.pcolormesh(
357
- mesh[0] * Tmax,
358
- mesh[1],
359
- values_grid,
360
- cmap=cmap,
361
- vmin=vmin_vmax[0],
362
- vmax=vmin_vmax[1],
363
- )
364
- else:
365
- im = ax.pcolormesh(mesh[0] * Tmax, mesh[1], values_grid, cmap=cmap)
368
+ im = ax.pcolormesh(
369
+ mesh[0] * Tmax,
370
+ mesh[1],
371
+ values_grid,
372
+ cmap=cmap,
373
+ vmin=vmin_vmax[0],
374
+ vmax=vmin_vmax[1],
375
+ )
376
+
366
377
  if colorbar:
367
378
  fig.colorbar(im, format="%0.2f")
368
379
  ax.set_title(title)
380
+
381
+ return fig, ax