python-somax 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- python_somax-0.0.1.dist-info/LICENSE +201 -0
- python_somax-0.0.1.dist-info/METADATA +131 -0
- python_somax-0.0.1.dist-info/RECORD +19 -0
- python_somax-0.0.1.dist-info/WHEEL +5 -0
- python_somax-0.0.1.dist-info/top_level.txt +1 -0
- somax/__init__.py +19 -0
- somax/diagonal/__init__.py +0 -0
- somax/diagonal/adahessian.py +192 -0
- somax/diagonal/sophia_g.py +208 -0
- somax/diagonal/sophia_h.py +191 -0
- somax/gn/__init__.py +0 -0
- somax/gn/egn.py +567 -0
- somax/gn/sgn.py +237 -0
- somax/hf/__init__.py +0 -0
- somax/hf/newton_cg.py +339 -0
- somax/ng/__init__.py +0 -0
- somax/ng/swm_ng.py +411 -0
- somax/qn/__init__.py +0 -0
- somax/qn/sqn.py +311 -0
somax/gn/sgn.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Stochastic Gauss-Newton (SGN):
|
|
3
|
+
- approximate solution to the linear system via the CG method
|
|
4
|
+
- adaptive regularization (lambda)
|
|
5
|
+
Paper: https://arxiv.org/abs/2006.02409
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any
|
|
9
|
+
from typing import Callable
|
|
10
|
+
from typing import NamedTuple, Tuple
|
|
11
|
+
from typing import Optional
|
|
12
|
+
import dataclasses
|
|
13
|
+
from functools import partial
|
|
14
|
+
|
|
15
|
+
import jax
|
|
16
|
+
import jax.lax as lax
|
|
17
|
+
import jax.numpy as jnp
|
|
18
|
+
from jax.flatten_util import ravel_pytree
|
|
19
|
+
from jax.scipy.sparse.linalg import cg
|
|
20
|
+
from optax import sigmoid_binary_cross_entropy
|
|
21
|
+
from jaxopt.tree_util import tree_add_scalar_mul, tree_scalar_mul
|
|
22
|
+
from jaxopt._src import base
|
|
23
|
+
from jaxopt._src import loop
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class SGNState(NamedTuple):
|
|
27
|
+
"""Named tuple containing state information."""
|
|
28
|
+
iter_num: int
|
|
29
|
+
# error: float
|
|
30
|
+
# value: float
|
|
31
|
+
stepsize: float
|
|
32
|
+
regularizer: float
|
|
33
|
+
# direction_inf_norm: float
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclasses.dataclass(eq=False)
|
|
37
|
+
class SGN(base.StochasticSolver):
|
|
38
|
+
predict_fun: Callable
|
|
39
|
+
|
|
40
|
+
loss_fun: Optional[Callable] = None
|
|
41
|
+
loss_type: str = 'mse'
|
|
42
|
+
maxcg: int = 10
|
|
43
|
+
|
|
44
|
+
learning_rate: float = 1.0 # default value recommended by Gargiani et al.
|
|
45
|
+
|
|
46
|
+
batch_size: Optional[int] = None
|
|
47
|
+
|
|
48
|
+
n_classes: Optional[int] = None
|
|
49
|
+
|
|
50
|
+
# Adaptive Regularization parameters
|
|
51
|
+
adaptive_lambda: bool = False
|
|
52
|
+
regularizer: float = 1e-3 # default value recommended by Gargiani et al.
|
|
53
|
+
lambda_decrease_factor: float = 0.99 # default value recommended by Kiros
|
|
54
|
+
lambda_increase_factor: float = 1.01 # default value recommended by Kiros
|
|
55
|
+
|
|
56
|
+
pre_update: Optional[Callable] = None
|
|
57
|
+
|
|
58
|
+
verbose: int = 0
|
|
59
|
+
|
|
60
|
+
jit: bool = True
|
|
61
|
+
unroll: base.AutoOrBoolean = "auto"
|
|
62
|
+
|
|
63
|
+
def __post_init__(self):
|
|
64
|
+
super().__post_init__()
|
|
65
|
+
|
|
66
|
+
self.reference_signature = self.loss_fun
|
|
67
|
+
|
|
68
|
+
# Regression (MSE)
|
|
69
|
+
if self.loss_type == 'mse':
|
|
70
|
+
if self.loss_fun is None: # mostly the case for supervised learning
|
|
71
|
+
self.loss_fun = self.mse
|
|
72
|
+
|
|
73
|
+
self.loss_fun_from_preds = self.mse_from_predictions
|
|
74
|
+
|
|
75
|
+
# Classification (Cross-Entropy)
|
|
76
|
+
elif self.loss_type == 'ce' or self.loss_type == 'xe':
|
|
77
|
+
if self.n_classes == 1: # binary classification
|
|
78
|
+
if self.loss_fun is None: # mostly the case for supervised learning
|
|
79
|
+
self.loss_fun = self.ce_binary
|
|
80
|
+
|
|
81
|
+
self.loss_fun_from_preds = self.ce_binary_from_logits
|
|
82
|
+
|
|
83
|
+
else:
|
|
84
|
+
if self.loss_fun is None: # mostly the case for supervised learning
|
|
85
|
+
self.loss_fun = self.ce
|
|
86
|
+
|
|
87
|
+
self.loss_fun_from_preds = self.ce_from_logits
|
|
88
|
+
|
|
89
|
+
self.grad_fun = jax.grad(self.loss_fun)
|
|
90
|
+
|
|
91
|
+
def update(
|
|
92
|
+
self,
|
|
93
|
+
params: Any,
|
|
94
|
+
state: SGNState,
|
|
95
|
+
*args,
|
|
96
|
+
**kwargs,
|
|
97
|
+
) -> base.OptStep:
|
|
98
|
+
"""Performs one iteration of the solver.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
params: pytree containing the parameters.
|
|
102
|
+
state: named tuple containing the solver state.
|
|
103
|
+
*args: additional positional arguments to be passed to ``fun``.
|
|
104
|
+
**kwargs: additional keyword arguments to be passed to ``fun``.
|
|
105
|
+
Returns:
|
|
106
|
+
(params, state)
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
# ---------- STEP 1: calculate direction with DG ---------- #
|
|
110
|
+
# TODO analyze *args and **kwargs
|
|
111
|
+
# split (x,y) pair into (x,) and (y,)
|
|
112
|
+
if 'targets' in kwargs:
|
|
113
|
+
targets = kwargs['targets']
|
|
114
|
+
nn_args = args
|
|
115
|
+
else:
|
|
116
|
+
targets = args[-1]
|
|
117
|
+
nn_args = args[:-1]
|
|
118
|
+
|
|
119
|
+
direction_tree, grad_loss_tree = self.calculate_direction(params, state, targets, *nn_args)
|
|
120
|
+
|
|
121
|
+
# # ---------- STEP 2: update (next step) lambda ---------- #
|
|
122
|
+
# TODO
|
|
123
|
+
|
|
124
|
+
next_params = tree_add_scalar_mul(params, state.stepsize, direction_tree)
|
|
125
|
+
|
|
126
|
+
# construct the next state
|
|
127
|
+
next_state = SGNState(
|
|
128
|
+
iter_num=state.iter_num + 1, # Next Iteration
|
|
129
|
+
stepsize=state.stepsize, # Current alpha
|
|
130
|
+
regularizer=state.regularizer, # Next lambda
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
return base.OptStep(params=next_params, state=next_state)
|
|
134
|
+
|
|
135
|
+
def init_state(self,
|
|
136
|
+
init_params: Any,
|
|
137
|
+
*args,
|
|
138
|
+
**kwargs) -> SGNState:
|
|
139
|
+
return SGNState(
|
|
140
|
+
iter_num=0,
|
|
141
|
+
stepsize=self.learning_rate,
|
|
142
|
+
regularizer=self.regularizer,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
def optimality_fun(self, params, *args, **kwargs):
|
|
146
|
+
"""Optimality function mapping compatible with ``@custom_root``."""
|
|
147
|
+
# return self._grad_fun(params, *args, **kwargs)[0]
|
|
148
|
+
raise NotImplementedError
|
|
149
|
+
|
|
150
|
+
def gnhvp(self, params, vec, targets, *args):
|
|
151
|
+
def inner_predict_fun(params):
|
|
152
|
+
return self.predict_fun(params, *args)
|
|
153
|
+
|
|
154
|
+
def hvp(network_outputs, vec, targets):
|
|
155
|
+
# prep grad function to accept only "network_outputs"
|
|
156
|
+
def inner_grad_fun(network_outputs):
|
|
157
|
+
return jax.grad(self.loss_fun_from_preds)(network_outputs, targets)
|
|
158
|
+
|
|
159
|
+
return jax.jvp(inner_grad_fun, (network_outputs,), (vec,))[1]
|
|
160
|
+
|
|
161
|
+
# H_GN ~ J^T Q J, s.t. hvp is J^T Q J v
|
|
162
|
+
network_outputs, Jv = jax.jvp(inner_predict_fun, (params,), (vec,))
|
|
163
|
+
|
|
164
|
+
if self.loss_type == 'mse':
|
|
165
|
+
_, JTJv_fun = jax.vjp(inner_predict_fun, params)
|
|
166
|
+
JTJv = JTJv_fun(Jv)[0]
|
|
167
|
+
return tree_scalar_mul(1 / self.batch_size, JTJv)
|
|
168
|
+
else:
|
|
169
|
+
QJv = hvp(network_outputs, Jv, targets)
|
|
170
|
+
_, JTQJv_fun = jax.vjp(inner_predict_fun, params)
|
|
171
|
+
JTQJv = JTQJv_fun(QJv)[0]
|
|
172
|
+
return JTQJv
|
|
173
|
+
|
|
174
|
+
def calculate_direction(self, params, state, targets, *args):
|
|
175
|
+
def mvp(vec):
|
|
176
|
+
# H_{GN}v
|
|
177
|
+
# hv = jax.jvp(jax.grad(stand_alone_loss_fn), (params,), (vec,))[1]
|
|
178
|
+
hv = self.gnhvp(params, vec, targets, *args)
|
|
179
|
+
# add regularization, works since (H + lambda*I) v = Hv + lambda*v
|
|
180
|
+
return tree_add_scalar_mul(hv, state.regularizer, vec)
|
|
181
|
+
|
|
182
|
+
# --------- Start Here --------- #
|
|
183
|
+
# calculate grad
|
|
184
|
+
grad_tree = self.grad_fun(params, *args, targets)
|
|
185
|
+
|
|
186
|
+
# CG iterations
|
|
187
|
+
# TODO initial guess and preconditioner
|
|
188
|
+
direction, _ = cg(
|
|
189
|
+
A=mvp,
|
|
190
|
+
b=tree_scalar_mul(-1, grad_tree),
|
|
191
|
+
maxiter=self.maxcg,
|
|
192
|
+
# x0=None, # initial guess
|
|
193
|
+
# M=None, # preconditioner
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
return direction, grad_tree
|
|
197
|
+
|
|
198
|
+
def mse(self, params, x, y):
|
|
199
|
+
return self.mse_from_predictions(self.predict_fun(params, x), y)
|
|
200
|
+
|
|
201
|
+
@staticmethod
|
|
202
|
+
def mse_from_predictions(preds, y):
|
|
203
|
+
return 0.5 * jnp.mean(jnp.square(y - preds))
|
|
204
|
+
|
|
205
|
+
def ce(self, params, x, y):
|
|
206
|
+
logits = self.predict_fun(params, x)
|
|
207
|
+
return self.ce_from_logits(logits, y)
|
|
208
|
+
|
|
209
|
+
@staticmethod
|
|
210
|
+
def ce_from_logits(logits, y):
|
|
211
|
+
# b x C
|
|
212
|
+
# jax.nn.log_softmax combines exp() and log() in a numerically stable way.
|
|
213
|
+
log_probs = jax.nn.log_softmax(logits)
|
|
214
|
+
|
|
215
|
+
# b x 1
|
|
216
|
+
# if y is one-hot encoded, this operation picks the log probability of the correct class
|
|
217
|
+
residuals = jnp.sum(y * log_probs, axis=-1)
|
|
218
|
+
|
|
219
|
+
# 1,
|
|
220
|
+
# average over the batch
|
|
221
|
+
return -jnp.mean(residuals)
|
|
222
|
+
|
|
223
|
+
def ce_binary(self, params, x, y):
|
|
224
|
+
logits = self.predict_fun(params, x)
|
|
225
|
+
return self.ce_binary_from_logits(logits, y)
|
|
226
|
+
|
|
227
|
+
def ce_binary_from_logits(self, logits, y):
|
|
228
|
+
# b x 1
|
|
229
|
+
loss = sigmoid_binary_cross_entropy(logits.ravel(), y)
|
|
230
|
+
|
|
231
|
+
# 1,
|
|
232
|
+
# average over the batch
|
|
233
|
+
return jnp.mean(loss)
|
|
234
|
+
|
|
235
|
+
def __hash__(self):
|
|
236
|
+
# We assume that the attribute values completely determine the solver.
|
|
237
|
+
return hash(self.attribute_values())
|
somax/hf/__init__.py
ADDED
|
File without changes
|
somax/hf/newton_cg.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Newton-CG Solver:
|
|
3
|
+
- approximate solution to the linear system via the CG method
|
|
4
|
+
- adaptive regularization (lambda)
|
|
5
|
+
- line search
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any
|
|
9
|
+
from typing import Callable
|
|
10
|
+
from typing import NamedTuple, Tuple
|
|
11
|
+
from typing import Optional
|
|
12
|
+
import dataclasses
|
|
13
|
+
from functools import partial
|
|
14
|
+
|
|
15
|
+
import jax
|
|
16
|
+
import jax.lax as lax
|
|
17
|
+
import jax.numpy as jnp
|
|
18
|
+
from jax.flatten_util import ravel_pytree
|
|
19
|
+
from jax.scipy.sparse.linalg import cg
|
|
20
|
+
from jaxopt.tree_util import tree_add_scalar_mul, tree_scalar_mul
|
|
21
|
+
from jaxopt._src import base
|
|
22
|
+
from jaxopt._src import loop
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def wolfe_cond_violated(stepsize, coef, f_cur, f_next, direct_deriv):
|
|
26
|
+
eps = jnp.finfo(f_next.dtype).eps
|
|
27
|
+
loss_decrease = f_cur - f_next + eps
|
|
28
|
+
prescribed_decrease = -stepsize * coef * direct_deriv
|
|
29
|
+
return prescribed_decrease > loss_decrease
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def curvature_cond_violated(stepsize, coef, f_cur, f_next, direct_deriv):
|
|
33
|
+
loss_decrease = f_cur - f_next
|
|
34
|
+
prescribed_decrease = -stepsize * (1. - coef) * direct_deriv
|
|
35
|
+
return loss_decrease > prescribed_decrease
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def armijo_line_search(loss_fun, unroll, jit,
|
|
39
|
+
goldstein, maxls,
|
|
40
|
+
params, f_cur, stepsize,
|
|
41
|
+
direction, direct_deriv,
|
|
42
|
+
coef, decrease_factor, increase_factor, max_stepsize,
|
|
43
|
+
args, targets):
|
|
44
|
+
# given direction calculate next params
|
|
45
|
+
next_params = tree_add_scalar_mul(params, stepsize, direction)
|
|
46
|
+
|
|
47
|
+
# calculate loss at next params
|
|
48
|
+
f_next = loss_fun(next_params, *args, targets)
|
|
49
|
+
|
|
50
|
+
# grad_sqnorm = tree_l2_norm(grad, squared=True)
|
|
51
|
+
|
|
52
|
+
def update_stepsize(t):
|
|
53
|
+
"""Multiply stepsize per factor, return new params and new value."""
|
|
54
|
+
stepsize, factor = t
|
|
55
|
+
stepsize = stepsize * factor
|
|
56
|
+
stepsize = jnp.minimum(stepsize, max_stepsize)
|
|
57
|
+
|
|
58
|
+
next_params = tree_add_scalar_mul(params, stepsize, direction)
|
|
59
|
+
|
|
60
|
+
f_next = loss_fun(next_params, *args, targets)
|
|
61
|
+
|
|
62
|
+
return stepsize, next_params, f_next
|
|
63
|
+
|
|
64
|
+
def body_fun(t):
|
|
65
|
+
stepsize, next_params, f_next, _ = t
|
|
66
|
+
|
|
67
|
+
violated = wolfe_cond_violated(stepsize, coef, f_cur, f_next, direct_deriv)
|
|
68
|
+
|
|
69
|
+
stepsize, next_params, f_next = lax.cond(
|
|
70
|
+
violated,
|
|
71
|
+
update_stepsize,
|
|
72
|
+
lambda _: (stepsize, next_params, f_next),
|
|
73
|
+
operand=(stepsize, decrease_factor),
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
if goldstein:
|
|
77
|
+
goldstein_violated = curvature_cond_violated(
|
|
78
|
+
stepsize, coef, f_cur, f_next, direct_deriv)
|
|
79
|
+
|
|
80
|
+
stepsize, next_params, f_next = lax.cond(
|
|
81
|
+
goldstein_violated, update_stepsize,
|
|
82
|
+
lambda _: (stepsize, next_params, f_next),
|
|
83
|
+
operand=(stepsize, increase_factor),
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
violated = jnp.logical_or(violated, goldstein_violated)
|
|
87
|
+
|
|
88
|
+
return stepsize, next_params, f_next, violated
|
|
89
|
+
|
|
90
|
+
init_val = stepsize, next_params, f_next, jnp.array(True)
|
|
91
|
+
|
|
92
|
+
ret = loop.while_loop(cond_fun=lambda t: t[-1], # check boolean violated
|
|
93
|
+
body_fun=body_fun,
|
|
94
|
+
init_val=init_val, maxiter=maxls,
|
|
95
|
+
unroll=unroll, jit=jit)
|
|
96
|
+
|
|
97
|
+
return ret[:-1] # remove boolean
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class NewtonCGState(NamedTuple):
|
|
101
|
+
"""Named tuple containing state information."""
|
|
102
|
+
iter_num: int
|
|
103
|
+
stepsize: float
|
|
104
|
+
regularizer: float
|
|
105
|
+
cg_guess: Optional[Any]
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@dataclasses.dataclass(eq=False)
|
|
109
|
+
class NewtonCG(base.StochasticSolver):
|
|
110
|
+
# Jacobian of the residual function
|
|
111
|
+
loss_fun: Callable
|
|
112
|
+
maxcg: int = 3
|
|
113
|
+
|
|
114
|
+
# Either fixed alpha if line_search=False or max_alpha if line_search=True
|
|
115
|
+
learning_rate: float = 1.0
|
|
116
|
+
decay_coef = 0.1
|
|
117
|
+
|
|
118
|
+
batch_size: Optional[int] = None
|
|
119
|
+
|
|
120
|
+
n_classes: Optional[int] = None
|
|
121
|
+
|
|
122
|
+
# Adaptive Regularization parameters
|
|
123
|
+
adaptive_lambda: bool = True
|
|
124
|
+
regularizer: float = 1.0
|
|
125
|
+
lambda_decrease_factor: float = 0.99 # default value recommended by Kiros
|
|
126
|
+
lambda_increase_factor: float = 1.01 # default value recommended by Kiros
|
|
127
|
+
|
|
128
|
+
# Line Search parameters
|
|
129
|
+
line_search: bool = True
|
|
130
|
+
|
|
131
|
+
aggressiveness: float = 0.9 # default value recommended by Vaswani et al.
|
|
132
|
+
decrease_factor: float = 0.8 # default value recommended by Vaswani et al.
|
|
133
|
+
increase_factor: float = 1.5 # default value recommended by Vaswani et al.
|
|
134
|
+
reset_option: str = 'increase' # ['increase', 'goldstein', 'conservative']
|
|
135
|
+
|
|
136
|
+
max_stepsize: float = 1.0
|
|
137
|
+
maxls: int = 15
|
|
138
|
+
|
|
139
|
+
pre_update: Optional[Callable] = None
|
|
140
|
+
|
|
141
|
+
verbose: int = 0
|
|
142
|
+
|
|
143
|
+
jit: bool = True
|
|
144
|
+
unroll: base.AutoOrBoolean = "auto"
|
|
145
|
+
|
|
146
|
+
def __post_init__(self):
|
|
147
|
+
super().__post_init__()
|
|
148
|
+
|
|
149
|
+
self.reference_signature = self.loss_fun
|
|
150
|
+
self.grad_fun = jax.grad(self.loss_fun)
|
|
151
|
+
|
|
152
|
+
# set up line search
|
|
153
|
+
if self.line_search:
|
|
154
|
+
# !! learning rate is the maximum step size in case of line search
|
|
155
|
+
self.max_stepsize = self.learning_rate
|
|
156
|
+
|
|
157
|
+
options = ['increase', 'goldstein', 'conservative']
|
|
158
|
+
|
|
159
|
+
if self.reset_option not in options:
|
|
160
|
+
raise ValueError(f"'reset_option' should be one of {options}")
|
|
161
|
+
if self.aggressiveness <= 0. or self.aggressiveness >= 1.:
|
|
162
|
+
raise ValueError(f"'aggressiveness' must belong to open interval (0,1)")
|
|
163
|
+
|
|
164
|
+
self._coef = 1 - self.aggressiveness
|
|
165
|
+
|
|
166
|
+
unroll = self._get_unroll_option()
|
|
167
|
+
|
|
168
|
+
armijo_with_fun = partial(armijo_line_search, self.loss_fun, unroll, self.jit)
|
|
169
|
+
if self.jit:
|
|
170
|
+
jitted_armijo = jax.jit(armijo_with_fun, static_argnums=(0, 1))
|
|
171
|
+
self._armijo_line_search = jitted_armijo
|
|
172
|
+
else:
|
|
173
|
+
self._armijo_line_search = armijo_with_fun
|
|
174
|
+
|
|
175
|
+
def update(
|
|
176
|
+
self,
|
|
177
|
+
params: Any,
|
|
178
|
+
state: NewtonCGState,
|
|
179
|
+
*args,
|
|
180
|
+
**kwargs,
|
|
181
|
+
) -> base.OptStep:
|
|
182
|
+
"""Performs one iteration of the solver.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
params: pytree containing the parameters.
|
|
186
|
+
state: named tuple containing the solver state.
|
|
187
|
+
*args: additional positional arguments to be passed to ``fun``.
|
|
188
|
+
**kwargs: additional keyword arguments to be passed to ``fun``.
|
|
189
|
+
Returns:
|
|
190
|
+
(params, state)
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
# ---------- STEP 1: calculate direction with DG ---------- #
|
|
194
|
+
# TODO analyze *args and **kwargs
|
|
195
|
+
# split (x,y) pair into (x,) and (y,)
|
|
196
|
+
if 'targets' in kwargs:
|
|
197
|
+
targets = kwargs['targets']
|
|
198
|
+
nn_args = args
|
|
199
|
+
else:
|
|
200
|
+
targets = args[-1]
|
|
201
|
+
nn_args = args[:-1]
|
|
202
|
+
|
|
203
|
+
direction_tree, grad_loss_tree = self.calculate_direction(params, state, targets, *nn_args)
|
|
204
|
+
|
|
205
|
+
# ---------- STEP 2: line search for alpha ---------- #
|
|
206
|
+
f_cur = None
|
|
207
|
+
f_next = None
|
|
208
|
+
direct_deriv = None
|
|
209
|
+
if not self.line_search:
|
|
210
|
+
# constant learning rate
|
|
211
|
+
stepsize = state.stepsize
|
|
212
|
+
next_params = tree_add_scalar_mul(params, state.stepsize, direction_tree)
|
|
213
|
+
else:
|
|
214
|
+
stepsize = self.reset_stepsize(state.stepsize)
|
|
215
|
+
|
|
216
|
+
goldstein = self.reset_option == 'goldstein'
|
|
217
|
+
|
|
218
|
+
f_cur = self.loss_fun(params, *nn_args, targets)
|
|
219
|
+
|
|
220
|
+
# the directional derivative used for Armijo's line search
|
|
221
|
+
direction, _ = ravel_pytree(direction_tree)
|
|
222
|
+
grad_loss, _ = ravel_pytree(grad_loss_tree)
|
|
223
|
+
direct_deriv = grad_loss.T @ direction
|
|
224
|
+
|
|
225
|
+
stepsize, next_params, f_next = self._armijo_line_search(
|
|
226
|
+
goldstein, self.maxls, params, f_cur, stepsize,
|
|
227
|
+
direction_tree, direct_deriv, self._coef,
|
|
228
|
+
self.decrease_factor, self.increase_factor,
|
|
229
|
+
self.max_stepsize, nn_args, targets,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# ---------- STEP 3: update (next step) lambda ---------- #
|
|
233
|
+
if not self.adaptive_lambda:
|
|
234
|
+
# constant regularization
|
|
235
|
+
regularizer_next = state.regularizer
|
|
236
|
+
else:
|
|
237
|
+
# numerator: in a good scenario, should be large and negative
|
|
238
|
+
if f_cur is None:
|
|
239
|
+
f_cur = self.loss_fun(params, *nn_args, targets)
|
|
240
|
+
if f_next is None:
|
|
241
|
+
f_next = self.loss_fun(next_params, *nn_args, targets)
|
|
242
|
+
|
|
243
|
+
num = f_next - f_cur
|
|
244
|
+
|
|
245
|
+
# denominator
|
|
246
|
+
Hv_tree = self.hvp(params, direction_tree, targets, *nn_args)
|
|
247
|
+
|
|
248
|
+
# flattening stage
|
|
249
|
+
Hv, _ = ravel_pytree(Hv_tree)
|
|
250
|
+
|
|
251
|
+
if direct_deriv is None:
|
|
252
|
+
direction, _ = ravel_pytree(direction_tree)
|
|
253
|
+
grad_loss, _ = ravel_pytree(grad_loss_tree)
|
|
254
|
+
direct_deriv = grad_loss.T @ direction
|
|
255
|
+
|
|
256
|
+
denom = 0.5 * jnp.vdot(direction, Hv) + direct_deriv
|
|
257
|
+
|
|
258
|
+
# negative denominator means that the direction is a descent direction
|
|
259
|
+
rho = num / denom
|
|
260
|
+
|
|
261
|
+
regularizer_next = lax.cond(
|
|
262
|
+
rho < 0.25,
|
|
263
|
+
lambda _: self.lambda_increase_factor * state.regularizer,
|
|
264
|
+
lambda _: lax.cond(
|
|
265
|
+
rho > 0.75,
|
|
266
|
+
lambda _: self.lambda_decrease_factor * state.regularizer,
|
|
267
|
+
lambda _: state.regularizer,
|
|
268
|
+
None,
|
|
269
|
+
),
|
|
270
|
+
None,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
# construct the next state
|
|
274
|
+
next_state = NewtonCGState(
|
|
275
|
+
iter_num=state.iter_num + 1, # Next Iteration
|
|
276
|
+
stepsize=stepsize, # Current alpha
|
|
277
|
+
regularizer=regularizer_next, # Next lambda
|
|
278
|
+
cg_guess=tree_scalar_mul(self.decay_coef, direction_tree), # Next CG guess
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
return base.OptStep(params=next_params, state=next_state)
|
|
282
|
+
|
|
283
|
+
def init_state(self,
|
|
284
|
+
init_params: Any,
|
|
285
|
+
*args,
|
|
286
|
+
**kwargs) -> NewtonCGState:
|
|
287
|
+
return NewtonCGState(
|
|
288
|
+
iter_num=0,
|
|
289
|
+
stepsize=self.learning_rate,
|
|
290
|
+
regularizer=self.regularizer,
|
|
291
|
+
cg_guess=None,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
def optimality_fun(self, params, *args, **kwargs):
|
|
295
|
+
"""Optimality function mapping compatible with ``@custom_root``."""
|
|
296
|
+
# return self._grad_fun(params, *args, **kwargs)[0]
|
|
297
|
+
raise NotImplementedError
|
|
298
|
+
|
|
299
|
+
def reset_stepsize(self, stepsize):
|
|
300
|
+
"""Return new step size for current step, according to reset_option."""
|
|
301
|
+
if self.reset_option == 'goldstein':
|
|
302
|
+
return stepsize
|
|
303
|
+
if self.reset_option == 'conservative':
|
|
304
|
+
return stepsize
|
|
305
|
+
stepsize = stepsize * self.increase_factor
|
|
306
|
+
return jnp.minimum(stepsize, self.max_stepsize)
|
|
307
|
+
|
|
308
|
+
def hvp(self, params, vec, targets, *args):
|
|
309
|
+
# prep grad function to accept only "params"
|
|
310
|
+
def inner_grad_fun(params):
|
|
311
|
+
return self.grad_fun(params, *args, targets)
|
|
312
|
+
|
|
313
|
+
return jax.jvp(inner_grad_fun, (params,), (vec,))[1]
|
|
314
|
+
|
|
315
|
+
def calculate_direction(self, params, state, targets, *args):
|
|
316
|
+
def mvp(vec):
|
|
317
|
+
# Hv
|
|
318
|
+
hv = self.hvp(params, vec, targets, *args)
|
|
319
|
+
# add regularization, works since (H + lambda*I) v = Hv + lambda*v
|
|
320
|
+
return tree_add_scalar_mul(hv, state.regularizer, vec)
|
|
321
|
+
|
|
322
|
+
# --------- Start Here --------- #
|
|
323
|
+
# calculate grad
|
|
324
|
+
grad_tree = self.grad_fun(params, *args, targets)
|
|
325
|
+
|
|
326
|
+
# CG iterations
|
|
327
|
+
direction, _ = cg(
|
|
328
|
+
A=mvp,
|
|
329
|
+
b=tree_scalar_mul(-1, grad_tree),
|
|
330
|
+
maxiter=self.maxcg,
|
|
331
|
+
x0=state.cg_guess, # initial guess
|
|
332
|
+
# M=None, # preconditioner (see wiesler2013: preconditioner is not important)
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
return direction, grad_tree
|
|
336
|
+
|
|
337
|
+
def __hash__(self):
|
|
338
|
+
# We assume that the attribute values completely determine the solver.
|
|
339
|
+
return hash(self.attribute_values())
|
somax/ng/__init__.py
ADDED
|
File without changes
|