jinns 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/solver/_solve.py CHANGED
@@ -7,6 +7,7 @@ from __future__ import (
7
7
  annotations,
8
8
  ) # https://docs.python.org/3/library/typing.html#constant
9
9
 
10
+ import time
10
11
  from typing import TYPE_CHECKING, NamedTuple, Dict, Union
11
12
  from functools import partial
12
13
  import optax
@@ -16,6 +17,7 @@ import jax.numpy as jnp
16
17
  from jaxtyping import Int, Bool, Float, Array
17
18
  from jinns.solver._rar import init_rar, trigger_rar
18
19
  from jinns.utils._utils import _check_nan_in_pytree
20
+ from jinns.solver._utils import _check_batch_size
19
21
  from jinns.utils._containers import *
20
22
  from jinns.data._DataGenerators import (
21
23
  DataGeneratorODE,
@@ -29,31 +31,6 @@ if TYPE_CHECKING:
29
31
  from jinns.utils._types import *
30
32
 
31
33
 
32
- def _check_batch_size(other_data, main_data, attr_name):
33
- if (
34
- (
35
- isinstance(main_data, DataGeneratorODE)
36
- and getattr(other_data, attr_name) != main_data.temporal_batch_size
37
- )
38
- or (
39
- isinstance(main_data, CubicMeshPDEStatio)
40
- and not isinstance(main_data, CubicMeshPDENonStatio)
41
- and getattr(other_data, attr_name) != main_data.omega_batch_size
42
- )
43
- or (
44
- isinstance(main_data, CubicMeshPDENonStatio)
45
- and getattr(other_data, attr_name)
46
- != main_data.omega_batch_size * main_data.temporal_batch_size
47
- )
48
- ):
49
- raise ValueError(
50
- "Optional other_data.param_batch_size must be"
51
- " equal to main_data.temporal_batch_size or main_data.omega_batch_size or"
52
- " the product of both dependeing on the type of the main"
53
- " datagenerator"
54
- )
55
-
56
-
57
34
  def solve(
58
35
  n_iter: Int,
59
36
  init_params: AnyParams,
@@ -167,10 +144,22 @@ def solve(
167
144
  The best parameters according to the validation criterion
168
145
  """
169
146
  if param_data is not None:
170
- _check_batch_size(param_data, data, "param_batch_size")
171
-
172
- if obs_data is not None:
173
- _check_batch_size(obs_data, data, "obs_batch_size")
147
+ if param_data.param_batch_size is not None:
148
+ # We need to check that batch sizes will all be compliant for
149
+ # correct vectorization
150
+ _check_batch_size(param_data, data, "param_batch_size")
151
+ else:
152
+ # If DataGeneratorParameter does not have a batch size we will
153
+ # vectorization using `n`, and the same checks must be done
154
+ _check_batch_size(param_data, data, "n")
155
+
156
+ if obs_data is not None and param_data is not None:
157
+ # obs_data batch dimensions need only to be aligned with param_data
158
+ # batch dimensions if the latter exist
159
+ if obs_data.obs_batch_size is not None:
160
+ _check_batch_size(obs_data, param_data, "obs_batch_size")
161
+ else:
162
+ _check_batch_size(obs_data, param_data, "n")
174
163
 
175
164
  if opt_state is None:
176
165
  opt_state = optimizer.init(init_params)
@@ -224,6 +213,8 @@ def solve(
224
213
  )
225
214
  optimization_extra = OptimizationExtraContainer(
226
215
  curr_seq=curr_seq,
216
+ best_iter_id=0,
217
+ best_val_criterion=jnp.nan,
227
218
  best_val_params=init_params,
228
219
  )
229
220
  loss_container = LossContainer(
@@ -323,16 +314,26 @@ def solve(
323
314
  validation_criterion
324
315
  )
325
316
 
326
- # update best_val_params w.r.t val_loss if needed
327
- best_val_params = jax.lax.cond(
317
+ # update best_val_params and best_val_criterion w.r.t val_loss if needed
318
+ (best_val_params, best_val_criterion, best_iter_id) = jax.lax.cond(
328
319
  update_best_params,
329
- lambda _: params, # update with current value
330
- lambda operands: operands[0].best_val_params, # unchanged
320
+ lambda operands: (
321
+ params,
322
+ validation_criterion,
323
+ i,
324
+ ), # update with current value
325
+ lambda operands: (
326
+ operands[0].best_val_params,
327
+ operands[0].best_val_criterion,
328
+ operands[0].best_iter_id,
329
+ ), # unchanged
331
330
  (optimization_extra,),
332
331
  )
333
332
  else:
334
333
  early_stopping = False
334
+ best_iter_id = 0
335
335
  best_val_params = params
336
+ best_val_criterion = jnp.nan
336
337
 
337
338
  # Trigger RAR
338
339
  loss, params, data = trigger_rar(
@@ -358,7 +359,13 @@ def solve(
358
359
  i,
359
360
  loss,
360
361
  OptimizationContainer(params, last_non_nan_params, opt_state),
361
- OptimizationExtraContainer(curr_seq, best_val_params, early_stopping),
362
+ OptimizationExtraContainer(
363
+ curr_seq,
364
+ best_iter_id,
365
+ best_val_criterion,
366
+ best_val_params,
367
+ early_stopping,
368
+ ),
362
369
  DataGeneratorContainer(data, param_data, obs_data),
363
370
  validation,
364
371
  LossContainer(stored_loss_terms, train_loss_values),
@@ -373,7 +380,20 @@ def solve(
373
380
  while break_fun(carry):
374
381
  carry = _one_iteration(carry)
375
382
  else:
376
- carry = jax.lax.while_loop(break_fun, _one_iteration, carry)
383
+
384
+ def train_fun(carry):
385
+ return jax.lax.while_loop(break_fun, _one_iteration, carry)
386
+
387
+ start = time.time()
388
+ compiled_train_fun = jax.jit(train_fun).lower(carry).compile()
389
+ end = time.time()
390
+ print("\nCompilation took\n", end - start, "\n")
391
+
392
+ start = time.time()
393
+ carry = compiled_train_fun(carry)
394
+ jax.block_until_ready(carry)
395
+ end = time.time()
396
+ print("\nTraining took\n", end - start, "\n")
377
397
 
378
398
  (
379
399
  i,
@@ -389,15 +409,30 @@ def solve(
389
409
 
390
410
  if verbose:
391
411
  jax.debug.print(
392
- "Final iteration {i}: train loss value = {train_loss_val}",
412
+ "\nFinal iteration {i}: train loss value = {train_loss_val}",
393
413
  i=i,
394
414
  train_loss_val=loss_container.train_loss_values[i - 1],
395
415
  )
416
+
417
+ # get ready to return the parameters at last iteration...
418
+ # (by default arbitrary choice, this could be None)
419
+ validation_parameters = optimization.last_non_nan_params
396
420
  if validation is not None:
397
421
  jax.debug.print(
398
422
  "validation loss value = {validation_loss_val}",
399
423
  validation_loss_val=validation_crit_values[i - 1],
400
424
  )
425
+ if optimization_extra.early_stopping:
426
+ jax.debug.print(
427
+ "\n Returning a set of best parameters from early stopping"
428
+ " as last argument!\n"
429
+ " Best parameters from iteration {best_iter_id}"
430
+ " with validation loss criterion = {best_val_criterion}",
431
+ best_iter_id=optimization_extra.best_iter_id,
432
+ best_val_criterion=optimization_extra.best_val_criterion,
433
+ )
434
+ # ...but if early stopping, return the parameters at the best_iter_id
435
+ validation_parameters = optimization_extra.best_val_params
401
436
 
402
437
  return (
403
438
  optimization.last_non_nan_params,
@@ -408,7 +443,7 @@ def solve(
408
443
  optimization.opt_state,
409
444
  stored_objects.stored_params,
410
445
  validation_crit_values if validation is not None else None,
411
- optimization_extra.best_val_params if validation is not None else None,
446
+ validation_parameters,
412
447
  )
413
448
 
414
449
 
@@ -531,7 +566,7 @@ def _get_break_fun(n_iter: Int, verbose: Bool) -> Callable[[main_carry], Bool]:
531
566
  string is not a valid JAX type that can be fed into the operands
532
567
  """
533
568
  if verbose:
534
- jax.debug.print(f"Stopping main optimization loop, cause: {msg}")
569
+ jax.debug.print(f"\nStopping main optimization loop, cause: {msg}")
535
570
  return False
536
571
 
537
572
  def continue_while_loop(_):
jinns/solver/_utils.py ADDED
@@ -0,0 +1,122 @@
1
+ from jinns.data._DataGenerators import (
2
+ DataGeneratorODE,
3
+ CubicMeshPDEStatio,
4
+ CubicMeshPDENonStatio,
5
+ DataGeneratorParameter,
6
+ )
7
+
8
+
9
+ def _check_batch_size(other_data, main_data, attr_name):
10
+ if isinstance(main_data, DataGeneratorODE):
11
+ if main_data.temporal_batch_size is not None:
12
+ if getattr(other_data, attr_name) != main_data.temporal_batch_size:
13
+ raise ValueError(
14
+ f"{other_data.__class__}.{attr_name} must be equal"
15
+ f" to {main_data.__class__}.temporal_batch_size for correct"
16
+ " vectorization"
17
+ )
18
+ else:
19
+ if main_data.nt is not None:
20
+ if getattr(other_data, attr_name) != main_data.nt:
21
+ raise ValueError(
22
+ f"{other_data.__class__}.{attr_name} must be equal"
23
+ f" to {main_data.__class__}.nt for correct"
24
+ " vectorization"
25
+ )
26
+ if isinstance(main_data, CubicMeshPDEStatio) and not isinstance(
27
+ main_data, CubicMeshPDENonStatio
28
+ ):
29
+ if main_data.omega_batch_size is not None:
30
+ if getattr(other_data, attr_name) != main_data.omega_batch_size:
31
+ raise ValueError(
32
+ f"{other_data.__class__}.{attr_name} must be equal"
33
+ f" to {main_data.__class__}.omega_batch_size for correct"
34
+ " vectorization"
35
+ )
36
+ else:
37
+ if main_data.n is not None:
38
+ if getattr(other_data, attr_name) != main_data.n:
39
+ raise ValueError(
40
+ f"{other_data.__class__}.{attr_name} must be equal"
41
+ f" to {main_data.__class__}.n for correct"
42
+ " vectorization"
43
+ )
44
+ if main_data.omega_border_batch_size is not None:
45
+ if getattr(other_data, attr_name) != main_data.omega_border_batch_size:
46
+ raise ValueError(
47
+ f"{other_data.__class__}.{attr_name} must be equal"
48
+ f" to {main_data.__class__}.omega_border_batch_size for correct"
49
+ " vectorization"
50
+ )
51
+ else:
52
+ if main_data.nb is not None:
53
+ if getattr(other_data, attr_name) != main_data.nb:
54
+ raise ValueError(
55
+ f"{other_data.__class__}.{attr_name} must be equal"
56
+ f" to {main_data.__class__}.nb for correct"
57
+ " vectorization"
58
+ )
59
+ if isinstance(main_data, CubicMeshPDENonStatio):
60
+ if main_data.domain_batch_size is not None:
61
+ if getattr(other_data, attr_name) != main_data.domain_batch_size:
62
+ raise ValueError(
63
+ f"{other_data.__class__}.{attr_name} must be equal"
64
+ f" to {main_data.__class__}.domain_batch_size for correct"
65
+ " vectorization"
66
+ )
67
+ else:
68
+ if main_data.n is not None:
69
+ if getattr(other_data, attr_name) != main_data.n:
70
+ raise ValueError(
71
+ f"{other_data.__class__}.{attr_name} must be equal"
72
+ f" to {main_data.__class__}.n for correct"
73
+ " vectorization"
74
+ )
75
+ if main_data.border_batch_size is not None:
76
+ if getattr(other_data, attr_name) != main_data.border_batch_size:
77
+ raise ValueError(
78
+ f"{other_data.__class__}.{attr_name} must be equal"
79
+ f" to {main_data.__class__}.border_batch_size for correct"
80
+ " vectorization"
81
+ )
82
+ else:
83
+ if main_data.nb is not None:
84
+ if main_data.dim > 1 and getattr(other_data, attr_name) != (
85
+ main_data.nb // 2**main_data.dim
86
+ ):
87
+ raise ValueError(
88
+ f"{other_data.__class__}.{attr_name} must be equal"
89
+ f" to ({main_data.__class__}.nb // 2**{main_data.__class__}.dim)"
90
+ " for correct vectorization"
91
+ )
92
+ if main_data.initial_batch_size is not None:
93
+ if getattr(other_data, attr_name) != main_data.initial_batch_size:
94
+ raise ValueError(
95
+ f"{other_data.__class__}.{attr_name} must be equal"
96
+ f" to {main_data.__class__}.initial_batch_size for correct"
97
+ " vectorization"
98
+ )
99
+ else:
100
+ if main_data.ni is not None:
101
+ if getattr(other_data, attr_name) != main_data.ni:
102
+ raise ValueError(
103
+ f"{other_data.__class__}.{attr_name} must be equal"
104
+ f" to {main_data.__class__}.ni for correct"
105
+ " vectorization"
106
+ )
107
+ if isinstance(main_data, DataGeneratorParameter):
108
+ if main_data.param_batch_size is not None:
109
+ if getattr(other_data, attr_name) != main_data.param_batch_size:
110
+ raise ValueError(
111
+ f"{other_data.__class__}.{attr_name} must be equal"
112
+ f" to {main_data.__class__}.param_batch_size for correct"
113
+ " vectorization"
114
+ )
115
+ else:
116
+ if main_data.n is not None:
117
+ if getattr(other_data, attr_name) != main_data.n:
118
+ raise ValueError(
119
+ f"{other_data.__class__}.{attr_name} must be equal"
120
+ f" to {main_data.__class__}.n for correct"
121
+ " vectorization"
122
+ )
jinns/utils/__init__.py CHANGED
@@ -1,4 +1,6 @@
1
1
  from ._pinn import create_PINN, PINN
2
+ from ._ppinn import create_PPINN, PPINN
2
3
  from ._spinn import create_SPINN, SPINN
3
4
  from ._hyperpinn import create_HYPERPINN, HYPERPINN
4
5
  from ._save_load import save_pinn, load_pinn
6
+ from ._utils import get_grid
@@ -38,7 +38,9 @@ class OptimizationContainer(eqx.Module):
38
38
 
39
39
  class OptimizationExtraContainer(eqx.Module):
40
40
  curr_seq: int
41
- best_val_params: Params
41
+ best_iter_id: int # the best iteration number (that which achieves best_val_params and best_val_params)
42
+ best_val_criterion: float # the best validation criterion at early stopping
43
+ best_val_params: Params # the best parameter values at early stopping
42
44
  early_stopping: Bool = False
43
45
 
44
46
 
jinns/utils/_hyperpinn.py CHANGED
@@ -16,7 +16,7 @@ import equinox as eqx
16
16
  import numpy as onp
17
17
 
18
18
  from jinns.utils._pinn import PINN, _MLP
19
- from jinns.parameters._params import Params
19
+ from jinns.parameters._params import Params, ParamsDict
20
20
 
21
21
 
22
22
  def _get_param_nb(
@@ -114,6 +114,7 @@ class HYPERPINN(PINN):
114
114
  )
115
115
  self.pinn_params_sum, self.pinn_params_cumsum = _get_param_nb(self.params)
116
116
 
117
+ @property
117
118
  def init_params(self) -> Params:
118
119
  """
119
120
  Returns an initial set of parameters
@@ -138,14 +139,20 @@ class HYPERPINN(PINN):
138
139
  is_leaf=lambda x: isinstance(x, jnp.ndarray),
139
140
  )
140
141
 
141
- def eval_nn(
142
+ def __call__(
142
143
  self,
143
144
  inputs: Float[Array, "input_dim"],
144
- params: Params | PyTree,
145
+ params: Params | ParamsDict | PyTree,
145
146
  ) -> Float[Array, "output_dim"]:
146
147
  """
147
- Evaluate the HYPERPINN on some inputs with some params.
148
+ Evaluate the HyperPINN on some inputs with some params.
148
149
  """
150
+ if len(inputs.shape) == 0:
151
+ # This can happen often when the user directly provides some
152
+ # collocation points (eg for plotting, whithout using
153
+ # DataGenerators)
154
+ inputs = inputs[None]
155
+
149
156
  try:
150
157
  hyper = eqx.combine(params.nn_params, self.static_hyper)
151
158
  except (KeyError, AttributeError, TypeError) as e: # give more flexibility
@@ -190,7 +197,7 @@ def create_HYPERPINN(
190
197
  slice_solution: slice = None,
191
198
  shared_pinn_outputs: slice = None,
192
199
  eqx_list_hyper: tuple[tuple[Callable, int, int] | Callable, ...] = None,
193
- ) -> HYPERPINN | list[HYPERPINN]:
200
+ ) -> tuple[HYPERPINN | list[HYPERPINN], PyTree | list[PyTree]]:
194
201
  r"""
195
202
  Utility function to create a standard PINN neural network with the equinox
196
203
  library.
@@ -274,6 +281,9 @@ def create_HYPERPINN(
274
281
  A HYPERPINN instance or, when `shared_pinn_ouput` is not None,
275
282
  a list of HYPERPINN instances with the same structure is returned,
276
283
  only differing by there final slicing of the network output.
284
+ hyperpinn.init_params
285
+ The initial set of parameters for the HyperPINN or a list of the latter
286
+ when `shared_pinn_ouput` is not None.
277
287
 
278
288
 
279
289
  Raises
@@ -389,7 +399,7 @@ def create_HYPERPINN(
389
399
  output_slice=output_slice,
390
400
  )
391
401
  hyperpinns.append(hyperpinn)
392
- return hyperpinns
402
+ return hyperpinns, [h.init_params for h in hyperpinns]
393
403
  with warnings.catch_warnings():
394
404
  # Catch the equinox warning because we put the number of
395
405
  # parameters as static while being jnp.Array. This this time
@@ -407,4 +417,4 @@ def create_HYPERPINN(
407
417
  hypernet_input_size=hypernet_input_size,
408
418
  output_slice=None,
409
419
  )
410
- return hyperpinn
420
+ return hyperpinn, hyperpinn.init_params
jinns/utils/_pinn.py CHANGED
@@ -10,7 +10,7 @@ import equinox as eqx
10
10
 
11
11
  from jaxtyping import Array, Key, PyTree, Float
12
12
 
13
- from jinns.parameters._params import Params
13
+ from jinns.parameters._params import Params, ParamsDict
14
14
 
15
15
 
16
16
  class _MLP(eqx.Module):
@@ -128,40 +128,27 @@ class PINN(eqx.Module):
128
128
  def __post_init__(self, mlp):
129
129
  self.params, self.static = eqx.partition(mlp, eqx.is_inexact_array)
130
130
 
131
+ @property
131
132
  def init_params(self) -> PyTree:
132
133
  """
133
134
  Returns an initial set of parameters
134
135
  """
135
136
  return self.params
136
137
 
137
- def __call__(self, *args) -> Float[Array, "output_dim"]:
138
- """
139
- Calls `eval_nn` with rearranged arguments
140
- """
141
- if self.eq_type == "ODE":
142
- (t, params) = args
143
- if len(t.shape) == 0:
144
- t = t[..., None] # Add mandatory dimension which can be lacking
145
- # (eg. for the ODE batches) but this dimension can already
146
- # exists (eg. for user provided observation times)
147
- return self.eval_nn(t, params)
148
- if self.eq_type == "statio_PDE":
149
- (x, params) = args
150
- return self.eval_nn(x, params)
151
- if self.eq_type == "nonstatio_PDE":
152
- (t, x, params) = args
153
- t_x = jnp.concatenate([t, x], axis=-1)
154
- return self.eval_nn(t_x, params)
155
- raise ValueError("Wrong value for self.eq_type")
156
-
157
- def eval_nn(
138
+ def __call__(
158
139
  self,
159
- inputs: Float[Array, "input_dim"],
160
- params: Params | PyTree,
140
+ inputs: Float[Array, "1"] | Float[Array, "dim"] | Float[Array, "1+dim"],
141
+ params: Params | ParamsDict | PyTree,
161
142
  ) -> Float[Array, "output_dim"]:
162
143
  """
163
144
  Evaluate the PINN on some inputs with some params.
164
145
  """
146
+ if len(inputs.shape) == 0:
147
+ # This can happen often when the user directly provides some
148
+ # collocation points (eg for plotting, whithout using
149
+ # DataGenerators)
150
+ inputs = inputs[None]
151
+
165
152
  try:
166
153
  model = eqx.combine(params.nn_params, self.static)
167
154
  except (KeyError, AttributeError, TypeError) as e: # give more flexibility
@@ -193,7 +180,7 @@ def create_PINN(
193
180
  ] = None,
194
181
  shared_pinn_outputs: tuple[slice] = None,
195
182
  slice_solution: slice = None,
196
- ) -> PINN | list[PINN]:
183
+ ) -> tuple[PINN | list[PINN], PyTree | list[PyTree]]:
197
184
  r"""
198
185
  Utility function to create a standard PINN neural network with the equinox
199
186
  library.
@@ -266,6 +253,9 @@ def create_PINN(
266
253
  A PINN instance or, when `shared_pinn_ouput` is not None,
267
254
  a list of PINN instances with the same structure is returned,
268
255
  only differing by there final slicing of the network output.
256
+ pinn.init_params
257
+ An initial set of parameters for the PINN or a list of the latter
258
+ when `shared_pinn_ouput` is not None.
269
259
 
270
260
  Raises
271
261
  ------
@@ -322,7 +312,7 @@ def create_PINN(
322
312
  output_slice=output_slice,
323
313
  )
324
314
  pinns.append(pinn)
325
- return pinns
315
+ return pinns, [p.init_params for p in pinns]
326
316
  pinn = PINN(
327
317
  mlp=mlp,
328
318
  slice_solution=slice_solution,
@@ -331,4 +321,4 @@ def create_PINN(
331
321
  output_transform=output_transform,
332
322
  output_slice=None,
333
323
  )
334
- return pinn
324
+ return pinn, pinn.init_params