openadapt-ml 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. openadapt_ml/benchmarks/__init__.py +8 -0
  2. openadapt_ml/benchmarks/agent.py +90 -11
  3. openadapt_ml/benchmarks/azure.py +35 -6
  4. openadapt_ml/benchmarks/cli.py +4449 -201
  5. openadapt_ml/benchmarks/live_tracker.py +180 -0
  6. openadapt_ml/benchmarks/runner.py +41 -4
  7. openadapt_ml/benchmarks/viewer.py +1219 -0
  8. openadapt_ml/benchmarks/vm_monitor.py +610 -0
  9. openadapt_ml/benchmarks/waa.py +61 -4
  10. openadapt_ml/benchmarks/waa_deploy/Dockerfile +222 -0
  11. openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
  12. openadapt_ml/benchmarks/waa_deploy/api_agent.py +539 -0
  13. openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
  14. openadapt_ml/benchmarks/waa_live.py +619 -0
  15. openadapt_ml/cloud/local.py +1555 -1
  16. openadapt_ml/cloud/ssh_tunnel.py +553 -0
  17. openadapt_ml/datasets/next_action.py +87 -68
  18. openadapt_ml/evals/grounding.py +26 -8
  19. openadapt_ml/evals/trajectory_matching.py +84 -36
  20. openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
  21. openadapt_ml/experiments/demo_prompt/format_demo.py +226 -0
  22. openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
  23. openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
  24. openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
  25. openadapt_ml/experiments/demo_prompt/run_experiment.py +531 -0
  26. openadapt_ml/experiments/waa_demo/__init__.py +10 -0
  27. openadapt_ml/experiments/waa_demo/demos.py +357 -0
  28. openadapt_ml/experiments/waa_demo/runner.py +717 -0
  29. openadapt_ml/experiments/waa_demo/tasks.py +151 -0
  30. openadapt_ml/export/__init__.py +9 -0
  31. openadapt_ml/export/__main__.py +6 -0
  32. openadapt_ml/export/cli.py +89 -0
  33. openadapt_ml/export/parquet.py +265 -0
  34. openadapt_ml/ingest/__init__.py +3 -4
  35. openadapt_ml/ingest/capture.py +89 -81
  36. openadapt_ml/ingest/loader.py +116 -68
  37. openadapt_ml/ingest/synthetic.py +221 -159
  38. openadapt_ml/retrieval/README.md +226 -0
  39. openadapt_ml/retrieval/USAGE.md +391 -0
  40. openadapt_ml/retrieval/__init__.py +91 -0
  41. openadapt_ml/retrieval/demo_retriever.py +817 -0
  42. openadapt_ml/retrieval/embeddings.py +629 -0
  43. openadapt_ml/retrieval/index.py +194 -0
  44. openadapt_ml/retrieval/retriever.py +160 -0
  45. openadapt_ml/runtime/policy.py +10 -10
  46. openadapt_ml/schema/__init__.py +104 -0
  47. openadapt_ml/schema/converters.py +541 -0
  48. openadapt_ml/schema/episode.py +457 -0
  49. openadapt_ml/scripts/compare.py +26 -16
  50. openadapt_ml/scripts/eval_policy.py +4 -5
  51. openadapt_ml/scripts/prepare_synthetic.py +14 -17
  52. openadapt_ml/scripts/train.py +81 -70
  53. openadapt_ml/training/benchmark_viewer.py +3225 -0
  54. openadapt_ml/training/trainer.py +120 -363
  55. openadapt_ml/training/trl_trainer.py +354 -0
  56. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/METADATA +102 -60
  57. openadapt_ml-0.2.0.dist-info/RECORD +86 -0
  58. openadapt_ml/schemas/__init__.py +0 -53
  59. openadapt_ml/schemas/sessions.py +0 -122
  60. openadapt_ml/schemas/validation.py +0 -252
  61. openadapt_ml-0.1.0.dist-info/RECORD +0 -55
  62. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/WHEEL +0 -0
  63. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -6,7 +6,7 @@ from typing import Any, Dict, List
6
6
  import torch
7
7
  from torch.utils.data import Dataset
8
8
 
9
- from openadapt_ml.schemas.sessions import Action, Episode, Step
9
+ from openadapt_ml.schema import Action, ActionType, Episode, Step, UIElement
10
10
 
11
11
 
12
12
  # Coordinate-based DSL system prompt (original)
@@ -97,6 +97,13 @@ SYSTEM_PROMPT_SOM_REGISTRATION = (
97
97
  )
98
98
 
99
99
 
100
+ def _get_element_id(action: Action) -> str | None:
101
+ """Extract element ID from action's element field."""
102
+ if action.element is not None and action.element.element_id is not None:
103
+ return action.element.element_id
104
+ return None
105
+
106
+
100
107
  def format_action(action: Action, use_som: bool = False) -> str:
101
108
  """Serialize an Action into a simple textual command.
102
109
 
@@ -110,53 +117,55 @@ def format_action(action: Action, use_som: bool = False) -> str:
110
117
  Args:
111
118
  action: The action to format.
112
119
  use_som: If True, use Set-of-Marks (SoM) index-based format instead of
113
- coordinate-based format. Requires element_index to be set.
120
+ coordinate-based format. Requires element with element_id to be set.
114
121
  """
115
122
 
116
123
  t = action.type
124
+ element_id = _get_element_id(action)
117
125
  if use_som:
118
126
  # SoM mode: use element indices instead of coordinates
119
- if t == "click" and action.element_index is not None:
120
- return f"CLICK([{action.element_index}])"
121
- if t == "type" and action.text is not None:
127
+ if t == ActionType.CLICK and element_id is not None:
128
+ return f"CLICK([{element_id}])"
129
+ if t == ActionType.TYPE and action.text is not None:
122
130
  escaped = action.text.replace("\\", "\\\\").replace("\"", "\\\"")
123
- if action.element_index is not None:
124
- return f"TYPE([{action.element_index}], \"{escaped}\")"
131
+ if element_id is not None:
132
+ return f"TYPE([{element_id}], \"{escaped}\")"
125
133
  else:
126
134
  # Fallback: TYPE without element reference (for focused field)
127
135
  return f"TYPE(\"{escaped}\")"
128
- if t == "wait":
136
+ if t == ActionType.WAIT:
129
137
  return "WAIT()"
130
- if t == "done":
138
+ if t == ActionType.DONE:
131
139
  return "DONE()"
132
140
  # Fallback
133
- return f"ACTION(type={t})"
141
+ return f"ACTION(type={t.value if isinstance(t, ActionType) else t})"
134
142
  else:
135
143
  # Coordinate mode (original)
136
- if t == "click" and action.x is not None and action.y is not None:
137
- return f"CLICK(x={action.x:.2f}, y={action.y:.2f})"
138
- if t == "type" and action.text is not None:
144
+ if t == ActionType.CLICK and action.normalized_coordinates is not None:
145
+ x, y = action.normalized_coordinates
146
+ return f"CLICK(x={x:.2f}, y={y:.2f})"
147
+ if t == ActionType.TYPE and action.text is not None:
139
148
  escaped = action.text.replace("\\", "\\\\").replace("\"", "\\\"")
140
149
  return f"TYPE(text=\"{escaped}\")"
141
- if t == "wait":
150
+ if t == ActionType.WAIT:
142
151
  return "WAIT()"
143
- if t == "done":
152
+ if t == ActionType.DONE:
144
153
  return "DONE()"
145
154
  # Fallback
146
- return f"ACTION(type={t})"
155
+ return f"ACTION(type={t.value if isinstance(t, ActionType) else t})"
147
156
 
148
157
 
149
158
  def parse_action_som(text: str) -> Action:
150
159
  """Parse a SoM-style action string into an Action object.
151
160
 
152
161
  Supported formats:
153
- - CLICK([N]) click element N
154
- - TYPE([N], "text") type text into element N
155
- - TYPE("text") type text into focused field
156
- - WAIT() wait
157
- - DONE() done
162
+ - CLICK([N]) -> click element N
163
+ - TYPE([N], "text") -> type text into element N
164
+ - TYPE("text") -> type text into focused field
165
+ - WAIT() -> wait
166
+ - DONE() -> done
158
167
 
159
- Returns Action with element_index set for click/type actions.
168
+ Returns Action with element set for click/type actions.
160
169
  """
161
170
  import re
162
171
 
@@ -165,32 +174,32 @@ def parse_action_som(text: str) -> Action:
165
174
  # CLICK([N])
166
175
  match = re.match(r"CLICK\(\[(\d+)\]\)", text)
167
176
  if match:
168
- idx = int(match.group(1))
169
- return Action(type="click", element_index=idx)
177
+ idx = match.group(1)
178
+ return Action(type=ActionType.CLICK, element=UIElement(element_id=idx))
170
179
 
171
180
  # TYPE([N], "text") or TYPE([N], 'text')
172
181
  match = re.match(r'TYPE\(\[(\d+)\],\s*["\'](.*)["\']\)', text, re.DOTALL)
173
182
  if match:
174
- idx = int(match.group(1))
183
+ idx = match.group(1)
175
184
  content = match.group(2).replace("\\\"", "\"").replace("\\\\", "\\")
176
- return Action(type="type", text=content, element_index=idx)
185
+ return Action(type=ActionType.TYPE, text=content, element=UIElement(element_id=idx))
177
186
 
178
187
  # TYPE("text") - no element index
179
188
  match = re.match(r'TYPE\(["\'](.*)["\']\)', text, re.DOTALL)
180
189
  if match:
181
190
  content = match.group(1).replace("\\\"", "\"").replace("\\\\", "\\")
182
- return Action(type="type", text=content)
191
+ return Action(type=ActionType.TYPE, text=content)
183
192
 
184
193
  # WAIT()
185
194
  if text.upper().startswith("WAIT"):
186
- return Action(type="wait")
195
+ return Action(type=ActionType.WAIT)
187
196
 
188
197
  # DONE()
189
198
  if text.upper().startswith("DONE"):
190
- return Action(type="done")
199
+ return Action(type=ActionType.DONE)
191
200
 
192
201
  # Failed to parse
193
- return Action(type="failed", raw={"text": text})
202
+ return Action(type=ActionType.FAIL, raw={"text": text})
194
203
 
195
204
 
196
205
  def _generate_generic_thought(step_index: int, step: Step, goal: str, total_steps: int) -> str:
@@ -205,10 +214,10 @@ def _generate_generic_thought(step_index: int, step: Step, goal: str, total_step
205
214
  # Progress context
206
215
  progress = f"Step {step_index + 1} of {total_steps}."
207
216
 
208
- if t == "click":
209
- if action.x is not None and action.y is not None:
217
+ if t == ActionType.CLICK:
218
+ if action.normalized_coordinates is not None:
210
219
  # Describe the click location relative to screen regions
211
- x, y = action.x, action.y
220
+ x, y = action.normalized_coordinates
212
221
  h_pos = "left" if x < 0.33 else ("center" if x < 0.66 else "right")
213
222
  v_pos = "top" if y < 0.33 else ("middle" if y < 0.66 else "bottom")
214
223
  return (
@@ -217,28 +226,28 @@ def _generate_generic_thought(step_index: int, step: Step, goal: str, total_step
217
226
  )
218
227
  return f"{progress} I need to click on the relevant UI element to continue toward '{goal}'."
219
228
 
220
- if t == "double_click":
229
+ if t == ActionType.DOUBLE_CLICK:
221
230
  return f"{progress} I need to double-click to select or activate this element for '{goal}'."
222
231
 
223
- if t == "type":
232
+ if t == ActionType.TYPE:
224
233
  if action.text:
225
234
  # Don't reveal the actual text, just indicate typing is needed
226
235
  return f"{progress} I need to type text into the focused input field to continue toward '{goal}'."
227
236
  return f"{progress} I need to enter text in the current field."
228
237
 
229
- if t == "scroll":
238
+ if t == ActionType.SCROLL:
230
239
  return f"{progress} I need to scroll to reveal more content or reach the target element for '{goal}'."
231
240
 
232
- if t == "drag":
241
+ if t == ActionType.DRAG:
233
242
  return f"{progress} I need to drag an element to complete this part of '{goal}'."
234
243
 
235
- if t == "key_press":
244
+ if t == ActionType.KEY:
236
245
  return f"{progress} I need to press a key to continue the workflow."
237
246
 
238
- if t == "wait":
247
+ if t == ActionType.WAIT:
239
248
  return f"{progress} I should wait for the UI to update before the next action."
240
249
 
241
- if t == "done":
250
+ if t == ActionType.DONE:
242
251
  return f"The goal '{goal}' has been achieved. The workflow is complete."
243
252
 
244
253
  # Fallback
@@ -279,42 +288,42 @@ def _generate_login_thought(step_index: int, step: Step, goal: str, total_steps:
279
288
  t = action.type
280
289
 
281
290
  # Step 0: click username field
282
- if step_index == 0 and t == "click":
291
+ if step_index == 0 and t == ActionType.CLICK:
283
292
  return (
284
293
  "I see a login screen with empty username and password fields and a Login button. "
285
294
  f"To start logging in, I need to click on the username field to focus it ({goal})."
286
295
  )
287
296
 
288
297
  # Step 1: type username
289
- if step_index == 1 and t == "type":
298
+ if step_index == 1 and t == ActionType.TYPE:
290
299
  return (
291
300
  "The username field is focused. To move toward the login goal, I should type the "
292
301
  "username into this field."
293
302
  )
294
303
 
295
304
  # Step 2: click password field
296
- if step_index == 2 and t == "click":
305
+ if step_index == 2 and t == ActionType.CLICK:
297
306
  return (
298
307
  "The username has been entered. Next, I need to focus the password field so that I can "
299
308
  "enter the password for this login. I will click on the password input box."
300
309
  )
301
310
 
302
311
  # Step 3: type password
303
- if step_index == 3 and t == "type":
312
+ if step_index == 3 and t == ActionType.TYPE:
304
313
  return (
305
314
  "The password field is focused. To continue the login process, I should type the "
306
315
  "password (which will appear as masked characters on the screen)."
307
316
  )
308
317
 
309
318
  # Step 4: click Login button
310
- if step_index == 4 and t == "click":
319
+ if step_index == 4 and t == ActionType.CLICK:
311
320
  return (
312
321
  "Both the username and password have been entered. To submit the form and attempt the "
313
322
  "login, I should click the Login button."
314
323
  )
315
324
 
316
325
  # Step 5: DONE on logged-in screen
317
- if step_index == 5 and t == "done":
326
+ if step_index == 5 and t == ActionType.DONE:
318
327
  return (
319
328
  "I now see a logged-in confirmation screen indicating the goal has been satisfied. "
320
329
  "The task is complete, so I should emit DONE()."
@@ -334,41 +343,41 @@ def _generate_registration_thought(step_index: int, step: Step, goal: str, total
334
343
 
335
344
  # Registration step mapping (pairs of click + type for 5 fields, then submit + done)
336
345
  thoughts = {
337
- (0, "click"): (
346
+ (0, ActionType.CLICK): (
338
347
  "I see a registration form with empty fields for name, email, and password. "
339
348
  f"To start registration, I need to click on the First Name field ({goal})."
340
349
  ),
341
- (1, "type"): (
350
+ (1, ActionType.TYPE): (
342
351
  "The First Name field is focused. I should type the first name."
343
352
  ),
344
- (2, "click"): (
353
+ (2, ActionType.CLICK): (
345
354
  "First name entered. Now I need to focus the Last Name field to enter it."
346
355
  ),
347
- (3, "type"): (
356
+ (3, ActionType.TYPE): (
348
357
  "The Last Name field is focused. I should type the last name."
349
358
  ),
350
- (4, "click"): (
359
+ (4, ActionType.CLICK): (
351
360
  "Last name entered. Now I need to focus the Email field to enter the email address."
352
361
  ),
353
- (5, "type"): (
362
+ (5, ActionType.TYPE): (
354
363
  "The Email field is focused. I should type the email address."
355
364
  ),
356
- (6, "click"): (
365
+ (6, ActionType.CLICK): (
357
366
  "Email entered. Now I need to focus the Password field to create a password."
358
367
  ),
359
- (7, "type"): (
368
+ (7, ActionType.TYPE): (
360
369
  "The Password field is focused. I should type the password."
361
370
  ),
362
- (8, "click"): (
371
+ (8, ActionType.CLICK): (
363
372
  "Password entered. Now I need to focus the Confirm Password field to verify the password."
364
373
  ),
365
- (9, "type"): (
374
+ (9, ActionType.TYPE): (
366
375
  "The Confirm Password field is focused. I should type the same password again."
367
376
  ),
368
- (10, "click"): (
377
+ (10, ActionType.CLICK): (
369
378
  "All form fields are filled. I should click the Register button to submit the form."
370
379
  ),
371
- (11, "done"): (
380
+ (11, ActionType.DONE): (
372
381
  "Registration is complete - I see a success screen. The task is finished."
373
382
  ),
374
383
  }
@@ -385,10 +394,16 @@ def _generate_registration_thought(step_index: int, step: Step, goal: str, total
385
394
 
386
395
 
387
396
  def _detect_scenario(episode: Episode) -> str:
388
- """Detect scenario from episode workflow_id."""
389
- workflow_id = episode.workflow_id or ""
390
- if "registration" in workflow_id.lower():
397
+ """Detect scenario from episode task_id or metadata."""
398
+ # Check task_id first
399
+ task_id = episode.task_id or ""
400
+ if "registration" in task_id.lower():
391
401
  return "registration"
402
+ # Check metadata for workflow_id (backward compatibility)
403
+ if episode.metadata and "workflow_id" in episode.metadata:
404
+ workflow_id = episode.metadata.get("workflow_id", "")
405
+ if "registration" in str(workflow_id).lower():
406
+ return "registration"
392
407
  return "login"
393
408
 
394
409
 
@@ -417,7 +432,8 @@ def build_next_action_sft_samples(
417
432
  samples: List[Dict[str, Any]] = []
418
433
 
419
434
  for episode in episodes:
420
- goal = episode.goal
435
+ # Use instruction as the goal (new schema field name)
436
+ goal = episode.instruction
421
437
  total_steps = len(episode.steps)
422
438
  scenario = _detect_scenario(episode)
423
439
 
@@ -430,18 +446,21 @@ def build_next_action_sft_samples(
430
446
  else:
431
447
  system_prompt = SYSTEM_PROMPT
432
448
 
433
- for step_index, step in enumerate(episode.steps):
434
- image_path = step.observation.image_path
449
+ for step in episode.steps:
450
+ # Use step_index from the Step model
451
+ step_index = step.step_index
452
+ # Use screenshot_path instead of image_path
453
+ image_path = step.observation.screenshot_path
435
454
  if not image_path:
436
455
  # Skip steps without an associated image
437
456
  continue
438
457
 
439
458
  # Build action history from previous steps
440
459
  action_history = []
441
- for prev_idx in range(step_index):
442
- prev_step = episode.steps[prev_idx]
443
- prev_action_text = format_action(prev_step.action, use_som=use_som)
444
- action_history.append(prev_action_text)
460
+ for prev_step in episode.steps:
461
+ if prev_step.step_index < step_index:
462
+ prev_action_text = format_action(prev_step.action, use_som=use_som)
463
+ action_history.append(prev_action_text)
445
464
 
446
465
  # Build history section for both modes - use actual step count
447
466
  if action_history:
@@ -212,30 +212,48 @@ def evaluate_grounder_on_episode(
212
212
  """
213
213
  from PIL import Image
214
214
 
215
- from openadapt_ml.schemas.sessions import Episode
215
+ from openadapt_ml.schema import Episode, ActionType
216
216
 
217
217
  test_cases = []
218
218
 
219
219
  for step in episode.steps:
220
220
  action = step.action
221
221
 
222
+ # Get action type as string for comparison
223
+ action_type_str = action.type.value if isinstance(action.type, ActionType) else action.type
224
+
222
225
  # Only evaluate clicks with bboxes
223
- if action.type not in ("click", "double_click"):
226
+ if action_type_str not in ("click", "double_click"):
224
227
  continue
225
- if action.bbox is None:
228
+
229
+ # Check for bbox - in new schema, bbox is in element.bounds or raw
230
+ bbox = None
231
+ if action.element and action.element.bounds:
232
+ b = action.element.bounds
233
+ bbox = (b.x, b.y, b.x + b.width, b.y + b.height)
234
+ elif action.raw and "bbox" in action.raw:
235
+ bbox = action.raw["bbox"]
236
+
237
+ if bbox is None:
226
238
  continue
227
- if step.observation.image_path is None:
239
+ if step.observation.screenshot_path is None:
228
240
  continue
229
241
 
230
242
  # Load image
231
243
  try:
232
- image = Image.open(step.observation.image_path)
244
+ image = Image.open(step.observation.screenshot_path)
233
245
  except Exception:
234
246
  continue
235
247
 
236
- # Create target description from thought or action
237
- target_desc = step.thought or f"element at ({action.x:.2f}, {action.y:.2f})"
248
+ # Create target description from reasoning or action coordinates
249
+ coords_x, coords_y = None, None
250
+ if action.normalized_coordinates:
251
+ coords_x, coords_y = action.normalized_coordinates
252
+ if coords_x is not None and coords_y is not None:
253
+ target_desc = step.reasoning or f"element at ({coords_x:.2f}, {coords_y:.2f})"
254
+ else:
255
+ target_desc = step.reasoning or "target element"
238
256
 
239
- test_cases.append((image, target_desc, action.bbox))
257
+ test_cases.append((image, target_desc, bbox))
240
258
 
241
259
  return evaluate_grounder(grounder, test_cases, k=k)
@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
5
5
  from typing import Any, Callable, Dict, List, Optional
6
6
 
7
7
  from openadapt_ml.runtime.policy import AgentPolicy
8
- from openadapt_ml.schemas.sessions import Action, Episode
8
+ from openadapt_ml.schema import Action, Episode, ActionType
9
9
 
10
10
 
11
11
  @dataclass
@@ -92,22 +92,46 @@ class AggregateMetrics:
92
92
  element_accuracy: Optional[float] = None # SoM element index accuracy
93
93
 
94
94
 
95
+ def _get_action_type_str(action: Action) -> str:
96
+ """Get action type as string, handling both enum and string types."""
97
+ return action.type.value if isinstance(action.type, ActionType) else action.type
98
+
99
+
100
+ def _get_normalized_coords(action: Action) -> tuple[Optional[float], Optional[float]]:
101
+ """Extract normalized coordinates from action."""
102
+ if action.normalized_coordinates:
103
+ return action.normalized_coordinates
104
+ return None, None
105
+
106
+
107
+ def _get_bbox(action: Action) -> Optional[tuple[float, float, float, float]]:
108
+ """Extract bounding box from action, checking element.bounds or raw."""
109
+ if action.element and action.element.bounds:
110
+ b = action.element.bounds
111
+ return (b.x, b.y, b.x + b.width, b.y + b.height)
112
+ elif action.raw and "bbox" in action.raw:
113
+ return action.raw["bbox"]
114
+ return None
115
+
116
+
95
117
  def compute_coordinate_error(pred_action: Action, gt_action: Action) -> Optional[float]:
96
118
  """Compute normalized L2 distance between predicted and ground-truth coords.
97
119
 
98
120
  Returns None if either action is missing coordinates.
99
121
  """
122
+ pred_x, pred_y = _get_normalized_coords(pred_action)
123
+ gt_x, gt_y = _get_normalized_coords(gt_action)
100
124
 
101
125
  if (
102
- pred_action.x is None
103
- or pred_action.y is None
104
- or gt_action.x is None
105
- or gt_action.y is None
126
+ pred_x is None
127
+ or pred_y is None
128
+ or gt_x is None
129
+ or gt_y is None
106
130
  ):
107
131
  return None
108
132
 
109
- dx = pred_action.x - gt_action.x
110
- dy = pred_action.y - gt_action.y
133
+ dx = pred_x - gt_x
134
+ dy = pred_y - gt_y
111
135
  return math.sqrt(dx * dx + dy * dy)
112
136
 
113
137
 
@@ -119,14 +143,16 @@ def is_click_in_bbox(pred_action: Action, gt_action: Action) -> Optional[bool]:
119
143
  - False if prediction is outside bbox
120
144
  - None if no bbox is available (fall back to coord distance)
121
145
  """
122
- if gt_action.bbox is None:
146
+ gt_bbox = _get_bbox(gt_action)
147
+ if gt_bbox is None:
123
148
  return None
124
149
 
125
- if pred_action.x is None or pred_action.y is None:
150
+ pred_x, pred_y = _get_normalized_coords(pred_action)
151
+ if pred_x is None or pred_y is None:
126
152
  return False
127
153
 
128
- x_min, y_min, x_max, y_max = gt_action.bbox
129
- return (x_min <= pred_action.x <= x_max) and (y_min <= pred_action.y <= y_max)
154
+ x_min, y_min, x_max, y_max = gt_bbox
155
+ return (x_min <= pred_x <= x_max) and (y_min <= pred_y <= y_max)
130
156
 
131
157
 
132
158
  def evaluate_episode(
@@ -177,7 +203,7 @@ def evaluate_episode(
177
203
 
178
204
  for step_idx, step in enumerate(episode.steps):
179
205
  # Skip steps without an image; the dataset builder does the same.
180
- if not step.observation.image_path:
206
+ if not step.observation.screenshot_path:
181
207
  continue
182
208
 
183
209
  if sample_idx >= len(samples):
@@ -189,13 +215,17 @@ def evaluate_episode(
189
215
  pred_action, _thought, pred_state, raw_text = policy.predict_action_from_sample(sample)
190
216
  gt_action = step.action
191
217
 
218
+ # Get action types as strings for comparison
219
+ pred_type_str = _get_action_type_str(pred_action)
220
+ gt_type_str = _get_action_type_str(gt_action)
221
+
192
222
  # Track state-based success from final step
193
223
  if pred_state and isinstance(pred_state, dict):
194
224
  success_val = pred_state.get("success")
195
225
  if isinstance(success_val, bool):
196
226
  last_state_success = success_val
197
227
 
198
- type_match = pred_action.type == gt_action.type
228
+ type_match = pred_type_str == gt_type_str
199
229
  if type_match:
200
230
  step_matches += 1
201
231
  else:
@@ -206,14 +236,28 @@ def evaluate_episode(
206
236
  bbox_hit = False
207
237
  element_hit = False
208
238
 
239
+ # Helper to get element index - check element.element_id or raw field
240
+ def _get_element_index(action: Action) -> Optional[int]:
241
+ if action.element and action.element.element_id:
242
+ try:
243
+ return int(action.element.element_id)
244
+ except (ValueError, TypeError):
245
+ pass
246
+ if action.raw and "element_index" in action.raw:
247
+ return action.raw["element_index"]
248
+ return None
249
+
250
+ gt_element_index = _get_element_index(gt_action)
251
+ pred_element_index = _get_element_index(pred_action)
252
+
209
253
  # SoM mode: evaluate by element index for click/drag/type actions
210
- if use_som and gt_action.type in {"click", "drag", "type"}:
211
- if gt_action.element_index is not None:
254
+ if use_som and gt_type_str in {"click", "drag", "type"}:
255
+ if gt_element_index is not None:
212
256
  element_total += 1
213
- if pred_action.element_index == gt_action.element_index:
257
+ if pred_element_index == gt_element_index:
214
258
  element_hits += 1
215
259
  element_hit = True
216
- elif gt_action.type in {"click", "drag"}:
260
+ elif gt_type_str in {"click", "drag"}:
217
261
  # Coordinate mode: evaluate by coordinate distance
218
262
  coord_error = compute_coordinate_error(pred_action, gt_action)
219
263
  if coord_error is not None:
@@ -233,11 +277,11 @@ def evaluate_episode(
233
277
 
234
278
  # Full step correctness: type matches AND element/coord match for relevant actions
235
279
  if type_match:
236
- if use_som and gt_action.type in {"click", "drag", "type"}:
280
+ if use_som and gt_type_str in {"click", "drag", "type"}:
237
281
  # SoM mode: require element index match
238
282
  if element_hit:
239
283
  full_step_correct += 1
240
- elif gt_action.type in {"click", "drag"}:
284
+ elif gt_type_str in {"click", "drag"}:
241
285
  # Coordinate mode: require click hit
242
286
  if click_hit:
243
287
  full_step_correct += 1
@@ -247,8 +291,8 @@ def evaluate_episode(
247
291
 
248
292
  # Track semantic milestones using the milestone spec
249
293
  for milestone in milestones:
250
- if step_idx == milestone.step_index and gt_action.type == milestone.expected_type:
251
- if pred_action.type == milestone.expected_type:
294
+ if step_idx == milestone.step_index and gt_type_str == milestone.expected_type:
295
+ if pred_type_str == milestone.expected_type:
252
296
  # Check coord threshold if specified (for click actions)
253
297
  if milestone.coord_threshold is not None:
254
298
  if coord_error is not None and coord_error < milestone.coord_threshold:
@@ -258,9 +302,13 @@ def evaluate_episode(
258
302
  milestones_achieved[milestone.name] = True
259
303
 
260
304
  # Ensure DONE is correct at the DONE step.
261
- if gt_action.type == "done" and pred_action.type != "done":
305
+ if gt_type_str == "done" and pred_type_str != "done":
262
306
  success_pred = False
263
307
 
308
+ # Get normalized coordinates for logging
309
+ pred_x, pred_y = _get_normalized_coords(pred_action)
310
+ gt_x, gt_y = _get_normalized_coords(gt_action)
311
+
264
312
  # Optional logging of this step.
265
313
  if log_fn is not None and (log_limit is None or logged_count < log_limit):
266
314
  messages = sample.get("messages", [])
@@ -273,30 +321,30 @@ def evaluate_episode(
273
321
  user_prompt = m.get("content")
274
322
 
275
323
  record: Dict[str, Any] = {
276
- "episode_id": episode.id,
324
+ "episode_id": episode.episode_id,
277
325
  "step_index": step_idx,
278
- "goal": episode.goal,
326
+ "goal": episode.instruction,
279
327
  "system_prompt": system_prompt,
280
328
  "user_prompt": user_prompt,
281
329
  "model_output_raw": raw_text,
282
330
  "pred_action": {
283
- "type": pred_action.type,
284
- "x": pred_action.x,
285
- "y": pred_action.y,
331
+ "type": pred_type_str,
332
+ "x": pred_x,
333
+ "y": pred_y,
286
334
  "text": pred_action.text,
287
- "element_index": pred_action.element_index,
335
+ "element_index": pred_element_index,
288
336
  },
289
337
  "ground_truth_action": {
290
- "type": gt_action.type,
291
- "x": gt_action.x,
292
- "y": gt_action.y,
338
+ "type": gt_type_str,
339
+ "x": gt_x,
340
+ "y": gt_y,
293
341
  "text": gt_action.text,
294
- "element_index": gt_action.element_index,
342
+ "element_index": gt_element_index,
295
343
  },
296
- "correct_type": pred_action.type == gt_action.type,
344
+ "correct_type": pred_type_str == gt_type_str,
297
345
  "coord_error_norm": coord_error,
298
- "element_match": pred_action.element_index == gt_action.element_index
299
- if gt_action.element_index is not None
346
+ "element_match": pred_element_index == gt_element_index
347
+ if gt_element_index is not None
300
348
  else None,
301
349
  }
302
350
 
@@ -306,7 +354,7 @@ def evaluate_episode(
306
354
  step_total += 1
307
355
 
308
356
  metrics = EpisodeMetrics(
309
- episode_id=episode.id,
357
+ episode_id=episode.episode_id,
310
358
  step_matches=step_matches,
311
359
  step_total=step_total,
312
360
  coord_errors=coord_errors,
@@ -0,0 +1,19 @@
1
+ """Demo-conditioned prompt experiment.
2
+
3
+ Tests whether including a human demonstration in the prompt
4
+ improves VLM agent performance on similar tasks.
5
+ """
6
+
7
+ from openadapt_ml.experiments.demo_prompt.format_demo import (
8
+ format_episode_as_demo,
9
+ format_action,
10
+ )
11
+ from openadapt_ml.experiments.demo_prompt.run_experiment import (
12
+ DemoPromptExperiment,
13
+ )
14
+
15
+ __all__ = [
16
+ "format_episode_as_demo",
17
+ "format_action",
18
+ "DemoPromptExperiment",
19
+ ]