openadapt-ml 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. openadapt_ml/baselines/__init__.py +121 -0
  2. openadapt_ml/baselines/adapter.py +185 -0
  3. openadapt_ml/baselines/cli.py +314 -0
  4. openadapt_ml/baselines/config.py +448 -0
  5. openadapt_ml/baselines/parser.py +922 -0
  6. openadapt_ml/baselines/prompts.py +787 -0
  7. openadapt_ml/benchmarks/__init__.py +13 -115
  8. openadapt_ml/benchmarks/agent.py +265 -421
  9. openadapt_ml/benchmarks/azure.py +28 -19
  10. openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
  11. openadapt_ml/benchmarks/cli.py +1722 -4847
  12. openadapt_ml/benchmarks/trace_export.py +631 -0
  13. openadapt_ml/benchmarks/viewer.py +22 -5
  14. openadapt_ml/benchmarks/vm_monitor.py +530 -29
  15. openadapt_ml/benchmarks/waa_deploy/Dockerfile +47 -53
  16. openadapt_ml/benchmarks/waa_deploy/api_agent.py +21 -20
  17. openadapt_ml/cloud/azure_inference.py +3 -5
  18. openadapt_ml/cloud/lambda_labs.py +722 -307
  19. openadapt_ml/cloud/local.py +2038 -487
  20. openadapt_ml/cloud/ssh_tunnel.py +68 -26
  21. openadapt_ml/datasets/next_action.py +40 -30
  22. openadapt_ml/evals/grounding.py +8 -3
  23. openadapt_ml/evals/plot_eval_metrics.py +15 -13
  24. openadapt_ml/evals/trajectory_matching.py +41 -26
  25. openadapt_ml/experiments/demo_prompt/format_demo.py +16 -6
  26. openadapt_ml/experiments/demo_prompt/run_experiment.py +26 -16
  27. openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
  28. openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
  29. openadapt_ml/experiments/representation_shootout/config.py +390 -0
  30. openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
  31. openadapt_ml/experiments/representation_shootout/runner.py +687 -0
  32. openadapt_ml/experiments/waa_demo/runner.py +29 -14
  33. openadapt_ml/export/parquet.py +36 -24
  34. openadapt_ml/grounding/detector.py +18 -14
  35. openadapt_ml/ingest/__init__.py +8 -6
  36. openadapt_ml/ingest/capture.py +25 -22
  37. openadapt_ml/ingest/loader.py +7 -4
  38. openadapt_ml/ingest/synthetic.py +189 -100
  39. openadapt_ml/models/api_adapter.py +14 -4
  40. openadapt_ml/models/base_adapter.py +10 -2
  41. openadapt_ml/models/providers/__init__.py +288 -0
  42. openadapt_ml/models/providers/anthropic.py +266 -0
  43. openadapt_ml/models/providers/base.py +299 -0
  44. openadapt_ml/models/providers/google.py +376 -0
  45. openadapt_ml/models/providers/openai.py +342 -0
  46. openadapt_ml/models/qwen_vl.py +46 -19
  47. openadapt_ml/perception/__init__.py +35 -0
  48. openadapt_ml/perception/integration.py +399 -0
  49. openadapt_ml/retrieval/demo_retriever.py +50 -24
  50. openadapt_ml/retrieval/embeddings.py +9 -8
  51. openadapt_ml/retrieval/retriever.py +3 -1
  52. openadapt_ml/runtime/__init__.py +50 -0
  53. openadapt_ml/runtime/policy.py +18 -5
  54. openadapt_ml/runtime/safety_gate.py +471 -0
  55. openadapt_ml/schema/__init__.py +9 -0
  56. openadapt_ml/schema/converters.py +74 -27
  57. openadapt_ml/schema/episode.py +31 -18
  58. openadapt_ml/scripts/capture_screenshots.py +530 -0
  59. openadapt_ml/scripts/compare.py +85 -54
  60. openadapt_ml/scripts/demo_policy.py +4 -1
  61. openadapt_ml/scripts/eval_policy.py +15 -9
  62. openadapt_ml/scripts/make_gif.py +1 -1
  63. openadapt_ml/scripts/prepare_synthetic.py +3 -1
  64. openadapt_ml/scripts/train.py +21 -9
  65. openadapt_ml/segmentation/README.md +920 -0
  66. openadapt_ml/segmentation/__init__.py +97 -0
  67. openadapt_ml/segmentation/adapters/__init__.py +5 -0
  68. openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
  69. openadapt_ml/segmentation/annotator.py +610 -0
  70. openadapt_ml/segmentation/cache.py +290 -0
  71. openadapt_ml/segmentation/cli.py +674 -0
  72. openadapt_ml/segmentation/deduplicator.py +656 -0
  73. openadapt_ml/segmentation/frame_describer.py +788 -0
  74. openadapt_ml/segmentation/pipeline.py +340 -0
  75. openadapt_ml/segmentation/schemas.py +622 -0
  76. openadapt_ml/segmentation/segment_extractor.py +634 -0
  77. openadapt_ml/training/azure_ops_viewer.py +1097 -0
  78. openadapt_ml/training/benchmark_viewer.py +52 -41
  79. openadapt_ml/training/shared_ui.py +7 -7
  80. openadapt_ml/training/stub_provider.py +57 -35
  81. openadapt_ml/training/trainer.py +143 -86
  82. openadapt_ml/training/trl_trainer.py +70 -21
  83. openadapt_ml/training/viewer.py +323 -108
  84. openadapt_ml/training/viewer_components.py +180 -0
  85. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.2.dist-info}/METADATA +215 -14
  86. openadapt_ml-0.2.2.dist-info/RECORD +116 -0
  87. openadapt_ml/benchmarks/base.py +0 -366
  88. openadapt_ml/benchmarks/data_collection.py +0 -432
  89. openadapt_ml/benchmarks/live_tracker.py +0 -180
  90. openadapt_ml/benchmarks/runner.py +0 -418
  91. openadapt_ml/benchmarks/waa.py +0 -761
  92. openadapt_ml/benchmarks/waa_live.py +0 -619
  93. openadapt_ml-0.2.0.dist-info/RECORD +0 -86
  94. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.2.dist-info}/WHEEL +0 -0
  95. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.2.dist-info}/licenses/LICENSE +0 -0
@@ -1,8 +1,15 @@
1
- """Agent interface for benchmark evaluation.
1
+ """ML-specific agents for benchmark evaluation.
2
2
 
3
- This module provides the BenchmarkAgent interface that agents must implement
4
- to be evaluated on benchmarks, plus adapters to wrap existing openadapt-ml
5
- components.
3
+ This module provides agents that wrap openadapt-ml components (VLM adapters,
4
+ policies, baselines) for benchmark evaluation.
5
+
6
+ For standalone agents without ML dependencies, use openadapt_evals:
7
+ from openadapt_evals import ApiAgent, ScriptedAgent, RandomAgent
8
+
9
+ ML-specific agents in this module:
10
+ - PolicyAgent: Wraps openadapt_ml.runtime.policy.AgentPolicy
11
+ - APIBenchmarkAgent: Uses openadapt_ml.models.api_adapter.ApiVLMAdapter
12
+ - UnifiedBaselineAgent: Uses openadapt_ml.baselines adapters
6
13
 
7
14
  Example:
8
15
  from openadapt_ml.benchmarks import PolicyAgent
@@ -12,7 +19,7 @@ Example:
12
19
  agent = PolicyAgent(policy)
13
20
  results = evaluate_agent_on_benchmark(agent, benchmark_adapter)
14
21
 
15
- # API-backed agents (GPT-5.1, Claude)
22
+ # API-backed agents (GPT-5.1, Claude) using openadapt-ml adapters
16
23
  from openadapt_ml.benchmarks import APIBenchmarkAgent
17
24
 
18
25
  agent = APIBenchmarkAgent(provider="anthropic") # Uses Claude
@@ -22,13 +29,13 @@ Example:
22
29
 
23
30
  from __future__ import annotations
24
31
 
25
- import json
26
32
  import re
27
- from abc import ABC, abstractmethod
28
33
  from typing import TYPE_CHECKING, Any
29
34
 
30
- from openadapt_ml.benchmarks.base import (
35
+ # Import base classes from openadapt-evals (canonical location)
36
+ from openadapt_evals import (
31
37
  BenchmarkAction,
38
+ BenchmarkAgent,
32
39
  BenchmarkObservation,
33
40
  BenchmarkTask,
34
41
  )
@@ -36,43 +43,7 @@ from openadapt_ml.benchmarks.base import (
36
43
  if TYPE_CHECKING:
37
44
  from openadapt_ml.models.api_adapter import ApiVLMAdapter
38
45
  from openadapt_ml.runtime.policy import AgentPolicy
39
- from openadapt_ml.schema import Action, ActionType
40
-
41
-
42
- class BenchmarkAgent(ABC):
43
- """Abstract interface for agents evaluated on benchmarks.
44
-
45
- Agents must implement the `act` method to receive observations
46
- and return actions. The agent can maintain internal state across
47
- steps within an episode.
48
- """
49
-
50
- @abstractmethod
51
- def act(
52
- self,
53
- observation: BenchmarkObservation,
54
- task: BenchmarkTask,
55
- history: list[tuple[BenchmarkObservation, BenchmarkAction]] | None = None,
56
- ) -> BenchmarkAction:
57
- """Given observation and task, return next action.
58
-
59
- Args:
60
- observation: Current observation from the environment.
61
- task: Task being performed.
62
- history: Optional list of previous (observation, action) pairs.
63
-
64
- Returns:
65
- Action to execute.
66
- """
67
- pass
68
-
69
- def reset(self) -> None:
70
- """Reset agent state between episodes.
71
-
72
- Called before starting a new task. Override to clear any
73
- internal state.
74
- """
75
- pass
46
+ from openadapt_ml.schema import Action
76
47
 
77
48
 
78
49
  class PolicyAgent(BenchmarkAgent):
@@ -128,61 +99,37 @@ class PolicyAgent(BenchmarkAgent):
128
99
  task: BenchmarkTask,
129
100
  history: list[tuple[BenchmarkObservation, BenchmarkAction]] | None,
130
101
  ) -> dict:
131
- """Build SFT-style sample from benchmark observation.
132
-
133
- Args:
134
- observation: Current observation.
135
- task: Current task.
136
- history: Action history.
137
-
138
- Returns:
139
- Sample dict with 'images' and 'messages'.
140
- """
141
- # Build user message content
102
+ """Build SFT-style sample from benchmark observation."""
142
103
  content_parts = [f"Goal: {task.instruction}"]
143
104
 
144
- # Add accessibility tree if available and enabled
145
105
  if self.use_accessibility_tree and observation.accessibility_tree:
146
106
  tree_str = self._format_accessibility_tree(observation.accessibility_tree)
147
107
  content_parts.append(f"UI Elements:\n{tree_str}")
148
108
 
149
- # Add context
150
109
  if observation.url:
151
110
  content_parts.append(f"URL: {observation.url}")
152
111
  if observation.window_title:
153
112
  content_parts.append(f"Window: {observation.window_title}")
154
113
 
155
- # Add history if enabled
156
114
  if self.use_history and history:
157
115
  history_str = self._format_history(history)
158
116
  content_parts.append(f"Previous actions:\n{history_str}")
159
117
 
160
118
  content_parts.append("What action should be taken next?")
161
119
 
162
- # Build sample
163
120
  sample = {
164
121
  "messages": [
165
122
  {"role": "user", "content": "\n\n".join(content_parts)},
166
123
  ],
167
124
  }
168
125
 
169
- # Add image if available
170
126
  if observation.screenshot_path:
171
127
  sample["images"] = [observation.screenshot_path]
172
128
 
173
129
  return sample
174
130
 
175
131
  def _format_accessibility_tree(self, tree: dict, indent: int = 0) -> str:
176
- """Format accessibility tree for prompt.
177
-
178
- Args:
179
- tree: Accessibility tree dict.
180
- indent: Current indentation level.
181
-
182
- Returns:
183
- Formatted string representation.
184
- """
185
- # Simple formatting - can be overridden for platform-specific formatting
132
+ """Format accessibility tree for prompt."""
186
133
  lines = []
187
134
  prefix = " " * indent
188
135
 
@@ -203,29 +150,15 @@ class PolicyAgent(BenchmarkAgent):
203
150
  def _format_history(
204
151
  self, history: list[tuple[BenchmarkObservation, BenchmarkAction]]
205
152
  ) -> str:
206
- """Format action history for prompt.
207
-
208
- Args:
209
- history: List of (observation, action) pairs.
210
-
211
- Returns:
212
- Formatted string.
213
- """
153
+ """Format action history for prompt."""
214
154
  lines = []
215
- for i, (obs, action) in enumerate(history[-5:], 1): # Last 5 actions
155
+ for i, (obs, action) in enumerate(history[-5:], 1):
216
156
  action_str = self._action_to_string(action)
217
157
  lines.append(f"{i}. {action_str}")
218
158
  return "\n".join(lines)
219
159
 
220
160
  def _action_to_string(self, action: BenchmarkAction) -> str:
221
- """Convert BenchmarkAction to string representation.
222
-
223
- Args:
224
- action: Action to convert.
225
-
226
- Returns:
227
- String representation.
228
- """
161
+ """Convert BenchmarkAction to string representation."""
229
162
  if action.type == "click":
230
163
  if action.target_name:
231
164
  return f"CLICK({action.target_name})"
@@ -250,29 +183,19 @@ class PolicyAgent(BenchmarkAgent):
250
183
  def _to_benchmark_action(
251
184
  self, action: Action, thought: str | None
252
185
  ) -> BenchmarkAction:
253
- """Convert openadapt-ml Action to BenchmarkAction.
254
-
255
- Args:
256
- action: Action from policy.
257
- thought: Optional thought/reasoning.
258
-
259
- Returns:
260
- BenchmarkAction.
261
- """
262
- # Extract normalized coordinates
186
+ """Convert openadapt-ml Action to BenchmarkAction."""
263
187
  x, y = None, None
264
188
  if action.normalized_coordinates is not None:
265
189
  x, y = action.normalized_coordinates
266
190
 
267
- # Extract end coordinates for drag
268
191
  end_x, end_y = None, None
269
192
  if action.normalized_end is not None:
270
193
  end_x, end_y = action.normalized_end
271
194
 
272
- # Extract action type value (enum -> string)
273
- action_type = action.type.value if hasattr(action.type, 'value') else action.type
195
+ action_type = (
196
+ action.type.value if hasattr(action.type, "value") else action.type
197
+ )
274
198
 
275
- # Extract element info if available
276
199
  target_node_id = None
277
200
  target_role = None
278
201
  target_name = None
@@ -310,192 +233,28 @@ class PolicyAgent(BenchmarkAgent):
310
233
 
311
234
  def reset(self) -> None:
312
235
  """Reset agent state."""
313
- # PolicyAgent is stateless, nothing to reset
314
- pass
315
-
316
-
317
- class ScriptedAgent(BenchmarkAgent):
318
- """Agent that follows a predefined script of actions.
319
-
320
- Useful for testing benchmark adapters or replaying trajectories.
321
-
322
- Args:
323
- actions: List of actions to execute in order.
324
- """
325
-
326
- def __init__(self, actions: list[BenchmarkAction]):
327
- self.actions = actions
328
- self._step = 0
329
-
330
- def act(
331
- self,
332
- observation: BenchmarkObservation,
333
- task: BenchmarkTask,
334
- history: list[tuple[BenchmarkObservation, BenchmarkAction]] | None = None,
335
- ) -> BenchmarkAction:
336
- """Return the next scripted action.
337
-
338
- Args:
339
- observation: Ignored.
340
- task: Ignored.
341
- history: Ignored.
342
-
343
- Returns:
344
- Next action from script, or DONE if script exhausted.
345
- """
346
- if self._step < len(self.actions):
347
- action = self.actions[self._step]
348
- self._step += 1
349
- return action
350
- return BenchmarkAction(type="done")
351
-
352
- def reset(self) -> None:
353
- """Reset step counter."""
354
- self._step = 0
355
-
356
-
357
- class RandomAgent(BenchmarkAgent):
358
- """Agent that takes random actions.
359
-
360
- Useful for baseline comparisons.
361
-
362
- Args:
363
- action_types: List of action types to randomly select from.
364
- seed: Random seed for reproducibility.
365
- """
366
-
367
- def __init__(
368
- self,
369
- action_types: list[str] | None = None,
370
- seed: int | None = None,
371
- ):
372
- import random
373
-
374
- self.action_types = action_types or ["click", "type", "scroll", "done"]
375
- self.rng = random.Random(seed)
376
-
377
- def act(
378
- self,
379
- observation: BenchmarkObservation,
380
- task: BenchmarkTask,
381
- history: list[tuple[BenchmarkObservation, BenchmarkAction]] | None = None,
382
- ) -> BenchmarkAction:
383
- """Return a random action.
384
-
385
- Args:
386
- observation: Used to get viewport bounds.
387
- task: Ignored.
388
- history: Used to decide when to stop.
389
-
390
- Returns:
391
- Random action.
392
- """
393
- # Stop after many actions
394
- if history and len(history) > 20:
395
- return BenchmarkAction(type="done")
396
-
397
- action_type = self.rng.choice(self.action_types)
398
-
399
- if action_type == "click":
400
- return BenchmarkAction(
401
- type="click",
402
- x=self.rng.random(),
403
- y=self.rng.random(),
404
- )
405
- elif action_type == "type":
406
- return BenchmarkAction(
407
- type="type",
408
- text="test",
409
- )
410
- elif action_type == "scroll":
411
- return BenchmarkAction(
412
- type="scroll",
413
- scroll_direction=self.rng.choice(["up", "down"]),
414
- )
415
- else:
416
- return BenchmarkAction(type="done")
417
-
418
- def reset(self) -> None:
419
- """Nothing to reset."""
420
236
  pass
421
237
 
422
238
 
423
- class SmartMockAgent(BenchmarkAgent):
424
- """Agent designed to pass WAAMockAdapter evaluation.
425
-
426
- Performs a fixed sequence of actions that satisfy the mock adapter's
427
- success criteria. Use for validating the benchmark pipeline locally.
428
-
429
- The mock adapter evaluates success based on:
430
- - Clicking Submit (ID 4) - primary success path
431
- - Typing something AND clicking OK (ID 1) - form submission path
432
- - Calling DONE after at least 2 actions - reasonable completion
433
-
434
- This agent clicks Submit (ID 4) which is the simplest success path.
435
- """
436
-
437
- def __init__(self):
438
- """Initialize the agent."""
439
- self._step = 0
440
- # Simple action sequence: click Submit button (ID 4), then done
441
- self._actions = [
442
- BenchmarkAction(type="click", target_node_id="4"), # Click Submit
443
- BenchmarkAction(type="done"),
444
- ]
445
-
446
- def act(
447
- self,
448
- observation: BenchmarkObservation,
449
- task: BenchmarkTask,
450
- history: list[tuple[BenchmarkObservation, BenchmarkAction]] | None = None,
451
- ) -> BenchmarkAction:
452
- """Return the next scripted action.
453
-
454
- Args:
455
- observation: Ignored.
456
- task: Ignored.
457
- history: Ignored.
458
-
459
- Returns:
460
- Next action from script, or DONE if script exhausted.
461
- """
462
- if self._step < len(self._actions):
463
- action = self._actions[self._step]
464
- self._step += 1
465
- return action
466
- return BenchmarkAction(type="done")
467
-
468
- def reset(self) -> None:
469
- """Reset step counter."""
470
- self._step = 0
471
-
472
-
473
239
  class APIBenchmarkAgent(BenchmarkAgent):
474
- """Agent that uses hosted VLM APIs (Claude, GPT-5.1) for benchmark evaluation.
240
+ """Agent that uses hosted VLM APIs via openadapt-ml ApiVLMAdapter.
475
241
 
476
242
  This agent wraps ApiVLMAdapter to provide Claude or GPT-5.1 baselines
477
243
  for benchmark evaluation. It converts BenchmarkObservation to the
478
244
  API format and parses VLM responses into BenchmarkActions.
479
245
 
246
+ Note: For standalone API evaluation without openadapt-ml, use
247
+ openadapt_evals.ApiAgent instead (has P0 demo persistence fix).
248
+
480
249
  Args:
481
250
  provider: API provider - "anthropic" (Claude) or "openai" (GPT-5.1).
482
251
  api_key: Optional API key override. If not provided, uses env vars.
483
- model: Optional model name override. Defaults to provider's best VLM.
252
+ model: Optional model name override.
484
253
  max_tokens: Maximum tokens for VLM response.
485
254
  use_accessibility_tree: Whether to include accessibility tree in prompt.
486
255
  use_history: Whether to include action history in prompt.
487
-
488
- Example:
489
- # Claude baseline
490
- agent = APIBenchmarkAgent(provider="anthropic")
491
- results = evaluate_agent_on_benchmark(agent, waa_adapter)
492
-
493
- # GPT-5.1 baseline
494
- agent = APIBenchmarkAgent(provider="openai")
495
- results = evaluate_agent_on_benchmark(agent, waa_adapter)
496
256
  """
497
257
 
498
- # System prompt for GUI automation
499
258
  SYSTEM_PROMPT = """You are a GUI automation agent. Given a screenshot and task instruction, determine the next action to take.
500
259
 
501
260
  Available actions:
@@ -505,7 +264,7 @@ Available actions:
505
264
  - KEY(key) - Press a key (e.g., Enter, Tab, Escape)
506
265
  - KEY(modifier+key) - Press key combination (e.g., Ctrl+c, Alt+Tab)
507
266
  - SCROLL(direction) - Scroll up or down
508
- - DRAG(x1, y1, x2, y2) - Drag from (x1,y1) to (x2,y2) (pixel or normalized)
267
+ - DRAG(x1, y1, x2, y2) - Drag from (x1,y1) to (x2,y2)
509
268
  - DONE() - Task is complete
510
269
  - ANSWER("response") - For QA tasks, provide the answer
511
270
 
@@ -554,32 +313,15 @@ Then output the action on a new line starting with "ACTION:"
554
313
  task: BenchmarkTask,
555
314
  history: list[tuple[BenchmarkObservation, BenchmarkAction]] | None = None,
556
315
  ) -> BenchmarkAction:
557
- """Use VLM API to determine next action.
558
-
559
- Args:
560
- observation: Current observation with screenshot.
561
- task: Task being performed.
562
- history: Previous observations and actions.
563
-
564
- Returns:
565
- BenchmarkAction parsed from VLM response.
566
- """
316
+ """Use VLM API to determine next action."""
567
317
  adapter = self._get_adapter()
568
-
569
- # Build the sample for the API
570
318
  sample = self._build_sample(observation, task, history)
571
319
 
572
- # Call the VLM API
573
320
  try:
574
321
  response = adapter.generate(sample, max_new_tokens=self.max_tokens)
575
322
  except Exception as e:
576
- # On API error, return done to avoid infinite loops
577
- return BenchmarkAction(
578
- type="done",
579
- raw_action={"error": str(e)},
580
- )
323
+ return BenchmarkAction(type="done", raw_action={"error": str(e)})
581
324
 
582
- # Parse the response into a BenchmarkAction
583
325
  return self._parse_response(response, observation)
584
326
 
585
327
  def _build_sample(
@@ -588,41 +330,26 @@ Then output the action on a new line starting with "ACTION:"
588
330
  task: BenchmarkTask,
589
331
  history: list[tuple[BenchmarkObservation, BenchmarkAction]] | None,
590
332
  ) -> dict[str, Any]:
591
- """Build API sample from benchmark observation.
592
-
593
- Args:
594
- observation: Current observation.
595
- task: Current task.
596
- history: Action history.
597
-
598
- Returns:
599
- Sample dict with 'images' and 'messages'.
600
- """
601
- # Build user message content
333
+ """Build API sample from benchmark observation."""
602
334
  content_parts = [f"GOAL: {task.instruction}"]
603
335
 
604
- # Add context
605
336
  if observation.url:
606
337
  content_parts.append(f"URL: {observation.url}")
607
338
  if observation.window_title:
608
339
  content_parts.append(f"Window: {observation.window_title}")
609
340
 
610
- # Add accessibility tree if available and enabled
611
341
  if self.use_accessibility_tree and observation.accessibility_tree:
612
342
  tree_str = self._format_accessibility_tree(observation.accessibility_tree)
613
- # Truncate if too long
614
343
  if len(tree_str) > 4000:
615
344
  tree_str = tree_str[:4000] + "\n... (truncated)"
616
345
  content_parts.append(f"UI Elements:\n{tree_str}")
617
346
 
618
- # Add history if enabled
619
347
  if self.use_history and history:
620
348
  history_str = self._format_history(history)
621
349
  content_parts.append(f"Previous actions:\n{history_str}")
622
350
 
623
351
  content_parts.append("\nWhat is the next action?")
624
352
 
625
- # Build sample
626
353
  sample: dict[str, Any] = {
627
354
  "messages": [
628
355
  {"role": "system", "content": self.SYSTEM_PROMPT},
@@ -630,22 +357,13 @@ Then output the action on a new line starting with "ACTION:"
630
357
  ],
631
358
  }
632
359
 
633
- # Add image if available
634
360
  if observation.screenshot_path:
635
361
  sample["images"] = [observation.screenshot_path]
636
362
 
637
363
  return sample
638
364
 
639
365
  def _format_accessibility_tree(self, tree: dict, indent: int = 0) -> str:
640
- """Format accessibility tree for prompt.
641
-
642
- Args:
643
- tree: Accessibility tree dict.
644
- indent: Current indentation level.
645
-
646
- Returns:
647
- Formatted string representation.
648
- """
366
+ """Format accessibility tree for prompt."""
649
367
  lines = []
650
368
  prefix = " " * indent
651
369
 
@@ -666,29 +384,15 @@ Then output the action on a new line starting with "ACTION:"
666
384
  def _format_history(
667
385
  self, history: list[tuple[BenchmarkObservation, BenchmarkAction]]
668
386
  ) -> str:
669
- """Format action history for prompt.
670
-
671
- Args:
672
- history: List of (observation, action) pairs.
673
-
674
- Returns:
675
- Formatted string.
676
- """
387
+ """Format action history for prompt."""
677
388
  lines = []
678
- for i, (obs, action) in enumerate(history[-5:], 1): # Last 5 actions
389
+ for i, (obs, action) in enumerate(history[-5:], 1):
679
390
  action_str = self._action_to_string(action)
680
391
  lines.append(f"{i}. {action_str}")
681
392
  return "\n".join(lines)
682
393
 
683
394
  def _action_to_string(self, action: BenchmarkAction) -> str:
684
- """Convert BenchmarkAction to string representation.
685
-
686
- Args:
687
- action: Action to convert.
688
-
689
- Returns:
690
- String representation.
691
- """
395
+ """Convert BenchmarkAction to string representation."""
692
396
  if action.type == "click":
693
397
  if action.target_node_id:
694
398
  return f"CLICK([{action.target_node_id}])"
@@ -717,32 +421,14 @@ Then output the action on a new line starting with "ACTION:"
717
421
  def _parse_response(
718
422
  self, response: str, observation: BenchmarkObservation | None = None
719
423
  ) -> BenchmarkAction:
720
- """Parse VLM response into BenchmarkAction.
721
-
722
- Handles various response formats:
723
- - ACTION: CLICK(0.5, 0.3)
724
- - CLICK(0.5, 0.3)
725
- - I'll click at coordinates (0.5, 0.3) -> CLICK(0.5, 0.3)
726
-
727
- Args:
728
- response: Raw VLM response text.
729
- observation: Current observation (used for coordinate normalization).
730
-
731
- Returns:
732
- Parsed BenchmarkAction.
733
- """
734
- # Store raw response for debugging
424
+ """Parse VLM response into BenchmarkAction."""
735
425
  raw_action = {"response": response}
736
426
 
737
- # Extract action line (look for ACTION: prefix or action pattern)
738
427
  action_line = None
739
-
740
- # Try to find ACTION: prefix
741
428
  action_match = re.search(r"ACTION:\s*(.+)", response, re.IGNORECASE)
742
429
  if action_match:
743
430
  action_line = action_match.group(1).strip()
744
431
  else:
745
- # Look for action pattern anywhere in response
746
432
  patterns = [
747
433
  r"(CLICK\s*\([^)]+\))",
748
434
  r"(TYPE\s*\([^)]+\))",
@@ -759,146 +445,304 @@ Then output the action on a new line starting with "ACTION:"
759
445
  break
760
446
 
761
447
  if not action_line:
762
- # Could not parse action, return done
763
448
  raw_action["parse_error"] = "No action pattern found"
764
449
  return BenchmarkAction(type="done", raw_action=raw_action)
765
450
 
766
- # Parse CLICK action
451
+ # Parse CLICK([id])
767
452
  click_match = re.match(
768
453
  r"CLICK\s*\(\s*\[?(\d+)\]?\s*\)", action_line, re.IGNORECASE
769
454
  )
770
455
  if click_match:
771
- # CLICK([id]) - element ID
772
456
  node_id = click_match.group(1)
773
457
  return BenchmarkAction(
774
- type="click",
775
- target_node_id=node_id,
776
- raw_action=raw_action,
458
+ type="click", target_node_id=node_id, raw_action=raw_action
777
459
  )
778
460
 
461
+ # Parse CLICK(x, y)
779
462
  click_coords = re.match(
780
463
  r"CLICK\s*\(\s*([\d.]+)\s*,\s*([\d.]+)\s*\)", action_line, re.IGNORECASE
781
464
  )
782
465
  if click_coords:
783
- # CLICK(x, y) - coordinates
784
466
  x = float(click_coords.group(1))
785
467
  y = float(click_coords.group(2))
786
-
787
- # Normalize coordinates if they appear to be pixel values
788
- # If x or y > 1.0, assume pixel coordinates and normalize using viewport
789
468
  if observation and observation.viewport and (x > 1.0 or y > 1.0):
790
469
  width, height = observation.viewport
791
- x_norm = x / width
792
- y_norm = y / height
793
470
  raw_action["original_coords"] = {"x": x, "y": y}
794
471
  raw_action["normalized"] = True
795
- x = x_norm
796
- y = y_norm
797
-
798
- return BenchmarkAction(
799
- type="click",
800
- x=x,
801
- y=y,
802
- raw_action=raw_action,
803
- )
472
+ x, y = x / width, y / height
473
+ return BenchmarkAction(type="click", x=x, y=y, raw_action=raw_action)
804
474
 
805
- # Parse TYPE action
475
+ # Parse TYPE
806
476
  type_match = re.match(
807
477
  r"TYPE\s*\(\s*[\"'](.+?)[\"']\s*\)", action_line, re.IGNORECASE
808
478
  )
809
479
  if type_match:
810
- text = type_match.group(1)
811
480
  return BenchmarkAction(
812
- type="type",
813
- text=text,
814
- raw_action=raw_action,
481
+ type="type", text=type_match.group(1), raw_action=raw_action
815
482
  )
816
483
 
817
- # Parse KEY action
484
+ # Parse KEY
818
485
  key_match = re.match(r"KEY\s*\(\s*(.+?)\s*\)", action_line, re.IGNORECASE)
819
486
  if key_match:
820
487
  key_str = key_match.group(1)
821
- # Handle modifier+key format
822
488
  if "+" in key_str:
823
489
  parts = key_str.split("+")
824
- key = parts[-1]
825
- modifiers = parts[:-1]
826
490
  return BenchmarkAction(
827
491
  type="key",
828
- key=key,
829
- modifiers=modifiers,
492
+ key=parts[-1],
493
+ modifiers=parts[:-1],
830
494
  raw_action=raw_action,
831
495
  )
832
- return BenchmarkAction(
833
- type="key",
834
- key=key_str,
835
- raw_action=raw_action,
836
- )
496
+ return BenchmarkAction(type="key", key=key_str, raw_action=raw_action)
837
497
 
838
- # Parse SCROLL action
498
+ # Parse SCROLL
839
499
  scroll_match = re.match(
840
500
  r"SCROLL\s*\(\s*(up|down)\s*\)", action_line, re.IGNORECASE
841
501
  )
842
502
  if scroll_match:
843
- direction = scroll_match.group(1).lower()
844
503
  return BenchmarkAction(
845
504
  type="scroll",
846
- scroll_direction=direction,
505
+ scroll_direction=scroll_match.group(1).lower(),
847
506
  raw_action=raw_action,
848
507
  )
849
508
 
850
- # Parse DRAG action
509
+ # Parse DRAG
851
510
  drag_match = re.match(
852
511
  r"DRAG\s*\(\s*([\d.]+)\s*,\s*([\d.]+)\s*,\s*([\d.]+)\s*,\s*([\d.]+)\s*\)",
853
512
  action_line,
854
513
  re.IGNORECASE,
855
514
  )
856
515
  if drag_match:
857
- x = float(drag_match.group(1))
858
- y = float(drag_match.group(2))
859
- end_x = float(drag_match.group(3))
860
- end_y = float(drag_match.group(4))
861
-
862
- # Normalize coordinates if they appear to be pixel values
863
- if observation and observation.viewport and (x > 1.0 or y > 1.0 or end_x > 1.0 or end_y > 1.0):
516
+ x, y = float(drag_match.group(1)), float(drag_match.group(2))
517
+ end_x, end_y = float(drag_match.group(3)), float(drag_match.group(4))
518
+ if (
519
+ observation
520
+ and observation.viewport
521
+ and (x > 1.0 or y > 1.0 or end_x > 1.0 or end_y > 1.0)
522
+ ):
864
523
  width, height = observation.viewport
865
- raw_action["original_coords"] = {"x": x, "y": y, "end_x": end_x, "end_y": end_y}
524
+ raw_action["original_coords"] = {
525
+ "x": x,
526
+ "y": y,
527
+ "end_x": end_x,
528
+ "end_y": end_y,
529
+ }
866
530
  raw_action["normalized"] = True
867
- x = x / width
868
- y = y / height
869
- end_x = end_x / width
870
- end_y = end_y / height
871
-
531
+ x, y, end_x, end_y = (
532
+ x / width,
533
+ y / height,
534
+ end_x / width,
535
+ end_y / height,
536
+ )
872
537
  return BenchmarkAction(
873
- type="drag",
874
- x=x,
875
- y=y,
876
- end_x=end_x,
877
- end_y=end_y,
878
- raw_action=raw_action,
538
+ type="drag", x=x, y=y, end_x=end_x, end_y=end_y, raw_action=raw_action
879
539
  )
880
540
 
881
- # Parse DONE action
541
+ # Parse DONE
882
542
  if re.match(r"DONE\s*\(\s*\)", action_line, re.IGNORECASE):
883
543
  return BenchmarkAction(type="done", raw_action=raw_action)
884
544
 
885
- # Parse ANSWER action
545
+ # Parse ANSWER
886
546
  answer_match = re.match(
887
547
  r"ANSWER\s*\(\s*[\"'](.+?)[\"']\s*\)", action_line, re.IGNORECASE
888
548
  )
889
549
  if answer_match:
890
- answer = answer_match.group(1)
891
550
  return BenchmarkAction(
892
- type="answer",
893
- answer=answer,
894
- raw_action=raw_action,
551
+ type="answer", answer=answer_match.group(1), raw_action=raw_action
895
552
  )
896
553
 
897
- # Unknown action format
898
554
  raw_action["parse_error"] = f"Unknown action format: {action_line}"
899
555
  return BenchmarkAction(type="done", raw_action=raw_action)
900
556
 
901
557
  def reset(self) -> None:
902
558
  """Reset agent state."""
903
- # APIBenchmarkAgent is stateless, nothing to reset
904
559
  pass
560
+
561
+
562
+ class UnifiedBaselineAgent(BenchmarkAgent):
563
+ """Agent that uses UnifiedBaselineAdapter for benchmark evaluation.
564
+
565
+ Provides unified interface for Claude, GPT, and Gemini baselines
566
+ across multiple tracks (A: coordinates, B: ReAct, C: SoM).
567
+
568
+ Args:
569
+ model_alias: Model alias (e.g., 'claude-opus-4.5', 'gpt-5.2').
570
+ track: Track type ('A', 'B', or 'C'). Defaults to 'A'.
571
+ api_key: Optional API key override.
572
+ temperature: Sampling temperature.
573
+ max_tokens: Maximum tokens for response.
574
+ demo: Optional demo text for prompts.
575
+ verbose: Whether to print debug output.
576
+ """
577
+
578
+ def __init__(
579
+ self,
580
+ model_alias: str = "claude-opus-4.5",
581
+ track: str = "A",
582
+ api_key: str | None = None,
583
+ temperature: float = 0.1,
584
+ max_tokens: int = 1024,
585
+ demo: str | None = None,
586
+ verbose: bool = False,
587
+ ):
588
+ self.model_alias = model_alias
589
+ self.track = track.upper()
590
+ self.api_key = api_key
591
+ self.temperature = temperature
592
+ self.max_tokens = max_tokens
593
+ self.demo = demo
594
+ self.verbose = verbose
595
+ self._adapter = None
596
+
597
+ def _get_adapter(self):
598
+ """Lazily initialize the UnifiedBaselineAdapter."""
599
+ if self._adapter is None:
600
+ from openadapt_ml.baselines import TrackConfig, UnifiedBaselineAdapter
601
+
602
+ track_configs = {
603
+ "A": TrackConfig.track_a(),
604
+ "B": TrackConfig.track_b(),
605
+ "C": TrackConfig.track_c(),
606
+ }
607
+ track_config = track_configs.get(self.track, TrackConfig.track_a())
608
+
609
+ self._adapter = UnifiedBaselineAdapter.from_alias(
610
+ self.model_alias,
611
+ track=track_config,
612
+ api_key=self.api_key,
613
+ temperature=self.temperature,
614
+ max_tokens=self.max_tokens,
615
+ demo=self.demo,
616
+ verbose=self.verbose,
617
+ )
618
+ return self._adapter
619
+
620
+ def act(
621
+ self,
622
+ observation: BenchmarkObservation,
623
+ task: BenchmarkTask,
624
+ history: list[tuple[BenchmarkObservation, BenchmarkAction]] | None = None,
625
+ ) -> BenchmarkAction:
626
+ """Use UnifiedBaselineAdapter to determine next action."""
627
+ from PIL import Image
628
+
629
+ adapter = self._get_adapter()
630
+
631
+ screenshot = None
632
+ if observation.screenshot_path:
633
+ try:
634
+ screenshot = Image.open(observation.screenshot_path)
635
+ except Exception as e:
636
+ if self.verbose:
637
+ print(f"[UnifiedBaselineAgent] Failed to load screenshot: {e}")
638
+
639
+ a11y_tree = (
640
+ observation.accessibility_tree if observation.accessibility_tree else None
641
+ )
642
+
643
+ adapter_history = None
644
+ if history:
645
+ adapter_history = [
646
+ self._benchmark_action_to_dict(a) for _, a in history[-5:]
647
+ ]
648
+
649
+ try:
650
+ parsed_action = adapter.predict(
651
+ screenshot=screenshot,
652
+ goal=task.instruction,
653
+ a11y_tree=a11y_tree,
654
+ history=adapter_history,
655
+ )
656
+ except Exception as e:
657
+ if self.verbose:
658
+ print(f"[UnifiedBaselineAgent] Adapter error: {e}")
659
+ return BenchmarkAction(type="done", raw_action={"error": str(e)})
660
+
661
+ return self._parsed_to_benchmark_action(parsed_action, observation)
662
+
663
+ def _benchmark_action_to_dict(self, action: BenchmarkAction) -> dict[str, Any]:
664
+ """Convert BenchmarkAction to dict for history."""
665
+ result = {"type": action.type}
666
+ if action.x is not None:
667
+ result["x"] = action.x
668
+ if action.y is not None:
669
+ result["y"] = action.y
670
+ if action.text:
671
+ result["text"] = action.text
672
+ if action.key:
673
+ result["key"] = action.key
674
+ if action.target_node_id:
675
+ result["element_id"] = action.target_node_id
676
+ if action.scroll_direction:
677
+ result["direction"] = action.scroll_direction
678
+ return result
679
+
680
+ def _parsed_to_benchmark_action(
681
+ self, parsed_action, observation: BenchmarkObservation | None = None
682
+ ) -> BenchmarkAction:
683
+ """Convert ParsedAction to BenchmarkAction."""
684
+ raw_action = {
685
+ "raw_response": parsed_action.raw_response,
686
+ "thought": parsed_action.thought,
687
+ }
688
+
689
+ if not parsed_action.is_valid:
690
+ raw_action["parse_error"] = parsed_action.parse_error
691
+ return BenchmarkAction(type="done", raw_action=raw_action)
692
+
693
+ action_type = parsed_action.action_type
694
+
695
+ if action_type == "click":
696
+ if parsed_action.element_id is not None:
697
+ return BenchmarkAction(
698
+ type="click",
699
+ target_node_id=str(parsed_action.element_id),
700
+ raw_action=raw_action,
701
+ )
702
+ elif parsed_action.x is not None and parsed_action.y is not None:
703
+ x, y = parsed_action.x, parsed_action.y
704
+ if observation and observation.viewport and (x > 1.0 or y > 1.0):
705
+ width, height = observation.viewport
706
+ raw_action["original_coords"] = {"x": x, "y": y}
707
+ x, y = x / width, y / height
708
+ return BenchmarkAction(type="click", x=x, y=y, raw_action=raw_action)
709
+
710
+ elif action_type == "type":
711
+ return BenchmarkAction(
712
+ type="type", text=parsed_action.text, raw_action=raw_action
713
+ )
714
+
715
+ elif action_type == "key":
716
+ return BenchmarkAction(
717
+ type="key", key=parsed_action.key, raw_action=raw_action
718
+ )
719
+
720
+ elif action_type == "scroll":
721
+ return BenchmarkAction(
722
+ type="scroll",
723
+ scroll_direction=parsed_action.direction,
724
+ raw_action=raw_action,
725
+ )
726
+
727
+ elif action_type == "done":
728
+ return BenchmarkAction(type="done", raw_action=raw_action)
729
+
730
+ elif action_type == "drag":
731
+ return BenchmarkAction(
732
+ type="drag",
733
+ x=parsed_action.x,
734
+ y=parsed_action.y,
735
+ end_x=getattr(parsed_action, "end_x", None),
736
+ end_y=getattr(parsed_action, "end_y", None),
737
+ raw_action=raw_action,
738
+ )
739
+
740
+ raw_action["unknown_action"] = action_type
741
+ return BenchmarkAction(type="done", raw_action=raw_action)
742
+
743
+ def reset(self) -> None:
744
+ """Reset agent state."""
745
+ pass
746
+
747
+ def __repr__(self) -> str:
748
+ return f"UnifiedBaselineAgent(model={self.model_alias}, track={self.track})"