quantmllibrary 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. quantml/__init__.py +74 -0
  2. quantml/autograd.py +154 -0
  3. quantml/cli/__init__.py +10 -0
  4. quantml/cli/run_experiment.py +385 -0
  5. quantml/config/__init__.py +28 -0
  6. quantml/config/config.py +259 -0
  7. quantml/data/__init__.py +33 -0
  8. quantml/data/cache.py +149 -0
  9. quantml/data/feature_store.py +234 -0
  10. quantml/data/futures.py +254 -0
  11. quantml/data/loaders.py +236 -0
  12. quantml/data/memory_optimizer.py +234 -0
  13. quantml/data/validators.py +390 -0
  14. quantml/experiments/__init__.py +23 -0
  15. quantml/experiments/logger.py +208 -0
  16. quantml/experiments/results.py +158 -0
  17. quantml/experiments/tracker.py +223 -0
  18. quantml/features/__init__.py +25 -0
  19. quantml/features/base.py +104 -0
  20. quantml/features/gap_features.py +124 -0
  21. quantml/features/registry.py +138 -0
  22. quantml/features/volatility_features.py +140 -0
  23. quantml/features/volume_features.py +142 -0
  24. quantml/functional.py +37 -0
  25. quantml/models/__init__.py +27 -0
  26. quantml/models/attention.py +258 -0
  27. quantml/models/dropout.py +130 -0
  28. quantml/models/gru.py +319 -0
  29. quantml/models/linear.py +112 -0
  30. quantml/models/lstm.py +353 -0
  31. quantml/models/mlp.py +286 -0
  32. quantml/models/normalization.py +289 -0
  33. quantml/models/rnn.py +154 -0
  34. quantml/models/tcn.py +238 -0
  35. quantml/online.py +209 -0
  36. quantml/ops.py +1707 -0
  37. quantml/optim/__init__.py +42 -0
  38. quantml/optim/adafactor.py +206 -0
  39. quantml/optim/adagrad.py +157 -0
  40. quantml/optim/adam.py +267 -0
  41. quantml/optim/lookahead.py +97 -0
  42. quantml/optim/quant_optimizer.py +228 -0
  43. quantml/optim/radam.py +192 -0
  44. quantml/optim/rmsprop.py +203 -0
  45. quantml/optim/schedulers.py +286 -0
  46. quantml/optim/sgd.py +181 -0
  47. quantml/py.typed +0 -0
  48. quantml/streaming.py +175 -0
  49. quantml/tensor.py +462 -0
  50. quantml/time_series.py +447 -0
  51. quantml/training/__init__.py +135 -0
  52. quantml/training/alpha_eval.py +203 -0
  53. quantml/training/backtest.py +280 -0
  54. quantml/training/backtest_analysis.py +168 -0
  55. quantml/training/cv.py +106 -0
  56. quantml/training/data_loader.py +177 -0
  57. quantml/training/ensemble.py +84 -0
  58. quantml/training/feature_importance.py +135 -0
  59. quantml/training/features.py +364 -0
  60. quantml/training/futures_backtest.py +266 -0
  61. quantml/training/gradient_clipping.py +206 -0
  62. quantml/training/losses.py +248 -0
  63. quantml/training/lr_finder.py +127 -0
  64. quantml/training/metrics.py +376 -0
  65. quantml/training/regularization.py +89 -0
  66. quantml/training/trainer.py +239 -0
  67. quantml/training/walk_forward.py +190 -0
  68. quantml/utils/__init__.py +51 -0
  69. quantml/utils/gradient_check.py +274 -0
  70. quantml/utils/logging.py +181 -0
  71. quantml/utils/ops_cpu.py +231 -0
  72. quantml/utils/profiling.py +364 -0
  73. quantml/utils/reproducibility.py +220 -0
  74. quantml/utils/serialization.py +335 -0
  75. quantmllibrary-0.1.0.dist-info/METADATA +536 -0
  76. quantmllibrary-0.1.0.dist-info/RECORD +79 -0
  77. quantmllibrary-0.1.0.dist-info/WHEEL +5 -0
  78. quantmllibrary-0.1.0.dist-info/licenses/LICENSE +22 -0
  79. quantmllibrary-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,97 @@
1
+ """
2
+ Lookahead optimizer wrapper.
3
+
4
+ Lookahead is a wrapper that improves training stability by maintaining
5
+ a slow-moving average of parameters.
6
+ """
7
+
8
+ from typing import List, Optional, Any
9
+ from quantml.tensor import Tensor
10
+
11
+
12
+ class Lookahead:
13
+ """
14
+ Lookahead optimizer wrapper.
15
+
16
+ Wraps any base optimizer and maintains a slow-moving average of parameters.
17
+ Improves training stability for quant models.
18
+
19
+ Attributes:
20
+ base_optimizer: The base optimizer to wrap
21
+ k: Number of steps before updating slow weights
22
+ alpha: Interpolation factor for slow weights
23
+ slow_weights: Slow-moving parameter averages
24
+
25
+ Examples:
26
+ >>> base_opt = Adam(lr=0.001)
27
+ >>> optimizer = Lookahead(base_opt, k=5, alpha=0.5)
28
+ >>> for param in model.parameters():
29
+ >>> optimizer.step(param)
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ base_optimizer: Any,
35
+ k: int = 5,
36
+ alpha: float = 0.5
37
+ ):
38
+ """
39
+ Initialize Lookahead optimizer.
40
+
41
+ Args:
42
+ base_optimizer: Base optimizer to wrap (SGD, Adam, etc.)
43
+ k: Number of steps before updating slow weights
44
+ alpha: Interpolation factor (0 < alpha < 1)
45
+ """
46
+ self.base_optimizer = base_optimizer
47
+ self.k = k
48
+ self.alpha = alpha
49
+ self.slow_weights: dict = {} # Slow-moving parameter averages
50
+ self.step_count = 0
51
+
52
+ def step(self, param: Optional[Tensor] = None):
53
+ """Perform a single optimization step."""
54
+ # Use base optimizer
55
+ self.base_optimizer.step(param)
56
+ self.step_count += 1
57
+
58
+ # Update slow weights every k steps
59
+ if self.step_count % self.k == 0:
60
+ if param is not None:
61
+ self._update_slow_weights(param)
62
+ else:
63
+ for p in self.base_optimizer.params:
64
+ self._update_slow_weights(p)
65
+
66
+ def _update_slow_weights(self, param: Tensor):
67
+ """Update slow-moving parameter average."""
68
+ param_id = id(param)
69
+
70
+ if param_id not in self.slow_weights:
71
+ # Initialize slow weights with current parameter values
72
+ self.slow_weights[param_id] = param.data.copy() if hasattr(param.data, 'copy') else param.data
73
+ else:
74
+ # Interpolate: slow = alpha * slow + (1 - alpha) * param
75
+ slow = self.slow_weights[param_id]
76
+ if hasattr(slow, '__iter__') and hasattr(param.data, '__iter__'):
77
+ # Update slow weights (simplified - would need proper NumPy/list handling)
78
+ # For now, just store current param as slow weight
79
+ self.slow_weights[param_id] = param.data.copy() if hasattr(param.data, 'copy') else param.data
80
+ else:
81
+ self.slow_weights[param_id] = self.alpha * slow + (1.0 - self.alpha) * param.data
82
+
83
+ def sync_slow_weights(self):
84
+ """Synchronize parameters with slow weights."""
85
+ for param in self.base_optimizer.params:
86
+ param_id = id(param)
87
+ if param_id in self.slow_weights:
88
+ param.data = self.slow_weights[param_id]
89
+
90
+ def zero_grad(self, param: Optional[Tensor] = None):
91
+ """Clear gradients."""
92
+ self.base_optimizer.zero_grad(param)
93
+
94
+ def add_param_group(self, params: List[Tensor]):
95
+ """Add a parameter group to optimize."""
96
+ self.base_optimizer.add_param_group(params)
97
+
@@ -0,0 +1,228 @@
1
+ """
2
+ QuantOptimizer - Custom optimizer for quant trading.
3
+
4
+ Adaptive learning rate based on market volatility and regime-aware parameter updates.
5
+ """
6
+
7
+ from typing import List, Optional, Dict, Any
8
+ from quantml.tensor import Tensor
9
+ from quantml import ops
10
+
11
+ # Try to import NumPy
12
+ try:
13
+ import numpy as np
14
+ HAS_NUMPY = True
15
+ except ImportError:
16
+ HAS_NUMPY = False
17
+ np = None
18
+
19
+
20
+ class QuantOptimizer:
21
+ """
22
+ QuantOptimizer - Custom optimizer for quant trading.
23
+
24
+ Features:
25
+ - Adaptive learning rate based on market volatility
26
+ - Regime-aware parameter updates
27
+ - Per-feature learning rate scaling
28
+
29
+ Attributes:
30
+ lr: Base learning rate
31
+ volatility_window: Window size for volatility calculation
32
+ regime_threshold: Threshold for regime detection
33
+ feature_lrs: Per-feature learning rate multipliers
34
+
35
+ Examples:
36
+ >>> optimizer = QuantOptimizer(lr=0.001, volatility_window=20)
37
+ >>> for param in model.parameters():
38
+ >>> optimizer.step(param, market_volatility=0.02)
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ params: Optional[List[Tensor]] = None,
44
+ lr: float = 0.001,
45
+ beta1: float = 0.9,
46
+ beta2: float = 0.999,
47
+ eps: float = 1e-8,
48
+ weight_decay: float = 0.0,
49
+ volatility_window: int = 20,
50
+ regime_threshold: float = 0.5
51
+ ):
52
+ """
53
+ Initialize QuantOptimizer.
54
+
55
+ Args:
56
+ params: Optional list of parameters to optimize
57
+ lr: Base learning rate
58
+ beta1: First moment decay rate
59
+ beta2: Second moment decay rate
60
+ eps: Small value for numerical stability
61
+ weight_decay: Weight decay coefficient
62
+ volatility_window: Window size for volatility calculation
63
+ regime_threshold: Threshold for regime detection
64
+ """
65
+ self.params = params if params is not None else []
66
+ self.lr = lr
67
+ self.beta1 = beta1
68
+ self.beta2 = beta2
69
+ self.eps = eps
70
+ self.weight_decay = weight_decay
71
+ self.volatility_window = volatility_window
72
+ self.regime_threshold = regime_threshold
73
+
74
+ self.m: Dict[int, Any] = {} # First moment
75
+ self.v: Dict[int, Any] = {} # Second moment
76
+ self.step_count = 0
77
+ self.volatility_history: List[float] = []
78
+ self.feature_lrs: Dict[int, Any] = {} # Per-parameter learning rate multipliers
79
+
80
+ def step(self, param: Optional[Tensor] = None, market_volatility: Optional[float] = None):
81
+ """
82
+ Perform a single optimization step.
83
+
84
+ Args:
85
+ param: Optional single parameter to update
86
+ market_volatility: Current market volatility (for adaptive LR)
87
+ """
88
+ if param is not None:
89
+ self._update_param(param, market_volatility)
90
+ else:
91
+ for p in self.params:
92
+ self._update_param(p, market_volatility)
93
+ self.step_count += 1
94
+
95
+ # Update volatility history
96
+ if market_volatility is not None:
97
+ self.volatility_history.append(market_volatility)
98
+ if len(self.volatility_history) > self.volatility_window:
99
+ self.volatility_history.pop(0)
100
+
101
+ def _update_param(self, param: Tensor, market_volatility: Optional[float] = None):
102
+ """Update a single parameter with adaptive learning rate."""
103
+ if not param.requires_grad:
104
+ return
105
+
106
+ if param.grad is None:
107
+ return
108
+
109
+ param_id = id(param)
110
+
111
+ # Calculate adaptive learning rate based on volatility
112
+ adaptive_lr = self.lr
113
+ if market_volatility is not None and len(self.volatility_history) > 1:
114
+ # Adjust LR based on volatility regime
115
+ avg_vol = sum(self.volatility_history) / len(self.volatility_history)
116
+ if market_volatility > avg_vol * (1.0 + self.regime_threshold):
117
+ # High volatility regime - reduce LR
118
+ adaptive_lr = self.lr * 0.5
119
+ elif market_volatility < avg_vol * (1.0 - self.regime_threshold):
120
+ # Low volatility regime - increase LR
121
+ adaptive_lr = self.lr * 1.5
122
+
123
+ # Apply per-parameter learning rate multiplier
124
+ if param_id in self.feature_lrs:
125
+ adaptive_lr = adaptive_lr * self.feature_lrs[param_id]
126
+
127
+ if HAS_NUMPY:
128
+ try:
129
+ grad = param.grad
130
+ if isinstance(grad, np.ndarray):
131
+ grad_arr = grad
132
+ else:
133
+ grad_arr = np.array(grad, dtype=np.float64)
134
+
135
+ param_arr = param.numpy if param.numpy is not None else np.array(param.data, dtype=np.float64)
136
+
137
+ if param_id not in self.m:
138
+ self.m[param_id] = np.zeros_like(param_arr, dtype=np.float64)
139
+ self.v[param_id] = np.zeros_like(param_arr, dtype=np.float64)
140
+
141
+ if self.weight_decay > 0:
142
+ grad_arr = grad_arr + self.weight_decay * param_arr
143
+
144
+ # Update moments (Adam-like)
145
+ m = self.m[param_id]
146
+ v = self.v[param_id]
147
+ m[:] = self.beta1 * m + (1.0 - self.beta1) * grad_arr
148
+ v[:] = self.beta2 * v + (1.0 - self.beta2) * (grad_arr ** 2)
149
+
150
+ # Bias correction
151
+ bias_correction1 = 1.0 - (self.beta1 ** self.step_count)
152
+ bias_correction2 = 1.0 - (self.beta2 ** self.step_count)
153
+ m_hat = m / bias_correction1
154
+ v_hat = v / bias_correction2
155
+
156
+ # Update parameter with adaptive LR
157
+ v_hat_sqrt = np.sqrt(v_hat) + self.eps
158
+ update = m_hat / v_hat_sqrt
159
+ param_update = adaptive_lr * update
160
+ new_param_arr = param_arr - param_update
161
+ param.data = new_param_arr
162
+
163
+ except (ValueError, TypeError, AttributeError):
164
+ self._update_param_fallback(param, adaptive_lr)
165
+ else:
166
+ self._update_param_fallback(param, adaptive_lr)
167
+
168
+ def _update_param_fallback(self, param: Tensor, adaptive_lr: float):
169
+ """Fallback update using Tensor operations."""
170
+ if param.grad is None:
171
+ return
172
+
173
+ param_id = id(param)
174
+
175
+ if param_id not in self.m:
176
+ if isinstance(param.data[0], list):
177
+ self.m[param_id] = Tensor([[0.0] * len(row) for row in param.data])
178
+ self.v[param_id] = Tensor([[0.0] * len(row) for row in param.data])
179
+ else:
180
+ self.m[param_id] = Tensor([0.0] * len(param.data))
181
+ self.v[param_id] = Tensor([0.0] * len(param.data))
182
+
183
+ grad = Tensor(param.grad)
184
+ if self.weight_decay > 0:
185
+ grad = ops.add(grad, ops.mul(param, self.weight_decay))
186
+
187
+ m_prev = self.m[param_id]
188
+ m_new = ops.add(ops.mul(m_prev, self.beta1), ops.mul(grad, 1.0 - self.beta1))
189
+ self.m[param_id] = m_new
190
+
191
+ v_prev = self.v[param_id]
192
+ grad_sq = ops.mul(grad, grad)
193
+ v_new = ops.add(ops.mul(v_prev, self.beta2), ops.mul(grad_sq, 1.0 - self.beta2))
194
+ self.v[param_id] = v_new
195
+
196
+ bias_correction1 = 1.0 - (self.beta1 ** self.step_count)
197
+ bias_correction2 = 1.0 - (self.beta2 ** self.step_count)
198
+ m_hat = ops.div(m_new, bias_correction1)
199
+ v_hat = ops.div(v_new, bias_correction2)
200
+
201
+ v_hat_sqrt = ops.pow(ops.add(v_hat, self.eps), 0.5)
202
+ update = ops.div(m_hat, v_hat_sqrt)
203
+ param_update = ops.mul(update, adaptive_lr)
204
+
205
+ if param.requires_grad:
206
+ param_detached = param.detach()
207
+ param_detached.sub_(param_update)
208
+ param.data = param_detached.data
209
+ else:
210
+ param.sub_(param_update)
211
+
212
+ def set_feature_lr(self, param: Tensor, multiplier: float):
213
+ """Set learning rate multiplier for a specific parameter."""
214
+ param_id = id(param)
215
+ self.feature_lrs[param_id] = multiplier
216
+
217
+ def zero_grad(self, param: Optional[Tensor] = None):
218
+ """Clear gradients."""
219
+ if param is not None:
220
+ param.zero_grad()
221
+ else:
222
+ for p in self.params:
223
+ p.zero_grad()
224
+
225
+ def add_param_group(self, params: List[Tensor]):
226
+ """Add a parameter group to optimize."""
227
+ self.params.extend(params)
228
+
quantml/optim/radam.py ADDED
@@ -0,0 +1,192 @@
1
+ """
2
+ RAdam (Rectified Adam) optimizer implementation.
3
+
4
+ RAdam rectifies the variance of the adaptive learning rate in Adam.
5
+ """
6
+
7
+ from typing import List, Optional, Dict, Any
8
+ from quantml.tensor import Tensor
9
+ from quantml import ops
10
+ import math
11
+
12
+ # Try to import NumPy
13
+ try:
14
+ import numpy as np
15
+ HAS_NUMPY = True
16
+ except ImportError:
17
+ HAS_NUMPY = False
18
+ np = None
19
+
20
+
21
+ class RAdam:
22
+ """
23
+ RAdam (Rectified Adam) optimizer.
24
+
25
+ Rectifies the variance of the adaptive learning rate in Adam,
26
+ providing better convergence for quant training.
27
+
28
+ Attributes:
29
+ lr: Learning rate
30
+ betas: Tuple of (beta1, beta2) for moment estimates
31
+ eps: Small value for numerical stability
32
+ weight_decay: Weight decay coefficient
33
+ m: First moment estimates
34
+ v: Second moment estimates
35
+
36
+ Examples:
37
+ >>> optimizer = RAdam(lr=0.001, betas=(0.9, 0.999))
38
+ >>> for param in model.parameters():
39
+ >>> optimizer.step(param)
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ params: Optional[List[Tensor]] = None,
45
+ lr: float = 0.001,
46
+ betas: tuple = (0.9, 0.999),
47
+ eps: float = 1e-8,
48
+ weight_decay: float = 0.0
49
+ ):
50
+ """
51
+ Initialize RAdam optimizer.
52
+
53
+ Args:
54
+ params: Optional list of parameters to optimize
55
+ lr: Learning rate
56
+ betas: Tuple of (beta1, beta2) for exponential decay rates
57
+ eps: Small value to prevent division by zero
58
+ weight_decay: Weight decay coefficient
59
+ """
60
+ self.params = params if params is not None else []
61
+ self.lr = lr
62
+ self.beta1, self.beta2 = betas
63
+ self.eps = eps
64
+ self.weight_decay = weight_decay
65
+ self.m: Dict[int, Any] = {} # First moment
66
+ self.v: Dict[int, Any] = {} # Second moment
67
+ self.step_count = 0
68
+
69
+ def step(self, param: Optional[Tensor] = None):
70
+ """Perform a single optimization step."""
71
+ if param is not None:
72
+ self._update_param(param)
73
+ else:
74
+ for p in self.params:
75
+ self._update_param(p)
76
+ self.step_count += 1
77
+
78
+ def _update_param(self, param: Tensor):
79
+ """Update a single parameter using RAdam algorithm."""
80
+ if not param.requires_grad:
81
+ return
82
+
83
+ if param.grad is None:
84
+ return
85
+
86
+ param_id = id(param)
87
+
88
+ if HAS_NUMPY:
89
+ try:
90
+ grad = param.grad
91
+ if isinstance(grad, np.ndarray):
92
+ grad_arr = grad
93
+ else:
94
+ grad_arr = np.array(grad, dtype=np.float64)
95
+
96
+ param_arr = param.numpy if param.numpy is not None else np.array(param.data, dtype=np.float64)
97
+
98
+ if param_id not in self.m:
99
+ self.m[param_id] = np.zeros_like(param_arr, dtype=np.float64)
100
+ self.v[param_id] = np.zeros_like(param_arr, dtype=np.float64)
101
+
102
+ if self.weight_decay > 0:
103
+ grad_arr = grad_arr + self.weight_decay * param_arr
104
+
105
+ # Update moments
106
+ m = self.m[param_id]
107
+ v = self.v[param_id]
108
+ m[:] = self.beta1 * m + (1.0 - self.beta1) * grad_arr
109
+ v[:] = self.beta2 * v + (1.0 - self.beta2) * (grad_arr ** 2)
110
+
111
+ # RAdam variance rectification
112
+ beta2_t = self.beta2 ** self.step_count
113
+ rho_inf = 2.0 / (1.0 - self.beta2) - 1.0
114
+ rho_t = rho_inf - 2.0 * self.step_count * beta2_t / (1.0 - beta2_t)
115
+
116
+ if rho_t > 4.0:
117
+ # Rectified update
118
+ r_t = math.sqrt((rho_t - 4.0) * (rho_t - 2.0) * rho_inf / ((rho_inf - 4.0) * (rho_inf - 2.0) * rho_t))
119
+ m_hat = m / (1.0 - self.beta1 ** self.step_count)
120
+ v_hat = v / (1.0 - beta2_t)
121
+ update = r_t * m_hat / (np.sqrt(v_hat) + self.eps)
122
+ else:
123
+ # Simple momentum update
124
+ m_hat = m / (1.0 - self.beta1 ** self.step_count)
125
+ update = m_hat
126
+
127
+ param_update = self.lr * update
128
+ new_param_arr = param_arr - param_update
129
+ param.data = new_param_arr
130
+
131
+ except (ValueError, TypeError, AttributeError):
132
+ self._update_param_fallback(param)
133
+ else:
134
+ self._update_param_fallback(param)
135
+
136
+ def _update_param_fallback(self, param: Tensor):
137
+ """Fallback update using Tensor operations."""
138
+ if param.grad is None:
139
+ return
140
+
141
+ param_id = id(param)
142
+
143
+ if param_id not in self.m:
144
+ if isinstance(param.data[0], list):
145
+ self.m[param_id] = Tensor([[0.0] * len(row) for row in param.data])
146
+ self.v[param_id] = Tensor([[0.0] * len(row) for row in param.data])
147
+ else:
148
+ self.m[param_id] = Tensor([0.0] * len(param.data))
149
+ self.v[param_id] = Tensor([0.0] * len(param.data))
150
+
151
+ grad = Tensor(param.grad)
152
+ if self.weight_decay > 0:
153
+ grad = ops.add(grad, ops.mul(param, self.weight_decay))
154
+
155
+ m_prev = self.m[param_id]
156
+ m_new = ops.add(ops.mul(m_prev, self.beta1), ops.mul(grad, 1.0 - self.beta1))
157
+ self.m[param_id] = m_new
158
+
159
+ v_prev = self.v[param_id]
160
+ grad_sq = ops.mul(grad, grad)
161
+ v_new = ops.add(ops.mul(v_prev, self.beta2), ops.mul(grad_sq, 1.0 - self.beta2))
162
+ self.v[param_id] = v_new
163
+
164
+ # Simplified RAdam (use regular Adam update in fallback)
165
+ bias_correction1 = 1.0 - (self.beta1 ** self.step_count)
166
+ bias_correction2 = 1.0 - (self.beta2 ** self.step_count)
167
+ m_hat = ops.div(m_new, bias_correction1)
168
+ v_hat = ops.div(v_new, bias_correction2)
169
+
170
+ v_hat_sqrt = ops.pow(ops.add(v_hat, self.eps), 0.5)
171
+ update = ops.div(m_hat, v_hat_sqrt)
172
+ param_update = ops.mul(update, self.lr)
173
+
174
+ if param.requires_grad:
175
+ param_detached = param.detach()
176
+ param_detached.sub_(param_update)
177
+ param.data = param_detached.data
178
+ else:
179
+ param.sub_(param_update)
180
+
181
+ def zero_grad(self, param: Optional[Tensor] = None):
182
+ """Clear gradients."""
183
+ if param is not None:
184
+ param.zero_grad()
185
+ else:
186
+ for p in self.params:
187
+ p.zero_grad()
188
+
189
+ def add_param_group(self, params: List[Tensor]):
190
+ """Add a parameter group to optimize."""
191
+ self.params.extend(params)
192
+