jinns 0.8.6__py3-none-any.whl → 0.8.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +1 -0
- jinns/data/_display.py +102 -13
- jinns/experimental/__init__.py +2 -0
- jinns/experimental/_sinuspinn.py +135 -0
- jinns/experimental/_spectralpinn.py +87 -0
- jinns/loss/_LossODE.py +6 -0
- jinns/loss/_LossPDE.py +18 -18
- jinns/solver/_solve.py +264 -121
- jinns/utils/_containers.py +57 -0
- jinns/validation/__init__.py +1 -0
- jinns/validation/_validation.py +214 -0
- {jinns-0.8.6.dist-info → jinns-0.8.8.dist-info}/METADATA +1 -1
- {jinns-0.8.6.dist-info → jinns-0.8.8.dist-info}/RECORD +16 -11
- {jinns-0.8.6.dist-info → jinns-0.8.8.dist-info}/LICENSE +0 -0
- {jinns-0.8.6.dist-info → jinns-0.8.8.dist-info}/WHEEL +0 -0
- {jinns-0.8.6.dist-info → jinns-0.8.8.dist-info}/top_level.txt +0 -0
jinns/__init__.py
CHANGED
jinns/data/_display.py
CHANGED
|
@@ -1,8 +1,13 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility functions for plotting
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from functools import partial
|
|
6
|
+
import warnings
|
|
1
7
|
import matplotlib.pyplot as plt
|
|
2
8
|
import jax.numpy as jnp
|
|
3
9
|
from jax import vmap
|
|
4
10
|
from mpl_toolkits.axes_grid1 import ImageGrid
|
|
5
|
-
from functools import partial
|
|
6
11
|
|
|
7
12
|
|
|
8
13
|
def plot2d(
|
|
@@ -14,8 +19,10 @@ def plot2d(
|
|
|
14
19
|
figsize=(7, 7),
|
|
15
20
|
cmap="inferno",
|
|
16
21
|
spinn=False,
|
|
22
|
+
vmin_vmax=None,
|
|
23
|
+
ax_for_plot=None,
|
|
17
24
|
):
|
|
18
|
-
"""Generic function for plotting functions over rectangular 2-D domains
|
|
25
|
+
r"""Generic function for plotting functions over rectangular 2-D domains
|
|
19
26
|
:math:`\Omega`. It treats both the stationary case :math:`u(x)` or the
|
|
20
27
|
non-stationnary case :math:`u(t, x)`.
|
|
21
28
|
|
|
@@ -40,6 +47,12 @@ def plot2d(
|
|
|
40
47
|
_description_, by default (7, 7)
|
|
41
48
|
cmap : str, optional
|
|
42
49
|
_description_, by default "inferno"
|
|
50
|
+
vmin_vmax : tuple, optional
|
|
51
|
+
The colorbar minimum and maximum value. Defaults None.
|
|
52
|
+
ax_for_plot : Matplotlib axis, optional
|
|
53
|
+
If None, jinns triggers the plotting. Otherwise this argument
|
|
54
|
+
corresponds to the axis which will host the plot. Default is None.
|
|
55
|
+
NOTE: that this argument will have an effect only if times is None.
|
|
43
56
|
|
|
44
57
|
Raises
|
|
45
58
|
------
|
|
@@ -59,25 +72,49 @@ def plot2d(
|
|
|
59
72
|
# Statio case : expect a function of one argument fun(x)
|
|
60
73
|
if not spinn:
|
|
61
74
|
v_fun = vmap(fun, 0, 0)
|
|
62
|
-
_plot_2D_statio(
|
|
63
|
-
v_fun,
|
|
75
|
+
ret = _plot_2D_statio(
|
|
76
|
+
v_fun,
|
|
77
|
+
mesh,
|
|
78
|
+
plot=not ax_for_plot,
|
|
79
|
+
colorbar=True,
|
|
80
|
+
cmap=cmap,
|
|
81
|
+
figsize=figsize,
|
|
82
|
+
vmin_vmax=vmin_vmax,
|
|
64
83
|
)
|
|
65
84
|
elif spinn:
|
|
66
85
|
values_grid = jnp.squeeze(
|
|
67
86
|
fun(jnp.stack([xy_data[0][..., None], xy_data[1][..., None]], axis=1))
|
|
68
87
|
)
|
|
69
|
-
_plot_2D_statio(
|
|
88
|
+
ret = _plot_2D_statio(
|
|
70
89
|
values_grid,
|
|
71
90
|
mesh,
|
|
72
|
-
plot=
|
|
91
|
+
plot=not ax_for_plot,
|
|
73
92
|
colorbar=True,
|
|
74
93
|
cmap=cmap,
|
|
75
94
|
spinn=True,
|
|
76
95
|
figsize=figsize,
|
|
96
|
+
vmin_vmax=vmin_vmax,
|
|
77
97
|
)
|
|
78
|
-
|
|
98
|
+
if not ax_for_plot:
|
|
99
|
+
plt.title(title)
|
|
100
|
+
else:
|
|
101
|
+
if vmin_vmax is not None:
|
|
102
|
+
im = ax_for_plot.pcolormesh(
|
|
103
|
+
mesh[0],
|
|
104
|
+
mesh[1],
|
|
105
|
+
ret[0],
|
|
106
|
+
cmap=cmap,
|
|
107
|
+
vmin=vmin_vmax[0],
|
|
108
|
+
vmax=vmin_vmax[1],
|
|
109
|
+
)
|
|
110
|
+
else:
|
|
111
|
+
im = ax_for_plot.pcolormesh(mesh[0], mesh[1], ret[0], cmap=cmap)
|
|
112
|
+
ax_for_plot.set_title(title)
|
|
113
|
+
ax_for_plot.cax.colorbar(im, format="%0.2f")
|
|
79
114
|
|
|
80
115
|
else:
|
|
116
|
+
if ax_for_plot is not None:
|
|
117
|
+
warnings.warn("ax_for_plot is ignored. jinns will plot the figure")
|
|
81
118
|
if not isinstance(times, list):
|
|
82
119
|
try:
|
|
83
120
|
times = times.tolist()
|
|
@@ -101,7 +138,12 @@ def plot2d(
|
|
|
101
138
|
if not spinn:
|
|
102
139
|
v_fun_at_t = vmap(lambda x: fun(t=jnp.array([t]), x=x), 0, 0)
|
|
103
140
|
t_slice, _ = _plot_2D_statio(
|
|
104
|
-
v_fun_at_t,
|
|
141
|
+
v_fun_at_t,
|
|
142
|
+
mesh,
|
|
143
|
+
plot=False,
|
|
144
|
+
colorbar=False,
|
|
145
|
+
cmap=None,
|
|
146
|
+
vmin_vmax=vmin_vmax,
|
|
105
147
|
)
|
|
106
148
|
elif spinn:
|
|
107
149
|
values_grid = jnp.squeeze(
|
|
@@ -113,15 +155,37 @@ def plot2d(
|
|
|
113
155
|
)[0]
|
|
114
156
|
)
|
|
115
157
|
t_slice, _ = _plot_2D_statio(
|
|
116
|
-
values_grid,
|
|
158
|
+
values_grid,
|
|
159
|
+
mesh,
|
|
160
|
+
plot=False,
|
|
161
|
+
colorbar=True,
|
|
162
|
+
spinn=True,
|
|
163
|
+
vmin_vmax=vmin_vmax,
|
|
164
|
+
)
|
|
165
|
+
if vmin_vmax is not None:
|
|
166
|
+
im = ax.pcolormesh(
|
|
167
|
+
mesh[0],
|
|
168
|
+
mesh[1],
|
|
169
|
+
t_slice,
|
|
170
|
+
cmap=cmap,
|
|
171
|
+
vmin=vmin_vmax[0],
|
|
172
|
+
vmax=vmin_vmax[1],
|
|
117
173
|
)
|
|
118
|
-
|
|
174
|
+
else:
|
|
175
|
+
im = ax.pcolormesh(mesh[0], mesh[1], t_slice, cmap=cmap)
|
|
119
176
|
ax.set_title(f"t = {times[idx] * Tmax:.2f}")
|
|
120
177
|
ax.cax.colorbar(im, format="%0.2f")
|
|
121
178
|
|
|
122
179
|
|
|
123
180
|
def _plot_2D_statio(
|
|
124
|
-
v_fun,
|
|
181
|
+
v_fun,
|
|
182
|
+
mesh,
|
|
183
|
+
plot=True,
|
|
184
|
+
colorbar=True,
|
|
185
|
+
cmap="inferno",
|
|
186
|
+
figsize=(7, 7),
|
|
187
|
+
spinn=False,
|
|
188
|
+
vmin_vmax=None,
|
|
125
189
|
):
|
|
126
190
|
"""Function that plot the function u(x) with 2-D input x using pcolormesh()
|
|
127
191
|
|
|
@@ -136,6 +200,8 @@ def _plot_2D_statio(
|
|
|
136
200
|
either show or return the plot, by default True
|
|
137
201
|
colorbar : bool, optional
|
|
138
202
|
add a colorbar, by default True
|
|
203
|
+
vmin_vmax: tuple, optional
|
|
204
|
+
The colorbar minimum and maximum value. Defaults None.
|
|
139
205
|
|
|
140
206
|
Returns
|
|
141
207
|
-------
|
|
@@ -153,7 +219,17 @@ def _plot_2D_statio(
|
|
|
153
219
|
|
|
154
220
|
if plot:
|
|
155
221
|
fig = plt.figure(figsize=figsize)
|
|
156
|
-
|
|
222
|
+
if vmin_vmax is not None:
|
|
223
|
+
im = plt.pcolormesh(
|
|
224
|
+
x_grid,
|
|
225
|
+
y_grid,
|
|
226
|
+
values_grid,
|
|
227
|
+
cmap=cmap,
|
|
228
|
+
vmin=vmin_vmax[0],
|
|
229
|
+
vmax=vmin_vmax[1],
|
|
230
|
+
)
|
|
231
|
+
else:
|
|
232
|
+
im = plt.pcolormesh(x_grid, y_grid, values_grid, cmap=cmap)
|
|
157
233
|
if colorbar:
|
|
158
234
|
fig.colorbar(im, format="%0.2f")
|
|
159
235
|
# don't plt.show() because it is done in plot2d()
|
|
@@ -217,6 +293,7 @@ def plot1d_image(
|
|
|
217
293
|
colorbar=True,
|
|
218
294
|
cmap="inferno",
|
|
219
295
|
spinn=False,
|
|
296
|
+
vmin_vmax=None,
|
|
220
297
|
):
|
|
221
298
|
"""Function for plotting the 2-D image of a function :math:`f(t, x)` where
|
|
222
299
|
`t` is time (1-D) and x is space (1-D).
|
|
@@ -237,6 +314,8 @@ def plot1d_image(
|
|
|
237
314
|
, by default ""
|
|
238
315
|
figsize : tuple, optional
|
|
239
316
|
, by default (10, 10)
|
|
317
|
+
vmin_vmax: tuple
|
|
318
|
+
The colorbar minimum and maximum value. Defaults None.
|
|
240
319
|
"""
|
|
241
320
|
|
|
242
321
|
mesh = jnp.meshgrid(times, xdata) # cartesian product
|
|
@@ -250,7 +329,17 @@ def plot1d_image(
|
|
|
250
329
|
elif spinn:
|
|
251
330
|
values_grid = jnp.squeeze(fun((times[..., None]), xdata[..., None]).T)
|
|
252
331
|
fig, ax = plt.subplots(1, 1, figsize=figsize)
|
|
253
|
-
|
|
332
|
+
if vmin_vmax is not None:
|
|
333
|
+
im = ax.pcolormesh(
|
|
334
|
+
mesh[0] * Tmax,
|
|
335
|
+
mesh[1],
|
|
336
|
+
values_grid,
|
|
337
|
+
cmap=cmap,
|
|
338
|
+
vmin=vmin_vmax[0],
|
|
339
|
+
vmax=vmin_vmax[1],
|
|
340
|
+
)
|
|
341
|
+
else:
|
|
342
|
+
im = ax.pcolormesh(mesh[0] * Tmax, mesh[1], values_grid, cmap=cmap)
|
|
254
343
|
if colorbar:
|
|
255
344
|
fig.colorbar(im, format="%0.2f")
|
|
256
345
|
ax.set_title(title)
|
jinns/experimental/__init__.py
CHANGED
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import equinox as eqx
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
from jinns.utils._pinn import PINN
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def almost_zero_init(weight: jax.Array, key: jax.random.PRNGKey) -> jax.Array:
|
|
8
|
+
out, in_ = weight.shape
|
|
9
|
+
stddev = 1e-2
|
|
10
|
+
return stddev * jax.random.normal(key, shape=(out, in_))
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class _SinusPINN(eqx.Module):
|
|
14
|
+
"""
|
|
15
|
+
A specific PINN whose layers are x_sin2x functions whose frequencies are
|
|
16
|
+
determined by an other network
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
layers_pinn: list
|
|
20
|
+
layers_aux_nn: list
|
|
21
|
+
|
|
22
|
+
def __init__(self, key, list_layers_pinn, list_layers_aux_nn):
|
|
23
|
+
"""
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
key
|
|
27
|
+
A jax random key
|
|
28
|
+
list_layers_pinn
|
|
29
|
+
A list as eqx_list in jinns' PINN utility for the main PINN
|
|
30
|
+
list_layers_aux_nn
|
|
31
|
+
A list as eqx_list in jinns' PINN utility for the network which outputs
|
|
32
|
+
the PINN's activation frequencies
|
|
33
|
+
"""
|
|
34
|
+
self.layers_pinn = []
|
|
35
|
+
for l in list_layers_pinn:
|
|
36
|
+
if len(l) == 1:
|
|
37
|
+
self.layers_pinn.append(l[0])
|
|
38
|
+
else:
|
|
39
|
+
key, subkey = jax.random.split(key, 2)
|
|
40
|
+
self.layers_pinn.append(l[0](*l[1:], key=subkey))
|
|
41
|
+
self.layers_aux_nn = []
|
|
42
|
+
for idx, l in enumerate(list_layers_aux_nn):
|
|
43
|
+
if len(l) == 1:
|
|
44
|
+
self.layers_aux_nn.append(l[0])
|
|
45
|
+
else:
|
|
46
|
+
key, subkey = jax.random.split(key, 2)
|
|
47
|
+
linear_layer = l[0](*l[1:], key=subkey)
|
|
48
|
+
key, subkey = jax.random.split(key, 2)
|
|
49
|
+
linear_layer = eqx.tree_at(
|
|
50
|
+
lambda l: l.weight,
|
|
51
|
+
linear_layer,
|
|
52
|
+
almost_zero_init(linear_layer.weight, subkey),
|
|
53
|
+
)
|
|
54
|
+
if (idx == len(list_layers_aux_nn) - 1) or (
|
|
55
|
+
idx == len(list_layers_aux_nn) - 2
|
|
56
|
+
):
|
|
57
|
+
# for the last layer: almost 0 weights and 0.5 bias
|
|
58
|
+
linear_layer = eqx.tree_at(
|
|
59
|
+
lambda l: l.bias,
|
|
60
|
+
linear_layer,
|
|
61
|
+
0.5 * jnp.ones(linear_layer.bias.shape),
|
|
62
|
+
)
|
|
63
|
+
else:
|
|
64
|
+
# for the other previous layers:
|
|
65
|
+
# almost 0 weight and 0 bias
|
|
66
|
+
linear_layer = eqx.tree_at(
|
|
67
|
+
lambda l: l.bias,
|
|
68
|
+
linear_layer,
|
|
69
|
+
jnp.zeros(linear_layer.bias.shape),
|
|
70
|
+
)
|
|
71
|
+
self.layers_aux_nn.append(linear_layer)
|
|
72
|
+
|
|
73
|
+
## init to zero the frequency network except last biases
|
|
74
|
+
# key, subkey = jax.random.split(key, 2)
|
|
75
|
+
# _pinn = init_linear_weight(_pinn, almost_zero_init, subkey)
|
|
76
|
+
# key, subkey = jax.random.split(key, 2)
|
|
77
|
+
# _pinn = init_linear_bias(_pinn, zero_init, subkey)
|
|
78
|
+
# print(_pinn)
|
|
79
|
+
# print(jax.tree_util.tree_leaves(_pinn, is_leaf=lambda
|
|
80
|
+
# p:not isinstance(p,eqx.nn.Linear))[0].layers_aux_nn[-1].bias)
|
|
81
|
+
# _pinn = eqx.tree_at(lambda p:_pinn.layers_aux_nn[-1].bias, 0.5 *
|
|
82
|
+
# jnp.ones(_pinn.layers_aux_nn[-1].bias.shape))
|
|
83
|
+
# #, is_leaf=lambda
|
|
84
|
+
# #p:not isinstance(p, eqx.nn.Linear))
|
|
85
|
+
|
|
86
|
+
def __call__(self, x):
|
|
87
|
+
x_ = x.copy()
|
|
88
|
+
# forward pass in the network which determines the freq
|
|
89
|
+
for layer in self.layers_aux_nn:
|
|
90
|
+
x_ = layer(x_)
|
|
91
|
+
freq_list = jnp.clip(jnp.square(x_), a_min=1e-4, a_max=5)
|
|
92
|
+
x_ = x.copy()
|
|
93
|
+
# forward pass through the actual PINN
|
|
94
|
+
for idx, layer in enumerate(self.layers_pinn):
|
|
95
|
+
if idx % 2 == 0:
|
|
96
|
+
# Currently: every two layer we have an activation
|
|
97
|
+
# requiring a frequency
|
|
98
|
+
x_ = layer(x_)
|
|
99
|
+
else:
|
|
100
|
+
x_ = layer(x_, freq_list[(idx - 1) // 2])
|
|
101
|
+
return x_
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class sinusPINN(PINN):
|
|
105
|
+
"""
|
|
106
|
+
MUST inherit from PINN to pass all the checks
|
|
107
|
+
|
|
108
|
+
HOWEVER we dot not bother with reimplementing anything
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
def __init__(self, key, list_layers_pinn, list_layers_aux_nn):
|
|
112
|
+
super().__init__({}, jnp.s_[...], "statio_PDE", None, None, None)
|
|
113
|
+
key, subkey = jax.random.split(key, 2)
|
|
114
|
+
_pinn = _SinusPINN(subkey, list_layers_pinn, list_layers_aux_nn)
|
|
115
|
+
|
|
116
|
+
self.params, self.static = eqx.partition(_pinn, eqx.is_inexact_array)
|
|
117
|
+
|
|
118
|
+
def init_params(self):
|
|
119
|
+
return self.params
|
|
120
|
+
|
|
121
|
+
def __call__(self, x, params):
|
|
122
|
+
try:
|
|
123
|
+
model = eqx.combine(params["nn_params"], self.static)
|
|
124
|
+
except (KeyError, TypeError) as e: # give more flexibility
|
|
125
|
+
model = eqx.combine(params, self.static)
|
|
126
|
+
res = model(x)
|
|
127
|
+
if not res.shape:
|
|
128
|
+
return jnp.expand_dims(res, axis=-1)
|
|
129
|
+
return res
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def create_sinusPINN(key, list_layers_pinn, list_layers_aux_nn):
|
|
133
|
+
""" """
|
|
134
|
+
u = sinusPINN(key, list_layers_pinn, list_layers_aux_nn)
|
|
135
|
+
return u
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import equinox as eqx
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
from jinns.utils._pinn import PINN
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def almost_zero_init(weight: jax.Array, key: jax.random.PRNGKey) -> jax.Array:
|
|
8
|
+
out, in_ = weight.shape
|
|
9
|
+
stddev = 1e-2
|
|
10
|
+
return stddev * jax.random.normal(key, shape=(out, in_))
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class _SpectralPINN(eqx.Module):
|
|
14
|
+
"""
|
|
15
|
+
A specific PINN whose acrhitecture is similar to spectral method for simulation of a spatial field
|
|
16
|
+
(Chilès and Delfiner, 2012) - a single layer with cos() activation function and sum for last layer
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
layers_pinn: list
|
|
20
|
+
nbands: int
|
|
21
|
+
|
|
22
|
+
def __init__(self, key, list_layers_pinn, nbands):
|
|
23
|
+
"""
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
key
|
|
27
|
+
A jax random key
|
|
28
|
+
list_layers_pinn
|
|
29
|
+
A list as eqx_list in jinns' PINN utility for the main PINN
|
|
30
|
+
nbands
|
|
31
|
+
Number of spectral bands (i.e., neurones in the single layer of the PINN)
|
|
32
|
+
"""
|
|
33
|
+
self.nbands = nbands
|
|
34
|
+
self.layers_pinn = []
|
|
35
|
+
for l in list_layers_pinn:
|
|
36
|
+
if len(l) == 1:
|
|
37
|
+
self.layers_pinn.append(l[0])
|
|
38
|
+
else:
|
|
39
|
+
key, subkey = jax.random.split(key, 2)
|
|
40
|
+
self.layers_pinn.append(l[0](*l[1:], key=subkey))
|
|
41
|
+
|
|
42
|
+
def __call__(self, x):
|
|
43
|
+
# forward pass through the actual PINN
|
|
44
|
+
for layer in self.layers_pinn:
|
|
45
|
+
x = layer(x)
|
|
46
|
+
|
|
47
|
+
return jnp.sqrt(2 / self.nbands) * jnp.sum(x)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class spectralPINN(PINN):
|
|
51
|
+
"""
|
|
52
|
+
MUST inherit from PINN to pass all the checks
|
|
53
|
+
|
|
54
|
+
HOWEVER we dot not bother with reimplementing anything
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(self, key, list_layers_pinn, nbands):
|
|
58
|
+
super().__init__({}, jnp.s_[...], "statio_PDE", None, None, None)
|
|
59
|
+
key, subkey = jax.random.split(key, 2)
|
|
60
|
+
_pinn = _SpectralPINN(subkey, list_layers_pinn, nbands)
|
|
61
|
+
|
|
62
|
+
self.params, self.static = eqx.partition(_pinn, eqx.is_inexact_array)
|
|
63
|
+
|
|
64
|
+
def init_params(self):
|
|
65
|
+
return self.params
|
|
66
|
+
|
|
67
|
+
def __call__(self, x, params):
|
|
68
|
+
try:
|
|
69
|
+
model = eqx.combine(params["nn_params"], self.static)
|
|
70
|
+
except (KeyError, TypeError) as e: # give more flexibility
|
|
71
|
+
model = eqx.combine(params, self.static)
|
|
72
|
+
# model = eqx.tree_at(lambda m:
|
|
73
|
+
# m.layers_pinn[0].bias,
|
|
74
|
+
# model,
|
|
75
|
+
# model.layers_pinn[0].bias % (2 *
|
|
76
|
+
# jnp.pi)
|
|
77
|
+
# )
|
|
78
|
+
res = model(x)
|
|
79
|
+
if not res.shape:
|
|
80
|
+
return jnp.expand_dims(res, axis=-1)
|
|
81
|
+
return res
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def create_spectralPINN(key, list_layers_pinn, nbands):
|
|
85
|
+
""" """
|
|
86
|
+
u = spectralPINN(key, list_layers_pinn, nbands)
|
|
87
|
+
return u
|
jinns/loss/_LossODE.py
CHANGED
|
@@ -19,6 +19,8 @@ from jinns.loss._Losses import (
|
|
|
19
19
|
)
|
|
20
20
|
from jinns.utils._pinn import PINN
|
|
21
21
|
|
|
22
|
+
_LOSS_WEIGHT_KEYS_ODE = ["observations", "dyn_loss", "initial_condition"]
|
|
23
|
+
|
|
22
24
|
|
|
23
25
|
@register_pytree_node_class
|
|
24
26
|
class LossODE:
|
|
@@ -128,6 +130,10 @@ class LossODE:
|
|
|
128
130
|
if self.obs_slice is None:
|
|
129
131
|
self.obs_slice = jnp.s_[...]
|
|
130
132
|
|
|
133
|
+
for k in _LOSS_WEIGHT_KEYS_ODE:
|
|
134
|
+
if k not in self.loss_weights.keys():
|
|
135
|
+
self.loss_weights[k] = 0
|
|
136
|
+
|
|
131
137
|
def __call__(self, *args, **kwargs):
|
|
132
138
|
return self.evaluate(*args, **kwargs)
|
|
133
139
|
|
jinns/loss/_LossPDE.py
CHANGED
|
@@ -31,6 +31,16 @@ _IMPLEMENTED_BOUNDARY_CONDITIONS = [
|
|
|
31
31
|
"vonneumann",
|
|
32
32
|
]
|
|
33
33
|
|
|
34
|
+
_LOSS_WEIGHT_KEYS_PDESTATIO = [
|
|
35
|
+
"sobolev",
|
|
36
|
+
"observations",
|
|
37
|
+
"norm_loss",
|
|
38
|
+
"boundary_loss",
|
|
39
|
+
"dyn_loss",
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
_LOSS_WEIGHT_KEYS_PDENONSTATIO = _LOSS_WEIGHT_KEYS_PDESTATIO + ["initial_condition"]
|
|
43
|
+
|
|
34
44
|
|
|
35
45
|
@register_pytree_node_class
|
|
36
46
|
class LossPDEAbstract:
|
|
@@ -269,8 +279,8 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
269
279
|
the PINN object
|
|
270
280
|
loss_weights
|
|
271
281
|
a dictionary with values used to ponderate each term in the loss
|
|
272
|
-
function. Valid keys are `dyn_loss`, `norm_loss`, `boundary_loss
|
|
273
|
-
and `
|
|
282
|
+
function. Valid keys are `dyn_loss`, `norm_loss`, `boundary_loss`,
|
|
283
|
+
`observations` and `sobolev`.
|
|
274
284
|
Note that we can have jnp.arrays with the same dimension of
|
|
275
285
|
`u` which then ponderates each output of `u`
|
|
276
286
|
dynamic_loss
|
|
@@ -441,18 +451,12 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
441
451
|
) # we return a function, that way
|
|
442
452
|
# the order of sobolev_m is static and the conditional in the recursive
|
|
443
453
|
# function is properly set
|
|
444
|
-
self.sobolev_m = self.sobolev_m
|
|
445
454
|
else:
|
|
446
455
|
self.sobolev_reg = None
|
|
447
456
|
|
|
448
|
-
|
|
449
|
-
self.loss_weights
|
|
450
|
-
|
|
451
|
-
if self.omega_boundary_fun is None:
|
|
452
|
-
self.loss_weights["boundary_loss"] = 0
|
|
453
|
-
|
|
454
|
-
if self.sobolev_reg is None:
|
|
455
|
-
self.loss_weights["sobolev"] = 0
|
|
457
|
+
for k in _LOSS_WEIGHT_KEYS_PDESTATIO:
|
|
458
|
+
if k not in self.loss_weights.keys():
|
|
459
|
+
self.loss_weights[k] = 0
|
|
456
460
|
|
|
457
461
|
if (
|
|
458
462
|
isinstance(self.omega_boundary_fun, dict)
|
|
@@ -533,7 +537,6 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
533
537
|
)
|
|
534
538
|
else:
|
|
535
539
|
mse_norm_loss = jnp.array(0.0)
|
|
536
|
-
self.loss_weights["norm_loss"] = 0
|
|
537
540
|
|
|
538
541
|
# boundary part
|
|
539
542
|
params_ = _set_derivatives(params, "boundary_loss", self.derivative_keys)
|
|
@@ -567,7 +570,6 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
567
570
|
)
|
|
568
571
|
else:
|
|
569
572
|
mse_observation_loss = jnp.array(0.0)
|
|
570
|
-
self.loss_weights["observations"] = 0
|
|
571
573
|
|
|
572
574
|
# Sobolev regularization
|
|
573
575
|
params_ = _set_derivatives(params, "sobolev", self.derivative_keys)
|
|
@@ -582,7 +584,6 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
582
584
|
)
|
|
583
585
|
else:
|
|
584
586
|
mse_sobolev_loss = jnp.array(0.0)
|
|
585
|
-
self.loss_weights["sobolev"] = 0
|
|
586
587
|
|
|
587
588
|
# total loss
|
|
588
589
|
total_loss = (
|
|
@@ -785,8 +786,9 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
785
786
|
else:
|
|
786
787
|
self.sobolev_reg = None
|
|
787
788
|
|
|
788
|
-
|
|
789
|
-
self.loss_weights
|
|
789
|
+
for k in _LOSS_WEIGHT_KEYS_PDENONSTATIO:
|
|
790
|
+
if k not in self.loss_weights.keys():
|
|
791
|
+
self.loss_weights[k] = 0
|
|
790
792
|
|
|
791
793
|
def __call__(self, *args, **kwargs):
|
|
792
794
|
return self.evaluate(*args, **kwargs)
|
|
@@ -924,7 +926,6 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
924
926
|
)
|
|
925
927
|
else:
|
|
926
928
|
mse_observation_loss = jnp.array(0.0)
|
|
927
|
-
self.loss_weights["observations"] = 0
|
|
928
929
|
|
|
929
930
|
# Sobolev regularization
|
|
930
931
|
params_ = _set_derivatives(params, "sobolev", self.derivative_keys)
|
|
@@ -939,7 +940,6 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
939
940
|
)
|
|
940
941
|
else:
|
|
941
942
|
mse_sobolev_loss = jnp.array(0.0)
|
|
942
|
-
self.loss_weights["sobolev"] = 0.0
|
|
943
943
|
|
|
944
944
|
# total loss
|
|
945
945
|
total_loss = (
|