jinns 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/loss/_loss_utils.py CHANGED
@@ -16,10 +16,9 @@ from jaxtyping import Float, Array, PyTree
16
16
  from jinns.loss._boundary_conditions import (
17
17
  _compute_boundary_loss,
18
18
  )
19
- from jinns.utils._utils import _check_user_func_return, _get_grid
20
- from jinns.data._DataGenerators import (
21
- append_obs_batch,
22
- )
19
+ from jinns.utils._utils import _subtract_with_check, get_grid
20
+ from jinns.data._DataGenerators import append_obs_batch, make_cartesian_product
21
+ from jinns.parameters._params import _get_vmap_in_axes_params
23
22
  from jinns.utils._pinn import PINN
24
23
  from jinns.utils._spinn import SPINN
25
24
  from jinns.utils._hyperpinn import HYPERPINN
@@ -33,7 +32,11 @@ if TYPE_CHECKING:
33
32
  def dynamic_loss_apply(
34
33
  dyn_loss: DynamicLoss,
35
34
  u: eqx.Module,
36
- batches: ODEBatch | PDEStatioBatch | PDENonStatioBatch,
35
+ batch: (
36
+ Float[Array, "batch_size 1"]
37
+ | Float[Array, "batch_size dim"]
38
+ | Float[Array, "batch_size 1+dim"]
39
+ ),
37
40
  params: Params | ParamsDict,
38
41
  vmap_axes: tuple[int | None, ...],
39
42
  loss_weight: float | Float[Array, "dyn_loss_dimension"],
@@ -45,16 +48,16 @@ def dynamic_loss_apply(
45
48
  """
46
49
  if u_type == PINN or u_type == HYPERPINN or isinstance(u, (PINN, HYPERPINN)):
47
50
  v_dyn_loss = vmap(
48
- lambda *args: dyn_loss(
49
- *args[:-1], u, args[-1] # we must place the params at the end
51
+ lambda batch, params: dyn_loss(
52
+ batch, u, params # we must place the params at the end
50
53
  ),
51
54
  vmap_axes,
52
55
  0,
53
56
  )
54
- residuals = v_dyn_loss(*batches, params)
57
+ residuals = v_dyn_loss(batch, params)
55
58
  mse_dyn_loss = jnp.mean(jnp.sum(loss_weight * residuals**2, axis=-1))
56
59
  elif u_type == SPINN or isinstance(u, SPINN):
57
- residuals = dyn_loss(*batches, u, params)
60
+ residuals = dyn_loss(batch, u, params)
58
61
  mse_dyn_loss = jnp.mean(jnp.sum(loss_weight * residuals**2, axis=-1))
59
62
  else:
60
63
  raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
@@ -64,35 +67,49 @@ def dynamic_loss_apply(
64
67
 
65
68
  def normalization_loss_apply(
66
69
  u: eqx.Module,
67
- batches: ODEBatch | PDEStatioBatch | PDENonStatioBatch,
70
+ batches: (
71
+ tuple[Float[Array, "nb_norm_samples dim"]]
72
+ | tuple[
73
+ Float[Array, "nb_norm_time_slices 1"], Float[Array, "nb_norm_samples dim"]
74
+ ]
75
+ ),
68
76
  params: Params | ParamsDict,
69
- vmap_axes: tuple[int | None, ...],
77
+ vmap_axes_params: tuple[int | None, ...],
70
78
  int_length: int,
71
79
  loss_weight: float,
72
80
  ) -> float:
73
- # TODO merge stationary and non stationary cases
81
+ """
82
+ Note the squeezing on each result. We expect unidimensional *PINN since
83
+ they represent probability distributions
84
+ """
74
85
  if isinstance(u, (PINN, HYPERPINN)):
75
86
  if len(batches) == 1:
76
87
  v_u = vmap(
77
- lambda *args: u(*args)[u.slice_solution],
78
- vmap_axes,
88
+ lambda b: u(b)[u.slice_solution],
89
+ (0,) + vmap_axes_params,
79
90
  0,
80
91
  )
81
- mse_norm_loss = loss_weight * jnp.mean(
82
- jnp.abs(jnp.mean(v_u(*batches, params), axis=-1) * int_length - 1) ** 2
92
+ res = v_u(*batches, params)
93
+ mse_norm_loss = loss_weight * (
94
+ jnp.abs(jnp.mean(res.squeeze()) * int_length - 1) ** 2
83
95
  )
84
96
  else:
97
+ # NOTE this cartesian product is costly
98
+ batches = make_cartesian_product(
99
+ batches[0],
100
+ batches[1],
101
+ ).reshape(batches[0].shape[0], batches[1].shape[0], -1)
85
102
  v_u = vmap(
86
103
  vmap(
87
- lambda t, x, params_: u(t, x, params_),
88
- in_axes=(None, 0) + vmap_axes[2:],
104
+ lambda t_x, params_: u(t_x, params_),
105
+ in_axes=(0,) + vmap_axes_params,
89
106
  ),
90
- in_axes=(0, None) + vmap_axes[2:],
107
+ in_axes=(0,) + vmap_axes_params,
91
108
  )
92
- res = v_u(*batches, params)
93
- # the outer mean() below is for the times stamps
109
+ res = v_u(batches, params)
110
+ # Over all the times t, we perform a integration
94
111
  mse_norm_loss = loss_weight * jnp.mean(
95
- jnp.abs(jnp.mean(res, axis=(-2, -1)) * int_length - 1) ** 2
112
+ jnp.abs(jnp.mean(res.squeeze(), axis=-1) * int_length - 1) ** 2
96
113
  )
97
114
  elif isinstance(u, SPINN):
98
115
  if len(batches) == 1:
@@ -101,8 +118,7 @@ def normalization_loss_apply(
101
118
  loss_weight
102
119
  * jnp.abs(
103
120
  jnp.mean(
104
- jnp.mean(res, axis=-1),
105
- axis=tuple(range(res.ndim - 1)),
121
+ res.squeeze(),
106
122
  )
107
123
  * int_length
108
124
  - 1
@@ -112,12 +128,17 @@ def normalization_loss_apply(
112
128
  else:
113
129
  assert batches[1].shape[0] % batches[0].shape[0] == 0
114
130
  rep_t = batches[1].shape[0] // batches[0].shape[0]
115
- res = u(jnp.repeat(batches[0], rep_t, axis=0), batches[1], params)
131
+ res = u(
132
+ jnp.concatenate(
133
+ [jnp.repeat(batches[0], rep_t, axis=0), batches[1]], axis=-1
134
+ ),
135
+ params,
136
+ )
116
137
  # the outer mean() below is for the times stamps
117
138
  mse_norm_loss = loss_weight * jnp.mean(
118
139
  jnp.abs(
119
140
  jnp.mean(
120
- jnp.mean(res, axis=-1),
141
+ res.squeeze(),
121
142
  axis=(d + 1 for d in range(res.ndim - 2)),
122
143
  )
123
144
  * int_length
@@ -140,23 +161,16 @@ def boundary_condition_apply(
140
161
  omega_boundary_dim: int,
141
162
  loss_weight: float | Float[Array, "boundary_cond_dim"],
142
163
  ) -> float:
164
+
165
+ vmap_in_axes = (0,) + _get_vmap_in_axes_params(batch.param_batch_dict, params)
166
+
143
167
  if isinstance(omega_boundary_fun, dict):
144
168
  # We must create the facet tree dictionary as we do not have the
145
169
  # enumerate from the for loop to pass the id integer
146
- if (
147
- isinstance(batch, PDEStatioBatch) and batch.border_batch.shape[-1] == 2
148
- ) or (
149
- isinstance(batch, PDENonStatioBatch)
150
- and batch.times_x_border_batch.shape[-1] == 2
151
- ):
170
+ if batch.border_batch.shape[-1] == 2:
152
171
  # 1D
153
172
  facet_tree = {"xmin": 0, "xmax": 1}
154
- elif (
155
- isinstance(batch, PDEStatioBatch) and batch.border_batch.shape[-1] == 4
156
- ) or (
157
- isinstance(batch, PDENonStatioBatch)
158
- and batch.times_x_border_batch.shape[-1] == 4
159
- ):
173
+ elif batch.border_batch.shape[-1] == 4:
160
174
  # 2D
161
175
  facet_tree = {"xmin": 0, "xmax": 1, "ymin": 2, "ymax": 3}
162
176
  else:
@@ -166,7 +180,10 @@ def boundary_condition_apply(
166
180
  None
167
181
  if c is None
168
182
  else jnp.mean(
169
- loss_weight * _compute_boundary_loss(c, f, batch, u, params, fa, d)
183
+ loss_weight
184
+ * _compute_boundary_loss(
185
+ c, f, batch, u, params, fa, d, vmap_in_axes
186
+ )
170
187
  )
171
188
  ),
172
189
  omega_boundary_condition,
@@ -180,10 +197,7 @@ def boundary_condition_apply(
180
197
  # Note that to keep the behaviour given in the comment above we neede
181
198
  # to specify is_leaf according to the note in the release of 0.4.29
182
199
  else:
183
- if isinstance(batch, PDEStatioBatch):
184
- facet_tuple = tuple(f for f in range(batch.border_batch.shape[-1]))
185
- else:
186
- facet_tuple = tuple(f for f in range(batch.times_x_border_batch.shape[-1]))
200
+ facet_tuple = tuple(f for f in range(batch.border_batch.shape[-1]))
187
201
  b_losses_by_facet = jax.tree_util.tree_map(
188
202
  lambda fa: jnp.mean(
189
203
  loss_weight
@@ -195,6 +209,7 @@ def boundary_condition_apply(
195
209
  params,
196
210
  fa,
197
211
  omega_boundary_dim,
212
+ vmap_in_axes,
198
213
  )
199
214
  ),
200
215
  facet_tuple,
@@ -225,8 +240,10 @@ def observations_loss_apply(
225
240
  mse_observation_loss = jnp.mean(
226
241
  jnp.sum(
227
242
  loss_weight
228
- * (val - _check_user_func_return(observed_values, val.shape)) ** 2,
229
- # the reshape above avoids a potential missing (1,)
243
+ * _subtract_with_check(
244
+ observed_values, val, cause="user defined observed_values"
245
+ )
246
+ ** 2,
230
247
  axis=-1,
231
248
  )
232
249
  )
@@ -243,33 +260,38 @@ def initial_condition_apply(
243
260
  params: Params | ParamsDict,
244
261
  vmap_axes: tuple[int | None, ...],
245
262
  initial_condition_fun: Callable,
246
- n: int,
247
263
  loss_weight: float | Float[Array, "initial_condition_dimension"],
248
264
  ) -> float:
265
+ n = omega_batch.shape[0]
266
+ t0_omega_batch = jnp.concatenate([jnp.zeros((n, 1)), omega_batch], axis=1)
249
267
  if isinstance(u, (PINN, HYPERPINN)):
250
268
  v_u_t0 = vmap(
251
- lambda x, params: initial_condition_fun(x) - u(jnp.zeros((1,)), x, params),
269
+ lambda t0_x, params: _subtract_with_check(
270
+ initial_condition_fun(t0_x[1:]),
271
+ u(t0_x, params),
272
+ cause="Output of initial_condition_fun",
273
+ ),
252
274
  vmap_axes,
253
275
  0,
254
276
  )
255
- res = v_u_t0(omega_batch, params) # NOTE take the tiled
277
+ res = v_u_t0(t0_omega_batch, params) # NOTE take the tiled
256
278
  # omega_batch (ie omega_batch_) to have the same batch
257
279
  # dimension as params to be able to vmap.
258
280
  # Recall that by convention:
259
281
  # param_batch_dict = times_batch_size * omega_batch_size
260
282
  mse_initial_condition = jnp.mean(jnp.sum(loss_weight * res**2, axis=-1))
261
283
  elif isinstance(u, SPINN):
262
- values = lambda x: u(
263
- jnp.repeat(jnp.zeros((1, 1)), n, axis=0),
264
- x,
284
+ values = lambda t_x: u(
285
+ t_x,
265
286
  params,
266
287
  )[0]
267
- omega_batch_grid = _get_grid(omega_batch)
268
- v_ini = values(omega_batch)
269
- ini = _check_user_func_return(
270
- initial_condition_fun(omega_batch_grid), v_ini.shape
288
+ omega_batch_grid = get_grid(omega_batch)
289
+ v_ini = values(t0_omega_batch)
290
+ res = _subtract_with_check(
291
+ initial_condition_fun(omega_batch_grid),
292
+ v_ini,
293
+ cause="Output of initial_condition_fun",
271
294
  )
272
- res = ini - v_ini
273
295
  mse_initial_condition = jnp.mean(jnp.sum(loss_weight * res**2, axis=-1))
274
296
  else:
275
297
  raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
@@ -297,12 +319,12 @@ def constraints_system_loss_apply(
297
319
  if isinstance(params_dict.nn_params, dict):
298
320
 
299
321
  def apply_u_constraint(
300
- u_constraint, nn_params, loss_weights_for_u, obs_batch_u
322
+ u_constraint, nn_params, eq_params, loss_weights_for_u, obs_batch_u
301
323
  ):
302
324
  res_dict_for_u = u_constraint.evaluate(
303
325
  Params(
304
326
  nn_params=nn_params,
305
- eq_params=params_dict.eq_params,
327
+ eq_params=eq_params,
306
328
  ),
307
329
  append_obs_batch(batch, obs_batch_u),
308
330
  )[1]
@@ -319,6 +341,11 @@ def constraints_system_loss_apply(
319
341
  apply_u_constraint,
320
342
  u_constraints_dict,
321
343
  params_dict.nn_params,
344
+ (
345
+ params_dict.eq_params
346
+ if params_dict.eq_params.keys() == params_dict.nn_params.keys()
347
+ else {k: params_dict.eq_params for k in params_dict.nn_params.keys()}
348
+ ), # this manipulation is needed since we authorize eq_params not to have the same structure as nn_params in ParamsDict
322
349
  loss_weights_T,
323
350
  batch.obs_batch_dict,
324
351
  is_leaf=lambda x: (