openadapt-ml 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/benchmarks/__init__.py +8 -0
- openadapt_ml/benchmarks/agent.py +90 -11
- openadapt_ml/benchmarks/azure.py +35 -6
- openadapt_ml/benchmarks/cli.py +4449 -201
- openadapt_ml/benchmarks/live_tracker.py +180 -0
- openadapt_ml/benchmarks/runner.py +41 -4
- openadapt_ml/benchmarks/viewer.py +1219 -0
- openadapt_ml/benchmarks/vm_monitor.py +610 -0
- openadapt_ml/benchmarks/waa.py +61 -4
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +222 -0
- openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +539 -0
- openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
- openadapt_ml/benchmarks/waa_live.py +619 -0
- openadapt_ml/cloud/local.py +1555 -1
- openadapt_ml/cloud/ssh_tunnel.py +553 -0
- openadapt_ml/datasets/next_action.py +87 -68
- openadapt_ml/evals/grounding.py +26 -8
- openadapt_ml/evals/trajectory_matching.py +84 -36
- openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
- openadapt_ml/experiments/demo_prompt/format_demo.py +226 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
- openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
- openadapt_ml/experiments/demo_prompt/run_experiment.py +531 -0
- openadapt_ml/experiments/waa_demo/__init__.py +10 -0
- openadapt_ml/experiments/waa_demo/demos.py +357 -0
- openadapt_ml/experiments/waa_demo/runner.py +717 -0
- openadapt_ml/experiments/waa_demo/tasks.py +151 -0
- openadapt_ml/export/__init__.py +9 -0
- openadapt_ml/export/__main__.py +6 -0
- openadapt_ml/export/cli.py +89 -0
- openadapt_ml/export/parquet.py +265 -0
- openadapt_ml/ingest/__init__.py +3 -4
- openadapt_ml/ingest/capture.py +89 -81
- openadapt_ml/ingest/loader.py +116 -68
- openadapt_ml/ingest/synthetic.py +221 -159
- openadapt_ml/retrieval/README.md +226 -0
- openadapt_ml/retrieval/USAGE.md +391 -0
- openadapt_ml/retrieval/__init__.py +91 -0
- openadapt_ml/retrieval/demo_retriever.py +817 -0
- openadapt_ml/retrieval/embeddings.py +629 -0
- openadapt_ml/retrieval/index.py +194 -0
- openadapt_ml/retrieval/retriever.py +160 -0
- openadapt_ml/runtime/policy.py +10 -10
- openadapt_ml/schema/__init__.py +104 -0
- openadapt_ml/schema/converters.py +541 -0
- openadapt_ml/schema/episode.py +457 -0
- openadapt_ml/scripts/compare.py +26 -16
- openadapt_ml/scripts/eval_policy.py +4 -5
- openadapt_ml/scripts/prepare_synthetic.py +14 -17
- openadapt_ml/scripts/train.py +81 -70
- openadapt_ml/training/benchmark_viewer.py +3225 -0
- openadapt_ml/training/trainer.py +120 -363
- openadapt_ml/training/trl_trainer.py +354 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/METADATA +102 -60
- openadapt_ml-0.2.0.dist-info/RECORD +86 -0
- openadapt_ml/schemas/__init__.py +0 -53
- openadapt_ml/schemas/sessions.py +0 -122
- openadapt_ml/schemas/validation.py +0 -252
- openadapt_ml-0.1.0.dist-info/RECORD +0 -55
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -6,7 +6,7 @@ from typing import Any, Dict, List
|
|
|
6
6
|
import torch
|
|
7
7
|
from torch.utils.data import Dataset
|
|
8
8
|
|
|
9
|
-
from openadapt_ml.
|
|
9
|
+
from openadapt_ml.schema import Action, ActionType, Episode, Step, UIElement
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
# Coordinate-based DSL system prompt (original)
|
|
@@ -97,6 +97,13 @@ SYSTEM_PROMPT_SOM_REGISTRATION = (
|
|
|
97
97
|
)
|
|
98
98
|
|
|
99
99
|
|
|
100
|
+
def _get_element_id(action: Action) -> str | None:
|
|
101
|
+
"""Extract element ID from action's element field."""
|
|
102
|
+
if action.element is not None and action.element.element_id is not None:
|
|
103
|
+
return action.element.element_id
|
|
104
|
+
return None
|
|
105
|
+
|
|
106
|
+
|
|
100
107
|
def format_action(action: Action, use_som: bool = False) -> str:
|
|
101
108
|
"""Serialize an Action into a simple textual command.
|
|
102
109
|
|
|
@@ -110,53 +117,55 @@ def format_action(action: Action, use_som: bool = False) -> str:
|
|
|
110
117
|
Args:
|
|
111
118
|
action: The action to format.
|
|
112
119
|
use_som: If True, use Set-of-Marks (SoM) index-based format instead of
|
|
113
|
-
coordinate-based format. Requires
|
|
120
|
+
coordinate-based format. Requires element with element_id to be set.
|
|
114
121
|
"""
|
|
115
122
|
|
|
116
123
|
t = action.type
|
|
124
|
+
element_id = _get_element_id(action)
|
|
117
125
|
if use_som:
|
|
118
126
|
# SoM mode: use element indices instead of coordinates
|
|
119
|
-
if t ==
|
|
120
|
-
return f"CLICK([{
|
|
121
|
-
if t ==
|
|
127
|
+
if t == ActionType.CLICK and element_id is not None:
|
|
128
|
+
return f"CLICK([{element_id}])"
|
|
129
|
+
if t == ActionType.TYPE and action.text is not None:
|
|
122
130
|
escaped = action.text.replace("\\", "\\\\").replace("\"", "\\\"")
|
|
123
|
-
if
|
|
124
|
-
return f"TYPE([{
|
|
131
|
+
if element_id is not None:
|
|
132
|
+
return f"TYPE([{element_id}], \"{escaped}\")"
|
|
125
133
|
else:
|
|
126
134
|
# Fallback: TYPE without element reference (for focused field)
|
|
127
135
|
return f"TYPE(\"{escaped}\")"
|
|
128
|
-
if t ==
|
|
136
|
+
if t == ActionType.WAIT:
|
|
129
137
|
return "WAIT()"
|
|
130
|
-
if t ==
|
|
138
|
+
if t == ActionType.DONE:
|
|
131
139
|
return "DONE()"
|
|
132
140
|
# Fallback
|
|
133
|
-
return f"ACTION(type={t})"
|
|
141
|
+
return f"ACTION(type={t.value if isinstance(t, ActionType) else t})"
|
|
134
142
|
else:
|
|
135
143
|
# Coordinate mode (original)
|
|
136
|
-
if t ==
|
|
137
|
-
|
|
138
|
-
|
|
144
|
+
if t == ActionType.CLICK and action.normalized_coordinates is not None:
|
|
145
|
+
x, y = action.normalized_coordinates
|
|
146
|
+
return f"CLICK(x={x:.2f}, y={y:.2f})"
|
|
147
|
+
if t == ActionType.TYPE and action.text is not None:
|
|
139
148
|
escaped = action.text.replace("\\", "\\\\").replace("\"", "\\\"")
|
|
140
149
|
return f"TYPE(text=\"{escaped}\")"
|
|
141
|
-
if t ==
|
|
150
|
+
if t == ActionType.WAIT:
|
|
142
151
|
return "WAIT()"
|
|
143
|
-
if t ==
|
|
152
|
+
if t == ActionType.DONE:
|
|
144
153
|
return "DONE()"
|
|
145
154
|
# Fallback
|
|
146
|
-
return f"ACTION(type={t})"
|
|
155
|
+
return f"ACTION(type={t.value if isinstance(t, ActionType) else t})"
|
|
147
156
|
|
|
148
157
|
|
|
149
158
|
def parse_action_som(text: str) -> Action:
|
|
150
159
|
"""Parse a SoM-style action string into an Action object.
|
|
151
160
|
|
|
152
161
|
Supported formats:
|
|
153
|
-
- CLICK([N])
|
|
154
|
-
- TYPE([N], "text")
|
|
155
|
-
- TYPE("text")
|
|
156
|
-
- WAIT()
|
|
157
|
-
- DONE()
|
|
162
|
+
- CLICK([N]) -> click element N
|
|
163
|
+
- TYPE([N], "text") -> type text into element N
|
|
164
|
+
- TYPE("text") -> type text into focused field
|
|
165
|
+
- WAIT() -> wait
|
|
166
|
+
- DONE() -> done
|
|
158
167
|
|
|
159
|
-
Returns Action with
|
|
168
|
+
Returns Action with element set for click/type actions.
|
|
160
169
|
"""
|
|
161
170
|
import re
|
|
162
171
|
|
|
@@ -165,32 +174,32 @@ def parse_action_som(text: str) -> Action:
|
|
|
165
174
|
# CLICK([N])
|
|
166
175
|
match = re.match(r"CLICK\(\[(\d+)\]\)", text)
|
|
167
176
|
if match:
|
|
168
|
-
idx =
|
|
169
|
-
return Action(type=
|
|
177
|
+
idx = match.group(1)
|
|
178
|
+
return Action(type=ActionType.CLICK, element=UIElement(element_id=idx))
|
|
170
179
|
|
|
171
180
|
# TYPE([N], "text") or TYPE([N], 'text')
|
|
172
181
|
match = re.match(r'TYPE\(\[(\d+)\],\s*["\'](.*)["\']\)', text, re.DOTALL)
|
|
173
182
|
if match:
|
|
174
|
-
idx =
|
|
183
|
+
idx = match.group(1)
|
|
175
184
|
content = match.group(2).replace("\\\"", "\"").replace("\\\\", "\\")
|
|
176
|
-
return Action(type=
|
|
185
|
+
return Action(type=ActionType.TYPE, text=content, element=UIElement(element_id=idx))
|
|
177
186
|
|
|
178
187
|
# TYPE("text") - no element index
|
|
179
188
|
match = re.match(r'TYPE\(["\'](.*)["\']\)', text, re.DOTALL)
|
|
180
189
|
if match:
|
|
181
190
|
content = match.group(1).replace("\\\"", "\"").replace("\\\\", "\\")
|
|
182
|
-
return Action(type=
|
|
191
|
+
return Action(type=ActionType.TYPE, text=content)
|
|
183
192
|
|
|
184
193
|
# WAIT()
|
|
185
194
|
if text.upper().startswith("WAIT"):
|
|
186
|
-
return Action(type=
|
|
195
|
+
return Action(type=ActionType.WAIT)
|
|
187
196
|
|
|
188
197
|
# DONE()
|
|
189
198
|
if text.upper().startswith("DONE"):
|
|
190
|
-
return Action(type=
|
|
199
|
+
return Action(type=ActionType.DONE)
|
|
191
200
|
|
|
192
201
|
# Failed to parse
|
|
193
|
-
return Action(type=
|
|
202
|
+
return Action(type=ActionType.FAIL, raw={"text": text})
|
|
194
203
|
|
|
195
204
|
|
|
196
205
|
def _generate_generic_thought(step_index: int, step: Step, goal: str, total_steps: int) -> str:
|
|
@@ -205,10 +214,10 @@ def _generate_generic_thought(step_index: int, step: Step, goal: str, total_step
|
|
|
205
214
|
# Progress context
|
|
206
215
|
progress = f"Step {step_index + 1} of {total_steps}."
|
|
207
216
|
|
|
208
|
-
if t ==
|
|
209
|
-
if action.
|
|
217
|
+
if t == ActionType.CLICK:
|
|
218
|
+
if action.normalized_coordinates is not None:
|
|
210
219
|
# Describe the click location relative to screen regions
|
|
211
|
-
x, y = action.
|
|
220
|
+
x, y = action.normalized_coordinates
|
|
212
221
|
h_pos = "left" if x < 0.33 else ("center" if x < 0.66 else "right")
|
|
213
222
|
v_pos = "top" if y < 0.33 else ("middle" if y < 0.66 else "bottom")
|
|
214
223
|
return (
|
|
@@ -217,28 +226,28 @@ def _generate_generic_thought(step_index: int, step: Step, goal: str, total_step
|
|
|
217
226
|
)
|
|
218
227
|
return f"{progress} I need to click on the relevant UI element to continue toward '{goal}'."
|
|
219
228
|
|
|
220
|
-
if t ==
|
|
229
|
+
if t == ActionType.DOUBLE_CLICK:
|
|
221
230
|
return f"{progress} I need to double-click to select or activate this element for '{goal}'."
|
|
222
231
|
|
|
223
|
-
if t ==
|
|
232
|
+
if t == ActionType.TYPE:
|
|
224
233
|
if action.text:
|
|
225
234
|
# Don't reveal the actual text, just indicate typing is needed
|
|
226
235
|
return f"{progress} I need to type text into the focused input field to continue toward '{goal}'."
|
|
227
236
|
return f"{progress} I need to enter text in the current field."
|
|
228
237
|
|
|
229
|
-
if t ==
|
|
238
|
+
if t == ActionType.SCROLL:
|
|
230
239
|
return f"{progress} I need to scroll to reveal more content or reach the target element for '{goal}'."
|
|
231
240
|
|
|
232
|
-
if t ==
|
|
241
|
+
if t == ActionType.DRAG:
|
|
233
242
|
return f"{progress} I need to drag an element to complete this part of '{goal}'."
|
|
234
243
|
|
|
235
|
-
if t ==
|
|
244
|
+
if t == ActionType.KEY:
|
|
236
245
|
return f"{progress} I need to press a key to continue the workflow."
|
|
237
246
|
|
|
238
|
-
if t ==
|
|
247
|
+
if t == ActionType.WAIT:
|
|
239
248
|
return f"{progress} I should wait for the UI to update before the next action."
|
|
240
249
|
|
|
241
|
-
if t ==
|
|
250
|
+
if t == ActionType.DONE:
|
|
242
251
|
return f"The goal '{goal}' has been achieved. The workflow is complete."
|
|
243
252
|
|
|
244
253
|
# Fallback
|
|
@@ -279,42 +288,42 @@ def _generate_login_thought(step_index: int, step: Step, goal: str, total_steps:
|
|
|
279
288
|
t = action.type
|
|
280
289
|
|
|
281
290
|
# Step 0: click username field
|
|
282
|
-
if step_index == 0 and t ==
|
|
291
|
+
if step_index == 0 and t == ActionType.CLICK:
|
|
283
292
|
return (
|
|
284
293
|
"I see a login screen with empty username and password fields and a Login button. "
|
|
285
294
|
f"To start logging in, I need to click on the username field to focus it ({goal})."
|
|
286
295
|
)
|
|
287
296
|
|
|
288
297
|
# Step 1: type username
|
|
289
|
-
if step_index == 1 and t ==
|
|
298
|
+
if step_index == 1 and t == ActionType.TYPE:
|
|
290
299
|
return (
|
|
291
300
|
"The username field is focused. To move toward the login goal, I should type the "
|
|
292
301
|
"username into this field."
|
|
293
302
|
)
|
|
294
303
|
|
|
295
304
|
# Step 2: click password field
|
|
296
|
-
if step_index == 2 and t ==
|
|
305
|
+
if step_index == 2 and t == ActionType.CLICK:
|
|
297
306
|
return (
|
|
298
307
|
"The username has been entered. Next, I need to focus the password field so that I can "
|
|
299
308
|
"enter the password for this login. I will click on the password input box."
|
|
300
309
|
)
|
|
301
310
|
|
|
302
311
|
# Step 3: type password
|
|
303
|
-
if step_index == 3 and t ==
|
|
312
|
+
if step_index == 3 and t == ActionType.TYPE:
|
|
304
313
|
return (
|
|
305
314
|
"The password field is focused. To continue the login process, I should type the "
|
|
306
315
|
"password (which will appear as masked characters on the screen)."
|
|
307
316
|
)
|
|
308
317
|
|
|
309
318
|
# Step 4: click Login button
|
|
310
|
-
if step_index == 4 and t ==
|
|
319
|
+
if step_index == 4 and t == ActionType.CLICK:
|
|
311
320
|
return (
|
|
312
321
|
"Both the username and password have been entered. To submit the form and attempt the "
|
|
313
322
|
"login, I should click the Login button."
|
|
314
323
|
)
|
|
315
324
|
|
|
316
325
|
# Step 5: DONE on logged-in screen
|
|
317
|
-
if step_index == 5 and t ==
|
|
326
|
+
if step_index == 5 and t == ActionType.DONE:
|
|
318
327
|
return (
|
|
319
328
|
"I now see a logged-in confirmation screen indicating the goal has been satisfied. "
|
|
320
329
|
"The task is complete, so I should emit DONE()."
|
|
@@ -334,41 +343,41 @@ def _generate_registration_thought(step_index: int, step: Step, goal: str, total
|
|
|
334
343
|
|
|
335
344
|
# Registration step mapping (pairs of click + type for 5 fields, then submit + done)
|
|
336
345
|
thoughts = {
|
|
337
|
-
(0,
|
|
346
|
+
(0, ActionType.CLICK): (
|
|
338
347
|
"I see a registration form with empty fields for name, email, and password. "
|
|
339
348
|
f"To start registration, I need to click on the First Name field ({goal})."
|
|
340
349
|
),
|
|
341
|
-
(1,
|
|
350
|
+
(1, ActionType.TYPE): (
|
|
342
351
|
"The First Name field is focused. I should type the first name."
|
|
343
352
|
),
|
|
344
|
-
(2,
|
|
353
|
+
(2, ActionType.CLICK): (
|
|
345
354
|
"First name entered. Now I need to focus the Last Name field to enter it."
|
|
346
355
|
),
|
|
347
|
-
(3,
|
|
356
|
+
(3, ActionType.TYPE): (
|
|
348
357
|
"The Last Name field is focused. I should type the last name."
|
|
349
358
|
),
|
|
350
|
-
(4,
|
|
359
|
+
(4, ActionType.CLICK): (
|
|
351
360
|
"Last name entered. Now I need to focus the Email field to enter the email address."
|
|
352
361
|
),
|
|
353
|
-
(5,
|
|
362
|
+
(5, ActionType.TYPE): (
|
|
354
363
|
"The Email field is focused. I should type the email address."
|
|
355
364
|
),
|
|
356
|
-
(6,
|
|
365
|
+
(6, ActionType.CLICK): (
|
|
357
366
|
"Email entered. Now I need to focus the Password field to create a password."
|
|
358
367
|
),
|
|
359
|
-
(7,
|
|
368
|
+
(7, ActionType.TYPE): (
|
|
360
369
|
"The Password field is focused. I should type the password."
|
|
361
370
|
),
|
|
362
|
-
(8,
|
|
371
|
+
(8, ActionType.CLICK): (
|
|
363
372
|
"Password entered. Now I need to focus the Confirm Password field to verify the password."
|
|
364
373
|
),
|
|
365
|
-
(9,
|
|
374
|
+
(9, ActionType.TYPE): (
|
|
366
375
|
"The Confirm Password field is focused. I should type the same password again."
|
|
367
376
|
),
|
|
368
|
-
(10,
|
|
377
|
+
(10, ActionType.CLICK): (
|
|
369
378
|
"All form fields are filled. I should click the Register button to submit the form."
|
|
370
379
|
),
|
|
371
|
-
(11,
|
|
380
|
+
(11, ActionType.DONE): (
|
|
372
381
|
"Registration is complete - I see a success screen. The task is finished."
|
|
373
382
|
),
|
|
374
383
|
}
|
|
@@ -385,10 +394,16 @@ def _generate_registration_thought(step_index: int, step: Step, goal: str, total
|
|
|
385
394
|
|
|
386
395
|
|
|
387
396
|
def _detect_scenario(episode: Episode) -> str:
|
|
388
|
-
"""Detect scenario from episode
|
|
389
|
-
|
|
390
|
-
|
|
397
|
+
"""Detect scenario from episode task_id or metadata."""
|
|
398
|
+
# Check task_id first
|
|
399
|
+
task_id = episode.task_id or ""
|
|
400
|
+
if "registration" in task_id.lower():
|
|
391
401
|
return "registration"
|
|
402
|
+
# Check metadata for workflow_id (backward compatibility)
|
|
403
|
+
if episode.metadata and "workflow_id" in episode.metadata:
|
|
404
|
+
workflow_id = episode.metadata.get("workflow_id", "")
|
|
405
|
+
if "registration" in str(workflow_id).lower():
|
|
406
|
+
return "registration"
|
|
392
407
|
return "login"
|
|
393
408
|
|
|
394
409
|
|
|
@@ -417,7 +432,8 @@ def build_next_action_sft_samples(
|
|
|
417
432
|
samples: List[Dict[str, Any]] = []
|
|
418
433
|
|
|
419
434
|
for episode in episodes:
|
|
420
|
-
goal
|
|
435
|
+
# Use instruction as the goal (new schema field name)
|
|
436
|
+
goal = episode.instruction
|
|
421
437
|
total_steps = len(episode.steps)
|
|
422
438
|
scenario = _detect_scenario(episode)
|
|
423
439
|
|
|
@@ -430,18 +446,21 @@ def build_next_action_sft_samples(
|
|
|
430
446
|
else:
|
|
431
447
|
system_prompt = SYSTEM_PROMPT
|
|
432
448
|
|
|
433
|
-
for
|
|
434
|
-
|
|
449
|
+
for step in episode.steps:
|
|
450
|
+
# Use step_index from the Step model
|
|
451
|
+
step_index = step.step_index
|
|
452
|
+
# Use screenshot_path instead of image_path
|
|
453
|
+
image_path = step.observation.screenshot_path
|
|
435
454
|
if not image_path:
|
|
436
455
|
# Skip steps without an associated image
|
|
437
456
|
continue
|
|
438
457
|
|
|
439
458
|
# Build action history from previous steps
|
|
440
459
|
action_history = []
|
|
441
|
-
for
|
|
442
|
-
prev_step
|
|
443
|
-
|
|
444
|
-
|
|
460
|
+
for prev_step in episode.steps:
|
|
461
|
+
if prev_step.step_index < step_index:
|
|
462
|
+
prev_action_text = format_action(prev_step.action, use_som=use_som)
|
|
463
|
+
action_history.append(prev_action_text)
|
|
445
464
|
|
|
446
465
|
# Build history section for both modes - use actual step count
|
|
447
466
|
if action_history:
|
openadapt_ml/evals/grounding.py
CHANGED
|
@@ -212,30 +212,48 @@ def evaluate_grounder_on_episode(
|
|
|
212
212
|
"""
|
|
213
213
|
from PIL import Image
|
|
214
214
|
|
|
215
|
-
from openadapt_ml.
|
|
215
|
+
from openadapt_ml.schema import Episode, ActionType
|
|
216
216
|
|
|
217
217
|
test_cases = []
|
|
218
218
|
|
|
219
219
|
for step in episode.steps:
|
|
220
220
|
action = step.action
|
|
221
221
|
|
|
222
|
+
# Get action type as string for comparison
|
|
223
|
+
action_type_str = action.type.value if isinstance(action.type, ActionType) else action.type
|
|
224
|
+
|
|
222
225
|
# Only evaluate clicks with bboxes
|
|
223
|
-
if
|
|
226
|
+
if action_type_str not in ("click", "double_click"):
|
|
224
227
|
continue
|
|
225
|
-
|
|
228
|
+
|
|
229
|
+
# Check for bbox - in new schema, bbox is in element.bounds or raw
|
|
230
|
+
bbox = None
|
|
231
|
+
if action.element and action.element.bounds:
|
|
232
|
+
b = action.element.bounds
|
|
233
|
+
bbox = (b.x, b.y, b.x + b.width, b.y + b.height)
|
|
234
|
+
elif action.raw and "bbox" in action.raw:
|
|
235
|
+
bbox = action.raw["bbox"]
|
|
236
|
+
|
|
237
|
+
if bbox is None:
|
|
226
238
|
continue
|
|
227
|
-
if step.observation.
|
|
239
|
+
if step.observation.screenshot_path is None:
|
|
228
240
|
continue
|
|
229
241
|
|
|
230
242
|
# Load image
|
|
231
243
|
try:
|
|
232
|
-
image = Image.open(step.observation.
|
|
244
|
+
image = Image.open(step.observation.screenshot_path)
|
|
233
245
|
except Exception:
|
|
234
246
|
continue
|
|
235
247
|
|
|
236
|
-
# Create target description from
|
|
237
|
-
|
|
248
|
+
# Create target description from reasoning or action coordinates
|
|
249
|
+
coords_x, coords_y = None, None
|
|
250
|
+
if action.normalized_coordinates:
|
|
251
|
+
coords_x, coords_y = action.normalized_coordinates
|
|
252
|
+
if coords_x is not None and coords_y is not None:
|
|
253
|
+
target_desc = step.reasoning or f"element at ({coords_x:.2f}, {coords_y:.2f})"
|
|
254
|
+
else:
|
|
255
|
+
target_desc = step.reasoning or "target element"
|
|
238
256
|
|
|
239
|
-
test_cases.append((image, target_desc,
|
|
257
|
+
test_cases.append((image, target_desc, bbox))
|
|
240
258
|
|
|
241
259
|
return evaluate_grounder(grounder, test_cases, k=k)
|
|
@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
|
|
|
5
5
|
from typing import Any, Callable, Dict, List, Optional
|
|
6
6
|
|
|
7
7
|
from openadapt_ml.runtime.policy import AgentPolicy
|
|
8
|
-
from openadapt_ml.
|
|
8
|
+
from openadapt_ml.schema import Action, Episode, ActionType
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
@dataclass
|
|
@@ -92,22 +92,46 @@ class AggregateMetrics:
|
|
|
92
92
|
element_accuracy: Optional[float] = None # SoM element index accuracy
|
|
93
93
|
|
|
94
94
|
|
|
95
|
+
def _get_action_type_str(action: Action) -> str:
|
|
96
|
+
"""Get action type as string, handling both enum and string types."""
|
|
97
|
+
return action.type.value if isinstance(action.type, ActionType) else action.type
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _get_normalized_coords(action: Action) -> tuple[Optional[float], Optional[float]]:
|
|
101
|
+
"""Extract normalized coordinates from action."""
|
|
102
|
+
if action.normalized_coordinates:
|
|
103
|
+
return action.normalized_coordinates
|
|
104
|
+
return None, None
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _get_bbox(action: Action) -> Optional[tuple[float, float, float, float]]:
|
|
108
|
+
"""Extract bounding box from action, checking element.bounds or raw."""
|
|
109
|
+
if action.element and action.element.bounds:
|
|
110
|
+
b = action.element.bounds
|
|
111
|
+
return (b.x, b.y, b.x + b.width, b.y + b.height)
|
|
112
|
+
elif action.raw and "bbox" in action.raw:
|
|
113
|
+
return action.raw["bbox"]
|
|
114
|
+
return None
|
|
115
|
+
|
|
116
|
+
|
|
95
117
|
def compute_coordinate_error(pred_action: Action, gt_action: Action) -> Optional[float]:
|
|
96
118
|
"""Compute normalized L2 distance between predicted and ground-truth coords.
|
|
97
119
|
|
|
98
120
|
Returns None if either action is missing coordinates.
|
|
99
121
|
"""
|
|
122
|
+
pred_x, pred_y = _get_normalized_coords(pred_action)
|
|
123
|
+
gt_x, gt_y = _get_normalized_coords(gt_action)
|
|
100
124
|
|
|
101
125
|
if (
|
|
102
|
-
|
|
103
|
-
or
|
|
104
|
-
or
|
|
105
|
-
or
|
|
126
|
+
pred_x is None
|
|
127
|
+
or pred_y is None
|
|
128
|
+
or gt_x is None
|
|
129
|
+
or gt_y is None
|
|
106
130
|
):
|
|
107
131
|
return None
|
|
108
132
|
|
|
109
|
-
dx =
|
|
110
|
-
dy =
|
|
133
|
+
dx = pred_x - gt_x
|
|
134
|
+
dy = pred_y - gt_y
|
|
111
135
|
return math.sqrt(dx * dx + dy * dy)
|
|
112
136
|
|
|
113
137
|
|
|
@@ -119,14 +143,16 @@ def is_click_in_bbox(pred_action: Action, gt_action: Action) -> Optional[bool]:
|
|
|
119
143
|
- False if prediction is outside bbox
|
|
120
144
|
- None if no bbox is available (fall back to coord distance)
|
|
121
145
|
"""
|
|
122
|
-
|
|
146
|
+
gt_bbox = _get_bbox(gt_action)
|
|
147
|
+
if gt_bbox is None:
|
|
123
148
|
return None
|
|
124
149
|
|
|
125
|
-
|
|
150
|
+
pred_x, pred_y = _get_normalized_coords(pred_action)
|
|
151
|
+
if pred_x is None or pred_y is None:
|
|
126
152
|
return False
|
|
127
153
|
|
|
128
|
-
x_min, y_min, x_max, y_max =
|
|
129
|
-
return (x_min <=
|
|
154
|
+
x_min, y_min, x_max, y_max = gt_bbox
|
|
155
|
+
return (x_min <= pred_x <= x_max) and (y_min <= pred_y <= y_max)
|
|
130
156
|
|
|
131
157
|
|
|
132
158
|
def evaluate_episode(
|
|
@@ -177,7 +203,7 @@ def evaluate_episode(
|
|
|
177
203
|
|
|
178
204
|
for step_idx, step in enumerate(episode.steps):
|
|
179
205
|
# Skip steps without an image; the dataset builder does the same.
|
|
180
|
-
if not step.observation.
|
|
206
|
+
if not step.observation.screenshot_path:
|
|
181
207
|
continue
|
|
182
208
|
|
|
183
209
|
if sample_idx >= len(samples):
|
|
@@ -189,13 +215,17 @@ def evaluate_episode(
|
|
|
189
215
|
pred_action, _thought, pred_state, raw_text = policy.predict_action_from_sample(sample)
|
|
190
216
|
gt_action = step.action
|
|
191
217
|
|
|
218
|
+
# Get action types as strings for comparison
|
|
219
|
+
pred_type_str = _get_action_type_str(pred_action)
|
|
220
|
+
gt_type_str = _get_action_type_str(gt_action)
|
|
221
|
+
|
|
192
222
|
# Track state-based success from final step
|
|
193
223
|
if pred_state and isinstance(pred_state, dict):
|
|
194
224
|
success_val = pred_state.get("success")
|
|
195
225
|
if isinstance(success_val, bool):
|
|
196
226
|
last_state_success = success_val
|
|
197
227
|
|
|
198
|
-
type_match =
|
|
228
|
+
type_match = pred_type_str == gt_type_str
|
|
199
229
|
if type_match:
|
|
200
230
|
step_matches += 1
|
|
201
231
|
else:
|
|
@@ -206,14 +236,28 @@ def evaluate_episode(
|
|
|
206
236
|
bbox_hit = False
|
|
207
237
|
element_hit = False
|
|
208
238
|
|
|
239
|
+
# Helper to get element index - check element.element_id or raw field
|
|
240
|
+
def _get_element_index(action: Action) -> Optional[int]:
|
|
241
|
+
if action.element and action.element.element_id:
|
|
242
|
+
try:
|
|
243
|
+
return int(action.element.element_id)
|
|
244
|
+
except (ValueError, TypeError):
|
|
245
|
+
pass
|
|
246
|
+
if action.raw and "element_index" in action.raw:
|
|
247
|
+
return action.raw["element_index"]
|
|
248
|
+
return None
|
|
249
|
+
|
|
250
|
+
gt_element_index = _get_element_index(gt_action)
|
|
251
|
+
pred_element_index = _get_element_index(pred_action)
|
|
252
|
+
|
|
209
253
|
# SoM mode: evaluate by element index for click/drag/type actions
|
|
210
|
-
if use_som and
|
|
211
|
-
if
|
|
254
|
+
if use_som and gt_type_str in {"click", "drag", "type"}:
|
|
255
|
+
if gt_element_index is not None:
|
|
212
256
|
element_total += 1
|
|
213
|
-
if
|
|
257
|
+
if pred_element_index == gt_element_index:
|
|
214
258
|
element_hits += 1
|
|
215
259
|
element_hit = True
|
|
216
|
-
elif
|
|
260
|
+
elif gt_type_str in {"click", "drag"}:
|
|
217
261
|
# Coordinate mode: evaluate by coordinate distance
|
|
218
262
|
coord_error = compute_coordinate_error(pred_action, gt_action)
|
|
219
263
|
if coord_error is not None:
|
|
@@ -233,11 +277,11 @@ def evaluate_episode(
|
|
|
233
277
|
|
|
234
278
|
# Full step correctness: type matches AND element/coord match for relevant actions
|
|
235
279
|
if type_match:
|
|
236
|
-
if use_som and
|
|
280
|
+
if use_som and gt_type_str in {"click", "drag", "type"}:
|
|
237
281
|
# SoM mode: require element index match
|
|
238
282
|
if element_hit:
|
|
239
283
|
full_step_correct += 1
|
|
240
|
-
elif
|
|
284
|
+
elif gt_type_str in {"click", "drag"}:
|
|
241
285
|
# Coordinate mode: require click hit
|
|
242
286
|
if click_hit:
|
|
243
287
|
full_step_correct += 1
|
|
@@ -247,8 +291,8 @@ def evaluate_episode(
|
|
|
247
291
|
|
|
248
292
|
# Track semantic milestones using the milestone spec
|
|
249
293
|
for milestone in milestones:
|
|
250
|
-
if step_idx == milestone.step_index and
|
|
251
|
-
if
|
|
294
|
+
if step_idx == milestone.step_index and gt_type_str == milestone.expected_type:
|
|
295
|
+
if pred_type_str == milestone.expected_type:
|
|
252
296
|
# Check coord threshold if specified (for click actions)
|
|
253
297
|
if milestone.coord_threshold is not None:
|
|
254
298
|
if coord_error is not None and coord_error < milestone.coord_threshold:
|
|
@@ -258,9 +302,13 @@ def evaluate_episode(
|
|
|
258
302
|
milestones_achieved[milestone.name] = True
|
|
259
303
|
|
|
260
304
|
# Ensure DONE is correct at the DONE step.
|
|
261
|
-
if
|
|
305
|
+
if gt_type_str == "done" and pred_type_str != "done":
|
|
262
306
|
success_pred = False
|
|
263
307
|
|
|
308
|
+
# Get normalized coordinates for logging
|
|
309
|
+
pred_x, pred_y = _get_normalized_coords(pred_action)
|
|
310
|
+
gt_x, gt_y = _get_normalized_coords(gt_action)
|
|
311
|
+
|
|
264
312
|
# Optional logging of this step.
|
|
265
313
|
if log_fn is not None and (log_limit is None or logged_count < log_limit):
|
|
266
314
|
messages = sample.get("messages", [])
|
|
@@ -273,30 +321,30 @@ def evaluate_episode(
|
|
|
273
321
|
user_prompt = m.get("content")
|
|
274
322
|
|
|
275
323
|
record: Dict[str, Any] = {
|
|
276
|
-
"episode_id": episode.
|
|
324
|
+
"episode_id": episode.episode_id,
|
|
277
325
|
"step_index": step_idx,
|
|
278
|
-
"goal": episode.
|
|
326
|
+
"goal": episode.instruction,
|
|
279
327
|
"system_prompt": system_prompt,
|
|
280
328
|
"user_prompt": user_prompt,
|
|
281
329
|
"model_output_raw": raw_text,
|
|
282
330
|
"pred_action": {
|
|
283
|
-
"type":
|
|
284
|
-
"x":
|
|
285
|
-
"y":
|
|
331
|
+
"type": pred_type_str,
|
|
332
|
+
"x": pred_x,
|
|
333
|
+
"y": pred_y,
|
|
286
334
|
"text": pred_action.text,
|
|
287
|
-
"element_index":
|
|
335
|
+
"element_index": pred_element_index,
|
|
288
336
|
},
|
|
289
337
|
"ground_truth_action": {
|
|
290
|
-
"type":
|
|
291
|
-
"x":
|
|
292
|
-
"y":
|
|
338
|
+
"type": gt_type_str,
|
|
339
|
+
"x": gt_x,
|
|
340
|
+
"y": gt_y,
|
|
293
341
|
"text": gt_action.text,
|
|
294
|
-
"element_index":
|
|
342
|
+
"element_index": gt_element_index,
|
|
295
343
|
},
|
|
296
|
-
"correct_type":
|
|
344
|
+
"correct_type": pred_type_str == gt_type_str,
|
|
297
345
|
"coord_error_norm": coord_error,
|
|
298
|
-
"element_match":
|
|
299
|
-
if
|
|
346
|
+
"element_match": pred_element_index == gt_element_index
|
|
347
|
+
if gt_element_index is not None
|
|
300
348
|
else None,
|
|
301
349
|
}
|
|
302
350
|
|
|
@@ -306,7 +354,7 @@ def evaluate_episode(
|
|
|
306
354
|
step_total += 1
|
|
307
355
|
|
|
308
356
|
metrics = EpisodeMetrics(
|
|
309
|
-
episode_id=episode.
|
|
357
|
+
episode_id=episode.episode_id,
|
|
310
358
|
step_matches=step_matches,
|
|
311
359
|
step_total=step_total,
|
|
312
360
|
coord_errors=coord_errors,
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Demo-conditioned prompt experiment.
|
|
2
|
+
|
|
3
|
+
Tests whether including a human demonstration in the prompt
|
|
4
|
+
improves VLM agent performance on similar tasks.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from openadapt_ml.experiments.demo_prompt.format_demo import (
|
|
8
|
+
format_episode_as_demo,
|
|
9
|
+
format_action,
|
|
10
|
+
)
|
|
11
|
+
from openadapt_ml.experiments.demo_prompt.run_experiment import (
|
|
12
|
+
DemoPromptExperiment,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"format_episode_as_demo",
|
|
17
|
+
"format_action",
|
|
18
|
+
"DemoPromptExperiment",
|
|
19
|
+
]
|