python-somax 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
somax/gn/sgn.py ADDED
@@ -0,0 +1,237 @@
1
+ """
2
+ Stochastic Gauss-Newton (SGN):
3
+ - approximate solution to the linear system via the CG method
4
+ - adaptive regularization (lambda)
5
+ Paper: https://arxiv.org/abs/2006.02409
6
+ """
7
+
8
+ from typing import Any
9
+ from typing import Callable
10
+ from typing import NamedTuple, Tuple
11
+ from typing import Optional
12
+ import dataclasses
13
+ from functools import partial
14
+
15
+ import jax
16
+ import jax.lax as lax
17
+ import jax.numpy as jnp
18
+ from jax.flatten_util import ravel_pytree
19
+ from jax.scipy.sparse.linalg import cg
20
+ from optax import sigmoid_binary_cross_entropy
21
+ from jaxopt.tree_util import tree_add_scalar_mul, tree_scalar_mul
22
+ from jaxopt._src import base
23
+ from jaxopt._src import loop
24
+
25
+
26
+ class SGNState(NamedTuple):
27
+ """Named tuple containing state information."""
28
+ iter_num: int
29
+ # error: float
30
+ # value: float
31
+ stepsize: float
32
+ regularizer: float
33
+ # direction_inf_norm: float
34
+
35
+
36
+ @dataclasses.dataclass(eq=False)
37
+ class SGN(base.StochasticSolver):
38
+ predict_fun: Callable
39
+
40
+ loss_fun: Optional[Callable] = None
41
+ loss_type: str = 'mse'
42
+ maxcg: int = 10
43
+
44
+ learning_rate: float = 1.0 # default value recommended by Gargiani et al.
45
+
46
+ batch_size: Optional[int] = None
47
+
48
+ n_classes: Optional[int] = None
49
+
50
+ # Adaptive Regularization parameters
51
+ adaptive_lambda: bool = False
52
+ regularizer: float = 1e-3 # default value recommended by Gargiani et al.
53
+ lambda_decrease_factor: float = 0.99 # default value recommended by Kiros
54
+ lambda_increase_factor: float = 1.01 # default value recommended by Kiros
55
+
56
+ pre_update: Optional[Callable] = None
57
+
58
+ verbose: int = 0
59
+
60
+ jit: bool = True
61
+ unroll: base.AutoOrBoolean = "auto"
62
+
63
+ def __post_init__(self):
64
+ super().__post_init__()
65
+
66
+ self.reference_signature = self.loss_fun
67
+
68
+ # Regression (MSE)
69
+ if self.loss_type == 'mse':
70
+ if self.loss_fun is None: # mostly the case for supervised learning
71
+ self.loss_fun = self.mse
72
+
73
+ self.loss_fun_from_preds = self.mse_from_predictions
74
+
75
+ # Classification (Cross-Entropy)
76
+ elif self.loss_type == 'ce' or self.loss_type == 'xe':
77
+ if self.n_classes == 1: # binary classification
78
+ if self.loss_fun is None: # mostly the case for supervised learning
79
+ self.loss_fun = self.ce_binary
80
+
81
+ self.loss_fun_from_preds = self.ce_binary_from_logits
82
+
83
+ else:
84
+ if self.loss_fun is None: # mostly the case for supervised learning
85
+ self.loss_fun = self.ce
86
+
87
+ self.loss_fun_from_preds = self.ce_from_logits
88
+
89
+ self.grad_fun = jax.grad(self.loss_fun)
90
+
91
+ def update(
92
+ self,
93
+ params: Any,
94
+ state: SGNState,
95
+ *args,
96
+ **kwargs,
97
+ ) -> base.OptStep:
98
+ """Performs one iteration of the solver.
99
+
100
+ Args:
101
+ params: pytree containing the parameters.
102
+ state: named tuple containing the solver state.
103
+ *args: additional positional arguments to be passed to ``fun``.
104
+ **kwargs: additional keyword arguments to be passed to ``fun``.
105
+ Returns:
106
+ (params, state)
107
+ """
108
+
109
+ # ---------- STEP 1: calculate direction with DG ---------- #
110
+ # TODO analyze *args and **kwargs
111
+ # split (x,y) pair into (x,) and (y,)
112
+ if 'targets' in kwargs:
113
+ targets = kwargs['targets']
114
+ nn_args = args
115
+ else:
116
+ targets = args[-1]
117
+ nn_args = args[:-1]
118
+
119
+ direction_tree, grad_loss_tree = self.calculate_direction(params, state, targets, *nn_args)
120
+
121
+ # # ---------- STEP 2: update (next step) lambda ---------- #
122
+ # TODO
123
+
124
+ next_params = tree_add_scalar_mul(params, state.stepsize, direction_tree)
125
+
126
+ # construct the next state
127
+ next_state = SGNState(
128
+ iter_num=state.iter_num + 1, # Next Iteration
129
+ stepsize=state.stepsize, # Current alpha
130
+ regularizer=state.regularizer, # Next lambda
131
+ )
132
+
133
+ return base.OptStep(params=next_params, state=next_state)
134
+
135
+ def init_state(self,
136
+ init_params: Any,
137
+ *args,
138
+ **kwargs) -> SGNState:
139
+ return SGNState(
140
+ iter_num=0,
141
+ stepsize=self.learning_rate,
142
+ regularizer=self.regularizer,
143
+ )
144
+
145
+ def optimality_fun(self, params, *args, **kwargs):
146
+ """Optimality function mapping compatible with ``@custom_root``."""
147
+ # return self._grad_fun(params, *args, **kwargs)[0]
148
+ raise NotImplementedError
149
+
150
+ def gnhvp(self, params, vec, targets, *args):
151
+ def inner_predict_fun(params):
152
+ return self.predict_fun(params, *args)
153
+
154
+ def hvp(network_outputs, vec, targets):
155
+ # prep grad function to accept only "network_outputs"
156
+ def inner_grad_fun(network_outputs):
157
+ return jax.grad(self.loss_fun_from_preds)(network_outputs, targets)
158
+
159
+ return jax.jvp(inner_grad_fun, (network_outputs,), (vec,))[1]
160
+
161
+ # H_GN ~ J^T Q J, s.t. hvp is J^T Q J v
162
+ network_outputs, Jv = jax.jvp(inner_predict_fun, (params,), (vec,))
163
+
164
+ if self.loss_type == 'mse':
165
+ _, JTJv_fun = jax.vjp(inner_predict_fun, params)
166
+ JTJv = JTJv_fun(Jv)[0]
167
+ return tree_scalar_mul(1 / self.batch_size, JTJv)
168
+ else:
169
+ QJv = hvp(network_outputs, Jv, targets)
170
+ _, JTQJv_fun = jax.vjp(inner_predict_fun, params)
171
+ JTQJv = JTQJv_fun(QJv)[0]
172
+ return JTQJv
173
+
174
+ def calculate_direction(self, params, state, targets, *args):
175
+ def mvp(vec):
176
+ # H_{GN}v
177
+ # hv = jax.jvp(jax.grad(stand_alone_loss_fn), (params,), (vec,))[1]
178
+ hv = self.gnhvp(params, vec, targets, *args)
179
+ # add regularization, works since (H + lambda*I) v = Hv + lambda*v
180
+ return tree_add_scalar_mul(hv, state.regularizer, vec)
181
+
182
+ # --------- Start Here --------- #
183
+ # calculate grad
184
+ grad_tree = self.grad_fun(params, *args, targets)
185
+
186
+ # CG iterations
187
+ # TODO initial guess and preconditioner
188
+ direction, _ = cg(
189
+ A=mvp,
190
+ b=tree_scalar_mul(-1, grad_tree),
191
+ maxiter=self.maxcg,
192
+ # x0=None, # initial guess
193
+ # M=None, # preconditioner
194
+ )
195
+
196
+ return direction, grad_tree
197
+
198
+ def mse(self, params, x, y):
199
+ return self.mse_from_predictions(self.predict_fun(params, x), y)
200
+
201
+ @staticmethod
202
+ def mse_from_predictions(preds, y):
203
+ return 0.5 * jnp.mean(jnp.square(y - preds))
204
+
205
+ def ce(self, params, x, y):
206
+ logits = self.predict_fun(params, x)
207
+ return self.ce_from_logits(logits, y)
208
+
209
+ @staticmethod
210
+ def ce_from_logits(logits, y):
211
+ # b x C
212
+ # jax.nn.log_softmax combines exp() and log() in a numerically stable way.
213
+ log_probs = jax.nn.log_softmax(logits)
214
+
215
+ # b x 1
216
+ # if y is one-hot encoded, this operation picks the log probability of the correct class
217
+ residuals = jnp.sum(y * log_probs, axis=-1)
218
+
219
+ # 1,
220
+ # average over the batch
221
+ return -jnp.mean(residuals)
222
+
223
+ def ce_binary(self, params, x, y):
224
+ logits = self.predict_fun(params, x)
225
+ return self.ce_binary_from_logits(logits, y)
226
+
227
+ def ce_binary_from_logits(self, logits, y):
228
+ # b x 1
229
+ loss = sigmoid_binary_cross_entropy(logits.ravel(), y)
230
+
231
+ # 1,
232
+ # average over the batch
233
+ return jnp.mean(loss)
234
+
235
+ def __hash__(self):
236
+ # We assume that the attribute values completely determine the solver.
237
+ return hash(self.attribute_values())
somax/hf/__init__.py ADDED
File without changes
somax/hf/newton_cg.py ADDED
@@ -0,0 +1,339 @@
1
+ """
2
+ Newton-CG Solver:
3
+ - approximate solution to the linear system via the CG method
4
+ - adaptive regularization (lambda)
5
+ - line search
6
+ """
7
+
8
+ from typing import Any
9
+ from typing import Callable
10
+ from typing import NamedTuple, Tuple
11
+ from typing import Optional
12
+ import dataclasses
13
+ from functools import partial
14
+
15
+ import jax
16
+ import jax.lax as lax
17
+ import jax.numpy as jnp
18
+ from jax.flatten_util import ravel_pytree
19
+ from jax.scipy.sparse.linalg import cg
20
+ from jaxopt.tree_util import tree_add_scalar_mul, tree_scalar_mul
21
+ from jaxopt._src import base
22
+ from jaxopt._src import loop
23
+
24
+
25
+ def wolfe_cond_violated(stepsize, coef, f_cur, f_next, direct_deriv):
26
+ eps = jnp.finfo(f_next.dtype).eps
27
+ loss_decrease = f_cur - f_next + eps
28
+ prescribed_decrease = -stepsize * coef * direct_deriv
29
+ return prescribed_decrease > loss_decrease
30
+
31
+
32
+ def curvature_cond_violated(stepsize, coef, f_cur, f_next, direct_deriv):
33
+ loss_decrease = f_cur - f_next
34
+ prescribed_decrease = -stepsize * (1. - coef) * direct_deriv
35
+ return loss_decrease > prescribed_decrease
36
+
37
+
38
+ def armijo_line_search(loss_fun, unroll, jit,
39
+ goldstein, maxls,
40
+ params, f_cur, stepsize,
41
+ direction, direct_deriv,
42
+ coef, decrease_factor, increase_factor, max_stepsize,
43
+ args, targets):
44
+ # given direction calculate next params
45
+ next_params = tree_add_scalar_mul(params, stepsize, direction)
46
+
47
+ # calculate loss at next params
48
+ f_next = loss_fun(next_params, *args, targets)
49
+
50
+ # grad_sqnorm = tree_l2_norm(grad, squared=True)
51
+
52
+ def update_stepsize(t):
53
+ """Multiply stepsize per factor, return new params and new value."""
54
+ stepsize, factor = t
55
+ stepsize = stepsize * factor
56
+ stepsize = jnp.minimum(stepsize, max_stepsize)
57
+
58
+ next_params = tree_add_scalar_mul(params, stepsize, direction)
59
+
60
+ f_next = loss_fun(next_params, *args, targets)
61
+
62
+ return stepsize, next_params, f_next
63
+
64
+ def body_fun(t):
65
+ stepsize, next_params, f_next, _ = t
66
+
67
+ violated = wolfe_cond_violated(stepsize, coef, f_cur, f_next, direct_deriv)
68
+
69
+ stepsize, next_params, f_next = lax.cond(
70
+ violated,
71
+ update_stepsize,
72
+ lambda _: (stepsize, next_params, f_next),
73
+ operand=(stepsize, decrease_factor),
74
+ )
75
+
76
+ if goldstein:
77
+ goldstein_violated = curvature_cond_violated(
78
+ stepsize, coef, f_cur, f_next, direct_deriv)
79
+
80
+ stepsize, next_params, f_next = lax.cond(
81
+ goldstein_violated, update_stepsize,
82
+ lambda _: (stepsize, next_params, f_next),
83
+ operand=(stepsize, increase_factor),
84
+ )
85
+
86
+ violated = jnp.logical_or(violated, goldstein_violated)
87
+
88
+ return stepsize, next_params, f_next, violated
89
+
90
+ init_val = stepsize, next_params, f_next, jnp.array(True)
91
+
92
+ ret = loop.while_loop(cond_fun=lambda t: t[-1], # check boolean violated
93
+ body_fun=body_fun,
94
+ init_val=init_val, maxiter=maxls,
95
+ unroll=unroll, jit=jit)
96
+
97
+ return ret[:-1] # remove boolean
98
+
99
+
100
+ class NewtonCGState(NamedTuple):
101
+ """Named tuple containing state information."""
102
+ iter_num: int
103
+ stepsize: float
104
+ regularizer: float
105
+ cg_guess: Optional[Any]
106
+
107
+
108
+ @dataclasses.dataclass(eq=False)
109
+ class NewtonCG(base.StochasticSolver):
110
+ # Jacobian of the residual function
111
+ loss_fun: Callable
112
+ maxcg: int = 3
113
+
114
+ # Either fixed alpha if line_search=False or max_alpha if line_search=True
115
+ learning_rate: float = 1.0
116
+ decay_coef = 0.1
117
+
118
+ batch_size: Optional[int] = None
119
+
120
+ n_classes: Optional[int] = None
121
+
122
+ # Adaptive Regularization parameters
123
+ adaptive_lambda: bool = True
124
+ regularizer: float = 1.0
125
+ lambda_decrease_factor: float = 0.99 # default value recommended by Kiros
126
+ lambda_increase_factor: float = 1.01 # default value recommended by Kiros
127
+
128
+ # Line Search parameters
129
+ line_search: bool = True
130
+
131
+ aggressiveness: float = 0.9 # default value recommended by Vaswani et al.
132
+ decrease_factor: float = 0.8 # default value recommended by Vaswani et al.
133
+ increase_factor: float = 1.5 # default value recommended by Vaswani et al.
134
+ reset_option: str = 'increase' # ['increase', 'goldstein', 'conservative']
135
+
136
+ max_stepsize: float = 1.0
137
+ maxls: int = 15
138
+
139
+ pre_update: Optional[Callable] = None
140
+
141
+ verbose: int = 0
142
+
143
+ jit: bool = True
144
+ unroll: base.AutoOrBoolean = "auto"
145
+
146
+ def __post_init__(self):
147
+ super().__post_init__()
148
+
149
+ self.reference_signature = self.loss_fun
150
+ self.grad_fun = jax.grad(self.loss_fun)
151
+
152
+ # set up line search
153
+ if self.line_search:
154
+ # !! learning rate is the maximum step size in case of line search
155
+ self.max_stepsize = self.learning_rate
156
+
157
+ options = ['increase', 'goldstein', 'conservative']
158
+
159
+ if self.reset_option not in options:
160
+ raise ValueError(f"'reset_option' should be one of {options}")
161
+ if self.aggressiveness <= 0. or self.aggressiveness >= 1.:
162
+ raise ValueError(f"'aggressiveness' must belong to open interval (0,1)")
163
+
164
+ self._coef = 1 - self.aggressiveness
165
+
166
+ unroll = self._get_unroll_option()
167
+
168
+ armijo_with_fun = partial(armijo_line_search, self.loss_fun, unroll, self.jit)
169
+ if self.jit:
170
+ jitted_armijo = jax.jit(armijo_with_fun, static_argnums=(0, 1))
171
+ self._armijo_line_search = jitted_armijo
172
+ else:
173
+ self._armijo_line_search = armijo_with_fun
174
+
175
+ def update(
176
+ self,
177
+ params: Any,
178
+ state: NewtonCGState,
179
+ *args,
180
+ **kwargs,
181
+ ) -> base.OptStep:
182
+ """Performs one iteration of the solver.
183
+
184
+ Args:
185
+ params: pytree containing the parameters.
186
+ state: named tuple containing the solver state.
187
+ *args: additional positional arguments to be passed to ``fun``.
188
+ **kwargs: additional keyword arguments to be passed to ``fun``.
189
+ Returns:
190
+ (params, state)
191
+ """
192
+
193
+ # ---------- STEP 1: calculate direction with DG ---------- #
194
+ # TODO analyze *args and **kwargs
195
+ # split (x,y) pair into (x,) and (y,)
196
+ if 'targets' in kwargs:
197
+ targets = kwargs['targets']
198
+ nn_args = args
199
+ else:
200
+ targets = args[-1]
201
+ nn_args = args[:-1]
202
+
203
+ direction_tree, grad_loss_tree = self.calculate_direction(params, state, targets, *nn_args)
204
+
205
+ # ---------- STEP 2: line search for alpha ---------- #
206
+ f_cur = None
207
+ f_next = None
208
+ direct_deriv = None
209
+ if not self.line_search:
210
+ # constant learning rate
211
+ stepsize = state.stepsize
212
+ next_params = tree_add_scalar_mul(params, state.stepsize, direction_tree)
213
+ else:
214
+ stepsize = self.reset_stepsize(state.stepsize)
215
+
216
+ goldstein = self.reset_option == 'goldstein'
217
+
218
+ f_cur = self.loss_fun(params, *nn_args, targets)
219
+
220
+ # the directional derivative used for Armijo's line search
221
+ direction, _ = ravel_pytree(direction_tree)
222
+ grad_loss, _ = ravel_pytree(grad_loss_tree)
223
+ direct_deriv = grad_loss.T @ direction
224
+
225
+ stepsize, next_params, f_next = self._armijo_line_search(
226
+ goldstein, self.maxls, params, f_cur, stepsize,
227
+ direction_tree, direct_deriv, self._coef,
228
+ self.decrease_factor, self.increase_factor,
229
+ self.max_stepsize, nn_args, targets,
230
+ )
231
+
232
+ # ---------- STEP 3: update (next step) lambda ---------- #
233
+ if not self.adaptive_lambda:
234
+ # constant regularization
235
+ regularizer_next = state.regularizer
236
+ else:
237
+ # numerator: in a good scenario, should be large and negative
238
+ if f_cur is None:
239
+ f_cur = self.loss_fun(params, *nn_args, targets)
240
+ if f_next is None:
241
+ f_next = self.loss_fun(next_params, *nn_args, targets)
242
+
243
+ num = f_next - f_cur
244
+
245
+ # denominator
246
+ Hv_tree = self.hvp(params, direction_tree, targets, *nn_args)
247
+
248
+ # flattening stage
249
+ Hv, _ = ravel_pytree(Hv_tree)
250
+
251
+ if direct_deriv is None:
252
+ direction, _ = ravel_pytree(direction_tree)
253
+ grad_loss, _ = ravel_pytree(grad_loss_tree)
254
+ direct_deriv = grad_loss.T @ direction
255
+
256
+ denom = 0.5 * jnp.vdot(direction, Hv) + direct_deriv
257
+
258
+ # negative denominator means that the direction is a descent direction
259
+ rho = num / denom
260
+
261
+ regularizer_next = lax.cond(
262
+ rho < 0.25,
263
+ lambda _: self.lambda_increase_factor * state.regularizer,
264
+ lambda _: lax.cond(
265
+ rho > 0.75,
266
+ lambda _: self.lambda_decrease_factor * state.regularizer,
267
+ lambda _: state.regularizer,
268
+ None,
269
+ ),
270
+ None,
271
+ )
272
+
273
+ # construct the next state
274
+ next_state = NewtonCGState(
275
+ iter_num=state.iter_num + 1, # Next Iteration
276
+ stepsize=stepsize, # Current alpha
277
+ regularizer=regularizer_next, # Next lambda
278
+ cg_guess=tree_scalar_mul(self.decay_coef, direction_tree), # Next CG guess
279
+ )
280
+
281
+ return base.OptStep(params=next_params, state=next_state)
282
+
283
+ def init_state(self,
284
+ init_params: Any,
285
+ *args,
286
+ **kwargs) -> NewtonCGState:
287
+ return NewtonCGState(
288
+ iter_num=0,
289
+ stepsize=self.learning_rate,
290
+ regularizer=self.regularizer,
291
+ cg_guess=None,
292
+ )
293
+
294
+ def optimality_fun(self, params, *args, **kwargs):
295
+ """Optimality function mapping compatible with ``@custom_root``."""
296
+ # return self._grad_fun(params, *args, **kwargs)[0]
297
+ raise NotImplementedError
298
+
299
+ def reset_stepsize(self, stepsize):
300
+ """Return new step size for current step, according to reset_option."""
301
+ if self.reset_option == 'goldstein':
302
+ return stepsize
303
+ if self.reset_option == 'conservative':
304
+ return stepsize
305
+ stepsize = stepsize * self.increase_factor
306
+ return jnp.minimum(stepsize, self.max_stepsize)
307
+
308
+ def hvp(self, params, vec, targets, *args):
309
+ # prep grad function to accept only "params"
310
+ def inner_grad_fun(params):
311
+ return self.grad_fun(params, *args, targets)
312
+
313
+ return jax.jvp(inner_grad_fun, (params,), (vec,))[1]
314
+
315
+ def calculate_direction(self, params, state, targets, *args):
316
+ def mvp(vec):
317
+ # Hv
318
+ hv = self.hvp(params, vec, targets, *args)
319
+ # add regularization, works since (H + lambda*I) v = Hv + lambda*v
320
+ return tree_add_scalar_mul(hv, state.regularizer, vec)
321
+
322
+ # --------- Start Here --------- #
323
+ # calculate grad
324
+ grad_tree = self.grad_fun(params, *args, targets)
325
+
326
+ # CG iterations
327
+ direction, _ = cg(
328
+ A=mvp,
329
+ b=tree_scalar_mul(-1, grad_tree),
330
+ maxiter=self.maxcg,
331
+ x0=state.cg_guess, # initial guess
332
+ # M=None, # preconditioner (see wiesler2013: preconditioner is not important)
333
+ )
334
+
335
+ return direction, grad_tree
336
+
337
+ def __hash__(self):
338
+ # We assume that the attribute values completely determine the solver.
339
+ return hash(self.attribute_values())
somax/ng/__init__.py ADDED
File without changes