jinns 0.4.2__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_display.py +78 -21
- jinns/loss/_DynamicLoss.py +405 -907
- jinns/loss/_LossPDE.py +303 -154
- jinns/loss/__init__.py +0 -6
- jinns/loss/_boundary_conditions.py +231 -65
- jinns/loss/_operators.py +201 -45
- jinns/utils/__init__.py +2 -1
- jinns/utils/_pinn.py +308 -0
- jinns/utils/_spinn.py +237 -0
- jinns/utils/_utils.py +32 -306
- {jinns-0.4.2.dist-info → jinns-0.5.0.dist-info}/METADATA +15 -2
- jinns-0.5.0.dist-info/RECORD +24 -0
- jinns-0.4.2.dist-info/RECORD +0 -22
- {jinns-0.4.2.dist-info → jinns-0.5.0.dist-info}/LICENSE +0 -0
- {jinns-0.4.2.dist-info → jinns-0.5.0.dist-info}/WHEEL +0 -0
- {jinns-0.4.2.dist-info → jinns-0.5.0.dist-info}/top_level.txt +0 -0
jinns/data/_display.py
CHANGED
|
@@ -13,6 +13,7 @@ def plot2d(
|
|
|
13
13
|
title="",
|
|
14
14
|
figsize=(7, 7),
|
|
15
15
|
cmap="inferno",
|
|
16
|
+
spinn=False,
|
|
16
17
|
):
|
|
17
18
|
"""Generic function for plotting functions over rectangular 2-D domains
|
|
18
19
|
:math:`\Omega`. It treats both the stationary case :math:`u(x)` or the
|
|
@@ -56,8 +57,24 @@ def plot2d(
|
|
|
56
57
|
|
|
57
58
|
if times is None:
|
|
58
59
|
# Statio case : expect a function of one argument fun(x)
|
|
59
|
-
|
|
60
|
-
|
|
60
|
+
if not spinn:
|
|
61
|
+
v_fun = vmap(fun, 0, 0)
|
|
62
|
+
_plot_2D_statio(
|
|
63
|
+
v_fun, mesh, plot=True, colorbar=True, cmap=cmap, figsize=figsize
|
|
64
|
+
)
|
|
65
|
+
elif spinn:
|
|
66
|
+
values_grid = jnp.squeeze(
|
|
67
|
+
fun(jnp.stack([xy_data[0][..., None], xy_data[1][..., None]], axis=1))
|
|
68
|
+
)
|
|
69
|
+
_plot_2D_statio(
|
|
70
|
+
values_grid,
|
|
71
|
+
mesh,
|
|
72
|
+
plot=True,
|
|
73
|
+
colorbar=True,
|
|
74
|
+
cmap=cmap,
|
|
75
|
+
spinn=True,
|
|
76
|
+
figsize=figsize,
|
|
77
|
+
)
|
|
61
78
|
plt.title(title)
|
|
62
79
|
|
|
63
80
|
else:
|
|
@@ -81,17 +98,30 @@ def plot2d(
|
|
|
81
98
|
)
|
|
82
99
|
|
|
83
100
|
for idx, (t, ax) in enumerate(zip(times, grid)):
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
101
|
+
if not spinn:
|
|
102
|
+
v_fun_at_t = vmap(lambda x: fun(t=jnp.array([t]), x=x), 0, 0)
|
|
103
|
+
t_slice, _ = _plot_2D_statio(
|
|
104
|
+
v_fun_at_t, mesh, plot=False, colorbar=False, cmap=None
|
|
105
|
+
)
|
|
106
|
+
elif spinn:
|
|
107
|
+
values_grid = jnp.squeeze(
|
|
108
|
+
fun(
|
|
109
|
+
t * jnp.ones((xy_data[0].shape[0], 1)),
|
|
110
|
+
jnp.stack(
|
|
111
|
+
[xy_data[0][..., None], xy_data[1][..., None]], axis=1
|
|
112
|
+
),
|
|
113
|
+
)[0]
|
|
114
|
+
)
|
|
115
|
+
t_slice, _ = _plot_2D_statio(
|
|
116
|
+
values_grid, mesh, plot=False, colorbar=True, spinn=True
|
|
117
|
+
)
|
|
88
118
|
im = ax.pcolormesh(mesh[0], mesh[1], t_slice, cmap=cmap)
|
|
89
119
|
ax.set_title(f"t = {times[idx] * Tmax}")
|
|
90
120
|
ax.cax.colorbar(im)
|
|
91
121
|
|
|
92
122
|
|
|
93
123
|
def _plot_2D_statio(
|
|
94
|
-
v_fun, mesh, plot=True, colorbar=True, cmap="inferno", figsize=(7, 7)
|
|
124
|
+
v_fun, mesh, plot=True, colorbar=True, cmap="inferno", figsize=(7, 7), spinn=False
|
|
95
125
|
):
|
|
96
126
|
"""Function that plot the function u(x) with 2-D input x using pcolormesh()
|
|
97
127
|
|
|
@@ -114,8 +144,12 @@ def _plot_2D_statio(
|
|
|
114
144
|
"""
|
|
115
145
|
|
|
116
146
|
x_grid, y_grid = mesh
|
|
117
|
-
|
|
118
|
-
|
|
147
|
+
if not spinn:
|
|
148
|
+
values = v_fun(jnp.vstack([x_grid.flatten(), y_grid.flatten()]).T)
|
|
149
|
+
values_grid = values.reshape(x_grid.shape)
|
|
150
|
+
elif spinn:
|
|
151
|
+
# in this case v_fun is directly the values :)
|
|
152
|
+
values_grid = v_fun.T
|
|
119
153
|
|
|
120
154
|
if plot:
|
|
121
155
|
fig = plt.figure(figsize=figsize)
|
|
@@ -128,7 +162,13 @@ def _plot_2D_statio(
|
|
|
128
162
|
|
|
129
163
|
|
|
130
164
|
def plot1d_slice(
|
|
131
|
-
fun,
|
|
165
|
+
fun,
|
|
166
|
+
xdata,
|
|
167
|
+
time_slices=jnp.array([0]),
|
|
168
|
+
Tmax=1,
|
|
169
|
+
title="",
|
|
170
|
+
figsize=(10, 10),
|
|
171
|
+
spinn=False,
|
|
132
172
|
):
|
|
133
173
|
"""Function for plotting time slices of a function :math:`f(t_i, x)` where
|
|
134
174
|
`t` is time (1-D) and x is 1-D
|
|
@@ -151,10 +191,16 @@ def plot1d_slice(
|
|
|
151
191
|
"""
|
|
152
192
|
plt.figure(figsize=figsize)
|
|
153
193
|
for t in time_slices:
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
194
|
+
if not spinn:
|
|
195
|
+
# fix t with partial : shape is (1,)
|
|
196
|
+
v_u_tfixed = vmap(partial(fun, t=t * jnp.ones((1,))), 0, 0)
|
|
197
|
+
# add an axis to xdata for the concatenate function in the neural net
|
|
198
|
+
values = v_u_tfixed(x=xdata[:, None])
|
|
199
|
+
elif spinn:
|
|
200
|
+
values = jnp.squeeze(
|
|
201
|
+
fun(t * jnp.ones((xdata.shape[0], 1)), xdata[..., None])[0]
|
|
202
|
+
)
|
|
203
|
+
plt.plot(xdata, values, label=f"$t_i={t * Tmax}$")
|
|
158
204
|
plt.xlabel("x")
|
|
159
205
|
plt.ylabel(r"$u(t_i, x)$")
|
|
160
206
|
plt.legend()
|
|
@@ -162,7 +208,15 @@ def plot1d_slice(
|
|
|
162
208
|
|
|
163
209
|
|
|
164
210
|
def plot1d_image(
|
|
165
|
-
fun,
|
|
211
|
+
fun,
|
|
212
|
+
xdata,
|
|
213
|
+
times,
|
|
214
|
+
Tmax=1,
|
|
215
|
+
title="",
|
|
216
|
+
figsize=(10, 10),
|
|
217
|
+
colorbar=True,
|
|
218
|
+
cmap="inferno",
|
|
219
|
+
spinn=False,
|
|
166
220
|
):
|
|
167
221
|
"""Function for plotting the 2-D image of a function :math:`f(t, x)` where
|
|
168
222
|
`t` is time (1-D) and x is space (1-D).
|
|
@@ -186,12 +240,15 @@ def plot1d_image(
|
|
|
186
240
|
"""
|
|
187
241
|
|
|
188
242
|
mesh = jnp.meshgrid(times, xdata) # cartesian product
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
t_grid.
|
|
194
|
-
|
|
243
|
+
if not spinn:
|
|
244
|
+
# the trick is to use _plot2Dstatio
|
|
245
|
+
v_fun = vmap(lambda tx: fun(t=tx[0, None], x=tx[1, None]), 0, 0)
|
|
246
|
+
t_grid, x_grid = mesh
|
|
247
|
+
values_grid = v_fun(jnp.vstack([t_grid.flatten(), x_grid.flatten()]).T).reshape(
|
|
248
|
+
t_grid.shape
|
|
249
|
+
)
|
|
250
|
+
elif spinn:
|
|
251
|
+
values_grid = jnp.squeeze(fun((times[..., None]), xdata[..., None]).T)
|
|
195
252
|
fig, ax = plt.subplots(1, 1, figsize=figsize)
|
|
196
253
|
im = ax.pcolormesh(mesh[0] * Tmax, mesh[1], values_grid, cmap=cmap)
|
|
197
254
|
if colorbar:
|