prompture 0.0.29.dev8__py3-none-any.whl → 0.0.38.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. prompture/__init__.py +264 -23
  2. prompture/_version.py +34 -0
  3. prompture/agent.py +924 -0
  4. prompture/agent_types.py +156 -0
  5. prompture/aio/__init__.py +74 -0
  6. prompture/async_agent.py +880 -0
  7. prompture/async_conversation.py +789 -0
  8. prompture/async_core.py +803 -0
  9. prompture/async_driver.py +193 -0
  10. prompture/async_groups.py +551 -0
  11. prompture/cache.py +469 -0
  12. prompture/callbacks.py +55 -0
  13. prompture/cli.py +63 -4
  14. prompture/conversation.py +826 -0
  15. prompture/core.py +894 -263
  16. prompture/cost_mixin.py +51 -0
  17. prompture/discovery.py +187 -0
  18. prompture/driver.py +206 -5
  19. prompture/drivers/__init__.py +175 -67
  20. prompture/drivers/airllm_driver.py +109 -0
  21. prompture/drivers/async_airllm_driver.py +26 -0
  22. prompture/drivers/async_azure_driver.py +123 -0
  23. prompture/drivers/async_claude_driver.py +113 -0
  24. prompture/drivers/async_google_driver.py +316 -0
  25. prompture/drivers/async_grok_driver.py +97 -0
  26. prompture/drivers/async_groq_driver.py +90 -0
  27. prompture/drivers/async_hugging_driver.py +61 -0
  28. prompture/drivers/async_lmstudio_driver.py +148 -0
  29. prompture/drivers/async_local_http_driver.py +44 -0
  30. prompture/drivers/async_ollama_driver.py +135 -0
  31. prompture/drivers/async_openai_driver.py +102 -0
  32. prompture/drivers/async_openrouter_driver.py +102 -0
  33. prompture/drivers/async_registry.py +133 -0
  34. prompture/drivers/azure_driver.py +42 -9
  35. prompture/drivers/claude_driver.py +257 -34
  36. prompture/drivers/google_driver.py +295 -42
  37. prompture/drivers/grok_driver.py +35 -32
  38. prompture/drivers/groq_driver.py +33 -26
  39. prompture/drivers/hugging_driver.py +6 -6
  40. prompture/drivers/lmstudio_driver.py +97 -19
  41. prompture/drivers/local_http_driver.py +6 -6
  42. prompture/drivers/ollama_driver.py +168 -23
  43. prompture/drivers/openai_driver.py +184 -9
  44. prompture/drivers/openrouter_driver.py +37 -25
  45. prompture/drivers/registry.py +306 -0
  46. prompture/drivers/vision_helpers.py +153 -0
  47. prompture/field_definitions.py +106 -96
  48. prompture/group_types.py +147 -0
  49. prompture/groups.py +530 -0
  50. prompture/image.py +180 -0
  51. prompture/logging.py +80 -0
  52. prompture/model_rates.py +217 -0
  53. prompture/persistence.py +254 -0
  54. prompture/persona.py +482 -0
  55. prompture/runner.py +49 -47
  56. prompture/scaffold/__init__.py +1 -0
  57. prompture/scaffold/generator.py +84 -0
  58. prompture/scaffold/templates/Dockerfile.j2 +12 -0
  59. prompture/scaffold/templates/README.md.j2 +41 -0
  60. prompture/scaffold/templates/config.py.j2 +21 -0
  61. prompture/scaffold/templates/env.example.j2 +8 -0
  62. prompture/scaffold/templates/main.py.j2 +86 -0
  63. prompture/scaffold/templates/models.py.j2 +40 -0
  64. prompture/scaffold/templates/requirements.txt.j2 +5 -0
  65. prompture/serialization.py +218 -0
  66. prompture/server.py +183 -0
  67. prompture/session.py +117 -0
  68. prompture/settings.py +19 -1
  69. prompture/tools.py +219 -267
  70. prompture/tools_schema.py +254 -0
  71. prompture/validator.py +3 -3
  72. prompture-0.0.38.dev2.dist-info/METADATA +369 -0
  73. prompture-0.0.38.dev2.dist-info/RECORD +77 -0
  74. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/WHEEL +1 -1
  75. prompture-0.0.29.dev8.dist-info/METADATA +0 -368
  76. prompture-0.0.29.dev8.dist-info/RECORD +0 -27
  77. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/entry_points.txt +0 -0
  78. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/licenses/LICENSE +0 -0
  79. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,28 @@
1
1
  """Minimal OpenAI driver (migrated to openai>=1.0.0).
2
2
  Requires the `openai` package. Uses OPENAI_API_KEY env var.
3
3
  """
4
+
5
+ import json
4
6
  import os
5
- from typing import Any, Dict
7
+ from collections.abc import Iterator
8
+ from typing import Any
9
+
6
10
  try:
7
11
  from openai import OpenAI
8
12
  except Exception:
9
13
  OpenAI = None
10
14
 
15
+ from ..cost_mixin import CostMixin
11
16
  from ..driver import Driver
12
17
 
13
18
 
14
- class OpenAIDriver(Driver):
19
+ class OpenAIDriver(CostMixin, Driver):
20
+ supports_json_mode = True
21
+ supports_json_schema = True
22
+ supports_tool_use = True
23
+ supports_streaming = True
24
+ supports_vision = True
25
+
15
26
  # Approximate pricing per 1K tokens (keep updated with OpenAI's official pricing)
16
27
  # Each model entry also defines which token parameter it supports and
17
28
  # whether it accepts temperature.
@@ -62,7 +73,21 @@ class OpenAIDriver(Driver):
62
73
  else:
63
74
  self.client = None
64
75
 
65
- def generate(self, prompt: str, options: Dict[str, Any]) -> Dict[str, Any]:
76
+ supports_messages = True
77
+
78
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
79
+ from .vision_helpers import _prepare_openai_vision_messages
80
+
81
+ return _prepare_openai_vision_messages(messages)
82
+
83
+ def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
84
+ messages = [{"role": "user", "content": prompt}]
85
+ return self._do_generate(messages, options)
86
+
87
+ def generate_messages(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
88
+ return self._do_generate(self._prepare_messages(messages), options)
89
+
90
+ def _do_generate(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
66
91
  if self.client is None:
67
92
  raise RuntimeError("openai package (>=1.0.0) is not installed")
68
93
 
@@ -79,7 +104,7 @@ class OpenAIDriver(Driver):
79
104
  # Base kwargs
80
105
  kwargs = {
81
106
  "model": model,
82
- "messages": [{"role": "user", "content": prompt}],
107
+ "messages": messages,
83
108
  }
84
109
 
85
110
  # Assign token limit with the correct parameter name
@@ -89,6 +114,21 @@ class OpenAIDriver(Driver):
89
114
  if supports_temperature and "temperature" in opts:
90
115
  kwargs["temperature"] = opts["temperature"]
91
116
 
117
+ # Native JSON mode support
118
+ if options.get("json_mode"):
119
+ json_schema = options.get("json_schema")
120
+ if json_schema:
121
+ kwargs["response_format"] = {
122
+ "type": "json_schema",
123
+ "json_schema": {
124
+ "name": "extraction",
125
+ "strict": True,
126
+ "schema": json_schema,
127
+ },
128
+ }
129
+ else:
130
+ kwargs["response_format"] = {"type": "json_object"}
131
+
92
132
  resp = self.client.chat.completions.create(**kwargs)
93
133
 
94
134
  # Extract usage info
@@ -97,11 +137,8 @@ class OpenAIDriver(Driver):
97
137
  completion_tokens = getattr(usage, "completion_tokens", 0)
98
138
  total_tokens = getattr(usage, "total_tokens", 0)
99
139
 
100
- # Calculate cost
101
- model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
102
- prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
103
- completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
104
- total_cost = prompt_cost + completion_cost
140
+ # Calculate cost via shared mixin
141
+ total_cost = self._calculate_cost("openai", model, prompt_tokens, completion_tokens)
105
142
 
106
143
  # Standardized meta object
107
144
  meta = {
@@ -115,3 +152,141 @@ class OpenAIDriver(Driver):
115
152
 
116
153
  text = resp.choices[0].message.content
117
154
  return {"text": text, "meta": meta}
155
+
156
+ # ------------------------------------------------------------------
157
+ # Tool use
158
+ # ------------------------------------------------------------------
159
+
160
+ def generate_messages_with_tools(
161
+ self,
162
+ messages: list[dict[str, Any]],
163
+ tools: list[dict[str, Any]],
164
+ options: dict[str, Any],
165
+ ) -> dict[str, Any]:
166
+ """Generate a response that may include tool calls."""
167
+ if self.client is None:
168
+ raise RuntimeError("openai package (>=1.0.0) is not installed")
169
+
170
+ model = options.get("model", self.model)
171
+ model_info = self.MODEL_PRICING.get(model, {})
172
+ tokens_param = model_info.get("tokens_param", "max_tokens")
173
+ supports_temperature = model_info.get("supports_temperature", True)
174
+
175
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
176
+
177
+ kwargs: dict[str, Any] = {
178
+ "model": model,
179
+ "messages": messages,
180
+ "tools": tools,
181
+ }
182
+ kwargs[tokens_param] = opts.get("max_tokens", 512)
183
+
184
+ if supports_temperature and "temperature" in opts:
185
+ kwargs["temperature"] = opts["temperature"]
186
+
187
+ resp = self.client.chat.completions.create(**kwargs)
188
+
189
+ usage = getattr(resp, "usage", None)
190
+ prompt_tokens = getattr(usage, "prompt_tokens", 0)
191
+ completion_tokens = getattr(usage, "completion_tokens", 0)
192
+ total_tokens = getattr(usage, "total_tokens", 0)
193
+ total_cost = self._calculate_cost("openai", model, prompt_tokens, completion_tokens)
194
+
195
+ meta = {
196
+ "prompt_tokens": prompt_tokens,
197
+ "completion_tokens": completion_tokens,
198
+ "total_tokens": total_tokens,
199
+ "cost": round(total_cost, 6),
200
+ "raw_response": resp.model_dump(),
201
+ "model_name": model,
202
+ }
203
+
204
+ choice = resp.choices[0]
205
+ text = choice.message.content or ""
206
+ stop_reason = choice.finish_reason
207
+
208
+ tool_calls_out: list[dict[str, Any]] = []
209
+ if choice.message.tool_calls:
210
+ for tc in choice.message.tool_calls:
211
+ try:
212
+ args = json.loads(tc.function.arguments)
213
+ except (json.JSONDecodeError, TypeError):
214
+ args = {}
215
+ tool_calls_out.append({
216
+ "id": tc.id,
217
+ "name": tc.function.name,
218
+ "arguments": args,
219
+ })
220
+
221
+ return {
222
+ "text": text,
223
+ "meta": meta,
224
+ "tool_calls": tool_calls_out,
225
+ "stop_reason": stop_reason,
226
+ }
227
+
228
+ # ------------------------------------------------------------------
229
+ # Streaming
230
+ # ------------------------------------------------------------------
231
+
232
+ def generate_messages_stream(
233
+ self,
234
+ messages: list[dict[str, Any]],
235
+ options: dict[str, Any],
236
+ ) -> Iterator[dict[str, Any]]:
237
+ """Yield response chunks via OpenAI streaming API."""
238
+ if self.client is None:
239
+ raise RuntimeError("openai package (>=1.0.0) is not installed")
240
+
241
+ model = options.get("model", self.model)
242
+ model_info = self.MODEL_PRICING.get(model, {})
243
+ tokens_param = model_info.get("tokens_param", "max_tokens")
244
+ supports_temperature = model_info.get("supports_temperature", True)
245
+
246
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
247
+
248
+ kwargs: dict[str, Any] = {
249
+ "model": model,
250
+ "messages": messages,
251
+ "stream": True,
252
+ "stream_options": {"include_usage": True},
253
+ }
254
+ kwargs[tokens_param] = opts.get("max_tokens", 512)
255
+
256
+ if supports_temperature and "temperature" in opts:
257
+ kwargs["temperature"] = opts["temperature"]
258
+
259
+ stream = self.client.chat.completions.create(**kwargs)
260
+
261
+ full_text = ""
262
+ prompt_tokens = 0
263
+ completion_tokens = 0
264
+
265
+ for chunk in stream:
266
+ # Usage comes in the final chunk
267
+ if getattr(chunk, "usage", None):
268
+ prompt_tokens = chunk.usage.prompt_tokens or 0
269
+ completion_tokens = chunk.usage.completion_tokens or 0
270
+
271
+ if chunk.choices:
272
+ delta = chunk.choices[0].delta
273
+ content = getattr(delta, "content", None) or ""
274
+ if content:
275
+ full_text += content
276
+ yield {"type": "delta", "text": content}
277
+
278
+ total_tokens = prompt_tokens + completion_tokens
279
+ total_cost = self._calculate_cost("openai", model, prompt_tokens, completion_tokens)
280
+
281
+ yield {
282
+ "type": "done",
283
+ "text": full_text,
284
+ "meta": {
285
+ "prompt_tokens": prompt_tokens,
286
+ "completion_tokens": completion_tokens,
287
+ "total_tokens": total_tokens,
288
+ "cost": round(total_cost, 6),
289
+ "raw_response": {},
290
+ "model_name": model,
291
+ },
292
+ }
@@ -1,14 +1,20 @@
1
1
  """OpenRouter driver implementation.
2
2
  Requires the `requests` package. Uses OPENROUTER_API_KEY env var.
3
3
  """
4
+
4
5
  import os
5
- from typing import Any, Dict
6
+ from typing import Any
7
+
6
8
  import requests
7
9
 
10
+ from ..cost_mixin import CostMixin
8
11
  from ..driver import Driver
9
12
 
10
13
 
11
- class OpenRouterDriver(Driver):
14
+ class OpenRouterDriver(CostMixin, Driver):
15
+ supports_json_mode = True
16
+ supports_vision = True
17
+
12
18
  # Approximate pricing per 1K tokens based on OpenRouter's pricing
13
19
  # https://openrouter.ai/docs#pricing
14
20
  MODEL_PRICING = {
@@ -40,7 +46,7 @@ class OpenRouterDriver(Driver):
40
46
 
41
47
  def __init__(self, api_key: str | None = None, model: str = "openai/gpt-3.5-turbo"):
42
48
  """Initialize OpenRouter driver.
43
-
49
+
44
50
  Args:
45
51
  api_key: OpenRouter API key. If not provided, will look for OPENROUTER_API_KEY env var
46
52
  model: Model to use. Defaults to openai/gpt-3.5-turbo
@@ -48,10 +54,10 @@ class OpenRouterDriver(Driver):
48
54
  self.api_key = api_key or os.getenv("OPENROUTER_API_KEY")
49
55
  if not self.api_key:
50
56
  raise ValueError("OpenRouter API key not found. Set OPENROUTER_API_KEY env var.")
51
-
57
+
52
58
  self.model = model
53
59
  self.base_url = "https://openrouter.ai/api/v1"
54
-
60
+
55
61
  # Required headers for OpenRouter
56
62
  self.headers = {
57
63
  "Authorization": f"Bearer {self.api_key}",
@@ -59,21 +65,26 @@ class OpenRouterDriver(Driver):
59
65
  "Content-Type": "application/json",
60
66
  }
61
67
 
62
- def generate(self, prompt: str, options: Dict[str, Any]) -> Dict[str, Any]:
63
- """Generate completion using OpenRouter API.
64
-
65
- Args:
66
- prompt: The prompt text
67
- options: Generation options
68
-
69
- Returns:
70
- Dict containing generated text and metadata
71
- """
68
+ supports_messages = True
69
+
70
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
71
+ from .vision_helpers import _prepare_openai_vision_messages
72
+
73
+ return _prepare_openai_vision_messages(messages)
74
+
75
+ def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
76
+ messages = [{"role": "user", "content": prompt}]
77
+ return self._do_generate(messages, options)
78
+
79
+ def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
80
+ return self._do_generate(self._prepare_messages(messages), options)
81
+
82
+ def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
72
83
  if not self.api_key:
73
84
  raise RuntimeError("OpenRouter API key not found")
74
85
 
75
86
  model = options.get("model", self.model)
76
-
87
+
77
88
  # Lookup model-specific config
78
89
  model_info = self.MODEL_PRICING.get(model, {})
79
90
  tokens_param = model_info.get("tokens_param", "max_tokens")
@@ -85,7 +96,7 @@ class OpenRouterDriver(Driver):
85
96
  # Base request data
86
97
  data = {
87
98
  "model": model,
88
- "messages": [{"role": "user", "content": prompt}],
99
+ "messages": messages,
89
100
  }
90
101
 
91
102
  # Add token limit with correct parameter name
@@ -95,6 +106,10 @@ class OpenRouterDriver(Driver):
95
106
  if supports_temperature and "temperature" in opts:
96
107
  data["temperature"] = opts["temperature"]
97
108
 
109
+ # Native JSON mode support
110
+ if options.get("json_mode"):
111
+ data["response_format"] = {"type": "json_object"}
112
+
98
113
  try:
99
114
  response = requests.post(
100
115
  f"{self.base_url}/chat/completions",
@@ -110,11 +125,8 @@ class OpenRouterDriver(Driver):
110
125
  completion_tokens = usage.get("completion_tokens", 0)
111
126
  total_tokens = usage.get("total_tokens", 0)
112
127
 
113
- # Calculate cost
114
- model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
115
- prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
116
- completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
117
- total_cost = prompt_cost + completion_cost
128
+ # Calculate cost via shared mixin
129
+ total_cost = self._calculate_cost("openrouter", model, prompt_tokens, completion_tokens)
118
130
 
119
131
  # Standardized meta object
120
132
  meta = {
@@ -130,11 +142,11 @@ class OpenRouterDriver(Driver):
130
142
  return {"text": text, "meta": meta}
131
143
 
132
144
  except requests.exceptions.RequestException as e:
133
- error_msg = f"OpenRouter API request failed: {str(e)}"
134
- if hasattr(e.response, 'json'):
145
+ error_msg = f"OpenRouter API request failed: {e!s}"
146
+ if hasattr(e.response, "json"):
135
147
  try:
136
148
  error_details = e.response.json()
137
149
  error_msg = f"{error_msg} - {error_details.get('error', {}).get('message', '')}"
138
150
  except Exception:
139
151
  pass
140
- raise RuntimeError(error_msg) from e
152
+ raise RuntimeError(error_msg) from e
@@ -0,0 +1,306 @@
1
+ """Driver registry with plugin support.
2
+
3
+ This module provides a public API for registering custom drivers and
4
+ supports auto-discovery of drivers via Python entry points.
5
+
6
+ Example usage:
7
+ # Register a custom driver
8
+ from prompture import register_driver
9
+
10
+ def my_driver_factory(model=None):
11
+ return MyCustomDriver(model=model)
12
+
13
+ register_driver("my_provider", my_driver_factory)
14
+
15
+ # Now you can use it
16
+ driver = get_driver_for_model("my_provider/my-model")
17
+
18
+ For entry point discovery, add to your package's pyproject.toml:
19
+ [project.entry-points."prompture.drivers"]
20
+ my_provider = "my_package.drivers:my_driver_factory"
21
+
22
+ [project.entry-points."prompture.async_drivers"]
23
+ my_provider = "my_package.drivers:my_async_driver_factory"
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import logging
29
+ import sys
30
+ from typing import Callable
31
+
32
+ logger = logging.getLogger("prompture.drivers.registry")
33
+
34
+ # Type alias for driver factory functions
35
+ # A factory takes an optional model name and returns a driver instance
36
+ DriverFactory = Callable[[str | None], object]
37
+
38
+ # Internal registries - populated by built-in drivers and plugins
39
+ _SYNC_REGISTRY: dict[str, DriverFactory] = {}
40
+ _ASYNC_REGISTRY: dict[str, DriverFactory] = {}
41
+
42
+ # Track whether entry points have been loaded
43
+ _entry_points_loaded = False
44
+
45
+
46
+ def register_driver(name: str, factory: DriverFactory, *, overwrite: bool = False) -> None:
47
+ """Register a custom driver factory for a provider name.
48
+
49
+ Args:
50
+ name: Provider name (e.g., "my_provider"). Will be lowercased.
51
+ factory: A callable that takes an optional model name and returns
52
+ a driver instance. The driver must implement the
53
+ ``Driver`` interface (specifically ``generate()``).
54
+ overwrite: If True, allow overwriting an existing registration.
55
+ Defaults to False.
56
+
57
+ Raises:
58
+ ValueError: If a driver with this name is already registered
59
+ and overwrite=False.
60
+
61
+ Example:
62
+ >>> def my_factory(model=None):
63
+ ... return MyDriver(model=model or "default-model")
64
+ >>> register_driver("my_provider", my_factory)
65
+ >>> driver = get_driver_for_model("my_provider/custom-model")
66
+ """
67
+ name = name.lower()
68
+ if name in _SYNC_REGISTRY and not overwrite:
69
+ raise ValueError(f"Driver '{name}' is already registered. Use overwrite=True to replace it.")
70
+ _SYNC_REGISTRY[name] = factory
71
+ logger.debug("Registered sync driver: %s", name)
72
+
73
+
74
+ def register_async_driver(name: str, factory: DriverFactory, *, overwrite: bool = False) -> None:
75
+ """Register a custom async driver factory for a provider name.
76
+
77
+ Args:
78
+ name: Provider name (e.g., "my_provider"). Will be lowercased.
79
+ factory: A callable that takes an optional model name and returns
80
+ an async driver instance. The driver must implement the
81
+ ``AsyncDriver`` interface (specifically ``async generate()``).
82
+ overwrite: If True, allow overwriting an existing registration.
83
+ Defaults to False.
84
+
85
+ Raises:
86
+ ValueError: If an async driver with this name is already registered
87
+ and overwrite=False.
88
+
89
+ Example:
90
+ >>> def my_async_factory(model=None):
91
+ ... return MyAsyncDriver(model=model or "default-model")
92
+ >>> register_async_driver("my_provider", my_async_factory)
93
+ >>> driver = get_async_driver_for_model("my_provider/custom-model")
94
+ """
95
+ name = name.lower()
96
+ if name in _ASYNC_REGISTRY and not overwrite:
97
+ raise ValueError(f"Async driver '{name}' is already registered. Use overwrite=True to replace it.")
98
+ _ASYNC_REGISTRY[name] = factory
99
+ logger.debug("Registered async driver: %s", name)
100
+
101
+
102
+ def unregister_driver(name: str) -> bool:
103
+ """Unregister a sync driver by name.
104
+
105
+ Args:
106
+ name: Provider name to unregister.
107
+
108
+ Returns:
109
+ True if the driver was unregistered, False if it wasn't registered.
110
+ """
111
+ name = name.lower()
112
+ if name in _SYNC_REGISTRY:
113
+ del _SYNC_REGISTRY[name]
114
+ logger.debug("Unregistered sync driver: %s", name)
115
+ return True
116
+ return False
117
+
118
+
119
+ def unregister_async_driver(name: str) -> bool:
120
+ """Unregister an async driver by name.
121
+
122
+ Args:
123
+ name: Provider name to unregister.
124
+
125
+ Returns:
126
+ True if the driver was unregistered, False if it wasn't registered.
127
+ """
128
+ name = name.lower()
129
+ if name in _ASYNC_REGISTRY:
130
+ del _ASYNC_REGISTRY[name]
131
+ logger.debug("Unregistered async driver: %s", name)
132
+ return True
133
+ return False
134
+
135
+
136
+ def list_registered_drivers() -> list[str]:
137
+ """Return a sorted list of registered sync driver names."""
138
+ _ensure_entry_points_loaded()
139
+ return sorted(_SYNC_REGISTRY.keys())
140
+
141
+
142
+ def list_registered_async_drivers() -> list[str]:
143
+ """Return a sorted list of registered async driver names."""
144
+ _ensure_entry_points_loaded()
145
+ return sorted(_ASYNC_REGISTRY.keys())
146
+
147
+
148
+ def is_driver_registered(name: str) -> bool:
149
+ """Check if a sync driver is registered.
150
+
151
+ Args:
152
+ name: Provider name to check.
153
+
154
+ Returns:
155
+ True if the driver is registered.
156
+ """
157
+ _ensure_entry_points_loaded()
158
+ return name.lower() in _SYNC_REGISTRY
159
+
160
+
161
+ def is_async_driver_registered(name: str) -> bool:
162
+ """Check if an async driver is registered.
163
+
164
+ Args:
165
+ name: Provider name to check.
166
+
167
+ Returns:
168
+ True if the async driver is registered.
169
+ """
170
+ _ensure_entry_points_loaded()
171
+ return name.lower() in _ASYNC_REGISTRY
172
+
173
+
174
+ def get_driver_factory(name: str) -> DriverFactory:
175
+ """Get a registered sync driver factory by name.
176
+
177
+ Args:
178
+ name: Provider name.
179
+
180
+ Returns:
181
+ The factory function.
182
+
183
+ Raises:
184
+ ValueError: If the driver is not registered.
185
+ """
186
+ _ensure_entry_points_loaded()
187
+ name = name.lower()
188
+ if name not in _SYNC_REGISTRY:
189
+ raise ValueError(f"Unsupported provider '{name}'")
190
+ return _SYNC_REGISTRY[name]
191
+
192
+
193
+ def get_async_driver_factory(name: str) -> DriverFactory:
194
+ """Get a registered async driver factory by name.
195
+
196
+ Args:
197
+ name: Provider name.
198
+
199
+ Returns:
200
+ The factory function.
201
+
202
+ Raises:
203
+ ValueError: If the async driver is not registered.
204
+ """
205
+ _ensure_entry_points_loaded()
206
+ name = name.lower()
207
+ if name not in _ASYNC_REGISTRY:
208
+ raise ValueError(f"Unsupported provider '{name}'")
209
+ return _ASYNC_REGISTRY[name]
210
+
211
+
212
+ def load_entry_point_drivers() -> tuple[int, int]:
213
+ """Load drivers from installed packages via entry points.
214
+
215
+ This function scans for packages that define entry points in the
216
+ ``prompture.drivers`` and ``prompture.async_drivers`` groups.
217
+
218
+ Returns:
219
+ A tuple of (sync_count, async_count) indicating how many drivers
220
+ were loaded from entry points.
221
+
222
+ Example pyproject.toml for a plugin package:
223
+ [project.entry-points."prompture.drivers"]
224
+ my_provider = "my_package.drivers:create_my_driver"
225
+
226
+ [project.entry-points."prompture.async_drivers"]
227
+ my_provider = "my_package.drivers:create_my_async_driver"
228
+ """
229
+ global _entry_points_loaded
230
+
231
+ sync_count = 0
232
+ async_count = 0
233
+
234
+ # Python 3.9+ has importlib.metadata in stdlib
235
+ # Python 3.8 needs importlib_metadata backport
236
+ if sys.version_info >= (3, 10):
237
+ from importlib.metadata import entry_points
238
+
239
+ sync_eps = entry_points(group="prompture.drivers")
240
+ async_eps = entry_points(group="prompture.async_drivers")
241
+ else:
242
+ from importlib.metadata import entry_points
243
+
244
+ all_eps = entry_points()
245
+ sync_eps = all_eps.get("prompture.drivers", [])
246
+ async_eps = all_eps.get("prompture.async_drivers", [])
247
+
248
+ # Load sync drivers
249
+ for ep in sync_eps:
250
+ try:
251
+ # Skip if already registered (built-in drivers take precedence)
252
+ if ep.name.lower() in _SYNC_REGISTRY:
253
+ logger.debug("Skipping entry point driver '%s' (already registered)", ep.name)
254
+ continue
255
+
256
+ factory = ep.load()
257
+ _SYNC_REGISTRY[ep.name.lower()] = factory
258
+ sync_count += 1
259
+ logger.info("Loaded sync driver from entry point: %s", ep.name)
260
+ except Exception:
261
+ logger.exception("Failed to load sync driver entry point: %s", ep.name)
262
+
263
+ # Load async drivers
264
+ for ep in async_eps:
265
+ try:
266
+ # Skip if already registered (built-in drivers take precedence)
267
+ if ep.name.lower() in _ASYNC_REGISTRY:
268
+ logger.debug("Skipping entry point async driver '%s' (already registered)", ep.name)
269
+ continue
270
+
271
+ factory = ep.load()
272
+ _ASYNC_REGISTRY[ep.name.lower()] = factory
273
+ async_count += 1
274
+ logger.info("Loaded async driver from entry point: %s", ep.name)
275
+ except Exception:
276
+ logger.exception("Failed to load async driver entry point: %s", ep.name)
277
+
278
+ _entry_points_loaded = True
279
+ return (sync_count, async_count)
280
+
281
+
282
+ def _ensure_entry_points_loaded() -> None:
283
+ """Ensure entry points have been loaded (lazy initialization)."""
284
+ global _entry_points_loaded
285
+ if not _entry_points_loaded:
286
+ load_entry_point_drivers()
287
+
288
+
289
+ def _get_sync_registry() -> dict[str, DriverFactory]:
290
+ """Get the internal sync registry dict (for internal use by drivers/__init__.py)."""
291
+ _ensure_entry_points_loaded()
292
+ return _SYNC_REGISTRY
293
+
294
+
295
+ def _get_async_registry() -> dict[str, DriverFactory]:
296
+ """Get the internal async registry dict (for internal use by drivers/async_registry.py)."""
297
+ _ensure_entry_points_loaded()
298
+ return _ASYNC_REGISTRY
299
+
300
+
301
+ def _reset_registries() -> None:
302
+ """Reset registries to empty state (for testing only)."""
303
+ global _entry_points_loaded
304
+ _SYNC_REGISTRY.clear()
305
+ _ASYNC_REGISTRY.clear()
306
+ _entry_points_loaded = False