prompture 0.0.29.dev8__py3-none-any.whl → 0.0.38.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +264 -23
- prompture/_version.py +34 -0
- prompture/agent.py +924 -0
- prompture/agent_types.py +156 -0
- prompture/aio/__init__.py +74 -0
- prompture/async_agent.py +880 -0
- prompture/async_conversation.py +789 -0
- prompture/async_core.py +803 -0
- prompture/async_driver.py +193 -0
- prompture/async_groups.py +551 -0
- prompture/cache.py +469 -0
- prompture/callbacks.py +55 -0
- prompture/cli.py +63 -4
- prompture/conversation.py +826 -0
- prompture/core.py +894 -263
- prompture/cost_mixin.py +51 -0
- prompture/discovery.py +187 -0
- prompture/driver.py +206 -5
- prompture/drivers/__init__.py +175 -67
- prompture/drivers/airllm_driver.py +109 -0
- prompture/drivers/async_airllm_driver.py +26 -0
- prompture/drivers/async_azure_driver.py +123 -0
- prompture/drivers/async_claude_driver.py +113 -0
- prompture/drivers/async_google_driver.py +316 -0
- prompture/drivers/async_grok_driver.py +97 -0
- prompture/drivers/async_groq_driver.py +90 -0
- prompture/drivers/async_hugging_driver.py +61 -0
- prompture/drivers/async_lmstudio_driver.py +148 -0
- prompture/drivers/async_local_http_driver.py +44 -0
- prompture/drivers/async_ollama_driver.py +135 -0
- prompture/drivers/async_openai_driver.py +102 -0
- prompture/drivers/async_openrouter_driver.py +102 -0
- prompture/drivers/async_registry.py +133 -0
- prompture/drivers/azure_driver.py +42 -9
- prompture/drivers/claude_driver.py +257 -34
- prompture/drivers/google_driver.py +295 -42
- prompture/drivers/grok_driver.py +35 -32
- prompture/drivers/groq_driver.py +33 -26
- prompture/drivers/hugging_driver.py +6 -6
- prompture/drivers/lmstudio_driver.py +97 -19
- prompture/drivers/local_http_driver.py +6 -6
- prompture/drivers/ollama_driver.py +168 -23
- prompture/drivers/openai_driver.py +184 -9
- prompture/drivers/openrouter_driver.py +37 -25
- prompture/drivers/registry.py +306 -0
- prompture/drivers/vision_helpers.py +153 -0
- prompture/field_definitions.py +106 -96
- prompture/group_types.py +147 -0
- prompture/groups.py +530 -0
- prompture/image.py +180 -0
- prompture/logging.py +80 -0
- prompture/model_rates.py +217 -0
- prompture/persistence.py +254 -0
- prompture/persona.py +482 -0
- prompture/runner.py +49 -47
- prompture/scaffold/__init__.py +1 -0
- prompture/scaffold/generator.py +84 -0
- prompture/scaffold/templates/Dockerfile.j2 +12 -0
- prompture/scaffold/templates/README.md.j2 +41 -0
- prompture/scaffold/templates/config.py.j2 +21 -0
- prompture/scaffold/templates/env.example.j2 +8 -0
- prompture/scaffold/templates/main.py.j2 +86 -0
- prompture/scaffold/templates/models.py.j2 +40 -0
- prompture/scaffold/templates/requirements.txt.j2 +5 -0
- prompture/serialization.py +218 -0
- prompture/server.py +183 -0
- prompture/session.py +117 -0
- prompture/settings.py +19 -1
- prompture/tools.py +219 -267
- prompture/tools_schema.py +254 -0
- prompture/validator.py +3 -3
- prompture-0.0.38.dev2.dist-info/METADATA +369 -0
- prompture-0.0.38.dev2.dist-info/RECORD +77 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/WHEEL +1 -1
- prompture-0.0.29.dev8.dist-info/METADATA +0 -368
- prompture-0.0.29.dev8.dist-info/RECORD +0 -27
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/top_level.txt +0 -0
|
@@ -1,17 +1,28 @@
|
|
|
1
1
|
"""Minimal OpenAI driver (migrated to openai>=1.0.0).
|
|
2
2
|
Requires the `openai` package. Uses OPENAI_API_KEY env var.
|
|
3
3
|
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
4
6
|
import os
|
|
5
|
-
from
|
|
7
|
+
from collections.abc import Iterator
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
6
10
|
try:
|
|
7
11
|
from openai import OpenAI
|
|
8
12
|
except Exception:
|
|
9
13
|
OpenAI = None
|
|
10
14
|
|
|
15
|
+
from ..cost_mixin import CostMixin
|
|
11
16
|
from ..driver import Driver
|
|
12
17
|
|
|
13
18
|
|
|
14
|
-
class OpenAIDriver(Driver):
|
|
19
|
+
class OpenAIDriver(CostMixin, Driver):
|
|
20
|
+
supports_json_mode = True
|
|
21
|
+
supports_json_schema = True
|
|
22
|
+
supports_tool_use = True
|
|
23
|
+
supports_streaming = True
|
|
24
|
+
supports_vision = True
|
|
25
|
+
|
|
15
26
|
# Approximate pricing per 1K tokens (keep updated with OpenAI's official pricing)
|
|
16
27
|
# Each model entry also defines which token parameter it supports and
|
|
17
28
|
# whether it accepts temperature.
|
|
@@ -62,7 +73,21 @@ class OpenAIDriver(Driver):
|
|
|
62
73
|
else:
|
|
63
74
|
self.client = None
|
|
64
75
|
|
|
65
|
-
|
|
76
|
+
supports_messages = True
|
|
77
|
+
|
|
78
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
79
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
80
|
+
|
|
81
|
+
return _prepare_openai_vision_messages(messages)
|
|
82
|
+
|
|
83
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
84
|
+
messages = [{"role": "user", "content": prompt}]
|
|
85
|
+
return self._do_generate(messages, options)
|
|
86
|
+
|
|
87
|
+
def generate_messages(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
|
|
88
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
89
|
+
|
|
90
|
+
def _do_generate(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
|
|
66
91
|
if self.client is None:
|
|
67
92
|
raise RuntimeError("openai package (>=1.0.0) is not installed")
|
|
68
93
|
|
|
@@ -79,7 +104,7 @@ class OpenAIDriver(Driver):
|
|
|
79
104
|
# Base kwargs
|
|
80
105
|
kwargs = {
|
|
81
106
|
"model": model,
|
|
82
|
-
"messages":
|
|
107
|
+
"messages": messages,
|
|
83
108
|
}
|
|
84
109
|
|
|
85
110
|
# Assign token limit with the correct parameter name
|
|
@@ -89,6 +114,21 @@ class OpenAIDriver(Driver):
|
|
|
89
114
|
if supports_temperature and "temperature" in opts:
|
|
90
115
|
kwargs["temperature"] = opts["temperature"]
|
|
91
116
|
|
|
117
|
+
# Native JSON mode support
|
|
118
|
+
if options.get("json_mode"):
|
|
119
|
+
json_schema = options.get("json_schema")
|
|
120
|
+
if json_schema:
|
|
121
|
+
kwargs["response_format"] = {
|
|
122
|
+
"type": "json_schema",
|
|
123
|
+
"json_schema": {
|
|
124
|
+
"name": "extraction",
|
|
125
|
+
"strict": True,
|
|
126
|
+
"schema": json_schema,
|
|
127
|
+
},
|
|
128
|
+
}
|
|
129
|
+
else:
|
|
130
|
+
kwargs["response_format"] = {"type": "json_object"}
|
|
131
|
+
|
|
92
132
|
resp = self.client.chat.completions.create(**kwargs)
|
|
93
133
|
|
|
94
134
|
# Extract usage info
|
|
@@ -97,11 +137,8 @@ class OpenAIDriver(Driver):
|
|
|
97
137
|
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
98
138
|
total_tokens = getattr(usage, "total_tokens", 0)
|
|
99
139
|
|
|
100
|
-
# Calculate cost
|
|
101
|
-
|
|
102
|
-
prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
|
|
103
|
-
completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
|
|
104
|
-
total_cost = prompt_cost + completion_cost
|
|
140
|
+
# Calculate cost via shared mixin
|
|
141
|
+
total_cost = self._calculate_cost("openai", model, prompt_tokens, completion_tokens)
|
|
105
142
|
|
|
106
143
|
# Standardized meta object
|
|
107
144
|
meta = {
|
|
@@ -115,3 +152,141 @@ class OpenAIDriver(Driver):
|
|
|
115
152
|
|
|
116
153
|
text = resp.choices[0].message.content
|
|
117
154
|
return {"text": text, "meta": meta}
|
|
155
|
+
|
|
156
|
+
# ------------------------------------------------------------------
|
|
157
|
+
# Tool use
|
|
158
|
+
# ------------------------------------------------------------------
|
|
159
|
+
|
|
160
|
+
def generate_messages_with_tools(
|
|
161
|
+
self,
|
|
162
|
+
messages: list[dict[str, Any]],
|
|
163
|
+
tools: list[dict[str, Any]],
|
|
164
|
+
options: dict[str, Any],
|
|
165
|
+
) -> dict[str, Any]:
|
|
166
|
+
"""Generate a response that may include tool calls."""
|
|
167
|
+
if self.client is None:
|
|
168
|
+
raise RuntimeError("openai package (>=1.0.0) is not installed")
|
|
169
|
+
|
|
170
|
+
model = options.get("model", self.model)
|
|
171
|
+
model_info = self.MODEL_PRICING.get(model, {})
|
|
172
|
+
tokens_param = model_info.get("tokens_param", "max_tokens")
|
|
173
|
+
supports_temperature = model_info.get("supports_temperature", True)
|
|
174
|
+
|
|
175
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
176
|
+
|
|
177
|
+
kwargs: dict[str, Any] = {
|
|
178
|
+
"model": model,
|
|
179
|
+
"messages": messages,
|
|
180
|
+
"tools": tools,
|
|
181
|
+
}
|
|
182
|
+
kwargs[tokens_param] = opts.get("max_tokens", 512)
|
|
183
|
+
|
|
184
|
+
if supports_temperature and "temperature" in opts:
|
|
185
|
+
kwargs["temperature"] = opts["temperature"]
|
|
186
|
+
|
|
187
|
+
resp = self.client.chat.completions.create(**kwargs)
|
|
188
|
+
|
|
189
|
+
usage = getattr(resp, "usage", None)
|
|
190
|
+
prompt_tokens = getattr(usage, "prompt_tokens", 0)
|
|
191
|
+
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
192
|
+
total_tokens = getattr(usage, "total_tokens", 0)
|
|
193
|
+
total_cost = self._calculate_cost("openai", model, prompt_tokens, completion_tokens)
|
|
194
|
+
|
|
195
|
+
meta = {
|
|
196
|
+
"prompt_tokens": prompt_tokens,
|
|
197
|
+
"completion_tokens": completion_tokens,
|
|
198
|
+
"total_tokens": total_tokens,
|
|
199
|
+
"cost": round(total_cost, 6),
|
|
200
|
+
"raw_response": resp.model_dump(),
|
|
201
|
+
"model_name": model,
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
choice = resp.choices[0]
|
|
205
|
+
text = choice.message.content or ""
|
|
206
|
+
stop_reason = choice.finish_reason
|
|
207
|
+
|
|
208
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
209
|
+
if choice.message.tool_calls:
|
|
210
|
+
for tc in choice.message.tool_calls:
|
|
211
|
+
try:
|
|
212
|
+
args = json.loads(tc.function.arguments)
|
|
213
|
+
except (json.JSONDecodeError, TypeError):
|
|
214
|
+
args = {}
|
|
215
|
+
tool_calls_out.append({
|
|
216
|
+
"id": tc.id,
|
|
217
|
+
"name": tc.function.name,
|
|
218
|
+
"arguments": args,
|
|
219
|
+
})
|
|
220
|
+
|
|
221
|
+
return {
|
|
222
|
+
"text": text,
|
|
223
|
+
"meta": meta,
|
|
224
|
+
"tool_calls": tool_calls_out,
|
|
225
|
+
"stop_reason": stop_reason,
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
# ------------------------------------------------------------------
|
|
229
|
+
# Streaming
|
|
230
|
+
# ------------------------------------------------------------------
|
|
231
|
+
|
|
232
|
+
def generate_messages_stream(
|
|
233
|
+
self,
|
|
234
|
+
messages: list[dict[str, Any]],
|
|
235
|
+
options: dict[str, Any],
|
|
236
|
+
) -> Iterator[dict[str, Any]]:
|
|
237
|
+
"""Yield response chunks via OpenAI streaming API."""
|
|
238
|
+
if self.client is None:
|
|
239
|
+
raise RuntimeError("openai package (>=1.0.0) is not installed")
|
|
240
|
+
|
|
241
|
+
model = options.get("model", self.model)
|
|
242
|
+
model_info = self.MODEL_PRICING.get(model, {})
|
|
243
|
+
tokens_param = model_info.get("tokens_param", "max_tokens")
|
|
244
|
+
supports_temperature = model_info.get("supports_temperature", True)
|
|
245
|
+
|
|
246
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
247
|
+
|
|
248
|
+
kwargs: dict[str, Any] = {
|
|
249
|
+
"model": model,
|
|
250
|
+
"messages": messages,
|
|
251
|
+
"stream": True,
|
|
252
|
+
"stream_options": {"include_usage": True},
|
|
253
|
+
}
|
|
254
|
+
kwargs[tokens_param] = opts.get("max_tokens", 512)
|
|
255
|
+
|
|
256
|
+
if supports_temperature and "temperature" in opts:
|
|
257
|
+
kwargs["temperature"] = opts["temperature"]
|
|
258
|
+
|
|
259
|
+
stream = self.client.chat.completions.create(**kwargs)
|
|
260
|
+
|
|
261
|
+
full_text = ""
|
|
262
|
+
prompt_tokens = 0
|
|
263
|
+
completion_tokens = 0
|
|
264
|
+
|
|
265
|
+
for chunk in stream:
|
|
266
|
+
# Usage comes in the final chunk
|
|
267
|
+
if getattr(chunk, "usage", None):
|
|
268
|
+
prompt_tokens = chunk.usage.prompt_tokens or 0
|
|
269
|
+
completion_tokens = chunk.usage.completion_tokens or 0
|
|
270
|
+
|
|
271
|
+
if chunk.choices:
|
|
272
|
+
delta = chunk.choices[0].delta
|
|
273
|
+
content = getattr(delta, "content", None) or ""
|
|
274
|
+
if content:
|
|
275
|
+
full_text += content
|
|
276
|
+
yield {"type": "delta", "text": content}
|
|
277
|
+
|
|
278
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
279
|
+
total_cost = self._calculate_cost("openai", model, prompt_tokens, completion_tokens)
|
|
280
|
+
|
|
281
|
+
yield {
|
|
282
|
+
"type": "done",
|
|
283
|
+
"text": full_text,
|
|
284
|
+
"meta": {
|
|
285
|
+
"prompt_tokens": prompt_tokens,
|
|
286
|
+
"completion_tokens": completion_tokens,
|
|
287
|
+
"total_tokens": total_tokens,
|
|
288
|
+
"cost": round(total_cost, 6),
|
|
289
|
+
"raw_response": {},
|
|
290
|
+
"model_name": model,
|
|
291
|
+
},
|
|
292
|
+
}
|
|
@@ -1,14 +1,20 @@
|
|
|
1
1
|
"""OpenRouter driver implementation.
|
|
2
2
|
Requires the `requests` package. Uses OPENROUTER_API_KEY env var.
|
|
3
3
|
"""
|
|
4
|
+
|
|
4
5
|
import os
|
|
5
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
6
8
|
import requests
|
|
7
9
|
|
|
10
|
+
from ..cost_mixin import CostMixin
|
|
8
11
|
from ..driver import Driver
|
|
9
12
|
|
|
10
13
|
|
|
11
|
-
class OpenRouterDriver(Driver):
|
|
14
|
+
class OpenRouterDriver(CostMixin, Driver):
|
|
15
|
+
supports_json_mode = True
|
|
16
|
+
supports_vision = True
|
|
17
|
+
|
|
12
18
|
# Approximate pricing per 1K tokens based on OpenRouter's pricing
|
|
13
19
|
# https://openrouter.ai/docs#pricing
|
|
14
20
|
MODEL_PRICING = {
|
|
@@ -40,7 +46,7 @@ class OpenRouterDriver(Driver):
|
|
|
40
46
|
|
|
41
47
|
def __init__(self, api_key: str | None = None, model: str = "openai/gpt-3.5-turbo"):
|
|
42
48
|
"""Initialize OpenRouter driver.
|
|
43
|
-
|
|
49
|
+
|
|
44
50
|
Args:
|
|
45
51
|
api_key: OpenRouter API key. If not provided, will look for OPENROUTER_API_KEY env var
|
|
46
52
|
model: Model to use. Defaults to openai/gpt-3.5-turbo
|
|
@@ -48,10 +54,10 @@ class OpenRouterDriver(Driver):
|
|
|
48
54
|
self.api_key = api_key or os.getenv("OPENROUTER_API_KEY")
|
|
49
55
|
if not self.api_key:
|
|
50
56
|
raise ValueError("OpenRouter API key not found. Set OPENROUTER_API_KEY env var.")
|
|
51
|
-
|
|
57
|
+
|
|
52
58
|
self.model = model
|
|
53
59
|
self.base_url = "https://openrouter.ai/api/v1"
|
|
54
|
-
|
|
60
|
+
|
|
55
61
|
# Required headers for OpenRouter
|
|
56
62
|
self.headers = {
|
|
57
63
|
"Authorization": f"Bearer {self.api_key}",
|
|
@@ -59,21 +65,26 @@ class OpenRouterDriver(Driver):
|
|
|
59
65
|
"Content-Type": "application/json",
|
|
60
66
|
}
|
|
61
67
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
68
|
+
supports_messages = True
|
|
69
|
+
|
|
70
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
71
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
72
|
+
|
|
73
|
+
return _prepare_openai_vision_messages(messages)
|
|
74
|
+
|
|
75
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
76
|
+
messages = [{"role": "user", "content": prompt}]
|
|
77
|
+
return self._do_generate(messages, options)
|
|
78
|
+
|
|
79
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
80
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
81
|
+
|
|
82
|
+
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
72
83
|
if not self.api_key:
|
|
73
84
|
raise RuntimeError("OpenRouter API key not found")
|
|
74
85
|
|
|
75
86
|
model = options.get("model", self.model)
|
|
76
|
-
|
|
87
|
+
|
|
77
88
|
# Lookup model-specific config
|
|
78
89
|
model_info = self.MODEL_PRICING.get(model, {})
|
|
79
90
|
tokens_param = model_info.get("tokens_param", "max_tokens")
|
|
@@ -85,7 +96,7 @@ class OpenRouterDriver(Driver):
|
|
|
85
96
|
# Base request data
|
|
86
97
|
data = {
|
|
87
98
|
"model": model,
|
|
88
|
-
"messages":
|
|
99
|
+
"messages": messages,
|
|
89
100
|
}
|
|
90
101
|
|
|
91
102
|
# Add token limit with correct parameter name
|
|
@@ -95,6 +106,10 @@ class OpenRouterDriver(Driver):
|
|
|
95
106
|
if supports_temperature and "temperature" in opts:
|
|
96
107
|
data["temperature"] = opts["temperature"]
|
|
97
108
|
|
|
109
|
+
# Native JSON mode support
|
|
110
|
+
if options.get("json_mode"):
|
|
111
|
+
data["response_format"] = {"type": "json_object"}
|
|
112
|
+
|
|
98
113
|
try:
|
|
99
114
|
response = requests.post(
|
|
100
115
|
f"{self.base_url}/chat/completions",
|
|
@@ -110,11 +125,8 @@ class OpenRouterDriver(Driver):
|
|
|
110
125
|
completion_tokens = usage.get("completion_tokens", 0)
|
|
111
126
|
total_tokens = usage.get("total_tokens", 0)
|
|
112
127
|
|
|
113
|
-
# Calculate cost
|
|
114
|
-
|
|
115
|
-
prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
|
|
116
|
-
completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
|
|
117
|
-
total_cost = prompt_cost + completion_cost
|
|
128
|
+
# Calculate cost via shared mixin
|
|
129
|
+
total_cost = self._calculate_cost("openrouter", model, prompt_tokens, completion_tokens)
|
|
118
130
|
|
|
119
131
|
# Standardized meta object
|
|
120
132
|
meta = {
|
|
@@ -130,11 +142,11 @@ class OpenRouterDriver(Driver):
|
|
|
130
142
|
return {"text": text, "meta": meta}
|
|
131
143
|
|
|
132
144
|
except requests.exceptions.RequestException as e:
|
|
133
|
-
error_msg = f"OpenRouter API request failed: {
|
|
134
|
-
if hasattr(e.response,
|
|
145
|
+
error_msg = f"OpenRouter API request failed: {e!s}"
|
|
146
|
+
if hasattr(e.response, "json"):
|
|
135
147
|
try:
|
|
136
148
|
error_details = e.response.json()
|
|
137
149
|
error_msg = f"{error_msg} - {error_details.get('error', {}).get('message', '')}"
|
|
138
150
|
except Exception:
|
|
139
151
|
pass
|
|
140
|
-
raise RuntimeError(error_msg) from e
|
|
152
|
+
raise RuntimeError(error_msg) from e
|
|
@@ -0,0 +1,306 @@
|
|
|
1
|
+
"""Driver registry with plugin support.
|
|
2
|
+
|
|
3
|
+
This module provides a public API for registering custom drivers and
|
|
4
|
+
supports auto-discovery of drivers via Python entry points.
|
|
5
|
+
|
|
6
|
+
Example usage:
|
|
7
|
+
# Register a custom driver
|
|
8
|
+
from prompture import register_driver
|
|
9
|
+
|
|
10
|
+
def my_driver_factory(model=None):
|
|
11
|
+
return MyCustomDriver(model=model)
|
|
12
|
+
|
|
13
|
+
register_driver("my_provider", my_driver_factory)
|
|
14
|
+
|
|
15
|
+
# Now you can use it
|
|
16
|
+
driver = get_driver_for_model("my_provider/my-model")
|
|
17
|
+
|
|
18
|
+
For entry point discovery, add to your package's pyproject.toml:
|
|
19
|
+
[project.entry-points."prompture.drivers"]
|
|
20
|
+
my_provider = "my_package.drivers:my_driver_factory"
|
|
21
|
+
|
|
22
|
+
[project.entry-points."prompture.async_drivers"]
|
|
23
|
+
my_provider = "my_package.drivers:my_async_driver_factory"
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from __future__ import annotations
|
|
27
|
+
|
|
28
|
+
import logging
|
|
29
|
+
import sys
|
|
30
|
+
from typing import Callable
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger("prompture.drivers.registry")
|
|
33
|
+
|
|
34
|
+
# Type alias for driver factory functions
|
|
35
|
+
# A factory takes an optional model name and returns a driver instance
|
|
36
|
+
DriverFactory = Callable[[str | None], object]
|
|
37
|
+
|
|
38
|
+
# Internal registries - populated by built-in drivers and plugins
|
|
39
|
+
_SYNC_REGISTRY: dict[str, DriverFactory] = {}
|
|
40
|
+
_ASYNC_REGISTRY: dict[str, DriverFactory] = {}
|
|
41
|
+
|
|
42
|
+
# Track whether entry points have been loaded
|
|
43
|
+
_entry_points_loaded = False
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def register_driver(name: str, factory: DriverFactory, *, overwrite: bool = False) -> None:
|
|
47
|
+
"""Register a custom driver factory for a provider name.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
name: Provider name (e.g., "my_provider"). Will be lowercased.
|
|
51
|
+
factory: A callable that takes an optional model name and returns
|
|
52
|
+
a driver instance. The driver must implement the
|
|
53
|
+
``Driver`` interface (specifically ``generate()``).
|
|
54
|
+
overwrite: If True, allow overwriting an existing registration.
|
|
55
|
+
Defaults to False.
|
|
56
|
+
|
|
57
|
+
Raises:
|
|
58
|
+
ValueError: If a driver with this name is already registered
|
|
59
|
+
and overwrite=False.
|
|
60
|
+
|
|
61
|
+
Example:
|
|
62
|
+
>>> def my_factory(model=None):
|
|
63
|
+
... return MyDriver(model=model or "default-model")
|
|
64
|
+
>>> register_driver("my_provider", my_factory)
|
|
65
|
+
>>> driver = get_driver_for_model("my_provider/custom-model")
|
|
66
|
+
"""
|
|
67
|
+
name = name.lower()
|
|
68
|
+
if name in _SYNC_REGISTRY and not overwrite:
|
|
69
|
+
raise ValueError(f"Driver '{name}' is already registered. Use overwrite=True to replace it.")
|
|
70
|
+
_SYNC_REGISTRY[name] = factory
|
|
71
|
+
logger.debug("Registered sync driver: %s", name)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def register_async_driver(name: str, factory: DriverFactory, *, overwrite: bool = False) -> None:
|
|
75
|
+
"""Register a custom async driver factory for a provider name.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
name: Provider name (e.g., "my_provider"). Will be lowercased.
|
|
79
|
+
factory: A callable that takes an optional model name and returns
|
|
80
|
+
an async driver instance. The driver must implement the
|
|
81
|
+
``AsyncDriver`` interface (specifically ``async generate()``).
|
|
82
|
+
overwrite: If True, allow overwriting an existing registration.
|
|
83
|
+
Defaults to False.
|
|
84
|
+
|
|
85
|
+
Raises:
|
|
86
|
+
ValueError: If an async driver with this name is already registered
|
|
87
|
+
and overwrite=False.
|
|
88
|
+
|
|
89
|
+
Example:
|
|
90
|
+
>>> def my_async_factory(model=None):
|
|
91
|
+
... return MyAsyncDriver(model=model or "default-model")
|
|
92
|
+
>>> register_async_driver("my_provider", my_async_factory)
|
|
93
|
+
>>> driver = get_async_driver_for_model("my_provider/custom-model")
|
|
94
|
+
"""
|
|
95
|
+
name = name.lower()
|
|
96
|
+
if name in _ASYNC_REGISTRY and not overwrite:
|
|
97
|
+
raise ValueError(f"Async driver '{name}' is already registered. Use overwrite=True to replace it.")
|
|
98
|
+
_ASYNC_REGISTRY[name] = factory
|
|
99
|
+
logger.debug("Registered async driver: %s", name)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def unregister_driver(name: str) -> bool:
|
|
103
|
+
"""Unregister a sync driver by name.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
name: Provider name to unregister.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
True if the driver was unregistered, False if it wasn't registered.
|
|
110
|
+
"""
|
|
111
|
+
name = name.lower()
|
|
112
|
+
if name in _SYNC_REGISTRY:
|
|
113
|
+
del _SYNC_REGISTRY[name]
|
|
114
|
+
logger.debug("Unregistered sync driver: %s", name)
|
|
115
|
+
return True
|
|
116
|
+
return False
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def unregister_async_driver(name: str) -> bool:
|
|
120
|
+
"""Unregister an async driver by name.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
name: Provider name to unregister.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
True if the driver was unregistered, False if it wasn't registered.
|
|
127
|
+
"""
|
|
128
|
+
name = name.lower()
|
|
129
|
+
if name in _ASYNC_REGISTRY:
|
|
130
|
+
del _ASYNC_REGISTRY[name]
|
|
131
|
+
logger.debug("Unregistered async driver: %s", name)
|
|
132
|
+
return True
|
|
133
|
+
return False
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def list_registered_drivers() -> list[str]:
|
|
137
|
+
"""Return a sorted list of registered sync driver names."""
|
|
138
|
+
_ensure_entry_points_loaded()
|
|
139
|
+
return sorted(_SYNC_REGISTRY.keys())
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def list_registered_async_drivers() -> list[str]:
|
|
143
|
+
"""Return a sorted list of registered async driver names."""
|
|
144
|
+
_ensure_entry_points_loaded()
|
|
145
|
+
return sorted(_ASYNC_REGISTRY.keys())
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def is_driver_registered(name: str) -> bool:
|
|
149
|
+
"""Check if a sync driver is registered.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
name: Provider name to check.
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
True if the driver is registered.
|
|
156
|
+
"""
|
|
157
|
+
_ensure_entry_points_loaded()
|
|
158
|
+
return name.lower() in _SYNC_REGISTRY
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def is_async_driver_registered(name: str) -> bool:
|
|
162
|
+
"""Check if an async driver is registered.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
name: Provider name to check.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
True if the async driver is registered.
|
|
169
|
+
"""
|
|
170
|
+
_ensure_entry_points_loaded()
|
|
171
|
+
return name.lower() in _ASYNC_REGISTRY
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def get_driver_factory(name: str) -> DriverFactory:
|
|
175
|
+
"""Get a registered sync driver factory by name.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
name: Provider name.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
The factory function.
|
|
182
|
+
|
|
183
|
+
Raises:
|
|
184
|
+
ValueError: If the driver is not registered.
|
|
185
|
+
"""
|
|
186
|
+
_ensure_entry_points_loaded()
|
|
187
|
+
name = name.lower()
|
|
188
|
+
if name not in _SYNC_REGISTRY:
|
|
189
|
+
raise ValueError(f"Unsupported provider '{name}'")
|
|
190
|
+
return _SYNC_REGISTRY[name]
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def get_async_driver_factory(name: str) -> DriverFactory:
|
|
194
|
+
"""Get a registered async driver factory by name.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
name: Provider name.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
The factory function.
|
|
201
|
+
|
|
202
|
+
Raises:
|
|
203
|
+
ValueError: If the async driver is not registered.
|
|
204
|
+
"""
|
|
205
|
+
_ensure_entry_points_loaded()
|
|
206
|
+
name = name.lower()
|
|
207
|
+
if name not in _ASYNC_REGISTRY:
|
|
208
|
+
raise ValueError(f"Unsupported provider '{name}'")
|
|
209
|
+
return _ASYNC_REGISTRY[name]
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def load_entry_point_drivers() -> tuple[int, int]:
|
|
213
|
+
"""Load drivers from installed packages via entry points.
|
|
214
|
+
|
|
215
|
+
This function scans for packages that define entry points in the
|
|
216
|
+
``prompture.drivers`` and ``prompture.async_drivers`` groups.
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
A tuple of (sync_count, async_count) indicating how many drivers
|
|
220
|
+
were loaded from entry points.
|
|
221
|
+
|
|
222
|
+
Example pyproject.toml for a plugin package:
|
|
223
|
+
[project.entry-points."prompture.drivers"]
|
|
224
|
+
my_provider = "my_package.drivers:create_my_driver"
|
|
225
|
+
|
|
226
|
+
[project.entry-points."prompture.async_drivers"]
|
|
227
|
+
my_provider = "my_package.drivers:create_my_async_driver"
|
|
228
|
+
"""
|
|
229
|
+
global _entry_points_loaded
|
|
230
|
+
|
|
231
|
+
sync_count = 0
|
|
232
|
+
async_count = 0
|
|
233
|
+
|
|
234
|
+
# Python 3.9+ has importlib.metadata in stdlib
|
|
235
|
+
# Python 3.8 needs importlib_metadata backport
|
|
236
|
+
if sys.version_info >= (3, 10):
|
|
237
|
+
from importlib.metadata import entry_points
|
|
238
|
+
|
|
239
|
+
sync_eps = entry_points(group="prompture.drivers")
|
|
240
|
+
async_eps = entry_points(group="prompture.async_drivers")
|
|
241
|
+
else:
|
|
242
|
+
from importlib.metadata import entry_points
|
|
243
|
+
|
|
244
|
+
all_eps = entry_points()
|
|
245
|
+
sync_eps = all_eps.get("prompture.drivers", [])
|
|
246
|
+
async_eps = all_eps.get("prompture.async_drivers", [])
|
|
247
|
+
|
|
248
|
+
# Load sync drivers
|
|
249
|
+
for ep in sync_eps:
|
|
250
|
+
try:
|
|
251
|
+
# Skip if already registered (built-in drivers take precedence)
|
|
252
|
+
if ep.name.lower() in _SYNC_REGISTRY:
|
|
253
|
+
logger.debug("Skipping entry point driver '%s' (already registered)", ep.name)
|
|
254
|
+
continue
|
|
255
|
+
|
|
256
|
+
factory = ep.load()
|
|
257
|
+
_SYNC_REGISTRY[ep.name.lower()] = factory
|
|
258
|
+
sync_count += 1
|
|
259
|
+
logger.info("Loaded sync driver from entry point: %s", ep.name)
|
|
260
|
+
except Exception:
|
|
261
|
+
logger.exception("Failed to load sync driver entry point: %s", ep.name)
|
|
262
|
+
|
|
263
|
+
# Load async drivers
|
|
264
|
+
for ep in async_eps:
|
|
265
|
+
try:
|
|
266
|
+
# Skip if already registered (built-in drivers take precedence)
|
|
267
|
+
if ep.name.lower() in _ASYNC_REGISTRY:
|
|
268
|
+
logger.debug("Skipping entry point async driver '%s' (already registered)", ep.name)
|
|
269
|
+
continue
|
|
270
|
+
|
|
271
|
+
factory = ep.load()
|
|
272
|
+
_ASYNC_REGISTRY[ep.name.lower()] = factory
|
|
273
|
+
async_count += 1
|
|
274
|
+
logger.info("Loaded async driver from entry point: %s", ep.name)
|
|
275
|
+
except Exception:
|
|
276
|
+
logger.exception("Failed to load async driver entry point: %s", ep.name)
|
|
277
|
+
|
|
278
|
+
_entry_points_loaded = True
|
|
279
|
+
return (sync_count, async_count)
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def _ensure_entry_points_loaded() -> None:
|
|
283
|
+
"""Ensure entry points have been loaded (lazy initialization)."""
|
|
284
|
+
global _entry_points_loaded
|
|
285
|
+
if not _entry_points_loaded:
|
|
286
|
+
load_entry_point_drivers()
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def _get_sync_registry() -> dict[str, DriverFactory]:
|
|
290
|
+
"""Get the internal sync registry dict (for internal use by drivers/__init__.py)."""
|
|
291
|
+
_ensure_entry_points_loaded()
|
|
292
|
+
return _SYNC_REGISTRY
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def _get_async_registry() -> dict[str, DriverFactory]:
|
|
296
|
+
"""Get the internal async registry dict (for internal use by drivers/async_registry.py)."""
|
|
297
|
+
_ensure_entry_points_loaded()
|
|
298
|
+
return _ASYNC_REGISTRY
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def _reset_registries() -> None:
|
|
302
|
+
"""Reset registries to empty state (for testing only)."""
|
|
303
|
+
global _entry_points_loaded
|
|
304
|
+
_SYNC_REGISTRY.clear()
|
|
305
|
+
_ASYNC_REGISTRY.clear()
|
|
306
|
+
_entry_points_loaded = False
|