textpolicy 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. textpolicy/__init__.py +53 -0
  2. textpolicy/__main__.py +8 -0
  3. textpolicy/algorithms/__init__.py +54 -0
  4. textpolicy/algorithms/grpo.py +642 -0
  5. textpolicy/algorithms/gspo.py +582 -0
  6. textpolicy/buffer/__init__.py +23 -0
  7. textpolicy/buffer/buffer.py +244 -0
  8. textpolicy/buffer/episode.py +383 -0
  9. textpolicy/buffer/sampling.py +438 -0
  10. textpolicy/buffer/storage.py +255 -0
  11. textpolicy/cli.py +67 -0
  12. textpolicy/environment/__init__.py +79 -0
  13. textpolicy/environment/base.py +110 -0
  14. textpolicy/environment/environment.py +46 -0
  15. textpolicy/environment/factory.py +103 -0
  16. textpolicy/environment/gym.py +106 -0
  17. textpolicy/environment/task_suites.py +51 -0
  18. textpolicy/environment/text_generation.py +797 -0
  19. textpolicy/environment/vectorized.py +253 -0
  20. textpolicy/generation/__init__.py +62 -0
  21. textpolicy/generation/lora.py +411 -0
  22. textpolicy/generation/mlx_generation.py +557 -0
  23. textpolicy/generation/reload.py +253 -0
  24. textpolicy/rewards/__init__.py +137 -0
  25. textpolicy/rewards/adapters.py +387 -0
  26. textpolicy/rewards/basic.py +214 -0
  27. textpolicy/rewards/integrated_system.py +338 -0
  28. textpolicy/rewards/mlx_batch_processor.py +447 -0
  29. textpolicy/rewards/registry.py +293 -0
  30. textpolicy/rewards/rollout_rewards.py +410 -0
  31. textpolicy/rewards/verifiers.py +369 -0
  32. textpolicy/rollout/__init__.py +44 -0
  33. textpolicy/rollout/aggregator.py +145 -0
  34. textpolicy/rollout/base.py +108 -0
  35. textpolicy/rollout/rollout.py +142 -0
  36. textpolicy/rollout/runner.py +280 -0
  37. textpolicy/rollout/strategy.py +208 -0
  38. textpolicy/rollout/worker.py +194 -0
  39. textpolicy/training/__init__.py +14 -0
  40. textpolicy/training/metrics.py +242 -0
  41. textpolicy/training/rollout_manager.py +78 -0
  42. textpolicy/training/trainer.py +684 -0
  43. textpolicy/utils/__init__.py +40 -0
  44. textpolicy/utils/benchmarking.py +489 -0
  45. textpolicy/utils/data.py +60 -0
  46. textpolicy/utils/debug.py +170 -0
  47. textpolicy/utils/environment.py +349 -0
  48. textpolicy/utils/logging/__init__.py +22 -0
  49. textpolicy/utils/logging/base.py +48 -0
  50. textpolicy/utils/logging/console.py +61 -0
  51. textpolicy/utils/logging/factory.py +133 -0
  52. textpolicy/utils/logging/multi.py +83 -0
  53. textpolicy/utils/logging/tensorboard.py +65 -0
  54. textpolicy/utils/logging/wandb.py +72 -0
  55. textpolicy/utils/memory.py +118 -0
  56. textpolicy/utils/performance.py +464 -0
  57. textpolicy/utils/timing.py +171 -0
  58. textpolicy/validate.py +101 -0
  59. textpolicy/validation/__init__.py +13 -0
  60. textpolicy/validation/logprob_validation.py +315 -0
  61. textpolicy-0.1.1.dist-info/METADATA +109 -0
  62. textpolicy-0.1.1.dist-info/RECORD +66 -0
  63. {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/WHEEL +1 -1
  64. textpolicy-0.1.1.dist-info/entry_points.txt +2 -0
  65. textpolicy-0.0.1.dist-info/METADATA +0 -10
  66. textpolicy-0.0.1.dist-info/RECORD +0 -6
  67. {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/licenses/LICENSE +0 -0
  68. {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,170 @@
1
+ # textpolicy/utils/debug.py
2
+ """
3
+ Debug utilities and configuration for TextPolicy.
4
+ """
5
+
6
+ import os
7
+
8
+ class DebugConfig:
9
+ """Global debug configuration for TextPolicy."""
10
+
11
+ def __init__(self):
12
+ # Read from environment variables with defaults
13
+ self.policy_init = os.getenv('MLX_RL_DEBUG_POLICY_INIT', 'false').lower() == 'true'
14
+ self.value_init = os.getenv('MLX_RL_DEBUG_VALUE_INIT', 'false').lower() == 'true'
15
+ self.training = os.getenv('MLX_RL_DEBUG_TRAINING', 'false').lower() == 'true'
16
+ self.gradients = os.getenv('MLX_RL_DEBUG_GRADIENTS', 'false').lower() == 'true'
17
+ self.baseline_estimation = os.getenv('MLX_RL_DEBUG_BASELINE', 'false').lower() == 'true'
18
+
19
+ # New performance and vectorization debug categories
20
+ self.vectorization = os.getenv('MLX_RL_DEBUG_VECTORIZATION', 'false').lower() == 'true'
21
+ self.environment = os.getenv('MLX_RL_DEBUG_ENVIRONMENT', 'false').lower() == 'true'
22
+ self.performance = os.getenv('MLX_RL_DEBUG_PERFORMANCE', 'false').lower() == 'true'
23
+ self.benchmarking = os.getenv('MLX_RL_DEBUG_BENCHMARKING', 'false').lower() == 'true'
24
+ self.memory = os.getenv('MLX_RL_DEBUG_MEMORY', 'false').lower() == 'true'
25
+ self.timing = os.getenv('MLX_RL_DEBUG_TIMING', 'false').lower() == 'true'
26
+
27
+ # Overall debug level
28
+ debug_level = os.getenv('MLX_RL_DEBUG', 'info').lower()
29
+ self.enabled = debug_level in ['debug', 'verbose']
30
+ self.verbose = debug_level == 'verbose'
31
+
32
+ def should_debug(self, category: str) -> bool:
33
+ """Check if debugging is enabled for a specific category."""
34
+ if not self.enabled:
35
+ return False
36
+
37
+ category_map = {
38
+ 'policy_init': self.policy_init,
39
+ 'value_init': self.value_init,
40
+ 'training': self.training,
41
+ 'gradients': self.gradients,
42
+ 'baseline': self.baseline_estimation,
43
+ 'vectorization': self.vectorization,
44
+ 'environment': self.environment,
45
+ 'performance': self.performance,
46
+ 'benchmarking': self.benchmarking,
47
+ 'memory': self.memory,
48
+ 'timing': self.timing
49
+ }
50
+
51
+ return category_map.get(category, self.verbose)
52
+
53
+ # Global debug configuration instance
54
+ debug_config = DebugConfig()
55
+
56
+ def debug_print(message: str, category: str = 'general', force: bool = False):
57
+ """Print debug message if debugging is enabled for the category."""
58
+ if force or debug_config.should_debug(category):
59
+ print(f"[DEBUG] {message}")
60
+
61
+ def error_print(message: str, category: str = 'general'):
62
+ """Print error messages only if any debug mode is enabled."""
63
+ if debug_config.enabled:
64
+ print(f"[ERROR] {message}")
65
+
66
+ def info_print(message: str, category: str = 'general'):
67
+ """Print info messages only if explicitly enabled for the category."""
68
+ if debug_config.should_debug(category):
69
+ print(f"[INFO] {message}")
70
+
71
+
72
+ def performance_debug(message: str, force: bool = False):
73
+ """Debug print for performance-related messages."""
74
+ debug_print(message, 'performance', force)
75
+
76
+
77
+ def vectorization_debug(message: str, force: bool = False):
78
+ """Debug print for vectorization-related messages."""
79
+ debug_print(message, 'vectorization', force)
80
+
81
+
82
+ def environment_debug(message: str, force: bool = False):
83
+ """Debug print for environment-related messages."""
84
+ debug_print(message, 'environment', force)
85
+
86
+
87
+ def benchmarking_debug(message: str, force: bool = False):
88
+ """Debug print for benchmarking-related messages."""
89
+ debug_print(message, 'benchmarking', force)
90
+
91
+
92
+ def memory_debug(message: str, force: bool = False):
93
+ """Debug print for memory-related messages."""
94
+ debug_print(message, 'memory', force)
95
+
96
+
97
+ def timing_debug(message: str, force: bool = False):
98
+ """Debug print for timing-related messages."""
99
+ debug_print(message, 'timing', force)
100
+
101
+
102
+ def get_debug_categories() -> list:
103
+ """Get list of available debug categories."""
104
+ return [
105
+ 'policy_init', 'value_init', 'training', 'gradients', 'baseline',
106
+ 'vectorization', 'environment', 'performance', 'benchmarking',
107
+ 'memory', 'timing', 'general'
108
+ ]
109
+
110
+
111
+ def is_debug_enabled(category: str = 'general') -> bool:
112
+ """Check if debug is enabled for a specific category."""
113
+ return debug_config.should_debug(category)
114
+
115
+
116
+ def set_debug_level(level: str):
117
+ """
118
+ Set debug level programmatically.
119
+
120
+ Args:
121
+ level: Debug level ('info', 'debug', 'verbose')
122
+ """
123
+ os.environ['MLX_RL_DEBUG'] = level.lower()
124
+ # Reinitialize global config
125
+ global debug_config
126
+ debug_config = DebugConfig()
127
+
128
+
129
+ def enable_category_debug(category: str, enabled: bool = True):
130
+ """
131
+ Enable/disable debug for a specific category.
132
+
133
+ Args:
134
+ category: Debug category name
135
+ enabled: Whether to enable or disable
136
+ """
137
+ env_var = f'MLX_RL_DEBUG_{category.upper()}'
138
+ os.environ[env_var] = 'true' if enabled else 'false'
139
+
140
+ # Reinitialize global config
141
+ global debug_config
142
+ debug_config = DebugConfig()
143
+
144
+
145
+ def print_debug_status():
146
+ """Print current debug configuration status."""
147
+ print("MLX-RL Debug Configuration:")
148
+ print("=" * 40)
149
+ print(f"Overall enabled: {debug_config.enabled}")
150
+ print(f"Verbose mode: {debug_config.verbose}")
151
+ print()
152
+ print("Category-specific settings:")
153
+
154
+ categories = {
155
+ 'policy_init': debug_config.policy_init,
156
+ 'value_init': debug_config.value_init,
157
+ 'training': debug_config.training,
158
+ 'gradients': debug_config.gradients,
159
+ 'baseline': debug_config.baseline_estimation,
160
+ 'vectorization': debug_config.vectorization,
161
+ 'environment': debug_config.environment,
162
+ 'performance': debug_config.performance,
163
+ 'benchmarking': debug_config.benchmarking,
164
+ 'memory': debug_config.memory,
165
+ 'timing': debug_config.timing
166
+ }
167
+
168
+ for category, enabled in categories.items():
169
+ status = "Enabled" if enabled else "Disabled"
170
+ print(f" {category:<15} {status}")
@@ -0,0 +1,349 @@
1
+ """
2
+ Environment analysis and profiling utilities for MLX-RL.
3
+
4
+ This module provides tools to analyze environment characteristics and determine
5
+ optimal parallelization strategies based on empirical performance testing.
6
+ """
7
+
8
+ import time
9
+ import numpy as np
10
+ from typing import List, Optional, Tuple, Union
11
+ from dataclasses import dataclass
12
+ from enum import Enum
13
+
14
+ try:
15
+ MLX_AVAILABLE = True
16
+ except ImportError:
17
+ MLX_AVAILABLE = False
18
+
19
+ from .timing import Timer
20
+ from .debug import debug_print
21
+
22
+
23
+ class EnvironmentType(Enum):
24
+ """Classification of environment computational complexity."""
25
+ ULTRA_LIGHT = "ultra_light" # <1ms per step (CartPole, MountainCar)
26
+ LIGHT = "light" # 1-5ms per step
27
+ MODERATE = "moderate" # 5-20ms per step
28
+ HEAVY = "heavy" # 20-100ms per step
29
+ ULTRA_HEAVY = "ultra_heavy" # >100ms per step
30
+
31
+
32
+ @dataclass
33
+ class EnvironmentProfile:
34
+ """Environment performance characteristics."""
35
+ name: str
36
+ avg_step_time_ms: float
37
+ steps_per_second: float
38
+ avg_episode_length: float
39
+ environment_type: EnvironmentType
40
+ vectorization_recommended: bool
41
+ recommended_num_envs: Optional[int]
42
+ reasoning: str
43
+
44
+
45
+ class EnvironmentAnalyzer:
46
+ """
47
+ Analyzes environment performance characteristics to guide optimization decisions.
48
+
49
+ Based on empirical testing from MLX-RL vectorization analysis:
50
+ - CartPole-v1: ~250k steps/sec (ultra_light)
51
+ - MountainCar-v0: ~225k steps/sec (ultra_light)
52
+ - Acrobot-v1: ~42k steps/sec (light)
53
+ """
54
+
55
+ # Performance thresholds based on empirical analysis
56
+ STEP_TIME_THRESHOLDS = {
57
+ EnvironmentType.ULTRA_LIGHT: 0.005, # <5ms per step
58
+ EnvironmentType.LIGHT: 0.02, # 5-20ms per step
59
+ EnvironmentType.MODERATE: 0.1, # 20-100ms per step
60
+ EnvironmentType.HEAVY: 0.5, # 100-500ms per step
61
+ EnvironmentType.ULTRA_HEAVY: float('inf') # >500ms per step
62
+ }
63
+
64
+ def __init__(self, test_steps: int = 500, warmup_steps: int = 50):
65
+ """
66
+ Initialize environment analyzer.
67
+
68
+ Args:
69
+ test_steps: Number of steps to use for performance testing
70
+ warmup_steps: Number of warmup steps (excluded from timing)
71
+ """
72
+ self.test_steps = test_steps
73
+ self.warmup_steps = warmup_steps
74
+ self.timer = Timer()
75
+
76
+ def analyze_environment(self, env_name: str, **env_kwargs) -> EnvironmentProfile:
77
+ """
78
+ Analyze an environment's performance characteristics.
79
+
80
+ Args:
81
+ env_name: Environment name (e.g., "CartPole-v1")
82
+ **env_kwargs: Additional environment creation arguments
83
+
84
+ Returns:
85
+ EnvironmentProfile with performance analysis and recommendations
86
+ """
87
+ debug_print(f"Analyzing environment: {env_name}", "environment")
88
+
89
+ try:
90
+ # Import here to avoid circular dependencies
91
+ from textpolicy.environment import GymAdapter
92
+
93
+ env = GymAdapter(env_name, **env_kwargs)
94
+
95
+ # Warmup
96
+ obs, info = env.reset()
97
+ for _ in range(self.warmup_steps):
98
+ action = self._sample_action(env)
99
+ step_result = env.step(action)
100
+ if step_result["terminated"] or step_result["truncated"]:
101
+ obs, info = env.reset()
102
+
103
+ # Performance measurement
104
+ step_times = []
105
+ episode_lengths = []
106
+ current_episode_length = 0
107
+
108
+ obs, info = env.reset()
109
+
110
+ start_time = time.perf_counter()
111
+ for i in range(self.test_steps):
112
+ action = self._sample_action(env)
113
+
114
+ step_start = time.perf_counter()
115
+ step_result = env.step(action)
116
+ step_end = time.perf_counter()
117
+
118
+ step_times.append(step_end - step_start)
119
+ current_episode_length += 1
120
+
121
+ if step_result["terminated"] or step_result["truncated"]:
122
+ episode_lengths.append(current_episode_length)
123
+ current_episode_length = 0
124
+ obs, info = env.reset()
125
+
126
+ end_time = time.perf_counter()
127
+
128
+ env.close()
129
+
130
+ # Calculate metrics
131
+ total_time = end_time - start_time
132
+ avg_step_time = np.mean(step_times)
133
+ steps_per_second = self.test_steps / total_time
134
+ avg_episode_length = np.mean(episode_lengths) if episode_lengths else current_episode_length
135
+
136
+ # Classify environment type
137
+ env_type = self._classify_environment_type(avg_step_time)
138
+
139
+ # Generate recommendations
140
+ vectorization_recommended, recommended_num_envs, reasoning = self._generate_recommendations(
141
+ env_type, avg_step_time, steps_per_second
142
+ )
143
+
144
+ return EnvironmentProfile(
145
+ name=env_name,
146
+ avg_step_time_ms=float(avg_step_time * 1000),
147
+ steps_per_second=float(steps_per_second),
148
+ avg_episode_length=float(avg_episode_length),
149
+ environment_type=env_type,
150
+ vectorization_recommended=vectorization_recommended,
151
+ recommended_num_envs=recommended_num_envs,
152
+ reasoning=reasoning
153
+ )
154
+
155
+ except Exception as e:
156
+ debug_print(f"Error analyzing environment {env_name}: {e}", "environment")
157
+ # Return minimal profile on error
158
+ return EnvironmentProfile(
159
+ name=env_name,
160
+ avg_step_time_ms=0.0,
161
+ steps_per_second=0.0,
162
+ avg_episode_length=0.0,
163
+ environment_type=EnvironmentType.MODERATE,
164
+ vectorization_recommended=False,
165
+ recommended_num_envs=None,
166
+ reasoning=f"Analysis failed: {e}"
167
+ )
168
+
169
+ def _sample_action(self, env) -> Union[int, float, np.ndarray]:
170
+ """Sample a random action from the environment's action space."""
171
+ try:
172
+ # Use gymnasium's action space sampling
173
+ return env.action_space.sample()
174
+ except Exception:
175
+ # Fallback for discrete spaces
176
+ if hasattr(env.action_space, 'n'):
177
+ return np.random.randint(0, env.action_space.n)
178
+ else:
179
+ # Assume continuous space
180
+ return np.random.random()
181
+
182
+ def _classify_environment_type(self, avg_step_time: float) -> EnvironmentType:
183
+ """Classify environment based on average step time."""
184
+ for env_type, threshold in self.STEP_TIME_THRESHOLDS.items():
185
+ if avg_step_time <= threshold:
186
+ return env_type
187
+ return EnvironmentType.ULTRA_HEAVY
188
+
189
+ def _generate_recommendations(self, env_type: EnvironmentType, avg_step_time: float,
190
+ steps_per_second: float) -> Tuple[bool, Optional[int], str]:
191
+ """
192
+ Generate vectorization recommendations based on environment analysis.
193
+
194
+ Returns:
195
+ Tuple of (should_vectorize, recommended_num_envs, reasoning)
196
+ """
197
+ if env_type == EnvironmentType.ULTRA_LIGHT:
198
+ return False, None, (
199
+ f"Ultra-lightweight environment ({avg_step_time*1000:.2f}ms/step, "
200
+ f"{steps_per_second:.0f} steps/sec). Process overhead will dominate. "
201
+ "Use single environment for optimal performance."
202
+ )
203
+
204
+ elif env_type == EnvironmentType.LIGHT:
205
+ return False, None, (
206
+ f"Lightweight environment ({avg_step_time*1000:.2f}ms/step). "
207
+ "May benefit from vectorization with complex environments only. "
208
+ "Test with 2-4 environments if training is I/O bound."
209
+ )
210
+
211
+ elif env_type == EnvironmentType.MODERATE:
212
+ return True, 4, (
213
+ f"Moderate complexity environment ({avg_step_time*1000:.2f}ms/step). "
214
+ "Good candidate for vectorization. Start with 4 environments."
215
+ )
216
+
217
+ elif env_type == EnvironmentType.HEAVY:
218
+ return True, 8, (
219
+ f"Heavy computation environment ({avg_step_time*1000:.2f}ms/step). "
220
+ "Excellent vectorization candidate. Recommended 8 environments."
221
+ )
222
+
223
+ else: # ULTRA_HEAVY
224
+ return True, 16, (
225
+ f"Ultra-heavy environment ({avg_step_time*1000:.2f}ms/step). "
226
+ "Excellent vectorization candidate. Consider 16+ environments."
227
+ )
228
+
229
+
230
+ def analyze_environment(env_name: str, **env_kwargs) -> EnvironmentProfile:
231
+ """
232
+ Convenience function to analyze an environment.
233
+
234
+ Args:
235
+ env_name: Environment name
236
+ **env_kwargs: Environment creation arguments
237
+
238
+ Returns:
239
+ EnvironmentProfile with analysis results
240
+ """
241
+ analyzer = EnvironmentAnalyzer()
242
+ return analyzer.analyze_environment(env_name, **env_kwargs)
243
+
244
+
245
+ def should_vectorize(env_name: str, **env_kwargs) -> Tuple[bool, str]:
246
+ """
247
+ Quick recommendation on whether to vectorize an environment.
248
+
249
+ Args:
250
+ env_name: Environment name
251
+ **env_kwargs: Environment creation arguments
252
+
253
+ Returns:
254
+ Tuple of (should_vectorize, reasoning)
255
+ """
256
+ profile = analyze_environment(env_name, **env_kwargs)
257
+ return profile.vectorization_recommended, profile.reasoning
258
+
259
+
260
+ def print_environment_analysis(profile: EnvironmentProfile, detailed: bool = True):
261
+ """
262
+ Print environment analysis results in a readable format.
263
+
264
+ Args:
265
+ profile: EnvironmentProfile to display
266
+ detailed: Whether to show detailed statistics
267
+ """
268
+ print(f"\nEnvironment Analysis: {profile.name}")
269
+ print("=" * 50)
270
+
271
+ if detailed:
272
+ print("Performance Metrics:")
273
+ print(f" Average step time: {profile.avg_step_time_ms:.3f} ms")
274
+ print(f" Steps per second: {profile.steps_per_second:.0f}")
275
+ print(f" Average episode length: {profile.avg_episode_length:.1f}")
276
+ print(f" Environment type: {profile.environment_type.value}")
277
+ print()
278
+
279
+ print("Vectorization Recommendation:")
280
+ if profile.vectorization_recommended:
281
+ print(" Use vectorized environments")
282
+ if profile.recommended_num_envs:
283
+ print(f" Recommended: {profile.recommended_num_envs} environments")
284
+ else:
285
+ print(" Use single environment")
286
+
287
+ print("\nReasoning:")
288
+ print(f" {profile.reasoning}")
289
+
290
+
291
+ class EnvironmentBenchmark:
292
+ """
293
+ Benchmark multiple environments to compare their characteristics.
294
+
295
+ Useful for comparing similar environments or testing modifications.
296
+ """
297
+
298
+ def __init__(self, analyzer: Optional[EnvironmentAnalyzer] = None):
299
+ """Initialize with optional custom analyzer."""
300
+ self.analyzer = analyzer or EnvironmentAnalyzer()
301
+ self.profiles: List[EnvironmentProfile] = []
302
+
303
+ def add_environment(self, env_name: str, **env_kwargs) -> EnvironmentProfile:
304
+ """
305
+ Add an environment to the benchmark.
306
+
307
+ Args:
308
+ env_name: Environment name
309
+ **env_kwargs: Environment creation arguments
310
+
311
+ Returns:
312
+ EnvironmentProfile for the added environment
313
+ """
314
+ profile = self.analyzer.analyze_environment(env_name, **env_kwargs)
315
+ self.profiles.append(profile)
316
+ return profile
317
+
318
+ def print_comparison(self):
319
+ """Print comparison of all benchmarked environments."""
320
+ if not self.profiles:
321
+ print("No environments benchmarked yet.")
322
+ return
323
+
324
+ print(f"\nEnvironment Comparison ({len(self.profiles)} environments)")
325
+ print("=" * 80)
326
+
327
+ # Header
328
+ print(f"{'Environment':<20} {'Step Time (ms)':<15} {'Steps/sec':<12} {'Type':<12} {'Vectorize':<10}")
329
+ print("-" * 80)
330
+
331
+ # Sort by steps per second (descending)
332
+ sorted_profiles = sorted(self.profiles, key=lambda p: p.steps_per_second, reverse=True)
333
+
334
+ for profile in sorted_profiles:
335
+ vectorize_str = "Yes" if profile.vectorization_recommended else "No"
336
+ print(f"{profile.name:<20} {profile.avg_step_time_ms:<15.3f} "
337
+ f"{profile.steps_per_second:<12.0f} {profile.environment_type.value:<12} {vectorize_str:<10}")
338
+
339
+ print()
340
+
341
+ # Summary statistics
342
+ ultra_light = sum(1 for p in self.profiles if p.environment_type == EnvironmentType.ULTRA_LIGHT)
343
+ light = sum(1 for p in self.profiles if p.environment_type == EnvironmentType.LIGHT)
344
+ moderate_plus = len(self.profiles) - ultra_light - light
345
+
346
+ print("Summary:")
347
+ print(f" Ultra-light environments: {ultra_light} (single process recommended)")
348
+ print(f" Light environments: {light} (test vectorization carefully)")
349
+ print(f" Moderate+ environments: {moderate_plus} (vectorization recommended)")
@@ -0,0 +1,22 @@
1
+ # textpolicy/utils/logging/__init__.py
2
+ """
3
+ Logging utilities for TextPolicy.
4
+ """
5
+
6
+ from .base import Logger
7
+ from .wandb import WandbLogger
8
+ from .tensorboard import TensorboardLogger
9
+ from .console import ConsoleLogger
10
+ from .multi import MultiLogger
11
+ from .factory import create_logger, create_multi_logger, create_auto_logger
12
+
13
+ __all__ = [
14
+ 'Logger',
15
+ 'WandbLogger',
16
+ 'TensorboardLogger',
17
+ 'ConsoleLogger',
18
+ 'MultiLogger',
19
+ 'create_logger',
20
+ 'create_multi_logger',
21
+ 'create_auto_logger',
22
+ ]
@@ -0,0 +1,48 @@
1
+ # textpolicy/utils/logging/base.py
2
+ """
3
+ Base logger interface and protocols for TextPolicy.
4
+ """
5
+
6
+ from typing import Dict
7
+ from abc import ABC, abstractmethod
8
+
9
+
10
+ class Logger(ABC):
11
+ """
12
+ Abstract base class for logging backends.
13
+
14
+ Defines the standard interface that all logging implementations must follow.
15
+ Supports both training metrics and evaluation metrics with step tracking.
16
+ """
17
+
18
+ @abstractmethod
19
+ def log_metrics(self, metrics: Dict[str, float], step: int):
20
+ """
21
+ Log training metrics.
22
+
23
+ Args:
24
+ metrics: Dictionary of metric names to values
25
+ step: Training step number for time-series tracking
26
+ """
27
+ pass
28
+
29
+ @abstractmethod
30
+ def log_evaluation(self, metrics: Dict[str, float], step: int):
31
+ """
32
+ Log evaluation metrics.
33
+
34
+ Args:
35
+ metrics: Dictionary of evaluation metric names to values
36
+ step: Training step number when evaluation was performed
37
+ """
38
+ pass
39
+
40
+ @abstractmethod
41
+ def finish(self):
42
+ """
43
+ Finish logging session and cleanup resources.
44
+
45
+ Called at end of training to properly close connections,
46
+ save final data, and cleanup any temporary resources.
47
+ """
48
+ pass
@@ -0,0 +1,61 @@
1
+ # textpolicy/utils/logging/console.py
2
+ """
3
+ Simple console logging for debugging and minimal setups.
4
+ """
5
+
6
+ from typing import Dict
7
+ from .base import Logger
8
+
9
+
10
+ class ConsoleLogger(Logger):
11
+ """
12
+ Simple console logging for debugging and development.
13
+
14
+ Features:
15
+ - No external dependencies
16
+ - Immediate output for debugging
17
+ - Configurable verbosity
18
+ - Minimal overhead
19
+
20
+ Ideal for:
21
+ - Development and debugging
22
+ - CI/CD pipelines
23
+ - Minimal deployment environments
24
+ """
25
+
26
+ def __init__(self, verbose: bool = True):
27
+ """
28
+ Initialize console logging.
29
+
30
+ Args:
31
+ verbose: Whether to print metrics to console
32
+ """
33
+ self.verbose = verbose
34
+
35
+ def log_metrics(self, metrics: Dict[str, float], step: int):
36
+ """
37
+ Log training metrics to console.
38
+
39
+ Args:
40
+ metrics: Training metrics dictionary
41
+ step: Training step number
42
+ """
43
+ if self.verbose:
44
+ metrics_str = ", ".join(f"{k}: {v:.4f}" for k, v in metrics.items())
45
+ print(f"Step {step} - Training: {metrics_str}")
46
+
47
+ def log_evaluation(self, metrics: Dict[str, float], step: int):
48
+ """
49
+ Log evaluation metrics to console.
50
+
51
+ Args:
52
+ metrics: Evaluation metrics dictionary
53
+ step: Training step when evaluation was performed
54
+ """
55
+ if self.verbose:
56
+ metrics_str = ", ".join(f"{k}: {v:.4f}" for k, v in metrics.items())
57
+ print(f"Step {step} - Evaluation: {metrics_str}")
58
+
59
+ def finish(self):
60
+ """No cleanup needed for console logging."""
61
+ pass