jinns 1.4.0__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/solver/_solve.py CHANGED
@@ -14,7 +14,8 @@ import optax
14
14
  import jax
15
15
  from jax import jit
16
16
  import jax.numpy as jnp
17
- from jaxtyping import Float, Array
17
+ from jaxtyping import Float, Array, PyTree, Key
18
+ import equinox as eqx
18
19
  from jinns.solver._rar import init_rar, trigger_rar
19
20
  from jinns.utils._utils import _check_nan_in_pytree
20
21
  from jinns.solver._utils import _check_batch_size
@@ -29,7 +30,8 @@ from jinns.data._utils import append_param_batch, append_obs_batch
29
30
 
30
31
  if TYPE_CHECKING:
31
32
  from jinns.parameters._params import Params
32
- from jinns.utils._types import AnyLoss, AnyBatch
33
+ from jinns.utils._types import AnyBatch
34
+ from jinns.loss._abstract_loss import AbstractLoss
33
35
  from jinns.validation._validation import AbstractValidationModule
34
36
  from jinns.data._DataGeneratorParameter import DataGeneratorParameter
35
37
  from jinns.data._DataGeneratorObservations import DataGeneratorObservations
@@ -37,7 +39,7 @@ if TYPE_CHECKING:
37
39
 
38
40
  main_carry: TypeAlias = tuple[
39
41
  int,
40
- AnyLoss,
42
+ AbstractLoss,
41
43
  OptimizationContainer,
42
44
  OptimizationExtraContainer,
43
45
  DataGeneratorContainer,
@@ -45,6 +47,7 @@ if TYPE_CHECKING:
45
47
  LossContainer,
46
48
  StoredObjectContainer,
47
49
  Float[Array, " n_iter"] | None,
50
+ Key | None,
48
51
  ]
49
52
 
50
53
 
@@ -52,7 +55,7 @@ def solve(
52
55
  n_iter: int,
53
56
  init_params: Params[Array],
54
57
  data: AbstractDataGenerator,
55
- loss: AnyLoss,
58
+ loss: AbstractLoss,
56
59
  optimizer: optax.GradientTransformation,
57
60
  print_loss_every: int = 1000,
58
61
  opt_state: optax.OptState | None = None,
@@ -63,14 +66,16 @@ def solve(
63
66
  obs_batch_sharding: jax.sharding.Sharding | None = None,
64
67
  verbose: bool = True,
65
68
  ahead_of_time: bool = True,
69
+ key: Key = None,
66
70
  ) -> tuple[
67
71
  Params[Array],
68
72
  Float[Array, " n_iter"],
69
- dict[str, Float[Array, " n_iter"]],
73
+ PyTree,
70
74
  AbstractDataGenerator,
71
- AnyLoss,
75
+ AbstractLoss,
72
76
  optax.OptState,
73
77
  Params[Array | None],
78
+ PyTree,
74
79
  Float[Array, " n_iter"] | None,
75
80
  Params[Array],
76
81
  ]:
@@ -141,6 +146,9 @@ def solve(
141
146
  transformed (see https://jax.readthedocs.io/en/latest/aot.html#aot-compiled-functions-cannot-be-transformed).
142
147
  When False, jinns does not provide any timing information (which would
143
148
  be nonsense in a JIT transformed `solve()` function).
149
+ key
150
+ Default None. A JAX random key that can be used for diverse purpose in
151
+ the main iteration loop.
144
152
 
145
153
  Returns
146
154
  -------
@@ -150,8 +158,8 @@ def solve(
150
158
  total_loss_values
151
159
  An array of the total loss term along the gradient steps
152
160
  stored_loss_terms
153
- A dictionary. At each key an array of the values of a given loss
154
- term is stored
161
+ A PyTree with attributes being arrays of all the values for each loss
162
+ term
155
163
  data
156
164
  The input data object
157
165
  loss
@@ -161,11 +169,20 @@ def solve(
161
169
  stored_params
162
170
  A Params objects with the stored values of the desired parameters (as
163
171
  signified in tracked_params argument)
172
+ stored_weights_terms
173
+ A PyTree with attributes being arrays of all the values for each loss
174
+ weight. Note that if Loss.update_weight_method is None, we return None,
175
+ because loss weights are never updated and we can then save some
176
+ computations
164
177
  validation_crit_values
165
178
  An array containing the validation criterion values of the training
166
179
  best_val_params
167
180
  The best parameters according to the validation criterion
168
181
  """
182
+ initialization_time = time.time()
183
+ if n_iter < 1:
184
+ raise ValueError("Cannot run jinns.solve for n_iter<1")
185
+
169
186
  if param_data is not None:
170
187
  if param_data.param_batch_size is not None:
171
188
  # We need to check that batch sizes will all be compliant for
@@ -209,11 +226,6 @@ def solve(
209
226
  # get_batch with device_put, the latter is not jittable
210
227
  get_batch = _get_get_batch(obs_batch_sharding)
211
228
 
212
- # initialize the dict for stored parameter values
213
- # we need to get a loss_term to init stuff
214
- batch_ini, data, param_data, obs_data = get_batch(data, param_data, obs_data)
215
- _, loss_terms = loss(init_params, batch_ini)
216
-
217
229
  # initialize parameter tracking
218
230
  if tracked_params is None:
219
231
  tracked_params = jax.tree.map(lambda p: None, init_params)
@@ -231,11 +243,45 @@ def solve(
231
243
  # being a complex data structure
232
244
  )
233
245
 
234
- # initialize the dict for stored loss values
246
+ # initialize the dict for stored parameter values
247
+ # we need to get a loss_term to init stuff
248
+ # NOTE: we use jax.eval_shape to avoid FLOPS since we only need the tree
249
+ # structure
250
+ batch_ini, data, param_data, obs_data = get_batch(data, param_data, obs_data)
251
+ _, loss_terms = jax.eval_shape(loss, init_params, batch_ini)
252
+
253
+ # initialize the PyTree for stored loss values
235
254
  stored_loss_terms = jax.tree_util.tree_map(
236
255
  lambda _: jnp.zeros((n_iter)), loss_terms
237
256
  )
238
257
 
258
+ # initialize the PyTree for stored loss weights values
259
+ if loss.update_weight_method is not None:
260
+ stored_weights_terms = eqx.tree_at(
261
+ lambda pt: jax.tree.leaves(
262
+ pt, is_leaf=lambda x: x is not None and eqx.is_inexact_array(x)
263
+ ),
264
+ loss.loss_weights,
265
+ tuple(
266
+ jnp.zeros((n_iter))
267
+ for n in range(
268
+ len(
269
+ jax.tree.leaves(
270
+ loss.loss_weights,
271
+ is_leaf=lambda x: x is not None and eqx.is_inexact_array(x),
272
+ )
273
+ )
274
+ )
275
+ ),
276
+ )
277
+ else:
278
+ stored_weights_terms = None
279
+ if loss.update_weight_method is not None and key is None:
280
+ raise ValueError(
281
+ "`key` argument must be passed to jinns.solve when"
282
+ " `loss.update_weight_method` is not None"
283
+ )
284
+
239
285
  train_data = DataGeneratorContainer(
240
286
  data=data, param_data=param_data, obs_data=obs_data
241
287
  )
@@ -252,6 +298,7 @@ def solve(
252
298
  )
253
299
  loss_container = LossContainer(
254
300
  stored_loss_terms=stored_loss_terms,
301
+ stored_weights_terms=stored_weights_terms,
255
302
  train_loss_values=train_loss_values,
256
303
  )
257
304
  stored_objects = StoredObjectContainer(
@@ -276,6 +323,7 @@ def solve(
276
323
  loss_container,
277
324
  stored_objects,
278
325
  validation_crit_values,
326
+ key,
279
327
  )
280
328
 
281
329
  def _one_iteration(carry: main_carry) -> main_carry:
@@ -289,24 +337,47 @@ def solve(
289
337
  loss_container,
290
338
  stored_objects,
291
339
  validation_crit_values,
340
+ key,
292
341
  ) = carry
293
342
 
294
343
  batch, data, param_data, obs_data = get_batch(
295
344
  train_data.data, train_data.param_data, train_data.obs_data
296
345
  )
297
346
 
298
- # Gradient step
347
+ # ---------------------------------------------------------------------
348
+ # The following part is the equivalent of a
349
+ # > train_loss_value, grads = jax.values_and_grad(total_loss.evaluate)(params, ...)
350
+ # but it is decomposed on individual loss terms so that we can use it
351
+ # if needed for updating loss weights.
352
+ # Since the total loss is a weighted sum of individual loss terms, so
353
+ # are its total gradients.
354
+
355
+ # Compute individual losses and individual gradients
356
+ loss_terms, grad_terms = loss.evaluate_by_terms(optimization.params, batch)
357
+
358
+ if loss.update_weight_method is not None:
359
+ key, subkey = jax.random.split(key) # type: ignore because key can
360
+ # still be None currently
361
+ # avoid computations of tree_at if no updates
362
+ loss = loss.update_weights(
363
+ i, loss_terms, loss_container.stored_loss_terms, grad_terms, subkey
364
+ )
365
+
366
+ # total grad
367
+ grads = loss.ponderate_and_sum_gradient(grad_terms)
368
+
369
+ # total loss
370
+ train_loss_value = loss.ponderate_and_sum_loss(loss_terms)
371
+ # ---------------------------------------------------------------------
372
+
373
+ # gradient step
299
374
  (
300
- loss,
301
- train_loss_value,
302
- loss_terms,
303
375
  params,
304
376
  opt_state,
305
377
  last_non_nan_params,
306
378
  ) = _gradient_step(
307
- loss,
379
+ grads,
308
380
  optimizer,
309
- batch,
310
381
  optimization.params,
311
382
  optimization.opt_state,
312
383
  optimization.last_non_nan_params,
@@ -374,14 +445,14 @@ def solve(
374
445
  )
375
446
 
376
447
  # save loss value and selected parameters
377
- stored_params, stored_loss_terms, train_loss_values = _store_loss_and_params(
448
+ stored_objects, loss_container = _store_loss_and_params(
378
449
  i,
379
450
  params,
380
451
  stored_objects.stored_params,
381
- loss_container.stored_loss_terms,
382
- loss_container.train_loss_values,
452
+ loss_container,
383
453
  train_loss_value,
384
454
  loss_terms,
455
+ loss.loss_weights,
385
456
  tracked_params,
386
457
  )
387
458
 
@@ -401,11 +472,15 @@ def solve(
401
472
  ),
402
473
  DataGeneratorContainer(data, param_data, obs_data),
403
474
  validation,
404
- LossContainer(stored_loss_terms, train_loss_values),
405
- StoredObjectContainer(stored_params),
475
+ loss_container,
476
+ stored_objects,
406
477
  validation_crit_values,
478
+ key,
407
479
  )
408
480
 
481
+ if verbose:
482
+ print("Initialization time:", time.time() - initialization_time)
483
+
409
484
  # Main optimization loop. We use the LAX while loop (fully jitted) version
410
485
  # if no mixing devices. Otherwise we use the standard while loop. Here devices only
411
486
  # concern obs_batch, but it could lead to more complex scheme in the future
@@ -443,6 +518,7 @@ def solve(
443
518
  loss_container,
444
519
  stored_objects,
445
520
  validation_crit_values,
521
+ key,
446
522
  ) = carry
447
523
 
448
524
  if verbose:
@@ -480,6 +556,7 @@ def solve(
480
556
  loss, # return the Loss if needed (no-inplace modif)
481
557
  optimization.opt_state,
482
558
  stored_objects.stored_params,
559
+ loss_container.stored_weights_terms,
483
560
  validation_crit_values if validation is not None else None,
484
561
  validation_parameters,
485
562
  )
@@ -487,16 +564,12 @@ def solve(
487
564
 
488
565
  @partial(jit, static_argnames=["optimizer"])
489
566
  def _gradient_step(
490
- loss: AnyLoss,
567
+ grads: Params[Array],
491
568
  optimizer: optax.GradientTransformation,
492
- batch: AnyBatch,
493
569
  params: Params[Array],
494
570
  opt_state: optax.OptState,
495
571
  last_non_nan_params: Params[Array],
496
572
  ) -> tuple[
497
- AnyLoss,
498
- float,
499
- dict[str, float],
500
573
  Params[Array],
501
574
  optax.OptState,
502
575
  Params[Array],
@@ -504,13 +577,12 @@ def _gradient_step(
504
577
  """
505
578
  optimizer cannot be jit-ted.
506
579
  """
507
- value_grad_loss = jax.value_and_grad(loss, has_aux=True)
508
- (loss_val, loss_terms), grads = value_grad_loss(params, batch)
580
+
509
581
  updates, opt_state = optimizer.update(
510
- grads,
582
+ grads, # type: ignore
511
583
  opt_state,
512
584
  params, # type: ignore
513
- ) # see optimizer.init for explaination
585
+ ) # see optimizer.init for explaination for the ignore(s) here
514
586
  params = optax.apply_updates(params, updates) # type: ignore
515
587
 
516
588
  # check if any of the parameters is NaN
@@ -522,9 +594,6 @@ def _gradient_step(
522
594
  )
523
595
 
524
596
  return (
525
- loss,
526
- loss_val,
527
- loss_terms,
528
597
  params,
529
598
  opt_state,
530
599
  last_non_nan_params,
@@ -551,13 +620,13 @@ def _print_fn(i: int, loss_val: Float, print_loss_every: int, prefix: str = ""):
551
620
  def _store_loss_and_params(
552
621
  i: int,
553
622
  params: Params[Array],
554
- stored_params: Params[Array],
555
- stored_loss_terms: dict[str, Float[Array, " n_iter"]],
556
- train_loss_values: Float[Array, " n_iter"],
623
+ stored_params: Params[Array | None],
624
+ loss_container: LossContainer,
557
625
  train_loss_val: float,
558
- loss_terms: dict[str, float],
626
+ loss_terms: PyTree[Array],
627
+ weight_terms: PyTree[Array],
559
628
  tracked_params: Params,
560
- ) -> tuple[Params, dict[str, Float[Array, " n_iter"]], Float[Array, " n_iter"]]:
629
+ ) -> tuple[StoredObjectContainer, LossContainer]:
561
630
  stored_params = jax.tree_util.tree_map(
562
631
  lambda stored_value, param, tracked_param: (
563
632
  None
@@ -576,12 +645,38 @@ def _store_loss_and_params(
576
645
  )
577
646
  stored_loss_terms = jax.tree_util.tree_map(
578
647
  lambda stored_term, loss_term: stored_term.at[i].set(loss_term),
579
- stored_loss_terms,
648
+ loss_container.stored_loss_terms,
580
649
  loss_terms,
581
650
  )
582
651
 
583
- train_loss_values = train_loss_values.at[i].set(train_loss_val)
584
- return (stored_params, stored_loss_terms, train_loss_values)
652
+ if loss_container.stored_weights_terms is not None:
653
+ stored_weights_terms = jax.tree_util.tree_map(
654
+ lambda stored_term, weight_term: stored_term.at[i].set(weight_term),
655
+ jax.tree.leaves(
656
+ loss_container.stored_weights_terms,
657
+ is_leaf=lambda x: x is not None and eqx.is_inexact_array(x),
658
+ ),
659
+ jax.tree.leaves(
660
+ weight_terms,
661
+ is_leaf=lambda x: x is not None and eqx.is_inexact_array(x),
662
+ ),
663
+ )
664
+ stored_weights_terms = eqx.tree_at(
665
+ lambda pt: jax.tree.leaves(
666
+ pt, is_leaf=lambda x: x is not None and eqx.is_inexact_array(x)
667
+ ),
668
+ loss_container.stored_weights_terms,
669
+ stored_weights_terms,
670
+ )
671
+ else:
672
+ stored_weights_terms = None
673
+
674
+ train_loss_values = loss_container.train_loss_values.at[i].set(train_loss_val)
675
+ loss_container = LossContainer(
676
+ stored_loss_terms, stored_weights_terms, train_loss_values
677
+ )
678
+ stored_objects = StoredObjectContainer(stored_params)
679
+ return stored_objects, loss_container
585
680
 
586
681
 
587
682
  def _get_break_fun(n_iter: int, verbose: bool) -> Callable[[main_carry], bool]:
@@ -612,7 +707,7 @@ def _get_break_fun(n_iter: int, verbose: bool) -> Callable[[main_carry], bool]:
612
707
  def continue_while_loop(_):
613
708
  return True
614
709
 
615
- (i, _, optimization, optimization_extra, _, _, _, _, _) = carry
710
+ (i, _, optimization, optimization_extra, _, _, _, _, _, _) = carry
616
711
 
617
712
  # Condition 1
618
713
  bool_max_iter = jax.lax.cond(
@@ -6,7 +6,7 @@ from __future__ import (
6
6
  annotations,
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
8
 
9
- from typing import TYPE_CHECKING, Dict
9
+ from typing import TYPE_CHECKING
10
10
  from jaxtyping import PyTree, Array, Float, Bool
11
11
  from optax import OptState
12
12
  import equinox as eqx
@@ -48,7 +48,10 @@ class OptimizationExtraContainer(eqx.Module):
48
48
 
49
49
 
50
50
  class LossContainer(eqx.Module):
51
- stored_loss_terms: Dict[str, Float[Array, " n_iter"]]
51
+ # PyTree below refers to ODEComponents or PDEStatioComponents or
52
+ # PDENonStatioComponents
53
+ stored_loss_terms: PyTree[Float[Array, " n_iter"]]
54
+ stored_weights_terms: PyTree[Float[Array, " n_iter"]]
52
55
  train_loss_values: Float[Array, " n_iter"]
53
56
 
54
57
 
jinns/utils/_types.py CHANGED
@@ -9,6 +9,11 @@ if TYPE_CHECKING:
9
9
  from jinns.data._Batchs import ODEBatch, PDEStatioBatch, PDENonStatioBatch
10
10
  from jinns.loss._LossODE import LossODE
11
11
  from jinns.loss._LossPDE import LossPDEStatio, LossPDENonStatio
12
+ from jinns.loss._loss_components import (
13
+ ODEComponents,
14
+ PDEStatioComponents,
15
+ PDENonStatioComponents,
16
+ )
12
17
 
13
18
  # Here we define types available for the whole package
14
19
  BoundaryConditionFun: TypeAlias = Callable[
@@ -17,3 +22,10 @@ if TYPE_CHECKING:
17
22
 
18
23
  AnyBatch: TypeAlias = ODEBatch | PDENonStatioBatch | PDEStatioBatch
19
24
  AnyLoss: TypeAlias = LossODE | LossPDEStatio | LossPDENonStatio
25
+
26
+ # here we would like a type from 3.12
27
+ # (https://typing.python.org/en/latest/spec/aliases.html#type-statement) so
28
+ # that we could have a generic AnyLossComponents
29
+ AnyLossComponents: TypeAlias = (
30
+ ODEComponents | PDEStatioComponents | PDENonStatioComponents
31
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jinns
3
- Version: 1.4.0
3
+ Version: 1.5.1
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -54,6 +54,9 @@ It can also be used for forward problems and hybrid-modeling.
54
54
 
55
55
  - [Hyper PINNs](https://arxiv.org/pdf/2111.01008.pdf): useful for meta-modeling
56
56
 
57
+ - Other
58
+ - Adaptative Loss Weights are now implemented. Some SoftAdapt, LRAnnealing and ReLoBRaLo are available and users can implement their own strategy. See the [tutorial](https://mia_jinns.gitlab.io/jinns/Notebooks/Tutorials/implementing_your_own_PDE_problem/)
59
+
57
60
 
58
61
  - **Get started**: check out our various notebooks on the [documentation](https://mia_jinns.gitlab.io/jinns/index.html).
59
62
 
@@ -113,7 +116,7 @@ pre-commit install
113
116
 
114
117
  Don't hesitate to contribute and get your name on the list here !
115
118
 
116
- **List of contributors:** Hugo Gangloff, Nicolas Jouvin, Lucia Clarotto, Inass Soukarieh
119
+ **List of contributors:** Hugo Gangloff, Nicolas Jouvin, Lucia Clarotto, Inass Soukarieh, Mohamed Badi
117
120
 
118
121
  # Cite us
119
122
 
@@ -1,8 +1,8 @@
1
- jinns/__init__.py,sha256=hyh3QKO2cQGK5cmvFYP0MrXb-tK_DM2T9CwLwO-sEX8,500
1
+ jinns/__init__.py,sha256=Tjp5z0Mnd1nscvJaXnxZb4lEYbI0cpN6OPgCZ-Swo74,453
2
2
  jinns/data/_AbstractDataGenerator.py,sha256=O61TBOyeOFKwf1xqKzFD4KwCWRDnm2XgyJ-kKY9fmB4,557
3
3
  jinns/data/_Batchs.py,sha256=-DlD6Qag3zs5QbKtKAOvOzV7JOpNOqAm_P8cwo1dIZg,1574
4
- jinns/data/_CubicMeshPDENonStatio.py,sha256=c_8czJpxSoEvgZ8LDpL2sqtF9dcW4ELNO4juEFMOxog,16400
5
- jinns/data/_CubicMeshPDEStatio.py,sha256=stZ0Kbb7_VwFmWUSPs0P6a6qRj2Tu67p7sxEfb1Ajks,17865
4
+ jinns/data/_CubicMeshPDENonStatio.py,sha256=4f21SgeNQsGJTz8-uehduU0X9TibRcs28Iq49Kv4nQQ,22250
5
+ jinns/data/_CubicMeshPDEStatio.py,sha256=DVrP4qHVAJMu915EYH2PKyzwoG0nIMixAEKR2fz6C58,22525
6
6
  jinns/data/_DataGeneratorODE.py,sha256=5RzUbQFEsooAZsocDw4wRgA_w5lJmDMuY4M6u79K-1c,7260
7
7
  jinns/data/_DataGeneratorObservations.py,sha256=jknepLsJatSJHFq5lLMD-fFHkPGj5q286LEjE-vH24k,7738
8
8
  jinns/data/_DataGeneratorParameter.py,sha256=IedX3jcOj7ZDW_18IAcRR75KVzQzo85z9SICIKDBJl4,8539
@@ -11,14 +11,16 @@ jinns/data/_utils.py,sha256=XxaLIg_HIgcB7ACBIhTpHbCT1HXKcDaY1NABncAYX1c,5223
11
11
  jinns/experimental/__init__.py,sha256=DT9e57zbjfzPeRnXemGUqnGd--MhV77FspChT0z4YrE,410
12
12
  jinns/experimental/_diffrax_solver.py,sha256=upMr3kTTNrxEiSUO_oLvCXcjS9lPxSjvbB81h3qlhaU,6813
13
13
  jinns/loss/_DynamicLoss.py,sha256=4mb7OCP-cGZ_mG2MQ-AniddDcuBT78p4bQI7rZpwte4,22722
14
- jinns/loss/_DynamicLossAbstract.py,sha256=HIs6TtE9ouvT5H3cBR52UWSkALTgRWRG_kB3s890b2U,11253
15
- jinns/loss/_LossODE.py,sha256=aCx5vD3CXmhws36gYre1iu_t29MefpTxW54gUxKej_Q,11856
16
- jinns/loss/_LossPDE.py,sha256=mANXuWjm02bKwfoIn-RH8vxPY-RG3jVErcSrwwK3HzM,33259
17
- jinns/loss/__init__.py,sha256=qnNRGjl6Tcga1koztMJnJ3eL8XNP0gRbNsXVgq4CkOI,1162
18
- jinns/loss/_abstract_loss.py,sha256=hf6ohNoSAqzskFyivCiuE2SCqhsV4UWLv65L4V8H3ys,407
14
+ jinns/loss/_DynamicLossAbstract.py,sha256=QhHRgvtcT-ifHlOxTyXbjDtHk9UfPN2Si8s3v9nEQm4,12672
15
+ jinns/loss/_LossODE.py,sha256=iVYDojaI6Co7S5CrU67_niopD4Bk7UBTuLzDiTHoWMc,16996
16
+ jinns/loss/_LossPDE.py,sha256=VT56oQ_33fLq46lIch0slNsxu4d97eQBOgRAPeFESts,36401
17
+ jinns/loss/__init__.py,sha256=z5xYgBipNFf66__5BqQc6R_8r4F6A3TXL60YjsM8Osk,1287
18
+ jinns/loss/_abstract_loss.py,sha256=DMxn0SQe9PW-pq3p5Oqvb0YK3_ulLDOnoIXzK219GV4,4576
19
19
  jinns/loss/_boundary_conditions.py,sha256=9HGw1cGLfmEilP4V4B2T0zl0YP1kNtrtXVLQNiBmWgc,12464
20
- jinns/loss/_loss_utils.py,sha256=R_jhBHkTwGu41gWnhYRswunxdzetPZ9-Gmkghzorock,11745
21
- jinns/loss/_loss_weights.py,sha256=5BVZglM7Y3m_8muXcKT898fAC6_RbdLNQ7WWx3lOE9k,1077
20
+ jinns/loss/_loss_components.py,sha256=MMzaGlaRqESPjRzT0j0WU9HAqWQSbIXpGAqM1xQCZHw,1106
21
+ jinns/loss/_loss_utils.py,sha256=eJ4JcBm396LHx7Tti88ZQrLcKqVL1oSfFGT23VNkytQ,11949
22
+ jinns/loss/_loss_weight_updates.py,sha256=9Bwouh7shLyc_wrdzN6CYL0ZuQH81uEs-L6wCeiYFx8,6817
23
+ jinns/loss/_loss_weights.py,sha256=kII5WddORgeommFTudT3CSvhICpo6nSe47LclUgu_78,2429
22
24
  jinns/loss/_operators.py,sha256=Ds5yRH7hu-jaGRp7PYbt821BgYuEvgWHufWhYgdMjw0,22909
23
25
  jinns/nn/__init__.py,sha256=gwE48oqB_FsSIE-hUvCLz0jPaqX350LBxzH6ueFWYk4,456
24
26
  jinns/nn/_abstract_pinn.py,sha256=JUFjlV_nyheZw-max_tAUgFh6SspIbD5we_4bn70V6k,671
@@ -32,22 +34,22 @@ jinns/nn/_spinn_mlp.py,sha256=uCL454sF0Tfj7KT-fdXPnvKJYRQOuq60N0r2b2VAB8Q,7606
32
34
  jinns/nn/_utils.py,sha256=9UXz73iHKHVQYPBPIEitrHYJzJ14dspRwPfLA8avx0c,1120
33
35
  jinns/parameters/__init__.py,sha256=O0n7y6R1LRmFzzugCxMFCMS2pgsuWSh-XHjfFViN_eg,265
34
36
  jinns/parameters/_derivative_keys.py,sha256=YlLDX49PfYhr2Tj--t3praiD8JOUTZU6PTmjbNZsbMc,19173
35
- jinns/parameters/_params.py,sha256=qn4IGMJhD9lDBqOWmGEMy4gXt5a6KHfirkYZwHO7Vwk,2633
37
+ jinns/parameters/_params.py,sha256=nv0WScbgUdmuC0bSF15VbnKypJ58pl6wynZAcYfuF6M,3081
36
38
  jinns/plot/__init__.py,sha256=KPHX0Um4FbciZO1yD8kjZbkaT8tT964Y6SE2xCQ4eDU,135
37
39
  jinns/plot/_plot.py,sha256=-A5auNeElaz2_8UzVQJQE4143ZFg0zgMjStU7kwttEY,11565
38
40
  jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
39
41
  jinns/solver/_rar.py,sha256=vSVTnCGCusI1vTZCvIkP2_G8we44G_42yZHx2sOK9DE,10291
40
- jinns/solver/_solve.py,sha256=uPJsN4Pv_QEHYMlMdo29hlJXmWyCtf2aFZlj2M8Fl2U,24886
42
+ jinns/solver/_solve.py,sha256=IsrmG2m48KkkYgvXYomlSbZ3hd1FySCj3rwlkovs-lI,28616
41
43
  jinns/solver/_utils.py,sha256=sM2UbVzYyjw24l4QSIR3IlynJTPGD_S08r8v0lXMxA8,5876
42
44
  jinns/utils/__init__.py,sha256=OEYWLCw8pKE7xoQREbd6SHvCjuw2QZHuVA6YwDcsBE8,53
43
- jinns/utils/_containers.py,sha256=XANkmGiXvFb7Qh8MtGuhcZQl4Fpw4woJcn17-y1-VHs,1690
44
- jinns/utils/_types.py,sha256=PEPVEZ4XGT-7gCIasUHDYpIrMP_Ke1KTXGloXJPlK_k,746
45
+ jinns/utils/_containers.py,sha256=YShcrPKfj5_I9mn3NMAS4Ea9MhhyL7fjv0e3MRbITHg,1837
46
+ jinns/utils/_types.py,sha256=jl_91HtcrtE6UHbdTrRI8iUmr2kBUL0oP0UNIKhAXYw,1170
45
47
  jinns/utils/_utils.py,sha256=M7NXX9ok-BkH5o_xo74PB1_Cc8XiDipSl51rq82dTH4,2821
46
48
  jinns/validation/__init__.py,sha256=FTyUO-v1b8Tv-FDSQsntrH7zl9E0ENexqKMT_dFRkYo,124
47
49
  jinns/validation/_validation.py,sha256=8p6sMKiBAvA6JNm65hjkMj0997LJ0BkyCREEh0AnPVE,4803
48
- jinns-1.4.0.dist-info/licenses/AUTHORS,sha256=7NwCj9nU-HNG1asvy4qhQ2w7oZHrn-Lk5_wK_Ve7a3M,80
49
- jinns-1.4.0.dist-info/licenses/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
50
- jinns-1.4.0.dist-info/METADATA,sha256=MkB5xNjrdFcJHmlwhk_RRwNNuCdkHLYpAxB-7TYhykg,5031
51
- jinns-1.4.0.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
52
- jinns-1.4.0.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
53
- jinns-1.4.0.dist-info/RECORD,,
50
+ jinns-1.5.1.dist-info/licenses/AUTHORS,sha256=7NwCj9nU-HNG1asvy4qhQ2w7oZHrn-Lk5_wK_Ve7a3M,80
51
+ jinns-1.5.1.dist-info/licenses/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
52
+ jinns-1.5.1.dist-info/METADATA,sha256=K7Aii5ivFcczIwLlQtCPqzMJEfW86D1yW1q7qMtvWPE,5314
53
+ jinns-1.5.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
54
+ jinns-1.5.1.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
55
+ jinns-1.5.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.3.1)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5