prompture 0.0.29.dev8__py3-none-any.whl → 0.0.38.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. prompture/__init__.py +264 -23
  2. prompture/_version.py +34 -0
  3. prompture/agent.py +924 -0
  4. prompture/agent_types.py +156 -0
  5. prompture/aio/__init__.py +74 -0
  6. prompture/async_agent.py +880 -0
  7. prompture/async_conversation.py +789 -0
  8. prompture/async_core.py +803 -0
  9. prompture/async_driver.py +193 -0
  10. prompture/async_groups.py +551 -0
  11. prompture/cache.py +469 -0
  12. prompture/callbacks.py +55 -0
  13. prompture/cli.py +63 -4
  14. prompture/conversation.py +826 -0
  15. prompture/core.py +894 -263
  16. prompture/cost_mixin.py +51 -0
  17. prompture/discovery.py +187 -0
  18. prompture/driver.py +206 -5
  19. prompture/drivers/__init__.py +175 -67
  20. prompture/drivers/airllm_driver.py +109 -0
  21. prompture/drivers/async_airllm_driver.py +26 -0
  22. prompture/drivers/async_azure_driver.py +123 -0
  23. prompture/drivers/async_claude_driver.py +113 -0
  24. prompture/drivers/async_google_driver.py +316 -0
  25. prompture/drivers/async_grok_driver.py +97 -0
  26. prompture/drivers/async_groq_driver.py +90 -0
  27. prompture/drivers/async_hugging_driver.py +61 -0
  28. prompture/drivers/async_lmstudio_driver.py +148 -0
  29. prompture/drivers/async_local_http_driver.py +44 -0
  30. prompture/drivers/async_ollama_driver.py +135 -0
  31. prompture/drivers/async_openai_driver.py +102 -0
  32. prompture/drivers/async_openrouter_driver.py +102 -0
  33. prompture/drivers/async_registry.py +133 -0
  34. prompture/drivers/azure_driver.py +42 -9
  35. prompture/drivers/claude_driver.py +257 -34
  36. prompture/drivers/google_driver.py +295 -42
  37. prompture/drivers/grok_driver.py +35 -32
  38. prompture/drivers/groq_driver.py +33 -26
  39. prompture/drivers/hugging_driver.py +6 -6
  40. prompture/drivers/lmstudio_driver.py +97 -19
  41. prompture/drivers/local_http_driver.py +6 -6
  42. prompture/drivers/ollama_driver.py +168 -23
  43. prompture/drivers/openai_driver.py +184 -9
  44. prompture/drivers/openrouter_driver.py +37 -25
  45. prompture/drivers/registry.py +306 -0
  46. prompture/drivers/vision_helpers.py +153 -0
  47. prompture/field_definitions.py +106 -96
  48. prompture/group_types.py +147 -0
  49. prompture/groups.py +530 -0
  50. prompture/image.py +180 -0
  51. prompture/logging.py +80 -0
  52. prompture/model_rates.py +217 -0
  53. prompture/persistence.py +254 -0
  54. prompture/persona.py +482 -0
  55. prompture/runner.py +49 -47
  56. prompture/scaffold/__init__.py +1 -0
  57. prompture/scaffold/generator.py +84 -0
  58. prompture/scaffold/templates/Dockerfile.j2 +12 -0
  59. prompture/scaffold/templates/README.md.j2 +41 -0
  60. prompture/scaffold/templates/config.py.j2 +21 -0
  61. prompture/scaffold/templates/env.example.j2 +8 -0
  62. prompture/scaffold/templates/main.py.j2 +86 -0
  63. prompture/scaffold/templates/models.py.j2 +40 -0
  64. prompture/scaffold/templates/requirements.txt.j2 +5 -0
  65. prompture/serialization.py +218 -0
  66. prompture/server.py +183 -0
  67. prompture/session.py +117 -0
  68. prompture/settings.py +19 -1
  69. prompture/tools.py +219 -267
  70. prompture/tools_schema.py +254 -0
  71. prompture/validator.py +3 -3
  72. prompture-0.0.38.dev2.dist-info/METADATA +369 -0
  73. prompture-0.0.38.dev2.dist-info/RECORD +77 -0
  74. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/WHEEL +1 -1
  75. prompture-0.0.29.dev8.dist-info/METADATA +0 -368
  76. prompture-0.0.29.dev8.dist-info/RECORD +0 -27
  77. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/entry_points.txt +0 -0
  78. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/licenses/LICENSE +0 -0
  79. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/top_level.txt +0 -0
@@ -1,71 +1,149 @@
1
- from .openai_driver import OpenAIDriver
2
- from .local_http_driver import LocalHTTPDriver
3
- from .ollama_driver import OllamaDriver
4
- from .claude_driver import ClaudeDriver
1
+ """Driver registry and factory functions.
2
+
3
+ This module provides:
4
+ - Built-in drivers for popular LLM providers
5
+ - A pluggable registry system for custom drivers
6
+ - Factory functions to instantiate drivers by provider/model name
7
+
8
+ Custom Driver Registration:
9
+ from prompture import register_driver
10
+
11
+ def my_driver_factory(model=None):
12
+ return MyCustomDriver(model=model)
13
+
14
+ register_driver("my_provider", my_driver_factory)
15
+
16
+ # Now you can use it
17
+ driver = get_driver_for_model("my_provider/my-model")
18
+
19
+ Entry Point Discovery:
20
+ Third-party packages can register drivers via entry points.
21
+ Add to your pyproject.toml:
22
+
23
+ [project.entry-points."prompture.drivers"]
24
+ my_provider = "my_package.drivers:my_driver_factory"
25
+ """
26
+
27
+ from typing import Optional
28
+
29
+ from ..settings import settings
30
+ from .airllm_driver import AirLLMDriver
31
+ from .async_airllm_driver import AsyncAirLLMDriver
32
+ from .async_azure_driver import AsyncAzureDriver
33
+ from .async_claude_driver import AsyncClaudeDriver
34
+ from .async_google_driver import AsyncGoogleDriver
35
+ from .async_grok_driver import AsyncGrokDriver
36
+ from .async_groq_driver import AsyncGroqDriver
37
+ from .async_hugging_driver import AsyncHuggingFaceDriver
38
+ from .async_lmstudio_driver import AsyncLMStudioDriver
39
+ from .async_local_http_driver import AsyncLocalHTTPDriver
40
+ from .async_ollama_driver import AsyncOllamaDriver
41
+ from .async_openai_driver import AsyncOpenAIDriver
42
+ from .async_openrouter_driver import AsyncOpenRouterDriver
43
+ from .async_registry import ASYNC_DRIVER_REGISTRY, get_async_driver, get_async_driver_for_model
5
44
  from .azure_driver import AzureDriver
6
- from .lmstudio_driver import LMStudioDriver
45
+ from .claude_driver import ClaudeDriver
7
46
  from .google_driver import GoogleDriver
47
+ from .grok_driver import GrokDriver
8
48
  from .groq_driver import GroqDriver
49
+ from .lmstudio_driver import LMStudioDriver
50
+ from .local_http_driver import LocalHTTPDriver
51
+ from .ollama_driver import OllamaDriver
52
+ from .openai_driver import OpenAIDriver
9
53
  from .openrouter_driver import OpenRouterDriver
10
- from .grok_driver import GrokDriver
11
- from ..settings import settings
12
-
54
+ from .registry import (
55
+ _get_sync_registry,
56
+ get_async_driver_factory,
57
+ get_driver_factory,
58
+ is_async_driver_registered,
59
+ is_driver_registered,
60
+ list_registered_async_drivers,
61
+ list_registered_drivers,
62
+ load_entry_point_drivers,
63
+ register_async_driver,
64
+ register_driver,
65
+ unregister_async_driver,
66
+ unregister_driver,
67
+ )
13
68
 
14
- # Central registry: maps provider → factory function
15
- DRIVER_REGISTRY = {
16
- "openai": lambda model=None: OpenAIDriver(
17
- api_key=settings.openai_api_key,
18
- model=model or settings.openai_model
19
- ),
20
- "ollama": lambda model=None: OllamaDriver(
21
- endpoint=settings.ollama_endpoint,
22
- model=model or settings.ollama_model
23
- ),
24
- "claude": lambda model=None: ClaudeDriver(
25
- api_key=settings.claude_api_key,
26
- model=model or settings.claude_model
27
- ),
28
- "lmstudio": lambda model=None: LMStudioDriver(
69
+ # Register built-in sync drivers
70
+ register_driver(
71
+ "openai",
72
+ lambda model=None: OpenAIDriver(api_key=settings.openai_api_key, model=model or settings.openai_model),
73
+ overwrite=True,
74
+ )
75
+ register_driver(
76
+ "ollama",
77
+ lambda model=None: OllamaDriver(endpoint=settings.ollama_endpoint, model=model or settings.ollama_model),
78
+ overwrite=True,
79
+ )
80
+ register_driver(
81
+ "claude",
82
+ lambda model=None: ClaudeDriver(api_key=settings.claude_api_key, model=model or settings.claude_model),
83
+ overwrite=True,
84
+ )
85
+ register_driver(
86
+ "lmstudio",
87
+ lambda model=None: LMStudioDriver(
29
88
  endpoint=settings.lmstudio_endpoint,
30
- model=model or settings.lmstudio_model
31
- ),
32
- "azure": lambda model=None: AzureDriver(
33
- api_key=settings.azure_api_key,
34
- endpoint=settings.azure_api_endpoint,
35
- deployment_id=settings.azure_deployment_id
36
- ),
37
- "local_http": lambda model=None: LocalHTTPDriver(
38
- endpoint=settings.local_http_endpoint,
39
- model=model
40
- ),
41
- "google": lambda model=None: GoogleDriver(
42
- api_key=settings.google_api_key,
43
- model=model or settings.google_model
44
- ),
45
- "groq": lambda model=None: GroqDriver(
46
- api_key=settings.groq_api_key,
47
- model=model or settings.groq_model
89
+ model=model or settings.lmstudio_model,
90
+ api_key=settings.lmstudio_api_key,
48
91
  ),
49
- "openrouter": lambda model=None: OpenRouterDriver(
50
- api_key=settings.openrouter_api_key,
51
- model=model or settings.openrouter_model
92
+ overwrite=True,
93
+ )
94
+ register_driver(
95
+ "azure",
96
+ lambda model=None: AzureDriver(
97
+ api_key=settings.azure_api_key, endpoint=settings.azure_api_endpoint, deployment_id=settings.azure_deployment_id
52
98
  ),
53
- "grok": lambda model=None: GrokDriver(
54
- api_key=settings.grok_api_key,
55
- model=model or settings.grok_model
99
+ overwrite=True,
100
+ )
101
+ register_driver(
102
+ "local_http",
103
+ lambda model=None: LocalHTTPDriver(endpoint=settings.local_http_endpoint, model=model),
104
+ overwrite=True,
105
+ )
106
+ register_driver(
107
+ "google",
108
+ lambda model=None: GoogleDriver(api_key=settings.google_api_key, model=model or settings.google_model),
109
+ overwrite=True,
110
+ )
111
+ register_driver(
112
+ "groq",
113
+ lambda model=None: GroqDriver(api_key=settings.groq_api_key, model=model or settings.groq_model),
114
+ overwrite=True,
115
+ )
116
+ register_driver(
117
+ "openrouter",
118
+ lambda model=None: OpenRouterDriver(api_key=settings.openrouter_api_key, model=model or settings.openrouter_model),
119
+ overwrite=True,
120
+ )
121
+ register_driver(
122
+ "grok",
123
+ lambda model=None: GrokDriver(api_key=settings.grok_api_key, model=model or settings.grok_model),
124
+ overwrite=True,
125
+ )
126
+ register_driver(
127
+ "airllm",
128
+ lambda model=None: AirLLMDriver(
129
+ model=model or settings.airllm_model,
130
+ compression=settings.airllm_compression,
56
131
  ),
57
- }
132
+ overwrite=True,
133
+ )
58
134
 
135
+ # Backwards compatibility: expose registry dict (read-only view recommended)
136
+ DRIVER_REGISTRY = _get_sync_registry()
59
137
 
60
- def get_driver(provider_name: str = None):
138
+
139
+ def get_driver(provider_name: Optional[str] = None):
61
140
  """
62
141
  Factory to get a driver instance based on the provider name (legacy style).
63
142
  Uses default model from settings if not overridden.
64
143
  """
65
144
  provider = (provider_name or settings.ai_provider or "ollama").strip().lower()
66
- if provider not in DRIVER_REGISTRY:
67
- raise ValueError(f"Unknown provider: {provider_name}")
68
- return DRIVER_REGISTRY[provider]() # use default model from settings
145
+ factory = get_driver_factory(provider)
146
+ return factory() # use default model from settings
69
147
 
70
148
 
71
149
  def get_driver_for_model(model_str: str):
@@ -73,21 +151,21 @@ def get_driver_for_model(model_str: str):
73
151
  Factory to get a driver instance based on a full model string.
74
152
  Format: provider/model_id
75
153
  Example: "openai/gpt-4-turbo-preview"
76
-
154
+
77
155
  Args:
78
156
  model_str: Model identifier string. Can be either:
79
157
  - Full format: "provider/model" (e.g. "openai/gpt-4")
80
158
  - Provider only: "provider" (e.g. "openai")
81
-
159
+
82
160
  Returns:
83
161
  A configured driver instance for the specified provider/model.
84
-
162
+
85
163
  Raises:
86
164
  ValueError: If provider is invalid or format is incorrect.
87
165
  """
88
166
  if not isinstance(model_str, str):
89
167
  raise ValueError("Model string must be a string, got {type(model_str)}")
90
-
168
+
91
169
  if not model_str:
92
170
  raise ValueError("Model string cannot be empty")
93
171
 
@@ -96,25 +174,55 @@ def get_driver_for_model(model_str: str):
96
174
  provider = parts[0].lower()
97
175
  model_id = parts[1] if len(parts) > 1 else None
98
176
 
99
- # Validate provider
100
- if provider not in DRIVER_REGISTRY:
101
- raise ValueError(f"Unsupported provider '{provider}'")
102
-
177
+ # Get factory (validates provider exists)
178
+ factory = get_driver_factory(provider)
179
+
103
180
  # Create driver with model ID if provided, otherwise use default
104
- return DRIVER_REGISTRY[provider](model_id)
181
+ return factory(model_id)
105
182
 
106
183
 
107
184
  __all__ = [
108
- "OpenAIDriver",
109
- "LocalHTTPDriver",
110
- "OllamaDriver",
111
- "ClaudeDriver",
112
- "LMStudioDriver",
185
+ "ASYNC_DRIVER_REGISTRY",
186
+ # Legacy registry dicts (for backwards compatibility)
187
+ "DRIVER_REGISTRY",
188
+ # Sync drivers
189
+ "AirLLMDriver",
190
+ # Async drivers
191
+ "AsyncAirLLMDriver",
192
+ "AsyncAzureDriver",
193
+ "AsyncClaudeDriver",
194
+ "AsyncGoogleDriver",
195
+ "AsyncGrokDriver",
196
+ "AsyncGroqDriver",
197
+ "AsyncHuggingFaceDriver",
198
+ "AsyncLMStudioDriver",
199
+ "AsyncLocalHTTPDriver",
200
+ "AsyncOllamaDriver",
201
+ "AsyncOpenAIDriver",
202
+ "AsyncOpenRouterDriver",
113
203
  "AzureDriver",
204
+ "ClaudeDriver",
114
205
  "GoogleDriver",
206
+ "GrokDriver",
115
207
  "GroqDriver",
208
+ "LMStudioDriver",
209
+ "LocalHTTPDriver",
210
+ "OllamaDriver",
211
+ "OpenAIDriver",
116
212
  "OpenRouterDriver",
117
- "GrokDriver",
213
+ "get_async_driver",
214
+ "get_async_driver_for_model",
215
+ # Factory functions
118
216
  "get_driver",
119
217
  "get_driver_for_model",
218
+ "is_async_driver_registered",
219
+ "is_driver_registered",
220
+ "list_registered_async_drivers",
221
+ "list_registered_drivers",
222
+ "load_entry_point_drivers",
223
+ "register_async_driver",
224
+ # Registry functions (public API)
225
+ "register_driver",
226
+ "unregister_async_driver",
227
+ "unregister_driver",
120
228
  ]
@@ -0,0 +1,109 @@
1
+ import logging
2
+ from typing import Any, Optional
3
+
4
+ from ..driver import Driver
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ class AirLLMDriver(Driver):
10
+ """Driver for AirLLM — run large models (70B+) on consumer GPUs via
11
+ layer-by-layer memory management.
12
+
13
+ The ``airllm`` package is a lazy dependency: it is imported on first
14
+ ``generate()`` call so the rest of Prompture works without it installed.
15
+ """
16
+
17
+ MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
18
+
19
+ def __init__(self, model: str = "meta-llama/Llama-2-7b-hf", compression: Optional[str] = None):
20
+ """
21
+ Args:
22
+ model: HuggingFace repo ID (e.g. ``"meta-llama/Llama-2-70b-hf"``).
23
+ compression: Optional quantization mode — ``"4bit"`` or ``"8bit"``.
24
+ """
25
+ self.model = model
26
+ self.compression = compression
27
+ self.options: dict[str, Any] = {}
28
+ self._llm = None
29
+ self._tokenizer = None
30
+
31
+ # ------------------------------------------------------------------
32
+ # Lazy model loading
33
+ # ------------------------------------------------------------------
34
+ def _ensure_loaded(self):
35
+ """Load the AirLLM model and tokenizer on first use."""
36
+ if self._llm is not None:
37
+ return
38
+
39
+ try:
40
+ from airllm import AutoModel
41
+ except ImportError:
42
+ raise ImportError(
43
+ "The 'airllm' package is required for the AirLLM driver. Install it with: pip install prompture[airllm]"
44
+ ) from None
45
+
46
+ try:
47
+ from transformers import AutoTokenizer
48
+ except ImportError:
49
+ raise ImportError(
50
+ "The 'transformers' package is required for the AirLLM driver. "
51
+ "Install it with: pip install transformers"
52
+ ) from None
53
+
54
+ logger.info(f"Loading AirLLM model: {self.model} (compression={self.compression})")
55
+
56
+ load_kwargs: dict[str, Any] = {}
57
+ if self.compression:
58
+ load_kwargs["compression"] = self.compression
59
+
60
+ self._llm = AutoModel.from_pretrained(self.model, **load_kwargs)
61
+ self._tokenizer = AutoTokenizer.from_pretrained(self.model)
62
+ logger.info("AirLLM model loaded successfully")
63
+
64
+ # ------------------------------------------------------------------
65
+ # Driver interface
66
+ # ------------------------------------------------------------------
67
+ def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
68
+ self._ensure_loaded()
69
+
70
+ merged_options = self.options.copy()
71
+ if options:
72
+ merged_options.update(options)
73
+
74
+ max_new_tokens = merged_options.get("max_new_tokens", 256)
75
+
76
+ # Tokenize
77
+ input_ids = self._tokenizer(prompt, return_tensors="pt").input_ids
78
+
79
+ prompt_tokens = input_ids.shape[1]
80
+
81
+ logger.debug(f"AirLLM generating with max_new_tokens={max_new_tokens}, prompt_tokens={prompt_tokens}")
82
+
83
+ # Generate
84
+ output_ids = self._llm.generate(
85
+ input_ids,
86
+ max_new_tokens=max_new_tokens,
87
+ )
88
+
89
+ # Decode only the newly generated tokens (strip the prompt prefix)
90
+ new_tokens = output_ids[0, prompt_tokens:]
91
+ completion_tokens = len(new_tokens)
92
+ text = self._tokenizer.decode(new_tokens, skip_special_tokens=True)
93
+
94
+ total_tokens = prompt_tokens + completion_tokens
95
+
96
+ meta = {
97
+ "prompt_tokens": prompt_tokens,
98
+ "completion_tokens": completion_tokens,
99
+ "total_tokens": total_tokens,
100
+ "cost": 0.0,
101
+ "raw_response": {
102
+ "model": self.model,
103
+ "compression": self.compression,
104
+ "max_new_tokens": max_new_tokens,
105
+ },
106
+ "model_name": self.model,
107
+ }
108
+
109
+ return {"text": text, "meta": meta}
@@ -0,0 +1,26 @@
1
+ """Async AirLLM driver — wraps the sync GPU-bound driver with asyncio.to_thread."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ from typing import Any
7
+
8
+ from ..async_driver import AsyncDriver
9
+ from .airllm_driver import AirLLMDriver
10
+
11
+
12
+ class AsyncAirLLMDriver(AsyncDriver):
13
+ """Async wrapper around :class:`AirLLMDriver`.
14
+
15
+ AirLLM is GPU-bound with no native async API, so we delegate to
16
+ ``asyncio.to_thread()`` to avoid blocking the event loop.
17
+ """
18
+
19
+ MODEL_PRICING = AirLLMDriver.MODEL_PRICING
20
+
21
+ def __init__(self, model: str = "meta-llama/Llama-2-7b-hf", compression: str | None = None):
22
+ self.model = model
23
+ self._sync_driver = AirLLMDriver(model=model, compression=compression)
24
+
25
+ async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
26
+ return await asyncio.to_thread(self._sync_driver.generate, prompt, options)
@@ -0,0 +1,123 @@
1
+ """Async Azure OpenAI driver. Requires the ``openai`` package (>=1.0.0)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from typing import Any
7
+
8
+ try:
9
+ from openai import AsyncAzureOpenAI
10
+ except Exception:
11
+ AsyncAzureOpenAI = None
12
+
13
+ from ..async_driver import AsyncDriver
14
+ from ..cost_mixin import CostMixin
15
+ from .azure_driver import AzureDriver
16
+
17
+
18
+ class AsyncAzureDriver(CostMixin, AsyncDriver):
19
+ supports_json_mode = True
20
+ supports_json_schema = True
21
+ supports_vision = True
22
+
23
+ MODEL_PRICING = AzureDriver.MODEL_PRICING
24
+
25
+ def __init__(
26
+ self,
27
+ api_key: str | None = None,
28
+ endpoint: str | None = None,
29
+ deployment_id: str | None = None,
30
+ model: str = "gpt-4o-mini",
31
+ ):
32
+ self.api_key = api_key or os.getenv("AZURE_API_KEY")
33
+ self.endpoint = endpoint or os.getenv("AZURE_API_ENDPOINT")
34
+ self.deployment_id = deployment_id or os.getenv("AZURE_DEPLOYMENT_ID")
35
+ self.api_version = os.getenv("AZURE_API_VERSION", "2023-07-01-preview")
36
+ self.model = model
37
+
38
+ if not self.api_key:
39
+ raise ValueError("Missing Azure API key (AZURE_API_KEY).")
40
+ if not self.endpoint:
41
+ raise ValueError("Missing Azure API endpoint (AZURE_API_ENDPOINT).")
42
+ if not self.deployment_id:
43
+ raise ValueError("Missing Azure deployment ID (AZURE_DEPLOYMENT_ID).")
44
+
45
+ if AsyncAzureOpenAI:
46
+ self.client = AsyncAzureOpenAI(
47
+ api_key=self.api_key,
48
+ api_version=self.api_version,
49
+ azure_endpoint=self.endpoint,
50
+ )
51
+ else:
52
+ self.client = None
53
+
54
+ supports_messages = True
55
+
56
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
57
+ from .vision_helpers import _prepare_openai_vision_messages
58
+
59
+ return _prepare_openai_vision_messages(messages)
60
+
61
+ async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
62
+ messages = [{"role": "user", "content": prompt}]
63
+ return await self._do_generate(messages, options)
64
+
65
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
66
+ return await self._do_generate(self._prepare_messages(messages), options)
67
+
68
+ async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
69
+ if self.client is None:
70
+ raise RuntimeError("openai package (>=1.0.0) with AsyncAzureOpenAI not installed")
71
+
72
+ model = options.get("model", self.model)
73
+ model_info = self.MODEL_PRICING.get(model, {})
74
+ tokens_param = model_info.get("tokens_param", "max_tokens")
75
+ supports_temperature = model_info.get("supports_temperature", True)
76
+
77
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
78
+
79
+ kwargs = {
80
+ "model": self.deployment_id,
81
+ "messages": messages,
82
+ }
83
+ kwargs[tokens_param] = opts.get("max_tokens", 512)
84
+
85
+ if supports_temperature and "temperature" in opts:
86
+ kwargs["temperature"] = opts["temperature"]
87
+
88
+ # Native JSON mode support
89
+ if options.get("json_mode"):
90
+ json_schema = options.get("json_schema")
91
+ if json_schema:
92
+ kwargs["response_format"] = {
93
+ "type": "json_schema",
94
+ "json_schema": {
95
+ "name": "extraction",
96
+ "strict": True,
97
+ "schema": json_schema,
98
+ },
99
+ }
100
+ else:
101
+ kwargs["response_format"] = {"type": "json_object"}
102
+
103
+ resp = await self.client.chat.completions.create(**kwargs)
104
+
105
+ usage = getattr(resp, "usage", None)
106
+ prompt_tokens = getattr(usage, "prompt_tokens", 0)
107
+ completion_tokens = getattr(usage, "completion_tokens", 0)
108
+ total_tokens = getattr(usage, "total_tokens", 0)
109
+
110
+ total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
111
+
112
+ meta = {
113
+ "prompt_tokens": prompt_tokens,
114
+ "completion_tokens": completion_tokens,
115
+ "total_tokens": total_tokens,
116
+ "cost": total_cost,
117
+ "raw_response": resp.model_dump(),
118
+ "model_name": model,
119
+ "deployment_id": self.deployment_id,
120
+ }
121
+
122
+ text = resp.choices[0].message.content
123
+ return {"text": text, "meta": meta}
@@ -0,0 +1,113 @@
1
+ """Async Anthropic Claude driver. Requires the ``anthropic`` package."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ from typing import Any
8
+
9
+ try:
10
+ import anthropic
11
+ except Exception:
12
+ anthropic = None
13
+
14
+ from ..async_driver import AsyncDriver
15
+ from ..cost_mixin import CostMixin
16
+ from .claude_driver import ClaudeDriver
17
+
18
+
19
+ class AsyncClaudeDriver(CostMixin, AsyncDriver):
20
+ supports_json_mode = True
21
+ supports_json_schema = True
22
+ supports_vision = True
23
+
24
+ MODEL_PRICING = ClaudeDriver.MODEL_PRICING
25
+
26
+ def __init__(self, api_key: str | None = None, model: str = "claude-3-5-haiku-20241022"):
27
+ self.api_key = api_key or os.getenv("CLAUDE_API_KEY")
28
+ self.model = model or os.getenv("CLAUDE_MODEL_NAME", "claude-3-5-haiku-20241022")
29
+
30
+ supports_messages = True
31
+
32
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
33
+ from .vision_helpers import _prepare_claude_vision_messages
34
+
35
+ return _prepare_claude_vision_messages(messages)
36
+
37
+ async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
38
+ messages = [{"role": "user", "content": prompt}]
39
+ return await self._do_generate(messages, options)
40
+
41
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
42
+ return await self._do_generate(self._prepare_messages(messages), options)
43
+
44
+ async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
45
+ if anthropic is None:
46
+ raise RuntimeError("anthropic package not installed")
47
+
48
+ opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
49
+ model = options.get("model", self.model)
50
+
51
+ client = anthropic.AsyncAnthropic(api_key=self.api_key)
52
+
53
+ # Anthropic requires system messages as a top-level parameter
54
+ system_content = None
55
+ api_messages = []
56
+ for msg in messages:
57
+ if msg.get("role") == "system":
58
+ system_content = msg.get("content", "")
59
+ else:
60
+ api_messages.append(msg)
61
+
62
+ # Build common kwargs
63
+ common_kwargs: dict[str, Any] = {
64
+ "model": model,
65
+ "messages": api_messages,
66
+ "temperature": opts["temperature"],
67
+ "max_tokens": opts["max_tokens"],
68
+ }
69
+ if system_content:
70
+ common_kwargs["system"] = system_content
71
+
72
+ # Native JSON mode: use tool-use for schema enforcement
73
+ if options.get("json_mode"):
74
+ json_schema = options.get("json_schema")
75
+ if json_schema:
76
+ tool_def = {
77
+ "name": "extract_json",
78
+ "description": "Extract structured data matching the schema",
79
+ "input_schema": json_schema,
80
+ }
81
+ resp = await client.messages.create(
82
+ **common_kwargs,
83
+ tools=[tool_def],
84
+ tool_choice={"type": "tool", "name": "extract_json"},
85
+ )
86
+ text = ""
87
+ for block in resp.content:
88
+ if block.type == "tool_use":
89
+ text = json.dumps(block.input)
90
+ break
91
+ else:
92
+ resp = await client.messages.create(**common_kwargs)
93
+ text = resp.content[0].text
94
+ else:
95
+ resp = await client.messages.create(**common_kwargs)
96
+ text = resp.content[0].text
97
+
98
+ prompt_tokens = resp.usage.input_tokens
99
+ completion_tokens = resp.usage.output_tokens
100
+ total_tokens = prompt_tokens + completion_tokens
101
+
102
+ total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
103
+
104
+ meta = {
105
+ "prompt_tokens": prompt_tokens,
106
+ "completion_tokens": completion_tokens,
107
+ "total_tokens": total_tokens,
108
+ "cost": total_cost,
109
+ "raw_response": dict(resp),
110
+ "model_name": model,
111
+ }
112
+
113
+ return {"text": text, "meta": meta}