jinns 1.6.0__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/solver/_utils.py CHANGED
@@ -1,7 +1,510 @@
1
+ """
2
+ Common functions for _solve.py and _solve_alternate.py
3
+ """
4
+
5
+ from __future__ import (
6
+ annotations,
7
+ ) # https://docs.python.org/3/library/typing.html#constant
8
+
9
+ from typing import TYPE_CHECKING, Callable
10
+ from functools import partial
11
+ import jax
12
+ from jax import jit
13
+ import jax.numpy as jnp
14
+ import equinox as eqx
15
+ from jaxtyping import PyTree, Float, Array, PRNGKeyArray
16
+ import optax
17
+
18
+ from jinns.data._utils import append_param_batch, append_obs_batch
19
+ from jinns.utils._utils import _check_nan_in_pytree
1
20
  from jinns.data._DataGeneratorODE import DataGeneratorODE
2
21
  from jinns.data._CubicMeshPDEStatio import CubicMeshPDEStatio
3
22
  from jinns.data._CubicMeshPDENonStatio import CubicMeshPDENonStatio
4
23
  from jinns.data._DataGeneratorParameter import DataGeneratorParameter
24
+ from jinns.parameters._params import Params
25
+ from jinns.utils._containers import (
26
+ LossContainer,
27
+ StoredObjectContainer,
28
+ )
29
+
30
+ if TYPE_CHECKING:
31
+ from jinns.utils._types import AnyBatch, SolveCarry, SolveAlternateCarry
32
+ from jinns.loss._abstract_loss import AbstractLoss
33
+ from jinns.data._DataGeneratorObservations import DataGeneratorObservations
34
+ from jinns.data._AbstractDataGenerator import AbstractDataGenerator
35
+
36
+
37
+ def _init_stored_weights_terms(loss, n_iter):
38
+ return eqx.tree_at(
39
+ lambda pt: jax.tree.leaves(
40
+ pt, is_leaf=lambda x: x is not None and eqx.is_inexact_array(x)
41
+ ),
42
+ loss.loss_weights,
43
+ tuple(
44
+ jnp.zeros((n_iter))
45
+ for n in range(
46
+ len(
47
+ jax.tree.leaves(
48
+ loss.loss_weights,
49
+ is_leaf=lambda x: x is not None and eqx.is_inexact_array(x),
50
+ )
51
+ )
52
+ )
53
+ ),
54
+ )
55
+
56
+
57
+ def _init_stored_params(tracked_params, params, n_iter):
58
+ return jax.tree_util.tree_map(
59
+ lambda tracked_param, param: (
60
+ jnp.zeros((n_iter,) + jnp.asarray(param).shape)
61
+ if tracked_param is not None
62
+ else None
63
+ ),
64
+ tracked_params,
65
+ params,
66
+ is_leaf=lambda x: x is None, # None values in tracked_params will not
67
+ # be traversed. Thus the user can provide something like
68
+ # ```
69
+ # tracked_params = jinns.parameters.Params(
70
+ # nn_params=None,
71
+ # eq_params={"nu": True})
72
+ # ```
73
+ # even when init_params.nn_params is a complex data structure.
74
+ )
75
+
76
+
77
+ @partial(jit, static_argnames=["optimizer", "params_mask", "with_loss_weight_update"])
78
+ def _loss_evaluate_and_gradient_step(
79
+ i,
80
+ batch: AnyBatch,
81
+ loss: AbstractLoss,
82
+ params: Params[Array],
83
+ last_non_nan_params: Params[Array],
84
+ state: optax.OptState,
85
+ optimizer: optax.GradientTransformation,
86
+ loss_container: LossContainer,
87
+ key: PRNGKeyArray,
88
+ params_mask: Params[bool] | None = None,
89
+ opt_state_field_for_acceleration: str | None = None,
90
+ with_loss_weight_update: bool = True,
91
+ ):
92
+ """
93
+ # The crux of our new approach is partitioning and recombining the parameters and optimization state according to params_mask.
94
+
95
+ params_mask:
96
+ A jinns.parameters.Params object with boolean as leaves, specifying
97
+ over which parameters optimization is enabled. This usually implies
98
+ important computational gains. Internally, it is used as the
99
+ filter_spec of a eqx.partition function on the parameters. Note that this
100
+ differs from (and complement) DerivativeKeys, as the latter allows
101
+ for more granularity by freezing some gradients with respect to
102
+ different loss terms, but do not subset the optimized parameters globally.
103
+
104
+ NOTE: in this function body, we change naming convention for concision:
105
+ * `state` refers to the general optimizer state
106
+ * `opt_state` refers to the unmasked optimizer state, i.e. which are
107
+ really involved in the parameter update as defined by `params_mask`.
108
+ * `non_opt_state` refers to the the optimizer state for non-optimized
109
+ params.
110
+ """
111
+
112
+ (
113
+ opt_params,
114
+ opt_params_accel,
115
+ non_opt_params,
116
+ opt_state,
117
+ non_opt_state,
118
+ ) = _get_masked_optimization_stuff(
119
+ params, state, opt_state_field_for_acceleration, params_mask
120
+ )
121
+
122
+ # The following part is the equivalent of a
123
+ # > train_loss_value, grads = jax.values_and_grad(total_loss.evaluate)(params, ...)
124
+ # but it is decomposed on individual loss terms so that we can use it
125
+ # if needed for updating loss weights.
126
+ # Since the total loss is a weighted sum of individual loss terms, so
127
+ # are its total gradients.
128
+
129
+ # 1. Compute individual losses and individual gradients
130
+ loss_terms, grad_terms = loss.evaluate_by_terms(
131
+ opt_params_accel
132
+ if opt_state_field_for_acceleration is not None
133
+ else opt_params,
134
+ batch,
135
+ non_opt_params=non_opt_params,
136
+ )
137
+
138
+ if loss.update_weight_method is not None and with_loss_weight_update:
139
+ key, subkey = jax.random.split(key) # type: ignore because key can
140
+ # still be None currently
141
+ # avoid computations of tree_at if no updates
142
+ loss = loss.update_weights(
143
+ i, loss_terms, loss_container.stored_loss_terms, grad_terms, subkey
144
+ )
145
+
146
+ # 2. total grad
147
+ grads = loss.ponderate_and_sum_gradient(grad_terms)
148
+
149
+ # total loss
150
+ train_loss_value = loss.ponderate_and_sum_loss(loss_terms)
151
+
152
+ opt_grads, _ = grads.partition(
153
+ params_mask
154
+ ) # because the update cannot be made otherwise
155
+
156
+ # Here, we only use the gradient step of the Optax optimizer on the
157
+ # parameters specified by params_mask. , no dummy state with filled with zero entries
158
+ # all other entries of the pytrees are None thanks to params_mask)
159
+ opt_params, opt_state = _gradient_step(
160
+ opt_grads,
161
+ optimizer,
162
+ opt_params, # NOTE that we never give the accelerated
163
+ # params here, this would be a wrong procedure
164
+ opt_state,
165
+ )
166
+
167
+ params, state = _get_unmasked_optimization_stuff(
168
+ opt_params,
169
+ non_opt_params,
170
+ state,
171
+ opt_state,
172
+ non_opt_state,
173
+ params_mask,
174
+ )
175
+
176
+ # check if any of the parameters is NaN
177
+ last_non_nan_params = jax.lax.cond(
178
+ _check_nan_in_pytree(params),
179
+ lambda _: last_non_nan_params,
180
+ lambda _: params,
181
+ None,
182
+ )
183
+ return train_loss_value, params, last_non_nan_params, state, loss, loss_terms
184
+
185
+
186
+ @partial(
187
+ jit,
188
+ static_argnames=["optimizer"],
189
+ )
190
+ def _gradient_step(
191
+ grads: Params[Array],
192
+ optimizer: optax.GradientTransformation,
193
+ params: Params[Array],
194
+ state: optax.OptState,
195
+ ) -> tuple[
196
+ Params[Array],
197
+ optax.OptState,
198
+ ]:
199
+ """
200
+ optimizer cannot be jit-ted.
201
+
202
+ a plain old gradient step that is compatible with the new masked update
203
+ stuff
204
+ """
205
+
206
+ updates, state = optimizer.update(
207
+ grads, # type: ignore
208
+ state,
209
+ params, # type: ignore
210
+ ) # Also see optimizer.init for explanation of type ignore
211
+ params = optax.apply_updates(params, updates) # type: ignore
212
+
213
+ return (
214
+ params,
215
+ state,
216
+ )
217
+
218
+
219
+ @partial(jit, static_argnames=["params_mask"])
220
+ def _get_masked_optimization_stuff(
221
+ params, state, state_field_for_acceleration, params_mask
222
+ ):
223
+ """
224
+ From the parameters `params`, the optimizer state `state`, we use the
225
+ parameter mask `params_mask` to retrieve the partitioned version of those
226
+ two objects, `opt_params` for the parameters that are optimized and
227
+ `non_opt_params` for those that are not optimized. Same for `state`.
228
+
229
+ The argument `state_field_for_acceleration` can correspond to a field
230
+ inside the `state` module. If it is not None, a `opt_params_accel` object
231
+ is created that is different of `opt_params`. See
232
+ `opt_state_field_for_acceleration` in `jinns.solve` docstring for more
233
+ details.
234
+
235
+ The opposite of `eqx.partition` ie, `eqx.combine` is made in the loss
236
+ `evaluevaluate_by_terms()` method for the computations and in
237
+ `_get_unmasked_optimization_stuff` to reconstruct the object after the
238
+ gradient step
239
+ """
240
+ opt_params, non_opt_params = params.partition(params_mask)
241
+ opt_state = jax.tree.map(
242
+ lambda l: l.partition(params_mask)[0] if isinstance(l, Params) else l,
243
+ state,
244
+ is_leaf=lambda x: isinstance(x, Params),
245
+ )
246
+ non_opt_state = jax.tree.map(
247
+ lambda l: l.partition(params_mask)[1] if isinstance(l, Params) else l,
248
+ state,
249
+ is_leaf=lambda x: isinstance(x, Params),
250
+ )
251
+
252
+ # NOTE to enable optimization procedures with acceleration
253
+ if state_field_for_acceleration is not None:
254
+ opt_params_accel = getattr(opt_state, state_field_for_acceleration)
255
+ else:
256
+ opt_params_accel = opt_params
257
+
258
+ return (
259
+ opt_params,
260
+ opt_params_accel,
261
+ non_opt_params,
262
+ opt_state,
263
+ non_opt_state,
264
+ )
265
+
266
+
267
+ @partial(jit, static_argnames=["params_mask"])
268
+ def _get_unmasked_optimization_stuff(
269
+ opt_params, non_opt_params, state, opt_state, non_opt_state, params_mask
270
+ ):
271
+ """
272
+ Reverse operations of `_get_masked_optimization_stuff`
273
+ """
274
+ # NOTE the combine which closes the partitioned chunck
275
+ if params_mask is not None:
276
+ params = eqx.combine(opt_params, non_opt_params)
277
+ state = jax.tree.map(
278
+ lambda a, b, c: eqx.combine(b, c) if isinstance(a, Params) else b,
279
+ # NOTE else b in order to take all non Params stuff from
280
+ # opt_state that may have been updated too
281
+ state,
282
+ opt_state,
283
+ non_opt_state,
284
+ is_leaf=lambda x: isinstance(x, Params),
285
+ )
286
+ else:
287
+ params = opt_params
288
+ state = opt_state
289
+
290
+ return params, state
291
+
292
+
293
+ @partial(jit, static_argnames=["prefix"])
294
+ def _print_fn(i: int, loss_val: Float, print_loss_every: int, prefix: str = ""):
295
+ # note that if the following is not jitted in the main for loop, it is
296
+ # super slow
297
+ _ = jax.lax.cond(
298
+ i % print_loss_every == 0,
299
+ lambda _: jax.debug.print(
300
+ prefix + "Iteration {i}: loss value = {loss_val}",
301
+ i=i,
302
+ loss_val=loss_val,
303
+ ),
304
+ lambda _: None,
305
+ (None,),
306
+ )
307
+
308
+
309
+ @jit
310
+ def _store_loss_and_params(
311
+ i: int,
312
+ params: Params[Array],
313
+ stored_params: Params[Array | None],
314
+ loss_container: LossContainer,
315
+ train_loss_val: float,
316
+ loss_terms: PyTree[Array],
317
+ weight_terms: PyTree[Array],
318
+ tracked_params: Params,
319
+ ) -> tuple[StoredObjectContainer, LossContainer]:
320
+ stored_params = jax.tree_util.tree_map(
321
+ lambda stored_value, param, tracked_param: (
322
+ None
323
+ if stored_value is None
324
+ else jax.lax.cond(
325
+ tracked_param,
326
+ lambda ope: ope[0].at[i].set(ope[1]),
327
+ lambda ope: ope[0],
328
+ (stored_value, param),
329
+ )
330
+ ),
331
+ stored_params,
332
+ params,
333
+ tracked_params,
334
+ is_leaf=lambda x: x is None,
335
+ )
336
+ stored_loss_terms = jax.tree_util.tree_map(
337
+ lambda stored_term, loss_term: stored_term.at[i].set(loss_term),
338
+ loss_container.stored_loss_terms,
339
+ loss_terms,
340
+ )
341
+
342
+ if loss_container.stored_weights_terms is not None:
343
+ stored_weights_terms = jax.tree_util.tree_map(
344
+ lambda stored_term, weight_term: stored_term.at[i].set(weight_term),
345
+ jax.tree.leaves(
346
+ loss_container.stored_weights_terms,
347
+ is_leaf=lambda x: x is not None and eqx.is_inexact_array(x),
348
+ ),
349
+ jax.tree.leaves(
350
+ weight_terms,
351
+ is_leaf=lambda x: x is not None and eqx.is_inexact_array(x),
352
+ ),
353
+ )
354
+ stored_weights_terms = eqx.tree_at(
355
+ lambda pt: jax.tree.leaves(
356
+ pt, is_leaf=lambda x: x is not None and eqx.is_inexact_array(x)
357
+ ),
358
+ loss_container.stored_weights_terms,
359
+ stored_weights_terms,
360
+ )
361
+ else:
362
+ stored_weights_terms = None
363
+
364
+ train_loss_values = loss_container.train_loss_values.at[i].set(train_loss_val)
365
+ loss_container = LossContainer(
366
+ stored_loss_terms, stored_weights_terms, train_loss_values
367
+ )
368
+ stored_objects = StoredObjectContainer(stored_params)
369
+ return stored_objects, loss_container
370
+
371
+
372
+ def _get_break_fun(
373
+ n_iter: int,
374
+ verbose: bool,
375
+ conditions_str: tuple[str, ...] = (
376
+ "bool_max_iter",
377
+ "bool_nan_in_params",
378
+ "bool_early_stopping",
379
+ ),
380
+ ) -> Callable[[SolveCarry | SolveAlternateCarry], bool]:
381
+ """
382
+ Wrapper to get the break_fun with appropriate `n_iter`.
383
+ The verbose argument is here to control printing (or not) when exiting
384
+ the optimisation loop. It can be convenient is jinns.solve is itself
385
+ called in a loop and user want to avoid std output.
386
+ """
387
+
388
+ @jit
389
+ def break_fun(carry: tuple):
390
+ """
391
+ Function to break from the main optimization loop whe the following
392
+ conditions are met : maximum number of iterations, NaN
393
+ appearing in the parameters, and early stopping criterion.
394
+ """
395
+
396
+ def stop_while_loop(msg):
397
+ """
398
+ Note that the message is wrapped in the jax.lax.cond because a
399
+ string is not a valid JAX type that can be fed into the operands
400
+ """
401
+ if verbose:
402
+ jax.debug.print(f"\nStopping main optimization loop, cause: {msg}")
403
+ return False
404
+
405
+ def continue_while_loop(_):
406
+ return True
407
+
408
+ i = carry[0]
409
+ optimization = carry[2]
410
+ optimization_extra = carry[3]
411
+
412
+ conditions_bool = ()
413
+ if "bool_max_iter" in conditions_str:
414
+ # Condition 1
415
+ bool_max_iter = jax.lax.cond(
416
+ i >= n_iter,
417
+ lambda _: stop_while_loop("max iteration is reached"),
418
+ continue_while_loop,
419
+ None,
420
+ )
421
+ conditions_bool += (bool_max_iter,)
422
+ if "bool_nan_in_params" in conditions_str:
423
+ # Condition 2
424
+ bool_nan_in_params = jax.lax.cond(
425
+ _check_nan_in_pytree(optimization.params),
426
+ lambda _: stop_while_loop(
427
+ "NaN values in parameters (returning last non NaN values)"
428
+ ),
429
+ continue_while_loop,
430
+ None,
431
+ )
432
+ conditions_bool += (bool_nan_in_params,)
433
+ if "bool_early_stopping" in conditions_str:
434
+ # Condition 3
435
+ bool_early_stopping = jax.lax.cond(
436
+ optimization_extra.early_stopping,
437
+ lambda _: stop_while_loop("early stopping"),
438
+ continue_while_loop,
439
+ None,
440
+ )
441
+ conditions_bool += (bool_early_stopping,)
442
+
443
+ # stop when one of the cond to continue is False
444
+ return jax.tree_util.tree_reduce(
445
+ lambda x, y: jnp.logical_and(jnp.array(x), jnp.array(y)),
446
+ conditions_bool,
447
+ )
448
+
449
+ return break_fun
450
+
451
+
452
+ def _build_get_batch(
453
+ obs_batch_sharding: jax.sharding.Sharding | None,
454
+ ) -> Callable[
455
+ [
456
+ AbstractDataGenerator,
457
+ DataGeneratorParameter | None,
458
+ DataGeneratorObservations | None,
459
+ ],
460
+ tuple[
461
+ AnyBatch,
462
+ AbstractDataGenerator,
463
+ DataGeneratorParameter | None,
464
+ DataGeneratorObservations | None,
465
+ ],
466
+ ]:
467
+ """
468
+ Return the get_batch function that will be used either the jittable one or
469
+ the non-jittable one with sharding using jax.device.put()
470
+ """
471
+
472
+ def get_batch_sharding(data, param_data, obs_data):
473
+ """
474
+ This function is used at each loop but it cannot be jitted because of
475
+ device_put
476
+ """
477
+ data, batch = data.get_batch()
478
+ if param_data is not None:
479
+ param_data, param_batch = param_data.get_batch()
480
+ batch = append_param_batch(batch, param_batch)
481
+ if obs_data is not None:
482
+ # This is the part that motivated the transition from scan to for loop
483
+ # Indeed we need to be transit obs_batch from CPU to GPU when we have
484
+ # huge observations that cannot fit on GPU. Such transfer wasn't meant
485
+ # to be jitted, i.e. in a scan loop
486
+ obs_data, obs_batch = obs_data.get_batch()
487
+ obs_batch = jax.device_put(obs_batch, obs_batch_sharding)
488
+ batch = append_obs_batch(batch, obs_batch)
489
+ return batch, data, param_data, obs_data
490
+
491
+ @jit
492
+ def get_batch(data, param_data, obs_data):
493
+ """
494
+ Original get_batch with no sharding
495
+ """
496
+ data, batch = data.get_batch()
497
+ if param_data is not None:
498
+ param_data, param_batch = param_data.get_batch()
499
+ batch = append_param_batch(batch, param_batch)
500
+ if obs_data is not None:
501
+ obs_data, obs_batch = obs_data.get_batch()
502
+ batch = append_obs_batch(batch, obs_batch)
503
+ return batch, data, param_data, obs_data
504
+
505
+ if obs_batch_sharding is not None:
506
+ return get_batch_sharding
507
+ return get_batch
5
508
 
6
509
 
7
510
  def _check_batch_size(other_data, main_data, attr_name):
@@ -1,6 +1,8 @@
1
1
  from typing import Any
2
2
  import equinox as eqx
3
3
 
4
+ from jinns.utils._ItemizableModule import ItemizableModule
5
+
4
6
 
5
7
  class DictToModuleMeta(type):
6
8
  """
@@ -42,7 +44,7 @@ class DictToModuleMeta(type):
42
44
  if self._class is None and class_name is not None:
43
45
  self._class = type(
44
46
  class_name,
45
- (eqx.Module,),
47
+ (ItemizableModule,),
46
48
  {"__annotations__": {k: type(v) for k, v in d.items()}},
47
49
  )
48
50
  try:
@@ -37,13 +37,17 @@ class OptimizationContainer(eqx.Module):
37
37
  params: Params
38
38
  last_non_nan_params: Params
39
39
  opt_state: OptState
40
+ # params_mask: Params = eqx.field(static=True) # to make params_mask
41
+ # hashable JAX type. See _gradient_step docstring
40
42
 
41
43
 
42
44
  class OptimizationExtraContainer(eqx.Module):
43
- curr_seq: int
44
- best_iter_id: int # the best iteration number (that which achieves best_val_params and best_val_params)
45
- best_val_criterion: float # the best validation criterion at early stopping
46
- best_val_params: Params # the best parameter values at early stopping
45
+ curr_seq: int | None
46
+ best_iter_id: (
47
+ int | None
48
+ ) # the best iteration number (that which achieves best_val_params and best_val_params)
49
+ best_val_criterion: float | None # the best validation criterion at early stopping
50
+ best_val_params: Params | None # the best parameter values at early stopping
47
51
  early_stopping: Bool = False
48
52
 
49
53
 
jinns/utils/_types.py CHANGED
@@ -3,7 +3,7 @@ from __future__ import (
3
3
  ) # https://docs.python.org/3/library/typing.html#constant
4
4
 
5
5
  from typing import TypeAlias, TYPE_CHECKING, Callable, TypeVar
6
- from jaxtyping import Float, Array
6
+ from jaxtyping import Float, Array, PRNGKeyArray
7
7
 
8
8
  from jinns.data._Batchs import ODEBatch, PDEStatioBatch, PDENonStatioBatch, ObsBatchDict
9
9
  from jinns.loss._loss_weights import (
@@ -11,6 +11,11 @@ from jinns.loss._loss_weights import (
11
11
  LossWeightsPDEStatio,
12
12
  LossWeightsPDENonStatio,
13
13
  )
14
+ from jinns.parameters._derivative_keys import (
15
+ DerivativeKeysODE,
16
+ DerivativeKeysPDENonStatio,
17
+ DerivativeKeysPDEStatio,
18
+ )
14
19
  from jinns.loss._loss_components import (
15
20
  ODEComponents,
16
21
  PDEStatioComponents,
@@ -19,6 +24,9 @@ from jinns.loss._loss_components import (
19
24
 
20
25
  AnyBatch: TypeAlias = ODEBatch | PDENonStatioBatch | PDEStatioBatch | ObsBatchDict
21
26
 
27
+ AnyDerivativeKeys: TypeAlias = (
28
+ DerivativeKeysODE | DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio
29
+ )
22
30
  AnyLossWeights: TypeAlias = (
23
31
  LossWeightsODE | LossWeightsPDEStatio | LossWeightsPDENonStatio
24
32
  )
@@ -30,6 +38,15 @@ AnyLossComponents: TypeAlias = (
30
38
  )
31
39
 
32
40
  if TYPE_CHECKING:
41
+ from jinns.utils._containers import (
42
+ DataGeneratorContainer,
43
+ OptimizationContainer,
44
+ OptimizationExtraContainer,
45
+ LossContainer,
46
+ StoredObjectContainer,
47
+ )
48
+ from jinns.validation._validation import AbstractValidationModule
49
+ from jinns.loss._abstract_loss import AbstractLoss
33
50
  from jinns.loss._LossODE import LossODE
34
51
  from jinns.loss._LossPDE import LossPDEStatio, LossPDENonStatio
35
52
 
@@ -39,3 +56,27 @@ if TYPE_CHECKING:
39
56
  ]
40
57
 
41
58
  AnyLoss: TypeAlias = LossODE | LossPDEStatio | LossPDENonStatio
59
+
60
+ SolveCarry: TypeAlias = tuple[
61
+ int,
62
+ AbstractLoss,
63
+ OptimizationContainer,
64
+ OptimizationExtraContainer,
65
+ DataGeneratorContainer,
66
+ AbstractValidationModule | None,
67
+ LossContainer,
68
+ StoredObjectContainer,
69
+ Float[Array, " n_iter"] | None,
70
+ PRNGKeyArray | None,
71
+ ]
72
+
73
+ SolveAlternateCarry: TypeAlias = tuple[
74
+ int,
75
+ AbstractLoss,
76
+ OptimizationContainer,
77
+ OptimizationExtraContainer,
78
+ DataGeneratorContainer,
79
+ LossContainer,
80
+ StoredObjectContainer,
81
+ PRNGKeyArray | None,
82
+ ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jinns
3
- Version: 1.6.0
3
+ Version: 1.7.0
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -8,23 +8,26 @@ License: Apache License 2.0
8
8
  Project-URL: Repository, https://gitlab.com/mia_jinns/jinns
9
9
  Project-URL: Documentation, https://mia_jinns.gitlab.io/jinns/index.html
10
10
  Classifier: License :: OSI Approved :: Apache Software License
11
- Classifier: Development Status :: 4 - Beta
11
+ Classifier: Development Status :: 5 - Production/Stable
12
12
  Classifier: Programming Language :: Python
13
13
  Requires-Python: >=3.11
14
14
  Description-Content-Type: text/markdown
15
15
  License-File: LICENSE
16
16
  License-File: AUTHORS
17
- Requires-Dist: numpy
18
- Requires-Dist: jax
19
- Requires-Dist: jaxopt
20
- Requires-Dist: optax
21
- Requires-Dist: equinox>0.11.3
22
- Requires-Dist: jax-tqdm
23
- Requires-Dist: diffrax
17
+ Requires-Dist: numpy>=2.0.0
18
+ Requires-Dist: jax>=0.8.1
19
+ Requires-Dist: optax>=0.2.6
20
+ Requires-Dist: equinox>=0.13.2
24
21
  Requires-Dist: matplotlib
22
+ Requires-Dist: jaxtyping
25
23
  Provides-Extra: notebook
26
24
  Requires-Dist: jupyter; extra == "notebook"
27
25
  Requires-Dist: seaborn; extra == "notebook"
26
+ Requires-Dist: pandas; extra == "notebook"
27
+ Requires-Dist: pytest; extra == "notebook"
28
+ Requires-Dist: pre-commit; extra == "notebook"
29
+ Requires-Dist: pyright; extra == "notebook"
30
+ Requires-Dist: diffrax; extra == "notebook"
28
31
  Dynamic: license-file
29
32
 
30
33
  jinns
@@ -32,12 +35,11 @@ jinns
32
35
 
33
36
  ![status](https://gitlab.com/mia_jinns/jinns/badges/main/pipeline.svg) ![coverage](https://gitlab.com/mia_jinns/jinns/badges/main/coverage.svg)
34
37
 
35
- Physics Informed Neural Networks with JAX. **jinns** is developed to estimate solutions of ODE and PDE problems using neural networks, with a strong focus on
38
+ Physics Informed Neural Networks with JAX. **jinns** is a Python package for physics-informed neural networks (PINNs) in the [JAX](https://jax.readthedocs.io/en/latest/) ecosystem. It provides an intuitive and flexible interface for
36
39
 
37
- 1. inverse problems: find equation parameters given noisy/indirect observations
38
- 2. meta-modeling: solve for a parametric family of differential equations
39
-
40
- It can also be used for forward problems and hybrid-modeling.
40
+ * forward problem: learning a PDE solution.
41
+ * inverse problem: learning the parameters of a PDE. **New in jinns v1.7.0:** `jinns.solve_alternate()` for fine-grained and efficient inverse problems.
42
+ * meta-modeling: learning a family of PDE indexed by its parameters.
41
43
 
42
44
  **jinns** specific points:
43
45