openadapt-ml 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. openadapt_ml/baselines/__init__.py +121 -0
  2. openadapt_ml/baselines/adapter.py +185 -0
  3. openadapt_ml/baselines/cli.py +314 -0
  4. openadapt_ml/baselines/config.py +448 -0
  5. openadapt_ml/baselines/parser.py +922 -0
  6. openadapt_ml/baselines/prompts.py +787 -0
  7. openadapt_ml/benchmarks/__init__.py +13 -107
  8. openadapt_ml/benchmarks/agent.py +297 -374
  9. openadapt_ml/benchmarks/azure.py +62 -24
  10. openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
  11. openadapt_ml/benchmarks/cli.py +1874 -751
  12. openadapt_ml/benchmarks/trace_export.py +631 -0
  13. openadapt_ml/benchmarks/viewer.py +1236 -0
  14. openadapt_ml/benchmarks/vm_monitor.py +1111 -0
  15. openadapt_ml/benchmarks/waa_deploy/Dockerfile +216 -0
  16. openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
  17. openadapt_ml/benchmarks/waa_deploy/api_agent.py +540 -0
  18. openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
  19. openadapt_ml/cloud/azure_inference.py +3 -5
  20. openadapt_ml/cloud/lambda_labs.py +722 -307
  21. openadapt_ml/cloud/local.py +3194 -89
  22. openadapt_ml/cloud/ssh_tunnel.py +595 -0
  23. openadapt_ml/datasets/next_action.py +125 -96
  24. openadapt_ml/evals/grounding.py +32 -9
  25. openadapt_ml/evals/plot_eval_metrics.py +15 -13
  26. openadapt_ml/evals/trajectory_matching.py +120 -57
  27. openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
  28. openadapt_ml/experiments/demo_prompt/format_demo.py +236 -0
  29. openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
  30. openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
  31. openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
  32. openadapt_ml/experiments/demo_prompt/run_experiment.py +541 -0
  33. openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
  34. openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
  35. openadapt_ml/experiments/representation_shootout/config.py +390 -0
  36. openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
  37. openadapt_ml/experiments/representation_shootout/runner.py +687 -0
  38. openadapt_ml/experiments/waa_demo/__init__.py +10 -0
  39. openadapt_ml/experiments/waa_demo/demos.py +357 -0
  40. openadapt_ml/experiments/waa_demo/runner.py +732 -0
  41. openadapt_ml/experiments/waa_demo/tasks.py +151 -0
  42. openadapt_ml/export/__init__.py +9 -0
  43. openadapt_ml/export/__main__.py +6 -0
  44. openadapt_ml/export/cli.py +89 -0
  45. openadapt_ml/export/parquet.py +277 -0
  46. openadapt_ml/grounding/detector.py +18 -14
  47. openadapt_ml/ingest/__init__.py +11 -10
  48. openadapt_ml/ingest/capture.py +97 -86
  49. openadapt_ml/ingest/loader.py +120 -69
  50. openadapt_ml/ingest/synthetic.py +344 -193
  51. openadapt_ml/models/api_adapter.py +14 -4
  52. openadapt_ml/models/base_adapter.py +10 -2
  53. openadapt_ml/models/providers/__init__.py +288 -0
  54. openadapt_ml/models/providers/anthropic.py +266 -0
  55. openadapt_ml/models/providers/base.py +299 -0
  56. openadapt_ml/models/providers/google.py +376 -0
  57. openadapt_ml/models/providers/openai.py +342 -0
  58. openadapt_ml/models/qwen_vl.py +46 -19
  59. openadapt_ml/perception/__init__.py +35 -0
  60. openadapt_ml/perception/integration.py +399 -0
  61. openadapt_ml/retrieval/README.md +226 -0
  62. openadapt_ml/retrieval/USAGE.md +391 -0
  63. openadapt_ml/retrieval/__init__.py +91 -0
  64. openadapt_ml/retrieval/demo_retriever.py +843 -0
  65. openadapt_ml/retrieval/embeddings.py +630 -0
  66. openadapt_ml/retrieval/index.py +194 -0
  67. openadapt_ml/retrieval/retriever.py +162 -0
  68. openadapt_ml/runtime/__init__.py +50 -0
  69. openadapt_ml/runtime/policy.py +27 -14
  70. openadapt_ml/runtime/safety_gate.py +471 -0
  71. openadapt_ml/schema/__init__.py +113 -0
  72. openadapt_ml/schema/converters.py +588 -0
  73. openadapt_ml/schema/episode.py +470 -0
  74. openadapt_ml/scripts/capture_screenshots.py +530 -0
  75. openadapt_ml/scripts/compare.py +102 -61
  76. openadapt_ml/scripts/demo_policy.py +4 -1
  77. openadapt_ml/scripts/eval_policy.py +19 -14
  78. openadapt_ml/scripts/make_gif.py +1 -1
  79. openadapt_ml/scripts/prepare_synthetic.py +16 -17
  80. openadapt_ml/scripts/train.py +98 -75
  81. openadapt_ml/segmentation/README.md +920 -0
  82. openadapt_ml/segmentation/__init__.py +97 -0
  83. openadapt_ml/segmentation/adapters/__init__.py +5 -0
  84. openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
  85. openadapt_ml/segmentation/annotator.py +610 -0
  86. openadapt_ml/segmentation/cache.py +290 -0
  87. openadapt_ml/segmentation/cli.py +674 -0
  88. openadapt_ml/segmentation/deduplicator.py +656 -0
  89. openadapt_ml/segmentation/frame_describer.py +788 -0
  90. openadapt_ml/segmentation/pipeline.py +340 -0
  91. openadapt_ml/segmentation/schemas.py +622 -0
  92. openadapt_ml/segmentation/segment_extractor.py +634 -0
  93. openadapt_ml/training/azure_ops_viewer.py +1097 -0
  94. openadapt_ml/training/benchmark_viewer.py +3255 -19
  95. openadapt_ml/training/shared_ui.py +7 -7
  96. openadapt_ml/training/stub_provider.py +57 -35
  97. openadapt_ml/training/trainer.py +255 -441
  98. openadapt_ml/training/trl_trainer.py +403 -0
  99. openadapt_ml/training/viewer.py +323 -108
  100. openadapt_ml/training/viewer_components.py +180 -0
  101. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +312 -69
  102. openadapt_ml-0.2.1.dist-info/RECORD +116 -0
  103. openadapt_ml/benchmarks/base.py +0 -366
  104. openadapt_ml/benchmarks/data_collection.py +0 -432
  105. openadapt_ml/benchmarks/runner.py +0 -381
  106. openadapt_ml/benchmarks/waa.py +0 -704
  107. openadapt_ml/schemas/__init__.py +0 -53
  108. openadapt_ml/schemas/sessions.py +0 -122
  109. openadapt_ml/schemas/validation.py +0 -252
  110. openadapt_ml-0.1.0.dist-info/RECORD +0 -55
  111. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
  112. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,194 @@
1
+ """Demo index for storing and retrieving demonstrations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Dict, List, Optional
7
+
8
+ from openadapt_ml.retrieval.embeddings import TextEmbedder
9
+ from openadapt_ml.schema import Episode
10
+
11
+
12
+ @dataclass
13
+ class DemoMetadata:
14
+ """Metadata for a single demonstration.
15
+
16
+ Stores both the episode and computed features for retrieval.
17
+ """
18
+
19
+ episode: Episode
20
+ app_name: Optional[str] = None
21
+ domain: Optional[str] = None
22
+ metadata: Dict[str, Any] = field(default_factory=dict)
23
+
24
+ # Computed at index time
25
+ text_embedding: Dict[str, float] = field(default_factory=dict)
26
+
27
+
28
+ class DemoIndex:
29
+ """Index for demonstrations.
30
+
31
+ Stores episodes with their metadata and embeddings for efficient retrieval.
32
+ """
33
+
34
+ def __init__(self) -> None:
35
+ """Initialize the demo index."""
36
+ self.demos: List[DemoMetadata] = []
37
+ self.embedder = TextEmbedder()
38
+ self._is_fitted = False
39
+
40
+ def _extract_app_name(self, episode: Episode) -> Optional[str]:
41
+ """Extract app name from episode steps.
42
+
43
+ Args:
44
+ episode: Episode to extract from.
45
+
46
+ Returns:
47
+ App name if found, None otherwise.
48
+ """
49
+ # Look through observations to find app_name
50
+ for step in episode.steps:
51
+ if step.observation and step.observation.app_name:
52
+ return step.observation.app_name
53
+ return None
54
+
55
+ def _extract_domain(self, episode: Episode) -> Optional[str]:
56
+ """Extract domain from episode metadata or URL.
57
+
58
+ Args:
59
+ episode: Episode to extract from.
60
+
61
+ Returns:
62
+ Domain if found, None otherwise.
63
+ """
64
+ # Try to extract from URL in observations
65
+ for step in episode.steps:
66
+ if step.observation and step.observation.url:
67
+ url = step.observation.url
68
+ # Simple domain extraction (e.g., "github.com" from "https://github.com/...")
69
+ if "://" in url:
70
+ domain = url.split("://")[1].split("/")[0]
71
+ # Remove www. prefix
72
+ if domain.startswith("www."):
73
+ domain = domain[4:]
74
+ return domain
75
+
76
+ return None
77
+
78
+ def add(
79
+ self,
80
+ episode: Episode,
81
+ app_name: Optional[str] = None,
82
+ domain: Optional[str] = None,
83
+ metadata: Optional[Dict[str, Any]] = None,
84
+ ) -> None:
85
+ """Add an episode to the index.
86
+
87
+ Args:
88
+ episode: Episode to add.
89
+ app_name: Optional app name (auto-extracted if not provided).
90
+ domain: Optional domain (auto-extracted if not provided).
91
+ metadata: Additional metadata for the episode.
92
+ """
93
+ # Auto-extract app_name and domain if not provided
94
+ if app_name is None:
95
+ app_name = self._extract_app_name(episode)
96
+ if domain is None:
97
+ domain = self._extract_domain(episode)
98
+
99
+ demo_meta = DemoMetadata(
100
+ episode=episode,
101
+ app_name=app_name,
102
+ domain=domain,
103
+ metadata=metadata or {},
104
+ )
105
+
106
+ self.demos.append(demo_meta)
107
+ # Mark as not fitted since we added new data
108
+ self._is_fitted = False
109
+
110
+ def add_many(self, episodes: List[Episode]) -> None:
111
+ """Add multiple episodes to the index.
112
+
113
+ Args:
114
+ episodes: List of episodes to add.
115
+ """
116
+ for episode in episodes:
117
+ self.add(episode)
118
+
119
+ def build(self) -> None:
120
+ """Build the index by computing embeddings.
121
+
122
+ This must be called after adding all demos and before retrieval.
123
+ """
124
+ if not self.demos:
125
+ return
126
+
127
+ # Fit embedder on all instruction texts
128
+ instruction_texts = [demo.episode.instruction for demo in self.demos]
129
+ self.embedder.fit(instruction_texts)
130
+
131
+ # Compute embeddings for each demo
132
+ for demo in self.demos:
133
+ demo.text_embedding = self.embedder.embed(demo.episode.instruction)
134
+
135
+ self._is_fitted = True
136
+
137
+ def is_empty(self) -> bool:
138
+ """Check if the index is empty.
139
+
140
+ Returns:
141
+ True if no demos have been added.
142
+ """
143
+ return len(self.demos) == 0
144
+
145
+ def is_fitted(self) -> bool:
146
+ """Check if the index has been built.
147
+
148
+ Returns:
149
+ True if build() has been called.
150
+ """
151
+ return self._is_fitted
152
+
153
+ def get_all_demos(self) -> List[DemoMetadata]:
154
+ """Get all demos in the index.
155
+
156
+ Returns:
157
+ List of all DemoMetadata objects.
158
+ """
159
+ return self.demos
160
+
161
+ def get_apps(self) -> List[str]:
162
+ """Get list of unique app names in the index.
163
+
164
+ Returns:
165
+ List of app names (excluding None).
166
+ """
167
+ apps = {demo.app_name for demo in self.demos if demo.app_name is not None}
168
+ return sorted(apps)
169
+
170
+ def get_domains(self) -> List[str]:
171
+ """Get list of unique domains in the index.
172
+
173
+ Returns:
174
+ List of domains (excluding None).
175
+ """
176
+ domains = {demo.domain for demo in self.demos if demo.domain is not None}
177
+ return sorted(domains)
178
+
179
+ def __len__(self) -> int:
180
+ """Return number of demos in the index.
181
+
182
+ Returns:
183
+ Number of demos.
184
+ """
185
+ return len(self.demos)
186
+
187
+ def __repr__(self) -> str:
188
+ """String representation of the index.
189
+
190
+ Returns:
191
+ String representation.
192
+ """
193
+ status = "fitted" if self._is_fitted else "not fitted"
194
+ return f"DemoIndex({len(self.demos)} demos, {status})"
@@ -0,0 +1,162 @@
1
+ """Demo retriever for finding similar demonstrations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import List, Optional
7
+
8
+ from openadapt_ml.retrieval.index import DemoIndex, DemoMetadata
9
+ from openadapt_ml.schema import Episode
10
+
11
+
12
+ @dataclass
13
+ class RetrievalResult:
14
+ """A single retrieval result with score.
15
+
16
+ Attributes:
17
+ demo: The demo metadata.
18
+ score: Retrieval score (higher is better).
19
+ text_score: Text similarity component.
20
+ domain_bonus: Domain match bonus applied.
21
+ """
22
+
23
+ demo: DemoMetadata
24
+ score: float
25
+ text_score: float
26
+ domain_bonus: float
27
+
28
+
29
+ class DemoRetriever:
30
+ """Retrieves top-K similar demonstrations from an index.
31
+
32
+ Uses text similarity (TF-IDF cosine) with optional domain match bonus.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ index: DemoIndex,
38
+ domain_bonus: float = 0.2,
39
+ ) -> None:
40
+ """Initialize the retriever.
41
+
42
+ Args:
43
+ index: DemoIndex to retrieve from.
44
+ domain_bonus: Bonus score for domain match (default: 0.2).
45
+
46
+ Raises:
47
+ ValueError: If index is empty or not fitted.
48
+ """
49
+ if index.is_empty():
50
+ raise ValueError("Cannot create retriever from empty index")
51
+ if not index.is_fitted():
52
+ raise ValueError(
53
+ "Index must be built before retrieval (call index.build())"
54
+ )
55
+
56
+ self.index = index
57
+ self.domain_bonus = domain_bonus
58
+
59
+ def _compute_score(
60
+ self,
61
+ task: str,
62
+ demo: DemoMetadata,
63
+ app_context: Optional[str] = None,
64
+ ) -> RetrievalResult:
65
+ """Compute retrieval score for a demo.
66
+
67
+ Args:
68
+ task: Task description to match against.
69
+ demo: Demo metadata to score.
70
+ app_context: Optional app/domain context for bonus.
71
+
72
+ Returns:
73
+ RetrievalResult with computed scores.
74
+ """
75
+ # Text similarity using TF-IDF
76
+ query_embedding = self.index.embedder.embed(task)
77
+ text_score = self.index.embedder.cosine_similarity(
78
+ query_embedding,
79
+ demo.text_embedding,
80
+ )
81
+
82
+ # Domain match bonus
83
+ bonus = 0.0
84
+ if app_context is not None:
85
+ # Check if app_context matches app_name or domain
86
+ app_match = demo.app_name and app_context.lower() in demo.app_name.lower()
87
+ domain_match = demo.domain and app_context.lower() in demo.domain.lower()
88
+
89
+ if app_match or domain_match:
90
+ bonus = self.domain_bonus
91
+
92
+ # Final score is text similarity + bonus
93
+ total_score = text_score + bonus
94
+
95
+ return RetrievalResult(
96
+ demo=demo,
97
+ score=total_score,
98
+ text_score=text_score,
99
+ domain_bonus=bonus,
100
+ )
101
+
102
+ def retrieve(
103
+ self,
104
+ task: str,
105
+ app_context: Optional[str] = None,
106
+ top_k: int = 3,
107
+ ) -> List[Episode]:
108
+ """Retrieve top-K most similar demos.
109
+
110
+ Args:
111
+ task: Task description to find demos for.
112
+ app_context: Optional app/domain context (e.g., "Chrome", "github.com").
113
+ top_k: Number of demos to retrieve.
114
+
115
+ Returns:
116
+ List of Episode objects, ordered by relevance (most similar first).
117
+ """
118
+ if self.index.is_empty():
119
+ return []
120
+
121
+ # Score all demos
122
+ results = [
123
+ self._compute_score(task, demo, app_context)
124
+ for demo in self.index.get_all_demos()
125
+ ]
126
+
127
+ # Sort by score (descending)
128
+ results.sort(key=lambda r: r.score, reverse=True)
129
+
130
+ # Return top-K episodes
131
+ top_results = results[:top_k]
132
+ return [r.demo.episode for r in top_results]
133
+
134
+ def retrieve_with_scores(
135
+ self,
136
+ task: str,
137
+ app_context: Optional[str] = None,
138
+ top_k: int = 3,
139
+ ) -> List[RetrievalResult]:
140
+ """Retrieve top-K demos with their scores.
141
+
142
+ Args:
143
+ task: Task description to find demos for.
144
+ app_context: Optional app/domain context.
145
+ top_k: Number of demos to retrieve.
146
+
147
+ Returns:
148
+ List of RetrievalResult objects with scores.
149
+ """
150
+ if self.index.is_empty():
151
+ return []
152
+
153
+ # Score all demos
154
+ results = [
155
+ self._compute_score(task, demo, app_context)
156
+ for demo in self.index.get_all_demos()
157
+ ]
158
+
159
+ # Sort by score (descending)
160
+ results.sort(key=lambda r: r.score, reverse=True)
161
+
162
+ return results[:top_k]
@@ -0,0 +1,50 @@
1
+ """
2
+ Runtime module for GUI automation agents.
3
+
4
+ This module provides:
5
+ - AgentPolicy: Runtime policy wrapper for VLM-based action prediction
6
+ - SafetyGate: Deterministic safety checks for action validation
7
+
8
+ Example usage:
9
+ from openadapt_ml.runtime import AgentPolicy, SafetyGate, SafetyConfig, SafetyDecision
10
+
11
+ # Create policy and safety gate
12
+ policy = AgentPolicy(adapter)
13
+ gate = SafetyGate(SafetyConfig(confidence_threshold=0.8))
14
+
15
+ # In agent loop
16
+ action, thought, state, raw = policy.predict_action_from_sample(sample)
17
+ assessment = gate.assess(action, observation)
18
+
19
+ if assessment.decision == SafetyDecision.ALLOW:
20
+ execute(action)
21
+ elif assessment.decision == SafetyDecision.REQUIRE_CONFIRMATION:
22
+ if user_confirms():
23
+ execute(action)
24
+ else: # BLOCK
25
+ log_blocked(assessment.reason)
26
+ """
27
+
28
+ from openadapt_ml.runtime.policy import (
29
+ AgentPolicy,
30
+ PolicyOutput,
31
+ parse_thought_state_action,
32
+ )
33
+ from openadapt_ml.runtime.safety_gate import (
34
+ SafetyAssessment,
35
+ SafetyConfig,
36
+ SafetyDecision,
37
+ SafetyGate,
38
+ )
39
+
40
+ __all__ = [
41
+ # Policy
42
+ "AgentPolicy",
43
+ "PolicyOutput",
44
+ "parse_thought_state_action",
45
+ # Safety Gate
46
+ "SafetyGate",
47
+ "SafetyConfig",
48
+ "SafetyDecision",
49
+ "SafetyAssessment",
50
+ ]
@@ -3,16 +3,16 @@ from __future__ import annotations
3
3
  import json
4
4
  import re
5
5
  from dataclasses import dataclass
6
- from typing import Any, Dict, List, Optional, Tuple
6
+ from typing import Any, Dict, Optional, Tuple
7
7
 
8
8
  from PIL import Image
9
9
 
10
10
  from openadapt_ml.models.base_adapter import BaseVLMAdapter
11
- from openadapt_ml.schemas.sessions import Action
11
+ from openadapt_ml.schema import Action, ActionType, UIElement
12
12
 
13
13
 
14
14
  # Coordinate-based DSL patterns
15
- _CLICK_RE = re.compile(r"CLICK\(x=([0-9]*\.?[0-9]+),\s*y=([0-9]*\.?[0-9]+)\)")
15
+ _CLICK_RE = re.compile(r"CLICK\(x=(-?[0-9]*\.?[0-9]+),\s*y=(-?[0-9]*\.?[0-9]+)\)")
16
16
  _TYPE_RE = re.compile(r'TYPE\(text="([^"\\]*(?:\\.[^"\\]*)*)"\)')
17
17
  _WAIT_RE = re.compile(r"\bWAIT\s*\(\s*\)")
18
18
  _DONE_RE = re.compile(r"\bDONE\s*\(\s*\)")
@@ -26,13 +26,16 @@ _TYPE_SOM_SIMPLE_RE = re.compile(r'TYPE\(["\']([^"\']*(?:\\.[^"\']*)*)["\']\)')
26
26
  @dataclass
27
27
  class PolicyOutput:
28
28
  """Result of a single policy step."""
29
+
29
30
  action: Action
30
31
  thought: Optional[str] = None
31
32
  state: Optional[Dict[str, Any]] = None
32
33
  raw_text: str = ""
33
34
 
34
35
 
35
- def parse_thought_state_action(text: str) -> Tuple[Optional[str], Optional[Dict[str, Any]], str]:
36
+ def parse_thought_state_action(
37
+ text: str,
38
+ ) -> Tuple[Optional[str], Optional[Dict[str, Any]], str]:
36
39
  """Parse Thought / State / Action blocks from model output.
37
40
 
38
41
  Expected format:
@@ -54,12 +57,18 @@ def parse_thought_state_action(text: str) -> Tuple[Optional[str], Optional[Dict[
54
57
  action_str: str = text.strip()
55
58
 
56
59
  # Extract Thought - find the LAST occurrence (model's response, not template)
57
- thought_matches = list(re.finditer(r"Thought:\s*(.+?)(?=State:|Action:|$)", text, re.DOTALL | re.IGNORECASE))
60
+ thought_matches = list(
61
+ re.finditer(
62
+ r"Thought:\s*(.+?)(?=State:|Action:|$)", text, re.DOTALL | re.IGNORECASE
63
+ )
64
+ )
58
65
  if thought_matches:
59
66
  thought = thought_matches[-1].group(1).strip()
60
67
 
61
68
  # Extract State (JSON on same line or next line) - last occurrence
62
- state_matches = list(re.finditer(r"State:\s*(\{.*?\})", text, re.DOTALL | re.IGNORECASE))
69
+ state_matches = list(
70
+ re.finditer(r"State:\s*(\{.*?\})", text, re.DOTALL | re.IGNORECASE)
71
+ )
63
72
  if state_matches:
64
73
  try:
65
74
  state = json.loads(state_matches[-1].group(1))
@@ -119,7 +128,7 @@ class AgentPolicy:
119
128
  m = _CLICK_SOM_RE.search(text)
120
129
  if m:
121
130
  idx = int(m.group(1))
122
- return Action(type="click", element_index=idx)
131
+ return Action(type=ActionType.CLICK, element=UIElement(element_id=str(idx)))
123
132
 
124
133
  # TYPE([N], "text")
125
134
  m = _TYPE_SOM_RE.search(text)
@@ -127,14 +136,18 @@ class AgentPolicy:
127
136
  idx = int(m.group(1))
128
137
  raw_text = m.group(2)
129
138
  unescaped = raw_text.replace('\\"', '"').replace("\\\\", "\\")
130
- return Action(type="type", text=unescaped, element_index=idx)
139
+ return Action(
140
+ type=ActionType.TYPE,
141
+ text=unescaped,
142
+ element=UIElement(element_id=str(idx)),
143
+ )
131
144
 
132
145
  # TYPE("text") - SoM style without index
133
146
  m = _TYPE_SOM_SIMPLE_RE.search(text)
134
147
  if m:
135
148
  raw_text = m.group(1)
136
149
  unescaped = raw_text.replace('\\"', '"').replace("\\\\", "\\")
137
- return Action(type="type", text=unescaped)
150
+ return Action(type=ActionType.TYPE, text=unescaped)
138
151
 
139
152
  # Coordinate-based patterns
140
153
  # CLICK(x=..., y=...)
@@ -145,7 +158,7 @@ class AgentPolicy:
145
158
  # Clamp to [0, 1]
146
159
  x = max(0.0, min(1.0, x))
147
160
  y = max(0.0, min(1.0, y))
148
- return Action(type="click", x=x, y=y)
161
+ return Action(type=ActionType.CLICK, normalized_coordinates=(x, y))
149
162
 
150
163
  # TYPE(text="...")
151
164
  m = _TYPE_RE.search(text)
@@ -153,18 +166,18 @@ class AgentPolicy:
153
166
  # Unescape the text content
154
167
  raw_text = m.group(1)
155
168
  unescaped = raw_text.replace('\\"', '"').replace("\\\\", "\\")
156
- return Action(type="type", text=unescaped)
169
+ return Action(type=ActionType.TYPE, text=unescaped)
157
170
 
158
171
  # WAIT()
159
172
  if _WAIT_RE.search(text):
160
- return Action(type="wait")
173
+ return Action(type=ActionType.WAIT)
161
174
 
162
175
  # DONE()
163
176
  if _DONE_RE.search(text):
164
- return Action(type="done")
177
+ return Action(type=ActionType.DONE)
165
178
 
166
179
  # Fallback
167
- return Action(type="failed", raw={"text": text})
180
+ return Action(type=ActionType.FAIL, raw={"text": text})
168
181
 
169
182
  def predict_action_from_sample(
170
183
  self, sample: Dict[str, Any], max_new_tokens: int = 150