torchzero 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +1 -1
- torchzero/__init__.py +3 -1
- torchzero/_minimize/__init__.py +0 -0
- torchzero/_minimize/methods.py +95 -0
- torchzero/_minimize/minimize.py +518 -0
- torchzero/core/__init__.py +5 -5
- torchzero/core/chain.py +2 -1
- torchzero/core/functional.py +2 -1
- torchzero/core/module.py +75 -4
- torchzero/core/transform.py +6 -5
- torchzero/linalg/eigh.py +116 -68
- torchzero/linalg/linear_operator.py +1 -0
- torchzero/linalg/orthogonalize.py +60 -5
- torchzero/linalg/sketch.py +39 -0
- torchzero/modules/__init__.py +1 -0
- torchzero/modules/adaptive/adagrad.py +2 -0
- torchzero/modules/adaptive/adam.py +5 -1
- torchzero/modules/adaptive/adan.py +3 -0
- torchzero/modules/adaptive/ggt.py +20 -18
- torchzero/modules/adaptive/lion.py +3 -1
- torchzero/modules/adaptive/mars.py +6 -5
- torchzero/modules/adaptive/msam.py +3 -0
- torchzero/modules/adaptive/rmsprop.py +2 -0
- torchzero/modules/adaptive/rprop.py +9 -7
- torchzero/modules/adaptive/shampoo.py +9 -1
- torchzero/modules/adaptive/soap.py +32 -29
- torchzero/modules/basis/__init__.py +2 -0
- torchzero/modules/basis/ggt_basis.py +199 -0
- torchzero/modules/basis/soap_basis.py +254 -0
- torchzero/modules/clipping/ema_clipping.py +32 -27
- torchzero/modules/clipping/growth_clipping.py +1 -0
- torchzero/modules/experimental/__init__.py +1 -6
- torchzero/modules/experimental/coordinate_momentum.py +2 -0
- torchzero/modules/experimental/cubic_adam.py +4 -0
- torchzero/modules/grad_approximation/__init__.py +3 -2
- torchzero/modules/least_squares/gn.py +6 -0
- torchzero/modules/misc/gradient_accumulation.py +1 -0
- torchzero/modules/misc/misc.py +6 -0
- torchzero/modules/momentum/averaging.py +6 -0
- torchzero/modules/momentum/momentum.py +13 -9
- torchzero/modules/ops/__init__.py +0 -1
- torchzero/modules/ops/accumulate.py +4 -0
- torchzero/modules/ops/higher_level.py +6 -1
- torchzero/modules/second_order/inm.py +4 -0
- torchzero/modules/second_order/newton.py +11 -3
- torchzero/modules/second_order/newton_cg.py +7 -3
- torchzero/modules/second_order/nystrom.py +14 -19
- torchzero/modules/second_order/rsn.py +37 -6
- torchzero/modules/trust_region/trust_region.py +2 -1
- torchzero/utils/benchmarks/logistic.py +33 -18
- torchzero/utils/optuna_tools.py +1 -1
- torchzero/utils/params.py +13 -1
- torchzero/utils/tensorlist.py +2 -2
- {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/METADATA +1 -1
- {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/RECORD +58 -55
- torchzero/modules/experimental/adanystrom.py +0 -258
- torchzero/modules/experimental/common_directions_whiten.py +0 -142
- torchzero/modules/experimental/eigen_sr1.py +0 -182
- torchzero/modules/experimental/eigengrad.py +0 -207
- /torchzero/modules/{experimental → grad_approximation}/spsa1.py +0 -0
- {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/WHEEL +0 -0
- {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/top_level.txt +0 -0
|
@@ -30,6 +30,7 @@ class EMASquared(TensorTransform):
|
|
|
30
30
|
def __init__(self, beta:float=0.999, amsgrad=False, pow:float=2):
|
|
31
31
|
defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad)
|
|
32
32
|
super().__init__(defaults)
|
|
33
|
+
self.add_projected_keys("grad_sq", "exp_avg_sq", "max_exp_avg_sq")
|
|
33
34
|
|
|
34
35
|
@torch.no_grad
|
|
35
36
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
@@ -57,7 +58,7 @@ class SqrtEMASquared(TensorTransform):
|
|
|
57
58
|
def __init__(self, beta:float=0.999, amsgrad=False, debiased: bool = False, pow:float=2,):
|
|
58
59
|
defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad,debiased=debiased)
|
|
59
60
|
super().__init__(defaults)
|
|
60
|
-
|
|
61
|
+
self.add_projected_keys("grad_sq", "exp_avg_sq", "max_exp_avg_sq")
|
|
61
62
|
|
|
62
63
|
@torch.no_grad
|
|
63
64
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
@@ -141,6 +142,8 @@ class CenteredEMASquared(TensorTransform):
|
|
|
141
142
|
def __init__(self, beta: float = 0.99, amsgrad=False, pow:float=2):
|
|
142
143
|
defaults = dict(beta=beta, amsgrad=amsgrad, pow=pow)
|
|
143
144
|
super().__init__(defaults, uses_grad=False)
|
|
145
|
+
self.add_projected_keys("grad", "exp_avg")
|
|
146
|
+
self.add_projected_keys("grad_sq", "exp_avg_sq", "max_exp_avg_sq")
|
|
144
147
|
|
|
145
148
|
@torch.no_grad
|
|
146
149
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
@@ -175,6 +178,8 @@ class CenteredSqrtEMASquared(TensorTransform):
|
|
|
175
178
|
def __init__(self, beta: float = 0.99, amsgrad=False, debiased: bool = False, pow:float=2):
|
|
176
179
|
defaults = dict(beta=beta, amsgrad=amsgrad, debiased=debiased, pow=pow)
|
|
177
180
|
super().__init__(defaults, uses_grad=False)
|
|
181
|
+
self.add_projected_keys("grad", "exp_avg")
|
|
182
|
+
self.add_projected_keys("grad_sq", "exp_avg_sq", "max_exp_avg_sq")
|
|
178
183
|
|
|
179
184
|
@torch.no_grad
|
|
180
185
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
@@ -35,6 +35,8 @@ class ImprovedNewton(Transform):
|
|
|
35
35
|
self,
|
|
36
36
|
damping: float = 0,
|
|
37
37
|
eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
38
|
+
eigv_tol: float | None = None,
|
|
39
|
+
truncate: int | None = None,
|
|
38
40
|
update_freq: int = 1,
|
|
39
41
|
precompute_inverse: bool | None = None,
|
|
40
42
|
use_lstsq: bool = False,
|
|
@@ -89,6 +91,8 @@ class ImprovedNewton(Transform):
|
|
|
89
91
|
state = self.global_state,
|
|
90
92
|
damping = fs["damping"],
|
|
91
93
|
eigval_fn = fs["eigval_fn"],
|
|
94
|
+
eigv_tol = fs["eigv_tol"],
|
|
95
|
+
truncate = fs["truncate"],
|
|
92
96
|
precompute_inverse = precompute_inverse,
|
|
93
97
|
use_lstsq = fs["use_lstsq"]
|
|
94
98
|
)
|
|
@@ -7,6 +7,7 @@ from ...core import Chainable, Transform, Objective, HessianMethod
|
|
|
7
7
|
from ...utils import vec_to_tensors_
|
|
8
8
|
from ...linalg.linear_operator import Dense, DenseWithInverse, Eigendecomposition
|
|
9
9
|
from ...linalg import torch_linalg
|
|
10
|
+
from ...linalg.eigh import regularize_eigh
|
|
10
11
|
|
|
11
12
|
def _try_lu_solve(H: torch.Tensor, g: torch.Tensor):
|
|
12
13
|
try:
|
|
@@ -30,6 +31,8 @@ def _newton_update_state_(
|
|
|
30
31
|
H: torch.Tensor,
|
|
31
32
|
damping: float,
|
|
32
33
|
eigval_fn: Callable | None,
|
|
34
|
+
eigv_tol: float | None,
|
|
35
|
+
truncate: int | None,
|
|
33
36
|
precompute_inverse: bool,
|
|
34
37
|
use_lstsq: bool,
|
|
35
38
|
):
|
|
@@ -39,10 +42,11 @@ def _newton_update_state_(
|
|
|
39
42
|
reg = torch.eye(H.size(0), device=H.device, dtype=H.dtype).mul_(damping)
|
|
40
43
|
H += reg
|
|
41
44
|
|
|
42
|
-
# if
|
|
43
|
-
if
|
|
45
|
+
# if any args require eigendecomp, we don't need H or H_inv, we store factors
|
|
46
|
+
if any(i is not None for i in [eigval_fn, eigv_tol, truncate]):
|
|
44
47
|
L, Q = torch_linalg.eigh(H, retry_float64=True)
|
|
45
|
-
L = eigval_fn(L)
|
|
48
|
+
if eigval_fn is not None: L = eigval_fn(L)
|
|
49
|
+
L, Q = regularize_eigh(L, Q, truncate=truncate, tol=eigv_tol)
|
|
46
50
|
state["L"] = L
|
|
47
51
|
state["Q"] = Q
|
|
48
52
|
return
|
|
@@ -216,6 +220,8 @@ class Newton(Transform):
|
|
|
216
220
|
self,
|
|
217
221
|
damping: float = 0,
|
|
218
222
|
eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
223
|
+
eigv_tol: float | None = None,
|
|
224
|
+
truncate: int | None = None,
|
|
219
225
|
update_freq: int = 1,
|
|
220
226
|
precompute_inverse: bool | None = None,
|
|
221
227
|
use_lstsq: bool = False,
|
|
@@ -242,6 +248,8 @@ class Newton(Transform):
|
|
|
242
248
|
H=H,
|
|
243
249
|
damping = fs["damping"],
|
|
244
250
|
eigval_fn = fs["eigval_fn"],
|
|
251
|
+
eigv_tol = fs["eigv_tol"],
|
|
252
|
+
truncate = fs["truncate"],
|
|
245
253
|
precompute_inverse = precompute_inverse,
|
|
246
254
|
use_lstsq = fs["use_lstsq"]
|
|
247
255
|
)
|
|
@@ -226,7 +226,8 @@ class NewtonCGSteihaug(Transform):
|
|
|
226
226
|
tol: float = 1e-8,
|
|
227
227
|
reg: float = 1e-8,
|
|
228
228
|
solver: Literal['cg', "minres"] = 'cg',
|
|
229
|
-
adapt_tol: bool =
|
|
229
|
+
adapt_tol: bool = False,
|
|
230
|
+
terminate_on_tr: bool = True,
|
|
230
231
|
npc_terminate: bool = False,
|
|
231
232
|
|
|
232
233
|
# hvp settings
|
|
@@ -272,7 +273,6 @@ class NewtonCGSteihaug(Transform):
|
|
|
272
273
|
npc_terminate=fs["npc_terminate"]
|
|
273
274
|
miniter=fs["miniter"]
|
|
274
275
|
max_history=fs["max_history"]
|
|
275
|
-
adapt_tol=fs["adapt_tol"]
|
|
276
276
|
|
|
277
277
|
|
|
278
278
|
# ------------------------------- trust region ------------------------------- #
|
|
@@ -294,9 +294,13 @@ class NewtonCGSteihaug(Transform):
|
|
|
294
294
|
finfo = torch.finfo(orig_params[0].dtype)
|
|
295
295
|
if trust_radius < finfo.tiny * 2:
|
|
296
296
|
trust_radius = self.global_state['trust_radius'] = init
|
|
297
|
-
|
|
297
|
+
|
|
298
|
+
if fs["adapt_tol"]:
|
|
298
299
|
self.global_state["tol_mul"] = self.global_state.get("tol_mul", 1) * 0.1
|
|
299
300
|
|
|
301
|
+
if fs["terminate_on_tr"]:
|
|
302
|
+
objective.should_terminate = True
|
|
303
|
+
|
|
300
304
|
elif trust_radius > finfo.max / 2:
|
|
301
305
|
trust_radius = self.global_state['trust_radius'] = init
|
|
302
306
|
|
|
@@ -5,7 +5,7 @@ import torch
|
|
|
5
5
|
|
|
6
6
|
from ...core import Chainable, Transform, HVPMethod
|
|
7
7
|
from ...utils import TensorList, vec_to_tensors
|
|
8
|
-
from ...linalg import nystrom_pcg, nystrom_sketch_and_solve, nystrom_approximation, cg, regularize_eigh, OrthogonalizeMethod
|
|
8
|
+
from ...linalg import nystrom_pcg, nystrom_sketch_and_solve, nystrom_approximation, cg, regularize_eigh, OrthogonalizeMethod, orthogonalize
|
|
9
9
|
from ...linalg.linear_operator import Eigendecomposition, ScaledIdentity
|
|
10
10
|
|
|
11
11
|
class NystromSketchAndSolve(Transform):
|
|
@@ -75,7 +75,7 @@ class NystromSketchAndSolve(Transform):
|
|
|
75
75
|
"""
|
|
76
76
|
def __init__(
|
|
77
77
|
self,
|
|
78
|
-
rank: int,
|
|
78
|
+
rank: int = 100,
|
|
79
79
|
reg: float | None = 1e-2,
|
|
80
80
|
eigv_tol: float = 0,
|
|
81
81
|
truncate: int | None = None,
|
|
@@ -109,17 +109,15 @@ class NystromSketchAndSolve(Transform):
|
|
|
109
109
|
|
|
110
110
|
generator = self.get_generator(params[0].device, seed=fs['seed'])
|
|
111
111
|
try:
|
|
112
|
+
Omega = torch.randn([ndim, min(fs["rank"], ndim)], device=device, dtype=dtype, generator=generator)
|
|
113
|
+
Omega = orthogonalize(Omega, fs["orthogonalize_method"])
|
|
114
|
+
HOmega = H_mm(Omega)
|
|
115
|
+
|
|
112
116
|
# compute the approximation
|
|
113
117
|
L, Q = nystrom_approximation(
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
ndim=ndim,
|
|
117
|
-
rank=min(fs["rank"], ndim),
|
|
118
|
+
Omega=Omega,
|
|
119
|
+
AOmega=HOmega,
|
|
118
120
|
eigv_tol=fs["eigv_tol"],
|
|
119
|
-
orthogonalize_method=fs["orthogonalize_method"],
|
|
120
|
-
dtype=dtype,
|
|
121
|
-
device=device,
|
|
122
|
-
generator=generator,
|
|
123
121
|
)
|
|
124
122
|
|
|
125
123
|
# regularize
|
|
@@ -225,7 +223,7 @@ class NystromPCG(Transform):
|
|
|
225
223
|
"""
|
|
226
224
|
def __init__(
|
|
227
225
|
self,
|
|
228
|
-
rank: int,
|
|
226
|
+
rank: int = 100,
|
|
229
227
|
maxiter=None,
|
|
230
228
|
tol=1e-8,
|
|
231
229
|
reg: float = 1e-6,
|
|
@@ -260,16 +258,13 @@ class NystromPCG(Transform):
|
|
|
260
258
|
generator = self.get_generator(device, seed=fs['seed'])
|
|
261
259
|
|
|
262
260
|
try:
|
|
261
|
+
Omega = torch.randn(ndim, min(fs["rank"], ndim), device=device, dtype=dtype, generator=generator)
|
|
262
|
+
HOmega = H_mm(orthogonalize(Omega, fs["orthogonalize_method"]))
|
|
263
|
+
# compute the approximation
|
|
263
264
|
L, Q = nystrom_approximation(
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
ndim=ndim,
|
|
267
|
-
rank=min(fs["rank"], ndim),
|
|
265
|
+
Omega=Omega,
|
|
266
|
+
AOmega=HOmega,
|
|
268
267
|
eigv_tol=fs["eigv_tol"],
|
|
269
|
-
orthogonalize_method=fs["orthogonalize_method"],
|
|
270
|
-
dtype=dtype,
|
|
271
|
-
device=device,
|
|
272
|
-
generator=generator,
|
|
273
268
|
)
|
|
274
269
|
|
|
275
270
|
self.global_state["L"] = L
|
|
@@ -25,9 +25,23 @@ def _orthonormal_sketch(m, n, dtype, device, generator):
|
|
|
25
25
|
return _qr_orthonormalize(torch.randn(m, n, dtype=dtype, device=device, generator=generator))
|
|
26
26
|
|
|
27
27
|
def _rademacher_sketch(m, n, dtype, device, generator):
|
|
28
|
-
rademacher = torch.bernoulli(torch.full((m,n), 0.5), generator = generator).mul_(2).sub_(1)
|
|
28
|
+
rademacher = torch.bernoulli(torch.full((m,n), 0.5, device=device, dtype=dtype), generator = generator).mul_(2).sub_(1)
|
|
29
29
|
return rademacher.mul_(1 / math.sqrt(m))
|
|
30
30
|
|
|
31
|
+
def _row_sketch(m, n, dtype, device, generator):
|
|
32
|
+
weights = torch.ones(m, dtype=dtype, device=device)
|
|
33
|
+
indices = torch.multinomial(weights, n, replacement=False, generator=generator)
|
|
34
|
+
|
|
35
|
+
P = torch.zeros(m, n, dtype=dtype, device=device)
|
|
36
|
+
P[indices, range(n)] = 1
|
|
37
|
+
return P
|
|
38
|
+
|
|
39
|
+
def _topk_rows(grad, m, n, dtype, device, generator):
|
|
40
|
+
_, indices = torch.topk(grad.abs(), n)
|
|
41
|
+
P = torch.zeros(m, n, dtype=dtype, device=device)
|
|
42
|
+
P[indices, range(n)] = 1
|
|
43
|
+
return P
|
|
44
|
+
|
|
31
45
|
class SubspaceNewton(Transform):
|
|
32
46
|
"""Subspace Newton. Performs a Newton step in a subspace (random or spanned by past gradients).
|
|
33
47
|
|
|
@@ -37,7 +51,9 @@ class SubspaceNewton(Transform):
|
|
|
37
51
|
sketch_type (str, optional):
|
|
38
52
|
- "common_directions" - uses history steepest descent directions as the basis[2]. It is orthonormalized on-line using Gram-Schmidt (default).
|
|
39
53
|
- "orthonormal" - random orthonormal basis. Orthonormality is necessary to use linear operator based modules such as trust region, but it can be slower to compute.
|
|
40
|
-
- "
|
|
54
|
+
- "rows" - samples random rows.
|
|
55
|
+
- "topk" - samples top-rank rows with largest gradient magnitude.
|
|
56
|
+
- "rademacher" - approximately orthonormal (if dimension is large) scaled random rademacher basis.
|
|
41
57
|
- "mixed" - random orthonormal basis but with four directions set to gradient, slow and fast gradient EMAs, and previous update direction.
|
|
42
58
|
damping (float, optional): hessian damping (scale of identity matrix added to hessian). Defaults to 0.
|
|
43
59
|
hvp_method (str, optional):
|
|
@@ -93,13 +109,15 @@ class SubspaceNewton(Transform):
|
|
|
93
109
|
|
|
94
110
|
def __init__(
|
|
95
111
|
self,
|
|
96
|
-
sketch_size: int,
|
|
97
|
-
sketch_type: Literal["orthonormal", "common_directions", "mixed", "rademacher"] = "common_directions",
|
|
112
|
+
sketch_size: int = 100,
|
|
113
|
+
sketch_type: Literal["orthonormal", "common_directions", "mixed", "rademacher", "rows", "topk"] = "common_directions",
|
|
98
114
|
damping:float=0,
|
|
99
115
|
eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
116
|
+
eigv_tol: float | None = None,
|
|
117
|
+
truncate: int | None = None,
|
|
100
118
|
update_freq: int = 1,
|
|
101
119
|
precompute_inverse: bool = False,
|
|
102
|
-
use_lstsq: bool =
|
|
120
|
+
use_lstsq: bool = False,
|
|
103
121
|
hvp_method: HVPMethod = "batched_autograd",
|
|
104
122
|
h: float = 1e-2,
|
|
105
123
|
seed: int | None = None,
|
|
@@ -131,6 +149,14 @@ class SubspaceNewton(Transform):
|
|
|
131
149
|
elif sketch_type == 'orthonormal':
|
|
132
150
|
S = _orthonormal_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
|
|
133
151
|
|
|
152
|
+
elif sketch_type == "rows":
|
|
153
|
+
S = _row_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
|
|
154
|
+
|
|
155
|
+
elif sketch_type == "topk":
|
|
156
|
+
g_list = objective.get_grads(create_graph=hvp_method in ("batched_autograd", "autograd"))
|
|
157
|
+
g = torch.cat([t.ravel() for t in g_list])
|
|
158
|
+
S = _topk_rows(g, ndim, sketch_size, device=device, dtype=dtype, generator=generator)
|
|
159
|
+
|
|
134
160
|
elif sketch_type == 'common_directions':
|
|
135
161
|
# Wang, Po-Wei, Ching-pei Lee, and Chih-Jen Lin. "The common-directions method for regularized empirical risk minimization." Journal of Machine Learning Research 20.58 (2019): 1-49.
|
|
136
162
|
g_list = objective.get_grads(create_graph=hvp_method in ("batched_autograd", "autograd"))
|
|
@@ -189,6 +215,10 @@ class SubspaceNewton(Transform):
|
|
|
189
215
|
else:
|
|
190
216
|
raise ValueError(f'Unknown sketch_type {sketch_type}')
|
|
191
217
|
|
|
218
|
+
# print(f'{S.shape = }')
|
|
219
|
+
# I = torch.eye(S.size(1), device=S.device, dtype=S.dtype)
|
|
220
|
+
# print(f'{torch.nn.functional.mse_loss(S.T @ S, I) = }')
|
|
221
|
+
|
|
192
222
|
# form sketched hessian
|
|
193
223
|
HS, _ = objective.hessian_matrix_product(S, rgrad=None, at_x0=True,
|
|
194
224
|
hvp_method=fs["hvp_method"], h=fs["h"])
|
|
@@ -200,9 +230,10 @@ class SubspaceNewton(Transform):
|
|
|
200
230
|
H = H_sketched,
|
|
201
231
|
damping = fs["damping"],
|
|
202
232
|
eigval_fn = fs["eigval_fn"],
|
|
233
|
+
eigv_tol = fs["eigv_tol"],
|
|
234
|
+
truncate = fs["truncate"],
|
|
203
235
|
precompute_inverse = fs["precompute_inverse"],
|
|
204
236
|
use_lstsq = fs["use_lstsq"]
|
|
205
|
-
|
|
206
237
|
)
|
|
207
238
|
|
|
208
239
|
self.global_state["S"] = S
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import math
|
|
2
2
|
import warnings
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from collections.abc import Callable, Mapping, Sequence
|
|
4
|
+
from collections.abc import Callable, Mapping, Sequence, MutableMapping
|
|
5
5
|
from functools import partial
|
|
6
6
|
from typing import Any, Literal, Protocol, cast, final, overload
|
|
7
7
|
|
|
@@ -203,6 +203,7 @@ def fixed_radius(
|
|
|
203
203
|
) -> tuple[float, bool]:
|
|
204
204
|
return init, True
|
|
205
205
|
|
|
206
|
+
|
|
206
207
|
_RADIUS_KEYS = Literal['default', 'fixed']
|
|
207
208
|
_RADIUS_STRATEGIES: dict[_RADIUS_KEYS, _RadiusStrategy] = {
|
|
208
209
|
"default": default_radius,
|
|
@@ -5,39 +5,54 @@ import numpy as np
|
|
|
5
5
|
import torch
|
|
6
6
|
import tqdm
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
8
|
+
def generate_correlated_logistic_data(
|
|
9
|
+
n_samples=100_000,
|
|
10
|
+
n_features=32,
|
|
11
|
+
n_classes=10,
|
|
12
|
+
n_correlated=768,
|
|
13
|
+
correlation=0.99,
|
|
14
|
+
seed=0
|
|
15
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
16
|
+
assert n_classes >= 2
|
|
11
17
|
generator = np.random.default_rng(seed)
|
|
12
18
|
|
|
13
|
-
# ------------------------------------- X ------------------------------------ #
|
|
14
19
|
X = generator.standard_normal(size=(n_samples, n_features))
|
|
15
|
-
weights = generator.uniform(-2, 2, n_features)
|
|
20
|
+
weights = generator.uniform(-2, 2, size=(n_features, n_classes))
|
|
21
|
+
|
|
22
|
+
used_pairs = set()
|
|
23
|
+
n_correlated = min(n_correlated, n_features * (n_features - 1) // 2)
|
|
16
24
|
|
|
17
|
-
|
|
18
|
-
for i in range(n_correlated_pairs):
|
|
25
|
+
for _ in range(n_correlated):
|
|
19
26
|
idxs = None
|
|
20
27
|
while idxs is None or idxs in used_pairs:
|
|
21
|
-
|
|
28
|
+
pair = generator.choice(n_features, size=2, replace=False)
|
|
29
|
+
pair.sort()
|
|
30
|
+
idxs = tuple(pair)
|
|
22
31
|
|
|
23
|
-
used_pairs.
|
|
32
|
+
used_pairs.add(idxs)
|
|
24
33
|
idx1, idx2 = idxs
|
|
25
34
|
|
|
26
35
|
noise = generator.standard_normal(n_samples) * np.sqrt(1 - correlation**2)
|
|
27
36
|
X[:, idx2] = correlation * X[:, idx1] + noise
|
|
28
37
|
|
|
29
38
|
w = generator.integers(1, 51)
|
|
30
|
-
|
|
31
|
-
weights[
|
|
39
|
+
cls = generator.integers(0, n_classes)
|
|
40
|
+
weights[idx1, cls] = w
|
|
41
|
+
weights[idx2, cls] = -w
|
|
32
42
|
|
|
33
|
-
# ---------------------------------- logits ---------------------------------- #
|
|
34
43
|
logits = X @ weights
|
|
35
|
-
probabilities = 1 / (1 + np.exp(-logits))
|
|
36
|
-
y = generator.binomial(1, probabilities).astype(np.float32)
|
|
37
44
|
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
45
|
+
logits -= logits.max(axis=1, keepdims=True)
|
|
46
|
+
exp_logits = np.exp(logits)
|
|
47
|
+
probabilities = exp_logits / exp_logits.sum(axis=1, keepdims=True)
|
|
48
|
+
|
|
49
|
+
y_one_hot = generator.multinomial(1, pvals=probabilities)
|
|
50
|
+
y = np.argmax(y_one_hot, axis=1)
|
|
51
|
+
|
|
52
|
+
X -= X.mean(0, keepdims=True)
|
|
53
|
+
X /= X.std(0, keepdims=True)
|
|
54
|
+
|
|
55
|
+
return X, y.astype(np.int64)
|
|
41
56
|
|
|
42
57
|
|
|
43
58
|
# if __name__ == '__main__':
|
|
@@ -101,7 +116,7 @@ def run_logistic_regression(X: torch.Tensor, y: torch.Tensor, opt_fn, max_steps:
|
|
|
101
116
|
# this is for tests
|
|
102
117
|
if _assert_on_evaluated_same_params:
|
|
103
118
|
for p in evaluated_params:
|
|
104
|
-
assert not _tensorlist_equal(p, model.parameters()), f"evaluated same parameters on epoch {epoch}"
|
|
119
|
+
assert not _tensorlist_equal(p, model.parameters()), f"{optimizer} evaluated same parameters on epoch {epoch}"
|
|
105
120
|
|
|
106
121
|
evaluated_params.append([p.clone() for p in model.parameters()])
|
|
107
122
|
|
torchzero/utils/optuna_tools.py
CHANGED
|
@@ -27,7 +27,7 @@ def get_momentum(trial: optuna.Trial, prefix: str, conditional: bool=True) -> li
|
|
|
27
27
|
m = NAG(beta, dampening, lerp)
|
|
28
28
|
if debiased: m = Chain(m, Debias(beta1=beta))
|
|
29
29
|
else:
|
|
30
|
-
m = EMA(beta, dampening,
|
|
30
|
+
m = EMA(beta, dampening, debias=debiased, lerp=lerp)
|
|
31
31
|
return [m]
|
|
32
32
|
return []
|
|
33
33
|
|
torchzero/utils/params.py
CHANGED
|
@@ -3,7 +3,7 @@ from collections.abc import Sequence, Iterable, Mapping
|
|
|
3
3
|
import warnings
|
|
4
4
|
import torch, numpy as np
|
|
5
5
|
|
|
6
|
-
|
|
6
|
+
from .torch_tools import set_storage_
|
|
7
7
|
|
|
8
8
|
Params = Iterable[torch.Tensor | tuple[str, torch.Tensor] | Mapping[str, Any]]
|
|
9
9
|
|
|
@@ -147,3 +147,15 @@ def _set_update_and_grad_(
|
|
|
147
147
|
|
|
148
148
|
return param_groups
|
|
149
149
|
|
|
150
|
+
|
|
151
|
+
def _set_fake_params_(fake_params: Iterable[torch.Tensor], storage: Iterable[torch.Tensor]):
|
|
152
|
+
"""sets ``fake_params`` storage to ``storage`` while they remain the same python object"""
|
|
153
|
+
for fake_p, s in zip(fake_params, storage):
|
|
154
|
+
fake_p.set_(s.view_as(s).requires_grad_()) # pyright: ignore[reportArgumentType]
|
|
155
|
+
|
|
156
|
+
def _empty_fake_param_storage_(fake_params: Iterable[torch.Tensor]):
|
|
157
|
+
"""sets ``fake_params`` storage to empty while they remain the same python object"""
|
|
158
|
+
for p in fake_params:
|
|
159
|
+
set_storage_(p, torch.empty(0, device=p.device, dtype=p.dtype))
|
|
160
|
+
|
|
161
|
+
|
torchzero/utils/tensorlist.py
CHANGED
|
@@ -330,10 +330,10 @@ class TensorList(list[torch.Tensor | Any]):
|
|
|
330
330
|
|
|
331
331
|
def global_vector_norm(self, ord:float = 2) -> torch.Tensor:
|
|
332
332
|
# return torch.linalg.vector_norm(self.to_vec(), ord = ord) # pylint:disable = not-callable
|
|
333
|
-
if ord == 1: return self.global_sum()
|
|
334
|
-
if ord % 2 == 0: return self.pow(ord).global_sum().pow(1/ord)
|
|
335
333
|
if ord == torch.inf: return self.abs().global_max()
|
|
336
334
|
if ord == -torch.inf: return self.abs().global_min()
|
|
335
|
+
if ord == 1: return self.abs().global_sum()
|
|
336
|
+
if ord % 2 == 0: return self.pow(ord).global_sum().pow(1/ord)
|
|
337
337
|
if ord == 0: return (self != 0).global_sum().to(self[0].dtype)
|
|
338
338
|
|
|
339
339
|
return self.abs().pow_(ord).global_sum().pow(1/ord)
|