openadapt-ml 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. openadapt_ml/baselines/__init__.py +121 -0
  2. openadapt_ml/baselines/adapter.py +185 -0
  3. openadapt_ml/baselines/cli.py +314 -0
  4. openadapt_ml/baselines/config.py +448 -0
  5. openadapt_ml/baselines/parser.py +922 -0
  6. openadapt_ml/baselines/prompts.py +787 -0
  7. openadapt_ml/benchmarks/__init__.py +13 -115
  8. openadapt_ml/benchmarks/agent.py +265 -421
  9. openadapt_ml/benchmarks/azure.py +28 -19
  10. openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
  11. openadapt_ml/benchmarks/cli.py +1722 -4847
  12. openadapt_ml/benchmarks/trace_export.py +631 -0
  13. openadapt_ml/benchmarks/viewer.py +22 -5
  14. openadapt_ml/benchmarks/vm_monitor.py +530 -29
  15. openadapt_ml/benchmarks/waa_deploy/Dockerfile +47 -53
  16. openadapt_ml/benchmarks/waa_deploy/api_agent.py +21 -20
  17. openadapt_ml/cloud/azure_inference.py +3 -5
  18. openadapt_ml/cloud/lambda_labs.py +722 -307
  19. openadapt_ml/cloud/local.py +2038 -487
  20. openadapt_ml/cloud/ssh_tunnel.py +68 -26
  21. openadapt_ml/datasets/next_action.py +40 -30
  22. openadapt_ml/evals/grounding.py +8 -3
  23. openadapt_ml/evals/plot_eval_metrics.py +15 -13
  24. openadapt_ml/evals/trajectory_matching.py +41 -26
  25. openadapt_ml/experiments/demo_prompt/format_demo.py +16 -6
  26. openadapt_ml/experiments/demo_prompt/run_experiment.py +26 -16
  27. openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
  28. openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
  29. openadapt_ml/experiments/representation_shootout/config.py +390 -0
  30. openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
  31. openadapt_ml/experiments/representation_shootout/runner.py +687 -0
  32. openadapt_ml/experiments/waa_demo/runner.py +29 -14
  33. openadapt_ml/export/parquet.py +36 -24
  34. openadapt_ml/grounding/detector.py +18 -14
  35. openadapt_ml/ingest/__init__.py +8 -6
  36. openadapt_ml/ingest/capture.py +25 -22
  37. openadapt_ml/ingest/loader.py +7 -4
  38. openadapt_ml/ingest/synthetic.py +189 -100
  39. openadapt_ml/models/api_adapter.py +14 -4
  40. openadapt_ml/models/base_adapter.py +10 -2
  41. openadapt_ml/models/providers/__init__.py +288 -0
  42. openadapt_ml/models/providers/anthropic.py +266 -0
  43. openadapt_ml/models/providers/base.py +299 -0
  44. openadapt_ml/models/providers/google.py +376 -0
  45. openadapt_ml/models/providers/openai.py +342 -0
  46. openadapt_ml/models/qwen_vl.py +46 -19
  47. openadapt_ml/perception/__init__.py +35 -0
  48. openadapt_ml/perception/integration.py +399 -0
  49. openadapt_ml/retrieval/demo_retriever.py +50 -24
  50. openadapt_ml/retrieval/embeddings.py +9 -8
  51. openadapt_ml/retrieval/retriever.py +3 -1
  52. openadapt_ml/runtime/__init__.py +50 -0
  53. openadapt_ml/runtime/policy.py +18 -5
  54. openadapt_ml/runtime/safety_gate.py +471 -0
  55. openadapt_ml/schema/__init__.py +9 -0
  56. openadapt_ml/schema/converters.py +74 -27
  57. openadapt_ml/schema/episode.py +31 -18
  58. openadapt_ml/scripts/capture_screenshots.py +530 -0
  59. openadapt_ml/scripts/compare.py +85 -54
  60. openadapt_ml/scripts/demo_policy.py +4 -1
  61. openadapt_ml/scripts/eval_policy.py +15 -9
  62. openadapt_ml/scripts/make_gif.py +1 -1
  63. openadapt_ml/scripts/prepare_synthetic.py +3 -1
  64. openadapt_ml/scripts/train.py +21 -9
  65. openadapt_ml/segmentation/README.md +920 -0
  66. openadapt_ml/segmentation/__init__.py +97 -0
  67. openadapt_ml/segmentation/adapters/__init__.py +5 -0
  68. openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
  69. openadapt_ml/segmentation/annotator.py +610 -0
  70. openadapt_ml/segmentation/cache.py +290 -0
  71. openadapt_ml/segmentation/cli.py +674 -0
  72. openadapt_ml/segmentation/deduplicator.py +656 -0
  73. openadapt_ml/segmentation/frame_describer.py +788 -0
  74. openadapt_ml/segmentation/pipeline.py +340 -0
  75. openadapt_ml/segmentation/schemas.py +622 -0
  76. openadapt_ml/segmentation/segment_extractor.py +634 -0
  77. openadapt_ml/training/azure_ops_viewer.py +1097 -0
  78. openadapt_ml/training/benchmark_viewer.py +52 -41
  79. openadapt_ml/training/shared_ui.py +7 -7
  80. openadapt_ml/training/stub_provider.py +57 -35
  81. openadapt_ml/training/trainer.py +143 -86
  82. openadapt_ml/training/trl_trainer.py +70 -21
  83. openadapt_ml/training/viewer.py +323 -108
  84. openadapt_ml/training/viewer_components.py +180 -0
  85. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.2.dist-info}/METADATA +215 -14
  86. openadapt_ml-0.2.2.dist-info/RECORD +116 -0
  87. openadapt_ml/benchmarks/base.py +0 -366
  88. openadapt_ml/benchmarks/data_collection.py +0 -432
  89. openadapt_ml/benchmarks/live_tracker.py +0 -180
  90. openadapt_ml/benchmarks/runner.py +0 -418
  91. openadapt_ml/benchmarks/waa.py +0 -761
  92. openadapt_ml/benchmarks/waa_live.py +0 -619
  93. openadapt_ml-0.2.0.dist-info/RECORD +0 -86
  94. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.2.dist-info}/WHEEL +0 -0
  95. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,622 @@
1
+ """Data schemas for workflow segmentation.
2
+
3
+ This module defines the Pydantic models used throughout the
4
+ segmentation pipeline, ensuring type safety and validation.
5
+
6
+ In OpenAdapt terminology:
7
+ - "Episode" = A coherent workflow segment
8
+ - "Trajectory" = Sequence of observation-action pairs (full recording)
9
+ """
10
+
11
+ from datetime import datetime
12
+ from enum import Enum
13
+ from typing import Optional
14
+ from uuid import UUID, uuid4
15
+
16
+ from pydantic import BaseModel, Field, ConfigDict
17
+
18
+
19
+ class ActionType(str, Enum):
20
+ """Types of user actions that can be captured."""
21
+
22
+ CLICK = "click"
23
+ DOUBLE_CLICK = "double_click"
24
+ RIGHT_CLICK = "right_click"
25
+ TYPE = "type"
26
+ SCROLL = "scroll"
27
+ DRAG = "drag"
28
+ HOTKEY = "hotkey"
29
+ MOVE = "move"
30
+
31
+
32
+ class FrameDescription(BaseModel):
33
+ """Description of a single frame + action pair from VLM analysis.
34
+
35
+ This is the output of Stage 1 for each frame in the recording.
36
+ """
37
+
38
+ model_config = ConfigDict(arbitrary_types_allowed=True)
39
+
40
+ # Timing
41
+ timestamp: float = Field(description="Timestamp in seconds from recording start")
42
+ formatted_time: str = Field(description="Human-readable time format (MM:SS.m)")
43
+
44
+ # Screen context
45
+ visible_application: str = Field(
46
+ description="Primary application visible on screen"
47
+ )
48
+ visible_elements: list[str] = Field(
49
+ default_factory=list,
50
+ description="Notable UI elements visible in the frame",
51
+ )
52
+ screen_context: str = Field(description="Brief description of overall screen state")
53
+
54
+ # Action details
55
+ action_type: ActionType = Field(description="Type of action performed")
56
+ action_target: Optional[str] = Field(
57
+ default=None,
58
+ description="UI element that was the target of the action",
59
+ )
60
+ action_value: Optional[str] = Field(
61
+ default=None,
62
+ description="Value associated with action (e.g., typed text)",
63
+ )
64
+
65
+ # Semantic interpretation
66
+ apparent_intent: str = Field(
67
+ description="What the user appears to be trying to accomplish"
68
+ )
69
+ confidence: float = Field(
70
+ ge=0.0, le=1.0, description="VLM confidence in this description"
71
+ )
72
+
73
+ # Metadata
74
+ frame_index: int = Field(description="Index of this frame in the recording")
75
+ vlm_model: str = Field(description="Model used for description generation")
76
+
77
+ def to_transcript_line(self) -> str:
78
+ """Format as a single transcript line."""
79
+ return f"[{self.formatted_time}] {self.apparent_intent}"
80
+
81
+
82
+ class ActionTranscript(BaseModel):
83
+ """Complete transcript of a recording from VLM analysis.
84
+
85
+ This is the full output of Stage 1.
86
+ """
87
+
88
+ recording_id: str = Field(description="Unique identifier for the source recording")
89
+ recording_name: str = Field(description="Human-readable recording name")
90
+ task_description: Optional[str] = Field(
91
+ default=None,
92
+ description="User-provided task description (if available)",
93
+ )
94
+
95
+ # Frame descriptions
96
+ frames: list[FrameDescription] = Field(
97
+ default_factory=list,
98
+ description="Ordered list of frame descriptions",
99
+ )
100
+
101
+ # Processing metadata
102
+ total_duration: float = Field(description="Total recording duration in seconds")
103
+ frame_count: int = Field(description="Total number of frames processed")
104
+ vlm_model: str = Field(description="Primary VLM model used")
105
+ processing_timestamp: datetime = Field(
106
+ default_factory=datetime.now,
107
+ description="When this transcript was generated",
108
+ )
109
+
110
+ def to_transcript_text(self) -> str:
111
+ """Format as plain text transcript."""
112
+ lines = []
113
+ for frame in self.frames:
114
+ lines.append(frame.to_transcript_line())
115
+ return "\n".join(lines)
116
+
117
+ @property
118
+ def duration_formatted(self) -> str:
119
+ """Return duration as MM:SS format."""
120
+ minutes = int(self.total_duration // 60)
121
+ seconds = self.total_duration % 60
122
+ return f"{minutes:02d}:{seconds:05.2f}"
123
+
124
+
125
+ class EpisodeStep(BaseModel):
126
+ """A single step within an episode (workflow segment)."""
127
+
128
+ description: str = Field(description="What this step accomplishes")
129
+ start_timestamp: float = Field(description="Start time in seconds")
130
+ end_timestamp: float = Field(description="End time in seconds")
131
+ frame_indices: list[int] = Field(
132
+ default_factory=list,
133
+ description="Indices of frames belonging to this step",
134
+ )
135
+
136
+
137
+ class EpisodeBoundary(BaseModel):
138
+ """Represents a boundary between episodes with confidence."""
139
+
140
+ timestamp: float = Field(description="Time of the boundary")
141
+ confidence: float = Field(
142
+ ge=0.0,
143
+ le=1.0,
144
+ description="Confidence that this is a true episode boundary",
145
+ )
146
+ reason: str = Field(description="Explanation for why this is a boundary")
147
+
148
+
149
+ class Episode(BaseModel):
150
+ """A coherent workflow segment (episode) extracted from a recording.
151
+
152
+ This is the output of Stage 2 for each identified workflow.
153
+
154
+ In OpenAdapt, an Episode represents a self-contained unit of work
155
+ that can be used for:
156
+ - Training data for fine-tuning
157
+ - Demo conditioning context
158
+ - Workflow library building
159
+ """
160
+
161
+ model_config = ConfigDict(arbitrary_types_allowed=True)
162
+
163
+ # Identification
164
+ episode_id: UUID = Field(
165
+ default_factory=uuid4,
166
+ description="Unique identifier for this episode",
167
+ )
168
+ name: str = Field(
169
+ description="Concise name for this workflow (e.g., 'Adjust Night Shift Settings')"
170
+ )
171
+
172
+ # Timing
173
+ start_time: float = Field(description="Start timestamp in seconds")
174
+ end_time: float = Field(description="End timestamp in seconds")
175
+ start_time_formatted: str = Field(description="Formatted start time (MM:SS.m)")
176
+ end_time_formatted: str = Field(description="Formatted end time (MM:SS.m)")
177
+
178
+ # Content
179
+ description: str = Field(
180
+ description="Detailed description of what this workflow accomplishes"
181
+ )
182
+ steps: list[EpisodeStep] = Field(
183
+ default_factory=list,
184
+ description="Ordered list of steps in this workflow",
185
+ )
186
+ step_summaries: list[str] = Field(
187
+ default_factory=list,
188
+ description="Simple list of step descriptions for quick reference",
189
+ )
190
+
191
+ # Context
192
+ application: str = Field(description="Primary application used in this workflow")
193
+ prerequisites: list[str] = Field(
194
+ default_factory=list,
195
+ description="Conditions that must be true before starting",
196
+ )
197
+ outcomes: list[str] = Field(
198
+ default_factory=list,
199
+ description="Expected state changes after completion",
200
+ )
201
+
202
+ # Hierarchy
203
+ parent_episode_id: Optional[UUID] = Field(
204
+ default=None,
205
+ description="Parent episode if this is a subtask",
206
+ )
207
+ child_episode_ids: list[UUID] = Field(
208
+ default_factory=list,
209
+ description="Child episodes if this contains subtasks",
210
+ )
211
+
212
+ # Quality metrics
213
+ boundary_confidence: float = Field(
214
+ ge=0.0,
215
+ le=1.0,
216
+ description="Confidence in episode boundaries",
217
+ )
218
+ coherence_score: float = Field(
219
+ ge=0.0,
220
+ le=1.0,
221
+ description="How coherent/self-contained this episode is",
222
+ )
223
+
224
+ # Source
225
+ recording_id: str = Field(description="Source recording identifier")
226
+ frame_indices: list[int] = Field(
227
+ default_factory=list,
228
+ description="Indices of frames in this episode",
229
+ )
230
+
231
+ @property
232
+ def duration(self) -> float:
233
+ """Episode duration in seconds."""
234
+ return self.end_time - self.start_time
235
+
236
+ @property
237
+ def step_count(self) -> int:
238
+ """Number of steps in this episode."""
239
+ return len(self.steps)
240
+
241
+
242
+ class EpisodeExtractionResult(BaseModel):
243
+ """Complete extraction result for a single recording.
244
+
245
+ This is the full output of Stage 2.
246
+ """
247
+
248
+ recording_id: str = Field(description="Source recording identifier")
249
+ recording_name: str = Field(description="Human-readable recording name")
250
+
251
+ # Extracted episodes
252
+ episodes: list[Episode] = Field(
253
+ default_factory=list,
254
+ description="Extracted workflow episodes",
255
+ )
256
+
257
+ # Boundaries
258
+ boundaries: list[EpisodeBoundary] = Field(
259
+ default_factory=list,
260
+ description="All identified episode boundaries",
261
+ )
262
+
263
+ # Processing metadata
264
+ llm_model: str = Field(description="LLM model used for extraction")
265
+ processing_timestamp: datetime = Field(default_factory=datetime.now)
266
+
267
+ # Quality metrics
268
+ coverage: float = Field(
269
+ ge=0.0,
270
+ le=1.0,
271
+ description="Fraction of recording covered by episodes",
272
+ )
273
+ avg_confidence: float = Field(
274
+ ge=0.0,
275
+ le=1.0,
276
+ description="Average boundary confidence",
277
+ )
278
+
279
+
280
+ class CanonicalEpisode(BaseModel):
281
+ """A deduplicated, canonical episode definition.
282
+
283
+ This represents a workflow type that may appear across multiple recordings.
284
+ """
285
+
286
+ model_config = ConfigDict(arbitrary_types_allowed=True)
287
+
288
+ # Identification
289
+ canonical_id: UUID = Field(
290
+ default_factory=uuid4,
291
+ description="Unique identifier for this canonical episode",
292
+ )
293
+ canonical_name: str = Field(description="Standardized name for this workflow")
294
+
295
+ # Variants
296
+ variant_names: list[str] = Field(
297
+ default_factory=list,
298
+ description="Alternative names from merged episodes",
299
+ )
300
+ variant_descriptions: list[str] = Field(
301
+ default_factory=list,
302
+ description="Alternative descriptions from merged episodes",
303
+ )
304
+
305
+ # Source tracking
306
+ source_recordings: list[str] = Field(
307
+ default_factory=list,
308
+ description="Recording IDs containing this workflow",
309
+ )
310
+ source_episode_ids: list[UUID] = Field(
311
+ default_factory=list,
312
+ description="Original episode IDs that were merged",
313
+ )
314
+ occurrence_count: int = Field(
315
+ ge=1,
316
+ description="Number of times this workflow appears",
317
+ )
318
+
319
+ # Canonical definition
320
+ canonical_description: str = Field(
321
+ description="Best/merged description of this workflow"
322
+ )
323
+ canonical_steps: list[str] = Field(
324
+ default_factory=list,
325
+ description="Standardized step list",
326
+ )
327
+
328
+ # Embedding
329
+ embedding: Optional[list[float]] = Field(
330
+ default=None,
331
+ description="Vector embedding for similarity search",
332
+ )
333
+
334
+ # Clustering metadata
335
+ cluster_id: int = Field(default=0, description="Cluster ID from deduplication")
336
+ cluster_centroid_distance: float = Field(
337
+ default=0.0,
338
+ ge=0.0,
339
+ description="Distance from cluster centroid",
340
+ )
341
+
342
+ # Quality
343
+ internal_similarity: float = Field(
344
+ default=1.0,
345
+ ge=0.0,
346
+ le=1.0,
347
+ description="Average similarity between merged variants",
348
+ )
349
+
350
+
351
+ class EpisodeAnnotation(BaseModel):
352
+ """Annotation for an episode indicating its quality for training.
353
+
354
+ This model is used to mark episodes as "gold" (suitable for training)
355
+ or exclude them with reasons. VLM-based auto-annotation can populate
356
+ initial values, which humans can then verify.
357
+
358
+ Attributes:
359
+ annotation_id: Unique identifier for this annotation
360
+ episode_id: ID of the Episode being annotated
361
+ start_frame: Exact start frame index (refined from Episode)
362
+ end_frame: Exact end frame index (refined from Episode)
363
+ is_gold: Whether this episode should be included in training export
364
+ exclusion_reason: Why this episode was excluded (if not gold)
365
+ confidence: VLM confidence in the annotation (0-1)
366
+ human_verified: Whether a human has confirmed this annotation
367
+ notes: Optional human notes about the episode
368
+ failure_signals: Detected failure signals from post-episode analysis
369
+ created_at: When this annotation was created
370
+ verified_at: When a human verified this annotation
371
+ verified_by: Who verified this annotation
372
+ """
373
+
374
+ model_config = ConfigDict(arbitrary_types_allowed=True)
375
+
376
+ # Identification
377
+ annotation_id: UUID = Field(
378
+ default_factory=uuid4,
379
+ description="Unique identifier for this annotation",
380
+ )
381
+ episode_id: UUID = Field(
382
+ description="ID of the Episode being annotated",
383
+ )
384
+
385
+ # Refined boundaries
386
+ start_frame: int = Field(
387
+ ge=0,
388
+ description="Exact start frame index",
389
+ )
390
+ end_frame: int = Field(
391
+ ge=0,
392
+ description="Exact end frame index",
393
+ )
394
+
395
+ # Quality assessment
396
+ is_gold: bool = Field(
397
+ default=False,
398
+ description="Should this episode be included in training export?",
399
+ )
400
+ exclusion_reason: Optional[str] = Field(
401
+ default=None,
402
+ description="Why this episode was excluded (e.g., 'task failed', 'incomplete', 'error visible')",
403
+ )
404
+ confidence: float = Field(
405
+ ge=0.0,
406
+ le=1.0,
407
+ default=0.5,
408
+ description="VLM confidence in the annotation",
409
+ )
410
+
411
+ # Human verification
412
+ human_verified: bool = Field(
413
+ default=False,
414
+ description="Has a human confirmed this annotation?",
415
+ )
416
+ notes: Optional[str] = Field(
417
+ default=None,
418
+ description="Optional human notes about the episode",
419
+ )
420
+
421
+ # Failure detection
422
+ failure_signals: list[str] = Field(
423
+ default_factory=list,
424
+ description="Detected failure signals from post-episode analysis",
425
+ )
426
+
427
+ # Metadata
428
+ created_at: datetime = Field(
429
+ default_factory=datetime.now,
430
+ description="When this annotation was created",
431
+ )
432
+ verified_at: Optional[datetime] = Field(
433
+ default=None,
434
+ description="When a human verified this annotation",
435
+ )
436
+ verified_by: Optional[str] = Field(
437
+ default=None,
438
+ description="Who verified this annotation",
439
+ )
440
+
441
+
442
+ class AnnotatedEpisodeLibrary(BaseModel):
443
+ """Collection of episodes with their annotations.
444
+
445
+ This is used for reviewing, exporting, and managing annotated episodes.
446
+ """
447
+
448
+ model_config = ConfigDict(arbitrary_types_allowed=True)
449
+
450
+ # Identification
451
+ library_id: UUID = Field(
452
+ default_factory=uuid4,
453
+ description="Unique identifier for this library",
454
+ )
455
+ created_at: datetime = Field(default_factory=datetime.now)
456
+
457
+ # Content
458
+ episodes: list[Episode] = Field(
459
+ default_factory=list,
460
+ description="All episodes in this library",
461
+ )
462
+ annotations: list[EpisodeAnnotation] = Field(
463
+ default_factory=list,
464
+ description="Annotations for episodes",
465
+ )
466
+
467
+ # Source tracking
468
+ source_recordings: list[str] = Field(
469
+ default_factory=list,
470
+ description="Recording IDs that were processed",
471
+ )
472
+
473
+ # Statistics
474
+ @property
475
+ def total_episodes(self) -> int:
476
+ """Total number of episodes."""
477
+ return len(self.episodes)
478
+
479
+ @property
480
+ def annotated_count(self) -> int:
481
+ """Number of episodes with annotations."""
482
+ annotated_ids = {a.episode_id for a in self.annotations}
483
+ return len(annotated_ids)
484
+
485
+ @property
486
+ def gold_count(self) -> int:
487
+ """Number of gold episodes."""
488
+ return sum(1 for a in self.annotations if a.is_gold)
489
+
490
+ @property
491
+ def verified_count(self) -> int:
492
+ """Number of human-verified annotations."""
493
+ return sum(1 for a in self.annotations if a.human_verified)
494
+
495
+ @property
496
+ def export_ready_count(self) -> int:
497
+ """Number of episodes ready for export (gold AND verified)."""
498
+ return sum(1 for a in self.annotations if a.is_gold and a.human_verified)
499
+
500
+ def get_annotation(self, episode_id: UUID) -> Optional[EpisodeAnnotation]:
501
+ """Get annotation for a specific episode."""
502
+ for annotation in self.annotations:
503
+ if annotation.episode_id == episode_id:
504
+ return annotation
505
+ return None
506
+
507
+ def get_episode(self, episode_id: UUID) -> Optional[Episode]:
508
+ """Get episode by ID."""
509
+ for episode in self.episodes:
510
+ if episode.episode_id == episode_id:
511
+ return episode
512
+ return None
513
+
514
+ def get_gold_episodes(self) -> list[tuple[Episode, EpisodeAnnotation]]:
515
+ """Get all gold episodes with their annotations."""
516
+ result = []
517
+ for annotation in self.annotations:
518
+ if annotation.is_gold:
519
+ episode = self.get_episode(annotation.episode_id)
520
+ if episode:
521
+ result.append((episode, annotation))
522
+ return result
523
+
524
+ def get_verified_gold_episodes(self) -> list[tuple[Episode, EpisodeAnnotation]]:
525
+ """Get episodes that are both gold AND human-verified."""
526
+ result = []
527
+ for annotation in self.annotations:
528
+ if annotation.is_gold and annotation.human_verified:
529
+ episode = self.get_episode(annotation.episode_id)
530
+ if episode:
531
+ result.append((episode, annotation))
532
+ return result
533
+
534
+ def get_pending_review(self) -> list[tuple[Episode, EpisodeAnnotation]]:
535
+ """Get episodes that have annotations but need human verification."""
536
+ result = []
537
+ for annotation in self.annotations:
538
+ if not annotation.human_verified:
539
+ episode = self.get_episode(annotation.episode_id)
540
+ if episode:
541
+ result.append((episode, annotation))
542
+ return result
543
+
544
+ def to_dict(self) -> dict:
545
+ """Convert to dictionary for JSON serialization."""
546
+ return self.model_dump(mode="json")
547
+
548
+ @classmethod
549
+ def from_dict(cls, data: dict) -> "AnnotatedEpisodeLibrary":
550
+ """Create from dictionary."""
551
+ return cls.model_validate(data)
552
+
553
+
554
+ class EpisodeLibrary(BaseModel):
555
+ """Complete deduplicated episode library.
556
+
557
+ This is the final output of Stage 3 - a library of canonical
558
+ workflow episodes that can be used for training data curation,
559
+ demo conditioning, and workflow retrieval.
560
+ """
561
+
562
+ # Library metadata
563
+ library_id: UUID = Field(
564
+ default_factory=uuid4,
565
+ description="Unique identifier for this library version",
566
+ )
567
+ created_at: datetime = Field(default_factory=datetime.now)
568
+
569
+ # Workflows
570
+ episodes: list[CanonicalEpisode] = Field(
571
+ default_factory=list,
572
+ description="All canonical episodes",
573
+ )
574
+
575
+ # Statistics
576
+ total_recordings_processed: int = Field(
577
+ ge=0,
578
+ description="Number of recordings analyzed",
579
+ )
580
+ total_episodes_extracted: int = Field(
581
+ ge=0,
582
+ description="Total episodes before deduplication",
583
+ )
584
+ unique_episode_count: int = Field(
585
+ ge=0,
586
+ description="Number of unique episodes after deduplication",
587
+ )
588
+ deduplication_ratio: float = Field(
589
+ ge=0.0,
590
+ le=1.0,
591
+ description="Fraction of episodes that were duplicates",
592
+ )
593
+
594
+ # Processing parameters
595
+ similarity_threshold: float = Field(
596
+ ge=0.0,
597
+ le=1.0,
598
+ description="Threshold used for clustering",
599
+ )
600
+ embedding_model: str = Field(description="Model used for embeddings")
601
+
602
+ def get_episode_by_name(self, name: str) -> Optional[CanonicalEpisode]:
603
+ """Find episode by canonical name."""
604
+ for episode in self.episodes:
605
+ if episode.canonical_name.lower() == name.lower():
606
+ return episode
607
+ if name.lower() in [v.lower() for v in episode.variant_names]:
608
+ return episode
609
+ return None
610
+
611
+ def get_episodes_for_recording(self, recording_id: str) -> list[CanonicalEpisode]:
612
+ """Get all episodes that appear in a specific recording."""
613
+ return [e for e in self.episodes if recording_id in e.source_recordings]
614
+
615
+ def to_dict(self) -> dict:
616
+ """Convert to dictionary for JSON serialization."""
617
+ return self.model_dump(mode="json")
618
+
619
+ @classmethod
620
+ def from_dict(cls, data: dict) -> "EpisodeLibrary":
621
+ """Create from dictionary."""
622
+ return cls.model_validate(data)