jinns 0.8.10__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +953 -1182
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +321 -168
  9. jinns/loss/_LossODE.py +290 -307
  10. jinns/loss/_LossPDE.py +628 -1040
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +95 -96
  13. jinns/loss/{_Losses.py → _loss_utils.py} +104 -46
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +94 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +193 -45
  22. jinns/solver/_solve.py +199 -144
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -43
  25. jinns/utils/_hyperpinn.py +226 -127
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +117 -84
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +52 -144
  32. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/METADATA +5 -4
  33. jinns-1.0.0.dist-info/RECORD +38 -0
  34. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/WHEEL +1 -1
  35. jinns/experimental/_sinuspinn.py +0 -135
  36. jinns/experimental/_spectralpinn.py +0 -87
  37. jinns/solver/_seq2seq.py +0 -157
  38. jinns/utils/_optim.py +0 -147
  39. jinns/utils/_utils_uspinn.py +0 -727
  40. jinns-0.8.10.dist-info/RECORD +0 -36
  41. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/LICENSE +0 -0
  42. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/top_level.txt +0 -0
jinns/__init__.py CHANGED
@@ -3,4 +3,6 @@ import jinns.loss
3
3
  import jinns.solver
4
4
  import jinns.utils
5
5
  import jinns.experimental
6
+ import jinns.parameters
7
+ import jinns.plot
6
8
  from jinns.solver._solve import solve
jinns/data/_Batchs.py ADDED
@@ -0,0 +1,27 @@
1
+ import equinox as eqx
2
+ from jaxtyping import Float, Array
3
+
4
+
5
+ class ODEBatch(eqx.Module):
6
+ temporal_batch: Float[Array, "batch_size"]
7
+ param_batch_dict: dict = eqx.field(default=None)
8
+ obs_batch_dict: dict = eqx.field(default=None)
9
+
10
+
11
+ class PDENonStatioBatch(eqx.Module):
12
+ times_x_inside_batch: (
13
+ Float[Array, "batch_size dimension"] | Float[Array, "(batch_size**2) dimension"]
14
+ )
15
+ times_x_border_batch: (
16
+ Float[Array, "border_batch_size dimension n_facets"]
17
+ | Float[Array, "(border_batch_size**2) dimension n_facets"]
18
+ )
19
+ param_batch_dict: dict = eqx.field(default=None)
20
+ obs_batch_dict: dict = eqx.field(default=None)
21
+
22
+
23
+ class PDEStatioBatch(eqx.Module):
24
+ inside_batch: Float[Array, "batch_size dimension"]
25
+ border_batch: Float[Array, "batch_size dimension n_facets"]
26
+ param_batch_dict: dict = eqx.field(default=None)
27
+ obs_batch_dict: dict = eqx.field(default=None)