prompture 0.0.49__py3-none-any.whl → 0.0.50__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +9 -0
- prompture/_version.py +2 -2
- prompture/discovery.py +10 -3
- prompture/drivers/__init__.py +15 -1
- prompture/drivers/async_azure_driver.py +256 -39
- prompture/drivers/async_registry.py +4 -1
- prompture/drivers/azure_config.py +146 -0
- prompture/drivers/azure_driver.py +293 -42
- prompture/settings.py +11 -1
- {prompture-0.0.49.dist-info → prompture-0.0.50.dist-info}/METADATA +1 -1
- {prompture-0.0.49.dist-info → prompture-0.0.50.dist-info}/RECORD +15 -14
- {prompture-0.0.49.dist-info → prompture-0.0.50.dist-info}/WHEEL +0 -0
- {prompture-0.0.49.dist-info → prompture-0.0.50.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.49.dist-info → prompture-0.0.50.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.49.dist-info → prompture-0.0.50.dist-info}/top_level.txt +0 -0
prompture/__init__.py
CHANGED
|
@@ -60,6 +60,8 @@ from .drivers import (
|
|
|
60
60
|
OllamaDriver,
|
|
61
61
|
OpenAIDriver,
|
|
62
62
|
OpenRouterDriver,
|
|
63
|
+
# Azure config API
|
|
64
|
+
clear_azure_configs,
|
|
63
65
|
get_driver,
|
|
64
66
|
get_driver_for_model,
|
|
65
67
|
# Plugin registration API
|
|
@@ -69,8 +71,11 @@ from .drivers import (
|
|
|
69
71
|
list_registered_drivers,
|
|
70
72
|
load_entry_point_drivers,
|
|
71
73
|
register_async_driver,
|
|
74
|
+
register_azure_config,
|
|
72
75
|
register_driver,
|
|
76
|
+
set_azure_config_resolver,
|
|
73
77
|
unregister_async_driver,
|
|
78
|
+
unregister_azure_config,
|
|
74
79
|
unregister_driver,
|
|
75
80
|
)
|
|
76
81
|
from .field_definitions import (
|
|
@@ -247,6 +252,7 @@ __all__ = [
|
|
|
247
252
|
"clean_json_text",
|
|
248
253
|
"clean_json_text_with_ai",
|
|
249
254
|
"clean_toon_text",
|
|
255
|
+
"clear_azure_configs",
|
|
250
256
|
"clear_persona_registry",
|
|
251
257
|
"clear_registry",
|
|
252
258
|
"configure_cache",
|
|
@@ -292,6 +298,7 @@ __all__ = [
|
|
|
292
298
|
"normalize_enum_value",
|
|
293
299
|
"refresh_rates_cache",
|
|
294
300
|
"register_async_driver",
|
|
301
|
+
"register_azure_config",
|
|
295
302
|
"register_driver",
|
|
296
303
|
"register_field",
|
|
297
304
|
"register_persona",
|
|
@@ -301,9 +308,11 @@ __all__ = [
|
|
|
301
308
|
"reset_registry",
|
|
302
309
|
"reset_trait_registry",
|
|
303
310
|
"run_suite_from_spec",
|
|
311
|
+
"set_azure_config_resolver",
|
|
304
312
|
"stepwise_extract_with_model",
|
|
305
313
|
"tool_from_function",
|
|
306
314
|
"unregister_async_driver",
|
|
315
|
+
"unregister_azure_config",
|
|
307
316
|
"unregister_driver",
|
|
308
317
|
"validate_against_schema",
|
|
309
318
|
"validate_enum_value",
|
prompture/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.0.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 0,
|
|
31
|
+
__version__ = version = '0.0.50'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 0, 50)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
prompture/discovery.py
CHANGED
|
@@ -89,10 +89,17 @@ def get_available_models(
|
|
|
89
89
|
if settings.openai_api_key or os.getenv("OPENAI_API_KEY"):
|
|
90
90
|
is_configured = True
|
|
91
91
|
elif provider == "azure":
|
|
92
|
+
from .drivers.azure_config import has_azure_config_resolver, has_registered_configs
|
|
93
|
+
|
|
92
94
|
if (
|
|
93
|
-
(
|
|
94
|
-
|
|
95
|
-
|
|
95
|
+
(
|
|
96
|
+
(settings.azure_api_key or os.getenv("AZURE_API_KEY"))
|
|
97
|
+
and (settings.azure_api_endpoint or os.getenv("AZURE_API_ENDPOINT"))
|
|
98
|
+
)
|
|
99
|
+
or (settings.azure_claude_api_key or os.getenv("AZURE_CLAUDE_API_KEY"))
|
|
100
|
+
or (settings.azure_mistral_api_key or os.getenv("AZURE_MISTRAL_API_KEY"))
|
|
101
|
+
or has_registered_configs()
|
|
102
|
+
or has_azure_config_resolver()
|
|
96
103
|
):
|
|
97
104
|
is_configured = True
|
|
98
105
|
elif provider == "claude":
|
prompture/drivers/__init__.py
CHANGED
|
@@ -44,6 +44,12 @@ from .async_openai_driver import AsyncOpenAIDriver
|
|
|
44
44
|
from .async_openrouter_driver import AsyncOpenRouterDriver
|
|
45
45
|
from .async_registry import ASYNC_DRIVER_REGISTRY, get_async_driver, get_async_driver_for_model
|
|
46
46
|
from .async_zai_driver import AsyncZaiDriver
|
|
47
|
+
from .azure_config import (
|
|
48
|
+
clear_azure_configs,
|
|
49
|
+
register_azure_config,
|
|
50
|
+
set_azure_config_resolver,
|
|
51
|
+
unregister_azure_config,
|
|
52
|
+
)
|
|
47
53
|
from .azure_driver import AzureDriver
|
|
48
54
|
from .claude_driver import ClaudeDriver
|
|
49
55
|
from .google_driver import GoogleDriver
|
|
@@ -100,7 +106,10 @@ register_driver(
|
|
|
100
106
|
register_driver(
|
|
101
107
|
"azure",
|
|
102
108
|
lambda model=None: AzureDriver(
|
|
103
|
-
api_key=settings.azure_api_key,
|
|
109
|
+
api_key=settings.azure_api_key,
|
|
110
|
+
endpoint=settings.azure_api_endpoint,
|
|
111
|
+
deployment_id=settings.azure_deployment_id,
|
|
112
|
+
model=model or "gpt-4o-mini",
|
|
104
113
|
),
|
|
105
114
|
overwrite=True,
|
|
106
115
|
)
|
|
@@ -249,6 +258,8 @@ __all__ = [
|
|
|
249
258
|
"OpenAIDriver",
|
|
250
259
|
"OpenRouterDriver",
|
|
251
260
|
"ZaiDriver",
|
|
261
|
+
# Azure config API
|
|
262
|
+
"clear_azure_configs",
|
|
252
263
|
"get_async_driver",
|
|
253
264
|
"get_async_driver_for_model",
|
|
254
265
|
# Factory functions
|
|
@@ -260,8 +271,11 @@ __all__ = [
|
|
|
260
271
|
"list_registered_drivers",
|
|
261
272
|
"load_entry_point_drivers",
|
|
262
273
|
"register_async_driver",
|
|
274
|
+
"register_azure_config",
|
|
263
275
|
# Registry functions (public API)
|
|
264
276
|
"register_driver",
|
|
277
|
+
"set_azure_config_resolver",
|
|
265
278
|
"unregister_async_driver",
|
|
279
|
+
"unregister_azure_config",
|
|
266
280
|
"unregister_driver",
|
|
267
281
|
]
|
|
@@ -1,4 +1,7 @@
|
|
|
1
|
-
"""Async Azure
|
|
1
|
+
"""Async Azure driver with multi-endpoint and multi-backend support.
|
|
2
|
+
|
|
3
|
+
Requires the ``openai`` package (>=1.0.0). Claude backend also requires ``anthropic``.
|
|
4
|
+
"""
|
|
2
5
|
|
|
3
6
|
from __future__ import annotations
|
|
4
7
|
|
|
@@ -11,8 +14,14 @@ try:
|
|
|
11
14
|
except Exception:
|
|
12
15
|
AsyncAzureOpenAI = None
|
|
13
16
|
|
|
17
|
+
try:
|
|
18
|
+
import anthropic
|
|
19
|
+
except Exception:
|
|
20
|
+
anthropic = None
|
|
21
|
+
|
|
14
22
|
from ..async_driver import AsyncDriver
|
|
15
23
|
from ..cost_mixin import CostMixin, prepare_strict_schema
|
|
24
|
+
from .azure_config import classify_backend, resolve_config
|
|
16
25
|
from .azure_driver import AzureDriver
|
|
17
26
|
|
|
18
27
|
|
|
@@ -31,27 +40,15 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
|
|
|
31
40
|
deployment_id: str | None = None,
|
|
32
41
|
model: str = "gpt-4o-mini",
|
|
33
42
|
):
|
|
34
|
-
self.api_key = api_key or os.getenv("AZURE_API_KEY")
|
|
35
|
-
self.endpoint = endpoint or os.getenv("AZURE_API_ENDPOINT")
|
|
36
|
-
self.deployment_id = deployment_id or os.getenv("AZURE_DEPLOYMENT_ID")
|
|
37
|
-
self.api_version = os.getenv("AZURE_API_VERSION", "2023-07-01-preview")
|
|
38
43
|
self.model = model
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
if AsyncAzureOpenAI:
|
|
48
|
-
self.client = AsyncAzureOpenAI(
|
|
49
|
-
api_key=self.api_key,
|
|
50
|
-
api_version=self.api_version,
|
|
51
|
-
azure_endpoint=self.endpoint,
|
|
52
|
-
)
|
|
53
|
-
else:
|
|
54
|
-
self.client = None
|
|
44
|
+
self._default_config = {
|
|
45
|
+
"api_key": api_key or os.getenv("AZURE_API_KEY"),
|
|
46
|
+
"endpoint": endpoint or os.getenv("AZURE_API_ENDPOINT"),
|
|
47
|
+
"deployment_id": deployment_id or os.getenv("AZURE_DEPLOYMENT_ID"),
|
|
48
|
+
"api_version": os.getenv("AZURE_API_VERSION", "2024-02-15-preview"),
|
|
49
|
+
}
|
|
50
|
+
self._openai_clients: dict[tuple[str, str], AsyncAzureOpenAI] = {}
|
|
51
|
+
self._anthropic_clients: dict[tuple[str, str], Any] = {}
|
|
55
52
|
|
|
56
53
|
supports_messages = True
|
|
57
54
|
|
|
@@ -60,6 +57,36 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
|
|
|
60
57
|
|
|
61
58
|
return _prepare_openai_vision_messages(messages)
|
|
62
59
|
|
|
60
|
+
def _resolve_model_config(self, model: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
61
|
+
"""Resolve Azure config for this model using the priority chain."""
|
|
62
|
+
override = options.pop("azure_config", None)
|
|
63
|
+
return resolve_config(model, override=override, default_config=self._default_config)
|
|
64
|
+
|
|
65
|
+
def _get_openai_client(self, config: dict[str, Any]) -> AsyncAzureOpenAI:
|
|
66
|
+
"""Get or create an AsyncAzureOpenAI client for the given config."""
|
|
67
|
+
if AsyncAzureOpenAI is None:
|
|
68
|
+
raise RuntimeError("openai package (>=1.0.0) with AsyncAzureOpenAI not installed")
|
|
69
|
+
cache_key = (config["endpoint"], config["api_key"])
|
|
70
|
+
if cache_key not in self._openai_clients:
|
|
71
|
+
self._openai_clients[cache_key] = AsyncAzureOpenAI(
|
|
72
|
+
api_key=config["api_key"],
|
|
73
|
+
api_version=config.get("api_version", "2024-02-15-preview"),
|
|
74
|
+
azure_endpoint=config["endpoint"],
|
|
75
|
+
)
|
|
76
|
+
return self._openai_clients[cache_key]
|
|
77
|
+
|
|
78
|
+
def _get_anthropic_client(self, config: dict[str, Any]) -> Any:
|
|
79
|
+
"""Get or create an AsyncAnthropic client for the given Azure config."""
|
|
80
|
+
if anthropic is None:
|
|
81
|
+
raise RuntimeError("anthropic package not installed (required for Claude on Azure)")
|
|
82
|
+
cache_key = (config["endpoint"], config["api_key"])
|
|
83
|
+
if cache_key not in self._anthropic_clients:
|
|
84
|
+
self._anthropic_clients[cache_key] = anthropic.AsyncAnthropic(
|
|
85
|
+
base_url=config["endpoint"],
|
|
86
|
+
api_key=config["api_key"],
|
|
87
|
+
)
|
|
88
|
+
return self._anthropic_clients[cache_key]
|
|
89
|
+
|
|
63
90
|
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
64
91
|
messages = [{"role": "user", "content": prompt}]
|
|
65
92
|
return await self._do_generate(messages, options)
|
|
@@ -68,10 +95,26 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
|
|
|
68
95
|
return await self._do_generate(self._prepare_messages(messages), options)
|
|
69
96
|
|
|
70
97
|
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
71
|
-
if self.client is None:
|
|
72
|
-
raise RuntimeError("openai package (>=1.0.0) with AsyncAzureOpenAI not installed")
|
|
73
|
-
|
|
74
98
|
model = options.get("model", self.model)
|
|
99
|
+
config = self._resolve_model_config(model, options)
|
|
100
|
+
backend = classify_backend(model)
|
|
101
|
+
|
|
102
|
+
if backend == "claude":
|
|
103
|
+
return await self._generate_claude(messages, options, config, model)
|
|
104
|
+
else:
|
|
105
|
+
return await self._generate_openai(messages, options, config, model)
|
|
106
|
+
|
|
107
|
+
async def _generate_openai(
|
|
108
|
+
self,
|
|
109
|
+
messages: list[dict[str, Any]],
|
|
110
|
+
options: dict[str, Any],
|
|
111
|
+
config: dict[str, Any],
|
|
112
|
+
model: str,
|
|
113
|
+
) -> dict[str, Any]:
|
|
114
|
+
"""Generate via Azure OpenAI (or Mistral OpenAI-compat) endpoint."""
|
|
115
|
+
client = self._get_openai_client(config)
|
|
116
|
+
deployment_id = config.get("deployment_id") or model
|
|
117
|
+
|
|
75
118
|
model_config = self._get_model_config("azure", model)
|
|
76
119
|
tokens_param = model_config["tokens_param"]
|
|
77
120
|
supports_temperature = model_config["supports_temperature"]
|
|
@@ -79,7 +122,7 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
|
|
|
79
122
|
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
80
123
|
|
|
81
124
|
kwargs = {
|
|
82
|
-
"model":
|
|
125
|
+
"model": deployment_id,
|
|
83
126
|
"messages": messages,
|
|
84
127
|
}
|
|
85
128
|
kwargs[tokens_param] = opts.get("max_tokens", 512)
|
|
@@ -87,7 +130,6 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
|
|
|
87
130
|
if supports_temperature and "temperature" in opts:
|
|
88
131
|
kwargs["temperature"] = opts["temperature"]
|
|
89
132
|
|
|
90
|
-
# Native JSON mode support
|
|
91
133
|
if options.get("json_mode"):
|
|
92
134
|
json_schema = options.get("json_schema")
|
|
93
135
|
if json_schema:
|
|
@@ -103,7 +145,7 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
|
|
|
103
145
|
else:
|
|
104
146
|
kwargs["response_format"] = {"type": "json_object"}
|
|
105
147
|
|
|
106
|
-
resp = await
|
|
148
|
+
resp = await client.chat.completions.create(**kwargs)
|
|
107
149
|
|
|
108
150
|
usage = getattr(resp, "usage", None)
|
|
109
151
|
prompt_tokens = getattr(usage, "prompt_tokens", 0)
|
|
@@ -119,12 +161,84 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
|
|
|
119
161
|
"cost": round(total_cost, 6),
|
|
120
162
|
"raw_response": resp.model_dump(),
|
|
121
163
|
"model_name": model,
|
|
122
|
-
"deployment_id":
|
|
164
|
+
"deployment_id": deployment_id,
|
|
123
165
|
}
|
|
124
166
|
|
|
125
167
|
text = resp.choices[0].message.content
|
|
126
168
|
return {"text": text, "meta": meta}
|
|
127
169
|
|
|
170
|
+
async def _generate_claude(
|
|
171
|
+
self,
|
|
172
|
+
messages: list[dict[str, Any]],
|
|
173
|
+
options: dict[str, Any],
|
|
174
|
+
config: dict[str, Any],
|
|
175
|
+
model: str,
|
|
176
|
+
) -> dict[str, Any]:
|
|
177
|
+
"""Generate via Anthropic SDK with Azure endpoint."""
|
|
178
|
+
client = self._get_anthropic_client(config)
|
|
179
|
+
|
|
180
|
+
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
181
|
+
|
|
182
|
+
system_content = None
|
|
183
|
+
api_messages = []
|
|
184
|
+
for msg in messages:
|
|
185
|
+
if msg.get("role") == "system":
|
|
186
|
+
system_content = msg.get("content", "")
|
|
187
|
+
else:
|
|
188
|
+
api_messages.append(msg)
|
|
189
|
+
|
|
190
|
+
common_kwargs: dict[str, Any] = {
|
|
191
|
+
"model": model,
|
|
192
|
+
"messages": api_messages,
|
|
193
|
+
"temperature": opts["temperature"],
|
|
194
|
+
"max_tokens": opts["max_tokens"],
|
|
195
|
+
}
|
|
196
|
+
if system_content:
|
|
197
|
+
common_kwargs["system"] = system_content
|
|
198
|
+
|
|
199
|
+
if options.get("json_mode"):
|
|
200
|
+
json_schema = options.get("json_schema")
|
|
201
|
+
if json_schema:
|
|
202
|
+
tool_def = {
|
|
203
|
+
"name": "extract_json",
|
|
204
|
+
"description": "Extract structured data matching the schema",
|
|
205
|
+
"input_schema": json_schema,
|
|
206
|
+
}
|
|
207
|
+
resp = await client.messages.create(
|
|
208
|
+
**common_kwargs,
|
|
209
|
+
tools=[tool_def],
|
|
210
|
+
tool_choice={"type": "tool", "name": "extract_json"},
|
|
211
|
+
)
|
|
212
|
+
text = ""
|
|
213
|
+
for block in resp.content:
|
|
214
|
+
if block.type == "tool_use":
|
|
215
|
+
text = json.dumps(block.input)
|
|
216
|
+
break
|
|
217
|
+
else:
|
|
218
|
+
resp = await client.messages.create(**common_kwargs)
|
|
219
|
+
text = resp.content[0].text
|
|
220
|
+
else:
|
|
221
|
+
resp = await client.messages.create(**common_kwargs)
|
|
222
|
+
text = resp.content[0].text
|
|
223
|
+
|
|
224
|
+
prompt_tokens = resp.usage.input_tokens
|
|
225
|
+
completion_tokens = resp.usage.output_tokens
|
|
226
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
227
|
+
|
|
228
|
+
total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
|
|
229
|
+
|
|
230
|
+
meta = {
|
|
231
|
+
"prompt_tokens": prompt_tokens,
|
|
232
|
+
"completion_tokens": completion_tokens,
|
|
233
|
+
"total_tokens": total_tokens,
|
|
234
|
+
"cost": round(total_cost, 6),
|
|
235
|
+
"raw_response": dict(resp),
|
|
236
|
+
"model_name": model,
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
text_result = text or ""
|
|
240
|
+
return {"text": text_result, "meta": meta}
|
|
241
|
+
|
|
128
242
|
# ------------------------------------------------------------------
|
|
129
243
|
# Tool use
|
|
130
244
|
# ------------------------------------------------------------------
|
|
@@ -136,10 +250,27 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
|
|
|
136
250
|
options: dict[str, Any],
|
|
137
251
|
) -> dict[str, Any]:
|
|
138
252
|
"""Generate a response that may include tool calls."""
|
|
139
|
-
if self.client is None:
|
|
140
|
-
raise RuntimeError("openai package (>=1.0.0) with AsyncAzureOpenAI not installed")
|
|
141
|
-
|
|
142
253
|
model = options.get("model", self.model)
|
|
254
|
+
config = self._resolve_model_config(model, options)
|
|
255
|
+
backend = classify_backend(model)
|
|
256
|
+
|
|
257
|
+
if backend == "claude":
|
|
258
|
+
return await self._generate_claude_with_tools(messages, tools, options, config, model)
|
|
259
|
+
else:
|
|
260
|
+
return await self._generate_openai_with_tools(messages, tools, options, config, model)
|
|
261
|
+
|
|
262
|
+
async def _generate_openai_with_tools(
|
|
263
|
+
self,
|
|
264
|
+
messages: list[dict[str, Any]],
|
|
265
|
+
tools: list[dict[str, Any]],
|
|
266
|
+
options: dict[str, Any],
|
|
267
|
+
config: dict[str, Any],
|
|
268
|
+
model: str,
|
|
269
|
+
) -> dict[str, Any]:
|
|
270
|
+
"""Tool calling via Azure OpenAI endpoint."""
|
|
271
|
+
client = self._get_openai_client(config)
|
|
272
|
+
deployment_id = config.get("deployment_id") or model
|
|
273
|
+
|
|
143
274
|
model_config = self._get_model_config("azure", model)
|
|
144
275
|
tokens_param = model_config["tokens_param"]
|
|
145
276
|
supports_temperature = model_config["supports_temperature"]
|
|
@@ -149,7 +280,7 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
|
|
|
149
280
|
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
150
281
|
|
|
151
282
|
kwargs: dict[str, Any] = {
|
|
152
|
-
"model":
|
|
283
|
+
"model": deployment_id,
|
|
153
284
|
"messages": messages,
|
|
154
285
|
"tools": tools,
|
|
155
286
|
}
|
|
@@ -158,7 +289,7 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
|
|
|
158
289
|
if supports_temperature and "temperature" in opts:
|
|
159
290
|
kwargs["temperature"] = opts["temperature"]
|
|
160
291
|
|
|
161
|
-
resp = await
|
|
292
|
+
resp = await client.chat.completions.create(**kwargs)
|
|
162
293
|
|
|
163
294
|
usage = getattr(resp, "usage", None)
|
|
164
295
|
prompt_tokens = getattr(usage, "prompt_tokens", 0)
|
|
@@ -173,7 +304,7 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
|
|
|
173
304
|
"cost": round(total_cost, 6),
|
|
174
305
|
"raw_response": resp.model_dump(),
|
|
175
306
|
"model_name": model,
|
|
176
|
-
"deployment_id":
|
|
307
|
+
"deployment_id": deployment_id,
|
|
177
308
|
}
|
|
178
309
|
|
|
179
310
|
choice = resp.choices[0]
|
|
@@ -187,11 +318,13 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
|
|
|
187
318
|
args = json.loads(tc.function.arguments)
|
|
188
319
|
except (json.JSONDecodeError, TypeError):
|
|
189
320
|
args = {}
|
|
190
|
-
tool_calls_out.append(
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
321
|
+
tool_calls_out.append(
|
|
322
|
+
{
|
|
323
|
+
"id": tc.id,
|
|
324
|
+
"name": tc.function.name,
|
|
325
|
+
"arguments": args,
|
|
326
|
+
}
|
|
327
|
+
)
|
|
195
328
|
|
|
196
329
|
return {
|
|
197
330
|
"text": text,
|
|
@@ -199,3 +332,87 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
|
|
|
199
332
|
"tool_calls": tool_calls_out,
|
|
200
333
|
"stop_reason": stop_reason,
|
|
201
334
|
}
|
|
335
|
+
|
|
336
|
+
async def _generate_claude_with_tools(
|
|
337
|
+
self,
|
|
338
|
+
messages: list[dict[str, Any]],
|
|
339
|
+
tools: list[dict[str, Any]],
|
|
340
|
+
options: dict[str, Any],
|
|
341
|
+
config: dict[str, Any],
|
|
342
|
+
model: str,
|
|
343
|
+
) -> dict[str, Any]:
|
|
344
|
+
"""Tool calling via Anthropic SDK with Azure endpoint."""
|
|
345
|
+
client = self._get_anthropic_client(config)
|
|
346
|
+
|
|
347
|
+
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
348
|
+
|
|
349
|
+
system_content = None
|
|
350
|
+
api_messages: list[dict[str, Any]] = []
|
|
351
|
+
for msg in messages:
|
|
352
|
+
if msg.get("role") == "system":
|
|
353
|
+
system_content = msg.get("content", "")
|
|
354
|
+
else:
|
|
355
|
+
api_messages.append(msg)
|
|
356
|
+
|
|
357
|
+
anthropic_tools = []
|
|
358
|
+
for t in tools:
|
|
359
|
+
if "type" in t and t["type"] == "function":
|
|
360
|
+
fn = t["function"]
|
|
361
|
+
anthropic_tools.append(
|
|
362
|
+
{
|
|
363
|
+
"name": fn["name"],
|
|
364
|
+
"description": fn.get("description", ""),
|
|
365
|
+
"input_schema": fn.get("parameters", {"type": "object", "properties": {}}),
|
|
366
|
+
}
|
|
367
|
+
)
|
|
368
|
+
elif "input_schema" in t:
|
|
369
|
+
anthropic_tools.append(t)
|
|
370
|
+
else:
|
|
371
|
+
anthropic_tools.append(t)
|
|
372
|
+
|
|
373
|
+
kwargs: dict[str, Any] = {
|
|
374
|
+
"model": model,
|
|
375
|
+
"messages": api_messages,
|
|
376
|
+
"temperature": opts["temperature"],
|
|
377
|
+
"max_tokens": opts["max_tokens"],
|
|
378
|
+
"tools": anthropic_tools,
|
|
379
|
+
}
|
|
380
|
+
if system_content:
|
|
381
|
+
kwargs["system"] = system_content
|
|
382
|
+
|
|
383
|
+
resp = await client.messages.create(**kwargs)
|
|
384
|
+
|
|
385
|
+
prompt_tokens = resp.usage.input_tokens
|
|
386
|
+
completion_tokens = resp.usage.output_tokens
|
|
387
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
388
|
+
total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
|
|
389
|
+
|
|
390
|
+
meta = {
|
|
391
|
+
"prompt_tokens": prompt_tokens,
|
|
392
|
+
"completion_tokens": completion_tokens,
|
|
393
|
+
"total_tokens": total_tokens,
|
|
394
|
+
"cost": round(total_cost, 6),
|
|
395
|
+
"raw_response": dict(resp),
|
|
396
|
+
"model_name": model,
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
text = ""
|
|
400
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
401
|
+
for block in resp.content:
|
|
402
|
+
if block.type == "text":
|
|
403
|
+
text += block.text
|
|
404
|
+
elif block.type == "tool_use":
|
|
405
|
+
tool_calls_out.append(
|
|
406
|
+
{
|
|
407
|
+
"id": block.id,
|
|
408
|
+
"name": block.name,
|
|
409
|
+
"arguments": block.input,
|
|
410
|
+
}
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
return {
|
|
414
|
+
"text": text,
|
|
415
|
+
"meta": meta,
|
|
416
|
+
"tool_calls": tool_calls_out,
|
|
417
|
+
"stop_reason": resp.stop_reason,
|
|
418
|
+
}
|
|
@@ -62,7 +62,10 @@ register_async_driver(
|
|
|
62
62
|
register_async_driver(
|
|
63
63
|
"azure",
|
|
64
64
|
lambda model=None: AsyncAzureDriver(
|
|
65
|
-
api_key=settings.azure_api_key,
|
|
65
|
+
api_key=settings.azure_api_key,
|
|
66
|
+
endpoint=settings.azure_api_endpoint,
|
|
67
|
+
deployment_id=settings.azure_deployment_id,
|
|
68
|
+
model=model or "gpt-4o-mini",
|
|
66
69
|
),
|
|
67
70
|
overwrite=True,
|
|
68
71
|
)
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
"""Azure per-model configuration resolution.
|
|
2
|
+
|
|
3
|
+
Supports multiple Azure endpoints, API keys, and deployment names for
|
|
4
|
+
different models, as well as routing to different API backends (OpenAI,
|
|
5
|
+
Claude, Mistral) based on the model prefix.
|
|
6
|
+
|
|
7
|
+
Usage::
|
|
8
|
+
|
|
9
|
+
from prompture.drivers.azure_config import (
|
|
10
|
+
register_azure_config,
|
|
11
|
+
set_azure_config_resolver,
|
|
12
|
+
resolve_config,
|
|
13
|
+
classify_backend,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
# Register per-model configs
|
|
17
|
+
register_azure_config("gpt-4o", {
|
|
18
|
+
"endpoint": "https://my-eastus.openai.azure.com/",
|
|
19
|
+
"api_key": "key-eastus",
|
|
20
|
+
"deployment_id": "gpt-4o",
|
|
21
|
+
})
|
|
22
|
+
|
|
23
|
+
# Or use a resolver callback
|
|
24
|
+
set_azure_config_resolver(lambda model: my_db.get_config(model))
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
from __future__ import annotations
|
|
28
|
+
|
|
29
|
+
import threading
|
|
30
|
+
from typing import Any, Callable
|
|
31
|
+
|
|
32
|
+
# Model prefix → backend type
|
|
33
|
+
AZURE_BACKEND_MAP: dict[str, str] = {
|
|
34
|
+
"gpt-": "openai",
|
|
35
|
+
"o1-": "openai",
|
|
36
|
+
"o3-": "openai",
|
|
37
|
+
"o4-": "openai",
|
|
38
|
+
"claude-": "claude",
|
|
39
|
+
"mistral-": "mistral",
|
|
40
|
+
"mixtral-": "mistral",
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
_lock = threading.Lock()
|
|
44
|
+
_config_registry: dict[str, dict[str, Any]] = {}
|
|
45
|
+
_config_resolver: Callable[[str], dict[str, Any]] | None = None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def classify_backend(model: str) -> str:
|
|
49
|
+
"""Determine API backend for a model. Default: ``'openai'``."""
|
|
50
|
+
model_lower = model.lower()
|
|
51
|
+
for prefix, backend in AZURE_BACKEND_MAP.items():
|
|
52
|
+
if model_lower.startswith(prefix):
|
|
53
|
+
return backend
|
|
54
|
+
return "openai"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def register_azure_config(name: str, config: dict[str, Any]) -> None:
|
|
58
|
+
"""Register a named Azure config (deployment name, region, etc.).
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
name: Model name key (e.g. ``"gpt-4o"``).
|
|
62
|
+
config: Dict with ``endpoint``, ``api_key``, and optionally
|
|
63
|
+
``deployment_id``, ``api_version``.
|
|
64
|
+
"""
|
|
65
|
+
with _lock:
|
|
66
|
+
_config_registry[name] = config
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def unregister_azure_config(name: str) -> None:
|
|
70
|
+
"""Remove a previously registered Azure config."""
|
|
71
|
+
with _lock:
|
|
72
|
+
_config_registry.pop(name, None)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def clear_azure_configs() -> None:
|
|
76
|
+
"""Remove all registered Azure configs."""
|
|
77
|
+
with _lock:
|
|
78
|
+
_config_registry.clear()
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def set_azure_config_resolver(
|
|
82
|
+
resolver: Callable[[str], dict[str, Any]] | None,
|
|
83
|
+
) -> None:
|
|
84
|
+
"""Set a callback that resolves config per deployment/model name.
|
|
85
|
+
|
|
86
|
+
Pass ``None`` to clear the resolver.
|
|
87
|
+
"""
|
|
88
|
+
global _config_resolver
|
|
89
|
+
with _lock:
|
|
90
|
+
_config_resolver = resolver
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def has_azure_config_resolver() -> bool:
|
|
94
|
+
"""Return ``True`` if a config resolver callback is registered."""
|
|
95
|
+
with _lock:
|
|
96
|
+
return _config_resolver is not None
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def has_registered_configs() -> bool:
|
|
100
|
+
"""Return ``True`` if any named configs are registered."""
|
|
101
|
+
with _lock:
|
|
102
|
+
return len(_config_registry) > 0
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def resolve_config(
|
|
106
|
+
model: str,
|
|
107
|
+
override: dict[str, Any] | None = None,
|
|
108
|
+
default_config: dict[str, Any] | None = None,
|
|
109
|
+
) -> dict[str, Any]:
|
|
110
|
+
"""Resolve Azure config for a model using priority chain.
|
|
111
|
+
|
|
112
|
+
Priority:
|
|
113
|
+
1. Per-call ``override`` (highest)
|
|
114
|
+
2. Resolver callback (if registered)
|
|
115
|
+
3. Registry lookup (by model name)
|
|
116
|
+
4. ``default_config`` (env vars fallback)
|
|
117
|
+
|
|
118
|
+
Raises:
|
|
119
|
+
ValueError: If no config could be resolved.
|
|
120
|
+
"""
|
|
121
|
+
# 1. Per-call override
|
|
122
|
+
if override:
|
|
123
|
+
return override
|
|
124
|
+
|
|
125
|
+
# 2. Resolver callback
|
|
126
|
+
with _lock:
|
|
127
|
+
resolver = _config_resolver
|
|
128
|
+
if resolver:
|
|
129
|
+
resolved = resolver(model)
|
|
130
|
+
if resolved:
|
|
131
|
+
return resolved
|
|
132
|
+
|
|
133
|
+
# 3. Registry lookup (by model name)
|
|
134
|
+
with _lock:
|
|
135
|
+
if model in _config_registry:
|
|
136
|
+
return _config_registry[model]
|
|
137
|
+
|
|
138
|
+
# 4. Default (env vars) — only use if it has at least an endpoint or api_key
|
|
139
|
+
if default_config and (default_config.get("endpoint") or default_config.get("api_key")):
|
|
140
|
+
return default_config
|
|
141
|
+
|
|
142
|
+
raise ValueError(
|
|
143
|
+
f"No Azure config found for '{model}'. "
|
|
144
|
+
"Set env vars, register a config with register_azure_config(), "
|
|
145
|
+
"or provide azure_config in options."
|
|
146
|
+
)
|
|
@@ -1,5 +1,12 @@
|
|
|
1
|
-
"""Driver for Azure OpenAI Service
|
|
2
|
-
|
|
1
|
+
"""Driver for Azure OpenAI Service with multi-endpoint and multi-backend support.
|
|
2
|
+
|
|
3
|
+
Supports:
|
|
4
|
+
- Multiple Azure endpoints with per-model config resolution
|
|
5
|
+
- OpenAI models (gpt-*, o1-*, o3-*, o4-*) via AzureOpenAI SDK
|
|
6
|
+
- Claude models (claude-*) via Anthropic SDK with Azure endpoint
|
|
7
|
+
- Mistral models (mistral-*, mixtral-*) via OpenAI-compatible protocol
|
|
8
|
+
|
|
9
|
+
Requires the ``openai`` package. Claude backend also requires ``anthropic``.
|
|
3
10
|
"""
|
|
4
11
|
|
|
5
12
|
import json
|
|
@@ -11,8 +18,14 @@ try:
|
|
|
11
18
|
except Exception:
|
|
12
19
|
AzureOpenAI = None
|
|
13
20
|
|
|
21
|
+
try:
|
|
22
|
+
import anthropic
|
|
23
|
+
except Exception:
|
|
24
|
+
anthropic = None
|
|
25
|
+
|
|
14
26
|
from ..cost_mixin import CostMixin, prepare_strict_schema
|
|
15
27
|
from ..driver import Driver
|
|
28
|
+
from .azure_config import classify_backend, resolve_config
|
|
16
29
|
|
|
17
30
|
|
|
18
31
|
class AzureDriver(CostMixin, Driver):
|
|
@@ -59,6 +72,32 @@ class AzureDriver(CostMixin, Driver):
|
|
|
59
72
|
"tokens_param": "max_tokens",
|
|
60
73
|
"supports_temperature": True,
|
|
61
74
|
},
|
|
75
|
+
# Claude models on Azure
|
|
76
|
+
"claude-sonnet-4-20250514": {
|
|
77
|
+
"prompt": 0.003,
|
|
78
|
+
"completion": 0.015,
|
|
79
|
+
"tokens_param": "max_tokens",
|
|
80
|
+
"supports_temperature": True,
|
|
81
|
+
},
|
|
82
|
+
"claude-3-7-sonnet-20250219": {
|
|
83
|
+
"prompt": 0.003,
|
|
84
|
+
"completion": 0.015,
|
|
85
|
+
"tokens_param": "max_tokens",
|
|
86
|
+
"supports_temperature": True,
|
|
87
|
+
},
|
|
88
|
+
"claude-3-5-haiku-20241022": {
|
|
89
|
+
"prompt": 0.0008,
|
|
90
|
+
"completion": 0.004,
|
|
91
|
+
"tokens_param": "max_tokens",
|
|
92
|
+
"supports_temperature": True,
|
|
93
|
+
},
|
|
94
|
+
# Mistral models on Azure
|
|
95
|
+
"mistral-large-latest": {
|
|
96
|
+
"prompt": 0.004,
|
|
97
|
+
"completion": 0.012,
|
|
98
|
+
"tokens_param": "max_tokens",
|
|
99
|
+
"supports_temperature": True,
|
|
100
|
+
},
|
|
62
101
|
}
|
|
63
102
|
|
|
64
103
|
def __init__(
|
|
@@ -68,28 +107,17 @@ class AzureDriver(CostMixin, Driver):
|
|
|
68
107
|
deployment_id: str | None = None,
|
|
69
108
|
model: str = "gpt-4o-mini",
|
|
70
109
|
):
|
|
71
|
-
self.api_key = api_key or os.getenv("AZURE_API_KEY")
|
|
72
|
-
self.endpoint = endpoint or os.getenv("AZURE_API_ENDPOINT")
|
|
73
|
-
self.deployment_id = deployment_id or os.getenv("AZURE_DEPLOYMENT_ID")
|
|
74
|
-
self.api_version = os.getenv("AZURE_API_VERSION", "2023-07-01-preview")
|
|
75
110
|
self.model = model
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
self.client = AzureOpenAI(
|
|
87
|
-
api_key=self.api_key,
|
|
88
|
-
api_version=self.api_version,
|
|
89
|
-
azure_endpoint=self.endpoint,
|
|
90
|
-
)
|
|
91
|
-
else:
|
|
92
|
-
self.client = None
|
|
111
|
+
# Store default config from env vars (may be partial/None)
|
|
112
|
+
self._default_config = {
|
|
113
|
+
"api_key": api_key or os.getenv("AZURE_API_KEY"),
|
|
114
|
+
"endpoint": endpoint or os.getenv("AZURE_API_ENDPOINT"),
|
|
115
|
+
"deployment_id": deployment_id or os.getenv("AZURE_DEPLOYMENT_ID"),
|
|
116
|
+
"api_version": os.getenv("AZURE_API_VERSION", "2024-02-15-preview"),
|
|
117
|
+
}
|
|
118
|
+
# Client caches: (endpoint, key) → client instance
|
|
119
|
+
self._openai_clients: dict[tuple[str, str], AzureOpenAI] = {}
|
|
120
|
+
self._anthropic_clients: dict[tuple[str, str], Any] = {}
|
|
93
121
|
|
|
94
122
|
supports_messages = True
|
|
95
123
|
|
|
@@ -98,6 +126,36 @@ class AzureDriver(CostMixin, Driver):
|
|
|
98
126
|
|
|
99
127
|
return _prepare_openai_vision_messages(messages)
|
|
100
128
|
|
|
129
|
+
def _resolve_model_config(self, model: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
130
|
+
"""Resolve Azure config for this model using the priority chain."""
|
|
131
|
+
override = options.pop("azure_config", None)
|
|
132
|
+
return resolve_config(model, override=override, default_config=self._default_config)
|
|
133
|
+
|
|
134
|
+
def _get_openai_client(self, config: dict[str, Any]) -> "AzureOpenAI":
|
|
135
|
+
"""Get or create an AzureOpenAI client for the given config."""
|
|
136
|
+
if AzureOpenAI is None:
|
|
137
|
+
raise RuntimeError("openai package (>=1.0.0) with AzureOpenAI not installed")
|
|
138
|
+
cache_key = (config["endpoint"], config["api_key"])
|
|
139
|
+
if cache_key not in self._openai_clients:
|
|
140
|
+
self._openai_clients[cache_key] = AzureOpenAI(
|
|
141
|
+
api_key=config["api_key"],
|
|
142
|
+
api_version=config.get("api_version", "2024-02-15-preview"),
|
|
143
|
+
azure_endpoint=config["endpoint"],
|
|
144
|
+
)
|
|
145
|
+
return self._openai_clients[cache_key]
|
|
146
|
+
|
|
147
|
+
def _get_anthropic_client(self, config: dict[str, Any]) -> Any:
|
|
148
|
+
"""Get or create an Anthropic client for the given Azure config."""
|
|
149
|
+
if anthropic is None:
|
|
150
|
+
raise RuntimeError("anthropic package not installed (required for Claude on Azure)")
|
|
151
|
+
cache_key = (config["endpoint"], config["api_key"])
|
|
152
|
+
if cache_key not in self._anthropic_clients:
|
|
153
|
+
self._anthropic_clients[cache_key] = anthropic.Anthropic(
|
|
154
|
+
base_url=config["endpoint"],
|
|
155
|
+
api_key=config["api_key"],
|
|
156
|
+
)
|
|
157
|
+
return self._anthropic_clients[cache_key]
|
|
158
|
+
|
|
101
159
|
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
102
160
|
messages = [{"role": "user", "content": prompt}]
|
|
103
161
|
return self._do_generate(messages, options)
|
|
@@ -106,19 +164,35 @@ class AzureDriver(CostMixin, Driver):
|
|
|
106
164
|
return self._do_generate(self._prepare_messages(messages), options)
|
|
107
165
|
|
|
108
166
|
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
109
|
-
if self.client is None:
|
|
110
|
-
raise RuntimeError("openai package (>=1.0.0) with AzureOpenAI not installed")
|
|
111
|
-
|
|
112
167
|
model = options.get("model", self.model)
|
|
168
|
+
config = self._resolve_model_config(model, options)
|
|
169
|
+
backend = classify_backend(model)
|
|
170
|
+
|
|
171
|
+
if backend == "claude":
|
|
172
|
+
return self._generate_claude(messages, options, config, model)
|
|
173
|
+
else:
|
|
174
|
+
# Both "openai" and "mistral" use the OpenAI-compatible protocol
|
|
175
|
+
return self._generate_openai(messages, options, config, model)
|
|
176
|
+
|
|
177
|
+
def _generate_openai(
|
|
178
|
+
self,
|
|
179
|
+
messages: list[dict[str, Any]],
|
|
180
|
+
options: dict[str, Any],
|
|
181
|
+
config: dict[str, Any],
|
|
182
|
+
model: str,
|
|
183
|
+
) -> dict[str, Any]:
|
|
184
|
+
"""Generate via Azure OpenAI (or Mistral OpenAI-compat) endpoint."""
|
|
185
|
+
client = self._get_openai_client(config)
|
|
186
|
+
deployment_id = config.get("deployment_id") or model
|
|
187
|
+
|
|
113
188
|
model_config = self._get_model_config("azure", model)
|
|
114
189
|
tokens_param = model_config["tokens_param"]
|
|
115
190
|
supports_temperature = model_config["supports_temperature"]
|
|
116
191
|
|
|
117
192
|
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
118
193
|
|
|
119
|
-
# Build request kwargs
|
|
120
194
|
kwargs = {
|
|
121
|
-
"model":
|
|
195
|
+
"model": deployment_id,
|
|
122
196
|
"messages": messages,
|
|
123
197
|
}
|
|
124
198
|
kwargs[tokens_param] = opts.get("max_tokens", 512)
|
|
@@ -142,7 +216,7 @@ class AzureDriver(CostMixin, Driver):
|
|
|
142
216
|
else:
|
|
143
217
|
kwargs["response_format"] = {"type": "json_object"}
|
|
144
218
|
|
|
145
|
-
resp =
|
|
219
|
+
resp = client.chat.completions.create(**kwargs)
|
|
146
220
|
|
|
147
221
|
# Extract usage
|
|
148
222
|
usage = getattr(resp, "usage", None)
|
|
@@ -153,7 +227,6 @@ class AzureDriver(CostMixin, Driver):
|
|
|
153
227
|
# Calculate cost via shared mixin
|
|
154
228
|
total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
|
|
155
229
|
|
|
156
|
-
# Standardized meta object
|
|
157
230
|
meta = {
|
|
158
231
|
"prompt_tokens": prompt_tokens,
|
|
159
232
|
"completion_tokens": completion_tokens,
|
|
@@ -161,12 +234,86 @@ class AzureDriver(CostMixin, Driver):
|
|
|
161
234
|
"cost": round(total_cost, 6),
|
|
162
235
|
"raw_response": resp.model_dump(),
|
|
163
236
|
"model_name": model,
|
|
164
|
-
"deployment_id":
|
|
237
|
+
"deployment_id": deployment_id,
|
|
165
238
|
}
|
|
166
239
|
|
|
167
240
|
text = resp.choices[0].message.content
|
|
168
241
|
return {"text": text, "meta": meta}
|
|
169
242
|
|
|
243
|
+
def _generate_claude(
|
|
244
|
+
self,
|
|
245
|
+
messages: list[dict[str, Any]],
|
|
246
|
+
options: dict[str, Any],
|
|
247
|
+
config: dict[str, Any],
|
|
248
|
+
model: str,
|
|
249
|
+
) -> dict[str, Any]:
|
|
250
|
+
"""Generate via Anthropic SDK with Azure endpoint."""
|
|
251
|
+
client = self._get_anthropic_client(config)
|
|
252
|
+
|
|
253
|
+
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
254
|
+
|
|
255
|
+
# Anthropic requires system messages as a top-level parameter
|
|
256
|
+
system_content = None
|
|
257
|
+
api_messages = []
|
|
258
|
+
for msg in messages:
|
|
259
|
+
if msg.get("role") == "system":
|
|
260
|
+
system_content = msg.get("content", "")
|
|
261
|
+
else:
|
|
262
|
+
api_messages.append(msg)
|
|
263
|
+
|
|
264
|
+
common_kwargs: dict[str, Any] = {
|
|
265
|
+
"model": model,
|
|
266
|
+
"messages": api_messages,
|
|
267
|
+
"temperature": opts["temperature"],
|
|
268
|
+
"max_tokens": opts["max_tokens"],
|
|
269
|
+
}
|
|
270
|
+
if system_content:
|
|
271
|
+
common_kwargs["system"] = system_content
|
|
272
|
+
|
|
273
|
+
# Native JSON mode: use tool-use for schema enforcement
|
|
274
|
+
if options.get("json_mode"):
|
|
275
|
+
json_schema = options.get("json_schema")
|
|
276
|
+
if json_schema:
|
|
277
|
+
tool_def = {
|
|
278
|
+
"name": "extract_json",
|
|
279
|
+
"description": "Extract structured data matching the schema",
|
|
280
|
+
"input_schema": json_schema,
|
|
281
|
+
}
|
|
282
|
+
resp = client.messages.create(
|
|
283
|
+
**common_kwargs,
|
|
284
|
+
tools=[tool_def],
|
|
285
|
+
tool_choice={"type": "tool", "name": "extract_json"},
|
|
286
|
+
)
|
|
287
|
+
text = ""
|
|
288
|
+
for block in resp.content:
|
|
289
|
+
if block.type == "tool_use":
|
|
290
|
+
text = json.dumps(block.input)
|
|
291
|
+
break
|
|
292
|
+
else:
|
|
293
|
+
resp = client.messages.create(**common_kwargs)
|
|
294
|
+
text = resp.content[0].text
|
|
295
|
+
else:
|
|
296
|
+
resp = client.messages.create(**common_kwargs)
|
|
297
|
+
text = resp.content[0].text
|
|
298
|
+
|
|
299
|
+
prompt_tokens = resp.usage.input_tokens
|
|
300
|
+
completion_tokens = resp.usage.output_tokens
|
|
301
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
302
|
+
|
|
303
|
+
total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
|
|
304
|
+
|
|
305
|
+
meta = {
|
|
306
|
+
"prompt_tokens": prompt_tokens,
|
|
307
|
+
"completion_tokens": completion_tokens,
|
|
308
|
+
"total_tokens": total_tokens,
|
|
309
|
+
"cost": round(total_cost, 6),
|
|
310
|
+
"raw_response": dict(resp),
|
|
311
|
+
"model_name": model,
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
text_result = text or ""
|
|
315
|
+
return {"text": text_result, "meta": meta}
|
|
316
|
+
|
|
170
317
|
# ------------------------------------------------------------------
|
|
171
318
|
# Tool use
|
|
172
319
|
# ------------------------------------------------------------------
|
|
@@ -178,10 +325,27 @@ class AzureDriver(CostMixin, Driver):
|
|
|
178
325
|
options: dict[str, Any],
|
|
179
326
|
) -> dict[str, Any]:
|
|
180
327
|
"""Generate a response that may include tool calls."""
|
|
181
|
-
if self.client is None:
|
|
182
|
-
raise RuntimeError("openai package (>=1.0.0) with AzureOpenAI not installed")
|
|
183
|
-
|
|
184
328
|
model = options.get("model", self.model)
|
|
329
|
+
config = self._resolve_model_config(model, options)
|
|
330
|
+
backend = classify_backend(model)
|
|
331
|
+
|
|
332
|
+
if backend == "claude":
|
|
333
|
+
return self._generate_claude_with_tools(messages, tools, options, config, model)
|
|
334
|
+
else:
|
|
335
|
+
return self._generate_openai_with_tools(messages, tools, options, config, model)
|
|
336
|
+
|
|
337
|
+
def _generate_openai_with_tools(
|
|
338
|
+
self,
|
|
339
|
+
messages: list[dict[str, Any]],
|
|
340
|
+
tools: list[dict[str, Any]],
|
|
341
|
+
options: dict[str, Any],
|
|
342
|
+
config: dict[str, Any],
|
|
343
|
+
model: str,
|
|
344
|
+
) -> dict[str, Any]:
|
|
345
|
+
"""Tool calling via Azure OpenAI endpoint."""
|
|
346
|
+
client = self._get_openai_client(config)
|
|
347
|
+
deployment_id = config.get("deployment_id") or model
|
|
348
|
+
|
|
185
349
|
model_config = self._get_model_config("azure", model)
|
|
186
350
|
tokens_param = model_config["tokens_param"]
|
|
187
351
|
supports_temperature = model_config["supports_temperature"]
|
|
@@ -191,7 +355,7 @@ class AzureDriver(CostMixin, Driver):
|
|
|
191
355
|
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
192
356
|
|
|
193
357
|
kwargs: dict[str, Any] = {
|
|
194
|
-
"model":
|
|
358
|
+
"model": deployment_id,
|
|
195
359
|
"messages": messages,
|
|
196
360
|
"tools": tools,
|
|
197
361
|
}
|
|
@@ -200,7 +364,7 @@ class AzureDriver(CostMixin, Driver):
|
|
|
200
364
|
if supports_temperature and "temperature" in opts:
|
|
201
365
|
kwargs["temperature"] = opts["temperature"]
|
|
202
366
|
|
|
203
|
-
resp =
|
|
367
|
+
resp = client.chat.completions.create(**kwargs)
|
|
204
368
|
|
|
205
369
|
usage = getattr(resp, "usage", None)
|
|
206
370
|
prompt_tokens = getattr(usage, "prompt_tokens", 0)
|
|
@@ -215,7 +379,7 @@ class AzureDriver(CostMixin, Driver):
|
|
|
215
379
|
"cost": round(total_cost, 6),
|
|
216
380
|
"raw_response": resp.model_dump(),
|
|
217
381
|
"model_name": model,
|
|
218
|
-
"deployment_id":
|
|
382
|
+
"deployment_id": deployment_id,
|
|
219
383
|
}
|
|
220
384
|
|
|
221
385
|
choice = resp.choices[0]
|
|
@@ -229,11 +393,13 @@ class AzureDriver(CostMixin, Driver):
|
|
|
229
393
|
args = json.loads(tc.function.arguments)
|
|
230
394
|
except (json.JSONDecodeError, TypeError):
|
|
231
395
|
args = {}
|
|
232
|
-
tool_calls_out.append(
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
396
|
+
tool_calls_out.append(
|
|
397
|
+
{
|
|
398
|
+
"id": tc.id,
|
|
399
|
+
"name": tc.function.name,
|
|
400
|
+
"arguments": args,
|
|
401
|
+
}
|
|
402
|
+
)
|
|
237
403
|
|
|
238
404
|
return {
|
|
239
405
|
"text": text,
|
|
@@ -241,3 +407,88 @@ class AzureDriver(CostMixin, Driver):
|
|
|
241
407
|
"tool_calls": tool_calls_out,
|
|
242
408
|
"stop_reason": stop_reason,
|
|
243
409
|
}
|
|
410
|
+
|
|
411
|
+
def _generate_claude_with_tools(
|
|
412
|
+
self,
|
|
413
|
+
messages: list[dict[str, Any]],
|
|
414
|
+
tools: list[dict[str, Any]],
|
|
415
|
+
options: dict[str, Any],
|
|
416
|
+
config: dict[str, Any],
|
|
417
|
+
model: str,
|
|
418
|
+
) -> dict[str, Any]:
|
|
419
|
+
"""Tool calling via Anthropic SDK with Azure endpoint."""
|
|
420
|
+
client = self._get_anthropic_client(config)
|
|
421
|
+
|
|
422
|
+
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
423
|
+
|
|
424
|
+
system_content = None
|
|
425
|
+
api_messages: list[dict[str, Any]] = []
|
|
426
|
+
for msg in messages:
|
|
427
|
+
if msg.get("role") == "system":
|
|
428
|
+
system_content = msg.get("content", "")
|
|
429
|
+
else:
|
|
430
|
+
api_messages.append(msg)
|
|
431
|
+
|
|
432
|
+
# Convert tools from OpenAI format to Anthropic format if needed
|
|
433
|
+
anthropic_tools = []
|
|
434
|
+
for t in tools:
|
|
435
|
+
if "type" in t and t["type"] == "function":
|
|
436
|
+
fn = t["function"]
|
|
437
|
+
anthropic_tools.append(
|
|
438
|
+
{
|
|
439
|
+
"name": fn["name"],
|
|
440
|
+
"description": fn.get("description", ""),
|
|
441
|
+
"input_schema": fn.get("parameters", {"type": "object", "properties": {}}),
|
|
442
|
+
}
|
|
443
|
+
)
|
|
444
|
+
elif "input_schema" in t:
|
|
445
|
+
anthropic_tools.append(t)
|
|
446
|
+
else:
|
|
447
|
+
anthropic_tools.append(t)
|
|
448
|
+
|
|
449
|
+
kwargs: dict[str, Any] = {
|
|
450
|
+
"model": model,
|
|
451
|
+
"messages": api_messages,
|
|
452
|
+
"temperature": opts["temperature"],
|
|
453
|
+
"max_tokens": opts["max_tokens"],
|
|
454
|
+
"tools": anthropic_tools,
|
|
455
|
+
}
|
|
456
|
+
if system_content:
|
|
457
|
+
kwargs["system"] = system_content
|
|
458
|
+
|
|
459
|
+
resp = client.messages.create(**kwargs)
|
|
460
|
+
|
|
461
|
+
prompt_tokens = resp.usage.input_tokens
|
|
462
|
+
completion_tokens = resp.usage.output_tokens
|
|
463
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
464
|
+
total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
|
|
465
|
+
|
|
466
|
+
meta = {
|
|
467
|
+
"prompt_tokens": prompt_tokens,
|
|
468
|
+
"completion_tokens": completion_tokens,
|
|
469
|
+
"total_tokens": total_tokens,
|
|
470
|
+
"cost": round(total_cost, 6),
|
|
471
|
+
"raw_response": dict(resp),
|
|
472
|
+
"model_name": model,
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
text = ""
|
|
476
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
477
|
+
for block in resp.content:
|
|
478
|
+
if block.type == "text":
|
|
479
|
+
text += block.text
|
|
480
|
+
elif block.type == "tool_use":
|
|
481
|
+
tool_calls_out.append(
|
|
482
|
+
{
|
|
483
|
+
"id": block.id,
|
|
484
|
+
"name": block.name,
|
|
485
|
+
"arguments": block.input,
|
|
486
|
+
}
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
return {
|
|
490
|
+
"text": text,
|
|
491
|
+
"meta": meta,
|
|
492
|
+
"tool_calls": tool_calls_out,
|
|
493
|
+
"stop_reason": resp.stop_reason,
|
|
494
|
+
}
|
prompture/settings.py
CHANGED
|
@@ -25,11 +25,21 @@ class Settings(BaseSettings):
|
|
|
25
25
|
ollama_endpoint: str = "http://localhost:11434/api/generate"
|
|
26
26
|
ollama_model: str = "llama2"
|
|
27
27
|
|
|
28
|
-
# Azure
|
|
28
|
+
# Azure (default / OpenAI backend)
|
|
29
29
|
azure_api_key: Optional[str] = None
|
|
30
30
|
azure_api_endpoint: Optional[str] = None
|
|
31
31
|
azure_deployment_id: Optional[str] = None
|
|
32
32
|
|
|
33
|
+
# Azure - Claude backend (optional)
|
|
34
|
+
azure_claude_api_key: Optional[str] = None
|
|
35
|
+
azure_claude_endpoint: Optional[str] = None
|
|
36
|
+
azure_claude_api_version: Optional[str] = None
|
|
37
|
+
|
|
38
|
+
# Azure - Mistral backend (optional)
|
|
39
|
+
azure_mistral_api_key: Optional[str] = None
|
|
40
|
+
azure_mistral_endpoint: Optional[str] = None
|
|
41
|
+
azure_mistral_api_version: Optional[str] = None
|
|
42
|
+
|
|
33
43
|
# LM Studio
|
|
34
44
|
lmstudio_endpoint: str = "http://127.0.0.1:1234/v1/chat/completions"
|
|
35
45
|
lmstudio_model: str = "deepseek/deepseek-r1-0528-qwen3-8b"
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
prompture/__init__.py,sha256=
|
|
2
|
-
prompture/_version.py,sha256=
|
|
1
|
+
prompture/__init__.py,sha256=qpslzeo6Lll0oaed5ed7-SjI27wz5GS4lbMPLKxmSPU,7710
|
|
2
|
+
prompture/_version.py,sha256=sBIk0sWwfkCODSG03QjKVbEilvZaZDxI-tCNjui50TI,706
|
|
3
3
|
prompture/agent.py,sha256=-8qdo_Lz20GGssCe5B_QPxb5Kct71YtKHh5vZgrSYik,34748
|
|
4
4
|
prompture/agent_types.py,sha256=Icl16PQI-ThGLMFCU43adtQA6cqETbsPn4KssKBI4xc,4664
|
|
5
5
|
prompture/async_agent.py,sha256=_6_IRb-LGzZxGxfPVy43SIWByUoQfN-5XnUWahVP6r8,33110
|
|
@@ -13,7 +13,7 @@ prompture/cli.py,sha256=tNiIddRmgC1BomjY5O1VVVAwvqHVzF8IHmQrM-cG2wQ,2902
|
|
|
13
13
|
prompture/conversation.py,sha256=uxstayJjgY6a39DtU0YxQl0Dt3JBo2UVCyMPJW95MNI,36428
|
|
14
14
|
prompture/core.py,sha256=5FHwX7fNPwFHMbFCMvV-RH7LpPpTToLAmcyDnKbrN0E,57202
|
|
15
15
|
prompture/cost_mixin.py,sha256=JbdmSWFP3om7rfQQbRgVh_HboGmWHbvbtjSrNjGN4NU,4621
|
|
16
|
-
prompture/discovery.py,sha256=
|
|
16
|
+
prompture/discovery.py,sha256=5wlooJedAwaceUC21Tw5qHtBrQXbAvzGmte1lIOhToI,10642
|
|
17
17
|
prompture/driver.py,sha256=wE7K3vnqeCVT5pEEBP-3uZ6e-YyU6TXtnEKRSB25eOc,10410
|
|
18
18
|
prompture/field_definitions.py,sha256=PLvxq2ot-ngJ8JbWkkZ-XLtM1wvjUQ3TL01vSEo-a6E,21368
|
|
19
19
|
prompture/group_types.py,sha256=lr8f5kA5IY5cJ_K06OGBaziEf9fPMIRYgtVT12q3aiQ,4769
|
|
@@ -28,16 +28,16 @@ prompture/runner.py,sha256=lHe2L2jqY1pDXoKNPJALN9lAm-Q8QOY8C8gw-vM9VrM,4213
|
|
|
28
28
|
prompture/serialization.py,sha256=m4cdAQJspitMcfwRgecElkY2SBt3BjEwubbhS3W-0s0,7433
|
|
29
29
|
prompture/server.py,sha256=W6Kn6Et8nG5twXjD2wKn_N9yplGjz5Z-2naeI_UPd1Y,6198
|
|
30
30
|
prompture/session.py,sha256=FldK3cKq_jO0-beukVOhIiwsYWb6U_lLBlAERx95aaM,3821
|
|
31
|
-
prompture/settings.py,sha256=
|
|
31
|
+
prompture/settings.py,sha256=fGuIxesF3S4t145d-0JGgWEa7UprhmVGBuqZJwy9Cv0,2967
|
|
32
32
|
prompture/simulated_tools.py,sha256=oL6W6hAEKXZHBfb8b-UDPfm3V4nSqXu7eG8IpvwtqKg,3901
|
|
33
33
|
prompture/tools.py,sha256=PmFbGHTWYWahpJOG6BLlM0Y-EG6S37IFW57C-8GdsXo,36449
|
|
34
34
|
prompture/tools_schema.py,sha256=wuVfPyCKVWlhUDRsXWArtGpxkQRqNWyKeLJuXn_6X8k,8986
|
|
35
35
|
prompture/validator.py,sha256=FY_VjIVEbjG2nwzh-r6l23Kt3UzaLyCis8_pZMNGHBA,993
|
|
36
36
|
prompture/aio/__init__.py,sha256=bKqTu4Jxld16aP_7SP9wU5au45UBIb041ORo4E4HzVo,1810
|
|
37
|
-
prompture/drivers/__init__.py,sha256=
|
|
37
|
+
prompture/drivers/__init__.py,sha256=ueloUcIxOK02u6ICcECZ6_DLWOkobjVWrHPL5uRJSWs,8518
|
|
38
38
|
prompture/drivers/airllm_driver.py,sha256=SaTh7e7Plvuct_TfRqQvsJsKHvvM_3iVqhBtlciM-Kw,3858
|
|
39
39
|
prompture/drivers/async_airllm_driver.py,sha256=1hIWLXfyyIg9tXaOE22tLJvFyNwHnOi1M5BIKnV8ysk,908
|
|
40
|
-
prompture/drivers/async_azure_driver.py,sha256=
|
|
40
|
+
prompture/drivers/async_azure_driver.py,sha256=iRR06MqmLOO--JZT8EryCHbimJTYI3q7pyX1zQp8Mns,15319
|
|
41
41
|
prompture/drivers/async_claude_driver.py,sha256=k6D6aEgcy8HYbuCsoqDknh7aTfw_cJrV7kDMqCA0OSg,11746
|
|
42
42
|
prompture/drivers/async_google_driver.py,sha256=LTUgCXJjzuTDGzsCsmY2-xH2KdTLJD7htwO49ZNFOdE,13711
|
|
43
43
|
prompture/drivers/async_grok_driver.py,sha256=lj160GHARe0fqTms4ovWhkpgt0idsGt55xnuc6JlH1w,7413
|
|
@@ -50,9 +50,10 @@ prompture/drivers/async_moonshot_driver.py,sha256=a9gr3T_4NiDFd7foM1mSHJRvXYb43i
|
|
|
50
50
|
prompture/drivers/async_ollama_driver.py,sha256=Li2ZKZrItxKLkbIuugF8LChlnN3xtXtIoc92Ek8_wMc,9121
|
|
51
51
|
prompture/drivers/async_openai_driver.py,sha256=COa_JE-AgKowKJpmRnfDJp4RSQKZel_7WswxOzvLksM,9044
|
|
52
52
|
prompture/drivers/async_openrouter_driver.py,sha256=N7s72HuXHLs_RWmJO9P3pCayWE98ommfqVeAfru8Bl0,11758
|
|
53
|
-
prompture/drivers/async_registry.py,sha256=
|
|
53
|
+
prompture/drivers/async_registry.py,sha256=K5BZOsK1rb-8OgQZiwhIf8-8Q1E1aLl9MgHlxBVPIiE,5278
|
|
54
54
|
prompture/drivers/async_zai_driver.py,sha256=zXHxske1CtK8dDTGY-D_kiyZZ_NfceNTJlyTpKn0R4c,10727
|
|
55
|
-
prompture/drivers/
|
|
55
|
+
prompture/drivers/azure_config.py,sha256=tA0dWtN8qYxE00_ftWJtJ84Cboq2MMC7yOoR--o8u0s,4046
|
|
56
|
+
prompture/drivers/azure_driver.py,sha256=v1fiPb1vZrZwWLypOH2iIghi3BRvOiFXMTdTevhbtMQ,17867
|
|
56
57
|
prompture/drivers/claude_driver.py,sha256=TOJMhCSAyF8yRmKyVl0pACJUBrxMZHHnQE12iBijCCQ,13474
|
|
57
58
|
prompture/drivers/google_driver.py,sha256=Zck5VUsW37kDgohXz3cUWRmZ88OfhmTpVD-qzAVMp-8,16318
|
|
58
59
|
prompture/drivers/grok_driver.py,sha256=fxl5Gx9acFq7BlOh_N9U66oJvG3y8YX4QuSAgZWHJmU,8963
|
|
@@ -77,9 +78,9 @@ prompture/scaffold/templates/env.example.j2,sha256=eESKr1KWgyrczO6d-nwAhQwSpf_G-
|
|
|
77
78
|
prompture/scaffold/templates/main.py.j2,sha256=TEgc5OvsZOEX0JthkSW1NI_yLwgoeVN_x97Ibg-vyWY,2632
|
|
78
79
|
prompture/scaffold/templates/models.py.j2,sha256=JrZ99GCVK6TKWapskVRSwCssGrTu5cGZ_r46fOhY2GE,858
|
|
79
80
|
prompture/scaffold/templates/requirements.txt.j2,sha256=m3S5fi1hq9KG9l_9j317rjwWww0a43WMKd8VnUWv2A4,102
|
|
80
|
-
prompture-0.0.
|
|
81
|
-
prompture-0.0.
|
|
82
|
-
prompture-0.0.
|
|
83
|
-
prompture-0.0.
|
|
84
|
-
prompture-0.0.
|
|
85
|
-
prompture-0.0.
|
|
81
|
+
prompture-0.0.50.dist-info/licenses/LICENSE,sha256=0HgDepH7aaHNFhHF-iXuW6_GqDfYPnVkjtiCAZ4yS8I,1060
|
|
82
|
+
prompture-0.0.50.dist-info/METADATA,sha256=5q88yZTm91Z3jwvxZcbTOJJsBj1eQuLKNHCIlrmf5Vk,12148
|
|
83
|
+
prompture-0.0.50.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
84
|
+
prompture-0.0.50.dist-info/entry_points.txt,sha256=AFPG3lJR86g4IJMoWQUW5Ph7G6MLNWG3A2u2Tp9zkp8,48
|
|
85
|
+
prompture-0.0.50.dist-info/top_level.txt,sha256=to86zq_kjfdoLeAxQNr420UWqT0WzkKoZ509J7Qr2t4,10
|
|
86
|
+
prompture-0.0.50.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|