jinns 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +904 -1203
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +322 -167
  9. jinns/loss/_LossODE.py +324 -322
  10. jinns/loss/_LossPDE.py +652 -1027
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +87 -41
  13. jinns/loss/{_Losses.py → _loss_utils.py} +101 -45
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +521 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +183 -39
  22. jinns/solver/_solve.py +151 -124
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -44
  25. jinns/utils/_hyperpinn.py +224 -119
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +113 -86
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +48 -140
  32. jinns-1.1.0.dist-info/AUTHORS +2 -0
  33. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/METADATA +5 -4
  34. jinns-1.1.0.dist-info/RECORD +39 -0
  35. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/WHEEL +1 -1
  36. jinns/experimental/_sinuspinn.py +0 -135
  37. jinns/experimental/_spectralpinn.py +0 -87
  38. jinns/solver/_seq2seq.py +0 -157
  39. jinns/utils/_optim.py +0 -147
  40. jinns/utils/_utils_uspinn.py +0 -727
  41. jinns-0.9.0.dist-info/RECORD +0 -36
  42. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/LICENSE +0 -0
  43. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/top_level.txt +0 -0
jinns/__init__.py CHANGED
@@ -3,4 +3,6 @@ import jinns.loss
3
3
  import jinns.solver
4
4
  import jinns.utils
5
5
  import jinns.experimental
6
+ import jinns.parameters
7
+ import jinns.plot
6
8
  from jinns.solver._solve import solve
jinns/data/_Batchs.py ADDED
@@ -0,0 +1,27 @@
1
+ import equinox as eqx
2
+ from jaxtyping import Float, Array
3
+
4
+
5
+ class ODEBatch(eqx.Module):
6
+ temporal_batch: Float[Array, "batch_size"]
7
+ param_batch_dict: dict = eqx.field(default=None)
8
+ obs_batch_dict: dict = eqx.field(default=None)
9
+
10
+
11
+ class PDENonStatioBatch(eqx.Module):
12
+ times_x_inside_batch: (
13
+ Float[Array, "batch_size dimension"] | Float[Array, "(batch_size**2) dimension"]
14
+ )
15
+ times_x_border_batch: (
16
+ Float[Array, "border_batch_size dimension n_facets"]
17
+ | Float[Array, "(border_batch_size**2) dimension n_facets"]
18
+ )
19
+ param_batch_dict: dict = eqx.field(default=None)
20
+ obs_batch_dict: dict = eqx.field(default=None)
21
+
22
+
23
+ class PDEStatioBatch(eqx.Module):
24
+ inside_batch: Float[Array, "batch_size dimension"]
25
+ border_batch: Float[Array, "batch_size dimension n_facets"]
26
+ param_batch_dict: dict = eqx.field(default=None)
27
+ obs_batch_dict: dict = eqx.field(default=None)