prompture 0.0.33.dev1__py3-none-any.whl → 0.0.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. prompture/__init__.py +133 -49
  2. prompture/_version.py +34 -0
  3. prompture/aio/__init__.py +74 -0
  4. prompture/async_conversation.py +484 -0
  5. prompture/async_core.py +803 -0
  6. prompture/async_driver.py +131 -0
  7. prompture/cache.py +469 -0
  8. prompture/callbacks.py +50 -0
  9. prompture/cli.py +7 -3
  10. prompture/conversation.py +504 -0
  11. prompture/core.py +475 -352
  12. prompture/cost_mixin.py +51 -0
  13. prompture/discovery.py +50 -35
  14. prompture/driver.py +125 -5
  15. prompture/drivers/__init__.py +171 -73
  16. prompture/drivers/airllm_driver.py +13 -20
  17. prompture/drivers/async_airllm_driver.py +26 -0
  18. prompture/drivers/async_azure_driver.py +117 -0
  19. prompture/drivers/async_claude_driver.py +107 -0
  20. prompture/drivers/async_google_driver.py +132 -0
  21. prompture/drivers/async_grok_driver.py +91 -0
  22. prompture/drivers/async_groq_driver.py +84 -0
  23. prompture/drivers/async_hugging_driver.py +61 -0
  24. prompture/drivers/async_lmstudio_driver.py +79 -0
  25. prompture/drivers/async_local_http_driver.py +44 -0
  26. prompture/drivers/async_ollama_driver.py +125 -0
  27. prompture/drivers/async_openai_driver.py +96 -0
  28. prompture/drivers/async_openrouter_driver.py +96 -0
  29. prompture/drivers/async_registry.py +129 -0
  30. prompture/drivers/azure_driver.py +36 -9
  31. prompture/drivers/claude_driver.py +86 -34
  32. prompture/drivers/google_driver.py +87 -51
  33. prompture/drivers/grok_driver.py +29 -32
  34. prompture/drivers/groq_driver.py +27 -26
  35. prompture/drivers/hugging_driver.py +6 -6
  36. prompture/drivers/lmstudio_driver.py +26 -13
  37. prompture/drivers/local_http_driver.py +6 -6
  38. prompture/drivers/ollama_driver.py +90 -23
  39. prompture/drivers/openai_driver.py +36 -9
  40. prompture/drivers/openrouter_driver.py +31 -25
  41. prompture/drivers/registry.py +306 -0
  42. prompture/field_definitions.py +106 -96
  43. prompture/logging.py +80 -0
  44. prompture/model_rates.py +217 -0
  45. prompture/runner.py +49 -47
  46. prompture/session.py +117 -0
  47. prompture/settings.py +14 -1
  48. prompture/tools.py +172 -265
  49. prompture/validator.py +3 -3
  50. {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/METADATA +18 -20
  51. prompture-0.0.34.dist-info/RECORD +55 -0
  52. prompture-0.0.33.dev1.dist-info/RECORD +0 -29
  53. {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/WHEEL +0 -0
  54. {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/entry_points.txt +0 -0
  55. {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/licenses/LICENSE +0 -0
  56. {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,51 @@
1
+ """Shared cost-calculation mixin for LLM drivers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+
8
+ class CostMixin:
9
+ """Mixin that provides ``_calculate_cost`` to sync and async drivers.
10
+
11
+ Drivers that charge per-token should inherit from this mixin alongside
12
+ their base class (``Driver`` or ``AsyncDriver``). Free/local drivers
13
+ (Ollama, LM Studio, LocalHTTP, HuggingFace, AirLLM) can skip it.
14
+ """
15
+
16
+ # Subclasses must define MODEL_PRICING as a class attribute.
17
+ MODEL_PRICING: dict[str, dict[str, Any]] = {}
18
+
19
+ # Divisor for hardcoded MODEL_PRICING rates.
20
+ # Most drivers use per-1K-token pricing (1_000).
21
+ # Grok uses per-1M-token pricing (1_000_000).
22
+ # Google uses per-1M-character pricing (1_000_000).
23
+ _PRICING_UNIT: int = 1_000
24
+
25
+ def _calculate_cost(
26
+ self,
27
+ provider: str,
28
+ model: str,
29
+ prompt_tokens: int | float,
30
+ completion_tokens: int | float,
31
+ ) -> float:
32
+ """Calculate USD cost for a generation call.
33
+
34
+ Resolution order:
35
+ 1. Live rates from ``model_rates.get_model_rates()`` (per 1M tokens).
36
+ 2. Hardcoded ``MODEL_PRICING`` on the driver class (unit set by ``_PRICING_UNIT``).
37
+ 3. Zero if neither source has data.
38
+ """
39
+ from .model_rates import get_model_rates
40
+
41
+ live_rates = get_model_rates(provider, model)
42
+ if live_rates:
43
+ prompt_cost = (prompt_tokens / 1_000_000) * live_rates["input"]
44
+ completion_cost = (completion_tokens / 1_000_000) * live_rates["output"]
45
+ else:
46
+ unit = self._PRICING_UNIT
47
+ model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
48
+ prompt_cost = (prompt_tokens / unit) * model_pricing["prompt"]
49
+ completion_cost = (completion_tokens / unit) * model_pricing["completion"]
50
+
51
+ return round(prompt_cost + completion_cost, 6)
prompture/discovery.py CHANGED
@@ -1,39 +1,41 @@
1
1
  """Discovery module for auto-detecting available models."""
2
+
3
+ import logging
2
4
  import os
5
+
3
6
  import requests
4
- import logging
5
- from typing import List, Dict, Any, Set
6
7
 
7
8
  from .drivers import (
8
- DRIVER_REGISTRY,
9
- OpenAIDriver,
10
9
  AzureDriver,
11
10
  ClaudeDriver,
12
11
  GoogleDriver,
13
- GroqDriver,
14
- OpenRouterDriver,
15
12
  GrokDriver,
16
- OllamaDriver,
13
+ GroqDriver,
17
14
  LMStudioDriver,
18
15
  LocalHTTPDriver,
16
+ OllamaDriver,
17
+ OpenAIDriver,
18
+ OpenRouterDriver,
19
19
  )
20
20
  from .settings import settings
21
21
 
22
22
  logger = logging.getLogger(__name__)
23
23
 
24
- def get_available_models() -> List[str]:
24
+
25
+ def get_available_models() -> list[str]:
25
26
  """
26
27
  Auto-detects all available models based on configured drivers and environment variables.
27
-
28
+
28
29
  Iterates through supported providers and checks if they are configured (e.g. API key present).
29
30
  For static drivers, returns models from their MODEL_PRICING keys.
30
31
  For dynamic drivers (like Ollama), attempts to fetch available models from the endpoint.
31
-
32
+
32
33
  Returns:
33
34
  A list of unique model strings in the format "provider/model_id".
34
35
  """
35
- available_models: Set[str] = set()
36
-
36
+ available_models: set[str] = set()
37
+ configured_providers: set[str] = set()
38
+
37
39
  # Map of provider name to driver class
38
40
  # We need to map the registry keys to the actual classes to check MODEL_PRICING
39
41
  # and instantiate for dynamic checks if needed.
@@ -56,16 +58,18 @@ def get_available_models() -> List[str]:
56
58
  # We can check this by looking at the settings or env vars that the driver uses.
57
59
  # A simple way is to try to instantiate it with defaults, but that might fail if keys are missing.
58
60
  # Instead, let's check the specific requirements for each known provider.
59
-
61
+
60
62
  is_configured = False
61
-
63
+
62
64
  if provider == "openai":
63
65
  if settings.openai_api_key or os.getenv("OPENAI_API_KEY"):
64
66
  is_configured = True
65
67
  elif provider == "azure":
66
- if (settings.azure_api_key or os.getenv("AZURE_API_KEY")) and \
67
- (settings.azure_api_endpoint or os.getenv("AZURE_API_ENDPOINT")) and \
68
- (settings.azure_deployment_id or os.getenv("AZURE_DEPLOYMENT_ID")):
68
+ if (
69
+ (settings.azure_api_key or os.getenv("AZURE_API_KEY"))
70
+ and (settings.azure_api_endpoint or os.getenv("AZURE_API_ENDPOINT"))
71
+ and (settings.azure_deployment_id or os.getenv("AZURE_DEPLOYMENT_ID"))
72
+ ):
69
73
  is_configured = True
70
74
  elif provider == "claude":
71
75
  if settings.claude_api_key or os.getenv("CLAUDE_API_KEY"):
@@ -87,46 +91,49 @@ def get_available_models() -> List[str]:
87
91
  # We will check connectivity later
88
92
  is_configured = True
89
93
  elif provider == "lmstudio":
90
- # LM Studio is similar to Ollama, defaults to localhost
91
- is_configured = True
92
- elif provider == "local_http":
93
- if settings.local_http_endpoint or os.getenv("LOCAL_HTTP_ENDPOINT"):
94
- is_configured = True
94
+ # LM Studio is similar to Ollama, defaults to localhost
95
+ is_configured = True
96
+ elif provider == "local_http" and (settings.local_http_endpoint or os.getenv("LOCAL_HTTP_ENDPOINT")):
97
+ is_configured = True
95
98
 
96
99
  if not is_configured:
97
100
  continue
98
101
 
102
+ configured_providers.add(provider)
103
+
99
104
  # 2. Static Detection: Get models from MODEL_PRICING
100
105
  if hasattr(driver_cls, "MODEL_PRICING"):
101
106
  pricing = driver_cls.MODEL_PRICING
102
- for model_id in pricing.keys():
107
+ for model_id in pricing:
103
108
  # Skip "default" or generic keys if they exist
104
109
  if model_id == "default":
105
110
  continue
106
-
107
- # For Azure, the model_id in pricing is usually the base model name,
108
- # but the user needs to use the deployment ID.
109
- # However, our Azure driver implementation uses the deployment_id from init
110
- # as the "model" for the request, but expects the user to pass a model name
111
- # that maps to pricing?
111
+
112
+ # For Azure, the model_id in pricing is usually the base model name,
113
+ # but the user needs to use the deployment ID.
114
+ # However, our Azure driver implementation uses the deployment_id from init
115
+ # as the "model" for the request, but expects the user to pass a model name
116
+ # that maps to pricing?
112
117
  # Looking at AzureDriver:
113
118
  # kwargs = {"model": self.deployment_id, ...}
114
119
  # model = options.get("model", self.model) -> used for pricing lookup
115
- # So we should list the keys in MODEL_PRICING as available "models"
120
+ # So we should list the keys in MODEL_PRICING as available "models"
116
121
  # even though for Azure specifically it's a bit weird because of deployment IDs.
117
122
  # But for general discovery, listing supported models is correct.
118
-
123
+
119
124
  available_models.add(f"{provider}/{model_id}")
120
125
 
121
126
  # 3. Dynamic Detection: Specific logic for Ollama
122
127
  if provider == "ollama":
123
128
  try:
124
- endpoint = settings.ollama_endpoint or os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434/api/generate")
129
+ endpoint = settings.ollama_endpoint or os.getenv(
130
+ "OLLAMA_ENDPOINT", "http://localhost:11434/api/generate"
131
+ )
125
132
  # We need the base URL for tags, usually http://localhost:11434/api/tags
126
133
  # The configured endpoint might be .../api/generate or .../api/chat
127
134
  base_url = endpoint.split("/api/")[0]
128
135
  tags_url = f"{base_url}/api/tags"
129
-
136
+
130
137
  resp = requests.get(tags_url, timeout=2)
131
138
  if resp.status_code == 200:
132
139
  data = resp.json()
@@ -139,11 +146,19 @@ def get_available_models() -> List[str]:
139
146
  available_models.add(f"ollama/{name}")
140
147
  except Exception as e:
141
148
  logger.debug(f"Failed to fetch Ollama models: {e}")
142
-
149
+
143
150
  # Future: Add dynamic detection for LM Studio if they have an endpoint for listing models
144
-
151
+
145
152
  except Exception as e:
146
153
  logger.warning(f"Error detecting models for provider {provider}: {e}")
147
154
  continue
148
155
 
156
+ # Enrich with live model list from models.dev cache
157
+ from .model_rates import PROVIDER_MAP, get_all_provider_models
158
+
159
+ for prompture_name, api_name in PROVIDER_MAP.items():
160
+ if prompture_name in configured_providers:
161
+ for model_id in get_all_provider_models(api_name):
162
+ available_models.add(f"{prompture_name}/{model_id}")
163
+
149
164
  return sorted(list(available_models))
prompture/driver.py CHANGED
@@ -1,7 +1,15 @@
1
- """Driver base class for LLM adapters.
2
- """
1
+ """Driver base class for LLM adapters."""
2
+
3
3
  from __future__ import annotations
4
- from typing import Any, Dict
4
+
5
+ import logging
6
+ import time
7
+ from typing import Any
8
+
9
+ from .callbacks import DriverCallbacks
10
+
11
+ logger = logging.getLogger("prompture.driver")
12
+
5
13
 
6
14
  class Driver:
7
15
  """Adapter base. Implementar generate(prompt, options) -> {"text": ... , "meta": {...}}
@@ -20,5 +28,117 @@ class Driver:
20
28
  additional provider-specific metadata while the core fields provide
21
29
  standardized access to token usage and cost information.
22
30
  """
23
- def generate(self, prompt: str, options: Dict[str,Any]) -> Dict[str,Any]:
24
- raise NotImplementedError
31
+
32
+ supports_json_mode: bool = False
33
+ supports_json_schema: bool = False
34
+ supports_messages: bool = False
35
+
36
+ callbacks: DriverCallbacks | None = None
37
+
38
+ def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
39
+ raise NotImplementedError
40
+
41
+ def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
42
+ """Generate a response from a list of conversation messages.
43
+
44
+ Each message is a dict with ``"role"`` (``"system"``, ``"user"``, or
45
+ ``"assistant"``) and ``"content"`` keys.
46
+
47
+ The default implementation flattens the messages into a single prompt
48
+ string and delegates to :meth:`generate`. Drivers that natively
49
+ support message arrays should override this method and set
50
+ ``supports_messages = True``.
51
+ """
52
+ prompt = self._flatten_messages(messages)
53
+ return self.generate(prompt, options)
54
+
55
+ # ------------------------------------------------------------------
56
+ # Hook-aware wrappers
57
+ # ------------------------------------------------------------------
58
+
59
+ def generate_with_hooks(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
60
+ """Wrap :meth:`generate` with on_request / on_response / on_error callbacks."""
61
+ driver_name = getattr(self, "model", self.__class__.__name__)
62
+ self._fire_callback(
63
+ "on_request",
64
+ {"prompt": prompt, "messages": None, "options": options, "driver": driver_name},
65
+ )
66
+ t0 = time.perf_counter()
67
+ try:
68
+ resp = self.generate(prompt, options)
69
+ except Exception as exc:
70
+ self._fire_callback(
71
+ "on_error",
72
+ {"error": exc, "prompt": prompt, "messages": None, "options": options, "driver": driver_name},
73
+ )
74
+ raise
75
+ elapsed_ms = (time.perf_counter() - t0) * 1000
76
+ self._fire_callback(
77
+ "on_response",
78
+ {
79
+ "text": resp.get("text", ""),
80
+ "meta": resp.get("meta", {}),
81
+ "driver": driver_name,
82
+ "elapsed_ms": elapsed_ms,
83
+ },
84
+ )
85
+ return resp
86
+
87
+ def generate_messages_with_hooks(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
88
+ """Wrap :meth:`generate_messages` with callbacks."""
89
+ driver_name = getattr(self, "model", self.__class__.__name__)
90
+ self._fire_callback(
91
+ "on_request",
92
+ {"prompt": None, "messages": messages, "options": options, "driver": driver_name},
93
+ )
94
+ t0 = time.perf_counter()
95
+ try:
96
+ resp = self.generate_messages(messages, options)
97
+ except Exception as exc:
98
+ self._fire_callback(
99
+ "on_error",
100
+ {"error": exc, "prompt": None, "messages": messages, "options": options, "driver": driver_name},
101
+ )
102
+ raise
103
+ elapsed_ms = (time.perf_counter() - t0) * 1000
104
+ self._fire_callback(
105
+ "on_response",
106
+ {
107
+ "text": resp.get("text", ""),
108
+ "meta": resp.get("meta", {}),
109
+ "driver": driver_name,
110
+ "elapsed_ms": elapsed_ms,
111
+ },
112
+ )
113
+ return resp
114
+
115
+ # ------------------------------------------------------------------
116
+ # Internal helpers
117
+ # ------------------------------------------------------------------
118
+
119
+ def _fire_callback(self, event: str, payload: dict[str, Any]) -> None:
120
+ """Invoke a single callback, swallowing and logging any exception."""
121
+ if self.callbacks is None:
122
+ return
123
+ cb = getattr(self.callbacks, event, None)
124
+ if cb is None:
125
+ return
126
+ try:
127
+ cb(payload)
128
+ except Exception:
129
+ logger.exception("Callback %s raised an exception", event)
130
+
131
+ @staticmethod
132
+ def _flatten_messages(messages: list[dict[str, str]]) -> str:
133
+ """Join messages into a single prompt string with role prefixes."""
134
+ parts: list[str] = []
135
+ for msg in messages:
136
+ role = msg.get("role", "user")
137
+ content = msg.get("content", "")
138
+ if role == "system":
139
+ parts.append(f"[System]: {content}")
140
+ elif role == "assistant":
141
+ parts.append(f"[Assistant]: {content}")
142
+ else:
143
+ parts.append(f"[User]: {content}")
144
+ return "\n\n".join(parts)
@@ -1,76 +1,145 @@
1
- from .openai_driver import OpenAIDriver
2
- from .local_http_driver import LocalHTTPDriver
3
- from .ollama_driver import OllamaDriver
4
- from .claude_driver import ClaudeDriver
1
+ """Driver registry and factory functions.
2
+
3
+ This module provides:
4
+ - Built-in drivers for popular LLM providers
5
+ - A pluggable registry system for custom drivers
6
+ - Factory functions to instantiate drivers by provider/model name
7
+
8
+ Custom Driver Registration:
9
+ from prompture import register_driver
10
+
11
+ def my_driver_factory(model=None):
12
+ return MyCustomDriver(model=model)
13
+
14
+ register_driver("my_provider", my_driver_factory)
15
+
16
+ # Now you can use it
17
+ driver = get_driver_for_model("my_provider/my-model")
18
+
19
+ Entry Point Discovery:
20
+ Third-party packages can register drivers via entry points.
21
+ Add to your pyproject.toml:
22
+
23
+ [project.entry-points."prompture.drivers"]
24
+ my_provider = "my_package.drivers:my_driver_factory"
25
+ """
26
+
27
+ from typing import Optional
28
+
29
+ from ..settings import settings
30
+ from .airllm_driver import AirLLMDriver
31
+ from .async_airllm_driver import AsyncAirLLMDriver
32
+ from .async_azure_driver import AsyncAzureDriver
33
+ from .async_claude_driver import AsyncClaudeDriver
34
+ from .async_google_driver import AsyncGoogleDriver
35
+ from .async_grok_driver import AsyncGrokDriver
36
+ from .async_groq_driver import AsyncGroqDriver
37
+ from .async_hugging_driver import AsyncHuggingFaceDriver
38
+ from .async_lmstudio_driver import AsyncLMStudioDriver
39
+ from .async_local_http_driver import AsyncLocalHTTPDriver
40
+ from .async_ollama_driver import AsyncOllamaDriver
41
+ from .async_openai_driver import AsyncOpenAIDriver
42
+ from .async_openrouter_driver import AsyncOpenRouterDriver
43
+ from .async_registry import ASYNC_DRIVER_REGISTRY, get_async_driver, get_async_driver_for_model
5
44
  from .azure_driver import AzureDriver
6
- from .lmstudio_driver import LMStudioDriver
45
+ from .claude_driver import ClaudeDriver
7
46
  from .google_driver import GoogleDriver
47
+ from .grok_driver import GrokDriver
8
48
  from .groq_driver import GroqDriver
49
+ from .lmstudio_driver import LMStudioDriver
50
+ from .local_http_driver import LocalHTTPDriver
51
+ from .ollama_driver import OllamaDriver
52
+ from .openai_driver import OpenAIDriver
9
53
  from .openrouter_driver import OpenRouterDriver
10
- from .grok_driver import GrokDriver
11
- from .airllm_driver import AirLLMDriver
12
- from ..settings import settings
54
+ from .registry import (
55
+ _get_sync_registry,
56
+ get_async_driver_factory,
57
+ get_driver_factory,
58
+ is_async_driver_registered,
59
+ is_driver_registered,
60
+ list_registered_async_drivers,
61
+ list_registered_drivers,
62
+ load_entry_point_drivers,
63
+ register_async_driver,
64
+ register_driver,
65
+ unregister_async_driver,
66
+ unregister_driver,
67
+ )
13
68
 
14
-
15
- # Central registry: maps provider → factory function
16
- DRIVER_REGISTRY = {
17
- "openai": lambda model=None: OpenAIDriver(
18
- api_key=settings.openai_api_key,
19
- model=model or settings.openai_model
20
- ),
21
- "ollama": lambda model=None: OllamaDriver(
22
- endpoint=settings.ollama_endpoint,
23
- model=model or settings.ollama_model
24
- ),
25
- "claude": lambda model=None: ClaudeDriver(
26
- api_key=settings.claude_api_key,
27
- model=model or settings.claude_model
69
+ # Register built-in sync drivers
70
+ register_driver(
71
+ "openai",
72
+ lambda model=None: OpenAIDriver(api_key=settings.openai_api_key, model=model or settings.openai_model),
73
+ overwrite=True,
74
+ )
75
+ register_driver(
76
+ "ollama",
77
+ lambda model=None: OllamaDriver(endpoint=settings.ollama_endpoint, model=model or settings.ollama_model),
78
+ overwrite=True,
79
+ )
80
+ register_driver(
81
+ "claude",
82
+ lambda model=None: ClaudeDriver(api_key=settings.claude_api_key, model=model or settings.claude_model),
83
+ overwrite=True,
84
+ )
85
+ register_driver(
86
+ "lmstudio",
87
+ lambda model=None: LMStudioDriver(endpoint=settings.lmstudio_endpoint, model=model or settings.lmstudio_model),
88
+ overwrite=True,
89
+ )
90
+ register_driver(
91
+ "azure",
92
+ lambda model=None: AzureDriver(
93
+ api_key=settings.azure_api_key, endpoint=settings.azure_api_endpoint, deployment_id=settings.azure_deployment_id
28
94
  ),
29
- "lmstudio": lambda model=None: LMStudioDriver(
30
- endpoint=settings.lmstudio_endpoint,
31
- model=model or settings.lmstudio_model
32
- ),
33
- "azure": lambda model=None: AzureDriver(
34
- api_key=settings.azure_api_key,
35
- endpoint=settings.azure_api_endpoint,
36
- deployment_id=settings.azure_deployment_id
37
- ),
38
- "local_http": lambda model=None: LocalHTTPDriver(
39
- endpoint=settings.local_http_endpoint,
40
- model=model
41
- ),
42
- "google": lambda model=None: GoogleDriver(
43
- api_key=settings.google_api_key,
44
- model=model or settings.google_model
45
- ),
46
- "groq": lambda model=None: GroqDriver(
47
- api_key=settings.groq_api_key,
48
- model=model or settings.groq_model
49
- ),
50
- "openrouter": lambda model=None: OpenRouterDriver(
51
- api_key=settings.openrouter_api_key,
52
- model=model or settings.openrouter_model
53
- ),
54
- "grok": lambda model=None: GrokDriver(
55
- api_key=settings.grok_api_key,
56
- model=model or settings.grok_model
57
- ),
58
- "airllm": lambda model=None: AirLLMDriver(
95
+ overwrite=True,
96
+ )
97
+ register_driver(
98
+ "local_http",
99
+ lambda model=None: LocalHTTPDriver(endpoint=settings.local_http_endpoint, model=model),
100
+ overwrite=True,
101
+ )
102
+ register_driver(
103
+ "google",
104
+ lambda model=None: GoogleDriver(api_key=settings.google_api_key, model=model or settings.google_model),
105
+ overwrite=True,
106
+ )
107
+ register_driver(
108
+ "groq",
109
+ lambda model=None: GroqDriver(api_key=settings.groq_api_key, model=model or settings.groq_model),
110
+ overwrite=True,
111
+ )
112
+ register_driver(
113
+ "openrouter",
114
+ lambda model=None: OpenRouterDriver(api_key=settings.openrouter_api_key, model=model or settings.openrouter_model),
115
+ overwrite=True,
116
+ )
117
+ register_driver(
118
+ "grok",
119
+ lambda model=None: GrokDriver(api_key=settings.grok_api_key, model=model or settings.grok_model),
120
+ overwrite=True,
121
+ )
122
+ register_driver(
123
+ "airllm",
124
+ lambda model=None: AirLLMDriver(
59
125
  model=model or settings.airllm_model,
60
126
  compression=settings.airllm_compression,
61
127
  ),
62
- }
128
+ overwrite=True,
129
+ )
63
130
 
131
+ # Backwards compatibility: expose registry dict (read-only view recommended)
132
+ DRIVER_REGISTRY = _get_sync_registry()
64
133
 
65
- def get_driver(provider_name: str = None):
134
+
135
+ def get_driver(provider_name: Optional[str] = None):
66
136
  """
67
137
  Factory to get a driver instance based on the provider name (legacy style).
68
138
  Uses default model from settings if not overridden.
69
139
  """
70
140
  provider = (provider_name or settings.ai_provider or "ollama").strip().lower()
71
- if provider not in DRIVER_REGISTRY:
72
- raise ValueError(f"Unknown provider: {provider_name}")
73
- return DRIVER_REGISTRY[provider]() # use default model from settings
141
+ factory = get_driver_factory(provider)
142
+ return factory() # use default model from settings
74
143
 
75
144
 
76
145
  def get_driver_for_model(model_str: str):
@@ -78,21 +147,21 @@ def get_driver_for_model(model_str: str):
78
147
  Factory to get a driver instance based on a full model string.
79
148
  Format: provider/model_id
80
149
  Example: "openai/gpt-4-turbo-preview"
81
-
150
+
82
151
  Args:
83
152
  model_str: Model identifier string. Can be either:
84
153
  - Full format: "provider/model" (e.g. "openai/gpt-4")
85
154
  - Provider only: "provider" (e.g. "openai")
86
-
155
+
87
156
  Returns:
88
157
  A configured driver instance for the specified provider/model.
89
-
158
+
90
159
  Raises:
91
160
  ValueError: If provider is invalid or format is incorrect.
92
161
  """
93
162
  if not isinstance(model_str, str):
94
163
  raise ValueError("Model string must be a string, got {type(model_str)}")
95
-
164
+
96
165
  if not model_str:
97
166
  raise ValueError("Model string cannot be empty")
98
167
 
@@ -101,26 +170,55 @@ def get_driver_for_model(model_str: str):
101
170
  provider = parts[0].lower()
102
171
  model_id = parts[1] if len(parts) > 1 else None
103
172
 
104
- # Validate provider
105
- if provider not in DRIVER_REGISTRY:
106
- raise ValueError(f"Unsupported provider '{provider}'")
107
-
173
+ # Get factory (validates provider exists)
174
+ factory = get_driver_factory(provider)
175
+
108
176
  # Create driver with model ID if provided, otherwise use default
109
- return DRIVER_REGISTRY[provider](model_id)
177
+ return factory(model_id)
110
178
 
111
179
 
112
180
  __all__ = [
113
- "OpenAIDriver",
114
- "LocalHTTPDriver",
115
- "OllamaDriver",
116
- "ClaudeDriver",
117
- "LMStudioDriver",
181
+ "ASYNC_DRIVER_REGISTRY",
182
+ # Legacy registry dicts (for backwards compatibility)
183
+ "DRIVER_REGISTRY",
184
+ # Sync drivers
185
+ "AirLLMDriver",
186
+ # Async drivers
187
+ "AsyncAirLLMDriver",
188
+ "AsyncAzureDriver",
189
+ "AsyncClaudeDriver",
190
+ "AsyncGoogleDriver",
191
+ "AsyncGrokDriver",
192
+ "AsyncGroqDriver",
193
+ "AsyncHuggingFaceDriver",
194
+ "AsyncLMStudioDriver",
195
+ "AsyncLocalHTTPDriver",
196
+ "AsyncOllamaDriver",
197
+ "AsyncOpenAIDriver",
198
+ "AsyncOpenRouterDriver",
118
199
  "AzureDriver",
200
+ "ClaudeDriver",
119
201
  "GoogleDriver",
202
+ "GrokDriver",
120
203
  "GroqDriver",
204
+ "LMStudioDriver",
205
+ "LocalHTTPDriver",
206
+ "OllamaDriver",
207
+ "OpenAIDriver",
121
208
  "OpenRouterDriver",
122
- "GrokDriver",
123
- "AirLLMDriver",
209
+ "get_async_driver",
210
+ "get_async_driver_for_model",
211
+ # Factory functions
124
212
  "get_driver",
125
213
  "get_driver_for_model",
214
+ "is_async_driver_registered",
215
+ "is_driver_registered",
216
+ "list_registered_async_drivers",
217
+ "list_registered_drivers",
218
+ "load_entry_point_drivers",
219
+ "register_async_driver",
220
+ # Registry functions (public API)
221
+ "register_driver",
222
+ "unregister_async_driver",
223
+ "unregister_driver",
126
224
  ]