torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
from typing import Any, Literal
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from ...core import Chainable
|
|
6
|
+
from .quasi_newton import (
|
|
7
|
+
HessianUpdateStrategy,
|
|
8
|
+
_HessianUpdateStrategyDefaults,
|
|
9
|
+
_InverseHessianUpdateStrategyDefaults,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
from ..functional import safe_clip
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def diagonal_bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
16
|
+
sy = s.dot(y)
|
|
17
|
+
if sy < tol: return H
|
|
18
|
+
|
|
19
|
+
sy_sq = safe_clip(sy**2)
|
|
20
|
+
|
|
21
|
+
num1 = (sy + (y * H * y)) * s*s
|
|
22
|
+
term1 = num1.div_(sy_sq)
|
|
23
|
+
num2 = (H * y * s).add_(s * y * H)
|
|
24
|
+
term2 = num2.div_(sy)
|
|
25
|
+
H += term1.sub_(term2)
|
|
26
|
+
return H
|
|
27
|
+
|
|
28
|
+
class DiagonalBFGS(_InverseHessianUpdateStrategyDefaults):
|
|
29
|
+
"""Diagonal BFGS. This is simply BFGS with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful."""
|
|
30
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
31
|
+
return diagonal_bfgs_H_(H=H, s=s, y=y, tol=setting['tol'])
|
|
32
|
+
|
|
33
|
+
def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
|
|
34
|
+
|
|
35
|
+
def diagonal_sr1_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
|
|
36
|
+
z = s - H*y
|
|
37
|
+
denom = z.dot(y)
|
|
38
|
+
|
|
39
|
+
z_norm = torch.linalg.norm(z) # pylint:disable=not-callable
|
|
40
|
+
y_norm = torch.linalg.norm(y) # pylint:disable=not-callable
|
|
41
|
+
|
|
42
|
+
# if y_norm*z_norm < tol: return H
|
|
43
|
+
|
|
44
|
+
# check as in Nocedal, Wright. “Numerical optimization” 2nd p.146
|
|
45
|
+
if denom.abs() <= tol * y_norm * z_norm: return H # pylint:disable=not-callable
|
|
46
|
+
H += (z*z).div_(safe_clip(denom))
|
|
47
|
+
return H
|
|
48
|
+
class DiagonalSR1(_InverseHessianUpdateStrategyDefaults):
|
|
49
|
+
"""Diagonal SR1. This is simply SR1 with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful."""
|
|
50
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
51
|
+
return diagonal_sr1_(H=H, s=s, y=y, tol=setting['tol'])
|
|
52
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
53
|
+
return diagonal_sr1_(H=B, s=y, y=s, tol=setting['tol'])
|
|
54
|
+
|
|
55
|
+
def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# Zhu M., Nazareth J. L., Wolkowicz H. The quasi-Cauchy relation and diagonal updating //SIAM Journal on Optimization. – 1999. – Т. 9. – №. 4. – С. 1192-1204.
|
|
60
|
+
def diagonal_qc_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
61
|
+
denom = safe_clip((s**4).sum())
|
|
62
|
+
num = s.dot(y) - (s*B).dot(s)
|
|
63
|
+
B += s**2 * (num/denom)
|
|
64
|
+
return B
|
|
65
|
+
|
|
66
|
+
class DiagonalQuasiCauchi(_HessianUpdateStrategyDefaults):
|
|
67
|
+
"""Diagonal quasi-cauchi method.
|
|
68
|
+
|
|
69
|
+
Reference:
|
|
70
|
+
Zhu M., Nazareth J. L., Wolkowicz H. The quasi-Cauchy relation and diagonal updating //SIAM Journal on Optimization. – 1999. – Т. 9. – №. 4. – С. 1192-1204.
|
|
71
|
+
"""
|
|
72
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
73
|
+
return diagonal_qc_B_(B=B, s=s, y=y)
|
|
74
|
+
|
|
75
|
+
def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
|
|
76
|
+
|
|
77
|
+
# Leong, Wah June, Sharareh Enshaei, and Sie Long Kek. "Diagonal quasi-Newton methods via least change updating principle with weighted Frobenius norm." Numerical Algorithms 86 (2021): 1225-1241.
|
|
78
|
+
def diagonal_wqc_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
79
|
+
E_sq = s**2 * B**2
|
|
80
|
+
denom = safe_clip((s*E_sq).dot(s))
|
|
81
|
+
num = s.dot(y) - (s*B).dot(s)
|
|
82
|
+
B += E_sq * (num/denom)
|
|
83
|
+
return B
|
|
84
|
+
|
|
85
|
+
class DiagonalWeightedQuasiCauchi(_HessianUpdateStrategyDefaults):
|
|
86
|
+
"""Diagonal quasi-cauchi method.
|
|
87
|
+
|
|
88
|
+
Reference:
|
|
89
|
+
Leong, Wah June, Sharareh Enshaei, and Sie Long Kek. "Diagonal quasi-Newton methods via least change updating principle with weighted Frobenius norm." Numerical Algorithms 86 (2021): 1225-1241.
|
|
90
|
+
"""
|
|
91
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
92
|
+
return diagonal_wqc_B_(B=B, s=s, y=y)
|
|
93
|
+
|
|
94
|
+
def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
|
|
95
|
+
|
|
96
|
+
def _truncate(B: torch.Tensor, lb, ub):
|
|
97
|
+
return torch.where((B>lb).logical_and(B<ub), B, 1)
|
|
98
|
+
|
|
99
|
+
# Andrei, Neculai. "A diagonal quasi-Newton updating method for unconstrained optimization." Numerical Algorithms 81.2 (2019): 575-590.
|
|
100
|
+
def dnrtr_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
101
|
+
denom = safe_clip((s**4).sum())
|
|
102
|
+
num = s.dot(y) + s.dot(s) - (s*B).dot(s)
|
|
103
|
+
B += s**2 * (num/denom) - 1
|
|
104
|
+
return B
|
|
105
|
+
|
|
106
|
+
class DNRTR(HessianUpdateStrategy):
|
|
107
|
+
"""Diagonal quasi-newton method.
|
|
108
|
+
|
|
109
|
+
Reference:
|
|
110
|
+
Andrei, Neculai. "A diagonal quasi-Newton updating method for unconstrained optimization." Numerical Algorithms 81.2 (2019): 575-590.
|
|
111
|
+
"""
|
|
112
|
+
def __init__(
|
|
113
|
+
self,
|
|
114
|
+
lb: float = 1e-2,
|
|
115
|
+
ub: float = 1e5,
|
|
116
|
+
init_scale: float | Literal["auto"] = "auto",
|
|
117
|
+
tol: float = 1e-32,
|
|
118
|
+
ptol: float | None = 1e-32,
|
|
119
|
+
ptol_restart: bool = False,
|
|
120
|
+
gtol: float | None = 1e-32,
|
|
121
|
+
restart_interval: int | None | Literal['auto'] = None,
|
|
122
|
+
beta: float | None = None,
|
|
123
|
+
update_freq: int = 1,
|
|
124
|
+
scale_first: bool = False,
|
|
125
|
+
concat_params: bool = True,
|
|
126
|
+
inner: Chainable | None = None,
|
|
127
|
+
):
|
|
128
|
+
defaults = dict(lb=lb, ub=ub)
|
|
129
|
+
super().__init__(
|
|
130
|
+
defaults=defaults,
|
|
131
|
+
init_scale=init_scale,
|
|
132
|
+
tol=tol,
|
|
133
|
+
ptol=ptol,
|
|
134
|
+
ptol_restart=ptol_restart,
|
|
135
|
+
gtol=gtol,
|
|
136
|
+
restart_interval=restart_interval,
|
|
137
|
+
beta=beta,
|
|
138
|
+
update_freq=update_freq,
|
|
139
|
+
scale_first=scale_first,
|
|
140
|
+
concat_params=concat_params,
|
|
141
|
+
inverse=False,
|
|
142
|
+
inner=inner,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
146
|
+
return diagonal_wqc_B_(B=B, s=s, y=y)
|
|
147
|
+
|
|
148
|
+
def modify_B(self, B, state, setting):
|
|
149
|
+
return _truncate(B, setting['lb'], setting['ub'])
|
|
150
|
+
|
|
151
|
+
def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
|
|
152
|
+
|
|
153
|
+
# Nosrati, Mahsa, and Keyvan Amini. "A new diagonal quasi-Newton algorithm for unconstrained optimization problems." Applications of Mathematics 69.4 (2024): 501-512.
|
|
154
|
+
def new_dqn_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
155
|
+
denom = safe_clip((s**4).sum())
|
|
156
|
+
num = s.dot(y)
|
|
157
|
+
B += s**2 * (num/denom)
|
|
158
|
+
return B
|
|
159
|
+
|
|
160
|
+
class NewDQN(DNRTR):
|
|
161
|
+
"""Diagonal quasi-newton method.
|
|
162
|
+
|
|
163
|
+
Reference:
|
|
164
|
+
Nosrati, Mahsa, and Keyvan Amini. "A new diagonal quasi-Newton algorithm for unconstrained optimization problems." Applications of Mathematics 69.4 (2024): 501-512.
|
|
165
|
+
"""
|
|
166
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
167
|
+
return new_dqn_B_(B=B, s=s, y=y)
|
|
@@ -1,162 +1,257 @@
|
|
|
1
1
|
from collections import deque
|
|
2
|
-
from
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from typing import overload
|
|
4
|
+
|
|
3
5
|
import torch
|
|
4
6
|
|
|
5
|
-
from ...core import
|
|
6
|
-
from ...utils import TensorList, as_tensorlist,
|
|
7
|
+
from ...core import Chainable, Transform
|
|
8
|
+
from ...utils import TensorList, as_tensorlist, unpack_states
|
|
9
|
+
from ...utils.linalg.linear_operator import LinearOperator
|
|
10
|
+
from ..functional import initial_step_size
|
|
11
|
+
from .damping import DampingStrategyType, apply_damping
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@torch.no_grad
|
|
15
|
+
def _make_M(S:torch.Tensor, Y:torch.Tensor, B_0:torch.Tensor):
|
|
16
|
+
m,n = S.size()
|
|
17
|
+
|
|
18
|
+
M = torch.zeros((2 * m, 2 * m), device=S.device, dtype=S.dtype)
|
|
19
|
+
|
|
20
|
+
# top-left is B S^T S
|
|
21
|
+
M[:m, :m] = B_0 * S @ S.mT
|
|
22
|
+
|
|
23
|
+
# anti-diagonal is L^T and L
|
|
24
|
+
L = (S @ Y.mT).tril_(-1)
|
|
25
|
+
|
|
26
|
+
M[m:, :m] = L.mT
|
|
27
|
+
M[:m, m:] = L
|
|
28
|
+
|
|
29
|
+
# bottom-right
|
|
30
|
+
D_diag = (S * Y).sum(1).neg()
|
|
31
|
+
M[m:, m:] = D_diag.diag_embed()
|
|
32
|
+
|
|
33
|
+
return M
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@torch.no_grad
|
|
37
|
+
def lbfgs_Bx(x: torch.Tensor, S: torch.Tensor, Y: torch.Tensor, sy_history, M=None):
|
|
38
|
+
"""L-BFGS hessian-vector product based on compact representation,
|
|
39
|
+
returns (Bx, M), where M is an internal matrix that depends on S and Y so it can be reused."""
|
|
40
|
+
m = len(S)
|
|
41
|
+
if m == 0: return x.clone()
|
|
42
|
+
|
|
43
|
+
# initial scaling
|
|
44
|
+
y = Y[-1]
|
|
45
|
+
sy = sy_history[-1]
|
|
46
|
+
yy = y.dot(y)
|
|
47
|
+
B_0 = yy / sy
|
|
48
|
+
Bx = x * B_0
|
|
49
|
+
|
|
50
|
+
Psi = torch.zeros(2 * m, device=x.device, dtype=x.dtype)
|
|
51
|
+
Psi[:m] = B_0 * S@x
|
|
52
|
+
Psi[m:] = Y@x
|
|
53
|
+
|
|
54
|
+
if M is None: M = _make_M(S, Y, B_0)
|
|
55
|
+
|
|
56
|
+
# solve Mu = p
|
|
57
|
+
u, info = torch.linalg.solve_ex(M, Psi) # pylint:disable=not-callable
|
|
58
|
+
if info != 0:
|
|
59
|
+
return Bx
|
|
60
|
+
|
|
61
|
+
# Bx
|
|
62
|
+
u_S = u[:m]
|
|
63
|
+
u_Y = u[m:]
|
|
64
|
+
SuS = (S * u_S.unsqueeze(-1)).sum(0)
|
|
65
|
+
YuY = (Y * u_Y.unsqueeze(-1)).sum(0)
|
|
66
|
+
return Bx - (B_0 * SuS + YuY), M
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@overload
|
|
70
|
+
def lbfgs_Hx(
|
|
71
|
+
x: torch.Tensor,
|
|
72
|
+
s_history: Sequence[torch.Tensor] | torch.Tensor,
|
|
73
|
+
y_history: Sequence[torch.Tensor] | torch.Tensor,
|
|
74
|
+
sy_history: Sequence[torch.Tensor] | torch.Tensor,
|
|
75
|
+
) -> torch.Tensor: ...
|
|
76
|
+
@overload
|
|
77
|
+
def lbfgs_Hx(
|
|
78
|
+
x: TensorList,
|
|
79
|
+
s_history: Sequence[TensorList],
|
|
80
|
+
y_history: Sequence[TensorList],
|
|
81
|
+
sy_history: Sequence[torch.Tensor] | torch.Tensor,
|
|
82
|
+
) -> TensorList: ...
|
|
83
|
+
def lbfgs_Hx(
|
|
84
|
+
x,
|
|
85
|
+
s_history: Sequence | torch.Tensor,
|
|
86
|
+
y_history: Sequence | torch.Tensor,
|
|
87
|
+
sy_history: Sequence[torch.Tensor] | torch.Tensor,
|
|
88
|
+
):
|
|
89
|
+
"""L-BFGS inverse-hessian-vector product, works with tensors and TensorLists"""
|
|
90
|
+
x = x.clone()
|
|
91
|
+
if len(s_history) == 0: return x
|
|
92
|
+
|
|
93
|
+
# 1st loop
|
|
94
|
+
alpha_list = []
|
|
95
|
+
for s_i, y_i, sy_i in zip(reversed(s_history), reversed(y_history), reversed(sy_history)):
|
|
96
|
+
p_i = 1 / sy_i
|
|
97
|
+
alpha = p_i * s_i.dot(x)
|
|
98
|
+
alpha_list.append(alpha)
|
|
99
|
+
x.sub_(y_i, alpha=alpha)
|
|
100
|
+
|
|
101
|
+
# scaled initial hessian inverse
|
|
102
|
+
# H_0 = (s.y/y.y) * I, and z = H_0 @ q
|
|
103
|
+
sy = sy_history[-1]
|
|
104
|
+
y = y_history[-1]
|
|
105
|
+
Hx = x * (sy / y.dot(y))
|
|
106
|
+
|
|
107
|
+
# 2nd loop
|
|
108
|
+
for s_i, y_i, sy_i, alpha_i in zip(s_history, y_history, sy_history, reversed(alpha_list)):
|
|
109
|
+
p_i = 1 / sy_i
|
|
110
|
+
beta_i = p_i * y_i.dot(Hx)
|
|
111
|
+
Hx.add_(s_i, alpha = alpha_i - beta_i)
|
|
112
|
+
|
|
113
|
+
return Hx
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class LBFGSLinearOperator(LinearOperator):
|
|
117
|
+
def __init__(self, s_history: Sequence[torch.Tensor] | torch.Tensor, y_history: Sequence[torch.Tensor] | torch.Tensor, sy_history: Sequence[torch.Tensor] | torch.Tensor):
|
|
118
|
+
super().__init__()
|
|
119
|
+
if len(s_history) == 0:
|
|
120
|
+
self.S = self.Y = self.yy = None
|
|
121
|
+
else:
|
|
122
|
+
self.S = s_history
|
|
123
|
+
self.Y = y_history
|
|
7
124
|
|
|
125
|
+
self.sy_history = sy_history
|
|
126
|
+
self.M = None
|
|
8
127
|
|
|
9
|
-
def
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
eigval_bounds = (0.01, 1.5)
|
|
15
|
-
):
|
|
16
|
-
# adaptive damping Al-Baali, M.: Quasi-Wolfe conditions for quasi-Newton methods for large-scale optimization. In: 40th Workshop on Large Scale Nonlinear Optimization, Erice, Italy, June 22–July 1 (2004)
|
|
17
|
-
sigma_l, sigma_h = eigval_bounds
|
|
18
|
-
u = ys_k / s_k.dot(s_k)
|
|
19
|
-
if u <= sigma_l < 1: tau = min((1-sigma_l)/(1-u), init_damping)
|
|
20
|
-
elif u >= sigma_h > 1: tau = min((sigma_h-1)/(u-1), init_damping)
|
|
21
|
-
else: tau = init_damping
|
|
22
|
-
y_k = tau * y_k + (1-tau) * s_k
|
|
23
|
-
ys_k = s_k.dot(y_k)
|
|
24
|
-
|
|
25
|
-
return s_k, y_k, ys_k
|
|
26
|
-
|
|
27
|
-
def lbfgs(
|
|
28
|
-
tensors_: TensorList,
|
|
29
|
-
s_history: deque[TensorList],
|
|
30
|
-
y_history: deque[TensorList],
|
|
31
|
-
sy_history: deque[torch.Tensor],
|
|
32
|
-
y_k: TensorList | None,
|
|
33
|
-
ys_k: torch.Tensor | None,
|
|
34
|
-
z_beta: float | None,
|
|
35
|
-
z_ema: TensorList | None,
|
|
36
|
-
step: int,
|
|
37
|
-
):
|
|
38
|
-
if len(s_history) == 0 or y_k is None or ys_k is None:
|
|
39
|
-
|
|
40
|
-
# initial step size guess modified from pytorch L-BFGS
|
|
41
|
-
scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
|
|
42
|
-
scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
|
|
43
|
-
return tensors_.mul_(scale_factor)
|
|
44
|
-
|
|
45
|
-
else:
|
|
46
|
-
# 1st loop
|
|
47
|
-
alpha_list = []
|
|
48
|
-
q = tensors_.clone()
|
|
49
|
-
for s_i, y_i, ys_i in zip(reversed(s_history), reversed(y_history), reversed(sy_history)):
|
|
50
|
-
p_i = 1 / ys_i # this is also denoted as ρ (rho)
|
|
51
|
-
alpha = p_i * s_i.dot(q)
|
|
52
|
-
alpha_list.append(alpha)
|
|
53
|
-
q.sub_(y_i, alpha=alpha) # pyright: ignore[reportArgumentType]
|
|
54
|
-
|
|
55
|
-
# calculate z
|
|
56
|
-
# s.y/y.y is also this weird y-looking symbol I couldn't find
|
|
57
|
-
# z is it times q
|
|
58
|
-
# actually H0 = (s.y/y.y) * I, and z = H0 @ q
|
|
59
|
-
z = q * (ys_k / (y_k.dot(y_k)))
|
|
60
|
-
|
|
61
|
-
# an attempt into adding momentum, lerping initial z seems stable compared to other variables
|
|
62
|
-
if z_beta is not None:
|
|
63
|
-
assert z_ema is not None
|
|
64
|
-
if step == 0: z_ema.copy_(z)
|
|
65
|
-
else: z_ema.lerp(z, 1-z_beta)
|
|
66
|
-
z = z_ema
|
|
67
|
-
|
|
68
|
-
# 2nd loop
|
|
69
|
-
for s_i, y_i, ys_i, alpha_i in zip(s_history, y_history, sy_history, reversed(alpha_list)):
|
|
70
|
-
p_i = 1 / ys_i
|
|
71
|
-
beta_i = p_i * y_i.dot(z)
|
|
72
|
-
z.add_(s_i, alpha = alpha_i - beta_i)
|
|
73
|
-
|
|
74
|
-
return z
|
|
75
|
-
|
|
76
|
-
def _lerp_params_update_(
|
|
77
|
-
self_: Module,
|
|
78
|
-
params: list[torch.Tensor],
|
|
79
|
-
update: list[torch.Tensor],
|
|
80
|
-
params_beta: list[float | None],
|
|
81
|
-
grads_beta: list[float | None],
|
|
82
|
-
):
|
|
83
|
-
for i, (p, u, p_beta, u_beta) in enumerate(zip(params.copy(), update.copy(), params_beta, grads_beta)):
|
|
84
|
-
if p_beta is not None or u_beta is not None:
|
|
85
|
-
state = self_.state[p]
|
|
128
|
+
def _get_S(self):
|
|
129
|
+
if self.S is None: return None
|
|
130
|
+
if not isinstance(self.S, torch.Tensor):
|
|
131
|
+
self.S = torch.stack(tuple(self.S))
|
|
132
|
+
return self.S
|
|
86
133
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
134
|
+
def _get_Y(self):
|
|
135
|
+
if self.Y is None: return None
|
|
136
|
+
if not isinstance(self.Y, torch.Tensor):
|
|
137
|
+
self.Y = torch.stack(tuple(self.Y))
|
|
138
|
+
return self.Y
|
|
91
139
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
140
|
+
def solve(self, b):
|
|
141
|
+
S = self._get_S(); Y = self._get_Y()
|
|
142
|
+
if S is None or Y is None: return b.clone()
|
|
143
|
+
return lbfgs_Hx(b, S, Y, self.sy_history)
|
|
96
144
|
|
|
97
|
-
|
|
145
|
+
def matvec(self, x):
|
|
146
|
+
S = self._get_S(); Y = self._get_Y()
|
|
147
|
+
if S is None or Y is None: return x.clone()
|
|
148
|
+
Bx, self.M = lbfgs_Bx(x, S, Y, self.sy_history, M=self.M)
|
|
149
|
+
return Bx
|
|
98
150
|
|
|
99
|
-
|
|
100
|
-
|
|
151
|
+
def size(self):
|
|
152
|
+
if self.S is None: raise RuntimeError()
|
|
153
|
+
n = len(self.S[0])
|
|
154
|
+
return (n, n)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class LBFGS(Transform):
|
|
158
|
+
"""Limited-memory BFGS algorithm. A line search or trust region is recommended.
|
|
101
159
|
|
|
102
160
|
Args:
|
|
103
|
-
history_size (int, optional):
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
161
|
+
history_size (int, optional):
|
|
162
|
+
number of past parameter differences and gradient differences to store. Defaults to 10.
|
|
163
|
+
ptol (float | None, optional):
|
|
164
|
+
skips updating the history if maximum absolute value of
|
|
165
|
+
parameter difference is less than this value. Defaults to 1e-10.
|
|
166
|
+
ptol_restart (bool, optional):
|
|
167
|
+
If true, whenever parameter difference is less then ``ptol``,
|
|
168
|
+
L-BFGS state will be reset. Defaults to None.
|
|
169
|
+
gtol (float | None, optional):
|
|
170
|
+
skips updating the history if if maximum absolute value of
|
|
171
|
+
gradient difference is less than this value. Defaults to 1e-10.
|
|
172
|
+
ptol_restart (bool, optional):
|
|
173
|
+
If true, whenever gradient difference is less then ``gtol``,
|
|
174
|
+
L-BFGS state will be reset. Defaults to None.
|
|
175
|
+
sy_tol (float | None, optional):
|
|
176
|
+
history will not be updated whenever s⋅y is less than this value (negative s⋅y means negative curvature)
|
|
177
|
+
scale_first (bool, optional):
|
|
178
|
+
makes first step, when hessian approximation is not available,
|
|
179
|
+
small to reduce number of line search iterations. Defaults to True.
|
|
116
180
|
update_freq (int, optional):
|
|
117
|
-
how often to update L-BFGS history. Defaults to 1.
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
tol_reset (bool, optional):
|
|
121
|
-
If true, whenever gradient difference is less then `tol`, the history will be reset. Defaults to None.
|
|
181
|
+
how often to update L-BFGS history. Larger values may be better for stochastic optimization. Defaults to 1.
|
|
182
|
+
damping (DampingStrategyType, optional):
|
|
183
|
+
damping to use, can be "powell" or "double". Defaults to None.
|
|
122
184
|
inner (Chainable | None, optional):
|
|
123
185
|
optional inner modules applied after updating L-BFGS history and before preconditioning. Defaults to None.
|
|
186
|
+
|
|
187
|
+
## Examples:
|
|
188
|
+
|
|
189
|
+
L-BFGS with line search
|
|
190
|
+
```python
|
|
191
|
+
opt = tz.Modular(
|
|
192
|
+
model.parameters(),
|
|
193
|
+
tz.m.LBFGS(100),
|
|
194
|
+
tz.m.Backtracking()
|
|
195
|
+
)
|
|
196
|
+
```
|
|
197
|
+
|
|
198
|
+
L-BFGS with trust region
|
|
199
|
+
```python
|
|
200
|
+
opt = tz.Modular(
|
|
201
|
+
model.parameters(),
|
|
202
|
+
tz.m.TrustCG(tz.m.LBFGS())
|
|
203
|
+
)
|
|
204
|
+
```
|
|
124
205
|
"""
|
|
125
206
|
def __init__(
|
|
126
207
|
self,
|
|
127
208
|
history_size=10,
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
209
|
+
ptol: float | None = 1e-32,
|
|
210
|
+
ptol_restart: bool = False,
|
|
211
|
+
gtol: float | None = 1e-32,
|
|
212
|
+
gtol_restart: bool = False,
|
|
213
|
+
sy_tol: float = 1e-32,
|
|
214
|
+
scale_first:bool=True,
|
|
134
215
|
update_freq = 1,
|
|
135
|
-
|
|
136
|
-
tol_reset: bool = False,
|
|
216
|
+
damping: DampingStrategyType = None,
|
|
137
217
|
inner: Chainable | None = None,
|
|
138
218
|
):
|
|
139
|
-
defaults = dict(
|
|
140
|
-
|
|
219
|
+
defaults = dict(
|
|
220
|
+
history_size=history_size,
|
|
221
|
+
scale_first=scale_first,
|
|
222
|
+
ptol=ptol,
|
|
223
|
+
gtol=gtol,
|
|
224
|
+
ptol_restart=ptol_restart,
|
|
225
|
+
gtol_restart=gtol_restart,
|
|
226
|
+
sy_tol=sy_tol,
|
|
227
|
+
damping = damping,
|
|
228
|
+
)
|
|
229
|
+
super().__init__(defaults, uses_grad=False, inner=inner, update_freq=update_freq)
|
|
141
230
|
|
|
142
231
|
self.global_state['s_history'] = deque(maxlen=history_size)
|
|
143
232
|
self.global_state['y_history'] = deque(maxlen=history_size)
|
|
144
233
|
self.global_state['sy_history'] = deque(maxlen=history_size)
|
|
145
234
|
|
|
146
|
-
|
|
147
|
-
self.set_child('inner', inner)
|
|
148
|
-
|
|
149
|
-
def reset(self):
|
|
235
|
+
def _reset_self(self):
|
|
150
236
|
self.state.clear()
|
|
151
237
|
self.global_state['step'] = 0
|
|
152
238
|
self.global_state['s_history'].clear()
|
|
153
239
|
self.global_state['y_history'].clear()
|
|
154
240
|
self.global_state['sy_history'].clear()
|
|
155
241
|
|
|
242
|
+
def reset(self):
|
|
243
|
+
self._reset_self()
|
|
244
|
+
for c in self.children.values(): c.reset()
|
|
245
|
+
|
|
246
|
+
def reset_for_online(self):
|
|
247
|
+
super().reset_for_online()
|
|
248
|
+
self.clear_state_keys('p_prev', 'g_prev')
|
|
249
|
+
self.global_state.pop('step', None)
|
|
250
|
+
|
|
156
251
|
@torch.no_grad
|
|
157
|
-
def
|
|
158
|
-
|
|
159
|
-
|
|
252
|
+
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
253
|
+
p = as_tensorlist(params)
|
|
254
|
+
g = as_tensorlist(tensors)
|
|
160
255
|
step = self.global_state.get('step', 0)
|
|
161
256
|
self.global_state['step'] = step + 1
|
|
162
257
|
|
|
@@ -165,65 +260,83 @@ class LBFGS(Module):
|
|
|
165
260
|
y_history: deque[TensorList] = self.global_state['y_history']
|
|
166
261
|
sy_history: deque[torch.Tensor] = self.global_state['sy_history']
|
|
167
262
|
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
263
|
+
ptol = self.defaults['ptol']
|
|
264
|
+
gtol = self.defaults['gtol']
|
|
265
|
+
ptol_restart = self.defaults['ptol_restart']
|
|
266
|
+
gtol_restart = self.defaults['gtol_restart']
|
|
267
|
+
sy_tol = self.defaults['sy_tol']
|
|
268
|
+
damping = self.defaults['damping']
|
|
171
269
|
|
|
172
|
-
|
|
173
|
-
prev_l_params, prev_l_grad = self.get_state(params, 'prev_l_params', 'prev_l_grad', cls=TensorList)
|
|
270
|
+
p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', cls=TensorList)
|
|
174
271
|
|
|
175
|
-
# 1st step - there are no previous params and grads,
|
|
272
|
+
# 1st step - there are no previous params and grads, lbfgs will do normalized SGD step
|
|
176
273
|
if step == 0:
|
|
177
|
-
|
|
274
|
+
s = None; y = None; sy = None
|
|
178
275
|
else:
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
276
|
+
s = p - p_prev
|
|
277
|
+
y = g - g_prev
|
|
278
|
+
|
|
279
|
+
if damping is not None:
|
|
280
|
+
s, y = apply_damping(damping, s=s, y=y, g=g, H=self.get_H())
|
|
281
|
+
|
|
282
|
+
sy = s.dot(y)
|
|
283
|
+
# damping to be added here
|
|
284
|
+
|
|
285
|
+
below_tol = False
|
|
286
|
+
# tolerance on parameter difference to avoid exploding after converging
|
|
287
|
+
if ptol is not None:
|
|
288
|
+
if s is not None and s.abs().global_max() <= ptol:
|
|
289
|
+
if ptol_restart:
|
|
290
|
+
self._reset_self()
|
|
291
|
+
sy = None
|
|
292
|
+
below_tol = True
|
|
293
|
+
|
|
294
|
+
# tolerance on gradient difference to avoid exploding when there is no curvature
|
|
295
|
+
if gtol is not None:
|
|
296
|
+
if y is not None and y.abs().global_max() <= gtol:
|
|
297
|
+
if gtol_restart: self._reset_self()
|
|
298
|
+
sy = None
|
|
299
|
+
below_tol = True
|
|
300
|
+
|
|
301
|
+
# store previous params and grads
|
|
302
|
+
if not below_tol:
|
|
303
|
+
p_prev.copy_(p)
|
|
304
|
+
g_prev.copy_(g)
|
|
182
305
|
|
|
183
|
-
|
|
184
|
-
|
|
306
|
+
# update effective preconditioning state
|
|
307
|
+
if sy is not None and sy > sy_tol:
|
|
308
|
+
assert s is not None and y is not None and sy is not None
|
|
185
309
|
|
|
186
|
-
|
|
187
|
-
|
|
310
|
+
s_history.append(s)
|
|
311
|
+
y_history.append(y)
|
|
312
|
+
sy_history.append(sy)
|
|
188
313
|
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
var.update = update # may have been updated by inner module, probably makes sense to use it here?
|
|
205
|
-
if tol_reset: self.reset()
|
|
206
|
-
return var
|
|
207
|
-
|
|
208
|
-
# lerp initial H^-1 @ q guess
|
|
209
|
-
z_ema = None
|
|
210
|
-
if z_beta is not None:
|
|
211
|
-
z_ema = self.get_state(var.params, 'z_ema', cls=TensorList)
|
|
314
|
+
def get_H(self, var=...):
|
|
315
|
+
s_history = [tl.to_vec() for tl in self.global_state['s_history']]
|
|
316
|
+
y_history = [tl.to_vec() for tl in self.global_state['y_history']]
|
|
317
|
+
sy_history = self.global_state['sy_history']
|
|
318
|
+
return LBFGSLinearOperator(s_history, y_history, sy_history)
|
|
319
|
+
|
|
320
|
+
@torch.no_grad
|
|
321
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
322
|
+
scale_first = self.defaults['scale_first']
|
|
323
|
+
|
|
324
|
+
tensors = as_tensorlist(tensors)
|
|
325
|
+
|
|
326
|
+
s_history = self.global_state['s_history']
|
|
327
|
+
y_history = self.global_state['y_history']
|
|
328
|
+
sy_history = self.global_state['sy_history']
|
|
212
329
|
|
|
213
330
|
# precondition
|
|
214
|
-
dir =
|
|
215
|
-
|
|
331
|
+
dir = lbfgs_Hx(
|
|
332
|
+
x=tensors,
|
|
216
333
|
s_history=s_history,
|
|
217
334
|
y_history=y_history,
|
|
218
335
|
sy_history=sy_history,
|
|
219
|
-
y_k=y_k,
|
|
220
|
-
ys_k=ys_k,
|
|
221
|
-
z_beta = z_beta,
|
|
222
|
-
z_ema = z_ema,
|
|
223
|
-
step=step
|
|
224
336
|
)
|
|
225
337
|
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
338
|
+
# scale 1st step
|
|
339
|
+
if scale_first and self.global_state.get('step', 1) == 1:
|
|
340
|
+
dir *= initial_step_size(dir, eps=1e-7)
|
|
229
341
|
|
|
342
|
+
return dir
|