prompture 0.0.49__tar.gz → 0.0.50__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. {prompture-0.0.49 → prompture-0.0.50}/.env.copy +11 -1
  2. {prompture-0.0.49 → prompture-0.0.50}/PKG-INFO +1 -1
  3. prompture-0.0.50/VERSION +1 -0
  4. {prompture-0.0.49 → prompture-0.0.50}/prompture/__init__.py +9 -0
  5. {prompture-0.0.49 → prompture-0.0.50}/prompture/_version.py +2 -2
  6. {prompture-0.0.49 → prompture-0.0.50}/prompture/discovery.py +10 -3
  7. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/__init__.py +15 -1
  8. prompture-0.0.50/prompture/drivers/async_azure_driver.py +418 -0
  9. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_registry.py +4 -1
  10. prompture-0.0.50/prompture/drivers/azure_config.py +146 -0
  11. prompture-0.0.50/prompture/drivers/azure_driver.py +494 -0
  12. {prompture-0.0.49 → prompture-0.0.50}/prompture/settings.py +11 -1
  13. {prompture-0.0.49 → prompture-0.0.50}/prompture.egg-info/PKG-INFO +1 -1
  14. {prompture-0.0.49 → prompture-0.0.50}/prompture.egg-info/SOURCES.txt +1 -0
  15. prompture-0.0.49/VERSION +0 -1
  16. prompture-0.0.49/prompture/drivers/async_azure_driver.py +0 -201
  17. prompture-0.0.49/prompture/drivers/azure_driver.py +0 -243
  18. {prompture-0.0.49 → prompture-0.0.50}/.claude/skills/add-driver/SKILL.md +0 -0
  19. {prompture-0.0.49 → prompture-0.0.50}/.claude/skills/add-driver/references/driver-template.md +0 -0
  20. {prompture-0.0.49 → prompture-0.0.50}/.claude/skills/add-example/SKILL.md +0 -0
  21. {prompture-0.0.49 → prompture-0.0.50}/.claude/skills/add-field/SKILL.md +0 -0
  22. {prompture-0.0.49 → prompture-0.0.50}/.claude/skills/add-persona/SKILL.md +0 -0
  23. {prompture-0.0.49 → prompture-0.0.50}/.claude/skills/add-test/SKILL.md +0 -0
  24. {prompture-0.0.49 → prompture-0.0.50}/.claude/skills/add-tool/SKILL.md +0 -0
  25. {prompture-0.0.49 → prompture-0.0.50}/.claude/skills/run-tests/SKILL.md +0 -0
  26. {prompture-0.0.49 → prompture-0.0.50}/.claude/skills/scaffold-extraction/SKILL.md +0 -0
  27. {prompture-0.0.49 → prompture-0.0.50}/.claude/skills/update-pricing/SKILL.md +0 -0
  28. {prompture-0.0.49 → prompture-0.0.50}/.github/FUNDING.yml +0 -0
  29. {prompture-0.0.49 → prompture-0.0.50}/.github/scripts/update_docs_version.py +0 -0
  30. {prompture-0.0.49 → prompture-0.0.50}/.github/scripts/update_wrapper_version.py +0 -0
  31. {prompture-0.0.49 → prompture-0.0.50}/.github/workflows/dev.yml +0 -0
  32. {prompture-0.0.49 → prompture-0.0.50}/.github/workflows/documentation.yml +0 -0
  33. {prompture-0.0.49 → prompture-0.0.50}/.github/workflows/publish.yml +0 -0
  34. {prompture-0.0.49 → prompture-0.0.50}/CLAUDE.md +0 -0
  35. {prompture-0.0.49 → prompture-0.0.50}/LICENSE +0 -0
  36. {prompture-0.0.49 → prompture-0.0.50}/MANIFEST.in +0 -0
  37. {prompture-0.0.49 → prompture-0.0.50}/README.md +0 -0
  38. {prompture-0.0.49 → prompture-0.0.50}/ROADMAP.md +0 -0
  39. {prompture-0.0.49 → prompture-0.0.50}/docs/source/_static/custom.css +0 -0
  40. {prompture-0.0.49 → prompture-0.0.50}/docs/source/_templates/footer.html +0 -0
  41. {prompture-0.0.49 → prompture-0.0.50}/docs/source/api/core.rst +0 -0
  42. {prompture-0.0.49 → prompture-0.0.50}/docs/source/api/drivers.rst +0 -0
  43. {prompture-0.0.49 → prompture-0.0.50}/docs/source/api/field_definitions.rst +0 -0
  44. {prompture-0.0.49 → prompture-0.0.50}/docs/source/api/index.rst +0 -0
  45. {prompture-0.0.49 → prompture-0.0.50}/docs/source/api/runner.rst +0 -0
  46. {prompture-0.0.49 → prompture-0.0.50}/docs/source/api/tools.rst +0 -0
  47. {prompture-0.0.49 → prompture-0.0.50}/docs/source/api/validator.rst +0 -0
  48. {prompture-0.0.49 → prompture-0.0.50}/docs/source/conf.py +0 -0
  49. {prompture-0.0.49 → prompture-0.0.50}/docs/source/contributing.rst +0 -0
  50. {prompture-0.0.49 → prompture-0.0.50}/docs/source/examples.rst +0 -0
  51. {prompture-0.0.49 → prompture-0.0.50}/docs/source/field_definitions_reference.rst +0 -0
  52. {prompture-0.0.49 → prompture-0.0.50}/docs/source/index.rst +0 -0
  53. {prompture-0.0.49 → prompture-0.0.50}/docs/source/installation.rst +0 -0
  54. {prompture-0.0.49 → prompture-0.0.50}/docs/source/quickstart.rst +0 -0
  55. {prompture-0.0.49 → prompture-0.0.50}/docs/source/toon_input_guide.rst +0 -0
  56. {prompture-0.0.49 → prompture-0.0.50}/packages/README.md +0 -0
  57. {prompture-0.0.49 → prompture-0.0.50}/packages/llm_to_json/README.md +0 -0
  58. {prompture-0.0.49 → prompture-0.0.50}/packages/llm_to_json/llm_to_json/__init__.py +0 -0
  59. {prompture-0.0.49 → prompture-0.0.50}/packages/llm_to_json/pyproject.toml +0 -0
  60. {prompture-0.0.49 → prompture-0.0.50}/packages/llm_to_json/test.py +0 -0
  61. {prompture-0.0.49 → prompture-0.0.50}/packages/llm_to_toon/README.md +0 -0
  62. {prompture-0.0.49 → prompture-0.0.50}/packages/llm_to_toon/llm_to_toon/__init__.py +0 -0
  63. {prompture-0.0.49 → prompture-0.0.50}/packages/llm_to_toon/pyproject.toml +0 -0
  64. {prompture-0.0.49 → prompture-0.0.50}/packages/llm_to_toon/test.py +0 -0
  65. {prompture-0.0.49 → prompture-0.0.50}/prompture/agent.py +0 -0
  66. {prompture-0.0.49 → prompture-0.0.50}/prompture/agent_types.py +0 -0
  67. {prompture-0.0.49 → prompture-0.0.50}/prompture/aio/__init__.py +0 -0
  68. {prompture-0.0.49 → prompture-0.0.50}/prompture/async_agent.py +0 -0
  69. {prompture-0.0.49 → prompture-0.0.50}/prompture/async_conversation.py +0 -0
  70. {prompture-0.0.49 → prompture-0.0.50}/prompture/async_core.py +0 -0
  71. {prompture-0.0.49 → prompture-0.0.50}/prompture/async_driver.py +0 -0
  72. {prompture-0.0.49 → prompture-0.0.50}/prompture/async_groups.py +0 -0
  73. {prompture-0.0.49 → prompture-0.0.50}/prompture/cache.py +0 -0
  74. {prompture-0.0.49 → prompture-0.0.50}/prompture/callbacks.py +0 -0
  75. {prompture-0.0.49 → prompture-0.0.50}/prompture/cli.py +0 -0
  76. {prompture-0.0.49 → prompture-0.0.50}/prompture/conversation.py +0 -0
  77. {prompture-0.0.49 → prompture-0.0.50}/prompture/core.py +0 -0
  78. {prompture-0.0.49 → prompture-0.0.50}/prompture/cost_mixin.py +0 -0
  79. {prompture-0.0.49 → prompture-0.0.50}/prompture/driver.py +0 -0
  80. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/airllm_driver.py +0 -0
  81. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_airllm_driver.py +0 -0
  82. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_claude_driver.py +0 -0
  83. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_google_driver.py +0 -0
  84. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_grok_driver.py +0 -0
  85. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_groq_driver.py +0 -0
  86. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_hugging_driver.py +0 -0
  87. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_lmstudio_driver.py +0 -0
  88. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_local_http_driver.py +0 -0
  89. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_modelscope_driver.py +0 -0
  90. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_moonshot_driver.py +0 -0
  91. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_ollama_driver.py +0 -0
  92. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_openai_driver.py +0 -0
  93. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_openrouter_driver.py +0 -0
  94. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/async_zai_driver.py +0 -0
  95. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/claude_driver.py +0 -0
  96. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/google_driver.py +0 -0
  97. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/grok_driver.py +0 -0
  98. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/groq_driver.py +0 -0
  99. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/hugging_driver.py +0 -0
  100. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/lmstudio_driver.py +0 -0
  101. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/local_http_driver.py +0 -0
  102. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/modelscope_driver.py +0 -0
  103. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/moonshot_driver.py +0 -0
  104. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/ollama_driver.py +0 -0
  105. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/openai_driver.py +0 -0
  106. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/openrouter_driver.py +0 -0
  107. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/registry.py +0 -0
  108. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/vision_helpers.py +0 -0
  109. {prompture-0.0.49 → prompture-0.0.50}/prompture/drivers/zai_driver.py +0 -0
  110. {prompture-0.0.49 → prompture-0.0.50}/prompture/field_definitions.py +0 -0
  111. {prompture-0.0.49 → prompture-0.0.50}/prompture/group_types.py +0 -0
  112. {prompture-0.0.49 → prompture-0.0.50}/prompture/groups.py +0 -0
  113. {prompture-0.0.49 → prompture-0.0.50}/prompture/image.py +0 -0
  114. {prompture-0.0.49 → prompture-0.0.50}/prompture/ledger.py +0 -0
  115. {prompture-0.0.49 → prompture-0.0.50}/prompture/logging.py +0 -0
  116. {prompture-0.0.49 → prompture-0.0.50}/prompture/model_rates.py +0 -0
  117. {prompture-0.0.49 → prompture-0.0.50}/prompture/persistence.py +0 -0
  118. {prompture-0.0.49 → prompture-0.0.50}/prompture/persona.py +0 -0
  119. {prompture-0.0.49 → prompture-0.0.50}/prompture/runner.py +0 -0
  120. {prompture-0.0.49 → prompture-0.0.50}/prompture/scaffold/__init__.py +0 -0
  121. {prompture-0.0.49 → prompture-0.0.50}/prompture/scaffold/generator.py +0 -0
  122. {prompture-0.0.49 → prompture-0.0.50}/prompture/scaffold/templates/Dockerfile.j2 +0 -0
  123. {prompture-0.0.49 → prompture-0.0.50}/prompture/scaffold/templates/README.md.j2 +0 -0
  124. {prompture-0.0.49 → prompture-0.0.50}/prompture/scaffold/templates/config.py.j2 +0 -0
  125. {prompture-0.0.49 → prompture-0.0.50}/prompture/scaffold/templates/env.example.j2 +0 -0
  126. {prompture-0.0.49 → prompture-0.0.50}/prompture/scaffold/templates/main.py.j2 +0 -0
  127. {prompture-0.0.49 → prompture-0.0.50}/prompture/scaffold/templates/models.py.j2 +0 -0
  128. {prompture-0.0.49 → prompture-0.0.50}/prompture/scaffold/templates/requirements.txt.j2 +0 -0
  129. {prompture-0.0.49 → prompture-0.0.50}/prompture/serialization.py +0 -0
  130. {prompture-0.0.49 → prompture-0.0.50}/prompture/server.py +0 -0
  131. {prompture-0.0.49 → prompture-0.0.50}/prompture/session.py +0 -0
  132. {prompture-0.0.49 → prompture-0.0.50}/prompture/simulated_tools.py +0 -0
  133. {prompture-0.0.49 → prompture-0.0.50}/prompture/tools.py +0 -0
  134. {prompture-0.0.49 → prompture-0.0.50}/prompture/tools_schema.py +0 -0
  135. {prompture-0.0.49 → prompture-0.0.50}/prompture/validator.py +0 -0
  136. {prompture-0.0.49 → prompture-0.0.50}/prompture.egg-info/dependency_links.txt +0 -0
  137. {prompture-0.0.49 → prompture-0.0.50}/prompture.egg-info/entry_points.txt +0 -0
  138. {prompture-0.0.49 → prompture-0.0.50}/prompture.egg-info/requires.txt +0 -0
  139. {prompture-0.0.49 → prompture-0.0.50}/prompture.egg-info/top_level.txt +0 -0
  140. {prompture-0.0.49 → prompture-0.0.50}/pyproject.toml +0 -0
  141. {prompture-0.0.49 → prompture-0.0.50}/requirements.txt +0 -0
  142. {prompture-0.0.49 → prompture-0.0.50}/setup.cfg +0 -0
  143. {prompture-0.0.49 → prompture-0.0.50}/test.py +0 -0
  144. {prompture-0.0.49 → prompture-0.0.50}/test_version_diagnosis.py +0 -0
@@ -25,12 +25,22 @@ LMSTUDIO_ENDPOINT=http://127.0.0.1:1234/v1/chat/completions
25
25
  LMSTUDIO_MODEL=deepseek/deepseek-r1-0528-qwen3-8b
26
26
  LMSTUDIO_API_KEY=
27
27
 
28
- # Azure OpenAI Configuration
28
+ # Azure OpenAI Configuration (default backend)
29
29
  AZURE_API_KEY=
30
30
  AZURE_API_ENDPOINT=
31
31
  AZURE_DEPLOYMENT_ID=
32
32
  AZURE_API_VERSION=
33
33
 
34
+ # Azure Claude Backend (optional, for claude-* models on Azure)
35
+ AZURE_CLAUDE_API_KEY=
36
+ AZURE_CLAUDE_ENDPOINT=
37
+ AZURE_CLAUDE_API_VERSION=
38
+
39
+ # Azure Mistral Backend (optional, for mistral-*/mixtral-* models on Azure)
40
+ AZURE_MISTRAL_API_KEY=
41
+ AZURE_MISTRAL_ENDPOINT=
42
+ AZURE_MISTRAL_API_VERSION=
43
+
34
44
  # Additional Providers (not required for tests)
35
45
  # HuggingFace Configuration
36
46
  HF_ENDPOINT=
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: prompture
3
- Version: 0.0.49
3
+ Version: 0.0.50
4
4
  Summary: Ask LLMs to return structured JSON and run cross-model tests. API-first.
5
5
  Author-email: Juan Denis <juan@vene.co>
6
6
  License-Expression: MIT
@@ -0,0 +1 @@
1
+ 0.0.50
@@ -60,6 +60,8 @@ from .drivers import (
60
60
  OllamaDriver,
61
61
  OpenAIDriver,
62
62
  OpenRouterDriver,
63
+ # Azure config API
64
+ clear_azure_configs,
63
65
  get_driver,
64
66
  get_driver_for_model,
65
67
  # Plugin registration API
@@ -69,8 +71,11 @@ from .drivers import (
69
71
  list_registered_drivers,
70
72
  load_entry_point_drivers,
71
73
  register_async_driver,
74
+ register_azure_config,
72
75
  register_driver,
76
+ set_azure_config_resolver,
73
77
  unregister_async_driver,
78
+ unregister_azure_config,
74
79
  unregister_driver,
75
80
  )
76
81
  from .field_definitions import (
@@ -247,6 +252,7 @@ __all__ = [
247
252
  "clean_json_text",
248
253
  "clean_json_text_with_ai",
249
254
  "clean_toon_text",
255
+ "clear_azure_configs",
250
256
  "clear_persona_registry",
251
257
  "clear_registry",
252
258
  "configure_cache",
@@ -292,6 +298,7 @@ __all__ = [
292
298
  "normalize_enum_value",
293
299
  "refresh_rates_cache",
294
300
  "register_async_driver",
301
+ "register_azure_config",
295
302
  "register_driver",
296
303
  "register_field",
297
304
  "register_persona",
@@ -301,9 +308,11 @@ __all__ = [
301
308
  "reset_registry",
302
309
  "reset_trait_registry",
303
310
  "run_suite_from_spec",
311
+ "set_azure_config_resolver",
304
312
  "stepwise_extract_with_model",
305
313
  "tool_from_function",
306
314
  "unregister_async_driver",
315
+ "unregister_azure_config",
307
316
  "unregister_driver",
308
317
  "validate_against_schema",
309
318
  "validate_enum_value",
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.0.49'
32
- __version_tuple__ = version_tuple = (0, 0, 49)
31
+ __version__ = version = '0.0.50'
32
+ __version_tuple__ = version_tuple = (0, 0, 50)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -89,10 +89,17 @@ def get_available_models(
89
89
  if settings.openai_api_key or os.getenv("OPENAI_API_KEY"):
90
90
  is_configured = True
91
91
  elif provider == "azure":
92
+ from .drivers.azure_config import has_azure_config_resolver, has_registered_configs
93
+
92
94
  if (
93
- (settings.azure_api_key or os.getenv("AZURE_API_KEY"))
94
- and (settings.azure_api_endpoint or os.getenv("AZURE_API_ENDPOINT"))
95
- and (settings.azure_deployment_id or os.getenv("AZURE_DEPLOYMENT_ID"))
95
+ (
96
+ (settings.azure_api_key or os.getenv("AZURE_API_KEY"))
97
+ and (settings.azure_api_endpoint or os.getenv("AZURE_API_ENDPOINT"))
98
+ )
99
+ or (settings.azure_claude_api_key or os.getenv("AZURE_CLAUDE_API_KEY"))
100
+ or (settings.azure_mistral_api_key or os.getenv("AZURE_MISTRAL_API_KEY"))
101
+ or has_registered_configs()
102
+ or has_azure_config_resolver()
96
103
  ):
97
104
  is_configured = True
98
105
  elif provider == "claude":
@@ -44,6 +44,12 @@ from .async_openai_driver import AsyncOpenAIDriver
44
44
  from .async_openrouter_driver import AsyncOpenRouterDriver
45
45
  from .async_registry import ASYNC_DRIVER_REGISTRY, get_async_driver, get_async_driver_for_model
46
46
  from .async_zai_driver import AsyncZaiDriver
47
+ from .azure_config import (
48
+ clear_azure_configs,
49
+ register_azure_config,
50
+ set_azure_config_resolver,
51
+ unregister_azure_config,
52
+ )
47
53
  from .azure_driver import AzureDriver
48
54
  from .claude_driver import ClaudeDriver
49
55
  from .google_driver import GoogleDriver
@@ -100,7 +106,10 @@ register_driver(
100
106
  register_driver(
101
107
  "azure",
102
108
  lambda model=None: AzureDriver(
103
- api_key=settings.azure_api_key, endpoint=settings.azure_api_endpoint, deployment_id=settings.azure_deployment_id
109
+ api_key=settings.azure_api_key,
110
+ endpoint=settings.azure_api_endpoint,
111
+ deployment_id=settings.azure_deployment_id,
112
+ model=model or "gpt-4o-mini",
104
113
  ),
105
114
  overwrite=True,
106
115
  )
@@ -249,6 +258,8 @@ __all__ = [
249
258
  "OpenAIDriver",
250
259
  "OpenRouterDriver",
251
260
  "ZaiDriver",
261
+ # Azure config API
262
+ "clear_azure_configs",
252
263
  "get_async_driver",
253
264
  "get_async_driver_for_model",
254
265
  # Factory functions
@@ -260,8 +271,11 @@ __all__ = [
260
271
  "list_registered_drivers",
261
272
  "load_entry_point_drivers",
262
273
  "register_async_driver",
274
+ "register_azure_config",
263
275
  # Registry functions (public API)
264
276
  "register_driver",
277
+ "set_azure_config_resolver",
265
278
  "unregister_async_driver",
279
+ "unregister_azure_config",
266
280
  "unregister_driver",
267
281
  ]
@@ -0,0 +1,418 @@
1
+ """Async Azure driver with multi-endpoint and multi-backend support.
2
+
3
+ Requires the ``openai`` package (>=1.0.0). Claude backend also requires ``anthropic``.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import json
9
+ import os
10
+ from typing import Any
11
+
12
+ try:
13
+ from openai import AsyncAzureOpenAI
14
+ except Exception:
15
+ AsyncAzureOpenAI = None
16
+
17
+ try:
18
+ import anthropic
19
+ except Exception:
20
+ anthropic = None
21
+
22
+ from ..async_driver import AsyncDriver
23
+ from ..cost_mixin import CostMixin, prepare_strict_schema
24
+ from .azure_config import classify_backend, resolve_config
25
+ from .azure_driver import AzureDriver
26
+
27
+
28
+ class AsyncAzureDriver(CostMixin, AsyncDriver):
29
+ supports_json_mode = True
30
+ supports_json_schema = True
31
+ supports_tool_use = True
32
+ supports_vision = True
33
+
34
+ MODEL_PRICING = AzureDriver.MODEL_PRICING
35
+
36
+ def __init__(
37
+ self,
38
+ api_key: str | None = None,
39
+ endpoint: str | None = None,
40
+ deployment_id: str | None = None,
41
+ model: str = "gpt-4o-mini",
42
+ ):
43
+ self.model = model
44
+ self._default_config = {
45
+ "api_key": api_key or os.getenv("AZURE_API_KEY"),
46
+ "endpoint": endpoint or os.getenv("AZURE_API_ENDPOINT"),
47
+ "deployment_id": deployment_id or os.getenv("AZURE_DEPLOYMENT_ID"),
48
+ "api_version": os.getenv("AZURE_API_VERSION", "2024-02-15-preview"),
49
+ }
50
+ self._openai_clients: dict[tuple[str, str], AsyncAzureOpenAI] = {}
51
+ self._anthropic_clients: dict[tuple[str, str], Any] = {}
52
+
53
+ supports_messages = True
54
+
55
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
56
+ from .vision_helpers import _prepare_openai_vision_messages
57
+
58
+ return _prepare_openai_vision_messages(messages)
59
+
60
+ def _resolve_model_config(self, model: str, options: dict[str, Any]) -> dict[str, Any]:
61
+ """Resolve Azure config for this model using the priority chain."""
62
+ override = options.pop("azure_config", None)
63
+ return resolve_config(model, override=override, default_config=self._default_config)
64
+
65
+ def _get_openai_client(self, config: dict[str, Any]) -> AsyncAzureOpenAI:
66
+ """Get or create an AsyncAzureOpenAI client for the given config."""
67
+ if AsyncAzureOpenAI is None:
68
+ raise RuntimeError("openai package (>=1.0.0) with AsyncAzureOpenAI not installed")
69
+ cache_key = (config["endpoint"], config["api_key"])
70
+ if cache_key not in self._openai_clients:
71
+ self._openai_clients[cache_key] = AsyncAzureOpenAI(
72
+ api_key=config["api_key"],
73
+ api_version=config.get("api_version", "2024-02-15-preview"),
74
+ azure_endpoint=config["endpoint"],
75
+ )
76
+ return self._openai_clients[cache_key]
77
+
78
+ def _get_anthropic_client(self, config: dict[str, Any]) -> Any:
79
+ """Get or create an AsyncAnthropic client for the given Azure config."""
80
+ if anthropic is None:
81
+ raise RuntimeError("anthropic package not installed (required for Claude on Azure)")
82
+ cache_key = (config["endpoint"], config["api_key"])
83
+ if cache_key not in self._anthropic_clients:
84
+ self._anthropic_clients[cache_key] = anthropic.AsyncAnthropic(
85
+ base_url=config["endpoint"],
86
+ api_key=config["api_key"],
87
+ )
88
+ return self._anthropic_clients[cache_key]
89
+
90
+ async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
91
+ messages = [{"role": "user", "content": prompt}]
92
+ return await self._do_generate(messages, options)
93
+
94
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
95
+ return await self._do_generate(self._prepare_messages(messages), options)
96
+
97
+ async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
98
+ model = options.get("model", self.model)
99
+ config = self._resolve_model_config(model, options)
100
+ backend = classify_backend(model)
101
+
102
+ if backend == "claude":
103
+ return await self._generate_claude(messages, options, config, model)
104
+ else:
105
+ return await self._generate_openai(messages, options, config, model)
106
+
107
+ async def _generate_openai(
108
+ self,
109
+ messages: list[dict[str, Any]],
110
+ options: dict[str, Any],
111
+ config: dict[str, Any],
112
+ model: str,
113
+ ) -> dict[str, Any]:
114
+ """Generate via Azure OpenAI (or Mistral OpenAI-compat) endpoint."""
115
+ client = self._get_openai_client(config)
116
+ deployment_id = config.get("deployment_id") or model
117
+
118
+ model_config = self._get_model_config("azure", model)
119
+ tokens_param = model_config["tokens_param"]
120
+ supports_temperature = model_config["supports_temperature"]
121
+
122
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
123
+
124
+ kwargs = {
125
+ "model": deployment_id,
126
+ "messages": messages,
127
+ }
128
+ kwargs[tokens_param] = opts.get("max_tokens", 512)
129
+
130
+ if supports_temperature and "temperature" in opts:
131
+ kwargs["temperature"] = opts["temperature"]
132
+
133
+ if options.get("json_mode"):
134
+ json_schema = options.get("json_schema")
135
+ if json_schema:
136
+ schema_copy = prepare_strict_schema(json_schema)
137
+ kwargs["response_format"] = {
138
+ "type": "json_schema",
139
+ "json_schema": {
140
+ "name": "extraction",
141
+ "strict": True,
142
+ "schema": schema_copy,
143
+ },
144
+ }
145
+ else:
146
+ kwargs["response_format"] = {"type": "json_object"}
147
+
148
+ resp = await client.chat.completions.create(**kwargs)
149
+
150
+ usage = getattr(resp, "usage", None)
151
+ prompt_tokens = getattr(usage, "prompt_tokens", 0)
152
+ completion_tokens = getattr(usage, "completion_tokens", 0)
153
+ total_tokens = getattr(usage, "total_tokens", 0)
154
+
155
+ total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
156
+
157
+ meta = {
158
+ "prompt_tokens": prompt_tokens,
159
+ "completion_tokens": completion_tokens,
160
+ "total_tokens": total_tokens,
161
+ "cost": round(total_cost, 6),
162
+ "raw_response": resp.model_dump(),
163
+ "model_name": model,
164
+ "deployment_id": deployment_id,
165
+ }
166
+
167
+ text = resp.choices[0].message.content
168
+ return {"text": text, "meta": meta}
169
+
170
+ async def _generate_claude(
171
+ self,
172
+ messages: list[dict[str, Any]],
173
+ options: dict[str, Any],
174
+ config: dict[str, Any],
175
+ model: str,
176
+ ) -> dict[str, Any]:
177
+ """Generate via Anthropic SDK with Azure endpoint."""
178
+ client = self._get_anthropic_client(config)
179
+
180
+ opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
181
+
182
+ system_content = None
183
+ api_messages = []
184
+ for msg in messages:
185
+ if msg.get("role") == "system":
186
+ system_content = msg.get("content", "")
187
+ else:
188
+ api_messages.append(msg)
189
+
190
+ common_kwargs: dict[str, Any] = {
191
+ "model": model,
192
+ "messages": api_messages,
193
+ "temperature": opts["temperature"],
194
+ "max_tokens": opts["max_tokens"],
195
+ }
196
+ if system_content:
197
+ common_kwargs["system"] = system_content
198
+
199
+ if options.get("json_mode"):
200
+ json_schema = options.get("json_schema")
201
+ if json_schema:
202
+ tool_def = {
203
+ "name": "extract_json",
204
+ "description": "Extract structured data matching the schema",
205
+ "input_schema": json_schema,
206
+ }
207
+ resp = await client.messages.create(
208
+ **common_kwargs,
209
+ tools=[tool_def],
210
+ tool_choice={"type": "tool", "name": "extract_json"},
211
+ )
212
+ text = ""
213
+ for block in resp.content:
214
+ if block.type == "tool_use":
215
+ text = json.dumps(block.input)
216
+ break
217
+ else:
218
+ resp = await client.messages.create(**common_kwargs)
219
+ text = resp.content[0].text
220
+ else:
221
+ resp = await client.messages.create(**common_kwargs)
222
+ text = resp.content[0].text
223
+
224
+ prompt_tokens = resp.usage.input_tokens
225
+ completion_tokens = resp.usage.output_tokens
226
+ total_tokens = prompt_tokens + completion_tokens
227
+
228
+ total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
229
+
230
+ meta = {
231
+ "prompt_tokens": prompt_tokens,
232
+ "completion_tokens": completion_tokens,
233
+ "total_tokens": total_tokens,
234
+ "cost": round(total_cost, 6),
235
+ "raw_response": dict(resp),
236
+ "model_name": model,
237
+ }
238
+
239
+ text_result = text or ""
240
+ return {"text": text_result, "meta": meta}
241
+
242
+ # ------------------------------------------------------------------
243
+ # Tool use
244
+ # ------------------------------------------------------------------
245
+
246
+ async def generate_messages_with_tools(
247
+ self,
248
+ messages: list[dict[str, Any]],
249
+ tools: list[dict[str, Any]],
250
+ options: dict[str, Any],
251
+ ) -> dict[str, Any]:
252
+ """Generate a response that may include tool calls."""
253
+ model = options.get("model", self.model)
254
+ config = self._resolve_model_config(model, options)
255
+ backend = classify_backend(model)
256
+
257
+ if backend == "claude":
258
+ return await self._generate_claude_with_tools(messages, tools, options, config, model)
259
+ else:
260
+ return await self._generate_openai_with_tools(messages, tools, options, config, model)
261
+
262
+ async def _generate_openai_with_tools(
263
+ self,
264
+ messages: list[dict[str, Any]],
265
+ tools: list[dict[str, Any]],
266
+ options: dict[str, Any],
267
+ config: dict[str, Any],
268
+ model: str,
269
+ ) -> dict[str, Any]:
270
+ """Tool calling via Azure OpenAI endpoint."""
271
+ client = self._get_openai_client(config)
272
+ deployment_id = config.get("deployment_id") or model
273
+
274
+ model_config = self._get_model_config("azure", model)
275
+ tokens_param = model_config["tokens_param"]
276
+ supports_temperature = model_config["supports_temperature"]
277
+
278
+ self._validate_model_capabilities("azure", model, using_tool_use=True)
279
+
280
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
281
+
282
+ kwargs: dict[str, Any] = {
283
+ "model": deployment_id,
284
+ "messages": messages,
285
+ "tools": tools,
286
+ }
287
+ kwargs[tokens_param] = opts.get("max_tokens", 512)
288
+
289
+ if supports_temperature and "temperature" in opts:
290
+ kwargs["temperature"] = opts["temperature"]
291
+
292
+ resp = await client.chat.completions.create(**kwargs)
293
+
294
+ usage = getattr(resp, "usage", None)
295
+ prompt_tokens = getattr(usage, "prompt_tokens", 0)
296
+ completion_tokens = getattr(usage, "completion_tokens", 0)
297
+ total_tokens = getattr(usage, "total_tokens", 0)
298
+ total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
299
+
300
+ meta = {
301
+ "prompt_tokens": prompt_tokens,
302
+ "completion_tokens": completion_tokens,
303
+ "total_tokens": total_tokens,
304
+ "cost": round(total_cost, 6),
305
+ "raw_response": resp.model_dump(),
306
+ "model_name": model,
307
+ "deployment_id": deployment_id,
308
+ }
309
+
310
+ choice = resp.choices[0]
311
+ text = choice.message.content or ""
312
+ stop_reason = choice.finish_reason
313
+
314
+ tool_calls_out: list[dict[str, Any]] = []
315
+ if choice.message.tool_calls:
316
+ for tc in choice.message.tool_calls:
317
+ try:
318
+ args = json.loads(tc.function.arguments)
319
+ except (json.JSONDecodeError, TypeError):
320
+ args = {}
321
+ tool_calls_out.append(
322
+ {
323
+ "id": tc.id,
324
+ "name": tc.function.name,
325
+ "arguments": args,
326
+ }
327
+ )
328
+
329
+ return {
330
+ "text": text,
331
+ "meta": meta,
332
+ "tool_calls": tool_calls_out,
333
+ "stop_reason": stop_reason,
334
+ }
335
+
336
+ async def _generate_claude_with_tools(
337
+ self,
338
+ messages: list[dict[str, Any]],
339
+ tools: list[dict[str, Any]],
340
+ options: dict[str, Any],
341
+ config: dict[str, Any],
342
+ model: str,
343
+ ) -> dict[str, Any]:
344
+ """Tool calling via Anthropic SDK with Azure endpoint."""
345
+ client = self._get_anthropic_client(config)
346
+
347
+ opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
348
+
349
+ system_content = None
350
+ api_messages: list[dict[str, Any]] = []
351
+ for msg in messages:
352
+ if msg.get("role") == "system":
353
+ system_content = msg.get("content", "")
354
+ else:
355
+ api_messages.append(msg)
356
+
357
+ anthropic_tools = []
358
+ for t in tools:
359
+ if "type" in t and t["type"] == "function":
360
+ fn = t["function"]
361
+ anthropic_tools.append(
362
+ {
363
+ "name": fn["name"],
364
+ "description": fn.get("description", ""),
365
+ "input_schema": fn.get("parameters", {"type": "object", "properties": {}}),
366
+ }
367
+ )
368
+ elif "input_schema" in t:
369
+ anthropic_tools.append(t)
370
+ else:
371
+ anthropic_tools.append(t)
372
+
373
+ kwargs: dict[str, Any] = {
374
+ "model": model,
375
+ "messages": api_messages,
376
+ "temperature": opts["temperature"],
377
+ "max_tokens": opts["max_tokens"],
378
+ "tools": anthropic_tools,
379
+ }
380
+ if system_content:
381
+ kwargs["system"] = system_content
382
+
383
+ resp = await client.messages.create(**kwargs)
384
+
385
+ prompt_tokens = resp.usage.input_tokens
386
+ completion_tokens = resp.usage.output_tokens
387
+ total_tokens = prompt_tokens + completion_tokens
388
+ total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
389
+
390
+ meta = {
391
+ "prompt_tokens": prompt_tokens,
392
+ "completion_tokens": completion_tokens,
393
+ "total_tokens": total_tokens,
394
+ "cost": round(total_cost, 6),
395
+ "raw_response": dict(resp),
396
+ "model_name": model,
397
+ }
398
+
399
+ text = ""
400
+ tool_calls_out: list[dict[str, Any]] = []
401
+ for block in resp.content:
402
+ if block.type == "text":
403
+ text += block.text
404
+ elif block.type == "tool_use":
405
+ tool_calls_out.append(
406
+ {
407
+ "id": block.id,
408
+ "name": block.name,
409
+ "arguments": block.input,
410
+ }
411
+ )
412
+
413
+ return {
414
+ "text": text,
415
+ "meta": meta,
416
+ "tool_calls": tool_calls_out,
417
+ "stop_reason": resp.stop_reason,
418
+ }
@@ -62,7 +62,10 @@ register_async_driver(
62
62
  register_async_driver(
63
63
  "azure",
64
64
  lambda model=None: AsyncAzureDriver(
65
- api_key=settings.azure_api_key, endpoint=settings.azure_api_endpoint, deployment_id=settings.azure_deployment_id
65
+ api_key=settings.azure_api_key,
66
+ endpoint=settings.azure_api_endpoint,
67
+ deployment_id=settings.azure_deployment_id,
68
+ model=model or "gpt-4o-mini",
66
69
  ),
67
70
  overwrite=True,
68
71
  )