jinns 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +114 -187
  13. jinns/loss/_DynamicLossAbstract.py +74 -69
  14. jinns/loss/_LossODE.py +132 -348
  15. jinns/loss/_LossPDE.py +262 -549
  16. jinns/loss/__init__.py +32 -6
  17. jinns/loss/_abstract_loss.py +128 -0
  18. jinns/loss/_boundary_conditions.py +20 -19
  19. jinns/loss/_loss_components.py +43 -0
  20. jinns/loss/_loss_utils.py +85 -179
  21. jinns/loss/_loss_weight_updates.py +202 -0
  22. jinns/loss/_loss_weights.py +64 -40
  23. jinns/loss/_operators.py +84 -74
  24. jinns/nn/__init__.py +15 -0
  25. jinns/nn/_abstract_pinn.py +22 -0
  26. jinns/nn/_hyperpinn.py +94 -57
  27. jinns/nn/_mlp.py +50 -25
  28. jinns/nn/_pinn.py +33 -19
  29. jinns/nn/_ppinn.py +70 -34
  30. jinns/nn/_save_load.py +21 -51
  31. jinns/nn/_spinn.py +33 -16
  32. jinns/nn/_spinn_mlp.py +28 -22
  33. jinns/nn/_utils.py +38 -0
  34. jinns/parameters/__init__.py +8 -1
  35. jinns/parameters/_derivative_keys.py +116 -177
  36. jinns/parameters/_params.py +18 -46
  37. jinns/plot/__init__.py +2 -0
  38. jinns/plot/_plot.py +35 -34
  39. jinns/solver/_rar.py +80 -63
  40. jinns/solver/_solve.py +207 -92
  41. jinns/solver/_utils.py +4 -6
  42. jinns/utils/__init__.py +2 -0
  43. jinns/utils/_containers.py +16 -10
  44. jinns/utils/_types.py +20 -54
  45. jinns/utils/_utils.py +4 -11
  46. jinns/validation/__init__.py +2 -0
  47. jinns/validation/_validation.py +20 -19
  48. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/METADATA +8 -4
  49. jinns-1.5.0.dist-info/RECORD +55 -0
  50. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/WHEEL +1 -1
  51. jinns/data/_DataGenerators.py +0 -1634
  52. jinns-1.3.0.dist-info/RECORD +0 -44
  53. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/AUTHORS +0 -0
  54. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/LICENSE +0 -0
  55. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,189 @@
1
+ """
2
+ Define the DataGenerators modules
3
+ """
4
+
5
+ from __future__ import (
6
+ annotations,
7
+ ) # https://docs.python.org/3/library/typing.html#constant
8
+ import equinox as eqx
9
+ import jax
10
+ import jax.numpy as jnp
11
+ from jaxtyping import Key, Int, Array, Float
12
+ from jinns.data._Batchs import ObsBatchDict
13
+ from jinns.data._utils import _reset_or_increment
14
+ from jinns.data._AbstractDataGenerator import AbstractDataGenerator
15
+
16
+
17
+ class DataGeneratorObservations(AbstractDataGenerator):
18
+ r"""
19
+ Despite the class name, it is rather a dataloader for user-provided
20
+ observations which will are used in the observations loss.
21
+
22
+ Parameters
23
+ ----------
24
+ key : Key
25
+ Jax random key to shuffle batches
26
+ obs_batch_size : int | None
27
+ The size of the batch of randomly selected points among
28
+ the `n` points. If None, no minibatch are used.
29
+ observed_pinn_in : Float[Array, " n_obs nb_pinn_in"]
30
+ Observed values corresponding to the input of the PINN
31
+ (eg. the time at which we recorded the observations). The first
32
+ dimension must corresponds to the number of observed_values.
33
+ The second dimension depends on the input dimension of the PINN,
34
+ that is `1` for ODE, `n_dim_x` for stationnary PDE and `n_dim_x + 1`
35
+ for non-stationnary PDE.
36
+ observed_values : Float[Array, " n_obs, nb_pinn_out"]
37
+ Observed values that the PINN should learn to fit. The first
38
+ dimension must be aligned with observed_pinn_in.
39
+ observed_eq_params : dict[str, Float[Array, " n_obs 1"]], default={}
40
+ A dict with keys corresponding to
41
+ the parameter name. The keys must match the keys in
42
+ `params["eq_params"]`. The values are jnp.array with 2 dimensions
43
+ with values corresponding to the parameter value for which we also
44
+ have observed_pinn_in and observed_values. Hence the first
45
+ dimension must be aligned with observed_pinn_in and observed_values.
46
+ Optional argument.
47
+ sharding_device : jax.sharding.Sharding, default=None
48
+ Default None. An optional sharding object to constraint the storage
49
+ of observed inputs, values and parameters. Typically, a
50
+ SingleDeviceSharding(cpu_device) to avoid loading on GPU huge
51
+ datasets of observations. Note that computations for **batches**
52
+ can still be performed on other devices (*e.g.* GPU, TPU or
53
+ any pre-defined Sharding) thanks to the `obs_batch_sharding`
54
+ arguments of `jinns.solve()`. Read `jinns.solve()` doc for more info.
55
+ """
56
+
57
+ key: Key
58
+ obs_batch_size: int | None = eqx.field(static=True)
59
+ observed_pinn_in: Float[Array, " n_obs nb_pinn_in"]
60
+ observed_values: Float[Array, " n_obs nb_pinn_out"]
61
+ observed_eq_params: dict[str, Float[Array, " n_obs 1"]] = eqx.field(
62
+ static=True, default_factory=lambda: {}
63
+ )
64
+ sharding_device: jax.sharding.Sharding = eqx.field(static=True, default=None)
65
+
66
+ n: int = eqx.field(init=False, static=True)
67
+ curr_idx: int = eqx.field(init=False)
68
+ indices: Array = eqx.field(init=False)
69
+
70
+ def __post_init__(self):
71
+ if self.observed_pinn_in.shape[0] != self.observed_values.shape[0]:
72
+ raise ValueError(
73
+ "self.observed_pinn_in and self.observed_values must have same first axis"
74
+ )
75
+ for _, v in self.observed_eq_params.items():
76
+ if v.shape[0] != self.observed_pinn_in.shape[0]:
77
+ raise ValueError(
78
+ "self.observed_pinn_in and the values of"
79
+ " self.observed_eq_params must have the same first axis"
80
+ )
81
+ if len(self.observed_pinn_in.shape) == 1:
82
+ self.observed_pinn_in = self.observed_pinn_in[:, None]
83
+ if self.observed_pinn_in.ndim > 2:
84
+ raise ValueError("self.observed_pinn_in must have 2 dimensions")
85
+ if len(self.observed_values.shape) == 1:
86
+ self.observed_values = self.observed_values[:, None]
87
+ if self.observed_values.ndim > 2:
88
+ raise ValueError("self.observed_values must have 2 dimensions")
89
+ for k, v in self.observed_eq_params.items():
90
+ if len(v.shape) == 1:
91
+ self.observed_eq_params[k] = v[:, None]
92
+ if len(v.shape) > 2:
93
+ raise ValueError(
94
+ "Each value of observed_eq_params must have 2 dimensions"
95
+ )
96
+
97
+ self.n = self.observed_pinn_in.shape[0]
98
+
99
+ if self.sharding_device is not None:
100
+ self.observed_pinn_in = jax.lax.with_sharding_constraint(
101
+ self.observed_pinn_in, self.sharding_device
102
+ )
103
+ self.observed_values = jax.lax.with_sharding_constraint(
104
+ self.observed_values, self.sharding_device
105
+ )
106
+ self.observed_eq_params = jax.lax.with_sharding_constraint(
107
+ self.observed_eq_params, self.sharding_device
108
+ )
109
+
110
+ if self.obs_batch_size is not None:
111
+ self.curr_idx = self.n + self.obs_batch_size
112
+ # to be sure there is a shuffling at first get_batch()
113
+ else:
114
+ self.curr_idx = 0
115
+ # For speed and to avoid duplicating data what is really
116
+ # shuffled is a vector of indices
117
+ if self.sharding_device is not None:
118
+ self.indices = jax.lax.with_sharding_constraint(
119
+ jnp.arange(self.n), self.sharding_device
120
+ )
121
+ else:
122
+ self.indices = jnp.arange(self.n)
123
+
124
+ # recall post_init is the only place with _init_ where we can set
125
+ # self attribute in a in-place way
126
+ self.key, _ = jax.random.split(self.key, 2) # to make it equivalent to
127
+ # the call to _reset_batch_idx_and_permute in legacy DG
128
+
129
+ def _get_operands(self) -> tuple[Key, Int[Array, " n"], int, int | None, None]:
130
+ return (
131
+ self.key,
132
+ self.indices,
133
+ self.curr_idx,
134
+ self.obs_batch_size,
135
+ None,
136
+ )
137
+
138
+ def obs_batch(
139
+ self,
140
+ ) -> tuple[DataGeneratorObservations, ObsBatchDict]:
141
+ """
142
+ Return an update DataGeneratorObservations instance and an ObsBatchDict
143
+ """
144
+ if self.obs_batch_size is None or self.obs_batch_size == self.n:
145
+ # Avoid unnecessary reshuffling
146
+ return self, {
147
+ "pinn_in": self.observed_pinn_in,
148
+ "val": self.observed_values,
149
+ "eq_params": self.observed_eq_params,
150
+ }
151
+
152
+ new_attributes = _reset_or_increment(
153
+ self.curr_idx + self.obs_batch_size,
154
+ self.n,
155
+ self._get_operands(), # type: ignore
156
+ # ignore since the case self.obs_batch_size is None has been
157
+ # handled above
158
+ )
159
+ new = eqx.tree_at(
160
+ lambda m: (m.key, m.indices, m.curr_idx), self, new_attributes
161
+ )
162
+
163
+ minib_indices = jax.lax.dynamic_slice(
164
+ new.indices,
165
+ start_indices=(new.curr_idx,),
166
+ slice_sizes=(new.obs_batch_size,),
167
+ )
168
+
169
+ obs_batch: ObsBatchDict = {
170
+ "pinn_in": jnp.take(
171
+ new.observed_pinn_in, minib_indices, unique_indices=True, axis=0
172
+ ),
173
+ "val": jnp.take(
174
+ new.observed_values, minib_indices, unique_indices=True, axis=0
175
+ ),
176
+ "eq_params": jax.tree_util.tree_map(
177
+ lambda a: jnp.take(a, minib_indices, unique_indices=True, axis=0),
178
+ new.observed_eq_params,
179
+ ),
180
+ }
181
+ return new, obs_batch
182
+
183
+ def get_batch(
184
+ self,
185
+ ) -> tuple[DataGeneratorObservations, ObsBatchDict]:
186
+ """
187
+ Generic method to return a batch
188
+ """
189
+ return self.obs_batch()
@@ -0,0 +1,206 @@
1
+ """
2
+ Define the DataGenerators modules
3
+ """
4
+
5
+ from __future__ import (
6
+ annotations,
7
+ ) # https://docs.python.org/3/library/typing.html#constant
8
+ import equinox as eqx
9
+ import jax
10
+ import jax.numpy as jnp
11
+ from jaxtyping import Key, Array, Float
12
+ from jinns.data._utils import _reset_or_increment
13
+ from jinns.data._AbstractDataGenerator import AbstractDataGenerator
14
+
15
+
16
+ class DataGeneratorParameter(AbstractDataGenerator):
17
+ r"""
18
+ A data generator for additional unidimensional equation parameter(s).
19
+ Mostly useful for metamodeling where batch of `params.eq_params` are fed
20
+ to the network.
21
+
22
+ Parameters
23
+ ----------
24
+ keys : Key | dict[str, Key]
25
+ Jax random key to sample new time points and to shuffle batches
26
+ or a dict of Jax random keys with key entries from param_ranges
27
+ n : int
28
+ The number of total points that will be divided in
29
+ batches. Batches are made so that each data point is seen only
30
+ once during 1 epoch.
31
+ param_batch_size : int | None, default=None
32
+ The size of the batch of randomly selected points among
33
+ the `n` points. **Important**: no check is performed but
34
+ `param_batch_size` must be the same as other collocation points
35
+ batch_size (time, space or timexspace depending on the context). This is because we vmap the network on all its axes at once to compute the MSE. Also, `param_batch_size` will be the same for all parameters. If None, no mini-batches are used.
36
+ param_ranges : dict[str, tuple[Float, Float] | None, default={}
37
+ A dict. A dict of tuples (min, max), which
38
+ reprensents the range of real numbers where to sample batches (of
39
+ length `param_batch_size` among `n` points).
40
+ The key corresponds to the parameter name. The keys must match the
41
+ keys in `params["eq_params"]`.
42
+ By providing several entries in this dictionary we can sample
43
+ an arbitrary number of parameters.
44
+ **Note** that we currently only support unidimensional parameters.
45
+ This argument can be None if we use `user_data`.
46
+ method : str, default="uniform"
47
+ Either `grid` or `uniform`, default is `uniform`. `grid` means
48
+ regularly spaced points over the domain. `uniform` means uniformly
49
+ sampled points over the domain
50
+ user_data : dict[str, Float[Array, " n"]] | None, default={}
51
+ A dictionary containing user-provided data for parameters.
52
+ The keys corresponds to the parameter name,
53
+ and must match the keys in `params["eq_params"]`. Only
54
+ unidimensional `jnp.array` are supported. Therefore, the array at
55
+ `user_data[k]` must have shape `(n, 1)` or `(n,)`.
56
+ Note that if the same key appears in `param_ranges` and `user_data`
57
+ priority goes for the content in `user_data`.
58
+ Defaults to None.
59
+ """
60
+
61
+ keys: Key | dict[str, Key]
62
+ n: int = eqx.field(static=True)
63
+ param_batch_size: int | None = eqx.field(static=True, default=None)
64
+ param_ranges: dict[str, tuple[Float, Float]] = eqx.field(
65
+ static=True, default_factory=lambda: {}
66
+ )
67
+ method: str = eqx.field(static=True, default="uniform")
68
+ user_data: dict[str, Float[Array, " n"]] | None = eqx.field(
69
+ default_factory=lambda: {}
70
+ )
71
+
72
+ curr_param_idx: dict[str, int] = eqx.field(init=False)
73
+ param_n_samples: dict[str, Array] = eqx.field(init=False)
74
+
75
+ def __post_init__(self):
76
+ if self.user_data is None:
77
+ self.user_data = {}
78
+ if self.param_ranges is None:
79
+ self.param_ranges = {}
80
+ if self.param_batch_size is not None and self.n < self.param_batch_size:
81
+ raise ValueError(
82
+ f"Number of data points ({self.n}) is smaller than the"
83
+ f"number of batch points ({self.param_batch_size})."
84
+ )
85
+ if not isinstance(self.keys, dict):
86
+ all_keys = set().union(self.param_ranges, self.user_data)
87
+ self.keys = dict(zip(all_keys, jax.random.split(self.keys, len(all_keys))))
88
+
89
+ if self.param_batch_size is None:
90
+ self.curr_param_idx = None # type: ignore
91
+ else:
92
+ self.curr_param_idx = {}
93
+ for k in self.keys.keys():
94
+ self.curr_param_idx[k] = self.n + self.param_batch_size
95
+ # to be sure there is a shuffling at first get_batch()
96
+
97
+ # The call to self.generate_data() creates
98
+ # the dict self.param_n_samples and then we will only use this one
99
+ # because it merges the scattered data between `user_data` and
100
+ # `param_ranges`
101
+ self.keys, self.param_n_samples = self.generate_data(self.keys)
102
+
103
+ def generate_data(
104
+ self, keys: dict[str, Key]
105
+ ) -> tuple[dict[str, Key], dict[str, Float[Array, " n"]]]:
106
+ """
107
+ Generate parameter samples, either through generation
108
+ or using user-provided data.
109
+ """
110
+ param_n_samples = {}
111
+
112
+ all_keys = set().union(
113
+ self.param_ranges,
114
+ self.user_data, # type: ignore this has been handled in post_init
115
+ )
116
+ for k in all_keys:
117
+ if self.user_data and k in self.user_data.keys():
118
+ if self.user_data[k].shape == (self.n, 1):
119
+ param_n_samples[k] = self.user_data[k]
120
+ if self.user_data[k].shape == (self.n,):
121
+ param_n_samples[k] = self.user_data[k][:, None]
122
+ else:
123
+ raise ValueError(
124
+ "Wrong shape for user provided parameters"
125
+ f" in user_data dictionary at key='{k}'"
126
+ )
127
+ else:
128
+ if self.method == "grid":
129
+ xmin, xmax = self.param_ranges[k][0], self.param_ranges[k][1]
130
+ partial = (xmax - xmin) / self.n
131
+ # shape (n, 1)
132
+ param_n_samples[k] = jnp.arange(xmin, xmax, partial)[:, None]
133
+ elif self.method == "uniform":
134
+ xmin, xmax = self.param_ranges[k][0], self.param_ranges[k][1]
135
+ keys[k], subkey = jax.random.split(keys[k], 2)
136
+ param_n_samples[k] = jax.random.uniform(
137
+ subkey, shape=(self.n, 1), minval=xmin, maxval=xmax
138
+ )
139
+ else:
140
+ raise ValueError("Method " + self.method + " is not implemented.")
141
+
142
+ return keys, param_n_samples
143
+
144
+ def _get_param_operands(
145
+ self, k: str
146
+ ) -> tuple[Key, Float[Array, " n"], int, int | None, None]:
147
+ return (
148
+ self.keys[k],
149
+ self.param_n_samples[k],
150
+ self.curr_param_idx[k],
151
+ self.param_batch_size,
152
+ None,
153
+ )
154
+
155
+ def param_batch(self):
156
+ """
157
+ Return a dictionary with batches of parameters
158
+ If all the batches have been seen, we reshuffle them,
159
+ otherwise we just return the next unseen batch.
160
+ """
161
+
162
+ if self.param_batch_size is None or self.param_batch_size == self.n:
163
+ return self, self.param_n_samples
164
+
165
+ def _reset_or_increment_wrapper(param_k, idx_k, key_k):
166
+ return _reset_or_increment(
167
+ idx_k + self.param_batch_size,
168
+ self.n,
169
+ (key_k, param_k, idx_k, self.param_batch_size, None), # type: ignore
170
+ # ignore since the case self.param_batch_size is None has been
171
+ # handled above
172
+ )
173
+
174
+ res = jax.tree_util.tree_map(
175
+ _reset_or_increment_wrapper,
176
+ self.param_n_samples,
177
+ self.curr_param_idx,
178
+ self.keys,
179
+ )
180
+ # we must transpose the pytrees because keys are merged in res
181
+ # https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#transposing-trees
182
+ new_attributes = jax.tree_util.tree_transpose(
183
+ jax.tree_util.tree_structure(self.keys),
184
+ jax.tree_util.tree_structure([0, 0, 0]),
185
+ res,
186
+ )
187
+
188
+ new = eqx.tree_at(
189
+ lambda m: (m.keys, m.param_n_samples, m.curr_param_idx),
190
+ self,
191
+ new_attributes,
192
+ )
193
+
194
+ return new, jax.tree_util.tree_map(
195
+ lambda p, q: jax.lax.dynamic_slice(
196
+ p, start_indices=(q, 0), slice_sizes=(new.param_batch_size, 1)
197
+ ),
198
+ new.param_n_samples,
199
+ new.curr_param_idx,
200
+ )
201
+
202
+ def get_batch(self):
203
+ """
204
+ Generic method to return a batch
205
+ """
206
+ return self.param_batch()
jinns/data/__init__.py CHANGED
@@ -1,11 +1,21 @@
1
- from ._DataGenerators import (
2
- DataGeneratorODE,
3
- CubicMeshPDEStatio,
4
- CubicMeshPDENonStatio,
5
- DataGeneratorObservations,
6
- DataGeneratorParameter,
7
- DataGeneratorObservationsMultiPINNs,
8
- )
1
+ from ._DataGeneratorODE import DataGeneratorODE
2
+ from ._CubicMeshPDEStatio import CubicMeshPDEStatio
3
+ from ._CubicMeshPDENonStatio import CubicMeshPDENonStatio
4
+ from ._DataGeneratorObservations import DataGeneratorObservations
5
+ from ._DataGeneratorParameter import DataGeneratorParameter
9
6
  from ._Batchs import ODEBatch, PDEStatioBatch, PDENonStatioBatch
10
7
 
11
- from ._DataGenerators import append_obs_batch, append_param_batch
8
+ from ._utils import append_obs_batch, append_param_batch
9
+
10
+ __all__ = [
11
+ "DataGeneratorODE",
12
+ "CubicMeshPDEStatio",
13
+ "CubicMeshPDENonStatio",
14
+ "DataGeneratorParameter",
15
+ "DataGeneratorObservations",
16
+ "ODEBatch",
17
+ "PDEStatioBatch",
18
+ "PDENonStatioBatch",
19
+ "append_obs_batch",
20
+ "append_param_batch",
21
+ ]
jinns/data/_utils.py ADDED
@@ -0,0 +1,149 @@
1
+ """
2
+ Utility functions for DataGenerators
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import TYPE_CHECKING
8
+ import equinox as eqx
9
+ import jax
10
+ import jax.numpy as jnp
11
+ from jaxtyping import Key, Array, Float
12
+
13
+ if TYPE_CHECKING:
14
+ from jinns.utils._types import AnyBatch
15
+ from jinns.data._Batchs import ObsBatchDict
16
+
17
+
18
+ def append_param_batch(batch: AnyBatch, param_batch_dict: dict[str, Array]) -> AnyBatch:
19
+ """
20
+ Utility function that fills the field `batch.param_batch_dict` of a batch object.
21
+ """
22
+ return eqx.tree_at(
23
+ lambda m: m.param_batch_dict,
24
+ batch,
25
+ param_batch_dict,
26
+ is_leaf=lambda x: x is None,
27
+ )
28
+
29
+
30
+ def append_obs_batch(batch: AnyBatch, obs_batch_dict: ObsBatchDict) -> AnyBatch:
31
+ """
32
+ Utility function that fills the field `batch.obs_batch_dict` of a batch object
33
+ """
34
+ return eqx.tree_at(
35
+ lambda m: m.obs_batch_dict, batch, obs_batch_dict, is_leaf=lambda x: x is None
36
+ )
37
+
38
+
39
+ def make_cartesian_product(
40
+ b1: Float[Array, " batch_size dim1"], b2: Float[Array, " batch_size dim2"]
41
+ ) -> Float[Array, " rows=batch_size*batch_size (dim1+dim2)"]:
42
+ # rows= serves to disable jaxtyping wish for runtime check since it does not like the star
43
+ # operator, we wish use not as expected
44
+ """
45
+ Create the cartesian product of a time and a border omega batches
46
+ by tiling and repeating
47
+ """
48
+ n1 = b1.shape[0]
49
+ n2 = b2.shape[0]
50
+ b1 = jnp.repeat(b1, n2, axis=0)
51
+ b2 = jnp.tile(b2, reps=(n1,) + tuple(1 for i in b2.shape[1:]))
52
+ return jnp.concatenate([b1, b2], axis=1)
53
+
54
+
55
+ def _reset_batch_idx_and_permute(
56
+ operands: tuple[Key, Float[Array, " n dimension"], int, None, Float[Array, " n"]],
57
+ ) -> tuple[Key, Float[Array, " n dimension"], int]:
58
+ key, domain, curr_idx, _, p = operands
59
+ # resetting counter
60
+ curr_idx = 0
61
+ # reshuffling
62
+ key, subkey = jax.random.split(key)
63
+ if p is None:
64
+ domain = jax.random.permutation(subkey, domain, axis=0, independent=False)
65
+ else:
66
+ # otherwise p is used to avoid collocation points not in n_start
67
+ # NOTE that replace=True to avoid undefined behaviour but then, the
68
+ # domain.shape[0] does not really grow as in the original RAR. instead,
69
+ # it always comprises the same number of points, but the points are
70
+ # updated
71
+ domain = jax.random.choice(
72
+ subkey, domain, shape=(domain.shape[0],), replace=True, p=p
73
+ )
74
+
75
+ # return updated
76
+ return (key, domain, curr_idx)
77
+
78
+
79
+ def _increment_batch_idx(
80
+ operands: tuple[Key, Float[Array, " n dimension"], int, int, Float[Array, " n"]],
81
+ ) -> tuple[Key, Float[Array, " n dimension"], int]:
82
+ key, domain, curr_idx, batch_size, _ = operands
83
+ # simply increases counter and get the batch
84
+ curr_idx += batch_size
85
+ return (key, domain, curr_idx)
86
+
87
+
88
+ def _reset_or_increment(
89
+ bend: int,
90
+ n_eff: int,
91
+ operands: tuple[Key, Float[Array, " n dimension"], int, int, Float[Array, " n"]],
92
+ ) -> tuple[Key, Float[Array, " n dimension"], int]:
93
+ """
94
+ Factorize the code of the jax.lax.cond which checks if we have seen all the
95
+ batches in an epoch
96
+ If bend > n_eff (ie n when no RAR sampling) we reshuffle and start from 0
97
+ again. Otherwise, if bend < n_eff, this means there are still *_batch_size
98
+ samples at least that have not been seen and we can take a new batch
99
+
100
+ Parameters
101
+ ----------
102
+ bend
103
+ An integer. The new hypothetical index for the starting of the batch
104
+ n_eff
105
+ An integer. The number of points to see to complete an epoch
106
+ operands
107
+ A tuple. As passed to _reset_batch_idx_and_permute and
108
+ _increment_batch_idx
109
+
110
+ Returns
111
+ -------
112
+ res
113
+ A tuple as returned by _reset_batch_idx_and_permute or
114
+ _increment_batch_idx
115
+ """
116
+ return jax.lax.cond(
117
+ bend > n_eff, _reset_batch_idx_and_permute, _increment_batch_idx, operands
118
+ )
119
+
120
+
121
+ def _check_and_set_rar_parameters(
122
+ rar_parameters: dict, n: int, n_start: int
123
+ ) -> tuple[int, Float[Array, " n"] | None, int | None, int | None]:
124
+ if rar_parameters is not None and n_start is None:
125
+ raise ValueError(
126
+ "n_start must be provided in the context of RAR sampling scheme"
127
+ )
128
+
129
+ if rar_parameters is not None:
130
+ # Default p is None. However, in the RAR sampling scheme we use 0
131
+ # probability to specify non-used collocation points (i.e. points
132
+ # above n_start). Thus, p is a vector of probability of shape (nt, 1).
133
+ p = jnp.zeros((n,))
134
+ p = p.at[:n_start].set(1 / n_start)
135
+ # set internal counter for the number of gradient steps since the
136
+ # last new collocation points have been added
137
+ # It is not 0 to ensure the first iteration of RAR happens just
138
+ # after start_iter. See the _proceed_to_rar() function in _rar.py
139
+ rar_iter_from_last_sampling = rar_parameters["update_every"] - 1
140
+ # set iternal counter for the number of times collocation points
141
+ # have been added
142
+ rar_iter_nb = 0
143
+ else:
144
+ n_start = n
145
+ p = None
146
+ rar_iter_from_last_sampling = None
147
+ rar_iter_nb = None
148
+
149
+ return n_start, p, rar_iter_from_last_sampling, rar_iter_nb
@@ -6,3 +6,12 @@ from ._diffrax_solver import (
6
6
  neumann_boundary_condition,
7
7
  plot_diffrax_solution,
8
8
  )
9
+
10
+ __all__ = [
11
+ "SpatialDiscretisation",
12
+ "reaction_diffusion_2d_vector_field",
13
+ "laplacian",
14
+ "dirichlet_boundary_condition",
15
+ "neumann_boundary_condition",
16
+ "plot_diffrax_solution",
17
+ ]