openadapt-ml 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/baselines/__init__.py +121 -0
- openadapt_ml/baselines/adapter.py +185 -0
- openadapt_ml/baselines/cli.py +314 -0
- openadapt_ml/baselines/config.py +448 -0
- openadapt_ml/baselines/parser.py +922 -0
- openadapt_ml/baselines/prompts.py +787 -0
- openadapt_ml/benchmarks/__init__.py +13 -115
- openadapt_ml/benchmarks/agent.py +265 -421
- openadapt_ml/benchmarks/azure.py +28 -19
- openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
- openadapt_ml/benchmarks/cli.py +1722 -4847
- openadapt_ml/benchmarks/trace_export.py +631 -0
- openadapt_ml/benchmarks/viewer.py +22 -5
- openadapt_ml/benchmarks/vm_monitor.py +530 -29
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +47 -53
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +21 -20
- openadapt_ml/cloud/azure_inference.py +3 -5
- openadapt_ml/cloud/lambda_labs.py +722 -307
- openadapt_ml/cloud/local.py +2038 -487
- openadapt_ml/cloud/ssh_tunnel.py +68 -26
- openadapt_ml/datasets/next_action.py +40 -30
- openadapt_ml/evals/grounding.py +8 -3
- openadapt_ml/evals/plot_eval_metrics.py +15 -13
- openadapt_ml/evals/trajectory_matching.py +41 -26
- openadapt_ml/experiments/demo_prompt/format_demo.py +16 -6
- openadapt_ml/experiments/demo_prompt/run_experiment.py +26 -16
- openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
- openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
- openadapt_ml/experiments/representation_shootout/config.py +390 -0
- openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
- openadapt_ml/experiments/representation_shootout/runner.py +687 -0
- openadapt_ml/experiments/waa_demo/runner.py +29 -14
- openadapt_ml/export/parquet.py +36 -24
- openadapt_ml/grounding/detector.py +18 -14
- openadapt_ml/ingest/__init__.py +8 -6
- openadapt_ml/ingest/capture.py +25 -22
- openadapt_ml/ingest/loader.py +7 -4
- openadapt_ml/ingest/synthetic.py +189 -100
- openadapt_ml/models/api_adapter.py +14 -4
- openadapt_ml/models/base_adapter.py +10 -2
- openadapt_ml/models/providers/__init__.py +288 -0
- openadapt_ml/models/providers/anthropic.py +266 -0
- openadapt_ml/models/providers/base.py +299 -0
- openadapt_ml/models/providers/google.py +376 -0
- openadapt_ml/models/providers/openai.py +342 -0
- openadapt_ml/models/qwen_vl.py +46 -19
- openadapt_ml/perception/__init__.py +35 -0
- openadapt_ml/perception/integration.py +399 -0
- openadapt_ml/retrieval/demo_retriever.py +50 -24
- openadapt_ml/retrieval/embeddings.py +9 -8
- openadapt_ml/retrieval/retriever.py +3 -1
- openadapt_ml/runtime/__init__.py +50 -0
- openadapt_ml/runtime/policy.py +18 -5
- openadapt_ml/runtime/safety_gate.py +471 -0
- openadapt_ml/schema/__init__.py +9 -0
- openadapt_ml/schema/converters.py +74 -27
- openadapt_ml/schema/episode.py +31 -18
- openadapt_ml/scripts/capture_screenshots.py +530 -0
- openadapt_ml/scripts/compare.py +85 -54
- openadapt_ml/scripts/demo_policy.py +4 -1
- openadapt_ml/scripts/eval_policy.py +15 -9
- openadapt_ml/scripts/make_gif.py +1 -1
- openadapt_ml/scripts/prepare_synthetic.py +3 -1
- openadapt_ml/scripts/train.py +21 -9
- openadapt_ml/segmentation/README.md +920 -0
- openadapt_ml/segmentation/__init__.py +97 -0
- openadapt_ml/segmentation/adapters/__init__.py +5 -0
- openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
- openadapt_ml/segmentation/annotator.py +610 -0
- openadapt_ml/segmentation/cache.py +290 -0
- openadapt_ml/segmentation/cli.py +674 -0
- openadapt_ml/segmentation/deduplicator.py +656 -0
- openadapt_ml/segmentation/frame_describer.py +788 -0
- openadapt_ml/segmentation/pipeline.py +340 -0
- openadapt_ml/segmentation/schemas.py +622 -0
- openadapt_ml/segmentation/segment_extractor.py +634 -0
- openadapt_ml/training/azure_ops_viewer.py +1097 -0
- openadapt_ml/training/benchmark_viewer.py +52 -41
- openadapt_ml/training/shared_ui.py +7 -7
- openadapt_ml/training/stub_provider.py +57 -35
- openadapt_ml/training/trainer.py +143 -86
- openadapt_ml/training/trl_trainer.py +70 -21
- openadapt_ml/training/viewer.py +323 -108
- openadapt_ml/training/viewer_components.py +180 -0
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.2.dist-info}/METADATA +215 -14
- openadapt_ml-0.2.2.dist-info/RECORD +116 -0
- openadapt_ml/benchmarks/base.py +0 -366
- openadapt_ml/benchmarks/data_collection.py +0 -432
- openadapt_ml/benchmarks/live_tracker.py +0 -180
- openadapt_ml/benchmarks/runner.py +0 -418
- openadapt_ml/benchmarks/waa.py +0 -761
- openadapt_ml/benchmarks/waa_live.py +0 -619
- openadapt_ml-0.2.0.dist-info/RECORD +0 -86
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.2.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.2.dist-info}/licenses/LICENSE +0 -0
openadapt_ml/cloud/ssh_tunnel.py
CHANGED
|
@@ -51,9 +51,8 @@ import signal
|
|
|
51
51
|
import socket
|
|
52
52
|
import subprocess
|
|
53
53
|
import time
|
|
54
|
-
from dataclasses import dataclass
|
|
54
|
+
from dataclasses import dataclass
|
|
55
55
|
from pathlib import Path
|
|
56
|
-
from typing import Any
|
|
57
56
|
|
|
58
57
|
logger = logging.getLogger(__name__)
|
|
59
58
|
|
|
@@ -97,9 +96,11 @@ class SSHTunnelManager:
|
|
|
97
96
|
"""
|
|
98
97
|
|
|
99
98
|
# Default tunnel configurations
|
|
99
|
+
# Note: WAA uses local_port=5001 to avoid conflicts with any local WAA server on 5000
|
|
100
|
+
# The remote port is still 5000 (where WAA Flask runs inside Windows)
|
|
100
101
|
DEFAULT_TUNNELS = [
|
|
101
102
|
TunnelConfig(name="vnc", local_port=8006, remote_port=8006),
|
|
102
|
-
TunnelConfig(name="waa", local_port=
|
|
103
|
+
TunnelConfig(name="waa", local_port=5001, remote_port=5000),
|
|
103
104
|
]
|
|
104
105
|
|
|
105
106
|
# Auto-reconnect settings
|
|
@@ -125,7 +126,9 @@ class SSHTunnelManager:
|
|
|
125
126
|
self._current_vm_ip: str | None = None
|
|
126
127
|
self._current_ssh_user: str | None = None
|
|
127
128
|
self._auto_reconnect = auto_reconnect
|
|
128
|
-
self._reconnect_attempts: dict[
|
|
129
|
+
self._reconnect_attempts: dict[
|
|
130
|
+
str, int
|
|
131
|
+
] = {} # Track reconnect attempts per tunnel
|
|
129
132
|
|
|
130
133
|
def start_tunnels_for_vm(
|
|
131
134
|
self,
|
|
@@ -198,7 +201,9 @@ class SSHTunnelManager:
|
|
|
198
201
|
pid=None, # We don't know the PID of the external tunnel
|
|
199
202
|
)
|
|
200
203
|
else:
|
|
201
|
-
logger.warning(
|
|
204
|
+
logger.warning(
|
|
205
|
+
f"Port {config.local_port} already in use by unknown process"
|
|
206
|
+
)
|
|
202
207
|
return TunnelStatus(
|
|
203
208
|
name=config.name,
|
|
204
209
|
active=False,
|
|
@@ -213,16 +218,25 @@ class SSHTunnelManager:
|
|
|
213
218
|
# TCPKeepAlive=yes: Enable TCP-level keepalive as additional safeguard
|
|
214
219
|
ssh_cmd = [
|
|
215
220
|
"ssh",
|
|
216
|
-
"-o",
|
|
217
|
-
"
|
|
218
|
-
"-o",
|
|
219
|
-
"
|
|
220
|
-
"-o",
|
|
221
|
-
"
|
|
222
|
-
"-o",
|
|
223
|
-
"
|
|
221
|
+
"-o",
|
|
222
|
+
"StrictHostKeyChecking=no",
|
|
223
|
+
"-o",
|
|
224
|
+
"UserKnownHostsFile=/dev/null",
|
|
225
|
+
"-o",
|
|
226
|
+
"LogLevel=ERROR",
|
|
227
|
+
"-o",
|
|
228
|
+
"ServerAliveInterval=60",
|
|
229
|
+
"-o",
|
|
230
|
+
"ServerAliveCountMax=10",
|
|
231
|
+
"-o",
|
|
232
|
+
"TCPKeepAlive=yes",
|
|
233
|
+
"-o",
|
|
234
|
+
"ExitOnForwardFailure=yes",
|
|
235
|
+
"-i",
|
|
236
|
+
str(self.ssh_key_path),
|
|
224
237
|
"-N", # Don't execute remote command
|
|
225
|
-
"-L",
|
|
238
|
+
"-L",
|
|
239
|
+
f"{config.local_port}:{config.remote_host}:{config.remote_port}",
|
|
226
240
|
f"{ssh_user}@{vm_ip}",
|
|
227
241
|
]
|
|
228
242
|
|
|
@@ -253,7 +267,9 @@ class SSHTunnelManager:
|
|
|
253
267
|
|
|
254
268
|
# Tunnel started successfully
|
|
255
269
|
self._active_tunnels[config.name] = (config, proc)
|
|
256
|
-
logger.info(
|
|
270
|
+
logger.info(
|
|
271
|
+
f"Started tunnel {config.name}: localhost:{config.local_port} -> {vm_ip}:{config.remote_port}"
|
|
272
|
+
)
|
|
257
273
|
|
|
258
274
|
return TunnelStatus(
|
|
259
275
|
name=config.name,
|
|
@@ -340,24 +356,36 @@ class SSHTunnelManager:
|
|
|
340
356
|
name=config.name,
|
|
341
357
|
active=True,
|
|
342
358
|
local_port=config.local_port,
|
|
343
|
-
remote_endpoint=f"{self._current_vm_ip}:{config.remote_port}"
|
|
359
|
+
remote_endpoint=f"{self._current_vm_ip}:{config.remote_port}"
|
|
360
|
+
if self._current_vm_ip
|
|
361
|
+
else "unknown",
|
|
344
362
|
pid=proc.pid,
|
|
345
363
|
)
|
|
346
364
|
else:
|
|
347
365
|
# Process died - but check if port is still working
|
|
348
366
|
# (could be another tunnel on the same port)
|
|
349
367
|
del self._active_tunnels[config.name]
|
|
350
|
-
if self._is_port_in_use(
|
|
368
|
+
if self._is_port_in_use(
|
|
369
|
+
config.local_port
|
|
370
|
+
) and self._check_tunnel_works(
|
|
371
|
+
config.local_port, config.remote_port
|
|
372
|
+
):
|
|
351
373
|
results[config.name] = TunnelStatus(
|
|
352
374
|
name=config.name,
|
|
353
375
|
active=True,
|
|
354
376
|
local_port=config.local_port,
|
|
355
|
-
remote_endpoint=f"{self._current_vm_ip}:{config.remote_port}"
|
|
377
|
+
remote_endpoint=f"{self._current_vm_ip}:{config.remote_port}"
|
|
378
|
+
if self._current_vm_ip
|
|
379
|
+
else "external",
|
|
356
380
|
pid=None, # External tunnel, PID unknown
|
|
357
381
|
)
|
|
358
382
|
else:
|
|
359
383
|
# Tunnel is dead - mark for restart if auto_reconnect enabled
|
|
360
|
-
if
|
|
384
|
+
if (
|
|
385
|
+
self._auto_reconnect
|
|
386
|
+
and auto_restart
|
|
387
|
+
and self._current_vm_ip
|
|
388
|
+
):
|
|
361
389
|
tunnels_to_restart.append(config)
|
|
362
390
|
results[config.name] = TunnelStatus(
|
|
363
391
|
name=config.name,
|
|
@@ -369,13 +397,19 @@ class SSHTunnelManager:
|
|
|
369
397
|
else:
|
|
370
398
|
# Not tracked internally - but check if an external tunnel exists
|
|
371
399
|
# This handles tunnels started by other processes or after manager restart
|
|
372
|
-
if self._is_port_in_use(config.local_port) and self._check_tunnel_works(
|
|
373
|
-
|
|
400
|
+
if self._is_port_in_use(config.local_port) and self._check_tunnel_works(
|
|
401
|
+
config.local_port, config.remote_port
|
|
402
|
+
):
|
|
403
|
+
logger.debug(
|
|
404
|
+
f"Found working external tunnel on port {config.local_port}"
|
|
405
|
+
)
|
|
374
406
|
results[config.name] = TunnelStatus(
|
|
375
407
|
name=config.name,
|
|
376
408
|
active=True,
|
|
377
409
|
local_port=config.local_port,
|
|
378
|
-
remote_endpoint=f"{self._current_vm_ip}:{config.remote_port}"
|
|
410
|
+
remote_endpoint=f"{self._current_vm_ip}:{config.remote_port}"
|
|
411
|
+
if self._current_vm_ip
|
|
412
|
+
else "external",
|
|
379
413
|
pid=None, # External tunnel, PID unknown
|
|
380
414
|
)
|
|
381
415
|
else:
|
|
@@ -390,16 +424,22 @@ class SSHTunnelManager:
|
|
|
390
424
|
for config in tunnels_to_restart:
|
|
391
425
|
attempts = self._reconnect_attempts.get(config.name, 0)
|
|
392
426
|
if attempts < self.MAX_RECONNECT_ATTEMPTS:
|
|
393
|
-
logger.info(
|
|
427
|
+
logger.info(
|
|
428
|
+
f"Auto-reconnecting tunnel {config.name} (attempt {attempts + 1}/{self.MAX_RECONNECT_ATTEMPTS})"
|
|
429
|
+
)
|
|
394
430
|
time.sleep(self.RECONNECT_DELAY_SECONDS)
|
|
395
431
|
self._reconnect_attempts[config.name] = attempts + 1
|
|
396
|
-
status = self._start_tunnel(
|
|
432
|
+
status = self._start_tunnel(
|
|
433
|
+
config, self._current_vm_ip, self._current_ssh_user or "azureuser"
|
|
434
|
+
)
|
|
397
435
|
results[config.name] = status
|
|
398
436
|
if status.active:
|
|
399
437
|
logger.info(f"Successfully reconnected tunnel {config.name}")
|
|
400
438
|
self._reconnect_attempts[config.name] = 0 # Reset on success
|
|
401
439
|
else:
|
|
402
|
-
logger.warning(
|
|
440
|
+
logger.warning(
|
|
441
|
+
f"Tunnel {config.name} exceeded max reconnect attempts ({self.MAX_RECONNECT_ATTEMPTS})"
|
|
442
|
+
)
|
|
403
443
|
results[config.name] = TunnelStatus(
|
|
404
444
|
name=config.name,
|
|
405
445
|
active=False,
|
|
@@ -455,7 +495,9 @@ class SSHTunnelManager:
|
|
|
455
495
|
"""
|
|
456
496
|
# If VM changed, stop old tunnels and reset reconnect attempts
|
|
457
497
|
if self._current_vm_ip and self._current_vm_ip != vm_ip:
|
|
458
|
-
logger.info(
|
|
498
|
+
logger.info(
|
|
499
|
+
f"VM IP changed from {self._current_vm_ip} to {vm_ip}, restarting tunnels"
|
|
500
|
+
)
|
|
459
501
|
self.stop_all_tunnels()
|
|
460
502
|
self.reset_reconnect_attempts() # Fresh start for new VM
|
|
461
503
|
|
|
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
from typing import Any, Dict, List
|
|
5
5
|
|
|
6
|
-
import torch
|
|
7
6
|
from torch.utils.data import Dataset
|
|
8
7
|
|
|
9
8
|
from openadapt_ml.schema import Action, ActionType, Episode, Step, UIElement
|
|
@@ -20,7 +19,7 @@ SYSTEM_PROMPT = (
|
|
|
20
19
|
"- Example: An element in the middle of the screen would be approximately x=0.5, y=0.5\n\n"
|
|
21
20
|
"ALLOWED ACTIONS (use exactly this format):\n"
|
|
22
21
|
"- CLICK(x=0.XX, y=0.XX) → click at normalized coordinates\n"
|
|
23
|
-
|
|
22
|
+
'- TYPE(text="...") → type text into the currently focused field\n'
|
|
24
23
|
"- WAIT() → wait for UI to update\n"
|
|
25
24
|
"- DONE() → task is complete\n\n"
|
|
26
25
|
"RESPONSE FORMAT (required):\n"
|
|
@@ -42,14 +41,14 @@ SYSTEM_PROMPT_SOM = (
|
|
|
42
41
|
"[3] = Login button\n\n"
|
|
43
42
|
"ALLOWED ACTIONS (use exactly this format):\n"
|
|
44
43
|
"- CLICK([N]) → click element with number N to focus/activate it\n"
|
|
45
|
-
|
|
44
|
+
'- TYPE([N], "text") → type text into element N (e.g., TYPE([2], "hello"))\n'
|
|
46
45
|
"- WAIT() → wait for UI to update\n"
|
|
47
46
|
"- DONE() → task is complete\n\n"
|
|
48
47
|
"ACTION SEQUENCE FOR LOGIN:\n"
|
|
49
48
|
"1. CLICK([1]) to focus username field\n"
|
|
50
|
-
|
|
49
|
+
'2. TYPE([1], "username") to enter username\n'
|
|
51
50
|
"3. CLICK([2]) to focus password field\n"
|
|
52
|
-
|
|
51
|
+
'4. TYPE([2], "password") to enter password\n'
|
|
53
52
|
"5. CLICK([3]) to submit login\n"
|
|
54
53
|
"6. DONE() when login is complete\n\n"
|
|
55
54
|
"RESPONSE FORMAT (required):\n"
|
|
@@ -74,20 +73,20 @@ SYSTEM_PROMPT_SOM_REGISTRATION = (
|
|
|
74
73
|
"[6] = Register button\n\n"
|
|
75
74
|
"ALLOWED ACTIONS (use exactly this format):\n"
|
|
76
75
|
"- CLICK([N]) → click element with number N to focus/activate it\n"
|
|
77
|
-
|
|
76
|
+
'- TYPE([N], "text") → type text into element N (e.g., TYPE([2], "hello"))\n'
|
|
78
77
|
"- WAIT() → wait for UI to update\n"
|
|
79
78
|
"- DONE() → task is complete\n\n"
|
|
80
79
|
"ACTION SEQUENCE FOR REGISTRATION:\n"
|
|
81
80
|
"1. CLICK([1]) to focus first name field\n"
|
|
82
|
-
|
|
81
|
+
'2. TYPE([1], "name") to enter first name\n'
|
|
83
82
|
"3. CLICK([2]) to focus last name field\n"
|
|
84
|
-
|
|
83
|
+
'4. TYPE([2], "name") to enter last name\n'
|
|
85
84
|
"5. CLICK([3]) to focus email field\n"
|
|
86
|
-
|
|
85
|
+
'6. TYPE([3], "email") to enter email\n'
|
|
87
86
|
"7. CLICK([4]) to focus password field\n"
|
|
88
|
-
|
|
87
|
+
'8. TYPE([4], "pass") to enter password\n'
|
|
89
88
|
"9. CLICK([5]) to focus confirm password field\n"
|
|
90
|
-
|
|
89
|
+
'10. TYPE([5], "pass") to enter confirmation\n'
|
|
91
90
|
"11. CLICK([6]) to submit registration\n"
|
|
92
91
|
"12. DONE() when registration is complete\n\n"
|
|
93
92
|
"RESPONSE FORMAT (required):\n"
|
|
@@ -127,12 +126,12 @@ def format_action(action: Action, use_som: bool = False) -> str:
|
|
|
127
126
|
if t == ActionType.CLICK and element_id is not None:
|
|
128
127
|
return f"CLICK([{element_id}])"
|
|
129
128
|
if t == ActionType.TYPE and action.text is not None:
|
|
130
|
-
escaped = action.text.replace("\\", "\\\\").replace("
|
|
129
|
+
escaped = action.text.replace("\\", "\\\\").replace('"', '\\"')
|
|
131
130
|
if element_id is not None:
|
|
132
|
-
return f
|
|
131
|
+
return f'TYPE([{element_id}], "{escaped}")'
|
|
133
132
|
else:
|
|
134
133
|
# Fallback: TYPE without element reference (for focused field)
|
|
135
|
-
return f
|
|
134
|
+
return f'TYPE("{escaped}")'
|
|
136
135
|
if t == ActionType.WAIT:
|
|
137
136
|
return "WAIT()"
|
|
138
137
|
if t == ActionType.DONE:
|
|
@@ -145,8 +144,8 @@ def format_action(action: Action, use_som: bool = False) -> str:
|
|
|
145
144
|
x, y = action.normalized_coordinates
|
|
146
145
|
return f"CLICK(x={x:.2f}, y={y:.2f})"
|
|
147
146
|
if t == ActionType.TYPE and action.text is not None:
|
|
148
|
-
escaped = action.text.replace("\\", "\\\\").replace("
|
|
149
|
-
return f
|
|
147
|
+
escaped = action.text.replace("\\", "\\\\").replace('"', '\\"')
|
|
148
|
+
return f'TYPE(text="{escaped}")'
|
|
150
149
|
if t == ActionType.WAIT:
|
|
151
150
|
return "WAIT()"
|
|
152
151
|
if t == ActionType.DONE:
|
|
@@ -181,13 +180,15 @@ def parse_action_som(text: str) -> Action:
|
|
|
181
180
|
match = re.match(r'TYPE\(\[(\d+)\],\s*["\'](.*)["\']\)', text, re.DOTALL)
|
|
182
181
|
if match:
|
|
183
182
|
idx = match.group(1)
|
|
184
|
-
content = match.group(2).replace("
|
|
185
|
-
return Action(
|
|
183
|
+
content = match.group(2).replace('\\"', '"').replace("\\\\", "\\")
|
|
184
|
+
return Action(
|
|
185
|
+
type=ActionType.TYPE, text=content, element=UIElement(element_id=idx)
|
|
186
|
+
)
|
|
186
187
|
|
|
187
188
|
# TYPE("text") - no element index
|
|
188
189
|
match = re.match(r'TYPE\(["\'](.*)["\']\)', text, re.DOTALL)
|
|
189
190
|
if match:
|
|
190
|
-
content = match.group(1).replace("
|
|
191
|
+
content = match.group(1).replace('\\"', '"').replace("\\\\", "\\")
|
|
191
192
|
return Action(type=ActionType.TYPE, text=content)
|
|
192
193
|
|
|
193
194
|
# WAIT()
|
|
@@ -202,7 +203,9 @@ def parse_action_som(text: str) -> Action:
|
|
|
202
203
|
return Action(type=ActionType.FAIL, raw={"text": text})
|
|
203
204
|
|
|
204
205
|
|
|
205
|
-
def _generate_generic_thought(
|
|
206
|
+
def _generate_generic_thought(
|
|
207
|
+
step_index: int, step: Step, goal: str, total_steps: int
|
|
208
|
+
) -> str:
|
|
206
209
|
"""Generate a thought for real captures (non-synthetic scenarios).
|
|
207
210
|
|
|
208
211
|
This creates action-appropriate thoughts that teach the model to output
|
|
@@ -239,7 +242,9 @@ def _generate_generic_thought(step_index: int, step: Step, goal: str, total_step
|
|
|
239
242
|
return f"{progress} I need to scroll to reveal more content or reach the target element for '{goal}'."
|
|
240
243
|
|
|
241
244
|
if t == ActionType.DRAG:
|
|
242
|
-
return
|
|
245
|
+
return (
|
|
246
|
+
f"{progress} I need to drag an element to complete this part of '{goal}'."
|
|
247
|
+
)
|
|
243
248
|
|
|
244
249
|
if t == ActionType.KEY:
|
|
245
250
|
return f"{progress} I need to press a key to continue the workflow."
|
|
@@ -269,9 +274,6 @@ def _generate_thought_for_step(
|
|
|
269
274
|
actions back to the stated objective.
|
|
270
275
|
"""
|
|
271
276
|
|
|
272
|
-
action = step.action
|
|
273
|
-
t = action.type
|
|
274
|
-
|
|
275
277
|
if scenario == "registration":
|
|
276
278
|
return _generate_registration_thought(step_index, step, goal, total_steps)
|
|
277
279
|
elif scenario == "login" and total_steps <= 7:
|
|
@@ -282,7 +284,9 @@ def _generate_thought_for_step(
|
|
|
282
284
|
return _generate_generic_thought(step_index, step, goal, total_steps)
|
|
283
285
|
|
|
284
286
|
|
|
285
|
-
def _generate_login_thought(
|
|
287
|
+
def _generate_login_thought(
|
|
288
|
+
step_index: int, step: Step, goal: str, total_steps: int
|
|
289
|
+
) -> str:
|
|
286
290
|
"""Generate thought for login scenario (6 steps)."""
|
|
287
291
|
action = step.action
|
|
288
292
|
t = action.type
|
|
@@ -336,7 +340,9 @@ def _generate_login_thought(step_index: int, step: Step, goal: str, total_steps:
|
|
|
336
340
|
)
|
|
337
341
|
|
|
338
342
|
|
|
339
|
-
def _generate_registration_thought(
|
|
343
|
+
def _generate_registration_thought(
|
|
344
|
+
step_index: int, step: Step, goal: str, total_steps: int
|
|
345
|
+
) -> str:
|
|
340
346
|
"""Generate thought for registration scenario (12 steps)."""
|
|
341
347
|
action = step.action
|
|
342
348
|
t = action.type
|
|
@@ -469,7 +475,9 @@ def build_next_action_sft_samples(
|
|
|
469
475
|
history_text += f" {i}. {action_text}\n"
|
|
470
476
|
history_text += f"\nThis is step {step_index + 1} of {total_steps}. "
|
|
471
477
|
else:
|
|
472
|
-
history_text =
|
|
478
|
+
history_text = (
|
|
479
|
+
f"This is step 1 of {total_steps} (no actions completed yet). "
|
|
480
|
+
)
|
|
473
481
|
|
|
474
482
|
if use_som:
|
|
475
483
|
user_content = (
|
|
@@ -477,7 +485,7 @@ def build_next_action_sft_samples(
|
|
|
477
485
|
f"{history_text}"
|
|
478
486
|
"Look at the screenshot and determine the NEXT action.\n\n"
|
|
479
487
|
"Thought: [which numbered element to interact with and why]\n"
|
|
480
|
-
|
|
488
|
+
'Action: [CLICK([N]) or TYPE([N], "text") or WAIT() or DONE()]'
|
|
481
489
|
)
|
|
482
490
|
else:
|
|
483
491
|
user_content = (
|
|
@@ -485,13 +493,15 @@ def build_next_action_sft_samples(
|
|
|
485
493
|
f"{history_text}"
|
|
486
494
|
"Look at the screenshot and determine the NEXT action.\n\n"
|
|
487
495
|
"Thought: [what element to interact with and why]\n"
|
|
488
|
-
|
|
496
|
+
'Action: [CLICK(x=..., y=...) or TYPE(text="...") or WAIT() or DONE()]'
|
|
489
497
|
)
|
|
490
498
|
|
|
491
499
|
# Provide a deterministic, semantically meaningful Thought while supervising
|
|
492
500
|
# the exact DSL Action.
|
|
493
501
|
action_text = format_action(step.action, use_som=use_som)
|
|
494
|
-
thought_text = _generate_thought_for_step(
|
|
502
|
+
thought_text = _generate_thought_for_step(
|
|
503
|
+
step_index, step, goal, scenario, total_steps
|
|
504
|
+
)
|
|
495
505
|
assistant_content = f"Thought: {thought_text}\nAction: {action_text}"
|
|
496
506
|
|
|
497
507
|
sample = {
|
openadapt_ml/evals/grounding.py
CHANGED
|
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING
|
|
|
19
19
|
if TYPE_CHECKING:
|
|
20
20
|
from PIL import Image
|
|
21
21
|
|
|
22
|
+
from openadapt_ml.data.types import Episode
|
|
22
23
|
from openadapt_ml.grounding.base import GroundingModule, RegionCandidate
|
|
23
24
|
|
|
24
25
|
|
|
@@ -212,7 +213,7 @@ def evaluate_grounder_on_episode(
|
|
|
212
213
|
"""
|
|
213
214
|
from PIL import Image
|
|
214
215
|
|
|
215
|
-
from openadapt_ml.schema import
|
|
216
|
+
from openadapt_ml.schema import ActionType
|
|
216
217
|
|
|
217
218
|
test_cases = []
|
|
218
219
|
|
|
@@ -220,7 +221,9 @@ def evaluate_grounder_on_episode(
|
|
|
220
221
|
action = step.action
|
|
221
222
|
|
|
222
223
|
# Get action type as string for comparison
|
|
223
|
-
action_type_str =
|
|
224
|
+
action_type_str = (
|
|
225
|
+
action.type.value if isinstance(action.type, ActionType) else action.type
|
|
226
|
+
)
|
|
224
227
|
|
|
225
228
|
# Only evaluate clicks with bboxes
|
|
226
229
|
if action_type_str not in ("click", "double_click"):
|
|
@@ -250,7 +253,9 @@ def evaluate_grounder_on_episode(
|
|
|
250
253
|
if action.normalized_coordinates:
|
|
251
254
|
coords_x, coords_y = action.normalized_coordinates
|
|
252
255
|
if coords_x is not None and coords_y is not None:
|
|
253
|
-
target_desc =
|
|
256
|
+
target_desc = (
|
|
257
|
+
step.reasoning or f"element at ({coords_x:.2f}, {coords_y:.2f})"
|
|
258
|
+
)
|
|
254
259
|
else:
|
|
255
260
|
target_desc = step.reasoning or "target element"
|
|
256
261
|
|
|
@@ -73,7 +73,7 @@ def plot_eval_metrics(
|
|
|
73
73
|
fig.suptitle(
|
|
74
74
|
"VLM Model Comparison (Offline fine-tuned vs API models)",
|
|
75
75
|
fontsize=12,
|
|
76
|
-
fontweight=
|
|
76
|
+
fontweight="bold",
|
|
77
77
|
)
|
|
78
78
|
if num_metrics == 1:
|
|
79
79
|
axes = [axes]
|
|
@@ -96,36 +96,38 @@ def plot_eval_metrics(
|
|
|
96
96
|
hatches.append(hatch)
|
|
97
97
|
|
|
98
98
|
x = range(num_models)
|
|
99
|
-
bars = ax.bar(
|
|
99
|
+
bars = ax.bar(
|
|
100
|
+
x, values, tick_label=labels, color=colors, edgecolor="black", linewidth=1.2
|
|
101
|
+
)
|
|
100
102
|
|
|
101
103
|
# Apply hatch patterns
|
|
102
104
|
for bar, hatch in zip(bars, hatches):
|
|
103
105
|
bar.set_hatch(hatch)
|
|
104
106
|
|
|
105
|
-
ax.set_title(title, fontsize=11, fontweight=
|
|
107
|
+
ax.set_title(title, fontsize=11, fontweight="bold")
|
|
106
108
|
ax.set_ylabel(key, fontsize=9)
|
|
107
109
|
ax.set_ylim(bottom=0.0)
|
|
108
110
|
# Rotate x-axis labels to prevent crowding
|
|
109
|
-
ax.tick_params(axis=
|
|
111
|
+
ax.tick_params(axis="x", labelrotation=45, labelsize=8)
|
|
110
112
|
# Align labels to the right for better readability when rotated
|
|
111
113
|
for tick in ax.get_xticklabels():
|
|
112
|
-
tick.set_horizontalalignment(
|
|
114
|
+
tick.set_horizontalalignment("right")
|
|
113
115
|
|
|
114
116
|
fig.tight_layout()
|
|
115
117
|
|
|
116
118
|
# Add legend explaining color coding and hatch patterns
|
|
117
119
|
legend_elements = [
|
|
118
|
-
Patch(facecolor=
|
|
119
|
-
Patch(facecolor=
|
|
120
|
-
Patch(facecolor=
|
|
121
|
-
Patch(facecolor=
|
|
122
|
-
Patch(facecolor=
|
|
123
|
-
Patch(facecolor=
|
|
120
|
+
Patch(facecolor="#4A90E2", edgecolor="black", label="Qwen3-VL-2B"),
|
|
121
|
+
Patch(facecolor="#2E5C8A", edgecolor="black", label="Qwen3-VL-8B"),
|
|
122
|
+
Patch(facecolor="#FF6B35", edgecolor="black", label="Claude (API)"),
|
|
123
|
+
Patch(facecolor="#C1121F", edgecolor="black", label="GPT (API)"),
|
|
124
|
+
Patch(facecolor="gray", edgecolor="black", hatch="///", label="Fine-tuned"),
|
|
125
|
+
Patch(facecolor="gray", edgecolor="black", label="Base/Pretrained"),
|
|
124
126
|
]
|
|
125
127
|
|
|
126
128
|
fig.legend(
|
|
127
129
|
handles=legend_elements,
|
|
128
|
-
loc=
|
|
130
|
+
loc="lower center",
|
|
129
131
|
bbox_to_anchor=(0.5, -0.05),
|
|
130
132
|
ncol=3,
|
|
131
133
|
fontsize=9,
|
|
@@ -133,7 +135,7 @@ def plot_eval_metrics(
|
|
|
133
135
|
)
|
|
134
136
|
|
|
135
137
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
136
|
-
fig.savefig(output_path, dpi=150, bbox_inches=
|
|
138
|
+
fig.savefig(output_path, dpi=150, bbox_inches="tight")
|
|
137
139
|
plt.close(fig)
|
|
138
140
|
|
|
139
141
|
|
|
@@ -15,10 +15,15 @@ class MilestoneSpec:
|
|
|
15
15
|
A milestone is achieved when, at a specific step, the predicted action
|
|
16
16
|
matches certain criteria (type match + optional coord threshold).
|
|
17
17
|
"""
|
|
18
|
+
|
|
18
19
|
name: str
|
|
19
20
|
step_index: int # Which step in the episode (0-indexed)
|
|
20
|
-
expected_type:
|
|
21
|
-
|
|
21
|
+
expected_type: (
|
|
22
|
+
str # Expected ground truth action type ("click", "type", "done", etc.)
|
|
23
|
+
)
|
|
24
|
+
coord_threshold: Optional[float] = (
|
|
25
|
+
None # If set, coord error must be < this for clicks
|
|
26
|
+
)
|
|
22
27
|
|
|
23
28
|
|
|
24
29
|
# Predefined milestone specs per scenario
|
|
@@ -28,7 +33,9 @@ class MilestoneSpec:
|
|
|
28
33
|
LOGIN_MILESTONES = [
|
|
29
34
|
MilestoneSpec("typed_username", step_index=1, expected_type="type"),
|
|
30
35
|
MilestoneSpec("typed_password", step_index=3, expected_type="type"),
|
|
31
|
-
MilestoneSpec(
|
|
36
|
+
MilestoneSpec(
|
|
37
|
+
"clicked_login", step_index=4, expected_type="click", coord_threshold=0.10
|
|
38
|
+
),
|
|
32
39
|
MilestoneSpec("emitted_done", step_index=5, expected_type="done"),
|
|
33
40
|
]
|
|
34
41
|
|
|
@@ -81,14 +88,22 @@ class AggregateMetrics:
|
|
|
81
88
|
action_type_accuracy: float
|
|
82
89
|
mean_coord_error: Optional[float]
|
|
83
90
|
coord_error_count: int
|
|
84
|
-
episode_success_rate: Optional[
|
|
91
|
+
episode_success_rate: Optional[
|
|
92
|
+
float
|
|
93
|
+
] # Strict: all steps must match (renamed from success_pred)
|
|
85
94
|
click_hit_rate: Optional[float] # Point-based: within 5% of center
|
|
86
|
-
mean_episode_progress: Optional[
|
|
95
|
+
mean_episode_progress: Optional[
|
|
96
|
+
float
|
|
97
|
+
] # Partial credit: avg(step_matches/step_total)
|
|
87
98
|
# New partial-credit metrics
|
|
88
|
-
mean_episode_step_score: Optional[
|
|
99
|
+
mean_episode_step_score: Optional[
|
|
100
|
+
float
|
|
101
|
+
] # Strict partial: avg(full_step_correct/step_total)
|
|
89
102
|
weak_episode_success_rate: Optional[float] # Semantic milestones all achieved
|
|
90
103
|
state_success_rate: Optional[float] = None # From model's State: {"success": true}
|
|
91
|
-
bbox_hit_rate: Optional[float] =
|
|
104
|
+
bbox_hit_rate: Optional[float] = (
|
|
105
|
+
None # Bbox-based: click anywhere in element bounds
|
|
106
|
+
)
|
|
92
107
|
element_accuracy: Optional[float] = None # SoM element index accuracy
|
|
93
108
|
|
|
94
109
|
|
|
@@ -122,12 +137,7 @@ def compute_coordinate_error(pred_action: Action, gt_action: Action) -> Optional
|
|
|
122
137
|
pred_x, pred_y = _get_normalized_coords(pred_action)
|
|
123
138
|
gt_x, gt_y = _get_normalized_coords(gt_action)
|
|
124
139
|
|
|
125
|
-
if
|
|
126
|
-
pred_x is None
|
|
127
|
-
or pred_y is None
|
|
128
|
-
or gt_x is None
|
|
129
|
-
or gt_y is None
|
|
130
|
-
):
|
|
140
|
+
if pred_x is None or pred_y is None or gt_x is None or gt_y is None:
|
|
131
141
|
return None
|
|
132
142
|
|
|
133
143
|
dx = pred_x - gt_x
|
|
@@ -212,7 +222,9 @@ def evaluate_episode(
|
|
|
212
222
|
sample = samples[sample_idx]
|
|
213
223
|
sample_idx += 1
|
|
214
224
|
|
|
215
|
-
pred_action, _thought, pred_state, raw_text = policy.predict_action_from_sample(
|
|
225
|
+
pred_action, _thought, pred_state, raw_text = policy.predict_action_from_sample(
|
|
226
|
+
sample
|
|
227
|
+
)
|
|
216
228
|
gt_action = step.action
|
|
217
229
|
|
|
218
230
|
# Get action types as strings for comparison
|
|
@@ -233,7 +245,6 @@ def evaluate_episode(
|
|
|
233
245
|
|
|
234
246
|
coord_error: Optional[float] = None
|
|
235
247
|
click_hit = False
|
|
236
|
-
bbox_hit = False
|
|
237
248
|
element_hit = False
|
|
238
249
|
|
|
239
250
|
# Helper to get element index - check element.element_id or raw field
|
|
@@ -273,7 +284,6 @@ def evaluate_episode(
|
|
|
273
284
|
bbox_total += 1
|
|
274
285
|
if in_bbox:
|
|
275
286
|
bbox_hits += 1
|
|
276
|
-
bbox_hit = True
|
|
277
287
|
|
|
278
288
|
# Full step correctness: type matches AND element/coord match for relevant actions
|
|
279
289
|
if type_match:
|
|
@@ -291,11 +301,17 @@ def evaluate_episode(
|
|
|
291
301
|
|
|
292
302
|
# Track semantic milestones using the milestone spec
|
|
293
303
|
for milestone in milestones:
|
|
294
|
-
if
|
|
304
|
+
if (
|
|
305
|
+
step_idx == milestone.step_index
|
|
306
|
+
and gt_type_str == milestone.expected_type
|
|
307
|
+
):
|
|
295
308
|
if pred_type_str == milestone.expected_type:
|
|
296
309
|
# Check coord threshold if specified (for click actions)
|
|
297
310
|
if milestone.coord_threshold is not None:
|
|
298
|
-
if
|
|
311
|
+
if (
|
|
312
|
+
coord_error is not None
|
|
313
|
+
and coord_error < milestone.coord_threshold
|
|
314
|
+
):
|
|
299
315
|
milestones_achieved[milestone.name] = True
|
|
300
316
|
else:
|
|
301
317
|
# No coord threshold - type match is sufficient
|
|
@@ -428,18 +444,16 @@ def aggregate_metrics(episodes_metrics: List[EpisodeMetrics]) -> AggregateMetric
|
|
|
428
444
|
|
|
429
445
|
# Partial credit: average episode progress (step_matches / step_total per episode)
|
|
430
446
|
if eval_episodes:
|
|
431
|
-
episode_progress_scores = [
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
447
|
+
episode_progress_scores = [m.step_matches / m.step_total for m in eval_episodes]
|
|
448
|
+
mean_episode_progress = sum(episode_progress_scores) / len(
|
|
449
|
+
episode_progress_scores
|
|
450
|
+
)
|
|
435
451
|
else:
|
|
436
452
|
mean_episode_progress = None
|
|
437
453
|
|
|
438
454
|
# Strict partial: avg(full_step_correct / step_total) - requires type match + click hit
|
|
439
455
|
if eval_episodes:
|
|
440
|
-
step_scores = [
|
|
441
|
-
m.full_step_correct / m.step_total for m in eval_episodes
|
|
442
|
-
]
|
|
456
|
+
step_scores = [m.full_step_correct / m.step_total for m in eval_episodes]
|
|
443
457
|
mean_episode_step_score = sum(step_scores) / len(step_scores)
|
|
444
458
|
else:
|
|
445
459
|
mean_episode_step_score = None
|
|
@@ -447,7 +461,8 @@ def aggregate_metrics(episodes_metrics: List[EpisodeMetrics]) -> AggregateMetric
|
|
|
447
461
|
# Weak episode success: all milestones achieved
|
|
448
462
|
if eval_episodes:
|
|
449
463
|
weak_success_count = sum(
|
|
450
|
-
1
|
|
464
|
+
1
|
|
465
|
+
for m in eval_episodes
|
|
451
466
|
if m.milestones_achieved and all(m.milestones_achieved.values())
|
|
452
467
|
)
|
|
453
468
|
weak_episode_success_rate = weak_success_count / len(eval_episodes)
|