prompture 0.0.33.dev1__py3-none-any.whl → 0.0.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +133 -49
- prompture/_version.py +34 -0
- prompture/aio/__init__.py +74 -0
- prompture/async_conversation.py +484 -0
- prompture/async_core.py +803 -0
- prompture/async_driver.py +131 -0
- prompture/cache.py +469 -0
- prompture/callbacks.py +50 -0
- prompture/cli.py +7 -3
- prompture/conversation.py +504 -0
- prompture/core.py +475 -352
- prompture/cost_mixin.py +51 -0
- prompture/discovery.py +50 -35
- prompture/driver.py +125 -5
- prompture/drivers/__init__.py +171 -73
- prompture/drivers/airllm_driver.py +13 -20
- prompture/drivers/async_airllm_driver.py +26 -0
- prompture/drivers/async_azure_driver.py +117 -0
- prompture/drivers/async_claude_driver.py +107 -0
- prompture/drivers/async_google_driver.py +132 -0
- prompture/drivers/async_grok_driver.py +91 -0
- prompture/drivers/async_groq_driver.py +84 -0
- prompture/drivers/async_hugging_driver.py +61 -0
- prompture/drivers/async_lmstudio_driver.py +79 -0
- prompture/drivers/async_local_http_driver.py +44 -0
- prompture/drivers/async_ollama_driver.py +125 -0
- prompture/drivers/async_openai_driver.py +96 -0
- prompture/drivers/async_openrouter_driver.py +96 -0
- prompture/drivers/async_registry.py +129 -0
- prompture/drivers/azure_driver.py +36 -9
- prompture/drivers/claude_driver.py +86 -34
- prompture/drivers/google_driver.py +87 -51
- prompture/drivers/grok_driver.py +29 -32
- prompture/drivers/groq_driver.py +27 -26
- prompture/drivers/hugging_driver.py +6 -6
- prompture/drivers/lmstudio_driver.py +26 -13
- prompture/drivers/local_http_driver.py +6 -6
- prompture/drivers/ollama_driver.py +90 -23
- prompture/drivers/openai_driver.py +36 -9
- prompture/drivers/openrouter_driver.py +31 -25
- prompture/drivers/registry.py +306 -0
- prompture/field_definitions.py +106 -96
- prompture/logging.py +80 -0
- prompture/model_rates.py +217 -0
- prompture/runner.py +49 -47
- prompture/session.py +117 -0
- prompture/settings.py +14 -1
- prompture/tools.py +172 -265
- prompture/validator.py +3 -3
- {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/METADATA +18 -20
- prompture-0.0.34.dist-info/RECORD +55 -0
- prompture-0.0.33.dev1.dist-info/RECORD +0 -29
- {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/WHEEL +0 -0
- {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
from typing import Any, Optional
|
|
3
|
+
|
|
2
4
|
from ..driver import Driver
|
|
3
|
-
from typing import Any, Dict, Optional
|
|
4
5
|
|
|
5
6
|
logger = logging.getLogger(__name__)
|
|
6
7
|
|
|
@@ -13,12 +14,9 @@ class AirLLMDriver(Driver):
|
|
|
13
14
|
``generate()`` call so the rest of Prompture works without it installed.
|
|
14
15
|
"""
|
|
15
16
|
|
|
16
|
-
MODEL_PRICING = {
|
|
17
|
-
"default": {"prompt": 0.0, "completion": 0.0}
|
|
18
|
-
}
|
|
17
|
+
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
19
18
|
|
|
20
|
-
def __init__(self, model: str = "meta-llama/Llama-2-7b-hf",
|
|
21
|
-
compression: Optional[str] = None):
|
|
19
|
+
def __init__(self, model: str = "meta-llama/Llama-2-7b-hf", compression: Optional[str] = None):
|
|
22
20
|
"""
|
|
23
21
|
Args:
|
|
24
22
|
model: HuggingFace repo ID (e.g. ``"meta-llama/Llama-2-70b-hf"``).
|
|
@@ -26,7 +24,7 @@ class AirLLMDriver(Driver):
|
|
|
26
24
|
"""
|
|
27
25
|
self.model = model
|
|
28
26
|
self.compression = compression
|
|
29
|
-
self.options:
|
|
27
|
+
self.options: dict[str, Any] = {}
|
|
30
28
|
self._llm = None
|
|
31
29
|
self._tokenizer = None
|
|
32
30
|
|
|
@@ -42,9 +40,8 @@ class AirLLMDriver(Driver):
|
|
|
42
40
|
from airllm import AutoModel
|
|
43
41
|
except ImportError:
|
|
44
42
|
raise ImportError(
|
|
45
|
-
"The 'airllm' package is required for the AirLLM driver. "
|
|
46
|
-
|
|
47
|
-
)
|
|
43
|
+
"The 'airllm' package is required for the AirLLM driver. Install it with: pip install prompture[airllm]"
|
|
44
|
+
) from None
|
|
48
45
|
|
|
49
46
|
try:
|
|
50
47
|
from transformers import AutoTokenizer
|
|
@@ -52,12 +49,11 @@ class AirLLMDriver(Driver):
|
|
|
52
49
|
raise ImportError(
|
|
53
50
|
"The 'transformers' package is required for the AirLLM driver. "
|
|
54
51
|
"Install it with: pip install transformers"
|
|
55
|
-
)
|
|
52
|
+
) from None
|
|
56
53
|
|
|
57
|
-
logger.info(f"Loading AirLLM model: {self.model} "
|
|
58
|
-
f"(compression={self.compression})")
|
|
54
|
+
logger.info(f"Loading AirLLM model: {self.model} (compression={self.compression})")
|
|
59
55
|
|
|
60
|
-
load_kwargs:
|
|
56
|
+
load_kwargs: dict[str, Any] = {}
|
|
61
57
|
if self.compression:
|
|
62
58
|
load_kwargs["compression"] = self.compression
|
|
63
59
|
|
|
@@ -68,7 +64,7 @@ class AirLLMDriver(Driver):
|
|
|
68
64
|
# ------------------------------------------------------------------
|
|
69
65
|
# Driver interface
|
|
70
66
|
# ------------------------------------------------------------------
|
|
71
|
-
def generate(self, prompt: str, options:
|
|
67
|
+
def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
72
68
|
self._ensure_loaded()
|
|
73
69
|
|
|
74
70
|
merged_options = self.options.copy()
|
|
@@ -78,14 +74,11 @@ class AirLLMDriver(Driver):
|
|
|
78
74
|
max_new_tokens = merged_options.get("max_new_tokens", 256)
|
|
79
75
|
|
|
80
76
|
# Tokenize
|
|
81
|
-
input_ids = self._tokenizer(
|
|
82
|
-
prompt, return_tensors="pt"
|
|
83
|
-
).input_ids
|
|
77
|
+
input_ids = self._tokenizer(prompt, return_tensors="pt").input_ids
|
|
84
78
|
|
|
85
79
|
prompt_tokens = input_ids.shape[1]
|
|
86
80
|
|
|
87
|
-
logger.debug(f"AirLLM generating with max_new_tokens={max_new_tokens}, "
|
|
88
|
-
f"prompt_tokens={prompt_tokens}")
|
|
81
|
+
logger.debug(f"AirLLM generating with max_new_tokens={max_new_tokens}, prompt_tokens={prompt_tokens}")
|
|
89
82
|
|
|
90
83
|
# Generate
|
|
91
84
|
output_ids = self._llm.generate(
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Async AirLLM driver — wraps the sync GPU-bound driver with asyncio.to_thread."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from ..async_driver import AsyncDriver
|
|
9
|
+
from .airllm_driver import AirLLMDriver
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class AsyncAirLLMDriver(AsyncDriver):
|
|
13
|
+
"""Async wrapper around :class:`AirLLMDriver`.
|
|
14
|
+
|
|
15
|
+
AirLLM is GPU-bound with no native async API, so we delegate to
|
|
16
|
+
``asyncio.to_thread()`` to avoid blocking the event loop.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
MODEL_PRICING = AirLLMDriver.MODEL_PRICING
|
|
20
|
+
|
|
21
|
+
def __init__(self, model: str = "meta-llama/Llama-2-7b-hf", compression: str | None = None):
|
|
22
|
+
self.model = model
|
|
23
|
+
self._sync_driver = AirLLMDriver(model=model, compression=compression)
|
|
24
|
+
|
|
25
|
+
async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
26
|
+
return await asyncio.to_thread(self._sync_driver.generate, prompt, options)
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
"""Async Azure OpenAI driver. Requires the ``openai`` package (>=1.0.0)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
from openai import AsyncAzureOpenAI
|
|
10
|
+
except Exception:
|
|
11
|
+
AsyncAzureOpenAI = None
|
|
12
|
+
|
|
13
|
+
from ..async_driver import AsyncDriver
|
|
14
|
+
from ..cost_mixin import CostMixin
|
|
15
|
+
from .azure_driver import AzureDriver
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AsyncAzureDriver(CostMixin, AsyncDriver):
|
|
19
|
+
supports_json_mode = True
|
|
20
|
+
supports_json_schema = True
|
|
21
|
+
|
|
22
|
+
MODEL_PRICING = AzureDriver.MODEL_PRICING
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
api_key: str | None = None,
|
|
27
|
+
endpoint: str | None = None,
|
|
28
|
+
deployment_id: str | None = None,
|
|
29
|
+
model: str = "gpt-4o-mini",
|
|
30
|
+
):
|
|
31
|
+
self.api_key = api_key or os.getenv("AZURE_API_KEY")
|
|
32
|
+
self.endpoint = endpoint or os.getenv("AZURE_API_ENDPOINT")
|
|
33
|
+
self.deployment_id = deployment_id or os.getenv("AZURE_DEPLOYMENT_ID")
|
|
34
|
+
self.api_version = os.getenv("AZURE_API_VERSION", "2023-07-01-preview")
|
|
35
|
+
self.model = model
|
|
36
|
+
|
|
37
|
+
if not self.api_key:
|
|
38
|
+
raise ValueError("Missing Azure API key (AZURE_API_KEY).")
|
|
39
|
+
if not self.endpoint:
|
|
40
|
+
raise ValueError("Missing Azure API endpoint (AZURE_API_ENDPOINT).")
|
|
41
|
+
if not self.deployment_id:
|
|
42
|
+
raise ValueError("Missing Azure deployment ID (AZURE_DEPLOYMENT_ID).")
|
|
43
|
+
|
|
44
|
+
if AsyncAzureOpenAI:
|
|
45
|
+
self.client = AsyncAzureOpenAI(
|
|
46
|
+
api_key=self.api_key,
|
|
47
|
+
api_version=self.api_version,
|
|
48
|
+
azure_endpoint=self.endpoint,
|
|
49
|
+
)
|
|
50
|
+
else:
|
|
51
|
+
self.client = None
|
|
52
|
+
|
|
53
|
+
supports_messages = True
|
|
54
|
+
|
|
55
|
+
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
56
|
+
messages = [{"role": "user", "content": prompt}]
|
|
57
|
+
return await self._do_generate(messages, options)
|
|
58
|
+
|
|
59
|
+
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
60
|
+
return await self._do_generate(messages, options)
|
|
61
|
+
|
|
62
|
+
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
63
|
+
if self.client is None:
|
|
64
|
+
raise RuntimeError("openai package (>=1.0.0) with AsyncAzureOpenAI not installed")
|
|
65
|
+
|
|
66
|
+
model = options.get("model", self.model)
|
|
67
|
+
model_info = self.MODEL_PRICING.get(model, {})
|
|
68
|
+
tokens_param = model_info.get("tokens_param", "max_tokens")
|
|
69
|
+
supports_temperature = model_info.get("supports_temperature", True)
|
|
70
|
+
|
|
71
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
72
|
+
|
|
73
|
+
kwargs = {
|
|
74
|
+
"model": self.deployment_id,
|
|
75
|
+
"messages": messages,
|
|
76
|
+
}
|
|
77
|
+
kwargs[tokens_param] = opts.get("max_tokens", 512)
|
|
78
|
+
|
|
79
|
+
if supports_temperature and "temperature" in opts:
|
|
80
|
+
kwargs["temperature"] = opts["temperature"]
|
|
81
|
+
|
|
82
|
+
# Native JSON mode support
|
|
83
|
+
if options.get("json_mode"):
|
|
84
|
+
json_schema = options.get("json_schema")
|
|
85
|
+
if json_schema:
|
|
86
|
+
kwargs["response_format"] = {
|
|
87
|
+
"type": "json_schema",
|
|
88
|
+
"json_schema": {
|
|
89
|
+
"name": "extraction",
|
|
90
|
+
"strict": True,
|
|
91
|
+
"schema": json_schema,
|
|
92
|
+
},
|
|
93
|
+
}
|
|
94
|
+
else:
|
|
95
|
+
kwargs["response_format"] = {"type": "json_object"}
|
|
96
|
+
|
|
97
|
+
resp = await self.client.chat.completions.create(**kwargs)
|
|
98
|
+
|
|
99
|
+
usage = getattr(resp, "usage", None)
|
|
100
|
+
prompt_tokens = getattr(usage, "prompt_tokens", 0)
|
|
101
|
+
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
102
|
+
total_tokens = getattr(usage, "total_tokens", 0)
|
|
103
|
+
|
|
104
|
+
total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
|
|
105
|
+
|
|
106
|
+
meta = {
|
|
107
|
+
"prompt_tokens": prompt_tokens,
|
|
108
|
+
"completion_tokens": completion_tokens,
|
|
109
|
+
"total_tokens": total_tokens,
|
|
110
|
+
"cost": total_cost,
|
|
111
|
+
"raw_response": resp.model_dump(),
|
|
112
|
+
"model_name": model,
|
|
113
|
+
"deployment_id": self.deployment_id,
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
text = resp.choices[0].message.content
|
|
117
|
+
return {"text": text, "meta": meta}
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
"""Async Anthropic Claude driver. Requires the ``anthropic`` package."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import anthropic
|
|
11
|
+
except Exception:
|
|
12
|
+
anthropic = None
|
|
13
|
+
|
|
14
|
+
from ..async_driver import AsyncDriver
|
|
15
|
+
from ..cost_mixin import CostMixin
|
|
16
|
+
from .claude_driver import ClaudeDriver
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AsyncClaudeDriver(CostMixin, AsyncDriver):
|
|
20
|
+
supports_json_mode = True
|
|
21
|
+
supports_json_schema = True
|
|
22
|
+
|
|
23
|
+
MODEL_PRICING = ClaudeDriver.MODEL_PRICING
|
|
24
|
+
|
|
25
|
+
def __init__(self, api_key: str | None = None, model: str = "claude-3-5-haiku-20241022"):
|
|
26
|
+
self.api_key = api_key or os.getenv("CLAUDE_API_KEY")
|
|
27
|
+
self.model = model or os.getenv("CLAUDE_MODEL_NAME", "claude-3-5-haiku-20241022")
|
|
28
|
+
|
|
29
|
+
supports_messages = True
|
|
30
|
+
|
|
31
|
+
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
32
|
+
messages = [{"role": "user", "content": prompt}]
|
|
33
|
+
return await self._do_generate(messages, options)
|
|
34
|
+
|
|
35
|
+
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
36
|
+
return await self._do_generate(messages, options)
|
|
37
|
+
|
|
38
|
+
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
39
|
+
if anthropic is None:
|
|
40
|
+
raise RuntimeError("anthropic package not installed")
|
|
41
|
+
|
|
42
|
+
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
43
|
+
model = options.get("model", self.model)
|
|
44
|
+
|
|
45
|
+
client = anthropic.AsyncAnthropic(api_key=self.api_key)
|
|
46
|
+
|
|
47
|
+
# Anthropic requires system messages as a top-level parameter
|
|
48
|
+
system_content = None
|
|
49
|
+
api_messages = []
|
|
50
|
+
for msg in messages:
|
|
51
|
+
if msg.get("role") == "system":
|
|
52
|
+
system_content = msg.get("content", "")
|
|
53
|
+
else:
|
|
54
|
+
api_messages.append(msg)
|
|
55
|
+
|
|
56
|
+
# Build common kwargs
|
|
57
|
+
common_kwargs: dict[str, Any] = {
|
|
58
|
+
"model": model,
|
|
59
|
+
"messages": api_messages,
|
|
60
|
+
"temperature": opts["temperature"],
|
|
61
|
+
"max_tokens": opts["max_tokens"],
|
|
62
|
+
}
|
|
63
|
+
if system_content:
|
|
64
|
+
common_kwargs["system"] = system_content
|
|
65
|
+
|
|
66
|
+
# Native JSON mode: use tool-use for schema enforcement
|
|
67
|
+
if options.get("json_mode"):
|
|
68
|
+
json_schema = options.get("json_schema")
|
|
69
|
+
if json_schema:
|
|
70
|
+
tool_def = {
|
|
71
|
+
"name": "extract_json",
|
|
72
|
+
"description": "Extract structured data matching the schema",
|
|
73
|
+
"input_schema": json_schema,
|
|
74
|
+
}
|
|
75
|
+
resp = await client.messages.create(
|
|
76
|
+
**common_kwargs,
|
|
77
|
+
tools=[tool_def],
|
|
78
|
+
tool_choice={"type": "tool", "name": "extract_json"},
|
|
79
|
+
)
|
|
80
|
+
text = ""
|
|
81
|
+
for block in resp.content:
|
|
82
|
+
if block.type == "tool_use":
|
|
83
|
+
text = json.dumps(block.input)
|
|
84
|
+
break
|
|
85
|
+
else:
|
|
86
|
+
resp = await client.messages.create(**common_kwargs)
|
|
87
|
+
text = resp.content[0].text
|
|
88
|
+
else:
|
|
89
|
+
resp = await client.messages.create(**common_kwargs)
|
|
90
|
+
text = resp.content[0].text
|
|
91
|
+
|
|
92
|
+
prompt_tokens = resp.usage.input_tokens
|
|
93
|
+
completion_tokens = resp.usage.output_tokens
|
|
94
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
95
|
+
|
|
96
|
+
total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
|
|
97
|
+
|
|
98
|
+
meta = {
|
|
99
|
+
"prompt_tokens": prompt_tokens,
|
|
100
|
+
"completion_tokens": completion_tokens,
|
|
101
|
+
"total_tokens": total_tokens,
|
|
102
|
+
"cost": total_cost,
|
|
103
|
+
"raw_response": dict(resp),
|
|
104
|
+
"model_name": model,
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
return {"text": text, "meta": meta}
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
"""Async Google Generative AI (Gemini) driver."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import google.generativeai as genai
|
|
10
|
+
|
|
11
|
+
from ..async_driver import AsyncDriver
|
|
12
|
+
from ..cost_mixin import CostMixin
|
|
13
|
+
from .google_driver import GoogleDriver
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AsyncGoogleDriver(CostMixin, AsyncDriver):
|
|
19
|
+
"""Async driver for Google's Generative AI API (Gemini)."""
|
|
20
|
+
|
|
21
|
+
supports_json_mode = True
|
|
22
|
+
supports_json_schema = True
|
|
23
|
+
|
|
24
|
+
MODEL_PRICING = GoogleDriver.MODEL_PRICING
|
|
25
|
+
_PRICING_UNIT = 1_000_000
|
|
26
|
+
|
|
27
|
+
def __init__(self, api_key: str | None = None, model: str = "gemini-1.5-pro"):
|
|
28
|
+
self.api_key = api_key or os.getenv("GOOGLE_API_KEY")
|
|
29
|
+
if not self.api_key:
|
|
30
|
+
raise ValueError("Google API key not found. Set GOOGLE_API_KEY env var or pass api_key to constructor")
|
|
31
|
+
self.model = model
|
|
32
|
+
genai.configure(api_key=self.api_key)
|
|
33
|
+
self.options: dict[str, Any] = {}
|
|
34
|
+
|
|
35
|
+
def _calculate_cost_chars(self, prompt_chars: int, completion_chars: int) -> float:
|
|
36
|
+
"""Calculate cost from character counts (same logic as sync GoogleDriver)."""
|
|
37
|
+
from ..model_rates import get_model_rates
|
|
38
|
+
|
|
39
|
+
live_rates = get_model_rates("google", self.model)
|
|
40
|
+
if live_rates:
|
|
41
|
+
est_prompt_tokens = prompt_chars / 4
|
|
42
|
+
est_completion_tokens = completion_chars / 4
|
|
43
|
+
prompt_cost = (est_prompt_tokens / 1_000_000) * live_rates["input"]
|
|
44
|
+
completion_cost = (est_completion_tokens / 1_000_000) * live_rates["output"]
|
|
45
|
+
else:
|
|
46
|
+
model_pricing = self.MODEL_PRICING.get(self.model, {"prompt": 0, "completion": 0})
|
|
47
|
+
prompt_cost = (prompt_chars / 1_000_000) * model_pricing["prompt"]
|
|
48
|
+
completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
|
|
49
|
+
return round(prompt_cost + completion_cost, 6)
|
|
50
|
+
|
|
51
|
+
supports_messages = True
|
|
52
|
+
|
|
53
|
+
async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
54
|
+
messages = [{"role": "user", "content": prompt}]
|
|
55
|
+
return await self._do_generate(messages, options)
|
|
56
|
+
|
|
57
|
+
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
58
|
+
return await self._do_generate(messages, options)
|
|
59
|
+
|
|
60
|
+
async def _do_generate(
|
|
61
|
+
self, messages: list[dict[str, str]], options: dict[str, Any] | None = None
|
|
62
|
+
) -> dict[str, Any]:
|
|
63
|
+
merged_options = self.options.copy()
|
|
64
|
+
if options:
|
|
65
|
+
merged_options.update(options)
|
|
66
|
+
|
|
67
|
+
generation_config = merged_options.get("generation_config", {})
|
|
68
|
+
safety_settings = merged_options.get("safety_settings", {})
|
|
69
|
+
|
|
70
|
+
if "temperature" in merged_options and "temperature" not in generation_config:
|
|
71
|
+
generation_config["temperature"] = merged_options["temperature"]
|
|
72
|
+
if "max_tokens" in merged_options and "max_output_tokens" not in generation_config:
|
|
73
|
+
generation_config["max_output_tokens"] = merged_options["max_tokens"]
|
|
74
|
+
if "top_p" in merged_options and "top_p" not in generation_config:
|
|
75
|
+
generation_config["top_p"] = merged_options["top_p"]
|
|
76
|
+
if "top_k" in merged_options and "top_k" not in generation_config:
|
|
77
|
+
generation_config["top_k"] = merged_options["top_k"]
|
|
78
|
+
|
|
79
|
+
# Native JSON mode support
|
|
80
|
+
if merged_options.get("json_mode"):
|
|
81
|
+
generation_config["response_mime_type"] = "application/json"
|
|
82
|
+
json_schema = merged_options.get("json_schema")
|
|
83
|
+
if json_schema:
|
|
84
|
+
generation_config["response_schema"] = json_schema
|
|
85
|
+
|
|
86
|
+
# Convert messages to Gemini format
|
|
87
|
+
system_instruction = None
|
|
88
|
+
contents: list[dict[str, Any]] = []
|
|
89
|
+
for msg in messages:
|
|
90
|
+
role = msg.get("role", "user")
|
|
91
|
+
content = msg.get("content", "")
|
|
92
|
+
if role == "system":
|
|
93
|
+
system_instruction = content
|
|
94
|
+
else:
|
|
95
|
+
gemini_role = "model" if role == "assistant" else "user"
|
|
96
|
+
contents.append({"role": gemini_role, "parts": [content]})
|
|
97
|
+
|
|
98
|
+
try:
|
|
99
|
+
model_kwargs: dict[str, Any] = {}
|
|
100
|
+
if system_instruction:
|
|
101
|
+
model_kwargs["system_instruction"] = system_instruction
|
|
102
|
+
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
103
|
+
|
|
104
|
+
gen_input: Any = contents if len(contents) != 1 else contents[0]["parts"][0]
|
|
105
|
+
response = await model.generate_content_async(
|
|
106
|
+
gen_input,
|
|
107
|
+
generation_config=generation_config if generation_config else None,
|
|
108
|
+
safety_settings=safety_settings if safety_settings else None,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
if not response.text:
|
|
112
|
+
raise ValueError("Empty response from model")
|
|
113
|
+
|
|
114
|
+
total_prompt_chars = sum(len(msg.get("content", "")) for msg in messages)
|
|
115
|
+
completion_chars = len(response.text)
|
|
116
|
+
|
|
117
|
+
total_cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
|
|
118
|
+
|
|
119
|
+
meta = {
|
|
120
|
+
"prompt_chars": total_prompt_chars,
|
|
121
|
+
"completion_chars": completion_chars,
|
|
122
|
+
"total_chars": total_prompt_chars + completion_chars,
|
|
123
|
+
"cost": total_cost,
|
|
124
|
+
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
125
|
+
"model_name": self.model,
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
return {"text": response.text, "meta": meta}
|
|
129
|
+
|
|
130
|
+
except Exception as e:
|
|
131
|
+
logger.error(f"Google API request failed: {e}")
|
|
132
|
+
raise RuntimeError(f"Google API request failed: {e}") from e
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
"""Async xAI Grok driver using httpx."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import httpx
|
|
9
|
+
|
|
10
|
+
from ..async_driver import AsyncDriver
|
|
11
|
+
from ..cost_mixin import CostMixin
|
|
12
|
+
from .grok_driver import GrokDriver
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AsyncGrokDriver(CostMixin, AsyncDriver):
|
|
16
|
+
supports_json_mode = True
|
|
17
|
+
|
|
18
|
+
MODEL_PRICING = GrokDriver.MODEL_PRICING
|
|
19
|
+
_PRICING_UNIT = 1_000_000
|
|
20
|
+
|
|
21
|
+
def __init__(self, api_key: str | None = None, model: str = "grok-4-fast-reasoning"):
|
|
22
|
+
self.api_key = api_key or os.getenv("GROK_API_KEY")
|
|
23
|
+
self.model = model
|
|
24
|
+
self.api_base = "https://api.x.ai/v1"
|
|
25
|
+
|
|
26
|
+
supports_messages = True
|
|
27
|
+
|
|
28
|
+
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
29
|
+
messages = [{"role": "user", "content": prompt}]
|
|
30
|
+
return await self._do_generate(messages, options)
|
|
31
|
+
|
|
32
|
+
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
33
|
+
return await self._do_generate(messages, options)
|
|
34
|
+
|
|
35
|
+
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
36
|
+
if not self.api_key:
|
|
37
|
+
raise RuntimeError("GROK_API_KEY environment variable is required")
|
|
38
|
+
|
|
39
|
+
model = options.get("model", self.model)
|
|
40
|
+
|
|
41
|
+
model_info = self.MODEL_PRICING.get(model, {})
|
|
42
|
+
tokens_param = model_info.get("tokens_param", "max_tokens")
|
|
43
|
+
supports_temperature = model_info.get("supports_temperature", True)
|
|
44
|
+
|
|
45
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
46
|
+
|
|
47
|
+
payload = {
|
|
48
|
+
"model": model,
|
|
49
|
+
"messages": messages,
|
|
50
|
+
}
|
|
51
|
+
payload[tokens_param] = opts.get("max_tokens", 512)
|
|
52
|
+
|
|
53
|
+
if supports_temperature and "temperature" in opts:
|
|
54
|
+
payload["temperature"] = opts["temperature"]
|
|
55
|
+
|
|
56
|
+
# Native JSON mode support
|
|
57
|
+
if options.get("json_mode"):
|
|
58
|
+
payload["response_format"] = {"type": "json_object"}
|
|
59
|
+
|
|
60
|
+
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
|
61
|
+
|
|
62
|
+
async with httpx.AsyncClient() as client:
|
|
63
|
+
try:
|
|
64
|
+
response = await client.post(
|
|
65
|
+
f"{self.api_base}/chat/completions", headers=headers, json=payload, timeout=120
|
|
66
|
+
)
|
|
67
|
+
response.raise_for_status()
|
|
68
|
+
resp = response.json()
|
|
69
|
+
except httpx.HTTPStatusError as e:
|
|
70
|
+
raise RuntimeError(f"Grok API request failed: {e!s}") from e
|
|
71
|
+
except Exception as e:
|
|
72
|
+
raise RuntimeError(f"Grok API request failed: {e!s}") from e
|
|
73
|
+
|
|
74
|
+
usage = resp.get("usage", {})
|
|
75
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
76
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
77
|
+
total_tokens = usage.get("total_tokens", 0)
|
|
78
|
+
|
|
79
|
+
total_cost = self._calculate_cost("grok", model, prompt_tokens, completion_tokens)
|
|
80
|
+
|
|
81
|
+
meta = {
|
|
82
|
+
"prompt_tokens": prompt_tokens,
|
|
83
|
+
"completion_tokens": completion_tokens,
|
|
84
|
+
"total_tokens": total_tokens,
|
|
85
|
+
"cost": total_cost,
|
|
86
|
+
"raw_response": resp,
|
|
87
|
+
"model_name": model,
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
text = resp["choices"][0]["message"]["content"]
|
|
91
|
+
return {"text": text, "meta": meta}
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
"""Async Groq driver. Requires the ``groq`` package."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
import groq
|
|
10
|
+
except Exception:
|
|
11
|
+
groq = None
|
|
12
|
+
|
|
13
|
+
from ..async_driver import AsyncDriver
|
|
14
|
+
from ..cost_mixin import CostMixin
|
|
15
|
+
from .groq_driver import GroqDriver
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AsyncGroqDriver(CostMixin, AsyncDriver):
|
|
19
|
+
supports_json_mode = True
|
|
20
|
+
|
|
21
|
+
MODEL_PRICING = GroqDriver.MODEL_PRICING
|
|
22
|
+
|
|
23
|
+
def __init__(self, api_key: str | None = None, model: str = "llama2-70b-4096"):
|
|
24
|
+
self.api_key = api_key or os.getenv("GROQ_API_KEY")
|
|
25
|
+
self.model = model
|
|
26
|
+
if groq:
|
|
27
|
+
self.client = groq.AsyncClient(api_key=self.api_key)
|
|
28
|
+
else:
|
|
29
|
+
self.client = None
|
|
30
|
+
|
|
31
|
+
supports_messages = True
|
|
32
|
+
|
|
33
|
+
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
34
|
+
messages = [{"role": "user", "content": prompt}]
|
|
35
|
+
return await self._do_generate(messages, options)
|
|
36
|
+
|
|
37
|
+
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
38
|
+
return await self._do_generate(messages, options)
|
|
39
|
+
|
|
40
|
+
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
41
|
+
if self.client is None:
|
|
42
|
+
raise RuntimeError("groq package is not installed")
|
|
43
|
+
|
|
44
|
+
model = options.get("model", self.model)
|
|
45
|
+
|
|
46
|
+
model_info = self.MODEL_PRICING.get(model, {})
|
|
47
|
+
tokens_param = model_info.get("tokens_param", "max_tokens")
|
|
48
|
+
supports_temperature = model_info.get("supports_temperature", True)
|
|
49
|
+
|
|
50
|
+
opts = {"temperature": 0.7, "max_tokens": 512, **options}
|
|
51
|
+
|
|
52
|
+
kwargs = {
|
|
53
|
+
"model": model,
|
|
54
|
+
"messages": messages,
|
|
55
|
+
}
|
|
56
|
+
kwargs[tokens_param] = opts.get("max_tokens", 512)
|
|
57
|
+
|
|
58
|
+
if supports_temperature and "temperature" in opts:
|
|
59
|
+
kwargs["temperature"] = opts["temperature"]
|
|
60
|
+
|
|
61
|
+
# Native JSON mode support
|
|
62
|
+
if options.get("json_mode"):
|
|
63
|
+
kwargs["response_format"] = {"type": "json_object"}
|
|
64
|
+
|
|
65
|
+
resp = await self.client.chat.completions.create(**kwargs)
|
|
66
|
+
|
|
67
|
+
usage = getattr(resp, "usage", None)
|
|
68
|
+
prompt_tokens = getattr(usage, "prompt_tokens", 0)
|
|
69
|
+
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
70
|
+
total_tokens = getattr(usage, "total_tokens", 0)
|
|
71
|
+
|
|
72
|
+
total_cost = self._calculate_cost("groq", model, prompt_tokens, completion_tokens)
|
|
73
|
+
|
|
74
|
+
meta = {
|
|
75
|
+
"prompt_tokens": prompt_tokens,
|
|
76
|
+
"completion_tokens": completion_tokens,
|
|
77
|
+
"total_tokens": total_tokens,
|
|
78
|
+
"cost": total_cost,
|
|
79
|
+
"raw_response": resp.model_dump(),
|
|
80
|
+
"model_name": model,
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
text = resp.choices[0].message.content
|
|
84
|
+
return {"text": text, "meta": meta}
|