openadapt-ml 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. openadapt_ml/baselines/__init__.py +121 -0
  2. openadapt_ml/baselines/adapter.py +185 -0
  3. openadapt_ml/baselines/cli.py +314 -0
  4. openadapt_ml/baselines/config.py +448 -0
  5. openadapt_ml/baselines/parser.py +922 -0
  6. openadapt_ml/baselines/prompts.py +787 -0
  7. openadapt_ml/benchmarks/__init__.py +13 -115
  8. openadapt_ml/benchmarks/agent.py +265 -421
  9. openadapt_ml/benchmarks/azure.py +28 -19
  10. openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
  11. openadapt_ml/benchmarks/cli.py +1722 -4847
  12. openadapt_ml/benchmarks/trace_export.py +631 -0
  13. openadapt_ml/benchmarks/viewer.py +22 -5
  14. openadapt_ml/benchmarks/vm_monitor.py +530 -29
  15. openadapt_ml/benchmarks/waa_deploy/Dockerfile +47 -53
  16. openadapt_ml/benchmarks/waa_deploy/api_agent.py +21 -20
  17. openadapt_ml/cloud/azure_inference.py +3 -5
  18. openadapt_ml/cloud/lambda_labs.py +722 -307
  19. openadapt_ml/cloud/local.py +2038 -487
  20. openadapt_ml/cloud/ssh_tunnel.py +68 -26
  21. openadapt_ml/datasets/next_action.py +40 -30
  22. openadapt_ml/evals/grounding.py +8 -3
  23. openadapt_ml/evals/plot_eval_metrics.py +15 -13
  24. openadapt_ml/evals/trajectory_matching.py +41 -26
  25. openadapt_ml/experiments/demo_prompt/format_demo.py +16 -6
  26. openadapt_ml/experiments/demo_prompt/run_experiment.py +26 -16
  27. openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
  28. openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
  29. openadapt_ml/experiments/representation_shootout/config.py +390 -0
  30. openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
  31. openadapt_ml/experiments/representation_shootout/runner.py +687 -0
  32. openadapt_ml/experiments/waa_demo/runner.py +29 -14
  33. openadapt_ml/export/parquet.py +36 -24
  34. openadapt_ml/grounding/detector.py +18 -14
  35. openadapt_ml/ingest/__init__.py +8 -6
  36. openadapt_ml/ingest/capture.py +25 -22
  37. openadapt_ml/ingest/loader.py +7 -4
  38. openadapt_ml/ingest/synthetic.py +189 -100
  39. openadapt_ml/models/api_adapter.py +14 -4
  40. openadapt_ml/models/base_adapter.py +10 -2
  41. openadapt_ml/models/providers/__init__.py +288 -0
  42. openadapt_ml/models/providers/anthropic.py +266 -0
  43. openadapt_ml/models/providers/base.py +299 -0
  44. openadapt_ml/models/providers/google.py +376 -0
  45. openadapt_ml/models/providers/openai.py +342 -0
  46. openadapt_ml/models/qwen_vl.py +46 -19
  47. openadapt_ml/perception/__init__.py +35 -0
  48. openadapt_ml/perception/integration.py +399 -0
  49. openadapt_ml/retrieval/demo_retriever.py +50 -24
  50. openadapt_ml/retrieval/embeddings.py +9 -8
  51. openadapt_ml/retrieval/retriever.py +3 -1
  52. openadapt_ml/runtime/__init__.py +50 -0
  53. openadapt_ml/runtime/policy.py +18 -5
  54. openadapt_ml/runtime/safety_gate.py +471 -0
  55. openadapt_ml/schema/__init__.py +9 -0
  56. openadapt_ml/schema/converters.py +74 -27
  57. openadapt_ml/schema/episode.py +31 -18
  58. openadapt_ml/scripts/capture_screenshots.py +530 -0
  59. openadapt_ml/scripts/compare.py +85 -54
  60. openadapt_ml/scripts/demo_policy.py +4 -1
  61. openadapt_ml/scripts/eval_policy.py +15 -9
  62. openadapt_ml/scripts/make_gif.py +1 -1
  63. openadapt_ml/scripts/prepare_synthetic.py +3 -1
  64. openadapt_ml/scripts/train.py +21 -9
  65. openadapt_ml/segmentation/README.md +920 -0
  66. openadapt_ml/segmentation/__init__.py +97 -0
  67. openadapt_ml/segmentation/adapters/__init__.py +5 -0
  68. openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
  69. openadapt_ml/segmentation/annotator.py +610 -0
  70. openadapt_ml/segmentation/cache.py +290 -0
  71. openadapt_ml/segmentation/cli.py +674 -0
  72. openadapt_ml/segmentation/deduplicator.py +656 -0
  73. openadapt_ml/segmentation/frame_describer.py +788 -0
  74. openadapt_ml/segmentation/pipeline.py +340 -0
  75. openadapt_ml/segmentation/schemas.py +622 -0
  76. openadapt_ml/segmentation/segment_extractor.py +634 -0
  77. openadapt_ml/training/azure_ops_viewer.py +1097 -0
  78. openadapt_ml/training/benchmark_viewer.py +52 -41
  79. openadapt_ml/training/shared_ui.py +7 -7
  80. openadapt_ml/training/stub_provider.py +57 -35
  81. openadapt_ml/training/trainer.py +143 -86
  82. openadapt_ml/training/trl_trainer.py +70 -21
  83. openadapt_ml/training/viewer.py +323 -108
  84. openadapt_ml/training/viewer_components.py +180 -0
  85. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +215 -14
  86. openadapt_ml-0.2.1.dist-info/RECORD +116 -0
  87. openadapt_ml/benchmarks/base.py +0 -366
  88. openadapt_ml/benchmarks/data_collection.py +0 -432
  89. openadapt_ml/benchmarks/live_tracker.py +0 -180
  90. openadapt_ml/benchmarks/runner.py +0 -418
  91. openadapt_ml/benchmarks/waa.py +0 -761
  92. openadapt_ml/benchmarks/waa_live.py +0 -619
  93. openadapt_ml-0.2.0.dist-info/RECORD +0 -86
  94. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
  95. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import json
4
4
  import re
5
5
  from dataclasses import dataclass
6
- from typing import Any, Dict, List, Optional, Tuple
6
+ from typing import Any, Dict, Optional, Tuple
7
7
 
8
8
  from PIL import Image
9
9
 
@@ -26,13 +26,16 @@ _TYPE_SOM_SIMPLE_RE = re.compile(r'TYPE\(["\']([^"\']*(?:\\.[^"\']*)*)["\']\)')
26
26
  @dataclass
27
27
  class PolicyOutput:
28
28
  """Result of a single policy step."""
29
+
29
30
  action: Action
30
31
  thought: Optional[str] = None
31
32
  state: Optional[Dict[str, Any]] = None
32
33
  raw_text: str = ""
33
34
 
34
35
 
35
- def parse_thought_state_action(text: str) -> Tuple[Optional[str], Optional[Dict[str, Any]], str]:
36
+ def parse_thought_state_action(
37
+ text: str,
38
+ ) -> Tuple[Optional[str], Optional[Dict[str, Any]], str]:
36
39
  """Parse Thought / State / Action blocks from model output.
37
40
 
38
41
  Expected format:
@@ -54,12 +57,18 @@ def parse_thought_state_action(text: str) -> Tuple[Optional[str], Optional[Dict[
54
57
  action_str: str = text.strip()
55
58
 
56
59
  # Extract Thought - find the LAST occurrence (model's response, not template)
57
- thought_matches = list(re.finditer(r"Thought:\s*(.+?)(?=State:|Action:|$)", text, re.DOTALL | re.IGNORECASE))
60
+ thought_matches = list(
61
+ re.finditer(
62
+ r"Thought:\s*(.+?)(?=State:|Action:|$)", text, re.DOTALL | re.IGNORECASE
63
+ )
64
+ )
58
65
  if thought_matches:
59
66
  thought = thought_matches[-1].group(1).strip()
60
67
 
61
68
  # Extract State (JSON on same line or next line) - last occurrence
62
- state_matches = list(re.finditer(r"State:\s*(\{.*?\})", text, re.DOTALL | re.IGNORECASE))
69
+ state_matches = list(
70
+ re.finditer(r"State:\s*(\{.*?\})", text, re.DOTALL | re.IGNORECASE)
71
+ )
63
72
  if state_matches:
64
73
  try:
65
74
  state = json.loads(state_matches[-1].group(1))
@@ -127,7 +136,11 @@ class AgentPolicy:
127
136
  idx = int(m.group(1))
128
137
  raw_text = m.group(2)
129
138
  unescaped = raw_text.replace('\\"', '"').replace("\\\\", "\\")
130
- return Action(type=ActionType.TYPE, text=unescaped, element=UIElement(element_id=str(idx)))
139
+ return Action(
140
+ type=ActionType.TYPE,
141
+ text=unescaped,
142
+ element=UIElement(element_id=str(idx)),
143
+ )
131
144
 
132
145
  # TYPE("text") - SoM style without index
133
146
  m = _TYPE_SOM_SIMPLE_RE.search(text)
@@ -0,0 +1,471 @@
1
+ """
2
+ Safety Gate Module - Deterministic safety checks for GUI automation actions.
3
+
4
+ The Safety Gate runs AFTER the policy decides but BEFORE execution, providing
5
+ a critical safety layer that prevents destructive or irreversible actions
6
+ without explicit human confirmation.
7
+
8
+ Example Usage:
9
+ from openadapt_ml.runtime import SafetyGate, SafetyConfig, SafetyDecision
10
+
11
+ # Create with default config
12
+ gate = SafetyGate()
13
+
14
+ # Or customize
15
+ config = SafetyConfig(
16
+ confidence_threshold=0.8,
17
+ loop_threshold=2,
18
+ expected_app="Chrome",
19
+ )
20
+ gate = SafetyGate(config)
21
+
22
+ # Evaluate action
23
+ assessment = gate.assess(action, observation, trace=history)
24
+
25
+ if assessment.decision == SafetyDecision.ALLOW:
26
+ execute(action)
27
+ elif assessment.decision == SafetyDecision.REQUIRE_CONFIRMATION:
28
+ if user_confirms(assessment.reason):
29
+ execute(action)
30
+ else: # BLOCK
31
+ log_warning(assessment.reason)
32
+ """
33
+
34
+ from __future__ import annotations
35
+
36
+ import hashlib
37
+ import re
38
+ from dataclasses import dataclass, field
39
+ from enum import Enum
40
+ from typing import Optional
41
+
42
+ from openadapt_ml.schema import Action, ActionType, Observation, Step
43
+
44
+
45
+ class SafetyDecision(str, Enum):
46
+ """Possible safety gate decisions."""
47
+
48
+ ALLOW = "allow" # Action can proceed without intervention
49
+ BLOCK = "block" # Action must not proceed under any circumstances
50
+ REQUIRE_CONFIRMATION = "require_confirmation" # Human must approve
51
+
52
+
53
+ @dataclass
54
+ class SafetyAssessment:
55
+ """Result of safety gate evaluation.
56
+
57
+ Attributes:
58
+ decision: The safety decision (ALLOW, BLOCK, or REQUIRE_CONFIRMATION)
59
+ reason: Human-readable explanation of the decision
60
+ triggered_rules: List of safety rule names that triggered this decision
61
+ confidence: How certain the gate is about this decision (1.0 = certain)
62
+ """
63
+
64
+ decision: SafetyDecision
65
+ reason: str
66
+ triggered_rules: list[str]
67
+ confidence: float = 1.0
68
+
69
+ def __str__(self) -> str:
70
+ return f"SafetyAssessment({self.decision.value}: {self.reason})"
71
+
72
+
73
+ # Default blocklist patterns - actions that are ALWAYS blocked
74
+ DEFAULT_BLOCKLIST_PATTERNS = [
75
+ # File/data destruction keywords
76
+ r"\bdelete\b",
77
+ r"\bremove\b",
78
+ r"\bformat\b",
79
+ r"\breset\b",
80
+ r"\bbroadcast\b",
81
+ # Database destruction
82
+ r"\bdrop\s+table\b",
83
+ r"\btruncate\b",
84
+ # Shell destruction commands
85
+ r"\brm\s+-rf\b",
86
+ r"\bsudo\s+rm\b",
87
+ ]
88
+
89
+ # Default irreversible action patterns - require confirmation
90
+ DEFAULT_IRREVERSIBLE_PATTERNS = [
91
+ # Submission actions
92
+ r"\bsubmit\b",
93
+ r"\bsend\b",
94
+ r"\bapply\b",
95
+ r"\bconfirm\b",
96
+ # Document closure (potentially with unsaved changes)
97
+ r"\bclos(?:e|ing)\b",
98
+ # Financial actions
99
+ r"\bpurchase\b",
100
+ r"\bcheckout\b",
101
+ r"\bpay\b",
102
+ r"\bbuy\b",
103
+ r"\border\b",
104
+ ]
105
+
106
+ # Default credential field patterns - typing to these requires confirmation
107
+ DEFAULT_CREDENTIAL_PATTERNS = [
108
+ r"password",
109
+ r"token",
110
+ r"secret",
111
+ r"api[_-]?key",
112
+ r"apikey",
113
+ r"credential",
114
+ r"auth",
115
+ r"private[_-]?key",
116
+ ]
117
+
118
+
119
+ @dataclass
120
+ class SafetyConfig:
121
+ """Configuration for safety gate behavior.
122
+
123
+ All patterns are case-insensitive regular expressions.
124
+
125
+ Attributes:
126
+ blocklist_patterns: Patterns that trigger BLOCK decision
127
+ irreversible_patterns: Patterns that trigger REQUIRE_CONFIRMATION
128
+ confidence_threshold: Actions below this confidence require confirmation
129
+ loop_threshold: Same state visited this many times triggers BLOCK
130
+ credential_patterns: Field name patterns that indicate sensitive input
131
+ credential_allowlist: Override patterns to allow typing in specific fields
132
+ expected_app: Expected application name (None = don't check)
133
+ expected_window_pattern: Regex for expected window title (None = don't check)
134
+ """
135
+
136
+ blocklist_patterns: list[str] = field(
137
+ default_factory=lambda: DEFAULT_BLOCKLIST_PATTERNS.copy()
138
+ )
139
+ irreversible_patterns: list[str] = field(
140
+ default_factory=lambda: DEFAULT_IRREVERSIBLE_PATTERNS.copy()
141
+ )
142
+ confidence_threshold: float = 0.7
143
+ loop_threshold: int = 3
144
+ credential_patterns: list[str] = field(
145
+ default_factory=lambda: DEFAULT_CREDENTIAL_PATTERNS.copy()
146
+ )
147
+ credential_allowlist: list[str] = field(default_factory=list)
148
+ expected_app: Optional[str] = None
149
+ expected_window_pattern: Optional[str] = None
150
+
151
+
152
+ class SafetyGate:
153
+ """Deterministic safety gate for GUI automation actions.
154
+
155
+ The SafetyGate evaluates proposed actions against a set of safety rules
156
+ and returns a SafetyAssessment indicating whether the action should be
157
+ allowed, blocked, or require human confirmation.
158
+
159
+ Safety checks are evaluated in priority order:
160
+ 1. Blocklist (immediate BLOCK)
161
+ 2. Loop Detection (BLOCK if loop detected)
162
+ 3. Credential Guard (REQUIRE_CONFIRMATION if typing to sensitive field)
163
+ 4. Irreversibility (REQUIRE_CONFIRMATION for irreversible actions)
164
+ 5. App/Window Mismatch (REQUIRE_CONFIRMATION if context mismatch)
165
+ 6. Confidence Threshold (REQUIRE_CONFIRMATION if low confidence)
166
+ 7. Default (ALLOW if no rules triggered)
167
+ """
168
+
169
+ def __init__(self, config: Optional[SafetyConfig] = None) -> None:
170
+ """Initialize the safety gate with optional custom configuration.
171
+
172
+ Args:
173
+ config: Safety configuration. If None, uses default SafetyConfig.
174
+ """
175
+ self.config = config or SafetyConfig()
176
+ self._state_visit_counts: dict[str, int] = {}
177
+
178
+ # Pre-compile regex patterns for efficiency
179
+ self._blocklist_re = self._compile_patterns(self.config.blocklist_patterns)
180
+ self._irreversible_re = self._compile_patterns(
181
+ self.config.irreversible_patterns
182
+ )
183
+ self._credential_re = self._compile_patterns(self.config.credential_patterns)
184
+ self._credential_allowlist_re = self._compile_patterns(
185
+ self.config.credential_allowlist
186
+ )
187
+ self._window_pattern_re = (
188
+ re.compile(self.config.expected_window_pattern, re.IGNORECASE)
189
+ if self.config.expected_window_pattern
190
+ else None
191
+ )
192
+
193
+ def _compile_patterns(self, patterns: list[str]) -> list[re.Pattern]:
194
+ """Compile a list of regex patterns (case-insensitive)."""
195
+ return [re.compile(p, re.IGNORECASE) for p in patterns]
196
+
197
+ def _matches_any(self, text: str, patterns: list[re.Pattern]) -> Optional[str]:
198
+ """Check if text matches any pattern. Returns matching pattern or None."""
199
+ for pattern in patterns:
200
+ if pattern.search(text):
201
+ return pattern.pattern
202
+ return None
203
+
204
+ def _compute_state_hash(self, observation: Observation) -> str:
205
+ """Compute a hash representing the current state for loop detection.
206
+
207
+ Uses window title, app name, and URL to identify state. This is a
208
+ lightweight approach that doesn't require image comparison.
209
+ """
210
+ components = [
211
+ observation.window_title or "",
212
+ observation.app_name or "",
213
+ observation.url or "",
214
+ ]
215
+ state_str = "|".join(components)
216
+ return hashlib.sha256(state_str.encode()).hexdigest()[:16]
217
+
218
+ def _get_action_text(self, action: Action) -> str:
219
+ """Extract text content from an action for pattern matching."""
220
+ parts = []
221
+
222
+ # Include typed text
223
+ if action.text:
224
+ parts.append(action.text)
225
+
226
+ # Include key presses
227
+ if action.key:
228
+ parts.append(action.key)
229
+
230
+ # Include URL for navigation
231
+ if action.url:
232
+ parts.append(action.url)
233
+
234
+ # Include app name for open/close
235
+ if action.app_name:
236
+ parts.append(action.app_name)
237
+
238
+ # Include window title for focus
239
+ if action.window_title:
240
+ parts.append(action.window_title)
241
+
242
+ # Include raw action data if present
243
+ if action.raw:
244
+ for value in action.raw.values():
245
+ if isinstance(value, str):
246
+ parts.append(value)
247
+
248
+ return " ".join(parts)
249
+
250
+ def _get_target_field_name(self, action: Action, observation: Observation) -> str:
251
+ """Get the name/label of the field targeted by a TYPE action."""
252
+ parts = []
253
+
254
+ # Check action's target element
255
+ if action.element:
256
+ if action.element.name:
257
+ parts.append(action.element.name)
258
+ if action.element.role:
259
+ parts.append(action.element.role)
260
+ if action.element.automation_id:
261
+ parts.append(action.element.automation_id)
262
+
263
+ # Check observation's focused element
264
+ if observation.focused_element:
265
+ if observation.focused_element.name:
266
+ parts.append(observation.focused_element.name)
267
+ if observation.focused_element.role:
268
+ parts.append(observation.focused_element.role)
269
+ if observation.focused_element.automation_id:
270
+ parts.append(observation.focused_element.automation_id)
271
+
272
+ return " ".join(parts)
273
+
274
+ def _get_action_confidence(self, action: Action) -> Optional[float]:
275
+ """Extract confidence score from action's raw metadata."""
276
+ if action.raw and "confidence" in action.raw:
277
+ try:
278
+ return float(action.raw["confidence"])
279
+ except (TypeError, ValueError):
280
+ pass
281
+ return None
282
+
283
+ def _check_blocklist(self, action: Action) -> Optional[SafetyAssessment]:
284
+ """Check if action text matches any blocklist pattern."""
285
+ action_text = self._get_action_text(action)
286
+ matched = self._matches_any(action_text, self._blocklist_re)
287
+
288
+ if matched:
289
+ return SafetyAssessment(
290
+ decision=SafetyDecision.BLOCK,
291
+ reason=f"Action contains blocked keyword pattern: {matched}",
292
+ triggered_rules=["blocklist"],
293
+ confidence=1.0,
294
+ )
295
+ return None
296
+
297
+ def _check_loop_detection(
298
+ self, observation: Observation
299
+ ) -> Optional[SafetyAssessment]:
300
+ """Check if we're stuck in a loop visiting the same state."""
301
+ state_hash = self._compute_state_hash(observation)
302
+ self._state_visit_counts[state_hash] = (
303
+ self._state_visit_counts.get(state_hash, 0) + 1
304
+ )
305
+
306
+ visit_count = self._state_visit_counts[state_hash]
307
+ if visit_count >= self.config.loop_threshold:
308
+ return SafetyAssessment(
309
+ decision=SafetyDecision.BLOCK,
310
+ reason=f"Loop detected: same state visited {visit_count} times "
311
+ f"(threshold: {self.config.loop_threshold})",
312
+ triggered_rules=["loop_detection"],
313
+ confidence=1.0,
314
+ )
315
+ return None
316
+
317
+ def _check_credential_guard(
318
+ self, action: Action, observation: Observation
319
+ ) -> Optional[SafetyAssessment]:
320
+ """Check if TYPE action targets a credential/sensitive field."""
321
+ if action.type != ActionType.TYPE:
322
+ return None
323
+
324
+ field_name = self._get_target_field_name(action, observation)
325
+ if not field_name:
326
+ return None
327
+
328
+ # Check if field is in allowlist (override)
329
+ if self._matches_any(field_name, self._credential_allowlist_re):
330
+ return None
331
+
332
+ # Check if field matches credential patterns
333
+ matched = self._matches_any(field_name, self._credential_re)
334
+ if matched:
335
+ return SafetyAssessment(
336
+ decision=SafetyDecision.REQUIRE_CONFIRMATION,
337
+ reason=f"Typing into credential field matching '{matched}': {field_name}",
338
+ triggered_rules=["credential_guard"],
339
+ confidence=1.0,
340
+ )
341
+ return None
342
+
343
+ def _check_irreversibility(self, action: Action) -> Optional[SafetyAssessment]:
344
+ """Check if action appears to be irreversible (submit, send, etc.)."""
345
+ action_text = self._get_action_text(action)
346
+ matched = self._matches_any(action_text, self._irreversible_re)
347
+
348
+ if matched:
349
+ return SafetyAssessment(
350
+ decision=SafetyDecision.REQUIRE_CONFIRMATION,
351
+ reason=f"Action may be irreversible (matched pattern: {matched})",
352
+ triggered_rules=["irreversibility"],
353
+ confidence=0.9,
354
+ )
355
+ return None
356
+
357
+ def _check_app_window_mismatch(
358
+ self, observation: Observation
359
+ ) -> Optional[SafetyAssessment]:
360
+ """Check if current app/window doesn't match expected context."""
361
+ triggered_rules = []
362
+ reasons = []
363
+
364
+ # Check app name
365
+ if self.config.expected_app:
366
+ current_app = observation.app_name or ""
367
+ if current_app.lower() != self.config.expected_app.lower():
368
+ triggered_rules.append("app_mismatch")
369
+ reasons.append(
370
+ f"Expected app '{self.config.expected_app}', got '{current_app}'"
371
+ )
372
+
373
+ # Check window title pattern
374
+ if self._window_pattern_re:
375
+ current_title = observation.window_title or ""
376
+ if not self._window_pattern_re.search(current_title):
377
+ triggered_rules.append("window_mismatch")
378
+ reasons.append(
379
+ f"Window title '{current_title}' doesn't match pattern "
380
+ f"'{self.config.expected_window_pattern}'"
381
+ )
382
+
383
+ if triggered_rules:
384
+ return SafetyAssessment(
385
+ decision=SafetyDecision.REQUIRE_CONFIRMATION,
386
+ reason="; ".join(reasons),
387
+ triggered_rules=triggered_rules,
388
+ confidence=0.95,
389
+ )
390
+ return None
391
+
392
+ def _check_confidence_threshold(self, action: Action) -> Optional[SafetyAssessment]:
393
+ """Check if action confidence is below threshold."""
394
+ confidence = self._get_action_confidence(action)
395
+
396
+ if confidence is not None and confidence < self.config.confidence_threshold:
397
+ return SafetyAssessment(
398
+ decision=SafetyDecision.REQUIRE_CONFIRMATION,
399
+ reason=f"Action confidence ({confidence:.2f}) below threshold "
400
+ f"({self.config.confidence_threshold})",
401
+ triggered_rules=["confidence_threshold"],
402
+ confidence=1.0,
403
+ )
404
+ return None
405
+
406
+ def assess(
407
+ self,
408
+ action: Action,
409
+ observation: Observation,
410
+ trace: Optional[list[Step]] = None,
411
+ ) -> SafetyAssessment:
412
+ """Evaluate an action against all safety rules.
413
+
414
+ Checks are evaluated in priority order. First triggered rule wins.
415
+
416
+ Args:
417
+ action: The action proposed by the policy
418
+ observation: Current state observation
419
+ trace: Execution history (currently unused, reserved for future use)
420
+
421
+ Returns:
422
+ SafetyAssessment with decision and reasoning
423
+ """
424
+ # Priority 1: Blocklist (BLOCK)
425
+ assessment = self._check_blocklist(action)
426
+ if assessment:
427
+ return assessment
428
+
429
+ # Priority 2: Loop Detection (BLOCK)
430
+ assessment = self._check_loop_detection(observation)
431
+ if assessment:
432
+ return assessment
433
+
434
+ # Priority 3: Credential Guard (REQUIRE_CONFIRMATION)
435
+ assessment = self._check_credential_guard(action, observation)
436
+ if assessment:
437
+ return assessment
438
+
439
+ # Priority 4: Irreversibility (REQUIRE_CONFIRMATION)
440
+ assessment = self._check_irreversibility(action)
441
+ if assessment:
442
+ return assessment
443
+
444
+ # Priority 5: App/Window Mismatch (REQUIRE_CONFIRMATION)
445
+ assessment = self._check_app_window_mismatch(observation)
446
+ if assessment:
447
+ return assessment
448
+
449
+ # Priority 6: Confidence Threshold (REQUIRE_CONFIRMATION)
450
+ assessment = self._check_confidence_threshold(action)
451
+ if assessment:
452
+ return assessment
453
+
454
+ # Default: ALLOW
455
+ return SafetyAssessment(
456
+ decision=SafetyDecision.ALLOW,
457
+ reason="All safety checks passed",
458
+ triggered_rules=[],
459
+ confidence=1.0,
460
+ )
461
+
462
+ def reset(self) -> None:
463
+ """Clear internal state (loop detection history).
464
+
465
+ Call this method between episodes to reset the state visit counts.
466
+ """
467
+ self._state_visit_counts.clear()
468
+
469
+ def get_state_visit_counts(self) -> dict[str, int]:
470
+ """Get current state visit counts (for debugging/monitoring)."""
471
+ return self._state_visit_counts.copy()
@@ -82,6 +82,13 @@ from openadapt_ml.schema.episode import (
82
82
  export_json_schema,
83
83
  )
84
84
 
85
+ # Perception integration (requires openadapt-grounding)
86
+ try:
87
+ from openadapt_ml.perception.integration import UIElementGraph
88
+ except ImportError:
89
+ # openadapt-grounding not installed, UIElementGraph unavailable
90
+ UIElementGraph = None # type: ignore
91
+
85
92
  __all__ = [
86
93
  # Version
87
94
  "SCHEMA_VERSION",
@@ -96,6 +103,8 @@ __all__ = [
96
103
  "Coordinates",
97
104
  "BoundingBox",
98
105
  "UIElement",
106
+ # Perception integration
107
+ "UIElementGraph",
99
108
  # Utilities
100
109
  "validate_episode",
101
110
  "load_episode",