openadapt-ml 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/baselines/__init__.py +121 -0
- openadapt_ml/baselines/adapter.py +185 -0
- openadapt_ml/baselines/cli.py +314 -0
- openadapt_ml/baselines/config.py +448 -0
- openadapt_ml/baselines/parser.py +922 -0
- openadapt_ml/baselines/prompts.py +787 -0
- openadapt_ml/benchmarks/__init__.py +13 -115
- openadapt_ml/benchmarks/agent.py +265 -421
- openadapt_ml/benchmarks/azure.py +28 -19
- openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
- openadapt_ml/benchmarks/cli.py +1722 -4847
- openadapt_ml/benchmarks/trace_export.py +631 -0
- openadapt_ml/benchmarks/viewer.py +22 -5
- openadapt_ml/benchmarks/vm_monitor.py +530 -29
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +47 -53
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +21 -20
- openadapt_ml/cloud/azure_inference.py +3 -5
- openadapt_ml/cloud/lambda_labs.py +722 -307
- openadapt_ml/cloud/local.py +2038 -487
- openadapt_ml/cloud/ssh_tunnel.py +68 -26
- openadapt_ml/datasets/next_action.py +40 -30
- openadapt_ml/evals/grounding.py +8 -3
- openadapt_ml/evals/plot_eval_metrics.py +15 -13
- openadapt_ml/evals/trajectory_matching.py +41 -26
- openadapt_ml/experiments/demo_prompt/format_demo.py +16 -6
- openadapt_ml/experiments/demo_prompt/run_experiment.py +26 -16
- openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
- openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
- openadapt_ml/experiments/representation_shootout/config.py +390 -0
- openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
- openadapt_ml/experiments/representation_shootout/runner.py +687 -0
- openadapt_ml/experiments/waa_demo/runner.py +29 -14
- openadapt_ml/export/parquet.py +36 -24
- openadapt_ml/grounding/detector.py +18 -14
- openadapt_ml/ingest/__init__.py +8 -6
- openadapt_ml/ingest/capture.py +25 -22
- openadapt_ml/ingest/loader.py +7 -4
- openadapt_ml/ingest/synthetic.py +189 -100
- openadapt_ml/models/api_adapter.py +14 -4
- openadapt_ml/models/base_adapter.py +10 -2
- openadapt_ml/models/providers/__init__.py +288 -0
- openadapt_ml/models/providers/anthropic.py +266 -0
- openadapt_ml/models/providers/base.py +299 -0
- openadapt_ml/models/providers/google.py +376 -0
- openadapt_ml/models/providers/openai.py +342 -0
- openadapt_ml/models/qwen_vl.py +46 -19
- openadapt_ml/perception/__init__.py +35 -0
- openadapt_ml/perception/integration.py +399 -0
- openadapt_ml/retrieval/demo_retriever.py +50 -24
- openadapt_ml/retrieval/embeddings.py +9 -8
- openadapt_ml/retrieval/retriever.py +3 -1
- openadapt_ml/runtime/__init__.py +50 -0
- openadapt_ml/runtime/policy.py +18 -5
- openadapt_ml/runtime/safety_gate.py +471 -0
- openadapt_ml/schema/__init__.py +9 -0
- openadapt_ml/schema/converters.py +74 -27
- openadapt_ml/schema/episode.py +31 -18
- openadapt_ml/scripts/capture_screenshots.py +530 -0
- openadapt_ml/scripts/compare.py +85 -54
- openadapt_ml/scripts/demo_policy.py +4 -1
- openadapt_ml/scripts/eval_policy.py +15 -9
- openadapt_ml/scripts/make_gif.py +1 -1
- openadapt_ml/scripts/prepare_synthetic.py +3 -1
- openadapt_ml/scripts/train.py +21 -9
- openadapt_ml/segmentation/README.md +920 -0
- openadapt_ml/segmentation/__init__.py +97 -0
- openadapt_ml/segmentation/adapters/__init__.py +5 -0
- openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
- openadapt_ml/segmentation/annotator.py +610 -0
- openadapt_ml/segmentation/cache.py +290 -0
- openadapt_ml/segmentation/cli.py +674 -0
- openadapt_ml/segmentation/deduplicator.py +656 -0
- openadapt_ml/segmentation/frame_describer.py +788 -0
- openadapt_ml/segmentation/pipeline.py +340 -0
- openadapt_ml/segmentation/schemas.py +622 -0
- openadapt_ml/segmentation/segment_extractor.py +634 -0
- openadapt_ml/training/azure_ops_viewer.py +1097 -0
- openadapt_ml/training/benchmark_viewer.py +52 -41
- openadapt_ml/training/shared_ui.py +7 -7
- openadapt_ml/training/stub_provider.py +57 -35
- openadapt_ml/training/trainer.py +143 -86
- openadapt_ml/training/trl_trainer.py +70 -21
- openadapt_ml/training/viewer.py +323 -108
- openadapt_ml/training/viewer_components.py +180 -0
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +215 -14
- openadapt_ml-0.2.1.dist-info/RECORD +116 -0
- openadapt_ml/benchmarks/base.py +0 -366
- openadapt_ml/benchmarks/data_collection.py +0 -432
- openadapt_ml/benchmarks/live_tracker.py +0 -180
- openadapt_ml/benchmarks/runner.py +0 -418
- openadapt_ml/benchmarks/waa.py +0 -761
- openadapt_ml/benchmarks/waa_live.py +0 -619
- openadapt_ml-0.2.0.dist-info/RECORD +0 -86
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,610 @@
|
|
|
1
|
+
"""VLM-based episode annotation for training data quality control.
|
|
2
|
+
|
|
3
|
+
This module provides automatic annotation of extracted episodes using
|
|
4
|
+
Vision-Language Models to determine which episodes are suitable for
|
|
5
|
+
training ("gold") and which should be excluded.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import logging
|
|
10
|
+
from datetime import datetime
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Optional, Union
|
|
13
|
+
|
|
14
|
+
from PIL import Image
|
|
15
|
+
|
|
16
|
+
from openadapt_ml.segmentation.schemas import (
|
|
17
|
+
AnnotatedEpisodeLibrary,
|
|
18
|
+
Episode,
|
|
19
|
+
EpisodeAnnotation,
|
|
20
|
+
EpisodeExtractionResult,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class EpisodeAnnotator:
|
|
27
|
+
"""Annotates episodes using VLM analysis for training data quality.
|
|
28
|
+
|
|
29
|
+
This class examines episode frames and post-episode frames to:
|
|
30
|
+
1. Identify precise episode boundaries
|
|
31
|
+
2. Detect failure signals (errors, undos, repeated attempts)
|
|
32
|
+
3. Assess whether the workflow completed successfully
|
|
33
|
+
4. Generate is_gold recommendations
|
|
34
|
+
|
|
35
|
+
Example:
|
|
36
|
+
>>> annotator = EpisodeAnnotator(model="gemini-2.0-flash")
|
|
37
|
+
>>> library = annotator.annotate_episodes(
|
|
38
|
+
... episodes=extraction_result.episodes,
|
|
39
|
+
... recording_path="/path/to/recording",
|
|
40
|
+
... )
|
|
41
|
+
>>> print(f"Found {library.gold_count} gold episodes")
|
|
42
|
+
|
|
43
|
+
Attributes:
|
|
44
|
+
model: VLM model identifier
|
|
45
|
+
lookback_frames: Number of frames to analyze before episode
|
|
46
|
+
lookahead_frames: Number of frames to analyze after episode
|
|
47
|
+
confidence_threshold: Minimum confidence to mark as gold
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
SUPPORTED_MODELS = [
|
|
51
|
+
"gemini-2.0-flash",
|
|
52
|
+
"gemini-2.0-pro",
|
|
53
|
+
"claude-sonnet-4-20250514",
|
|
54
|
+
"claude-3-5-haiku-20241022",
|
|
55
|
+
"gpt-4o",
|
|
56
|
+
"gpt-4o-mini",
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
# Common failure signals to detect
|
|
60
|
+
FAILURE_INDICATORS = [
|
|
61
|
+
"error",
|
|
62
|
+
"failed",
|
|
63
|
+
"undo",
|
|
64
|
+
"cancel",
|
|
65
|
+
"retry",
|
|
66
|
+
"oops",
|
|
67
|
+
"wrong",
|
|
68
|
+
"delete",
|
|
69
|
+
"remove",
|
|
70
|
+
"revert",
|
|
71
|
+
"back",
|
|
72
|
+
"ctrl+z",
|
|
73
|
+
"cmd+z",
|
|
74
|
+
]
|
|
75
|
+
|
|
76
|
+
def __init__(
|
|
77
|
+
self,
|
|
78
|
+
model: str = "gemini-2.0-flash",
|
|
79
|
+
lookback_frames: int = 3,
|
|
80
|
+
lookahead_frames: int = 10,
|
|
81
|
+
confidence_threshold: float = 0.7,
|
|
82
|
+
api_key: Optional[str] = None,
|
|
83
|
+
) -> None:
|
|
84
|
+
"""Initialize the episode annotator.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
model: VLM model to use for analysis.
|
|
88
|
+
lookback_frames: Number of frames to check before episode start.
|
|
89
|
+
lookahead_frames: Number of frames to check after episode end.
|
|
90
|
+
confidence_threshold: Minimum confidence to recommend as gold.
|
|
91
|
+
api_key: API key for VLM provider (uses env var if not provided).
|
|
92
|
+
"""
|
|
93
|
+
self.model = model
|
|
94
|
+
self.lookback_frames = lookback_frames
|
|
95
|
+
self.lookahead_frames = lookahead_frames
|
|
96
|
+
self.confidence_threshold = confidence_threshold
|
|
97
|
+
self._api_key = api_key
|
|
98
|
+
self._client = None
|
|
99
|
+
|
|
100
|
+
def _get_client(self):
|
|
101
|
+
"""Get or create VLM client."""
|
|
102
|
+
if self._client is not None:
|
|
103
|
+
return self._client
|
|
104
|
+
|
|
105
|
+
from openadapt_ml.config import settings
|
|
106
|
+
|
|
107
|
+
if "gemini" in self.model.lower():
|
|
108
|
+
import google.generativeai as genai
|
|
109
|
+
|
|
110
|
+
api_key = self._api_key or settings.google_api_key
|
|
111
|
+
if not api_key:
|
|
112
|
+
raise ValueError("GOOGLE_API_KEY not set")
|
|
113
|
+
genai.configure(api_key=api_key)
|
|
114
|
+
self._client = genai.GenerativeModel(self.model)
|
|
115
|
+
elif "claude" in self.model.lower():
|
|
116
|
+
import anthropic
|
|
117
|
+
|
|
118
|
+
api_key = self._api_key or settings.anthropic_api_key
|
|
119
|
+
self._client = anthropic.Anthropic(api_key=api_key)
|
|
120
|
+
elif "gpt" in self.model.lower():
|
|
121
|
+
import openai
|
|
122
|
+
|
|
123
|
+
api_key = self._api_key or settings.openai_api_key
|
|
124
|
+
self._client = openai.OpenAI(api_key=api_key)
|
|
125
|
+
else:
|
|
126
|
+
raise ValueError(f"Unknown model: {self.model}")
|
|
127
|
+
|
|
128
|
+
return self._client
|
|
129
|
+
|
|
130
|
+
def _encode_image(self, image: Image.Image) -> dict:
|
|
131
|
+
"""Encode image for API calls."""
|
|
132
|
+
import base64
|
|
133
|
+
import io
|
|
134
|
+
|
|
135
|
+
buffer = io.BytesIO()
|
|
136
|
+
image.save(buffer, format="PNG")
|
|
137
|
+
b64 = base64.b64encode(buffer.getvalue()).decode()
|
|
138
|
+
return b64
|
|
139
|
+
|
|
140
|
+
def _get_annotation_prompt(
|
|
141
|
+
self,
|
|
142
|
+
episode: Episode,
|
|
143
|
+
has_post_frames: bool,
|
|
144
|
+
) -> str:
|
|
145
|
+
"""Generate prompt for episode annotation."""
|
|
146
|
+
return f"""You are analyzing a GUI workflow episode to determine if it should be included in a training dataset.
|
|
147
|
+
|
|
148
|
+
## Episode Information
|
|
149
|
+
- **Name**: {episode.name}
|
|
150
|
+
- **Description**: {episode.description}
|
|
151
|
+
- **Duration**: {episode.start_time_formatted} - {episode.end_time_formatted}
|
|
152
|
+
- **Steps**: {", ".join(episode.step_summaries)}
|
|
153
|
+
- **Application**: {episode.application}
|
|
154
|
+
|
|
155
|
+
## Analysis Task
|
|
156
|
+
|
|
157
|
+
Examine the provided screenshots and determine:
|
|
158
|
+
|
|
159
|
+
1. **Boundary Accuracy**: Are the episode boundaries (start/end frames) correct?
|
|
160
|
+
- Does the first frame show the actual start of the workflow?
|
|
161
|
+
- Does the last frame show the actual completion?
|
|
162
|
+
|
|
163
|
+
2. **Workflow Completeness**: Did the workflow complete successfully?
|
|
164
|
+
- Were all steps executed?
|
|
165
|
+
- Is there a clear completion state visible?
|
|
166
|
+
|
|
167
|
+
3. **Failure Detection**: Look for any signs of failure:
|
|
168
|
+
- Error dialogs or messages
|
|
169
|
+
- User performing undo actions (Ctrl+Z, etc.)
|
|
170
|
+
- Repeated attempts at the same action
|
|
171
|
+
- User navigating back or canceling
|
|
172
|
+
- Signs of frustration (rapid clicking, erratic movements)
|
|
173
|
+
|
|
174
|
+
{"4. **Post-Episode Analysis**: Examine frames AFTER the episode ended:" if has_post_frames else ""}
|
|
175
|
+
{" - Are there error dialogs appearing after completion?" if has_post_frames else ""}
|
|
176
|
+
{" - Does the user immediately undo or retry the task?" if has_post_frames else ""}
|
|
177
|
+
{" - Is there evidence the workflow actually failed?" if has_post_frames else ""}
|
|
178
|
+
|
|
179
|
+
## Response Format
|
|
180
|
+
|
|
181
|
+
Respond with JSON:
|
|
182
|
+
```json
|
|
183
|
+
{{
|
|
184
|
+
"is_gold": true/false,
|
|
185
|
+
"confidence": 0.0-1.0,
|
|
186
|
+
"start_frame_correct": true/false,
|
|
187
|
+
"end_frame_correct": true/false,
|
|
188
|
+
"suggested_start_offset": 0,
|
|
189
|
+
"suggested_end_offset": 0,
|
|
190
|
+
"workflow_complete": true/false,
|
|
191
|
+
"failure_signals": ["list of detected issues"],
|
|
192
|
+
"exclusion_reason": "reason if not gold, null if gold",
|
|
193
|
+
"analysis_notes": "brief explanation of assessment"
|
|
194
|
+
}}
|
|
195
|
+
```
|
|
196
|
+
|
|
197
|
+
**Guidelines for is_gold**:
|
|
198
|
+
- TRUE if: Workflow completed successfully, no errors visible, episode is coherent and self-contained
|
|
199
|
+
- FALSE if: Any errors detected, incomplete workflow, user had to retry, or evidence of failure in post-frames
|
|
200
|
+
"""
|
|
201
|
+
|
|
202
|
+
def _call_vlm(
|
|
203
|
+
self,
|
|
204
|
+
prompt: str,
|
|
205
|
+
images: list[Image.Image],
|
|
206
|
+
) -> dict:
|
|
207
|
+
"""Call VLM with images and return parsed response."""
|
|
208
|
+
client = self._get_client()
|
|
209
|
+
|
|
210
|
+
if "gemini" in self.model.lower():
|
|
211
|
+
content = [prompt] + images
|
|
212
|
+
response = client.generate_content(content)
|
|
213
|
+
text = response.text
|
|
214
|
+
elif "claude" in self.model.lower():
|
|
215
|
+
content = []
|
|
216
|
+
for img in images:
|
|
217
|
+
b64 = self._encode_image(img)
|
|
218
|
+
content.append(
|
|
219
|
+
{
|
|
220
|
+
"type": "image",
|
|
221
|
+
"source": {
|
|
222
|
+
"type": "base64",
|
|
223
|
+
"media_type": "image/png",
|
|
224
|
+
"data": b64,
|
|
225
|
+
},
|
|
226
|
+
}
|
|
227
|
+
)
|
|
228
|
+
content.append({"type": "text", "text": prompt})
|
|
229
|
+
response = client.messages.create(
|
|
230
|
+
model=self.model,
|
|
231
|
+
max_tokens=2048,
|
|
232
|
+
messages=[{"role": "user", "content": content}],
|
|
233
|
+
)
|
|
234
|
+
text = response.content[0].text
|
|
235
|
+
elif "gpt" in self.model.lower():
|
|
236
|
+
content = []
|
|
237
|
+
for img in images:
|
|
238
|
+
b64 = self._encode_image(img)
|
|
239
|
+
content.append(
|
|
240
|
+
{
|
|
241
|
+
"type": "image_url",
|
|
242
|
+
"image_url": {"url": f"data:image/png;base64,{b64}"},
|
|
243
|
+
}
|
|
244
|
+
)
|
|
245
|
+
content.append({"type": "text", "text": prompt})
|
|
246
|
+
response = client.chat.completions.create(
|
|
247
|
+
model=self.model,
|
|
248
|
+
max_tokens=2048,
|
|
249
|
+
messages=[{"role": "user", "content": content}],
|
|
250
|
+
)
|
|
251
|
+
text = response.choices[0].message.content
|
|
252
|
+
else:
|
|
253
|
+
raise ValueError(f"Unknown model: {self.model}")
|
|
254
|
+
|
|
255
|
+
# Parse JSON from response
|
|
256
|
+
try:
|
|
257
|
+
start = text.find("{")
|
|
258
|
+
end = text.rfind("}") + 1
|
|
259
|
+
if start >= 0 and end > start:
|
|
260
|
+
return json.loads(text[start:end])
|
|
261
|
+
except json.JSONDecodeError:
|
|
262
|
+
pass
|
|
263
|
+
|
|
264
|
+
# Return default if parsing failed
|
|
265
|
+
return {
|
|
266
|
+
"is_gold": False,
|
|
267
|
+
"confidence": 0.3,
|
|
268
|
+
"failure_signals": ["Failed to parse VLM response"],
|
|
269
|
+
"exclusion_reason": "VLM response parsing failed",
|
|
270
|
+
"analysis_notes": text[:200],
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
def _load_frames(
|
|
274
|
+
self,
|
|
275
|
+
recording_path: Path,
|
|
276
|
+
frame_indices: list[int],
|
|
277
|
+
) -> list[Image.Image]:
|
|
278
|
+
"""Load frames from recording directory."""
|
|
279
|
+
images = []
|
|
280
|
+
screenshots_dir = recording_path / "screenshots"
|
|
281
|
+
|
|
282
|
+
if not screenshots_dir.exists():
|
|
283
|
+
# Try direct directory with numbered PNGs
|
|
284
|
+
png_files = sorted(recording_path.glob("*.png"))
|
|
285
|
+
if png_files:
|
|
286
|
+
for idx in frame_indices:
|
|
287
|
+
if 0 <= idx < len(png_files):
|
|
288
|
+
try:
|
|
289
|
+
images.append(Image.open(png_files[idx]))
|
|
290
|
+
except Exception as e:
|
|
291
|
+
logger.warning(f"Failed to load frame {idx}: {e}")
|
|
292
|
+
return images
|
|
293
|
+
|
|
294
|
+
# Load from screenshots directory
|
|
295
|
+
for idx in frame_indices:
|
|
296
|
+
path = screenshots_dir / f"{idx:06d}.png"
|
|
297
|
+
if path.exists():
|
|
298
|
+
try:
|
|
299
|
+
images.append(Image.open(path))
|
|
300
|
+
except Exception as e:
|
|
301
|
+
logger.warning(f"Failed to load frame {idx}: {e}")
|
|
302
|
+
|
|
303
|
+
return images
|
|
304
|
+
|
|
305
|
+
def annotate_episode(
|
|
306
|
+
self,
|
|
307
|
+
episode: Episode,
|
|
308
|
+
recording_path: Union[str, Path],
|
|
309
|
+
total_frames: int,
|
|
310
|
+
) -> EpisodeAnnotation:
|
|
311
|
+
"""Annotate a single episode.
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
episode: Episode to annotate.
|
|
315
|
+
recording_path: Path to the recording directory.
|
|
316
|
+
total_frames: Total number of frames in the recording.
|
|
317
|
+
|
|
318
|
+
Returns:
|
|
319
|
+
EpisodeAnnotation with VLM-generated assessment.
|
|
320
|
+
"""
|
|
321
|
+
recording_path = Path(recording_path)
|
|
322
|
+
|
|
323
|
+
# Determine frame ranges to analyze
|
|
324
|
+
start_frame = min(episode.frame_indices) if episode.frame_indices else 0
|
|
325
|
+
end_frame = max(episode.frame_indices) if episode.frame_indices else 0
|
|
326
|
+
|
|
327
|
+
# Get episode frames (sample if too many)
|
|
328
|
+
episode_frames = episode.frame_indices
|
|
329
|
+
if len(episode_frames) > 10:
|
|
330
|
+
# Sample: first 3, middle 4, last 3
|
|
331
|
+
sampled = (
|
|
332
|
+
episode_frames[:3]
|
|
333
|
+
+ episode_frames[
|
|
334
|
+
len(episode_frames) // 2 - 2 : len(episode_frames) // 2 + 2
|
|
335
|
+
]
|
|
336
|
+
+ episode_frames[-3:]
|
|
337
|
+
)
|
|
338
|
+
episode_frames = sorted(set(sampled))
|
|
339
|
+
|
|
340
|
+
# Get post-episode frames
|
|
341
|
+
post_start = end_frame + 1
|
|
342
|
+
post_end = min(end_frame + self.lookahead_frames + 1, total_frames)
|
|
343
|
+
post_frames = list(range(post_start, post_end))
|
|
344
|
+
|
|
345
|
+
# Load images
|
|
346
|
+
all_frames = episode_frames + post_frames
|
|
347
|
+
images = self._load_frames(recording_path, all_frames)
|
|
348
|
+
|
|
349
|
+
if not images:
|
|
350
|
+
logger.warning(f"No frames loaded for episode {episode.episode_id}")
|
|
351
|
+
return EpisodeAnnotation(
|
|
352
|
+
episode_id=episode.episode_id,
|
|
353
|
+
start_frame=start_frame,
|
|
354
|
+
end_frame=end_frame,
|
|
355
|
+
is_gold=False,
|
|
356
|
+
exclusion_reason="Failed to load episode frames",
|
|
357
|
+
confidence=0.0,
|
|
358
|
+
failure_signals=["No frames available for analysis"],
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
# Generate annotation
|
|
362
|
+
prompt = self._get_annotation_prompt(
|
|
363
|
+
episode=episode,
|
|
364
|
+
has_post_frames=len(post_frames) > 0,
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
result = self._call_vlm(prompt, images)
|
|
368
|
+
|
|
369
|
+
# Apply boundary adjustments
|
|
370
|
+
adjusted_start = start_frame + result.get("suggested_start_offset", 0)
|
|
371
|
+
adjusted_end = end_frame + result.get("suggested_end_offset", 0)
|
|
372
|
+
|
|
373
|
+
return EpisodeAnnotation(
|
|
374
|
+
episode_id=episode.episode_id,
|
|
375
|
+
start_frame=max(0, adjusted_start),
|
|
376
|
+
end_frame=min(total_frames - 1, adjusted_end),
|
|
377
|
+
is_gold=result.get("is_gold", False)
|
|
378
|
+
and result.get("confidence", 0) >= self.confidence_threshold,
|
|
379
|
+
exclusion_reason=result.get("exclusion_reason"),
|
|
380
|
+
confidence=result.get("confidence", 0.5),
|
|
381
|
+
failure_signals=result.get("failure_signals", []),
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
def annotate_episodes(
|
|
385
|
+
self,
|
|
386
|
+
episodes: list[Episode],
|
|
387
|
+
recording_path: Union[str, Path],
|
|
388
|
+
total_frames: Optional[int] = None,
|
|
389
|
+
progress_callback: Optional[callable] = None,
|
|
390
|
+
) -> AnnotatedEpisodeLibrary:
|
|
391
|
+
"""Annotate multiple episodes from a recording.
|
|
392
|
+
|
|
393
|
+
Args:
|
|
394
|
+
episodes: List of episodes to annotate.
|
|
395
|
+
recording_path: Path to the recording directory.
|
|
396
|
+
total_frames: Total number of frames (auto-detected if not provided).
|
|
397
|
+
progress_callback: Optional callback(current, total) for progress.
|
|
398
|
+
|
|
399
|
+
Returns:
|
|
400
|
+
AnnotatedEpisodeLibrary with all episodes and annotations.
|
|
401
|
+
"""
|
|
402
|
+
recording_path = Path(recording_path)
|
|
403
|
+
|
|
404
|
+
# Auto-detect total frames if not provided
|
|
405
|
+
if total_frames is None:
|
|
406
|
+
screenshots_dir = recording_path / "screenshots"
|
|
407
|
+
if screenshots_dir.exists():
|
|
408
|
+
total_frames = len(list(screenshots_dir.glob("*.png")))
|
|
409
|
+
else:
|
|
410
|
+
total_frames = len(list(recording_path.glob("*.png")))
|
|
411
|
+
|
|
412
|
+
annotations = []
|
|
413
|
+
for i, episode in enumerate(episodes):
|
|
414
|
+
logger.info(f"Annotating episode {i + 1}/{len(episodes)}: {episode.name}")
|
|
415
|
+
|
|
416
|
+
annotation = self.annotate_episode(
|
|
417
|
+
episode=episode,
|
|
418
|
+
recording_path=recording_path,
|
|
419
|
+
total_frames=total_frames,
|
|
420
|
+
)
|
|
421
|
+
annotations.append(annotation)
|
|
422
|
+
|
|
423
|
+
if progress_callback:
|
|
424
|
+
progress_callback(i + 1, len(episodes))
|
|
425
|
+
|
|
426
|
+
# Build library
|
|
427
|
+
recording_ids = list(set(e.recording_id for e in episodes))
|
|
428
|
+
|
|
429
|
+
return AnnotatedEpisodeLibrary(
|
|
430
|
+
episodes=episodes,
|
|
431
|
+
annotations=annotations,
|
|
432
|
+
source_recordings=recording_ids,
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
def annotate_extraction_result(
|
|
436
|
+
self,
|
|
437
|
+
extraction_result: EpisodeExtractionResult,
|
|
438
|
+
recording_path: Union[str, Path],
|
|
439
|
+
total_frames: Optional[int] = None,
|
|
440
|
+
progress_callback: Optional[callable] = None,
|
|
441
|
+
) -> AnnotatedEpisodeLibrary:
|
|
442
|
+
"""Annotate all episodes from an extraction result.
|
|
443
|
+
|
|
444
|
+
Args:
|
|
445
|
+
extraction_result: Output from SegmentExtractor.
|
|
446
|
+
recording_path: Path to the recording directory.
|
|
447
|
+
total_frames: Total number of frames.
|
|
448
|
+
progress_callback: Optional callback for progress.
|
|
449
|
+
|
|
450
|
+
Returns:
|
|
451
|
+
AnnotatedEpisodeLibrary with annotations.
|
|
452
|
+
"""
|
|
453
|
+
return self.annotate_episodes(
|
|
454
|
+
episodes=extraction_result.episodes,
|
|
455
|
+
recording_path=recording_path,
|
|
456
|
+
total_frames=total_frames,
|
|
457
|
+
progress_callback=progress_callback,
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
def verify_annotation(
|
|
462
|
+
annotation: EpisodeAnnotation,
|
|
463
|
+
is_gold: bool,
|
|
464
|
+
notes: Optional[str] = None,
|
|
465
|
+
verified_by: Optional[str] = None,
|
|
466
|
+
) -> EpisodeAnnotation:
|
|
467
|
+
"""Update an annotation with human verification.
|
|
468
|
+
|
|
469
|
+
Args:
|
|
470
|
+
annotation: The annotation to verify.
|
|
471
|
+
is_gold: Human decision on gold status.
|
|
472
|
+
notes: Optional notes from the reviewer.
|
|
473
|
+
verified_by: Name/ID of the person verifying.
|
|
474
|
+
|
|
475
|
+
Returns:
|
|
476
|
+
Updated EpisodeAnnotation with human_verified=True.
|
|
477
|
+
"""
|
|
478
|
+
return EpisodeAnnotation(
|
|
479
|
+
annotation_id=annotation.annotation_id,
|
|
480
|
+
episode_id=annotation.episode_id,
|
|
481
|
+
start_frame=annotation.start_frame,
|
|
482
|
+
end_frame=annotation.end_frame,
|
|
483
|
+
is_gold=is_gold,
|
|
484
|
+
exclusion_reason=annotation.exclusion_reason if not is_gold else None,
|
|
485
|
+
confidence=annotation.confidence,
|
|
486
|
+
human_verified=True,
|
|
487
|
+
notes=notes or annotation.notes,
|
|
488
|
+
failure_signals=annotation.failure_signals,
|
|
489
|
+
created_at=annotation.created_at,
|
|
490
|
+
verified_at=datetime.now(),
|
|
491
|
+
verified_by=verified_by,
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
def export_gold_episodes(
|
|
496
|
+
library: AnnotatedEpisodeLibrary,
|
|
497
|
+
output_path: Union[str, Path],
|
|
498
|
+
recording_path: Optional[Union[str, Path]] = None,
|
|
499
|
+
format: str = "jsonl",
|
|
500
|
+
include_screenshots: bool = False,
|
|
501
|
+
) -> int:
|
|
502
|
+
"""Export gold episodes for fine-tuning.
|
|
503
|
+
|
|
504
|
+
Only exports episodes where is_gold=True AND human_verified=True.
|
|
505
|
+
|
|
506
|
+
Args:
|
|
507
|
+
library: AnnotatedEpisodeLibrary to export from.
|
|
508
|
+
output_path: Path to output file/directory.
|
|
509
|
+
recording_path: Path to recording (needed if include_screenshots=True).
|
|
510
|
+
format: Export format ("jsonl", "json", or "hf" for HuggingFace).
|
|
511
|
+
include_screenshots: Whether to include screenshots in export.
|
|
512
|
+
|
|
513
|
+
Returns:
|
|
514
|
+
Number of episodes exported.
|
|
515
|
+
"""
|
|
516
|
+
output_path = Path(output_path)
|
|
517
|
+
|
|
518
|
+
# Get verified gold episodes
|
|
519
|
+
gold_episodes = library.get_verified_gold_episodes()
|
|
520
|
+
|
|
521
|
+
if not gold_episodes:
|
|
522
|
+
logger.warning("No verified gold episodes to export")
|
|
523
|
+
return 0
|
|
524
|
+
|
|
525
|
+
if format == "jsonl":
|
|
526
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
527
|
+
with open(output_path, "w") as f:
|
|
528
|
+
for episode, annotation in gold_episodes:
|
|
529
|
+
record = {
|
|
530
|
+
"episode_id": str(episode.episode_id),
|
|
531
|
+
"name": episode.name,
|
|
532
|
+
"description": episode.description,
|
|
533
|
+
"application": episode.application,
|
|
534
|
+
"steps": episode.step_summaries,
|
|
535
|
+
"start_frame": annotation.start_frame,
|
|
536
|
+
"end_frame": annotation.end_frame,
|
|
537
|
+
"recording_id": episode.recording_id,
|
|
538
|
+
"annotation_confidence": annotation.confidence,
|
|
539
|
+
"verified_by": annotation.verified_by,
|
|
540
|
+
"notes": annotation.notes,
|
|
541
|
+
}
|
|
542
|
+
f.write(json.dumps(record) + "\n")
|
|
543
|
+
|
|
544
|
+
elif format == "json":
|
|
545
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
546
|
+
records = []
|
|
547
|
+
for episode, annotation in gold_episodes:
|
|
548
|
+
records.append(
|
|
549
|
+
{
|
|
550
|
+
"episode_id": str(episode.episode_id),
|
|
551
|
+
"name": episode.name,
|
|
552
|
+
"description": episode.description,
|
|
553
|
+
"application": episode.application,
|
|
554
|
+
"steps": episode.step_summaries,
|
|
555
|
+
"start_frame": annotation.start_frame,
|
|
556
|
+
"end_frame": annotation.end_frame,
|
|
557
|
+
"start_time": episode.start_time,
|
|
558
|
+
"end_time": episode.end_time,
|
|
559
|
+
"recording_id": episode.recording_id,
|
|
560
|
+
"annotation_confidence": annotation.confidence,
|
|
561
|
+
"verified_by": annotation.verified_by,
|
|
562
|
+
"notes": annotation.notes,
|
|
563
|
+
}
|
|
564
|
+
)
|
|
565
|
+
output_path.write_text(json.dumps(records, indent=2))
|
|
566
|
+
|
|
567
|
+
elif format == "hf":
|
|
568
|
+
# Export in HuggingFace datasets format
|
|
569
|
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
570
|
+
records = []
|
|
571
|
+
for episode, annotation in gold_episodes:
|
|
572
|
+
record = {
|
|
573
|
+
"episode_id": str(episode.episode_id),
|
|
574
|
+
"task_name": episode.name,
|
|
575
|
+
"task_description": episode.description,
|
|
576
|
+
"application": episode.application,
|
|
577
|
+
"steps": episode.step_summaries,
|
|
578
|
+
"frame_indices": list(
|
|
579
|
+
range(annotation.start_frame, annotation.end_frame + 1)
|
|
580
|
+
),
|
|
581
|
+
"recording_id": episode.recording_id,
|
|
582
|
+
}
|
|
583
|
+
|
|
584
|
+
if include_screenshots and recording_path:
|
|
585
|
+
# Load and save screenshots
|
|
586
|
+
episode_dir = output_path / str(episode.episode_id)
|
|
587
|
+
episode_dir.mkdir(parents=True, exist_ok=True)
|
|
588
|
+
|
|
589
|
+
screenshots_src = Path(recording_path) / "screenshots"
|
|
590
|
+
screenshot_paths = []
|
|
591
|
+
for idx in range(annotation.start_frame, annotation.end_frame + 1):
|
|
592
|
+
src = screenshots_src / f"{idx:06d}.png"
|
|
593
|
+
if src.exists():
|
|
594
|
+
dst = episode_dir / f"frame_{idx:06d}.png"
|
|
595
|
+
import shutil
|
|
596
|
+
|
|
597
|
+
shutil.copy(src, dst)
|
|
598
|
+
screenshot_paths.append(str(dst))
|
|
599
|
+
record["screenshot_paths"] = screenshot_paths
|
|
600
|
+
|
|
601
|
+
records.append(record)
|
|
602
|
+
|
|
603
|
+
# Save metadata
|
|
604
|
+
(output_path / "metadata.json").write_text(json.dumps(records, indent=2))
|
|
605
|
+
|
|
606
|
+
else:
|
|
607
|
+
raise ValueError(f"Unknown export format: {format}")
|
|
608
|
+
|
|
609
|
+
logger.info(f"Exported {len(gold_episodes)} gold episodes to {output_path}")
|
|
610
|
+
return len(gold_episodes)
|