python-somax 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
somax/ng/swm_ng.py ADDED
@@ -0,0 +1,411 @@
1
+ """
2
+ SWM-NG Solver
3
+ """
4
+
5
+ from typing import Any
6
+ from typing import Callable
7
+ from typing import NamedTuple, Tuple
8
+ from typing import Optional
9
+ import dataclasses
10
+ from functools import partial
11
+
12
+ import jax
13
+ import jax.lax as lax
14
+ import jax.numpy as jnp
15
+ from jax.flatten_util import ravel_pytree
16
+ from jaxopt.tree_util import tree_add_scalar_mul
17
+ from jaxopt._src import base
18
+ from jaxopt._src import loop
19
+
20
+
21
+ def wolfe_cond_violated(stepsize, coef, f_cur, f_next, direct_deriv):
22
+ eps = jnp.finfo(f_next.dtype).eps
23
+ loss_decrease = f_cur - f_next + eps
24
+ prescribed_decrease = -stepsize * coef * direct_deriv
25
+ return prescribed_decrease > loss_decrease
26
+
27
+
28
+ def curvature_cond_violated(stepsize, coef, f_cur, f_next, direct_deriv):
29
+ loss_decrease = f_cur - f_next
30
+ prescribed_decrease = -stepsize * (1. - coef) * direct_deriv
31
+ return loss_decrease > prescribed_decrease
32
+
33
+
34
+ def armijo_line_search(loss_fun, unroll, jit,
35
+ goldstein, maxls,
36
+ params, f_cur, stepsize,
37
+ direction, direct_deriv,
38
+ coef, decrease_factor, increase_factor, max_stepsize,
39
+ args, targets):
40
+ # given direction calculate next params
41
+ next_params = tree_add_scalar_mul(params, stepsize, direction)
42
+
43
+ # calculate loss at next params
44
+ f_next = loss_fun(next_params, *args, targets)
45
+
46
+ # grad_sqnorm = tree_l2_norm(grad, squared=True)
47
+
48
+ def update_stepsize(t):
49
+ """Multiply stepsize per factor, return new params and new value."""
50
+ stepsize, factor = t
51
+ stepsize = stepsize * factor
52
+ stepsize = jnp.minimum(stepsize, max_stepsize)
53
+
54
+ next_params = tree_add_scalar_mul(params, stepsize, direction)
55
+
56
+ f_next = loss_fun(next_params, *args, targets)
57
+
58
+ return stepsize, next_params, f_next
59
+
60
+ def body_fun(t):
61
+ stepsize, next_params, f_next, _ = t
62
+
63
+ violated = wolfe_cond_violated(stepsize, coef, f_cur, f_next, direct_deriv)
64
+
65
+ stepsize, next_params, f_next = lax.cond(
66
+ violated,
67
+ update_stepsize,
68
+ lambda _: (stepsize, next_params, f_next),
69
+ operand=(stepsize, decrease_factor),
70
+ )
71
+
72
+ if goldstein:
73
+ goldstein_violated = curvature_cond_violated(
74
+ stepsize, coef, f_cur, f_next, direct_deriv)
75
+
76
+ stepsize, next_params, f_next = lax.cond(
77
+ goldstein_violated, update_stepsize,
78
+ lambda _: (stepsize, next_params, f_next),
79
+ operand=(stepsize, increase_factor),
80
+ )
81
+
82
+ violated = jnp.logical_or(violated, goldstein_violated)
83
+
84
+ return stepsize, next_params, f_next, violated
85
+
86
+ init_val = stepsize, next_params, f_next, jnp.array(True)
87
+
88
+ ret = loop.while_loop(cond_fun=lambda t: t[-1], # check boolean violated
89
+ body_fun=body_fun,
90
+ init_val=init_val, maxiter=maxls,
91
+ unroll=unroll, jit=jit)
92
+
93
+ return ret[:-1] # remove boolean
94
+
95
+
96
+ def flatten_2d_jacobian(jac_tree):
97
+ return jax.vmap(lambda _: ravel_pytree(_)[0], in_axes=(0,))(jac_tree)
98
+
99
+
100
+ class SWMNGState(NamedTuple):
101
+ """Named tuple containing state information."""
102
+ iter_num: int
103
+ # error: float
104
+ # value: float
105
+ stepsize: float
106
+ regularizer: float
107
+ # direction_inf_norm: float
108
+ velocity: Optional[Any]
109
+
110
+
111
+ @dataclasses.dataclass(eq=False)
112
+ class SWMNG(base.StochasticSolver):
113
+ # should be provided
114
+ # predict_fun: Callable
115
+ loss_fun: Callable
116
+
117
+ # filled during initialization
118
+ jac_fun: Optional[Callable] = None
119
+
120
+ # vectorization axis for Jacobian, None - no vectorization, 0 - batch axis of the tensor
121
+ # default value (None, 0,) matches predict_fun(params, X)
122
+ jac_axis: Tuple[Optional[int], ...] = (None, 0,)
123
+
124
+ # Loss function parameters
125
+ loss_type: str = 'mse' # ['mse', 'ce']
126
+
127
+ # Either fixed alpha if line_search=False or max_alpha if line_search=True
128
+ learning_rate: Optional[float] = None
129
+
130
+ batch_size: Optional[int] = None
131
+
132
+ n_classes: Optional[int] = None
133
+
134
+ # Line Search parameters
135
+ line_search: bool = False
136
+
137
+ aggressiveness: float = 0.9 # default value recommended by Vaswani et al.
138
+ decrease_factor: float = 0.8 # default value recommended by Vaswani et al.
139
+ increase_factor: float = 1.5 # default value recommended by Vaswani et al.
140
+ reset_option: str = 'increase' # ['increase', 'goldstein', 'conservative']
141
+
142
+ max_stepsize: float = 1.0
143
+ maxls: int = 15
144
+
145
+ # Adaptive Regularization parameters
146
+ adaptive_lambda: bool = False
147
+ regularizer: float = 1.0
148
+ # regularizer_eps: float = 1e-5
149
+ lambda_decrease_factor: float = 0.99 # default value recommended by Kiros
150
+ lambda_increase_factor: float = 1.01 # default value recommended by Kiros
151
+
152
+ # Momentum parameters
153
+ momentum: float = 0.0
154
+
155
+ pre_update: Optional[Callable] = None
156
+
157
+ verbose: int = 0
158
+
159
+ jit: bool = True
160
+ unroll: base.AutoOrBoolean = "auto"
161
+
162
+ def __post_init__(self):
163
+ super().__post_init__()
164
+
165
+ self.reference_signature = self.loss_fun
166
+
167
+ self.jac_axis = (None, 0, 0) # default for classification
168
+ self.jac_fun = jax.vmap(jax.grad(self.loss_fun), in_axes=self.jac_axis)
169
+ self.regularizer_array = self.batch_size * self.regularizer * jnp.eye(self.batch_size)
170
+
171
+ # set up line search
172
+ if self.line_search:
173
+ # !! learning rate is the maximum step size in case of line search
174
+ self.max_stepsize = self.learning_rate
175
+
176
+ options = ['increase', 'goldstein', 'conservative']
177
+
178
+ if self.reset_option not in options:
179
+ raise ValueError(f"'reset_option' should be one of {options}")
180
+ if self.aggressiveness <= 0. or self.aggressiveness >= 1.:
181
+ raise ValueError(f"'aggressiveness' must belong to open interval (0,1)")
182
+
183
+ self._coef = 1 - self.aggressiveness
184
+
185
+ unroll = self._get_unroll_option()
186
+
187
+ armijo_with_fun = partial(armijo_line_search, self.loss_fun, unroll, self.jit)
188
+ if self.jit:
189
+ jitted_armijo = jax.jit(armijo_with_fun, static_argnums=(0, 1))
190
+ self._armijo_line_search = jitted_armijo
191
+ else:
192
+ self._armijo_line_search = armijo_with_fun
193
+
194
+ def update(
195
+ self,
196
+ params: Any,
197
+ state: SWMNGState,
198
+ *args,
199
+ **kwargs,
200
+ ) -> base.OptStep:
201
+ """Performs one iteration of the solver.
202
+
203
+ Args:
204
+ params: pytree containing the parameters.
205
+ state: named tuple containing the solver state.
206
+ *args: additional positional arguments to be passed to ``fun``.
207
+ **kwargs: additional keyword arguments to be passed to ``fun``.
208
+ Returns:
209
+ (params, state)
210
+ """
211
+
212
+ # convert pytree to JAX array (w)
213
+ params_flat, unflatten_fn = ravel_pytree(params)
214
+
215
+ # ---------- STEP 1: calculate direction with SWM ---------- #
216
+ # TODO analyze *args and **kwargs
217
+ # split (x,y) pair into (x,) and (y,)
218
+ if 'targets' in kwargs:
219
+ targets = kwargs['targets']
220
+ nn_args = args
221
+ else:
222
+ targets = args[-1]
223
+ nn_args = args[:-1]
224
+
225
+ direction, grad_loss, J, Q = self.calculate_direction(params, state, targets, *nn_args)
226
+
227
+ # ---------- STEP 2: line search for alpha ---------- #
228
+ f_cur = None
229
+ f_next = None
230
+ next_params = None
231
+ if self.line_search:
232
+ stepsize = self.reset_stepsize(state.stepsize)
233
+
234
+ goldstein = self.reset_option == 'goldstein'
235
+
236
+ f_cur = self.loss_fun(params, *nn_args, targets)
237
+
238
+ # the directional derivative used for Armijo's line search
239
+ direct_deriv = grad_loss.T @ direction
240
+
241
+ direction_packed = unflatten_fn(direction)
242
+
243
+ stepsize, next_params, f_next = self._armijo_line_search(
244
+ goldstein, self.maxls, params, f_cur, stepsize, direction_packed, direct_deriv, self._coef,
245
+ self.decrease_factor, self.increase_factor, self.max_stepsize, nn_args, targets, )
246
+ else:
247
+ stepsize = state.stepsize
248
+
249
+ # ---------- STEP 3: momentum acceleration ---------- #
250
+ if self.momentum == 0:
251
+ next_velocity = None
252
+ else:
253
+ # next_params = params + stepsize*direction + momentum*(params - previous_params)
254
+
255
+ if next_params is None:
256
+ next_params_flat = params_flat + stepsize * direction
257
+ else:
258
+ next_params_flat, _ = ravel_pytree(next_params)
259
+
260
+ next_params_flat = next_params_flat + self.momentum * state.velocity
261
+ next_velocity = next_params_flat - params_flat
262
+
263
+ next_params = unflatten_fn(next_params_flat)
264
+
265
+ # ! params should be "packed" in a pytree before sending to the OptStep
266
+ if next_params is None:
267
+ # the only case is when LS=False, Momentum=0
268
+ next_params_flat = params_flat + stepsize * direction
269
+ next_params = unflatten_fn(next_params_flat)
270
+
271
+ # ---------- STEP 4: update (next step) lambda ---------- #
272
+ if self.adaptive_lambda:
273
+ # f_cur can be already computed if line search is used
274
+ f_cur = self.loss_fun(params, *nn_args, targets) if f_cur is None else f_cur
275
+
276
+ # if momentum is used, we need to calculate f_next again to take into account the momentum term
277
+ f_next = self.loss_fun(next_params, *nn_args, targets) if f_next is None or self.momentum > 0 else f_next
278
+
279
+ # in a good scenario, should be large and negative
280
+ num = f_next - f_cur
281
+
282
+ if self.momentum > 0:
283
+ delta_w = next_velocity
284
+ else:
285
+ delta_w = stepsize * direction
286
+
287
+ b = targets.shape[0]
288
+
289
+ # dimensions: (b x d) @ (d x 1) = (b x 1)
290
+ mvp = J @ delta_w
291
+ if Q is None:
292
+ denom = grad_loss.T @ delta_w + 0.5 * mvp.T @ mvp / b
293
+ else:
294
+ denom = grad_loss.T @ delta_w + 0.5 * mvp.T @ Q @ mvp / b
295
+
296
+ # negative denominator means that the direction is a descent direction
297
+
298
+ rho = num / denom
299
+
300
+ regularizer_next = lax.cond(
301
+ rho < 0.25,
302
+ lambda _: self.lambda_increase_factor * state.regularizer,
303
+ lambda _: lax.cond(
304
+ rho > 0.75,
305
+ lambda _: self.lambda_decrease_factor * state.regularizer,
306
+ lambda _: state.regularizer,
307
+ None,
308
+ ),
309
+ None,
310
+ )
311
+ else:
312
+ regularizer_next = state.regularizer
313
+
314
+ # construct the next state
315
+ next_state = SWMNGState(
316
+ iter_num=state.iter_num + 1, # Next Iteration
317
+ stepsize=stepsize, # Current alpha
318
+ regularizer=regularizer_next, # Next lambda
319
+ velocity=next_velocity, # Next velocity
320
+ )
321
+
322
+ return base.OptStep(params=next_params, state=next_state)
323
+
324
+ def init_state(self,
325
+ init_params: Any,
326
+ *args,
327
+ **kwargs) -> SWMNGState:
328
+ if self.momentum == 0:
329
+ velocity = None
330
+ else:
331
+ velocity = jnp.zeros_like(ravel_pytree(init_params)[0])
332
+
333
+ return SWMNGState(
334
+ iter_num=jnp.asarray(0),
335
+ stepsize=jnp.asarray(self.learning_rate),
336
+ regularizer=jnp.asarray(self.regularizer),
337
+ velocity=velocity,
338
+ )
339
+
340
+ def optimality_fun(self, params, *args, **kwargs):
341
+ """Optimality function mapping compatible with ``@custom_root``."""
342
+ # return self._grad_fun(params, *args, **kwargs)[0]
343
+ raise NotImplementedError
344
+
345
+ def reset_stepsize(self, stepsize):
346
+ """Return new step size for current step, according to reset_option."""
347
+ if self.reset_option == 'goldstein':
348
+ return stepsize
349
+ if self.reset_option == 'conservative':
350
+ return stepsize
351
+ stepsize = stepsize * self.increase_factor
352
+ return jnp.minimum(stepsize, self.max_stepsize)
353
+
354
+ def calculate_direction(self, params, state, targets, *args):
355
+ batch_loss_tree = self.jac_fun(params, *args, targets)
356
+ L = flatten_2d_jacobian(batch_loss_tree)
357
+
358
+ grad_loss = jnp.sum(L, axis=0) / self.batch_size
359
+
360
+ temp = jax.scipy.linalg.solve(L @ L.T + self.regularizer_array, L @ grad_loss, assume_a='sym')
361
+ direction = (L.T @ temp - grad_loss) / self.regularizer
362
+
363
+ return direction, grad_loss, L, None
364
+
365
+ # def mse(self, params, x, y):
366
+ # # b x 1
367
+ # residuals = y - self.predict_fun(params, x)
368
+ #
369
+ # # 1,
370
+ # # average over the batch
371
+ # return 0.5 * jnp.mean(jnp.square(residuals))
372
+ #
373
+ # def ce(self, params, x, y):
374
+ # # b x C
375
+ # logits = self.predict_fun(params, x)
376
+ #
377
+ # # b x C
378
+ # # jax.nn.log_softmax combines exp() and log() in a numerically stable way.
379
+ # log_probs = jax.nn.log_softmax(logits)
380
+ #
381
+ # # b x 1
382
+ # # if y is one-hot encoded, this operation picks the log probability of the correct class
383
+ # residuals = jnp.sum(y * log_probs, axis=-1)
384
+ #
385
+ # # 1,
386
+ # # average over the batch
387
+ # return -jnp.mean(residuals)
388
+ #
389
+ # def ce_with_aux(self, params, x, y):
390
+ # # b x C
391
+ # logits = self.predict_fun(params, x)
392
+ #
393
+ # # b x C
394
+ # # jax.nn.log_softmax combines exp() and log() in a numerically stable way.
395
+ # log_probs = jax.nn.log_softmax(logits)
396
+ #
397
+ # # b x 1
398
+ # # if y is one-hot encoded, this operation picks the log probability of the correct class
399
+ # residuals = jnp.sum(y * log_probs, axis=-1)
400
+ #
401
+ # # 1,
402
+ # # average over the batch
403
+ # return -jnp.mean(residuals), logits
404
+ #
405
+ # def predict_with_aux(self, params, *args):
406
+ # preds = self.predict_fun(params, *args)
407
+ # return preds, preds
408
+
409
+ def __hash__(self):
410
+ # We assume that the attribute values completely determine the solver.
411
+ return hash(self.attribute_values())
somax/qn/__init__.py ADDED
File without changes