jinns 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/solver/_solve.py +45 -15
- {jinns-0.3.2.dist-info → jinns-0.3.3.dist-info}/METADATA +1 -1
- {jinns-0.3.2.dist-info → jinns-0.3.3.dist-info}/RECORD +6 -6
- {jinns-0.3.2.dist-info → jinns-0.3.3.dist-info}/LICENSE +0 -0
- {jinns-0.3.2.dist-info → jinns-0.3.3.dist-info}/WHEEL +0 -0
- {jinns-0.3.2.dist-info → jinns-0.3.3.dist-info}/top_level.txt +0 -0
jinns/solver/_solve.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
from functools import partial
|
|
2
|
-
from jaxopt import OptaxSolver
|
|
2
|
+
from jaxopt import OptaxSolver, LBFGS
|
|
3
|
+
from optax import GradientTransformation
|
|
3
4
|
from jax_tqdm import scan_tqdm
|
|
4
5
|
import jax
|
|
5
|
-
from jax import jit
|
|
6
6
|
import jax.numpy as jnp
|
|
7
7
|
from jinns.solver._seq2seq import (
|
|
8
8
|
_initialize_seq2seq,
|
|
@@ -24,7 +24,8 @@ def solve(
|
|
|
24
24
|
init_params,
|
|
25
25
|
data,
|
|
26
26
|
loss,
|
|
27
|
-
|
|
27
|
+
optimizer,
|
|
28
|
+
print_loss_every=1000,
|
|
28
29
|
opt_state=None,
|
|
29
30
|
seq2seq=None,
|
|
30
31
|
tracked_params_key_list=None,
|
|
@@ -54,11 +55,19 @@ def solve(
|
|
|
54
55
|
A loss object (e.g. a LossODE, SystemLossODE, LossPDEStatio [...]
|
|
55
56
|
object). It must be jittable (e.g. implements via a pytree
|
|
56
57
|
registration)
|
|
57
|
-
|
|
58
|
-
|
|
58
|
+
optimizer
|
|
59
|
+
Can be an `optax` optimizer (e.g. `optax.adam`).
|
|
60
|
+
In such case, it is wrapped in the `jaxopt.OptaxSolver` wrapper.
|
|
61
|
+
Can be a `jaxopt` optimizer (e.g. `jaxopt.BFGS`) which supports the
|
|
62
|
+
methods `init_state` and `update`.
|
|
63
|
+
Can be a string (currently only `bfgs`), in such case a `jaxopt`
|
|
64
|
+
optimizer is created with default parameters.
|
|
65
|
+
print_loss_every
|
|
66
|
+
Integer. Default 100. The rate at which we print the loss value in the
|
|
67
|
+
gradient step loop.
|
|
59
68
|
opt_state
|
|
60
69
|
Default None. Provide an optional initial optional state to the
|
|
61
|
-
optimizer
|
|
70
|
+
optimizer. Not valid for all optimizers.
|
|
62
71
|
seq2seq
|
|
63
72
|
Default None. A dictionary with keys 'times_steps'
|
|
64
73
|
and 'iter_steps' which mush have same length. The first represents
|
|
@@ -99,13 +108,16 @@ def solve(
|
|
|
99
108
|
"""
|
|
100
109
|
params = init_params
|
|
101
110
|
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
111
|
+
if isinstance(optimizer, GradientTransformation):
|
|
112
|
+
optimizer = OptaxSolver(
|
|
113
|
+
opt=optimizer,
|
|
114
|
+
fun=loss,
|
|
115
|
+
has_aux=True,
|
|
116
|
+
maxiter=n_iter,
|
|
117
|
+
)
|
|
118
|
+
elif optimizer == "lbfgs":
|
|
119
|
+
optimizer = LBFGS(fun=loss, has_aux=True, maxiter=n_iter)
|
|
120
|
+
# else, we trust that the user has given a valid jaxopt optimizer
|
|
109
121
|
|
|
110
122
|
if param_data is not None:
|
|
111
123
|
if (
|
|
@@ -134,7 +146,7 @@ def solve(
|
|
|
134
146
|
batch = data.get_batch()
|
|
135
147
|
if param_data is not None:
|
|
136
148
|
batch = append_param_batch(batch, param_data.get_batch())
|
|
137
|
-
opt_state =
|
|
149
|
+
opt_state = optimizer.init_state(params, batch=batch)
|
|
138
150
|
|
|
139
151
|
curr_seq = 0
|
|
140
152
|
if seq2seq is not None:
|
|
@@ -183,7 +195,7 @@ def solve(
|
|
|
183
195
|
batch = carry["data"].get_batch()
|
|
184
196
|
if carry["param_data"] is not None:
|
|
185
197
|
batch = append_param_batch(batch, carry["param_data"].get_batch())
|
|
186
|
-
carry["params"], carry["opt_state"] =
|
|
198
|
+
carry["params"], carry["opt_state"] = optimizer.update(
|
|
187
199
|
params=carry["params"], state=carry["state"], batch=batch
|
|
188
200
|
)
|
|
189
201
|
|
|
@@ -197,6 +209,18 @@ def solve(
|
|
|
197
209
|
|
|
198
210
|
total_loss_val, loss_terms = loss(carry["params"], batch)
|
|
199
211
|
|
|
212
|
+
# Print loss during optimization
|
|
213
|
+
_ = jax.lax.cond(
|
|
214
|
+
i % print_loss_every == 0,
|
|
215
|
+
lambda _: jax.debug.print(
|
|
216
|
+
"Iteration {i}: loss value = " "{total_loss_val}",
|
|
217
|
+
i=i,
|
|
218
|
+
total_loss_val=total_loss_val,
|
|
219
|
+
),
|
|
220
|
+
lambda _: None,
|
|
221
|
+
(None,),
|
|
222
|
+
)
|
|
223
|
+
|
|
200
224
|
# optionnal seq2seq
|
|
201
225
|
if seq2seq is not None:
|
|
202
226
|
carry = _seq2seq_triggerer(
|
|
@@ -249,6 +273,12 @@ def solve(
|
|
|
249
273
|
jnp.arange(n_iter),
|
|
250
274
|
)
|
|
251
275
|
|
|
276
|
+
jax.debug.print(
|
|
277
|
+
"Iteration {i}: loss value = " "{total_loss_val}",
|
|
278
|
+
i=n_iter,
|
|
279
|
+
total_loss_val=accu[-1][-1],
|
|
280
|
+
)
|
|
281
|
+
|
|
252
282
|
params = res["params"]
|
|
253
283
|
last_non_nan_params = res["last_non_nan_params"]
|
|
254
284
|
opt_state = res["state"]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: jinns
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.3
|
|
4
4
|
Summary: Physics Informed Neural Network with JAX
|
|
5
5
|
Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
6
6
|
Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
@@ -12,11 +12,11 @@ jinns/loss/_operators.py,sha256=nza4gfBn3Ppx7BN9Eben7KXPZ4cDQDyl6pfC8sHnAlI,4380
|
|
|
12
12
|
jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
13
|
jinns/solver/_rar.py,sha256=V9y07F6objmP6rPA305dIJ82h7kwP4AWBAktZ68b-38,13894
|
|
14
14
|
jinns/solver/_seq2seq.py,sha256=XNL9e0fBj85Q86XfGDzq9dzxIkPPMwoJF38C8doNYtM,6032
|
|
15
|
-
jinns/solver/_solve.py,sha256=
|
|
15
|
+
jinns/solver/_solve.py,sha256=pCctVnrkfOptGd1wH2w9jDdFiOMmNqgatqaoYtyE_LM,10071
|
|
16
16
|
jinns/utils/__init__.py,sha256=-jDlwCjyEzWweswKdwLal3OhaUU3FVzK_Ge2S-7KHXs,149
|
|
17
17
|
jinns/utils/_utils.py,sha256=bQm6z_xPKJj9BMCr2tXc44IA8JyGNN-PR5LNRhZ1fD8,20085
|
|
18
|
-
jinns-0.3.
|
|
19
|
-
jinns-0.3.
|
|
20
|
-
jinns-0.3.
|
|
21
|
-
jinns-0.3.
|
|
22
|
-
jinns-0.3.
|
|
18
|
+
jinns-0.3.3.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
|
|
19
|
+
jinns-0.3.3.dist-info/METADATA,sha256=GL7VeN4nQ5QuFqnKlnvagQiES6AyWsw9eeRVxnbu4vc,1821
|
|
20
|
+
jinns-0.3.3.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
|
|
21
|
+
jinns-0.3.3.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
|
|
22
|
+
jinns-0.3.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|