prompture 0.0.33.dev2__py3-none-any.whl → 0.0.34.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. prompture/__init__.py +112 -54
  2. prompture/_version.py +34 -0
  3. prompture/aio/__init__.py +74 -0
  4. prompture/async_conversation.py +484 -0
  5. prompture/async_core.py +803 -0
  6. prompture/async_driver.py +131 -0
  7. prompture/cache.py +469 -0
  8. prompture/callbacks.py +50 -0
  9. prompture/cli.py +7 -3
  10. prompture/conversation.py +504 -0
  11. prompture/core.py +475 -352
  12. prompture/cost_mixin.py +51 -0
  13. prompture/discovery.py +41 -36
  14. prompture/driver.py +125 -5
  15. prompture/drivers/__init__.py +63 -57
  16. prompture/drivers/airllm_driver.py +13 -20
  17. prompture/drivers/async_airllm_driver.py +26 -0
  18. prompture/drivers/async_azure_driver.py +117 -0
  19. prompture/drivers/async_claude_driver.py +107 -0
  20. prompture/drivers/async_google_driver.py +132 -0
  21. prompture/drivers/async_grok_driver.py +91 -0
  22. prompture/drivers/async_groq_driver.py +84 -0
  23. prompture/drivers/async_hugging_driver.py +61 -0
  24. prompture/drivers/async_lmstudio_driver.py +79 -0
  25. prompture/drivers/async_local_http_driver.py +44 -0
  26. prompture/drivers/async_ollama_driver.py +125 -0
  27. prompture/drivers/async_openai_driver.py +96 -0
  28. prompture/drivers/async_openrouter_driver.py +96 -0
  29. prompture/drivers/async_registry.py +80 -0
  30. prompture/drivers/azure_driver.py +36 -15
  31. prompture/drivers/claude_driver.py +86 -40
  32. prompture/drivers/google_driver.py +86 -58
  33. prompture/drivers/grok_driver.py +29 -38
  34. prompture/drivers/groq_driver.py +27 -32
  35. prompture/drivers/hugging_driver.py +6 -6
  36. prompture/drivers/lmstudio_driver.py +26 -13
  37. prompture/drivers/local_http_driver.py +6 -6
  38. prompture/drivers/ollama_driver.py +90 -23
  39. prompture/drivers/openai_driver.py +36 -15
  40. prompture/drivers/openrouter_driver.py +31 -31
  41. prompture/field_definitions.py +106 -96
  42. prompture/logging.py +80 -0
  43. prompture/model_rates.py +16 -15
  44. prompture/runner.py +49 -47
  45. prompture/session.py +117 -0
  46. prompture/settings.py +11 -1
  47. prompture/tools.py +172 -265
  48. prompture/validator.py +3 -3
  49. {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/METADATA +18 -20
  50. prompture-0.0.34.dev1.dist-info/RECORD +54 -0
  51. prompture-0.0.33.dev2.dist-info/RECORD +0 -30
  52. {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/WHEEL +0 -0
  53. {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/entry_points.txt +0 -0
  54. {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/licenses/LICENSE +0 -0
  55. {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,51 @@
1
+ """Shared cost-calculation mixin for LLM drivers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+
8
+ class CostMixin:
9
+ """Mixin that provides ``_calculate_cost`` to sync and async drivers.
10
+
11
+ Drivers that charge per-token should inherit from this mixin alongside
12
+ their base class (``Driver`` or ``AsyncDriver``). Free/local drivers
13
+ (Ollama, LM Studio, LocalHTTP, HuggingFace, AirLLM) can skip it.
14
+ """
15
+
16
+ # Subclasses must define MODEL_PRICING as a class attribute.
17
+ MODEL_PRICING: dict[str, dict[str, Any]] = {}
18
+
19
+ # Divisor for hardcoded MODEL_PRICING rates.
20
+ # Most drivers use per-1K-token pricing (1_000).
21
+ # Grok uses per-1M-token pricing (1_000_000).
22
+ # Google uses per-1M-character pricing (1_000_000).
23
+ _PRICING_UNIT: int = 1_000
24
+
25
+ def _calculate_cost(
26
+ self,
27
+ provider: str,
28
+ model: str,
29
+ prompt_tokens: int | float,
30
+ completion_tokens: int | float,
31
+ ) -> float:
32
+ """Calculate USD cost for a generation call.
33
+
34
+ Resolution order:
35
+ 1. Live rates from ``model_rates.get_model_rates()`` (per 1M tokens).
36
+ 2. Hardcoded ``MODEL_PRICING`` on the driver class (unit set by ``_PRICING_UNIT``).
37
+ 3. Zero if neither source has data.
38
+ """
39
+ from .model_rates import get_model_rates
40
+
41
+ live_rates = get_model_rates(provider, model)
42
+ if live_rates:
43
+ prompt_cost = (prompt_tokens / 1_000_000) * live_rates["input"]
44
+ completion_cost = (completion_tokens / 1_000_000) * live_rates["output"]
45
+ else:
46
+ unit = self._PRICING_UNIT
47
+ model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
48
+ prompt_cost = (prompt_tokens / unit) * model_pricing["prompt"]
49
+ completion_cost = (completion_tokens / unit) * model_pricing["completion"]
50
+
51
+ return round(prompt_cost + completion_cost, 6)
prompture/discovery.py CHANGED
@@ -1,39 +1,40 @@
1
1
  """Discovery module for auto-detecting available models."""
2
+
3
+ import logging
2
4
  import os
5
+
3
6
  import requests
4
- import logging
5
- from typing import List, Dict, Any, Set
6
7
 
7
8
  from .drivers import (
8
- DRIVER_REGISTRY,
9
- OpenAIDriver,
10
9
  AzureDriver,
11
10
  ClaudeDriver,
12
11
  GoogleDriver,
13
- GroqDriver,
14
- OpenRouterDriver,
15
12
  GrokDriver,
16
- OllamaDriver,
13
+ GroqDriver,
17
14
  LMStudioDriver,
18
15
  LocalHTTPDriver,
16
+ OllamaDriver,
17
+ OpenAIDriver,
18
+ OpenRouterDriver,
19
19
  )
20
20
  from .settings import settings
21
21
 
22
22
  logger = logging.getLogger(__name__)
23
23
 
24
- def get_available_models() -> List[str]:
24
+
25
+ def get_available_models() -> list[str]:
25
26
  """
26
27
  Auto-detects all available models based on configured drivers and environment variables.
27
-
28
+
28
29
  Iterates through supported providers and checks if they are configured (e.g. API key present).
29
30
  For static drivers, returns models from their MODEL_PRICING keys.
30
31
  For dynamic drivers (like Ollama), attempts to fetch available models from the endpoint.
31
-
32
+
32
33
  Returns:
33
34
  A list of unique model strings in the format "provider/model_id".
34
35
  """
35
- available_models: Set[str] = set()
36
- configured_providers: Set[str] = set()
36
+ available_models: set[str] = set()
37
+ configured_providers: set[str] = set()
37
38
 
38
39
  # Map of provider name to driver class
39
40
  # We need to map the registry keys to the actual classes to check MODEL_PRICING
@@ -57,16 +58,18 @@ def get_available_models() -> List[str]:
57
58
  # We can check this by looking at the settings or env vars that the driver uses.
58
59
  # A simple way is to try to instantiate it with defaults, but that might fail if keys are missing.
59
60
  # Instead, let's check the specific requirements for each known provider.
60
-
61
+
61
62
  is_configured = False
62
-
63
+
63
64
  if provider == "openai":
64
65
  if settings.openai_api_key or os.getenv("OPENAI_API_KEY"):
65
66
  is_configured = True
66
67
  elif provider == "azure":
67
- if (settings.azure_api_key or os.getenv("AZURE_API_KEY")) and \
68
- (settings.azure_api_endpoint or os.getenv("AZURE_API_ENDPOINT")) and \
69
- (settings.azure_deployment_id or os.getenv("AZURE_DEPLOYMENT_ID")):
68
+ if (
69
+ (settings.azure_api_key or os.getenv("AZURE_API_KEY"))
70
+ and (settings.azure_api_endpoint or os.getenv("AZURE_API_ENDPOINT"))
71
+ and (settings.azure_deployment_id or os.getenv("AZURE_DEPLOYMENT_ID"))
72
+ ):
70
73
  is_configured = True
71
74
  elif provider == "claude":
72
75
  if settings.claude_api_key or os.getenv("CLAUDE_API_KEY"):
@@ -88,11 +91,10 @@ def get_available_models() -> List[str]:
88
91
  # We will check connectivity later
89
92
  is_configured = True
90
93
  elif provider == "lmstudio":
91
- # LM Studio is similar to Ollama, defaults to localhost
92
- is_configured = True
93
- elif provider == "local_http":
94
- if settings.local_http_endpoint or os.getenv("LOCAL_HTTP_ENDPOINT"):
95
- is_configured = True
94
+ # LM Studio is similar to Ollama, defaults to localhost
95
+ is_configured = True
96
+ elif provider == "local_http" and (settings.local_http_endpoint or os.getenv("LOCAL_HTTP_ENDPOINT")):
97
+ is_configured = True
96
98
 
97
99
  if not is_configured:
98
100
  continue
@@ -102,34 +104,36 @@ def get_available_models() -> List[str]:
102
104
  # 2. Static Detection: Get models from MODEL_PRICING
103
105
  if hasattr(driver_cls, "MODEL_PRICING"):
104
106
  pricing = driver_cls.MODEL_PRICING
105
- for model_id in pricing.keys():
107
+ for model_id in pricing:
106
108
  # Skip "default" or generic keys if they exist
107
109
  if model_id == "default":
108
110
  continue
109
-
110
- # For Azure, the model_id in pricing is usually the base model name,
111
- # but the user needs to use the deployment ID.
112
- # However, our Azure driver implementation uses the deployment_id from init
113
- # as the "model" for the request, but expects the user to pass a model name
114
- # that maps to pricing?
111
+
112
+ # For Azure, the model_id in pricing is usually the base model name,
113
+ # but the user needs to use the deployment ID.
114
+ # However, our Azure driver implementation uses the deployment_id from init
115
+ # as the "model" for the request, but expects the user to pass a model name
116
+ # that maps to pricing?
115
117
  # Looking at AzureDriver:
116
118
  # kwargs = {"model": self.deployment_id, ...}
117
119
  # model = options.get("model", self.model) -> used for pricing lookup
118
- # So we should list the keys in MODEL_PRICING as available "models"
120
+ # So we should list the keys in MODEL_PRICING as available "models"
119
121
  # even though for Azure specifically it's a bit weird because of deployment IDs.
120
122
  # But for general discovery, listing supported models is correct.
121
-
123
+
122
124
  available_models.add(f"{provider}/{model_id}")
123
125
 
124
126
  # 3. Dynamic Detection: Specific logic for Ollama
125
127
  if provider == "ollama":
126
128
  try:
127
- endpoint = settings.ollama_endpoint or os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434/api/generate")
129
+ endpoint = settings.ollama_endpoint or os.getenv(
130
+ "OLLAMA_ENDPOINT", "http://localhost:11434/api/generate"
131
+ )
128
132
  # We need the base URL for tags, usually http://localhost:11434/api/tags
129
133
  # The configured endpoint might be .../api/generate or .../api/chat
130
134
  base_url = endpoint.split("/api/")[0]
131
135
  tags_url = f"{base_url}/api/tags"
132
-
136
+
133
137
  resp = requests.get(tags_url, timeout=2)
134
138
  if resp.status_code == 200:
135
139
  data = resp.json()
@@ -142,15 +146,16 @@ def get_available_models() -> List[str]:
142
146
  available_models.add(f"ollama/{name}")
143
147
  except Exception as e:
144
148
  logger.debug(f"Failed to fetch Ollama models: {e}")
145
-
149
+
146
150
  # Future: Add dynamic detection for LM Studio if they have an endpoint for listing models
147
-
151
+
148
152
  except Exception as e:
149
153
  logger.warning(f"Error detecting models for provider {provider}: {e}")
150
154
  continue
151
155
 
152
156
  # Enrich with live model list from models.dev cache
153
- from .model_rates import get_all_provider_models, PROVIDER_MAP
157
+ from .model_rates import PROVIDER_MAP, get_all_provider_models
158
+
154
159
  for prompture_name, api_name in PROVIDER_MAP.items():
155
160
  if prompture_name in configured_providers:
156
161
  for model_id in get_all_provider_models(api_name):
prompture/driver.py CHANGED
@@ -1,7 +1,15 @@
1
- """Driver base class for LLM adapters.
2
- """
1
+ """Driver base class for LLM adapters."""
2
+
3
3
  from __future__ import annotations
4
- from typing import Any, Dict
4
+
5
+ import logging
6
+ import time
7
+ from typing import Any
8
+
9
+ from .callbacks import DriverCallbacks
10
+
11
+ logger = logging.getLogger("prompture.driver")
12
+
5
13
 
6
14
  class Driver:
7
15
  """Adapter base. Implementar generate(prompt, options) -> {"text": ... , "meta": {...}}
@@ -20,5 +28,117 @@ class Driver:
20
28
  additional provider-specific metadata while the core fields provide
21
29
  standardized access to token usage and cost information.
22
30
  """
23
- def generate(self, prompt: str, options: Dict[str,Any]) -> Dict[str,Any]:
24
- raise NotImplementedError
31
+
32
+ supports_json_mode: bool = False
33
+ supports_json_schema: bool = False
34
+ supports_messages: bool = False
35
+
36
+ callbacks: DriverCallbacks | None = None
37
+
38
+ def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
39
+ raise NotImplementedError
40
+
41
+ def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
42
+ """Generate a response from a list of conversation messages.
43
+
44
+ Each message is a dict with ``"role"`` (``"system"``, ``"user"``, or
45
+ ``"assistant"``) and ``"content"`` keys.
46
+
47
+ The default implementation flattens the messages into a single prompt
48
+ string and delegates to :meth:`generate`. Drivers that natively
49
+ support message arrays should override this method and set
50
+ ``supports_messages = True``.
51
+ """
52
+ prompt = self._flatten_messages(messages)
53
+ return self.generate(prompt, options)
54
+
55
+ # ------------------------------------------------------------------
56
+ # Hook-aware wrappers
57
+ # ------------------------------------------------------------------
58
+
59
+ def generate_with_hooks(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
60
+ """Wrap :meth:`generate` with on_request / on_response / on_error callbacks."""
61
+ driver_name = getattr(self, "model", self.__class__.__name__)
62
+ self._fire_callback(
63
+ "on_request",
64
+ {"prompt": prompt, "messages": None, "options": options, "driver": driver_name},
65
+ )
66
+ t0 = time.perf_counter()
67
+ try:
68
+ resp = self.generate(prompt, options)
69
+ except Exception as exc:
70
+ self._fire_callback(
71
+ "on_error",
72
+ {"error": exc, "prompt": prompt, "messages": None, "options": options, "driver": driver_name},
73
+ )
74
+ raise
75
+ elapsed_ms = (time.perf_counter() - t0) * 1000
76
+ self._fire_callback(
77
+ "on_response",
78
+ {
79
+ "text": resp.get("text", ""),
80
+ "meta": resp.get("meta", {}),
81
+ "driver": driver_name,
82
+ "elapsed_ms": elapsed_ms,
83
+ },
84
+ )
85
+ return resp
86
+
87
+ def generate_messages_with_hooks(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
88
+ """Wrap :meth:`generate_messages` with callbacks."""
89
+ driver_name = getattr(self, "model", self.__class__.__name__)
90
+ self._fire_callback(
91
+ "on_request",
92
+ {"prompt": None, "messages": messages, "options": options, "driver": driver_name},
93
+ )
94
+ t0 = time.perf_counter()
95
+ try:
96
+ resp = self.generate_messages(messages, options)
97
+ except Exception as exc:
98
+ self._fire_callback(
99
+ "on_error",
100
+ {"error": exc, "prompt": None, "messages": messages, "options": options, "driver": driver_name},
101
+ )
102
+ raise
103
+ elapsed_ms = (time.perf_counter() - t0) * 1000
104
+ self._fire_callback(
105
+ "on_response",
106
+ {
107
+ "text": resp.get("text", ""),
108
+ "meta": resp.get("meta", {}),
109
+ "driver": driver_name,
110
+ "elapsed_ms": elapsed_ms,
111
+ },
112
+ )
113
+ return resp
114
+
115
+ # ------------------------------------------------------------------
116
+ # Internal helpers
117
+ # ------------------------------------------------------------------
118
+
119
+ def _fire_callback(self, event: str, payload: dict[str, Any]) -> None:
120
+ """Invoke a single callback, swallowing and logging any exception."""
121
+ if self.callbacks is None:
122
+ return
123
+ cb = getattr(self.callbacks, event, None)
124
+ if cb is None:
125
+ return
126
+ try:
127
+ cb(payload)
128
+ except Exception:
129
+ logger.exception("Callback %s raised an exception", event)
130
+
131
+ @staticmethod
132
+ def _flatten_messages(messages: list[dict[str, str]]) -> str:
133
+ """Join messages into a single prompt string with role prefixes."""
134
+ parts: list[str] = []
135
+ for msg in messages:
136
+ role = msg.get("role", "user")
137
+ content = msg.get("content", "")
138
+ if role == "system":
139
+ parts.append(f"[System]: {content}")
140
+ elif role == "assistant":
141
+ parts.append(f"[Assistant]: {content}")
142
+ else:
143
+ parts.append(f"[User]: {content}")
144
+ return "\n\n".join(parts)
@@ -1,60 +1,49 @@
1
- from .openai_driver import OpenAIDriver
2
- from .local_http_driver import LocalHTTPDriver
3
- from .ollama_driver import OllamaDriver
4
- from .claude_driver import ClaudeDriver
1
+ from typing import Optional
2
+
3
+ from ..settings import settings
4
+ from .airllm_driver import AirLLMDriver
5
+ from .async_airllm_driver import AsyncAirLLMDriver
6
+ from .async_azure_driver import AsyncAzureDriver
7
+ from .async_claude_driver import AsyncClaudeDriver
8
+ from .async_google_driver import AsyncGoogleDriver
9
+ from .async_grok_driver import AsyncGrokDriver
10
+ from .async_groq_driver import AsyncGroqDriver
11
+ from .async_hugging_driver import AsyncHuggingFaceDriver
12
+ from .async_lmstudio_driver import AsyncLMStudioDriver
13
+ from .async_local_http_driver import AsyncLocalHTTPDriver
14
+ from .async_ollama_driver import AsyncOllamaDriver
15
+ from .async_openai_driver import AsyncOpenAIDriver
16
+ from .async_openrouter_driver import AsyncOpenRouterDriver
17
+ from .async_registry import ASYNC_DRIVER_REGISTRY, get_async_driver, get_async_driver_for_model
5
18
  from .azure_driver import AzureDriver
6
- from .lmstudio_driver import LMStudioDriver
19
+ from .claude_driver import ClaudeDriver
7
20
  from .google_driver import GoogleDriver
21
+ from .grok_driver import GrokDriver
8
22
  from .groq_driver import GroqDriver
23
+ from .lmstudio_driver import LMStudioDriver
24
+ from .local_http_driver import LocalHTTPDriver
25
+ from .ollama_driver import OllamaDriver
26
+ from .openai_driver import OpenAIDriver
9
27
  from .openrouter_driver import OpenRouterDriver
10
- from .grok_driver import GrokDriver
11
- from .airllm_driver import AirLLMDriver
12
- from ..settings import settings
13
-
14
28
 
15
29
  # Central registry: maps provider → factory function
16
30
  DRIVER_REGISTRY = {
17
- "openai": lambda model=None: OpenAIDriver(
18
- api_key=settings.openai_api_key,
19
- model=model or settings.openai_model
20
- ),
21
- "ollama": lambda model=None: OllamaDriver(
22
- endpoint=settings.ollama_endpoint,
23
- model=model or settings.ollama_model
24
- ),
25
- "claude": lambda model=None: ClaudeDriver(
26
- api_key=settings.claude_api_key,
27
- model=model or settings.claude_model
28
- ),
31
+ "openai": lambda model=None: OpenAIDriver(api_key=settings.openai_api_key, model=model or settings.openai_model),
32
+ "ollama": lambda model=None: OllamaDriver(endpoint=settings.ollama_endpoint, model=model or settings.ollama_model),
33
+ "claude": lambda model=None: ClaudeDriver(api_key=settings.claude_api_key, model=model or settings.claude_model),
29
34
  "lmstudio": lambda model=None: LMStudioDriver(
30
- endpoint=settings.lmstudio_endpoint,
31
- model=model or settings.lmstudio_model
35
+ endpoint=settings.lmstudio_endpoint, model=model or settings.lmstudio_model
32
36
  ),
33
37
  "azure": lambda model=None: AzureDriver(
34
- api_key=settings.azure_api_key,
35
- endpoint=settings.azure_api_endpoint,
36
- deployment_id=settings.azure_deployment_id
37
- ),
38
- "local_http": lambda model=None: LocalHTTPDriver(
39
- endpoint=settings.local_http_endpoint,
40
- model=model
41
- ),
42
- "google": lambda model=None: GoogleDriver(
43
- api_key=settings.google_api_key,
44
- model=model or settings.google_model
45
- ),
46
- "groq": lambda model=None: GroqDriver(
47
- api_key=settings.groq_api_key,
48
- model=model or settings.groq_model
38
+ api_key=settings.azure_api_key, endpoint=settings.azure_api_endpoint, deployment_id=settings.azure_deployment_id
49
39
  ),
40
+ "local_http": lambda model=None: LocalHTTPDriver(endpoint=settings.local_http_endpoint, model=model),
41
+ "google": lambda model=None: GoogleDriver(api_key=settings.google_api_key, model=model or settings.google_model),
42
+ "groq": lambda model=None: GroqDriver(api_key=settings.groq_api_key, model=model or settings.groq_model),
50
43
  "openrouter": lambda model=None: OpenRouterDriver(
51
- api_key=settings.openrouter_api_key,
52
- model=model or settings.openrouter_model
53
- ),
54
- "grok": lambda model=None: GrokDriver(
55
- api_key=settings.grok_api_key,
56
- model=model or settings.grok_model
44
+ api_key=settings.openrouter_api_key, model=model or settings.openrouter_model
57
45
  ),
46
+ "grok": lambda model=None: GrokDriver(api_key=settings.grok_api_key, model=model or settings.grok_model),
58
47
  "airllm": lambda model=None: AirLLMDriver(
59
48
  model=model or settings.airllm_model,
60
49
  compression=settings.airllm_compression,
@@ -62,7 +51,7 @@ DRIVER_REGISTRY = {
62
51
  }
63
52
 
64
53
 
65
- def get_driver(provider_name: str = None):
54
+ def get_driver(provider_name: Optional[str] = None):
66
55
  """
67
56
  Factory to get a driver instance based on the provider name (legacy style).
68
57
  Uses default model from settings if not overridden.
@@ -78,21 +67,21 @@ def get_driver_for_model(model_str: str):
78
67
  Factory to get a driver instance based on a full model string.
79
68
  Format: provider/model_id
80
69
  Example: "openai/gpt-4-turbo-preview"
81
-
70
+
82
71
  Args:
83
72
  model_str: Model identifier string. Can be either:
84
73
  - Full format: "provider/model" (e.g. "openai/gpt-4")
85
74
  - Provider only: "provider" (e.g. "openai")
86
-
75
+
87
76
  Returns:
88
77
  A configured driver instance for the specified provider/model.
89
-
78
+
90
79
  Raises:
91
80
  ValueError: If provider is invalid or format is incorrect.
92
81
  """
93
82
  if not isinstance(model_str, str):
94
83
  raise ValueError("Model string must be a string, got {type(model_str)}")
95
-
84
+
96
85
  if not model_str:
97
86
  raise ValueError("Model string cannot be empty")
98
87
 
@@ -104,23 +93,40 @@ def get_driver_for_model(model_str: str):
104
93
  # Validate provider
105
94
  if provider not in DRIVER_REGISTRY:
106
95
  raise ValueError(f"Unsupported provider '{provider}'")
107
-
96
+
108
97
  # Create driver with model ID if provided, otherwise use default
109
98
  return DRIVER_REGISTRY[provider](model_id)
110
99
 
111
100
 
112
101
  __all__ = [
113
- "OpenAIDriver",
114
- "LocalHTTPDriver",
115
- "OllamaDriver",
116
- "ClaudeDriver",
117
- "LMStudioDriver",
102
+ # Async drivers
103
+ "ASYNC_DRIVER_REGISTRY",
104
+ # Sync drivers
105
+ "AirLLMDriver",
106
+ "AsyncAirLLMDriver",
107
+ "AsyncAzureDriver",
108
+ "AsyncClaudeDriver",
109
+ "AsyncGoogleDriver",
110
+ "AsyncGrokDriver",
111
+ "AsyncGroqDriver",
112
+ "AsyncHuggingFaceDriver",
113
+ "AsyncLMStudioDriver",
114
+ "AsyncLocalHTTPDriver",
115
+ "AsyncOllamaDriver",
116
+ "AsyncOpenAIDriver",
117
+ "AsyncOpenRouterDriver",
118
118
  "AzureDriver",
119
+ "ClaudeDriver",
119
120
  "GoogleDriver",
121
+ "GrokDriver",
120
122
  "GroqDriver",
123
+ "LMStudioDriver",
124
+ "LocalHTTPDriver",
125
+ "OllamaDriver",
126
+ "OpenAIDriver",
121
127
  "OpenRouterDriver",
122
- "GrokDriver",
123
- "AirLLMDriver",
128
+ "get_async_driver",
129
+ "get_async_driver_for_model",
124
130
  "get_driver",
125
131
  "get_driver_for_model",
126
132
  ]
@@ -1,6 +1,7 @@
1
1
  import logging
2
+ from typing import Any, Optional
3
+
2
4
  from ..driver import Driver
3
- from typing import Any, Dict, Optional
4
5
 
5
6
  logger = logging.getLogger(__name__)
6
7
 
@@ -13,12 +14,9 @@ class AirLLMDriver(Driver):
13
14
  ``generate()`` call so the rest of Prompture works without it installed.
14
15
  """
15
16
 
16
- MODEL_PRICING = {
17
- "default": {"prompt": 0.0, "completion": 0.0}
18
- }
17
+ MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
19
18
 
20
- def __init__(self, model: str = "meta-llama/Llama-2-7b-hf",
21
- compression: Optional[str] = None):
19
+ def __init__(self, model: str = "meta-llama/Llama-2-7b-hf", compression: Optional[str] = None):
22
20
  """
23
21
  Args:
24
22
  model: HuggingFace repo ID (e.g. ``"meta-llama/Llama-2-70b-hf"``).
@@ -26,7 +24,7 @@ class AirLLMDriver(Driver):
26
24
  """
27
25
  self.model = model
28
26
  self.compression = compression
29
- self.options: Dict[str, Any] = {}
27
+ self.options: dict[str, Any] = {}
30
28
  self._llm = None
31
29
  self._tokenizer = None
32
30
 
@@ -42,9 +40,8 @@ class AirLLMDriver(Driver):
42
40
  from airllm import AutoModel
43
41
  except ImportError:
44
42
  raise ImportError(
45
- "The 'airllm' package is required for the AirLLM driver. "
46
- "Install it with: pip install prompture[airllm]"
47
- )
43
+ "The 'airllm' package is required for the AirLLM driver. Install it with: pip install prompture[airllm]"
44
+ ) from None
48
45
 
49
46
  try:
50
47
  from transformers import AutoTokenizer
@@ -52,12 +49,11 @@ class AirLLMDriver(Driver):
52
49
  raise ImportError(
53
50
  "The 'transformers' package is required for the AirLLM driver. "
54
51
  "Install it with: pip install transformers"
55
- )
52
+ ) from None
56
53
 
57
- logger.info(f"Loading AirLLM model: {self.model} "
58
- f"(compression={self.compression})")
54
+ logger.info(f"Loading AirLLM model: {self.model} (compression={self.compression})")
59
55
 
60
- load_kwargs: Dict[str, Any] = {}
56
+ load_kwargs: dict[str, Any] = {}
61
57
  if self.compression:
62
58
  load_kwargs["compression"] = self.compression
63
59
 
@@ -68,7 +64,7 @@ class AirLLMDriver(Driver):
68
64
  # ------------------------------------------------------------------
69
65
  # Driver interface
70
66
  # ------------------------------------------------------------------
71
- def generate(self, prompt: str, options: Dict[str, Any] = None) -> Dict[str, Any]:
67
+ def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
72
68
  self._ensure_loaded()
73
69
 
74
70
  merged_options = self.options.copy()
@@ -78,14 +74,11 @@ class AirLLMDriver(Driver):
78
74
  max_new_tokens = merged_options.get("max_new_tokens", 256)
79
75
 
80
76
  # Tokenize
81
- input_ids = self._tokenizer(
82
- prompt, return_tensors="pt"
83
- ).input_ids
77
+ input_ids = self._tokenizer(prompt, return_tensors="pt").input_ids
84
78
 
85
79
  prompt_tokens = input_ids.shape[1]
86
80
 
87
- logger.debug(f"AirLLM generating with max_new_tokens={max_new_tokens}, "
88
- f"prompt_tokens={prompt_tokens}")
81
+ logger.debug(f"AirLLM generating with max_new_tokens={max_new_tokens}, prompt_tokens={prompt_tokens}")
89
82
 
90
83
  # Generate
91
84
  output_ids = self._llm.generate(
@@ -0,0 +1,26 @@
1
+ """Async AirLLM driver — wraps the sync GPU-bound driver with asyncio.to_thread."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ from typing import Any
7
+
8
+ from ..async_driver import AsyncDriver
9
+ from .airllm_driver import AirLLMDriver
10
+
11
+
12
+ class AsyncAirLLMDriver(AsyncDriver):
13
+ """Async wrapper around :class:`AirLLMDriver`.
14
+
15
+ AirLLM is GPU-bound with no native async API, so we delegate to
16
+ ``asyncio.to_thread()`` to avoid blocking the event loop.
17
+ """
18
+
19
+ MODEL_PRICING = AirLLMDriver.MODEL_PRICING
20
+
21
+ def __init__(self, model: str = "meta-llama/Llama-2-7b-hf", compression: str | None = None):
22
+ self.model = model
23
+ self._sync_driver = AirLLMDriver(model=model, compression=compression)
24
+
25
+ async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
26
+ return await asyncio.to_thread(self._sync_driver.generate, prompt, options)