textpolicy 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. textpolicy/__init__.py +52 -0
  2. textpolicy/__main__.py +8 -0
  3. textpolicy/algorithms/__init__.py +54 -0
  4. textpolicy/algorithms/grpo.py +642 -0
  5. textpolicy/algorithms/gspo.py +582 -0
  6. textpolicy/buffer/__init__.py +23 -0
  7. textpolicy/buffer/buffer.py +244 -0
  8. textpolicy/buffer/episode.py +383 -0
  9. textpolicy/buffer/sampling.py +438 -0
  10. textpolicy/buffer/storage.py +255 -0
  11. textpolicy/cli.py +67 -0
  12. textpolicy/environment/__init__.py +79 -0
  13. textpolicy/environment/base.py +110 -0
  14. textpolicy/environment/environment.py +46 -0
  15. textpolicy/environment/factory.py +103 -0
  16. textpolicy/environment/gym.py +106 -0
  17. textpolicy/environment/task_suites.py +51 -0
  18. textpolicy/environment/text_generation.py +789 -0
  19. textpolicy/environment/vectorized.py +253 -0
  20. textpolicy/generation/__init__.py +62 -0
  21. textpolicy/generation/lora.py +411 -0
  22. textpolicy/generation/mlx_generation.py +557 -0
  23. textpolicy/generation/reload.py +253 -0
  24. textpolicy/rewards/__init__.py +137 -0
  25. textpolicy/rewards/adapters.py +387 -0
  26. textpolicy/rewards/basic.py +214 -0
  27. textpolicy/rewards/integrated_system.py +338 -0
  28. textpolicy/rewards/mlx_batch_processor.py +447 -0
  29. textpolicy/rewards/registry.py +293 -0
  30. textpolicy/rewards/rollout_rewards.py +410 -0
  31. textpolicy/rewards/verifiers.py +369 -0
  32. textpolicy/rollout/__init__.py +44 -0
  33. textpolicy/rollout/aggregator.py +145 -0
  34. textpolicy/rollout/base.py +108 -0
  35. textpolicy/rollout/rollout.py +142 -0
  36. textpolicy/rollout/runner.py +280 -0
  37. textpolicy/rollout/strategy.py +208 -0
  38. textpolicy/rollout/worker.py +194 -0
  39. textpolicy/training/__init__.py +14 -0
  40. textpolicy/training/metrics.py +242 -0
  41. textpolicy/training/rollout_manager.py +78 -0
  42. textpolicy/training/trainer.py +684 -0
  43. textpolicy/utils/__init__.py +40 -0
  44. textpolicy/utils/benchmarking.py +489 -0
  45. textpolicy/utils/data.py +60 -0
  46. textpolicy/utils/debug.py +170 -0
  47. textpolicy/utils/environment.py +349 -0
  48. textpolicy/utils/logging/__init__.py +22 -0
  49. textpolicy/utils/logging/base.py +48 -0
  50. textpolicy/utils/logging/console.py +61 -0
  51. textpolicy/utils/logging/factory.py +133 -0
  52. textpolicy/utils/logging/multi.py +83 -0
  53. textpolicy/utils/logging/tensorboard.py +65 -0
  54. textpolicy/utils/logging/wandb.py +72 -0
  55. textpolicy/utils/memory.py +118 -0
  56. textpolicy/utils/performance.py +464 -0
  57. textpolicy/utils/timing.py +171 -0
  58. textpolicy/validate.py +101 -0
  59. textpolicy/validation/__init__.py +13 -0
  60. textpolicy/validation/logprob_validation.py +315 -0
  61. textpolicy-0.1.0.dist-info/METADATA +99 -0
  62. textpolicy-0.1.0.dist-info/RECORD +66 -0
  63. textpolicy-0.1.0.dist-info/entry_points.txt +2 -0
  64. textpolicy-0.0.1.dist-info/METADATA +0 -10
  65. textpolicy-0.0.1.dist-info/RECORD +0 -6
  66. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/WHEEL +0 -0
  67. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/licenses/LICENSE +0 -0
  68. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,438 @@
1
+ # textpolicy/buffer/sampling.py
2
+ """
3
+ Buffer sampling methods for training data retrieval.
4
+
5
+ Designed for MLX and Apple Silicon with efficient tensor conversions.
6
+ """
7
+
8
+ from typing import List, Dict
9
+ import random
10
+ import mlx.core as mx # type: ignore
11
+ from .episode import Episode
12
+
13
+
14
+ class BufferSampler:
15
+ """
16
+ Handles all data sampling operations for the buffer.
17
+
18
+ Provides multiple sampling strategies:
19
+ - Full buffer sampling
20
+ - Latest N steps sampling
21
+ - Episode-based sampling
22
+
23
+ Uses MLX tensor operations and Apple Silicon-friendly patterns.
24
+ """
25
+
26
+ def __init__(self, episodes: List[Episode]):
27
+ """
28
+ Initialize sampler with episode storage reference.
29
+
30
+ Args:
31
+ episodes: Reference to list of complete episodes
32
+ """
33
+ self.episodes = episodes
34
+
35
+ def sample_all(self) -> Dict[str, mx.array]:
36
+ """
37
+ Sample all stored episodes as a single concatenated batch.
38
+
39
+ Returns all transitions in chronological order:
40
+ - Oldest episode → Newest episode
41
+ - Each episode: first step → last step
42
+
43
+ Returns:
44
+ Dict of MLX arrays with all transitions
45
+
46
+ Raises:
47
+ ValueError: If buffer is empty
48
+ """
49
+ if not self.episodes:
50
+ raise ValueError("Buffer empty. No episodes to sample.")
51
+
52
+ # Concatenate all episodes by collecting data directly (bypass Episode validation)
53
+ # Episodes can have different optional fields, so collect what exists across all episodes
54
+ all_obs = []
55
+ all_act = []
56
+ all_rew = []
57
+ all_next_obs = []
58
+ all_done = []
59
+ all_timeout = []
60
+ all_logprob = []
61
+ all_value = []
62
+ all_entropy = []
63
+
64
+ # Only include optional fields that exist in ALL episodes for consistent sampling
65
+ # This matches the buffer's "all-or-nothing" design philosophy per episode
66
+ has_logprob = all(episode.logprob is not None for episode in self.episodes)
67
+ has_value = all(episode.value is not None for episode in self.episodes)
68
+ has_entropy = all(episode.entropy is not None for episode in self.episodes)
69
+
70
+ # Collect transitions from all episodes
71
+ for episode in self.episodes:
72
+ for i in range(len(episode)):
73
+ all_obs.append(episode.obs[i])
74
+ all_act.append(episode.act[i])
75
+ all_rew.append(episode.rew[i])
76
+ all_next_obs.append(episode.next_obs[i])
77
+ all_done.append(episode.done[i])
78
+ all_timeout.append(episode.timeout[i])
79
+
80
+ # Only collect optional fields that exist in ALL episodes
81
+ if has_logprob:
82
+ all_logprob.append(episode.logprob[i])
83
+ if has_value:
84
+ all_value.append(episode.value[i])
85
+ if has_entropy:
86
+ all_entropy.append(episode.entropy[i])
87
+
88
+ # Create Episode directly with collected data (bypassing validation during construction)
89
+ all_transitions = Episode()
90
+ all_transitions.obs = all_obs
91
+ all_transitions.act = all_act
92
+ all_transitions.rew = all_rew
93
+ all_transitions.next_obs = all_next_obs
94
+ all_transitions.done = all_done
95
+ all_transitions.timeout = all_timeout
96
+
97
+ # Only set optional fields if they exist in any episode
98
+ if has_logprob:
99
+ all_transitions.logprob = all_logprob
100
+ if has_value:
101
+ all_transitions.value = all_value
102
+ if has_entropy:
103
+ all_transitions.entropy = all_entropy
104
+
105
+ return all_transitions.to_tensor_dict()
106
+
107
+ def sample_latest_steps(self, n: int) -> Dict[str, mx.array]:
108
+ """
109
+ Sample the N most recent transitions across episodes.
110
+
111
+ Useful for on-policy RL algorithms that use recent experience for stable training.
112
+
113
+ Args:
114
+ n: Number of steps to sample (must be > 0)
115
+
116
+ Returns:
117
+ Dict of MLX arrays with the latest n steps in chronological order
118
+ (oldest → newest, ensuring temporal consistency)
119
+
120
+ Raises:
121
+ ValueError: If buffer is empty, n <= 0, or not enough steps
122
+ """
123
+ if not self.episodes:
124
+ raise ValueError("Buffer is empty. No recent episodes to sample.")
125
+ if n <= 0:
126
+ raise ValueError("Number of steps (n) to sample must be greater than 0.")
127
+
128
+ # Collect steps in reverse chronological order (newest first)
129
+ steps = []
130
+ for episode in reversed(self.episodes):
131
+ for i in reversed(range(len(episode))):
132
+ step_dict = {
133
+ "obs": episode.obs[i],
134
+ "act": episode.act[i],
135
+ "rew": episode.rew[i],
136
+ "next_obs": episode.next_obs[i],
137
+ "done": episode.done[i],
138
+ "timeout": episode.timeout[i],
139
+ }
140
+
141
+ # Add optional fields if present
142
+ if episode.logprob is not None:
143
+ step_dict["logprob"] = episode.logprob[i]
144
+ if episode.value is not None:
145
+ step_dict["value"] = episode.value[i]
146
+ if episode.entropy is not None:
147
+ step_dict["entropy"] = episode.entropy[i]
148
+
149
+ steps.append(step_dict)
150
+
151
+ # Stop when we have enough steps
152
+ if len(steps) >= n:
153
+ break
154
+
155
+ if len(steps) >= n:
156
+ break
157
+
158
+ if not steps:
159
+ raise ValueError("No steps available to sample")
160
+
161
+ # Reverse to get chronological order (oldest → newest)
162
+ chronological_steps = list(reversed(steps[:n]))
163
+
164
+ # Batch convert to MLX arrays
165
+ # Collect values first, then convert in a single pass for memory efficiency
166
+ batch = {}
167
+ for key in chronological_steps[0].keys():
168
+ # Extract all values for this key at once
169
+ values = [step[key] for step in chronological_steps]
170
+ # Single batch conversion is more efficient than individual mx.array() calls
171
+ batch[key] = mx.stack([mx.array(v) for v in values])
172
+
173
+ return batch
174
+
175
+ def sample_episodes(self, k: int, order: str = 'asc') -> Dict[str, mx.array]:
176
+ """
177
+ Sample up to k complete episodes.
178
+
179
+ Useful for episode-based training or evaluation analysis.
180
+
181
+ Args:
182
+ k: Number of episodes to sample (must be > 0)
183
+ order: 'asc' for oldest first, 'desc' for newest first
184
+
185
+ Returns:
186
+ Dict of MLX arrays with concatenated transitions from selected episodes
187
+
188
+ Raises:
189
+ ValueError: If buffer is empty, k <= 0, or invalid order
190
+ """
191
+ if not self.episodes:
192
+ raise ValueError("Buffer is empty. No episodes to sample.")
193
+ if k <= 0:
194
+ raise ValueError("k must be positive.")
195
+ if order not in ('asc', 'desc'):
196
+ raise ValueError("order must be 'asc' or 'desc'")
197
+
198
+ # Select episodes based on order
199
+ if order == 'asc':
200
+ selected_episodes = self.episodes[:k] # Oldest k episodes
201
+ else: # 'desc'
202
+ selected_episodes = self.episodes[-k:] # Newest k episodes
203
+
204
+ if not selected_episodes:
205
+ raise ValueError("No episodes matched the criteria.")
206
+
207
+ # Concatenate all steps from selected episodes
208
+ all_transitions = Episode()
209
+ for episode in selected_episodes:
210
+ for i in range(len(episode)):
211
+ all_transitions.append(
212
+ obs=episode.obs[i],
213
+ act=episode.act[i],
214
+ rew=episode.rew[i],
215
+ next_obs=episode.next_obs[i],
216
+ done=episode.done[i],
217
+ timeout=episode.timeout[i],
218
+ logprob=episode.logprob[i] if episode.logprob is not None else None,
219
+ value=episode.value[i] if episode.value is not None else None,
220
+ entropy=episode.entropy[i] if episode.entropy is not None else None
221
+ )
222
+
223
+ return all_transitions.to_tensor_dict()
224
+
225
+ def sample_sequences(
226
+ self,
227
+ batch_size: int,
228
+ seq_len: int,
229
+ recent_first: bool = True,
230
+ drop_incomplete: bool = True,
231
+ dreamerv3_mode: bool = False,
232
+ ) -> Dict[str, mx.array]:
233
+ """
234
+ Sample contiguous sequences of length `seq_len`.
235
+
236
+ Standard mode: Sample without crossing episode boundaries (for most algorithms).
237
+ DreamerV3 mode: Sample across episode boundaries to include terminals (for continue head supervision).
238
+
239
+ Args:
240
+ batch_size: Number of sequences to sample.
241
+ seq_len: Length of each sequence (T). Must be > 0.
242
+ recent_first: Prefer sampling from most recent episodes first.
243
+ drop_incomplete: If True, skip episodes shorter than seq_len (standard mode only).
244
+ dreamerv3_mode: If True, sample across episode boundaries to include terminals.
245
+
246
+ Returns:
247
+ Dict of MLX arrays with keys: obs, act, rew, next_obs, done, timeout
248
+ and optional keys (logprob, value, entropy) included only if present.
249
+
250
+ Raises:
251
+ ValueError: If buffer is empty or inputs are invalid or not enough data.
252
+ """
253
+ if not self.episodes:
254
+ raise ValueError("Buffer is empty. No episodes to sample.")
255
+ if batch_size <= 0 or seq_len <= 0:
256
+ raise ValueError("batch_size and seq_len must be positive.")
257
+
258
+ if dreamerv3_mode:
259
+ return self._sample_dreamerv3_sequences(batch_size, seq_len, recent_first)
260
+
261
+ # Choose episode order per recency preference.
262
+ episodes_iter: List[Episode]
263
+ episodes_iter = list(reversed(self.episodes)) if recent_first else list(self.episodes)
264
+
265
+ # Collect one latest contiguous window per episode until batch is filled.
266
+ sequences = []
267
+ episodes_used: List[Episode] = []
268
+
269
+ for ep in episodes_iter:
270
+ n = len(ep)
271
+ if n < seq_len:
272
+ if drop_incomplete:
273
+ continue
274
+ # Padding/masking path intentionally not implemented for simplicity/perf.
275
+ continue
276
+
277
+ # DreamerV3 expects sequences sampled from throughout episodes, not only tails.
278
+ # Avoid tail-bias that would overrepresent terminal steps. Sample a random
279
+ # start index within the episode. This preserves recency at the episode
280
+ # level via episodes_iter and reduces bias within each episode.
281
+ max_start = n - seq_len
282
+ start = random.randint(0, max_start) if max_start > 0 else 0
283
+ end = start + seq_len # ensure fixed-length window [start, start+seq_len)
284
+ seq = {
285
+ 'obs': ep.obs[start:end],
286
+ 'act': ep.act[start:end],
287
+ 'rew': ep.rew[start:end],
288
+ 'next_obs': ep.next_obs[start:end],
289
+ 'done': ep.done[start:end],
290
+ 'timeout': ep.timeout[start:end],
291
+ }
292
+
293
+ # Optional fields: only include if present in the episode
294
+ if ep.logprob is not None:
295
+ seq['logprob'] = ep.logprob[start:end]
296
+ if ep.value is not None:
297
+ seq['value'] = ep.value[start:end]
298
+ if ep.entropy is not None:
299
+ seq['entropy'] = ep.entropy[start:end]
300
+
301
+ sequences.append(seq)
302
+ episodes_used.append(ep)
303
+ if len(sequences) >= batch_size:
304
+ break
305
+
306
+ if not sequences:
307
+ raise ValueError("No sequences available to sample.")
308
+
309
+ # Determine optional keys that exist across all sampled sequences for consistent batching.
310
+ all_keys = set(sequences[0].keys())
311
+ for s in sequences[1:]:
312
+ all_keys &= set(s.keys())
313
+
314
+ # Convert to MLX arrays with shape [B, T, ...]
315
+ batch: Dict[str, mx.array] = {}
316
+ for key in all_keys:
317
+ # First convert each sequence to [T, ...]
318
+ per_seq = []
319
+ for s in sequences:
320
+ # Efficient batch conversion: one mx.array() per time step then stack along time
321
+ # to minimize Python overhead while keeping explicit control of dimensions.
322
+ per_seq.append(mx.stack([mx.array(v) for v in s[key]], axis=0)) # [T, ...]
323
+ # Stack sequences along batch dimension
324
+ batch[key] = mx.stack(per_seq, axis=0) # [B, T, ...]
325
+
326
+ return batch
327
+
328
+ def _sample_dreamerv3_sequences(
329
+ self,
330
+ batch_size: int,
331
+ seq_len: int,
332
+ recent_first: bool
333
+ ) -> Dict[str, mx.array]:
334
+ """
335
+ Sample sequences for DreamerV3 by concatenating episodes to form continuous trajectories.
336
+
337
+ Unlike standard sampling, this allows sequences to span episode boundaries,
338
+ ensuring episode terminals appear within training sequences for continue head supervision.
339
+
340
+ This mirrors DreamerV3's original replay buffer behavior where episodes are stored
341
+ continuously and sampling naturally includes episode boundaries.
342
+ """
343
+ import random
344
+
345
+ # Create continuous trajectory by concatenating recent episodes
346
+ episodes_iter = list(reversed(self.episodes)) if recent_first else list(self.episodes)
347
+
348
+ # Concatenate episodes into a continuous stream
349
+ continuous_data = {
350
+ 'obs': [], 'act': [], 'rew': [], 'next_obs': [], 'done': [], 'timeout': [],
351
+ 'logprob': [], 'value': [], 'entropy': []
352
+ }
353
+
354
+ for ep in episodes_iter:
355
+ continuous_data['obs'].extend(ep.obs)
356
+ continuous_data['act'].extend(ep.act)
357
+ continuous_data['rew'].extend(ep.rew)
358
+ continuous_data['next_obs'].extend(ep.next_obs)
359
+ continuous_data['done'].extend(ep.done)
360
+ continuous_data['timeout'].extend(ep.timeout)
361
+
362
+ # Optional fields - extend with zeros if episode doesn't have them
363
+ if ep.logprob is not None:
364
+ continuous_data['logprob'].extend(ep.logprob)
365
+ else:
366
+ continuous_data['logprob'].extend([0.0] * len(ep.obs))
367
+
368
+ if ep.value is not None:
369
+ continuous_data['value'].extend(ep.value)
370
+ else:
371
+ continuous_data['value'].extend([0.0] * len(ep.obs))
372
+
373
+ if ep.entropy is not None:
374
+ continuous_data['entropy'].extend(ep.entropy)
375
+ else:
376
+ continuous_data['entropy'].extend([0.0] * len(ep.obs))
377
+
378
+ total_steps = len(continuous_data['obs'])
379
+ if total_steps < seq_len:
380
+ raise ValueError(f"Not enough continuous data: {total_steps} steps < {seq_len} required")
381
+
382
+ # Sample random sequences from the continuous trajectory
383
+ sequences = []
384
+ for _ in range(batch_size):
385
+ # Random start position that allows full sequence
386
+ max_start = total_steps - seq_len
387
+ start = random.randint(0, max_start) if max_start > 0 else 0
388
+ end = start + seq_len
389
+
390
+ # Extract sequence
391
+ seq = {}
392
+ for key in ['obs', 'act', 'rew', 'next_obs', 'done', 'timeout', 'logprob', 'value', 'entropy']:
393
+ seq[key] = continuous_data[key][start:end]
394
+
395
+ sequences.append(seq)
396
+
397
+ # Convert to batch format [batch_size, seq_len, ...] using the same efficient conversion as standard mode
398
+ batch = {}
399
+ for key in sequences[0].keys():
400
+ per_seq = []
401
+ for s in sequences:
402
+ # Convert sequence to mx.array and stack along time dimension
403
+ per_seq.append(mx.stack([mx.array(v) for v in s[key]], axis=0)) # [T, ...]
404
+ # Stack sequences along batch dimension
405
+ batch[key] = mx.stack(per_seq, axis=0) # [B, T, ...]
406
+
407
+ return batch
408
+
409
+ def get_episode_statistics(self) -> Dict[str, float]:
410
+ """
411
+ Get statistical information about stored episodes.
412
+
413
+ Returns:
414
+ Dictionary with episode statistics for analysis
415
+ """
416
+ if not self.episodes:
417
+ return {
418
+ 'episode_count': 0,
419
+ 'total_steps': 0,
420
+ 'mean_episode_length': 0.0,
421
+ 'min_episode_length': 0,
422
+ 'max_episode_length': 0,
423
+ 'total_reward': 0.0,
424
+ 'mean_episode_reward': 0.0
425
+ }
426
+
427
+ episode_lengths = [len(ep) for ep in self.episodes]
428
+ episode_rewards = [sum(ep.rew) for ep in self.episodes]
429
+
430
+ return {
431
+ 'episode_count': len(self.episodes),
432
+ 'total_steps': sum(episode_lengths),
433
+ 'mean_episode_length': sum(episode_lengths) / len(episode_lengths),
434
+ 'min_episode_length': min(episode_lengths),
435
+ 'max_episode_length': max(episode_lengths),
436
+ 'total_reward': sum(episode_rewards),
437
+ 'mean_episode_reward': sum(episode_rewards) / len(episode_rewards)
438
+ }
@@ -0,0 +1,255 @@
1
+ # textpolicy/buffer/storage.py
2
+ """
3
+ Buffer storage and capacity management.
4
+
5
+ Handles episode storage, capacity limits, and episode lifecycle management.
6
+ Designed for efficient memory usage on Apple Silicon.
7
+ """
8
+
9
+ from typing import List, Dict, Any
10
+ from .episode import Episode
11
+
12
+
13
+ class BufferStorage:
14
+ """
15
+ Manages episode storage with capacity limits and lifecycle.
16
+
17
+ Features:
18
+ - FIFO episode eviction when capacity exceeded
19
+ - Episode validation before storage
20
+ - Efficient storage for multiprocessing scenarios
21
+ - Memory-conscious design for Apple Silicon
22
+ """
23
+
24
+ def __init__(self, max_episodes: int = 100):
25
+ """
26
+ Initialize buffer storage.
27
+
28
+ Args:
29
+ max_episodes: Maximum number of complete episodes to store.
30
+ Oldest episodes are dropped when capacity is exceeded.
31
+ """
32
+ self.max_episodes = max_episodes
33
+ self.episodes: List[Episode] = []
34
+ self.current_episode = Episode()
35
+
36
+ def add_transition(
37
+ self,
38
+ obs: Any,
39
+ act: Any,
40
+ rew: Any,
41
+ next_obs: Any,
42
+ done: bool,
43
+ timeout: bool = False,
44
+ **kwargs # Additional fields like logprob, value, entropy
45
+ ):
46
+ """
47
+ Add a transition to the current episode.
48
+
49
+ Completes and stores the episode if done or timeout is True.
50
+
51
+ Args:
52
+ obs: Observation
53
+ act: Action taken
54
+ rew: Reward received
55
+ next_obs: Next observation
56
+ done: Boolean indicating episode termination
57
+ timeout: Boolean indicating truncation (e.g. time limit)
58
+ **kwargs: Optional fields (logprob, value, entropy)
59
+ """
60
+ # Add transition to current episode
61
+ self.current_episode.append(
62
+ obs=obs, act=act, rew=rew, next_obs=next_obs,
63
+ done=done, timeout=timeout, **kwargs
64
+ )
65
+
66
+ # Complete episode if terminated
67
+ if done or timeout:
68
+ self._complete_current_episode()
69
+
70
+ def _complete_current_episode(self):
71
+ """
72
+ Complete the current episode and start a new one.
73
+
74
+ Validates episode before storage and enforces capacity limits.
75
+ """
76
+ # Validate episode before storage
77
+ if len(self.current_episode) > 0:
78
+ self.current_episode.validate_consistency()
79
+
80
+ # Debug: episode data quality
81
+ self._debug_episode_data(self.current_episode)
82
+
83
+ # Add to storage
84
+ self.episodes.append(self.current_episode)
85
+
86
+ # Enforce capacity limit (FIFO eviction)
87
+ if len(self.episodes) > self.max_episodes:
88
+ self.episodes.pop(0) # Remove oldest episode
89
+
90
+ # Start new episode
91
+ self.current_episode = Episode()
92
+
93
+ def _debug_episode_data(self, episode):
94
+ """Debug what episodes look like - only show every 10th episode."""
95
+ if not hasattr(self, 'episode_debug_count'):
96
+ self.episode_debug_count = 0
97
+
98
+ self.episode_debug_count += 1
99
+
100
+ # Only debug every 50th episode to avoid spam
101
+ if self.episode_debug_count % 50 == 1:
102
+ try:
103
+ # Type checking and safe conversion
104
+ episode_count = int(self.episode_debug_count)
105
+
106
+ rewards = episode.rew if hasattr(episode, 'rew') else []
107
+ # Convert all rewards to float before summing to handle mixed types
108
+ try:
109
+ numeric_rewards = [float(r) for r in rewards]
110
+ total_reward = sum(numeric_rewards)
111
+ except (TypeError, ValueError):
112
+ # Fallback: filter out non-numeric rewards
113
+ numeric_rewards = []
114
+ for r in rewards:
115
+ try:
116
+ numeric_rewards.append(float(r))
117
+ except (TypeError, ValueError):
118
+ continue
119
+ total_reward = sum(numeric_rewards) if numeric_rewards else 0.0
120
+ episode_length = len(episode)
121
+
122
+ print(f"\nEPISODE DEBUG (Episode #{episode_count}):")
123
+ print(f" Episode length: {episode_length}")
124
+ print(f" Total reward: {total_reward:.2f}")
125
+
126
+ if rewards:
127
+ # Ensure rewards are numeric before formatting
128
+ try:
129
+ first_rewards = [float(r) for r in rewards[:10]]
130
+ last_reward = float(rewards[-1])
131
+ print(f" Reward sequence: {first_rewards}...")
132
+ print(f" Last reward: {last_reward:.2f}")
133
+ except (TypeError, ValueError) as reward_error:
134
+ print(f" Reward formatting error: {reward_error}")
135
+ print(f" Raw rewards: {rewards[:10]}...")
136
+
137
+ # Check termination
138
+ if hasattr(episode, 'done') and episode.done:
139
+ final_done = episode.done[-1] if episode.done else False
140
+ print(f" Final done: {final_done}")
141
+
142
+ if hasattr(episode, 'timeout') and episode.timeout:
143
+ final_timeout = episode.timeout[-1] if episode.timeout else False
144
+ print(f" Final timeout: {final_timeout}")
145
+
146
+ except Exception as e:
147
+ print(f" Episode debug error: {e}")
148
+ # More detailed error info
149
+ import traceback
150
+ print(f" Error details: {traceback.format_exc()}")
151
+
152
+ def add_episode_from_dict(self, data: Dict[str, Any]):
153
+ """
154
+ Reconstruct and add an episode from serialized dictionary.
155
+
156
+ Used for multiprocessing: worker serializes episode to dict,
157
+ trainer deserializes and adds to buffer.
158
+
159
+ Args:
160
+ data: Dictionary containing episode data from episode.to_dict()
161
+ Must include: obs, act, rew, next_obs, done, timeout
162
+ Optional: logprob, value, entropy
163
+
164
+ Raises:
165
+ ValueError: If episode data is invalid or inconsistent
166
+ """
167
+ # Create new episode from dictionary
168
+ episode = Episode()
169
+
170
+ # Set required fields
171
+ episode.obs = data['obs']
172
+ episode.act = data['act']
173
+ episode.rew = data['rew']
174
+ episode.next_obs = data['next_obs']
175
+ episode.done = data['done']
176
+ episode.timeout = data['timeout']
177
+
178
+ # Set optional fields if present
179
+ episode.logprob = data.get('logprob', None)
180
+ episode.value = data.get('value', None)
181
+ episode.entropy = data.get('entropy', None)
182
+
183
+ # Validate consistency before adding
184
+ episode.validate_consistency()
185
+
186
+ # Add to storage
187
+ self.episodes.append(episode)
188
+
189
+ # Enforce capacity limit
190
+ if len(self.episodes) > self.max_episodes:
191
+ self.episodes.pop(0)
192
+
193
+ def clear(self):
194
+ """
195
+ Clear all stored episodes and reset current episode.
196
+
197
+ Used to reset buffer state between training runs.
198
+ """
199
+ self.episodes.clear()
200
+ self.current_episode = Episode()
201
+
202
+ def ready(self, min_episodes: int = 1) -> bool:
203
+ """
204
+ Check if buffer has enough complete episodes for training.
205
+
206
+ Args:
207
+ min_episodes: Minimum number of episodes required
208
+
209
+ Returns:
210
+ True if buffer has at least min_episodes complete episodes
211
+ """
212
+ return len(self.episodes) >= min_episodes
213
+
214
+ def total_steps(self) -> int:
215
+ """
216
+ Calculate total number of steps across all episodes.
217
+
218
+ Returns:
219
+ Sum of steps in all complete episodes
220
+ """
221
+ return sum(len(episode) for episode in self.episodes)
222
+
223
+ def get_episodes(self) -> List[Episode]:
224
+ """
225
+ Get read-only access to stored episodes.
226
+
227
+ Returns:
228
+ List of complete episodes (does not include current incomplete episode)
229
+ """
230
+ return self.episodes.copy()
231
+
232
+ def __len__(self) -> int:
233
+ """Return total number of steps in buffer."""
234
+ return self.total_steps()
235
+
236
+ @property
237
+ def episode_count(self) -> int:
238
+ """Number of complete episodes currently stored."""
239
+ return len(self.episodes)
240
+
241
+ def get_storage_info(self) -> Dict[str, Any]:
242
+ """
243
+ Get detailed storage information for debugging.
244
+
245
+ Returns:
246
+ Dictionary with storage statistics
247
+ """
248
+ return {
249
+ 'episode_count': len(self.episodes),
250
+ 'max_episodes': self.max_episodes,
251
+ 'total_steps': self.total_steps(),
252
+ 'current_episode_steps': len(self.current_episode),
253
+ 'episode_lengths': [len(ep) for ep in self.episodes],
254
+ 'capacity_usage': len(self.episodes) / self.max_episodes if self.max_episodes > 0 else 0.0
255
+ }