prompture 0.0.29.dev8__py3-none-any.whl → 0.0.38.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +264 -23
- prompture/_version.py +34 -0
- prompture/agent.py +924 -0
- prompture/agent_types.py +156 -0
- prompture/aio/__init__.py +74 -0
- prompture/async_agent.py +880 -0
- prompture/async_conversation.py +789 -0
- prompture/async_core.py +803 -0
- prompture/async_driver.py +193 -0
- prompture/async_groups.py +551 -0
- prompture/cache.py +469 -0
- prompture/callbacks.py +55 -0
- prompture/cli.py +63 -4
- prompture/conversation.py +826 -0
- prompture/core.py +894 -263
- prompture/cost_mixin.py +51 -0
- prompture/discovery.py +187 -0
- prompture/driver.py +206 -5
- prompture/drivers/__init__.py +175 -67
- prompture/drivers/airllm_driver.py +109 -0
- prompture/drivers/async_airllm_driver.py +26 -0
- prompture/drivers/async_azure_driver.py +123 -0
- prompture/drivers/async_claude_driver.py +113 -0
- prompture/drivers/async_google_driver.py +316 -0
- prompture/drivers/async_grok_driver.py +97 -0
- prompture/drivers/async_groq_driver.py +90 -0
- prompture/drivers/async_hugging_driver.py +61 -0
- prompture/drivers/async_lmstudio_driver.py +148 -0
- prompture/drivers/async_local_http_driver.py +44 -0
- prompture/drivers/async_ollama_driver.py +135 -0
- prompture/drivers/async_openai_driver.py +102 -0
- prompture/drivers/async_openrouter_driver.py +102 -0
- prompture/drivers/async_registry.py +133 -0
- prompture/drivers/azure_driver.py +42 -9
- prompture/drivers/claude_driver.py +257 -34
- prompture/drivers/google_driver.py +295 -42
- prompture/drivers/grok_driver.py +35 -32
- prompture/drivers/groq_driver.py +33 -26
- prompture/drivers/hugging_driver.py +6 -6
- prompture/drivers/lmstudio_driver.py +97 -19
- prompture/drivers/local_http_driver.py +6 -6
- prompture/drivers/ollama_driver.py +168 -23
- prompture/drivers/openai_driver.py +184 -9
- prompture/drivers/openrouter_driver.py +37 -25
- prompture/drivers/registry.py +306 -0
- prompture/drivers/vision_helpers.py +153 -0
- prompture/field_definitions.py +106 -96
- prompture/group_types.py +147 -0
- prompture/groups.py +530 -0
- prompture/image.py +180 -0
- prompture/logging.py +80 -0
- prompture/model_rates.py +217 -0
- prompture/persistence.py +254 -0
- prompture/persona.py +482 -0
- prompture/runner.py +49 -47
- prompture/scaffold/__init__.py +1 -0
- prompture/scaffold/generator.py +84 -0
- prompture/scaffold/templates/Dockerfile.j2 +12 -0
- prompture/scaffold/templates/README.md.j2 +41 -0
- prompture/scaffold/templates/config.py.j2 +21 -0
- prompture/scaffold/templates/env.example.j2 +8 -0
- prompture/scaffold/templates/main.py.j2 +86 -0
- prompture/scaffold/templates/models.py.j2 +40 -0
- prompture/scaffold/templates/requirements.txt.j2 +5 -0
- prompture/serialization.py +218 -0
- prompture/server.py +183 -0
- prompture/session.py +117 -0
- prompture/settings.py +19 -1
- prompture/tools.py +219 -267
- prompture/tools_schema.py +254 -0
- prompture/validator.py +3 -3
- prompture-0.0.38.dev2.dist-info/METADATA +369 -0
- prompture-0.0.38.dev2.dist-info/RECORD +77 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/WHEEL +1 -1
- prompture-0.0.29.dev8.dist-info/METADATA +0 -368
- prompture-0.0.29.dev8.dist-info/RECORD +0 -27
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/top_level.txt +0 -0
prompture/drivers/__init__.py
CHANGED
|
@@ -1,71 +1,149 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
1
|
+
"""Driver registry and factory functions.
|
|
2
|
+
|
|
3
|
+
This module provides:
|
|
4
|
+
- Built-in drivers for popular LLM providers
|
|
5
|
+
- A pluggable registry system for custom drivers
|
|
6
|
+
- Factory functions to instantiate drivers by provider/model name
|
|
7
|
+
|
|
8
|
+
Custom Driver Registration:
|
|
9
|
+
from prompture import register_driver
|
|
10
|
+
|
|
11
|
+
def my_driver_factory(model=None):
|
|
12
|
+
return MyCustomDriver(model=model)
|
|
13
|
+
|
|
14
|
+
register_driver("my_provider", my_driver_factory)
|
|
15
|
+
|
|
16
|
+
# Now you can use it
|
|
17
|
+
driver = get_driver_for_model("my_provider/my-model")
|
|
18
|
+
|
|
19
|
+
Entry Point Discovery:
|
|
20
|
+
Third-party packages can register drivers via entry points.
|
|
21
|
+
Add to your pyproject.toml:
|
|
22
|
+
|
|
23
|
+
[project.entry-points."prompture.drivers"]
|
|
24
|
+
my_provider = "my_package.drivers:my_driver_factory"
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
from typing import Optional
|
|
28
|
+
|
|
29
|
+
from ..settings import settings
|
|
30
|
+
from .airllm_driver import AirLLMDriver
|
|
31
|
+
from .async_airllm_driver import AsyncAirLLMDriver
|
|
32
|
+
from .async_azure_driver import AsyncAzureDriver
|
|
33
|
+
from .async_claude_driver import AsyncClaudeDriver
|
|
34
|
+
from .async_google_driver import AsyncGoogleDriver
|
|
35
|
+
from .async_grok_driver import AsyncGrokDriver
|
|
36
|
+
from .async_groq_driver import AsyncGroqDriver
|
|
37
|
+
from .async_hugging_driver import AsyncHuggingFaceDriver
|
|
38
|
+
from .async_lmstudio_driver import AsyncLMStudioDriver
|
|
39
|
+
from .async_local_http_driver import AsyncLocalHTTPDriver
|
|
40
|
+
from .async_ollama_driver import AsyncOllamaDriver
|
|
41
|
+
from .async_openai_driver import AsyncOpenAIDriver
|
|
42
|
+
from .async_openrouter_driver import AsyncOpenRouterDriver
|
|
43
|
+
from .async_registry import ASYNC_DRIVER_REGISTRY, get_async_driver, get_async_driver_for_model
|
|
5
44
|
from .azure_driver import AzureDriver
|
|
6
|
-
from .
|
|
45
|
+
from .claude_driver import ClaudeDriver
|
|
7
46
|
from .google_driver import GoogleDriver
|
|
47
|
+
from .grok_driver import GrokDriver
|
|
8
48
|
from .groq_driver import GroqDriver
|
|
49
|
+
from .lmstudio_driver import LMStudioDriver
|
|
50
|
+
from .local_http_driver import LocalHTTPDriver
|
|
51
|
+
from .ollama_driver import OllamaDriver
|
|
52
|
+
from .openai_driver import OpenAIDriver
|
|
9
53
|
from .openrouter_driver import OpenRouterDriver
|
|
10
|
-
from .
|
|
11
|
-
|
|
12
|
-
|
|
54
|
+
from .registry import (
|
|
55
|
+
_get_sync_registry,
|
|
56
|
+
get_async_driver_factory,
|
|
57
|
+
get_driver_factory,
|
|
58
|
+
is_async_driver_registered,
|
|
59
|
+
is_driver_registered,
|
|
60
|
+
list_registered_async_drivers,
|
|
61
|
+
list_registered_drivers,
|
|
62
|
+
load_entry_point_drivers,
|
|
63
|
+
register_async_driver,
|
|
64
|
+
register_driver,
|
|
65
|
+
unregister_async_driver,
|
|
66
|
+
unregister_driver,
|
|
67
|
+
)
|
|
13
68
|
|
|
14
|
-
#
|
|
15
|
-
|
|
16
|
-
"openai"
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
),
|
|
28
|
-
|
|
69
|
+
# Register built-in sync drivers
|
|
70
|
+
register_driver(
|
|
71
|
+
"openai",
|
|
72
|
+
lambda model=None: OpenAIDriver(api_key=settings.openai_api_key, model=model or settings.openai_model),
|
|
73
|
+
overwrite=True,
|
|
74
|
+
)
|
|
75
|
+
register_driver(
|
|
76
|
+
"ollama",
|
|
77
|
+
lambda model=None: OllamaDriver(endpoint=settings.ollama_endpoint, model=model or settings.ollama_model),
|
|
78
|
+
overwrite=True,
|
|
79
|
+
)
|
|
80
|
+
register_driver(
|
|
81
|
+
"claude",
|
|
82
|
+
lambda model=None: ClaudeDriver(api_key=settings.claude_api_key, model=model or settings.claude_model),
|
|
83
|
+
overwrite=True,
|
|
84
|
+
)
|
|
85
|
+
register_driver(
|
|
86
|
+
"lmstudio",
|
|
87
|
+
lambda model=None: LMStudioDriver(
|
|
29
88
|
endpoint=settings.lmstudio_endpoint,
|
|
30
|
-
model=model or settings.lmstudio_model
|
|
31
|
-
|
|
32
|
-
"azure": lambda model=None: AzureDriver(
|
|
33
|
-
api_key=settings.azure_api_key,
|
|
34
|
-
endpoint=settings.azure_api_endpoint,
|
|
35
|
-
deployment_id=settings.azure_deployment_id
|
|
36
|
-
),
|
|
37
|
-
"local_http": lambda model=None: LocalHTTPDriver(
|
|
38
|
-
endpoint=settings.local_http_endpoint,
|
|
39
|
-
model=model
|
|
40
|
-
),
|
|
41
|
-
"google": lambda model=None: GoogleDriver(
|
|
42
|
-
api_key=settings.google_api_key,
|
|
43
|
-
model=model or settings.google_model
|
|
44
|
-
),
|
|
45
|
-
"groq": lambda model=None: GroqDriver(
|
|
46
|
-
api_key=settings.groq_api_key,
|
|
47
|
-
model=model or settings.groq_model
|
|
89
|
+
model=model or settings.lmstudio_model,
|
|
90
|
+
api_key=settings.lmstudio_api_key,
|
|
48
91
|
),
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
92
|
+
overwrite=True,
|
|
93
|
+
)
|
|
94
|
+
register_driver(
|
|
95
|
+
"azure",
|
|
96
|
+
lambda model=None: AzureDriver(
|
|
97
|
+
api_key=settings.azure_api_key, endpoint=settings.azure_api_endpoint, deployment_id=settings.azure_deployment_id
|
|
52
98
|
),
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
99
|
+
overwrite=True,
|
|
100
|
+
)
|
|
101
|
+
register_driver(
|
|
102
|
+
"local_http",
|
|
103
|
+
lambda model=None: LocalHTTPDriver(endpoint=settings.local_http_endpoint, model=model),
|
|
104
|
+
overwrite=True,
|
|
105
|
+
)
|
|
106
|
+
register_driver(
|
|
107
|
+
"google",
|
|
108
|
+
lambda model=None: GoogleDriver(api_key=settings.google_api_key, model=model or settings.google_model),
|
|
109
|
+
overwrite=True,
|
|
110
|
+
)
|
|
111
|
+
register_driver(
|
|
112
|
+
"groq",
|
|
113
|
+
lambda model=None: GroqDriver(api_key=settings.groq_api_key, model=model or settings.groq_model),
|
|
114
|
+
overwrite=True,
|
|
115
|
+
)
|
|
116
|
+
register_driver(
|
|
117
|
+
"openrouter",
|
|
118
|
+
lambda model=None: OpenRouterDriver(api_key=settings.openrouter_api_key, model=model or settings.openrouter_model),
|
|
119
|
+
overwrite=True,
|
|
120
|
+
)
|
|
121
|
+
register_driver(
|
|
122
|
+
"grok",
|
|
123
|
+
lambda model=None: GrokDriver(api_key=settings.grok_api_key, model=model or settings.grok_model),
|
|
124
|
+
overwrite=True,
|
|
125
|
+
)
|
|
126
|
+
register_driver(
|
|
127
|
+
"airllm",
|
|
128
|
+
lambda model=None: AirLLMDriver(
|
|
129
|
+
model=model or settings.airllm_model,
|
|
130
|
+
compression=settings.airllm_compression,
|
|
56
131
|
),
|
|
57
|
-
|
|
132
|
+
overwrite=True,
|
|
133
|
+
)
|
|
58
134
|
|
|
135
|
+
# Backwards compatibility: expose registry dict (read-only view recommended)
|
|
136
|
+
DRIVER_REGISTRY = _get_sync_registry()
|
|
59
137
|
|
|
60
|
-
|
|
138
|
+
|
|
139
|
+
def get_driver(provider_name: Optional[str] = None):
|
|
61
140
|
"""
|
|
62
141
|
Factory to get a driver instance based on the provider name (legacy style).
|
|
63
142
|
Uses default model from settings if not overridden.
|
|
64
143
|
"""
|
|
65
144
|
provider = (provider_name or settings.ai_provider or "ollama").strip().lower()
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
return DRIVER_REGISTRY[provider]() # use default model from settings
|
|
145
|
+
factory = get_driver_factory(provider)
|
|
146
|
+
return factory() # use default model from settings
|
|
69
147
|
|
|
70
148
|
|
|
71
149
|
def get_driver_for_model(model_str: str):
|
|
@@ -73,21 +151,21 @@ def get_driver_for_model(model_str: str):
|
|
|
73
151
|
Factory to get a driver instance based on a full model string.
|
|
74
152
|
Format: provider/model_id
|
|
75
153
|
Example: "openai/gpt-4-turbo-preview"
|
|
76
|
-
|
|
154
|
+
|
|
77
155
|
Args:
|
|
78
156
|
model_str: Model identifier string. Can be either:
|
|
79
157
|
- Full format: "provider/model" (e.g. "openai/gpt-4")
|
|
80
158
|
- Provider only: "provider" (e.g. "openai")
|
|
81
|
-
|
|
159
|
+
|
|
82
160
|
Returns:
|
|
83
161
|
A configured driver instance for the specified provider/model.
|
|
84
|
-
|
|
162
|
+
|
|
85
163
|
Raises:
|
|
86
164
|
ValueError: If provider is invalid or format is incorrect.
|
|
87
165
|
"""
|
|
88
166
|
if not isinstance(model_str, str):
|
|
89
167
|
raise ValueError("Model string must be a string, got {type(model_str)}")
|
|
90
|
-
|
|
168
|
+
|
|
91
169
|
if not model_str:
|
|
92
170
|
raise ValueError("Model string cannot be empty")
|
|
93
171
|
|
|
@@ -96,25 +174,55 @@ def get_driver_for_model(model_str: str):
|
|
|
96
174
|
provider = parts[0].lower()
|
|
97
175
|
model_id = parts[1] if len(parts) > 1 else None
|
|
98
176
|
|
|
99
|
-
#
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
177
|
+
# Get factory (validates provider exists)
|
|
178
|
+
factory = get_driver_factory(provider)
|
|
179
|
+
|
|
103
180
|
# Create driver with model ID if provided, otherwise use default
|
|
104
|
-
return
|
|
181
|
+
return factory(model_id)
|
|
105
182
|
|
|
106
183
|
|
|
107
184
|
__all__ = [
|
|
108
|
-
"
|
|
109
|
-
|
|
110
|
-
"
|
|
111
|
-
|
|
112
|
-
"
|
|
185
|
+
"ASYNC_DRIVER_REGISTRY",
|
|
186
|
+
# Legacy registry dicts (for backwards compatibility)
|
|
187
|
+
"DRIVER_REGISTRY",
|
|
188
|
+
# Sync drivers
|
|
189
|
+
"AirLLMDriver",
|
|
190
|
+
# Async drivers
|
|
191
|
+
"AsyncAirLLMDriver",
|
|
192
|
+
"AsyncAzureDriver",
|
|
193
|
+
"AsyncClaudeDriver",
|
|
194
|
+
"AsyncGoogleDriver",
|
|
195
|
+
"AsyncGrokDriver",
|
|
196
|
+
"AsyncGroqDriver",
|
|
197
|
+
"AsyncHuggingFaceDriver",
|
|
198
|
+
"AsyncLMStudioDriver",
|
|
199
|
+
"AsyncLocalHTTPDriver",
|
|
200
|
+
"AsyncOllamaDriver",
|
|
201
|
+
"AsyncOpenAIDriver",
|
|
202
|
+
"AsyncOpenRouterDriver",
|
|
113
203
|
"AzureDriver",
|
|
204
|
+
"ClaudeDriver",
|
|
114
205
|
"GoogleDriver",
|
|
206
|
+
"GrokDriver",
|
|
115
207
|
"GroqDriver",
|
|
208
|
+
"LMStudioDriver",
|
|
209
|
+
"LocalHTTPDriver",
|
|
210
|
+
"OllamaDriver",
|
|
211
|
+
"OpenAIDriver",
|
|
116
212
|
"OpenRouterDriver",
|
|
117
|
-
"
|
|
213
|
+
"get_async_driver",
|
|
214
|
+
"get_async_driver_for_model",
|
|
215
|
+
# Factory functions
|
|
118
216
|
"get_driver",
|
|
119
217
|
"get_driver_for_model",
|
|
218
|
+
"is_async_driver_registered",
|
|
219
|
+
"is_driver_registered",
|
|
220
|
+
"list_registered_async_drivers",
|
|
221
|
+
"list_registered_drivers",
|
|
222
|
+
"load_entry_point_drivers",
|
|
223
|
+
"register_async_driver",
|
|
224
|
+
# Registry functions (public API)
|
|
225
|
+
"register_driver",
|
|
226
|
+
"unregister_async_driver",
|
|
227
|
+
"unregister_driver",
|
|
120
228
|
]
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any, Optional
|
|
3
|
+
|
|
4
|
+
from ..driver import Driver
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class AirLLMDriver(Driver):
|
|
10
|
+
"""Driver for AirLLM — run large models (70B+) on consumer GPUs via
|
|
11
|
+
layer-by-layer memory management.
|
|
12
|
+
|
|
13
|
+
The ``airllm`` package is a lazy dependency: it is imported on first
|
|
14
|
+
``generate()`` call so the rest of Prompture works without it installed.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
18
|
+
|
|
19
|
+
def __init__(self, model: str = "meta-llama/Llama-2-7b-hf", compression: Optional[str] = None):
|
|
20
|
+
"""
|
|
21
|
+
Args:
|
|
22
|
+
model: HuggingFace repo ID (e.g. ``"meta-llama/Llama-2-70b-hf"``).
|
|
23
|
+
compression: Optional quantization mode — ``"4bit"`` or ``"8bit"``.
|
|
24
|
+
"""
|
|
25
|
+
self.model = model
|
|
26
|
+
self.compression = compression
|
|
27
|
+
self.options: dict[str, Any] = {}
|
|
28
|
+
self._llm = None
|
|
29
|
+
self._tokenizer = None
|
|
30
|
+
|
|
31
|
+
# ------------------------------------------------------------------
|
|
32
|
+
# Lazy model loading
|
|
33
|
+
# ------------------------------------------------------------------
|
|
34
|
+
def _ensure_loaded(self):
|
|
35
|
+
"""Load the AirLLM model and tokenizer on first use."""
|
|
36
|
+
if self._llm is not None:
|
|
37
|
+
return
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
from airllm import AutoModel
|
|
41
|
+
except ImportError:
|
|
42
|
+
raise ImportError(
|
|
43
|
+
"The 'airllm' package is required for the AirLLM driver. Install it with: pip install prompture[airllm]"
|
|
44
|
+
) from None
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
from transformers import AutoTokenizer
|
|
48
|
+
except ImportError:
|
|
49
|
+
raise ImportError(
|
|
50
|
+
"The 'transformers' package is required for the AirLLM driver. "
|
|
51
|
+
"Install it with: pip install transformers"
|
|
52
|
+
) from None
|
|
53
|
+
|
|
54
|
+
logger.info(f"Loading AirLLM model: {self.model} (compression={self.compression})")
|
|
55
|
+
|
|
56
|
+
load_kwargs: dict[str, Any] = {}
|
|
57
|
+
if self.compression:
|
|
58
|
+
load_kwargs["compression"] = self.compression
|
|
59
|
+
|
|
60
|
+
self._llm = AutoModel.from_pretrained(self.model, **load_kwargs)
|
|
61
|
+
self._tokenizer = AutoTokenizer.from_pretrained(self.model)
|
|
62
|
+
logger.info("AirLLM model loaded successfully")
|
|
63
|
+
|
|
64
|
+
# ------------------------------------------------------------------
|
|
65
|
+
# Driver interface
|
|
66
|
+
# ------------------------------------------------------------------
|
|
67
|
+
def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
68
|
+
self._ensure_loaded()
|
|
69
|
+
|
|
70
|
+
merged_options = self.options.copy()
|
|
71
|
+
if options:
|
|
72
|
+
merged_options.update(options)
|
|
73
|
+
|
|
74
|
+
max_new_tokens = merged_options.get("max_new_tokens", 256)
|
|
75
|
+
|
|
76
|
+
# Tokenize
|
|
77
|
+
input_ids = self._tokenizer(prompt, return_tensors="pt").input_ids
|
|
78
|
+
|
|
79
|
+
prompt_tokens = input_ids.shape[1]
|
|
80
|
+
|
|
81
|
+
logger.debug(f"AirLLM generating with max_new_tokens={max_new_tokens}, prompt_tokens={prompt_tokens}")
|
|
82
|
+
|
|
83
|
+
# Generate
|
|
84
|
+
output_ids = self._llm.generate(
|
|
85
|
+
input_ids,
|
|
86
|
+
max_new_tokens=max_new_tokens,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# Decode only the newly generated tokens (strip the prompt prefix)
|
|
90
|
+
new_tokens = output_ids[0, prompt_tokens:]
|
|
91
|
+
completion_tokens = len(new_tokens)
|
|
92
|
+
text = self._tokenizer.decode(new_tokens, skip_special_tokens=True)
|
|
93
|
+
|
|
94
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
95
|
+
|
|
96
|
+
meta = {
|
|
97
|
+
"prompt_tokens": prompt_tokens,
|
|
98
|
+
"completion_tokens": completion_tokens,
|
|
99
|
+
"total_tokens": total_tokens,
|
|
100
|
+
"cost": 0.0,
|
|
101
|
+
"raw_response": {
|
|
102
|
+
"model": self.model,
|
|
103
|
+
"compression": self.compression,
|
|
104
|
+
"max_new_tokens": max_new_tokens,
|
|
105
|
+
},
|
|
106
|
+
"model_name": self.model,
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
return {"text": text, "meta": meta}
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Async AirLLM driver — wraps the sync GPU-bound driver with asyncio.to_thread."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from ..async_driver import AsyncDriver
|
|
9
|
+
from .airllm_driver import AirLLMDriver
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class AsyncAirLLMDriver(AsyncDriver):
|
|
13
|
+
"""Async wrapper around :class:`AirLLMDriver`.
|
|
14
|
+
|
|
15
|
+
AirLLM is GPU-bound with no native async API, so we delegate to
|
|
16
|
+
``asyncio.to_thread()`` to avoid blocking the event loop.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
MODEL_PRICING = AirLLMDriver.MODEL_PRICING
|
|
20
|
+
|
|
21
|
+
def __init__(self, model: str = "meta-llama/Llama-2-7b-hf", compression: str | None = None):
|
|
22
|
+
self.model = model
|
|
23
|
+
self._sync_driver = AirLLMDriver(model=model, compression=compression)
|
|
24
|
+
|
|
25
|
+
async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
26
|
+
return await asyncio.to_thread(self._sync_driver.generate, prompt, options)
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
"""Async Azure OpenAI driver. Requires the ``openai`` package (>=1.0.0)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
from openai import AsyncAzureOpenAI
|
|
10
|
+
except Exception:
|
|
11
|
+
AsyncAzureOpenAI = None
|
|
12
|
+
|
|
13
|
+
from ..async_driver import AsyncDriver
|
|
14
|
+
from ..cost_mixin import CostMixin
|
|
15
|
+
from .azure_driver import AzureDriver
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AsyncAzureDriver(CostMixin, AsyncDriver):
|
|
19
|
+
supports_json_mode = True
|
|
20
|
+
supports_json_schema = True
|
|
21
|
+
supports_vision = True
|
|
22
|
+
|
|
23
|
+
MODEL_PRICING = AzureDriver.MODEL_PRICING
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
api_key: str | None = None,
|
|
28
|
+
endpoint: str | None = None,
|
|
29
|
+
deployment_id: str | None = None,
|
|
30
|
+
model: str = "gpt-4o-mini",
|
|
31
|
+
):
|
|
32
|
+
self.api_key = api_key or os.getenv("AZURE_API_KEY")
|
|
33
|
+
self.endpoint = endpoint or os.getenv("AZURE_API_ENDPOINT")
|
|
34
|
+
self.deployment_id = deployment_id or os.getenv("AZURE_DEPLOYMENT_ID")
|
|
35
|
+
self.api_version = os.getenv("AZURE_API_VERSION", "2023-07-01-preview")
|
|
36
|
+
self.model = model
|
|
37
|
+
|
|
38
|
+
if not self.api_key:
|
|
39
|
+
raise ValueError("Missing Azure API key (AZURE_API_KEY).")
|
|
40
|
+
if not self.endpoint:
|
|
41
|
+
raise ValueError("Missing Azure API endpoint (AZURE_API_ENDPOINT).")
|
|
42
|
+
if not self.deployment_id:
|
|
43
|
+
raise ValueError("Missing Azure deployment ID (AZURE_DEPLOYMENT_ID).")
|
|
44
|
+
|
|
45
|
+
if AsyncAzureOpenAI:
|
|
46
|
+
self.client = AsyncAzureOpenAI(
|
|
47
|
+
api_key=self.api_key,
|
|
48
|
+
api_version=self.api_version,
|
|
49
|
+
azure_endpoint=self.endpoint,
|
|
50
|
+
)
|
|
51
|
+
else:
|
|
52
|
+
self.client = None
|
|
53
|
+
|
|
54
|
+
supports_messages = True
|
|
55
|
+
|
|
56
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
57
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
58
|
+
|
|
59
|
+
return _prepare_openai_vision_messages(messages)
|
|
60
|
+
|
|
61
|
+
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
62
|
+
messages = [{"role": "user", "content": prompt}]
|
|
63
|
+
return await self._do_generate(messages, options)
|
|
64
|
+
|
|
65
|
+
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
66
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
67
|
+
|
|
68
|
+
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
69
|
+
if self.client is None:
|
|
70
|
+
raise RuntimeError("openai package (>=1.0.0) with AsyncAzureOpenAI not installed")
|
|
71
|
+
|
|
72
|
+
model = options.get("model", self.model)
|
|
73
|
+
model_info = self.MODEL_PRICING.get(model, {})
|
|
74
|
+
tokens_param = model_info.get("tokens_param", "max_tokens")
|
|
75
|
+
supports_temperature = model_info.get("supports_temperature", True)
|
|
76
|
+
|
|
77
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
78
|
+
|
|
79
|
+
kwargs = {
|
|
80
|
+
"model": self.deployment_id,
|
|
81
|
+
"messages": messages,
|
|
82
|
+
}
|
|
83
|
+
kwargs[tokens_param] = opts.get("max_tokens", 512)
|
|
84
|
+
|
|
85
|
+
if supports_temperature and "temperature" in opts:
|
|
86
|
+
kwargs["temperature"] = opts["temperature"]
|
|
87
|
+
|
|
88
|
+
# Native JSON mode support
|
|
89
|
+
if options.get("json_mode"):
|
|
90
|
+
json_schema = options.get("json_schema")
|
|
91
|
+
if json_schema:
|
|
92
|
+
kwargs["response_format"] = {
|
|
93
|
+
"type": "json_schema",
|
|
94
|
+
"json_schema": {
|
|
95
|
+
"name": "extraction",
|
|
96
|
+
"strict": True,
|
|
97
|
+
"schema": json_schema,
|
|
98
|
+
},
|
|
99
|
+
}
|
|
100
|
+
else:
|
|
101
|
+
kwargs["response_format"] = {"type": "json_object"}
|
|
102
|
+
|
|
103
|
+
resp = await self.client.chat.completions.create(**kwargs)
|
|
104
|
+
|
|
105
|
+
usage = getattr(resp, "usage", None)
|
|
106
|
+
prompt_tokens = getattr(usage, "prompt_tokens", 0)
|
|
107
|
+
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
108
|
+
total_tokens = getattr(usage, "total_tokens", 0)
|
|
109
|
+
|
|
110
|
+
total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
|
|
111
|
+
|
|
112
|
+
meta = {
|
|
113
|
+
"prompt_tokens": prompt_tokens,
|
|
114
|
+
"completion_tokens": completion_tokens,
|
|
115
|
+
"total_tokens": total_tokens,
|
|
116
|
+
"cost": total_cost,
|
|
117
|
+
"raw_response": resp.model_dump(),
|
|
118
|
+
"model_name": model,
|
|
119
|
+
"deployment_id": self.deployment_id,
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
text = resp.choices[0].message.content
|
|
123
|
+
return {"text": text, "meta": meta}
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""Async Anthropic Claude driver. Requires the ``anthropic`` package."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import anthropic
|
|
11
|
+
except Exception:
|
|
12
|
+
anthropic = None
|
|
13
|
+
|
|
14
|
+
from ..async_driver import AsyncDriver
|
|
15
|
+
from ..cost_mixin import CostMixin
|
|
16
|
+
from .claude_driver import ClaudeDriver
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AsyncClaudeDriver(CostMixin, AsyncDriver):
|
|
20
|
+
supports_json_mode = True
|
|
21
|
+
supports_json_schema = True
|
|
22
|
+
supports_vision = True
|
|
23
|
+
|
|
24
|
+
MODEL_PRICING = ClaudeDriver.MODEL_PRICING
|
|
25
|
+
|
|
26
|
+
def __init__(self, api_key: str | None = None, model: str = "claude-3-5-haiku-20241022"):
|
|
27
|
+
self.api_key = api_key or os.getenv("CLAUDE_API_KEY")
|
|
28
|
+
self.model = model or os.getenv("CLAUDE_MODEL_NAME", "claude-3-5-haiku-20241022")
|
|
29
|
+
|
|
30
|
+
supports_messages = True
|
|
31
|
+
|
|
32
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
33
|
+
from .vision_helpers import _prepare_claude_vision_messages
|
|
34
|
+
|
|
35
|
+
return _prepare_claude_vision_messages(messages)
|
|
36
|
+
|
|
37
|
+
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
38
|
+
messages = [{"role": "user", "content": prompt}]
|
|
39
|
+
return await self._do_generate(messages, options)
|
|
40
|
+
|
|
41
|
+
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
42
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
43
|
+
|
|
44
|
+
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
45
|
+
if anthropic is None:
|
|
46
|
+
raise RuntimeError("anthropic package not installed")
|
|
47
|
+
|
|
48
|
+
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
49
|
+
model = options.get("model", self.model)
|
|
50
|
+
|
|
51
|
+
client = anthropic.AsyncAnthropic(api_key=self.api_key)
|
|
52
|
+
|
|
53
|
+
# Anthropic requires system messages as a top-level parameter
|
|
54
|
+
system_content = None
|
|
55
|
+
api_messages = []
|
|
56
|
+
for msg in messages:
|
|
57
|
+
if msg.get("role") == "system":
|
|
58
|
+
system_content = msg.get("content", "")
|
|
59
|
+
else:
|
|
60
|
+
api_messages.append(msg)
|
|
61
|
+
|
|
62
|
+
# Build common kwargs
|
|
63
|
+
common_kwargs: dict[str, Any] = {
|
|
64
|
+
"model": model,
|
|
65
|
+
"messages": api_messages,
|
|
66
|
+
"temperature": opts["temperature"],
|
|
67
|
+
"max_tokens": opts["max_tokens"],
|
|
68
|
+
}
|
|
69
|
+
if system_content:
|
|
70
|
+
common_kwargs["system"] = system_content
|
|
71
|
+
|
|
72
|
+
# Native JSON mode: use tool-use for schema enforcement
|
|
73
|
+
if options.get("json_mode"):
|
|
74
|
+
json_schema = options.get("json_schema")
|
|
75
|
+
if json_schema:
|
|
76
|
+
tool_def = {
|
|
77
|
+
"name": "extract_json",
|
|
78
|
+
"description": "Extract structured data matching the schema",
|
|
79
|
+
"input_schema": json_schema,
|
|
80
|
+
}
|
|
81
|
+
resp = await client.messages.create(
|
|
82
|
+
**common_kwargs,
|
|
83
|
+
tools=[tool_def],
|
|
84
|
+
tool_choice={"type": "tool", "name": "extract_json"},
|
|
85
|
+
)
|
|
86
|
+
text = ""
|
|
87
|
+
for block in resp.content:
|
|
88
|
+
if block.type == "tool_use":
|
|
89
|
+
text = json.dumps(block.input)
|
|
90
|
+
break
|
|
91
|
+
else:
|
|
92
|
+
resp = await client.messages.create(**common_kwargs)
|
|
93
|
+
text = resp.content[0].text
|
|
94
|
+
else:
|
|
95
|
+
resp = await client.messages.create(**common_kwargs)
|
|
96
|
+
text = resp.content[0].text
|
|
97
|
+
|
|
98
|
+
prompt_tokens = resp.usage.input_tokens
|
|
99
|
+
completion_tokens = resp.usage.output_tokens
|
|
100
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
101
|
+
|
|
102
|
+
total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
|
|
103
|
+
|
|
104
|
+
meta = {
|
|
105
|
+
"prompt_tokens": prompt_tokens,
|
|
106
|
+
"completion_tokens": completion_tokens,
|
|
107
|
+
"total_tokens": total_tokens,
|
|
108
|
+
"cost": total_cost,
|
|
109
|
+
"raw_response": dict(resp),
|
|
110
|
+
"model_name": model,
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
return {"text": text, "meta": meta}
|