prompture 0.0.49__tar.gz → 0.0.50__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {prompture-0.0.49 → prompture-0.0.50}/.env.copy +11 -1
- {prompture-0.0.49 → prompture-0.0.50}/PKG-INFO +1 -1
- prompture-0.0.50/VERSION +1 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/__init__.py +9 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/_version.py +2 -2
- {prompture-0.0.49 → prompture-0.0.50}/prompture/discovery.py +10 -3
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/__init__.py +15 -1
- prompture-0.0.50/prompture/drivers/async_azure_driver.py +418 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_registry.py +4 -1
- prompture-0.0.50/prompture/drivers/azure_config.py +146 -0
- prompture-0.0.50/prompture/drivers/azure_driver.py +494 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/settings.py +11 -1
- {prompture-0.0.49 → prompture-0.0.50}/prompture.egg-info/PKG-INFO +1 -1
- {prompture-0.0.49 → prompture-0.0.50}/prompture.egg-info/SOURCES.txt +1 -0
- prompture-0.0.49/VERSION +0 -1
- prompture-0.0.49/prompture/drivers/async_azure_driver.py +0 -201
- prompture-0.0.49/prompture/drivers/azure_driver.py +0 -243
- {prompture-0.0.49 → prompture-0.0.50}/.claude/skills/add-driver/SKILL.md +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/.claude/skills/add-driver/references/driver-template.md +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/.claude/skills/add-example/SKILL.md +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/.claude/skills/add-field/SKILL.md +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/.claude/skills/add-persona/SKILL.md +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/.claude/skills/add-test/SKILL.md +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/.claude/skills/add-tool/SKILL.md +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/.claude/skills/run-tests/SKILL.md +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/.claude/skills/scaffold-extraction/SKILL.md +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/.claude/skills/update-pricing/SKILL.md +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/.github/FUNDING.yml +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/.github/scripts/update_docs_version.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/.github/scripts/update_wrapper_version.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/.github/workflows/dev.yml +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/.github/workflows/documentation.yml +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/.github/workflows/publish.yml +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/CLAUDE.md +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/LICENSE +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/MANIFEST.in +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/README.md +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/ROADMAP.md +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/docs/source/_static/custom.css +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/docs/source/_templates/footer.html +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/docs/source/api/core.rst +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/docs/source/api/drivers.rst +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/docs/source/api/field_definitions.rst +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/docs/source/api/index.rst +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/docs/source/api/runner.rst +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/docs/source/api/tools.rst +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/docs/source/api/validator.rst +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/docs/source/conf.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/docs/source/contributing.rst +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/docs/source/examples.rst +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/docs/source/field_definitions_reference.rst +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/docs/source/index.rst +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/docs/source/installation.rst +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/docs/source/quickstart.rst +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/docs/source/toon_input_guide.rst +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/packages/README.md +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/packages/llm_to_json/README.md +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/packages/llm_to_json/llm_to_json/__init__.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/packages/llm_to_json/pyproject.toml +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/packages/llm_to_json/test.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/packages/llm_to_toon/README.md +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/packages/llm_to_toon/llm_to_toon/__init__.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/packages/llm_to_toon/pyproject.toml +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/packages/llm_to_toon/test.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/agent.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/agent_types.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/aio/__init__.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/async_agent.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/async_conversation.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/async_core.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/async_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/async_groups.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/cache.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/callbacks.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/cli.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/conversation.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/core.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/cost_mixin.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/airllm_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_airllm_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_claude_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_google_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_grok_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_groq_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_hugging_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_lmstudio_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_local_http_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_modelscope_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_moonshot_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_ollama_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_openai_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_openrouter_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_zai_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/claude_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/google_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/grok_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/groq_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/hugging_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/lmstudio_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/local_http_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/modelscope_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/moonshot_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/ollama_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/openai_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/openrouter_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/registry.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/vision_helpers.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/zai_driver.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/field_definitions.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/group_types.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/groups.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/image.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/ledger.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/logging.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/model_rates.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/persistence.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/persona.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/runner.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/scaffold/__init__.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/scaffold/generator.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/scaffold/templates/Dockerfile.j2 +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/scaffold/templates/README.md.j2 +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/scaffold/templates/config.py.j2 +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/scaffold/templates/env.example.j2 +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/scaffold/templates/main.py.j2 +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/scaffold/templates/models.py.j2 +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/scaffold/templates/requirements.txt.j2 +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/serialization.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/server.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/session.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/simulated_tools.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/tools.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/tools_schema.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture/validator.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture.egg-info/dependency_links.txt +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture.egg-info/entry_points.txt +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture.egg-info/requires.txt +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/prompture.egg-info/top_level.txt +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/pyproject.toml +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/requirements.txt +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/setup.cfg +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/test.py +0 -0
- {prompture-0.0.49 → prompture-0.0.50}/test_version_diagnosis.py +0 -0
|
@@ -25,12 +25,22 @@ LMSTUDIO_ENDPOINT=http://127.0.0.1:1234/v1/chat/completions
|
|
|
25
25
|
LMSTUDIO_MODEL=deepseek/deepseek-r1-0528-qwen3-8b
|
|
26
26
|
LMSTUDIO_API_KEY=
|
|
27
27
|
|
|
28
|
-
# Azure OpenAI Configuration
|
|
28
|
+
# Azure OpenAI Configuration (default backend)
|
|
29
29
|
AZURE_API_KEY=
|
|
30
30
|
AZURE_API_ENDPOINT=
|
|
31
31
|
AZURE_DEPLOYMENT_ID=
|
|
32
32
|
AZURE_API_VERSION=
|
|
33
33
|
|
|
34
|
+
# Azure Claude Backend (optional, for claude-* models on Azure)
|
|
35
|
+
AZURE_CLAUDE_API_KEY=
|
|
36
|
+
AZURE_CLAUDE_ENDPOINT=
|
|
37
|
+
AZURE_CLAUDE_API_VERSION=
|
|
38
|
+
|
|
39
|
+
# Azure Mistral Backend (optional, for mistral-*/mixtral-* models on Azure)
|
|
40
|
+
AZURE_MISTRAL_API_KEY=
|
|
41
|
+
AZURE_MISTRAL_ENDPOINT=
|
|
42
|
+
AZURE_MISTRAL_API_VERSION=
|
|
43
|
+
|
|
34
44
|
# Additional Providers (not required for tests)
|
|
35
45
|
# HuggingFace Configuration
|
|
36
46
|
HF_ENDPOINT=
|
prompture-0.0.50/VERSION
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
0.0.50
|
|
@@ -60,6 +60,8 @@ from .drivers import (
|
|
|
60
60
|
OllamaDriver,
|
|
61
61
|
OpenAIDriver,
|
|
62
62
|
OpenRouterDriver,
|
|
63
|
+
# Azure config API
|
|
64
|
+
clear_azure_configs,
|
|
63
65
|
get_driver,
|
|
64
66
|
get_driver_for_model,
|
|
65
67
|
# Plugin registration API
|
|
@@ -69,8 +71,11 @@ from .drivers import (
|
|
|
69
71
|
list_registered_drivers,
|
|
70
72
|
load_entry_point_drivers,
|
|
71
73
|
register_async_driver,
|
|
74
|
+
register_azure_config,
|
|
72
75
|
register_driver,
|
|
76
|
+
set_azure_config_resolver,
|
|
73
77
|
unregister_async_driver,
|
|
78
|
+
unregister_azure_config,
|
|
74
79
|
unregister_driver,
|
|
75
80
|
)
|
|
76
81
|
from .field_definitions import (
|
|
@@ -247,6 +252,7 @@ __all__ = [
|
|
|
247
252
|
"clean_json_text",
|
|
248
253
|
"clean_json_text_with_ai",
|
|
249
254
|
"clean_toon_text",
|
|
255
|
+
"clear_azure_configs",
|
|
250
256
|
"clear_persona_registry",
|
|
251
257
|
"clear_registry",
|
|
252
258
|
"configure_cache",
|
|
@@ -292,6 +298,7 @@ __all__ = [
|
|
|
292
298
|
"normalize_enum_value",
|
|
293
299
|
"refresh_rates_cache",
|
|
294
300
|
"register_async_driver",
|
|
301
|
+
"register_azure_config",
|
|
295
302
|
"register_driver",
|
|
296
303
|
"register_field",
|
|
297
304
|
"register_persona",
|
|
@@ -301,9 +308,11 @@ __all__ = [
|
|
|
301
308
|
"reset_registry",
|
|
302
309
|
"reset_trait_registry",
|
|
303
310
|
"run_suite_from_spec",
|
|
311
|
+
"set_azure_config_resolver",
|
|
304
312
|
"stepwise_extract_with_model",
|
|
305
313
|
"tool_from_function",
|
|
306
314
|
"unregister_async_driver",
|
|
315
|
+
"unregister_azure_config",
|
|
307
316
|
"unregister_driver",
|
|
308
317
|
"validate_against_schema",
|
|
309
318
|
"validate_enum_value",
|
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.0.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 0,
|
|
31
|
+
__version__ = version = '0.0.50'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 0, 50)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -89,10 +89,17 @@ def get_available_models(
|
|
|
89
89
|
if settings.openai_api_key or os.getenv("OPENAI_API_KEY"):
|
|
90
90
|
is_configured = True
|
|
91
91
|
elif provider == "azure":
|
|
92
|
+
from .drivers.azure_config import has_azure_config_resolver, has_registered_configs
|
|
93
|
+
|
|
92
94
|
if (
|
|
93
|
-
(
|
|
94
|
-
|
|
95
|
-
|
|
95
|
+
(
|
|
96
|
+
(settings.azure_api_key or os.getenv("AZURE_API_KEY"))
|
|
97
|
+
and (settings.azure_api_endpoint or os.getenv("AZURE_API_ENDPOINT"))
|
|
98
|
+
)
|
|
99
|
+
or (settings.azure_claude_api_key or os.getenv("AZURE_CLAUDE_API_KEY"))
|
|
100
|
+
or (settings.azure_mistral_api_key or os.getenv("AZURE_MISTRAL_API_KEY"))
|
|
101
|
+
or has_registered_configs()
|
|
102
|
+
or has_azure_config_resolver()
|
|
96
103
|
):
|
|
97
104
|
is_configured = True
|
|
98
105
|
elif provider == "claude":
|
|
@@ -44,6 +44,12 @@ from .async_openai_driver import AsyncOpenAIDriver
|
|
|
44
44
|
from .async_openrouter_driver import AsyncOpenRouterDriver
|
|
45
45
|
from .async_registry import ASYNC_DRIVER_REGISTRY, get_async_driver, get_async_driver_for_model
|
|
46
46
|
from .async_zai_driver import AsyncZaiDriver
|
|
47
|
+
from .azure_config import (
|
|
48
|
+
clear_azure_configs,
|
|
49
|
+
register_azure_config,
|
|
50
|
+
set_azure_config_resolver,
|
|
51
|
+
unregister_azure_config,
|
|
52
|
+
)
|
|
47
53
|
from .azure_driver import AzureDriver
|
|
48
54
|
from .claude_driver import ClaudeDriver
|
|
49
55
|
from .google_driver import GoogleDriver
|
|
@@ -100,7 +106,10 @@ register_driver(
|
|
|
100
106
|
register_driver(
|
|
101
107
|
"azure",
|
|
102
108
|
lambda model=None: AzureDriver(
|
|
103
|
-
api_key=settings.azure_api_key,
|
|
109
|
+
api_key=settings.azure_api_key,
|
|
110
|
+
endpoint=settings.azure_api_endpoint,
|
|
111
|
+
deployment_id=settings.azure_deployment_id,
|
|
112
|
+
model=model or "gpt-4o-mini",
|
|
104
113
|
),
|
|
105
114
|
overwrite=True,
|
|
106
115
|
)
|
|
@@ -249,6 +258,8 @@ __all__ = [
|
|
|
249
258
|
"OpenAIDriver",
|
|
250
259
|
"OpenRouterDriver",
|
|
251
260
|
"ZaiDriver",
|
|
261
|
+
# Azure config API
|
|
262
|
+
"clear_azure_configs",
|
|
252
263
|
"get_async_driver",
|
|
253
264
|
"get_async_driver_for_model",
|
|
254
265
|
# Factory functions
|
|
@@ -260,8 +271,11 @@ __all__ = [
|
|
|
260
271
|
"list_registered_drivers",
|
|
261
272
|
"load_entry_point_drivers",
|
|
262
273
|
"register_async_driver",
|
|
274
|
+
"register_azure_config",
|
|
263
275
|
# Registry functions (public API)
|
|
264
276
|
"register_driver",
|
|
277
|
+
"set_azure_config_resolver",
|
|
265
278
|
"unregister_async_driver",
|
|
279
|
+
"unregister_azure_config",
|
|
266
280
|
"unregister_driver",
|
|
267
281
|
]
|
|
@@ -0,0 +1,418 @@
|
|
|
1
|
+
"""Async Azure driver with multi-endpoint and multi-backend support.
|
|
2
|
+
|
|
3
|
+
Requires the ``openai`` package (>=1.0.0). Claude backend also requires ``anthropic``.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
from openai import AsyncAzureOpenAI
|
|
14
|
+
except Exception:
|
|
15
|
+
AsyncAzureOpenAI = None
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
import anthropic
|
|
19
|
+
except Exception:
|
|
20
|
+
anthropic = None
|
|
21
|
+
|
|
22
|
+
from ..async_driver import AsyncDriver
|
|
23
|
+
from ..cost_mixin import CostMixin, prepare_strict_schema
|
|
24
|
+
from .azure_config import classify_backend, resolve_config
|
|
25
|
+
from .azure_driver import AzureDriver
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class AsyncAzureDriver(CostMixin, AsyncDriver):
|
|
29
|
+
supports_json_mode = True
|
|
30
|
+
supports_json_schema = True
|
|
31
|
+
supports_tool_use = True
|
|
32
|
+
supports_vision = True
|
|
33
|
+
|
|
34
|
+
MODEL_PRICING = AzureDriver.MODEL_PRICING
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
api_key: str | None = None,
|
|
39
|
+
endpoint: str | None = None,
|
|
40
|
+
deployment_id: str | None = None,
|
|
41
|
+
model: str = "gpt-4o-mini",
|
|
42
|
+
):
|
|
43
|
+
self.model = model
|
|
44
|
+
self._default_config = {
|
|
45
|
+
"api_key": api_key or os.getenv("AZURE_API_KEY"),
|
|
46
|
+
"endpoint": endpoint or os.getenv("AZURE_API_ENDPOINT"),
|
|
47
|
+
"deployment_id": deployment_id or os.getenv("AZURE_DEPLOYMENT_ID"),
|
|
48
|
+
"api_version": os.getenv("AZURE_API_VERSION", "2024-02-15-preview"),
|
|
49
|
+
}
|
|
50
|
+
self._openai_clients: dict[tuple[str, str], AsyncAzureOpenAI] = {}
|
|
51
|
+
self._anthropic_clients: dict[tuple[str, str], Any] = {}
|
|
52
|
+
|
|
53
|
+
supports_messages = True
|
|
54
|
+
|
|
55
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
56
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
57
|
+
|
|
58
|
+
return _prepare_openai_vision_messages(messages)
|
|
59
|
+
|
|
60
|
+
def _resolve_model_config(self, model: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
61
|
+
"""Resolve Azure config for this model using the priority chain."""
|
|
62
|
+
override = options.pop("azure_config", None)
|
|
63
|
+
return resolve_config(model, override=override, default_config=self._default_config)
|
|
64
|
+
|
|
65
|
+
def _get_openai_client(self, config: dict[str, Any]) -> AsyncAzureOpenAI:
|
|
66
|
+
"""Get or create an AsyncAzureOpenAI client for the given config."""
|
|
67
|
+
if AsyncAzureOpenAI is None:
|
|
68
|
+
raise RuntimeError("openai package (>=1.0.0) with AsyncAzureOpenAI not installed")
|
|
69
|
+
cache_key = (config["endpoint"], config["api_key"])
|
|
70
|
+
if cache_key not in self._openai_clients:
|
|
71
|
+
self._openai_clients[cache_key] = AsyncAzureOpenAI(
|
|
72
|
+
api_key=config["api_key"],
|
|
73
|
+
api_version=config.get("api_version", "2024-02-15-preview"),
|
|
74
|
+
azure_endpoint=config["endpoint"],
|
|
75
|
+
)
|
|
76
|
+
return self._openai_clients[cache_key]
|
|
77
|
+
|
|
78
|
+
def _get_anthropic_client(self, config: dict[str, Any]) -> Any:
|
|
79
|
+
"""Get or create an AsyncAnthropic client for the given Azure config."""
|
|
80
|
+
if anthropic is None:
|
|
81
|
+
raise RuntimeError("anthropic package not installed (required for Claude on Azure)")
|
|
82
|
+
cache_key = (config["endpoint"], config["api_key"])
|
|
83
|
+
if cache_key not in self._anthropic_clients:
|
|
84
|
+
self._anthropic_clients[cache_key] = anthropic.AsyncAnthropic(
|
|
85
|
+
base_url=config["endpoint"],
|
|
86
|
+
api_key=config["api_key"],
|
|
87
|
+
)
|
|
88
|
+
return self._anthropic_clients[cache_key]
|
|
89
|
+
|
|
90
|
+
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
91
|
+
messages = [{"role": "user", "content": prompt}]
|
|
92
|
+
return await self._do_generate(messages, options)
|
|
93
|
+
|
|
94
|
+
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
95
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
96
|
+
|
|
97
|
+
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
98
|
+
model = options.get("model", self.model)
|
|
99
|
+
config = self._resolve_model_config(model, options)
|
|
100
|
+
backend = classify_backend(model)
|
|
101
|
+
|
|
102
|
+
if backend == "claude":
|
|
103
|
+
return await self._generate_claude(messages, options, config, model)
|
|
104
|
+
else:
|
|
105
|
+
return await self._generate_openai(messages, options, config, model)
|
|
106
|
+
|
|
107
|
+
async def _generate_openai(
|
|
108
|
+
self,
|
|
109
|
+
messages: list[dict[str, Any]],
|
|
110
|
+
options: dict[str, Any],
|
|
111
|
+
config: dict[str, Any],
|
|
112
|
+
model: str,
|
|
113
|
+
) -> dict[str, Any]:
|
|
114
|
+
"""Generate via Azure OpenAI (or Mistral OpenAI-compat) endpoint."""
|
|
115
|
+
client = self._get_openai_client(config)
|
|
116
|
+
deployment_id = config.get("deployment_id") or model
|
|
117
|
+
|
|
118
|
+
model_config = self._get_model_config("azure", model)
|
|
119
|
+
tokens_param = model_config["tokens_param"]
|
|
120
|
+
supports_temperature = model_config["supports_temperature"]
|
|
121
|
+
|
|
122
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
123
|
+
|
|
124
|
+
kwargs = {
|
|
125
|
+
"model": deployment_id,
|
|
126
|
+
"messages": messages,
|
|
127
|
+
}
|
|
128
|
+
kwargs[tokens_param] = opts.get("max_tokens", 512)
|
|
129
|
+
|
|
130
|
+
if supports_temperature and "temperature" in opts:
|
|
131
|
+
kwargs["temperature"] = opts["temperature"]
|
|
132
|
+
|
|
133
|
+
if options.get("json_mode"):
|
|
134
|
+
json_schema = options.get("json_schema")
|
|
135
|
+
if json_schema:
|
|
136
|
+
schema_copy = prepare_strict_schema(json_schema)
|
|
137
|
+
kwargs["response_format"] = {
|
|
138
|
+
"type": "json_schema",
|
|
139
|
+
"json_schema": {
|
|
140
|
+
"name": "extraction",
|
|
141
|
+
"strict": True,
|
|
142
|
+
"schema": schema_copy,
|
|
143
|
+
},
|
|
144
|
+
}
|
|
145
|
+
else:
|
|
146
|
+
kwargs["response_format"] = {"type": "json_object"}
|
|
147
|
+
|
|
148
|
+
resp = await client.chat.completions.create(**kwargs)
|
|
149
|
+
|
|
150
|
+
usage = getattr(resp, "usage", None)
|
|
151
|
+
prompt_tokens = getattr(usage, "prompt_tokens", 0)
|
|
152
|
+
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
153
|
+
total_tokens = getattr(usage, "total_tokens", 0)
|
|
154
|
+
|
|
155
|
+
total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
|
|
156
|
+
|
|
157
|
+
meta = {
|
|
158
|
+
"prompt_tokens": prompt_tokens,
|
|
159
|
+
"completion_tokens": completion_tokens,
|
|
160
|
+
"total_tokens": total_tokens,
|
|
161
|
+
"cost": round(total_cost, 6),
|
|
162
|
+
"raw_response": resp.model_dump(),
|
|
163
|
+
"model_name": model,
|
|
164
|
+
"deployment_id": deployment_id,
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
text = resp.choices[0].message.content
|
|
168
|
+
return {"text": text, "meta": meta}
|
|
169
|
+
|
|
170
|
+
async def _generate_claude(
|
|
171
|
+
self,
|
|
172
|
+
messages: list[dict[str, Any]],
|
|
173
|
+
options: dict[str, Any],
|
|
174
|
+
config: dict[str, Any],
|
|
175
|
+
model: str,
|
|
176
|
+
) -> dict[str, Any]:
|
|
177
|
+
"""Generate via Anthropic SDK with Azure endpoint."""
|
|
178
|
+
client = self._get_anthropic_client(config)
|
|
179
|
+
|
|
180
|
+
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
181
|
+
|
|
182
|
+
system_content = None
|
|
183
|
+
api_messages = []
|
|
184
|
+
for msg in messages:
|
|
185
|
+
if msg.get("role") == "system":
|
|
186
|
+
system_content = msg.get("content", "")
|
|
187
|
+
else:
|
|
188
|
+
api_messages.append(msg)
|
|
189
|
+
|
|
190
|
+
common_kwargs: dict[str, Any] = {
|
|
191
|
+
"model": model,
|
|
192
|
+
"messages": api_messages,
|
|
193
|
+
"temperature": opts["temperature"],
|
|
194
|
+
"max_tokens": opts["max_tokens"],
|
|
195
|
+
}
|
|
196
|
+
if system_content:
|
|
197
|
+
common_kwargs["system"] = system_content
|
|
198
|
+
|
|
199
|
+
if options.get("json_mode"):
|
|
200
|
+
json_schema = options.get("json_schema")
|
|
201
|
+
if json_schema:
|
|
202
|
+
tool_def = {
|
|
203
|
+
"name": "extract_json",
|
|
204
|
+
"description": "Extract structured data matching the schema",
|
|
205
|
+
"input_schema": json_schema,
|
|
206
|
+
}
|
|
207
|
+
resp = await client.messages.create(
|
|
208
|
+
**common_kwargs,
|
|
209
|
+
tools=[tool_def],
|
|
210
|
+
tool_choice={"type": "tool", "name": "extract_json"},
|
|
211
|
+
)
|
|
212
|
+
text = ""
|
|
213
|
+
for block in resp.content:
|
|
214
|
+
if block.type == "tool_use":
|
|
215
|
+
text = json.dumps(block.input)
|
|
216
|
+
break
|
|
217
|
+
else:
|
|
218
|
+
resp = await client.messages.create(**common_kwargs)
|
|
219
|
+
text = resp.content[0].text
|
|
220
|
+
else:
|
|
221
|
+
resp = await client.messages.create(**common_kwargs)
|
|
222
|
+
text = resp.content[0].text
|
|
223
|
+
|
|
224
|
+
prompt_tokens = resp.usage.input_tokens
|
|
225
|
+
completion_tokens = resp.usage.output_tokens
|
|
226
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
227
|
+
|
|
228
|
+
total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
|
|
229
|
+
|
|
230
|
+
meta = {
|
|
231
|
+
"prompt_tokens": prompt_tokens,
|
|
232
|
+
"completion_tokens": completion_tokens,
|
|
233
|
+
"total_tokens": total_tokens,
|
|
234
|
+
"cost": round(total_cost, 6),
|
|
235
|
+
"raw_response": dict(resp),
|
|
236
|
+
"model_name": model,
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
text_result = text or ""
|
|
240
|
+
return {"text": text_result, "meta": meta}
|
|
241
|
+
|
|
242
|
+
# ------------------------------------------------------------------
|
|
243
|
+
# Tool use
|
|
244
|
+
# ------------------------------------------------------------------
|
|
245
|
+
|
|
246
|
+
async def generate_messages_with_tools(
|
|
247
|
+
self,
|
|
248
|
+
messages: list[dict[str, Any]],
|
|
249
|
+
tools: list[dict[str, Any]],
|
|
250
|
+
options: dict[str, Any],
|
|
251
|
+
) -> dict[str, Any]:
|
|
252
|
+
"""Generate a response that may include tool calls."""
|
|
253
|
+
model = options.get("model", self.model)
|
|
254
|
+
config = self._resolve_model_config(model, options)
|
|
255
|
+
backend = classify_backend(model)
|
|
256
|
+
|
|
257
|
+
if backend == "claude":
|
|
258
|
+
return await self._generate_claude_with_tools(messages, tools, options, config, model)
|
|
259
|
+
else:
|
|
260
|
+
return await self._generate_openai_with_tools(messages, tools, options, config, model)
|
|
261
|
+
|
|
262
|
+
async def _generate_openai_with_tools(
|
|
263
|
+
self,
|
|
264
|
+
messages: list[dict[str, Any]],
|
|
265
|
+
tools: list[dict[str, Any]],
|
|
266
|
+
options: dict[str, Any],
|
|
267
|
+
config: dict[str, Any],
|
|
268
|
+
model: str,
|
|
269
|
+
) -> dict[str, Any]:
|
|
270
|
+
"""Tool calling via Azure OpenAI endpoint."""
|
|
271
|
+
client = self._get_openai_client(config)
|
|
272
|
+
deployment_id = config.get("deployment_id") or model
|
|
273
|
+
|
|
274
|
+
model_config = self._get_model_config("azure", model)
|
|
275
|
+
tokens_param = model_config["tokens_param"]
|
|
276
|
+
supports_temperature = model_config["supports_temperature"]
|
|
277
|
+
|
|
278
|
+
self._validate_model_capabilities("azure", model, using_tool_use=True)
|
|
279
|
+
|
|
280
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
281
|
+
|
|
282
|
+
kwargs: dict[str, Any] = {
|
|
283
|
+
"model": deployment_id,
|
|
284
|
+
"messages": messages,
|
|
285
|
+
"tools": tools,
|
|
286
|
+
}
|
|
287
|
+
kwargs[tokens_param] = opts.get("max_tokens", 512)
|
|
288
|
+
|
|
289
|
+
if supports_temperature and "temperature" in opts:
|
|
290
|
+
kwargs["temperature"] = opts["temperature"]
|
|
291
|
+
|
|
292
|
+
resp = await client.chat.completions.create(**kwargs)
|
|
293
|
+
|
|
294
|
+
usage = getattr(resp, "usage", None)
|
|
295
|
+
prompt_tokens = getattr(usage, "prompt_tokens", 0)
|
|
296
|
+
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
297
|
+
total_tokens = getattr(usage, "total_tokens", 0)
|
|
298
|
+
total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
|
|
299
|
+
|
|
300
|
+
meta = {
|
|
301
|
+
"prompt_tokens": prompt_tokens,
|
|
302
|
+
"completion_tokens": completion_tokens,
|
|
303
|
+
"total_tokens": total_tokens,
|
|
304
|
+
"cost": round(total_cost, 6),
|
|
305
|
+
"raw_response": resp.model_dump(),
|
|
306
|
+
"model_name": model,
|
|
307
|
+
"deployment_id": deployment_id,
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
choice = resp.choices[0]
|
|
311
|
+
text = choice.message.content or ""
|
|
312
|
+
stop_reason = choice.finish_reason
|
|
313
|
+
|
|
314
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
315
|
+
if choice.message.tool_calls:
|
|
316
|
+
for tc in choice.message.tool_calls:
|
|
317
|
+
try:
|
|
318
|
+
args = json.loads(tc.function.arguments)
|
|
319
|
+
except (json.JSONDecodeError, TypeError):
|
|
320
|
+
args = {}
|
|
321
|
+
tool_calls_out.append(
|
|
322
|
+
{
|
|
323
|
+
"id": tc.id,
|
|
324
|
+
"name": tc.function.name,
|
|
325
|
+
"arguments": args,
|
|
326
|
+
}
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
return {
|
|
330
|
+
"text": text,
|
|
331
|
+
"meta": meta,
|
|
332
|
+
"tool_calls": tool_calls_out,
|
|
333
|
+
"stop_reason": stop_reason,
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
async def _generate_claude_with_tools(
|
|
337
|
+
self,
|
|
338
|
+
messages: list[dict[str, Any]],
|
|
339
|
+
tools: list[dict[str, Any]],
|
|
340
|
+
options: dict[str, Any],
|
|
341
|
+
config: dict[str, Any],
|
|
342
|
+
model: str,
|
|
343
|
+
) -> dict[str, Any]:
|
|
344
|
+
"""Tool calling via Anthropic SDK with Azure endpoint."""
|
|
345
|
+
client = self._get_anthropic_client(config)
|
|
346
|
+
|
|
347
|
+
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
348
|
+
|
|
349
|
+
system_content = None
|
|
350
|
+
api_messages: list[dict[str, Any]] = []
|
|
351
|
+
for msg in messages:
|
|
352
|
+
if msg.get("role") == "system":
|
|
353
|
+
system_content = msg.get("content", "")
|
|
354
|
+
else:
|
|
355
|
+
api_messages.append(msg)
|
|
356
|
+
|
|
357
|
+
anthropic_tools = []
|
|
358
|
+
for t in tools:
|
|
359
|
+
if "type" in t and t["type"] == "function":
|
|
360
|
+
fn = t["function"]
|
|
361
|
+
anthropic_tools.append(
|
|
362
|
+
{
|
|
363
|
+
"name": fn["name"],
|
|
364
|
+
"description": fn.get("description", ""),
|
|
365
|
+
"input_schema": fn.get("parameters", {"type": "object", "properties": {}}),
|
|
366
|
+
}
|
|
367
|
+
)
|
|
368
|
+
elif "input_schema" in t:
|
|
369
|
+
anthropic_tools.append(t)
|
|
370
|
+
else:
|
|
371
|
+
anthropic_tools.append(t)
|
|
372
|
+
|
|
373
|
+
kwargs: dict[str, Any] = {
|
|
374
|
+
"model": model,
|
|
375
|
+
"messages": api_messages,
|
|
376
|
+
"temperature": opts["temperature"],
|
|
377
|
+
"max_tokens": opts["max_tokens"],
|
|
378
|
+
"tools": anthropic_tools,
|
|
379
|
+
}
|
|
380
|
+
if system_content:
|
|
381
|
+
kwargs["system"] = system_content
|
|
382
|
+
|
|
383
|
+
resp = await client.messages.create(**kwargs)
|
|
384
|
+
|
|
385
|
+
prompt_tokens = resp.usage.input_tokens
|
|
386
|
+
completion_tokens = resp.usage.output_tokens
|
|
387
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
388
|
+
total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
|
|
389
|
+
|
|
390
|
+
meta = {
|
|
391
|
+
"prompt_tokens": prompt_tokens,
|
|
392
|
+
"completion_tokens": completion_tokens,
|
|
393
|
+
"total_tokens": total_tokens,
|
|
394
|
+
"cost": round(total_cost, 6),
|
|
395
|
+
"raw_response": dict(resp),
|
|
396
|
+
"model_name": model,
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
text = ""
|
|
400
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
401
|
+
for block in resp.content:
|
|
402
|
+
if block.type == "text":
|
|
403
|
+
text += block.text
|
|
404
|
+
elif block.type == "tool_use":
|
|
405
|
+
tool_calls_out.append(
|
|
406
|
+
{
|
|
407
|
+
"id": block.id,
|
|
408
|
+
"name": block.name,
|
|
409
|
+
"arguments": block.input,
|
|
410
|
+
}
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
return {
|
|
414
|
+
"text": text,
|
|
415
|
+
"meta": meta,
|
|
416
|
+
"tool_calls": tool_calls_out,
|
|
417
|
+
"stop_reason": resp.stop_reason,
|
|
418
|
+
}
|
|
@@ -62,7 +62,10 @@ register_async_driver(
|
|
|
62
62
|
register_async_driver(
|
|
63
63
|
"azure",
|
|
64
64
|
lambda model=None: AsyncAzureDriver(
|
|
65
|
-
api_key=settings.azure_api_key,
|
|
65
|
+
api_key=settings.azure_api_key,
|
|
66
|
+
endpoint=settings.azure_api_endpoint,
|
|
67
|
+
deployment_id=settings.azure_deployment_id,
|
|
68
|
+
model=model or "gpt-4o-mini",
|
|
66
69
|
),
|
|
67
70
|
overwrite=True,
|
|
68
71
|
)
|