jinns 0.8.7__py3-none-any.whl → 0.8.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +1 -0
- jinns/data/_DataGenerators.py +93 -90
- jinns/data/_display.py +102 -13
- jinns/experimental/__init__.py +2 -0
- jinns/experimental/_sinuspinn.py +135 -0
- jinns/experimental/_spectralpinn.py +87 -0
- jinns/solver/_rar.py +203 -146
- jinns/solver/_seq2seq.py +2 -2
- {jinns-0.8.7.dist-info → jinns-0.8.9.dist-info}/METADATA +1 -1
- {jinns-0.8.7.dist-info → jinns-0.8.9.dist-info}/RECORD +13 -11
- {jinns-0.8.7.dist-info → jinns-0.8.9.dist-info}/LICENSE +0 -0
- {jinns-0.8.7.dist-info → jinns-0.8.9.dist-info}/WHEEL +0 -0
- {jinns-0.8.7.dist-info → jinns-0.8.9.dist-info}/top_level.txt +0 -0
jinns/solver/_rar.py
CHANGED
|
@@ -11,6 +11,41 @@ from jinns.loss._LossODE import LossODE, SystemLossODE
|
|
|
11
11
|
from jinns.loss._DynamicLossAbstract import PDEStatio
|
|
12
12
|
|
|
13
13
|
from functools import partial
|
|
14
|
+
from jinns.utils._hyperpinn import HYPERPINN
|
|
15
|
+
from jinns.utils._spinn import SPINN
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _proceed_to_rar(data, i):
|
|
19
|
+
"""Utilility function with various check to ensure we can proceed with the rar_step.
|
|
20
|
+
Return True if yes, and False otherwise"""
|
|
21
|
+
|
|
22
|
+
# Overall checks (universal for any data generator)
|
|
23
|
+
check_list = [
|
|
24
|
+
# check if burn-in period has ended
|
|
25
|
+
data.rar_parameters["start_iter"] <= i,
|
|
26
|
+
# check if enough iterations since last points added
|
|
27
|
+
(data.rar_parameters["update_every"] - 1) == data.rar_iter_from_last_sampling,
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
# Memory allocation checks (depends on the type of DataGenerator)
|
|
31
|
+
# check if we still have room to append new collocation points in the
|
|
32
|
+
# allocated jnp.array (can concern `data.p_times` or `p_omega`)
|
|
33
|
+
if isinstance(data, DataGeneratorODE) or isinstance(data, CubicMeshPDENonStatio):
|
|
34
|
+
check_list.append(
|
|
35
|
+
data.rar_parameters["selected_sample_size_times"]
|
|
36
|
+
<= jnp.count_nonzero(data.p_times == 0),
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
if isinstance(data, CubicMeshPDEStatio) or isinstance(data, CubicMeshPDENonStatio):
|
|
40
|
+
# for now the above check are redundants but there may be a time when
|
|
41
|
+
# we drop inheritence
|
|
42
|
+
check_list.append(
|
|
43
|
+
data.rar_parameters["selected_sample_size_omega"]
|
|
44
|
+
<= jnp.count_nonzero(data.p_omega == 0),
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
proceed = jnp.all(jnp.array(check_list))
|
|
48
|
+
return proceed
|
|
14
49
|
|
|
15
50
|
|
|
16
51
|
@partial(jax.jit, static_argnames=["_rar_step_true", "_rar_step_false"])
|
|
@@ -22,21 +57,7 @@ def trigger_rar(i, loss, params, data, _rar_step_true, _rar_step_false):
|
|
|
22
57
|
else:
|
|
23
58
|
# update `data` according to rar scheme.
|
|
24
59
|
data = jax.lax.cond(
|
|
25
|
-
|
|
26
|
-
jnp.array(
|
|
27
|
-
[
|
|
28
|
-
# check if enough it since last points added
|
|
29
|
-
data.rar_parameters["update_rate"]
|
|
30
|
-
== data.rar_iter_from_last_sampling,
|
|
31
|
-
# check if burn in period has ended
|
|
32
|
-
data.rar_parameters["start_iter"] < i,
|
|
33
|
-
# check if we still have room to append new
|
|
34
|
-
# collocation points in the allocated jnp array
|
|
35
|
-
data.rar_parameters["selected_sample_size"]
|
|
36
|
-
<= jnp.count_nonzero(data.p == 0),
|
|
37
|
-
]
|
|
38
|
-
)
|
|
39
|
-
),
|
|
60
|
+
_proceed_to_rar(data, i),
|
|
40
61
|
_rar_step_true,
|
|
41
62
|
_rar_step_false,
|
|
42
63
|
(loss, params, data, i),
|
|
@@ -49,13 +70,37 @@ def init_rar(data):
|
|
|
49
70
|
Separated from the main rar, because the initialization to get _true and
|
|
50
71
|
_false cannot be jit-ted.
|
|
51
72
|
"""
|
|
73
|
+
# NOTE if a user misspell some entry of ``rar_parameters`` the error
|
|
74
|
+
# risks to be a bit obscure but it should be ok.
|
|
52
75
|
if data.rar_parameters is None:
|
|
53
76
|
_rar_step_true, _rar_step_false = None, None
|
|
54
77
|
else:
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
78
|
+
if isinstance(data, DataGeneratorODE):
|
|
79
|
+
# In this case we only need rar parameters related to `times`
|
|
80
|
+
_rar_step_true, _rar_step_false = _rar_step_init(
|
|
81
|
+
data.rar_parameters["sample_size_times"],
|
|
82
|
+
data.rar_parameters["selected_sample_size_times"],
|
|
83
|
+
)
|
|
84
|
+
elif isinstance(data, CubicMeshPDENonStatio):
|
|
85
|
+
# In this case we need rar parameters related to both `times`
|
|
86
|
+
# and`omega`
|
|
87
|
+
_rar_step_true, _rar_step_false = _rar_step_init(
|
|
88
|
+
(
|
|
89
|
+
data.rar_parameters["sample_size_times"],
|
|
90
|
+
data.rar_parameters["sample_size_omega"],
|
|
91
|
+
),
|
|
92
|
+
(
|
|
93
|
+
data.rar_parameters["selected_sample_size_times"],
|
|
94
|
+
data.rar_parameters["selected_sample_size_omega"],
|
|
95
|
+
),
|
|
96
|
+
)
|
|
97
|
+
elif isinstance(data, CubicMeshPDEStatio):
|
|
98
|
+
# In this case we only need rar parameters related to `omega`
|
|
99
|
+
_rar_step_true, _rar_step_false = _rar_step_init(
|
|
100
|
+
data.rar_parameters["sample_size_omega"],
|
|
101
|
+
data.rar_parameters["selected_sample_size_omega"],
|
|
102
|
+
)
|
|
103
|
+
|
|
59
104
|
data.rar_parameters["iter_from_last_sampling"] = 0
|
|
60
105
|
|
|
61
106
|
return data, _rar_step_true, _rar_step_false
|
|
@@ -64,7 +109,7 @@ def init_rar(data):
|
|
|
64
109
|
def _rar_step_init(sample_size, selected_sample_size):
|
|
65
110
|
"""
|
|
66
111
|
This is a wrapper because the sampling size and
|
|
67
|
-
selected_sample_size, must be treated static
|
|
112
|
+
selected_sample_size, must be treated as static
|
|
68
113
|
in order to slice. So they must be set before jitting and not with the jitted
|
|
69
114
|
dictionary values rar["test_points_nb"] and rar["added_points_nb"]
|
|
70
115
|
|
|
@@ -72,16 +117,10 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
72
117
|
"""
|
|
73
118
|
|
|
74
119
|
def rar_step_true(operands):
|
|
75
|
-
"""
|
|
76
|
-
Note: in all generality, we would need a stop gradient operator around
|
|
77
|
-
these dynamic_loss evaluations that follow which produce weights for
|
|
78
|
-
sampling. However, they appear through a argsort and sampling
|
|
79
|
-
operations which definitly kills gradient flows
|
|
80
|
-
"""
|
|
81
120
|
loss, params, data, i = operands
|
|
82
121
|
|
|
83
122
|
if isinstance(data, DataGeneratorODE):
|
|
84
|
-
|
|
123
|
+
new_omega_samples = data.sample_in_time_domain(sample_size)
|
|
85
124
|
|
|
86
125
|
# We can have different types of Loss
|
|
87
126
|
if isinstance(loss, LossODE):
|
|
@@ -90,7 +129,7 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
90
129
|
(0),
|
|
91
130
|
0,
|
|
92
131
|
)
|
|
93
|
-
dyn_on_s = v_dyn_loss(
|
|
132
|
+
dyn_on_s = v_dyn_loss(new_omega_samples)
|
|
94
133
|
if dyn_on_s.ndim > 1:
|
|
95
134
|
mse_on_s = (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
|
|
96
135
|
else:
|
|
@@ -106,7 +145,7 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
106
145
|
(0),
|
|
107
146
|
0,
|
|
108
147
|
)
|
|
109
|
-
dyn_on_s = v_dyn_loss(
|
|
148
|
+
dyn_on_s = v_dyn_loss(new_omega_samples)
|
|
110
149
|
if dyn_on_s.ndim > 1:
|
|
111
150
|
mse_on_s += (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
|
|
112
151
|
else:
|
|
@@ -118,9 +157,7 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
118
157
|
(mse_on_s.shape[0] - selected_sample_size,),
|
|
119
158
|
(selected_sample_size,),
|
|
120
159
|
)
|
|
121
|
-
higher_residual_points =
|
|
122
|
-
|
|
123
|
-
data.rar_iter_from_last_sampling = 0
|
|
160
|
+
higher_residual_points = new_omega_samples[higher_residual_idx]
|
|
124
161
|
|
|
125
162
|
## add the new points in times
|
|
126
163
|
# start indices of update can be dynamic but the the shape (length)
|
|
@@ -135,7 +172,7 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
135
172
|
## points are non-zero
|
|
136
173
|
new_proba = 1 / (data.nt_start + data.rar_iter_nb * selected_sample_size)
|
|
137
174
|
# the next work because nt_start is static
|
|
138
|
-
data.
|
|
175
|
+
data.p_times = data.p_times.at[: data.nt_start].set(new_proba)
|
|
139
176
|
|
|
140
177
|
# the next requires a fori_loop because the range is dynamic
|
|
141
178
|
def update_slices(i, p):
|
|
@@ -147,16 +184,14 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
147
184
|
|
|
148
185
|
data.rar_iter_nb += 1
|
|
149
186
|
|
|
150
|
-
data.
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
# have side effects in this function that will be jitted
|
|
154
|
-
return data
|
|
187
|
+
data.p_times = jax.lax.fori_loop(
|
|
188
|
+
0, data.rar_iter_nb, update_slices, data.p_times
|
|
189
|
+
)
|
|
155
190
|
|
|
156
191
|
elif isinstance(data, CubicMeshPDEStatio) and not isinstance(
|
|
157
192
|
data, CubicMeshPDENonStatio
|
|
158
193
|
):
|
|
159
|
-
|
|
194
|
+
new_omega_samples = data.sample_in_omega_domain(sample_size)
|
|
160
195
|
|
|
161
196
|
# We can have different types of Loss
|
|
162
197
|
if isinstance(loss, LossPDEStatio):
|
|
@@ -169,7 +204,7 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
169
204
|
(0),
|
|
170
205
|
0,
|
|
171
206
|
)
|
|
172
|
-
dyn_on_s = v_dyn_loss(
|
|
207
|
+
dyn_on_s = v_dyn_loss(new_omega_samples)
|
|
173
208
|
if dyn_on_s.ndim > 1:
|
|
174
209
|
mse_on_s = (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
|
|
175
210
|
else:
|
|
@@ -185,7 +220,7 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
185
220
|
0,
|
|
186
221
|
0,
|
|
187
222
|
)
|
|
188
|
-
dyn_on_s = v_dyn_loss(
|
|
223
|
+
dyn_on_s = v_dyn_loss(new_omega_samples)
|
|
189
224
|
if dyn_on_s.ndim > 1:
|
|
190
225
|
mse_on_s += (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
|
|
191
226
|
else:
|
|
@@ -197,12 +232,10 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
197
232
|
(mse_on_s.shape[0] - selected_sample_size,),
|
|
198
233
|
(selected_sample_size,),
|
|
199
234
|
)
|
|
200
|
-
higher_residual_points =
|
|
235
|
+
higher_residual_points = new_omega_samples[higher_residual_idx]
|
|
201
236
|
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
## add the new points in times
|
|
205
|
-
# start indices of update can be dynamic but the the shape (length)
|
|
237
|
+
## add the new points in omega
|
|
238
|
+
# start indices of update can be dynamic but not the shape (length)
|
|
206
239
|
# of the slice
|
|
207
240
|
data.omega = jax.lax.dynamic_update_slice(
|
|
208
241
|
data.omega,
|
|
@@ -214,7 +247,7 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
214
247
|
## points are non-zero
|
|
215
248
|
new_proba = 1 / (data.n_start + data.rar_iter_nb * selected_sample_size)
|
|
216
249
|
# the next work because n_start is static
|
|
217
|
-
data.
|
|
250
|
+
data.p_omega = data.p_omega.at[: data.n_start].set(new_proba)
|
|
218
251
|
|
|
219
252
|
# the next requires a fori_loop because the range is dynamic
|
|
220
253
|
def update_slices(i, p):
|
|
@@ -226,145 +259,169 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
226
259
|
|
|
227
260
|
data.rar_iter_nb += 1
|
|
228
261
|
|
|
229
|
-
data.
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
# have side effects in this function that will be jitted
|
|
233
|
-
return data
|
|
262
|
+
data.p_omega = jax.lax.fori_loop(
|
|
263
|
+
0, data.rar_iter_nb, update_slices, data.p_omega
|
|
264
|
+
)
|
|
234
265
|
|
|
235
266
|
elif isinstance(data, CubicMeshPDENonStatio):
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
)
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
(
|
|
252
|
-
|
|
253
|
-
)
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
mse_on_s = dyn_on_s**2
|
|
259
|
-
elif isinstance(loss, LossPDENonStatio):
|
|
267
|
+
# NOTE in this case sample_size and selected_sample_size
|
|
268
|
+
# are tuples (times, omega) => we unpack them for clarity
|
|
269
|
+
selected_sample_size_times, selected_sample_size_omega = (
|
|
270
|
+
selected_sample_size
|
|
271
|
+
)
|
|
272
|
+
sample_size_times, sample_size_omega = sample_size
|
|
273
|
+
|
|
274
|
+
new_times_samples = data.sample_in_time_domain(sample_size_times)
|
|
275
|
+
new_omega_samples = data.sample_in_omega_domain(sample_size_omega)
|
|
276
|
+
|
|
277
|
+
if isinstance(loss.u, HYPERPINN) or isinstance(loss.u, SPINN):
|
|
278
|
+
raise NotImplementedError("RAR not implemented for hyperPINN and SPINN")
|
|
279
|
+
else:
|
|
280
|
+
# do cartesian product on new points
|
|
281
|
+
tile_omega = jnp.tile(
|
|
282
|
+
new_omega_samples, reps=(sample_size_times, 1)
|
|
283
|
+
) # it is tiled
|
|
284
|
+
repeat_times = jnp.repeat(new_times_samples, sample_size_omega, axis=0)[
|
|
285
|
+
..., None
|
|
286
|
+
] # it is repeated + add an axis
|
|
287
|
+
|
|
288
|
+
if isinstance(loss, LossPDENonStatio):
|
|
260
289
|
v_dyn_loss = vmap(
|
|
261
290
|
lambda t, x: loss.dynamic_loss.evaluate(t, x, loss.u, params),
|
|
262
291
|
(0, 0),
|
|
263
292
|
0,
|
|
264
293
|
)
|
|
265
|
-
dyn_on_s = v_dyn_loss(
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
mse_on_s = dyn_on_s**2
|
|
294
|
+
dyn_on_s = v_dyn_loss(repeat_times, tile_omega).reshape(
|
|
295
|
+
(sample_size_times, sample_size_omega)
|
|
296
|
+
)
|
|
297
|
+
mse_on_s = dyn_on_s**2
|
|
270
298
|
elif isinstance(loss, SystemLossPDE):
|
|
271
|
-
|
|
299
|
+
dyn_on_s = jnp.zeros((sample_size_times, sample_size_omega))
|
|
272
300
|
for i in loss.dynamic_loss_dict.keys():
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
mse_on_s += (
|
|
284
|
-
jnp.linalg.norm(dyn_on_s, axis=-1) ** 2
|
|
285
|
-
).flatten()
|
|
286
|
-
else:
|
|
287
|
-
mse_on_s += dyn_on_s**2
|
|
288
|
-
else:
|
|
289
|
-
v_dyn_loss = vmap(
|
|
290
|
-
lambda t, x: loss.dynamic_loss_dict[i].evaluate(
|
|
291
|
-
t, x, loss.u_dict, params
|
|
292
|
-
),
|
|
293
|
-
(0, 0),
|
|
294
|
-
0,
|
|
295
|
-
)
|
|
296
|
-
dyn_on_s = v_dyn_loss(st[..., None], sx)
|
|
297
|
-
if dyn_on_s.ndim > 1:
|
|
298
|
-
mse_on_s += (
|
|
299
|
-
jnp.linalg.norm(dyn_on_s, axis=-1) ** 2
|
|
300
|
-
).flatten()
|
|
301
|
-
else:
|
|
302
|
-
mse_on_s += dyn_on_s**2
|
|
303
|
-
|
|
304
|
-
## Now that we have the residuals, select the m points
|
|
305
|
-
# with higher dynamic loss (residuals)
|
|
306
|
-
higher_residual_idx = jax.lax.dynamic_slice(
|
|
307
|
-
jnp.argsort(mse_on_s),
|
|
308
|
-
(mse_on_s.shape[0] - selected_sample_size,),
|
|
309
|
-
(selected_sample_size,),
|
|
310
|
-
)
|
|
311
|
-
higher_residual_points_st = st[higher_residual_idx]
|
|
312
|
-
higher_residual_points_sx = sx[higher_residual_idx]
|
|
301
|
+
v_dyn_loss = vmap(
|
|
302
|
+
lambda t, x: loss.dynamic_loss_dict[i].evaluate(
|
|
303
|
+
t, x, loss.u_dict, params
|
|
304
|
+
),
|
|
305
|
+
(0, 0),
|
|
306
|
+
0,
|
|
307
|
+
)
|
|
308
|
+
dyn_on_s += v_dyn_loss(repeat_times, tile_omega).reshape(
|
|
309
|
+
(sample_size_times, sample_size_omega)
|
|
310
|
+
)
|
|
313
311
|
|
|
314
|
-
|
|
312
|
+
mse_on_s = dyn_on_s**2
|
|
313
|
+
# -- Select the m points with highest average residuals on time and
|
|
314
|
+
# -- space (times in rows / omega in columns)
|
|
315
|
+
# mean_times = mse_on_s.mean(axis=1)
|
|
316
|
+
# mean_omega = mse_on_s.mean(axis=0)
|
|
317
|
+
# times_idx = jax.lax.dynamic_slice(
|
|
318
|
+
# jnp.argsort(mean_times),
|
|
319
|
+
# (mse_on_s.shape[0] - selected_sample_size_times,),
|
|
320
|
+
# (selected_sample_size_times,),
|
|
321
|
+
# )
|
|
322
|
+
# omega_idx = jax.lax.dynamic_slice(
|
|
323
|
+
# jnp.argsort(mean_omega),
|
|
324
|
+
# (mse_on_s.shape[1] - selected_sample_size_omega,),
|
|
325
|
+
# (selected_sample_size_omega,),
|
|
326
|
+
# )
|
|
327
|
+
|
|
328
|
+
# -- Select the m worst points (t, x) with highest residuals
|
|
329
|
+
n_select = max(selected_sample_size_times, selected_sample_size_omega)
|
|
330
|
+
_, idx = jax.lax.top_k(mse_on_s.flatten(), k=n_select)
|
|
331
|
+
arr_idx = jnp.unravel_index(idx, mse_on_s.shape)
|
|
332
|
+
times_idx = arr_idx[0][:selected_sample_size_times]
|
|
333
|
+
omega_idx = arr_idx[1][:selected_sample_size_omega]
|
|
334
|
+
|
|
335
|
+
higher_residual_points_times = new_times_samples[times_idx]
|
|
336
|
+
higher_residual_points_omega = new_omega_samples[omega_idx]
|
|
315
337
|
|
|
316
338
|
## add the new points in times
|
|
317
|
-
# start indices of update can be dynamic but
|
|
339
|
+
# start indices of update can be dynamic but not the shape (length)
|
|
318
340
|
# of the slice
|
|
319
341
|
data.times = jax.lax.dynamic_update_slice(
|
|
320
342
|
data.times,
|
|
321
|
-
|
|
322
|
-
(data.n_start + data.rar_iter_nb *
|
|
343
|
+
higher_residual_points_times,
|
|
344
|
+
(data.n_start + data.rar_iter_nb * selected_sample_size_times,),
|
|
323
345
|
)
|
|
324
346
|
|
|
325
347
|
## add the new points in omega
|
|
326
348
|
data.omega = jax.lax.dynamic_update_slice(
|
|
327
349
|
data.omega,
|
|
328
|
-
|
|
350
|
+
higher_residual_points_omega,
|
|
329
351
|
(
|
|
330
|
-
data.n_start + data.rar_iter_nb *
|
|
352
|
+
data.n_start + data.rar_iter_nb * selected_sample_size_omega,
|
|
331
353
|
data.dim,
|
|
332
354
|
),
|
|
333
355
|
)
|
|
334
356
|
|
|
335
357
|
## rearrange probabilities so that the probabilities of the new
|
|
336
358
|
## points are non-zero
|
|
337
|
-
|
|
359
|
+
new_p_times = 1 / (
|
|
360
|
+
data.nt_start + data.rar_iter_nb * selected_sample_size_times
|
|
361
|
+
)
|
|
338
362
|
# the next work because nt_start is static
|
|
339
|
-
data.
|
|
363
|
+
data.p_times = data.p_times.at[: data.nt_start].set(new_p_times)
|
|
340
364
|
|
|
341
|
-
#
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
365
|
+
# same for p_omega (work because n_start is static)
|
|
366
|
+
new_p_omega = 1 / (
|
|
367
|
+
data.n_start + data.rar_iter_nb * selected_sample_size_omega
|
|
368
|
+
)
|
|
369
|
+
data.p_omega = data.p_omega.at[: data.n_start].set(new_p_omega)
|
|
370
|
+
|
|
371
|
+
# the part of data.p_* after n_start requires a fori_loop because
|
|
372
|
+
# the range is dynamic
|
|
373
|
+
def create_update_slices(new_val, selected_sample_size):
|
|
374
|
+
def update_slices(i, p):
|
|
375
|
+
new_p = jax.lax.dynamic_update_slice(
|
|
376
|
+
p,
|
|
377
|
+
new_val * jnp.ones((selected_sample_size,)),
|
|
378
|
+
((data.n_start + i * selected_sample_size),),
|
|
379
|
+
)
|
|
380
|
+
return new_p
|
|
381
|
+
|
|
382
|
+
return update_slices
|
|
348
383
|
|
|
349
384
|
data.rar_iter_nb += 1
|
|
350
385
|
|
|
351
|
-
|
|
386
|
+
## update rest of p_times
|
|
387
|
+
update_slices_times = create_update_slices(
|
|
388
|
+
new_p_times, selected_sample_size_times
|
|
389
|
+
)
|
|
390
|
+
data.p_times = jax.lax.fori_loop(
|
|
391
|
+
0,
|
|
392
|
+
data.rar_iter_nb,
|
|
393
|
+
update_slices_times,
|
|
394
|
+
data.p_times,
|
|
395
|
+
)
|
|
396
|
+
## update rest of p_omega
|
|
397
|
+
update_slices_omega = create_update_slices(
|
|
398
|
+
new_p_omega, selected_sample_size_omega
|
|
399
|
+
)
|
|
400
|
+
data.p_omega = jax.lax.fori_loop(
|
|
401
|
+
0,
|
|
402
|
+
data.rar_iter_nb,
|
|
403
|
+
update_slices_omega,
|
|
404
|
+
data.p_omega,
|
|
405
|
+
)
|
|
352
406
|
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
407
|
+
# update RAR parameters for all cases
|
|
408
|
+
data.rar_iter_from_last_sampling = 0
|
|
409
|
+
|
|
410
|
+
# NOTE must return data to be correctly updated because we cannot
|
|
411
|
+
# have side effects in this function that will be jitted
|
|
412
|
+
return data
|
|
356
413
|
|
|
357
414
|
def rar_step_false(operands):
|
|
358
415
|
_, _, data, i = operands
|
|
359
416
|
|
|
360
417
|
# Add 1 only if we are after the burn in period
|
|
361
|
-
|
|
362
|
-
i
|
|
363
|
-
lambda
|
|
364
|
-
lambda
|
|
365
|
-
(data.rar_iter_from_last_sampling),
|
|
418
|
+
increment = jax.lax.cond(
|
|
419
|
+
i <= data.rar_parameters["start_iter"],
|
|
420
|
+
lambda: 0,
|
|
421
|
+
lambda: 1,
|
|
366
422
|
)
|
|
367
423
|
|
|
424
|
+
data.rar_iter_from_last_sampling += increment
|
|
368
425
|
return data
|
|
369
426
|
|
|
370
427
|
return rar_step_true, rar_step_false
|
jinns/solver/_seq2seq.py
CHANGED
|
@@ -88,7 +88,7 @@ def initialize_seq2seq(loss, data, seq2seq, opt_state):
|
|
|
88
88
|
data.curr_omega_idx = 0
|
|
89
89
|
data.generate_time_data()
|
|
90
90
|
data._key, data.times, _ = _reset_batch_idx_and_permute(
|
|
91
|
-
(data._key, data.times, data.curr_omega_idx, None, data.
|
|
91
|
+
(data._key, data.times, data.curr_omega_idx, None, data.p_times)
|
|
92
92
|
)
|
|
93
93
|
opt_state.hyperparams["learning_rate"] = seq2seq["learning_rate"][curr_seq]
|
|
94
94
|
|
|
@@ -145,7 +145,7 @@ def _update_seq2seq_SystemLossODE(operands):
|
|
|
145
145
|
data.curr_omega_idx = 0
|
|
146
146
|
data.generate_time_data()
|
|
147
147
|
data._key, data.times, _ = _reset_batch_idx_and_permute(
|
|
148
|
-
(data._key, data.times, data.curr_omega_idx, None, data.
|
|
148
|
+
(data._key, data.times, data.curr_omega_idx, None, data.p_times)
|
|
149
149
|
)
|
|
150
150
|
|
|
151
151
|
opt_state.hyperparams["learning_rate"] = seq2seq["learning_rate"][curr_seq]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: jinns
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.9
|
|
4
4
|
Summary: Physics Informed Neural Network with JAX
|
|
5
5
|
Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
6
6
|
Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
@@ -1,9 +1,11 @@
|
|
|
1
|
-
jinns/__init__.py,sha256=
|
|
2
|
-
jinns/data/_DataGenerators.py,sha256=
|
|
1
|
+
jinns/__init__.py,sha256=T2XlmLbYqcXTumPJL00cJ80W98We5LH8Yg_Lss_exl4,139
|
|
2
|
+
jinns/data/_DataGenerators.py,sha256=_um-giHQ8mCILUOJHX231njHTHZp4S7EcGrUs7R1dUs,61829
|
|
3
3
|
jinns/data/__init__.py,sha256=yBOmoavSD-cABp4XcjQY1zsEVO0mDyIhi2MJ5WNp0l8,326
|
|
4
|
-
jinns/data/_display.py,sha256=
|
|
5
|
-
jinns/experimental/__init__.py,sha256=
|
|
4
|
+
jinns/data/_display.py,sha256=vlqggDCgVMEwdGBtjVmZaTQORU6imSfDkssn2XCtITI,10392
|
|
5
|
+
jinns/experimental/__init__.py,sha256=qWbhC7Z8UgLWy0t-zU7RYze6v13-FngiCYXu-2bRVFQ,296
|
|
6
6
|
jinns/experimental/_diffrax_solver.py,sha256=sLT22byqh-6015_fhe1xtMWlFOYcCjzYKET4sLhA9R4,6818
|
|
7
|
+
jinns/experimental/_sinuspinn.py,sha256=hxSzscwMV2LayWOqenIlT1zqEVVrE5Y8CKf7bHX5XFQ,5016
|
|
8
|
+
jinns/experimental/_spectralpinn.py,sha256=-4795pa7AYtRNSE-ugan3gHh64mtu2VdrRG5AS_J9Eg,2654
|
|
7
9
|
jinns/loss/_DynamicLoss.py,sha256=L4CVmmF0rTPbHntgqsLLHlnrlQgLHsetUocpJm7ZYag,27461
|
|
8
10
|
jinns/loss/_DynamicLossAbstract.py,sha256=kTQlhLx7SBuH5dIDmYaE79sVHUZt1nUFa8LxPU5IHhM,8504
|
|
9
11
|
jinns/loss/_LossODE.py,sha256=b9doBHoQwYvlgpqzrNO4dOaTN87LRvjHtHbz9bMoH7E,22119
|
|
@@ -13,8 +15,8 @@ jinns/loss/__init__.py,sha256=pFNYUxns-NPXBFdqrEVSiXkQLfCtKw-t2trlhvLzpYE,355
|
|
|
13
15
|
jinns/loss/_boundary_conditions.py,sha256=YfSnLZ25hXqQ5KWAuxOrWSKkf_oBqAc9GQV4z7MjWyQ,17434
|
|
14
16
|
jinns/loss/_operators.py,sha256=zDGJqYqeYH7xd-4dtGX9PS-pf0uSOpUUXGo5SVjIJ4o,11069
|
|
15
17
|
jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
16
|
-
jinns/solver/_rar.py,sha256=
|
|
17
|
-
jinns/solver/_seq2seq.py,sha256=
|
|
18
|
+
jinns/solver/_rar.py,sha256=IYP-jdbM0rbjBtxislrBYBuj49p9_QDOqejZKCHrKg8,17072
|
|
19
|
+
jinns/solver/_seq2seq.py,sha256=S6IPfsXpS_fbqIqAy01eUM7GBSBSkRzURan_J-iXXzI,5632
|
|
18
20
|
jinns/solver/_solve.py,sha256=mGi0zaT_fK_QpBjTxof5Ix4mmfmnPi66CNJ3GQFZuo4,19099
|
|
19
21
|
jinns/utils/__init__.py,sha256=44ms5UR6vMw3Nf6u4RCAzPFs4fom_YbBnH9mfne8m6k,313
|
|
20
22
|
jinns/utils/_containers.py,sha256=eYD277fO7X4EfX7PUFCCl69r3JBfh1sCfq8LkL5gd6o,1495
|
|
@@ -27,8 +29,8 @@ jinns/utils/_utils.py,sha256=8dgvWXX9NT7_7-zltWp0C9tG45ZFNwXxueyxPBb4hjo,6740
|
|
|
27
29
|
jinns/utils/_utils_uspinn.py,sha256=qcKcOw3zrwWSQyGVj6fD8c9GinHt_U6JWN_k0auTtXM,26039
|
|
28
30
|
jinns/validation/__init__.py,sha256=Jv58mzgC3F7cRfXA6caicL1t_U0UAhbwLrmMNVg6E7s,66
|
|
29
31
|
jinns/validation/_validation.py,sha256=KfetbzB0xTNdBcYLwFWjEtP63Tf9wJirlhgqLTJDyy4,6761
|
|
30
|
-
jinns-0.8.
|
|
31
|
-
jinns-0.8.
|
|
32
|
-
jinns-0.8.
|
|
33
|
-
jinns-0.8.
|
|
34
|
-
jinns-0.8.
|
|
32
|
+
jinns-0.8.9.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
|
|
33
|
+
jinns-0.8.9.dist-info/METADATA,sha256=na97ODHPafvEMKGvmUq6XszFrjQ9LP8L2FtCY2gZ8oI,2482
|
|
34
|
+
jinns-0.8.9.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
35
|
+
jinns-0.8.9.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
|
|
36
|
+
jinns-0.8.9.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|