openadapt-ml 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/baselines/__init__.py +121 -0
- openadapt_ml/baselines/adapter.py +185 -0
- openadapt_ml/baselines/cli.py +314 -0
- openadapt_ml/baselines/config.py +448 -0
- openadapt_ml/baselines/parser.py +922 -0
- openadapt_ml/baselines/prompts.py +787 -0
- openadapt_ml/benchmarks/__init__.py +13 -115
- openadapt_ml/benchmarks/agent.py +265 -421
- openadapt_ml/benchmarks/azure.py +28 -19
- openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
- openadapt_ml/benchmarks/cli.py +1722 -4847
- openadapt_ml/benchmarks/trace_export.py +631 -0
- openadapt_ml/benchmarks/viewer.py +22 -5
- openadapt_ml/benchmarks/vm_monitor.py +530 -29
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +47 -53
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +21 -20
- openadapt_ml/cloud/azure_inference.py +3 -5
- openadapt_ml/cloud/lambda_labs.py +722 -307
- openadapt_ml/cloud/local.py +2038 -487
- openadapt_ml/cloud/ssh_tunnel.py +68 -26
- openadapt_ml/datasets/next_action.py +40 -30
- openadapt_ml/evals/grounding.py +8 -3
- openadapt_ml/evals/plot_eval_metrics.py +15 -13
- openadapt_ml/evals/trajectory_matching.py +41 -26
- openadapt_ml/experiments/demo_prompt/format_demo.py +16 -6
- openadapt_ml/experiments/demo_prompt/run_experiment.py +26 -16
- openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
- openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
- openadapt_ml/experiments/representation_shootout/config.py +390 -0
- openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
- openadapt_ml/experiments/representation_shootout/runner.py +687 -0
- openadapt_ml/experiments/waa_demo/runner.py +29 -14
- openadapt_ml/export/parquet.py +36 -24
- openadapt_ml/grounding/detector.py +18 -14
- openadapt_ml/ingest/__init__.py +8 -6
- openadapt_ml/ingest/capture.py +25 -22
- openadapt_ml/ingest/loader.py +7 -4
- openadapt_ml/ingest/synthetic.py +189 -100
- openadapt_ml/models/api_adapter.py +14 -4
- openadapt_ml/models/base_adapter.py +10 -2
- openadapt_ml/models/providers/__init__.py +288 -0
- openadapt_ml/models/providers/anthropic.py +266 -0
- openadapt_ml/models/providers/base.py +299 -0
- openadapt_ml/models/providers/google.py +376 -0
- openadapt_ml/models/providers/openai.py +342 -0
- openadapt_ml/models/qwen_vl.py +46 -19
- openadapt_ml/perception/__init__.py +35 -0
- openadapt_ml/perception/integration.py +399 -0
- openadapt_ml/retrieval/demo_retriever.py +50 -24
- openadapt_ml/retrieval/embeddings.py +9 -8
- openadapt_ml/retrieval/retriever.py +3 -1
- openadapt_ml/runtime/__init__.py +50 -0
- openadapt_ml/runtime/policy.py +18 -5
- openadapt_ml/runtime/safety_gate.py +471 -0
- openadapt_ml/schema/__init__.py +9 -0
- openadapt_ml/schema/converters.py +74 -27
- openadapt_ml/schema/episode.py +31 -18
- openadapt_ml/scripts/capture_screenshots.py +530 -0
- openadapt_ml/scripts/compare.py +85 -54
- openadapt_ml/scripts/demo_policy.py +4 -1
- openadapt_ml/scripts/eval_policy.py +15 -9
- openadapt_ml/scripts/make_gif.py +1 -1
- openadapt_ml/scripts/prepare_synthetic.py +3 -1
- openadapt_ml/scripts/train.py +21 -9
- openadapt_ml/segmentation/README.md +920 -0
- openadapt_ml/segmentation/__init__.py +97 -0
- openadapt_ml/segmentation/adapters/__init__.py +5 -0
- openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
- openadapt_ml/segmentation/annotator.py +610 -0
- openadapt_ml/segmentation/cache.py +290 -0
- openadapt_ml/segmentation/cli.py +674 -0
- openadapt_ml/segmentation/deduplicator.py +656 -0
- openadapt_ml/segmentation/frame_describer.py +788 -0
- openadapt_ml/segmentation/pipeline.py +340 -0
- openadapt_ml/segmentation/schemas.py +622 -0
- openadapt_ml/segmentation/segment_extractor.py +634 -0
- openadapt_ml/training/azure_ops_viewer.py +1097 -0
- openadapt_ml/training/benchmark_viewer.py +52 -41
- openadapt_ml/training/shared_ui.py +7 -7
- openadapt_ml/training/stub_provider.py +57 -35
- openadapt_ml/training/trainer.py +143 -86
- openadapt_ml/training/trl_trainer.py +70 -21
- openadapt_ml/training/viewer.py +323 -108
- openadapt_ml/training/viewer_components.py +180 -0
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +215 -14
- openadapt_ml-0.2.1.dist-info/RECORD +116 -0
- openadapt_ml/benchmarks/base.py +0 -366
- openadapt_ml/benchmarks/data_collection.py +0 -432
- openadapt_ml/benchmarks/live_tracker.py +0 -180
- openadapt_ml/benchmarks/runner.py +0 -418
- openadapt_ml/benchmarks/waa.py +0 -761
- openadapt_ml/benchmarks/waa_live.py +0 -619
- openadapt_ml-0.2.0.dist-info/RECORD +0 -86
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
openadapt_ml/runtime/policy.py
CHANGED
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import json
|
|
4
4
|
import re
|
|
5
5
|
from dataclasses import dataclass
|
|
6
|
-
from typing import Any, Dict,
|
|
6
|
+
from typing import Any, Dict, Optional, Tuple
|
|
7
7
|
|
|
8
8
|
from PIL import Image
|
|
9
9
|
|
|
@@ -26,13 +26,16 @@ _TYPE_SOM_SIMPLE_RE = re.compile(r'TYPE\(["\']([^"\']*(?:\\.[^"\']*)*)["\']\)')
|
|
|
26
26
|
@dataclass
|
|
27
27
|
class PolicyOutput:
|
|
28
28
|
"""Result of a single policy step."""
|
|
29
|
+
|
|
29
30
|
action: Action
|
|
30
31
|
thought: Optional[str] = None
|
|
31
32
|
state: Optional[Dict[str, Any]] = None
|
|
32
33
|
raw_text: str = ""
|
|
33
34
|
|
|
34
35
|
|
|
35
|
-
def parse_thought_state_action(
|
|
36
|
+
def parse_thought_state_action(
|
|
37
|
+
text: str,
|
|
38
|
+
) -> Tuple[Optional[str], Optional[Dict[str, Any]], str]:
|
|
36
39
|
"""Parse Thought / State / Action blocks from model output.
|
|
37
40
|
|
|
38
41
|
Expected format:
|
|
@@ -54,12 +57,18 @@ def parse_thought_state_action(text: str) -> Tuple[Optional[str], Optional[Dict[
|
|
|
54
57
|
action_str: str = text.strip()
|
|
55
58
|
|
|
56
59
|
# Extract Thought - find the LAST occurrence (model's response, not template)
|
|
57
|
-
thought_matches = list(
|
|
60
|
+
thought_matches = list(
|
|
61
|
+
re.finditer(
|
|
62
|
+
r"Thought:\s*(.+?)(?=State:|Action:|$)", text, re.DOTALL | re.IGNORECASE
|
|
63
|
+
)
|
|
64
|
+
)
|
|
58
65
|
if thought_matches:
|
|
59
66
|
thought = thought_matches[-1].group(1).strip()
|
|
60
67
|
|
|
61
68
|
# Extract State (JSON on same line or next line) - last occurrence
|
|
62
|
-
state_matches = list(
|
|
69
|
+
state_matches = list(
|
|
70
|
+
re.finditer(r"State:\s*(\{.*?\})", text, re.DOTALL | re.IGNORECASE)
|
|
71
|
+
)
|
|
63
72
|
if state_matches:
|
|
64
73
|
try:
|
|
65
74
|
state = json.loads(state_matches[-1].group(1))
|
|
@@ -127,7 +136,11 @@ class AgentPolicy:
|
|
|
127
136
|
idx = int(m.group(1))
|
|
128
137
|
raw_text = m.group(2)
|
|
129
138
|
unescaped = raw_text.replace('\\"', '"').replace("\\\\", "\\")
|
|
130
|
-
return Action(
|
|
139
|
+
return Action(
|
|
140
|
+
type=ActionType.TYPE,
|
|
141
|
+
text=unescaped,
|
|
142
|
+
element=UIElement(element_id=str(idx)),
|
|
143
|
+
)
|
|
131
144
|
|
|
132
145
|
# TYPE("text") - SoM style without index
|
|
133
146
|
m = _TYPE_SOM_SIMPLE_RE.search(text)
|
|
@@ -0,0 +1,471 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Safety Gate Module - Deterministic safety checks for GUI automation actions.
|
|
3
|
+
|
|
4
|
+
The Safety Gate runs AFTER the policy decides but BEFORE execution, providing
|
|
5
|
+
a critical safety layer that prevents destructive or irreversible actions
|
|
6
|
+
without explicit human confirmation.
|
|
7
|
+
|
|
8
|
+
Example Usage:
|
|
9
|
+
from openadapt_ml.runtime import SafetyGate, SafetyConfig, SafetyDecision
|
|
10
|
+
|
|
11
|
+
# Create with default config
|
|
12
|
+
gate = SafetyGate()
|
|
13
|
+
|
|
14
|
+
# Or customize
|
|
15
|
+
config = SafetyConfig(
|
|
16
|
+
confidence_threshold=0.8,
|
|
17
|
+
loop_threshold=2,
|
|
18
|
+
expected_app="Chrome",
|
|
19
|
+
)
|
|
20
|
+
gate = SafetyGate(config)
|
|
21
|
+
|
|
22
|
+
# Evaluate action
|
|
23
|
+
assessment = gate.assess(action, observation, trace=history)
|
|
24
|
+
|
|
25
|
+
if assessment.decision == SafetyDecision.ALLOW:
|
|
26
|
+
execute(action)
|
|
27
|
+
elif assessment.decision == SafetyDecision.REQUIRE_CONFIRMATION:
|
|
28
|
+
if user_confirms(assessment.reason):
|
|
29
|
+
execute(action)
|
|
30
|
+
else: # BLOCK
|
|
31
|
+
log_warning(assessment.reason)
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
from __future__ import annotations
|
|
35
|
+
|
|
36
|
+
import hashlib
|
|
37
|
+
import re
|
|
38
|
+
from dataclasses import dataclass, field
|
|
39
|
+
from enum import Enum
|
|
40
|
+
from typing import Optional
|
|
41
|
+
|
|
42
|
+
from openadapt_ml.schema import Action, ActionType, Observation, Step
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class SafetyDecision(str, Enum):
|
|
46
|
+
"""Possible safety gate decisions."""
|
|
47
|
+
|
|
48
|
+
ALLOW = "allow" # Action can proceed without intervention
|
|
49
|
+
BLOCK = "block" # Action must not proceed under any circumstances
|
|
50
|
+
REQUIRE_CONFIRMATION = "require_confirmation" # Human must approve
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class SafetyAssessment:
|
|
55
|
+
"""Result of safety gate evaluation.
|
|
56
|
+
|
|
57
|
+
Attributes:
|
|
58
|
+
decision: The safety decision (ALLOW, BLOCK, or REQUIRE_CONFIRMATION)
|
|
59
|
+
reason: Human-readable explanation of the decision
|
|
60
|
+
triggered_rules: List of safety rule names that triggered this decision
|
|
61
|
+
confidence: How certain the gate is about this decision (1.0 = certain)
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
decision: SafetyDecision
|
|
65
|
+
reason: str
|
|
66
|
+
triggered_rules: list[str]
|
|
67
|
+
confidence: float = 1.0
|
|
68
|
+
|
|
69
|
+
def __str__(self) -> str:
|
|
70
|
+
return f"SafetyAssessment({self.decision.value}: {self.reason})"
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
# Default blocklist patterns - actions that are ALWAYS blocked
|
|
74
|
+
DEFAULT_BLOCKLIST_PATTERNS = [
|
|
75
|
+
# File/data destruction keywords
|
|
76
|
+
r"\bdelete\b",
|
|
77
|
+
r"\bremove\b",
|
|
78
|
+
r"\bformat\b",
|
|
79
|
+
r"\breset\b",
|
|
80
|
+
r"\bbroadcast\b",
|
|
81
|
+
# Database destruction
|
|
82
|
+
r"\bdrop\s+table\b",
|
|
83
|
+
r"\btruncate\b",
|
|
84
|
+
# Shell destruction commands
|
|
85
|
+
r"\brm\s+-rf\b",
|
|
86
|
+
r"\bsudo\s+rm\b",
|
|
87
|
+
]
|
|
88
|
+
|
|
89
|
+
# Default irreversible action patterns - require confirmation
|
|
90
|
+
DEFAULT_IRREVERSIBLE_PATTERNS = [
|
|
91
|
+
# Submission actions
|
|
92
|
+
r"\bsubmit\b",
|
|
93
|
+
r"\bsend\b",
|
|
94
|
+
r"\bapply\b",
|
|
95
|
+
r"\bconfirm\b",
|
|
96
|
+
# Document closure (potentially with unsaved changes)
|
|
97
|
+
r"\bclos(?:e|ing)\b",
|
|
98
|
+
# Financial actions
|
|
99
|
+
r"\bpurchase\b",
|
|
100
|
+
r"\bcheckout\b",
|
|
101
|
+
r"\bpay\b",
|
|
102
|
+
r"\bbuy\b",
|
|
103
|
+
r"\border\b",
|
|
104
|
+
]
|
|
105
|
+
|
|
106
|
+
# Default credential field patterns - typing to these requires confirmation
|
|
107
|
+
DEFAULT_CREDENTIAL_PATTERNS = [
|
|
108
|
+
r"password",
|
|
109
|
+
r"token",
|
|
110
|
+
r"secret",
|
|
111
|
+
r"api[_-]?key",
|
|
112
|
+
r"apikey",
|
|
113
|
+
r"credential",
|
|
114
|
+
r"auth",
|
|
115
|
+
r"private[_-]?key",
|
|
116
|
+
]
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@dataclass
|
|
120
|
+
class SafetyConfig:
|
|
121
|
+
"""Configuration for safety gate behavior.
|
|
122
|
+
|
|
123
|
+
All patterns are case-insensitive regular expressions.
|
|
124
|
+
|
|
125
|
+
Attributes:
|
|
126
|
+
blocklist_patterns: Patterns that trigger BLOCK decision
|
|
127
|
+
irreversible_patterns: Patterns that trigger REQUIRE_CONFIRMATION
|
|
128
|
+
confidence_threshold: Actions below this confidence require confirmation
|
|
129
|
+
loop_threshold: Same state visited this many times triggers BLOCK
|
|
130
|
+
credential_patterns: Field name patterns that indicate sensitive input
|
|
131
|
+
credential_allowlist: Override patterns to allow typing in specific fields
|
|
132
|
+
expected_app: Expected application name (None = don't check)
|
|
133
|
+
expected_window_pattern: Regex for expected window title (None = don't check)
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
blocklist_patterns: list[str] = field(
|
|
137
|
+
default_factory=lambda: DEFAULT_BLOCKLIST_PATTERNS.copy()
|
|
138
|
+
)
|
|
139
|
+
irreversible_patterns: list[str] = field(
|
|
140
|
+
default_factory=lambda: DEFAULT_IRREVERSIBLE_PATTERNS.copy()
|
|
141
|
+
)
|
|
142
|
+
confidence_threshold: float = 0.7
|
|
143
|
+
loop_threshold: int = 3
|
|
144
|
+
credential_patterns: list[str] = field(
|
|
145
|
+
default_factory=lambda: DEFAULT_CREDENTIAL_PATTERNS.copy()
|
|
146
|
+
)
|
|
147
|
+
credential_allowlist: list[str] = field(default_factory=list)
|
|
148
|
+
expected_app: Optional[str] = None
|
|
149
|
+
expected_window_pattern: Optional[str] = None
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class SafetyGate:
|
|
153
|
+
"""Deterministic safety gate for GUI automation actions.
|
|
154
|
+
|
|
155
|
+
The SafetyGate evaluates proposed actions against a set of safety rules
|
|
156
|
+
and returns a SafetyAssessment indicating whether the action should be
|
|
157
|
+
allowed, blocked, or require human confirmation.
|
|
158
|
+
|
|
159
|
+
Safety checks are evaluated in priority order:
|
|
160
|
+
1. Blocklist (immediate BLOCK)
|
|
161
|
+
2. Loop Detection (BLOCK if loop detected)
|
|
162
|
+
3. Credential Guard (REQUIRE_CONFIRMATION if typing to sensitive field)
|
|
163
|
+
4. Irreversibility (REQUIRE_CONFIRMATION for irreversible actions)
|
|
164
|
+
5. App/Window Mismatch (REQUIRE_CONFIRMATION if context mismatch)
|
|
165
|
+
6. Confidence Threshold (REQUIRE_CONFIRMATION if low confidence)
|
|
166
|
+
7. Default (ALLOW if no rules triggered)
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
def __init__(self, config: Optional[SafetyConfig] = None) -> None:
|
|
170
|
+
"""Initialize the safety gate with optional custom configuration.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
config: Safety configuration. If None, uses default SafetyConfig.
|
|
174
|
+
"""
|
|
175
|
+
self.config = config or SafetyConfig()
|
|
176
|
+
self._state_visit_counts: dict[str, int] = {}
|
|
177
|
+
|
|
178
|
+
# Pre-compile regex patterns for efficiency
|
|
179
|
+
self._blocklist_re = self._compile_patterns(self.config.blocklist_patterns)
|
|
180
|
+
self._irreversible_re = self._compile_patterns(
|
|
181
|
+
self.config.irreversible_patterns
|
|
182
|
+
)
|
|
183
|
+
self._credential_re = self._compile_patterns(self.config.credential_patterns)
|
|
184
|
+
self._credential_allowlist_re = self._compile_patterns(
|
|
185
|
+
self.config.credential_allowlist
|
|
186
|
+
)
|
|
187
|
+
self._window_pattern_re = (
|
|
188
|
+
re.compile(self.config.expected_window_pattern, re.IGNORECASE)
|
|
189
|
+
if self.config.expected_window_pattern
|
|
190
|
+
else None
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
def _compile_patterns(self, patterns: list[str]) -> list[re.Pattern]:
|
|
194
|
+
"""Compile a list of regex patterns (case-insensitive)."""
|
|
195
|
+
return [re.compile(p, re.IGNORECASE) for p in patterns]
|
|
196
|
+
|
|
197
|
+
def _matches_any(self, text: str, patterns: list[re.Pattern]) -> Optional[str]:
|
|
198
|
+
"""Check if text matches any pattern. Returns matching pattern or None."""
|
|
199
|
+
for pattern in patterns:
|
|
200
|
+
if pattern.search(text):
|
|
201
|
+
return pattern.pattern
|
|
202
|
+
return None
|
|
203
|
+
|
|
204
|
+
def _compute_state_hash(self, observation: Observation) -> str:
|
|
205
|
+
"""Compute a hash representing the current state for loop detection.
|
|
206
|
+
|
|
207
|
+
Uses window title, app name, and URL to identify state. This is a
|
|
208
|
+
lightweight approach that doesn't require image comparison.
|
|
209
|
+
"""
|
|
210
|
+
components = [
|
|
211
|
+
observation.window_title or "",
|
|
212
|
+
observation.app_name or "",
|
|
213
|
+
observation.url or "",
|
|
214
|
+
]
|
|
215
|
+
state_str = "|".join(components)
|
|
216
|
+
return hashlib.sha256(state_str.encode()).hexdigest()[:16]
|
|
217
|
+
|
|
218
|
+
def _get_action_text(self, action: Action) -> str:
|
|
219
|
+
"""Extract text content from an action for pattern matching."""
|
|
220
|
+
parts = []
|
|
221
|
+
|
|
222
|
+
# Include typed text
|
|
223
|
+
if action.text:
|
|
224
|
+
parts.append(action.text)
|
|
225
|
+
|
|
226
|
+
# Include key presses
|
|
227
|
+
if action.key:
|
|
228
|
+
parts.append(action.key)
|
|
229
|
+
|
|
230
|
+
# Include URL for navigation
|
|
231
|
+
if action.url:
|
|
232
|
+
parts.append(action.url)
|
|
233
|
+
|
|
234
|
+
# Include app name for open/close
|
|
235
|
+
if action.app_name:
|
|
236
|
+
parts.append(action.app_name)
|
|
237
|
+
|
|
238
|
+
# Include window title for focus
|
|
239
|
+
if action.window_title:
|
|
240
|
+
parts.append(action.window_title)
|
|
241
|
+
|
|
242
|
+
# Include raw action data if present
|
|
243
|
+
if action.raw:
|
|
244
|
+
for value in action.raw.values():
|
|
245
|
+
if isinstance(value, str):
|
|
246
|
+
parts.append(value)
|
|
247
|
+
|
|
248
|
+
return " ".join(parts)
|
|
249
|
+
|
|
250
|
+
def _get_target_field_name(self, action: Action, observation: Observation) -> str:
|
|
251
|
+
"""Get the name/label of the field targeted by a TYPE action."""
|
|
252
|
+
parts = []
|
|
253
|
+
|
|
254
|
+
# Check action's target element
|
|
255
|
+
if action.element:
|
|
256
|
+
if action.element.name:
|
|
257
|
+
parts.append(action.element.name)
|
|
258
|
+
if action.element.role:
|
|
259
|
+
parts.append(action.element.role)
|
|
260
|
+
if action.element.automation_id:
|
|
261
|
+
parts.append(action.element.automation_id)
|
|
262
|
+
|
|
263
|
+
# Check observation's focused element
|
|
264
|
+
if observation.focused_element:
|
|
265
|
+
if observation.focused_element.name:
|
|
266
|
+
parts.append(observation.focused_element.name)
|
|
267
|
+
if observation.focused_element.role:
|
|
268
|
+
parts.append(observation.focused_element.role)
|
|
269
|
+
if observation.focused_element.automation_id:
|
|
270
|
+
parts.append(observation.focused_element.automation_id)
|
|
271
|
+
|
|
272
|
+
return " ".join(parts)
|
|
273
|
+
|
|
274
|
+
def _get_action_confidence(self, action: Action) -> Optional[float]:
|
|
275
|
+
"""Extract confidence score from action's raw metadata."""
|
|
276
|
+
if action.raw and "confidence" in action.raw:
|
|
277
|
+
try:
|
|
278
|
+
return float(action.raw["confidence"])
|
|
279
|
+
except (TypeError, ValueError):
|
|
280
|
+
pass
|
|
281
|
+
return None
|
|
282
|
+
|
|
283
|
+
def _check_blocklist(self, action: Action) -> Optional[SafetyAssessment]:
|
|
284
|
+
"""Check if action text matches any blocklist pattern."""
|
|
285
|
+
action_text = self._get_action_text(action)
|
|
286
|
+
matched = self._matches_any(action_text, self._blocklist_re)
|
|
287
|
+
|
|
288
|
+
if matched:
|
|
289
|
+
return SafetyAssessment(
|
|
290
|
+
decision=SafetyDecision.BLOCK,
|
|
291
|
+
reason=f"Action contains blocked keyword pattern: {matched}",
|
|
292
|
+
triggered_rules=["blocklist"],
|
|
293
|
+
confidence=1.0,
|
|
294
|
+
)
|
|
295
|
+
return None
|
|
296
|
+
|
|
297
|
+
def _check_loop_detection(
|
|
298
|
+
self, observation: Observation
|
|
299
|
+
) -> Optional[SafetyAssessment]:
|
|
300
|
+
"""Check if we're stuck in a loop visiting the same state."""
|
|
301
|
+
state_hash = self._compute_state_hash(observation)
|
|
302
|
+
self._state_visit_counts[state_hash] = (
|
|
303
|
+
self._state_visit_counts.get(state_hash, 0) + 1
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
visit_count = self._state_visit_counts[state_hash]
|
|
307
|
+
if visit_count >= self.config.loop_threshold:
|
|
308
|
+
return SafetyAssessment(
|
|
309
|
+
decision=SafetyDecision.BLOCK,
|
|
310
|
+
reason=f"Loop detected: same state visited {visit_count} times "
|
|
311
|
+
f"(threshold: {self.config.loop_threshold})",
|
|
312
|
+
triggered_rules=["loop_detection"],
|
|
313
|
+
confidence=1.0,
|
|
314
|
+
)
|
|
315
|
+
return None
|
|
316
|
+
|
|
317
|
+
def _check_credential_guard(
|
|
318
|
+
self, action: Action, observation: Observation
|
|
319
|
+
) -> Optional[SafetyAssessment]:
|
|
320
|
+
"""Check if TYPE action targets a credential/sensitive field."""
|
|
321
|
+
if action.type != ActionType.TYPE:
|
|
322
|
+
return None
|
|
323
|
+
|
|
324
|
+
field_name = self._get_target_field_name(action, observation)
|
|
325
|
+
if not field_name:
|
|
326
|
+
return None
|
|
327
|
+
|
|
328
|
+
# Check if field is in allowlist (override)
|
|
329
|
+
if self._matches_any(field_name, self._credential_allowlist_re):
|
|
330
|
+
return None
|
|
331
|
+
|
|
332
|
+
# Check if field matches credential patterns
|
|
333
|
+
matched = self._matches_any(field_name, self._credential_re)
|
|
334
|
+
if matched:
|
|
335
|
+
return SafetyAssessment(
|
|
336
|
+
decision=SafetyDecision.REQUIRE_CONFIRMATION,
|
|
337
|
+
reason=f"Typing into credential field matching '{matched}': {field_name}",
|
|
338
|
+
triggered_rules=["credential_guard"],
|
|
339
|
+
confidence=1.0,
|
|
340
|
+
)
|
|
341
|
+
return None
|
|
342
|
+
|
|
343
|
+
def _check_irreversibility(self, action: Action) -> Optional[SafetyAssessment]:
|
|
344
|
+
"""Check if action appears to be irreversible (submit, send, etc.)."""
|
|
345
|
+
action_text = self._get_action_text(action)
|
|
346
|
+
matched = self._matches_any(action_text, self._irreversible_re)
|
|
347
|
+
|
|
348
|
+
if matched:
|
|
349
|
+
return SafetyAssessment(
|
|
350
|
+
decision=SafetyDecision.REQUIRE_CONFIRMATION,
|
|
351
|
+
reason=f"Action may be irreversible (matched pattern: {matched})",
|
|
352
|
+
triggered_rules=["irreversibility"],
|
|
353
|
+
confidence=0.9,
|
|
354
|
+
)
|
|
355
|
+
return None
|
|
356
|
+
|
|
357
|
+
def _check_app_window_mismatch(
|
|
358
|
+
self, observation: Observation
|
|
359
|
+
) -> Optional[SafetyAssessment]:
|
|
360
|
+
"""Check if current app/window doesn't match expected context."""
|
|
361
|
+
triggered_rules = []
|
|
362
|
+
reasons = []
|
|
363
|
+
|
|
364
|
+
# Check app name
|
|
365
|
+
if self.config.expected_app:
|
|
366
|
+
current_app = observation.app_name or ""
|
|
367
|
+
if current_app.lower() != self.config.expected_app.lower():
|
|
368
|
+
triggered_rules.append("app_mismatch")
|
|
369
|
+
reasons.append(
|
|
370
|
+
f"Expected app '{self.config.expected_app}', got '{current_app}'"
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
# Check window title pattern
|
|
374
|
+
if self._window_pattern_re:
|
|
375
|
+
current_title = observation.window_title or ""
|
|
376
|
+
if not self._window_pattern_re.search(current_title):
|
|
377
|
+
triggered_rules.append("window_mismatch")
|
|
378
|
+
reasons.append(
|
|
379
|
+
f"Window title '{current_title}' doesn't match pattern "
|
|
380
|
+
f"'{self.config.expected_window_pattern}'"
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
if triggered_rules:
|
|
384
|
+
return SafetyAssessment(
|
|
385
|
+
decision=SafetyDecision.REQUIRE_CONFIRMATION,
|
|
386
|
+
reason="; ".join(reasons),
|
|
387
|
+
triggered_rules=triggered_rules,
|
|
388
|
+
confidence=0.95,
|
|
389
|
+
)
|
|
390
|
+
return None
|
|
391
|
+
|
|
392
|
+
def _check_confidence_threshold(self, action: Action) -> Optional[SafetyAssessment]:
|
|
393
|
+
"""Check if action confidence is below threshold."""
|
|
394
|
+
confidence = self._get_action_confidence(action)
|
|
395
|
+
|
|
396
|
+
if confidence is not None and confidence < self.config.confidence_threshold:
|
|
397
|
+
return SafetyAssessment(
|
|
398
|
+
decision=SafetyDecision.REQUIRE_CONFIRMATION,
|
|
399
|
+
reason=f"Action confidence ({confidence:.2f}) below threshold "
|
|
400
|
+
f"({self.config.confidence_threshold})",
|
|
401
|
+
triggered_rules=["confidence_threshold"],
|
|
402
|
+
confidence=1.0,
|
|
403
|
+
)
|
|
404
|
+
return None
|
|
405
|
+
|
|
406
|
+
def assess(
|
|
407
|
+
self,
|
|
408
|
+
action: Action,
|
|
409
|
+
observation: Observation,
|
|
410
|
+
trace: Optional[list[Step]] = None,
|
|
411
|
+
) -> SafetyAssessment:
|
|
412
|
+
"""Evaluate an action against all safety rules.
|
|
413
|
+
|
|
414
|
+
Checks are evaluated in priority order. First triggered rule wins.
|
|
415
|
+
|
|
416
|
+
Args:
|
|
417
|
+
action: The action proposed by the policy
|
|
418
|
+
observation: Current state observation
|
|
419
|
+
trace: Execution history (currently unused, reserved for future use)
|
|
420
|
+
|
|
421
|
+
Returns:
|
|
422
|
+
SafetyAssessment with decision and reasoning
|
|
423
|
+
"""
|
|
424
|
+
# Priority 1: Blocklist (BLOCK)
|
|
425
|
+
assessment = self._check_blocklist(action)
|
|
426
|
+
if assessment:
|
|
427
|
+
return assessment
|
|
428
|
+
|
|
429
|
+
# Priority 2: Loop Detection (BLOCK)
|
|
430
|
+
assessment = self._check_loop_detection(observation)
|
|
431
|
+
if assessment:
|
|
432
|
+
return assessment
|
|
433
|
+
|
|
434
|
+
# Priority 3: Credential Guard (REQUIRE_CONFIRMATION)
|
|
435
|
+
assessment = self._check_credential_guard(action, observation)
|
|
436
|
+
if assessment:
|
|
437
|
+
return assessment
|
|
438
|
+
|
|
439
|
+
# Priority 4: Irreversibility (REQUIRE_CONFIRMATION)
|
|
440
|
+
assessment = self._check_irreversibility(action)
|
|
441
|
+
if assessment:
|
|
442
|
+
return assessment
|
|
443
|
+
|
|
444
|
+
# Priority 5: App/Window Mismatch (REQUIRE_CONFIRMATION)
|
|
445
|
+
assessment = self._check_app_window_mismatch(observation)
|
|
446
|
+
if assessment:
|
|
447
|
+
return assessment
|
|
448
|
+
|
|
449
|
+
# Priority 6: Confidence Threshold (REQUIRE_CONFIRMATION)
|
|
450
|
+
assessment = self._check_confidence_threshold(action)
|
|
451
|
+
if assessment:
|
|
452
|
+
return assessment
|
|
453
|
+
|
|
454
|
+
# Default: ALLOW
|
|
455
|
+
return SafetyAssessment(
|
|
456
|
+
decision=SafetyDecision.ALLOW,
|
|
457
|
+
reason="All safety checks passed",
|
|
458
|
+
triggered_rules=[],
|
|
459
|
+
confidence=1.0,
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
def reset(self) -> None:
|
|
463
|
+
"""Clear internal state (loop detection history).
|
|
464
|
+
|
|
465
|
+
Call this method between episodes to reset the state visit counts.
|
|
466
|
+
"""
|
|
467
|
+
self._state_visit_counts.clear()
|
|
468
|
+
|
|
469
|
+
def get_state_visit_counts(self) -> dict[str, int]:
|
|
470
|
+
"""Get current state visit counts (for debugging/monitoring)."""
|
|
471
|
+
return self._state_visit_counts.copy()
|
openadapt_ml/schema/__init__.py
CHANGED
|
@@ -82,6 +82,13 @@ from openadapt_ml.schema.episode import (
|
|
|
82
82
|
export_json_schema,
|
|
83
83
|
)
|
|
84
84
|
|
|
85
|
+
# Perception integration (requires openadapt-grounding)
|
|
86
|
+
try:
|
|
87
|
+
from openadapt_ml.perception.integration import UIElementGraph
|
|
88
|
+
except ImportError:
|
|
89
|
+
# openadapt-grounding not installed, UIElementGraph unavailable
|
|
90
|
+
UIElementGraph = None # type: ignore
|
|
91
|
+
|
|
85
92
|
__all__ = [
|
|
86
93
|
# Version
|
|
87
94
|
"SCHEMA_VERSION",
|
|
@@ -96,6 +103,8 @@ __all__ = [
|
|
|
96
103
|
"Coordinates",
|
|
97
104
|
"BoundingBox",
|
|
98
105
|
"UIElement",
|
|
106
|
+
# Perception integration
|
|
107
|
+
"UIElementGraph",
|
|
99
108
|
# Utilities
|
|
100
109
|
"validate_episode",
|
|
101
110
|
"load_episode",
|