jinns 0.8.10__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -0
- jinns/data/_Batchs.py +27 -0
- jinns/data/_DataGenerators.py +953 -1182
- jinns/data/__init__.py +4 -8
- jinns/experimental/__init__.py +0 -2
- jinns/experimental/_diffrax_solver.py +5 -5
- jinns/loss/_DynamicLoss.py +282 -305
- jinns/loss/_DynamicLossAbstract.py +321 -168
- jinns/loss/_LossODE.py +290 -307
- jinns/loss/_LossPDE.py +628 -1040
- jinns/loss/__init__.py +21 -5
- jinns/loss/_boundary_conditions.py +95 -96
- jinns/loss/{_Losses.py → _loss_utils.py} +104 -46
- jinns/loss/_loss_weights.py +59 -0
- jinns/loss/_operators.py +78 -72
- jinns/parameters/__init__.py +6 -0
- jinns/parameters/_derivative_keys.py +94 -0
- jinns/parameters/_params.py +115 -0
- jinns/plot/__init__.py +5 -0
- jinns/{data/_display.py → plot/_plot.py} +98 -75
- jinns/solver/_rar.py +193 -45
- jinns/solver/_solve.py +199 -144
- jinns/utils/__init__.py +3 -9
- jinns/utils/_containers.py +37 -43
- jinns/utils/_hyperpinn.py +226 -127
- jinns/utils/_pinn.py +183 -111
- jinns/utils/_save_load.py +121 -56
- jinns/utils/_spinn.py +117 -84
- jinns/utils/_types.py +64 -0
- jinns/utils/_utils.py +6 -160
- jinns/validation/_validation.py +52 -144
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/METADATA +5 -4
- jinns-1.0.0.dist-info/RECORD +38 -0
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/WHEEL +1 -1
- jinns/experimental/_sinuspinn.py +0 -135
- jinns/experimental/_spectralpinn.py +0 -87
- jinns/solver/_seq2seq.py +0 -157
- jinns/utils/_optim.py +0 -147
- jinns/utils/_utils_uspinn.py +0 -727
- jinns-0.8.10.dist-info/RECORD +0 -36
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/LICENSE +0 -0
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Utility functions for plotting
|
|
2
|
+
Utility functions for plotting in 1D and 2D, with and without time.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from functools import partial
|
|
@@ -8,48 +8,55 @@ import matplotlib.pyplot as plt
|
|
|
8
8
|
import jax.numpy as jnp
|
|
9
9
|
from jax import vmap
|
|
10
10
|
from mpl_toolkits.axes_grid1 import ImageGrid
|
|
11
|
+
from typing import Callable, List
|
|
12
|
+
from jaxtyping import Array, Float, Bool
|
|
11
13
|
|
|
12
14
|
|
|
13
15
|
def plot2d(
|
|
14
|
-
fun,
|
|
15
|
-
xy_data,
|
|
16
|
-
times=None,
|
|
17
|
-
Tmax=1,
|
|
18
|
-
title="",
|
|
19
|
-
figsize=(7, 7),
|
|
20
|
-
cmap="inferno",
|
|
21
|
-
spinn=False,
|
|
22
|
-
vmin_vmax=None,
|
|
23
|
-
ax_for_plot=None,
|
|
16
|
+
fun: Callable,
|
|
17
|
+
xy_data: tuple[Float[Array, "nx"], Float[Array, "ny"]],
|
|
18
|
+
times: Float[Array, "nt"] | List[float] | None = None,
|
|
19
|
+
Tmax: float = 1,
|
|
20
|
+
title: str = "",
|
|
21
|
+
figsize: tuple = (7, 7),
|
|
22
|
+
cmap: str = "inferno",
|
|
23
|
+
spinn: bool = False,
|
|
24
|
+
vmin_vmax: tuple[float, float] | None = None,
|
|
25
|
+
ax_for_plot: plt.Axes | None = None,
|
|
24
26
|
):
|
|
25
27
|
r"""Generic function for plotting functions over rectangular 2-D domains
|
|
26
|
-
|
|
27
|
-
non-stationnary case :math:`u(t, x)`.
|
|
28
|
+
$\Omega$. It handles both the
|
|
28
29
|
|
|
29
|
-
|
|
30
|
-
|
|
30
|
+
1. the stationary case $u(x)$
|
|
31
|
+
2. the non-stationnary case $u(t, x)$
|
|
32
|
+
|
|
33
|
+
In the non-stationnary case, the `times` argument gives the time
|
|
34
|
+
slices $t_i$ at which to plot $u(t_i, x)$.
|
|
31
35
|
|
|
32
36
|
|
|
33
37
|
Parameters
|
|
34
38
|
----------
|
|
35
|
-
fun :
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
39
|
+
fun :
|
|
40
|
+
the function $u$ to plot on the meshgrid, and eventually the time
|
|
41
|
+
slices. It's suppose to have signature `u(x)` in the stationnary case,, and `u(t, x)` in the non-stationnary case. Use `partial` or `lambda to freeze / reorder any other arguments.
|
|
42
|
+
xy_data :
|
|
43
|
+
A list of 2 `jnp.Array` providing grid values for meshgrid creation
|
|
44
|
+
times :
|
|
45
|
+
list or Array of time slices where to plot the function. Use Tmax if
|
|
46
|
+
you trained with time-rescaling.
|
|
47
|
+
Tmax :
|
|
48
|
+
Useful if you used time rescaling in the differential equation for training, default to 1 (no rescaling).
|
|
49
|
+
title :
|
|
45
50
|
plot title, by default ""
|
|
46
|
-
figsize :
|
|
47
|
-
|
|
48
|
-
cmap :
|
|
49
|
-
|
|
50
|
-
vmin_vmax :
|
|
51
|
+
figsize :
|
|
52
|
+
By default (7, 7)
|
|
53
|
+
cmap :
|
|
54
|
+
the matplotlib color map used in the ImageGrid.
|
|
55
|
+
vmin_vmax :
|
|
51
56
|
The colorbar minimum and maximum value. Defaults None.
|
|
52
|
-
|
|
57
|
+
spinn :
|
|
58
|
+
True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
|
|
59
|
+
ax_for_plot :
|
|
53
60
|
If None, jinns triggers the plotting. Otherwise this argument
|
|
54
61
|
corresponds to the axis which will host the plot. Default is None.
|
|
55
62
|
NOTE: that this argument will have an effect only if times is None.
|
|
@@ -57,7 +64,7 @@ def plot2d(
|
|
|
57
64
|
Raises
|
|
58
65
|
------
|
|
59
66
|
ValueError
|
|
60
|
-
if xy_data is not a list of
|
|
67
|
+
if xy_data is not a list of length 2
|
|
61
68
|
"""
|
|
62
69
|
|
|
63
70
|
# if not isinstance(xy_data, jnp.ndarray) and not xy_data.shape[-1] == 2:
|
|
@@ -179,13 +186,13 @@ def plot2d(
|
|
|
179
186
|
|
|
180
187
|
def _plot_2D_statio(
|
|
181
188
|
v_fun,
|
|
182
|
-
mesh,
|
|
183
|
-
plot=True,
|
|
184
|
-
colorbar=True,
|
|
185
|
-
cmap="inferno",
|
|
186
|
-
figsize=(7, 7),
|
|
187
|
-
spinn=False,
|
|
188
|
-
vmin_vmax=None,
|
|
189
|
+
mesh: Float[Array, "nx*ny nx*ny"],
|
|
190
|
+
plot: Bool = True,
|
|
191
|
+
colorbar: Bool = True,
|
|
192
|
+
cmap: str = "inferno",
|
|
193
|
+
figsize: tuple[int, int] = (7, 7),
|
|
194
|
+
spinn: Bool = False,
|
|
195
|
+
vmin_vmax: tuple[float, float] = None,
|
|
189
196
|
):
|
|
190
197
|
"""Function that plot the function u(x) with 2-D input x using pcolormesh()
|
|
191
198
|
|
|
@@ -200,6 +207,12 @@ def _plot_2D_statio(
|
|
|
200
207
|
either show or return the plot, by default True
|
|
201
208
|
colorbar : bool, optional
|
|
202
209
|
add a colorbar, by default True
|
|
210
|
+
cmap :
|
|
211
|
+
the matplotlib color map used in the ImageGrid.
|
|
212
|
+
figsize :
|
|
213
|
+
By default (7, 7)
|
|
214
|
+
spinn :
|
|
215
|
+
True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
|
|
203
216
|
vmin_vmax: tuple, optional
|
|
204
217
|
The colorbar minimum and maximum value. Defaults None.
|
|
205
218
|
|
|
@@ -238,33 +251,37 @@ def _plot_2D_statio(
|
|
|
238
251
|
|
|
239
252
|
|
|
240
253
|
def plot1d_slice(
|
|
241
|
-
fun,
|
|
242
|
-
xdata,
|
|
243
|
-
time_slices
|
|
244
|
-
Tmax=1,
|
|
245
|
-
title="",
|
|
246
|
-
figsize=(10, 10),
|
|
247
|
-
spinn=False,
|
|
254
|
+
fun: Callable[[float, float], float],
|
|
255
|
+
xdata: Float[Array, "nx"],
|
|
256
|
+
time_slices: Float[Array, "nt"] | None = None,
|
|
257
|
+
Tmax: float = 1.0,
|
|
258
|
+
title: str = "",
|
|
259
|
+
figsize: tuple[int, int] = (10, 10),
|
|
260
|
+
spinn: Bool = False,
|
|
248
261
|
):
|
|
249
262
|
"""Function for plotting time slices of a function :math:`f(t_i, x)` where
|
|
250
|
-
`
|
|
263
|
+
`t_i` is time (1-D) and x is 1-D
|
|
251
264
|
|
|
252
265
|
Parameters
|
|
253
266
|
----------
|
|
254
|
-
fun
|
|
267
|
+
fun
|
|
255
268
|
f(t, x)
|
|
256
|
-
xdata
|
|
269
|
+
xdata
|
|
257
270
|
the discretization of space
|
|
258
|
-
time_slices
|
|
259
|
-
the time slices :math:`t_i` at which to plot
|
|
260
|
-
Tmax
|
|
271
|
+
time_slices
|
|
272
|
+
the time slices :math:`t_i` at which to plot.
|
|
273
|
+
Tmax
|
|
261
274
|
Useful if you used time re-scaling in the differential equation, by
|
|
262
275
|
default 1
|
|
263
|
-
title
|
|
276
|
+
title
|
|
264
277
|
title of the plot, by default ""
|
|
265
|
-
|
|
278
|
+
spinn :
|
|
279
|
+
True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
|
|
280
|
+
figsize
|
|
266
281
|
size of the figure, by default (10, 10)
|
|
267
282
|
"""
|
|
283
|
+
if time_slices is None:
|
|
284
|
+
time_slices = jnp.array([0])
|
|
268
285
|
plt.figure(figsize=figsize)
|
|
269
286
|
for t in time_slices:
|
|
270
287
|
if not spinn:
|
|
@@ -284,16 +301,16 @@ def plot1d_slice(
|
|
|
284
301
|
|
|
285
302
|
|
|
286
303
|
def plot1d_image(
|
|
287
|
-
fun,
|
|
288
|
-
xdata,
|
|
289
|
-
times,
|
|
290
|
-
Tmax=1,
|
|
291
|
-
title="",
|
|
292
|
-
figsize=(10, 10),
|
|
293
|
-
colorbar=True,
|
|
294
|
-
cmap="inferno",
|
|
295
|
-
spinn=False,
|
|
296
|
-
vmin_vmax=None,
|
|
304
|
+
fun: Callable[[float, float], float],
|
|
305
|
+
xdata: Float[Array, "nx"],
|
|
306
|
+
times: Float[Array, "nt"],
|
|
307
|
+
Tmax: float = 1.0,
|
|
308
|
+
title: str = "",
|
|
309
|
+
figsize: tuple[int, int] = (10, 10),
|
|
310
|
+
colorbar: Bool = True,
|
|
311
|
+
cmap: str = "inferno",
|
|
312
|
+
spinn: Bool = False,
|
|
313
|
+
vmin_vmax: tuple[float, float] = None,
|
|
297
314
|
):
|
|
298
315
|
"""Function for plotting the 2-D image of a function :math:`f(t, x)` where
|
|
299
316
|
`t` is time (1-D) and x is space (1-D).
|
|
@@ -302,19 +319,25 @@ def plot1d_image(
|
|
|
302
319
|
|
|
303
320
|
Parameters
|
|
304
321
|
----------
|
|
305
|
-
fun :
|
|
306
|
-
the function to plot
|
|
307
|
-
xdata :
|
|
322
|
+
fun :
|
|
323
|
+
callable with two arguments t and x the function to plot
|
|
324
|
+
xdata :
|
|
308
325
|
the discretization of space
|
|
309
|
-
times :
|
|
326
|
+
times :
|
|
310
327
|
the discretization of time
|
|
311
|
-
Tmax :
|
|
312
|
-
|
|
313
|
-
title :
|
|
314
|
-
|
|
315
|
-
figsize :
|
|
316
|
-
|
|
317
|
-
|
|
328
|
+
Tmax :
|
|
329
|
+
by default 1
|
|
330
|
+
title :
|
|
331
|
+
by default ""
|
|
332
|
+
figsize :
|
|
333
|
+
by default (10, 10)
|
|
334
|
+
colorbar :
|
|
335
|
+
Whether to add a colobar
|
|
336
|
+
cmap :
|
|
337
|
+
the matplotlib color map used in the ImageGrid.
|
|
338
|
+
spinn :
|
|
339
|
+
True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
|
|
340
|
+
vmin_vmax:
|
|
318
341
|
The colorbar minimum and maximum value. Defaults None.
|
|
319
342
|
"""
|
|
320
343
|
|