jinns 1.1.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/loss/_loss_utils.py CHANGED
@@ -16,13 +16,12 @@ from jaxtyping import Float, Array, PyTree
16
16
  from jinns.loss._boundary_conditions import (
17
17
  _compute_boundary_loss,
18
18
  )
19
- from jinns.utils._utils import _check_user_func_return, _get_grid
20
- from jinns.data._DataGenerators import (
21
- append_obs_batch,
22
- )
23
- from jinns.utils._pinn import PINN
24
- from jinns.utils._spinn import SPINN
25
- from jinns.utils._hyperpinn import HYPERPINN
19
+ from jinns.utils._utils import _subtract_with_check, get_grid
20
+ from jinns.data._DataGenerators import append_obs_batch, make_cartesian_product
21
+ from jinns.parameters._params import _get_vmap_in_axes_params
22
+ from jinns.nn._pinn import PINN
23
+ from jinns.nn._spinn import SPINN
24
+ from jinns.nn._hyperpinn import HyperPINN
26
25
  from jinns.data._Batchs import *
27
26
  from jinns.parameters._params import Params, ParamsDict
28
27
 
@@ -33,28 +32,32 @@ if TYPE_CHECKING:
33
32
  def dynamic_loss_apply(
34
33
  dyn_loss: DynamicLoss,
35
34
  u: eqx.Module,
36
- batches: ODEBatch | PDEStatioBatch | PDENonStatioBatch,
35
+ batch: (
36
+ Float[Array, "batch_size 1"]
37
+ | Float[Array, "batch_size dim"]
38
+ | Float[Array, "batch_size 1+dim"]
39
+ ),
37
40
  params: Params | ParamsDict,
38
41
  vmap_axes: tuple[int | None, ...],
39
42
  loss_weight: float | Float[Array, "dyn_loss_dimension"],
40
- u_type: PINN | HYPERPINN | None = None,
43
+ u_type: PINN | HyperPINN | None = None,
41
44
  ) -> float:
42
45
  """
43
46
  Sometimes when u is a lambda function a or dict we do not have access to
44
47
  its type here, hence the last argument
45
48
  """
46
- if u_type == PINN or u_type == HYPERPINN or isinstance(u, (PINN, HYPERPINN)):
49
+ if u_type == PINN or u_type == HyperPINN or isinstance(u, (PINN, HyperPINN)):
47
50
  v_dyn_loss = vmap(
48
- lambda *args: dyn_loss(
49
- *args[:-1], u, args[-1] # we must place the params at the end
51
+ lambda batch, params: dyn_loss(
52
+ batch, u, params # we must place the params at the end
50
53
  ),
51
54
  vmap_axes,
52
55
  0,
53
56
  )
54
- residuals = v_dyn_loss(*batches, params)
57
+ residuals = v_dyn_loss(batch, params)
55
58
  mse_dyn_loss = jnp.mean(jnp.sum(loss_weight * residuals**2, axis=-1))
56
59
  elif u_type == SPINN or isinstance(u, SPINN):
57
- residuals = dyn_loss(*batches, u, params)
60
+ residuals = dyn_loss(batch, u, params)
58
61
  mse_dyn_loss = jnp.mean(jnp.sum(loss_weight * residuals**2, axis=-1))
59
62
  else:
60
63
  raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
@@ -64,47 +67,65 @@ def dynamic_loss_apply(
64
67
 
65
68
  def normalization_loss_apply(
66
69
  u: eqx.Module,
67
- batches: ODEBatch | PDEStatioBatch | PDENonStatioBatch,
70
+ batches: (
71
+ tuple[Float[Array, "nb_norm_samples dim"]]
72
+ | tuple[
73
+ Float[Array, "nb_norm_time_slices 1"], Float[Array, "nb_norm_samples dim"]
74
+ ]
75
+ ),
68
76
  params: Params | ParamsDict,
69
- vmap_axes: tuple[int | None, ...],
70
- int_length: int,
77
+ vmap_axes_params: tuple[int | None, ...],
78
+ norm_weights: Float[Array, "nb_norm_samples"],
71
79
  loss_weight: float,
72
80
  ) -> float:
73
- # TODO merge stationary and non stationary cases
74
- if isinstance(u, (PINN, HYPERPINN)):
81
+ """
82
+ Note the squeezing on each result. We expect unidimensional *PINN since
83
+ they represent probability distributions
84
+ """
85
+ if isinstance(u, (PINN, HyperPINN)):
75
86
  if len(batches) == 1:
76
87
  v_u = vmap(
77
- lambda *args: u(*args)[u.slice_solution],
78
- vmap_axes,
88
+ lambda *b: u(*b)[u.slice_solution],
89
+ (0,) + vmap_axes_params,
79
90
  0,
80
91
  )
81
- mse_norm_loss = loss_weight * jnp.mean(
82
- jnp.abs(jnp.mean(v_u(*batches, params), axis=-1) * int_length - 1) ** 2
92
+ res = v_u(*batches, params)
93
+ assert res.shape[-1] == 1, "norm loss expects unidimensional *PINN"
94
+ # Monte-Carlo integration using importance sampling
95
+ mse_norm_loss = loss_weight * (
96
+ jnp.abs(jnp.mean(res.squeeze() * norm_weights) - 1) ** 2
83
97
  )
84
98
  else:
99
+ # NOTE this cartesian product is costly
100
+ batches = make_cartesian_product(
101
+ batches[0],
102
+ batches[1],
103
+ ).reshape(batches[0].shape[0], batches[1].shape[0], -1)
85
104
  v_u = vmap(
86
105
  vmap(
87
- lambda t, x, params_: u(t, x, params_),
88
- in_axes=(None, 0) + vmap_axes[2:],
106
+ lambda t_x, params_: u(t_x, params_),
107
+ in_axes=(0,) + vmap_axes_params,
89
108
  ),
90
- in_axes=(0, None) + vmap_axes[2:],
109
+ in_axes=(0,) + vmap_axes_params,
91
110
  )
92
- res = v_u(*batches, params)
93
- # the outer mean() below is for the times stamps
111
+ res = v_u(batches, params)
112
+ assert res.shape[-1] == 1, "norm loss expects unidimensional *PINN"
113
+ # For all times t, we perform an integration. Then we average the
114
+ # losses over times.
94
115
  mse_norm_loss = loss_weight * jnp.mean(
95
- jnp.abs(jnp.mean(res, axis=(-2, -1)) * int_length - 1) ** 2
116
+ jnp.abs(jnp.mean(res.squeeze() * norm_weights, axis=-1) - 1) ** 2
96
117
  )
97
118
  elif isinstance(u, SPINN):
98
119
  if len(batches) == 1:
99
120
  res = u(*batches, params)
121
+ assert res.shape[-1] == 1, "norm loss expects unidimensional *SPINN"
100
122
  mse_norm_loss = (
101
123
  loss_weight
102
124
  * jnp.abs(
103
125
  jnp.mean(
104
- jnp.mean(res, axis=-1),
105
- axis=tuple(range(res.ndim - 1)),
126
+ res.squeeze(),
106
127
  )
107
- * int_length
128
+ * norm_weights
108
129
  - 1
109
130
  )
110
131
  ** 2
@@ -112,15 +133,21 @@ def normalization_loss_apply(
112
133
  else:
113
134
  assert batches[1].shape[0] % batches[0].shape[0] == 0
114
135
  rep_t = batches[1].shape[0] // batches[0].shape[0]
115
- res = u(jnp.repeat(batches[0], rep_t, axis=0), batches[1], params)
136
+ res = u(
137
+ jnp.concatenate(
138
+ [jnp.repeat(batches[0], rep_t, axis=0), batches[1]], axis=-1
139
+ ),
140
+ params,
141
+ )
142
+ assert res.shape[-1] == 1, "norm loss expects unidimensional *SPINN"
116
143
  # the outer mean() below is for the times stamps
117
144
  mse_norm_loss = loss_weight * jnp.mean(
118
145
  jnp.abs(
119
146
  jnp.mean(
120
- jnp.mean(res, axis=-1),
147
+ res.squeeze(),
121
148
  axis=(d + 1 for d in range(res.ndim - 2)),
122
149
  )
123
- * int_length
150
+ * norm_weights
124
151
  - 1
125
152
  )
126
153
  ** 2
@@ -140,23 +167,16 @@ def boundary_condition_apply(
140
167
  omega_boundary_dim: int,
141
168
  loss_weight: float | Float[Array, "boundary_cond_dim"],
142
169
  ) -> float:
170
+
171
+ vmap_in_axes = (0,) + _get_vmap_in_axes_params(batch.param_batch_dict, params)
172
+
143
173
  if isinstance(omega_boundary_fun, dict):
144
174
  # We must create the facet tree dictionary as we do not have the
145
175
  # enumerate from the for loop to pass the id integer
146
- if (
147
- isinstance(batch, PDEStatioBatch) and batch.border_batch.shape[-1] == 2
148
- ) or (
149
- isinstance(batch, PDENonStatioBatch)
150
- and batch.times_x_border_batch.shape[-1] == 2
151
- ):
176
+ if batch.border_batch.shape[-1] == 2:
152
177
  # 1D
153
178
  facet_tree = {"xmin": 0, "xmax": 1}
154
- elif (
155
- isinstance(batch, PDEStatioBatch) and batch.border_batch.shape[-1] == 4
156
- ) or (
157
- isinstance(batch, PDENonStatioBatch)
158
- and batch.times_x_border_batch.shape[-1] == 4
159
- ):
179
+ elif batch.border_batch.shape[-1] == 4:
160
180
  # 2D
161
181
  facet_tree = {"xmin": 0, "xmax": 1, "ymin": 2, "ymax": 3}
162
182
  else:
@@ -166,7 +186,10 @@ def boundary_condition_apply(
166
186
  None
167
187
  if c is None
168
188
  else jnp.mean(
169
- loss_weight * _compute_boundary_loss(c, f, batch, u, params, fa, d)
189
+ loss_weight
190
+ * _compute_boundary_loss(
191
+ c, f, batch, u, params, fa, d, vmap_in_axes
192
+ )
170
193
  )
171
194
  ),
172
195
  omega_boundary_condition,
@@ -180,10 +203,7 @@ def boundary_condition_apply(
180
203
  # Note that to keep the behaviour given in the comment above we neede
181
204
  # to specify is_leaf according to the note in the release of 0.4.29
182
205
  else:
183
- if isinstance(batch, PDEStatioBatch):
184
- facet_tuple = tuple(f for f in range(batch.border_batch.shape[-1]))
185
- else:
186
- facet_tuple = tuple(f for f in range(batch.times_x_border_batch.shape[-1]))
206
+ facet_tuple = tuple(f for f in range(batch.border_batch.shape[-1]))
187
207
  b_losses_by_facet = jax.tree_util.tree_map(
188
208
  lambda fa: jnp.mean(
189
209
  loss_weight
@@ -195,6 +215,7 @@ def boundary_condition_apply(
195
215
  params,
196
216
  fa,
197
217
  omega_boundary_dim,
218
+ vmap_in_axes,
198
219
  )
199
220
  ),
200
221
  facet_tuple,
@@ -215,7 +236,7 @@ def observations_loss_apply(
215
236
  obs_slice: slice,
216
237
  ) -> float:
217
238
  # TODO implement for SPINN
218
- if isinstance(u, (PINN, HYPERPINN)):
239
+ if isinstance(u, (PINN, HyperPINN)):
219
240
  v_u = vmap(
220
241
  lambda *args: u(*args)[u.slice_solution],
221
242
  vmap_axes,
@@ -225,8 +246,10 @@ def observations_loss_apply(
225
246
  mse_observation_loss = jnp.mean(
226
247
  jnp.sum(
227
248
  loss_weight
228
- * (val - _check_user_func_return(observed_values, val.shape)) ** 2,
229
- # the reshape above avoids a potential missing (1,)
249
+ * _subtract_with_check(
250
+ observed_values, val, cause="user defined observed_values"
251
+ )
252
+ ** 2,
230
253
  axis=-1,
231
254
  )
232
255
  )
@@ -243,33 +266,38 @@ def initial_condition_apply(
243
266
  params: Params | ParamsDict,
244
267
  vmap_axes: tuple[int | None, ...],
245
268
  initial_condition_fun: Callable,
246
- n: int,
247
269
  loss_weight: float | Float[Array, "initial_condition_dimension"],
248
270
  ) -> float:
249
- if isinstance(u, (PINN, HYPERPINN)):
271
+ n = omega_batch.shape[0]
272
+ t0_omega_batch = jnp.concatenate([jnp.zeros((n, 1)), omega_batch], axis=1)
273
+ if isinstance(u, (PINN, HyperPINN)):
250
274
  v_u_t0 = vmap(
251
- lambda x, params: initial_condition_fun(x) - u(jnp.zeros((1,)), x, params),
275
+ lambda t0_x, params: _subtract_with_check(
276
+ initial_condition_fun(t0_x[1:]),
277
+ u(t0_x, params),
278
+ cause="Output of initial_condition_fun",
279
+ ),
252
280
  vmap_axes,
253
281
  0,
254
282
  )
255
- res = v_u_t0(omega_batch, params) # NOTE take the tiled
283
+ res = v_u_t0(t0_omega_batch, params) # NOTE take the tiled
256
284
  # omega_batch (ie omega_batch_) to have the same batch
257
285
  # dimension as params to be able to vmap.
258
286
  # Recall that by convention:
259
287
  # param_batch_dict = times_batch_size * omega_batch_size
260
288
  mse_initial_condition = jnp.mean(jnp.sum(loss_weight * res**2, axis=-1))
261
289
  elif isinstance(u, SPINN):
262
- values = lambda x: u(
263
- jnp.repeat(jnp.zeros((1, 1)), n, axis=0),
264
- x,
290
+ values = lambda t_x: u(
291
+ t_x,
265
292
  params,
266
293
  )[0]
267
- omega_batch_grid = _get_grid(omega_batch)
268
- v_ini = values(omega_batch)
269
- ini = _check_user_func_return(
270
- initial_condition_fun(omega_batch_grid), v_ini.shape
294
+ omega_batch_grid = get_grid(omega_batch)
295
+ v_ini = values(t0_omega_batch)
296
+ res = _subtract_with_check(
297
+ initial_condition_fun(omega_batch_grid),
298
+ v_ini,
299
+ cause="Output of initial_condition_fun",
271
300
  )
272
- res = ini - v_ini
273
301
  mse_initial_condition = jnp.mean(jnp.sum(loss_weight * res**2, axis=-1))
274
302
  else:
275
303
  raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")