openadapt-ml 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. openadapt_ml/baselines/__init__.py +121 -0
  2. openadapt_ml/baselines/adapter.py +185 -0
  3. openadapt_ml/baselines/cli.py +314 -0
  4. openadapt_ml/baselines/config.py +448 -0
  5. openadapt_ml/baselines/parser.py +922 -0
  6. openadapt_ml/baselines/prompts.py +787 -0
  7. openadapt_ml/benchmarks/__init__.py +13 -115
  8. openadapt_ml/benchmarks/agent.py +265 -421
  9. openadapt_ml/benchmarks/azure.py +28 -19
  10. openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
  11. openadapt_ml/benchmarks/cli.py +1722 -4847
  12. openadapt_ml/benchmarks/trace_export.py +631 -0
  13. openadapt_ml/benchmarks/viewer.py +22 -5
  14. openadapt_ml/benchmarks/vm_monitor.py +530 -29
  15. openadapt_ml/benchmarks/waa_deploy/Dockerfile +47 -53
  16. openadapt_ml/benchmarks/waa_deploy/api_agent.py +21 -20
  17. openadapt_ml/cloud/azure_inference.py +3 -5
  18. openadapt_ml/cloud/lambda_labs.py +722 -307
  19. openadapt_ml/cloud/local.py +2038 -487
  20. openadapt_ml/cloud/ssh_tunnel.py +68 -26
  21. openadapt_ml/datasets/next_action.py +40 -30
  22. openadapt_ml/evals/grounding.py +8 -3
  23. openadapt_ml/evals/plot_eval_metrics.py +15 -13
  24. openadapt_ml/evals/trajectory_matching.py +41 -26
  25. openadapt_ml/experiments/demo_prompt/format_demo.py +16 -6
  26. openadapt_ml/experiments/demo_prompt/run_experiment.py +26 -16
  27. openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
  28. openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
  29. openadapt_ml/experiments/representation_shootout/config.py +390 -0
  30. openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
  31. openadapt_ml/experiments/representation_shootout/runner.py +687 -0
  32. openadapt_ml/experiments/waa_demo/runner.py +29 -14
  33. openadapt_ml/export/parquet.py +36 -24
  34. openadapt_ml/grounding/detector.py +18 -14
  35. openadapt_ml/ingest/__init__.py +8 -6
  36. openadapt_ml/ingest/capture.py +25 -22
  37. openadapt_ml/ingest/loader.py +7 -4
  38. openadapt_ml/ingest/synthetic.py +189 -100
  39. openadapt_ml/models/api_adapter.py +14 -4
  40. openadapt_ml/models/base_adapter.py +10 -2
  41. openadapt_ml/models/providers/__init__.py +288 -0
  42. openadapt_ml/models/providers/anthropic.py +266 -0
  43. openadapt_ml/models/providers/base.py +299 -0
  44. openadapt_ml/models/providers/google.py +376 -0
  45. openadapt_ml/models/providers/openai.py +342 -0
  46. openadapt_ml/models/qwen_vl.py +46 -19
  47. openadapt_ml/perception/__init__.py +35 -0
  48. openadapt_ml/perception/integration.py +399 -0
  49. openadapt_ml/retrieval/demo_retriever.py +50 -24
  50. openadapt_ml/retrieval/embeddings.py +9 -8
  51. openadapt_ml/retrieval/retriever.py +3 -1
  52. openadapt_ml/runtime/__init__.py +50 -0
  53. openadapt_ml/runtime/policy.py +18 -5
  54. openadapt_ml/runtime/safety_gate.py +471 -0
  55. openadapt_ml/schema/__init__.py +9 -0
  56. openadapt_ml/schema/converters.py +74 -27
  57. openadapt_ml/schema/episode.py +31 -18
  58. openadapt_ml/scripts/capture_screenshots.py +530 -0
  59. openadapt_ml/scripts/compare.py +85 -54
  60. openadapt_ml/scripts/demo_policy.py +4 -1
  61. openadapt_ml/scripts/eval_policy.py +15 -9
  62. openadapt_ml/scripts/make_gif.py +1 -1
  63. openadapt_ml/scripts/prepare_synthetic.py +3 -1
  64. openadapt_ml/scripts/train.py +21 -9
  65. openadapt_ml/segmentation/README.md +920 -0
  66. openadapt_ml/segmentation/__init__.py +97 -0
  67. openadapt_ml/segmentation/adapters/__init__.py +5 -0
  68. openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
  69. openadapt_ml/segmentation/annotator.py +610 -0
  70. openadapt_ml/segmentation/cache.py +290 -0
  71. openadapt_ml/segmentation/cli.py +674 -0
  72. openadapt_ml/segmentation/deduplicator.py +656 -0
  73. openadapt_ml/segmentation/frame_describer.py +788 -0
  74. openadapt_ml/segmentation/pipeline.py +340 -0
  75. openadapt_ml/segmentation/schemas.py +622 -0
  76. openadapt_ml/segmentation/segment_extractor.py +634 -0
  77. openadapt_ml/training/azure_ops_viewer.py +1097 -0
  78. openadapt_ml/training/benchmark_viewer.py +52 -41
  79. openadapt_ml/training/shared_ui.py +7 -7
  80. openadapt_ml/training/stub_provider.py +57 -35
  81. openadapt_ml/training/trainer.py +143 -86
  82. openadapt_ml/training/trl_trainer.py +70 -21
  83. openadapt_ml/training/viewer.py +323 -108
  84. openadapt_ml/training/viewer_components.py +180 -0
  85. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +215 -14
  86. openadapt_ml-0.2.1.dist-info/RECORD +116 -0
  87. openadapt_ml/benchmarks/base.py +0 -366
  88. openadapt_ml/benchmarks/data_collection.py +0 -432
  89. openadapt_ml/benchmarks/live_tracker.py +0 -180
  90. openadapt_ml/benchmarks/runner.py +0 -418
  91. openadapt_ml/benchmarks/waa.py +0 -761
  92. openadapt_ml/benchmarks/waa_live.py +0 -619
  93. openadapt_ml-0.2.0.dist-info/RECORD +0 -86
  94. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
  95. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,299 @@
1
+ """Base provider abstraction for API-backed VLMs.
2
+
3
+ This module defines the interface that all API providers must implement.
4
+ Providers handle client creation, message sending, and image encoding
5
+ in a provider-specific way.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import base64
11
+ import io
12
+ import logging
13
+ import os
14
+ from abc import ABC, abstractmethod
15
+ from typing import TYPE_CHECKING, Any
16
+
17
+ if TYPE_CHECKING:
18
+ from PIL import Image
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class ProviderError(Exception):
24
+ """Base exception for provider errors."""
25
+
26
+ pass
27
+
28
+
29
+ class AuthenticationError(ProviderError):
30
+ """Raised when API authentication fails."""
31
+
32
+ pass
33
+
34
+
35
+ class RateLimitError(ProviderError):
36
+ """Raised when API rate limit is exceeded."""
37
+
38
+ pass
39
+
40
+
41
+ class ModelNotFoundError(ProviderError):
42
+ """Raised when the specified model is not available."""
43
+
44
+ pass
45
+
46
+
47
+ class BaseAPIProvider(ABC):
48
+ """Abstract base class for API providers (Anthropic, OpenAI, Google).
49
+
50
+ Each provider implements client creation, message sending, and image encoding
51
+ in a provider-specific way.
52
+
53
+ Attributes:
54
+ name: Provider identifier ('anthropic', 'openai', 'google').
55
+
56
+ Example:
57
+ >>> provider = get_provider("anthropic")
58
+ >>> client = provider.create_client(api_key)
59
+ >>> response = provider.send_message(
60
+ ... client,
61
+ ... model="claude-opus-4-5-20251101",
62
+ ... system="You are a GUI agent.",
63
+ ... content=[
64
+ ... {"type": "text", "text": "Click the submit button"},
65
+ ... provider.encode_image(screenshot),
66
+ ... ],
67
+ ... )
68
+ """
69
+
70
+ @property
71
+ @abstractmethod
72
+ def name(self) -> str:
73
+ """Provider name (e.g., 'anthropic', 'openai', 'google')."""
74
+ ...
75
+
76
+ @property
77
+ def env_key_name(self) -> str:
78
+ """Environment variable name for API key.
79
+
80
+ Returns:
81
+ Environment variable name (e.g., 'ANTHROPIC_API_KEY').
82
+ """
83
+ return f"{self.name.upper()}_API_KEY"
84
+
85
+ def get_api_key(self, api_key: str | None = None) -> str:
86
+ """Get API key from parameter, settings, or environment.
87
+
88
+ Args:
89
+ api_key: Optional explicit API key.
90
+
91
+ Returns:
92
+ API key string.
93
+
94
+ Raises:
95
+ AuthenticationError: If no API key is available.
96
+ """
97
+ if api_key:
98
+ return api_key
99
+
100
+ # Try settings
101
+ from openadapt_ml.config import settings
102
+
103
+ settings_key = getattr(settings, f"{self.name}_api_key", None)
104
+ if settings_key:
105
+ return settings_key
106
+
107
+ # Try environment
108
+ env_key = os.getenv(self.env_key_name)
109
+ if env_key:
110
+ return env_key
111
+
112
+ raise AuthenticationError(
113
+ f"{self.env_key_name} is required but not found. "
114
+ f"Set it in .env file, environment variable, or pass api_key parameter."
115
+ )
116
+
117
+ @abstractmethod
118
+ def create_client(self, api_key: str) -> Any:
119
+ """Create and return an API client.
120
+
121
+ Args:
122
+ api_key: The API key for authentication.
123
+
124
+ Returns:
125
+ Provider-specific client object.
126
+
127
+ Raises:
128
+ ImportError: If required package is not installed.
129
+ AuthenticationError: If API key is invalid.
130
+ """
131
+ ...
132
+
133
+ @abstractmethod
134
+ def send_message(
135
+ self,
136
+ client: Any,
137
+ model: str,
138
+ system: str,
139
+ content: list[dict[str, Any]],
140
+ max_tokens: int = 1024,
141
+ temperature: float = 0.1,
142
+ ) -> str:
143
+ """Send a message to the API and return the response text.
144
+
145
+ Args:
146
+ client: The API client from create_client().
147
+ model: Model identifier (e.g., 'claude-opus-4-5-20251101').
148
+ system: System prompt.
149
+ content: List of content items (text and images).
150
+ max_tokens: Maximum tokens in response.
151
+ temperature: Sampling temperature.
152
+
153
+ Returns:
154
+ The model's text response.
155
+
156
+ Raises:
157
+ RateLimitError: If rate limit is exceeded.
158
+ ModelNotFoundError: If model is not available.
159
+ ProviderError: For other API errors.
160
+ """
161
+ ...
162
+
163
+ @abstractmethod
164
+ def encode_image(self, image: "Image") -> dict[str, Any]:
165
+ """Encode a PIL Image for the API.
166
+
167
+ Args:
168
+ image: PIL Image to encode.
169
+
170
+ Returns:
171
+ Provider-specific image representation for inclusion in content.
172
+ """
173
+ ...
174
+
175
+ def image_to_base64(self, image: "Image", format: str = "PNG") -> str:
176
+ """Convert PIL Image to base64 string.
177
+
178
+ Args:
179
+ image: PIL Image to convert.
180
+ format: Image format (PNG, JPEG, etc.).
181
+
182
+ Returns:
183
+ Base64-encoded string.
184
+ """
185
+ buffer = io.BytesIO()
186
+ image.save(buffer, format=format)
187
+ return base64.b64encode(buffer.getvalue()).decode("utf-8")
188
+
189
+ def get_media_type(self, format: str = "PNG") -> str:
190
+ """Get MIME type for image format.
191
+
192
+ Args:
193
+ format: Image format string.
194
+
195
+ Returns:
196
+ MIME type string.
197
+ """
198
+ format_map = {
199
+ "PNG": "image/png",
200
+ "JPEG": "image/jpeg",
201
+ "JPG": "image/jpeg",
202
+ "GIF": "image/gif",
203
+ "WEBP": "image/webp",
204
+ }
205
+ return format_map.get(format.upper(), "image/png")
206
+
207
+ def create_text_content(self, text: str) -> dict[str, Any]:
208
+ """Create a text content block.
209
+
210
+ Args:
211
+ text: Text content.
212
+
213
+ Returns:
214
+ Text content block.
215
+ """
216
+ return {"type": "text", "text": text}
217
+
218
+ def build_content(
219
+ self,
220
+ text: str | None = None,
221
+ image: "Image | None" = None,
222
+ additional_content: list[dict[str, Any]] | None = None,
223
+ ) -> list[dict[str, Any]]:
224
+ """Build a content list from text and/or image.
225
+
226
+ Convenience method for building content lists in the correct format.
227
+
228
+ Args:
229
+ text: Optional text content.
230
+ image: Optional PIL Image.
231
+ additional_content: Optional additional content blocks.
232
+
233
+ Returns:
234
+ List of content blocks.
235
+
236
+ Example:
237
+ >>> content = provider.build_content(
238
+ ... text="Click the button",
239
+ ... image=screenshot,
240
+ ... )
241
+ """
242
+ content = []
243
+
244
+ if text:
245
+ content.append(self.create_text_content(text))
246
+
247
+ if image is not None:
248
+ content.append(self.encode_image(image))
249
+
250
+ if additional_content:
251
+ content.extend(additional_content)
252
+
253
+ return content
254
+
255
+ def quick_message(
256
+ self,
257
+ api_key: str,
258
+ model: str,
259
+ prompt: str,
260
+ image: "Image | None" = None,
261
+ system: str = "",
262
+ max_tokens: int = 1024,
263
+ temperature: float = 0.1,
264
+ ) -> str:
265
+ """Send a quick message without managing client lifecycle.
266
+
267
+ Convenience method that creates a client, sends a message, and returns
268
+ the response in one call. Useful for one-off requests.
269
+
270
+ Args:
271
+ api_key: API key for authentication.
272
+ model: Model identifier.
273
+ prompt: User prompt text.
274
+ image: Optional image to include.
275
+ system: Optional system prompt.
276
+ max_tokens: Maximum tokens in response.
277
+ temperature: Sampling temperature.
278
+
279
+ Returns:
280
+ Model response text.
281
+
282
+ Example:
283
+ >>> response = provider.quick_message(
284
+ ... api_key=key,
285
+ ... model="claude-opus-4-5-20251101",
286
+ ... prompt="What's in this image?",
287
+ ... image=screenshot,
288
+ ... )
289
+ """
290
+ client = self.create_client(api_key)
291
+ content = self.build_content(text=prompt, image=image)
292
+ return self.send_message(
293
+ client=client,
294
+ model=model,
295
+ system=system,
296
+ content=content,
297
+ max_tokens=max_tokens,
298
+ temperature=temperature,
299
+ )
@@ -0,0 +1,376 @@
1
+ """Google (Gemini) API provider.
2
+
3
+ Supports Gemini 3 Pro, Gemini 3 Flash, and other Gemini models.
4
+ Implements the BaseAPIProvider interface for the Generative AI API.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ from typing import TYPE_CHECKING, Any
11
+
12
+ from openadapt_ml.models.providers.base import (
13
+ BaseAPIProvider,
14
+ AuthenticationError,
15
+ ModelNotFoundError,
16
+ ProviderError,
17
+ RateLimitError,
18
+ )
19
+
20
+ if TYPE_CHECKING:
21
+ from PIL import Image
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # Default models
26
+ DEFAULT_MODEL = "gemini-2.5-flash"
27
+
28
+ # Supported models with their properties
29
+ SUPPORTED_MODELS = {
30
+ "gemini-3-pro": {"context": 2_000_000, "description": "Most capable Gemini"},
31
+ "gemini-3-flash": {"context": 1_000_000, "description": "Fast inference"},
32
+ "gemini-2.5-pro": {"context": 2_000_000, "description": "Previous pro"},
33
+ "gemini-2.5-flash": {"context": 1_000_000, "description": "Fast previous gen"},
34
+ "gemini-2.0-flash": {"context": 1_000_000, "description": "Stable flash"},
35
+ "gemini-1.5-pro": {"context": 2_000_000, "description": "Legacy pro"},
36
+ "gemini-1.5-flash": {"context": 1_000_000, "description": "Legacy flash"},
37
+ }
38
+
39
+
40
+ class GoogleProvider(BaseAPIProvider):
41
+ """Provider for Google's Gemini models.
42
+
43
+ Implements vision support with native PIL Image handling. Unlike Anthropic
44
+ and OpenAI which require base64 encoding, Gemini accepts PIL Images directly.
45
+
46
+ Supported models:
47
+ - gemini-3-pro: Most capable, 2M context window
48
+ - gemini-3-flash: Fast inference, 1M context
49
+ - gemini-2.5-pro/flash: Previous generation
50
+ - gemini-2.0-flash: Stable release
51
+
52
+ Note:
53
+ Gemini supports PIL Images directly without base64 encoding.
54
+ The encode_image method returns the image wrapped in a dict for
55
+ consistency with other providers.
56
+
57
+ Example:
58
+ >>> provider = GoogleProvider()
59
+ >>> client = provider.create_client(api_key)
60
+ >>> response = provider.send_message(
61
+ ... client,
62
+ ... model="gemini-3-pro",
63
+ ... system="You are a GUI agent.",
64
+ ... content=[
65
+ ... {"type": "text", "text": "Click the submit button"},
66
+ ... provider.encode_image(screenshot),
67
+ ... ],
68
+ ... )
69
+
70
+ Attributes:
71
+ name: Returns 'google'.
72
+ """
73
+
74
+ @property
75
+ def name(self) -> str:
76
+ """Provider name."""
77
+ return "google"
78
+
79
+ @property
80
+ def env_key_name(self) -> str:
81
+ """Environment variable name for API key."""
82
+ return "GOOGLE_API_KEY"
83
+
84
+ @property
85
+ def default_model(self) -> str:
86
+ """Default model to use."""
87
+ return DEFAULT_MODEL
88
+
89
+ @property
90
+ def supported_models(self) -> dict[str, dict[str, Any]]:
91
+ """Dictionary of supported models and their properties."""
92
+ return SUPPORTED_MODELS
93
+
94
+ def create_client(self, api_key: str) -> Any:
95
+ """Create Google Generative AI client.
96
+
97
+ Unlike Anthropic/OpenAI, Gemini uses a global configure call.
98
+ We return a dict containing the configured genai module.
99
+
100
+ Args:
101
+ api_key: Google API key.
102
+
103
+ Returns:
104
+ Dict containing api_key and configured genai module.
105
+
106
+ Raises:
107
+ ImportError: If google-generativeai package not installed.
108
+ AuthenticationError: If API key is empty.
109
+ """
110
+ try:
111
+ import google.generativeai as genai
112
+ except ImportError as e:
113
+ raise ImportError(
114
+ "google-generativeai package is required for provider='google'. "
115
+ "Install with: uv add google-generativeai"
116
+ ) from e
117
+
118
+ if not api_key or not api_key.strip():
119
+ raise AuthenticationError(
120
+ "Google API key cannot be empty. "
121
+ "Get a key from https://makersuite.google.com/app/apikey"
122
+ )
123
+
124
+ logger.debug("Configuring Google Generative AI")
125
+ genai.configure(api_key=api_key)
126
+ return {"api_key": api_key, "genai": genai}
127
+
128
+ def send_message(
129
+ self,
130
+ client: Any,
131
+ model: str,
132
+ system: str,
133
+ content: list[dict[str, Any]],
134
+ max_tokens: int = 1024,
135
+ temperature: float = 0.1,
136
+ ) -> str:
137
+ """Send message using Gemini Generate Content API.
138
+
139
+ Args:
140
+ client: Client dict from create_client().
141
+ model: Model ID (e.g., 'gemini-3-pro').
142
+ system: System prompt (prepended to content as text).
143
+ content: List of content blocks.
144
+ max_tokens: Max response tokens.
145
+ temperature: Sampling temperature (0.0-2.0 for Gemini).
146
+
147
+ Returns:
148
+ Model response text.
149
+
150
+ Raises:
151
+ AuthenticationError: If API key is invalid.
152
+ RateLimitError: If rate limit exceeded.
153
+ ModelNotFoundError: If model doesn't exist.
154
+ ProviderError: For other API errors.
155
+ """
156
+ logger.debug(f"Sending message to {model} with {len(content)} content blocks")
157
+
158
+ genai = client["genai"]
159
+ model_instance = genai.GenerativeModel(model)
160
+
161
+ # Build content list for Gemini
162
+ gemini_content = []
163
+
164
+ # Add system prompt as first text if provided
165
+ if system:
166
+ gemini_content.append(f"System: {system}\n\n")
167
+
168
+ # Process content items
169
+ for item in content:
170
+ if item.get("type") == "text":
171
+ gemini_content.append(item.get("text", ""))
172
+ elif item.get("type") == "image":
173
+ # Gemini accepts PIL Images directly
174
+ image = item.get("image")
175
+ if image is not None:
176
+ gemini_content.append(image)
177
+
178
+ try:
179
+ response = model_instance.generate_content(
180
+ gemini_content,
181
+ generation_config=genai.GenerationConfig(
182
+ temperature=temperature,
183
+ max_output_tokens=max_tokens,
184
+ ),
185
+ )
186
+
187
+ result = response.text
188
+ logger.debug(f"Received response: {len(result)} chars")
189
+ return result
190
+
191
+ except Exception as e:
192
+ error_str = str(e).lower()
193
+
194
+ # Map common errors to specific exceptions
195
+ if (
196
+ "api_key" in error_str
197
+ or "authentication" in error_str
198
+ or "invalid" in error_str
199
+ ):
200
+ raise AuthenticationError(f"Google authentication failed: {e}") from e
201
+ elif "quota" in error_str or "rate" in error_str or "429" in error_str:
202
+ raise RateLimitError(f"Google rate limit/quota exceeded: {e}") from e
203
+ elif "not found" in error_str or "does not exist" in error_str:
204
+ raise ModelNotFoundError(f"Model '{model}' not found: {e}") from e
205
+ else:
206
+ raise ProviderError(f"Google API error: {e}") from e
207
+
208
+ def encode_image(self, image: "Image") -> dict[str, Any]:
209
+ """Encode image for Gemini API.
210
+
211
+ Gemini accepts PIL Images directly, no base64 encoding needed.
212
+ We wrap the image in a dict for API consistency.
213
+
214
+ Args:
215
+ image: PIL Image.
216
+
217
+ Returns:
218
+ Image content block containing the PIL Image:
219
+ {
220
+ "type": "image",
221
+ "image": <PIL.Image>
222
+ }
223
+ """
224
+ return {
225
+ "type": "image",
226
+ "image": image,
227
+ }
228
+
229
+ def encode_image_from_bytes(
230
+ self,
231
+ image_bytes: bytes,
232
+ media_type: str = "image/png",
233
+ ) -> dict[str, Any]:
234
+ """Encode raw image bytes for Gemini API.
235
+
236
+ Converts bytes to PIL Image for Gemini's native format.
237
+
238
+ Args:
239
+ image_bytes: Raw image bytes.
240
+ media_type: MIME type (used to verify format).
241
+
242
+ Returns:
243
+ Image content block with PIL Image.
244
+ """
245
+ import io
246
+
247
+ from PIL import Image as PILImage
248
+
249
+ image = PILImage.open(io.BytesIO(image_bytes))
250
+ return self.encode_image(image)
251
+
252
+ def encode_image_from_url(self, url: str) -> dict[str, Any]:
253
+ """Create image content block from URL.
254
+
255
+ Fetches the image and converts to PIL Image.
256
+
257
+ Args:
258
+ url: Image URL to fetch.
259
+
260
+ Returns:
261
+ Image content block with PIL Image.
262
+
263
+ Raises:
264
+ ProviderError: If URL fetch fails.
265
+ """
266
+ import io
267
+ import urllib.request
268
+
269
+ from PIL import Image as PILImage
270
+
271
+ try:
272
+ with urllib.request.urlopen(url) as response:
273
+ image_bytes = response.read()
274
+ image = PILImage.open(io.BytesIO(image_bytes))
275
+ return self.encode_image(image)
276
+ except Exception as e:
277
+ raise ProviderError(f"Failed to fetch image from URL: {e}") from e
278
+
279
+ def encode_image_as_base64(self, image: "Image") -> dict[str, Any]:
280
+ """Encode image as base64 for Gemini API.
281
+
282
+ While Gemini prefers PIL Images, it can also accept base64.
283
+ Use this for cases where you need to serialize the content.
284
+
285
+ Args:
286
+ image: PIL Image.
287
+
288
+ Returns:
289
+ Image content block with base64 data.
290
+ """
291
+ return {
292
+ "type": "image",
293
+ "inline_data": {
294
+ "mime_type": "image/png",
295
+ "data": self.image_to_base64(image, "PNG"),
296
+ },
297
+ }
298
+
299
+ def send_with_grounding(
300
+ self,
301
+ client: Any,
302
+ model: str,
303
+ prompt: str,
304
+ image: "Image",
305
+ max_tokens: int = 1024,
306
+ temperature: float = 0.1,
307
+ ) -> dict[str, Any]:
308
+ """Send message with grounding/bounding box detection.
309
+
310
+ Uses Gemini's native vision capabilities to detect UI elements
311
+ and return bounding boxes. Useful for Set-of-Marks processing.
312
+
313
+ Args:
314
+ client: Client dict from create_client().
315
+ model: Model ID.
316
+ prompt: Detection prompt.
317
+ image: Screenshot to analyze.
318
+ max_tokens: Max response tokens.
319
+ temperature: Sampling temperature.
320
+
321
+ Returns:
322
+ Dict with response text and any detected bounding boxes.
323
+
324
+ Example:
325
+ >>> result = provider.send_with_grounding(
326
+ ... client,
327
+ ... model="gemini-2.5-flash",
328
+ ... prompt="Find the login button",
329
+ ... image=screenshot,
330
+ ... )
331
+ >>> print(result["boxes"]) # List of bounding boxes
332
+ """
333
+ genai = client["genai"]
334
+ model_instance = genai.GenerativeModel(model)
335
+
336
+ grounding_prompt = f"""Analyze this screenshot and {prompt}
337
+
338
+ Return a JSON object with:
339
+ - "elements": array of detected elements with "label", "bbox" [x1,y1,x2,y2], "confidence"
340
+ - "description": brief description of what you found
341
+
342
+ Use pixel coordinates based on image dimensions: {image.width}x{image.height}
343
+
344
+ Return ONLY valid JSON."""
345
+
346
+ try:
347
+ response = model_instance.generate_content(
348
+ [grounding_prompt, image],
349
+ generation_config=genai.GenerationConfig(
350
+ temperature=temperature,
351
+ max_output_tokens=max_tokens,
352
+ ),
353
+ )
354
+
355
+ text = response.text
356
+
357
+ # Try to parse JSON response
358
+ import json
359
+ import re
360
+
361
+ json_match = re.search(r"\{[\s\S]*\}", text)
362
+ if json_match:
363
+ try:
364
+ data = json.loads(json_match.group())
365
+ return {
366
+ "text": text,
367
+ "elements": data.get("elements", []),
368
+ "description": data.get("description", ""),
369
+ }
370
+ except json.JSONDecodeError:
371
+ pass
372
+
373
+ return {"text": text, "elements": [], "description": ""}
374
+
375
+ except Exception as e:
376
+ raise ProviderError(f"Google grounding error: {e}") from e