textpolicy 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. textpolicy/__init__.py +53 -0
  2. textpolicy/__main__.py +8 -0
  3. textpolicy/algorithms/__init__.py +54 -0
  4. textpolicy/algorithms/grpo.py +642 -0
  5. textpolicy/algorithms/gspo.py +582 -0
  6. textpolicy/buffer/__init__.py +23 -0
  7. textpolicy/buffer/buffer.py +244 -0
  8. textpolicy/buffer/episode.py +383 -0
  9. textpolicy/buffer/sampling.py +438 -0
  10. textpolicy/buffer/storage.py +255 -0
  11. textpolicy/cli.py +67 -0
  12. textpolicy/environment/__init__.py +79 -0
  13. textpolicy/environment/base.py +110 -0
  14. textpolicy/environment/environment.py +46 -0
  15. textpolicy/environment/factory.py +103 -0
  16. textpolicy/environment/gym.py +106 -0
  17. textpolicy/environment/task_suites.py +51 -0
  18. textpolicy/environment/text_generation.py +797 -0
  19. textpolicy/environment/vectorized.py +253 -0
  20. textpolicy/generation/__init__.py +62 -0
  21. textpolicy/generation/lora.py +411 -0
  22. textpolicy/generation/mlx_generation.py +557 -0
  23. textpolicy/generation/reload.py +253 -0
  24. textpolicy/rewards/__init__.py +137 -0
  25. textpolicy/rewards/adapters.py +387 -0
  26. textpolicy/rewards/basic.py +214 -0
  27. textpolicy/rewards/integrated_system.py +338 -0
  28. textpolicy/rewards/mlx_batch_processor.py +447 -0
  29. textpolicy/rewards/registry.py +293 -0
  30. textpolicy/rewards/rollout_rewards.py +410 -0
  31. textpolicy/rewards/verifiers.py +369 -0
  32. textpolicy/rollout/__init__.py +44 -0
  33. textpolicy/rollout/aggregator.py +145 -0
  34. textpolicy/rollout/base.py +108 -0
  35. textpolicy/rollout/rollout.py +142 -0
  36. textpolicy/rollout/runner.py +280 -0
  37. textpolicy/rollout/strategy.py +208 -0
  38. textpolicy/rollout/worker.py +194 -0
  39. textpolicy/training/__init__.py +14 -0
  40. textpolicy/training/metrics.py +242 -0
  41. textpolicy/training/rollout_manager.py +78 -0
  42. textpolicy/training/trainer.py +684 -0
  43. textpolicy/utils/__init__.py +40 -0
  44. textpolicy/utils/benchmarking.py +489 -0
  45. textpolicy/utils/data.py +60 -0
  46. textpolicy/utils/debug.py +170 -0
  47. textpolicy/utils/environment.py +349 -0
  48. textpolicy/utils/logging/__init__.py +22 -0
  49. textpolicy/utils/logging/base.py +48 -0
  50. textpolicy/utils/logging/console.py +61 -0
  51. textpolicy/utils/logging/factory.py +133 -0
  52. textpolicy/utils/logging/multi.py +83 -0
  53. textpolicy/utils/logging/tensorboard.py +65 -0
  54. textpolicy/utils/logging/wandb.py +72 -0
  55. textpolicy/utils/memory.py +118 -0
  56. textpolicy/utils/performance.py +464 -0
  57. textpolicy/utils/timing.py +171 -0
  58. textpolicy/validate.py +101 -0
  59. textpolicy/validation/__init__.py +13 -0
  60. textpolicy/validation/logprob_validation.py +315 -0
  61. textpolicy-0.1.1.dist-info/METADATA +109 -0
  62. textpolicy-0.1.1.dist-info/RECORD +66 -0
  63. {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/WHEEL +1 -1
  64. textpolicy-0.1.1.dist-info/entry_points.txt +2 -0
  65. textpolicy-0.0.1.dist-info/METADATA +0 -10
  66. textpolicy-0.0.1.dist-info/RECORD +0 -6
  67. {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/licenses/LICENSE +0 -0
  68. {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,464 @@
1
+ """
2
+ Performance monitoring and optimization utilities for MLX-RL.
3
+
4
+ This module provides real-time performance monitoring, bottleneck detection,
5
+ and optimization recommendations for training pipelines.
6
+ """
7
+
8
+ import time
9
+ import statistics
10
+ from typing import Dict, List, Optional, Tuple, Any
11
+ from dataclasses import dataclass, field
12
+ from collections import deque
13
+ from contextlib import contextmanager
14
+ from enum import Enum
15
+
16
+ from .memory import MemoryMonitor, get_memory_stats
17
+ from .debug import debug_print
18
+
19
+
20
+ class PerformanceCategory(Enum):
21
+ """Categories for performance tracking."""
22
+ ENVIRONMENT = "environment"
23
+ POLICY = "policy"
24
+ TRAINING = "training"
25
+ DATA_LOADING = "data_loading"
26
+ LOGGING = "logging"
27
+ MEMORY = "memory"
28
+ OVERALL = "overall"
29
+
30
+
31
+ @dataclass
32
+ class PerformanceMetrics:
33
+ """Performance metrics for a specific operation."""
34
+ name: str
35
+ category: PerformanceCategory
36
+ total_time: float = 0.0
37
+ call_count: int = 0
38
+ avg_time: float = 0.0
39
+ min_time: float = float('inf')
40
+ max_time: float = 0.0
41
+ recent_times: deque = field(default_factory=lambda: deque(maxlen=100))
42
+ memory_usage: Dict[str, float] = field(default_factory=dict)
43
+
44
+ def update(self, duration: float, memory_stats: Optional[Dict[str, float]] = None):
45
+ """Update metrics with new timing data."""
46
+ self.total_time += duration
47
+ self.call_count += 1
48
+ self.avg_time = self.total_time / self.call_count
49
+ self.min_time = min(self.min_time, duration)
50
+ self.max_time = max(self.max_time, duration)
51
+ self.recent_times.append(duration)
52
+
53
+ if memory_stats:
54
+ self.memory_usage.update(memory_stats)
55
+
56
+ @property
57
+ def recent_avg_time(self) -> float:
58
+ """Average time for recent calls."""
59
+ if not self.recent_times:
60
+ return 0.0
61
+ return statistics.mean(self.recent_times)
62
+
63
+ @property
64
+ def calls_per_second(self) -> float:
65
+ """Number of calls per second based on recent data."""
66
+ if not self.recent_times or len(self.recent_times) < 2:
67
+ return 0.0
68
+ total_recent_time = sum(self.recent_times)
69
+ return len(self.recent_times) / total_recent_time if total_recent_time > 0 else 0.0
70
+
71
+
72
+ class PerformanceAlert(Enum):
73
+ """Types of performance alerts."""
74
+ SLOW_OPERATION = "slow_operation"
75
+ MEMORY_HIGH = "memory_high"
76
+ THROUGHPUT_DROP = "throughput_drop"
77
+ BOTTLENECK_DETECTED = "bottleneck_detected"
78
+
79
+
80
+ @dataclass
81
+ class Alert:
82
+ """Performance alert information."""
83
+ alert_type: PerformanceAlert
84
+ message: str
85
+ metric_name: str
86
+ value: float
87
+ threshold: float
88
+ timestamp: float = field(default_factory=time.time)
89
+
90
+
91
+ class PerformanceMonitor:
92
+ """
93
+ Real-time performance monitoring system for MLX-RL training.
94
+
95
+ Features:
96
+ - Automatic bottleneck detection
97
+ - Memory usage tracking
98
+ - Performance regression alerts
99
+ - Training efficiency analysis
100
+ - Optimization recommendations
101
+ """
102
+
103
+ def __init__(self, alert_thresholds: Optional[Dict[str, float]] = None):
104
+ """
105
+ Initialize performance monitor.
106
+
107
+ Args:
108
+ alert_thresholds: Custom alert thresholds for different metrics
109
+ """
110
+ self.metrics: Dict[str, PerformanceMetrics] = {}
111
+ self.memory_monitor = MemoryMonitor()
112
+ self.alerts: List[Alert] = []
113
+ self.start_time = time.time()
114
+
115
+ # Default alert thresholds
116
+ self.thresholds = {
117
+ "slow_operation_ms": 100.0, # Operations slower than 100ms
118
+ "memory_usage_mb": 8000.0, # Memory usage above 8GB
119
+ "throughput_drop_ratio": 0.7, # Throughput drops below 70% of recent average
120
+ "memory_growth_rate": 100.0 # Memory growing faster than 100MB/min
121
+ }
122
+
123
+ if alert_thresholds:
124
+ self.thresholds.update(alert_thresholds)
125
+
126
+ @contextmanager
127
+ def measure(self, name: str, category: PerformanceCategory = PerformanceCategory.OVERALL,
128
+ track_memory: bool = False):
129
+ """
130
+ Context manager for measuring operation performance.
131
+
132
+ Args:
133
+ name: Operation name
134
+ category: Performance category
135
+ track_memory: Whether to track memory usage
136
+
137
+ Example:
138
+ monitor = PerformanceMonitor()
139
+ with monitor.measure("policy_forward", PerformanceCategory.POLICY):
140
+ action = policy(observation)
141
+ """
142
+ if name not in self.metrics:
143
+ self.metrics[name] = PerformanceMetrics(name, category)
144
+
145
+ memory_before = get_memory_stats() if track_memory else None
146
+ start_time = time.perf_counter()
147
+
148
+ try:
149
+ yield
150
+ finally:
151
+ duration = time.perf_counter() - start_time
152
+ memory_after = get_memory_stats() if track_memory else None
153
+
154
+ # Calculate memory delta if tracking
155
+ memory_delta = None
156
+ if memory_before and memory_after:
157
+ memory_delta = {
158
+ key: memory_after.get(key, 0) - memory_before.get(key, 0)
159
+ for key in memory_after.keys()
160
+ }
161
+
162
+ self.metrics[name].update(duration, memory_delta)
163
+ self._check_alerts(name, duration, memory_after)
164
+
165
+ def record_metric(self, name: str, value: float, category: PerformanceCategory = PerformanceCategory.OVERALL):
166
+ """
167
+ Record a custom metric value.
168
+
169
+ Args:
170
+ name: Metric name
171
+ value: Metric value
172
+ category: Performance category
173
+ """
174
+ if name not in self.metrics:
175
+ self.metrics[name] = PerformanceMetrics(name, category)
176
+
177
+ self.metrics[name].update(value)
178
+
179
+ def _check_alerts(self, metric_name: str, duration: float, memory_stats: Optional[Dict[str, float]]):
180
+ """Check for performance alerts based on current metrics."""
181
+ metric = self.metrics[metric_name]
182
+
183
+ # Check for slow operations
184
+ if duration * 1000 > self.thresholds["slow_operation_ms"]:
185
+ alert = Alert(
186
+ alert_type=PerformanceAlert.SLOW_OPERATION,
187
+ message=f"Slow operation detected: {metric_name} took {duration*1000:.1f}ms",
188
+ metric_name=metric_name,
189
+ value=duration * 1000,
190
+ threshold=self.thresholds["slow_operation_ms"]
191
+ )
192
+ self.alerts.append(alert)
193
+ debug_print(alert.message, "performance")
194
+
195
+ # Check for throughput drops
196
+ if len(metric.recent_times) >= 10:
197
+ recent_avg = metric.recent_avg_time
198
+ overall_avg = metric.avg_time
199
+ if recent_avg > overall_avg * (1 / self.thresholds["throughput_drop_ratio"]):
200
+ alert = Alert(
201
+ alert_type=PerformanceAlert.THROUGHPUT_DROP,
202
+ message=f"Throughput drop detected in {metric_name}: {recent_avg*1000:.1f}ms vs {overall_avg*1000:.1f}ms avg",
203
+ metric_name=metric_name,
204
+ value=recent_avg * 1000,
205
+ threshold=overall_avg * 1000
206
+ )
207
+ self.alerts.append(alert)
208
+ debug_print(alert.message, "performance")
209
+
210
+ # Check memory usage
211
+ if memory_stats:
212
+ total_memory = memory_stats.get("mlx_memory_mb", 0) + memory_stats.get("python_memory_mb", 0)
213
+ if total_memory > self.thresholds["memory_usage_mb"]:
214
+ alert = Alert(
215
+ alert_type=PerformanceAlert.MEMORY_HIGH,
216
+ message=f"High memory usage: {total_memory:.1f}MB",
217
+ metric_name="memory_usage",
218
+ value=total_memory,
219
+ threshold=self.thresholds["memory_usage_mb"]
220
+ )
221
+ self.alerts.append(alert)
222
+
223
+ def get_bottlenecks(self, min_time_ms: float = 1.0) -> List[Tuple[str, PerformanceMetrics]]:
224
+ """
225
+ Identify performance bottlenecks.
226
+
227
+ Args:
228
+ min_time_ms: Minimum average time in ms to consider as bottleneck
229
+
230
+ Returns:
231
+ List of (metric_name, metrics) tuples sorted by total time
232
+ """
233
+ bottlenecks = []
234
+ for name, metric in self.metrics.items():
235
+ if metric.avg_time * 1000 >= min_time_ms:
236
+ bottlenecks.append((name, metric))
237
+
238
+ # Sort by total time spent (highest first)
239
+ bottlenecks.sort(key=lambda x: x[1].total_time, reverse=True)
240
+ return bottlenecks
241
+
242
+ def get_summary(self) -> Dict[str, Any]:
243
+ """
244
+ Get performance monitoring summary.
245
+
246
+ Returns:
247
+ Dictionary with performance summary data
248
+ """
249
+ total_runtime = time.time() - self.start_time
250
+ bottlenecks = self.get_bottlenecks()
251
+
252
+ summary = {
253
+ "total_runtime_seconds": total_runtime,
254
+ "total_operations": sum(m.call_count for m in self.metrics.values()),
255
+ "total_measured_time": sum(m.total_time for m in self.metrics.values()),
256
+ "overhead_ratio": 1.0 - (sum(m.total_time for m in self.metrics.values()) / total_runtime) if total_runtime > 0 else 0.0,
257
+ "num_alerts": len(self.alerts),
258
+ "num_bottlenecks": len(bottlenecks),
259
+ "top_bottleneck": bottlenecks[0][0] if bottlenecks else None,
260
+ "memory_stats": get_memory_stats(),
261
+ "categories": {}
262
+ }
263
+
264
+ # Group metrics by category
265
+ for metric in self.metrics.values():
266
+ category = metric.category.value
267
+ if category not in summary["categories"]:
268
+ summary["categories"][category] = {
269
+ "total_time": 0.0,
270
+ "call_count": 0,
271
+ "operations": []
272
+ }
273
+
274
+ summary["categories"][category]["total_time"] += metric.total_time
275
+ summary["categories"][category]["call_count"] += metric.call_count
276
+ summary["categories"][category]["operations"].append(metric.name)
277
+
278
+ return summary
279
+
280
+ def print_summary(self, detailed: bool = True):
281
+ """
282
+ Print performance monitoring summary.
283
+
284
+ Args:
285
+ detailed: Whether to show detailed metrics
286
+ """
287
+ summary = self.get_summary()
288
+
289
+ print("\nPerformance Monitoring Summary")
290
+ print("=" * 50)
291
+ print(f"Total runtime: {summary['total_runtime_seconds']:.1f}s")
292
+ print(f"Total operations: {summary['total_operations']}")
293
+ print(f"Monitoring overhead: {summary['overhead_ratio']*100:.1f}%")
294
+ print(f"Alerts: {summary['num_alerts']}")
295
+ print(f"Bottlenecks: {summary['num_bottlenecks']}")
296
+
297
+ if summary['top_bottleneck']:
298
+ print(f"Top bottleneck: {summary['top_bottleneck']}")
299
+
300
+ # Memory stats
301
+ memory = summary['memory_stats']
302
+ if memory:
303
+ print("\nMemory Usage:")
304
+ for key, value in memory.items():
305
+ if key.endswith('_mb'):
306
+ print(f" {key}: {value:.1f} MB")
307
+
308
+ if detailed and self.metrics:
309
+ print("\nDetailed Metrics:")
310
+ bottlenecks = self.get_bottlenecks(0.1) # Show operations > 0.1ms
311
+
312
+ print(f"{'Operation':<25} {'Calls':<8} {'Total (ms)':<12} {'Avg (ms)':<10} {'Recent (ms)':<12}")
313
+ print("-" * 75)
314
+
315
+ for name, metric in bottlenecks[:10]: # Top 10 bottlenecks
316
+ print(f"{name:<25} {metric.call_count:<8} {metric.total_time*1000:<12.1f} "
317
+ f"{metric.avg_time*1000:<10.2f} {metric.recent_avg_time*1000:<12.2f}")
318
+
319
+ # Recent alerts
320
+ if self.alerts:
321
+ recent_alerts = [a for a in self.alerts if time.time() - a.timestamp < 300] # Last 5 minutes
322
+ if recent_alerts:
323
+ print(f"\nRecent Alerts ({len(recent_alerts)}):")
324
+ for alert in recent_alerts[-5:]: # Last 5 alerts
325
+ print(f" {alert.alert_type.value}: {alert.message}")
326
+
327
+ def print_bottlenecks(self, top_n: int = 10):
328
+ """
329
+ Print identified performance bottlenecks.
330
+
331
+ Args:
332
+ top_n: Number of top bottlenecks to show
333
+ """
334
+ bottlenecks = self.get_bottlenecks()
335
+
336
+ if not bottlenecks:
337
+ print("No performance bottlenecks detected.")
338
+ return
339
+
340
+ print(f"\nTop {min(top_n, len(bottlenecks))} Performance Bottlenecks")
341
+ print("=" * 60)
342
+
343
+ for i, (name, metric) in enumerate(bottlenecks[:top_n], 1):
344
+ percentage = (metric.total_time / sum(m.total_time for m in self.metrics.values())) * 100
345
+ print(f"{i}. {name}")
346
+ print(f" Total time: {metric.total_time*1000:.1f}ms ({percentage:.1f}% of total)")
347
+ print(f" Calls: {metric.call_count} (avg: {metric.avg_time*1000:.2f}ms)")
348
+ print(f" Range: {metric.min_time*1000:.2f}ms - {metric.max_time*1000:.2f}ms")
349
+
350
+ # Recommendations
351
+ if metric.avg_time > 0.1: # > 100ms
352
+ print(" High latency operation - consider optimization")
353
+ elif metric.call_count > 1000 and metric.total_time > 1.0:
354
+ print(" High frequency operation - consider caching/batching")
355
+
356
+ print()
357
+
358
+ def get_optimization_recommendations(self) -> List[str]:
359
+ """
360
+ Generate optimization recommendations based on performance data.
361
+
362
+ Returns:
363
+ List of optimization recommendation strings
364
+ """
365
+ recommendations = []
366
+ bottlenecks = self.get_bottlenecks()
367
+
368
+ if not bottlenecks:
369
+ recommendations.append("No significant bottlenecks detected")
370
+ return recommendations
371
+
372
+ # Analyze top bottleneck
373
+ top_name, top_metric = bottlenecks[0]
374
+
375
+ if "environment" in top_name.lower() or top_metric.category == PerformanceCategory.ENVIRONMENT:
376
+ if top_metric.avg_time > 0.01: # > 10ms
377
+ recommendations.append(f"Consider vectorizing environment '{top_name}' for better throughput")
378
+ else:
379
+ recommendations.append(f"Environment '{top_name}' is lightweight - vectorization may not help")
380
+
381
+ if "policy" in top_name.lower() or top_metric.category == PerformanceCategory.POLICY:
382
+ recommendations.append(f"Consider batching policy inference for '{top_name}'")
383
+ recommendations.append("Profile MLX operations in policy forward pass")
384
+
385
+ # Memory recommendations
386
+ memory_stats = get_memory_stats()
387
+ if memory_stats.get("mlx_memory_mb", 0) > 4000:
388
+ recommendations.append("High MLX memory usage - consider reducing batch size")
389
+
390
+ # General recommendations
391
+ total_measured = sum(m.total_time for m in self.metrics.values())
392
+ total_runtime = time.time() - self.start_time
393
+ if total_runtime > 0 and (total_measured / total_runtime) < 0.5:
394
+ recommendations.append("Low measurement coverage - add more performance monitoring")
395
+
396
+ high_variance_ops = [
397
+ name for name, metric in self.metrics.items()
398
+ if len(metric.recent_times) > 10 and
399
+ (max(metric.recent_times) / min(metric.recent_times)) > 3
400
+ ]
401
+
402
+ if high_variance_ops:
403
+ recommendations.append(f"High timing variance in: {', '.join(high_variance_ops[:3])}")
404
+
405
+ return recommendations
406
+
407
+ def clear_alerts(self, older_than_seconds: Optional[float] = None):
408
+ """
409
+ Clear alerts, optionally only those older than specified time.
410
+
411
+ Args:
412
+ older_than_seconds: Only clear alerts older than this many seconds
413
+ """
414
+ if older_than_seconds is None:
415
+ self.alerts.clear()
416
+ else:
417
+ cutoff_time = time.time() - older_than_seconds
418
+ self.alerts = [a for a in self.alerts if a.timestamp > cutoff_time]
419
+
420
+ def reset(self):
421
+ """Reset all performance monitoring data."""
422
+ self.metrics.clear()
423
+ self.alerts.clear()
424
+ self.start_time = time.time()
425
+
426
+
427
+ # Global performance monitor instance
428
+ global_monitor = PerformanceMonitor()
429
+
430
+
431
+ @contextmanager
432
+ def monitor_performance(name: str, category: PerformanceCategory = PerformanceCategory.OVERALL,
433
+ monitor: Optional[PerformanceMonitor] = None, track_memory: bool = False):
434
+ """
435
+ Convenience function for monitoring performance.
436
+
437
+ Args:
438
+ name: Operation name
439
+ category: Performance category
440
+ monitor: PerformanceMonitor instance (uses global if None)
441
+ track_memory: Whether to track memory usage
442
+
443
+ Example:
444
+ with monitor_performance("training_step", PerformanceCategory.TRAINING):
445
+ loss = trainer.update(batch)
446
+ """
447
+ m = monitor if monitor is not None else global_monitor
448
+ with m.measure(name, category, track_memory):
449
+ yield
450
+
451
+
452
+ def get_performance_summary() -> Dict[str, Any]:
453
+ """Get performance summary from global monitor."""
454
+ return global_monitor.get_summary()
455
+
456
+
457
+ def print_performance_summary(detailed: bool = True):
458
+ """Print performance summary from global monitor."""
459
+ global_monitor.print_summary(detailed)
460
+
461
+
462
+ def get_optimization_recommendations() -> List[str]:
463
+ """Get optimization recommendations from global monitor."""
464
+ return global_monitor.get_optimization_recommendations()
@@ -0,0 +1,171 @@
1
+ # mlx_rl/utils/timing.py
2
+ """
3
+ Timing and performance measurement utilities.
4
+ """
5
+
6
+ import time
7
+ from typing import Dict, Optional
8
+ from contextlib import contextmanager
9
+
10
+
11
+ class Timer:
12
+ """
13
+ High-precision timer for performance measurement.
14
+
15
+ Features:
16
+ - Context manager support for easy use
17
+ - Multiple named timers with aggregation
18
+ - Statistics tracking (mean, min, max, count)
19
+ - MLX-optimized for Apple Silicon performance profiling
20
+ """
21
+
22
+ def __init__(self):
23
+ """Initialize timer with empty statistics."""
24
+ self.times: Dict[str, list] = {}
25
+ self.current_start: Optional[float] = None
26
+
27
+ def start(self, name: str = "default"):
28
+ """
29
+ Start timing a named operation.
30
+
31
+ Args:
32
+ name: Timer name for tracking multiple operations
33
+ """
34
+ if name not in self.times:
35
+ self.times[name] = []
36
+ self.current_start = time.perf_counter()
37
+
38
+ def stop(self, name: str = "default") -> float:
39
+ """
40
+ Stop timing and record the duration.
41
+
42
+ Args:
43
+ name: Timer name (must match start() call)
44
+
45
+ Returns:
46
+ Duration in seconds
47
+
48
+ Raises:
49
+ RuntimeError: If timer wasn't started
50
+ """
51
+ if self.current_start is None:
52
+ raise RuntimeError(f"Timer '{name}' was not started")
53
+
54
+ duration = time.perf_counter() - self.current_start
55
+ self.times[name].append(duration)
56
+ self.current_start = None
57
+ return duration
58
+
59
+ @contextmanager
60
+ def time(self, name: str = "default"):
61
+ """
62
+ Context manager for timing code blocks.
63
+
64
+ Args:
65
+ name: Timer name
66
+
67
+ Example:
68
+ timer = Timer()
69
+ with timer.time("policy_forward"):
70
+ action = policy(obs)
71
+ """
72
+ self.start(name)
73
+ try:
74
+ yield
75
+ finally:
76
+ self.stop(name)
77
+
78
+ def get_stats(self, name: str = "default") -> Dict[str, float]:
79
+ """
80
+ Get timing statistics for a named timer.
81
+
82
+ Args:
83
+ name: Timer name
84
+
85
+ Returns:
86
+ Dictionary with mean, min, max, total, count statistics
87
+ """
88
+ if name not in self.times or not self.times[name]:
89
+ return {"mean": 0.0, "min": 0.0, "max": 0.0, "total": 0.0, "count": 0}
90
+
91
+ times = self.times[name]
92
+ return {
93
+ "mean": sum(times) / len(times),
94
+ "min": min(times),
95
+ "max": max(times),
96
+ "total": sum(times),
97
+ "count": len(times)
98
+ }
99
+
100
+ def get_all_stats(self) -> Dict[str, Dict[str, float]]:
101
+ """Get statistics for all named timers."""
102
+ return {name: self.get_stats(name) for name in self.times.keys()}
103
+
104
+ def reset(self, name: Optional[str] = None):
105
+ """
106
+ Reset timer statistics.
107
+
108
+ Args:
109
+ name: Specific timer to reset (None = reset all)
110
+ """
111
+ if name is None:
112
+ self.times.clear()
113
+ elif name in self.times:
114
+ self.times[name].clear()
115
+
116
+ def __str__(self) -> str:
117
+ """String representation of all timer statistics."""
118
+ lines = ["Timer Statistics:"]
119
+ for name, stats in self.get_all_stats().items():
120
+ lines.append(f" {name}: {stats['mean']:.4f}s avg ({stats['count']} calls)")
121
+ return "\n".join(lines)
122
+
123
+
124
+ # Global timer instance for convenience
125
+ global_timer = Timer()
126
+
127
+
128
+ @contextmanager
129
+ def time_it(name: str = "default", timer: Optional[Timer] = None):
130
+ """
131
+ Convenience function for timing code blocks.
132
+
133
+ Args:
134
+ name: Timer name
135
+ timer: Timer instance (uses global_timer if None)
136
+
137
+ Example:
138
+ with time_it("training_step"):
139
+ loss = trainer.update(batch)
140
+ """
141
+ t = timer if timer is not None else global_timer
142
+ with t.time(name):
143
+ yield
144
+
145
+
146
+ def benchmark_function(func, *args, iterations: int = 100, warmup: int = 10, **kwargs) -> Dict[str, float]:
147
+ """
148
+ Benchmark a function with multiple iterations.
149
+
150
+ Args:
151
+ func: Function to benchmark
152
+ *args: Function positional arguments
153
+ iterations: Number of benchmark iterations
154
+ warmup: Number of warmup iterations (excluded from timing)
155
+ **kwargs: Function keyword arguments
156
+
157
+ Returns:
158
+ Statistics dictionary with timing results
159
+ """
160
+ timer = Timer()
161
+
162
+ # Warmup iterations
163
+ for _ in range(warmup):
164
+ func(*args, **kwargs)
165
+
166
+ # Benchmark iterations
167
+ for i in range(iterations):
168
+ with timer.time("benchmark"):
169
+ func(*args, **kwargs)
170
+
171
+ return timer.get_stats("benchmark")