jinns 1.5.1__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_AbstractDataGenerator.py +1 -1
- jinns/data/_Batchs.py +47 -13
- jinns/data/_CubicMeshPDENonStatio.py +55 -34
- jinns/data/_CubicMeshPDEStatio.py +63 -35
- jinns/data/_DataGeneratorODE.py +48 -22
- jinns/data/_DataGeneratorObservations.py +86 -32
- jinns/data/_DataGeneratorParameter.py +152 -101
- jinns/data/__init__.py +2 -1
- jinns/data/_utils.py +22 -10
- jinns/loss/_DynamicLoss.py +21 -20
- jinns/loss/_DynamicLossAbstract.py +51 -36
- jinns/loss/_LossODE.py +139 -184
- jinns/loss/_LossPDE.py +440 -358
- jinns/loss/_abstract_loss.py +60 -25
- jinns/loss/_loss_components.py +4 -25
- jinns/loss/_loss_weight_updates.py +6 -7
- jinns/loss/_loss_weights.py +34 -35
- jinns/nn/_abstract_pinn.py +0 -2
- jinns/nn/_hyperpinn.py +34 -23
- jinns/nn/_mlp.py +5 -4
- jinns/nn/_pinn.py +1 -16
- jinns/nn/_ppinn.py +5 -16
- jinns/nn/_save_load.py +11 -4
- jinns/nn/_spinn.py +1 -16
- jinns/nn/_spinn_mlp.py +5 -5
- jinns/nn/_utils.py +33 -38
- jinns/parameters/__init__.py +3 -1
- jinns/parameters/_derivative_keys.py +99 -41
- jinns/parameters/_params.py +50 -25
- jinns/solver/_solve.py +3 -3
- jinns/utils/_DictToModuleMeta.py +66 -0
- jinns/utils/_ItemizableModule.py +19 -0
- jinns/utils/__init__.py +2 -1
- jinns/utils/_types.py +25 -15
- {jinns-1.5.1.dist-info → jinns-1.6.1.dist-info}/METADATA +2 -2
- jinns-1.6.1.dist-info/RECORD +57 -0
- jinns-1.5.1.dist-info/RECORD +0 -55
- {jinns-1.5.1.dist-info → jinns-1.6.1.dist-info}/WHEEL +0 -0
- {jinns-1.5.1.dist-info → jinns-1.6.1.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.5.1.dist-info → jinns-1.6.1.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.5.1.dist-info → jinns-1.6.1.dist-info}/top_level.txt +0 -0
jinns/data/_DataGeneratorODE.py
CHANGED
|
@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING
|
|
|
9
9
|
import equinox as eqx
|
|
10
10
|
import jax
|
|
11
11
|
import jax.numpy as jnp
|
|
12
|
-
from jaxtyping import
|
|
12
|
+
from jaxtyping import PRNGKeyArray, Array, Float
|
|
13
13
|
from jinns.data._Batchs import ODEBatch
|
|
14
14
|
from jinns.data._utils import _check_and_set_rar_parameters, _reset_or_increment
|
|
15
15
|
from jinns.data._AbstractDataGenerator import AbstractDataGenerator
|
|
@@ -24,7 +24,7 @@ class DataGeneratorODE(AbstractDataGenerator):
|
|
|
24
24
|
|
|
25
25
|
Parameters
|
|
26
26
|
----------
|
|
27
|
-
key :
|
|
27
|
+
key : PRNGKeyArray
|
|
28
28
|
Jax random key to sample new time points and to shuffle batches
|
|
29
29
|
nt : int
|
|
30
30
|
The number of total time points that will be divided in
|
|
@@ -42,10 +42,10 @@ class DataGeneratorODE(AbstractDataGenerator):
|
|
|
42
42
|
The method that generates the `nt` time points. `grid` means
|
|
43
43
|
regularly spaced points over the domain. `uniform` means uniformly
|
|
44
44
|
sampled points over the domain
|
|
45
|
-
rar_parameters : RarParameterDict, default=None
|
|
45
|
+
rar_parameters : None | RarParameterDict, default=None
|
|
46
46
|
A TypedDict to specify the Residual Adaptative Resampling procedure. See
|
|
47
47
|
the docstring from RarParameterDict
|
|
48
|
-
n_start : int, default=None
|
|
48
|
+
n_start : None | int, default=None
|
|
49
49
|
Defaults to None. The effective size of nt used at start time.
|
|
50
50
|
This value must be
|
|
51
51
|
provided when rar_parameters is not None. Otherwise we set internally
|
|
@@ -54,25 +54,43 @@ class DataGeneratorODE(AbstractDataGenerator):
|
|
|
54
54
|
then corresponds to the initial number of points we train the PINN.
|
|
55
55
|
"""
|
|
56
56
|
|
|
57
|
-
key:
|
|
58
|
-
nt: int = eqx.field(
|
|
59
|
-
tmin:
|
|
60
|
-
tmax:
|
|
61
|
-
temporal_batch_size: int | None = eqx.field(static=True
|
|
62
|
-
method: str = eqx.field(
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
# all the init=False fields are set in __post_init__
|
|
57
|
+
key: PRNGKeyArray
|
|
58
|
+
nt: int = eqx.field(static=True)
|
|
59
|
+
tmin: float
|
|
60
|
+
tmax: float
|
|
61
|
+
temporal_batch_size: int | None = eqx.field(static=True)
|
|
62
|
+
method: str = eqx.field(static=True)
|
|
63
|
+
rar_parameters: None | dict[str, int]
|
|
64
|
+
n_start: None | int
|
|
65
|
+
|
|
66
|
+
# --- Below fields are not passed as arguments to __init__
|
|
69
67
|
p: Float[Array, " nt 1"] | None = eqx.field(init=False)
|
|
70
68
|
rar_iter_from_last_sampling: int | None = eqx.field(init=False)
|
|
71
69
|
rar_iter_nb: int | None = eqx.field(init=False)
|
|
72
70
|
curr_time_idx: int = eqx.field(init=False)
|
|
73
71
|
times: Float[Array, " nt 1"] = eqx.field(init=False)
|
|
74
72
|
|
|
75
|
-
def
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
*,
|
|
76
|
+
key: PRNGKeyArray,
|
|
77
|
+
nt: int,
|
|
78
|
+
tmin: float,
|
|
79
|
+
tmax: float,
|
|
80
|
+
temporal_batch_size: int | None,
|
|
81
|
+
method: str = "uniform",
|
|
82
|
+
rar_parameters: None | dict[str, int] = None,
|
|
83
|
+
n_start: None | int = None,
|
|
84
|
+
):
|
|
85
|
+
self.key = key
|
|
86
|
+
self.nt = nt
|
|
87
|
+
self.tmin = tmin
|
|
88
|
+
self.tmax = tmax
|
|
89
|
+
self.temporal_batch_size = temporal_batch_size
|
|
90
|
+
self.method = method
|
|
91
|
+
self.n_start = n_start
|
|
92
|
+
self.rar_parameters = rar_parameters
|
|
93
|
+
|
|
76
94
|
(
|
|
77
95
|
self.n_start,
|
|
78
96
|
self.p,
|
|
@@ -97,7 +115,7 @@ class DataGeneratorODE(AbstractDataGenerator):
|
|
|
97
115
|
# above way for the key.
|
|
98
116
|
|
|
99
117
|
def sample_in_time_domain(
|
|
100
|
-
self, key:
|
|
118
|
+
self, key: PRNGKeyArray, sample_size: int | None = None
|
|
101
119
|
) -> Float[Array, " nt 1"]:
|
|
102
120
|
return jax.random.uniform(
|
|
103
121
|
key,
|
|
@@ -106,7 +124,9 @@ class DataGeneratorODE(AbstractDataGenerator):
|
|
|
106
124
|
maxval=self.tmax,
|
|
107
125
|
)
|
|
108
126
|
|
|
109
|
-
def generate_time_data(
|
|
127
|
+
def generate_time_data(
|
|
128
|
+
self, key: PRNGKeyArray
|
|
129
|
+
) -> tuple[PRNGKeyArray, Float[Array, " nt"]]:
|
|
110
130
|
"""
|
|
111
131
|
Construct a complete set of `self.nt` time points according to the
|
|
112
132
|
specified `self.method`
|
|
@@ -125,7 +145,11 @@ class DataGeneratorODE(AbstractDataGenerator):
|
|
|
125
145
|
def _get_time_operands(
|
|
126
146
|
self,
|
|
127
147
|
) -> tuple[
|
|
128
|
-
|
|
148
|
+
PRNGKeyArray,
|
|
149
|
+
Float[Array, " nt 1"],
|
|
150
|
+
int,
|
|
151
|
+
int | None,
|
|
152
|
+
Float[Array, " nt 1"] | None,
|
|
129
153
|
]:
|
|
130
154
|
return (
|
|
131
155
|
self.key,
|
|
@@ -150,7 +174,7 @@ class DataGeneratorODE(AbstractDataGenerator):
|
|
|
150
174
|
bend = bstart + self.temporal_batch_size
|
|
151
175
|
|
|
152
176
|
# Compute the effective number of used collocation points
|
|
153
|
-
if self.rar_parameters is not None:
|
|
177
|
+
if self.rar_parameters is not None and self.n_start is not None:
|
|
154
178
|
nt_eff = (
|
|
155
179
|
self.n_start
|
|
156
180
|
+ self.rar_iter_nb # type: ignore
|
|
@@ -167,7 +191,9 @@ class DataGeneratorODE(AbstractDataGenerator):
|
|
|
167
191
|
# handled above
|
|
168
192
|
)
|
|
169
193
|
new = eqx.tree_at(
|
|
170
|
-
lambda m: (m.key, m.times, m.curr_time_idx),
|
|
194
|
+
lambda m: (m.key, m.times, m.curr_time_idx), # type: ignore
|
|
195
|
+
self,
|
|
196
|
+
new_attributes,
|
|
171
197
|
)
|
|
172
198
|
|
|
173
199
|
# commands below are equivalent to
|
|
@@ -8,10 +8,32 @@ from __future__ import (
|
|
|
8
8
|
import equinox as eqx
|
|
9
9
|
import jax
|
|
10
10
|
import jax.numpy as jnp
|
|
11
|
-
from
|
|
11
|
+
from typing import TYPE_CHECKING, Self
|
|
12
|
+
from jaxtyping import PRNGKeyArray, Int, Array, Float
|
|
12
13
|
from jinns.data._Batchs import ObsBatchDict
|
|
13
14
|
from jinns.data._utils import _reset_or_increment
|
|
14
15
|
from jinns.data._AbstractDataGenerator import AbstractDataGenerator
|
|
16
|
+
from jinns.utils._DictToModuleMeta import DictToModuleMeta
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
# imports only used in type hints
|
|
20
|
+
InputEqParams = (
|
|
21
|
+
dict[str, Float[Array, " n_obs"]] | dict[str, Float[Array, " n_obs 1"]]
|
|
22
|
+
) | None
|
|
23
|
+
|
|
24
|
+
# Note that the lambda functions used below are with type: ignore just
|
|
25
|
+
# because the lambda are not type annotated, but there is no proper way
|
|
26
|
+
# to do this and we should assign the lambda to a type hinted variable
|
|
27
|
+
# before hand: this is not practical, let us not get mad at this
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class DGObservedParams(metaclass=DictToModuleMeta):
|
|
31
|
+
"""
|
|
32
|
+
However, static type checkers cannot know that DGObservedParams inherit from
|
|
33
|
+
eqx.Module and explicit casting to the latter class will be needed
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
pass
|
|
15
37
|
|
|
16
38
|
|
|
17
39
|
class DataGeneratorObservations(AbstractDataGenerator):
|
|
@@ -21,7 +43,7 @@ class DataGeneratorObservations(AbstractDataGenerator):
|
|
|
21
43
|
|
|
22
44
|
Parameters
|
|
23
45
|
----------
|
|
24
|
-
key :
|
|
46
|
+
key : PRNGKeyArray
|
|
25
47
|
Jax random key to shuffle batches
|
|
26
48
|
obs_batch_size : int | None
|
|
27
49
|
The size of the batch of randomly selected points among
|
|
@@ -39,7 +61,8 @@ class DataGeneratorObservations(AbstractDataGenerator):
|
|
|
39
61
|
observed_eq_params : dict[str, Float[Array, " n_obs 1"]], default={}
|
|
40
62
|
A dict with keys corresponding to
|
|
41
63
|
the parameter name. The keys must match the keys in
|
|
42
|
-
`params["eq_params"]
|
|
64
|
+
`params["eq_params"]`, ie., if only some parameters are observed, other
|
|
65
|
+
keys **must still appear with None as value**. The values are jnp.array with 2 dimensions
|
|
43
66
|
with values corresponding to the parameter value for which we also
|
|
44
67
|
have observed_pinn_in and observed_values. Hence the first
|
|
45
68
|
dimension must be aligned with observed_pinn_in and observed_values.
|
|
@@ -54,30 +77,37 @@ class DataGeneratorObservations(AbstractDataGenerator):
|
|
|
54
77
|
arguments of `jinns.solve()`. Read `jinns.solve()` doc for more info.
|
|
55
78
|
"""
|
|
56
79
|
|
|
57
|
-
key:
|
|
80
|
+
key: PRNGKeyArray
|
|
58
81
|
obs_batch_size: int | None = eqx.field(static=True)
|
|
59
82
|
observed_pinn_in: Float[Array, " n_obs nb_pinn_in"]
|
|
60
83
|
observed_values: Float[Array, " n_obs nb_pinn_out"]
|
|
61
|
-
observed_eq_params:
|
|
62
|
-
|
|
63
|
-
)
|
|
64
|
-
sharding_device: jax.sharding.Sharding = eqx.field(static=True, default=None)
|
|
84
|
+
observed_eq_params: eqx.Module | None
|
|
85
|
+
sharding_device: jax.sharding.Sharding | None # = eqx.field(static=True)
|
|
65
86
|
|
|
66
87
|
n: int = eqx.field(init=False, static=True)
|
|
67
88
|
curr_idx: int = eqx.field(init=False)
|
|
68
89
|
indices: Array = eqx.field(init=False)
|
|
69
90
|
|
|
70
|
-
def
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
*,
|
|
94
|
+
key: PRNGKeyArray,
|
|
95
|
+
obs_batch_size: int | None = None,
|
|
96
|
+
observed_pinn_in: Float[Array, " n_obs nb_pinn_in"],
|
|
97
|
+
observed_values: Float[Array, " n_obs nb_pinn_out"],
|
|
98
|
+
observed_eq_params: InputEqParams | None = None,
|
|
99
|
+
sharding_device: jax.sharding.Sharding | None = None,
|
|
100
|
+
) -> None:
|
|
101
|
+
super().__init__()
|
|
102
|
+
self.key = key
|
|
103
|
+
self.obs_batch_size = obs_batch_size
|
|
104
|
+
self.observed_pinn_in = observed_pinn_in
|
|
105
|
+
self.observed_values = observed_values
|
|
106
|
+
|
|
71
107
|
if self.observed_pinn_in.shape[0] != self.observed_values.shape[0]:
|
|
72
108
|
raise ValueError(
|
|
73
109
|
"self.observed_pinn_in and self.observed_values must have same first axis"
|
|
74
110
|
)
|
|
75
|
-
for _, v in self.observed_eq_params.items():
|
|
76
|
-
if v.shape[0] != self.observed_pinn_in.shape[0]:
|
|
77
|
-
raise ValueError(
|
|
78
|
-
"self.observed_pinn_in and the values of"
|
|
79
|
-
" self.observed_eq_params must have the same first axis"
|
|
80
|
-
)
|
|
81
111
|
if len(self.observed_pinn_in.shape) == 1:
|
|
82
112
|
self.observed_pinn_in = self.observed_pinn_in[:, None]
|
|
83
113
|
if self.observed_pinn_in.ndim > 2:
|
|
@@ -86,16 +116,34 @@ class DataGeneratorObservations(AbstractDataGenerator):
|
|
|
86
116
|
self.observed_values = self.observed_values[:, None]
|
|
87
117
|
if self.observed_values.ndim > 2:
|
|
88
118
|
raise ValueError("self.observed_values must have 2 dimensions")
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
119
|
+
|
|
120
|
+
if observed_eq_params is not None:
|
|
121
|
+
for _, v in observed_eq_params.items():
|
|
122
|
+
if v.shape[0] != self.observed_pinn_in.shape[0]:
|
|
123
|
+
raise ValueError(
|
|
124
|
+
"self.observed_pinn_in and the values of"
|
|
125
|
+
" self.observed_eq_params must have the same first axis"
|
|
126
|
+
)
|
|
127
|
+
for k, v in observed_eq_params.items():
|
|
128
|
+
if len(v.shape) == 1:
|
|
129
|
+
# Reshape to add an axis for 1-d Array
|
|
130
|
+
observed_eq_params[k] = v[:, None]
|
|
131
|
+
if len(v.shape) > 2:
|
|
132
|
+
raise ValueError(
|
|
133
|
+
f"Each key of observed_eq_params must have 2"
|
|
134
|
+
f"dimensions, key {k} had shape {v.shape}."
|
|
135
|
+
)
|
|
136
|
+
# Convert the dict of observed parameters to the internal `EqParams`
|
|
137
|
+
# class used by Jinns.
|
|
138
|
+
self.observed_eq_params = DGObservedParams(
|
|
139
|
+
observed_eq_params, "DGObservedParams"
|
|
140
|
+
)
|
|
141
|
+
else:
|
|
142
|
+
self.observed_eq_params = observed_eq_params
|
|
96
143
|
|
|
97
144
|
self.n = self.observed_pinn_in.shape[0]
|
|
98
145
|
|
|
146
|
+
self.sharding_device = sharding_device
|
|
99
147
|
if self.sharding_device is not None:
|
|
100
148
|
self.observed_pinn_in = jax.lax.with_sharding_constraint(
|
|
101
149
|
self.observed_pinn_in, self.sharding_device
|
|
@@ -126,7 +174,9 @@ class DataGeneratorObservations(AbstractDataGenerator):
|
|
|
126
174
|
self.key, _ = jax.random.split(self.key, 2) # to make it equivalent to
|
|
127
175
|
# the call to _reset_batch_idx_and_permute in legacy DG
|
|
128
176
|
|
|
129
|
-
def _get_operands(
|
|
177
|
+
def _get_operands(
|
|
178
|
+
self,
|
|
179
|
+
) -> tuple[PRNGKeyArray, Int[Array, " n"], int, int | None, None]:
|
|
130
180
|
return (
|
|
131
181
|
self.key,
|
|
132
182
|
self.indices,
|
|
@@ -137,17 +187,19 @@ class DataGeneratorObservations(AbstractDataGenerator):
|
|
|
137
187
|
|
|
138
188
|
def obs_batch(
|
|
139
189
|
self,
|
|
140
|
-
) -> tuple[
|
|
190
|
+
) -> tuple[Self, ObsBatchDict]:
|
|
141
191
|
"""
|
|
142
192
|
Return an update DataGeneratorObservations instance and an ObsBatchDict
|
|
143
193
|
"""
|
|
144
194
|
if self.obs_batch_size is None or self.obs_batch_size == self.n:
|
|
145
195
|
# Avoid unnecessary reshuffling
|
|
146
|
-
return self,
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
196
|
+
return self, ObsBatchDict(
|
|
197
|
+
{
|
|
198
|
+
"pinn_in": self.observed_pinn_in,
|
|
199
|
+
"val": self.observed_values,
|
|
200
|
+
"eq_params": self.observed_eq_params,
|
|
201
|
+
}
|
|
202
|
+
)
|
|
151
203
|
|
|
152
204
|
new_attributes = _reset_or_increment(
|
|
153
205
|
self.curr_idx + self.obs_batch_size,
|
|
@@ -157,7 +209,9 @@ class DataGeneratorObservations(AbstractDataGenerator):
|
|
|
157
209
|
# handled above
|
|
158
210
|
)
|
|
159
211
|
new = eqx.tree_at(
|
|
160
|
-
lambda m: (m.key, m.indices, m.curr_idx),
|
|
212
|
+
lambda m: (m.key, m.indices, m.curr_idx), # type: ignore
|
|
213
|
+
self,
|
|
214
|
+
new_attributes,
|
|
161
215
|
)
|
|
162
216
|
|
|
163
217
|
minib_indices = jax.lax.dynamic_slice(
|
|
@@ -174,7 +228,7 @@ class DataGeneratorObservations(AbstractDataGenerator):
|
|
|
174
228
|
new.observed_values, minib_indices, unique_indices=True, axis=0
|
|
175
229
|
),
|
|
176
230
|
"eq_params": jax.tree_util.tree_map(
|
|
177
|
-
lambda a: jnp.take(a, minib_indices, unique_indices=True, axis=0),
|
|
231
|
+
lambda a: jnp.take(a, minib_indices, unique_indices=True, axis=0), # type: ignore
|
|
178
232
|
new.observed_eq_params,
|
|
179
233
|
),
|
|
180
234
|
}
|
|
@@ -182,7 +236,7 @@ class DataGeneratorObservations(AbstractDataGenerator):
|
|
|
182
236
|
|
|
183
237
|
def get_batch(
|
|
184
238
|
self,
|
|
185
|
-
) -> tuple[
|
|
239
|
+
) -> tuple[Self, ObsBatchDict]:
|
|
186
240
|
"""
|
|
187
241
|
Generic method to return a batch
|
|
188
242
|
"""
|
|
@@ -5,12 +5,23 @@ Define the DataGenerators modules
|
|
|
5
5
|
from __future__ import (
|
|
6
6
|
annotations,
|
|
7
7
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
|
+
from typing import Self
|
|
8
9
|
import equinox as eqx
|
|
9
10
|
import jax
|
|
10
11
|
import jax.numpy as jnp
|
|
11
|
-
from jaxtyping import
|
|
12
|
+
from jaxtyping import PRNGKeyArray, Array, Float
|
|
12
13
|
from jinns.data._utils import _reset_or_increment
|
|
13
14
|
from jinns.data._AbstractDataGenerator import AbstractDataGenerator
|
|
15
|
+
from jinns.utils._DictToModuleMeta import DictToModuleMeta
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class DGParams(metaclass=DictToModuleMeta):
|
|
19
|
+
"""
|
|
20
|
+
However, static type checkers cannot know that DGParams inherit from
|
|
21
|
+
eqx.Module and explicit casting to the latter class will be needed
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
pass
|
|
14
25
|
|
|
15
26
|
|
|
16
27
|
class DataGeneratorParameter(AbstractDataGenerator):
|
|
@@ -21,9 +32,8 @@ class DataGeneratorParameter(AbstractDataGenerator):
|
|
|
21
32
|
|
|
22
33
|
Parameters
|
|
23
34
|
----------
|
|
24
|
-
|
|
25
|
-
Jax random key to sample new time points and to shuffle batches
|
|
26
|
-
or a dict of Jax random keys with key entries from param_ranges
|
|
35
|
+
key : PRNGKeyArray
|
|
36
|
+
Jax random key to sample new time points and to shuffle batches.
|
|
27
37
|
n : int
|
|
28
38
|
The number of total points that will be divided in
|
|
29
39
|
batches. Batches are made so that each data point is seen only
|
|
@@ -58,71 +68,86 @@ class DataGeneratorParameter(AbstractDataGenerator):
|
|
|
58
68
|
Defaults to None.
|
|
59
69
|
"""
|
|
60
70
|
|
|
61
|
-
|
|
71
|
+
key: PRNGKeyArray
|
|
62
72
|
n: int = eqx.field(static=True)
|
|
63
|
-
param_batch_size: int | None = eqx.field(static=True
|
|
64
|
-
param_ranges: dict[str, tuple[Float, Float]] = eqx.field(
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
)
|
|
73
|
+
param_batch_size: int | None = eqx.field(static=True)
|
|
74
|
+
param_ranges: dict[str, tuple[Float, Float]] = eqx.field(static=True)
|
|
75
|
+
method: str = eqx.field(static=True)
|
|
76
|
+
user_data: dict[str, Float[Array, " n"]]
|
|
77
|
+
|
|
78
|
+
# --- Below fields are not passed as arguments to __init__
|
|
79
|
+
_all_params_keys: set[str] = eqx.field(init=False, static=True)
|
|
80
|
+
curr_param_idx: eqx.Module | None = eqx.field(init=False)
|
|
81
|
+
param_n_samples: eqx.Module = eqx.field(init=False)
|
|
71
82
|
|
|
72
|
-
|
|
73
|
-
|
|
83
|
+
def __init__(
|
|
84
|
+
self,
|
|
85
|
+
*,
|
|
86
|
+
key: PRNGKeyArray,
|
|
87
|
+
n: int,
|
|
88
|
+
param_batch_size: int | None,
|
|
89
|
+
param_ranges: dict[str, tuple[Float, Float]] = {},
|
|
90
|
+
method: str = "uniform",
|
|
91
|
+
user_data: dict[str, Float[Array, " n"]] = {},
|
|
92
|
+
):
|
|
93
|
+
self.key = key
|
|
94
|
+
self.n = n
|
|
95
|
+
self.param_batch_size = param_batch_size
|
|
96
|
+
self.param_ranges = param_ranges
|
|
97
|
+
self.method = method
|
|
98
|
+
self.user_data = user_data
|
|
99
|
+
|
|
100
|
+
_all_keys = set().union(self.param_ranges, self.user_data)
|
|
101
|
+
self._all_params_keys = _all_keys
|
|
74
102
|
|
|
75
|
-
def __post_init__(self):
|
|
76
|
-
if self.user_data is None:
|
|
77
|
-
self.user_data = {}
|
|
78
|
-
if self.param_ranges is None:
|
|
79
|
-
self.param_ranges = {}
|
|
80
103
|
if self.param_batch_size is not None and self.n < self.param_batch_size:
|
|
81
104
|
raise ValueError(
|
|
82
105
|
f"Number of data points ({self.n}) is smaller than the"
|
|
83
106
|
f"number of batch points ({self.param_batch_size})."
|
|
84
107
|
)
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
108
|
+
|
|
109
|
+
# NOTE from jinns > v1.5.1 we work with eqx.Module
|
|
110
|
+
# because eq_params is not a dict anymore.
|
|
111
|
+
# We have to use a different class from the publicly exposed EqParams
|
|
112
|
+
# because fields(EqParams) are not necessarily all present in the
|
|
113
|
+
# datagenerator, which would cause eqx.Module to error.
|
|
114
|
+
|
|
115
|
+
# 1) Call self.generate_data() to generate a dictionnary that merges the scattered data between `user_data` and `param_ranges`
|
|
116
|
+
self.key, _param_n_samples = self.generate_data(self.key)
|
|
117
|
+
|
|
118
|
+
# 2) Use the dictionnary to populate the field of the eqx.Module.
|
|
119
|
+
self.param_n_samples = DGParams(_param_n_samples, "DGParams")
|
|
88
120
|
|
|
89
121
|
if self.param_batch_size is None:
|
|
90
|
-
self.curr_param_idx = None
|
|
122
|
+
self.curr_param_idx = None
|
|
91
123
|
else:
|
|
92
|
-
self.
|
|
93
|
-
for k in self.
|
|
94
|
-
self.curr_param_idx[k] = self.n + self.param_batch_size
|
|
95
|
-
# to be sure there is a shuffling at first get_batch()
|
|
124
|
+
curr_idx = self.n + self.param_batch_size
|
|
125
|
+
param_keys_and_curr_idx = {k: curr_idx for k in self._all_params_keys}
|
|
96
126
|
|
|
97
|
-
|
|
98
|
-
# the dict self.param_n_samples and then we will only use this one
|
|
99
|
-
# because it merges the scattered data between `user_data` and
|
|
100
|
-
# `param_ranges`
|
|
101
|
-
self.keys, self.param_n_samples = self.generate_data(self.keys)
|
|
127
|
+
self.curr_param_idx = DGParams(param_keys_and_curr_idx)
|
|
102
128
|
|
|
103
129
|
def generate_data(
|
|
104
|
-
self,
|
|
105
|
-
) -> tuple[
|
|
130
|
+
self, key: PRNGKeyArray
|
|
131
|
+
) -> tuple[PRNGKeyArray, dict[str, Float[Array, " n 1"]]]:
|
|
106
132
|
"""
|
|
107
133
|
Generate parameter samples, either through generation
|
|
108
134
|
or using user-provided data.
|
|
109
135
|
"""
|
|
110
136
|
param_n_samples = {}
|
|
111
137
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
)
|
|
116
|
-
for k in all_keys:
|
|
138
|
+
# Some of the subkeys might not be used cause of user-provided data.
|
|
139
|
+
# This is not a big deal and simpler like that.
|
|
140
|
+
key, *subkeys = jax.random.split(key, len(self._all_params_keys) + 1)
|
|
141
|
+
for i, k in enumerate(self._all_params_keys):
|
|
117
142
|
if self.user_data and k in self.user_data.keys():
|
|
118
|
-
|
|
119
|
-
param_n_samples[k] = self.user_data[k]
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
raise ValueError(
|
|
143
|
+
try:
|
|
144
|
+
param_n_samples[k] = self.user_data[k].reshape((self.n, 1))
|
|
145
|
+
except TypeError:
|
|
146
|
+
shape = self.user_data[k].shape
|
|
147
|
+
raise TypeError(
|
|
124
148
|
"Wrong shape for user provided parameters"
|
|
125
|
-
f" in user_data dictionary at key='{k}'"
|
|
149
|
+
f" in user_data dictionary at key='{k}' got {shape} "
|
|
150
|
+
f"and expected {(self.n, 1)}."
|
|
126
151
|
)
|
|
127
152
|
else:
|
|
128
153
|
if self.method == "grid":
|
|
@@ -132,75 +157,101 @@ class DataGeneratorParameter(AbstractDataGenerator):
|
|
|
132
157
|
param_n_samples[k] = jnp.arange(xmin, xmax, partial)[:, None]
|
|
133
158
|
elif self.method == "uniform":
|
|
134
159
|
xmin, xmax = self.param_ranges[k][0], self.param_ranges[k][1]
|
|
135
|
-
keys[k], subkey = jax.random.split(keys[k], 2)
|
|
136
160
|
param_n_samples[k] = jax.random.uniform(
|
|
137
|
-
|
|
161
|
+
subkeys[i], shape=(self.n, 1), minval=xmin, maxval=xmax
|
|
138
162
|
)
|
|
139
163
|
else:
|
|
140
164
|
raise ValueError("Method " + self.method + " is not implemented.")
|
|
141
165
|
|
|
142
|
-
return
|
|
143
|
-
|
|
144
|
-
def _get_param_operands(
|
|
145
|
-
self, k: str
|
|
146
|
-
) -> tuple[Key, Float[Array, " n"], int, int | None, None]:
|
|
147
|
-
return (
|
|
148
|
-
self.keys[k],
|
|
149
|
-
self.param_n_samples[k],
|
|
150
|
-
self.curr_param_idx[k],
|
|
151
|
-
self.param_batch_size,
|
|
152
|
-
None,
|
|
153
|
-
)
|
|
166
|
+
return key, param_n_samples
|
|
154
167
|
|
|
155
|
-
def param_batch(self):
|
|
168
|
+
def param_batch(self) -> tuple[Self, eqx.Module]:
|
|
156
169
|
"""
|
|
157
|
-
Return
|
|
158
|
-
If all the batches have been seen, we reshuffle them
|
|
159
|
-
otherwise we just return the next unseen batch.
|
|
170
|
+
Return an `eqx.Module` with batches of parameters at its leafs.
|
|
171
|
+
If all the batches have been seen, we reshuffle them (or rather
|
|
172
|
+
their indices), otherwise we just return the next unseen batch.
|
|
160
173
|
"""
|
|
161
174
|
|
|
162
175
|
if self.param_batch_size is None or self.param_batch_size == self.n:
|
|
176
|
+
# Full batch mode: nothing to do.
|
|
163
177
|
return self, self.param_n_samples
|
|
178
|
+
else:
|
|
179
|
+
|
|
180
|
+
def _reset_or_increment_wrapper(
|
|
181
|
+
param_k: Array, idx_k: int, key_k: PRNGKeyArray
|
|
182
|
+
):
|
|
183
|
+
everything_but_key = _reset_or_increment(
|
|
184
|
+
idx_k + self.param_batch_size, # type: ignore
|
|
185
|
+
self.n,
|
|
186
|
+
(key_k, param_k, idx_k, self.param_batch_size, None), # type: ignore
|
|
187
|
+
)[1:]
|
|
188
|
+
return everything_but_key
|
|
164
189
|
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
idx_k + self.param_batch_size,
|
|
168
|
-
self.n,
|
|
169
|
-
(key_k, param_k, idx_k, self.param_batch_size, None), # type: ignore
|
|
170
|
-
# ignore since the case self.param_batch_size is None has been
|
|
171
|
-
# handled above
|
|
190
|
+
new_key, *subkeys = jax.random.split(
|
|
191
|
+
self.key, len(self._all_params_keys) + 1
|
|
172
192
|
)
|
|
173
193
|
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
res
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
194
|
+
# From PRNGKeyArray to a pytree of keys with adequate structure
|
|
195
|
+
subkeys = jax.tree.unflatten(
|
|
196
|
+
jax.tree.structure(self.param_n_samples), subkeys
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
res = jax.tree.map(
|
|
200
|
+
_reset_or_increment_wrapper,
|
|
201
|
+
self.param_n_samples,
|
|
202
|
+
self.curr_param_idx,
|
|
203
|
+
subkeys,
|
|
204
|
+
)
|
|
205
|
+
# we must transpose the pytrees because both params and curr_idx # are merged in res
|
|
206
|
+
# https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#transposing-trees
|
|
207
|
+
|
|
208
|
+
new_attributes = jax.tree.transpose(
|
|
209
|
+
jax.tree.structure(self.param_n_samples),
|
|
210
|
+
jax.tree.structure([0, 0]),
|
|
211
|
+
res,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
new = eqx.tree_at(
|
|
215
|
+
lambda m: (m.param_n_samples, m.curr_param_idx),
|
|
216
|
+
self,
|
|
217
|
+
new_attributes,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
new = eqx.tree_at(lambda m: m.key, new, new_key)
|
|
221
|
+
|
|
222
|
+
return new, jax.tree_util.tree_map(
|
|
223
|
+
lambda p, q: jax.lax.dynamic_slice(
|
|
224
|
+
p, start_indices=(q, 0), slice_sizes=(new.param_batch_size, 1)
|
|
225
|
+
),
|
|
226
|
+
new.param_n_samples,
|
|
227
|
+
new.curr_param_idx,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
def get_batch(self) -> tuple[Self, eqx.Module]:
|
|
203
231
|
"""
|
|
204
232
|
Generic method to return a batch
|
|
205
233
|
"""
|
|
206
234
|
return self.param_batch()
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
if __name__ == "__main__":
|
|
238
|
+
key = jax.random.PRNGKey(2)
|
|
239
|
+
key, subkey = jax.random.split(key)
|
|
240
|
+
|
|
241
|
+
n = 64
|
|
242
|
+
param_batch_size = 32
|
|
243
|
+
method = "uniform"
|
|
244
|
+
param_ranges = {"theta": (10.0, 11.0)}
|
|
245
|
+
user_data = {"nu": jnp.ones((n, 1))}
|
|
246
|
+
|
|
247
|
+
x = DataGeneratorParameter(
|
|
248
|
+
key=subkey,
|
|
249
|
+
n=n,
|
|
250
|
+
param_batch_size=param_batch_size,
|
|
251
|
+
param_ranges=param_ranges,
|
|
252
|
+
method=method,
|
|
253
|
+
user_data=user_data,
|
|
254
|
+
)
|
|
255
|
+
print(x.key)
|
|
256
|
+
x, batch = x.get_batch()
|
|
257
|
+
print(x.key)
|