python-somax 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- python_somax-0.0.1.dist-info/LICENSE +201 -0
- python_somax-0.0.1.dist-info/METADATA +131 -0
- python_somax-0.0.1.dist-info/RECORD +19 -0
- python_somax-0.0.1.dist-info/WHEEL +5 -0
- python_somax-0.0.1.dist-info/top_level.txt +1 -0
- somax/__init__.py +19 -0
- somax/diagonal/__init__.py +0 -0
- somax/diagonal/adahessian.py +192 -0
- somax/diagonal/sophia_g.py +208 -0
- somax/diagonal/sophia_h.py +191 -0
- somax/gn/__init__.py +0 -0
- somax/gn/egn.py +567 -0
- somax/gn/sgn.py +237 -0
- somax/hf/__init__.py +0 -0
- somax/hf/newton_cg.py +339 -0
- somax/ng/__init__.py +0 -0
- somax/ng/swm_ng.py +411 -0
- somax/qn/__init__.py +0 -0
- somax/qn/sqn.py +311 -0
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Sophia-G Optimizer
|
|
3
|
+
Paper: https://arxiv.org/abs/2305.14342
|
|
4
|
+
|
|
5
|
+
!! Sophia-G is only implemented for the Cross-Entropy loss function !!
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any
|
|
9
|
+
from typing import Callable
|
|
10
|
+
from typing import NamedTuple, Tuple
|
|
11
|
+
from typing import Optional
|
|
12
|
+
import dataclasses
|
|
13
|
+
from functools import partial
|
|
14
|
+
|
|
15
|
+
import jax
|
|
16
|
+
import jax.lax as lax
|
|
17
|
+
import jax.numpy as jnp
|
|
18
|
+
from jax.flatten_util import ravel_pytree
|
|
19
|
+
from optax.losses import softmax_cross_entropy_with_integer_labels
|
|
20
|
+
from jaxopt._src import base
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SophiaGState(NamedTuple):
|
|
24
|
+
"""Named tuple containing state information."""
|
|
25
|
+
iter_num: int
|
|
26
|
+
stepsize: float
|
|
27
|
+
velocity_m: Any
|
|
28
|
+
velocity_v: Any
|
|
29
|
+
hess_approx_rng: Any
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclasses.dataclass(eq=False)
|
|
33
|
+
class SophiaG(base.StochasticSolver):
|
|
34
|
+
predict_fun: Callable
|
|
35
|
+
|
|
36
|
+
# loss_from_logits_fun: Callable
|
|
37
|
+
|
|
38
|
+
learning_rate: float = 1e-3
|
|
39
|
+
|
|
40
|
+
# Lazy Hessian parameters
|
|
41
|
+
eval_hess_every_k: int = 10
|
|
42
|
+
|
|
43
|
+
# Momentum parameters
|
|
44
|
+
beta1: float = 0.99 # proposed as default by levanter
|
|
45
|
+
beta2: float = 0.99 # proposed as default by levanter
|
|
46
|
+
|
|
47
|
+
# Regularization parameters
|
|
48
|
+
weight_decay: float = 0.0 # L2 regularization coefficient
|
|
49
|
+
|
|
50
|
+
gamma: float = 0.05 # clipping parameter
|
|
51
|
+
clip_th: float = 1. # clipping threshold
|
|
52
|
+
eps: float = 1e-8 # term added to the denominator to improve numerical stability
|
|
53
|
+
|
|
54
|
+
hess_approx_seed: int = 1337
|
|
55
|
+
|
|
56
|
+
pre_update: Optional[Callable] = None
|
|
57
|
+
|
|
58
|
+
verbose: int = 0
|
|
59
|
+
|
|
60
|
+
jit: bool = True
|
|
61
|
+
unroll: base.AutoOrBoolean = "auto"
|
|
62
|
+
|
|
63
|
+
def __post_init__(self):
|
|
64
|
+
super().__post_init__()
|
|
65
|
+
|
|
66
|
+
self.reference_signature = self.predict_fun
|
|
67
|
+
|
|
68
|
+
self.grad_fun = jax.grad(self.loss_with_aux_fun, has_aux=True)
|
|
69
|
+
|
|
70
|
+
assert 0 <= self.weight_decay < 1, "Weight decay must be in [0, 1)"
|
|
71
|
+
|
|
72
|
+
def update(
|
|
73
|
+
self,
|
|
74
|
+
params: Any,
|
|
75
|
+
state: SophiaGState,
|
|
76
|
+
*args,
|
|
77
|
+
**kwargs,
|
|
78
|
+
) -> base.OptStep:
|
|
79
|
+
"""Performs one iteration of the solver.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
params: pytree containing the parameters.
|
|
83
|
+
state: named tuple containing the solver state.
|
|
84
|
+
*args: additional positional arguments to be passed to ``fun``.
|
|
85
|
+
**kwargs: additional keyword arguments to be passed to ``fun``.
|
|
86
|
+
Returns:
|
|
87
|
+
(params, state)
|
|
88
|
+
"""
|
|
89
|
+
# params_flat, pack_fn = ravel_pytree(params)
|
|
90
|
+
# p_shape = params_flat.shape
|
|
91
|
+
|
|
92
|
+
# ------- Step 1 -------
|
|
93
|
+
# compute the exact gradient and the Diagonal Hessian estimate
|
|
94
|
+
# incorporate the lazy Hessian evaluation, i.e. re-compute H every k-th iteration
|
|
95
|
+
# apply temporal averaging (momentum) to the first and second moments
|
|
96
|
+
# !! original Sophia paper has no bias correction: 1/(1 - self.beta1 ** i) !!
|
|
97
|
+
next_rng_key, rng_key = jax.random.split(state.hess_approx_rng)
|
|
98
|
+
inputs = (params, state, rng_key, args, kwargs)
|
|
99
|
+
m_t, v_t = lax.cond(
|
|
100
|
+
state.iter_num % self.eval_hess_every_k == 0,
|
|
101
|
+
inputs, self.grad_and_hess_gnb,
|
|
102
|
+
inputs, self.grad_and_hess_reuse,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# TODO apply bias correction for m_t?
|
|
106
|
+
|
|
107
|
+
# ------- Step 2 -------
|
|
108
|
+
# clip direction
|
|
109
|
+
direction = jnp.clip(m_t / jnp.maximum(self.gamma * v_t, self.eps), a_min=-self.clip_th, a_max=self.clip_th)
|
|
110
|
+
|
|
111
|
+
# update parameters
|
|
112
|
+
params_flat, pack_fn = ravel_pytree(params)
|
|
113
|
+
next_params_flat = params_flat - self.learning_rate * direction
|
|
114
|
+
|
|
115
|
+
# ------- Step 4 -------
|
|
116
|
+
# AdamW-style weight decay
|
|
117
|
+
if self.weight_decay > 0:
|
|
118
|
+
next_params_flat -= self.learning_rate * self.weight_decay * params_flat
|
|
119
|
+
|
|
120
|
+
# bookkeeping
|
|
121
|
+
next_params = pack_fn(next_params_flat)
|
|
122
|
+
|
|
123
|
+
next_state = SophiaGState(
|
|
124
|
+
iter_num=state.iter_num + 1,
|
|
125
|
+
stepsize=state.stepsize,
|
|
126
|
+
velocity_m=m_t,
|
|
127
|
+
velocity_v=v_t,
|
|
128
|
+
hess_approx_rng=next_rng_key,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
return base.OptStep(params=next_params, state=next_state)
|
|
132
|
+
|
|
133
|
+
def init_state(self,
|
|
134
|
+
init_params: Any,
|
|
135
|
+
*args,
|
|
136
|
+
**kwargs) -> SophiaGState:
|
|
137
|
+
params_flat, pack_fn = ravel_pytree(init_params)
|
|
138
|
+
velocity_m = jnp.zeros_like(params_flat)
|
|
139
|
+
velocity_v = jnp.zeros_like(params_flat)
|
|
140
|
+
|
|
141
|
+
return SophiaGState(
|
|
142
|
+
iter_num=0,
|
|
143
|
+
stepsize=self.learning_rate,
|
|
144
|
+
velocity_m=velocity_m,
|
|
145
|
+
velocity_v=velocity_v,
|
|
146
|
+
hess_approx_rng=jax.random.PRNGKey(self.hess_approx_seed),
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
def optimality_fun(self, params, *args, **kwargs):
|
|
150
|
+
"""Optimality function mapping compatible with ``@custom_root``."""
|
|
151
|
+
# return self._grad_fun(params, *args, **kwargs)[0]
|
|
152
|
+
raise NotImplementedError
|
|
153
|
+
|
|
154
|
+
def __hash__(self):
|
|
155
|
+
# We assume that the attribute values completely determine the solver.
|
|
156
|
+
return hash(self.attribute_values())
|
|
157
|
+
|
|
158
|
+
@staticmethod
|
|
159
|
+
def resolve_features_and_targets(*args, **kwargs):
|
|
160
|
+
# TODO: add support for named arguments
|
|
161
|
+
features = args[0]
|
|
162
|
+
targets = args[-1]
|
|
163
|
+
return features, targets
|
|
164
|
+
|
|
165
|
+
def loss_with_aux_fun(self, params, *args, **kwargs):
|
|
166
|
+
features, targets = self.resolve_features_and_targets(*args, **kwargs)
|
|
167
|
+
logits = self.predict_fun(params, features)
|
|
168
|
+
loss = softmax_cross_entropy_with_integer_labels(
|
|
169
|
+
logits=logits, labels=targets).mean()
|
|
170
|
+
return loss, logits
|
|
171
|
+
|
|
172
|
+
def grad_and_hess_gnb(self, inputs):
|
|
173
|
+
params, state, rng_key, args, kwargs = inputs
|
|
174
|
+
|
|
175
|
+
grads_tree, logits = self.grad_fun(params, *args, **kwargs)
|
|
176
|
+
|
|
177
|
+
samples = jax.random.categorical(rng_key, logits=logits)
|
|
178
|
+
|
|
179
|
+
# TODO consider re-writing this part more elegantly
|
|
180
|
+
sampled_args = (args[0], samples)
|
|
181
|
+
grads_sampled_tree, _ = self.grad_fun(params, *sampled_args, **kwargs)
|
|
182
|
+
|
|
183
|
+
# apply EMA to the first and second moments
|
|
184
|
+
# m_t
|
|
185
|
+
grads = ravel_pytree(grads_tree)[0]
|
|
186
|
+
m_t = self.beta1 * state.velocity_m + (1 - self.beta1) * grads
|
|
187
|
+
|
|
188
|
+
# v_t
|
|
189
|
+
b = logits.shape[0]
|
|
190
|
+
grads_sampled = ravel_pytree(grads_sampled_tree)[0]
|
|
191
|
+
v_t = b * grads_sampled * grads_sampled
|
|
192
|
+
v_t = self.beta2 * state.velocity_v + (1 - self.beta2) * v_t
|
|
193
|
+
|
|
194
|
+
return m_t, v_t
|
|
195
|
+
|
|
196
|
+
def grad_and_hess_reuse(self, inputs):
|
|
197
|
+
params, state, rng_key, args, kwargs = inputs
|
|
198
|
+
|
|
199
|
+
grads_tree, _ = self.grad_fun(params, *args, **kwargs)
|
|
200
|
+
|
|
201
|
+
# m_t
|
|
202
|
+
grads = ravel_pytree(grads_tree)[0]
|
|
203
|
+
m_t = self.beta1 * state.velocity_m + (1 - self.beta1) * grads
|
|
204
|
+
|
|
205
|
+
# v_t: reuse the previous estimate
|
|
206
|
+
v_t = state.velocity_v
|
|
207
|
+
|
|
208
|
+
return m_t, v_t
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Sophia-H Optimizer
|
|
3
|
+
Paper: https://arxiv.org/abs/2305.14342
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Any
|
|
7
|
+
from typing import Callable
|
|
8
|
+
from typing import NamedTuple, Tuple
|
|
9
|
+
from typing import Optional
|
|
10
|
+
import dataclasses
|
|
11
|
+
from functools import partial
|
|
12
|
+
|
|
13
|
+
import jax
|
|
14
|
+
import jax.lax as lax
|
|
15
|
+
import jax.numpy as jnp
|
|
16
|
+
from jax.flatten_util import ravel_pytree
|
|
17
|
+
from jaxopt._src import base
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SophiaHState(NamedTuple):
|
|
21
|
+
"""Named tuple containing state information."""
|
|
22
|
+
iter_num: int
|
|
23
|
+
stepsize: float
|
|
24
|
+
velocity_m: Any
|
|
25
|
+
velocity_v: Any
|
|
26
|
+
hess_approx_rng: Any
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclasses.dataclass(eq=False)
|
|
30
|
+
class SophiaH(base.StochasticSolver):
|
|
31
|
+
loss_fun: Callable
|
|
32
|
+
|
|
33
|
+
learning_rate: float = 0.85e-3 # proposed as default by levanter
|
|
34
|
+
|
|
35
|
+
# Lazy Hessian parameters
|
|
36
|
+
eval_hess_every_k: int = 10
|
|
37
|
+
|
|
38
|
+
# Momentum parameters
|
|
39
|
+
beta1: float = 0.965 # proposed as default by levanter
|
|
40
|
+
beta2: float = 0.99 # proposed as default by levanter
|
|
41
|
+
|
|
42
|
+
# Regularization parameters
|
|
43
|
+
weight_decay: float = 0.0 # L2 regularization coefficient
|
|
44
|
+
|
|
45
|
+
gamma: float = 0.01 # clipping parameter
|
|
46
|
+
clip_th: float = 1. # clipping threshold
|
|
47
|
+
eps: float = 1e-8 # term added to the denominator to improve numerical stability
|
|
48
|
+
|
|
49
|
+
hess_approx_seed: int = 1337
|
|
50
|
+
|
|
51
|
+
pre_update: Optional[Callable] = None
|
|
52
|
+
|
|
53
|
+
verbose: int = 0
|
|
54
|
+
|
|
55
|
+
jit: bool = True
|
|
56
|
+
unroll: base.AutoOrBoolean = "auto"
|
|
57
|
+
|
|
58
|
+
def __post_init__(self):
|
|
59
|
+
super().__post_init__()
|
|
60
|
+
|
|
61
|
+
self.reference_signature = self.loss_fun
|
|
62
|
+
|
|
63
|
+
self.grad_fun = jax.grad(self.loss_fun)
|
|
64
|
+
|
|
65
|
+
assert 0 <= self.weight_decay < 1, "Weight decay must be in [0, 1)"
|
|
66
|
+
|
|
67
|
+
def update(
|
|
68
|
+
self,
|
|
69
|
+
params: Any,
|
|
70
|
+
state: SophiaHState,
|
|
71
|
+
*args,
|
|
72
|
+
**kwargs,
|
|
73
|
+
) -> base.OptStep:
|
|
74
|
+
"""Performs one iteration of the solver.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
params: pytree containing the parameters.
|
|
78
|
+
state: named tuple containing the solver state.
|
|
79
|
+
*args: additional positional arguments to be passed to ``fun``.
|
|
80
|
+
**kwargs: additional keyword arguments to be passed to ``fun``.
|
|
81
|
+
Returns:
|
|
82
|
+
(params, state)
|
|
83
|
+
"""
|
|
84
|
+
# params_flat, pack_fn = ravel_pytree(params)
|
|
85
|
+
# p_shape = params_flat.shape
|
|
86
|
+
|
|
87
|
+
# ------- Step 1 -------
|
|
88
|
+
# compute the exact gradient and the Diagonal Hessian estimate
|
|
89
|
+
# incorporate the lazy Hessian evaluation, i.e. re-compute H every k-th iteration
|
|
90
|
+
# apply temporal averaging (momentum) to the first and second moments
|
|
91
|
+
# !! original Sophia paper has no bias correction: 1/(1 - self.beta1 ** i) !!
|
|
92
|
+
next_rng_key, rng_key = jax.random.split(state.hess_approx_rng)
|
|
93
|
+
inputs = (params, state, rng_key, args, kwargs)
|
|
94
|
+
m_t, v_t = lax.cond(
|
|
95
|
+
state.iter_num % self.eval_hess_every_k == 0,
|
|
96
|
+
inputs, self.grad_and_hess_hutchinson,
|
|
97
|
+
inputs, self.grad_and_hess_reuse,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# TODO apply bias correction for m_t?
|
|
101
|
+
|
|
102
|
+
# ------- Step 2 -------
|
|
103
|
+
# clip direction
|
|
104
|
+
direction = jnp.clip(m_t / jnp.maximum(self.gamma * v_t, self.eps), a_min=-self.clip_th, a_max=self.clip_th)
|
|
105
|
+
|
|
106
|
+
# update parameters
|
|
107
|
+
params_flat, pack_fn = ravel_pytree(params)
|
|
108
|
+
next_params_flat = params_flat - self.learning_rate * direction
|
|
109
|
+
|
|
110
|
+
# ------- Step 4 -------
|
|
111
|
+
# AdamW-style weight decay
|
|
112
|
+
if self.weight_decay > 0:
|
|
113
|
+
next_params_flat -= self.learning_rate * self.weight_decay * params_flat
|
|
114
|
+
|
|
115
|
+
# bookkeeping
|
|
116
|
+
next_params = pack_fn(next_params_flat)
|
|
117
|
+
|
|
118
|
+
next_state = SophiaHState(
|
|
119
|
+
iter_num=state.iter_num + 1,
|
|
120
|
+
stepsize=state.stepsize,
|
|
121
|
+
velocity_m=m_t,
|
|
122
|
+
velocity_v=v_t,
|
|
123
|
+
hess_approx_rng=next_rng_key,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
return base.OptStep(params=next_params, state=next_state)
|
|
127
|
+
|
|
128
|
+
def init_state(self,
|
|
129
|
+
init_params: Any,
|
|
130
|
+
*args,
|
|
131
|
+
**kwargs) -> SophiaHState:
|
|
132
|
+
params_flat, pack_fn = ravel_pytree(init_params)
|
|
133
|
+
velocity_m = jnp.zeros_like(params_flat)
|
|
134
|
+
velocity_v = jnp.zeros_like(params_flat)
|
|
135
|
+
|
|
136
|
+
return SophiaHState(
|
|
137
|
+
iter_num=0,
|
|
138
|
+
stepsize=self.learning_rate,
|
|
139
|
+
velocity_m=velocity_m,
|
|
140
|
+
velocity_v=velocity_v,
|
|
141
|
+
hess_approx_rng=jax.random.PRNGKey(self.hess_approx_seed),
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
def optimality_fun(self, params, *args, **kwargs):
|
|
145
|
+
"""Optimality function mapping compatible with ``@custom_root``."""
|
|
146
|
+
# return self._grad_fun(params, *args, **kwargs)[0]
|
|
147
|
+
raise NotImplementedError
|
|
148
|
+
|
|
149
|
+
def __hash__(self):
|
|
150
|
+
# We assume that the attribute values completely determine the solver.
|
|
151
|
+
return hash(self.attribute_values())
|
|
152
|
+
|
|
153
|
+
def grad_and_hess_hutchinson(self, inputs):
|
|
154
|
+
# prep grad function to accept only "params"
|
|
155
|
+
def inner_grad_fun(params):
|
|
156
|
+
return self.grad_fun(params, *args, **kwargs)
|
|
157
|
+
|
|
158
|
+
params, state, rng_key, args, kwargs = inputs
|
|
159
|
+
|
|
160
|
+
# calculate z * Hz
|
|
161
|
+
params_flat, pack_fn = ravel_pytree(params)
|
|
162
|
+
p_shape = params_flat.shape
|
|
163
|
+
z = jax.random.rademacher(rng_key, shape=p_shape, dtype=jnp.float32)
|
|
164
|
+
z_tree = pack_fn(z)
|
|
165
|
+
grads_tree, hz_tree = jax.jvp(inner_grad_fun, (params,), (z_tree,))
|
|
166
|
+
diag_hess_tree = jax.tree_map(lambda x, y: x * y, grads_tree, hz_tree)
|
|
167
|
+
|
|
168
|
+
# apply EMA to the first and second moments
|
|
169
|
+
# m_t
|
|
170
|
+
grads = ravel_pytree(grads_tree)[0]
|
|
171
|
+
m_t = self.beta1 * state.velocity_m + (1 - self.beta1) * grads
|
|
172
|
+
|
|
173
|
+
# v_t
|
|
174
|
+
v_t = ravel_pytree(diag_hess_tree)[0]
|
|
175
|
+
v_t = self.beta2 * state.velocity_v + (1 - self.beta2) * v_t
|
|
176
|
+
|
|
177
|
+
return m_t, v_t
|
|
178
|
+
|
|
179
|
+
def grad_and_hess_reuse(self, inputs):
|
|
180
|
+
params, state, rng_key, args, kwargs = inputs
|
|
181
|
+
|
|
182
|
+
grads_tree = self.grad_fun(params, *args, **kwargs)
|
|
183
|
+
|
|
184
|
+
# m_t
|
|
185
|
+
grads = ravel_pytree(grads_tree)[0]
|
|
186
|
+
m_t = self.beta1 * state.velocity_m + (1 - self.beta1) * grads
|
|
187
|
+
|
|
188
|
+
# v_t: reuse the previous estimate
|
|
189
|
+
v_t = state.velocity_v
|
|
190
|
+
|
|
191
|
+
return m_t, v_t
|
somax/gn/__init__.py
ADDED
|
File without changes
|