jinns 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_Batchs.py +4 -8
- jinns/data/_DataGenerators.py +532 -341
- jinns/loss/_DynamicLoss.py +150 -173
- jinns/loss/_DynamicLossAbstract.py +27 -73
- jinns/loss/_LossODE.py +45 -26
- jinns/loss/_LossPDE.py +85 -84
- jinns/loss/__init__.py +7 -6
- jinns/loss/_boundary_conditions.py +148 -279
- jinns/loss/_loss_utils.py +85 -58
- jinns/loss/_operators.py +441 -184
- jinns/parameters/_derivative_keys.py +487 -60
- jinns/plot/_plot.py +111 -98
- jinns/solver/_rar.py +102 -407
- jinns/solver/_solve.py +73 -38
- jinns/solver/_utils.py +122 -0
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +3 -1
- jinns/utils/_hyperpinn.py +17 -7
- jinns/utils/_pinn.py +17 -27
- jinns/utils/_ppinn.py +227 -0
- jinns/utils/_save_load.py +13 -13
- jinns/utils/_spinn.py +24 -43
- jinns/utils/_types.py +1 -0
- jinns/utils/_utils.py +40 -12
- jinns-1.2.0.dist-info/AUTHORS +2 -0
- jinns-1.2.0.dist-info/METADATA +127 -0
- jinns-1.2.0.dist-info/RECORD +41 -0
- {jinns-1.0.0.dist-info → jinns-1.2.0.dist-info}/WHEEL +1 -1
- jinns-1.0.0.dist-info/METADATA +0 -84
- jinns-1.0.0.dist-info/RECORD +0 -38
- {jinns-1.0.0.dist-info → jinns-1.2.0.dist-info}/LICENSE +0 -0
- {jinns-1.0.0.dist-info → jinns-1.2.0.dist-info}/top_level.txt +0 -0
jinns/solver/_rar.py
CHANGED
|
@@ -13,6 +13,7 @@ from jaxtyping import Int, Bool
|
|
|
13
13
|
from jinns.data._Batchs import *
|
|
14
14
|
from jinns.loss._LossODE import LossODE, SystemLossODE
|
|
15
15
|
from jinns.loss._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
|
|
16
|
+
from jinns.loss._loss_utils import dynamic_loss_apply
|
|
16
17
|
from jinns.data._DataGenerators import (
|
|
17
18
|
DataGeneratorODE,
|
|
18
19
|
CubicMeshPDEStatio,
|
|
@@ -30,7 +31,7 @@ def _proceed_to_rar(data: AnyDataGenerator, i: Int) -> Bool:
|
|
|
30
31
|
"""Utilility function with various check to ensure we can proceed with the rar_step.
|
|
31
32
|
Return True if yes, and False otherwise"""
|
|
32
33
|
|
|
33
|
-
# Overall checks
|
|
34
|
+
# Overall checks
|
|
34
35
|
check_list = [
|
|
35
36
|
# check if burn-in period has ended
|
|
36
37
|
data.rar_parameters["start_iter"] <= i,
|
|
@@ -38,22 +39,12 @@ def _proceed_to_rar(data: AnyDataGenerator, i: Int) -> Bool:
|
|
|
38
39
|
(data.rar_parameters["update_every"] - 1) == data.rar_iter_from_last_sampling,
|
|
39
40
|
]
|
|
40
41
|
|
|
41
|
-
# Memory allocation checks
|
|
42
|
+
# Memory allocation checks
|
|
42
43
|
# check if we still have room to append new collocation points in the
|
|
43
|
-
# allocated jnp.array
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
<= jnp.count_nonzero(data.p_times == 0),
|
|
48
|
-
)
|
|
49
|
-
|
|
50
|
-
if isinstance(data, (CubicMeshPDEStatio, CubicMeshPDENonStatio)):
|
|
51
|
-
# for now the above check are redundants but there may be a time when
|
|
52
|
-
# we drop inheritence
|
|
53
|
-
check_list.append(
|
|
54
|
-
data.rar_parameters["selected_sample_size_omega"]
|
|
55
|
-
<= jnp.count_nonzero(data.p_omega == 0),
|
|
56
|
-
)
|
|
44
|
+
# allocated jnp.array
|
|
45
|
+
check_list.append(
|
|
46
|
+
data.rar_parameters["selected_sample_size"] <= jnp.count_nonzero(data.p == 0),
|
|
47
|
+
)
|
|
57
48
|
|
|
58
49
|
proceed = jnp.all(jnp.array(check_list))
|
|
59
50
|
return proceed
|
|
@@ -99,38 +90,12 @@ def init_rar(
|
|
|
99
90
|
if data.rar_parameters is None:
|
|
100
91
|
_rar_step_true, _rar_step_false = None, None
|
|
101
92
|
else:
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
data.rar_parameters["selected_sample_size_times"],
|
|
107
|
-
)
|
|
108
|
-
elif isinstance(data, CubicMeshPDENonStatio):
|
|
109
|
-
# In this case we need rar parameters related to both `times`
|
|
110
|
-
# and`omega`
|
|
111
|
-
_rar_step_true, _rar_step_false = _rar_step_init(
|
|
112
|
-
(
|
|
113
|
-
data.rar_parameters["sample_size_times"],
|
|
114
|
-
data.rar_parameters["sample_size_omega"],
|
|
115
|
-
),
|
|
116
|
-
(
|
|
117
|
-
data.rar_parameters["selected_sample_size_times"],
|
|
118
|
-
data.rar_parameters["selected_sample_size_omega"],
|
|
119
|
-
),
|
|
120
|
-
)
|
|
121
|
-
elif isinstance(data, CubicMeshPDEStatio):
|
|
122
|
-
# In this case we only need rar parameters related to `omega`
|
|
123
|
-
_rar_step_true, _rar_step_false = _rar_step_init(
|
|
124
|
-
data.rar_parameters["sample_size_omega"],
|
|
125
|
-
data.rar_parameters["selected_sample_size_omega"],
|
|
126
|
-
)
|
|
127
|
-
else:
|
|
128
|
-
raise ValueError(f"Wrong type for data got {type(data)}")
|
|
93
|
+
_rar_step_true, _rar_step_false = _rar_step_init(
|
|
94
|
+
data.rar_parameters["sample_size"],
|
|
95
|
+
data.rar_parameters["selected_sample_size"],
|
|
96
|
+
)
|
|
129
97
|
|
|
130
|
-
|
|
131
|
-
data = eqx.tree_at(lambda m: m.rar_iter_from_last_sampling, data, 0)
|
|
132
|
-
else:
|
|
133
|
-
data.rar_iter_from_last_sampling = 0
|
|
98
|
+
data = eqx.tree_at(lambda m: m.rar_iter_from_last_sampling, data, 0)
|
|
134
99
|
|
|
135
100
|
return data, _rar_step_true, _rar_step_false
|
|
136
101
|
|
|
@@ -150,402 +115,132 @@ def _rar_step_init(sample_size: Int, selected_sample_size: Int) -> tuple[
|
|
|
150
115
|
|
|
151
116
|
def rar_step_true(operands: rar_operands) -> AnyDataGenerator:
|
|
152
117
|
loss, params, data, i = operands
|
|
118
|
+
if isinstance(loss.u, HYPERPINN) or isinstance(loss.u, SPINN):
|
|
119
|
+
raise NotImplementedError("RAR not implemented for hyperPINN and SPINN")
|
|
153
120
|
|
|
154
121
|
if isinstance(data, DataGeneratorODE):
|
|
155
122
|
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
data = eqx.tree_at(lambda m: m.key, data, new_key)
|
|
160
|
-
else:
|
|
161
|
-
new_omega_samples = data.sample_in_time_domain(sample_size)
|
|
123
|
+
new_key, subkey = jax.random.split(data.key)
|
|
124
|
+
new_samples = data.sample_in_time_domain(subkey, sample_size)
|
|
125
|
+
data = eqx.tree_at(lambda m: m.key, data, new_key)
|
|
162
126
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
)
|
|
170
|
-
dyn_on_s = v_dyn_loss(new_omega_samples)
|
|
171
|
-
if dyn_on_s.ndim > 1:
|
|
172
|
-
mse_on_s = (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
|
|
173
|
-
else:
|
|
174
|
-
mse_on_s = dyn_on_s**2
|
|
175
|
-
elif isinstance(loss, SystemLossODE):
|
|
176
|
-
mse_on_s = 0
|
|
177
|
-
|
|
178
|
-
for i in loss.dynamic_loss_dict.keys():
|
|
179
|
-
v_dyn_loss = vmap(
|
|
180
|
-
lambda t: loss.dynamic_loss_dict[i].evaluate(
|
|
181
|
-
t, loss.u_dict, params
|
|
182
|
-
),
|
|
183
|
-
(0),
|
|
184
|
-
0,
|
|
185
|
-
)
|
|
186
|
-
dyn_on_s = v_dyn_loss(new_omega_samples)
|
|
187
|
-
if dyn_on_s.ndim > 1:
|
|
188
|
-
mse_on_s += (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
|
|
189
|
-
else:
|
|
190
|
-
mse_on_s += dyn_on_s**2
|
|
191
|
-
|
|
192
|
-
## Select the m points with higher dynamic loss
|
|
193
|
-
higher_residual_idx = jax.lax.dynamic_slice(
|
|
194
|
-
jnp.argsort(mse_on_s),
|
|
195
|
-
(mse_on_s.shape[0] - selected_sample_size,),
|
|
196
|
-
(selected_sample_size,),
|
|
197
|
-
)
|
|
198
|
-
higher_residual_points = new_omega_samples[higher_residual_idx]
|
|
127
|
+
elif isinstance(data, CubicMeshPDEStatio) and not isinstance(
|
|
128
|
+
data, CubicMeshPDENonStatio
|
|
129
|
+
):
|
|
130
|
+
new_key, *subkeys = jax.random.split(data.key, data.dim + 1)
|
|
131
|
+
new_samples = data.sample_in_omega_domain(subkeys, sample_size)
|
|
132
|
+
data = eqx.tree_at(lambda m: m.key, data, new_key)
|
|
199
133
|
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
(data.
|
|
134
|
+
elif isinstance(data, CubicMeshPDENonStatio):
|
|
135
|
+
new_key, subkey = jax.random.split(data.key)
|
|
136
|
+
new_samples_times = data.sample_in_time_domain(subkey, sample_size)
|
|
137
|
+
if data.dim == 1:
|
|
138
|
+
new_key, subkeys = jax.random.split(new_key, 2)
|
|
139
|
+
else:
|
|
140
|
+
new_key, *subkeys = jax.random.split(new_key, data.dim + 1)
|
|
141
|
+
new_samples_omega = data.sample_in_omega_domain(subkeys, sample_size)
|
|
142
|
+
new_samples = jnp.concatenate(
|
|
143
|
+
[new_samples_times, new_samples_omega], axis=1
|
|
207
144
|
)
|
|
208
145
|
|
|
209
|
-
|
|
210
|
-
data = eqx.tree_at(lambda m: m.times, data, new_times)
|
|
211
|
-
else:
|
|
212
|
-
data.times = new_times
|
|
213
|
-
## rearrange probabilities so that the probabilities of the new
|
|
214
|
-
## points are non-zero
|
|
215
|
-
new_proba = 1 / (data.nt_start + data.rar_iter_nb * selected_sample_size)
|
|
216
|
-
# the next work because nt_start is static
|
|
217
|
-
new_p_times = data.p_times.at[: data.nt_start].set(new_proba)
|
|
218
|
-
if isinstance(data, eqx.Module):
|
|
219
|
-
data = eqx.tree_at(
|
|
220
|
-
lambda m: m.p_times,
|
|
221
|
-
data,
|
|
222
|
-
new_p_times,
|
|
223
|
-
)
|
|
224
|
-
else:
|
|
225
|
-
data.p_times = new_p_times
|
|
226
|
-
|
|
227
|
-
# the next requires a fori_loop because the range is dynamic
|
|
228
|
-
def update_slices(i, p):
|
|
229
|
-
return jax.lax.dynamic_update_slice(
|
|
230
|
-
p,
|
|
231
|
-
1 / new_proba * jnp.ones((selected_sample_size,)),
|
|
232
|
-
((data.nt_start + i * selected_sample_size),),
|
|
233
|
-
)
|
|
146
|
+
data = eqx.tree_at(lambda m: m.key, data, new_key)
|
|
234
147
|
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
148
|
+
# We can have different types of Loss
|
|
149
|
+
if isinstance(loss, (LossODE, LossPDEStatio, LossPDENonStatio)):
|
|
150
|
+
v_dyn_loss = vmap(
|
|
151
|
+
lambda inputs: loss.dynamic_loss.evaluate(inputs, loss.u, params),
|
|
238
152
|
)
|
|
239
|
-
|
|
240
|
-
data = eqx.tree_at(
|
|
241
|
-
lambda m: (m.rar_iter_nb, m.p_times),
|
|
242
|
-
data,
|
|
243
|
-
(new_rar_iter_nb, new_p_times),
|
|
244
|
-
)
|
|
245
|
-
else:
|
|
246
|
-
data.rar_iter_nb = new_rar_iter_nb
|
|
247
|
-
data.p_times = new_p_times
|
|
153
|
+
dyn_on_s = v_dyn_loss(new_samples)
|
|
248
154
|
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
):
|
|
252
|
-
if isinstance(data, eqx.Module):
|
|
253
|
-
new_key, *subkeys = jax.random.split(data.key, data.dim + 1)
|
|
254
|
-
new_omega_samples = data.sample_in_omega_domain(subkeys, sample_size)
|
|
255
|
-
data = eqx.tree_at(lambda m: m.key, data, new_key)
|
|
155
|
+
if dyn_on_s.ndim > 1:
|
|
156
|
+
mse_on_s = (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
|
|
256
157
|
else:
|
|
257
|
-
|
|
158
|
+
mse_on_s = dyn_on_s**2
|
|
159
|
+
elif isinstance(loss, SystemLossODE, SystemLossPDE):
|
|
160
|
+
mse_on_s = 0
|
|
258
161
|
|
|
259
|
-
|
|
260
|
-
if isinstance(loss, LossPDEStatio):
|
|
162
|
+
for i in loss.dynamic_loss_dict.keys():
|
|
261
163
|
v_dyn_loss = vmap(
|
|
262
|
-
lambda
|
|
263
|
-
|
|
264
|
-
loss.u,
|
|
265
|
-
params,
|
|
164
|
+
lambda inputs: loss.dynamic_loss_dict[i].evaluate(
|
|
165
|
+
inputs, loss.u_dict, params
|
|
266
166
|
),
|
|
267
167
|
(0),
|
|
268
168
|
0,
|
|
269
169
|
)
|
|
270
|
-
dyn_on_s = v_dyn_loss(
|
|
170
|
+
dyn_on_s = v_dyn_loss(new_samples)
|
|
271
171
|
if dyn_on_s.ndim > 1:
|
|
272
|
-
mse_on_s
|
|
172
|
+
mse_on_s += (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
|
|
273
173
|
else:
|
|
274
|
-
mse_on_s
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
## Select the m points with higher dynamic loss
|
|
293
|
-
higher_residual_idx = jax.lax.dynamic_slice(
|
|
294
|
-
jnp.argsort(mse_on_s),
|
|
295
|
-
(mse_on_s.shape[0] - selected_sample_size,),
|
|
296
|
-
(selected_sample_size,),
|
|
174
|
+
mse_on_s += dyn_on_s**2
|
|
175
|
+
|
|
176
|
+
## Select the m points with higher dynamic loss
|
|
177
|
+
higher_residual_idx = jax.lax.dynamic_slice(
|
|
178
|
+
jnp.argsort(mse_on_s),
|
|
179
|
+
(mse_on_s.shape[0] - selected_sample_size,),
|
|
180
|
+
(selected_sample_size,),
|
|
181
|
+
)
|
|
182
|
+
higher_residual_points = new_samples[higher_residual_idx]
|
|
183
|
+
|
|
184
|
+
# add the new points
|
|
185
|
+
# start indices of update can be dynamic but the the shape (length)
|
|
186
|
+
# of the slice
|
|
187
|
+
if isinstance(data, DataGeneratorODE):
|
|
188
|
+
new_times = jax.lax.dynamic_update_slice(
|
|
189
|
+
data.times,
|
|
190
|
+
higher_residual_points,
|
|
191
|
+
(data.n_start + data.rar_iter_nb * selected_sample_size,),
|
|
297
192
|
)
|
|
298
|
-
higher_residual_points = new_omega_samples[higher_residual_idx]
|
|
299
193
|
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
194
|
+
data = eqx.tree_at(lambda m: m.times, data, new_times)
|
|
195
|
+
elif isinstance(data, CubicMeshPDEStatio) and not isinstance(
|
|
196
|
+
data, CubicMeshPDENonStatio
|
|
197
|
+
):
|
|
303
198
|
new_omega = jax.lax.dynamic_update_slice(
|
|
304
199
|
data.omega,
|
|
305
200
|
higher_residual_points,
|
|
306
201
|
(data.n_start + data.rar_iter_nb * selected_sample_size, data.dim),
|
|
307
202
|
)
|
|
308
203
|
|
|
309
|
-
|
|
310
|
-
data = eqx.tree_at(lambda m: m.omega, data, new_omega)
|
|
311
|
-
else:
|
|
312
|
-
data.omega = new_omega
|
|
313
|
-
|
|
314
|
-
## rearrange probabilities so that the probabilities of the new
|
|
315
|
-
## points are non-zero
|
|
316
|
-
new_proba = 1 / (data.n_start + data.rar_iter_nb * selected_sample_size)
|
|
317
|
-
# the next work because n_start is static
|
|
318
|
-
new_p_omega = data.p_omega.at[: data.n_start].set(new_proba)
|
|
319
|
-
if isinstance(data, eqx.Module):
|
|
320
|
-
data = eqx.tree_at(
|
|
321
|
-
lambda m: m.p_omega,
|
|
322
|
-
data,
|
|
323
|
-
new_p_omega,
|
|
324
|
-
)
|
|
325
|
-
else:
|
|
326
|
-
data.p_omega = new_p_omega
|
|
327
|
-
|
|
328
|
-
# the next requires a fori_loop because the range is dynamic
|
|
329
|
-
def update_slices(i, p):
|
|
330
|
-
return jax.lax.dynamic_update_slice(
|
|
331
|
-
p,
|
|
332
|
-
1 / new_proba * jnp.ones((selected_sample_size,)),
|
|
333
|
-
((data.n_start + i * selected_sample_size),),
|
|
334
|
-
)
|
|
335
|
-
|
|
336
|
-
new_rar_iter_nb = data.rar_iter_nb + 1
|
|
337
|
-
new_p_omega = jax.lax.fori_loop(
|
|
338
|
-
0, data.rar_iter_nb, update_slices, data.p_omega
|
|
339
|
-
)
|
|
340
|
-
if isinstance(data, eqx.Module):
|
|
341
|
-
data = eqx.tree_at(
|
|
342
|
-
lambda m: (m.rar_iter_nb, m.p_omega),
|
|
343
|
-
data,
|
|
344
|
-
(new_rar_iter_nb, new_p_omega),
|
|
345
|
-
)
|
|
346
|
-
else:
|
|
347
|
-
data.rar_iter_nb = new_rar_iter_nb
|
|
348
|
-
data.p_omega = new_p_omega
|
|
204
|
+
data = eqx.tree_at(lambda m: m.omega, data, new_omega)
|
|
349
205
|
|
|
350
206
|
elif isinstance(data, CubicMeshPDENonStatio):
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
# are tuples (times, omega) => we unpack them for clarity
|
|
356
|
-
selected_sample_size_times, selected_sample_size_omega = (
|
|
357
|
-
selected_sample_size
|
|
358
|
-
)
|
|
359
|
-
sample_size_times, sample_size_omega = sample_size
|
|
360
|
-
|
|
361
|
-
if isinstance(data, eqx.Module):
|
|
362
|
-
new_key, subkey = jax.random.split(data.key)
|
|
363
|
-
new_times_samples = data.sample_in_time_domain(
|
|
364
|
-
subkey, sample_size_times
|
|
365
|
-
)
|
|
366
|
-
new_key, *subkeys = jax.random.split(new_key, data.dim + 1)
|
|
367
|
-
new_omega_samples = data.sample_in_omega_domain(
|
|
368
|
-
subkeys, sample_size_omega
|
|
369
|
-
)
|
|
370
|
-
data = eqx.tree_at(lambda m: m.key, data, new_key)
|
|
371
|
-
else:
|
|
372
|
-
new_times_samples = data.sample_in_time_domain(sample_size_times)
|
|
373
|
-
new_omega_samples = data.sample_in_omega_domain(sample_size_omega)
|
|
374
|
-
|
|
375
|
-
if not data.cartesian_product:
|
|
376
|
-
times = new_times_samples
|
|
377
|
-
omega = new_omega_samples
|
|
378
|
-
else:
|
|
379
|
-
# do cartesian product on new points
|
|
380
|
-
omega = jnp.tile(
|
|
381
|
-
new_omega_samples, reps=(sample_size_times, 1)
|
|
382
|
-
) # it is tiled
|
|
383
|
-
times = jnp.repeat(new_times_samples, sample_size_omega, axis=0)[
|
|
384
|
-
..., None
|
|
385
|
-
] # it is repeated + add an axis
|
|
386
|
-
|
|
387
|
-
if isinstance(loss, LossPDENonStatio):
|
|
388
|
-
v_dyn_loss = vmap(
|
|
389
|
-
lambda t, x: loss.dynamic_loss.evaluate(t, x, loss.u, params),
|
|
390
|
-
(0, 0),
|
|
391
|
-
0,
|
|
392
|
-
)
|
|
393
|
-
dyn_on_s = v_dyn_loss(times, omega).reshape(
|
|
394
|
-
(sample_size_times, sample_size_omega)
|
|
395
|
-
)
|
|
396
|
-
mse_on_s = dyn_on_s**2
|
|
397
|
-
elif isinstance(loss, SystemLossPDE):
|
|
398
|
-
dyn_on_s = jnp.zeros((sample_size_times, sample_size_omega))
|
|
399
|
-
for i in loss.dynamic_loss_dict.keys():
|
|
400
|
-
v_dyn_loss = vmap(
|
|
401
|
-
lambda t, x: loss.dynamic_loss_dict[i].evaluate(
|
|
402
|
-
t, x, loss.u_dict, params
|
|
403
|
-
),
|
|
404
|
-
(0, 0),
|
|
405
|
-
0,
|
|
406
|
-
)
|
|
407
|
-
dyn_on_s += v_dyn_loss(times, omega).reshape(
|
|
408
|
-
(sample_size_times, sample_size_omega)
|
|
409
|
-
)
|
|
410
|
-
|
|
411
|
-
mse_on_s = dyn_on_s**2
|
|
412
|
-
# -- Select the m points with highest average residuals on time and
|
|
413
|
-
# -- space (times in rows / omega in columns)
|
|
414
|
-
# mean_times = mse_on_s.mean(axis=1)
|
|
415
|
-
# mean_omega = mse_on_s.mean(axis=0)
|
|
416
|
-
# times_idx = jax.lax.dynamic_slice(
|
|
417
|
-
# jnp.argsort(mean_times),
|
|
418
|
-
# (mse_on_s.shape[0] - selected_sample_size_times,),
|
|
419
|
-
# (selected_sample_size_times,),
|
|
420
|
-
# )
|
|
421
|
-
# omega_idx = jax.lax.dynamic_slice(
|
|
422
|
-
# jnp.argsort(mean_omega),
|
|
423
|
-
# (mse_on_s.shape[1] - selected_sample_size_omega,),
|
|
424
|
-
# (selected_sample_size_omega,),
|
|
425
|
-
# )
|
|
426
|
-
|
|
427
|
-
# -- Select the m worst points (t, x) with highest residuals
|
|
428
|
-
n_select = max(selected_sample_size_times, selected_sample_size_omega)
|
|
429
|
-
_, idx = jax.lax.top_k(mse_on_s.flatten(), k=n_select)
|
|
430
|
-
arr_idx = jnp.unravel_index(idx, mse_on_s.shape)
|
|
431
|
-
times_idx = arr_idx[0][:selected_sample_size_times]
|
|
432
|
-
omega_idx = arr_idx[1][:selected_sample_size_omega]
|
|
433
|
-
|
|
434
|
-
higher_residual_points_times = new_times_samples[times_idx]
|
|
435
|
-
higher_residual_points_omega = new_omega_samples[omega_idx]
|
|
436
|
-
|
|
437
|
-
## add the new points in times
|
|
438
|
-
# start indices of update can be dynamic but not the shape (length)
|
|
439
|
-
# of the slice
|
|
440
|
-
new_times = jax.lax.dynamic_update_slice(
|
|
441
|
-
data.times,
|
|
442
|
-
higher_residual_points_times,
|
|
443
|
-
(
|
|
444
|
-
data.n_start
|
|
445
|
-
+ data.rar_iter_nb # NOTE typo here nt_start ?
|
|
446
|
-
* selected_sample_size_times,
|
|
447
|
-
),
|
|
448
|
-
)
|
|
449
|
-
|
|
450
|
-
if isinstance(data, eqx.Module):
|
|
451
|
-
data = eqx.tree_at(lambda m: m.times, data, new_times)
|
|
452
|
-
else:
|
|
453
|
-
data.times = new_times
|
|
454
|
-
|
|
455
|
-
## add the new points in omega
|
|
456
|
-
new_omega = jax.lax.dynamic_update_slice(
|
|
457
|
-
data.omega,
|
|
458
|
-
higher_residual_points_omega,
|
|
459
|
-
(
|
|
460
|
-
data.n_start + data.rar_iter_nb * selected_sample_size_omega,
|
|
461
|
-
data.dim,
|
|
462
|
-
),
|
|
207
|
+
new_domain = jax.lax.dynamic_update_slice(
|
|
208
|
+
data.domain,
|
|
209
|
+
higher_residual_points,
|
|
210
|
+
(data.n_start + data.rar_iter_nb * selected_sample_size, 1 + data.dim),
|
|
463
211
|
)
|
|
464
212
|
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
213
|
+
data = eqx.tree_at(lambda m: m.domain, data, new_domain)
|
|
214
|
+
|
|
215
|
+
## rearrange probabilities so that the probabilities of the new
|
|
216
|
+
## points are non-zero
|
|
217
|
+
new_proba = 1 / (data.n_start + data.rar_iter_nb * selected_sample_size)
|
|
218
|
+
# the next work because nt_start is static
|
|
219
|
+
new_p = data.p.at[: data.n_start].set(new_proba)
|
|
220
|
+
data = eqx.tree_at(
|
|
221
|
+
lambda m: m.p,
|
|
222
|
+
data,
|
|
223
|
+
new_p,
|
|
224
|
+
)
|
|
469
225
|
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
226
|
+
# the next requires a fori_loop because the range is dynamic
|
|
227
|
+
def update_slices(i, p):
|
|
228
|
+
return jax.lax.dynamic_update_slice(
|
|
229
|
+
p,
|
|
230
|
+
1 / new_proba * jnp.ones((selected_sample_size,)),
|
|
231
|
+
((data.n_start + i * selected_sample_size),),
|
|
474
232
|
)
|
|
475
|
-
# the next work because nt_start is static
|
|
476
|
-
if isinstance(data, eqx.Module):
|
|
477
|
-
data = eqx.tree_at(
|
|
478
|
-
lambda m: m.p_times,
|
|
479
|
-
data,
|
|
480
|
-
data.p_times.at[: data.nt_start].set(new_p_times),
|
|
481
|
-
)
|
|
482
|
-
else:
|
|
483
|
-
data.p_times = data.p_times.at[: data.nt_start].set(new_p_times)
|
|
484
233
|
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
)
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
data,
|
|
493
|
-
data.p_omega.at[: data.n_start].set(new_p_omega),
|
|
494
|
-
)
|
|
495
|
-
else:
|
|
496
|
-
data.p_omega = data.p_omega.at[: data.n_start].set(new_p_omega)
|
|
497
|
-
|
|
498
|
-
# the part of data.p_* after n_start requires a fori_loop because
|
|
499
|
-
# the range is dynamic
|
|
500
|
-
def create_update_slices(new_val, selected_sample_size):
|
|
501
|
-
def update_slices(i, p):
|
|
502
|
-
new_p = jax.lax.dynamic_update_slice(
|
|
503
|
-
p,
|
|
504
|
-
new_val * jnp.ones((selected_sample_size,)),
|
|
505
|
-
((data.n_start + i * selected_sample_size),),
|
|
506
|
-
)
|
|
507
|
-
return new_p
|
|
508
|
-
|
|
509
|
-
return update_slices
|
|
510
|
-
|
|
511
|
-
new_rar_iter_nb = data.rar_iter_nb + 1
|
|
512
|
-
|
|
513
|
-
## update rest of p_times
|
|
514
|
-
update_slices_times = create_update_slices(
|
|
515
|
-
new_p_times, selected_sample_size_times
|
|
516
|
-
)
|
|
517
|
-
new_p_times = jax.lax.fori_loop(
|
|
518
|
-
0,
|
|
519
|
-
data.rar_iter_nb,
|
|
520
|
-
update_slices_times,
|
|
521
|
-
data.p_times,
|
|
522
|
-
)
|
|
523
|
-
## update rest of p_omega
|
|
524
|
-
update_slices_omega = create_update_slices(
|
|
525
|
-
new_p_omega, selected_sample_size_omega
|
|
526
|
-
)
|
|
527
|
-
new_p_omega = jax.lax.fori_loop(
|
|
528
|
-
0,
|
|
529
|
-
data.rar_iter_nb,
|
|
530
|
-
update_slices_omega,
|
|
531
|
-
data.p_omega,
|
|
532
|
-
)
|
|
533
|
-
if isinstance(data, eqx.Module):
|
|
534
|
-
data = eqx.tree_at(
|
|
535
|
-
lambda m: (m.rar_iter_nb, m.p_omega, m.p_times),
|
|
536
|
-
data,
|
|
537
|
-
(new_rar_iter_nb, new_p_omega, new_p_times),
|
|
538
|
-
)
|
|
539
|
-
else:
|
|
540
|
-
data.rar_iter_nb = new_rar_iter_nb
|
|
541
|
-
data.p_times = new_p_times
|
|
542
|
-
data.p_omega = new_p_omega
|
|
234
|
+
new_rar_iter_nb = data.rar_iter_nb + 1
|
|
235
|
+
new_p = jax.lax.fori_loop(0, new_rar_iter_nb, update_slices, data.p)
|
|
236
|
+
data = eqx.tree_at(
|
|
237
|
+
lambda m: (m.rar_iter_nb, m.p),
|
|
238
|
+
data,
|
|
239
|
+
(new_rar_iter_nb, new_p),
|
|
240
|
+
)
|
|
543
241
|
|
|
544
242
|
# update RAR parameters for all cases
|
|
545
|
-
|
|
546
|
-
data = eqx.tree_at(lambda m: m.rar_iter_from_last_sampling, data, 0)
|
|
547
|
-
else:
|
|
548
|
-
data.rar_iter_from_last_sampling = 0
|
|
243
|
+
data = eqx.tree_at(lambda m: m.rar_iter_from_last_sampling, data, 0)
|
|
549
244
|
|
|
550
245
|
# NOTE must return data to be correctly updated because we cannot
|
|
551
246
|
# have side effects in this function that will be jitted
|