prompture 0.0.33.dev2__py3-none-any.whl → 0.0.34.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +112 -54
- prompture/_version.py +34 -0
- prompture/aio/__init__.py +74 -0
- prompture/async_conversation.py +484 -0
- prompture/async_core.py +803 -0
- prompture/async_driver.py +131 -0
- prompture/cache.py +469 -0
- prompture/callbacks.py +50 -0
- prompture/cli.py +7 -3
- prompture/conversation.py +504 -0
- prompture/core.py +475 -352
- prompture/cost_mixin.py +51 -0
- prompture/discovery.py +41 -36
- prompture/driver.py +125 -5
- prompture/drivers/__init__.py +63 -57
- prompture/drivers/airllm_driver.py +13 -20
- prompture/drivers/async_airllm_driver.py +26 -0
- prompture/drivers/async_azure_driver.py +117 -0
- prompture/drivers/async_claude_driver.py +107 -0
- prompture/drivers/async_google_driver.py +132 -0
- prompture/drivers/async_grok_driver.py +91 -0
- prompture/drivers/async_groq_driver.py +84 -0
- prompture/drivers/async_hugging_driver.py +61 -0
- prompture/drivers/async_lmstudio_driver.py +79 -0
- prompture/drivers/async_local_http_driver.py +44 -0
- prompture/drivers/async_ollama_driver.py +125 -0
- prompture/drivers/async_openai_driver.py +96 -0
- prompture/drivers/async_openrouter_driver.py +96 -0
- prompture/drivers/async_registry.py +80 -0
- prompture/drivers/azure_driver.py +36 -15
- prompture/drivers/claude_driver.py +86 -40
- prompture/drivers/google_driver.py +86 -58
- prompture/drivers/grok_driver.py +29 -38
- prompture/drivers/groq_driver.py +27 -32
- prompture/drivers/hugging_driver.py +6 -6
- prompture/drivers/lmstudio_driver.py +26 -13
- prompture/drivers/local_http_driver.py +6 -6
- prompture/drivers/ollama_driver.py +90 -23
- prompture/drivers/openai_driver.py +36 -15
- prompture/drivers/openrouter_driver.py +31 -31
- prompture/field_definitions.py +106 -96
- prompture/logging.py +80 -0
- prompture/model_rates.py +16 -15
- prompture/runner.py +49 -47
- prompture/session.py +117 -0
- prompture/settings.py +11 -1
- prompture/tools.py +172 -265
- prompture/validator.py +3 -3
- {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/METADATA +18 -20
- prompture-0.0.34.dev1.dist-info/RECORD +54 -0
- prompture-0.0.33.dev2.dist-info/RECORD +0 -30
- {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/WHEEL +0 -0
- {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/top_level.txt +0 -0
|
@@ -1,81 +1,128 @@
|
|
|
1
1
|
"""Driver for Anthropic's Claude models. Requires the `anthropic` library.
|
|
2
2
|
Use with API key in CLAUDE_API_KEY env var or provide directly.
|
|
3
3
|
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
4
6
|
import os
|
|
5
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
6
9
|
try:
|
|
7
10
|
import anthropic
|
|
8
11
|
except Exception:
|
|
9
12
|
anthropic = None
|
|
10
13
|
|
|
14
|
+
from ..cost_mixin import CostMixin
|
|
11
15
|
from ..driver import Driver
|
|
12
16
|
|
|
13
|
-
|
|
17
|
+
|
|
18
|
+
class ClaudeDriver(CostMixin, Driver):
|
|
19
|
+
supports_json_mode = True
|
|
20
|
+
supports_json_schema = True
|
|
21
|
+
|
|
14
22
|
# Claude pricing per 1000 tokens (prices should be kept current with Anthropic's pricing)
|
|
15
23
|
MODEL_PRICING = {
|
|
16
24
|
# Claude Opus 4.1
|
|
17
25
|
"claude-opus-4-1-20250805": {
|
|
18
|
-
"prompt": 0.015,
|
|
19
|
-
"completion": 0.075,
|
|
26
|
+
"prompt": 0.015, # $15 per 1M prompt tokens
|
|
27
|
+
"completion": 0.075, # $75 per 1M completion tokens
|
|
20
28
|
},
|
|
21
29
|
# Claude Opus 4.0
|
|
22
30
|
"claude-opus-4-20250514": {
|
|
23
|
-
"prompt": 0.015,
|
|
24
|
-
"completion": 0.075,
|
|
31
|
+
"prompt": 0.015, # $15 per 1M prompt tokens
|
|
32
|
+
"completion": 0.075, # $75 per 1M completion tokens
|
|
25
33
|
},
|
|
26
34
|
# Claude Sonnet 4.0
|
|
27
35
|
"claude-sonnet-4-20250514": {
|
|
28
|
-
"prompt": 0.003,
|
|
29
|
-
"completion": 0.015,
|
|
36
|
+
"prompt": 0.003, # $3 per 1M prompt tokens
|
|
37
|
+
"completion": 0.015, # $15 per 1M completion tokens
|
|
30
38
|
},
|
|
31
39
|
# Claude Sonnet 3.7
|
|
32
40
|
"claude-3-7-sonnet-20250219": {
|
|
33
|
-
"prompt": 0.003,
|
|
34
|
-
"completion": 0.015,
|
|
41
|
+
"prompt": 0.003, # $3 per 1M prompt tokens
|
|
42
|
+
"completion": 0.015, # $15 per 1M completion tokens
|
|
35
43
|
},
|
|
36
44
|
# Claude Haiku 3.5
|
|
37
45
|
"claude-3-5-haiku-20241022": {
|
|
38
|
-
"prompt": 0.0008,
|
|
39
|
-
"completion": 0.004,
|
|
40
|
-
}
|
|
46
|
+
"prompt": 0.0008, # $0.80 per 1M prompt tokens
|
|
47
|
+
"completion": 0.004, # $4 per 1M completion tokens
|
|
48
|
+
},
|
|
41
49
|
}
|
|
42
50
|
|
|
43
51
|
def __init__(self, api_key: str | None = None, model: str = "claude-3-5-haiku-20241022"):
|
|
44
52
|
self.api_key = api_key or os.getenv("CLAUDE_API_KEY")
|
|
45
53
|
self.model = model or os.getenv("CLAUDE_MODEL_NAME", "claude-3-5-haiku-20241022")
|
|
46
54
|
|
|
47
|
-
|
|
55
|
+
supports_messages = True
|
|
56
|
+
|
|
57
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
58
|
+
messages = [{"role": "user", "content": prompt}]
|
|
59
|
+
return self._do_generate(messages, options)
|
|
60
|
+
|
|
61
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
62
|
+
return self._do_generate(messages, options)
|
|
63
|
+
|
|
64
|
+
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
48
65
|
if anthropic is None:
|
|
49
66
|
raise RuntimeError("anthropic package not installed")
|
|
50
|
-
|
|
67
|
+
|
|
51
68
|
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
52
69
|
model = options.get("model", self.model)
|
|
53
|
-
|
|
70
|
+
|
|
54
71
|
client = anthropic.Anthropic(api_key=self.api_key)
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
72
|
+
|
|
73
|
+
# Anthropic requires system messages as a top-level parameter
|
|
74
|
+
system_content = None
|
|
75
|
+
api_messages = []
|
|
76
|
+
for msg in messages:
|
|
77
|
+
if msg.get("role") == "system":
|
|
78
|
+
system_content = msg.get("content", "")
|
|
79
|
+
else:
|
|
80
|
+
api_messages.append(msg)
|
|
81
|
+
|
|
82
|
+
# Build common kwargs
|
|
83
|
+
common_kwargs: dict[str, Any] = {
|
|
84
|
+
"model": model,
|
|
85
|
+
"messages": api_messages,
|
|
86
|
+
"temperature": opts["temperature"],
|
|
87
|
+
"max_tokens": opts["max_tokens"],
|
|
88
|
+
}
|
|
89
|
+
if system_content:
|
|
90
|
+
common_kwargs["system"] = system_content
|
|
91
|
+
|
|
92
|
+
# Native JSON mode: use tool-use for schema enforcement
|
|
93
|
+
if options.get("json_mode"):
|
|
94
|
+
json_schema = options.get("json_schema")
|
|
95
|
+
if json_schema:
|
|
96
|
+
tool_def = {
|
|
97
|
+
"name": "extract_json",
|
|
98
|
+
"description": "Extract structured data matching the schema",
|
|
99
|
+
"input_schema": json_schema,
|
|
100
|
+
}
|
|
101
|
+
resp = client.messages.create(
|
|
102
|
+
**common_kwargs,
|
|
103
|
+
tools=[tool_def],
|
|
104
|
+
tool_choice={"type": "tool", "name": "extract_json"},
|
|
105
|
+
)
|
|
106
|
+
text = ""
|
|
107
|
+
for block in resp.content:
|
|
108
|
+
if block.type == "tool_use":
|
|
109
|
+
text = json.dumps(block.input)
|
|
110
|
+
break
|
|
111
|
+
else:
|
|
112
|
+
resp = client.messages.create(**common_kwargs)
|
|
113
|
+
text = resp.content[0].text
|
|
114
|
+
else:
|
|
115
|
+
resp = client.messages.create(**common_kwargs)
|
|
116
|
+
text = resp.content[0].text
|
|
117
|
+
|
|
62
118
|
# Extract token usage from Claude response
|
|
63
119
|
prompt_tokens = resp.usage.input_tokens
|
|
64
120
|
completion_tokens = resp.usage.output_tokens
|
|
65
121
|
total_tokens = prompt_tokens + completion_tokens
|
|
66
|
-
|
|
67
|
-
# Calculate cost
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
if live_rates:
|
|
71
|
-
prompt_cost = (prompt_tokens / 1_000_000) * live_rates["input"]
|
|
72
|
-
completion_cost = (completion_tokens / 1_000_000) * live_rates["output"]
|
|
73
|
-
else:
|
|
74
|
-
model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
|
|
75
|
-
prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
|
|
76
|
-
completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
|
|
77
|
-
total_cost = prompt_cost + completion_cost
|
|
78
|
-
|
|
122
|
+
|
|
123
|
+
# Calculate cost via shared mixin
|
|
124
|
+
total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
|
|
125
|
+
|
|
79
126
|
# Create standardized meta object
|
|
80
127
|
meta = {
|
|
81
128
|
"prompt_tokens": prompt_tokens,
|
|
@@ -83,8 +130,7 @@ class ClaudeDriver(Driver):
|
|
|
83
130
|
"total_tokens": total_tokens,
|
|
84
131
|
"cost": round(total_cost, 6), # Round to 6 decimal places
|
|
85
132
|
"raw_response": dict(resp),
|
|
86
|
-
"model_name": model
|
|
133
|
+
"model_name": model,
|
|
87
134
|
}
|
|
88
|
-
|
|
89
|
-
text
|
|
90
|
-
return {"text": text, "meta": meta}
|
|
135
|
+
|
|
136
|
+
return {"text": text, "meta": meta}
|
|
@@ -1,60 +1,55 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import logging
|
|
3
|
-
import google.generativeai as genai
|
|
4
|
-
from typing import Any, Dict
|
|
5
|
-
from ..driver import Driver
|
|
6
|
-
|
|
7
2
|
import os
|
|
8
|
-
import
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
|
|
9
5
|
import google.generativeai as genai
|
|
10
|
-
|
|
6
|
+
|
|
7
|
+
from ..cost_mixin import CostMixin
|
|
11
8
|
from ..driver import Driver
|
|
12
9
|
|
|
13
10
|
logger = logging.getLogger(__name__)
|
|
14
11
|
|
|
15
12
|
|
|
16
|
-
class GoogleDriver(Driver):
|
|
13
|
+
class GoogleDriver(CostMixin, Driver):
|
|
17
14
|
"""Driver for Google's Generative AI API (Gemini)."""
|
|
18
15
|
|
|
16
|
+
supports_json_mode = True
|
|
17
|
+
supports_json_schema = True
|
|
18
|
+
|
|
19
19
|
# Based on current Gemini pricing (as of 2025)
|
|
20
20
|
# Source: https://cloud.google.com/vertex-ai/pricing#gemini_models
|
|
21
|
+
_PRICING_UNIT = 1_000_000
|
|
21
22
|
MODEL_PRICING = {
|
|
22
23
|
"gemini-1.5-pro": {
|
|
23
24
|
"prompt": 0.00025, # $0.25/1M chars input
|
|
24
|
-
"completion": 0.0005 # $0.50/1M chars output
|
|
25
|
+
"completion": 0.0005, # $0.50/1M chars output
|
|
25
26
|
},
|
|
26
27
|
"gemini-1.5-pro-vision": {
|
|
27
28
|
"prompt": 0.00025, # $0.25/1M chars input
|
|
28
|
-
"completion": 0.0005 # $0.50/1M chars output
|
|
29
|
+
"completion": 0.0005, # $0.50/1M chars output
|
|
29
30
|
},
|
|
30
31
|
"gemini-2.5-pro": {
|
|
31
32
|
"prompt": 0.0004, # $0.40/1M chars input
|
|
32
|
-
"completion": 0.0008 # $0.80/1M chars output
|
|
33
|
+
"completion": 0.0008, # $0.80/1M chars output
|
|
33
34
|
},
|
|
34
35
|
"gemini-2.5-flash": {
|
|
35
36
|
"prompt": 0.0004, # $0.40/1M chars input
|
|
36
|
-
"completion": 0.0008 # $0.80/1M chars output
|
|
37
|
+
"completion": 0.0008, # $0.80/1M chars output
|
|
37
38
|
},
|
|
38
39
|
"gemini-2.5-flash-lite": {
|
|
39
40
|
"prompt": 0.0002, # $0.20/1M chars input
|
|
40
|
-
"completion": 0.0004 # $0.40/1M chars output
|
|
41
|
+
"completion": 0.0004, # $0.40/1M chars output
|
|
41
42
|
},
|
|
42
|
-
|
|
43
|
+
"gemini-2.0-flash": {
|
|
43
44
|
"prompt": 0.0004, # $0.40/1M chars input
|
|
44
|
-
"completion": 0.0008 # $0.80/1M chars output
|
|
45
|
+
"completion": 0.0008, # $0.80/1M chars output
|
|
45
46
|
},
|
|
46
47
|
"gemini-2.0-flash-lite": {
|
|
47
48
|
"prompt": 0.0002, # $0.20/1M chars input
|
|
48
|
-
"completion": 0.0004 # $0.40/1M chars output
|
|
49
|
-
},
|
|
50
|
-
"gemini-1.5-flash": {
|
|
51
|
-
"prompt": 0.00001875,
|
|
52
|
-
"completion": 0.000075
|
|
49
|
+
"completion": 0.0004, # $0.40/1M chars output
|
|
53
50
|
},
|
|
54
|
-
"gemini-1.5-flash
|
|
55
|
-
|
|
56
|
-
"completion": 0.00004
|
|
57
|
-
}
|
|
51
|
+
"gemini-1.5-flash": {"prompt": 0.00001875, "completion": 0.000075},
|
|
52
|
+
"gemini-1.5-flash-8b": {"prompt": 0.00001, "completion": 0.00004},
|
|
58
53
|
}
|
|
59
54
|
|
|
60
55
|
def __init__(self, api_key: str | None = None, model: str = "gemini-1.5-pro"):
|
|
@@ -75,8 +70,8 @@ class GoogleDriver(Driver):
|
|
|
75
70
|
|
|
76
71
|
# Configure google.generativeai
|
|
77
72
|
genai.configure(api_key=self.api_key)
|
|
78
|
-
self.options:
|
|
79
|
-
|
|
73
|
+
self.options: dict[str, Any] = {}
|
|
74
|
+
|
|
80
75
|
# Validate connection and model availability
|
|
81
76
|
self._validate_connection()
|
|
82
77
|
|
|
@@ -90,16 +85,36 @@ class GoogleDriver(Driver):
|
|
|
90
85
|
logger.warning(f"Could not validate connection to Google API: {e}")
|
|
91
86
|
raise
|
|
92
87
|
|
|
93
|
-
def
|
|
94
|
-
"""
|
|
95
|
-
|
|
96
|
-
Args:
|
|
97
|
-
prompt: The input prompt
|
|
98
|
-
options: Additional options to pass to the model
|
|
88
|
+
def _calculate_cost_chars(self, prompt_chars: int, completion_chars: int) -> float:
|
|
89
|
+
"""Calculate cost from character counts.
|
|
99
90
|
|
|
100
|
-
|
|
101
|
-
|
|
91
|
+
Live rates use token-based pricing (estimate ~4 chars/token).
|
|
92
|
+
Hardcoded MODEL_PRICING uses per-1M-character rates.
|
|
102
93
|
"""
|
|
94
|
+
from ..model_rates import get_model_rates
|
|
95
|
+
|
|
96
|
+
live_rates = get_model_rates("google", self.model)
|
|
97
|
+
if live_rates:
|
|
98
|
+
est_prompt_tokens = prompt_chars / 4
|
|
99
|
+
est_completion_tokens = completion_chars / 4
|
|
100
|
+
prompt_cost = (est_prompt_tokens / 1_000_000) * live_rates["input"]
|
|
101
|
+
completion_cost = (est_completion_tokens / 1_000_000) * live_rates["output"]
|
|
102
|
+
else:
|
|
103
|
+
model_pricing = self.MODEL_PRICING.get(self.model, {"prompt": 0, "completion": 0})
|
|
104
|
+
prompt_cost = (prompt_chars / 1_000_000) * model_pricing["prompt"]
|
|
105
|
+
completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
|
|
106
|
+
return round(prompt_cost + completion_cost, 6)
|
|
107
|
+
|
|
108
|
+
supports_messages = True
|
|
109
|
+
|
|
110
|
+
def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
111
|
+
messages = [{"role": "user", "content": prompt}]
|
|
112
|
+
return self._do_generate(messages, options)
|
|
113
|
+
|
|
114
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
115
|
+
return self._do_generate(messages, options)
|
|
116
|
+
|
|
117
|
+
def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
103
118
|
merged_options = self.options.copy()
|
|
104
119
|
if options:
|
|
105
120
|
merged_options.update(options)
|
|
@@ -107,7 +122,7 @@ class GoogleDriver(Driver):
|
|
|
107
122
|
# Extract specific options for Google's API
|
|
108
123
|
generation_config = merged_options.get("generation_config", {})
|
|
109
124
|
safety_settings = merged_options.get("safety_settings", {})
|
|
110
|
-
|
|
125
|
+
|
|
111
126
|
# Map common options to generation_config if not present
|
|
112
127
|
if "temperature" in merged_options and "temperature" not in generation_config:
|
|
113
128
|
generation_config["temperature"] = merged_options["temperature"]
|
|
@@ -118,44 +133,57 @@ class GoogleDriver(Driver):
|
|
|
118
133
|
if "top_k" in merged_options and "top_k" not in generation_config:
|
|
119
134
|
generation_config["top_k"] = merged_options["top_k"]
|
|
120
135
|
|
|
136
|
+
# Native JSON mode support
|
|
137
|
+
if merged_options.get("json_mode"):
|
|
138
|
+
generation_config["response_mime_type"] = "application/json"
|
|
139
|
+
json_schema = merged_options.get("json_schema")
|
|
140
|
+
if json_schema:
|
|
141
|
+
generation_config["response_schema"] = json_schema
|
|
142
|
+
|
|
143
|
+
# Convert messages to Gemini format
|
|
144
|
+
system_instruction = None
|
|
145
|
+
contents: list[dict[str, Any]] = []
|
|
146
|
+
for msg in messages:
|
|
147
|
+
role = msg.get("role", "user")
|
|
148
|
+
content = msg.get("content", "")
|
|
149
|
+
if role == "system":
|
|
150
|
+
system_instruction = content
|
|
151
|
+
else:
|
|
152
|
+
# Gemini uses "model" for assistant role
|
|
153
|
+
gemini_role = "model" if role == "assistant" else "user"
|
|
154
|
+
contents.append({"role": gemini_role, "parts": [content]})
|
|
155
|
+
|
|
121
156
|
try:
|
|
122
157
|
logger.debug(f"Initializing {self.model} for generation")
|
|
123
|
-
|
|
158
|
+
model_kwargs: dict[str, Any] = {}
|
|
159
|
+
if system_instruction:
|
|
160
|
+
model_kwargs["system_instruction"] = system_instruction
|
|
161
|
+
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
124
162
|
|
|
125
163
|
# Generate response
|
|
126
|
-
logger.debug(f"Generating with
|
|
164
|
+
logger.debug(f"Generating with {len(contents)} content parts")
|
|
165
|
+
# If single user message, pass content directly for backward compatibility
|
|
166
|
+
gen_input: Any = contents if len(contents) != 1 else contents[0]["parts"][0]
|
|
127
167
|
response = model.generate_content(
|
|
128
|
-
|
|
168
|
+
gen_input,
|
|
129
169
|
generation_config=generation_config if generation_config else None,
|
|
130
|
-
safety_settings=safety_settings if safety_settings else None
|
|
170
|
+
safety_settings=safety_settings if safety_settings else None,
|
|
131
171
|
)
|
|
132
|
-
|
|
172
|
+
|
|
133
173
|
if not response.text:
|
|
134
174
|
raise ValueError("Empty response from model")
|
|
135
175
|
|
|
136
176
|
# Calculate token usage and cost
|
|
137
|
-
|
|
177
|
+
total_prompt_chars = sum(len(msg.get("content", "")) for msg in messages)
|
|
138
178
|
completion_chars = len(response.text)
|
|
139
179
|
|
|
140
|
-
#
|
|
141
|
-
|
|
142
|
-
live_rates = get_model_rates("google", self.model)
|
|
143
|
-
if live_rates:
|
|
144
|
-
# models.dev reports token-based pricing; estimate tokens from chars (~4 chars/token)
|
|
145
|
-
est_prompt_tokens = prompt_chars / 4
|
|
146
|
-
est_completion_tokens = completion_chars / 4
|
|
147
|
-
prompt_cost = (est_prompt_tokens / 1_000_000) * live_rates["input"]
|
|
148
|
-
completion_cost = (est_completion_tokens / 1_000_000) * live_rates["output"]
|
|
149
|
-
else:
|
|
150
|
-
model_pricing = self.MODEL_PRICING.get(self.model, {"prompt": 0, "completion": 0})
|
|
151
|
-
prompt_cost = (prompt_chars / 1_000_000) * model_pricing["prompt"]
|
|
152
|
-
completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
|
|
153
|
-
total_cost = prompt_cost + completion_cost
|
|
180
|
+
# Google uses character-based cost estimation
|
|
181
|
+
total_cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
|
|
154
182
|
|
|
155
183
|
meta = {
|
|
156
|
-
"prompt_chars":
|
|
184
|
+
"prompt_chars": total_prompt_chars,
|
|
157
185
|
"completion_chars": completion_chars,
|
|
158
|
-
"total_chars":
|
|
186
|
+
"total_chars": total_prompt_chars + completion_chars,
|
|
159
187
|
"cost": total_cost,
|
|
160
188
|
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
161
189
|
"model_name": self.model,
|
|
@@ -165,4 +193,4 @@ class GoogleDriver(Driver):
|
|
|
165
193
|
|
|
166
194
|
except Exception as e:
|
|
167
195
|
logger.error(f"Google API request failed: {e}")
|
|
168
|
-
raise RuntimeError(f"Google API request failed: {e}")
|
|
196
|
+
raise RuntimeError(f"Google API request failed: {e}") from e
|
prompture/drivers/grok_driver.py
CHANGED
|
@@ -1,15 +1,21 @@
|
|
|
1
1
|
"""xAI Grok driver.
|
|
2
2
|
Requires the `requests` package. Uses GROK_API_KEY env var.
|
|
3
3
|
"""
|
|
4
|
+
|
|
4
5
|
import os
|
|
5
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
6
8
|
import requests
|
|
7
9
|
|
|
10
|
+
from ..cost_mixin import CostMixin
|
|
8
11
|
from ..driver import Driver
|
|
9
12
|
|
|
10
13
|
|
|
11
|
-
class GrokDriver(Driver):
|
|
14
|
+
class GrokDriver(CostMixin, Driver):
|
|
15
|
+
supports_json_mode = True
|
|
16
|
+
|
|
12
17
|
# Pricing per 1M tokens based on xAI's documentation
|
|
18
|
+
_PRICING_UNIT = 1_000_000
|
|
13
19
|
MODEL_PRICING = {
|
|
14
20
|
"grok-code-fast-1": {
|
|
15
21
|
"prompt": 0.20,
|
|
@@ -72,19 +78,16 @@ class GrokDriver(Driver):
|
|
|
72
78
|
self.model = model
|
|
73
79
|
self.api_base = "https://api.x.ai/v1"
|
|
74
80
|
|
|
75
|
-
|
|
76
|
-
"""Generate completion using Grok API.
|
|
81
|
+
supports_messages = True
|
|
77
82
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
RuntimeError: If API key is missing or request fails
|
|
87
|
-
"""
|
|
83
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
84
|
+
messages = [{"role": "user", "content": prompt}]
|
|
85
|
+
return self._do_generate(messages, options)
|
|
86
|
+
|
|
87
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
88
|
+
return self._do_generate(messages, options)
|
|
89
|
+
|
|
90
|
+
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
88
91
|
if not self.api_key:
|
|
89
92
|
raise RuntimeError("GROK_API_KEY environment variable is required")
|
|
90
93
|
|
|
@@ -101,7 +104,7 @@ class GrokDriver(Driver):
|
|
|
101
104
|
# Base request payload
|
|
102
105
|
payload = {
|
|
103
106
|
"model": model,
|
|
104
|
-
"messages":
|
|
107
|
+
"messages": messages,
|
|
105
108
|
}
|
|
106
109
|
|
|
107
110
|
# Add token limit with correct parameter name
|
|
@@ -111,39 +114,27 @@ class GrokDriver(Driver):
|
|
|
111
114
|
if supports_temperature and "temperature" in opts:
|
|
112
115
|
payload["temperature"] = opts["temperature"]
|
|
113
116
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
"
|
|
117
|
-
|
|
117
|
+
# Native JSON mode support
|
|
118
|
+
if options.get("json_mode"):
|
|
119
|
+
payload["response_format"] = {"type": "json_object"}
|
|
120
|
+
|
|
121
|
+
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
|
118
122
|
|
|
119
123
|
try:
|
|
120
|
-
response = requests.post(
|
|
121
|
-
f"{self.api_base}/chat/completions",
|
|
122
|
-
headers=headers,
|
|
123
|
-
json=payload
|
|
124
|
-
)
|
|
124
|
+
response = requests.post(f"{self.api_base}/chat/completions", headers=headers, json=payload)
|
|
125
125
|
response.raise_for_status()
|
|
126
126
|
resp = response.json()
|
|
127
127
|
except requests.exceptions.RequestException as e:
|
|
128
|
-
raise RuntimeError(f"Grok API request failed: {
|
|
128
|
+
raise RuntimeError(f"Grok API request failed: {e!s}") from e
|
|
129
129
|
|
|
130
130
|
# Extract usage info
|
|
131
131
|
usage = resp.get("usage", {})
|
|
132
132
|
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
133
|
-
completion_tokens = usage.get("completion_tokens", 0)
|
|
133
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
134
134
|
total_tokens = usage.get("total_tokens", 0)
|
|
135
135
|
|
|
136
|
-
# Calculate cost
|
|
137
|
-
|
|
138
|
-
live_rates = get_model_rates("grok", model)
|
|
139
|
-
if live_rates:
|
|
140
|
-
prompt_cost = (prompt_tokens / 1_000_000) * live_rates["input"]
|
|
141
|
-
completion_cost = (completion_tokens / 1_000_000) * live_rates["output"]
|
|
142
|
-
else:
|
|
143
|
-
model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
|
|
144
|
-
prompt_cost = (prompt_tokens / 1_000_000) * model_pricing["prompt"]
|
|
145
|
-
completion_cost = (completion_tokens / 1_000_000) * model_pricing["completion"]
|
|
146
|
-
total_cost = prompt_cost + completion_cost
|
|
136
|
+
# Calculate cost via shared mixin
|
|
137
|
+
total_cost = self._calculate_cost("grok", model, prompt_tokens, completion_tokens)
|
|
147
138
|
|
|
148
139
|
# Standardized meta object
|
|
149
140
|
meta = {
|
|
@@ -156,4 +147,4 @@ class GrokDriver(Driver):
|
|
|
156
147
|
}
|
|
157
148
|
|
|
158
149
|
text = resp["choices"][0]["message"]["content"]
|
|
159
|
-
return {"text": text, "meta": meta}
|
|
150
|
+
return {"text": text, "meta": meta}
|
prompture/drivers/groq_driver.py
CHANGED
|
@@ -1,18 +1,22 @@
|
|
|
1
1
|
"""Groq driver for prompture.
|
|
2
2
|
Requires the `groq` package. Uses GROQ_API_KEY env var.
|
|
3
3
|
"""
|
|
4
|
+
|
|
4
5
|
import os
|
|
5
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
6
7
|
|
|
7
8
|
try:
|
|
8
9
|
import groq
|
|
9
10
|
except Exception:
|
|
10
11
|
groq = None
|
|
11
12
|
|
|
13
|
+
from ..cost_mixin import CostMixin
|
|
12
14
|
from ..driver import Driver
|
|
13
15
|
|
|
14
16
|
|
|
15
|
-
class GroqDriver(Driver):
|
|
17
|
+
class GroqDriver(CostMixin, Driver):
|
|
18
|
+
supports_json_mode = True
|
|
19
|
+
|
|
16
20
|
# Approximate pricing per 1K tokens (to be updated with official pricing)
|
|
17
21
|
# Each model entry defines token parameters and temperature support
|
|
18
22
|
MODEL_PRICING = {
|
|
@@ -32,7 +36,7 @@ class GroqDriver(Driver):
|
|
|
32
36
|
|
|
33
37
|
def __init__(self, api_key: str | None = None, model: str = "llama2-70b-4096"):
|
|
34
38
|
"""Initialize Groq driver.
|
|
35
|
-
|
|
39
|
+
|
|
36
40
|
Args:
|
|
37
41
|
api_key: Groq API key (defaults to GROQ_API_KEY env var)
|
|
38
42
|
model: Model to use (defaults to llama2-70b-4096)
|
|
@@ -44,20 +48,16 @@ class GroqDriver(Driver):
|
|
|
44
48
|
else:
|
|
45
49
|
self.client = None
|
|
46
50
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
Raises:
|
|
58
|
-
RuntimeError: If groq package is not installed
|
|
59
|
-
groq.error.*: Various Groq API errors
|
|
60
|
-
"""
|
|
51
|
+
supports_messages = True
|
|
52
|
+
|
|
53
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
54
|
+
messages = [{"role": "user", "content": prompt}]
|
|
55
|
+
return self._do_generate(messages, options)
|
|
56
|
+
|
|
57
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
58
|
+
return self._do_generate(messages, options)
|
|
59
|
+
|
|
60
|
+
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
61
61
|
if self.client is None:
|
|
62
62
|
raise RuntimeError("groq package is not installed")
|
|
63
63
|
|
|
@@ -74,7 +74,7 @@ class GroqDriver(Driver):
|
|
|
74
74
|
# Base kwargs for API call
|
|
75
75
|
kwargs = {
|
|
76
76
|
"model": model,
|
|
77
|
-
"messages":
|
|
77
|
+
"messages": messages,
|
|
78
78
|
}
|
|
79
79
|
|
|
80
80
|
# Set token limit with correct parameter name
|
|
@@ -84,29 +84,24 @@ class GroqDriver(Driver):
|
|
|
84
84
|
if supports_temperature and "temperature" in opts:
|
|
85
85
|
kwargs["temperature"] = opts["temperature"]
|
|
86
86
|
|
|
87
|
+
# Native JSON mode support
|
|
88
|
+
if options.get("json_mode"):
|
|
89
|
+
kwargs["response_format"] = {"type": "json_object"}
|
|
90
|
+
|
|
87
91
|
try:
|
|
88
92
|
resp = self.client.chat.completions.create(**kwargs)
|
|
89
|
-
except Exception
|
|
93
|
+
except Exception:
|
|
90
94
|
# Re-raise any Groq API errors
|
|
91
95
|
raise
|
|
92
96
|
|
|
93
97
|
# Extract usage statistics
|
|
94
98
|
usage = getattr(resp, "usage", None)
|
|
95
99
|
prompt_tokens = getattr(usage, "prompt_tokens", 0)
|
|
96
|
-
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
100
|
+
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
97
101
|
total_tokens = getattr(usage, "total_tokens", 0)
|
|
98
102
|
|
|
99
|
-
# Calculate
|
|
100
|
-
|
|
101
|
-
live_rates = get_model_rates("groq", model)
|
|
102
|
-
if live_rates:
|
|
103
|
-
prompt_cost = (prompt_tokens / 1_000_000) * live_rates["input"]
|
|
104
|
-
completion_cost = (completion_tokens / 1_000_000) * live_rates["output"]
|
|
105
|
-
else:
|
|
106
|
-
model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
|
|
107
|
-
prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
|
|
108
|
-
completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
|
|
109
|
-
total_cost = prompt_cost + completion_cost
|
|
103
|
+
# Calculate cost via shared mixin
|
|
104
|
+
total_cost = self._calculate_cost("groq", model, prompt_tokens, completion_tokens)
|
|
110
105
|
|
|
111
106
|
# Standard metadata object
|
|
112
107
|
meta = {
|
|
@@ -120,4 +115,4 @@ class GroqDriver(Driver):
|
|
|
120
115
|
|
|
121
116
|
# Extract generated text
|
|
122
117
|
text = resp.choices[0].message.content
|
|
123
|
-
return {"text": text, "meta": meta}
|
|
118
|
+
return {"text": text, "meta": meta}
|