textpolicy 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +52 -0
- textpolicy/__main__.py +8 -0
- textpolicy/algorithms/__init__.py +54 -0
- textpolicy/algorithms/grpo.py +642 -0
- textpolicy/algorithms/gspo.py +582 -0
- textpolicy/buffer/__init__.py +23 -0
- textpolicy/buffer/buffer.py +244 -0
- textpolicy/buffer/episode.py +383 -0
- textpolicy/buffer/sampling.py +438 -0
- textpolicy/buffer/storage.py +255 -0
- textpolicy/cli.py +67 -0
- textpolicy/environment/__init__.py +79 -0
- textpolicy/environment/base.py +110 -0
- textpolicy/environment/environment.py +46 -0
- textpolicy/environment/factory.py +103 -0
- textpolicy/environment/gym.py +106 -0
- textpolicy/environment/task_suites.py +51 -0
- textpolicy/environment/text_generation.py +789 -0
- textpolicy/environment/vectorized.py +253 -0
- textpolicy/generation/__init__.py +62 -0
- textpolicy/generation/lora.py +411 -0
- textpolicy/generation/mlx_generation.py +557 -0
- textpolicy/generation/reload.py +253 -0
- textpolicy/rewards/__init__.py +137 -0
- textpolicy/rewards/adapters.py +387 -0
- textpolicy/rewards/basic.py +214 -0
- textpolicy/rewards/integrated_system.py +338 -0
- textpolicy/rewards/mlx_batch_processor.py +447 -0
- textpolicy/rewards/registry.py +293 -0
- textpolicy/rewards/rollout_rewards.py +410 -0
- textpolicy/rewards/verifiers.py +369 -0
- textpolicy/rollout/__init__.py +44 -0
- textpolicy/rollout/aggregator.py +145 -0
- textpolicy/rollout/base.py +108 -0
- textpolicy/rollout/rollout.py +142 -0
- textpolicy/rollout/runner.py +280 -0
- textpolicy/rollout/strategy.py +208 -0
- textpolicy/rollout/worker.py +194 -0
- textpolicy/training/__init__.py +14 -0
- textpolicy/training/metrics.py +242 -0
- textpolicy/training/rollout_manager.py +78 -0
- textpolicy/training/trainer.py +684 -0
- textpolicy/utils/__init__.py +40 -0
- textpolicy/utils/benchmarking.py +489 -0
- textpolicy/utils/data.py +60 -0
- textpolicy/utils/debug.py +170 -0
- textpolicy/utils/environment.py +349 -0
- textpolicy/utils/logging/__init__.py +22 -0
- textpolicy/utils/logging/base.py +48 -0
- textpolicy/utils/logging/console.py +61 -0
- textpolicy/utils/logging/factory.py +133 -0
- textpolicy/utils/logging/multi.py +83 -0
- textpolicy/utils/logging/tensorboard.py +65 -0
- textpolicy/utils/logging/wandb.py +72 -0
- textpolicy/utils/memory.py +118 -0
- textpolicy/utils/performance.py +464 -0
- textpolicy/utils/timing.py +171 -0
- textpolicy/validate.py +101 -0
- textpolicy/validation/__init__.py +13 -0
- textpolicy/validation/logprob_validation.py +315 -0
- textpolicy-0.1.0.dist-info/METADATA +99 -0
- textpolicy-0.1.0.dist-info/RECORD +66 -0
- textpolicy-0.1.0.dist-info/entry_points.txt +2 -0
- textpolicy-0.0.1.dist-info/METADATA +0 -10
- textpolicy-0.0.1.dist-info/RECORD +0 -6
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/WHEEL +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# textpolicy/utils/__init__.py
|
|
2
|
+
"""
|
|
3
|
+
General utilities for TextPolicy.
|
|
4
|
+
|
|
5
|
+
Organized by functionality:
|
|
6
|
+
- logging: Multiple backend logging support
|
|
7
|
+
- timing: Performance measurement and benchmarking
|
|
8
|
+
- memory: Memory monitoring and cleanup
|
|
9
|
+
- data: Data conversion and preprocessing
|
|
10
|
+
- math: Mathematical utilities and statistics
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
# Import logging utilities
|
|
14
|
+
from .logging import (
|
|
15
|
+
Logger, WandbLogger, TensorboardLogger, ConsoleLogger, MultiLogger,
|
|
16
|
+
create_logger, create_multi_logger, create_auto_logger
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
# Import other utilities
|
|
20
|
+
from .timing import Timer, global_timer, time_it, benchmark_function
|
|
21
|
+
from .memory import get_memory_stats, clear_memory, MemoryMonitor
|
|
22
|
+
from .data import to_mlx, to_numpy, batch_to_mlx
|
|
23
|
+
|
|
24
|
+
# Backwards compatibility for existing imports
|
|
25
|
+
from .logging import *
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
# Logging utilities
|
|
29
|
+
'Logger', 'WandbLogger', 'TensorboardLogger', 'ConsoleLogger', 'MultiLogger',
|
|
30
|
+
'create_logger', 'create_multi_logger', 'create_auto_logger',
|
|
31
|
+
|
|
32
|
+
# Timing utilities
|
|
33
|
+
'Timer', 'global_timer', 'time_it', 'benchmark_function',
|
|
34
|
+
|
|
35
|
+
# Memory utilities
|
|
36
|
+
'get_memory_stats', 'clear_memory', 'MemoryMonitor',
|
|
37
|
+
|
|
38
|
+
# Data utilities
|
|
39
|
+
'to_mlx', 'to_numpy', 'batch_to_mlx',
|
|
40
|
+
]
|
|
@@ -0,0 +1,489 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Performance benchmarking utilities for MLX-RL.
|
|
3
|
+
|
|
4
|
+
This module provides tools for comprehensive performance analysis, comparing
|
|
5
|
+
different configurations, and generating detailed performance reports.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import time
|
|
9
|
+
import statistics
|
|
10
|
+
from typing import Dict, List, Optional, Any, Callable
|
|
11
|
+
from dataclasses import dataclass, field
|
|
12
|
+
from contextlib import contextmanager
|
|
13
|
+
from enum import Enum
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
import mlx.core as mx # type: ignore
|
|
17
|
+
MLX_AVAILABLE = True
|
|
18
|
+
except ImportError:
|
|
19
|
+
MLX_AVAILABLE = False
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
from .timing import Timer, benchmark_function
|
|
23
|
+
from .debug import debug_print
|
|
24
|
+
from .environment import EnvironmentAnalyzer, EnvironmentProfile
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class BenchmarkType(Enum):
|
|
28
|
+
"""Types of benchmarks that can be performed."""
|
|
29
|
+
ENVIRONMENT = "environment"
|
|
30
|
+
TRAINING = "training"
|
|
31
|
+
POLICY = "policy"
|
|
32
|
+
VECTORIZATION = "vectorization"
|
|
33
|
+
MEMORY = "memory"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class BenchmarkResult:
|
|
38
|
+
"""Results from a single benchmark run."""
|
|
39
|
+
name: str
|
|
40
|
+
benchmark_type: BenchmarkType
|
|
41
|
+
duration_seconds: float
|
|
42
|
+
iterations: int
|
|
43
|
+
metrics: Dict[str, float] = field(default_factory=dict)
|
|
44
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def avg_time_per_iteration(self) -> float:
|
|
48
|
+
"""Average time per iteration in seconds."""
|
|
49
|
+
return self.duration_seconds / max(1, self.iterations)
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def iterations_per_second(self) -> float:
|
|
53
|
+
"""Number of iterations per second."""
|
|
54
|
+
return self.iterations / max(0.001, self.duration_seconds)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class ComparisonResult:
|
|
59
|
+
"""Results from comparing multiple benchmarks."""
|
|
60
|
+
baseline: BenchmarkResult
|
|
61
|
+
comparisons: List[BenchmarkResult]
|
|
62
|
+
|
|
63
|
+
def get_speedup(self, result: BenchmarkResult) -> float:
|
|
64
|
+
"""Calculate speedup compared to baseline."""
|
|
65
|
+
if self.baseline.avg_time_per_iteration == 0:
|
|
66
|
+
return 1.0
|
|
67
|
+
return self.baseline.avg_time_per_iteration / result.avg_time_per_iteration
|
|
68
|
+
|
|
69
|
+
def get_throughput_ratio(self, result: BenchmarkResult) -> float:
|
|
70
|
+
"""Calculate throughput ratio compared to baseline."""
|
|
71
|
+
if self.baseline.iterations_per_second == 0:
|
|
72
|
+
return 1.0
|
|
73
|
+
return result.iterations_per_second / self.baseline.iterations_per_second
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class PerformanceBenchmark:
|
|
77
|
+
"""
|
|
78
|
+
Comprehensive performance benchmarking system for MLX-RL.
|
|
79
|
+
|
|
80
|
+
Features:
|
|
81
|
+
- Environment performance comparison
|
|
82
|
+
- Vectorization vs single environment benchmarks
|
|
83
|
+
- Training pipeline performance analysis
|
|
84
|
+
- Statistical significance testing
|
|
85
|
+
- Detailed reporting and visualization
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
def __init__(self, warmup_iterations: int = 10, min_duration: float = 1.0):
|
|
89
|
+
"""
|
|
90
|
+
Initialize performance benchmark.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
warmup_iterations: Number of warmup iterations before timing
|
|
94
|
+
min_duration: Minimum benchmark duration in seconds
|
|
95
|
+
"""
|
|
96
|
+
self.warmup_iterations = warmup_iterations
|
|
97
|
+
self.min_duration = min_duration
|
|
98
|
+
self.timer = Timer()
|
|
99
|
+
self.results: List[BenchmarkResult] = []
|
|
100
|
+
|
|
101
|
+
def benchmark_environment_speed(self, env_name: str, iterations: int = 1000, **env_kwargs) -> BenchmarkResult:
|
|
102
|
+
"""
|
|
103
|
+
Benchmark environment step performance.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
env_name: Environment name
|
|
107
|
+
iterations: Number of steps to benchmark
|
|
108
|
+
**env_kwargs: Environment creation arguments
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
BenchmarkResult with environment performance data
|
|
112
|
+
"""
|
|
113
|
+
debug_print(f"Benchmarking environment speed: {env_name}", "benchmarking")
|
|
114
|
+
|
|
115
|
+
try:
|
|
116
|
+
from textpolicy.environment import GymAdapter
|
|
117
|
+
|
|
118
|
+
env = GymAdapter(env_name, **env_kwargs)
|
|
119
|
+
|
|
120
|
+
# Warmup
|
|
121
|
+
obs, info = env.reset()
|
|
122
|
+
for _ in range(self.warmup_iterations):
|
|
123
|
+
action = env.action_space.sample()
|
|
124
|
+
step_result = env.step(action)
|
|
125
|
+
if step_result["terminated"] or step_result["truncated"]:
|
|
126
|
+
obs, info = env.reset()
|
|
127
|
+
|
|
128
|
+
# Benchmark
|
|
129
|
+
obs, info = env.reset()
|
|
130
|
+
episode_lengths = []
|
|
131
|
+
current_episode_length = 0
|
|
132
|
+
|
|
133
|
+
start_time = time.perf_counter()
|
|
134
|
+
for i in range(iterations):
|
|
135
|
+
action = env.action_space.sample()
|
|
136
|
+
step_result = env.step(action)
|
|
137
|
+
current_episode_length += 1
|
|
138
|
+
|
|
139
|
+
if step_result["terminated"] or step_result["truncated"]:
|
|
140
|
+
episode_lengths.append(current_episode_length)
|
|
141
|
+
current_episode_length = 0
|
|
142
|
+
obs, info = env.reset()
|
|
143
|
+
|
|
144
|
+
end_time = time.perf_counter()
|
|
145
|
+
duration = end_time - start_time
|
|
146
|
+
|
|
147
|
+
env.close()
|
|
148
|
+
|
|
149
|
+
# Calculate metrics
|
|
150
|
+
avg_episode_length = statistics.mean(episode_lengths) if episode_lengths else current_episode_length
|
|
151
|
+
steps_per_second = iterations / duration
|
|
152
|
+
|
|
153
|
+
result = BenchmarkResult(
|
|
154
|
+
name=f"{env_name}_speed",
|
|
155
|
+
benchmark_type=BenchmarkType.ENVIRONMENT,
|
|
156
|
+
duration_seconds=duration,
|
|
157
|
+
iterations=iterations,
|
|
158
|
+
metrics={
|
|
159
|
+
"steps_per_second": steps_per_second,
|
|
160
|
+
"avg_step_time_ms": (duration / iterations) * 1000,
|
|
161
|
+
"avg_episode_length": avg_episode_length,
|
|
162
|
+
"episodes_completed": len(episode_lengths)
|
|
163
|
+
},
|
|
164
|
+
metadata={"env_name": env_name, "env_kwargs": env_kwargs}
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
self.results.append(result)
|
|
168
|
+
return result
|
|
169
|
+
|
|
170
|
+
except Exception as e:
|
|
171
|
+
debug_print(f"Error benchmarking environment {env_name}: {e}", "benchmarking")
|
|
172
|
+
raise
|
|
173
|
+
|
|
174
|
+
def benchmark_vectorization_comparison(self, env_name: str, num_envs: int = 4,
|
|
175
|
+
iterations: int = 1000, **env_kwargs) -> ComparisonResult:
|
|
176
|
+
"""
|
|
177
|
+
Compare single vs vectorized environment performance.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
env_name: Environment name
|
|
181
|
+
num_envs: Number of vectorized environments
|
|
182
|
+
iterations: Number of steps per environment
|
|
183
|
+
**env_kwargs: Environment creation arguments
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
ComparisonResult with single vs vectorized performance
|
|
187
|
+
"""
|
|
188
|
+
debug_print(f"Benchmarking vectorization: {env_name} (1 vs {num_envs} envs)", "benchmarking")
|
|
189
|
+
|
|
190
|
+
# Benchmark single environment
|
|
191
|
+
single_result = self.benchmark_environment_speed(env_name, iterations, **env_kwargs)
|
|
192
|
+
single_result.name = f"{env_name}_single"
|
|
193
|
+
|
|
194
|
+
# Benchmark vectorized environments
|
|
195
|
+
try:
|
|
196
|
+
from textpolicy.environment import make_vectorized_env
|
|
197
|
+
|
|
198
|
+
vectorized_env = make_vectorized_env(env_name, num_envs=num_envs, env_kwargs=env_kwargs)
|
|
199
|
+
|
|
200
|
+
# Warmup
|
|
201
|
+
obs, infos = vectorized_env.reset()
|
|
202
|
+
for _ in range(self.warmup_iterations):
|
|
203
|
+
# Sample actions for each environment and create batch
|
|
204
|
+
action_samples = [vectorized_env.action_space.sample() for _ in range(num_envs)]
|
|
205
|
+
actions = mx.array(action_samples)
|
|
206
|
+
step_results = vectorized_env.step(actions)
|
|
207
|
+
# Check if any environments are done
|
|
208
|
+
terminated = step_results["terminated"]
|
|
209
|
+
truncated = step_results["truncated"]
|
|
210
|
+
if mx.any(terminated | truncated):
|
|
211
|
+
obs, infos = vectorized_env.reset()
|
|
212
|
+
|
|
213
|
+
# Benchmark
|
|
214
|
+
obs, infos = vectorized_env.reset()
|
|
215
|
+
start_time = time.perf_counter()
|
|
216
|
+
|
|
217
|
+
for i in range(iterations):
|
|
218
|
+
# Sample actions for each environment and create batch
|
|
219
|
+
action_samples = [vectorized_env.action_space.sample() for _ in range(num_envs)]
|
|
220
|
+
actions = mx.array(action_samples)
|
|
221
|
+
step_results = vectorized_env.step(actions)
|
|
222
|
+
# Check if any environments are done
|
|
223
|
+
terminated = step_results["terminated"]
|
|
224
|
+
truncated = step_results["truncated"]
|
|
225
|
+
if mx.any(terminated | truncated):
|
|
226
|
+
obs, infos = vectorized_env.reset()
|
|
227
|
+
|
|
228
|
+
end_time = time.perf_counter()
|
|
229
|
+
duration = end_time - start_time
|
|
230
|
+
|
|
231
|
+
vectorized_env.close()
|
|
232
|
+
|
|
233
|
+
# Total steps = iterations * num_envs
|
|
234
|
+
total_steps = iterations * num_envs
|
|
235
|
+
steps_per_second = total_steps / duration
|
|
236
|
+
|
|
237
|
+
vectorized_result = BenchmarkResult(
|
|
238
|
+
name=f"{env_name}_vectorized_{num_envs}",
|
|
239
|
+
benchmark_type=BenchmarkType.VECTORIZATION,
|
|
240
|
+
duration_seconds=duration,
|
|
241
|
+
iterations=total_steps,
|
|
242
|
+
metrics={
|
|
243
|
+
"steps_per_second": steps_per_second,
|
|
244
|
+
"avg_step_time_ms": (duration / total_steps) * 1000,
|
|
245
|
+
"num_envs": num_envs,
|
|
246
|
+
"parallelization_efficiency": steps_per_second / (single_result.metrics["steps_per_second"] * num_envs)
|
|
247
|
+
},
|
|
248
|
+
metadata={"env_name": env_name, "num_envs": num_envs, "env_kwargs": env_kwargs}
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
self.results.append(vectorized_result)
|
|
252
|
+
|
|
253
|
+
return ComparisonResult(baseline=single_result, comparisons=[vectorized_result])
|
|
254
|
+
|
|
255
|
+
except Exception as e:
|
|
256
|
+
debug_print(f"Error benchmarking vectorized environment: {e}", "benchmarking")
|
|
257
|
+
# Return comparison with just single environment
|
|
258
|
+
return ComparisonResult(baseline=single_result, comparisons=[])
|
|
259
|
+
|
|
260
|
+
def benchmark_function_performance(self, name: str, func: Callable, *args,
|
|
261
|
+
iterations: int = 100, **kwargs) -> BenchmarkResult:
|
|
262
|
+
"""
|
|
263
|
+
Benchmark arbitrary function performance.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
name: Benchmark name
|
|
267
|
+
func: Function to benchmark
|
|
268
|
+
*args: Function positional arguments
|
|
269
|
+
iterations: Number of iterations
|
|
270
|
+
**kwargs: Function keyword arguments
|
|
271
|
+
|
|
272
|
+
Returns:
|
|
273
|
+
BenchmarkResult with function performance data
|
|
274
|
+
"""
|
|
275
|
+
stats = benchmark_function(func, *args, iterations=iterations,
|
|
276
|
+
warmup=self.warmup_iterations, **kwargs)
|
|
277
|
+
|
|
278
|
+
result = BenchmarkResult(
|
|
279
|
+
name=name,
|
|
280
|
+
benchmark_type=BenchmarkType.TRAINING,
|
|
281
|
+
duration_seconds=stats["total"],
|
|
282
|
+
iterations=iterations,
|
|
283
|
+
metrics={
|
|
284
|
+
"avg_time_ms": stats["mean"] * 1000,
|
|
285
|
+
"min_time_ms": stats["min"] * 1000,
|
|
286
|
+
"max_time_ms": stats["max"] * 1000,
|
|
287
|
+
"std_time_ms": 0.0 # Would need to calculate from raw data
|
|
288
|
+
}
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
self.results.append(result)
|
|
292
|
+
return result
|
|
293
|
+
|
|
294
|
+
def compare_environments(self, env_names: List[str], iterations: int = 1000,
|
|
295
|
+
**shared_env_kwargs) -> List[BenchmarkResult]:
|
|
296
|
+
"""
|
|
297
|
+
Compare performance across multiple environments.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
env_names: List of environment names to compare
|
|
301
|
+
iterations: Number of steps per environment
|
|
302
|
+
**shared_env_kwargs: Shared environment creation arguments
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
List of BenchmarkResults sorted by performance
|
|
306
|
+
"""
|
|
307
|
+
debug_print(f"Comparing {len(env_names)} environments", "benchmarking")
|
|
308
|
+
|
|
309
|
+
results = []
|
|
310
|
+
for env_name in env_names:
|
|
311
|
+
try:
|
|
312
|
+
result = self.benchmark_environment_speed(env_name, iterations, **shared_env_kwargs)
|
|
313
|
+
results.append(result)
|
|
314
|
+
except Exception as e:
|
|
315
|
+
debug_print(f"Failed to benchmark {env_name}: {e}", "benchmarking")
|
|
316
|
+
continue
|
|
317
|
+
|
|
318
|
+
# Sort by steps per second (descending)
|
|
319
|
+
results.sort(key=lambda r: r.metrics.get("steps_per_second", 0), reverse=True)
|
|
320
|
+
return results
|
|
321
|
+
|
|
322
|
+
def get_results(self, benchmark_type: Optional[BenchmarkType] = None) -> List[BenchmarkResult]:
|
|
323
|
+
"""
|
|
324
|
+
Get benchmark results, optionally filtered by type.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
benchmark_type: Optional filter by benchmark type
|
|
328
|
+
|
|
329
|
+
Returns:
|
|
330
|
+
List of BenchmarkResults
|
|
331
|
+
"""
|
|
332
|
+
if benchmark_type is None:
|
|
333
|
+
return self.results.copy()
|
|
334
|
+
return [r for r in self.results if r.benchmark_type == benchmark_type]
|
|
335
|
+
|
|
336
|
+
def clear_results(self):
|
|
337
|
+
"""Clear all benchmark results."""
|
|
338
|
+
self.results.clear()
|
|
339
|
+
|
|
340
|
+
def print_results(self, benchmark_type: Optional[BenchmarkType] = None, detailed: bool = True):
|
|
341
|
+
"""
|
|
342
|
+
Print benchmark results in a readable format.
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
benchmark_type: Optional filter by benchmark type
|
|
346
|
+
detailed: Whether to show detailed metrics
|
|
347
|
+
"""
|
|
348
|
+
results = self.get_results(benchmark_type)
|
|
349
|
+
|
|
350
|
+
if not results:
|
|
351
|
+
print("No benchmark results available.")
|
|
352
|
+
return
|
|
353
|
+
|
|
354
|
+
print(f"\nBenchmark Results ({len(results)} tests)")
|
|
355
|
+
print("=" * 70)
|
|
356
|
+
|
|
357
|
+
if detailed:
|
|
358
|
+
for result in results:
|
|
359
|
+
print(f"\n{result.name}")
|
|
360
|
+
print(f" Type: {result.benchmark_type.value}")
|
|
361
|
+
print(f" Duration: {result.duration_seconds:.3f}s")
|
|
362
|
+
print(f" Iterations: {result.iterations}")
|
|
363
|
+
print(f" Avg time/iteration: {result.avg_time_per_iteration*1000:.3f}ms")
|
|
364
|
+
print(f" Iterations/second: {result.iterations_per_second:.0f}")
|
|
365
|
+
|
|
366
|
+
if result.metrics:
|
|
367
|
+
print(" Metrics:")
|
|
368
|
+
for key, value in result.metrics.items():
|
|
369
|
+
if isinstance(value, float):
|
|
370
|
+
print(f" {key}: {value:.3f}")
|
|
371
|
+
else:
|
|
372
|
+
print(f" {key}: {value}")
|
|
373
|
+
else:
|
|
374
|
+
# Compact table format
|
|
375
|
+
print(f"{'Name':<25} {'Type':<12} {'Time/iter (ms)':<15} {'Iter/sec':<12}")
|
|
376
|
+
print("-" * 70)
|
|
377
|
+
for result in results:
|
|
378
|
+
print(f"{result.name:<25} {result.benchmark_type.value:<12} "
|
|
379
|
+
f"{result.avg_time_per_iteration*1000:<15.3f} {result.iterations_per_second:<12.0f}")
|
|
380
|
+
|
|
381
|
+
def print_comparison(self, comparison: ComparisonResult):
|
|
382
|
+
"""
|
|
383
|
+
Print detailed comparison results.
|
|
384
|
+
|
|
385
|
+
Args:
|
|
386
|
+
comparison: ComparisonResult to display
|
|
387
|
+
"""
|
|
388
|
+
print("\nPerformance Comparison")
|
|
389
|
+
print("=" * 50)
|
|
390
|
+
|
|
391
|
+
baseline = comparison.baseline
|
|
392
|
+
print(f"Baseline: {baseline.name}")
|
|
393
|
+
print(f" {baseline.iterations_per_second:.0f} iterations/sec")
|
|
394
|
+
print(f" {baseline.avg_time_per_iteration*1000:.3f} ms/iteration")
|
|
395
|
+
|
|
396
|
+
print("\nComparisons:")
|
|
397
|
+
for result in comparison.comparisons:
|
|
398
|
+
speedup = comparison.get_speedup(result)
|
|
399
|
+
throughput_ratio = comparison.get_throughput_ratio(result)
|
|
400
|
+
|
|
401
|
+
print(f" {result.name}:")
|
|
402
|
+
print(f" {result.iterations_per_second:.0f} iterations/sec ({throughput_ratio:.2f}x throughput)")
|
|
403
|
+
print(f" {result.avg_time_per_iteration*1000:.3f} ms/iteration ({speedup:.2f}x speedup)")
|
|
404
|
+
|
|
405
|
+
if speedup > 1.1:
|
|
406
|
+
print(f" {speedup:.1f}x faster than baseline")
|
|
407
|
+
elif speedup < 0.9:
|
|
408
|
+
print(f" {1/speedup:.1f}x slower than baseline")
|
|
409
|
+
else:
|
|
410
|
+
print(" Similar performance to baseline")
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
def quick_environment_benchmark(env_name: str, iterations: int = 1000, **env_kwargs) -> EnvironmentProfile:
|
|
414
|
+
"""
|
|
415
|
+
Quick benchmark and analysis of a single environment.
|
|
416
|
+
|
|
417
|
+
Args:
|
|
418
|
+
env_name: Environment name
|
|
419
|
+
iterations: Number of steps to benchmark
|
|
420
|
+
**env_kwargs: Environment creation arguments
|
|
421
|
+
|
|
422
|
+
Returns:
|
|
423
|
+
EnvironmentProfile with analysis and recommendations
|
|
424
|
+
"""
|
|
425
|
+
# Use EnvironmentAnalyzer for detailed analysis
|
|
426
|
+
analyzer = EnvironmentAnalyzer(test_steps=iterations)
|
|
427
|
+
return analyzer.analyze_environment(env_name, **env_kwargs)
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def quick_vectorization_test(env_name: str, num_envs_list: List[int] = [1, 2, 4, 8],
|
|
431
|
+
iterations: int = 500, **env_kwargs) -> Dict[int, BenchmarkResult]:
|
|
432
|
+
"""
|
|
433
|
+
Quick test of vectorization performance across different environment counts.
|
|
434
|
+
|
|
435
|
+
Args:
|
|
436
|
+
env_name: Environment name
|
|
437
|
+
num_envs_list: List of environment counts to test
|
|
438
|
+
iterations: Number of steps per test
|
|
439
|
+
**env_kwargs: Environment creation arguments
|
|
440
|
+
|
|
441
|
+
Returns:
|
|
442
|
+
Dictionary mapping num_envs to BenchmarkResult
|
|
443
|
+
"""
|
|
444
|
+
benchmark = PerformanceBenchmark()
|
|
445
|
+
results = {}
|
|
446
|
+
|
|
447
|
+
for num_envs in num_envs_list:
|
|
448
|
+
if num_envs == 1:
|
|
449
|
+
result = benchmark.benchmark_environment_speed(env_name, iterations, **env_kwargs)
|
|
450
|
+
else:
|
|
451
|
+
comparison = benchmark.benchmark_vectorization_comparison(
|
|
452
|
+
env_name, num_envs=num_envs, iterations=iterations//num_envs, **env_kwargs
|
|
453
|
+
)
|
|
454
|
+
result = comparison.comparisons[0] if comparison.comparisons else None
|
|
455
|
+
|
|
456
|
+
if result:
|
|
457
|
+
results[num_envs] = result
|
|
458
|
+
|
|
459
|
+
return results
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
@contextmanager
|
|
463
|
+
def benchmark_context(name: str, benchmark: Optional[PerformanceBenchmark] = None):
|
|
464
|
+
"""
|
|
465
|
+
Context manager for benchmarking code blocks.
|
|
466
|
+
|
|
467
|
+
Args:
|
|
468
|
+
name: Benchmark name
|
|
469
|
+
benchmark: PerformanceBenchmark instance (creates new if None)
|
|
470
|
+
|
|
471
|
+
Example:
|
|
472
|
+
with benchmark_context("training_step") as bench:
|
|
473
|
+
loss = trainer.update(batch)
|
|
474
|
+
"""
|
|
475
|
+
bench = benchmark or PerformanceBenchmark()
|
|
476
|
+
start_time = time.perf_counter()
|
|
477
|
+
|
|
478
|
+
try:
|
|
479
|
+
yield bench
|
|
480
|
+
finally:
|
|
481
|
+
duration = time.perf_counter() - start_time
|
|
482
|
+
result = BenchmarkResult(
|
|
483
|
+
name=name,
|
|
484
|
+
benchmark_type=BenchmarkType.TRAINING,
|
|
485
|
+
duration_seconds=duration,
|
|
486
|
+
iterations=1,
|
|
487
|
+
metrics={"duration_ms": duration * 1000}
|
|
488
|
+
)
|
|
489
|
+
bench.results.append(result)
|
textpolicy/utils/data.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# textpolicy/utils/data.py
|
|
2
|
+
"""
|
|
3
|
+
Data processing and conversion utilities.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Any, Dict
|
|
7
|
+
import mlx.core as mx # type: ignore
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def to_mlx(data: Any) -> mx.array:
|
|
12
|
+
"""
|
|
13
|
+
Convert various data types to MLX arrays.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
data: Input data (numpy array, list, scalar, etc.)
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
MLX array
|
|
20
|
+
"""
|
|
21
|
+
if isinstance(data, mx.array):
|
|
22
|
+
return data
|
|
23
|
+
elif isinstance(data, np.ndarray):
|
|
24
|
+
# Ensure contiguous array for MLX compatibility (fixes "Invalid type ndarray" error)
|
|
25
|
+
if not data.flags.c_contiguous:
|
|
26
|
+
data = np.ascontiguousarray(data)
|
|
27
|
+
return mx.array(data)
|
|
28
|
+
elif isinstance(data, (list, tuple)):
|
|
29
|
+
return mx.array(data)
|
|
30
|
+
elif isinstance(data, (int, float)):
|
|
31
|
+
return mx.array(data)
|
|
32
|
+
else:
|
|
33
|
+
# Try direct conversion
|
|
34
|
+
return mx.array(data)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def to_numpy(data: mx.array) -> np.ndarray:
|
|
38
|
+
"""
|
|
39
|
+
Convert MLX array to numpy array.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
data: MLX array
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
Numpy array
|
|
46
|
+
"""
|
|
47
|
+
return np.array(data)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def batch_to_mlx(batch: Dict[str, Any]) -> Dict[str, mx.array]:
|
|
51
|
+
"""
|
|
52
|
+
Convert batch dictionary to MLX arrays.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
batch: Dictionary with various data types
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Dictionary with MLX arrays
|
|
59
|
+
"""
|
|
60
|
+
return {key: to_mlx(value) for key, value in batch.items()}
|