jinns 0.8.7__py3-none-any.whl → 0.8.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/__init__.py CHANGED
@@ -2,4 +2,5 @@ import jinns.data
2
2
  import jinns.loss
3
3
  import jinns.solver
4
4
  import jinns.utils
5
+ import jinns.experimental
5
6
  from jinns.solver._solve import solve
@@ -98,6 +98,33 @@ def _reset_or_increment(bend, n_eff, operands):
98
98
  )
99
99
 
100
100
 
101
+ def _check_and_set_rar_parameters(rar_parameters, n, n_start):
102
+ if rar_parameters is not None and n_start is None:
103
+ raise ValueError(
104
+ f"n_start or/and nt_start must be provided in the context of RAR sampling scheme, {n_start} was provided"
105
+ )
106
+ if rar_parameters is not None:
107
+ # Default p is None. However, in the RAR sampling scheme we use 0
108
+ # probability to specify non-used collocation points (i.e. points
109
+ # above n_start). Thus, p is a vector of probability of shape (n, 1).
110
+ p = jnp.zeros((n,))
111
+ p = p.at[:n_start].set(1 / n_start)
112
+ # set internal counter for the number of gradient steps since the
113
+ # last new collocation points have been added
114
+ # It is not 0 to ensure the first iteration of RAR happens just
115
+ # after start_iter. See the _proceed_to_rar() function in _rar.py
116
+ rar_iter_from_last_sampling = rar_parameters["update_every"] - 1
117
+ # set iternal counter for the number of times collocation points
118
+ # have been added
119
+ rar_iter_nb = 0
120
+ else:
121
+ p = None
122
+ rar_iter_from_last_sampling = None
123
+ rar_iter_nb = None
124
+
125
+ return n_start, p, rar_iter_from_last_sampling, rar_iter_nb
126
+
127
+
101
128
  #####################################################
102
129
  # DataGenerator for ODE : only returns time_batches
103
130
  #####################################################
@@ -150,10 +177,10 @@ class DataGeneratorODE:
150
177
  Default to None: do not use Residual Adaptative Resampling.
151
178
  Otherwise a dictionary with keys. `start_iter`: the iteration at
152
179
  which we start the RAR sampling scheme (we first have a burn in
153
- period). `update_rate`: the number of gradient steps taken between
180
+ period). `update_every`: the number of gradient steps taken between
154
181
  each appending of collocation points in the RAR algo.
155
- `sample_size`: the size of the sample from which we will select new
156
- collocation points. `selected_sample_size`: the number of selected
182
+ `sample_size_times`: the size of the sample from which we will select new
183
+ collocation points. `selected_sample_size_times`: the number of selected
157
184
  points from the sample to be added to the current collocation
158
185
  points
159
186
  "DeepXDE: A deep learning library for solving differential
@@ -179,29 +206,13 @@ class DataGeneratorODE:
179
206
  self.method = method
180
207
  self.rar_parameters = rar_parameters
181
208
 
182
- if rar_parameters is not None and nt_start is None:
183
- raise ValueError(
184
- "nt_start must be provided in the context of RAR sampling scheme"
185
- )
186
- if rar_parameters is not None:
187
- self.nt_start = nt_start
188
- # Default p is None. However, in the RAR sampling scheme we use 0
189
- # probability to specify non-used collocation points (i.e. points
190
- # above nt_start). Thus, p is a vector of probability of shape (nt, 1).
191
- self.p = jnp.zeros((self.nt,))
192
- self.p = self.p.at[: self.nt_start].set(1 / nt_start)
193
- # set internal counter for the number of gradient steps since the
194
- # last new collocation points have been added
195
- self.rar_iter_from_last_sampling = 0
196
- # set iternal counter for the number of times collocation points
197
- # have been added
198
- self.rar_iter_nb = 0
199
-
200
- if rar_parameters is None or nt_start is None:
201
- self.nt_start = self.nt
202
- self.p = None
203
- self.rar_iter_from_last_sampling = None
204
- self.rar_iter_nb = None
209
+ # Set-up for RAR (if used)
210
+ (
211
+ self.nt_start,
212
+ self.p_times,
213
+ self.rar_iter_from_last_sampling,
214
+ self.rar_iter_nb,
215
+ ) = _check_and_set_rar_parameters(rar_parameters, n=nt, n_start=nt_start)
205
216
 
206
217
  if not self.data_exists:
207
218
  # Useful when using a lax.scan with pytree
@@ -238,7 +249,7 @@ class DataGeneratorODE:
238
249
  self.times,
239
250
  self.curr_time_idx,
240
251
  self.temporal_batch_size,
241
- self.p,
252
+ self.p_times,
242
253
  )
243
254
 
244
255
  def temporal_batch(self):
@@ -253,7 +264,7 @@ class DataGeneratorODE:
253
264
  if self.rar_parameters is not None:
254
265
  nt_eff = (
255
266
  self.nt_start
256
- + self.rar_iter_nb * self.rar_parameters["selected_sample_size"]
267
+ + self.rar_iter_nb * self.rar_parameters["selected_sample_size_omega"]
257
268
  )
258
269
  else:
259
270
  nt_eff = self.nt
@@ -283,14 +294,19 @@ class DataGeneratorODE:
283
294
  self.curr_time_idx,
284
295
  self.tmin,
285
296
  self.tmax,
286
- self.rar_parameters,
287
- self.p,
297
+ self.p_times,
288
298
  self.rar_iter_from_last_sampling,
289
299
  self.rar_iter_nb,
290
300
  ) # arrays / dynamic values
291
301
  aux_data = {
292
302
  k: vars(self)[k]
293
- for k in ["temporal_batch_size", "method", "nt", "nt_start"]
303
+ for k in [
304
+ "temporal_batch_size",
305
+ "method",
306
+ "nt",
307
+ "rar_parameters",
308
+ "nt_start",
309
+ ]
294
310
  } # static values
295
311
  return (children, aux_data)
296
312
 
@@ -308,8 +324,7 @@ class DataGeneratorODE:
308
324
  curr_time_idx,
309
325
  tmin,
310
326
  tmax,
311
- rar_parameters,
312
- p,
327
+ p_times,
313
328
  rar_iter_from_last_sampling,
314
329
  rar_iter_nb,
315
330
  ) = children
@@ -318,12 +333,11 @@ class DataGeneratorODE:
318
333
  data_exists=True,
319
334
  tmin=tmin,
320
335
  tmax=tmax,
321
- rar_parameters=rar_parameters,
322
336
  **aux_data,
323
337
  )
324
338
  obj.times = times
325
339
  obj.curr_time_idx = curr_time_idx
326
- obj.p = p
340
+ obj.p_times = p_times
327
341
  obj.rar_iter_from_last_sampling = rar_iter_from_last_sampling
328
342
  obj.rar_iter_nb = rar_iter_nb
329
343
  return obj
@@ -412,10 +426,10 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
412
426
  Default to None: do not use Residual Adaptative Resampling.
413
427
  Otherwise a dictionary with keys. `start_iter`: the iteration at
414
428
  which we start the RAR sampling scheme (we first have a burn in
415
- period). `update_rate`: the number of gradient steps taken between
429
+ period). `update_every`: the number of gradient steps taken between
416
430
  each appending of collocation points in the RAR algo.
417
- `sample_size`: the size of the sample from which we will select new
418
- collocation points. `selected_sample_size`: the number of selected
431
+ `sample_size_omega`: the size of the sample from which we will select new
432
+ collocation points. `selected_sample_size_omega`: the number of selected
419
433
  points from the sample to be added to the current collocation
420
434
  points
421
435
  "DeepXDE: A deep learning library for solving differential
@@ -442,30 +456,12 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
442
456
  assert dim == len(max_pts) and isinstance(max_pts, tuple)
443
457
  self.n = n
444
458
  self.rar_parameters = rar_parameters
445
-
446
- if rar_parameters is not None and n_start is None:
447
- raise ValueError(
448
- "n_start must be provided in the context of RAR sampling scheme"
449
- )
450
- if rar_parameters is not None:
451
- self.n_start = n_start
452
- # Default p is None. However, in the RAR sampling scheme we use 0
453
- # probability to specify non-used collocation points (i.e. points
454
- # above n_start). Thus, p is a vector of probability of shape (n, 1).
455
- self.p = jnp.zeros((self.n,))
456
- self.p = self.p.at[: self.n_start].set(1 / n_start)
457
- # set internal counter for the number of gradient steps since the
458
- # last new collocation points have been added
459
- self.rar_iter_from_last_sampling = 0
460
- # set iternal counter for the number of times collocation points
461
- # have been added
462
- self.rar_iter_nb = 0
463
-
464
- if rar_parameters is None or n_start is None:
465
- self.n_start = self.n
466
- self.p = None
467
- self.rar_iter_from_last_sampling = None
468
- self.rar_iter_nb = None
459
+ (
460
+ self.n_start,
461
+ self.p_omega,
462
+ self.rar_iter_from_last_sampling,
463
+ self.rar_iter_nb,
464
+ ) = _check_and_set_rar_parameters(rar_parameters, n=n, n_start=n_start)
469
465
 
470
466
  self.p_border = None # no RAR sampling for border for now
471
467
 
@@ -643,7 +639,7 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
643
639
  self.omega,
644
640
  self.curr_omega_idx,
645
641
  self.omega_batch_size,
646
- self.p,
642
+ self.p_omega,
647
643
  )
648
644
 
649
645
  def inside_batch(self):
@@ -656,7 +652,7 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
656
652
  if self.rar_parameters is not None:
657
653
  n_eff = (
658
654
  self.n_start
659
- + self.rar_iter_nb * self.rar_parameters["selected_sample_size"]
655
+ + self.rar_iter_nb * self.rar_parameters["selected_sample_size_omega"]
660
656
  )
661
657
  else:
662
658
  n_eff = self.n
@@ -745,8 +741,7 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
745
741
  self.curr_omega_border_idx,
746
742
  self.min_pts,
747
743
  self.max_pts,
748
- self.rar_parameters,
749
- self.p,
744
+ self.p_omega,
750
745
  self.rar_iter_from_last_sampling,
751
746
  self.rar_iter_nb,
752
747
  )
@@ -759,6 +754,7 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
759
754
  "omega_border_batch_size",
760
755
  "method",
761
756
  "dim",
757
+ "rar_parameters",
762
758
  "n_start",
763
759
  ]
764
760
  }
@@ -780,8 +776,7 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
780
776
  curr_omega_border_idx,
781
777
  min_pts,
782
778
  max_pts,
783
- rar_parameters,
784
- p,
779
+ p_omega,
785
780
  rar_iter_from_last_sampling,
786
781
  rar_iter_nb,
787
782
  ) = children
@@ -792,14 +787,13 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
792
787
  data_exists=True,
793
788
  min_pts=min_pts,
794
789
  max_pts=max_pts,
795
- rar_parameters=rar_parameters,
796
790
  **aux_data,
797
791
  )
798
792
  obj.omega = omega
799
793
  obj.omega_border = omega_border
800
794
  obj.curr_omega_idx = curr_omega_idx
801
795
  obj.curr_omega_border_idx = curr_omega_border_idx
802
- obj.p = p
796
+ obj.p_omega = p_omega
803
797
  obj.rar_iter_from_last_sampling = rar_iter_from_last_sampling
804
798
  obj.rar_iter_nb = rar_iter_nb
805
799
  return obj
@@ -834,6 +828,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
834
828
  method="grid",
835
829
  rar_parameters=None,
836
830
  n_start=None,
831
+ nt_start=None,
837
832
  data_exists=False,
838
833
  ):
839
834
  r"""
@@ -887,27 +882,23 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
887
882
  Default to None: do not use Residual Adaptative Resampling.
888
883
  Otherwise a dictionary with keys. `start_iter`: the iteration at
889
884
  which we start the RAR sampling scheme (we first have a burn in
890
- period). `update_rate`: the number of gradient steps taken between
885
+ period). `update_every`: the number of gradient steps taken between
891
886
  each appending of collocation points in the RAR algo.
892
- `sample_size`: the size of the sample from which we will select new
893
- collocation points. `selected_sample_size`: the number of selected
887
+ `sample_size_omega`: the size of the sample from which we will select new
888
+ collocation points. `selected_sample_size_omega`: the number of selected
894
889
  points from the sample to be added to the current collocation
895
890
  points.
896
- __Note:__ that if RAR sampling is chosen it will currently affect both
897
- self.times and self.omega with the same hyperparameters
898
- (rar_parameters and n_start)
899
- "DeepXDE: A deep learning library for solving differential
900
- equations", L. Lu, SIAM Review, 2021
901
891
  n_start
902
892
  Defaults to None. The effective size of n used at start time.
903
893
  This value must be
904
894
  provided when rar_parameters is not None. Otherwise we set internally
905
895
  n_start = n and this is hidden from the user.
906
896
  In RAR, n_start
907
- then corresponds to the initial number of points we train the PINN.
908
- __Note:__ that if RAR sampling is chosen it will currently affect both
909
- self.times and self.omega with the same hyperparameters
910
- (rar_parameters and n_start)
897
+ then corresponds to the initial number of omega points we train the PINN.
898
+ nt_start
899
+ Defaults to None. A RAR hyper-parameter. Same as ``n_start`` but
900
+ for times collocation point. See also ``DataGeneratorODE``
901
+ documentation.
911
902
  data_exists
912
903
  Must be left to `False` when created by the user. Avoids the
913
904
  regeneration of :math:`\Omega`, :math:`\partial\Omega` and
@@ -931,6 +922,15 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
931
922
  self.tmin = tmin
932
923
  self.tmax = tmax
933
924
  self.nt = nt
925
+
926
+ # Set-up for timewise RAR (some quantity are already set-up by super())
927
+ (
928
+ self.nt_start,
929
+ self.p_times,
930
+ _,
931
+ _,
932
+ ) = _check_and_set_rar_parameters(rar_parameters, n=nt, n_start=nt_start)
933
+
934
934
  if not self.data_exists:
935
935
  # Useful when using a lax.scan with pytree
936
936
  # Optionally can tell JAX not to re-generate data
@@ -950,7 +950,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
950
950
  self.times,
951
951
  self.curr_time_idx,
952
952
  self.temporal_batch_size,
953
- self.p,
953
+ self.p_times,
954
954
  )
955
955
 
956
956
  def generate_data_nonstatio(self):
@@ -980,7 +980,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
980
980
  if self.rar_parameters is not None:
981
981
  nt_eff = (
982
982
  self.n_start
983
- + self.rar_iter_nb * self.rar_parameters["selected_sample_size"]
983
+ + self.rar_iter_nb * self.rar_parameters["selected_sample_size_times"]
984
984
  )
985
985
  else:
986
986
  nt_eff = self.nt
@@ -1022,8 +1022,8 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
1022
1022
  self.max_pts,
1023
1023
  self.tmin,
1024
1024
  self.tmax,
1025
- self.rar_parameters,
1026
- self.p,
1025
+ self.p_times,
1026
+ self.p_omega,
1027
1027
  self.rar_iter_from_last_sampling,
1028
1028
  self.rar_iter_nb,
1029
1029
  )
@@ -1038,7 +1038,9 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
1038
1038
  "temporal_batch_size",
1039
1039
  "method",
1040
1040
  "dim",
1041
+ "rar_parameters",
1041
1042
  "n_start",
1043
+ "nt_start",
1042
1044
  ]
1043
1045
  }
1044
1046
  return (children, aux_data)
@@ -1063,8 +1065,8 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
1063
1065
  max_pts,
1064
1066
  tmin,
1065
1067
  tmax,
1066
- rar_parameters,
1067
- p,
1068
+ p_times,
1069
+ p_omega,
1068
1070
  rar_iter_from_last_sampling,
1069
1071
  rar_iter_nb,
1070
1072
  ) = children
@@ -1075,7 +1077,6 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
1075
1077
  max_pts=max_pts,
1076
1078
  tmin=tmin,
1077
1079
  tmax=tmax,
1078
- rar_parameters=rar_parameters,
1079
1080
  **aux_data,
1080
1081
  )
1081
1082
  obj.omega = omega
@@ -1084,9 +1085,11 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
1084
1085
  obj.curr_omega_idx = curr_omega_idx
1085
1086
  obj.curr_omega_border_idx = curr_omega_border_idx
1086
1087
  obj.curr_time_idx = curr_time_idx
1087
- obj.p = p
1088
+ obj.p_times = p_times
1089
+ obj.p_omega = p_omega
1088
1090
  obj.rar_iter_from_last_sampling = rar_iter_from_last_sampling
1089
1091
  obj.rar_iter_nb = rar_iter_nb
1092
+
1090
1093
  return obj
1091
1094
 
1092
1095
 
jinns/data/_display.py CHANGED
@@ -1,8 +1,13 @@
1
+ """
2
+ Utility functions for plotting
3
+ """
4
+
5
+ from functools import partial
6
+ import warnings
1
7
  import matplotlib.pyplot as plt
2
8
  import jax.numpy as jnp
3
9
  from jax import vmap
4
10
  from mpl_toolkits.axes_grid1 import ImageGrid
5
- from functools import partial
6
11
 
7
12
 
8
13
  def plot2d(
@@ -14,8 +19,10 @@ def plot2d(
14
19
  figsize=(7, 7),
15
20
  cmap="inferno",
16
21
  spinn=False,
22
+ vmin_vmax=None,
23
+ ax_for_plot=None,
17
24
  ):
18
- """Generic function for plotting functions over rectangular 2-D domains
25
+ r"""Generic function for plotting functions over rectangular 2-D domains
19
26
  :math:`\Omega`. It treats both the stationary case :math:`u(x)` or the
20
27
  non-stationnary case :math:`u(t, x)`.
21
28
 
@@ -40,6 +47,12 @@ def plot2d(
40
47
  _description_, by default (7, 7)
41
48
  cmap : str, optional
42
49
  _description_, by default "inferno"
50
+ vmin_vmax : tuple, optional
51
+ The colorbar minimum and maximum value. Defaults None.
52
+ ax_for_plot : Matplotlib axis, optional
53
+ If None, jinns triggers the plotting. Otherwise this argument
54
+ corresponds to the axis which will host the plot. Default is None.
55
+ NOTE: that this argument will have an effect only if times is None.
43
56
 
44
57
  Raises
45
58
  ------
@@ -59,25 +72,49 @@ def plot2d(
59
72
  # Statio case : expect a function of one argument fun(x)
60
73
  if not spinn:
61
74
  v_fun = vmap(fun, 0, 0)
62
- _plot_2D_statio(
63
- v_fun, mesh, plot=True, colorbar=True, cmap=cmap, figsize=figsize
75
+ ret = _plot_2D_statio(
76
+ v_fun,
77
+ mesh,
78
+ plot=not ax_for_plot,
79
+ colorbar=True,
80
+ cmap=cmap,
81
+ figsize=figsize,
82
+ vmin_vmax=vmin_vmax,
64
83
  )
65
84
  elif spinn:
66
85
  values_grid = jnp.squeeze(
67
86
  fun(jnp.stack([xy_data[0][..., None], xy_data[1][..., None]], axis=1))
68
87
  )
69
- _plot_2D_statio(
88
+ ret = _plot_2D_statio(
70
89
  values_grid,
71
90
  mesh,
72
- plot=True,
91
+ plot=not ax_for_plot,
73
92
  colorbar=True,
74
93
  cmap=cmap,
75
94
  spinn=True,
76
95
  figsize=figsize,
96
+ vmin_vmax=vmin_vmax,
77
97
  )
78
- plt.title(title)
98
+ if not ax_for_plot:
99
+ plt.title(title)
100
+ else:
101
+ if vmin_vmax is not None:
102
+ im = ax_for_plot.pcolormesh(
103
+ mesh[0],
104
+ mesh[1],
105
+ ret[0],
106
+ cmap=cmap,
107
+ vmin=vmin_vmax[0],
108
+ vmax=vmin_vmax[1],
109
+ )
110
+ else:
111
+ im = ax_for_plot.pcolormesh(mesh[0], mesh[1], ret[0], cmap=cmap)
112
+ ax_for_plot.set_title(title)
113
+ ax_for_plot.cax.colorbar(im, format="%0.2f")
79
114
 
80
115
  else:
116
+ if ax_for_plot is not None:
117
+ warnings.warn("ax_for_plot is ignored. jinns will plot the figure")
81
118
  if not isinstance(times, list):
82
119
  try:
83
120
  times = times.tolist()
@@ -101,7 +138,12 @@ def plot2d(
101
138
  if not spinn:
102
139
  v_fun_at_t = vmap(lambda x: fun(t=jnp.array([t]), x=x), 0, 0)
103
140
  t_slice, _ = _plot_2D_statio(
104
- v_fun_at_t, mesh, plot=False, colorbar=False, cmap=None
141
+ v_fun_at_t,
142
+ mesh,
143
+ plot=False,
144
+ colorbar=False,
145
+ cmap=None,
146
+ vmin_vmax=vmin_vmax,
105
147
  )
106
148
  elif spinn:
107
149
  values_grid = jnp.squeeze(
@@ -113,15 +155,37 @@ def plot2d(
113
155
  )[0]
114
156
  )
115
157
  t_slice, _ = _plot_2D_statio(
116
- values_grid, mesh, plot=False, colorbar=True, spinn=True
158
+ values_grid,
159
+ mesh,
160
+ plot=False,
161
+ colorbar=True,
162
+ spinn=True,
163
+ vmin_vmax=vmin_vmax,
164
+ )
165
+ if vmin_vmax is not None:
166
+ im = ax.pcolormesh(
167
+ mesh[0],
168
+ mesh[1],
169
+ t_slice,
170
+ cmap=cmap,
171
+ vmin=vmin_vmax[0],
172
+ vmax=vmin_vmax[1],
117
173
  )
118
- im = ax.pcolormesh(mesh[0], mesh[1], t_slice, cmap=cmap)
174
+ else:
175
+ im = ax.pcolormesh(mesh[0], mesh[1], t_slice, cmap=cmap)
119
176
  ax.set_title(f"t = {times[idx] * Tmax:.2f}")
120
177
  ax.cax.colorbar(im, format="%0.2f")
121
178
 
122
179
 
123
180
  def _plot_2D_statio(
124
- v_fun, mesh, plot=True, colorbar=True, cmap="inferno", figsize=(7, 7), spinn=False
181
+ v_fun,
182
+ mesh,
183
+ plot=True,
184
+ colorbar=True,
185
+ cmap="inferno",
186
+ figsize=(7, 7),
187
+ spinn=False,
188
+ vmin_vmax=None,
125
189
  ):
126
190
  """Function that plot the function u(x) with 2-D input x using pcolormesh()
127
191
 
@@ -136,6 +200,8 @@ def _plot_2D_statio(
136
200
  either show or return the plot, by default True
137
201
  colorbar : bool, optional
138
202
  add a colorbar, by default True
203
+ vmin_vmax: tuple, optional
204
+ The colorbar minimum and maximum value. Defaults None.
139
205
 
140
206
  Returns
141
207
  -------
@@ -153,7 +219,17 @@ def _plot_2D_statio(
153
219
 
154
220
  if plot:
155
221
  fig = plt.figure(figsize=figsize)
156
- im = plt.pcolormesh(x_grid, y_grid, values_grid, cmap=cmap)
222
+ if vmin_vmax is not None:
223
+ im = plt.pcolormesh(
224
+ x_grid,
225
+ y_grid,
226
+ values_grid,
227
+ cmap=cmap,
228
+ vmin=vmin_vmax[0],
229
+ vmax=vmin_vmax[1],
230
+ )
231
+ else:
232
+ im = plt.pcolormesh(x_grid, y_grid, values_grid, cmap=cmap)
157
233
  if colorbar:
158
234
  fig.colorbar(im, format="%0.2f")
159
235
  # don't plt.show() because it is done in plot2d()
@@ -217,6 +293,7 @@ def plot1d_image(
217
293
  colorbar=True,
218
294
  cmap="inferno",
219
295
  spinn=False,
296
+ vmin_vmax=None,
220
297
  ):
221
298
  """Function for plotting the 2-D image of a function :math:`f(t, x)` where
222
299
  `t` is time (1-D) and x is space (1-D).
@@ -237,6 +314,8 @@ def plot1d_image(
237
314
  , by default ""
238
315
  figsize : tuple, optional
239
316
  , by default (10, 10)
317
+ vmin_vmax: tuple
318
+ The colorbar minimum and maximum value. Defaults None.
240
319
  """
241
320
 
242
321
  mesh = jnp.meshgrid(times, xdata) # cartesian product
@@ -250,7 +329,17 @@ def plot1d_image(
250
329
  elif spinn:
251
330
  values_grid = jnp.squeeze(fun((times[..., None]), xdata[..., None]).T)
252
331
  fig, ax = plt.subplots(1, 1, figsize=figsize)
253
- im = ax.pcolormesh(mesh[0] * Tmax, mesh[1], values_grid, cmap=cmap)
332
+ if vmin_vmax is not None:
333
+ im = ax.pcolormesh(
334
+ mesh[0] * Tmax,
335
+ mesh[1],
336
+ values_grid,
337
+ cmap=cmap,
338
+ vmin=vmin_vmax[0],
339
+ vmax=vmin_vmax[1],
340
+ )
341
+ else:
342
+ im = ax.pcolormesh(mesh[0] * Tmax, mesh[1], values_grid, cmap=cmap)
254
343
  if colorbar:
255
344
  fig.colorbar(im, format="%0.2f")
256
345
  ax.set_title(title)
@@ -6,3 +6,5 @@ from ._diffrax_solver import (
6
6
  neumann_boundary_condition,
7
7
  plot_diffrax_solution,
8
8
  )
9
+ from ._sinuspinn import create_sinusPINN
10
+ from ._spectralpinn import create_spectralPINN