textpolicy 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. textpolicy/__init__.py +53 -0
  2. textpolicy/__main__.py +8 -0
  3. textpolicy/algorithms/__init__.py +54 -0
  4. textpolicy/algorithms/grpo.py +642 -0
  5. textpolicy/algorithms/gspo.py +582 -0
  6. textpolicy/buffer/__init__.py +23 -0
  7. textpolicy/buffer/buffer.py +244 -0
  8. textpolicy/buffer/episode.py +383 -0
  9. textpolicy/buffer/sampling.py +438 -0
  10. textpolicy/buffer/storage.py +255 -0
  11. textpolicy/cli.py +67 -0
  12. textpolicy/environment/__init__.py +79 -0
  13. textpolicy/environment/base.py +110 -0
  14. textpolicy/environment/environment.py +46 -0
  15. textpolicy/environment/factory.py +103 -0
  16. textpolicy/environment/gym.py +106 -0
  17. textpolicy/environment/task_suites.py +51 -0
  18. textpolicy/environment/text_generation.py +797 -0
  19. textpolicy/environment/vectorized.py +253 -0
  20. textpolicy/generation/__init__.py +62 -0
  21. textpolicy/generation/lora.py +411 -0
  22. textpolicy/generation/mlx_generation.py +557 -0
  23. textpolicy/generation/reload.py +253 -0
  24. textpolicy/rewards/__init__.py +137 -0
  25. textpolicy/rewards/adapters.py +387 -0
  26. textpolicy/rewards/basic.py +214 -0
  27. textpolicy/rewards/integrated_system.py +338 -0
  28. textpolicy/rewards/mlx_batch_processor.py +447 -0
  29. textpolicy/rewards/registry.py +293 -0
  30. textpolicy/rewards/rollout_rewards.py +410 -0
  31. textpolicy/rewards/verifiers.py +369 -0
  32. textpolicy/rollout/__init__.py +44 -0
  33. textpolicy/rollout/aggregator.py +145 -0
  34. textpolicy/rollout/base.py +108 -0
  35. textpolicy/rollout/rollout.py +142 -0
  36. textpolicy/rollout/runner.py +280 -0
  37. textpolicy/rollout/strategy.py +208 -0
  38. textpolicy/rollout/worker.py +194 -0
  39. textpolicy/training/__init__.py +14 -0
  40. textpolicy/training/metrics.py +242 -0
  41. textpolicy/training/rollout_manager.py +78 -0
  42. textpolicy/training/trainer.py +684 -0
  43. textpolicy/utils/__init__.py +40 -0
  44. textpolicy/utils/benchmarking.py +489 -0
  45. textpolicy/utils/data.py +60 -0
  46. textpolicy/utils/debug.py +170 -0
  47. textpolicy/utils/environment.py +349 -0
  48. textpolicy/utils/logging/__init__.py +22 -0
  49. textpolicy/utils/logging/base.py +48 -0
  50. textpolicy/utils/logging/console.py +61 -0
  51. textpolicy/utils/logging/factory.py +133 -0
  52. textpolicy/utils/logging/multi.py +83 -0
  53. textpolicy/utils/logging/tensorboard.py +65 -0
  54. textpolicy/utils/logging/wandb.py +72 -0
  55. textpolicy/utils/memory.py +118 -0
  56. textpolicy/utils/performance.py +464 -0
  57. textpolicy/utils/timing.py +171 -0
  58. textpolicy/validate.py +101 -0
  59. textpolicy/validation/__init__.py +13 -0
  60. textpolicy/validation/logprob_validation.py +315 -0
  61. textpolicy-0.1.1.dist-info/METADATA +109 -0
  62. textpolicy-0.1.1.dist-info/RECORD +66 -0
  63. {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/WHEEL +1 -1
  64. textpolicy-0.1.1.dist-info/entry_points.txt +2 -0
  65. textpolicy-0.0.1.dist-info/METADATA +0 -10
  66. textpolicy-0.0.1.dist-info/RECORD +0 -6
  67. {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/licenses/LICENSE +0 -0
  68. {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,253 @@
1
+ """
2
+ Vectorized environment wrapper for MLX-RL to improve training performance.
3
+ """
4
+
5
+ # Import debug utilities for proper debug management
6
+ from textpolicy.utils.debug import vectorization_debug
7
+
8
+ from typing import Any, Dict, Optional, Tuple, cast
9
+
10
+ import gymnasium as gym
11
+ try:
12
+ import mlx.core as mx # type: ignore
13
+ except ImportError:
14
+ mx = None # MLX is optional; vectorized data conversion will error if invoked without MLX installed
15
+ import numpy as np
16
+
17
+
18
+ class VectorizedEnvironment:
19
+ """
20
+ MLX-compatible vectorized environment wrapper.
21
+
22
+ This wrapper provides parallel environment execution using Gymnasium's
23
+ vectorized environments, with MLX-optimized data conversion.
24
+ """
25
+
26
+ def __init__(self, vec_env):
27
+ """
28
+ Initialize vectorized environment.
29
+
30
+ Args:
31
+ vec_env: Gymnasium vectorized environment
32
+ """
33
+ vectorization_debug("Initializing VectorizedEnvironment")
34
+ self.vec_env = vec_env
35
+ self.num_envs = vec_env.num_envs
36
+
37
+ def reset(self) -> Tuple[Any, Dict[str, Any]]:
38
+ """
39
+ Reset all environments.
40
+
41
+ Fixed to match Environment base class contract: returns (observation, info) tuple
42
+ instead of dict. This enables VectorizedEnvironment to be used as drop-in
43
+ replacement for GymAdapter in the training system.
44
+
45
+ Previously returned dict caused silent failures - training system would unpack
46
+ dict keys ("observation", "info") instead of actual values, passing strings
47
+ to policy instead of observation arrays.
48
+ """
49
+ observations, infos = self.vec_env.reset()
50
+
51
+ # Convert to MLX-compatible format and return as tuple per Environment interface
52
+ batched_obs = self._to_mlx_batch(observations)
53
+ return batched_obs, infos
54
+
55
+ def step(self, actions: mx.array) -> Dict[str, Any]:
56
+ vectorization_debug(f"VectorizedEnvironment.step: actions shape {actions.shape}")
57
+ # Convert MLX actions to numpy for Gymnasium with zero-copy view for Apple Silicon efficiency
58
+ # Use direct conversion - MLX arrays are naturally numpy-compatible via __array__ protocol
59
+ actions_np = np.array(actions, copy=False)
60
+ vectorization_debug(f"Converted actions: shape {actions_np.shape}")
61
+
62
+ # Ensure proper shape for vectorized env based on action space
63
+ # Check if action space is discrete or continuous
64
+ from gymnasium.spaces import Discrete, MultiDiscrete, Box
65
+
66
+ if isinstance(self.vec_env.action_space, (Discrete, MultiDiscrete)):
67
+ # Discrete action space
68
+ if actions_np.ndim == 0:
69
+ # Scalar action - expand to [num_envs]
70
+ actions_np = np.full(self.num_envs, actions_np.item())
71
+ elif actions_np.ndim == 1:
72
+ # 1D actions are already in the correct shape [num_envs]
73
+ pass
74
+ elif actions_np.ndim == 2 and actions_np.shape[1] == 1:
75
+ # 2D array with shape [num_envs, 1] - flatten to [num_envs]
76
+ actions_np = actions_np.flatten()
77
+ elif isinstance(self.vec_env.action_space, Box):
78
+ # Continuous action space
79
+ if actions_np.ndim == 0:
80
+ # Scalar action - this shouldn't happen for continuous spaces, but let's handle it
81
+ actions_np = np.full((self.num_envs, self.vec_env.action_space.shape[0]), actions_np.item())
82
+ elif actions_np.ndim == 1:
83
+ # Reshape 1D array to [num_envs, action_dim]
84
+ actions_np = actions_np.reshape(self.num_envs, -1)
85
+ else:
86
+ # Unknown action space type - try to handle gracefully
87
+ if hasattr(self.vec_env.action_space, 'shape') and len(self.vec_env.action_space.shape) > 0:
88
+ # Assume continuous action space
89
+ if actions_np.ndim == 0:
90
+ actions_np = np.full((self.num_envs, self.vec_env.action_space.shape[0]), actions_np.item())
91
+ elif actions_np.ndim == 1:
92
+ actions_np = actions_np.reshape(self.num_envs, -1)
93
+ else:
94
+ # Assume discrete action space
95
+ if actions_np.ndim == 0:
96
+ actions_np = np.full(self.num_envs, actions_np.item())
97
+ elif actions_np.ndim == 1:
98
+ pass
99
+ elif actions_np.ndim == 2 and actions_np.shape[1] == 1:
100
+ actions_np = actions_np.flatten()
101
+
102
+ vectorization_debug(f"Calling vec_env.step with actions shape: {actions_np.shape}")
103
+ observations, rewards, terminated, truncated, infos = self.vec_env.step(actions_np)
104
+
105
+ return {
106
+ "observation": self._to_mlx_batch(observations),
107
+ "reward": mx.array(rewards, dtype=mx.float32),
108
+ "terminated": mx.array(terminated, dtype=mx.bool_),
109
+ "truncated": mx.array(truncated, dtype=mx.bool_),
110
+ "info": infos
111
+ }
112
+
113
+ def close(self):
114
+ """Close all environments."""
115
+ self.vec_env.close()
116
+
117
+ def _to_mlx_batch(self, observations: np.ndarray) -> mx.array:
118
+ """
119
+ Convert numpy observations to MLX array batch.
120
+
121
+ Args:
122
+ observations: Numpy array [num_envs, obs_dim]
123
+
124
+ Returns:
125
+ MLX array with proper contiguous memory layout
126
+ """
127
+ # Ensure contiguous array for MLX compatibility
128
+ if not observations.flags.c_contiguous:
129
+ observations = np.ascontiguousarray(observations)
130
+
131
+ return mx.array(observations, dtype=mx.float32)
132
+
133
+ @property
134
+ def observation_space(self):
135
+ """Get observation space."""
136
+ return self.vec_env.observation_space
137
+
138
+ @property
139
+ def action_space(self):
140
+ """Get action space."""
141
+ return self.vec_env.action_space
142
+
143
+
144
+ def make_vectorized_env(env_id: str, num_envs: int = 1, use_async: bool = True, env_kwargs: Optional[Dict] = None) -> VectorizedEnvironment:
145
+ """
146
+ Create a vectorized environment.
147
+
148
+ Args:
149
+ env_id: Environment ID
150
+ num_envs: Number of environments to create
151
+ use_async: Whether to use async vectorized env (default: True)
152
+ env_kwargs: Optional kwargs to pass to individual environments
153
+
154
+ Returns:
155
+ VectorizedEnvironment instance
156
+ """
157
+ vectorization_debug(f"Creating vectorized env: {env_id}, num_envs={num_envs}, async={use_async}")
158
+
159
+ # Handle edge case: num_envs must be positive
160
+ if num_envs <= 0:
161
+ raise ValueError(f"num_envs must be positive, got {num_envs}")
162
+
163
+ # Prepare environment creation arguments
164
+ vec_kwargs = {"vectorization_mode": "async" if use_async else "sync"}
165
+ if env_kwargs:
166
+ vec_kwargs["env_kwargs"] = env_kwargs
167
+
168
+ # Create the vectorized environment using Gymnasium
169
+ try:
170
+ vec_env = gym.make_vec(env_id, num_envs=num_envs, **vec_kwargs)
171
+ vectorization_debug(f"Created gym vec env: {type(vec_env)}")
172
+ return VectorizedEnvironment(vec_env)
173
+ except Exception as e:
174
+ vectorization_debug(f"Failed to create vec env: {e}")
175
+ raise
176
+
177
+
178
+ class VectorizedCollector:
179
+ """
180
+ Efficient data collection using vectorized environments.
181
+
182
+ This collector can gather experience from multiple environments in parallel,
183
+ significantly speeding up training for sample-hungry algorithms like PPO.
184
+ """
185
+
186
+ def __init__(self, vec_env: VectorizedEnvironment, policy):
187
+ """
188
+ Initialize vectorized collector.
189
+
190
+ Args:
191
+ vec_env: Vectorized environment
192
+ policy: Policy to collect data with
193
+ """
194
+ self.vec_env = vec_env
195
+ self.policy = policy
196
+ self.num_envs = vec_env.num_envs
197
+
198
+ def collect_batch(self, batch_size: int) -> Dict[str, mx.array]:
199
+ """
200
+ Collect a batch of experiences using vectorized environments.
201
+
202
+ Args:
203
+ batch_size: Total number of steps to collect
204
+
205
+ Returns:
206
+ Dictionary containing batched experiences
207
+ """
208
+ # Calculate steps per environment
209
+ steps_per_env = batch_size // self.num_envs
210
+ if batch_size % self.num_envs != 0:
211
+ steps_per_env += 1
212
+
213
+ # Storage for collected data
214
+ observations = []
215
+ actions = []
216
+ rewards = []
217
+ dones = []
218
+ values = []
219
+ logprobs = []
220
+
221
+ # Reset environments (now returns tuple per Environment interface)
222
+ current_obs, reset_info = self.vec_env.reset()
223
+
224
+ for step in range(steps_per_env):
225
+ # Get actions from policy
226
+ action_mx, info = self.policy(current_obs, deterministic=False)
227
+
228
+ # Store current step data
229
+ observations.append(current_obs)
230
+ actions.append(action_mx)
231
+ values.append(info.get("value", mx.zeros(self.num_envs)))
232
+ logprobs.append(info.get("logprob", mx.zeros(self.num_envs)))
233
+
234
+ # Step environments
235
+ step_result = self.vec_env.step(action_mx)
236
+
237
+ rewards.append(step_result["reward"])
238
+ dones.append(step_result["terminated"] | step_result["truncated"])
239
+
240
+ # Update observations
241
+ current_obs = step_result["observation"]
242
+
243
+ # Handle episode resets (vectorized environments handle this automatically)
244
+
245
+ # Stack all collected data
246
+ return {
247
+ "observations": mx.stack(observations[:batch_size]),
248
+ "actions": mx.stack(actions[:batch_size]),
249
+ "rewards": mx.stack(rewards[:batch_size]),
250
+ "dones": mx.stack(dones[:batch_size]),
251
+ "values": mx.stack(values[:batch_size]),
252
+ "logprobs": mx.stack(logprobs[:batch_size])
253
+ }
@@ -0,0 +1,62 @@
1
+ # textpolicy/generation/__init__.py
2
+ """
3
+ Pure MLX-LM text generation functions for RL training.
4
+
5
+ Following TextPolicy design principles:
6
+ - Pure function composition over classes
7
+ - Zero abstraction cost for MLX optimization
8
+ - Direct integration with GRPO trainer
9
+ - LoRA/QLoRA support for memory efficiency
10
+
11
+ All functions are pure and composable with our existing trainer system.
12
+ """
13
+
14
+ # Core MLX-LM generation functions
15
+ from .mlx_generation import (
16
+ load_model,
17
+ generate_tokens,
18
+ compute_logprobs,
19
+ encode,
20
+ decode,
21
+ create_policy,
22
+ compute_reward,
23
+ )
24
+
25
+ # LoRA/QLoRA pure functions
26
+ from .lora import (
27
+ apply_lora,
28
+ freeze_base,
29
+ extract_params,
30
+ merge_weights,
31
+ create_lora_setup,
32
+ create_qlora_setup
33
+ )
34
+
35
+ # LoRA utility functions (advanced use only)
36
+ from .reload import (
37
+ save_adapters,
38
+ reload_model
39
+ )
40
+
41
+ __all__ = [
42
+ # Core generation functions
43
+ "load_model",
44
+ "generate_tokens",
45
+ "compute_logprobs",
46
+ "encode",
47
+ "decode",
48
+ "create_policy",
49
+ "compute_reward",
50
+
51
+ # LoRA functions
52
+ "apply_lora",
53
+ "freeze_base",
54
+ "extract_params",
55
+ "merge_weights",
56
+ "create_lora_setup",
57
+ "create_qlora_setup",
58
+
59
+ # Advanced LoRA utilities
60
+ "save_adapters",
61
+ "reload_model",
62
+ ]