jinns 1.4.0__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +7 -7
- jinns/data/_CubicMeshPDENonStatio.py +156 -28
- jinns/data/_CubicMeshPDEStatio.py +132 -24
- jinns/loss/_DynamicLossAbstract.py +30 -2
- jinns/loss/_LossODE.py +177 -64
- jinns/loss/_LossPDE.py +146 -68
- jinns/loss/__init__.py +4 -0
- jinns/loss/_abstract_loss.py +116 -3
- jinns/loss/_loss_components.py +43 -0
- jinns/loss/_loss_utils.py +34 -24
- jinns/loss/_loss_weight_updates.py +202 -0
- jinns/loss/_loss_weights.py +72 -16
- jinns/parameters/_params.py +8 -0
- jinns/solver/_solve.py +141 -46
- jinns/utils/_containers.py +5 -2
- jinns/utils/_types.py +12 -0
- {jinns-1.4.0.dist-info → jinns-1.5.1.dist-info}/METADATA +5 -2
- {jinns-1.4.0.dist-info → jinns-1.5.1.dist-info}/RECORD +22 -20
- {jinns-1.4.0.dist-info → jinns-1.5.1.dist-info}/WHEEL +1 -1
- {jinns-1.4.0.dist-info → jinns-1.5.1.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.4.0.dist-info → jinns-1.5.1.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.4.0.dist-info → jinns-1.5.1.dist-info}/top_level.txt +0 -0
jinns/solver/_solve.py
CHANGED
|
@@ -14,7 +14,8 @@ import optax
|
|
|
14
14
|
import jax
|
|
15
15
|
from jax import jit
|
|
16
16
|
import jax.numpy as jnp
|
|
17
|
-
from jaxtyping import Float, Array
|
|
17
|
+
from jaxtyping import Float, Array, PyTree, Key
|
|
18
|
+
import equinox as eqx
|
|
18
19
|
from jinns.solver._rar import init_rar, trigger_rar
|
|
19
20
|
from jinns.utils._utils import _check_nan_in_pytree
|
|
20
21
|
from jinns.solver._utils import _check_batch_size
|
|
@@ -29,7 +30,8 @@ from jinns.data._utils import append_param_batch, append_obs_batch
|
|
|
29
30
|
|
|
30
31
|
if TYPE_CHECKING:
|
|
31
32
|
from jinns.parameters._params import Params
|
|
32
|
-
from jinns.utils._types import
|
|
33
|
+
from jinns.utils._types import AnyBatch
|
|
34
|
+
from jinns.loss._abstract_loss import AbstractLoss
|
|
33
35
|
from jinns.validation._validation import AbstractValidationModule
|
|
34
36
|
from jinns.data._DataGeneratorParameter import DataGeneratorParameter
|
|
35
37
|
from jinns.data._DataGeneratorObservations import DataGeneratorObservations
|
|
@@ -37,7 +39,7 @@ if TYPE_CHECKING:
|
|
|
37
39
|
|
|
38
40
|
main_carry: TypeAlias = tuple[
|
|
39
41
|
int,
|
|
40
|
-
|
|
42
|
+
AbstractLoss,
|
|
41
43
|
OptimizationContainer,
|
|
42
44
|
OptimizationExtraContainer,
|
|
43
45
|
DataGeneratorContainer,
|
|
@@ -45,6 +47,7 @@ if TYPE_CHECKING:
|
|
|
45
47
|
LossContainer,
|
|
46
48
|
StoredObjectContainer,
|
|
47
49
|
Float[Array, " n_iter"] | None,
|
|
50
|
+
Key | None,
|
|
48
51
|
]
|
|
49
52
|
|
|
50
53
|
|
|
@@ -52,7 +55,7 @@ def solve(
|
|
|
52
55
|
n_iter: int,
|
|
53
56
|
init_params: Params[Array],
|
|
54
57
|
data: AbstractDataGenerator,
|
|
55
|
-
loss:
|
|
58
|
+
loss: AbstractLoss,
|
|
56
59
|
optimizer: optax.GradientTransformation,
|
|
57
60
|
print_loss_every: int = 1000,
|
|
58
61
|
opt_state: optax.OptState | None = None,
|
|
@@ -63,14 +66,16 @@ def solve(
|
|
|
63
66
|
obs_batch_sharding: jax.sharding.Sharding | None = None,
|
|
64
67
|
verbose: bool = True,
|
|
65
68
|
ahead_of_time: bool = True,
|
|
69
|
+
key: Key = None,
|
|
66
70
|
) -> tuple[
|
|
67
71
|
Params[Array],
|
|
68
72
|
Float[Array, " n_iter"],
|
|
69
|
-
|
|
73
|
+
PyTree,
|
|
70
74
|
AbstractDataGenerator,
|
|
71
|
-
|
|
75
|
+
AbstractLoss,
|
|
72
76
|
optax.OptState,
|
|
73
77
|
Params[Array | None],
|
|
78
|
+
PyTree,
|
|
74
79
|
Float[Array, " n_iter"] | None,
|
|
75
80
|
Params[Array],
|
|
76
81
|
]:
|
|
@@ -141,6 +146,9 @@ def solve(
|
|
|
141
146
|
transformed (see https://jax.readthedocs.io/en/latest/aot.html#aot-compiled-functions-cannot-be-transformed).
|
|
142
147
|
When False, jinns does not provide any timing information (which would
|
|
143
148
|
be nonsense in a JIT transformed `solve()` function).
|
|
149
|
+
key
|
|
150
|
+
Default None. A JAX random key that can be used for diverse purpose in
|
|
151
|
+
the main iteration loop.
|
|
144
152
|
|
|
145
153
|
Returns
|
|
146
154
|
-------
|
|
@@ -150,8 +158,8 @@ def solve(
|
|
|
150
158
|
total_loss_values
|
|
151
159
|
An array of the total loss term along the gradient steps
|
|
152
160
|
stored_loss_terms
|
|
153
|
-
A
|
|
154
|
-
term
|
|
161
|
+
A PyTree with attributes being arrays of all the values for each loss
|
|
162
|
+
term
|
|
155
163
|
data
|
|
156
164
|
The input data object
|
|
157
165
|
loss
|
|
@@ -161,11 +169,20 @@ def solve(
|
|
|
161
169
|
stored_params
|
|
162
170
|
A Params objects with the stored values of the desired parameters (as
|
|
163
171
|
signified in tracked_params argument)
|
|
172
|
+
stored_weights_terms
|
|
173
|
+
A PyTree with attributes being arrays of all the values for each loss
|
|
174
|
+
weight. Note that if Loss.update_weight_method is None, we return None,
|
|
175
|
+
because loss weights are never updated and we can then save some
|
|
176
|
+
computations
|
|
164
177
|
validation_crit_values
|
|
165
178
|
An array containing the validation criterion values of the training
|
|
166
179
|
best_val_params
|
|
167
180
|
The best parameters according to the validation criterion
|
|
168
181
|
"""
|
|
182
|
+
initialization_time = time.time()
|
|
183
|
+
if n_iter < 1:
|
|
184
|
+
raise ValueError("Cannot run jinns.solve for n_iter<1")
|
|
185
|
+
|
|
169
186
|
if param_data is not None:
|
|
170
187
|
if param_data.param_batch_size is not None:
|
|
171
188
|
# We need to check that batch sizes will all be compliant for
|
|
@@ -209,11 +226,6 @@ def solve(
|
|
|
209
226
|
# get_batch with device_put, the latter is not jittable
|
|
210
227
|
get_batch = _get_get_batch(obs_batch_sharding)
|
|
211
228
|
|
|
212
|
-
# initialize the dict for stored parameter values
|
|
213
|
-
# we need to get a loss_term to init stuff
|
|
214
|
-
batch_ini, data, param_data, obs_data = get_batch(data, param_data, obs_data)
|
|
215
|
-
_, loss_terms = loss(init_params, batch_ini)
|
|
216
|
-
|
|
217
229
|
# initialize parameter tracking
|
|
218
230
|
if tracked_params is None:
|
|
219
231
|
tracked_params = jax.tree.map(lambda p: None, init_params)
|
|
@@ -231,11 +243,45 @@ def solve(
|
|
|
231
243
|
# being a complex data structure
|
|
232
244
|
)
|
|
233
245
|
|
|
234
|
-
# initialize the dict for stored
|
|
246
|
+
# initialize the dict for stored parameter values
|
|
247
|
+
# we need to get a loss_term to init stuff
|
|
248
|
+
# NOTE: we use jax.eval_shape to avoid FLOPS since we only need the tree
|
|
249
|
+
# structure
|
|
250
|
+
batch_ini, data, param_data, obs_data = get_batch(data, param_data, obs_data)
|
|
251
|
+
_, loss_terms = jax.eval_shape(loss, init_params, batch_ini)
|
|
252
|
+
|
|
253
|
+
# initialize the PyTree for stored loss values
|
|
235
254
|
stored_loss_terms = jax.tree_util.tree_map(
|
|
236
255
|
lambda _: jnp.zeros((n_iter)), loss_terms
|
|
237
256
|
)
|
|
238
257
|
|
|
258
|
+
# initialize the PyTree for stored loss weights values
|
|
259
|
+
if loss.update_weight_method is not None:
|
|
260
|
+
stored_weights_terms = eqx.tree_at(
|
|
261
|
+
lambda pt: jax.tree.leaves(
|
|
262
|
+
pt, is_leaf=lambda x: x is not None and eqx.is_inexact_array(x)
|
|
263
|
+
),
|
|
264
|
+
loss.loss_weights,
|
|
265
|
+
tuple(
|
|
266
|
+
jnp.zeros((n_iter))
|
|
267
|
+
for n in range(
|
|
268
|
+
len(
|
|
269
|
+
jax.tree.leaves(
|
|
270
|
+
loss.loss_weights,
|
|
271
|
+
is_leaf=lambda x: x is not None and eqx.is_inexact_array(x),
|
|
272
|
+
)
|
|
273
|
+
)
|
|
274
|
+
)
|
|
275
|
+
),
|
|
276
|
+
)
|
|
277
|
+
else:
|
|
278
|
+
stored_weights_terms = None
|
|
279
|
+
if loss.update_weight_method is not None and key is None:
|
|
280
|
+
raise ValueError(
|
|
281
|
+
"`key` argument must be passed to jinns.solve when"
|
|
282
|
+
" `loss.update_weight_method` is not None"
|
|
283
|
+
)
|
|
284
|
+
|
|
239
285
|
train_data = DataGeneratorContainer(
|
|
240
286
|
data=data, param_data=param_data, obs_data=obs_data
|
|
241
287
|
)
|
|
@@ -252,6 +298,7 @@ def solve(
|
|
|
252
298
|
)
|
|
253
299
|
loss_container = LossContainer(
|
|
254
300
|
stored_loss_terms=stored_loss_terms,
|
|
301
|
+
stored_weights_terms=stored_weights_terms,
|
|
255
302
|
train_loss_values=train_loss_values,
|
|
256
303
|
)
|
|
257
304
|
stored_objects = StoredObjectContainer(
|
|
@@ -276,6 +323,7 @@ def solve(
|
|
|
276
323
|
loss_container,
|
|
277
324
|
stored_objects,
|
|
278
325
|
validation_crit_values,
|
|
326
|
+
key,
|
|
279
327
|
)
|
|
280
328
|
|
|
281
329
|
def _one_iteration(carry: main_carry) -> main_carry:
|
|
@@ -289,24 +337,47 @@ def solve(
|
|
|
289
337
|
loss_container,
|
|
290
338
|
stored_objects,
|
|
291
339
|
validation_crit_values,
|
|
340
|
+
key,
|
|
292
341
|
) = carry
|
|
293
342
|
|
|
294
343
|
batch, data, param_data, obs_data = get_batch(
|
|
295
344
|
train_data.data, train_data.param_data, train_data.obs_data
|
|
296
345
|
)
|
|
297
346
|
|
|
298
|
-
#
|
|
347
|
+
# ---------------------------------------------------------------------
|
|
348
|
+
# The following part is the equivalent of a
|
|
349
|
+
# > train_loss_value, grads = jax.values_and_grad(total_loss.evaluate)(params, ...)
|
|
350
|
+
# but it is decomposed on individual loss terms so that we can use it
|
|
351
|
+
# if needed for updating loss weights.
|
|
352
|
+
# Since the total loss is a weighted sum of individual loss terms, so
|
|
353
|
+
# are its total gradients.
|
|
354
|
+
|
|
355
|
+
# Compute individual losses and individual gradients
|
|
356
|
+
loss_terms, grad_terms = loss.evaluate_by_terms(optimization.params, batch)
|
|
357
|
+
|
|
358
|
+
if loss.update_weight_method is not None:
|
|
359
|
+
key, subkey = jax.random.split(key) # type: ignore because key can
|
|
360
|
+
# still be None currently
|
|
361
|
+
# avoid computations of tree_at if no updates
|
|
362
|
+
loss = loss.update_weights(
|
|
363
|
+
i, loss_terms, loss_container.stored_loss_terms, grad_terms, subkey
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
# total grad
|
|
367
|
+
grads = loss.ponderate_and_sum_gradient(grad_terms)
|
|
368
|
+
|
|
369
|
+
# total loss
|
|
370
|
+
train_loss_value = loss.ponderate_and_sum_loss(loss_terms)
|
|
371
|
+
# ---------------------------------------------------------------------
|
|
372
|
+
|
|
373
|
+
# gradient step
|
|
299
374
|
(
|
|
300
|
-
loss,
|
|
301
|
-
train_loss_value,
|
|
302
|
-
loss_terms,
|
|
303
375
|
params,
|
|
304
376
|
opt_state,
|
|
305
377
|
last_non_nan_params,
|
|
306
378
|
) = _gradient_step(
|
|
307
|
-
|
|
379
|
+
grads,
|
|
308
380
|
optimizer,
|
|
309
|
-
batch,
|
|
310
381
|
optimization.params,
|
|
311
382
|
optimization.opt_state,
|
|
312
383
|
optimization.last_non_nan_params,
|
|
@@ -374,14 +445,14 @@ def solve(
|
|
|
374
445
|
)
|
|
375
446
|
|
|
376
447
|
# save loss value and selected parameters
|
|
377
|
-
|
|
448
|
+
stored_objects, loss_container = _store_loss_and_params(
|
|
378
449
|
i,
|
|
379
450
|
params,
|
|
380
451
|
stored_objects.stored_params,
|
|
381
|
-
loss_container
|
|
382
|
-
loss_container.train_loss_values,
|
|
452
|
+
loss_container,
|
|
383
453
|
train_loss_value,
|
|
384
454
|
loss_terms,
|
|
455
|
+
loss.loss_weights,
|
|
385
456
|
tracked_params,
|
|
386
457
|
)
|
|
387
458
|
|
|
@@ -401,11 +472,15 @@ def solve(
|
|
|
401
472
|
),
|
|
402
473
|
DataGeneratorContainer(data, param_data, obs_data),
|
|
403
474
|
validation,
|
|
404
|
-
|
|
405
|
-
|
|
475
|
+
loss_container,
|
|
476
|
+
stored_objects,
|
|
406
477
|
validation_crit_values,
|
|
478
|
+
key,
|
|
407
479
|
)
|
|
408
480
|
|
|
481
|
+
if verbose:
|
|
482
|
+
print("Initialization time:", time.time() - initialization_time)
|
|
483
|
+
|
|
409
484
|
# Main optimization loop. We use the LAX while loop (fully jitted) version
|
|
410
485
|
# if no mixing devices. Otherwise we use the standard while loop. Here devices only
|
|
411
486
|
# concern obs_batch, but it could lead to more complex scheme in the future
|
|
@@ -443,6 +518,7 @@ def solve(
|
|
|
443
518
|
loss_container,
|
|
444
519
|
stored_objects,
|
|
445
520
|
validation_crit_values,
|
|
521
|
+
key,
|
|
446
522
|
) = carry
|
|
447
523
|
|
|
448
524
|
if verbose:
|
|
@@ -480,6 +556,7 @@ def solve(
|
|
|
480
556
|
loss, # return the Loss if needed (no-inplace modif)
|
|
481
557
|
optimization.opt_state,
|
|
482
558
|
stored_objects.stored_params,
|
|
559
|
+
loss_container.stored_weights_terms,
|
|
483
560
|
validation_crit_values if validation is not None else None,
|
|
484
561
|
validation_parameters,
|
|
485
562
|
)
|
|
@@ -487,16 +564,12 @@ def solve(
|
|
|
487
564
|
|
|
488
565
|
@partial(jit, static_argnames=["optimizer"])
|
|
489
566
|
def _gradient_step(
|
|
490
|
-
|
|
567
|
+
grads: Params[Array],
|
|
491
568
|
optimizer: optax.GradientTransformation,
|
|
492
|
-
batch: AnyBatch,
|
|
493
569
|
params: Params[Array],
|
|
494
570
|
opt_state: optax.OptState,
|
|
495
571
|
last_non_nan_params: Params[Array],
|
|
496
572
|
) -> tuple[
|
|
497
|
-
AnyLoss,
|
|
498
|
-
float,
|
|
499
|
-
dict[str, float],
|
|
500
573
|
Params[Array],
|
|
501
574
|
optax.OptState,
|
|
502
575
|
Params[Array],
|
|
@@ -504,13 +577,12 @@ def _gradient_step(
|
|
|
504
577
|
"""
|
|
505
578
|
optimizer cannot be jit-ted.
|
|
506
579
|
"""
|
|
507
|
-
|
|
508
|
-
(loss_val, loss_terms), grads = value_grad_loss(params, batch)
|
|
580
|
+
|
|
509
581
|
updates, opt_state = optimizer.update(
|
|
510
|
-
grads,
|
|
582
|
+
grads, # type: ignore
|
|
511
583
|
opt_state,
|
|
512
584
|
params, # type: ignore
|
|
513
|
-
) # see optimizer.init for explaination
|
|
585
|
+
) # see optimizer.init for explaination for the ignore(s) here
|
|
514
586
|
params = optax.apply_updates(params, updates) # type: ignore
|
|
515
587
|
|
|
516
588
|
# check if any of the parameters is NaN
|
|
@@ -522,9 +594,6 @@ def _gradient_step(
|
|
|
522
594
|
)
|
|
523
595
|
|
|
524
596
|
return (
|
|
525
|
-
loss,
|
|
526
|
-
loss_val,
|
|
527
|
-
loss_terms,
|
|
528
597
|
params,
|
|
529
598
|
opt_state,
|
|
530
599
|
last_non_nan_params,
|
|
@@ -551,13 +620,13 @@ def _print_fn(i: int, loss_val: Float, print_loss_every: int, prefix: str = ""):
|
|
|
551
620
|
def _store_loss_and_params(
|
|
552
621
|
i: int,
|
|
553
622
|
params: Params[Array],
|
|
554
|
-
stored_params: Params[Array],
|
|
555
|
-
|
|
556
|
-
train_loss_values: Float[Array, " n_iter"],
|
|
623
|
+
stored_params: Params[Array | None],
|
|
624
|
+
loss_container: LossContainer,
|
|
557
625
|
train_loss_val: float,
|
|
558
|
-
loss_terms:
|
|
626
|
+
loss_terms: PyTree[Array],
|
|
627
|
+
weight_terms: PyTree[Array],
|
|
559
628
|
tracked_params: Params,
|
|
560
|
-
) -> tuple[
|
|
629
|
+
) -> tuple[StoredObjectContainer, LossContainer]:
|
|
561
630
|
stored_params = jax.tree_util.tree_map(
|
|
562
631
|
lambda stored_value, param, tracked_param: (
|
|
563
632
|
None
|
|
@@ -576,12 +645,38 @@ def _store_loss_and_params(
|
|
|
576
645
|
)
|
|
577
646
|
stored_loss_terms = jax.tree_util.tree_map(
|
|
578
647
|
lambda stored_term, loss_term: stored_term.at[i].set(loss_term),
|
|
579
|
-
stored_loss_terms,
|
|
648
|
+
loss_container.stored_loss_terms,
|
|
580
649
|
loss_terms,
|
|
581
650
|
)
|
|
582
651
|
|
|
583
|
-
|
|
584
|
-
|
|
652
|
+
if loss_container.stored_weights_terms is not None:
|
|
653
|
+
stored_weights_terms = jax.tree_util.tree_map(
|
|
654
|
+
lambda stored_term, weight_term: stored_term.at[i].set(weight_term),
|
|
655
|
+
jax.tree.leaves(
|
|
656
|
+
loss_container.stored_weights_terms,
|
|
657
|
+
is_leaf=lambda x: x is not None and eqx.is_inexact_array(x),
|
|
658
|
+
),
|
|
659
|
+
jax.tree.leaves(
|
|
660
|
+
weight_terms,
|
|
661
|
+
is_leaf=lambda x: x is not None and eqx.is_inexact_array(x),
|
|
662
|
+
),
|
|
663
|
+
)
|
|
664
|
+
stored_weights_terms = eqx.tree_at(
|
|
665
|
+
lambda pt: jax.tree.leaves(
|
|
666
|
+
pt, is_leaf=lambda x: x is not None and eqx.is_inexact_array(x)
|
|
667
|
+
),
|
|
668
|
+
loss_container.stored_weights_terms,
|
|
669
|
+
stored_weights_terms,
|
|
670
|
+
)
|
|
671
|
+
else:
|
|
672
|
+
stored_weights_terms = None
|
|
673
|
+
|
|
674
|
+
train_loss_values = loss_container.train_loss_values.at[i].set(train_loss_val)
|
|
675
|
+
loss_container = LossContainer(
|
|
676
|
+
stored_loss_terms, stored_weights_terms, train_loss_values
|
|
677
|
+
)
|
|
678
|
+
stored_objects = StoredObjectContainer(stored_params)
|
|
679
|
+
return stored_objects, loss_container
|
|
585
680
|
|
|
586
681
|
|
|
587
682
|
def _get_break_fun(n_iter: int, verbose: bool) -> Callable[[main_carry], bool]:
|
|
@@ -612,7 +707,7 @@ def _get_break_fun(n_iter: int, verbose: bool) -> Callable[[main_carry], bool]:
|
|
|
612
707
|
def continue_while_loop(_):
|
|
613
708
|
return True
|
|
614
709
|
|
|
615
|
-
(i, _, optimization, optimization_extra, _, _, _, _, _) = carry
|
|
710
|
+
(i, _, optimization, optimization_extra, _, _, _, _, _, _) = carry
|
|
616
711
|
|
|
617
712
|
# Condition 1
|
|
618
713
|
bool_max_iter = jax.lax.cond(
|
jinns/utils/_containers.py
CHANGED
|
@@ -6,7 +6,7 @@ from __future__ import (
|
|
|
6
6
|
annotations,
|
|
7
7
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
8
|
|
|
9
|
-
from typing import TYPE_CHECKING
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
10
10
|
from jaxtyping import PyTree, Array, Float, Bool
|
|
11
11
|
from optax import OptState
|
|
12
12
|
import equinox as eqx
|
|
@@ -48,7 +48,10 @@ class OptimizationExtraContainer(eqx.Module):
|
|
|
48
48
|
|
|
49
49
|
|
|
50
50
|
class LossContainer(eqx.Module):
|
|
51
|
-
|
|
51
|
+
# PyTree below refers to ODEComponents or PDEStatioComponents or
|
|
52
|
+
# PDENonStatioComponents
|
|
53
|
+
stored_loss_terms: PyTree[Float[Array, " n_iter"]]
|
|
54
|
+
stored_weights_terms: PyTree[Float[Array, " n_iter"]]
|
|
52
55
|
train_loss_values: Float[Array, " n_iter"]
|
|
53
56
|
|
|
54
57
|
|
jinns/utils/_types.py
CHANGED
|
@@ -9,6 +9,11 @@ if TYPE_CHECKING:
|
|
|
9
9
|
from jinns.data._Batchs import ODEBatch, PDEStatioBatch, PDENonStatioBatch
|
|
10
10
|
from jinns.loss._LossODE import LossODE
|
|
11
11
|
from jinns.loss._LossPDE import LossPDEStatio, LossPDENonStatio
|
|
12
|
+
from jinns.loss._loss_components import (
|
|
13
|
+
ODEComponents,
|
|
14
|
+
PDEStatioComponents,
|
|
15
|
+
PDENonStatioComponents,
|
|
16
|
+
)
|
|
12
17
|
|
|
13
18
|
# Here we define types available for the whole package
|
|
14
19
|
BoundaryConditionFun: TypeAlias = Callable[
|
|
@@ -17,3 +22,10 @@ if TYPE_CHECKING:
|
|
|
17
22
|
|
|
18
23
|
AnyBatch: TypeAlias = ODEBatch | PDENonStatioBatch | PDEStatioBatch
|
|
19
24
|
AnyLoss: TypeAlias = LossODE | LossPDEStatio | LossPDENonStatio
|
|
25
|
+
|
|
26
|
+
# here we would like a type from 3.12
|
|
27
|
+
# (https://typing.python.org/en/latest/spec/aliases.html#type-statement) so
|
|
28
|
+
# that we could have a generic AnyLossComponents
|
|
29
|
+
AnyLossComponents: TypeAlias = (
|
|
30
|
+
ODEComponents | PDEStatioComponents | PDENonStatioComponents
|
|
31
|
+
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: jinns
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.5.1
|
|
4
4
|
Summary: Physics Informed Neural Network with JAX
|
|
5
5
|
Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
6
6
|
Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
@@ -54,6 +54,9 @@ It can also be used for forward problems and hybrid-modeling.
|
|
|
54
54
|
|
|
55
55
|
- [Hyper PINNs](https://arxiv.org/pdf/2111.01008.pdf): useful for meta-modeling
|
|
56
56
|
|
|
57
|
+
- Other
|
|
58
|
+
- Adaptative Loss Weights are now implemented. Some SoftAdapt, LRAnnealing and ReLoBRaLo are available and users can implement their own strategy. See the [tutorial](https://mia_jinns.gitlab.io/jinns/Notebooks/Tutorials/implementing_your_own_PDE_problem/)
|
|
59
|
+
|
|
57
60
|
|
|
58
61
|
- **Get started**: check out our various notebooks on the [documentation](https://mia_jinns.gitlab.io/jinns/index.html).
|
|
59
62
|
|
|
@@ -113,7 +116,7 @@ pre-commit install
|
|
|
113
116
|
|
|
114
117
|
Don't hesitate to contribute and get your name on the list here !
|
|
115
118
|
|
|
116
|
-
**List of contributors:** Hugo Gangloff, Nicolas Jouvin, Lucia Clarotto, Inass Soukarieh
|
|
119
|
+
**List of contributors:** Hugo Gangloff, Nicolas Jouvin, Lucia Clarotto, Inass Soukarieh, Mohamed Badi
|
|
117
120
|
|
|
118
121
|
# Cite us
|
|
119
122
|
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
jinns/__init__.py,sha256=
|
|
1
|
+
jinns/__init__.py,sha256=Tjp5z0Mnd1nscvJaXnxZb4lEYbI0cpN6OPgCZ-Swo74,453
|
|
2
2
|
jinns/data/_AbstractDataGenerator.py,sha256=O61TBOyeOFKwf1xqKzFD4KwCWRDnm2XgyJ-kKY9fmB4,557
|
|
3
3
|
jinns/data/_Batchs.py,sha256=-DlD6Qag3zs5QbKtKAOvOzV7JOpNOqAm_P8cwo1dIZg,1574
|
|
4
|
-
jinns/data/_CubicMeshPDENonStatio.py,sha256=
|
|
5
|
-
jinns/data/_CubicMeshPDEStatio.py,sha256=
|
|
4
|
+
jinns/data/_CubicMeshPDENonStatio.py,sha256=4f21SgeNQsGJTz8-uehduU0X9TibRcs28Iq49Kv4nQQ,22250
|
|
5
|
+
jinns/data/_CubicMeshPDEStatio.py,sha256=DVrP4qHVAJMu915EYH2PKyzwoG0nIMixAEKR2fz6C58,22525
|
|
6
6
|
jinns/data/_DataGeneratorODE.py,sha256=5RzUbQFEsooAZsocDw4wRgA_w5lJmDMuY4M6u79K-1c,7260
|
|
7
7
|
jinns/data/_DataGeneratorObservations.py,sha256=jknepLsJatSJHFq5lLMD-fFHkPGj5q286LEjE-vH24k,7738
|
|
8
8
|
jinns/data/_DataGeneratorParameter.py,sha256=IedX3jcOj7ZDW_18IAcRR75KVzQzo85z9SICIKDBJl4,8539
|
|
@@ -11,14 +11,16 @@ jinns/data/_utils.py,sha256=XxaLIg_HIgcB7ACBIhTpHbCT1HXKcDaY1NABncAYX1c,5223
|
|
|
11
11
|
jinns/experimental/__init__.py,sha256=DT9e57zbjfzPeRnXemGUqnGd--MhV77FspChT0z4YrE,410
|
|
12
12
|
jinns/experimental/_diffrax_solver.py,sha256=upMr3kTTNrxEiSUO_oLvCXcjS9lPxSjvbB81h3qlhaU,6813
|
|
13
13
|
jinns/loss/_DynamicLoss.py,sha256=4mb7OCP-cGZ_mG2MQ-AniddDcuBT78p4bQI7rZpwte4,22722
|
|
14
|
-
jinns/loss/_DynamicLossAbstract.py,sha256=
|
|
15
|
-
jinns/loss/_LossODE.py,sha256=
|
|
16
|
-
jinns/loss/_LossPDE.py,sha256=
|
|
17
|
-
jinns/loss/__init__.py,sha256=
|
|
18
|
-
jinns/loss/_abstract_loss.py,sha256=
|
|
14
|
+
jinns/loss/_DynamicLossAbstract.py,sha256=QhHRgvtcT-ifHlOxTyXbjDtHk9UfPN2Si8s3v9nEQm4,12672
|
|
15
|
+
jinns/loss/_LossODE.py,sha256=iVYDojaI6Co7S5CrU67_niopD4Bk7UBTuLzDiTHoWMc,16996
|
|
16
|
+
jinns/loss/_LossPDE.py,sha256=VT56oQ_33fLq46lIch0slNsxu4d97eQBOgRAPeFESts,36401
|
|
17
|
+
jinns/loss/__init__.py,sha256=z5xYgBipNFf66__5BqQc6R_8r4F6A3TXL60YjsM8Osk,1287
|
|
18
|
+
jinns/loss/_abstract_loss.py,sha256=DMxn0SQe9PW-pq3p5Oqvb0YK3_ulLDOnoIXzK219GV4,4576
|
|
19
19
|
jinns/loss/_boundary_conditions.py,sha256=9HGw1cGLfmEilP4V4B2T0zl0YP1kNtrtXVLQNiBmWgc,12464
|
|
20
|
-
jinns/loss/
|
|
21
|
-
jinns/loss/
|
|
20
|
+
jinns/loss/_loss_components.py,sha256=MMzaGlaRqESPjRzT0j0WU9HAqWQSbIXpGAqM1xQCZHw,1106
|
|
21
|
+
jinns/loss/_loss_utils.py,sha256=eJ4JcBm396LHx7Tti88ZQrLcKqVL1oSfFGT23VNkytQ,11949
|
|
22
|
+
jinns/loss/_loss_weight_updates.py,sha256=9Bwouh7shLyc_wrdzN6CYL0ZuQH81uEs-L6wCeiYFx8,6817
|
|
23
|
+
jinns/loss/_loss_weights.py,sha256=kII5WddORgeommFTudT3CSvhICpo6nSe47LclUgu_78,2429
|
|
22
24
|
jinns/loss/_operators.py,sha256=Ds5yRH7hu-jaGRp7PYbt821BgYuEvgWHufWhYgdMjw0,22909
|
|
23
25
|
jinns/nn/__init__.py,sha256=gwE48oqB_FsSIE-hUvCLz0jPaqX350LBxzH6ueFWYk4,456
|
|
24
26
|
jinns/nn/_abstract_pinn.py,sha256=JUFjlV_nyheZw-max_tAUgFh6SspIbD5we_4bn70V6k,671
|
|
@@ -32,22 +34,22 @@ jinns/nn/_spinn_mlp.py,sha256=uCL454sF0Tfj7KT-fdXPnvKJYRQOuq60N0r2b2VAB8Q,7606
|
|
|
32
34
|
jinns/nn/_utils.py,sha256=9UXz73iHKHVQYPBPIEitrHYJzJ14dspRwPfLA8avx0c,1120
|
|
33
35
|
jinns/parameters/__init__.py,sha256=O0n7y6R1LRmFzzugCxMFCMS2pgsuWSh-XHjfFViN_eg,265
|
|
34
36
|
jinns/parameters/_derivative_keys.py,sha256=YlLDX49PfYhr2Tj--t3praiD8JOUTZU6PTmjbNZsbMc,19173
|
|
35
|
-
jinns/parameters/_params.py,sha256=
|
|
37
|
+
jinns/parameters/_params.py,sha256=nv0WScbgUdmuC0bSF15VbnKypJ58pl6wynZAcYfuF6M,3081
|
|
36
38
|
jinns/plot/__init__.py,sha256=KPHX0Um4FbciZO1yD8kjZbkaT8tT964Y6SE2xCQ4eDU,135
|
|
37
39
|
jinns/plot/_plot.py,sha256=-A5auNeElaz2_8UzVQJQE4143ZFg0zgMjStU7kwttEY,11565
|
|
38
40
|
jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
39
41
|
jinns/solver/_rar.py,sha256=vSVTnCGCusI1vTZCvIkP2_G8we44G_42yZHx2sOK9DE,10291
|
|
40
|
-
jinns/solver/_solve.py,sha256=
|
|
42
|
+
jinns/solver/_solve.py,sha256=IsrmG2m48KkkYgvXYomlSbZ3hd1FySCj3rwlkovs-lI,28616
|
|
41
43
|
jinns/solver/_utils.py,sha256=sM2UbVzYyjw24l4QSIR3IlynJTPGD_S08r8v0lXMxA8,5876
|
|
42
44
|
jinns/utils/__init__.py,sha256=OEYWLCw8pKE7xoQREbd6SHvCjuw2QZHuVA6YwDcsBE8,53
|
|
43
|
-
jinns/utils/_containers.py,sha256=
|
|
44
|
-
jinns/utils/_types.py,sha256=
|
|
45
|
+
jinns/utils/_containers.py,sha256=YShcrPKfj5_I9mn3NMAS4Ea9MhhyL7fjv0e3MRbITHg,1837
|
|
46
|
+
jinns/utils/_types.py,sha256=jl_91HtcrtE6UHbdTrRI8iUmr2kBUL0oP0UNIKhAXYw,1170
|
|
45
47
|
jinns/utils/_utils.py,sha256=M7NXX9ok-BkH5o_xo74PB1_Cc8XiDipSl51rq82dTH4,2821
|
|
46
48
|
jinns/validation/__init__.py,sha256=FTyUO-v1b8Tv-FDSQsntrH7zl9E0ENexqKMT_dFRkYo,124
|
|
47
49
|
jinns/validation/_validation.py,sha256=8p6sMKiBAvA6JNm65hjkMj0997LJ0BkyCREEh0AnPVE,4803
|
|
48
|
-
jinns-1.
|
|
49
|
-
jinns-1.
|
|
50
|
-
jinns-1.
|
|
51
|
-
jinns-1.
|
|
52
|
-
jinns-1.
|
|
53
|
-
jinns-1.
|
|
50
|
+
jinns-1.5.1.dist-info/licenses/AUTHORS,sha256=7NwCj9nU-HNG1asvy4qhQ2w7oZHrn-Lk5_wK_Ve7a3M,80
|
|
51
|
+
jinns-1.5.1.dist-info/licenses/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
|
|
52
|
+
jinns-1.5.1.dist-info/METADATA,sha256=K7Aii5ivFcczIwLlQtCPqzMJEfW86D1yW1q7qMtvWPE,5314
|
|
53
|
+
jinns-1.5.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
54
|
+
jinns-1.5.1.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
|
|
55
|
+
jinns-1.5.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|