jinns 0.8.8__py3-none-any.whl → 0.8.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_DataGenerators.py +93 -90
- jinns/loss/_LossODE.py +1 -0
- jinns/solver/_rar.py +203 -146
- jinns/solver/_seq2seq.py +2 -2
- {jinns-0.8.8.dist-info → jinns-0.8.10.dist-info}/METADATA +1 -1
- {jinns-0.8.8.dist-info → jinns-0.8.10.dist-info}/RECORD +9 -9
- {jinns-0.8.8.dist-info → jinns-0.8.10.dist-info}/LICENSE +0 -0
- {jinns-0.8.8.dist-info → jinns-0.8.10.dist-info}/WHEEL +0 -0
- {jinns-0.8.8.dist-info → jinns-0.8.10.dist-info}/top_level.txt +0 -0
jinns/data/_DataGenerators.py
CHANGED
|
@@ -98,6 +98,33 @@ def _reset_or_increment(bend, n_eff, operands):
|
|
|
98
98
|
)
|
|
99
99
|
|
|
100
100
|
|
|
101
|
+
def _check_and_set_rar_parameters(rar_parameters, n, n_start):
|
|
102
|
+
if rar_parameters is not None and n_start is None:
|
|
103
|
+
raise ValueError(
|
|
104
|
+
f"n_start or/and nt_start must be provided in the context of RAR sampling scheme, {n_start} was provided"
|
|
105
|
+
)
|
|
106
|
+
if rar_parameters is not None:
|
|
107
|
+
# Default p is None. However, in the RAR sampling scheme we use 0
|
|
108
|
+
# probability to specify non-used collocation points (i.e. points
|
|
109
|
+
# above n_start). Thus, p is a vector of probability of shape (n, 1).
|
|
110
|
+
p = jnp.zeros((n,))
|
|
111
|
+
p = p.at[:n_start].set(1 / n_start)
|
|
112
|
+
# set internal counter for the number of gradient steps since the
|
|
113
|
+
# last new collocation points have been added
|
|
114
|
+
# It is not 0 to ensure the first iteration of RAR happens just
|
|
115
|
+
# after start_iter. See the _proceed_to_rar() function in _rar.py
|
|
116
|
+
rar_iter_from_last_sampling = rar_parameters["update_every"] - 1
|
|
117
|
+
# set iternal counter for the number of times collocation points
|
|
118
|
+
# have been added
|
|
119
|
+
rar_iter_nb = 0
|
|
120
|
+
else:
|
|
121
|
+
p = None
|
|
122
|
+
rar_iter_from_last_sampling = None
|
|
123
|
+
rar_iter_nb = None
|
|
124
|
+
|
|
125
|
+
return n_start, p, rar_iter_from_last_sampling, rar_iter_nb
|
|
126
|
+
|
|
127
|
+
|
|
101
128
|
#####################################################
|
|
102
129
|
# DataGenerator for ODE : only returns time_batches
|
|
103
130
|
#####################################################
|
|
@@ -150,10 +177,10 @@ class DataGeneratorODE:
|
|
|
150
177
|
Default to None: do not use Residual Adaptative Resampling.
|
|
151
178
|
Otherwise a dictionary with keys. `start_iter`: the iteration at
|
|
152
179
|
which we start the RAR sampling scheme (we first have a burn in
|
|
153
|
-
period). `
|
|
180
|
+
period). `update_every`: the number of gradient steps taken between
|
|
154
181
|
each appending of collocation points in the RAR algo.
|
|
155
|
-
`
|
|
156
|
-
collocation points. `
|
|
182
|
+
`sample_size_times`: the size of the sample from which we will select new
|
|
183
|
+
collocation points. `selected_sample_size_times`: the number of selected
|
|
157
184
|
points from the sample to be added to the current collocation
|
|
158
185
|
points
|
|
159
186
|
"DeepXDE: A deep learning library for solving differential
|
|
@@ -179,29 +206,13 @@ class DataGeneratorODE:
|
|
|
179
206
|
self.method = method
|
|
180
207
|
self.rar_parameters = rar_parameters
|
|
181
208
|
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
self.
|
|
188
|
-
|
|
189
|
-
# probability to specify non-used collocation points (i.e. points
|
|
190
|
-
# above nt_start). Thus, p is a vector of probability of shape (nt, 1).
|
|
191
|
-
self.p = jnp.zeros((self.nt,))
|
|
192
|
-
self.p = self.p.at[: self.nt_start].set(1 / nt_start)
|
|
193
|
-
# set internal counter for the number of gradient steps since the
|
|
194
|
-
# last new collocation points have been added
|
|
195
|
-
self.rar_iter_from_last_sampling = 0
|
|
196
|
-
# set iternal counter for the number of times collocation points
|
|
197
|
-
# have been added
|
|
198
|
-
self.rar_iter_nb = 0
|
|
199
|
-
|
|
200
|
-
if rar_parameters is None or nt_start is None:
|
|
201
|
-
self.nt_start = self.nt
|
|
202
|
-
self.p = None
|
|
203
|
-
self.rar_iter_from_last_sampling = None
|
|
204
|
-
self.rar_iter_nb = None
|
|
209
|
+
# Set-up for RAR (if used)
|
|
210
|
+
(
|
|
211
|
+
self.nt_start,
|
|
212
|
+
self.p_times,
|
|
213
|
+
self.rar_iter_from_last_sampling,
|
|
214
|
+
self.rar_iter_nb,
|
|
215
|
+
) = _check_and_set_rar_parameters(rar_parameters, n=nt, n_start=nt_start)
|
|
205
216
|
|
|
206
217
|
if not self.data_exists:
|
|
207
218
|
# Useful when using a lax.scan with pytree
|
|
@@ -238,7 +249,7 @@ class DataGeneratorODE:
|
|
|
238
249
|
self.times,
|
|
239
250
|
self.curr_time_idx,
|
|
240
251
|
self.temporal_batch_size,
|
|
241
|
-
self.
|
|
252
|
+
self.p_times,
|
|
242
253
|
)
|
|
243
254
|
|
|
244
255
|
def temporal_batch(self):
|
|
@@ -253,7 +264,7 @@ class DataGeneratorODE:
|
|
|
253
264
|
if self.rar_parameters is not None:
|
|
254
265
|
nt_eff = (
|
|
255
266
|
self.nt_start
|
|
256
|
-
+ self.rar_iter_nb * self.rar_parameters["
|
|
267
|
+
+ self.rar_iter_nb * self.rar_parameters["selected_sample_size_omega"]
|
|
257
268
|
)
|
|
258
269
|
else:
|
|
259
270
|
nt_eff = self.nt
|
|
@@ -283,14 +294,19 @@ class DataGeneratorODE:
|
|
|
283
294
|
self.curr_time_idx,
|
|
284
295
|
self.tmin,
|
|
285
296
|
self.tmax,
|
|
286
|
-
self.
|
|
287
|
-
self.p,
|
|
297
|
+
self.p_times,
|
|
288
298
|
self.rar_iter_from_last_sampling,
|
|
289
299
|
self.rar_iter_nb,
|
|
290
300
|
) # arrays / dynamic values
|
|
291
301
|
aux_data = {
|
|
292
302
|
k: vars(self)[k]
|
|
293
|
-
for k in [
|
|
303
|
+
for k in [
|
|
304
|
+
"temporal_batch_size",
|
|
305
|
+
"method",
|
|
306
|
+
"nt",
|
|
307
|
+
"rar_parameters",
|
|
308
|
+
"nt_start",
|
|
309
|
+
]
|
|
294
310
|
} # static values
|
|
295
311
|
return (children, aux_data)
|
|
296
312
|
|
|
@@ -308,8 +324,7 @@ class DataGeneratorODE:
|
|
|
308
324
|
curr_time_idx,
|
|
309
325
|
tmin,
|
|
310
326
|
tmax,
|
|
311
|
-
|
|
312
|
-
p,
|
|
327
|
+
p_times,
|
|
313
328
|
rar_iter_from_last_sampling,
|
|
314
329
|
rar_iter_nb,
|
|
315
330
|
) = children
|
|
@@ -318,12 +333,11 @@ class DataGeneratorODE:
|
|
|
318
333
|
data_exists=True,
|
|
319
334
|
tmin=tmin,
|
|
320
335
|
tmax=tmax,
|
|
321
|
-
rar_parameters=rar_parameters,
|
|
322
336
|
**aux_data,
|
|
323
337
|
)
|
|
324
338
|
obj.times = times
|
|
325
339
|
obj.curr_time_idx = curr_time_idx
|
|
326
|
-
obj.
|
|
340
|
+
obj.p_times = p_times
|
|
327
341
|
obj.rar_iter_from_last_sampling = rar_iter_from_last_sampling
|
|
328
342
|
obj.rar_iter_nb = rar_iter_nb
|
|
329
343
|
return obj
|
|
@@ -412,10 +426,10 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
412
426
|
Default to None: do not use Residual Adaptative Resampling.
|
|
413
427
|
Otherwise a dictionary with keys. `start_iter`: the iteration at
|
|
414
428
|
which we start the RAR sampling scheme (we first have a burn in
|
|
415
|
-
period). `
|
|
429
|
+
period). `update_every`: the number of gradient steps taken between
|
|
416
430
|
each appending of collocation points in the RAR algo.
|
|
417
|
-
`
|
|
418
|
-
collocation points. `
|
|
431
|
+
`sample_size_omega`: the size of the sample from which we will select new
|
|
432
|
+
collocation points. `selected_sample_size_omega`: the number of selected
|
|
419
433
|
points from the sample to be added to the current collocation
|
|
420
434
|
points
|
|
421
435
|
"DeepXDE: A deep learning library for solving differential
|
|
@@ -442,30 +456,12 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
442
456
|
assert dim == len(max_pts) and isinstance(max_pts, tuple)
|
|
443
457
|
self.n = n
|
|
444
458
|
self.rar_parameters = rar_parameters
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
self.n_start = n_start
|
|
452
|
-
# Default p is None. However, in the RAR sampling scheme we use 0
|
|
453
|
-
# probability to specify non-used collocation points (i.e. points
|
|
454
|
-
# above n_start). Thus, p is a vector of probability of shape (n, 1).
|
|
455
|
-
self.p = jnp.zeros((self.n,))
|
|
456
|
-
self.p = self.p.at[: self.n_start].set(1 / n_start)
|
|
457
|
-
# set internal counter for the number of gradient steps since the
|
|
458
|
-
# last new collocation points have been added
|
|
459
|
-
self.rar_iter_from_last_sampling = 0
|
|
460
|
-
# set iternal counter for the number of times collocation points
|
|
461
|
-
# have been added
|
|
462
|
-
self.rar_iter_nb = 0
|
|
463
|
-
|
|
464
|
-
if rar_parameters is None or n_start is None:
|
|
465
|
-
self.n_start = self.n
|
|
466
|
-
self.p = None
|
|
467
|
-
self.rar_iter_from_last_sampling = None
|
|
468
|
-
self.rar_iter_nb = None
|
|
459
|
+
(
|
|
460
|
+
self.n_start,
|
|
461
|
+
self.p_omega,
|
|
462
|
+
self.rar_iter_from_last_sampling,
|
|
463
|
+
self.rar_iter_nb,
|
|
464
|
+
) = _check_and_set_rar_parameters(rar_parameters, n=n, n_start=n_start)
|
|
469
465
|
|
|
470
466
|
self.p_border = None # no RAR sampling for border for now
|
|
471
467
|
|
|
@@ -643,7 +639,7 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
643
639
|
self.omega,
|
|
644
640
|
self.curr_omega_idx,
|
|
645
641
|
self.omega_batch_size,
|
|
646
|
-
self.
|
|
642
|
+
self.p_omega,
|
|
647
643
|
)
|
|
648
644
|
|
|
649
645
|
def inside_batch(self):
|
|
@@ -656,7 +652,7 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
656
652
|
if self.rar_parameters is not None:
|
|
657
653
|
n_eff = (
|
|
658
654
|
self.n_start
|
|
659
|
-
+ self.rar_iter_nb * self.rar_parameters["
|
|
655
|
+
+ self.rar_iter_nb * self.rar_parameters["selected_sample_size_omega"]
|
|
660
656
|
)
|
|
661
657
|
else:
|
|
662
658
|
n_eff = self.n
|
|
@@ -745,8 +741,7 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
745
741
|
self.curr_omega_border_idx,
|
|
746
742
|
self.min_pts,
|
|
747
743
|
self.max_pts,
|
|
748
|
-
self.
|
|
749
|
-
self.p,
|
|
744
|
+
self.p_omega,
|
|
750
745
|
self.rar_iter_from_last_sampling,
|
|
751
746
|
self.rar_iter_nb,
|
|
752
747
|
)
|
|
@@ -759,6 +754,7 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
759
754
|
"omega_border_batch_size",
|
|
760
755
|
"method",
|
|
761
756
|
"dim",
|
|
757
|
+
"rar_parameters",
|
|
762
758
|
"n_start",
|
|
763
759
|
]
|
|
764
760
|
}
|
|
@@ -780,8 +776,7 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
780
776
|
curr_omega_border_idx,
|
|
781
777
|
min_pts,
|
|
782
778
|
max_pts,
|
|
783
|
-
|
|
784
|
-
p,
|
|
779
|
+
p_omega,
|
|
785
780
|
rar_iter_from_last_sampling,
|
|
786
781
|
rar_iter_nb,
|
|
787
782
|
) = children
|
|
@@ -792,14 +787,13 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
792
787
|
data_exists=True,
|
|
793
788
|
min_pts=min_pts,
|
|
794
789
|
max_pts=max_pts,
|
|
795
|
-
rar_parameters=rar_parameters,
|
|
796
790
|
**aux_data,
|
|
797
791
|
)
|
|
798
792
|
obj.omega = omega
|
|
799
793
|
obj.omega_border = omega_border
|
|
800
794
|
obj.curr_omega_idx = curr_omega_idx
|
|
801
795
|
obj.curr_omega_border_idx = curr_omega_border_idx
|
|
802
|
-
obj.
|
|
796
|
+
obj.p_omega = p_omega
|
|
803
797
|
obj.rar_iter_from_last_sampling = rar_iter_from_last_sampling
|
|
804
798
|
obj.rar_iter_nb = rar_iter_nb
|
|
805
799
|
return obj
|
|
@@ -834,6 +828,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
834
828
|
method="grid",
|
|
835
829
|
rar_parameters=None,
|
|
836
830
|
n_start=None,
|
|
831
|
+
nt_start=None,
|
|
837
832
|
data_exists=False,
|
|
838
833
|
):
|
|
839
834
|
r"""
|
|
@@ -887,27 +882,23 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
887
882
|
Default to None: do not use Residual Adaptative Resampling.
|
|
888
883
|
Otherwise a dictionary with keys. `start_iter`: the iteration at
|
|
889
884
|
which we start the RAR sampling scheme (we first have a burn in
|
|
890
|
-
period). `
|
|
885
|
+
period). `update_every`: the number of gradient steps taken between
|
|
891
886
|
each appending of collocation points in the RAR algo.
|
|
892
|
-
`
|
|
893
|
-
collocation points. `
|
|
887
|
+
`sample_size_omega`: the size of the sample from which we will select new
|
|
888
|
+
collocation points. `selected_sample_size_omega`: the number of selected
|
|
894
889
|
points from the sample to be added to the current collocation
|
|
895
890
|
points.
|
|
896
|
-
__Note:__ that if RAR sampling is chosen it will currently affect both
|
|
897
|
-
self.times and self.omega with the same hyperparameters
|
|
898
|
-
(rar_parameters and n_start)
|
|
899
|
-
"DeepXDE: A deep learning library for solving differential
|
|
900
|
-
equations", L. Lu, SIAM Review, 2021
|
|
901
891
|
n_start
|
|
902
892
|
Defaults to None. The effective size of n used at start time.
|
|
903
893
|
This value must be
|
|
904
894
|
provided when rar_parameters is not None. Otherwise we set internally
|
|
905
895
|
n_start = n and this is hidden from the user.
|
|
906
896
|
In RAR, n_start
|
|
907
|
-
then corresponds to the initial number of points we train the PINN.
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
897
|
+
then corresponds to the initial number of omega points we train the PINN.
|
|
898
|
+
nt_start
|
|
899
|
+
Defaults to None. A RAR hyper-parameter. Same as ``n_start`` but
|
|
900
|
+
for times collocation point. See also ``DataGeneratorODE``
|
|
901
|
+
documentation.
|
|
911
902
|
data_exists
|
|
912
903
|
Must be left to `False` when created by the user. Avoids the
|
|
913
904
|
regeneration of :math:`\Omega`, :math:`\partial\Omega` and
|
|
@@ -931,6 +922,15 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
931
922
|
self.tmin = tmin
|
|
932
923
|
self.tmax = tmax
|
|
933
924
|
self.nt = nt
|
|
925
|
+
|
|
926
|
+
# Set-up for timewise RAR (some quantity are already set-up by super())
|
|
927
|
+
(
|
|
928
|
+
self.nt_start,
|
|
929
|
+
self.p_times,
|
|
930
|
+
_,
|
|
931
|
+
_,
|
|
932
|
+
) = _check_and_set_rar_parameters(rar_parameters, n=nt, n_start=nt_start)
|
|
933
|
+
|
|
934
934
|
if not self.data_exists:
|
|
935
935
|
# Useful when using a lax.scan with pytree
|
|
936
936
|
# Optionally can tell JAX not to re-generate data
|
|
@@ -950,7 +950,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
950
950
|
self.times,
|
|
951
951
|
self.curr_time_idx,
|
|
952
952
|
self.temporal_batch_size,
|
|
953
|
-
self.
|
|
953
|
+
self.p_times,
|
|
954
954
|
)
|
|
955
955
|
|
|
956
956
|
def generate_data_nonstatio(self):
|
|
@@ -980,7 +980,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
980
980
|
if self.rar_parameters is not None:
|
|
981
981
|
nt_eff = (
|
|
982
982
|
self.n_start
|
|
983
|
-
+ self.rar_iter_nb * self.rar_parameters["
|
|
983
|
+
+ self.rar_iter_nb * self.rar_parameters["selected_sample_size_times"]
|
|
984
984
|
)
|
|
985
985
|
else:
|
|
986
986
|
nt_eff = self.nt
|
|
@@ -1022,8 +1022,8 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
1022
1022
|
self.max_pts,
|
|
1023
1023
|
self.tmin,
|
|
1024
1024
|
self.tmax,
|
|
1025
|
-
self.
|
|
1026
|
-
self.
|
|
1025
|
+
self.p_times,
|
|
1026
|
+
self.p_omega,
|
|
1027
1027
|
self.rar_iter_from_last_sampling,
|
|
1028
1028
|
self.rar_iter_nb,
|
|
1029
1029
|
)
|
|
@@ -1038,7 +1038,9 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
1038
1038
|
"temporal_batch_size",
|
|
1039
1039
|
"method",
|
|
1040
1040
|
"dim",
|
|
1041
|
+
"rar_parameters",
|
|
1041
1042
|
"n_start",
|
|
1043
|
+
"nt_start",
|
|
1042
1044
|
]
|
|
1043
1045
|
}
|
|
1044
1046
|
return (children, aux_data)
|
|
@@ -1063,8 +1065,8 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
1063
1065
|
max_pts,
|
|
1064
1066
|
tmin,
|
|
1065
1067
|
tmax,
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
+
p_times,
|
|
1069
|
+
p_omega,
|
|
1068
1070
|
rar_iter_from_last_sampling,
|
|
1069
1071
|
rar_iter_nb,
|
|
1070
1072
|
) = children
|
|
@@ -1075,7 +1077,6 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
1075
1077
|
max_pts=max_pts,
|
|
1076
1078
|
tmin=tmin,
|
|
1077
1079
|
tmax=tmax,
|
|
1078
|
-
rar_parameters=rar_parameters,
|
|
1079
1080
|
**aux_data,
|
|
1080
1081
|
)
|
|
1081
1082
|
obj.omega = omega
|
|
@@ -1084,9 +1085,11 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
1084
1085
|
obj.curr_omega_idx = curr_omega_idx
|
|
1085
1086
|
obj.curr_omega_border_idx = curr_omega_border_idx
|
|
1086
1087
|
obj.curr_time_idx = curr_time_idx
|
|
1087
|
-
obj.
|
|
1088
|
+
obj.p_times = p_times
|
|
1089
|
+
obj.p_omega = p_omega
|
|
1088
1090
|
obj.rar_iter_from_last_sampling = rar_iter_from_last_sampling
|
|
1089
1091
|
obj.rar_iter_nb = rar_iter_nb
|
|
1092
|
+
|
|
1090
1093
|
return obj
|
|
1091
1094
|
|
|
1092
1095
|
|
jinns/loss/_LossODE.py
CHANGED
jinns/solver/_rar.py
CHANGED
|
@@ -11,6 +11,41 @@ from jinns.loss._LossODE import LossODE, SystemLossODE
|
|
|
11
11
|
from jinns.loss._DynamicLossAbstract import PDEStatio
|
|
12
12
|
|
|
13
13
|
from functools import partial
|
|
14
|
+
from jinns.utils._hyperpinn import HYPERPINN
|
|
15
|
+
from jinns.utils._spinn import SPINN
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _proceed_to_rar(data, i):
|
|
19
|
+
"""Utilility function with various check to ensure we can proceed with the rar_step.
|
|
20
|
+
Return True if yes, and False otherwise"""
|
|
21
|
+
|
|
22
|
+
# Overall checks (universal for any data generator)
|
|
23
|
+
check_list = [
|
|
24
|
+
# check if burn-in period has ended
|
|
25
|
+
data.rar_parameters["start_iter"] <= i,
|
|
26
|
+
# check if enough iterations since last points added
|
|
27
|
+
(data.rar_parameters["update_every"] - 1) == data.rar_iter_from_last_sampling,
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
# Memory allocation checks (depends on the type of DataGenerator)
|
|
31
|
+
# check if we still have room to append new collocation points in the
|
|
32
|
+
# allocated jnp.array (can concern `data.p_times` or `p_omega`)
|
|
33
|
+
if isinstance(data, DataGeneratorODE) or isinstance(data, CubicMeshPDENonStatio):
|
|
34
|
+
check_list.append(
|
|
35
|
+
data.rar_parameters["selected_sample_size_times"]
|
|
36
|
+
<= jnp.count_nonzero(data.p_times == 0),
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
if isinstance(data, CubicMeshPDEStatio) or isinstance(data, CubicMeshPDENonStatio):
|
|
40
|
+
# for now the above check are redundants but there may be a time when
|
|
41
|
+
# we drop inheritence
|
|
42
|
+
check_list.append(
|
|
43
|
+
data.rar_parameters["selected_sample_size_omega"]
|
|
44
|
+
<= jnp.count_nonzero(data.p_omega == 0),
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
proceed = jnp.all(jnp.array(check_list))
|
|
48
|
+
return proceed
|
|
14
49
|
|
|
15
50
|
|
|
16
51
|
@partial(jax.jit, static_argnames=["_rar_step_true", "_rar_step_false"])
|
|
@@ -22,21 +57,7 @@ def trigger_rar(i, loss, params, data, _rar_step_true, _rar_step_false):
|
|
|
22
57
|
else:
|
|
23
58
|
# update `data` according to rar scheme.
|
|
24
59
|
data = jax.lax.cond(
|
|
25
|
-
|
|
26
|
-
jnp.array(
|
|
27
|
-
[
|
|
28
|
-
# check if enough it since last points added
|
|
29
|
-
data.rar_parameters["update_rate"]
|
|
30
|
-
== data.rar_iter_from_last_sampling,
|
|
31
|
-
# check if burn in period has ended
|
|
32
|
-
data.rar_parameters["start_iter"] < i,
|
|
33
|
-
# check if we still have room to append new
|
|
34
|
-
# collocation points in the allocated jnp array
|
|
35
|
-
data.rar_parameters["selected_sample_size"]
|
|
36
|
-
<= jnp.count_nonzero(data.p == 0),
|
|
37
|
-
]
|
|
38
|
-
)
|
|
39
|
-
),
|
|
60
|
+
_proceed_to_rar(data, i),
|
|
40
61
|
_rar_step_true,
|
|
41
62
|
_rar_step_false,
|
|
42
63
|
(loss, params, data, i),
|
|
@@ -49,13 +70,37 @@ def init_rar(data):
|
|
|
49
70
|
Separated from the main rar, because the initialization to get _true and
|
|
50
71
|
_false cannot be jit-ted.
|
|
51
72
|
"""
|
|
73
|
+
# NOTE if a user misspell some entry of ``rar_parameters`` the error
|
|
74
|
+
# risks to be a bit obscure but it should be ok.
|
|
52
75
|
if data.rar_parameters is None:
|
|
53
76
|
_rar_step_true, _rar_step_false = None, None
|
|
54
77
|
else:
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
78
|
+
if isinstance(data, DataGeneratorODE):
|
|
79
|
+
# In this case we only need rar parameters related to `times`
|
|
80
|
+
_rar_step_true, _rar_step_false = _rar_step_init(
|
|
81
|
+
data.rar_parameters["sample_size_times"],
|
|
82
|
+
data.rar_parameters["selected_sample_size_times"],
|
|
83
|
+
)
|
|
84
|
+
elif isinstance(data, CubicMeshPDENonStatio):
|
|
85
|
+
# In this case we need rar parameters related to both `times`
|
|
86
|
+
# and`omega`
|
|
87
|
+
_rar_step_true, _rar_step_false = _rar_step_init(
|
|
88
|
+
(
|
|
89
|
+
data.rar_parameters["sample_size_times"],
|
|
90
|
+
data.rar_parameters["sample_size_omega"],
|
|
91
|
+
),
|
|
92
|
+
(
|
|
93
|
+
data.rar_parameters["selected_sample_size_times"],
|
|
94
|
+
data.rar_parameters["selected_sample_size_omega"],
|
|
95
|
+
),
|
|
96
|
+
)
|
|
97
|
+
elif isinstance(data, CubicMeshPDEStatio):
|
|
98
|
+
# In this case we only need rar parameters related to `omega`
|
|
99
|
+
_rar_step_true, _rar_step_false = _rar_step_init(
|
|
100
|
+
data.rar_parameters["sample_size_omega"],
|
|
101
|
+
data.rar_parameters["selected_sample_size_omega"],
|
|
102
|
+
)
|
|
103
|
+
|
|
59
104
|
data.rar_parameters["iter_from_last_sampling"] = 0
|
|
60
105
|
|
|
61
106
|
return data, _rar_step_true, _rar_step_false
|
|
@@ -64,7 +109,7 @@ def init_rar(data):
|
|
|
64
109
|
def _rar_step_init(sample_size, selected_sample_size):
|
|
65
110
|
"""
|
|
66
111
|
This is a wrapper because the sampling size and
|
|
67
|
-
selected_sample_size, must be treated static
|
|
112
|
+
selected_sample_size, must be treated as static
|
|
68
113
|
in order to slice. So they must be set before jitting and not with the jitted
|
|
69
114
|
dictionary values rar["test_points_nb"] and rar["added_points_nb"]
|
|
70
115
|
|
|
@@ -72,16 +117,10 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
72
117
|
"""
|
|
73
118
|
|
|
74
119
|
def rar_step_true(operands):
|
|
75
|
-
"""
|
|
76
|
-
Note: in all generality, we would need a stop gradient operator around
|
|
77
|
-
these dynamic_loss evaluations that follow which produce weights for
|
|
78
|
-
sampling. However, they appear through a argsort and sampling
|
|
79
|
-
operations which definitly kills gradient flows
|
|
80
|
-
"""
|
|
81
120
|
loss, params, data, i = operands
|
|
82
121
|
|
|
83
122
|
if isinstance(data, DataGeneratorODE):
|
|
84
|
-
|
|
123
|
+
new_omega_samples = data.sample_in_time_domain(sample_size)
|
|
85
124
|
|
|
86
125
|
# We can have different types of Loss
|
|
87
126
|
if isinstance(loss, LossODE):
|
|
@@ -90,7 +129,7 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
90
129
|
(0),
|
|
91
130
|
0,
|
|
92
131
|
)
|
|
93
|
-
dyn_on_s = v_dyn_loss(
|
|
132
|
+
dyn_on_s = v_dyn_loss(new_omega_samples)
|
|
94
133
|
if dyn_on_s.ndim > 1:
|
|
95
134
|
mse_on_s = (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
|
|
96
135
|
else:
|
|
@@ -106,7 +145,7 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
106
145
|
(0),
|
|
107
146
|
0,
|
|
108
147
|
)
|
|
109
|
-
dyn_on_s = v_dyn_loss(
|
|
148
|
+
dyn_on_s = v_dyn_loss(new_omega_samples)
|
|
110
149
|
if dyn_on_s.ndim > 1:
|
|
111
150
|
mse_on_s += (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
|
|
112
151
|
else:
|
|
@@ -118,9 +157,7 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
118
157
|
(mse_on_s.shape[0] - selected_sample_size,),
|
|
119
158
|
(selected_sample_size,),
|
|
120
159
|
)
|
|
121
|
-
higher_residual_points =
|
|
122
|
-
|
|
123
|
-
data.rar_iter_from_last_sampling = 0
|
|
160
|
+
higher_residual_points = new_omega_samples[higher_residual_idx]
|
|
124
161
|
|
|
125
162
|
## add the new points in times
|
|
126
163
|
# start indices of update can be dynamic but the the shape (length)
|
|
@@ -135,7 +172,7 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
135
172
|
## points are non-zero
|
|
136
173
|
new_proba = 1 / (data.nt_start + data.rar_iter_nb * selected_sample_size)
|
|
137
174
|
# the next work because nt_start is static
|
|
138
|
-
data.
|
|
175
|
+
data.p_times = data.p_times.at[: data.nt_start].set(new_proba)
|
|
139
176
|
|
|
140
177
|
# the next requires a fori_loop because the range is dynamic
|
|
141
178
|
def update_slices(i, p):
|
|
@@ -147,16 +184,14 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
147
184
|
|
|
148
185
|
data.rar_iter_nb += 1
|
|
149
186
|
|
|
150
|
-
data.
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
# have side effects in this function that will be jitted
|
|
154
|
-
return data
|
|
187
|
+
data.p_times = jax.lax.fori_loop(
|
|
188
|
+
0, data.rar_iter_nb, update_slices, data.p_times
|
|
189
|
+
)
|
|
155
190
|
|
|
156
191
|
elif isinstance(data, CubicMeshPDEStatio) and not isinstance(
|
|
157
192
|
data, CubicMeshPDENonStatio
|
|
158
193
|
):
|
|
159
|
-
|
|
194
|
+
new_omega_samples = data.sample_in_omega_domain(sample_size)
|
|
160
195
|
|
|
161
196
|
# We can have different types of Loss
|
|
162
197
|
if isinstance(loss, LossPDEStatio):
|
|
@@ -169,7 +204,7 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
169
204
|
(0),
|
|
170
205
|
0,
|
|
171
206
|
)
|
|
172
|
-
dyn_on_s = v_dyn_loss(
|
|
207
|
+
dyn_on_s = v_dyn_loss(new_omega_samples)
|
|
173
208
|
if dyn_on_s.ndim > 1:
|
|
174
209
|
mse_on_s = (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
|
|
175
210
|
else:
|
|
@@ -185,7 +220,7 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
185
220
|
0,
|
|
186
221
|
0,
|
|
187
222
|
)
|
|
188
|
-
dyn_on_s = v_dyn_loss(
|
|
223
|
+
dyn_on_s = v_dyn_loss(new_omega_samples)
|
|
189
224
|
if dyn_on_s.ndim > 1:
|
|
190
225
|
mse_on_s += (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
|
|
191
226
|
else:
|
|
@@ -197,12 +232,10 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
197
232
|
(mse_on_s.shape[0] - selected_sample_size,),
|
|
198
233
|
(selected_sample_size,),
|
|
199
234
|
)
|
|
200
|
-
higher_residual_points =
|
|
235
|
+
higher_residual_points = new_omega_samples[higher_residual_idx]
|
|
201
236
|
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
## add the new points in times
|
|
205
|
-
# start indices of update can be dynamic but the the shape (length)
|
|
237
|
+
## add the new points in omega
|
|
238
|
+
# start indices of update can be dynamic but not the shape (length)
|
|
206
239
|
# of the slice
|
|
207
240
|
data.omega = jax.lax.dynamic_update_slice(
|
|
208
241
|
data.omega,
|
|
@@ -214,7 +247,7 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
214
247
|
## points are non-zero
|
|
215
248
|
new_proba = 1 / (data.n_start + data.rar_iter_nb * selected_sample_size)
|
|
216
249
|
# the next work because n_start is static
|
|
217
|
-
data.
|
|
250
|
+
data.p_omega = data.p_omega.at[: data.n_start].set(new_proba)
|
|
218
251
|
|
|
219
252
|
# the next requires a fori_loop because the range is dynamic
|
|
220
253
|
def update_slices(i, p):
|
|
@@ -226,145 +259,169 @@ def _rar_step_init(sample_size, selected_sample_size):
|
|
|
226
259
|
|
|
227
260
|
data.rar_iter_nb += 1
|
|
228
261
|
|
|
229
|
-
data.
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
# have side effects in this function that will be jitted
|
|
233
|
-
return data
|
|
262
|
+
data.p_omega = jax.lax.fori_loop(
|
|
263
|
+
0, data.rar_iter_nb, update_slices, data.p_omega
|
|
264
|
+
)
|
|
234
265
|
|
|
235
266
|
elif isinstance(data, CubicMeshPDENonStatio):
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
)
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
(
|
|
252
|
-
|
|
253
|
-
)
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
mse_on_s = dyn_on_s**2
|
|
259
|
-
elif isinstance(loss, LossPDENonStatio):
|
|
267
|
+
# NOTE in this case sample_size and selected_sample_size
|
|
268
|
+
# are tuples (times, omega) => we unpack them for clarity
|
|
269
|
+
selected_sample_size_times, selected_sample_size_omega = (
|
|
270
|
+
selected_sample_size
|
|
271
|
+
)
|
|
272
|
+
sample_size_times, sample_size_omega = sample_size
|
|
273
|
+
|
|
274
|
+
new_times_samples = data.sample_in_time_domain(sample_size_times)
|
|
275
|
+
new_omega_samples = data.sample_in_omega_domain(sample_size_omega)
|
|
276
|
+
|
|
277
|
+
if isinstance(loss.u, HYPERPINN) or isinstance(loss.u, SPINN):
|
|
278
|
+
raise NotImplementedError("RAR not implemented for hyperPINN and SPINN")
|
|
279
|
+
else:
|
|
280
|
+
# do cartesian product on new points
|
|
281
|
+
tile_omega = jnp.tile(
|
|
282
|
+
new_omega_samples, reps=(sample_size_times, 1)
|
|
283
|
+
) # it is tiled
|
|
284
|
+
repeat_times = jnp.repeat(new_times_samples, sample_size_omega, axis=0)[
|
|
285
|
+
..., None
|
|
286
|
+
] # it is repeated + add an axis
|
|
287
|
+
|
|
288
|
+
if isinstance(loss, LossPDENonStatio):
|
|
260
289
|
v_dyn_loss = vmap(
|
|
261
290
|
lambda t, x: loss.dynamic_loss.evaluate(t, x, loss.u, params),
|
|
262
291
|
(0, 0),
|
|
263
292
|
0,
|
|
264
293
|
)
|
|
265
|
-
dyn_on_s = v_dyn_loss(
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
mse_on_s = dyn_on_s**2
|
|
294
|
+
dyn_on_s = v_dyn_loss(repeat_times, tile_omega).reshape(
|
|
295
|
+
(sample_size_times, sample_size_omega)
|
|
296
|
+
)
|
|
297
|
+
mse_on_s = dyn_on_s**2
|
|
270
298
|
elif isinstance(loss, SystemLossPDE):
|
|
271
|
-
|
|
299
|
+
dyn_on_s = jnp.zeros((sample_size_times, sample_size_omega))
|
|
272
300
|
for i in loss.dynamic_loss_dict.keys():
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
mse_on_s += (
|
|
284
|
-
jnp.linalg.norm(dyn_on_s, axis=-1) ** 2
|
|
285
|
-
).flatten()
|
|
286
|
-
else:
|
|
287
|
-
mse_on_s += dyn_on_s**2
|
|
288
|
-
else:
|
|
289
|
-
v_dyn_loss = vmap(
|
|
290
|
-
lambda t, x: loss.dynamic_loss_dict[i].evaluate(
|
|
291
|
-
t, x, loss.u_dict, params
|
|
292
|
-
),
|
|
293
|
-
(0, 0),
|
|
294
|
-
0,
|
|
295
|
-
)
|
|
296
|
-
dyn_on_s = v_dyn_loss(st[..., None], sx)
|
|
297
|
-
if dyn_on_s.ndim > 1:
|
|
298
|
-
mse_on_s += (
|
|
299
|
-
jnp.linalg.norm(dyn_on_s, axis=-1) ** 2
|
|
300
|
-
).flatten()
|
|
301
|
-
else:
|
|
302
|
-
mse_on_s += dyn_on_s**2
|
|
303
|
-
|
|
304
|
-
## Now that we have the residuals, select the m points
|
|
305
|
-
# with higher dynamic loss (residuals)
|
|
306
|
-
higher_residual_idx = jax.lax.dynamic_slice(
|
|
307
|
-
jnp.argsort(mse_on_s),
|
|
308
|
-
(mse_on_s.shape[0] - selected_sample_size,),
|
|
309
|
-
(selected_sample_size,),
|
|
310
|
-
)
|
|
311
|
-
higher_residual_points_st = st[higher_residual_idx]
|
|
312
|
-
higher_residual_points_sx = sx[higher_residual_idx]
|
|
301
|
+
v_dyn_loss = vmap(
|
|
302
|
+
lambda t, x: loss.dynamic_loss_dict[i].evaluate(
|
|
303
|
+
t, x, loss.u_dict, params
|
|
304
|
+
),
|
|
305
|
+
(0, 0),
|
|
306
|
+
0,
|
|
307
|
+
)
|
|
308
|
+
dyn_on_s += v_dyn_loss(repeat_times, tile_omega).reshape(
|
|
309
|
+
(sample_size_times, sample_size_omega)
|
|
310
|
+
)
|
|
313
311
|
|
|
314
|
-
|
|
312
|
+
mse_on_s = dyn_on_s**2
|
|
313
|
+
# -- Select the m points with highest average residuals on time and
|
|
314
|
+
# -- space (times in rows / omega in columns)
|
|
315
|
+
# mean_times = mse_on_s.mean(axis=1)
|
|
316
|
+
# mean_omega = mse_on_s.mean(axis=0)
|
|
317
|
+
# times_idx = jax.lax.dynamic_slice(
|
|
318
|
+
# jnp.argsort(mean_times),
|
|
319
|
+
# (mse_on_s.shape[0] - selected_sample_size_times,),
|
|
320
|
+
# (selected_sample_size_times,),
|
|
321
|
+
# )
|
|
322
|
+
# omega_idx = jax.lax.dynamic_slice(
|
|
323
|
+
# jnp.argsort(mean_omega),
|
|
324
|
+
# (mse_on_s.shape[1] - selected_sample_size_omega,),
|
|
325
|
+
# (selected_sample_size_omega,),
|
|
326
|
+
# )
|
|
327
|
+
|
|
328
|
+
# -- Select the m worst points (t, x) with highest residuals
|
|
329
|
+
n_select = max(selected_sample_size_times, selected_sample_size_omega)
|
|
330
|
+
_, idx = jax.lax.top_k(mse_on_s.flatten(), k=n_select)
|
|
331
|
+
arr_idx = jnp.unravel_index(idx, mse_on_s.shape)
|
|
332
|
+
times_idx = arr_idx[0][:selected_sample_size_times]
|
|
333
|
+
omega_idx = arr_idx[1][:selected_sample_size_omega]
|
|
334
|
+
|
|
335
|
+
higher_residual_points_times = new_times_samples[times_idx]
|
|
336
|
+
higher_residual_points_omega = new_omega_samples[omega_idx]
|
|
315
337
|
|
|
316
338
|
## add the new points in times
|
|
317
|
-
# start indices of update can be dynamic but
|
|
339
|
+
# start indices of update can be dynamic but not the shape (length)
|
|
318
340
|
# of the slice
|
|
319
341
|
data.times = jax.lax.dynamic_update_slice(
|
|
320
342
|
data.times,
|
|
321
|
-
|
|
322
|
-
(data.n_start + data.rar_iter_nb *
|
|
343
|
+
higher_residual_points_times,
|
|
344
|
+
(data.n_start + data.rar_iter_nb * selected_sample_size_times,),
|
|
323
345
|
)
|
|
324
346
|
|
|
325
347
|
## add the new points in omega
|
|
326
348
|
data.omega = jax.lax.dynamic_update_slice(
|
|
327
349
|
data.omega,
|
|
328
|
-
|
|
350
|
+
higher_residual_points_omega,
|
|
329
351
|
(
|
|
330
|
-
data.n_start + data.rar_iter_nb *
|
|
352
|
+
data.n_start + data.rar_iter_nb * selected_sample_size_omega,
|
|
331
353
|
data.dim,
|
|
332
354
|
),
|
|
333
355
|
)
|
|
334
356
|
|
|
335
357
|
## rearrange probabilities so that the probabilities of the new
|
|
336
358
|
## points are non-zero
|
|
337
|
-
|
|
359
|
+
new_p_times = 1 / (
|
|
360
|
+
data.nt_start + data.rar_iter_nb * selected_sample_size_times
|
|
361
|
+
)
|
|
338
362
|
# the next work because nt_start is static
|
|
339
|
-
data.
|
|
363
|
+
data.p_times = data.p_times.at[: data.nt_start].set(new_p_times)
|
|
340
364
|
|
|
341
|
-
#
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
365
|
+
# same for p_omega (work because n_start is static)
|
|
366
|
+
new_p_omega = 1 / (
|
|
367
|
+
data.n_start + data.rar_iter_nb * selected_sample_size_omega
|
|
368
|
+
)
|
|
369
|
+
data.p_omega = data.p_omega.at[: data.n_start].set(new_p_omega)
|
|
370
|
+
|
|
371
|
+
# the part of data.p_* after n_start requires a fori_loop because
|
|
372
|
+
# the range is dynamic
|
|
373
|
+
def create_update_slices(new_val, selected_sample_size):
|
|
374
|
+
def update_slices(i, p):
|
|
375
|
+
new_p = jax.lax.dynamic_update_slice(
|
|
376
|
+
p,
|
|
377
|
+
new_val * jnp.ones((selected_sample_size,)),
|
|
378
|
+
((data.n_start + i * selected_sample_size),),
|
|
379
|
+
)
|
|
380
|
+
return new_p
|
|
381
|
+
|
|
382
|
+
return update_slices
|
|
348
383
|
|
|
349
384
|
data.rar_iter_nb += 1
|
|
350
385
|
|
|
351
|
-
|
|
386
|
+
## update rest of p_times
|
|
387
|
+
update_slices_times = create_update_slices(
|
|
388
|
+
new_p_times, selected_sample_size_times
|
|
389
|
+
)
|
|
390
|
+
data.p_times = jax.lax.fori_loop(
|
|
391
|
+
0,
|
|
392
|
+
data.rar_iter_nb,
|
|
393
|
+
update_slices_times,
|
|
394
|
+
data.p_times,
|
|
395
|
+
)
|
|
396
|
+
## update rest of p_omega
|
|
397
|
+
update_slices_omega = create_update_slices(
|
|
398
|
+
new_p_omega, selected_sample_size_omega
|
|
399
|
+
)
|
|
400
|
+
data.p_omega = jax.lax.fori_loop(
|
|
401
|
+
0,
|
|
402
|
+
data.rar_iter_nb,
|
|
403
|
+
update_slices_omega,
|
|
404
|
+
data.p_omega,
|
|
405
|
+
)
|
|
352
406
|
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
407
|
+
# update RAR parameters for all cases
|
|
408
|
+
data.rar_iter_from_last_sampling = 0
|
|
409
|
+
|
|
410
|
+
# NOTE must return data to be correctly updated because we cannot
|
|
411
|
+
# have side effects in this function that will be jitted
|
|
412
|
+
return data
|
|
356
413
|
|
|
357
414
|
def rar_step_false(operands):
|
|
358
415
|
_, _, data, i = operands
|
|
359
416
|
|
|
360
417
|
# Add 1 only if we are after the burn in period
|
|
361
|
-
|
|
362
|
-
i
|
|
363
|
-
lambda
|
|
364
|
-
lambda
|
|
365
|
-
(data.rar_iter_from_last_sampling),
|
|
418
|
+
increment = jax.lax.cond(
|
|
419
|
+
i <= data.rar_parameters["start_iter"],
|
|
420
|
+
lambda: 0,
|
|
421
|
+
lambda: 1,
|
|
366
422
|
)
|
|
367
423
|
|
|
424
|
+
data.rar_iter_from_last_sampling += increment
|
|
368
425
|
return data
|
|
369
426
|
|
|
370
427
|
return rar_step_true, rar_step_false
|
jinns/solver/_seq2seq.py
CHANGED
|
@@ -88,7 +88,7 @@ def initialize_seq2seq(loss, data, seq2seq, opt_state):
|
|
|
88
88
|
data.curr_omega_idx = 0
|
|
89
89
|
data.generate_time_data()
|
|
90
90
|
data._key, data.times, _ = _reset_batch_idx_and_permute(
|
|
91
|
-
(data._key, data.times, data.curr_omega_idx, None, data.
|
|
91
|
+
(data._key, data.times, data.curr_omega_idx, None, data.p_times)
|
|
92
92
|
)
|
|
93
93
|
opt_state.hyperparams["learning_rate"] = seq2seq["learning_rate"][curr_seq]
|
|
94
94
|
|
|
@@ -145,7 +145,7 @@ def _update_seq2seq_SystemLossODE(operands):
|
|
|
145
145
|
data.curr_omega_idx = 0
|
|
146
146
|
data.generate_time_data()
|
|
147
147
|
data._key, data.times, _ = _reset_batch_idx_and_permute(
|
|
148
|
-
(data._key, data.times, data.curr_omega_idx, None, data.
|
|
148
|
+
(data._key, data.times, data.curr_omega_idx, None, data.p_times)
|
|
149
149
|
)
|
|
150
150
|
|
|
151
151
|
opt_state.hyperparams["learning_rate"] = seq2seq["learning_rate"][curr_seq]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: jinns
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.10
|
|
4
4
|
Summary: Physics Informed Neural Network with JAX
|
|
5
5
|
Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
6
6
|
Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
jinns/__init__.py,sha256=T2XlmLbYqcXTumPJL00cJ80W98We5LH8Yg_Lss_exl4,139
|
|
2
|
-
jinns/data/_DataGenerators.py,sha256=
|
|
2
|
+
jinns/data/_DataGenerators.py,sha256=_um-giHQ8mCILUOJHX231njHTHZp4S7EcGrUs7R1dUs,61829
|
|
3
3
|
jinns/data/__init__.py,sha256=yBOmoavSD-cABp4XcjQY1zsEVO0mDyIhi2MJ5WNp0l8,326
|
|
4
4
|
jinns/data/_display.py,sha256=vlqggDCgVMEwdGBtjVmZaTQORU6imSfDkssn2XCtITI,10392
|
|
5
5
|
jinns/experimental/__init__.py,sha256=qWbhC7Z8UgLWy0t-zU7RYze6v13-FngiCYXu-2bRVFQ,296
|
|
@@ -8,15 +8,15 @@ jinns/experimental/_sinuspinn.py,sha256=hxSzscwMV2LayWOqenIlT1zqEVVrE5Y8CKf7bHX5
|
|
|
8
8
|
jinns/experimental/_spectralpinn.py,sha256=-4795pa7AYtRNSE-ugan3gHh64mtu2VdrRG5AS_J9Eg,2654
|
|
9
9
|
jinns/loss/_DynamicLoss.py,sha256=L4CVmmF0rTPbHntgqsLLHlnrlQgLHsetUocpJm7ZYag,27461
|
|
10
10
|
jinns/loss/_DynamicLossAbstract.py,sha256=kTQlhLx7SBuH5dIDmYaE79sVHUZt1nUFa8LxPU5IHhM,8504
|
|
11
|
-
jinns/loss/_LossODE.py,sha256=
|
|
11
|
+
jinns/loss/_LossODE.py,sha256=Y1mxryPVFf7ruqw_mGNACLExfx4iQT4R2bZP3s5rg4c,22172
|
|
12
12
|
jinns/loss/_LossPDE.py,sha256=purAEtc0e71kv9XnZUT-a7MrkDAkM_3tTI4xJPu6fH4,61629
|
|
13
13
|
jinns/loss/_Losses.py,sha256=XOL3MFiKEd3ndsc78Qnpi1vbgR0B2HaAWOGGW2meDM8,11190
|
|
14
14
|
jinns/loss/__init__.py,sha256=pFNYUxns-NPXBFdqrEVSiXkQLfCtKw-t2trlhvLzpYE,355
|
|
15
15
|
jinns/loss/_boundary_conditions.py,sha256=YfSnLZ25hXqQ5KWAuxOrWSKkf_oBqAc9GQV4z7MjWyQ,17434
|
|
16
16
|
jinns/loss/_operators.py,sha256=zDGJqYqeYH7xd-4dtGX9PS-pf0uSOpUUXGo5SVjIJ4o,11069
|
|
17
17
|
jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
18
|
-
jinns/solver/_rar.py,sha256=
|
|
19
|
-
jinns/solver/_seq2seq.py,sha256=
|
|
18
|
+
jinns/solver/_rar.py,sha256=IYP-jdbM0rbjBtxislrBYBuj49p9_QDOqejZKCHrKg8,17072
|
|
19
|
+
jinns/solver/_seq2seq.py,sha256=S6IPfsXpS_fbqIqAy01eUM7GBSBSkRzURan_J-iXXzI,5632
|
|
20
20
|
jinns/solver/_solve.py,sha256=mGi0zaT_fK_QpBjTxof5Ix4mmfmnPi66CNJ3GQFZuo4,19099
|
|
21
21
|
jinns/utils/__init__.py,sha256=44ms5UR6vMw3Nf6u4RCAzPFs4fom_YbBnH9mfne8m6k,313
|
|
22
22
|
jinns/utils/_containers.py,sha256=eYD277fO7X4EfX7PUFCCl69r3JBfh1sCfq8LkL5gd6o,1495
|
|
@@ -29,8 +29,8 @@ jinns/utils/_utils.py,sha256=8dgvWXX9NT7_7-zltWp0C9tG45ZFNwXxueyxPBb4hjo,6740
|
|
|
29
29
|
jinns/utils/_utils_uspinn.py,sha256=qcKcOw3zrwWSQyGVj6fD8c9GinHt_U6JWN_k0auTtXM,26039
|
|
30
30
|
jinns/validation/__init__.py,sha256=Jv58mzgC3F7cRfXA6caicL1t_U0UAhbwLrmMNVg6E7s,66
|
|
31
31
|
jinns/validation/_validation.py,sha256=KfetbzB0xTNdBcYLwFWjEtP63Tf9wJirlhgqLTJDyy4,6761
|
|
32
|
-
jinns-0.8.
|
|
33
|
-
jinns-0.8.
|
|
34
|
-
jinns-0.8.
|
|
35
|
-
jinns-0.8.
|
|
36
|
-
jinns-0.8.
|
|
32
|
+
jinns-0.8.10.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
|
|
33
|
+
jinns-0.8.10.dist-info/METADATA,sha256=5lGoyi2W9MRamQdHVgZnYflJtp__zWDGyTiYgmfGc6g,2483
|
|
34
|
+
jinns-0.8.10.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
35
|
+
jinns-0.8.10.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
|
|
36
|
+
jinns-0.8.10.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|