jinns 1.4.0__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/__init__.py CHANGED
@@ -1,10 +1,3 @@
1
- # import jinns.data
2
- # import jinns.loss
3
- # import jinns.solver
4
- # import jinns.utils
5
- # import jinns.experimental
6
- # import jinns.parameters
7
- # import jinns.plot
8
1
  from jinns import data as data
9
2
  from jinns import loss as loss
10
3
  from jinns import solver as solver
@@ -16,3 +9,10 @@ from jinns import nn as nn
16
9
  from jinns.solver._solve import solve
17
10
 
18
11
  __all__ = ["nn", "solve"]
12
+
13
+ import warnings
14
+
15
+ warnings.filterwarnings(
16
+ action="ignore",
17
+ message=r"Using `field\(init=False\)`",
18
+ )
@@ -7,8 +7,10 @@ from __future__ import (
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
8
  import warnings
9
9
  import equinox as eqx
10
+ import numpy as np
10
11
  import jax
11
12
  import jax.numpy as jnp
13
+ from scipy.stats import qmc
12
14
  from jaxtyping import Key, Array, Float
13
15
  from jinns.data._Batchs import PDENonStatioBatch
14
16
  from jinns.data._utils import (
@@ -65,11 +67,13 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
65
67
  The minimum value of the time domain to consider
66
68
  tmax : float
67
69
  The maximum value of the time domain to consider
68
- method : str, default="uniform"
70
+ method : Literal["uniform", "grid", "sobol", "halton"], default="uniform"
69
71
  Either `grid` or `uniform`, default is `uniform`.
70
72
  The method that generates the `nt` time points. `grid` means
71
73
  regularly spaced points over the domain. `uniform` means uniformly
72
- sampled points over the domain
74
+ sampled points over the domain.
75
+ **Note** that Sobol and Halton approaches use scipy modules and will not
76
+ be JIT compatible.
73
77
  rar_parameters : Dict[str, int], default=None
74
78
  Defaults to None: do not use Residual Adaptative Resampling.
75
79
  Otherwise a dictionary with keys
@@ -150,9 +154,11 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
150
154
  elif self.method == "uniform":
151
155
  self.key, domain_times = self.generate_time_data(self.key, self.n)
152
156
  self.domain = jnp.concatenate([domain_times, self.omega], axis=1)
157
+ elif self.method in ["sobol", "halton"]:
158
+ self.key, self.domain = self.qmc_in_time_omega_domain(self.key, self.n)
153
159
  else:
154
160
  raise ValueError(
155
- f'Bad value for method. Got {self.method}, expected "grid" or "uniform"'
161
+ f'Bad value for method. Got {self.method}, expected "grid" or "uniform" or "sobol" or "halton"'
156
162
  )
157
163
 
158
164
  if self.domain_batch_size is None:
@@ -182,21 +188,28 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
182
188
  "number of points per facets (nb//2*self.dim)"
183
189
  " cannot be lower than border batch size"
184
190
  )
185
- self.key, boundary_times = self.generate_time_data(
186
- self.key, self.nb // (2 * self.dim)
187
- )
188
- boundary_times = boundary_times.reshape(-1, 1, 1)
189
- boundary_times = jnp.repeat(
190
- boundary_times, self.omega_border.shape[-1], axis=2
191
- )
192
- if self.dim == 1:
193
- self.border = make_cartesian_product(
194
- boundary_times, self.omega_border[None, None]
191
+ if self.method in ["grid", "uniform"]:
192
+ self.key, boundary_times = self.generate_time_data(
193
+ self.key, self.nb // (2 * self.dim)
195
194
  )
195
+ boundary_times = boundary_times.reshape(-1, 1, 1)
196
+ boundary_times = jnp.repeat(
197
+ boundary_times, self.omega_border.shape[-1], axis=2
198
+ )
199
+ if self.dim == 1:
200
+ self.border = make_cartesian_product(
201
+ boundary_times, self.omega_border[None, None]
202
+ )
203
+ else:
204
+ self.border = jnp.concatenate(
205
+ [boundary_times, self.omega_border], axis=1
206
+ )
196
207
  else:
197
- self.border = jnp.concatenate(
198
- [boundary_times, self.omega_border], axis=1
208
+ self.key, self.border = self.qmc_in_time_omega_border_domain(
209
+ self.key,
210
+ self.nb, # type: ignore (see inside the fun)
199
211
  )
212
+
200
213
  if self.border_batch_size is None:
201
214
  self.curr_border_idx = 0
202
215
  else:
@@ -209,14 +222,30 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
209
222
  self.curr_border_idx = 0
210
223
 
211
224
  if self.ni is not None:
212
- perfect_sq = int(jnp.round(jnp.sqrt(self.ni)) ** 2)
213
- if self.ni != perfect_sq:
214
- warnings.warn(
215
- "Grid sampling is requested in dimension 2 with a non"
216
- f" perfect square dataset size (self.ni = {self.ni})."
217
- f" Modifying self.ni to self.ni = {perfect_sq}."
225
+ if self.method == "grid":
226
+ perfect_sq = int(jnp.round(jnp.sqrt(self.ni)) ** 2)
227
+ if self.ni != perfect_sq:
228
+ warnings.warn(
229
+ "Grid sampling is requested in dimension 2 with a non"
230
+ f" perfect square dataset size (self.ni = {self.ni})."
231
+ f" Modifying self.ni to self.ni = {perfect_sq}."
232
+ )
233
+ self.ni = perfect_sq
234
+ if self.method in ["sobol", "halton"]:
235
+ log2_n = jnp.log2(self.ni)
236
+ lower_pow = 2 ** jnp.floor(log2_n)
237
+ higher_pow = 2 ** jnp.ceil(log2_n)
238
+ closest_two_power = (
239
+ lower_pow
240
+ if (self.ni - lower_pow) < (higher_pow - self.ni)
241
+ else higher_pow
218
242
  )
219
- self.ni = perfect_sq
243
+ if self.n != closest_two_power:
244
+ warnings.warn(
245
+ f"QuasiMonteCarlo sampling with {self.method} requires sample size to be a power fo 2."
246
+ f"Modfiying self.n from {self.ni} to {closest_two_power}.",
247
+ )
248
+ self.ni = int(closest_two_power)
220
249
  self.key, self.initial = self.generate_omega_data(
221
250
  self.key, data_size=self.ni
222
251
  )
@@ -245,16 +274,115 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
245
274
  if self.method == "grid":
246
275
  partial_times = (self.tmax - self.tmin) / nt
247
276
  return key, jnp.arange(self.tmin, self.tmax, partial_times)[:, None]
248
- if self.method == "uniform":
277
+ elif self.method in ["uniform", "sobol", "halton"]:
249
278
  return key, self.sample_in_time_domain(subkey, nt)
250
279
  raise ValueError("Method " + self.method + " is not implemented.")
251
280
 
252
281
  def sample_in_time_domain(self, key: Key, nt: int) -> Float[Array, " nt 1"]:
253
- return jax.random.uniform(
254
- key,
255
- (nt, 1),
256
- minval=self.tmin,
257
- maxval=self.tmax,
282
+ return jax.random.uniform(key, (nt, 1), minval=self.tmin, maxval=self.tmax)
283
+
284
+ def qmc_in_time_omega_domain(
285
+ self, key: Key, sample_size: int
286
+ ) -> tuple[Key, Float[Array, "n 1+dim"]]:
287
+ """
288
+ Because in Quasi-Monte Carlo sampling we cannot concatenate two vectors generated independently
289
+ We generate time and omega samples jointly
290
+ """
291
+ key, subkey = jax.random.split(key, 2)
292
+ qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
293
+ sampler = qmc_generator(
294
+ d=self.dim + 1, scramble=True, rng=np.random.default_rng(np.uint32(subkey))
295
+ )
296
+ samples = sampler.random(n=sample_size)
297
+ samples[:, 1:] = qmc.scale(
298
+ samples[:, 1:], l_bounds=self.min_pts, u_bounds=self.max_pts
299
+ ) # We scale omega domain to be in (min_pts, max_pts)
300
+ return key, jnp.array(samples)
301
+
302
+ def qmc_in_time_omega_border_domain(
303
+ self, key: Key, sample_size: int | None = None
304
+ ) -> tuple[Key, Float[Array, "n 1+dim"]] | None:
305
+ """
306
+ For each facet of the border we generate Quasi-MonteCarlo sequences jointy with time.
307
+
308
+ We need to do some type ignore in this function because we have lost
309
+ the type narrowing from post_init, type checkers only narrow at function level and because we cannot narrow a class attribute.
310
+ """
311
+ qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
312
+ sample_size = self.nb if sample_size is None else sample_size
313
+ if sample_size is None:
314
+ return None
315
+ if self.dim == 1:
316
+ key, subkey = jax.random.split(key, 2)
317
+ qmc_seq = qmc_generator(
318
+ d=1, scramble=True, rng=np.random.default_rng(np.uint32(subkey))
319
+ )
320
+ boundary_times = jnp.array(
321
+ qmc_seq.random(self.nb // (2 * self.dim)) # type: ignore
322
+ )
323
+ boundary_times = boundary_times.reshape(-1, 1, 1)
324
+ boundary_times = jnp.repeat(
325
+ boundary_times,
326
+ self.omega_border.shape[-1], # type: ignore
327
+ axis=2,
328
+ )
329
+ return key, make_cartesian_product(
330
+ boundary_times,
331
+ self.omega_border[None, None], # type: ignore
332
+ )
333
+ if self.dim == 2:
334
+ # currently hard-coded the 4 edges for d==2
335
+ # TODO : find a general & efficient way to sample from the border
336
+ # (facets) of the hypercube in general dim.
337
+ key, *subkeys = jax.random.split(key, 5)
338
+ facet_n = sample_size // (2 * self.dim)
339
+
340
+ def generate_qmc_sample(key, min_val, max_val):
341
+ qmc_seq = qmc_generator(
342
+ d=2,
343
+ scramble=True,
344
+ rng=np.random.default_rng(np.uint32(key)),
345
+ )
346
+ u = qmc_seq.random(n=facet_n)
347
+ u[:, 1:2] = qmc.scale(u[:, 1:2], l_bounds=min_val, u_bounds=max_val)
348
+ return jnp.array(u)
349
+
350
+ xmin_sample = generate_qmc_sample(
351
+ subkeys[0], self.min_pts[1], self.max_pts[1]
352
+ ) # [t,x,y]
353
+ xmin = jnp.hstack(
354
+ [
355
+ xmin_sample[:, 0:1],
356
+ self.min_pts[0] * jnp.ones((facet_n, 1)),
357
+ xmin_sample[:, 1:2],
358
+ ]
359
+ )
360
+ xmax_sample = generate_qmc_sample(
361
+ subkeys[1], self.min_pts[1], self.max_pts[1]
362
+ )
363
+ xmax = jnp.hstack(
364
+ [
365
+ xmax_sample[:, 0:1],
366
+ self.max_pts[0] * jnp.ones((facet_n, 1)),
367
+ xmax_sample[:, 1:2],
368
+ ]
369
+ )
370
+ ymin = jnp.hstack(
371
+ [
372
+ generate_qmc_sample(subkeys[2], self.min_pts[0], self.max_pts[0]),
373
+ self.min_pts[1] * jnp.ones((facet_n, 1)),
374
+ ]
375
+ )
376
+ ymax = jnp.hstack(
377
+ [
378
+ generate_qmc_sample(subkeys[3], self.min_pts[0], self.max_pts[0]),
379
+ self.max_pts[1] * jnp.ones((facet_n, 1)),
380
+ ]
381
+ )
382
+ return key, jnp.stack([xmin, xmax, ymin, ymax], axis=-1)
383
+ raise NotImplementedError(
384
+ "Generation of the border of a cube in dimension > 2 is not "
385
+ + f"implemented yet. You are asking for generation in dimension d={self.dim}."
258
386
  )
259
387
 
260
388
  def _get_domain_operands(
@@ -8,8 +8,11 @@ from __future__ import (
8
8
  import warnings
9
9
  import equinox as eqx
10
10
  import jax
11
+ import numpy as np
11
12
  import jax.numpy as jnp
13
+ from scipy.stats import qmc
12
14
  from jaxtyping import Key, Array, Float
15
+ from typing import Literal
13
16
  from jinns.data._Batchs import PDEStatioBatch
14
17
  from jinns.data._utils import _check_and_set_rar_parameters, _reset_or_increment
15
18
  from jinns.data._AbstractDataGenerator import AbstractDataGenerator
@@ -50,11 +53,13 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
50
53
  A tuple of maximum values of the domain along each dimension. For a sampling
51
54
  in `n` dimension, this represents $(x_{1, max}, x_{2,max}, ...,
52
55
  x_{n,max})$
53
- method : str, default="uniform"
54
- Either `grid` or `uniform`, default is `uniform`.
56
+ method : Literal["grid", "uniform", "sobol", "halton"], default="uniform"
57
+ Either "grid", "uniform", "sobol" or "halton", default is `uniform`.
55
58
  The method that generates the `nt` time points. `grid` means
56
59
  regularly spaced points over the domain. `uniform` means uniformly
57
- sampled points over the domain
60
+ sampled points over the domain.
61
+ **Note** that Sobol and Halton approaches use scipy modules and will not
62
+ be JIT compatible.
58
63
  rar_parameters : dict[str, int], default=None
59
64
  Defaults to None: do not use Residual Adaptative Resampling.
60
65
  Otherwise a dictionary with keys
@@ -94,7 +99,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
94
99
  # shape in jax.lax.dynamic_slice
95
100
  min_pts: tuple[float, ...] = eqx.field(kw_only=True)
96
101
  max_pts: tuple[float, ...] = eqx.field(kw_only=True)
97
- method: str = eqx.field(
102
+ method: Literal["grid", "uniform", "sobol", "halton"] = eqx.field(
98
103
  kw_only=True, static=True, default_factory=lambda: "uniform"
99
104
  )
100
105
  rar_parameters: dict[str, int] = eqx.field(kw_only=True, default=None)
@@ -132,6 +137,22 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
132
137
  )
133
138
  self.n = perfect_sq
134
139
 
140
+ if self.method in ["sobol", "halton"]:
141
+ log2_n = jnp.log2(self.n)
142
+ lower_pow = 2 ** jnp.floor(log2_n)
143
+ higher_pow = 2 ** jnp.ceil(log2_n)
144
+ closest_two_power = (
145
+ lower_pow
146
+ if (self.n - lower_pow) < (higher_pow - self.n)
147
+ else higher_pow
148
+ )
149
+ if self.n != closest_two_power:
150
+ warnings.warn(
151
+ f"QuasiMonteCarlo sampling with {self.method} requires sample size to be a power fo 2."
152
+ f"Modfiying self.n from {self.n} to {closest_two_power}.",
153
+ )
154
+ self.n = int(closest_two_power)
155
+
135
156
  if self.omega_batch_size is None:
136
157
  self.curr_omega_idx = 0
137
158
  else:
@@ -176,24 +197,48 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
176
197
  def sample_in_omega_domain(
177
198
  self, keys: Key, sample_size: int
178
199
  ) -> Float[Array, " n dim"]:
200
+ if self.method == "uniform":
201
+ if self.dim == 1:
202
+ xmin, xmax = self.min_pts[0], self.max_pts[0]
203
+ return jax.random.uniform(
204
+ keys, shape=(sample_size, 1), minval=xmin, maxval=xmax
205
+ )
206
+
207
+ return jnp.concatenate(
208
+ [
209
+ jax.random.uniform(
210
+ keys[i],
211
+ (sample_size, 1),
212
+ minval=self.min_pts[i],
213
+ maxval=self.max_pts[i],
214
+ )
215
+ for i in range(self.dim)
216
+ ],
217
+ axis=-1,
218
+ )
219
+ else:
220
+ return self._qmc_in_omega_domain(keys, sample_size)
221
+
222
+ def _qmc_in_omega_domain(
223
+ self, subkey: Key, sample_size: int
224
+ ) -> Float[Array, "n dim"]:
225
+ qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
179
226
  if self.dim == 1:
180
- xmin, xmax = self.min_pts[0], self.max_pts[0]
181
- return jax.random.uniform(
182
- keys, shape=(sample_size, 1), minval=xmin, maxval=xmax
227
+ qmc_seq = qmc_generator(
228
+ d=self.dim,
229
+ scramble=True,
230
+ rng=np.random.default_rng(np.uint32(subkey)),
183
231
  )
184
- # keys = jax.random.split(key, self.dim)
185
- return jnp.concatenate(
186
- [
187
- jax.random.uniform(
188
- keys[i],
189
- (sample_size, 1),
190
- minval=self.min_pts[i],
191
- maxval=self.max_pts[i],
192
- )
193
- for i in range(self.dim)
194
- ],
195
- axis=-1,
232
+ u = qmc_seq.random(n=sample_size)
233
+ return jnp.array(
234
+ qmc.scale(u, l_bounds=self.min_pts[0], u_bounds=self.max_pts[0])
235
+ )
236
+ sampler = qmc.Sobol(
237
+ d=self.dim, scramble=True, rng=np.random.default_rng(np.uint32(subkey))
196
238
  )
239
+ samples = sampler.random(n=sample_size)
240
+ samples = qmc.scale(samples, l_bounds=self.min_pts, u_bounds=self.max_pts)
241
+ return jnp.array(samples)
197
242
 
198
243
  def sample_in_omega_border_domain(
199
244
  self, keys: Key, sample_size: int | None = None
@@ -260,6 +305,62 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
260
305
  + f"implemented yet. You are asking for generation in dimension d={self.dim}."
261
306
  )
262
307
 
308
+ def qmc_in_omega_border_domain(
309
+ self, keys: Key, sample_size: int | None = None
310
+ ) -> Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None:
311
+ qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
312
+ sample_size = self.nb if sample_size is None else sample_size
313
+ if sample_size is None:
314
+ return None
315
+ if self.dim == 1:
316
+ xmin = self.min_pts[0]
317
+ xmax = self.max_pts[0]
318
+ return jnp.array([xmin, xmax]).astype(float)
319
+ if self.dim == 2:
320
+ # currently hard-coded the 4 edges for d==2
321
+ # TODO : find a general & efficient way to sample from the border
322
+ # (facets) of the hypercube in general dim.
323
+ facet_n = sample_size // (2 * self.dim)
324
+
325
+ def generate_qmc_sample(key, min_val, max_val):
326
+ qmc_seq = qmc_generator(
327
+ d=1,
328
+ scramble=True,
329
+ rng=np.random.default_rng(np.uint32(key)),
330
+ )
331
+ u = qmc_seq.random(n=facet_n)
332
+ return jnp.array(qmc.scale(u, l_bounds=min_val, u_bounds=max_val))
333
+
334
+ xmin = jnp.hstack(
335
+ [
336
+ self.min_pts[0] * jnp.ones((facet_n, 1)),
337
+ generate_qmc_sample(keys[0], self.min_pts[1], self.max_pts[1]),
338
+ ]
339
+ )
340
+ xmax = jnp.hstack(
341
+ [
342
+ self.max_pts[0] * jnp.ones((facet_n, 1)),
343
+ generate_qmc_sample(keys[1], self.min_pts[1], self.max_pts[1]),
344
+ ]
345
+ )
346
+ ymin = jnp.hstack(
347
+ [
348
+ generate_qmc_sample(keys[2], self.min_pts[0], self.max_pts[0]),
349
+ self.min_pts[1] * jnp.ones((facet_n, 1)),
350
+ ]
351
+ )
352
+ ymax = jnp.hstack(
353
+ [
354
+ generate_qmc_sample(keys[3], self.min_pts[0], self.max_pts[0]),
355
+ self.max_pts[1] * jnp.ones((facet_n, 1)),
356
+ ]
357
+ )
358
+ return jnp.stack([xmin, xmax, ymin, ymax], axis=-1)
359
+ raise NotImplementedError(
360
+ "Generation of the border of a cube in dimension > 2 is not "
361
+ + f"implemented yet. You are asking for generation in dimension d={self.dim}."
362
+ )
363
+
263
364
  def generate_omega_data(
264
365
  self, key: Key, data_size: int | None = None
265
366
  ) -> tuple[
@@ -290,8 +391,8 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
290
391
  )
291
392
  xyz_ = [a.reshape((data_size, 1)) for a in xyz_]
292
393
  omega = jnp.concatenate(xyz_, axis=-1)
293
- elif self.method == "uniform":
294
- if self.dim == 1:
394
+ elif self.method in ["uniform", "sobol", "halton"]:
395
+ if self.dim == 1 or self.method in ["sobol", "halton"]:
295
396
  key, subkeys = jax.random.split(key, 2)
296
397
  else:
297
398
  key, *subkeys = jax.random.split(key, self.dim + 1)
@@ -317,10 +418,17 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
317
418
  key, *subkeys = jax.random.split(key, 5)
318
419
  else:
319
420
  subkeys = None
320
- omega_border = self.sample_in_omega_border_domain(
321
- subkeys, sample_size=data_size
322
- )
323
421
 
422
+ if self.method in ["grid", "uniform"]:
423
+ omega_border = self.sample_in_omega_border_domain(
424
+ subkeys, sample_size=data_size
425
+ )
426
+ elif self.method in ["sobol", "halton"]:
427
+ omega_border = self.qmc_in_omega_border_domain(
428
+ subkeys, sample_size=data_size
429
+ )
430
+ else:
431
+ raise ValueError("Method " + self.method + " is not implemented.")
324
432
  return key, omega_border
325
433
 
326
434
  def _get_omega_operands(
@@ -6,11 +6,13 @@ from __future__ import (
6
6
  annotations,
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
8
 
9
+ import warnings
9
10
  import abc
10
11
  from functools import partial
11
12
  from typing import Callable, TYPE_CHECKING, ClassVar, Generic, TypeVar
12
13
  import equinox as eqx
13
14
  from jaxtyping import Float, Array
15
+ import jax.numpy as jnp
14
16
 
15
17
 
16
18
  # See : https://docs.kidger.site/equinox/api/module/advanced_fields/#equinox.AbstractClassVar--known-issues
@@ -70,14 +72,31 @@ class DynamicLoss(eqx.Module, Generic[InputDim]):
70
72
  A value can be missing, in this case there is no heterogeneity (=None).
71
73
  Default None, meaning there is no heterogeneity in the equation
72
74
  parameters.
75
+ vectorial_dyn_loss_ponderation : Float[Array, " dim"], default=None
76
+ Add a different ponderation weight to each of the dimension to the
77
+ dynamic loss. This array must have the same dimension as the output of
78
+ the dynamic loss equation or an error is raised. Default is None which
79
+ means that a ponderation of 1 is applied on each dimension.
80
+ `vectorial_dyn_loss_ponderation`
81
+ is different from loss weights, which are attributes of Loss
82
+ classes and which implement scalar (and possibly dynamic)
83
+ ponderations for each term of the total loss.
84
+ `vectorial_dyn_loss_ponderation` can be used with loss weights.
73
85
  """
74
86
 
75
87
  _eq_type = AbstractClassVar[str] # class variable denoting the type of
76
88
  # differential equation
77
89
  Tmax: Float = eqx.field(kw_only=True, default=1)
78
- eq_params_heterogeneity: dict[str, Callable | None] = eqx.field(
90
+ eq_params_heterogeneity: dict[str, Callable | None] | None = eqx.field(
79
91
  kw_only=True, default=None, static=True
80
92
  )
93
+ vectorial_dyn_loss_ponderation: Float[Array, " dim"] | None = eqx.field(
94
+ kw_only=True, default=None
95
+ )
96
+
97
+ def __post_init__(self):
98
+ if self.vectorial_dyn_loss_ponderation is None:
99
+ self.vectorial_dyn_loss_ponderation = jnp.array(1.0)
81
100
 
82
101
  def _eval_heterogeneous_parameters(
83
102
  self,
@@ -110,12 +129,21 @@ class DynamicLoss(eqx.Module, Generic[InputDim]):
110
129
  u: AbstractPINN,
111
130
  params: Params[Array],
112
131
  ) -> float:
113
- evaluation = self.equation(inputs, u, params)
132
+ evaluation = self.vectorial_dyn_loss_ponderation * self.equation(
133
+ inputs, u, params
134
+ )
114
135
  if len(evaluation.shape) == 0:
115
136
  raise ValueError(
116
137
  "The output of dynamic loss must be vectorial, "
117
138
  "i.e. of shape (d,) with d >= 1"
118
139
  )
140
+ if len(evaluation.shape) > 1:
141
+ warnings.warn(
142
+ "Return value from DynamicLoss' equation has more "
143
+ "than one dimension. This is in general a mistake (probably from "
144
+ "an unfortunate broadcast in jnp.array computations) resulting in "
145
+ "bad reduction operations in losses."
146
+ )
119
147
  return evaluation
120
148
 
121
149
  @abc.abstractmethod