jinns 1.6.1__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/__init__.py CHANGED
@@ -7,8 +7,9 @@ from jinns import parameters as parameters
7
7
  from jinns import plot as plot
8
8
  from jinns import nn as nn
9
9
  from jinns.solver._solve import solve
10
+ from jinns.solver._solve_alternate import solve_alternate
10
11
 
11
- __all__ = ["nn", "solve"]
12
+ __all__ = ["nn", "solve", "solve_alternate"]
12
13
 
13
14
  import warnings
14
15
 
jinns/data/_Batchs.py CHANGED
@@ -26,14 +26,14 @@ class ObsBatchDict(TypedDict):
26
26
  class ODEBatch(eqx.Module):
27
27
  temporal_batch: Float[Array, " batch_size"]
28
28
  param_batch_dict: eqx.Module | None = eqx.field(default=None)
29
- obs_batch_dict: ObsBatchDict | None = eqx.field(default=None)
29
+ obs_batch_dict: tuple[ObsBatchDict, ...] | None = eqx.field(default=None)
30
30
 
31
31
 
32
32
  class PDEStatioBatch(eqx.Module):
33
33
  domain_batch: Float[Array, " batch_size dimension"]
34
34
  border_batch: Float[Array, " batch_size dimension n_facets"] | None
35
35
  param_batch_dict: eqx.Module | None
36
- obs_batch_dict: ObsBatchDict | None
36
+ obs_batch_dict: tuple[ObsBatchDict, ...] | None
37
37
 
38
38
  # rewrite __init__ to be able to use inheritance for the NonStatio case
39
39
  # below. That way PDENonStatioBatch is a subtype of PDEStatioBatch which
@@ -44,7 +44,7 @@ class PDEStatioBatch(eqx.Module):
44
44
  domain_batch: Float[Array, " batch_size dimension"],
45
45
  border_batch: Float[Array, " batch_size dimension n_facets"] | None,
46
46
  param_batch_dict: eqx.Module | None = None,
47
- obs_batch_dict: ObsBatchDict | None = None,
47
+ obs_batch_dict: tuple[ObsBatchDict, ...] | None = None,
48
48
  ):
49
49
  # TODO: document this ?
50
50
  self.domain_batch = domain_batch
@@ -67,7 +67,7 @@ class PDENonStatioBatch(PDEStatioBatch):
67
67
  border_batch: Float[Array, " batch_size dimension n_facets"] | None,
68
68
  initial_batch: Float[Array, " batch_size dimension"] | None,
69
69
  param_batch_dict: eqx.Module | None = None,
70
- obs_batch_dict: ObsBatchDict | None = None,
70
+ obs_batch_dict: tuple[ObsBatchDict, ...] | None = None,
71
71
  ):
72
72
  self.domain_batch = domain_batch
73
73
  self.border_batch = border_batch
@@ -77,7 +77,7 @@ class DataGeneratorODE(AbstractDataGenerator):
77
77
  nt: int,
78
78
  tmin: float,
79
79
  tmax: float,
80
- temporal_batch_size: int | None,
80
+ temporal_batch_size: int | None = None,
81
81
  method: str = "uniform",
82
82
  rar_parameters: None | dict[str, int] = None,
83
83
  n_start: None | int = None,