textpolicy 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +52 -0
- textpolicy/__main__.py +8 -0
- textpolicy/algorithms/__init__.py +54 -0
- textpolicy/algorithms/grpo.py +642 -0
- textpolicy/algorithms/gspo.py +582 -0
- textpolicy/buffer/__init__.py +23 -0
- textpolicy/buffer/buffer.py +244 -0
- textpolicy/buffer/episode.py +383 -0
- textpolicy/buffer/sampling.py +438 -0
- textpolicy/buffer/storage.py +255 -0
- textpolicy/cli.py +67 -0
- textpolicy/environment/__init__.py +79 -0
- textpolicy/environment/base.py +110 -0
- textpolicy/environment/environment.py +46 -0
- textpolicy/environment/factory.py +103 -0
- textpolicy/environment/gym.py +106 -0
- textpolicy/environment/task_suites.py +51 -0
- textpolicy/environment/text_generation.py +789 -0
- textpolicy/environment/vectorized.py +253 -0
- textpolicy/generation/__init__.py +62 -0
- textpolicy/generation/lora.py +411 -0
- textpolicy/generation/mlx_generation.py +557 -0
- textpolicy/generation/reload.py +253 -0
- textpolicy/rewards/__init__.py +137 -0
- textpolicy/rewards/adapters.py +387 -0
- textpolicy/rewards/basic.py +214 -0
- textpolicy/rewards/integrated_system.py +338 -0
- textpolicy/rewards/mlx_batch_processor.py +447 -0
- textpolicy/rewards/registry.py +293 -0
- textpolicy/rewards/rollout_rewards.py +410 -0
- textpolicy/rewards/verifiers.py +369 -0
- textpolicy/rollout/__init__.py +44 -0
- textpolicy/rollout/aggregator.py +145 -0
- textpolicy/rollout/base.py +108 -0
- textpolicy/rollout/rollout.py +142 -0
- textpolicy/rollout/runner.py +280 -0
- textpolicy/rollout/strategy.py +208 -0
- textpolicy/rollout/worker.py +194 -0
- textpolicy/training/__init__.py +14 -0
- textpolicy/training/metrics.py +242 -0
- textpolicy/training/rollout_manager.py +78 -0
- textpolicy/training/trainer.py +684 -0
- textpolicy/utils/__init__.py +40 -0
- textpolicy/utils/benchmarking.py +489 -0
- textpolicy/utils/data.py +60 -0
- textpolicy/utils/debug.py +170 -0
- textpolicy/utils/environment.py +349 -0
- textpolicy/utils/logging/__init__.py +22 -0
- textpolicy/utils/logging/base.py +48 -0
- textpolicy/utils/logging/console.py +61 -0
- textpolicy/utils/logging/factory.py +133 -0
- textpolicy/utils/logging/multi.py +83 -0
- textpolicy/utils/logging/tensorboard.py +65 -0
- textpolicy/utils/logging/wandb.py +72 -0
- textpolicy/utils/memory.py +118 -0
- textpolicy/utils/performance.py +464 -0
- textpolicy/utils/timing.py +171 -0
- textpolicy/validate.py +101 -0
- textpolicy/validation/__init__.py +13 -0
- textpolicy/validation/logprob_validation.py +315 -0
- textpolicy-0.1.0.dist-info/METADATA +99 -0
- textpolicy-0.1.0.dist-info/RECORD +66 -0
- textpolicy-0.1.0.dist-info/entry_points.txt +2 -0
- textpolicy-0.0.1.dist-info/METADATA +0 -10
- textpolicy-0.0.1.dist-info/RECORD +0 -6
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/WHEEL +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
# textpolicy/utils/debug.py
|
|
2
|
+
"""
|
|
3
|
+
Debug utilities and configuration for TextPolicy.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
class DebugConfig:
|
|
9
|
+
"""Global debug configuration for TextPolicy."""
|
|
10
|
+
|
|
11
|
+
def __init__(self):
|
|
12
|
+
# Read from environment variables with defaults
|
|
13
|
+
self.policy_init = os.getenv('MLX_RL_DEBUG_POLICY_INIT', 'false').lower() == 'true'
|
|
14
|
+
self.value_init = os.getenv('MLX_RL_DEBUG_VALUE_INIT', 'false').lower() == 'true'
|
|
15
|
+
self.training = os.getenv('MLX_RL_DEBUG_TRAINING', 'false').lower() == 'true'
|
|
16
|
+
self.gradients = os.getenv('MLX_RL_DEBUG_GRADIENTS', 'false').lower() == 'true'
|
|
17
|
+
self.baseline_estimation = os.getenv('MLX_RL_DEBUG_BASELINE', 'false').lower() == 'true'
|
|
18
|
+
|
|
19
|
+
# New performance and vectorization debug categories
|
|
20
|
+
self.vectorization = os.getenv('MLX_RL_DEBUG_VECTORIZATION', 'false').lower() == 'true'
|
|
21
|
+
self.environment = os.getenv('MLX_RL_DEBUG_ENVIRONMENT', 'false').lower() == 'true'
|
|
22
|
+
self.performance = os.getenv('MLX_RL_DEBUG_PERFORMANCE', 'false').lower() == 'true'
|
|
23
|
+
self.benchmarking = os.getenv('MLX_RL_DEBUG_BENCHMARKING', 'false').lower() == 'true'
|
|
24
|
+
self.memory = os.getenv('MLX_RL_DEBUG_MEMORY', 'false').lower() == 'true'
|
|
25
|
+
self.timing = os.getenv('MLX_RL_DEBUG_TIMING', 'false').lower() == 'true'
|
|
26
|
+
|
|
27
|
+
# Overall debug level
|
|
28
|
+
debug_level = os.getenv('MLX_RL_DEBUG', 'info').lower()
|
|
29
|
+
self.enabled = debug_level in ['debug', 'verbose']
|
|
30
|
+
self.verbose = debug_level == 'verbose'
|
|
31
|
+
|
|
32
|
+
def should_debug(self, category: str) -> bool:
|
|
33
|
+
"""Check if debugging is enabled for a specific category."""
|
|
34
|
+
if not self.enabled:
|
|
35
|
+
return False
|
|
36
|
+
|
|
37
|
+
category_map = {
|
|
38
|
+
'policy_init': self.policy_init,
|
|
39
|
+
'value_init': self.value_init,
|
|
40
|
+
'training': self.training,
|
|
41
|
+
'gradients': self.gradients,
|
|
42
|
+
'baseline': self.baseline_estimation,
|
|
43
|
+
'vectorization': self.vectorization,
|
|
44
|
+
'environment': self.environment,
|
|
45
|
+
'performance': self.performance,
|
|
46
|
+
'benchmarking': self.benchmarking,
|
|
47
|
+
'memory': self.memory,
|
|
48
|
+
'timing': self.timing
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
return category_map.get(category, self.verbose)
|
|
52
|
+
|
|
53
|
+
# Global debug configuration instance
|
|
54
|
+
debug_config = DebugConfig()
|
|
55
|
+
|
|
56
|
+
def debug_print(message: str, category: str = 'general', force: bool = False):
|
|
57
|
+
"""Print debug message if debugging is enabled for the category."""
|
|
58
|
+
if force or debug_config.should_debug(category):
|
|
59
|
+
print(f"[DEBUG] {message}")
|
|
60
|
+
|
|
61
|
+
def error_print(message: str, category: str = 'general'):
|
|
62
|
+
"""Print error messages only if any debug mode is enabled."""
|
|
63
|
+
if debug_config.enabled:
|
|
64
|
+
print(f"[ERROR] {message}")
|
|
65
|
+
|
|
66
|
+
def info_print(message: str, category: str = 'general'):
|
|
67
|
+
"""Print info messages only if explicitly enabled for the category."""
|
|
68
|
+
if debug_config.should_debug(category):
|
|
69
|
+
print(f"[INFO] {message}")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def performance_debug(message: str, force: bool = False):
|
|
73
|
+
"""Debug print for performance-related messages."""
|
|
74
|
+
debug_print(message, 'performance', force)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def vectorization_debug(message: str, force: bool = False):
|
|
78
|
+
"""Debug print for vectorization-related messages."""
|
|
79
|
+
debug_print(message, 'vectorization', force)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def environment_debug(message: str, force: bool = False):
|
|
83
|
+
"""Debug print for environment-related messages."""
|
|
84
|
+
debug_print(message, 'environment', force)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def benchmarking_debug(message: str, force: bool = False):
|
|
88
|
+
"""Debug print for benchmarking-related messages."""
|
|
89
|
+
debug_print(message, 'benchmarking', force)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def memory_debug(message: str, force: bool = False):
|
|
93
|
+
"""Debug print for memory-related messages."""
|
|
94
|
+
debug_print(message, 'memory', force)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def timing_debug(message: str, force: bool = False):
|
|
98
|
+
"""Debug print for timing-related messages."""
|
|
99
|
+
debug_print(message, 'timing', force)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def get_debug_categories() -> list:
|
|
103
|
+
"""Get list of available debug categories."""
|
|
104
|
+
return [
|
|
105
|
+
'policy_init', 'value_init', 'training', 'gradients', 'baseline',
|
|
106
|
+
'vectorization', 'environment', 'performance', 'benchmarking',
|
|
107
|
+
'memory', 'timing', 'general'
|
|
108
|
+
]
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def is_debug_enabled(category: str = 'general') -> bool:
|
|
112
|
+
"""Check if debug is enabled for a specific category."""
|
|
113
|
+
return debug_config.should_debug(category)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def set_debug_level(level: str):
|
|
117
|
+
"""
|
|
118
|
+
Set debug level programmatically.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
level: Debug level ('info', 'debug', 'verbose')
|
|
122
|
+
"""
|
|
123
|
+
os.environ['MLX_RL_DEBUG'] = level.lower()
|
|
124
|
+
# Reinitialize global config
|
|
125
|
+
global debug_config
|
|
126
|
+
debug_config = DebugConfig()
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def enable_category_debug(category: str, enabled: bool = True):
|
|
130
|
+
"""
|
|
131
|
+
Enable/disable debug for a specific category.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
category: Debug category name
|
|
135
|
+
enabled: Whether to enable or disable
|
|
136
|
+
"""
|
|
137
|
+
env_var = f'MLX_RL_DEBUG_{category.upper()}'
|
|
138
|
+
os.environ[env_var] = 'true' if enabled else 'false'
|
|
139
|
+
|
|
140
|
+
# Reinitialize global config
|
|
141
|
+
global debug_config
|
|
142
|
+
debug_config = DebugConfig()
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def print_debug_status():
|
|
146
|
+
"""Print current debug configuration status."""
|
|
147
|
+
print("MLX-RL Debug Configuration:")
|
|
148
|
+
print("=" * 40)
|
|
149
|
+
print(f"Overall enabled: {debug_config.enabled}")
|
|
150
|
+
print(f"Verbose mode: {debug_config.verbose}")
|
|
151
|
+
print()
|
|
152
|
+
print("Category-specific settings:")
|
|
153
|
+
|
|
154
|
+
categories = {
|
|
155
|
+
'policy_init': debug_config.policy_init,
|
|
156
|
+
'value_init': debug_config.value_init,
|
|
157
|
+
'training': debug_config.training,
|
|
158
|
+
'gradients': debug_config.gradients,
|
|
159
|
+
'baseline': debug_config.baseline_estimation,
|
|
160
|
+
'vectorization': debug_config.vectorization,
|
|
161
|
+
'environment': debug_config.environment,
|
|
162
|
+
'performance': debug_config.performance,
|
|
163
|
+
'benchmarking': debug_config.benchmarking,
|
|
164
|
+
'memory': debug_config.memory,
|
|
165
|
+
'timing': debug_config.timing
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
for category, enabled in categories.items():
|
|
169
|
+
status = "Enabled" if enabled else "Disabled"
|
|
170
|
+
print(f" {category:<15} {status}")
|
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Environment analysis and profiling utilities for MLX-RL.
|
|
3
|
+
|
|
4
|
+
This module provides tools to analyze environment characteristics and determine
|
|
5
|
+
optimal parallelization strategies based on empirical performance testing.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import time
|
|
9
|
+
import numpy as np
|
|
10
|
+
from typing import List, Optional, Tuple, Union
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from enum import Enum
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
MLX_AVAILABLE = True
|
|
16
|
+
except ImportError:
|
|
17
|
+
MLX_AVAILABLE = False
|
|
18
|
+
|
|
19
|
+
from .timing import Timer
|
|
20
|
+
from .debug import debug_print
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class EnvironmentType(Enum):
|
|
24
|
+
"""Classification of environment computational complexity."""
|
|
25
|
+
ULTRA_LIGHT = "ultra_light" # <1ms per step (CartPole, MountainCar)
|
|
26
|
+
LIGHT = "light" # 1-5ms per step
|
|
27
|
+
MODERATE = "moderate" # 5-20ms per step
|
|
28
|
+
HEAVY = "heavy" # 20-100ms per step
|
|
29
|
+
ULTRA_HEAVY = "ultra_heavy" # >100ms per step
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class EnvironmentProfile:
|
|
34
|
+
"""Environment performance characteristics."""
|
|
35
|
+
name: str
|
|
36
|
+
avg_step_time_ms: float
|
|
37
|
+
steps_per_second: float
|
|
38
|
+
avg_episode_length: float
|
|
39
|
+
environment_type: EnvironmentType
|
|
40
|
+
vectorization_recommended: bool
|
|
41
|
+
recommended_num_envs: Optional[int]
|
|
42
|
+
reasoning: str
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class EnvironmentAnalyzer:
|
|
46
|
+
"""
|
|
47
|
+
Analyzes environment performance characteristics to guide optimization decisions.
|
|
48
|
+
|
|
49
|
+
Based on empirical testing from MLX-RL vectorization analysis:
|
|
50
|
+
- CartPole-v1: ~250k steps/sec (ultra_light)
|
|
51
|
+
- MountainCar-v0: ~225k steps/sec (ultra_light)
|
|
52
|
+
- Acrobot-v1: ~42k steps/sec (light)
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
# Performance thresholds based on empirical analysis
|
|
56
|
+
STEP_TIME_THRESHOLDS = {
|
|
57
|
+
EnvironmentType.ULTRA_LIGHT: 0.005, # <5ms per step
|
|
58
|
+
EnvironmentType.LIGHT: 0.02, # 5-20ms per step
|
|
59
|
+
EnvironmentType.MODERATE: 0.1, # 20-100ms per step
|
|
60
|
+
EnvironmentType.HEAVY: 0.5, # 100-500ms per step
|
|
61
|
+
EnvironmentType.ULTRA_HEAVY: float('inf') # >500ms per step
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
def __init__(self, test_steps: int = 500, warmup_steps: int = 50):
|
|
65
|
+
"""
|
|
66
|
+
Initialize environment analyzer.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
test_steps: Number of steps to use for performance testing
|
|
70
|
+
warmup_steps: Number of warmup steps (excluded from timing)
|
|
71
|
+
"""
|
|
72
|
+
self.test_steps = test_steps
|
|
73
|
+
self.warmup_steps = warmup_steps
|
|
74
|
+
self.timer = Timer()
|
|
75
|
+
|
|
76
|
+
def analyze_environment(self, env_name: str, **env_kwargs) -> EnvironmentProfile:
|
|
77
|
+
"""
|
|
78
|
+
Analyze an environment's performance characteristics.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
env_name: Environment name (e.g., "CartPole-v1")
|
|
82
|
+
**env_kwargs: Additional environment creation arguments
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
EnvironmentProfile with performance analysis and recommendations
|
|
86
|
+
"""
|
|
87
|
+
debug_print(f"Analyzing environment: {env_name}", "environment")
|
|
88
|
+
|
|
89
|
+
try:
|
|
90
|
+
# Import here to avoid circular dependencies
|
|
91
|
+
from textpolicy.environment import GymAdapter
|
|
92
|
+
|
|
93
|
+
env = GymAdapter(env_name, **env_kwargs)
|
|
94
|
+
|
|
95
|
+
# Warmup
|
|
96
|
+
obs, info = env.reset()
|
|
97
|
+
for _ in range(self.warmup_steps):
|
|
98
|
+
action = self._sample_action(env)
|
|
99
|
+
step_result = env.step(action)
|
|
100
|
+
if step_result["terminated"] or step_result["truncated"]:
|
|
101
|
+
obs, info = env.reset()
|
|
102
|
+
|
|
103
|
+
# Performance measurement
|
|
104
|
+
step_times = []
|
|
105
|
+
episode_lengths = []
|
|
106
|
+
current_episode_length = 0
|
|
107
|
+
|
|
108
|
+
obs, info = env.reset()
|
|
109
|
+
|
|
110
|
+
start_time = time.perf_counter()
|
|
111
|
+
for i in range(self.test_steps):
|
|
112
|
+
action = self._sample_action(env)
|
|
113
|
+
|
|
114
|
+
step_start = time.perf_counter()
|
|
115
|
+
step_result = env.step(action)
|
|
116
|
+
step_end = time.perf_counter()
|
|
117
|
+
|
|
118
|
+
step_times.append(step_end - step_start)
|
|
119
|
+
current_episode_length += 1
|
|
120
|
+
|
|
121
|
+
if step_result["terminated"] or step_result["truncated"]:
|
|
122
|
+
episode_lengths.append(current_episode_length)
|
|
123
|
+
current_episode_length = 0
|
|
124
|
+
obs, info = env.reset()
|
|
125
|
+
|
|
126
|
+
end_time = time.perf_counter()
|
|
127
|
+
|
|
128
|
+
env.close()
|
|
129
|
+
|
|
130
|
+
# Calculate metrics
|
|
131
|
+
total_time = end_time - start_time
|
|
132
|
+
avg_step_time = np.mean(step_times)
|
|
133
|
+
steps_per_second = self.test_steps / total_time
|
|
134
|
+
avg_episode_length = np.mean(episode_lengths) if episode_lengths else current_episode_length
|
|
135
|
+
|
|
136
|
+
# Classify environment type
|
|
137
|
+
env_type = self._classify_environment_type(avg_step_time)
|
|
138
|
+
|
|
139
|
+
# Generate recommendations
|
|
140
|
+
vectorization_recommended, recommended_num_envs, reasoning = self._generate_recommendations(
|
|
141
|
+
env_type, avg_step_time, steps_per_second
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
return EnvironmentProfile(
|
|
145
|
+
name=env_name,
|
|
146
|
+
avg_step_time_ms=float(avg_step_time * 1000),
|
|
147
|
+
steps_per_second=float(steps_per_second),
|
|
148
|
+
avg_episode_length=float(avg_episode_length),
|
|
149
|
+
environment_type=env_type,
|
|
150
|
+
vectorization_recommended=vectorization_recommended,
|
|
151
|
+
recommended_num_envs=recommended_num_envs,
|
|
152
|
+
reasoning=reasoning
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
except Exception as e:
|
|
156
|
+
debug_print(f"Error analyzing environment {env_name}: {e}", "environment")
|
|
157
|
+
# Return minimal profile on error
|
|
158
|
+
return EnvironmentProfile(
|
|
159
|
+
name=env_name,
|
|
160
|
+
avg_step_time_ms=0.0,
|
|
161
|
+
steps_per_second=0.0,
|
|
162
|
+
avg_episode_length=0.0,
|
|
163
|
+
environment_type=EnvironmentType.MODERATE,
|
|
164
|
+
vectorization_recommended=False,
|
|
165
|
+
recommended_num_envs=None,
|
|
166
|
+
reasoning=f"Analysis failed: {e}"
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
def _sample_action(self, env) -> Union[int, float, np.ndarray]:
|
|
170
|
+
"""Sample a random action from the environment's action space."""
|
|
171
|
+
try:
|
|
172
|
+
# Use gymnasium's action space sampling
|
|
173
|
+
return env.action_space.sample()
|
|
174
|
+
except Exception:
|
|
175
|
+
# Fallback for discrete spaces
|
|
176
|
+
if hasattr(env.action_space, 'n'):
|
|
177
|
+
return np.random.randint(0, env.action_space.n)
|
|
178
|
+
else:
|
|
179
|
+
# Assume continuous space
|
|
180
|
+
return np.random.random()
|
|
181
|
+
|
|
182
|
+
def _classify_environment_type(self, avg_step_time: float) -> EnvironmentType:
|
|
183
|
+
"""Classify environment based on average step time."""
|
|
184
|
+
for env_type, threshold in self.STEP_TIME_THRESHOLDS.items():
|
|
185
|
+
if avg_step_time <= threshold:
|
|
186
|
+
return env_type
|
|
187
|
+
return EnvironmentType.ULTRA_HEAVY
|
|
188
|
+
|
|
189
|
+
def _generate_recommendations(self, env_type: EnvironmentType, avg_step_time: float,
|
|
190
|
+
steps_per_second: float) -> Tuple[bool, Optional[int], str]:
|
|
191
|
+
"""
|
|
192
|
+
Generate vectorization recommendations based on environment analysis.
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
Tuple of (should_vectorize, recommended_num_envs, reasoning)
|
|
196
|
+
"""
|
|
197
|
+
if env_type == EnvironmentType.ULTRA_LIGHT:
|
|
198
|
+
return False, None, (
|
|
199
|
+
f"Ultra-lightweight environment ({avg_step_time*1000:.2f}ms/step, "
|
|
200
|
+
f"{steps_per_second:.0f} steps/sec). Process overhead will dominate. "
|
|
201
|
+
"Use single environment for optimal performance."
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
elif env_type == EnvironmentType.LIGHT:
|
|
205
|
+
return False, None, (
|
|
206
|
+
f"Lightweight environment ({avg_step_time*1000:.2f}ms/step). "
|
|
207
|
+
"May benefit from vectorization with complex environments only. "
|
|
208
|
+
"Test with 2-4 environments if training is I/O bound."
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
elif env_type == EnvironmentType.MODERATE:
|
|
212
|
+
return True, 4, (
|
|
213
|
+
f"Moderate complexity environment ({avg_step_time*1000:.2f}ms/step). "
|
|
214
|
+
"Good candidate for vectorization. Start with 4 environments."
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
elif env_type == EnvironmentType.HEAVY:
|
|
218
|
+
return True, 8, (
|
|
219
|
+
f"Heavy computation environment ({avg_step_time*1000:.2f}ms/step). "
|
|
220
|
+
"Excellent vectorization candidate. Recommended 8 environments."
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
else: # ULTRA_HEAVY
|
|
224
|
+
return True, 16, (
|
|
225
|
+
f"Ultra-heavy environment ({avg_step_time*1000:.2f}ms/step). "
|
|
226
|
+
"Excellent vectorization candidate. Consider 16+ environments."
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def analyze_environment(env_name: str, **env_kwargs) -> EnvironmentProfile:
|
|
231
|
+
"""
|
|
232
|
+
Convenience function to analyze an environment.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
env_name: Environment name
|
|
236
|
+
**env_kwargs: Environment creation arguments
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
EnvironmentProfile with analysis results
|
|
240
|
+
"""
|
|
241
|
+
analyzer = EnvironmentAnalyzer()
|
|
242
|
+
return analyzer.analyze_environment(env_name, **env_kwargs)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def should_vectorize(env_name: str, **env_kwargs) -> Tuple[bool, str]:
|
|
246
|
+
"""
|
|
247
|
+
Quick recommendation on whether to vectorize an environment.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
env_name: Environment name
|
|
251
|
+
**env_kwargs: Environment creation arguments
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
Tuple of (should_vectorize, reasoning)
|
|
255
|
+
"""
|
|
256
|
+
profile = analyze_environment(env_name, **env_kwargs)
|
|
257
|
+
return profile.vectorization_recommended, profile.reasoning
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def print_environment_analysis(profile: EnvironmentProfile, detailed: bool = True):
|
|
261
|
+
"""
|
|
262
|
+
Print environment analysis results in a readable format.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
profile: EnvironmentProfile to display
|
|
266
|
+
detailed: Whether to show detailed statistics
|
|
267
|
+
"""
|
|
268
|
+
print(f"\nEnvironment Analysis: {profile.name}")
|
|
269
|
+
print("=" * 50)
|
|
270
|
+
|
|
271
|
+
if detailed:
|
|
272
|
+
print("Performance Metrics:")
|
|
273
|
+
print(f" Average step time: {profile.avg_step_time_ms:.3f} ms")
|
|
274
|
+
print(f" Steps per second: {profile.steps_per_second:.0f}")
|
|
275
|
+
print(f" Average episode length: {profile.avg_episode_length:.1f}")
|
|
276
|
+
print(f" Environment type: {profile.environment_type.value}")
|
|
277
|
+
print()
|
|
278
|
+
|
|
279
|
+
print("Vectorization Recommendation:")
|
|
280
|
+
if profile.vectorization_recommended:
|
|
281
|
+
print(" Use vectorized environments")
|
|
282
|
+
if profile.recommended_num_envs:
|
|
283
|
+
print(f" Recommended: {profile.recommended_num_envs} environments")
|
|
284
|
+
else:
|
|
285
|
+
print(" Use single environment")
|
|
286
|
+
|
|
287
|
+
print("\nReasoning:")
|
|
288
|
+
print(f" {profile.reasoning}")
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
class EnvironmentBenchmark:
|
|
292
|
+
"""
|
|
293
|
+
Benchmark multiple environments to compare their characteristics.
|
|
294
|
+
|
|
295
|
+
Useful for comparing similar environments or testing modifications.
|
|
296
|
+
"""
|
|
297
|
+
|
|
298
|
+
def __init__(self, analyzer: Optional[EnvironmentAnalyzer] = None):
|
|
299
|
+
"""Initialize with optional custom analyzer."""
|
|
300
|
+
self.analyzer = analyzer or EnvironmentAnalyzer()
|
|
301
|
+
self.profiles: List[EnvironmentProfile] = []
|
|
302
|
+
|
|
303
|
+
def add_environment(self, env_name: str, **env_kwargs) -> EnvironmentProfile:
|
|
304
|
+
"""
|
|
305
|
+
Add an environment to the benchmark.
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
env_name: Environment name
|
|
309
|
+
**env_kwargs: Environment creation arguments
|
|
310
|
+
|
|
311
|
+
Returns:
|
|
312
|
+
EnvironmentProfile for the added environment
|
|
313
|
+
"""
|
|
314
|
+
profile = self.analyzer.analyze_environment(env_name, **env_kwargs)
|
|
315
|
+
self.profiles.append(profile)
|
|
316
|
+
return profile
|
|
317
|
+
|
|
318
|
+
def print_comparison(self):
|
|
319
|
+
"""Print comparison of all benchmarked environments."""
|
|
320
|
+
if not self.profiles:
|
|
321
|
+
print("No environments benchmarked yet.")
|
|
322
|
+
return
|
|
323
|
+
|
|
324
|
+
print(f"\nEnvironment Comparison ({len(self.profiles)} environments)")
|
|
325
|
+
print("=" * 80)
|
|
326
|
+
|
|
327
|
+
# Header
|
|
328
|
+
print(f"{'Environment':<20} {'Step Time (ms)':<15} {'Steps/sec':<12} {'Type':<12} {'Vectorize':<10}")
|
|
329
|
+
print("-" * 80)
|
|
330
|
+
|
|
331
|
+
# Sort by steps per second (descending)
|
|
332
|
+
sorted_profiles = sorted(self.profiles, key=lambda p: p.steps_per_second, reverse=True)
|
|
333
|
+
|
|
334
|
+
for profile in sorted_profiles:
|
|
335
|
+
vectorize_str = "Yes" if profile.vectorization_recommended else "No"
|
|
336
|
+
print(f"{profile.name:<20} {profile.avg_step_time_ms:<15.3f} "
|
|
337
|
+
f"{profile.steps_per_second:<12.0f} {profile.environment_type.value:<12} {vectorize_str:<10}")
|
|
338
|
+
|
|
339
|
+
print()
|
|
340
|
+
|
|
341
|
+
# Summary statistics
|
|
342
|
+
ultra_light = sum(1 for p in self.profiles if p.environment_type == EnvironmentType.ULTRA_LIGHT)
|
|
343
|
+
light = sum(1 for p in self.profiles if p.environment_type == EnvironmentType.LIGHT)
|
|
344
|
+
moderate_plus = len(self.profiles) - ultra_light - light
|
|
345
|
+
|
|
346
|
+
print("Summary:")
|
|
347
|
+
print(f" Ultra-light environments: {ultra_light} (single process recommended)")
|
|
348
|
+
print(f" Light environments: {light} (test vectorization carefully)")
|
|
349
|
+
print(f" Moderate+ environments: {moderate_plus} (vectorization recommended)")
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# textpolicy/utils/logging/__init__.py
|
|
2
|
+
"""
|
|
3
|
+
Logging utilities for TextPolicy.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from .base import Logger
|
|
7
|
+
from .wandb import WandbLogger
|
|
8
|
+
from .tensorboard import TensorboardLogger
|
|
9
|
+
from .console import ConsoleLogger
|
|
10
|
+
from .multi import MultiLogger
|
|
11
|
+
from .factory import create_logger, create_multi_logger, create_auto_logger
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
'Logger',
|
|
15
|
+
'WandbLogger',
|
|
16
|
+
'TensorboardLogger',
|
|
17
|
+
'ConsoleLogger',
|
|
18
|
+
'MultiLogger',
|
|
19
|
+
'create_logger',
|
|
20
|
+
'create_multi_logger',
|
|
21
|
+
'create_auto_logger',
|
|
22
|
+
]
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# textpolicy/utils/logging/base.py
|
|
2
|
+
"""
|
|
3
|
+
Base logger interface and protocols for TextPolicy.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Dict
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Logger(ABC):
|
|
11
|
+
"""
|
|
12
|
+
Abstract base class for logging backends.
|
|
13
|
+
|
|
14
|
+
Defines the standard interface that all logging implementations must follow.
|
|
15
|
+
Supports both training metrics and evaluation metrics with step tracking.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def log_metrics(self, metrics: Dict[str, float], step: int):
|
|
20
|
+
"""
|
|
21
|
+
Log training metrics.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
metrics: Dictionary of metric names to values
|
|
25
|
+
step: Training step number for time-series tracking
|
|
26
|
+
"""
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def log_evaluation(self, metrics: Dict[str, float], step: int):
|
|
31
|
+
"""
|
|
32
|
+
Log evaluation metrics.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
metrics: Dictionary of evaluation metric names to values
|
|
36
|
+
step: Training step number when evaluation was performed
|
|
37
|
+
"""
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def finish(self):
|
|
42
|
+
"""
|
|
43
|
+
Finish logging session and cleanup resources.
|
|
44
|
+
|
|
45
|
+
Called at end of training to properly close connections,
|
|
46
|
+
save final data, and cleanup any temporary resources.
|
|
47
|
+
"""
|
|
48
|
+
pass
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
# textpolicy/utils/logging/console.py
|
|
2
|
+
"""
|
|
3
|
+
Simple console logging for debugging and minimal setups.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Dict
|
|
7
|
+
from .base import Logger
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ConsoleLogger(Logger):
|
|
11
|
+
"""
|
|
12
|
+
Simple console logging for debugging and development.
|
|
13
|
+
|
|
14
|
+
Features:
|
|
15
|
+
- No external dependencies
|
|
16
|
+
- Immediate output for debugging
|
|
17
|
+
- Configurable verbosity
|
|
18
|
+
- Minimal overhead
|
|
19
|
+
|
|
20
|
+
Ideal for:
|
|
21
|
+
- Development and debugging
|
|
22
|
+
- CI/CD pipelines
|
|
23
|
+
- Minimal deployment environments
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, verbose: bool = True):
|
|
27
|
+
"""
|
|
28
|
+
Initialize console logging.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
verbose: Whether to print metrics to console
|
|
32
|
+
"""
|
|
33
|
+
self.verbose = verbose
|
|
34
|
+
|
|
35
|
+
def log_metrics(self, metrics: Dict[str, float], step: int):
|
|
36
|
+
"""
|
|
37
|
+
Log training metrics to console.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
metrics: Training metrics dictionary
|
|
41
|
+
step: Training step number
|
|
42
|
+
"""
|
|
43
|
+
if self.verbose:
|
|
44
|
+
metrics_str = ", ".join(f"{k}: {v:.4f}" for k, v in metrics.items())
|
|
45
|
+
print(f"Step {step} - Training: {metrics_str}")
|
|
46
|
+
|
|
47
|
+
def log_evaluation(self, metrics: Dict[str, float], step: int):
|
|
48
|
+
"""
|
|
49
|
+
Log evaluation metrics to console.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
metrics: Evaluation metrics dictionary
|
|
53
|
+
step: Training step when evaluation was performed
|
|
54
|
+
"""
|
|
55
|
+
if self.verbose:
|
|
56
|
+
metrics_str = ", ".join(f"{k}: {v:.4f}" for k, v in metrics.items())
|
|
57
|
+
print(f"Step {step} - Evaluation: {metrics_str}")
|
|
58
|
+
|
|
59
|
+
def finish(self):
|
|
60
|
+
"""No cleanup needed for console logging."""
|
|
61
|
+
pass
|