textpolicy 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +52 -0
- textpolicy/__main__.py +8 -0
- textpolicy/algorithms/__init__.py +54 -0
- textpolicy/algorithms/grpo.py +642 -0
- textpolicy/algorithms/gspo.py +582 -0
- textpolicy/buffer/__init__.py +23 -0
- textpolicy/buffer/buffer.py +244 -0
- textpolicy/buffer/episode.py +383 -0
- textpolicy/buffer/sampling.py +438 -0
- textpolicy/buffer/storage.py +255 -0
- textpolicy/cli.py +67 -0
- textpolicy/environment/__init__.py +79 -0
- textpolicy/environment/base.py +110 -0
- textpolicy/environment/environment.py +46 -0
- textpolicy/environment/factory.py +103 -0
- textpolicy/environment/gym.py +106 -0
- textpolicy/environment/task_suites.py +51 -0
- textpolicy/environment/text_generation.py +789 -0
- textpolicy/environment/vectorized.py +253 -0
- textpolicy/generation/__init__.py +62 -0
- textpolicy/generation/lora.py +411 -0
- textpolicy/generation/mlx_generation.py +557 -0
- textpolicy/generation/reload.py +253 -0
- textpolicy/rewards/__init__.py +137 -0
- textpolicy/rewards/adapters.py +387 -0
- textpolicy/rewards/basic.py +214 -0
- textpolicy/rewards/integrated_system.py +338 -0
- textpolicy/rewards/mlx_batch_processor.py +447 -0
- textpolicy/rewards/registry.py +293 -0
- textpolicy/rewards/rollout_rewards.py +410 -0
- textpolicy/rewards/verifiers.py +369 -0
- textpolicy/rollout/__init__.py +44 -0
- textpolicy/rollout/aggregator.py +145 -0
- textpolicy/rollout/base.py +108 -0
- textpolicy/rollout/rollout.py +142 -0
- textpolicy/rollout/runner.py +280 -0
- textpolicy/rollout/strategy.py +208 -0
- textpolicy/rollout/worker.py +194 -0
- textpolicy/training/__init__.py +14 -0
- textpolicy/training/metrics.py +242 -0
- textpolicy/training/rollout_manager.py +78 -0
- textpolicy/training/trainer.py +684 -0
- textpolicy/utils/__init__.py +40 -0
- textpolicy/utils/benchmarking.py +489 -0
- textpolicy/utils/data.py +60 -0
- textpolicy/utils/debug.py +170 -0
- textpolicy/utils/environment.py +349 -0
- textpolicy/utils/logging/__init__.py +22 -0
- textpolicy/utils/logging/base.py +48 -0
- textpolicy/utils/logging/console.py +61 -0
- textpolicy/utils/logging/factory.py +133 -0
- textpolicy/utils/logging/multi.py +83 -0
- textpolicy/utils/logging/tensorboard.py +65 -0
- textpolicy/utils/logging/wandb.py +72 -0
- textpolicy/utils/memory.py +118 -0
- textpolicy/utils/performance.py +464 -0
- textpolicy/utils/timing.py +171 -0
- textpolicy/validate.py +101 -0
- textpolicy/validation/__init__.py +13 -0
- textpolicy/validation/logprob_validation.py +315 -0
- textpolicy-0.1.0.dist-info/METADATA +99 -0
- textpolicy-0.1.0.dist-info/RECORD +66 -0
- textpolicy-0.1.0.dist-info/entry_points.txt +2 -0
- textpolicy-0.0.1.dist-info/METADATA +0 -10
- textpolicy-0.0.1.dist-info/RECORD +0 -6
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/WHEEL +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Vectorized environment wrapper for MLX-RL to improve training performance.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
# Import debug utilities for proper debug management
|
|
6
|
+
from textpolicy.utils.debug import vectorization_debug
|
|
7
|
+
|
|
8
|
+
from typing import Any, Dict, Optional, Tuple, cast
|
|
9
|
+
|
|
10
|
+
import gymnasium as gym
|
|
11
|
+
try:
|
|
12
|
+
import mlx.core as mx # type: ignore
|
|
13
|
+
except ImportError:
|
|
14
|
+
mx = None # MLX is optional; vectorized data conversion will error if invoked without MLX installed
|
|
15
|
+
import numpy as np
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class VectorizedEnvironment:
|
|
19
|
+
"""
|
|
20
|
+
MLX-compatible vectorized environment wrapper.
|
|
21
|
+
|
|
22
|
+
This wrapper provides parallel environment execution using Gymnasium's
|
|
23
|
+
vectorized environments, with MLX-optimized data conversion.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, vec_env):
|
|
27
|
+
"""
|
|
28
|
+
Initialize vectorized environment.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
vec_env: Gymnasium vectorized environment
|
|
32
|
+
"""
|
|
33
|
+
vectorization_debug("Initializing VectorizedEnvironment")
|
|
34
|
+
self.vec_env = vec_env
|
|
35
|
+
self.num_envs = vec_env.num_envs
|
|
36
|
+
|
|
37
|
+
def reset(self) -> Tuple[Any, Dict[str, Any]]:
|
|
38
|
+
"""
|
|
39
|
+
Reset all environments.
|
|
40
|
+
|
|
41
|
+
Fixed to match Environment base class contract: returns (observation, info) tuple
|
|
42
|
+
instead of dict. This enables VectorizedEnvironment to be used as drop-in
|
|
43
|
+
replacement for GymAdapter in the training system.
|
|
44
|
+
|
|
45
|
+
Previously returned dict caused silent failures - training system would unpack
|
|
46
|
+
dict keys ("observation", "info") instead of actual values, passing strings
|
|
47
|
+
to policy instead of observation arrays.
|
|
48
|
+
"""
|
|
49
|
+
observations, infos = self.vec_env.reset()
|
|
50
|
+
|
|
51
|
+
# Convert to MLX-compatible format and return as tuple per Environment interface
|
|
52
|
+
batched_obs = self._to_mlx_batch(observations)
|
|
53
|
+
return batched_obs, infos
|
|
54
|
+
|
|
55
|
+
def step(self, actions: mx.array) -> Dict[str, Any]:
|
|
56
|
+
vectorization_debug(f"VectorizedEnvironment.step: actions shape {actions.shape}")
|
|
57
|
+
# Convert MLX actions to numpy for Gymnasium with zero-copy view for Apple Silicon efficiency
|
|
58
|
+
# Use direct conversion - MLX arrays are naturally numpy-compatible via __array__ protocol
|
|
59
|
+
actions_np = np.array(actions, copy=False)
|
|
60
|
+
vectorization_debug(f"Converted actions: shape {actions_np.shape}")
|
|
61
|
+
|
|
62
|
+
# Ensure proper shape for vectorized env based on action space
|
|
63
|
+
# Check if action space is discrete or continuous
|
|
64
|
+
from gymnasium.spaces import Discrete, MultiDiscrete, Box
|
|
65
|
+
|
|
66
|
+
if isinstance(self.vec_env.action_space, (Discrete, MultiDiscrete)):
|
|
67
|
+
# Discrete action space
|
|
68
|
+
if actions_np.ndim == 0:
|
|
69
|
+
# Scalar action - expand to [num_envs]
|
|
70
|
+
actions_np = np.full(self.num_envs, actions_np.item())
|
|
71
|
+
elif actions_np.ndim == 1:
|
|
72
|
+
# 1D actions are already in the correct shape [num_envs]
|
|
73
|
+
pass
|
|
74
|
+
elif actions_np.ndim == 2 and actions_np.shape[1] == 1:
|
|
75
|
+
# 2D array with shape [num_envs, 1] - flatten to [num_envs]
|
|
76
|
+
actions_np = actions_np.flatten()
|
|
77
|
+
elif isinstance(self.vec_env.action_space, Box):
|
|
78
|
+
# Continuous action space
|
|
79
|
+
if actions_np.ndim == 0:
|
|
80
|
+
# Scalar action - this shouldn't happen for continuous spaces, but let's handle it
|
|
81
|
+
actions_np = np.full((self.num_envs, self.vec_env.action_space.shape[0]), actions_np.item())
|
|
82
|
+
elif actions_np.ndim == 1:
|
|
83
|
+
# Reshape 1D array to [num_envs, action_dim]
|
|
84
|
+
actions_np = actions_np.reshape(self.num_envs, -1)
|
|
85
|
+
else:
|
|
86
|
+
# Unknown action space type - try to handle gracefully
|
|
87
|
+
if hasattr(self.vec_env.action_space, 'shape') and len(self.vec_env.action_space.shape) > 0:
|
|
88
|
+
# Assume continuous action space
|
|
89
|
+
if actions_np.ndim == 0:
|
|
90
|
+
actions_np = np.full((self.num_envs, self.vec_env.action_space.shape[0]), actions_np.item())
|
|
91
|
+
elif actions_np.ndim == 1:
|
|
92
|
+
actions_np = actions_np.reshape(self.num_envs, -1)
|
|
93
|
+
else:
|
|
94
|
+
# Assume discrete action space
|
|
95
|
+
if actions_np.ndim == 0:
|
|
96
|
+
actions_np = np.full(self.num_envs, actions_np.item())
|
|
97
|
+
elif actions_np.ndim == 1:
|
|
98
|
+
pass
|
|
99
|
+
elif actions_np.ndim == 2 and actions_np.shape[1] == 1:
|
|
100
|
+
actions_np = actions_np.flatten()
|
|
101
|
+
|
|
102
|
+
vectorization_debug(f"Calling vec_env.step with actions shape: {actions_np.shape}")
|
|
103
|
+
observations, rewards, terminated, truncated, infos = self.vec_env.step(actions_np)
|
|
104
|
+
|
|
105
|
+
return {
|
|
106
|
+
"observation": self._to_mlx_batch(observations),
|
|
107
|
+
"reward": mx.array(rewards, dtype=mx.float32),
|
|
108
|
+
"terminated": mx.array(terminated, dtype=mx.bool_),
|
|
109
|
+
"truncated": mx.array(truncated, dtype=mx.bool_),
|
|
110
|
+
"info": infos
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
def close(self):
|
|
114
|
+
"""Close all environments."""
|
|
115
|
+
self.vec_env.close()
|
|
116
|
+
|
|
117
|
+
def _to_mlx_batch(self, observations: np.ndarray) -> mx.array:
|
|
118
|
+
"""
|
|
119
|
+
Convert numpy observations to MLX array batch.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
observations: Numpy array [num_envs, obs_dim]
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
MLX array with proper contiguous memory layout
|
|
126
|
+
"""
|
|
127
|
+
# Ensure contiguous array for MLX compatibility
|
|
128
|
+
if not observations.flags.c_contiguous:
|
|
129
|
+
observations = np.ascontiguousarray(observations)
|
|
130
|
+
|
|
131
|
+
return mx.array(observations, dtype=mx.float32)
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def observation_space(self):
|
|
135
|
+
"""Get observation space."""
|
|
136
|
+
return self.vec_env.observation_space
|
|
137
|
+
|
|
138
|
+
@property
|
|
139
|
+
def action_space(self):
|
|
140
|
+
"""Get action space."""
|
|
141
|
+
return self.vec_env.action_space
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def make_vectorized_env(env_id: str, num_envs: int = 1, use_async: bool = True, env_kwargs: Optional[Dict] = None) -> VectorizedEnvironment:
|
|
145
|
+
"""
|
|
146
|
+
Create a vectorized environment.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
env_id: Environment ID
|
|
150
|
+
num_envs: Number of environments to create
|
|
151
|
+
use_async: Whether to use async vectorized env (default: True)
|
|
152
|
+
env_kwargs: Optional kwargs to pass to individual environments
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
VectorizedEnvironment instance
|
|
156
|
+
"""
|
|
157
|
+
vectorization_debug(f"Creating vectorized env: {env_id}, num_envs={num_envs}, async={use_async}")
|
|
158
|
+
|
|
159
|
+
# Handle edge case: num_envs must be positive
|
|
160
|
+
if num_envs <= 0:
|
|
161
|
+
raise ValueError(f"num_envs must be positive, got {num_envs}")
|
|
162
|
+
|
|
163
|
+
# Prepare environment creation arguments
|
|
164
|
+
vec_kwargs = {"vectorization_mode": "async" if use_async else "sync"}
|
|
165
|
+
if env_kwargs:
|
|
166
|
+
vec_kwargs["env_kwargs"] = env_kwargs
|
|
167
|
+
|
|
168
|
+
# Create the vectorized environment using Gymnasium
|
|
169
|
+
try:
|
|
170
|
+
vec_env = gym.make_vec(env_id, num_envs=num_envs, **vec_kwargs)
|
|
171
|
+
vectorization_debug(f"Created gym vec env: {type(vec_env)}")
|
|
172
|
+
return VectorizedEnvironment(vec_env)
|
|
173
|
+
except Exception as e:
|
|
174
|
+
vectorization_debug(f"Failed to create vec env: {e}")
|
|
175
|
+
raise
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class VectorizedCollector:
|
|
179
|
+
"""
|
|
180
|
+
Efficient data collection using vectorized environments.
|
|
181
|
+
|
|
182
|
+
This collector can gather experience from multiple environments in parallel,
|
|
183
|
+
significantly speeding up training for sample-hungry algorithms like PPO.
|
|
184
|
+
"""
|
|
185
|
+
|
|
186
|
+
def __init__(self, vec_env: VectorizedEnvironment, policy):
|
|
187
|
+
"""
|
|
188
|
+
Initialize vectorized collector.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
vec_env: Vectorized environment
|
|
192
|
+
policy: Policy to collect data with
|
|
193
|
+
"""
|
|
194
|
+
self.vec_env = vec_env
|
|
195
|
+
self.policy = policy
|
|
196
|
+
self.num_envs = vec_env.num_envs
|
|
197
|
+
|
|
198
|
+
def collect_batch(self, batch_size: int) -> Dict[str, mx.array]:
|
|
199
|
+
"""
|
|
200
|
+
Collect a batch of experiences using vectorized environments.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
batch_size: Total number of steps to collect
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
Dictionary containing batched experiences
|
|
207
|
+
"""
|
|
208
|
+
# Calculate steps per environment
|
|
209
|
+
steps_per_env = batch_size // self.num_envs
|
|
210
|
+
if batch_size % self.num_envs != 0:
|
|
211
|
+
steps_per_env += 1
|
|
212
|
+
|
|
213
|
+
# Storage for collected data
|
|
214
|
+
observations = []
|
|
215
|
+
actions = []
|
|
216
|
+
rewards = []
|
|
217
|
+
dones = []
|
|
218
|
+
values = []
|
|
219
|
+
logprobs = []
|
|
220
|
+
|
|
221
|
+
# Reset environments (now returns tuple per Environment interface)
|
|
222
|
+
current_obs, reset_info = self.vec_env.reset()
|
|
223
|
+
|
|
224
|
+
for step in range(steps_per_env):
|
|
225
|
+
# Get actions from policy
|
|
226
|
+
action_mx, info = self.policy(current_obs, deterministic=False)
|
|
227
|
+
|
|
228
|
+
# Store current step data
|
|
229
|
+
observations.append(current_obs)
|
|
230
|
+
actions.append(action_mx)
|
|
231
|
+
values.append(info.get("value", mx.zeros(self.num_envs)))
|
|
232
|
+
logprobs.append(info.get("logprob", mx.zeros(self.num_envs)))
|
|
233
|
+
|
|
234
|
+
# Step environments
|
|
235
|
+
step_result = self.vec_env.step(action_mx)
|
|
236
|
+
|
|
237
|
+
rewards.append(step_result["reward"])
|
|
238
|
+
dones.append(step_result["terminated"] | step_result["truncated"])
|
|
239
|
+
|
|
240
|
+
# Update observations
|
|
241
|
+
current_obs = step_result["observation"]
|
|
242
|
+
|
|
243
|
+
# Handle episode resets (vectorized environments handle this automatically)
|
|
244
|
+
|
|
245
|
+
# Stack all collected data
|
|
246
|
+
return {
|
|
247
|
+
"observations": mx.stack(observations[:batch_size]),
|
|
248
|
+
"actions": mx.stack(actions[:batch_size]),
|
|
249
|
+
"rewards": mx.stack(rewards[:batch_size]),
|
|
250
|
+
"dones": mx.stack(dones[:batch_size]),
|
|
251
|
+
"values": mx.stack(values[:batch_size]),
|
|
252
|
+
"logprobs": mx.stack(logprobs[:batch_size])
|
|
253
|
+
}
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
# textpolicy/generation/__init__.py
|
|
2
|
+
"""
|
|
3
|
+
Pure MLX-LM text generation functions for RL training.
|
|
4
|
+
|
|
5
|
+
Following TextPolicy design principles:
|
|
6
|
+
- Pure function composition over classes
|
|
7
|
+
- Zero abstraction cost for MLX optimization
|
|
8
|
+
- Direct integration with GRPO trainer
|
|
9
|
+
- LoRA/QLoRA support for memory efficiency
|
|
10
|
+
|
|
11
|
+
All functions are pure and composable with our existing trainer system.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
# Core MLX-LM generation functions
|
|
15
|
+
from .mlx_generation import (
|
|
16
|
+
load_model,
|
|
17
|
+
generate_tokens,
|
|
18
|
+
compute_logprobs,
|
|
19
|
+
encode,
|
|
20
|
+
decode,
|
|
21
|
+
create_policy,
|
|
22
|
+
compute_reward,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
# LoRA/QLoRA pure functions
|
|
26
|
+
from .lora import (
|
|
27
|
+
apply_lora,
|
|
28
|
+
freeze_base,
|
|
29
|
+
extract_params,
|
|
30
|
+
merge_weights,
|
|
31
|
+
create_lora_setup,
|
|
32
|
+
create_qlora_setup
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# LoRA utility functions (advanced use only)
|
|
36
|
+
from .reload import (
|
|
37
|
+
save_adapters,
|
|
38
|
+
reload_model
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
__all__ = [
|
|
42
|
+
# Core generation functions
|
|
43
|
+
"load_model",
|
|
44
|
+
"generate_tokens",
|
|
45
|
+
"compute_logprobs",
|
|
46
|
+
"encode",
|
|
47
|
+
"decode",
|
|
48
|
+
"create_policy",
|
|
49
|
+
"compute_reward",
|
|
50
|
+
|
|
51
|
+
# LoRA functions
|
|
52
|
+
"apply_lora",
|
|
53
|
+
"freeze_base",
|
|
54
|
+
"extract_params",
|
|
55
|
+
"merge_weights",
|
|
56
|
+
"create_lora_setup",
|
|
57
|
+
"create_qlora_setup",
|
|
58
|
+
|
|
59
|
+
# Advanced LoRA utilities
|
|
60
|
+
"save_adapters",
|
|
61
|
+
"reload_model",
|
|
62
|
+
]
|