openadapt-ml 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. openadapt_ml/baselines/__init__.py +121 -0
  2. openadapt_ml/baselines/adapter.py +185 -0
  3. openadapt_ml/baselines/cli.py +314 -0
  4. openadapt_ml/baselines/config.py +448 -0
  5. openadapt_ml/baselines/parser.py +922 -0
  6. openadapt_ml/baselines/prompts.py +787 -0
  7. openadapt_ml/benchmarks/__init__.py +13 -115
  8. openadapt_ml/benchmarks/agent.py +265 -421
  9. openadapt_ml/benchmarks/azure.py +28 -19
  10. openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
  11. openadapt_ml/benchmarks/cli.py +1722 -4847
  12. openadapt_ml/benchmarks/trace_export.py +631 -0
  13. openadapt_ml/benchmarks/viewer.py +22 -5
  14. openadapt_ml/benchmarks/vm_monitor.py +530 -29
  15. openadapt_ml/benchmarks/waa_deploy/Dockerfile +47 -53
  16. openadapt_ml/benchmarks/waa_deploy/api_agent.py +21 -20
  17. openadapt_ml/cloud/azure_inference.py +3 -5
  18. openadapt_ml/cloud/lambda_labs.py +722 -307
  19. openadapt_ml/cloud/local.py +2038 -487
  20. openadapt_ml/cloud/ssh_tunnel.py +68 -26
  21. openadapt_ml/datasets/next_action.py +40 -30
  22. openadapt_ml/evals/grounding.py +8 -3
  23. openadapt_ml/evals/plot_eval_metrics.py +15 -13
  24. openadapt_ml/evals/trajectory_matching.py +41 -26
  25. openadapt_ml/experiments/demo_prompt/format_demo.py +16 -6
  26. openadapt_ml/experiments/demo_prompt/run_experiment.py +26 -16
  27. openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
  28. openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
  29. openadapt_ml/experiments/representation_shootout/config.py +390 -0
  30. openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
  31. openadapt_ml/experiments/representation_shootout/runner.py +687 -0
  32. openadapt_ml/experiments/waa_demo/runner.py +29 -14
  33. openadapt_ml/export/parquet.py +36 -24
  34. openadapt_ml/grounding/detector.py +18 -14
  35. openadapt_ml/ingest/__init__.py +8 -6
  36. openadapt_ml/ingest/capture.py +25 -22
  37. openadapt_ml/ingest/loader.py +7 -4
  38. openadapt_ml/ingest/synthetic.py +189 -100
  39. openadapt_ml/models/api_adapter.py +14 -4
  40. openadapt_ml/models/base_adapter.py +10 -2
  41. openadapt_ml/models/providers/__init__.py +288 -0
  42. openadapt_ml/models/providers/anthropic.py +266 -0
  43. openadapt_ml/models/providers/base.py +299 -0
  44. openadapt_ml/models/providers/google.py +376 -0
  45. openadapt_ml/models/providers/openai.py +342 -0
  46. openadapt_ml/models/qwen_vl.py +46 -19
  47. openadapt_ml/perception/__init__.py +35 -0
  48. openadapt_ml/perception/integration.py +399 -0
  49. openadapt_ml/retrieval/demo_retriever.py +50 -24
  50. openadapt_ml/retrieval/embeddings.py +9 -8
  51. openadapt_ml/retrieval/retriever.py +3 -1
  52. openadapt_ml/runtime/__init__.py +50 -0
  53. openadapt_ml/runtime/policy.py +18 -5
  54. openadapt_ml/runtime/safety_gate.py +471 -0
  55. openadapt_ml/schema/__init__.py +9 -0
  56. openadapt_ml/schema/converters.py +74 -27
  57. openadapt_ml/schema/episode.py +31 -18
  58. openadapt_ml/scripts/capture_screenshots.py +530 -0
  59. openadapt_ml/scripts/compare.py +85 -54
  60. openadapt_ml/scripts/demo_policy.py +4 -1
  61. openadapt_ml/scripts/eval_policy.py +15 -9
  62. openadapt_ml/scripts/make_gif.py +1 -1
  63. openadapt_ml/scripts/prepare_synthetic.py +3 -1
  64. openadapt_ml/scripts/train.py +21 -9
  65. openadapt_ml/segmentation/README.md +920 -0
  66. openadapt_ml/segmentation/__init__.py +97 -0
  67. openadapt_ml/segmentation/adapters/__init__.py +5 -0
  68. openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
  69. openadapt_ml/segmentation/annotator.py +610 -0
  70. openadapt_ml/segmentation/cache.py +290 -0
  71. openadapt_ml/segmentation/cli.py +674 -0
  72. openadapt_ml/segmentation/deduplicator.py +656 -0
  73. openadapt_ml/segmentation/frame_describer.py +788 -0
  74. openadapt_ml/segmentation/pipeline.py +340 -0
  75. openadapt_ml/segmentation/schemas.py +622 -0
  76. openadapt_ml/segmentation/segment_extractor.py +634 -0
  77. openadapt_ml/training/azure_ops_viewer.py +1097 -0
  78. openadapt_ml/training/benchmark_viewer.py +52 -41
  79. openadapt_ml/training/shared_ui.py +7 -7
  80. openadapt_ml/training/stub_provider.py +57 -35
  81. openadapt_ml/training/trainer.py +143 -86
  82. openadapt_ml/training/trl_trainer.py +70 -21
  83. openadapt_ml/training/viewer.py +323 -108
  84. openadapt_ml/training/viewer_components.py +180 -0
  85. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.2.dist-info}/METADATA +215 -14
  86. openadapt_ml-0.2.2.dist-info/RECORD +116 -0
  87. openadapt_ml/benchmarks/base.py +0 -366
  88. openadapt_ml/benchmarks/data_collection.py +0 -432
  89. openadapt_ml/benchmarks/live_tracker.py +0 -180
  90. openadapt_ml/benchmarks/runner.py +0 -418
  91. openadapt_ml/benchmarks/waa.py +0 -761
  92. openadapt_ml/benchmarks/waa_live.py +0 -619
  93. openadapt_ml-0.2.0.dist-info/RECORD +0 -86
  94. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.2.dist-info}/WHEEL +0 -0
  95. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.2.dist-info}/licenses/LICENSE +0 -0
@@ -51,9 +51,8 @@ import signal
51
51
  import socket
52
52
  import subprocess
53
53
  import time
54
- from dataclasses import dataclass, field
54
+ from dataclasses import dataclass
55
55
  from pathlib import Path
56
- from typing import Any
57
56
 
58
57
  logger = logging.getLogger(__name__)
59
58
 
@@ -97,9 +96,11 @@ class SSHTunnelManager:
97
96
  """
98
97
 
99
98
  # Default tunnel configurations
99
+ # Note: WAA uses local_port=5001 to avoid conflicts with any local WAA server on 5000
100
+ # The remote port is still 5000 (where WAA Flask runs inside Windows)
100
101
  DEFAULT_TUNNELS = [
101
102
  TunnelConfig(name="vnc", local_port=8006, remote_port=8006),
102
- TunnelConfig(name="waa", local_port=5000, remote_port=5000),
103
+ TunnelConfig(name="waa", local_port=5001, remote_port=5000),
103
104
  ]
104
105
 
105
106
  # Auto-reconnect settings
@@ -125,7 +126,9 @@ class SSHTunnelManager:
125
126
  self._current_vm_ip: str | None = None
126
127
  self._current_ssh_user: str | None = None
127
128
  self._auto_reconnect = auto_reconnect
128
- self._reconnect_attempts: dict[str, int] = {} # Track reconnect attempts per tunnel
129
+ self._reconnect_attempts: dict[
130
+ str, int
131
+ ] = {} # Track reconnect attempts per tunnel
129
132
 
130
133
  def start_tunnels_for_vm(
131
134
  self,
@@ -198,7 +201,9 @@ class SSHTunnelManager:
198
201
  pid=None, # We don't know the PID of the external tunnel
199
202
  )
200
203
  else:
201
- logger.warning(f"Port {config.local_port} already in use by unknown process")
204
+ logger.warning(
205
+ f"Port {config.local_port} already in use by unknown process"
206
+ )
202
207
  return TunnelStatus(
203
208
  name=config.name,
204
209
  active=False,
@@ -213,16 +218,25 @@ class SSHTunnelManager:
213
218
  # TCPKeepAlive=yes: Enable TCP-level keepalive as additional safeguard
214
219
  ssh_cmd = [
215
220
  "ssh",
216
- "-o", "StrictHostKeyChecking=no",
217
- "-o", "UserKnownHostsFile=/dev/null",
218
- "-o", "LogLevel=ERROR",
219
- "-o", "ServerAliveInterval=60",
220
- "-o", "ServerAliveCountMax=10",
221
- "-o", "TCPKeepAlive=yes",
222
- "-o", "ExitOnForwardFailure=yes",
223
- "-i", str(self.ssh_key_path),
221
+ "-o",
222
+ "StrictHostKeyChecking=no",
223
+ "-o",
224
+ "UserKnownHostsFile=/dev/null",
225
+ "-o",
226
+ "LogLevel=ERROR",
227
+ "-o",
228
+ "ServerAliveInterval=60",
229
+ "-o",
230
+ "ServerAliveCountMax=10",
231
+ "-o",
232
+ "TCPKeepAlive=yes",
233
+ "-o",
234
+ "ExitOnForwardFailure=yes",
235
+ "-i",
236
+ str(self.ssh_key_path),
224
237
  "-N", # Don't execute remote command
225
- "-L", f"{config.local_port}:{config.remote_host}:{config.remote_port}",
238
+ "-L",
239
+ f"{config.local_port}:{config.remote_host}:{config.remote_port}",
226
240
  f"{ssh_user}@{vm_ip}",
227
241
  ]
228
242
 
@@ -253,7 +267,9 @@ class SSHTunnelManager:
253
267
 
254
268
  # Tunnel started successfully
255
269
  self._active_tunnels[config.name] = (config, proc)
256
- logger.info(f"Started tunnel {config.name}: localhost:{config.local_port} -> {vm_ip}:{config.remote_port}")
270
+ logger.info(
271
+ f"Started tunnel {config.name}: localhost:{config.local_port} -> {vm_ip}:{config.remote_port}"
272
+ )
257
273
 
258
274
  return TunnelStatus(
259
275
  name=config.name,
@@ -340,24 +356,36 @@ class SSHTunnelManager:
340
356
  name=config.name,
341
357
  active=True,
342
358
  local_port=config.local_port,
343
- remote_endpoint=f"{self._current_vm_ip}:{config.remote_port}" if self._current_vm_ip else "unknown",
359
+ remote_endpoint=f"{self._current_vm_ip}:{config.remote_port}"
360
+ if self._current_vm_ip
361
+ else "unknown",
344
362
  pid=proc.pid,
345
363
  )
346
364
  else:
347
365
  # Process died - but check if port is still working
348
366
  # (could be another tunnel on the same port)
349
367
  del self._active_tunnels[config.name]
350
- if self._is_port_in_use(config.local_port) and self._check_tunnel_works(config.local_port, config.remote_port):
368
+ if self._is_port_in_use(
369
+ config.local_port
370
+ ) and self._check_tunnel_works(
371
+ config.local_port, config.remote_port
372
+ ):
351
373
  results[config.name] = TunnelStatus(
352
374
  name=config.name,
353
375
  active=True,
354
376
  local_port=config.local_port,
355
- remote_endpoint=f"{self._current_vm_ip}:{config.remote_port}" if self._current_vm_ip else "external",
377
+ remote_endpoint=f"{self._current_vm_ip}:{config.remote_port}"
378
+ if self._current_vm_ip
379
+ else "external",
356
380
  pid=None, # External tunnel, PID unknown
357
381
  )
358
382
  else:
359
383
  # Tunnel is dead - mark for restart if auto_reconnect enabled
360
- if self._auto_reconnect and auto_restart and self._current_vm_ip:
384
+ if (
385
+ self._auto_reconnect
386
+ and auto_restart
387
+ and self._current_vm_ip
388
+ ):
361
389
  tunnels_to_restart.append(config)
362
390
  results[config.name] = TunnelStatus(
363
391
  name=config.name,
@@ -369,13 +397,19 @@ class SSHTunnelManager:
369
397
  else:
370
398
  # Not tracked internally - but check if an external tunnel exists
371
399
  # This handles tunnels started by other processes or after manager restart
372
- if self._is_port_in_use(config.local_port) and self._check_tunnel_works(config.local_port, config.remote_port):
373
- logger.debug(f"Found working external tunnel on port {config.local_port}")
400
+ if self._is_port_in_use(config.local_port) and self._check_tunnel_works(
401
+ config.local_port, config.remote_port
402
+ ):
403
+ logger.debug(
404
+ f"Found working external tunnel on port {config.local_port}"
405
+ )
374
406
  results[config.name] = TunnelStatus(
375
407
  name=config.name,
376
408
  active=True,
377
409
  local_port=config.local_port,
378
- remote_endpoint=f"{self._current_vm_ip}:{config.remote_port}" if self._current_vm_ip else "external",
410
+ remote_endpoint=f"{self._current_vm_ip}:{config.remote_port}"
411
+ if self._current_vm_ip
412
+ else "external",
379
413
  pid=None, # External tunnel, PID unknown
380
414
  )
381
415
  else:
@@ -390,16 +424,22 @@ class SSHTunnelManager:
390
424
  for config in tunnels_to_restart:
391
425
  attempts = self._reconnect_attempts.get(config.name, 0)
392
426
  if attempts < self.MAX_RECONNECT_ATTEMPTS:
393
- logger.info(f"Auto-reconnecting tunnel {config.name} (attempt {attempts + 1}/{self.MAX_RECONNECT_ATTEMPTS})")
427
+ logger.info(
428
+ f"Auto-reconnecting tunnel {config.name} (attempt {attempts + 1}/{self.MAX_RECONNECT_ATTEMPTS})"
429
+ )
394
430
  time.sleep(self.RECONNECT_DELAY_SECONDS)
395
431
  self._reconnect_attempts[config.name] = attempts + 1
396
- status = self._start_tunnel(config, self._current_vm_ip, self._current_ssh_user or "azureuser")
432
+ status = self._start_tunnel(
433
+ config, self._current_vm_ip, self._current_ssh_user or "azureuser"
434
+ )
397
435
  results[config.name] = status
398
436
  if status.active:
399
437
  logger.info(f"Successfully reconnected tunnel {config.name}")
400
438
  self._reconnect_attempts[config.name] = 0 # Reset on success
401
439
  else:
402
- logger.warning(f"Tunnel {config.name} exceeded max reconnect attempts ({self.MAX_RECONNECT_ATTEMPTS})")
440
+ logger.warning(
441
+ f"Tunnel {config.name} exceeded max reconnect attempts ({self.MAX_RECONNECT_ATTEMPTS})"
442
+ )
403
443
  results[config.name] = TunnelStatus(
404
444
  name=config.name,
405
445
  active=False,
@@ -455,7 +495,9 @@ class SSHTunnelManager:
455
495
  """
456
496
  # If VM changed, stop old tunnels and reset reconnect attempts
457
497
  if self._current_vm_ip and self._current_vm_ip != vm_ip:
458
- logger.info(f"VM IP changed from {self._current_vm_ip} to {vm_ip}, restarting tunnels")
498
+ logger.info(
499
+ f"VM IP changed from {self._current_vm_ip} to {vm_ip}, restarting tunnels"
500
+ )
459
501
  self.stop_all_tunnels()
460
502
  self.reset_reconnect_attempts() # Fresh start for new VM
461
503
 
@@ -3,7 +3,6 @@ from __future__ import annotations
3
3
  from dataclasses import dataclass
4
4
  from typing import Any, Dict, List
5
5
 
6
- import torch
7
6
  from torch.utils.data import Dataset
8
7
 
9
8
  from openadapt_ml.schema import Action, ActionType, Episode, Step, UIElement
@@ -20,7 +19,7 @@ SYSTEM_PROMPT = (
20
19
  "- Example: An element in the middle of the screen would be approximately x=0.5, y=0.5\n\n"
21
20
  "ALLOWED ACTIONS (use exactly this format):\n"
22
21
  "- CLICK(x=0.XX, y=0.XX) → click at normalized coordinates\n"
23
- "- TYPE(text=\"...\") → type text into the currently focused field\n"
22
+ '- TYPE(text="...") → type text into the currently focused field\n'
24
23
  "- WAIT() → wait for UI to update\n"
25
24
  "- DONE() → task is complete\n\n"
26
25
  "RESPONSE FORMAT (required):\n"
@@ -42,14 +41,14 @@ SYSTEM_PROMPT_SOM = (
42
41
  "[3] = Login button\n\n"
43
42
  "ALLOWED ACTIONS (use exactly this format):\n"
44
43
  "- CLICK([N]) → click element with number N to focus/activate it\n"
45
- "- TYPE([N], \"text\") → type text into element N (e.g., TYPE([2], \"hello\"))\n"
44
+ '- TYPE([N], "text") → type text into element N (e.g., TYPE([2], "hello"))\n'
46
45
  "- WAIT() → wait for UI to update\n"
47
46
  "- DONE() → task is complete\n\n"
48
47
  "ACTION SEQUENCE FOR LOGIN:\n"
49
48
  "1. CLICK([1]) to focus username field\n"
50
- "2. TYPE([1], \"username\") to enter username\n"
49
+ '2. TYPE([1], "username") to enter username\n'
51
50
  "3. CLICK([2]) to focus password field\n"
52
- "4. TYPE([2], \"password\") to enter password\n"
51
+ '4. TYPE([2], "password") to enter password\n'
53
52
  "5. CLICK([3]) to submit login\n"
54
53
  "6. DONE() when login is complete\n\n"
55
54
  "RESPONSE FORMAT (required):\n"
@@ -74,20 +73,20 @@ SYSTEM_PROMPT_SOM_REGISTRATION = (
74
73
  "[6] = Register button\n\n"
75
74
  "ALLOWED ACTIONS (use exactly this format):\n"
76
75
  "- CLICK([N]) → click element with number N to focus/activate it\n"
77
- "- TYPE([N], \"text\") → type text into element N (e.g., TYPE([2], \"hello\"))\n"
76
+ '- TYPE([N], "text") → type text into element N (e.g., TYPE([2], "hello"))\n'
78
77
  "- WAIT() → wait for UI to update\n"
79
78
  "- DONE() → task is complete\n\n"
80
79
  "ACTION SEQUENCE FOR REGISTRATION:\n"
81
80
  "1. CLICK([1]) to focus first name field\n"
82
- "2. TYPE([1], \"name\") to enter first name\n"
81
+ '2. TYPE([1], "name") to enter first name\n'
83
82
  "3. CLICK([2]) to focus last name field\n"
84
- "4. TYPE([2], \"name\") to enter last name\n"
83
+ '4. TYPE([2], "name") to enter last name\n'
85
84
  "5. CLICK([3]) to focus email field\n"
86
- "6. TYPE([3], \"email\") to enter email\n"
85
+ '6. TYPE([3], "email") to enter email\n'
87
86
  "7. CLICK([4]) to focus password field\n"
88
- "8. TYPE([4], \"pass\") to enter password\n"
87
+ '8. TYPE([4], "pass") to enter password\n'
89
88
  "9. CLICK([5]) to focus confirm password field\n"
90
- "10. TYPE([5], \"pass\") to enter confirmation\n"
89
+ '10. TYPE([5], "pass") to enter confirmation\n'
91
90
  "11. CLICK([6]) to submit registration\n"
92
91
  "12. DONE() when registration is complete\n\n"
93
92
  "RESPONSE FORMAT (required):\n"
@@ -127,12 +126,12 @@ def format_action(action: Action, use_som: bool = False) -> str:
127
126
  if t == ActionType.CLICK and element_id is not None:
128
127
  return f"CLICK([{element_id}])"
129
128
  if t == ActionType.TYPE and action.text is not None:
130
- escaped = action.text.replace("\\", "\\\\").replace("\"", "\\\"")
129
+ escaped = action.text.replace("\\", "\\\\").replace('"', '\\"')
131
130
  if element_id is not None:
132
- return f"TYPE([{element_id}], \"{escaped}\")"
131
+ return f'TYPE([{element_id}], "{escaped}")'
133
132
  else:
134
133
  # Fallback: TYPE without element reference (for focused field)
135
- return f"TYPE(\"{escaped}\")"
134
+ return f'TYPE("{escaped}")'
136
135
  if t == ActionType.WAIT:
137
136
  return "WAIT()"
138
137
  if t == ActionType.DONE:
@@ -145,8 +144,8 @@ def format_action(action: Action, use_som: bool = False) -> str:
145
144
  x, y = action.normalized_coordinates
146
145
  return f"CLICK(x={x:.2f}, y={y:.2f})"
147
146
  if t == ActionType.TYPE and action.text is not None:
148
- escaped = action.text.replace("\\", "\\\\").replace("\"", "\\\"")
149
- return f"TYPE(text=\"{escaped}\")"
147
+ escaped = action.text.replace("\\", "\\\\").replace('"', '\\"')
148
+ return f'TYPE(text="{escaped}")'
150
149
  if t == ActionType.WAIT:
151
150
  return "WAIT()"
152
151
  if t == ActionType.DONE:
@@ -181,13 +180,15 @@ def parse_action_som(text: str) -> Action:
181
180
  match = re.match(r'TYPE\(\[(\d+)\],\s*["\'](.*)["\']\)', text, re.DOTALL)
182
181
  if match:
183
182
  idx = match.group(1)
184
- content = match.group(2).replace("\\\"", "\"").replace("\\\\", "\\")
185
- return Action(type=ActionType.TYPE, text=content, element=UIElement(element_id=idx))
183
+ content = match.group(2).replace('\\"', '"').replace("\\\\", "\\")
184
+ return Action(
185
+ type=ActionType.TYPE, text=content, element=UIElement(element_id=idx)
186
+ )
186
187
 
187
188
  # TYPE("text") - no element index
188
189
  match = re.match(r'TYPE\(["\'](.*)["\']\)', text, re.DOTALL)
189
190
  if match:
190
- content = match.group(1).replace("\\\"", "\"").replace("\\\\", "\\")
191
+ content = match.group(1).replace('\\"', '"').replace("\\\\", "\\")
191
192
  return Action(type=ActionType.TYPE, text=content)
192
193
 
193
194
  # WAIT()
@@ -202,7 +203,9 @@ def parse_action_som(text: str) -> Action:
202
203
  return Action(type=ActionType.FAIL, raw={"text": text})
203
204
 
204
205
 
205
- def _generate_generic_thought(step_index: int, step: Step, goal: str, total_steps: int) -> str:
206
+ def _generate_generic_thought(
207
+ step_index: int, step: Step, goal: str, total_steps: int
208
+ ) -> str:
206
209
  """Generate a thought for real captures (non-synthetic scenarios).
207
210
 
208
211
  This creates action-appropriate thoughts that teach the model to output
@@ -239,7 +242,9 @@ def _generate_generic_thought(step_index: int, step: Step, goal: str, total_step
239
242
  return f"{progress} I need to scroll to reveal more content or reach the target element for '{goal}'."
240
243
 
241
244
  if t == ActionType.DRAG:
242
- return f"{progress} I need to drag an element to complete this part of '{goal}'."
245
+ return (
246
+ f"{progress} I need to drag an element to complete this part of '{goal}'."
247
+ )
243
248
 
244
249
  if t == ActionType.KEY:
245
250
  return f"{progress} I need to press a key to continue the workflow."
@@ -269,9 +274,6 @@ def _generate_thought_for_step(
269
274
  actions back to the stated objective.
270
275
  """
271
276
 
272
- action = step.action
273
- t = action.type
274
-
275
277
  if scenario == "registration":
276
278
  return _generate_registration_thought(step_index, step, goal, total_steps)
277
279
  elif scenario == "login" and total_steps <= 7:
@@ -282,7 +284,9 @@ def _generate_thought_for_step(
282
284
  return _generate_generic_thought(step_index, step, goal, total_steps)
283
285
 
284
286
 
285
- def _generate_login_thought(step_index: int, step: Step, goal: str, total_steps: int) -> str:
287
+ def _generate_login_thought(
288
+ step_index: int, step: Step, goal: str, total_steps: int
289
+ ) -> str:
286
290
  """Generate thought for login scenario (6 steps)."""
287
291
  action = step.action
288
292
  t = action.type
@@ -336,7 +340,9 @@ def _generate_login_thought(step_index: int, step: Step, goal: str, total_steps:
336
340
  )
337
341
 
338
342
 
339
- def _generate_registration_thought(step_index: int, step: Step, goal: str, total_steps: int) -> str:
343
+ def _generate_registration_thought(
344
+ step_index: int, step: Step, goal: str, total_steps: int
345
+ ) -> str:
340
346
  """Generate thought for registration scenario (12 steps)."""
341
347
  action = step.action
342
348
  t = action.type
@@ -469,7 +475,9 @@ def build_next_action_sft_samples(
469
475
  history_text += f" {i}. {action_text}\n"
470
476
  history_text += f"\nThis is step {step_index + 1} of {total_steps}. "
471
477
  else:
472
- history_text = f"This is step 1 of {total_steps} (no actions completed yet). "
478
+ history_text = (
479
+ f"This is step 1 of {total_steps} (no actions completed yet). "
480
+ )
473
481
 
474
482
  if use_som:
475
483
  user_content = (
@@ -477,7 +485,7 @@ def build_next_action_sft_samples(
477
485
  f"{history_text}"
478
486
  "Look at the screenshot and determine the NEXT action.\n\n"
479
487
  "Thought: [which numbered element to interact with and why]\n"
480
- "Action: [CLICK([N]) or TYPE([N], \"text\") or WAIT() or DONE()]"
488
+ 'Action: [CLICK([N]) or TYPE([N], "text") or WAIT() or DONE()]'
481
489
  )
482
490
  else:
483
491
  user_content = (
@@ -485,13 +493,15 @@ def build_next_action_sft_samples(
485
493
  f"{history_text}"
486
494
  "Look at the screenshot and determine the NEXT action.\n\n"
487
495
  "Thought: [what element to interact with and why]\n"
488
- "Action: [CLICK(x=..., y=...) or TYPE(text=\"...\") or WAIT() or DONE()]"
496
+ 'Action: [CLICK(x=..., y=...) or TYPE(text="...") or WAIT() or DONE()]'
489
497
  )
490
498
 
491
499
  # Provide a deterministic, semantically meaningful Thought while supervising
492
500
  # the exact DSL Action.
493
501
  action_text = format_action(step.action, use_som=use_som)
494
- thought_text = _generate_thought_for_step(step_index, step, goal, scenario, total_steps)
502
+ thought_text = _generate_thought_for_step(
503
+ step_index, step, goal, scenario, total_steps
504
+ )
495
505
  assistant_content = f"Thought: {thought_text}\nAction: {action_text}"
496
506
 
497
507
  sample = {
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING
19
19
  if TYPE_CHECKING:
20
20
  from PIL import Image
21
21
 
22
+ from openadapt_ml.data.types import Episode
22
23
  from openadapt_ml.grounding.base import GroundingModule, RegionCandidate
23
24
 
24
25
 
@@ -212,7 +213,7 @@ def evaluate_grounder_on_episode(
212
213
  """
213
214
  from PIL import Image
214
215
 
215
- from openadapt_ml.schema import Episode, ActionType
216
+ from openadapt_ml.schema import ActionType
216
217
 
217
218
  test_cases = []
218
219
 
@@ -220,7 +221,9 @@ def evaluate_grounder_on_episode(
220
221
  action = step.action
221
222
 
222
223
  # Get action type as string for comparison
223
- action_type_str = action.type.value if isinstance(action.type, ActionType) else action.type
224
+ action_type_str = (
225
+ action.type.value if isinstance(action.type, ActionType) else action.type
226
+ )
224
227
 
225
228
  # Only evaluate clicks with bboxes
226
229
  if action_type_str not in ("click", "double_click"):
@@ -250,7 +253,9 @@ def evaluate_grounder_on_episode(
250
253
  if action.normalized_coordinates:
251
254
  coords_x, coords_y = action.normalized_coordinates
252
255
  if coords_x is not None and coords_y is not None:
253
- target_desc = step.reasoning or f"element at ({coords_x:.2f}, {coords_y:.2f})"
256
+ target_desc = (
257
+ step.reasoning or f"element at ({coords_x:.2f}, {coords_y:.2f})"
258
+ )
254
259
  else:
255
260
  target_desc = step.reasoning or "target element"
256
261
 
@@ -73,7 +73,7 @@ def plot_eval_metrics(
73
73
  fig.suptitle(
74
74
  "VLM Model Comparison (Offline fine-tuned vs API models)",
75
75
  fontsize=12,
76
- fontweight='bold',
76
+ fontweight="bold",
77
77
  )
78
78
  if num_metrics == 1:
79
79
  axes = [axes]
@@ -96,36 +96,38 @@ def plot_eval_metrics(
96
96
  hatches.append(hatch)
97
97
 
98
98
  x = range(num_models)
99
- bars = ax.bar(x, values, tick_label=labels, color=colors, edgecolor='black', linewidth=1.2)
99
+ bars = ax.bar(
100
+ x, values, tick_label=labels, color=colors, edgecolor="black", linewidth=1.2
101
+ )
100
102
 
101
103
  # Apply hatch patterns
102
104
  for bar, hatch in zip(bars, hatches):
103
105
  bar.set_hatch(hatch)
104
106
 
105
- ax.set_title(title, fontsize=11, fontweight='bold')
107
+ ax.set_title(title, fontsize=11, fontweight="bold")
106
108
  ax.set_ylabel(key, fontsize=9)
107
109
  ax.set_ylim(bottom=0.0)
108
110
  # Rotate x-axis labels to prevent crowding
109
- ax.tick_params(axis='x', labelrotation=45, labelsize=8)
111
+ ax.tick_params(axis="x", labelrotation=45, labelsize=8)
110
112
  # Align labels to the right for better readability when rotated
111
113
  for tick in ax.get_xticklabels():
112
- tick.set_horizontalalignment('right')
114
+ tick.set_horizontalalignment("right")
113
115
 
114
116
  fig.tight_layout()
115
117
 
116
118
  # Add legend explaining color coding and hatch patterns
117
119
  legend_elements = [
118
- Patch(facecolor='#4A90E2', edgecolor='black', label='Qwen3-VL-2B'),
119
- Patch(facecolor='#2E5C8A', edgecolor='black', label='Qwen3-VL-8B'),
120
- Patch(facecolor='#FF6B35', edgecolor='black', label='Claude (API)'),
121
- Patch(facecolor='#C1121F', edgecolor='black', label='GPT (API)'),
122
- Patch(facecolor='gray', edgecolor='black', hatch='///', label='Fine-tuned'),
123
- Patch(facecolor='gray', edgecolor='black', label='Base/Pretrained'),
120
+ Patch(facecolor="#4A90E2", edgecolor="black", label="Qwen3-VL-2B"),
121
+ Patch(facecolor="#2E5C8A", edgecolor="black", label="Qwen3-VL-8B"),
122
+ Patch(facecolor="#FF6B35", edgecolor="black", label="Claude (API)"),
123
+ Patch(facecolor="#C1121F", edgecolor="black", label="GPT (API)"),
124
+ Patch(facecolor="gray", edgecolor="black", hatch="///", label="Fine-tuned"),
125
+ Patch(facecolor="gray", edgecolor="black", label="Base/Pretrained"),
124
126
  ]
125
127
 
126
128
  fig.legend(
127
129
  handles=legend_elements,
128
- loc='lower center',
130
+ loc="lower center",
129
131
  bbox_to_anchor=(0.5, -0.05),
130
132
  ncol=3,
131
133
  fontsize=9,
@@ -133,7 +135,7 @@ def plot_eval_metrics(
133
135
  )
134
136
 
135
137
  output_path.parent.mkdir(parents=True, exist_ok=True)
136
- fig.savefig(output_path, dpi=150, bbox_inches='tight')
138
+ fig.savefig(output_path, dpi=150, bbox_inches="tight")
137
139
  plt.close(fig)
138
140
 
139
141
 
@@ -15,10 +15,15 @@ class MilestoneSpec:
15
15
  A milestone is achieved when, at a specific step, the predicted action
16
16
  matches certain criteria (type match + optional coord threshold).
17
17
  """
18
+
18
19
  name: str
19
20
  step_index: int # Which step in the episode (0-indexed)
20
- expected_type: str # Expected ground truth action type ("click", "type", "done", etc.)
21
- coord_threshold: Optional[float] = None # If set, coord error must be < this for clicks
21
+ expected_type: (
22
+ str # Expected ground truth action type ("click", "type", "done", etc.)
23
+ )
24
+ coord_threshold: Optional[float] = (
25
+ None # If set, coord error must be < this for clicks
26
+ )
22
27
 
23
28
 
24
29
  # Predefined milestone specs per scenario
@@ -28,7 +33,9 @@ class MilestoneSpec:
28
33
  LOGIN_MILESTONES = [
29
34
  MilestoneSpec("typed_username", step_index=1, expected_type="type"),
30
35
  MilestoneSpec("typed_password", step_index=3, expected_type="type"),
31
- MilestoneSpec("clicked_login", step_index=4, expected_type="click", coord_threshold=0.10),
36
+ MilestoneSpec(
37
+ "clicked_login", step_index=4, expected_type="click", coord_threshold=0.10
38
+ ),
32
39
  MilestoneSpec("emitted_done", step_index=5, expected_type="done"),
33
40
  ]
34
41
 
@@ -81,14 +88,22 @@ class AggregateMetrics:
81
88
  action_type_accuracy: float
82
89
  mean_coord_error: Optional[float]
83
90
  coord_error_count: int
84
- episode_success_rate: Optional[float] # Strict: all steps must match (renamed from success_pred)
91
+ episode_success_rate: Optional[
92
+ float
93
+ ] # Strict: all steps must match (renamed from success_pred)
85
94
  click_hit_rate: Optional[float] # Point-based: within 5% of center
86
- mean_episode_progress: Optional[float] # Partial credit: avg(step_matches/step_total)
95
+ mean_episode_progress: Optional[
96
+ float
97
+ ] # Partial credit: avg(step_matches/step_total)
87
98
  # New partial-credit metrics
88
- mean_episode_step_score: Optional[float] # Strict partial: avg(full_step_correct/step_total)
99
+ mean_episode_step_score: Optional[
100
+ float
101
+ ] # Strict partial: avg(full_step_correct/step_total)
89
102
  weak_episode_success_rate: Optional[float] # Semantic milestones all achieved
90
103
  state_success_rate: Optional[float] = None # From model's State: {"success": true}
91
- bbox_hit_rate: Optional[float] = None # Bbox-based: click anywhere in element bounds
104
+ bbox_hit_rate: Optional[float] = (
105
+ None # Bbox-based: click anywhere in element bounds
106
+ )
92
107
  element_accuracy: Optional[float] = None # SoM element index accuracy
93
108
 
94
109
 
@@ -122,12 +137,7 @@ def compute_coordinate_error(pred_action: Action, gt_action: Action) -> Optional
122
137
  pred_x, pred_y = _get_normalized_coords(pred_action)
123
138
  gt_x, gt_y = _get_normalized_coords(gt_action)
124
139
 
125
- if (
126
- pred_x is None
127
- or pred_y is None
128
- or gt_x is None
129
- or gt_y is None
130
- ):
140
+ if pred_x is None or pred_y is None or gt_x is None or gt_y is None:
131
141
  return None
132
142
 
133
143
  dx = pred_x - gt_x
@@ -212,7 +222,9 @@ def evaluate_episode(
212
222
  sample = samples[sample_idx]
213
223
  sample_idx += 1
214
224
 
215
- pred_action, _thought, pred_state, raw_text = policy.predict_action_from_sample(sample)
225
+ pred_action, _thought, pred_state, raw_text = policy.predict_action_from_sample(
226
+ sample
227
+ )
216
228
  gt_action = step.action
217
229
 
218
230
  # Get action types as strings for comparison
@@ -233,7 +245,6 @@ def evaluate_episode(
233
245
 
234
246
  coord_error: Optional[float] = None
235
247
  click_hit = False
236
- bbox_hit = False
237
248
  element_hit = False
238
249
 
239
250
  # Helper to get element index - check element.element_id or raw field
@@ -273,7 +284,6 @@ def evaluate_episode(
273
284
  bbox_total += 1
274
285
  if in_bbox:
275
286
  bbox_hits += 1
276
- bbox_hit = True
277
287
 
278
288
  # Full step correctness: type matches AND element/coord match for relevant actions
279
289
  if type_match:
@@ -291,11 +301,17 @@ def evaluate_episode(
291
301
 
292
302
  # Track semantic milestones using the milestone spec
293
303
  for milestone in milestones:
294
- if step_idx == milestone.step_index and gt_type_str == milestone.expected_type:
304
+ if (
305
+ step_idx == milestone.step_index
306
+ and gt_type_str == milestone.expected_type
307
+ ):
295
308
  if pred_type_str == milestone.expected_type:
296
309
  # Check coord threshold if specified (for click actions)
297
310
  if milestone.coord_threshold is not None:
298
- if coord_error is not None and coord_error < milestone.coord_threshold:
311
+ if (
312
+ coord_error is not None
313
+ and coord_error < milestone.coord_threshold
314
+ ):
299
315
  milestones_achieved[milestone.name] = True
300
316
  else:
301
317
  # No coord threshold - type match is sufficient
@@ -428,18 +444,16 @@ def aggregate_metrics(episodes_metrics: List[EpisodeMetrics]) -> AggregateMetric
428
444
 
429
445
  # Partial credit: average episode progress (step_matches / step_total per episode)
430
446
  if eval_episodes:
431
- episode_progress_scores = [
432
- m.step_matches / m.step_total for m in eval_episodes
433
- ]
434
- mean_episode_progress = sum(episode_progress_scores) / len(episode_progress_scores)
447
+ episode_progress_scores = [m.step_matches / m.step_total for m in eval_episodes]
448
+ mean_episode_progress = sum(episode_progress_scores) / len(
449
+ episode_progress_scores
450
+ )
435
451
  else:
436
452
  mean_episode_progress = None
437
453
 
438
454
  # Strict partial: avg(full_step_correct / step_total) - requires type match + click hit
439
455
  if eval_episodes:
440
- step_scores = [
441
- m.full_step_correct / m.step_total for m in eval_episodes
442
- ]
456
+ step_scores = [m.full_step_correct / m.step_total for m in eval_episodes]
443
457
  mean_episode_step_score = sum(step_scores) / len(step_scores)
444
458
  else:
445
459
  mean_episode_step_score = None
@@ -447,7 +461,8 @@ def aggregate_metrics(episodes_metrics: List[EpisodeMetrics]) -> AggregateMetric
447
461
  # Weak episode success: all milestones achieved
448
462
  if eval_episodes:
449
463
  weak_success_count = sum(
450
- 1 for m in eval_episodes
464
+ 1
465
+ for m in eval_episodes
451
466
  if m.milestones_achieved and all(m.milestones_achieved.values())
452
467
  )
453
468
  weak_episode_success_rate = weak_success_count / len(eval_episodes)