textpolicy 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +52 -0
- textpolicy/__main__.py +8 -0
- textpolicy/algorithms/__init__.py +54 -0
- textpolicy/algorithms/grpo.py +642 -0
- textpolicy/algorithms/gspo.py +582 -0
- textpolicy/buffer/__init__.py +23 -0
- textpolicy/buffer/buffer.py +244 -0
- textpolicy/buffer/episode.py +383 -0
- textpolicy/buffer/sampling.py +438 -0
- textpolicy/buffer/storage.py +255 -0
- textpolicy/cli.py +67 -0
- textpolicy/environment/__init__.py +79 -0
- textpolicy/environment/base.py +110 -0
- textpolicy/environment/environment.py +46 -0
- textpolicy/environment/factory.py +103 -0
- textpolicy/environment/gym.py +106 -0
- textpolicy/environment/task_suites.py +51 -0
- textpolicy/environment/text_generation.py +789 -0
- textpolicy/environment/vectorized.py +253 -0
- textpolicy/generation/__init__.py +62 -0
- textpolicy/generation/lora.py +411 -0
- textpolicy/generation/mlx_generation.py +557 -0
- textpolicy/generation/reload.py +253 -0
- textpolicy/rewards/__init__.py +137 -0
- textpolicy/rewards/adapters.py +387 -0
- textpolicy/rewards/basic.py +214 -0
- textpolicy/rewards/integrated_system.py +338 -0
- textpolicy/rewards/mlx_batch_processor.py +447 -0
- textpolicy/rewards/registry.py +293 -0
- textpolicy/rewards/rollout_rewards.py +410 -0
- textpolicy/rewards/verifiers.py +369 -0
- textpolicy/rollout/__init__.py +44 -0
- textpolicy/rollout/aggregator.py +145 -0
- textpolicy/rollout/base.py +108 -0
- textpolicy/rollout/rollout.py +142 -0
- textpolicy/rollout/runner.py +280 -0
- textpolicy/rollout/strategy.py +208 -0
- textpolicy/rollout/worker.py +194 -0
- textpolicy/training/__init__.py +14 -0
- textpolicy/training/metrics.py +242 -0
- textpolicy/training/rollout_manager.py +78 -0
- textpolicy/training/trainer.py +684 -0
- textpolicy/utils/__init__.py +40 -0
- textpolicy/utils/benchmarking.py +489 -0
- textpolicy/utils/data.py +60 -0
- textpolicy/utils/debug.py +170 -0
- textpolicy/utils/environment.py +349 -0
- textpolicy/utils/logging/__init__.py +22 -0
- textpolicy/utils/logging/base.py +48 -0
- textpolicy/utils/logging/console.py +61 -0
- textpolicy/utils/logging/factory.py +133 -0
- textpolicy/utils/logging/multi.py +83 -0
- textpolicy/utils/logging/tensorboard.py +65 -0
- textpolicy/utils/logging/wandb.py +72 -0
- textpolicy/utils/memory.py +118 -0
- textpolicy/utils/performance.py +464 -0
- textpolicy/utils/timing.py +171 -0
- textpolicy/validate.py +101 -0
- textpolicy/validation/__init__.py +13 -0
- textpolicy/validation/logprob_validation.py +315 -0
- textpolicy-0.1.0.dist-info/METADATA +99 -0
- textpolicy-0.1.0.dist-info/RECORD +66 -0
- textpolicy-0.1.0.dist-info/entry_points.txt +2 -0
- textpolicy-0.0.1.dist-info/METADATA +0 -10
- textpolicy-0.0.1.dist-info/RECORD +0 -6
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/WHEEL +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,438 @@
|
|
|
1
|
+
# textpolicy/buffer/sampling.py
|
|
2
|
+
"""
|
|
3
|
+
Buffer sampling methods for training data retrieval.
|
|
4
|
+
|
|
5
|
+
Designed for MLX and Apple Silicon with efficient tensor conversions.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import List, Dict
|
|
9
|
+
import random
|
|
10
|
+
import mlx.core as mx # type: ignore
|
|
11
|
+
from .episode import Episode
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class BufferSampler:
|
|
15
|
+
"""
|
|
16
|
+
Handles all data sampling operations for the buffer.
|
|
17
|
+
|
|
18
|
+
Provides multiple sampling strategies:
|
|
19
|
+
- Full buffer sampling
|
|
20
|
+
- Latest N steps sampling
|
|
21
|
+
- Episode-based sampling
|
|
22
|
+
|
|
23
|
+
Uses MLX tensor operations and Apple Silicon-friendly patterns.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, episodes: List[Episode]):
|
|
27
|
+
"""
|
|
28
|
+
Initialize sampler with episode storage reference.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
episodes: Reference to list of complete episodes
|
|
32
|
+
"""
|
|
33
|
+
self.episodes = episodes
|
|
34
|
+
|
|
35
|
+
def sample_all(self) -> Dict[str, mx.array]:
|
|
36
|
+
"""
|
|
37
|
+
Sample all stored episodes as a single concatenated batch.
|
|
38
|
+
|
|
39
|
+
Returns all transitions in chronological order:
|
|
40
|
+
- Oldest episode → Newest episode
|
|
41
|
+
- Each episode: first step → last step
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
Dict of MLX arrays with all transitions
|
|
45
|
+
|
|
46
|
+
Raises:
|
|
47
|
+
ValueError: If buffer is empty
|
|
48
|
+
"""
|
|
49
|
+
if not self.episodes:
|
|
50
|
+
raise ValueError("Buffer empty. No episodes to sample.")
|
|
51
|
+
|
|
52
|
+
# Concatenate all episodes by collecting data directly (bypass Episode validation)
|
|
53
|
+
# Episodes can have different optional fields, so collect what exists across all episodes
|
|
54
|
+
all_obs = []
|
|
55
|
+
all_act = []
|
|
56
|
+
all_rew = []
|
|
57
|
+
all_next_obs = []
|
|
58
|
+
all_done = []
|
|
59
|
+
all_timeout = []
|
|
60
|
+
all_logprob = []
|
|
61
|
+
all_value = []
|
|
62
|
+
all_entropy = []
|
|
63
|
+
|
|
64
|
+
# Only include optional fields that exist in ALL episodes for consistent sampling
|
|
65
|
+
# This matches the buffer's "all-or-nothing" design philosophy per episode
|
|
66
|
+
has_logprob = all(episode.logprob is not None for episode in self.episodes)
|
|
67
|
+
has_value = all(episode.value is not None for episode in self.episodes)
|
|
68
|
+
has_entropy = all(episode.entropy is not None for episode in self.episodes)
|
|
69
|
+
|
|
70
|
+
# Collect transitions from all episodes
|
|
71
|
+
for episode in self.episodes:
|
|
72
|
+
for i in range(len(episode)):
|
|
73
|
+
all_obs.append(episode.obs[i])
|
|
74
|
+
all_act.append(episode.act[i])
|
|
75
|
+
all_rew.append(episode.rew[i])
|
|
76
|
+
all_next_obs.append(episode.next_obs[i])
|
|
77
|
+
all_done.append(episode.done[i])
|
|
78
|
+
all_timeout.append(episode.timeout[i])
|
|
79
|
+
|
|
80
|
+
# Only collect optional fields that exist in ALL episodes
|
|
81
|
+
if has_logprob:
|
|
82
|
+
all_logprob.append(episode.logprob[i])
|
|
83
|
+
if has_value:
|
|
84
|
+
all_value.append(episode.value[i])
|
|
85
|
+
if has_entropy:
|
|
86
|
+
all_entropy.append(episode.entropy[i])
|
|
87
|
+
|
|
88
|
+
# Create Episode directly with collected data (bypassing validation during construction)
|
|
89
|
+
all_transitions = Episode()
|
|
90
|
+
all_transitions.obs = all_obs
|
|
91
|
+
all_transitions.act = all_act
|
|
92
|
+
all_transitions.rew = all_rew
|
|
93
|
+
all_transitions.next_obs = all_next_obs
|
|
94
|
+
all_transitions.done = all_done
|
|
95
|
+
all_transitions.timeout = all_timeout
|
|
96
|
+
|
|
97
|
+
# Only set optional fields if they exist in any episode
|
|
98
|
+
if has_logprob:
|
|
99
|
+
all_transitions.logprob = all_logprob
|
|
100
|
+
if has_value:
|
|
101
|
+
all_transitions.value = all_value
|
|
102
|
+
if has_entropy:
|
|
103
|
+
all_transitions.entropy = all_entropy
|
|
104
|
+
|
|
105
|
+
return all_transitions.to_tensor_dict()
|
|
106
|
+
|
|
107
|
+
def sample_latest_steps(self, n: int) -> Dict[str, mx.array]:
|
|
108
|
+
"""
|
|
109
|
+
Sample the N most recent transitions across episodes.
|
|
110
|
+
|
|
111
|
+
Useful for on-policy RL algorithms that use recent experience for stable training.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
n: Number of steps to sample (must be > 0)
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
Dict of MLX arrays with the latest n steps in chronological order
|
|
118
|
+
(oldest → newest, ensuring temporal consistency)
|
|
119
|
+
|
|
120
|
+
Raises:
|
|
121
|
+
ValueError: If buffer is empty, n <= 0, or not enough steps
|
|
122
|
+
"""
|
|
123
|
+
if not self.episodes:
|
|
124
|
+
raise ValueError("Buffer is empty. No recent episodes to sample.")
|
|
125
|
+
if n <= 0:
|
|
126
|
+
raise ValueError("Number of steps (n) to sample must be greater than 0.")
|
|
127
|
+
|
|
128
|
+
# Collect steps in reverse chronological order (newest first)
|
|
129
|
+
steps = []
|
|
130
|
+
for episode in reversed(self.episodes):
|
|
131
|
+
for i in reversed(range(len(episode))):
|
|
132
|
+
step_dict = {
|
|
133
|
+
"obs": episode.obs[i],
|
|
134
|
+
"act": episode.act[i],
|
|
135
|
+
"rew": episode.rew[i],
|
|
136
|
+
"next_obs": episode.next_obs[i],
|
|
137
|
+
"done": episode.done[i],
|
|
138
|
+
"timeout": episode.timeout[i],
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
# Add optional fields if present
|
|
142
|
+
if episode.logprob is not None:
|
|
143
|
+
step_dict["logprob"] = episode.logprob[i]
|
|
144
|
+
if episode.value is not None:
|
|
145
|
+
step_dict["value"] = episode.value[i]
|
|
146
|
+
if episode.entropy is not None:
|
|
147
|
+
step_dict["entropy"] = episode.entropy[i]
|
|
148
|
+
|
|
149
|
+
steps.append(step_dict)
|
|
150
|
+
|
|
151
|
+
# Stop when we have enough steps
|
|
152
|
+
if len(steps) >= n:
|
|
153
|
+
break
|
|
154
|
+
|
|
155
|
+
if len(steps) >= n:
|
|
156
|
+
break
|
|
157
|
+
|
|
158
|
+
if not steps:
|
|
159
|
+
raise ValueError("No steps available to sample")
|
|
160
|
+
|
|
161
|
+
# Reverse to get chronological order (oldest → newest)
|
|
162
|
+
chronological_steps = list(reversed(steps[:n]))
|
|
163
|
+
|
|
164
|
+
# Batch convert to MLX arrays
|
|
165
|
+
# Collect values first, then convert in a single pass for memory efficiency
|
|
166
|
+
batch = {}
|
|
167
|
+
for key in chronological_steps[0].keys():
|
|
168
|
+
# Extract all values for this key at once
|
|
169
|
+
values = [step[key] for step in chronological_steps]
|
|
170
|
+
# Single batch conversion is more efficient than individual mx.array() calls
|
|
171
|
+
batch[key] = mx.stack([mx.array(v) for v in values])
|
|
172
|
+
|
|
173
|
+
return batch
|
|
174
|
+
|
|
175
|
+
def sample_episodes(self, k: int, order: str = 'asc') -> Dict[str, mx.array]:
|
|
176
|
+
"""
|
|
177
|
+
Sample up to k complete episodes.
|
|
178
|
+
|
|
179
|
+
Useful for episode-based training or evaluation analysis.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
k: Number of episodes to sample (must be > 0)
|
|
183
|
+
order: 'asc' for oldest first, 'desc' for newest first
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
Dict of MLX arrays with concatenated transitions from selected episodes
|
|
187
|
+
|
|
188
|
+
Raises:
|
|
189
|
+
ValueError: If buffer is empty, k <= 0, or invalid order
|
|
190
|
+
"""
|
|
191
|
+
if not self.episodes:
|
|
192
|
+
raise ValueError("Buffer is empty. No episodes to sample.")
|
|
193
|
+
if k <= 0:
|
|
194
|
+
raise ValueError("k must be positive.")
|
|
195
|
+
if order not in ('asc', 'desc'):
|
|
196
|
+
raise ValueError("order must be 'asc' or 'desc'")
|
|
197
|
+
|
|
198
|
+
# Select episodes based on order
|
|
199
|
+
if order == 'asc':
|
|
200
|
+
selected_episodes = self.episodes[:k] # Oldest k episodes
|
|
201
|
+
else: # 'desc'
|
|
202
|
+
selected_episodes = self.episodes[-k:] # Newest k episodes
|
|
203
|
+
|
|
204
|
+
if not selected_episodes:
|
|
205
|
+
raise ValueError("No episodes matched the criteria.")
|
|
206
|
+
|
|
207
|
+
# Concatenate all steps from selected episodes
|
|
208
|
+
all_transitions = Episode()
|
|
209
|
+
for episode in selected_episodes:
|
|
210
|
+
for i in range(len(episode)):
|
|
211
|
+
all_transitions.append(
|
|
212
|
+
obs=episode.obs[i],
|
|
213
|
+
act=episode.act[i],
|
|
214
|
+
rew=episode.rew[i],
|
|
215
|
+
next_obs=episode.next_obs[i],
|
|
216
|
+
done=episode.done[i],
|
|
217
|
+
timeout=episode.timeout[i],
|
|
218
|
+
logprob=episode.logprob[i] if episode.logprob is not None else None,
|
|
219
|
+
value=episode.value[i] if episode.value is not None else None,
|
|
220
|
+
entropy=episode.entropy[i] if episode.entropy is not None else None
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
return all_transitions.to_tensor_dict()
|
|
224
|
+
|
|
225
|
+
def sample_sequences(
|
|
226
|
+
self,
|
|
227
|
+
batch_size: int,
|
|
228
|
+
seq_len: int,
|
|
229
|
+
recent_first: bool = True,
|
|
230
|
+
drop_incomplete: bool = True,
|
|
231
|
+
dreamerv3_mode: bool = False,
|
|
232
|
+
) -> Dict[str, mx.array]:
|
|
233
|
+
"""
|
|
234
|
+
Sample contiguous sequences of length `seq_len`.
|
|
235
|
+
|
|
236
|
+
Standard mode: Sample without crossing episode boundaries (for most algorithms).
|
|
237
|
+
DreamerV3 mode: Sample across episode boundaries to include terminals (for continue head supervision).
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
batch_size: Number of sequences to sample.
|
|
241
|
+
seq_len: Length of each sequence (T). Must be > 0.
|
|
242
|
+
recent_first: Prefer sampling from most recent episodes first.
|
|
243
|
+
drop_incomplete: If True, skip episodes shorter than seq_len (standard mode only).
|
|
244
|
+
dreamerv3_mode: If True, sample across episode boundaries to include terminals.
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
Dict of MLX arrays with keys: obs, act, rew, next_obs, done, timeout
|
|
248
|
+
and optional keys (logprob, value, entropy) included only if present.
|
|
249
|
+
|
|
250
|
+
Raises:
|
|
251
|
+
ValueError: If buffer is empty or inputs are invalid or not enough data.
|
|
252
|
+
"""
|
|
253
|
+
if not self.episodes:
|
|
254
|
+
raise ValueError("Buffer is empty. No episodes to sample.")
|
|
255
|
+
if batch_size <= 0 or seq_len <= 0:
|
|
256
|
+
raise ValueError("batch_size and seq_len must be positive.")
|
|
257
|
+
|
|
258
|
+
if dreamerv3_mode:
|
|
259
|
+
return self._sample_dreamerv3_sequences(batch_size, seq_len, recent_first)
|
|
260
|
+
|
|
261
|
+
# Choose episode order per recency preference.
|
|
262
|
+
episodes_iter: List[Episode]
|
|
263
|
+
episodes_iter = list(reversed(self.episodes)) if recent_first else list(self.episodes)
|
|
264
|
+
|
|
265
|
+
# Collect one latest contiguous window per episode until batch is filled.
|
|
266
|
+
sequences = []
|
|
267
|
+
episodes_used: List[Episode] = []
|
|
268
|
+
|
|
269
|
+
for ep in episodes_iter:
|
|
270
|
+
n = len(ep)
|
|
271
|
+
if n < seq_len:
|
|
272
|
+
if drop_incomplete:
|
|
273
|
+
continue
|
|
274
|
+
# Padding/masking path intentionally not implemented for simplicity/perf.
|
|
275
|
+
continue
|
|
276
|
+
|
|
277
|
+
# DreamerV3 expects sequences sampled from throughout episodes, not only tails.
|
|
278
|
+
# Avoid tail-bias that would overrepresent terminal steps. Sample a random
|
|
279
|
+
# start index within the episode. This preserves recency at the episode
|
|
280
|
+
# level via episodes_iter and reduces bias within each episode.
|
|
281
|
+
max_start = n - seq_len
|
|
282
|
+
start = random.randint(0, max_start) if max_start > 0 else 0
|
|
283
|
+
end = start + seq_len # ensure fixed-length window [start, start+seq_len)
|
|
284
|
+
seq = {
|
|
285
|
+
'obs': ep.obs[start:end],
|
|
286
|
+
'act': ep.act[start:end],
|
|
287
|
+
'rew': ep.rew[start:end],
|
|
288
|
+
'next_obs': ep.next_obs[start:end],
|
|
289
|
+
'done': ep.done[start:end],
|
|
290
|
+
'timeout': ep.timeout[start:end],
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
# Optional fields: only include if present in the episode
|
|
294
|
+
if ep.logprob is not None:
|
|
295
|
+
seq['logprob'] = ep.logprob[start:end]
|
|
296
|
+
if ep.value is not None:
|
|
297
|
+
seq['value'] = ep.value[start:end]
|
|
298
|
+
if ep.entropy is not None:
|
|
299
|
+
seq['entropy'] = ep.entropy[start:end]
|
|
300
|
+
|
|
301
|
+
sequences.append(seq)
|
|
302
|
+
episodes_used.append(ep)
|
|
303
|
+
if len(sequences) >= batch_size:
|
|
304
|
+
break
|
|
305
|
+
|
|
306
|
+
if not sequences:
|
|
307
|
+
raise ValueError("No sequences available to sample.")
|
|
308
|
+
|
|
309
|
+
# Determine optional keys that exist across all sampled sequences for consistent batching.
|
|
310
|
+
all_keys = set(sequences[0].keys())
|
|
311
|
+
for s in sequences[1:]:
|
|
312
|
+
all_keys &= set(s.keys())
|
|
313
|
+
|
|
314
|
+
# Convert to MLX arrays with shape [B, T, ...]
|
|
315
|
+
batch: Dict[str, mx.array] = {}
|
|
316
|
+
for key in all_keys:
|
|
317
|
+
# First convert each sequence to [T, ...]
|
|
318
|
+
per_seq = []
|
|
319
|
+
for s in sequences:
|
|
320
|
+
# Efficient batch conversion: one mx.array() per time step then stack along time
|
|
321
|
+
# to minimize Python overhead while keeping explicit control of dimensions.
|
|
322
|
+
per_seq.append(mx.stack([mx.array(v) for v in s[key]], axis=0)) # [T, ...]
|
|
323
|
+
# Stack sequences along batch dimension
|
|
324
|
+
batch[key] = mx.stack(per_seq, axis=0) # [B, T, ...]
|
|
325
|
+
|
|
326
|
+
return batch
|
|
327
|
+
|
|
328
|
+
def _sample_dreamerv3_sequences(
|
|
329
|
+
self,
|
|
330
|
+
batch_size: int,
|
|
331
|
+
seq_len: int,
|
|
332
|
+
recent_first: bool
|
|
333
|
+
) -> Dict[str, mx.array]:
|
|
334
|
+
"""
|
|
335
|
+
Sample sequences for DreamerV3 by concatenating episodes to form continuous trajectories.
|
|
336
|
+
|
|
337
|
+
Unlike standard sampling, this allows sequences to span episode boundaries,
|
|
338
|
+
ensuring episode terminals appear within training sequences for continue head supervision.
|
|
339
|
+
|
|
340
|
+
This mirrors DreamerV3's original replay buffer behavior where episodes are stored
|
|
341
|
+
continuously and sampling naturally includes episode boundaries.
|
|
342
|
+
"""
|
|
343
|
+
import random
|
|
344
|
+
|
|
345
|
+
# Create continuous trajectory by concatenating recent episodes
|
|
346
|
+
episodes_iter = list(reversed(self.episodes)) if recent_first else list(self.episodes)
|
|
347
|
+
|
|
348
|
+
# Concatenate episodes into a continuous stream
|
|
349
|
+
continuous_data = {
|
|
350
|
+
'obs': [], 'act': [], 'rew': [], 'next_obs': [], 'done': [], 'timeout': [],
|
|
351
|
+
'logprob': [], 'value': [], 'entropy': []
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
for ep in episodes_iter:
|
|
355
|
+
continuous_data['obs'].extend(ep.obs)
|
|
356
|
+
continuous_data['act'].extend(ep.act)
|
|
357
|
+
continuous_data['rew'].extend(ep.rew)
|
|
358
|
+
continuous_data['next_obs'].extend(ep.next_obs)
|
|
359
|
+
continuous_data['done'].extend(ep.done)
|
|
360
|
+
continuous_data['timeout'].extend(ep.timeout)
|
|
361
|
+
|
|
362
|
+
# Optional fields - extend with zeros if episode doesn't have them
|
|
363
|
+
if ep.logprob is not None:
|
|
364
|
+
continuous_data['logprob'].extend(ep.logprob)
|
|
365
|
+
else:
|
|
366
|
+
continuous_data['logprob'].extend([0.0] * len(ep.obs))
|
|
367
|
+
|
|
368
|
+
if ep.value is not None:
|
|
369
|
+
continuous_data['value'].extend(ep.value)
|
|
370
|
+
else:
|
|
371
|
+
continuous_data['value'].extend([0.0] * len(ep.obs))
|
|
372
|
+
|
|
373
|
+
if ep.entropy is not None:
|
|
374
|
+
continuous_data['entropy'].extend(ep.entropy)
|
|
375
|
+
else:
|
|
376
|
+
continuous_data['entropy'].extend([0.0] * len(ep.obs))
|
|
377
|
+
|
|
378
|
+
total_steps = len(continuous_data['obs'])
|
|
379
|
+
if total_steps < seq_len:
|
|
380
|
+
raise ValueError(f"Not enough continuous data: {total_steps} steps < {seq_len} required")
|
|
381
|
+
|
|
382
|
+
# Sample random sequences from the continuous trajectory
|
|
383
|
+
sequences = []
|
|
384
|
+
for _ in range(batch_size):
|
|
385
|
+
# Random start position that allows full sequence
|
|
386
|
+
max_start = total_steps - seq_len
|
|
387
|
+
start = random.randint(0, max_start) if max_start > 0 else 0
|
|
388
|
+
end = start + seq_len
|
|
389
|
+
|
|
390
|
+
# Extract sequence
|
|
391
|
+
seq = {}
|
|
392
|
+
for key in ['obs', 'act', 'rew', 'next_obs', 'done', 'timeout', 'logprob', 'value', 'entropy']:
|
|
393
|
+
seq[key] = continuous_data[key][start:end]
|
|
394
|
+
|
|
395
|
+
sequences.append(seq)
|
|
396
|
+
|
|
397
|
+
# Convert to batch format [batch_size, seq_len, ...] using the same efficient conversion as standard mode
|
|
398
|
+
batch = {}
|
|
399
|
+
for key in sequences[0].keys():
|
|
400
|
+
per_seq = []
|
|
401
|
+
for s in sequences:
|
|
402
|
+
# Convert sequence to mx.array and stack along time dimension
|
|
403
|
+
per_seq.append(mx.stack([mx.array(v) for v in s[key]], axis=0)) # [T, ...]
|
|
404
|
+
# Stack sequences along batch dimension
|
|
405
|
+
batch[key] = mx.stack(per_seq, axis=0) # [B, T, ...]
|
|
406
|
+
|
|
407
|
+
return batch
|
|
408
|
+
|
|
409
|
+
def get_episode_statistics(self) -> Dict[str, float]:
|
|
410
|
+
"""
|
|
411
|
+
Get statistical information about stored episodes.
|
|
412
|
+
|
|
413
|
+
Returns:
|
|
414
|
+
Dictionary with episode statistics for analysis
|
|
415
|
+
"""
|
|
416
|
+
if not self.episodes:
|
|
417
|
+
return {
|
|
418
|
+
'episode_count': 0,
|
|
419
|
+
'total_steps': 0,
|
|
420
|
+
'mean_episode_length': 0.0,
|
|
421
|
+
'min_episode_length': 0,
|
|
422
|
+
'max_episode_length': 0,
|
|
423
|
+
'total_reward': 0.0,
|
|
424
|
+
'mean_episode_reward': 0.0
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
episode_lengths = [len(ep) for ep in self.episodes]
|
|
428
|
+
episode_rewards = [sum(ep.rew) for ep in self.episodes]
|
|
429
|
+
|
|
430
|
+
return {
|
|
431
|
+
'episode_count': len(self.episodes),
|
|
432
|
+
'total_steps': sum(episode_lengths),
|
|
433
|
+
'mean_episode_length': sum(episode_lengths) / len(episode_lengths),
|
|
434
|
+
'min_episode_length': min(episode_lengths),
|
|
435
|
+
'max_episode_length': max(episode_lengths),
|
|
436
|
+
'total_reward': sum(episode_rewards),
|
|
437
|
+
'mean_episode_reward': sum(episode_rewards) / len(episode_rewards)
|
|
438
|
+
}
|
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
# textpolicy/buffer/storage.py
|
|
2
|
+
"""
|
|
3
|
+
Buffer storage and capacity management.
|
|
4
|
+
|
|
5
|
+
Handles episode storage, capacity limits, and episode lifecycle management.
|
|
6
|
+
Designed for efficient memory usage on Apple Silicon.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import List, Dict, Any
|
|
10
|
+
from .episode import Episode
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class BufferStorage:
|
|
14
|
+
"""
|
|
15
|
+
Manages episode storage with capacity limits and lifecycle.
|
|
16
|
+
|
|
17
|
+
Features:
|
|
18
|
+
- FIFO episode eviction when capacity exceeded
|
|
19
|
+
- Episode validation before storage
|
|
20
|
+
- Efficient storage for multiprocessing scenarios
|
|
21
|
+
- Memory-conscious design for Apple Silicon
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, max_episodes: int = 100):
|
|
25
|
+
"""
|
|
26
|
+
Initialize buffer storage.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
max_episodes: Maximum number of complete episodes to store.
|
|
30
|
+
Oldest episodes are dropped when capacity is exceeded.
|
|
31
|
+
"""
|
|
32
|
+
self.max_episodes = max_episodes
|
|
33
|
+
self.episodes: List[Episode] = []
|
|
34
|
+
self.current_episode = Episode()
|
|
35
|
+
|
|
36
|
+
def add_transition(
|
|
37
|
+
self,
|
|
38
|
+
obs: Any,
|
|
39
|
+
act: Any,
|
|
40
|
+
rew: Any,
|
|
41
|
+
next_obs: Any,
|
|
42
|
+
done: bool,
|
|
43
|
+
timeout: bool = False,
|
|
44
|
+
**kwargs # Additional fields like logprob, value, entropy
|
|
45
|
+
):
|
|
46
|
+
"""
|
|
47
|
+
Add a transition to the current episode.
|
|
48
|
+
|
|
49
|
+
Completes and stores the episode if done or timeout is True.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
obs: Observation
|
|
53
|
+
act: Action taken
|
|
54
|
+
rew: Reward received
|
|
55
|
+
next_obs: Next observation
|
|
56
|
+
done: Boolean indicating episode termination
|
|
57
|
+
timeout: Boolean indicating truncation (e.g. time limit)
|
|
58
|
+
**kwargs: Optional fields (logprob, value, entropy)
|
|
59
|
+
"""
|
|
60
|
+
# Add transition to current episode
|
|
61
|
+
self.current_episode.append(
|
|
62
|
+
obs=obs, act=act, rew=rew, next_obs=next_obs,
|
|
63
|
+
done=done, timeout=timeout, **kwargs
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Complete episode if terminated
|
|
67
|
+
if done or timeout:
|
|
68
|
+
self._complete_current_episode()
|
|
69
|
+
|
|
70
|
+
def _complete_current_episode(self):
|
|
71
|
+
"""
|
|
72
|
+
Complete the current episode and start a new one.
|
|
73
|
+
|
|
74
|
+
Validates episode before storage and enforces capacity limits.
|
|
75
|
+
"""
|
|
76
|
+
# Validate episode before storage
|
|
77
|
+
if len(self.current_episode) > 0:
|
|
78
|
+
self.current_episode.validate_consistency()
|
|
79
|
+
|
|
80
|
+
# Debug: episode data quality
|
|
81
|
+
self._debug_episode_data(self.current_episode)
|
|
82
|
+
|
|
83
|
+
# Add to storage
|
|
84
|
+
self.episodes.append(self.current_episode)
|
|
85
|
+
|
|
86
|
+
# Enforce capacity limit (FIFO eviction)
|
|
87
|
+
if len(self.episodes) > self.max_episodes:
|
|
88
|
+
self.episodes.pop(0) # Remove oldest episode
|
|
89
|
+
|
|
90
|
+
# Start new episode
|
|
91
|
+
self.current_episode = Episode()
|
|
92
|
+
|
|
93
|
+
def _debug_episode_data(self, episode):
|
|
94
|
+
"""Debug what episodes look like - only show every 10th episode."""
|
|
95
|
+
if not hasattr(self, 'episode_debug_count'):
|
|
96
|
+
self.episode_debug_count = 0
|
|
97
|
+
|
|
98
|
+
self.episode_debug_count += 1
|
|
99
|
+
|
|
100
|
+
# Only debug every 50th episode to avoid spam
|
|
101
|
+
if self.episode_debug_count % 50 == 1:
|
|
102
|
+
try:
|
|
103
|
+
# Type checking and safe conversion
|
|
104
|
+
episode_count = int(self.episode_debug_count)
|
|
105
|
+
|
|
106
|
+
rewards = episode.rew if hasattr(episode, 'rew') else []
|
|
107
|
+
# Convert all rewards to float before summing to handle mixed types
|
|
108
|
+
try:
|
|
109
|
+
numeric_rewards = [float(r) for r in rewards]
|
|
110
|
+
total_reward = sum(numeric_rewards)
|
|
111
|
+
except (TypeError, ValueError):
|
|
112
|
+
# Fallback: filter out non-numeric rewards
|
|
113
|
+
numeric_rewards = []
|
|
114
|
+
for r in rewards:
|
|
115
|
+
try:
|
|
116
|
+
numeric_rewards.append(float(r))
|
|
117
|
+
except (TypeError, ValueError):
|
|
118
|
+
continue
|
|
119
|
+
total_reward = sum(numeric_rewards) if numeric_rewards else 0.0
|
|
120
|
+
episode_length = len(episode)
|
|
121
|
+
|
|
122
|
+
print(f"\nEPISODE DEBUG (Episode #{episode_count}):")
|
|
123
|
+
print(f" Episode length: {episode_length}")
|
|
124
|
+
print(f" Total reward: {total_reward:.2f}")
|
|
125
|
+
|
|
126
|
+
if rewards:
|
|
127
|
+
# Ensure rewards are numeric before formatting
|
|
128
|
+
try:
|
|
129
|
+
first_rewards = [float(r) for r in rewards[:10]]
|
|
130
|
+
last_reward = float(rewards[-1])
|
|
131
|
+
print(f" Reward sequence: {first_rewards}...")
|
|
132
|
+
print(f" Last reward: {last_reward:.2f}")
|
|
133
|
+
except (TypeError, ValueError) as reward_error:
|
|
134
|
+
print(f" Reward formatting error: {reward_error}")
|
|
135
|
+
print(f" Raw rewards: {rewards[:10]}...")
|
|
136
|
+
|
|
137
|
+
# Check termination
|
|
138
|
+
if hasattr(episode, 'done') and episode.done:
|
|
139
|
+
final_done = episode.done[-1] if episode.done else False
|
|
140
|
+
print(f" Final done: {final_done}")
|
|
141
|
+
|
|
142
|
+
if hasattr(episode, 'timeout') and episode.timeout:
|
|
143
|
+
final_timeout = episode.timeout[-1] if episode.timeout else False
|
|
144
|
+
print(f" Final timeout: {final_timeout}")
|
|
145
|
+
|
|
146
|
+
except Exception as e:
|
|
147
|
+
print(f" Episode debug error: {e}")
|
|
148
|
+
# More detailed error info
|
|
149
|
+
import traceback
|
|
150
|
+
print(f" Error details: {traceback.format_exc()}")
|
|
151
|
+
|
|
152
|
+
def add_episode_from_dict(self, data: Dict[str, Any]):
|
|
153
|
+
"""
|
|
154
|
+
Reconstruct and add an episode from serialized dictionary.
|
|
155
|
+
|
|
156
|
+
Used for multiprocessing: worker serializes episode to dict,
|
|
157
|
+
trainer deserializes and adds to buffer.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
data: Dictionary containing episode data from episode.to_dict()
|
|
161
|
+
Must include: obs, act, rew, next_obs, done, timeout
|
|
162
|
+
Optional: logprob, value, entropy
|
|
163
|
+
|
|
164
|
+
Raises:
|
|
165
|
+
ValueError: If episode data is invalid or inconsistent
|
|
166
|
+
"""
|
|
167
|
+
# Create new episode from dictionary
|
|
168
|
+
episode = Episode()
|
|
169
|
+
|
|
170
|
+
# Set required fields
|
|
171
|
+
episode.obs = data['obs']
|
|
172
|
+
episode.act = data['act']
|
|
173
|
+
episode.rew = data['rew']
|
|
174
|
+
episode.next_obs = data['next_obs']
|
|
175
|
+
episode.done = data['done']
|
|
176
|
+
episode.timeout = data['timeout']
|
|
177
|
+
|
|
178
|
+
# Set optional fields if present
|
|
179
|
+
episode.logprob = data.get('logprob', None)
|
|
180
|
+
episode.value = data.get('value', None)
|
|
181
|
+
episode.entropy = data.get('entropy', None)
|
|
182
|
+
|
|
183
|
+
# Validate consistency before adding
|
|
184
|
+
episode.validate_consistency()
|
|
185
|
+
|
|
186
|
+
# Add to storage
|
|
187
|
+
self.episodes.append(episode)
|
|
188
|
+
|
|
189
|
+
# Enforce capacity limit
|
|
190
|
+
if len(self.episodes) > self.max_episodes:
|
|
191
|
+
self.episodes.pop(0)
|
|
192
|
+
|
|
193
|
+
def clear(self):
|
|
194
|
+
"""
|
|
195
|
+
Clear all stored episodes and reset current episode.
|
|
196
|
+
|
|
197
|
+
Used to reset buffer state between training runs.
|
|
198
|
+
"""
|
|
199
|
+
self.episodes.clear()
|
|
200
|
+
self.current_episode = Episode()
|
|
201
|
+
|
|
202
|
+
def ready(self, min_episodes: int = 1) -> bool:
|
|
203
|
+
"""
|
|
204
|
+
Check if buffer has enough complete episodes for training.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
min_episodes: Minimum number of episodes required
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
True if buffer has at least min_episodes complete episodes
|
|
211
|
+
"""
|
|
212
|
+
return len(self.episodes) >= min_episodes
|
|
213
|
+
|
|
214
|
+
def total_steps(self) -> int:
|
|
215
|
+
"""
|
|
216
|
+
Calculate total number of steps across all episodes.
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
Sum of steps in all complete episodes
|
|
220
|
+
"""
|
|
221
|
+
return sum(len(episode) for episode in self.episodes)
|
|
222
|
+
|
|
223
|
+
def get_episodes(self) -> List[Episode]:
|
|
224
|
+
"""
|
|
225
|
+
Get read-only access to stored episodes.
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
List of complete episodes (does not include current incomplete episode)
|
|
229
|
+
"""
|
|
230
|
+
return self.episodes.copy()
|
|
231
|
+
|
|
232
|
+
def __len__(self) -> int:
|
|
233
|
+
"""Return total number of steps in buffer."""
|
|
234
|
+
return self.total_steps()
|
|
235
|
+
|
|
236
|
+
@property
|
|
237
|
+
def episode_count(self) -> int:
|
|
238
|
+
"""Number of complete episodes currently stored."""
|
|
239
|
+
return len(self.episodes)
|
|
240
|
+
|
|
241
|
+
def get_storage_info(self) -> Dict[str, Any]:
|
|
242
|
+
"""
|
|
243
|
+
Get detailed storage information for debugging.
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
Dictionary with storage statistics
|
|
247
|
+
"""
|
|
248
|
+
return {
|
|
249
|
+
'episode_count': len(self.episodes),
|
|
250
|
+
'max_episodes': self.max_episodes,
|
|
251
|
+
'total_steps': self.total_steps(),
|
|
252
|
+
'current_episode_steps': len(self.current_episode),
|
|
253
|
+
'episode_lengths': [len(ep) for ep in self.episodes],
|
|
254
|
+
'capacity_usage': len(self.episodes) / self.max_episodes if self.max_episodes > 0 else 0.0
|
|
255
|
+
}
|