jinns 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -0
- jinns/data/_Batchs.py +27 -0
- jinns/data/_DataGenerators.py +904 -1203
- jinns/data/__init__.py +4 -8
- jinns/experimental/__init__.py +0 -2
- jinns/experimental/_diffrax_solver.py +5 -5
- jinns/loss/_DynamicLoss.py +282 -305
- jinns/loss/_DynamicLossAbstract.py +322 -167
- jinns/loss/_LossODE.py +324 -322
- jinns/loss/_LossPDE.py +652 -1027
- jinns/loss/__init__.py +21 -5
- jinns/loss/_boundary_conditions.py +87 -41
- jinns/loss/{_Losses.py → _loss_utils.py} +101 -45
- jinns/loss/_loss_weights.py +59 -0
- jinns/loss/_operators.py +78 -72
- jinns/parameters/__init__.py +6 -0
- jinns/parameters/_derivative_keys.py +521 -0
- jinns/parameters/_params.py +115 -0
- jinns/plot/__init__.py +5 -0
- jinns/{data/_display.py → plot/_plot.py} +98 -75
- jinns/solver/_rar.py +183 -39
- jinns/solver/_solve.py +151 -124
- jinns/utils/__init__.py +3 -9
- jinns/utils/_containers.py +37 -44
- jinns/utils/_hyperpinn.py +224 -119
- jinns/utils/_pinn.py +183 -111
- jinns/utils/_save_load.py +121 -56
- jinns/utils/_spinn.py +113 -86
- jinns/utils/_types.py +64 -0
- jinns/utils/_utils.py +6 -160
- jinns/validation/_validation.py +48 -140
- jinns-1.1.0.dist-info/AUTHORS +2 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/METADATA +5 -4
- jinns-1.1.0.dist-info/RECORD +39 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/WHEEL +1 -1
- jinns/experimental/_sinuspinn.py +0 -135
- jinns/experimental/_spectralpinn.py +0 -87
- jinns/solver/_seq2seq.py +0 -157
- jinns/utils/_optim.py +0 -147
- jinns/utils/_utils_uspinn.py +0 -727
- jinns-0.9.0.dist-info/RECORD +0 -36
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/LICENSE +0 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/top_level.txt +0 -0
jinns/solver/_seq2seq.py
DELETED
|
@@ -1,157 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Implements Seq2Seq training as described in “Characterizing possible
|
|
3
|
-
failure modes in physics-informed neural networks”, A. S. Krishnapriyan,
|
|
4
|
-
NeurIPS 2021.
|
|
5
|
-
|
|
6
|
-
**Note:** we do not change tmin, we only let the interval grow longer.
|
|
7
|
-
Indeed we noticed some unlearning happening.
|
|
8
|
-
|
|
9
|
-
**Note:** using seq2seq might create some instability in training when
|
|
10
|
-
interval changes. Some of this instability comes from the fact that Tmax in
|
|
11
|
-
the dynamic loss rescaling must be the true and final (and potentially large and
|
|
12
|
-
unstable) one from the beginning if we want to be able to catch the real dynamic.
|
|
13
|
-
However it does offer some better results for learning on long time intervals.
|
|
14
|
-
|
|
15
|
-
**Note:** As this is experimental some changes in the future might be:
|
|
16
|
-
- to dig deeper and try to attenuate the instability
|
|
17
|
-
- to try to attenuate the discrepancy with the real dynamic when we
|
|
18
|
-
also change Tmax in dynamic loss (this requires to treat the dynamic
|
|
19
|
-
loss as a dynamic attribute of a Loss class...).
|
|
20
|
-
- to investigate Tmax as input of the PINN
|
|
21
|
-
|
|
22
|
-
"""
|
|
23
|
-
|
|
24
|
-
import jax
|
|
25
|
-
from jax import jit
|
|
26
|
-
import jax.numpy as jnp
|
|
27
|
-
from jinns.data._DataGenerators import (
|
|
28
|
-
DataGeneratorODE,
|
|
29
|
-
_reset_batch_idx_and_permute,
|
|
30
|
-
)
|
|
31
|
-
from jinns.loss._LossODE import SystemLossODE, LossODE
|
|
32
|
-
from jinns.loss._LossPDE import LossPDENonStatio, LossPDEStatio, SystemLossPDE
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
# @partial(jit, static_argnames=["_update_seq2seq_true", "_update_seq2seq_false"])
|
|
36
|
-
@jit
|
|
37
|
-
def trigger_seq2seq(
|
|
38
|
-
i,
|
|
39
|
-
loss,
|
|
40
|
-
params,
|
|
41
|
-
data,
|
|
42
|
-
opt_state,
|
|
43
|
-
curr_seq,
|
|
44
|
-
seq2seq,
|
|
45
|
-
# _update_seq2seq_true,
|
|
46
|
-
# _update_seq2seq_false,
|
|
47
|
-
):
|
|
48
|
-
if seq2seq is not None:
|
|
49
|
-
curr_seq, loss, data, opt_state = jax.lax.cond(
|
|
50
|
-
curr_seq < jnp.sum(seq2seq["iter_steps"] < i),
|
|
51
|
-
# check if we fall in another time interval
|
|
52
|
-
_update_seq2seq_SystemLossODE,
|
|
53
|
-
# only SystemLoss are handled for now
|
|
54
|
-
_update_seq2seq_false,
|
|
55
|
-
(
|
|
56
|
-
loss,
|
|
57
|
-
seq2seq,
|
|
58
|
-
data,
|
|
59
|
-
params,
|
|
60
|
-
curr_seq,
|
|
61
|
-
opt_state,
|
|
62
|
-
),
|
|
63
|
-
)
|
|
64
|
-
else:
|
|
65
|
-
# Do nothing if no seq2seq
|
|
66
|
-
curr_seq = -1
|
|
67
|
-
|
|
68
|
-
return loss, params, data, opt_state, curr_seq, seq2seq
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
def initialize_seq2seq(loss, data, seq2seq, opt_state):
|
|
72
|
-
"""
|
|
73
|
-
Helper function to set up Seq2Seq before going into scan or for loop
|
|
74
|
-
in `jinns.solve()`.
|
|
75
|
-
"""
|
|
76
|
-
if isinstance(loss, SystemLossODE) and isinstance(data, DataGeneratorODE):
|
|
77
|
-
curr_seq = 0
|
|
78
|
-
# Note that boundaries for the first PINN are OK
|
|
79
|
-
# set new boundaries for the batch generator
|
|
80
|
-
data.tmax = seq2seq["time_steps"][curr_seq]
|
|
81
|
-
|
|
82
|
-
jax.debug.print(
|
|
83
|
-
"# -- Begin training on time segment [{tmin}, {tk}]",
|
|
84
|
-
tmin=data.tmin,
|
|
85
|
-
tk=seq2seq["time_steps"][curr_seq],
|
|
86
|
-
)
|
|
87
|
-
# and do not forget to regenerate the data
|
|
88
|
-
data.curr_omega_idx = 0
|
|
89
|
-
data.generate_time_data()
|
|
90
|
-
data._key, data.times, _ = _reset_batch_idx_and_permute(
|
|
91
|
-
(data._key, data.times, data.curr_omega_idx, None, data.p_times)
|
|
92
|
-
)
|
|
93
|
-
opt_state.hyperparams["learning_rate"] = seq2seq["learning_rate"][curr_seq]
|
|
94
|
-
|
|
95
|
-
elif isinstance(loss, (LossPDENonStatio, LossPDEStatio, SystemLossPDE)):
|
|
96
|
-
raise RuntimeError("Not implemented")
|
|
97
|
-
|
|
98
|
-
return data, opt_state
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
def _update_seq2seq_SystemLossODE(operands):
|
|
102
|
-
"""
|
|
103
|
-
Make all the necessary updates for a SystemLossODE in seq2seq learning mode
|
|
104
|
-
|
|
105
|
-
Parameters
|
|
106
|
-
----------
|
|
107
|
-
operands
|
|
108
|
-
A tuple which comprises.
|
|
109
|
-
|
|
110
|
-
loss
|
|
111
|
-
A loss object (e.g. a LossODE, SystemLossODE, LossPDEStatio [...]
|
|
112
|
-
object). It must be jittable (e.g. implements via a pytree
|
|
113
|
-
registration)
|
|
114
|
-
seq2seq
|
|
115
|
-
A dictionary with keys 'times_steps'
|
|
116
|
-
and 'iter_steps' which mush have same length. The first represents
|
|
117
|
-
the time steps which represents the different time interval upon
|
|
118
|
-
which we perform the incremental learning. The second represents
|
|
119
|
-
the number of iteration we perform in each time interval.
|
|
120
|
-
data
|
|
121
|
-
A DataGenerator object which implements a `get_batch()`
|
|
122
|
-
method which returns a 3-tuple with (omega_grid, omega_border, time grid).
|
|
123
|
-
It must be jittable (e.g. implements via a pytree
|
|
124
|
-
registration)
|
|
125
|
-
params
|
|
126
|
-
The dictionary of parameters of the model.
|
|
127
|
-
Typically, it is a dictionary of
|
|
128
|
-
dictionaries: `eq_params` and `nn_params``, respectively the
|
|
129
|
-
differential equation parameters and the neural network parameter
|
|
130
|
-
curr_seq
|
|
131
|
-
A integer which represents which sequence we currently are in
|
|
132
|
-
"""
|
|
133
|
-
loss, seq2seq, data, params, curr_seq, opt_state = operands
|
|
134
|
-
curr_seq += 1
|
|
135
|
-
|
|
136
|
-
jax.debug.print(
|
|
137
|
-
"# -- Entering training on time segment [{tmin}, {tk}]",
|
|
138
|
-
tmin=data.tmin,
|
|
139
|
-
tk=seq2seq["time_steps"][curr_seq],
|
|
140
|
-
)
|
|
141
|
-
|
|
142
|
-
# set new boundaries for the batch generator
|
|
143
|
-
data.tmax = seq2seq["time_steps"][curr_seq]
|
|
144
|
-
# and do not forget to regenerate the data
|
|
145
|
-
data.curr_omega_idx = 0
|
|
146
|
-
data.generate_time_data()
|
|
147
|
-
data._key, data.times, _ = _reset_batch_idx_and_permute(
|
|
148
|
-
(data._key, data.times, data.curr_omega_idx, None, data.p_times)
|
|
149
|
-
)
|
|
150
|
-
|
|
151
|
-
opt_state.hyperparams["learning_rate"] = seq2seq["learning_rate"][curr_seq]
|
|
152
|
-
return curr_seq, loss, data, opt_state
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
def _update_seq2seq_false(operands):
|
|
156
|
-
# basically returns (curr_seq, loss, data, opt_state) in this order
|
|
157
|
-
return (operands[-2], operands[0], operands[2], operands[-1])
|
jinns/utils/_optim.py
DELETED
|
@@ -1,147 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Implements utility functions for optimization
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
import optax
|
|
6
|
-
import jax
|
|
7
|
-
import jax.numpy as jnp
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
def alternate_optimizer(list_first_params, list_second_params, n_iter, evry, tx1, tx2):
|
|
11
|
-
"""
|
|
12
|
-
Alternatively optimize on two sets of parameters for equal number of steps
|
|
13
|
-
|
|
14
|
-
Parameters
|
|
15
|
-
----------
|
|
16
|
-
list_first_params
|
|
17
|
-
The first set of parameter to optimize on. A list of leaves from the `params` dict
|
|
18
|
-
list_second_params
|
|
19
|
-
The second set of parameter to optimize on. A list of leaves from the `params` dict
|
|
20
|
-
n_iter
|
|
21
|
-
The total number of iterations
|
|
22
|
-
evry
|
|
23
|
-
The number of iterations we spend optimizing a set of parameters before switching
|
|
24
|
-
tx1
|
|
25
|
-
An optax optimizer of the set of parameters 1
|
|
26
|
-
tx2
|
|
27
|
-
An optax optimizer of the set of parameters 1
|
|
28
|
-
|
|
29
|
-
Returns
|
|
30
|
-
-------
|
|
31
|
-
tx
|
|
32
|
-
An optax optimizer object
|
|
33
|
-
"""
|
|
34
|
-
|
|
35
|
-
def map_nested_fn(fn):
|
|
36
|
-
"""Recursively apply `fn` to the key-value pairs of a nested dict"""
|
|
37
|
-
|
|
38
|
-
def map_fn(nested_dict):
|
|
39
|
-
return {
|
|
40
|
-
k: (map_fn(v) if isinstance(v, dict) else fn(k, v))
|
|
41
|
-
for k, v in nested_dict.items()
|
|
42
|
-
}
|
|
43
|
-
|
|
44
|
-
return map_fn
|
|
45
|
-
|
|
46
|
-
label_fn = map_nested_fn(lambda k, _: k)
|
|
47
|
-
|
|
48
|
-
def should_update_1(step):
|
|
49
|
-
return jax.tree_util.tree_reduce(
|
|
50
|
-
lambda x, y: jnp.logical_or(x, y),
|
|
51
|
-
[
|
|
52
|
-
jnp.logical_and((step > i * evry), (step < (i + 1) * evry))
|
|
53
|
-
for i in range(1, n_iter // evry, 2)
|
|
54
|
-
],
|
|
55
|
-
)
|
|
56
|
-
|
|
57
|
-
def should_update_2(step):
|
|
58
|
-
return jax.tree_util.tree_reduce(
|
|
59
|
-
lambda x, y: jnp.logical_or(x, y),
|
|
60
|
-
[
|
|
61
|
-
jnp.logical_and((step > i * evry), (step < (i + 1) * evry))
|
|
62
|
-
for i in range(
|
|
63
|
-
0, n_iter // evry, 2
|
|
64
|
-
) # starts at 0 since this one is blocked first
|
|
65
|
-
],
|
|
66
|
-
)
|
|
67
|
-
|
|
68
|
-
first_adam = optax.chain(
|
|
69
|
-
tx1,
|
|
70
|
-
optax.maybe_update(
|
|
71
|
-
optax.scale(0.0), should_update_1
|
|
72
|
-
), # We add an update (a GradientTransform if should_update is True) i.e. we mult the update by 0.
|
|
73
|
-
# not to take a step
|
|
74
|
-
)
|
|
75
|
-
second_adam = optax.chain(
|
|
76
|
-
tx2,
|
|
77
|
-
optax.maybe_update(
|
|
78
|
-
optax.scale(0.0), should_update_2
|
|
79
|
-
), # We add an update (a GradientTransform if should_update is True) i.e. we mult the update by 0.
|
|
80
|
-
# not to take a step
|
|
81
|
-
)
|
|
82
|
-
|
|
83
|
-
return optax.multi_transform(
|
|
84
|
-
{k: first_adam for k in list_first_params}
|
|
85
|
-
| {
|
|
86
|
-
k: second_adam for k in list_second_params
|
|
87
|
-
}, # those gradient transforms must correspond to leaves of parameter pytree
|
|
88
|
-
label_fn,
|
|
89
|
-
)
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
def delayed_optimizer(list_first_params, list_second_params, delay_steps, tx1, tx2):
|
|
93
|
-
"""
|
|
94
|
-
Optimize on two sets of parameters, the optimization on the second set of
|
|
95
|
-
parameters start after `delay_steps` of freezing
|
|
96
|
-
|
|
97
|
-
Parameters
|
|
98
|
-
----------
|
|
99
|
-
list_first_params
|
|
100
|
-
The first set of parameter to optimize on. A list of leaves from the `params` dict
|
|
101
|
-
list_second_params
|
|
102
|
-
The second set of parameter to optimize on. A list of leaves from the `params` dict
|
|
103
|
-
delay_steps
|
|
104
|
-
The number of steps we wait before starting the optimization on the
|
|
105
|
-
second set of parameters
|
|
106
|
-
tx1
|
|
107
|
-
An optax optimizer of the set of parameters 1
|
|
108
|
-
tx2
|
|
109
|
-
An optax optimizer of the set of parameters 1
|
|
110
|
-
|
|
111
|
-
Returns
|
|
112
|
-
-------
|
|
113
|
-
tx
|
|
114
|
-
An optax optimizer object
|
|
115
|
-
"""
|
|
116
|
-
|
|
117
|
-
def map_nested_fn(fn):
|
|
118
|
-
"""Recursively apply `fn` to the key-value pairs of a nested dict"""
|
|
119
|
-
|
|
120
|
-
def map_fn(nested_dict):
|
|
121
|
-
return {
|
|
122
|
-
k: (map_fn(v) if isinstance(v, dict) else fn(k, v))
|
|
123
|
-
for k, v in nested_dict.items()
|
|
124
|
-
}
|
|
125
|
-
|
|
126
|
-
return map_fn
|
|
127
|
-
|
|
128
|
-
label_fn = map_nested_fn(lambda k, _: k)
|
|
129
|
-
|
|
130
|
-
def should_update_2(step):
|
|
131
|
-
return step < delay_steps
|
|
132
|
-
|
|
133
|
-
delayed_tx2 = optax.chain(
|
|
134
|
-
tx2,
|
|
135
|
-
optax.maybe_update(
|
|
136
|
-
optax.scale(0.0), should_update_2
|
|
137
|
-
), # We add an update (a GradientTransform if should_update is True) i.e. we mult the update by 0.
|
|
138
|
-
# not to take a step
|
|
139
|
-
)
|
|
140
|
-
|
|
141
|
-
return optax.multi_transform(
|
|
142
|
-
{k: tx1 for k in list_first_params}
|
|
143
|
-
| {
|
|
144
|
-
k: delayed_tx2 for k in list_second_params
|
|
145
|
-
}, # those gradient transforms must correspond to leaves of parameter pytree
|
|
146
|
-
label_fn,
|
|
147
|
-
)
|