jinns 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/solver/_solve.py CHANGED
@@ -1,8 +1,8 @@
1
1
  from functools import partial
2
- from jaxopt import OptaxSolver
2
+ from jaxopt import OptaxSolver, LBFGS
3
+ from optax import GradientTransformation
3
4
  from jax_tqdm import scan_tqdm
4
5
  import jax
5
- from jax import jit
6
6
  import jax.numpy as jnp
7
7
  from jinns.solver._seq2seq import (
8
8
  _initialize_seq2seq,
@@ -24,7 +24,8 @@ def solve(
24
24
  init_params,
25
25
  data,
26
26
  loss,
27
- optax_solver,
27
+ optimizer,
28
+ print_loss_every=1000,
28
29
  opt_state=None,
29
30
  seq2seq=None,
30
31
  tracked_params_key_list=None,
@@ -54,11 +55,19 @@ def solve(
54
55
  A loss object (e.g. a LossODE, SystemLossODE, LossPDEStatio [...]
55
56
  object). It must be jittable (e.g. implements via a pytree
56
57
  registration)
57
- optax_solver
58
- An optax solver (e.g. adam with a given step-size)
58
+ optimizer
59
+ Can be an `optax` optimizer (e.g. `optax.adam`).
60
+ In such case, it is wrapped in the `jaxopt.OptaxSolver` wrapper.
61
+ Can be a `jaxopt` optimizer (e.g. `jaxopt.BFGS`) which supports the
62
+ methods `init_state` and `update`.
63
+ Can be a string (currently only `bfgs`), in such case a `jaxopt`
64
+ optimizer is created with default parameters.
65
+ print_loss_every
66
+ Integer. Default 100. The rate at which we print the loss value in the
67
+ gradient step loop.
59
68
  opt_state
60
69
  Default None. Provide an optional initial optional state to the
61
- optimizer
70
+ optimizer. Not valid for all optimizers.
62
71
  seq2seq
63
72
  Default None. A dictionary with keys 'times_steps'
64
73
  and 'iter_steps' which mush have same length. The first represents
@@ -99,13 +108,16 @@ def solve(
99
108
  """
100
109
  params = init_params
101
110
 
102
- # Wrap the optax solver with jaxopt
103
- optax_solver = OptaxSolver(
104
- opt=optax_solver,
105
- fun=loss,
106
- has_aux=True, # because the objective has aux output
107
- maxiter=n_iter,
108
- )
111
+ if isinstance(optimizer, GradientTransformation):
112
+ optimizer = OptaxSolver(
113
+ opt=optimizer,
114
+ fun=loss,
115
+ has_aux=True,
116
+ maxiter=n_iter,
117
+ )
118
+ elif optimizer == "lbfgs":
119
+ optimizer = LBFGS(fun=loss, has_aux=True, maxiter=n_iter)
120
+ # else, we trust that the user has given a valid jaxopt optimizer
109
121
 
110
122
  if param_data is not None:
111
123
  if (
@@ -134,7 +146,7 @@ def solve(
134
146
  batch = data.get_batch()
135
147
  if param_data is not None:
136
148
  batch = append_param_batch(batch, param_data.get_batch())
137
- opt_state = optax_solver.init_state(params, batch=batch)
149
+ opt_state = optimizer.init_state(params, batch=batch)
138
150
 
139
151
  curr_seq = 0
140
152
  if seq2seq is not None:
@@ -183,7 +195,7 @@ def solve(
183
195
  batch = carry["data"].get_batch()
184
196
  if carry["param_data"] is not None:
185
197
  batch = append_param_batch(batch, carry["param_data"].get_batch())
186
- carry["params"], carry["opt_state"] = optax_solver.update(
198
+ carry["params"], carry["opt_state"] = optimizer.update(
187
199
  params=carry["params"], state=carry["state"], batch=batch
188
200
  )
189
201
 
@@ -197,6 +209,18 @@ def solve(
197
209
 
198
210
  total_loss_val, loss_terms = loss(carry["params"], batch)
199
211
 
212
+ # Print loss during optimization
213
+ _ = jax.lax.cond(
214
+ i % print_loss_every == 0,
215
+ lambda _: jax.debug.print(
216
+ "Iteration {i}: loss value = " "{total_loss_val}",
217
+ i=i,
218
+ total_loss_val=total_loss_val,
219
+ ),
220
+ lambda _: None,
221
+ (None,),
222
+ )
223
+
200
224
  # optionnal seq2seq
201
225
  if seq2seq is not None:
202
226
  carry = _seq2seq_triggerer(
@@ -249,6 +273,12 @@ def solve(
249
273
  jnp.arange(n_iter),
250
274
  )
251
275
 
276
+ jax.debug.print(
277
+ "Iteration {i}: loss value = " "{total_loss_val}",
278
+ i=n_iter,
279
+ total_loss_val=accu[-1][-1],
280
+ )
281
+
252
282
  params = res["params"]
253
283
  last_non_nan_params = res["last_non_nan_params"]
254
284
  opt_state = res["state"]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jinns
3
- Version: 0.3.2
3
+ Version: 0.3.3
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -12,11 +12,11 @@ jinns/loss/_operators.py,sha256=nza4gfBn3Ppx7BN9Eben7KXPZ4cDQDyl6pfC8sHnAlI,4380
12
12
  jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
13
  jinns/solver/_rar.py,sha256=V9y07F6objmP6rPA305dIJ82h7kwP4AWBAktZ68b-38,13894
14
14
  jinns/solver/_seq2seq.py,sha256=XNL9e0fBj85Q86XfGDzq9dzxIkPPMwoJF38C8doNYtM,6032
15
- jinns/solver/_solve.py,sha256=XaY8I9dYGhFArB5GFWjcoe1D1nQOoXxpsUEPhrqaW0U,8896
15
+ jinns/solver/_solve.py,sha256=pCctVnrkfOptGd1wH2w9jDdFiOMmNqgatqaoYtyE_LM,10071
16
16
  jinns/utils/__init__.py,sha256=-jDlwCjyEzWweswKdwLal3OhaUU3FVzK_Ge2S-7KHXs,149
17
17
  jinns/utils/_utils.py,sha256=bQm6z_xPKJj9BMCr2tXc44IA8JyGNN-PR5LNRhZ1fD8,20085
18
- jinns-0.3.2.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
19
- jinns-0.3.2.dist-info/METADATA,sha256=YAstgf9edypib3O8bdTN7VUG5e811vhD3IaLjfuK7JM,1821
20
- jinns-0.3.2.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
21
- jinns-0.3.2.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
22
- jinns-0.3.2.dist-info/RECORD,,
18
+ jinns-0.3.3.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
19
+ jinns-0.3.3.dist-info/METADATA,sha256=GL7VeN4nQ5QuFqnKlnvagQiES6AyWsw9eeRVxnbu4vc,1821
20
+ jinns-0.3.3.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
21
+ jinns-0.3.3.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
22
+ jinns-0.3.3.dist-info/RECORD,,
File without changes
File without changes