openadapt-ml 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/baselines/__init__.py +121 -0
- openadapt_ml/baselines/adapter.py +185 -0
- openadapt_ml/baselines/cli.py +314 -0
- openadapt_ml/baselines/config.py +448 -0
- openadapt_ml/baselines/parser.py +922 -0
- openadapt_ml/baselines/prompts.py +787 -0
- openadapt_ml/benchmarks/__init__.py +13 -107
- openadapt_ml/benchmarks/agent.py +297 -374
- openadapt_ml/benchmarks/azure.py +62 -24
- openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
- openadapt_ml/benchmarks/cli.py +1874 -751
- openadapt_ml/benchmarks/trace_export.py +631 -0
- openadapt_ml/benchmarks/viewer.py +1236 -0
- openadapt_ml/benchmarks/vm_monitor.py +1111 -0
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +216 -0
- openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +540 -0
- openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
- openadapt_ml/cloud/azure_inference.py +3 -5
- openadapt_ml/cloud/lambda_labs.py +722 -307
- openadapt_ml/cloud/local.py +3194 -89
- openadapt_ml/cloud/ssh_tunnel.py +595 -0
- openadapt_ml/datasets/next_action.py +125 -96
- openadapt_ml/evals/grounding.py +32 -9
- openadapt_ml/evals/plot_eval_metrics.py +15 -13
- openadapt_ml/evals/trajectory_matching.py +120 -57
- openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
- openadapt_ml/experiments/demo_prompt/format_demo.py +236 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
- openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
- openadapt_ml/experiments/demo_prompt/run_experiment.py +541 -0
- openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
- openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
- openadapt_ml/experiments/representation_shootout/config.py +390 -0
- openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
- openadapt_ml/experiments/representation_shootout/runner.py +687 -0
- openadapt_ml/experiments/waa_demo/__init__.py +10 -0
- openadapt_ml/experiments/waa_demo/demos.py +357 -0
- openadapt_ml/experiments/waa_demo/runner.py +732 -0
- openadapt_ml/experiments/waa_demo/tasks.py +151 -0
- openadapt_ml/export/__init__.py +9 -0
- openadapt_ml/export/__main__.py +6 -0
- openadapt_ml/export/cli.py +89 -0
- openadapt_ml/export/parquet.py +277 -0
- openadapt_ml/grounding/detector.py +18 -14
- openadapt_ml/ingest/__init__.py +11 -10
- openadapt_ml/ingest/capture.py +97 -86
- openadapt_ml/ingest/loader.py +120 -69
- openadapt_ml/ingest/synthetic.py +344 -193
- openadapt_ml/models/api_adapter.py +14 -4
- openadapt_ml/models/base_adapter.py +10 -2
- openadapt_ml/models/providers/__init__.py +288 -0
- openadapt_ml/models/providers/anthropic.py +266 -0
- openadapt_ml/models/providers/base.py +299 -0
- openadapt_ml/models/providers/google.py +376 -0
- openadapt_ml/models/providers/openai.py +342 -0
- openadapt_ml/models/qwen_vl.py +46 -19
- openadapt_ml/perception/__init__.py +35 -0
- openadapt_ml/perception/integration.py +399 -0
- openadapt_ml/retrieval/README.md +226 -0
- openadapt_ml/retrieval/USAGE.md +391 -0
- openadapt_ml/retrieval/__init__.py +91 -0
- openadapt_ml/retrieval/demo_retriever.py +843 -0
- openadapt_ml/retrieval/embeddings.py +630 -0
- openadapt_ml/retrieval/index.py +194 -0
- openadapt_ml/retrieval/retriever.py +162 -0
- openadapt_ml/runtime/__init__.py +50 -0
- openadapt_ml/runtime/policy.py +27 -14
- openadapt_ml/runtime/safety_gate.py +471 -0
- openadapt_ml/schema/__init__.py +113 -0
- openadapt_ml/schema/converters.py +588 -0
- openadapt_ml/schema/episode.py +470 -0
- openadapt_ml/scripts/capture_screenshots.py +530 -0
- openadapt_ml/scripts/compare.py +102 -61
- openadapt_ml/scripts/demo_policy.py +4 -1
- openadapt_ml/scripts/eval_policy.py +19 -14
- openadapt_ml/scripts/make_gif.py +1 -1
- openadapt_ml/scripts/prepare_synthetic.py +16 -17
- openadapt_ml/scripts/train.py +98 -75
- openadapt_ml/segmentation/README.md +920 -0
- openadapt_ml/segmentation/__init__.py +97 -0
- openadapt_ml/segmentation/adapters/__init__.py +5 -0
- openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
- openadapt_ml/segmentation/annotator.py +610 -0
- openadapt_ml/segmentation/cache.py +290 -0
- openadapt_ml/segmentation/cli.py +674 -0
- openadapt_ml/segmentation/deduplicator.py +656 -0
- openadapt_ml/segmentation/frame_describer.py +788 -0
- openadapt_ml/segmentation/pipeline.py +340 -0
- openadapt_ml/segmentation/schemas.py +622 -0
- openadapt_ml/segmentation/segment_extractor.py +634 -0
- openadapt_ml/training/azure_ops_viewer.py +1097 -0
- openadapt_ml/training/benchmark_viewer.py +3255 -19
- openadapt_ml/training/shared_ui.py +7 -7
- openadapt_ml/training/stub_provider.py +57 -35
- openadapt_ml/training/trainer.py +255 -441
- openadapt_ml/training/trl_trainer.py +403 -0
- openadapt_ml/training/viewer.py +323 -108
- openadapt_ml/training/viewer_components.py +180 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +312 -69
- openadapt_ml-0.2.1.dist-info/RECORD +116 -0
- openadapt_ml/benchmarks/base.py +0 -366
- openadapt_ml/benchmarks/data_collection.py +0 -432
- openadapt_ml/benchmarks/runner.py +0 -381
- openadapt_ml/benchmarks/waa.py +0 -704
- openadapt_ml/schemas/__init__.py +0 -53
- openadapt_ml/schemas/sessions.py +0 -122
- openadapt_ml/schemas/validation.py +0 -252
- openadapt_ml-0.1.0.dist-info/RECORD +0 -55
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,471 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Safety Gate Module - Deterministic safety checks for GUI automation actions.
|
|
3
|
+
|
|
4
|
+
The Safety Gate runs AFTER the policy decides but BEFORE execution, providing
|
|
5
|
+
a critical safety layer that prevents destructive or irreversible actions
|
|
6
|
+
without explicit human confirmation.
|
|
7
|
+
|
|
8
|
+
Example Usage:
|
|
9
|
+
from openadapt_ml.runtime import SafetyGate, SafetyConfig, SafetyDecision
|
|
10
|
+
|
|
11
|
+
# Create with default config
|
|
12
|
+
gate = SafetyGate()
|
|
13
|
+
|
|
14
|
+
# Or customize
|
|
15
|
+
config = SafetyConfig(
|
|
16
|
+
confidence_threshold=0.8,
|
|
17
|
+
loop_threshold=2,
|
|
18
|
+
expected_app="Chrome",
|
|
19
|
+
)
|
|
20
|
+
gate = SafetyGate(config)
|
|
21
|
+
|
|
22
|
+
# Evaluate action
|
|
23
|
+
assessment = gate.assess(action, observation, trace=history)
|
|
24
|
+
|
|
25
|
+
if assessment.decision == SafetyDecision.ALLOW:
|
|
26
|
+
execute(action)
|
|
27
|
+
elif assessment.decision == SafetyDecision.REQUIRE_CONFIRMATION:
|
|
28
|
+
if user_confirms(assessment.reason):
|
|
29
|
+
execute(action)
|
|
30
|
+
else: # BLOCK
|
|
31
|
+
log_warning(assessment.reason)
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
from __future__ import annotations
|
|
35
|
+
|
|
36
|
+
import hashlib
|
|
37
|
+
import re
|
|
38
|
+
from dataclasses import dataclass, field
|
|
39
|
+
from enum import Enum
|
|
40
|
+
from typing import Optional
|
|
41
|
+
|
|
42
|
+
from openadapt_ml.schema import Action, ActionType, Observation, Step
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class SafetyDecision(str, Enum):
|
|
46
|
+
"""Possible safety gate decisions."""
|
|
47
|
+
|
|
48
|
+
ALLOW = "allow" # Action can proceed without intervention
|
|
49
|
+
BLOCK = "block" # Action must not proceed under any circumstances
|
|
50
|
+
REQUIRE_CONFIRMATION = "require_confirmation" # Human must approve
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class SafetyAssessment:
|
|
55
|
+
"""Result of safety gate evaluation.
|
|
56
|
+
|
|
57
|
+
Attributes:
|
|
58
|
+
decision: The safety decision (ALLOW, BLOCK, or REQUIRE_CONFIRMATION)
|
|
59
|
+
reason: Human-readable explanation of the decision
|
|
60
|
+
triggered_rules: List of safety rule names that triggered this decision
|
|
61
|
+
confidence: How certain the gate is about this decision (1.0 = certain)
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
decision: SafetyDecision
|
|
65
|
+
reason: str
|
|
66
|
+
triggered_rules: list[str]
|
|
67
|
+
confidence: float = 1.0
|
|
68
|
+
|
|
69
|
+
def __str__(self) -> str:
|
|
70
|
+
return f"SafetyAssessment({self.decision.value}: {self.reason})"
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
# Default blocklist patterns - actions that are ALWAYS blocked
|
|
74
|
+
DEFAULT_BLOCKLIST_PATTERNS = [
|
|
75
|
+
# File/data destruction keywords
|
|
76
|
+
r"\bdelete\b",
|
|
77
|
+
r"\bremove\b",
|
|
78
|
+
r"\bformat\b",
|
|
79
|
+
r"\breset\b",
|
|
80
|
+
r"\bbroadcast\b",
|
|
81
|
+
# Database destruction
|
|
82
|
+
r"\bdrop\s+table\b",
|
|
83
|
+
r"\btruncate\b",
|
|
84
|
+
# Shell destruction commands
|
|
85
|
+
r"\brm\s+-rf\b",
|
|
86
|
+
r"\bsudo\s+rm\b",
|
|
87
|
+
]
|
|
88
|
+
|
|
89
|
+
# Default irreversible action patterns - require confirmation
|
|
90
|
+
DEFAULT_IRREVERSIBLE_PATTERNS = [
|
|
91
|
+
# Submission actions
|
|
92
|
+
r"\bsubmit\b",
|
|
93
|
+
r"\bsend\b",
|
|
94
|
+
r"\bapply\b",
|
|
95
|
+
r"\bconfirm\b",
|
|
96
|
+
# Document closure (potentially with unsaved changes)
|
|
97
|
+
r"\bclos(?:e|ing)\b",
|
|
98
|
+
# Financial actions
|
|
99
|
+
r"\bpurchase\b",
|
|
100
|
+
r"\bcheckout\b",
|
|
101
|
+
r"\bpay\b",
|
|
102
|
+
r"\bbuy\b",
|
|
103
|
+
r"\border\b",
|
|
104
|
+
]
|
|
105
|
+
|
|
106
|
+
# Default credential field patterns - typing to these requires confirmation
|
|
107
|
+
DEFAULT_CREDENTIAL_PATTERNS = [
|
|
108
|
+
r"password",
|
|
109
|
+
r"token",
|
|
110
|
+
r"secret",
|
|
111
|
+
r"api[_-]?key",
|
|
112
|
+
r"apikey",
|
|
113
|
+
r"credential",
|
|
114
|
+
r"auth",
|
|
115
|
+
r"private[_-]?key",
|
|
116
|
+
]
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@dataclass
|
|
120
|
+
class SafetyConfig:
|
|
121
|
+
"""Configuration for safety gate behavior.
|
|
122
|
+
|
|
123
|
+
All patterns are case-insensitive regular expressions.
|
|
124
|
+
|
|
125
|
+
Attributes:
|
|
126
|
+
blocklist_patterns: Patterns that trigger BLOCK decision
|
|
127
|
+
irreversible_patterns: Patterns that trigger REQUIRE_CONFIRMATION
|
|
128
|
+
confidence_threshold: Actions below this confidence require confirmation
|
|
129
|
+
loop_threshold: Same state visited this many times triggers BLOCK
|
|
130
|
+
credential_patterns: Field name patterns that indicate sensitive input
|
|
131
|
+
credential_allowlist: Override patterns to allow typing in specific fields
|
|
132
|
+
expected_app: Expected application name (None = don't check)
|
|
133
|
+
expected_window_pattern: Regex for expected window title (None = don't check)
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
blocklist_patterns: list[str] = field(
|
|
137
|
+
default_factory=lambda: DEFAULT_BLOCKLIST_PATTERNS.copy()
|
|
138
|
+
)
|
|
139
|
+
irreversible_patterns: list[str] = field(
|
|
140
|
+
default_factory=lambda: DEFAULT_IRREVERSIBLE_PATTERNS.copy()
|
|
141
|
+
)
|
|
142
|
+
confidence_threshold: float = 0.7
|
|
143
|
+
loop_threshold: int = 3
|
|
144
|
+
credential_patterns: list[str] = field(
|
|
145
|
+
default_factory=lambda: DEFAULT_CREDENTIAL_PATTERNS.copy()
|
|
146
|
+
)
|
|
147
|
+
credential_allowlist: list[str] = field(default_factory=list)
|
|
148
|
+
expected_app: Optional[str] = None
|
|
149
|
+
expected_window_pattern: Optional[str] = None
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class SafetyGate:
|
|
153
|
+
"""Deterministic safety gate for GUI automation actions.
|
|
154
|
+
|
|
155
|
+
The SafetyGate evaluates proposed actions against a set of safety rules
|
|
156
|
+
and returns a SafetyAssessment indicating whether the action should be
|
|
157
|
+
allowed, blocked, or require human confirmation.
|
|
158
|
+
|
|
159
|
+
Safety checks are evaluated in priority order:
|
|
160
|
+
1. Blocklist (immediate BLOCK)
|
|
161
|
+
2. Loop Detection (BLOCK if loop detected)
|
|
162
|
+
3. Credential Guard (REQUIRE_CONFIRMATION if typing to sensitive field)
|
|
163
|
+
4. Irreversibility (REQUIRE_CONFIRMATION for irreversible actions)
|
|
164
|
+
5. App/Window Mismatch (REQUIRE_CONFIRMATION if context mismatch)
|
|
165
|
+
6. Confidence Threshold (REQUIRE_CONFIRMATION if low confidence)
|
|
166
|
+
7. Default (ALLOW if no rules triggered)
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
def __init__(self, config: Optional[SafetyConfig] = None) -> None:
|
|
170
|
+
"""Initialize the safety gate with optional custom configuration.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
config: Safety configuration. If None, uses default SafetyConfig.
|
|
174
|
+
"""
|
|
175
|
+
self.config = config or SafetyConfig()
|
|
176
|
+
self._state_visit_counts: dict[str, int] = {}
|
|
177
|
+
|
|
178
|
+
# Pre-compile regex patterns for efficiency
|
|
179
|
+
self._blocklist_re = self._compile_patterns(self.config.blocklist_patterns)
|
|
180
|
+
self._irreversible_re = self._compile_patterns(
|
|
181
|
+
self.config.irreversible_patterns
|
|
182
|
+
)
|
|
183
|
+
self._credential_re = self._compile_patterns(self.config.credential_patterns)
|
|
184
|
+
self._credential_allowlist_re = self._compile_patterns(
|
|
185
|
+
self.config.credential_allowlist
|
|
186
|
+
)
|
|
187
|
+
self._window_pattern_re = (
|
|
188
|
+
re.compile(self.config.expected_window_pattern, re.IGNORECASE)
|
|
189
|
+
if self.config.expected_window_pattern
|
|
190
|
+
else None
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
def _compile_patterns(self, patterns: list[str]) -> list[re.Pattern]:
|
|
194
|
+
"""Compile a list of regex patterns (case-insensitive)."""
|
|
195
|
+
return [re.compile(p, re.IGNORECASE) for p in patterns]
|
|
196
|
+
|
|
197
|
+
def _matches_any(self, text: str, patterns: list[re.Pattern]) -> Optional[str]:
|
|
198
|
+
"""Check if text matches any pattern. Returns matching pattern or None."""
|
|
199
|
+
for pattern in patterns:
|
|
200
|
+
if pattern.search(text):
|
|
201
|
+
return pattern.pattern
|
|
202
|
+
return None
|
|
203
|
+
|
|
204
|
+
def _compute_state_hash(self, observation: Observation) -> str:
|
|
205
|
+
"""Compute a hash representing the current state for loop detection.
|
|
206
|
+
|
|
207
|
+
Uses window title, app name, and URL to identify state. This is a
|
|
208
|
+
lightweight approach that doesn't require image comparison.
|
|
209
|
+
"""
|
|
210
|
+
components = [
|
|
211
|
+
observation.window_title or "",
|
|
212
|
+
observation.app_name or "",
|
|
213
|
+
observation.url or "",
|
|
214
|
+
]
|
|
215
|
+
state_str = "|".join(components)
|
|
216
|
+
return hashlib.sha256(state_str.encode()).hexdigest()[:16]
|
|
217
|
+
|
|
218
|
+
def _get_action_text(self, action: Action) -> str:
|
|
219
|
+
"""Extract text content from an action for pattern matching."""
|
|
220
|
+
parts = []
|
|
221
|
+
|
|
222
|
+
# Include typed text
|
|
223
|
+
if action.text:
|
|
224
|
+
parts.append(action.text)
|
|
225
|
+
|
|
226
|
+
# Include key presses
|
|
227
|
+
if action.key:
|
|
228
|
+
parts.append(action.key)
|
|
229
|
+
|
|
230
|
+
# Include URL for navigation
|
|
231
|
+
if action.url:
|
|
232
|
+
parts.append(action.url)
|
|
233
|
+
|
|
234
|
+
# Include app name for open/close
|
|
235
|
+
if action.app_name:
|
|
236
|
+
parts.append(action.app_name)
|
|
237
|
+
|
|
238
|
+
# Include window title for focus
|
|
239
|
+
if action.window_title:
|
|
240
|
+
parts.append(action.window_title)
|
|
241
|
+
|
|
242
|
+
# Include raw action data if present
|
|
243
|
+
if action.raw:
|
|
244
|
+
for value in action.raw.values():
|
|
245
|
+
if isinstance(value, str):
|
|
246
|
+
parts.append(value)
|
|
247
|
+
|
|
248
|
+
return " ".join(parts)
|
|
249
|
+
|
|
250
|
+
def _get_target_field_name(self, action: Action, observation: Observation) -> str:
|
|
251
|
+
"""Get the name/label of the field targeted by a TYPE action."""
|
|
252
|
+
parts = []
|
|
253
|
+
|
|
254
|
+
# Check action's target element
|
|
255
|
+
if action.element:
|
|
256
|
+
if action.element.name:
|
|
257
|
+
parts.append(action.element.name)
|
|
258
|
+
if action.element.role:
|
|
259
|
+
parts.append(action.element.role)
|
|
260
|
+
if action.element.automation_id:
|
|
261
|
+
parts.append(action.element.automation_id)
|
|
262
|
+
|
|
263
|
+
# Check observation's focused element
|
|
264
|
+
if observation.focused_element:
|
|
265
|
+
if observation.focused_element.name:
|
|
266
|
+
parts.append(observation.focused_element.name)
|
|
267
|
+
if observation.focused_element.role:
|
|
268
|
+
parts.append(observation.focused_element.role)
|
|
269
|
+
if observation.focused_element.automation_id:
|
|
270
|
+
parts.append(observation.focused_element.automation_id)
|
|
271
|
+
|
|
272
|
+
return " ".join(parts)
|
|
273
|
+
|
|
274
|
+
def _get_action_confidence(self, action: Action) -> Optional[float]:
|
|
275
|
+
"""Extract confidence score from action's raw metadata."""
|
|
276
|
+
if action.raw and "confidence" in action.raw:
|
|
277
|
+
try:
|
|
278
|
+
return float(action.raw["confidence"])
|
|
279
|
+
except (TypeError, ValueError):
|
|
280
|
+
pass
|
|
281
|
+
return None
|
|
282
|
+
|
|
283
|
+
def _check_blocklist(self, action: Action) -> Optional[SafetyAssessment]:
|
|
284
|
+
"""Check if action text matches any blocklist pattern."""
|
|
285
|
+
action_text = self._get_action_text(action)
|
|
286
|
+
matched = self._matches_any(action_text, self._blocklist_re)
|
|
287
|
+
|
|
288
|
+
if matched:
|
|
289
|
+
return SafetyAssessment(
|
|
290
|
+
decision=SafetyDecision.BLOCK,
|
|
291
|
+
reason=f"Action contains blocked keyword pattern: {matched}",
|
|
292
|
+
triggered_rules=["blocklist"],
|
|
293
|
+
confidence=1.0,
|
|
294
|
+
)
|
|
295
|
+
return None
|
|
296
|
+
|
|
297
|
+
def _check_loop_detection(
|
|
298
|
+
self, observation: Observation
|
|
299
|
+
) -> Optional[SafetyAssessment]:
|
|
300
|
+
"""Check if we're stuck in a loop visiting the same state."""
|
|
301
|
+
state_hash = self._compute_state_hash(observation)
|
|
302
|
+
self._state_visit_counts[state_hash] = (
|
|
303
|
+
self._state_visit_counts.get(state_hash, 0) + 1
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
visit_count = self._state_visit_counts[state_hash]
|
|
307
|
+
if visit_count >= self.config.loop_threshold:
|
|
308
|
+
return SafetyAssessment(
|
|
309
|
+
decision=SafetyDecision.BLOCK,
|
|
310
|
+
reason=f"Loop detected: same state visited {visit_count} times "
|
|
311
|
+
f"(threshold: {self.config.loop_threshold})",
|
|
312
|
+
triggered_rules=["loop_detection"],
|
|
313
|
+
confidence=1.0,
|
|
314
|
+
)
|
|
315
|
+
return None
|
|
316
|
+
|
|
317
|
+
def _check_credential_guard(
|
|
318
|
+
self, action: Action, observation: Observation
|
|
319
|
+
) -> Optional[SafetyAssessment]:
|
|
320
|
+
"""Check if TYPE action targets a credential/sensitive field."""
|
|
321
|
+
if action.type != ActionType.TYPE:
|
|
322
|
+
return None
|
|
323
|
+
|
|
324
|
+
field_name = self._get_target_field_name(action, observation)
|
|
325
|
+
if not field_name:
|
|
326
|
+
return None
|
|
327
|
+
|
|
328
|
+
# Check if field is in allowlist (override)
|
|
329
|
+
if self._matches_any(field_name, self._credential_allowlist_re):
|
|
330
|
+
return None
|
|
331
|
+
|
|
332
|
+
# Check if field matches credential patterns
|
|
333
|
+
matched = self._matches_any(field_name, self._credential_re)
|
|
334
|
+
if matched:
|
|
335
|
+
return SafetyAssessment(
|
|
336
|
+
decision=SafetyDecision.REQUIRE_CONFIRMATION,
|
|
337
|
+
reason=f"Typing into credential field matching '{matched}': {field_name}",
|
|
338
|
+
triggered_rules=["credential_guard"],
|
|
339
|
+
confidence=1.0,
|
|
340
|
+
)
|
|
341
|
+
return None
|
|
342
|
+
|
|
343
|
+
def _check_irreversibility(self, action: Action) -> Optional[SafetyAssessment]:
|
|
344
|
+
"""Check if action appears to be irreversible (submit, send, etc.)."""
|
|
345
|
+
action_text = self._get_action_text(action)
|
|
346
|
+
matched = self._matches_any(action_text, self._irreversible_re)
|
|
347
|
+
|
|
348
|
+
if matched:
|
|
349
|
+
return SafetyAssessment(
|
|
350
|
+
decision=SafetyDecision.REQUIRE_CONFIRMATION,
|
|
351
|
+
reason=f"Action may be irreversible (matched pattern: {matched})",
|
|
352
|
+
triggered_rules=["irreversibility"],
|
|
353
|
+
confidence=0.9,
|
|
354
|
+
)
|
|
355
|
+
return None
|
|
356
|
+
|
|
357
|
+
def _check_app_window_mismatch(
|
|
358
|
+
self, observation: Observation
|
|
359
|
+
) -> Optional[SafetyAssessment]:
|
|
360
|
+
"""Check if current app/window doesn't match expected context."""
|
|
361
|
+
triggered_rules = []
|
|
362
|
+
reasons = []
|
|
363
|
+
|
|
364
|
+
# Check app name
|
|
365
|
+
if self.config.expected_app:
|
|
366
|
+
current_app = observation.app_name or ""
|
|
367
|
+
if current_app.lower() != self.config.expected_app.lower():
|
|
368
|
+
triggered_rules.append("app_mismatch")
|
|
369
|
+
reasons.append(
|
|
370
|
+
f"Expected app '{self.config.expected_app}', got '{current_app}'"
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
# Check window title pattern
|
|
374
|
+
if self._window_pattern_re:
|
|
375
|
+
current_title = observation.window_title or ""
|
|
376
|
+
if not self._window_pattern_re.search(current_title):
|
|
377
|
+
triggered_rules.append("window_mismatch")
|
|
378
|
+
reasons.append(
|
|
379
|
+
f"Window title '{current_title}' doesn't match pattern "
|
|
380
|
+
f"'{self.config.expected_window_pattern}'"
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
if triggered_rules:
|
|
384
|
+
return SafetyAssessment(
|
|
385
|
+
decision=SafetyDecision.REQUIRE_CONFIRMATION,
|
|
386
|
+
reason="; ".join(reasons),
|
|
387
|
+
triggered_rules=triggered_rules,
|
|
388
|
+
confidence=0.95,
|
|
389
|
+
)
|
|
390
|
+
return None
|
|
391
|
+
|
|
392
|
+
def _check_confidence_threshold(self, action: Action) -> Optional[SafetyAssessment]:
|
|
393
|
+
"""Check if action confidence is below threshold."""
|
|
394
|
+
confidence = self._get_action_confidence(action)
|
|
395
|
+
|
|
396
|
+
if confidence is not None and confidence < self.config.confidence_threshold:
|
|
397
|
+
return SafetyAssessment(
|
|
398
|
+
decision=SafetyDecision.REQUIRE_CONFIRMATION,
|
|
399
|
+
reason=f"Action confidence ({confidence:.2f}) below threshold "
|
|
400
|
+
f"({self.config.confidence_threshold})",
|
|
401
|
+
triggered_rules=["confidence_threshold"],
|
|
402
|
+
confidence=1.0,
|
|
403
|
+
)
|
|
404
|
+
return None
|
|
405
|
+
|
|
406
|
+
def assess(
|
|
407
|
+
self,
|
|
408
|
+
action: Action,
|
|
409
|
+
observation: Observation,
|
|
410
|
+
trace: Optional[list[Step]] = None,
|
|
411
|
+
) -> SafetyAssessment:
|
|
412
|
+
"""Evaluate an action against all safety rules.
|
|
413
|
+
|
|
414
|
+
Checks are evaluated in priority order. First triggered rule wins.
|
|
415
|
+
|
|
416
|
+
Args:
|
|
417
|
+
action: The action proposed by the policy
|
|
418
|
+
observation: Current state observation
|
|
419
|
+
trace: Execution history (currently unused, reserved for future use)
|
|
420
|
+
|
|
421
|
+
Returns:
|
|
422
|
+
SafetyAssessment with decision and reasoning
|
|
423
|
+
"""
|
|
424
|
+
# Priority 1: Blocklist (BLOCK)
|
|
425
|
+
assessment = self._check_blocklist(action)
|
|
426
|
+
if assessment:
|
|
427
|
+
return assessment
|
|
428
|
+
|
|
429
|
+
# Priority 2: Loop Detection (BLOCK)
|
|
430
|
+
assessment = self._check_loop_detection(observation)
|
|
431
|
+
if assessment:
|
|
432
|
+
return assessment
|
|
433
|
+
|
|
434
|
+
# Priority 3: Credential Guard (REQUIRE_CONFIRMATION)
|
|
435
|
+
assessment = self._check_credential_guard(action, observation)
|
|
436
|
+
if assessment:
|
|
437
|
+
return assessment
|
|
438
|
+
|
|
439
|
+
# Priority 4: Irreversibility (REQUIRE_CONFIRMATION)
|
|
440
|
+
assessment = self._check_irreversibility(action)
|
|
441
|
+
if assessment:
|
|
442
|
+
return assessment
|
|
443
|
+
|
|
444
|
+
# Priority 5: App/Window Mismatch (REQUIRE_CONFIRMATION)
|
|
445
|
+
assessment = self._check_app_window_mismatch(observation)
|
|
446
|
+
if assessment:
|
|
447
|
+
return assessment
|
|
448
|
+
|
|
449
|
+
# Priority 6: Confidence Threshold (REQUIRE_CONFIRMATION)
|
|
450
|
+
assessment = self._check_confidence_threshold(action)
|
|
451
|
+
if assessment:
|
|
452
|
+
return assessment
|
|
453
|
+
|
|
454
|
+
# Default: ALLOW
|
|
455
|
+
return SafetyAssessment(
|
|
456
|
+
decision=SafetyDecision.ALLOW,
|
|
457
|
+
reason="All safety checks passed",
|
|
458
|
+
triggered_rules=[],
|
|
459
|
+
confidence=1.0,
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
def reset(self) -> None:
|
|
463
|
+
"""Clear internal state (loop detection history).
|
|
464
|
+
|
|
465
|
+
Call this method between episodes to reset the state visit counts.
|
|
466
|
+
"""
|
|
467
|
+
self._state_visit_counts.clear()
|
|
468
|
+
|
|
469
|
+
def get_state_visit_counts(self) -> dict[str, int]:
|
|
470
|
+
"""Get current state visit counts (for debugging/monitoring)."""
|
|
471
|
+
return self._state_visit_counts.copy()
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Episode Schema - Canonical format for GUI trajectory data.
|
|
3
|
+
|
|
4
|
+
A standardized contract for representing GUI automation episodes, enabling
|
|
5
|
+
interoperability across training pipelines, benchmarks, and recording tools.
|
|
6
|
+
|
|
7
|
+
Installation:
|
|
8
|
+
pip install openadapt-ml
|
|
9
|
+
# or: uv add openadapt-ml
|
|
10
|
+
|
|
11
|
+
Basic Usage:
|
|
12
|
+
from openadapt_ml.schema import Episode, Step, Action, Observation, ActionType
|
|
13
|
+
|
|
14
|
+
# Create an episode
|
|
15
|
+
episode = Episode(
|
|
16
|
+
episode_id="demo_001",
|
|
17
|
+
instruction="Open Notepad and type Hello World",
|
|
18
|
+
steps=[
|
|
19
|
+
Step(
|
|
20
|
+
step_index=0,
|
|
21
|
+
observation=Observation(screenshot_path="step_0.png"),
|
|
22
|
+
action=Action(type=ActionType.CLICK, coordinates={"x": 100, "y": 200}),
|
|
23
|
+
),
|
|
24
|
+
Step(
|
|
25
|
+
step_index=1,
|
|
26
|
+
observation=Observation(screenshot_path="step_1.png"),
|
|
27
|
+
action=Action(type=ActionType.TYPE, text="Hello World"),
|
|
28
|
+
),
|
|
29
|
+
],
|
|
30
|
+
success=True,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# Save/load JSON
|
|
34
|
+
save_episode(episode, "episode.json")
|
|
35
|
+
episode = load_episode("episode.json")
|
|
36
|
+
|
|
37
|
+
# Validate external data
|
|
38
|
+
is_valid, error = validate_episode({"episode_id": "x", ...})
|
|
39
|
+
|
|
40
|
+
Coordinate Systems:
|
|
41
|
+
# Pixel coordinates (absolute)
|
|
42
|
+
Action(type=ActionType.CLICK, coordinates={"x": 512, "y": 384})
|
|
43
|
+
|
|
44
|
+
# Normalized coordinates (0.0-1.0, resolution-independent)
|
|
45
|
+
Action(type=ActionType.CLICK, normalized_coordinates=(0.5, 0.375))
|
|
46
|
+
|
|
47
|
+
# Both can coexist - use whichever fits your pipeline
|
|
48
|
+
|
|
49
|
+
Converting from Other Formats:
|
|
50
|
+
from openadapt_ml.schema.converters import from_waa_trajectory
|
|
51
|
+
|
|
52
|
+
# Convert Windows Agent Arena format
|
|
53
|
+
episode = from_waa_trajectory(trajectory_list, task_info_dict)
|
|
54
|
+
|
|
55
|
+
# Convert back
|
|
56
|
+
trajectory, task_info = to_waa_trajectory(episode)
|
|
57
|
+
|
|
58
|
+
JSON Schema Export:
|
|
59
|
+
# For external validation tools (e.g., JSON Schema validators, TypeScript codegen)
|
|
60
|
+
export_json_schema("episode.schema.json")
|
|
61
|
+
|
|
62
|
+
See Also:
|
|
63
|
+
- docs/schema/episode.schema.json - Full JSON Schema
|
|
64
|
+
- openadapt_ml.schema.episode - Model definitions
|
|
65
|
+
- openadapt_ml.schema.converters - Format converters
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
from openadapt_ml.schema.episode import (
|
|
69
|
+
SCHEMA_VERSION,
|
|
70
|
+
Episode,
|
|
71
|
+
Step,
|
|
72
|
+
Action,
|
|
73
|
+
Observation,
|
|
74
|
+
ActionType,
|
|
75
|
+
BenchmarkSource,
|
|
76
|
+
Coordinates,
|
|
77
|
+
BoundingBox,
|
|
78
|
+
UIElement,
|
|
79
|
+
validate_episode,
|
|
80
|
+
load_episode,
|
|
81
|
+
save_episode,
|
|
82
|
+
export_json_schema,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Perception integration (requires openadapt-grounding)
|
|
86
|
+
try:
|
|
87
|
+
from openadapt_ml.perception.integration import UIElementGraph
|
|
88
|
+
except ImportError:
|
|
89
|
+
# openadapt-grounding not installed, UIElementGraph unavailable
|
|
90
|
+
UIElementGraph = None # type: ignore
|
|
91
|
+
|
|
92
|
+
__all__ = [
|
|
93
|
+
# Version
|
|
94
|
+
"SCHEMA_VERSION",
|
|
95
|
+
# Core models
|
|
96
|
+
"Episode",
|
|
97
|
+
"Step",
|
|
98
|
+
"Action",
|
|
99
|
+
"Observation",
|
|
100
|
+
# Supporting models
|
|
101
|
+
"ActionType",
|
|
102
|
+
"BenchmarkSource",
|
|
103
|
+
"Coordinates",
|
|
104
|
+
"BoundingBox",
|
|
105
|
+
"UIElement",
|
|
106
|
+
# Perception integration
|
|
107
|
+
"UIElementGraph",
|
|
108
|
+
# Utilities
|
|
109
|
+
"validate_episode",
|
|
110
|
+
"load_episode",
|
|
111
|
+
"save_episode",
|
|
112
|
+
"export_json_schema",
|
|
113
|
+
]
|