jinns 0.8.10__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +953 -1182
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +321 -168
  9. jinns/loss/_LossODE.py +290 -307
  10. jinns/loss/_LossPDE.py +628 -1040
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +95 -96
  13. jinns/loss/{_Losses.py → _loss_utils.py} +104 -46
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +94 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +193 -45
  22. jinns/solver/_solve.py +199 -144
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -43
  25. jinns/utils/_hyperpinn.py +226 -127
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +117 -84
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +52 -144
  32. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/METADATA +5 -4
  33. jinns-1.0.0.dist-info/RECORD +38 -0
  34. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/WHEEL +1 -1
  35. jinns/experimental/_sinuspinn.py +0 -135
  36. jinns/experimental/_spectralpinn.py +0 -87
  37. jinns/solver/_seq2seq.py +0 -157
  38. jinns/utils/_optim.py +0 -147
  39. jinns/utils/_utils_uspinn.py +0 -727
  40. jinns-0.8.10.dist-info/RECORD +0 -36
  41. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/LICENSE +0 -0
  42. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  """
2
- Utility functions for plotting
2
+ Utility functions for plotting in 1D and 2D, with and without time.
3
3
  """
4
4
 
5
5
  from functools import partial
@@ -8,48 +8,55 @@ import matplotlib.pyplot as plt
8
8
  import jax.numpy as jnp
9
9
  from jax import vmap
10
10
  from mpl_toolkits.axes_grid1 import ImageGrid
11
+ from typing import Callable, List
12
+ from jaxtyping import Array, Float, Bool
11
13
 
12
14
 
13
15
  def plot2d(
14
- fun,
15
- xy_data,
16
- times=None,
17
- Tmax=1,
18
- title="",
19
- figsize=(7, 7),
20
- cmap="inferno",
21
- spinn=False,
22
- vmin_vmax=None,
23
- ax_for_plot=None,
16
+ fun: Callable,
17
+ xy_data: tuple[Float[Array, "nx"], Float[Array, "ny"]],
18
+ times: Float[Array, "nt"] | List[float] | None = None,
19
+ Tmax: float = 1,
20
+ title: str = "",
21
+ figsize: tuple = (7, 7),
22
+ cmap: str = "inferno",
23
+ spinn: bool = False,
24
+ vmin_vmax: tuple[float, float] | None = None,
25
+ ax_for_plot: plt.Axes | None = None,
24
26
  ):
25
27
  r"""Generic function for plotting functions over rectangular 2-D domains
26
- :math:`\Omega`. It treats both the stationary case :math:`u(x)` or the
27
- non-stationnary case :math:`u(t, x)`.
28
+ $\Omega$. It handles both the
28
29
 
29
- When in the non-stationnary case, the `times` argument gives the time
30
- slices :math:`t_i` at which to plot :math:`u(t_i, x)`.
30
+ 1. the stationary case $u(x)$
31
+ 2. the non-stationnary case $u(t, x)$
32
+
33
+ In the non-stationnary case, the `times` argument gives the time
34
+ slices $t_i$ at which to plot $u(t_i, x)$.
31
35
 
32
36
 
33
37
  Parameters
34
38
  ----------
35
- fun : _type_
36
- _description_
37
- xy_data : _type_
38
- _description_
39
- times : _type_, optional
40
- _description_, by default None
41
- Tmax : float, only in non-stationary cases
42
- Useful if you used time re-scaling in the differential equation, by
43
- default 1
44
- title : str, optional
39
+ fun :
40
+ the function $u$ to plot on the meshgrid, and eventually the time
41
+ slices. It's suppose to have signature `u(x)` in the stationnary case,, and `u(t, x)` in the non-stationnary case. Use `partial` or `lambda to freeze / reorder any other arguments.
42
+ xy_data :
43
+ A list of 2 `jnp.Array` providing grid values for meshgrid creation
44
+ times :
45
+ list or Array of time slices where to plot the function. Use Tmax if
46
+ you trained with time-rescaling.
47
+ Tmax :
48
+ Useful if you used time rescaling in the differential equation for training, default to 1 (no rescaling).
49
+ title :
45
50
  plot title, by default ""
46
- figsize : tuple, optional
47
- _description_, by default (7, 7)
48
- cmap : str, optional
49
- _description_, by default "inferno"
50
- vmin_vmax : tuple, optional
51
+ figsize :
52
+ By default (7, 7)
53
+ cmap :
54
+ the matplotlib color map used in the ImageGrid.
55
+ vmin_vmax :
51
56
  The colorbar minimum and maximum value. Defaults None.
52
- ax_for_plot : Matplotlib axis, optional
57
+ spinn :
58
+ True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
59
+ ax_for_plot :
53
60
  If None, jinns triggers the plotting. Otherwise this argument
54
61
  corresponds to the axis which will host the plot. Default is None.
55
62
  NOTE: that this argument will have an effect only if times is None.
@@ -57,7 +64,7 @@ def plot2d(
57
64
  Raises
58
65
  ------
59
66
  ValueError
60
- if xy_data is not a list of type 2
67
+ if xy_data is not a list of length 2
61
68
  """
62
69
 
63
70
  # if not isinstance(xy_data, jnp.ndarray) and not xy_data.shape[-1] == 2:
@@ -179,13 +186,13 @@ def plot2d(
179
186
 
180
187
  def _plot_2D_statio(
181
188
  v_fun,
182
- mesh,
183
- plot=True,
184
- colorbar=True,
185
- cmap="inferno",
186
- figsize=(7, 7),
187
- spinn=False,
188
- vmin_vmax=None,
189
+ mesh: Float[Array, "nx*ny nx*ny"],
190
+ plot: Bool = True,
191
+ colorbar: Bool = True,
192
+ cmap: str = "inferno",
193
+ figsize: tuple[int, int] = (7, 7),
194
+ spinn: Bool = False,
195
+ vmin_vmax: tuple[float, float] = None,
189
196
  ):
190
197
  """Function that plot the function u(x) with 2-D input x using pcolormesh()
191
198
 
@@ -200,6 +207,12 @@ def _plot_2D_statio(
200
207
  either show or return the plot, by default True
201
208
  colorbar : bool, optional
202
209
  add a colorbar, by default True
210
+ cmap :
211
+ the matplotlib color map used in the ImageGrid.
212
+ figsize :
213
+ By default (7, 7)
214
+ spinn :
215
+ True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
203
216
  vmin_vmax: tuple, optional
204
217
  The colorbar minimum and maximum value. Defaults None.
205
218
 
@@ -238,33 +251,37 @@ def _plot_2D_statio(
238
251
 
239
252
 
240
253
  def plot1d_slice(
241
- fun,
242
- xdata,
243
- time_slices=jnp.array([0]),
244
- Tmax=1,
245
- title="",
246
- figsize=(10, 10),
247
- spinn=False,
254
+ fun: Callable[[float, float], float],
255
+ xdata: Float[Array, "nx"],
256
+ time_slices: Float[Array, "nt"] | None = None,
257
+ Tmax: float = 1.0,
258
+ title: str = "",
259
+ figsize: tuple[int, int] = (10, 10),
260
+ spinn: Bool = False,
248
261
  ):
249
262
  """Function for plotting time slices of a function :math:`f(t_i, x)` where
250
- `t` is time (1-D) and x is 1-D
263
+ `t_i` is time (1-D) and x is 1-D
251
264
 
252
265
  Parameters
253
266
  ----------
254
- fun : callable with two arguments `t` and `x`
267
+ fun
255
268
  f(t, x)
256
- xdata : jnp.array
269
+ xdata
257
270
  the discretization of space
258
- time_slices : list, optional
259
- the time slices :math:`t_i` at which to plot, by default [0]
260
- Tmax : int, optional
271
+ time_slices
272
+ the time slices :math:`t_i` at which to plot.
273
+ Tmax
261
274
  Useful if you used time re-scaling in the differential equation, by
262
275
  default 1
263
- title : str, optional
276
+ title
264
277
  title of the plot, by default ""
265
- figsize : tuple, optional
278
+ spinn :
279
+ True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
280
+ figsize
266
281
  size of the figure, by default (10, 10)
267
282
  """
283
+ if time_slices is None:
284
+ time_slices = jnp.array([0])
268
285
  plt.figure(figsize=figsize)
269
286
  for t in time_slices:
270
287
  if not spinn:
@@ -284,16 +301,16 @@ def plot1d_slice(
284
301
 
285
302
 
286
303
  def plot1d_image(
287
- fun,
288
- xdata,
289
- times,
290
- Tmax=1,
291
- title="",
292
- figsize=(10, 10),
293
- colorbar=True,
294
- cmap="inferno",
295
- spinn=False,
296
- vmin_vmax=None,
304
+ fun: Callable[[float, float], float],
305
+ xdata: Float[Array, "nx"],
306
+ times: Float[Array, "nt"],
307
+ Tmax: float = 1.0,
308
+ title: str = "",
309
+ figsize: tuple[int, int] = (10, 10),
310
+ colorbar: Bool = True,
311
+ cmap: str = "inferno",
312
+ spinn: Bool = False,
313
+ vmin_vmax: tuple[float, float] = None,
297
314
  ):
298
315
  """Function for plotting the 2-D image of a function :math:`f(t, x)` where
299
316
  `t` is time (1-D) and x is space (1-D).
@@ -302,19 +319,25 @@ def plot1d_image(
302
319
 
303
320
  Parameters
304
321
  ----------
305
- fun : callable with two arguments t and x
306
- the function to plot
307
- xdata : jnp.array
322
+ fun :
323
+ callable with two arguments t and x the function to plot
324
+ xdata :
308
325
  the discretization of space
309
- times : jnp.array
326
+ times :
310
327
  the discretization of time
311
- Tmax : int, optional
312
- _description_, by default 1
313
- title : str, optional
314
- , by default ""
315
- figsize : tuple, optional
316
- , by default (10, 10)
317
- vmin_vmax: tuple
328
+ Tmax :
329
+ by default 1
330
+ title :
331
+ by default ""
332
+ figsize :
333
+ by default (10, 10)
334
+ colorbar :
335
+ Whether to add a colobar
336
+ cmap :
337
+ the matplotlib color map used in the ImageGrid.
338
+ spinn :
339
+ True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
340
+ vmin_vmax:
318
341
  The colorbar minimum and maximum value. Defaults None.
319
342
  """
320
343