prompture 0.0.33.dev2__py3-none-any.whl → 0.0.34.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +112 -54
- prompture/_version.py +34 -0
- prompture/aio/__init__.py +74 -0
- prompture/async_conversation.py +484 -0
- prompture/async_core.py +803 -0
- prompture/async_driver.py +131 -0
- prompture/cache.py +469 -0
- prompture/callbacks.py +50 -0
- prompture/cli.py +7 -3
- prompture/conversation.py +504 -0
- prompture/core.py +475 -352
- prompture/cost_mixin.py +51 -0
- prompture/discovery.py +41 -36
- prompture/driver.py +125 -5
- prompture/drivers/__init__.py +63 -57
- prompture/drivers/airllm_driver.py +13 -20
- prompture/drivers/async_airllm_driver.py +26 -0
- prompture/drivers/async_azure_driver.py +117 -0
- prompture/drivers/async_claude_driver.py +107 -0
- prompture/drivers/async_google_driver.py +132 -0
- prompture/drivers/async_grok_driver.py +91 -0
- prompture/drivers/async_groq_driver.py +84 -0
- prompture/drivers/async_hugging_driver.py +61 -0
- prompture/drivers/async_lmstudio_driver.py +79 -0
- prompture/drivers/async_local_http_driver.py +44 -0
- prompture/drivers/async_ollama_driver.py +125 -0
- prompture/drivers/async_openai_driver.py +96 -0
- prompture/drivers/async_openrouter_driver.py +96 -0
- prompture/drivers/async_registry.py +80 -0
- prompture/drivers/azure_driver.py +36 -15
- prompture/drivers/claude_driver.py +86 -40
- prompture/drivers/google_driver.py +86 -58
- prompture/drivers/grok_driver.py +29 -38
- prompture/drivers/groq_driver.py +27 -32
- prompture/drivers/hugging_driver.py +6 -6
- prompture/drivers/lmstudio_driver.py +26 -13
- prompture/drivers/local_http_driver.py +6 -6
- prompture/drivers/ollama_driver.py +90 -23
- prompture/drivers/openai_driver.py +36 -15
- prompture/drivers/openrouter_driver.py +31 -31
- prompture/field_definitions.py +106 -96
- prompture/logging.py +80 -0
- prompture/model_rates.py +16 -15
- prompture/runner.py +49 -47
- prompture/session.py +117 -0
- prompture/settings.py +11 -1
- prompture/tools.py +172 -265
- prompture/validator.py +3 -3
- {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/METADATA +18 -20
- prompture-0.0.34.dev1.dist-info/RECORD +54 -0
- prompture-0.0.33.dev2.dist-info/RECORD +0 -30
- {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/WHEEL +0 -0
- {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""Async LocalHTTP driver using httpx."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import httpx
|
|
9
|
+
|
|
10
|
+
from ..async_driver import AsyncDriver
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AsyncLocalHTTPDriver(AsyncDriver):
|
|
14
|
+
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
15
|
+
|
|
16
|
+
def __init__(self, endpoint: str | None = None, model: str = "local-model"):
|
|
17
|
+
self.endpoint = endpoint or os.getenv("LOCAL_HTTP_ENDPOINT", "http://localhost:8000/generate")
|
|
18
|
+
self.model = model
|
|
19
|
+
|
|
20
|
+
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
21
|
+
payload = {"prompt": prompt, "options": options}
|
|
22
|
+
|
|
23
|
+
async with httpx.AsyncClient() as client:
|
|
24
|
+
try:
|
|
25
|
+
r = await client.post(self.endpoint, json=payload, timeout=options.get("timeout", 30))
|
|
26
|
+
r.raise_for_status()
|
|
27
|
+
response_data = r.json()
|
|
28
|
+
except Exception as e:
|
|
29
|
+
raise RuntimeError(f"AsyncLocalHTTPDriver request failed: {e}") from e
|
|
30
|
+
|
|
31
|
+
if "text" in response_data and "meta" in response_data:
|
|
32
|
+
return response_data
|
|
33
|
+
|
|
34
|
+
meta = {
|
|
35
|
+
"prompt_tokens": response_data.get("prompt_tokens", 0),
|
|
36
|
+
"completion_tokens": response_data.get("completion_tokens", 0),
|
|
37
|
+
"total_tokens": response_data.get("total_tokens", 0),
|
|
38
|
+
"cost": 0.0,
|
|
39
|
+
"raw_response": response_data,
|
|
40
|
+
"model_name": options.get("model", self.model),
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
text = response_data.get("text") or response_data.get("response") or str(response_data)
|
|
44
|
+
return {"text": text, "meta": meta}
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
"""Async Ollama driver using httpx."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import httpx
|
|
10
|
+
|
|
11
|
+
from ..async_driver import AsyncDriver
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AsyncOllamaDriver(AsyncDriver):
|
|
17
|
+
supports_json_mode = True
|
|
18
|
+
|
|
19
|
+
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
20
|
+
|
|
21
|
+
def __init__(self, endpoint: str | None = None, model: str = "llama3"):
|
|
22
|
+
self.endpoint = endpoint or os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434/api/generate")
|
|
23
|
+
self.model = model
|
|
24
|
+
self.options: dict[str, Any] = {}
|
|
25
|
+
|
|
26
|
+
supports_messages = True
|
|
27
|
+
|
|
28
|
+
async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
29
|
+
merged_options = self.options.copy()
|
|
30
|
+
if options:
|
|
31
|
+
merged_options.update(options)
|
|
32
|
+
|
|
33
|
+
payload = {
|
|
34
|
+
"prompt": prompt,
|
|
35
|
+
"model": merged_options.get("model", self.model),
|
|
36
|
+
"stream": False,
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
# Native JSON mode support
|
|
40
|
+
if merged_options.get("json_mode"):
|
|
41
|
+
payload["format"] = "json"
|
|
42
|
+
|
|
43
|
+
if "temperature" in merged_options:
|
|
44
|
+
payload["temperature"] = merged_options["temperature"]
|
|
45
|
+
if "top_p" in merged_options:
|
|
46
|
+
payload["top_p"] = merged_options["top_p"]
|
|
47
|
+
if "top_k" in merged_options:
|
|
48
|
+
payload["top_k"] = merged_options["top_k"]
|
|
49
|
+
|
|
50
|
+
async with httpx.AsyncClient() as client:
|
|
51
|
+
try:
|
|
52
|
+
r = await client.post(self.endpoint, json=payload, timeout=120)
|
|
53
|
+
r.raise_for_status()
|
|
54
|
+
response_data = r.json()
|
|
55
|
+
except httpx.HTTPStatusError as e:
|
|
56
|
+
raise RuntimeError(f"Ollama request failed: {e}") from e
|
|
57
|
+
except Exception as e:
|
|
58
|
+
raise RuntimeError(f"Ollama request failed: {e}") from e
|
|
59
|
+
|
|
60
|
+
prompt_tokens = response_data.get("prompt_eval_count", 0)
|
|
61
|
+
completion_tokens = response_data.get("eval_count", 0)
|
|
62
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
63
|
+
|
|
64
|
+
meta = {
|
|
65
|
+
"prompt_tokens": prompt_tokens,
|
|
66
|
+
"completion_tokens": completion_tokens,
|
|
67
|
+
"total_tokens": total_tokens,
|
|
68
|
+
"cost": 0.0,
|
|
69
|
+
"raw_response": response_data,
|
|
70
|
+
"model_name": merged_options.get("model", self.model),
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
return {"text": response_data.get("response", ""), "meta": meta}
|
|
74
|
+
|
|
75
|
+
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
76
|
+
"""Use Ollama's /api/chat endpoint for multi-turn conversations."""
|
|
77
|
+
merged_options = self.options.copy()
|
|
78
|
+
if options:
|
|
79
|
+
merged_options.update(options)
|
|
80
|
+
|
|
81
|
+
# Derive the chat endpoint from the generate endpoint
|
|
82
|
+
chat_endpoint = self.endpoint.replace("/api/generate", "/api/chat")
|
|
83
|
+
|
|
84
|
+
payload: dict[str, Any] = {
|
|
85
|
+
"model": merged_options.get("model", self.model),
|
|
86
|
+
"messages": messages,
|
|
87
|
+
"stream": False,
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
if merged_options.get("json_mode"):
|
|
91
|
+
payload["format"] = "json"
|
|
92
|
+
|
|
93
|
+
if "temperature" in merged_options:
|
|
94
|
+
payload["temperature"] = merged_options["temperature"]
|
|
95
|
+
if "top_p" in merged_options:
|
|
96
|
+
payload["top_p"] = merged_options["top_p"]
|
|
97
|
+
if "top_k" in merged_options:
|
|
98
|
+
payload["top_k"] = merged_options["top_k"]
|
|
99
|
+
|
|
100
|
+
async with httpx.AsyncClient() as client:
|
|
101
|
+
try:
|
|
102
|
+
r = await client.post(chat_endpoint, json=payload, timeout=120)
|
|
103
|
+
r.raise_for_status()
|
|
104
|
+
response_data = r.json()
|
|
105
|
+
except httpx.HTTPStatusError as e:
|
|
106
|
+
raise RuntimeError(f"Ollama chat request failed: {e}") from e
|
|
107
|
+
except Exception as e:
|
|
108
|
+
raise RuntimeError(f"Ollama chat request failed: {e}") from e
|
|
109
|
+
|
|
110
|
+
prompt_tokens = response_data.get("prompt_eval_count", 0)
|
|
111
|
+
completion_tokens = response_data.get("eval_count", 0)
|
|
112
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
113
|
+
|
|
114
|
+
meta = {
|
|
115
|
+
"prompt_tokens": prompt_tokens,
|
|
116
|
+
"completion_tokens": completion_tokens,
|
|
117
|
+
"total_tokens": total_tokens,
|
|
118
|
+
"cost": 0.0,
|
|
119
|
+
"raw_response": response_data,
|
|
120
|
+
"model_name": merged_options.get("model", self.model),
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
message = response_data.get("message", {})
|
|
124
|
+
text = message.get("content", "")
|
|
125
|
+
return {"text": text, "meta": meta}
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
"""Async OpenAI driver. Requires the ``openai`` package (>=1.0.0)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
from openai import AsyncOpenAI
|
|
10
|
+
except Exception:
|
|
11
|
+
AsyncOpenAI = None
|
|
12
|
+
|
|
13
|
+
from ..async_driver import AsyncDriver
|
|
14
|
+
from ..cost_mixin import CostMixin
|
|
15
|
+
from .openai_driver import OpenAIDriver
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AsyncOpenAIDriver(CostMixin, AsyncDriver):
|
|
19
|
+
supports_json_mode = True
|
|
20
|
+
supports_json_schema = True
|
|
21
|
+
|
|
22
|
+
MODEL_PRICING = OpenAIDriver.MODEL_PRICING
|
|
23
|
+
|
|
24
|
+
def __init__(self, api_key: str | None = None, model: str = "gpt-4o-mini"):
|
|
25
|
+
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
|
26
|
+
self.model = model
|
|
27
|
+
if AsyncOpenAI:
|
|
28
|
+
self.client = AsyncOpenAI(api_key=self.api_key)
|
|
29
|
+
else:
|
|
30
|
+
self.client = None
|
|
31
|
+
|
|
32
|
+
supports_messages = True
|
|
33
|
+
|
|
34
|
+
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
35
|
+
messages = [{"role": "user", "content": prompt}]
|
|
36
|
+
return await self._do_generate(messages, options)
|
|
37
|
+
|
|
38
|
+
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
39
|
+
return await self._do_generate(messages, options)
|
|
40
|
+
|
|
41
|
+
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
42
|
+
if self.client is None:
|
|
43
|
+
raise RuntimeError("openai package (>=1.0.0) is not installed")
|
|
44
|
+
|
|
45
|
+
model = options.get("model", self.model)
|
|
46
|
+
|
|
47
|
+
model_info = self.MODEL_PRICING.get(model, {})
|
|
48
|
+
tokens_param = model_info.get("tokens_param", "max_tokens")
|
|
49
|
+
supports_temperature = model_info.get("supports_temperature", True)
|
|
50
|
+
|
|
51
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
52
|
+
|
|
53
|
+
kwargs = {
|
|
54
|
+
"model": model,
|
|
55
|
+
"messages": messages,
|
|
56
|
+
}
|
|
57
|
+
kwargs[tokens_param] = opts.get("max_tokens", 512)
|
|
58
|
+
|
|
59
|
+
if supports_temperature and "temperature" in opts:
|
|
60
|
+
kwargs["temperature"] = opts["temperature"]
|
|
61
|
+
|
|
62
|
+
# Native JSON mode support
|
|
63
|
+
if options.get("json_mode"):
|
|
64
|
+
json_schema = options.get("json_schema")
|
|
65
|
+
if json_schema:
|
|
66
|
+
kwargs["response_format"] = {
|
|
67
|
+
"type": "json_schema",
|
|
68
|
+
"json_schema": {
|
|
69
|
+
"name": "extraction",
|
|
70
|
+
"strict": True,
|
|
71
|
+
"schema": json_schema,
|
|
72
|
+
},
|
|
73
|
+
}
|
|
74
|
+
else:
|
|
75
|
+
kwargs["response_format"] = {"type": "json_object"}
|
|
76
|
+
|
|
77
|
+
resp = await self.client.chat.completions.create(**kwargs)
|
|
78
|
+
|
|
79
|
+
usage = getattr(resp, "usage", None)
|
|
80
|
+
prompt_tokens = getattr(usage, "prompt_tokens", 0)
|
|
81
|
+
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
82
|
+
total_tokens = getattr(usage, "total_tokens", 0)
|
|
83
|
+
|
|
84
|
+
total_cost = self._calculate_cost("openai", model, prompt_tokens, completion_tokens)
|
|
85
|
+
|
|
86
|
+
meta = {
|
|
87
|
+
"prompt_tokens": prompt_tokens,
|
|
88
|
+
"completion_tokens": completion_tokens,
|
|
89
|
+
"total_tokens": total_tokens,
|
|
90
|
+
"cost": total_cost,
|
|
91
|
+
"raw_response": resp.model_dump(),
|
|
92
|
+
"model_name": model,
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
text = resp.choices[0].message.content
|
|
96
|
+
return {"text": text, "meta": meta}
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
"""Async OpenRouter driver using httpx."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import httpx
|
|
9
|
+
|
|
10
|
+
from ..async_driver import AsyncDriver
|
|
11
|
+
from ..cost_mixin import CostMixin
|
|
12
|
+
from .openrouter_driver import OpenRouterDriver
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
|
|
16
|
+
supports_json_mode = True
|
|
17
|
+
|
|
18
|
+
MODEL_PRICING = OpenRouterDriver.MODEL_PRICING
|
|
19
|
+
|
|
20
|
+
def __init__(self, api_key: str | None = None, model: str = "openai/gpt-3.5-turbo"):
|
|
21
|
+
self.api_key = api_key or os.getenv("OPENROUTER_API_KEY")
|
|
22
|
+
if not self.api_key:
|
|
23
|
+
raise ValueError("OpenRouter API key not found. Set OPENROUTER_API_KEY env var.")
|
|
24
|
+
self.model = model
|
|
25
|
+
self.base_url = "https://openrouter.ai/api/v1"
|
|
26
|
+
self.headers = {
|
|
27
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
28
|
+
"HTTP-Referer": "https://github.com/jhd3197/prompture",
|
|
29
|
+
"Content-Type": "application/json",
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
supports_messages = True
|
|
33
|
+
|
|
34
|
+
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
35
|
+
messages = [{"role": "user", "content": prompt}]
|
|
36
|
+
return await self._do_generate(messages, options)
|
|
37
|
+
|
|
38
|
+
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
39
|
+
return await self._do_generate(messages, options)
|
|
40
|
+
|
|
41
|
+
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
42
|
+
model = options.get("model", self.model)
|
|
43
|
+
|
|
44
|
+
model_info = self.MODEL_PRICING.get(model, {})
|
|
45
|
+
tokens_param = model_info.get("tokens_param", "max_tokens")
|
|
46
|
+
supports_temperature = model_info.get("supports_temperature", True)
|
|
47
|
+
|
|
48
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
49
|
+
|
|
50
|
+
data = {
|
|
51
|
+
"model": model,
|
|
52
|
+
"messages": messages,
|
|
53
|
+
}
|
|
54
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
55
|
+
|
|
56
|
+
if supports_temperature and "temperature" in opts:
|
|
57
|
+
data["temperature"] = opts["temperature"]
|
|
58
|
+
|
|
59
|
+
# Native JSON mode support
|
|
60
|
+
if options.get("json_mode"):
|
|
61
|
+
data["response_format"] = {"type": "json_object"}
|
|
62
|
+
|
|
63
|
+
async with httpx.AsyncClient() as client:
|
|
64
|
+
try:
|
|
65
|
+
response = await client.post(
|
|
66
|
+
f"{self.base_url}/chat/completions",
|
|
67
|
+
headers=self.headers,
|
|
68
|
+
json=data,
|
|
69
|
+
timeout=120,
|
|
70
|
+
)
|
|
71
|
+
response.raise_for_status()
|
|
72
|
+
resp = response.json()
|
|
73
|
+
except httpx.HTTPStatusError as e:
|
|
74
|
+
error_msg = f"OpenRouter API request failed: {e!s}"
|
|
75
|
+
raise RuntimeError(error_msg) from e
|
|
76
|
+
except Exception as e:
|
|
77
|
+
raise RuntimeError(f"OpenRouter API request failed: {e!s}") from e
|
|
78
|
+
|
|
79
|
+
usage = resp.get("usage", {})
|
|
80
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
81
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
82
|
+
total_tokens = usage.get("total_tokens", 0)
|
|
83
|
+
|
|
84
|
+
total_cost = self._calculate_cost("openrouter", model, prompt_tokens, completion_tokens)
|
|
85
|
+
|
|
86
|
+
meta = {
|
|
87
|
+
"prompt_tokens": prompt_tokens,
|
|
88
|
+
"completion_tokens": completion_tokens,
|
|
89
|
+
"total_tokens": total_tokens,
|
|
90
|
+
"cost": total_cost,
|
|
91
|
+
"raw_response": resp,
|
|
92
|
+
"model_name": model,
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
text = resp["choices"][0]["message"]["content"]
|
|
96
|
+
return {"text": text, "meta": meta}
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""Async driver registry — mirrors the sync DRIVER_REGISTRY."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from ..settings import settings
|
|
6
|
+
from .async_airllm_driver import AsyncAirLLMDriver
|
|
7
|
+
from .async_azure_driver import AsyncAzureDriver
|
|
8
|
+
from .async_claude_driver import AsyncClaudeDriver
|
|
9
|
+
from .async_google_driver import AsyncGoogleDriver
|
|
10
|
+
from .async_grok_driver import AsyncGrokDriver
|
|
11
|
+
from .async_groq_driver import AsyncGroqDriver
|
|
12
|
+
from .async_lmstudio_driver import AsyncLMStudioDriver
|
|
13
|
+
from .async_local_http_driver import AsyncLocalHTTPDriver
|
|
14
|
+
from .async_ollama_driver import AsyncOllamaDriver
|
|
15
|
+
from .async_openai_driver import AsyncOpenAIDriver
|
|
16
|
+
from .async_openrouter_driver import AsyncOpenRouterDriver
|
|
17
|
+
|
|
18
|
+
ASYNC_DRIVER_REGISTRY = {
|
|
19
|
+
"openai": lambda model=None: AsyncOpenAIDriver(
|
|
20
|
+
api_key=settings.openai_api_key, model=model or settings.openai_model
|
|
21
|
+
),
|
|
22
|
+
"ollama": lambda model=None: AsyncOllamaDriver(
|
|
23
|
+
endpoint=settings.ollama_endpoint, model=model or settings.ollama_model
|
|
24
|
+
),
|
|
25
|
+
"claude": lambda model=None: AsyncClaudeDriver(
|
|
26
|
+
api_key=settings.claude_api_key, model=model or settings.claude_model
|
|
27
|
+
),
|
|
28
|
+
"lmstudio": lambda model=None: AsyncLMStudioDriver(
|
|
29
|
+
endpoint=settings.lmstudio_endpoint, model=model or settings.lmstudio_model
|
|
30
|
+
),
|
|
31
|
+
"azure": lambda model=None: AsyncAzureDriver(
|
|
32
|
+
api_key=settings.azure_api_key, endpoint=settings.azure_api_endpoint, deployment_id=settings.azure_deployment_id
|
|
33
|
+
),
|
|
34
|
+
"local_http": lambda model=None: AsyncLocalHTTPDriver(endpoint=settings.local_http_endpoint, model=model),
|
|
35
|
+
"google": lambda model=None: AsyncGoogleDriver(
|
|
36
|
+
api_key=settings.google_api_key, model=model or settings.google_model
|
|
37
|
+
),
|
|
38
|
+
"groq": lambda model=None: AsyncGroqDriver(api_key=settings.groq_api_key, model=model or settings.groq_model),
|
|
39
|
+
"openrouter": lambda model=None: AsyncOpenRouterDriver(
|
|
40
|
+
api_key=settings.openrouter_api_key, model=model or settings.openrouter_model
|
|
41
|
+
),
|
|
42
|
+
"grok": lambda model=None: AsyncGrokDriver(api_key=settings.grok_api_key, model=model or settings.grok_model),
|
|
43
|
+
"airllm": lambda model=None: AsyncAirLLMDriver(
|
|
44
|
+
model=model or settings.airllm_model,
|
|
45
|
+
compression=settings.airllm_compression,
|
|
46
|
+
),
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def get_async_driver(provider_name: str | None = None):
|
|
51
|
+
"""Factory to get an async driver instance based on the provider name.
|
|
52
|
+
|
|
53
|
+
Uses default model from settings if not overridden.
|
|
54
|
+
"""
|
|
55
|
+
provider = (provider_name or settings.ai_provider or "ollama").strip().lower()
|
|
56
|
+
if provider not in ASYNC_DRIVER_REGISTRY:
|
|
57
|
+
raise ValueError(f"Unknown provider: {provider_name}")
|
|
58
|
+
return ASYNC_DRIVER_REGISTRY[provider]()
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def get_async_driver_for_model(model_str: str):
|
|
62
|
+
"""Factory to get an async driver instance based on a full model string.
|
|
63
|
+
|
|
64
|
+
Format: ``provider/model_id``
|
|
65
|
+
Example: ``"openai/gpt-4-turbo-preview"``
|
|
66
|
+
"""
|
|
67
|
+
if not isinstance(model_str, str):
|
|
68
|
+
raise ValueError("Model string must be a string, got {type(model_str)}")
|
|
69
|
+
|
|
70
|
+
if not model_str:
|
|
71
|
+
raise ValueError("Model string cannot be empty")
|
|
72
|
+
|
|
73
|
+
parts = model_str.split("/", 1)
|
|
74
|
+
provider = parts[0].lower()
|
|
75
|
+
model_id = parts[1] if len(parts) > 1 else None
|
|
76
|
+
|
|
77
|
+
if provider not in ASYNC_DRIVER_REGISTRY:
|
|
78
|
+
raise ValueError(f"Unsupported provider '{provider}'")
|
|
79
|
+
|
|
80
|
+
return ASYNC_DRIVER_REGISTRY[provider](model_id)
|
|
@@ -1,17 +1,23 @@
|
|
|
1
1
|
"""Driver for Azure OpenAI Service (migrated to openai>=1.0.0).
|
|
2
2
|
Requires the `openai` package.
|
|
3
3
|
"""
|
|
4
|
+
|
|
4
5
|
import os
|
|
5
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
6
8
|
try:
|
|
7
9
|
from openai import AzureOpenAI
|
|
8
10
|
except Exception:
|
|
9
11
|
AzureOpenAI = None
|
|
10
12
|
|
|
13
|
+
from ..cost_mixin import CostMixin
|
|
11
14
|
from ..driver import Driver
|
|
12
15
|
|
|
13
16
|
|
|
14
|
-
class AzureDriver(Driver):
|
|
17
|
+
class AzureDriver(CostMixin, Driver):
|
|
18
|
+
supports_json_mode = True
|
|
19
|
+
supports_json_schema = True
|
|
20
|
+
|
|
15
21
|
# Pricing per 1K tokens (adjust if your Azure pricing differs from OpenAI defaults)
|
|
16
22
|
MODEL_PRICING = {
|
|
17
23
|
"gpt-5-mini": {
|
|
@@ -82,7 +88,16 @@ class AzureDriver(Driver):
|
|
|
82
88
|
else:
|
|
83
89
|
self.client = None
|
|
84
90
|
|
|
85
|
-
|
|
91
|
+
supports_messages = True
|
|
92
|
+
|
|
93
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
94
|
+
messages = [{"role": "user", "content": prompt}]
|
|
95
|
+
return self._do_generate(messages, options)
|
|
96
|
+
|
|
97
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
98
|
+
return self._do_generate(messages, options)
|
|
99
|
+
|
|
100
|
+
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
86
101
|
if self.client is None:
|
|
87
102
|
raise RuntimeError("openai package (>=1.0.0) with AzureOpenAI not installed")
|
|
88
103
|
|
|
@@ -96,13 +111,28 @@ class AzureDriver(Driver):
|
|
|
96
111
|
# Build request kwargs
|
|
97
112
|
kwargs = {
|
|
98
113
|
"model": self.deployment_id, # for Azure, use deployment name
|
|
99
|
-
"messages":
|
|
114
|
+
"messages": messages,
|
|
100
115
|
}
|
|
101
116
|
kwargs[tokens_param] = opts.get("max_tokens", 512)
|
|
102
117
|
|
|
103
118
|
if supports_temperature and "temperature" in opts:
|
|
104
119
|
kwargs["temperature"] = opts["temperature"]
|
|
105
120
|
|
|
121
|
+
# Native JSON mode support
|
|
122
|
+
if options.get("json_mode"):
|
|
123
|
+
json_schema = options.get("json_schema")
|
|
124
|
+
if json_schema:
|
|
125
|
+
kwargs["response_format"] = {
|
|
126
|
+
"type": "json_schema",
|
|
127
|
+
"json_schema": {
|
|
128
|
+
"name": "extraction",
|
|
129
|
+
"strict": True,
|
|
130
|
+
"schema": json_schema,
|
|
131
|
+
},
|
|
132
|
+
}
|
|
133
|
+
else:
|
|
134
|
+
kwargs["response_format"] = {"type": "json_object"}
|
|
135
|
+
|
|
106
136
|
resp = self.client.chat.completions.create(**kwargs)
|
|
107
137
|
|
|
108
138
|
# Extract usage
|
|
@@ -111,17 +141,8 @@ class AzureDriver(Driver):
|
|
|
111
141
|
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
112
142
|
total_tokens = getattr(usage, "total_tokens", 0)
|
|
113
143
|
|
|
114
|
-
# Calculate cost
|
|
115
|
-
|
|
116
|
-
live_rates = get_model_rates("azure", model)
|
|
117
|
-
if live_rates:
|
|
118
|
-
prompt_cost = (prompt_tokens / 1_000_000) * live_rates["input"]
|
|
119
|
-
completion_cost = (completion_tokens / 1_000_000) * live_rates["output"]
|
|
120
|
-
else:
|
|
121
|
-
model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
|
|
122
|
-
prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
|
|
123
|
-
completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
|
|
124
|
-
total_cost = prompt_cost + completion_cost
|
|
144
|
+
# Calculate cost via shared mixin
|
|
145
|
+
total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
|
|
125
146
|
|
|
126
147
|
# Standardized meta object
|
|
127
148
|
meta = {
|