textpolicy 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. textpolicy/__init__.py +52 -0
  2. textpolicy/__main__.py +8 -0
  3. textpolicy/algorithms/__init__.py +54 -0
  4. textpolicy/algorithms/grpo.py +642 -0
  5. textpolicy/algorithms/gspo.py +582 -0
  6. textpolicy/buffer/__init__.py +23 -0
  7. textpolicy/buffer/buffer.py +244 -0
  8. textpolicy/buffer/episode.py +383 -0
  9. textpolicy/buffer/sampling.py +438 -0
  10. textpolicy/buffer/storage.py +255 -0
  11. textpolicy/cli.py +67 -0
  12. textpolicy/environment/__init__.py +79 -0
  13. textpolicy/environment/base.py +110 -0
  14. textpolicy/environment/environment.py +46 -0
  15. textpolicy/environment/factory.py +103 -0
  16. textpolicy/environment/gym.py +106 -0
  17. textpolicy/environment/task_suites.py +51 -0
  18. textpolicy/environment/text_generation.py +789 -0
  19. textpolicy/environment/vectorized.py +253 -0
  20. textpolicy/generation/__init__.py +62 -0
  21. textpolicy/generation/lora.py +411 -0
  22. textpolicy/generation/mlx_generation.py +557 -0
  23. textpolicy/generation/reload.py +253 -0
  24. textpolicy/rewards/__init__.py +137 -0
  25. textpolicy/rewards/adapters.py +387 -0
  26. textpolicy/rewards/basic.py +214 -0
  27. textpolicy/rewards/integrated_system.py +338 -0
  28. textpolicy/rewards/mlx_batch_processor.py +447 -0
  29. textpolicy/rewards/registry.py +293 -0
  30. textpolicy/rewards/rollout_rewards.py +410 -0
  31. textpolicy/rewards/verifiers.py +369 -0
  32. textpolicy/rollout/__init__.py +44 -0
  33. textpolicy/rollout/aggregator.py +145 -0
  34. textpolicy/rollout/base.py +108 -0
  35. textpolicy/rollout/rollout.py +142 -0
  36. textpolicy/rollout/runner.py +280 -0
  37. textpolicy/rollout/strategy.py +208 -0
  38. textpolicy/rollout/worker.py +194 -0
  39. textpolicy/training/__init__.py +14 -0
  40. textpolicy/training/metrics.py +242 -0
  41. textpolicy/training/rollout_manager.py +78 -0
  42. textpolicy/training/trainer.py +684 -0
  43. textpolicy/utils/__init__.py +40 -0
  44. textpolicy/utils/benchmarking.py +489 -0
  45. textpolicy/utils/data.py +60 -0
  46. textpolicy/utils/debug.py +170 -0
  47. textpolicy/utils/environment.py +349 -0
  48. textpolicy/utils/logging/__init__.py +22 -0
  49. textpolicy/utils/logging/base.py +48 -0
  50. textpolicy/utils/logging/console.py +61 -0
  51. textpolicy/utils/logging/factory.py +133 -0
  52. textpolicy/utils/logging/multi.py +83 -0
  53. textpolicy/utils/logging/tensorboard.py +65 -0
  54. textpolicy/utils/logging/wandb.py +72 -0
  55. textpolicy/utils/memory.py +118 -0
  56. textpolicy/utils/performance.py +464 -0
  57. textpolicy/utils/timing.py +171 -0
  58. textpolicy/validate.py +101 -0
  59. textpolicy/validation/__init__.py +13 -0
  60. textpolicy/validation/logprob_validation.py +315 -0
  61. textpolicy-0.1.0.dist-info/METADATA +99 -0
  62. textpolicy-0.1.0.dist-info/RECORD +66 -0
  63. textpolicy-0.1.0.dist-info/entry_points.txt +2 -0
  64. textpolicy-0.0.1.dist-info/METADATA +0 -10
  65. textpolicy-0.0.1.dist-info/RECORD +0 -6
  66. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/WHEEL +0 -0
  67. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/licenses/LICENSE +0 -0
  68. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,40 @@
1
+ # textpolicy/utils/__init__.py
2
+ """
3
+ General utilities for TextPolicy.
4
+
5
+ Organized by functionality:
6
+ - logging: Multiple backend logging support
7
+ - timing: Performance measurement and benchmarking
8
+ - memory: Memory monitoring and cleanup
9
+ - data: Data conversion and preprocessing
10
+ - math: Mathematical utilities and statistics
11
+ """
12
+
13
+ # Import logging utilities
14
+ from .logging import (
15
+ Logger, WandbLogger, TensorboardLogger, ConsoleLogger, MultiLogger,
16
+ create_logger, create_multi_logger, create_auto_logger
17
+ )
18
+
19
+ # Import other utilities
20
+ from .timing import Timer, global_timer, time_it, benchmark_function
21
+ from .memory import get_memory_stats, clear_memory, MemoryMonitor
22
+ from .data import to_mlx, to_numpy, batch_to_mlx
23
+
24
+ # Backwards compatibility for existing imports
25
+ from .logging import *
26
+
27
+ __all__ = [
28
+ # Logging utilities
29
+ 'Logger', 'WandbLogger', 'TensorboardLogger', 'ConsoleLogger', 'MultiLogger',
30
+ 'create_logger', 'create_multi_logger', 'create_auto_logger',
31
+
32
+ # Timing utilities
33
+ 'Timer', 'global_timer', 'time_it', 'benchmark_function',
34
+
35
+ # Memory utilities
36
+ 'get_memory_stats', 'clear_memory', 'MemoryMonitor',
37
+
38
+ # Data utilities
39
+ 'to_mlx', 'to_numpy', 'batch_to_mlx',
40
+ ]
@@ -0,0 +1,489 @@
1
+ """
2
+ Performance benchmarking utilities for MLX-RL.
3
+
4
+ This module provides tools for comprehensive performance analysis, comparing
5
+ different configurations, and generating detailed performance reports.
6
+ """
7
+
8
+ import time
9
+ import statistics
10
+ from typing import Dict, List, Optional, Any, Callable
11
+ from dataclasses import dataclass, field
12
+ from contextlib import contextmanager
13
+ from enum import Enum
14
+
15
+ try:
16
+ import mlx.core as mx # type: ignore
17
+ MLX_AVAILABLE = True
18
+ except ImportError:
19
+ MLX_AVAILABLE = False
20
+
21
+
22
+ from .timing import Timer, benchmark_function
23
+ from .debug import debug_print
24
+ from .environment import EnvironmentAnalyzer, EnvironmentProfile
25
+
26
+
27
+ class BenchmarkType(Enum):
28
+ """Types of benchmarks that can be performed."""
29
+ ENVIRONMENT = "environment"
30
+ TRAINING = "training"
31
+ POLICY = "policy"
32
+ VECTORIZATION = "vectorization"
33
+ MEMORY = "memory"
34
+
35
+
36
+ @dataclass
37
+ class BenchmarkResult:
38
+ """Results from a single benchmark run."""
39
+ name: str
40
+ benchmark_type: BenchmarkType
41
+ duration_seconds: float
42
+ iterations: int
43
+ metrics: Dict[str, float] = field(default_factory=dict)
44
+ metadata: Dict[str, Any] = field(default_factory=dict)
45
+
46
+ @property
47
+ def avg_time_per_iteration(self) -> float:
48
+ """Average time per iteration in seconds."""
49
+ return self.duration_seconds / max(1, self.iterations)
50
+
51
+ @property
52
+ def iterations_per_second(self) -> float:
53
+ """Number of iterations per second."""
54
+ return self.iterations / max(0.001, self.duration_seconds)
55
+
56
+
57
+ @dataclass
58
+ class ComparisonResult:
59
+ """Results from comparing multiple benchmarks."""
60
+ baseline: BenchmarkResult
61
+ comparisons: List[BenchmarkResult]
62
+
63
+ def get_speedup(self, result: BenchmarkResult) -> float:
64
+ """Calculate speedup compared to baseline."""
65
+ if self.baseline.avg_time_per_iteration == 0:
66
+ return 1.0
67
+ return self.baseline.avg_time_per_iteration / result.avg_time_per_iteration
68
+
69
+ def get_throughput_ratio(self, result: BenchmarkResult) -> float:
70
+ """Calculate throughput ratio compared to baseline."""
71
+ if self.baseline.iterations_per_second == 0:
72
+ return 1.0
73
+ return result.iterations_per_second / self.baseline.iterations_per_second
74
+
75
+
76
+ class PerformanceBenchmark:
77
+ """
78
+ Comprehensive performance benchmarking system for MLX-RL.
79
+
80
+ Features:
81
+ - Environment performance comparison
82
+ - Vectorization vs single environment benchmarks
83
+ - Training pipeline performance analysis
84
+ - Statistical significance testing
85
+ - Detailed reporting and visualization
86
+ """
87
+
88
+ def __init__(self, warmup_iterations: int = 10, min_duration: float = 1.0):
89
+ """
90
+ Initialize performance benchmark.
91
+
92
+ Args:
93
+ warmup_iterations: Number of warmup iterations before timing
94
+ min_duration: Minimum benchmark duration in seconds
95
+ """
96
+ self.warmup_iterations = warmup_iterations
97
+ self.min_duration = min_duration
98
+ self.timer = Timer()
99
+ self.results: List[BenchmarkResult] = []
100
+
101
+ def benchmark_environment_speed(self, env_name: str, iterations: int = 1000, **env_kwargs) -> BenchmarkResult:
102
+ """
103
+ Benchmark environment step performance.
104
+
105
+ Args:
106
+ env_name: Environment name
107
+ iterations: Number of steps to benchmark
108
+ **env_kwargs: Environment creation arguments
109
+
110
+ Returns:
111
+ BenchmarkResult with environment performance data
112
+ """
113
+ debug_print(f"Benchmarking environment speed: {env_name}", "benchmarking")
114
+
115
+ try:
116
+ from textpolicy.environment import GymAdapter
117
+
118
+ env = GymAdapter(env_name, **env_kwargs)
119
+
120
+ # Warmup
121
+ obs, info = env.reset()
122
+ for _ in range(self.warmup_iterations):
123
+ action = env.action_space.sample()
124
+ step_result = env.step(action)
125
+ if step_result["terminated"] or step_result["truncated"]:
126
+ obs, info = env.reset()
127
+
128
+ # Benchmark
129
+ obs, info = env.reset()
130
+ episode_lengths = []
131
+ current_episode_length = 0
132
+
133
+ start_time = time.perf_counter()
134
+ for i in range(iterations):
135
+ action = env.action_space.sample()
136
+ step_result = env.step(action)
137
+ current_episode_length += 1
138
+
139
+ if step_result["terminated"] or step_result["truncated"]:
140
+ episode_lengths.append(current_episode_length)
141
+ current_episode_length = 0
142
+ obs, info = env.reset()
143
+
144
+ end_time = time.perf_counter()
145
+ duration = end_time - start_time
146
+
147
+ env.close()
148
+
149
+ # Calculate metrics
150
+ avg_episode_length = statistics.mean(episode_lengths) if episode_lengths else current_episode_length
151
+ steps_per_second = iterations / duration
152
+
153
+ result = BenchmarkResult(
154
+ name=f"{env_name}_speed",
155
+ benchmark_type=BenchmarkType.ENVIRONMENT,
156
+ duration_seconds=duration,
157
+ iterations=iterations,
158
+ metrics={
159
+ "steps_per_second": steps_per_second,
160
+ "avg_step_time_ms": (duration / iterations) * 1000,
161
+ "avg_episode_length": avg_episode_length,
162
+ "episodes_completed": len(episode_lengths)
163
+ },
164
+ metadata={"env_name": env_name, "env_kwargs": env_kwargs}
165
+ )
166
+
167
+ self.results.append(result)
168
+ return result
169
+
170
+ except Exception as e:
171
+ debug_print(f"Error benchmarking environment {env_name}: {e}", "benchmarking")
172
+ raise
173
+
174
+ def benchmark_vectorization_comparison(self, env_name: str, num_envs: int = 4,
175
+ iterations: int = 1000, **env_kwargs) -> ComparisonResult:
176
+ """
177
+ Compare single vs vectorized environment performance.
178
+
179
+ Args:
180
+ env_name: Environment name
181
+ num_envs: Number of vectorized environments
182
+ iterations: Number of steps per environment
183
+ **env_kwargs: Environment creation arguments
184
+
185
+ Returns:
186
+ ComparisonResult with single vs vectorized performance
187
+ """
188
+ debug_print(f"Benchmarking vectorization: {env_name} (1 vs {num_envs} envs)", "benchmarking")
189
+
190
+ # Benchmark single environment
191
+ single_result = self.benchmark_environment_speed(env_name, iterations, **env_kwargs)
192
+ single_result.name = f"{env_name}_single"
193
+
194
+ # Benchmark vectorized environments
195
+ try:
196
+ from textpolicy.environment import make_vectorized_env
197
+
198
+ vectorized_env = make_vectorized_env(env_name, num_envs=num_envs, env_kwargs=env_kwargs)
199
+
200
+ # Warmup
201
+ obs, infos = vectorized_env.reset()
202
+ for _ in range(self.warmup_iterations):
203
+ # Sample actions for each environment and create batch
204
+ action_samples = [vectorized_env.action_space.sample() for _ in range(num_envs)]
205
+ actions = mx.array(action_samples)
206
+ step_results = vectorized_env.step(actions)
207
+ # Check if any environments are done
208
+ terminated = step_results["terminated"]
209
+ truncated = step_results["truncated"]
210
+ if mx.any(terminated | truncated):
211
+ obs, infos = vectorized_env.reset()
212
+
213
+ # Benchmark
214
+ obs, infos = vectorized_env.reset()
215
+ start_time = time.perf_counter()
216
+
217
+ for i in range(iterations):
218
+ # Sample actions for each environment and create batch
219
+ action_samples = [vectorized_env.action_space.sample() for _ in range(num_envs)]
220
+ actions = mx.array(action_samples)
221
+ step_results = vectorized_env.step(actions)
222
+ # Check if any environments are done
223
+ terminated = step_results["terminated"]
224
+ truncated = step_results["truncated"]
225
+ if mx.any(terminated | truncated):
226
+ obs, infos = vectorized_env.reset()
227
+
228
+ end_time = time.perf_counter()
229
+ duration = end_time - start_time
230
+
231
+ vectorized_env.close()
232
+
233
+ # Total steps = iterations * num_envs
234
+ total_steps = iterations * num_envs
235
+ steps_per_second = total_steps / duration
236
+
237
+ vectorized_result = BenchmarkResult(
238
+ name=f"{env_name}_vectorized_{num_envs}",
239
+ benchmark_type=BenchmarkType.VECTORIZATION,
240
+ duration_seconds=duration,
241
+ iterations=total_steps,
242
+ metrics={
243
+ "steps_per_second": steps_per_second,
244
+ "avg_step_time_ms": (duration / total_steps) * 1000,
245
+ "num_envs": num_envs,
246
+ "parallelization_efficiency": steps_per_second / (single_result.metrics["steps_per_second"] * num_envs)
247
+ },
248
+ metadata={"env_name": env_name, "num_envs": num_envs, "env_kwargs": env_kwargs}
249
+ )
250
+
251
+ self.results.append(vectorized_result)
252
+
253
+ return ComparisonResult(baseline=single_result, comparisons=[vectorized_result])
254
+
255
+ except Exception as e:
256
+ debug_print(f"Error benchmarking vectorized environment: {e}", "benchmarking")
257
+ # Return comparison with just single environment
258
+ return ComparisonResult(baseline=single_result, comparisons=[])
259
+
260
+ def benchmark_function_performance(self, name: str, func: Callable, *args,
261
+ iterations: int = 100, **kwargs) -> BenchmarkResult:
262
+ """
263
+ Benchmark arbitrary function performance.
264
+
265
+ Args:
266
+ name: Benchmark name
267
+ func: Function to benchmark
268
+ *args: Function positional arguments
269
+ iterations: Number of iterations
270
+ **kwargs: Function keyword arguments
271
+
272
+ Returns:
273
+ BenchmarkResult with function performance data
274
+ """
275
+ stats = benchmark_function(func, *args, iterations=iterations,
276
+ warmup=self.warmup_iterations, **kwargs)
277
+
278
+ result = BenchmarkResult(
279
+ name=name,
280
+ benchmark_type=BenchmarkType.TRAINING,
281
+ duration_seconds=stats["total"],
282
+ iterations=iterations,
283
+ metrics={
284
+ "avg_time_ms": stats["mean"] * 1000,
285
+ "min_time_ms": stats["min"] * 1000,
286
+ "max_time_ms": stats["max"] * 1000,
287
+ "std_time_ms": 0.0 # Would need to calculate from raw data
288
+ }
289
+ )
290
+
291
+ self.results.append(result)
292
+ return result
293
+
294
+ def compare_environments(self, env_names: List[str], iterations: int = 1000,
295
+ **shared_env_kwargs) -> List[BenchmarkResult]:
296
+ """
297
+ Compare performance across multiple environments.
298
+
299
+ Args:
300
+ env_names: List of environment names to compare
301
+ iterations: Number of steps per environment
302
+ **shared_env_kwargs: Shared environment creation arguments
303
+
304
+ Returns:
305
+ List of BenchmarkResults sorted by performance
306
+ """
307
+ debug_print(f"Comparing {len(env_names)} environments", "benchmarking")
308
+
309
+ results = []
310
+ for env_name in env_names:
311
+ try:
312
+ result = self.benchmark_environment_speed(env_name, iterations, **shared_env_kwargs)
313
+ results.append(result)
314
+ except Exception as e:
315
+ debug_print(f"Failed to benchmark {env_name}: {e}", "benchmarking")
316
+ continue
317
+
318
+ # Sort by steps per second (descending)
319
+ results.sort(key=lambda r: r.metrics.get("steps_per_second", 0), reverse=True)
320
+ return results
321
+
322
+ def get_results(self, benchmark_type: Optional[BenchmarkType] = None) -> List[BenchmarkResult]:
323
+ """
324
+ Get benchmark results, optionally filtered by type.
325
+
326
+ Args:
327
+ benchmark_type: Optional filter by benchmark type
328
+
329
+ Returns:
330
+ List of BenchmarkResults
331
+ """
332
+ if benchmark_type is None:
333
+ return self.results.copy()
334
+ return [r for r in self.results if r.benchmark_type == benchmark_type]
335
+
336
+ def clear_results(self):
337
+ """Clear all benchmark results."""
338
+ self.results.clear()
339
+
340
+ def print_results(self, benchmark_type: Optional[BenchmarkType] = None, detailed: bool = True):
341
+ """
342
+ Print benchmark results in a readable format.
343
+
344
+ Args:
345
+ benchmark_type: Optional filter by benchmark type
346
+ detailed: Whether to show detailed metrics
347
+ """
348
+ results = self.get_results(benchmark_type)
349
+
350
+ if not results:
351
+ print("No benchmark results available.")
352
+ return
353
+
354
+ print(f"\nBenchmark Results ({len(results)} tests)")
355
+ print("=" * 70)
356
+
357
+ if detailed:
358
+ for result in results:
359
+ print(f"\n{result.name}")
360
+ print(f" Type: {result.benchmark_type.value}")
361
+ print(f" Duration: {result.duration_seconds:.3f}s")
362
+ print(f" Iterations: {result.iterations}")
363
+ print(f" Avg time/iteration: {result.avg_time_per_iteration*1000:.3f}ms")
364
+ print(f" Iterations/second: {result.iterations_per_second:.0f}")
365
+
366
+ if result.metrics:
367
+ print(" Metrics:")
368
+ for key, value in result.metrics.items():
369
+ if isinstance(value, float):
370
+ print(f" {key}: {value:.3f}")
371
+ else:
372
+ print(f" {key}: {value}")
373
+ else:
374
+ # Compact table format
375
+ print(f"{'Name':<25} {'Type':<12} {'Time/iter (ms)':<15} {'Iter/sec':<12}")
376
+ print("-" * 70)
377
+ for result in results:
378
+ print(f"{result.name:<25} {result.benchmark_type.value:<12} "
379
+ f"{result.avg_time_per_iteration*1000:<15.3f} {result.iterations_per_second:<12.0f}")
380
+
381
+ def print_comparison(self, comparison: ComparisonResult):
382
+ """
383
+ Print detailed comparison results.
384
+
385
+ Args:
386
+ comparison: ComparisonResult to display
387
+ """
388
+ print("\nPerformance Comparison")
389
+ print("=" * 50)
390
+
391
+ baseline = comparison.baseline
392
+ print(f"Baseline: {baseline.name}")
393
+ print(f" {baseline.iterations_per_second:.0f} iterations/sec")
394
+ print(f" {baseline.avg_time_per_iteration*1000:.3f} ms/iteration")
395
+
396
+ print("\nComparisons:")
397
+ for result in comparison.comparisons:
398
+ speedup = comparison.get_speedup(result)
399
+ throughput_ratio = comparison.get_throughput_ratio(result)
400
+
401
+ print(f" {result.name}:")
402
+ print(f" {result.iterations_per_second:.0f} iterations/sec ({throughput_ratio:.2f}x throughput)")
403
+ print(f" {result.avg_time_per_iteration*1000:.3f} ms/iteration ({speedup:.2f}x speedup)")
404
+
405
+ if speedup > 1.1:
406
+ print(f" {speedup:.1f}x faster than baseline")
407
+ elif speedup < 0.9:
408
+ print(f" {1/speedup:.1f}x slower than baseline")
409
+ else:
410
+ print(" Similar performance to baseline")
411
+
412
+
413
+ def quick_environment_benchmark(env_name: str, iterations: int = 1000, **env_kwargs) -> EnvironmentProfile:
414
+ """
415
+ Quick benchmark and analysis of a single environment.
416
+
417
+ Args:
418
+ env_name: Environment name
419
+ iterations: Number of steps to benchmark
420
+ **env_kwargs: Environment creation arguments
421
+
422
+ Returns:
423
+ EnvironmentProfile with analysis and recommendations
424
+ """
425
+ # Use EnvironmentAnalyzer for detailed analysis
426
+ analyzer = EnvironmentAnalyzer(test_steps=iterations)
427
+ return analyzer.analyze_environment(env_name, **env_kwargs)
428
+
429
+
430
+ def quick_vectorization_test(env_name: str, num_envs_list: List[int] = [1, 2, 4, 8],
431
+ iterations: int = 500, **env_kwargs) -> Dict[int, BenchmarkResult]:
432
+ """
433
+ Quick test of vectorization performance across different environment counts.
434
+
435
+ Args:
436
+ env_name: Environment name
437
+ num_envs_list: List of environment counts to test
438
+ iterations: Number of steps per test
439
+ **env_kwargs: Environment creation arguments
440
+
441
+ Returns:
442
+ Dictionary mapping num_envs to BenchmarkResult
443
+ """
444
+ benchmark = PerformanceBenchmark()
445
+ results = {}
446
+
447
+ for num_envs in num_envs_list:
448
+ if num_envs == 1:
449
+ result = benchmark.benchmark_environment_speed(env_name, iterations, **env_kwargs)
450
+ else:
451
+ comparison = benchmark.benchmark_vectorization_comparison(
452
+ env_name, num_envs=num_envs, iterations=iterations//num_envs, **env_kwargs
453
+ )
454
+ result = comparison.comparisons[0] if comparison.comparisons else None
455
+
456
+ if result:
457
+ results[num_envs] = result
458
+
459
+ return results
460
+
461
+
462
+ @contextmanager
463
+ def benchmark_context(name: str, benchmark: Optional[PerformanceBenchmark] = None):
464
+ """
465
+ Context manager for benchmarking code blocks.
466
+
467
+ Args:
468
+ name: Benchmark name
469
+ benchmark: PerformanceBenchmark instance (creates new if None)
470
+
471
+ Example:
472
+ with benchmark_context("training_step") as bench:
473
+ loss = trainer.update(batch)
474
+ """
475
+ bench = benchmark or PerformanceBenchmark()
476
+ start_time = time.perf_counter()
477
+
478
+ try:
479
+ yield bench
480
+ finally:
481
+ duration = time.perf_counter() - start_time
482
+ result = BenchmarkResult(
483
+ name=name,
484
+ benchmark_type=BenchmarkType.TRAINING,
485
+ duration_seconds=duration,
486
+ iterations=1,
487
+ metrics={"duration_ms": duration * 1000}
488
+ )
489
+ bench.results.append(result)
@@ -0,0 +1,60 @@
1
+ # textpolicy/utils/data.py
2
+ """
3
+ Data processing and conversion utilities.
4
+ """
5
+
6
+ from typing import Any, Dict
7
+ import mlx.core as mx # type: ignore
8
+ import numpy as np
9
+
10
+
11
+ def to_mlx(data: Any) -> mx.array:
12
+ """
13
+ Convert various data types to MLX arrays.
14
+
15
+ Args:
16
+ data: Input data (numpy array, list, scalar, etc.)
17
+
18
+ Returns:
19
+ MLX array
20
+ """
21
+ if isinstance(data, mx.array):
22
+ return data
23
+ elif isinstance(data, np.ndarray):
24
+ # Ensure contiguous array for MLX compatibility (fixes "Invalid type ndarray" error)
25
+ if not data.flags.c_contiguous:
26
+ data = np.ascontiguousarray(data)
27
+ return mx.array(data)
28
+ elif isinstance(data, (list, tuple)):
29
+ return mx.array(data)
30
+ elif isinstance(data, (int, float)):
31
+ return mx.array(data)
32
+ else:
33
+ # Try direct conversion
34
+ return mx.array(data)
35
+
36
+
37
+ def to_numpy(data: mx.array) -> np.ndarray:
38
+ """
39
+ Convert MLX array to numpy array.
40
+
41
+ Args:
42
+ data: MLX array
43
+
44
+ Returns:
45
+ Numpy array
46
+ """
47
+ return np.array(data)
48
+
49
+
50
+ def batch_to_mlx(batch: Dict[str, Any]) -> Dict[str, mx.array]:
51
+ """
52
+ Convert batch dictionary to MLX arrays.
53
+
54
+ Args:
55
+ batch: Dictionary with various data types
56
+
57
+ Returns:
58
+ Dictionary with MLX arrays
59
+ """
60
+ return {key: to_mlx(value) for key, value in batch.items()}