openadapt-ml 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/baselines/__init__.py +121 -0
- openadapt_ml/baselines/adapter.py +185 -0
- openadapt_ml/baselines/cli.py +314 -0
- openadapt_ml/baselines/config.py +448 -0
- openadapt_ml/baselines/parser.py +922 -0
- openadapt_ml/baselines/prompts.py +787 -0
- openadapt_ml/benchmarks/__init__.py +13 -115
- openadapt_ml/benchmarks/agent.py +265 -421
- openadapt_ml/benchmarks/azure.py +28 -19
- openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
- openadapt_ml/benchmarks/cli.py +1722 -4847
- openadapt_ml/benchmarks/trace_export.py +631 -0
- openadapt_ml/benchmarks/viewer.py +22 -5
- openadapt_ml/benchmarks/vm_monitor.py +530 -29
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +47 -53
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +21 -20
- openadapt_ml/cloud/azure_inference.py +3 -5
- openadapt_ml/cloud/lambda_labs.py +722 -307
- openadapt_ml/cloud/local.py +2038 -487
- openadapt_ml/cloud/ssh_tunnel.py +68 -26
- openadapt_ml/datasets/next_action.py +40 -30
- openadapt_ml/evals/grounding.py +8 -3
- openadapt_ml/evals/plot_eval_metrics.py +15 -13
- openadapt_ml/evals/trajectory_matching.py +41 -26
- openadapt_ml/experiments/demo_prompt/format_demo.py +16 -6
- openadapt_ml/experiments/demo_prompt/run_experiment.py +26 -16
- openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
- openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
- openadapt_ml/experiments/representation_shootout/config.py +390 -0
- openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
- openadapt_ml/experiments/representation_shootout/runner.py +687 -0
- openadapt_ml/experiments/waa_demo/runner.py +29 -14
- openadapt_ml/export/parquet.py +36 -24
- openadapt_ml/grounding/detector.py +18 -14
- openadapt_ml/ingest/__init__.py +8 -6
- openadapt_ml/ingest/capture.py +25 -22
- openadapt_ml/ingest/loader.py +7 -4
- openadapt_ml/ingest/synthetic.py +189 -100
- openadapt_ml/models/api_adapter.py +14 -4
- openadapt_ml/models/base_adapter.py +10 -2
- openadapt_ml/models/providers/__init__.py +288 -0
- openadapt_ml/models/providers/anthropic.py +266 -0
- openadapt_ml/models/providers/base.py +299 -0
- openadapt_ml/models/providers/google.py +376 -0
- openadapt_ml/models/providers/openai.py +342 -0
- openadapt_ml/models/qwen_vl.py +46 -19
- openadapt_ml/perception/__init__.py +35 -0
- openadapt_ml/perception/integration.py +399 -0
- openadapt_ml/retrieval/demo_retriever.py +50 -24
- openadapt_ml/retrieval/embeddings.py +9 -8
- openadapt_ml/retrieval/retriever.py +3 -1
- openadapt_ml/runtime/__init__.py +50 -0
- openadapt_ml/runtime/policy.py +18 -5
- openadapt_ml/runtime/safety_gate.py +471 -0
- openadapt_ml/schema/__init__.py +9 -0
- openadapt_ml/schema/converters.py +74 -27
- openadapt_ml/schema/episode.py +31 -18
- openadapt_ml/scripts/capture_screenshots.py +530 -0
- openadapt_ml/scripts/compare.py +85 -54
- openadapt_ml/scripts/demo_policy.py +4 -1
- openadapt_ml/scripts/eval_policy.py +15 -9
- openadapt_ml/scripts/make_gif.py +1 -1
- openadapt_ml/scripts/prepare_synthetic.py +3 -1
- openadapt_ml/scripts/train.py +21 -9
- openadapt_ml/segmentation/README.md +920 -0
- openadapt_ml/segmentation/__init__.py +97 -0
- openadapt_ml/segmentation/adapters/__init__.py +5 -0
- openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
- openadapt_ml/segmentation/annotator.py +610 -0
- openadapt_ml/segmentation/cache.py +290 -0
- openadapt_ml/segmentation/cli.py +674 -0
- openadapt_ml/segmentation/deduplicator.py +656 -0
- openadapt_ml/segmentation/frame_describer.py +788 -0
- openadapt_ml/segmentation/pipeline.py +340 -0
- openadapt_ml/segmentation/schemas.py +622 -0
- openadapt_ml/segmentation/segment_extractor.py +634 -0
- openadapt_ml/training/azure_ops_viewer.py +1097 -0
- openadapt_ml/training/benchmark_viewer.py +52 -41
- openadapt_ml/training/shared_ui.py +7 -7
- openadapt_ml/training/stub_provider.py +57 -35
- openadapt_ml/training/trainer.py +143 -86
- openadapt_ml/training/trl_trainer.py +70 -21
- openadapt_ml/training/viewer.py +323 -108
- openadapt_ml/training/viewer_components.py +180 -0
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.2.dist-info}/METADATA +215 -14
- openadapt_ml-0.2.2.dist-info/RECORD +116 -0
- openadapt_ml/benchmarks/base.py +0 -366
- openadapt_ml/benchmarks/data_collection.py +0 -432
- openadapt_ml/benchmarks/live_tracker.py +0 -180
- openadapt_ml/benchmarks/runner.py +0 -418
- openadapt_ml/benchmarks/waa.py +0 -761
- openadapt_ml/benchmarks/waa_live.py +0 -619
- openadapt_ml-0.2.0.dist-info/RECORD +0 -86
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.2.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,788 @@
|
|
|
1
|
+
"""Frame-level description using Vision-Language Models.
|
|
2
|
+
|
|
3
|
+
This module processes recording frames with their associated actions
|
|
4
|
+
to generate semantic descriptions of user behavior (Stage 1 of pipeline).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import hashlib
|
|
8
|
+
import json
|
|
9
|
+
import logging
|
|
10
|
+
from abc import ABC, abstractmethod
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Optional, Union
|
|
14
|
+
|
|
15
|
+
from PIL import Image
|
|
16
|
+
|
|
17
|
+
from openadapt_ml.segmentation.schemas import (
|
|
18
|
+
ActionTranscript,
|
|
19
|
+
ActionType,
|
|
20
|
+
FrameDescription,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class VLMBackend(ABC):
|
|
27
|
+
"""Abstract base class for VLM backend implementations."""
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def describe_frame(
|
|
31
|
+
self,
|
|
32
|
+
image: Image.Image,
|
|
33
|
+
action_context: dict,
|
|
34
|
+
system_prompt: str,
|
|
35
|
+
user_prompt: str,
|
|
36
|
+
) -> dict:
|
|
37
|
+
"""Generate description for a single frame."""
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def describe_batch(
|
|
42
|
+
self,
|
|
43
|
+
images: list[Image.Image],
|
|
44
|
+
action_contexts: list[dict],
|
|
45
|
+
system_prompt: str,
|
|
46
|
+
user_prompt: str,
|
|
47
|
+
) -> list[dict]:
|
|
48
|
+
"""Generate descriptions for multiple frames (more efficient)."""
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class GeminiBackend(VLMBackend):
|
|
53
|
+
"""Google Gemini VLM backend."""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
model: str = "gemini-2.0-flash",
|
|
58
|
+
api_key: Optional[str] = None,
|
|
59
|
+
):
|
|
60
|
+
self.model = model
|
|
61
|
+
self._api_key = api_key
|
|
62
|
+
self._client = None
|
|
63
|
+
|
|
64
|
+
def _get_client(self):
|
|
65
|
+
if self._client is None:
|
|
66
|
+
import google.generativeai as genai
|
|
67
|
+
from openadapt_ml.config import settings
|
|
68
|
+
|
|
69
|
+
api_key = self._api_key or settings.google_api_key
|
|
70
|
+
if not api_key:
|
|
71
|
+
raise ValueError("GOOGLE_API_KEY not set")
|
|
72
|
+
genai.configure(api_key=api_key)
|
|
73
|
+
self._client = genai.GenerativeModel(self.model)
|
|
74
|
+
return self._client
|
|
75
|
+
|
|
76
|
+
def describe_frame(
|
|
77
|
+
self,
|
|
78
|
+
image: Image.Image,
|
|
79
|
+
action_context: dict,
|
|
80
|
+
system_prompt: str,
|
|
81
|
+
user_prompt: str,
|
|
82
|
+
) -> dict:
|
|
83
|
+
client = self._get_client()
|
|
84
|
+
full_prompt = f"{system_prompt}\n\n{user_prompt}"
|
|
85
|
+
response = client.generate_content([full_prompt, image])
|
|
86
|
+
return self._parse_response(response.text)
|
|
87
|
+
|
|
88
|
+
def describe_batch(
|
|
89
|
+
self,
|
|
90
|
+
images: list[Image.Image],
|
|
91
|
+
action_contexts: list[dict],
|
|
92
|
+
system_prompt: str,
|
|
93
|
+
user_prompt: str,
|
|
94
|
+
) -> list[dict]:
|
|
95
|
+
# Gemini can handle multiple images in one call
|
|
96
|
+
client = self._get_client()
|
|
97
|
+
full_prompt = f"{system_prompt}\n\n{user_prompt}"
|
|
98
|
+
content = [full_prompt] + images
|
|
99
|
+
response = client.generate_content(content)
|
|
100
|
+
return self._parse_batch_response(response.text, len(images))
|
|
101
|
+
|
|
102
|
+
def _parse_response(self, text: str) -> dict:
|
|
103
|
+
"""Parse JSON from response text."""
|
|
104
|
+
try:
|
|
105
|
+
# Find JSON in response
|
|
106
|
+
start = text.find("{")
|
|
107
|
+
end = text.rfind("}") + 1
|
|
108
|
+
if start >= 0 and end > start:
|
|
109
|
+
return json.loads(text[start:end])
|
|
110
|
+
except json.JSONDecodeError:
|
|
111
|
+
pass
|
|
112
|
+
return {"apparent_intent": text, "confidence": 0.5}
|
|
113
|
+
|
|
114
|
+
def _parse_batch_response(self, text: str, count: int) -> list[dict]:
|
|
115
|
+
"""Parse batch JSON response."""
|
|
116
|
+
try:
|
|
117
|
+
start = text.find("{")
|
|
118
|
+
end = text.rfind("}") + 1
|
|
119
|
+
if start >= 0 and end > start:
|
|
120
|
+
data = json.loads(text[start:end])
|
|
121
|
+
if "frames" in data:
|
|
122
|
+
return data["frames"]
|
|
123
|
+
except json.JSONDecodeError:
|
|
124
|
+
pass
|
|
125
|
+
return [
|
|
126
|
+
{"apparent_intent": f"Frame {i}", "confidence": 0.5} for i in range(count)
|
|
127
|
+
]
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class ClaudeBackend(VLMBackend):
|
|
131
|
+
"""Anthropic Claude VLM backend."""
|
|
132
|
+
|
|
133
|
+
def __init__(
|
|
134
|
+
self,
|
|
135
|
+
model: str = "claude-sonnet-4-20250514",
|
|
136
|
+
api_key: Optional[str] = None,
|
|
137
|
+
):
|
|
138
|
+
self.model = model
|
|
139
|
+
self._api_key = api_key
|
|
140
|
+
self._client = None
|
|
141
|
+
|
|
142
|
+
def _get_client(self):
|
|
143
|
+
if self._client is None:
|
|
144
|
+
import anthropic
|
|
145
|
+
from openadapt_ml.config import settings
|
|
146
|
+
|
|
147
|
+
api_key = self._api_key or settings.anthropic_api_key
|
|
148
|
+
self._client = anthropic.Anthropic(api_key=api_key)
|
|
149
|
+
return self._client
|
|
150
|
+
|
|
151
|
+
def _encode_image(self, image: Image.Image) -> dict:
|
|
152
|
+
"""Encode image for Claude API."""
|
|
153
|
+
import base64
|
|
154
|
+
import io
|
|
155
|
+
|
|
156
|
+
buffer = io.BytesIO()
|
|
157
|
+
image.save(buffer, format="PNG")
|
|
158
|
+
b64 = base64.b64encode(buffer.getvalue()).decode()
|
|
159
|
+
return {
|
|
160
|
+
"type": "image",
|
|
161
|
+
"source": {
|
|
162
|
+
"type": "base64",
|
|
163
|
+
"media_type": "image/png",
|
|
164
|
+
"data": b64,
|
|
165
|
+
},
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
def describe_frame(
|
|
169
|
+
self,
|
|
170
|
+
image: Image.Image,
|
|
171
|
+
action_context: dict,
|
|
172
|
+
system_prompt: str,
|
|
173
|
+
user_prompt: str,
|
|
174
|
+
) -> dict:
|
|
175
|
+
client = self._get_client()
|
|
176
|
+
response = client.messages.create(
|
|
177
|
+
model=self.model,
|
|
178
|
+
max_tokens=1024,
|
|
179
|
+
system=system_prompt,
|
|
180
|
+
messages=[
|
|
181
|
+
{
|
|
182
|
+
"role": "user",
|
|
183
|
+
"content": [
|
|
184
|
+
self._encode_image(image),
|
|
185
|
+
{"type": "text", "text": user_prompt},
|
|
186
|
+
],
|
|
187
|
+
}
|
|
188
|
+
],
|
|
189
|
+
)
|
|
190
|
+
return self._parse_response(response.content[0].text)
|
|
191
|
+
|
|
192
|
+
def describe_batch(
|
|
193
|
+
self,
|
|
194
|
+
images: list[Image.Image],
|
|
195
|
+
action_contexts: list[dict],
|
|
196
|
+
system_prompt: str,
|
|
197
|
+
user_prompt: str,
|
|
198
|
+
) -> list[dict]:
|
|
199
|
+
client = self._get_client()
|
|
200
|
+
content = []
|
|
201
|
+
for img in images:
|
|
202
|
+
content.append(self._encode_image(img))
|
|
203
|
+
content.append({"type": "text", "text": user_prompt})
|
|
204
|
+
|
|
205
|
+
response = client.messages.create(
|
|
206
|
+
model=self.model,
|
|
207
|
+
max_tokens=4096,
|
|
208
|
+
system=system_prompt,
|
|
209
|
+
messages=[{"role": "user", "content": content}],
|
|
210
|
+
)
|
|
211
|
+
return self._parse_batch_response(response.content[0].text, len(images))
|
|
212
|
+
|
|
213
|
+
def _parse_response(self, text: str) -> dict:
|
|
214
|
+
try:
|
|
215
|
+
start = text.find("{")
|
|
216
|
+
end = text.rfind("}") + 1
|
|
217
|
+
if start >= 0 and end > start:
|
|
218
|
+
return json.loads(text[start:end])
|
|
219
|
+
except json.JSONDecodeError:
|
|
220
|
+
pass
|
|
221
|
+
return {"apparent_intent": text, "confidence": 0.5}
|
|
222
|
+
|
|
223
|
+
def _parse_batch_response(self, text: str, count: int) -> list[dict]:
|
|
224
|
+
try:
|
|
225
|
+
start = text.find("{")
|
|
226
|
+
end = text.rfind("}") + 1
|
|
227
|
+
if start >= 0 and end > start:
|
|
228
|
+
data = json.loads(text[start:end])
|
|
229
|
+
if "frames" in data:
|
|
230
|
+
return data["frames"]
|
|
231
|
+
except json.JSONDecodeError:
|
|
232
|
+
pass
|
|
233
|
+
return [
|
|
234
|
+
{"apparent_intent": f"Frame {i}", "confidence": 0.5} for i in range(count)
|
|
235
|
+
]
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class OpenAIBackend(VLMBackend):
|
|
239
|
+
"""OpenAI GPT-4V backend."""
|
|
240
|
+
|
|
241
|
+
def __init__(
|
|
242
|
+
self,
|
|
243
|
+
model: str = "gpt-4o",
|
|
244
|
+
api_key: Optional[str] = None,
|
|
245
|
+
):
|
|
246
|
+
self.model = model
|
|
247
|
+
self._api_key = api_key
|
|
248
|
+
self._client = None
|
|
249
|
+
|
|
250
|
+
def _get_client(self):
|
|
251
|
+
if self._client is None:
|
|
252
|
+
import openai
|
|
253
|
+
from openadapt_ml.config import settings
|
|
254
|
+
|
|
255
|
+
api_key = self._api_key or settings.openai_api_key
|
|
256
|
+
self._client = openai.OpenAI(api_key=api_key)
|
|
257
|
+
return self._client
|
|
258
|
+
|
|
259
|
+
def _encode_image(self, image: Image.Image) -> dict:
|
|
260
|
+
"""Encode image for OpenAI API."""
|
|
261
|
+
import base64
|
|
262
|
+
import io
|
|
263
|
+
|
|
264
|
+
buffer = io.BytesIO()
|
|
265
|
+
image.save(buffer, format="PNG")
|
|
266
|
+
b64 = base64.b64encode(buffer.getvalue()).decode()
|
|
267
|
+
return {
|
|
268
|
+
"type": "image_url",
|
|
269
|
+
"image_url": {"url": f"data:image/png;base64,{b64}"},
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
def describe_frame(
|
|
273
|
+
self,
|
|
274
|
+
image: Image.Image,
|
|
275
|
+
action_context: dict,
|
|
276
|
+
system_prompt: str,
|
|
277
|
+
user_prompt: str,
|
|
278
|
+
) -> dict:
|
|
279
|
+
client = self._get_client()
|
|
280
|
+
response = client.chat.completions.create(
|
|
281
|
+
model=self.model,
|
|
282
|
+
max_tokens=1024,
|
|
283
|
+
messages=[
|
|
284
|
+
{"role": "system", "content": system_prompt},
|
|
285
|
+
{
|
|
286
|
+
"role": "user",
|
|
287
|
+
"content": [
|
|
288
|
+
self._encode_image(image),
|
|
289
|
+
{"type": "text", "text": user_prompt},
|
|
290
|
+
],
|
|
291
|
+
},
|
|
292
|
+
],
|
|
293
|
+
)
|
|
294
|
+
return self._parse_response(response.choices[0].message.content)
|
|
295
|
+
|
|
296
|
+
def describe_batch(
|
|
297
|
+
self,
|
|
298
|
+
images: list[Image.Image],
|
|
299
|
+
action_contexts: list[dict],
|
|
300
|
+
system_prompt: str,
|
|
301
|
+
user_prompt: str,
|
|
302
|
+
) -> list[dict]:
|
|
303
|
+
client = self._get_client()
|
|
304
|
+
content = [self._encode_image(img) for img in images]
|
|
305
|
+
content.append({"type": "text", "text": user_prompt})
|
|
306
|
+
|
|
307
|
+
response = client.chat.completions.create(
|
|
308
|
+
model=self.model,
|
|
309
|
+
max_tokens=4096,
|
|
310
|
+
messages=[
|
|
311
|
+
{"role": "system", "content": system_prompt},
|
|
312
|
+
{"role": "user", "content": content},
|
|
313
|
+
],
|
|
314
|
+
)
|
|
315
|
+
return self._parse_batch_response(
|
|
316
|
+
response.choices[0].message.content, len(images)
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
def _parse_response(self, text: str) -> dict:
|
|
320
|
+
try:
|
|
321
|
+
start = text.find("{")
|
|
322
|
+
end = text.rfind("}") + 1
|
|
323
|
+
if start >= 0 and end > start:
|
|
324
|
+
return json.loads(text[start:end])
|
|
325
|
+
except json.JSONDecodeError:
|
|
326
|
+
pass
|
|
327
|
+
return {"apparent_intent": text, "confidence": 0.5}
|
|
328
|
+
|
|
329
|
+
def _parse_batch_response(self, text: str, count: int) -> list[dict]:
|
|
330
|
+
try:
|
|
331
|
+
start = text.find("{")
|
|
332
|
+
end = text.rfind("}") + 1
|
|
333
|
+
if start >= 0 and end > start:
|
|
334
|
+
data = json.loads(text[start:end])
|
|
335
|
+
if "frames" in data:
|
|
336
|
+
return data["frames"]
|
|
337
|
+
except json.JSONDecodeError:
|
|
338
|
+
pass
|
|
339
|
+
return [
|
|
340
|
+
{"apparent_intent": f"Frame {i}", "confidence": 0.5} for i in range(count)
|
|
341
|
+
]
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def _format_timestamp(seconds: float) -> str:
|
|
345
|
+
"""Format seconds as MM:SS.m"""
|
|
346
|
+
minutes = int(seconds // 60)
|
|
347
|
+
secs = seconds % 60
|
|
348
|
+
return f"{minutes:02d}:{secs:04.1f}"
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
def _get_action_type(action_name: str) -> ActionType:
|
|
352
|
+
"""Convert action name to ActionType enum."""
|
|
353
|
+
name_lower = action_name.lower()
|
|
354
|
+
if "double" in name_lower:
|
|
355
|
+
return ActionType.DOUBLE_CLICK
|
|
356
|
+
elif "right" in name_lower:
|
|
357
|
+
return ActionType.RIGHT_CLICK
|
|
358
|
+
elif "click" in name_lower:
|
|
359
|
+
return ActionType.CLICK
|
|
360
|
+
elif "type" in name_lower or "key" in name_lower:
|
|
361
|
+
return ActionType.TYPE
|
|
362
|
+
elif "scroll" in name_lower:
|
|
363
|
+
return ActionType.SCROLL
|
|
364
|
+
elif "drag" in name_lower:
|
|
365
|
+
return ActionType.DRAG
|
|
366
|
+
elif "hotkey" in name_lower:
|
|
367
|
+
return ActionType.HOTKEY
|
|
368
|
+
elif "move" in name_lower:
|
|
369
|
+
return ActionType.MOVE
|
|
370
|
+
return ActionType.CLICK
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
class FrameDescriber:
|
|
374
|
+
"""Generates semantic descriptions of recording frames using VLMs.
|
|
375
|
+
|
|
376
|
+
This class implements Stage 1 of the segmentation pipeline, converting
|
|
377
|
+
raw screenshots and action data into human-readable descriptions.
|
|
378
|
+
|
|
379
|
+
Example:
|
|
380
|
+
>>> describer = FrameDescriber(model="gemini-2.0-flash")
|
|
381
|
+
>>> transcript = describer.describe_recording(recording)
|
|
382
|
+
>>> print(transcript.to_transcript_text())
|
|
383
|
+
[00:00.0] User opens System Preferences from Apple menu
|
|
384
|
+
[00:02.5] User clicks Display settings icon
|
|
385
|
+
...
|
|
386
|
+
|
|
387
|
+
Attributes:
|
|
388
|
+
model: VLM model identifier
|
|
389
|
+
batch_size: Number of frames to process per API call
|
|
390
|
+
cache_enabled: Whether to cache descriptions
|
|
391
|
+
"""
|
|
392
|
+
|
|
393
|
+
SUPPORTED_MODELS = [
|
|
394
|
+
"gemini-2.0-flash",
|
|
395
|
+
"gemini-2.0-pro",
|
|
396
|
+
"claude-sonnet-4-20250514",
|
|
397
|
+
"claude-3-5-haiku-20241022",
|
|
398
|
+
"gpt-4o",
|
|
399
|
+
"gpt-4o-mini",
|
|
400
|
+
]
|
|
401
|
+
|
|
402
|
+
def __init__(
|
|
403
|
+
self,
|
|
404
|
+
model: str = "gemini-2.0-flash",
|
|
405
|
+
batch_size: int = 10,
|
|
406
|
+
cache_enabled: bool = True,
|
|
407
|
+
cache_dir: Optional[Path] = None,
|
|
408
|
+
backend: Optional[VLMBackend] = None,
|
|
409
|
+
) -> None:
|
|
410
|
+
"""Initialize the frame describer.
|
|
411
|
+
|
|
412
|
+
Args:
|
|
413
|
+
model: VLM model to use.
|
|
414
|
+
batch_size: Number of frames per API call.
|
|
415
|
+
cache_enabled: Cache descriptions to avoid reprocessing.
|
|
416
|
+
cache_dir: Directory for cached descriptions.
|
|
417
|
+
backend: Custom VLM backend (for testing or custom models).
|
|
418
|
+
"""
|
|
419
|
+
self.model = model
|
|
420
|
+
self.batch_size = batch_size
|
|
421
|
+
self.cache_enabled = cache_enabled
|
|
422
|
+
self.cache_dir = (
|
|
423
|
+
cache_dir or Path.home() / ".openadapt" / "cache" / "descriptions"
|
|
424
|
+
)
|
|
425
|
+
self._backend = backend or self._create_backend(model)
|
|
426
|
+
|
|
427
|
+
def _create_backend(self, model: str) -> VLMBackend:
|
|
428
|
+
"""Create appropriate backend for the specified model."""
|
|
429
|
+
if "gemini" in model.lower():
|
|
430
|
+
return GeminiBackend(model=model)
|
|
431
|
+
elif "claude" in model.lower():
|
|
432
|
+
return ClaudeBackend(model=model)
|
|
433
|
+
elif "gpt" in model.lower():
|
|
434
|
+
return OpenAIBackend(model=model)
|
|
435
|
+
else:
|
|
436
|
+
raise ValueError(
|
|
437
|
+
f"Unknown model: {model}. Supported: {self.SUPPORTED_MODELS}"
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
def _get_system_prompt(self) -> str:
|
|
441
|
+
"""Return system prompt for VLM."""
|
|
442
|
+
return """You are an expert at analyzing GUI screenshots and user actions. Your task is to describe what the user is doing in each screenshot, focusing on:
|
|
443
|
+
|
|
444
|
+
1. **Context**: What application is open? What screen/view is visible?
|
|
445
|
+
2. **Action**: What specific action did the user take? (click, type, scroll, etc.)
|
|
446
|
+
3. **Intent**: What is the user trying to accomplish with this action?
|
|
447
|
+
|
|
448
|
+
Provide descriptions that would help someone understand and reproduce the workflow.
|
|
449
|
+
|
|
450
|
+
Guidelines:
|
|
451
|
+
- Be specific about UI elements (e.g., "Night Shift toggle" not "a button")
|
|
452
|
+
- Include relevant text visible on screen when it clarifies intent
|
|
453
|
+
- Note any state changes visible in the screenshot
|
|
454
|
+
- Keep descriptions concise but complete (1-2 sentences typically)"""
|
|
455
|
+
|
|
456
|
+
def _get_user_prompt(self, frames_data: list[dict]) -> str:
|
|
457
|
+
"""Build user prompt for batch of frames."""
|
|
458
|
+
lines = ["Analyze the following screenshot(s) and action(s):\n"]
|
|
459
|
+
|
|
460
|
+
for i, frame in enumerate(frames_data, 1):
|
|
461
|
+
lines.append(f"## Frame {i} ({frame['timestamp_formatted']})")
|
|
462
|
+
lines.append("**Action performed**:")
|
|
463
|
+
lines.append(f"- Type: {frame['action']['name']}")
|
|
464
|
+
if frame["action"].get("mouse_x") is not None:
|
|
465
|
+
lines.append(
|
|
466
|
+
f"- Location: ({int(frame['action']['mouse_x'])}, {int(frame['action']['mouse_y'])})"
|
|
467
|
+
)
|
|
468
|
+
if frame["action"].get("text"):
|
|
469
|
+
lines.append(f'- Text typed: "{frame["action"]["text"]}"')
|
|
470
|
+
lines.append("")
|
|
471
|
+
|
|
472
|
+
lines.append("""For each frame, provide a JSON response with this structure:
|
|
473
|
+
```json
|
|
474
|
+
{
|
|
475
|
+
"frames": [
|
|
476
|
+
{
|
|
477
|
+
"frame_index": 1,
|
|
478
|
+
"visible_application": "Application name",
|
|
479
|
+
"visible_elements": ["element1", "element2"],
|
|
480
|
+
"screen_context": "Brief description of the overall screen state",
|
|
481
|
+
"action_target": "Specific UI element targeted",
|
|
482
|
+
"apparent_intent": "What the user is trying to accomplish",
|
|
483
|
+
"confidence": 0.95
|
|
484
|
+
}
|
|
485
|
+
]
|
|
486
|
+
}
|
|
487
|
+
```""")
|
|
488
|
+
return "\n".join(lines)
|
|
489
|
+
|
|
490
|
+
def _cache_key(self, image: Image.Image, action: dict) -> str:
|
|
491
|
+
"""Generate cache key for a frame."""
|
|
492
|
+
import io
|
|
493
|
+
|
|
494
|
+
buffer = io.BytesIO()
|
|
495
|
+
image.save(buffer, format="PNG")
|
|
496
|
+
img_hash = hashlib.md5(buffer.getvalue()).hexdigest()[:12]
|
|
497
|
+
action_str = json.dumps(action, sort_keys=True)
|
|
498
|
+
action_hash = hashlib.md5(action_str.encode()).hexdigest()[:8]
|
|
499
|
+
return f"{img_hash}_{action_hash}"
|
|
500
|
+
|
|
501
|
+
def _load_cached(self, cache_key: str) -> Optional[dict]:
|
|
502
|
+
"""Load cached description."""
|
|
503
|
+
if not self.cache_enabled:
|
|
504
|
+
return None
|
|
505
|
+
cache_file = self.cache_dir / f"{cache_key}.json"
|
|
506
|
+
if cache_file.exists():
|
|
507
|
+
try:
|
|
508
|
+
return json.loads(cache_file.read_text())
|
|
509
|
+
except Exception:
|
|
510
|
+
pass
|
|
511
|
+
return None
|
|
512
|
+
|
|
513
|
+
def _save_cached(self, cache_key: str, description: dict) -> None:
|
|
514
|
+
"""Save description to cache."""
|
|
515
|
+
if not self.cache_enabled:
|
|
516
|
+
return
|
|
517
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
518
|
+
cache_file = self.cache_dir / f"{cache_key}.json"
|
|
519
|
+
try:
|
|
520
|
+
cache_file.write_text(json.dumps(description))
|
|
521
|
+
except Exception as e:
|
|
522
|
+
logger.warning(f"Failed to cache description: {e}")
|
|
523
|
+
|
|
524
|
+
def describe_recording(
|
|
525
|
+
self,
|
|
526
|
+
recording_path: Union[str, Path],
|
|
527
|
+
progress_callback: Optional[callable] = None,
|
|
528
|
+
) -> ActionTranscript:
|
|
529
|
+
"""Generate descriptions for all frames in a recording.
|
|
530
|
+
|
|
531
|
+
Args:
|
|
532
|
+
recording_path: Path to recording directory or file.
|
|
533
|
+
progress_callback: Optional callback(current, total) for progress.
|
|
534
|
+
|
|
535
|
+
Returns:
|
|
536
|
+
ActionTranscript with descriptions for all frames.
|
|
537
|
+
"""
|
|
538
|
+
recording_path = Path(recording_path)
|
|
539
|
+
if not recording_path.exists():
|
|
540
|
+
raise FileNotFoundError(f"Recording not found: {recording_path}")
|
|
541
|
+
|
|
542
|
+
# Load recording data
|
|
543
|
+
images, action_events = self._load_recording(recording_path)
|
|
544
|
+
recording_id = recording_path.name
|
|
545
|
+
recording_name = recording_path.stem
|
|
546
|
+
|
|
547
|
+
# Process in batches
|
|
548
|
+
frame_descriptions = []
|
|
549
|
+
total_frames = len(images)
|
|
550
|
+
|
|
551
|
+
for batch_start in range(0, total_frames, self.batch_size):
|
|
552
|
+
batch_end = min(batch_start + self.batch_size, total_frames)
|
|
553
|
+
batch_images = images[batch_start:batch_end]
|
|
554
|
+
batch_actions = action_events[batch_start:batch_end]
|
|
555
|
+
|
|
556
|
+
# Check cache first
|
|
557
|
+
batch_results = []
|
|
558
|
+
uncached_indices = []
|
|
559
|
+
|
|
560
|
+
for i, (img, action) in enumerate(zip(batch_images, batch_actions)):
|
|
561
|
+
cache_key = self._cache_key(img, action)
|
|
562
|
+
cached = self._load_cached(cache_key)
|
|
563
|
+
if cached:
|
|
564
|
+
batch_results.append((i, cached))
|
|
565
|
+
else:
|
|
566
|
+
uncached_indices.append(i)
|
|
567
|
+
|
|
568
|
+
# Process uncached frames
|
|
569
|
+
if uncached_indices:
|
|
570
|
+
uncached_images = [batch_images[i] for i in uncached_indices]
|
|
571
|
+
uncached_actions = [batch_actions[i] for i in uncached_indices]
|
|
572
|
+
|
|
573
|
+
frames_data = [
|
|
574
|
+
{
|
|
575
|
+
"timestamp_formatted": _format_timestamp(a.get("timestamp", 0)),
|
|
576
|
+
"action": a,
|
|
577
|
+
}
|
|
578
|
+
for a in uncached_actions
|
|
579
|
+
]
|
|
580
|
+
|
|
581
|
+
descriptions = self._backend.describe_batch(
|
|
582
|
+
uncached_images,
|
|
583
|
+
uncached_actions,
|
|
584
|
+
self._get_system_prompt(),
|
|
585
|
+
self._get_user_prompt(frames_data),
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
for i, desc in zip(uncached_indices, descriptions):
|
|
589
|
+
batch_results.append((i, desc))
|
|
590
|
+
cache_key = self._cache_key(batch_images[i], batch_actions[i])
|
|
591
|
+
self._save_cached(cache_key, desc)
|
|
592
|
+
|
|
593
|
+
# Sort by index and create FrameDescriptions
|
|
594
|
+
batch_results.sort(key=lambda x: x[0])
|
|
595
|
+
for i, (idx, desc) in enumerate(batch_results):
|
|
596
|
+
frame_idx = batch_start + idx
|
|
597
|
+
action = batch_actions[idx]
|
|
598
|
+
timestamp = action.get("timestamp", 0)
|
|
599
|
+
|
|
600
|
+
frame_desc = FrameDescription(
|
|
601
|
+
timestamp=timestamp,
|
|
602
|
+
formatted_time=_format_timestamp(timestamp),
|
|
603
|
+
visible_application=desc.get("visible_application", "Unknown"),
|
|
604
|
+
visible_elements=desc.get("visible_elements", []),
|
|
605
|
+
screen_context=desc.get("screen_context", ""),
|
|
606
|
+
action_type=_get_action_type(action.get("name", "click")),
|
|
607
|
+
action_target=desc.get("action_target"),
|
|
608
|
+
action_value=action.get("text"),
|
|
609
|
+
apparent_intent=desc.get("apparent_intent", "Unknown action"),
|
|
610
|
+
confidence=desc.get("confidence", 0.5),
|
|
611
|
+
frame_index=frame_idx,
|
|
612
|
+
vlm_model=self.model,
|
|
613
|
+
)
|
|
614
|
+
frame_descriptions.append(frame_desc)
|
|
615
|
+
|
|
616
|
+
if progress_callback:
|
|
617
|
+
progress_callback(batch_end, total_frames)
|
|
618
|
+
|
|
619
|
+
# Calculate total duration
|
|
620
|
+
total_duration = 0
|
|
621
|
+
if frame_descriptions:
|
|
622
|
+
total_duration = max(f.timestamp for f in frame_descriptions)
|
|
623
|
+
|
|
624
|
+
return ActionTranscript(
|
|
625
|
+
recording_id=recording_id,
|
|
626
|
+
recording_name=recording_name,
|
|
627
|
+
frames=frame_descriptions,
|
|
628
|
+
total_duration=total_duration,
|
|
629
|
+
frame_count=len(frame_descriptions),
|
|
630
|
+
vlm_model=self.model,
|
|
631
|
+
processing_timestamp=datetime.now(),
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
def _load_recording(
|
|
635
|
+
self, recording_path: Path
|
|
636
|
+
) -> tuple[list[Image.Image], list[dict]]:
|
|
637
|
+
"""Load recording data from various formats."""
|
|
638
|
+
# Try to load from openadapt-capture SQLite format
|
|
639
|
+
if (recording_path / "capture.db").exists():
|
|
640
|
+
try:
|
|
641
|
+
from openadapt_ml.segmentation.adapters import CaptureAdapter
|
|
642
|
+
|
|
643
|
+
adapter = CaptureAdapter()
|
|
644
|
+
return adapter.load_recording(recording_path)
|
|
645
|
+
except Exception as e:
|
|
646
|
+
logger.warning(f"Failed to load via CaptureAdapter: {e}")
|
|
647
|
+
# Fall through to other formats
|
|
648
|
+
|
|
649
|
+
# Try to load from openadapt-capture format (events.json)
|
|
650
|
+
metadata_file = recording_path / "metadata.json"
|
|
651
|
+
if metadata_file.exists():
|
|
652
|
+
return self._load_capture_format(recording_path)
|
|
653
|
+
|
|
654
|
+
# Try loading from a single JSON file
|
|
655
|
+
if recording_path.suffix == ".json":
|
|
656
|
+
return self._load_json_format(recording_path)
|
|
657
|
+
|
|
658
|
+
# Try loading from directory with screenshots
|
|
659
|
+
return self._load_directory_format(recording_path)
|
|
660
|
+
|
|
661
|
+
def _load_capture_format(
|
|
662
|
+
self, recording_path: Path
|
|
663
|
+
) -> tuple[list[Image.Image], list[dict]]:
|
|
664
|
+
"""Load from openadapt-capture format."""
|
|
665
|
+
_metadata = json.loads((recording_path / "metadata.json").read_text())
|
|
666
|
+
# Note: _metadata contains recording_id, goal, timestamps but we load
|
|
667
|
+
# these at the transcript level, not per-frame
|
|
668
|
+
images = []
|
|
669
|
+
actions = []
|
|
670
|
+
|
|
671
|
+
screenshots_dir = recording_path / "screenshots"
|
|
672
|
+
events_file = recording_path / "events.json"
|
|
673
|
+
|
|
674
|
+
if events_file.exists():
|
|
675
|
+
events = json.loads(events_file.read_text())
|
|
676
|
+
for event in events:
|
|
677
|
+
screenshot_path = screenshots_dir / f"{event['frame_index']:06d}.png"
|
|
678
|
+
if screenshot_path.exists():
|
|
679
|
+
images.append(Image.open(screenshot_path))
|
|
680
|
+
actions.append(event)
|
|
681
|
+
|
|
682
|
+
return images, actions
|
|
683
|
+
|
|
684
|
+
def _load_json_format(
|
|
685
|
+
self, json_path: Path
|
|
686
|
+
) -> tuple[list[Image.Image], list[dict]]:
|
|
687
|
+
"""Load from JSON file with base64 screenshots."""
|
|
688
|
+
import base64
|
|
689
|
+
import io
|
|
690
|
+
|
|
691
|
+
data = json.loads(json_path.read_text())
|
|
692
|
+
images = []
|
|
693
|
+
actions = []
|
|
694
|
+
|
|
695
|
+
for frame in data.get("frames", []):
|
|
696
|
+
if "screenshot_base64" in frame:
|
|
697
|
+
img_data = base64.b64decode(frame["screenshot_base64"])
|
|
698
|
+
images.append(Image.open(io.BytesIO(img_data)))
|
|
699
|
+
actions.append(frame.get("action", {}))
|
|
700
|
+
|
|
701
|
+
return images, actions
|
|
702
|
+
|
|
703
|
+
def _load_directory_format(
|
|
704
|
+
self, dir_path: Path
|
|
705
|
+
) -> tuple[list[Image.Image], list[dict]]:
|
|
706
|
+
"""Load from directory with numbered screenshots."""
|
|
707
|
+
images = []
|
|
708
|
+
actions = []
|
|
709
|
+
|
|
710
|
+
# Find all PNG files
|
|
711
|
+
png_files = sorted(dir_path.glob("*.png"))
|
|
712
|
+
for i, png_file in enumerate(png_files):
|
|
713
|
+
images.append(Image.open(png_file))
|
|
714
|
+
# Create synthetic action event
|
|
715
|
+
actions.append(
|
|
716
|
+
{
|
|
717
|
+
"name": "unknown",
|
|
718
|
+
"timestamp": i * 1.0, # Assume 1 second between frames
|
|
719
|
+
"frame_index": i,
|
|
720
|
+
}
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
return images, actions
|
|
724
|
+
|
|
725
|
+
def describe_frame(
|
|
726
|
+
self,
|
|
727
|
+
image: Image.Image,
|
|
728
|
+
action_event: dict,
|
|
729
|
+
previous_context: Optional[str] = None,
|
|
730
|
+
) -> FrameDescription:
|
|
731
|
+
"""Generate description for a single frame."""
|
|
732
|
+
frames_data = [
|
|
733
|
+
{
|
|
734
|
+
"timestamp_formatted": _format_timestamp(
|
|
735
|
+
action_event.get("timestamp", 0)
|
|
736
|
+
),
|
|
737
|
+
"action": action_event,
|
|
738
|
+
}
|
|
739
|
+
]
|
|
740
|
+
|
|
741
|
+
descriptions = self._backend.describe_batch(
|
|
742
|
+
[image],
|
|
743
|
+
[action_event],
|
|
744
|
+
self._get_system_prompt(),
|
|
745
|
+
self._get_user_prompt(frames_data),
|
|
746
|
+
)
|
|
747
|
+
|
|
748
|
+
desc = descriptions[0] if descriptions else {}
|
|
749
|
+
timestamp = action_event.get("timestamp", 0)
|
|
750
|
+
|
|
751
|
+
return FrameDescription(
|
|
752
|
+
timestamp=timestamp,
|
|
753
|
+
formatted_time=_format_timestamp(timestamp),
|
|
754
|
+
visible_application=desc.get("visible_application", "Unknown"),
|
|
755
|
+
visible_elements=desc.get("visible_elements", []),
|
|
756
|
+
screen_context=desc.get("screen_context", ""),
|
|
757
|
+
action_type=_get_action_type(action_event.get("name", "click")),
|
|
758
|
+
action_target=desc.get("action_target"),
|
|
759
|
+
action_value=action_event.get("text"),
|
|
760
|
+
apparent_intent=desc.get("apparent_intent", "Unknown action"),
|
|
761
|
+
confidence=desc.get("confidence", 0.5),
|
|
762
|
+
frame_index=action_event.get("frame_index", 0),
|
|
763
|
+
vlm_model=self.model,
|
|
764
|
+
)
|
|
765
|
+
|
|
766
|
+
def clear_cache(self, recording_id: Optional[str] = None) -> int:
|
|
767
|
+
"""Clear cached descriptions.
|
|
768
|
+
|
|
769
|
+
Args:
|
|
770
|
+
recording_id: If specified, only clear cache for this recording.
|
|
771
|
+
|
|
772
|
+
Returns:
|
|
773
|
+
Number of cached items cleared.
|
|
774
|
+
"""
|
|
775
|
+
if not self.cache_dir.exists():
|
|
776
|
+
return 0
|
|
777
|
+
|
|
778
|
+
count = 0
|
|
779
|
+
for cache_file in self.cache_dir.glob("*.json"):
|
|
780
|
+
if recording_id is None or recording_id in cache_file.name:
|
|
781
|
+
cache_file.unlink()
|
|
782
|
+
count += 1
|
|
783
|
+
return count
|
|
784
|
+
|
|
785
|
+
@property
|
|
786
|
+
def supported_models(self) -> list[str]:
|
|
787
|
+
"""Return list of supported VLM models."""
|
|
788
|
+
return self.SUPPORTED_MODELS
|