jinns 0.8.8__py3-none-any.whl → 0.8.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -98,6 +98,33 @@ def _reset_or_increment(bend, n_eff, operands):
98
98
  )
99
99
 
100
100
 
101
+ def _check_and_set_rar_parameters(rar_parameters, n, n_start):
102
+ if rar_parameters is not None and n_start is None:
103
+ raise ValueError(
104
+ f"n_start or/and nt_start must be provided in the context of RAR sampling scheme, {n_start} was provided"
105
+ )
106
+ if rar_parameters is not None:
107
+ # Default p is None. However, in the RAR sampling scheme we use 0
108
+ # probability to specify non-used collocation points (i.e. points
109
+ # above n_start). Thus, p is a vector of probability of shape (n, 1).
110
+ p = jnp.zeros((n,))
111
+ p = p.at[:n_start].set(1 / n_start)
112
+ # set internal counter for the number of gradient steps since the
113
+ # last new collocation points have been added
114
+ # It is not 0 to ensure the first iteration of RAR happens just
115
+ # after start_iter. See the _proceed_to_rar() function in _rar.py
116
+ rar_iter_from_last_sampling = rar_parameters["update_every"] - 1
117
+ # set iternal counter for the number of times collocation points
118
+ # have been added
119
+ rar_iter_nb = 0
120
+ else:
121
+ p = None
122
+ rar_iter_from_last_sampling = None
123
+ rar_iter_nb = None
124
+
125
+ return n_start, p, rar_iter_from_last_sampling, rar_iter_nb
126
+
127
+
101
128
  #####################################################
102
129
  # DataGenerator for ODE : only returns time_batches
103
130
  #####################################################
@@ -150,10 +177,10 @@ class DataGeneratorODE:
150
177
  Default to None: do not use Residual Adaptative Resampling.
151
178
  Otherwise a dictionary with keys. `start_iter`: the iteration at
152
179
  which we start the RAR sampling scheme (we first have a burn in
153
- period). `update_rate`: the number of gradient steps taken between
180
+ period). `update_every`: the number of gradient steps taken between
154
181
  each appending of collocation points in the RAR algo.
155
- `sample_size`: the size of the sample from which we will select new
156
- collocation points. `selected_sample_size`: the number of selected
182
+ `sample_size_times`: the size of the sample from which we will select new
183
+ collocation points. `selected_sample_size_times`: the number of selected
157
184
  points from the sample to be added to the current collocation
158
185
  points
159
186
  "DeepXDE: A deep learning library for solving differential
@@ -179,29 +206,13 @@ class DataGeneratorODE:
179
206
  self.method = method
180
207
  self.rar_parameters = rar_parameters
181
208
 
182
- if rar_parameters is not None and nt_start is None:
183
- raise ValueError(
184
- "nt_start must be provided in the context of RAR sampling scheme"
185
- )
186
- if rar_parameters is not None:
187
- self.nt_start = nt_start
188
- # Default p is None. However, in the RAR sampling scheme we use 0
189
- # probability to specify non-used collocation points (i.e. points
190
- # above nt_start). Thus, p is a vector of probability of shape (nt, 1).
191
- self.p = jnp.zeros((self.nt,))
192
- self.p = self.p.at[: self.nt_start].set(1 / nt_start)
193
- # set internal counter for the number of gradient steps since the
194
- # last new collocation points have been added
195
- self.rar_iter_from_last_sampling = 0
196
- # set iternal counter for the number of times collocation points
197
- # have been added
198
- self.rar_iter_nb = 0
199
-
200
- if rar_parameters is None or nt_start is None:
201
- self.nt_start = self.nt
202
- self.p = None
203
- self.rar_iter_from_last_sampling = None
204
- self.rar_iter_nb = None
209
+ # Set-up for RAR (if used)
210
+ (
211
+ self.nt_start,
212
+ self.p_times,
213
+ self.rar_iter_from_last_sampling,
214
+ self.rar_iter_nb,
215
+ ) = _check_and_set_rar_parameters(rar_parameters, n=nt, n_start=nt_start)
205
216
 
206
217
  if not self.data_exists:
207
218
  # Useful when using a lax.scan with pytree
@@ -238,7 +249,7 @@ class DataGeneratorODE:
238
249
  self.times,
239
250
  self.curr_time_idx,
240
251
  self.temporal_batch_size,
241
- self.p,
252
+ self.p_times,
242
253
  )
243
254
 
244
255
  def temporal_batch(self):
@@ -253,7 +264,7 @@ class DataGeneratorODE:
253
264
  if self.rar_parameters is not None:
254
265
  nt_eff = (
255
266
  self.nt_start
256
- + self.rar_iter_nb * self.rar_parameters["selected_sample_size"]
267
+ + self.rar_iter_nb * self.rar_parameters["selected_sample_size_omega"]
257
268
  )
258
269
  else:
259
270
  nt_eff = self.nt
@@ -283,14 +294,19 @@ class DataGeneratorODE:
283
294
  self.curr_time_idx,
284
295
  self.tmin,
285
296
  self.tmax,
286
- self.rar_parameters,
287
- self.p,
297
+ self.p_times,
288
298
  self.rar_iter_from_last_sampling,
289
299
  self.rar_iter_nb,
290
300
  ) # arrays / dynamic values
291
301
  aux_data = {
292
302
  k: vars(self)[k]
293
- for k in ["temporal_batch_size", "method", "nt", "nt_start"]
303
+ for k in [
304
+ "temporal_batch_size",
305
+ "method",
306
+ "nt",
307
+ "rar_parameters",
308
+ "nt_start",
309
+ ]
294
310
  } # static values
295
311
  return (children, aux_data)
296
312
 
@@ -308,8 +324,7 @@ class DataGeneratorODE:
308
324
  curr_time_idx,
309
325
  tmin,
310
326
  tmax,
311
- rar_parameters,
312
- p,
327
+ p_times,
313
328
  rar_iter_from_last_sampling,
314
329
  rar_iter_nb,
315
330
  ) = children
@@ -318,12 +333,11 @@ class DataGeneratorODE:
318
333
  data_exists=True,
319
334
  tmin=tmin,
320
335
  tmax=tmax,
321
- rar_parameters=rar_parameters,
322
336
  **aux_data,
323
337
  )
324
338
  obj.times = times
325
339
  obj.curr_time_idx = curr_time_idx
326
- obj.p = p
340
+ obj.p_times = p_times
327
341
  obj.rar_iter_from_last_sampling = rar_iter_from_last_sampling
328
342
  obj.rar_iter_nb = rar_iter_nb
329
343
  return obj
@@ -412,10 +426,10 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
412
426
  Default to None: do not use Residual Adaptative Resampling.
413
427
  Otherwise a dictionary with keys. `start_iter`: the iteration at
414
428
  which we start the RAR sampling scheme (we first have a burn in
415
- period). `update_rate`: the number of gradient steps taken between
429
+ period). `update_every`: the number of gradient steps taken between
416
430
  each appending of collocation points in the RAR algo.
417
- `sample_size`: the size of the sample from which we will select new
418
- collocation points. `selected_sample_size`: the number of selected
431
+ `sample_size_omega`: the size of the sample from which we will select new
432
+ collocation points. `selected_sample_size_omega`: the number of selected
419
433
  points from the sample to be added to the current collocation
420
434
  points
421
435
  "DeepXDE: A deep learning library for solving differential
@@ -442,30 +456,12 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
442
456
  assert dim == len(max_pts) and isinstance(max_pts, tuple)
443
457
  self.n = n
444
458
  self.rar_parameters = rar_parameters
445
-
446
- if rar_parameters is not None and n_start is None:
447
- raise ValueError(
448
- "n_start must be provided in the context of RAR sampling scheme"
449
- )
450
- if rar_parameters is not None:
451
- self.n_start = n_start
452
- # Default p is None. However, in the RAR sampling scheme we use 0
453
- # probability to specify non-used collocation points (i.e. points
454
- # above n_start). Thus, p is a vector of probability of shape (n, 1).
455
- self.p = jnp.zeros((self.n,))
456
- self.p = self.p.at[: self.n_start].set(1 / n_start)
457
- # set internal counter for the number of gradient steps since the
458
- # last new collocation points have been added
459
- self.rar_iter_from_last_sampling = 0
460
- # set iternal counter for the number of times collocation points
461
- # have been added
462
- self.rar_iter_nb = 0
463
-
464
- if rar_parameters is None or n_start is None:
465
- self.n_start = self.n
466
- self.p = None
467
- self.rar_iter_from_last_sampling = None
468
- self.rar_iter_nb = None
459
+ (
460
+ self.n_start,
461
+ self.p_omega,
462
+ self.rar_iter_from_last_sampling,
463
+ self.rar_iter_nb,
464
+ ) = _check_and_set_rar_parameters(rar_parameters, n=n, n_start=n_start)
469
465
 
470
466
  self.p_border = None # no RAR sampling for border for now
471
467
 
@@ -643,7 +639,7 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
643
639
  self.omega,
644
640
  self.curr_omega_idx,
645
641
  self.omega_batch_size,
646
- self.p,
642
+ self.p_omega,
647
643
  )
648
644
 
649
645
  def inside_batch(self):
@@ -656,7 +652,7 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
656
652
  if self.rar_parameters is not None:
657
653
  n_eff = (
658
654
  self.n_start
659
- + self.rar_iter_nb * self.rar_parameters["selected_sample_size"]
655
+ + self.rar_iter_nb * self.rar_parameters["selected_sample_size_omega"]
660
656
  )
661
657
  else:
662
658
  n_eff = self.n
@@ -745,8 +741,7 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
745
741
  self.curr_omega_border_idx,
746
742
  self.min_pts,
747
743
  self.max_pts,
748
- self.rar_parameters,
749
- self.p,
744
+ self.p_omega,
750
745
  self.rar_iter_from_last_sampling,
751
746
  self.rar_iter_nb,
752
747
  )
@@ -759,6 +754,7 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
759
754
  "omega_border_batch_size",
760
755
  "method",
761
756
  "dim",
757
+ "rar_parameters",
762
758
  "n_start",
763
759
  ]
764
760
  }
@@ -780,8 +776,7 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
780
776
  curr_omega_border_idx,
781
777
  min_pts,
782
778
  max_pts,
783
- rar_parameters,
784
- p,
779
+ p_omega,
785
780
  rar_iter_from_last_sampling,
786
781
  rar_iter_nb,
787
782
  ) = children
@@ -792,14 +787,13 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
792
787
  data_exists=True,
793
788
  min_pts=min_pts,
794
789
  max_pts=max_pts,
795
- rar_parameters=rar_parameters,
796
790
  **aux_data,
797
791
  )
798
792
  obj.omega = omega
799
793
  obj.omega_border = omega_border
800
794
  obj.curr_omega_idx = curr_omega_idx
801
795
  obj.curr_omega_border_idx = curr_omega_border_idx
802
- obj.p = p
796
+ obj.p_omega = p_omega
803
797
  obj.rar_iter_from_last_sampling = rar_iter_from_last_sampling
804
798
  obj.rar_iter_nb = rar_iter_nb
805
799
  return obj
@@ -834,6 +828,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
834
828
  method="grid",
835
829
  rar_parameters=None,
836
830
  n_start=None,
831
+ nt_start=None,
837
832
  data_exists=False,
838
833
  ):
839
834
  r"""
@@ -887,27 +882,23 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
887
882
  Default to None: do not use Residual Adaptative Resampling.
888
883
  Otherwise a dictionary with keys. `start_iter`: the iteration at
889
884
  which we start the RAR sampling scheme (we first have a burn in
890
- period). `update_rate`: the number of gradient steps taken between
885
+ period). `update_every`: the number of gradient steps taken between
891
886
  each appending of collocation points in the RAR algo.
892
- `sample_size`: the size of the sample from which we will select new
893
- collocation points. `selected_sample_size`: the number of selected
887
+ `sample_size_omega`: the size of the sample from which we will select new
888
+ collocation points. `selected_sample_size_omega`: the number of selected
894
889
  points from the sample to be added to the current collocation
895
890
  points.
896
- __Note:__ that if RAR sampling is chosen it will currently affect both
897
- self.times and self.omega with the same hyperparameters
898
- (rar_parameters and n_start)
899
- "DeepXDE: A deep learning library for solving differential
900
- equations", L. Lu, SIAM Review, 2021
901
891
  n_start
902
892
  Defaults to None. The effective size of n used at start time.
903
893
  This value must be
904
894
  provided when rar_parameters is not None. Otherwise we set internally
905
895
  n_start = n and this is hidden from the user.
906
896
  In RAR, n_start
907
- then corresponds to the initial number of points we train the PINN.
908
- __Note:__ that if RAR sampling is chosen it will currently affect both
909
- self.times and self.omega with the same hyperparameters
910
- (rar_parameters and n_start)
897
+ then corresponds to the initial number of omega points we train the PINN.
898
+ nt_start
899
+ Defaults to None. A RAR hyper-parameter. Same as ``n_start`` but
900
+ for times collocation point. See also ``DataGeneratorODE``
901
+ documentation.
911
902
  data_exists
912
903
  Must be left to `False` when created by the user. Avoids the
913
904
  regeneration of :math:`\Omega`, :math:`\partial\Omega` and
@@ -931,6 +922,15 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
931
922
  self.tmin = tmin
932
923
  self.tmax = tmax
933
924
  self.nt = nt
925
+
926
+ # Set-up for timewise RAR (some quantity are already set-up by super())
927
+ (
928
+ self.nt_start,
929
+ self.p_times,
930
+ _,
931
+ _,
932
+ ) = _check_and_set_rar_parameters(rar_parameters, n=nt, n_start=nt_start)
933
+
934
934
  if not self.data_exists:
935
935
  # Useful when using a lax.scan with pytree
936
936
  # Optionally can tell JAX not to re-generate data
@@ -950,7 +950,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
950
950
  self.times,
951
951
  self.curr_time_idx,
952
952
  self.temporal_batch_size,
953
- self.p,
953
+ self.p_times,
954
954
  )
955
955
 
956
956
  def generate_data_nonstatio(self):
@@ -980,7 +980,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
980
980
  if self.rar_parameters is not None:
981
981
  nt_eff = (
982
982
  self.n_start
983
- + self.rar_iter_nb * self.rar_parameters["selected_sample_size"]
983
+ + self.rar_iter_nb * self.rar_parameters["selected_sample_size_times"]
984
984
  )
985
985
  else:
986
986
  nt_eff = self.nt
@@ -1022,8 +1022,8 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
1022
1022
  self.max_pts,
1023
1023
  self.tmin,
1024
1024
  self.tmax,
1025
- self.rar_parameters,
1026
- self.p,
1025
+ self.p_times,
1026
+ self.p_omega,
1027
1027
  self.rar_iter_from_last_sampling,
1028
1028
  self.rar_iter_nb,
1029
1029
  )
@@ -1038,7 +1038,9 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
1038
1038
  "temporal_batch_size",
1039
1039
  "method",
1040
1040
  "dim",
1041
+ "rar_parameters",
1041
1042
  "n_start",
1043
+ "nt_start",
1042
1044
  ]
1043
1045
  }
1044
1046
  return (children, aux_data)
@@ -1063,8 +1065,8 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
1063
1065
  max_pts,
1064
1066
  tmin,
1065
1067
  tmax,
1066
- rar_parameters,
1067
- p,
1068
+ p_times,
1069
+ p_omega,
1068
1070
  rar_iter_from_last_sampling,
1069
1071
  rar_iter_nb,
1070
1072
  ) = children
@@ -1075,7 +1077,6 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
1075
1077
  max_pts=max_pts,
1076
1078
  tmin=tmin,
1077
1079
  tmax=tmax,
1078
- rar_parameters=rar_parameters,
1079
1080
  **aux_data,
1080
1081
  )
1081
1082
  obj.omega = omega
@@ -1084,9 +1085,11 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
1084
1085
  obj.curr_omega_idx = curr_omega_idx
1085
1086
  obj.curr_omega_border_idx = curr_omega_border_idx
1086
1087
  obj.curr_time_idx = curr_time_idx
1087
- obj.p = p
1088
+ obj.p_times = p_times
1089
+ obj.p_omega = p_omega
1088
1090
  obj.rar_iter_from_last_sampling = rar_iter_from_last_sampling
1089
1091
  obj.rar_iter_nb = rar_iter_nb
1092
+
1090
1093
  return obj
1091
1094
 
1092
1095
 
jinns/loss/_LossODE.py CHANGED
@@ -237,6 +237,7 @@ class LossODE:
237
237
  "u": self.u,
238
238
  "dynamic_loss": self.dynamic_loss,
239
239
  "obs_slice": self.obs_slice,
240
+ "derivative_keys": self.derivative_keys,
240
241
  }
241
242
  return (children, aux_data)
242
243
 
jinns/solver/_rar.py CHANGED
@@ -11,6 +11,41 @@ from jinns.loss._LossODE import LossODE, SystemLossODE
11
11
  from jinns.loss._DynamicLossAbstract import PDEStatio
12
12
 
13
13
  from functools import partial
14
+ from jinns.utils._hyperpinn import HYPERPINN
15
+ from jinns.utils._spinn import SPINN
16
+
17
+
18
+ def _proceed_to_rar(data, i):
19
+ """Utilility function with various check to ensure we can proceed with the rar_step.
20
+ Return True if yes, and False otherwise"""
21
+
22
+ # Overall checks (universal for any data generator)
23
+ check_list = [
24
+ # check if burn-in period has ended
25
+ data.rar_parameters["start_iter"] <= i,
26
+ # check if enough iterations since last points added
27
+ (data.rar_parameters["update_every"] - 1) == data.rar_iter_from_last_sampling,
28
+ ]
29
+
30
+ # Memory allocation checks (depends on the type of DataGenerator)
31
+ # check if we still have room to append new collocation points in the
32
+ # allocated jnp.array (can concern `data.p_times` or `p_omega`)
33
+ if isinstance(data, DataGeneratorODE) or isinstance(data, CubicMeshPDENonStatio):
34
+ check_list.append(
35
+ data.rar_parameters["selected_sample_size_times"]
36
+ <= jnp.count_nonzero(data.p_times == 0),
37
+ )
38
+
39
+ if isinstance(data, CubicMeshPDEStatio) or isinstance(data, CubicMeshPDENonStatio):
40
+ # for now the above check are redundants but there may be a time when
41
+ # we drop inheritence
42
+ check_list.append(
43
+ data.rar_parameters["selected_sample_size_omega"]
44
+ <= jnp.count_nonzero(data.p_omega == 0),
45
+ )
46
+
47
+ proceed = jnp.all(jnp.array(check_list))
48
+ return proceed
14
49
 
15
50
 
16
51
  @partial(jax.jit, static_argnames=["_rar_step_true", "_rar_step_false"])
@@ -22,21 +57,7 @@ def trigger_rar(i, loss, params, data, _rar_step_true, _rar_step_false):
22
57
  else:
23
58
  # update `data` according to rar scheme.
24
59
  data = jax.lax.cond(
25
- jnp.all(
26
- jnp.array(
27
- [
28
- # check if enough it since last points added
29
- data.rar_parameters["update_rate"]
30
- == data.rar_iter_from_last_sampling,
31
- # check if burn in period has ended
32
- data.rar_parameters["start_iter"] < i,
33
- # check if we still have room to append new
34
- # collocation points in the allocated jnp array
35
- data.rar_parameters["selected_sample_size"]
36
- <= jnp.count_nonzero(data.p == 0),
37
- ]
38
- )
39
- ),
60
+ _proceed_to_rar(data, i),
40
61
  _rar_step_true,
41
62
  _rar_step_false,
42
63
  (loss, params, data, i),
@@ -49,13 +70,37 @@ def init_rar(data):
49
70
  Separated from the main rar, because the initialization to get _true and
50
71
  _false cannot be jit-ted.
51
72
  """
73
+ # NOTE if a user misspell some entry of ``rar_parameters`` the error
74
+ # risks to be a bit obscure but it should be ok.
52
75
  if data.rar_parameters is None:
53
76
  _rar_step_true, _rar_step_false = None, None
54
77
  else:
55
- _rar_step_true, _rar_step_false = _rar_step_init(
56
- data.rar_parameters["sample_size"],
57
- data.rar_parameters["selected_sample_size"],
58
- )
78
+ if isinstance(data, DataGeneratorODE):
79
+ # In this case we only need rar parameters related to `times`
80
+ _rar_step_true, _rar_step_false = _rar_step_init(
81
+ data.rar_parameters["sample_size_times"],
82
+ data.rar_parameters["selected_sample_size_times"],
83
+ )
84
+ elif isinstance(data, CubicMeshPDENonStatio):
85
+ # In this case we need rar parameters related to both `times`
86
+ # and`omega`
87
+ _rar_step_true, _rar_step_false = _rar_step_init(
88
+ (
89
+ data.rar_parameters["sample_size_times"],
90
+ data.rar_parameters["sample_size_omega"],
91
+ ),
92
+ (
93
+ data.rar_parameters["selected_sample_size_times"],
94
+ data.rar_parameters["selected_sample_size_omega"],
95
+ ),
96
+ )
97
+ elif isinstance(data, CubicMeshPDEStatio):
98
+ # In this case we only need rar parameters related to `omega`
99
+ _rar_step_true, _rar_step_false = _rar_step_init(
100
+ data.rar_parameters["sample_size_omega"],
101
+ data.rar_parameters["selected_sample_size_omega"],
102
+ )
103
+
59
104
  data.rar_parameters["iter_from_last_sampling"] = 0
60
105
 
61
106
  return data, _rar_step_true, _rar_step_false
@@ -64,7 +109,7 @@ def init_rar(data):
64
109
  def _rar_step_init(sample_size, selected_sample_size):
65
110
  """
66
111
  This is a wrapper because the sampling size and
67
- selected_sample_size, must be treated static
112
+ selected_sample_size, must be treated as static
68
113
  in order to slice. So they must be set before jitting and not with the jitted
69
114
  dictionary values rar["test_points_nb"] and rar["added_points_nb"]
70
115
 
@@ -72,16 +117,10 @@ def _rar_step_init(sample_size, selected_sample_size):
72
117
  """
73
118
 
74
119
  def rar_step_true(operands):
75
- """
76
- Note: in all generality, we would need a stop gradient operator around
77
- these dynamic_loss evaluations that follow which produce weights for
78
- sampling. However, they appear through a argsort and sampling
79
- operations which definitly kills gradient flows
80
- """
81
120
  loss, params, data, i = operands
82
121
 
83
122
  if isinstance(data, DataGeneratorODE):
84
- s = data.sample_in_time_domain(sample_size)
123
+ new_omega_samples = data.sample_in_time_domain(sample_size)
85
124
 
86
125
  # We can have different types of Loss
87
126
  if isinstance(loss, LossODE):
@@ -90,7 +129,7 @@ def _rar_step_init(sample_size, selected_sample_size):
90
129
  (0),
91
130
  0,
92
131
  )
93
- dyn_on_s = v_dyn_loss(s)
132
+ dyn_on_s = v_dyn_loss(new_omega_samples)
94
133
  if dyn_on_s.ndim > 1:
95
134
  mse_on_s = (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
96
135
  else:
@@ -106,7 +145,7 @@ def _rar_step_init(sample_size, selected_sample_size):
106
145
  (0),
107
146
  0,
108
147
  )
109
- dyn_on_s = v_dyn_loss(s)
148
+ dyn_on_s = v_dyn_loss(new_omega_samples)
110
149
  if dyn_on_s.ndim > 1:
111
150
  mse_on_s += (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
112
151
  else:
@@ -118,9 +157,7 @@ def _rar_step_init(sample_size, selected_sample_size):
118
157
  (mse_on_s.shape[0] - selected_sample_size,),
119
158
  (selected_sample_size,),
120
159
  )
121
- higher_residual_points = s[higher_residual_idx]
122
-
123
- data.rar_iter_from_last_sampling = 0
160
+ higher_residual_points = new_omega_samples[higher_residual_idx]
124
161
 
125
162
  ## add the new points in times
126
163
  # start indices of update can be dynamic but the the shape (length)
@@ -135,7 +172,7 @@ def _rar_step_init(sample_size, selected_sample_size):
135
172
  ## points are non-zero
136
173
  new_proba = 1 / (data.nt_start + data.rar_iter_nb * selected_sample_size)
137
174
  # the next work because nt_start is static
138
- data.p = data.p.at[: data.nt_start].set(new_proba)
175
+ data.p_times = data.p_times.at[: data.nt_start].set(new_proba)
139
176
 
140
177
  # the next requires a fori_loop because the range is dynamic
141
178
  def update_slices(i, p):
@@ -147,16 +184,14 @@ def _rar_step_init(sample_size, selected_sample_size):
147
184
 
148
185
  data.rar_iter_nb += 1
149
186
 
150
- data.p = jax.lax.fori_loop(0, data.rar_iter_nb, update_slices, data.p)
151
-
152
- # NOTE must return data to be correctly updated because we cannot
153
- # have side effects in this function that will be jitted
154
- return data
187
+ data.p_times = jax.lax.fori_loop(
188
+ 0, data.rar_iter_nb, update_slices, data.p_times
189
+ )
155
190
 
156
191
  elif isinstance(data, CubicMeshPDEStatio) and not isinstance(
157
192
  data, CubicMeshPDENonStatio
158
193
  ):
159
- s = data.sample_in_omega_domain(sample_size)
194
+ new_omega_samples = data.sample_in_omega_domain(sample_size)
160
195
 
161
196
  # We can have different types of Loss
162
197
  if isinstance(loss, LossPDEStatio):
@@ -169,7 +204,7 @@ def _rar_step_init(sample_size, selected_sample_size):
169
204
  (0),
170
205
  0,
171
206
  )
172
- dyn_on_s = v_dyn_loss(s)
207
+ dyn_on_s = v_dyn_loss(new_omega_samples)
173
208
  if dyn_on_s.ndim > 1:
174
209
  mse_on_s = (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
175
210
  else:
@@ -185,7 +220,7 @@ def _rar_step_init(sample_size, selected_sample_size):
185
220
  0,
186
221
  0,
187
222
  )
188
- dyn_on_s = v_dyn_loss(s)
223
+ dyn_on_s = v_dyn_loss(new_omega_samples)
189
224
  if dyn_on_s.ndim > 1:
190
225
  mse_on_s += (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
191
226
  else:
@@ -197,12 +232,10 @@ def _rar_step_init(sample_size, selected_sample_size):
197
232
  (mse_on_s.shape[0] - selected_sample_size,),
198
233
  (selected_sample_size,),
199
234
  )
200
- higher_residual_points = s[higher_residual_idx]
235
+ higher_residual_points = new_omega_samples[higher_residual_idx]
201
236
 
202
- data.rar_iter_from_last_sampling = 0
203
-
204
- ## add the new points in times
205
- # start indices of update can be dynamic but the the shape (length)
237
+ ## add the new points in omega
238
+ # start indices of update can be dynamic but not the shape (length)
206
239
  # of the slice
207
240
  data.omega = jax.lax.dynamic_update_slice(
208
241
  data.omega,
@@ -214,7 +247,7 @@ def _rar_step_init(sample_size, selected_sample_size):
214
247
  ## points are non-zero
215
248
  new_proba = 1 / (data.n_start + data.rar_iter_nb * selected_sample_size)
216
249
  # the next work because n_start is static
217
- data.p = data.p.at[: data.n_start].set(new_proba)
250
+ data.p_omega = data.p_omega.at[: data.n_start].set(new_proba)
218
251
 
219
252
  # the next requires a fori_loop because the range is dynamic
220
253
  def update_slices(i, p):
@@ -226,145 +259,169 @@ def _rar_step_init(sample_size, selected_sample_size):
226
259
 
227
260
  data.rar_iter_nb += 1
228
261
 
229
- data.p = jax.lax.fori_loop(0, data.rar_iter_nb, update_slices, data.p)
230
-
231
- # NOTE must return data to be correctly updated because we cannot
232
- # have side effects in this function that will be jitted
233
- return data
262
+ data.p_omega = jax.lax.fori_loop(
263
+ 0, data.rar_iter_nb, update_slices, data.p_omega
264
+ )
234
265
 
235
266
  elif isinstance(data, CubicMeshPDENonStatio):
236
- st = data.sample_in_time_domain(sample_size)
237
- sx = data.sample_in_omega_domain(sample_size)
238
-
239
- # According to the Loss type we have different syntax to call the
240
- # dynamic_loss evaluate function
241
- if isinstance(loss, LossPDEStatio) and not isinstance(
242
- loss, LossPDENonStatio
243
- ):
244
- # This case might not happen very often...
245
- v_dyn_loss = vmap(
246
- lambda x: loss.dynamic_loss.evaluate(
247
- x,
248
- loss.u,
249
- params,
250
- ),
251
- (0),
252
- 0,
253
- )
254
- dyn_on_s = v_dyn_loss(sx)
255
- if dyn_on_s.ndim > 1:
256
- mse_on_s = (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
257
- else:
258
- mse_on_s = dyn_on_s**2
259
- elif isinstance(loss, LossPDENonStatio):
267
+ # NOTE in this case sample_size and selected_sample_size
268
+ # are tuples (times, omega) => we unpack them for clarity
269
+ selected_sample_size_times, selected_sample_size_omega = (
270
+ selected_sample_size
271
+ )
272
+ sample_size_times, sample_size_omega = sample_size
273
+
274
+ new_times_samples = data.sample_in_time_domain(sample_size_times)
275
+ new_omega_samples = data.sample_in_omega_domain(sample_size_omega)
276
+
277
+ if isinstance(loss.u, HYPERPINN) or isinstance(loss.u, SPINN):
278
+ raise NotImplementedError("RAR not implemented for hyperPINN and SPINN")
279
+ else:
280
+ # do cartesian product on new points
281
+ tile_omega = jnp.tile(
282
+ new_omega_samples, reps=(sample_size_times, 1)
283
+ ) # it is tiled
284
+ repeat_times = jnp.repeat(new_times_samples, sample_size_omega, axis=0)[
285
+ ..., None
286
+ ] # it is repeated + add an axis
287
+
288
+ if isinstance(loss, LossPDENonStatio):
260
289
  v_dyn_loss = vmap(
261
290
  lambda t, x: loss.dynamic_loss.evaluate(t, x, loss.u, params),
262
291
  (0, 0),
263
292
  0,
264
293
  )
265
- dyn_on_s = v_dyn_loss(st[..., None], sx)
266
- if dyn_on_s.ndim > 1:
267
- mse_on_s = (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
268
- else:
269
- mse_on_s = dyn_on_s**2
294
+ dyn_on_s = v_dyn_loss(repeat_times, tile_omega).reshape(
295
+ (sample_size_times, sample_size_omega)
296
+ )
297
+ mse_on_s = dyn_on_s**2
270
298
  elif isinstance(loss, SystemLossPDE):
271
- mse_on_s = 0
299
+ dyn_on_s = jnp.zeros((sample_size_times, sample_size_omega))
272
300
  for i in loss.dynamic_loss_dict.keys():
273
- if isinstance(loss.dynamic_loss_dict[i], PDEStatio):
274
- v_dyn_loss = vmap(
275
- lambda x: loss.dynamic_loss_dict[i].evaluate(
276
- x, loss.u_dict, params
277
- ),
278
- 0,
279
- 0,
280
- )
281
- dyn_on_s = v_dyn_loss(sx)
282
- if dyn_on_s.ndim > 1:
283
- mse_on_s += (
284
- jnp.linalg.norm(dyn_on_s, axis=-1) ** 2
285
- ).flatten()
286
- else:
287
- mse_on_s += dyn_on_s**2
288
- else:
289
- v_dyn_loss = vmap(
290
- lambda t, x: loss.dynamic_loss_dict[i].evaluate(
291
- t, x, loss.u_dict, params
292
- ),
293
- (0, 0),
294
- 0,
295
- )
296
- dyn_on_s = v_dyn_loss(st[..., None], sx)
297
- if dyn_on_s.ndim > 1:
298
- mse_on_s += (
299
- jnp.linalg.norm(dyn_on_s, axis=-1) ** 2
300
- ).flatten()
301
- else:
302
- mse_on_s += dyn_on_s**2
303
-
304
- ## Now that we have the residuals, select the m points
305
- # with higher dynamic loss (residuals)
306
- higher_residual_idx = jax.lax.dynamic_slice(
307
- jnp.argsort(mse_on_s),
308
- (mse_on_s.shape[0] - selected_sample_size,),
309
- (selected_sample_size,),
310
- )
311
- higher_residual_points_st = st[higher_residual_idx]
312
- higher_residual_points_sx = sx[higher_residual_idx]
301
+ v_dyn_loss = vmap(
302
+ lambda t, x: loss.dynamic_loss_dict[i].evaluate(
303
+ t, x, loss.u_dict, params
304
+ ),
305
+ (0, 0),
306
+ 0,
307
+ )
308
+ dyn_on_s += v_dyn_loss(repeat_times, tile_omega).reshape(
309
+ (sample_size_times, sample_size_omega)
310
+ )
313
311
 
314
- data.rar_iter_from_last_sampling = 0
312
+ mse_on_s = dyn_on_s**2
313
+ # -- Select the m points with highest average residuals on time and
314
+ # -- space (times in rows / omega in columns)
315
+ # mean_times = mse_on_s.mean(axis=1)
316
+ # mean_omega = mse_on_s.mean(axis=0)
317
+ # times_idx = jax.lax.dynamic_slice(
318
+ # jnp.argsort(mean_times),
319
+ # (mse_on_s.shape[0] - selected_sample_size_times,),
320
+ # (selected_sample_size_times,),
321
+ # )
322
+ # omega_idx = jax.lax.dynamic_slice(
323
+ # jnp.argsort(mean_omega),
324
+ # (mse_on_s.shape[1] - selected_sample_size_omega,),
325
+ # (selected_sample_size_omega,),
326
+ # )
327
+
328
+ # -- Select the m worst points (t, x) with highest residuals
329
+ n_select = max(selected_sample_size_times, selected_sample_size_omega)
330
+ _, idx = jax.lax.top_k(mse_on_s.flatten(), k=n_select)
331
+ arr_idx = jnp.unravel_index(idx, mse_on_s.shape)
332
+ times_idx = arr_idx[0][:selected_sample_size_times]
333
+ omega_idx = arr_idx[1][:selected_sample_size_omega]
334
+
335
+ higher_residual_points_times = new_times_samples[times_idx]
336
+ higher_residual_points_omega = new_omega_samples[omega_idx]
315
337
 
316
338
  ## add the new points in times
317
- # start indices of update can be dynamic but the the shape (length)
339
+ # start indices of update can be dynamic but not the shape (length)
318
340
  # of the slice
319
341
  data.times = jax.lax.dynamic_update_slice(
320
342
  data.times,
321
- higher_residual_points_st,
322
- (data.n_start + data.rar_iter_nb * selected_sample_size,),
343
+ higher_residual_points_times,
344
+ (data.n_start + data.rar_iter_nb * selected_sample_size_times,),
323
345
  )
324
346
 
325
347
  ## add the new points in omega
326
348
  data.omega = jax.lax.dynamic_update_slice(
327
349
  data.omega,
328
- higher_residual_points_sx,
350
+ higher_residual_points_omega,
329
351
  (
330
- data.n_start + data.rar_iter_nb * selected_sample_size,
352
+ data.n_start + data.rar_iter_nb * selected_sample_size_omega,
331
353
  data.dim,
332
354
  ),
333
355
  )
334
356
 
335
357
  ## rearrange probabilities so that the probabilities of the new
336
358
  ## points are non-zero
337
- new_proba = 1 / (data.n_start + data.rar_iter_nb * selected_sample_size)
359
+ new_p_times = 1 / (
360
+ data.nt_start + data.rar_iter_nb * selected_sample_size_times
361
+ )
338
362
  # the next work because nt_start is static
339
- data.p = data.p.at[: data.n_start].set(new_proba)
363
+ data.p_times = data.p_times.at[: data.nt_start].set(new_p_times)
340
364
 
341
- # the next requires a fori_loop because the range is dynamic
342
- def update_slices(i, p):
343
- return jax.lax.dynamic_update_slice(
344
- p,
345
- 1 / new_proba * jnp.ones((selected_sample_size,)),
346
- ((data.n_start + i * selected_sample_size),),
347
- )
365
+ # same for p_omega (work because n_start is static)
366
+ new_p_omega = 1 / (
367
+ data.n_start + data.rar_iter_nb * selected_sample_size_omega
368
+ )
369
+ data.p_omega = data.p_omega.at[: data.n_start].set(new_p_omega)
370
+
371
+ # the part of data.p_* after n_start requires a fori_loop because
372
+ # the range is dynamic
373
+ def create_update_slices(new_val, selected_sample_size):
374
+ def update_slices(i, p):
375
+ new_p = jax.lax.dynamic_update_slice(
376
+ p,
377
+ new_val * jnp.ones((selected_sample_size,)),
378
+ ((data.n_start + i * selected_sample_size),),
379
+ )
380
+ return new_p
381
+
382
+ return update_slices
348
383
 
349
384
  data.rar_iter_nb += 1
350
385
 
351
- data.p = jax.lax.fori_loop(0, data.rar_iter_nb, update_slices, data.p)
386
+ ## update rest of p_times
387
+ update_slices_times = create_update_slices(
388
+ new_p_times, selected_sample_size_times
389
+ )
390
+ data.p_times = jax.lax.fori_loop(
391
+ 0,
392
+ data.rar_iter_nb,
393
+ update_slices_times,
394
+ data.p_times,
395
+ )
396
+ ## update rest of p_omega
397
+ update_slices_omega = create_update_slices(
398
+ new_p_omega, selected_sample_size_omega
399
+ )
400
+ data.p_omega = jax.lax.fori_loop(
401
+ 0,
402
+ data.rar_iter_nb,
403
+ update_slices_omega,
404
+ data.p_omega,
405
+ )
352
406
 
353
- # NOTE must return data to be correctly updated because we cannot
354
- # have side effects in this function that will be jitted
355
- return data
407
+ # update RAR parameters for all cases
408
+ data.rar_iter_from_last_sampling = 0
409
+
410
+ # NOTE must return data to be correctly updated because we cannot
411
+ # have side effects in this function that will be jitted
412
+ return data
356
413
 
357
414
  def rar_step_false(operands):
358
415
  _, _, data, i = operands
359
416
 
360
417
  # Add 1 only if we are after the burn in period
361
- data.rar_iter_from_last_sampling = jax.lax.cond(
362
- i < data.rar_parameters["start_iter"],
363
- lambda operand: 0,
364
- lambda operand: operand + 1,
365
- (data.rar_iter_from_last_sampling),
418
+ increment = jax.lax.cond(
419
+ i <= data.rar_parameters["start_iter"],
420
+ lambda: 0,
421
+ lambda: 1,
366
422
  )
367
423
 
424
+ data.rar_iter_from_last_sampling += increment
368
425
  return data
369
426
 
370
427
  return rar_step_true, rar_step_false
jinns/solver/_seq2seq.py CHANGED
@@ -88,7 +88,7 @@ def initialize_seq2seq(loss, data, seq2seq, opt_state):
88
88
  data.curr_omega_idx = 0
89
89
  data.generate_time_data()
90
90
  data._key, data.times, _ = _reset_batch_idx_and_permute(
91
- (data._key, data.times, data.curr_omega_idx, None, data.p)
91
+ (data._key, data.times, data.curr_omega_idx, None, data.p_times)
92
92
  )
93
93
  opt_state.hyperparams["learning_rate"] = seq2seq["learning_rate"][curr_seq]
94
94
 
@@ -145,7 +145,7 @@ def _update_seq2seq_SystemLossODE(operands):
145
145
  data.curr_omega_idx = 0
146
146
  data.generate_time_data()
147
147
  data._key, data.times, _ = _reset_batch_idx_and_permute(
148
- (data._key, data.times, data.curr_omega_idx, None, data.p)
148
+ (data._key, data.times, data.curr_omega_idx, None, data.p_times)
149
149
  )
150
150
 
151
151
  opt_state.hyperparams["learning_rate"] = seq2seq["learning_rate"][curr_seq]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jinns
3
- Version: 0.8.8
3
+ Version: 0.8.10
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -1,5 +1,5 @@
1
1
  jinns/__init__.py,sha256=T2XlmLbYqcXTumPJL00cJ80W98We5LH8Yg_Lss_exl4,139
2
- jinns/data/_DataGenerators.py,sha256=N4-U4z3MG46UIzHCbKScv9Z7AN40w1wlLY_VsVNj2sI,62293
2
+ jinns/data/_DataGenerators.py,sha256=_um-giHQ8mCILUOJHX231njHTHZp4S7EcGrUs7R1dUs,61829
3
3
  jinns/data/__init__.py,sha256=yBOmoavSD-cABp4XcjQY1zsEVO0mDyIhi2MJ5WNp0l8,326
4
4
  jinns/data/_display.py,sha256=vlqggDCgVMEwdGBtjVmZaTQORU6imSfDkssn2XCtITI,10392
5
5
  jinns/experimental/__init__.py,sha256=qWbhC7Z8UgLWy0t-zU7RYze6v13-FngiCYXu-2bRVFQ,296
@@ -8,15 +8,15 @@ jinns/experimental/_sinuspinn.py,sha256=hxSzscwMV2LayWOqenIlT1zqEVVrE5Y8CKf7bHX5
8
8
  jinns/experimental/_spectralpinn.py,sha256=-4795pa7AYtRNSE-ugan3gHh64mtu2VdrRG5AS_J9Eg,2654
9
9
  jinns/loss/_DynamicLoss.py,sha256=L4CVmmF0rTPbHntgqsLLHlnrlQgLHsetUocpJm7ZYag,27461
10
10
  jinns/loss/_DynamicLossAbstract.py,sha256=kTQlhLx7SBuH5dIDmYaE79sVHUZt1nUFa8LxPU5IHhM,8504
11
- jinns/loss/_LossODE.py,sha256=b9doBHoQwYvlgpqzrNO4dOaTN87LRvjHtHbz9bMoH7E,22119
11
+ jinns/loss/_LossODE.py,sha256=Y1mxryPVFf7ruqw_mGNACLExfx4iQT4R2bZP3s5rg4c,22172
12
12
  jinns/loss/_LossPDE.py,sha256=purAEtc0e71kv9XnZUT-a7MrkDAkM_3tTI4xJPu6fH4,61629
13
13
  jinns/loss/_Losses.py,sha256=XOL3MFiKEd3ndsc78Qnpi1vbgR0B2HaAWOGGW2meDM8,11190
14
14
  jinns/loss/__init__.py,sha256=pFNYUxns-NPXBFdqrEVSiXkQLfCtKw-t2trlhvLzpYE,355
15
15
  jinns/loss/_boundary_conditions.py,sha256=YfSnLZ25hXqQ5KWAuxOrWSKkf_oBqAc9GQV4z7MjWyQ,17434
16
16
  jinns/loss/_operators.py,sha256=zDGJqYqeYH7xd-4dtGX9PS-pf0uSOpUUXGo5SVjIJ4o,11069
17
17
  jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
- jinns/solver/_rar.py,sha256=K-0y1-ofOAo1n_Ea3QShSGCGKVYTwiaE_Bz9-DZMJm8,14525
19
- jinns/solver/_seq2seq.py,sha256=FL-42hTgmVl7O3hHh1ccFVw2bT8bW82hvlDRz971Chk,5620
18
+ jinns/solver/_rar.py,sha256=IYP-jdbM0rbjBtxislrBYBuj49p9_QDOqejZKCHrKg8,17072
19
+ jinns/solver/_seq2seq.py,sha256=S6IPfsXpS_fbqIqAy01eUM7GBSBSkRzURan_J-iXXzI,5632
20
20
  jinns/solver/_solve.py,sha256=mGi0zaT_fK_QpBjTxof5Ix4mmfmnPi66CNJ3GQFZuo4,19099
21
21
  jinns/utils/__init__.py,sha256=44ms5UR6vMw3Nf6u4RCAzPFs4fom_YbBnH9mfne8m6k,313
22
22
  jinns/utils/_containers.py,sha256=eYD277fO7X4EfX7PUFCCl69r3JBfh1sCfq8LkL5gd6o,1495
@@ -29,8 +29,8 @@ jinns/utils/_utils.py,sha256=8dgvWXX9NT7_7-zltWp0C9tG45ZFNwXxueyxPBb4hjo,6740
29
29
  jinns/utils/_utils_uspinn.py,sha256=qcKcOw3zrwWSQyGVj6fD8c9GinHt_U6JWN_k0auTtXM,26039
30
30
  jinns/validation/__init__.py,sha256=Jv58mzgC3F7cRfXA6caicL1t_U0UAhbwLrmMNVg6E7s,66
31
31
  jinns/validation/_validation.py,sha256=KfetbzB0xTNdBcYLwFWjEtP63Tf9wJirlhgqLTJDyy4,6761
32
- jinns-0.8.8.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
33
- jinns-0.8.8.dist-info/METADATA,sha256=oTs2EJMu4Bwo2n9DLsAPSU5edpbgPtwhNXBuW8YjpOc,2482
34
- jinns-0.8.8.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
35
- jinns-0.8.8.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
36
- jinns-0.8.8.dist-info/RECORD,,
32
+ jinns-0.8.10.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
33
+ jinns-0.8.10.dist-info/METADATA,sha256=5lGoyi2W9MRamQdHVgZnYflJtp__zWDGyTiYgmfGc6g,2483
34
+ jinns-0.8.10.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
35
+ jinns-0.8.10.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
36
+ jinns-0.8.10.dist-info/RECORD,,
File without changes