jinns 0.6.0__py3-none-any.whl → 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/loss/_DynamicLoss.py +10 -34
- jinns/loss/_DynamicLossAbstract.py +91 -10
- jinns/loss/_LossPDE.py +82 -16
- jinns/loss/_boundary_conditions.py +48 -20
- jinns/loss/_operators.py +10 -10
- jinns/utils/_pinn.py +65 -30
- jinns/utils/_spinn.py +23 -30
- {jinns-0.6.0.dist-info → jinns-0.6.1.dist-info}/METADATA +1 -1
- {jinns-0.6.0.dist-info → jinns-0.6.1.dist-info}/RECORD +12 -12
- {jinns-0.6.0.dist-info → jinns-0.6.1.dist-info}/LICENSE +0 -0
- {jinns-0.6.0.dist-info → jinns-0.6.1.dist-info}/WHEEL +0 -0
- {jinns-0.6.0.dist-info → jinns-0.6.1.dist-info}/top_level.txt +0 -0
jinns/loss/_DynamicLoss.py
CHANGED
|
@@ -48,6 +48,7 @@ class FisherKPP(PDENonStatio):
|
|
|
48
48
|
"""
|
|
49
49
|
super().__init__(Tmax, eq_params_heterogeneity)
|
|
50
50
|
|
|
51
|
+
@PDENonStatio.evaluate_heterogeneous_parameters
|
|
51
52
|
def evaluate(self, t, x, u, params):
|
|
52
53
|
r"""
|
|
53
54
|
Evaluate the dynamic loss at :math:`(t,x)`.
|
|
@@ -67,16 +68,12 @@ class FisherKPP(PDENonStatio):
|
|
|
67
68
|
differential equation parameters and the neural network parameter
|
|
68
69
|
"""
|
|
69
70
|
if isinstance(u, PINN):
|
|
70
|
-
params["eq_params"] = self._eval_heterogeneous_parameters(
|
|
71
|
-
params["eq_params"], t, x, self.eq_params_heterogeneity
|
|
72
|
-
)
|
|
73
|
-
|
|
74
71
|
# Note that the last dim of u is nec. 1
|
|
75
72
|
u_ = lambda t, x: u(t, x, params)[0]
|
|
76
73
|
|
|
77
74
|
du_dt = grad(u_, 0)(t, x)
|
|
78
75
|
|
|
79
|
-
lap = _laplacian_rev(
|
|
76
|
+
lap = _laplacian_rev(t, x, u, params)[..., None]
|
|
80
77
|
|
|
81
78
|
return du_dt + self.Tmax * (
|
|
82
79
|
-params["eq_params"]["D"] * lap
|
|
@@ -87,17 +84,12 @@ class FisherKPP(PDENonStatio):
|
|
|
87
84
|
)
|
|
88
85
|
)
|
|
89
86
|
if isinstance(u, SPINN):
|
|
90
|
-
x_grid = _get_grid(x)
|
|
91
|
-
params["eq_params"] = self._eval_heterogeneous_parameters(
|
|
92
|
-
params["eq_params"], t, x_grid, self.eq_params_heterogeneity
|
|
93
|
-
)
|
|
94
|
-
|
|
95
87
|
u_tx, du_dt = jax.jvp(
|
|
96
88
|
lambda t: u(t, x, params),
|
|
97
89
|
(t,),
|
|
98
90
|
(jnp.ones_like(t),),
|
|
99
91
|
)
|
|
100
|
-
lap = _laplacian_fwd(
|
|
92
|
+
lap = _laplacian_fwd(t, x, u, params)[..., None]
|
|
101
93
|
return du_dt + self.Tmax * (
|
|
102
94
|
-params["eq_params"]["D"] * lap
|
|
103
95
|
- u_tx
|
|
@@ -160,12 +152,8 @@ class BurgerEquation(PDENonStatio):
|
|
|
160
152
|
differential equation parameters and the neural network parameter
|
|
161
153
|
"""
|
|
162
154
|
if isinstance(u, PINN):
|
|
163
|
-
params["eq_params"] = self._eval_heterogeneous_parameters(
|
|
164
|
-
params["eq_params"], t, x, self.eq_params_heterogeneity
|
|
165
|
-
)
|
|
166
|
-
|
|
167
155
|
# Note that the last dim of u is nec. 1
|
|
168
|
-
u_ = lambda t, x: u(t, x, params)[
|
|
156
|
+
u_ = lambda t, x: jnp.squeeze(u(t, x, params)[u.slice_solution])
|
|
169
157
|
du_dt = grad(u_, 0)
|
|
170
158
|
du_dx = grad(u_, 1)
|
|
171
159
|
d2u_dx2 = grad(
|
|
@@ -179,10 +167,6 @@ class BurgerEquation(PDENonStatio):
|
|
|
179
167
|
)
|
|
180
168
|
|
|
181
169
|
if isinstance(u, SPINN):
|
|
182
|
-
x_grid = _get_grid(x)
|
|
183
|
-
params["eq_params"] = self._eval_heterogeneous_parameters(
|
|
184
|
-
params["eq_params"], t, x_grid, self.eq_params_heterogeneity
|
|
185
|
-
)
|
|
186
170
|
# d=2 JVP calls are expected since we have time and x
|
|
187
171
|
# then with a batch of size B, we then have Bd JVP calls
|
|
188
172
|
u_tx, du_dt = jax.jvp(
|
|
@@ -352,10 +336,6 @@ class FPENonStatioLoss2D(PDENonStatio):
|
|
|
352
336
|
differential equation parameters and the neural network parameter
|
|
353
337
|
"""
|
|
354
338
|
if isinstance(u, PINN):
|
|
355
|
-
params["eq_params"] = self._eval_heterogeneous_parameters(
|
|
356
|
-
params["eq_params"], t, x, self.eq_params_heterogeneity
|
|
357
|
-
)
|
|
358
|
-
|
|
359
339
|
# Note that the last dim of u is nec. 1
|
|
360
340
|
u_ = lambda t, x: u(t, x, params)[0]
|
|
361
341
|
|
|
@@ -411,10 +391,6 @@ class FPENonStatioLoss2D(PDENonStatio):
|
|
|
411
391
|
|
|
412
392
|
if isinstance(u, SPINN):
|
|
413
393
|
x_grid = _get_grid(x)
|
|
414
|
-
params["eq_params"] = self._eval_heterogeneous_parameters(
|
|
415
|
-
params["eq_params"], t, x_grid, self.eq_params_heterogeneity
|
|
416
|
-
)
|
|
417
|
-
|
|
418
394
|
_, du_dt = jax.jvp(
|
|
419
395
|
lambda t: u(t, x, params),
|
|
420
396
|
(t,),
|
|
@@ -634,12 +610,12 @@ class MassConservation2DStatio(PDEStatio):
|
|
|
634
610
|
if isinstance(u_dict[self.nn_key], PINN):
|
|
635
611
|
u = u_dict[self.nn_key]
|
|
636
612
|
|
|
637
|
-
return _div_rev(u, params
|
|
613
|
+
return _div_rev(None, x, u, params)[..., None]
|
|
638
614
|
|
|
639
615
|
if isinstance(u_dict[self.nn_key], SPINN):
|
|
640
616
|
u = u_dict[self.nn_key]
|
|
641
617
|
|
|
642
|
-
return _div_fwd(u, params
|
|
618
|
+
return _div_fwd(None, x, u, params)[..., None]
|
|
643
619
|
raise ValueError("u is not among the recognized types (PINN or SPINN)")
|
|
644
620
|
|
|
645
621
|
|
|
@@ -724,12 +700,12 @@ class NavierStokes2DStatio(PDEStatio):
|
|
|
724
700
|
if isinstance(u_dict[self.u_key], PINN):
|
|
725
701
|
u = u_dict[self.u_key]
|
|
726
702
|
|
|
727
|
-
u_dot_nabla_x_u = _u_dot_nabla_times_u_rev(u, u_params
|
|
703
|
+
u_dot_nabla_x_u = _u_dot_nabla_times_u_rev(None, x, u, u_params)
|
|
728
704
|
|
|
729
705
|
p = lambda x: u_dict[self.p_key](x, p_params)
|
|
730
706
|
jac_p = jacrev(p, 0)(x) # compute the gradient
|
|
731
707
|
|
|
732
|
-
vec_laplacian_u = _vectorial_laplacian(u, u_params,
|
|
708
|
+
vec_laplacian_u = _vectorial_laplacian(None, x, u, u_params, u_vec_ndim=2)
|
|
733
709
|
|
|
734
710
|
# dynamic loss on x axis
|
|
735
711
|
result_x = (
|
|
@@ -751,7 +727,7 @@ class NavierStokes2DStatio(PDEStatio):
|
|
|
751
727
|
if isinstance(u_dict[self.u_key], SPINN):
|
|
752
728
|
u = u_dict[self.u_key]
|
|
753
729
|
|
|
754
|
-
u_dot_nabla_x_u = _u_dot_nabla_times_u_fwd(u, u_params
|
|
730
|
+
u_dot_nabla_x_u = _u_dot_nabla_times_u_fwd(None, x, u, u_params)
|
|
755
731
|
|
|
756
732
|
p = lambda x: u_dict[self.p_key](x, p_params)
|
|
757
733
|
|
|
@@ -761,7 +737,7 @@ class NavierStokes2DStatio(PDEStatio):
|
|
|
761
737
|
_, dp_dy = jax.jvp(p, (x,), (tangent_vec_1,))
|
|
762
738
|
|
|
763
739
|
vec_laplacian_u = jnp.moveaxis(
|
|
764
|
-
_vectorial_laplacian(u, u_params,
|
|
740
|
+
_vectorial_laplacian(None, x, u, u_params, u_vec_ndim=2),
|
|
765
741
|
source=0,
|
|
766
742
|
destination=-1,
|
|
767
743
|
)
|
|
@@ -23,8 +23,8 @@ class DynamicLoss:
|
|
|
23
23
|
Default None. A dict with the keys being the same as in eq_params
|
|
24
24
|
and the value being either None (no heterogeneity) or a function
|
|
25
25
|
which encodes for the spatio-temporal heterogeneity of the parameter.
|
|
26
|
-
Such a function must be jittable and take
|
|
27
|
-
`
|
|
26
|
+
Such a function must be jittable and take four arguments `t`, `x`,
|
|
27
|
+
`u` and `params` even if one is not used. Therefore,
|
|
28
28
|
one can introduce spatio-temporal covariates upon which a particular
|
|
29
29
|
parameter can depend, e.g. in a GLM fashion. The effect of these
|
|
30
30
|
covariables can themselves be estimated by being in `eq_params` too.
|
|
@@ -35,20 +35,24 @@ class DynamicLoss:
|
|
|
35
35
|
self.Tmax = Tmax
|
|
36
36
|
self.eq_params_heterogeneity = eq_params_heterogeneity
|
|
37
37
|
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
):
|
|
38
|
+
@staticmethod
|
|
39
|
+
def _eval_heterogeneous_parameters(t, x, u, params, eq_params_heterogeneity=None):
|
|
41
40
|
eq_params_ = {}
|
|
42
41
|
if eq_params_heterogeneity is None:
|
|
43
|
-
return eq_params
|
|
44
|
-
for k, p in eq_params.items():
|
|
42
|
+
return params["eq_params"]
|
|
43
|
+
for k, p in params["eq_params"].items():
|
|
45
44
|
try:
|
|
46
45
|
if eq_params_heterogeneity[k] is None:
|
|
47
46
|
eq_params_[k] = p
|
|
48
47
|
else:
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
48
|
+
if t is None:
|
|
49
|
+
eq_params_[k] = eq_params_heterogeneity[k](
|
|
50
|
+
x, u, params # heterogeneity encoded through a function
|
|
51
|
+
)
|
|
52
|
+
else:
|
|
53
|
+
eq_params_[k] = eq_params_heterogeneity[k](
|
|
54
|
+
t, x, u, params # heterogeneity encoded through a function
|
|
55
|
+
)
|
|
52
56
|
except KeyError:
|
|
53
57
|
# we authorize missing eq_params_heterogeneity key
|
|
54
58
|
# is its heterogeneity is None anyway
|
|
@@ -79,6 +83,31 @@ class ODE(DynamicLoss):
|
|
|
79
83
|
"""
|
|
80
84
|
super().__init__(Tmax, eq_params_heterogeneity)
|
|
81
85
|
|
|
86
|
+
def eval_heterogeneous_parameters(self, t, u, params, eq_params_heterogeneity=None):
|
|
87
|
+
return super()._eval_heterogeneous_parameters(
|
|
88
|
+
t, None, u, params, eq_params_heterogeneity
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
@staticmethod
|
|
92
|
+
def evaluate_heterogeneous_parameters(evaluate):
|
|
93
|
+
"""
|
|
94
|
+
Decorator which aims to decorate the evaluate methods of Dynamic losses
|
|
95
|
+
in order. It calls _eval_heterogeneous_parameters which applies the
|
|
96
|
+
user defined rules to obtain spatially / temporally heterogeneous
|
|
97
|
+
parameters
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
def wrapper(*args):
|
|
101
|
+
self, t, u, params = args
|
|
102
|
+
params["eq_params"] = self.eval_heterogeneous_parameters(
|
|
103
|
+
t, u, params, self.eq_params_heterogeneity
|
|
104
|
+
)
|
|
105
|
+
new_args = args[:-1] + (params,)
|
|
106
|
+
res = evaluate(*new_args)
|
|
107
|
+
return res
|
|
108
|
+
|
|
109
|
+
return wrapper
|
|
110
|
+
|
|
82
111
|
|
|
83
112
|
class PDEStatio(DynamicLoss):
|
|
84
113
|
r"""
|
|
@@ -99,6 +128,31 @@ class PDEStatio(DynamicLoss):
|
|
|
99
128
|
"""
|
|
100
129
|
super().__init__(eq_params_heterogeneity=eq_params_heterogeneity)
|
|
101
130
|
|
|
131
|
+
def eval_heterogeneous_parameters(self, x, u, params, eq_params_heterogeneity=None):
|
|
132
|
+
return super()._eval_heterogeneous_parameters(
|
|
133
|
+
None, x, u, params, eq_params_heterogeneity
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
@staticmethod
|
|
137
|
+
def evaluate_heterogeneous_parameters(evaluate):
|
|
138
|
+
"""
|
|
139
|
+
Decorator which aims to decorate the evaluate methods of Dynamic losses
|
|
140
|
+
in order. It calls _eval_heterogeneous_parameters which applies the
|
|
141
|
+
user defined rules to obtain spatially / temporally heterogeneous
|
|
142
|
+
parameters
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
def wrapper(*args):
|
|
146
|
+
self, x, u, params = args
|
|
147
|
+
params["eq_params"] = self.eval_heterogeneous_parameters(
|
|
148
|
+
x, u, params, self.eq_params_heterogeneity
|
|
149
|
+
)
|
|
150
|
+
new_args = args[:-1] + (params,)
|
|
151
|
+
res = evaluate(*new_args)
|
|
152
|
+
return res
|
|
153
|
+
|
|
154
|
+
return wrapper
|
|
155
|
+
|
|
102
156
|
|
|
103
157
|
class PDENonStatio(DynamicLoss):
|
|
104
158
|
r"""
|
|
@@ -122,3 +176,30 @@ class PDENonStatio(DynamicLoss):
|
|
|
122
176
|
heterogeneity for no parameters.
|
|
123
177
|
"""
|
|
124
178
|
super().__init__(Tmax, eq_params_heterogeneity)
|
|
179
|
+
|
|
180
|
+
def eval_heterogeneous_parameters(
|
|
181
|
+
self, t, x, u, params, eq_params_heterogeneity=None
|
|
182
|
+
):
|
|
183
|
+
return super()._eval_heterogeneous_parameters(
|
|
184
|
+
t, x, u, params, eq_params_heterogeneity
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
@staticmethod
|
|
188
|
+
def evaluate_heterogeneous_parameters(evaluate):
|
|
189
|
+
"""
|
|
190
|
+
Decorator which aims to decorate the evaluate methods of Dynamic losses
|
|
191
|
+
in order. It calls _eval_heterogeneous_parameters which applies the
|
|
192
|
+
user defined rules to obtain spatially / temporally heterogeneous
|
|
193
|
+
parameters
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
def wrapper(*args):
|
|
197
|
+
self, t, x, u, params = args
|
|
198
|
+
params["eq_params"] = self.eval_heterogeneous_parameters(
|
|
199
|
+
t, x, u, params, self.eq_params_heterogeneity
|
|
200
|
+
)
|
|
201
|
+
new_args = args[:-1] + (params,)
|
|
202
|
+
res = evaluate(*new_args)
|
|
203
|
+
return res
|
|
204
|
+
|
|
205
|
+
return wrapper
|
jinns/loss/_LossPDE.py
CHANGED
|
@@ -254,6 +254,7 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
254
254
|
derivative_keys=None,
|
|
255
255
|
omega_boundary_fun=None,
|
|
256
256
|
omega_boundary_condition=None,
|
|
257
|
+
omega_boundary_dim=None,
|
|
257
258
|
norm_key=None,
|
|
258
259
|
norm_borders=None,
|
|
259
260
|
norm_samples=None,
|
|
@@ -310,6 +311,12 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
310
311
|
enforce a particular boundary condition on this facet.
|
|
311
312
|
The facet called "xmin", resp. "xmax" etc., in 2D,
|
|
312
313
|
refers to the set of 2D points with fixed "xmin", resp. "xmax", etc.
|
|
314
|
+
omega_boundary_dim
|
|
315
|
+
Either None, or a jnp.s_ or a dict of jnp.s_ with keys following
|
|
316
|
+
the logic of omega_boundary_fun. It indicates which dimension(s) of
|
|
317
|
+
the PINN will be forced to match the boundary condition
|
|
318
|
+
Note that it must be a slice and not an integer (a preprocessing of the
|
|
319
|
+
user provided argument takes care of it)
|
|
313
320
|
norm_key
|
|
314
321
|
Jax random key to draw samples in for the Monte Carlo computation
|
|
315
322
|
of the normalization constant. Default is None
|
|
@@ -345,7 +352,7 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
345
352
|
raise ValueError(
|
|
346
353
|
f"obs_batch must be a list of size 2. You gave {len(obs_batch)}"
|
|
347
354
|
)
|
|
348
|
-
if
|
|
355
|
+
if not all(isinstance(b, jnp.ndarray) for b in obs_batch):
|
|
349
356
|
raise ValueError("Every element of obs_batch should be a jnp.array.")
|
|
350
357
|
n_obs = obs_batch[0].shape[0]
|
|
351
358
|
if any(b.shape[0] != n_obs for b in obs_batch):
|
|
@@ -411,16 +418,49 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
411
418
|
|
|
412
419
|
self.omega_boundary_fun = omega_boundary_fun
|
|
413
420
|
self.omega_boundary_condition = omega_boundary_condition
|
|
421
|
+
|
|
422
|
+
self.omega_boundary_dim = omega_boundary_dim
|
|
423
|
+
if isinstance(self.omega_boundary_fun, dict):
|
|
424
|
+
if self.omega_boundary_dim is None:
|
|
425
|
+
self.omega_boundary_dim = {
|
|
426
|
+
k: jnp.s_[::] for k in self.omega_boundary_fun.keys()
|
|
427
|
+
}
|
|
428
|
+
if list(self.omega_boundary_dim.keys()) != list(
|
|
429
|
+
self.omega_boundary_fun.keys()
|
|
430
|
+
):
|
|
431
|
+
raise ValueError(
|
|
432
|
+
"If omega_boundary_fun is a dict,"
|
|
433
|
+
" omega_boundary_dim should be a dict with the same keys"
|
|
434
|
+
)
|
|
435
|
+
for k, v in self.omega_boundary_dim.items():
|
|
436
|
+
if isinstance(v, int):
|
|
437
|
+
# rewrite it as a slice to ensure that axis does not disappear when
|
|
438
|
+
# indexing
|
|
439
|
+
self.omega_boundary_dim[k] = jnp.s_[v : v + 1]
|
|
440
|
+
|
|
441
|
+
else:
|
|
442
|
+
if self.omega_boundary_dim is None:
|
|
443
|
+
self.omega_boundary_dim = jnp.s_[::]
|
|
444
|
+
if isinstance(self.omega_boundary_dim, int):
|
|
445
|
+
# rewrite it as a slice to ensure that axis does not disappear when
|
|
446
|
+
# indexing
|
|
447
|
+
self.omega_boundary_dim = jnp.s_[
|
|
448
|
+
self.omega_boundary_dim : self.omega_boundary_dim + 1
|
|
449
|
+
]
|
|
450
|
+
if not isinstance(self.omega_boundary_dim, slice):
|
|
451
|
+
raise ValueError("self.omega_boundary_dim must be a jnp.s_" " object")
|
|
452
|
+
|
|
414
453
|
self.dynamic_loss = dynamic_loss
|
|
415
454
|
self.obs_batch = obs_batch
|
|
416
455
|
|
|
417
|
-
|
|
456
|
+
self.sobolev_m = sobolev_m
|
|
457
|
+
if self.sobolev_m is not None:
|
|
418
458
|
self.sobolev_reg = _sobolev(
|
|
419
|
-
self.u, sobolev_m
|
|
459
|
+
self.u, self.sobolev_m
|
|
420
460
|
) # we return a function, that way
|
|
421
461
|
# the order of sobolev_m is static and the conditional in the recursive
|
|
422
462
|
# function is properly set
|
|
423
|
-
self.sobolev_m = sobolev_m
|
|
463
|
+
self.sobolev_m = self.sobolev_m
|
|
424
464
|
else:
|
|
425
465
|
self.sobolev_reg = None
|
|
426
466
|
|
|
@@ -512,7 +552,7 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
512
552
|
if self.normalization_loss is not None:
|
|
513
553
|
if isinstance(self.u, PINN):
|
|
514
554
|
v_u = vmap(
|
|
515
|
-
lambda x: self.u(x, params_),
|
|
555
|
+
lambda x: self.u(x, params_)[self.u.slice_solution],
|
|
516
556
|
(0),
|
|
517
557
|
0,
|
|
518
558
|
)
|
|
@@ -560,6 +600,7 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
560
600
|
self.u,
|
|
561
601
|
params_,
|
|
562
602
|
idx,
|
|
603
|
+
self.omega_boundary_dim[facet],
|
|
563
604
|
)
|
|
564
605
|
)
|
|
565
606
|
else:
|
|
@@ -574,6 +615,7 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
574
615
|
self.u,
|
|
575
616
|
params_,
|
|
576
617
|
facet,
|
|
618
|
+
self.omega_boundary_dim,
|
|
577
619
|
)
|
|
578
620
|
)
|
|
579
621
|
else:
|
|
@@ -585,11 +627,11 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
585
627
|
# TODO implement for SPINN
|
|
586
628
|
if isinstance(self.u, PINN):
|
|
587
629
|
v_u = vmap(
|
|
588
|
-
lambda x: self.u(x, params_),
|
|
630
|
+
lambda x: self.u(x, params_)[self.u.slice_solution],
|
|
589
631
|
0,
|
|
590
632
|
0,
|
|
591
633
|
)
|
|
592
|
-
val = v_u(self.obs_batch[0]
|
|
634
|
+
val = v_u(self.obs_batch[0])
|
|
593
635
|
mse_observation_loss = jnp.mean(
|
|
594
636
|
self.loss_weights["observations"]
|
|
595
637
|
* jnp.sum(
|
|
@@ -655,6 +697,7 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
655
697
|
"derivative_keys": self.derivative_keys,
|
|
656
698
|
"omega_boundary_fun": self.omega_boundary_fun,
|
|
657
699
|
"omega_boundary_condition": self.omega_boundary_condition,
|
|
700
|
+
"omega_boundary_dim": self.omega_boundary_dim,
|
|
658
701
|
"norm_borders": self.norm_borders,
|
|
659
702
|
"sobolev_m": self.sobolev_m,
|
|
660
703
|
}
|
|
@@ -670,6 +713,7 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
670
713
|
aux_data["derivative_keys"],
|
|
671
714
|
aux_data["omega_boundary_fun"],
|
|
672
715
|
aux_data["omega_boundary_condition"],
|
|
716
|
+
aux_data["omega_boundary_dim"],
|
|
673
717
|
norm_key,
|
|
674
718
|
aux_data["norm_borders"],
|
|
675
719
|
norm_samples,
|
|
@@ -706,6 +750,7 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
706
750
|
derivative_keys=None,
|
|
707
751
|
omega_boundary_fun=None,
|
|
708
752
|
omega_boundary_condition=None,
|
|
753
|
+
omega_boundary_dim=None,
|
|
709
754
|
initial_condition_fun=None,
|
|
710
755
|
norm_key=None,
|
|
711
756
|
norm_borders=None,
|
|
@@ -760,6 +805,12 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
760
805
|
enforce a particular boundary condition on this facet.
|
|
761
806
|
The facet called "xmin", resp. "xmax" etc., in 2D,
|
|
762
807
|
refers to the set of 2D points with fixed "xmin", resp. "xmax", etc.
|
|
808
|
+
omega_boundary_dim
|
|
809
|
+
Either None, or a jnp.s_ or a dict of jnp.s_ with keys following
|
|
810
|
+
the logic of omega_boundary_fun. It indicates which dimension(s) of
|
|
811
|
+
the PINN will be forced to match the boundary condition
|
|
812
|
+
Note that it must be a slice and not an integer (a preprocessing of the
|
|
813
|
+
user provided argument takes care of it)
|
|
763
814
|
initial_condition_fun
|
|
764
815
|
A function representing the temporal initial condition. If None
|
|
765
816
|
(default) then no initial condition is applied
|
|
@@ -815,6 +866,7 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
815
866
|
derivative_keys,
|
|
816
867
|
omega_boundary_fun,
|
|
817
868
|
omega_boundary_condition,
|
|
869
|
+
omega_boundary_dim,
|
|
818
870
|
norm_key,
|
|
819
871
|
norm_borders,
|
|
820
872
|
norm_samples,
|
|
@@ -978,6 +1030,7 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
978
1030
|
self.u,
|
|
979
1031
|
params_,
|
|
980
1032
|
idx,
|
|
1033
|
+
self.omega_boundary_dim[facet],
|
|
981
1034
|
)
|
|
982
1035
|
)
|
|
983
1036
|
else:
|
|
@@ -993,6 +1046,7 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
993
1046
|
self.u,
|
|
994
1047
|
params_,
|
|
995
1048
|
facet,
|
|
1049
|
+
self.omega_boundary_dim,
|
|
996
1050
|
)
|
|
997
1051
|
)
|
|
998
1052
|
else:
|
|
@@ -1036,7 +1090,7 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
1036
1090
|
# TODO implement for SPINN
|
|
1037
1091
|
if isinstance(self.u, PINN):
|
|
1038
1092
|
v_u = vmap(
|
|
1039
|
-
lambda t, x: self.u(t, x, params_),
|
|
1093
|
+
lambda t, x: self.u(t, x, params_)[self.u.slice_solution],
|
|
1040
1094
|
(0, 0),
|
|
1041
1095
|
0,
|
|
1042
1096
|
)
|
|
@@ -1108,6 +1162,7 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
1108
1162
|
"derivative_keys": self.derivative_keys,
|
|
1109
1163
|
"omega_boundary_fun": self.omega_boundary_fun,
|
|
1110
1164
|
"omega_boundary_condition": self.omega_boundary_condition,
|
|
1165
|
+
"omega_boundary_dim": self.omega_boundary_dim,
|
|
1111
1166
|
"initial_condition_fun": self.initial_condition_fun,
|
|
1112
1167
|
"norm_borders": self.norm_borders,
|
|
1113
1168
|
"sobolev_m": self.sobolev_m,
|
|
@@ -1124,6 +1179,7 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
1124
1179
|
aux_data["derivative_keys"],
|
|
1125
1180
|
aux_data["omega_boundary_fun"],
|
|
1126
1181
|
aux_data["omega_boundary_condition"],
|
|
1182
|
+
aux_data["omega_boundary_dim"],
|
|
1127
1183
|
aux_data["initial_condition_fun"],
|
|
1128
1184
|
norm_key,
|
|
1129
1185
|
aux_data["norm_borders"],
|
|
@@ -1163,6 +1219,7 @@ class SystemLossPDE:
|
|
|
1163
1219
|
derivative_keys_dict=None,
|
|
1164
1220
|
omega_boundary_fun_dict=None,
|
|
1165
1221
|
omega_boundary_condition_dict=None,
|
|
1222
|
+
omega_boundary_dim_dict=None,
|
|
1166
1223
|
initial_condition_fun_dict=None,
|
|
1167
1224
|
norm_key_dict=None,
|
|
1168
1225
|
norm_borders_dict=None,
|
|
@@ -1199,15 +1256,17 @@ class SystemLossPDE:
|
|
|
1199
1256
|
are not specified then the default behaviour for `derivative_keys`
|
|
1200
1257
|
of LossODE is used
|
|
1201
1258
|
omega_boundary_fun_dict
|
|
1202
|
-
A dict of
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
Must share the keys of `u_dict`
|
|
1259
|
+
A dict of dict of functions (see doc for `omega_boundary_fun` in
|
|
1260
|
+
LossPDEStatio or LossPDENonStatio). Default is None.
|
|
1261
|
+
Must share the keys of `u_dict`.
|
|
1206
1262
|
omega_boundary_condition_dict
|
|
1207
|
-
A dict of
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
|
|
1263
|
+
A dict of dict of strings (see doc for
|
|
1264
|
+
`omega_boundary_condition_dict` in
|
|
1265
|
+
LossPDEStatio or LossPDENonStatio). Default is None.
|
|
1266
|
+
Must share the keys of `u_dict`
|
|
1267
|
+
omega_boundary_dim_dict
|
|
1268
|
+
A dict of dict of slices (see doc for `omega_boundary_dim` in
|
|
1269
|
+
LossPDEStatio or LossPDENonStatio). Default is None.
|
|
1211
1270
|
Must share the keys of `u_dict`
|
|
1212
1271
|
initial_condition_fun_dict
|
|
1213
1272
|
A dict of functions representing the temporal initial condition. If None
|
|
@@ -1265,6 +1324,10 @@ class SystemLossPDE:
|
|
|
1265
1324
|
self.omega_boundary_condition_dict = {k: None for k in u_dict.keys()}
|
|
1266
1325
|
else:
|
|
1267
1326
|
self.omega_boundary_condition_dict = omega_boundary_condition_dict
|
|
1327
|
+
if omega_boundary_dim_dict is None:
|
|
1328
|
+
self.omega_boundary_dim_dict = {k: None for k in u_dict.keys()}
|
|
1329
|
+
else:
|
|
1330
|
+
self.omega_boundary_dim_dict = omega_boundary_dim_dict
|
|
1268
1331
|
if initial_condition_fun_dict is None:
|
|
1269
1332
|
self.initial_condition_fun_dict = {k: None for k in u_dict.keys()}
|
|
1270
1333
|
else:
|
|
@@ -1309,6 +1372,7 @@ class SystemLossPDE:
|
|
|
1309
1372
|
or u_dict.keys() != self.obs_batch_dict.keys()
|
|
1310
1373
|
or u_dict.keys() != self.omega_boundary_fun_dict.keys()
|
|
1311
1374
|
or u_dict.keys() != self.omega_boundary_condition_dict.keys()
|
|
1375
|
+
or u_dict.keys() != self.omega_boundary_dim_dict.keys()
|
|
1312
1376
|
or u_dict.keys() != self.initial_condition_fun_dict.keys()
|
|
1313
1377
|
or u_dict.keys() != self.norm_key_dict.keys()
|
|
1314
1378
|
or u_dict.keys() != self.norm_borders_dict.keys()
|
|
@@ -1345,6 +1409,7 @@ class SystemLossPDE:
|
|
|
1345
1409
|
derivative_keys=self.derivative_keys_dict[i],
|
|
1346
1410
|
omega_boundary_fun=self.omega_boundary_fun_dict[i],
|
|
1347
1411
|
omega_boundary_condition=self.omega_boundary_condition_dict[i],
|
|
1412
|
+
omega_boundary_dim=self.omega_boundary_dim_dict[i],
|
|
1348
1413
|
norm_key=self.norm_key_dict[i],
|
|
1349
1414
|
norm_borders=self.norm_borders_dict[i],
|
|
1350
1415
|
norm_samples=self.norm_samples_dict[i],
|
|
@@ -1366,6 +1431,7 @@ class SystemLossPDE:
|
|
|
1366
1431
|
derivative_keys=self.derivative_keys_dict[i],
|
|
1367
1432
|
omega_boundary_fun=self.omega_boundary_fun_dict[i],
|
|
1368
1433
|
omega_boundary_condition=self.omega_boundary_condition_dict[i],
|
|
1434
|
+
omega_boundary_dim=self.omega_boundary_dim_dict[i],
|
|
1369
1435
|
initial_condition_fun=self.initial_condition_fun_dict[i],
|
|
1370
1436
|
norm_key=self.norm_key_dict[i],
|
|
1371
1437
|
norm_borders=self.norm_borders_dict[i],
|
|
@@ -11,7 +11,7 @@ from jinns.utils._spinn import SPINN
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
def _compute_boundary_loss_statio(
|
|
14
|
-
boundary_condition_type, f, border_batch, u, params, facet
|
|
14
|
+
boundary_condition_type, f, border_batch, u, params, facet, dim_to_apply
|
|
15
15
|
):
|
|
16
16
|
r"""A generic function that will compute the mini-batch MSE of a
|
|
17
17
|
boundary condition in the stationary case, given by:
|
|
@@ -44,6 +44,9 @@ def _compute_boundary_loss_statio(
|
|
|
44
44
|
facet:
|
|
45
45
|
An integer which represents the id of the facet which is currently
|
|
46
46
|
considered (in the order provided wy the DataGenerator which is fixed)
|
|
47
|
+
dim_to_apply
|
|
48
|
+
A jnp.s_ object which indicates which dimension(s) of u will be forced
|
|
49
|
+
to match the boundary condition
|
|
47
50
|
|
|
48
51
|
Returns
|
|
49
52
|
-------
|
|
@@ -51,17 +54,24 @@ def _compute_boundary_loss_statio(
|
|
|
51
54
|
the MSE computed on `border_batch`
|
|
52
55
|
"""
|
|
53
56
|
if boundary_condition_type.lower() in "dirichlet":
|
|
54
|
-
mse = boundary_dirichlet_statio(f, border_batch, u, params)
|
|
57
|
+
mse = boundary_dirichlet_statio(f, border_batch, u, params, dim_to_apply)
|
|
55
58
|
elif any(
|
|
56
59
|
boundary_condition_type.lower() in s
|
|
57
60
|
for s in ["von neumann", "vn", "vonneumann"]
|
|
58
61
|
):
|
|
59
|
-
mse = boundary_neumann_statio(f, border_batch, u, params, facet)
|
|
62
|
+
mse = boundary_neumann_statio(f, border_batch, u, params, facet, dim_to_apply)
|
|
60
63
|
return mse
|
|
61
64
|
|
|
62
65
|
|
|
63
66
|
def _compute_boundary_loss_nonstatio(
|
|
64
|
-
boundary_condition_type,
|
|
67
|
+
boundary_condition_type,
|
|
68
|
+
f,
|
|
69
|
+
times_batch,
|
|
70
|
+
border_batch,
|
|
71
|
+
u,
|
|
72
|
+
params,
|
|
73
|
+
facet,
|
|
74
|
+
dim_to_apply,
|
|
65
75
|
):
|
|
66
76
|
r"""A generic function that will compute the mini-batch MSE of a
|
|
67
77
|
boundary condition in the non-stationary case, given by:
|
|
@@ -95,6 +105,8 @@ def _compute_boundary_loss_nonstatio(
|
|
|
95
105
|
facet:
|
|
96
106
|
An integer which represents the id of the facet which is currently
|
|
97
107
|
considered (in the order provided wy the DataGenerator which is fixed)
|
|
108
|
+
dim_to_apply
|
|
109
|
+
A jnp.s_ object. The dimension of u on which to apply the boundary condition
|
|
98
110
|
|
|
99
111
|
Returns
|
|
100
112
|
-------
|
|
@@ -102,16 +114,20 @@ def _compute_boundary_loss_nonstatio(
|
|
|
102
114
|
the MSE computed on `border_batch`
|
|
103
115
|
"""
|
|
104
116
|
if boundary_condition_type.lower() in "dirichlet":
|
|
105
|
-
mse = boundary_dirichlet_nonstatio(
|
|
117
|
+
mse = boundary_dirichlet_nonstatio(
|
|
118
|
+
f, times_batch, border_batch, u, params, dim_to_apply
|
|
119
|
+
)
|
|
106
120
|
elif any(
|
|
107
121
|
boundary_condition_type.lower() in s
|
|
108
122
|
for s in ["von neumann", "vn", "vonneumann"]
|
|
109
123
|
):
|
|
110
|
-
mse = boundary_neumann_nonstatio(
|
|
124
|
+
mse = boundary_neumann_nonstatio(
|
|
125
|
+
f, times_batch, border_batch, u, params, facet, dim_to_apply
|
|
126
|
+
)
|
|
111
127
|
return mse
|
|
112
128
|
|
|
113
129
|
|
|
114
|
-
def boundary_dirichlet_statio(f, border_batch, u, params):
|
|
130
|
+
def boundary_dirichlet_statio(f, border_batch, u, params, dim_to_apply):
|
|
115
131
|
r"""
|
|
116
132
|
This omega boundary condition enforces a solution that is equal to f on
|
|
117
133
|
border batch.
|
|
@@ -129,17 +145,19 @@ def boundary_dirichlet_statio(f, border_batch, u, params):
|
|
|
129
145
|
Typically, it is a dictionary of
|
|
130
146
|
dictionaries: `eq_params` and `nn_params``, respectively the
|
|
131
147
|
differential equation parameters and the neural network parameter
|
|
148
|
+
dim_to_apply
|
|
149
|
+
A jnp.s_ object. The dimension of u on which to apply the boundary condition
|
|
132
150
|
"""
|
|
133
151
|
if isinstance(u, PINN):
|
|
134
152
|
v_u_boundary = vmap(
|
|
135
|
-
lambda dx: u(dx, params) - f(dx),
|
|
153
|
+
lambda dx: u(dx, params)[dim_to_apply] - f(dx),
|
|
136
154
|
(0),
|
|
137
155
|
0,
|
|
138
156
|
)
|
|
139
157
|
|
|
140
158
|
mse_u_boundary = jnp.sum((v_u_boundary(border_batch)) ** 2, axis=-1)
|
|
141
159
|
elif isinstance(u, SPINN):
|
|
142
|
-
values = u(border_batch, params)
|
|
160
|
+
values = u(border_batch, params)[..., dim_to_apply]
|
|
143
161
|
x_grid = _get_grid(border_batch)
|
|
144
162
|
boundaries = _check_user_func_return(f(x_grid), values.shape)
|
|
145
163
|
res = values - boundaries
|
|
@@ -150,7 +168,7 @@ def boundary_dirichlet_statio(f, border_batch, u, params):
|
|
|
150
168
|
return mse_u_boundary
|
|
151
169
|
|
|
152
170
|
|
|
153
|
-
def boundary_neumann_statio(f, border_batch, u, params, facet):
|
|
171
|
+
def boundary_neumann_statio(f, border_batch, u, params, facet, dim_to_apply):
|
|
154
172
|
r"""
|
|
155
173
|
This omega boundary condition enforces a solution where :math:`\nabla u\cdot
|
|
156
174
|
n` is equal to `f` on omega borders. :math:`n` is the unitary
|
|
@@ -173,6 +191,8 @@ def boundary_neumann_statio(f, border_batch, u, params, facet):
|
|
|
173
191
|
facet:
|
|
174
192
|
An integer which represents the id of the facet which is currently
|
|
175
193
|
considered (in the order provided wy the DataGenerator which is fixed)
|
|
194
|
+
dim_to_apply
|
|
195
|
+
A jnp.s_ object. The dimension of u on which to apply the boundary condition
|
|
176
196
|
"""
|
|
177
197
|
# We resort to the shape of the border_batch to determine the dimension as
|
|
178
198
|
# described in the border_batch function
|
|
@@ -186,7 +206,7 @@ def boundary_neumann_statio(f, border_batch, u, params, facet):
|
|
|
186
206
|
n = jnp.array([[-1, 1, 0, 0], [0, 0, -1, 1]])
|
|
187
207
|
|
|
188
208
|
if isinstance(u, PINN):
|
|
189
|
-
u_ = lambda x, params: u(x, params)[
|
|
209
|
+
u_ = lambda x, params: jnp.squeeze(u(x, params)[dim_to_apply])
|
|
190
210
|
v_neumann = vmap(
|
|
191
211
|
lambda dx: jnp.dot(
|
|
192
212
|
grad(u_, 0)(dx, params),
|
|
@@ -206,7 +226,7 @@ def boundary_neumann_statio(f, border_batch, u, params, facet):
|
|
|
206
226
|
lambda x: u(
|
|
207
227
|
x,
|
|
208
228
|
params,
|
|
209
|
-
),
|
|
229
|
+
)[..., dim_to_apply],
|
|
210
230
|
(border_batch,),
|
|
211
231
|
(jnp.ones_like(border_batch),),
|
|
212
232
|
)
|
|
@@ -246,7 +266,9 @@ def boundary_neumann_statio(f, border_batch, u, params, facet):
|
|
|
246
266
|
return mse_u_boundary
|
|
247
267
|
|
|
248
268
|
|
|
249
|
-
def boundary_dirichlet_nonstatio(
|
|
269
|
+
def boundary_dirichlet_nonstatio(
|
|
270
|
+
f, times_batch, omega_border_batch, u, params, dim_to_apply
|
|
271
|
+
):
|
|
250
272
|
"""
|
|
251
273
|
This omega boundary condition enforces a solution that is equal to f
|
|
252
274
|
at times_batch x omega borders
|
|
@@ -266,6 +288,8 @@ def boundary_dirichlet_nonstatio(f, times_batch, omega_border_batch, u, params):
|
|
|
266
288
|
Typically, it is a dictionary of
|
|
267
289
|
dictionaries: `eq_params` and `nn_params``, respectively the
|
|
268
290
|
differential equation parameters and the neural network parameter
|
|
291
|
+
dim_to_apply
|
|
292
|
+
A jnp.s_ object. The dimension of u on which to apply the boundary condition
|
|
269
293
|
"""
|
|
270
294
|
if isinstance(u, PINN):
|
|
271
295
|
tile_omega_border_batch = jnp.tile(
|
|
@@ -280,7 +304,7 @@ def boundary_dirichlet_nonstatio(f, times_batch, omega_border_batch, u, params):
|
|
|
280
304
|
t,
|
|
281
305
|
dx,
|
|
282
306
|
params,
|
|
283
|
-
)
|
|
307
|
+
)[dim_to_apply]
|
|
284
308
|
- f(t, dx),
|
|
285
309
|
(0, 0),
|
|
286
310
|
0,
|
|
@@ -304,7 +328,7 @@ def boundary_dirichlet_nonstatio(f, times_batch, omega_border_batch, u, params):
|
|
|
304
328
|
# otherwise we require batches to have same shape and we do not need
|
|
305
329
|
# this operation
|
|
306
330
|
|
|
307
|
-
values = u(times_batch, tile_omega_border_batch, params)
|
|
331
|
+
values = u(times_batch, tile_omega_border_batch, params)[..., dim_to_apply]
|
|
308
332
|
tx_grid = _get_grid(jnp.concatenate([times_batch, omega_border_batch], axis=-1))
|
|
309
333
|
boundaries = _check_user_func_return(
|
|
310
334
|
f(tx_grid[..., 0:1], tx_grid[..., 1:]), values.shape
|
|
@@ -314,7 +338,9 @@ def boundary_dirichlet_nonstatio(f, times_batch, omega_border_batch, u, params):
|
|
|
314
338
|
return mse_u_boundary
|
|
315
339
|
|
|
316
340
|
|
|
317
|
-
def boundary_neumann_nonstatio(
|
|
341
|
+
def boundary_neumann_nonstatio(
|
|
342
|
+
f, times_batch, omega_border_batch, u, params, facet, dim_to_apply
|
|
343
|
+
):
|
|
318
344
|
r"""
|
|
319
345
|
This omega boundary condition enforces a solution where :math:`\nabla u\cdot
|
|
320
346
|
n` is equal to `f` at time_batch x omega borders. :math:`n` is the unitary
|
|
@@ -338,6 +364,8 @@ def boundary_neumann_nonstatio(f, times_batch, omega_border_batch, u, params, fa
|
|
|
338
364
|
facet:
|
|
339
365
|
An integer which represents the id of the facet which is currently
|
|
340
366
|
considered (in the order provided wy the DataGenerator which is fixed)
|
|
367
|
+
dim_to_apply
|
|
368
|
+
A jnp.s_ object. The dimension of u on which to apply the boundary condition
|
|
341
369
|
"""
|
|
342
370
|
# We resort to the shape of the border_batch to determine the dimension as
|
|
343
371
|
# described in the border_batch function
|
|
@@ -358,7 +386,7 @@ def boundary_neumann_nonstatio(f, times_batch, omega_border_batch, u, params, fa
|
|
|
358
386
|
def rep_times(k):
|
|
359
387
|
return jnp.repeat(times_batch, k, axis=0)
|
|
360
388
|
|
|
361
|
-
u_ = lambda t, x, params: u(t, x, params)[
|
|
389
|
+
u_ = lambda t, x, params: jnp.squeeze(u(t, x, params)[dim_to_apply])
|
|
362
390
|
v_neumann = vmap(
|
|
363
391
|
lambda t, dx: jnp.dot(
|
|
364
392
|
grad(u_, 1)(t, dx, params),
|
|
@@ -389,7 +417,7 @@ def boundary_neumann_nonstatio(f, times_batch, omega_border_batch, u, params, fa
|
|
|
389
417
|
# high dim output at once
|
|
390
418
|
if omega_border_batch.shape[0] == 1: # i.e. case 1D
|
|
391
419
|
_, du_dx = jax.jvp(
|
|
392
|
-
lambda x: u(times_batch, x, params),
|
|
420
|
+
lambda x: u(times_batch, x, params)[..., dim_to_apply],
|
|
393
421
|
(omega_border_batch,),
|
|
394
422
|
(jnp.ones_like(omega_border_batch),),
|
|
395
423
|
)
|
|
@@ -402,12 +430,12 @@ def boundary_neumann_nonstatio(f, times_batch, omega_border_batch, u, params, fa
|
|
|
402
430
|
jnp.array([0.0, 1.0])[None], omega_border_batch.shape[0], axis=0
|
|
403
431
|
)
|
|
404
432
|
_, du_dx1 = jax.jvp(
|
|
405
|
-
lambda x: u(times_batch, x, params),
|
|
433
|
+
lambda x: u(times_batch, x, params)[..., dim_to_apply],
|
|
406
434
|
(omega_border_batch,),
|
|
407
435
|
(tangent_vec_0,),
|
|
408
436
|
)
|
|
409
437
|
_, du_dx2 = jax.jvp(
|
|
410
|
-
lambda x: u(times_batch, x, params),
|
|
438
|
+
lambda x: u(times_batch, x, params)[..., dim_to_apply],
|
|
411
439
|
(omega_border_batch,),
|
|
412
440
|
(tangent_vec_1,),
|
|
413
441
|
)
|
jinns/loss/_operators.py
CHANGED
|
@@ -9,7 +9,7 @@ from jinns.utils._pinn import PINN
|
|
|
9
9
|
from jinns.utils._spinn import SPINN
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
def _div_rev(
|
|
12
|
+
def _div_rev(t, x, u, params):
|
|
13
13
|
r"""
|
|
14
14
|
Compute the divergence of a vector field :math:`\mathbf{u}`, i.e.,
|
|
15
15
|
:math:`\nabla \cdot \mathbf{u}(\mathbf{x})` with :math:`\mathbf{u}` a vector
|
|
@@ -21,14 +21,14 @@ def _div_rev(u, params, x, t=None):
|
|
|
21
21
|
if t is None:
|
|
22
22
|
du_dxi = grad(lambda x, params: u(x, params)[i], 0)(x, params)[i]
|
|
23
23
|
else:
|
|
24
|
-
du_dxi = grad(lambda t, x, params: u(x, params)[i], 1)(x, params)[i]
|
|
24
|
+
du_dxi = grad(lambda t, x, params: u(t, x, params)[i], 1)(t, x, params)[i]
|
|
25
25
|
return _, du_dxi
|
|
26
26
|
|
|
27
27
|
_, accu = jax.lax.scan(scan_fun, {}, jnp.arange(x.shape[0]))
|
|
28
28
|
return jnp.sum(accu)
|
|
29
29
|
|
|
30
30
|
|
|
31
|
-
def _div_fwd(
|
|
31
|
+
def _div_fwd(t, x, u, params):
|
|
32
32
|
r"""
|
|
33
33
|
Compute the divergence of a **batched** vector field :math:`\mathbf{u}`, i.e.,
|
|
34
34
|
:math:`\nabla \cdot \mathbf{u}(\mathbf{x})` with :math:`\mathbf{u}` a vector
|
|
@@ -55,7 +55,7 @@ def _div_fwd(u, params, x, t=None):
|
|
|
55
55
|
return jnp.sum(accu, axis=0)
|
|
56
56
|
|
|
57
57
|
|
|
58
|
-
def _laplacian_rev(
|
|
58
|
+
def _laplacian_rev(t, x, u, params):
|
|
59
59
|
r"""
|
|
60
60
|
Compute the Laplacian of a scalar field :math:`u` (from :math:`\mathbb{R}^d`
|
|
61
61
|
to :math:`\mathbb{R}`) for :math:`\mathbf{x}` of arbitrary dimension, i.e.,
|
|
@@ -98,7 +98,7 @@ def _laplacian_rev(u, params, x, t=None):
|
|
|
98
98
|
# return jnp.sum(trace_hessian)
|
|
99
99
|
|
|
100
100
|
|
|
101
|
-
def _laplacian_fwd(
|
|
101
|
+
def _laplacian_fwd(t, x, u, params):
|
|
102
102
|
r"""
|
|
103
103
|
Compute the Laplacian of a **batched** scalar field :math:`u`
|
|
104
104
|
(from :math:`\mathbb{R}^{b\times d}` to :math:`\mathbb{R}^{b\times b}`)
|
|
@@ -134,7 +134,7 @@ def _laplacian_fwd(u, params, x, t=None):
|
|
|
134
134
|
return jnp.sum(trace_hessian, axis=0)
|
|
135
135
|
|
|
136
136
|
|
|
137
|
-
def _vectorial_laplacian(
|
|
137
|
+
def _vectorial_laplacian(t, x, u, params, u_vec_ndim=None):
|
|
138
138
|
r"""
|
|
139
139
|
Compute the vectorial Laplacian of a vector field :math:`\mathbf{u}` (from
|
|
140
140
|
:math:`\mathbb{R}^d`
|
|
@@ -163,7 +163,7 @@ def _vectorial_laplacian(u, params, x, t=None, u_vec_ndim=None):
|
|
|
163
163
|
uj = lambda x, params: jnp.expand_dims(u(x, params)[j], axis=-1)
|
|
164
164
|
else:
|
|
165
165
|
uj = lambda t, x, params: jnp.expand_dims(u(t, x, params)[j], axis=-1)
|
|
166
|
-
lap_on_j = _laplacian_rev(
|
|
166
|
+
lap_on_j = _laplacian_rev(t, x, uj, params)
|
|
167
167
|
elif isinstance(u, SPINN):
|
|
168
168
|
if t is None:
|
|
169
169
|
uj = lambda x, params: jnp.expand_dims(u(x, params)[..., j], axis=-1)
|
|
@@ -171,7 +171,7 @@ def _vectorial_laplacian(u, params, x, t=None, u_vec_ndim=None):
|
|
|
171
171
|
uj = lambda t, x, params: jnp.expand_dims(
|
|
172
172
|
u(t, x, params)[..., j], axis=-1
|
|
173
173
|
)
|
|
174
|
-
lap_on_j = _laplacian_fwd(
|
|
174
|
+
lap_on_j = _laplacian_fwd(t, x, uj, params)
|
|
175
175
|
|
|
176
176
|
return _, lap_on_j
|
|
177
177
|
|
|
@@ -179,7 +179,7 @@ def _vectorial_laplacian(u, params, x, t=None, u_vec_ndim=None):
|
|
|
179
179
|
return vec_lap
|
|
180
180
|
|
|
181
181
|
|
|
182
|
-
def _u_dot_nabla_times_u_rev(
|
|
182
|
+
def _u_dot_nabla_times_u_rev(t, x, u, params):
|
|
183
183
|
r"""
|
|
184
184
|
Implement :math:`((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})` for
|
|
185
185
|
:math:`\mathbf{x}` of arbitrary
|
|
@@ -224,7 +224,7 @@ def _u_dot_nabla_times_u_rev(u, params, x, t=None):
|
|
|
224
224
|
raise NotImplementedError("x.ndim must be 2")
|
|
225
225
|
|
|
226
226
|
|
|
227
|
-
def _u_dot_nabla_times_u_fwd(
|
|
227
|
+
def _u_dot_nabla_times_u_fwd(t, x, u, params):
|
|
228
228
|
r"""
|
|
229
229
|
Implement :math:`((\mathbf{u}\cdot\nabla)\mathbf{u})(\mathbf{x})` for
|
|
230
230
|
:math:`\mathbf{x}` of arbitrary dimension **with a batch dimension**.
|
jinns/utils/_pinn.py
CHANGED
|
@@ -64,16 +64,44 @@ class PINN:
|
|
|
64
64
|
The function create_PINN has the role to population the `__call__` function
|
|
65
65
|
"""
|
|
66
66
|
|
|
67
|
-
def __init__(
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
key,
|
|
70
|
+
eqx_list,
|
|
71
|
+
slice_solution,
|
|
72
|
+
eq_type,
|
|
73
|
+
input_transform,
|
|
74
|
+
output_transform,
|
|
75
|
+
output_slice=None,
|
|
76
|
+
):
|
|
68
77
|
_pinn = _MLP(key, eqx_list)
|
|
69
78
|
self.params, self.static = eqx.partition(_pinn, eqx.is_inexact_array)
|
|
79
|
+
self.slice_solution = slice_solution
|
|
80
|
+
self.eq_type = eq_type
|
|
81
|
+
self.input_transform = input_transform
|
|
82
|
+
self.output_transform = output_transform
|
|
70
83
|
self.output_slice = output_slice
|
|
71
84
|
|
|
72
85
|
def init_params(self):
|
|
73
86
|
return self.params
|
|
74
87
|
|
|
75
|
-
def __call__(self, *args
|
|
76
|
-
|
|
88
|
+
def __call__(self, *args):
|
|
89
|
+
if self.eq_type == "ODE":
|
|
90
|
+
(t, params) = args
|
|
91
|
+
t = t[None] # Add dimension which is lacking for the ODE batches
|
|
92
|
+
return self._eval_nn(
|
|
93
|
+
t, params, self.input_transform, self.output_transform
|
|
94
|
+
).squeeze()
|
|
95
|
+
if self.eq_type == "statio_PDE":
|
|
96
|
+
(x, params) = args
|
|
97
|
+
return self._eval_nn(x, params, self.input_transform, self.output_transform)
|
|
98
|
+
if self.eq_type == "nonstatio_PDE":
|
|
99
|
+
(t, x, params) = args
|
|
100
|
+
t_x = jnp.concatenate([t, x], axis=-1)
|
|
101
|
+
return self._eval_nn(
|
|
102
|
+
t_x, params, self.input_transform, self.output_transform
|
|
103
|
+
)
|
|
104
|
+
raise ValueError("Wrong value for self.eq_type")
|
|
77
105
|
|
|
78
106
|
def _eval_nn(self, inputs, params, input_transform, output_transform):
|
|
79
107
|
"""
|
|
@@ -82,7 +110,7 @@ class PINN:
|
|
|
82
110
|
"""
|
|
83
111
|
try:
|
|
84
112
|
model = eqx.combine(params["nn_params"], self.static)
|
|
85
|
-
except: # give more flexibility
|
|
113
|
+
except (KeyError, TypeError) as e: # give more flexibility
|
|
86
114
|
model = eqx.combine(params, self.static)
|
|
87
115
|
res = output_transform(inputs, model(input_transform(inputs, params)).squeeze())
|
|
88
116
|
|
|
@@ -103,6 +131,7 @@ def create_PINN(
|
|
|
103
131
|
input_transform=None,
|
|
104
132
|
output_transform=None,
|
|
105
133
|
shared_pinn_outputs=None,
|
|
134
|
+
slice_solution=None,
|
|
106
135
|
):
|
|
107
136
|
"""
|
|
108
137
|
Utility function to create a standard PINN neural network with the equinox
|
|
@@ -152,6 +181,13 @@ def create_PINN(
|
|
|
152
181
|
network. In this case we return a list of PINNs, one for each output in
|
|
153
182
|
shared_pinn_outputs. This is useful to create PINNs that share the
|
|
154
183
|
same network and same parameters. Default is None, we only return one PINN.
|
|
184
|
+
slice_solution
|
|
185
|
+
A jnp.s_ object which indicates which axis of the PINN output is
|
|
186
|
+
dedicated to the actual equation solution. Default None
|
|
187
|
+
means that slice_solution = the whole PINN output. This argument is useful
|
|
188
|
+
when the PINN is also used to output equation parameters for example
|
|
189
|
+
Note that it must be a slice and not an integer (a preprocessing of the
|
|
190
|
+
user provided argument takes care of it)
|
|
155
191
|
|
|
156
192
|
|
|
157
193
|
Returns
|
|
@@ -182,6 +218,19 @@ def create_PINN(
|
|
|
182
218
|
if eq_type != "ODE" and dim_x == 0:
|
|
183
219
|
raise RuntimeError("Wrong parameter combination eq_type and dim_x")
|
|
184
220
|
|
|
221
|
+
try:
|
|
222
|
+
nb_outputs_declared = eqx_list[-1][2] # normally we look for 3rd ele of
|
|
223
|
+
# last layer
|
|
224
|
+
except IndexError:
|
|
225
|
+
nb_outputs_declared = eqx_list[-2][2]
|
|
226
|
+
|
|
227
|
+
if slice_solution is None:
|
|
228
|
+
slice_solution = jnp.s_[0:nb_outputs_declared]
|
|
229
|
+
if isinstance(slice_solution, int):
|
|
230
|
+
# rewrite it as a slice to ensure that axis does not disappear when
|
|
231
|
+
# indexing
|
|
232
|
+
slice_solution = jnp.s_[slice_solution : slice_solution + 1]
|
|
233
|
+
|
|
185
234
|
if input_transform is None:
|
|
186
235
|
|
|
187
236
|
def input_transform(_in, _params):
|
|
@@ -192,34 +241,19 @@ def create_PINN(
|
|
|
192
241
|
def output_transform(_in_pinn, _out_pinn):
|
|
193
242
|
return _out_pinn
|
|
194
243
|
|
|
195
|
-
if eq_type == "ODE":
|
|
196
|
-
|
|
197
|
-
def apply_fn(self, t, params):
|
|
198
|
-
t = t[
|
|
199
|
-
None
|
|
200
|
-
] # Note that we added a dimension to t which is lacking for the ODE batches
|
|
201
|
-
return self._eval_nn(t, params, input_transform, output_transform).squeeze()
|
|
202
|
-
|
|
203
|
-
elif eq_type == "statio_PDE":
|
|
204
|
-
# Here we add an argument `x` which can be high dimensional
|
|
205
|
-
def apply_fn(self, x, params):
|
|
206
|
-
return self._eval_nn(x, params, input_transform, output_transform)
|
|
207
|
-
|
|
208
|
-
elif eq_type == "nonstatio_PDE":
|
|
209
|
-
# Here we add an argument `x` which can be high dimensional
|
|
210
|
-
def apply_fn(self, t, x, params):
|
|
211
|
-
t_x = jnp.concatenate([t, x], axis=-1)
|
|
212
|
-
return self._eval_nn(t_x, params, input_transform, output_transform)
|
|
213
|
-
|
|
214
|
-
else:
|
|
215
|
-
raise RuntimeError("Wrong parameter value for eq_type")
|
|
216
|
-
|
|
217
244
|
if shared_pinn_outputs is not None:
|
|
218
245
|
pinns = []
|
|
219
246
|
static = None
|
|
220
247
|
for output_slice in shared_pinn_outputs:
|
|
221
|
-
pinn = PINN(
|
|
222
|
-
|
|
248
|
+
pinn = PINN(
|
|
249
|
+
key,
|
|
250
|
+
eqx_list,
|
|
251
|
+
slice_solution,
|
|
252
|
+
eq_type,
|
|
253
|
+
input_transform,
|
|
254
|
+
output_transform,
|
|
255
|
+
output_slice,
|
|
256
|
+
)
|
|
223
257
|
# all the pinns are in fact the same so we share the same static
|
|
224
258
|
if static is None:
|
|
225
259
|
static = pinn.static
|
|
@@ -227,6 +261,7 @@ def create_PINN(
|
|
|
227
261
|
pinn.static = static
|
|
228
262
|
pinns.append(pinn)
|
|
229
263
|
return pinns
|
|
230
|
-
pinn = PINN(
|
|
231
|
-
|
|
264
|
+
pinn = PINN(
|
|
265
|
+
key, eqx_list, slice_solution, eq_type, input_transform, output_transform
|
|
266
|
+
)
|
|
232
267
|
return pinn
|
jinns/utils/_spinn.py
CHANGED
|
@@ -82,16 +82,35 @@ class SPINN:
|
|
|
82
82
|
The function create_SPINN has the role to population the `__call__` function
|
|
83
83
|
"""
|
|
84
84
|
|
|
85
|
-
def __init__(self, key, d, r, eqx_list, m=1):
|
|
85
|
+
def __init__(self, key, d, r, eqx_list, eq_type, m=1):
|
|
86
86
|
self.d, self.r, self.m = d, r, m
|
|
87
87
|
_spinn = _SPINN(key, d, r, eqx_list, m)
|
|
88
88
|
self.params, self.static = eqx.partition(_spinn, eqx.is_inexact_array)
|
|
89
|
+
self.eq_type = eq_type
|
|
89
90
|
|
|
90
91
|
def init_params(self):
|
|
91
92
|
return self.params
|
|
92
93
|
|
|
93
|
-
def __call__(self, *args
|
|
94
|
-
|
|
94
|
+
def __call__(self, *args):
|
|
95
|
+
if self.eq_type == "statio_PDE":
|
|
96
|
+
(x, params) = args
|
|
97
|
+
try:
|
|
98
|
+
spinn = eqx.combine(params["nn_params"], self.static)
|
|
99
|
+
except (KeyError, TypeError) as e:
|
|
100
|
+
spinn = eqx.combine(params, self.static)
|
|
101
|
+
v_model = jax.vmap(spinn, (0))
|
|
102
|
+
res = v_model(t=None, x=x)
|
|
103
|
+
return self._eval_nn(res)
|
|
104
|
+
if self.eq_type == "nonstatio_PDE":
|
|
105
|
+
(t, x, params) = args
|
|
106
|
+
try:
|
|
107
|
+
spinn = eqx.combine(params["nn_params"], self.static)
|
|
108
|
+
except (KeyError, TypeError) as e:
|
|
109
|
+
spinn = eqx.combine(params, self.static)
|
|
110
|
+
v_model = jax.vmap(spinn, ((0, 0)))
|
|
111
|
+
res = v_model(t, x)
|
|
112
|
+
return self._eval_nn(res)
|
|
113
|
+
raise RuntimeError("Wrong parameter value for eq_type")
|
|
95
114
|
|
|
96
115
|
def _eval_nn(self, res):
|
|
97
116
|
"""
|
|
@@ -209,32 +228,6 @@ def create_SPINN(key, d, r, eqx_list, eq_type, m=1):
|
|
|
209
228
|
"Too many dimensions, not enough letters available in jnp.einsum"
|
|
210
229
|
)
|
|
211
230
|
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
def apply_fn(self, x, params):
|
|
215
|
-
try:
|
|
216
|
-
spinn = eqx.combine(params["nn_params"], self.static)
|
|
217
|
-
except: # give more flexibility
|
|
218
|
-
spinn = eqx.combine(params, self.static)
|
|
219
|
-
v_model = jax.vmap(spinn, (0))
|
|
220
|
-
res = v_model(t=None, x=x)
|
|
221
|
-
return self._eval_nn(res)
|
|
222
|
-
|
|
223
|
-
elif eq_type == "nonstatio_PDE":
|
|
224
|
-
|
|
225
|
-
def apply_fn(self, t, x, params):
|
|
226
|
-
try:
|
|
227
|
-
spinn = eqx.combine(params["nn_params"], self.static)
|
|
228
|
-
except: # give more flexibility
|
|
229
|
-
spinn = eqx.combine(params, self.static)
|
|
230
|
-
v_model = jax.vmap(spinn, ((0, 0)))
|
|
231
|
-
res = v_model(t, x)
|
|
232
|
-
return self._eval_nn(res)
|
|
233
|
-
|
|
234
|
-
else:
|
|
235
|
-
raise RuntimeError("Wrong parameter value for eq_type")
|
|
236
|
-
|
|
237
|
-
spinn = SPINN(key, d, r, eqx_list, m)
|
|
238
|
-
spinn.apply_fn = apply_fn
|
|
231
|
+
spinn = SPINN(key, d, r, eqx_list, eq_type, m)
|
|
239
232
|
|
|
240
233
|
return spinn
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: jinns
|
|
3
|
-
Version: 0.6.
|
|
3
|
+
Version: 0.6.1
|
|
4
4
|
Summary: Physics Informed Neural Network with JAX
|
|
5
5
|
Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
6
6
|
Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
@@ -2,24 +2,24 @@ jinns/__init__.py,sha256=Nw5pdlmDhJwco3bXX3YttkeCF8czX_6m0poh8vu0lDQ,113
|
|
|
2
2
|
jinns/data/_DataGenerators.py,sha256=nIuKtkX4V4ckfT4-g0bjlY7BLkgcok5JbI9OzJn73mA,44461
|
|
3
3
|
jinns/data/__init__.py,sha256=S13J59Fxuph4uNJ542fP_Mj8U72ilhb5t_UQ-c1k3nY,232
|
|
4
4
|
jinns/data/_display.py,sha256=NfINLJAGmQSPz30cWVaeQFpabzCXprp4RNH6Iycx-VU,7722
|
|
5
|
-
jinns/loss/_DynamicLoss.py,sha256=
|
|
6
|
-
jinns/loss/_DynamicLossAbstract.py,sha256=
|
|
5
|
+
jinns/loss/_DynamicLoss.py,sha256=JOn6rpVcM825QHFWBQBOGCrR9tvStDaZiaB6pxZ0fhE,27336
|
|
6
|
+
jinns/loss/_DynamicLossAbstract.py,sha256=NnvS8WKJZZZiGMfidxpGlkmWtINn_UDj3BNXe9l9O84,7787
|
|
7
7
|
jinns/loss/_LossODE.py,sha256=nYWzUVO6vpZMYXZkeUpDRMZSowyGfGWUdwzBNciQkyo,20639
|
|
8
|
-
jinns/loss/_LossPDE.py,sha256=
|
|
8
|
+
jinns/loss/_LossPDE.py,sha256=gPi75x4MTjNSxY0Vz7swnsG1serPOtIPeKYsV5uYkr4,72221
|
|
9
9
|
jinns/loss/__init__.py,sha256=pFNYUxns-NPXBFdqrEVSiXkQLfCtKw-t2trlhvLzpYE,355
|
|
10
|
-
jinns/loss/_boundary_conditions.py,sha256=
|
|
11
|
-
jinns/loss/_operators.py,sha256=
|
|
10
|
+
jinns/loss/_boundary_conditions.py,sha256=K4cLYe5ecW2gJznv51nM3zOvE4UO-ZvVmVp3VIiuddM,16296
|
|
11
|
+
jinns/loss/_operators.py,sha256=zDGJqYqeYH7xd-4dtGX9PS-pf0uSOpUUXGo5SVjIJ4o,11069
|
|
12
12
|
jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
13
|
jinns/solver/_rar.py,sha256=ZoklcHq5MFLUsIiWMUH_426MRxq7BUZkD86jcQYau0I,13918
|
|
14
14
|
jinns/solver/_seq2seq.py,sha256=ihHnb6UpvShgEXOynUhsrPyt0wqNXIjeEL8HxvgRRHE,5985
|
|
15
15
|
jinns/solver/_solve.py,sha256=-Ji16NJ9CMBcFDKLFpBibvQwow9ysAB4pVGwRRqoXRE,10138
|
|
16
16
|
jinns/utils/__init__.py,sha256=bksGuq0mNoqciKIhEA8wOHkUppYQYJqOiHK2SEn5jac,227
|
|
17
17
|
jinns/utils/_optim.py,sha256=Q8a_dMqb6rjdv1qGibgvl7cv8DENfXsthDJ56RziWJc,4441
|
|
18
|
-
jinns/utils/_pinn.py,sha256=
|
|
19
|
-
jinns/utils/_spinn.py,sha256=
|
|
18
|
+
jinns/utils/_pinn.py,sha256=wEbCU4aDT8nd1sz72ZrgR8e-xlR-P0iczHk_hnnZQBo,9427
|
|
19
|
+
jinns/utils/_spinn.py,sha256=RZveR5R2HKMyQKuWQE_H6WKAfxABCkB6YGKyDZK051w,7889
|
|
20
20
|
jinns/utils/_utils.py,sha256=b0-R6jtyMescaDfod0DuzeLMpm1C7gTwpU1MZEW2468,6060
|
|
21
|
-
jinns-0.6.
|
|
22
|
-
jinns-0.6.
|
|
23
|
-
jinns-0.6.
|
|
24
|
-
jinns-0.6.
|
|
25
|
-
jinns-0.6.
|
|
21
|
+
jinns-0.6.1.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
|
|
22
|
+
jinns-0.6.1.dist-info/METADATA,sha256=yUHlhcVpsm8b1vHJSy5_TZjEUaaufI3niTvWZNXehks,2388
|
|
23
|
+
jinns-0.6.1.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
|
24
|
+
jinns-0.6.1.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
|
|
25
|
+
jinns-0.6.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|