prompture 0.0.38.dev2__tar.gz → 0.0.39__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. {prompture-0.0.38.dev2/prompture.egg-info → prompture-0.0.39}/PKG-INFO +1 -1
  2. prompture-0.0.39/VERSION +1 -0
  3. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/__init__.py +9 -1
  4. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/_version.py +2 -2
  5. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/async_driver.py +39 -0
  6. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/cost_mixin.py +37 -0
  7. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/discovery.py +54 -39
  8. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/driver.py +39 -0
  9. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/drivers/async_azure_driver.py +4 -4
  10. prompture-0.0.39/prompture/drivers/async_claude_driver.py +282 -0
  11. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/drivers/async_google_driver.py +10 -0
  12. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/drivers/async_grok_driver.py +4 -4
  13. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/drivers/async_groq_driver.py +4 -4
  14. prompture-0.0.39/prompture/drivers/async_openai_driver.py +253 -0
  15. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/drivers/async_openrouter_driver.py +4 -4
  16. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/drivers/azure_driver.py +3 -3
  17. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/drivers/claude_driver.py +10 -0
  18. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/drivers/google_driver.py +10 -0
  19. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/drivers/grok_driver.py +4 -4
  20. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/drivers/groq_driver.py +4 -4
  21. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/drivers/openai_driver.py +19 -10
  22. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/drivers/openrouter_driver.py +4 -4
  23. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/model_rates.py +112 -2
  24. {prompture-0.0.38.dev2 → prompture-0.0.39/prompture.egg-info}/PKG-INFO +1 -1
  25. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture.egg-info/SOURCES.txt +1 -0
  26. prompture-0.0.38.dev2/prompture/drivers/async_claude_driver.py +0 -113
  27. prompture-0.0.38.dev2/prompture/drivers/async_openai_driver.py +0 -102
  28. {prompture-0.0.38.dev2 → prompture-0.0.39}/.claude/skills/add-driver/SKILL.md +0 -0
  29. {prompture-0.0.38.dev2 → prompture-0.0.39}/.claude/skills/add-driver/references/driver-template.md +0 -0
  30. {prompture-0.0.38.dev2 → prompture-0.0.39}/.claude/skills/add-example/SKILL.md +0 -0
  31. {prompture-0.0.38.dev2 → prompture-0.0.39}/.claude/skills/add-field/SKILL.md +0 -0
  32. {prompture-0.0.38.dev2 → prompture-0.0.39}/.claude/skills/add-test/SKILL.md +0 -0
  33. {prompture-0.0.38.dev2 → prompture-0.0.39}/.claude/skills/run-tests/SKILL.md +0 -0
  34. {prompture-0.0.38.dev2 → prompture-0.0.39}/.claude/skills/scaffold-extraction/SKILL.md +0 -0
  35. {prompture-0.0.38.dev2 → prompture-0.0.39}/.claude/skills/update-pricing/SKILL.md +0 -0
  36. {prompture-0.0.38.dev2 → prompture-0.0.39}/.env.copy +0 -0
  37. {prompture-0.0.38.dev2 → prompture-0.0.39}/.github/FUNDING.yml +0 -0
  38. {prompture-0.0.38.dev2 → prompture-0.0.39}/.github/scripts/update_docs_version.py +0 -0
  39. {prompture-0.0.38.dev2 → prompture-0.0.39}/.github/scripts/update_wrapper_version.py +0 -0
  40. {prompture-0.0.38.dev2 → prompture-0.0.39}/.github/workflows/dev.yml +0 -0
  41. {prompture-0.0.38.dev2 → prompture-0.0.39}/.github/workflows/documentation.yml +0 -0
  42. {prompture-0.0.38.dev2 → prompture-0.0.39}/.github/workflows/publish.yml +0 -0
  43. {prompture-0.0.38.dev2 → prompture-0.0.39}/CLAUDE.md +0 -0
  44. {prompture-0.0.38.dev2 → prompture-0.0.39}/LICENSE +0 -0
  45. {prompture-0.0.38.dev2 → prompture-0.0.39}/MANIFEST.in +0 -0
  46. {prompture-0.0.38.dev2 → prompture-0.0.39}/README.md +0 -0
  47. {prompture-0.0.38.dev2 → prompture-0.0.39}/ROADMAP.md +0 -0
  48. {prompture-0.0.38.dev2 → prompture-0.0.39}/docs/source/_static/custom.css +0 -0
  49. {prompture-0.0.38.dev2 → prompture-0.0.39}/docs/source/_templates/footer.html +0 -0
  50. {prompture-0.0.38.dev2 → prompture-0.0.39}/docs/source/api/core.rst +0 -0
  51. {prompture-0.0.38.dev2 → prompture-0.0.39}/docs/source/api/drivers.rst +0 -0
  52. {prompture-0.0.38.dev2 → prompture-0.0.39}/docs/source/api/field_definitions.rst +0 -0
  53. {prompture-0.0.38.dev2 → prompture-0.0.39}/docs/source/api/index.rst +0 -0
  54. {prompture-0.0.38.dev2 → prompture-0.0.39}/docs/source/api/runner.rst +0 -0
  55. {prompture-0.0.38.dev2 → prompture-0.0.39}/docs/source/api/tools.rst +0 -0
  56. {prompture-0.0.38.dev2 → prompture-0.0.39}/docs/source/api/validator.rst +0 -0
  57. {prompture-0.0.38.dev2 → prompture-0.0.39}/docs/source/conf.py +0 -0
  58. {prompture-0.0.38.dev2 → prompture-0.0.39}/docs/source/contributing.rst +0 -0
  59. {prompture-0.0.38.dev2 → prompture-0.0.39}/docs/source/examples.rst +0 -0
  60. {prompture-0.0.38.dev2 → prompture-0.0.39}/docs/source/field_definitions_reference.rst +0 -0
  61. {prompture-0.0.38.dev2 → prompture-0.0.39}/docs/source/index.rst +0 -0
  62. {prompture-0.0.38.dev2 → prompture-0.0.39}/docs/source/installation.rst +0 -0
  63. {prompture-0.0.38.dev2 → prompture-0.0.39}/docs/source/quickstart.rst +0 -0
  64. {prompture-0.0.38.dev2 → prompture-0.0.39}/docs/source/toon_input_guide.rst +0 -0
  65. {prompture-0.0.38.dev2 → prompture-0.0.39}/packages/README.md +0 -0
  66. {prompture-0.0.38.dev2 → prompture-0.0.39}/packages/llm_to_json/README.md +0 -0
  67. {prompture-0.0.38.dev2 → prompture-0.0.39}/packages/llm_to_json/llm_to_json/__init__.py +0 -0
  68. {prompture-0.0.38.dev2 → prompture-0.0.39}/packages/llm_to_json/pyproject.toml +0 -0
  69. {prompture-0.0.38.dev2 → prompture-0.0.39}/packages/llm_to_json/test.py +0 -0
  70. {prompture-0.0.38.dev2 → prompture-0.0.39}/packages/llm_to_toon/README.md +0 -0
  71. {prompture-0.0.38.dev2 → prompture-0.0.39}/packages/llm_to_toon/llm_to_toon/__init__.py +0 -0
  72. {prompture-0.0.38.dev2 → prompture-0.0.39}/packages/llm_to_toon/pyproject.toml +0 -0
  73. {prompture-0.0.38.dev2 → prompture-0.0.39}/packages/llm_to_toon/test.py +0 -0
  74. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/agent.py +0 -0
  75. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/agent_types.py +0 -0
  76. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/aio/__init__.py +0 -0
  77. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/async_agent.py +0 -0
  78. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/async_conversation.py +0 -0
  79. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/async_core.py +0 -0
  80. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/async_groups.py +0 -0
  81. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/cache.py +0 -0
  82. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/callbacks.py +0 -0
  83. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/cli.py +0 -0
  84. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/conversation.py +0 -0
  85. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/core.py +0 -0
  86. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/drivers/__init__.py +0 -0
  87. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/drivers/airllm_driver.py +0 -0
  88. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/drivers/async_airllm_driver.py +0 -0
  89. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/drivers/async_hugging_driver.py +0 -0
  90. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/drivers/async_lmstudio_driver.py +0 -0
  91. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/drivers/async_local_http_driver.py +0 -0
  92. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/drivers/async_ollama_driver.py +0 -0
  93. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/drivers/async_registry.py +0 -0
  94. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/drivers/hugging_driver.py +0 -0
  95. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/drivers/lmstudio_driver.py +0 -0
  96. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/drivers/local_http_driver.py +0 -0
  97. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/drivers/ollama_driver.py +0 -0
  98. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/drivers/registry.py +0 -0
  99. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/drivers/vision_helpers.py +0 -0
  100. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/field_definitions.py +0 -0
  101. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/group_types.py +0 -0
  102. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/groups.py +0 -0
  103. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/image.py +0 -0
  104. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/logging.py +0 -0
  105. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/persistence.py +0 -0
  106. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/persona.py +0 -0
  107. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/runner.py +0 -0
  108. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/scaffold/__init__.py +0 -0
  109. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/scaffold/generator.py +0 -0
  110. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/scaffold/templates/Dockerfile.j2 +0 -0
  111. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/scaffold/templates/README.md.j2 +0 -0
  112. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/scaffold/templates/config.py.j2 +0 -0
  113. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/scaffold/templates/env.example.j2 +0 -0
  114. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/scaffold/templates/main.py.j2 +0 -0
  115. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/scaffold/templates/models.py.j2 +0 -0
  116. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/scaffold/templates/requirements.txt.j2 +0 -0
  117. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/serialization.py +0 -0
  118. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/server.py +0 -0
  119. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/session.py +0 -0
  120. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/settings.py +0 -0
  121. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/tools.py +0 -0
  122. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/tools_schema.py +0 -0
  123. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture/validator.py +0 -0
  124. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture.egg-info/dependency_links.txt +0 -0
  125. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture.egg-info/entry_points.txt +0 -0
  126. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture.egg-info/requires.txt +0 -0
  127. {prompture-0.0.38.dev2 → prompture-0.0.39}/prompture.egg-info/top_level.txt +0 -0
  128. {prompture-0.0.38.dev2 → prompture-0.0.39}/pyproject.toml +0 -0
  129. {prompture-0.0.38.dev2 → prompture-0.0.39}/requirements.txt +0 -0
  130. {prompture-0.0.38.dev2 → prompture-0.0.39}/setup.cfg +0 -0
  131. {prompture-0.0.38.dev2 → prompture-0.0.39}/test.py +0 -0
  132. {prompture-0.0.38.dev2 → prompture-0.0.39}/test_version_diagnosis.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: prompture
3
- Version: 0.0.38.dev2
3
+ Version: 0.0.39
4
4
  Summary: Ask LLMs to return structured JSON and run cross-model tests. API-first.
5
5
  Author-email: Juan Denis <juan@vene.co>
6
6
  License-Expression: MIT
@@ -0,0 +1 @@
1
+ 0.0.39
@@ -111,7 +111,13 @@ from .image import (
111
111
  make_image,
112
112
  )
113
113
  from .logging import JSONFormatter, configure_logging
114
- from .model_rates import get_model_info, get_model_rates, refresh_rates_cache
114
+ from .model_rates import (
115
+ ModelCapabilities,
116
+ get_model_capabilities,
117
+ get_model_info,
118
+ get_model_rates,
119
+ refresh_rates_cache,
120
+ )
115
121
  from .persistence import ConversationStore
116
122
  from .persona import (
117
123
  PERSONAS,
@@ -213,6 +219,7 @@ __all__ = [
213
219
  "LocalHTTPDriver",
214
220
  "LoopGroup",
215
221
  "MemoryCacheBackend",
222
+ "ModelCapabilities",
216
223
  "ModelRetry",
217
224
  "OllamaDriver",
218
225
  "OpenAIDriver",
@@ -255,6 +262,7 @@ __all__ = [
255
262
  "get_driver_for_model",
256
263
  "get_field_definition",
257
264
  "get_field_names",
265
+ "get_model_capabilities",
258
266
  "get_model_info",
259
267
  "get_model_rates",
260
268
  "get_persona",
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.0.38.dev2'
32
- __version_tuple__ = version_tuple = (0, 0, 38, 'dev2')
31
+ __version__ = version = '0.0.39'
32
+ __version_tuple__ = version_tuple = (0, 0, 39)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -166,6 +166,45 @@ class AsyncDriver:
166
166
  except Exception:
167
167
  logger.exception("Callback %s raised an exception", event)
168
168
 
169
+ def _validate_model_capabilities(
170
+ self,
171
+ provider: str,
172
+ model: str,
173
+ *,
174
+ using_tool_use: bool = False,
175
+ using_json_schema: bool = False,
176
+ using_vision: bool = False,
177
+ ) -> None:
178
+ """Log warnings when the model may not support a requested feature.
179
+
180
+ Uses models.dev metadata as a secondary signal. Warnings only — the
181
+ API is the final authority and models.dev data may be stale.
182
+ """
183
+ from .model_rates import get_model_capabilities
184
+
185
+ caps = get_model_capabilities(provider, model)
186
+ if caps is None:
187
+ return
188
+
189
+ if using_tool_use and caps.supports_tool_use is False:
190
+ logger.warning(
191
+ "Model %s/%s may not support tool use according to models.dev metadata",
192
+ provider,
193
+ model,
194
+ )
195
+ if using_json_schema and caps.supports_structured_output is False:
196
+ logger.warning(
197
+ "Model %s/%s may not support structured output / JSON schema according to models.dev metadata",
198
+ provider,
199
+ model,
200
+ )
201
+ if using_vision and caps.supports_vision is False:
202
+ logger.warning(
203
+ "Model %s/%s may not support vision/image inputs according to models.dev metadata",
204
+ provider,
205
+ model,
206
+ )
207
+
169
208
  def _check_vision_support(self, messages: list[dict[str, Any]]) -> None:
170
209
  """Raise if messages contain image blocks and the driver lacks vision support."""
171
210
  if self.supports_vision:
@@ -49,3 +49,40 @@ class CostMixin:
49
49
  completion_cost = (completion_tokens / unit) * model_pricing["completion"]
50
50
 
51
51
  return round(prompt_cost + completion_cost, 6)
52
+
53
+ def _get_model_config(self, provider: str, model: str) -> dict[str, Any]:
54
+ """Merge live models.dev capabilities with hardcoded ``MODEL_PRICING``.
55
+
56
+ Returns a dict with:
57
+ - ``tokens_param`` — always from hardcoded ``MODEL_PRICING`` (API-specific)
58
+ - ``supports_temperature`` — prefers live data, falls back to hardcoded, default ``True``
59
+ - ``context_window`` — from live data only (``None`` if unavailable)
60
+ - ``max_output_tokens`` — from live data only (``None`` if unavailable)
61
+ """
62
+ from .model_rates import get_model_capabilities
63
+
64
+ hardcoded = self.MODEL_PRICING.get(model, {})
65
+
66
+ # tokens_param is always from hardcoded config (API-specific, not in models.dev)
67
+ tokens_param = hardcoded.get("tokens_param", "max_tokens")
68
+
69
+ # Start with hardcoded supports_temperature, default True
70
+ supports_temperature = hardcoded.get("supports_temperature", True)
71
+
72
+ context_window: int | None = None
73
+ max_output_tokens: int | None = None
74
+
75
+ # Override with live data when available
76
+ caps = get_model_capabilities(provider, model)
77
+ if caps is not None:
78
+ if caps.supports_temperature is not None:
79
+ supports_temperature = caps.supports_temperature
80
+ context_window = caps.context_window
81
+ max_output_tokens = caps.max_output_tokens
82
+
83
+ return {
84
+ "tokens_param": tokens_param,
85
+ "supports_temperature": supports_temperature,
86
+ "context_window": context_window,
87
+ "max_output_tokens": max_output_tokens,
88
+ }
@@ -1,7 +1,11 @@
1
1
  """Discovery module for auto-detecting available models."""
2
2
 
3
+ from __future__ import annotations
4
+
5
+ import dataclasses
3
6
  import logging
4
7
  import os
8
+ from typing import Any, overload
5
9
 
6
10
  import requests
7
11
 
@@ -22,23 +26,34 @@ from .settings import settings
22
26
  logger = logging.getLogger(__name__)
23
27
 
24
28
 
25
- def get_available_models() -> list[str]:
26
- """
27
- Auto-detects all available models based on configured drivers and environment variables.
29
+ @overload
30
+ def get_available_models(*, include_capabilities: bool = False) -> list[str]: ...
31
+
32
+
33
+ @overload
34
+ def get_available_models(*, include_capabilities: bool = True) -> list[dict[str, Any]]: ...
28
35
 
29
- Iterates through supported providers and checks if they are configured (e.g. API key present).
30
- For static drivers, returns models from their MODEL_PRICING keys.
31
- For dynamic drivers (like Ollama), attempts to fetch available models from the endpoint.
36
+
37
+ def get_available_models(*, include_capabilities: bool = False) -> list[str] | list[dict[str, Any]]:
38
+ """Auto-detect available models based on configured drivers and environment variables.
39
+
40
+ Iterates through supported providers and checks if they are configured
41
+ (e.g. API key present). For static drivers, returns models from their
42
+ ``MODEL_PRICING`` keys. For dynamic drivers (like Ollama), attempts to
43
+ fetch available models from the endpoint.
44
+
45
+ Args:
46
+ include_capabilities: When ``True``, return enriched dicts with
47
+ ``model``, ``provider``, ``model_id``, and ``capabilities``
48
+ fields instead of plain ``"provider/model_id"`` strings.
32
49
 
33
50
  Returns:
34
- A list of unique model strings in the format "provider/model_id".
51
+ A sorted list of unique model strings (default) or enriched dicts.
35
52
  """
36
53
  available_models: set[str] = set()
37
54
  configured_providers: set[str] = set()
38
55
 
39
56
  # Map of provider name to driver class
40
- # We need to map the registry keys to the actual classes to check MODEL_PRICING
41
- # and instantiate for dynamic checks if needed.
42
57
  provider_classes = {
43
58
  "openai": OpenAIDriver,
44
59
  "azure": AzureDriver,
@@ -54,11 +69,6 @@ def get_available_models() -> list[str]:
54
69
 
55
70
  for provider, driver_cls in provider_classes.items():
56
71
  try:
57
- # 1. Check if the provider is configured (has API key or endpoint)
58
- # We can check this by looking at the settings or env vars that the driver uses.
59
- # A simple way is to try to instantiate it with defaults, but that might fail if keys are missing.
60
- # Instead, let's check the specific requirements for each known provider.
61
-
62
72
  is_configured = False
63
73
 
64
74
  if provider == "openai":
@@ -87,13 +97,10 @@ def get_available_models() -> list[str]:
87
97
  if settings.grok_api_key or os.getenv("GROK_API_KEY"):
88
98
  is_configured = True
89
99
  elif provider == "ollama":
90
- # Ollama is always considered "configured" as it defaults to localhost
91
- # We will check connectivity later
92
100
  is_configured = True
93
101
  elif provider == "lmstudio":
94
- # LM Studio is similar to Ollama, defaults to localhost
95
102
  is_configured = True
96
- elif provider == "local_http" and (settings.local_http_endpoint or os.getenv("LOCAL_HTTP_ENDPOINT")):
103
+ elif provider == "local_http" and os.getenv("LOCAL_HTTP_ENDPOINT"):
97
104
  is_configured = True
98
105
 
99
106
  if not is_configured:
@@ -101,36 +108,20 @@ def get_available_models() -> list[str]:
101
108
 
102
109
  configured_providers.add(provider)
103
110
 
104
- # 2. Static Detection: Get models from MODEL_PRICING
111
+ # Static Detection: Get models from MODEL_PRICING
105
112
  if hasattr(driver_cls, "MODEL_PRICING"):
106
113
  pricing = driver_cls.MODEL_PRICING
107
114
  for model_id in pricing:
108
- # Skip "default" or generic keys if they exist
109
115
  if model_id == "default":
110
116
  continue
111
-
112
- # For Azure, the model_id in pricing is usually the base model name,
113
- # but the user needs to use the deployment ID.
114
- # However, our Azure driver implementation uses the deployment_id from init
115
- # as the "model" for the request, but expects the user to pass a model name
116
- # that maps to pricing?
117
- # Looking at AzureDriver:
118
- # kwargs = {"model": self.deployment_id, ...}
119
- # model = options.get("model", self.model) -> used for pricing lookup
120
- # So we should list the keys in MODEL_PRICING as available "models"
121
- # even though for Azure specifically it's a bit weird because of deployment IDs.
122
- # But for general discovery, listing supported models is correct.
123
-
124
117
  available_models.add(f"{provider}/{model_id}")
125
118
 
126
- # 3. Dynamic Detection: Specific logic for Ollama
119
+ # Dynamic Detection: Specific logic for Ollama
127
120
  if provider == "ollama":
128
121
  try:
129
122
  endpoint = settings.ollama_endpoint or os.getenv(
130
123
  "OLLAMA_ENDPOINT", "http://localhost:11434/api/generate"
131
124
  )
132
- # We need the base URL for tags, usually http://localhost:11434/api/tags
133
- # The configured endpoint might be .../api/generate or .../api/chat
134
125
  base_url = endpoint.split("/api/")[0]
135
126
  tags_url = f"{base_url}/api/tags"
136
127
 
@@ -141,8 +132,6 @@ def get_available_models() -> list[str]:
141
132
  for model in models:
142
133
  name = model.get("name")
143
134
  if name:
144
- # Ollama model names often include tags like "llama3:latest"
145
- # We can keep them as is.
146
135
  available_models.add(f"ollama/{name}")
147
136
  except Exception as e:
148
137
  logger.debug(f"Failed to fetch Ollama models: {e}")
@@ -184,4 +173,30 @@ def get_available_models() -> list[str]:
184
173
  for model_id in get_all_provider_models(api_name):
185
174
  available_models.add(f"{prompture_name}/{model_id}")
186
175
 
187
- return sorted(list(available_models))
176
+ sorted_models = sorted(available_models)
177
+
178
+ if not include_capabilities:
179
+ return sorted_models
180
+
181
+ # Build enriched dicts with capabilities from models.dev
182
+ from .model_rates import get_model_capabilities
183
+
184
+ enriched: list[dict[str, Any]] = []
185
+ for model_str in sorted_models:
186
+ parts = model_str.split("/", 1)
187
+ provider = parts[0]
188
+ model_id = parts[1] if len(parts) > 1 else parts[0]
189
+
190
+ caps = get_model_capabilities(provider, model_id)
191
+ caps_dict = dataclasses.asdict(caps) if caps is not None else None
192
+
193
+ enriched.append(
194
+ {
195
+ "model": model_str,
196
+ "provider": provider,
197
+ "model_id": model_id,
198
+ "capabilities": caps_dict,
199
+ }
200
+ )
201
+
202
+ return enriched
@@ -173,6 +173,45 @@ class Driver:
173
173
  except Exception:
174
174
  logger.exception("Callback %s raised an exception", event)
175
175
 
176
+ def _validate_model_capabilities(
177
+ self,
178
+ provider: str,
179
+ model: str,
180
+ *,
181
+ using_tool_use: bool = False,
182
+ using_json_schema: bool = False,
183
+ using_vision: bool = False,
184
+ ) -> None:
185
+ """Log warnings when the model may not support a requested feature.
186
+
187
+ Uses models.dev metadata as a secondary signal. Warnings only — the
188
+ API is the final authority and models.dev data may be stale.
189
+ """
190
+ from .model_rates import get_model_capabilities
191
+
192
+ caps = get_model_capabilities(provider, model)
193
+ if caps is None:
194
+ return
195
+
196
+ if using_tool_use and caps.supports_tool_use is False:
197
+ logger.warning(
198
+ "Model %s/%s may not support tool use according to models.dev metadata",
199
+ provider,
200
+ model,
201
+ )
202
+ if using_json_schema and caps.supports_structured_output is False:
203
+ logger.warning(
204
+ "Model %s/%s may not support structured output / JSON schema according to models.dev metadata",
205
+ provider,
206
+ model,
207
+ )
208
+ if using_vision and caps.supports_vision is False:
209
+ logger.warning(
210
+ "Model %s/%s may not support vision/image inputs according to models.dev metadata",
211
+ provider,
212
+ model,
213
+ )
214
+
176
215
  def _check_vision_support(self, messages: list[dict[str, Any]]) -> None:
177
216
  """Raise if messages contain image blocks and the driver lacks vision support."""
178
217
  if self.supports_vision:
@@ -70,9 +70,9 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
70
70
  raise RuntimeError("openai package (>=1.0.0) with AsyncAzureOpenAI not installed")
71
71
 
72
72
  model = options.get("model", self.model)
73
- model_info = self.MODEL_PRICING.get(model, {})
74
- tokens_param = model_info.get("tokens_param", "max_tokens")
75
- supports_temperature = model_info.get("supports_temperature", True)
73
+ model_config = self._get_model_config("azure", model)
74
+ tokens_param = model_config["tokens_param"]
75
+ supports_temperature = model_config["supports_temperature"]
76
76
 
77
77
  opts = {"temperature": 1.0, "max_tokens": 512, **options}
78
78
 
@@ -113,7 +113,7 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
113
113
  "prompt_tokens": prompt_tokens,
114
114
  "completion_tokens": completion_tokens,
115
115
  "total_tokens": total_tokens,
116
- "cost": total_cost,
116
+ "cost": round(total_cost, 6),
117
117
  "raw_response": resp.model_dump(),
118
118
  "model_name": model,
119
119
  "deployment_id": self.deployment_id,
@@ -0,0 +1,282 @@
1
+ """Async Anthropic Claude driver. Requires the ``anthropic`` package."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ from collections.abc import AsyncIterator
8
+ from typing import Any
9
+
10
+ try:
11
+ import anthropic
12
+ except Exception:
13
+ anthropic = None
14
+
15
+ from ..async_driver import AsyncDriver
16
+ from ..cost_mixin import CostMixin
17
+ from .claude_driver import ClaudeDriver
18
+
19
+
20
+ class AsyncClaudeDriver(CostMixin, AsyncDriver):
21
+ supports_json_mode = True
22
+ supports_json_schema = True
23
+ supports_tool_use = True
24
+ supports_streaming = True
25
+ supports_vision = True
26
+
27
+ MODEL_PRICING = ClaudeDriver.MODEL_PRICING
28
+
29
+ def __init__(self, api_key: str | None = None, model: str = "claude-3-5-haiku-20241022"):
30
+ self.api_key = api_key or os.getenv("CLAUDE_API_KEY")
31
+ self.model = model or os.getenv("CLAUDE_MODEL_NAME", "claude-3-5-haiku-20241022")
32
+
33
+ supports_messages = True
34
+
35
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
36
+ from .vision_helpers import _prepare_claude_vision_messages
37
+
38
+ return _prepare_claude_vision_messages(messages)
39
+
40
+ async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
41
+ messages = [{"role": "user", "content": prompt}]
42
+ return await self._do_generate(messages, options)
43
+
44
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
45
+ return await self._do_generate(self._prepare_messages(messages), options)
46
+
47
+ async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
48
+ if anthropic is None:
49
+ raise RuntimeError("anthropic package not installed")
50
+
51
+ opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
52
+ model = options.get("model", self.model)
53
+
54
+ # Validate capabilities against models.dev metadata
55
+ self._validate_model_capabilities(
56
+ "claude",
57
+ model,
58
+ using_json_schema=bool(options.get("json_schema")),
59
+ )
60
+
61
+ client = anthropic.AsyncAnthropic(api_key=self.api_key)
62
+
63
+ # Anthropic requires system messages as a top-level parameter
64
+ system_content, api_messages = self._extract_system_and_messages(messages)
65
+
66
+ # Build common kwargs
67
+ common_kwargs: dict[str, Any] = {
68
+ "model": model,
69
+ "messages": api_messages,
70
+ "temperature": opts["temperature"],
71
+ "max_tokens": opts["max_tokens"],
72
+ }
73
+ if system_content:
74
+ common_kwargs["system"] = system_content
75
+
76
+ # Native JSON mode: use tool-use for schema enforcement
77
+ if options.get("json_mode"):
78
+ json_schema = options.get("json_schema")
79
+ if json_schema:
80
+ tool_def = {
81
+ "name": "extract_json",
82
+ "description": "Extract structured data matching the schema",
83
+ "input_schema": json_schema,
84
+ }
85
+ resp = await client.messages.create(
86
+ **common_kwargs,
87
+ tools=[tool_def],
88
+ tool_choice={"type": "tool", "name": "extract_json"},
89
+ )
90
+ text = ""
91
+ for block in resp.content:
92
+ if block.type == "tool_use":
93
+ text = json.dumps(block.input)
94
+ break
95
+ else:
96
+ resp = await client.messages.create(**common_kwargs)
97
+ text = resp.content[0].text
98
+ else:
99
+ resp = await client.messages.create(**common_kwargs)
100
+ text = resp.content[0].text
101
+
102
+ prompt_tokens = resp.usage.input_tokens
103
+ completion_tokens = resp.usage.output_tokens
104
+ total_tokens = prompt_tokens + completion_tokens
105
+
106
+ total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
107
+
108
+ meta = {
109
+ "prompt_tokens": prompt_tokens,
110
+ "completion_tokens": completion_tokens,
111
+ "total_tokens": total_tokens,
112
+ "cost": round(total_cost, 6),
113
+ "raw_response": dict(resp),
114
+ "model_name": model,
115
+ }
116
+
117
+ return {"text": text, "meta": meta}
118
+
119
+ # ------------------------------------------------------------------
120
+ # Helpers
121
+ # ------------------------------------------------------------------
122
+
123
+ def _extract_system_and_messages(
124
+ self, messages: list[dict[str, Any]]
125
+ ) -> tuple[str | None, list[dict[str, Any]]]:
126
+ """Separate system message from conversation messages for Anthropic API."""
127
+ system_content = None
128
+ api_messages: list[dict[str, Any]] = []
129
+ for msg in messages:
130
+ if msg.get("role") == "system":
131
+ system_content = msg.get("content", "")
132
+ else:
133
+ api_messages.append(msg)
134
+ return system_content, api_messages
135
+
136
+ # ------------------------------------------------------------------
137
+ # Tool use
138
+ # ------------------------------------------------------------------
139
+
140
+ async def generate_messages_with_tools(
141
+ self,
142
+ messages: list[dict[str, Any]],
143
+ tools: list[dict[str, Any]],
144
+ options: dict[str, Any],
145
+ ) -> dict[str, Any]:
146
+ """Generate a response that may include tool calls (Anthropic)."""
147
+ if anthropic is None:
148
+ raise RuntimeError("anthropic package not installed")
149
+
150
+ opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
151
+ model = options.get("model", self.model)
152
+
153
+ self._validate_model_capabilities("claude", model, using_tool_use=True)
154
+
155
+ client = anthropic.AsyncAnthropic(api_key=self.api_key)
156
+
157
+ system_content, api_messages = self._extract_system_and_messages(messages)
158
+
159
+ # Convert tools from OpenAI format to Anthropic format if needed
160
+ anthropic_tools = []
161
+ for t in tools:
162
+ if "type" in t and t["type"] == "function":
163
+ # OpenAI format -> Anthropic format
164
+ fn = t["function"]
165
+ anthropic_tools.append({
166
+ "name": fn["name"],
167
+ "description": fn.get("description", ""),
168
+ "input_schema": fn.get("parameters", {"type": "object", "properties": {}}),
169
+ })
170
+ elif "input_schema" in t:
171
+ # Already Anthropic format
172
+ anthropic_tools.append(t)
173
+ else:
174
+ anthropic_tools.append(t)
175
+
176
+ kwargs: dict[str, Any] = {
177
+ "model": model,
178
+ "messages": api_messages,
179
+ "temperature": opts["temperature"],
180
+ "max_tokens": opts["max_tokens"],
181
+ "tools": anthropic_tools,
182
+ }
183
+ if system_content:
184
+ kwargs["system"] = system_content
185
+
186
+ resp = await client.messages.create(**kwargs)
187
+
188
+ prompt_tokens = resp.usage.input_tokens
189
+ completion_tokens = resp.usage.output_tokens
190
+ total_tokens = prompt_tokens + completion_tokens
191
+ total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
192
+
193
+ meta = {
194
+ "prompt_tokens": prompt_tokens,
195
+ "completion_tokens": completion_tokens,
196
+ "total_tokens": total_tokens,
197
+ "cost": round(total_cost, 6),
198
+ "raw_response": dict(resp),
199
+ "model_name": model,
200
+ }
201
+
202
+ text = ""
203
+ tool_calls_out: list[dict[str, Any]] = []
204
+ for block in resp.content:
205
+ if block.type == "text":
206
+ text += block.text
207
+ elif block.type == "tool_use":
208
+ tool_calls_out.append({
209
+ "id": block.id,
210
+ "name": block.name,
211
+ "arguments": block.input,
212
+ })
213
+
214
+ return {
215
+ "text": text,
216
+ "meta": meta,
217
+ "tool_calls": tool_calls_out,
218
+ "stop_reason": resp.stop_reason,
219
+ }
220
+
221
+ # ------------------------------------------------------------------
222
+ # Streaming
223
+ # ------------------------------------------------------------------
224
+
225
+ async def generate_messages_stream(
226
+ self,
227
+ messages: list[dict[str, Any]],
228
+ options: dict[str, Any],
229
+ ) -> AsyncIterator[dict[str, Any]]:
230
+ """Yield response chunks via Anthropic streaming API."""
231
+ if anthropic is None:
232
+ raise RuntimeError("anthropic package not installed")
233
+
234
+ opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
235
+ model = options.get("model", self.model)
236
+ client = anthropic.AsyncAnthropic(api_key=self.api_key)
237
+
238
+ system_content, api_messages = self._extract_system_and_messages(messages)
239
+
240
+ kwargs: dict[str, Any] = {
241
+ "model": model,
242
+ "messages": api_messages,
243
+ "temperature": opts["temperature"],
244
+ "max_tokens": opts["max_tokens"],
245
+ }
246
+ if system_content:
247
+ kwargs["system"] = system_content
248
+
249
+ full_text = ""
250
+ prompt_tokens = 0
251
+ completion_tokens = 0
252
+
253
+ async with client.messages.stream(**kwargs) as stream:
254
+ async for event in stream:
255
+ if hasattr(event, "type"):
256
+ if event.type == "content_block_delta" and hasattr(event, "delta"):
257
+ delta_text = getattr(event.delta, "text", "")
258
+ if delta_text:
259
+ full_text += delta_text
260
+ yield {"type": "delta", "text": delta_text}
261
+ elif event.type == "message_delta" and hasattr(event, "usage"):
262
+ completion_tokens = getattr(event.usage, "output_tokens", 0)
263
+ elif event.type == "message_start" and hasattr(event, "message"):
264
+ usage = getattr(event.message, "usage", None)
265
+ if usage:
266
+ prompt_tokens = getattr(usage, "input_tokens", 0)
267
+
268
+ total_tokens = prompt_tokens + completion_tokens
269
+ total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
270
+
271
+ yield {
272
+ "type": "done",
273
+ "text": full_text,
274
+ "meta": {
275
+ "prompt_tokens": prompt_tokens,
276
+ "completion_tokens": completion_tokens,
277
+ "total_tokens": total_tokens,
278
+ "cost": round(total_cost, 6),
279
+ "raw_response": {},
280
+ "model_name": model,
281
+ },
282
+ }
@@ -169,6 +169,13 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
169
169
  ) -> dict[str, Any]:
170
170
  gen_input, gen_kwargs, model_kwargs = self._build_generation_args(messages, options)
171
171
 
172
+ # Validate capabilities against models.dev metadata
173
+ self._validate_model_capabilities(
174
+ "google",
175
+ self.model,
176
+ using_json_schema=bool((options or {}).get("json_schema")),
177
+ )
178
+
172
179
  try:
173
180
  model = genai.GenerativeModel(self.model, **model_kwargs)
174
181
  response = await model.generate_content_async(gen_input, **gen_kwargs)
@@ -201,6 +208,9 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
201
208
  options: dict[str, Any],
202
209
  ) -> dict[str, Any]:
203
210
  """Generate a response that may include tool/function calls (async)."""
211
+ model = options.get("model", self.model)
212
+ self._validate_model_capabilities("google", model, using_tool_use=True)
213
+
204
214
  gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
205
215
  self._prepare_messages(messages), options
206
216
  )