jinns 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +904 -1203
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +322 -167
  9. jinns/loss/_LossODE.py +324 -322
  10. jinns/loss/_LossPDE.py +652 -1027
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +87 -41
  13. jinns/loss/{_Losses.py → _loss_utils.py} +101 -45
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +521 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +183 -39
  22. jinns/solver/_solve.py +151 -124
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -44
  25. jinns/utils/_hyperpinn.py +224 -119
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +113 -86
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +48 -140
  32. jinns-1.1.0.dist-info/AUTHORS +2 -0
  33. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/METADATA +5 -4
  34. jinns-1.1.0.dist-info/RECORD +39 -0
  35. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/WHEEL +1 -1
  36. jinns/experimental/_sinuspinn.py +0 -135
  37. jinns/experimental/_spectralpinn.py +0 -87
  38. jinns/solver/_seq2seq.py +0 -157
  39. jinns/utils/_optim.py +0 -147
  40. jinns/utils/_utils_uspinn.py +0 -727
  41. jinns-0.9.0.dist-info/RECORD +0 -36
  42. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/LICENSE +0 -0
  43. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/top_level.txt +0 -0
jinns/solver/_seq2seq.py DELETED
@@ -1,157 +0,0 @@
1
- """
2
- Implements Seq2Seq training as described in “Characterizing possible
3
- failure modes in physics-informed neural networks”, A. S. Krishnapriyan,
4
- NeurIPS 2021.
5
-
6
- **Note:** we do not change tmin, we only let the interval grow longer.
7
- Indeed we noticed some unlearning happening.
8
-
9
- **Note:** using seq2seq might create some instability in training when
10
- interval changes. Some of this instability comes from the fact that Tmax in
11
- the dynamic loss rescaling must be the true and final (and potentially large and
12
- unstable) one from the beginning if we want to be able to catch the real dynamic.
13
- However it does offer some better results for learning on long time intervals.
14
-
15
- **Note:** As this is experimental some changes in the future might be:
16
- - to dig deeper and try to attenuate the instability
17
- - to try to attenuate the discrepancy with the real dynamic when we
18
- also change Tmax in dynamic loss (this requires to treat the dynamic
19
- loss as a dynamic attribute of a Loss class...).
20
- - to investigate Tmax as input of the PINN
21
-
22
- """
23
-
24
- import jax
25
- from jax import jit
26
- import jax.numpy as jnp
27
- from jinns.data._DataGenerators import (
28
- DataGeneratorODE,
29
- _reset_batch_idx_and_permute,
30
- )
31
- from jinns.loss._LossODE import SystemLossODE, LossODE
32
- from jinns.loss._LossPDE import LossPDENonStatio, LossPDEStatio, SystemLossPDE
33
-
34
-
35
- # @partial(jit, static_argnames=["_update_seq2seq_true", "_update_seq2seq_false"])
36
- @jit
37
- def trigger_seq2seq(
38
- i,
39
- loss,
40
- params,
41
- data,
42
- opt_state,
43
- curr_seq,
44
- seq2seq,
45
- # _update_seq2seq_true,
46
- # _update_seq2seq_false,
47
- ):
48
- if seq2seq is not None:
49
- curr_seq, loss, data, opt_state = jax.lax.cond(
50
- curr_seq < jnp.sum(seq2seq["iter_steps"] < i),
51
- # check if we fall in another time interval
52
- _update_seq2seq_SystemLossODE,
53
- # only SystemLoss are handled for now
54
- _update_seq2seq_false,
55
- (
56
- loss,
57
- seq2seq,
58
- data,
59
- params,
60
- curr_seq,
61
- opt_state,
62
- ),
63
- )
64
- else:
65
- # Do nothing if no seq2seq
66
- curr_seq = -1
67
-
68
- return loss, params, data, opt_state, curr_seq, seq2seq
69
-
70
-
71
- def initialize_seq2seq(loss, data, seq2seq, opt_state):
72
- """
73
- Helper function to set up Seq2Seq before going into scan or for loop
74
- in `jinns.solve()`.
75
- """
76
- if isinstance(loss, SystemLossODE) and isinstance(data, DataGeneratorODE):
77
- curr_seq = 0
78
- # Note that boundaries for the first PINN are OK
79
- # set new boundaries for the batch generator
80
- data.tmax = seq2seq["time_steps"][curr_seq]
81
-
82
- jax.debug.print(
83
- "# -- Begin training on time segment [{tmin}, {tk}]",
84
- tmin=data.tmin,
85
- tk=seq2seq["time_steps"][curr_seq],
86
- )
87
- # and do not forget to regenerate the data
88
- data.curr_omega_idx = 0
89
- data.generate_time_data()
90
- data._key, data.times, _ = _reset_batch_idx_and_permute(
91
- (data._key, data.times, data.curr_omega_idx, None, data.p_times)
92
- )
93
- opt_state.hyperparams["learning_rate"] = seq2seq["learning_rate"][curr_seq]
94
-
95
- elif isinstance(loss, (LossPDENonStatio, LossPDEStatio, SystemLossPDE)):
96
- raise RuntimeError("Not implemented")
97
-
98
- return data, opt_state
99
-
100
-
101
- def _update_seq2seq_SystemLossODE(operands):
102
- """
103
- Make all the necessary updates for a SystemLossODE in seq2seq learning mode
104
-
105
- Parameters
106
- ----------
107
- operands
108
- A tuple which comprises.
109
-
110
- loss
111
- A loss object (e.g. a LossODE, SystemLossODE, LossPDEStatio [...]
112
- object). It must be jittable (e.g. implements via a pytree
113
- registration)
114
- seq2seq
115
- A dictionary with keys 'times_steps'
116
- and 'iter_steps' which mush have same length. The first represents
117
- the time steps which represents the different time interval upon
118
- which we perform the incremental learning. The second represents
119
- the number of iteration we perform in each time interval.
120
- data
121
- A DataGenerator object which implements a `get_batch()`
122
- method which returns a 3-tuple with (omega_grid, omega_border, time grid).
123
- It must be jittable (e.g. implements via a pytree
124
- registration)
125
- params
126
- The dictionary of parameters of the model.
127
- Typically, it is a dictionary of
128
- dictionaries: `eq_params` and `nn_params``, respectively the
129
- differential equation parameters and the neural network parameter
130
- curr_seq
131
- A integer which represents which sequence we currently are in
132
- """
133
- loss, seq2seq, data, params, curr_seq, opt_state = operands
134
- curr_seq += 1
135
-
136
- jax.debug.print(
137
- "# -- Entering training on time segment [{tmin}, {tk}]",
138
- tmin=data.tmin,
139
- tk=seq2seq["time_steps"][curr_seq],
140
- )
141
-
142
- # set new boundaries for the batch generator
143
- data.tmax = seq2seq["time_steps"][curr_seq]
144
- # and do not forget to regenerate the data
145
- data.curr_omega_idx = 0
146
- data.generate_time_data()
147
- data._key, data.times, _ = _reset_batch_idx_and_permute(
148
- (data._key, data.times, data.curr_omega_idx, None, data.p_times)
149
- )
150
-
151
- opt_state.hyperparams["learning_rate"] = seq2seq["learning_rate"][curr_seq]
152
- return curr_seq, loss, data, opt_state
153
-
154
-
155
- def _update_seq2seq_false(operands):
156
- # basically returns (curr_seq, loss, data, opt_state) in this order
157
- return (operands[-2], operands[0], operands[2], operands[-1])
jinns/utils/_optim.py DELETED
@@ -1,147 +0,0 @@
1
- """
2
- Implements utility functions for optimization
3
- """
4
-
5
- import optax
6
- import jax
7
- import jax.numpy as jnp
8
-
9
-
10
- def alternate_optimizer(list_first_params, list_second_params, n_iter, evry, tx1, tx2):
11
- """
12
- Alternatively optimize on two sets of parameters for equal number of steps
13
-
14
- Parameters
15
- ----------
16
- list_first_params
17
- The first set of parameter to optimize on. A list of leaves from the `params` dict
18
- list_second_params
19
- The second set of parameter to optimize on. A list of leaves from the `params` dict
20
- n_iter
21
- The total number of iterations
22
- evry
23
- The number of iterations we spend optimizing a set of parameters before switching
24
- tx1
25
- An optax optimizer of the set of parameters 1
26
- tx2
27
- An optax optimizer of the set of parameters 1
28
-
29
- Returns
30
- -------
31
- tx
32
- An optax optimizer object
33
- """
34
-
35
- def map_nested_fn(fn):
36
- """Recursively apply `fn` to the key-value pairs of a nested dict"""
37
-
38
- def map_fn(nested_dict):
39
- return {
40
- k: (map_fn(v) if isinstance(v, dict) else fn(k, v))
41
- for k, v in nested_dict.items()
42
- }
43
-
44
- return map_fn
45
-
46
- label_fn = map_nested_fn(lambda k, _: k)
47
-
48
- def should_update_1(step):
49
- return jax.tree_util.tree_reduce(
50
- lambda x, y: jnp.logical_or(x, y),
51
- [
52
- jnp.logical_and((step > i * evry), (step < (i + 1) * evry))
53
- for i in range(1, n_iter // evry, 2)
54
- ],
55
- )
56
-
57
- def should_update_2(step):
58
- return jax.tree_util.tree_reduce(
59
- lambda x, y: jnp.logical_or(x, y),
60
- [
61
- jnp.logical_and((step > i * evry), (step < (i + 1) * evry))
62
- for i in range(
63
- 0, n_iter // evry, 2
64
- ) # starts at 0 since this one is blocked first
65
- ],
66
- )
67
-
68
- first_adam = optax.chain(
69
- tx1,
70
- optax.maybe_update(
71
- optax.scale(0.0), should_update_1
72
- ), # We add an update (a GradientTransform if should_update is True) i.e. we mult the update by 0.
73
- # not to take a step
74
- )
75
- second_adam = optax.chain(
76
- tx2,
77
- optax.maybe_update(
78
- optax.scale(0.0), should_update_2
79
- ), # We add an update (a GradientTransform if should_update is True) i.e. we mult the update by 0.
80
- # not to take a step
81
- )
82
-
83
- return optax.multi_transform(
84
- {k: first_adam for k in list_first_params}
85
- | {
86
- k: second_adam for k in list_second_params
87
- }, # those gradient transforms must correspond to leaves of parameter pytree
88
- label_fn,
89
- )
90
-
91
-
92
- def delayed_optimizer(list_first_params, list_second_params, delay_steps, tx1, tx2):
93
- """
94
- Optimize on two sets of parameters, the optimization on the second set of
95
- parameters start after `delay_steps` of freezing
96
-
97
- Parameters
98
- ----------
99
- list_first_params
100
- The first set of parameter to optimize on. A list of leaves from the `params` dict
101
- list_second_params
102
- The second set of parameter to optimize on. A list of leaves from the `params` dict
103
- delay_steps
104
- The number of steps we wait before starting the optimization on the
105
- second set of parameters
106
- tx1
107
- An optax optimizer of the set of parameters 1
108
- tx2
109
- An optax optimizer of the set of parameters 1
110
-
111
- Returns
112
- -------
113
- tx
114
- An optax optimizer object
115
- """
116
-
117
- def map_nested_fn(fn):
118
- """Recursively apply `fn` to the key-value pairs of a nested dict"""
119
-
120
- def map_fn(nested_dict):
121
- return {
122
- k: (map_fn(v) if isinstance(v, dict) else fn(k, v))
123
- for k, v in nested_dict.items()
124
- }
125
-
126
- return map_fn
127
-
128
- label_fn = map_nested_fn(lambda k, _: k)
129
-
130
- def should_update_2(step):
131
- return step < delay_steps
132
-
133
- delayed_tx2 = optax.chain(
134
- tx2,
135
- optax.maybe_update(
136
- optax.scale(0.0), should_update_2
137
- ), # We add an update (a GradientTransform if should_update is True) i.e. we mult the update by 0.
138
- # not to take a step
139
- )
140
-
141
- return optax.multi_transform(
142
- {k: tx1 for k in list_first_params}
143
- | {
144
- k: delayed_tx2 for k in list_second_params
145
- }, # those gradient transforms must correspond to leaves of parameter pytree
146
- label_fn,
147
- )