prompture 0.0.34__py3-none-any.whl → 0.0.34.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +0 -21
- prompture/_version.py +2 -2
- prompture/drivers/__init__.py +26 -118
- prompture/drivers/async_registry.py +31 -80
- {prompture-0.0.34.dist-info → prompture-0.0.34.dev1.dist-info}/METADATA +1 -1
- {prompture-0.0.34.dist-info → prompture-0.0.34.dev1.dist-info}/RECORD +10 -11
- prompture/drivers/registry.py +0 -306
- {prompture-0.0.34.dist-info → prompture-0.0.34.dev1.dist-info}/WHEEL +0 -0
- {prompture-0.0.34.dist-info → prompture-0.0.34.dev1.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.34.dist-info → prompture-0.0.34.dev1.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.34.dist-info → prompture-0.0.34.dev1.dist-info}/top_level.txt +0 -0
prompture/__init__.py
CHANGED
|
@@ -42,16 +42,6 @@ from .drivers import (
|
|
|
42
42
|
OpenRouterDriver,
|
|
43
43
|
get_driver,
|
|
44
44
|
get_driver_for_model,
|
|
45
|
-
# Plugin registration API
|
|
46
|
-
is_async_driver_registered,
|
|
47
|
-
is_driver_registered,
|
|
48
|
-
list_registered_async_drivers,
|
|
49
|
-
list_registered_drivers,
|
|
50
|
-
load_entry_point_drivers,
|
|
51
|
-
register_async_driver,
|
|
52
|
-
register_driver,
|
|
53
|
-
unregister_async_driver,
|
|
54
|
-
unregister_driver,
|
|
55
45
|
)
|
|
56
46
|
from .field_definitions import (
|
|
57
47
|
FIELD_DEFINITIONS,
|
|
@@ -153,25 +143,14 @@ __all__ = [
|
|
|
153
143
|
"get_model_rates",
|
|
154
144
|
"get_registry_snapshot",
|
|
155
145
|
"get_required_fields",
|
|
156
|
-
# Plugin registration API
|
|
157
|
-
"is_async_driver_registered",
|
|
158
|
-
"is_driver_registered",
|
|
159
|
-
"list_registered_async_drivers",
|
|
160
|
-
"list_registered_drivers",
|
|
161
|
-
"load_entry_point_drivers",
|
|
162
|
-
# Other exports
|
|
163
146
|
"manual_extract_and_jsonify",
|
|
164
147
|
"normalize_enum_value",
|
|
165
148
|
"refresh_rates_cache",
|
|
166
|
-
"register_async_driver",
|
|
167
|
-
"register_driver",
|
|
168
149
|
"register_field",
|
|
169
150
|
"render_output",
|
|
170
151
|
"reset_registry",
|
|
171
152
|
"run_suite_from_spec",
|
|
172
153
|
"stepwise_extract_with_model",
|
|
173
|
-
"unregister_async_driver",
|
|
174
|
-
"unregister_driver",
|
|
175
154
|
"validate_against_schema",
|
|
176
155
|
"validate_enum_value",
|
|
177
156
|
]
|
prompture/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.0.34'
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 0, 34)
|
|
31
|
+
__version__ = version = '0.0.34.dev1'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 0, 34, 'dev1')
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
prompture/drivers/__init__.py
CHANGED
|
@@ -1,29 +1,3 @@
|
|
|
1
|
-
"""Driver registry and factory functions.
|
|
2
|
-
|
|
3
|
-
This module provides:
|
|
4
|
-
- Built-in drivers for popular LLM providers
|
|
5
|
-
- A pluggable registry system for custom drivers
|
|
6
|
-
- Factory functions to instantiate drivers by provider/model name
|
|
7
|
-
|
|
8
|
-
Custom Driver Registration:
|
|
9
|
-
from prompture import register_driver
|
|
10
|
-
|
|
11
|
-
def my_driver_factory(model=None):
|
|
12
|
-
return MyCustomDriver(model=model)
|
|
13
|
-
|
|
14
|
-
register_driver("my_provider", my_driver_factory)
|
|
15
|
-
|
|
16
|
-
# Now you can use it
|
|
17
|
-
driver = get_driver_for_model("my_provider/my-model")
|
|
18
|
-
|
|
19
|
-
Entry Point Discovery:
|
|
20
|
-
Third-party packages can register drivers via entry points.
|
|
21
|
-
Add to your pyproject.toml:
|
|
22
|
-
|
|
23
|
-
[project.entry-points."prompture.drivers"]
|
|
24
|
-
my_provider = "my_package.drivers:my_driver_factory"
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
1
|
from typing import Optional
|
|
28
2
|
|
|
29
3
|
from ..settings import settings
|
|
@@ -51,85 +25,30 @@ from .local_http_driver import LocalHTTPDriver
|
|
|
51
25
|
from .ollama_driver import OllamaDriver
|
|
52
26
|
from .openai_driver import OpenAIDriver
|
|
53
27
|
from .openrouter_driver import OpenRouterDriver
|
|
54
|
-
from .registry import (
|
|
55
|
-
_get_sync_registry,
|
|
56
|
-
get_async_driver_factory,
|
|
57
|
-
get_driver_factory,
|
|
58
|
-
is_async_driver_registered,
|
|
59
|
-
is_driver_registered,
|
|
60
|
-
list_registered_async_drivers,
|
|
61
|
-
list_registered_drivers,
|
|
62
|
-
load_entry_point_drivers,
|
|
63
|
-
register_async_driver,
|
|
64
|
-
register_driver,
|
|
65
|
-
unregister_async_driver,
|
|
66
|
-
unregister_driver,
|
|
67
|
-
)
|
|
68
28
|
|
|
69
|
-
#
|
|
70
|
-
|
|
71
|
-
"openai",
|
|
72
|
-
lambda model=None:
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
lambda model=None:
|
|
78
|
-
overwrite=True,
|
|
79
|
-
)
|
|
80
|
-
register_driver(
|
|
81
|
-
"claude",
|
|
82
|
-
lambda model=None: ClaudeDriver(api_key=settings.claude_api_key, model=model or settings.claude_model),
|
|
83
|
-
overwrite=True,
|
|
84
|
-
)
|
|
85
|
-
register_driver(
|
|
86
|
-
"lmstudio",
|
|
87
|
-
lambda model=None: LMStudioDriver(endpoint=settings.lmstudio_endpoint, model=model or settings.lmstudio_model),
|
|
88
|
-
overwrite=True,
|
|
89
|
-
)
|
|
90
|
-
register_driver(
|
|
91
|
-
"azure",
|
|
92
|
-
lambda model=None: AzureDriver(
|
|
29
|
+
# Central registry: maps provider → factory function
|
|
30
|
+
DRIVER_REGISTRY = {
|
|
31
|
+
"openai": lambda model=None: OpenAIDriver(api_key=settings.openai_api_key, model=model or settings.openai_model),
|
|
32
|
+
"ollama": lambda model=None: OllamaDriver(endpoint=settings.ollama_endpoint, model=model or settings.ollama_model),
|
|
33
|
+
"claude": lambda model=None: ClaudeDriver(api_key=settings.claude_api_key, model=model or settings.claude_model),
|
|
34
|
+
"lmstudio": lambda model=None: LMStudioDriver(
|
|
35
|
+
endpoint=settings.lmstudio_endpoint, model=model or settings.lmstudio_model
|
|
36
|
+
),
|
|
37
|
+
"azure": lambda model=None: AzureDriver(
|
|
93
38
|
api_key=settings.azure_api_key, endpoint=settings.azure_api_endpoint, deployment_id=settings.azure_deployment_id
|
|
94
39
|
),
|
|
95
|
-
|
|
96
|
-
)
|
|
97
|
-
|
|
98
|
-
"
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
)
|
|
102
|
-
|
|
103
|
-
"google",
|
|
104
|
-
lambda model=None: GoogleDriver(api_key=settings.google_api_key, model=model or settings.google_model),
|
|
105
|
-
overwrite=True,
|
|
106
|
-
)
|
|
107
|
-
register_driver(
|
|
108
|
-
"groq",
|
|
109
|
-
lambda model=None: GroqDriver(api_key=settings.groq_api_key, model=model or settings.groq_model),
|
|
110
|
-
overwrite=True,
|
|
111
|
-
)
|
|
112
|
-
register_driver(
|
|
113
|
-
"openrouter",
|
|
114
|
-
lambda model=None: OpenRouterDriver(api_key=settings.openrouter_api_key, model=model or settings.openrouter_model),
|
|
115
|
-
overwrite=True,
|
|
116
|
-
)
|
|
117
|
-
register_driver(
|
|
118
|
-
"grok",
|
|
119
|
-
lambda model=None: GrokDriver(api_key=settings.grok_api_key, model=model or settings.grok_model),
|
|
120
|
-
overwrite=True,
|
|
121
|
-
)
|
|
122
|
-
register_driver(
|
|
123
|
-
"airllm",
|
|
124
|
-
lambda model=None: AirLLMDriver(
|
|
40
|
+
"local_http": lambda model=None: LocalHTTPDriver(endpoint=settings.local_http_endpoint, model=model),
|
|
41
|
+
"google": lambda model=None: GoogleDriver(api_key=settings.google_api_key, model=model or settings.google_model),
|
|
42
|
+
"groq": lambda model=None: GroqDriver(api_key=settings.groq_api_key, model=model or settings.groq_model),
|
|
43
|
+
"openrouter": lambda model=None: OpenRouterDriver(
|
|
44
|
+
api_key=settings.openrouter_api_key, model=model or settings.openrouter_model
|
|
45
|
+
),
|
|
46
|
+
"grok": lambda model=None: GrokDriver(api_key=settings.grok_api_key, model=model or settings.grok_model),
|
|
47
|
+
"airllm": lambda model=None: AirLLMDriver(
|
|
125
48
|
model=model or settings.airllm_model,
|
|
126
49
|
compression=settings.airllm_compression,
|
|
127
50
|
),
|
|
128
|
-
|
|
129
|
-
)
|
|
130
|
-
|
|
131
|
-
# Backwards compatibility: expose registry dict (read-only view recommended)
|
|
132
|
-
DRIVER_REGISTRY = _get_sync_registry()
|
|
51
|
+
}
|
|
133
52
|
|
|
134
53
|
|
|
135
54
|
def get_driver(provider_name: Optional[str] = None):
|
|
@@ -138,8 +57,9 @@ def get_driver(provider_name: Optional[str] = None):
|
|
|
138
57
|
Uses default model from settings if not overridden.
|
|
139
58
|
"""
|
|
140
59
|
provider = (provider_name or settings.ai_provider or "ollama").strip().lower()
|
|
141
|
-
|
|
142
|
-
|
|
60
|
+
if provider not in DRIVER_REGISTRY:
|
|
61
|
+
raise ValueError(f"Unknown provider: {provider_name}")
|
|
62
|
+
return DRIVER_REGISTRY[provider]() # use default model from settings
|
|
143
63
|
|
|
144
64
|
|
|
145
65
|
def get_driver_for_model(model_str: str):
|
|
@@ -170,20 +90,19 @@ def get_driver_for_model(model_str: str):
|
|
|
170
90
|
provider = parts[0].lower()
|
|
171
91
|
model_id = parts[1] if len(parts) > 1 else None
|
|
172
92
|
|
|
173
|
-
#
|
|
174
|
-
|
|
93
|
+
# Validate provider
|
|
94
|
+
if provider not in DRIVER_REGISTRY:
|
|
95
|
+
raise ValueError(f"Unsupported provider '{provider}'")
|
|
175
96
|
|
|
176
97
|
# Create driver with model ID if provided, otherwise use default
|
|
177
|
-
return
|
|
98
|
+
return DRIVER_REGISTRY[provider](model_id)
|
|
178
99
|
|
|
179
100
|
|
|
180
101
|
__all__ = [
|
|
102
|
+
# Async drivers
|
|
181
103
|
"ASYNC_DRIVER_REGISTRY",
|
|
182
|
-
# Legacy registry dicts (for backwards compatibility)
|
|
183
|
-
"DRIVER_REGISTRY",
|
|
184
104
|
# Sync drivers
|
|
185
105
|
"AirLLMDriver",
|
|
186
|
-
# Async drivers
|
|
187
106
|
"AsyncAirLLMDriver",
|
|
188
107
|
"AsyncAzureDriver",
|
|
189
108
|
"AsyncClaudeDriver",
|
|
@@ -208,17 +127,6 @@ __all__ = [
|
|
|
208
127
|
"OpenRouterDriver",
|
|
209
128
|
"get_async_driver",
|
|
210
129
|
"get_async_driver_for_model",
|
|
211
|
-
# Factory functions
|
|
212
130
|
"get_driver",
|
|
213
131
|
"get_driver_for_model",
|
|
214
|
-
"is_async_driver_registered",
|
|
215
|
-
"is_driver_registered",
|
|
216
|
-
"list_registered_async_drivers",
|
|
217
|
-
"list_registered_drivers",
|
|
218
|
-
"load_entry_point_drivers",
|
|
219
|
-
"register_async_driver",
|
|
220
|
-
# Registry functions (public API)
|
|
221
|
-
"register_driver",
|
|
222
|
-
"unregister_async_driver",
|
|
223
|
-
"unregister_driver",
|
|
224
132
|
]
|
|
@@ -1,15 +1,4 @@
|
|
|
1
|
-
"""Async driver registry — mirrors the sync DRIVER_REGISTRY.
|
|
2
|
-
|
|
3
|
-
This module provides async driver registration and factory functions.
|
|
4
|
-
Custom async drivers can be registered via the ``register_async_driver()``
|
|
5
|
-
function or discovered via entry points.
|
|
6
|
-
|
|
7
|
-
Entry Point Discovery:
|
|
8
|
-
Add to your pyproject.toml:
|
|
9
|
-
|
|
10
|
-
[project.entry-points."prompture.async_drivers"]
|
|
11
|
-
my_provider = "my_package.drivers:my_async_driver_factory"
|
|
12
|
-
"""
|
|
1
|
+
"""Async driver registry — mirrors the sync DRIVER_REGISTRY."""
|
|
13
2
|
|
|
14
3
|
from __future__ import annotations
|
|
15
4
|
|
|
@@ -25,78 +14,37 @@ from .async_local_http_driver import AsyncLocalHTTPDriver
|
|
|
25
14
|
from .async_ollama_driver import AsyncOllamaDriver
|
|
26
15
|
from .async_openai_driver import AsyncOpenAIDriver
|
|
27
16
|
from .async_openrouter_driver import AsyncOpenRouterDriver
|
|
28
|
-
from .registry import (
|
|
29
|
-
_get_async_registry,
|
|
30
|
-
get_async_driver_factory,
|
|
31
|
-
register_async_driver,
|
|
32
|
-
)
|
|
33
17
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
"
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
lambda model=None:
|
|
48
|
-
overwrite=True,
|
|
49
|
-
)
|
|
50
|
-
register_async_driver(
|
|
51
|
-
"lmstudio",
|
|
52
|
-
lambda model=None: AsyncLMStudioDriver(endpoint=settings.lmstudio_endpoint, model=model or settings.lmstudio_model),
|
|
53
|
-
overwrite=True,
|
|
54
|
-
)
|
|
55
|
-
register_async_driver(
|
|
56
|
-
"azure",
|
|
57
|
-
lambda model=None: AsyncAzureDriver(
|
|
18
|
+
ASYNC_DRIVER_REGISTRY = {
|
|
19
|
+
"openai": lambda model=None: AsyncOpenAIDriver(
|
|
20
|
+
api_key=settings.openai_api_key, model=model or settings.openai_model
|
|
21
|
+
),
|
|
22
|
+
"ollama": lambda model=None: AsyncOllamaDriver(
|
|
23
|
+
endpoint=settings.ollama_endpoint, model=model or settings.ollama_model
|
|
24
|
+
),
|
|
25
|
+
"claude": lambda model=None: AsyncClaudeDriver(
|
|
26
|
+
api_key=settings.claude_api_key, model=model or settings.claude_model
|
|
27
|
+
),
|
|
28
|
+
"lmstudio": lambda model=None: AsyncLMStudioDriver(
|
|
29
|
+
endpoint=settings.lmstudio_endpoint, model=model or settings.lmstudio_model
|
|
30
|
+
),
|
|
31
|
+
"azure": lambda model=None: AsyncAzureDriver(
|
|
58
32
|
api_key=settings.azure_api_key, endpoint=settings.azure_api_endpoint, deployment_id=settings.azure_deployment_id
|
|
59
33
|
),
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
lambda model=None:
|
|
65
|
-
|
|
66
|
-
)
|
|
67
|
-
register_async_driver(
|
|
68
|
-
"google",
|
|
69
|
-
lambda model=None: AsyncGoogleDriver(api_key=settings.google_api_key, model=model or settings.google_model),
|
|
70
|
-
overwrite=True,
|
|
71
|
-
)
|
|
72
|
-
register_async_driver(
|
|
73
|
-
"groq",
|
|
74
|
-
lambda model=None: AsyncGroqDriver(api_key=settings.groq_api_key, model=model or settings.groq_model),
|
|
75
|
-
overwrite=True,
|
|
76
|
-
)
|
|
77
|
-
register_async_driver(
|
|
78
|
-
"openrouter",
|
|
79
|
-
lambda model=None: AsyncOpenRouterDriver(
|
|
34
|
+
"local_http": lambda model=None: AsyncLocalHTTPDriver(endpoint=settings.local_http_endpoint, model=model),
|
|
35
|
+
"google": lambda model=None: AsyncGoogleDriver(
|
|
36
|
+
api_key=settings.google_api_key, model=model or settings.google_model
|
|
37
|
+
),
|
|
38
|
+
"groq": lambda model=None: AsyncGroqDriver(api_key=settings.groq_api_key, model=model or settings.groq_model),
|
|
39
|
+
"openrouter": lambda model=None: AsyncOpenRouterDriver(
|
|
80
40
|
api_key=settings.openrouter_api_key, model=model or settings.openrouter_model
|
|
81
41
|
),
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
register_async_driver(
|
|
85
|
-
"grok",
|
|
86
|
-
lambda model=None: AsyncGrokDriver(api_key=settings.grok_api_key, model=model or settings.grok_model),
|
|
87
|
-
overwrite=True,
|
|
88
|
-
)
|
|
89
|
-
register_async_driver(
|
|
90
|
-
"airllm",
|
|
91
|
-
lambda model=None: AsyncAirLLMDriver(
|
|
42
|
+
"grok": lambda model=None: AsyncGrokDriver(api_key=settings.grok_api_key, model=model or settings.grok_model),
|
|
43
|
+
"airllm": lambda model=None: AsyncAirLLMDriver(
|
|
92
44
|
model=model or settings.airllm_model,
|
|
93
45
|
compression=settings.airllm_compression,
|
|
94
46
|
),
|
|
95
|
-
|
|
96
|
-
)
|
|
97
|
-
|
|
98
|
-
# Backwards compatibility: expose registry dict
|
|
99
|
-
ASYNC_DRIVER_REGISTRY = _get_async_registry()
|
|
47
|
+
}
|
|
100
48
|
|
|
101
49
|
|
|
102
50
|
def get_async_driver(provider_name: str | None = None):
|
|
@@ -105,8 +53,9 @@ def get_async_driver(provider_name: str | None = None):
|
|
|
105
53
|
Uses default model from settings if not overridden.
|
|
106
54
|
"""
|
|
107
55
|
provider = (provider_name or settings.ai_provider or "ollama").strip().lower()
|
|
108
|
-
|
|
109
|
-
|
|
56
|
+
if provider not in ASYNC_DRIVER_REGISTRY:
|
|
57
|
+
raise ValueError(f"Unknown provider: {provider_name}")
|
|
58
|
+
return ASYNC_DRIVER_REGISTRY[provider]()
|
|
110
59
|
|
|
111
60
|
|
|
112
61
|
def get_async_driver_for_model(model_str: str):
|
|
@@ -125,5 +74,7 @@ def get_async_driver_for_model(model_str: str):
|
|
|
125
74
|
provider = parts[0].lower()
|
|
126
75
|
model_id = parts[1] if len(parts) > 1 else None
|
|
127
76
|
|
|
128
|
-
|
|
129
|
-
|
|
77
|
+
if provider not in ASYNC_DRIVER_REGISTRY:
|
|
78
|
+
raise ValueError(f"Unsupported provider '{provider}'")
|
|
79
|
+
|
|
80
|
+
return ASYNC_DRIVER_REGISTRY[provider](model_id)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
prompture/__init__.py,sha256=
|
|
2
|
-
prompture/_version.py,sha256=
|
|
1
|
+
prompture/__init__.py,sha256=XLOuu8AnbAjEq7oKPyrIcXzoSFzfzjBfWmS4WjZlCpU,3945
|
|
2
|
+
prompture/_version.py,sha256=oWvGRxkg4wF7c8HSUd3ltOCQNJfrhcDmrYXTODvTKeA,719
|
|
3
3
|
prompture/async_conversation.py,sha256=Q4wJctjTffH55i5FnxV5tt0xB-WCJCfCmxeHLOisuP8,18807
|
|
4
4
|
prompture/async_core.py,sha256=s8G0nGUGR1Bf_BQG9_FcQRpveSnJKkEwcWNfbAJaSkg,29208
|
|
5
5
|
prompture/async_driver.py,sha256=8mV3wEQiGOGuegaAPJ6uT1lNXwFDmbUGKXZkHk9Baow,4630
|
|
@@ -20,7 +20,7 @@ prompture/settings.py,sha256=o-zsYpxRvSg-ICGWqqVNEoJG23GCMBLlkC7RPXpouSw,1976
|
|
|
20
20
|
prompture/tools.py,sha256=PmFbGHTWYWahpJOG6BLlM0Y-EG6S37IFW57C-8GdsXo,36449
|
|
21
21
|
prompture/validator.py,sha256=FY_VjIVEbjG2nwzh-r6l23Kt3UzaLyCis8_pZMNGHBA,993
|
|
22
22
|
prompture/aio/__init__.py,sha256=bKqTu4Jxld16aP_7SP9wU5au45UBIb041ORo4E4HzVo,1810
|
|
23
|
-
prompture/drivers/__init__.py,sha256=
|
|
23
|
+
prompture/drivers/__init__.py,sha256=5bNmiKaKYh1hanTa_SZ0akwEcOMVJoNSeMGBXb67VfM,5083
|
|
24
24
|
prompture/drivers/airllm_driver.py,sha256=SaTh7e7Plvuct_TfRqQvsJsKHvvM_3iVqhBtlciM-Kw,3858
|
|
25
25
|
prompture/drivers/async_airllm_driver.py,sha256=1hIWLXfyyIg9tXaOE22tLJvFyNwHnOi1M5BIKnV8ysk,908
|
|
26
26
|
prompture/drivers/async_azure_driver.py,sha256=vRp1PlOB87OLUbEZJEp7En3tvadG956Q6AV2o9UmyLA,4196
|
|
@@ -34,7 +34,7 @@ prompture/drivers/async_local_http_driver.py,sha256=qoigIf-w3_c2dbVdM6m1e2RMAWP4
|
|
|
34
34
|
prompture/drivers/async_ollama_driver.py,sha256=vRd2VIl412d6WVSo8vmZg0GBYUo7gBj-S2_55PpUWbk,4511
|
|
35
35
|
prompture/drivers/async_openai_driver.py,sha256=jHtSA_MeeIwGeE9o9F1ZsKTNgGGA7xF3WbGZgD8ACEU,3305
|
|
36
36
|
prompture/drivers/async_openrouter_driver.py,sha256=OKL4MfRAopXaMevf6A6WcAytyvWr0tWO_BmshdI0fSY,3516
|
|
37
|
-
prompture/drivers/async_registry.py,sha256=
|
|
37
|
+
prompture/drivers/async_registry.py,sha256=xI_QI8z56vg6qWy6mJTsIdtd16o3nilBZbVIbfArEKI,3393
|
|
38
38
|
prompture/drivers/azure_driver.py,sha256=4IAzdKqcORgVEDUj6itkVmJUg1ayo4HXSfqLKzIGnlM,5460
|
|
39
39
|
prompture/drivers/claude_driver.py,sha256=11nEiq0Ga-d-0vwxbWo-cAl4AjnYPLdRyPtucJs9xTA,4985
|
|
40
40
|
prompture/drivers/google_driver.py,sha256=Ysa1ZZEAPEKHCJFCBiJtB4K-sGvsYti4hGBv_85nowY,8454
|
|
@@ -46,10 +46,9 @@ prompture/drivers/local_http_driver.py,sha256=QJgEf9kAmy8YZ5fb8FHnWuhoDoZYNd8at4
|
|
|
46
46
|
prompture/drivers/ollama_driver.py,sha256=fNvHW5mp7cIwpZKCS5r7WAO-yTK01BKTKvotM_GJCE0,7229
|
|
47
47
|
prompture/drivers/openai_driver.py,sha256=-mgieUBlNgjz5B6ejRwMPoIGgvEW2p5LLwS9j0B9hno,4815
|
|
48
48
|
prompture/drivers/openrouter_driver.py,sha256=WH48KEkafuxFX6b55FzwT57tUlmbwYlHSeNsIxWvM4o,5141
|
|
49
|
-
prompture/
|
|
50
|
-
prompture-0.0.34.dist-info/
|
|
51
|
-
prompture-0.0.34.dist-info/
|
|
52
|
-
prompture-0.0.34.dist-info/
|
|
53
|
-
prompture-0.0.34.dist-info/
|
|
54
|
-
prompture-0.0.34.dist-info/
|
|
55
|
-
prompture-0.0.34.dist-info/RECORD,,
|
|
49
|
+
prompture-0.0.34.dev1.dist-info/licenses/LICENSE,sha256=0HgDepH7aaHNFhHF-iXuW6_GqDfYPnVkjtiCAZ4yS8I,1060
|
|
50
|
+
prompture-0.0.34.dev1.dist-info/METADATA,sha256=aEHoj6Hw91QHAzbZ40RoW2RQ9K0RckRWEnsj-lr7c3c,18231
|
|
51
|
+
prompture-0.0.34.dev1.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
52
|
+
prompture-0.0.34.dev1.dist-info/entry_points.txt,sha256=AFPG3lJR86g4IJMoWQUW5Ph7G6MLNWG3A2u2Tp9zkp8,48
|
|
53
|
+
prompture-0.0.34.dev1.dist-info/top_level.txt,sha256=to86zq_kjfdoLeAxQNr420UWqT0WzkKoZ509J7Qr2t4,10
|
|
54
|
+
prompture-0.0.34.dev1.dist-info/RECORD,,
|
prompture/drivers/registry.py
DELETED
|
@@ -1,306 +0,0 @@
|
|
|
1
|
-
"""Driver registry with plugin support.
|
|
2
|
-
|
|
3
|
-
This module provides a public API for registering custom drivers and
|
|
4
|
-
supports auto-discovery of drivers via Python entry points.
|
|
5
|
-
|
|
6
|
-
Example usage:
|
|
7
|
-
# Register a custom driver
|
|
8
|
-
from prompture import register_driver
|
|
9
|
-
|
|
10
|
-
def my_driver_factory(model=None):
|
|
11
|
-
return MyCustomDriver(model=model)
|
|
12
|
-
|
|
13
|
-
register_driver("my_provider", my_driver_factory)
|
|
14
|
-
|
|
15
|
-
# Now you can use it
|
|
16
|
-
driver = get_driver_for_model("my_provider/my-model")
|
|
17
|
-
|
|
18
|
-
For entry point discovery, add to your package's pyproject.toml:
|
|
19
|
-
[project.entry-points."prompture.drivers"]
|
|
20
|
-
my_provider = "my_package.drivers:my_driver_factory"
|
|
21
|
-
|
|
22
|
-
[project.entry-points."prompture.async_drivers"]
|
|
23
|
-
my_provider = "my_package.drivers:my_async_driver_factory"
|
|
24
|
-
"""
|
|
25
|
-
|
|
26
|
-
from __future__ import annotations
|
|
27
|
-
|
|
28
|
-
import logging
|
|
29
|
-
import sys
|
|
30
|
-
from typing import Callable
|
|
31
|
-
|
|
32
|
-
logger = logging.getLogger("prompture.drivers.registry")
|
|
33
|
-
|
|
34
|
-
# Type alias for driver factory functions
|
|
35
|
-
# A factory takes an optional model name and returns a driver instance
|
|
36
|
-
DriverFactory = Callable[[str | None], object]
|
|
37
|
-
|
|
38
|
-
# Internal registries - populated by built-in drivers and plugins
|
|
39
|
-
_SYNC_REGISTRY: dict[str, DriverFactory] = {}
|
|
40
|
-
_ASYNC_REGISTRY: dict[str, DriverFactory] = {}
|
|
41
|
-
|
|
42
|
-
# Track whether entry points have been loaded
|
|
43
|
-
_entry_points_loaded = False
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
def register_driver(name: str, factory: DriverFactory, *, overwrite: bool = False) -> None:
|
|
47
|
-
"""Register a custom driver factory for a provider name.
|
|
48
|
-
|
|
49
|
-
Args:
|
|
50
|
-
name: Provider name (e.g., "my_provider"). Will be lowercased.
|
|
51
|
-
factory: A callable that takes an optional model name and returns
|
|
52
|
-
a driver instance. The driver must implement the
|
|
53
|
-
``Driver`` interface (specifically ``generate()``).
|
|
54
|
-
overwrite: If True, allow overwriting an existing registration.
|
|
55
|
-
Defaults to False.
|
|
56
|
-
|
|
57
|
-
Raises:
|
|
58
|
-
ValueError: If a driver with this name is already registered
|
|
59
|
-
and overwrite=False.
|
|
60
|
-
|
|
61
|
-
Example:
|
|
62
|
-
>>> def my_factory(model=None):
|
|
63
|
-
... return MyDriver(model=model or "default-model")
|
|
64
|
-
>>> register_driver("my_provider", my_factory)
|
|
65
|
-
>>> driver = get_driver_for_model("my_provider/custom-model")
|
|
66
|
-
"""
|
|
67
|
-
name = name.lower()
|
|
68
|
-
if name in _SYNC_REGISTRY and not overwrite:
|
|
69
|
-
raise ValueError(f"Driver '{name}' is already registered. Use overwrite=True to replace it.")
|
|
70
|
-
_SYNC_REGISTRY[name] = factory
|
|
71
|
-
logger.debug("Registered sync driver: %s", name)
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
def register_async_driver(name: str, factory: DriverFactory, *, overwrite: bool = False) -> None:
|
|
75
|
-
"""Register a custom async driver factory for a provider name.
|
|
76
|
-
|
|
77
|
-
Args:
|
|
78
|
-
name: Provider name (e.g., "my_provider"). Will be lowercased.
|
|
79
|
-
factory: A callable that takes an optional model name and returns
|
|
80
|
-
an async driver instance. The driver must implement the
|
|
81
|
-
``AsyncDriver`` interface (specifically ``async generate()``).
|
|
82
|
-
overwrite: If True, allow overwriting an existing registration.
|
|
83
|
-
Defaults to False.
|
|
84
|
-
|
|
85
|
-
Raises:
|
|
86
|
-
ValueError: If an async driver with this name is already registered
|
|
87
|
-
and overwrite=False.
|
|
88
|
-
|
|
89
|
-
Example:
|
|
90
|
-
>>> def my_async_factory(model=None):
|
|
91
|
-
... return MyAsyncDriver(model=model or "default-model")
|
|
92
|
-
>>> register_async_driver("my_provider", my_async_factory)
|
|
93
|
-
>>> driver = get_async_driver_for_model("my_provider/custom-model")
|
|
94
|
-
"""
|
|
95
|
-
name = name.lower()
|
|
96
|
-
if name in _ASYNC_REGISTRY and not overwrite:
|
|
97
|
-
raise ValueError(f"Async driver '{name}' is already registered. Use overwrite=True to replace it.")
|
|
98
|
-
_ASYNC_REGISTRY[name] = factory
|
|
99
|
-
logger.debug("Registered async driver: %s", name)
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
def unregister_driver(name: str) -> bool:
|
|
103
|
-
"""Unregister a sync driver by name.
|
|
104
|
-
|
|
105
|
-
Args:
|
|
106
|
-
name: Provider name to unregister.
|
|
107
|
-
|
|
108
|
-
Returns:
|
|
109
|
-
True if the driver was unregistered, False if it wasn't registered.
|
|
110
|
-
"""
|
|
111
|
-
name = name.lower()
|
|
112
|
-
if name in _SYNC_REGISTRY:
|
|
113
|
-
del _SYNC_REGISTRY[name]
|
|
114
|
-
logger.debug("Unregistered sync driver: %s", name)
|
|
115
|
-
return True
|
|
116
|
-
return False
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
def unregister_async_driver(name: str) -> bool:
|
|
120
|
-
"""Unregister an async driver by name.
|
|
121
|
-
|
|
122
|
-
Args:
|
|
123
|
-
name: Provider name to unregister.
|
|
124
|
-
|
|
125
|
-
Returns:
|
|
126
|
-
True if the driver was unregistered, False if it wasn't registered.
|
|
127
|
-
"""
|
|
128
|
-
name = name.lower()
|
|
129
|
-
if name in _ASYNC_REGISTRY:
|
|
130
|
-
del _ASYNC_REGISTRY[name]
|
|
131
|
-
logger.debug("Unregistered async driver: %s", name)
|
|
132
|
-
return True
|
|
133
|
-
return False
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
def list_registered_drivers() -> list[str]:
|
|
137
|
-
"""Return a sorted list of registered sync driver names."""
|
|
138
|
-
_ensure_entry_points_loaded()
|
|
139
|
-
return sorted(_SYNC_REGISTRY.keys())
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
def list_registered_async_drivers() -> list[str]:
|
|
143
|
-
"""Return a sorted list of registered async driver names."""
|
|
144
|
-
_ensure_entry_points_loaded()
|
|
145
|
-
return sorted(_ASYNC_REGISTRY.keys())
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
def is_driver_registered(name: str) -> bool:
|
|
149
|
-
"""Check if a sync driver is registered.
|
|
150
|
-
|
|
151
|
-
Args:
|
|
152
|
-
name: Provider name to check.
|
|
153
|
-
|
|
154
|
-
Returns:
|
|
155
|
-
True if the driver is registered.
|
|
156
|
-
"""
|
|
157
|
-
_ensure_entry_points_loaded()
|
|
158
|
-
return name.lower() in _SYNC_REGISTRY
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
def is_async_driver_registered(name: str) -> bool:
|
|
162
|
-
"""Check if an async driver is registered.
|
|
163
|
-
|
|
164
|
-
Args:
|
|
165
|
-
name: Provider name to check.
|
|
166
|
-
|
|
167
|
-
Returns:
|
|
168
|
-
True if the async driver is registered.
|
|
169
|
-
"""
|
|
170
|
-
_ensure_entry_points_loaded()
|
|
171
|
-
return name.lower() in _ASYNC_REGISTRY
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
def get_driver_factory(name: str) -> DriverFactory:
|
|
175
|
-
"""Get a registered sync driver factory by name.
|
|
176
|
-
|
|
177
|
-
Args:
|
|
178
|
-
name: Provider name.
|
|
179
|
-
|
|
180
|
-
Returns:
|
|
181
|
-
The factory function.
|
|
182
|
-
|
|
183
|
-
Raises:
|
|
184
|
-
ValueError: If the driver is not registered.
|
|
185
|
-
"""
|
|
186
|
-
_ensure_entry_points_loaded()
|
|
187
|
-
name = name.lower()
|
|
188
|
-
if name not in _SYNC_REGISTRY:
|
|
189
|
-
raise ValueError(f"Unsupported provider '{name}'")
|
|
190
|
-
return _SYNC_REGISTRY[name]
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
def get_async_driver_factory(name: str) -> DriverFactory:
|
|
194
|
-
"""Get a registered async driver factory by name.
|
|
195
|
-
|
|
196
|
-
Args:
|
|
197
|
-
name: Provider name.
|
|
198
|
-
|
|
199
|
-
Returns:
|
|
200
|
-
The factory function.
|
|
201
|
-
|
|
202
|
-
Raises:
|
|
203
|
-
ValueError: If the async driver is not registered.
|
|
204
|
-
"""
|
|
205
|
-
_ensure_entry_points_loaded()
|
|
206
|
-
name = name.lower()
|
|
207
|
-
if name not in _ASYNC_REGISTRY:
|
|
208
|
-
raise ValueError(f"Unsupported provider '{name}'")
|
|
209
|
-
return _ASYNC_REGISTRY[name]
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
def load_entry_point_drivers() -> tuple[int, int]:
|
|
213
|
-
"""Load drivers from installed packages via entry points.
|
|
214
|
-
|
|
215
|
-
This function scans for packages that define entry points in the
|
|
216
|
-
``prompture.drivers`` and ``prompture.async_drivers`` groups.
|
|
217
|
-
|
|
218
|
-
Returns:
|
|
219
|
-
A tuple of (sync_count, async_count) indicating how many drivers
|
|
220
|
-
were loaded from entry points.
|
|
221
|
-
|
|
222
|
-
Example pyproject.toml for a plugin package:
|
|
223
|
-
[project.entry-points."prompture.drivers"]
|
|
224
|
-
my_provider = "my_package.drivers:create_my_driver"
|
|
225
|
-
|
|
226
|
-
[project.entry-points."prompture.async_drivers"]
|
|
227
|
-
my_provider = "my_package.drivers:create_my_async_driver"
|
|
228
|
-
"""
|
|
229
|
-
global _entry_points_loaded
|
|
230
|
-
|
|
231
|
-
sync_count = 0
|
|
232
|
-
async_count = 0
|
|
233
|
-
|
|
234
|
-
# Python 3.9+ has importlib.metadata in stdlib
|
|
235
|
-
# Python 3.8 needs importlib_metadata backport
|
|
236
|
-
if sys.version_info >= (3, 10):
|
|
237
|
-
from importlib.metadata import entry_points
|
|
238
|
-
|
|
239
|
-
sync_eps = entry_points(group="prompture.drivers")
|
|
240
|
-
async_eps = entry_points(group="prompture.async_drivers")
|
|
241
|
-
else:
|
|
242
|
-
from importlib.metadata import entry_points
|
|
243
|
-
|
|
244
|
-
all_eps = entry_points()
|
|
245
|
-
sync_eps = all_eps.get("prompture.drivers", [])
|
|
246
|
-
async_eps = all_eps.get("prompture.async_drivers", [])
|
|
247
|
-
|
|
248
|
-
# Load sync drivers
|
|
249
|
-
for ep in sync_eps:
|
|
250
|
-
try:
|
|
251
|
-
# Skip if already registered (built-in drivers take precedence)
|
|
252
|
-
if ep.name.lower() in _SYNC_REGISTRY:
|
|
253
|
-
logger.debug("Skipping entry point driver '%s' (already registered)", ep.name)
|
|
254
|
-
continue
|
|
255
|
-
|
|
256
|
-
factory = ep.load()
|
|
257
|
-
_SYNC_REGISTRY[ep.name.lower()] = factory
|
|
258
|
-
sync_count += 1
|
|
259
|
-
logger.info("Loaded sync driver from entry point: %s", ep.name)
|
|
260
|
-
except Exception:
|
|
261
|
-
logger.exception("Failed to load sync driver entry point: %s", ep.name)
|
|
262
|
-
|
|
263
|
-
# Load async drivers
|
|
264
|
-
for ep in async_eps:
|
|
265
|
-
try:
|
|
266
|
-
# Skip if already registered (built-in drivers take precedence)
|
|
267
|
-
if ep.name.lower() in _ASYNC_REGISTRY:
|
|
268
|
-
logger.debug("Skipping entry point async driver '%s' (already registered)", ep.name)
|
|
269
|
-
continue
|
|
270
|
-
|
|
271
|
-
factory = ep.load()
|
|
272
|
-
_ASYNC_REGISTRY[ep.name.lower()] = factory
|
|
273
|
-
async_count += 1
|
|
274
|
-
logger.info("Loaded async driver from entry point: %s", ep.name)
|
|
275
|
-
except Exception:
|
|
276
|
-
logger.exception("Failed to load async driver entry point: %s", ep.name)
|
|
277
|
-
|
|
278
|
-
_entry_points_loaded = True
|
|
279
|
-
return (sync_count, async_count)
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
def _ensure_entry_points_loaded() -> None:
|
|
283
|
-
"""Ensure entry points have been loaded (lazy initialization)."""
|
|
284
|
-
global _entry_points_loaded
|
|
285
|
-
if not _entry_points_loaded:
|
|
286
|
-
load_entry_point_drivers()
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
def _get_sync_registry() -> dict[str, DriverFactory]:
|
|
290
|
-
"""Get the internal sync registry dict (for internal use by drivers/__init__.py)."""
|
|
291
|
-
_ensure_entry_points_loaded()
|
|
292
|
-
return _SYNC_REGISTRY
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
def _get_async_registry() -> dict[str, DriverFactory]:
|
|
296
|
-
"""Get the internal async registry dict (for internal use by drivers/async_registry.py)."""
|
|
297
|
-
_ensure_entry_points_loaded()
|
|
298
|
-
return _ASYNC_REGISTRY
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
def _reset_registries() -> None:
|
|
302
|
-
"""Reset registries to empty state (for testing only)."""
|
|
303
|
-
global _entry_points_loaded
|
|
304
|
-
_SYNC_REGISTRY.clear()
|
|
305
|
-
_ASYNC_REGISTRY.clear()
|
|
306
|
-
_entry_points_loaded = False
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|