quantmllibrary 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. quantml/__init__.py +74 -0
  2. quantml/autograd.py +154 -0
  3. quantml/cli/__init__.py +10 -0
  4. quantml/cli/run_experiment.py +385 -0
  5. quantml/config/__init__.py +28 -0
  6. quantml/config/config.py +259 -0
  7. quantml/data/__init__.py +33 -0
  8. quantml/data/cache.py +149 -0
  9. quantml/data/feature_store.py +234 -0
  10. quantml/data/futures.py +254 -0
  11. quantml/data/loaders.py +236 -0
  12. quantml/data/memory_optimizer.py +234 -0
  13. quantml/data/validators.py +390 -0
  14. quantml/experiments/__init__.py +23 -0
  15. quantml/experiments/logger.py +208 -0
  16. quantml/experiments/results.py +158 -0
  17. quantml/experiments/tracker.py +223 -0
  18. quantml/features/__init__.py +25 -0
  19. quantml/features/base.py +104 -0
  20. quantml/features/gap_features.py +124 -0
  21. quantml/features/registry.py +138 -0
  22. quantml/features/volatility_features.py +140 -0
  23. quantml/features/volume_features.py +142 -0
  24. quantml/functional.py +37 -0
  25. quantml/models/__init__.py +27 -0
  26. quantml/models/attention.py +258 -0
  27. quantml/models/dropout.py +130 -0
  28. quantml/models/gru.py +319 -0
  29. quantml/models/linear.py +112 -0
  30. quantml/models/lstm.py +353 -0
  31. quantml/models/mlp.py +286 -0
  32. quantml/models/normalization.py +289 -0
  33. quantml/models/rnn.py +154 -0
  34. quantml/models/tcn.py +238 -0
  35. quantml/online.py +209 -0
  36. quantml/ops.py +1707 -0
  37. quantml/optim/__init__.py +42 -0
  38. quantml/optim/adafactor.py +206 -0
  39. quantml/optim/adagrad.py +157 -0
  40. quantml/optim/adam.py +267 -0
  41. quantml/optim/lookahead.py +97 -0
  42. quantml/optim/quant_optimizer.py +228 -0
  43. quantml/optim/radam.py +192 -0
  44. quantml/optim/rmsprop.py +203 -0
  45. quantml/optim/schedulers.py +286 -0
  46. quantml/optim/sgd.py +181 -0
  47. quantml/py.typed +0 -0
  48. quantml/streaming.py +175 -0
  49. quantml/tensor.py +462 -0
  50. quantml/time_series.py +447 -0
  51. quantml/training/__init__.py +135 -0
  52. quantml/training/alpha_eval.py +203 -0
  53. quantml/training/backtest.py +280 -0
  54. quantml/training/backtest_analysis.py +168 -0
  55. quantml/training/cv.py +106 -0
  56. quantml/training/data_loader.py +177 -0
  57. quantml/training/ensemble.py +84 -0
  58. quantml/training/feature_importance.py +135 -0
  59. quantml/training/features.py +364 -0
  60. quantml/training/futures_backtest.py +266 -0
  61. quantml/training/gradient_clipping.py +206 -0
  62. quantml/training/losses.py +248 -0
  63. quantml/training/lr_finder.py +127 -0
  64. quantml/training/metrics.py +376 -0
  65. quantml/training/regularization.py +89 -0
  66. quantml/training/trainer.py +239 -0
  67. quantml/training/walk_forward.py +190 -0
  68. quantml/utils/__init__.py +51 -0
  69. quantml/utils/gradient_check.py +274 -0
  70. quantml/utils/logging.py +181 -0
  71. quantml/utils/ops_cpu.py +231 -0
  72. quantml/utils/profiling.py +364 -0
  73. quantml/utils/reproducibility.py +220 -0
  74. quantml/utils/serialization.py +335 -0
  75. quantmllibrary-0.1.0.dist-info/METADATA +536 -0
  76. quantmllibrary-0.1.0.dist-info/RECORD +79 -0
  77. quantmllibrary-0.1.0.dist-info/WHEEL +5 -0
  78. quantmllibrary-0.1.0.dist-info/licenses/LICENSE +22 -0
  79. quantmllibrary-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,231 @@
1
+ """
2
+ CPU-optimized operations for QuantML.
3
+
4
+ This module provides low-level CPU-optimized implementations of operations.
5
+ It uses NumPy when available for better performance, with pure Python fallbacks.
6
+ """
7
+
8
+ from typing import List, Union, Any
9
+ import math
10
+
11
+ # Try to import NumPy for optimized operations
12
+ try:
13
+ import numpy as np
14
+ HAS_NUMPY = True
15
+ except ImportError:
16
+ HAS_NUMPY = False
17
+
18
+
19
+ def add_cpu(a: Union[List, Any], b: Union[List, Any]) -> List:
20
+ """
21
+ CPU-optimized addition.
22
+
23
+ Uses NumPy if available, otherwise falls back to pure Python.
24
+
25
+ Args:
26
+ a: First operand
27
+ b: Second operand
28
+
29
+ Returns:
30
+ Result of addition
31
+ """
32
+ if HAS_NUMPY:
33
+ a_arr = np.array(a)
34
+ b_arr = np.array(b)
35
+ result = (a_arr + b_arr).tolist()
36
+ return result
37
+ else:
38
+ return _add_pure_python(a, b)
39
+
40
+
41
+ def mul_cpu(a: Union[List, Any], b: Union[List, Any]) -> List:
42
+ """CPU-optimized multiplication."""
43
+ if HAS_NUMPY:
44
+ a_arr = np.array(a)
45
+ b_arr = np.array(b)
46
+ result = (a_arr * b_arr).tolist()
47
+ return result
48
+ else:
49
+ return _mul_pure_python(a, b)
50
+
51
+
52
+ def matmul_cpu(a: Union[List, Any], b: Union[List, Any]) -> List:
53
+ """CPU-optimized matrix multiplication."""
54
+ if HAS_NUMPY:
55
+ a_arr = np.array(a)
56
+ b_arr = np.array(b)
57
+ result = np.matmul(a_arr, b_arr).tolist()
58
+ return result
59
+ else:
60
+ return _matmul_pure_python(a, b)
61
+
62
+
63
+ def sum_cpu(a: Union[List, Any], axis: int = None) -> Union[float, List]:
64
+ """CPU-optimized sum."""
65
+ if HAS_NUMPY:
66
+ a_arr = np.array(a)
67
+ result = np.sum(a_arr, axis=axis)
68
+ if isinstance(result, np.ndarray):
69
+ return result.tolist()
70
+ return float(result)
71
+ else:
72
+ return _sum_pure_python(a, axis)
73
+
74
+
75
+ def mean_cpu(a: Union[List, Any], axis: int = None) -> Union[float, List]:
76
+ """CPU-optimized mean."""
77
+ if HAS_NUMPY:
78
+ a_arr = np.array(a)
79
+ result = np.mean(a_arr, axis=axis)
80
+ if isinstance(result, np.ndarray):
81
+ return result.tolist()
82
+ return float(result)
83
+ else:
84
+ return _mean_pure_python(a, axis)
85
+
86
+
87
+ def std_cpu(a: Union[List, Any], axis: int = None) -> Union[float, List]:
88
+ """CPU-optimized standard deviation."""
89
+ if HAS_NUMPY:
90
+ a_arr = np.array(a)
91
+ result = np.std(a_arr, axis=axis)
92
+ if isinstance(result, np.ndarray):
93
+ return result.tolist()
94
+ return float(result)
95
+ else:
96
+ return _std_pure_python(a, axis)
97
+
98
+
99
+ # Pure Python implementations (fallbacks)
100
+
101
+ def _add_pure_python(a: Any, b: Any) -> List:
102
+ """Pure Python addition."""
103
+ if isinstance(a, list) and isinstance(b, list):
104
+ if isinstance(a[0], list) and isinstance(b[0], list):
105
+ return [[a[i][j] + b[i][j] for j in range(len(a[i]))]
106
+ for i in range(len(a))]
107
+ else:
108
+ return [a[i] + b[i] for i in range(len(a))]
109
+ elif isinstance(a, list):
110
+ if isinstance(a[0], list):
111
+ return [[a[i][j] + b for j in range(len(a[i]))]
112
+ for i in range(len(a))]
113
+ return [a[i] + b for i in range(len(a))]
114
+ elif isinstance(b, list):
115
+ if isinstance(b[0], list):
116
+ return [[a + b[i][j] for j in range(len(b[i]))]
117
+ for i in range(len(b))]
118
+ return [a + b[i] for i in range(len(b))]
119
+ else:
120
+ return [[a + b]]
121
+
122
+
123
+ def _mul_pure_python(a: Any, b: Any) -> List:
124
+ """Pure Python multiplication."""
125
+ if isinstance(a, list) and isinstance(b, list):
126
+ if isinstance(a[0], list) and isinstance(b[0], list):
127
+ return [[a[i][j] * b[i][j] for j in range(len(a[i]))]
128
+ for i in range(len(a))]
129
+ else:
130
+ return [a[i] * b[i] for i in range(len(a))]
131
+ elif isinstance(a, list):
132
+ if isinstance(a[0], list):
133
+ return [[a[i][j] * b for j in range(len(a[i]))]
134
+ for i in range(len(a))]
135
+ return [a[i] * b for i in range(len(a))]
136
+ elif isinstance(b, list):
137
+ if isinstance(b[0], list):
138
+ return [[a * b[i][j] for j in range(len(b[i]))]
139
+ for i in range(len(b))]
140
+ return [a * b[i] for i in range(len(b))]
141
+ else:
142
+ return [[a * b]]
143
+
144
+
145
+ def _matmul_pure_python(a: List, b: List) -> List:
146
+ """Pure Python matrix multiplication."""
147
+ # Ensure 2D
148
+ a_2d = a if isinstance(a[0], list) else [a]
149
+ b_2d = b if isinstance(b[0], list) else [b]
150
+
151
+ m, n = len(a_2d), len(a_2d[0])
152
+ n2, p = len(b_2d), len(b_2d[0])
153
+
154
+ if n != n2:
155
+ raise ValueError(f"Incompatible dimensions: {m}x{n} x {n2}x{p}")
156
+
157
+ result = [[sum(a_2d[i][k] * b_2d[k][j] for k in range(n))
158
+ for j in range(p)]
159
+ for i in range(m)]
160
+ return result
161
+
162
+
163
+ def _sum_pure_python(a: Any, axis: int = None) -> Union[float, List]:
164
+ """Pure Python sum."""
165
+ if axis is None:
166
+ if isinstance(a, list):
167
+ if isinstance(a[0], list):
168
+ return sum(sum(row) for row in a)
169
+ return sum(a)
170
+ return float(a)
171
+ elif axis == 0:
172
+ if isinstance(a[0], list):
173
+ return [sum(a[i][j] for i in range(len(a)))
174
+ for j in range(len(a[0]))]
175
+ return [sum(a)]
176
+ elif axis == 1:
177
+ if isinstance(a[0], list):
178
+ return [sum(row) for row in a]
179
+ return a
180
+ else:
181
+ raise ValueError(f"Invalid axis: {axis}")
182
+
183
+
184
+ def _mean_pure_python(a: Any, axis: int = None) -> Union[float, List]:
185
+ """Pure Python mean."""
186
+ s = _sum_pure_python(a, axis)
187
+ if axis is None:
188
+ count = 1.0
189
+ if isinstance(a, list):
190
+ if isinstance(a[0], list):
191
+ count = len(a) * len(a[0])
192
+ else:
193
+ count = len(a)
194
+ return s / count if count > 0 else 0.0
195
+ elif axis == 0:
196
+ count = len(a) if isinstance(a[0], list) else 1.0
197
+ if isinstance(s, list):
198
+ return [x / count for x in s]
199
+ return s / count
200
+ else:
201
+ count = len(a[0]) if isinstance(a[0], list) else len(a)
202
+ if isinstance(s, list):
203
+ return [x / count for x in s]
204
+ return s / count
205
+
206
+
207
+ def _std_pure_python(a: Any, axis: int = None) -> Union[float, List]:
208
+ """Pure Python standard deviation."""
209
+ m = _mean_pure_python(a, axis)
210
+
211
+ # Compute variance
212
+ if axis is None:
213
+ if isinstance(a, list):
214
+ if isinstance(a[0], list):
215
+ diff_sq = sum((a[i][j] - m) ** 2
216
+ for i in range(len(a))
217
+ for j in range(len(a[i])))
218
+ count = len(a) * len(a[0])
219
+ else:
220
+ diff_sq = sum((a[i] - m) ** 2 for i in range(len(a)))
221
+ count = len(a)
222
+ else:
223
+ diff_sq = (a - m) ** 2
224
+ count = 1.0
225
+ var = diff_sq / count if count > 0 else 0.0
226
+ return math.sqrt(var)
227
+ else:
228
+ # For axis-specific std, use simplified approach
229
+ # Full implementation would be more complex
230
+ return m # Placeholder - full implementation needed
231
+
@@ -0,0 +1,364 @@
1
+ """
2
+ Profiling and performance measurement utilities.
3
+
4
+ This module provides tools for measuring latency and performance of operations,
5
+ essential for optimizing quant trading systems.
6
+ """
7
+
8
+ import time
9
+ import functools
10
+ import sys
11
+ from typing import Callable, Any, Optional, Dict, List
12
+ from collections import defaultdict
13
+
14
+ # Try to import psutil for memory tracking
15
+ try:
16
+ import psutil
17
+ HAS_PSUTIL = True
18
+ except ImportError:
19
+ HAS_PSUTIL = False
20
+ psutil = None
21
+
22
+
23
+ def timing(func: Callable) -> Callable:
24
+ """
25
+ Decorator to measure execution time of a function.
26
+
27
+ The decorator prints the execution time and returns the result.
28
+
29
+ Args:
30
+ func: Function to time
31
+
32
+ Returns:
33
+ Wrapped function that measures execution time
34
+
35
+ Examples:
36
+ >>> @timing
37
+ >>> def my_function():
38
+ >>> # ... do work ...
39
+ >>> return result
40
+ >>> result = my_function() # Prints: "my_function took 0.123 seconds"
41
+ """
42
+ @functools.wraps(func)
43
+ def wrapper(*args, **kwargs):
44
+ start = time.perf_counter()
45
+ result = func(*args, **kwargs)
46
+ end = time.perf_counter()
47
+ elapsed = end - start
48
+ print(f"{func.__name__} took {elapsed:.6f} seconds")
49
+ return result
50
+ return wrapper
51
+
52
+
53
+ def measure_latency(func: Callable, *args, **kwargs) -> tuple:
54
+ """
55
+ Measure the latency of a function call.
56
+
57
+ Args:
58
+ func: Function to measure
59
+ *args: Positional arguments for function
60
+ **kwargs: Keyword arguments for function
61
+
62
+ Returns:
63
+ Tuple of (result, latency_in_seconds)
64
+
65
+ Examples:
66
+ >>> result, latency = measure_latency(my_function, arg1, arg2)
67
+ >>> print(f"Latency: {latency * 1000:.2f} ms")
68
+ """
69
+ start = time.perf_counter()
70
+ result = func(*args, **kwargs)
71
+ end = time.perf_counter()
72
+ latency = end - start
73
+ return result, latency
74
+
75
+
76
+ def measure_latency_microseconds(func: Callable, *args, **kwargs) -> tuple:
77
+ """
78
+ Measure latency in microseconds (useful for HFT).
79
+
80
+ Args:
81
+ func: Function to measure
82
+ *args: Positional arguments
83
+ **kwargs: Keyword arguments
84
+
85
+ Returns:
86
+ Tuple of (result, latency_in_microseconds)
87
+ """
88
+ result, latency_seconds = measure_latency(func, *args, **kwargs)
89
+ latency_us = latency_seconds * 1_000_000
90
+ return result, latency_us
91
+
92
+
93
+ class PerformanceProfiler:
94
+ """
95
+ Performance profiler for tracking multiple function calls.
96
+
97
+ Tracks statistics like mean, min, max latency for multiple calls.
98
+
99
+ Attributes:
100
+ stats: Dictionary mapping function names to statistics
101
+
102
+ Examples:
103
+ >>> profiler = PerformanceProfiler()
104
+ >>> for _ in range(100):
105
+ >>> profiler.record('my_func', measure_latency(my_function))
106
+ >>> print(profiler.get_stats('my_func'))
107
+ """
108
+
109
+ def __init__(self):
110
+ """Initialize profiler."""
111
+ self.stats = defaultdict(list)
112
+
113
+ def record(self, name: str, latency: float):
114
+ """
115
+ Record a latency measurement.
116
+
117
+ Args:
118
+ name: Function or operation name
119
+ latency: Latency in seconds
120
+ """
121
+ self.stats[name].append(latency)
122
+
123
+ def get_stats(self, name: str) -> Optional[dict]:
124
+ """
125
+ Get statistics for a function.
126
+
127
+ Args:
128
+ name: Function name
129
+
130
+ Returns:
131
+ Dictionary with mean, min, max, count, or None if no data
132
+ """
133
+ if name not in self.stats or len(self.stats[name]) == 0:
134
+ return None
135
+
136
+ latencies = self.stats[name]
137
+ return {
138
+ 'mean': sum(latencies) / len(latencies),
139
+ 'min': min(latencies),
140
+ 'max': max(latencies),
141
+ 'count': len(latencies),
142
+ 'total': sum(latencies)
143
+ }
144
+
145
+ def print_stats(self, name: Optional[str] = None):
146
+ """
147
+ Print statistics for one or all functions.
148
+
149
+ Args:
150
+ name: Optional function name, or None for all
151
+ """
152
+ if name is not None:
153
+ stats = self.get_stats(name)
154
+ if stats:
155
+ print(f"{name}:")
156
+ print(f" Mean: {stats['mean']*1000:.3f} ms")
157
+ print(f" Min: {stats['min']*1000:.3f} ms")
158
+ print(f" Max: {stats['max']*1000:.3f} ms")
159
+ print(f" Count: {stats['count']}")
160
+ else:
161
+ for func_name in self.stats:
162
+ self.print_stats(func_name)
163
+
164
+ def clear(self):
165
+ """Clear all recorded statistics."""
166
+ self.stats.clear()
167
+
168
+
169
+ def benchmark(func: Callable, n_iterations: int = 100, *args, **kwargs) -> dict:
170
+ """
171
+ Benchmark a function over multiple iterations.
172
+
173
+ Args:
174
+ func: Function to benchmark
175
+ n_iterations: Number of iterations
176
+ *args: Positional arguments
177
+ **kwargs: Keyword arguments
178
+
179
+ Returns:
180
+ Dictionary with benchmark statistics
181
+
182
+ Examples:
183
+ >>> stats = benchmark(my_function, n_iterations=1000, arg1, arg2)
184
+ >>> print(f"Mean latency: {stats['mean']*1000:.2f} ms")
185
+ """
186
+ latencies = []
187
+ for _ in range(n_iterations):
188
+ _, latency = measure_latency(func, *args, **kwargs)
189
+ latencies.append(latency)
190
+
191
+ return {
192
+ 'mean': sum(latencies) / len(latencies),
193
+ 'min': min(latencies),
194
+ 'max': max(latencies),
195
+ 'median': sorted(latencies)[len(latencies) // 2],
196
+ 'p95': sorted(latencies)[int(len(latencies) * 0.95)],
197
+ 'p99': sorted(latencies)[int(len(latencies) * 0.99)],
198
+ 'count': len(latencies)
199
+ }
200
+
201
+
202
+ def get_memory_usage() -> Dict[str, float]:
203
+ """
204
+ Get current memory usage.
205
+
206
+ Returns:
207
+ Dictionary with memory usage in MB
208
+ """
209
+ if HAS_PSUTIL:
210
+ process = psutil.Process()
211
+ mem_info = process.memory_info()
212
+ return {
213
+ 'rss_mb': mem_info.rss / 1024 / 1024, # Resident Set Size
214
+ 'vms_mb': mem_info.vms / 1024 / 1024, # Virtual Memory Size
215
+ 'percent': process.memory_percent()
216
+ }
217
+ else:
218
+ # Fallback using sys
219
+ try:
220
+ import resource
221
+ mem_mb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024
222
+ return {'rss_mb': mem_mb, 'vms_mb': 0.0, 'percent': 0.0}
223
+ except (ImportError, AttributeError):
224
+ return {'rss_mb': 0.0, 'vms_mb': 0.0, 'percent': 0.0}
225
+
226
+
227
+ class PipelineProfiler:
228
+ """
229
+ Profiler for entire pipeline (data loading, feature generation, training).
230
+
231
+ Tracks timing and memory usage for each stage.
232
+ """
233
+
234
+ def __init__(self):
235
+ """Initialize pipeline profiler."""
236
+ self.stages = {}
237
+ self.memory_snapshots = []
238
+
239
+ def start_stage(self, stage_name: str):
240
+ """Start timing a stage."""
241
+ self.stages[stage_name] = {
242
+ 'start_time': time.perf_counter(),
243
+ 'start_memory': get_memory_usage()
244
+ }
245
+
246
+ def end_stage(self, stage_name: str):
247
+ """End timing a stage."""
248
+ if stage_name not in self.stages:
249
+ return
250
+
251
+ end_time = time.perf_counter()
252
+ end_memory = get_memory_usage()
253
+
254
+ start_info = self.stages[stage_name]
255
+ duration = end_time - start_info['start_time']
256
+
257
+ memory_delta = {
258
+ 'rss_mb': end_memory['rss_mb'] - start_info['start_memory']['rss_mb'],
259
+ 'vms_mb': end_memory['vms_mb'] - start_info['start_memory']['vms_mb']
260
+ }
261
+
262
+ self.stages[stage_name] = {
263
+ 'duration': duration,
264
+ 'memory_delta': memory_delta,
265
+ 'peak_memory': end_memory
266
+ }
267
+
268
+ def get_report(self) -> Dict[str, Any]:
269
+ """Get profiling report."""
270
+ total_time = sum(s.get('duration', 0) for s in self.stages.values())
271
+
272
+ return {
273
+ 'stages': self.stages,
274
+ 'total_time': total_time,
275
+ 'bottleneck': self._identify_bottleneck()
276
+ }
277
+
278
+ def _identify_bottleneck(self) -> Optional[str]:
279
+ """Identify the slowest stage."""
280
+ if not self.stages:
281
+ return None
282
+
283
+ max_duration = 0
284
+ bottleneck = None
285
+
286
+ for stage_name, info in self.stages.items():
287
+ duration = info.get('duration', 0)
288
+ if duration > max_duration:
289
+ max_duration = duration
290
+ bottleneck = stage_name
291
+
292
+ return bottleneck
293
+
294
+ def print_report(self):
295
+ """Print profiling report."""
296
+ report = self.get_report()
297
+
298
+ print("\n" + "=" * 70)
299
+ print("Pipeline Profiling Report")
300
+ print("=" * 70)
301
+
302
+ for stage_name, info in report['stages'].items():
303
+ duration = info.get('duration', 0)
304
+ memory_delta = info.get('memory_delta', {})
305
+
306
+ print(f"\n{stage_name}:")
307
+ print(f" Duration: {duration:.4f} seconds ({duration*1000:.2f} ms)")
308
+ print(f" Memory Delta: {memory_delta.get('rss_mb', 0):.2f} MB")
309
+
310
+ print(f"\nTotal Time: {report['total_time']:.4f} seconds")
311
+ if report['bottleneck']:
312
+ print(f"Bottleneck: {report['bottleneck']} ({report['stages'][report['bottleneck']]['duration']:.4f}s)")
313
+ print("=" * 70 + "\n")
314
+
315
+
316
+ def profile_training_loop(
317
+ trainer,
318
+ n_epochs: int,
319
+ log_interval: int = 10
320
+ ) -> Dict[str, Any]:
321
+ """
322
+ Profile a training loop.
323
+
324
+ Args:
325
+ trainer: Trainer instance
326
+ n_epochs: Number of epochs
327
+ log_interval: Log every N epochs
328
+
329
+ Returns:
330
+ Training profile dictionary
331
+ """
332
+ profiler = PipelineProfiler()
333
+ epoch_times = []
334
+
335
+ profiler.start_stage('total_training')
336
+
337
+ for epoch in range(n_epochs):
338
+ epoch_start = time.perf_counter()
339
+
340
+ # Train epoch (assuming trainer has train_epoch method)
341
+ if hasattr(trainer, 'train_epoch'):
342
+ trainer.train_epoch()
343
+ else:
344
+ # Fallback: assume train_step is called externally
345
+ pass
346
+
347
+ epoch_time = time.perf_counter() - epoch_start
348
+ epoch_times.append(epoch_time)
349
+
350
+ if (epoch + 1) % log_interval == 0:
351
+ avg_time = sum(epoch_times[-log_interval:]) / log_interval
352
+ print(f"Epoch {epoch+1}/{n_epochs}: {avg_time:.4f}s per epoch")
353
+
354
+ profiler.end_stage('total_training')
355
+
356
+ return {
357
+ 'total_time': sum(epoch_times),
358
+ 'avg_epoch_time': sum(epoch_times) / len(epoch_times),
359
+ 'min_epoch_time': min(epoch_times),
360
+ 'max_epoch_time': max(epoch_times),
361
+ 'epoch_times': epoch_times,
362
+ 'memory': get_memory_usage()
363
+ }
364
+