jinns 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +114 -187
- jinns/loss/_DynamicLossAbstract.py +45 -68
- jinns/loss/_LossODE.py +71 -336
- jinns/loss/_LossPDE.py +146 -520
- jinns/loss/__init__.py +28 -6
- jinns/loss/_abstract_loss.py +15 -0
- jinns/loss/_boundary_conditions.py +20 -19
- jinns/loss/_loss_utils.py +78 -159
- jinns/loss/_loss_weights.py +12 -44
- jinns/loss/_operators.py +84 -74
- jinns/nn/__init__.py +15 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +94 -57
- jinns/nn/_mlp.py +50 -25
- jinns/nn/_pinn.py +33 -19
- jinns/nn/_ppinn.py +70 -34
- jinns/nn/_save_load.py +21 -51
- jinns/nn/_spinn.py +33 -16
- jinns/nn/_spinn_mlp.py +28 -22
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +35 -34
- jinns/solver/_rar.py +80 -63
- jinns/solver/_solve.py +89 -63
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +12 -9
- jinns/utils/_types.py +11 -57
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/METADATA +4 -3
- jinns-1.4.0.dist-info/RECORD +53 -0
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns-1.3.0.dist-info/RECORD +0 -44
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Define the DataGenerators modules
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import (
|
|
6
|
+
annotations,
|
|
7
|
+
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
|
+
import equinox as eqx
|
|
9
|
+
import jax
|
|
10
|
+
import jax.numpy as jnp
|
|
11
|
+
from jaxtyping import Key, Int, Array, Float
|
|
12
|
+
from jinns.data._Batchs import ObsBatchDict
|
|
13
|
+
from jinns.data._utils import _reset_or_increment
|
|
14
|
+
from jinns.data._AbstractDataGenerator import AbstractDataGenerator
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class DataGeneratorObservations(AbstractDataGenerator):
|
|
18
|
+
r"""
|
|
19
|
+
Despite the class name, it is rather a dataloader for user-provided
|
|
20
|
+
observations which will are used in the observations loss.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
key : Key
|
|
25
|
+
Jax random key to shuffle batches
|
|
26
|
+
obs_batch_size : int | None
|
|
27
|
+
The size of the batch of randomly selected points among
|
|
28
|
+
the `n` points. If None, no minibatch are used.
|
|
29
|
+
observed_pinn_in : Float[Array, " n_obs nb_pinn_in"]
|
|
30
|
+
Observed values corresponding to the input of the PINN
|
|
31
|
+
(eg. the time at which we recorded the observations). The first
|
|
32
|
+
dimension must corresponds to the number of observed_values.
|
|
33
|
+
The second dimension depends on the input dimension of the PINN,
|
|
34
|
+
that is `1` for ODE, `n_dim_x` for stationnary PDE and `n_dim_x + 1`
|
|
35
|
+
for non-stationnary PDE.
|
|
36
|
+
observed_values : Float[Array, " n_obs, nb_pinn_out"]
|
|
37
|
+
Observed values that the PINN should learn to fit. The first
|
|
38
|
+
dimension must be aligned with observed_pinn_in.
|
|
39
|
+
observed_eq_params : dict[str, Float[Array, " n_obs 1"]], default={}
|
|
40
|
+
A dict with keys corresponding to
|
|
41
|
+
the parameter name. The keys must match the keys in
|
|
42
|
+
`params["eq_params"]`. The values are jnp.array with 2 dimensions
|
|
43
|
+
with values corresponding to the parameter value for which we also
|
|
44
|
+
have observed_pinn_in and observed_values. Hence the first
|
|
45
|
+
dimension must be aligned with observed_pinn_in and observed_values.
|
|
46
|
+
Optional argument.
|
|
47
|
+
sharding_device : jax.sharding.Sharding, default=None
|
|
48
|
+
Default None. An optional sharding object to constraint the storage
|
|
49
|
+
of observed inputs, values and parameters. Typically, a
|
|
50
|
+
SingleDeviceSharding(cpu_device) to avoid loading on GPU huge
|
|
51
|
+
datasets of observations. Note that computations for **batches**
|
|
52
|
+
can still be performed on other devices (*e.g.* GPU, TPU or
|
|
53
|
+
any pre-defined Sharding) thanks to the `obs_batch_sharding`
|
|
54
|
+
arguments of `jinns.solve()`. Read `jinns.solve()` doc for more info.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
key: Key
|
|
58
|
+
obs_batch_size: int | None = eqx.field(static=True)
|
|
59
|
+
observed_pinn_in: Float[Array, " n_obs nb_pinn_in"]
|
|
60
|
+
observed_values: Float[Array, " n_obs nb_pinn_out"]
|
|
61
|
+
observed_eq_params: dict[str, Float[Array, " n_obs 1"]] = eqx.field(
|
|
62
|
+
static=True, default_factory=lambda: {}
|
|
63
|
+
)
|
|
64
|
+
sharding_device: jax.sharding.Sharding = eqx.field(static=True, default=None)
|
|
65
|
+
|
|
66
|
+
n: int = eqx.field(init=False, static=True)
|
|
67
|
+
curr_idx: int = eqx.field(init=False)
|
|
68
|
+
indices: Array = eqx.field(init=False)
|
|
69
|
+
|
|
70
|
+
def __post_init__(self):
|
|
71
|
+
if self.observed_pinn_in.shape[0] != self.observed_values.shape[0]:
|
|
72
|
+
raise ValueError(
|
|
73
|
+
"self.observed_pinn_in and self.observed_values must have same first axis"
|
|
74
|
+
)
|
|
75
|
+
for _, v in self.observed_eq_params.items():
|
|
76
|
+
if v.shape[0] != self.observed_pinn_in.shape[0]:
|
|
77
|
+
raise ValueError(
|
|
78
|
+
"self.observed_pinn_in and the values of"
|
|
79
|
+
" self.observed_eq_params must have the same first axis"
|
|
80
|
+
)
|
|
81
|
+
if len(self.observed_pinn_in.shape) == 1:
|
|
82
|
+
self.observed_pinn_in = self.observed_pinn_in[:, None]
|
|
83
|
+
if self.observed_pinn_in.ndim > 2:
|
|
84
|
+
raise ValueError("self.observed_pinn_in must have 2 dimensions")
|
|
85
|
+
if len(self.observed_values.shape) == 1:
|
|
86
|
+
self.observed_values = self.observed_values[:, None]
|
|
87
|
+
if self.observed_values.ndim > 2:
|
|
88
|
+
raise ValueError("self.observed_values must have 2 dimensions")
|
|
89
|
+
for k, v in self.observed_eq_params.items():
|
|
90
|
+
if len(v.shape) == 1:
|
|
91
|
+
self.observed_eq_params[k] = v[:, None]
|
|
92
|
+
if len(v.shape) > 2:
|
|
93
|
+
raise ValueError(
|
|
94
|
+
"Each value of observed_eq_params must have 2 dimensions"
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
self.n = self.observed_pinn_in.shape[0]
|
|
98
|
+
|
|
99
|
+
if self.sharding_device is not None:
|
|
100
|
+
self.observed_pinn_in = jax.lax.with_sharding_constraint(
|
|
101
|
+
self.observed_pinn_in, self.sharding_device
|
|
102
|
+
)
|
|
103
|
+
self.observed_values = jax.lax.with_sharding_constraint(
|
|
104
|
+
self.observed_values, self.sharding_device
|
|
105
|
+
)
|
|
106
|
+
self.observed_eq_params = jax.lax.with_sharding_constraint(
|
|
107
|
+
self.observed_eq_params, self.sharding_device
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
if self.obs_batch_size is not None:
|
|
111
|
+
self.curr_idx = self.n + self.obs_batch_size
|
|
112
|
+
# to be sure there is a shuffling at first get_batch()
|
|
113
|
+
else:
|
|
114
|
+
self.curr_idx = 0
|
|
115
|
+
# For speed and to avoid duplicating data what is really
|
|
116
|
+
# shuffled is a vector of indices
|
|
117
|
+
if self.sharding_device is not None:
|
|
118
|
+
self.indices = jax.lax.with_sharding_constraint(
|
|
119
|
+
jnp.arange(self.n), self.sharding_device
|
|
120
|
+
)
|
|
121
|
+
else:
|
|
122
|
+
self.indices = jnp.arange(self.n)
|
|
123
|
+
|
|
124
|
+
# recall post_init is the only place with _init_ where we can set
|
|
125
|
+
# self attribute in a in-place way
|
|
126
|
+
self.key, _ = jax.random.split(self.key, 2) # to make it equivalent to
|
|
127
|
+
# the call to _reset_batch_idx_and_permute in legacy DG
|
|
128
|
+
|
|
129
|
+
def _get_operands(self) -> tuple[Key, Int[Array, " n"], int, int | None, None]:
|
|
130
|
+
return (
|
|
131
|
+
self.key,
|
|
132
|
+
self.indices,
|
|
133
|
+
self.curr_idx,
|
|
134
|
+
self.obs_batch_size,
|
|
135
|
+
None,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def obs_batch(
|
|
139
|
+
self,
|
|
140
|
+
) -> tuple[DataGeneratorObservations, ObsBatchDict]:
|
|
141
|
+
"""
|
|
142
|
+
Return an update DataGeneratorObservations instance and an ObsBatchDict
|
|
143
|
+
"""
|
|
144
|
+
if self.obs_batch_size is None or self.obs_batch_size == self.n:
|
|
145
|
+
# Avoid unnecessary reshuffling
|
|
146
|
+
return self, {
|
|
147
|
+
"pinn_in": self.observed_pinn_in,
|
|
148
|
+
"val": self.observed_values,
|
|
149
|
+
"eq_params": self.observed_eq_params,
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
new_attributes = _reset_or_increment(
|
|
153
|
+
self.curr_idx + self.obs_batch_size,
|
|
154
|
+
self.n,
|
|
155
|
+
self._get_operands(), # type: ignore
|
|
156
|
+
# ignore since the case self.obs_batch_size is None has been
|
|
157
|
+
# handled above
|
|
158
|
+
)
|
|
159
|
+
new = eqx.tree_at(
|
|
160
|
+
lambda m: (m.key, m.indices, m.curr_idx), self, new_attributes
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
minib_indices = jax.lax.dynamic_slice(
|
|
164
|
+
new.indices,
|
|
165
|
+
start_indices=(new.curr_idx,),
|
|
166
|
+
slice_sizes=(new.obs_batch_size,),
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
obs_batch: ObsBatchDict = {
|
|
170
|
+
"pinn_in": jnp.take(
|
|
171
|
+
new.observed_pinn_in, minib_indices, unique_indices=True, axis=0
|
|
172
|
+
),
|
|
173
|
+
"val": jnp.take(
|
|
174
|
+
new.observed_values, minib_indices, unique_indices=True, axis=0
|
|
175
|
+
),
|
|
176
|
+
"eq_params": jax.tree_util.tree_map(
|
|
177
|
+
lambda a: jnp.take(a, minib_indices, unique_indices=True, axis=0),
|
|
178
|
+
new.observed_eq_params,
|
|
179
|
+
),
|
|
180
|
+
}
|
|
181
|
+
return new, obs_batch
|
|
182
|
+
|
|
183
|
+
def get_batch(
|
|
184
|
+
self,
|
|
185
|
+
) -> tuple[DataGeneratorObservations, ObsBatchDict]:
|
|
186
|
+
"""
|
|
187
|
+
Generic method to return a batch
|
|
188
|
+
"""
|
|
189
|
+
return self.obs_batch()
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Define the DataGenerators modules
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import (
|
|
6
|
+
annotations,
|
|
7
|
+
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
|
+
import equinox as eqx
|
|
9
|
+
import jax
|
|
10
|
+
import jax.numpy as jnp
|
|
11
|
+
from jaxtyping import Key, Array, Float
|
|
12
|
+
from jinns.data._utils import _reset_or_increment
|
|
13
|
+
from jinns.data._AbstractDataGenerator import AbstractDataGenerator
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DataGeneratorParameter(AbstractDataGenerator):
|
|
17
|
+
r"""
|
|
18
|
+
A data generator for additional unidimensional equation parameter(s).
|
|
19
|
+
Mostly useful for metamodeling where batch of `params.eq_params` are fed
|
|
20
|
+
to the network.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
keys : Key | dict[str, Key]
|
|
25
|
+
Jax random key to sample new time points and to shuffle batches
|
|
26
|
+
or a dict of Jax random keys with key entries from param_ranges
|
|
27
|
+
n : int
|
|
28
|
+
The number of total points that will be divided in
|
|
29
|
+
batches. Batches are made so that each data point is seen only
|
|
30
|
+
once during 1 epoch.
|
|
31
|
+
param_batch_size : int | None, default=None
|
|
32
|
+
The size of the batch of randomly selected points among
|
|
33
|
+
the `n` points. **Important**: no check is performed but
|
|
34
|
+
`param_batch_size` must be the same as other collocation points
|
|
35
|
+
batch_size (time, space or timexspace depending on the context). This is because we vmap the network on all its axes at once to compute the MSE. Also, `param_batch_size` will be the same for all parameters. If None, no mini-batches are used.
|
|
36
|
+
param_ranges : dict[str, tuple[Float, Float] | None, default={}
|
|
37
|
+
A dict. A dict of tuples (min, max), which
|
|
38
|
+
reprensents the range of real numbers where to sample batches (of
|
|
39
|
+
length `param_batch_size` among `n` points).
|
|
40
|
+
The key corresponds to the parameter name. The keys must match the
|
|
41
|
+
keys in `params["eq_params"]`.
|
|
42
|
+
By providing several entries in this dictionary we can sample
|
|
43
|
+
an arbitrary number of parameters.
|
|
44
|
+
**Note** that we currently only support unidimensional parameters.
|
|
45
|
+
This argument can be None if we use `user_data`.
|
|
46
|
+
method : str, default="uniform"
|
|
47
|
+
Either `grid` or `uniform`, default is `uniform`. `grid` means
|
|
48
|
+
regularly spaced points over the domain. `uniform` means uniformly
|
|
49
|
+
sampled points over the domain
|
|
50
|
+
user_data : dict[str, Float[Array, " n"]] | None, default={}
|
|
51
|
+
A dictionary containing user-provided data for parameters.
|
|
52
|
+
The keys corresponds to the parameter name,
|
|
53
|
+
and must match the keys in `params["eq_params"]`. Only
|
|
54
|
+
unidimensional `jnp.array` are supported. Therefore, the array at
|
|
55
|
+
`user_data[k]` must have shape `(n, 1)` or `(n,)`.
|
|
56
|
+
Note that if the same key appears in `param_ranges` and `user_data`
|
|
57
|
+
priority goes for the content in `user_data`.
|
|
58
|
+
Defaults to None.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
keys: Key | dict[str, Key]
|
|
62
|
+
n: int = eqx.field(static=True)
|
|
63
|
+
param_batch_size: int | None = eqx.field(static=True, default=None)
|
|
64
|
+
param_ranges: dict[str, tuple[Float, Float]] = eqx.field(
|
|
65
|
+
static=True, default_factory=lambda: {}
|
|
66
|
+
)
|
|
67
|
+
method: str = eqx.field(static=True, default="uniform")
|
|
68
|
+
user_data: dict[str, Float[Array, " n"]] | None = eqx.field(
|
|
69
|
+
default_factory=lambda: {}
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
curr_param_idx: dict[str, int] = eqx.field(init=False)
|
|
73
|
+
param_n_samples: dict[str, Array] = eqx.field(init=False)
|
|
74
|
+
|
|
75
|
+
def __post_init__(self):
|
|
76
|
+
if self.user_data is None:
|
|
77
|
+
self.user_data = {}
|
|
78
|
+
if self.param_ranges is None:
|
|
79
|
+
self.param_ranges = {}
|
|
80
|
+
if self.param_batch_size is not None and self.n < self.param_batch_size:
|
|
81
|
+
raise ValueError(
|
|
82
|
+
f"Number of data points ({self.n}) is smaller than the"
|
|
83
|
+
f"number of batch points ({self.param_batch_size})."
|
|
84
|
+
)
|
|
85
|
+
if not isinstance(self.keys, dict):
|
|
86
|
+
all_keys = set().union(self.param_ranges, self.user_data)
|
|
87
|
+
self.keys = dict(zip(all_keys, jax.random.split(self.keys, len(all_keys))))
|
|
88
|
+
|
|
89
|
+
if self.param_batch_size is None:
|
|
90
|
+
self.curr_param_idx = None # type: ignore
|
|
91
|
+
else:
|
|
92
|
+
self.curr_param_idx = {}
|
|
93
|
+
for k in self.keys.keys():
|
|
94
|
+
self.curr_param_idx[k] = self.n + self.param_batch_size
|
|
95
|
+
# to be sure there is a shuffling at first get_batch()
|
|
96
|
+
|
|
97
|
+
# The call to self.generate_data() creates
|
|
98
|
+
# the dict self.param_n_samples and then we will only use this one
|
|
99
|
+
# because it merges the scattered data between `user_data` and
|
|
100
|
+
# `param_ranges`
|
|
101
|
+
self.keys, self.param_n_samples = self.generate_data(self.keys)
|
|
102
|
+
|
|
103
|
+
def generate_data(
|
|
104
|
+
self, keys: dict[str, Key]
|
|
105
|
+
) -> tuple[dict[str, Key], dict[str, Float[Array, " n"]]]:
|
|
106
|
+
"""
|
|
107
|
+
Generate parameter samples, either through generation
|
|
108
|
+
or using user-provided data.
|
|
109
|
+
"""
|
|
110
|
+
param_n_samples = {}
|
|
111
|
+
|
|
112
|
+
all_keys = set().union(
|
|
113
|
+
self.param_ranges,
|
|
114
|
+
self.user_data, # type: ignore this has been handled in post_init
|
|
115
|
+
)
|
|
116
|
+
for k in all_keys:
|
|
117
|
+
if self.user_data and k in self.user_data.keys():
|
|
118
|
+
if self.user_data[k].shape == (self.n, 1):
|
|
119
|
+
param_n_samples[k] = self.user_data[k]
|
|
120
|
+
if self.user_data[k].shape == (self.n,):
|
|
121
|
+
param_n_samples[k] = self.user_data[k][:, None]
|
|
122
|
+
else:
|
|
123
|
+
raise ValueError(
|
|
124
|
+
"Wrong shape for user provided parameters"
|
|
125
|
+
f" in user_data dictionary at key='{k}'"
|
|
126
|
+
)
|
|
127
|
+
else:
|
|
128
|
+
if self.method == "grid":
|
|
129
|
+
xmin, xmax = self.param_ranges[k][0], self.param_ranges[k][1]
|
|
130
|
+
partial = (xmax - xmin) / self.n
|
|
131
|
+
# shape (n, 1)
|
|
132
|
+
param_n_samples[k] = jnp.arange(xmin, xmax, partial)[:, None]
|
|
133
|
+
elif self.method == "uniform":
|
|
134
|
+
xmin, xmax = self.param_ranges[k][0], self.param_ranges[k][1]
|
|
135
|
+
keys[k], subkey = jax.random.split(keys[k], 2)
|
|
136
|
+
param_n_samples[k] = jax.random.uniform(
|
|
137
|
+
subkey, shape=(self.n, 1), minval=xmin, maxval=xmax
|
|
138
|
+
)
|
|
139
|
+
else:
|
|
140
|
+
raise ValueError("Method " + self.method + " is not implemented.")
|
|
141
|
+
|
|
142
|
+
return keys, param_n_samples
|
|
143
|
+
|
|
144
|
+
def _get_param_operands(
|
|
145
|
+
self, k: str
|
|
146
|
+
) -> tuple[Key, Float[Array, " n"], int, int | None, None]:
|
|
147
|
+
return (
|
|
148
|
+
self.keys[k],
|
|
149
|
+
self.param_n_samples[k],
|
|
150
|
+
self.curr_param_idx[k],
|
|
151
|
+
self.param_batch_size,
|
|
152
|
+
None,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
def param_batch(self):
|
|
156
|
+
"""
|
|
157
|
+
Return a dictionary with batches of parameters
|
|
158
|
+
If all the batches have been seen, we reshuffle them,
|
|
159
|
+
otherwise we just return the next unseen batch.
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
if self.param_batch_size is None or self.param_batch_size == self.n:
|
|
163
|
+
return self, self.param_n_samples
|
|
164
|
+
|
|
165
|
+
def _reset_or_increment_wrapper(param_k, idx_k, key_k):
|
|
166
|
+
return _reset_or_increment(
|
|
167
|
+
idx_k + self.param_batch_size,
|
|
168
|
+
self.n,
|
|
169
|
+
(key_k, param_k, idx_k, self.param_batch_size, None), # type: ignore
|
|
170
|
+
# ignore since the case self.param_batch_size is None has been
|
|
171
|
+
# handled above
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
res = jax.tree_util.tree_map(
|
|
175
|
+
_reset_or_increment_wrapper,
|
|
176
|
+
self.param_n_samples,
|
|
177
|
+
self.curr_param_idx,
|
|
178
|
+
self.keys,
|
|
179
|
+
)
|
|
180
|
+
# we must transpose the pytrees because keys are merged in res
|
|
181
|
+
# https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#transposing-trees
|
|
182
|
+
new_attributes = jax.tree_util.tree_transpose(
|
|
183
|
+
jax.tree_util.tree_structure(self.keys),
|
|
184
|
+
jax.tree_util.tree_structure([0, 0, 0]),
|
|
185
|
+
res,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
new = eqx.tree_at(
|
|
189
|
+
lambda m: (m.keys, m.param_n_samples, m.curr_param_idx),
|
|
190
|
+
self,
|
|
191
|
+
new_attributes,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
return new, jax.tree_util.tree_map(
|
|
195
|
+
lambda p, q: jax.lax.dynamic_slice(
|
|
196
|
+
p, start_indices=(q, 0), slice_sizes=(new.param_batch_size, 1)
|
|
197
|
+
),
|
|
198
|
+
new.param_n_samples,
|
|
199
|
+
new.curr_param_idx,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
def get_batch(self):
|
|
203
|
+
"""
|
|
204
|
+
Generic method to return a batch
|
|
205
|
+
"""
|
|
206
|
+
return self.param_batch()
|
jinns/data/__init__.py
CHANGED
|
@@ -1,11 +1,21 @@
|
|
|
1
|
-
from .
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
DataGeneratorParameter,
|
|
7
|
-
DataGeneratorObservationsMultiPINNs,
|
|
8
|
-
)
|
|
1
|
+
from ._DataGeneratorODE import DataGeneratorODE
|
|
2
|
+
from ._CubicMeshPDEStatio import CubicMeshPDEStatio
|
|
3
|
+
from ._CubicMeshPDENonStatio import CubicMeshPDENonStatio
|
|
4
|
+
from ._DataGeneratorObservations import DataGeneratorObservations
|
|
5
|
+
from ._DataGeneratorParameter import DataGeneratorParameter
|
|
9
6
|
from ._Batchs import ODEBatch, PDEStatioBatch, PDENonStatioBatch
|
|
10
7
|
|
|
11
|
-
from .
|
|
8
|
+
from ._utils import append_obs_batch, append_param_batch
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"DataGeneratorODE",
|
|
12
|
+
"CubicMeshPDEStatio",
|
|
13
|
+
"CubicMeshPDENonStatio",
|
|
14
|
+
"DataGeneratorParameter",
|
|
15
|
+
"DataGeneratorObservations",
|
|
16
|
+
"ODEBatch",
|
|
17
|
+
"PDEStatioBatch",
|
|
18
|
+
"PDENonStatioBatch",
|
|
19
|
+
"append_obs_batch",
|
|
20
|
+
"append_param_batch",
|
|
21
|
+
]
|
jinns/data/_utils.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility functions for DataGenerators
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
import equinox as eqx
|
|
9
|
+
import jax
|
|
10
|
+
import jax.numpy as jnp
|
|
11
|
+
from jaxtyping import Key, Array, Float
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from jinns.utils._types import AnyBatch
|
|
15
|
+
from jinns.data._Batchs import ObsBatchDict
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def append_param_batch(batch: AnyBatch, param_batch_dict: dict[str, Array]) -> AnyBatch:
|
|
19
|
+
"""
|
|
20
|
+
Utility function that fills the field `batch.param_batch_dict` of a batch object.
|
|
21
|
+
"""
|
|
22
|
+
return eqx.tree_at(
|
|
23
|
+
lambda m: m.param_batch_dict,
|
|
24
|
+
batch,
|
|
25
|
+
param_batch_dict,
|
|
26
|
+
is_leaf=lambda x: x is None,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def append_obs_batch(batch: AnyBatch, obs_batch_dict: ObsBatchDict) -> AnyBatch:
|
|
31
|
+
"""
|
|
32
|
+
Utility function that fills the field `batch.obs_batch_dict` of a batch object
|
|
33
|
+
"""
|
|
34
|
+
return eqx.tree_at(
|
|
35
|
+
lambda m: m.obs_batch_dict, batch, obs_batch_dict, is_leaf=lambda x: x is None
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def make_cartesian_product(
|
|
40
|
+
b1: Float[Array, " batch_size dim1"], b2: Float[Array, " batch_size dim2"]
|
|
41
|
+
) -> Float[Array, " rows=batch_size*batch_size (dim1+dim2)"]:
|
|
42
|
+
# rows= serves to disable jaxtyping wish for runtime check since it does not like the star
|
|
43
|
+
# operator, we wish use not as expected
|
|
44
|
+
"""
|
|
45
|
+
Create the cartesian product of a time and a border omega batches
|
|
46
|
+
by tiling and repeating
|
|
47
|
+
"""
|
|
48
|
+
n1 = b1.shape[0]
|
|
49
|
+
n2 = b2.shape[0]
|
|
50
|
+
b1 = jnp.repeat(b1, n2, axis=0)
|
|
51
|
+
b2 = jnp.tile(b2, reps=(n1,) + tuple(1 for i in b2.shape[1:]))
|
|
52
|
+
return jnp.concatenate([b1, b2], axis=1)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _reset_batch_idx_and_permute(
|
|
56
|
+
operands: tuple[Key, Float[Array, " n dimension"], int, None, Float[Array, " n"]],
|
|
57
|
+
) -> tuple[Key, Float[Array, " n dimension"], int]:
|
|
58
|
+
key, domain, curr_idx, _, p = operands
|
|
59
|
+
# resetting counter
|
|
60
|
+
curr_idx = 0
|
|
61
|
+
# reshuffling
|
|
62
|
+
key, subkey = jax.random.split(key)
|
|
63
|
+
if p is None:
|
|
64
|
+
domain = jax.random.permutation(subkey, domain, axis=0, independent=False)
|
|
65
|
+
else:
|
|
66
|
+
# otherwise p is used to avoid collocation points not in n_start
|
|
67
|
+
# NOTE that replace=True to avoid undefined behaviour but then, the
|
|
68
|
+
# domain.shape[0] does not really grow as in the original RAR. instead,
|
|
69
|
+
# it always comprises the same number of points, but the points are
|
|
70
|
+
# updated
|
|
71
|
+
domain = jax.random.choice(
|
|
72
|
+
subkey, domain, shape=(domain.shape[0],), replace=True, p=p
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# return updated
|
|
76
|
+
return (key, domain, curr_idx)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _increment_batch_idx(
|
|
80
|
+
operands: tuple[Key, Float[Array, " n dimension"], int, int, Float[Array, " n"]],
|
|
81
|
+
) -> tuple[Key, Float[Array, " n dimension"], int]:
|
|
82
|
+
key, domain, curr_idx, batch_size, _ = operands
|
|
83
|
+
# simply increases counter and get the batch
|
|
84
|
+
curr_idx += batch_size
|
|
85
|
+
return (key, domain, curr_idx)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _reset_or_increment(
|
|
89
|
+
bend: int,
|
|
90
|
+
n_eff: int,
|
|
91
|
+
operands: tuple[Key, Float[Array, " n dimension"], int, int, Float[Array, " n"]],
|
|
92
|
+
) -> tuple[Key, Float[Array, " n dimension"], int]:
|
|
93
|
+
"""
|
|
94
|
+
Factorize the code of the jax.lax.cond which checks if we have seen all the
|
|
95
|
+
batches in an epoch
|
|
96
|
+
If bend > n_eff (ie n when no RAR sampling) we reshuffle and start from 0
|
|
97
|
+
again. Otherwise, if bend < n_eff, this means there are still *_batch_size
|
|
98
|
+
samples at least that have not been seen and we can take a new batch
|
|
99
|
+
|
|
100
|
+
Parameters
|
|
101
|
+
----------
|
|
102
|
+
bend
|
|
103
|
+
An integer. The new hypothetical index for the starting of the batch
|
|
104
|
+
n_eff
|
|
105
|
+
An integer. The number of points to see to complete an epoch
|
|
106
|
+
operands
|
|
107
|
+
A tuple. As passed to _reset_batch_idx_and_permute and
|
|
108
|
+
_increment_batch_idx
|
|
109
|
+
|
|
110
|
+
Returns
|
|
111
|
+
-------
|
|
112
|
+
res
|
|
113
|
+
A tuple as returned by _reset_batch_idx_and_permute or
|
|
114
|
+
_increment_batch_idx
|
|
115
|
+
"""
|
|
116
|
+
return jax.lax.cond(
|
|
117
|
+
bend > n_eff, _reset_batch_idx_and_permute, _increment_batch_idx, operands
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _check_and_set_rar_parameters(
|
|
122
|
+
rar_parameters: dict, n: int, n_start: int
|
|
123
|
+
) -> tuple[int, Float[Array, " n"] | None, int | None, int | None]:
|
|
124
|
+
if rar_parameters is not None and n_start is None:
|
|
125
|
+
raise ValueError(
|
|
126
|
+
"n_start must be provided in the context of RAR sampling scheme"
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
if rar_parameters is not None:
|
|
130
|
+
# Default p is None. However, in the RAR sampling scheme we use 0
|
|
131
|
+
# probability to specify non-used collocation points (i.e. points
|
|
132
|
+
# above n_start). Thus, p is a vector of probability of shape (nt, 1).
|
|
133
|
+
p = jnp.zeros((n,))
|
|
134
|
+
p = p.at[:n_start].set(1 / n_start)
|
|
135
|
+
# set internal counter for the number of gradient steps since the
|
|
136
|
+
# last new collocation points have been added
|
|
137
|
+
# It is not 0 to ensure the first iteration of RAR happens just
|
|
138
|
+
# after start_iter. See the _proceed_to_rar() function in _rar.py
|
|
139
|
+
rar_iter_from_last_sampling = rar_parameters["update_every"] - 1
|
|
140
|
+
# set iternal counter for the number of times collocation points
|
|
141
|
+
# have been added
|
|
142
|
+
rar_iter_nb = 0
|
|
143
|
+
else:
|
|
144
|
+
n_start = n
|
|
145
|
+
p = None
|
|
146
|
+
rar_iter_from_last_sampling = None
|
|
147
|
+
rar_iter_nb = None
|
|
148
|
+
|
|
149
|
+
return n_start, p, rar_iter_from_last_sampling, rar_iter_nb
|
jinns/experimental/__init__.py
CHANGED
|
@@ -6,3 +6,12 @@ from ._diffrax_solver import (
|
|
|
6
6
|
neumann_boundary_condition,
|
|
7
7
|
plot_diffrax_solution,
|
|
8
8
|
)
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"SpatialDiscretisation",
|
|
12
|
+
"reaction_diffusion_2d_vector_field",
|
|
13
|
+
"laplacian",
|
|
14
|
+
"dirichlet_boundary_condition",
|
|
15
|
+
"neumann_boundary_condition",
|
|
16
|
+
"plot_diffrax_solution",
|
|
17
|
+
]
|