python-somax 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,208 @@
1
+ """
2
+ Sophia-G Optimizer
3
+ Paper: https://arxiv.org/abs/2305.14342
4
+
5
+ !! Sophia-G is only implemented for the Cross-Entropy loss function !!
6
+ """
7
+
8
+ from typing import Any
9
+ from typing import Callable
10
+ from typing import NamedTuple, Tuple
11
+ from typing import Optional
12
+ import dataclasses
13
+ from functools import partial
14
+
15
+ import jax
16
+ import jax.lax as lax
17
+ import jax.numpy as jnp
18
+ from jax.flatten_util import ravel_pytree
19
+ from optax.losses import softmax_cross_entropy_with_integer_labels
20
+ from jaxopt._src import base
21
+
22
+
23
+ class SophiaGState(NamedTuple):
24
+ """Named tuple containing state information."""
25
+ iter_num: int
26
+ stepsize: float
27
+ velocity_m: Any
28
+ velocity_v: Any
29
+ hess_approx_rng: Any
30
+
31
+
32
+ @dataclasses.dataclass(eq=False)
33
+ class SophiaG(base.StochasticSolver):
34
+ predict_fun: Callable
35
+
36
+ # loss_from_logits_fun: Callable
37
+
38
+ learning_rate: float = 1e-3
39
+
40
+ # Lazy Hessian parameters
41
+ eval_hess_every_k: int = 10
42
+
43
+ # Momentum parameters
44
+ beta1: float = 0.99 # proposed as default by levanter
45
+ beta2: float = 0.99 # proposed as default by levanter
46
+
47
+ # Regularization parameters
48
+ weight_decay: float = 0.0 # L2 regularization coefficient
49
+
50
+ gamma: float = 0.05 # clipping parameter
51
+ clip_th: float = 1. # clipping threshold
52
+ eps: float = 1e-8 # term added to the denominator to improve numerical stability
53
+
54
+ hess_approx_seed: int = 1337
55
+
56
+ pre_update: Optional[Callable] = None
57
+
58
+ verbose: int = 0
59
+
60
+ jit: bool = True
61
+ unroll: base.AutoOrBoolean = "auto"
62
+
63
+ def __post_init__(self):
64
+ super().__post_init__()
65
+
66
+ self.reference_signature = self.predict_fun
67
+
68
+ self.grad_fun = jax.grad(self.loss_with_aux_fun, has_aux=True)
69
+
70
+ assert 0 <= self.weight_decay < 1, "Weight decay must be in [0, 1)"
71
+
72
+ def update(
73
+ self,
74
+ params: Any,
75
+ state: SophiaGState,
76
+ *args,
77
+ **kwargs,
78
+ ) -> base.OptStep:
79
+ """Performs one iteration of the solver.
80
+
81
+ Args:
82
+ params: pytree containing the parameters.
83
+ state: named tuple containing the solver state.
84
+ *args: additional positional arguments to be passed to ``fun``.
85
+ **kwargs: additional keyword arguments to be passed to ``fun``.
86
+ Returns:
87
+ (params, state)
88
+ """
89
+ # params_flat, pack_fn = ravel_pytree(params)
90
+ # p_shape = params_flat.shape
91
+
92
+ # ------- Step 1 -------
93
+ # compute the exact gradient and the Diagonal Hessian estimate
94
+ # incorporate the lazy Hessian evaluation, i.e. re-compute H every k-th iteration
95
+ # apply temporal averaging (momentum) to the first and second moments
96
+ # !! original Sophia paper has no bias correction: 1/(1 - self.beta1 ** i) !!
97
+ next_rng_key, rng_key = jax.random.split(state.hess_approx_rng)
98
+ inputs = (params, state, rng_key, args, kwargs)
99
+ m_t, v_t = lax.cond(
100
+ state.iter_num % self.eval_hess_every_k == 0,
101
+ inputs, self.grad_and_hess_gnb,
102
+ inputs, self.grad_and_hess_reuse,
103
+ )
104
+
105
+ # TODO apply bias correction for m_t?
106
+
107
+ # ------- Step 2 -------
108
+ # clip direction
109
+ direction = jnp.clip(m_t / jnp.maximum(self.gamma * v_t, self.eps), a_min=-self.clip_th, a_max=self.clip_th)
110
+
111
+ # update parameters
112
+ params_flat, pack_fn = ravel_pytree(params)
113
+ next_params_flat = params_flat - self.learning_rate * direction
114
+
115
+ # ------- Step 4 -------
116
+ # AdamW-style weight decay
117
+ if self.weight_decay > 0:
118
+ next_params_flat -= self.learning_rate * self.weight_decay * params_flat
119
+
120
+ # bookkeeping
121
+ next_params = pack_fn(next_params_flat)
122
+
123
+ next_state = SophiaGState(
124
+ iter_num=state.iter_num + 1,
125
+ stepsize=state.stepsize,
126
+ velocity_m=m_t,
127
+ velocity_v=v_t,
128
+ hess_approx_rng=next_rng_key,
129
+ )
130
+
131
+ return base.OptStep(params=next_params, state=next_state)
132
+
133
+ def init_state(self,
134
+ init_params: Any,
135
+ *args,
136
+ **kwargs) -> SophiaGState:
137
+ params_flat, pack_fn = ravel_pytree(init_params)
138
+ velocity_m = jnp.zeros_like(params_flat)
139
+ velocity_v = jnp.zeros_like(params_flat)
140
+
141
+ return SophiaGState(
142
+ iter_num=0,
143
+ stepsize=self.learning_rate,
144
+ velocity_m=velocity_m,
145
+ velocity_v=velocity_v,
146
+ hess_approx_rng=jax.random.PRNGKey(self.hess_approx_seed),
147
+ )
148
+
149
+ def optimality_fun(self, params, *args, **kwargs):
150
+ """Optimality function mapping compatible with ``@custom_root``."""
151
+ # return self._grad_fun(params, *args, **kwargs)[0]
152
+ raise NotImplementedError
153
+
154
+ def __hash__(self):
155
+ # We assume that the attribute values completely determine the solver.
156
+ return hash(self.attribute_values())
157
+
158
+ @staticmethod
159
+ def resolve_features_and_targets(*args, **kwargs):
160
+ # TODO: add support for named arguments
161
+ features = args[0]
162
+ targets = args[-1]
163
+ return features, targets
164
+
165
+ def loss_with_aux_fun(self, params, *args, **kwargs):
166
+ features, targets = self.resolve_features_and_targets(*args, **kwargs)
167
+ logits = self.predict_fun(params, features)
168
+ loss = softmax_cross_entropy_with_integer_labels(
169
+ logits=logits, labels=targets).mean()
170
+ return loss, logits
171
+
172
+ def grad_and_hess_gnb(self, inputs):
173
+ params, state, rng_key, args, kwargs = inputs
174
+
175
+ grads_tree, logits = self.grad_fun(params, *args, **kwargs)
176
+
177
+ samples = jax.random.categorical(rng_key, logits=logits)
178
+
179
+ # TODO consider re-writing this part more elegantly
180
+ sampled_args = (args[0], samples)
181
+ grads_sampled_tree, _ = self.grad_fun(params, *sampled_args, **kwargs)
182
+
183
+ # apply EMA to the first and second moments
184
+ # m_t
185
+ grads = ravel_pytree(grads_tree)[0]
186
+ m_t = self.beta1 * state.velocity_m + (1 - self.beta1) * grads
187
+
188
+ # v_t
189
+ b = logits.shape[0]
190
+ grads_sampled = ravel_pytree(grads_sampled_tree)[0]
191
+ v_t = b * grads_sampled * grads_sampled
192
+ v_t = self.beta2 * state.velocity_v + (1 - self.beta2) * v_t
193
+
194
+ return m_t, v_t
195
+
196
+ def grad_and_hess_reuse(self, inputs):
197
+ params, state, rng_key, args, kwargs = inputs
198
+
199
+ grads_tree, _ = self.grad_fun(params, *args, **kwargs)
200
+
201
+ # m_t
202
+ grads = ravel_pytree(grads_tree)[0]
203
+ m_t = self.beta1 * state.velocity_m + (1 - self.beta1) * grads
204
+
205
+ # v_t: reuse the previous estimate
206
+ v_t = state.velocity_v
207
+
208
+ return m_t, v_t
@@ -0,0 +1,191 @@
1
+ """
2
+ Sophia-H Optimizer
3
+ Paper: https://arxiv.org/abs/2305.14342
4
+ """
5
+
6
+ from typing import Any
7
+ from typing import Callable
8
+ from typing import NamedTuple, Tuple
9
+ from typing import Optional
10
+ import dataclasses
11
+ from functools import partial
12
+
13
+ import jax
14
+ import jax.lax as lax
15
+ import jax.numpy as jnp
16
+ from jax.flatten_util import ravel_pytree
17
+ from jaxopt._src import base
18
+
19
+
20
+ class SophiaHState(NamedTuple):
21
+ """Named tuple containing state information."""
22
+ iter_num: int
23
+ stepsize: float
24
+ velocity_m: Any
25
+ velocity_v: Any
26
+ hess_approx_rng: Any
27
+
28
+
29
+ @dataclasses.dataclass(eq=False)
30
+ class SophiaH(base.StochasticSolver):
31
+ loss_fun: Callable
32
+
33
+ learning_rate: float = 0.85e-3 # proposed as default by levanter
34
+
35
+ # Lazy Hessian parameters
36
+ eval_hess_every_k: int = 10
37
+
38
+ # Momentum parameters
39
+ beta1: float = 0.965 # proposed as default by levanter
40
+ beta2: float = 0.99 # proposed as default by levanter
41
+
42
+ # Regularization parameters
43
+ weight_decay: float = 0.0 # L2 regularization coefficient
44
+
45
+ gamma: float = 0.01 # clipping parameter
46
+ clip_th: float = 1. # clipping threshold
47
+ eps: float = 1e-8 # term added to the denominator to improve numerical stability
48
+
49
+ hess_approx_seed: int = 1337
50
+
51
+ pre_update: Optional[Callable] = None
52
+
53
+ verbose: int = 0
54
+
55
+ jit: bool = True
56
+ unroll: base.AutoOrBoolean = "auto"
57
+
58
+ def __post_init__(self):
59
+ super().__post_init__()
60
+
61
+ self.reference_signature = self.loss_fun
62
+
63
+ self.grad_fun = jax.grad(self.loss_fun)
64
+
65
+ assert 0 <= self.weight_decay < 1, "Weight decay must be in [0, 1)"
66
+
67
+ def update(
68
+ self,
69
+ params: Any,
70
+ state: SophiaHState,
71
+ *args,
72
+ **kwargs,
73
+ ) -> base.OptStep:
74
+ """Performs one iteration of the solver.
75
+
76
+ Args:
77
+ params: pytree containing the parameters.
78
+ state: named tuple containing the solver state.
79
+ *args: additional positional arguments to be passed to ``fun``.
80
+ **kwargs: additional keyword arguments to be passed to ``fun``.
81
+ Returns:
82
+ (params, state)
83
+ """
84
+ # params_flat, pack_fn = ravel_pytree(params)
85
+ # p_shape = params_flat.shape
86
+
87
+ # ------- Step 1 -------
88
+ # compute the exact gradient and the Diagonal Hessian estimate
89
+ # incorporate the lazy Hessian evaluation, i.e. re-compute H every k-th iteration
90
+ # apply temporal averaging (momentum) to the first and second moments
91
+ # !! original Sophia paper has no bias correction: 1/(1 - self.beta1 ** i) !!
92
+ next_rng_key, rng_key = jax.random.split(state.hess_approx_rng)
93
+ inputs = (params, state, rng_key, args, kwargs)
94
+ m_t, v_t = lax.cond(
95
+ state.iter_num % self.eval_hess_every_k == 0,
96
+ inputs, self.grad_and_hess_hutchinson,
97
+ inputs, self.grad_and_hess_reuse,
98
+ )
99
+
100
+ # TODO apply bias correction for m_t?
101
+
102
+ # ------- Step 2 -------
103
+ # clip direction
104
+ direction = jnp.clip(m_t / jnp.maximum(self.gamma * v_t, self.eps), a_min=-self.clip_th, a_max=self.clip_th)
105
+
106
+ # update parameters
107
+ params_flat, pack_fn = ravel_pytree(params)
108
+ next_params_flat = params_flat - self.learning_rate * direction
109
+
110
+ # ------- Step 4 -------
111
+ # AdamW-style weight decay
112
+ if self.weight_decay > 0:
113
+ next_params_flat -= self.learning_rate * self.weight_decay * params_flat
114
+
115
+ # bookkeeping
116
+ next_params = pack_fn(next_params_flat)
117
+
118
+ next_state = SophiaHState(
119
+ iter_num=state.iter_num + 1,
120
+ stepsize=state.stepsize,
121
+ velocity_m=m_t,
122
+ velocity_v=v_t,
123
+ hess_approx_rng=next_rng_key,
124
+ )
125
+
126
+ return base.OptStep(params=next_params, state=next_state)
127
+
128
+ def init_state(self,
129
+ init_params: Any,
130
+ *args,
131
+ **kwargs) -> SophiaHState:
132
+ params_flat, pack_fn = ravel_pytree(init_params)
133
+ velocity_m = jnp.zeros_like(params_flat)
134
+ velocity_v = jnp.zeros_like(params_flat)
135
+
136
+ return SophiaHState(
137
+ iter_num=0,
138
+ stepsize=self.learning_rate,
139
+ velocity_m=velocity_m,
140
+ velocity_v=velocity_v,
141
+ hess_approx_rng=jax.random.PRNGKey(self.hess_approx_seed),
142
+ )
143
+
144
+ def optimality_fun(self, params, *args, **kwargs):
145
+ """Optimality function mapping compatible with ``@custom_root``."""
146
+ # return self._grad_fun(params, *args, **kwargs)[0]
147
+ raise NotImplementedError
148
+
149
+ def __hash__(self):
150
+ # We assume that the attribute values completely determine the solver.
151
+ return hash(self.attribute_values())
152
+
153
+ def grad_and_hess_hutchinson(self, inputs):
154
+ # prep grad function to accept only "params"
155
+ def inner_grad_fun(params):
156
+ return self.grad_fun(params, *args, **kwargs)
157
+
158
+ params, state, rng_key, args, kwargs = inputs
159
+
160
+ # calculate z * Hz
161
+ params_flat, pack_fn = ravel_pytree(params)
162
+ p_shape = params_flat.shape
163
+ z = jax.random.rademacher(rng_key, shape=p_shape, dtype=jnp.float32)
164
+ z_tree = pack_fn(z)
165
+ grads_tree, hz_tree = jax.jvp(inner_grad_fun, (params,), (z_tree,))
166
+ diag_hess_tree = jax.tree_map(lambda x, y: x * y, grads_tree, hz_tree)
167
+
168
+ # apply EMA to the first and second moments
169
+ # m_t
170
+ grads = ravel_pytree(grads_tree)[0]
171
+ m_t = self.beta1 * state.velocity_m + (1 - self.beta1) * grads
172
+
173
+ # v_t
174
+ v_t = ravel_pytree(diag_hess_tree)[0]
175
+ v_t = self.beta2 * state.velocity_v + (1 - self.beta2) * v_t
176
+
177
+ return m_t, v_t
178
+
179
+ def grad_and_hess_reuse(self, inputs):
180
+ params, state, rng_key, args, kwargs = inputs
181
+
182
+ grads_tree = self.grad_fun(params, *args, **kwargs)
183
+
184
+ # m_t
185
+ grads = ravel_pytree(grads_tree)[0]
186
+ m_t = self.beta1 * state.velocity_m + (1 - self.beta1) * grads
187
+
188
+ # v_t: reuse the previous estimate
189
+ v_t = state.velocity_v
190
+
191
+ return m_t, v_t
somax/gn/__init__.py ADDED
File without changes