jinns 0.8.5__py3-none-any.whl → 0.8.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/loss/_LossODE.py CHANGED
@@ -19,6 +19,8 @@ from jinns.loss._Losses import (
19
19
  )
20
20
  from jinns.utils._pinn import PINN
21
21
 
22
+ _LOSS_WEIGHT_KEYS_ODE = ["observations", "dyn_loss", "initial_condition"]
23
+
22
24
 
23
25
  @register_pytree_node_class
24
26
  class LossODE:
@@ -128,6 +130,10 @@ class LossODE:
128
130
  if self.obs_slice is None:
129
131
  self.obs_slice = jnp.s_[...]
130
132
 
133
+ for k in _LOSS_WEIGHT_KEYS_ODE:
134
+ if k not in self.loss_weights.keys():
135
+ self.loss_weights[k] = 0
136
+
131
137
  def __call__(self, *args, **kwargs):
132
138
  return self.evaluate(*args, **kwargs)
133
139
 
jinns/loss/_LossPDE.py CHANGED
@@ -31,6 +31,16 @@ _IMPLEMENTED_BOUNDARY_CONDITIONS = [
31
31
  "vonneumann",
32
32
  ]
33
33
 
34
+ _LOSS_WEIGHT_KEYS_PDESTATIO = [
35
+ "sobolev",
36
+ "observations",
37
+ "norm_loss",
38
+ "boundary_loss",
39
+ "dyn_loss",
40
+ ]
41
+
42
+ _LOSS_WEIGHT_KEYS_PDENONSTATIO = _LOSS_WEIGHT_KEYS_PDESTATIO + ["initial_condition"]
43
+
34
44
 
35
45
  @register_pytree_node_class
36
46
  class LossPDEAbstract:
@@ -269,8 +279,8 @@ class LossPDEStatio(LossPDEAbstract):
269
279
  the PINN object
270
280
  loss_weights
271
281
  a dictionary with values used to ponderate each term in the loss
272
- function. Valid keys are `dyn_loss`, `norm_loss`, `boundary_loss`
273
- and `observations`.
282
+ function. Valid keys are `dyn_loss`, `norm_loss`, `boundary_loss`,
283
+ `observations` and `sobolev`.
274
284
  Note that we can have jnp.arrays with the same dimension of
275
285
  `u` which then ponderates each output of `u`
276
286
  dynamic_loss
@@ -441,18 +451,12 @@ class LossPDEStatio(LossPDEAbstract):
441
451
  ) # we return a function, that way
442
452
  # the order of sobolev_m is static and the conditional in the recursive
443
453
  # function is properly set
444
- self.sobolev_m = self.sobolev_m
445
454
  else:
446
455
  self.sobolev_reg = None
447
456
 
448
- if self.normalization_loss is None:
449
- self.loss_weights["norm_loss"] = 0
450
-
451
- if self.omega_boundary_fun is None:
452
- self.loss_weights["boundary_loss"] = 0
453
-
454
- if self.sobolev_reg is None:
455
- self.loss_weights["sobolev"] = 0
457
+ for k in _LOSS_WEIGHT_KEYS_PDESTATIO:
458
+ if k not in self.loss_weights.keys():
459
+ self.loss_weights[k] = 0
456
460
 
457
461
  if (
458
462
  isinstance(self.omega_boundary_fun, dict)
@@ -533,7 +537,6 @@ class LossPDEStatio(LossPDEAbstract):
533
537
  )
534
538
  else:
535
539
  mse_norm_loss = jnp.array(0.0)
536
- self.loss_weights["norm_loss"] = 0
537
540
 
538
541
  # boundary part
539
542
  params_ = _set_derivatives(params, "boundary_loss", self.derivative_keys)
@@ -567,7 +570,6 @@ class LossPDEStatio(LossPDEAbstract):
567
570
  )
568
571
  else:
569
572
  mse_observation_loss = jnp.array(0.0)
570
- self.loss_weights["observations"] = 0
571
573
 
572
574
  # Sobolev regularization
573
575
  params_ = _set_derivatives(params, "sobolev", self.derivative_keys)
@@ -582,7 +584,6 @@ class LossPDEStatio(LossPDEAbstract):
582
584
  )
583
585
  else:
584
586
  mse_sobolev_loss = jnp.array(0.0)
585
- self.loss_weights["sobolev"] = 0
586
587
 
587
588
  # total loss
588
589
  total_loss = (
@@ -785,8 +786,9 @@ class LossPDENonStatio(LossPDEStatio):
785
786
  else:
786
787
  self.sobolev_reg = None
787
788
 
788
- if self.sobolev_reg is None:
789
- self.loss_weights["sobolev"] = 0
789
+ for k in _LOSS_WEIGHT_KEYS_PDENONSTATIO:
790
+ if k not in self.loss_weights.keys():
791
+ self.loss_weights[k] = 0
790
792
 
791
793
  def __call__(self, *args, **kwargs):
792
794
  return self.evaluate(*args, **kwargs)
@@ -924,7 +926,6 @@ class LossPDENonStatio(LossPDEStatio):
924
926
  )
925
927
  else:
926
928
  mse_observation_loss = jnp.array(0.0)
927
- self.loss_weights["observations"] = 0
928
929
 
929
930
  # Sobolev regularization
930
931
  params_ = _set_derivatives(params, "sobolev", self.derivative_keys)
@@ -939,7 +940,6 @@ class LossPDENonStatio(LossPDEStatio):
939
940
  )
940
941
  else:
941
942
  mse_sobolev_loss = jnp.array(0.0)
942
- self.loss_weights["sobolev"] = 0.0
943
943
 
944
944
  # total loss
945
945
  total_loss = (
jinns/solver/_solve.py CHANGED
@@ -3,9 +3,9 @@ This modules implements the main `solve()` function of jinns which
3
3
  handles the optimization process
4
4
  """
5
5
 
6
+ import copy
7
+ from functools import partial
6
8
  import optax
7
- from tqdm import tqdm
8
- from jax_tqdm import scan_tqdm
9
9
  import jax
10
10
  from jax import jit
11
11
  import jax.numpy as jnp
@@ -19,8 +19,32 @@ from jinns.data._DataGenerators import (
19
19
  append_param_batch,
20
20
  append_obs_batch,
21
21
  )
22
+ from jinns.utils._containers import *
22
23
 
23
- from functools import partial
24
+
25
+ def check_batch_size(other_data, main_data, attr_name):
26
+ if (
27
+ (
28
+ isinstance(main_data, DataGeneratorODE)
29
+ and getattr(other_data, attr_name) != main_data.temporal_batch_size
30
+ )
31
+ or (
32
+ isinstance(main_data, CubicMeshPDEStatio)
33
+ and not isinstance(main_data, CubicMeshPDENonStatio)
34
+ and getattr(other_data, attr_name) != main_data.omega_batch_size
35
+ )
36
+ or (
37
+ isinstance(main_data, CubicMeshPDENonStatio)
38
+ and getattr(other_data, attr_name)
39
+ != main_data.omega_batch_size * main_data.temporal_batch_size
40
+ )
41
+ ):
42
+ raise ValueError(
43
+ "Optional other_data.param_batch_size must be"
44
+ " equal to main_data.temporal_batch_size or main_data.omega_batch_size or"
45
+ " the product of both dependeing on the type of the main"
46
+ " datagenerator"
47
+ )
24
48
 
25
49
 
26
50
  def solve(
@@ -35,6 +59,7 @@ def solve(
35
59
  tracked_params_key_list=None,
36
60
  param_data=None,
37
61
  obs_data=None,
62
+ validation=None,
38
63
  obs_batch_sharding=None,
39
64
  ):
40
65
  """
@@ -88,6 +113,20 @@ def solve(
88
113
  obs_data
89
114
  Default None. A DataGeneratorObservations object which can be used to
90
115
  sample minibatches of observations
116
+ validation
117
+ Default None. Otherwise, a callable ``eqx.Module`` which implements a
118
+ validation strategy. See documentation of :obj:`~jinns.validation.
119
+ _validation.AbstractValidationModule` for the general interface, and
120
+ :obj:`~jinns.validation._validation.ValidationLoss` for a practical
121
+ implementation of a vanilla validation stategy on a validation set of
122
+ collocation points.
123
+
124
+ **Note**: The ``__call__(self, params)`` method should have
125
+ the latter prescribed signature and return ``(validation [eqx.Module],
126
+ early_stop [bool], validation_criterion [Array])``. It is called every
127
+ ``validation.call_every`` iteration. Users are free to design any
128
+ validation strategy of their choice, and to decide on the early
129
+ stopping criterion.
91
130
  obs_batch_sharding
92
131
  Default None. An optional sharding object to constraint the obs_batch.
93
132
  Typically, a SingleDeviceSharding(gpu_device) when obs_data has been
@@ -114,39 +153,14 @@ def solve(
114
153
  A dictionary. At each key an array of the values of the parameters
115
154
  given in tracked_params_key_list is stored
116
155
  """
117
- params = init_params
118
- last_non_nan_params = init_params.copy()
119
-
120
156
  if param_data is not None:
121
- if (
122
- (
123
- isinstance(data, DataGeneratorODE)
124
- and param_data.param_batch_size != data.temporal_batch_size
125
- and obs_data.obs_batch_size != data.temporal_batch_size
126
- )
127
- or (
128
- isinstance(data, CubicMeshPDEStatio)
129
- and not isinstance(data, CubicMeshPDENonStatio)
130
- and param_data.param_batch_size != data.omega_batch_size
131
- and obs_data.obs_batch_size != data.omega_batch_size
132
- )
133
- or (
134
- isinstance(data, CubicMeshPDENonStatio)
135
- and param_data.param_batch_size
136
- != data.omega_batch_size * data.temporal_batch_size
137
- and obs_data.obs_batch_size
138
- != data.omega_batch_size * data.temporal_batch_size
139
- )
140
- ):
141
- raise ValueError(
142
- "Optional param_data.param_batch_size must be"
143
- " equal to data.temporal_batch_size or data.omega_batch_size or"
144
- " the product of both dependeing on the type of the main"
145
- " datagenerator"
146
- )
157
+ check_batch_size(param_data, data, "param_batch_size")
158
+
159
+ if obs_data is not None:
160
+ check_batch_size(obs_data, data, "obs_batch_size")
147
161
 
148
162
  if opt_state is None:
149
- opt_state = optimizer.init(params)
163
+ opt_state = optimizer.init(init_params)
150
164
 
151
165
  # RAR sampling init (ouside scanned function to avoid dynamic slice error)
152
166
  # If RAR is not used the _rar_step_*() are juste None and data is unchanged
@@ -160,7 +174,7 @@ def solve(
160
174
  ), "data.method must be uniform if using seq2seq learning !"
161
175
  data, opt_state = initialize_seq2seq(loss, data, seq2seq, opt_state)
162
176
 
163
- total_loss_values = jnp.zeros((n_iter))
177
+ train_loss_values = jnp.zeros((n_iter))
164
178
  # depending on obs_batch_sharding we will get the simple get_batch or the
165
179
  # get_batch with device_put, the latter is not jittable
166
180
  get_batch = get_get_batch(obs_batch_sharding)
@@ -168,68 +182,125 @@ def solve(
168
182
  # initialize the dict for stored parameter values
169
183
  # we need to get a loss_term to init stuff
170
184
  batch_ini, data, param_data, obs_data = get_batch(data, param_data, obs_data)
171
- _, loss_terms = loss(params, batch_ini)
185
+ _, loss_terms = loss(init_params, batch_ini)
172
186
  if tracked_params_key_list is None:
173
187
  tracked_params_key_list = []
174
- tracked_params = _tracked_parameters(params, tracked_params_key_list)
188
+ tracked_params = _tracked_parameters(init_params, tracked_params_key_list)
175
189
  stored_params = jax.tree_util.tree_map(
176
190
  lambda tracked_param, param: (
177
191
  jnp.zeros((n_iter,) + param.shape) if tracked_param else None
178
192
  ),
179
193
  tracked_params,
180
- params,
194
+ init_params,
181
195
  )
182
196
 
183
197
  # initialize the dict for stored loss values
184
198
  stored_loss_terms = jax.tree_util.tree_map(
185
- lambda x: jnp.zeros((n_iter)), loss_terms
199
+ lambda _: jnp.zeros((n_iter)), loss_terms
200
+ )
201
+
202
+ train_data = DataGeneratorContainer(
203
+ data=data, param_data=param_data, obs_data=obs_data
204
+ )
205
+ optimization = OptimizationContainer(
206
+ params=init_params, last_non_nan_params=init_params.copy(), opt_state=opt_state
207
+ )
208
+ optimization_extra = OptimizationExtraContainer(
209
+ curr_seq=curr_seq,
210
+ seq2seq=seq2seq,
186
211
  )
212
+ loss_container = LossContainer(
213
+ stored_loss_terms=stored_loss_terms,
214
+ train_loss_values=train_loss_values,
215
+ )
216
+ stored_objects = StoredObjectContainer(
217
+ stored_params=stored_params,
218
+ )
219
+
220
+ if validation is not None:
221
+ validation_crit_values = jnp.zeros(n_iter)
222
+ else:
223
+ validation_crit_values = None
224
+
225
+ break_fun = get_break_fun(n_iter)
187
226
 
227
+ iteration = 0
188
228
  carry = (
189
- init_params,
190
- init_params.copy(),
191
- data,
192
- curr_seq,
193
- seq2seq,
194
- stored_params,
195
- stored_loss_terms,
229
+ iteration,
196
230
  loss,
197
- param_data,
198
- obs_data,
199
- opt_state,
200
- total_loss_values,
231
+ optimization,
232
+ optimization_extra,
233
+ train_data,
234
+ validation,
235
+ loss_container,
236
+ stored_objects,
237
+ validation_crit_values,
201
238
  )
202
239
 
203
- def one_iteration(carry, i):
240
+ def one_iteration(carry):
204
241
  (
205
- params,
206
- last_non_nan_params,
207
- data,
208
- curr_seq,
209
- seq2seq,
210
- stored_params,
211
- stored_loss_terms,
242
+ i,
212
243
  loss,
213
- param_data,
214
- obs_data,
215
- opt_state,
216
- total_loss_values,
244
+ optimization,
245
+ optimization_extra,
246
+ train_data,
247
+ validation,
248
+ loss_container,
249
+ stored_objects,
250
+ validation_crit_values,
217
251
  ) = carry
218
- batch, data, param_data, obs_data = get_batch(data, param_data, obs_data)
219
252
 
253
+ batch, data, param_data, obs_data = get_batch(
254
+ train_data.data, train_data.param_data, train_data.obs_data
255
+ )
256
+
257
+ # Gradient step
220
258
  (
221
259
  loss,
222
- loss_val,
260
+ train_loss_value,
223
261
  loss_terms,
224
262
  params,
225
263
  opt_state,
226
264
  last_non_nan_params,
227
265
  ) = gradient_step(
228
- loss, optimizer, batch, params, opt_state, last_non_nan_params
266
+ loss,
267
+ optimizer,
268
+ batch,
269
+ optimization.params,
270
+ optimization.opt_state,
271
+ optimization.last_non_nan_params,
229
272
  )
230
273
 
231
- # Print loss during optimization
232
- print_fn(i, loss_val, print_loss_every)
274
+ # Print train loss value during optimization
275
+ print_fn(i, train_loss_value, print_loss_every, prefix="[train] ")
276
+
277
+ if validation is not None:
278
+ # there is a jax.lax.cond because we do not necesarily call the
279
+ # validation step every iteration
280
+ (
281
+ validation, # always return `validation` for in-place mutation
282
+ early_stopping,
283
+ validation_criterion,
284
+ ) = jax.lax.cond(
285
+ i % validation.call_every == 0,
286
+ lambda operands: operands[0](*operands[1:]), # validation.__call__()
287
+ lambda operands: (
288
+ operands[0],
289
+ False,
290
+ validation_crit_values[i - 1],
291
+ ),
292
+ (
293
+ validation, # validation must be in operands
294
+ params,
295
+ ),
296
+ )
297
+ # Print validation loss value during optimization
298
+ print_fn(i, validation_criterion, print_loss_every, prefix="[validation] ")
299
+ validation_crit_values = validation_crit_values.at[i].set(
300
+ validation_criterion
301
+ )
302
+ else:
303
+ early_stopping = False
233
304
 
234
305
  # Trigger RAR
235
306
  loss, params, data = trigger_rar(
@@ -238,84 +309,98 @@ def solve(
238
309
 
239
310
  # Trigger seq2seq
240
311
  loss, params, data, opt_state, curr_seq, seq2seq = trigger_seq2seq(
241
- i, loss, params, data, opt_state, curr_seq, seq2seq
312
+ i,
313
+ loss,
314
+ params,
315
+ data,
316
+ opt_state,
317
+ optimization_extra.curr_seq,
318
+ optimization_extra.seq2seq,
242
319
  )
243
320
 
244
321
  # save loss value and selected parameters
245
- stored_params, stored_loss_terms, total_loss_values = store_loss_and_params(
322
+ stored_params, stored_loss_terms, train_loss_values = store_loss_and_params(
246
323
  i,
247
324
  params,
248
- stored_params,
249
- stored_loss_terms,
250
- total_loss_values,
251
- loss_val,
325
+ stored_objects.stored_params,
326
+ loss_container.stored_loss_terms,
327
+ loss_container.train_loss_values,
328
+ train_loss_value,
252
329
  loss_terms,
253
330
  tracked_params,
254
331
  )
332
+ i += 1
333
+
255
334
  return (
256
- params,
257
- last_non_nan_params,
258
- data,
259
- curr_seq,
260
- seq2seq,
261
- stored_params,
262
- stored_loss_terms,
335
+ i,
263
336
  loss,
264
- param_data,
265
- obs_data,
266
- opt_state,
267
- total_loss_values,
268
- ), None
337
+ OptimizationContainer(params, last_non_nan_params, opt_state),
338
+ OptimizationExtraContainer(curr_seq, seq2seq, early_stopping),
339
+ DataGeneratorContainer(data, param_data, obs_data),
340
+ validation,
341
+ LossContainer(stored_loss_terms, train_loss_values),
342
+ StoredObjectContainer(stored_params),
343
+ validation_crit_values,
344
+ )
269
345
 
270
- # Main optimization loop. We use the fully scanned (fully jitted) version
271
- # if no mixing devices. Otherwise we use the for loop. Here devices only
346
+ # Main optimization loop. We use the LAX while loop (fully jitted) version
347
+ # if no mixing devices. Otherwise we use the standard while loop. Here devices only
272
348
  # concern obs_batch, but it could lead to more complex scheme in the future
273
349
  if obs_batch_sharding is not None:
274
- for i in tqdm(range(n_iter)):
275
- carry, _ = one_iteration(carry, i)
350
+ while break_fun(carry):
351
+ carry = one_iteration(carry)
276
352
  else:
277
- carry, _ = jax.lax.scan(
278
- scan_tqdm(n_iter)(one_iteration),
279
- carry,
280
- jnp.arange(n_iter),
281
- )
353
+ carry = jax.lax.while_loop(break_fun, one_iteration, carry)
282
354
 
283
355
  (
284
- init_params,
285
- last_non_nan_params,
286
- data,
287
- curr_seq,
288
- seq2seq,
289
- stored_params,
290
- stored_loss_terms,
356
+ i,
291
357
  loss,
292
- param_data,
293
- obs_data,
294
- opt_state,
295
- total_loss_values,
358
+ optimization,
359
+ optimization_extra,
360
+ train_data,
361
+ validation,
362
+ loss_container,
363
+ stored_objects,
364
+ validation_crit_values,
296
365
  ) = carry
297
366
 
298
367
  jax.debug.print(
299
- "Iteration {i}: loss value = {total_loss_val}",
300
- i=n_iter,
301
- total_loss_val=total_loss_values[-1],
368
+ "Final iteration {i}: train loss value = {train_loss_val}",
369
+ i=i,
370
+ train_loss_val=loss_container.train_loss_values[i - 1],
302
371
  )
372
+ if validation is not None:
373
+ jax.debug.print(
374
+ "validation loss value = {validation_loss_val}",
375
+ validation_loss_val=validation_crit_values[i - 1],
376
+ )
303
377
 
378
+ if validation is None:
379
+ return (
380
+ optimization.last_non_nan_params,
381
+ loss_container.train_loss_values,
382
+ loss_container.stored_loss_terms,
383
+ train_data.data,
384
+ loss,
385
+ optimization.opt_state,
386
+ stored_objects.stored_params,
387
+ )
304
388
  return (
305
- last_non_nan_params,
306
- total_loss_values,
307
- stored_loss_terms,
308
- data,
389
+ optimization.last_non_nan_params,
390
+ loss_container.train_loss_values,
391
+ loss_container.stored_loss_terms,
392
+ train_data.data,
309
393
  loss,
310
- opt_state,
311
- stored_params,
394
+ optimization.opt_state,
395
+ stored_objects.stored_params,
396
+ validation_crit_values,
312
397
  )
313
398
 
314
399
 
315
400
  @partial(jit, static_argnames=["optimizer"])
316
401
  def gradient_step(loss, optimizer, batch, params, opt_state, last_non_nan_params):
317
402
  """
318
- loss and optimizer cannot be jit-ted.
403
+ optimizer cannot be jit-ted.
319
404
  """
320
405
  value_grad_loss = jax.value_and_grad(loss, has_aux=True)
321
406
  (loss_val, loss_terms), grads = value_grad_loss(params, batch)
@@ -340,14 +425,14 @@ def gradient_step(loss, optimizer, batch, params, opt_state, last_non_nan_params
340
425
  )
341
426
 
342
427
 
343
- @jit
344
- def print_fn(i, loss_val, print_loss_every):
428
+ @partial(jit, static_argnames=["prefix"])
429
+ def print_fn(i, loss_val, print_loss_every, prefix=""):
345
430
  # note that if the following is not jitted in the main lor loop, it is
346
431
  # super slow
347
432
  _ = jax.lax.cond(
348
433
  i % print_loss_every == 0,
349
434
  lambda _: jax.debug.print(
350
- "Iteration {i}: loss value = {loss_val}",
435
+ prefix + "Iteration {i}: loss value = {loss_val}",
351
436
  i=i,
352
437
  loss_val=loss_val,
353
438
  ),
@@ -362,8 +447,8 @@ def store_loss_and_params(
362
447
  params,
363
448
  stored_params,
364
449
  stored_loss_terms,
365
- total_loss_values,
366
- loss_val,
450
+ train_loss_values,
451
+ train_loss_val,
367
452
  loss_terms,
368
453
  tracked_params,
369
454
  ):
@@ -384,8 +469,66 @@ def store_loss_and_params(
384
469
  loss_terms,
385
470
  )
386
471
 
387
- total_loss_values = total_loss_values.at[i].set(loss_val)
388
- return stored_params, stored_loss_terms, total_loss_values
472
+ train_loss_values = train_loss_values.at[i].set(train_loss_val)
473
+ return (stored_params, stored_loss_terms, train_loss_values)
474
+
475
+
476
+ def get_break_fun(n_iter):
477
+ """
478
+ Wrapper to get the break_fun with appropriate `n_iter`
479
+ """
480
+
481
+ @jit
482
+ def break_fun(carry):
483
+ """
484
+ Function to break from the main optimization loop
485
+ We check several conditions
486
+ """
487
+
488
+ def stop_while_loop(msg):
489
+ """
490
+ Note that the message is wrapped in the jax.lax.cond because a
491
+ string is not a valid JAX type that can be fed into the operands
492
+ """
493
+ jax.debug.print(f"Stopping main optimization loop, cause: {msg}")
494
+ return False
495
+
496
+ def continue_while_loop(_):
497
+ return True
498
+
499
+ (i, _, optimization, optimization_extra, _, _, _, _, _) = carry
500
+
501
+ # Condition 1
502
+ bool_max_iter = jax.lax.cond(
503
+ i >= n_iter,
504
+ lambda _: stop_while_loop("max iteration is reached"),
505
+ continue_while_loop,
506
+ None,
507
+ )
508
+ # Condition 2
509
+ bool_nan_in_params = jax.lax.cond(
510
+ _check_nan_in_pytree(optimization.params),
511
+ lambda _: stop_while_loop(
512
+ "NaN values in parameters " "(returning last non NaN values)"
513
+ ),
514
+ continue_while_loop,
515
+ None,
516
+ )
517
+ # Condition 3
518
+ bool_early_stopping = jax.lax.cond(
519
+ optimization_extra.early_stopping,
520
+ lambda _: stop_while_loop("early stopping"),
521
+ continue_while_loop,
522
+ _,
523
+ )
524
+
525
+ # stop when one of the cond to continue is False
526
+ return jax.tree_util.tree_reduce(
527
+ lambda x, y: jnp.logical_and(jnp.array(x), jnp.array(y)),
528
+ (bool_max_iter, bool_nan_in_params, bool_early_stopping),
529
+ )
530
+
531
+ return break_fun
389
532
 
390
533
 
391
534
  def get_get_batch(obs_batch_sharding):
@@ -0,0 +1,57 @@
1
+ """
2
+ NamedTuples definition
3
+ """
4
+
5
+ from typing import Union, NamedTuple
6
+ from jaxtyping import PyTree
7
+ from jax.typing import ArrayLike
8
+ import optax
9
+ import jax.numpy as jnp
10
+ from jinns.loss._LossODE import LossODE, SystemLossODE
11
+ from jinns.loss._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
12
+ from jinns.data._DataGenerators import (
13
+ DataGeneratorODE,
14
+ CubicMeshPDEStatio,
15
+ CubicMeshPDENonStatio,
16
+ DataGeneratorParameter,
17
+ DataGeneratorObservations,
18
+ DataGeneratorObservationsMultiPINNs,
19
+ )
20
+
21
+
22
+ class DataGeneratorContainer(NamedTuple):
23
+ data: Union[DataGeneratorODE, CubicMeshPDEStatio, CubicMeshPDENonStatio]
24
+ param_data: Union[DataGeneratorParameter, None] = None
25
+ obs_data: Union[
26
+ DataGeneratorObservations, DataGeneratorObservationsMultiPINNs, None
27
+ ] = None
28
+
29
+
30
+ class ValidationContainer(NamedTuple):
31
+ loss: Union[
32
+ LossODE, SystemLossODE, LossPDEStatio, LossPDENonStatio, SystemLossPDE, None
33
+ ]
34
+ data: DataGeneratorContainer
35
+ hyperparams: PyTree = None
36
+ loss_values: Union[ArrayLike, None] = None
37
+
38
+
39
+ class OptimizationContainer(NamedTuple):
40
+ params: dict
41
+ last_non_nan_params: dict
42
+ opt_state: optax.OptState
43
+
44
+
45
+ class OptimizationExtraContainer(NamedTuple):
46
+ curr_seq: int
47
+ seq2seq: Union[dict, None]
48
+ early_stopping: bool = False
49
+
50
+
51
+ class LossContainer(NamedTuple):
52
+ stored_loss_terms: dict
53
+ train_loss_values: ArrayLike
54
+
55
+
56
+ class StoredObjectContainer(NamedTuple):
57
+ stored_params: Union[list, None]
jinns/utils/_hyperpinn.py CHANGED
@@ -216,13 +216,7 @@ def create_HYPERPINN(
216
216
 
217
217
  Returns
218
218
  -------
219
- init_fn
220
- A function which (re-)initializes the PINN parameters with the provided
221
- jax random key
222
- apply_fn
223
- A function to apply the neural network on given inputs for given
224
- parameters. A typical call will be of the form `u(t, params)` for
225
- ODE or `u(t, x, params)` for nD PDEs (`x` being multidimensional)
219
+ `u`, a :class:`.HyperPINN` object which inherits from `eqx.Module` (hence callable).
226
220
 
227
221
  Raises
228
222
  ------
@@ -289,7 +283,6 @@ def create_HYPERPINN(
289
283
 
290
284
  if shared_pinn_outputs is not None:
291
285
  hyperpinns = []
292
- static = None
293
286
  for output_slice in shared_pinn_outputs:
294
287
  hyperpinn = HYPERPINN(
295
288
  mlp,
@@ -302,11 +295,6 @@ def create_HYPERPINN(
302
295
  hypernet_input_size,
303
296
  output_slice,
304
297
  )
305
- # all the pinns are in fact the same so we share the same static
306
- if static is None:
307
- static = hyperpinn.static
308
- else:
309
- hyperpinn.static = static
310
298
  hyperpinns.append(hyperpinn)
311
299
  return hyperpinns
312
300
  hyperpinn = HYPERPINN(
jinns/utils/_pinn.py CHANGED
@@ -200,13 +200,7 @@ def create_PINN(
200
200
 
201
201
  Returns
202
202
  -------
203
- init_fn
204
- A function which (re-)initializes the PINN parameters with the provided
205
- jax random key
206
- apply_fn
207
- A function to apply the neural network on given inputs for given
208
- parameters. A typical call will be of the form `u(t, params)` for
209
- ODE or `u(t, x, params)` for nD PDEs (`x` being multidimensional)
203
+ `u`, a :class:`.PINN` object which inherits from `eqx.Module` (hence callable). This comes with a bound method :func:`u.init_params() <PINN.init_params>`. When `shared_pinn_ouput` is not None, a list of :class:`.PINN` with the same structure is returned, only differing by there final slicing of the network output.
210
204
 
211
205
  Raises
212
206
  ------
@@ -253,7 +247,6 @@ def create_PINN(
253
247
 
254
248
  if shared_pinn_outputs is not None:
255
249
  pinns = []
256
- static = None
257
250
  for output_slice in shared_pinn_outputs:
258
251
  pinn = PINN(
259
252
  mlp,
@@ -263,11 +256,6 @@ def create_PINN(
263
256
  output_transform,
264
257
  output_slice,
265
258
  )
266
- # all the pinns are in fact the same so we share the same static
267
- if static is None:
268
- static = pinn.static
269
- else:
270
- pinn.static = static
271
259
  pinns.append(pinn)
272
260
  return pinns
273
261
  pinn = PINN(mlp, slice_solution, eq_type, input_transform, output_transform, None)
jinns/utils/_spinn.py CHANGED
@@ -194,13 +194,7 @@ def create_SPINN(key, d, r, eqx_list, eq_type, m=1):
194
194
 
195
195
  Returns
196
196
  -------
197
- init_fn
198
- A function which (re-)initializes the SPINN parameters with the provided
199
- jax random key
200
- apply_fn
201
- A function to apply the neural network on given inputs for given
202
- parameters. A typical call will be of the form `u(t, params)` for
203
- ODE or `u(t, x, params)` for nD PDEs (`x` being multidimensional)
197
+ `u`, a :class:`.SPINN` object which inherits from `eqx.Module` (hence callable).
204
198
 
205
199
  Raises
206
200
  ------
@@ -0,0 +1 @@
1
+ from ._validation import AbstractValidationModule, ValidationLoss
@@ -0,0 +1,214 @@
1
+ """
2
+ Implements some validation functions and their associated hyperparameter
3
+ """
4
+
5
+ import copy
6
+ import abc
7
+ from typing import Union
8
+ import equinox as eqx
9
+ import jax
10
+ import jax.numpy as jnp
11
+ from jaxtyping import Array, Bool, PyTree, Int
12
+ import jinns
13
+ import jinns.data
14
+ from jinns.loss import LossODE, LossPDENonStatio, LossPDEStatio
15
+ from jinns.data._DataGenerators import (
16
+ DataGeneratorODE,
17
+ CubicMeshPDEStatio,
18
+ CubicMeshPDENonStatio,
19
+ DataGeneratorParameter,
20
+ DataGeneratorObservations,
21
+ DataGeneratorObservationsMultiPINNs,
22
+ append_obs_batch,
23
+ append_param_batch,
24
+ )
25
+ import jinns.loss
26
+
27
+ # Using eqx Module for the DataClass + Pytree inheritance
28
+ # Abstract class and abstract/final pattern is used
29
+ # see : https://docs.kidger.site/equinox/pattern/
30
+
31
+
32
+ class AbstractValidationModule(eqx.Module):
33
+ """Abstract class representing interface for any validation module. It must
34
+ 1. have a ``call_every`` attribute.
35
+ 2. implement a ``__call__`` returning ``(AbstractValidationModule, Bool, Array)``
36
+ """
37
+
38
+ call_every: eqx.AbstractVar[Int] # Mandatory for all validation step,
39
+ # it tells that the validation step is performed every call_every
40
+ # iterations.
41
+
42
+ @abc.abstractmethod
43
+ def __call__(
44
+ self, params: PyTree
45
+ ) -> tuple["AbstractValidationModule", Bool, Array]:
46
+ raise NotImplementedError
47
+
48
+
49
+ class ValidationLoss(AbstractValidationModule):
50
+ """
51
+ Implementation of a vanilla validation module returning the PINN loss
52
+ on a validation set of collocation points. This can be used as a baseline
53
+ for more complicated validation strategy.
54
+ """
55
+
56
+ loss: Union[callable, LossODE, LossPDEStatio, LossPDENonStatio] = eqx.field(
57
+ converter=copy.deepcopy
58
+ )
59
+ validation_data: Union[DataGeneratorODE, CubicMeshPDEStatio, CubicMeshPDENonStatio]
60
+ validation_param_data: Union[DataGeneratorParameter, None] = None
61
+ validation_obs_data: Union[
62
+ DataGeneratorObservations, DataGeneratorObservationsMultiPINNs, None
63
+ ] = None
64
+ call_every: Int = 250 # concrete typing
65
+ early_stopping: Bool = True # globally control if early stopping happens
66
+
67
+ patience: Union[Int] = 10
68
+ best_val_loss: Array = eqx.field(
69
+ converter=jnp.asarray, default_factory=lambda: jnp.array(jnp.inf)
70
+ )
71
+
72
+ counter: Array = eqx.field(
73
+ converter=jnp.asarray, default_factory=lambda: jnp.array(0.0)
74
+ )
75
+
76
+ def __call__(self, params) -> tuple["ValidationLoss", Bool, Array]:
77
+ # do in-place mutation
78
+ val_batch = self.validation_data.get_batch()
79
+ if self.validation_param_data is not None:
80
+ val_batch = append_param_batch(
81
+ val_batch, self.validation_param_data.get_batch()
82
+ )
83
+ if self.validation_obs_data is not None:
84
+ val_batch = append_obs_batch(
85
+ val_batch, self.validation_obs_data.get_batch()
86
+ )
87
+
88
+ validation_loss_value, _ = self.loss(params, val_batch)
89
+ (counter, best_val_loss) = jax.lax.cond(
90
+ validation_loss_value < self.best_val_loss,
91
+ lambda _: (jnp.array(0.0), validation_loss_value), # reset
92
+ lambda operands: (operands[0] + 1, operands[1]), # increment
93
+ (self.counter, self.best_val_loss),
94
+ )
95
+
96
+ # use eqx.tree_at to update attributes
97
+ # (https://github.com/patrick-kidger/equinox/issues/396)
98
+ new = eqx.tree_at(lambda t: t.counter, self, counter)
99
+ new = eqx.tree_at(lambda t: t.best_val_loss, new, best_val_loss)
100
+
101
+ bool_early_stopping = jax.lax.cond(
102
+ jnp.logical_and(
103
+ jnp.array(self.counter == self.patience),
104
+ jnp.array(self.early_stopping),
105
+ ),
106
+ lambda _: True,
107
+ lambda _: False,
108
+ None,
109
+ )
110
+ # return `new` cause no in-place modification of the eqx.Module
111
+ return (new, bool_early_stopping, validation_loss_value)
112
+
113
+
114
+ if __name__ == "__main__":
115
+ import jax
116
+ import jax.numpy as jnp
117
+ import jax.random as random
118
+ from jinns.loss import BurgerEquation
119
+
120
+ key = random.PRNGKey(1)
121
+ key, subkey = random.split(key)
122
+
123
+ n = 50
124
+ nb = 2 * 2 * 10
125
+ nt = 10
126
+ omega_batch_size = 10
127
+ omega_border_batch_size = 10
128
+ temporal_batch_size = 4
129
+ dim = 1
130
+ xmin = 0
131
+ xmax = 1
132
+ tmin, tmax = 0, 1
133
+ method = "uniform"
134
+
135
+ val_data = jinns.data.CubicMeshPDENonStatio(
136
+ subkey,
137
+ n,
138
+ nb,
139
+ nt,
140
+ omega_batch_size,
141
+ omega_border_batch_size,
142
+ temporal_batch_size,
143
+ dim,
144
+ (xmin,),
145
+ (xmax,),
146
+ tmin,
147
+ tmax,
148
+ method,
149
+ )
150
+
151
+ eqx_list = [
152
+ [eqx.nn.Linear, 2, 50],
153
+ [jax.nn.tanh],
154
+ [eqx.nn.Linear, 50, 50],
155
+ [jax.nn.tanh],
156
+ [eqx.nn.Linear, 50, 50],
157
+ [jax.nn.tanh],
158
+ [eqx.nn.Linear, 50, 50],
159
+ [jax.nn.tanh],
160
+ [eqx.nn.Linear, 50, 50],
161
+ [jax.nn.tanh],
162
+ [eqx.nn.Linear, 50, 2],
163
+ ]
164
+
165
+ key, subkey = random.split(key)
166
+ u = jinns.utils.create_PINN(
167
+ subkey, eqx_list, "nonstatio_PDE", 2, slice_solution=jnp.s_[:1]
168
+ )
169
+ init_nn_params = u.init_params()
170
+
171
+ dyn_loss = BurgerEquation()
172
+ loss_weights = {"dyn_loss": 1, "boundary_loss": 10, "observations": 10}
173
+
174
+ key, subkey = random.split(key)
175
+ loss = jinns.loss.LossPDENonStatio(
176
+ u=u,
177
+ loss_weights=loss_weights,
178
+ dynamic_loss=dyn_loss,
179
+ norm_key=subkey,
180
+ norm_borders=(-1, 1),
181
+ )
182
+ print(id(loss))
183
+ validation = ValidationLoss(
184
+ call_every=250,
185
+ early_stopping=True,
186
+ patience=1000,
187
+ loss=loss,
188
+ validation_data=val_data,
189
+ validation_param_data=None,
190
+ )
191
+ print(id(validation.loss) is not id(loss)) # should be True (deepcopy)
192
+
193
+ init_params = {"nn_params": init_nn_params, "eq_params": {"nu": 1.0}}
194
+
195
+ print(validation.loss is loss)
196
+ loss.evaluate(init_params, val_data.get_batch())
197
+ print(loss.norm_key)
198
+ print("Call validation once")
199
+ validation, _, _ = validation(init_params)
200
+ print(validation.loss is loss)
201
+ print(validation.loss.norm_key == loss.norm_key)
202
+ print("Crate new pytree from validation and call it once")
203
+ new_val = eqx.tree_at(lambda t: t.counter, validation, jnp.array(3.0))
204
+ print(validation.loss is new_val.loss) # FALSE
205
+ # test if attribute have been modified
206
+ new_val, _, _ = new_val(init_params)
207
+ print(f"{new_val.loss is loss=}")
208
+ print(f"{loss.norm_key=}")
209
+ print(f"{validation.loss.norm_key=}")
210
+ print(f"{new_val.loss.norm_key=}")
211
+ print(f"{new_val.loss.norm_key == loss.norm_key=}")
212
+ print(f"{new_val.loss.norm_key == validation.loss.norm_key=}")
213
+ print(new_val.counter)
214
+ print(validation.counter)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jinns
3
- Version: 0.8.5
3
+ Version: 0.8.7
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -6,8 +6,8 @@ jinns/experimental/__init__.py,sha256=3jCIy2R2i_0Erwxg-HwISdH79Nt1XCXhS9yY1F5awi
6
6
  jinns/experimental/_diffrax_solver.py,sha256=sLT22byqh-6015_fhe1xtMWlFOYcCjzYKET4sLhA9R4,6818
7
7
  jinns/loss/_DynamicLoss.py,sha256=L4CVmmF0rTPbHntgqsLLHlnrlQgLHsetUocpJm7ZYag,27461
8
8
  jinns/loss/_DynamicLossAbstract.py,sha256=kTQlhLx7SBuH5dIDmYaE79sVHUZt1nUFa8LxPU5IHhM,8504
9
- jinns/loss/_LossODE.py,sha256=sxpgiDR6mfoREuc-qe0AkirOe5K_5oblaYCnodTNxoI,21912
10
- jinns/loss/_LossPDE.py,sha256=_yX3R-FrAScTn9_QfVC8PfDYRE4UQ5lnzITUYgNFitA,61766
9
+ jinns/loss/_LossODE.py,sha256=b9doBHoQwYvlgpqzrNO4dOaTN87LRvjHtHbz9bMoH7E,22119
10
+ jinns/loss/_LossPDE.py,sha256=purAEtc0e71kv9XnZUT-a7MrkDAkM_3tTI4xJPu6fH4,61629
11
11
  jinns/loss/_Losses.py,sha256=XOL3MFiKEd3ndsc78Qnpi1vbgR0B2HaAWOGGW2meDM8,11190
12
12
  jinns/loss/__init__.py,sha256=pFNYUxns-NPXBFdqrEVSiXkQLfCtKw-t2trlhvLzpYE,355
13
13
  jinns/loss/_boundary_conditions.py,sha256=YfSnLZ25hXqQ5KWAuxOrWSKkf_oBqAc9GQV4z7MjWyQ,17434
@@ -15,17 +15,20 @@ jinns/loss/_operators.py,sha256=zDGJqYqeYH7xd-4dtGX9PS-pf0uSOpUUXGo5SVjIJ4o,1106
15
15
  jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  jinns/solver/_rar.py,sha256=K-0y1-ofOAo1n_Ea3QShSGCGKVYTwiaE_Bz9-DZMJm8,14525
17
17
  jinns/solver/_seq2seq.py,sha256=FL-42hTgmVl7O3hHh1ccFVw2bT8bW82hvlDRz971Chk,5620
18
- jinns/solver/_solve.py,sha256=r4jn6hx7_t-Y2rBWA2npUmWWnDg4iRbgYBHZDNn9tmY,13745
18
+ jinns/solver/_solve.py,sha256=mGi0zaT_fK_QpBjTxof5Ix4mmfmnPi66CNJ3GQFZuo4,19099
19
19
  jinns/utils/__init__.py,sha256=44ms5UR6vMw3Nf6u4RCAzPFs4fom_YbBnH9mfne8m6k,313
20
- jinns/utils/_hyperpinn.py,sha256=Mb5d6auzFfXcA81WgjiuhDBvAypAzVOENj_gUeqz6gI,11370
20
+ jinns/utils/_containers.py,sha256=eYD277fO7X4EfX7PUFCCl69r3JBfh1sCfq8LkL5gd6o,1495
21
+ jinns/utils/_hyperpinn.py,sha256=93hbiATdp5W4l1cu9Oe6O2c45o-ZF_z2u6FzNLyjnm4,10878
21
22
  jinns/utils/_optim.py,sha256=550kxH75TL30o1iKx1swJyP0KqyUPsJ7-imL1w65Qd0,4444
22
- jinns/utils/_pinn.py,sha256=N8LuB9Ql472O01USghkJkEOmx67DTjc279T8Lj-Lwd4,9722
23
+ jinns/utils/_pinn.py,sha256=mhA4-3PazyQTbWIx9oLaNwL0QDe8ZIBhbiy5J3kwa4I,9471
23
24
  jinns/utils/_save_load.py,sha256=qgZ23nUcB8-B5IZ2guuUWC4M7r5Lxd_Ms3staScdyJo,5668
24
- jinns/utils/_spinn.py,sha256=aeIC3DBY7f_N8HABjvBNv375dMyjll3zt6KjY2bEIkM,8058
25
+ jinns/utils/_spinn.py,sha256=SzOUt1KHtB9QOpghpvitnXN-KEqXUXbvabC5k0TnKEo,7793
25
26
  jinns/utils/_utils.py,sha256=8dgvWXX9NT7_7-zltWp0C9tG45ZFNwXxueyxPBb4hjo,6740
26
27
  jinns/utils/_utils_uspinn.py,sha256=qcKcOw3zrwWSQyGVj6fD8c9GinHt_U6JWN_k0auTtXM,26039
27
- jinns-0.8.5.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
28
- jinns-0.8.5.dist-info/METADATA,sha256=FrCf4ivCoMU3olFyCedvnVxCaBCiJbNfziUn1mVtyKo,2482
29
- jinns-0.8.5.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
30
- jinns-0.8.5.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
31
- jinns-0.8.5.dist-info/RECORD,,
28
+ jinns/validation/__init__.py,sha256=Jv58mzgC3F7cRfXA6caicL1t_U0UAhbwLrmMNVg6E7s,66
29
+ jinns/validation/_validation.py,sha256=KfetbzB0xTNdBcYLwFWjEtP63Tf9wJirlhgqLTJDyy4,6761
30
+ jinns-0.8.7.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
31
+ jinns-0.8.7.dist-info/METADATA,sha256=L0P7JvMGKrJHx9OjrtFsmNKEwdKA_RlufAbOBf5l10I,2482
32
+ jinns-0.8.7.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
33
+ jinns-0.8.7.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
34
+ jinns-0.8.7.dist-info/RECORD,,
File without changes
File without changes