jinns 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_Batchs.py +4 -8
- jinns/data/_DataGenerators.py +532 -341
- jinns/loss/_DynamicLoss.py +150 -173
- jinns/loss/_DynamicLossAbstract.py +27 -73
- jinns/loss/_LossODE.py +45 -26
- jinns/loss/_LossPDE.py +85 -84
- jinns/loss/__init__.py +7 -6
- jinns/loss/_boundary_conditions.py +148 -279
- jinns/loss/_loss_utils.py +85 -58
- jinns/loss/_operators.py +441 -184
- jinns/parameters/_derivative_keys.py +487 -60
- jinns/plot/_plot.py +111 -98
- jinns/solver/_rar.py +102 -407
- jinns/solver/_solve.py +73 -38
- jinns/solver/_utils.py +122 -0
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +3 -1
- jinns/utils/_hyperpinn.py +17 -7
- jinns/utils/_pinn.py +17 -27
- jinns/utils/_ppinn.py +227 -0
- jinns/utils/_save_load.py +13 -13
- jinns/utils/_spinn.py +24 -43
- jinns/utils/_types.py +1 -0
- jinns/utils/_utils.py +40 -12
- jinns-1.2.0.dist-info/AUTHORS +2 -0
- jinns-1.2.0.dist-info/METADATA +127 -0
- jinns-1.2.0.dist-info/RECORD +41 -0
- {jinns-1.0.0.dist-info → jinns-1.2.0.dist-info}/WHEEL +1 -1
- jinns-1.0.0.dist-info/METADATA +0 -84
- jinns-1.0.0.dist-info/RECORD +0 -38
- {jinns-1.0.0.dist-info → jinns-1.2.0.dist-info}/LICENSE +0 -0
- {jinns-1.0.0.dist-info → jinns-1.2.0.dist-info}/top_level.txt +0 -0
jinns/plot/_plot.py
CHANGED
|
@@ -21,8 +21,7 @@ def plot2d(
|
|
|
21
21
|
figsize: tuple = (7, 7),
|
|
22
22
|
cmap: str = "inferno",
|
|
23
23
|
spinn: bool = False,
|
|
24
|
-
vmin_vmax: tuple[float, float]
|
|
25
|
-
ax_for_plot: plt.Axes | None = None,
|
|
24
|
+
vmin_vmax: tuple[float, float] = [None, None],
|
|
26
25
|
):
|
|
27
26
|
r"""Generic function for plotting functions over rectangular 2-D domains
|
|
28
27
|
$\Omega$. It handles both the
|
|
@@ -55,11 +54,7 @@ def plot2d(
|
|
|
55
54
|
vmin_vmax :
|
|
56
55
|
The colorbar minimum and maximum value. Defaults None.
|
|
57
56
|
spinn :
|
|
58
|
-
True if
|
|
59
|
-
ax_for_plot :
|
|
60
|
-
If None, jinns triggers the plotting. Otherwise this argument
|
|
61
|
-
corresponds to the axis which will host the plot. Default is None.
|
|
62
|
-
NOTE: that this argument will have an effect only if times is None.
|
|
57
|
+
True if the function is a `SPINN` object.
|
|
63
58
|
|
|
64
59
|
Raises
|
|
65
60
|
------
|
|
@@ -73,55 +68,48 @@ def plot2d(
|
|
|
73
68
|
"xy_data must be a list of length 2 containing"
|
|
74
69
|
"jnp.array of shape (nx,) and (ny,)."
|
|
75
70
|
)
|
|
71
|
+
|
|
76
72
|
mesh = jnp.meshgrid(xy_data[0], xy_data[1]) # cartesian product
|
|
77
73
|
|
|
78
74
|
if times is None:
|
|
79
75
|
# Statio case : expect a function of one argument fun(x)
|
|
76
|
+
|
|
80
77
|
if not spinn:
|
|
81
78
|
v_fun = vmap(fun, 0, 0)
|
|
82
79
|
ret = _plot_2D_statio(
|
|
83
80
|
v_fun,
|
|
84
81
|
mesh,
|
|
85
|
-
plot=not ax_for_plot,
|
|
86
82
|
colorbar=True,
|
|
87
83
|
cmap=cmap,
|
|
88
84
|
figsize=figsize,
|
|
89
85
|
vmin_vmax=vmin_vmax,
|
|
90
86
|
)
|
|
91
87
|
elif spinn:
|
|
92
|
-
values_grid = jnp.squeeze(
|
|
93
|
-
fun(jnp.stack([xy_data[0][..., None], xy_data[1][..., None]], axis=1))
|
|
94
|
-
)
|
|
88
|
+
values_grid = jnp.squeeze(fun(jnp.stack([xy_data[0], xy_data[1]], axis=1)))
|
|
95
89
|
ret = _plot_2D_statio(
|
|
96
90
|
values_grid,
|
|
97
91
|
mesh,
|
|
98
|
-
plot=not ax_for_plot,
|
|
99
92
|
colorbar=True,
|
|
100
93
|
cmap=cmap,
|
|
101
|
-
spinn=True,
|
|
102
94
|
figsize=figsize,
|
|
103
95
|
vmin_vmax=vmin_vmax,
|
|
104
96
|
)
|
|
105
|
-
if not ax_for_plot:
|
|
106
|
-
plt.title(title)
|
|
107
97
|
else:
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
98
|
+
fig, ax = plt.subplots(1, 1, figsize=figsize)
|
|
99
|
+
|
|
100
|
+
im = ax.pcolormesh(
|
|
101
|
+
mesh[0],
|
|
102
|
+
mesh[1],
|
|
103
|
+
ret[0],
|
|
104
|
+
cmap=cmap,
|
|
105
|
+
vmin=vmin_vmax[0],
|
|
106
|
+
vmax=vmin_vmax[1],
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
ax.set_title(title)
|
|
110
|
+
fig.cax.colorbar(im, format="%0.2f")
|
|
121
111
|
|
|
122
112
|
else:
|
|
123
|
-
if ax_for_plot is not None:
|
|
124
|
-
warnings.warn("ax_for_plot is ignored. jinns will plot the figure")
|
|
125
113
|
if not isinstance(times, list):
|
|
126
114
|
try:
|
|
127
115
|
times = times.tolist()
|
|
@@ -143,56 +131,64 @@ def plot2d(
|
|
|
143
131
|
|
|
144
132
|
for idx, (t, ax) in enumerate(zip(times, grid)):
|
|
145
133
|
if not spinn:
|
|
146
|
-
|
|
134
|
+
x_grid, y_grid = mesh
|
|
135
|
+
v_fun_at_t = vmap(fun)(
|
|
136
|
+
jnp.concatenate(
|
|
137
|
+
[
|
|
138
|
+
t
|
|
139
|
+
* jnp.ones((xy_data[0].shape[0] * xy_data[1].shape[0], 1)),
|
|
140
|
+
jnp.vstack([x_grid.flatten(), y_grid.flatten()]).T,
|
|
141
|
+
],
|
|
142
|
+
axis=-1,
|
|
143
|
+
)
|
|
144
|
+
)
|
|
147
145
|
t_slice, _ = _plot_2D_statio(
|
|
148
146
|
v_fun_at_t,
|
|
149
147
|
mesh,
|
|
150
|
-
plot=False,
|
|
148
|
+
plot=False, # only use to compute t_slice
|
|
151
149
|
colorbar=False,
|
|
152
150
|
cmap=None,
|
|
153
151
|
vmin_vmax=vmin_vmax,
|
|
154
152
|
)
|
|
155
153
|
elif spinn:
|
|
156
|
-
|
|
157
|
-
|
|
154
|
+
t_x = jnp.concatenate(
|
|
155
|
+
[
|
|
158
156
|
t * jnp.ones((xy_data[0].shape[0], 1)),
|
|
159
|
-
jnp.
|
|
160
|
-
[xy_data[0][..., None], xy_data[1][..., None]], axis
|
|
157
|
+
jnp.concatenate(
|
|
158
|
+
[xy_data[0][..., None], xy_data[1][..., None]], axis=-1
|
|
161
159
|
),
|
|
162
|
-
|
|
160
|
+
],
|
|
161
|
+
axis=-1,
|
|
163
162
|
)
|
|
163
|
+
values_grid = jnp.squeeze(fun(t_x)[0]).T
|
|
164
164
|
t_slice, _ = _plot_2D_statio(
|
|
165
165
|
values_grid,
|
|
166
166
|
mesh,
|
|
167
|
-
plot=False,
|
|
167
|
+
plot=False, # only use to compute t_slice
|
|
168
168
|
colorbar=True,
|
|
169
|
-
spinn=True,
|
|
170
169
|
vmin_vmax=vmin_vmax,
|
|
171
170
|
)
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
else:
|
|
182
|
-
im = ax.pcolormesh(mesh[0], mesh[1], t_slice, cmap=cmap)
|
|
171
|
+
|
|
172
|
+
im = ax.pcolormesh(
|
|
173
|
+
mesh[0],
|
|
174
|
+
mesh[1],
|
|
175
|
+
t_slice,
|
|
176
|
+
cmap=cmap,
|
|
177
|
+
vmin=vmin_vmax[0],
|
|
178
|
+
vmax=vmin_vmax[1],
|
|
179
|
+
)
|
|
183
180
|
ax.set_title(f"t = {times[idx] * Tmax:.2f}")
|
|
184
181
|
ax.cax.colorbar(im, format="%0.2f")
|
|
185
182
|
|
|
186
183
|
|
|
187
184
|
def _plot_2D_statio(
|
|
188
|
-
v_fun,
|
|
185
|
+
v_fun: Callable | Float[Array, "(nx*ny)^2 1"],
|
|
189
186
|
mesh: Float[Array, "nx*ny nx*ny"],
|
|
190
187
|
plot: Bool = True,
|
|
191
188
|
colorbar: Bool = True,
|
|
192
189
|
cmap: str = "inferno",
|
|
193
190
|
figsize: tuple[int, int] = (7, 7),
|
|
194
|
-
|
|
195
|
-
vmin_vmax: tuple[float, float] = None,
|
|
191
|
+
vmin_vmax: tuple[float, float] = [None, None],
|
|
196
192
|
):
|
|
197
193
|
"""Function that plot the function u(x) with 2-D input x using pcolormesh()
|
|
198
194
|
|
|
@@ -200,11 +196,11 @@ def _plot_2D_statio(
|
|
|
200
196
|
Parameters
|
|
201
197
|
----------
|
|
202
198
|
v_fun :
|
|
203
|
-
a vmapped function over jnp.array of shape (*, 2)
|
|
199
|
+
a vmapped function over jnp.array of shape (*, 2) OR a precomputed array of function values with shape compatible with `mesh`.
|
|
204
200
|
mesh :
|
|
205
201
|
a tuple of size 2, containing the x and y meshgrid.
|
|
206
202
|
plot : bool, optional
|
|
207
|
-
either
|
|
203
|
+
either displays the plot, or silently returns the grid of values `v_fun(mesh)`.
|
|
208
204
|
colorbar : bool, optional
|
|
209
205
|
add a colorbar, by default True
|
|
210
206
|
cmap :
|
|
@@ -213,7 +209,7 @@ def _plot_2D_statio(
|
|
|
213
209
|
By default (7, 7)
|
|
214
210
|
spinn :
|
|
215
211
|
True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
|
|
216
|
-
vmin_vmax:
|
|
212
|
+
vmin_vmax: list, optional
|
|
217
213
|
The colorbar minimum and maximum value. Defaults None.
|
|
218
214
|
|
|
219
215
|
Returns
|
|
@@ -223,26 +219,23 @@ def _plot_2D_statio(
|
|
|
223
219
|
"""
|
|
224
220
|
|
|
225
221
|
x_grid, y_grid = mesh
|
|
226
|
-
if
|
|
222
|
+
if callable(v_fun):
|
|
227
223
|
values = v_fun(jnp.vstack([x_grid.flatten(), y_grid.flatten()]).T)
|
|
228
224
|
values_grid = values.reshape(x_grid.shape)
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
values_grid = v_fun.T
|
|
225
|
+
else:
|
|
226
|
+
values_grid = v_fun.reshape(x_grid.shape)
|
|
232
227
|
|
|
233
228
|
if plot:
|
|
234
229
|
fig = plt.figure(figsize=figsize)
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
else:
|
|
245
|
-
im = plt.pcolormesh(x_grid, y_grid, values_grid, cmap=cmap)
|
|
230
|
+
im = plt.pcolormesh(
|
|
231
|
+
x_grid,
|
|
232
|
+
y_grid,
|
|
233
|
+
values_grid,
|
|
234
|
+
cmap=cmap,
|
|
235
|
+
vmin=vmin_vmax[0],
|
|
236
|
+
vmax=vmin_vmax[1],
|
|
237
|
+
)
|
|
238
|
+
|
|
246
239
|
if colorbar:
|
|
247
240
|
fig.colorbar(im, format="%0.2f")
|
|
248
241
|
# don't plt.show() because it is done in plot2d()
|
|
@@ -258,6 +251,7 @@ def plot1d_slice(
|
|
|
258
251
|
title: str = "",
|
|
259
252
|
figsize: tuple[int, int] = (10, 10),
|
|
260
253
|
spinn: Bool = False,
|
|
254
|
+
ax=None,
|
|
261
255
|
):
|
|
262
256
|
"""Function for plotting time slices of a function :math:`f(t_i, x)` where
|
|
263
257
|
`t_i` is time (1-D) and x is 1-D
|
|
@@ -275,29 +269,40 @@ def plot1d_slice(
|
|
|
275
269
|
default 1
|
|
276
270
|
title
|
|
277
271
|
title of the plot, by default ""
|
|
278
|
-
spinn :
|
|
279
|
-
True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
|
|
280
272
|
figsize
|
|
281
273
|
size of the figure, by default (10, 10)
|
|
274
|
+
spinn
|
|
275
|
+
True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
|
|
276
|
+
ax
|
|
277
|
+
A pre-defined `matplotlib.Axes` where you want to plot.
|
|
278
|
+
|
|
279
|
+
Returns
|
|
280
|
+
-------
|
|
281
|
+
ax
|
|
282
|
+
A `matplotlib.Axes` object
|
|
282
283
|
"""
|
|
283
284
|
if time_slices is None:
|
|
284
285
|
time_slices = jnp.array([0])
|
|
285
|
-
|
|
286
|
+
if ax is None:
|
|
287
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
288
|
+
|
|
286
289
|
for t in time_slices:
|
|
290
|
+
t_xdata = jnp.concatenate(
|
|
291
|
+
[t * jnp.ones((xdata.shape[0], 1)), xdata[:, None]], axis=1
|
|
292
|
+
)
|
|
287
293
|
if not spinn:
|
|
288
294
|
# fix t with partial : shape is (1,)
|
|
289
|
-
v_u_tfixed = vmap(
|
|
295
|
+
v_u_tfixed = vmap(fun)
|
|
290
296
|
# add an axis to xdata for the concatenate function in the neural net
|
|
291
|
-
values = v_u_tfixed(
|
|
297
|
+
values = v_u_tfixed(t_xdata)
|
|
292
298
|
elif spinn:
|
|
293
|
-
values = jnp.squeeze(
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
plt.title(title)
|
|
299
|
+
values = jnp.squeeze(fun(t_xdata)[0])
|
|
300
|
+
ax.plot(xdata, values, label=f"$t_i={t * Tmax:.2f}$")
|
|
301
|
+
ax.set_xlabel("x")
|
|
302
|
+
ax.set_ylabel(r"$u(t_i, x)$")
|
|
303
|
+
ax.legend()
|
|
304
|
+
ax.set_title(title)
|
|
305
|
+
return ax
|
|
301
306
|
|
|
302
307
|
|
|
303
308
|
def plot1d_image(
|
|
@@ -310,7 +315,7 @@ def plot1d_image(
|
|
|
310
315
|
colorbar: Bool = True,
|
|
311
316
|
cmap: str = "inferno",
|
|
312
317
|
spinn: Bool = False,
|
|
313
|
-
vmin_vmax: tuple[float, float] = None,
|
|
318
|
+
vmin_vmax: tuple[float, float] = [None, None],
|
|
314
319
|
):
|
|
315
320
|
"""Function for plotting the 2-D image of a function :math:`f(t, x)` where
|
|
316
321
|
`t` is time (1-D) and x is space (1-D).
|
|
@@ -339,30 +344,38 @@ def plot1d_image(
|
|
|
339
344
|
True if a SPINN is to be plotted. False for PINNs and HYPERPINNs
|
|
340
345
|
vmin_vmax:
|
|
341
346
|
The colorbar minimum and maximum value. Defaults None.
|
|
347
|
+
|
|
348
|
+
Returns
|
|
349
|
+
-------
|
|
350
|
+
fig, ax
|
|
351
|
+
A `matplotlib` `Figure` and `Axes` objects with the figure.
|
|
342
352
|
"""
|
|
343
353
|
|
|
344
354
|
mesh = jnp.meshgrid(times, xdata) # cartesian product
|
|
345
355
|
if not spinn:
|
|
346
356
|
# the trick is to use _plot2Dstatio
|
|
347
|
-
v_fun = vmap(lambda tx: fun(t=tx[0, None], x=tx[1, None]), 0, 0)
|
|
357
|
+
v_fun = vmap(fun) # lambda tx: fun(t=tx[0, None], x=tx[1, None]), 0, 0)
|
|
348
358
|
t_grid, x_grid = mesh
|
|
349
359
|
values_grid = v_fun(jnp.vstack([t_grid.flatten(), x_grid.flatten()]).T).reshape(
|
|
350
360
|
t_grid.shape
|
|
351
361
|
)
|
|
352
362
|
elif spinn:
|
|
353
|
-
values_grid = jnp.squeeze(
|
|
363
|
+
values_grid = jnp.squeeze(
|
|
364
|
+
fun(jnp.concatenate([times[..., None], xdata[..., None]], axis=-1))
|
|
365
|
+
).T
|
|
366
|
+
|
|
354
367
|
fig, ax = plt.subplots(1, 1, figsize=figsize)
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
else:
|
|
365
|
-
im = ax.pcolormesh(mesh[0] * Tmax, mesh[1], values_grid, cmap=cmap)
|
|
368
|
+
im = ax.pcolormesh(
|
|
369
|
+
mesh[0] * Tmax,
|
|
370
|
+
mesh[1],
|
|
371
|
+
values_grid,
|
|
372
|
+
cmap=cmap,
|
|
373
|
+
vmin=vmin_vmax[0],
|
|
374
|
+
vmax=vmin_vmax[1],
|
|
375
|
+
)
|
|
376
|
+
|
|
366
377
|
if colorbar:
|
|
367
378
|
fig.colorbar(im, format="%0.2f")
|
|
368
379
|
ax.set_title(title)
|
|
380
|
+
|
|
381
|
+
return fig, ax
|