prompture 0.0.35__py3-none-any.whl → 0.0.40.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. prompture/__init__.py +132 -3
  2. prompture/_version.py +2 -2
  3. prompture/agent.py +924 -0
  4. prompture/agent_types.py +156 -0
  5. prompture/async_agent.py +880 -0
  6. prompture/async_conversation.py +208 -17
  7. prompture/async_core.py +16 -0
  8. prompture/async_driver.py +63 -0
  9. prompture/async_groups.py +551 -0
  10. prompture/conversation.py +222 -18
  11. prompture/core.py +46 -12
  12. prompture/cost_mixin.py +37 -0
  13. prompture/discovery.py +132 -44
  14. prompture/driver.py +77 -0
  15. prompture/drivers/__init__.py +5 -1
  16. prompture/drivers/async_azure_driver.py +11 -5
  17. prompture/drivers/async_claude_driver.py +184 -9
  18. prompture/drivers/async_google_driver.py +222 -28
  19. prompture/drivers/async_grok_driver.py +11 -5
  20. prompture/drivers/async_groq_driver.py +11 -5
  21. prompture/drivers/async_lmstudio_driver.py +74 -5
  22. prompture/drivers/async_ollama_driver.py +13 -3
  23. prompture/drivers/async_openai_driver.py +162 -5
  24. prompture/drivers/async_openrouter_driver.py +11 -5
  25. prompture/drivers/async_registry.py +5 -1
  26. prompture/drivers/azure_driver.py +10 -4
  27. prompture/drivers/claude_driver.py +17 -1
  28. prompture/drivers/google_driver.py +227 -33
  29. prompture/drivers/grok_driver.py +11 -5
  30. prompture/drivers/groq_driver.py +11 -5
  31. prompture/drivers/lmstudio_driver.py +73 -8
  32. prompture/drivers/ollama_driver.py +16 -5
  33. prompture/drivers/openai_driver.py +26 -11
  34. prompture/drivers/openrouter_driver.py +11 -5
  35. prompture/drivers/vision_helpers.py +153 -0
  36. prompture/group_types.py +147 -0
  37. prompture/groups.py +530 -0
  38. prompture/image.py +180 -0
  39. prompture/ledger.py +252 -0
  40. prompture/model_rates.py +112 -2
  41. prompture/persistence.py +254 -0
  42. prompture/persona.py +482 -0
  43. prompture/serialization.py +218 -0
  44. prompture/settings.py +1 -0
  45. prompture-0.0.40.dev1.dist-info/METADATA +369 -0
  46. prompture-0.0.40.dev1.dist-info/RECORD +78 -0
  47. prompture-0.0.35.dist-info/METADATA +0 -464
  48. prompture-0.0.35.dist-info/RECORD +0 -66
  49. {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/WHEEL +0 -0
  50. {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/entry_points.txt +0 -0
  51. {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/licenses/LICENSE +0 -0
  52. {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/top_level.txt +0 -0
prompture/discovery.py CHANGED
@@ -1,7 +1,11 @@
1
1
  """Discovery module for auto-detecting available models."""
2
2
 
3
+ from __future__ import annotations
4
+
5
+ import dataclasses
3
6
  import logging
4
7
  import os
8
+ from typing import Any, overload
5
9
 
6
10
  import requests
7
11
 
@@ -22,23 +26,40 @@ from .settings import settings
22
26
  logger = logging.getLogger(__name__)
23
27
 
24
28
 
25
- def get_available_models() -> list[str]:
26
- """
27
- Auto-detects all available models based on configured drivers and environment variables.
29
+ @overload
30
+ def get_available_models(*, include_capabilities: bool = False, verified_only: bool = False) -> list[str]: ...
31
+
32
+
33
+ @overload
34
+ def get_available_models(*, include_capabilities: bool = True, verified_only: bool = False) -> list[dict[str, Any]]: ...
35
+
36
+
37
+ def get_available_models(
38
+ *,
39
+ include_capabilities: bool = False,
40
+ verified_only: bool = False,
41
+ ) -> list[str] | list[dict[str, Any]]:
42
+ """Auto-detect available models based on configured drivers and environment variables.
28
43
 
29
- Iterates through supported providers and checks if they are configured (e.g. API key present).
30
- For static drivers, returns models from their MODEL_PRICING keys.
31
- For dynamic drivers (like Ollama), attempts to fetch available models from the endpoint.
44
+ Iterates through supported providers and checks if they are configured
45
+ (e.g. API key present). For static drivers, returns models from their
46
+ ``MODEL_PRICING`` keys. For dynamic drivers (like Ollama), attempts to
47
+ fetch available models from the endpoint.
48
+
49
+ Args:
50
+ include_capabilities: When ``True``, return enriched dicts with
51
+ ``model``, ``provider``, ``model_id``, and ``capabilities``
52
+ fields instead of plain ``"provider/model_id"`` strings.
53
+ verified_only: When ``True``, only return models that have been
54
+ successfully used (as recorded by the usage ledger).
32
55
 
33
56
  Returns:
34
- A list of unique model strings in the format "provider/model_id".
57
+ A sorted list of unique model strings (default) or enriched dicts.
35
58
  """
36
59
  available_models: set[str] = set()
37
60
  configured_providers: set[str] = set()
38
61
 
39
62
  # Map of provider name to driver class
40
- # We need to map the registry keys to the actual classes to check MODEL_PRICING
41
- # and instantiate for dynamic checks if needed.
42
63
  provider_classes = {
43
64
  "openai": OpenAIDriver,
44
65
  "azure": AzureDriver,
@@ -54,11 +75,6 @@ def get_available_models() -> list[str]:
54
75
 
55
76
  for provider, driver_cls in provider_classes.items():
56
77
  try:
57
- # 1. Check if the provider is configured (has API key or endpoint)
58
- # We can check this by looking at the settings or env vars that the driver uses.
59
- # A simple way is to try to instantiate it with defaults, but that might fail if keys are missing.
60
- # Instead, let's check the specific requirements for each known provider.
61
-
62
78
  is_configured = False
63
79
 
64
80
  if provider == "openai":
@@ -86,14 +102,11 @@ def get_available_models() -> list[str]:
86
102
  elif provider == "grok":
87
103
  if settings.grok_api_key or os.getenv("GROK_API_KEY"):
88
104
  is_configured = True
89
- elif provider == "ollama":
90
- # Ollama is always considered "configured" as it defaults to localhost
91
- # We will check connectivity later
92
- is_configured = True
93
- elif provider == "lmstudio":
94
- # LM Studio is similar to Ollama, defaults to localhost
95
- is_configured = True
96
- elif provider == "local_http" and (settings.local_http_endpoint or os.getenv("LOCAL_HTTP_ENDPOINT")):
105
+ elif (
106
+ provider == "ollama"
107
+ or provider == "lmstudio"
108
+ or (provider == "local_http" and os.getenv("LOCAL_HTTP_ENDPOINT"))
109
+ ):
97
110
  is_configured = True
98
111
 
99
112
  if not is_configured:
@@ -101,36 +114,20 @@ def get_available_models() -> list[str]:
101
114
 
102
115
  configured_providers.add(provider)
103
116
 
104
- # 2. Static Detection: Get models from MODEL_PRICING
117
+ # Static Detection: Get models from MODEL_PRICING
105
118
  if hasattr(driver_cls, "MODEL_PRICING"):
106
119
  pricing = driver_cls.MODEL_PRICING
107
120
  for model_id in pricing:
108
- # Skip "default" or generic keys if they exist
109
121
  if model_id == "default":
110
122
  continue
111
-
112
- # For Azure, the model_id in pricing is usually the base model name,
113
- # but the user needs to use the deployment ID.
114
- # However, our Azure driver implementation uses the deployment_id from init
115
- # as the "model" for the request, but expects the user to pass a model name
116
- # that maps to pricing?
117
- # Looking at AzureDriver:
118
- # kwargs = {"model": self.deployment_id, ...}
119
- # model = options.get("model", self.model) -> used for pricing lookup
120
- # So we should list the keys in MODEL_PRICING as available "models"
121
- # even though for Azure specifically it's a bit weird because of deployment IDs.
122
- # But for general discovery, listing supported models is correct.
123
-
124
123
  available_models.add(f"{provider}/{model_id}")
125
124
 
126
- # 3. Dynamic Detection: Specific logic for Ollama
125
+ # Dynamic Detection: Specific logic for Ollama
127
126
  if provider == "ollama":
128
127
  try:
129
128
  endpoint = settings.ollama_endpoint or os.getenv(
130
129
  "OLLAMA_ENDPOINT", "http://localhost:11434/api/generate"
131
130
  )
132
- # We need the base URL for tags, usually http://localhost:11434/api/tags
133
- # The configured endpoint might be .../api/generate or .../api/chat
134
131
  base_url = endpoint.split("/api/")[0]
135
132
  tags_url = f"{base_url}/api/tags"
136
133
 
@@ -141,13 +138,34 @@ def get_available_models() -> list[str]:
141
138
  for model in models:
142
139
  name = model.get("name")
143
140
  if name:
144
- # Ollama model names often include tags like "llama3:latest"
145
- # We can keep them as is.
146
141
  available_models.add(f"ollama/{name}")
147
142
  except Exception as e:
148
143
  logger.debug(f"Failed to fetch Ollama models: {e}")
149
144
 
150
- # Future: Add dynamic detection for LM Studio if they have an endpoint for listing models
145
+ # Dynamic Detection: LM Studio loaded models
146
+ if provider == "lmstudio":
147
+ try:
148
+ endpoint = settings.lmstudio_endpoint or os.getenv(
149
+ "LMSTUDIO_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions"
150
+ )
151
+ base_url = endpoint.split("/v1/")[0]
152
+ models_url = f"{base_url}/v1/models"
153
+
154
+ headers: dict[str, str] = {}
155
+ api_key = settings.lmstudio_api_key or os.getenv("LMSTUDIO_API_KEY")
156
+ if api_key:
157
+ headers["Authorization"] = f"Bearer {api_key}"
158
+
159
+ resp = requests.get(models_url, headers=headers, timeout=2)
160
+ if resp.status_code == 200:
161
+ data = resp.json()
162
+ models = data.get("data", [])
163
+ for model in models:
164
+ model_id = model.get("id")
165
+ if model_id:
166
+ available_models.add(f"lmstudio/{model_id}")
167
+ except Exception as e:
168
+ logger.debug(f"Failed to fetch LM Studio models: {e}")
151
169
 
152
170
  except Exception as e:
153
171
  logger.warning(f"Error detecting models for provider {provider}: {e}")
@@ -161,4 +179,74 @@ def get_available_models() -> list[str]:
161
179
  for model_id in get_all_provider_models(api_name):
162
180
  available_models.add(f"{prompture_name}/{model_id}")
163
181
 
164
- return sorted(list(available_models))
182
+ sorted_models = sorted(available_models)
183
+
184
+ # --- verified_only filtering ---
185
+ verified_set: set[str] | None = None
186
+ if verified_only or include_capabilities:
187
+ try:
188
+ from .ledger import _get_ledger
189
+
190
+ ledger = _get_ledger()
191
+ verified_set = ledger.get_verified_models()
192
+ except Exception:
193
+ logger.debug("Could not load ledger for verified models", exc_info=True)
194
+ verified_set = set()
195
+
196
+ if verified_only and verified_set is not None:
197
+ sorted_models = [m for m in sorted_models if m in verified_set]
198
+
199
+ if not include_capabilities:
200
+ return sorted_models
201
+
202
+ # Build enriched dicts with capabilities from models.dev
203
+ from .model_rates import get_model_capabilities
204
+
205
+ # Fetch all ledger stats for annotation (keyed by model_name)
206
+ ledger_stats: dict[str, dict[str, Any]] = {}
207
+ try:
208
+ from .ledger import _get_ledger
209
+
210
+ for row in _get_ledger().get_all_stats():
211
+ name = row["model_name"]
212
+ if name not in ledger_stats:
213
+ ledger_stats[name] = row
214
+ else:
215
+ # Aggregate across API key hashes
216
+ existing = ledger_stats[name]
217
+ existing["use_count"] += row["use_count"]
218
+ existing["total_tokens"] += row["total_tokens"]
219
+ existing["total_cost"] += row["total_cost"]
220
+ if row["last_used"] > existing["last_used"]:
221
+ existing["last_used"] = row["last_used"]
222
+ except Exception:
223
+ logger.debug("Could not load ledger stats for enrichment", exc_info=True)
224
+
225
+ enriched: list[dict[str, Any]] = []
226
+ for model_str in sorted_models:
227
+ parts = model_str.split("/", 1)
228
+ provider = parts[0]
229
+ model_id = parts[1] if len(parts) > 1 else parts[0]
230
+
231
+ caps = get_model_capabilities(provider, model_id)
232
+ caps_dict = dataclasses.asdict(caps) if caps is not None else None
233
+
234
+ entry: dict[str, Any] = {
235
+ "model": model_str,
236
+ "provider": provider,
237
+ "model_id": model_id,
238
+ "capabilities": caps_dict,
239
+ "verified": verified_set is not None and model_str in verified_set,
240
+ }
241
+
242
+ stats = ledger_stats.get(model_str)
243
+ if stats:
244
+ entry["last_used"] = stats["last_used"]
245
+ entry["use_count"] = stats["use_count"]
246
+ else:
247
+ entry["last_used"] = None
248
+ entry["use_count"] = 0
249
+
250
+ enriched.append(entry)
251
+
252
+ return enriched
prompture/driver.py CHANGED
@@ -35,6 +35,7 @@ class Driver:
35
35
  supports_messages: bool = False
36
36
  supports_tool_use: bool = False
37
37
  supports_streaming: bool = False
38
+ supports_vision: bool = False
38
39
 
39
40
  callbacks: DriverCallbacks | None = None
40
41
 
@@ -52,6 +53,7 @@ class Driver:
52
53
  support message arrays should override this method and set
53
54
  ``supports_messages = True``.
54
55
  """
56
+ self._check_vision_support(messages)
55
57
  prompt = self._flatten_messages(messages)
56
58
  return self.generate(prompt, options)
57
59
 
@@ -171,6 +173,69 @@ class Driver:
171
173
  except Exception:
172
174
  logger.exception("Callback %s raised an exception", event)
173
175
 
176
+ def _validate_model_capabilities(
177
+ self,
178
+ provider: str,
179
+ model: str,
180
+ *,
181
+ using_tool_use: bool = False,
182
+ using_json_schema: bool = False,
183
+ using_vision: bool = False,
184
+ ) -> None:
185
+ """Log warnings when the model may not support a requested feature.
186
+
187
+ Uses models.dev metadata as a secondary signal. Warnings only — the
188
+ API is the final authority and models.dev data may be stale.
189
+ """
190
+ from .model_rates import get_model_capabilities
191
+
192
+ caps = get_model_capabilities(provider, model)
193
+ if caps is None:
194
+ return
195
+
196
+ if using_tool_use and caps.supports_tool_use is False:
197
+ logger.warning(
198
+ "Model %s/%s may not support tool use according to models.dev metadata",
199
+ provider,
200
+ model,
201
+ )
202
+ if using_json_schema and caps.supports_structured_output is False:
203
+ logger.warning(
204
+ "Model %s/%s may not support structured output / JSON schema according to models.dev metadata",
205
+ provider,
206
+ model,
207
+ )
208
+ if using_vision and caps.supports_vision is False:
209
+ logger.warning(
210
+ "Model %s/%s may not support vision/image inputs according to models.dev metadata",
211
+ provider,
212
+ model,
213
+ )
214
+
215
+ def _check_vision_support(self, messages: list[dict[str, Any]]) -> None:
216
+ """Raise if messages contain image blocks and the driver lacks vision support."""
217
+ if self.supports_vision:
218
+ return
219
+ for msg in messages:
220
+ content = msg.get("content")
221
+ if isinstance(content, list):
222
+ for block in content:
223
+ if isinstance(block, dict) and block.get("type") == "image":
224
+ raise NotImplementedError(
225
+ f"{self.__class__.__name__} does not support vision/image inputs. "
226
+ "Use a vision-capable model."
227
+ )
228
+
229
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
230
+ """Transform universal message format into provider-specific wire format.
231
+
232
+ Vision-capable drivers override this to convert the universal image
233
+ blocks into their provider-specific format. The base implementation
234
+ validates vision support and returns messages unchanged.
235
+ """
236
+ self._check_vision_support(messages)
237
+ return messages
238
+
174
239
  @staticmethod
175
240
  def _flatten_messages(messages: list[dict[str, Any]]) -> str:
176
241
  """Join messages into a single prompt string with role prefixes."""
@@ -178,6 +243,18 @@ class Driver:
178
243
  for msg in messages:
179
244
  role = msg.get("role", "user")
180
245
  content = msg.get("content", "")
246
+ # Handle content that is a list of blocks (vision messages)
247
+ if isinstance(content, list):
248
+ text_parts = []
249
+ for block in content:
250
+ if isinstance(block, dict):
251
+ if block.get("type") == "text":
252
+ text_parts.append(block.get("text", ""))
253
+ elif block.get("type") == "image":
254
+ text_parts.append("[image]")
255
+ elif isinstance(block, str):
256
+ text_parts.append(block)
257
+ content = " ".join(text_parts)
181
258
  if role == "system":
182
259
  parts.append(f"[System]: {content}")
183
260
  elif role == "assistant":
@@ -84,7 +84,11 @@ register_driver(
84
84
  )
85
85
  register_driver(
86
86
  "lmstudio",
87
- lambda model=None: LMStudioDriver(endpoint=settings.lmstudio_endpoint, model=model or settings.lmstudio_model),
87
+ lambda model=None: LMStudioDriver(
88
+ endpoint=settings.lmstudio_endpoint,
89
+ model=model or settings.lmstudio_model,
90
+ api_key=settings.lmstudio_api_key,
91
+ ),
88
92
  overwrite=True,
89
93
  )
90
94
  register_driver(
@@ -18,6 +18,7 @@ from .azure_driver import AzureDriver
18
18
  class AsyncAzureDriver(CostMixin, AsyncDriver):
19
19
  supports_json_mode = True
20
20
  supports_json_schema = True
21
+ supports_vision = True
21
22
 
22
23
  MODEL_PRICING = AzureDriver.MODEL_PRICING
23
24
 
@@ -52,21 +53,26 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
52
53
 
53
54
  supports_messages = True
54
55
 
56
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
57
+ from .vision_helpers import _prepare_openai_vision_messages
58
+
59
+ return _prepare_openai_vision_messages(messages)
60
+
55
61
  async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
56
62
  messages = [{"role": "user", "content": prompt}]
57
63
  return await self._do_generate(messages, options)
58
64
 
59
65
  async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
60
- return await self._do_generate(messages, options)
66
+ return await self._do_generate(self._prepare_messages(messages), options)
61
67
 
62
68
  async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
63
69
  if self.client is None:
64
70
  raise RuntimeError("openai package (>=1.0.0) with AsyncAzureOpenAI not installed")
65
71
 
66
72
  model = options.get("model", self.model)
67
- model_info = self.MODEL_PRICING.get(model, {})
68
- tokens_param = model_info.get("tokens_param", "max_tokens")
69
- supports_temperature = model_info.get("supports_temperature", True)
73
+ model_config = self._get_model_config("azure", model)
74
+ tokens_param = model_config["tokens_param"]
75
+ supports_temperature = model_config["supports_temperature"]
70
76
 
71
77
  opts = {"temperature": 1.0, "max_tokens": 512, **options}
72
78
 
@@ -107,7 +113,7 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
107
113
  "prompt_tokens": prompt_tokens,
108
114
  "completion_tokens": completion_tokens,
109
115
  "total_tokens": total_tokens,
110
- "cost": total_cost,
116
+ "cost": round(total_cost, 6),
111
117
  "raw_response": resp.model_dump(),
112
118
  "model_name": model,
113
119
  "deployment_id": self.deployment_id,
@@ -4,6 +4,7 @@ from __future__ import annotations
4
4
 
5
5
  import json
6
6
  import os
7
+ from collections.abc import AsyncIterator
7
8
  from typing import Any
8
9
 
9
10
  try:
@@ -19,6 +20,9 @@ from .claude_driver import ClaudeDriver
19
20
  class AsyncClaudeDriver(CostMixin, AsyncDriver):
20
21
  supports_json_mode = True
21
22
  supports_json_schema = True
23
+ supports_tool_use = True
24
+ supports_streaming = True
25
+ supports_vision = True
22
26
 
23
27
  MODEL_PRICING = ClaudeDriver.MODEL_PRICING
24
28
 
@@ -28,12 +32,17 @@ class AsyncClaudeDriver(CostMixin, AsyncDriver):
28
32
 
29
33
  supports_messages = True
30
34
 
35
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
36
+ from .vision_helpers import _prepare_claude_vision_messages
37
+
38
+ return _prepare_claude_vision_messages(messages)
39
+
31
40
  async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
32
41
  messages = [{"role": "user", "content": prompt}]
33
42
  return await self._do_generate(messages, options)
34
43
 
35
44
  async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
36
- return await self._do_generate(messages, options)
45
+ return await self._do_generate(self._prepare_messages(messages), options)
37
46
 
38
47
  async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
39
48
  if anthropic is None:
@@ -42,16 +51,17 @@ class AsyncClaudeDriver(CostMixin, AsyncDriver):
42
51
  opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
43
52
  model = options.get("model", self.model)
44
53
 
54
+ # Validate capabilities against models.dev metadata
55
+ self._validate_model_capabilities(
56
+ "claude",
57
+ model,
58
+ using_json_schema=bool(options.get("json_schema")),
59
+ )
60
+
45
61
  client = anthropic.AsyncAnthropic(api_key=self.api_key)
46
62
 
47
63
  # Anthropic requires system messages as a top-level parameter
48
- system_content = None
49
- api_messages = []
50
- for msg in messages:
51
- if msg.get("role") == "system":
52
- system_content = msg.get("content", "")
53
- else:
54
- api_messages.append(msg)
64
+ system_content, api_messages = self._extract_system_and_messages(messages)
55
65
 
56
66
  # Build common kwargs
57
67
  common_kwargs: dict[str, Any] = {
@@ -99,9 +109,174 @@ class AsyncClaudeDriver(CostMixin, AsyncDriver):
99
109
  "prompt_tokens": prompt_tokens,
100
110
  "completion_tokens": completion_tokens,
101
111
  "total_tokens": total_tokens,
102
- "cost": total_cost,
112
+ "cost": round(total_cost, 6),
103
113
  "raw_response": dict(resp),
104
114
  "model_name": model,
105
115
  }
106
116
 
107
117
  return {"text": text, "meta": meta}
118
+
119
+ # ------------------------------------------------------------------
120
+ # Helpers
121
+ # ------------------------------------------------------------------
122
+
123
+ def _extract_system_and_messages(
124
+ self, messages: list[dict[str, Any]]
125
+ ) -> tuple[str | None, list[dict[str, Any]]]:
126
+ """Separate system message from conversation messages for Anthropic API."""
127
+ system_content = None
128
+ api_messages: list[dict[str, Any]] = []
129
+ for msg in messages:
130
+ if msg.get("role") == "system":
131
+ system_content = msg.get("content", "")
132
+ else:
133
+ api_messages.append(msg)
134
+ return system_content, api_messages
135
+
136
+ # ------------------------------------------------------------------
137
+ # Tool use
138
+ # ------------------------------------------------------------------
139
+
140
+ async def generate_messages_with_tools(
141
+ self,
142
+ messages: list[dict[str, Any]],
143
+ tools: list[dict[str, Any]],
144
+ options: dict[str, Any],
145
+ ) -> dict[str, Any]:
146
+ """Generate a response that may include tool calls (Anthropic)."""
147
+ if anthropic is None:
148
+ raise RuntimeError("anthropic package not installed")
149
+
150
+ opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
151
+ model = options.get("model", self.model)
152
+
153
+ self._validate_model_capabilities("claude", model, using_tool_use=True)
154
+
155
+ client = anthropic.AsyncAnthropic(api_key=self.api_key)
156
+
157
+ system_content, api_messages = self._extract_system_and_messages(messages)
158
+
159
+ # Convert tools from OpenAI format to Anthropic format if needed
160
+ anthropic_tools = []
161
+ for t in tools:
162
+ if "type" in t and t["type"] == "function":
163
+ # OpenAI format -> Anthropic format
164
+ fn = t["function"]
165
+ anthropic_tools.append({
166
+ "name": fn["name"],
167
+ "description": fn.get("description", ""),
168
+ "input_schema": fn.get("parameters", {"type": "object", "properties": {}}),
169
+ })
170
+ elif "input_schema" in t:
171
+ # Already Anthropic format
172
+ anthropic_tools.append(t)
173
+ else:
174
+ anthropic_tools.append(t)
175
+
176
+ kwargs: dict[str, Any] = {
177
+ "model": model,
178
+ "messages": api_messages,
179
+ "temperature": opts["temperature"],
180
+ "max_tokens": opts["max_tokens"],
181
+ "tools": anthropic_tools,
182
+ }
183
+ if system_content:
184
+ kwargs["system"] = system_content
185
+
186
+ resp = await client.messages.create(**kwargs)
187
+
188
+ prompt_tokens = resp.usage.input_tokens
189
+ completion_tokens = resp.usage.output_tokens
190
+ total_tokens = prompt_tokens + completion_tokens
191
+ total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
192
+
193
+ meta = {
194
+ "prompt_tokens": prompt_tokens,
195
+ "completion_tokens": completion_tokens,
196
+ "total_tokens": total_tokens,
197
+ "cost": round(total_cost, 6),
198
+ "raw_response": dict(resp),
199
+ "model_name": model,
200
+ }
201
+
202
+ text = ""
203
+ tool_calls_out: list[dict[str, Any]] = []
204
+ for block in resp.content:
205
+ if block.type == "text":
206
+ text += block.text
207
+ elif block.type == "tool_use":
208
+ tool_calls_out.append({
209
+ "id": block.id,
210
+ "name": block.name,
211
+ "arguments": block.input,
212
+ })
213
+
214
+ return {
215
+ "text": text,
216
+ "meta": meta,
217
+ "tool_calls": tool_calls_out,
218
+ "stop_reason": resp.stop_reason,
219
+ }
220
+
221
+ # ------------------------------------------------------------------
222
+ # Streaming
223
+ # ------------------------------------------------------------------
224
+
225
+ async def generate_messages_stream(
226
+ self,
227
+ messages: list[dict[str, Any]],
228
+ options: dict[str, Any],
229
+ ) -> AsyncIterator[dict[str, Any]]:
230
+ """Yield response chunks via Anthropic streaming API."""
231
+ if anthropic is None:
232
+ raise RuntimeError("anthropic package not installed")
233
+
234
+ opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
235
+ model = options.get("model", self.model)
236
+ client = anthropic.AsyncAnthropic(api_key=self.api_key)
237
+
238
+ system_content, api_messages = self._extract_system_and_messages(messages)
239
+
240
+ kwargs: dict[str, Any] = {
241
+ "model": model,
242
+ "messages": api_messages,
243
+ "temperature": opts["temperature"],
244
+ "max_tokens": opts["max_tokens"],
245
+ }
246
+ if system_content:
247
+ kwargs["system"] = system_content
248
+
249
+ full_text = ""
250
+ prompt_tokens = 0
251
+ completion_tokens = 0
252
+
253
+ async with client.messages.stream(**kwargs) as stream:
254
+ async for event in stream:
255
+ if hasattr(event, "type"):
256
+ if event.type == "content_block_delta" and hasattr(event, "delta"):
257
+ delta_text = getattr(event.delta, "text", "")
258
+ if delta_text:
259
+ full_text += delta_text
260
+ yield {"type": "delta", "text": delta_text}
261
+ elif event.type == "message_delta" and hasattr(event, "usage"):
262
+ completion_tokens = getattr(event.usage, "output_tokens", 0)
263
+ elif event.type == "message_start" and hasattr(event, "message"):
264
+ usage = getattr(event.message, "usage", None)
265
+ if usage:
266
+ prompt_tokens = getattr(usage, "input_tokens", 0)
267
+
268
+ total_tokens = prompt_tokens + completion_tokens
269
+ total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
270
+
271
+ yield {
272
+ "type": "done",
273
+ "text": full_text,
274
+ "meta": {
275
+ "prompt_tokens": prompt_tokens,
276
+ "completion_tokens": completion_tokens,
277
+ "total_tokens": total_tokens,
278
+ "cost": round(total_cost, 6),
279
+ "raw_response": {},
280
+ "model_name": model,
281
+ },
282
+ }