openadapt-ml 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/baselines/__init__.py +121 -0
- openadapt_ml/baselines/adapter.py +185 -0
- openadapt_ml/baselines/cli.py +314 -0
- openadapt_ml/baselines/config.py +448 -0
- openadapt_ml/baselines/parser.py +922 -0
- openadapt_ml/baselines/prompts.py +787 -0
- openadapt_ml/benchmarks/__init__.py +13 -107
- openadapt_ml/benchmarks/agent.py +297 -374
- openadapt_ml/benchmarks/azure.py +62 -24
- openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
- openadapt_ml/benchmarks/cli.py +1874 -751
- openadapt_ml/benchmarks/trace_export.py +631 -0
- openadapt_ml/benchmarks/viewer.py +1236 -0
- openadapt_ml/benchmarks/vm_monitor.py +1111 -0
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +216 -0
- openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +540 -0
- openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
- openadapt_ml/cloud/azure_inference.py +3 -5
- openadapt_ml/cloud/lambda_labs.py +722 -307
- openadapt_ml/cloud/local.py +3194 -89
- openadapt_ml/cloud/ssh_tunnel.py +595 -0
- openadapt_ml/datasets/next_action.py +125 -96
- openadapt_ml/evals/grounding.py +32 -9
- openadapt_ml/evals/plot_eval_metrics.py +15 -13
- openadapt_ml/evals/trajectory_matching.py +120 -57
- openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
- openadapt_ml/experiments/demo_prompt/format_demo.py +236 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
- openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
- openadapt_ml/experiments/demo_prompt/run_experiment.py +541 -0
- openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
- openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
- openadapt_ml/experiments/representation_shootout/config.py +390 -0
- openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
- openadapt_ml/experiments/representation_shootout/runner.py +687 -0
- openadapt_ml/experiments/waa_demo/__init__.py +10 -0
- openadapt_ml/experiments/waa_demo/demos.py +357 -0
- openadapt_ml/experiments/waa_demo/runner.py +732 -0
- openadapt_ml/experiments/waa_demo/tasks.py +151 -0
- openadapt_ml/export/__init__.py +9 -0
- openadapt_ml/export/__main__.py +6 -0
- openadapt_ml/export/cli.py +89 -0
- openadapt_ml/export/parquet.py +277 -0
- openadapt_ml/grounding/detector.py +18 -14
- openadapt_ml/ingest/__init__.py +11 -10
- openadapt_ml/ingest/capture.py +97 -86
- openadapt_ml/ingest/loader.py +120 -69
- openadapt_ml/ingest/synthetic.py +344 -193
- openadapt_ml/models/api_adapter.py +14 -4
- openadapt_ml/models/base_adapter.py +10 -2
- openadapt_ml/models/providers/__init__.py +288 -0
- openadapt_ml/models/providers/anthropic.py +266 -0
- openadapt_ml/models/providers/base.py +299 -0
- openadapt_ml/models/providers/google.py +376 -0
- openadapt_ml/models/providers/openai.py +342 -0
- openadapt_ml/models/qwen_vl.py +46 -19
- openadapt_ml/perception/__init__.py +35 -0
- openadapt_ml/perception/integration.py +399 -0
- openadapt_ml/retrieval/README.md +226 -0
- openadapt_ml/retrieval/USAGE.md +391 -0
- openadapt_ml/retrieval/__init__.py +91 -0
- openadapt_ml/retrieval/demo_retriever.py +843 -0
- openadapt_ml/retrieval/embeddings.py +630 -0
- openadapt_ml/retrieval/index.py +194 -0
- openadapt_ml/retrieval/retriever.py +162 -0
- openadapt_ml/runtime/__init__.py +50 -0
- openadapt_ml/runtime/policy.py +27 -14
- openadapt_ml/runtime/safety_gate.py +471 -0
- openadapt_ml/schema/__init__.py +113 -0
- openadapt_ml/schema/converters.py +588 -0
- openadapt_ml/schema/episode.py +470 -0
- openadapt_ml/scripts/capture_screenshots.py +530 -0
- openadapt_ml/scripts/compare.py +102 -61
- openadapt_ml/scripts/demo_policy.py +4 -1
- openadapt_ml/scripts/eval_policy.py +19 -14
- openadapt_ml/scripts/make_gif.py +1 -1
- openadapt_ml/scripts/prepare_synthetic.py +16 -17
- openadapt_ml/scripts/train.py +98 -75
- openadapt_ml/segmentation/README.md +920 -0
- openadapt_ml/segmentation/__init__.py +97 -0
- openadapt_ml/segmentation/adapters/__init__.py +5 -0
- openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
- openadapt_ml/segmentation/annotator.py +610 -0
- openadapt_ml/segmentation/cache.py +290 -0
- openadapt_ml/segmentation/cli.py +674 -0
- openadapt_ml/segmentation/deduplicator.py +656 -0
- openadapt_ml/segmentation/frame_describer.py +788 -0
- openadapt_ml/segmentation/pipeline.py +340 -0
- openadapt_ml/segmentation/schemas.py +622 -0
- openadapt_ml/segmentation/segment_extractor.py +634 -0
- openadapt_ml/training/azure_ops_viewer.py +1097 -0
- openadapt_ml/training/benchmark_viewer.py +3255 -19
- openadapt_ml/training/shared_ui.py +7 -7
- openadapt_ml/training/stub_provider.py +57 -35
- openadapt_ml/training/trainer.py +255 -441
- openadapt_ml/training/trl_trainer.py +403 -0
- openadapt_ml/training/viewer.py +323 -108
- openadapt_ml/training/viewer_components.py +180 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +312 -69
- openadapt_ml-0.2.1.dist-info/RECORD +116 -0
- openadapt_ml/benchmarks/base.py +0 -366
- openadapt_ml/benchmarks/data_collection.py +0 -432
- openadapt_ml/benchmarks/runner.py +0 -381
- openadapt_ml/benchmarks/waa.py +0 -704
- openadapt_ml/schemas/__init__.py +0 -53
- openadapt_ml/schemas/sessions.py +0 -122
- openadapt_ml/schemas/validation.py +0 -252
- openadapt_ml-0.1.0.dist-info/RECORD +0 -55
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -50,7 +50,9 @@ class ApiVLMAdapter(BaseVLMAdapter):
|
|
|
50
50
|
"Install with `uv sync --extra api`."
|
|
51
51
|
) from exc
|
|
52
52
|
|
|
53
|
-
key =
|
|
53
|
+
key = (
|
|
54
|
+
api_key or settings.anthropic_api_key or os.getenv("ANTHROPIC_API_KEY")
|
|
55
|
+
)
|
|
54
56
|
if not key:
|
|
55
57
|
raise RuntimeError(
|
|
56
58
|
"ANTHROPIC_API_KEY is required but not found. "
|
|
@@ -87,10 +89,14 @@ class ApiVLMAdapter(BaseVLMAdapter):
|
|
|
87
89
|
super().__init__(model=model, processor=processor, device=device)
|
|
88
90
|
|
|
89
91
|
def prepare_inputs(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: # type: ignore[override]
|
|
90
|
-
raise NotImplementedError(
|
|
92
|
+
raise NotImplementedError(
|
|
93
|
+
"ApiVLMAdapter does not support training (prepare_inputs)"
|
|
94
|
+
)
|
|
91
95
|
|
|
92
96
|
def compute_loss(self, inputs: Dict[str, Any]) -> torch.Tensor: # type: ignore[override]
|
|
93
|
-
raise NotImplementedError(
|
|
97
|
+
raise NotImplementedError(
|
|
98
|
+
"ApiVLMAdapter does not support training (compute_loss)"
|
|
99
|
+
)
|
|
94
100
|
|
|
95
101
|
def generate(self, sample: Dict[str, Any], max_new_tokens: int = 64) -> str: # type: ignore[override]
|
|
96
102
|
images = sample.get("images", [])
|
|
@@ -138,7 +144,11 @@ class ApiVLMAdapter(BaseVLMAdapter):
|
|
|
138
144
|
|
|
139
145
|
# Anthropic messages API returns a list of content blocks.
|
|
140
146
|
parts = getattr(resp, "content", [])
|
|
141
|
-
texts = [
|
|
147
|
+
texts = [
|
|
148
|
+
getattr(p, "text", "")
|
|
149
|
+
for p in parts
|
|
150
|
+
if getattr(p, "type", "") == "text"
|
|
151
|
+
]
|
|
142
152
|
return "\n".join([t for t in texts if t]).strip()
|
|
143
153
|
|
|
144
154
|
if self.provider == "openai":
|
|
@@ -14,7 +14,10 @@ def get_default_device() -> torch.device:
|
|
|
14
14
|
|
|
15
15
|
if torch.cuda.is_available():
|
|
16
16
|
return torch.device("cuda")
|
|
17
|
-
if
|
|
17
|
+
if (
|
|
18
|
+
getattr(torch.backends, "mps", None) is not None
|
|
19
|
+
and torch.backends.mps.is_available()
|
|
20
|
+
): # type: ignore[attr-defined]
|
|
18
21
|
return torch.device("mps")
|
|
19
22
|
return torch.device("cpu")
|
|
20
23
|
|
|
@@ -28,7 +31,12 @@ class BaseVLMAdapter(ABC):
|
|
|
28
31
|
- generating assistant text given a single sample at inference time
|
|
29
32
|
"""
|
|
30
33
|
|
|
31
|
-
def __init__(
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
model: torch.nn.Module,
|
|
37
|
+
processor: Any,
|
|
38
|
+
device: Optional[torch.device] = None,
|
|
39
|
+
) -> None:
|
|
32
40
|
self.model = model
|
|
33
41
|
self.processor = processor
|
|
34
42
|
self.device = device or get_default_device()
|
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
"""API Provider implementations for VLM backends.
|
|
2
|
+
|
|
3
|
+
This module provides a unified interface for different API providers:
|
|
4
|
+
- Anthropic (Claude)
|
|
5
|
+
- OpenAI (GPT)
|
|
6
|
+
- Google (Gemini)
|
|
7
|
+
|
|
8
|
+
The provider abstraction allows switching between different VLM backends
|
|
9
|
+
without changing the calling code. Each provider handles:
|
|
10
|
+
- Client creation with API key management
|
|
11
|
+
- Message sending with vision support
|
|
12
|
+
- Image encoding in provider-specific formats
|
|
13
|
+
|
|
14
|
+
Usage:
|
|
15
|
+
from openadapt_ml.models.providers import get_provider
|
|
16
|
+
|
|
17
|
+
# Get a provider and send a message
|
|
18
|
+
provider = get_provider("anthropic")
|
|
19
|
+
client = provider.create_client(api_key)
|
|
20
|
+
response = provider.send_message(
|
|
21
|
+
client,
|
|
22
|
+
model="claude-opus-4-5-20251101",
|
|
23
|
+
system="You are a GUI agent.",
|
|
24
|
+
content=provider.build_content(
|
|
25
|
+
text="Click the submit button",
|
|
26
|
+
image=screenshot,
|
|
27
|
+
),
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
# Or use the quick_message helper
|
|
31
|
+
response = provider.quick_message(
|
|
32
|
+
api_key=key,
|
|
33
|
+
model="claude-opus-4-5-20251101",
|
|
34
|
+
prompt="What's in this image?",
|
|
35
|
+
image=screenshot,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
Model Aliases:
|
|
39
|
+
Common model aliases are provided for convenience:
|
|
40
|
+
- "claude-opus-4.5" -> ("anthropic", "claude-opus-4-5-20251101")
|
|
41
|
+
- "gpt-5.2" -> ("openai", "gpt-5.2")
|
|
42
|
+
- "gemini-3-pro" -> ("google", "gemini-3-pro")
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
from __future__ import annotations
|
|
46
|
+
|
|
47
|
+
from typing import TYPE_CHECKING
|
|
48
|
+
|
|
49
|
+
from openadapt_ml.models.providers.base import (
|
|
50
|
+
BaseAPIProvider,
|
|
51
|
+
ProviderError,
|
|
52
|
+
AuthenticationError,
|
|
53
|
+
RateLimitError,
|
|
54
|
+
ModelNotFoundError,
|
|
55
|
+
)
|
|
56
|
+
from openadapt_ml.models.providers.anthropic import AnthropicProvider
|
|
57
|
+
from openadapt_ml.models.providers.openai import OpenAIProvider
|
|
58
|
+
from openadapt_ml.models.providers.google import GoogleProvider
|
|
59
|
+
|
|
60
|
+
if TYPE_CHECKING:
|
|
61
|
+
from PIL import Image
|
|
62
|
+
|
|
63
|
+
__all__ = [
|
|
64
|
+
# Base classes and exceptions
|
|
65
|
+
"BaseAPIProvider",
|
|
66
|
+
"ProviderError",
|
|
67
|
+
"AuthenticationError",
|
|
68
|
+
"RateLimitError",
|
|
69
|
+
"ModelNotFoundError",
|
|
70
|
+
# Provider implementations
|
|
71
|
+
"AnthropicProvider",
|
|
72
|
+
"OpenAIProvider",
|
|
73
|
+
"GoogleProvider",
|
|
74
|
+
# Factory functions
|
|
75
|
+
"get_provider",
|
|
76
|
+
"get_provider_for_model",
|
|
77
|
+
"resolve_model_alias",
|
|
78
|
+
# Registries
|
|
79
|
+
"PROVIDERS",
|
|
80
|
+
"MODEL_ALIASES",
|
|
81
|
+
# Convenience functions
|
|
82
|
+
"quick_message",
|
|
83
|
+
"list_providers",
|
|
84
|
+
"list_models",
|
|
85
|
+
]
|
|
86
|
+
|
|
87
|
+
# Provider registry
|
|
88
|
+
PROVIDERS: dict[str, type[BaseAPIProvider]] = {
|
|
89
|
+
"anthropic": AnthropicProvider,
|
|
90
|
+
"openai": OpenAIProvider,
|
|
91
|
+
"google": GoogleProvider,
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
# Model aliases for convenience
|
|
95
|
+
# Maps friendly names to (provider, model_id) tuples
|
|
96
|
+
MODEL_ALIASES: dict[str, tuple[str, str]] = {
|
|
97
|
+
# Anthropic
|
|
98
|
+
"claude-opus-4.5": ("anthropic", "claude-opus-4-5-20251101"),
|
|
99
|
+
"claude-sonnet-4.5": ("anthropic", "claude-sonnet-4-5-20250929"),
|
|
100
|
+
"claude-haiku-3.5": ("anthropic", "claude-haiku-3-5-20241022"),
|
|
101
|
+
# OpenAI
|
|
102
|
+
"gpt-5.2": ("openai", "gpt-5.2"),
|
|
103
|
+
"gpt-5.1": ("openai", "gpt-5.1"),
|
|
104
|
+
"gpt-4o": ("openai", "gpt-4o"),
|
|
105
|
+
"gpt-4o-mini": ("openai", "gpt-4o-mini"),
|
|
106
|
+
# Google
|
|
107
|
+
"gemini-3-pro": ("google", "gemini-3-pro"),
|
|
108
|
+
"gemini-3-flash": ("google", "gemini-3-flash"),
|
|
109
|
+
"gemini-2.5-pro": ("google", "gemini-2.5-pro"),
|
|
110
|
+
"gemini-2.5-flash": ("google", "gemini-2.5-flash"),
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def get_provider(provider_name: str) -> BaseAPIProvider:
|
|
115
|
+
"""Get a provider instance by name.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
provider_name: Provider identifier ('anthropic', 'openai', 'google').
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Provider instance.
|
|
122
|
+
|
|
123
|
+
Raises:
|
|
124
|
+
ValueError: If provider_name is not recognized.
|
|
125
|
+
|
|
126
|
+
Example:
|
|
127
|
+
>>> provider = get_provider("anthropic")
|
|
128
|
+
>>> provider.name
|
|
129
|
+
'anthropic'
|
|
130
|
+
"""
|
|
131
|
+
provider_class = PROVIDERS.get(provider_name.lower())
|
|
132
|
+
if provider_class is None:
|
|
133
|
+
available = ", ".join(PROVIDERS.keys())
|
|
134
|
+
raise ValueError(f"Unknown provider: {provider_name}. Available: {available}")
|
|
135
|
+
return provider_class()
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def resolve_model_alias(alias: str) -> tuple[str, str]:
|
|
139
|
+
"""Resolve a model alias to (provider, model_id).
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
alias: Model alias (e.g., 'claude-opus-4.5') or full model ID.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
Tuple of (provider_name, model_id).
|
|
146
|
+
|
|
147
|
+
Raises:
|
|
148
|
+
ValueError: If alias is not recognized and can't be inferred.
|
|
149
|
+
|
|
150
|
+
Example:
|
|
151
|
+
>>> resolve_model_alias("claude-opus-4.5")
|
|
152
|
+
('anthropic', 'claude-opus-4-5-20251101')
|
|
153
|
+
>>> resolve_model_alias("gemini-3-pro")
|
|
154
|
+
('google', 'gemini-3-pro')
|
|
155
|
+
"""
|
|
156
|
+
# Check explicit aliases first
|
|
157
|
+
if alias in MODEL_ALIASES:
|
|
158
|
+
return MODEL_ALIASES[alias]
|
|
159
|
+
|
|
160
|
+
# Try to infer provider from model name patterns
|
|
161
|
+
alias_lower = alias.lower()
|
|
162
|
+
|
|
163
|
+
if alias_lower.startswith("claude"):
|
|
164
|
+
return ("anthropic", alias)
|
|
165
|
+
elif alias_lower.startswith("gpt"):
|
|
166
|
+
return ("openai", alias)
|
|
167
|
+
elif alias_lower.startswith("gemini"):
|
|
168
|
+
return ("google", alias)
|
|
169
|
+
|
|
170
|
+
raise ValueError(
|
|
171
|
+
f"Unknown model alias: {alias}. "
|
|
172
|
+
f"Available aliases: {', '.join(MODEL_ALIASES.keys())}. "
|
|
173
|
+
f"Or use a full model ID with a known prefix (claude-*, gpt-*, gemini-*)."
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def get_provider_for_model(model: str) -> tuple[BaseAPIProvider, str]:
|
|
178
|
+
"""Get the appropriate provider for a model.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
model: Model alias or full model ID.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
Tuple of (provider_instance, resolved_model_id).
|
|
185
|
+
|
|
186
|
+
Example:
|
|
187
|
+
>>> provider, model_id = get_provider_for_model("claude-opus-4.5")
|
|
188
|
+
>>> provider.name
|
|
189
|
+
'anthropic'
|
|
190
|
+
>>> model_id
|
|
191
|
+
'claude-opus-4-5-20251101'
|
|
192
|
+
"""
|
|
193
|
+
provider_name, model_id = resolve_model_alias(model)
|
|
194
|
+
provider = get_provider(provider_name)
|
|
195
|
+
return provider, model_id
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def quick_message(
|
|
199
|
+
model: str,
|
|
200
|
+
prompt: str,
|
|
201
|
+
image: "Image | None" = None,
|
|
202
|
+
system: str = "",
|
|
203
|
+
api_key: str | None = None,
|
|
204
|
+
max_tokens: int = 1024,
|
|
205
|
+
temperature: float = 0.1,
|
|
206
|
+
) -> str:
|
|
207
|
+
"""Send a quick message to any model.
|
|
208
|
+
|
|
209
|
+
Convenience function that resolves the provider, creates a client,
|
|
210
|
+
and sends a message in one call. Useful for one-off requests.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
model: Model alias or full model ID.
|
|
214
|
+
prompt: User prompt text.
|
|
215
|
+
image: Optional image to include.
|
|
216
|
+
system: Optional system prompt.
|
|
217
|
+
api_key: Optional API key (uses settings/env if not provided).
|
|
218
|
+
max_tokens: Maximum tokens in response.
|
|
219
|
+
temperature: Sampling temperature.
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
Model response text.
|
|
223
|
+
|
|
224
|
+
Raises:
|
|
225
|
+
AuthenticationError: If no API key is available.
|
|
226
|
+
ProviderError: For API errors.
|
|
227
|
+
|
|
228
|
+
Example:
|
|
229
|
+
>>> response = quick_message(
|
|
230
|
+
... model="claude-opus-4.5",
|
|
231
|
+
... prompt="What's in this image?",
|
|
232
|
+
... image=screenshot,
|
|
233
|
+
... )
|
|
234
|
+
"""
|
|
235
|
+
provider, model_id = get_provider_for_model(model)
|
|
236
|
+
resolved_key = provider.get_api_key(api_key)
|
|
237
|
+
return provider.quick_message(
|
|
238
|
+
api_key=resolved_key,
|
|
239
|
+
model=model_id,
|
|
240
|
+
prompt=prompt,
|
|
241
|
+
image=image,
|
|
242
|
+
system=system,
|
|
243
|
+
max_tokens=max_tokens,
|
|
244
|
+
temperature=temperature,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def list_providers() -> list[str]:
|
|
249
|
+
"""List available provider names.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
List of provider identifiers.
|
|
253
|
+
|
|
254
|
+
Example:
|
|
255
|
+
>>> list_providers()
|
|
256
|
+
['anthropic', 'openai', 'google']
|
|
257
|
+
"""
|
|
258
|
+
return list(PROVIDERS.keys())
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def list_models(provider: str | None = None) -> dict[str, dict]:
|
|
262
|
+
"""List available models, optionally filtered by provider.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
provider: Optional provider name to filter by.
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
Dict mapping model IDs to their properties.
|
|
269
|
+
|
|
270
|
+
Example:
|
|
271
|
+
>>> list_models("anthropic")
|
|
272
|
+
{
|
|
273
|
+
'claude-opus-4-5-20251101': {'context': 200000, 'description': 'SOTA computer use'},
|
|
274
|
+
...
|
|
275
|
+
}
|
|
276
|
+
"""
|
|
277
|
+
if provider:
|
|
278
|
+
provider_instance = get_provider(provider)
|
|
279
|
+
return provider_instance.supported_models
|
|
280
|
+
|
|
281
|
+
# Combine models from all providers
|
|
282
|
+
all_models = {}
|
|
283
|
+
for provider_name in PROVIDERS:
|
|
284
|
+
provider_instance = get_provider(provider_name)
|
|
285
|
+
for model_id, props in provider_instance.supported_models.items():
|
|
286
|
+
all_models[model_id] = {**props, "provider": provider_name}
|
|
287
|
+
|
|
288
|
+
return all_models
|
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
"""Anthropic (Claude) API provider.
|
|
2
|
+
|
|
3
|
+
Supports Claude Opus 4.5, Sonnet 4.5, and other Claude models.
|
|
4
|
+
Implements the BaseAPIProvider interface for the Anthropic Messages API.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
from typing import TYPE_CHECKING, Any
|
|
11
|
+
|
|
12
|
+
from openadapt_ml.models.providers.base import (
|
|
13
|
+
BaseAPIProvider,
|
|
14
|
+
AuthenticationError,
|
|
15
|
+
ModelNotFoundError,
|
|
16
|
+
ProviderError,
|
|
17
|
+
RateLimitError,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from PIL import Image
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
# Default models
|
|
26
|
+
DEFAULT_MODEL = "claude-sonnet-4-5-20250929"
|
|
27
|
+
|
|
28
|
+
# Supported models with their context windows
|
|
29
|
+
SUPPORTED_MODELS = {
|
|
30
|
+
"claude-opus-4-5-20251101": {
|
|
31
|
+
"context": 200_000,
|
|
32
|
+
"description": "SOTA computer use",
|
|
33
|
+
},
|
|
34
|
+
"claude-sonnet-4-5-20250929": {"context": 200_000, "description": "Fast, cheaper"},
|
|
35
|
+
"claude-sonnet-4-20250514": {"context": 200_000, "description": "Previous Sonnet"},
|
|
36
|
+
"claude-haiku-3-5-20241022": {
|
|
37
|
+
"context": 200_000,
|
|
38
|
+
"description": "Fastest, cheapest",
|
|
39
|
+
},
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class AnthropicProvider(BaseAPIProvider):
|
|
44
|
+
"""Provider for Anthropic's Claude models.
|
|
45
|
+
|
|
46
|
+
Implements vision support via base64-encoded images in the Messages API format.
|
|
47
|
+
Claude models natively support screenshots and UI analysis for computer use tasks.
|
|
48
|
+
|
|
49
|
+
Supported models:
|
|
50
|
+
- claude-opus-4-5-20251101: Most capable, best for complex GUI tasks
|
|
51
|
+
- claude-sonnet-4-5-20250929: Fast and cost-effective
|
|
52
|
+
- claude-haiku-3-5-20241022: Fastest, lowest cost
|
|
53
|
+
|
|
54
|
+
Example:
|
|
55
|
+
>>> provider = AnthropicProvider()
|
|
56
|
+
>>> client = provider.create_client(api_key)
|
|
57
|
+
>>> response = provider.send_message(
|
|
58
|
+
... client,
|
|
59
|
+
... model="claude-opus-4-5-20251101",
|
|
60
|
+
... system="You are a GUI agent.",
|
|
61
|
+
... content=[
|
|
62
|
+
... {"type": "text", "text": "Click the submit button"},
|
|
63
|
+
... provider.encode_image(screenshot),
|
|
64
|
+
... ],
|
|
65
|
+
... )
|
|
66
|
+
|
|
67
|
+
Attributes:
|
|
68
|
+
name: Returns 'anthropic'.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def name(self) -> str:
|
|
73
|
+
"""Provider name."""
|
|
74
|
+
return "anthropic"
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def env_key_name(self) -> str:
|
|
78
|
+
"""Environment variable name for API key."""
|
|
79
|
+
return "ANTHROPIC_API_KEY"
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def default_model(self) -> str:
|
|
83
|
+
"""Default model to use."""
|
|
84
|
+
return DEFAULT_MODEL
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def supported_models(self) -> dict[str, dict[str, Any]]:
|
|
88
|
+
"""Dictionary of supported models and their properties."""
|
|
89
|
+
return SUPPORTED_MODELS
|
|
90
|
+
|
|
91
|
+
def create_client(self, api_key: str) -> Any:
|
|
92
|
+
"""Create Anthropic client.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
api_key: Anthropic API key.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
Anthropic client instance.
|
|
99
|
+
|
|
100
|
+
Raises:
|
|
101
|
+
ImportError: If anthropic package not installed.
|
|
102
|
+
AuthenticationError: If API key format is invalid.
|
|
103
|
+
"""
|
|
104
|
+
try:
|
|
105
|
+
from anthropic import Anthropic
|
|
106
|
+
except ImportError as e:
|
|
107
|
+
raise ImportError(
|
|
108
|
+
"anthropic package is required for provider='anthropic'. "
|
|
109
|
+
"Install with: uv add anthropic"
|
|
110
|
+
) from e
|
|
111
|
+
|
|
112
|
+
if not api_key or not api_key.strip():
|
|
113
|
+
raise AuthenticationError(
|
|
114
|
+
"Anthropic API key cannot be empty. "
|
|
115
|
+
"Get a key from https://console.anthropic.com/"
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
logger.debug("Creating Anthropic client")
|
|
119
|
+
return Anthropic(api_key=api_key)
|
|
120
|
+
|
|
121
|
+
def send_message(
|
|
122
|
+
self,
|
|
123
|
+
client: Any,
|
|
124
|
+
model: str,
|
|
125
|
+
system: str,
|
|
126
|
+
content: list[dict[str, Any]],
|
|
127
|
+
max_tokens: int = 1024,
|
|
128
|
+
temperature: float = 0.1,
|
|
129
|
+
) -> str:
|
|
130
|
+
"""Send message using Anthropic Messages API.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
client: Anthropic client from create_client().
|
|
134
|
+
model: Model ID (e.g., 'claude-opus-4-5-20251101').
|
|
135
|
+
system: System prompt.
|
|
136
|
+
content: List of content blocks (text and images).
|
|
137
|
+
max_tokens: Max response tokens.
|
|
138
|
+
temperature: Sampling temperature (0.0-1.0).
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Model response text.
|
|
142
|
+
|
|
143
|
+
Raises:
|
|
144
|
+
AuthenticationError: If API key is invalid.
|
|
145
|
+
RateLimitError: If rate limit exceeded.
|
|
146
|
+
ModelNotFoundError: If model doesn't exist.
|
|
147
|
+
ProviderError: For other API errors.
|
|
148
|
+
"""
|
|
149
|
+
logger.debug(f"Sending message to {model} with {len(content)} content blocks")
|
|
150
|
+
|
|
151
|
+
try:
|
|
152
|
+
response = client.messages.create(
|
|
153
|
+
model=model,
|
|
154
|
+
max_tokens=max_tokens,
|
|
155
|
+
temperature=temperature,
|
|
156
|
+
system=system or None,
|
|
157
|
+
messages=[{"role": "user", "content": content}],
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# Extract text from content blocks
|
|
161
|
+
parts = getattr(response, "content", [])
|
|
162
|
+
texts = [
|
|
163
|
+
getattr(p, "text", "")
|
|
164
|
+
for p in parts
|
|
165
|
+
if getattr(p, "type", "") == "text"
|
|
166
|
+
]
|
|
167
|
+
result = "\n".join([t for t in texts if t]).strip()
|
|
168
|
+
|
|
169
|
+
logger.debug(f"Received response: {len(result)} chars")
|
|
170
|
+
return result
|
|
171
|
+
|
|
172
|
+
except Exception as e:
|
|
173
|
+
error_str = str(e).lower()
|
|
174
|
+
|
|
175
|
+
# Map common errors to specific exceptions
|
|
176
|
+
if "authentication" in error_str or "api_key" in error_str:
|
|
177
|
+
raise AuthenticationError(
|
|
178
|
+
f"Anthropic authentication failed: {e}"
|
|
179
|
+
) from e
|
|
180
|
+
elif "rate_limit" in error_str or "429" in error_str:
|
|
181
|
+
raise RateLimitError(f"Anthropic rate limit exceeded: {e}") from e
|
|
182
|
+
elif "model_not_found" in error_str or "not found" in error_str:
|
|
183
|
+
raise ModelNotFoundError(f"Model '{model}' not found: {e}") from e
|
|
184
|
+
else:
|
|
185
|
+
raise ProviderError(f"Anthropic API error: {e}") from e
|
|
186
|
+
|
|
187
|
+
def encode_image(self, image: "Image") -> dict[str, Any]:
|
|
188
|
+
"""Encode image for Anthropic API.
|
|
189
|
+
|
|
190
|
+
Anthropic uses base64-encoded images with explicit source type.
|
|
191
|
+
PNG format is used for lossless quality.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
image: PIL Image.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
Image content block for Anthropic API in format:
|
|
198
|
+
{
|
|
199
|
+
"type": "image",
|
|
200
|
+
"source": {
|
|
201
|
+
"type": "base64",
|
|
202
|
+
"media_type": "image/png",
|
|
203
|
+
"data": "<base64_string>"
|
|
204
|
+
}
|
|
205
|
+
}
|
|
206
|
+
"""
|
|
207
|
+
return {
|
|
208
|
+
"type": "image",
|
|
209
|
+
"source": {
|
|
210
|
+
"type": "base64",
|
|
211
|
+
"media_type": "image/png",
|
|
212
|
+
"data": self.image_to_base64(image, "PNG"),
|
|
213
|
+
},
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
def encode_image_from_bytes(
|
|
217
|
+
self,
|
|
218
|
+
image_bytes: bytes,
|
|
219
|
+
media_type: str = "image/png",
|
|
220
|
+
) -> dict[str, Any]:
|
|
221
|
+
"""Encode raw image bytes for Anthropic API.
|
|
222
|
+
|
|
223
|
+
Useful when you already have image bytes and don't need PIL.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
image_bytes: Raw image bytes.
|
|
227
|
+
media_type: MIME type of the image.
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
Image content block for Anthropic API.
|
|
231
|
+
"""
|
|
232
|
+
import base64
|
|
233
|
+
|
|
234
|
+
return {
|
|
235
|
+
"type": "image",
|
|
236
|
+
"source": {
|
|
237
|
+
"type": "base64",
|
|
238
|
+
"media_type": media_type,
|
|
239
|
+
"data": base64.b64encode(image_bytes).decode("utf-8"),
|
|
240
|
+
},
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
def encode_image_from_url(self, url: str) -> dict[str, Any]:
|
|
244
|
+
"""Create image content block from URL.
|
|
245
|
+
|
|
246
|
+
Note: Anthropic doesn't support URL-based images directly.
|
|
247
|
+
This method fetches the URL and encodes the image.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
url: Image URL to fetch and encode.
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
Image content block for Anthropic API.
|
|
254
|
+
|
|
255
|
+
Raises:
|
|
256
|
+
ProviderError: If URL fetch fails.
|
|
257
|
+
"""
|
|
258
|
+
import urllib.request
|
|
259
|
+
|
|
260
|
+
try:
|
|
261
|
+
with urllib.request.urlopen(url) as response:
|
|
262
|
+
image_bytes = response.read()
|
|
263
|
+
content_type = response.headers.get("Content-Type", "image/png")
|
|
264
|
+
return self.encode_image_from_bytes(image_bytes, content_type)
|
|
265
|
+
except Exception as e:
|
|
266
|
+
raise ProviderError(f"Failed to fetch image from URL: {e}") from e
|