openadapt-ml 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. openadapt_ml/baselines/__init__.py +121 -0
  2. openadapt_ml/baselines/adapter.py +185 -0
  3. openadapt_ml/baselines/cli.py +314 -0
  4. openadapt_ml/baselines/config.py +448 -0
  5. openadapt_ml/baselines/parser.py +922 -0
  6. openadapt_ml/baselines/prompts.py +787 -0
  7. openadapt_ml/benchmarks/__init__.py +13 -115
  8. openadapt_ml/benchmarks/agent.py +265 -421
  9. openadapt_ml/benchmarks/azure.py +28 -19
  10. openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
  11. openadapt_ml/benchmarks/cli.py +1722 -4847
  12. openadapt_ml/benchmarks/trace_export.py +631 -0
  13. openadapt_ml/benchmarks/viewer.py +22 -5
  14. openadapt_ml/benchmarks/vm_monitor.py +530 -29
  15. openadapt_ml/benchmarks/waa_deploy/Dockerfile +47 -53
  16. openadapt_ml/benchmarks/waa_deploy/api_agent.py +21 -20
  17. openadapt_ml/cloud/azure_inference.py +3 -5
  18. openadapt_ml/cloud/lambda_labs.py +722 -307
  19. openadapt_ml/cloud/local.py +2038 -487
  20. openadapt_ml/cloud/ssh_tunnel.py +68 -26
  21. openadapt_ml/datasets/next_action.py +40 -30
  22. openadapt_ml/evals/grounding.py +8 -3
  23. openadapt_ml/evals/plot_eval_metrics.py +15 -13
  24. openadapt_ml/evals/trajectory_matching.py +41 -26
  25. openadapt_ml/experiments/demo_prompt/format_demo.py +16 -6
  26. openadapt_ml/experiments/demo_prompt/run_experiment.py +26 -16
  27. openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
  28. openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
  29. openadapt_ml/experiments/representation_shootout/config.py +390 -0
  30. openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
  31. openadapt_ml/experiments/representation_shootout/runner.py +687 -0
  32. openadapt_ml/experiments/waa_demo/runner.py +29 -14
  33. openadapt_ml/export/parquet.py +36 -24
  34. openadapt_ml/grounding/detector.py +18 -14
  35. openadapt_ml/ingest/__init__.py +8 -6
  36. openadapt_ml/ingest/capture.py +25 -22
  37. openadapt_ml/ingest/loader.py +7 -4
  38. openadapt_ml/ingest/synthetic.py +189 -100
  39. openadapt_ml/models/api_adapter.py +14 -4
  40. openadapt_ml/models/base_adapter.py +10 -2
  41. openadapt_ml/models/providers/__init__.py +288 -0
  42. openadapt_ml/models/providers/anthropic.py +266 -0
  43. openadapt_ml/models/providers/base.py +299 -0
  44. openadapt_ml/models/providers/google.py +376 -0
  45. openadapt_ml/models/providers/openai.py +342 -0
  46. openadapt_ml/models/qwen_vl.py +46 -19
  47. openadapt_ml/perception/__init__.py +35 -0
  48. openadapt_ml/perception/integration.py +399 -0
  49. openadapt_ml/retrieval/demo_retriever.py +50 -24
  50. openadapt_ml/retrieval/embeddings.py +9 -8
  51. openadapt_ml/retrieval/retriever.py +3 -1
  52. openadapt_ml/runtime/__init__.py +50 -0
  53. openadapt_ml/runtime/policy.py +18 -5
  54. openadapt_ml/runtime/safety_gate.py +471 -0
  55. openadapt_ml/schema/__init__.py +9 -0
  56. openadapt_ml/schema/converters.py +74 -27
  57. openadapt_ml/schema/episode.py +31 -18
  58. openadapt_ml/scripts/capture_screenshots.py +530 -0
  59. openadapt_ml/scripts/compare.py +85 -54
  60. openadapt_ml/scripts/demo_policy.py +4 -1
  61. openadapt_ml/scripts/eval_policy.py +15 -9
  62. openadapt_ml/scripts/make_gif.py +1 -1
  63. openadapt_ml/scripts/prepare_synthetic.py +3 -1
  64. openadapt_ml/scripts/train.py +21 -9
  65. openadapt_ml/segmentation/README.md +920 -0
  66. openadapt_ml/segmentation/__init__.py +97 -0
  67. openadapt_ml/segmentation/adapters/__init__.py +5 -0
  68. openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
  69. openadapt_ml/segmentation/annotator.py +610 -0
  70. openadapt_ml/segmentation/cache.py +290 -0
  71. openadapt_ml/segmentation/cli.py +674 -0
  72. openadapt_ml/segmentation/deduplicator.py +656 -0
  73. openadapt_ml/segmentation/frame_describer.py +788 -0
  74. openadapt_ml/segmentation/pipeline.py +340 -0
  75. openadapt_ml/segmentation/schemas.py +622 -0
  76. openadapt_ml/segmentation/segment_extractor.py +634 -0
  77. openadapt_ml/training/azure_ops_viewer.py +1097 -0
  78. openadapt_ml/training/benchmark_viewer.py +52 -41
  79. openadapt_ml/training/shared_ui.py +7 -7
  80. openadapt_ml/training/stub_provider.py +57 -35
  81. openadapt_ml/training/trainer.py +143 -86
  82. openadapt_ml/training/trl_trainer.py +70 -21
  83. openadapt_ml/training/viewer.py +323 -108
  84. openadapt_ml/training/viewer_components.py +180 -0
  85. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +215 -14
  86. openadapt_ml-0.2.1.dist-info/RECORD +116 -0
  87. openadapt_ml/benchmarks/base.py +0 -366
  88. openadapt_ml/benchmarks/data_collection.py +0 -432
  89. openadapt_ml/benchmarks/live_tracker.py +0 -180
  90. openadapt_ml/benchmarks/runner.py +0 -418
  91. openadapt_ml/benchmarks/waa.py +0 -761
  92. openadapt_ml/benchmarks/waa_live.py +0 -619
  93. openadapt_ml-0.2.0.dist-info/RECORD +0 -86
  94. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
  95. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,610 @@
1
+ """VLM-based episode annotation for training data quality control.
2
+
3
+ This module provides automatic annotation of extracted episodes using
4
+ Vision-Language Models to determine which episodes are suitable for
5
+ training ("gold") and which should be excluded.
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ from datetime import datetime
11
+ from pathlib import Path
12
+ from typing import Optional, Union
13
+
14
+ from PIL import Image
15
+
16
+ from openadapt_ml.segmentation.schemas import (
17
+ AnnotatedEpisodeLibrary,
18
+ Episode,
19
+ EpisodeAnnotation,
20
+ EpisodeExtractionResult,
21
+ )
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class EpisodeAnnotator:
27
+ """Annotates episodes using VLM analysis for training data quality.
28
+
29
+ This class examines episode frames and post-episode frames to:
30
+ 1. Identify precise episode boundaries
31
+ 2. Detect failure signals (errors, undos, repeated attempts)
32
+ 3. Assess whether the workflow completed successfully
33
+ 4. Generate is_gold recommendations
34
+
35
+ Example:
36
+ >>> annotator = EpisodeAnnotator(model="gemini-2.0-flash")
37
+ >>> library = annotator.annotate_episodes(
38
+ ... episodes=extraction_result.episodes,
39
+ ... recording_path="/path/to/recording",
40
+ ... )
41
+ >>> print(f"Found {library.gold_count} gold episodes")
42
+
43
+ Attributes:
44
+ model: VLM model identifier
45
+ lookback_frames: Number of frames to analyze before episode
46
+ lookahead_frames: Number of frames to analyze after episode
47
+ confidence_threshold: Minimum confidence to mark as gold
48
+ """
49
+
50
+ SUPPORTED_MODELS = [
51
+ "gemini-2.0-flash",
52
+ "gemini-2.0-pro",
53
+ "claude-sonnet-4-20250514",
54
+ "claude-3-5-haiku-20241022",
55
+ "gpt-4o",
56
+ "gpt-4o-mini",
57
+ ]
58
+
59
+ # Common failure signals to detect
60
+ FAILURE_INDICATORS = [
61
+ "error",
62
+ "failed",
63
+ "undo",
64
+ "cancel",
65
+ "retry",
66
+ "oops",
67
+ "wrong",
68
+ "delete",
69
+ "remove",
70
+ "revert",
71
+ "back",
72
+ "ctrl+z",
73
+ "cmd+z",
74
+ ]
75
+
76
+ def __init__(
77
+ self,
78
+ model: str = "gemini-2.0-flash",
79
+ lookback_frames: int = 3,
80
+ lookahead_frames: int = 10,
81
+ confidence_threshold: float = 0.7,
82
+ api_key: Optional[str] = None,
83
+ ) -> None:
84
+ """Initialize the episode annotator.
85
+
86
+ Args:
87
+ model: VLM model to use for analysis.
88
+ lookback_frames: Number of frames to check before episode start.
89
+ lookahead_frames: Number of frames to check after episode end.
90
+ confidence_threshold: Minimum confidence to recommend as gold.
91
+ api_key: API key for VLM provider (uses env var if not provided).
92
+ """
93
+ self.model = model
94
+ self.lookback_frames = lookback_frames
95
+ self.lookahead_frames = lookahead_frames
96
+ self.confidence_threshold = confidence_threshold
97
+ self._api_key = api_key
98
+ self._client = None
99
+
100
+ def _get_client(self):
101
+ """Get or create VLM client."""
102
+ if self._client is not None:
103
+ return self._client
104
+
105
+ from openadapt_ml.config import settings
106
+
107
+ if "gemini" in self.model.lower():
108
+ import google.generativeai as genai
109
+
110
+ api_key = self._api_key or settings.google_api_key
111
+ if not api_key:
112
+ raise ValueError("GOOGLE_API_KEY not set")
113
+ genai.configure(api_key=api_key)
114
+ self._client = genai.GenerativeModel(self.model)
115
+ elif "claude" in self.model.lower():
116
+ import anthropic
117
+
118
+ api_key = self._api_key or settings.anthropic_api_key
119
+ self._client = anthropic.Anthropic(api_key=api_key)
120
+ elif "gpt" in self.model.lower():
121
+ import openai
122
+
123
+ api_key = self._api_key or settings.openai_api_key
124
+ self._client = openai.OpenAI(api_key=api_key)
125
+ else:
126
+ raise ValueError(f"Unknown model: {self.model}")
127
+
128
+ return self._client
129
+
130
+ def _encode_image(self, image: Image.Image) -> dict:
131
+ """Encode image for API calls."""
132
+ import base64
133
+ import io
134
+
135
+ buffer = io.BytesIO()
136
+ image.save(buffer, format="PNG")
137
+ b64 = base64.b64encode(buffer.getvalue()).decode()
138
+ return b64
139
+
140
+ def _get_annotation_prompt(
141
+ self,
142
+ episode: Episode,
143
+ has_post_frames: bool,
144
+ ) -> str:
145
+ """Generate prompt for episode annotation."""
146
+ return f"""You are analyzing a GUI workflow episode to determine if it should be included in a training dataset.
147
+
148
+ ## Episode Information
149
+ - **Name**: {episode.name}
150
+ - **Description**: {episode.description}
151
+ - **Duration**: {episode.start_time_formatted} - {episode.end_time_formatted}
152
+ - **Steps**: {", ".join(episode.step_summaries)}
153
+ - **Application**: {episode.application}
154
+
155
+ ## Analysis Task
156
+
157
+ Examine the provided screenshots and determine:
158
+
159
+ 1. **Boundary Accuracy**: Are the episode boundaries (start/end frames) correct?
160
+ - Does the first frame show the actual start of the workflow?
161
+ - Does the last frame show the actual completion?
162
+
163
+ 2. **Workflow Completeness**: Did the workflow complete successfully?
164
+ - Were all steps executed?
165
+ - Is there a clear completion state visible?
166
+
167
+ 3. **Failure Detection**: Look for any signs of failure:
168
+ - Error dialogs or messages
169
+ - User performing undo actions (Ctrl+Z, etc.)
170
+ - Repeated attempts at the same action
171
+ - User navigating back or canceling
172
+ - Signs of frustration (rapid clicking, erratic movements)
173
+
174
+ {"4. **Post-Episode Analysis**: Examine frames AFTER the episode ended:" if has_post_frames else ""}
175
+ {" - Are there error dialogs appearing after completion?" if has_post_frames else ""}
176
+ {" - Does the user immediately undo or retry the task?" if has_post_frames else ""}
177
+ {" - Is there evidence the workflow actually failed?" if has_post_frames else ""}
178
+
179
+ ## Response Format
180
+
181
+ Respond with JSON:
182
+ ```json
183
+ {{
184
+ "is_gold": true/false,
185
+ "confidence": 0.0-1.0,
186
+ "start_frame_correct": true/false,
187
+ "end_frame_correct": true/false,
188
+ "suggested_start_offset": 0,
189
+ "suggested_end_offset": 0,
190
+ "workflow_complete": true/false,
191
+ "failure_signals": ["list of detected issues"],
192
+ "exclusion_reason": "reason if not gold, null if gold",
193
+ "analysis_notes": "brief explanation of assessment"
194
+ }}
195
+ ```
196
+
197
+ **Guidelines for is_gold**:
198
+ - TRUE if: Workflow completed successfully, no errors visible, episode is coherent and self-contained
199
+ - FALSE if: Any errors detected, incomplete workflow, user had to retry, or evidence of failure in post-frames
200
+ """
201
+
202
+ def _call_vlm(
203
+ self,
204
+ prompt: str,
205
+ images: list[Image.Image],
206
+ ) -> dict:
207
+ """Call VLM with images and return parsed response."""
208
+ client = self._get_client()
209
+
210
+ if "gemini" in self.model.lower():
211
+ content = [prompt] + images
212
+ response = client.generate_content(content)
213
+ text = response.text
214
+ elif "claude" in self.model.lower():
215
+ content = []
216
+ for img in images:
217
+ b64 = self._encode_image(img)
218
+ content.append(
219
+ {
220
+ "type": "image",
221
+ "source": {
222
+ "type": "base64",
223
+ "media_type": "image/png",
224
+ "data": b64,
225
+ },
226
+ }
227
+ )
228
+ content.append({"type": "text", "text": prompt})
229
+ response = client.messages.create(
230
+ model=self.model,
231
+ max_tokens=2048,
232
+ messages=[{"role": "user", "content": content}],
233
+ )
234
+ text = response.content[0].text
235
+ elif "gpt" in self.model.lower():
236
+ content = []
237
+ for img in images:
238
+ b64 = self._encode_image(img)
239
+ content.append(
240
+ {
241
+ "type": "image_url",
242
+ "image_url": {"url": f"data:image/png;base64,{b64}"},
243
+ }
244
+ )
245
+ content.append({"type": "text", "text": prompt})
246
+ response = client.chat.completions.create(
247
+ model=self.model,
248
+ max_tokens=2048,
249
+ messages=[{"role": "user", "content": content}],
250
+ )
251
+ text = response.choices[0].message.content
252
+ else:
253
+ raise ValueError(f"Unknown model: {self.model}")
254
+
255
+ # Parse JSON from response
256
+ try:
257
+ start = text.find("{")
258
+ end = text.rfind("}") + 1
259
+ if start >= 0 and end > start:
260
+ return json.loads(text[start:end])
261
+ except json.JSONDecodeError:
262
+ pass
263
+
264
+ # Return default if parsing failed
265
+ return {
266
+ "is_gold": False,
267
+ "confidence": 0.3,
268
+ "failure_signals": ["Failed to parse VLM response"],
269
+ "exclusion_reason": "VLM response parsing failed",
270
+ "analysis_notes": text[:200],
271
+ }
272
+
273
+ def _load_frames(
274
+ self,
275
+ recording_path: Path,
276
+ frame_indices: list[int],
277
+ ) -> list[Image.Image]:
278
+ """Load frames from recording directory."""
279
+ images = []
280
+ screenshots_dir = recording_path / "screenshots"
281
+
282
+ if not screenshots_dir.exists():
283
+ # Try direct directory with numbered PNGs
284
+ png_files = sorted(recording_path.glob("*.png"))
285
+ if png_files:
286
+ for idx in frame_indices:
287
+ if 0 <= idx < len(png_files):
288
+ try:
289
+ images.append(Image.open(png_files[idx]))
290
+ except Exception as e:
291
+ logger.warning(f"Failed to load frame {idx}: {e}")
292
+ return images
293
+
294
+ # Load from screenshots directory
295
+ for idx in frame_indices:
296
+ path = screenshots_dir / f"{idx:06d}.png"
297
+ if path.exists():
298
+ try:
299
+ images.append(Image.open(path))
300
+ except Exception as e:
301
+ logger.warning(f"Failed to load frame {idx}: {e}")
302
+
303
+ return images
304
+
305
+ def annotate_episode(
306
+ self,
307
+ episode: Episode,
308
+ recording_path: Union[str, Path],
309
+ total_frames: int,
310
+ ) -> EpisodeAnnotation:
311
+ """Annotate a single episode.
312
+
313
+ Args:
314
+ episode: Episode to annotate.
315
+ recording_path: Path to the recording directory.
316
+ total_frames: Total number of frames in the recording.
317
+
318
+ Returns:
319
+ EpisodeAnnotation with VLM-generated assessment.
320
+ """
321
+ recording_path = Path(recording_path)
322
+
323
+ # Determine frame ranges to analyze
324
+ start_frame = min(episode.frame_indices) if episode.frame_indices else 0
325
+ end_frame = max(episode.frame_indices) if episode.frame_indices else 0
326
+
327
+ # Get episode frames (sample if too many)
328
+ episode_frames = episode.frame_indices
329
+ if len(episode_frames) > 10:
330
+ # Sample: first 3, middle 4, last 3
331
+ sampled = (
332
+ episode_frames[:3]
333
+ + episode_frames[
334
+ len(episode_frames) // 2 - 2 : len(episode_frames) // 2 + 2
335
+ ]
336
+ + episode_frames[-3:]
337
+ )
338
+ episode_frames = sorted(set(sampled))
339
+
340
+ # Get post-episode frames
341
+ post_start = end_frame + 1
342
+ post_end = min(end_frame + self.lookahead_frames + 1, total_frames)
343
+ post_frames = list(range(post_start, post_end))
344
+
345
+ # Load images
346
+ all_frames = episode_frames + post_frames
347
+ images = self._load_frames(recording_path, all_frames)
348
+
349
+ if not images:
350
+ logger.warning(f"No frames loaded for episode {episode.episode_id}")
351
+ return EpisodeAnnotation(
352
+ episode_id=episode.episode_id,
353
+ start_frame=start_frame,
354
+ end_frame=end_frame,
355
+ is_gold=False,
356
+ exclusion_reason="Failed to load episode frames",
357
+ confidence=0.0,
358
+ failure_signals=["No frames available for analysis"],
359
+ )
360
+
361
+ # Generate annotation
362
+ prompt = self._get_annotation_prompt(
363
+ episode=episode,
364
+ has_post_frames=len(post_frames) > 0,
365
+ )
366
+
367
+ result = self._call_vlm(prompt, images)
368
+
369
+ # Apply boundary adjustments
370
+ adjusted_start = start_frame + result.get("suggested_start_offset", 0)
371
+ adjusted_end = end_frame + result.get("suggested_end_offset", 0)
372
+
373
+ return EpisodeAnnotation(
374
+ episode_id=episode.episode_id,
375
+ start_frame=max(0, adjusted_start),
376
+ end_frame=min(total_frames - 1, adjusted_end),
377
+ is_gold=result.get("is_gold", False)
378
+ and result.get("confidence", 0) >= self.confidence_threshold,
379
+ exclusion_reason=result.get("exclusion_reason"),
380
+ confidence=result.get("confidence", 0.5),
381
+ failure_signals=result.get("failure_signals", []),
382
+ )
383
+
384
+ def annotate_episodes(
385
+ self,
386
+ episodes: list[Episode],
387
+ recording_path: Union[str, Path],
388
+ total_frames: Optional[int] = None,
389
+ progress_callback: Optional[callable] = None,
390
+ ) -> AnnotatedEpisodeLibrary:
391
+ """Annotate multiple episodes from a recording.
392
+
393
+ Args:
394
+ episodes: List of episodes to annotate.
395
+ recording_path: Path to the recording directory.
396
+ total_frames: Total number of frames (auto-detected if not provided).
397
+ progress_callback: Optional callback(current, total) for progress.
398
+
399
+ Returns:
400
+ AnnotatedEpisodeLibrary with all episodes and annotations.
401
+ """
402
+ recording_path = Path(recording_path)
403
+
404
+ # Auto-detect total frames if not provided
405
+ if total_frames is None:
406
+ screenshots_dir = recording_path / "screenshots"
407
+ if screenshots_dir.exists():
408
+ total_frames = len(list(screenshots_dir.glob("*.png")))
409
+ else:
410
+ total_frames = len(list(recording_path.glob("*.png")))
411
+
412
+ annotations = []
413
+ for i, episode in enumerate(episodes):
414
+ logger.info(f"Annotating episode {i + 1}/{len(episodes)}: {episode.name}")
415
+
416
+ annotation = self.annotate_episode(
417
+ episode=episode,
418
+ recording_path=recording_path,
419
+ total_frames=total_frames,
420
+ )
421
+ annotations.append(annotation)
422
+
423
+ if progress_callback:
424
+ progress_callback(i + 1, len(episodes))
425
+
426
+ # Build library
427
+ recording_ids = list(set(e.recording_id for e in episodes))
428
+
429
+ return AnnotatedEpisodeLibrary(
430
+ episodes=episodes,
431
+ annotations=annotations,
432
+ source_recordings=recording_ids,
433
+ )
434
+
435
+ def annotate_extraction_result(
436
+ self,
437
+ extraction_result: EpisodeExtractionResult,
438
+ recording_path: Union[str, Path],
439
+ total_frames: Optional[int] = None,
440
+ progress_callback: Optional[callable] = None,
441
+ ) -> AnnotatedEpisodeLibrary:
442
+ """Annotate all episodes from an extraction result.
443
+
444
+ Args:
445
+ extraction_result: Output from SegmentExtractor.
446
+ recording_path: Path to the recording directory.
447
+ total_frames: Total number of frames.
448
+ progress_callback: Optional callback for progress.
449
+
450
+ Returns:
451
+ AnnotatedEpisodeLibrary with annotations.
452
+ """
453
+ return self.annotate_episodes(
454
+ episodes=extraction_result.episodes,
455
+ recording_path=recording_path,
456
+ total_frames=total_frames,
457
+ progress_callback=progress_callback,
458
+ )
459
+
460
+
461
+ def verify_annotation(
462
+ annotation: EpisodeAnnotation,
463
+ is_gold: bool,
464
+ notes: Optional[str] = None,
465
+ verified_by: Optional[str] = None,
466
+ ) -> EpisodeAnnotation:
467
+ """Update an annotation with human verification.
468
+
469
+ Args:
470
+ annotation: The annotation to verify.
471
+ is_gold: Human decision on gold status.
472
+ notes: Optional notes from the reviewer.
473
+ verified_by: Name/ID of the person verifying.
474
+
475
+ Returns:
476
+ Updated EpisodeAnnotation with human_verified=True.
477
+ """
478
+ return EpisodeAnnotation(
479
+ annotation_id=annotation.annotation_id,
480
+ episode_id=annotation.episode_id,
481
+ start_frame=annotation.start_frame,
482
+ end_frame=annotation.end_frame,
483
+ is_gold=is_gold,
484
+ exclusion_reason=annotation.exclusion_reason if not is_gold else None,
485
+ confidence=annotation.confidence,
486
+ human_verified=True,
487
+ notes=notes or annotation.notes,
488
+ failure_signals=annotation.failure_signals,
489
+ created_at=annotation.created_at,
490
+ verified_at=datetime.now(),
491
+ verified_by=verified_by,
492
+ )
493
+
494
+
495
+ def export_gold_episodes(
496
+ library: AnnotatedEpisodeLibrary,
497
+ output_path: Union[str, Path],
498
+ recording_path: Optional[Union[str, Path]] = None,
499
+ format: str = "jsonl",
500
+ include_screenshots: bool = False,
501
+ ) -> int:
502
+ """Export gold episodes for fine-tuning.
503
+
504
+ Only exports episodes where is_gold=True AND human_verified=True.
505
+
506
+ Args:
507
+ library: AnnotatedEpisodeLibrary to export from.
508
+ output_path: Path to output file/directory.
509
+ recording_path: Path to recording (needed if include_screenshots=True).
510
+ format: Export format ("jsonl", "json", or "hf" for HuggingFace).
511
+ include_screenshots: Whether to include screenshots in export.
512
+
513
+ Returns:
514
+ Number of episodes exported.
515
+ """
516
+ output_path = Path(output_path)
517
+
518
+ # Get verified gold episodes
519
+ gold_episodes = library.get_verified_gold_episodes()
520
+
521
+ if not gold_episodes:
522
+ logger.warning("No verified gold episodes to export")
523
+ return 0
524
+
525
+ if format == "jsonl":
526
+ output_path.parent.mkdir(parents=True, exist_ok=True)
527
+ with open(output_path, "w") as f:
528
+ for episode, annotation in gold_episodes:
529
+ record = {
530
+ "episode_id": str(episode.episode_id),
531
+ "name": episode.name,
532
+ "description": episode.description,
533
+ "application": episode.application,
534
+ "steps": episode.step_summaries,
535
+ "start_frame": annotation.start_frame,
536
+ "end_frame": annotation.end_frame,
537
+ "recording_id": episode.recording_id,
538
+ "annotation_confidence": annotation.confidence,
539
+ "verified_by": annotation.verified_by,
540
+ "notes": annotation.notes,
541
+ }
542
+ f.write(json.dumps(record) + "\n")
543
+
544
+ elif format == "json":
545
+ output_path.parent.mkdir(parents=True, exist_ok=True)
546
+ records = []
547
+ for episode, annotation in gold_episodes:
548
+ records.append(
549
+ {
550
+ "episode_id": str(episode.episode_id),
551
+ "name": episode.name,
552
+ "description": episode.description,
553
+ "application": episode.application,
554
+ "steps": episode.step_summaries,
555
+ "start_frame": annotation.start_frame,
556
+ "end_frame": annotation.end_frame,
557
+ "start_time": episode.start_time,
558
+ "end_time": episode.end_time,
559
+ "recording_id": episode.recording_id,
560
+ "annotation_confidence": annotation.confidence,
561
+ "verified_by": annotation.verified_by,
562
+ "notes": annotation.notes,
563
+ }
564
+ )
565
+ output_path.write_text(json.dumps(records, indent=2))
566
+
567
+ elif format == "hf":
568
+ # Export in HuggingFace datasets format
569
+ output_path.mkdir(parents=True, exist_ok=True)
570
+ records = []
571
+ for episode, annotation in gold_episodes:
572
+ record = {
573
+ "episode_id": str(episode.episode_id),
574
+ "task_name": episode.name,
575
+ "task_description": episode.description,
576
+ "application": episode.application,
577
+ "steps": episode.step_summaries,
578
+ "frame_indices": list(
579
+ range(annotation.start_frame, annotation.end_frame + 1)
580
+ ),
581
+ "recording_id": episode.recording_id,
582
+ }
583
+
584
+ if include_screenshots and recording_path:
585
+ # Load and save screenshots
586
+ episode_dir = output_path / str(episode.episode_id)
587
+ episode_dir.mkdir(parents=True, exist_ok=True)
588
+
589
+ screenshots_src = Path(recording_path) / "screenshots"
590
+ screenshot_paths = []
591
+ for idx in range(annotation.start_frame, annotation.end_frame + 1):
592
+ src = screenshots_src / f"{idx:06d}.png"
593
+ if src.exists():
594
+ dst = episode_dir / f"frame_{idx:06d}.png"
595
+ import shutil
596
+
597
+ shutil.copy(src, dst)
598
+ screenshot_paths.append(str(dst))
599
+ record["screenshot_paths"] = screenshot_paths
600
+
601
+ records.append(record)
602
+
603
+ # Save metadata
604
+ (output_path / "metadata.json").write_text(json.dumps(records, indent=2))
605
+
606
+ else:
607
+ raise ValueError(f"Unknown export format: {format}")
608
+
609
+ logger.info(f"Exported {len(gold_episodes)} gold episodes to {output_path}")
610
+ return len(gold_episodes)