jinns 1.1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_Batchs.py +4 -8
- jinns/data/_DataGenerators.py +532 -341
- jinns/loss/_DynamicLoss.py +150 -173
- jinns/loss/_DynamicLossAbstract.py +25 -73
- jinns/loss/_LossODE.py +3 -3
- jinns/loss/_LossPDE.py +27 -36
- jinns/loss/__init__.py +7 -6
- jinns/loss/_boundary_conditions.py +148 -279
- jinns/loss/_loss_utils.py +78 -56
- jinns/loss/_operators.py +441 -184
- jinns/plot/_plot.py +111 -98
- jinns/solver/_rar.py +102 -407
- jinns/solver/_solve.py +73 -38
- jinns/solver/_utils.py +122 -0
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +3 -1
- jinns/utils/_hyperpinn.py +17 -7
- jinns/utils/_pinn.py +17 -27
- jinns/utils/_ppinn.py +227 -0
- jinns/utils/_save_load.py +13 -13
- jinns/utils/_spinn.py +24 -43
- jinns/utils/_types.py +1 -0
- jinns/utils/_utils.py +40 -12
- jinns-1.2.0.dist-info/METADATA +127 -0
- jinns-1.2.0.dist-info/RECORD +41 -0
- {jinns-1.1.0.dist-info → jinns-1.2.0.dist-info}/WHEEL +1 -1
- jinns-1.1.0.dist-info/METADATA +0 -85
- jinns-1.1.0.dist-info/RECORD +0 -39
- {jinns-1.1.0.dist-info → jinns-1.2.0.dist-info}/AUTHORS +0 -0
- {jinns-1.1.0.dist-info → jinns-1.2.0.dist-info}/LICENSE +0 -0
- {jinns-1.1.0.dist-info → jinns-1.2.0.dist-info}/top_level.txt +0 -0
|
@@ -11,11 +11,7 @@ import jax
|
|
|
11
11
|
import jax.numpy as jnp
|
|
12
12
|
from jax import vmap, grad
|
|
13
13
|
import equinox as eqx
|
|
14
|
-
from jinns.utils._utils import
|
|
15
|
-
_get_grid,
|
|
16
|
-
_check_user_func_return,
|
|
17
|
-
)
|
|
18
|
-
from jinns.parameters._params import _get_vmap_in_axes_params
|
|
14
|
+
from jinns.utils._utils import get_grid, _subtract_with_check
|
|
19
15
|
from jinns.data._Batchs import *
|
|
20
16
|
from jinns.utils._pinn import PINN
|
|
21
17
|
from jinns.utils._spinn import SPINN
|
|
@@ -26,12 +22,15 @@ if TYPE_CHECKING:
|
|
|
26
22
|
|
|
27
23
|
def _compute_boundary_loss(
|
|
28
24
|
boundary_condition_type: str,
|
|
29
|
-
f: Callable
|
|
25
|
+
f: Callable[
|
|
26
|
+
[Float[Array, "dim"] | Float[Array, "dim + 1"]], Float[Array, "dim_solution"]
|
|
27
|
+
],
|
|
30
28
|
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
31
29
|
u: eqx.Module,
|
|
32
|
-
params:
|
|
30
|
+
params: AnyParams,
|
|
33
31
|
facet: int,
|
|
34
32
|
dim_to_apply: slice,
|
|
33
|
+
vmap_in_axes: tuple,
|
|
35
34
|
) -> float:
|
|
36
35
|
r"""A generic function that will compute the mini-batch MSE of a
|
|
37
36
|
boundary condition in the stationary case, resp. non-stationary, given by:
|
|
@@ -62,9 +61,9 @@ def _compute_boundary_loss(
|
|
|
62
61
|
unitary outgoing vector normal to $\partial\Omega$
|
|
63
62
|
f
|
|
64
63
|
the function to be matched in the boundary condition. It should have
|
|
65
|
-
one
|
|
64
|
+
one argument only (for `t`, `x` or `t_x`) (other are ignored).
|
|
66
65
|
batch
|
|
67
|
-
|
|
66
|
+
the batch
|
|
68
67
|
u
|
|
69
68
|
a PINN
|
|
70
69
|
params
|
|
@@ -75,233 +74,54 @@ def _compute_boundary_loss(
|
|
|
75
74
|
dim_to_apply
|
|
76
75
|
A `jnp.s_` object which indicates which dimension(s) of u will be forced
|
|
77
76
|
to match the boundary condition
|
|
77
|
+
vmap_in_axes
|
|
78
|
+
A tuple object which specifies the in_axes of the vmapping
|
|
78
79
|
|
|
79
80
|
Returns
|
|
80
81
|
-------
|
|
81
82
|
scalar
|
|
82
83
|
the MSE computed on `batch`
|
|
83
84
|
"""
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
):
|
|
92
|
-
mse = boundary_neumann_statio(f, batch, u, params, facet, dim_to_apply)
|
|
93
|
-
elif isinstance(batch, PDENonStatioBatch):
|
|
94
|
-
if boundary_condition_type.lower() in "dirichlet":
|
|
95
|
-
mse = boundary_dirichlet_nonstatio(f, batch, u, params, facet, dim_to_apply)
|
|
96
|
-
elif any(
|
|
97
|
-
boundary_condition_type.lower() in s
|
|
98
|
-
for s in ["von neumann", "vn", "vonneumann"]
|
|
99
|
-
):
|
|
100
|
-
mse = boundary_neumann_nonstatio(f, batch, u, params, facet, dim_to_apply)
|
|
85
|
+
if boundary_condition_type.lower() in "dirichlet":
|
|
86
|
+
mse = boundary_dirichlet(f, batch, u, params, facet, dim_to_apply, vmap_in_axes)
|
|
87
|
+
elif any(
|
|
88
|
+
boundary_condition_type.lower() in s
|
|
89
|
+
for s in ["von neumann", "vn", "vonneumann"]
|
|
90
|
+
):
|
|
91
|
+
mse = boundary_neumann(f, batch, u, params, facet, dim_to_apply, vmap_in_axes)
|
|
101
92
|
else:
|
|
102
|
-
raise ValueError("Wrong type of
|
|
93
|
+
raise ValueError("Wrong type of initial condition")
|
|
103
94
|
return mse
|
|
104
95
|
|
|
105
96
|
|
|
106
|
-
def
|
|
107
|
-
f: Callable
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
facet: int,
|
|
112
|
-
dim_to_apply: slice,
|
|
113
|
-
) -> float:
|
|
114
|
-
r"""
|
|
115
|
-
This omega boundary condition enforces a solution that is equal to f on
|
|
116
|
-
border batch.
|
|
117
|
-
|
|
118
|
-
__Note__: if using a batch.param_batch_dict, we need to resolve the
|
|
119
|
-
vmapping axes here however params["eq_params"] has already been fed with
|
|
120
|
-
the batch in the `evaluate()` of `LossPDEStatio`.
|
|
121
|
-
|
|
122
|
-
Parameters
|
|
123
|
-
----------
|
|
124
|
-
f
|
|
125
|
-
the constraint function
|
|
126
|
-
batch
|
|
127
|
-
A PDEStatioBatch object.
|
|
128
|
-
u
|
|
129
|
-
The PINN
|
|
130
|
-
params
|
|
131
|
-
Params or ParamsDict
|
|
132
|
-
dim_to_apply
|
|
133
|
-
A jnp.s\_ object. The dimension of u on which to apply the boundary condition
|
|
134
|
-
"""
|
|
135
|
-
_, border_batch = batch.inside_batch, batch.border_batch
|
|
136
|
-
border_batch = border_batch[..., facet]
|
|
137
|
-
|
|
138
|
-
if isinstance(u, PINN):
|
|
139
|
-
vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
|
|
140
|
-
vmap_in_axes_x = (0,)
|
|
141
|
-
|
|
142
|
-
v_u_boundary = vmap(
|
|
143
|
-
lambda dx, params: u(dx, params)[dim_to_apply] - f(dx),
|
|
144
|
-
vmap_in_axes_x + vmap_in_axes_params,
|
|
145
|
-
0,
|
|
146
|
-
)
|
|
147
|
-
|
|
148
|
-
mse_u_boundary = jnp.sum((v_u_boundary(border_batch, params)) ** 2, axis=-1)
|
|
149
|
-
elif isinstance(u, SPINN):
|
|
150
|
-
values = u(border_batch, params)[..., dim_to_apply]
|
|
151
|
-
x_grid = _get_grid(border_batch)
|
|
152
|
-
boundaries = _check_user_func_return(f(x_grid), values.shape)
|
|
153
|
-
res = values - boundaries
|
|
154
|
-
mse_u_boundary = jnp.sum(
|
|
155
|
-
res**2,
|
|
156
|
-
axis=-1,
|
|
157
|
-
)
|
|
158
|
-
else:
|
|
159
|
-
raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
|
|
160
|
-
return mse_u_boundary
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
def boundary_neumann_statio(
|
|
164
|
-
f: Callable,
|
|
165
|
-
batch: PDEStatioBatch,
|
|
166
|
-
u: eqx.Module,
|
|
167
|
-
params: Params | ParamsDict,
|
|
168
|
-
facet: int,
|
|
169
|
-
dim_to_apply: slice,
|
|
170
|
-
) -> float:
|
|
171
|
-
r"""
|
|
172
|
-
This omega boundary condition enforces a solution where $\nabla u\cdot
|
|
173
|
-
n$ is equal to `f` on omega borders. $n$ is the unitary
|
|
174
|
-
outgoing vector normal at border $\partial\Omega$.
|
|
175
|
-
|
|
176
|
-
__Note__: if using a batch.param_batch_dict, we need to resolve the
|
|
177
|
-
vmapping axes here however params["eq_params"] has already been fed with
|
|
178
|
-
the batch in the `evaluate()` of `LossPDEStatio`.
|
|
179
|
-
|
|
180
|
-
Parameters
|
|
181
|
-
----------
|
|
182
|
-
f
|
|
183
|
-
the constraint function
|
|
184
|
-
batch
|
|
185
|
-
A PDEStatioBatch object.
|
|
186
|
-
u
|
|
187
|
-
The PINN
|
|
188
|
-
params
|
|
189
|
-
The dictionary of parameters of the model.
|
|
190
|
-
Typically, it is a dictionary of
|
|
191
|
-
dictionaries: `eq_params` and `nn_params``, respectively the
|
|
192
|
-
differential equation parameters and the neural network parameter
|
|
193
|
-
facet
|
|
194
|
-
An integer which represents the id of the facet which is currently
|
|
195
|
-
considered (in the order provided wy the DataGenerator which is fixed)
|
|
196
|
-
dim_to_apply
|
|
197
|
-
A jnp.s\_ object. The dimension of u on which to apply the boundary condition
|
|
198
|
-
"""
|
|
199
|
-
_, border_batch = batch.inside_batch, batch.border_batch
|
|
200
|
-
border_batch = border_batch[..., facet]
|
|
201
|
-
|
|
202
|
-
# We resort to the shape of the border_batch to determine the dimension as
|
|
203
|
-
# described in the border_batch function
|
|
204
|
-
if jnp.squeeze(border_batch).ndim == 0: # case 1D borders (just a scalar)
|
|
205
|
-
n = jnp.array([1, -1]) # the unit vectors normal to the two borders
|
|
206
|
-
|
|
207
|
-
else: # case 2D borders (because 3D borders are not supported yet)
|
|
208
|
-
# they are in the order: left, right, bottom, top so we give the normal
|
|
209
|
-
# outgoing vectors accordingly with shape in concordance with
|
|
210
|
-
# border_batch shape (batch_size, ndim, nfacets)
|
|
211
|
-
n = jnp.array([[-1, 1, 0, 0], [0, 0, -1, 1]])
|
|
212
|
-
|
|
213
|
-
if isinstance(u, PINN):
|
|
214
|
-
u_ = lambda x, params: jnp.squeeze(u(x, params)[dim_to_apply])
|
|
215
|
-
|
|
216
|
-
vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
|
|
217
|
-
vmap_in_axes_x = (0,)
|
|
218
|
-
|
|
219
|
-
v_neumann = vmap(
|
|
220
|
-
lambda dx, params: jnp.dot(
|
|
221
|
-
grad(u_, 0)(dx, params),
|
|
222
|
-
n[..., facet],
|
|
223
|
-
)
|
|
224
|
-
- f(dx),
|
|
225
|
-
vmap_in_axes_x + vmap_in_axes_params,
|
|
226
|
-
0,
|
|
227
|
-
)
|
|
228
|
-
mse_u_boundary = jnp.sum((v_neumann(border_batch, params)) ** 2, axis=-1)
|
|
229
|
-
elif isinstance(u, SPINN):
|
|
230
|
-
# the gradient we see in the PINN case can get gradients wrt to x
|
|
231
|
-
# dimensions at once. But it would be very inefficient in SPINN because
|
|
232
|
-
# of the high dim output of u. So we do 2 explicit forward AD, handling all the
|
|
233
|
-
# high dim output at once
|
|
234
|
-
if border_batch.shape[0] == 1: # i.e. case 1D
|
|
235
|
-
_, du_dx = jax.jvp(
|
|
236
|
-
lambda x: u(
|
|
237
|
-
x,
|
|
238
|
-
params,
|
|
239
|
-
)[..., dim_to_apply],
|
|
240
|
-
(border_batch,),
|
|
241
|
-
(jnp.ones_like(border_batch),),
|
|
242
|
-
)
|
|
243
|
-
values = du_dx * n[facet]
|
|
244
|
-
elif border_batch.shape[-1] == 2:
|
|
245
|
-
tangent_vec_0 = jnp.repeat(
|
|
246
|
-
jnp.array([1.0, 0.0])[None], border_batch.shape[0], axis=0
|
|
247
|
-
)
|
|
248
|
-
tangent_vec_1 = jnp.repeat(
|
|
249
|
-
jnp.array([0.0, 1.0])[None], border_batch.shape[0], axis=0
|
|
250
|
-
)
|
|
251
|
-
_, du_dx1 = jax.jvp(
|
|
252
|
-
lambda x: u(
|
|
253
|
-
x,
|
|
254
|
-
params,
|
|
255
|
-
),
|
|
256
|
-
(border_batch,),
|
|
257
|
-
(tangent_vec_0,),
|
|
258
|
-
)
|
|
259
|
-
_, du_dx2 = jax.jvp(
|
|
260
|
-
lambda x: u(
|
|
261
|
-
x,
|
|
262
|
-
params,
|
|
263
|
-
),
|
|
264
|
-
(border_batch,),
|
|
265
|
-
(tangent_vec_1,),
|
|
266
|
-
)
|
|
267
|
-
values = du_dx1 * n[0, facet] + du_dx2 * n[1, facet] # dot product
|
|
268
|
-
# explicitly written
|
|
269
|
-
else:
|
|
270
|
-
raise ValueError("Not implemented, we'll do that with a loop")
|
|
271
|
-
|
|
272
|
-
x_grid = _get_grid(border_batch)
|
|
273
|
-
boundaries = _check_user_func_return(f(x_grid), values.shape)
|
|
274
|
-
res = values - boundaries
|
|
275
|
-
mse_u_boundary = jnp.sum(res**2, axis=-1)
|
|
276
|
-
else:
|
|
277
|
-
raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
|
|
278
|
-
return mse_u_boundary
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
def boundary_dirichlet_nonstatio(
|
|
282
|
-
f: Callable,
|
|
283
|
-
batch: PDENonStatioBatch,
|
|
97
|
+
def boundary_dirichlet(
|
|
98
|
+
f: Callable[
|
|
99
|
+
[Float[Array, "dim"] | Float[Array, "dim + 1"]], Float[Array, "dim_solution"]
|
|
100
|
+
],
|
|
101
|
+
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
284
102
|
u: eqx.Module,
|
|
285
103
|
params: Params | ParamsDict,
|
|
286
104
|
facet: int,
|
|
287
105
|
dim_to_apply: slice,
|
|
106
|
+
vmap_in_axes: tuple,
|
|
288
107
|
) -> float:
|
|
289
108
|
r"""
|
|
290
109
|
This omega boundary condition enforces a solution that is equal to `f`
|
|
291
|
-
at `times_batch` x `
|
|
110
|
+
at `times_batch` x `omega_border` (non stationary case) or at `omega_border`
|
|
111
|
+
(stationary case)
|
|
292
112
|
|
|
293
113
|
__Note__: if using a batch.param_batch_dict, we need to resolve the
|
|
294
114
|
vmapping axes here however params["eq_params"] has already been fed with
|
|
295
|
-
the batch in the `evaluate()` of `
|
|
115
|
+
the batch in the `evaluate()` of `LossPDE*`.
|
|
296
116
|
|
|
297
117
|
Parameters
|
|
298
118
|
----------
|
|
299
119
|
f
|
|
300
120
|
the constraint function
|
|
301
121
|
batch
|
|
302
|
-
|
|
122
|
+
The batch
|
|
303
123
|
u
|
|
304
|
-
The PINN
|
|
124
|
+
The PINN or SPINN
|
|
305
125
|
params
|
|
306
126
|
The dictionary of parameters of the model.
|
|
307
127
|
Typically, it is a dictionary of
|
|
@@ -312,49 +132,50 @@ def boundary_dirichlet_nonstatio(
|
|
|
312
132
|
considered (in the order provided wy the DataGenerator which is fixed)
|
|
313
133
|
dim_to_apply
|
|
314
134
|
A jnp.s\_ object. The dimension of u on which to apply the boundary condition
|
|
135
|
+
vmap_in_axes
|
|
136
|
+
A tuple object which specifies the in_axes of the vmapping
|
|
315
137
|
"""
|
|
316
|
-
|
|
317
|
-
|
|
138
|
+
batch_array = batch.border_batch
|
|
139
|
+
batch_array = batch_array[..., facet]
|
|
318
140
|
|
|
319
141
|
if isinstance(u, PINN):
|
|
320
|
-
vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
|
|
321
|
-
vmap_in_axes_x_t = (0, 0)
|
|
322
|
-
|
|
323
142
|
v_u_boundary = vmap(
|
|
324
|
-
lambda
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
143
|
+
lambda inputs, params: _subtract_with_check(
|
|
144
|
+
f(inputs),
|
|
145
|
+
u(
|
|
146
|
+
inputs,
|
|
147
|
+
params,
|
|
148
|
+
)[dim_to_apply],
|
|
149
|
+
cause="boundary condition fun",
|
|
150
|
+
),
|
|
151
|
+
vmap_in_axes,
|
|
331
152
|
0,
|
|
332
153
|
)
|
|
333
|
-
res = v_u_boundary(
|
|
154
|
+
res = v_u_boundary(batch_array, params)
|
|
334
155
|
mse_u_boundary = jnp.sum(
|
|
335
156
|
res**2,
|
|
336
157
|
axis=-1,
|
|
337
158
|
)
|
|
338
159
|
elif isinstance(u, SPINN):
|
|
339
|
-
values = u(
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
f(tx_grid[..., 0:1], tx_grid[..., 1:]), values.shape
|
|
343
|
-
)
|
|
344
|
-
res = values - boundaries
|
|
160
|
+
values = u(batch_array, params)[..., dim_to_apply]
|
|
161
|
+
grid = get_grid(batch_array)
|
|
162
|
+
res = _subtract_with_check(f(grid), values, cause="boundary condition fun")
|
|
345
163
|
mse_u_boundary = jnp.sum(res**2, axis=-1)
|
|
346
164
|
else:
|
|
347
165
|
raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
|
|
348
166
|
return mse_u_boundary
|
|
349
167
|
|
|
350
168
|
|
|
351
|
-
def
|
|
352
|
-
f: Callable
|
|
353
|
-
|
|
169
|
+
def boundary_neumann(
|
|
170
|
+
f: Callable[
|
|
171
|
+
[Float[Array, "dim"] | Float[Array, "dim + 1"]], Float[Array, "dim_solution"]
|
|
172
|
+
],
|
|
173
|
+
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
354
174
|
u: eqx.Module,
|
|
355
175
|
params: Params | ParamsDict,
|
|
356
176
|
facet: int,
|
|
357
177
|
dim_to_apply: slice,
|
|
178
|
+
vmap_in_axes: tuple,
|
|
358
179
|
) -> float:
|
|
359
180
|
r"""
|
|
360
181
|
This omega boundary condition enforces a solution where $\nabla u\cdot
|
|
@@ -371,7 +192,7 @@ def boundary_neumann_nonstatio(
|
|
|
371
192
|
f:
|
|
372
193
|
the constraint function
|
|
373
194
|
batch
|
|
374
|
-
|
|
195
|
+
The batch
|
|
375
196
|
u
|
|
376
197
|
The PINN
|
|
377
198
|
params
|
|
@@ -384,13 +205,15 @@ def boundary_neumann_nonstatio(
|
|
|
384
205
|
considered (in the order provided wy the DataGenerator which is fixed)
|
|
385
206
|
dim_to_apply
|
|
386
207
|
A jnp.s\_ object. The dimension of u on which to apply the boundary condition
|
|
208
|
+
vmap_in_axes
|
|
209
|
+
A tuple object which specifies the in_axes of the vmapping
|
|
387
210
|
"""
|
|
388
|
-
|
|
389
|
-
|
|
211
|
+
batch_array = batch.border_batch
|
|
212
|
+
batch_array = batch_array[..., facet]
|
|
390
213
|
|
|
391
214
|
# We resort to the shape of the border_batch to determine the dimension as
|
|
392
215
|
# described in the border_batch function
|
|
393
|
-
if jnp.squeeze(
|
|
216
|
+
if jnp.squeeze(batch_array).ndim == 0: # case 1D borders (just a scalar)
|
|
394
217
|
n = jnp.array([1, -1]) # the unit vectors normal to the two borders
|
|
395
218
|
|
|
396
219
|
else: # case 2D borders (because 3D borders are not supported yet)
|
|
@@ -400,24 +223,41 @@ def boundary_neumann_nonstatio(
|
|
|
400
223
|
n = jnp.array([[-1, 1, 0, 0], [0, 0, -1, 1]])
|
|
401
224
|
|
|
402
225
|
if isinstance(u, PINN):
|
|
403
|
-
vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
|
|
404
|
-
vmap_in_axes_x_t = (0, 0)
|
|
405
226
|
|
|
406
|
-
u_ = lambda
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
227
|
+
u_ = lambda inputs, params: jnp.squeeze(u(inputs, params)[dim_to_apply])
|
|
228
|
+
|
|
229
|
+
if u.eq_type == "statio_PDE":
|
|
230
|
+
v_neumann = vmap(
|
|
231
|
+
lambda inputs, params: _subtract_with_check(
|
|
232
|
+
f(inputs),
|
|
233
|
+
jnp.dot(
|
|
234
|
+
grad(u_, 0)(inputs, params),
|
|
235
|
+
n[..., facet],
|
|
236
|
+
),
|
|
237
|
+
cause="boundary condition fun",
|
|
238
|
+
),
|
|
239
|
+
vmap_in_axes,
|
|
240
|
+
0,
|
|
411
241
|
)
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
242
|
+
elif u.eq_type == "nonstatio_PDE":
|
|
243
|
+
v_neumann = vmap(
|
|
244
|
+
lambda inputs, params: _subtract_with_check(
|
|
245
|
+
f(inputs),
|
|
246
|
+
jnp.dot(
|
|
247
|
+
grad(u_, 0)(inputs, params)[1:], # get rid of time dim
|
|
248
|
+
n[..., facet],
|
|
249
|
+
),
|
|
250
|
+
cause="boundary condition fun",
|
|
251
|
+
),
|
|
252
|
+
vmap_in_axes,
|
|
253
|
+
0,
|
|
254
|
+
)
|
|
255
|
+
else:
|
|
256
|
+
raise ValueError("Wrong u.eq_type")
|
|
416
257
|
mse_u_boundary = jnp.sum(
|
|
417
258
|
(
|
|
418
259
|
v_neumann(
|
|
419
|
-
|
|
420
|
-
omega_border_batch,
|
|
260
|
+
batch_array,
|
|
421
261
|
params,
|
|
422
262
|
)
|
|
423
263
|
)
|
|
@@ -430,40 +270,69 @@ def boundary_neumann_nonstatio(
|
|
|
430
270
|
# dimensions at once. But it would be very inefficient in SPINN because
|
|
431
271
|
# of the high dim output of u. So we do 2 explicit forward AD, handling all the
|
|
432
272
|
# high dim output at once
|
|
433
|
-
if
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
)
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
(
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
273
|
+
if (batch_array.shape[0] == 1 and isinstance(batch, PDEStatioBatch)) or (
|
|
274
|
+
batch_array.shape[-1] == 2 and isinstance(batch, PDENonStatioBatch)
|
|
275
|
+
):
|
|
276
|
+
if u.eq_type == "statio_PDE":
|
|
277
|
+
_, du_dx = jax.jvp(
|
|
278
|
+
lambda inputs: u(inputs, params)[..., dim_to_apply],
|
|
279
|
+
(batch_array,),
|
|
280
|
+
(jnp.ones_like(batch_array),),
|
|
281
|
+
)
|
|
282
|
+
values = du_dx * n[facet]
|
|
283
|
+
if u.eq_type == "nonstatio_PDE":
|
|
284
|
+
_, du_dx = jax.jvp(
|
|
285
|
+
lambda inputs: u(inputs, params)[..., dim_to_apply],
|
|
286
|
+
(batch_array,),
|
|
287
|
+
(jnp.ones_like(batch_array),),
|
|
288
|
+
)
|
|
289
|
+
values = du_dx[..., 1] * n[facet]
|
|
290
|
+
elif (batch_array.shape[-1] == 2 and isinstance(batch, PDEStatioBatch)) or (
|
|
291
|
+
batch_array.shape[-1] == 3 and isinstance(batch, PDENonStatioBatch)
|
|
292
|
+
):
|
|
293
|
+
if u.eq_type == "statio_PDE":
|
|
294
|
+
tangent_vec_0 = jnp.repeat(
|
|
295
|
+
jnp.array([1.0, 0.0])[None], batch_array.shape[0], axis=0
|
|
296
|
+
)
|
|
297
|
+
tangent_vec_1 = jnp.repeat(
|
|
298
|
+
jnp.array([0.0, 1.0])[None], batch_array.shape[0], axis=0
|
|
299
|
+
)
|
|
300
|
+
_, du_dx1 = jax.jvp(
|
|
301
|
+
lambda inputs: u(inputs, params)[..., dim_to_apply],
|
|
302
|
+
(batch_array,),
|
|
303
|
+
(tangent_vec_0,),
|
|
304
|
+
)
|
|
305
|
+
_, du_dx2 = jax.jvp(
|
|
306
|
+
lambda inputs: u(inputs, params)[..., dim_to_apply],
|
|
307
|
+
(batch_array,),
|
|
308
|
+
(tangent_vec_1,),
|
|
309
|
+
)
|
|
310
|
+
values = du_dx1 * n[0, facet] + du_dx2 * n[1, facet] # dot product
|
|
311
|
+
if u.eq_type == "nonstatio_PDE":
|
|
312
|
+
tangent_vec_0 = jnp.repeat(
|
|
313
|
+
jnp.array([0.0, 1.0, 0.0])[None], batch_array.shape[0], axis=0
|
|
314
|
+
)
|
|
315
|
+
tangent_vec_1 = jnp.repeat(
|
|
316
|
+
jnp.array([0.0, 0.0, 1.0])[None], batch_array.shape[0], axis=0
|
|
317
|
+
)
|
|
318
|
+
_, du_dx1 = jax.jvp(
|
|
319
|
+
lambda inputs: u(inputs, params)[..., dim_to_apply],
|
|
320
|
+
(batch_array,),
|
|
321
|
+
(tangent_vec_0,),
|
|
322
|
+
)
|
|
323
|
+
_, du_dx2 = jax.jvp(
|
|
324
|
+
lambda inputs: u(inputs, params)[..., dim_to_apply],
|
|
325
|
+
(batch_array,),
|
|
326
|
+
(tangent_vec_1,),
|
|
327
|
+
)
|
|
328
|
+
values = (
|
|
329
|
+
du_dx1.squeeze() * n[0, facet] + du_dx2.squeeze() * n[1, facet]
|
|
330
|
+
) # dot product
|
|
459
331
|
else:
|
|
460
332
|
raise ValueError("Not implemented, we'll do that with a loop")
|
|
461
333
|
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
f(tx_grid[..., 0:1], tx_grid[..., 1:]), values.shape
|
|
465
|
-
)
|
|
466
|
-
res = values - boundaries
|
|
334
|
+
grid = get_grid(batch_array)
|
|
335
|
+
res = _subtract_with_check(f(grid), values, cause="boundary condition fun")
|
|
467
336
|
mse_u_boundary = jnp.sum(
|
|
468
337
|
res**2,
|
|
469
338
|
axis=-1,
|