openadapt-ml 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. openadapt_ml/baselines/__init__.py +121 -0
  2. openadapt_ml/baselines/adapter.py +185 -0
  3. openadapt_ml/baselines/cli.py +314 -0
  4. openadapt_ml/baselines/config.py +448 -0
  5. openadapt_ml/baselines/parser.py +922 -0
  6. openadapt_ml/baselines/prompts.py +787 -0
  7. openadapt_ml/benchmarks/__init__.py +13 -107
  8. openadapt_ml/benchmarks/agent.py +297 -374
  9. openadapt_ml/benchmarks/azure.py +62 -24
  10. openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
  11. openadapt_ml/benchmarks/cli.py +1874 -751
  12. openadapt_ml/benchmarks/trace_export.py +631 -0
  13. openadapt_ml/benchmarks/viewer.py +1236 -0
  14. openadapt_ml/benchmarks/vm_monitor.py +1111 -0
  15. openadapt_ml/benchmarks/waa_deploy/Dockerfile +216 -0
  16. openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
  17. openadapt_ml/benchmarks/waa_deploy/api_agent.py +540 -0
  18. openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
  19. openadapt_ml/cloud/azure_inference.py +3 -5
  20. openadapt_ml/cloud/lambda_labs.py +722 -307
  21. openadapt_ml/cloud/local.py +3194 -89
  22. openadapt_ml/cloud/ssh_tunnel.py +595 -0
  23. openadapt_ml/datasets/next_action.py +125 -96
  24. openadapt_ml/evals/grounding.py +32 -9
  25. openadapt_ml/evals/plot_eval_metrics.py +15 -13
  26. openadapt_ml/evals/trajectory_matching.py +120 -57
  27. openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
  28. openadapt_ml/experiments/demo_prompt/format_demo.py +236 -0
  29. openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
  30. openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
  31. openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
  32. openadapt_ml/experiments/demo_prompt/run_experiment.py +541 -0
  33. openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
  34. openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
  35. openadapt_ml/experiments/representation_shootout/config.py +390 -0
  36. openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
  37. openadapt_ml/experiments/representation_shootout/runner.py +687 -0
  38. openadapt_ml/experiments/waa_demo/__init__.py +10 -0
  39. openadapt_ml/experiments/waa_demo/demos.py +357 -0
  40. openadapt_ml/experiments/waa_demo/runner.py +732 -0
  41. openadapt_ml/experiments/waa_demo/tasks.py +151 -0
  42. openadapt_ml/export/__init__.py +9 -0
  43. openadapt_ml/export/__main__.py +6 -0
  44. openadapt_ml/export/cli.py +89 -0
  45. openadapt_ml/export/parquet.py +277 -0
  46. openadapt_ml/grounding/detector.py +18 -14
  47. openadapt_ml/ingest/__init__.py +11 -10
  48. openadapt_ml/ingest/capture.py +97 -86
  49. openadapt_ml/ingest/loader.py +120 -69
  50. openadapt_ml/ingest/synthetic.py +344 -193
  51. openadapt_ml/models/api_adapter.py +14 -4
  52. openadapt_ml/models/base_adapter.py +10 -2
  53. openadapt_ml/models/providers/__init__.py +288 -0
  54. openadapt_ml/models/providers/anthropic.py +266 -0
  55. openadapt_ml/models/providers/base.py +299 -0
  56. openadapt_ml/models/providers/google.py +376 -0
  57. openadapt_ml/models/providers/openai.py +342 -0
  58. openadapt_ml/models/qwen_vl.py +46 -19
  59. openadapt_ml/perception/__init__.py +35 -0
  60. openadapt_ml/perception/integration.py +399 -0
  61. openadapt_ml/retrieval/README.md +226 -0
  62. openadapt_ml/retrieval/USAGE.md +391 -0
  63. openadapt_ml/retrieval/__init__.py +91 -0
  64. openadapt_ml/retrieval/demo_retriever.py +843 -0
  65. openadapt_ml/retrieval/embeddings.py +630 -0
  66. openadapt_ml/retrieval/index.py +194 -0
  67. openadapt_ml/retrieval/retriever.py +162 -0
  68. openadapt_ml/runtime/__init__.py +50 -0
  69. openadapt_ml/runtime/policy.py +27 -14
  70. openadapt_ml/runtime/safety_gate.py +471 -0
  71. openadapt_ml/schema/__init__.py +113 -0
  72. openadapt_ml/schema/converters.py +588 -0
  73. openadapt_ml/schema/episode.py +470 -0
  74. openadapt_ml/scripts/capture_screenshots.py +530 -0
  75. openadapt_ml/scripts/compare.py +102 -61
  76. openadapt_ml/scripts/demo_policy.py +4 -1
  77. openadapt_ml/scripts/eval_policy.py +19 -14
  78. openadapt_ml/scripts/make_gif.py +1 -1
  79. openadapt_ml/scripts/prepare_synthetic.py +16 -17
  80. openadapt_ml/scripts/train.py +98 -75
  81. openadapt_ml/segmentation/README.md +920 -0
  82. openadapt_ml/segmentation/__init__.py +97 -0
  83. openadapt_ml/segmentation/adapters/__init__.py +5 -0
  84. openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
  85. openadapt_ml/segmentation/annotator.py +610 -0
  86. openadapt_ml/segmentation/cache.py +290 -0
  87. openadapt_ml/segmentation/cli.py +674 -0
  88. openadapt_ml/segmentation/deduplicator.py +656 -0
  89. openadapt_ml/segmentation/frame_describer.py +788 -0
  90. openadapt_ml/segmentation/pipeline.py +340 -0
  91. openadapt_ml/segmentation/schemas.py +622 -0
  92. openadapt_ml/segmentation/segment_extractor.py +634 -0
  93. openadapt_ml/training/azure_ops_viewer.py +1097 -0
  94. openadapt_ml/training/benchmark_viewer.py +3255 -19
  95. openadapt_ml/training/shared_ui.py +7 -7
  96. openadapt_ml/training/stub_provider.py +57 -35
  97. openadapt_ml/training/trainer.py +255 -441
  98. openadapt_ml/training/trl_trainer.py +403 -0
  99. openadapt_ml/training/viewer.py +323 -108
  100. openadapt_ml/training/viewer_components.py +180 -0
  101. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +312 -69
  102. openadapt_ml-0.2.1.dist-info/RECORD +116 -0
  103. openadapt_ml/benchmarks/base.py +0 -366
  104. openadapt_ml/benchmarks/data_collection.py +0 -432
  105. openadapt_ml/benchmarks/runner.py +0 -381
  106. openadapt_ml/benchmarks/waa.py +0 -704
  107. openadapt_ml/schemas/__init__.py +0 -53
  108. openadapt_ml/schemas/sessions.py +0 -122
  109. openadapt_ml/schemas/validation.py +0 -252
  110. openadapt_ml-0.1.0.dist-info/RECORD +0 -55
  111. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
  112. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,471 @@
1
+ """
2
+ Safety Gate Module - Deterministic safety checks for GUI automation actions.
3
+
4
+ The Safety Gate runs AFTER the policy decides but BEFORE execution, providing
5
+ a critical safety layer that prevents destructive or irreversible actions
6
+ without explicit human confirmation.
7
+
8
+ Example Usage:
9
+ from openadapt_ml.runtime import SafetyGate, SafetyConfig, SafetyDecision
10
+
11
+ # Create with default config
12
+ gate = SafetyGate()
13
+
14
+ # Or customize
15
+ config = SafetyConfig(
16
+ confidence_threshold=0.8,
17
+ loop_threshold=2,
18
+ expected_app="Chrome",
19
+ )
20
+ gate = SafetyGate(config)
21
+
22
+ # Evaluate action
23
+ assessment = gate.assess(action, observation, trace=history)
24
+
25
+ if assessment.decision == SafetyDecision.ALLOW:
26
+ execute(action)
27
+ elif assessment.decision == SafetyDecision.REQUIRE_CONFIRMATION:
28
+ if user_confirms(assessment.reason):
29
+ execute(action)
30
+ else: # BLOCK
31
+ log_warning(assessment.reason)
32
+ """
33
+
34
+ from __future__ import annotations
35
+
36
+ import hashlib
37
+ import re
38
+ from dataclasses import dataclass, field
39
+ from enum import Enum
40
+ from typing import Optional
41
+
42
+ from openadapt_ml.schema import Action, ActionType, Observation, Step
43
+
44
+
45
+ class SafetyDecision(str, Enum):
46
+ """Possible safety gate decisions."""
47
+
48
+ ALLOW = "allow" # Action can proceed without intervention
49
+ BLOCK = "block" # Action must not proceed under any circumstances
50
+ REQUIRE_CONFIRMATION = "require_confirmation" # Human must approve
51
+
52
+
53
+ @dataclass
54
+ class SafetyAssessment:
55
+ """Result of safety gate evaluation.
56
+
57
+ Attributes:
58
+ decision: The safety decision (ALLOW, BLOCK, or REQUIRE_CONFIRMATION)
59
+ reason: Human-readable explanation of the decision
60
+ triggered_rules: List of safety rule names that triggered this decision
61
+ confidence: How certain the gate is about this decision (1.0 = certain)
62
+ """
63
+
64
+ decision: SafetyDecision
65
+ reason: str
66
+ triggered_rules: list[str]
67
+ confidence: float = 1.0
68
+
69
+ def __str__(self) -> str:
70
+ return f"SafetyAssessment({self.decision.value}: {self.reason})"
71
+
72
+
73
+ # Default blocklist patterns - actions that are ALWAYS blocked
74
+ DEFAULT_BLOCKLIST_PATTERNS = [
75
+ # File/data destruction keywords
76
+ r"\bdelete\b",
77
+ r"\bremove\b",
78
+ r"\bformat\b",
79
+ r"\breset\b",
80
+ r"\bbroadcast\b",
81
+ # Database destruction
82
+ r"\bdrop\s+table\b",
83
+ r"\btruncate\b",
84
+ # Shell destruction commands
85
+ r"\brm\s+-rf\b",
86
+ r"\bsudo\s+rm\b",
87
+ ]
88
+
89
+ # Default irreversible action patterns - require confirmation
90
+ DEFAULT_IRREVERSIBLE_PATTERNS = [
91
+ # Submission actions
92
+ r"\bsubmit\b",
93
+ r"\bsend\b",
94
+ r"\bapply\b",
95
+ r"\bconfirm\b",
96
+ # Document closure (potentially with unsaved changes)
97
+ r"\bclos(?:e|ing)\b",
98
+ # Financial actions
99
+ r"\bpurchase\b",
100
+ r"\bcheckout\b",
101
+ r"\bpay\b",
102
+ r"\bbuy\b",
103
+ r"\border\b",
104
+ ]
105
+
106
+ # Default credential field patterns - typing to these requires confirmation
107
+ DEFAULT_CREDENTIAL_PATTERNS = [
108
+ r"password",
109
+ r"token",
110
+ r"secret",
111
+ r"api[_-]?key",
112
+ r"apikey",
113
+ r"credential",
114
+ r"auth",
115
+ r"private[_-]?key",
116
+ ]
117
+
118
+
119
+ @dataclass
120
+ class SafetyConfig:
121
+ """Configuration for safety gate behavior.
122
+
123
+ All patterns are case-insensitive regular expressions.
124
+
125
+ Attributes:
126
+ blocklist_patterns: Patterns that trigger BLOCK decision
127
+ irreversible_patterns: Patterns that trigger REQUIRE_CONFIRMATION
128
+ confidence_threshold: Actions below this confidence require confirmation
129
+ loop_threshold: Same state visited this many times triggers BLOCK
130
+ credential_patterns: Field name patterns that indicate sensitive input
131
+ credential_allowlist: Override patterns to allow typing in specific fields
132
+ expected_app: Expected application name (None = don't check)
133
+ expected_window_pattern: Regex for expected window title (None = don't check)
134
+ """
135
+
136
+ blocklist_patterns: list[str] = field(
137
+ default_factory=lambda: DEFAULT_BLOCKLIST_PATTERNS.copy()
138
+ )
139
+ irreversible_patterns: list[str] = field(
140
+ default_factory=lambda: DEFAULT_IRREVERSIBLE_PATTERNS.copy()
141
+ )
142
+ confidence_threshold: float = 0.7
143
+ loop_threshold: int = 3
144
+ credential_patterns: list[str] = field(
145
+ default_factory=lambda: DEFAULT_CREDENTIAL_PATTERNS.copy()
146
+ )
147
+ credential_allowlist: list[str] = field(default_factory=list)
148
+ expected_app: Optional[str] = None
149
+ expected_window_pattern: Optional[str] = None
150
+
151
+
152
+ class SafetyGate:
153
+ """Deterministic safety gate for GUI automation actions.
154
+
155
+ The SafetyGate evaluates proposed actions against a set of safety rules
156
+ and returns a SafetyAssessment indicating whether the action should be
157
+ allowed, blocked, or require human confirmation.
158
+
159
+ Safety checks are evaluated in priority order:
160
+ 1. Blocklist (immediate BLOCK)
161
+ 2. Loop Detection (BLOCK if loop detected)
162
+ 3. Credential Guard (REQUIRE_CONFIRMATION if typing to sensitive field)
163
+ 4. Irreversibility (REQUIRE_CONFIRMATION for irreversible actions)
164
+ 5. App/Window Mismatch (REQUIRE_CONFIRMATION if context mismatch)
165
+ 6. Confidence Threshold (REQUIRE_CONFIRMATION if low confidence)
166
+ 7. Default (ALLOW if no rules triggered)
167
+ """
168
+
169
+ def __init__(self, config: Optional[SafetyConfig] = None) -> None:
170
+ """Initialize the safety gate with optional custom configuration.
171
+
172
+ Args:
173
+ config: Safety configuration. If None, uses default SafetyConfig.
174
+ """
175
+ self.config = config or SafetyConfig()
176
+ self._state_visit_counts: dict[str, int] = {}
177
+
178
+ # Pre-compile regex patterns for efficiency
179
+ self._blocklist_re = self._compile_patterns(self.config.blocklist_patterns)
180
+ self._irreversible_re = self._compile_patterns(
181
+ self.config.irreversible_patterns
182
+ )
183
+ self._credential_re = self._compile_patterns(self.config.credential_patterns)
184
+ self._credential_allowlist_re = self._compile_patterns(
185
+ self.config.credential_allowlist
186
+ )
187
+ self._window_pattern_re = (
188
+ re.compile(self.config.expected_window_pattern, re.IGNORECASE)
189
+ if self.config.expected_window_pattern
190
+ else None
191
+ )
192
+
193
+ def _compile_patterns(self, patterns: list[str]) -> list[re.Pattern]:
194
+ """Compile a list of regex patterns (case-insensitive)."""
195
+ return [re.compile(p, re.IGNORECASE) for p in patterns]
196
+
197
+ def _matches_any(self, text: str, patterns: list[re.Pattern]) -> Optional[str]:
198
+ """Check if text matches any pattern. Returns matching pattern or None."""
199
+ for pattern in patterns:
200
+ if pattern.search(text):
201
+ return pattern.pattern
202
+ return None
203
+
204
+ def _compute_state_hash(self, observation: Observation) -> str:
205
+ """Compute a hash representing the current state for loop detection.
206
+
207
+ Uses window title, app name, and URL to identify state. This is a
208
+ lightweight approach that doesn't require image comparison.
209
+ """
210
+ components = [
211
+ observation.window_title or "",
212
+ observation.app_name or "",
213
+ observation.url or "",
214
+ ]
215
+ state_str = "|".join(components)
216
+ return hashlib.sha256(state_str.encode()).hexdigest()[:16]
217
+
218
+ def _get_action_text(self, action: Action) -> str:
219
+ """Extract text content from an action for pattern matching."""
220
+ parts = []
221
+
222
+ # Include typed text
223
+ if action.text:
224
+ parts.append(action.text)
225
+
226
+ # Include key presses
227
+ if action.key:
228
+ parts.append(action.key)
229
+
230
+ # Include URL for navigation
231
+ if action.url:
232
+ parts.append(action.url)
233
+
234
+ # Include app name for open/close
235
+ if action.app_name:
236
+ parts.append(action.app_name)
237
+
238
+ # Include window title for focus
239
+ if action.window_title:
240
+ parts.append(action.window_title)
241
+
242
+ # Include raw action data if present
243
+ if action.raw:
244
+ for value in action.raw.values():
245
+ if isinstance(value, str):
246
+ parts.append(value)
247
+
248
+ return " ".join(parts)
249
+
250
+ def _get_target_field_name(self, action: Action, observation: Observation) -> str:
251
+ """Get the name/label of the field targeted by a TYPE action."""
252
+ parts = []
253
+
254
+ # Check action's target element
255
+ if action.element:
256
+ if action.element.name:
257
+ parts.append(action.element.name)
258
+ if action.element.role:
259
+ parts.append(action.element.role)
260
+ if action.element.automation_id:
261
+ parts.append(action.element.automation_id)
262
+
263
+ # Check observation's focused element
264
+ if observation.focused_element:
265
+ if observation.focused_element.name:
266
+ parts.append(observation.focused_element.name)
267
+ if observation.focused_element.role:
268
+ parts.append(observation.focused_element.role)
269
+ if observation.focused_element.automation_id:
270
+ parts.append(observation.focused_element.automation_id)
271
+
272
+ return " ".join(parts)
273
+
274
+ def _get_action_confidence(self, action: Action) -> Optional[float]:
275
+ """Extract confidence score from action's raw metadata."""
276
+ if action.raw and "confidence" in action.raw:
277
+ try:
278
+ return float(action.raw["confidence"])
279
+ except (TypeError, ValueError):
280
+ pass
281
+ return None
282
+
283
+ def _check_blocklist(self, action: Action) -> Optional[SafetyAssessment]:
284
+ """Check if action text matches any blocklist pattern."""
285
+ action_text = self._get_action_text(action)
286
+ matched = self._matches_any(action_text, self._blocklist_re)
287
+
288
+ if matched:
289
+ return SafetyAssessment(
290
+ decision=SafetyDecision.BLOCK,
291
+ reason=f"Action contains blocked keyword pattern: {matched}",
292
+ triggered_rules=["blocklist"],
293
+ confidence=1.0,
294
+ )
295
+ return None
296
+
297
+ def _check_loop_detection(
298
+ self, observation: Observation
299
+ ) -> Optional[SafetyAssessment]:
300
+ """Check if we're stuck in a loop visiting the same state."""
301
+ state_hash = self._compute_state_hash(observation)
302
+ self._state_visit_counts[state_hash] = (
303
+ self._state_visit_counts.get(state_hash, 0) + 1
304
+ )
305
+
306
+ visit_count = self._state_visit_counts[state_hash]
307
+ if visit_count >= self.config.loop_threshold:
308
+ return SafetyAssessment(
309
+ decision=SafetyDecision.BLOCK,
310
+ reason=f"Loop detected: same state visited {visit_count} times "
311
+ f"(threshold: {self.config.loop_threshold})",
312
+ triggered_rules=["loop_detection"],
313
+ confidence=1.0,
314
+ )
315
+ return None
316
+
317
+ def _check_credential_guard(
318
+ self, action: Action, observation: Observation
319
+ ) -> Optional[SafetyAssessment]:
320
+ """Check if TYPE action targets a credential/sensitive field."""
321
+ if action.type != ActionType.TYPE:
322
+ return None
323
+
324
+ field_name = self._get_target_field_name(action, observation)
325
+ if not field_name:
326
+ return None
327
+
328
+ # Check if field is in allowlist (override)
329
+ if self._matches_any(field_name, self._credential_allowlist_re):
330
+ return None
331
+
332
+ # Check if field matches credential patterns
333
+ matched = self._matches_any(field_name, self._credential_re)
334
+ if matched:
335
+ return SafetyAssessment(
336
+ decision=SafetyDecision.REQUIRE_CONFIRMATION,
337
+ reason=f"Typing into credential field matching '{matched}': {field_name}",
338
+ triggered_rules=["credential_guard"],
339
+ confidence=1.0,
340
+ )
341
+ return None
342
+
343
+ def _check_irreversibility(self, action: Action) -> Optional[SafetyAssessment]:
344
+ """Check if action appears to be irreversible (submit, send, etc.)."""
345
+ action_text = self._get_action_text(action)
346
+ matched = self._matches_any(action_text, self._irreversible_re)
347
+
348
+ if matched:
349
+ return SafetyAssessment(
350
+ decision=SafetyDecision.REQUIRE_CONFIRMATION,
351
+ reason=f"Action may be irreversible (matched pattern: {matched})",
352
+ triggered_rules=["irreversibility"],
353
+ confidence=0.9,
354
+ )
355
+ return None
356
+
357
+ def _check_app_window_mismatch(
358
+ self, observation: Observation
359
+ ) -> Optional[SafetyAssessment]:
360
+ """Check if current app/window doesn't match expected context."""
361
+ triggered_rules = []
362
+ reasons = []
363
+
364
+ # Check app name
365
+ if self.config.expected_app:
366
+ current_app = observation.app_name or ""
367
+ if current_app.lower() != self.config.expected_app.lower():
368
+ triggered_rules.append("app_mismatch")
369
+ reasons.append(
370
+ f"Expected app '{self.config.expected_app}', got '{current_app}'"
371
+ )
372
+
373
+ # Check window title pattern
374
+ if self._window_pattern_re:
375
+ current_title = observation.window_title or ""
376
+ if not self._window_pattern_re.search(current_title):
377
+ triggered_rules.append("window_mismatch")
378
+ reasons.append(
379
+ f"Window title '{current_title}' doesn't match pattern "
380
+ f"'{self.config.expected_window_pattern}'"
381
+ )
382
+
383
+ if triggered_rules:
384
+ return SafetyAssessment(
385
+ decision=SafetyDecision.REQUIRE_CONFIRMATION,
386
+ reason="; ".join(reasons),
387
+ triggered_rules=triggered_rules,
388
+ confidence=0.95,
389
+ )
390
+ return None
391
+
392
+ def _check_confidence_threshold(self, action: Action) -> Optional[SafetyAssessment]:
393
+ """Check if action confidence is below threshold."""
394
+ confidence = self._get_action_confidence(action)
395
+
396
+ if confidence is not None and confidence < self.config.confidence_threshold:
397
+ return SafetyAssessment(
398
+ decision=SafetyDecision.REQUIRE_CONFIRMATION,
399
+ reason=f"Action confidence ({confidence:.2f}) below threshold "
400
+ f"({self.config.confidence_threshold})",
401
+ triggered_rules=["confidence_threshold"],
402
+ confidence=1.0,
403
+ )
404
+ return None
405
+
406
+ def assess(
407
+ self,
408
+ action: Action,
409
+ observation: Observation,
410
+ trace: Optional[list[Step]] = None,
411
+ ) -> SafetyAssessment:
412
+ """Evaluate an action against all safety rules.
413
+
414
+ Checks are evaluated in priority order. First triggered rule wins.
415
+
416
+ Args:
417
+ action: The action proposed by the policy
418
+ observation: Current state observation
419
+ trace: Execution history (currently unused, reserved for future use)
420
+
421
+ Returns:
422
+ SafetyAssessment with decision and reasoning
423
+ """
424
+ # Priority 1: Blocklist (BLOCK)
425
+ assessment = self._check_blocklist(action)
426
+ if assessment:
427
+ return assessment
428
+
429
+ # Priority 2: Loop Detection (BLOCK)
430
+ assessment = self._check_loop_detection(observation)
431
+ if assessment:
432
+ return assessment
433
+
434
+ # Priority 3: Credential Guard (REQUIRE_CONFIRMATION)
435
+ assessment = self._check_credential_guard(action, observation)
436
+ if assessment:
437
+ return assessment
438
+
439
+ # Priority 4: Irreversibility (REQUIRE_CONFIRMATION)
440
+ assessment = self._check_irreversibility(action)
441
+ if assessment:
442
+ return assessment
443
+
444
+ # Priority 5: App/Window Mismatch (REQUIRE_CONFIRMATION)
445
+ assessment = self._check_app_window_mismatch(observation)
446
+ if assessment:
447
+ return assessment
448
+
449
+ # Priority 6: Confidence Threshold (REQUIRE_CONFIRMATION)
450
+ assessment = self._check_confidence_threshold(action)
451
+ if assessment:
452
+ return assessment
453
+
454
+ # Default: ALLOW
455
+ return SafetyAssessment(
456
+ decision=SafetyDecision.ALLOW,
457
+ reason="All safety checks passed",
458
+ triggered_rules=[],
459
+ confidence=1.0,
460
+ )
461
+
462
+ def reset(self) -> None:
463
+ """Clear internal state (loop detection history).
464
+
465
+ Call this method between episodes to reset the state visit counts.
466
+ """
467
+ self._state_visit_counts.clear()
468
+
469
+ def get_state_visit_counts(self) -> dict[str, int]:
470
+ """Get current state visit counts (for debugging/monitoring)."""
471
+ return self._state_visit_counts.copy()
@@ -0,0 +1,113 @@
1
+ """
2
+ Episode Schema - Canonical format for GUI trajectory data.
3
+
4
+ A standardized contract for representing GUI automation episodes, enabling
5
+ interoperability across training pipelines, benchmarks, and recording tools.
6
+
7
+ Installation:
8
+ pip install openadapt-ml
9
+ # or: uv add openadapt-ml
10
+
11
+ Basic Usage:
12
+ from openadapt_ml.schema import Episode, Step, Action, Observation, ActionType
13
+
14
+ # Create an episode
15
+ episode = Episode(
16
+ episode_id="demo_001",
17
+ instruction="Open Notepad and type Hello World",
18
+ steps=[
19
+ Step(
20
+ step_index=0,
21
+ observation=Observation(screenshot_path="step_0.png"),
22
+ action=Action(type=ActionType.CLICK, coordinates={"x": 100, "y": 200}),
23
+ ),
24
+ Step(
25
+ step_index=1,
26
+ observation=Observation(screenshot_path="step_1.png"),
27
+ action=Action(type=ActionType.TYPE, text="Hello World"),
28
+ ),
29
+ ],
30
+ success=True,
31
+ )
32
+
33
+ # Save/load JSON
34
+ save_episode(episode, "episode.json")
35
+ episode = load_episode("episode.json")
36
+
37
+ # Validate external data
38
+ is_valid, error = validate_episode({"episode_id": "x", ...})
39
+
40
+ Coordinate Systems:
41
+ # Pixel coordinates (absolute)
42
+ Action(type=ActionType.CLICK, coordinates={"x": 512, "y": 384})
43
+
44
+ # Normalized coordinates (0.0-1.0, resolution-independent)
45
+ Action(type=ActionType.CLICK, normalized_coordinates=(0.5, 0.375))
46
+
47
+ # Both can coexist - use whichever fits your pipeline
48
+
49
+ Converting from Other Formats:
50
+ from openadapt_ml.schema.converters import from_waa_trajectory
51
+
52
+ # Convert Windows Agent Arena format
53
+ episode = from_waa_trajectory(trajectory_list, task_info_dict)
54
+
55
+ # Convert back
56
+ trajectory, task_info = to_waa_trajectory(episode)
57
+
58
+ JSON Schema Export:
59
+ # For external validation tools (e.g., JSON Schema validators, TypeScript codegen)
60
+ export_json_schema("episode.schema.json")
61
+
62
+ See Also:
63
+ - docs/schema/episode.schema.json - Full JSON Schema
64
+ - openadapt_ml.schema.episode - Model definitions
65
+ - openadapt_ml.schema.converters - Format converters
66
+ """
67
+
68
+ from openadapt_ml.schema.episode import (
69
+ SCHEMA_VERSION,
70
+ Episode,
71
+ Step,
72
+ Action,
73
+ Observation,
74
+ ActionType,
75
+ BenchmarkSource,
76
+ Coordinates,
77
+ BoundingBox,
78
+ UIElement,
79
+ validate_episode,
80
+ load_episode,
81
+ save_episode,
82
+ export_json_schema,
83
+ )
84
+
85
+ # Perception integration (requires openadapt-grounding)
86
+ try:
87
+ from openadapt_ml.perception.integration import UIElementGraph
88
+ except ImportError:
89
+ # openadapt-grounding not installed, UIElementGraph unavailable
90
+ UIElementGraph = None # type: ignore
91
+
92
+ __all__ = [
93
+ # Version
94
+ "SCHEMA_VERSION",
95
+ # Core models
96
+ "Episode",
97
+ "Step",
98
+ "Action",
99
+ "Observation",
100
+ # Supporting models
101
+ "ActionType",
102
+ "BenchmarkSource",
103
+ "Coordinates",
104
+ "BoundingBox",
105
+ "UIElement",
106
+ # Perception integration
107
+ "UIElementGraph",
108
+ # Utilities
109
+ "validate_episode",
110
+ "load_episode",
111
+ "save_episode",
112
+ "export_json_schema",
113
+ ]