jinns 1.1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,11 +11,7 @@ import jax
11
11
  import jax.numpy as jnp
12
12
  from jax import vmap, grad
13
13
  import equinox as eqx
14
- from jinns.utils._utils import (
15
- _get_grid,
16
- _check_user_func_return,
17
- )
18
- from jinns.parameters._params import _get_vmap_in_axes_params
14
+ from jinns.utils._utils import get_grid, _subtract_with_check
19
15
  from jinns.data._Batchs import *
20
16
  from jinns.utils._pinn import PINN
21
17
  from jinns.utils._spinn import SPINN
@@ -26,12 +22,15 @@ if TYPE_CHECKING:
26
22
 
27
23
  def _compute_boundary_loss(
28
24
  boundary_condition_type: str,
29
- f: Callable,
25
+ f: Callable[
26
+ [Float[Array, "dim"] | Float[Array, "dim + 1"]], Float[Array, "dim_solution"]
27
+ ],
30
28
  batch: PDEStatioBatch | PDENonStatioBatch,
31
29
  u: eqx.Module,
32
- params: Params | ParamsDict,
30
+ params: AnyParams,
33
31
  facet: int,
34
32
  dim_to_apply: slice,
33
+ vmap_in_axes: tuple,
35
34
  ) -> float:
36
35
  r"""A generic function that will compute the mini-batch MSE of a
37
36
  boundary condition in the stationary case, resp. non-stationary, given by:
@@ -62,9 +61,9 @@ def _compute_boundary_loss(
62
61
  unitary outgoing vector normal to $\partial\Omega$
63
62
  f
64
63
  the function to be matched in the boundary condition. It should have
65
- one or two arguments only (other are ignored).
64
+ one argument only (for `t`, `x` or `t_x`) (other are ignored).
66
65
  batch
67
- a PDEStatioBatch or PDENonStatioBatch
66
+ the batch
68
67
  u
69
68
  a PINN
70
69
  params
@@ -75,233 +74,54 @@ def _compute_boundary_loss(
75
74
  dim_to_apply
76
75
  A `jnp.s_` object which indicates which dimension(s) of u will be forced
77
76
  to match the boundary condition
77
+ vmap_in_axes
78
+ A tuple object which specifies the in_axes of the vmapping
78
79
 
79
80
  Returns
80
81
  -------
81
82
  scalar
82
83
  the MSE computed on `batch`
83
84
  """
84
- mse = None
85
- if isinstance(batch, PDEStatioBatch):
86
- if boundary_condition_type.lower() in "dirichlet":
87
- mse = boundary_dirichlet_statio(f, batch, u, params, facet, dim_to_apply)
88
- elif any(
89
- boundary_condition_type.lower() in s
90
- for s in ["von neumann", "vn", "vonneumann"]
91
- ):
92
- mse = boundary_neumann_statio(f, batch, u, params, facet, dim_to_apply)
93
- elif isinstance(batch, PDENonStatioBatch):
94
- if boundary_condition_type.lower() in "dirichlet":
95
- mse = boundary_dirichlet_nonstatio(f, batch, u, params, facet, dim_to_apply)
96
- elif any(
97
- boundary_condition_type.lower() in s
98
- for s in ["von neumann", "vn", "vonneumann"]
99
- ):
100
- mse = boundary_neumann_nonstatio(f, batch, u, params, facet, dim_to_apply)
85
+ if boundary_condition_type.lower() in "dirichlet":
86
+ mse = boundary_dirichlet(f, batch, u, params, facet, dim_to_apply, vmap_in_axes)
87
+ elif any(
88
+ boundary_condition_type.lower() in s
89
+ for s in ["von neumann", "vn", "vonneumann"]
90
+ ):
91
+ mse = boundary_neumann(f, batch, u, params, facet, dim_to_apply, vmap_in_axes)
101
92
  else:
102
- raise ValueError("Wrong type of batch")
93
+ raise ValueError("Wrong type of initial condition")
103
94
  return mse
104
95
 
105
96
 
106
- def boundary_dirichlet_statio(
107
- f: Callable,
108
- batch: PDEStatioBatch,
109
- u: eqx.Module,
110
- params: Params | ParamsDict,
111
- facet: int,
112
- dim_to_apply: slice,
113
- ) -> float:
114
- r"""
115
- This omega boundary condition enforces a solution that is equal to f on
116
- border batch.
117
-
118
- __Note__: if using a batch.param_batch_dict, we need to resolve the
119
- vmapping axes here however params["eq_params"] has already been fed with
120
- the batch in the `evaluate()` of `LossPDEStatio`.
121
-
122
- Parameters
123
- ----------
124
- f
125
- the constraint function
126
- batch
127
- A PDEStatioBatch object.
128
- u
129
- The PINN
130
- params
131
- Params or ParamsDict
132
- dim_to_apply
133
- A jnp.s\_ object. The dimension of u on which to apply the boundary condition
134
- """
135
- _, border_batch = batch.inside_batch, batch.border_batch
136
- border_batch = border_batch[..., facet]
137
-
138
- if isinstance(u, PINN):
139
- vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
140
- vmap_in_axes_x = (0,)
141
-
142
- v_u_boundary = vmap(
143
- lambda dx, params: u(dx, params)[dim_to_apply] - f(dx),
144
- vmap_in_axes_x + vmap_in_axes_params,
145
- 0,
146
- )
147
-
148
- mse_u_boundary = jnp.sum((v_u_boundary(border_batch, params)) ** 2, axis=-1)
149
- elif isinstance(u, SPINN):
150
- values = u(border_batch, params)[..., dim_to_apply]
151
- x_grid = _get_grid(border_batch)
152
- boundaries = _check_user_func_return(f(x_grid), values.shape)
153
- res = values - boundaries
154
- mse_u_boundary = jnp.sum(
155
- res**2,
156
- axis=-1,
157
- )
158
- else:
159
- raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
160
- return mse_u_boundary
161
-
162
-
163
- def boundary_neumann_statio(
164
- f: Callable,
165
- batch: PDEStatioBatch,
166
- u: eqx.Module,
167
- params: Params | ParamsDict,
168
- facet: int,
169
- dim_to_apply: slice,
170
- ) -> float:
171
- r"""
172
- This omega boundary condition enforces a solution where $\nabla u\cdot
173
- n$ is equal to `f` on omega borders. $n$ is the unitary
174
- outgoing vector normal at border $\partial\Omega$.
175
-
176
- __Note__: if using a batch.param_batch_dict, we need to resolve the
177
- vmapping axes here however params["eq_params"] has already been fed with
178
- the batch in the `evaluate()` of `LossPDEStatio`.
179
-
180
- Parameters
181
- ----------
182
- f
183
- the constraint function
184
- batch
185
- A PDEStatioBatch object.
186
- u
187
- The PINN
188
- params
189
- The dictionary of parameters of the model.
190
- Typically, it is a dictionary of
191
- dictionaries: `eq_params` and `nn_params``, respectively the
192
- differential equation parameters and the neural network parameter
193
- facet
194
- An integer which represents the id of the facet which is currently
195
- considered (in the order provided wy the DataGenerator which is fixed)
196
- dim_to_apply
197
- A jnp.s\_ object. The dimension of u on which to apply the boundary condition
198
- """
199
- _, border_batch = batch.inside_batch, batch.border_batch
200
- border_batch = border_batch[..., facet]
201
-
202
- # We resort to the shape of the border_batch to determine the dimension as
203
- # described in the border_batch function
204
- if jnp.squeeze(border_batch).ndim == 0: # case 1D borders (just a scalar)
205
- n = jnp.array([1, -1]) # the unit vectors normal to the two borders
206
-
207
- else: # case 2D borders (because 3D borders are not supported yet)
208
- # they are in the order: left, right, bottom, top so we give the normal
209
- # outgoing vectors accordingly with shape in concordance with
210
- # border_batch shape (batch_size, ndim, nfacets)
211
- n = jnp.array([[-1, 1, 0, 0], [0, 0, -1, 1]])
212
-
213
- if isinstance(u, PINN):
214
- u_ = lambda x, params: jnp.squeeze(u(x, params)[dim_to_apply])
215
-
216
- vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
217
- vmap_in_axes_x = (0,)
218
-
219
- v_neumann = vmap(
220
- lambda dx, params: jnp.dot(
221
- grad(u_, 0)(dx, params),
222
- n[..., facet],
223
- )
224
- - f(dx),
225
- vmap_in_axes_x + vmap_in_axes_params,
226
- 0,
227
- )
228
- mse_u_boundary = jnp.sum((v_neumann(border_batch, params)) ** 2, axis=-1)
229
- elif isinstance(u, SPINN):
230
- # the gradient we see in the PINN case can get gradients wrt to x
231
- # dimensions at once. But it would be very inefficient in SPINN because
232
- # of the high dim output of u. So we do 2 explicit forward AD, handling all the
233
- # high dim output at once
234
- if border_batch.shape[0] == 1: # i.e. case 1D
235
- _, du_dx = jax.jvp(
236
- lambda x: u(
237
- x,
238
- params,
239
- )[..., dim_to_apply],
240
- (border_batch,),
241
- (jnp.ones_like(border_batch),),
242
- )
243
- values = du_dx * n[facet]
244
- elif border_batch.shape[-1] == 2:
245
- tangent_vec_0 = jnp.repeat(
246
- jnp.array([1.0, 0.0])[None], border_batch.shape[0], axis=0
247
- )
248
- tangent_vec_1 = jnp.repeat(
249
- jnp.array([0.0, 1.0])[None], border_batch.shape[0], axis=0
250
- )
251
- _, du_dx1 = jax.jvp(
252
- lambda x: u(
253
- x,
254
- params,
255
- ),
256
- (border_batch,),
257
- (tangent_vec_0,),
258
- )
259
- _, du_dx2 = jax.jvp(
260
- lambda x: u(
261
- x,
262
- params,
263
- ),
264
- (border_batch,),
265
- (tangent_vec_1,),
266
- )
267
- values = du_dx1 * n[0, facet] + du_dx2 * n[1, facet] # dot product
268
- # explicitly written
269
- else:
270
- raise ValueError("Not implemented, we'll do that with a loop")
271
-
272
- x_grid = _get_grid(border_batch)
273
- boundaries = _check_user_func_return(f(x_grid), values.shape)
274
- res = values - boundaries
275
- mse_u_boundary = jnp.sum(res**2, axis=-1)
276
- else:
277
- raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
278
- return mse_u_boundary
279
-
280
-
281
- def boundary_dirichlet_nonstatio(
282
- f: Callable,
283
- batch: PDENonStatioBatch,
97
+ def boundary_dirichlet(
98
+ f: Callable[
99
+ [Float[Array, "dim"] | Float[Array, "dim + 1"]], Float[Array, "dim_solution"]
100
+ ],
101
+ batch: PDEStatioBatch | PDENonStatioBatch,
284
102
  u: eqx.Module,
285
103
  params: Params | ParamsDict,
286
104
  facet: int,
287
105
  dim_to_apply: slice,
106
+ vmap_in_axes: tuple,
288
107
  ) -> float:
289
108
  r"""
290
109
  This omega boundary condition enforces a solution that is equal to `f`
291
- at `times_batch` x `omega borders`
110
+ at `times_batch` x `omega_border` (non stationary case) or at `omega_border`
111
+ (stationary case)
292
112
 
293
113
  __Note__: if using a batch.param_batch_dict, we need to resolve the
294
114
  vmapping axes here however params["eq_params"] has already been fed with
295
- the batch in the `evaluate()` of `LossPDENonStatio`.
115
+ the batch in the `evaluate()` of `LossPDE*`.
296
116
 
297
117
  Parameters
298
118
  ----------
299
119
  f
300
120
  the constraint function
301
121
  batch
302
- A PDENonStatioBatch object.
122
+ The batch
303
123
  u
304
- The PINN
124
+ The PINN or SPINN
305
125
  params
306
126
  The dictionary of parameters of the model.
307
127
  Typically, it is a dictionary of
@@ -312,49 +132,50 @@ def boundary_dirichlet_nonstatio(
312
132
  considered (in the order provided wy the DataGenerator which is fixed)
313
133
  dim_to_apply
314
134
  A jnp.s\_ object. The dimension of u on which to apply the boundary condition
135
+ vmap_in_axes
136
+ A tuple object which specifies the in_axes of the vmapping
315
137
  """
316
- times_batch = batch.times_x_border_batch[:, 0:1, facet]
317
- omega_border_batch = batch.times_x_border_batch[:, 1:, facet]
138
+ batch_array = batch.border_batch
139
+ batch_array = batch_array[..., facet]
318
140
 
319
141
  if isinstance(u, PINN):
320
- vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
321
- vmap_in_axes_x_t = (0, 0)
322
-
323
142
  v_u_boundary = vmap(
324
- lambda t, dx, params: u(
325
- t,
326
- dx,
327
- params,
328
- )[dim_to_apply]
329
- - f(t, dx),
330
- vmap_in_axes_x_t + vmap_in_axes_params,
143
+ lambda inputs, params: _subtract_with_check(
144
+ f(inputs),
145
+ u(
146
+ inputs,
147
+ params,
148
+ )[dim_to_apply],
149
+ cause="boundary condition fun",
150
+ ),
151
+ vmap_in_axes,
331
152
  0,
332
153
  )
333
- res = v_u_boundary(times_batch, omega_border_batch, params)
154
+ res = v_u_boundary(batch_array, params)
334
155
  mse_u_boundary = jnp.sum(
335
156
  res**2,
336
157
  axis=-1,
337
158
  )
338
159
  elif isinstance(u, SPINN):
339
- values = u(times_batch, omega_border_batch, params)[..., dim_to_apply]
340
- tx_grid = _get_grid(jnp.concatenate([times_batch, omega_border_batch], axis=-1))
341
- boundaries = _check_user_func_return(
342
- f(tx_grid[..., 0:1], tx_grid[..., 1:]), values.shape
343
- )
344
- res = values - boundaries
160
+ values = u(batch_array, params)[..., dim_to_apply]
161
+ grid = get_grid(batch_array)
162
+ res = _subtract_with_check(f(grid), values, cause="boundary condition fun")
345
163
  mse_u_boundary = jnp.sum(res**2, axis=-1)
346
164
  else:
347
165
  raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
348
166
  return mse_u_boundary
349
167
 
350
168
 
351
- def boundary_neumann_nonstatio(
352
- f: Callable,
353
- batch: PDENonStatioBatch,
169
+ def boundary_neumann(
170
+ f: Callable[
171
+ [Float[Array, "dim"] | Float[Array, "dim + 1"]], Float[Array, "dim_solution"]
172
+ ],
173
+ batch: PDEStatioBatch | PDENonStatioBatch,
354
174
  u: eqx.Module,
355
175
  params: Params | ParamsDict,
356
176
  facet: int,
357
177
  dim_to_apply: slice,
178
+ vmap_in_axes: tuple,
358
179
  ) -> float:
359
180
  r"""
360
181
  This omega boundary condition enforces a solution where $\nabla u\cdot
@@ -371,7 +192,7 @@ def boundary_neumann_nonstatio(
371
192
  f:
372
193
  the constraint function
373
194
  batch
374
- A PDENonStatioBatch object.
195
+ The batch
375
196
  u
376
197
  The PINN
377
198
  params
@@ -384,13 +205,15 @@ def boundary_neumann_nonstatio(
384
205
  considered (in the order provided wy the DataGenerator which is fixed)
385
206
  dim_to_apply
386
207
  A jnp.s\_ object. The dimension of u on which to apply the boundary condition
208
+ vmap_in_axes
209
+ A tuple object which specifies the in_axes of the vmapping
387
210
  """
388
- times_batch = batch.times_x_border_batch[:, 0:1, facet]
389
- omega_border_batch = batch.times_x_border_batch[:, 1:, facet]
211
+ batch_array = batch.border_batch
212
+ batch_array = batch_array[..., facet]
390
213
 
391
214
  # We resort to the shape of the border_batch to determine the dimension as
392
215
  # described in the border_batch function
393
- if jnp.squeeze(omega_border_batch).ndim == 0: # case 1D borders (just a scalar)
216
+ if jnp.squeeze(batch_array).ndim == 0: # case 1D borders (just a scalar)
394
217
  n = jnp.array([1, -1]) # the unit vectors normal to the two borders
395
218
 
396
219
  else: # case 2D borders (because 3D borders are not supported yet)
@@ -400,24 +223,41 @@ def boundary_neumann_nonstatio(
400
223
  n = jnp.array([[-1, 1, 0, 0], [0, 0, -1, 1]])
401
224
 
402
225
  if isinstance(u, PINN):
403
- vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
404
- vmap_in_axes_x_t = (0, 0)
405
226
 
406
- u_ = lambda t, x, params: jnp.squeeze(u(t, x, params)[dim_to_apply])
407
- v_neumann = vmap(
408
- lambda t, dx, params: jnp.dot(
409
- grad(u_, 1)(t, dx, params),
410
- n[..., facet],
227
+ u_ = lambda inputs, params: jnp.squeeze(u(inputs, params)[dim_to_apply])
228
+
229
+ if u.eq_type == "statio_PDE":
230
+ v_neumann = vmap(
231
+ lambda inputs, params: _subtract_with_check(
232
+ f(inputs),
233
+ jnp.dot(
234
+ grad(u_, 0)(inputs, params),
235
+ n[..., facet],
236
+ ),
237
+ cause="boundary condition fun",
238
+ ),
239
+ vmap_in_axes,
240
+ 0,
411
241
  )
412
- - f(t, dx),
413
- vmap_in_axes_x_t + vmap_in_axes_params,
414
- 0,
415
- )
242
+ elif u.eq_type == "nonstatio_PDE":
243
+ v_neumann = vmap(
244
+ lambda inputs, params: _subtract_with_check(
245
+ f(inputs),
246
+ jnp.dot(
247
+ grad(u_, 0)(inputs, params)[1:], # get rid of time dim
248
+ n[..., facet],
249
+ ),
250
+ cause="boundary condition fun",
251
+ ),
252
+ vmap_in_axes,
253
+ 0,
254
+ )
255
+ else:
256
+ raise ValueError("Wrong u.eq_type")
416
257
  mse_u_boundary = jnp.sum(
417
258
  (
418
259
  v_neumann(
419
- times_batch,
420
- omega_border_batch,
260
+ batch_array,
421
261
  params,
422
262
  )
423
263
  )
@@ -430,40 +270,69 @@ def boundary_neumann_nonstatio(
430
270
  # dimensions at once. But it would be very inefficient in SPINN because
431
271
  # of the high dim output of u. So we do 2 explicit forward AD, handling all the
432
272
  # high dim output at once
433
- if omega_border_batch.shape[0] == 1: # i.e. case 1D
434
- _, du_dx = jax.jvp(
435
- lambda x: u(times_batch, x, params)[..., dim_to_apply],
436
- (omega_border_batch,),
437
- (jnp.ones_like(omega_border_batch),),
438
- )
439
- values = du_dx * n[facet]
440
- elif omega_border_batch.shape[-1] == 2:
441
- tangent_vec_0 = jnp.repeat(
442
- jnp.array([1.0, 0.0])[None], omega_border_batch.shape[0], axis=0
443
- )
444
- tangent_vec_1 = jnp.repeat(
445
- jnp.array([0.0, 1.0])[None], omega_border_batch.shape[0], axis=0
446
- )
447
- _, du_dx1 = jax.jvp(
448
- lambda x: u(times_batch, x, params)[..., dim_to_apply],
449
- (omega_border_batch,),
450
- (tangent_vec_0,),
451
- )
452
- _, du_dx2 = jax.jvp(
453
- lambda x: u(times_batch, x, params)[..., dim_to_apply],
454
- (omega_border_batch,),
455
- (tangent_vec_1,),
456
- )
457
- values = du_dx1 * n[0, facet] + du_dx2 * n[1, facet] # dot product
458
- # explicitly written
273
+ if (batch_array.shape[0] == 1 and isinstance(batch, PDEStatioBatch)) or (
274
+ batch_array.shape[-1] == 2 and isinstance(batch, PDENonStatioBatch)
275
+ ):
276
+ if u.eq_type == "statio_PDE":
277
+ _, du_dx = jax.jvp(
278
+ lambda inputs: u(inputs, params)[..., dim_to_apply],
279
+ (batch_array,),
280
+ (jnp.ones_like(batch_array),),
281
+ )
282
+ values = du_dx * n[facet]
283
+ if u.eq_type == "nonstatio_PDE":
284
+ _, du_dx = jax.jvp(
285
+ lambda inputs: u(inputs, params)[..., dim_to_apply],
286
+ (batch_array,),
287
+ (jnp.ones_like(batch_array),),
288
+ )
289
+ values = du_dx[..., 1] * n[facet]
290
+ elif (batch_array.shape[-1] == 2 and isinstance(batch, PDEStatioBatch)) or (
291
+ batch_array.shape[-1] == 3 and isinstance(batch, PDENonStatioBatch)
292
+ ):
293
+ if u.eq_type == "statio_PDE":
294
+ tangent_vec_0 = jnp.repeat(
295
+ jnp.array([1.0, 0.0])[None], batch_array.shape[0], axis=0
296
+ )
297
+ tangent_vec_1 = jnp.repeat(
298
+ jnp.array([0.0, 1.0])[None], batch_array.shape[0], axis=0
299
+ )
300
+ _, du_dx1 = jax.jvp(
301
+ lambda inputs: u(inputs, params)[..., dim_to_apply],
302
+ (batch_array,),
303
+ (tangent_vec_0,),
304
+ )
305
+ _, du_dx2 = jax.jvp(
306
+ lambda inputs: u(inputs, params)[..., dim_to_apply],
307
+ (batch_array,),
308
+ (tangent_vec_1,),
309
+ )
310
+ values = du_dx1 * n[0, facet] + du_dx2 * n[1, facet] # dot product
311
+ if u.eq_type == "nonstatio_PDE":
312
+ tangent_vec_0 = jnp.repeat(
313
+ jnp.array([0.0, 1.0, 0.0])[None], batch_array.shape[0], axis=0
314
+ )
315
+ tangent_vec_1 = jnp.repeat(
316
+ jnp.array([0.0, 0.0, 1.0])[None], batch_array.shape[0], axis=0
317
+ )
318
+ _, du_dx1 = jax.jvp(
319
+ lambda inputs: u(inputs, params)[..., dim_to_apply],
320
+ (batch_array,),
321
+ (tangent_vec_0,),
322
+ )
323
+ _, du_dx2 = jax.jvp(
324
+ lambda inputs: u(inputs, params)[..., dim_to_apply],
325
+ (batch_array,),
326
+ (tangent_vec_1,),
327
+ )
328
+ values = (
329
+ du_dx1.squeeze() * n[0, facet] + du_dx2.squeeze() * n[1, facet]
330
+ ) # dot product
459
331
  else:
460
332
  raise ValueError("Not implemented, we'll do that with a loop")
461
333
 
462
- tx_grid = _get_grid(jnp.concatenate([times_batch, omega_border_batch], axis=-1))
463
- boundaries = _check_user_func_return(
464
- f(tx_grid[..., 0:1], tx_grid[..., 1:]), values.shape
465
- )
466
- res = values - boundaries
334
+ grid = get_grid(batch_array)
335
+ res = _subtract_with_check(f(grid), values, cause="boundary condition fun")
467
336
  mse_u_boundary = jnp.sum(
468
337
  res**2,
469
338
  axis=-1,