quantmllibrary 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quantml/__init__.py +74 -0
- quantml/autograd.py +154 -0
- quantml/cli/__init__.py +10 -0
- quantml/cli/run_experiment.py +385 -0
- quantml/config/__init__.py +28 -0
- quantml/config/config.py +259 -0
- quantml/data/__init__.py +33 -0
- quantml/data/cache.py +149 -0
- quantml/data/feature_store.py +234 -0
- quantml/data/futures.py +254 -0
- quantml/data/loaders.py +236 -0
- quantml/data/memory_optimizer.py +234 -0
- quantml/data/validators.py +390 -0
- quantml/experiments/__init__.py +23 -0
- quantml/experiments/logger.py +208 -0
- quantml/experiments/results.py +158 -0
- quantml/experiments/tracker.py +223 -0
- quantml/features/__init__.py +25 -0
- quantml/features/base.py +104 -0
- quantml/features/gap_features.py +124 -0
- quantml/features/registry.py +138 -0
- quantml/features/volatility_features.py +140 -0
- quantml/features/volume_features.py +142 -0
- quantml/functional.py +37 -0
- quantml/models/__init__.py +27 -0
- quantml/models/attention.py +258 -0
- quantml/models/dropout.py +130 -0
- quantml/models/gru.py +319 -0
- quantml/models/linear.py +112 -0
- quantml/models/lstm.py +353 -0
- quantml/models/mlp.py +286 -0
- quantml/models/normalization.py +289 -0
- quantml/models/rnn.py +154 -0
- quantml/models/tcn.py +238 -0
- quantml/online.py +209 -0
- quantml/ops.py +1707 -0
- quantml/optim/__init__.py +42 -0
- quantml/optim/adafactor.py +206 -0
- quantml/optim/adagrad.py +157 -0
- quantml/optim/adam.py +267 -0
- quantml/optim/lookahead.py +97 -0
- quantml/optim/quant_optimizer.py +228 -0
- quantml/optim/radam.py +192 -0
- quantml/optim/rmsprop.py +203 -0
- quantml/optim/schedulers.py +286 -0
- quantml/optim/sgd.py +181 -0
- quantml/py.typed +0 -0
- quantml/streaming.py +175 -0
- quantml/tensor.py +462 -0
- quantml/time_series.py +447 -0
- quantml/training/__init__.py +135 -0
- quantml/training/alpha_eval.py +203 -0
- quantml/training/backtest.py +280 -0
- quantml/training/backtest_analysis.py +168 -0
- quantml/training/cv.py +106 -0
- quantml/training/data_loader.py +177 -0
- quantml/training/ensemble.py +84 -0
- quantml/training/feature_importance.py +135 -0
- quantml/training/features.py +364 -0
- quantml/training/futures_backtest.py +266 -0
- quantml/training/gradient_clipping.py +206 -0
- quantml/training/losses.py +248 -0
- quantml/training/lr_finder.py +127 -0
- quantml/training/metrics.py +376 -0
- quantml/training/regularization.py +89 -0
- quantml/training/trainer.py +239 -0
- quantml/training/walk_forward.py +190 -0
- quantml/utils/__init__.py +51 -0
- quantml/utils/gradient_check.py +274 -0
- quantml/utils/logging.py +181 -0
- quantml/utils/ops_cpu.py +231 -0
- quantml/utils/profiling.py +364 -0
- quantml/utils/reproducibility.py +220 -0
- quantml/utils/serialization.py +335 -0
- quantmllibrary-0.1.0.dist-info/METADATA +536 -0
- quantmllibrary-0.1.0.dist-info/RECORD +79 -0
- quantmllibrary-0.1.0.dist-info/WHEEL +5 -0
- quantmllibrary-0.1.0.dist-info/licenses/LICENSE +22 -0
- quantmllibrary-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Lookahead optimizer wrapper.
|
|
3
|
+
|
|
4
|
+
Lookahead is a wrapper that improves training stability by maintaining
|
|
5
|
+
a slow-moving average of parameters.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import List, Optional, Any
|
|
9
|
+
from quantml.tensor import Tensor
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Lookahead:
|
|
13
|
+
"""
|
|
14
|
+
Lookahead optimizer wrapper.
|
|
15
|
+
|
|
16
|
+
Wraps any base optimizer and maintains a slow-moving average of parameters.
|
|
17
|
+
Improves training stability for quant models.
|
|
18
|
+
|
|
19
|
+
Attributes:
|
|
20
|
+
base_optimizer: The base optimizer to wrap
|
|
21
|
+
k: Number of steps before updating slow weights
|
|
22
|
+
alpha: Interpolation factor for slow weights
|
|
23
|
+
slow_weights: Slow-moving parameter averages
|
|
24
|
+
|
|
25
|
+
Examples:
|
|
26
|
+
>>> base_opt = Adam(lr=0.001)
|
|
27
|
+
>>> optimizer = Lookahead(base_opt, k=5, alpha=0.5)
|
|
28
|
+
>>> for param in model.parameters():
|
|
29
|
+
>>> optimizer.step(param)
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
base_optimizer: Any,
|
|
35
|
+
k: int = 5,
|
|
36
|
+
alpha: float = 0.5
|
|
37
|
+
):
|
|
38
|
+
"""
|
|
39
|
+
Initialize Lookahead optimizer.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
base_optimizer: Base optimizer to wrap (SGD, Adam, etc.)
|
|
43
|
+
k: Number of steps before updating slow weights
|
|
44
|
+
alpha: Interpolation factor (0 < alpha < 1)
|
|
45
|
+
"""
|
|
46
|
+
self.base_optimizer = base_optimizer
|
|
47
|
+
self.k = k
|
|
48
|
+
self.alpha = alpha
|
|
49
|
+
self.slow_weights: dict = {} # Slow-moving parameter averages
|
|
50
|
+
self.step_count = 0
|
|
51
|
+
|
|
52
|
+
def step(self, param: Optional[Tensor] = None):
|
|
53
|
+
"""Perform a single optimization step."""
|
|
54
|
+
# Use base optimizer
|
|
55
|
+
self.base_optimizer.step(param)
|
|
56
|
+
self.step_count += 1
|
|
57
|
+
|
|
58
|
+
# Update slow weights every k steps
|
|
59
|
+
if self.step_count % self.k == 0:
|
|
60
|
+
if param is not None:
|
|
61
|
+
self._update_slow_weights(param)
|
|
62
|
+
else:
|
|
63
|
+
for p in self.base_optimizer.params:
|
|
64
|
+
self._update_slow_weights(p)
|
|
65
|
+
|
|
66
|
+
def _update_slow_weights(self, param: Tensor):
|
|
67
|
+
"""Update slow-moving parameter average."""
|
|
68
|
+
param_id = id(param)
|
|
69
|
+
|
|
70
|
+
if param_id not in self.slow_weights:
|
|
71
|
+
# Initialize slow weights with current parameter values
|
|
72
|
+
self.slow_weights[param_id] = param.data.copy() if hasattr(param.data, 'copy') else param.data
|
|
73
|
+
else:
|
|
74
|
+
# Interpolate: slow = alpha * slow + (1 - alpha) * param
|
|
75
|
+
slow = self.slow_weights[param_id]
|
|
76
|
+
if hasattr(slow, '__iter__') and hasattr(param.data, '__iter__'):
|
|
77
|
+
# Update slow weights (simplified - would need proper NumPy/list handling)
|
|
78
|
+
# For now, just store current param as slow weight
|
|
79
|
+
self.slow_weights[param_id] = param.data.copy() if hasattr(param.data, 'copy') else param.data
|
|
80
|
+
else:
|
|
81
|
+
self.slow_weights[param_id] = self.alpha * slow + (1.0 - self.alpha) * param.data
|
|
82
|
+
|
|
83
|
+
def sync_slow_weights(self):
|
|
84
|
+
"""Synchronize parameters with slow weights."""
|
|
85
|
+
for param in self.base_optimizer.params:
|
|
86
|
+
param_id = id(param)
|
|
87
|
+
if param_id in self.slow_weights:
|
|
88
|
+
param.data = self.slow_weights[param_id]
|
|
89
|
+
|
|
90
|
+
def zero_grad(self, param: Optional[Tensor] = None):
|
|
91
|
+
"""Clear gradients."""
|
|
92
|
+
self.base_optimizer.zero_grad(param)
|
|
93
|
+
|
|
94
|
+
def add_param_group(self, params: List[Tensor]):
|
|
95
|
+
"""Add a parameter group to optimize."""
|
|
96
|
+
self.base_optimizer.add_param_group(params)
|
|
97
|
+
|
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
"""
|
|
2
|
+
QuantOptimizer - Custom optimizer for quant trading.
|
|
3
|
+
|
|
4
|
+
Adaptive learning rate based on market volatility and regime-aware parameter updates.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import List, Optional, Dict, Any
|
|
8
|
+
from quantml.tensor import Tensor
|
|
9
|
+
from quantml import ops
|
|
10
|
+
|
|
11
|
+
# Try to import NumPy
|
|
12
|
+
try:
|
|
13
|
+
import numpy as np
|
|
14
|
+
HAS_NUMPY = True
|
|
15
|
+
except ImportError:
|
|
16
|
+
HAS_NUMPY = False
|
|
17
|
+
np = None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class QuantOptimizer:
|
|
21
|
+
"""
|
|
22
|
+
QuantOptimizer - Custom optimizer for quant trading.
|
|
23
|
+
|
|
24
|
+
Features:
|
|
25
|
+
- Adaptive learning rate based on market volatility
|
|
26
|
+
- Regime-aware parameter updates
|
|
27
|
+
- Per-feature learning rate scaling
|
|
28
|
+
|
|
29
|
+
Attributes:
|
|
30
|
+
lr: Base learning rate
|
|
31
|
+
volatility_window: Window size for volatility calculation
|
|
32
|
+
regime_threshold: Threshold for regime detection
|
|
33
|
+
feature_lrs: Per-feature learning rate multipliers
|
|
34
|
+
|
|
35
|
+
Examples:
|
|
36
|
+
>>> optimizer = QuantOptimizer(lr=0.001, volatility_window=20)
|
|
37
|
+
>>> for param in model.parameters():
|
|
38
|
+
>>> optimizer.step(param, market_volatility=0.02)
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
params: Optional[List[Tensor]] = None,
|
|
44
|
+
lr: float = 0.001,
|
|
45
|
+
beta1: float = 0.9,
|
|
46
|
+
beta2: float = 0.999,
|
|
47
|
+
eps: float = 1e-8,
|
|
48
|
+
weight_decay: float = 0.0,
|
|
49
|
+
volatility_window: int = 20,
|
|
50
|
+
regime_threshold: float = 0.5
|
|
51
|
+
):
|
|
52
|
+
"""
|
|
53
|
+
Initialize QuantOptimizer.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
params: Optional list of parameters to optimize
|
|
57
|
+
lr: Base learning rate
|
|
58
|
+
beta1: First moment decay rate
|
|
59
|
+
beta2: Second moment decay rate
|
|
60
|
+
eps: Small value for numerical stability
|
|
61
|
+
weight_decay: Weight decay coefficient
|
|
62
|
+
volatility_window: Window size for volatility calculation
|
|
63
|
+
regime_threshold: Threshold for regime detection
|
|
64
|
+
"""
|
|
65
|
+
self.params = params if params is not None else []
|
|
66
|
+
self.lr = lr
|
|
67
|
+
self.beta1 = beta1
|
|
68
|
+
self.beta2 = beta2
|
|
69
|
+
self.eps = eps
|
|
70
|
+
self.weight_decay = weight_decay
|
|
71
|
+
self.volatility_window = volatility_window
|
|
72
|
+
self.regime_threshold = regime_threshold
|
|
73
|
+
|
|
74
|
+
self.m: Dict[int, Any] = {} # First moment
|
|
75
|
+
self.v: Dict[int, Any] = {} # Second moment
|
|
76
|
+
self.step_count = 0
|
|
77
|
+
self.volatility_history: List[float] = []
|
|
78
|
+
self.feature_lrs: Dict[int, Any] = {} # Per-parameter learning rate multipliers
|
|
79
|
+
|
|
80
|
+
def step(self, param: Optional[Tensor] = None, market_volatility: Optional[float] = None):
|
|
81
|
+
"""
|
|
82
|
+
Perform a single optimization step.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
param: Optional single parameter to update
|
|
86
|
+
market_volatility: Current market volatility (for adaptive LR)
|
|
87
|
+
"""
|
|
88
|
+
if param is not None:
|
|
89
|
+
self._update_param(param, market_volatility)
|
|
90
|
+
else:
|
|
91
|
+
for p in self.params:
|
|
92
|
+
self._update_param(p, market_volatility)
|
|
93
|
+
self.step_count += 1
|
|
94
|
+
|
|
95
|
+
# Update volatility history
|
|
96
|
+
if market_volatility is not None:
|
|
97
|
+
self.volatility_history.append(market_volatility)
|
|
98
|
+
if len(self.volatility_history) > self.volatility_window:
|
|
99
|
+
self.volatility_history.pop(0)
|
|
100
|
+
|
|
101
|
+
def _update_param(self, param: Tensor, market_volatility: Optional[float] = None):
|
|
102
|
+
"""Update a single parameter with adaptive learning rate."""
|
|
103
|
+
if not param.requires_grad:
|
|
104
|
+
return
|
|
105
|
+
|
|
106
|
+
if param.grad is None:
|
|
107
|
+
return
|
|
108
|
+
|
|
109
|
+
param_id = id(param)
|
|
110
|
+
|
|
111
|
+
# Calculate adaptive learning rate based on volatility
|
|
112
|
+
adaptive_lr = self.lr
|
|
113
|
+
if market_volatility is not None and len(self.volatility_history) > 1:
|
|
114
|
+
# Adjust LR based on volatility regime
|
|
115
|
+
avg_vol = sum(self.volatility_history) / len(self.volatility_history)
|
|
116
|
+
if market_volatility > avg_vol * (1.0 + self.regime_threshold):
|
|
117
|
+
# High volatility regime - reduce LR
|
|
118
|
+
adaptive_lr = self.lr * 0.5
|
|
119
|
+
elif market_volatility < avg_vol * (1.0 - self.regime_threshold):
|
|
120
|
+
# Low volatility regime - increase LR
|
|
121
|
+
adaptive_lr = self.lr * 1.5
|
|
122
|
+
|
|
123
|
+
# Apply per-parameter learning rate multiplier
|
|
124
|
+
if param_id in self.feature_lrs:
|
|
125
|
+
adaptive_lr = adaptive_lr * self.feature_lrs[param_id]
|
|
126
|
+
|
|
127
|
+
if HAS_NUMPY:
|
|
128
|
+
try:
|
|
129
|
+
grad = param.grad
|
|
130
|
+
if isinstance(grad, np.ndarray):
|
|
131
|
+
grad_arr = grad
|
|
132
|
+
else:
|
|
133
|
+
grad_arr = np.array(grad, dtype=np.float64)
|
|
134
|
+
|
|
135
|
+
param_arr = param.numpy if param.numpy is not None else np.array(param.data, dtype=np.float64)
|
|
136
|
+
|
|
137
|
+
if param_id not in self.m:
|
|
138
|
+
self.m[param_id] = np.zeros_like(param_arr, dtype=np.float64)
|
|
139
|
+
self.v[param_id] = np.zeros_like(param_arr, dtype=np.float64)
|
|
140
|
+
|
|
141
|
+
if self.weight_decay > 0:
|
|
142
|
+
grad_arr = grad_arr + self.weight_decay * param_arr
|
|
143
|
+
|
|
144
|
+
# Update moments (Adam-like)
|
|
145
|
+
m = self.m[param_id]
|
|
146
|
+
v = self.v[param_id]
|
|
147
|
+
m[:] = self.beta1 * m + (1.0 - self.beta1) * grad_arr
|
|
148
|
+
v[:] = self.beta2 * v + (1.0 - self.beta2) * (grad_arr ** 2)
|
|
149
|
+
|
|
150
|
+
# Bias correction
|
|
151
|
+
bias_correction1 = 1.0 - (self.beta1 ** self.step_count)
|
|
152
|
+
bias_correction2 = 1.0 - (self.beta2 ** self.step_count)
|
|
153
|
+
m_hat = m / bias_correction1
|
|
154
|
+
v_hat = v / bias_correction2
|
|
155
|
+
|
|
156
|
+
# Update parameter with adaptive LR
|
|
157
|
+
v_hat_sqrt = np.sqrt(v_hat) + self.eps
|
|
158
|
+
update = m_hat / v_hat_sqrt
|
|
159
|
+
param_update = adaptive_lr * update
|
|
160
|
+
new_param_arr = param_arr - param_update
|
|
161
|
+
param.data = new_param_arr
|
|
162
|
+
|
|
163
|
+
except (ValueError, TypeError, AttributeError):
|
|
164
|
+
self._update_param_fallback(param, adaptive_lr)
|
|
165
|
+
else:
|
|
166
|
+
self._update_param_fallback(param, adaptive_lr)
|
|
167
|
+
|
|
168
|
+
def _update_param_fallback(self, param: Tensor, adaptive_lr: float):
|
|
169
|
+
"""Fallback update using Tensor operations."""
|
|
170
|
+
if param.grad is None:
|
|
171
|
+
return
|
|
172
|
+
|
|
173
|
+
param_id = id(param)
|
|
174
|
+
|
|
175
|
+
if param_id not in self.m:
|
|
176
|
+
if isinstance(param.data[0], list):
|
|
177
|
+
self.m[param_id] = Tensor([[0.0] * len(row) for row in param.data])
|
|
178
|
+
self.v[param_id] = Tensor([[0.0] * len(row) for row in param.data])
|
|
179
|
+
else:
|
|
180
|
+
self.m[param_id] = Tensor([0.0] * len(param.data))
|
|
181
|
+
self.v[param_id] = Tensor([0.0] * len(param.data))
|
|
182
|
+
|
|
183
|
+
grad = Tensor(param.grad)
|
|
184
|
+
if self.weight_decay > 0:
|
|
185
|
+
grad = ops.add(grad, ops.mul(param, self.weight_decay))
|
|
186
|
+
|
|
187
|
+
m_prev = self.m[param_id]
|
|
188
|
+
m_new = ops.add(ops.mul(m_prev, self.beta1), ops.mul(grad, 1.0 - self.beta1))
|
|
189
|
+
self.m[param_id] = m_new
|
|
190
|
+
|
|
191
|
+
v_prev = self.v[param_id]
|
|
192
|
+
grad_sq = ops.mul(grad, grad)
|
|
193
|
+
v_new = ops.add(ops.mul(v_prev, self.beta2), ops.mul(grad_sq, 1.0 - self.beta2))
|
|
194
|
+
self.v[param_id] = v_new
|
|
195
|
+
|
|
196
|
+
bias_correction1 = 1.0 - (self.beta1 ** self.step_count)
|
|
197
|
+
bias_correction2 = 1.0 - (self.beta2 ** self.step_count)
|
|
198
|
+
m_hat = ops.div(m_new, bias_correction1)
|
|
199
|
+
v_hat = ops.div(v_new, bias_correction2)
|
|
200
|
+
|
|
201
|
+
v_hat_sqrt = ops.pow(ops.add(v_hat, self.eps), 0.5)
|
|
202
|
+
update = ops.div(m_hat, v_hat_sqrt)
|
|
203
|
+
param_update = ops.mul(update, adaptive_lr)
|
|
204
|
+
|
|
205
|
+
if param.requires_grad:
|
|
206
|
+
param_detached = param.detach()
|
|
207
|
+
param_detached.sub_(param_update)
|
|
208
|
+
param.data = param_detached.data
|
|
209
|
+
else:
|
|
210
|
+
param.sub_(param_update)
|
|
211
|
+
|
|
212
|
+
def set_feature_lr(self, param: Tensor, multiplier: float):
|
|
213
|
+
"""Set learning rate multiplier for a specific parameter."""
|
|
214
|
+
param_id = id(param)
|
|
215
|
+
self.feature_lrs[param_id] = multiplier
|
|
216
|
+
|
|
217
|
+
def zero_grad(self, param: Optional[Tensor] = None):
|
|
218
|
+
"""Clear gradients."""
|
|
219
|
+
if param is not None:
|
|
220
|
+
param.zero_grad()
|
|
221
|
+
else:
|
|
222
|
+
for p in self.params:
|
|
223
|
+
p.zero_grad()
|
|
224
|
+
|
|
225
|
+
def add_param_group(self, params: List[Tensor]):
|
|
226
|
+
"""Add a parameter group to optimize."""
|
|
227
|
+
self.params.extend(params)
|
|
228
|
+
|
quantml/optim/radam.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
"""
|
|
2
|
+
RAdam (Rectified Adam) optimizer implementation.
|
|
3
|
+
|
|
4
|
+
RAdam rectifies the variance of the adaptive learning rate in Adam.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import List, Optional, Dict, Any
|
|
8
|
+
from quantml.tensor import Tensor
|
|
9
|
+
from quantml import ops
|
|
10
|
+
import math
|
|
11
|
+
|
|
12
|
+
# Try to import NumPy
|
|
13
|
+
try:
|
|
14
|
+
import numpy as np
|
|
15
|
+
HAS_NUMPY = True
|
|
16
|
+
except ImportError:
|
|
17
|
+
HAS_NUMPY = False
|
|
18
|
+
np = None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class RAdam:
|
|
22
|
+
"""
|
|
23
|
+
RAdam (Rectified Adam) optimizer.
|
|
24
|
+
|
|
25
|
+
Rectifies the variance of the adaptive learning rate in Adam,
|
|
26
|
+
providing better convergence for quant training.
|
|
27
|
+
|
|
28
|
+
Attributes:
|
|
29
|
+
lr: Learning rate
|
|
30
|
+
betas: Tuple of (beta1, beta2) for moment estimates
|
|
31
|
+
eps: Small value for numerical stability
|
|
32
|
+
weight_decay: Weight decay coefficient
|
|
33
|
+
m: First moment estimates
|
|
34
|
+
v: Second moment estimates
|
|
35
|
+
|
|
36
|
+
Examples:
|
|
37
|
+
>>> optimizer = RAdam(lr=0.001, betas=(0.9, 0.999))
|
|
38
|
+
>>> for param in model.parameters():
|
|
39
|
+
>>> optimizer.step(param)
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
params: Optional[List[Tensor]] = None,
|
|
45
|
+
lr: float = 0.001,
|
|
46
|
+
betas: tuple = (0.9, 0.999),
|
|
47
|
+
eps: float = 1e-8,
|
|
48
|
+
weight_decay: float = 0.0
|
|
49
|
+
):
|
|
50
|
+
"""
|
|
51
|
+
Initialize RAdam optimizer.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
params: Optional list of parameters to optimize
|
|
55
|
+
lr: Learning rate
|
|
56
|
+
betas: Tuple of (beta1, beta2) for exponential decay rates
|
|
57
|
+
eps: Small value to prevent division by zero
|
|
58
|
+
weight_decay: Weight decay coefficient
|
|
59
|
+
"""
|
|
60
|
+
self.params = params if params is not None else []
|
|
61
|
+
self.lr = lr
|
|
62
|
+
self.beta1, self.beta2 = betas
|
|
63
|
+
self.eps = eps
|
|
64
|
+
self.weight_decay = weight_decay
|
|
65
|
+
self.m: Dict[int, Any] = {} # First moment
|
|
66
|
+
self.v: Dict[int, Any] = {} # Second moment
|
|
67
|
+
self.step_count = 0
|
|
68
|
+
|
|
69
|
+
def step(self, param: Optional[Tensor] = None):
|
|
70
|
+
"""Perform a single optimization step."""
|
|
71
|
+
if param is not None:
|
|
72
|
+
self._update_param(param)
|
|
73
|
+
else:
|
|
74
|
+
for p in self.params:
|
|
75
|
+
self._update_param(p)
|
|
76
|
+
self.step_count += 1
|
|
77
|
+
|
|
78
|
+
def _update_param(self, param: Tensor):
|
|
79
|
+
"""Update a single parameter using RAdam algorithm."""
|
|
80
|
+
if not param.requires_grad:
|
|
81
|
+
return
|
|
82
|
+
|
|
83
|
+
if param.grad is None:
|
|
84
|
+
return
|
|
85
|
+
|
|
86
|
+
param_id = id(param)
|
|
87
|
+
|
|
88
|
+
if HAS_NUMPY:
|
|
89
|
+
try:
|
|
90
|
+
grad = param.grad
|
|
91
|
+
if isinstance(grad, np.ndarray):
|
|
92
|
+
grad_arr = grad
|
|
93
|
+
else:
|
|
94
|
+
grad_arr = np.array(grad, dtype=np.float64)
|
|
95
|
+
|
|
96
|
+
param_arr = param.numpy if param.numpy is not None else np.array(param.data, dtype=np.float64)
|
|
97
|
+
|
|
98
|
+
if param_id not in self.m:
|
|
99
|
+
self.m[param_id] = np.zeros_like(param_arr, dtype=np.float64)
|
|
100
|
+
self.v[param_id] = np.zeros_like(param_arr, dtype=np.float64)
|
|
101
|
+
|
|
102
|
+
if self.weight_decay > 0:
|
|
103
|
+
grad_arr = grad_arr + self.weight_decay * param_arr
|
|
104
|
+
|
|
105
|
+
# Update moments
|
|
106
|
+
m = self.m[param_id]
|
|
107
|
+
v = self.v[param_id]
|
|
108
|
+
m[:] = self.beta1 * m + (1.0 - self.beta1) * grad_arr
|
|
109
|
+
v[:] = self.beta2 * v + (1.0 - self.beta2) * (grad_arr ** 2)
|
|
110
|
+
|
|
111
|
+
# RAdam variance rectification
|
|
112
|
+
beta2_t = self.beta2 ** self.step_count
|
|
113
|
+
rho_inf = 2.0 / (1.0 - self.beta2) - 1.0
|
|
114
|
+
rho_t = rho_inf - 2.0 * self.step_count * beta2_t / (1.0 - beta2_t)
|
|
115
|
+
|
|
116
|
+
if rho_t > 4.0:
|
|
117
|
+
# Rectified update
|
|
118
|
+
r_t = math.sqrt((rho_t - 4.0) * (rho_t - 2.0) * rho_inf / ((rho_inf - 4.0) * (rho_inf - 2.0) * rho_t))
|
|
119
|
+
m_hat = m / (1.0 - self.beta1 ** self.step_count)
|
|
120
|
+
v_hat = v / (1.0 - beta2_t)
|
|
121
|
+
update = r_t * m_hat / (np.sqrt(v_hat) + self.eps)
|
|
122
|
+
else:
|
|
123
|
+
# Simple momentum update
|
|
124
|
+
m_hat = m / (1.0 - self.beta1 ** self.step_count)
|
|
125
|
+
update = m_hat
|
|
126
|
+
|
|
127
|
+
param_update = self.lr * update
|
|
128
|
+
new_param_arr = param_arr - param_update
|
|
129
|
+
param.data = new_param_arr
|
|
130
|
+
|
|
131
|
+
except (ValueError, TypeError, AttributeError):
|
|
132
|
+
self._update_param_fallback(param)
|
|
133
|
+
else:
|
|
134
|
+
self._update_param_fallback(param)
|
|
135
|
+
|
|
136
|
+
def _update_param_fallback(self, param: Tensor):
|
|
137
|
+
"""Fallback update using Tensor operations."""
|
|
138
|
+
if param.grad is None:
|
|
139
|
+
return
|
|
140
|
+
|
|
141
|
+
param_id = id(param)
|
|
142
|
+
|
|
143
|
+
if param_id not in self.m:
|
|
144
|
+
if isinstance(param.data[0], list):
|
|
145
|
+
self.m[param_id] = Tensor([[0.0] * len(row) for row in param.data])
|
|
146
|
+
self.v[param_id] = Tensor([[0.0] * len(row) for row in param.data])
|
|
147
|
+
else:
|
|
148
|
+
self.m[param_id] = Tensor([0.0] * len(param.data))
|
|
149
|
+
self.v[param_id] = Tensor([0.0] * len(param.data))
|
|
150
|
+
|
|
151
|
+
grad = Tensor(param.grad)
|
|
152
|
+
if self.weight_decay > 0:
|
|
153
|
+
grad = ops.add(grad, ops.mul(param, self.weight_decay))
|
|
154
|
+
|
|
155
|
+
m_prev = self.m[param_id]
|
|
156
|
+
m_new = ops.add(ops.mul(m_prev, self.beta1), ops.mul(grad, 1.0 - self.beta1))
|
|
157
|
+
self.m[param_id] = m_new
|
|
158
|
+
|
|
159
|
+
v_prev = self.v[param_id]
|
|
160
|
+
grad_sq = ops.mul(grad, grad)
|
|
161
|
+
v_new = ops.add(ops.mul(v_prev, self.beta2), ops.mul(grad_sq, 1.0 - self.beta2))
|
|
162
|
+
self.v[param_id] = v_new
|
|
163
|
+
|
|
164
|
+
# Simplified RAdam (use regular Adam update in fallback)
|
|
165
|
+
bias_correction1 = 1.0 - (self.beta1 ** self.step_count)
|
|
166
|
+
bias_correction2 = 1.0 - (self.beta2 ** self.step_count)
|
|
167
|
+
m_hat = ops.div(m_new, bias_correction1)
|
|
168
|
+
v_hat = ops.div(v_new, bias_correction2)
|
|
169
|
+
|
|
170
|
+
v_hat_sqrt = ops.pow(ops.add(v_hat, self.eps), 0.5)
|
|
171
|
+
update = ops.div(m_hat, v_hat_sqrt)
|
|
172
|
+
param_update = ops.mul(update, self.lr)
|
|
173
|
+
|
|
174
|
+
if param.requires_grad:
|
|
175
|
+
param_detached = param.detach()
|
|
176
|
+
param_detached.sub_(param_update)
|
|
177
|
+
param.data = param_detached.data
|
|
178
|
+
else:
|
|
179
|
+
param.sub_(param_update)
|
|
180
|
+
|
|
181
|
+
def zero_grad(self, param: Optional[Tensor] = None):
|
|
182
|
+
"""Clear gradients."""
|
|
183
|
+
if param is not None:
|
|
184
|
+
param.zero_grad()
|
|
185
|
+
else:
|
|
186
|
+
for p in self.params:
|
|
187
|
+
p.zero_grad()
|
|
188
|
+
|
|
189
|
+
def add_param_group(self, params: List[Tensor]):
|
|
190
|
+
"""Add a parameter group to optimize."""
|
|
191
|
+
self.params.extend(params)
|
|
192
|
+
|