jinns 0.8.7__py3-none-any.whl → 0.8.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +1 -0
- jinns/data/_DataGenerators.py +93 -90
- jinns/data/_display.py +102 -13
- jinns/experimental/__init__.py +2 -0
- jinns/experimental/_sinuspinn.py +135 -0
- jinns/experimental/_spectralpinn.py +87 -0
- jinns/solver/_rar.py +203 -146
- jinns/solver/_seq2seq.py +2 -2
- {jinns-0.8.7.dist-info → jinns-0.8.9.dist-info}/METADATA +1 -1
- {jinns-0.8.7.dist-info → jinns-0.8.9.dist-info}/RECORD +13 -11
- {jinns-0.8.7.dist-info → jinns-0.8.9.dist-info}/LICENSE +0 -0
- {jinns-0.8.7.dist-info → jinns-0.8.9.dist-info}/WHEEL +0 -0
- {jinns-0.8.7.dist-info → jinns-0.8.9.dist-info}/top_level.txt +0 -0
jinns/__init__.py
CHANGED
jinns/data/_DataGenerators.py
CHANGED
|
@@ -98,6 +98,33 @@ def _reset_or_increment(bend, n_eff, operands):
|
|
|
98
98
|
)
|
|
99
99
|
|
|
100
100
|
|
|
101
|
+
def _check_and_set_rar_parameters(rar_parameters, n, n_start):
|
|
102
|
+
if rar_parameters is not None and n_start is None:
|
|
103
|
+
raise ValueError(
|
|
104
|
+
f"n_start or/and nt_start must be provided in the context of RAR sampling scheme, {n_start} was provided"
|
|
105
|
+
)
|
|
106
|
+
if rar_parameters is not None:
|
|
107
|
+
# Default p is None. However, in the RAR sampling scheme we use 0
|
|
108
|
+
# probability to specify non-used collocation points (i.e. points
|
|
109
|
+
# above n_start). Thus, p is a vector of probability of shape (n, 1).
|
|
110
|
+
p = jnp.zeros((n,))
|
|
111
|
+
p = p.at[:n_start].set(1 / n_start)
|
|
112
|
+
# set internal counter for the number of gradient steps since the
|
|
113
|
+
# last new collocation points have been added
|
|
114
|
+
# It is not 0 to ensure the first iteration of RAR happens just
|
|
115
|
+
# after start_iter. See the _proceed_to_rar() function in _rar.py
|
|
116
|
+
rar_iter_from_last_sampling = rar_parameters["update_every"] - 1
|
|
117
|
+
# set iternal counter for the number of times collocation points
|
|
118
|
+
# have been added
|
|
119
|
+
rar_iter_nb = 0
|
|
120
|
+
else:
|
|
121
|
+
p = None
|
|
122
|
+
rar_iter_from_last_sampling = None
|
|
123
|
+
rar_iter_nb = None
|
|
124
|
+
|
|
125
|
+
return n_start, p, rar_iter_from_last_sampling, rar_iter_nb
|
|
126
|
+
|
|
127
|
+
|
|
101
128
|
#####################################################
|
|
102
129
|
# DataGenerator for ODE : only returns time_batches
|
|
103
130
|
#####################################################
|
|
@@ -150,10 +177,10 @@ class DataGeneratorODE:
|
|
|
150
177
|
Default to None: do not use Residual Adaptative Resampling.
|
|
151
178
|
Otherwise a dictionary with keys. `start_iter`: the iteration at
|
|
152
179
|
which we start the RAR sampling scheme (we first have a burn in
|
|
153
|
-
period). `
|
|
180
|
+
period). `update_every`: the number of gradient steps taken between
|
|
154
181
|
each appending of collocation points in the RAR algo.
|
|
155
|
-
`
|
|
156
|
-
collocation points. `
|
|
182
|
+
`sample_size_times`: the size of the sample from which we will select new
|
|
183
|
+
collocation points. `selected_sample_size_times`: the number of selected
|
|
157
184
|
points from the sample to be added to the current collocation
|
|
158
185
|
points
|
|
159
186
|
"DeepXDE: A deep learning library for solving differential
|
|
@@ -179,29 +206,13 @@ class DataGeneratorODE:
|
|
|
179
206
|
self.method = method
|
|
180
207
|
self.rar_parameters = rar_parameters
|
|
181
208
|
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
self.
|
|
188
|
-
|
|
189
|
-
# probability to specify non-used collocation points (i.e. points
|
|
190
|
-
# above nt_start). Thus, p is a vector of probability of shape (nt, 1).
|
|
191
|
-
self.p = jnp.zeros((self.nt,))
|
|
192
|
-
self.p = self.p.at[: self.nt_start].set(1 / nt_start)
|
|
193
|
-
# set internal counter for the number of gradient steps since the
|
|
194
|
-
# last new collocation points have been added
|
|
195
|
-
self.rar_iter_from_last_sampling = 0
|
|
196
|
-
# set iternal counter for the number of times collocation points
|
|
197
|
-
# have been added
|
|
198
|
-
self.rar_iter_nb = 0
|
|
199
|
-
|
|
200
|
-
if rar_parameters is None or nt_start is None:
|
|
201
|
-
self.nt_start = self.nt
|
|
202
|
-
self.p = None
|
|
203
|
-
self.rar_iter_from_last_sampling = None
|
|
204
|
-
self.rar_iter_nb = None
|
|
209
|
+
# Set-up for RAR (if used)
|
|
210
|
+
(
|
|
211
|
+
self.nt_start,
|
|
212
|
+
self.p_times,
|
|
213
|
+
self.rar_iter_from_last_sampling,
|
|
214
|
+
self.rar_iter_nb,
|
|
215
|
+
) = _check_and_set_rar_parameters(rar_parameters, n=nt, n_start=nt_start)
|
|
205
216
|
|
|
206
217
|
if not self.data_exists:
|
|
207
218
|
# Useful when using a lax.scan with pytree
|
|
@@ -238,7 +249,7 @@ class DataGeneratorODE:
|
|
|
238
249
|
self.times,
|
|
239
250
|
self.curr_time_idx,
|
|
240
251
|
self.temporal_batch_size,
|
|
241
|
-
self.
|
|
252
|
+
self.p_times,
|
|
242
253
|
)
|
|
243
254
|
|
|
244
255
|
def temporal_batch(self):
|
|
@@ -253,7 +264,7 @@ class DataGeneratorODE:
|
|
|
253
264
|
if self.rar_parameters is not None:
|
|
254
265
|
nt_eff = (
|
|
255
266
|
self.nt_start
|
|
256
|
-
+ self.rar_iter_nb * self.rar_parameters["
|
|
267
|
+
+ self.rar_iter_nb * self.rar_parameters["selected_sample_size_omega"]
|
|
257
268
|
)
|
|
258
269
|
else:
|
|
259
270
|
nt_eff = self.nt
|
|
@@ -283,14 +294,19 @@ class DataGeneratorODE:
|
|
|
283
294
|
self.curr_time_idx,
|
|
284
295
|
self.tmin,
|
|
285
296
|
self.tmax,
|
|
286
|
-
self.
|
|
287
|
-
self.p,
|
|
297
|
+
self.p_times,
|
|
288
298
|
self.rar_iter_from_last_sampling,
|
|
289
299
|
self.rar_iter_nb,
|
|
290
300
|
) # arrays / dynamic values
|
|
291
301
|
aux_data = {
|
|
292
302
|
k: vars(self)[k]
|
|
293
|
-
for k in [
|
|
303
|
+
for k in [
|
|
304
|
+
"temporal_batch_size",
|
|
305
|
+
"method",
|
|
306
|
+
"nt",
|
|
307
|
+
"rar_parameters",
|
|
308
|
+
"nt_start",
|
|
309
|
+
]
|
|
294
310
|
} # static values
|
|
295
311
|
return (children, aux_data)
|
|
296
312
|
|
|
@@ -308,8 +324,7 @@ class DataGeneratorODE:
|
|
|
308
324
|
curr_time_idx,
|
|
309
325
|
tmin,
|
|
310
326
|
tmax,
|
|
311
|
-
|
|
312
|
-
p,
|
|
327
|
+
p_times,
|
|
313
328
|
rar_iter_from_last_sampling,
|
|
314
329
|
rar_iter_nb,
|
|
315
330
|
) = children
|
|
@@ -318,12 +333,11 @@ class DataGeneratorODE:
|
|
|
318
333
|
data_exists=True,
|
|
319
334
|
tmin=tmin,
|
|
320
335
|
tmax=tmax,
|
|
321
|
-
rar_parameters=rar_parameters,
|
|
322
336
|
**aux_data,
|
|
323
337
|
)
|
|
324
338
|
obj.times = times
|
|
325
339
|
obj.curr_time_idx = curr_time_idx
|
|
326
|
-
obj.
|
|
340
|
+
obj.p_times = p_times
|
|
327
341
|
obj.rar_iter_from_last_sampling = rar_iter_from_last_sampling
|
|
328
342
|
obj.rar_iter_nb = rar_iter_nb
|
|
329
343
|
return obj
|
|
@@ -412,10 +426,10 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
412
426
|
Default to None: do not use Residual Adaptative Resampling.
|
|
413
427
|
Otherwise a dictionary with keys. `start_iter`: the iteration at
|
|
414
428
|
which we start the RAR sampling scheme (we first have a burn in
|
|
415
|
-
period). `
|
|
429
|
+
period). `update_every`: the number of gradient steps taken between
|
|
416
430
|
each appending of collocation points in the RAR algo.
|
|
417
|
-
`
|
|
418
|
-
collocation points. `
|
|
431
|
+
`sample_size_omega`: the size of the sample from which we will select new
|
|
432
|
+
collocation points. `selected_sample_size_omega`: the number of selected
|
|
419
433
|
points from the sample to be added to the current collocation
|
|
420
434
|
points
|
|
421
435
|
"DeepXDE: A deep learning library for solving differential
|
|
@@ -442,30 +456,12 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
442
456
|
assert dim == len(max_pts) and isinstance(max_pts, tuple)
|
|
443
457
|
self.n = n
|
|
444
458
|
self.rar_parameters = rar_parameters
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
self.n_start = n_start
|
|
452
|
-
# Default p is None. However, in the RAR sampling scheme we use 0
|
|
453
|
-
# probability to specify non-used collocation points (i.e. points
|
|
454
|
-
# above n_start). Thus, p is a vector of probability of shape (n, 1).
|
|
455
|
-
self.p = jnp.zeros((self.n,))
|
|
456
|
-
self.p = self.p.at[: self.n_start].set(1 / n_start)
|
|
457
|
-
# set internal counter for the number of gradient steps since the
|
|
458
|
-
# last new collocation points have been added
|
|
459
|
-
self.rar_iter_from_last_sampling = 0
|
|
460
|
-
# set iternal counter for the number of times collocation points
|
|
461
|
-
# have been added
|
|
462
|
-
self.rar_iter_nb = 0
|
|
463
|
-
|
|
464
|
-
if rar_parameters is None or n_start is None:
|
|
465
|
-
self.n_start = self.n
|
|
466
|
-
self.p = None
|
|
467
|
-
self.rar_iter_from_last_sampling = None
|
|
468
|
-
self.rar_iter_nb = None
|
|
459
|
+
(
|
|
460
|
+
self.n_start,
|
|
461
|
+
self.p_omega,
|
|
462
|
+
self.rar_iter_from_last_sampling,
|
|
463
|
+
self.rar_iter_nb,
|
|
464
|
+
) = _check_and_set_rar_parameters(rar_parameters, n=n, n_start=n_start)
|
|
469
465
|
|
|
470
466
|
self.p_border = None # no RAR sampling for border for now
|
|
471
467
|
|
|
@@ -643,7 +639,7 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
643
639
|
self.omega,
|
|
644
640
|
self.curr_omega_idx,
|
|
645
641
|
self.omega_batch_size,
|
|
646
|
-
self.
|
|
642
|
+
self.p_omega,
|
|
647
643
|
)
|
|
648
644
|
|
|
649
645
|
def inside_batch(self):
|
|
@@ -656,7 +652,7 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
656
652
|
if self.rar_parameters is not None:
|
|
657
653
|
n_eff = (
|
|
658
654
|
self.n_start
|
|
659
|
-
+ self.rar_iter_nb * self.rar_parameters["
|
|
655
|
+
+ self.rar_iter_nb * self.rar_parameters["selected_sample_size_omega"]
|
|
660
656
|
)
|
|
661
657
|
else:
|
|
662
658
|
n_eff = self.n
|
|
@@ -745,8 +741,7 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
745
741
|
self.curr_omega_border_idx,
|
|
746
742
|
self.min_pts,
|
|
747
743
|
self.max_pts,
|
|
748
|
-
self.
|
|
749
|
-
self.p,
|
|
744
|
+
self.p_omega,
|
|
750
745
|
self.rar_iter_from_last_sampling,
|
|
751
746
|
self.rar_iter_nb,
|
|
752
747
|
)
|
|
@@ -759,6 +754,7 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
759
754
|
"omega_border_batch_size",
|
|
760
755
|
"method",
|
|
761
756
|
"dim",
|
|
757
|
+
"rar_parameters",
|
|
762
758
|
"n_start",
|
|
763
759
|
]
|
|
764
760
|
}
|
|
@@ -780,8 +776,7 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
780
776
|
curr_omega_border_idx,
|
|
781
777
|
min_pts,
|
|
782
778
|
max_pts,
|
|
783
|
-
|
|
784
|
-
p,
|
|
779
|
+
p_omega,
|
|
785
780
|
rar_iter_from_last_sampling,
|
|
786
781
|
rar_iter_nb,
|
|
787
782
|
) = children
|
|
@@ -792,14 +787,13 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
|
|
|
792
787
|
data_exists=True,
|
|
793
788
|
min_pts=min_pts,
|
|
794
789
|
max_pts=max_pts,
|
|
795
|
-
rar_parameters=rar_parameters,
|
|
796
790
|
**aux_data,
|
|
797
791
|
)
|
|
798
792
|
obj.omega = omega
|
|
799
793
|
obj.omega_border = omega_border
|
|
800
794
|
obj.curr_omega_idx = curr_omega_idx
|
|
801
795
|
obj.curr_omega_border_idx = curr_omega_border_idx
|
|
802
|
-
obj.
|
|
796
|
+
obj.p_omega = p_omega
|
|
803
797
|
obj.rar_iter_from_last_sampling = rar_iter_from_last_sampling
|
|
804
798
|
obj.rar_iter_nb = rar_iter_nb
|
|
805
799
|
return obj
|
|
@@ -834,6 +828,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
834
828
|
method="grid",
|
|
835
829
|
rar_parameters=None,
|
|
836
830
|
n_start=None,
|
|
831
|
+
nt_start=None,
|
|
837
832
|
data_exists=False,
|
|
838
833
|
):
|
|
839
834
|
r"""
|
|
@@ -887,27 +882,23 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
887
882
|
Default to None: do not use Residual Adaptative Resampling.
|
|
888
883
|
Otherwise a dictionary with keys. `start_iter`: the iteration at
|
|
889
884
|
which we start the RAR sampling scheme (we first have a burn in
|
|
890
|
-
period). `
|
|
885
|
+
period). `update_every`: the number of gradient steps taken between
|
|
891
886
|
each appending of collocation points in the RAR algo.
|
|
892
|
-
`
|
|
893
|
-
collocation points. `
|
|
887
|
+
`sample_size_omega`: the size of the sample from which we will select new
|
|
888
|
+
collocation points. `selected_sample_size_omega`: the number of selected
|
|
894
889
|
points from the sample to be added to the current collocation
|
|
895
890
|
points.
|
|
896
|
-
__Note:__ that if RAR sampling is chosen it will currently affect both
|
|
897
|
-
self.times and self.omega with the same hyperparameters
|
|
898
|
-
(rar_parameters and n_start)
|
|
899
|
-
"DeepXDE: A deep learning library for solving differential
|
|
900
|
-
equations", L. Lu, SIAM Review, 2021
|
|
901
891
|
n_start
|
|
902
892
|
Defaults to None. The effective size of n used at start time.
|
|
903
893
|
This value must be
|
|
904
894
|
provided when rar_parameters is not None. Otherwise we set internally
|
|
905
895
|
n_start = n and this is hidden from the user.
|
|
906
896
|
In RAR, n_start
|
|
907
|
-
then corresponds to the initial number of points we train the PINN.
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
897
|
+
then corresponds to the initial number of omega points we train the PINN.
|
|
898
|
+
nt_start
|
|
899
|
+
Defaults to None. A RAR hyper-parameter. Same as ``n_start`` but
|
|
900
|
+
for times collocation point. See also ``DataGeneratorODE``
|
|
901
|
+
documentation.
|
|
911
902
|
data_exists
|
|
912
903
|
Must be left to `False` when created by the user. Avoids the
|
|
913
904
|
regeneration of :math:`\Omega`, :math:`\partial\Omega` and
|
|
@@ -931,6 +922,15 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
931
922
|
self.tmin = tmin
|
|
932
923
|
self.tmax = tmax
|
|
933
924
|
self.nt = nt
|
|
925
|
+
|
|
926
|
+
# Set-up for timewise RAR (some quantity are already set-up by super())
|
|
927
|
+
(
|
|
928
|
+
self.nt_start,
|
|
929
|
+
self.p_times,
|
|
930
|
+
_,
|
|
931
|
+
_,
|
|
932
|
+
) = _check_and_set_rar_parameters(rar_parameters, n=nt, n_start=nt_start)
|
|
933
|
+
|
|
934
934
|
if not self.data_exists:
|
|
935
935
|
# Useful when using a lax.scan with pytree
|
|
936
936
|
# Optionally can tell JAX not to re-generate data
|
|
@@ -950,7 +950,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
950
950
|
self.times,
|
|
951
951
|
self.curr_time_idx,
|
|
952
952
|
self.temporal_batch_size,
|
|
953
|
-
self.
|
|
953
|
+
self.p_times,
|
|
954
954
|
)
|
|
955
955
|
|
|
956
956
|
def generate_data_nonstatio(self):
|
|
@@ -980,7 +980,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
980
980
|
if self.rar_parameters is not None:
|
|
981
981
|
nt_eff = (
|
|
982
982
|
self.n_start
|
|
983
|
-
+ self.rar_iter_nb * self.rar_parameters["
|
|
983
|
+
+ self.rar_iter_nb * self.rar_parameters["selected_sample_size_times"]
|
|
984
984
|
)
|
|
985
985
|
else:
|
|
986
986
|
nt_eff = self.nt
|
|
@@ -1022,8 +1022,8 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
1022
1022
|
self.max_pts,
|
|
1023
1023
|
self.tmin,
|
|
1024
1024
|
self.tmax,
|
|
1025
|
-
self.
|
|
1026
|
-
self.
|
|
1025
|
+
self.p_times,
|
|
1026
|
+
self.p_omega,
|
|
1027
1027
|
self.rar_iter_from_last_sampling,
|
|
1028
1028
|
self.rar_iter_nb,
|
|
1029
1029
|
)
|
|
@@ -1038,7 +1038,9 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
1038
1038
|
"temporal_batch_size",
|
|
1039
1039
|
"method",
|
|
1040
1040
|
"dim",
|
|
1041
|
+
"rar_parameters",
|
|
1041
1042
|
"n_start",
|
|
1043
|
+
"nt_start",
|
|
1042
1044
|
]
|
|
1043
1045
|
}
|
|
1044
1046
|
return (children, aux_data)
|
|
@@ -1063,8 +1065,8 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
1063
1065
|
max_pts,
|
|
1064
1066
|
tmin,
|
|
1065
1067
|
tmax,
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
+
p_times,
|
|
1069
|
+
p_omega,
|
|
1068
1070
|
rar_iter_from_last_sampling,
|
|
1069
1071
|
rar_iter_nb,
|
|
1070
1072
|
) = children
|
|
@@ -1075,7 +1077,6 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
1075
1077
|
max_pts=max_pts,
|
|
1076
1078
|
tmin=tmin,
|
|
1077
1079
|
tmax=tmax,
|
|
1078
|
-
rar_parameters=rar_parameters,
|
|
1079
1080
|
**aux_data,
|
|
1080
1081
|
)
|
|
1081
1082
|
obj.omega = omega
|
|
@@ -1084,9 +1085,11 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
1084
1085
|
obj.curr_omega_idx = curr_omega_idx
|
|
1085
1086
|
obj.curr_omega_border_idx = curr_omega_border_idx
|
|
1086
1087
|
obj.curr_time_idx = curr_time_idx
|
|
1087
|
-
obj.
|
|
1088
|
+
obj.p_times = p_times
|
|
1089
|
+
obj.p_omega = p_omega
|
|
1088
1090
|
obj.rar_iter_from_last_sampling = rar_iter_from_last_sampling
|
|
1089
1091
|
obj.rar_iter_nb = rar_iter_nb
|
|
1092
|
+
|
|
1090
1093
|
return obj
|
|
1091
1094
|
|
|
1092
1095
|
|
jinns/data/_display.py
CHANGED
|
@@ -1,8 +1,13 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility functions for plotting
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from functools import partial
|
|
6
|
+
import warnings
|
|
1
7
|
import matplotlib.pyplot as plt
|
|
2
8
|
import jax.numpy as jnp
|
|
3
9
|
from jax import vmap
|
|
4
10
|
from mpl_toolkits.axes_grid1 import ImageGrid
|
|
5
|
-
from functools import partial
|
|
6
11
|
|
|
7
12
|
|
|
8
13
|
def plot2d(
|
|
@@ -14,8 +19,10 @@ def plot2d(
|
|
|
14
19
|
figsize=(7, 7),
|
|
15
20
|
cmap="inferno",
|
|
16
21
|
spinn=False,
|
|
22
|
+
vmin_vmax=None,
|
|
23
|
+
ax_for_plot=None,
|
|
17
24
|
):
|
|
18
|
-
"""Generic function for plotting functions over rectangular 2-D domains
|
|
25
|
+
r"""Generic function for plotting functions over rectangular 2-D domains
|
|
19
26
|
:math:`\Omega`. It treats both the stationary case :math:`u(x)` or the
|
|
20
27
|
non-stationnary case :math:`u(t, x)`.
|
|
21
28
|
|
|
@@ -40,6 +47,12 @@ def plot2d(
|
|
|
40
47
|
_description_, by default (7, 7)
|
|
41
48
|
cmap : str, optional
|
|
42
49
|
_description_, by default "inferno"
|
|
50
|
+
vmin_vmax : tuple, optional
|
|
51
|
+
The colorbar minimum and maximum value. Defaults None.
|
|
52
|
+
ax_for_plot : Matplotlib axis, optional
|
|
53
|
+
If None, jinns triggers the plotting. Otherwise this argument
|
|
54
|
+
corresponds to the axis which will host the plot. Default is None.
|
|
55
|
+
NOTE: that this argument will have an effect only if times is None.
|
|
43
56
|
|
|
44
57
|
Raises
|
|
45
58
|
------
|
|
@@ -59,25 +72,49 @@ def plot2d(
|
|
|
59
72
|
# Statio case : expect a function of one argument fun(x)
|
|
60
73
|
if not spinn:
|
|
61
74
|
v_fun = vmap(fun, 0, 0)
|
|
62
|
-
_plot_2D_statio(
|
|
63
|
-
v_fun,
|
|
75
|
+
ret = _plot_2D_statio(
|
|
76
|
+
v_fun,
|
|
77
|
+
mesh,
|
|
78
|
+
plot=not ax_for_plot,
|
|
79
|
+
colorbar=True,
|
|
80
|
+
cmap=cmap,
|
|
81
|
+
figsize=figsize,
|
|
82
|
+
vmin_vmax=vmin_vmax,
|
|
64
83
|
)
|
|
65
84
|
elif spinn:
|
|
66
85
|
values_grid = jnp.squeeze(
|
|
67
86
|
fun(jnp.stack([xy_data[0][..., None], xy_data[1][..., None]], axis=1))
|
|
68
87
|
)
|
|
69
|
-
_plot_2D_statio(
|
|
88
|
+
ret = _plot_2D_statio(
|
|
70
89
|
values_grid,
|
|
71
90
|
mesh,
|
|
72
|
-
plot=
|
|
91
|
+
plot=not ax_for_plot,
|
|
73
92
|
colorbar=True,
|
|
74
93
|
cmap=cmap,
|
|
75
94
|
spinn=True,
|
|
76
95
|
figsize=figsize,
|
|
96
|
+
vmin_vmax=vmin_vmax,
|
|
77
97
|
)
|
|
78
|
-
|
|
98
|
+
if not ax_for_plot:
|
|
99
|
+
plt.title(title)
|
|
100
|
+
else:
|
|
101
|
+
if vmin_vmax is not None:
|
|
102
|
+
im = ax_for_plot.pcolormesh(
|
|
103
|
+
mesh[0],
|
|
104
|
+
mesh[1],
|
|
105
|
+
ret[0],
|
|
106
|
+
cmap=cmap,
|
|
107
|
+
vmin=vmin_vmax[0],
|
|
108
|
+
vmax=vmin_vmax[1],
|
|
109
|
+
)
|
|
110
|
+
else:
|
|
111
|
+
im = ax_for_plot.pcolormesh(mesh[0], mesh[1], ret[0], cmap=cmap)
|
|
112
|
+
ax_for_plot.set_title(title)
|
|
113
|
+
ax_for_plot.cax.colorbar(im, format="%0.2f")
|
|
79
114
|
|
|
80
115
|
else:
|
|
116
|
+
if ax_for_plot is not None:
|
|
117
|
+
warnings.warn("ax_for_plot is ignored. jinns will plot the figure")
|
|
81
118
|
if not isinstance(times, list):
|
|
82
119
|
try:
|
|
83
120
|
times = times.tolist()
|
|
@@ -101,7 +138,12 @@ def plot2d(
|
|
|
101
138
|
if not spinn:
|
|
102
139
|
v_fun_at_t = vmap(lambda x: fun(t=jnp.array([t]), x=x), 0, 0)
|
|
103
140
|
t_slice, _ = _plot_2D_statio(
|
|
104
|
-
v_fun_at_t,
|
|
141
|
+
v_fun_at_t,
|
|
142
|
+
mesh,
|
|
143
|
+
plot=False,
|
|
144
|
+
colorbar=False,
|
|
145
|
+
cmap=None,
|
|
146
|
+
vmin_vmax=vmin_vmax,
|
|
105
147
|
)
|
|
106
148
|
elif spinn:
|
|
107
149
|
values_grid = jnp.squeeze(
|
|
@@ -113,15 +155,37 @@ def plot2d(
|
|
|
113
155
|
)[0]
|
|
114
156
|
)
|
|
115
157
|
t_slice, _ = _plot_2D_statio(
|
|
116
|
-
values_grid,
|
|
158
|
+
values_grid,
|
|
159
|
+
mesh,
|
|
160
|
+
plot=False,
|
|
161
|
+
colorbar=True,
|
|
162
|
+
spinn=True,
|
|
163
|
+
vmin_vmax=vmin_vmax,
|
|
164
|
+
)
|
|
165
|
+
if vmin_vmax is not None:
|
|
166
|
+
im = ax.pcolormesh(
|
|
167
|
+
mesh[0],
|
|
168
|
+
mesh[1],
|
|
169
|
+
t_slice,
|
|
170
|
+
cmap=cmap,
|
|
171
|
+
vmin=vmin_vmax[0],
|
|
172
|
+
vmax=vmin_vmax[1],
|
|
117
173
|
)
|
|
118
|
-
|
|
174
|
+
else:
|
|
175
|
+
im = ax.pcolormesh(mesh[0], mesh[1], t_slice, cmap=cmap)
|
|
119
176
|
ax.set_title(f"t = {times[idx] * Tmax:.2f}")
|
|
120
177
|
ax.cax.colorbar(im, format="%0.2f")
|
|
121
178
|
|
|
122
179
|
|
|
123
180
|
def _plot_2D_statio(
|
|
124
|
-
v_fun,
|
|
181
|
+
v_fun,
|
|
182
|
+
mesh,
|
|
183
|
+
plot=True,
|
|
184
|
+
colorbar=True,
|
|
185
|
+
cmap="inferno",
|
|
186
|
+
figsize=(7, 7),
|
|
187
|
+
spinn=False,
|
|
188
|
+
vmin_vmax=None,
|
|
125
189
|
):
|
|
126
190
|
"""Function that plot the function u(x) with 2-D input x using pcolormesh()
|
|
127
191
|
|
|
@@ -136,6 +200,8 @@ def _plot_2D_statio(
|
|
|
136
200
|
either show or return the plot, by default True
|
|
137
201
|
colorbar : bool, optional
|
|
138
202
|
add a colorbar, by default True
|
|
203
|
+
vmin_vmax: tuple, optional
|
|
204
|
+
The colorbar minimum and maximum value. Defaults None.
|
|
139
205
|
|
|
140
206
|
Returns
|
|
141
207
|
-------
|
|
@@ -153,7 +219,17 @@ def _plot_2D_statio(
|
|
|
153
219
|
|
|
154
220
|
if plot:
|
|
155
221
|
fig = plt.figure(figsize=figsize)
|
|
156
|
-
|
|
222
|
+
if vmin_vmax is not None:
|
|
223
|
+
im = plt.pcolormesh(
|
|
224
|
+
x_grid,
|
|
225
|
+
y_grid,
|
|
226
|
+
values_grid,
|
|
227
|
+
cmap=cmap,
|
|
228
|
+
vmin=vmin_vmax[0],
|
|
229
|
+
vmax=vmin_vmax[1],
|
|
230
|
+
)
|
|
231
|
+
else:
|
|
232
|
+
im = plt.pcolormesh(x_grid, y_grid, values_grid, cmap=cmap)
|
|
157
233
|
if colorbar:
|
|
158
234
|
fig.colorbar(im, format="%0.2f")
|
|
159
235
|
# don't plt.show() because it is done in plot2d()
|
|
@@ -217,6 +293,7 @@ def plot1d_image(
|
|
|
217
293
|
colorbar=True,
|
|
218
294
|
cmap="inferno",
|
|
219
295
|
spinn=False,
|
|
296
|
+
vmin_vmax=None,
|
|
220
297
|
):
|
|
221
298
|
"""Function for plotting the 2-D image of a function :math:`f(t, x)` where
|
|
222
299
|
`t` is time (1-D) and x is space (1-D).
|
|
@@ -237,6 +314,8 @@ def plot1d_image(
|
|
|
237
314
|
, by default ""
|
|
238
315
|
figsize : tuple, optional
|
|
239
316
|
, by default (10, 10)
|
|
317
|
+
vmin_vmax: tuple
|
|
318
|
+
The colorbar minimum and maximum value. Defaults None.
|
|
240
319
|
"""
|
|
241
320
|
|
|
242
321
|
mesh = jnp.meshgrid(times, xdata) # cartesian product
|
|
@@ -250,7 +329,17 @@ def plot1d_image(
|
|
|
250
329
|
elif spinn:
|
|
251
330
|
values_grid = jnp.squeeze(fun((times[..., None]), xdata[..., None]).T)
|
|
252
331
|
fig, ax = plt.subplots(1, 1, figsize=figsize)
|
|
253
|
-
|
|
332
|
+
if vmin_vmax is not None:
|
|
333
|
+
im = ax.pcolormesh(
|
|
334
|
+
mesh[0] * Tmax,
|
|
335
|
+
mesh[1],
|
|
336
|
+
values_grid,
|
|
337
|
+
cmap=cmap,
|
|
338
|
+
vmin=vmin_vmax[0],
|
|
339
|
+
vmax=vmin_vmax[1],
|
|
340
|
+
)
|
|
341
|
+
else:
|
|
342
|
+
im = ax.pcolormesh(mesh[0] * Tmax, mesh[1], values_grid, cmap=cmap)
|
|
254
343
|
if colorbar:
|
|
255
344
|
fig.colorbar(im, format="%0.2f")
|
|
256
345
|
ax.set_title(title)
|