prompture 0.0.33.dev2__py3-none-any.whl → 0.0.34.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +112 -54
- prompture/_version.py +34 -0
- prompture/aio/__init__.py +74 -0
- prompture/async_conversation.py +484 -0
- prompture/async_core.py +803 -0
- prompture/async_driver.py +131 -0
- prompture/cache.py +469 -0
- prompture/callbacks.py +50 -0
- prompture/cli.py +7 -3
- prompture/conversation.py +504 -0
- prompture/core.py +475 -352
- prompture/cost_mixin.py +51 -0
- prompture/discovery.py +41 -36
- prompture/driver.py +125 -5
- prompture/drivers/__init__.py +63 -57
- prompture/drivers/airllm_driver.py +13 -20
- prompture/drivers/async_airllm_driver.py +26 -0
- prompture/drivers/async_azure_driver.py +117 -0
- prompture/drivers/async_claude_driver.py +107 -0
- prompture/drivers/async_google_driver.py +132 -0
- prompture/drivers/async_grok_driver.py +91 -0
- prompture/drivers/async_groq_driver.py +84 -0
- prompture/drivers/async_hugging_driver.py +61 -0
- prompture/drivers/async_lmstudio_driver.py +79 -0
- prompture/drivers/async_local_http_driver.py +44 -0
- prompture/drivers/async_ollama_driver.py +125 -0
- prompture/drivers/async_openai_driver.py +96 -0
- prompture/drivers/async_openrouter_driver.py +96 -0
- prompture/drivers/async_registry.py +80 -0
- prompture/drivers/azure_driver.py +36 -15
- prompture/drivers/claude_driver.py +86 -40
- prompture/drivers/google_driver.py +86 -58
- prompture/drivers/grok_driver.py +29 -38
- prompture/drivers/groq_driver.py +27 -32
- prompture/drivers/hugging_driver.py +6 -6
- prompture/drivers/lmstudio_driver.py +26 -13
- prompture/drivers/local_http_driver.py +6 -6
- prompture/drivers/ollama_driver.py +90 -23
- prompture/drivers/openai_driver.py +36 -15
- prompture/drivers/openrouter_driver.py +31 -31
- prompture/field_definitions.py +106 -96
- prompture/logging.py +80 -0
- prompture/model_rates.py +16 -15
- prompture/runner.py +49 -47
- prompture/session.py +117 -0
- prompture/settings.py +11 -1
- prompture/tools.py +172 -265
- prompture/validator.py +3 -3
- {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/METADATA +18 -20
- prompture-0.0.34.dev1.dist-info/RECORD +54 -0
- prompture-0.0.33.dev2.dist-info/RECORD +0 -30
- {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/WHEEL +0 -0
- {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/top_level.txt +0 -0
prompture/cost_mixin.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""Shared cost-calculation mixin for LLM drivers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class CostMixin:
|
|
9
|
+
"""Mixin that provides ``_calculate_cost`` to sync and async drivers.
|
|
10
|
+
|
|
11
|
+
Drivers that charge per-token should inherit from this mixin alongside
|
|
12
|
+
their base class (``Driver`` or ``AsyncDriver``). Free/local drivers
|
|
13
|
+
(Ollama, LM Studio, LocalHTTP, HuggingFace, AirLLM) can skip it.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
# Subclasses must define MODEL_PRICING as a class attribute.
|
|
17
|
+
MODEL_PRICING: dict[str, dict[str, Any]] = {}
|
|
18
|
+
|
|
19
|
+
# Divisor for hardcoded MODEL_PRICING rates.
|
|
20
|
+
# Most drivers use per-1K-token pricing (1_000).
|
|
21
|
+
# Grok uses per-1M-token pricing (1_000_000).
|
|
22
|
+
# Google uses per-1M-character pricing (1_000_000).
|
|
23
|
+
_PRICING_UNIT: int = 1_000
|
|
24
|
+
|
|
25
|
+
def _calculate_cost(
|
|
26
|
+
self,
|
|
27
|
+
provider: str,
|
|
28
|
+
model: str,
|
|
29
|
+
prompt_tokens: int | float,
|
|
30
|
+
completion_tokens: int | float,
|
|
31
|
+
) -> float:
|
|
32
|
+
"""Calculate USD cost for a generation call.
|
|
33
|
+
|
|
34
|
+
Resolution order:
|
|
35
|
+
1. Live rates from ``model_rates.get_model_rates()`` (per 1M tokens).
|
|
36
|
+
2. Hardcoded ``MODEL_PRICING`` on the driver class (unit set by ``_PRICING_UNIT``).
|
|
37
|
+
3. Zero if neither source has data.
|
|
38
|
+
"""
|
|
39
|
+
from .model_rates import get_model_rates
|
|
40
|
+
|
|
41
|
+
live_rates = get_model_rates(provider, model)
|
|
42
|
+
if live_rates:
|
|
43
|
+
prompt_cost = (prompt_tokens / 1_000_000) * live_rates["input"]
|
|
44
|
+
completion_cost = (completion_tokens / 1_000_000) * live_rates["output"]
|
|
45
|
+
else:
|
|
46
|
+
unit = self._PRICING_UNIT
|
|
47
|
+
model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
|
|
48
|
+
prompt_cost = (prompt_tokens / unit) * model_pricing["prompt"]
|
|
49
|
+
completion_cost = (completion_tokens / unit) * model_pricing["completion"]
|
|
50
|
+
|
|
51
|
+
return round(prompt_cost + completion_cost, 6)
|
prompture/discovery.py
CHANGED
|
@@ -1,39 +1,40 @@
|
|
|
1
1
|
"""Discovery module for auto-detecting available models."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
2
4
|
import os
|
|
5
|
+
|
|
3
6
|
import requests
|
|
4
|
-
import logging
|
|
5
|
-
from typing import List, Dict, Any, Set
|
|
6
7
|
|
|
7
8
|
from .drivers import (
|
|
8
|
-
DRIVER_REGISTRY,
|
|
9
|
-
OpenAIDriver,
|
|
10
9
|
AzureDriver,
|
|
11
10
|
ClaudeDriver,
|
|
12
11
|
GoogleDriver,
|
|
13
|
-
GroqDriver,
|
|
14
|
-
OpenRouterDriver,
|
|
15
12
|
GrokDriver,
|
|
16
|
-
|
|
13
|
+
GroqDriver,
|
|
17
14
|
LMStudioDriver,
|
|
18
15
|
LocalHTTPDriver,
|
|
16
|
+
OllamaDriver,
|
|
17
|
+
OpenAIDriver,
|
|
18
|
+
OpenRouterDriver,
|
|
19
19
|
)
|
|
20
20
|
from .settings import settings
|
|
21
21
|
|
|
22
22
|
logger = logging.getLogger(__name__)
|
|
23
23
|
|
|
24
|
-
|
|
24
|
+
|
|
25
|
+
def get_available_models() -> list[str]:
|
|
25
26
|
"""
|
|
26
27
|
Auto-detects all available models based on configured drivers and environment variables.
|
|
27
|
-
|
|
28
|
+
|
|
28
29
|
Iterates through supported providers and checks if they are configured (e.g. API key present).
|
|
29
30
|
For static drivers, returns models from their MODEL_PRICING keys.
|
|
30
31
|
For dynamic drivers (like Ollama), attempts to fetch available models from the endpoint.
|
|
31
|
-
|
|
32
|
+
|
|
32
33
|
Returns:
|
|
33
34
|
A list of unique model strings in the format "provider/model_id".
|
|
34
35
|
"""
|
|
35
|
-
available_models:
|
|
36
|
-
configured_providers:
|
|
36
|
+
available_models: set[str] = set()
|
|
37
|
+
configured_providers: set[str] = set()
|
|
37
38
|
|
|
38
39
|
# Map of provider name to driver class
|
|
39
40
|
# We need to map the registry keys to the actual classes to check MODEL_PRICING
|
|
@@ -57,16 +58,18 @@ def get_available_models() -> List[str]:
|
|
|
57
58
|
# We can check this by looking at the settings or env vars that the driver uses.
|
|
58
59
|
# A simple way is to try to instantiate it with defaults, but that might fail if keys are missing.
|
|
59
60
|
# Instead, let's check the specific requirements for each known provider.
|
|
60
|
-
|
|
61
|
+
|
|
61
62
|
is_configured = False
|
|
62
|
-
|
|
63
|
+
|
|
63
64
|
if provider == "openai":
|
|
64
65
|
if settings.openai_api_key or os.getenv("OPENAI_API_KEY"):
|
|
65
66
|
is_configured = True
|
|
66
67
|
elif provider == "azure":
|
|
67
|
-
if (
|
|
68
|
-
|
|
69
|
-
|
|
68
|
+
if (
|
|
69
|
+
(settings.azure_api_key or os.getenv("AZURE_API_KEY"))
|
|
70
|
+
and (settings.azure_api_endpoint or os.getenv("AZURE_API_ENDPOINT"))
|
|
71
|
+
and (settings.azure_deployment_id or os.getenv("AZURE_DEPLOYMENT_ID"))
|
|
72
|
+
):
|
|
70
73
|
is_configured = True
|
|
71
74
|
elif provider == "claude":
|
|
72
75
|
if settings.claude_api_key or os.getenv("CLAUDE_API_KEY"):
|
|
@@ -88,11 +91,10 @@ def get_available_models() -> List[str]:
|
|
|
88
91
|
# We will check connectivity later
|
|
89
92
|
is_configured = True
|
|
90
93
|
elif provider == "lmstudio":
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
elif provider == "local_http":
|
|
94
|
-
|
|
95
|
-
is_configured = True
|
|
94
|
+
# LM Studio is similar to Ollama, defaults to localhost
|
|
95
|
+
is_configured = True
|
|
96
|
+
elif provider == "local_http" and (settings.local_http_endpoint or os.getenv("LOCAL_HTTP_ENDPOINT")):
|
|
97
|
+
is_configured = True
|
|
96
98
|
|
|
97
99
|
if not is_configured:
|
|
98
100
|
continue
|
|
@@ -102,34 +104,36 @@ def get_available_models() -> List[str]:
|
|
|
102
104
|
# 2. Static Detection: Get models from MODEL_PRICING
|
|
103
105
|
if hasattr(driver_cls, "MODEL_PRICING"):
|
|
104
106
|
pricing = driver_cls.MODEL_PRICING
|
|
105
|
-
for model_id in pricing
|
|
107
|
+
for model_id in pricing:
|
|
106
108
|
# Skip "default" or generic keys if they exist
|
|
107
109
|
if model_id == "default":
|
|
108
110
|
continue
|
|
109
|
-
|
|
110
|
-
# For Azure, the model_id in pricing is usually the base model name,
|
|
111
|
-
# but the user needs to use the deployment ID.
|
|
112
|
-
# However, our Azure driver implementation uses the deployment_id from init
|
|
113
|
-
# as the "model" for the request, but expects the user to pass a model name
|
|
114
|
-
# that maps to pricing?
|
|
111
|
+
|
|
112
|
+
# For Azure, the model_id in pricing is usually the base model name,
|
|
113
|
+
# but the user needs to use the deployment ID.
|
|
114
|
+
# However, our Azure driver implementation uses the deployment_id from init
|
|
115
|
+
# as the "model" for the request, but expects the user to pass a model name
|
|
116
|
+
# that maps to pricing?
|
|
115
117
|
# Looking at AzureDriver:
|
|
116
118
|
# kwargs = {"model": self.deployment_id, ...}
|
|
117
119
|
# model = options.get("model", self.model) -> used for pricing lookup
|
|
118
|
-
# So we should list the keys in MODEL_PRICING as available "models"
|
|
120
|
+
# So we should list the keys in MODEL_PRICING as available "models"
|
|
119
121
|
# even though for Azure specifically it's a bit weird because of deployment IDs.
|
|
120
122
|
# But for general discovery, listing supported models is correct.
|
|
121
|
-
|
|
123
|
+
|
|
122
124
|
available_models.add(f"{provider}/{model_id}")
|
|
123
125
|
|
|
124
126
|
# 3. Dynamic Detection: Specific logic for Ollama
|
|
125
127
|
if provider == "ollama":
|
|
126
128
|
try:
|
|
127
|
-
endpoint = settings.ollama_endpoint or os.getenv(
|
|
129
|
+
endpoint = settings.ollama_endpoint or os.getenv(
|
|
130
|
+
"OLLAMA_ENDPOINT", "http://localhost:11434/api/generate"
|
|
131
|
+
)
|
|
128
132
|
# We need the base URL for tags, usually http://localhost:11434/api/tags
|
|
129
133
|
# The configured endpoint might be .../api/generate or .../api/chat
|
|
130
134
|
base_url = endpoint.split("/api/")[0]
|
|
131
135
|
tags_url = f"{base_url}/api/tags"
|
|
132
|
-
|
|
136
|
+
|
|
133
137
|
resp = requests.get(tags_url, timeout=2)
|
|
134
138
|
if resp.status_code == 200:
|
|
135
139
|
data = resp.json()
|
|
@@ -142,15 +146,16 @@ def get_available_models() -> List[str]:
|
|
|
142
146
|
available_models.add(f"ollama/{name}")
|
|
143
147
|
except Exception as e:
|
|
144
148
|
logger.debug(f"Failed to fetch Ollama models: {e}")
|
|
145
|
-
|
|
149
|
+
|
|
146
150
|
# Future: Add dynamic detection for LM Studio if they have an endpoint for listing models
|
|
147
|
-
|
|
151
|
+
|
|
148
152
|
except Exception as e:
|
|
149
153
|
logger.warning(f"Error detecting models for provider {provider}: {e}")
|
|
150
154
|
continue
|
|
151
155
|
|
|
152
156
|
# Enrich with live model list from models.dev cache
|
|
153
|
-
from .model_rates import
|
|
157
|
+
from .model_rates import PROVIDER_MAP, get_all_provider_models
|
|
158
|
+
|
|
154
159
|
for prompture_name, api_name in PROVIDER_MAP.items():
|
|
155
160
|
if prompture_name in configured_providers:
|
|
156
161
|
for model_id in get_all_provider_models(api_name):
|
prompture/driver.py
CHANGED
|
@@ -1,7 +1,15 @@
|
|
|
1
|
-
"""Driver base class for LLM adapters.
|
|
2
|
-
|
|
1
|
+
"""Driver base class for LLM adapters."""
|
|
2
|
+
|
|
3
3
|
from __future__ import annotations
|
|
4
|
-
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import time
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from .callbacks import DriverCallbacks
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger("prompture.driver")
|
|
12
|
+
|
|
5
13
|
|
|
6
14
|
class Driver:
|
|
7
15
|
"""Adapter base. Implementar generate(prompt, options) -> {"text": ... , "meta": {...}}
|
|
@@ -20,5 +28,117 @@ class Driver:
|
|
|
20
28
|
additional provider-specific metadata while the core fields provide
|
|
21
29
|
standardized access to token usage and cost information.
|
|
22
30
|
"""
|
|
23
|
-
|
|
24
|
-
|
|
31
|
+
|
|
32
|
+
supports_json_mode: bool = False
|
|
33
|
+
supports_json_schema: bool = False
|
|
34
|
+
supports_messages: bool = False
|
|
35
|
+
|
|
36
|
+
callbacks: DriverCallbacks | None = None
|
|
37
|
+
|
|
38
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
39
|
+
raise NotImplementedError
|
|
40
|
+
|
|
41
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
42
|
+
"""Generate a response from a list of conversation messages.
|
|
43
|
+
|
|
44
|
+
Each message is a dict with ``"role"`` (``"system"``, ``"user"``, or
|
|
45
|
+
``"assistant"``) and ``"content"`` keys.
|
|
46
|
+
|
|
47
|
+
The default implementation flattens the messages into a single prompt
|
|
48
|
+
string and delegates to :meth:`generate`. Drivers that natively
|
|
49
|
+
support message arrays should override this method and set
|
|
50
|
+
``supports_messages = True``.
|
|
51
|
+
"""
|
|
52
|
+
prompt = self._flatten_messages(messages)
|
|
53
|
+
return self.generate(prompt, options)
|
|
54
|
+
|
|
55
|
+
# ------------------------------------------------------------------
|
|
56
|
+
# Hook-aware wrappers
|
|
57
|
+
# ------------------------------------------------------------------
|
|
58
|
+
|
|
59
|
+
def generate_with_hooks(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
60
|
+
"""Wrap :meth:`generate` with on_request / on_response / on_error callbacks."""
|
|
61
|
+
driver_name = getattr(self, "model", self.__class__.__name__)
|
|
62
|
+
self._fire_callback(
|
|
63
|
+
"on_request",
|
|
64
|
+
{"prompt": prompt, "messages": None, "options": options, "driver": driver_name},
|
|
65
|
+
)
|
|
66
|
+
t0 = time.perf_counter()
|
|
67
|
+
try:
|
|
68
|
+
resp = self.generate(prompt, options)
|
|
69
|
+
except Exception as exc:
|
|
70
|
+
self._fire_callback(
|
|
71
|
+
"on_error",
|
|
72
|
+
{"error": exc, "prompt": prompt, "messages": None, "options": options, "driver": driver_name},
|
|
73
|
+
)
|
|
74
|
+
raise
|
|
75
|
+
elapsed_ms = (time.perf_counter() - t0) * 1000
|
|
76
|
+
self._fire_callback(
|
|
77
|
+
"on_response",
|
|
78
|
+
{
|
|
79
|
+
"text": resp.get("text", ""),
|
|
80
|
+
"meta": resp.get("meta", {}),
|
|
81
|
+
"driver": driver_name,
|
|
82
|
+
"elapsed_ms": elapsed_ms,
|
|
83
|
+
},
|
|
84
|
+
)
|
|
85
|
+
return resp
|
|
86
|
+
|
|
87
|
+
def generate_messages_with_hooks(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
88
|
+
"""Wrap :meth:`generate_messages` with callbacks."""
|
|
89
|
+
driver_name = getattr(self, "model", self.__class__.__name__)
|
|
90
|
+
self._fire_callback(
|
|
91
|
+
"on_request",
|
|
92
|
+
{"prompt": None, "messages": messages, "options": options, "driver": driver_name},
|
|
93
|
+
)
|
|
94
|
+
t0 = time.perf_counter()
|
|
95
|
+
try:
|
|
96
|
+
resp = self.generate_messages(messages, options)
|
|
97
|
+
except Exception as exc:
|
|
98
|
+
self._fire_callback(
|
|
99
|
+
"on_error",
|
|
100
|
+
{"error": exc, "prompt": None, "messages": messages, "options": options, "driver": driver_name},
|
|
101
|
+
)
|
|
102
|
+
raise
|
|
103
|
+
elapsed_ms = (time.perf_counter() - t0) * 1000
|
|
104
|
+
self._fire_callback(
|
|
105
|
+
"on_response",
|
|
106
|
+
{
|
|
107
|
+
"text": resp.get("text", ""),
|
|
108
|
+
"meta": resp.get("meta", {}),
|
|
109
|
+
"driver": driver_name,
|
|
110
|
+
"elapsed_ms": elapsed_ms,
|
|
111
|
+
},
|
|
112
|
+
)
|
|
113
|
+
return resp
|
|
114
|
+
|
|
115
|
+
# ------------------------------------------------------------------
|
|
116
|
+
# Internal helpers
|
|
117
|
+
# ------------------------------------------------------------------
|
|
118
|
+
|
|
119
|
+
def _fire_callback(self, event: str, payload: dict[str, Any]) -> None:
|
|
120
|
+
"""Invoke a single callback, swallowing and logging any exception."""
|
|
121
|
+
if self.callbacks is None:
|
|
122
|
+
return
|
|
123
|
+
cb = getattr(self.callbacks, event, None)
|
|
124
|
+
if cb is None:
|
|
125
|
+
return
|
|
126
|
+
try:
|
|
127
|
+
cb(payload)
|
|
128
|
+
except Exception:
|
|
129
|
+
logger.exception("Callback %s raised an exception", event)
|
|
130
|
+
|
|
131
|
+
@staticmethod
|
|
132
|
+
def _flatten_messages(messages: list[dict[str, str]]) -> str:
|
|
133
|
+
"""Join messages into a single prompt string with role prefixes."""
|
|
134
|
+
parts: list[str] = []
|
|
135
|
+
for msg in messages:
|
|
136
|
+
role = msg.get("role", "user")
|
|
137
|
+
content = msg.get("content", "")
|
|
138
|
+
if role == "system":
|
|
139
|
+
parts.append(f"[System]: {content}")
|
|
140
|
+
elif role == "assistant":
|
|
141
|
+
parts.append(f"[Assistant]: {content}")
|
|
142
|
+
else:
|
|
143
|
+
parts.append(f"[User]: {content}")
|
|
144
|
+
return "\n\n".join(parts)
|
prompture/drivers/__init__.py
CHANGED
|
@@ -1,60 +1,49 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
from
|
|
4
|
-
from .
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from ..settings import settings
|
|
4
|
+
from .airllm_driver import AirLLMDriver
|
|
5
|
+
from .async_airllm_driver import AsyncAirLLMDriver
|
|
6
|
+
from .async_azure_driver import AsyncAzureDriver
|
|
7
|
+
from .async_claude_driver import AsyncClaudeDriver
|
|
8
|
+
from .async_google_driver import AsyncGoogleDriver
|
|
9
|
+
from .async_grok_driver import AsyncGrokDriver
|
|
10
|
+
from .async_groq_driver import AsyncGroqDriver
|
|
11
|
+
from .async_hugging_driver import AsyncHuggingFaceDriver
|
|
12
|
+
from .async_lmstudio_driver import AsyncLMStudioDriver
|
|
13
|
+
from .async_local_http_driver import AsyncLocalHTTPDriver
|
|
14
|
+
from .async_ollama_driver import AsyncOllamaDriver
|
|
15
|
+
from .async_openai_driver import AsyncOpenAIDriver
|
|
16
|
+
from .async_openrouter_driver import AsyncOpenRouterDriver
|
|
17
|
+
from .async_registry import ASYNC_DRIVER_REGISTRY, get_async_driver, get_async_driver_for_model
|
|
5
18
|
from .azure_driver import AzureDriver
|
|
6
|
-
from .
|
|
19
|
+
from .claude_driver import ClaudeDriver
|
|
7
20
|
from .google_driver import GoogleDriver
|
|
21
|
+
from .grok_driver import GrokDriver
|
|
8
22
|
from .groq_driver import GroqDriver
|
|
23
|
+
from .lmstudio_driver import LMStudioDriver
|
|
24
|
+
from .local_http_driver import LocalHTTPDriver
|
|
25
|
+
from .ollama_driver import OllamaDriver
|
|
26
|
+
from .openai_driver import OpenAIDriver
|
|
9
27
|
from .openrouter_driver import OpenRouterDriver
|
|
10
|
-
from .grok_driver import GrokDriver
|
|
11
|
-
from .airllm_driver import AirLLMDriver
|
|
12
|
-
from ..settings import settings
|
|
13
|
-
|
|
14
28
|
|
|
15
29
|
# Central registry: maps provider → factory function
|
|
16
30
|
DRIVER_REGISTRY = {
|
|
17
|
-
"openai": lambda model=None: OpenAIDriver(
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
),
|
|
21
|
-
"ollama": lambda model=None: OllamaDriver(
|
|
22
|
-
endpoint=settings.ollama_endpoint,
|
|
23
|
-
model=model or settings.ollama_model
|
|
24
|
-
),
|
|
25
|
-
"claude": lambda model=None: ClaudeDriver(
|
|
26
|
-
api_key=settings.claude_api_key,
|
|
27
|
-
model=model or settings.claude_model
|
|
28
|
-
),
|
|
31
|
+
"openai": lambda model=None: OpenAIDriver(api_key=settings.openai_api_key, model=model or settings.openai_model),
|
|
32
|
+
"ollama": lambda model=None: OllamaDriver(endpoint=settings.ollama_endpoint, model=model or settings.ollama_model),
|
|
33
|
+
"claude": lambda model=None: ClaudeDriver(api_key=settings.claude_api_key, model=model or settings.claude_model),
|
|
29
34
|
"lmstudio": lambda model=None: LMStudioDriver(
|
|
30
|
-
endpoint=settings.lmstudio_endpoint,
|
|
31
|
-
model=model or settings.lmstudio_model
|
|
35
|
+
endpoint=settings.lmstudio_endpoint, model=model or settings.lmstudio_model
|
|
32
36
|
),
|
|
33
37
|
"azure": lambda model=None: AzureDriver(
|
|
34
|
-
api_key=settings.azure_api_key,
|
|
35
|
-
endpoint=settings.azure_api_endpoint,
|
|
36
|
-
deployment_id=settings.azure_deployment_id
|
|
37
|
-
),
|
|
38
|
-
"local_http": lambda model=None: LocalHTTPDriver(
|
|
39
|
-
endpoint=settings.local_http_endpoint,
|
|
40
|
-
model=model
|
|
41
|
-
),
|
|
42
|
-
"google": lambda model=None: GoogleDriver(
|
|
43
|
-
api_key=settings.google_api_key,
|
|
44
|
-
model=model or settings.google_model
|
|
45
|
-
),
|
|
46
|
-
"groq": lambda model=None: GroqDriver(
|
|
47
|
-
api_key=settings.groq_api_key,
|
|
48
|
-
model=model or settings.groq_model
|
|
38
|
+
api_key=settings.azure_api_key, endpoint=settings.azure_api_endpoint, deployment_id=settings.azure_deployment_id
|
|
49
39
|
),
|
|
40
|
+
"local_http": lambda model=None: LocalHTTPDriver(endpoint=settings.local_http_endpoint, model=model),
|
|
41
|
+
"google": lambda model=None: GoogleDriver(api_key=settings.google_api_key, model=model or settings.google_model),
|
|
42
|
+
"groq": lambda model=None: GroqDriver(api_key=settings.groq_api_key, model=model or settings.groq_model),
|
|
50
43
|
"openrouter": lambda model=None: OpenRouterDriver(
|
|
51
|
-
api_key=settings.openrouter_api_key,
|
|
52
|
-
model=model or settings.openrouter_model
|
|
53
|
-
),
|
|
54
|
-
"grok": lambda model=None: GrokDriver(
|
|
55
|
-
api_key=settings.grok_api_key,
|
|
56
|
-
model=model or settings.grok_model
|
|
44
|
+
api_key=settings.openrouter_api_key, model=model or settings.openrouter_model
|
|
57
45
|
),
|
|
46
|
+
"grok": lambda model=None: GrokDriver(api_key=settings.grok_api_key, model=model or settings.grok_model),
|
|
58
47
|
"airllm": lambda model=None: AirLLMDriver(
|
|
59
48
|
model=model or settings.airllm_model,
|
|
60
49
|
compression=settings.airllm_compression,
|
|
@@ -62,7 +51,7 @@ DRIVER_REGISTRY = {
|
|
|
62
51
|
}
|
|
63
52
|
|
|
64
53
|
|
|
65
|
-
def get_driver(provider_name: str = None):
|
|
54
|
+
def get_driver(provider_name: Optional[str] = None):
|
|
66
55
|
"""
|
|
67
56
|
Factory to get a driver instance based on the provider name (legacy style).
|
|
68
57
|
Uses default model from settings if not overridden.
|
|
@@ -78,21 +67,21 @@ def get_driver_for_model(model_str: str):
|
|
|
78
67
|
Factory to get a driver instance based on a full model string.
|
|
79
68
|
Format: provider/model_id
|
|
80
69
|
Example: "openai/gpt-4-turbo-preview"
|
|
81
|
-
|
|
70
|
+
|
|
82
71
|
Args:
|
|
83
72
|
model_str: Model identifier string. Can be either:
|
|
84
73
|
- Full format: "provider/model" (e.g. "openai/gpt-4")
|
|
85
74
|
- Provider only: "provider" (e.g. "openai")
|
|
86
|
-
|
|
75
|
+
|
|
87
76
|
Returns:
|
|
88
77
|
A configured driver instance for the specified provider/model.
|
|
89
|
-
|
|
78
|
+
|
|
90
79
|
Raises:
|
|
91
80
|
ValueError: If provider is invalid or format is incorrect.
|
|
92
81
|
"""
|
|
93
82
|
if not isinstance(model_str, str):
|
|
94
83
|
raise ValueError("Model string must be a string, got {type(model_str)}")
|
|
95
|
-
|
|
84
|
+
|
|
96
85
|
if not model_str:
|
|
97
86
|
raise ValueError("Model string cannot be empty")
|
|
98
87
|
|
|
@@ -104,23 +93,40 @@ def get_driver_for_model(model_str: str):
|
|
|
104
93
|
# Validate provider
|
|
105
94
|
if provider not in DRIVER_REGISTRY:
|
|
106
95
|
raise ValueError(f"Unsupported provider '{provider}'")
|
|
107
|
-
|
|
96
|
+
|
|
108
97
|
# Create driver with model ID if provided, otherwise use default
|
|
109
98
|
return DRIVER_REGISTRY[provider](model_id)
|
|
110
99
|
|
|
111
100
|
|
|
112
101
|
__all__ = [
|
|
113
|
-
|
|
114
|
-
"
|
|
115
|
-
|
|
116
|
-
"
|
|
117
|
-
"
|
|
102
|
+
# Async drivers
|
|
103
|
+
"ASYNC_DRIVER_REGISTRY",
|
|
104
|
+
# Sync drivers
|
|
105
|
+
"AirLLMDriver",
|
|
106
|
+
"AsyncAirLLMDriver",
|
|
107
|
+
"AsyncAzureDriver",
|
|
108
|
+
"AsyncClaudeDriver",
|
|
109
|
+
"AsyncGoogleDriver",
|
|
110
|
+
"AsyncGrokDriver",
|
|
111
|
+
"AsyncGroqDriver",
|
|
112
|
+
"AsyncHuggingFaceDriver",
|
|
113
|
+
"AsyncLMStudioDriver",
|
|
114
|
+
"AsyncLocalHTTPDriver",
|
|
115
|
+
"AsyncOllamaDriver",
|
|
116
|
+
"AsyncOpenAIDriver",
|
|
117
|
+
"AsyncOpenRouterDriver",
|
|
118
118
|
"AzureDriver",
|
|
119
|
+
"ClaudeDriver",
|
|
119
120
|
"GoogleDriver",
|
|
121
|
+
"GrokDriver",
|
|
120
122
|
"GroqDriver",
|
|
123
|
+
"LMStudioDriver",
|
|
124
|
+
"LocalHTTPDriver",
|
|
125
|
+
"OllamaDriver",
|
|
126
|
+
"OpenAIDriver",
|
|
121
127
|
"OpenRouterDriver",
|
|
122
|
-
"
|
|
123
|
-
"
|
|
128
|
+
"get_async_driver",
|
|
129
|
+
"get_async_driver_for_model",
|
|
124
130
|
"get_driver",
|
|
125
131
|
"get_driver_for_model",
|
|
126
132
|
]
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
from typing import Any, Optional
|
|
3
|
+
|
|
2
4
|
from ..driver import Driver
|
|
3
|
-
from typing import Any, Dict, Optional
|
|
4
5
|
|
|
5
6
|
logger = logging.getLogger(__name__)
|
|
6
7
|
|
|
@@ -13,12 +14,9 @@ class AirLLMDriver(Driver):
|
|
|
13
14
|
``generate()`` call so the rest of Prompture works without it installed.
|
|
14
15
|
"""
|
|
15
16
|
|
|
16
|
-
MODEL_PRICING = {
|
|
17
|
-
"default": {"prompt": 0.0, "completion": 0.0}
|
|
18
|
-
}
|
|
17
|
+
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
19
18
|
|
|
20
|
-
def __init__(self, model: str = "meta-llama/Llama-2-7b-hf",
|
|
21
|
-
compression: Optional[str] = None):
|
|
19
|
+
def __init__(self, model: str = "meta-llama/Llama-2-7b-hf", compression: Optional[str] = None):
|
|
22
20
|
"""
|
|
23
21
|
Args:
|
|
24
22
|
model: HuggingFace repo ID (e.g. ``"meta-llama/Llama-2-70b-hf"``).
|
|
@@ -26,7 +24,7 @@ class AirLLMDriver(Driver):
|
|
|
26
24
|
"""
|
|
27
25
|
self.model = model
|
|
28
26
|
self.compression = compression
|
|
29
|
-
self.options:
|
|
27
|
+
self.options: dict[str, Any] = {}
|
|
30
28
|
self._llm = None
|
|
31
29
|
self._tokenizer = None
|
|
32
30
|
|
|
@@ -42,9 +40,8 @@ class AirLLMDriver(Driver):
|
|
|
42
40
|
from airllm import AutoModel
|
|
43
41
|
except ImportError:
|
|
44
42
|
raise ImportError(
|
|
45
|
-
"The 'airllm' package is required for the AirLLM driver. "
|
|
46
|
-
|
|
47
|
-
)
|
|
43
|
+
"The 'airllm' package is required for the AirLLM driver. Install it with: pip install prompture[airllm]"
|
|
44
|
+
) from None
|
|
48
45
|
|
|
49
46
|
try:
|
|
50
47
|
from transformers import AutoTokenizer
|
|
@@ -52,12 +49,11 @@ class AirLLMDriver(Driver):
|
|
|
52
49
|
raise ImportError(
|
|
53
50
|
"The 'transformers' package is required for the AirLLM driver. "
|
|
54
51
|
"Install it with: pip install transformers"
|
|
55
|
-
)
|
|
52
|
+
) from None
|
|
56
53
|
|
|
57
|
-
logger.info(f"Loading AirLLM model: {self.model} "
|
|
58
|
-
f"(compression={self.compression})")
|
|
54
|
+
logger.info(f"Loading AirLLM model: {self.model} (compression={self.compression})")
|
|
59
55
|
|
|
60
|
-
load_kwargs:
|
|
56
|
+
load_kwargs: dict[str, Any] = {}
|
|
61
57
|
if self.compression:
|
|
62
58
|
load_kwargs["compression"] = self.compression
|
|
63
59
|
|
|
@@ -68,7 +64,7 @@ class AirLLMDriver(Driver):
|
|
|
68
64
|
# ------------------------------------------------------------------
|
|
69
65
|
# Driver interface
|
|
70
66
|
# ------------------------------------------------------------------
|
|
71
|
-
def generate(self, prompt: str, options:
|
|
67
|
+
def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
72
68
|
self._ensure_loaded()
|
|
73
69
|
|
|
74
70
|
merged_options = self.options.copy()
|
|
@@ -78,14 +74,11 @@ class AirLLMDriver(Driver):
|
|
|
78
74
|
max_new_tokens = merged_options.get("max_new_tokens", 256)
|
|
79
75
|
|
|
80
76
|
# Tokenize
|
|
81
|
-
input_ids = self._tokenizer(
|
|
82
|
-
prompt, return_tensors="pt"
|
|
83
|
-
).input_ids
|
|
77
|
+
input_ids = self._tokenizer(prompt, return_tensors="pt").input_ids
|
|
84
78
|
|
|
85
79
|
prompt_tokens = input_ids.shape[1]
|
|
86
80
|
|
|
87
|
-
logger.debug(f"AirLLM generating with max_new_tokens={max_new_tokens}, "
|
|
88
|
-
f"prompt_tokens={prompt_tokens}")
|
|
81
|
+
logger.debug(f"AirLLM generating with max_new_tokens={max_new_tokens}, prompt_tokens={prompt_tokens}")
|
|
89
82
|
|
|
90
83
|
# Generate
|
|
91
84
|
output_ids = self._llm.generate(
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Async AirLLM driver — wraps the sync GPU-bound driver with asyncio.to_thread."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from ..async_driver import AsyncDriver
|
|
9
|
+
from .airllm_driver import AirLLMDriver
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class AsyncAirLLMDriver(AsyncDriver):
|
|
13
|
+
"""Async wrapper around :class:`AirLLMDriver`.
|
|
14
|
+
|
|
15
|
+
AirLLM is GPU-bound with no native async API, so we delegate to
|
|
16
|
+
``asyncio.to_thread()`` to avoid blocking the event loop.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
MODEL_PRICING = AirLLMDriver.MODEL_PRICING
|
|
20
|
+
|
|
21
|
+
def __init__(self, model: str = "meta-llama/Llama-2-7b-hf", compression: str | None = None):
|
|
22
|
+
self.model = model
|
|
23
|
+
self._sync_driver = AirLLMDriver(model=model, compression=compression)
|
|
24
|
+
|
|
25
|
+
async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
26
|
+
return await asyncio.to_thread(self._sync_driver.generate, prompt, options)
|