jinns 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/plot/_plot.py CHANGED
@@ -21,8 +21,7 @@ def plot2d(
21
21
  figsize: tuple = (7, 7),
22
22
  cmap: str = "inferno",
23
23
  spinn: bool = False,
24
- vmin_vmax: tuple[float, float] | None = None,
25
- ax_for_plot: plt.Axes | None = None,
24
+ vmin_vmax: tuple[float, float] = [None, None],
26
25
  ):
27
26
  r"""Generic function for plotting functions over rectangular 2-D domains
28
27
  $\Omega$. It handles both the
@@ -55,11 +54,7 @@ def plot2d(
55
54
  vmin_vmax :
56
55
  The colorbar minimum and maximum value. Defaults None.
57
56
  spinn :
58
- True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
59
- ax_for_plot :
60
- If None, jinns triggers the plotting. Otherwise this argument
61
- corresponds to the axis which will host the plot. Default is None.
62
- NOTE: that this argument will have an effect only if times is None.
57
+ True if the function is a `SPINN` object.
63
58
 
64
59
  Raises
65
60
  ------
@@ -73,55 +68,48 @@ def plot2d(
73
68
  "xy_data must be a list of length 2 containing"
74
69
  "jnp.array of shape (nx,) and (ny,)."
75
70
  )
71
+
76
72
  mesh = jnp.meshgrid(xy_data[0], xy_data[1]) # cartesian product
77
73
 
78
74
  if times is None:
79
75
  # Statio case : expect a function of one argument fun(x)
76
+
80
77
  if not spinn:
81
78
  v_fun = vmap(fun, 0, 0)
82
79
  ret = _plot_2D_statio(
83
80
  v_fun,
84
81
  mesh,
85
- plot=not ax_for_plot,
86
82
  colorbar=True,
87
83
  cmap=cmap,
88
84
  figsize=figsize,
89
85
  vmin_vmax=vmin_vmax,
90
86
  )
91
87
  elif spinn:
92
- values_grid = jnp.squeeze(
93
- fun(jnp.stack([xy_data[0][..., None], xy_data[1][..., None]], axis=1))
94
- )
88
+ values_grid = jnp.squeeze(fun(jnp.stack([xy_data[0], xy_data[1]], axis=1)))
95
89
  ret = _plot_2D_statio(
96
90
  values_grid,
97
91
  mesh,
98
- plot=not ax_for_plot,
99
92
  colorbar=True,
100
93
  cmap=cmap,
101
- spinn=True,
102
94
  figsize=figsize,
103
95
  vmin_vmax=vmin_vmax,
104
96
  )
105
- if not ax_for_plot:
106
- plt.title(title)
107
97
  else:
108
- if vmin_vmax is not None:
109
- im = ax_for_plot.pcolormesh(
110
- mesh[0],
111
- mesh[1],
112
- ret[0],
113
- cmap=cmap,
114
- vmin=vmin_vmax[0],
115
- vmax=vmin_vmax[1],
116
- )
117
- else:
118
- im = ax_for_plot.pcolormesh(mesh[0], mesh[1], ret[0], cmap=cmap)
119
- ax_for_plot.set_title(title)
120
- ax_for_plot.cax.colorbar(im, format="%0.2f")
98
+ fig, ax = plt.subplots(1, 1, figsize=figsize)
99
+
100
+ im = ax.pcolormesh(
101
+ mesh[0],
102
+ mesh[1],
103
+ ret[0],
104
+ cmap=cmap,
105
+ vmin=vmin_vmax[0],
106
+ vmax=vmin_vmax[1],
107
+ )
108
+
109
+ ax.set_title(title)
110
+ fig.cax.colorbar(im, format="%0.2f")
121
111
 
122
112
  else:
123
- if ax_for_plot is not None:
124
- warnings.warn("ax_for_plot is ignored. jinns will plot the figure")
125
113
  if not isinstance(times, list):
126
114
  try:
127
115
  times = times.tolist()
@@ -143,56 +131,64 @@ def plot2d(
143
131
 
144
132
  for idx, (t, ax) in enumerate(zip(times, grid)):
145
133
  if not spinn:
146
- v_fun_at_t = vmap(lambda x: fun(t=jnp.array([t]), x=x), 0, 0)
134
+ x_grid, y_grid = mesh
135
+ v_fun_at_t = vmap(fun)(
136
+ jnp.concatenate(
137
+ [
138
+ t
139
+ * jnp.ones((xy_data[0].shape[0] * xy_data[1].shape[0], 1)),
140
+ jnp.vstack([x_grid.flatten(), y_grid.flatten()]).T,
141
+ ],
142
+ axis=-1,
143
+ )
144
+ )
147
145
  t_slice, _ = _plot_2D_statio(
148
146
  v_fun_at_t,
149
147
  mesh,
150
- plot=False,
148
+ plot=False, # only use to compute t_slice
151
149
  colorbar=False,
152
150
  cmap=None,
153
151
  vmin_vmax=vmin_vmax,
154
152
  )
155
153
  elif spinn:
156
- values_grid = jnp.squeeze(
157
- fun(
154
+ t_x = jnp.concatenate(
155
+ [
158
156
  t * jnp.ones((xy_data[0].shape[0], 1)),
159
- jnp.stack(
160
- [xy_data[0][..., None], xy_data[1][..., None]], axis=1
157
+ jnp.concatenate(
158
+ [xy_data[0][..., None], xy_data[1][..., None]], axis=-1
161
159
  ),
162
- )[0]
160
+ ],
161
+ axis=-1,
163
162
  )
163
+ values_grid = jnp.squeeze(fun(t_x)[0]).T
164
164
  t_slice, _ = _plot_2D_statio(
165
165
  values_grid,
166
166
  mesh,
167
- plot=False,
167
+ plot=False, # only use to compute t_slice
168
168
  colorbar=True,
169
- spinn=True,
170
169
  vmin_vmax=vmin_vmax,
171
170
  )
172
- if vmin_vmax is not None:
173
- im = ax.pcolormesh(
174
- mesh[0],
175
- mesh[1],
176
- t_slice,
177
- cmap=cmap,
178
- vmin=vmin_vmax[0],
179
- vmax=vmin_vmax[1],
180
- )
181
- else:
182
- im = ax.pcolormesh(mesh[0], mesh[1], t_slice, cmap=cmap)
171
+
172
+ im = ax.pcolormesh(
173
+ mesh[0],
174
+ mesh[1],
175
+ t_slice,
176
+ cmap=cmap,
177
+ vmin=vmin_vmax[0],
178
+ vmax=vmin_vmax[1],
179
+ )
183
180
  ax.set_title(f"t = {times[idx] * Tmax:.2f}")
184
181
  ax.cax.colorbar(im, format="%0.2f")
185
182
 
186
183
 
187
184
  def _plot_2D_statio(
188
- v_fun,
185
+ v_fun: Callable | Float[Array, "(nx*ny)^2 1"],
189
186
  mesh: Float[Array, "nx*ny nx*ny"],
190
187
  plot: Bool = True,
191
188
  colorbar: Bool = True,
192
189
  cmap: str = "inferno",
193
190
  figsize: tuple[int, int] = (7, 7),
194
- spinn: Bool = False,
195
- vmin_vmax: tuple[float, float] = None,
191
+ vmin_vmax: tuple[float, float] = [None, None],
196
192
  ):
197
193
  """Function that plot the function u(x) with 2-D input x using pcolormesh()
198
194
 
@@ -200,11 +196,11 @@ def _plot_2D_statio(
200
196
  Parameters
201
197
  ----------
202
198
  v_fun :
203
- a vmapped function over jnp.array of shape (*, 2)
199
+ a vmapped function over jnp.array of shape (*, 2) OR a precomputed array of function values with shape compatible with `mesh`.
204
200
  mesh :
205
201
  a tuple of size 2, containing the x and y meshgrid.
206
202
  plot : bool, optional
207
- either show or return the plot, by default True
203
+ either displays the plot, or silently returns the grid of values `v_fun(mesh)`.
208
204
  colorbar : bool, optional
209
205
  add a colorbar, by default True
210
206
  cmap :
@@ -213,7 +209,7 @@ def _plot_2D_statio(
213
209
  By default (7, 7)
214
210
  spinn :
215
211
  True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
216
- vmin_vmax: tuple, optional
212
+ vmin_vmax: list, optional
217
213
  The colorbar minimum and maximum value. Defaults None.
218
214
 
219
215
  Returns
@@ -223,26 +219,23 @@ def _plot_2D_statio(
223
219
  """
224
220
 
225
221
  x_grid, y_grid = mesh
226
- if not spinn:
222
+ if callable(v_fun):
227
223
  values = v_fun(jnp.vstack([x_grid.flatten(), y_grid.flatten()]).T)
228
224
  values_grid = values.reshape(x_grid.shape)
229
- elif spinn:
230
- # in this case v_fun is directly the values :)
231
- values_grid = v_fun.T
225
+ else:
226
+ values_grid = v_fun.reshape(x_grid.shape)
232
227
 
233
228
  if plot:
234
229
  fig = plt.figure(figsize=figsize)
235
- if vmin_vmax is not None:
236
- im = plt.pcolormesh(
237
- x_grid,
238
- y_grid,
239
- values_grid,
240
- cmap=cmap,
241
- vmin=vmin_vmax[0],
242
- vmax=vmin_vmax[1],
243
- )
244
- else:
245
- im = plt.pcolormesh(x_grid, y_grid, values_grid, cmap=cmap)
230
+ im = plt.pcolormesh(
231
+ x_grid,
232
+ y_grid,
233
+ values_grid,
234
+ cmap=cmap,
235
+ vmin=vmin_vmax[0],
236
+ vmax=vmin_vmax[1],
237
+ )
238
+
246
239
  if colorbar:
247
240
  fig.colorbar(im, format="%0.2f")
248
241
  # don't plt.show() because it is done in plot2d()
@@ -258,6 +251,7 @@ def plot1d_slice(
258
251
  title: str = "",
259
252
  figsize: tuple[int, int] = (10, 10),
260
253
  spinn: Bool = False,
254
+ ax=None,
261
255
  ):
262
256
  """Function for plotting time slices of a function :math:`f(t_i, x)` where
263
257
  `t_i` is time (1-D) and x is 1-D
@@ -275,29 +269,40 @@ def plot1d_slice(
275
269
  default 1
276
270
  title
277
271
  title of the plot, by default ""
278
- spinn :
279
- True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
280
272
  figsize
281
273
  size of the figure, by default (10, 10)
274
+ spinn
275
+ True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
276
+ ax
277
+ A pre-defined `matplotlib.Axes` where you want to plot.
278
+
279
+ Returns
280
+ -------
281
+ ax
282
+ A `matplotlib.Axes` object
282
283
  """
283
284
  if time_slices is None:
284
285
  time_slices = jnp.array([0])
285
- plt.figure(figsize=figsize)
286
+ if ax is None:
287
+ fig, ax = plt.subplots(figsize=figsize)
288
+
286
289
  for t in time_slices:
290
+ t_xdata = jnp.concatenate(
291
+ [t * jnp.ones((xdata.shape[0], 1)), xdata[:, None]], axis=1
292
+ )
287
293
  if not spinn:
288
294
  # fix t with partial : shape is (1,)
289
- v_u_tfixed = vmap(partial(fun, t=t * jnp.ones((1,))), 0, 0)
295
+ v_u_tfixed = vmap(fun)
290
296
  # add an axis to xdata for the concatenate function in the neural net
291
- values = v_u_tfixed(x=xdata[:, None])
297
+ values = v_u_tfixed(t_xdata)
292
298
  elif spinn:
293
- values = jnp.squeeze(
294
- fun(t * jnp.ones((xdata.shape[0], 1)), xdata[..., None])[0]
295
- )
296
- plt.plot(xdata, values, label=f"$t_i={t * Tmax:.2f}$")
297
- plt.xlabel("x")
298
- plt.ylabel(r"$u(t_i, x)$")
299
- plt.legend()
300
- plt.title(title)
299
+ values = jnp.squeeze(fun(t_xdata)[0])
300
+ ax.plot(xdata, values, label=f"$t_i={t * Tmax:.2f}$")
301
+ ax.set_xlabel("x")
302
+ ax.set_ylabel(r"$u(t_i, x)$")
303
+ ax.legend()
304
+ ax.set_title(title)
305
+ return ax
301
306
 
302
307
 
303
308
  def plot1d_image(
@@ -310,7 +315,7 @@ def plot1d_image(
310
315
  colorbar: Bool = True,
311
316
  cmap: str = "inferno",
312
317
  spinn: Bool = False,
313
- vmin_vmax: tuple[float, float] = None,
318
+ vmin_vmax: tuple[float, float] = [None, None],
314
319
  ):
315
320
  """Function for plotting the 2-D image of a function :math:`f(t, x)` where
316
321
  `t` is time (1-D) and x is space (1-D).
@@ -339,30 +344,38 @@ def plot1d_image(
339
344
  True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
340
345
  vmin_vmax:
341
346
  The colorbar minimum and maximum value. Defaults None.
347
+
348
+ Returns
349
+ -------
350
+ fig, ax
351
+ A `matplotlib` `Figure` and `Axes` objects with the figure.
342
352
  """
343
353
 
344
354
  mesh = jnp.meshgrid(times, xdata) # cartesian product
345
355
  if not spinn:
346
356
  # the trick is to use _plot2Dstatio
347
- v_fun = vmap(lambda tx: fun(t=tx[0, None], x=tx[1, None]), 0, 0)
357
+ v_fun = vmap(fun) # lambda tx: fun(t=tx[0, None], x=tx[1, None]), 0, 0)
348
358
  t_grid, x_grid = mesh
349
359
  values_grid = v_fun(jnp.vstack([t_grid.flatten(), x_grid.flatten()]).T).reshape(
350
360
  t_grid.shape
351
361
  )
352
362
  elif spinn:
353
- values_grid = jnp.squeeze(fun((times[..., None]), xdata[..., None]).T)
363
+ values_grid = jnp.squeeze(
364
+ fun(jnp.concatenate([times[..., None], xdata[..., None]], axis=-1))
365
+ ).T
366
+
354
367
  fig, ax = plt.subplots(1, 1, figsize=figsize)
355
- if vmin_vmax is not None:
356
- im = ax.pcolormesh(
357
- mesh[0] * Tmax,
358
- mesh[1],
359
- values_grid,
360
- cmap=cmap,
361
- vmin=vmin_vmax[0],
362
- vmax=vmin_vmax[1],
363
- )
364
- else:
365
- im = ax.pcolormesh(mesh[0] * Tmax, mesh[1], values_grid, cmap=cmap)
368
+ im = ax.pcolormesh(
369
+ mesh[0] * Tmax,
370
+ mesh[1],
371
+ values_grid,
372
+ cmap=cmap,
373
+ vmin=vmin_vmax[0],
374
+ vmax=vmin_vmax[1],
375
+ )
376
+
366
377
  if colorbar:
367
378
  fig.colorbar(im, format="%0.2f")
368
379
  ax.set_title(title)
380
+
381
+ return fig, ax