prompture 0.0.33.dev2__py3-none-any.whl → 0.0.34.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. prompture/__init__.py +112 -54
  2. prompture/_version.py +34 -0
  3. prompture/aio/__init__.py +74 -0
  4. prompture/async_conversation.py +484 -0
  5. prompture/async_core.py +803 -0
  6. prompture/async_driver.py +131 -0
  7. prompture/cache.py +469 -0
  8. prompture/callbacks.py +50 -0
  9. prompture/cli.py +7 -3
  10. prompture/conversation.py +504 -0
  11. prompture/core.py +475 -352
  12. prompture/cost_mixin.py +51 -0
  13. prompture/discovery.py +41 -36
  14. prompture/driver.py +125 -5
  15. prompture/drivers/__init__.py +63 -57
  16. prompture/drivers/airllm_driver.py +13 -20
  17. prompture/drivers/async_airllm_driver.py +26 -0
  18. prompture/drivers/async_azure_driver.py +117 -0
  19. prompture/drivers/async_claude_driver.py +107 -0
  20. prompture/drivers/async_google_driver.py +132 -0
  21. prompture/drivers/async_grok_driver.py +91 -0
  22. prompture/drivers/async_groq_driver.py +84 -0
  23. prompture/drivers/async_hugging_driver.py +61 -0
  24. prompture/drivers/async_lmstudio_driver.py +79 -0
  25. prompture/drivers/async_local_http_driver.py +44 -0
  26. prompture/drivers/async_ollama_driver.py +125 -0
  27. prompture/drivers/async_openai_driver.py +96 -0
  28. prompture/drivers/async_openrouter_driver.py +96 -0
  29. prompture/drivers/async_registry.py +80 -0
  30. prompture/drivers/azure_driver.py +36 -15
  31. prompture/drivers/claude_driver.py +86 -40
  32. prompture/drivers/google_driver.py +86 -58
  33. prompture/drivers/grok_driver.py +29 -38
  34. prompture/drivers/groq_driver.py +27 -32
  35. prompture/drivers/hugging_driver.py +6 -6
  36. prompture/drivers/lmstudio_driver.py +26 -13
  37. prompture/drivers/local_http_driver.py +6 -6
  38. prompture/drivers/ollama_driver.py +90 -23
  39. prompture/drivers/openai_driver.py +36 -15
  40. prompture/drivers/openrouter_driver.py +31 -31
  41. prompture/field_definitions.py +106 -96
  42. prompture/logging.py +80 -0
  43. prompture/model_rates.py +16 -15
  44. prompture/runner.py +49 -47
  45. prompture/session.py +117 -0
  46. prompture/settings.py +11 -1
  47. prompture/tools.py +172 -265
  48. prompture/validator.py +3 -3
  49. {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/METADATA +18 -20
  50. prompture-0.0.34.dev1.dist-info/RECORD +54 -0
  51. prompture-0.0.33.dev2.dist-info/RECORD +0 -30
  52. {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/WHEEL +0 -0
  53. {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/entry_points.txt +0 -0
  54. {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/licenses/LICENSE +0 -0
  55. {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,44 @@
1
+ """Async LocalHTTP driver using httpx."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from typing import Any
7
+
8
+ import httpx
9
+
10
+ from ..async_driver import AsyncDriver
11
+
12
+
13
+ class AsyncLocalHTTPDriver(AsyncDriver):
14
+ MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
15
+
16
+ def __init__(self, endpoint: str | None = None, model: str = "local-model"):
17
+ self.endpoint = endpoint or os.getenv("LOCAL_HTTP_ENDPOINT", "http://localhost:8000/generate")
18
+ self.model = model
19
+
20
+ async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
21
+ payload = {"prompt": prompt, "options": options}
22
+
23
+ async with httpx.AsyncClient() as client:
24
+ try:
25
+ r = await client.post(self.endpoint, json=payload, timeout=options.get("timeout", 30))
26
+ r.raise_for_status()
27
+ response_data = r.json()
28
+ except Exception as e:
29
+ raise RuntimeError(f"AsyncLocalHTTPDriver request failed: {e}") from e
30
+
31
+ if "text" in response_data and "meta" in response_data:
32
+ return response_data
33
+
34
+ meta = {
35
+ "prompt_tokens": response_data.get("prompt_tokens", 0),
36
+ "completion_tokens": response_data.get("completion_tokens", 0),
37
+ "total_tokens": response_data.get("total_tokens", 0),
38
+ "cost": 0.0,
39
+ "raw_response": response_data,
40
+ "model_name": options.get("model", self.model),
41
+ }
42
+
43
+ text = response_data.get("text") or response_data.get("response") or str(response_data)
44
+ return {"text": text, "meta": meta}
@@ -0,0 +1,125 @@
1
+ """Async Ollama driver using httpx."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ from typing import Any
8
+
9
+ import httpx
10
+
11
+ from ..async_driver import AsyncDriver
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class AsyncOllamaDriver(AsyncDriver):
17
+ supports_json_mode = True
18
+
19
+ MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
20
+
21
+ def __init__(self, endpoint: str | None = None, model: str = "llama3"):
22
+ self.endpoint = endpoint or os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434/api/generate")
23
+ self.model = model
24
+ self.options: dict[str, Any] = {}
25
+
26
+ supports_messages = True
27
+
28
+ async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
29
+ merged_options = self.options.copy()
30
+ if options:
31
+ merged_options.update(options)
32
+
33
+ payload = {
34
+ "prompt": prompt,
35
+ "model": merged_options.get("model", self.model),
36
+ "stream": False,
37
+ }
38
+
39
+ # Native JSON mode support
40
+ if merged_options.get("json_mode"):
41
+ payload["format"] = "json"
42
+
43
+ if "temperature" in merged_options:
44
+ payload["temperature"] = merged_options["temperature"]
45
+ if "top_p" in merged_options:
46
+ payload["top_p"] = merged_options["top_p"]
47
+ if "top_k" in merged_options:
48
+ payload["top_k"] = merged_options["top_k"]
49
+
50
+ async with httpx.AsyncClient() as client:
51
+ try:
52
+ r = await client.post(self.endpoint, json=payload, timeout=120)
53
+ r.raise_for_status()
54
+ response_data = r.json()
55
+ except httpx.HTTPStatusError as e:
56
+ raise RuntimeError(f"Ollama request failed: {e}") from e
57
+ except Exception as e:
58
+ raise RuntimeError(f"Ollama request failed: {e}") from e
59
+
60
+ prompt_tokens = response_data.get("prompt_eval_count", 0)
61
+ completion_tokens = response_data.get("eval_count", 0)
62
+ total_tokens = prompt_tokens + completion_tokens
63
+
64
+ meta = {
65
+ "prompt_tokens": prompt_tokens,
66
+ "completion_tokens": completion_tokens,
67
+ "total_tokens": total_tokens,
68
+ "cost": 0.0,
69
+ "raw_response": response_data,
70
+ "model_name": merged_options.get("model", self.model),
71
+ }
72
+
73
+ return {"text": response_data.get("response", ""), "meta": meta}
74
+
75
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
76
+ """Use Ollama's /api/chat endpoint for multi-turn conversations."""
77
+ merged_options = self.options.copy()
78
+ if options:
79
+ merged_options.update(options)
80
+
81
+ # Derive the chat endpoint from the generate endpoint
82
+ chat_endpoint = self.endpoint.replace("/api/generate", "/api/chat")
83
+
84
+ payload: dict[str, Any] = {
85
+ "model": merged_options.get("model", self.model),
86
+ "messages": messages,
87
+ "stream": False,
88
+ }
89
+
90
+ if merged_options.get("json_mode"):
91
+ payload["format"] = "json"
92
+
93
+ if "temperature" in merged_options:
94
+ payload["temperature"] = merged_options["temperature"]
95
+ if "top_p" in merged_options:
96
+ payload["top_p"] = merged_options["top_p"]
97
+ if "top_k" in merged_options:
98
+ payload["top_k"] = merged_options["top_k"]
99
+
100
+ async with httpx.AsyncClient() as client:
101
+ try:
102
+ r = await client.post(chat_endpoint, json=payload, timeout=120)
103
+ r.raise_for_status()
104
+ response_data = r.json()
105
+ except httpx.HTTPStatusError as e:
106
+ raise RuntimeError(f"Ollama chat request failed: {e}") from e
107
+ except Exception as e:
108
+ raise RuntimeError(f"Ollama chat request failed: {e}") from e
109
+
110
+ prompt_tokens = response_data.get("prompt_eval_count", 0)
111
+ completion_tokens = response_data.get("eval_count", 0)
112
+ total_tokens = prompt_tokens + completion_tokens
113
+
114
+ meta = {
115
+ "prompt_tokens": prompt_tokens,
116
+ "completion_tokens": completion_tokens,
117
+ "total_tokens": total_tokens,
118
+ "cost": 0.0,
119
+ "raw_response": response_data,
120
+ "model_name": merged_options.get("model", self.model),
121
+ }
122
+
123
+ message = response_data.get("message", {})
124
+ text = message.get("content", "")
125
+ return {"text": text, "meta": meta}
@@ -0,0 +1,96 @@
1
+ """Async OpenAI driver. Requires the ``openai`` package (>=1.0.0)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from typing import Any
7
+
8
+ try:
9
+ from openai import AsyncOpenAI
10
+ except Exception:
11
+ AsyncOpenAI = None
12
+
13
+ from ..async_driver import AsyncDriver
14
+ from ..cost_mixin import CostMixin
15
+ from .openai_driver import OpenAIDriver
16
+
17
+
18
+ class AsyncOpenAIDriver(CostMixin, AsyncDriver):
19
+ supports_json_mode = True
20
+ supports_json_schema = True
21
+
22
+ MODEL_PRICING = OpenAIDriver.MODEL_PRICING
23
+
24
+ def __init__(self, api_key: str | None = None, model: str = "gpt-4o-mini"):
25
+ self.api_key = api_key or os.getenv("OPENAI_API_KEY")
26
+ self.model = model
27
+ if AsyncOpenAI:
28
+ self.client = AsyncOpenAI(api_key=self.api_key)
29
+ else:
30
+ self.client = None
31
+
32
+ supports_messages = True
33
+
34
+ async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
35
+ messages = [{"role": "user", "content": prompt}]
36
+ return await self._do_generate(messages, options)
37
+
38
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
39
+ return await self._do_generate(messages, options)
40
+
41
+ async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
42
+ if self.client is None:
43
+ raise RuntimeError("openai package (>=1.0.0) is not installed")
44
+
45
+ model = options.get("model", self.model)
46
+
47
+ model_info = self.MODEL_PRICING.get(model, {})
48
+ tokens_param = model_info.get("tokens_param", "max_tokens")
49
+ supports_temperature = model_info.get("supports_temperature", True)
50
+
51
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
52
+
53
+ kwargs = {
54
+ "model": model,
55
+ "messages": messages,
56
+ }
57
+ kwargs[tokens_param] = opts.get("max_tokens", 512)
58
+
59
+ if supports_temperature and "temperature" in opts:
60
+ kwargs["temperature"] = opts["temperature"]
61
+
62
+ # Native JSON mode support
63
+ if options.get("json_mode"):
64
+ json_schema = options.get("json_schema")
65
+ if json_schema:
66
+ kwargs["response_format"] = {
67
+ "type": "json_schema",
68
+ "json_schema": {
69
+ "name": "extraction",
70
+ "strict": True,
71
+ "schema": json_schema,
72
+ },
73
+ }
74
+ else:
75
+ kwargs["response_format"] = {"type": "json_object"}
76
+
77
+ resp = await self.client.chat.completions.create(**kwargs)
78
+
79
+ usage = getattr(resp, "usage", None)
80
+ prompt_tokens = getattr(usage, "prompt_tokens", 0)
81
+ completion_tokens = getattr(usage, "completion_tokens", 0)
82
+ total_tokens = getattr(usage, "total_tokens", 0)
83
+
84
+ total_cost = self._calculate_cost("openai", model, prompt_tokens, completion_tokens)
85
+
86
+ meta = {
87
+ "prompt_tokens": prompt_tokens,
88
+ "completion_tokens": completion_tokens,
89
+ "total_tokens": total_tokens,
90
+ "cost": total_cost,
91
+ "raw_response": resp.model_dump(),
92
+ "model_name": model,
93
+ }
94
+
95
+ text = resp.choices[0].message.content
96
+ return {"text": text, "meta": meta}
@@ -0,0 +1,96 @@
1
+ """Async OpenRouter driver using httpx."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from typing import Any
7
+
8
+ import httpx
9
+
10
+ from ..async_driver import AsyncDriver
11
+ from ..cost_mixin import CostMixin
12
+ from .openrouter_driver import OpenRouterDriver
13
+
14
+
15
+ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
16
+ supports_json_mode = True
17
+
18
+ MODEL_PRICING = OpenRouterDriver.MODEL_PRICING
19
+
20
+ def __init__(self, api_key: str | None = None, model: str = "openai/gpt-3.5-turbo"):
21
+ self.api_key = api_key or os.getenv("OPENROUTER_API_KEY")
22
+ if not self.api_key:
23
+ raise ValueError("OpenRouter API key not found. Set OPENROUTER_API_KEY env var.")
24
+ self.model = model
25
+ self.base_url = "https://openrouter.ai/api/v1"
26
+ self.headers = {
27
+ "Authorization": f"Bearer {self.api_key}",
28
+ "HTTP-Referer": "https://github.com/jhd3197/prompture",
29
+ "Content-Type": "application/json",
30
+ }
31
+
32
+ supports_messages = True
33
+
34
+ async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
35
+ messages = [{"role": "user", "content": prompt}]
36
+ return await self._do_generate(messages, options)
37
+
38
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
39
+ return await self._do_generate(messages, options)
40
+
41
+ async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
42
+ model = options.get("model", self.model)
43
+
44
+ model_info = self.MODEL_PRICING.get(model, {})
45
+ tokens_param = model_info.get("tokens_param", "max_tokens")
46
+ supports_temperature = model_info.get("supports_temperature", True)
47
+
48
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
49
+
50
+ data = {
51
+ "model": model,
52
+ "messages": messages,
53
+ }
54
+ data[tokens_param] = opts.get("max_tokens", 512)
55
+
56
+ if supports_temperature and "temperature" in opts:
57
+ data["temperature"] = opts["temperature"]
58
+
59
+ # Native JSON mode support
60
+ if options.get("json_mode"):
61
+ data["response_format"] = {"type": "json_object"}
62
+
63
+ async with httpx.AsyncClient() as client:
64
+ try:
65
+ response = await client.post(
66
+ f"{self.base_url}/chat/completions",
67
+ headers=self.headers,
68
+ json=data,
69
+ timeout=120,
70
+ )
71
+ response.raise_for_status()
72
+ resp = response.json()
73
+ except httpx.HTTPStatusError as e:
74
+ error_msg = f"OpenRouter API request failed: {e!s}"
75
+ raise RuntimeError(error_msg) from e
76
+ except Exception as e:
77
+ raise RuntimeError(f"OpenRouter API request failed: {e!s}") from e
78
+
79
+ usage = resp.get("usage", {})
80
+ prompt_tokens = usage.get("prompt_tokens", 0)
81
+ completion_tokens = usage.get("completion_tokens", 0)
82
+ total_tokens = usage.get("total_tokens", 0)
83
+
84
+ total_cost = self._calculate_cost("openrouter", model, prompt_tokens, completion_tokens)
85
+
86
+ meta = {
87
+ "prompt_tokens": prompt_tokens,
88
+ "completion_tokens": completion_tokens,
89
+ "total_tokens": total_tokens,
90
+ "cost": total_cost,
91
+ "raw_response": resp,
92
+ "model_name": model,
93
+ }
94
+
95
+ text = resp["choices"][0]["message"]["content"]
96
+ return {"text": text, "meta": meta}
@@ -0,0 +1,80 @@
1
+ """Async driver registry — mirrors the sync DRIVER_REGISTRY."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from ..settings import settings
6
+ from .async_airllm_driver import AsyncAirLLMDriver
7
+ from .async_azure_driver import AsyncAzureDriver
8
+ from .async_claude_driver import AsyncClaudeDriver
9
+ from .async_google_driver import AsyncGoogleDriver
10
+ from .async_grok_driver import AsyncGrokDriver
11
+ from .async_groq_driver import AsyncGroqDriver
12
+ from .async_lmstudio_driver import AsyncLMStudioDriver
13
+ from .async_local_http_driver import AsyncLocalHTTPDriver
14
+ from .async_ollama_driver import AsyncOllamaDriver
15
+ from .async_openai_driver import AsyncOpenAIDriver
16
+ from .async_openrouter_driver import AsyncOpenRouterDriver
17
+
18
+ ASYNC_DRIVER_REGISTRY = {
19
+ "openai": lambda model=None: AsyncOpenAIDriver(
20
+ api_key=settings.openai_api_key, model=model or settings.openai_model
21
+ ),
22
+ "ollama": lambda model=None: AsyncOllamaDriver(
23
+ endpoint=settings.ollama_endpoint, model=model or settings.ollama_model
24
+ ),
25
+ "claude": lambda model=None: AsyncClaudeDriver(
26
+ api_key=settings.claude_api_key, model=model or settings.claude_model
27
+ ),
28
+ "lmstudio": lambda model=None: AsyncLMStudioDriver(
29
+ endpoint=settings.lmstudio_endpoint, model=model or settings.lmstudio_model
30
+ ),
31
+ "azure": lambda model=None: AsyncAzureDriver(
32
+ api_key=settings.azure_api_key, endpoint=settings.azure_api_endpoint, deployment_id=settings.azure_deployment_id
33
+ ),
34
+ "local_http": lambda model=None: AsyncLocalHTTPDriver(endpoint=settings.local_http_endpoint, model=model),
35
+ "google": lambda model=None: AsyncGoogleDriver(
36
+ api_key=settings.google_api_key, model=model or settings.google_model
37
+ ),
38
+ "groq": lambda model=None: AsyncGroqDriver(api_key=settings.groq_api_key, model=model or settings.groq_model),
39
+ "openrouter": lambda model=None: AsyncOpenRouterDriver(
40
+ api_key=settings.openrouter_api_key, model=model or settings.openrouter_model
41
+ ),
42
+ "grok": lambda model=None: AsyncGrokDriver(api_key=settings.grok_api_key, model=model or settings.grok_model),
43
+ "airllm": lambda model=None: AsyncAirLLMDriver(
44
+ model=model or settings.airllm_model,
45
+ compression=settings.airllm_compression,
46
+ ),
47
+ }
48
+
49
+
50
+ def get_async_driver(provider_name: str | None = None):
51
+ """Factory to get an async driver instance based on the provider name.
52
+
53
+ Uses default model from settings if not overridden.
54
+ """
55
+ provider = (provider_name or settings.ai_provider or "ollama").strip().lower()
56
+ if provider not in ASYNC_DRIVER_REGISTRY:
57
+ raise ValueError(f"Unknown provider: {provider_name}")
58
+ return ASYNC_DRIVER_REGISTRY[provider]()
59
+
60
+
61
+ def get_async_driver_for_model(model_str: str):
62
+ """Factory to get an async driver instance based on a full model string.
63
+
64
+ Format: ``provider/model_id``
65
+ Example: ``"openai/gpt-4-turbo-preview"``
66
+ """
67
+ if not isinstance(model_str, str):
68
+ raise ValueError("Model string must be a string, got {type(model_str)}")
69
+
70
+ if not model_str:
71
+ raise ValueError("Model string cannot be empty")
72
+
73
+ parts = model_str.split("/", 1)
74
+ provider = parts[0].lower()
75
+ model_id = parts[1] if len(parts) > 1 else None
76
+
77
+ if provider not in ASYNC_DRIVER_REGISTRY:
78
+ raise ValueError(f"Unsupported provider '{provider}'")
79
+
80
+ return ASYNC_DRIVER_REGISTRY[provider](model_id)
@@ -1,17 +1,23 @@
1
1
  """Driver for Azure OpenAI Service (migrated to openai>=1.0.0).
2
2
  Requires the `openai` package.
3
3
  """
4
+
4
5
  import os
5
- from typing import Any, Dict
6
+ from typing import Any
7
+
6
8
  try:
7
9
  from openai import AzureOpenAI
8
10
  except Exception:
9
11
  AzureOpenAI = None
10
12
 
13
+ from ..cost_mixin import CostMixin
11
14
  from ..driver import Driver
12
15
 
13
16
 
14
- class AzureDriver(Driver):
17
+ class AzureDriver(CostMixin, Driver):
18
+ supports_json_mode = True
19
+ supports_json_schema = True
20
+
15
21
  # Pricing per 1K tokens (adjust if your Azure pricing differs from OpenAI defaults)
16
22
  MODEL_PRICING = {
17
23
  "gpt-5-mini": {
@@ -82,7 +88,16 @@ class AzureDriver(Driver):
82
88
  else:
83
89
  self.client = None
84
90
 
85
- def generate(self, prompt: str, options: Dict[str, Any]) -> Dict[str, Any]:
91
+ supports_messages = True
92
+
93
+ def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
94
+ messages = [{"role": "user", "content": prompt}]
95
+ return self._do_generate(messages, options)
96
+
97
+ def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
98
+ return self._do_generate(messages, options)
99
+
100
+ def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
86
101
  if self.client is None:
87
102
  raise RuntimeError("openai package (>=1.0.0) with AzureOpenAI not installed")
88
103
 
@@ -96,13 +111,28 @@ class AzureDriver(Driver):
96
111
  # Build request kwargs
97
112
  kwargs = {
98
113
  "model": self.deployment_id, # for Azure, use deployment name
99
- "messages": [{"role": "user", "content": prompt}],
114
+ "messages": messages,
100
115
  }
101
116
  kwargs[tokens_param] = opts.get("max_tokens", 512)
102
117
 
103
118
  if supports_temperature and "temperature" in opts:
104
119
  kwargs["temperature"] = opts["temperature"]
105
120
 
121
+ # Native JSON mode support
122
+ if options.get("json_mode"):
123
+ json_schema = options.get("json_schema")
124
+ if json_schema:
125
+ kwargs["response_format"] = {
126
+ "type": "json_schema",
127
+ "json_schema": {
128
+ "name": "extraction",
129
+ "strict": True,
130
+ "schema": json_schema,
131
+ },
132
+ }
133
+ else:
134
+ kwargs["response_format"] = {"type": "json_object"}
135
+
106
136
  resp = self.client.chat.completions.create(**kwargs)
107
137
 
108
138
  # Extract usage
@@ -111,17 +141,8 @@ class AzureDriver(Driver):
111
141
  completion_tokens = getattr(usage, "completion_tokens", 0)
112
142
  total_tokens = getattr(usage, "total_tokens", 0)
113
143
 
114
- # Calculate cost try live rates first (per 1M tokens), fall back to hardcoded (per 1K tokens)
115
- from ..model_rates import get_model_rates
116
- live_rates = get_model_rates("azure", model)
117
- if live_rates:
118
- prompt_cost = (prompt_tokens / 1_000_000) * live_rates["input"]
119
- completion_cost = (completion_tokens / 1_000_000) * live_rates["output"]
120
- else:
121
- model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
122
- prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
123
- completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
124
- total_cost = prompt_cost + completion_cost
144
+ # Calculate cost via shared mixin
145
+ total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
125
146
 
126
147
  # Standardized meta object
127
148
  meta = {