quantmllibrary 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. quantml/__init__.py +74 -0
  2. quantml/autograd.py +154 -0
  3. quantml/cli/__init__.py +10 -0
  4. quantml/cli/run_experiment.py +385 -0
  5. quantml/config/__init__.py +28 -0
  6. quantml/config/config.py +259 -0
  7. quantml/data/__init__.py +33 -0
  8. quantml/data/cache.py +149 -0
  9. quantml/data/feature_store.py +234 -0
  10. quantml/data/futures.py +254 -0
  11. quantml/data/loaders.py +236 -0
  12. quantml/data/memory_optimizer.py +234 -0
  13. quantml/data/validators.py +390 -0
  14. quantml/experiments/__init__.py +23 -0
  15. quantml/experiments/logger.py +208 -0
  16. quantml/experiments/results.py +158 -0
  17. quantml/experiments/tracker.py +223 -0
  18. quantml/features/__init__.py +25 -0
  19. quantml/features/base.py +104 -0
  20. quantml/features/gap_features.py +124 -0
  21. quantml/features/registry.py +138 -0
  22. quantml/features/volatility_features.py +140 -0
  23. quantml/features/volume_features.py +142 -0
  24. quantml/functional.py +37 -0
  25. quantml/models/__init__.py +27 -0
  26. quantml/models/attention.py +258 -0
  27. quantml/models/dropout.py +130 -0
  28. quantml/models/gru.py +319 -0
  29. quantml/models/linear.py +112 -0
  30. quantml/models/lstm.py +353 -0
  31. quantml/models/mlp.py +286 -0
  32. quantml/models/normalization.py +289 -0
  33. quantml/models/rnn.py +154 -0
  34. quantml/models/tcn.py +238 -0
  35. quantml/online.py +209 -0
  36. quantml/ops.py +1707 -0
  37. quantml/optim/__init__.py +42 -0
  38. quantml/optim/adafactor.py +206 -0
  39. quantml/optim/adagrad.py +157 -0
  40. quantml/optim/adam.py +267 -0
  41. quantml/optim/lookahead.py +97 -0
  42. quantml/optim/quant_optimizer.py +228 -0
  43. quantml/optim/radam.py +192 -0
  44. quantml/optim/rmsprop.py +203 -0
  45. quantml/optim/schedulers.py +286 -0
  46. quantml/optim/sgd.py +181 -0
  47. quantml/py.typed +0 -0
  48. quantml/streaming.py +175 -0
  49. quantml/tensor.py +462 -0
  50. quantml/time_series.py +447 -0
  51. quantml/training/__init__.py +135 -0
  52. quantml/training/alpha_eval.py +203 -0
  53. quantml/training/backtest.py +280 -0
  54. quantml/training/backtest_analysis.py +168 -0
  55. quantml/training/cv.py +106 -0
  56. quantml/training/data_loader.py +177 -0
  57. quantml/training/ensemble.py +84 -0
  58. quantml/training/feature_importance.py +135 -0
  59. quantml/training/features.py +364 -0
  60. quantml/training/futures_backtest.py +266 -0
  61. quantml/training/gradient_clipping.py +206 -0
  62. quantml/training/losses.py +248 -0
  63. quantml/training/lr_finder.py +127 -0
  64. quantml/training/metrics.py +376 -0
  65. quantml/training/regularization.py +89 -0
  66. quantml/training/trainer.py +239 -0
  67. quantml/training/walk_forward.py +190 -0
  68. quantml/utils/__init__.py +51 -0
  69. quantml/utils/gradient_check.py +274 -0
  70. quantml/utils/logging.py +181 -0
  71. quantml/utils/ops_cpu.py +231 -0
  72. quantml/utils/profiling.py +364 -0
  73. quantml/utils/reproducibility.py +220 -0
  74. quantml/utils/serialization.py +335 -0
  75. quantmllibrary-0.1.0.dist-info/METADATA +536 -0
  76. quantmllibrary-0.1.0.dist-info/RECORD +79 -0
  77. quantmllibrary-0.1.0.dist-info/WHEEL +5 -0
  78. quantmllibrary-0.1.0.dist-info/licenses/LICENSE +22 -0
  79. quantmllibrary-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,376 @@
1
+ """
2
+ Financial and prediction metrics for quant model evaluation.
3
+
4
+ This module provides metrics commonly used in quantitative trading,
5
+ including Sharpe ratio, Information Coefficient (IC), drawdown, and more.
6
+ """
7
+
8
+ from typing import List, Optional, Union
9
+ import math
10
+
11
+ # Try to import NumPy
12
+ try:
13
+ import numpy as np
14
+ HAS_NUMPY = True
15
+ except ImportError:
16
+ HAS_NUMPY = False
17
+ np = None
18
+
19
+
20
+ def sharpe_ratio(returns: Union[List, any], risk_free_rate: float = 0.0, annualize: bool = True) -> float:
21
+ """
22
+ Calculate Sharpe ratio: (mean(returns) - risk_free_rate) / std(returns)
23
+
24
+ Args:
25
+ returns: List of returns
26
+ risk_free_rate: Risk-free rate (default: 0.0)
27
+ annualize: Whether to annualize (multiply by sqrt(252))
28
+
29
+ Returns:
30
+ Sharpe ratio
31
+ """
32
+ if len(returns) == 0:
33
+ return 0.0
34
+
35
+ if HAS_NUMPY:
36
+ try:
37
+ ret_arr = np.array(returns, dtype=np.float64)
38
+ mean_ret = np.mean(ret_arr)
39
+ std_ret = np.std(ret_arr)
40
+ if std_ret == 0:
41
+ return 0.0
42
+ sharpe = (mean_ret - risk_free_rate) / std_ret
43
+ if annualize:
44
+ sharpe *= math.sqrt(252)
45
+ return float(sharpe)
46
+ except (ValueError, TypeError):
47
+ pass
48
+
49
+ # Pure Python fallback
50
+ mean_ret = sum(returns) / len(returns)
51
+ variance = sum((r - mean_ret) ** 2 for r in returns) / len(returns)
52
+ std_ret = math.sqrt(variance) if variance > 0 else 0.0
53
+
54
+ if std_ret == 0:
55
+ return 0.0
56
+
57
+ sharpe = (mean_ret - risk_free_rate) / std_ret
58
+ if annualize:
59
+ sharpe *= math.sqrt(252)
60
+ return sharpe
61
+
62
+
63
+ def sortino_ratio(returns: Union[List, any], risk_free_rate: float = 0.0, annualize: bool = True) -> float:
64
+ """
65
+ Calculate Sortino ratio: (mean(returns) - risk_free_rate) / downside_std(returns)
66
+
67
+ Args:
68
+ returns: List of returns
69
+ risk_free_rate: Risk-free rate
70
+ annualize: Whether to annualize
71
+
72
+ Returns:
73
+ Sortino ratio
74
+ """
75
+ if len(returns) == 0:
76
+ return 0.0
77
+
78
+ if HAS_NUMPY:
79
+ try:
80
+ ret_arr = np.array(returns, dtype=np.float64)
81
+ mean_ret = np.mean(ret_arr)
82
+ # Downside deviation: only negative returns
83
+ downside = ret_arr[ret_arr < 0]
84
+ if len(downside) == 0:
85
+ return float('inf') if mean_ret > risk_free_rate else 0.0
86
+ downside_std = np.std(downside)
87
+ if downside_std == 0:
88
+ return 0.0
89
+ sortino = (mean_ret - risk_free_rate) / downside_std
90
+ if annualize:
91
+ sortino *= math.sqrt(252)
92
+ return float(sortino)
93
+ except (ValueError, TypeError):
94
+ pass
95
+
96
+ # Pure Python fallback
97
+ mean_ret = sum(returns) / len(returns)
98
+ downside = [r for r in returns if r < 0]
99
+ if len(downside) == 0:
100
+ return float('inf') if mean_ret > risk_free_rate else 0.0
101
+
102
+ downside_mean = sum(downside) / len(downside)
103
+ downside_var = sum((d - downside_mean) ** 2 for d in downside) / len(downside)
104
+ downside_std = math.sqrt(downside_var) if downside_var > 0 else 0.0
105
+
106
+ if downside_std == 0:
107
+ return 0.0
108
+
109
+ sortino = (mean_ret - risk_free_rate) / downside_std
110
+ if annualize:
111
+ sortino *= math.sqrt(252)
112
+ return sortino
113
+
114
+
115
+ def calmar_ratio(returns: Union[List, any], annualize: bool = True) -> float:
116
+ """
117
+ Calculate Calmar ratio: annual_return / max_drawdown
118
+
119
+ Args:
120
+ returns: List of returns
121
+ annualize: Whether to annualize returns
122
+
123
+ Returns:
124
+ Calmar ratio
125
+ """
126
+ if len(returns) == 0:
127
+ return 0.0
128
+
129
+ annual_return = sum(returns) / len(returns)
130
+ if annualize:
131
+ annual_return *= 252
132
+
133
+ max_dd = max_drawdown(returns)
134
+ if max_dd == 0:
135
+ return 0.0
136
+
137
+ return abs(annual_return / max_dd)
138
+
139
+
140
+ def max_drawdown(returns: Union[List, any]) -> float:
141
+ """
142
+ Calculate maximum drawdown.
143
+
144
+ Args:
145
+ returns: List of returns
146
+
147
+ Returns:
148
+ Maximum drawdown (as positive value)
149
+ """
150
+ if len(returns) == 0:
151
+ return 0.0
152
+
153
+ if HAS_NUMPY:
154
+ try:
155
+ ret_arr = np.array(returns, dtype=np.float64)
156
+ # Cumulative returns
157
+ cum_ret = np.cumprod(1 + ret_arr)
158
+ # Running maximum
159
+ running_max = np.maximum.accumulate(cum_ret)
160
+ # Drawdown
161
+ drawdown = (cum_ret - running_max) / running_max
162
+ return float(abs(np.min(drawdown)))
163
+ except (ValueError, TypeError):
164
+ pass
165
+
166
+ # Pure Python fallback
167
+ cum_ret = 1.0
168
+ running_max = 1.0
169
+ max_dd = 0.0
170
+
171
+ for ret in returns:
172
+ cum_ret *= (1 + ret)
173
+ running_max = max(running_max, cum_ret)
174
+ dd = (cum_ret - running_max) / running_max
175
+ max_dd = min(max_dd, dd)
176
+
177
+ return abs(max_dd)
178
+
179
+
180
+ def information_coefficient(predictions: Union[List, any], actuals: Union[List, any]) -> float:
181
+ """
182
+ Calculate Information Coefficient (IC): correlation between predictions and actuals.
183
+
184
+ Args:
185
+ predictions: Predicted values
186
+ actuals: Actual/realized values
187
+
188
+ Returns:
189
+ IC (correlation coefficient)
190
+ """
191
+ if len(predictions) != len(actuals) or len(predictions) == 0:
192
+ return 0.0
193
+
194
+ if HAS_NUMPY:
195
+ try:
196
+ pred_arr = np.array(predictions, dtype=np.float64)
197
+ actual_arr = np.array(actuals, dtype=np.float64)
198
+ # Pearson correlation
199
+ corr = np.corrcoef(pred_arr, actual_arr)[0, 1]
200
+ return float(corr) if not np.isnan(corr) else 0.0
201
+ except (ValueError, TypeError):
202
+ pass
203
+
204
+ # Pure Python fallback
205
+ n = len(predictions)
206
+ pred_mean = sum(predictions) / n
207
+ actual_mean = sum(actuals) / n
208
+
209
+ numerator = sum((predictions[i] - pred_mean) * (actuals[i] - actual_mean) for i in range(n))
210
+ pred_var = sum((p - pred_mean) ** 2 for p in predictions)
211
+ actual_var = sum((a - actual_mean) ** 2 for a in actuals)
212
+
213
+ denominator = math.sqrt(pred_var * actual_var)
214
+ if denominator == 0:
215
+ return 0.0
216
+
217
+ return numerator / denominator
218
+
219
+
220
+ def rank_ic(predictions: Union[List, any], actuals: Union[List, any]) -> float:
221
+ """
222
+ Calculate Rank IC: Spearman rank correlation.
223
+
224
+ Args:
225
+ predictions: Predicted values
226
+ actuals: Actual values
227
+
228
+ Returns:
229
+ Rank IC
230
+ """
231
+ if len(predictions) != len(actuals) or len(predictions) == 0:
232
+ return 0.0
233
+
234
+ if HAS_NUMPY:
235
+ try:
236
+ from scipy.stats import spearmanr
237
+ corr, _ = spearmanr(predictions, actuals)
238
+ return float(corr) if not np.isnan(corr) else 0.0
239
+ except ImportError:
240
+ # Fallback to manual rank correlation
241
+ pass
242
+
243
+ # Manual rank correlation
244
+ n = len(predictions)
245
+ pred_ranks = _get_ranks(predictions)
246
+ actual_ranks = _get_ranks(actuals)
247
+
248
+ return information_coefficient(pred_ranks, actual_ranks)
249
+
250
+
251
+ def _get_ranks(values: List) -> List:
252
+ """Get ranks of values."""
253
+ sorted_vals = sorted(enumerate(values), key=lambda x: x[1])
254
+ ranks = [0] * len(values)
255
+ for rank, (idx, _) in enumerate(sorted_vals, 1):
256
+ ranks[idx] = rank
257
+ return ranks
258
+
259
+
260
+ def hit_rate(predictions: Union[List, any], actuals: Union[List, any]) -> float:
261
+ """
262
+ Calculate hit rate: percentage of correct directional predictions.
263
+
264
+ Args:
265
+ predictions: Predicted values
266
+ actuals: Actual values
267
+
268
+ Returns:
269
+ Hit rate (0.0 to 1.0)
270
+ """
271
+ if len(predictions) != len(actuals) or len(predictions) == 0:
272
+ return 0.0
273
+
274
+ correct = 0
275
+ for i in range(len(predictions)):
276
+ pred_dir = 1 if predictions[i] > 0 else -1
277
+ actual_dir = 1 if actuals[i] > 0 else -1
278
+ if pred_dir == actual_dir:
279
+ correct += 1
280
+
281
+ return correct / len(predictions)
282
+
283
+
284
+ def var(returns: Union[List, any], confidence: float = 0.05) -> float:
285
+ """
286
+ Calculate Value at Risk (VaR).
287
+
288
+ Args:
289
+ returns: List of returns
290
+ confidence: Confidence level (default: 0.05 for 95% VaR)
291
+
292
+ Returns:
293
+ VaR (negative value)
294
+ """
295
+ if len(returns) == 0:
296
+ return 0.0
297
+
298
+ if HAS_NUMPY:
299
+ try:
300
+ ret_arr = np.array(returns, dtype=np.float64)
301
+ var_val = np.percentile(ret_arr, confidence * 100)
302
+ return float(var_val)
303
+ except (ValueError, TypeError):
304
+ pass
305
+
306
+ # Pure Python fallback
307
+ sorted_returns = sorted(returns)
308
+ idx = int(len(sorted_returns) * confidence)
309
+ return sorted_returns[idx] if idx < len(sorted_returns) else sorted_returns[-1]
310
+
311
+
312
+ def cvar(returns: Union[List, any], confidence: float = 0.05) -> float:
313
+ """
314
+ Calculate Conditional Value at Risk (CVaR) / Expected Shortfall.
315
+
316
+ Args:
317
+ returns: List of returns
318
+ confidence: Confidence level
319
+
320
+ Returns:
321
+ CVaR (negative value)
322
+ """
323
+ if len(returns) == 0:
324
+ return 0.0
325
+
326
+ var_val = var(returns, confidence)
327
+
328
+ if HAS_NUMPY:
329
+ try:
330
+ ret_arr = np.array(returns, dtype=np.float64)
331
+ tail_losses = ret_arr[ret_arr <= var_val]
332
+ if len(tail_losses) == 0:
333
+ return var_val
334
+ return float(np.mean(tail_losses))
335
+ except (ValueError, TypeError):
336
+ pass
337
+
338
+ # Pure Python fallback
339
+ tail_losses = [r for r in returns if r <= var_val]
340
+ if len(tail_losses) == 0:
341
+ return var_val
342
+ return sum(tail_losses) / len(tail_losses)
343
+
344
+
345
+ def turnover(positions: Union[List, any]) -> float:
346
+ """
347
+ Calculate portfolio turnover.
348
+
349
+ Args:
350
+ positions: List of position values (can be weights)
351
+
352
+ Returns:
353
+ Average turnover
354
+ """
355
+ if len(positions) < 2:
356
+ return 0.0
357
+
358
+ if HAS_NUMPY:
359
+ try:
360
+ pos_arr = np.array(positions, dtype=np.float64)
361
+ changes = np.abs(np.diff(pos_arr, axis=0))
362
+ return float(np.mean(np.sum(changes, axis=1) if len(changes.shape) > 1 else changes))
363
+ except (ValueError, TypeError):
364
+ pass
365
+
366
+ # Pure Python fallback
367
+ total_turnover = 0.0
368
+ for i in range(1, len(positions)):
369
+ if isinstance(positions[i], list) and isinstance(positions[i-1], list):
370
+ change = sum(abs(positions[i][j] - positions[i-1][j]) for j in range(len(positions[i])))
371
+ else:
372
+ change = abs(positions[i] - positions[i-1])
373
+ total_turnover += change
374
+
375
+ return total_turnover / (len(positions) - 1)
376
+
@@ -0,0 +1,89 @@
1
+ """
2
+ Regularization utilities for quant models.
3
+
4
+ Provides dropout and other regularization techniques.
5
+ """
6
+
7
+ from typing import Optional
8
+ from quantml.tensor import Tensor
9
+ from quantml import ops
10
+ import random
11
+
12
+ # Try to import NumPy
13
+ try:
14
+ import numpy as np
15
+ HAS_NUMPY = True
16
+ except ImportError:
17
+ HAS_NUMPY = False
18
+ np = None
19
+
20
+
21
+ class Dropout:
22
+ """
23
+ Dropout layer for regularization.
24
+
25
+ Randomly sets a fraction of inputs to zero during training.
26
+ """
27
+
28
+ def __init__(self, p: float = 0.5):
29
+ """
30
+ Initialize dropout layer.
31
+
32
+ Args:
33
+ p: Probability of dropping out (0.0 to 1.0)
34
+ """
35
+ self.p = p
36
+ self.training = True
37
+ self.mask = None
38
+
39
+ def forward(self, x: Tensor) -> Tensor:
40
+ """
41
+ Forward pass through dropout.
42
+
43
+ Args:
44
+ x: Input tensor
45
+
46
+ Returns:
47
+ Output tensor with dropout applied
48
+ """
49
+ if not self.training or self.p == 0.0:
50
+ return x
51
+
52
+ # Create dropout mask
53
+ if HAS_NUMPY:
54
+ try:
55
+ import numpy as np
56
+ x_arr = x.numpy if x.numpy is not None else np.array(x.data, dtype=np.float64)
57
+ mask = (np.random.random(x_arr.shape) > self.p).astype(np.float64)
58
+ mask = mask / (1.0 - self.p) # Scale to maintain expected value
59
+ self.mask = mask
60
+ out_arr = x_arr * mask
61
+ return Tensor(out_arr.tolist(), requires_grad=x.requires_grad)
62
+ except:
63
+ pass
64
+
65
+ # Fallback to list-based
66
+ if isinstance(x.data[0], list):
67
+ mask = [[1.0 if random.random() > self.p else 0.0 for _ in row] for row in x.data]
68
+ scale = 1.0 / (1.0 - self.p)
69
+ mask = [[m * scale for m in row] for row in mask]
70
+ self.mask = mask
71
+ out_data = [[x.data[i][j] * mask[i][j] for j in range(len(x.data[i]))]
72
+ for i in range(len(x.data))]
73
+ else:
74
+ mask = [1.0 if random.random() > self.p else 0.0 for _ in x.data]
75
+ scale = 1.0 / (1.0 - self.p)
76
+ mask = [m * scale for m in mask]
77
+ self.mask = mask
78
+ out_data = [x.data[i] * mask[i] for i in range(len(x.data))]
79
+
80
+ return Tensor(out_data, requires_grad=x.requires_grad)
81
+
82
+ def eval(self):
83
+ """Set to evaluation mode (no dropout)."""
84
+ self.training = False
85
+
86
+ def train(self):
87
+ """Set to training mode (with dropout)."""
88
+ self.training = True
89
+
@@ -0,0 +1,239 @@
1
+ """
2
+ Training utilities for quant models.
3
+
4
+ This module provides QuantTrainer class with training loop, early stopping,
5
+ checkpointing, and metrics tracking.
6
+ """
7
+
8
+ from typing import Optional, Callable, Dict, Any, List
9
+ from quantml.tensor import Tensor
10
+ from quantml import ops
11
+ from quantml.training.gradient_clipping import GradientNormClipper, GradientValueClipper, AdaptiveClipper
12
+
13
+ # Try to import NumPy
14
+ try:
15
+ import numpy as np
16
+ HAS_NUMPY = True
17
+ except ImportError:
18
+ HAS_NUMPY = False
19
+ np = None
20
+
21
+
22
+ class QuantTrainer:
23
+ """
24
+ Trainer class for quant model training.
25
+
26
+ Provides training loop with early stopping, checkpointing, and metrics tracking.
27
+
28
+ Attributes:
29
+ model: Model to train
30
+ optimizer: Optimizer
31
+ loss_fn: Loss function
32
+ metrics: List of metric functions to track
33
+
34
+ Examples:
35
+ >>> trainer = QuantTrainer(model, optimizer, loss_fn=mse_loss)
36
+ >>> trainer.train(X_train, y_train, X_val, y_val, epochs=100)
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ model: Any,
42
+ optimizer: Any,
43
+ loss_fn: Callable,
44
+ metrics: Optional[List[Callable]] = None,
45
+ gradient_clipper: Optional[Any] = None,
46
+ accumulation_steps: int = 1
47
+ ):
48
+ """
49
+ Initialize trainer.
50
+
51
+ Args:
52
+ model: Model to train
53
+ optimizer: Optimizer (SGD, Adam, etc.)
54
+ loss_fn: Loss function
55
+ metrics: Optional list of metric functions
56
+ gradient_clipper: Optional gradient clipper (GradientNormClipper, etc.)
57
+ accumulation_steps: Number of steps to accumulate gradients before optimizer step
58
+ """
59
+ self.model = model
60
+ self.optimizer = optimizer
61
+ self.loss_fn = loss_fn
62
+ self.metrics = metrics if metrics else []
63
+ self.gradient_clipper = gradient_clipper
64
+ self.accumulation_steps = accumulation_steps
65
+ self.accumulation_counter = 0
66
+ self.history = {'train_loss': [], 'val_loss': []}
67
+
68
+ def train_step(self, x: Tensor, y: Tensor) -> float:
69
+ """
70
+ Perform a single training step.
71
+
72
+ Args:
73
+ x: Input features
74
+ y: Targets
75
+
76
+ Returns:
77
+ Loss value
78
+ """
79
+ # Forward pass
80
+ pred = self.model.forward(x)
81
+
82
+ # Compute loss (scale by accumulation steps for correct averaging)
83
+ loss = self.loss_fn(pred, y)
84
+ if self.accumulation_steps > 1:
85
+ loss = ops.mul(loss, 1.0 / self.accumulation_steps)
86
+
87
+ # Backward pass
88
+ if loss.requires_grad:
89
+ loss.backward()
90
+
91
+ self.accumulation_counter += 1
92
+
93
+ # Apply gradient clipping if enabled
94
+ if self.gradient_clipper is not None and self.accumulation_counter % self.accumulation_steps == 0:
95
+ params = self.model.parameters() if hasattr(self.model, 'parameters') else []
96
+ self.gradient_clipper(params)
97
+
98
+ # Optimizer step only after accumulation
99
+ if self.accumulation_counter % self.accumulation_steps == 0:
100
+ self.optimizer.step()
101
+ self.model.zero_grad()
102
+ self.accumulation_counter = 0
103
+
104
+ # Get loss value (unscale for reporting)
105
+ loss_value = self._get_value(loss)
106
+ if self.accumulation_steps > 1:
107
+ loss_value = loss_value * self.accumulation_steps
108
+ return loss_value
109
+
110
+ def validate(self, x: Tensor, y: Tensor) -> Dict[str, float]:
111
+ """
112
+ Validate model on data.
113
+
114
+ Args:
115
+ x: Input features
116
+ y: Targets
117
+
118
+ Returns:
119
+ Dictionary of metrics
120
+ """
121
+ # Forward pass
122
+ pred = self.model.forward(x)
123
+
124
+ # Compute loss
125
+ loss = self.loss_fn(pred, y)
126
+ metrics_dict = {'loss': self._get_value(loss)}
127
+
128
+ # Compute additional metrics
129
+ for metric_fn in self.metrics:
130
+ try:
131
+ metric_val = metric_fn(pred, y)
132
+ if isinstance(metric_val, Tensor):
133
+ metric_val = self._get_value(metric_val)
134
+ metrics_dict[metric_fn.__name__] = metric_val
135
+ except Exception:
136
+ pass
137
+
138
+ return metrics_dict
139
+
140
+ def train(
141
+ self,
142
+ X_train: List,
143
+ y_train: List,
144
+ X_val: Optional[List] = None,
145
+ y_val: Optional[List] = None,
146
+ epochs: int = 100,
147
+ batch_size: int = 32,
148
+ early_stopping: Optional[Dict] = None,
149
+ verbose: bool = True
150
+ ) -> Dict[str, List]:
151
+ """
152
+ Train model.
153
+
154
+ Args:
155
+ X_train: Training features
156
+ y_train: Training targets
157
+ X_val: Validation features (optional)
158
+ y_val: Validation targets (optional)
159
+ epochs: Number of epochs
160
+ batch_size: Batch size
161
+ early_stopping: Early stopping config (patience, min_delta)
162
+ verbose: Whether to print progress
163
+
164
+ Returns:
165
+ Training history
166
+ """
167
+ from quantml.training.data_loader import QuantDataLoader
168
+
169
+ train_loader = QuantDataLoader(X_train, y_train, batch_size=batch_size, shuffle=False)
170
+
171
+ val_loader = None
172
+ if X_val is not None and y_val is not None:
173
+ val_loader = QuantDataLoader(X_val, y_val, batch_size=batch_size, shuffle=False)
174
+
175
+ # Early stopping
176
+ best_val_loss = float('inf')
177
+ patience_counter = 0
178
+ patience = early_stopping.get('patience', 10) if early_stopping else None
179
+ min_delta = early_stopping.get('min_delta', 0.0) if early_stopping else 0.0
180
+
181
+ for epoch in range(epochs):
182
+ # Training
183
+ train_losses = []
184
+ for x_batch, y_batch in train_loader:
185
+ loss = self.train_step(x_batch, y_batch)
186
+ train_losses.append(loss)
187
+
188
+ avg_train_loss = sum(train_losses) / len(train_losses) if train_losses else 0.0
189
+ self.history['train_loss'].append(avg_train_loss)
190
+
191
+ # Validation
192
+ if val_loader is not None:
193
+ val_metrics = self._validate_loader(val_loader)
194
+ val_loss = val_metrics.get('loss', 0.0)
195
+ self.history['val_loss'].append(val_loss)
196
+
197
+ # Early stopping check
198
+ if patience is not None:
199
+ if val_loss < best_val_loss - min_delta:
200
+ best_val_loss = val_loss
201
+ patience_counter = 0
202
+ else:
203
+ patience_counter += 1
204
+ if patience_counter >= patience:
205
+ if verbose:
206
+ print(f"Early stopping at epoch {epoch + 1}")
207
+ break
208
+
209
+ if verbose and (epoch + 1) % 10 == 0:
210
+ print(f"Epoch {epoch + 1}/{epochs}: Train Loss = {avg_train_loss:.6f}", end="")
211
+ if val_loader is not None:
212
+ print(f", Val Loss = {val_loss:.6f}")
213
+ else:
214
+ print()
215
+
216
+ return self.history
217
+
218
+ def _validate_loader(self, loader) -> Dict[str, float]:
219
+ """Validate on data loader."""
220
+ all_metrics = []
221
+ for x_batch, y_batch in loader:
222
+ metrics = self.validate(x_batch, y_batch)
223
+ all_metrics.append(metrics)
224
+
225
+ # Average metrics
226
+ avg_metrics = {}
227
+ for key in all_metrics[0].keys():
228
+ avg_metrics[key] = sum(m[key] for m in all_metrics) / len(all_metrics)
229
+
230
+ return avg_metrics
231
+
232
+ def _get_value(self, tensor: Tensor) -> float:
233
+ """Extract scalar value from tensor."""
234
+ if isinstance(tensor.data, list):
235
+ if isinstance(tensor.data[0], list):
236
+ return float(tensor.data[0][0])
237
+ return float(tensor.data[0])
238
+ return float(tensor.data)
239
+