jinns 1.5.0__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/__init__.py CHANGED
@@ -1,10 +1,3 @@
1
- # import jinns.data
2
- # import jinns.loss
3
- # import jinns.solver
4
- # import jinns.utils
5
- # import jinns.experimental
6
- # import jinns.parameters
7
- # import jinns.plot
8
1
  from jinns import data as data
9
2
  from jinns import loss as loss
10
3
  from jinns import solver as solver
@@ -16,3 +9,10 @@ from jinns import nn as nn
16
9
  from jinns.solver._solve import solve
17
10
 
18
11
  __all__ = ["nn", "solve"]
12
+
13
+ import warnings
14
+
15
+ warnings.filterwarnings(
16
+ action="ignore",
17
+ message=r"Using `field\(init=False\)`",
18
+ )
@@ -7,8 +7,10 @@ from __future__ import (
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
8
  import warnings
9
9
  import equinox as eqx
10
+ import numpy as np
10
11
  import jax
11
12
  import jax.numpy as jnp
13
+ from scipy.stats import qmc
12
14
  from jaxtyping import Key, Array, Float
13
15
  from jinns.data._Batchs import PDENonStatioBatch
14
16
  from jinns.data._utils import (
@@ -65,11 +67,13 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
65
67
  The minimum value of the time domain to consider
66
68
  tmax : float
67
69
  The maximum value of the time domain to consider
68
- method : str, default="uniform"
70
+ method : Literal["uniform", "grid", "sobol", "halton"], default="uniform"
69
71
  Either `grid` or `uniform`, default is `uniform`.
70
72
  The method that generates the `nt` time points. `grid` means
71
73
  regularly spaced points over the domain. `uniform` means uniformly
72
- sampled points over the domain
74
+ sampled points over the domain.
75
+ **Note** that Sobol and Halton approaches use scipy modules and will not
76
+ be JIT compatible.
73
77
  rar_parameters : Dict[str, int], default=None
74
78
  Defaults to None: do not use Residual Adaptative Resampling.
75
79
  Otherwise a dictionary with keys
@@ -150,9 +154,11 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
150
154
  elif self.method == "uniform":
151
155
  self.key, domain_times = self.generate_time_data(self.key, self.n)
152
156
  self.domain = jnp.concatenate([domain_times, self.omega], axis=1)
157
+ elif self.method in ["sobol", "halton"]:
158
+ self.key, self.domain = self.qmc_in_time_omega_domain(self.key, self.n)
153
159
  else:
154
160
  raise ValueError(
155
- f'Bad value for method. Got {self.method}, expected "grid" or "uniform"'
161
+ f'Bad value for method. Got {self.method}, expected "grid" or "uniform" or "sobol" or "halton"'
156
162
  )
157
163
 
158
164
  if self.domain_batch_size is None:
@@ -182,21 +188,28 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
182
188
  "number of points per facets (nb//2*self.dim)"
183
189
  " cannot be lower than border batch size"
184
190
  )
185
- self.key, boundary_times = self.generate_time_data(
186
- self.key, self.nb // (2 * self.dim)
187
- )
188
- boundary_times = boundary_times.reshape(-1, 1, 1)
189
- boundary_times = jnp.repeat(
190
- boundary_times, self.omega_border.shape[-1], axis=2
191
- )
192
- if self.dim == 1:
193
- self.border = make_cartesian_product(
194
- boundary_times, self.omega_border[None, None]
191
+ if self.method in ["grid", "uniform"]:
192
+ self.key, boundary_times = self.generate_time_data(
193
+ self.key, self.nb // (2 * self.dim)
195
194
  )
195
+ boundary_times = boundary_times.reshape(-1, 1, 1)
196
+ boundary_times = jnp.repeat(
197
+ boundary_times, self.omega_border.shape[-1], axis=2
198
+ )
199
+ if self.dim == 1:
200
+ self.border = make_cartesian_product(
201
+ boundary_times, self.omega_border[None, None]
202
+ )
203
+ else:
204
+ self.border = jnp.concatenate(
205
+ [boundary_times, self.omega_border], axis=1
206
+ )
196
207
  else:
197
- self.border = jnp.concatenate(
198
- [boundary_times, self.omega_border], axis=1
208
+ self.key, self.border = self.qmc_in_time_omega_border_domain(
209
+ self.key,
210
+ self.nb, # type: ignore (see inside the fun)
199
211
  )
212
+
200
213
  if self.border_batch_size is None:
201
214
  self.curr_border_idx = 0
202
215
  else:
@@ -209,14 +222,30 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
209
222
  self.curr_border_idx = 0
210
223
 
211
224
  if self.ni is not None:
212
- perfect_sq = int(jnp.round(jnp.sqrt(self.ni)) ** 2)
213
- if self.ni != perfect_sq:
214
- warnings.warn(
215
- "Grid sampling is requested in dimension 2 with a non"
216
- f" perfect square dataset size (self.ni = {self.ni})."
217
- f" Modifying self.ni to self.ni = {perfect_sq}."
225
+ if self.method == "grid":
226
+ perfect_sq = int(jnp.round(jnp.sqrt(self.ni)) ** 2)
227
+ if self.ni != perfect_sq:
228
+ warnings.warn(
229
+ "Grid sampling is requested in dimension 2 with a non"
230
+ f" perfect square dataset size (self.ni = {self.ni})."
231
+ f" Modifying self.ni to self.ni = {perfect_sq}."
232
+ )
233
+ self.ni = perfect_sq
234
+ if self.method in ["sobol", "halton"]:
235
+ log2_n = jnp.log2(self.ni)
236
+ lower_pow = 2 ** jnp.floor(log2_n)
237
+ higher_pow = 2 ** jnp.ceil(log2_n)
238
+ closest_two_power = (
239
+ lower_pow
240
+ if (self.ni - lower_pow) < (higher_pow - self.ni)
241
+ else higher_pow
218
242
  )
219
- self.ni = perfect_sq
243
+ if self.n != closest_two_power:
244
+ warnings.warn(
245
+ f"QuasiMonteCarlo sampling with {self.method} requires sample size to be a power fo 2."
246
+ f"Modfiying self.n from {self.ni} to {closest_two_power}.",
247
+ )
248
+ self.ni = int(closest_two_power)
220
249
  self.key, self.initial = self.generate_omega_data(
221
250
  self.key, data_size=self.ni
222
251
  )
@@ -245,16 +274,115 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
245
274
  if self.method == "grid":
246
275
  partial_times = (self.tmax - self.tmin) / nt
247
276
  return key, jnp.arange(self.tmin, self.tmax, partial_times)[:, None]
248
- if self.method == "uniform":
277
+ elif self.method in ["uniform", "sobol", "halton"]:
249
278
  return key, self.sample_in_time_domain(subkey, nt)
250
279
  raise ValueError("Method " + self.method + " is not implemented.")
251
280
 
252
281
  def sample_in_time_domain(self, key: Key, nt: int) -> Float[Array, " nt 1"]:
253
- return jax.random.uniform(
254
- key,
255
- (nt, 1),
256
- minval=self.tmin,
257
- maxval=self.tmax,
282
+ return jax.random.uniform(key, (nt, 1), minval=self.tmin, maxval=self.tmax)
283
+
284
+ def qmc_in_time_omega_domain(
285
+ self, key: Key, sample_size: int
286
+ ) -> tuple[Key, Float[Array, "n 1+dim"]]:
287
+ """
288
+ Because in Quasi-Monte Carlo sampling we cannot concatenate two vectors generated independently
289
+ We generate time and omega samples jointly
290
+ """
291
+ key, subkey = jax.random.split(key, 2)
292
+ qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
293
+ sampler = qmc_generator(
294
+ d=self.dim + 1, scramble=True, rng=np.random.default_rng(np.uint32(subkey))
295
+ )
296
+ samples = sampler.random(n=sample_size)
297
+ samples[:, 1:] = qmc.scale(
298
+ samples[:, 1:], l_bounds=self.min_pts, u_bounds=self.max_pts
299
+ ) # We scale omega domain to be in (min_pts, max_pts)
300
+ return key, jnp.array(samples)
301
+
302
+ def qmc_in_time_omega_border_domain(
303
+ self, key: Key, sample_size: int | None = None
304
+ ) -> tuple[Key, Float[Array, "n 1+dim"]] | None:
305
+ """
306
+ For each facet of the border we generate Quasi-MonteCarlo sequences jointy with time.
307
+
308
+ We need to do some type ignore in this function because we have lost
309
+ the type narrowing from post_init, type checkers only narrow at function level and because we cannot narrow a class attribute.
310
+ """
311
+ qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
312
+ sample_size = self.nb if sample_size is None else sample_size
313
+ if sample_size is None:
314
+ return None
315
+ if self.dim == 1:
316
+ key, subkey = jax.random.split(key, 2)
317
+ qmc_seq = qmc_generator(
318
+ d=1, scramble=True, rng=np.random.default_rng(np.uint32(subkey))
319
+ )
320
+ boundary_times = jnp.array(
321
+ qmc_seq.random(self.nb // (2 * self.dim)) # type: ignore
322
+ )
323
+ boundary_times = boundary_times.reshape(-1, 1, 1)
324
+ boundary_times = jnp.repeat(
325
+ boundary_times,
326
+ self.omega_border.shape[-1], # type: ignore
327
+ axis=2,
328
+ )
329
+ return key, make_cartesian_product(
330
+ boundary_times,
331
+ self.omega_border[None, None], # type: ignore
332
+ )
333
+ if self.dim == 2:
334
+ # currently hard-coded the 4 edges for d==2
335
+ # TODO : find a general & efficient way to sample from the border
336
+ # (facets) of the hypercube in general dim.
337
+ key, *subkeys = jax.random.split(key, 5)
338
+ facet_n = sample_size // (2 * self.dim)
339
+
340
+ def generate_qmc_sample(key, min_val, max_val):
341
+ qmc_seq = qmc_generator(
342
+ d=2,
343
+ scramble=True,
344
+ rng=np.random.default_rng(np.uint32(key)),
345
+ )
346
+ u = qmc_seq.random(n=facet_n)
347
+ u[:, 1:2] = qmc.scale(u[:, 1:2], l_bounds=min_val, u_bounds=max_val)
348
+ return jnp.array(u)
349
+
350
+ xmin_sample = generate_qmc_sample(
351
+ subkeys[0], self.min_pts[1], self.max_pts[1]
352
+ ) # [t,x,y]
353
+ xmin = jnp.hstack(
354
+ [
355
+ xmin_sample[:, 0:1],
356
+ self.min_pts[0] * jnp.ones((facet_n, 1)),
357
+ xmin_sample[:, 1:2],
358
+ ]
359
+ )
360
+ xmax_sample = generate_qmc_sample(
361
+ subkeys[1], self.min_pts[1], self.max_pts[1]
362
+ )
363
+ xmax = jnp.hstack(
364
+ [
365
+ xmax_sample[:, 0:1],
366
+ self.max_pts[0] * jnp.ones((facet_n, 1)),
367
+ xmax_sample[:, 1:2],
368
+ ]
369
+ )
370
+ ymin = jnp.hstack(
371
+ [
372
+ generate_qmc_sample(subkeys[2], self.min_pts[0], self.max_pts[0]),
373
+ self.min_pts[1] * jnp.ones((facet_n, 1)),
374
+ ]
375
+ )
376
+ ymax = jnp.hstack(
377
+ [
378
+ generate_qmc_sample(subkeys[3], self.min_pts[0], self.max_pts[0]),
379
+ self.max_pts[1] * jnp.ones((facet_n, 1)),
380
+ ]
381
+ )
382
+ return key, jnp.stack([xmin, xmax, ymin, ymax], axis=-1)
383
+ raise NotImplementedError(
384
+ "Generation of the border of a cube in dimension > 2 is not "
385
+ + f"implemented yet. You are asking for generation in dimension d={self.dim}."
258
386
  )
259
387
 
260
388
  def _get_domain_operands(
@@ -8,8 +8,11 @@ from __future__ import (
8
8
  import warnings
9
9
  import equinox as eqx
10
10
  import jax
11
+ import numpy as np
11
12
  import jax.numpy as jnp
13
+ from scipy.stats import qmc
12
14
  from jaxtyping import Key, Array, Float
15
+ from typing import Literal
13
16
  from jinns.data._Batchs import PDEStatioBatch
14
17
  from jinns.data._utils import _check_and_set_rar_parameters, _reset_or_increment
15
18
  from jinns.data._AbstractDataGenerator import AbstractDataGenerator
@@ -50,11 +53,13 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
50
53
  A tuple of maximum values of the domain along each dimension. For a sampling
51
54
  in `n` dimension, this represents $(x_{1, max}, x_{2,max}, ...,
52
55
  x_{n,max})$
53
- method : str, default="uniform"
54
- Either `grid` or `uniform`, default is `uniform`.
56
+ method : Literal["grid", "uniform", "sobol", "halton"], default="uniform"
57
+ Either "grid", "uniform", "sobol" or "halton", default is `uniform`.
55
58
  The method that generates the `nt` time points. `grid` means
56
59
  regularly spaced points over the domain. `uniform` means uniformly
57
- sampled points over the domain
60
+ sampled points over the domain.
61
+ **Note** that Sobol and Halton approaches use scipy modules and will not
62
+ be JIT compatible.
58
63
  rar_parameters : dict[str, int], default=None
59
64
  Defaults to None: do not use Residual Adaptative Resampling.
60
65
  Otherwise a dictionary with keys
@@ -94,7 +99,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
94
99
  # shape in jax.lax.dynamic_slice
95
100
  min_pts: tuple[float, ...] = eqx.field(kw_only=True)
96
101
  max_pts: tuple[float, ...] = eqx.field(kw_only=True)
97
- method: str = eqx.field(
102
+ method: Literal["grid", "uniform", "sobol", "halton"] = eqx.field(
98
103
  kw_only=True, static=True, default_factory=lambda: "uniform"
99
104
  )
100
105
  rar_parameters: dict[str, int] = eqx.field(kw_only=True, default=None)
@@ -132,6 +137,22 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
132
137
  )
133
138
  self.n = perfect_sq
134
139
 
140
+ if self.method in ["sobol", "halton"]:
141
+ log2_n = jnp.log2(self.n)
142
+ lower_pow = 2 ** jnp.floor(log2_n)
143
+ higher_pow = 2 ** jnp.ceil(log2_n)
144
+ closest_two_power = (
145
+ lower_pow
146
+ if (self.n - lower_pow) < (higher_pow - self.n)
147
+ else higher_pow
148
+ )
149
+ if self.n != closest_two_power:
150
+ warnings.warn(
151
+ f"QuasiMonteCarlo sampling with {self.method} requires sample size to be a power fo 2."
152
+ f"Modfiying self.n from {self.n} to {closest_two_power}.",
153
+ )
154
+ self.n = int(closest_two_power)
155
+
135
156
  if self.omega_batch_size is None:
136
157
  self.curr_omega_idx = 0
137
158
  else:
@@ -176,24 +197,48 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
176
197
  def sample_in_omega_domain(
177
198
  self, keys: Key, sample_size: int
178
199
  ) -> Float[Array, " n dim"]:
200
+ if self.method == "uniform":
201
+ if self.dim == 1:
202
+ xmin, xmax = self.min_pts[0], self.max_pts[0]
203
+ return jax.random.uniform(
204
+ keys, shape=(sample_size, 1), minval=xmin, maxval=xmax
205
+ )
206
+
207
+ return jnp.concatenate(
208
+ [
209
+ jax.random.uniform(
210
+ keys[i],
211
+ (sample_size, 1),
212
+ minval=self.min_pts[i],
213
+ maxval=self.max_pts[i],
214
+ )
215
+ for i in range(self.dim)
216
+ ],
217
+ axis=-1,
218
+ )
219
+ else:
220
+ return self._qmc_in_omega_domain(keys, sample_size)
221
+
222
+ def _qmc_in_omega_domain(
223
+ self, subkey: Key, sample_size: int
224
+ ) -> Float[Array, "n dim"]:
225
+ qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
179
226
  if self.dim == 1:
180
- xmin, xmax = self.min_pts[0], self.max_pts[0]
181
- return jax.random.uniform(
182
- keys, shape=(sample_size, 1), minval=xmin, maxval=xmax
227
+ qmc_seq = qmc_generator(
228
+ d=self.dim,
229
+ scramble=True,
230
+ rng=np.random.default_rng(np.uint32(subkey)),
183
231
  )
184
- # keys = jax.random.split(key, self.dim)
185
- return jnp.concatenate(
186
- [
187
- jax.random.uniform(
188
- keys[i],
189
- (sample_size, 1),
190
- minval=self.min_pts[i],
191
- maxval=self.max_pts[i],
192
- )
193
- for i in range(self.dim)
194
- ],
195
- axis=-1,
232
+ u = qmc_seq.random(n=sample_size)
233
+ return jnp.array(
234
+ qmc.scale(u, l_bounds=self.min_pts[0], u_bounds=self.max_pts[0])
235
+ )
236
+ sampler = qmc.Sobol(
237
+ d=self.dim, scramble=True, rng=np.random.default_rng(np.uint32(subkey))
196
238
  )
239
+ samples = sampler.random(n=sample_size)
240
+ samples = qmc.scale(samples, l_bounds=self.min_pts, u_bounds=self.max_pts)
241
+ return jnp.array(samples)
197
242
 
198
243
  def sample_in_omega_border_domain(
199
244
  self, keys: Key, sample_size: int | None = None
@@ -260,6 +305,62 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
260
305
  + f"implemented yet. You are asking for generation in dimension d={self.dim}."
261
306
  )
262
307
 
308
+ def qmc_in_omega_border_domain(
309
+ self, keys: Key, sample_size: int | None = None
310
+ ) -> Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None:
311
+ qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
312
+ sample_size = self.nb if sample_size is None else sample_size
313
+ if sample_size is None:
314
+ return None
315
+ if self.dim == 1:
316
+ xmin = self.min_pts[0]
317
+ xmax = self.max_pts[0]
318
+ return jnp.array([xmin, xmax]).astype(float)
319
+ if self.dim == 2:
320
+ # currently hard-coded the 4 edges for d==2
321
+ # TODO : find a general & efficient way to sample from the border
322
+ # (facets) of the hypercube in general dim.
323
+ facet_n = sample_size // (2 * self.dim)
324
+
325
+ def generate_qmc_sample(key, min_val, max_val):
326
+ qmc_seq = qmc_generator(
327
+ d=1,
328
+ scramble=True,
329
+ rng=np.random.default_rng(np.uint32(key)),
330
+ )
331
+ u = qmc_seq.random(n=facet_n)
332
+ return jnp.array(qmc.scale(u, l_bounds=min_val, u_bounds=max_val))
333
+
334
+ xmin = jnp.hstack(
335
+ [
336
+ self.min_pts[0] * jnp.ones((facet_n, 1)),
337
+ generate_qmc_sample(keys[0], self.min_pts[1], self.max_pts[1]),
338
+ ]
339
+ )
340
+ xmax = jnp.hstack(
341
+ [
342
+ self.max_pts[0] * jnp.ones((facet_n, 1)),
343
+ generate_qmc_sample(keys[1], self.min_pts[1], self.max_pts[1]),
344
+ ]
345
+ )
346
+ ymin = jnp.hstack(
347
+ [
348
+ generate_qmc_sample(keys[2], self.min_pts[0], self.max_pts[0]),
349
+ self.min_pts[1] * jnp.ones((facet_n, 1)),
350
+ ]
351
+ )
352
+ ymax = jnp.hstack(
353
+ [
354
+ generate_qmc_sample(keys[3], self.min_pts[0], self.max_pts[0]),
355
+ self.max_pts[1] * jnp.ones((facet_n, 1)),
356
+ ]
357
+ )
358
+ return jnp.stack([xmin, xmax, ymin, ymax], axis=-1)
359
+ raise NotImplementedError(
360
+ "Generation of the border of a cube in dimension > 2 is not "
361
+ + f"implemented yet. You are asking for generation in dimension d={self.dim}."
362
+ )
363
+
263
364
  def generate_omega_data(
264
365
  self, key: Key, data_size: int | None = None
265
366
  ) -> tuple[
@@ -290,8 +391,8 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
290
391
  )
291
392
  xyz_ = [a.reshape((data_size, 1)) for a in xyz_]
292
393
  omega = jnp.concatenate(xyz_, axis=-1)
293
- elif self.method == "uniform":
294
- if self.dim == 1:
394
+ elif self.method in ["uniform", "sobol", "halton"]:
395
+ if self.dim == 1 or self.method in ["sobol", "halton"]:
295
396
  key, subkeys = jax.random.split(key, 2)
296
397
  else:
297
398
  key, *subkeys = jax.random.split(key, self.dim + 1)
@@ -317,10 +418,17 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
317
418
  key, *subkeys = jax.random.split(key, 5)
318
419
  else:
319
420
  subkeys = None
320
- omega_border = self.sample_in_omega_border_domain(
321
- subkeys, sample_size=data_size
322
- )
323
421
 
422
+ if self.method in ["grid", "uniform"]:
423
+ omega_border = self.sample_in_omega_border_domain(
424
+ subkeys, sample_size=data_size
425
+ )
426
+ elif self.method in ["sobol", "halton"]:
427
+ omega_border = self.qmc_in_omega_border_domain(
428
+ subkeys, sample_size=data_size
429
+ )
430
+ else:
431
+ raise ValueError("Method " + self.method + " is not implemented.")
324
432
  return key, omega_border
325
433
 
326
434
  def _get_omega_operands(
jinns/loss/_LossODE.py CHANGED
@@ -19,6 +19,7 @@ from jaxtyping import Float, Array
19
19
  from jinns.loss._loss_utils import (
20
20
  dynamic_loss_apply,
21
21
  observations_loss_apply,
22
+ initial_condition_check,
22
23
  )
23
24
  from jinns.parameters._params import (
24
25
  _get_vmap_in_axes_params,
@@ -43,7 +44,7 @@ if TYPE_CHECKING:
43
44
 
44
45
 
45
46
  class _LossODEAbstract(AbstractLoss):
46
- """
47
+ r"""
47
48
  Parameters
48
49
  ----------
49
50
 
@@ -60,8 +61,15 @@ class _LossODEAbstract(AbstractLoss):
60
61
  Fields can be "nn_params", "eq_params" or "both". Those that should not
61
62
  be updated will have a `jax.lax.stop_gradient` called on them. Default
62
63
  is `"nn_params"` for each composant of the loss.
63
- initial_condition : tuple[float | Float[Array, " 1"], Float[Array, " dim"]], default=None
64
- tuple of length 2 with initial condition $(t_0, u_0)$.
64
+ initial_condition : tuple[
65
+ Float[Array, "n_cond "],
66
+ Float[Array, "n_cond dim"]
67
+ ] |
68
+ tuple[int | float | Float[Array, " "],
69
+ int | float | Float[Array, " dim"]
70
+ ] | None, default=None
71
+ Most of the time, a tuple of length 2 with initial condition $(t_0, u_0)$.
72
+ From jinns v1.5.1 we accept tuples of jnp arrays with shape (n_cond, 1) for t0 and (n_cond, dim) for u0. This is useful to include observed conditions at different time points, such as *e.g* final conditions. It was designed to implement $\mathcal{L}^{aux}$ from _Systems biology informed deep learning for inferring parameters and hidden dynamics_, Alireza Yazdani et al., 2020
65
73
  obs_slice : EllipsisType | slice | None, default=None
66
74
  Slice object specifying the begininning/ending
67
75
  slice of u output(s) that is observed. This is useful for
@@ -78,7 +86,9 @@ class _LossODEAbstract(AbstractLoss):
78
86
  derivative_keys: DerivativeKeysODE | None = eqx.field(kw_only=True, default=None)
79
87
  loss_weights: LossWeightsODE | None = eqx.field(kw_only=True, default=None)
80
88
  initial_condition: (
81
- tuple[float | Float[Array, " 1"], Float[Array, " dim"]] | None
89
+ tuple[Float[Array, " n_cond 1"], Float[Array, " n_cond dim"]]
90
+ | tuple[int | float | Float[Array, " "], int | float | Float[Array, " dim"]]
91
+ | None
82
92
  ) = eqx.field(kw_only=True, default=None)
83
93
  obs_slice: EllipsisType | slice | None = eqx.field(
84
94
  kw_only=True, default=None, static=True
@@ -112,20 +122,60 @@ class _LossODEAbstract(AbstractLoss):
112
122
  "Initial condition should be a tuple of len 2 with (t0, u0), "
113
123
  f"{self.initial_condition} was passed."
114
124
  )
115
- # some checks/reshaping for t0
125
+ # some checks/reshaping for t0 and u0
116
126
  t0, u0 = self.initial_condition
117
127
  if isinstance(t0, Array):
118
- if not t0.shape: # e.g. user input: jnp.array(0.)
119
- t0 = jnp.array([t0])
120
- elif t0.shape != (1,):
128
+ # at the end we want to end up with t0 of shape (:, 1) to account for
129
+ # possibly several data points
130
+ if t0.ndim <= 1:
131
+ # in this case we assume t0 belongs one (initial)
132
+ # condition
133
+ t0 = initial_condition_check(t0, dim_size=1)[
134
+ None, :
135
+ ] # make a (1, 1) here
136
+ if t0.ndim > 2:
137
+ raise ValueError(
138
+ "It t0 is an Array, it represents n_cond"
139
+ " imposed conditions and must be of shape (n_cond, 1)"
140
+ )
141
+ else:
142
+ # in this case t0 clearly represents one (initial) condition
143
+ t0 = initial_condition_check(t0, dim_size=1)[
144
+ None, :
145
+ ] # make a (1, 1) here
146
+ if isinstance(u0, Array):
147
+ # at the end we want to end up with u0 of shape (:, dim) to account for
148
+ # possibly several data points
149
+ if not u0.shape:
150
+ # in this case we assume u0 belongs to one (initial)
151
+ # condition
152
+ u0 = initial_condition_check(u0, dim_size=1)[
153
+ None, :
154
+ ] # make a (1, 1) here
155
+ elif u0.ndim == 1:
156
+ # in this case we assume u0 belongs to one (initial)
157
+ # condition
158
+ u0 = initial_condition_check(u0, dim_size=u0.shape[0])[
159
+ None, :
160
+ ] # make a (1, dim) here
161
+ if u0.ndim > 2:
121
162
  raise ValueError(
122
- f"Wrong t0 input (self.initial_condition[0]) It should be"
123
- f"a float or an array of shape (1,). Got shape: {t0.shape}"
163
+ "It u0 is an Array, it represents n_cond "
164
+ "imposed conditions and must be of shape (n_cond, dim)"
124
165
  )
125
- if isinstance(t0, float): # e.g. user input: 0.
126
- t0 = jnp.array([t0])
127
- if isinstance(t0, int): # e.g. user input: 0
128
- t0 = jnp.array([float(t0)])
166
+ else:
167
+ # at the end we want to end up with u0 of shape (:, dim) to account for
168
+ # possibly several data points
169
+ u0 = initial_condition_check(u0, dim_size=None)[
170
+ None, :
171
+ ] # make a (1, 1) here
172
+
173
+ if t0.shape[0] != u0.shape[0] or t0.ndim != u0.ndim:
174
+ raise ValueError(
175
+ "t0 and u0 must represent a same number of initial"
176
+ " conditial conditions"
177
+ )
178
+
129
179
  self.initial_condition = (t0, u0)
130
180
 
131
181
  if self.obs_slice is None:
@@ -259,28 +309,42 @@ class LossODE(_LossODEAbstract):
259
309
 
260
310
  # initial condition
261
311
  if self.initial_condition is not None:
262
- vmap_in_axes = (None,) + vmap_in_axes_params
263
- if not jax.tree_util.tree_leaves(vmap_in_axes):
264
- # test if only None in vmap_in_axes to avoid the value error:
265
- # `vmap must have at least one non-None value in in_axes`
266
- v_u = self.u
267
- else:
268
- v_u = vmap(self.u, (None,) + vmap_in_axes_params)
269
312
  t0, u0 = self.initial_condition
270
313
  u0 = jnp.array(u0)
271
- initial_condition_fun = lambda p: jnp.mean(
272
- jnp.sum(
273
- (
274
- v_u(
275
- t0,
276
- _set_derivatives(p, self.derivative_keys.initial_condition), # type: ignore
277
- )
278
- - u0
314
+
315
+ # first construct the plain init loss no vmaping
316
+ initial_condition_fun__ = lambda t, u, p: jnp.sum(
317
+ (
318
+ self.u(
319
+ t,
320
+ _set_derivatives(
321
+ p,
322
+ self.derivative_keys.initial_condition, # type: ignore
323
+ ),
279
324
  )
280
- ** 2,
281
- axis=-1,
325
+ - u
282
326
  )
327
+ ** 2,
328
+ axis=0,
283
329
  )
330
+ # now vmap over the number of conditions (first dim of t0 and u0)
331
+ # and take the mean
332
+ initial_condition_fun_ = lambda p: jnp.mean(
333
+ vmap(initial_condition_fun__, (0, 0, None))(t0, u0, p)
334
+ )
335
+ # now vmap over the the possible batch of parameters and take the
336
+ # average. Note that we then finally have a cartesian product
337
+ # between the batch of parameters (if any) and the number of
338
+ # conditions (if any)
339
+ if not jax.tree_util.tree_leaves(vmap_in_axes_params):
340
+ # if there is no parameter batch to vmap over we cannot call
341
+ # vmap because calling vmap must be done with at least one non
342
+ # None in_axes or out_axes
343
+ initial_condition_fun = initial_condition_fun_
344
+ else:
345
+ initial_condition_fun = lambda p: jnp.mean(
346
+ vmap(initial_condition_fun_, vmap_in_axes_params)(p)
347
+ )
284
348
  else:
285
349
  initial_condition_fun = None
286
350
 
jinns/loss/_LossPDE.py CHANGED
@@ -21,6 +21,7 @@ from jinns.loss._loss_utils import (
21
21
  normalization_loss_apply,
22
22
  observations_loss_apply,
23
23
  initial_condition_apply,
24
+ initial_condition_check,
24
25
  )
25
26
  from jinns.parameters._params import (
26
27
  _get_vmap_in_axes_params,
@@ -700,22 +701,12 @@ class LossPDENonStatio(LossPDEStatio):
700
701
  "case (e.g by. hardcoding it into the PINN output)."
701
702
  )
702
703
  # some checks for t0
703
- if isinstance(self.t0, Array):
704
- if not self.t0.shape: # e.g. user input: jnp.array(0.)
705
- self.t0 = jnp.array([self.t0])
706
- elif self.t0.shape != (1,):
707
- raise ValueError(
708
- f"Wrong self.t0 input. It should be"
709
- f"a float or an array of shape (1,). Got shape: {self.t0.shape}"
710
- )
711
- elif isinstance(self.t0, float): # e.g. user input: 0.
712
- self.t0 = jnp.array([self.t0])
713
- elif isinstance(self.t0, int): # e.g. user input: 0
714
- self.t0 = jnp.array([float(self.t0)])
715
- elif self.t0 is None:
716
- self.t0 = jnp.array([0])
704
+ t0 = self.t0
705
+ if t0 is None:
706
+ t0 = jnp.array([0])
717
707
  else:
718
- raise ValueError("Wrong value for t0")
708
+ t0 = initial_condition_check(t0, dim_size=1)
709
+ self.t0 = t0
719
710
 
720
711
  # witht the variables below we avoid memory overflow since a cartesian
721
712
  # product is taken
jinns/loss/_loss_utils.py CHANGED
@@ -308,3 +308,26 @@ def initial_condition_apply(
308
308
  else:
309
309
  raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
310
310
  return mse_initial_condition
311
+
312
+
313
+ def initial_condition_check(x, dim_size=None):
314
+ """
315
+ Make a (dim_size,) jnp array from an int, a float or a 0D jnp array
316
+
317
+ """
318
+ if isinstance(x, Array):
319
+ if not x.shape: # e.g. user input: jnp.array(0.)
320
+ x = jnp.array([x])
321
+ if dim_size is not None: # we check for the required dims_ize
322
+ if x.shape != (dim_size,):
323
+ raise ValueError(
324
+ f"Wrong dim_size. It should be({dim_size},). Got shape: {x.shape}"
325
+ )
326
+
327
+ elif isinstance(x, float): # e.g. user input: 0.
328
+ x = jnp.array([x])
329
+ elif isinstance(x, int): # e.g. user input: 0
330
+ x = jnp.array([float(x)])
331
+ else:
332
+ raise ValueError(f"Wrong value, expected Array, float or int, got {type(x)}")
333
+ return x
@@ -10,6 +10,14 @@ from jaxtyping import Array, PyTree, Float
10
10
  T = TypeVar("T") # the generic type for what is in the Params PyTree because we
11
11
  # have possibly Params of Arrays, boolean, ...
12
12
 
13
+ ### NOTE
14
+ ### We are taking derivatives with respect to Params eqx.Modules.
15
+ ### This has been shown to behave weirdly if some fields of eqx.Modules have
16
+ ### been set as `field(init=False)`, we then should never create such fields in
17
+ ### jinns' Params modules.
18
+ ### We currently have silenced the warning related to this (see jinns.__init__
19
+ ### see https://github.com/patrick-kidger/equinox/pull/1043/commits/f88e62ab809140334c2f987ed13eff0d80b8be13
20
+
13
21
 
14
22
  class Params(eqx.Module, Generic[T]):
15
23
  """
jinns/solver/_solve.py CHANGED
@@ -179,6 +179,7 @@ def solve(
179
179
  best_val_params
180
180
  The best parameters according to the validation criterion
181
181
  """
182
+ initialization_time = time.time()
182
183
  if n_iter < 1:
183
184
  raise ValueError("Cannot run jinns.solve for n_iter<1")
184
185
 
@@ -225,11 +226,6 @@ def solve(
225
226
  # get_batch with device_put, the latter is not jittable
226
227
  get_batch = _get_get_batch(obs_batch_sharding)
227
228
 
228
- # initialize the dict for stored parameter values
229
- # we need to get a loss_term to init stuff
230
- batch_ini, data, param_data, obs_data = get_batch(data, param_data, obs_data)
231
- _, loss_terms = loss(init_params, batch_ini)
232
-
233
229
  # initialize parameter tracking
234
230
  if tracked_params is None:
235
231
  tracked_params = jax.tree.map(lambda p: None, init_params)
@@ -247,6 +243,13 @@ def solve(
247
243
  # being a complex data structure
248
244
  )
249
245
 
246
+ # initialize the dict for stored parameter values
247
+ # we need to get a loss_term to init stuff
248
+ # NOTE: we use jax.eval_shape to avoid FLOPS since we only need the tree
249
+ # structure
250
+ batch_ini, data, param_data, obs_data = get_batch(data, param_data, obs_data)
251
+ _, loss_terms = jax.eval_shape(loss, init_params, batch_ini)
252
+
250
253
  # initialize the PyTree for stored loss values
251
254
  stored_loss_terms = jax.tree_util.tree_map(
252
255
  lambda _: jnp.zeros((n_iter)), loss_terms
@@ -475,6 +478,9 @@ def solve(
475
478
  key,
476
479
  )
477
480
 
481
+ if verbose:
482
+ print("Initialization time:", time.time() - initialization_time)
483
+
478
484
  # Main optimization loop. We use the LAX while loop (fully jitted) version
479
485
  # if no mixing devices. Otherwise we use the standard while loop. Here devices only
480
486
  # concern obs_batch, but it could lead to more complex scheme in the future
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jinns
3
- Version: 1.5.0
3
+ Version: 1.5.1
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -1,8 +1,8 @@
1
- jinns/__init__.py,sha256=hyh3QKO2cQGK5cmvFYP0MrXb-tK_DM2T9CwLwO-sEX8,500
1
+ jinns/__init__.py,sha256=Tjp5z0Mnd1nscvJaXnxZb4lEYbI0cpN6OPgCZ-Swo74,453
2
2
  jinns/data/_AbstractDataGenerator.py,sha256=O61TBOyeOFKwf1xqKzFD4KwCWRDnm2XgyJ-kKY9fmB4,557
3
3
  jinns/data/_Batchs.py,sha256=-DlD6Qag3zs5QbKtKAOvOzV7JOpNOqAm_P8cwo1dIZg,1574
4
- jinns/data/_CubicMeshPDENonStatio.py,sha256=c_8czJpxSoEvgZ8LDpL2sqtF9dcW4ELNO4juEFMOxog,16400
5
- jinns/data/_CubicMeshPDEStatio.py,sha256=stZ0Kbb7_VwFmWUSPs0P6a6qRj2Tu67p7sxEfb1Ajks,17865
4
+ jinns/data/_CubicMeshPDENonStatio.py,sha256=4f21SgeNQsGJTz8-uehduU0X9TibRcs28Iq49Kv4nQQ,22250
5
+ jinns/data/_CubicMeshPDEStatio.py,sha256=DVrP4qHVAJMu915EYH2PKyzwoG0nIMixAEKR2fz6C58,22525
6
6
  jinns/data/_DataGeneratorODE.py,sha256=5RzUbQFEsooAZsocDw4wRgA_w5lJmDMuY4M6u79K-1c,7260
7
7
  jinns/data/_DataGeneratorObservations.py,sha256=jknepLsJatSJHFq5lLMD-fFHkPGj5q286LEjE-vH24k,7738
8
8
  jinns/data/_DataGeneratorParameter.py,sha256=IedX3jcOj7ZDW_18IAcRR75KVzQzo85z9SICIKDBJl4,8539
@@ -12,13 +12,13 @@ jinns/experimental/__init__.py,sha256=DT9e57zbjfzPeRnXemGUqnGd--MhV77FspChT0z4Yr
12
12
  jinns/experimental/_diffrax_solver.py,sha256=upMr3kTTNrxEiSUO_oLvCXcjS9lPxSjvbB81h3qlhaU,6813
13
13
  jinns/loss/_DynamicLoss.py,sha256=4mb7OCP-cGZ_mG2MQ-AniddDcuBT78p4bQI7rZpwte4,22722
14
14
  jinns/loss/_DynamicLossAbstract.py,sha256=QhHRgvtcT-ifHlOxTyXbjDtHk9UfPN2Si8s3v9nEQm4,12672
15
- jinns/loss/_LossODE.py,sha256=DeejnU2ytgrOxUnwuVkQDWWRKJAgNQyjacTx-jT0xPA,13796
16
- jinns/loss/_LossPDE.py,sha256=ycjWJ99SuXe9DV5nROSWyq--xcp2JJ2PGWxsdWyZZog,36942
15
+ jinns/loss/_LossODE.py,sha256=iVYDojaI6Co7S5CrU67_niopD4Bk7UBTuLzDiTHoWMc,16996
16
+ jinns/loss/_LossPDE.py,sha256=VT56oQ_33fLq46lIch0slNsxu4d97eQBOgRAPeFESts,36401
17
17
  jinns/loss/__init__.py,sha256=z5xYgBipNFf66__5BqQc6R_8r4F6A3TXL60YjsM8Osk,1287
18
18
  jinns/loss/_abstract_loss.py,sha256=DMxn0SQe9PW-pq3p5Oqvb0YK3_ulLDOnoIXzK219GV4,4576
19
19
  jinns/loss/_boundary_conditions.py,sha256=9HGw1cGLfmEilP4V4B2T0zl0YP1kNtrtXVLQNiBmWgc,12464
20
20
  jinns/loss/_loss_components.py,sha256=MMzaGlaRqESPjRzT0j0WU9HAqWQSbIXpGAqM1xQCZHw,1106
21
- jinns/loss/_loss_utils.py,sha256=R6PffBAtg6z9M8x1DFXmmqZpC095b9gZ_DB1phQxSuY,11168
21
+ jinns/loss/_loss_utils.py,sha256=eJ4JcBm396LHx7Tti88ZQrLcKqVL1oSfFGT23VNkytQ,11949
22
22
  jinns/loss/_loss_weight_updates.py,sha256=9Bwouh7shLyc_wrdzN6CYL0ZuQH81uEs-L6wCeiYFx8,6817
23
23
  jinns/loss/_loss_weights.py,sha256=kII5WddORgeommFTudT3CSvhICpo6nSe47LclUgu_78,2429
24
24
  jinns/loss/_operators.py,sha256=Ds5yRH7hu-jaGRp7PYbt821BgYuEvgWHufWhYgdMjw0,22909
@@ -34,12 +34,12 @@ jinns/nn/_spinn_mlp.py,sha256=uCL454sF0Tfj7KT-fdXPnvKJYRQOuq60N0r2b2VAB8Q,7606
34
34
  jinns/nn/_utils.py,sha256=9UXz73iHKHVQYPBPIEitrHYJzJ14dspRwPfLA8avx0c,1120
35
35
  jinns/parameters/__init__.py,sha256=O0n7y6R1LRmFzzugCxMFCMS2pgsuWSh-XHjfFViN_eg,265
36
36
  jinns/parameters/_derivative_keys.py,sha256=YlLDX49PfYhr2Tj--t3praiD8JOUTZU6PTmjbNZsbMc,19173
37
- jinns/parameters/_params.py,sha256=qn4IGMJhD9lDBqOWmGEMy4gXt5a6KHfirkYZwHO7Vwk,2633
37
+ jinns/parameters/_params.py,sha256=nv0WScbgUdmuC0bSF15VbnKypJ58pl6wynZAcYfuF6M,3081
38
38
  jinns/plot/__init__.py,sha256=KPHX0Um4FbciZO1yD8kjZbkaT8tT964Y6SE2xCQ4eDU,135
39
39
  jinns/plot/_plot.py,sha256=-A5auNeElaz2_8UzVQJQE4143ZFg0zgMjStU7kwttEY,11565
40
40
  jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
41
41
  jinns/solver/_rar.py,sha256=vSVTnCGCusI1vTZCvIkP2_G8we44G_42yZHx2sOK9DE,10291
42
- jinns/solver/_solve.py,sha256=oVHnuc7Z0V2-ZYgZtCx7xdFd7TpB9w-6AwafX-kgBE4,28379
42
+ jinns/solver/_solve.py,sha256=IsrmG2m48KkkYgvXYomlSbZ3hd1FySCj3rwlkovs-lI,28616
43
43
  jinns/solver/_utils.py,sha256=sM2UbVzYyjw24l4QSIR3IlynJTPGD_S08r8v0lXMxA8,5876
44
44
  jinns/utils/__init__.py,sha256=OEYWLCw8pKE7xoQREbd6SHvCjuw2QZHuVA6YwDcsBE8,53
45
45
  jinns/utils/_containers.py,sha256=YShcrPKfj5_I9mn3NMAS4Ea9MhhyL7fjv0e3MRbITHg,1837
@@ -47,9 +47,9 @@ jinns/utils/_types.py,sha256=jl_91HtcrtE6UHbdTrRI8iUmr2kBUL0oP0UNIKhAXYw,1170
47
47
  jinns/utils/_utils.py,sha256=M7NXX9ok-BkH5o_xo74PB1_Cc8XiDipSl51rq82dTH4,2821
48
48
  jinns/validation/__init__.py,sha256=FTyUO-v1b8Tv-FDSQsntrH7zl9E0ENexqKMT_dFRkYo,124
49
49
  jinns/validation/_validation.py,sha256=8p6sMKiBAvA6JNm65hjkMj0997LJ0BkyCREEh0AnPVE,4803
50
- jinns-1.5.0.dist-info/licenses/AUTHORS,sha256=7NwCj9nU-HNG1asvy4qhQ2w7oZHrn-Lk5_wK_Ve7a3M,80
51
- jinns-1.5.0.dist-info/licenses/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
52
- jinns-1.5.0.dist-info/METADATA,sha256=jEp__DP39B1HiTYVhtVcWKPmzS22kSUD6jNVSmHFh8g,5314
53
- jinns-1.5.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
54
- jinns-1.5.0.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
55
- jinns-1.5.0.dist-info/RECORD,,
50
+ jinns-1.5.1.dist-info/licenses/AUTHORS,sha256=7NwCj9nU-HNG1asvy4qhQ2w7oZHrn-Lk5_wK_Ve7a3M,80
51
+ jinns-1.5.1.dist-info/licenses/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
52
+ jinns-1.5.1.dist-info/METADATA,sha256=K7Aii5ivFcczIwLlQtCPqzMJEfW86D1yW1q7qMtvWPE,5314
53
+ jinns-1.5.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
54
+ jinns-1.5.1.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
55
+ jinns-1.5.1.dist-info/RECORD,,
File without changes