jinns 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +114 -187
  13. jinns/loss/_DynamicLossAbstract.py +74 -69
  14. jinns/loss/_LossODE.py +132 -348
  15. jinns/loss/_LossPDE.py +262 -549
  16. jinns/loss/__init__.py +32 -6
  17. jinns/loss/_abstract_loss.py +128 -0
  18. jinns/loss/_boundary_conditions.py +20 -19
  19. jinns/loss/_loss_components.py +43 -0
  20. jinns/loss/_loss_utils.py +85 -179
  21. jinns/loss/_loss_weight_updates.py +202 -0
  22. jinns/loss/_loss_weights.py +64 -40
  23. jinns/loss/_operators.py +84 -74
  24. jinns/nn/__init__.py +15 -0
  25. jinns/nn/_abstract_pinn.py +22 -0
  26. jinns/nn/_hyperpinn.py +94 -57
  27. jinns/nn/_mlp.py +50 -25
  28. jinns/nn/_pinn.py +33 -19
  29. jinns/nn/_ppinn.py +70 -34
  30. jinns/nn/_save_load.py +21 -51
  31. jinns/nn/_spinn.py +33 -16
  32. jinns/nn/_spinn_mlp.py +28 -22
  33. jinns/nn/_utils.py +38 -0
  34. jinns/parameters/__init__.py +8 -1
  35. jinns/parameters/_derivative_keys.py +116 -177
  36. jinns/parameters/_params.py +18 -46
  37. jinns/plot/__init__.py +2 -0
  38. jinns/plot/_plot.py +35 -34
  39. jinns/solver/_rar.py +80 -63
  40. jinns/solver/_solve.py +207 -92
  41. jinns/solver/_utils.py +4 -6
  42. jinns/utils/__init__.py +2 -0
  43. jinns/utils/_containers.py +16 -10
  44. jinns/utils/_types.py +20 -54
  45. jinns/utils/_utils.py +4 -11
  46. jinns/validation/__init__.py +2 -0
  47. jinns/validation/_validation.py +20 -19
  48. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/METADATA +8 -4
  49. jinns-1.5.0.dist-info/RECORD +55 -0
  50. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/WHEEL +1 -1
  51. jinns/data/_DataGenerators.py +0 -1634
  52. jinns-1.3.0.dist-info/RECORD +0 -44
  53. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/AUTHORS +0 -0
  54. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/LICENSE +0 -0
  55. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/top_level.txt +0 -0
jinns/plot/_plot.py CHANGED
@@ -2,26 +2,24 @@
2
2
  Utility functions for plotting in 1D and 2D, with and without time.
3
3
  """
4
4
 
5
- from functools import partial
6
- import warnings
7
- import matplotlib.pyplot as plt
5
+ from typing import Callable
8
6
  import jax.numpy as jnp
9
7
  from jax import vmap
8
+ import matplotlib.pyplot as plt
10
9
  from mpl_toolkits.axes_grid1 import ImageGrid
11
- from typing import Callable, List
12
- from jaxtyping import Array, Float, Bool
10
+ from jaxtyping import Array, Float
13
11
 
14
12
 
15
13
  def plot2d(
16
14
  fun: Callable,
17
- xy_data: tuple[Float[Array, "nx"], Float[Array, "ny"]],
18
- times: Float[Array, "nt"] | List[float] | None = None,
15
+ xy_data: tuple[Float[Array, " nx"], Float[Array, " ny"]],
16
+ times: Float[Array, " nt"] | list[float] | None = None,
19
17
  Tmax: float = 1,
20
18
  title: str = "",
21
19
  figsize: tuple = (7, 7),
22
20
  cmap: str = "inferno",
23
21
  spinn: bool = False,
24
- vmin_vmax: tuple[float, float] = [None, None],
22
+ vmin_vmax: tuple[float | None, float | None] | None = None,
25
23
  ):
26
24
  r"""Generic function for plotting functions over rectangular 2-D domains
27
25
  $\Omega$. It handles both the
@@ -61,6 +59,8 @@ def plot2d(
61
59
  ValueError
62
60
  if xy_data is not a list of length 2
63
61
  """
62
+ if vmin_vmax is None:
63
+ vmin_vmax = (None, None)
64
64
 
65
65
  # if not isinstance(xy_data, jnp.ndarray) and not xy_data.shape[-1] == 2:
66
66
  if not isinstance(xy_data, list) and not len(xy_data) == 2:
@@ -100,7 +100,7 @@ def plot2d(
100
100
  im = ax.pcolormesh(
101
101
  mesh[0],
102
102
  mesh[1],
103
- ret[0],
103
+ ret,
104
104
  cmap=cmap,
105
105
  vmin=vmin_vmax[0],
106
106
  vmax=vmin_vmax[1],
@@ -113,7 +113,7 @@ def plot2d(
113
113
  if not isinstance(times, list):
114
114
  try:
115
115
  times = times.tolist()
116
- except:
116
+ except AttributeError:
117
117
  raise ValueError("times must be a list or an array")
118
118
 
119
119
  fig = plt.figure(figsize=figsize)
@@ -129,7 +129,7 @@ def plot2d(
129
129
  cbar_pad=0.4,
130
130
  )
131
131
 
132
- for idx, (t, ax) in enumerate(zip(times, grid)):
132
+ for idx, (t, ax) in enumerate(zip(times, iter(grid))):
133
133
  if not spinn:
134
134
  x_grid, y_grid = mesh
135
135
  v_fun_at_t = vmap(fun)(
@@ -142,12 +142,11 @@ def plot2d(
142
142
  axis=-1,
143
143
  )
144
144
  )
145
- t_slice, _ = _plot_2D_statio(
145
+ t_slice = _plot_2D_statio(
146
146
  v_fun_at_t,
147
147
  mesh,
148
148
  plot=False, # only use to compute t_slice
149
149
  colorbar=False,
150
- cmap=None,
151
150
  vmin_vmax=vmin_vmax,
152
151
  )
153
152
  elif spinn:
@@ -161,7 +160,7 @@ def plot2d(
161
160
  axis=-1,
162
161
  )
163
162
  values_grid = jnp.squeeze(fun(t_x)[0]).T
164
- t_slice, _ = _plot_2D_statio(
163
+ t_slice = _plot_2D_statio(
165
164
  values_grid,
166
165
  mesh,
167
166
  plot=False, # only use to compute t_slice
@@ -182,14 +181,14 @@ def plot2d(
182
181
 
183
182
 
184
183
  def _plot_2D_statio(
185
- v_fun: Callable | Float[Array, "(nx*ny)^2 1"],
186
- mesh: Float[Array, "nx*ny nx*ny"],
187
- plot: Bool = True,
188
- colorbar: Bool = True,
184
+ v_fun: Callable | Float[Array, " (nx*ny)^2 1"],
185
+ mesh: list[Float[Array, " nx*ny nx*ny"]],
186
+ plot: bool = True,
187
+ colorbar: bool = True,
189
188
  cmap: str = "inferno",
190
189
  figsize: tuple[int, int] = (7, 7),
191
- vmin_vmax: tuple[float, float] = [None, None],
192
- ):
190
+ vmin_vmax: tuple[float | None, float | None] | None = None,
191
+ ) -> Array | None:
193
192
  """Function that plot the function u(x) with 2-D input x using pcolormesh()
194
193
 
195
194
 
@@ -217,6 +216,8 @@ def _plot_2D_statio(
217
216
  Either None or the values of u() over the meshgrid and the current plt axis
218
217
 
219
218
  """
219
+ if vmin_vmax is None:
220
+ vmin_vmax = (None, None)
220
221
 
221
222
  x_grid, y_grid = mesh
222
223
  if callable(v_fun):
@@ -238,19 +239,18 @@ def _plot_2D_statio(
238
239
 
239
240
  if colorbar:
240
241
  fig.colorbar(im, format="%0.2f")
241
- # don't plt.show() because it is done in plot2d()
242
242
  else:
243
- return values_grid, plt.gca()
243
+ return values_grid
244
244
 
245
245
 
246
246
  def plot1d_slice(
247
- fun: Callable[[float, float], float],
248
- xdata: Float[Array, "nx"],
249
- time_slices: Float[Array, "nt"] | None = None,
247
+ fun: Callable[[Float[Array, " "]], Float[Array, " "]],
248
+ xdata: Float[Array, " nx"],
249
+ time_slices: Float[Array, " nt"] | None = None,
250
250
  Tmax: float = 1.0,
251
251
  title: str = "",
252
252
  figsize: tuple[int, int] = (10, 10),
253
- spinn: Bool = False,
253
+ spinn: bool = False,
254
254
  ax=None,
255
255
  ):
256
256
  """Function for plotting time slices of a function :math:`f(t_i, x)` where
@@ -284,7 +284,7 @@ def plot1d_slice(
284
284
  if time_slices is None:
285
285
  time_slices = jnp.array([0])
286
286
  if ax is None:
287
- fig, ax = plt.subplots(figsize=figsize)
287
+ _, ax = plt.subplots(figsize=figsize)
288
288
 
289
289
  for t in time_slices:
290
290
  t_xdata = jnp.concatenate(
@@ -306,16 +306,16 @@ def plot1d_slice(
306
306
 
307
307
 
308
308
  def plot1d_image(
309
- fun: Callable[[float, float], float],
310
- xdata: Float[Array, "nx"],
311
- times: Float[Array, "nt"],
309
+ fun: Callable[[Float[Array, " "]], Float[Array, " "]],
310
+ xdata: Float[Array, " nx"],
311
+ times: Float[Array, " nt"],
312
312
  Tmax: float = 1.0,
313
313
  title: str = "",
314
314
  figsize: tuple[int, int] = (10, 10),
315
- colorbar: Bool = True,
315
+ colorbar: bool = True,
316
316
  cmap: str = "inferno",
317
- spinn: Bool = False,
318
- vmin_vmax: tuple[float, float] = [None, None],
317
+ spinn: bool = False,
318
+ vmin_vmax: tuple[float | None, float | None] | None = None,
319
319
  ):
320
320
  """Function for plotting the 2-D image of a function :math:`f(t, x)` where
321
321
  `t` is time (1-D) and x is space (1-D).
@@ -350,7 +350,8 @@ def plot1d_image(
350
350
  fig, ax
351
351
  A `matplotlib` `Figure` and `Axes` objects with the figure.
352
352
  """
353
-
353
+ if vmin_vmax is None:
354
+ vmin_vmax = (None, None)
354
355
  mesh = jnp.meshgrid(times, xdata) # cartesian product
355
356
  if not spinn:
356
357
  # the trick is to use _plot2Dstatio
jinns/solver/_rar.py CHANGED
@@ -2,41 +2,73 @@ from __future__ import (
2
2
  annotations,
3
3
  ) # https://docs.python.org/3/library/typing.html#constant
4
4
 
5
- from typing import TYPE_CHECKING, Callable
5
+ from typing import TYPE_CHECKING, Callable, TypeAlias, Any, TypedDict
6
6
  from functools import partial
7
+ from jaxtyping import Float, Array, Bool
7
8
  import jax
8
9
  from jax import vmap
9
10
  import jax.numpy as jnp
10
11
  import equinox as eqx
11
- from jaxtyping import Int, Bool
12
-
13
- from jinns.data._Batchs import *
14
- from jinns.loss._LossODE import LossODE, SystemLossODE
15
- from jinns.loss._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
16
- from jinns.loss._loss_utils import dynamic_loss_apply
17
- from jinns.data._DataGenerators import (
18
- DataGeneratorODE,
19
- CubicMeshPDEStatio,
20
- CubicMeshPDENonStatio,
21
- )
12
+
13
+ from jinns.data._DataGeneratorODE import DataGeneratorODE
14
+ from jinns.data._CubicMeshPDEStatio import CubicMeshPDEStatio
15
+ from jinns.data._CubicMeshPDENonStatio import CubicMeshPDENonStatio
22
16
  from jinns.nn._hyperpinn import HyperPINN
23
17
  from jinns.nn._spinn import SPINN
24
18
 
25
19
 
26
20
  if TYPE_CHECKING:
27
- from jinns.utils._types import *
21
+ from jinns.data._AbstractDataGenerator import AbstractDataGenerator
22
+ from jinns.utils._types import AnyLoss
23
+ from jinns.parameters._params import Params
24
+
25
+ class DataGeneratorWithRAR(AbstractDataGenerator):
26
+ """
27
+ Add the required RAR operands for type checks
28
+ """
29
+
30
+ rar_parameters: RarParameterDict
31
+ n_start: int
32
+ rar_iter_from_last_sampling: int
33
+ rar_iter_nb: int
34
+ p: Float[Array, " n 1"]
35
+
36
+ rar_operands: TypeAlias = tuple[Any, Params, DataGeneratorWithRAR, int]
37
+
38
+
39
+ class RarParameterDict(TypedDict):
40
+ """
41
+ TypedDict to specify the Residual Adaptative Resampling procedure
42
+ Otherwise a dictionary with keys
43
+ - `start_iter`: the iteration at which we start the RAR sampling scheme (we first have a "burn-in" period).
44
+ - `update_every`: the number of gradient steps taken between
45
+ each update of collocation points in the RAR algo.
46
+ - `sample_size`: the size of the sample from which we will select new
47
+ collocation points.
48
+ - `selected_sample_size`: the number of selected
49
+ points from the sample to be added to the current collocation
50
+ points.
51
+ """
28
52
 
53
+ start_iter: int
54
+ update_every: int
55
+ sample_size: int
56
+ selected_sample_size: int
29
57
 
30
- def _proceed_to_rar(data: AnyDataGenerator, i: Int) -> Bool:
58
+
59
+ def _proceed_to_rar(data: DataGeneratorWithRAR, i: int) -> Bool[Array, " "]:
31
60
  """Utilility function with various check to ensure we can proceed with the rar_step.
32
61
  Return True if yes, and False otherwise"""
33
62
 
34
63
  # Overall checks
35
64
  check_list = [
36
65
  # check if burn-in period has ended
37
- data.rar_parameters["start_iter"] <= i,
66
+ jnp.asarray(data.rar_parameters["start_iter"] <= i),
38
67
  # check if enough iterations since last points added
39
- (data.rar_parameters["update_every"] - 1) == data.rar_iter_from_last_sampling,
68
+ jnp.asarray(
69
+ (data.rar_parameters["update_every"] - 1)
70
+ == data.rar_iter_from_last_sampling
71
+ ),
40
72
  ]
41
73
 
42
74
  # Memory allocation checks
@@ -52,14 +84,13 @@ def _proceed_to_rar(data: AnyDataGenerator, i: Int) -> Bool:
52
84
 
53
85
  @partial(jax.jit, static_argnames=["_rar_step_true", "_rar_step_false"])
54
86
  def trigger_rar(
55
- i: Int,
87
+ i: int,
56
88
  loss: AnyLoss,
57
- params: AnyParams,
58
- data: AnyDataGenerator,
59
- _rar_step_true: Callable[[rar_operands], AnyDataGenerator],
60
- _rar_step_false: Callable[[rar_operands], AnyDataGenerator],
61
- ) -> tuple[AnyLoss, AnyParams, AnyDataGenerator]:
62
-
89
+ params: Params,
90
+ data: DataGeneratorWithRAR,
91
+ _rar_step_true: Callable[[rar_operands], DataGeneratorWithRAR],
92
+ _rar_step_false: Callable[[rar_operands], DataGeneratorWithRAR],
93
+ ) -> tuple[AnyLoss, Params, DataGeneratorWithRAR]:
63
94
  if data.rar_parameters is None:
64
95
  # do nothing.
65
96
  return loss, params, data
@@ -75,11 +106,11 @@ def trigger_rar(
75
106
 
76
107
 
77
108
  def init_rar(
78
- data: AnyDataGenerator,
109
+ data: DataGeneratorWithRAR,
79
110
  ) -> tuple[
80
- AnyDataGenerator,
81
- Callable[[rar_operands], AnyDataGenerator],
82
- Callable[[rar_operands], AnyDataGenerator],
111
+ DataGeneratorWithRAR,
112
+ Callable[[rar_operands], DataGeneratorWithRAR] | None,
113
+ Callable[[rar_operands], DataGeneratorWithRAR] | None,
83
114
  ]:
84
115
  """
85
116
  Separated from the main rar, because the initialization to get _true and
@@ -100,9 +131,11 @@ def init_rar(
100
131
  return data, _rar_step_true, _rar_step_false
101
132
 
102
133
 
103
- def _rar_step_init(sample_size: Int, selected_sample_size: Int) -> tuple[
104
- Callable[[rar_operands], AnyDataGenerator],
105
- Callable[[rar_operands], AnyDataGenerator],
134
+ def _rar_step_init(
135
+ sample_size: int, selected_sample_size: int
136
+ ) -> tuple[
137
+ Callable[[rar_operands], DataGeneratorWithRAR],
138
+ Callable[[rar_operands], DataGeneratorWithRAR],
106
139
  ]:
107
140
  """
108
141
  This is a wrapper because the sampling size and
@@ -113,13 +146,12 @@ def _rar_step_init(sample_size: Int, selected_sample_size: Int) -> tuple[
113
146
  This is a kind of manual declaration of static argnums
114
147
  """
115
148
 
116
- def rar_step_true(operands: rar_operands) -> AnyDataGenerator:
117
- loss, params, data, i = operands
149
+ def rar_step_true(operands: rar_operands) -> DataGeneratorWithRAR:
150
+ loss, params, data, _ = operands
118
151
  if isinstance(loss.u, HyperPINN) or isinstance(loss.u, SPINN):
119
152
  raise NotImplementedError("RAR not implemented for hyperPINN and SPINN")
120
153
 
121
154
  if isinstance(data, DataGeneratorODE):
122
-
123
155
  new_key, subkey = jax.random.split(data.key)
124
156
  new_samples = data.sample_in_time_domain(subkey, sample_size)
125
157
  data = eqx.tree_at(lambda m: m.key, data, new_key)
@@ -145,33 +177,15 @@ def _rar_step_init(sample_size: Int, selected_sample_size: Int) -> tuple[
145
177
 
146
178
  data = eqx.tree_at(lambda m: m.key, data, new_key)
147
179
 
148
- # We can have different types of Loss
149
- if isinstance(loss, (LossODE, LossPDEStatio, LossPDENonStatio)):
150
- v_dyn_loss = vmap(
151
- lambda inputs: loss.dynamic_loss.evaluate(inputs, loss.u, params),
152
- )
153
- dyn_on_s = v_dyn_loss(new_samples)
180
+ v_dyn_loss = vmap(
181
+ lambda inputs: loss.dynamic_loss.evaluate(inputs, loss.u, params),
182
+ )
183
+ dyn_on_s = v_dyn_loss(new_samples)
154
184
 
155
- if dyn_on_s.ndim > 1:
156
- mse_on_s = (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
157
- else:
158
- mse_on_s = dyn_on_s**2
159
- elif isinstance(loss, SystemLossODE, SystemLossPDE):
160
- mse_on_s = 0
161
-
162
- for i in loss.dynamic_loss_dict.keys():
163
- v_dyn_loss = vmap(
164
- lambda inputs: loss.dynamic_loss_dict[i].evaluate(
165
- inputs, loss.u_dict, params
166
- ),
167
- (0),
168
- 0,
169
- )
170
- dyn_on_s = v_dyn_loss(new_samples)
171
- if dyn_on_s.ndim > 1:
172
- mse_on_s += (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
173
- else:
174
- mse_on_s += dyn_on_s**2
185
+ if dyn_on_s.ndim > 1:
186
+ mse_on_s = (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
187
+ else:
188
+ mse_on_s = dyn_on_s**2
175
189
 
176
190
  ## Select the m points with higher dynamic loss
177
191
  higher_residual_idx = jax.lax.dynamic_slice(
@@ -188,7 +202,7 @@ def _rar_step_init(sample_size: Int, selected_sample_size: Int) -> tuple[
188
202
  new_times = jax.lax.dynamic_update_slice(
189
203
  data.times,
190
204
  higher_residual_points,
191
- (data.n_start + data.rar_iter_nb * selected_sample_size,),
205
+ (data.n_start + data.rar_iter_nb * selected_sample_size,), # type: ignore
192
206
  )
193
207
 
194
208
  data = eqx.tree_at(lambda m: m.times, data, new_times)
@@ -198,7 +212,7 @@ def _rar_step_init(sample_size: Int, selected_sample_size: Int) -> tuple[
198
212
  new_omega = jax.lax.dynamic_update_slice(
199
213
  data.omega,
200
214
  higher_residual_points,
201
- (data.n_start + data.rar_iter_nb * selected_sample_size, data.dim),
215
+ (data.n_start + data.rar_iter_nb * selected_sample_size, data.dim), # type: ignore
202
216
  )
203
217
 
204
218
  data = eqx.tree_at(lambda m: m.omega, data, new_omega)
@@ -207,7 +221,10 @@ def _rar_step_init(sample_size: Int, selected_sample_size: Int) -> tuple[
207
221
  new_domain = jax.lax.dynamic_update_slice(
208
222
  data.domain,
209
223
  higher_residual_points,
210
- (data.n_start + data.rar_iter_nb * selected_sample_size, 1 + data.dim),
224
+ (
225
+ data.n_start + data.rar_iter_nb * selected_sample_size, # type: ignore
226
+ 1 + data.dim,
227
+ ),
211
228
  )
212
229
 
213
230
  data = eqx.tree_at(lambda m: m.domain, data, new_domain)
@@ -246,7 +263,7 @@ def _rar_step_init(sample_size: Int, selected_sample_size: Int) -> tuple[
246
263
  # have side effects in this function that will be jitted
247
264
  return data
248
265
 
249
- def rar_step_false(operands: rar_operands) -> AnyDataGenerator:
266
+ def rar_step_false(operands: rar_operands) -> DataGeneratorWithRAR:
250
267
  _, _, data, i = operands
251
268
 
252
269
  # Add 1 only if we are after the burn in period