jinns 0.8.10__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +953 -1182
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +321 -168
  9. jinns/loss/_LossODE.py +290 -307
  10. jinns/loss/_LossPDE.py +628 -1040
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +95 -96
  13. jinns/loss/{_Losses.py → _loss_utils.py} +104 -46
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +94 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +193 -45
  22. jinns/solver/_solve.py +199 -144
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -43
  25. jinns/utils/_hyperpinn.py +226 -127
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +117 -84
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +52 -144
  32. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/METADATA +5 -4
  33. jinns-1.0.0.dist-info/RECORD +38 -0
  34. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/WHEEL +1 -1
  35. jinns/experimental/_sinuspinn.py +0 -135
  36. jinns/experimental/_spectralpinn.py +0 -87
  37. jinns/solver/_seq2seq.py +0 -157
  38. jinns/utils/_optim.py +0 -147
  39. jinns/utils/_utils_uspinn.py +0 -727
  40. jinns-0.8.10.dist-info/RECORD +0 -36
  41. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/LICENSE +0 -0
  42. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/top_level.txt +0 -0
jinns/solver/_rar.py CHANGED
@@ -1,21 +1,32 @@
1
+ from __future__ import (
2
+ annotations,
3
+ ) # https://docs.python.org/3/library/typing.html#constant
4
+
5
+ from typing import TYPE_CHECKING, Callable
6
+ from functools import partial
1
7
  import jax
2
8
  from jax import vmap
3
9
  import jax.numpy as jnp
10
+ import equinox as eqx
11
+ from jaxtyping import Int, Bool
12
+
13
+ from jinns.data._Batchs import *
14
+ from jinns.loss._LossODE import LossODE, SystemLossODE
15
+ from jinns.loss._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
4
16
  from jinns.data._DataGenerators import (
5
17
  DataGeneratorODE,
6
18
  CubicMeshPDEStatio,
7
19
  CubicMeshPDENonStatio,
8
20
  )
9
- from jinns.loss._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
10
- from jinns.loss._LossODE import LossODE, SystemLossODE
11
- from jinns.loss._DynamicLossAbstract import PDEStatio
12
-
13
- from functools import partial
14
21
  from jinns.utils._hyperpinn import HYPERPINN
15
22
  from jinns.utils._spinn import SPINN
16
23
 
17
24
 
18
- def _proceed_to_rar(data, i):
25
+ if TYPE_CHECKING:
26
+ from jinns.utils._types import *
27
+
28
+
29
+ def _proceed_to_rar(data: AnyDataGenerator, i: Int) -> Bool:
19
30
  """Utilility function with various check to ensure we can proceed with the rar_step.
20
31
  Return True if yes, and False otherwise"""
21
32
 
@@ -30,13 +41,13 @@ def _proceed_to_rar(data, i):
30
41
  # Memory allocation checks (depends on the type of DataGenerator)
31
42
  # check if we still have room to append new collocation points in the
32
43
  # allocated jnp.array (can concern `data.p_times` or `p_omega`)
33
- if isinstance(data, DataGeneratorODE) or isinstance(data, CubicMeshPDENonStatio):
44
+ if isinstance(data, (DataGeneratorODE, CubicMeshPDENonStatio)):
34
45
  check_list.append(
35
46
  data.rar_parameters["selected_sample_size_times"]
36
47
  <= jnp.count_nonzero(data.p_times == 0),
37
48
  )
38
49
 
39
- if isinstance(data, CubicMeshPDEStatio) or isinstance(data, CubicMeshPDENonStatio):
50
+ if isinstance(data, (CubicMeshPDEStatio, CubicMeshPDENonStatio)):
40
51
  # for now the above check are redundants but there may be a time when
41
52
  # we drop inheritence
42
53
  check_list.append(
@@ -49,7 +60,14 @@ def _proceed_to_rar(data, i):
49
60
 
50
61
 
51
62
  @partial(jax.jit, static_argnames=["_rar_step_true", "_rar_step_false"])
52
- def trigger_rar(i, loss, params, data, _rar_step_true, _rar_step_false):
63
+ def trigger_rar(
64
+ i: Int,
65
+ loss: AnyLoss,
66
+ params: AnyParams,
67
+ data: AnyDataGenerator,
68
+ _rar_step_true: Callable[[rar_operands], AnyDataGenerator],
69
+ _rar_step_false: Callable[[rar_operands], AnyDataGenerator],
70
+ ) -> tuple[AnyLoss, AnyParams, AnyDataGenerator]:
53
71
 
54
72
  if data.rar_parameters is None:
55
73
  # do nothing.
@@ -65,7 +83,13 @@ def trigger_rar(i, loss, params, data, _rar_step_true, _rar_step_false):
65
83
  return loss, params, data
66
84
 
67
85
 
68
- def init_rar(data):
86
+ def init_rar(
87
+ data: AnyDataGenerator,
88
+ ) -> tuple[
89
+ AnyDataGenerator,
90
+ Callable[[rar_operands], AnyDataGenerator],
91
+ Callable[[rar_operands], AnyDataGenerator],
92
+ ]:
69
93
  """
70
94
  Separated from the main rar, because the initialization to get _true and
71
95
  _false cannot be jit-ted.
@@ -100,13 +124,21 @@ def init_rar(data):
100
124
  data.rar_parameters["sample_size_omega"],
101
125
  data.rar_parameters["selected_sample_size_omega"],
102
126
  )
127
+ else:
128
+ raise ValueError(f"Wrong type for data got {type(data)}")
103
129
 
104
- data.rar_parameters["iter_from_last_sampling"] = 0
130
+ if isinstance(data, eqx.Module):
131
+ data = eqx.tree_at(lambda m: m.rar_iter_from_last_sampling, data, 0)
132
+ else:
133
+ data.rar_iter_from_last_sampling = 0
105
134
 
106
135
  return data, _rar_step_true, _rar_step_false
107
136
 
108
137
 
109
- def _rar_step_init(sample_size, selected_sample_size):
138
+ def _rar_step_init(sample_size: Int, selected_sample_size: Int) -> tuple[
139
+ Callable[[rar_operands], AnyDataGenerator],
140
+ Callable[[rar_operands], AnyDataGenerator],
141
+ ]:
110
142
  """
111
143
  This is a wrapper because the sampling size and
112
144
  selected_sample_size, must be treated as static
@@ -116,11 +148,17 @@ def _rar_step_init(sample_size, selected_sample_size):
116
148
  This is a kind of manual declaration of static argnums
117
149
  """
118
150
 
119
- def rar_step_true(operands):
151
+ def rar_step_true(operands: rar_operands) -> AnyDataGenerator:
120
152
  loss, params, data, i = operands
121
153
 
122
154
  if isinstance(data, DataGeneratorODE):
123
- new_omega_samples = data.sample_in_time_domain(sample_size)
155
+
156
+ if isinstance(data, eqx.Module):
157
+ new_key, subkey = jax.random.split(data.key)
158
+ new_omega_samples = data.sample_in_time_domain(subkey, sample_size)
159
+ data = eqx.tree_at(lambda m: m.key, data, new_key)
160
+ else:
161
+ new_omega_samples = data.sample_in_time_domain(sample_size)
124
162
 
125
163
  # We can have different types of Loss
126
164
  if isinstance(loss, LossODE):
@@ -162,17 +200,29 @@ def _rar_step_init(sample_size, selected_sample_size):
162
200
  ## add the new points in times
163
201
  # start indices of update can be dynamic but the the shape (length)
164
202
  # of the slice
165
- data.times = jax.lax.dynamic_update_slice(
203
+ new_times = jax.lax.dynamic_update_slice(
166
204
  data.times,
167
205
  higher_residual_points,
168
206
  (data.nt_start + data.rar_iter_nb * selected_sample_size,),
169
207
  )
170
208
 
209
+ if isinstance(data, eqx.Module):
210
+ data = eqx.tree_at(lambda m: m.times, data, new_times)
211
+ else:
212
+ data.times = new_times
171
213
  ## rearrange probabilities so that the probabilities of the new
172
214
  ## points are non-zero
173
215
  new_proba = 1 / (data.nt_start + data.rar_iter_nb * selected_sample_size)
174
216
  # the next work because nt_start is static
175
- data.p_times = data.p_times.at[: data.nt_start].set(new_proba)
217
+ new_p_times = data.p_times.at[: data.nt_start].set(new_proba)
218
+ if isinstance(data, eqx.Module):
219
+ data = eqx.tree_at(
220
+ lambda m: m.p_times,
221
+ data,
222
+ new_p_times,
223
+ )
224
+ else:
225
+ data.p_times = new_p_times
176
226
 
177
227
  # the next requires a fori_loop because the range is dynamic
178
228
  def update_slices(i, p):
@@ -182,16 +232,29 @@ def _rar_step_init(sample_size, selected_sample_size):
182
232
  ((data.nt_start + i * selected_sample_size),),
183
233
  )
184
234
 
185
- data.rar_iter_nb += 1
186
-
187
- data.p_times = jax.lax.fori_loop(
235
+ new_rar_iter_nb = data.rar_iter_nb + 1
236
+ new_p_times = jax.lax.fori_loop(
188
237
  0, data.rar_iter_nb, update_slices, data.p_times
189
238
  )
239
+ if isinstance(data, eqx.Module):
240
+ data = eqx.tree_at(
241
+ lambda m: (m.rar_iter_nb, m.p_times),
242
+ data,
243
+ (new_rar_iter_nb, new_p_times),
244
+ )
245
+ else:
246
+ data.rar_iter_nb = new_rar_iter_nb
247
+ data.p_times = new_p_times
190
248
 
191
249
  elif isinstance(data, CubicMeshPDEStatio) and not isinstance(
192
250
  data, CubicMeshPDENonStatio
193
251
  ):
194
- new_omega_samples = data.sample_in_omega_domain(sample_size)
252
+ if isinstance(data, eqx.Module):
253
+ new_key, *subkeys = jax.random.split(data.key, data.dim + 1)
254
+ new_omega_samples = data.sample_in_omega_domain(subkeys, sample_size)
255
+ data = eqx.tree_at(lambda m: m.key, data, new_key)
256
+ else:
257
+ new_omega_samples = data.sample_in_omega_domain(sample_size)
195
258
 
196
259
  # We can have different types of Loss
197
260
  if isinstance(loss, LossPDEStatio):
@@ -209,7 +272,7 @@ def _rar_step_init(sample_size, selected_sample_size):
209
272
  mse_on_s = (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
210
273
  else:
211
274
  mse_on_s = dyn_on_s**2
212
- elif isinstance(loss, SystemLossPDE):
275
+ elif isinstance(loss, SystemLossODE):
213
276
  mse_on_s = 0
214
277
  for i in loss.dynamic_loss_dict.keys():
215
278
  # only the case LossPDEStatio here
@@ -237,17 +300,30 @@ def _rar_step_init(sample_size, selected_sample_size):
237
300
  ## add the new points in omega
238
301
  # start indices of update can be dynamic but not the shape (length)
239
302
  # of the slice
240
- data.omega = jax.lax.dynamic_update_slice(
303
+ new_omega = jax.lax.dynamic_update_slice(
241
304
  data.omega,
242
305
  higher_residual_points,
243
306
  (data.n_start + data.rar_iter_nb * selected_sample_size, data.dim),
244
307
  )
245
308
 
309
+ if isinstance(data, eqx.Module):
310
+ data = eqx.tree_at(lambda m: m.omega, data, new_omega)
311
+ else:
312
+ data.omega = new_omega
313
+
246
314
  ## rearrange probabilities so that the probabilities of the new
247
315
  ## points are non-zero
248
316
  new_proba = 1 / (data.n_start + data.rar_iter_nb * selected_sample_size)
249
317
  # the next work because n_start is static
250
- data.p_omega = data.p_omega.at[: data.n_start].set(new_proba)
318
+ new_p_omega = data.p_omega.at[: data.n_start].set(new_proba)
319
+ if isinstance(data, eqx.Module):
320
+ data = eqx.tree_at(
321
+ lambda m: m.p_omega,
322
+ data,
323
+ new_p_omega,
324
+ )
325
+ else:
326
+ data.p_omega = new_p_omega
251
327
 
252
328
  # the next requires a fori_loop because the range is dynamic
253
329
  def update_slices(i, p):
@@ -257,13 +333,24 @@ def _rar_step_init(sample_size, selected_sample_size):
257
333
  ((data.n_start + i * selected_sample_size),),
258
334
  )
259
335
 
260
- data.rar_iter_nb += 1
261
-
262
- data.p_omega = jax.lax.fori_loop(
336
+ new_rar_iter_nb = data.rar_iter_nb + 1
337
+ new_p_omega = jax.lax.fori_loop(
263
338
  0, data.rar_iter_nb, update_slices, data.p_omega
264
339
  )
340
+ if isinstance(data, eqx.Module):
341
+ data = eqx.tree_at(
342
+ lambda m: (m.rar_iter_nb, m.p_omega),
343
+ data,
344
+ (new_rar_iter_nb, new_p_omega),
345
+ )
346
+ else:
347
+ data.rar_iter_nb = new_rar_iter_nb
348
+ data.p_omega = new_p_omega
265
349
 
266
350
  elif isinstance(data, CubicMeshPDENonStatio):
351
+ if isinstance(loss.u, HYPERPINN) or isinstance(loss.u, SPINN):
352
+ raise NotImplementedError("RAR not implemented for hyperPINN and SPINN")
353
+
267
354
  # NOTE in this case sample_size and selected_sample_size
268
355
  # are tuples (times, omega) => we unpack them for clarity
269
356
  selected_sample_size_times, selected_sample_size_omega = (
@@ -271,17 +358,29 @@ def _rar_step_init(sample_size, selected_sample_size):
271
358
  )
272
359
  sample_size_times, sample_size_omega = sample_size
273
360
 
274
- new_times_samples = data.sample_in_time_domain(sample_size_times)
275
- new_omega_samples = data.sample_in_omega_domain(sample_size_omega)
361
+ if isinstance(data, eqx.Module):
362
+ new_key, subkey = jax.random.split(data.key)
363
+ new_times_samples = data.sample_in_time_domain(
364
+ subkey, sample_size_times
365
+ )
366
+ new_key, *subkeys = jax.random.split(new_key, data.dim + 1)
367
+ new_omega_samples = data.sample_in_omega_domain(
368
+ subkeys, sample_size_omega
369
+ )
370
+ data = eqx.tree_at(lambda m: m.key, data, new_key)
371
+ else:
372
+ new_times_samples = data.sample_in_time_domain(sample_size_times)
373
+ new_omega_samples = data.sample_in_omega_domain(sample_size_omega)
276
374
 
277
- if isinstance(loss.u, HYPERPINN) or isinstance(loss.u, SPINN):
278
- raise NotImplementedError("RAR not implemented for hyperPINN and SPINN")
375
+ if not data.cartesian_product:
376
+ times = new_times_samples
377
+ omega = new_omega_samples
279
378
  else:
280
379
  # do cartesian product on new points
281
- tile_omega = jnp.tile(
380
+ omega = jnp.tile(
282
381
  new_omega_samples, reps=(sample_size_times, 1)
283
382
  ) # it is tiled
284
- repeat_times = jnp.repeat(new_times_samples, sample_size_omega, axis=0)[
383
+ times = jnp.repeat(new_times_samples, sample_size_omega, axis=0)[
285
384
  ..., None
286
385
  ] # it is repeated + add an axis
287
386
 
@@ -291,7 +390,7 @@ def _rar_step_init(sample_size, selected_sample_size):
291
390
  (0, 0),
292
391
  0,
293
392
  )
294
- dyn_on_s = v_dyn_loss(repeat_times, tile_omega).reshape(
393
+ dyn_on_s = v_dyn_loss(times, omega).reshape(
295
394
  (sample_size_times, sample_size_omega)
296
395
  )
297
396
  mse_on_s = dyn_on_s**2
@@ -305,7 +404,7 @@ def _rar_step_init(sample_size, selected_sample_size):
305
404
  (0, 0),
306
405
  0,
307
406
  )
308
- dyn_on_s += v_dyn_loss(repeat_times, tile_omega).reshape(
407
+ dyn_on_s += v_dyn_loss(times, omega).reshape(
309
408
  (sample_size_times, sample_size_omega)
310
409
  )
311
410
 
@@ -338,14 +437,23 @@ def _rar_step_init(sample_size, selected_sample_size):
338
437
  ## add the new points in times
339
438
  # start indices of update can be dynamic but not the shape (length)
340
439
  # of the slice
341
- data.times = jax.lax.dynamic_update_slice(
440
+ new_times = jax.lax.dynamic_update_slice(
342
441
  data.times,
343
442
  higher_residual_points_times,
344
- (data.n_start + data.rar_iter_nb * selected_sample_size_times,),
443
+ (
444
+ data.n_start
445
+ + data.rar_iter_nb # NOTE typo here nt_start ?
446
+ * selected_sample_size_times,
447
+ ),
345
448
  )
346
449
 
450
+ if isinstance(data, eqx.Module):
451
+ data = eqx.tree_at(lambda m: m.times, data, new_times)
452
+ else:
453
+ data.times = new_times
454
+
347
455
  ## add the new points in omega
348
- data.omega = jax.lax.dynamic_update_slice(
456
+ new_omega = jax.lax.dynamic_update_slice(
349
457
  data.omega,
350
458
  higher_residual_points_omega,
351
459
  (
@@ -354,19 +462,38 @@ def _rar_step_init(sample_size, selected_sample_size):
354
462
  ),
355
463
  )
356
464
 
465
+ if isinstance(data, eqx.Module):
466
+ data = eqx.tree_at(lambda m: m.omega, data, new_omega)
467
+ else:
468
+ data.omega = new_omega
469
+
357
470
  ## rearrange probabilities so that the probabilities of the new
358
471
  ## points are non-zero
359
472
  new_p_times = 1 / (
360
473
  data.nt_start + data.rar_iter_nb * selected_sample_size_times
361
474
  )
362
475
  # the next work because nt_start is static
363
- data.p_times = data.p_times.at[: data.nt_start].set(new_p_times)
476
+ if isinstance(data, eqx.Module):
477
+ data = eqx.tree_at(
478
+ lambda m: m.p_times,
479
+ data,
480
+ data.p_times.at[: data.nt_start].set(new_p_times),
481
+ )
482
+ else:
483
+ data.p_times = data.p_times.at[: data.nt_start].set(new_p_times)
364
484
 
365
485
  # same for p_omega (work because n_start is static)
366
486
  new_p_omega = 1 / (
367
487
  data.n_start + data.rar_iter_nb * selected_sample_size_omega
368
488
  )
369
- data.p_omega = data.p_omega.at[: data.n_start].set(new_p_omega)
489
+ if isinstance(data, eqx.Module):
490
+ data = eqx.tree_at(
491
+ lambda m: m.p_omega,
492
+ data,
493
+ data.p_omega.at[: data.n_start].set(new_p_omega),
494
+ )
495
+ else:
496
+ data.p_omega = data.p_omega.at[: data.n_start].set(new_p_omega)
370
497
 
371
498
  # the part of data.p_* after n_start requires a fori_loop because
372
499
  # the range is dynamic
@@ -381,13 +508,13 @@ def _rar_step_init(sample_size, selected_sample_size):
381
508
 
382
509
  return update_slices
383
510
 
384
- data.rar_iter_nb += 1
511
+ new_rar_iter_nb = data.rar_iter_nb + 1
385
512
 
386
513
  ## update rest of p_times
387
514
  update_slices_times = create_update_slices(
388
515
  new_p_times, selected_sample_size_times
389
516
  )
390
- data.p_times = jax.lax.fori_loop(
517
+ new_p_times = jax.lax.fori_loop(
391
518
  0,
392
519
  data.rar_iter_nb,
393
520
  update_slices_times,
@@ -397,21 +524,34 @@ def _rar_step_init(sample_size, selected_sample_size):
397
524
  update_slices_omega = create_update_slices(
398
525
  new_p_omega, selected_sample_size_omega
399
526
  )
400
- data.p_omega = jax.lax.fori_loop(
527
+ new_p_omega = jax.lax.fori_loop(
401
528
  0,
402
529
  data.rar_iter_nb,
403
530
  update_slices_omega,
404
531
  data.p_omega,
405
532
  )
533
+ if isinstance(data, eqx.Module):
534
+ data = eqx.tree_at(
535
+ lambda m: (m.rar_iter_nb, m.p_omega, m.p_times),
536
+ data,
537
+ (new_rar_iter_nb, new_p_omega, new_p_times),
538
+ )
539
+ else:
540
+ data.rar_iter_nb = new_rar_iter_nb
541
+ data.p_times = new_p_times
542
+ data.p_omega = new_p_omega
406
543
 
407
544
  # update RAR parameters for all cases
408
- data.rar_iter_from_last_sampling = 0
545
+ if isinstance(data, eqx.Module):
546
+ data = eqx.tree_at(lambda m: m.rar_iter_from_last_sampling, data, 0)
547
+ else:
548
+ data.rar_iter_from_last_sampling = 0
409
549
 
410
550
  # NOTE must return data to be correctly updated because we cannot
411
551
  # have side effects in this function that will be jitted
412
552
  return data
413
553
 
414
- def rar_step_false(operands):
554
+ def rar_step_false(operands: rar_operands) -> AnyDataGenerator:
415
555
  _, _, data, i = operands
416
556
 
417
557
  # Add 1 only if we are after the burn in period
@@ -421,7 +561,15 @@ def _rar_step_init(sample_size, selected_sample_size):
421
561
  lambda: 1,
422
562
  )
423
563
 
424
- data.rar_iter_from_last_sampling += increment
564
+ new_rar_iter_from_last_sampling = data.rar_iter_from_last_sampling + increment
565
+ if isinstance(data, eqx.Module):
566
+ data = eqx.tree_at(
567
+ lambda m: m.rar_iter_from_last_sampling,
568
+ data,
569
+ new_rar_iter_from_last_sampling,
570
+ )
571
+ else:
572
+ data.rar_iter_from_last_sampling = new_rar_iter_from_last_sampling
425
573
  return data
426
574
 
427
575
  return rar_step_true, rar_step_false