python-somax 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- python_somax-0.0.1.dist-info/LICENSE +201 -0
- python_somax-0.0.1.dist-info/METADATA +131 -0
- python_somax-0.0.1.dist-info/RECORD +19 -0
- python_somax-0.0.1.dist-info/WHEEL +5 -0
- python_somax-0.0.1.dist-info/top_level.txt +1 -0
- somax/__init__.py +19 -0
- somax/diagonal/__init__.py +0 -0
- somax/diagonal/adahessian.py +192 -0
- somax/diagonal/sophia_g.py +208 -0
- somax/diagonal/sophia_h.py +191 -0
- somax/gn/__init__.py +0 -0
- somax/gn/egn.py +567 -0
- somax/gn/sgn.py +237 -0
- somax/hf/__init__.py +0 -0
- somax/hf/newton_cg.py +339 -0
- somax/ng/__init__.py +0 -0
- somax/ng/swm_ng.py +411 -0
- somax/qn/__init__.py +0 -0
- somax/qn/sqn.py +311 -0
somax/ng/swm_ng.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SWM-NG Solver
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
from typing import Callable
|
|
7
|
+
from typing import NamedTuple, Tuple
|
|
8
|
+
from typing import Optional
|
|
9
|
+
import dataclasses
|
|
10
|
+
from functools import partial
|
|
11
|
+
|
|
12
|
+
import jax
|
|
13
|
+
import jax.lax as lax
|
|
14
|
+
import jax.numpy as jnp
|
|
15
|
+
from jax.flatten_util import ravel_pytree
|
|
16
|
+
from jaxopt.tree_util import tree_add_scalar_mul
|
|
17
|
+
from jaxopt._src import base
|
|
18
|
+
from jaxopt._src import loop
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def wolfe_cond_violated(stepsize, coef, f_cur, f_next, direct_deriv):
|
|
22
|
+
eps = jnp.finfo(f_next.dtype).eps
|
|
23
|
+
loss_decrease = f_cur - f_next + eps
|
|
24
|
+
prescribed_decrease = -stepsize * coef * direct_deriv
|
|
25
|
+
return prescribed_decrease > loss_decrease
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def curvature_cond_violated(stepsize, coef, f_cur, f_next, direct_deriv):
|
|
29
|
+
loss_decrease = f_cur - f_next
|
|
30
|
+
prescribed_decrease = -stepsize * (1. - coef) * direct_deriv
|
|
31
|
+
return loss_decrease > prescribed_decrease
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def armijo_line_search(loss_fun, unroll, jit,
|
|
35
|
+
goldstein, maxls,
|
|
36
|
+
params, f_cur, stepsize,
|
|
37
|
+
direction, direct_deriv,
|
|
38
|
+
coef, decrease_factor, increase_factor, max_stepsize,
|
|
39
|
+
args, targets):
|
|
40
|
+
# given direction calculate next params
|
|
41
|
+
next_params = tree_add_scalar_mul(params, stepsize, direction)
|
|
42
|
+
|
|
43
|
+
# calculate loss at next params
|
|
44
|
+
f_next = loss_fun(next_params, *args, targets)
|
|
45
|
+
|
|
46
|
+
# grad_sqnorm = tree_l2_norm(grad, squared=True)
|
|
47
|
+
|
|
48
|
+
def update_stepsize(t):
|
|
49
|
+
"""Multiply stepsize per factor, return new params and new value."""
|
|
50
|
+
stepsize, factor = t
|
|
51
|
+
stepsize = stepsize * factor
|
|
52
|
+
stepsize = jnp.minimum(stepsize, max_stepsize)
|
|
53
|
+
|
|
54
|
+
next_params = tree_add_scalar_mul(params, stepsize, direction)
|
|
55
|
+
|
|
56
|
+
f_next = loss_fun(next_params, *args, targets)
|
|
57
|
+
|
|
58
|
+
return stepsize, next_params, f_next
|
|
59
|
+
|
|
60
|
+
def body_fun(t):
|
|
61
|
+
stepsize, next_params, f_next, _ = t
|
|
62
|
+
|
|
63
|
+
violated = wolfe_cond_violated(stepsize, coef, f_cur, f_next, direct_deriv)
|
|
64
|
+
|
|
65
|
+
stepsize, next_params, f_next = lax.cond(
|
|
66
|
+
violated,
|
|
67
|
+
update_stepsize,
|
|
68
|
+
lambda _: (stepsize, next_params, f_next),
|
|
69
|
+
operand=(stepsize, decrease_factor),
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
if goldstein:
|
|
73
|
+
goldstein_violated = curvature_cond_violated(
|
|
74
|
+
stepsize, coef, f_cur, f_next, direct_deriv)
|
|
75
|
+
|
|
76
|
+
stepsize, next_params, f_next = lax.cond(
|
|
77
|
+
goldstein_violated, update_stepsize,
|
|
78
|
+
lambda _: (stepsize, next_params, f_next),
|
|
79
|
+
operand=(stepsize, increase_factor),
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
violated = jnp.logical_or(violated, goldstein_violated)
|
|
83
|
+
|
|
84
|
+
return stepsize, next_params, f_next, violated
|
|
85
|
+
|
|
86
|
+
init_val = stepsize, next_params, f_next, jnp.array(True)
|
|
87
|
+
|
|
88
|
+
ret = loop.while_loop(cond_fun=lambda t: t[-1], # check boolean violated
|
|
89
|
+
body_fun=body_fun,
|
|
90
|
+
init_val=init_val, maxiter=maxls,
|
|
91
|
+
unroll=unroll, jit=jit)
|
|
92
|
+
|
|
93
|
+
return ret[:-1] # remove boolean
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def flatten_2d_jacobian(jac_tree):
|
|
97
|
+
return jax.vmap(lambda _: ravel_pytree(_)[0], in_axes=(0,))(jac_tree)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class SWMNGState(NamedTuple):
|
|
101
|
+
"""Named tuple containing state information."""
|
|
102
|
+
iter_num: int
|
|
103
|
+
# error: float
|
|
104
|
+
# value: float
|
|
105
|
+
stepsize: float
|
|
106
|
+
regularizer: float
|
|
107
|
+
# direction_inf_norm: float
|
|
108
|
+
velocity: Optional[Any]
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@dataclasses.dataclass(eq=False)
|
|
112
|
+
class SWMNG(base.StochasticSolver):
|
|
113
|
+
# should be provided
|
|
114
|
+
# predict_fun: Callable
|
|
115
|
+
loss_fun: Callable
|
|
116
|
+
|
|
117
|
+
# filled during initialization
|
|
118
|
+
jac_fun: Optional[Callable] = None
|
|
119
|
+
|
|
120
|
+
# vectorization axis for Jacobian, None - no vectorization, 0 - batch axis of the tensor
|
|
121
|
+
# default value (None, 0,) matches predict_fun(params, X)
|
|
122
|
+
jac_axis: Tuple[Optional[int], ...] = (None, 0,)
|
|
123
|
+
|
|
124
|
+
# Loss function parameters
|
|
125
|
+
loss_type: str = 'mse' # ['mse', 'ce']
|
|
126
|
+
|
|
127
|
+
# Either fixed alpha if line_search=False or max_alpha if line_search=True
|
|
128
|
+
learning_rate: Optional[float] = None
|
|
129
|
+
|
|
130
|
+
batch_size: Optional[int] = None
|
|
131
|
+
|
|
132
|
+
n_classes: Optional[int] = None
|
|
133
|
+
|
|
134
|
+
# Line Search parameters
|
|
135
|
+
line_search: bool = False
|
|
136
|
+
|
|
137
|
+
aggressiveness: float = 0.9 # default value recommended by Vaswani et al.
|
|
138
|
+
decrease_factor: float = 0.8 # default value recommended by Vaswani et al.
|
|
139
|
+
increase_factor: float = 1.5 # default value recommended by Vaswani et al.
|
|
140
|
+
reset_option: str = 'increase' # ['increase', 'goldstein', 'conservative']
|
|
141
|
+
|
|
142
|
+
max_stepsize: float = 1.0
|
|
143
|
+
maxls: int = 15
|
|
144
|
+
|
|
145
|
+
# Adaptive Regularization parameters
|
|
146
|
+
adaptive_lambda: bool = False
|
|
147
|
+
regularizer: float = 1.0
|
|
148
|
+
# regularizer_eps: float = 1e-5
|
|
149
|
+
lambda_decrease_factor: float = 0.99 # default value recommended by Kiros
|
|
150
|
+
lambda_increase_factor: float = 1.01 # default value recommended by Kiros
|
|
151
|
+
|
|
152
|
+
# Momentum parameters
|
|
153
|
+
momentum: float = 0.0
|
|
154
|
+
|
|
155
|
+
pre_update: Optional[Callable] = None
|
|
156
|
+
|
|
157
|
+
verbose: int = 0
|
|
158
|
+
|
|
159
|
+
jit: bool = True
|
|
160
|
+
unroll: base.AutoOrBoolean = "auto"
|
|
161
|
+
|
|
162
|
+
def __post_init__(self):
|
|
163
|
+
super().__post_init__()
|
|
164
|
+
|
|
165
|
+
self.reference_signature = self.loss_fun
|
|
166
|
+
|
|
167
|
+
self.jac_axis = (None, 0, 0) # default for classification
|
|
168
|
+
self.jac_fun = jax.vmap(jax.grad(self.loss_fun), in_axes=self.jac_axis)
|
|
169
|
+
self.regularizer_array = self.batch_size * self.regularizer * jnp.eye(self.batch_size)
|
|
170
|
+
|
|
171
|
+
# set up line search
|
|
172
|
+
if self.line_search:
|
|
173
|
+
# !! learning rate is the maximum step size in case of line search
|
|
174
|
+
self.max_stepsize = self.learning_rate
|
|
175
|
+
|
|
176
|
+
options = ['increase', 'goldstein', 'conservative']
|
|
177
|
+
|
|
178
|
+
if self.reset_option not in options:
|
|
179
|
+
raise ValueError(f"'reset_option' should be one of {options}")
|
|
180
|
+
if self.aggressiveness <= 0. or self.aggressiveness >= 1.:
|
|
181
|
+
raise ValueError(f"'aggressiveness' must belong to open interval (0,1)")
|
|
182
|
+
|
|
183
|
+
self._coef = 1 - self.aggressiveness
|
|
184
|
+
|
|
185
|
+
unroll = self._get_unroll_option()
|
|
186
|
+
|
|
187
|
+
armijo_with_fun = partial(armijo_line_search, self.loss_fun, unroll, self.jit)
|
|
188
|
+
if self.jit:
|
|
189
|
+
jitted_armijo = jax.jit(armijo_with_fun, static_argnums=(0, 1))
|
|
190
|
+
self._armijo_line_search = jitted_armijo
|
|
191
|
+
else:
|
|
192
|
+
self._armijo_line_search = armijo_with_fun
|
|
193
|
+
|
|
194
|
+
def update(
|
|
195
|
+
self,
|
|
196
|
+
params: Any,
|
|
197
|
+
state: SWMNGState,
|
|
198
|
+
*args,
|
|
199
|
+
**kwargs,
|
|
200
|
+
) -> base.OptStep:
|
|
201
|
+
"""Performs one iteration of the solver.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
params: pytree containing the parameters.
|
|
205
|
+
state: named tuple containing the solver state.
|
|
206
|
+
*args: additional positional arguments to be passed to ``fun``.
|
|
207
|
+
**kwargs: additional keyword arguments to be passed to ``fun``.
|
|
208
|
+
Returns:
|
|
209
|
+
(params, state)
|
|
210
|
+
"""
|
|
211
|
+
|
|
212
|
+
# convert pytree to JAX array (w)
|
|
213
|
+
params_flat, unflatten_fn = ravel_pytree(params)
|
|
214
|
+
|
|
215
|
+
# ---------- STEP 1: calculate direction with SWM ---------- #
|
|
216
|
+
# TODO analyze *args and **kwargs
|
|
217
|
+
# split (x,y) pair into (x,) and (y,)
|
|
218
|
+
if 'targets' in kwargs:
|
|
219
|
+
targets = kwargs['targets']
|
|
220
|
+
nn_args = args
|
|
221
|
+
else:
|
|
222
|
+
targets = args[-1]
|
|
223
|
+
nn_args = args[:-1]
|
|
224
|
+
|
|
225
|
+
direction, grad_loss, J, Q = self.calculate_direction(params, state, targets, *nn_args)
|
|
226
|
+
|
|
227
|
+
# ---------- STEP 2: line search for alpha ---------- #
|
|
228
|
+
f_cur = None
|
|
229
|
+
f_next = None
|
|
230
|
+
next_params = None
|
|
231
|
+
if self.line_search:
|
|
232
|
+
stepsize = self.reset_stepsize(state.stepsize)
|
|
233
|
+
|
|
234
|
+
goldstein = self.reset_option == 'goldstein'
|
|
235
|
+
|
|
236
|
+
f_cur = self.loss_fun(params, *nn_args, targets)
|
|
237
|
+
|
|
238
|
+
# the directional derivative used for Armijo's line search
|
|
239
|
+
direct_deriv = grad_loss.T @ direction
|
|
240
|
+
|
|
241
|
+
direction_packed = unflatten_fn(direction)
|
|
242
|
+
|
|
243
|
+
stepsize, next_params, f_next = self._armijo_line_search(
|
|
244
|
+
goldstein, self.maxls, params, f_cur, stepsize, direction_packed, direct_deriv, self._coef,
|
|
245
|
+
self.decrease_factor, self.increase_factor, self.max_stepsize, nn_args, targets, )
|
|
246
|
+
else:
|
|
247
|
+
stepsize = state.stepsize
|
|
248
|
+
|
|
249
|
+
# ---------- STEP 3: momentum acceleration ---------- #
|
|
250
|
+
if self.momentum == 0:
|
|
251
|
+
next_velocity = None
|
|
252
|
+
else:
|
|
253
|
+
# next_params = params + stepsize*direction + momentum*(params - previous_params)
|
|
254
|
+
|
|
255
|
+
if next_params is None:
|
|
256
|
+
next_params_flat = params_flat + stepsize * direction
|
|
257
|
+
else:
|
|
258
|
+
next_params_flat, _ = ravel_pytree(next_params)
|
|
259
|
+
|
|
260
|
+
next_params_flat = next_params_flat + self.momentum * state.velocity
|
|
261
|
+
next_velocity = next_params_flat - params_flat
|
|
262
|
+
|
|
263
|
+
next_params = unflatten_fn(next_params_flat)
|
|
264
|
+
|
|
265
|
+
# ! params should be "packed" in a pytree before sending to the OptStep
|
|
266
|
+
if next_params is None:
|
|
267
|
+
# the only case is when LS=False, Momentum=0
|
|
268
|
+
next_params_flat = params_flat + stepsize * direction
|
|
269
|
+
next_params = unflatten_fn(next_params_flat)
|
|
270
|
+
|
|
271
|
+
# ---------- STEP 4: update (next step) lambda ---------- #
|
|
272
|
+
if self.adaptive_lambda:
|
|
273
|
+
# f_cur can be already computed if line search is used
|
|
274
|
+
f_cur = self.loss_fun(params, *nn_args, targets) if f_cur is None else f_cur
|
|
275
|
+
|
|
276
|
+
# if momentum is used, we need to calculate f_next again to take into account the momentum term
|
|
277
|
+
f_next = self.loss_fun(next_params, *nn_args, targets) if f_next is None or self.momentum > 0 else f_next
|
|
278
|
+
|
|
279
|
+
# in a good scenario, should be large and negative
|
|
280
|
+
num = f_next - f_cur
|
|
281
|
+
|
|
282
|
+
if self.momentum > 0:
|
|
283
|
+
delta_w = next_velocity
|
|
284
|
+
else:
|
|
285
|
+
delta_w = stepsize * direction
|
|
286
|
+
|
|
287
|
+
b = targets.shape[0]
|
|
288
|
+
|
|
289
|
+
# dimensions: (b x d) @ (d x 1) = (b x 1)
|
|
290
|
+
mvp = J @ delta_w
|
|
291
|
+
if Q is None:
|
|
292
|
+
denom = grad_loss.T @ delta_w + 0.5 * mvp.T @ mvp / b
|
|
293
|
+
else:
|
|
294
|
+
denom = grad_loss.T @ delta_w + 0.5 * mvp.T @ Q @ mvp / b
|
|
295
|
+
|
|
296
|
+
# negative denominator means that the direction is a descent direction
|
|
297
|
+
|
|
298
|
+
rho = num / denom
|
|
299
|
+
|
|
300
|
+
regularizer_next = lax.cond(
|
|
301
|
+
rho < 0.25,
|
|
302
|
+
lambda _: self.lambda_increase_factor * state.regularizer,
|
|
303
|
+
lambda _: lax.cond(
|
|
304
|
+
rho > 0.75,
|
|
305
|
+
lambda _: self.lambda_decrease_factor * state.regularizer,
|
|
306
|
+
lambda _: state.regularizer,
|
|
307
|
+
None,
|
|
308
|
+
),
|
|
309
|
+
None,
|
|
310
|
+
)
|
|
311
|
+
else:
|
|
312
|
+
regularizer_next = state.regularizer
|
|
313
|
+
|
|
314
|
+
# construct the next state
|
|
315
|
+
next_state = SWMNGState(
|
|
316
|
+
iter_num=state.iter_num + 1, # Next Iteration
|
|
317
|
+
stepsize=stepsize, # Current alpha
|
|
318
|
+
regularizer=regularizer_next, # Next lambda
|
|
319
|
+
velocity=next_velocity, # Next velocity
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
return base.OptStep(params=next_params, state=next_state)
|
|
323
|
+
|
|
324
|
+
def init_state(self,
|
|
325
|
+
init_params: Any,
|
|
326
|
+
*args,
|
|
327
|
+
**kwargs) -> SWMNGState:
|
|
328
|
+
if self.momentum == 0:
|
|
329
|
+
velocity = None
|
|
330
|
+
else:
|
|
331
|
+
velocity = jnp.zeros_like(ravel_pytree(init_params)[0])
|
|
332
|
+
|
|
333
|
+
return SWMNGState(
|
|
334
|
+
iter_num=jnp.asarray(0),
|
|
335
|
+
stepsize=jnp.asarray(self.learning_rate),
|
|
336
|
+
regularizer=jnp.asarray(self.regularizer),
|
|
337
|
+
velocity=velocity,
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
def optimality_fun(self, params, *args, **kwargs):
|
|
341
|
+
"""Optimality function mapping compatible with ``@custom_root``."""
|
|
342
|
+
# return self._grad_fun(params, *args, **kwargs)[0]
|
|
343
|
+
raise NotImplementedError
|
|
344
|
+
|
|
345
|
+
def reset_stepsize(self, stepsize):
|
|
346
|
+
"""Return new step size for current step, according to reset_option."""
|
|
347
|
+
if self.reset_option == 'goldstein':
|
|
348
|
+
return stepsize
|
|
349
|
+
if self.reset_option == 'conservative':
|
|
350
|
+
return stepsize
|
|
351
|
+
stepsize = stepsize * self.increase_factor
|
|
352
|
+
return jnp.minimum(stepsize, self.max_stepsize)
|
|
353
|
+
|
|
354
|
+
def calculate_direction(self, params, state, targets, *args):
|
|
355
|
+
batch_loss_tree = self.jac_fun(params, *args, targets)
|
|
356
|
+
L = flatten_2d_jacobian(batch_loss_tree)
|
|
357
|
+
|
|
358
|
+
grad_loss = jnp.sum(L, axis=0) / self.batch_size
|
|
359
|
+
|
|
360
|
+
temp = jax.scipy.linalg.solve(L @ L.T + self.regularizer_array, L @ grad_loss, assume_a='sym')
|
|
361
|
+
direction = (L.T @ temp - grad_loss) / self.regularizer
|
|
362
|
+
|
|
363
|
+
return direction, grad_loss, L, None
|
|
364
|
+
|
|
365
|
+
# def mse(self, params, x, y):
|
|
366
|
+
# # b x 1
|
|
367
|
+
# residuals = y - self.predict_fun(params, x)
|
|
368
|
+
#
|
|
369
|
+
# # 1,
|
|
370
|
+
# # average over the batch
|
|
371
|
+
# return 0.5 * jnp.mean(jnp.square(residuals))
|
|
372
|
+
#
|
|
373
|
+
# def ce(self, params, x, y):
|
|
374
|
+
# # b x C
|
|
375
|
+
# logits = self.predict_fun(params, x)
|
|
376
|
+
#
|
|
377
|
+
# # b x C
|
|
378
|
+
# # jax.nn.log_softmax combines exp() and log() in a numerically stable way.
|
|
379
|
+
# log_probs = jax.nn.log_softmax(logits)
|
|
380
|
+
#
|
|
381
|
+
# # b x 1
|
|
382
|
+
# # if y is one-hot encoded, this operation picks the log probability of the correct class
|
|
383
|
+
# residuals = jnp.sum(y * log_probs, axis=-1)
|
|
384
|
+
#
|
|
385
|
+
# # 1,
|
|
386
|
+
# # average over the batch
|
|
387
|
+
# return -jnp.mean(residuals)
|
|
388
|
+
#
|
|
389
|
+
# def ce_with_aux(self, params, x, y):
|
|
390
|
+
# # b x C
|
|
391
|
+
# logits = self.predict_fun(params, x)
|
|
392
|
+
#
|
|
393
|
+
# # b x C
|
|
394
|
+
# # jax.nn.log_softmax combines exp() and log() in a numerically stable way.
|
|
395
|
+
# log_probs = jax.nn.log_softmax(logits)
|
|
396
|
+
#
|
|
397
|
+
# # b x 1
|
|
398
|
+
# # if y is one-hot encoded, this operation picks the log probability of the correct class
|
|
399
|
+
# residuals = jnp.sum(y * log_probs, axis=-1)
|
|
400
|
+
#
|
|
401
|
+
# # 1,
|
|
402
|
+
# # average over the batch
|
|
403
|
+
# return -jnp.mean(residuals), logits
|
|
404
|
+
#
|
|
405
|
+
# def predict_with_aux(self, params, *args):
|
|
406
|
+
# preds = self.predict_fun(params, *args)
|
|
407
|
+
# return preds, preds
|
|
408
|
+
|
|
409
|
+
def __hash__(self):
|
|
410
|
+
# We assume that the attribute values completely determine the solver.
|
|
411
|
+
return hash(self.attribute_values())
|
somax/qn/__init__.py
ADDED
|
File without changes
|