openadapt-ml 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/baselines/__init__.py +121 -0
- openadapt_ml/baselines/adapter.py +185 -0
- openadapt_ml/baselines/cli.py +314 -0
- openadapt_ml/baselines/config.py +448 -0
- openadapt_ml/baselines/parser.py +922 -0
- openadapt_ml/baselines/prompts.py +787 -0
- openadapt_ml/benchmarks/__init__.py +13 -107
- openadapt_ml/benchmarks/agent.py +297 -374
- openadapt_ml/benchmarks/azure.py +62 -24
- openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
- openadapt_ml/benchmarks/cli.py +1874 -751
- openadapt_ml/benchmarks/trace_export.py +631 -0
- openadapt_ml/benchmarks/viewer.py +1236 -0
- openadapt_ml/benchmarks/vm_monitor.py +1111 -0
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +216 -0
- openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +540 -0
- openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
- openadapt_ml/cloud/azure_inference.py +3 -5
- openadapt_ml/cloud/lambda_labs.py +722 -307
- openadapt_ml/cloud/local.py +3194 -89
- openadapt_ml/cloud/ssh_tunnel.py +595 -0
- openadapt_ml/datasets/next_action.py +125 -96
- openadapt_ml/evals/grounding.py +32 -9
- openadapt_ml/evals/plot_eval_metrics.py +15 -13
- openadapt_ml/evals/trajectory_matching.py +120 -57
- openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
- openadapt_ml/experiments/demo_prompt/format_demo.py +236 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
- openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
- openadapt_ml/experiments/demo_prompt/run_experiment.py +541 -0
- openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
- openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
- openadapt_ml/experiments/representation_shootout/config.py +390 -0
- openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
- openadapt_ml/experiments/representation_shootout/runner.py +687 -0
- openadapt_ml/experiments/waa_demo/__init__.py +10 -0
- openadapt_ml/experiments/waa_demo/demos.py +357 -0
- openadapt_ml/experiments/waa_demo/runner.py +732 -0
- openadapt_ml/experiments/waa_demo/tasks.py +151 -0
- openadapt_ml/export/__init__.py +9 -0
- openadapt_ml/export/__main__.py +6 -0
- openadapt_ml/export/cli.py +89 -0
- openadapt_ml/export/parquet.py +277 -0
- openadapt_ml/grounding/detector.py +18 -14
- openadapt_ml/ingest/__init__.py +11 -10
- openadapt_ml/ingest/capture.py +97 -86
- openadapt_ml/ingest/loader.py +120 -69
- openadapt_ml/ingest/synthetic.py +344 -193
- openadapt_ml/models/api_adapter.py +14 -4
- openadapt_ml/models/base_adapter.py +10 -2
- openadapt_ml/models/providers/__init__.py +288 -0
- openadapt_ml/models/providers/anthropic.py +266 -0
- openadapt_ml/models/providers/base.py +299 -0
- openadapt_ml/models/providers/google.py +376 -0
- openadapt_ml/models/providers/openai.py +342 -0
- openadapt_ml/models/qwen_vl.py +46 -19
- openadapt_ml/perception/__init__.py +35 -0
- openadapt_ml/perception/integration.py +399 -0
- openadapt_ml/retrieval/README.md +226 -0
- openadapt_ml/retrieval/USAGE.md +391 -0
- openadapt_ml/retrieval/__init__.py +91 -0
- openadapt_ml/retrieval/demo_retriever.py +843 -0
- openadapt_ml/retrieval/embeddings.py +630 -0
- openadapt_ml/retrieval/index.py +194 -0
- openadapt_ml/retrieval/retriever.py +162 -0
- openadapt_ml/runtime/__init__.py +50 -0
- openadapt_ml/runtime/policy.py +27 -14
- openadapt_ml/runtime/safety_gate.py +471 -0
- openadapt_ml/schema/__init__.py +113 -0
- openadapt_ml/schema/converters.py +588 -0
- openadapt_ml/schema/episode.py +470 -0
- openadapt_ml/scripts/capture_screenshots.py +530 -0
- openadapt_ml/scripts/compare.py +102 -61
- openadapt_ml/scripts/demo_policy.py +4 -1
- openadapt_ml/scripts/eval_policy.py +19 -14
- openadapt_ml/scripts/make_gif.py +1 -1
- openadapt_ml/scripts/prepare_synthetic.py +16 -17
- openadapt_ml/scripts/train.py +98 -75
- openadapt_ml/segmentation/README.md +920 -0
- openadapt_ml/segmentation/__init__.py +97 -0
- openadapt_ml/segmentation/adapters/__init__.py +5 -0
- openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
- openadapt_ml/segmentation/annotator.py +610 -0
- openadapt_ml/segmentation/cache.py +290 -0
- openadapt_ml/segmentation/cli.py +674 -0
- openadapt_ml/segmentation/deduplicator.py +656 -0
- openadapt_ml/segmentation/frame_describer.py +788 -0
- openadapt_ml/segmentation/pipeline.py +340 -0
- openadapt_ml/segmentation/schemas.py +622 -0
- openadapt_ml/segmentation/segment_extractor.py +634 -0
- openadapt_ml/training/azure_ops_viewer.py +1097 -0
- openadapt_ml/training/benchmark_viewer.py +3255 -19
- openadapt_ml/training/shared_ui.py +7 -7
- openadapt_ml/training/stub_provider.py +57 -35
- openadapt_ml/training/trainer.py +255 -441
- openadapt_ml/training/trl_trainer.py +403 -0
- openadapt_ml/training/viewer.py +323 -108
- openadapt_ml/training/viewer_components.py +180 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +312 -69
- openadapt_ml-0.2.1.dist-info/RECORD +116 -0
- openadapt_ml/benchmarks/base.py +0 -366
- openadapt_ml/benchmarks/data_collection.py +0 -432
- openadapt_ml/benchmarks/runner.py +0 -381
- openadapt_ml/benchmarks/waa.py +0 -704
- openadapt_ml/schemas/__init__.py +0 -53
- openadapt_ml/schemas/sessions.py +0 -122
- openadapt_ml/schemas/validation.py +0 -252
- openadapt_ml-0.1.0.dist-info/RECORD +0 -55
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
"""Base provider abstraction for API-backed VLMs.
|
|
2
|
+
|
|
3
|
+
This module defines the interface that all API providers must implement.
|
|
4
|
+
Providers handle client creation, message sending, and image encoding
|
|
5
|
+
in a provider-specific way.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import base64
|
|
11
|
+
import io
|
|
12
|
+
import logging
|
|
13
|
+
import os
|
|
14
|
+
from abc import ABC, abstractmethod
|
|
15
|
+
from typing import TYPE_CHECKING, Any
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from PIL import Image
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ProviderError(Exception):
|
|
24
|
+
"""Base exception for provider errors."""
|
|
25
|
+
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class AuthenticationError(ProviderError):
|
|
30
|
+
"""Raised when API authentication fails."""
|
|
31
|
+
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class RateLimitError(ProviderError):
|
|
36
|
+
"""Raised when API rate limit is exceeded."""
|
|
37
|
+
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ModelNotFoundError(ProviderError):
|
|
42
|
+
"""Raised when the specified model is not available."""
|
|
43
|
+
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class BaseAPIProvider(ABC):
|
|
48
|
+
"""Abstract base class for API providers (Anthropic, OpenAI, Google).
|
|
49
|
+
|
|
50
|
+
Each provider implements client creation, message sending, and image encoding
|
|
51
|
+
in a provider-specific way.
|
|
52
|
+
|
|
53
|
+
Attributes:
|
|
54
|
+
name: Provider identifier ('anthropic', 'openai', 'google').
|
|
55
|
+
|
|
56
|
+
Example:
|
|
57
|
+
>>> provider = get_provider("anthropic")
|
|
58
|
+
>>> client = provider.create_client(api_key)
|
|
59
|
+
>>> response = provider.send_message(
|
|
60
|
+
... client,
|
|
61
|
+
... model="claude-opus-4-5-20251101",
|
|
62
|
+
... system="You are a GUI agent.",
|
|
63
|
+
... content=[
|
|
64
|
+
... {"type": "text", "text": "Click the submit button"},
|
|
65
|
+
... provider.encode_image(screenshot),
|
|
66
|
+
... ],
|
|
67
|
+
... )
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
@abstractmethod
|
|
72
|
+
def name(self) -> str:
|
|
73
|
+
"""Provider name (e.g., 'anthropic', 'openai', 'google')."""
|
|
74
|
+
...
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def env_key_name(self) -> str:
|
|
78
|
+
"""Environment variable name for API key.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Environment variable name (e.g., 'ANTHROPIC_API_KEY').
|
|
82
|
+
"""
|
|
83
|
+
return f"{self.name.upper()}_API_KEY"
|
|
84
|
+
|
|
85
|
+
def get_api_key(self, api_key: str | None = None) -> str:
|
|
86
|
+
"""Get API key from parameter, settings, or environment.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
api_key: Optional explicit API key.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
API key string.
|
|
93
|
+
|
|
94
|
+
Raises:
|
|
95
|
+
AuthenticationError: If no API key is available.
|
|
96
|
+
"""
|
|
97
|
+
if api_key:
|
|
98
|
+
return api_key
|
|
99
|
+
|
|
100
|
+
# Try settings
|
|
101
|
+
from openadapt_ml.config import settings
|
|
102
|
+
|
|
103
|
+
settings_key = getattr(settings, f"{self.name}_api_key", None)
|
|
104
|
+
if settings_key:
|
|
105
|
+
return settings_key
|
|
106
|
+
|
|
107
|
+
# Try environment
|
|
108
|
+
env_key = os.getenv(self.env_key_name)
|
|
109
|
+
if env_key:
|
|
110
|
+
return env_key
|
|
111
|
+
|
|
112
|
+
raise AuthenticationError(
|
|
113
|
+
f"{self.env_key_name} is required but not found. "
|
|
114
|
+
f"Set it in .env file, environment variable, or pass api_key parameter."
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
@abstractmethod
|
|
118
|
+
def create_client(self, api_key: str) -> Any:
|
|
119
|
+
"""Create and return an API client.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
api_key: The API key for authentication.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
Provider-specific client object.
|
|
126
|
+
|
|
127
|
+
Raises:
|
|
128
|
+
ImportError: If required package is not installed.
|
|
129
|
+
AuthenticationError: If API key is invalid.
|
|
130
|
+
"""
|
|
131
|
+
...
|
|
132
|
+
|
|
133
|
+
@abstractmethod
|
|
134
|
+
def send_message(
|
|
135
|
+
self,
|
|
136
|
+
client: Any,
|
|
137
|
+
model: str,
|
|
138
|
+
system: str,
|
|
139
|
+
content: list[dict[str, Any]],
|
|
140
|
+
max_tokens: int = 1024,
|
|
141
|
+
temperature: float = 0.1,
|
|
142
|
+
) -> str:
|
|
143
|
+
"""Send a message to the API and return the response text.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
client: The API client from create_client().
|
|
147
|
+
model: Model identifier (e.g., 'claude-opus-4-5-20251101').
|
|
148
|
+
system: System prompt.
|
|
149
|
+
content: List of content items (text and images).
|
|
150
|
+
max_tokens: Maximum tokens in response.
|
|
151
|
+
temperature: Sampling temperature.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
The model's text response.
|
|
155
|
+
|
|
156
|
+
Raises:
|
|
157
|
+
RateLimitError: If rate limit is exceeded.
|
|
158
|
+
ModelNotFoundError: If model is not available.
|
|
159
|
+
ProviderError: For other API errors.
|
|
160
|
+
"""
|
|
161
|
+
...
|
|
162
|
+
|
|
163
|
+
@abstractmethod
|
|
164
|
+
def encode_image(self, image: "Image") -> dict[str, Any]:
|
|
165
|
+
"""Encode a PIL Image for the API.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
image: PIL Image to encode.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
Provider-specific image representation for inclusion in content.
|
|
172
|
+
"""
|
|
173
|
+
...
|
|
174
|
+
|
|
175
|
+
def image_to_base64(self, image: "Image", format: str = "PNG") -> str:
|
|
176
|
+
"""Convert PIL Image to base64 string.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
image: PIL Image to convert.
|
|
180
|
+
format: Image format (PNG, JPEG, etc.).
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
Base64-encoded string.
|
|
184
|
+
"""
|
|
185
|
+
buffer = io.BytesIO()
|
|
186
|
+
image.save(buffer, format=format)
|
|
187
|
+
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
|
188
|
+
|
|
189
|
+
def get_media_type(self, format: str = "PNG") -> str:
|
|
190
|
+
"""Get MIME type for image format.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
format: Image format string.
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
MIME type string.
|
|
197
|
+
"""
|
|
198
|
+
format_map = {
|
|
199
|
+
"PNG": "image/png",
|
|
200
|
+
"JPEG": "image/jpeg",
|
|
201
|
+
"JPG": "image/jpeg",
|
|
202
|
+
"GIF": "image/gif",
|
|
203
|
+
"WEBP": "image/webp",
|
|
204
|
+
}
|
|
205
|
+
return format_map.get(format.upper(), "image/png")
|
|
206
|
+
|
|
207
|
+
def create_text_content(self, text: str) -> dict[str, Any]:
|
|
208
|
+
"""Create a text content block.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
text: Text content.
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
Text content block.
|
|
215
|
+
"""
|
|
216
|
+
return {"type": "text", "text": text}
|
|
217
|
+
|
|
218
|
+
def build_content(
|
|
219
|
+
self,
|
|
220
|
+
text: str | None = None,
|
|
221
|
+
image: "Image | None" = None,
|
|
222
|
+
additional_content: list[dict[str, Any]] | None = None,
|
|
223
|
+
) -> list[dict[str, Any]]:
|
|
224
|
+
"""Build a content list from text and/or image.
|
|
225
|
+
|
|
226
|
+
Convenience method for building content lists in the correct format.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
text: Optional text content.
|
|
230
|
+
image: Optional PIL Image.
|
|
231
|
+
additional_content: Optional additional content blocks.
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
List of content blocks.
|
|
235
|
+
|
|
236
|
+
Example:
|
|
237
|
+
>>> content = provider.build_content(
|
|
238
|
+
... text="Click the button",
|
|
239
|
+
... image=screenshot,
|
|
240
|
+
... )
|
|
241
|
+
"""
|
|
242
|
+
content = []
|
|
243
|
+
|
|
244
|
+
if text:
|
|
245
|
+
content.append(self.create_text_content(text))
|
|
246
|
+
|
|
247
|
+
if image is not None:
|
|
248
|
+
content.append(self.encode_image(image))
|
|
249
|
+
|
|
250
|
+
if additional_content:
|
|
251
|
+
content.extend(additional_content)
|
|
252
|
+
|
|
253
|
+
return content
|
|
254
|
+
|
|
255
|
+
def quick_message(
|
|
256
|
+
self,
|
|
257
|
+
api_key: str,
|
|
258
|
+
model: str,
|
|
259
|
+
prompt: str,
|
|
260
|
+
image: "Image | None" = None,
|
|
261
|
+
system: str = "",
|
|
262
|
+
max_tokens: int = 1024,
|
|
263
|
+
temperature: float = 0.1,
|
|
264
|
+
) -> str:
|
|
265
|
+
"""Send a quick message without managing client lifecycle.
|
|
266
|
+
|
|
267
|
+
Convenience method that creates a client, sends a message, and returns
|
|
268
|
+
the response in one call. Useful for one-off requests.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
api_key: API key for authentication.
|
|
272
|
+
model: Model identifier.
|
|
273
|
+
prompt: User prompt text.
|
|
274
|
+
image: Optional image to include.
|
|
275
|
+
system: Optional system prompt.
|
|
276
|
+
max_tokens: Maximum tokens in response.
|
|
277
|
+
temperature: Sampling temperature.
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
Model response text.
|
|
281
|
+
|
|
282
|
+
Example:
|
|
283
|
+
>>> response = provider.quick_message(
|
|
284
|
+
... api_key=key,
|
|
285
|
+
... model="claude-opus-4-5-20251101",
|
|
286
|
+
... prompt="What's in this image?",
|
|
287
|
+
... image=screenshot,
|
|
288
|
+
... )
|
|
289
|
+
"""
|
|
290
|
+
client = self.create_client(api_key)
|
|
291
|
+
content = self.build_content(text=prompt, image=image)
|
|
292
|
+
return self.send_message(
|
|
293
|
+
client=client,
|
|
294
|
+
model=model,
|
|
295
|
+
system=system,
|
|
296
|
+
content=content,
|
|
297
|
+
max_tokens=max_tokens,
|
|
298
|
+
temperature=temperature,
|
|
299
|
+
)
|
|
@@ -0,0 +1,376 @@
|
|
|
1
|
+
"""Google (Gemini) API provider.
|
|
2
|
+
|
|
3
|
+
Supports Gemini 3 Pro, Gemini 3 Flash, and other Gemini models.
|
|
4
|
+
Implements the BaseAPIProvider interface for the Generative AI API.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
from typing import TYPE_CHECKING, Any
|
|
11
|
+
|
|
12
|
+
from openadapt_ml.models.providers.base import (
|
|
13
|
+
BaseAPIProvider,
|
|
14
|
+
AuthenticationError,
|
|
15
|
+
ModelNotFoundError,
|
|
16
|
+
ProviderError,
|
|
17
|
+
RateLimitError,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from PIL import Image
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
# Default models
|
|
26
|
+
DEFAULT_MODEL = "gemini-2.5-flash"
|
|
27
|
+
|
|
28
|
+
# Supported models with their properties
|
|
29
|
+
SUPPORTED_MODELS = {
|
|
30
|
+
"gemini-3-pro": {"context": 2_000_000, "description": "Most capable Gemini"},
|
|
31
|
+
"gemini-3-flash": {"context": 1_000_000, "description": "Fast inference"},
|
|
32
|
+
"gemini-2.5-pro": {"context": 2_000_000, "description": "Previous pro"},
|
|
33
|
+
"gemini-2.5-flash": {"context": 1_000_000, "description": "Fast previous gen"},
|
|
34
|
+
"gemini-2.0-flash": {"context": 1_000_000, "description": "Stable flash"},
|
|
35
|
+
"gemini-1.5-pro": {"context": 2_000_000, "description": "Legacy pro"},
|
|
36
|
+
"gemini-1.5-flash": {"context": 1_000_000, "description": "Legacy flash"},
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class GoogleProvider(BaseAPIProvider):
|
|
41
|
+
"""Provider for Google's Gemini models.
|
|
42
|
+
|
|
43
|
+
Implements vision support with native PIL Image handling. Unlike Anthropic
|
|
44
|
+
and OpenAI which require base64 encoding, Gemini accepts PIL Images directly.
|
|
45
|
+
|
|
46
|
+
Supported models:
|
|
47
|
+
- gemini-3-pro: Most capable, 2M context window
|
|
48
|
+
- gemini-3-flash: Fast inference, 1M context
|
|
49
|
+
- gemini-2.5-pro/flash: Previous generation
|
|
50
|
+
- gemini-2.0-flash: Stable release
|
|
51
|
+
|
|
52
|
+
Note:
|
|
53
|
+
Gemini supports PIL Images directly without base64 encoding.
|
|
54
|
+
The encode_image method returns the image wrapped in a dict for
|
|
55
|
+
consistency with other providers.
|
|
56
|
+
|
|
57
|
+
Example:
|
|
58
|
+
>>> provider = GoogleProvider()
|
|
59
|
+
>>> client = provider.create_client(api_key)
|
|
60
|
+
>>> response = provider.send_message(
|
|
61
|
+
... client,
|
|
62
|
+
... model="gemini-3-pro",
|
|
63
|
+
... system="You are a GUI agent.",
|
|
64
|
+
... content=[
|
|
65
|
+
... {"type": "text", "text": "Click the submit button"},
|
|
66
|
+
... provider.encode_image(screenshot),
|
|
67
|
+
... ],
|
|
68
|
+
... )
|
|
69
|
+
|
|
70
|
+
Attributes:
|
|
71
|
+
name: Returns 'google'.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def name(self) -> str:
|
|
76
|
+
"""Provider name."""
|
|
77
|
+
return "google"
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def env_key_name(self) -> str:
|
|
81
|
+
"""Environment variable name for API key."""
|
|
82
|
+
return "GOOGLE_API_KEY"
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def default_model(self) -> str:
|
|
86
|
+
"""Default model to use."""
|
|
87
|
+
return DEFAULT_MODEL
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def supported_models(self) -> dict[str, dict[str, Any]]:
|
|
91
|
+
"""Dictionary of supported models and their properties."""
|
|
92
|
+
return SUPPORTED_MODELS
|
|
93
|
+
|
|
94
|
+
def create_client(self, api_key: str) -> Any:
|
|
95
|
+
"""Create Google Generative AI client.
|
|
96
|
+
|
|
97
|
+
Unlike Anthropic/OpenAI, Gemini uses a global configure call.
|
|
98
|
+
We return a dict containing the configured genai module.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
api_key: Google API key.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
Dict containing api_key and configured genai module.
|
|
105
|
+
|
|
106
|
+
Raises:
|
|
107
|
+
ImportError: If google-generativeai package not installed.
|
|
108
|
+
AuthenticationError: If API key is empty.
|
|
109
|
+
"""
|
|
110
|
+
try:
|
|
111
|
+
import google.generativeai as genai
|
|
112
|
+
except ImportError as e:
|
|
113
|
+
raise ImportError(
|
|
114
|
+
"google-generativeai package is required for provider='google'. "
|
|
115
|
+
"Install with: uv add google-generativeai"
|
|
116
|
+
) from e
|
|
117
|
+
|
|
118
|
+
if not api_key or not api_key.strip():
|
|
119
|
+
raise AuthenticationError(
|
|
120
|
+
"Google API key cannot be empty. "
|
|
121
|
+
"Get a key from https://makersuite.google.com/app/apikey"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
logger.debug("Configuring Google Generative AI")
|
|
125
|
+
genai.configure(api_key=api_key)
|
|
126
|
+
return {"api_key": api_key, "genai": genai}
|
|
127
|
+
|
|
128
|
+
def send_message(
|
|
129
|
+
self,
|
|
130
|
+
client: Any,
|
|
131
|
+
model: str,
|
|
132
|
+
system: str,
|
|
133
|
+
content: list[dict[str, Any]],
|
|
134
|
+
max_tokens: int = 1024,
|
|
135
|
+
temperature: float = 0.1,
|
|
136
|
+
) -> str:
|
|
137
|
+
"""Send message using Gemini Generate Content API.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
client: Client dict from create_client().
|
|
141
|
+
model: Model ID (e.g., 'gemini-3-pro').
|
|
142
|
+
system: System prompt (prepended to content as text).
|
|
143
|
+
content: List of content blocks.
|
|
144
|
+
max_tokens: Max response tokens.
|
|
145
|
+
temperature: Sampling temperature (0.0-2.0 for Gemini).
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Model response text.
|
|
149
|
+
|
|
150
|
+
Raises:
|
|
151
|
+
AuthenticationError: If API key is invalid.
|
|
152
|
+
RateLimitError: If rate limit exceeded.
|
|
153
|
+
ModelNotFoundError: If model doesn't exist.
|
|
154
|
+
ProviderError: For other API errors.
|
|
155
|
+
"""
|
|
156
|
+
logger.debug(f"Sending message to {model} with {len(content)} content blocks")
|
|
157
|
+
|
|
158
|
+
genai = client["genai"]
|
|
159
|
+
model_instance = genai.GenerativeModel(model)
|
|
160
|
+
|
|
161
|
+
# Build content list for Gemini
|
|
162
|
+
gemini_content = []
|
|
163
|
+
|
|
164
|
+
# Add system prompt as first text if provided
|
|
165
|
+
if system:
|
|
166
|
+
gemini_content.append(f"System: {system}\n\n")
|
|
167
|
+
|
|
168
|
+
# Process content items
|
|
169
|
+
for item in content:
|
|
170
|
+
if item.get("type") == "text":
|
|
171
|
+
gemini_content.append(item.get("text", ""))
|
|
172
|
+
elif item.get("type") == "image":
|
|
173
|
+
# Gemini accepts PIL Images directly
|
|
174
|
+
image = item.get("image")
|
|
175
|
+
if image is not None:
|
|
176
|
+
gemini_content.append(image)
|
|
177
|
+
|
|
178
|
+
try:
|
|
179
|
+
response = model_instance.generate_content(
|
|
180
|
+
gemini_content,
|
|
181
|
+
generation_config=genai.GenerationConfig(
|
|
182
|
+
temperature=temperature,
|
|
183
|
+
max_output_tokens=max_tokens,
|
|
184
|
+
),
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
result = response.text
|
|
188
|
+
logger.debug(f"Received response: {len(result)} chars")
|
|
189
|
+
return result
|
|
190
|
+
|
|
191
|
+
except Exception as e:
|
|
192
|
+
error_str = str(e).lower()
|
|
193
|
+
|
|
194
|
+
# Map common errors to specific exceptions
|
|
195
|
+
if (
|
|
196
|
+
"api_key" in error_str
|
|
197
|
+
or "authentication" in error_str
|
|
198
|
+
or "invalid" in error_str
|
|
199
|
+
):
|
|
200
|
+
raise AuthenticationError(f"Google authentication failed: {e}") from e
|
|
201
|
+
elif "quota" in error_str or "rate" in error_str or "429" in error_str:
|
|
202
|
+
raise RateLimitError(f"Google rate limit/quota exceeded: {e}") from e
|
|
203
|
+
elif "not found" in error_str or "does not exist" in error_str:
|
|
204
|
+
raise ModelNotFoundError(f"Model '{model}' not found: {e}") from e
|
|
205
|
+
else:
|
|
206
|
+
raise ProviderError(f"Google API error: {e}") from e
|
|
207
|
+
|
|
208
|
+
def encode_image(self, image: "Image") -> dict[str, Any]:
|
|
209
|
+
"""Encode image for Gemini API.
|
|
210
|
+
|
|
211
|
+
Gemini accepts PIL Images directly, no base64 encoding needed.
|
|
212
|
+
We wrap the image in a dict for API consistency.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
image: PIL Image.
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
Image content block containing the PIL Image:
|
|
219
|
+
{
|
|
220
|
+
"type": "image",
|
|
221
|
+
"image": <PIL.Image>
|
|
222
|
+
}
|
|
223
|
+
"""
|
|
224
|
+
return {
|
|
225
|
+
"type": "image",
|
|
226
|
+
"image": image,
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
def encode_image_from_bytes(
|
|
230
|
+
self,
|
|
231
|
+
image_bytes: bytes,
|
|
232
|
+
media_type: str = "image/png",
|
|
233
|
+
) -> dict[str, Any]:
|
|
234
|
+
"""Encode raw image bytes for Gemini API.
|
|
235
|
+
|
|
236
|
+
Converts bytes to PIL Image for Gemini's native format.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
image_bytes: Raw image bytes.
|
|
240
|
+
media_type: MIME type (used to verify format).
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
Image content block with PIL Image.
|
|
244
|
+
"""
|
|
245
|
+
import io
|
|
246
|
+
|
|
247
|
+
from PIL import Image as PILImage
|
|
248
|
+
|
|
249
|
+
image = PILImage.open(io.BytesIO(image_bytes))
|
|
250
|
+
return self.encode_image(image)
|
|
251
|
+
|
|
252
|
+
def encode_image_from_url(self, url: str) -> dict[str, Any]:
|
|
253
|
+
"""Create image content block from URL.
|
|
254
|
+
|
|
255
|
+
Fetches the image and converts to PIL Image.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
url: Image URL to fetch.
|
|
259
|
+
|
|
260
|
+
Returns:
|
|
261
|
+
Image content block with PIL Image.
|
|
262
|
+
|
|
263
|
+
Raises:
|
|
264
|
+
ProviderError: If URL fetch fails.
|
|
265
|
+
"""
|
|
266
|
+
import io
|
|
267
|
+
import urllib.request
|
|
268
|
+
|
|
269
|
+
from PIL import Image as PILImage
|
|
270
|
+
|
|
271
|
+
try:
|
|
272
|
+
with urllib.request.urlopen(url) as response:
|
|
273
|
+
image_bytes = response.read()
|
|
274
|
+
image = PILImage.open(io.BytesIO(image_bytes))
|
|
275
|
+
return self.encode_image(image)
|
|
276
|
+
except Exception as e:
|
|
277
|
+
raise ProviderError(f"Failed to fetch image from URL: {e}") from e
|
|
278
|
+
|
|
279
|
+
def encode_image_as_base64(self, image: "Image") -> dict[str, Any]:
|
|
280
|
+
"""Encode image as base64 for Gemini API.
|
|
281
|
+
|
|
282
|
+
While Gemini prefers PIL Images, it can also accept base64.
|
|
283
|
+
Use this for cases where you need to serialize the content.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
image: PIL Image.
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
Image content block with base64 data.
|
|
290
|
+
"""
|
|
291
|
+
return {
|
|
292
|
+
"type": "image",
|
|
293
|
+
"inline_data": {
|
|
294
|
+
"mime_type": "image/png",
|
|
295
|
+
"data": self.image_to_base64(image, "PNG"),
|
|
296
|
+
},
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
def send_with_grounding(
|
|
300
|
+
self,
|
|
301
|
+
client: Any,
|
|
302
|
+
model: str,
|
|
303
|
+
prompt: str,
|
|
304
|
+
image: "Image",
|
|
305
|
+
max_tokens: int = 1024,
|
|
306
|
+
temperature: float = 0.1,
|
|
307
|
+
) -> dict[str, Any]:
|
|
308
|
+
"""Send message with grounding/bounding box detection.
|
|
309
|
+
|
|
310
|
+
Uses Gemini's native vision capabilities to detect UI elements
|
|
311
|
+
and return bounding boxes. Useful for Set-of-Marks processing.
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
client: Client dict from create_client().
|
|
315
|
+
model: Model ID.
|
|
316
|
+
prompt: Detection prompt.
|
|
317
|
+
image: Screenshot to analyze.
|
|
318
|
+
max_tokens: Max response tokens.
|
|
319
|
+
temperature: Sampling temperature.
|
|
320
|
+
|
|
321
|
+
Returns:
|
|
322
|
+
Dict with response text and any detected bounding boxes.
|
|
323
|
+
|
|
324
|
+
Example:
|
|
325
|
+
>>> result = provider.send_with_grounding(
|
|
326
|
+
... client,
|
|
327
|
+
... model="gemini-2.5-flash",
|
|
328
|
+
... prompt="Find the login button",
|
|
329
|
+
... image=screenshot,
|
|
330
|
+
... )
|
|
331
|
+
>>> print(result["boxes"]) # List of bounding boxes
|
|
332
|
+
"""
|
|
333
|
+
genai = client["genai"]
|
|
334
|
+
model_instance = genai.GenerativeModel(model)
|
|
335
|
+
|
|
336
|
+
grounding_prompt = f"""Analyze this screenshot and {prompt}
|
|
337
|
+
|
|
338
|
+
Return a JSON object with:
|
|
339
|
+
- "elements": array of detected elements with "label", "bbox" [x1,y1,x2,y2], "confidence"
|
|
340
|
+
- "description": brief description of what you found
|
|
341
|
+
|
|
342
|
+
Use pixel coordinates based on image dimensions: {image.width}x{image.height}
|
|
343
|
+
|
|
344
|
+
Return ONLY valid JSON."""
|
|
345
|
+
|
|
346
|
+
try:
|
|
347
|
+
response = model_instance.generate_content(
|
|
348
|
+
[grounding_prompt, image],
|
|
349
|
+
generation_config=genai.GenerationConfig(
|
|
350
|
+
temperature=temperature,
|
|
351
|
+
max_output_tokens=max_tokens,
|
|
352
|
+
),
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
text = response.text
|
|
356
|
+
|
|
357
|
+
# Try to parse JSON response
|
|
358
|
+
import json
|
|
359
|
+
import re
|
|
360
|
+
|
|
361
|
+
json_match = re.search(r"\{[\s\S]*\}", text)
|
|
362
|
+
if json_match:
|
|
363
|
+
try:
|
|
364
|
+
data = json.loads(json_match.group())
|
|
365
|
+
return {
|
|
366
|
+
"text": text,
|
|
367
|
+
"elements": data.get("elements", []),
|
|
368
|
+
"description": data.get("description", ""),
|
|
369
|
+
}
|
|
370
|
+
except json.JSONDecodeError:
|
|
371
|
+
pass
|
|
372
|
+
|
|
373
|
+
return {"text": text, "elements": [], "description": ""}
|
|
374
|
+
|
|
375
|
+
except Exception as e:
|
|
376
|
+
raise ProviderError(f"Google grounding error: {e}") from e
|