quantmllibrary 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. quantml/__init__.py +74 -0
  2. quantml/autograd.py +154 -0
  3. quantml/cli/__init__.py +10 -0
  4. quantml/cli/run_experiment.py +385 -0
  5. quantml/config/__init__.py +28 -0
  6. quantml/config/config.py +259 -0
  7. quantml/data/__init__.py +33 -0
  8. quantml/data/cache.py +149 -0
  9. quantml/data/feature_store.py +234 -0
  10. quantml/data/futures.py +254 -0
  11. quantml/data/loaders.py +236 -0
  12. quantml/data/memory_optimizer.py +234 -0
  13. quantml/data/validators.py +390 -0
  14. quantml/experiments/__init__.py +23 -0
  15. quantml/experiments/logger.py +208 -0
  16. quantml/experiments/results.py +158 -0
  17. quantml/experiments/tracker.py +223 -0
  18. quantml/features/__init__.py +25 -0
  19. quantml/features/base.py +104 -0
  20. quantml/features/gap_features.py +124 -0
  21. quantml/features/registry.py +138 -0
  22. quantml/features/volatility_features.py +140 -0
  23. quantml/features/volume_features.py +142 -0
  24. quantml/functional.py +37 -0
  25. quantml/models/__init__.py +27 -0
  26. quantml/models/attention.py +258 -0
  27. quantml/models/dropout.py +130 -0
  28. quantml/models/gru.py +319 -0
  29. quantml/models/linear.py +112 -0
  30. quantml/models/lstm.py +353 -0
  31. quantml/models/mlp.py +286 -0
  32. quantml/models/normalization.py +289 -0
  33. quantml/models/rnn.py +154 -0
  34. quantml/models/tcn.py +238 -0
  35. quantml/online.py +209 -0
  36. quantml/ops.py +1707 -0
  37. quantml/optim/__init__.py +42 -0
  38. quantml/optim/adafactor.py +206 -0
  39. quantml/optim/adagrad.py +157 -0
  40. quantml/optim/adam.py +267 -0
  41. quantml/optim/lookahead.py +97 -0
  42. quantml/optim/quant_optimizer.py +228 -0
  43. quantml/optim/radam.py +192 -0
  44. quantml/optim/rmsprop.py +203 -0
  45. quantml/optim/schedulers.py +286 -0
  46. quantml/optim/sgd.py +181 -0
  47. quantml/py.typed +0 -0
  48. quantml/streaming.py +175 -0
  49. quantml/tensor.py +462 -0
  50. quantml/time_series.py +447 -0
  51. quantml/training/__init__.py +135 -0
  52. quantml/training/alpha_eval.py +203 -0
  53. quantml/training/backtest.py +280 -0
  54. quantml/training/backtest_analysis.py +168 -0
  55. quantml/training/cv.py +106 -0
  56. quantml/training/data_loader.py +177 -0
  57. quantml/training/ensemble.py +84 -0
  58. quantml/training/feature_importance.py +135 -0
  59. quantml/training/features.py +364 -0
  60. quantml/training/futures_backtest.py +266 -0
  61. quantml/training/gradient_clipping.py +206 -0
  62. quantml/training/losses.py +248 -0
  63. quantml/training/lr_finder.py +127 -0
  64. quantml/training/metrics.py +376 -0
  65. quantml/training/regularization.py +89 -0
  66. quantml/training/trainer.py +239 -0
  67. quantml/training/walk_forward.py +190 -0
  68. quantml/utils/__init__.py +51 -0
  69. quantml/utils/gradient_check.py +274 -0
  70. quantml/utils/logging.py +181 -0
  71. quantml/utils/ops_cpu.py +231 -0
  72. quantml/utils/profiling.py +364 -0
  73. quantml/utils/reproducibility.py +220 -0
  74. quantml/utils/serialization.py +335 -0
  75. quantmllibrary-0.1.0.dist-info/METADATA +536 -0
  76. quantmllibrary-0.1.0.dist-info/RECORD +79 -0
  77. quantmllibrary-0.1.0.dist-info/WHEEL +5 -0
  78. quantmllibrary-0.1.0.dist-info/licenses/LICENSE +22 -0
  79. quantmllibrary-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,190 @@
1
+ """
2
+ Walk-forward optimization for time-series model training.
3
+
4
+ Walk-forward optimization is essential for quant trading to prevent lookahead bias
5
+ and properly evaluate model performance on out-of-sample data.
6
+ """
7
+
8
+ from typing import List, Tuple, Optional, Callable, Any
9
+ from enum import Enum
10
+
11
+
12
+ class WindowType(Enum):
13
+ """Type of walk-forward window."""
14
+ EXPANDING = "expanding" # Training window grows over time
15
+ ROLLING = "rolling" # Training window has fixed size
16
+
17
+
18
+ class WalkForwardOptimizer:
19
+ """
20
+ Walk-forward optimization for time-series cross-validation.
21
+
22
+ This class manages train/test splits for time-series data, ensuring
23
+ no lookahead bias by only using past data for training.
24
+
25
+ Attributes:
26
+ window_type: Type of window (expanding or rolling)
27
+ train_size: Initial training window size
28
+ test_size: Test window size
29
+ step_size: Step size for moving forward
30
+ min_train_size: Minimum training window size
31
+
32
+ Examples:
33
+ >>> wfo = WalkForwardOptimizer(
34
+ >>> window_type=WindowType.EXPANDING,
35
+ >>> train_size=252,
36
+ >>> test_size=21
37
+ >>> )
38
+ >>> for train_idx, test_idx in wfo.split(data, n_splits=5):
39
+ >>> train_data = data[train_idx]
40
+ >>> test_data = data[test_idx]
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ window_type: WindowType = WindowType.EXPANDING,
46
+ train_size: int = 252,
47
+ test_size: int = 21,
48
+ step_size: Optional[int] = None,
49
+ min_train_size: Optional[int] = None
50
+ ):
51
+ """
52
+ Initialize walk-forward optimizer.
53
+
54
+ Args:
55
+ window_type: Type of window (expanding or rolling)
56
+ train_size: Initial training window size (in samples)
57
+ test_size: Test window size (in samples)
58
+ step_size: Step size for moving forward (default: test_size)
59
+ min_train_size: Minimum training window size (default: train_size)
60
+ """
61
+ self.window_type = window_type
62
+ self.train_size = train_size
63
+ self.test_size = test_size
64
+ self.step_size = step_size if step_size is not None else test_size
65
+ self.min_train_size = min_train_size if min_train_size is not None else train_size
66
+
67
+ def split(self, data_length: int, n_splits: Optional[int] = None) -> List[Tuple[slice, slice]]:
68
+ """
69
+ Generate train/test splits.
70
+
71
+ Args:
72
+ data_length: Total length of data
73
+ n_splits: Number of splits to generate (None for all possible)
74
+
75
+ Returns:
76
+ List of (train_slice, test_slice) tuples
77
+ """
78
+ splits = []
79
+ start_idx = self.train_size
80
+
81
+ split_count = 0
82
+ while start_idx + self.test_size <= data_length:
83
+ # Training window
84
+ if self.window_type == WindowType.EXPANDING:
85
+ train_start = 0
86
+ train_end = start_idx
87
+ else: # ROLLING
88
+ train_start = start_idx - self.train_size
89
+ train_end = start_idx
90
+
91
+ # Test window
92
+ test_start = start_idx
93
+ test_end = min(start_idx + self.test_size, data_length)
94
+
95
+ # Ensure minimum training size
96
+ if train_end - train_start >= self.min_train_size:
97
+ train_slice = slice(train_start, train_end)
98
+ test_slice = slice(test_start, test_end)
99
+ splits.append((train_slice, test_slice))
100
+
101
+ split_count += 1
102
+ if n_splits is not None and split_count >= n_splits:
103
+ break
104
+
105
+ # Move forward
106
+ start_idx += self.step_size
107
+
108
+ return splits
109
+
110
+ def get_splits(self, data_length: int, n_splits: Optional[int] = None) -> List[Tuple[List[int], List[int]]]:
111
+ """
112
+ Get train/test indices as lists.
113
+
114
+ Args:
115
+ data_length: Total length of data
116
+ n_splits: Number of splits to generate
117
+
118
+ Returns:
119
+ List of (train_indices, test_indices) tuples
120
+ """
121
+ splits = self.split(data_length, n_splits)
122
+ return [
123
+ (list(range(s[0].start, s[0].stop)), list(range(s[1].start, s[1].stop)))
124
+ for s in splits
125
+ ]
126
+
127
+
128
+ def walk_forward_validation(
129
+ model: Any,
130
+ X: List,
131
+ y: List,
132
+ train_fn: Callable,
133
+ eval_fn: Callable,
134
+ window_type: WindowType = WindowType.EXPANDING,
135
+ train_size: int = 252,
136
+ test_size: int = 21,
137
+ n_splits: Optional[int] = None
138
+ ) -> List[dict]:
139
+ """
140
+ Perform walk-forward validation on a model.
141
+
142
+ Args:
143
+ model: Model to train and evaluate
144
+ X: Feature data
145
+ y: Target data
146
+ train_fn: Function to train model: train_fn(model, X_train, y_train) -> trained_model
147
+ eval_fn: Function to evaluate: eval_fn(model, X_test, y_test) -> metrics_dict
148
+ window_type: Type of window
149
+ train_size: Initial training window size
150
+ test_size: Test window size
151
+ n_splits: Number of splits
152
+
153
+ Returns:
154
+ List of evaluation metrics for each split
155
+ """
156
+ if len(X) != len(y):
157
+ raise ValueError("X and y must have same length")
158
+
159
+ wfo = WalkForwardOptimizer(
160
+ window_type=window_type,
161
+ train_size=train_size,
162
+ test_size=test_size
163
+ )
164
+
165
+ results = []
166
+ splits = wfo.get_splits(len(X), n_splits)
167
+
168
+ for train_idx, test_idx in splits:
169
+ # Get train/test data
170
+ X_train = [X[i] for i in train_idx]
171
+ y_train = [y[i] for i in train_idx]
172
+ X_test = [X[i] for i in test_idx]
173
+ y_test = [y[i] for i in test_idx]
174
+
175
+ # Train model
176
+ trained_model = train_fn(model, X_train, y_train)
177
+
178
+ # Evaluate
179
+ metrics = eval_fn(trained_model, X_test, y_test)
180
+ metrics['train_size'] = len(train_idx)
181
+ metrics['test_size'] = len(test_idx)
182
+ metrics['train_start'] = train_idx[0]
183
+ metrics['train_end'] = train_idx[-1]
184
+ metrics['test_start'] = test_idx[0]
185
+ metrics['test_end'] = test_idx[-1]
186
+
187
+ results.append(metrics)
188
+
189
+ return results
190
+
@@ -0,0 +1,51 @@
1
+ """
2
+ QuantML Utilities
3
+
4
+ This module provides utility functions for profiling, gradient checking,
5
+ model serialization, and CPU-optimized operations.
6
+ """
7
+
8
+ from quantml.utils.profiling import (
9
+ timing,
10
+ measure_latency,
11
+ measure_latency_microseconds,
12
+ PerformanceProfiler,
13
+ benchmark
14
+ )
15
+
16
+ from quantml.utils.gradient_check import (
17
+ check_gradients,
18
+ gradient_check_layer,
19
+ print_gradient_check_results,
20
+ quick_gradient_check
21
+ )
22
+
23
+ from quantml.utils.serialization import (
24
+ save_model,
25
+ load_model,
26
+ save_checkpoint,
27
+ load_checkpoint,
28
+ get_model_state_dict,
29
+ set_model_state_dict
30
+ )
31
+
32
+ __all__ = [
33
+ # Profiling
34
+ 'timing',
35
+ 'measure_latency',
36
+ 'measure_latency_microseconds',
37
+ 'PerformanceProfiler',
38
+ 'benchmark',
39
+ # Gradient checking
40
+ 'check_gradients',
41
+ 'gradient_check_layer',
42
+ 'print_gradient_check_results',
43
+ 'quick_gradient_check',
44
+ # Serialization
45
+ 'save_model',
46
+ 'load_model',
47
+ 'save_checkpoint',
48
+ 'load_checkpoint',
49
+ 'get_model_state_dict',
50
+ 'set_model_state_dict',
51
+ ]
@@ -0,0 +1,274 @@
1
+ """
2
+ Gradient checking utilities for verifying autograd correctness.
3
+
4
+ This module provides numerical gradient checking to verify that the
5
+ analytical gradients computed by the autograd engine are correct.
6
+ """
7
+
8
+ from typing import Callable, List, Tuple, Optional, Union
9
+ from quantml.tensor import Tensor
10
+
11
+
12
+ def check_gradients(
13
+ func: Callable[[Tensor], Tensor],
14
+ inputs: Tensor,
15
+ eps: float = 1e-5,
16
+ rtol: float = 1e-3,
17
+ atol: float = 1e-5
18
+ ) -> Tuple[bool, List[dict]]:
19
+ """
20
+ Compare analytical gradients to numerical gradients.
21
+
22
+ Uses finite differences to compute numerical gradients and compares
23
+ them to the analytical gradients from backpropagation.
24
+
25
+ Args:
26
+ func: Function that takes a tensor and returns a scalar tensor
27
+ inputs: Input tensor to compute gradients for
28
+ eps: Small value for finite differences (default: 1e-5)
29
+ rtol: Relative tolerance for comparison (default: 1e-3)
30
+ atol: Absolute tolerance for comparison (default: 1e-5)
31
+
32
+ Returns:
33
+ Tuple of (passed, details) where:
34
+ - passed: True if all gradients match within tolerance
35
+ - details: List of dicts with 'index', 'numerical', 'analytical', 'diff'
36
+
37
+ Examples:
38
+ >>> def f(x):
39
+ ... return ops.sum(ops.mul(x, x)) # sum(x^2)
40
+ >>> x = Tensor([[1.0, 2.0, 3.0]], requires_grad=True)
41
+ >>> passed, details = check_gradients(f, x)
42
+ >>> print(passed) # True
43
+ >>> # Analytical gradient should be 2*x = [2.0, 4.0, 6.0]
44
+ """
45
+ # Create a copy with requires_grad=True
46
+ x = Tensor(inputs.data, requires_grad=True)
47
+
48
+ # Forward pass
49
+ y = func(x)
50
+
51
+ # Backward pass to get analytical gradients
52
+ y.backward()
53
+ analytical_grad = x.grad
54
+
55
+ # Get flattened data
56
+ if isinstance(x.data[0], list):
57
+ flat_data = [val for row in x.data for val in row]
58
+ shape_2d = True
59
+ rows = len(x.data)
60
+ cols = len(x.data[0])
61
+ else:
62
+ flat_data = list(x.data)
63
+ shape_2d = False
64
+ rows = 1
65
+ cols = len(x.data)
66
+
67
+ # Flatten analytical gradient
68
+ if analytical_grad is not None:
69
+ if isinstance(analytical_grad, list):
70
+ if analytical_grad and isinstance(analytical_grad[0], list):
71
+ flat_analytical = [val for row in analytical_grad for val in row]
72
+ else:
73
+ flat_analytical = list(analytical_grad)
74
+ else:
75
+ # NumPy array
76
+ flat_analytical = analytical_grad.flatten().tolist()
77
+ else:
78
+ flat_analytical = [0.0] * len(flat_data)
79
+
80
+ # Compute numerical gradients using central differences
81
+ numerical_grads = []
82
+ details = []
83
+ all_passed = True
84
+
85
+ for i in range(len(flat_data)):
86
+ # Perturb positively
87
+ data_plus = flat_data.copy()
88
+ data_plus[i] += eps
89
+
90
+ # Reshape to original shape
91
+ if shape_2d:
92
+ reshaped_plus = [data_plus[r*cols:(r+1)*cols] for r in range(rows)]
93
+ else:
94
+ reshaped_plus = data_plus
95
+
96
+ x_plus = Tensor(reshaped_plus, requires_grad=False)
97
+ y_plus = func(x_plus)
98
+
99
+ # Extract scalar value
100
+ y_plus_val = _get_scalar(y_plus)
101
+
102
+ # Perturb negatively
103
+ data_minus = flat_data.copy()
104
+ data_minus[i] -= eps
105
+
106
+ if shape_2d:
107
+ reshaped_minus = [data_minus[r*cols:(r+1)*cols] for r in range(rows)]
108
+ else:
109
+ reshaped_minus = data_minus
110
+
111
+ x_minus = Tensor(reshaped_minus, requires_grad=False)
112
+ y_minus = func(x_minus)
113
+ y_minus_val = _get_scalar(y_minus)
114
+
115
+ # Central difference
116
+ numerical = (y_plus_val - y_minus_val) / (2 * eps)
117
+ numerical_grads.append(numerical)
118
+
119
+ # Get analytical gradient for this index
120
+ analytical = float(flat_analytical[i]) if i < len(flat_analytical) else 0.0
121
+
122
+ # Compare
123
+ diff = abs(numerical - analytical)
124
+ rel_diff = diff / (abs(analytical) + atol) if abs(analytical) > atol else diff
125
+
126
+ passed = diff <= atol or rel_diff <= rtol
127
+ if not passed:
128
+ all_passed = False
129
+
130
+ # Compute 2D index if applicable
131
+ if shape_2d:
132
+ idx_tuple = (i // cols, i % cols)
133
+ else:
134
+ idx_tuple = (i,)
135
+
136
+ details.append({
137
+ 'index': idx_tuple,
138
+ 'numerical': numerical,
139
+ 'analytical': analytical,
140
+ 'diff': diff,
141
+ 'rel_diff': rel_diff,
142
+ 'passed': passed
143
+ })
144
+
145
+ return all_passed, details
146
+
147
+
148
+ def _get_scalar(t: Tensor) -> float:
149
+ """Extract scalar value from tensor."""
150
+ data = t.data
151
+ if isinstance(data, list):
152
+ if len(data) == 0:
153
+ return 0.0
154
+ if isinstance(data[0], list):
155
+ return float(data[0][0])
156
+ return float(data[0])
157
+ # NumPy array or scalar
158
+ try:
159
+ return float(data.flat[0])
160
+ except (AttributeError, TypeError):
161
+ return float(data)
162
+
163
+
164
+ def gradient_check_layer(
165
+ layer,
166
+ input_tensor: Tensor,
167
+ eps: float = 1e-5,
168
+ rtol: float = 1e-3,
169
+ atol: float = 1e-5
170
+ ) -> Tuple[bool, dict]:
171
+ """
172
+ Check gradients for a layer's parameters.
173
+
174
+ Verifies that the gradients w.r.t. the layer's weights and biases
175
+ are correctly computed.
176
+
177
+ Args:
178
+ layer: A layer with forward() method and parameters() method
179
+ input_tensor: Input to pass through the layer
180
+ eps: Small value for finite differences
181
+ rtol: Relative tolerance
182
+ atol: Absolute tolerance
183
+
184
+ Returns:
185
+ Tuple of (passed, results) where results contains gradient check
186
+ details for each parameter.
187
+
188
+ Examples:
189
+ >>> from quantml.models import Linear
190
+ >>> from quantml import ops
191
+ >>> layer = Linear(3, 2)
192
+ >>> x = Tensor([[1.0, 2.0, 3.0]])
193
+ >>> passed, results = gradient_check_layer(layer, x)
194
+ """
195
+ from quantml import ops
196
+
197
+ results = {}
198
+ all_passed = True
199
+
200
+ # Get parameters
201
+ params = layer.parameters()
202
+
203
+ for param_idx, param in enumerate(params):
204
+ param_name = f"param_{param_idx}"
205
+
206
+ # Define function that uses this parameter
207
+ def func_for_param(p):
208
+ # Temporarily replace parameter data
209
+ old_data = param.data
210
+ param._data = p.data if hasattr(p, 'data') else p
211
+ if hasattr(param, '_np_array'):
212
+ param._np_array = None # Clear cached numpy
213
+
214
+ # Forward pass
215
+ out = layer.forward(input_tensor)
216
+
217
+ # Sum to get scalar
218
+ scalar_out = ops.sum(out)
219
+
220
+ # Restore
221
+ param._data = old_data
222
+
223
+ return scalar_out
224
+
225
+ # Check gradients for this parameter
226
+ passed, details = check_gradients(func_for_param, param, eps, rtol, atol)
227
+
228
+ if not passed:
229
+ all_passed = False
230
+
231
+ results[param_name] = {
232
+ 'passed': passed,
233
+ 'details': details
234
+ }
235
+
236
+ return all_passed, results
237
+
238
+
239
+ def print_gradient_check_results(passed: bool, details: List[dict]) -> None:
240
+ """
241
+ Print formatted gradient check results.
242
+
243
+ Args:
244
+ passed: Overall pass/fail status
245
+ details: List of detail dicts from check_gradients
246
+ """
247
+ print("=" * 60)
248
+ print("GRADIENT CHECK RESULTS")
249
+ print("=" * 60)
250
+ print(f"Overall: {'PASSED ✓' if passed else 'FAILED ✗'}")
251
+ print("-" * 60)
252
+ print(f"{'Index':<15} {'Numerical':<15} {'Analytical':<15} {'Diff':<12} {'Status':<8}")
253
+ print("-" * 60)
254
+
255
+ for d in details:
256
+ status = "✓" if d['passed'] else "✗"
257
+ print(f"{str(d['index']):<15} {d['numerical']:<15.6f} {d['analytical']:<15.6f} {d['diff']:<12.2e} {status:<8}")
258
+
259
+ print("=" * 60)
260
+
261
+
262
+ def quick_gradient_check(func: Callable, inputs: Tensor) -> bool:
263
+ """
264
+ Quick gradient check that returns True if gradients are correct.
265
+
266
+ Args:
267
+ func: Function to check
268
+ inputs: Input tensor
269
+
270
+ Returns:
271
+ True if gradients pass, False otherwise
272
+ """
273
+ passed, _ = check_gradients(func, inputs)
274
+ return passed
@@ -0,0 +1,181 @@
1
+ """
2
+ Centralized logging system for QuantML.
3
+
4
+ Provides structured logging with file rotation and experiment tracking.
5
+ """
6
+
7
+ import logging
8
+ import sys
9
+ import os
10
+ from logging.handlers import RotatingFileHandler
11
+ from typing import Optional, Dict, Any
12
+ from datetime import datetime
13
+ import json
14
+
15
+
16
+ def setup_logger(
17
+ name: str = "quantml",
18
+ log_level: str = "INFO",
19
+ log_dir: Optional[str] = None,
20
+ log_file: Optional[str] = None,
21
+ console_output: bool = True
22
+ ) -> logging.Logger:
23
+ """
24
+ Set up a logger with file and console handlers.
25
+
26
+ Args:
27
+ name: Logger name
28
+ log_level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
29
+ log_dir: Directory for log files (default: ./logs)
30
+ log_file: Log file name (default: quantml_YYYYMMDD.log)
31
+ console_output: Whether to output to console
32
+
33
+ Returns:
34
+ Configured logger
35
+ """
36
+ logger = logging.getLogger(name)
37
+ logger.setLevel(getattr(logging, log_level.upper()))
38
+
39
+ # Remove existing handlers
40
+ logger.handlers.clear()
41
+
42
+ # Create formatters
43
+ detailed_formatter = logging.Formatter(
44
+ '%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s',
45
+ datefmt='%Y-%m-%d %H:%M:%S'
46
+ )
47
+ simple_formatter = logging.Formatter(
48
+ '%(asctime)s - %(levelname)s - %(message)s',
49
+ datefmt='%H:%M:%S'
50
+ )
51
+
52
+ # Console handler
53
+ if console_output:
54
+ console_handler = logging.StreamHandler(sys.stdout)
55
+ console_handler.setLevel(logging.INFO)
56
+ console_handler.setFormatter(simple_formatter)
57
+ logger.addHandler(console_handler)
58
+
59
+ # File handler with rotation
60
+ if log_dir is None:
61
+ log_dir = "./logs"
62
+
63
+ os.makedirs(log_dir, exist_ok=True)
64
+
65
+ if log_file is None:
66
+ log_file = f"quantml_{datetime.now().strftime('%Y%m%d')}.log"
67
+
68
+ log_path = os.path.join(log_dir, log_file)
69
+
70
+ file_handler = RotatingFileHandler(
71
+ log_path,
72
+ maxBytes=10 * 1024 * 1024, # 10MB
73
+ backupCount=5
74
+ )
75
+ file_handler.setLevel(logging.DEBUG)
76
+ file_handler.setFormatter(detailed_formatter)
77
+ logger.addHandler(file_handler)
78
+
79
+ return logger
80
+
81
+
82
+ def log_experiment_start(
83
+ logger: logging.Logger,
84
+ config: Dict[str, Any],
85
+ experiment_id: Optional[str] = None
86
+ ):
87
+ """
88
+ Log experiment start with metadata.
89
+
90
+ Args:
91
+ logger: Logger instance
92
+ config: Experiment configuration dictionary
93
+ experiment_id: Optional experiment ID
94
+ """
95
+ logger.info("=" * 70)
96
+ logger.info("EXPERIMENT START")
97
+ logger.info("=" * 70)
98
+
99
+ if experiment_id:
100
+ logger.info(f"Experiment ID: {experiment_id}")
101
+
102
+ logger.info(f"Timestamp: {datetime.now().isoformat()}")
103
+ logger.info(f"Configuration: {json.dumps(config, indent=2)}")
104
+
105
+ # Try to get git hash if available
106
+ try:
107
+ import subprocess
108
+ git_hash = subprocess.check_output(
109
+ ['git', 'rev-parse', 'HEAD'],
110
+ stderr=subprocess.DEVNULL
111
+ ).decode().strip()
112
+ logger.info(f"Git commit: {git_hash}")
113
+ except Exception:
114
+ pass
115
+
116
+ logger.info("=" * 70)
117
+
118
+
119
+ def log_experiment_end(
120
+ logger: logging.Logger,
121
+ metrics: Dict[str, Any],
122
+ experiment_id: Optional[str] = None
123
+ ):
124
+ """
125
+ Log experiment end with results.
126
+
127
+ Args:
128
+ logger: Logger instance
129
+ metrics: Experiment metrics dictionary
130
+ experiment_id: Optional experiment ID
131
+ """
132
+ logger.info("=" * 70)
133
+ logger.info("EXPERIMENT END")
134
+ logger.info("=" * 70)
135
+
136
+ if experiment_id:
137
+ logger.info(f"Experiment ID: {experiment_id}")
138
+
139
+ logger.info(f"Timestamp: {datetime.now().isoformat()}")
140
+ logger.info(f"Metrics: {json.dumps(metrics, indent=2)}")
141
+ logger.info("=" * 70)
142
+
143
+
144
+ def log_training_progress(
145
+ logger: logging.Logger,
146
+ epoch: int,
147
+ train_loss: float,
148
+ val_loss: Optional[float] = None,
149
+ metrics: Optional[Dict[str, float]] = None
150
+ ):
151
+ """
152
+ Log training progress.
153
+
154
+ Args:
155
+ logger: Logger instance
156
+ epoch: Current epoch
157
+ train_loss: Training loss
158
+ val_loss: Validation loss (optional)
159
+ metrics: Additional metrics (optional)
160
+ """
161
+ msg = f"Epoch {epoch}: Train Loss = {train_loss:.6f}"
162
+ if val_loss is not None:
163
+ msg += f", Val Loss = {val_loss:.6f}"
164
+ if metrics:
165
+ metric_str = ", ".join([f"{k} = {v:.4f}" for k, v in metrics.items()])
166
+ msg += f", {metric_str}"
167
+ logger.info(msg)
168
+
169
+
170
+ def get_logger(name: str = "quantml") -> logging.Logger:
171
+ """
172
+ Get or create a logger instance.
173
+
174
+ Args:
175
+ name: Logger name
176
+
177
+ Returns:
178
+ Logger instance
179
+ """
180
+ return logging.getLogger(name)
181
+