openadapt-ml 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. openadapt_ml/baselines/__init__.py +121 -0
  2. openadapt_ml/baselines/adapter.py +185 -0
  3. openadapt_ml/baselines/cli.py +314 -0
  4. openadapt_ml/baselines/config.py +448 -0
  5. openadapt_ml/baselines/parser.py +922 -0
  6. openadapt_ml/baselines/prompts.py +787 -0
  7. openadapt_ml/benchmarks/__init__.py +13 -107
  8. openadapt_ml/benchmarks/agent.py +297 -374
  9. openadapt_ml/benchmarks/azure.py +62 -24
  10. openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
  11. openadapt_ml/benchmarks/cli.py +1874 -751
  12. openadapt_ml/benchmarks/trace_export.py +631 -0
  13. openadapt_ml/benchmarks/viewer.py +1236 -0
  14. openadapt_ml/benchmarks/vm_monitor.py +1111 -0
  15. openadapt_ml/benchmarks/waa_deploy/Dockerfile +216 -0
  16. openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
  17. openadapt_ml/benchmarks/waa_deploy/api_agent.py +540 -0
  18. openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
  19. openadapt_ml/cloud/azure_inference.py +3 -5
  20. openadapt_ml/cloud/lambda_labs.py +722 -307
  21. openadapt_ml/cloud/local.py +3194 -89
  22. openadapt_ml/cloud/ssh_tunnel.py +595 -0
  23. openadapt_ml/datasets/next_action.py +125 -96
  24. openadapt_ml/evals/grounding.py +32 -9
  25. openadapt_ml/evals/plot_eval_metrics.py +15 -13
  26. openadapt_ml/evals/trajectory_matching.py +120 -57
  27. openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
  28. openadapt_ml/experiments/demo_prompt/format_demo.py +236 -0
  29. openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
  30. openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
  31. openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
  32. openadapt_ml/experiments/demo_prompt/run_experiment.py +541 -0
  33. openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
  34. openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
  35. openadapt_ml/experiments/representation_shootout/config.py +390 -0
  36. openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
  37. openadapt_ml/experiments/representation_shootout/runner.py +687 -0
  38. openadapt_ml/experiments/waa_demo/__init__.py +10 -0
  39. openadapt_ml/experiments/waa_demo/demos.py +357 -0
  40. openadapt_ml/experiments/waa_demo/runner.py +732 -0
  41. openadapt_ml/experiments/waa_demo/tasks.py +151 -0
  42. openadapt_ml/export/__init__.py +9 -0
  43. openadapt_ml/export/__main__.py +6 -0
  44. openadapt_ml/export/cli.py +89 -0
  45. openadapt_ml/export/parquet.py +277 -0
  46. openadapt_ml/grounding/detector.py +18 -14
  47. openadapt_ml/ingest/__init__.py +11 -10
  48. openadapt_ml/ingest/capture.py +97 -86
  49. openadapt_ml/ingest/loader.py +120 -69
  50. openadapt_ml/ingest/synthetic.py +344 -193
  51. openadapt_ml/models/api_adapter.py +14 -4
  52. openadapt_ml/models/base_adapter.py +10 -2
  53. openadapt_ml/models/providers/__init__.py +288 -0
  54. openadapt_ml/models/providers/anthropic.py +266 -0
  55. openadapt_ml/models/providers/base.py +299 -0
  56. openadapt_ml/models/providers/google.py +376 -0
  57. openadapt_ml/models/providers/openai.py +342 -0
  58. openadapt_ml/models/qwen_vl.py +46 -19
  59. openadapt_ml/perception/__init__.py +35 -0
  60. openadapt_ml/perception/integration.py +399 -0
  61. openadapt_ml/retrieval/README.md +226 -0
  62. openadapt_ml/retrieval/USAGE.md +391 -0
  63. openadapt_ml/retrieval/__init__.py +91 -0
  64. openadapt_ml/retrieval/demo_retriever.py +843 -0
  65. openadapt_ml/retrieval/embeddings.py +630 -0
  66. openadapt_ml/retrieval/index.py +194 -0
  67. openadapt_ml/retrieval/retriever.py +162 -0
  68. openadapt_ml/runtime/__init__.py +50 -0
  69. openadapt_ml/runtime/policy.py +27 -14
  70. openadapt_ml/runtime/safety_gate.py +471 -0
  71. openadapt_ml/schema/__init__.py +113 -0
  72. openadapt_ml/schema/converters.py +588 -0
  73. openadapt_ml/schema/episode.py +470 -0
  74. openadapt_ml/scripts/capture_screenshots.py +530 -0
  75. openadapt_ml/scripts/compare.py +102 -61
  76. openadapt_ml/scripts/demo_policy.py +4 -1
  77. openadapt_ml/scripts/eval_policy.py +19 -14
  78. openadapt_ml/scripts/make_gif.py +1 -1
  79. openadapt_ml/scripts/prepare_synthetic.py +16 -17
  80. openadapt_ml/scripts/train.py +98 -75
  81. openadapt_ml/segmentation/README.md +920 -0
  82. openadapt_ml/segmentation/__init__.py +97 -0
  83. openadapt_ml/segmentation/adapters/__init__.py +5 -0
  84. openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
  85. openadapt_ml/segmentation/annotator.py +610 -0
  86. openadapt_ml/segmentation/cache.py +290 -0
  87. openadapt_ml/segmentation/cli.py +674 -0
  88. openadapt_ml/segmentation/deduplicator.py +656 -0
  89. openadapt_ml/segmentation/frame_describer.py +788 -0
  90. openadapt_ml/segmentation/pipeline.py +340 -0
  91. openadapt_ml/segmentation/schemas.py +622 -0
  92. openadapt_ml/segmentation/segment_extractor.py +634 -0
  93. openadapt_ml/training/azure_ops_viewer.py +1097 -0
  94. openadapt_ml/training/benchmark_viewer.py +3255 -19
  95. openadapt_ml/training/shared_ui.py +7 -7
  96. openadapt_ml/training/stub_provider.py +57 -35
  97. openadapt_ml/training/trainer.py +255 -441
  98. openadapt_ml/training/trl_trainer.py +403 -0
  99. openadapt_ml/training/viewer.py +323 -108
  100. openadapt_ml/training/viewer_components.py +180 -0
  101. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +312 -69
  102. openadapt_ml-0.2.1.dist-info/RECORD +116 -0
  103. openadapt_ml/benchmarks/base.py +0 -366
  104. openadapt_ml/benchmarks/data_collection.py +0 -432
  105. openadapt_ml/benchmarks/runner.py +0 -381
  106. openadapt_ml/benchmarks/waa.py +0 -704
  107. openadapt_ml/schemas/__init__.py +0 -53
  108. openadapt_ml/schemas/sessions.py +0 -122
  109. openadapt_ml/schemas/validation.py +0 -252
  110. openadapt_ml-0.1.0.dist-info/RECORD +0 -55
  111. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
  112. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -3,10 +3,9 @@ from __future__ import annotations
3
3
  from dataclasses import dataclass
4
4
  from typing import Any, Dict, List
5
5
 
6
- import torch
7
6
  from torch.utils.data import Dataset
8
7
 
9
- from openadapt_ml.schemas.sessions import Action, Episode, Step
8
+ from openadapt_ml.schema import Action, ActionType, Episode, Step, UIElement
10
9
 
11
10
 
12
11
  # Coordinate-based DSL system prompt (original)
@@ -20,7 +19,7 @@ SYSTEM_PROMPT = (
20
19
  "- Example: An element in the middle of the screen would be approximately x=0.5, y=0.5\n\n"
21
20
  "ALLOWED ACTIONS (use exactly this format):\n"
22
21
  "- CLICK(x=0.XX, y=0.XX) → click at normalized coordinates\n"
23
- "- TYPE(text=\"...\") → type text into the currently focused field\n"
22
+ '- TYPE(text="...") → type text into the currently focused field\n'
24
23
  "- WAIT() → wait for UI to update\n"
25
24
  "- DONE() → task is complete\n\n"
26
25
  "RESPONSE FORMAT (required):\n"
@@ -42,14 +41,14 @@ SYSTEM_PROMPT_SOM = (
42
41
  "[3] = Login button\n\n"
43
42
  "ALLOWED ACTIONS (use exactly this format):\n"
44
43
  "- CLICK([N]) → click element with number N to focus/activate it\n"
45
- "- TYPE([N], \"text\") → type text into element N (e.g., TYPE([2], \"hello\"))\n"
44
+ '- TYPE([N], "text") → type text into element N (e.g., TYPE([2], "hello"))\n'
46
45
  "- WAIT() → wait for UI to update\n"
47
46
  "- DONE() → task is complete\n\n"
48
47
  "ACTION SEQUENCE FOR LOGIN:\n"
49
48
  "1. CLICK([1]) to focus username field\n"
50
- "2. TYPE([1], \"username\") to enter username\n"
49
+ '2. TYPE([1], "username") to enter username\n'
51
50
  "3. CLICK([2]) to focus password field\n"
52
- "4. TYPE([2], \"password\") to enter password\n"
51
+ '4. TYPE([2], "password") to enter password\n'
53
52
  "5. CLICK([3]) to submit login\n"
54
53
  "6. DONE() when login is complete\n\n"
55
54
  "RESPONSE FORMAT (required):\n"
@@ -74,20 +73,20 @@ SYSTEM_PROMPT_SOM_REGISTRATION = (
74
73
  "[6] = Register button\n\n"
75
74
  "ALLOWED ACTIONS (use exactly this format):\n"
76
75
  "- CLICK([N]) → click element with number N to focus/activate it\n"
77
- "- TYPE([N], \"text\") → type text into element N (e.g., TYPE([2], \"hello\"))\n"
76
+ '- TYPE([N], "text") → type text into element N (e.g., TYPE([2], "hello"))\n'
78
77
  "- WAIT() → wait for UI to update\n"
79
78
  "- DONE() → task is complete\n\n"
80
79
  "ACTION SEQUENCE FOR REGISTRATION:\n"
81
80
  "1. CLICK([1]) to focus first name field\n"
82
- "2. TYPE([1], \"name\") to enter first name\n"
81
+ '2. TYPE([1], "name") to enter first name\n'
83
82
  "3. CLICK([2]) to focus last name field\n"
84
- "4. TYPE([2], \"name\") to enter last name\n"
83
+ '4. TYPE([2], "name") to enter last name\n'
85
84
  "5. CLICK([3]) to focus email field\n"
86
- "6. TYPE([3], \"email\") to enter email\n"
85
+ '6. TYPE([3], "email") to enter email\n'
87
86
  "7. CLICK([4]) to focus password field\n"
88
- "8. TYPE([4], \"pass\") to enter password\n"
87
+ '8. TYPE([4], "pass") to enter password\n'
89
88
  "9. CLICK([5]) to focus confirm password field\n"
90
- "10. TYPE([5], \"pass\") to enter confirmation\n"
89
+ '10. TYPE([5], "pass") to enter confirmation\n'
91
90
  "11. CLICK([6]) to submit registration\n"
92
91
  "12. DONE() when registration is complete\n\n"
93
92
  "RESPONSE FORMAT (required):\n"
@@ -97,6 +96,13 @@ SYSTEM_PROMPT_SOM_REGISTRATION = (
97
96
  )
98
97
 
99
98
 
99
+ def _get_element_id(action: Action) -> str | None:
100
+ """Extract element ID from action's element field."""
101
+ if action.element is not None and action.element.element_id is not None:
102
+ return action.element.element_id
103
+ return None
104
+
105
+
100
106
  def format_action(action: Action, use_som: bool = False) -> str:
101
107
  """Serialize an Action into a simple textual command.
102
108
 
@@ -110,53 +116,55 @@ def format_action(action: Action, use_som: bool = False) -> str:
110
116
  Args:
111
117
  action: The action to format.
112
118
  use_som: If True, use Set-of-Marks (SoM) index-based format instead of
113
- coordinate-based format. Requires element_index to be set.
119
+ coordinate-based format. Requires element with element_id to be set.
114
120
  """
115
121
 
116
122
  t = action.type
123
+ element_id = _get_element_id(action)
117
124
  if use_som:
118
125
  # SoM mode: use element indices instead of coordinates
119
- if t == "click" and action.element_index is not None:
120
- return f"CLICK([{action.element_index}])"
121
- if t == "type" and action.text is not None:
122
- escaped = action.text.replace("\\", "\\\\").replace("\"", "\\\"")
123
- if action.element_index is not None:
124
- return f"TYPE([{action.element_index}], \"{escaped}\")"
126
+ if t == ActionType.CLICK and element_id is not None:
127
+ return f"CLICK([{element_id}])"
128
+ if t == ActionType.TYPE and action.text is not None:
129
+ escaped = action.text.replace("\\", "\\\\").replace('"', '\\"')
130
+ if element_id is not None:
131
+ return f'TYPE([{element_id}], "{escaped}")'
125
132
  else:
126
133
  # Fallback: TYPE without element reference (for focused field)
127
- return f"TYPE(\"{escaped}\")"
128
- if t == "wait":
134
+ return f'TYPE("{escaped}")'
135
+ if t == ActionType.WAIT:
129
136
  return "WAIT()"
130
- if t == "done":
137
+ if t == ActionType.DONE:
131
138
  return "DONE()"
132
139
  # Fallback
133
- return f"ACTION(type={t})"
140
+ return f"ACTION(type={t.value if isinstance(t, ActionType) else t})"
134
141
  else:
135
142
  # Coordinate mode (original)
136
- if t == "click" and action.x is not None and action.y is not None:
137
- return f"CLICK(x={action.x:.2f}, y={action.y:.2f})"
138
- if t == "type" and action.text is not None:
139
- escaped = action.text.replace("\\", "\\\\").replace("\"", "\\\"")
140
- return f"TYPE(text=\"{escaped}\")"
141
- if t == "wait":
143
+ if t == ActionType.CLICK and action.normalized_coordinates is not None:
144
+ x, y = action.normalized_coordinates
145
+ return f"CLICK(x={x:.2f}, y={y:.2f})"
146
+ if t == ActionType.TYPE and action.text is not None:
147
+ escaped = action.text.replace("\\", "\\\\").replace('"', '\\"')
148
+ return f'TYPE(text="{escaped}")'
149
+ if t == ActionType.WAIT:
142
150
  return "WAIT()"
143
- if t == "done":
151
+ if t == ActionType.DONE:
144
152
  return "DONE()"
145
153
  # Fallback
146
- return f"ACTION(type={t})"
154
+ return f"ACTION(type={t.value if isinstance(t, ActionType) else t})"
147
155
 
148
156
 
149
157
  def parse_action_som(text: str) -> Action:
150
158
  """Parse a SoM-style action string into an Action object.
151
159
 
152
160
  Supported formats:
153
- - CLICK([N]) click element N
154
- - TYPE([N], "text") type text into element N
155
- - TYPE("text") type text into focused field
156
- - WAIT() wait
157
- - DONE() done
161
+ - CLICK([N]) -> click element N
162
+ - TYPE([N], "text") -> type text into element N
163
+ - TYPE("text") -> type text into focused field
164
+ - WAIT() -> wait
165
+ - DONE() -> done
158
166
 
159
- Returns Action with element_index set for click/type actions.
167
+ Returns Action with element set for click/type actions.
160
168
  """
161
169
  import re
162
170
 
@@ -165,35 +173,39 @@ def parse_action_som(text: str) -> Action:
165
173
  # CLICK([N])
166
174
  match = re.match(r"CLICK\(\[(\d+)\]\)", text)
167
175
  if match:
168
- idx = int(match.group(1))
169
- return Action(type="click", element_index=idx)
176
+ idx = match.group(1)
177
+ return Action(type=ActionType.CLICK, element=UIElement(element_id=idx))
170
178
 
171
179
  # TYPE([N], "text") or TYPE([N], 'text')
172
180
  match = re.match(r'TYPE\(\[(\d+)\],\s*["\'](.*)["\']\)', text, re.DOTALL)
173
181
  if match:
174
- idx = int(match.group(1))
175
- content = match.group(2).replace("\\\"", "\"").replace("\\\\", "\\")
176
- return Action(type="type", text=content, element_index=idx)
182
+ idx = match.group(1)
183
+ content = match.group(2).replace('\\"', '"').replace("\\\\", "\\")
184
+ return Action(
185
+ type=ActionType.TYPE, text=content, element=UIElement(element_id=idx)
186
+ )
177
187
 
178
188
  # TYPE("text") - no element index
179
189
  match = re.match(r'TYPE\(["\'](.*)["\']\)', text, re.DOTALL)
180
190
  if match:
181
- content = match.group(1).replace("\\\"", "\"").replace("\\\\", "\\")
182
- return Action(type="type", text=content)
191
+ content = match.group(1).replace('\\"', '"').replace("\\\\", "\\")
192
+ return Action(type=ActionType.TYPE, text=content)
183
193
 
184
194
  # WAIT()
185
195
  if text.upper().startswith("WAIT"):
186
- return Action(type="wait")
196
+ return Action(type=ActionType.WAIT)
187
197
 
188
198
  # DONE()
189
199
  if text.upper().startswith("DONE"):
190
- return Action(type="done")
200
+ return Action(type=ActionType.DONE)
191
201
 
192
202
  # Failed to parse
193
- return Action(type="failed", raw={"text": text})
203
+ return Action(type=ActionType.FAIL, raw={"text": text})
194
204
 
195
205
 
196
- def _generate_generic_thought(step_index: int, step: Step, goal: str, total_steps: int) -> str:
206
+ def _generate_generic_thought(
207
+ step_index: int, step: Step, goal: str, total_steps: int
208
+ ) -> str:
197
209
  """Generate a thought for real captures (non-synthetic scenarios).
198
210
 
199
211
  This creates action-appropriate thoughts that teach the model to output
@@ -205,10 +217,10 @@ def _generate_generic_thought(step_index: int, step: Step, goal: str, total_step
205
217
  # Progress context
206
218
  progress = f"Step {step_index + 1} of {total_steps}."
207
219
 
208
- if t == "click":
209
- if action.x is not None and action.y is not None:
220
+ if t == ActionType.CLICK:
221
+ if action.normalized_coordinates is not None:
210
222
  # Describe the click location relative to screen regions
211
- x, y = action.x, action.y
223
+ x, y = action.normalized_coordinates
212
224
  h_pos = "left" if x < 0.33 else ("center" if x < 0.66 else "right")
213
225
  v_pos = "top" if y < 0.33 else ("middle" if y < 0.66 else "bottom")
214
226
  return (
@@ -217,28 +229,30 @@ def _generate_generic_thought(step_index: int, step: Step, goal: str, total_step
217
229
  )
218
230
  return f"{progress} I need to click on the relevant UI element to continue toward '{goal}'."
219
231
 
220
- if t == "double_click":
232
+ if t == ActionType.DOUBLE_CLICK:
221
233
  return f"{progress} I need to double-click to select or activate this element for '{goal}'."
222
234
 
223
- if t == "type":
235
+ if t == ActionType.TYPE:
224
236
  if action.text:
225
237
  # Don't reveal the actual text, just indicate typing is needed
226
238
  return f"{progress} I need to type text into the focused input field to continue toward '{goal}'."
227
239
  return f"{progress} I need to enter text in the current field."
228
240
 
229
- if t == "scroll":
241
+ if t == ActionType.SCROLL:
230
242
  return f"{progress} I need to scroll to reveal more content or reach the target element for '{goal}'."
231
243
 
232
- if t == "drag":
233
- return f"{progress} I need to drag an element to complete this part of '{goal}'."
244
+ if t == ActionType.DRAG:
245
+ return (
246
+ f"{progress} I need to drag an element to complete this part of '{goal}'."
247
+ )
234
248
 
235
- if t == "key_press":
249
+ if t == ActionType.KEY:
236
250
  return f"{progress} I need to press a key to continue the workflow."
237
251
 
238
- if t == "wait":
252
+ if t == ActionType.WAIT:
239
253
  return f"{progress} I should wait for the UI to update before the next action."
240
254
 
241
- if t == "done":
255
+ if t == ActionType.DONE:
242
256
  return f"The goal '{goal}' has been achieved. The workflow is complete."
243
257
 
244
258
  # Fallback
@@ -260,9 +274,6 @@ def _generate_thought_for_step(
260
274
  actions back to the stated objective.
261
275
  """
262
276
 
263
- action = step.action
264
- t = action.type
265
-
266
277
  if scenario == "registration":
267
278
  return _generate_registration_thought(step_index, step, goal, total_steps)
268
279
  elif scenario == "login" and total_steps <= 7:
@@ -273,48 +284,50 @@ def _generate_thought_for_step(
273
284
  return _generate_generic_thought(step_index, step, goal, total_steps)
274
285
 
275
286
 
276
- def _generate_login_thought(step_index: int, step: Step, goal: str, total_steps: int) -> str:
287
+ def _generate_login_thought(
288
+ step_index: int, step: Step, goal: str, total_steps: int
289
+ ) -> str:
277
290
  """Generate thought for login scenario (6 steps)."""
278
291
  action = step.action
279
292
  t = action.type
280
293
 
281
294
  # Step 0: click username field
282
- if step_index == 0 and t == "click":
295
+ if step_index == 0 and t == ActionType.CLICK:
283
296
  return (
284
297
  "I see a login screen with empty username and password fields and a Login button. "
285
298
  f"To start logging in, I need to click on the username field to focus it ({goal})."
286
299
  )
287
300
 
288
301
  # Step 1: type username
289
- if step_index == 1 and t == "type":
302
+ if step_index == 1 and t == ActionType.TYPE:
290
303
  return (
291
304
  "The username field is focused. To move toward the login goal, I should type the "
292
305
  "username into this field."
293
306
  )
294
307
 
295
308
  # Step 2: click password field
296
- if step_index == 2 and t == "click":
309
+ if step_index == 2 and t == ActionType.CLICK:
297
310
  return (
298
311
  "The username has been entered. Next, I need to focus the password field so that I can "
299
312
  "enter the password for this login. I will click on the password input box."
300
313
  )
301
314
 
302
315
  # Step 3: type password
303
- if step_index == 3 and t == "type":
316
+ if step_index == 3 and t == ActionType.TYPE:
304
317
  return (
305
318
  "The password field is focused. To continue the login process, I should type the "
306
319
  "password (which will appear as masked characters on the screen)."
307
320
  )
308
321
 
309
322
  # Step 4: click Login button
310
- if step_index == 4 and t == "click":
323
+ if step_index == 4 and t == ActionType.CLICK:
311
324
  return (
312
325
  "Both the username and password have been entered. To submit the form and attempt the "
313
326
  "login, I should click the Login button."
314
327
  )
315
328
 
316
329
  # Step 5: DONE on logged-in screen
317
- if step_index == 5 and t == "done":
330
+ if step_index == 5 and t == ActionType.DONE:
318
331
  return (
319
332
  "I now see a logged-in confirmation screen indicating the goal has been satisfied. "
320
333
  "The task is complete, so I should emit DONE()."
@@ -327,48 +340,50 @@ def _generate_login_thought(step_index: int, step: Step, goal: str, total_steps:
327
340
  )
328
341
 
329
342
 
330
- def _generate_registration_thought(step_index: int, step: Step, goal: str, total_steps: int) -> str:
343
+ def _generate_registration_thought(
344
+ step_index: int, step: Step, goal: str, total_steps: int
345
+ ) -> str:
331
346
  """Generate thought for registration scenario (12 steps)."""
332
347
  action = step.action
333
348
  t = action.type
334
349
 
335
350
  # Registration step mapping (pairs of click + type for 5 fields, then submit + done)
336
351
  thoughts = {
337
- (0, "click"): (
352
+ (0, ActionType.CLICK): (
338
353
  "I see a registration form with empty fields for name, email, and password. "
339
354
  f"To start registration, I need to click on the First Name field ({goal})."
340
355
  ),
341
- (1, "type"): (
356
+ (1, ActionType.TYPE): (
342
357
  "The First Name field is focused. I should type the first name."
343
358
  ),
344
- (2, "click"): (
359
+ (2, ActionType.CLICK): (
345
360
  "First name entered. Now I need to focus the Last Name field to enter it."
346
361
  ),
347
- (3, "type"): (
362
+ (3, ActionType.TYPE): (
348
363
  "The Last Name field is focused. I should type the last name."
349
364
  ),
350
- (4, "click"): (
365
+ (4, ActionType.CLICK): (
351
366
  "Last name entered. Now I need to focus the Email field to enter the email address."
352
367
  ),
353
- (5, "type"): (
368
+ (5, ActionType.TYPE): (
354
369
  "The Email field is focused. I should type the email address."
355
370
  ),
356
- (6, "click"): (
371
+ (6, ActionType.CLICK): (
357
372
  "Email entered. Now I need to focus the Password field to create a password."
358
373
  ),
359
- (7, "type"): (
374
+ (7, ActionType.TYPE): (
360
375
  "The Password field is focused. I should type the password."
361
376
  ),
362
- (8, "click"): (
377
+ (8, ActionType.CLICK): (
363
378
  "Password entered. Now I need to focus the Confirm Password field to verify the password."
364
379
  ),
365
- (9, "type"): (
380
+ (9, ActionType.TYPE): (
366
381
  "The Confirm Password field is focused. I should type the same password again."
367
382
  ),
368
- (10, "click"): (
383
+ (10, ActionType.CLICK): (
369
384
  "All form fields are filled. I should click the Register button to submit the form."
370
385
  ),
371
- (11, "done"): (
386
+ (11, ActionType.DONE): (
372
387
  "Registration is complete - I see a success screen. The task is finished."
373
388
  ),
374
389
  }
@@ -385,10 +400,16 @@ def _generate_registration_thought(step_index: int, step: Step, goal: str, total
385
400
 
386
401
 
387
402
  def _detect_scenario(episode: Episode) -> str:
388
- """Detect scenario from episode workflow_id."""
389
- workflow_id = episode.workflow_id or ""
390
- if "registration" in workflow_id.lower():
403
+ """Detect scenario from episode task_id or metadata."""
404
+ # Check task_id first
405
+ task_id = episode.task_id or ""
406
+ if "registration" in task_id.lower():
391
407
  return "registration"
408
+ # Check metadata for workflow_id (backward compatibility)
409
+ if episode.metadata and "workflow_id" in episode.metadata:
410
+ workflow_id = episode.metadata.get("workflow_id", "")
411
+ if "registration" in str(workflow_id).lower():
412
+ return "registration"
392
413
  return "login"
393
414
 
394
415
 
@@ -417,7 +438,8 @@ def build_next_action_sft_samples(
417
438
  samples: List[Dict[str, Any]] = []
418
439
 
419
440
  for episode in episodes:
420
- goal = episode.goal
441
+ # Use instruction as the goal (new schema field name)
442
+ goal = episode.instruction
421
443
  total_steps = len(episode.steps)
422
444
  scenario = _detect_scenario(episode)
423
445
 
@@ -430,18 +452,21 @@ def build_next_action_sft_samples(
430
452
  else:
431
453
  system_prompt = SYSTEM_PROMPT
432
454
 
433
- for step_index, step in enumerate(episode.steps):
434
- image_path = step.observation.image_path
455
+ for step in episode.steps:
456
+ # Use step_index from the Step model
457
+ step_index = step.step_index
458
+ # Use screenshot_path instead of image_path
459
+ image_path = step.observation.screenshot_path
435
460
  if not image_path:
436
461
  # Skip steps without an associated image
437
462
  continue
438
463
 
439
464
  # Build action history from previous steps
440
465
  action_history = []
441
- for prev_idx in range(step_index):
442
- prev_step = episode.steps[prev_idx]
443
- prev_action_text = format_action(prev_step.action, use_som=use_som)
444
- action_history.append(prev_action_text)
466
+ for prev_step in episode.steps:
467
+ if prev_step.step_index < step_index:
468
+ prev_action_text = format_action(prev_step.action, use_som=use_som)
469
+ action_history.append(prev_action_text)
445
470
 
446
471
  # Build history section for both modes - use actual step count
447
472
  if action_history:
@@ -450,7 +475,9 @@ def build_next_action_sft_samples(
450
475
  history_text += f" {i}. {action_text}\n"
451
476
  history_text += f"\nThis is step {step_index + 1} of {total_steps}. "
452
477
  else:
453
- history_text = f"This is step 1 of {total_steps} (no actions completed yet). "
478
+ history_text = (
479
+ f"This is step 1 of {total_steps} (no actions completed yet). "
480
+ )
454
481
 
455
482
  if use_som:
456
483
  user_content = (
@@ -458,7 +485,7 @@ def build_next_action_sft_samples(
458
485
  f"{history_text}"
459
486
  "Look at the screenshot and determine the NEXT action.\n\n"
460
487
  "Thought: [which numbered element to interact with and why]\n"
461
- "Action: [CLICK([N]) or TYPE([N], \"text\") or WAIT() or DONE()]"
488
+ 'Action: [CLICK([N]) or TYPE([N], "text") or WAIT() or DONE()]'
462
489
  )
463
490
  else:
464
491
  user_content = (
@@ -466,13 +493,15 @@ def build_next_action_sft_samples(
466
493
  f"{history_text}"
467
494
  "Look at the screenshot and determine the NEXT action.\n\n"
468
495
  "Thought: [what element to interact with and why]\n"
469
- "Action: [CLICK(x=..., y=...) or TYPE(text=\"...\") or WAIT() or DONE()]"
496
+ 'Action: [CLICK(x=..., y=...) or TYPE(text="...") or WAIT() or DONE()]'
470
497
  )
471
498
 
472
499
  # Provide a deterministic, semantically meaningful Thought while supervising
473
500
  # the exact DSL Action.
474
501
  action_text = format_action(step.action, use_som=use_som)
475
- thought_text = _generate_thought_for_step(step_index, step, goal, scenario, total_steps)
502
+ thought_text = _generate_thought_for_step(
503
+ step_index, step, goal, scenario, total_steps
504
+ )
476
505
  assistant_content = f"Thought: {thought_text}\nAction: {action_text}"
477
506
 
478
507
  sample = {
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING
19
19
  if TYPE_CHECKING:
20
20
  from PIL import Image
21
21
 
22
+ from openadapt_ml.data.types import Episode
22
23
  from openadapt_ml.grounding.base import GroundingModule, RegionCandidate
23
24
 
24
25
 
@@ -212,30 +213,52 @@ def evaluate_grounder_on_episode(
212
213
  """
213
214
  from PIL import Image
214
215
 
215
- from openadapt_ml.schemas.sessions import Episode
216
+ from openadapt_ml.schema import ActionType
216
217
 
217
218
  test_cases = []
218
219
 
219
220
  for step in episode.steps:
220
221
  action = step.action
221
222
 
223
+ # Get action type as string for comparison
224
+ action_type_str = (
225
+ action.type.value if isinstance(action.type, ActionType) else action.type
226
+ )
227
+
222
228
  # Only evaluate clicks with bboxes
223
- if action.type not in ("click", "double_click"):
229
+ if action_type_str not in ("click", "double_click"):
224
230
  continue
225
- if action.bbox is None:
231
+
232
+ # Check for bbox - in new schema, bbox is in element.bounds or raw
233
+ bbox = None
234
+ if action.element and action.element.bounds:
235
+ b = action.element.bounds
236
+ bbox = (b.x, b.y, b.x + b.width, b.y + b.height)
237
+ elif action.raw and "bbox" in action.raw:
238
+ bbox = action.raw["bbox"]
239
+
240
+ if bbox is None:
226
241
  continue
227
- if step.observation.image_path is None:
242
+ if step.observation.screenshot_path is None:
228
243
  continue
229
244
 
230
245
  # Load image
231
246
  try:
232
- image = Image.open(step.observation.image_path)
247
+ image = Image.open(step.observation.screenshot_path)
233
248
  except Exception:
234
249
  continue
235
250
 
236
- # Create target description from thought or action
237
- target_desc = step.thought or f"element at ({action.x:.2f}, {action.y:.2f})"
238
-
239
- test_cases.append((image, target_desc, action.bbox))
251
+ # Create target description from reasoning or action coordinates
252
+ coords_x, coords_y = None, None
253
+ if action.normalized_coordinates:
254
+ coords_x, coords_y = action.normalized_coordinates
255
+ if coords_x is not None and coords_y is not None:
256
+ target_desc = (
257
+ step.reasoning or f"element at ({coords_x:.2f}, {coords_y:.2f})"
258
+ )
259
+ else:
260
+ target_desc = step.reasoning or "target element"
261
+
262
+ test_cases.append((image, target_desc, bbox))
240
263
 
241
264
  return evaluate_grounder(grounder, test_cases, k=k)
@@ -73,7 +73,7 @@ def plot_eval_metrics(
73
73
  fig.suptitle(
74
74
  "VLM Model Comparison (Offline fine-tuned vs API models)",
75
75
  fontsize=12,
76
- fontweight='bold',
76
+ fontweight="bold",
77
77
  )
78
78
  if num_metrics == 1:
79
79
  axes = [axes]
@@ -96,36 +96,38 @@ def plot_eval_metrics(
96
96
  hatches.append(hatch)
97
97
 
98
98
  x = range(num_models)
99
- bars = ax.bar(x, values, tick_label=labels, color=colors, edgecolor='black', linewidth=1.2)
99
+ bars = ax.bar(
100
+ x, values, tick_label=labels, color=colors, edgecolor="black", linewidth=1.2
101
+ )
100
102
 
101
103
  # Apply hatch patterns
102
104
  for bar, hatch in zip(bars, hatches):
103
105
  bar.set_hatch(hatch)
104
106
 
105
- ax.set_title(title, fontsize=11, fontweight='bold')
107
+ ax.set_title(title, fontsize=11, fontweight="bold")
106
108
  ax.set_ylabel(key, fontsize=9)
107
109
  ax.set_ylim(bottom=0.0)
108
110
  # Rotate x-axis labels to prevent crowding
109
- ax.tick_params(axis='x', labelrotation=45, labelsize=8)
111
+ ax.tick_params(axis="x", labelrotation=45, labelsize=8)
110
112
  # Align labels to the right for better readability when rotated
111
113
  for tick in ax.get_xticklabels():
112
- tick.set_horizontalalignment('right')
114
+ tick.set_horizontalalignment("right")
113
115
 
114
116
  fig.tight_layout()
115
117
 
116
118
  # Add legend explaining color coding and hatch patterns
117
119
  legend_elements = [
118
- Patch(facecolor='#4A90E2', edgecolor='black', label='Qwen3-VL-2B'),
119
- Patch(facecolor='#2E5C8A', edgecolor='black', label='Qwen3-VL-8B'),
120
- Patch(facecolor='#FF6B35', edgecolor='black', label='Claude (API)'),
121
- Patch(facecolor='#C1121F', edgecolor='black', label='GPT (API)'),
122
- Patch(facecolor='gray', edgecolor='black', hatch='///', label='Fine-tuned'),
123
- Patch(facecolor='gray', edgecolor='black', label='Base/Pretrained'),
120
+ Patch(facecolor="#4A90E2", edgecolor="black", label="Qwen3-VL-2B"),
121
+ Patch(facecolor="#2E5C8A", edgecolor="black", label="Qwen3-VL-8B"),
122
+ Patch(facecolor="#FF6B35", edgecolor="black", label="Claude (API)"),
123
+ Patch(facecolor="#C1121F", edgecolor="black", label="GPT (API)"),
124
+ Patch(facecolor="gray", edgecolor="black", hatch="///", label="Fine-tuned"),
125
+ Patch(facecolor="gray", edgecolor="black", label="Base/Pretrained"),
124
126
  ]
125
127
 
126
128
  fig.legend(
127
129
  handles=legend_elements,
128
- loc='lower center',
130
+ loc="lower center",
129
131
  bbox_to_anchor=(0.5, -0.05),
130
132
  ncol=3,
131
133
  fontsize=9,
@@ -133,7 +135,7 @@ def plot_eval_metrics(
133
135
  )
134
136
 
135
137
  output_path.parent.mkdir(parents=True, exist_ok=True)
136
- fig.savefig(output_path, dpi=150, bbox_inches='tight')
138
+ fig.savefig(output_path, dpi=150, bbox_inches="tight")
137
139
  plt.close(fig)
138
140
 
139
141