textpolicy 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. textpolicy/__init__.py +52 -0
  2. textpolicy/__main__.py +8 -0
  3. textpolicy/algorithms/__init__.py +54 -0
  4. textpolicy/algorithms/grpo.py +642 -0
  5. textpolicy/algorithms/gspo.py +582 -0
  6. textpolicy/buffer/__init__.py +23 -0
  7. textpolicy/buffer/buffer.py +244 -0
  8. textpolicy/buffer/episode.py +383 -0
  9. textpolicy/buffer/sampling.py +438 -0
  10. textpolicy/buffer/storage.py +255 -0
  11. textpolicy/cli.py +67 -0
  12. textpolicy/environment/__init__.py +79 -0
  13. textpolicy/environment/base.py +110 -0
  14. textpolicy/environment/environment.py +46 -0
  15. textpolicy/environment/factory.py +103 -0
  16. textpolicy/environment/gym.py +106 -0
  17. textpolicy/environment/task_suites.py +51 -0
  18. textpolicy/environment/text_generation.py +789 -0
  19. textpolicy/environment/vectorized.py +253 -0
  20. textpolicy/generation/__init__.py +62 -0
  21. textpolicy/generation/lora.py +411 -0
  22. textpolicy/generation/mlx_generation.py +557 -0
  23. textpolicy/generation/reload.py +253 -0
  24. textpolicy/rewards/__init__.py +137 -0
  25. textpolicy/rewards/adapters.py +387 -0
  26. textpolicy/rewards/basic.py +214 -0
  27. textpolicy/rewards/integrated_system.py +338 -0
  28. textpolicy/rewards/mlx_batch_processor.py +447 -0
  29. textpolicy/rewards/registry.py +293 -0
  30. textpolicy/rewards/rollout_rewards.py +410 -0
  31. textpolicy/rewards/verifiers.py +369 -0
  32. textpolicy/rollout/__init__.py +44 -0
  33. textpolicy/rollout/aggregator.py +145 -0
  34. textpolicy/rollout/base.py +108 -0
  35. textpolicy/rollout/rollout.py +142 -0
  36. textpolicy/rollout/runner.py +280 -0
  37. textpolicy/rollout/strategy.py +208 -0
  38. textpolicy/rollout/worker.py +194 -0
  39. textpolicy/training/__init__.py +14 -0
  40. textpolicy/training/metrics.py +242 -0
  41. textpolicy/training/rollout_manager.py +78 -0
  42. textpolicy/training/trainer.py +684 -0
  43. textpolicy/utils/__init__.py +40 -0
  44. textpolicy/utils/benchmarking.py +489 -0
  45. textpolicy/utils/data.py +60 -0
  46. textpolicy/utils/debug.py +170 -0
  47. textpolicy/utils/environment.py +349 -0
  48. textpolicy/utils/logging/__init__.py +22 -0
  49. textpolicy/utils/logging/base.py +48 -0
  50. textpolicy/utils/logging/console.py +61 -0
  51. textpolicy/utils/logging/factory.py +133 -0
  52. textpolicy/utils/logging/multi.py +83 -0
  53. textpolicy/utils/logging/tensorboard.py +65 -0
  54. textpolicy/utils/logging/wandb.py +72 -0
  55. textpolicy/utils/memory.py +118 -0
  56. textpolicy/utils/performance.py +464 -0
  57. textpolicy/utils/timing.py +171 -0
  58. textpolicy/validate.py +101 -0
  59. textpolicy/validation/__init__.py +13 -0
  60. textpolicy/validation/logprob_validation.py +315 -0
  61. textpolicy-0.1.0.dist-info/METADATA +99 -0
  62. textpolicy-0.1.0.dist-info/RECORD +66 -0
  63. textpolicy-0.1.0.dist-info/entry_points.txt +2 -0
  64. textpolicy-0.0.1.dist-info/METADATA +0 -10
  65. textpolicy-0.0.1.dist-info/RECORD +0 -6
  66. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/WHEEL +0 -0
  67. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/licenses/LICENSE +0 -0
  68. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,244 @@
1
+ # textpolicy/buffer/buffer.py
2
+ """
3
+ Coordinates storage and sampling for RL training.
4
+
5
+ The Buffer class provides a clean interface for episode-centric replay
6
+ buffer operations, optimized for on-policy RL algorithms.
7
+ """
8
+
9
+ from typing import Optional, Any, Dict
10
+ from .storage import BufferStorage
11
+ from .sampling import BufferSampler
12
+
13
+
14
+ class Buffer:
15
+ """
16
+ Episode-centric replay buffer for on-policy RL (e.g., PPO).
17
+
18
+ Stores full episodes and converts to tensors at sample time.
19
+ Prevents silent corruption from circular overwrite.
20
+
21
+ The buffer enforces clean rollouts:
22
+ - Episodes are either complete or not stored
23
+ - Optional fields (logprob, value) must be all-or-nothing
24
+ - No partial episodes, no fragmented trajectories
25
+
26
+ Designed for:
27
+ - Apple Silicon (MLX, unified memory)
28
+ - Multiprocessing (not threading)
29
+ - PPO, GAE, and other on-policy algorithms
30
+
31
+ Example:
32
+ buffer = Buffer(max_episodes=100)
33
+
34
+ # Collect data
35
+ buffer.add(obs=obs, act=act, rew=rew, next_obs=next_obs, done=done)
36
+
37
+ # Sample data
38
+ batch = buffer.sample_latest_steps(2048) # Last 2k steps
39
+ batch = buffer.sample_episodes(10, order='desc') # Last 10 episodes
40
+ """
41
+
42
+ def __init__(self, max_episodes: int = 100):
43
+ """
44
+ Initialize the buffer.
45
+
46
+ Args:
47
+ max_episodes: Maximum number of complete episodes to store.
48
+ Oldest episodes are dropped when capacity is exceeded.
49
+ """
50
+ self.storage = BufferStorage(max_episodes)
51
+ self.sampler = BufferSampler(self.storage.episodes)
52
+
53
+ def add(
54
+ self,
55
+ obs: Any,
56
+ act: Any,
57
+ rew: Any,
58
+ next_obs: Any,
59
+ done: bool,
60
+ timeout: bool = False,
61
+ logprob: Optional[Any] = None,
62
+ value: Optional[Any] = None,
63
+ entropy: Optional[Any] = None
64
+ ):
65
+ """
66
+ Add a transition to the current episode.
67
+
68
+ Completes the episode and stores it if `done` or `timeout` is True.
69
+
70
+ Args:
71
+ obs: Observation
72
+ act: Action taken
73
+ rew: Reward received
74
+ next_obs: Next observation
75
+ done: Boolean indicating episode termination
76
+ timeout: Boolean indicating truncation (e.g. time limit)
77
+ logprob: Log probability of action (optional, all-or-nothing)
78
+ value: Estimated state value (optional, all-or-nothing)
79
+ entropy: Action entropy (optional, all-or-nothing)
80
+
81
+ Example:
82
+ buffer.add(
83
+ obs=obs,
84
+ act=action,
85
+ rew=reward,
86
+ next_obs=next_obs,
87
+ done=done,
88
+ timeout=timeout,
89
+ logprob=logp.item(),
90
+ value=value.item()
91
+ )
92
+ """
93
+ self.storage.add_transition(
94
+ obs=obs, act=act, rew=rew, next_obs=next_obs,
95
+ done=done, timeout=timeout,
96
+ logprob=logprob, value=value, entropy=entropy
97
+ )
98
+
99
+ def sample(self) -> Dict[str, Any]:
100
+ """
101
+ Sample all stored episodes as a single concatenated batch.
102
+
103
+ Returns:
104
+ Dict of MLX arrays with all transitions, in chronological order:
105
+ - Oldest episode → Newest episode
106
+ - Each episode: first step → last step
107
+
108
+ Raises:
109
+ ValueError: If buffer is empty
110
+ """
111
+ return self.sampler.sample_all()
112
+
113
+ def sample_latest_steps(self, n: int) -> Dict[str, Any]:
114
+ """
115
+ Sample the N most recent transitions across episodes.
116
+
117
+ Returns:
118
+ Dict of MLX arrays with the latest `n` steps,
119
+ in **chronological order** (oldest → newest).
120
+
121
+ Args:
122
+ n: Number of steps to sample (must be > 0)
123
+
124
+ Raises:
125
+ ValueError: If buffer is empty or n <= 0
126
+ """
127
+ return self.sampler.sample_latest_steps(n)
128
+
129
+ def sample_episodes(self, k: int, order: str = 'asc') -> Dict[str, Any]:
130
+ """
131
+ Sample up to k complete episodes.
132
+
133
+ Args:
134
+ k: Number of episodes to sample (must be > 0)
135
+ order: 'asc' for oldest first, 'desc' for newest first
136
+
137
+ Returns:
138
+ Dict of MLX arrays with concatenated transitions from selected episodes.
139
+
140
+ Raises:
141
+ ValueError: If buffer is empty, k <= 0, or invalid order
142
+ """
143
+ return self.sampler.sample_episodes(k, order)
144
+
145
+ def sample_sequences(
146
+ self,
147
+ batch_size: int,
148
+ seq_len: int,
149
+ recent_first: bool = True,
150
+ drop_incomplete: bool = True,
151
+ dreamerv3_mode: bool = False,
152
+ ) -> Dict[str, Any]:
153
+ """
154
+ Sample contiguous sequences of length `seq_len` for DreamerV3 RSSM training.
155
+
156
+ Returns tensors shaped [batch, time, ...] and avoids crossing episode boundaries.
157
+
158
+ This method is intentionally minimal and efficient to support Apple Silicon
159
+ memory patterns and avoids padding logic; set `drop_incomplete=True` to skip
160
+ short episodes.
161
+ """
162
+ return self.sampler.sample_sequences(
163
+ batch_size=batch_size,
164
+ seq_len=seq_len,
165
+ recent_first=recent_first,
166
+ drop_incomplete=drop_incomplete,
167
+ dreamerv3_mode=dreamerv3_mode,
168
+ )
169
+
170
+ def add_episode_from_dict(self, data: Dict[str, Any]):
171
+ """
172
+ Reconstruct and add an episode from a serialized dictionary.
173
+
174
+ This is used to deserialize episodes sent from RolloutWorker.
175
+
176
+ Args:
177
+ data: Dictionary containing episode data (e.g. from `episode.to_dict()`)
178
+ Must include: obs, act, rew, next_obs, done, timeout
179
+ Optional: logprob, value, entropy
180
+ """
181
+ self.storage.add_episode_from_dict(data)
182
+
183
+ def clear(self):
184
+ """
185
+ Reset the buffer: clear all stored episodes and reset current episode.
186
+ """
187
+ self.storage.clear()
188
+
189
+ def ready(self, min_episodes: int = 1) -> bool:
190
+ """
191
+ Check if buffer contains at least `min_episodes` complete episodes.
192
+
193
+ Args:
194
+ min_episodes: Minimum number of episodes required (default: 1)
195
+
196
+ Returns:
197
+ True if buffer has enough episodes, False otherwise
198
+ """
199
+ return self.storage.ready(min_episodes)
200
+
201
+ def __len__(self) -> int:
202
+ """
203
+ Total number of steps in the buffer.
204
+
205
+ Returns:
206
+ Sum of steps across all stored episodes
207
+ """
208
+ return len(self.storage)
209
+
210
+ @property
211
+ def episodes(self):
212
+ """Access to underlying episodes for backwards compatibility."""
213
+ return self.storage.episodes
214
+
215
+ @property
216
+ def current_episode(self):
217
+ """Access to current incomplete episode for backwards compatibility."""
218
+ return self.storage.current_episode
219
+
220
+ @property
221
+ def episode_count(self) -> int:
222
+ """Number of complete episodes currently stored."""
223
+ return self.storage.episode_count
224
+
225
+ def print_state(self, label: str = "Buffer State"):
226
+ """
227
+ Print current buffer state. Useful for debugging.
228
+
229
+ Args:
230
+ label: Label to display at the top
231
+ """
232
+ info = self.storage.get_storage_info()
233
+ stats = self.sampler.get_episode_statistics()
234
+
235
+ print("=" * 50)
236
+ print(f"{label}")
237
+ print(f"Episodes stored : {info['episode_count']} (max={info['max_episodes']})")
238
+ print(f"Total steps : {info['total_steps']}")
239
+ print(f"Capacity usage : {info['capacity_usage']:.1%}")
240
+ if info['episode_lengths']:
241
+ print(f"Episode lengths : {info['episode_lengths']}")
242
+ print(f"Mean length : {stats['mean_episode_length']:.1f}")
243
+ print(f"Mean reward : {stats['mean_episode_reward']:.2f}")
244
+ print("=" * 50)
@@ -0,0 +1,383 @@
1
+ # textpolicy/buffer/episode.py
2
+ """
3
+ Single episode trajectory management.
4
+
5
+ The Episode class stores transitions as Python lists during rollout,
6
+ then converts to MLX arrays only at sampling time. This aims to be optimal
7
+ for Apple Silicon's unified memory architecture.
8
+ """
9
+
10
+ from typing import Optional, Any, Dict
11
+ import mlx.core as mx # type: ignore
12
+
13
+
14
+ class Episode:
15
+ """
16
+ Represents a single complete episode trajectory.
17
+
18
+ Stores transitions as Python lists during rollout, then converts to MLX arrays
19
+ only at sampling time. This aims to be optimal for Apple Silicon's unified memory.
20
+
21
+ All optional fields (e.g. `logprob`, `value`) must be provided for **all steps**
22
+ or **none** — mixing will raise an error. This ensures tensor shape consistency.
23
+
24
+ Example:
25
+ ep = Episode()
26
+ ep.append(obs=1, act=0, rew=1, next_obs=2, done=False, logprob=0.1, value=1.5)
27
+ ep.append(obs=2, act=1, rew=2, next_obs=3, done=True, logprob=0.2, value=2.5)
28
+
29
+ batch = ep.to_tensor_dict() # Returns dict of MLX arrays
30
+ """
31
+
32
+ def __init__(self):
33
+ """Initialize empty episode with required fields."""
34
+ # Required fields - always present
35
+ self.obs: list[Any] = []
36
+ self.act: list[Any] = []
37
+ self.rew: list[Any] = []
38
+ self.next_obs: list[Any] = []
39
+ self.done: list[bool] = []
40
+ self.timeout: list[bool] = []
41
+
42
+ # Optional fields - all-or-nothing consistency
43
+ self.logprob: Optional[list[Any]] = None
44
+ self.value: Optional[list[Any]] = None
45
+ self.entropy: Optional[list[Any]] = None
46
+
47
+ def append(
48
+ self,
49
+ obs,
50
+ act,
51
+ rew,
52
+ next_obs,
53
+ done,
54
+ timeout=False,
55
+ logprob=None,
56
+ value=None,
57
+ entropy=None
58
+ ):
59
+ """
60
+ Append a single environment transition to the episode.
61
+
62
+ Args:
63
+ obs: Observation from environment
64
+ act: Action taken
65
+ rew: Reward received
66
+ next_obs: Next observation
67
+ done: Boolean indicating episode termination
68
+ timeout: Boolean indicating truncation (e.g. time limit)
69
+ logprob: Log probability of action (optional, but must be all-or-nothing)
70
+ value: Estimated state value (optional, but must be all-or-nothing)
71
+ entropy: Action entropy (optional, but must be all-or-nothing)
72
+
73
+ Raises:
74
+ ValueError: If optional fields are inconsistent (some provided, some missing)
75
+
76
+ Example:
77
+ episode.append(obs=obs, act=act, rew=rew, next_obs=next_obs, done=done)
78
+ """
79
+ # Store required fields
80
+ self.obs.append(obs)
81
+ self.act.append(act)
82
+ self.rew.append(rew)
83
+ self.next_obs.append(next_obs)
84
+ self.done.append(done)
85
+ self.timeout.append(timeout)
86
+
87
+ # Handle logprob: must be all-or-nothing
88
+ if logprob is not None:
89
+ if self.logprob is None:
90
+ self.logprob = []
91
+ self.logprob.append(logprob)
92
+ else:
93
+ if self.logprob is not None:
94
+ raise ValueError(
95
+ "This episode includes logprob, but one step is missing it. "
96
+ "Either provide logprob for all steps or none."
97
+ )
98
+
99
+ # Handle value: must be all-or-nothing
100
+ if value is not None:
101
+ if self.value is None:
102
+ self.value = []
103
+ self.value.append(value)
104
+ else:
105
+ if self.value is not None:
106
+ raise ValueError(
107
+ "This episode includes value, but one step is missing it. "
108
+ "Either provide value for all steps or none."
109
+ )
110
+
111
+ # Handle entropy: must be all-or-nothing
112
+ if entropy is not None:
113
+ if self.entropy is None:
114
+ self.entropy = []
115
+ self.entropy.append(entropy)
116
+ else:
117
+ if self.entropy is not None:
118
+ raise ValueError(
119
+ "This episode includes entropy, but one step is missing it. "
120
+ "Either provide entropy for all steps or none."
121
+ )
122
+
123
+ def __len__(self) -> int:
124
+ """Return the number of steps in this episode."""
125
+ return len(self.obs)
126
+
127
+ def to_tensor_dict(self) -> Dict[str, mx.array]:
128
+ """
129
+ Convert all stored data to MLX arrays for training.
130
+ Performed once at sample time for efficiency on Apple Silicon and MLX.
131
+
132
+ Returns:
133
+ Dict of MLX arrays with keys:
134
+ - 'obs': (T, *obs_shape) - observations
135
+ - 'act': (T, *act_shape) - actions
136
+ - 'rew': (T,) - rewards
137
+ - 'next_obs': (T, *obs_shape) - next observations
138
+ - 'done': (T,) - termination flags
139
+ - 'timeout': (T,) - truncation flags
140
+ - 'logprob': (T,) - log probabilities (if provided)
141
+ - 'value': (T,) - value estimates (if provided)
142
+ - 'entropy': (T,) - action entropy (if provided)
143
+
144
+ Notes:
145
+ This runs once at sample time and uses batched array conversion.
146
+ """
147
+ # Batched array conversion for improved memory efficiency
148
+ # Convert to numpy first, then a single MLX array
149
+ import numpy as np
150
+
151
+ # Convert required fields to MLX arrays - BATCHED APPROACH
152
+ try:
153
+ # Try numpy-based batched conversion first (most efficient)
154
+ obs_np = np.array(self.obs)
155
+ next_obs_np = np.array(self.next_obs)
156
+ act_np = np.array(self.act)
157
+
158
+ result = {
159
+ 'obs': mx.array(obs_np), # Single batched conversion
160
+ 'act': mx.array(act_np), # Single batched conversion
161
+ 'rew': mx.array(self.rew), # Already efficient for scalars
162
+ 'next_obs': mx.array(next_obs_np), # Single batched conversion
163
+ 'done': mx.array(self.done), # Already efficient for booleans
164
+ 'timeout': mx.array(self.timeout), # Already efficient for booleans
165
+ }
166
+ except (ValueError, TypeError):
167
+ # Batch conversion fallback with pre-allocation
168
+ # (for heterogeneous data types or complex structures)
169
+ try:
170
+ # Try batch conversion first (faster for homogeneous data)
171
+ import numpy as np
172
+ result = {
173
+ 'obs': mx.array(np.array(self.obs)),
174
+ 'act': mx.array(np.array(self.act)),
175
+ 'rew': mx.array(self.rew),
176
+ 'next_obs': mx.array(np.array(self.next_obs)),
177
+ 'done': mx.array(self.done),
178
+ 'timeout': mx.array(self.timeout),
179
+ }
180
+ except:
181
+ # Fallback for heterogeneous data - try stacking first
182
+ try:
183
+ result = {
184
+ 'obs': mx.stack([mx.array(o) for o in self.obs]),
185
+ 'act': mx.stack([mx.array(a) for a in self.act]),
186
+ 'rew': mx.array(self.rew),
187
+ 'next_obs': mx.stack([mx.array(o) for o in self.next_obs]),
188
+ 'done': mx.array(self.done),
189
+ 'timeout': mx.array(self.timeout),
190
+ }
191
+ except:
192
+ # Final fallback for truly heterogeneous shapes - return as list of arrays
193
+ # This handles cases where observations have completely different shapes
194
+ result = {
195
+ 'obs': [mx.array(o) for o in self.obs],
196
+ 'act': [mx.array(a) for a in self.act] if not all(isinstance(a, (int, float)) for a in self.act) else mx.array(self.act),
197
+ 'rew': mx.array(self.rew),
198
+ 'next_obs': [mx.array(o) for o in self.next_obs],
199
+ 'done': mx.array(self.done),
200
+ 'timeout': mx.array(self.timeout),
201
+ }
202
+
203
+ # Add optional fields if present - handle variable-length sequences properly
204
+ if self.logprob is not None:
205
+ # Handle variable-length logprob sequences (common in text generation)
206
+ # Each transition may have different response lengths, so we flatten them
207
+ try:
208
+ # Try direct conversion first (for uniform lengths)
209
+ result['logprob'] = mx.array(self.logprob)
210
+ except ValueError as e:
211
+ if "non-uniform length" in str(e):
212
+ # Handle variable-length sequences by flattening
213
+ # This preserves all logprob data while making it MLX-compatible
214
+ flattened_logprobs = []
215
+ for logprob_item in self.logprob:
216
+ if hasattr(logprob_item, 'tolist'): # MLX array
217
+ flattened_logprobs.extend(logprob_item.tolist())
218
+ elif isinstance(logprob_item, list): # Python list
219
+ flattened_logprobs.extend(logprob_item)
220
+ else: # Single value
221
+ flattened_logprobs.append(float(logprob_item))
222
+ result['logprob'] = mx.array(flattened_logprobs) if flattened_logprobs else mx.array([])
223
+ else:
224
+ # Re-raise other ValueError types
225
+ raise
226
+
227
+ if self.value is not None:
228
+ # Apply same variable-length handling to value if needed
229
+ try:
230
+ result['value'] = mx.array(self.value)
231
+ except ValueError as e:
232
+ if "non-uniform length" in str(e):
233
+ flattened_values = []
234
+ for value_item in self.value:
235
+ if hasattr(value_item, 'tolist'): # MLX array
236
+ flattened_values.extend(value_item.tolist())
237
+ elif isinstance(value_item, list): # Python list
238
+ flattened_values.extend(value_item)
239
+ else: # Single value
240
+ flattened_values.append(float(value_item))
241
+ result['value'] = mx.array(flattened_values) if flattened_values else mx.array([])
242
+ else:
243
+ raise
244
+
245
+ if self.entropy is not None:
246
+ # Apply same variable-length handling to entropy if needed
247
+ try:
248
+ result['entropy'] = mx.array(self.entropy)
249
+ except ValueError as e:
250
+ if "non-uniform length" in str(e):
251
+ flattened_entropy = []
252
+ for entropy_item in self.entropy:
253
+ if hasattr(entropy_item, 'tolist'): # MLX array
254
+ flattened_entropy.extend(entropy_item.tolist())
255
+ elif isinstance(entropy_item, list): # Python list
256
+ flattened_entropy.extend(entropy_item)
257
+ else: # Single value
258
+ flattened_entropy.append(float(entropy_item))
259
+ result['entropy'] = mx.array(flattened_entropy) if flattened_entropy else mx.array([])
260
+ else:
261
+ raise
262
+
263
+ return result
264
+
265
+ def to_dict(self) -> Dict[str, Any]:
266
+ """
267
+ Convert episode to dictionary for serialization (multiprocessing).
268
+
269
+ Used for inter-process communication where MLX arrays can't be shared.
270
+ This preserves all data as Python-native types for queue transmission.
271
+
272
+ Returns:
273
+ Dictionary representation with all Python-native types.
274
+ This is the inverse of creating an episode from a dict.
275
+
276
+ Example:
277
+ # In worker process
278
+ ep_dict = episode.to_dict()
279
+ queue.put(ep_dict)
280
+
281
+ # In trainer process
282
+ buffer.add_episode_from_dict(ep_dict)
283
+ """
284
+ # Always include required fields
285
+ result = {
286
+ 'obs': self.obs,
287
+ 'act': self.act,
288
+ 'rew': self.rew,
289
+ 'next_obs': self.next_obs,
290
+ 'done': self.done,
291
+ 'timeout': self.timeout,
292
+ }
293
+
294
+ # Add optional fields if present
295
+ if self.logprob is not None:
296
+ result['logprob'] = self.logprob
297
+ if self.value is not None:
298
+ result['value'] = self.value
299
+ if self.entropy is not None:
300
+ result['entropy'] = self.entropy
301
+
302
+ return result
303
+
304
+ @classmethod
305
+ def from_dict(cls, data: Dict[str, Any]) -> 'Episode':
306
+ """
307
+ Create Episode from dictionary representation (for deserialization).
308
+
309
+ This is the inverse of to_dict() - reconstructs an Episode from
310
+ serialized dictionary data, typically used after inter-process
311
+ communication where Episode objects are transmitted as dicts.
312
+
313
+ Args:
314
+ data: Dictionary containing episode data with Python-native types
315
+
316
+ Returns:
317
+ New Episode instance with data from the dictionary
318
+
319
+ Example:
320
+ # Reconstruct episode from serialized data
321
+ episode = Episode.from_dict(ep_dict)
322
+ """
323
+ episode = cls()
324
+
325
+ # Reconstruct episode by appending each step
326
+ length = len(data['obs'])
327
+ for i in range(length):
328
+ step_data = {
329
+ 'obs': data['obs'][i],
330
+ 'act': data['act'][i],
331
+ 'rew': data['rew'][i],
332
+ 'next_obs': data['next_obs'][i],
333
+ 'done': data['done'][i],
334
+ 'timeout': data['timeout'][i] if i < len(data['timeout']) else False
335
+ }
336
+
337
+ # Add optional fields if present in the data
338
+ if 'logprob' in data and i < len(data['logprob']):
339
+ step_data['logprob'] = data['logprob'][i]
340
+ if 'value' in data and i < len(data['value']):
341
+ step_data['value'] = data['value'][i]
342
+ if 'entropy' in data and i < len(data['entropy']):
343
+ step_data['entropy'] = data['entropy'][i]
344
+
345
+ episode.append(**step_data)
346
+
347
+ return episode
348
+
349
+ def validate_consistency(self):
350
+ """
351
+ Validate internal consistency of episode data.
352
+
353
+ Checks:
354
+ - All required fields have same length
355
+ - Optional fields have correct length if present
356
+ - Episode has at least one step
357
+
358
+ Raises:
359
+ ValueError: If episode data is inconsistent
360
+ """
361
+ if len(self) == 0:
362
+ raise ValueError("Episode is empty")
363
+
364
+ # Check required fields have consistent length
365
+ required_lengths = [
366
+ len(self.obs), len(self.act), len(self.rew),
367
+ len(self.next_obs), len(self.done), len(self.timeout)
368
+ ]
369
+
370
+ if not all(length == required_lengths[0] for length in required_lengths):
371
+ raise ValueError(f"Inconsistent required field lengths: {required_lengths}")
372
+
373
+ # Check optional fields have correct length if present
374
+ episode_length = len(self.obs)
375
+
376
+ if self.logprob is not None and len(self.logprob) != episode_length:
377
+ raise ValueError(f"logprob length {len(self.logprob)} != episode length {episode_length}")
378
+
379
+ if self.value is not None and len(self.value) != episode_length:
380
+ raise ValueError(f"value length {len(self.value)} != episode length {episode_length}")
381
+
382
+ if self.entropy is not None and len(self.entropy) != episode_length:
383
+ raise ValueError(f"entropy length {len(self.entropy)} != episode length {episode_length}")