openadapt-ml 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/baselines/__init__.py +121 -0
- openadapt_ml/baselines/adapter.py +185 -0
- openadapt_ml/baselines/cli.py +314 -0
- openadapt_ml/baselines/config.py +448 -0
- openadapt_ml/baselines/parser.py +922 -0
- openadapt_ml/baselines/prompts.py +787 -0
- openadapt_ml/benchmarks/__init__.py +13 -107
- openadapt_ml/benchmarks/agent.py +297 -374
- openadapt_ml/benchmarks/azure.py +62 -24
- openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
- openadapt_ml/benchmarks/cli.py +1874 -751
- openadapt_ml/benchmarks/trace_export.py +631 -0
- openadapt_ml/benchmarks/viewer.py +1236 -0
- openadapt_ml/benchmarks/vm_monitor.py +1111 -0
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +216 -0
- openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +540 -0
- openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
- openadapt_ml/cloud/azure_inference.py +3 -5
- openadapt_ml/cloud/lambda_labs.py +722 -307
- openadapt_ml/cloud/local.py +3194 -89
- openadapt_ml/cloud/ssh_tunnel.py +595 -0
- openadapt_ml/datasets/next_action.py +125 -96
- openadapt_ml/evals/grounding.py +32 -9
- openadapt_ml/evals/plot_eval_metrics.py +15 -13
- openadapt_ml/evals/trajectory_matching.py +120 -57
- openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
- openadapt_ml/experiments/demo_prompt/format_demo.py +236 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
- openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
- openadapt_ml/experiments/demo_prompt/run_experiment.py +541 -0
- openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
- openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
- openadapt_ml/experiments/representation_shootout/config.py +390 -0
- openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
- openadapt_ml/experiments/representation_shootout/runner.py +687 -0
- openadapt_ml/experiments/waa_demo/__init__.py +10 -0
- openadapt_ml/experiments/waa_demo/demos.py +357 -0
- openadapt_ml/experiments/waa_demo/runner.py +732 -0
- openadapt_ml/experiments/waa_demo/tasks.py +151 -0
- openadapt_ml/export/__init__.py +9 -0
- openadapt_ml/export/__main__.py +6 -0
- openadapt_ml/export/cli.py +89 -0
- openadapt_ml/export/parquet.py +277 -0
- openadapt_ml/grounding/detector.py +18 -14
- openadapt_ml/ingest/__init__.py +11 -10
- openadapt_ml/ingest/capture.py +97 -86
- openadapt_ml/ingest/loader.py +120 -69
- openadapt_ml/ingest/synthetic.py +344 -193
- openadapt_ml/models/api_adapter.py +14 -4
- openadapt_ml/models/base_adapter.py +10 -2
- openadapt_ml/models/providers/__init__.py +288 -0
- openadapt_ml/models/providers/anthropic.py +266 -0
- openadapt_ml/models/providers/base.py +299 -0
- openadapt_ml/models/providers/google.py +376 -0
- openadapt_ml/models/providers/openai.py +342 -0
- openadapt_ml/models/qwen_vl.py +46 -19
- openadapt_ml/perception/__init__.py +35 -0
- openadapt_ml/perception/integration.py +399 -0
- openadapt_ml/retrieval/README.md +226 -0
- openadapt_ml/retrieval/USAGE.md +391 -0
- openadapt_ml/retrieval/__init__.py +91 -0
- openadapt_ml/retrieval/demo_retriever.py +843 -0
- openadapt_ml/retrieval/embeddings.py +630 -0
- openadapt_ml/retrieval/index.py +194 -0
- openadapt_ml/retrieval/retriever.py +162 -0
- openadapt_ml/runtime/__init__.py +50 -0
- openadapt_ml/runtime/policy.py +27 -14
- openadapt_ml/runtime/safety_gate.py +471 -0
- openadapt_ml/schema/__init__.py +113 -0
- openadapt_ml/schema/converters.py +588 -0
- openadapt_ml/schema/episode.py +470 -0
- openadapt_ml/scripts/capture_screenshots.py +530 -0
- openadapt_ml/scripts/compare.py +102 -61
- openadapt_ml/scripts/demo_policy.py +4 -1
- openadapt_ml/scripts/eval_policy.py +19 -14
- openadapt_ml/scripts/make_gif.py +1 -1
- openadapt_ml/scripts/prepare_synthetic.py +16 -17
- openadapt_ml/scripts/train.py +98 -75
- openadapt_ml/segmentation/README.md +920 -0
- openadapt_ml/segmentation/__init__.py +97 -0
- openadapt_ml/segmentation/adapters/__init__.py +5 -0
- openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
- openadapt_ml/segmentation/annotator.py +610 -0
- openadapt_ml/segmentation/cache.py +290 -0
- openadapt_ml/segmentation/cli.py +674 -0
- openadapt_ml/segmentation/deduplicator.py +656 -0
- openadapt_ml/segmentation/frame_describer.py +788 -0
- openadapt_ml/segmentation/pipeline.py +340 -0
- openadapt_ml/segmentation/schemas.py +622 -0
- openadapt_ml/segmentation/segment_extractor.py +634 -0
- openadapt_ml/training/azure_ops_viewer.py +1097 -0
- openadapt_ml/training/benchmark_viewer.py +3255 -19
- openadapt_ml/training/shared_ui.py +7 -7
- openadapt_ml/training/stub_provider.py +57 -35
- openadapt_ml/training/trainer.py +255 -441
- openadapt_ml/training/trl_trainer.py +403 -0
- openadapt_ml/training/viewer.py +323 -108
- openadapt_ml/training/viewer_components.py +180 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +312 -69
- openadapt_ml-0.2.1.dist-info/RECORD +116 -0
- openadapt_ml/benchmarks/base.py +0 -366
- openadapt_ml/benchmarks/data_collection.py +0 -432
- openadapt_ml/benchmarks/runner.py +0 -381
- openadapt_ml/benchmarks/waa.py +0 -704
- openadapt_ml/schemas/__init__.py +0 -53
- openadapt_ml/schemas/sessions.py +0 -122
- openadapt_ml/schemas/validation.py +0 -252
- openadapt_ml-0.1.0.dist-info/RECORD +0 -55
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
"""Demo index for storing and retrieving demonstrations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any, Dict, List, Optional
|
|
7
|
+
|
|
8
|
+
from openadapt_ml.retrieval.embeddings import TextEmbedder
|
|
9
|
+
from openadapt_ml.schema import Episode
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class DemoMetadata:
|
|
14
|
+
"""Metadata for a single demonstration.
|
|
15
|
+
|
|
16
|
+
Stores both the episode and computed features for retrieval.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
episode: Episode
|
|
20
|
+
app_name: Optional[str] = None
|
|
21
|
+
domain: Optional[str] = None
|
|
22
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
23
|
+
|
|
24
|
+
# Computed at index time
|
|
25
|
+
text_embedding: Dict[str, float] = field(default_factory=dict)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class DemoIndex:
|
|
29
|
+
"""Index for demonstrations.
|
|
30
|
+
|
|
31
|
+
Stores episodes with their metadata and embeddings for efficient retrieval.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self) -> None:
|
|
35
|
+
"""Initialize the demo index."""
|
|
36
|
+
self.demos: List[DemoMetadata] = []
|
|
37
|
+
self.embedder = TextEmbedder()
|
|
38
|
+
self._is_fitted = False
|
|
39
|
+
|
|
40
|
+
def _extract_app_name(self, episode: Episode) -> Optional[str]:
|
|
41
|
+
"""Extract app name from episode steps.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
episode: Episode to extract from.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
App name if found, None otherwise.
|
|
48
|
+
"""
|
|
49
|
+
# Look through observations to find app_name
|
|
50
|
+
for step in episode.steps:
|
|
51
|
+
if step.observation and step.observation.app_name:
|
|
52
|
+
return step.observation.app_name
|
|
53
|
+
return None
|
|
54
|
+
|
|
55
|
+
def _extract_domain(self, episode: Episode) -> Optional[str]:
|
|
56
|
+
"""Extract domain from episode metadata or URL.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
episode: Episode to extract from.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
Domain if found, None otherwise.
|
|
63
|
+
"""
|
|
64
|
+
# Try to extract from URL in observations
|
|
65
|
+
for step in episode.steps:
|
|
66
|
+
if step.observation and step.observation.url:
|
|
67
|
+
url = step.observation.url
|
|
68
|
+
# Simple domain extraction (e.g., "github.com" from "https://github.com/...")
|
|
69
|
+
if "://" in url:
|
|
70
|
+
domain = url.split("://")[1].split("/")[0]
|
|
71
|
+
# Remove www. prefix
|
|
72
|
+
if domain.startswith("www."):
|
|
73
|
+
domain = domain[4:]
|
|
74
|
+
return domain
|
|
75
|
+
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
def add(
|
|
79
|
+
self,
|
|
80
|
+
episode: Episode,
|
|
81
|
+
app_name: Optional[str] = None,
|
|
82
|
+
domain: Optional[str] = None,
|
|
83
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
84
|
+
) -> None:
|
|
85
|
+
"""Add an episode to the index.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
episode: Episode to add.
|
|
89
|
+
app_name: Optional app name (auto-extracted if not provided).
|
|
90
|
+
domain: Optional domain (auto-extracted if not provided).
|
|
91
|
+
metadata: Additional metadata for the episode.
|
|
92
|
+
"""
|
|
93
|
+
# Auto-extract app_name and domain if not provided
|
|
94
|
+
if app_name is None:
|
|
95
|
+
app_name = self._extract_app_name(episode)
|
|
96
|
+
if domain is None:
|
|
97
|
+
domain = self._extract_domain(episode)
|
|
98
|
+
|
|
99
|
+
demo_meta = DemoMetadata(
|
|
100
|
+
episode=episode,
|
|
101
|
+
app_name=app_name,
|
|
102
|
+
domain=domain,
|
|
103
|
+
metadata=metadata or {},
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
self.demos.append(demo_meta)
|
|
107
|
+
# Mark as not fitted since we added new data
|
|
108
|
+
self._is_fitted = False
|
|
109
|
+
|
|
110
|
+
def add_many(self, episodes: List[Episode]) -> None:
|
|
111
|
+
"""Add multiple episodes to the index.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
episodes: List of episodes to add.
|
|
115
|
+
"""
|
|
116
|
+
for episode in episodes:
|
|
117
|
+
self.add(episode)
|
|
118
|
+
|
|
119
|
+
def build(self) -> None:
|
|
120
|
+
"""Build the index by computing embeddings.
|
|
121
|
+
|
|
122
|
+
This must be called after adding all demos and before retrieval.
|
|
123
|
+
"""
|
|
124
|
+
if not self.demos:
|
|
125
|
+
return
|
|
126
|
+
|
|
127
|
+
# Fit embedder on all instruction texts
|
|
128
|
+
instruction_texts = [demo.episode.instruction for demo in self.demos]
|
|
129
|
+
self.embedder.fit(instruction_texts)
|
|
130
|
+
|
|
131
|
+
# Compute embeddings for each demo
|
|
132
|
+
for demo in self.demos:
|
|
133
|
+
demo.text_embedding = self.embedder.embed(demo.episode.instruction)
|
|
134
|
+
|
|
135
|
+
self._is_fitted = True
|
|
136
|
+
|
|
137
|
+
def is_empty(self) -> bool:
|
|
138
|
+
"""Check if the index is empty.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
True if no demos have been added.
|
|
142
|
+
"""
|
|
143
|
+
return len(self.demos) == 0
|
|
144
|
+
|
|
145
|
+
def is_fitted(self) -> bool:
|
|
146
|
+
"""Check if the index has been built.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
True if build() has been called.
|
|
150
|
+
"""
|
|
151
|
+
return self._is_fitted
|
|
152
|
+
|
|
153
|
+
def get_all_demos(self) -> List[DemoMetadata]:
|
|
154
|
+
"""Get all demos in the index.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
List of all DemoMetadata objects.
|
|
158
|
+
"""
|
|
159
|
+
return self.demos
|
|
160
|
+
|
|
161
|
+
def get_apps(self) -> List[str]:
|
|
162
|
+
"""Get list of unique app names in the index.
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
List of app names (excluding None).
|
|
166
|
+
"""
|
|
167
|
+
apps = {demo.app_name for demo in self.demos if demo.app_name is not None}
|
|
168
|
+
return sorted(apps)
|
|
169
|
+
|
|
170
|
+
def get_domains(self) -> List[str]:
|
|
171
|
+
"""Get list of unique domains in the index.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
List of domains (excluding None).
|
|
175
|
+
"""
|
|
176
|
+
domains = {demo.domain for demo in self.demos if demo.domain is not None}
|
|
177
|
+
return sorted(domains)
|
|
178
|
+
|
|
179
|
+
def __len__(self) -> int:
|
|
180
|
+
"""Return number of demos in the index.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
Number of demos.
|
|
184
|
+
"""
|
|
185
|
+
return len(self.demos)
|
|
186
|
+
|
|
187
|
+
def __repr__(self) -> str:
|
|
188
|
+
"""String representation of the index.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
String representation.
|
|
192
|
+
"""
|
|
193
|
+
status = "fitted" if self._is_fitted else "not fitted"
|
|
194
|
+
return f"DemoIndex({len(self.demos)} demos, {status})"
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
"""Demo retriever for finding similar demonstrations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import List, Optional
|
|
7
|
+
|
|
8
|
+
from openadapt_ml.retrieval.index import DemoIndex, DemoMetadata
|
|
9
|
+
from openadapt_ml.schema import Episode
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class RetrievalResult:
|
|
14
|
+
"""A single retrieval result with score.
|
|
15
|
+
|
|
16
|
+
Attributes:
|
|
17
|
+
demo: The demo metadata.
|
|
18
|
+
score: Retrieval score (higher is better).
|
|
19
|
+
text_score: Text similarity component.
|
|
20
|
+
domain_bonus: Domain match bonus applied.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
demo: DemoMetadata
|
|
24
|
+
score: float
|
|
25
|
+
text_score: float
|
|
26
|
+
domain_bonus: float
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class DemoRetriever:
|
|
30
|
+
"""Retrieves top-K similar demonstrations from an index.
|
|
31
|
+
|
|
32
|
+
Uses text similarity (TF-IDF cosine) with optional domain match bonus.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
index: DemoIndex,
|
|
38
|
+
domain_bonus: float = 0.2,
|
|
39
|
+
) -> None:
|
|
40
|
+
"""Initialize the retriever.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
index: DemoIndex to retrieve from.
|
|
44
|
+
domain_bonus: Bonus score for domain match (default: 0.2).
|
|
45
|
+
|
|
46
|
+
Raises:
|
|
47
|
+
ValueError: If index is empty or not fitted.
|
|
48
|
+
"""
|
|
49
|
+
if index.is_empty():
|
|
50
|
+
raise ValueError("Cannot create retriever from empty index")
|
|
51
|
+
if not index.is_fitted():
|
|
52
|
+
raise ValueError(
|
|
53
|
+
"Index must be built before retrieval (call index.build())"
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
self.index = index
|
|
57
|
+
self.domain_bonus = domain_bonus
|
|
58
|
+
|
|
59
|
+
def _compute_score(
|
|
60
|
+
self,
|
|
61
|
+
task: str,
|
|
62
|
+
demo: DemoMetadata,
|
|
63
|
+
app_context: Optional[str] = None,
|
|
64
|
+
) -> RetrievalResult:
|
|
65
|
+
"""Compute retrieval score for a demo.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
task: Task description to match against.
|
|
69
|
+
demo: Demo metadata to score.
|
|
70
|
+
app_context: Optional app/domain context for bonus.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
RetrievalResult with computed scores.
|
|
74
|
+
"""
|
|
75
|
+
# Text similarity using TF-IDF
|
|
76
|
+
query_embedding = self.index.embedder.embed(task)
|
|
77
|
+
text_score = self.index.embedder.cosine_similarity(
|
|
78
|
+
query_embedding,
|
|
79
|
+
demo.text_embedding,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# Domain match bonus
|
|
83
|
+
bonus = 0.0
|
|
84
|
+
if app_context is not None:
|
|
85
|
+
# Check if app_context matches app_name or domain
|
|
86
|
+
app_match = demo.app_name and app_context.lower() in demo.app_name.lower()
|
|
87
|
+
domain_match = demo.domain and app_context.lower() in demo.domain.lower()
|
|
88
|
+
|
|
89
|
+
if app_match or domain_match:
|
|
90
|
+
bonus = self.domain_bonus
|
|
91
|
+
|
|
92
|
+
# Final score is text similarity + bonus
|
|
93
|
+
total_score = text_score + bonus
|
|
94
|
+
|
|
95
|
+
return RetrievalResult(
|
|
96
|
+
demo=demo,
|
|
97
|
+
score=total_score,
|
|
98
|
+
text_score=text_score,
|
|
99
|
+
domain_bonus=bonus,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def retrieve(
|
|
103
|
+
self,
|
|
104
|
+
task: str,
|
|
105
|
+
app_context: Optional[str] = None,
|
|
106
|
+
top_k: int = 3,
|
|
107
|
+
) -> List[Episode]:
|
|
108
|
+
"""Retrieve top-K most similar demos.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
task: Task description to find demos for.
|
|
112
|
+
app_context: Optional app/domain context (e.g., "Chrome", "github.com").
|
|
113
|
+
top_k: Number of demos to retrieve.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
List of Episode objects, ordered by relevance (most similar first).
|
|
117
|
+
"""
|
|
118
|
+
if self.index.is_empty():
|
|
119
|
+
return []
|
|
120
|
+
|
|
121
|
+
# Score all demos
|
|
122
|
+
results = [
|
|
123
|
+
self._compute_score(task, demo, app_context)
|
|
124
|
+
for demo in self.index.get_all_demos()
|
|
125
|
+
]
|
|
126
|
+
|
|
127
|
+
# Sort by score (descending)
|
|
128
|
+
results.sort(key=lambda r: r.score, reverse=True)
|
|
129
|
+
|
|
130
|
+
# Return top-K episodes
|
|
131
|
+
top_results = results[:top_k]
|
|
132
|
+
return [r.demo.episode for r in top_results]
|
|
133
|
+
|
|
134
|
+
def retrieve_with_scores(
|
|
135
|
+
self,
|
|
136
|
+
task: str,
|
|
137
|
+
app_context: Optional[str] = None,
|
|
138
|
+
top_k: int = 3,
|
|
139
|
+
) -> List[RetrievalResult]:
|
|
140
|
+
"""Retrieve top-K demos with their scores.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
task: Task description to find demos for.
|
|
144
|
+
app_context: Optional app/domain context.
|
|
145
|
+
top_k: Number of demos to retrieve.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
List of RetrievalResult objects with scores.
|
|
149
|
+
"""
|
|
150
|
+
if self.index.is_empty():
|
|
151
|
+
return []
|
|
152
|
+
|
|
153
|
+
# Score all demos
|
|
154
|
+
results = [
|
|
155
|
+
self._compute_score(task, demo, app_context)
|
|
156
|
+
for demo in self.index.get_all_demos()
|
|
157
|
+
]
|
|
158
|
+
|
|
159
|
+
# Sort by score (descending)
|
|
160
|
+
results.sort(key=lambda r: r.score, reverse=True)
|
|
161
|
+
|
|
162
|
+
return results[:top_k]
|
openadapt_ml/runtime/__init__.py
CHANGED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Runtime module for GUI automation agents.
|
|
3
|
+
|
|
4
|
+
This module provides:
|
|
5
|
+
- AgentPolicy: Runtime policy wrapper for VLM-based action prediction
|
|
6
|
+
- SafetyGate: Deterministic safety checks for action validation
|
|
7
|
+
|
|
8
|
+
Example usage:
|
|
9
|
+
from openadapt_ml.runtime import AgentPolicy, SafetyGate, SafetyConfig, SafetyDecision
|
|
10
|
+
|
|
11
|
+
# Create policy and safety gate
|
|
12
|
+
policy = AgentPolicy(adapter)
|
|
13
|
+
gate = SafetyGate(SafetyConfig(confidence_threshold=0.8))
|
|
14
|
+
|
|
15
|
+
# In agent loop
|
|
16
|
+
action, thought, state, raw = policy.predict_action_from_sample(sample)
|
|
17
|
+
assessment = gate.assess(action, observation)
|
|
18
|
+
|
|
19
|
+
if assessment.decision == SafetyDecision.ALLOW:
|
|
20
|
+
execute(action)
|
|
21
|
+
elif assessment.decision == SafetyDecision.REQUIRE_CONFIRMATION:
|
|
22
|
+
if user_confirms():
|
|
23
|
+
execute(action)
|
|
24
|
+
else: # BLOCK
|
|
25
|
+
log_blocked(assessment.reason)
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
from openadapt_ml.runtime.policy import (
|
|
29
|
+
AgentPolicy,
|
|
30
|
+
PolicyOutput,
|
|
31
|
+
parse_thought_state_action,
|
|
32
|
+
)
|
|
33
|
+
from openadapt_ml.runtime.safety_gate import (
|
|
34
|
+
SafetyAssessment,
|
|
35
|
+
SafetyConfig,
|
|
36
|
+
SafetyDecision,
|
|
37
|
+
SafetyGate,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
__all__ = [
|
|
41
|
+
# Policy
|
|
42
|
+
"AgentPolicy",
|
|
43
|
+
"PolicyOutput",
|
|
44
|
+
"parse_thought_state_action",
|
|
45
|
+
# Safety Gate
|
|
46
|
+
"SafetyGate",
|
|
47
|
+
"SafetyConfig",
|
|
48
|
+
"SafetyDecision",
|
|
49
|
+
"SafetyAssessment",
|
|
50
|
+
]
|
openadapt_ml/runtime/policy.py
CHANGED
|
@@ -3,16 +3,16 @@ from __future__ import annotations
|
|
|
3
3
|
import json
|
|
4
4
|
import re
|
|
5
5
|
from dataclasses import dataclass
|
|
6
|
-
from typing import Any, Dict,
|
|
6
|
+
from typing import Any, Dict, Optional, Tuple
|
|
7
7
|
|
|
8
8
|
from PIL import Image
|
|
9
9
|
|
|
10
10
|
from openadapt_ml.models.base_adapter import BaseVLMAdapter
|
|
11
|
-
from openadapt_ml.
|
|
11
|
+
from openadapt_ml.schema import Action, ActionType, UIElement
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
# Coordinate-based DSL patterns
|
|
15
|
-
_CLICK_RE = re.compile(r"CLICK\(x=([0-9]*\.?[0-9]+),\s*y=([0-9]*\.?[0-9]+)\)")
|
|
15
|
+
_CLICK_RE = re.compile(r"CLICK\(x=(-?[0-9]*\.?[0-9]+),\s*y=(-?[0-9]*\.?[0-9]+)\)")
|
|
16
16
|
_TYPE_RE = re.compile(r'TYPE\(text="([^"\\]*(?:\\.[^"\\]*)*)"\)')
|
|
17
17
|
_WAIT_RE = re.compile(r"\bWAIT\s*\(\s*\)")
|
|
18
18
|
_DONE_RE = re.compile(r"\bDONE\s*\(\s*\)")
|
|
@@ -26,13 +26,16 @@ _TYPE_SOM_SIMPLE_RE = re.compile(r'TYPE\(["\']([^"\']*(?:\\.[^"\']*)*)["\']\)')
|
|
|
26
26
|
@dataclass
|
|
27
27
|
class PolicyOutput:
|
|
28
28
|
"""Result of a single policy step."""
|
|
29
|
+
|
|
29
30
|
action: Action
|
|
30
31
|
thought: Optional[str] = None
|
|
31
32
|
state: Optional[Dict[str, Any]] = None
|
|
32
33
|
raw_text: str = ""
|
|
33
34
|
|
|
34
35
|
|
|
35
|
-
def parse_thought_state_action(
|
|
36
|
+
def parse_thought_state_action(
|
|
37
|
+
text: str,
|
|
38
|
+
) -> Tuple[Optional[str], Optional[Dict[str, Any]], str]:
|
|
36
39
|
"""Parse Thought / State / Action blocks from model output.
|
|
37
40
|
|
|
38
41
|
Expected format:
|
|
@@ -54,12 +57,18 @@ def parse_thought_state_action(text: str) -> Tuple[Optional[str], Optional[Dict[
|
|
|
54
57
|
action_str: str = text.strip()
|
|
55
58
|
|
|
56
59
|
# Extract Thought - find the LAST occurrence (model's response, not template)
|
|
57
|
-
thought_matches = list(
|
|
60
|
+
thought_matches = list(
|
|
61
|
+
re.finditer(
|
|
62
|
+
r"Thought:\s*(.+?)(?=State:|Action:|$)", text, re.DOTALL | re.IGNORECASE
|
|
63
|
+
)
|
|
64
|
+
)
|
|
58
65
|
if thought_matches:
|
|
59
66
|
thought = thought_matches[-1].group(1).strip()
|
|
60
67
|
|
|
61
68
|
# Extract State (JSON on same line or next line) - last occurrence
|
|
62
|
-
state_matches = list(
|
|
69
|
+
state_matches = list(
|
|
70
|
+
re.finditer(r"State:\s*(\{.*?\})", text, re.DOTALL | re.IGNORECASE)
|
|
71
|
+
)
|
|
63
72
|
if state_matches:
|
|
64
73
|
try:
|
|
65
74
|
state = json.loads(state_matches[-1].group(1))
|
|
@@ -119,7 +128,7 @@ class AgentPolicy:
|
|
|
119
128
|
m = _CLICK_SOM_RE.search(text)
|
|
120
129
|
if m:
|
|
121
130
|
idx = int(m.group(1))
|
|
122
|
-
return Action(type=
|
|
131
|
+
return Action(type=ActionType.CLICK, element=UIElement(element_id=str(idx)))
|
|
123
132
|
|
|
124
133
|
# TYPE([N], "text")
|
|
125
134
|
m = _TYPE_SOM_RE.search(text)
|
|
@@ -127,14 +136,18 @@ class AgentPolicy:
|
|
|
127
136
|
idx = int(m.group(1))
|
|
128
137
|
raw_text = m.group(2)
|
|
129
138
|
unescaped = raw_text.replace('\\"', '"').replace("\\\\", "\\")
|
|
130
|
-
return Action(
|
|
139
|
+
return Action(
|
|
140
|
+
type=ActionType.TYPE,
|
|
141
|
+
text=unescaped,
|
|
142
|
+
element=UIElement(element_id=str(idx)),
|
|
143
|
+
)
|
|
131
144
|
|
|
132
145
|
# TYPE("text") - SoM style without index
|
|
133
146
|
m = _TYPE_SOM_SIMPLE_RE.search(text)
|
|
134
147
|
if m:
|
|
135
148
|
raw_text = m.group(1)
|
|
136
149
|
unescaped = raw_text.replace('\\"', '"').replace("\\\\", "\\")
|
|
137
|
-
return Action(type=
|
|
150
|
+
return Action(type=ActionType.TYPE, text=unescaped)
|
|
138
151
|
|
|
139
152
|
# Coordinate-based patterns
|
|
140
153
|
# CLICK(x=..., y=...)
|
|
@@ -145,7 +158,7 @@ class AgentPolicy:
|
|
|
145
158
|
# Clamp to [0, 1]
|
|
146
159
|
x = max(0.0, min(1.0, x))
|
|
147
160
|
y = max(0.0, min(1.0, y))
|
|
148
|
-
return Action(type=
|
|
161
|
+
return Action(type=ActionType.CLICK, normalized_coordinates=(x, y))
|
|
149
162
|
|
|
150
163
|
# TYPE(text="...")
|
|
151
164
|
m = _TYPE_RE.search(text)
|
|
@@ -153,18 +166,18 @@ class AgentPolicy:
|
|
|
153
166
|
# Unescape the text content
|
|
154
167
|
raw_text = m.group(1)
|
|
155
168
|
unescaped = raw_text.replace('\\"', '"').replace("\\\\", "\\")
|
|
156
|
-
return Action(type=
|
|
169
|
+
return Action(type=ActionType.TYPE, text=unescaped)
|
|
157
170
|
|
|
158
171
|
# WAIT()
|
|
159
172
|
if _WAIT_RE.search(text):
|
|
160
|
-
return Action(type=
|
|
173
|
+
return Action(type=ActionType.WAIT)
|
|
161
174
|
|
|
162
175
|
# DONE()
|
|
163
176
|
if _DONE_RE.search(text):
|
|
164
|
-
return Action(type=
|
|
177
|
+
return Action(type=ActionType.DONE)
|
|
165
178
|
|
|
166
179
|
# Fallback
|
|
167
|
-
return Action(type=
|
|
180
|
+
return Action(type=ActionType.FAIL, raw={"text": text})
|
|
168
181
|
|
|
169
182
|
def predict_action_from_sample(
|
|
170
183
|
self, sample: Dict[str, Any], max_new_tokens: int = 150
|