openadapt-ml 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. openadapt_ml/baselines/__init__.py +121 -0
  2. openadapt_ml/baselines/adapter.py +185 -0
  3. openadapt_ml/baselines/cli.py +314 -0
  4. openadapt_ml/baselines/config.py +448 -0
  5. openadapt_ml/baselines/parser.py +922 -0
  6. openadapt_ml/baselines/prompts.py +787 -0
  7. openadapt_ml/benchmarks/__init__.py +13 -115
  8. openadapt_ml/benchmarks/agent.py +265 -421
  9. openadapt_ml/benchmarks/azure.py +28 -19
  10. openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
  11. openadapt_ml/benchmarks/cli.py +1722 -4847
  12. openadapt_ml/benchmarks/trace_export.py +631 -0
  13. openadapt_ml/benchmarks/viewer.py +22 -5
  14. openadapt_ml/benchmarks/vm_monitor.py +530 -29
  15. openadapt_ml/benchmarks/waa_deploy/Dockerfile +47 -53
  16. openadapt_ml/benchmarks/waa_deploy/api_agent.py +21 -20
  17. openadapt_ml/cloud/azure_inference.py +3 -5
  18. openadapt_ml/cloud/lambda_labs.py +722 -307
  19. openadapt_ml/cloud/local.py +2038 -487
  20. openadapt_ml/cloud/ssh_tunnel.py +68 -26
  21. openadapt_ml/datasets/next_action.py +40 -30
  22. openadapt_ml/evals/grounding.py +8 -3
  23. openadapt_ml/evals/plot_eval_metrics.py +15 -13
  24. openadapt_ml/evals/trajectory_matching.py +41 -26
  25. openadapt_ml/experiments/demo_prompt/format_demo.py +16 -6
  26. openadapt_ml/experiments/demo_prompt/run_experiment.py +26 -16
  27. openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
  28. openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
  29. openadapt_ml/experiments/representation_shootout/config.py +390 -0
  30. openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
  31. openadapt_ml/experiments/representation_shootout/runner.py +687 -0
  32. openadapt_ml/experiments/waa_demo/runner.py +29 -14
  33. openadapt_ml/export/parquet.py +36 -24
  34. openadapt_ml/grounding/detector.py +18 -14
  35. openadapt_ml/ingest/__init__.py +8 -6
  36. openadapt_ml/ingest/capture.py +25 -22
  37. openadapt_ml/ingest/loader.py +7 -4
  38. openadapt_ml/ingest/synthetic.py +189 -100
  39. openadapt_ml/models/api_adapter.py +14 -4
  40. openadapt_ml/models/base_adapter.py +10 -2
  41. openadapt_ml/models/providers/__init__.py +288 -0
  42. openadapt_ml/models/providers/anthropic.py +266 -0
  43. openadapt_ml/models/providers/base.py +299 -0
  44. openadapt_ml/models/providers/google.py +376 -0
  45. openadapt_ml/models/providers/openai.py +342 -0
  46. openadapt_ml/models/qwen_vl.py +46 -19
  47. openadapt_ml/perception/__init__.py +35 -0
  48. openadapt_ml/perception/integration.py +399 -0
  49. openadapt_ml/retrieval/demo_retriever.py +50 -24
  50. openadapt_ml/retrieval/embeddings.py +9 -8
  51. openadapt_ml/retrieval/retriever.py +3 -1
  52. openadapt_ml/runtime/__init__.py +50 -0
  53. openadapt_ml/runtime/policy.py +18 -5
  54. openadapt_ml/runtime/safety_gate.py +471 -0
  55. openadapt_ml/schema/__init__.py +9 -0
  56. openadapt_ml/schema/converters.py +74 -27
  57. openadapt_ml/schema/episode.py +31 -18
  58. openadapt_ml/scripts/capture_screenshots.py +530 -0
  59. openadapt_ml/scripts/compare.py +85 -54
  60. openadapt_ml/scripts/demo_policy.py +4 -1
  61. openadapt_ml/scripts/eval_policy.py +15 -9
  62. openadapt_ml/scripts/make_gif.py +1 -1
  63. openadapt_ml/scripts/prepare_synthetic.py +3 -1
  64. openadapt_ml/scripts/train.py +21 -9
  65. openadapt_ml/segmentation/README.md +920 -0
  66. openadapt_ml/segmentation/__init__.py +97 -0
  67. openadapt_ml/segmentation/adapters/__init__.py +5 -0
  68. openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
  69. openadapt_ml/segmentation/annotator.py +610 -0
  70. openadapt_ml/segmentation/cache.py +290 -0
  71. openadapt_ml/segmentation/cli.py +674 -0
  72. openadapt_ml/segmentation/deduplicator.py +656 -0
  73. openadapt_ml/segmentation/frame_describer.py +788 -0
  74. openadapt_ml/segmentation/pipeline.py +340 -0
  75. openadapt_ml/segmentation/schemas.py +622 -0
  76. openadapt_ml/segmentation/segment_extractor.py +634 -0
  77. openadapt_ml/training/azure_ops_viewer.py +1097 -0
  78. openadapt_ml/training/benchmark_viewer.py +52 -41
  79. openadapt_ml/training/shared_ui.py +7 -7
  80. openadapt_ml/training/stub_provider.py +57 -35
  81. openadapt_ml/training/trainer.py +143 -86
  82. openadapt_ml/training/trl_trainer.py +70 -21
  83. openadapt_ml/training/viewer.py +323 -108
  84. openadapt_ml/training/viewer_components.py +180 -0
  85. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +215 -14
  86. openadapt_ml-0.2.1.dist-info/RECORD +116 -0
  87. openadapt_ml/benchmarks/base.py +0 -366
  88. openadapt_ml/benchmarks/data_collection.py +0 -432
  89. openadapt_ml/benchmarks/live_tracker.py +0 -180
  90. openadapt_ml/benchmarks/runner.py +0 -418
  91. openadapt_ml/benchmarks/waa.py +0 -761
  92. openadapt_ml/benchmarks/waa_live.py +0 -619
  93. openadapt_ml-0.2.0.dist-info/RECORD +0 -86
  94. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
  95. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,399 @@
1
+ """
2
+ Integration Bridge between openadapt-grounding and openadapt-ml
3
+
4
+ This module provides the UIElementGraph class which wraps parsed UI elements
5
+ from openadapt-grounding parsers and converts them to the openadapt-ml schema.
6
+
7
+ Types imported from openadapt-grounding:
8
+ - Parser: Protocol for UI element parsers (OmniParser, UITars, etc.)
9
+ - Element: A detected UI element with normalized bounds
10
+ - LocatorResult: Result of attempting to locate an element
11
+ - RegistryEntry: A stable element that survived temporal filtering
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import uuid
17
+ from collections import Counter
18
+ from typing import Any, Literal, Optional, Union
19
+
20
+ from pydantic import BaseModel, Field
21
+
22
+ from openadapt_ml.schema.episode import BoundingBox, UIElement
23
+
24
+ # Lazy import for openadapt-grounding to make it an optional dependency
25
+ _grounding_available: Optional[bool] = None
26
+
27
+
28
+ def _check_grounding_available() -> bool:
29
+ """Check if openadapt-grounding is installed."""
30
+ global _grounding_available
31
+ if _grounding_available is None:
32
+ try:
33
+ import openadapt_grounding # noqa: F401
34
+
35
+ _grounding_available = True
36
+ except ImportError:
37
+ _grounding_available = False
38
+ return _grounding_available
39
+
40
+
41
+ def _get_grounding_types():
42
+ """Import and return openadapt-grounding types.
43
+
44
+ Raises:
45
+ ImportError: If openadapt-grounding is not installed
46
+ """
47
+ if not _check_grounding_available():
48
+ raise ImportError(
49
+ "openadapt-grounding is required for perception integration. "
50
+ "Install it with: pip install openadapt-grounding"
51
+ )
52
+ from openadapt_grounding import Element, LocatorResult, Parser, RegistryEntry
53
+
54
+ return Element, LocatorResult, Parser, RegistryEntry
55
+
56
+
57
+ # Source types for UI element graphs
58
+ SourceType = Literal["omniparser", "uitars", "ax", "dom", "mixed"]
59
+
60
+
61
+ def element_to_ui_element(
62
+ element: Any, # openadapt_grounding.Element
63
+ image_width: Optional[int] = None,
64
+ image_height: Optional[int] = None,
65
+ element_index: Optional[int] = None,
66
+ source: Optional[str] = None,
67
+ ) -> UIElement:
68
+ """Convert an openadapt-grounding Element to a schema UIElement.
69
+
70
+ Args:
71
+ element: Element from openadapt-grounding with normalized bounds
72
+ image_width: Image width in pixels (for coordinate conversion)
73
+ image_height: Image height in pixels (for coordinate conversion)
74
+ element_index: Optional index to use as element_id
75
+ source: Source parser name (stored in automation_id for reference)
76
+
77
+ Returns:
78
+ UIElement with converted bounds and properties
79
+ """
80
+ Element, _, _, _ = _get_grounding_types()
81
+
82
+ if not isinstance(element, Element):
83
+ raise TypeError(f"Expected Element, got {type(element)}")
84
+
85
+ # Extract normalized bounds (x, y, width, height)
86
+ norm_x, norm_y, norm_w, norm_h = element.bounds
87
+
88
+ # Convert to pixel coordinates if dimensions provided
89
+ if image_width is not None and image_height is not None:
90
+ bounds = BoundingBox(
91
+ x=int(norm_x * image_width),
92
+ y=int(norm_y * image_height),
93
+ width=int(norm_w * image_width),
94
+ height=int(norm_h * image_height),
95
+ )
96
+ else:
97
+ # Store normalized as integers (multiply by 10000 for precision)
98
+ # This allows using BoundingBox which requires int values
99
+ bounds = BoundingBox(
100
+ x=int(norm_x * 10000),
101
+ y=int(norm_y * 10000),
102
+ width=int(norm_w * 10000),
103
+ height=int(norm_h * 10000),
104
+ )
105
+
106
+ # Generate element_id from index if provided
107
+ element_id = str(element_index) if element_index is not None else None
108
+
109
+ return UIElement(
110
+ role=element.element_type,
111
+ name=element.text,
112
+ bounds=bounds,
113
+ element_id=element_id,
114
+ automation_id=source, # Store source in automation_id for reference
115
+ )
116
+
117
+
118
+ def ui_element_to_element(
119
+ ui_element: UIElement,
120
+ image_width: Optional[int] = None,
121
+ image_height: Optional[int] = None,
122
+ ) -> Any: # Returns openadapt_grounding.Element
123
+ """Convert a schema UIElement back to an openadapt-grounding Element.
124
+
125
+ Useful for evaluation or when passing elements back to grounding functions.
126
+
127
+ Args:
128
+ ui_element: UIElement from openadapt-ml schema
129
+ image_width: Image width in pixels (for coordinate normalization)
130
+ image_height: Image height in pixels (for coordinate normalization)
131
+
132
+ Returns:
133
+ Element with normalized bounds
134
+ """
135
+ Element, _, _, _ = _get_grounding_types()
136
+
137
+ if ui_element.bounds is None:
138
+ # Return element with zero bounds if no bounds available
139
+ return Element(
140
+ bounds=(0.0, 0.0, 0.0, 0.0),
141
+ text=ui_element.name,
142
+ element_type=ui_element.role or "unknown",
143
+ confidence=1.0,
144
+ )
145
+
146
+ # Convert pixel coordinates to normalized
147
+ if image_width is not None and image_height is not None:
148
+ norm_x = ui_element.bounds.x / image_width
149
+ norm_y = ui_element.bounds.y / image_height
150
+ norm_w = ui_element.bounds.width / image_width
151
+ norm_h = ui_element.bounds.height / image_height
152
+ else:
153
+ # Assume bounds are already in 10000-scale normalized format
154
+ norm_x = ui_element.bounds.x / 10000
155
+ norm_y = ui_element.bounds.y / 10000
156
+ norm_w = ui_element.bounds.width / 10000
157
+ norm_h = ui_element.bounds.height / 10000
158
+
159
+ return Element(
160
+ bounds=(norm_x, norm_y, norm_w, norm_h),
161
+ text=ui_element.name,
162
+ element_type=ui_element.role or "unknown",
163
+ confidence=1.0,
164
+ )
165
+
166
+
167
+ class UIElementGraph(BaseModel):
168
+ """A graph of UI elements parsed from a screenshot.
169
+
170
+ This class wraps a list of UIElement objects and tracks their source
171
+ (which parser produced them). It provides a bridge between openadapt-grounding
172
+ parsers and the openadapt-ml schema system.
173
+
174
+ Attributes:
175
+ graph_id: Unique identifier for this graph (UUID)
176
+ elements: List of UIElement objects
177
+ source: Primary source parser ("omniparser", "uitars", "ax", "dom", "mixed")
178
+ source_summary: Count of elements by source (e.g., {"omniparser": 15, "ax": 8})
179
+ timestamp_ms: Optional timestamp when the screenshot was captured
180
+ image_width: Original image width (for coordinate reference)
181
+ image_height: Original image height (for coordinate reference)
182
+
183
+ Example:
184
+ >>> from openadapt_grounding import OmniParserClient
185
+ >>> from openadapt_ml.perception import UIElementGraph
186
+ >>>
187
+ >>> parser = OmniParserClient(endpoint="http://localhost:8000")
188
+ >>> elements = parser.parse(image)
189
+ >>> graph = UIElementGraph.from_parser_output(elements, "omniparser")
190
+ >>> print(f"Found {len(graph.elements)} elements")
191
+ """
192
+
193
+ graph_id: str = Field(
194
+ default_factory=lambda: str(uuid.uuid4()),
195
+ description="Unique identifier for this graph",
196
+ )
197
+ elements: list[UIElement] = Field(
198
+ default_factory=list,
199
+ description="List of UI elements in the graph",
200
+ )
201
+ source: SourceType = Field(
202
+ default="mixed",
203
+ description="Primary source parser for the elements",
204
+ )
205
+ source_summary: dict[str, int] = Field(
206
+ default_factory=dict,
207
+ description="Count of elements by source",
208
+ )
209
+ timestamp_ms: Optional[int] = Field(
210
+ None,
211
+ description="Timestamp when the screenshot was captured (milliseconds)",
212
+ )
213
+ image_width: Optional[int] = Field(
214
+ None,
215
+ description="Original image width in pixels",
216
+ )
217
+ image_height: Optional[int] = Field(
218
+ None,
219
+ description="Original image height in pixels",
220
+ )
221
+
222
+ @classmethod
223
+ def from_parser_output(
224
+ cls,
225
+ elements: list[Any], # list[openadapt_grounding.Element]
226
+ source: Union[SourceType, str],
227
+ image_width: Optional[int] = None,
228
+ image_height: Optional[int] = None,
229
+ timestamp_ms: Optional[int] = None,
230
+ ) -> "UIElementGraph":
231
+ """Create a UIElementGraph from parser output.
232
+
233
+ Args:
234
+ elements: List of Element objects from openadapt-grounding parser
235
+ source: Parser source name ("omniparser", "uitars", "ax", "dom")
236
+ image_width: Image width in pixels (for coordinate conversion)
237
+ image_height: Image height in pixels (for coordinate conversion)
238
+ timestamp_ms: Optional timestamp when screenshot was captured
239
+
240
+ Returns:
241
+ UIElementGraph with converted UIElement objects
242
+ """
243
+ # Validate source type
244
+ valid_sources = {"omniparser", "uitars", "ax", "dom", "mixed"}
245
+ if source not in valid_sources:
246
+ # Allow custom sources but warn
247
+ pass
248
+
249
+ # Convert elements
250
+ ui_elements = [
251
+ element_to_ui_element(
252
+ element=el,
253
+ image_width=image_width,
254
+ image_height=image_height,
255
+ element_index=i,
256
+ source=source,
257
+ )
258
+ for i, el in enumerate(elements)
259
+ ]
260
+
261
+ # Build source summary
262
+ source_summary = {source: len(elements)}
263
+
264
+ return cls(
265
+ elements=ui_elements,
266
+ source=source if source in valid_sources else "mixed",
267
+ source_summary=source_summary,
268
+ timestamp_ms=timestamp_ms,
269
+ image_width=image_width,
270
+ image_height=image_height,
271
+ )
272
+
273
+ @classmethod
274
+ def merge(
275
+ cls,
276
+ graphs: list["UIElementGraph"],
277
+ deduplicate_iou_threshold: Optional[float] = None,
278
+ ) -> "UIElementGraph":
279
+ """Merge multiple UIElementGraphs into one.
280
+
281
+ Args:
282
+ graphs: List of UIElementGraph objects to merge
283
+ deduplicate_iou_threshold: If provided, remove duplicate elements
284
+ with IoU greater than this threshold (0.0-1.0)
285
+
286
+ Returns:
287
+ New UIElementGraph with combined elements
288
+ """
289
+ if not graphs:
290
+ return cls()
291
+
292
+ # Combine all elements
293
+ all_elements: list[UIElement] = []
294
+ source_counts: Counter[str] = Counter()
295
+
296
+ for graph in graphs:
297
+ all_elements.extend(graph.elements)
298
+ for src, count in graph.source_summary.items():
299
+ source_counts[src] += count
300
+
301
+ # Get image dimensions from first graph that has them
302
+ image_width = None
303
+ image_height = None
304
+ for graph in graphs:
305
+ if graph.image_width is not None:
306
+ image_width = graph.image_width
307
+ image_height = graph.image_height
308
+ break
309
+
310
+ # TODO: Implement deduplication by IoU if threshold provided
311
+ # For now, just combine all elements
312
+
313
+ return cls(
314
+ elements=all_elements,
315
+ source="mixed" if len(source_counts) > 1 else list(source_counts.keys())[0],
316
+ source_summary=dict(source_counts),
317
+ timestamp_ms=graphs[0].timestamp_ms if graphs else None,
318
+ image_width=image_width,
319
+ image_height=image_height,
320
+ )
321
+
322
+ def to_dict(self) -> dict[str, Any]:
323
+ """Convert to JSON-serializable dictionary.
324
+
325
+ Returns:
326
+ Dictionary representation suitable for JSON serialization
327
+ """
328
+ return self.model_dump()
329
+
330
+ @classmethod
331
+ def from_dict(cls, data: dict[str, Any]) -> "UIElementGraph":
332
+ """Create from dictionary.
333
+
334
+ Args:
335
+ data: Dictionary representation
336
+
337
+ Returns:
338
+ UIElementGraph instance
339
+ """
340
+ return cls.model_validate(data)
341
+
342
+ def get_element_by_id(self, element_id: str) -> Optional[UIElement]:
343
+ """Find element by ID.
344
+
345
+ Args:
346
+ element_id: Element ID to search for
347
+
348
+ Returns:
349
+ UIElement if found, None otherwise
350
+ """
351
+ for element in self.elements:
352
+ if element.element_id == element_id:
353
+ return element
354
+ return None
355
+
356
+ def get_elements_by_role(self, role: str) -> list[UIElement]:
357
+ """Find elements by role.
358
+
359
+ Args:
360
+ role: Role to filter by (e.g., "button", "textbox")
361
+
362
+ Returns:
363
+ List of matching UIElements
364
+ """
365
+ return [el for el in self.elements if el.role == role]
366
+
367
+ def get_elements_by_text(
368
+ self,
369
+ text: str,
370
+ exact: bool = False,
371
+ ) -> list[UIElement]:
372
+ """Find elements by text content.
373
+
374
+ Args:
375
+ text: Text to search for
376
+ exact: If True, require exact match; if False, use substring match
377
+
378
+ Returns:
379
+ List of matching UIElements
380
+ """
381
+ results = []
382
+ for el in self.elements:
383
+ if el.name is None:
384
+ continue
385
+ if exact:
386
+ if el.name == text:
387
+ results.append(el)
388
+ else:
389
+ if text.lower() in el.name.lower():
390
+ results.append(el)
391
+ return results
392
+
393
+ def __len__(self) -> int:
394
+ """Return number of elements in the graph."""
395
+ return len(self.elements)
396
+
397
+ def __iter__(self):
398
+ """Iterate over elements."""
399
+ return iter(self.elements)
@@ -32,7 +32,6 @@ Example usage:
32
32
 
33
33
  from __future__ import annotations
34
34
 
35
- import hashlib
36
35
  import json
37
36
  import logging
38
37
  from dataclasses import dataclass, field
@@ -206,11 +205,15 @@ class DemoRetriever:
206
205
  platform = self._detect_platform(episode, app_name, domain)
207
206
 
208
207
  # Extract action types
209
- action_types = list(set(
210
- step.action.type.value if hasattr(step.action.type, 'value') else str(step.action.type)
211
- for step in episode.steps
212
- if step.action
213
- ))
208
+ action_types = list(
209
+ set(
210
+ step.action.type.value
211
+ if hasattr(step.action.type, "value")
212
+ else str(step.action.type)
213
+ for step in episode.steps
214
+ if step.action
215
+ )
216
+ )
214
217
 
215
218
  # Extract key elements
216
219
  key_elements = self._extract_key_elements(episode)
@@ -297,9 +300,13 @@ class DemoRetriever:
297
300
  return
298
301
 
299
302
  if not self._demos:
300
- raise ValueError("Cannot build index: no demos added. Use add_demo() first.")
303
+ raise ValueError(
304
+ "Cannot build index: no demos added. Use add_demo() first."
305
+ )
301
306
 
302
- logger.info(f"Building index for {len(self._demos)} demos using {self.embedding_method}...")
307
+ logger.info(
308
+ f"Building index for {len(self._demos)} demos using {self.embedding_method}..."
309
+ )
303
310
 
304
311
  # Initialize embedder if needed
305
312
  if self._embedder is None:
@@ -358,16 +365,21 @@ class DemoRetriever:
358
365
  }
359
366
 
360
367
  with open(path / "index.json", "w") as f:
361
- json.dump({
362
- "embedding_method": self.embedding_method,
363
- "embedding_model": self.embedding_model,
364
- "demos": metadata,
365
- "embedder_state": embedder_state,
366
- }, f, indent=2)
368
+ json.dump(
369
+ {
370
+ "embedding_method": self.embedding_method,
371
+ "embedding_model": self.embedding_model,
372
+ "demos": metadata,
373
+ "embedder_state": embedder_state,
374
+ },
375
+ f,
376
+ indent=2,
377
+ )
367
378
 
368
379
  # Save embeddings as numpy array
369
380
  try:
370
381
  import numpy as np
382
+
371
383
  if self._embeddings_matrix is not None:
372
384
  np.save(path / "embeddings.npy", self._embeddings_matrix)
373
385
  except ImportError:
@@ -375,7 +387,11 @@ class DemoRetriever:
375
387
 
376
388
  logger.info(f"Index saved to {path}")
377
389
 
378
- def load_index(self, path: Union[str, Path], episode_loader: Optional[Callable[[str], Episode]] = None) -> None:
390
+ def load_index(
391
+ self,
392
+ path: Union[str, Path],
393
+ episode_loader: Optional[Callable[[str], Episode]] = None,
394
+ ) -> None:
379
395
  """Load index from disk.
380
396
 
381
397
  Args:
@@ -395,6 +411,7 @@ class DemoRetriever:
395
411
  embeddings = None
396
412
  try:
397
413
  import numpy as np
414
+
398
415
  embeddings_path = path / "embeddings.npy"
399
416
  if embeddings_path.exists():
400
417
  embeddings = np.load(embeddings_path)
@@ -409,11 +426,14 @@ class DemoRetriever:
409
426
  try:
410
427
  episode = episode_loader(meta["file_path"])
411
428
  except Exception as e:
412
- logger.warning(f"Failed to load episode from {meta['file_path']}: {e}")
429
+ logger.warning(
430
+ f"Failed to load episode from {meta['file_path']}: {e}"
431
+ )
413
432
 
414
433
  # Create placeholder episode if not loaded
415
434
  if episode is None:
416
435
  from openadapt_ml.schema import Action, ActionType, Observation, Step
436
+
417
437
  episode = Episode(
418
438
  episode_id=meta["demo_id"],
419
439
  instruction=meta["goal"],
@@ -454,6 +474,7 @@ class DemoRetriever:
454
474
  embedder_state = data.get("embedder_state", {})
455
475
  if embedder_state and self.embedding_method == "tfidf":
456
476
  from openadapt_ml.retrieval.embeddings import TFIDFEmbedder
477
+
457
478
  self._embedder = TFIDFEmbedder()
458
479
  self._embedder.vocab = embedder_state.get("vocab", [])
459
480
  self._embedder.vocab_to_idx = embedder_state.get("vocab_to_idx", {})
@@ -512,12 +533,14 @@ class DemoRetriever:
512
533
  bonus = self._compute_context_bonus(demo, app_context, domain_context)
513
534
  total_score = text_score + bonus
514
535
 
515
- results.append(RetrievalResult(
516
- demo=demo,
517
- score=total_score,
518
- text_score=text_score,
519
- domain_bonus=bonus,
520
- ))
536
+ results.append(
537
+ RetrievalResult(
538
+ demo=demo,
539
+ score=total_score,
540
+ text_score=text_score,
541
+ domain_bonus=bonus,
542
+ )
543
+ )
521
544
 
522
545
  # Sort by score (descending)
523
546
  results.sort(key=lambda r: r.score, reverse=True)
@@ -618,6 +641,7 @@ class DemoRetriever:
618
641
  def _format_action_minimal(self, action: Any) -> str:
619
642
  """Format action as minimal string."""
620
643
  from openadapt_ml.experiments.demo_prompt.format_demo import format_action
644
+
621
645
  return format_action(action)
622
646
 
623
647
  # =========================================================================
@@ -628,10 +652,12 @@ class DemoRetriever:
628
652
  """Initialize the embedding backend."""
629
653
  if self.embedding_method == "tfidf":
630
654
  from openadapt_ml.retrieval.embeddings import TFIDFEmbedder
655
+
631
656
  self._embedder = TFIDFEmbedder()
632
657
 
633
658
  elif self.embedding_method == "sentence_transformers":
634
659
  from openadapt_ml.retrieval.embeddings import SentenceTransformerEmbedder
660
+
635
661
  self._embedder = SentenceTransformerEmbedder(
636
662
  model_name=self.embedding_model,
637
663
  cache_dir=self.cache_dir / "st_cache",
@@ -639,6 +665,7 @@ class DemoRetriever:
639
665
 
640
666
  elif self.embedding_method == "openai":
641
667
  from openadapt_ml.retrieval.embeddings import OpenAIEmbedder
668
+
642
669
  self._embedder = OpenAIEmbedder(
643
670
  model_name=self.embedding_model,
644
671
  cache_dir=self.cache_dir / "openai_cache",
@@ -739,8 +766,7 @@ class DemoRetriever:
739
766
  if filter_tags:
740
767
  filter_tags_set = set(filter_tags)
741
768
  candidates = [
742
- d for d in candidates
743
- if filter_tags_set.issubset(set(d.tags))
769
+ d for d in candidates if filter_tags_set.issubset(set(d.tags))
744
770
  ]
745
771
 
746
772
  return candidates
@@ -25,9 +25,9 @@ import logging
25
25
  import re
26
26
  from abc import ABC, abstractmethod
27
27
  from collections import Counter
28
- from math import log, sqrt
28
+ from math import log
29
29
  from pathlib import Path
30
- from typing import Any, Dict, List, Optional, Union
30
+ from typing import Any, Dict, List, Optional
31
31
 
32
32
  logger = logging.getLogger(__name__)
33
33
 
@@ -121,7 +121,7 @@ class TFIDFEmbedder(BaseEmbedder):
121
121
  Returns:
122
122
  List of tokens.
123
123
  """
124
- tokens = re.findall(r'\b\w+\b', text.lower())
124
+ tokens = re.findall(r"\b\w+\b", text.lower())
125
125
  return tokens
126
126
 
127
127
  def _compute_tf(self, tokens: List[str]) -> Dict[str, float]:
@@ -169,8 +169,7 @@ class TFIDFEmbedder(BaseEmbedder):
169
169
  # Compute IDF: log(N / df) + 1
170
170
  n_docs = max(len(documents), 1)
171
171
  self.idf = {
172
- term: log(n_docs / doc_freq.get(term, 1)) + 1
173
- for term in self.vocab
172
+ term: log(n_docs / doc_freq.get(term, 1)) + 1 for term in self.vocab
174
173
  }
175
174
 
176
175
  self._is_fitted = True
@@ -440,6 +439,7 @@ class OpenAIEmbedder(BaseEmbedder):
440
439
  cached = json.load(f)
441
440
  # Convert lists back to arrays
442
441
  import numpy as np
442
+
443
443
  for key, val in cached.items():
444
444
  self._embedding_cache[key] = np.array(val, dtype=np.float32)
445
445
  logger.debug(f"Loaded {len(self._embedding_cache)} cached embeddings")
@@ -455,8 +455,7 @@ class OpenAIEmbedder(BaseEmbedder):
455
455
  try:
456
456
  # Convert arrays to lists for JSON
457
457
  cache_data = {
458
- key: val.tolist()
459
- for key, val in self._embedding_cache.items()
458
+ key: val.tolist() for key, val in self._embedding_cache.items()
460
459
  }
461
460
  with open(cache_file, "w") as f:
462
461
  json.dump(cache_data, f)
@@ -525,7 +524,9 @@ class OpenAIEmbedder(BaseEmbedder):
525
524
 
526
525
  # Process in batches
527
526
  for batch_start in range(0, len(uncached_texts), self.batch_size):
528
- batch_texts = uncached_texts[batch_start:batch_start + self.batch_size]
527
+ batch_texts = uncached_texts[
528
+ batch_start : batch_start + self.batch_size
529
+ ]
529
530
 
530
531
  try:
531
532
  response = client.embeddings.create(
@@ -49,7 +49,9 @@ class DemoRetriever:
49
49
  if index.is_empty():
50
50
  raise ValueError("Cannot create retriever from empty index")
51
51
  if not index.is_fitted():
52
- raise ValueError("Index must be built before retrieval (call index.build())")
52
+ raise ValueError(
53
+ "Index must be built before retrieval (call index.build())"
54
+ )
53
55
 
54
56
  self.index = index
55
57
  self.domain_bonus = domain_bonus
@@ -0,0 +1,50 @@
1
+ """
2
+ Runtime module for GUI automation agents.
3
+
4
+ This module provides:
5
+ - AgentPolicy: Runtime policy wrapper for VLM-based action prediction
6
+ - SafetyGate: Deterministic safety checks for action validation
7
+
8
+ Example usage:
9
+ from openadapt_ml.runtime import AgentPolicy, SafetyGate, SafetyConfig, SafetyDecision
10
+
11
+ # Create policy and safety gate
12
+ policy = AgentPolicy(adapter)
13
+ gate = SafetyGate(SafetyConfig(confidence_threshold=0.8))
14
+
15
+ # In agent loop
16
+ action, thought, state, raw = policy.predict_action_from_sample(sample)
17
+ assessment = gate.assess(action, observation)
18
+
19
+ if assessment.decision == SafetyDecision.ALLOW:
20
+ execute(action)
21
+ elif assessment.decision == SafetyDecision.REQUIRE_CONFIRMATION:
22
+ if user_confirms():
23
+ execute(action)
24
+ else: # BLOCK
25
+ log_blocked(assessment.reason)
26
+ """
27
+
28
+ from openadapt_ml.runtime.policy import (
29
+ AgentPolicy,
30
+ PolicyOutput,
31
+ parse_thought_state_action,
32
+ )
33
+ from openadapt_ml.runtime.safety_gate import (
34
+ SafetyAssessment,
35
+ SafetyConfig,
36
+ SafetyDecision,
37
+ SafetyGate,
38
+ )
39
+
40
+ __all__ = [
41
+ # Policy
42
+ "AgentPolicy",
43
+ "PolicyOutput",
44
+ "parse_thought_state_action",
45
+ # Safety Gate
46
+ "SafetyGate",
47
+ "SafetyConfig",
48
+ "SafetyDecision",
49
+ "SafetyAssessment",
50
+ ]