jinns 0.8.10__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,9 +17,8 @@ class ODEBatch(NamedTuple):
17
17
 
18
18
 
19
19
  class PDENonStatioBatch(NamedTuple):
20
- inside_batch: ArrayLike
21
- border_batch: ArrayLike
22
- temporal_batch: ArrayLike
20
+ times_x_inside_batch: ArrayLike
21
+ times_x_border_batch: ArrayLike
23
22
  param_batch_dict: dict = None
24
23
  obs_batch_dict: dict = None
25
24
 
@@ -47,6 +46,18 @@ def append_obs_batch(batch, obs_batch_dict):
47
46
  return batch._replace(obs_batch_dict=obs_batch_dict)
48
47
 
49
48
 
49
+ def make_cartesian_product(b1, b2):
50
+ """
51
+ Create the cartesian product of a time and a border omega batches
52
+ by tiling and repeating
53
+ """
54
+ n1 = b1.shape[0]
55
+ n2 = b2.shape[0]
56
+ b1 = jnp.repeat(b1, n2, axis=0)
57
+ b2 = jnp.tile(b2, reps=(n1,) + tuple(1 for i in b2.shape[1:]))
58
+ return jnp.concatenate([b1, b2], axis=1)
59
+
60
+
50
61
  def _reset_batch_idx_and_permute(operands):
51
62
  key, domain, curr_idx, _, p = operands
52
63
  # resetting counter
@@ -476,10 +487,9 @@ class CubicMeshPDEStatio(DataGeneratorPDEAbstract):
476
487
  # always set to 2.
477
488
  self.nb = 2
478
489
  self.omega_border_batch_size = 2
479
- # warnings.warn("We are in 1-D case => omega_border_batch_size is "
480
- # "ignored since borders of Omega are singletons."
481
- # " self.border_batch() will return [xmin, xmax]"
482
- # )
490
+ # We are in 1-D case => omega_border_batch_size is
491
+ # ignored since borders of Omega are singletons.
492
+ # self.border_batch() will return [xmin, xmax]
483
493
  else:
484
494
  if nb % (2 * self.dim) != 0 or nb < 2 * self.dim:
485
495
  raise ValueError(
@@ -829,6 +839,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
829
839
  rar_parameters=None,
830
840
  n_start=None,
831
841
  nt_start=None,
842
+ cartesian_product=True,
832
843
  data_exists=False,
833
844
  ):
834
845
  r"""
@@ -899,6 +910,10 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
899
910
  Defaults to None. A RAR hyper-parameter. Same as ``n_start`` but
900
911
  for times collocation point. See also ``DataGeneratorODE``
901
912
  documentation.
913
+ cartesian_product
914
+ Defaults to True. Whether we return the cartesian product of the
915
+ temporal batch with the inside and border batches. If False we just
916
+ return their concatenation.
902
917
  data_exists
903
918
  Must be left to `False` when created by the user. Avoids the
904
919
  regeneration of :math:`\Omega`, :math:`\partial\Omega` and
@@ -923,6 +938,30 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
923
938
  self.tmax = tmax
924
939
  self.nt = nt
925
940
 
941
+ self.cartesian_product = cartesian_product
942
+ if not self.cartesian_product:
943
+ if self.temporal_batch_size != self.omega_batch_size:
944
+ raise ValueError(
945
+ "If stacking is requested between the time and "
946
+ "inside batches of collocation points, self.temporal_batch_size "
947
+ "must then be equal to self.omega_batch_size"
948
+ )
949
+ if (
950
+ self.dim > 1
951
+ and self.omega_border_batch_size is not None
952
+ and self.temporal_batch_size != self.omega_border_batch_size
953
+ ):
954
+ raise ValueError(
955
+ "If dim > 1 and stacking is requested between the time and "
956
+ "inside batches of collocation points, self.temporal_batch_size "
957
+ "must then be equal to self.omega_border_batch_size"
958
+ )
959
+ # Note if self.dim == 1:
960
+ # print(
961
+ # "Cartesian product is not requested but will be "
962
+ # "executed anyway since dim=1"
963
+ # )
964
+
926
965
  # Set-up for timewise RAR (some quantity are already set-up by super())
927
966
  (
928
967
  self.nt_start,
@@ -1003,11 +1042,26 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
1003
1042
  Generic method to return a batch. Here we call `self.inside_batch()`,
1004
1043
  `self.border_batch()` and `self.temporal_batch()`
1005
1044
  """
1006
- return PDENonStatioBatch(
1007
- inside_batch=self.inside_batch(),
1008
- border_batch=self.border_batch(),
1009
- temporal_batch=self.temporal_batch(),
1010
- )
1045
+ x = self.inside_batch()
1046
+ dx = self.border_batch()
1047
+ t = self.temporal_batch().reshape(self.temporal_batch_size, 1)
1048
+
1049
+ if self.cartesian_product:
1050
+ t_x = make_cartesian_product(t, x)
1051
+ else:
1052
+ t_x = jnp.concatenate([t, x], axis=1)
1053
+
1054
+ if dx is not None:
1055
+ t_ = t.reshape(self.temporal_batch_size, 1, 1)
1056
+ t_ = jnp.repeat(t_, dx.shape[-1], axis=2)
1057
+ if self.cartesian_product or self.dim == 1:
1058
+ t_dx = make_cartesian_product(t_, dx)
1059
+ else:
1060
+ t_dx = jnp.concatenate([t_, dx], axis=1)
1061
+ else:
1062
+ t_dx = None
1063
+
1064
+ return PDENonStatioBatch(times_x_inside_batch=t_x, times_x_border_batch=t_dx)
1011
1065
 
1012
1066
  def tree_flatten(self):
1013
1067
  children = (
@@ -1041,6 +1095,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
1041
1095
  "rar_parameters",
1042
1096
  "n_start",
1043
1097
  "nt_start",
1098
+ "cartesian_product",
1044
1099
  ]
1045
1100
  }
1046
1101
  return (children, aux_data)
@@ -1121,12 +1176,17 @@ class DataGeneratorParameter:
1121
1176
  once during 1 epoch.
1122
1177
  param_batch_size
1123
1178
  An integer. The size of the batch of randomly selected points among
1124
- the `n` points. `param_batch_size` will be the same for all the
1125
- additional batch(es) of parameter(s). `param_batch_size` must be
1126
- equal to `temporal_batch_size` or `omega_batch_size` or the product
1127
- of both whether the present DataGeneratorParameter instance
1128
- complements and ODEBatch, a PDEStatioBatch or a PDENonStatioBatch,
1129
- respectively.
1179
+ the `n` points. `param_batch_size` will be the same for all
1180
+ additional batch of parameter.
1181
+ NOTE: no check is done BUT users should be careful that
1182
+ `param_batch_size` must be equal to `temporal_batch_size` or
1183
+ `omega_batch_size` or the product of both. In the first case, the
1184
+ present DataGeneratorParameter instance complements an ODEBatch, a
1185
+ PDEStatioBatch or a PDENonStatioBatch (with self.cartesian_product
1186
+ = False). In the second case, `param_batch_size` =
1187
+ `temporal_batch_size * omega_batch_size` if the present
1188
+ DataGeneratorParameter complements a PDENonStatioBatch
1189
+ with self.cartesian_product = True
1130
1190
  param_ranges
1131
1191
  A dict. A dict of tuples (min, max), which
1132
1192
  reprensents the range of real numbers where to sample batches (of
@@ -1335,13 +1395,18 @@ class DataGeneratorObservations:
1335
1395
  key
1336
1396
  Jax random key to sample new time points and to shuffle batches
1337
1397
  obs_batch_size
1338
- An integer. The size of the batch of randomly selected observations
1339
- `obs_batch_size` will be the same for all the
1340
- elements of the obs dict. `obs_batch_size` must be
1341
- equal to `temporal_batch_size` or `omega_batch_size` or the product
1342
- of both whether the present DataGeneratorParameter instance
1343
- complements and ODEBatch, a PDEStatioBatch or a PDENonStatioBatch,
1344
- respectively.
1398
+ An integer. The size of the batch of randomly selected points among
1399
+ the `n` points. `obs_batch_size` will be the same for all
1400
+ elements of the obs dict.
1401
+ NOTE: no check is done BUT users should be careful that
1402
+ `obs_batch_size` must be equal to `temporal_batch_size` or
1403
+ `omega_batch_size` or the product of both. In the first case, the
1404
+ present DataGeneratorObservations instance complements an ODEBatch,
1405
+ PDEStatioBatch or a PDENonStatioBatch (with self.cartesian_product
1406
+ = False). In the second case, `obs_batch_size` =
1407
+ `temporal_batch_size * omega_batch_size` if the present
1408
+ DataGeneratorParameter complements a PDENonStatioBatch
1409
+ with self.cartesian_product = True
1345
1410
  observed_pinn_in
1346
1411
  A jnp.array with 2 dimensions.
1347
1412
  Observed values corresponding to the input of the PINN
@@ -1549,11 +1614,16 @@ class DataGeneratorObservationsMultiPINNs:
1549
1614
  obs_batch_size
1550
1615
  An integer. The size of the batch of randomly selected observations
1551
1616
  `obs_batch_size` will be the same for all the
1552
- elements of the obs dict. `obs_batch_size` must be
1553
- equal to `temporal_batch_size` or `omega_batch_size` or the product
1554
- of both whether the present DataGeneratorParameter instance
1555
- complements and ODEBatch, a PDEStatioBatch or a PDENonStatioBatch,
1556
- respectively.
1617
+ elements of the obs dict.
1618
+ NOTE: no check is done BUT users should be careful that
1619
+ `obs_batch_size` must be equal to `temporal_batch_size` or
1620
+ `omega_batch_size` or the product of both. In the first case, the
1621
+ present DataGeneratorObservations instance complements an ODEBatch,
1622
+ PDEStatioBatch or a PDENonStatioBatch (with self.cartesian_product
1623
+ = False). In the second case, `obs_batch_size` =
1624
+ `temporal_batch_size * omega_batch_size` if the present
1625
+ DataGeneratorParameter complements a PDENonStatioBatch
1626
+ with self.cartesian_product = True
1557
1627
  observed_pinn_in_dict
1558
1628
  A dict of observed_pinn_in as defined in DataGeneratorObservations.
1559
1629
  Keys must be that of `u_dict`.
jinns/loss/_LossODE.py CHANGED
@@ -459,7 +459,7 @@ class SystemLossODE:
459
459
 
460
460
  Parameters
461
461
  ---------
462
- params
462
+ params_dict
463
463
  A dictionary of dictionaries of parameters of the model.
464
464
  Typically, it is a dictionary of dictionaries of
465
465
  dictionaries: `eq_params` and `nn_params``, respectively the
@@ -489,7 +489,7 @@ class SystemLossODE:
489
489
  # and update vmap_in_axes
490
490
  if batch.param_batch_dict is not None:
491
491
  # update params with the batches of generated params
492
- params = _update_eq_params_dict(params, batch.param_batch_dict)
492
+ params_dict = _update_eq_params_dict(params_dict, batch.param_batch_dict)
493
493
 
494
494
  vmap_in_axes_params = _get_vmap_in_axes_params(
495
495
  batch.param_batch_dict, params_dict
jinns/loss/_LossPDE.py CHANGED
@@ -818,17 +818,9 @@ class LossPDENonStatio(LossPDEStatio):
818
818
  inputs/outputs/parameters
819
819
  """
820
820
 
821
- omega_batch, omega_border_batch, times_batch = (
822
- batch.inside_batch,
823
- batch.border_batch,
824
- batch.temporal_batch,
825
- )
821
+ times_batch = batch.times_x_inside_batch[:, 0:1]
822
+ omega_batch = batch.times_x_inside_batch[:, 1:]
826
823
  n = omega_batch.shape[0]
827
- nt = times_batch.shape[0]
828
- times_batch = times_batch.reshape(nt, 1)
829
-
830
- def rep_times(k):
831
- return jnp.repeat(times_batch, k, axis=0)
832
824
 
833
825
  vmap_in_axes_x_t = (0, 0)
834
826
 
@@ -844,10 +836,6 @@ class LossPDENonStatio(LossPDEStatio):
844
836
 
845
837
  vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
846
838
 
847
- if isinstance(self.u, PINN):
848
- omega_batch = jnp.tile(omega_batch, reps=(nt, 1)) # it is tiled
849
- times_batch = rep_times(n) # it is repeated
850
-
851
839
  # dynamic part
852
840
  params_ = _set_derivatives(params, "dyn_loss", self.derivative_keys)
853
841
  if self.dynamic_loss is not None:
@@ -1372,27 +1360,12 @@ class SystemLossPDE:
1372
1360
 
1373
1361
  if isinstance(batch, PDEStatioBatch):
1374
1362
  omega_batch, _ = batch.inside_batch, batch.border_batch
1375
- n = omega_batch.shape[0]
1376
1363
  vmap_in_axes_x_or_x_t = (0,)
1377
1364
 
1378
1365
  batches = (omega_batch,)
1379
1366
  elif isinstance(batch, PDENonStatioBatch):
1380
- omega_batch, _, times_batch = (
1381
- batch.inside_batch,
1382
- batch.border_batch,
1383
- batch.temporal_batch,
1384
- )
1385
- n = omega_batch.shape[0]
1386
- nt = times_batch.shape[0]
1387
- times_batch = times_batch.reshape(nt, 1)
1388
-
1389
- def rep_times(k):
1390
- return jnp.repeat(times_batch, k, axis=0)
1391
-
1392
- # Moreover...
1393
- if isinstance(list(self.u_dict.values())[0], PINN):
1394
- omega_batch = jnp.tile(omega_batch, reps=(nt, 1)) # it is tiled
1395
- times_batch = rep_times(n) # it is repeated
1367
+ times_batch = batch.times_x_inside_batch[:, 0:1]
1368
+ omega_batch = batch.times_x_inside_batch[:, 1:]
1396
1369
 
1397
1370
  batches = (omega_batch, times_batch)
1398
1371
  vmap_in_axes_x_or_x_t = (0, 0)
jinns/loss/_Losses.py CHANGED
@@ -120,16 +120,23 @@ def boundary_condition_apply(
120
120
  else:
121
121
  raise ValueError("Other border batches are not implemented")
122
122
  b_losses_by_facet = jax.tree_util.tree_map(
123
- lambda c, f, fa, d: jnp.mean(
124
- loss_weight * _compute_boundary_loss(c, f, batch, u, params, fa, d)
123
+ lambda c, f, fa, d: (
124
+ None
125
+ if c is None
126
+ else jnp.mean(
127
+ loss_weight * _compute_boundary_loss(c, f, batch, u, params, fa, d)
128
+ )
125
129
  ),
126
130
  omega_boundary_condition,
127
131
  omega_boundary_fun,
128
132
  facet_tree,
129
133
  omega_boundary_dim,
134
+ is_leaf=lambda x: x is None,
130
135
  ) # when exploring leaves with None value (no condition) the returned
131
136
  # mse is None and we get rid of the None leaves of b_losses_by_facet
132
137
  # with the tree_leaves below
138
+ # Note that to keep the behaviour given in the comment above we neede
139
+ # to specify is_leaf according to the note in the release of 0.4.29
133
140
  else:
134
141
  facet_tuple = tuple(f for f in range(batch[1].shape[-1]))
135
142
  b_losses_by_facet = jax.tree_util.tree_map(
@@ -279,23 +279,10 @@ def boundary_dirichlet_nonstatio(f, batch, u, params, facet, dim_to_apply):
279
279
  dim_to_apply
280
280
  A jnp.s\_ object. The dimension of u on which to apply the boundary condition
281
281
  """
282
- _, omega_border_batch, times_batch = (
283
- batch.inside_batch,
284
- batch.border_batch,
285
- batch.temporal_batch,
286
- )
287
- nt = times_batch.shape[0]
288
- times_batch = times_batch.reshape(nt, 1)
289
- omega_border_batch = omega_border_batch[..., facet]
282
+ times_batch = batch.times_x_border_batch[:, 0:1, facet]
283
+ omega_border_batch = batch.times_x_border_batch[:, 1:, facet]
290
284
 
291
285
  if isinstance(u, PINN):
292
- tile_omega_border_batch = jnp.tile(
293
- omega_border_batch, reps=(times_batch.shape[0], 1)
294
- )
295
-
296
- def rep_times(k):
297
- return jnp.repeat(times_batch, k, axis=0)
298
-
299
286
  vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
300
287
  vmap_in_axes_x_t = (0, 0)
301
288
 
@@ -309,26 +296,13 @@ def boundary_dirichlet_nonstatio(f, batch, u, params, facet, dim_to_apply):
309
296
  vmap_in_axes_x_t + vmap_in_axes_params,
310
297
  0,
311
298
  )
312
- res = v_u_boundary(
313
- rep_times(omega_border_batch.shape[0]), tile_omega_border_batch, params
314
- )
299
+ res = v_u_boundary(times_batch, omega_border_batch, params)
315
300
  mse_u_boundary = jnp.sum(
316
301
  res**2,
317
302
  axis=-1,
318
303
  )
319
304
  elif isinstance(u, SPINN):
320
- tile_omega_border_batch = jnp.tile(
321
- omega_border_batch, reps=(times_batch.shape[0], 1)
322
- )
323
-
324
- if omega_border_batch.shape[0] == 1:
325
- omega_border_batch = jnp.tile(
326
- omega_border_batch, reps=(times_batch.shape[0], 1)
327
- )
328
- # otherwise we require batches to have same shape and we do not need
329
- # this operation
330
-
331
- values = u(times_batch, tile_omega_border_batch, params)[..., dim_to_apply]
305
+ values = u(times_batch, omega_border_batch, params)[..., dim_to_apply]
332
306
  tx_grid = _get_grid(jnp.concatenate([times_batch, omega_border_batch], axis=-1))
333
307
  boundaries = _check_user_func_return(
334
308
  f(tx_grid[..., 0:1], tx_grid[..., 1:]), values.shape
@@ -367,14 +341,8 @@ def boundary_neumann_nonstatio(f, batch, u, params, facet, dim_to_apply):
367
341
  dim_to_apply
368
342
  A jnp.s\_ object. The dimension of u on which to apply the boundary condition
369
343
  """
370
- _, omega_border_batch, times_batch = (
371
- batch.inside_batch,
372
- batch.border_batch,
373
- batch.temporal_batch,
374
- )
375
- nt = times_batch.shape[0]
376
- times_batch = times_batch.reshape(nt, 1)
377
- omega_border_batch = omega_border_batch[..., facet]
344
+ times_batch = batch.times_x_border_batch[:, 0:1, facet]
345
+ omega_border_batch = batch.times_x_border_batch[:, 1:, facet]
378
346
 
379
347
  # We resort to the shape of the border_batch to determine the dimension as
380
348
  # described in the border_batch function
@@ -388,13 +356,6 @@ def boundary_neumann_nonstatio(f, batch, u, params, facet, dim_to_apply):
388
356
  n = jnp.array([[-1, 1, 0, 0], [0, 0, -1, 1]])
389
357
 
390
358
  if isinstance(u, PINN):
391
- tile_omega_border_batch = jnp.tile(
392
- omega_border_batch, reps=(times_batch.shape[0], 1)
393
- )
394
-
395
- def rep_times(k):
396
- return jnp.repeat(times_batch, k, axis=0)
397
-
398
359
  vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
399
360
  vmap_in_axes_x_t = (0, 0)
400
361
 
@@ -411,8 +372,8 @@ def boundary_neumann_nonstatio(f, batch, u, params, facet, dim_to_apply):
411
372
  mse_u_boundary = jnp.sum(
412
373
  (
413
374
  v_neumann(
414
- rep_times(omega_border_batch.shape[0]),
415
- tile_omega_border_batch,
375
+ times_batch,
376
+ omega_border_batch,
416
377
  params,
417
378
  )
418
379
  )
@@ -421,14 +382,6 @@ def boundary_neumann_nonstatio(f, batch, u, params, facet, dim_to_apply):
421
382
  )
422
383
 
423
384
  elif isinstance(u, SPINN):
424
- if omega_border_batch.shape[0] == 1:
425
- omega_border_batch = jnp.tile(
426
- omega_border_batch, reps=(times_batch.shape[0], 1)
427
- )
428
- # ie case 1D
429
- # otherwise we require batches to have same shape and we do not need
430
- # this operation
431
-
432
385
  # the gradient we see in the PINN case can get gradients wrt to x
433
386
  # dimensions at once. But it would be very inefficient in SPINN because
434
387
  # of the high dim output of u. So we do 2 explicit forward AD, handling all the
jinns/solver/_rar.py CHANGED
@@ -264,6 +264,9 @@ def _rar_step_init(sample_size, selected_sample_size):
264
264
  )
265
265
 
266
266
  elif isinstance(data, CubicMeshPDENonStatio):
267
+ if isinstance(loss.u, HYPERPINN) or isinstance(loss.u, SPINN):
268
+ raise NotImplementedError("RAR not implemented for hyperPINN and SPINN")
269
+
267
270
  # NOTE in this case sample_size and selected_sample_size
268
271
  # are tuples (times, omega) => we unpack them for clarity
269
272
  selected_sample_size_times, selected_sample_size_omega = (
@@ -274,14 +277,15 @@ def _rar_step_init(sample_size, selected_sample_size):
274
277
  new_times_samples = data.sample_in_time_domain(sample_size_times)
275
278
  new_omega_samples = data.sample_in_omega_domain(sample_size_omega)
276
279
 
277
- if isinstance(loss.u, HYPERPINN) or isinstance(loss.u, SPINN):
278
- raise NotImplementedError("RAR not implemented for hyperPINN and SPINN")
280
+ if not data.cartesian_product:
281
+ times = new_times_samples
282
+ omega = new_omega_samples
279
283
  else:
280
284
  # do cartesian product on new points
281
- tile_omega = jnp.tile(
285
+ omega = jnp.tile(
282
286
  new_omega_samples, reps=(sample_size_times, 1)
283
287
  ) # it is tiled
284
- repeat_times = jnp.repeat(new_times_samples, sample_size_omega, axis=0)[
288
+ times = jnp.repeat(new_times_samples, sample_size_omega, axis=0)[
285
289
  ..., None
286
290
  ] # it is repeated + add an axis
287
291
 
@@ -291,7 +295,7 @@ def _rar_step_init(sample_size, selected_sample_size):
291
295
  (0, 0),
292
296
  0,
293
297
  )
294
- dyn_on_s = v_dyn_loss(repeat_times, tile_omega).reshape(
298
+ dyn_on_s = v_dyn_loss(times, omega).reshape(
295
299
  (sample_size_times, sample_size_omega)
296
300
  )
297
301
  mse_on_s = dyn_on_s**2
@@ -305,7 +309,7 @@ def _rar_step_init(sample_size, selected_sample_size):
305
309
  (0, 0),
306
310
  0,
307
311
  )
308
- dyn_on_s += v_dyn_loss(repeat_times, tile_omega).reshape(
312
+ dyn_on_s += v_dyn_loss(times, omega).reshape(
309
313
  (sample_size_times, sample_size_omega)
310
314
  )
311
315
 
jinns/solver/_solve.py CHANGED
@@ -61,6 +61,7 @@ def solve(
61
61
  obs_data=None,
62
62
  validation=None,
63
63
  obs_batch_sharding=None,
64
+ verbose=True,
64
65
  ):
65
66
  """
66
67
  Performs the optimization process via stochastic gradient descent
@@ -132,6 +133,9 @@ def solve(
132
133
  Typically, a SingleDeviceSharding(gpu_device) when obs_data has been
133
134
  created with sharding_device=SingleDeviceSharding(cpu_device) to avoid
134
135
  loading on GPU huge datasets of observations
136
+ verbose:
137
+ Boolean, default True. If False, no std output (loss or cause of
138
+ exiting the optimization loop) will be produced.
135
139
 
136
140
  Returns
137
141
  -------
@@ -203,11 +207,14 @@ def solve(
203
207
  data=data, param_data=param_data, obs_data=obs_data
204
208
  )
205
209
  optimization = OptimizationContainer(
206
- params=init_params, last_non_nan_params=init_params.copy(), opt_state=opt_state
210
+ params=init_params,
211
+ last_non_nan_params=init_params.copy(),
212
+ opt_state=opt_state,
207
213
  )
208
214
  optimization_extra = OptimizationExtraContainer(
209
215
  curr_seq=curr_seq,
210
216
  seq2seq=seq2seq,
217
+ best_val_params=init_params.copy(),
211
218
  )
212
219
  loss_container = LossContainer(
213
220
  stored_loss_terms=stored_loss_terms,
@@ -222,7 +229,7 @@ def solve(
222
229
  else:
223
230
  validation_crit_values = None
224
231
 
225
- break_fun = get_break_fun(n_iter)
232
+ break_fun = get_break_fun(n_iter, verbose)
226
233
 
227
234
  iteration = 0
228
235
  carry = (
@@ -272,7 +279,8 @@ def solve(
272
279
  )
273
280
 
274
281
  # Print train loss value during optimization
275
- print_fn(i, train_loss_value, print_loss_every, prefix="[train] ")
282
+ if verbose:
283
+ print_fn(i, train_loss_value, print_loss_every, prefix="[train] ")
276
284
 
277
285
  if validation is not None:
278
286
  # there is a jax.lax.cond because we do not necesarily call the
@@ -281,6 +289,7 @@ def solve(
281
289
  validation, # always return `validation` for in-place mutation
282
290
  early_stopping,
283
291
  validation_criterion,
292
+ update_best_params,
284
293
  ) = jax.lax.cond(
285
294
  i % validation.call_every == 0,
286
295
  lambda operands: operands[0](*operands[1:]), # validation.__call__()
@@ -288,6 +297,7 @@ def solve(
288
297
  operands[0],
289
298
  False,
290
299
  validation_crit_values[i - 1],
300
+ False,
291
301
  ),
292
302
  (
293
303
  validation, # validation must be in operands
@@ -295,12 +305,24 @@ def solve(
295
305
  ),
296
306
  )
297
307
  # Print validation loss value during optimization
298
- print_fn(i, validation_criterion, print_loss_every, prefix="[validation] ")
308
+ if verbose:
309
+ print_fn(
310
+ i, validation_criterion, print_loss_every, prefix="[validation] "
311
+ )
299
312
  validation_crit_values = validation_crit_values.at[i].set(
300
313
  validation_criterion
301
314
  )
315
+
316
+ # update best_val_params w.r.t val_loss if needed
317
+ best_val_params = jax.lax.cond(
318
+ update_best_params,
319
+ lambda _: params, # update with current value
320
+ lambda operands: operands[0].best_val_params, # unchanged
321
+ (optimization_extra,),
322
+ )
302
323
  else:
303
324
  early_stopping = False
325
+ best_val_params = params
304
326
 
305
327
  # Trigger RAR
306
328
  loss, params, data = trigger_rar(
@@ -329,13 +351,17 @@ def solve(
329
351
  loss_terms,
330
352
  tracked_params,
331
353
  )
354
+
355
+ # increment iteration number
332
356
  i += 1
333
357
 
334
358
  return (
335
359
  i,
336
360
  loss,
337
361
  OptimizationContainer(params, last_non_nan_params, opt_state),
338
- OptimizationExtraContainer(curr_seq, seq2seq, early_stopping),
362
+ OptimizationExtraContainer(
363
+ curr_seq, seq2seq, best_val_params, early_stopping
364
+ ),
339
365
  DataGeneratorContainer(data, param_data, obs_data),
340
366
  validation,
341
367
  LossContainer(stored_loss_terms, train_loss_values),
@@ -364,36 +390,28 @@ def solve(
364
390
  validation_crit_values,
365
391
  ) = carry
366
392
 
367
- jax.debug.print(
368
- "Final iteration {i}: train loss value = {train_loss_val}",
369
- i=i,
370
- train_loss_val=loss_container.train_loss_values[i - 1],
371
- )
393
+ if verbose:
394
+ jax.debug.print(
395
+ "Final iteration {i}: train loss value = {train_loss_val}",
396
+ i=i,
397
+ train_loss_val=loss_container.train_loss_values[i - 1],
398
+ )
372
399
  if validation is not None:
373
400
  jax.debug.print(
374
401
  "validation loss value = {validation_loss_val}",
375
402
  validation_loss_val=validation_crit_values[i - 1],
376
403
  )
377
404
 
378
- if validation is None:
379
- return (
380
- optimization.last_non_nan_params,
381
- loss_container.train_loss_values,
382
- loss_container.stored_loss_terms,
383
- train_data.data,
384
- loss,
385
- optimization.opt_state,
386
- stored_objects.stored_params,
387
- )
388
405
  return (
389
406
  optimization.last_non_nan_params,
390
407
  loss_container.train_loss_values,
391
408
  loss_container.stored_loss_terms,
392
- train_data.data,
393
- loss,
409
+ train_data.data, # return the DataGenerator if needed (no in-place modif)
410
+ loss, # return the Loss if needed (no-inplace modif)
394
411
  optimization.opt_state,
395
412
  stored_objects.stored_params,
396
- validation_crit_values,
413
+ validation_crit_values if validation is not None else None,
414
+ optimization_extra.best_val_params if validation is not None else None,
397
415
  )
398
416
 
399
417
 
@@ -453,15 +471,20 @@ def store_loss_and_params(
453
471
  tracked_params,
454
472
  ):
455
473
  stored_params = jax.tree_util.tree_map(
456
- lambda stored_value, param, tracked_param: jax.lax.cond(
457
- tracked_param,
458
- lambda ope: ope[0].at[i].set(ope[1]),
459
- lambda ope: ope[0],
460
- (stored_value, param),
474
+ lambda stored_value, param, tracked_param: (
475
+ None
476
+ if stored_value is None
477
+ else jax.lax.cond(
478
+ tracked_param,
479
+ lambda ope: ope[0].at[i].set(ope[1]),
480
+ lambda ope: ope[0],
481
+ (stored_value, param),
482
+ )
461
483
  ),
462
484
  stored_params,
463
485
  params,
464
486
  tracked_params,
487
+ is_leaf=lambda x: x is None,
465
488
  )
466
489
  stored_loss_terms = jax.tree_util.tree_map(
467
490
  lambda stored_term, loss_term: stored_term.at[i].set(loss_term),
@@ -473,16 +496,20 @@ def store_loss_and_params(
473
496
  return (stored_params, stored_loss_terms, train_loss_values)
474
497
 
475
498
 
476
- def get_break_fun(n_iter):
499
+ def get_break_fun(n_iter, verbose: str):
477
500
  """
478
- Wrapper to get the break_fun with appropriate `n_iter`
501
+ Wrapper to get the break_fun with appropriate `n_iter`.
502
+ The verbose argument is here to control printing (or not) when exiting
503
+ the optimisation loop. It can be convenient is jinns.solve is itself
504
+ called in a loop and user want to avoid std output.
479
505
  """
480
506
 
481
507
  @jit
482
- def break_fun(carry):
508
+ def break_fun(carry: tuple):
483
509
  """
484
510
  Function to break from the main optimization loop
485
- We check several conditions
511
+ We the following conditions : maximum number of iterations, NaN
512
+ appearing in the parameters, and early stopping criterion.
486
513
  """
487
514
 
488
515
  def stop_while_loop(msg):
@@ -490,7 +517,8 @@ def get_break_fun(n_iter):
490
517
  Note that the message is wrapped in the jax.lax.cond because a
491
518
  string is not a valid JAX type that can be fed into the operands
492
519
  """
493
- jax.debug.print(f"Stopping main optimization loop, cause: {msg}")
520
+ if verbose:
521
+ jax.debug.print(f"Stopping main optimization loop, cause: {msg}")
494
522
  return False
495
523
 
496
524
  def continue_while_loop(_):
@@ -45,6 +45,7 @@ class OptimizationContainer(NamedTuple):
45
45
  class OptimizationExtraContainer(NamedTuple):
46
46
  curr_seq: int
47
47
  seq2seq: Union[dict, None]
48
+ best_val_params: dict
48
49
  early_stopping: bool = False
49
50
 
50
51
 
jinns/utils/_hyperpinn.py CHANGED
@@ -78,15 +78,9 @@ class HYPERPINN(PINN):
78
78
  parameters of the pinn (`self.params`)
79
79
  """
80
80
  pinn_params_flat = eqx.tree_at(
81
- lambda p: tree_leaves(p, is_leaf=lambda x: isinstance(x, jnp.ndarray)),
81
+ lambda p: tree_leaves(p, is_leaf=eqx.is_array),
82
82
  self.params,
83
- [hyper_output[0 : self.pinn_params_cumsum[0]]]
84
- + [
85
- hyper_output[
86
- self.pinn_params_cumsum[i] : self.pinn_params_cumsum[i + 1]
87
- ]
88
- for i in range(len(self.pinn_params_cumsum) - 1)
89
- ],
83
+ jnp.split(hyper_output, self.pinn_params_cumsum[:-1]),
90
84
  )
91
85
 
92
86
  return tree_map(
jinns/utils/_spinn.py CHANGED
@@ -81,6 +81,9 @@ class SPINN(eqx.Module):
81
81
  Basically a wrapper around the `__call__` function to be able to give a type to
82
82
  our former `self.u`
83
83
  The function create_SPINN has the role to population the `__call__` function
84
+
85
+ **NOTE**: SPINNs with `t` and `x` as inputs are best used with a
86
+ DataGenerator with `self.cartesian_product=False` for memory consideration
84
87
  """
85
88
 
86
89
  d: int
@@ -191,6 +194,9 @@ def create_SPINN(key, d, r, eqx_list, eq_type, m=1):
191
194
  then sum groups of `r` embedding dimensions to compute each output.
192
195
  Default is 1.
193
196
 
197
+ **NOTE**: SPINNs with `t` and `x` as inputs are best used with a
198
+ DataGenerator with `self.cartesian_product=False` for memory consideration
199
+
194
200
 
195
201
  Returns
196
202
  -------
@@ -42,7 +42,7 @@ class AbstractValidationModule(eqx.Module):
42
42
  @abc.abstractmethod
43
43
  def __call__(
44
44
  self, params: PyTree
45
- ) -> tuple["AbstractValidationModule", Bool, Array]:
45
+ ) -> tuple["AbstractValidationModule", Bool, Array, Bool]:
46
46
  raise NotImplementedError
47
47
 
48
48
 
@@ -86,10 +86,10 @@ class ValidationLoss(AbstractValidationModule):
86
86
  )
87
87
 
88
88
  validation_loss_value, _ = self.loss(params, val_batch)
89
- (counter, best_val_loss) = jax.lax.cond(
89
+ (counter, best_val_loss, update_best_params) = jax.lax.cond(
90
90
  validation_loss_value < self.best_val_loss,
91
- lambda _: (jnp.array(0.0), validation_loss_value), # reset
92
- lambda operands: (operands[0] + 1, operands[1]), # increment
91
+ lambda _: (jnp.array(0.0), validation_loss_value, True), # reset
92
+ lambda operands: (operands[0] + 1, operands[1], False), # increment
93
93
  (self.counter, self.best_val_loss),
94
94
  )
95
95
 
@@ -108,7 +108,7 @@ class ValidationLoss(AbstractValidationModule):
108
108
  None,
109
109
  )
110
110
  # return `new` cause no in-place modification of the eqx.Module
111
- return (new, bool_early_stopping, validation_loss_value)
111
+ return (new, bool_early_stopping, validation_loss_value, update_best_params)
112
112
 
113
113
 
114
114
  if __name__ == "__main__":
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jinns
3
- Version: 0.8.10
3
+ Version: 0.9.0
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -20,6 +20,7 @@ Requires-Dist: optax
20
20
  Requires-Dist: equinox
21
21
  Requires-Dist: jax-tqdm
22
22
  Requires-Dist: diffrax
23
+ Requires-Dist: matplotlib
23
24
  Provides-Extra: notebook
24
25
  Requires-Dist: jupyter ; extra == 'notebook'
25
26
  Requires-Dist: matplotlib ; extra == 'notebook'
@@ -1,5 +1,5 @@
1
1
  jinns/__init__.py,sha256=T2XlmLbYqcXTumPJL00cJ80W98We5LH8Yg_Lss_exl4,139
2
- jinns/data/_DataGenerators.py,sha256=_um-giHQ8mCILUOJHX231njHTHZp4S7EcGrUs7R1dUs,61829
2
+ jinns/data/_DataGenerators.py,sha256=aR-L2ZTHnogU0CFRemtcg1PHlQCp_UZFPnXvxE8hBj0,65069
3
3
  jinns/data/__init__.py,sha256=yBOmoavSD-cABp4XcjQY1zsEVO0mDyIhi2MJ5WNp0l8,326
4
4
  jinns/data/_display.py,sha256=vlqggDCgVMEwdGBtjVmZaTQORU6imSfDkssn2XCtITI,10392
5
5
  jinns/experimental/__init__.py,sha256=qWbhC7Z8UgLWy0t-zU7RYze6v13-FngiCYXu-2bRVFQ,296
@@ -8,29 +8,29 @@ jinns/experimental/_sinuspinn.py,sha256=hxSzscwMV2LayWOqenIlT1zqEVVrE5Y8CKf7bHX5
8
8
  jinns/experimental/_spectralpinn.py,sha256=-4795pa7AYtRNSE-ugan3gHh64mtu2VdrRG5AS_J9Eg,2654
9
9
  jinns/loss/_DynamicLoss.py,sha256=L4CVmmF0rTPbHntgqsLLHlnrlQgLHsetUocpJm7ZYag,27461
10
10
  jinns/loss/_DynamicLossAbstract.py,sha256=kTQlhLx7SBuH5dIDmYaE79sVHUZt1nUFa8LxPU5IHhM,8504
11
- jinns/loss/_LossODE.py,sha256=Y1mxryPVFf7ruqw_mGNACLExfx4iQT4R2bZP3s5rg4c,22172
12
- jinns/loss/_LossPDE.py,sha256=purAEtc0e71kv9XnZUT-a7MrkDAkM_3tTI4xJPu6fH4,61629
13
- jinns/loss/_Losses.py,sha256=XOL3MFiKEd3ndsc78Qnpi1vbgR0B2HaAWOGGW2meDM8,11190
11
+ jinns/loss/_LossODE.py,sha256=Ava8kcAzj4-GiUkHJfZyhq15vGw7ABXk8Zpa3ynMlmY,22187
12
+ jinns/loss/_LossPDE.py,sha256=zsm5gxIU2aRq_UdBOkQa6CyVXU6lqu740W8iAghDHV8,60707
13
+ jinns/loss/_Losses.py,sha256=4tRnOgT31ZMDmdAKRuJv1e4Ob5zS5a54YnIVOQr0-uc,11480
14
14
  jinns/loss/__init__.py,sha256=pFNYUxns-NPXBFdqrEVSiXkQLfCtKw-t2trlhvLzpYE,355
15
- jinns/loss/_boundary_conditions.py,sha256=YfSnLZ25hXqQ5KWAuxOrWSKkf_oBqAc9GQV4z7MjWyQ,17434
15
+ jinns/loss/_boundary_conditions.py,sha256=lqva0LKyPWls0zSTD-UgWfh_Aul_Q88NTN2Lz7KMM1M,15978
16
16
  jinns/loss/_operators.py,sha256=zDGJqYqeYH7xd-4dtGX9PS-pf0uSOpUUXGo5SVjIJ4o,11069
17
17
  jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
- jinns/solver/_rar.py,sha256=IYP-jdbM0rbjBtxislrBYBuj49p9_QDOqejZKCHrKg8,17072
18
+ jinns/solver/_rar.py,sha256=wxoq9yNGd3UR6669FhqXJgD-LeKPBJ01t38alXculN8,17164
19
19
  jinns/solver/_seq2seq.py,sha256=S6IPfsXpS_fbqIqAy01eUM7GBSBSkRzURan_J-iXXzI,5632
20
- jinns/solver/_solve.py,sha256=mGi0zaT_fK_QpBjTxof5Ix4mmfmnPi66CNJ3GQFZuo4,19099
20
+ jinns/solver/_solve.py,sha256=QGaiH9EXgk8t19250qB1KhYMNuIMlv0YL6RCM710K58,20376
21
21
  jinns/utils/__init__.py,sha256=44ms5UR6vMw3Nf6u4RCAzPFs4fom_YbBnH9mfne8m6k,313
22
- jinns/utils/_containers.py,sha256=eYD277fO7X4EfX7PUFCCl69r3JBfh1sCfq8LkL5gd6o,1495
23
- jinns/utils/_hyperpinn.py,sha256=93hbiATdp5W4l1cu9Oe6O2c45o-ZF_z2u6FzNLyjnm4,10878
22
+ jinns/utils/_containers.py,sha256=nH67kpgN3Ir6M31rqApNVR7xHV7bWzFtQsrDAYuQHDo,1521
23
+ jinns/utils/_hyperpinn.py,sha256=gzT6keKvlLPZ5VnBxvW-IBDyLuAtPLgNdhtPulosmfc,10638
24
24
  jinns/utils/_optim.py,sha256=550kxH75TL30o1iKx1swJyP0KqyUPsJ7-imL1w65Qd0,4444
25
25
  jinns/utils/_pinn.py,sha256=mhA4-3PazyQTbWIx9oLaNwL0QDe8ZIBhbiy5J3kwa4I,9471
26
26
  jinns/utils/_save_load.py,sha256=qgZ23nUcB8-B5IZ2guuUWC4M7r5Lxd_Ms3staScdyJo,5668
27
- jinns/utils/_spinn.py,sha256=SzOUt1KHtB9QOpghpvitnXN-KEqXUXbvabC5k0TnKEo,7793
27
+ jinns/utils/_spinn.py,sha256=2CKTQv2PCvprJJiUlKv3eeo6SIZ5ZCTdajL0D1sul90,8093
28
28
  jinns/utils/_utils.py,sha256=8dgvWXX9NT7_7-zltWp0C9tG45ZFNwXxueyxPBb4hjo,6740
29
29
  jinns/utils/_utils_uspinn.py,sha256=qcKcOw3zrwWSQyGVj6fD8c9GinHt_U6JWN_k0auTtXM,26039
30
30
  jinns/validation/__init__.py,sha256=Jv58mzgC3F7cRfXA6caicL1t_U0UAhbwLrmMNVg6E7s,66
31
- jinns/validation/_validation.py,sha256=KfetbzB0xTNdBcYLwFWjEtP63Tf9wJirlhgqLTJDyy4,6761
32
- jinns-0.8.10.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
33
- jinns-0.8.10.dist-info/METADATA,sha256=5lGoyi2W9MRamQdHVgZnYflJtp__zWDGyTiYgmfGc6g,2483
34
- jinns-0.8.10.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
35
- jinns-0.8.10.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
36
- jinns-0.8.10.dist-info/RECORD,,
31
+ jinns/validation/_validation.py,sha256=ob2b9l2txJDujM8QR3RrS_eUrrz2RPsSadM07zXJJPs,6820
32
+ jinns-0.9.0.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
33
+ jinns-0.9.0.dist-info/METADATA,sha256=e-K6ITCYu6DogJd5P6jg5C1vc7MNgHEoxdVmY4H-LcQ,2508
34
+ jinns-0.9.0.dist-info/WHEEL,sha256=y4mX-SOX4fYIkonsAGA5N0Oy-8_gI4FXw5HNI1xqvWg,91
35
+ jinns-0.9.0.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
36
+ jinns-0.9.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.43.0)
2
+ Generator: setuptools (70.2.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5