prompture 0.0.29.dev8__py3-none-any.whl → 0.0.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. prompture/__init__.py +146 -23
  2. prompture/_version.py +34 -0
  3. prompture/aio/__init__.py +74 -0
  4. prompture/async_conversation.py +607 -0
  5. prompture/async_core.py +803 -0
  6. prompture/async_driver.py +169 -0
  7. prompture/cache.py +469 -0
  8. prompture/callbacks.py +55 -0
  9. prompture/cli.py +63 -4
  10. prompture/conversation.py +631 -0
  11. prompture/core.py +876 -263
  12. prompture/cost_mixin.py +51 -0
  13. prompture/discovery.py +164 -0
  14. prompture/driver.py +168 -5
  15. prompture/drivers/__init__.py +173 -69
  16. prompture/drivers/airllm_driver.py +109 -0
  17. prompture/drivers/async_airllm_driver.py +26 -0
  18. prompture/drivers/async_azure_driver.py +117 -0
  19. prompture/drivers/async_claude_driver.py +107 -0
  20. prompture/drivers/async_google_driver.py +132 -0
  21. prompture/drivers/async_grok_driver.py +91 -0
  22. prompture/drivers/async_groq_driver.py +84 -0
  23. prompture/drivers/async_hugging_driver.py +61 -0
  24. prompture/drivers/async_lmstudio_driver.py +79 -0
  25. prompture/drivers/async_local_http_driver.py +44 -0
  26. prompture/drivers/async_ollama_driver.py +125 -0
  27. prompture/drivers/async_openai_driver.py +96 -0
  28. prompture/drivers/async_openrouter_driver.py +96 -0
  29. prompture/drivers/async_registry.py +129 -0
  30. prompture/drivers/azure_driver.py +36 -9
  31. prompture/drivers/claude_driver.py +251 -34
  32. prompture/drivers/google_driver.py +107 -38
  33. prompture/drivers/grok_driver.py +29 -32
  34. prompture/drivers/groq_driver.py +27 -26
  35. prompture/drivers/hugging_driver.py +6 -6
  36. prompture/drivers/lmstudio_driver.py +26 -13
  37. prompture/drivers/local_http_driver.py +6 -6
  38. prompture/drivers/ollama_driver.py +157 -23
  39. prompture/drivers/openai_driver.py +178 -9
  40. prompture/drivers/openrouter_driver.py +31 -25
  41. prompture/drivers/registry.py +306 -0
  42. prompture/field_definitions.py +106 -96
  43. prompture/logging.py +80 -0
  44. prompture/model_rates.py +217 -0
  45. prompture/runner.py +49 -47
  46. prompture/scaffold/__init__.py +1 -0
  47. prompture/scaffold/generator.py +84 -0
  48. prompture/scaffold/templates/Dockerfile.j2 +12 -0
  49. prompture/scaffold/templates/README.md.j2 +41 -0
  50. prompture/scaffold/templates/config.py.j2 +21 -0
  51. prompture/scaffold/templates/env.example.j2 +8 -0
  52. prompture/scaffold/templates/main.py.j2 +86 -0
  53. prompture/scaffold/templates/models.py.j2 +40 -0
  54. prompture/scaffold/templates/requirements.txt.j2 +5 -0
  55. prompture/server.py +183 -0
  56. prompture/session.py +117 -0
  57. prompture/settings.py +18 -1
  58. prompture/tools.py +219 -267
  59. prompture/tools_schema.py +254 -0
  60. prompture/validator.py +3 -3
  61. {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/METADATA +117 -21
  62. prompture-0.0.35.dist-info/RECORD +66 -0
  63. {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/WHEEL +1 -1
  64. prompture-0.0.29.dev8.dist-info/RECORD +0 -27
  65. {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/entry_points.txt +0 -0
  66. {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/licenses/LICENSE +0 -0
  67. {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,51 @@
1
+ """Shared cost-calculation mixin for LLM drivers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+
8
+ class CostMixin:
9
+ """Mixin that provides ``_calculate_cost`` to sync and async drivers.
10
+
11
+ Drivers that charge per-token should inherit from this mixin alongside
12
+ their base class (``Driver`` or ``AsyncDriver``). Free/local drivers
13
+ (Ollama, LM Studio, LocalHTTP, HuggingFace, AirLLM) can skip it.
14
+ """
15
+
16
+ # Subclasses must define MODEL_PRICING as a class attribute.
17
+ MODEL_PRICING: dict[str, dict[str, Any]] = {}
18
+
19
+ # Divisor for hardcoded MODEL_PRICING rates.
20
+ # Most drivers use per-1K-token pricing (1_000).
21
+ # Grok uses per-1M-token pricing (1_000_000).
22
+ # Google uses per-1M-character pricing (1_000_000).
23
+ _PRICING_UNIT: int = 1_000
24
+
25
+ def _calculate_cost(
26
+ self,
27
+ provider: str,
28
+ model: str,
29
+ prompt_tokens: int | float,
30
+ completion_tokens: int | float,
31
+ ) -> float:
32
+ """Calculate USD cost for a generation call.
33
+
34
+ Resolution order:
35
+ 1. Live rates from ``model_rates.get_model_rates()`` (per 1M tokens).
36
+ 2. Hardcoded ``MODEL_PRICING`` on the driver class (unit set by ``_PRICING_UNIT``).
37
+ 3. Zero if neither source has data.
38
+ """
39
+ from .model_rates import get_model_rates
40
+
41
+ live_rates = get_model_rates(provider, model)
42
+ if live_rates:
43
+ prompt_cost = (prompt_tokens / 1_000_000) * live_rates["input"]
44
+ completion_cost = (completion_tokens / 1_000_000) * live_rates["output"]
45
+ else:
46
+ unit = self._PRICING_UNIT
47
+ model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
48
+ prompt_cost = (prompt_tokens / unit) * model_pricing["prompt"]
49
+ completion_cost = (completion_tokens / unit) * model_pricing["completion"]
50
+
51
+ return round(prompt_cost + completion_cost, 6)
prompture/discovery.py ADDED
@@ -0,0 +1,164 @@
1
+ """Discovery module for auto-detecting available models."""
2
+
3
+ import logging
4
+ import os
5
+
6
+ import requests
7
+
8
+ from .drivers import (
9
+ AzureDriver,
10
+ ClaudeDriver,
11
+ GoogleDriver,
12
+ GrokDriver,
13
+ GroqDriver,
14
+ LMStudioDriver,
15
+ LocalHTTPDriver,
16
+ OllamaDriver,
17
+ OpenAIDriver,
18
+ OpenRouterDriver,
19
+ )
20
+ from .settings import settings
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def get_available_models() -> list[str]:
26
+ """
27
+ Auto-detects all available models based on configured drivers and environment variables.
28
+
29
+ Iterates through supported providers and checks if they are configured (e.g. API key present).
30
+ For static drivers, returns models from their MODEL_PRICING keys.
31
+ For dynamic drivers (like Ollama), attempts to fetch available models from the endpoint.
32
+
33
+ Returns:
34
+ A list of unique model strings in the format "provider/model_id".
35
+ """
36
+ available_models: set[str] = set()
37
+ configured_providers: set[str] = set()
38
+
39
+ # Map of provider name to driver class
40
+ # We need to map the registry keys to the actual classes to check MODEL_PRICING
41
+ # and instantiate for dynamic checks if needed.
42
+ provider_classes = {
43
+ "openai": OpenAIDriver,
44
+ "azure": AzureDriver,
45
+ "claude": ClaudeDriver,
46
+ "google": GoogleDriver,
47
+ "groq": GroqDriver,
48
+ "openrouter": OpenRouterDriver,
49
+ "grok": GrokDriver,
50
+ "ollama": OllamaDriver,
51
+ "lmstudio": LMStudioDriver,
52
+ "local_http": LocalHTTPDriver,
53
+ }
54
+
55
+ for provider, driver_cls in provider_classes.items():
56
+ try:
57
+ # 1. Check if the provider is configured (has API key or endpoint)
58
+ # We can check this by looking at the settings or env vars that the driver uses.
59
+ # A simple way is to try to instantiate it with defaults, but that might fail if keys are missing.
60
+ # Instead, let's check the specific requirements for each known provider.
61
+
62
+ is_configured = False
63
+
64
+ if provider == "openai":
65
+ if settings.openai_api_key or os.getenv("OPENAI_API_KEY"):
66
+ is_configured = True
67
+ elif provider == "azure":
68
+ if (
69
+ (settings.azure_api_key or os.getenv("AZURE_API_KEY"))
70
+ and (settings.azure_api_endpoint or os.getenv("AZURE_API_ENDPOINT"))
71
+ and (settings.azure_deployment_id or os.getenv("AZURE_DEPLOYMENT_ID"))
72
+ ):
73
+ is_configured = True
74
+ elif provider == "claude":
75
+ if settings.claude_api_key or os.getenv("CLAUDE_API_KEY"):
76
+ is_configured = True
77
+ elif provider == "google":
78
+ if settings.google_api_key or os.getenv("GOOGLE_API_KEY"):
79
+ is_configured = True
80
+ elif provider == "groq":
81
+ if settings.groq_api_key or os.getenv("GROQ_API_KEY"):
82
+ is_configured = True
83
+ elif provider == "openrouter":
84
+ if settings.openrouter_api_key or os.getenv("OPENROUTER_API_KEY"):
85
+ is_configured = True
86
+ elif provider == "grok":
87
+ if settings.grok_api_key or os.getenv("GROK_API_KEY"):
88
+ is_configured = True
89
+ elif provider == "ollama":
90
+ # Ollama is always considered "configured" as it defaults to localhost
91
+ # We will check connectivity later
92
+ is_configured = True
93
+ elif provider == "lmstudio":
94
+ # LM Studio is similar to Ollama, defaults to localhost
95
+ is_configured = True
96
+ elif provider == "local_http" and (settings.local_http_endpoint or os.getenv("LOCAL_HTTP_ENDPOINT")):
97
+ is_configured = True
98
+
99
+ if not is_configured:
100
+ continue
101
+
102
+ configured_providers.add(provider)
103
+
104
+ # 2. Static Detection: Get models from MODEL_PRICING
105
+ if hasattr(driver_cls, "MODEL_PRICING"):
106
+ pricing = driver_cls.MODEL_PRICING
107
+ for model_id in pricing:
108
+ # Skip "default" or generic keys if they exist
109
+ if model_id == "default":
110
+ continue
111
+
112
+ # For Azure, the model_id in pricing is usually the base model name,
113
+ # but the user needs to use the deployment ID.
114
+ # However, our Azure driver implementation uses the deployment_id from init
115
+ # as the "model" for the request, but expects the user to pass a model name
116
+ # that maps to pricing?
117
+ # Looking at AzureDriver:
118
+ # kwargs = {"model": self.deployment_id, ...}
119
+ # model = options.get("model", self.model) -> used for pricing lookup
120
+ # So we should list the keys in MODEL_PRICING as available "models"
121
+ # even though for Azure specifically it's a bit weird because of deployment IDs.
122
+ # But for general discovery, listing supported models is correct.
123
+
124
+ available_models.add(f"{provider}/{model_id}")
125
+
126
+ # 3. Dynamic Detection: Specific logic for Ollama
127
+ if provider == "ollama":
128
+ try:
129
+ endpoint = settings.ollama_endpoint or os.getenv(
130
+ "OLLAMA_ENDPOINT", "http://localhost:11434/api/generate"
131
+ )
132
+ # We need the base URL for tags, usually http://localhost:11434/api/tags
133
+ # The configured endpoint might be .../api/generate or .../api/chat
134
+ base_url = endpoint.split("/api/")[0]
135
+ tags_url = f"{base_url}/api/tags"
136
+
137
+ resp = requests.get(tags_url, timeout=2)
138
+ if resp.status_code == 200:
139
+ data = resp.json()
140
+ models = data.get("models", [])
141
+ for model in models:
142
+ name = model.get("name")
143
+ if name:
144
+ # Ollama model names often include tags like "llama3:latest"
145
+ # We can keep them as is.
146
+ available_models.add(f"ollama/{name}")
147
+ except Exception as e:
148
+ logger.debug(f"Failed to fetch Ollama models: {e}")
149
+
150
+ # Future: Add dynamic detection for LM Studio if they have an endpoint for listing models
151
+
152
+ except Exception as e:
153
+ logger.warning(f"Error detecting models for provider {provider}: {e}")
154
+ continue
155
+
156
+ # Enrich with live model list from models.dev cache
157
+ from .model_rates import PROVIDER_MAP, get_all_provider_models
158
+
159
+ for prompture_name, api_name in PROVIDER_MAP.items():
160
+ if prompture_name in configured_providers:
161
+ for model_id in get_all_provider_models(api_name):
162
+ available_models.add(f"{prompture_name}/{model_id}")
163
+
164
+ return sorted(list(available_models))
prompture/driver.py CHANGED
@@ -1,7 +1,16 @@
1
- """Driver base class for LLM adapters.
2
- """
1
+ """Driver base class for LLM adapters."""
2
+
3
3
  from __future__ import annotations
4
- from typing import Any, Dict
4
+
5
+ import logging
6
+ import time
7
+ from collections.abc import Iterator
8
+ from typing import Any
9
+
10
+ from .callbacks import DriverCallbacks
11
+
12
+ logger = logging.getLogger("prompture.driver")
13
+
5
14
 
6
15
  class Driver:
7
16
  """Adapter base. Implementar generate(prompt, options) -> {"text": ... , "meta": {...}}
@@ -20,5 +29,159 @@ class Driver:
20
29
  additional provider-specific metadata while the core fields provide
21
30
  standardized access to token usage and cost information.
22
31
  """
23
- def generate(self, prompt: str, options: Dict[str,Any]) -> Dict[str,Any]:
24
- raise NotImplementedError
32
+
33
+ supports_json_mode: bool = False
34
+ supports_json_schema: bool = False
35
+ supports_messages: bool = False
36
+ supports_tool_use: bool = False
37
+ supports_streaming: bool = False
38
+
39
+ callbacks: DriverCallbacks | None = None
40
+
41
+ def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
42
+ raise NotImplementedError
43
+
44
+ def generate_messages(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
45
+ """Generate a response from a list of conversation messages.
46
+
47
+ Each message is a dict with ``"role"`` (``"system"``, ``"user"``, or
48
+ ``"assistant"``) and ``"content"`` keys.
49
+
50
+ The default implementation flattens the messages into a single prompt
51
+ string and delegates to :meth:`generate`. Drivers that natively
52
+ support message arrays should override this method and set
53
+ ``supports_messages = True``.
54
+ """
55
+ prompt = self._flatten_messages(messages)
56
+ return self.generate(prompt, options)
57
+
58
+ # ------------------------------------------------------------------
59
+ # Tool use
60
+ # ------------------------------------------------------------------
61
+
62
+ def generate_messages_with_tools(
63
+ self,
64
+ messages: list[dict[str, Any]],
65
+ tools: list[dict[str, Any]],
66
+ options: dict[str, Any],
67
+ ) -> dict[str, Any]:
68
+ """Generate a response that may include tool calls.
69
+
70
+ Returns a dict with keys: ``text``, ``meta``, ``tool_calls``, ``stop_reason``.
71
+ ``tool_calls`` is a list of ``{"id": str, "name": str, "arguments": dict}``.
72
+
73
+ Drivers that support tool use should override this method and set
74
+ ``supports_tool_use = True``.
75
+ """
76
+ raise NotImplementedError(f"{self.__class__.__name__} does not support tool use")
77
+
78
+ # ------------------------------------------------------------------
79
+ # Streaming
80
+ # ------------------------------------------------------------------
81
+
82
+ def generate_messages_stream(
83
+ self,
84
+ messages: list[dict[str, Any]],
85
+ options: dict[str, Any],
86
+ ) -> Iterator[dict[str, Any]]:
87
+ """Yield response chunks incrementally.
88
+
89
+ Each chunk is a dict:
90
+ - ``{"type": "delta", "text": str}`` for content fragments
91
+ - ``{"type": "done", "text": str, "meta": dict}`` for the final summary
92
+
93
+ Drivers that support streaming should override this method and set
94
+ ``supports_streaming = True``.
95
+ """
96
+ raise NotImplementedError(f"{self.__class__.__name__} does not support streaming")
97
+
98
+ # ------------------------------------------------------------------
99
+ # Hook-aware wrappers
100
+ # ------------------------------------------------------------------
101
+
102
+ def generate_with_hooks(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
103
+ """Wrap :meth:`generate` with on_request / on_response / on_error callbacks."""
104
+ driver_name = getattr(self, "model", self.__class__.__name__)
105
+ self._fire_callback(
106
+ "on_request",
107
+ {"prompt": prompt, "messages": None, "options": options, "driver": driver_name},
108
+ )
109
+ t0 = time.perf_counter()
110
+ try:
111
+ resp = self.generate(prompt, options)
112
+ except Exception as exc:
113
+ self._fire_callback(
114
+ "on_error",
115
+ {"error": exc, "prompt": prompt, "messages": None, "options": options, "driver": driver_name},
116
+ )
117
+ raise
118
+ elapsed_ms = (time.perf_counter() - t0) * 1000
119
+ self._fire_callback(
120
+ "on_response",
121
+ {
122
+ "text": resp.get("text", ""),
123
+ "meta": resp.get("meta", {}),
124
+ "driver": driver_name,
125
+ "elapsed_ms": elapsed_ms,
126
+ },
127
+ )
128
+ return resp
129
+
130
+ def generate_messages_with_hooks(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
131
+ """Wrap :meth:`generate_messages` with callbacks."""
132
+ driver_name = getattr(self, "model", self.__class__.__name__)
133
+ self._fire_callback(
134
+ "on_request",
135
+ {"prompt": None, "messages": messages, "options": options, "driver": driver_name},
136
+ )
137
+ t0 = time.perf_counter()
138
+ try:
139
+ resp = self.generate_messages(messages, options)
140
+ except Exception as exc:
141
+ self._fire_callback(
142
+ "on_error",
143
+ {"error": exc, "prompt": None, "messages": messages, "options": options, "driver": driver_name},
144
+ )
145
+ raise
146
+ elapsed_ms = (time.perf_counter() - t0) * 1000
147
+ self._fire_callback(
148
+ "on_response",
149
+ {
150
+ "text": resp.get("text", ""),
151
+ "meta": resp.get("meta", {}),
152
+ "driver": driver_name,
153
+ "elapsed_ms": elapsed_ms,
154
+ },
155
+ )
156
+ return resp
157
+
158
+ # ------------------------------------------------------------------
159
+ # Internal helpers
160
+ # ------------------------------------------------------------------
161
+
162
+ def _fire_callback(self, event: str, payload: dict[str, Any]) -> None:
163
+ """Invoke a single callback, swallowing and logging any exception."""
164
+ if self.callbacks is None:
165
+ return
166
+ cb = getattr(self.callbacks, event, None)
167
+ if cb is None:
168
+ return
169
+ try:
170
+ cb(payload)
171
+ except Exception:
172
+ logger.exception("Callback %s raised an exception", event)
173
+
174
+ @staticmethod
175
+ def _flatten_messages(messages: list[dict[str, Any]]) -> str:
176
+ """Join messages into a single prompt string with role prefixes."""
177
+ parts: list[str] = []
178
+ for msg in messages:
179
+ role = msg.get("role", "user")
180
+ content = msg.get("content", "")
181
+ if role == "system":
182
+ parts.append(f"[System]: {content}")
183
+ elif role == "assistant":
184
+ parts.append(f"[Assistant]: {content}")
185
+ else:
186
+ parts.append(f"[User]: {content}")
187
+ return "\n\n".join(parts)
@@ -1,71 +1,145 @@
1
- from .openai_driver import OpenAIDriver
2
- from .local_http_driver import LocalHTTPDriver
3
- from .ollama_driver import OllamaDriver
4
- from .claude_driver import ClaudeDriver
1
+ """Driver registry and factory functions.
2
+
3
+ This module provides:
4
+ - Built-in drivers for popular LLM providers
5
+ - A pluggable registry system for custom drivers
6
+ - Factory functions to instantiate drivers by provider/model name
7
+
8
+ Custom Driver Registration:
9
+ from prompture import register_driver
10
+
11
+ def my_driver_factory(model=None):
12
+ return MyCustomDriver(model=model)
13
+
14
+ register_driver("my_provider", my_driver_factory)
15
+
16
+ # Now you can use it
17
+ driver = get_driver_for_model("my_provider/my-model")
18
+
19
+ Entry Point Discovery:
20
+ Third-party packages can register drivers via entry points.
21
+ Add to your pyproject.toml:
22
+
23
+ [project.entry-points."prompture.drivers"]
24
+ my_provider = "my_package.drivers:my_driver_factory"
25
+ """
26
+
27
+ from typing import Optional
28
+
29
+ from ..settings import settings
30
+ from .airllm_driver import AirLLMDriver
31
+ from .async_airllm_driver import AsyncAirLLMDriver
32
+ from .async_azure_driver import AsyncAzureDriver
33
+ from .async_claude_driver import AsyncClaudeDriver
34
+ from .async_google_driver import AsyncGoogleDriver
35
+ from .async_grok_driver import AsyncGrokDriver
36
+ from .async_groq_driver import AsyncGroqDriver
37
+ from .async_hugging_driver import AsyncHuggingFaceDriver
38
+ from .async_lmstudio_driver import AsyncLMStudioDriver
39
+ from .async_local_http_driver import AsyncLocalHTTPDriver
40
+ from .async_ollama_driver import AsyncOllamaDriver
41
+ from .async_openai_driver import AsyncOpenAIDriver
42
+ from .async_openrouter_driver import AsyncOpenRouterDriver
43
+ from .async_registry import ASYNC_DRIVER_REGISTRY, get_async_driver, get_async_driver_for_model
5
44
  from .azure_driver import AzureDriver
6
- from .lmstudio_driver import LMStudioDriver
45
+ from .claude_driver import ClaudeDriver
7
46
  from .google_driver import GoogleDriver
47
+ from .grok_driver import GrokDriver
8
48
  from .groq_driver import GroqDriver
49
+ from .lmstudio_driver import LMStudioDriver
50
+ from .local_http_driver import LocalHTTPDriver
51
+ from .ollama_driver import OllamaDriver
52
+ from .openai_driver import OpenAIDriver
9
53
  from .openrouter_driver import OpenRouterDriver
10
- from .grok_driver import GrokDriver
11
- from ..settings import settings
12
-
54
+ from .registry import (
55
+ _get_sync_registry,
56
+ get_async_driver_factory,
57
+ get_driver_factory,
58
+ is_async_driver_registered,
59
+ is_driver_registered,
60
+ list_registered_async_drivers,
61
+ list_registered_drivers,
62
+ load_entry_point_drivers,
63
+ register_async_driver,
64
+ register_driver,
65
+ unregister_async_driver,
66
+ unregister_driver,
67
+ )
13
68
 
14
- # Central registry: maps provider → factory function
15
- DRIVER_REGISTRY = {
16
- "openai": lambda model=None: OpenAIDriver(
17
- api_key=settings.openai_api_key,
18
- model=model or settings.openai_model
19
- ),
20
- "ollama": lambda model=None: OllamaDriver(
21
- endpoint=settings.ollama_endpoint,
22
- model=model or settings.ollama_model
69
+ # Register built-in sync drivers
70
+ register_driver(
71
+ "openai",
72
+ lambda model=None: OpenAIDriver(api_key=settings.openai_api_key, model=model or settings.openai_model),
73
+ overwrite=True,
74
+ )
75
+ register_driver(
76
+ "ollama",
77
+ lambda model=None: OllamaDriver(endpoint=settings.ollama_endpoint, model=model or settings.ollama_model),
78
+ overwrite=True,
79
+ )
80
+ register_driver(
81
+ "claude",
82
+ lambda model=None: ClaudeDriver(api_key=settings.claude_api_key, model=model or settings.claude_model),
83
+ overwrite=True,
84
+ )
85
+ register_driver(
86
+ "lmstudio",
87
+ lambda model=None: LMStudioDriver(endpoint=settings.lmstudio_endpoint, model=model or settings.lmstudio_model),
88
+ overwrite=True,
89
+ )
90
+ register_driver(
91
+ "azure",
92
+ lambda model=None: AzureDriver(
93
+ api_key=settings.azure_api_key, endpoint=settings.azure_api_endpoint, deployment_id=settings.azure_deployment_id
23
94
  ),
24
- "claude": lambda model=None: ClaudeDriver(
25
- api_key=settings.claude_api_key,
26
- model=model or settings.claude_model
95
+ overwrite=True,
96
+ )
97
+ register_driver(
98
+ "local_http",
99
+ lambda model=None: LocalHTTPDriver(endpoint=settings.local_http_endpoint, model=model),
100
+ overwrite=True,
101
+ )
102
+ register_driver(
103
+ "google",
104
+ lambda model=None: GoogleDriver(api_key=settings.google_api_key, model=model or settings.google_model),
105
+ overwrite=True,
106
+ )
107
+ register_driver(
108
+ "groq",
109
+ lambda model=None: GroqDriver(api_key=settings.groq_api_key, model=model or settings.groq_model),
110
+ overwrite=True,
111
+ )
112
+ register_driver(
113
+ "openrouter",
114
+ lambda model=None: OpenRouterDriver(api_key=settings.openrouter_api_key, model=model or settings.openrouter_model),
115
+ overwrite=True,
116
+ )
117
+ register_driver(
118
+ "grok",
119
+ lambda model=None: GrokDriver(api_key=settings.grok_api_key, model=model or settings.grok_model),
120
+ overwrite=True,
121
+ )
122
+ register_driver(
123
+ "airllm",
124
+ lambda model=None: AirLLMDriver(
125
+ model=model or settings.airllm_model,
126
+ compression=settings.airllm_compression,
27
127
  ),
28
- "lmstudio": lambda model=None: LMStudioDriver(
29
- endpoint=settings.lmstudio_endpoint,
30
- model=model or settings.lmstudio_model
31
- ),
32
- "azure": lambda model=None: AzureDriver(
33
- api_key=settings.azure_api_key,
34
- endpoint=settings.azure_api_endpoint,
35
- deployment_id=settings.azure_deployment_id
36
- ),
37
- "local_http": lambda model=None: LocalHTTPDriver(
38
- endpoint=settings.local_http_endpoint,
39
- model=model
40
- ),
41
- "google": lambda model=None: GoogleDriver(
42
- api_key=settings.google_api_key,
43
- model=model or settings.google_model
44
- ),
45
- "groq": lambda model=None: GroqDriver(
46
- api_key=settings.groq_api_key,
47
- model=model or settings.groq_model
48
- ),
49
- "openrouter": lambda model=None: OpenRouterDriver(
50
- api_key=settings.openrouter_api_key,
51
- model=model or settings.openrouter_model
52
- ),
53
- "grok": lambda model=None: GrokDriver(
54
- api_key=settings.grok_api_key,
55
- model=model or settings.grok_model
56
- ),
57
- }
128
+ overwrite=True,
129
+ )
58
130
 
131
+ # Backwards compatibility: expose registry dict (read-only view recommended)
132
+ DRIVER_REGISTRY = _get_sync_registry()
59
133
 
60
- def get_driver(provider_name: str = None):
134
+
135
+ def get_driver(provider_name: Optional[str] = None):
61
136
  """
62
137
  Factory to get a driver instance based on the provider name (legacy style).
63
138
  Uses default model from settings if not overridden.
64
139
  """
65
140
  provider = (provider_name or settings.ai_provider or "ollama").strip().lower()
66
- if provider not in DRIVER_REGISTRY:
67
- raise ValueError(f"Unknown provider: {provider_name}")
68
- return DRIVER_REGISTRY[provider]() # use default model from settings
141
+ factory = get_driver_factory(provider)
142
+ return factory() # use default model from settings
69
143
 
70
144
 
71
145
  def get_driver_for_model(model_str: str):
@@ -73,21 +147,21 @@ def get_driver_for_model(model_str: str):
73
147
  Factory to get a driver instance based on a full model string.
74
148
  Format: provider/model_id
75
149
  Example: "openai/gpt-4-turbo-preview"
76
-
150
+
77
151
  Args:
78
152
  model_str: Model identifier string. Can be either:
79
153
  - Full format: "provider/model" (e.g. "openai/gpt-4")
80
154
  - Provider only: "provider" (e.g. "openai")
81
-
155
+
82
156
  Returns:
83
157
  A configured driver instance for the specified provider/model.
84
-
158
+
85
159
  Raises:
86
160
  ValueError: If provider is invalid or format is incorrect.
87
161
  """
88
162
  if not isinstance(model_str, str):
89
163
  raise ValueError("Model string must be a string, got {type(model_str)}")
90
-
164
+
91
165
  if not model_str:
92
166
  raise ValueError("Model string cannot be empty")
93
167
 
@@ -96,25 +170,55 @@ def get_driver_for_model(model_str: str):
96
170
  provider = parts[0].lower()
97
171
  model_id = parts[1] if len(parts) > 1 else None
98
172
 
99
- # Validate provider
100
- if provider not in DRIVER_REGISTRY:
101
- raise ValueError(f"Unsupported provider '{provider}'")
102
-
173
+ # Get factory (validates provider exists)
174
+ factory = get_driver_factory(provider)
175
+
103
176
  # Create driver with model ID if provided, otherwise use default
104
- return DRIVER_REGISTRY[provider](model_id)
177
+ return factory(model_id)
105
178
 
106
179
 
107
180
  __all__ = [
108
- "OpenAIDriver",
109
- "LocalHTTPDriver",
110
- "OllamaDriver",
111
- "ClaudeDriver",
112
- "LMStudioDriver",
181
+ "ASYNC_DRIVER_REGISTRY",
182
+ # Legacy registry dicts (for backwards compatibility)
183
+ "DRIVER_REGISTRY",
184
+ # Sync drivers
185
+ "AirLLMDriver",
186
+ # Async drivers
187
+ "AsyncAirLLMDriver",
188
+ "AsyncAzureDriver",
189
+ "AsyncClaudeDriver",
190
+ "AsyncGoogleDriver",
191
+ "AsyncGrokDriver",
192
+ "AsyncGroqDriver",
193
+ "AsyncHuggingFaceDriver",
194
+ "AsyncLMStudioDriver",
195
+ "AsyncLocalHTTPDriver",
196
+ "AsyncOllamaDriver",
197
+ "AsyncOpenAIDriver",
198
+ "AsyncOpenRouterDriver",
113
199
  "AzureDriver",
200
+ "ClaudeDriver",
114
201
  "GoogleDriver",
202
+ "GrokDriver",
115
203
  "GroqDriver",
204
+ "LMStudioDriver",
205
+ "LocalHTTPDriver",
206
+ "OllamaDriver",
207
+ "OpenAIDriver",
116
208
  "OpenRouterDriver",
117
- "GrokDriver",
209
+ "get_async_driver",
210
+ "get_async_driver_for_model",
211
+ # Factory functions
118
212
  "get_driver",
119
213
  "get_driver_for_model",
214
+ "is_async_driver_registered",
215
+ "is_driver_registered",
216
+ "list_registered_async_drivers",
217
+ "list_registered_drivers",
218
+ "load_entry_point_drivers",
219
+ "register_async_driver",
220
+ # Registry functions (public API)
221
+ "register_driver",
222
+ "unregister_async_driver",
223
+ "unregister_driver",
120
224
  ]