jinns 1.5.0__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +7 -7
- jinns/data/_CubicMeshPDENonStatio.py +156 -28
- jinns/data/_CubicMeshPDEStatio.py +132 -24
- jinns/loss/_LossODE.py +95 -31
- jinns/loss/_LossPDE.py +6 -15
- jinns/loss/_loss_utils.py +23 -0
- jinns/parameters/_params.py +8 -0
- jinns/solver/_solve.py +11 -5
- {jinns-1.5.0.dist-info → jinns-1.5.1.dist-info}/METADATA +1 -1
- {jinns-1.5.0.dist-info → jinns-1.5.1.dist-info}/RECORD +14 -14
- {jinns-1.5.0.dist-info → jinns-1.5.1.dist-info}/WHEEL +0 -0
- {jinns-1.5.0.dist-info → jinns-1.5.1.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.5.0.dist-info → jinns-1.5.1.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.5.0.dist-info → jinns-1.5.1.dist-info}/top_level.txt +0 -0
jinns/__init__.py
CHANGED
|
@@ -1,10 +1,3 @@
|
|
|
1
|
-
# import jinns.data
|
|
2
|
-
# import jinns.loss
|
|
3
|
-
# import jinns.solver
|
|
4
|
-
# import jinns.utils
|
|
5
|
-
# import jinns.experimental
|
|
6
|
-
# import jinns.parameters
|
|
7
|
-
# import jinns.plot
|
|
8
1
|
from jinns import data as data
|
|
9
2
|
from jinns import loss as loss
|
|
10
3
|
from jinns import solver as solver
|
|
@@ -16,3 +9,10 @@ from jinns import nn as nn
|
|
|
16
9
|
from jinns.solver._solve import solve
|
|
17
10
|
|
|
18
11
|
__all__ = ["nn", "solve"]
|
|
12
|
+
|
|
13
|
+
import warnings
|
|
14
|
+
|
|
15
|
+
warnings.filterwarnings(
|
|
16
|
+
action="ignore",
|
|
17
|
+
message=r"Using `field\(init=False\)`",
|
|
18
|
+
)
|
|
@@ -7,8 +7,10 @@ from __future__ import (
|
|
|
7
7
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
8
|
import warnings
|
|
9
9
|
import equinox as eqx
|
|
10
|
+
import numpy as np
|
|
10
11
|
import jax
|
|
11
12
|
import jax.numpy as jnp
|
|
13
|
+
from scipy.stats import qmc
|
|
12
14
|
from jaxtyping import Key, Array, Float
|
|
13
15
|
from jinns.data._Batchs import PDENonStatioBatch
|
|
14
16
|
from jinns.data._utils import (
|
|
@@ -65,11 +67,13 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
65
67
|
The minimum value of the time domain to consider
|
|
66
68
|
tmax : float
|
|
67
69
|
The maximum value of the time domain to consider
|
|
68
|
-
method :
|
|
70
|
+
method : Literal["uniform", "grid", "sobol", "halton"], default="uniform"
|
|
69
71
|
Either `grid` or `uniform`, default is `uniform`.
|
|
70
72
|
The method that generates the `nt` time points. `grid` means
|
|
71
73
|
regularly spaced points over the domain. `uniform` means uniformly
|
|
72
|
-
sampled points over the domain
|
|
74
|
+
sampled points over the domain.
|
|
75
|
+
**Note** that Sobol and Halton approaches use scipy modules and will not
|
|
76
|
+
be JIT compatible.
|
|
73
77
|
rar_parameters : Dict[str, int], default=None
|
|
74
78
|
Defaults to None: do not use Residual Adaptative Resampling.
|
|
75
79
|
Otherwise a dictionary with keys
|
|
@@ -150,9 +154,11 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
150
154
|
elif self.method == "uniform":
|
|
151
155
|
self.key, domain_times = self.generate_time_data(self.key, self.n)
|
|
152
156
|
self.domain = jnp.concatenate([domain_times, self.omega], axis=1)
|
|
157
|
+
elif self.method in ["sobol", "halton"]:
|
|
158
|
+
self.key, self.domain = self.qmc_in_time_omega_domain(self.key, self.n)
|
|
153
159
|
else:
|
|
154
160
|
raise ValueError(
|
|
155
|
-
f'Bad value for method. Got {self.method}, expected "grid" or "uniform"'
|
|
161
|
+
f'Bad value for method. Got {self.method}, expected "grid" or "uniform" or "sobol" or "halton"'
|
|
156
162
|
)
|
|
157
163
|
|
|
158
164
|
if self.domain_batch_size is None:
|
|
@@ -182,21 +188,28 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
182
188
|
"number of points per facets (nb//2*self.dim)"
|
|
183
189
|
" cannot be lower than border batch size"
|
|
184
190
|
)
|
|
185
|
-
self.
|
|
186
|
-
self.key,
|
|
187
|
-
|
|
188
|
-
boundary_times = boundary_times.reshape(-1, 1, 1)
|
|
189
|
-
boundary_times = jnp.repeat(
|
|
190
|
-
boundary_times, self.omega_border.shape[-1], axis=2
|
|
191
|
-
)
|
|
192
|
-
if self.dim == 1:
|
|
193
|
-
self.border = make_cartesian_product(
|
|
194
|
-
boundary_times, self.omega_border[None, None]
|
|
191
|
+
if self.method in ["grid", "uniform"]:
|
|
192
|
+
self.key, boundary_times = self.generate_time_data(
|
|
193
|
+
self.key, self.nb // (2 * self.dim)
|
|
195
194
|
)
|
|
195
|
+
boundary_times = boundary_times.reshape(-1, 1, 1)
|
|
196
|
+
boundary_times = jnp.repeat(
|
|
197
|
+
boundary_times, self.omega_border.shape[-1], axis=2
|
|
198
|
+
)
|
|
199
|
+
if self.dim == 1:
|
|
200
|
+
self.border = make_cartesian_product(
|
|
201
|
+
boundary_times, self.omega_border[None, None]
|
|
202
|
+
)
|
|
203
|
+
else:
|
|
204
|
+
self.border = jnp.concatenate(
|
|
205
|
+
[boundary_times, self.omega_border], axis=1
|
|
206
|
+
)
|
|
196
207
|
else:
|
|
197
|
-
self.border =
|
|
198
|
-
|
|
208
|
+
self.key, self.border = self.qmc_in_time_omega_border_domain(
|
|
209
|
+
self.key,
|
|
210
|
+
self.nb, # type: ignore (see inside the fun)
|
|
199
211
|
)
|
|
212
|
+
|
|
200
213
|
if self.border_batch_size is None:
|
|
201
214
|
self.curr_border_idx = 0
|
|
202
215
|
else:
|
|
@@ -209,14 +222,30 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
209
222
|
self.curr_border_idx = 0
|
|
210
223
|
|
|
211
224
|
if self.ni is not None:
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
225
|
+
if self.method == "grid":
|
|
226
|
+
perfect_sq = int(jnp.round(jnp.sqrt(self.ni)) ** 2)
|
|
227
|
+
if self.ni != perfect_sq:
|
|
228
|
+
warnings.warn(
|
|
229
|
+
"Grid sampling is requested in dimension 2 with a non"
|
|
230
|
+
f" perfect square dataset size (self.ni = {self.ni})."
|
|
231
|
+
f" Modifying self.ni to self.ni = {perfect_sq}."
|
|
232
|
+
)
|
|
233
|
+
self.ni = perfect_sq
|
|
234
|
+
if self.method in ["sobol", "halton"]:
|
|
235
|
+
log2_n = jnp.log2(self.ni)
|
|
236
|
+
lower_pow = 2 ** jnp.floor(log2_n)
|
|
237
|
+
higher_pow = 2 ** jnp.ceil(log2_n)
|
|
238
|
+
closest_two_power = (
|
|
239
|
+
lower_pow
|
|
240
|
+
if (self.ni - lower_pow) < (higher_pow - self.ni)
|
|
241
|
+
else higher_pow
|
|
218
242
|
)
|
|
219
|
-
|
|
243
|
+
if self.n != closest_two_power:
|
|
244
|
+
warnings.warn(
|
|
245
|
+
f"QuasiMonteCarlo sampling with {self.method} requires sample size to be a power fo 2."
|
|
246
|
+
f"Modfiying self.n from {self.ni} to {closest_two_power}.",
|
|
247
|
+
)
|
|
248
|
+
self.ni = int(closest_two_power)
|
|
220
249
|
self.key, self.initial = self.generate_omega_data(
|
|
221
250
|
self.key, data_size=self.ni
|
|
222
251
|
)
|
|
@@ -245,16 +274,115 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
245
274
|
if self.method == "grid":
|
|
246
275
|
partial_times = (self.tmax - self.tmin) / nt
|
|
247
276
|
return key, jnp.arange(self.tmin, self.tmax, partial_times)[:, None]
|
|
248
|
-
|
|
277
|
+
elif self.method in ["uniform", "sobol", "halton"]:
|
|
249
278
|
return key, self.sample_in_time_domain(subkey, nt)
|
|
250
279
|
raise ValueError("Method " + self.method + " is not implemented.")
|
|
251
280
|
|
|
252
281
|
def sample_in_time_domain(self, key: Key, nt: int) -> Float[Array, " nt 1"]:
|
|
253
|
-
return jax.random.uniform(
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
282
|
+
return jax.random.uniform(key, (nt, 1), minval=self.tmin, maxval=self.tmax)
|
|
283
|
+
|
|
284
|
+
def qmc_in_time_omega_domain(
|
|
285
|
+
self, key: Key, sample_size: int
|
|
286
|
+
) -> tuple[Key, Float[Array, "n 1+dim"]]:
|
|
287
|
+
"""
|
|
288
|
+
Because in Quasi-Monte Carlo sampling we cannot concatenate two vectors generated independently
|
|
289
|
+
We generate time and omega samples jointly
|
|
290
|
+
"""
|
|
291
|
+
key, subkey = jax.random.split(key, 2)
|
|
292
|
+
qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
|
|
293
|
+
sampler = qmc_generator(
|
|
294
|
+
d=self.dim + 1, scramble=True, rng=np.random.default_rng(np.uint32(subkey))
|
|
295
|
+
)
|
|
296
|
+
samples = sampler.random(n=sample_size)
|
|
297
|
+
samples[:, 1:] = qmc.scale(
|
|
298
|
+
samples[:, 1:], l_bounds=self.min_pts, u_bounds=self.max_pts
|
|
299
|
+
) # We scale omega domain to be in (min_pts, max_pts)
|
|
300
|
+
return key, jnp.array(samples)
|
|
301
|
+
|
|
302
|
+
def qmc_in_time_omega_border_domain(
|
|
303
|
+
self, key: Key, sample_size: int | None = None
|
|
304
|
+
) -> tuple[Key, Float[Array, "n 1+dim"]] | None:
|
|
305
|
+
"""
|
|
306
|
+
For each facet of the border we generate Quasi-MonteCarlo sequences jointy with time.
|
|
307
|
+
|
|
308
|
+
We need to do some type ignore in this function because we have lost
|
|
309
|
+
the type narrowing from post_init, type checkers only narrow at function level and because we cannot narrow a class attribute.
|
|
310
|
+
"""
|
|
311
|
+
qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
|
|
312
|
+
sample_size = self.nb if sample_size is None else sample_size
|
|
313
|
+
if sample_size is None:
|
|
314
|
+
return None
|
|
315
|
+
if self.dim == 1:
|
|
316
|
+
key, subkey = jax.random.split(key, 2)
|
|
317
|
+
qmc_seq = qmc_generator(
|
|
318
|
+
d=1, scramble=True, rng=np.random.default_rng(np.uint32(subkey))
|
|
319
|
+
)
|
|
320
|
+
boundary_times = jnp.array(
|
|
321
|
+
qmc_seq.random(self.nb // (2 * self.dim)) # type: ignore
|
|
322
|
+
)
|
|
323
|
+
boundary_times = boundary_times.reshape(-1, 1, 1)
|
|
324
|
+
boundary_times = jnp.repeat(
|
|
325
|
+
boundary_times,
|
|
326
|
+
self.omega_border.shape[-1], # type: ignore
|
|
327
|
+
axis=2,
|
|
328
|
+
)
|
|
329
|
+
return key, make_cartesian_product(
|
|
330
|
+
boundary_times,
|
|
331
|
+
self.omega_border[None, None], # type: ignore
|
|
332
|
+
)
|
|
333
|
+
if self.dim == 2:
|
|
334
|
+
# currently hard-coded the 4 edges for d==2
|
|
335
|
+
# TODO : find a general & efficient way to sample from the border
|
|
336
|
+
# (facets) of the hypercube in general dim.
|
|
337
|
+
key, *subkeys = jax.random.split(key, 5)
|
|
338
|
+
facet_n = sample_size // (2 * self.dim)
|
|
339
|
+
|
|
340
|
+
def generate_qmc_sample(key, min_val, max_val):
|
|
341
|
+
qmc_seq = qmc_generator(
|
|
342
|
+
d=2,
|
|
343
|
+
scramble=True,
|
|
344
|
+
rng=np.random.default_rng(np.uint32(key)),
|
|
345
|
+
)
|
|
346
|
+
u = qmc_seq.random(n=facet_n)
|
|
347
|
+
u[:, 1:2] = qmc.scale(u[:, 1:2], l_bounds=min_val, u_bounds=max_val)
|
|
348
|
+
return jnp.array(u)
|
|
349
|
+
|
|
350
|
+
xmin_sample = generate_qmc_sample(
|
|
351
|
+
subkeys[0], self.min_pts[1], self.max_pts[1]
|
|
352
|
+
) # [t,x,y]
|
|
353
|
+
xmin = jnp.hstack(
|
|
354
|
+
[
|
|
355
|
+
xmin_sample[:, 0:1],
|
|
356
|
+
self.min_pts[0] * jnp.ones((facet_n, 1)),
|
|
357
|
+
xmin_sample[:, 1:2],
|
|
358
|
+
]
|
|
359
|
+
)
|
|
360
|
+
xmax_sample = generate_qmc_sample(
|
|
361
|
+
subkeys[1], self.min_pts[1], self.max_pts[1]
|
|
362
|
+
)
|
|
363
|
+
xmax = jnp.hstack(
|
|
364
|
+
[
|
|
365
|
+
xmax_sample[:, 0:1],
|
|
366
|
+
self.max_pts[0] * jnp.ones((facet_n, 1)),
|
|
367
|
+
xmax_sample[:, 1:2],
|
|
368
|
+
]
|
|
369
|
+
)
|
|
370
|
+
ymin = jnp.hstack(
|
|
371
|
+
[
|
|
372
|
+
generate_qmc_sample(subkeys[2], self.min_pts[0], self.max_pts[0]),
|
|
373
|
+
self.min_pts[1] * jnp.ones((facet_n, 1)),
|
|
374
|
+
]
|
|
375
|
+
)
|
|
376
|
+
ymax = jnp.hstack(
|
|
377
|
+
[
|
|
378
|
+
generate_qmc_sample(subkeys[3], self.min_pts[0], self.max_pts[0]),
|
|
379
|
+
self.max_pts[1] * jnp.ones((facet_n, 1)),
|
|
380
|
+
]
|
|
381
|
+
)
|
|
382
|
+
return key, jnp.stack([xmin, xmax, ymin, ymax], axis=-1)
|
|
383
|
+
raise NotImplementedError(
|
|
384
|
+
"Generation of the border of a cube in dimension > 2 is not "
|
|
385
|
+
+ f"implemented yet. You are asking for generation in dimension d={self.dim}."
|
|
258
386
|
)
|
|
259
387
|
|
|
260
388
|
def _get_domain_operands(
|
|
@@ -8,8 +8,11 @@ from __future__ import (
|
|
|
8
8
|
import warnings
|
|
9
9
|
import equinox as eqx
|
|
10
10
|
import jax
|
|
11
|
+
import numpy as np
|
|
11
12
|
import jax.numpy as jnp
|
|
13
|
+
from scipy.stats import qmc
|
|
12
14
|
from jaxtyping import Key, Array, Float
|
|
15
|
+
from typing import Literal
|
|
13
16
|
from jinns.data._Batchs import PDEStatioBatch
|
|
14
17
|
from jinns.data._utils import _check_and_set_rar_parameters, _reset_or_increment
|
|
15
18
|
from jinns.data._AbstractDataGenerator import AbstractDataGenerator
|
|
@@ -50,11 +53,13 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
50
53
|
A tuple of maximum values of the domain along each dimension. For a sampling
|
|
51
54
|
in `n` dimension, this represents $(x_{1, max}, x_{2,max}, ...,
|
|
52
55
|
x_{n,max})$
|
|
53
|
-
method :
|
|
54
|
-
Either
|
|
56
|
+
method : Literal["grid", "uniform", "sobol", "halton"], default="uniform"
|
|
57
|
+
Either "grid", "uniform", "sobol" or "halton", default is `uniform`.
|
|
55
58
|
The method that generates the `nt` time points. `grid` means
|
|
56
59
|
regularly spaced points over the domain. `uniform` means uniformly
|
|
57
|
-
sampled points over the domain
|
|
60
|
+
sampled points over the domain.
|
|
61
|
+
**Note** that Sobol and Halton approaches use scipy modules and will not
|
|
62
|
+
be JIT compatible.
|
|
58
63
|
rar_parameters : dict[str, int], default=None
|
|
59
64
|
Defaults to None: do not use Residual Adaptative Resampling.
|
|
60
65
|
Otherwise a dictionary with keys
|
|
@@ -94,7 +99,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
94
99
|
# shape in jax.lax.dynamic_slice
|
|
95
100
|
min_pts: tuple[float, ...] = eqx.field(kw_only=True)
|
|
96
101
|
max_pts: tuple[float, ...] = eqx.field(kw_only=True)
|
|
97
|
-
method:
|
|
102
|
+
method: Literal["grid", "uniform", "sobol", "halton"] = eqx.field(
|
|
98
103
|
kw_only=True, static=True, default_factory=lambda: "uniform"
|
|
99
104
|
)
|
|
100
105
|
rar_parameters: dict[str, int] = eqx.field(kw_only=True, default=None)
|
|
@@ -132,6 +137,22 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
132
137
|
)
|
|
133
138
|
self.n = perfect_sq
|
|
134
139
|
|
|
140
|
+
if self.method in ["sobol", "halton"]:
|
|
141
|
+
log2_n = jnp.log2(self.n)
|
|
142
|
+
lower_pow = 2 ** jnp.floor(log2_n)
|
|
143
|
+
higher_pow = 2 ** jnp.ceil(log2_n)
|
|
144
|
+
closest_two_power = (
|
|
145
|
+
lower_pow
|
|
146
|
+
if (self.n - lower_pow) < (higher_pow - self.n)
|
|
147
|
+
else higher_pow
|
|
148
|
+
)
|
|
149
|
+
if self.n != closest_two_power:
|
|
150
|
+
warnings.warn(
|
|
151
|
+
f"QuasiMonteCarlo sampling with {self.method} requires sample size to be a power fo 2."
|
|
152
|
+
f"Modfiying self.n from {self.n} to {closest_two_power}.",
|
|
153
|
+
)
|
|
154
|
+
self.n = int(closest_two_power)
|
|
155
|
+
|
|
135
156
|
if self.omega_batch_size is None:
|
|
136
157
|
self.curr_omega_idx = 0
|
|
137
158
|
else:
|
|
@@ -176,24 +197,48 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
176
197
|
def sample_in_omega_domain(
|
|
177
198
|
self, keys: Key, sample_size: int
|
|
178
199
|
) -> Float[Array, " n dim"]:
|
|
200
|
+
if self.method == "uniform":
|
|
201
|
+
if self.dim == 1:
|
|
202
|
+
xmin, xmax = self.min_pts[0], self.max_pts[0]
|
|
203
|
+
return jax.random.uniform(
|
|
204
|
+
keys, shape=(sample_size, 1), minval=xmin, maxval=xmax
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
return jnp.concatenate(
|
|
208
|
+
[
|
|
209
|
+
jax.random.uniform(
|
|
210
|
+
keys[i],
|
|
211
|
+
(sample_size, 1),
|
|
212
|
+
minval=self.min_pts[i],
|
|
213
|
+
maxval=self.max_pts[i],
|
|
214
|
+
)
|
|
215
|
+
for i in range(self.dim)
|
|
216
|
+
],
|
|
217
|
+
axis=-1,
|
|
218
|
+
)
|
|
219
|
+
else:
|
|
220
|
+
return self._qmc_in_omega_domain(keys, sample_size)
|
|
221
|
+
|
|
222
|
+
def _qmc_in_omega_domain(
|
|
223
|
+
self, subkey: Key, sample_size: int
|
|
224
|
+
) -> Float[Array, "n dim"]:
|
|
225
|
+
qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
|
|
179
226
|
if self.dim == 1:
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
227
|
+
qmc_seq = qmc_generator(
|
|
228
|
+
d=self.dim,
|
|
229
|
+
scramble=True,
|
|
230
|
+
rng=np.random.default_rng(np.uint32(subkey)),
|
|
183
231
|
)
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
minval=self.min_pts[i],
|
|
191
|
-
maxval=self.max_pts[i],
|
|
192
|
-
)
|
|
193
|
-
for i in range(self.dim)
|
|
194
|
-
],
|
|
195
|
-
axis=-1,
|
|
232
|
+
u = qmc_seq.random(n=sample_size)
|
|
233
|
+
return jnp.array(
|
|
234
|
+
qmc.scale(u, l_bounds=self.min_pts[0], u_bounds=self.max_pts[0])
|
|
235
|
+
)
|
|
236
|
+
sampler = qmc.Sobol(
|
|
237
|
+
d=self.dim, scramble=True, rng=np.random.default_rng(np.uint32(subkey))
|
|
196
238
|
)
|
|
239
|
+
samples = sampler.random(n=sample_size)
|
|
240
|
+
samples = qmc.scale(samples, l_bounds=self.min_pts, u_bounds=self.max_pts)
|
|
241
|
+
return jnp.array(samples)
|
|
197
242
|
|
|
198
243
|
def sample_in_omega_border_domain(
|
|
199
244
|
self, keys: Key, sample_size: int | None = None
|
|
@@ -260,6 +305,62 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
260
305
|
+ f"implemented yet. You are asking for generation in dimension d={self.dim}."
|
|
261
306
|
)
|
|
262
307
|
|
|
308
|
+
def qmc_in_omega_border_domain(
|
|
309
|
+
self, keys: Key, sample_size: int | None = None
|
|
310
|
+
) -> Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None:
|
|
311
|
+
qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
|
|
312
|
+
sample_size = self.nb if sample_size is None else sample_size
|
|
313
|
+
if sample_size is None:
|
|
314
|
+
return None
|
|
315
|
+
if self.dim == 1:
|
|
316
|
+
xmin = self.min_pts[0]
|
|
317
|
+
xmax = self.max_pts[0]
|
|
318
|
+
return jnp.array([xmin, xmax]).astype(float)
|
|
319
|
+
if self.dim == 2:
|
|
320
|
+
# currently hard-coded the 4 edges for d==2
|
|
321
|
+
# TODO : find a general & efficient way to sample from the border
|
|
322
|
+
# (facets) of the hypercube in general dim.
|
|
323
|
+
facet_n = sample_size // (2 * self.dim)
|
|
324
|
+
|
|
325
|
+
def generate_qmc_sample(key, min_val, max_val):
|
|
326
|
+
qmc_seq = qmc_generator(
|
|
327
|
+
d=1,
|
|
328
|
+
scramble=True,
|
|
329
|
+
rng=np.random.default_rng(np.uint32(key)),
|
|
330
|
+
)
|
|
331
|
+
u = qmc_seq.random(n=facet_n)
|
|
332
|
+
return jnp.array(qmc.scale(u, l_bounds=min_val, u_bounds=max_val))
|
|
333
|
+
|
|
334
|
+
xmin = jnp.hstack(
|
|
335
|
+
[
|
|
336
|
+
self.min_pts[0] * jnp.ones((facet_n, 1)),
|
|
337
|
+
generate_qmc_sample(keys[0], self.min_pts[1], self.max_pts[1]),
|
|
338
|
+
]
|
|
339
|
+
)
|
|
340
|
+
xmax = jnp.hstack(
|
|
341
|
+
[
|
|
342
|
+
self.max_pts[0] * jnp.ones((facet_n, 1)),
|
|
343
|
+
generate_qmc_sample(keys[1], self.min_pts[1], self.max_pts[1]),
|
|
344
|
+
]
|
|
345
|
+
)
|
|
346
|
+
ymin = jnp.hstack(
|
|
347
|
+
[
|
|
348
|
+
generate_qmc_sample(keys[2], self.min_pts[0], self.max_pts[0]),
|
|
349
|
+
self.min_pts[1] * jnp.ones((facet_n, 1)),
|
|
350
|
+
]
|
|
351
|
+
)
|
|
352
|
+
ymax = jnp.hstack(
|
|
353
|
+
[
|
|
354
|
+
generate_qmc_sample(keys[3], self.min_pts[0], self.max_pts[0]),
|
|
355
|
+
self.max_pts[1] * jnp.ones((facet_n, 1)),
|
|
356
|
+
]
|
|
357
|
+
)
|
|
358
|
+
return jnp.stack([xmin, xmax, ymin, ymax], axis=-1)
|
|
359
|
+
raise NotImplementedError(
|
|
360
|
+
"Generation of the border of a cube in dimension > 2 is not "
|
|
361
|
+
+ f"implemented yet. You are asking for generation in dimension d={self.dim}."
|
|
362
|
+
)
|
|
363
|
+
|
|
263
364
|
def generate_omega_data(
|
|
264
365
|
self, key: Key, data_size: int | None = None
|
|
265
366
|
) -> tuple[
|
|
@@ -290,8 +391,8 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
290
391
|
)
|
|
291
392
|
xyz_ = [a.reshape((data_size, 1)) for a in xyz_]
|
|
292
393
|
omega = jnp.concatenate(xyz_, axis=-1)
|
|
293
|
-
elif self.method
|
|
294
|
-
if self.dim == 1:
|
|
394
|
+
elif self.method in ["uniform", "sobol", "halton"]:
|
|
395
|
+
if self.dim == 1 or self.method in ["sobol", "halton"]:
|
|
295
396
|
key, subkeys = jax.random.split(key, 2)
|
|
296
397
|
else:
|
|
297
398
|
key, *subkeys = jax.random.split(key, self.dim + 1)
|
|
@@ -317,10 +418,17 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
317
418
|
key, *subkeys = jax.random.split(key, 5)
|
|
318
419
|
else:
|
|
319
420
|
subkeys = None
|
|
320
|
-
omega_border = self.sample_in_omega_border_domain(
|
|
321
|
-
subkeys, sample_size=data_size
|
|
322
|
-
)
|
|
323
421
|
|
|
422
|
+
if self.method in ["grid", "uniform"]:
|
|
423
|
+
omega_border = self.sample_in_omega_border_domain(
|
|
424
|
+
subkeys, sample_size=data_size
|
|
425
|
+
)
|
|
426
|
+
elif self.method in ["sobol", "halton"]:
|
|
427
|
+
omega_border = self.qmc_in_omega_border_domain(
|
|
428
|
+
subkeys, sample_size=data_size
|
|
429
|
+
)
|
|
430
|
+
else:
|
|
431
|
+
raise ValueError("Method " + self.method + " is not implemented.")
|
|
324
432
|
return key, omega_border
|
|
325
433
|
|
|
326
434
|
def _get_omega_operands(
|
jinns/loss/_LossODE.py
CHANGED
|
@@ -19,6 +19,7 @@ from jaxtyping import Float, Array
|
|
|
19
19
|
from jinns.loss._loss_utils import (
|
|
20
20
|
dynamic_loss_apply,
|
|
21
21
|
observations_loss_apply,
|
|
22
|
+
initial_condition_check,
|
|
22
23
|
)
|
|
23
24
|
from jinns.parameters._params import (
|
|
24
25
|
_get_vmap_in_axes_params,
|
|
@@ -43,7 +44,7 @@ if TYPE_CHECKING:
|
|
|
43
44
|
|
|
44
45
|
|
|
45
46
|
class _LossODEAbstract(AbstractLoss):
|
|
46
|
-
"""
|
|
47
|
+
r"""
|
|
47
48
|
Parameters
|
|
48
49
|
----------
|
|
49
50
|
|
|
@@ -60,8 +61,15 @@ class _LossODEAbstract(AbstractLoss):
|
|
|
60
61
|
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
61
62
|
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
62
63
|
is `"nn_params"` for each composant of the loss.
|
|
63
|
-
initial_condition : tuple[
|
|
64
|
-
|
|
64
|
+
initial_condition : tuple[
|
|
65
|
+
Float[Array, "n_cond "],
|
|
66
|
+
Float[Array, "n_cond dim"]
|
|
67
|
+
] |
|
|
68
|
+
tuple[int | float | Float[Array, " "],
|
|
69
|
+
int | float | Float[Array, " dim"]
|
|
70
|
+
] | None, default=None
|
|
71
|
+
Most of the time, a tuple of length 2 with initial condition $(t_0, u_0)$.
|
|
72
|
+
From jinns v1.5.1 we accept tuples of jnp arrays with shape (n_cond, 1) for t0 and (n_cond, dim) for u0. This is useful to include observed conditions at different time points, such as *e.g* final conditions. It was designed to implement $\mathcal{L}^{aux}$ from _Systems biology informed deep learning for inferring parameters and hidden dynamics_, Alireza Yazdani et al., 2020
|
|
65
73
|
obs_slice : EllipsisType | slice | None, default=None
|
|
66
74
|
Slice object specifying the begininning/ending
|
|
67
75
|
slice of u output(s) that is observed. This is useful for
|
|
@@ -78,7 +86,9 @@ class _LossODEAbstract(AbstractLoss):
|
|
|
78
86
|
derivative_keys: DerivativeKeysODE | None = eqx.field(kw_only=True, default=None)
|
|
79
87
|
loss_weights: LossWeightsODE | None = eqx.field(kw_only=True, default=None)
|
|
80
88
|
initial_condition: (
|
|
81
|
-
tuple[
|
|
89
|
+
tuple[Float[Array, " n_cond 1"], Float[Array, " n_cond dim"]]
|
|
90
|
+
| tuple[int | float | Float[Array, " "], int | float | Float[Array, " dim"]]
|
|
91
|
+
| None
|
|
82
92
|
) = eqx.field(kw_only=True, default=None)
|
|
83
93
|
obs_slice: EllipsisType | slice | None = eqx.field(
|
|
84
94
|
kw_only=True, default=None, static=True
|
|
@@ -112,20 +122,60 @@ class _LossODEAbstract(AbstractLoss):
|
|
|
112
122
|
"Initial condition should be a tuple of len 2 with (t0, u0), "
|
|
113
123
|
f"{self.initial_condition} was passed."
|
|
114
124
|
)
|
|
115
|
-
# some checks/reshaping for t0
|
|
125
|
+
# some checks/reshaping for t0 and u0
|
|
116
126
|
t0, u0 = self.initial_condition
|
|
117
127
|
if isinstance(t0, Array):
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
128
|
+
# at the end we want to end up with t0 of shape (:, 1) to account for
|
|
129
|
+
# possibly several data points
|
|
130
|
+
if t0.ndim <= 1:
|
|
131
|
+
# in this case we assume t0 belongs one (initial)
|
|
132
|
+
# condition
|
|
133
|
+
t0 = initial_condition_check(t0, dim_size=1)[
|
|
134
|
+
None, :
|
|
135
|
+
] # make a (1, 1) here
|
|
136
|
+
if t0.ndim > 2:
|
|
137
|
+
raise ValueError(
|
|
138
|
+
"It t0 is an Array, it represents n_cond"
|
|
139
|
+
" imposed conditions and must be of shape (n_cond, 1)"
|
|
140
|
+
)
|
|
141
|
+
else:
|
|
142
|
+
# in this case t0 clearly represents one (initial) condition
|
|
143
|
+
t0 = initial_condition_check(t0, dim_size=1)[
|
|
144
|
+
None, :
|
|
145
|
+
] # make a (1, 1) here
|
|
146
|
+
if isinstance(u0, Array):
|
|
147
|
+
# at the end we want to end up with u0 of shape (:, dim) to account for
|
|
148
|
+
# possibly several data points
|
|
149
|
+
if not u0.shape:
|
|
150
|
+
# in this case we assume u0 belongs to one (initial)
|
|
151
|
+
# condition
|
|
152
|
+
u0 = initial_condition_check(u0, dim_size=1)[
|
|
153
|
+
None, :
|
|
154
|
+
] # make a (1, 1) here
|
|
155
|
+
elif u0.ndim == 1:
|
|
156
|
+
# in this case we assume u0 belongs to one (initial)
|
|
157
|
+
# condition
|
|
158
|
+
u0 = initial_condition_check(u0, dim_size=u0.shape[0])[
|
|
159
|
+
None, :
|
|
160
|
+
] # make a (1, dim) here
|
|
161
|
+
if u0.ndim > 2:
|
|
121
162
|
raise ValueError(
|
|
122
|
-
|
|
123
|
-
|
|
163
|
+
"It u0 is an Array, it represents n_cond "
|
|
164
|
+
"imposed conditions and must be of shape (n_cond, dim)"
|
|
124
165
|
)
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
166
|
+
else:
|
|
167
|
+
# at the end we want to end up with u0 of shape (:, dim) to account for
|
|
168
|
+
# possibly several data points
|
|
169
|
+
u0 = initial_condition_check(u0, dim_size=None)[
|
|
170
|
+
None, :
|
|
171
|
+
] # make a (1, 1) here
|
|
172
|
+
|
|
173
|
+
if t0.shape[0] != u0.shape[0] or t0.ndim != u0.ndim:
|
|
174
|
+
raise ValueError(
|
|
175
|
+
"t0 and u0 must represent a same number of initial"
|
|
176
|
+
" conditial conditions"
|
|
177
|
+
)
|
|
178
|
+
|
|
129
179
|
self.initial_condition = (t0, u0)
|
|
130
180
|
|
|
131
181
|
if self.obs_slice is None:
|
|
@@ -259,28 +309,42 @@ class LossODE(_LossODEAbstract):
|
|
|
259
309
|
|
|
260
310
|
# initial condition
|
|
261
311
|
if self.initial_condition is not None:
|
|
262
|
-
vmap_in_axes = (None,) + vmap_in_axes_params
|
|
263
|
-
if not jax.tree_util.tree_leaves(vmap_in_axes):
|
|
264
|
-
# test if only None in vmap_in_axes to avoid the value error:
|
|
265
|
-
# `vmap must have at least one non-None value in in_axes`
|
|
266
|
-
v_u = self.u
|
|
267
|
-
else:
|
|
268
|
-
v_u = vmap(self.u, (None,) + vmap_in_axes_params)
|
|
269
312
|
t0, u0 = self.initial_condition
|
|
270
313
|
u0 = jnp.array(u0)
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
314
|
+
|
|
315
|
+
# first construct the plain init loss no vmaping
|
|
316
|
+
initial_condition_fun__ = lambda t, u, p: jnp.sum(
|
|
317
|
+
(
|
|
318
|
+
self.u(
|
|
319
|
+
t,
|
|
320
|
+
_set_derivatives(
|
|
321
|
+
p,
|
|
322
|
+
self.derivative_keys.initial_condition, # type: ignore
|
|
323
|
+
),
|
|
279
324
|
)
|
|
280
|
-
|
|
281
|
-
axis=-1,
|
|
325
|
+
- u
|
|
282
326
|
)
|
|
327
|
+
** 2,
|
|
328
|
+
axis=0,
|
|
283
329
|
)
|
|
330
|
+
# now vmap over the number of conditions (first dim of t0 and u0)
|
|
331
|
+
# and take the mean
|
|
332
|
+
initial_condition_fun_ = lambda p: jnp.mean(
|
|
333
|
+
vmap(initial_condition_fun__, (0, 0, None))(t0, u0, p)
|
|
334
|
+
)
|
|
335
|
+
# now vmap over the the possible batch of parameters and take the
|
|
336
|
+
# average. Note that we then finally have a cartesian product
|
|
337
|
+
# between the batch of parameters (if any) and the number of
|
|
338
|
+
# conditions (if any)
|
|
339
|
+
if not jax.tree_util.tree_leaves(vmap_in_axes_params):
|
|
340
|
+
# if there is no parameter batch to vmap over we cannot call
|
|
341
|
+
# vmap because calling vmap must be done with at least one non
|
|
342
|
+
# None in_axes or out_axes
|
|
343
|
+
initial_condition_fun = initial_condition_fun_
|
|
344
|
+
else:
|
|
345
|
+
initial_condition_fun = lambda p: jnp.mean(
|
|
346
|
+
vmap(initial_condition_fun_, vmap_in_axes_params)(p)
|
|
347
|
+
)
|
|
284
348
|
else:
|
|
285
349
|
initial_condition_fun = None
|
|
286
350
|
|
jinns/loss/_LossPDE.py
CHANGED
|
@@ -21,6 +21,7 @@ from jinns.loss._loss_utils import (
|
|
|
21
21
|
normalization_loss_apply,
|
|
22
22
|
observations_loss_apply,
|
|
23
23
|
initial_condition_apply,
|
|
24
|
+
initial_condition_check,
|
|
24
25
|
)
|
|
25
26
|
from jinns.parameters._params import (
|
|
26
27
|
_get_vmap_in_axes_params,
|
|
@@ -700,22 +701,12 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
700
701
|
"case (e.g by. hardcoding it into the PINN output)."
|
|
701
702
|
)
|
|
702
703
|
# some checks for t0
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
elif self.t0.shape != (1,):
|
|
707
|
-
raise ValueError(
|
|
708
|
-
f"Wrong self.t0 input. It should be"
|
|
709
|
-
f"a float or an array of shape (1,). Got shape: {self.t0.shape}"
|
|
710
|
-
)
|
|
711
|
-
elif isinstance(self.t0, float): # e.g. user input: 0.
|
|
712
|
-
self.t0 = jnp.array([self.t0])
|
|
713
|
-
elif isinstance(self.t0, int): # e.g. user input: 0
|
|
714
|
-
self.t0 = jnp.array([float(self.t0)])
|
|
715
|
-
elif self.t0 is None:
|
|
716
|
-
self.t0 = jnp.array([0])
|
|
704
|
+
t0 = self.t0
|
|
705
|
+
if t0 is None:
|
|
706
|
+
t0 = jnp.array([0])
|
|
717
707
|
else:
|
|
718
|
-
|
|
708
|
+
t0 = initial_condition_check(t0, dim_size=1)
|
|
709
|
+
self.t0 = t0
|
|
719
710
|
|
|
720
711
|
# witht the variables below we avoid memory overflow since a cartesian
|
|
721
712
|
# product is taken
|
jinns/loss/_loss_utils.py
CHANGED
|
@@ -308,3 +308,26 @@ def initial_condition_apply(
|
|
|
308
308
|
else:
|
|
309
309
|
raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
|
|
310
310
|
return mse_initial_condition
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def initial_condition_check(x, dim_size=None):
|
|
314
|
+
"""
|
|
315
|
+
Make a (dim_size,) jnp array from an int, a float or a 0D jnp array
|
|
316
|
+
|
|
317
|
+
"""
|
|
318
|
+
if isinstance(x, Array):
|
|
319
|
+
if not x.shape: # e.g. user input: jnp.array(0.)
|
|
320
|
+
x = jnp.array([x])
|
|
321
|
+
if dim_size is not None: # we check for the required dims_ize
|
|
322
|
+
if x.shape != (dim_size,):
|
|
323
|
+
raise ValueError(
|
|
324
|
+
f"Wrong dim_size. It should be({dim_size},). Got shape: {x.shape}"
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
elif isinstance(x, float): # e.g. user input: 0.
|
|
328
|
+
x = jnp.array([x])
|
|
329
|
+
elif isinstance(x, int): # e.g. user input: 0
|
|
330
|
+
x = jnp.array([float(x)])
|
|
331
|
+
else:
|
|
332
|
+
raise ValueError(f"Wrong value, expected Array, float or int, got {type(x)}")
|
|
333
|
+
return x
|
jinns/parameters/_params.py
CHANGED
|
@@ -10,6 +10,14 @@ from jaxtyping import Array, PyTree, Float
|
|
|
10
10
|
T = TypeVar("T") # the generic type for what is in the Params PyTree because we
|
|
11
11
|
# have possibly Params of Arrays, boolean, ...
|
|
12
12
|
|
|
13
|
+
### NOTE
|
|
14
|
+
### We are taking derivatives with respect to Params eqx.Modules.
|
|
15
|
+
### This has been shown to behave weirdly if some fields of eqx.Modules have
|
|
16
|
+
### been set as `field(init=False)`, we then should never create such fields in
|
|
17
|
+
### jinns' Params modules.
|
|
18
|
+
### We currently have silenced the warning related to this (see jinns.__init__
|
|
19
|
+
### see https://github.com/patrick-kidger/equinox/pull/1043/commits/f88e62ab809140334c2f987ed13eff0d80b8be13
|
|
20
|
+
|
|
13
21
|
|
|
14
22
|
class Params(eqx.Module, Generic[T]):
|
|
15
23
|
"""
|
jinns/solver/_solve.py
CHANGED
|
@@ -179,6 +179,7 @@ def solve(
|
|
|
179
179
|
best_val_params
|
|
180
180
|
The best parameters according to the validation criterion
|
|
181
181
|
"""
|
|
182
|
+
initialization_time = time.time()
|
|
182
183
|
if n_iter < 1:
|
|
183
184
|
raise ValueError("Cannot run jinns.solve for n_iter<1")
|
|
184
185
|
|
|
@@ -225,11 +226,6 @@ def solve(
|
|
|
225
226
|
# get_batch with device_put, the latter is not jittable
|
|
226
227
|
get_batch = _get_get_batch(obs_batch_sharding)
|
|
227
228
|
|
|
228
|
-
# initialize the dict for stored parameter values
|
|
229
|
-
# we need to get a loss_term to init stuff
|
|
230
|
-
batch_ini, data, param_data, obs_data = get_batch(data, param_data, obs_data)
|
|
231
|
-
_, loss_terms = loss(init_params, batch_ini)
|
|
232
|
-
|
|
233
229
|
# initialize parameter tracking
|
|
234
230
|
if tracked_params is None:
|
|
235
231
|
tracked_params = jax.tree.map(lambda p: None, init_params)
|
|
@@ -247,6 +243,13 @@ def solve(
|
|
|
247
243
|
# being a complex data structure
|
|
248
244
|
)
|
|
249
245
|
|
|
246
|
+
# initialize the dict for stored parameter values
|
|
247
|
+
# we need to get a loss_term to init stuff
|
|
248
|
+
# NOTE: we use jax.eval_shape to avoid FLOPS since we only need the tree
|
|
249
|
+
# structure
|
|
250
|
+
batch_ini, data, param_data, obs_data = get_batch(data, param_data, obs_data)
|
|
251
|
+
_, loss_terms = jax.eval_shape(loss, init_params, batch_ini)
|
|
252
|
+
|
|
250
253
|
# initialize the PyTree for stored loss values
|
|
251
254
|
stored_loss_terms = jax.tree_util.tree_map(
|
|
252
255
|
lambda _: jnp.zeros((n_iter)), loss_terms
|
|
@@ -475,6 +478,9 @@ def solve(
|
|
|
475
478
|
key,
|
|
476
479
|
)
|
|
477
480
|
|
|
481
|
+
if verbose:
|
|
482
|
+
print("Initialization time:", time.time() - initialization_time)
|
|
483
|
+
|
|
478
484
|
# Main optimization loop. We use the LAX while loop (fully jitted) version
|
|
479
485
|
# if no mixing devices. Otherwise we use the standard while loop. Here devices only
|
|
480
486
|
# concern obs_batch, but it could lead to more complex scheme in the future
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: jinns
|
|
3
|
-
Version: 1.5.
|
|
3
|
+
Version: 1.5.1
|
|
4
4
|
Summary: Physics Informed Neural Network with JAX
|
|
5
5
|
Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
6
6
|
Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
jinns/__init__.py,sha256=
|
|
1
|
+
jinns/__init__.py,sha256=Tjp5z0Mnd1nscvJaXnxZb4lEYbI0cpN6OPgCZ-Swo74,453
|
|
2
2
|
jinns/data/_AbstractDataGenerator.py,sha256=O61TBOyeOFKwf1xqKzFD4KwCWRDnm2XgyJ-kKY9fmB4,557
|
|
3
3
|
jinns/data/_Batchs.py,sha256=-DlD6Qag3zs5QbKtKAOvOzV7JOpNOqAm_P8cwo1dIZg,1574
|
|
4
|
-
jinns/data/_CubicMeshPDENonStatio.py,sha256=
|
|
5
|
-
jinns/data/_CubicMeshPDEStatio.py,sha256=
|
|
4
|
+
jinns/data/_CubicMeshPDENonStatio.py,sha256=4f21SgeNQsGJTz8-uehduU0X9TibRcs28Iq49Kv4nQQ,22250
|
|
5
|
+
jinns/data/_CubicMeshPDEStatio.py,sha256=DVrP4qHVAJMu915EYH2PKyzwoG0nIMixAEKR2fz6C58,22525
|
|
6
6
|
jinns/data/_DataGeneratorODE.py,sha256=5RzUbQFEsooAZsocDw4wRgA_w5lJmDMuY4M6u79K-1c,7260
|
|
7
7
|
jinns/data/_DataGeneratorObservations.py,sha256=jknepLsJatSJHFq5lLMD-fFHkPGj5q286LEjE-vH24k,7738
|
|
8
8
|
jinns/data/_DataGeneratorParameter.py,sha256=IedX3jcOj7ZDW_18IAcRR75KVzQzo85z9SICIKDBJl4,8539
|
|
@@ -12,13 +12,13 @@ jinns/experimental/__init__.py,sha256=DT9e57zbjfzPeRnXemGUqnGd--MhV77FspChT0z4Yr
|
|
|
12
12
|
jinns/experimental/_diffrax_solver.py,sha256=upMr3kTTNrxEiSUO_oLvCXcjS9lPxSjvbB81h3qlhaU,6813
|
|
13
13
|
jinns/loss/_DynamicLoss.py,sha256=4mb7OCP-cGZ_mG2MQ-AniddDcuBT78p4bQI7rZpwte4,22722
|
|
14
14
|
jinns/loss/_DynamicLossAbstract.py,sha256=QhHRgvtcT-ifHlOxTyXbjDtHk9UfPN2Si8s3v9nEQm4,12672
|
|
15
|
-
jinns/loss/_LossODE.py,sha256=
|
|
16
|
-
jinns/loss/_LossPDE.py,sha256=
|
|
15
|
+
jinns/loss/_LossODE.py,sha256=iVYDojaI6Co7S5CrU67_niopD4Bk7UBTuLzDiTHoWMc,16996
|
|
16
|
+
jinns/loss/_LossPDE.py,sha256=VT56oQ_33fLq46lIch0slNsxu4d97eQBOgRAPeFESts,36401
|
|
17
17
|
jinns/loss/__init__.py,sha256=z5xYgBipNFf66__5BqQc6R_8r4F6A3TXL60YjsM8Osk,1287
|
|
18
18
|
jinns/loss/_abstract_loss.py,sha256=DMxn0SQe9PW-pq3p5Oqvb0YK3_ulLDOnoIXzK219GV4,4576
|
|
19
19
|
jinns/loss/_boundary_conditions.py,sha256=9HGw1cGLfmEilP4V4B2T0zl0YP1kNtrtXVLQNiBmWgc,12464
|
|
20
20
|
jinns/loss/_loss_components.py,sha256=MMzaGlaRqESPjRzT0j0WU9HAqWQSbIXpGAqM1xQCZHw,1106
|
|
21
|
-
jinns/loss/_loss_utils.py,sha256=
|
|
21
|
+
jinns/loss/_loss_utils.py,sha256=eJ4JcBm396LHx7Tti88ZQrLcKqVL1oSfFGT23VNkytQ,11949
|
|
22
22
|
jinns/loss/_loss_weight_updates.py,sha256=9Bwouh7shLyc_wrdzN6CYL0ZuQH81uEs-L6wCeiYFx8,6817
|
|
23
23
|
jinns/loss/_loss_weights.py,sha256=kII5WddORgeommFTudT3CSvhICpo6nSe47LclUgu_78,2429
|
|
24
24
|
jinns/loss/_operators.py,sha256=Ds5yRH7hu-jaGRp7PYbt821BgYuEvgWHufWhYgdMjw0,22909
|
|
@@ -34,12 +34,12 @@ jinns/nn/_spinn_mlp.py,sha256=uCL454sF0Tfj7KT-fdXPnvKJYRQOuq60N0r2b2VAB8Q,7606
|
|
|
34
34
|
jinns/nn/_utils.py,sha256=9UXz73iHKHVQYPBPIEitrHYJzJ14dspRwPfLA8avx0c,1120
|
|
35
35
|
jinns/parameters/__init__.py,sha256=O0n7y6R1LRmFzzugCxMFCMS2pgsuWSh-XHjfFViN_eg,265
|
|
36
36
|
jinns/parameters/_derivative_keys.py,sha256=YlLDX49PfYhr2Tj--t3praiD8JOUTZU6PTmjbNZsbMc,19173
|
|
37
|
-
jinns/parameters/_params.py,sha256=
|
|
37
|
+
jinns/parameters/_params.py,sha256=nv0WScbgUdmuC0bSF15VbnKypJ58pl6wynZAcYfuF6M,3081
|
|
38
38
|
jinns/plot/__init__.py,sha256=KPHX0Um4FbciZO1yD8kjZbkaT8tT964Y6SE2xCQ4eDU,135
|
|
39
39
|
jinns/plot/_plot.py,sha256=-A5auNeElaz2_8UzVQJQE4143ZFg0zgMjStU7kwttEY,11565
|
|
40
40
|
jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
41
41
|
jinns/solver/_rar.py,sha256=vSVTnCGCusI1vTZCvIkP2_G8we44G_42yZHx2sOK9DE,10291
|
|
42
|
-
jinns/solver/_solve.py,sha256=
|
|
42
|
+
jinns/solver/_solve.py,sha256=IsrmG2m48KkkYgvXYomlSbZ3hd1FySCj3rwlkovs-lI,28616
|
|
43
43
|
jinns/solver/_utils.py,sha256=sM2UbVzYyjw24l4QSIR3IlynJTPGD_S08r8v0lXMxA8,5876
|
|
44
44
|
jinns/utils/__init__.py,sha256=OEYWLCw8pKE7xoQREbd6SHvCjuw2QZHuVA6YwDcsBE8,53
|
|
45
45
|
jinns/utils/_containers.py,sha256=YShcrPKfj5_I9mn3NMAS4Ea9MhhyL7fjv0e3MRbITHg,1837
|
|
@@ -47,9 +47,9 @@ jinns/utils/_types.py,sha256=jl_91HtcrtE6UHbdTrRI8iUmr2kBUL0oP0UNIKhAXYw,1170
|
|
|
47
47
|
jinns/utils/_utils.py,sha256=M7NXX9ok-BkH5o_xo74PB1_Cc8XiDipSl51rq82dTH4,2821
|
|
48
48
|
jinns/validation/__init__.py,sha256=FTyUO-v1b8Tv-FDSQsntrH7zl9E0ENexqKMT_dFRkYo,124
|
|
49
49
|
jinns/validation/_validation.py,sha256=8p6sMKiBAvA6JNm65hjkMj0997LJ0BkyCREEh0AnPVE,4803
|
|
50
|
-
jinns-1.5.
|
|
51
|
-
jinns-1.5.
|
|
52
|
-
jinns-1.5.
|
|
53
|
-
jinns-1.5.
|
|
54
|
-
jinns-1.5.
|
|
55
|
-
jinns-1.5.
|
|
50
|
+
jinns-1.5.1.dist-info/licenses/AUTHORS,sha256=7NwCj9nU-HNG1asvy4qhQ2w7oZHrn-Lk5_wK_Ve7a3M,80
|
|
51
|
+
jinns-1.5.1.dist-info/licenses/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
|
|
52
|
+
jinns-1.5.1.dist-info/METADATA,sha256=K7Aii5ivFcczIwLlQtCPqzMJEfW86D1yW1q7qMtvWPE,5314
|
|
53
|
+
jinns-1.5.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
54
|
+
jinns-1.5.1.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
|
|
55
|
+
jinns-1.5.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|