textpolicy 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +52 -0
- textpolicy/__main__.py +8 -0
- textpolicy/algorithms/__init__.py +54 -0
- textpolicy/algorithms/grpo.py +642 -0
- textpolicy/algorithms/gspo.py +582 -0
- textpolicy/buffer/__init__.py +23 -0
- textpolicy/buffer/buffer.py +244 -0
- textpolicy/buffer/episode.py +383 -0
- textpolicy/buffer/sampling.py +438 -0
- textpolicy/buffer/storage.py +255 -0
- textpolicy/cli.py +67 -0
- textpolicy/environment/__init__.py +79 -0
- textpolicy/environment/base.py +110 -0
- textpolicy/environment/environment.py +46 -0
- textpolicy/environment/factory.py +103 -0
- textpolicy/environment/gym.py +106 -0
- textpolicy/environment/task_suites.py +51 -0
- textpolicy/environment/text_generation.py +789 -0
- textpolicy/environment/vectorized.py +253 -0
- textpolicy/generation/__init__.py +62 -0
- textpolicy/generation/lora.py +411 -0
- textpolicy/generation/mlx_generation.py +557 -0
- textpolicy/generation/reload.py +253 -0
- textpolicy/rewards/__init__.py +137 -0
- textpolicy/rewards/adapters.py +387 -0
- textpolicy/rewards/basic.py +214 -0
- textpolicy/rewards/integrated_system.py +338 -0
- textpolicy/rewards/mlx_batch_processor.py +447 -0
- textpolicy/rewards/registry.py +293 -0
- textpolicy/rewards/rollout_rewards.py +410 -0
- textpolicy/rewards/verifiers.py +369 -0
- textpolicy/rollout/__init__.py +44 -0
- textpolicy/rollout/aggregator.py +145 -0
- textpolicy/rollout/base.py +108 -0
- textpolicy/rollout/rollout.py +142 -0
- textpolicy/rollout/runner.py +280 -0
- textpolicy/rollout/strategy.py +208 -0
- textpolicy/rollout/worker.py +194 -0
- textpolicy/training/__init__.py +14 -0
- textpolicy/training/metrics.py +242 -0
- textpolicy/training/rollout_manager.py +78 -0
- textpolicy/training/trainer.py +684 -0
- textpolicy/utils/__init__.py +40 -0
- textpolicy/utils/benchmarking.py +489 -0
- textpolicy/utils/data.py +60 -0
- textpolicy/utils/debug.py +170 -0
- textpolicy/utils/environment.py +349 -0
- textpolicy/utils/logging/__init__.py +22 -0
- textpolicy/utils/logging/base.py +48 -0
- textpolicy/utils/logging/console.py +61 -0
- textpolicy/utils/logging/factory.py +133 -0
- textpolicy/utils/logging/multi.py +83 -0
- textpolicy/utils/logging/tensorboard.py +65 -0
- textpolicy/utils/logging/wandb.py +72 -0
- textpolicy/utils/memory.py +118 -0
- textpolicy/utils/performance.py +464 -0
- textpolicy/utils/timing.py +171 -0
- textpolicy/validate.py +101 -0
- textpolicy/validation/__init__.py +13 -0
- textpolicy/validation/logprob_validation.py +315 -0
- textpolicy-0.1.0.dist-info/METADATA +99 -0
- textpolicy-0.1.0.dist-info/RECORD +66 -0
- textpolicy-0.1.0.dist-info/entry_points.txt +2 -0
- textpolicy-0.0.1.dist-info/METADATA +0 -10
- textpolicy-0.0.1.dist-info/RECORD +0 -6
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/WHEEL +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,464 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Performance monitoring and optimization utilities for MLX-RL.
|
|
3
|
+
|
|
4
|
+
This module provides real-time performance monitoring, bottleneck detection,
|
|
5
|
+
and optimization recommendations for training pipelines.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import time
|
|
9
|
+
import statistics
|
|
10
|
+
from typing import Dict, List, Optional, Tuple, Any
|
|
11
|
+
from dataclasses import dataclass, field
|
|
12
|
+
from collections import deque
|
|
13
|
+
from contextlib import contextmanager
|
|
14
|
+
from enum import Enum
|
|
15
|
+
|
|
16
|
+
from .memory import MemoryMonitor, get_memory_stats
|
|
17
|
+
from .debug import debug_print
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class PerformanceCategory(Enum):
|
|
21
|
+
"""Categories for performance tracking."""
|
|
22
|
+
ENVIRONMENT = "environment"
|
|
23
|
+
POLICY = "policy"
|
|
24
|
+
TRAINING = "training"
|
|
25
|
+
DATA_LOADING = "data_loading"
|
|
26
|
+
LOGGING = "logging"
|
|
27
|
+
MEMORY = "memory"
|
|
28
|
+
OVERALL = "overall"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class PerformanceMetrics:
|
|
33
|
+
"""Performance metrics for a specific operation."""
|
|
34
|
+
name: str
|
|
35
|
+
category: PerformanceCategory
|
|
36
|
+
total_time: float = 0.0
|
|
37
|
+
call_count: int = 0
|
|
38
|
+
avg_time: float = 0.0
|
|
39
|
+
min_time: float = float('inf')
|
|
40
|
+
max_time: float = 0.0
|
|
41
|
+
recent_times: deque = field(default_factory=lambda: deque(maxlen=100))
|
|
42
|
+
memory_usage: Dict[str, float] = field(default_factory=dict)
|
|
43
|
+
|
|
44
|
+
def update(self, duration: float, memory_stats: Optional[Dict[str, float]] = None):
|
|
45
|
+
"""Update metrics with new timing data."""
|
|
46
|
+
self.total_time += duration
|
|
47
|
+
self.call_count += 1
|
|
48
|
+
self.avg_time = self.total_time / self.call_count
|
|
49
|
+
self.min_time = min(self.min_time, duration)
|
|
50
|
+
self.max_time = max(self.max_time, duration)
|
|
51
|
+
self.recent_times.append(duration)
|
|
52
|
+
|
|
53
|
+
if memory_stats:
|
|
54
|
+
self.memory_usage.update(memory_stats)
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def recent_avg_time(self) -> float:
|
|
58
|
+
"""Average time for recent calls."""
|
|
59
|
+
if not self.recent_times:
|
|
60
|
+
return 0.0
|
|
61
|
+
return statistics.mean(self.recent_times)
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def calls_per_second(self) -> float:
|
|
65
|
+
"""Number of calls per second based on recent data."""
|
|
66
|
+
if not self.recent_times or len(self.recent_times) < 2:
|
|
67
|
+
return 0.0
|
|
68
|
+
total_recent_time = sum(self.recent_times)
|
|
69
|
+
return len(self.recent_times) / total_recent_time if total_recent_time > 0 else 0.0
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class PerformanceAlert(Enum):
|
|
73
|
+
"""Types of performance alerts."""
|
|
74
|
+
SLOW_OPERATION = "slow_operation"
|
|
75
|
+
MEMORY_HIGH = "memory_high"
|
|
76
|
+
THROUGHPUT_DROP = "throughput_drop"
|
|
77
|
+
BOTTLENECK_DETECTED = "bottleneck_detected"
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@dataclass
|
|
81
|
+
class Alert:
|
|
82
|
+
"""Performance alert information."""
|
|
83
|
+
alert_type: PerformanceAlert
|
|
84
|
+
message: str
|
|
85
|
+
metric_name: str
|
|
86
|
+
value: float
|
|
87
|
+
threshold: float
|
|
88
|
+
timestamp: float = field(default_factory=time.time)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class PerformanceMonitor:
|
|
92
|
+
"""
|
|
93
|
+
Real-time performance monitoring system for MLX-RL training.
|
|
94
|
+
|
|
95
|
+
Features:
|
|
96
|
+
- Automatic bottleneck detection
|
|
97
|
+
- Memory usage tracking
|
|
98
|
+
- Performance regression alerts
|
|
99
|
+
- Training efficiency analysis
|
|
100
|
+
- Optimization recommendations
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def __init__(self, alert_thresholds: Optional[Dict[str, float]] = None):
|
|
104
|
+
"""
|
|
105
|
+
Initialize performance monitor.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
alert_thresholds: Custom alert thresholds for different metrics
|
|
109
|
+
"""
|
|
110
|
+
self.metrics: Dict[str, PerformanceMetrics] = {}
|
|
111
|
+
self.memory_monitor = MemoryMonitor()
|
|
112
|
+
self.alerts: List[Alert] = []
|
|
113
|
+
self.start_time = time.time()
|
|
114
|
+
|
|
115
|
+
# Default alert thresholds
|
|
116
|
+
self.thresholds = {
|
|
117
|
+
"slow_operation_ms": 100.0, # Operations slower than 100ms
|
|
118
|
+
"memory_usage_mb": 8000.0, # Memory usage above 8GB
|
|
119
|
+
"throughput_drop_ratio": 0.7, # Throughput drops below 70% of recent average
|
|
120
|
+
"memory_growth_rate": 100.0 # Memory growing faster than 100MB/min
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
if alert_thresholds:
|
|
124
|
+
self.thresholds.update(alert_thresholds)
|
|
125
|
+
|
|
126
|
+
@contextmanager
|
|
127
|
+
def measure(self, name: str, category: PerformanceCategory = PerformanceCategory.OVERALL,
|
|
128
|
+
track_memory: bool = False):
|
|
129
|
+
"""
|
|
130
|
+
Context manager for measuring operation performance.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
name: Operation name
|
|
134
|
+
category: Performance category
|
|
135
|
+
track_memory: Whether to track memory usage
|
|
136
|
+
|
|
137
|
+
Example:
|
|
138
|
+
monitor = PerformanceMonitor()
|
|
139
|
+
with monitor.measure("policy_forward", PerformanceCategory.POLICY):
|
|
140
|
+
action = policy(observation)
|
|
141
|
+
"""
|
|
142
|
+
if name not in self.metrics:
|
|
143
|
+
self.metrics[name] = PerformanceMetrics(name, category)
|
|
144
|
+
|
|
145
|
+
memory_before = get_memory_stats() if track_memory else None
|
|
146
|
+
start_time = time.perf_counter()
|
|
147
|
+
|
|
148
|
+
try:
|
|
149
|
+
yield
|
|
150
|
+
finally:
|
|
151
|
+
duration = time.perf_counter() - start_time
|
|
152
|
+
memory_after = get_memory_stats() if track_memory else None
|
|
153
|
+
|
|
154
|
+
# Calculate memory delta if tracking
|
|
155
|
+
memory_delta = None
|
|
156
|
+
if memory_before and memory_after:
|
|
157
|
+
memory_delta = {
|
|
158
|
+
key: memory_after.get(key, 0) - memory_before.get(key, 0)
|
|
159
|
+
for key in memory_after.keys()
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
self.metrics[name].update(duration, memory_delta)
|
|
163
|
+
self._check_alerts(name, duration, memory_after)
|
|
164
|
+
|
|
165
|
+
def record_metric(self, name: str, value: float, category: PerformanceCategory = PerformanceCategory.OVERALL):
|
|
166
|
+
"""
|
|
167
|
+
Record a custom metric value.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
name: Metric name
|
|
171
|
+
value: Metric value
|
|
172
|
+
category: Performance category
|
|
173
|
+
"""
|
|
174
|
+
if name not in self.metrics:
|
|
175
|
+
self.metrics[name] = PerformanceMetrics(name, category)
|
|
176
|
+
|
|
177
|
+
self.metrics[name].update(value)
|
|
178
|
+
|
|
179
|
+
def _check_alerts(self, metric_name: str, duration: float, memory_stats: Optional[Dict[str, float]]):
|
|
180
|
+
"""Check for performance alerts based on current metrics."""
|
|
181
|
+
metric = self.metrics[metric_name]
|
|
182
|
+
|
|
183
|
+
# Check for slow operations
|
|
184
|
+
if duration * 1000 > self.thresholds["slow_operation_ms"]:
|
|
185
|
+
alert = Alert(
|
|
186
|
+
alert_type=PerformanceAlert.SLOW_OPERATION,
|
|
187
|
+
message=f"Slow operation detected: {metric_name} took {duration*1000:.1f}ms",
|
|
188
|
+
metric_name=metric_name,
|
|
189
|
+
value=duration * 1000,
|
|
190
|
+
threshold=self.thresholds["slow_operation_ms"]
|
|
191
|
+
)
|
|
192
|
+
self.alerts.append(alert)
|
|
193
|
+
debug_print(alert.message, "performance")
|
|
194
|
+
|
|
195
|
+
# Check for throughput drops
|
|
196
|
+
if len(metric.recent_times) >= 10:
|
|
197
|
+
recent_avg = metric.recent_avg_time
|
|
198
|
+
overall_avg = metric.avg_time
|
|
199
|
+
if recent_avg > overall_avg * (1 / self.thresholds["throughput_drop_ratio"]):
|
|
200
|
+
alert = Alert(
|
|
201
|
+
alert_type=PerformanceAlert.THROUGHPUT_DROP,
|
|
202
|
+
message=f"Throughput drop detected in {metric_name}: {recent_avg*1000:.1f}ms vs {overall_avg*1000:.1f}ms avg",
|
|
203
|
+
metric_name=metric_name,
|
|
204
|
+
value=recent_avg * 1000,
|
|
205
|
+
threshold=overall_avg * 1000
|
|
206
|
+
)
|
|
207
|
+
self.alerts.append(alert)
|
|
208
|
+
debug_print(alert.message, "performance")
|
|
209
|
+
|
|
210
|
+
# Check memory usage
|
|
211
|
+
if memory_stats:
|
|
212
|
+
total_memory = memory_stats.get("mlx_memory_mb", 0) + memory_stats.get("python_memory_mb", 0)
|
|
213
|
+
if total_memory > self.thresholds["memory_usage_mb"]:
|
|
214
|
+
alert = Alert(
|
|
215
|
+
alert_type=PerformanceAlert.MEMORY_HIGH,
|
|
216
|
+
message=f"High memory usage: {total_memory:.1f}MB",
|
|
217
|
+
metric_name="memory_usage",
|
|
218
|
+
value=total_memory,
|
|
219
|
+
threshold=self.thresholds["memory_usage_mb"]
|
|
220
|
+
)
|
|
221
|
+
self.alerts.append(alert)
|
|
222
|
+
|
|
223
|
+
def get_bottlenecks(self, min_time_ms: float = 1.0) -> List[Tuple[str, PerformanceMetrics]]:
|
|
224
|
+
"""
|
|
225
|
+
Identify performance bottlenecks.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
min_time_ms: Minimum average time in ms to consider as bottleneck
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
List of (metric_name, metrics) tuples sorted by total time
|
|
232
|
+
"""
|
|
233
|
+
bottlenecks = []
|
|
234
|
+
for name, metric in self.metrics.items():
|
|
235
|
+
if metric.avg_time * 1000 >= min_time_ms:
|
|
236
|
+
bottlenecks.append((name, metric))
|
|
237
|
+
|
|
238
|
+
# Sort by total time spent (highest first)
|
|
239
|
+
bottlenecks.sort(key=lambda x: x[1].total_time, reverse=True)
|
|
240
|
+
return bottlenecks
|
|
241
|
+
|
|
242
|
+
def get_summary(self) -> Dict[str, Any]:
|
|
243
|
+
"""
|
|
244
|
+
Get performance monitoring summary.
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
Dictionary with performance summary data
|
|
248
|
+
"""
|
|
249
|
+
total_runtime = time.time() - self.start_time
|
|
250
|
+
bottlenecks = self.get_bottlenecks()
|
|
251
|
+
|
|
252
|
+
summary = {
|
|
253
|
+
"total_runtime_seconds": total_runtime,
|
|
254
|
+
"total_operations": sum(m.call_count for m in self.metrics.values()),
|
|
255
|
+
"total_measured_time": sum(m.total_time for m in self.metrics.values()),
|
|
256
|
+
"overhead_ratio": 1.0 - (sum(m.total_time for m in self.metrics.values()) / total_runtime) if total_runtime > 0 else 0.0,
|
|
257
|
+
"num_alerts": len(self.alerts),
|
|
258
|
+
"num_bottlenecks": len(bottlenecks),
|
|
259
|
+
"top_bottleneck": bottlenecks[0][0] if bottlenecks else None,
|
|
260
|
+
"memory_stats": get_memory_stats(),
|
|
261
|
+
"categories": {}
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
# Group metrics by category
|
|
265
|
+
for metric in self.metrics.values():
|
|
266
|
+
category = metric.category.value
|
|
267
|
+
if category not in summary["categories"]:
|
|
268
|
+
summary["categories"][category] = {
|
|
269
|
+
"total_time": 0.0,
|
|
270
|
+
"call_count": 0,
|
|
271
|
+
"operations": []
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
summary["categories"][category]["total_time"] += metric.total_time
|
|
275
|
+
summary["categories"][category]["call_count"] += metric.call_count
|
|
276
|
+
summary["categories"][category]["operations"].append(metric.name)
|
|
277
|
+
|
|
278
|
+
return summary
|
|
279
|
+
|
|
280
|
+
def print_summary(self, detailed: bool = True):
|
|
281
|
+
"""
|
|
282
|
+
Print performance monitoring summary.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
detailed: Whether to show detailed metrics
|
|
286
|
+
"""
|
|
287
|
+
summary = self.get_summary()
|
|
288
|
+
|
|
289
|
+
print("\nPerformance Monitoring Summary")
|
|
290
|
+
print("=" * 50)
|
|
291
|
+
print(f"Total runtime: {summary['total_runtime_seconds']:.1f}s")
|
|
292
|
+
print(f"Total operations: {summary['total_operations']}")
|
|
293
|
+
print(f"Monitoring overhead: {summary['overhead_ratio']*100:.1f}%")
|
|
294
|
+
print(f"Alerts: {summary['num_alerts']}")
|
|
295
|
+
print(f"Bottlenecks: {summary['num_bottlenecks']}")
|
|
296
|
+
|
|
297
|
+
if summary['top_bottleneck']:
|
|
298
|
+
print(f"Top bottleneck: {summary['top_bottleneck']}")
|
|
299
|
+
|
|
300
|
+
# Memory stats
|
|
301
|
+
memory = summary['memory_stats']
|
|
302
|
+
if memory:
|
|
303
|
+
print("\nMemory Usage:")
|
|
304
|
+
for key, value in memory.items():
|
|
305
|
+
if key.endswith('_mb'):
|
|
306
|
+
print(f" {key}: {value:.1f} MB")
|
|
307
|
+
|
|
308
|
+
if detailed and self.metrics:
|
|
309
|
+
print("\nDetailed Metrics:")
|
|
310
|
+
bottlenecks = self.get_bottlenecks(0.1) # Show operations > 0.1ms
|
|
311
|
+
|
|
312
|
+
print(f"{'Operation':<25} {'Calls':<8} {'Total (ms)':<12} {'Avg (ms)':<10} {'Recent (ms)':<12}")
|
|
313
|
+
print("-" * 75)
|
|
314
|
+
|
|
315
|
+
for name, metric in bottlenecks[:10]: # Top 10 bottlenecks
|
|
316
|
+
print(f"{name:<25} {metric.call_count:<8} {metric.total_time*1000:<12.1f} "
|
|
317
|
+
f"{metric.avg_time*1000:<10.2f} {metric.recent_avg_time*1000:<12.2f}")
|
|
318
|
+
|
|
319
|
+
# Recent alerts
|
|
320
|
+
if self.alerts:
|
|
321
|
+
recent_alerts = [a for a in self.alerts if time.time() - a.timestamp < 300] # Last 5 minutes
|
|
322
|
+
if recent_alerts:
|
|
323
|
+
print(f"\nRecent Alerts ({len(recent_alerts)}):")
|
|
324
|
+
for alert in recent_alerts[-5:]: # Last 5 alerts
|
|
325
|
+
print(f" {alert.alert_type.value}: {alert.message}")
|
|
326
|
+
|
|
327
|
+
def print_bottlenecks(self, top_n: int = 10):
|
|
328
|
+
"""
|
|
329
|
+
Print identified performance bottlenecks.
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
top_n: Number of top bottlenecks to show
|
|
333
|
+
"""
|
|
334
|
+
bottlenecks = self.get_bottlenecks()
|
|
335
|
+
|
|
336
|
+
if not bottlenecks:
|
|
337
|
+
print("No performance bottlenecks detected.")
|
|
338
|
+
return
|
|
339
|
+
|
|
340
|
+
print(f"\nTop {min(top_n, len(bottlenecks))} Performance Bottlenecks")
|
|
341
|
+
print("=" * 60)
|
|
342
|
+
|
|
343
|
+
for i, (name, metric) in enumerate(bottlenecks[:top_n], 1):
|
|
344
|
+
percentage = (metric.total_time / sum(m.total_time for m in self.metrics.values())) * 100
|
|
345
|
+
print(f"{i}. {name}")
|
|
346
|
+
print(f" Total time: {metric.total_time*1000:.1f}ms ({percentage:.1f}% of total)")
|
|
347
|
+
print(f" Calls: {metric.call_count} (avg: {metric.avg_time*1000:.2f}ms)")
|
|
348
|
+
print(f" Range: {metric.min_time*1000:.2f}ms - {metric.max_time*1000:.2f}ms")
|
|
349
|
+
|
|
350
|
+
# Recommendations
|
|
351
|
+
if metric.avg_time > 0.1: # > 100ms
|
|
352
|
+
print(" High latency operation - consider optimization")
|
|
353
|
+
elif metric.call_count > 1000 and metric.total_time > 1.0:
|
|
354
|
+
print(" High frequency operation - consider caching/batching")
|
|
355
|
+
|
|
356
|
+
print()
|
|
357
|
+
|
|
358
|
+
def get_optimization_recommendations(self) -> List[str]:
|
|
359
|
+
"""
|
|
360
|
+
Generate optimization recommendations based on performance data.
|
|
361
|
+
|
|
362
|
+
Returns:
|
|
363
|
+
List of optimization recommendation strings
|
|
364
|
+
"""
|
|
365
|
+
recommendations = []
|
|
366
|
+
bottlenecks = self.get_bottlenecks()
|
|
367
|
+
|
|
368
|
+
if not bottlenecks:
|
|
369
|
+
recommendations.append("No significant bottlenecks detected")
|
|
370
|
+
return recommendations
|
|
371
|
+
|
|
372
|
+
# Analyze top bottleneck
|
|
373
|
+
top_name, top_metric = bottlenecks[0]
|
|
374
|
+
|
|
375
|
+
if "environment" in top_name.lower() or top_metric.category == PerformanceCategory.ENVIRONMENT:
|
|
376
|
+
if top_metric.avg_time > 0.01: # > 10ms
|
|
377
|
+
recommendations.append(f"Consider vectorizing environment '{top_name}' for better throughput")
|
|
378
|
+
else:
|
|
379
|
+
recommendations.append(f"Environment '{top_name}' is lightweight - vectorization may not help")
|
|
380
|
+
|
|
381
|
+
if "policy" in top_name.lower() or top_metric.category == PerformanceCategory.POLICY:
|
|
382
|
+
recommendations.append(f"Consider batching policy inference for '{top_name}'")
|
|
383
|
+
recommendations.append("Profile MLX operations in policy forward pass")
|
|
384
|
+
|
|
385
|
+
# Memory recommendations
|
|
386
|
+
memory_stats = get_memory_stats()
|
|
387
|
+
if memory_stats.get("mlx_memory_mb", 0) > 4000:
|
|
388
|
+
recommendations.append("High MLX memory usage - consider reducing batch size")
|
|
389
|
+
|
|
390
|
+
# General recommendations
|
|
391
|
+
total_measured = sum(m.total_time for m in self.metrics.values())
|
|
392
|
+
total_runtime = time.time() - self.start_time
|
|
393
|
+
if total_runtime > 0 and (total_measured / total_runtime) < 0.5:
|
|
394
|
+
recommendations.append("Low measurement coverage - add more performance monitoring")
|
|
395
|
+
|
|
396
|
+
high_variance_ops = [
|
|
397
|
+
name for name, metric in self.metrics.items()
|
|
398
|
+
if len(metric.recent_times) > 10 and
|
|
399
|
+
(max(metric.recent_times) / min(metric.recent_times)) > 3
|
|
400
|
+
]
|
|
401
|
+
|
|
402
|
+
if high_variance_ops:
|
|
403
|
+
recommendations.append(f"High timing variance in: {', '.join(high_variance_ops[:3])}")
|
|
404
|
+
|
|
405
|
+
return recommendations
|
|
406
|
+
|
|
407
|
+
def clear_alerts(self, older_than_seconds: Optional[float] = None):
|
|
408
|
+
"""
|
|
409
|
+
Clear alerts, optionally only those older than specified time.
|
|
410
|
+
|
|
411
|
+
Args:
|
|
412
|
+
older_than_seconds: Only clear alerts older than this many seconds
|
|
413
|
+
"""
|
|
414
|
+
if older_than_seconds is None:
|
|
415
|
+
self.alerts.clear()
|
|
416
|
+
else:
|
|
417
|
+
cutoff_time = time.time() - older_than_seconds
|
|
418
|
+
self.alerts = [a for a in self.alerts if a.timestamp > cutoff_time]
|
|
419
|
+
|
|
420
|
+
def reset(self):
|
|
421
|
+
"""Reset all performance monitoring data."""
|
|
422
|
+
self.metrics.clear()
|
|
423
|
+
self.alerts.clear()
|
|
424
|
+
self.start_time = time.time()
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
# Global performance monitor instance
|
|
428
|
+
global_monitor = PerformanceMonitor()
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
@contextmanager
|
|
432
|
+
def monitor_performance(name: str, category: PerformanceCategory = PerformanceCategory.OVERALL,
|
|
433
|
+
monitor: Optional[PerformanceMonitor] = None, track_memory: bool = False):
|
|
434
|
+
"""
|
|
435
|
+
Convenience function for monitoring performance.
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
name: Operation name
|
|
439
|
+
category: Performance category
|
|
440
|
+
monitor: PerformanceMonitor instance (uses global if None)
|
|
441
|
+
track_memory: Whether to track memory usage
|
|
442
|
+
|
|
443
|
+
Example:
|
|
444
|
+
with monitor_performance("training_step", PerformanceCategory.TRAINING):
|
|
445
|
+
loss = trainer.update(batch)
|
|
446
|
+
"""
|
|
447
|
+
m = monitor if monitor is not None else global_monitor
|
|
448
|
+
with m.measure(name, category, track_memory):
|
|
449
|
+
yield
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
def get_performance_summary() -> Dict[str, Any]:
|
|
453
|
+
"""Get performance summary from global monitor."""
|
|
454
|
+
return global_monitor.get_summary()
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def print_performance_summary(detailed: bool = True):
|
|
458
|
+
"""Print performance summary from global monitor."""
|
|
459
|
+
global_monitor.print_summary(detailed)
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
def get_optimization_recommendations() -> List[str]:
|
|
463
|
+
"""Get optimization recommendations from global monitor."""
|
|
464
|
+
return global_monitor.get_optimization_recommendations()
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
# mlx_rl/utils/timing.py
|
|
2
|
+
"""
|
|
3
|
+
Timing and performance measurement utilities.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import time
|
|
7
|
+
from typing import Dict, Optional
|
|
8
|
+
from contextlib import contextmanager
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Timer:
|
|
12
|
+
"""
|
|
13
|
+
High-precision timer for performance measurement.
|
|
14
|
+
|
|
15
|
+
Features:
|
|
16
|
+
- Context manager support for easy use
|
|
17
|
+
- Multiple named timers with aggregation
|
|
18
|
+
- Statistics tracking (mean, min, max, count)
|
|
19
|
+
- MLX-optimized for Apple Silicon performance profiling
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self):
|
|
23
|
+
"""Initialize timer with empty statistics."""
|
|
24
|
+
self.times: Dict[str, list] = {}
|
|
25
|
+
self.current_start: Optional[float] = None
|
|
26
|
+
|
|
27
|
+
def start(self, name: str = "default"):
|
|
28
|
+
"""
|
|
29
|
+
Start timing a named operation.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
name: Timer name for tracking multiple operations
|
|
33
|
+
"""
|
|
34
|
+
if name not in self.times:
|
|
35
|
+
self.times[name] = []
|
|
36
|
+
self.current_start = time.perf_counter()
|
|
37
|
+
|
|
38
|
+
def stop(self, name: str = "default") -> float:
|
|
39
|
+
"""
|
|
40
|
+
Stop timing and record the duration.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
name: Timer name (must match start() call)
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Duration in seconds
|
|
47
|
+
|
|
48
|
+
Raises:
|
|
49
|
+
RuntimeError: If timer wasn't started
|
|
50
|
+
"""
|
|
51
|
+
if self.current_start is None:
|
|
52
|
+
raise RuntimeError(f"Timer '{name}' was not started")
|
|
53
|
+
|
|
54
|
+
duration = time.perf_counter() - self.current_start
|
|
55
|
+
self.times[name].append(duration)
|
|
56
|
+
self.current_start = None
|
|
57
|
+
return duration
|
|
58
|
+
|
|
59
|
+
@contextmanager
|
|
60
|
+
def time(self, name: str = "default"):
|
|
61
|
+
"""
|
|
62
|
+
Context manager for timing code blocks.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
name: Timer name
|
|
66
|
+
|
|
67
|
+
Example:
|
|
68
|
+
timer = Timer()
|
|
69
|
+
with timer.time("policy_forward"):
|
|
70
|
+
action = policy(obs)
|
|
71
|
+
"""
|
|
72
|
+
self.start(name)
|
|
73
|
+
try:
|
|
74
|
+
yield
|
|
75
|
+
finally:
|
|
76
|
+
self.stop(name)
|
|
77
|
+
|
|
78
|
+
def get_stats(self, name: str = "default") -> Dict[str, float]:
|
|
79
|
+
"""
|
|
80
|
+
Get timing statistics for a named timer.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
name: Timer name
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Dictionary with mean, min, max, total, count statistics
|
|
87
|
+
"""
|
|
88
|
+
if name not in self.times or not self.times[name]:
|
|
89
|
+
return {"mean": 0.0, "min": 0.0, "max": 0.0, "total": 0.0, "count": 0}
|
|
90
|
+
|
|
91
|
+
times = self.times[name]
|
|
92
|
+
return {
|
|
93
|
+
"mean": sum(times) / len(times),
|
|
94
|
+
"min": min(times),
|
|
95
|
+
"max": max(times),
|
|
96
|
+
"total": sum(times),
|
|
97
|
+
"count": len(times)
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
def get_all_stats(self) -> Dict[str, Dict[str, float]]:
|
|
101
|
+
"""Get statistics for all named timers."""
|
|
102
|
+
return {name: self.get_stats(name) for name in self.times.keys()}
|
|
103
|
+
|
|
104
|
+
def reset(self, name: Optional[str] = None):
|
|
105
|
+
"""
|
|
106
|
+
Reset timer statistics.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
name: Specific timer to reset (None = reset all)
|
|
110
|
+
"""
|
|
111
|
+
if name is None:
|
|
112
|
+
self.times.clear()
|
|
113
|
+
elif name in self.times:
|
|
114
|
+
self.times[name].clear()
|
|
115
|
+
|
|
116
|
+
def __str__(self) -> str:
|
|
117
|
+
"""String representation of all timer statistics."""
|
|
118
|
+
lines = ["Timer Statistics:"]
|
|
119
|
+
for name, stats in self.get_all_stats().items():
|
|
120
|
+
lines.append(f" {name}: {stats['mean']:.4f}s avg ({stats['count']} calls)")
|
|
121
|
+
return "\n".join(lines)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
# Global timer instance for convenience
|
|
125
|
+
global_timer = Timer()
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@contextmanager
|
|
129
|
+
def time_it(name: str = "default", timer: Optional[Timer] = None):
|
|
130
|
+
"""
|
|
131
|
+
Convenience function for timing code blocks.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
name: Timer name
|
|
135
|
+
timer: Timer instance (uses global_timer if None)
|
|
136
|
+
|
|
137
|
+
Example:
|
|
138
|
+
with time_it("training_step"):
|
|
139
|
+
loss = trainer.update(batch)
|
|
140
|
+
"""
|
|
141
|
+
t = timer if timer is not None else global_timer
|
|
142
|
+
with t.time(name):
|
|
143
|
+
yield
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def benchmark_function(func, *args, iterations: int = 100, warmup: int = 10, **kwargs) -> Dict[str, float]:
|
|
147
|
+
"""
|
|
148
|
+
Benchmark a function with multiple iterations.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
func: Function to benchmark
|
|
152
|
+
*args: Function positional arguments
|
|
153
|
+
iterations: Number of benchmark iterations
|
|
154
|
+
warmup: Number of warmup iterations (excluded from timing)
|
|
155
|
+
**kwargs: Function keyword arguments
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
Statistics dictionary with timing results
|
|
159
|
+
"""
|
|
160
|
+
timer = Timer()
|
|
161
|
+
|
|
162
|
+
# Warmup iterations
|
|
163
|
+
for _ in range(warmup):
|
|
164
|
+
func(*args, **kwargs)
|
|
165
|
+
|
|
166
|
+
# Benchmark iterations
|
|
167
|
+
for i in range(iterations):
|
|
168
|
+
with timer.time("benchmark"):
|
|
169
|
+
func(*args, **kwargs)
|
|
170
|
+
|
|
171
|
+
return timer.get_stats("benchmark")
|