jinns 1.5.0__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +7 -7
- jinns/data/_AbstractDataGenerator.py +1 -1
- jinns/data/_Batchs.py +47 -13
- jinns/data/_CubicMeshPDENonStatio.py +203 -54
- jinns/data/_CubicMeshPDEStatio.py +190 -54
- jinns/data/_DataGeneratorODE.py +48 -22
- jinns/data/_DataGeneratorObservations.py +75 -32
- jinns/data/_DataGeneratorParameter.py +152 -101
- jinns/data/__init__.py +2 -1
- jinns/data/_utils.py +22 -10
- jinns/loss/_DynamicLoss.py +21 -20
- jinns/loss/_DynamicLossAbstract.py +51 -36
- jinns/loss/_LossODE.py +210 -191
- jinns/loss/_LossPDE.py +441 -368
- jinns/loss/_abstract_loss.py +60 -25
- jinns/loss/_loss_components.py +4 -25
- jinns/loss/_loss_utils.py +23 -0
- jinns/loss/_loss_weight_updates.py +6 -7
- jinns/loss/_loss_weights.py +34 -35
- jinns/nn/_abstract_pinn.py +0 -2
- jinns/nn/_hyperpinn.py +34 -23
- jinns/nn/_mlp.py +5 -4
- jinns/nn/_pinn.py +1 -16
- jinns/nn/_ppinn.py +5 -16
- jinns/nn/_save_load.py +11 -4
- jinns/nn/_spinn.py +1 -16
- jinns/nn/_spinn_mlp.py +5 -5
- jinns/nn/_utils.py +33 -38
- jinns/parameters/__init__.py +3 -1
- jinns/parameters/_derivative_keys.py +99 -41
- jinns/parameters/_params.py +58 -25
- jinns/solver/_solve.py +14 -8
- jinns/utils/_DictToModuleMeta.py +66 -0
- jinns/utils/_ItemizableModule.py +19 -0
- jinns/utils/__init__.py +2 -1
- jinns/utils/_types.py +25 -15
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/METADATA +2 -2
- jinns-1.6.0.dist-info/RECORD +57 -0
- jinns-1.5.0.dist-info/RECORD +0 -55
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/WHEEL +0 -0
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/top_level.txt +0 -0
|
@@ -8,10 +8,23 @@ from __future__ import (
|
|
|
8
8
|
import equinox as eqx
|
|
9
9
|
import jax
|
|
10
10
|
import jax.numpy as jnp
|
|
11
|
-
from
|
|
11
|
+
from typing import TYPE_CHECKING, Self
|
|
12
|
+
from jaxtyping import PRNGKeyArray, Int, Array, Float
|
|
12
13
|
from jinns.data._Batchs import ObsBatchDict
|
|
13
14
|
from jinns.data._utils import _reset_or_increment
|
|
14
15
|
from jinns.data._AbstractDataGenerator import AbstractDataGenerator
|
|
16
|
+
from jinns.parameters._params import EqParams
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
# imports only used in type hints
|
|
20
|
+
InputEqParams = (
|
|
21
|
+
dict[str, Float[Array, " n_obs"]] | dict[str, Float[Array, " n_obs 1"]]
|
|
22
|
+
) | None
|
|
23
|
+
|
|
24
|
+
# Note that the lambda functions used below are with type: ignore just
|
|
25
|
+
# because the lambda are not type annotated, but there is no proper way
|
|
26
|
+
# to do this and we should assign the lambda to a type hinted variable
|
|
27
|
+
# before hand: this is not practical, let us not get mad at this
|
|
15
28
|
|
|
16
29
|
|
|
17
30
|
class DataGeneratorObservations(AbstractDataGenerator):
|
|
@@ -21,7 +34,7 @@ class DataGeneratorObservations(AbstractDataGenerator):
|
|
|
21
34
|
|
|
22
35
|
Parameters
|
|
23
36
|
----------
|
|
24
|
-
key :
|
|
37
|
+
key : PRNGKeyArray
|
|
25
38
|
Jax random key to shuffle batches
|
|
26
39
|
obs_batch_size : int | None
|
|
27
40
|
The size of the batch of randomly selected points among
|
|
@@ -39,7 +52,8 @@ class DataGeneratorObservations(AbstractDataGenerator):
|
|
|
39
52
|
observed_eq_params : dict[str, Float[Array, " n_obs 1"]], default={}
|
|
40
53
|
A dict with keys corresponding to
|
|
41
54
|
the parameter name. The keys must match the keys in
|
|
42
|
-
`params["eq_params"]
|
|
55
|
+
`params["eq_params"]`, ie., if only some parameters are observed, other
|
|
56
|
+
keys **must still appear with None as value**. The values are jnp.array with 2 dimensions
|
|
43
57
|
with values corresponding to the parameter value for which we also
|
|
44
58
|
have observed_pinn_in and observed_values. Hence the first
|
|
45
59
|
dimension must be aligned with observed_pinn_in and observed_values.
|
|
@@ -54,30 +68,37 @@ class DataGeneratorObservations(AbstractDataGenerator):
|
|
|
54
68
|
arguments of `jinns.solve()`. Read `jinns.solve()` doc for more info.
|
|
55
69
|
"""
|
|
56
70
|
|
|
57
|
-
key:
|
|
71
|
+
key: PRNGKeyArray
|
|
58
72
|
obs_batch_size: int | None = eqx.field(static=True)
|
|
59
73
|
observed_pinn_in: Float[Array, " n_obs nb_pinn_in"]
|
|
60
74
|
observed_values: Float[Array, " n_obs nb_pinn_out"]
|
|
61
|
-
observed_eq_params:
|
|
62
|
-
|
|
63
|
-
)
|
|
64
|
-
sharding_device: jax.sharding.Sharding = eqx.field(static=True, default=None)
|
|
75
|
+
observed_eq_params: eqx.Module | None
|
|
76
|
+
sharding_device: jax.sharding.Sharding | None # = eqx.field(static=True)
|
|
65
77
|
|
|
66
78
|
n: int = eqx.field(init=False, static=True)
|
|
67
79
|
curr_idx: int = eqx.field(init=False)
|
|
68
80
|
indices: Array = eqx.field(init=False)
|
|
69
81
|
|
|
70
|
-
def
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
*,
|
|
85
|
+
key: PRNGKeyArray,
|
|
86
|
+
obs_batch_size: int | None = None,
|
|
87
|
+
observed_pinn_in: Float[Array, " n_obs nb_pinn_in"],
|
|
88
|
+
observed_values: Float[Array, " n_obs nb_pinn_out"],
|
|
89
|
+
observed_eq_params: InputEqParams | None = None,
|
|
90
|
+
sharding_device: jax.sharding.Sharding | None = None,
|
|
91
|
+
) -> None:
|
|
92
|
+
super().__init__()
|
|
93
|
+
self.key = key
|
|
94
|
+
self.obs_batch_size = obs_batch_size
|
|
95
|
+
self.observed_pinn_in = observed_pinn_in
|
|
96
|
+
self.observed_values = observed_values
|
|
97
|
+
|
|
71
98
|
if self.observed_pinn_in.shape[0] != self.observed_values.shape[0]:
|
|
72
99
|
raise ValueError(
|
|
73
100
|
"self.observed_pinn_in and self.observed_values must have same first axis"
|
|
74
101
|
)
|
|
75
|
-
for _, v in self.observed_eq_params.items():
|
|
76
|
-
if v.shape[0] != self.observed_pinn_in.shape[0]:
|
|
77
|
-
raise ValueError(
|
|
78
|
-
"self.observed_pinn_in and the values of"
|
|
79
|
-
" self.observed_eq_params must have the same first axis"
|
|
80
|
-
)
|
|
81
102
|
if len(self.observed_pinn_in.shape) == 1:
|
|
82
103
|
self.observed_pinn_in = self.observed_pinn_in[:, None]
|
|
83
104
|
if self.observed_pinn_in.ndim > 2:
|
|
@@ -86,16 +107,32 @@ class DataGeneratorObservations(AbstractDataGenerator):
|
|
|
86
107
|
self.observed_values = self.observed_values[:, None]
|
|
87
108
|
if self.observed_values.ndim > 2:
|
|
88
109
|
raise ValueError("self.observed_values must have 2 dimensions")
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
110
|
+
|
|
111
|
+
if observed_eq_params is not None:
|
|
112
|
+
for _, v in observed_eq_params.items():
|
|
113
|
+
if v.shape[0] != self.observed_pinn_in.shape[0]:
|
|
114
|
+
raise ValueError(
|
|
115
|
+
"self.observed_pinn_in and the values of"
|
|
116
|
+
" self.observed_eq_params must have the same first axis"
|
|
117
|
+
)
|
|
118
|
+
for k, v in observed_eq_params.items():
|
|
119
|
+
if len(v.shape) == 1:
|
|
120
|
+
# Reshape to add an axis for 1-d Array
|
|
121
|
+
observed_eq_params[k] = v[:, None]
|
|
122
|
+
if len(v.shape) > 2:
|
|
123
|
+
raise ValueError(
|
|
124
|
+
f"Each key of observed_eq_params must have 2"
|
|
125
|
+
f"dimensions, key {k} had shape {v.shape}."
|
|
126
|
+
)
|
|
127
|
+
# Convert the dict of observed parameters to the internal `EqParams`
|
|
128
|
+
# class used by Jinns.
|
|
129
|
+
self.observed_eq_params = EqParams(observed_eq_params, "EqParams")
|
|
130
|
+
else:
|
|
131
|
+
self.observed_eq_params = observed_eq_params
|
|
96
132
|
|
|
97
133
|
self.n = self.observed_pinn_in.shape[0]
|
|
98
134
|
|
|
135
|
+
self.sharding_device = sharding_device
|
|
99
136
|
if self.sharding_device is not None:
|
|
100
137
|
self.observed_pinn_in = jax.lax.with_sharding_constraint(
|
|
101
138
|
self.observed_pinn_in, self.sharding_device
|
|
@@ -126,7 +163,9 @@ class DataGeneratorObservations(AbstractDataGenerator):
|
|
|
126
163
|
self.key, _ = jax.random.split(self.key, 2) # to make it equivalent to
|
|
127
164
|
# the call to _reset_batch_idx_and_permute in legacy DG
|
|
128
165
|
|
|
129
|
-
def _get_operands(
|
|
166
|
+
def _get_operands(
|
|
167
|
+
self,
|
|
168
|
+
) -> tuple[PRNGKeyArray, Int[Array, " n"], int, int | None, None]:
|
|
130
169
|
return (
|
|
131
170
|
self.key,
|
|
132
171
|
self.indices,
|
|
@@ -137,17 +176,19 @@ class DataGeneratorObservations(AbstractDataGenerator):
|
|
|
137
176
|
|
|
138
177
|
def obs_batch(
|
|
139
178
|
self,
|
|
140
|
-
) -> tuple[
|
|
179
|
+
) -> tuple[Self, ObsBatchDict]:
|
|
141
180
|
"""
|
|
142
181
|
Return an update DataGeneratorObservations instance and an ObsBatchDict
|
|
143
182
|
"""
|
|
144
183
|
if self.obs_batch_size is None or self.obs_batch_size == self.n:
|
|
145
184
|
# Avoid unnecessary reshuffling
|
|
146
|
-
return self,
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
185
|
+
return self, ObsBatchDict(
|
|
186
|
+
{
|
|
187
|
+
"pinn_in": self.observed_pinn_in,
|
|
188
|
+
"val": self.observed_values,
|
|
189
|
+
"eq_params": self.observed_eq_params,
|
|
190
|
+
}
|
|
191
|
+
)
|
|
151
192
|
|
|
152
193
|
new_attributes = _reset_or_increment(
|
|
153
194
|
self.curr_idx + self.obs_batch_size,
|
|
@@ -157,7 +198,9 @@ class DataGeneratorObservations(AbstractDataGenerator):
|
|
|
157
198
|
# handled above
|
|
158
199
|
)
|
|
159
200
|
new = eqx.tree_at(
|
|
160
|
-
lambda m: (m.key, m.indices, m.curr_idx),
|
|
201
|
+
lambda m: (m.key, m.indices, m.curr_idx), # type: ignore
|
|
202
|
+
self,
|
|
203
|
+
new_attributes,
|
|
161
204
|
)
|
|
162
205
|
|
|
163
206
|
minib_indices = jax.lax.dynamic_slice(
|
|
@@ -174,7 +217,7 @@ class DataGeneratorObservations(AbstractDataGenerator):
|
|
|
174
217
|
new.observed_values, minib_indices, unique_indices=True, axis=0
|
|
175
218
|
),
|
|
176
219
|
"eq_params": jax.tree_util.tree_map(
|
|
177
|
-
lambda a: jnp.take(a, minib_indices, unique_indices=True, axis=0),
|
|
220
|
+
lambda a: jnp.take(a, minib_indices, unique_indices=True, axis=0), # type: ignore
|
|
178
221
|
new.observed_eq_params,
|
|
179
222
|
),
|
|
180
223
|
}
|
|
@@ -182,7 +225,7 @@ class DataGeneratorObservations(AbstractDataGenerator):
|
|
|
182
225
|
|
|
183
226
|
def get_batch(
|
|
184
227
|
self,
|
|
185
|
-
) -> tuple[
|
|
228
|
+
) -> tuple[Self, ObsBatchDict]:
|
|
186
229
|
"""
|
|
187
230
|
Generic method to return a batch
|
|
188
231
|
"""
|
|
@@ -5,12 +5,23 @@ Define the DataGenerators modules
|
|
|
5
5
|
from __future__ import (
|
|
6
6
|
annotations,
|
|
7
7
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
|
+
from typing import Self
|
|
8
9
|
import equinox as eqx
|
|
9
10
|
import jax
|
|
10
11
|
import jax.numpy as jnp
|
|
11
|
-
from jaxtyping import
|
|
12
|
+
from jaxtyping import PRNGKeyArray, Array, Float
|
|
12
13
|
from jinns.data._utils import _reset_or_increment
|
|
13
14
|
from jinns.data._AbstractDataGenerator import AbstractDataGenerator
|
|
15
|
+
from jinns.utils._DictToModuleMeta import DictToModuleMeta
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class DGParams(metaclass=DictToModuleMeta):
|
|
19
|
+
"""
|
|
20
|
+
However, static type checkers cannot know that DGParams inherit from
|
|
21
|
+
eqx.Module and explicit casting to the latter class will be needed
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
pass
|
|
14
25
|
|
|
15
26
|
|
|
16
27
|
class DataGeneratorParameter(AbstractDataGenerator):
|
|
@@ -21,9 +32,8 @@ class DataGeneratorParameter(AbstractDataGenerator):
|
|
|
21
32
|
|
|
22
33
|
Parameters
|
|
23
34
|
----------
|
|
24
|
-
|
|
25
|
-
Jax random key to sample new time points and to shuffle batches
|
|
26
|
-
or a dict of Jax random keys with key entries from param_ranges
|
|
35
|
+
key : PRNGKeyArray
|
|
36
|
+
Jax random key to sample new time points and to shuffle batches.
|
|
27
37
|
n : int
|
|
28
38
|
The number of total points that will be divided in
|
|
29
39
|
batches. Batches are made so that each data point is seen only
|
|
@@ -58,71 +68,86 @@ class DataGeneratorParameter(AbstractDataGenerator):
|
|
|
58
68
|
Defaults to None.
|
|
59
69
|
"""
|
|
60
70
|
|
|
61
|
-
|
|
71
|
+
key: PRNGKeyArray
|
|
62
72
|
n: int = eqx.field(static=True)
|
|
63
|
-
param_batch_size: int | None = eqx.field(static=True
|
|
64
|
-
param_ranges: dict[str, tuple[Float, Float]] = eqx.field(
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
)
|
|
73
|
+
param_batch_size: int | None = eqx.field(static=True)
|
|
74
|
+
param_ranges: dict[str, tuple[Float, Float]] = eqx.field(static=True)
|
|
75
|
+
method: str = eqx.field(static=True)
|
|
76
|
+
user_data: dict[str, Float[Array, " n"]]
|
|
77
|
+
|
|
78
|
+
# --- Below fields are not passed as arguments to __init__
|
|
79
|
+
_all_params_keys: set[str] = eqx.field(init=False, static=True)
|
|
80
|
+
curr_param_idx: eqx.Module | None = eqx.field(init=False)
|
|
81
|
+
param_n_samples: eqx.Module = eqx.field(init=False)
|
|
71
82
|
|
|
72
|
-
|
|
73
|
-
|
|
83
|
+
def __init__(
|
|
84
|
+
self,
|
|
85
|
+
*,
|
|
86
|
+
key: PRNGKeyArray,
|
|
87
|
+
n: int,
|
|
88
|
+
param_batch_size: int | None,
|
|
89
|
+
param_ranges: dict[str, tuple[Float, Float]] = {},
|
|
90
|
+
method: str = "uniform",
|
|
91
|
+
user_data: dict[str, Float[Array, " n"]] = {},
|
|
92
|
+
):
|
|
93
|
+
self.key = key
|
|
94
|
+
self.n = n
|
|
95
|
+
self.param_batch_size = param_batch_size
|
|
96
|
+
self.param_ranges = param_ranges
|
|
97
|
+
self.method = method
|
|
98
|
+
self.user_data = user_data
|
|
99
|
+
|
|
100
|
+
_all_keys = set().union(self.param_ranges, self.user_data)
|
|
101
|
+
self._all_params_keys = _all_keys
|
|
74
102
|
|
|
75
|
-
def __post_init__(self):
|
|
76
|
-
if self.user_data is None:
|
|
77
|
-
self.user_data = {}
|
|
78
|
-
if self.param_ranges is None:
|
|
79
|
-
self.param_ranges = {}
|
|
80
103
|
if self.param_batch_size is not None and self.n < self.param_batch_size:
|
|
81
104
|
raise ValueError(
|
|
82
105
|
f"Number of data points ({self.n}) is smaller than the"
|
|
83
106
|
f"number of batch points ({self.param_batch_size})."
|
|
84
107
|
)
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
108
|
+
|
|
109
|
+
# NOTE from jinns > v1.5.1 we work with eqx.Module
|
|
110
|
+
# because eq_params is not a dict anymore.
|
|
111
|
+
# We have to use a different class from the publicly exposed EqParams
|
|
112
|
+
# because fields(EqParams) are not necessarily all present in the
|
|
113
|
+
# datagenerator, which would cause eqx.Module to error.
|
|
114
|
+
|
|
115
|
+
# 1) Call self.generate_data() to generate a dictionnary that merges the scattered data between `user_data` and `param_ranges`
|
|
116
|
+
self.key, _param_n_samples = self.generate_data(self.key)
|
|
117
|
+
|
|
118
|
+
# 2) Use the dictionnary to populate the field of the eqx.Module.
|
|
119
|
+
self.param_n_samples = DGParams(_param_n_samples, "DGParams")
|
|
88
120
|
|
|
89
121
|
if self.param_batch_size is None:
|
|
90
|
-
self.curr_param_idx = None
|
|
122
|
+
self.curr_param_idx = None
|
|
91
123
|
else:
|
|
92
|
-
self.
|
|
93
|
-
for k in self.
|
|
94
|
-
self.curr_param_idx[k] = self.n + self.param_batch_size
|
|
95
|
-
# to be sure there is a shuffling at first get_batch()
|
|
124
|
+
curr_idx = self.n + self.param_batch_size
|
|
125
|
+
param_keys_and_curr_idx = {k: curr_idx for k in self._all_params_keys}
|
|
96
126
|
|
|
97
|
-
|
|
98
|
-
# the dict self.param_n_samples and then we will only use this one
|
|
99
|
-
# because it merges the scattered data between `user_data` and
|
|
100
|
-
# `param_ranges`
|
|
101
|
-
self.keys, self.param_n_samples = self.generate_data(self.keys)
|
|
127
|
+
self.curr_param_idx = DGParams(param_keys_and_curr_idx)
|
|
102
128
|
|
|
103
129
|
def generate_data(
|
|
104
|
-
self,
|
|
105
|
-
) -> tuple[
|
|
130
|
+
self, key: PRNGKeyArray
|
|
131
|
+
) -> tuple[PRNGKeyArray, dict[str, Float[Array, " n 1"]]]:
|
|
106
132
|
"""
|
|
107
133
|
Generate parameter samples, either through generation
|
|
108
134
|
or using user-provided data.
|
|
109
135
|
"""
|
|
110
136
|
param_n_samples = {}
|
|
111
137
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
)
|
|
116
|
-
for k in all_keys:
|
|
138
|
+
# Some of the subkeys might not be used cause of user-provided data.
|
|
139
|
+
# This is not a big deal and simpler like that.
|
|
140
|
+
key, *subkeys = jax.random.split(key, len(self._all_params_keys) + 1)
|
|
141
|
+
for i, k in enumerate(self._all_params_keys):
|
|
117
142
|
if self.user_data and k in self.user_data.keys():
|
|
118
|
-
|
|
119
|
-
param_n_samples[k] = self.user_data[k]
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
raise ValueError(
|
|
143
|
+
try:
|
|
144
|
+
param_n_samples[k] = self.user_data[k].reshape((self.n, 1))
|
|
145
|
+
except TypeError:
|
|
146
|
+
shape = self.user_data[k].shape
|
|
147
|
+
raise TypeError(
|
|
124
148
|
"Wrong shape for user provided parameters"
|
|
125
|
-
f" in user_data dictionary at key='{k}'"
|
|
149
|
+
f" in user_data dictionary at key='{k}' got {shape} "
|
|
150
|
+
f"and expected {(self.n, 1)}."
|
|
126
151
|
)
|
|
127
152
|
else:
|
|
128
153
|
if self.method == "grid":
|
|
@@ -132,75 +157,101 @@ class DataGeneratorParameter(AbstractDataGenerator):
|
|
|
132
157
|
param_n_samples[k] = jnp.arange(xmin, xmax, partial)[:, None]
|
|
133
158
|
elif self.method == "uniform":
|
|
134
159
|
xmin, xmax = self.param_ranges[k][0], self.param_ranges[k][1]
|
|
135
|
-
keys[k], subkey = jax.random.split(keys[k], 2)
|
|
136
160
|
param_n_samples[k] = jax.random.uniform(
|
|
137
|
-
|
|
161
|
+
subkeys[i], shape=(self.n, 1), minval=xmin, maxval=xmax
|
|
138
162
|
)
|
|
139
163
|
else:
|
|
140
164
|
raise ValueError("Method " + self.method + " is not implemented.")
|
|
141
165
|
|
|
142
|
-
return
|
|
143
|
-
|
|
144
|
-
def _get_param_operands(
|
|
145
|
-
self, k: str
|
|
146
|
-
) -> tuple[Key, Float[Array, " n"], int, int | None, None]:
|
|
147
|
-
return (
|
|
148
|
-
self.keys[k],
|
|
149
|
-
self.param_n_samples[k],
|
|
150
|
-
self.curr_param_idx[k],
|
|
151
|
-
self.param_batch_size,
|
|
152
|
-
None,
|
|
153
|
-
)
|
|
166
|
+
return key, param_n_samples
|
|
154
167
|
|
|
155
|
-
def param_batch(self):
|
|
168
|
+
def param_batch(self) -> tuple[Self, eqx.Module]:
|
|
156
169
|
"""
|
|
157
|
-
Return
|
|
158
|
-
If all the batches have been seen, we reshuffle them
|
|
159
|
-
otherwise we just return the next unseen batch.
|
|
170
|
+
Return an `eqx.Module` with batches of parameters at its leafs.
|
|
171
|
+
If all the batches have been seen, we reshuffle them (or rather
|
|
172
|
+
their indices), otherwise we just return the next unseen batch.
|
|
160
173
|
"""
|
|
161
174
|
|
|
162
175
|
if self.param_batch_size is None or self.param_batch_size == self.n:
|
|
176
|
+
# Full batch mode: nothing to do.
|
|
163
177
|
return self, self.param_n_samples
|
|
178
|
+
else:
|
|
179
|
+
|
|
180
|
+
def _reset_or_increment_wrapper(
|
|
181
|
+
param_k: Array, idx_k: int, key_k: PRNGKeyArray
|
|
182
|
+
):
|
|
183
|
+
everything_but_key = _reset_or_increment(
|
|
184
|
+
idx_k + self.param_batch_size, # type: ignore
|
|
185
|
+
self.n,
|
|
186
|
+
(key_k, param_k, idx_k, self.param_batch_size, None), # type: ignore
|
|
187
|
+
)[1:]
|
|
188
|
+
return everything_but_key
|
|
164
189
|
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
idx_k + self.param_batch_size,
|
|
168
|
-
self.n,
|
|
169
|
-
(key_k, param_k, idx_k, self.param_batch_size, None), # type: ignore
|
|
170
|
-
# ignore since the case self.param_batch_size is None has been
|
|
171
|
-
# handled above
|
|
190
|
+
new_key, *subkeys = jax.random.split(
|
|
191
|
+
self.key, len(self._all_params_keys) + 1
|
|
172
192
|
)
|
|
173
193
|
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
res
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
194
|
+
# From PRNGKeyArray to a pytree of keys with adequate structure
|
|
195
|
+
subkeys = jax.tree.unflatten(
|
|
196
|
+
jax.tree.structure(self.param_n_samples), subkeys
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
res = jax.tree.map(
|
|
200
|
+
_reset_or_increment_wrapper,
|
|
201
|
+
self.param_n_samples,
|
|
202
|
+
self.curr_param_idx,
|
|
203
|
+
subkeys,
|
|
204
|
+
)
|
|
205
|
+
# we must transpose the pytrees because both params and curr_idx # are merged in res
|
|
206
|
+
# https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#transposing-trees
|
|
207
|
+
|
|
208
|
+
new_attributes = jax.tree.transpose(
|
|
209
|
+
jax.tree.structure(self.param_n_samples),
|
|
210
|
+
jax.tree.structure([0, 0]),
|
|
211
|
+
res,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
new = eqx.tree_at(
|
|
215
|
+
lambda m: (m.param_n_samples, m.curr_param_idx),
|
|
216
|
+
self,
|
|
217
|
+
new_attributes,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
new = eqx.tree_at(lambda m: m.key, new, new_key)
|
|
221
|
+
|
|
222
|
+
return new, jax.tree_util.tree_map(
|
|
223
|
+
lambda p, q: jax.lax.dynamic_slice(
|
|
224
|
+
p, start_indices=(q, 0), slice_sizes=(new.param_batch_size, 1)
|
|
225
|
+
),
|
|
226
|
+
new.param_n_samples,
|
|
227
|
+
new.curr_param_idx,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
def get_batch(self) -> tuple[Self, eqx.Module]:
|
|
203
231
|
"""
|
|
204
232
|
Generic method to return a batch
|
|
205
233
|
"""
|
|
206
234
|
return self.param_batch()
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
if __name__ == "__main__":
|
|
238
|
+
key = jax.random.PRNGKey(2)
|
|
239
|
+
key, subkey = jax.random.split(key)
|
|
240
|
+
|
|
241
|
+
n = 64
|
|
242
|
+
param_batch_size = 32
|
|
243
|
+
method = "uniform"
|
|
244
|
+
param_ranges = {"theta": (10.0, 11.0)}
|
|
245
|
+
user_data = {"nu": jnp.ones((n, 1))}
|
|
246
|
+
|
|
247
|
+
x = DataGeneratorParameter(
|
|
248
|
+
key=subkey,
|
|
249
|
+
n=n,
|
|
250
|
+
param_batch_size=param_batch_size,
|
|
251
|
+
param_ranges=param_ranges,
|
|
252
|
+
method=method,
|
|
253
|
+
user_data=user_data,
|
|
254
|
+
)
|
|
255
|
+
print(x.key)
|
|
256
|
+
x, batch = x.get_batch()
|
|
257
|
+
print(x.key)
|
jinns/data/__init__.py
CHANGED
|
@@ -2,7 +2,7 @@ from ._DataGeneratorODE import DataGeneratorODE
|
|
|
2
2
|
from ._CubicMeshPDEStatio import CubicMeshPDEStatio
|
|
3
3
|
from ._CubicMeshPDENonStatio import CubicMeshPDENonStatio
|
|
4
4
|
from ._DataGeneratorObservations import DataGeneratorObservations
|
|
5
|
-
from ._DataGeneratorParameter import DataGeneratorParameter
|
|
5
|
+
from ._DataGeneratorParameter import DataGeneratorParameter, DGParams
|
|
6
6
|
from ._Batchs import ODEBatch, PDEStatioBatch, PDENonStatioBatch
|
|
7
7
|
|
|
8
8
|
from ._utils import append_obs_batch, append_param_batch
|
|
@@ -12,6 +12,7 @@ __all__ = [
|
|
|
12
12
|
"CubicMeshPDEStatio",
|
|
13
13
|
"CubicMeshPDENonStatio",
|
|
14
14
|
"DataGeneratorParameter",
|
|
15
|
+
"DGParams",
|
|
15
16
|
"DataGeneratorObservations",
|
|
16
17
|
"ODEBatch",
|
|
17
18
|
"PDEStatioBatch",
|
jinns/data/_utils.py
CHANGED
|
@@ -3,19 +3,19 @@ Utility functions for DataGenerators
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
|
-
|
|
6
|
+
import warnings
|
|
7
7
|
from typing import TYPE_CHECKING
|
|
8
8
|
import equinox as eqx
|
|
9
9
|
import jax
|
|
10
10
|
import jax.numpy as jnp
|
|
11
|
-
from jaxtyping import
|
|
11
|
+
from jaxtyping import PRNGKeyArray, Array, Float
|
|
12
12
|
|
|
13
13
|
if TYPE_CHECKING:
|
|
14
14
|
from jinns.utils._types import AnyBatch
|
|
15
15
|
from jinns.data._Batchs import ObsBatchDict
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
def append_param_batch(batch: AnyBatch, param_batch_dict:
|
|
18
|
+
def append_param_batch(batch: AnyBatch, param_batch_dict: eqx.Module) -> AnyBatch:
|
|
19
19
|
"""
|
|
20
20
|
Utility function that fills the field `batch.param_batch_dict` of a batch object.
|
|
21
21
|
"""
|
|
@@ -53,8 +53,10 @@ def make_cartesian_product(
|
|
|
53
53
|
|
|
54
54
|
|
|
55
55
|
def _reset_batch_idx_and_permute(
|
|
56
|
-
operands: tuple[
|
|
57
|
-
|
|
56
|
+
operands: tuple[
|
|
57
|
+
PRNGKeyArray, Float[Array, " n dimension"], int, None, Float[Array, " n"] | None
|
|
58
|
+
],
|
|
59
|
+
) -> tuple[PRNGKeyArray, Float[Array, " n dimension"], int]:
|
|
58
60
|
key, domain, curr_idx, _, p = operands
|
|
59
61
|
# resetting counter
|
|
60
62
|
curr_idx = 0
|
|
@@ -77,8 +79,10 @@ def _reset_batch_idx_and_permute(
|
|
|
77
79
|
|
|
78
80
|
|
|
79
81
|
def _increment_batch_idx(
|
|
80
|
-
operands: tuple[
|
|
81
|
-
|
|
82
|
+
operands: tuple[
|
|
83
|
+
PRNGKeyArray, Float[Array, " n dimension"], int, int, Float[Array, " n"] | None
|
|
84
|
+
],
|
|
85
|
+
) -> tuple[PRNGKeyArray, Float[Array, " n dimension"], int]:
|
|
82
86
|
key, domain, curr_idx, batch_size, _ = operands
|
|
83
87
|
# simply increases counter and get the batch
|
|
84
88
|
curr_idx += batch_size
|
|
@@ -88,8 +92,10 @@ def _increment_batch_idx(
|
|
|
88
92
|
def _reset_or_increment(
|
|
89
93
|
bend: int,
|
|
90
94
|
n_eff: int,
|
|
91
|
-
operands: tuple[
|
|
92
|
-
|
|
95
|
+
operands: tuple[
|
|
96
|
+
PRNGKeyArray, Float[Array, " n dimension"], int, int, Float[Array, " n"] | None
|
|
97
|
+
],
|
|
98
|
+
) -> tuple[PRNGKeyArray, Float[Array, " n dimension"], int]:
|
|
93
99
|
"""
|
|
94
100
|
Factorize the code of the jax.lax.cond which checks if we have seen all the
|
|
95
101
|
batches in an epoch
|
|
@@ -119,7 +125,7 @@ def _reset_or_increment(
|
|
|
119
125
|
|
|
120
126
|
|
|
121
127
|
def _check_and_set_rar_parameters(
|
|
122
|
-
rar_parameters: dict, n: int, n_start: int
|
|
128
|
+
rar_parameters: None | dict, n: int, n_start: None | int
|
|
123
129
|
) -> tuple[int, Float[Array, " n"] | None, int | None, int | None]:
|
|
124
130
|
if rar_parameters is not None and n_start is None:
|
|
125
131
|
raise ValueError(
|
|
@@ -127,6 +133,12 @@ def _check_and_set_rar_parameters(
|
|
|
127
133
|
)
|
|
128
134
|
|
|
129
135
|
if rar_parameters is not None:
|
|
136
|
+
if n_start is None:
|
|
137
|
+
n_start = 0
|
|
138
|
+
warnings.warn(
|
|
139
|
+
"You asked for RAR sampling but didn't provide"
|
|
140
|
+
f"a proper `n_start` {n_start=}. Setting it to 0."
|
|
141
|
+
)
|
|
130
142
|
# Default p is None. However, in the RAR sampling scheme we use 0
|
|
131
143
|
# probability to specify non-used collocation points (i.e. points
|
|
132
144
|
# above n_start). Thus, p is a vector of probability of shape (nt, 1).
|