prompture 0.0.33.dev1__py3-none-any.whl → 0.0.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +133 -49
- prompture/_version.py +34 -0
- prompture/aio/__init__.py +74 -0
- prompture/async_conversation.py +484 -0
- prompture/async_core.py +803 -0
- prompture/async_driver.py +131 -0
- prompture/cache.py +469 -0
- prompture/callbacks.py +50 -0
- prompture/cli.py +7 -3
- prompture/conversation.py +504 -0
- prompture/core.py +475 -352
- prompture/cost_mixin.py +51 -0
- prompture/discovery.py +50 -35
- prompture/driver.py +125 -5
- prompture/drivers/__init__.py +171 -73
- prompture/drivers/airllm_driver.py +13 -20
- prompture/drivers/async_airllm_driver.py +26 -0
- prompture/drivers/async_azure_driver.py +117 -0
- prompture/drivers/async_claude_driver.py +107 -0
- prompture/drivers/async_google_driver.py +132 -0
- prompture/drivers/async_grok_driver.py +91 -0
- prompture/drivers/async_groq_driver.py +84 -0
- prompture/drivers/async_hugging_driver.py +61 -0
- prompture/drivers/async_lmstudio_driver.py +79 -0
- prompture/drivers/async_local_http_driver.py +44 -0
- prompture/drivers/async_ollama_driver.py +125 -0
- prompture/drivers/async_openai_driver.py +96 -0
- prompture/drivers/async_openrouter_driver.py +96 -0
- prompture/drivers/async_registry.py +129 -0
- prompture/drivers/azure_driver.py +36 -9
- prompture/drivers/claude_driver.py +86 -34
- prompture/drivers/google_driver.py +87 -51
- prompture/drivers/grok_driver.py +29 -32
- prompture/drivers/groq_driver.py +27 -26
- prompture/drivers/hugging_driver.py +6 -6
- prompture/drivers/lmstudio_driver.py +26 -13
- prompture/drivers/local_http_driver.py +6 -6
- prompture/drivers/ollama_driver.py +90 -23
- prompture/drivers/openai_driver.py +36 -9
- prompture/drivers/openrouter_driver.py +31 -25
- prompture/drivers/registry.py +306 -0
- prompture/field_definitions.py +106 -96
- prompture/logging.py +80 -0
- prompture/model_rates.py +217 -0
- prompture/runner.py +49 -47
- prompture/session.py +117 -0
- prompture/settings.py +14 -1
- prompture/tools.py +172 -265
- prompture/validator.py +3 -3
- {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/METADATA +18 -20
- prompture-0.0.34.dist-info/RECORD +55 -0
- prompture-0.0.33.dev1.dist-info/RECORD +0 -29
- {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/WHEEL +0 -0
- {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/top_level.txt +0 -0
|
@@ -1,17 +1,23 @@
|
|
|
1
1
|
"""Driver for Azure OpenAI Service (migrated to openai>=1.0.0).
|
|
2
2
|
Requires the `openai` package.
|
|
3
3
|
"""
|
|
4
|
+
|
|
4
5
|
import os
|
|
5
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
6
8
|
try:
|
|
7
9
|
from openai import AzureOpenAI
|
|
8
10
|
except Exception:
|
|
9
11
|
AzureOpenAI = None
|
|
10
12
|
|
|
13
|
+
from ..cost_mixin import CostMixin
|
|
11
14
|
from ..driver import Driver
|
|
12
15
|
|
|
13
16
|
|
|
14
|
-
class AzureDriver(Driver):
|
|
17
|
+
class AzureDriver(CostMixin, Driver):
|
|
18
|
+
supports_json_mode = True
|
|
19
|
+
supports_json_schema = True
|
|
20
|
+
|
|
15
21
|
# Pricing per 1K tokens (adjust if your Azure pricing differs from OpenAI defaults)
|
|
16
22
|
MODEL_PRICING = {
|
|
17
23
|
"gpt-5-mini": {
|
|
@@ -82,7 +88,16 @@ class AzureDriver(Driver):
|
|
|
82
88
|
else:
|
|
83
89
|
self.client = None
|
|
84
90
|
|
|
85
|
-
|
|
91
|
+
supports_messages = True
|
|
92
|
+
|
|
93
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
94
|
+
messages = [{"role": "user", "content": prompt}]
|
|
95
|
+
return self._do_generate(messages, options)
|
|
96
|
+
|
|
97
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
98
|
+
return self._do_generate(messages, options)
|
|
99
|
+
|
|
100
|
+
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
86
101
|
if self.client is None:
|
|
87
102
|
raise RuntimeError("openai package (>=1.0.0) with AzureOpenAI not installed")
|
|
88
103
|
|
|
@@ -96,13 +111,28 @@ class AzureDriver(Driver):
|
|
|
96
111
|
# Build request kwargs
|
|
97
112
|
kwargs = {
|
|
98
113
|
"model": self.deployment_id, # for Azure, use deployment name
|
|
99
|
-
"messages":
|
|
114
|
+
"messages": messages,
|
|
100
115
|
}
|
|
101
116
|
kwargs[tokens_param] = opts.get("max_tokens", 512)
|
|
102
117
|
|
|
103
118
|
if supports_temperature and "temperature" in opts:
|
|
104
119
|
kwargs["temperature"] = opts["temperature"]
|
|
105
120
|
|
|
121
|
+
# Native JSON mode support
|
|
122
|
+
if options.get("json_mode"):
|
|
123
|
+
json_schema = options.get("json_schema")
|
|
124
|
+
if json_schema:
|
|
125
|
+
kwargs["response_format"] = {
|
|
126
|
+
"type": "json_schema",
|
|
127
|
+
"json_schema": {
|
|
128
|
+
"name": "extraction",
|
|
129
|
+
"strict": True,
|
|
130
|
+
"schema": json_schema,
|
|
131
|
+
},
|
|
132
|
+
}
|
|
133
|
+
else:
|
|
134
|
+
kwargs["response_format"] = {"type": "json_object"}
|
|
135
|
+
|
|
106
136
|
resp = self.client.chat.completions.create(**kwargs)
|
|
107
137
|
|
|
108
138
|
# Extract usage
|
|
@@ -111,11 +141,8 @@ class AzureDriver(Driver):
|
|
|
111
141
|
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
112
142
|
total_tokens = getattr(usage, "total_tokens", 0)
|
|
113
143
|
|
|
114
|
-
# Calculate cost
|
|
115
|
-
|
|
116
|
-
prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
|
|
117
|
-
completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
|
|
118
|
-
total_cost = prompt_cost + completion_cost
|
|
144
|
+
# Calculate cost via shared mixin
|
|
145
|
+
total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
|
|
119
146
|
|
|
120
147
|
# Standardized meta object
|
|
121
148
|
meta = {
|
|
@@ -1,75 +1,128 @@
|
|
|
1
1
|
"""Driver for Anthropic's Claude models. Requires the `anthropic` library.
|
|
2
2
|
Use with API key in CLAUDE_API_KEY env var or provide directly.
|
|
3
3
|
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
4
6
|
import os
|
|
5
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
6
9
|
try:
|
|
7
10
|
import anthropic
|
|
8
11
|
except Exception:
|
|
9
12
|
anthropic = None
|
|
10
13
|
|
|
14
|
+
from ..cost_mixin import CostMixin
|
|
11
15
|
from ..driver import Driver
|
|
12
16
|
|
|
13
|
-
|
|
17
|
+
|
|
18
|
+
class ClaudeDriver(CostMixin, Driver):
|
|
19
|
+
supports_json_mode = True
|
|
20
|
+
supports_json_schema = True
|
|
21
|
+
|
|
14
22
|
# Claude pricing per 1000 tokens (prices should be kept current with Anthropic's pricing)
|
|
15
23
|
MODEL_PRICING = {
|
|
16
24
|
# Claude Opus 4.1
|
|
17
25
|
"claude-opus-4-1-20250805": {
|
|
18
|
-
"prompt": 0.015,
|
|
19
|
-
"completion": 0.075,
|
|
26
|
+
"prompt": 0.015, # $15 per 1M prompt tokens
|
|
27
|
+
"completion": 0.075, # $75 per 1M completion tokens
|
|
20
28
|
},
|
|
21
29
|
# Claude Opus 4.0
|
|
22
30
|
"claude-opus-4-20250514": {
|
|
23
|
-
"prompt": 0.015,
|
|
24
|
-
"completion": 0.075,
|
|
31
|
+
"prompt": 0.015, # $15 per 1M prompt tokens
|
|
32
|
+
"completion": 0.075, # $75 per 1M completion tokens
|
|
25
33
|
},
|
|
26
34
|
# Claude Sonnet 4.0
|
|
27
35
|
"claude-sonnet-4-20250514": {
|
|
28
|
-
"prompt": 0.003,
|
|
29
|
-
"completion": 0.015,
|
|
36
|
+
"prompt": 0.003, # $3 per 1M prompt tokens
|
|
37
|
+
"completion": 0.015, # $15 per 1M completion tokens
|
|
30
38
|
},
|
|
31
39
|
# Claude Sonnet 3.7
|
|
32
40
|
"claude-3-7-sonnet-20250219": {
|
|
33
|
-
"prompt": 0.003,
|
|
34
|
-
"completion": 0.015,
|
|
41
|
+
"prompt": 0.003, # $3 per 1M prompt tokens
|
|
42
|
+
"completion": 0.015, # $15 per 1M completion tokens
|
|
35
43
|
},
|
|
36
44
|
# Claude Haiku 3.5
|
|
37
45
|
"claude-3-5-haiku-20241022": {
|
|
38
|
-
"prompt": 0.0008,
|
|
39
|
-
"completion": 0.004,
|
|
40
|
-
}
|
|
46
|
+
"prompt": 0.0008, # $0.80 per 1M prompt tokens
|
|
47
|
+
"completion": 0.004, # $4 per 1M completion tokens
|
|
48
|
+
},
|
|
41
49
|
}
|
|
42
50
|
|
|
43
51
|
def __init__(self, api_key: str | None = None, model: str = "claude-3-5-haiku-20241022"):
|
|
44
52
|
self.api_key = api_key or os.getenv("CLAUDE_API_KEY")
|
|
45
53
|
self.model = model or os.getenv("CLAUDE_MODEL_NAME", "claude-3-5-haiku-20241022")
|
|
46
54
|
|
|
47
|
-
|
|
55
|
+
supports_messages = True
|
|
56
|
+
|
|
57
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
58
|
+
messages = [{"role": "user", "content": prompt}]
|
|
59
|
+
return self._do_generate(messages, options)
|
|
60
|
+
|
|
61
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
62
|
+
return self._do_generate(messages, options)
|
|
63
|
+
|
|
64
|
+
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
48
65
|
if anthropic is None:
|
|
49
66
|
raise RuntimeError("anthropic package not installed")
|
|
50
|
-
|
|
67
|
+
|
|
51
68
|
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
52
69
|
model = options.get("model", self.model)
|
|
53
|
-
|
|
70
|
+
|
|
54
71
|
client = anthropic.Anthropic(api_key=self.api_key)
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
72
|
+
|
|
73
|
+
# Anthropic requires system messages as a top-level parameter
|
|
74
|
+
system_content = None
|
|
75
|
+
api_messages = []
|
|
76
|
+
for msg in messages:
|
|
77
|
+
if msg.get("role") == "system":
|
|
78
|
+
system_content = msg.get("content", "")
|
|
79
|
+
else:
|
|
80
|
+
api_messages.append(msg)
|
|
81
|
+
|
|
82
|
+
# Build common kwargs
|
|
83
|
+
common_kwargs: dict[str, Any] = {
|
|
84
|
+
"model": model,
|
|
85
|
+
"messages": api_messages,
|
|
86
|
+
"temperature": opts["temperature"],
|
|
87
|
+
"max_tokens": opts["max_tokens"],
|
|
88
|
+
}
|
|
89
|
+
if system_content:
|
|
90
|
+
common_kwargs["system"] = system_content
|
|
91
|
+
|
|
92
|
+
# Native JSON mode: use tool-use for schema enforcement
|
|
93
|
+
if options.get("json_mode"):
|
|
94
|
+
json_schema = options.get("json_schema")
|
|
95
|
+
if json_schema:
|
|
96
|
+
tool_def = {
|
|
97
|
+
"name": "extract_json",
|
|
98
|
+
"description": "Extract structured data matching the schema",
|
|
99
|
+
"input_schema": json_schema,
|
|
100
|
+
}
|
|
101
|
+
resp = client.messages.create(
|
|
102
|
+
**common_kwargs,
|
|
103
|
+
tools=[tool_def],
|
|
104
|
+
tool_choice={"type": "tool", "name": "extract_json"},
|
|
105
|
+
)
|
|
106
|
+
text = ""
|
|
107
|
+
for block in resp.content:
|
|
108
|
+
if block.type == "tool_use":
|
|
109
|
+
text = json.dumps(block.input)
|
|
110
|
+
break
|
|
111
|
+
else:
|
|
112
|
+
resp = client.messages.create(**common_kwargs)
|
|
113
|
+
text = resp.content[0].text
|
|
114
|
+
else:
|
|
115
|
+
resp = client.messages.create(**common_kwargs)
|
|
116
|
+
text = resp.content[0].text
|
|
117
|
+
|
|
62
118
|
# Extract token usage from Claude response
|
|
63
119
|
prompt_tokens = resp.usage.input_tokens
|
|
64
120
|
completion_tokens = resp.usage.output_tokens
|
|
65
121
|
total_tokens = prompt_tokens + completion_tokens
|
|
66
|
-
|
|
67
|
-
# Calculate cost
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
|
|
71
|
-
total_cost = prompt_cost + completion_cost
|
|
72
|
-
|
|
122
|
+
|
|
123
|
+
# Calculate cost via shared mixin
|
|
124
|
+
total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
|
|
125
|
+
|
|
73
126
|
# Create standardized meta object
|
|
74
127
|
meta = {
|
|
75
128
|
"prompt_tokens": prompt_tokens,
|
|
@@ -77,8 +130,7 @@ class ClaudeDriver(Driver):
|
|
|
77
130
|
"total_tokens": total_tokens,
|
|
78
131
|
"cost": round(total_cost, 6), # Round to 6 decimal places
|
|
79
132
|
"raw_response": dict(resp),
|
|
80
|
-
"model_name": model
|
|
133
|
+
"model_name": model,
|
|
81
134
|
}
|
|
82
|
-
|
|
83
|
-
text
|
|
84
|
-
return {"text": text, "meta": meta}
|
|
135
|
+
|
|
136
|
+
return {"text": text, "meta": meta}
|
|
@@ -1,60 +1,55 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import logging
|
|
3
|
-
import google.generativeai as genai
|
|
4
|
-
from typing import Any, Dict
|
|
5
|
-
from ..driver import Driver
|
|
6
|
-
|
|
7
2
|
import os
|
|
8
|
-
import
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
|
|
9
5
|
import google.generativeai as genai
|
|
10
|
-
|
|
6
|
+
|
|
7
|
+
from ..cost_mixin import CostMixin
|
|
11
8
|
from ..driver import Driver
|
|
12
9
|
|
|
13
10
|
logger = logging.getLogger(__name__)
|
|
14
11
|
|
|
15
12
|
|
|
16
|
-
class GoogleDriver(Driver):
|
|
13
|
+
class GoogleDriver(CostMixin, Driver):
|
|
17
14
|
"""Driver for Google's Generative AI API (Gemini)."""
|
|
18
15
|
|
|
16
|
+
supports_json_mode = True
|
|
17
|
+
supports_json_schema = True
|
|
18
|
+
|
|
19
19
|
# Based on current Gemini pricing (as of 2025)
|
|
20
20
|
# Source: https://cloud.google.com/vertex-ai/pricing#gemini_models
|
|
21
|
+
_PRICING_UNIT = 1_000_000
|
|
21
22
|
MODEL_PRICING = {
|
|
22
23
|
"gemini-1.5-pro": {
|
|
23
24
|
"prompt": 0.00025, # $0.25/1M chars input
|
|
24
|
-
"completion": 0.0005 # $0.50/1M chars output
|
|
25
|
+
"completion": 0.0005, # $0.50/1M chars output
|
|
25
26
|
},
|
|
26
27
|
"gemini-1.5-pro-vision": {
|
|
27
28
|
"prompt": 0.00025, # $0.25/1M chars input
|
|
28
|
-
"completion": 0.0005 # $0.50/1M chars output
|
|
29
|
+
"completion": 0.0005, # $0.50/1M chars output
|
|
29
30
|
},
|
|
30
31
|
"gemini-2.5-pro": {
|
|
31
32
|
"prompt": 0.0004, # $0.40/1M chars input
|
|
32
|
-
"completion": 0.0008 # $0.80/1M chars output
|
|
33
|
+
"completion": 0.0008, # $0.80/1M chars output
|
|
33
34
|
},
|
|
34
35
|
"gemini-2.5-flash": {
|
|
35
36
|
"prompt": 0.0004, # $0.40/1M chars input
|
|
36
|
-
"completion": 0.0008 # $0.80/1M chars output
|
|
37
|
+
"completion": 0.0008, # $0.80/1M chars output
|
|
37
38
|
},
|
|
38
39
|
"gemini-2.5-flash-lite": {
|
|
39
40
|
"prompt": 0.0002, # $0.20/1M chars input
|
|
40
|
-
"completion": 0.0004 # $0.40/1M chars output
|
|
41
|
+
"completion": 0.0004, # $0.40/1M chars output
|
|
41
42
|
},
|
|
42
|
-
|
|
43
|
+
"gemini-2.0-flash": {
|
|
43
44
|
"prompt": 0.0004, # $0.40/1M chars input
|
|
44
|
-
"completion": 0.0008 # $0.80/1M chars output
|
|
45
|
+
"completion": 0.0008, # $0.80/1M chars output
|
|
45
46
|
},
|
|
46
47
|
"gemini-2.0-flash-lite": {
|
|
47
48
|
"prompt": 0.0002, # $0.20/1M chars input
|
|
48
|
-
"completion": 0.0004 # $0.40/1M chars output
|
|
49
|
-
},
|
|
50
|
-
"gemini-1.5-flash": {
|
|
51
|
-
"prompt": 0.00001875,
|
|
52
|
-
"completion": 0.000075
|
|
49
|
+
"completion": 0.0004, # $0.40/1M chars output
|
|
53
50
|
},
|
|
54
|
-
"gemini-1.5-flash
|
|
55
|
-
|
|
56
|
-
"completion": 0.00004
|
|
57
|
-
}
|
|
51
|
+
"gemini-1.5-flash": {"prompt": 0.00001875, "completion": 0.000075},
|
|
52
|
+
"gemini-1.5-flash-8b": {"prompt": 0.00001, "completion": 0.00004},
|
|
58
53
|
}
|
|
59
54
|
|
|
60
55
|
def __init__(self, api_key: str | None = None, model: str = "gemini-1.5-pro"):
|
|
@@ -75,8 +70,8 @@ class GoogleDriver(Driver):
|
|
|
75
70
|
|
|
76
71
|
# Configure google.generativeai
|
|
77
72
|
genai.configure(api_key=self.api_key)
|
|
78
|
-
self.options:
|
|
79
|
-
|
|
73
|
+
self.options: dict[str, Any] = {}
|
|
74
|
+
|
|
80
75
|
# Validate connection and model availability
|
|
81
76
|
self._validate_connection()
|
|
82
77
|
|
|
@@ -90,16 +85,36 @@ class GoogleDriver(Driver):
|
|
|
90
85
|
logger.warning(f"Could not validate connection to Google API: {e}")
|
|
91
86
|
raise
|
|
92
87
|
|
|
93
|
-
def
|
|
94
|
-
"""
|
|
88
|
+
def _calculate_cost_chars(self, prompt_chars: int, completion_chars: int) -> float:
|
|
89
|
+
"""Calculate cost from character counts.
|
|
95
90
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
options: Additional options to pass to the model
|
|
99
|
-
|
|
100
|
-
Returns:
|
|
101
|
-
Dict containing generated text and metadata
|
|
91
|
+
Live rates use token-based pricing (estimate ~4 chars/token).
|
|
92
|
+
Hardcoded MODEL_PRICING uses per-1M-character rates.
|
|
102
93
|
"""
|
|
94
|
+
from ..model_rates import get_model_rates
|
|
95
|
+
|
|
96
|
+
live_rates = get_model_rates("google", self.model)
|
|
97
|
+
if live_rates:
|
|
98
|
+
est_prompt_tokens = prompt_chars / 4
|
|
99
|
+
est_completion_tokens = completion_chars / 4
|
|
100
|
+
prompt_cost = (est_prompt_tokens / 1_000_000) * live_rates["input"]
|
|
101
|
+
completion_cost = (est_completion_tokens / 1_000_000) * live_rates["output"]
|
|
102
|
+
else:
|
|
103
|
+
model_pricing = self.MODEL_PRICING.get(self.model, {"prompt": 0, "completion": 0})
|
|
104
|
+
prompt_cost = (prompt_chars / 1_000_000) * model_pricing["prompt"]
|
|
105
|
+
completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
|
|
106
|
+
return round(prompt_cost + completion_cost, 6)
|
|
107
|
+
|
|
108
|
+
supports_messages = True
|
|
109
|
+
|
|
110
|
+
def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
111
|
+
messages = [{"role": "user", "content": prompt}]
|
|
112
|
+
return self._do_generate(messages, options)
|
|
113
|
+
|
|
114
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
115
|
+
return self._do_generate(messages, options)
|
|
116
|
+
|
|
117
|
+
def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
103
118
|
merged_options = self.options.copy()
|
|
104
119
|
if options:
|
|
105
120
|
merged_options.update(options)
|
|
@@ -107,7 +122,7 @@ class GoogleDriver(Driver):
|
|
|
107
122
|
# Extract specific options for Google's API
|
|
108
123
|
generation_config = merged_options.get("generation_config", {})
|
|
109
124
|
safety_settings = merged_options.get("safety_settings", {})
|
|
110
|
-
|
|
125
|
+
|
|
111
126
|
# Map common options to generation_config if not present
|
|
112
127
|
if "temperature" in merged_options and "temperature" not in generation_config:
|
|
113
128
|
generation_config["temperature"] = merged_options["temperature"]
|
|
@@ -118,36 +133,57 @@ class GoogleDriver(Driver):
|
|
|
118
133
|
if "top_k" in merged_options and "top_k" not in generation_config:
|
|
119
134
|
generation_config["top_k"] = merged_options["top_k"]
|
|
120
135
|
|
|
136
|
+
# Native JSON mode support
|
|
137
|
+
if merged_options.get("json_mode"):
|
|
138
|
+
generation_config["response_mime_type"] = "application/json"
|
|
139
|
+
json_schema = merged_options.get("json_schema")
|
|
140
|
+
if json_schema:
|
|
141
|
+
generation_config["response_schema"] = json_schema
|
|
142
|
+
|
|
143
|
+
# Convert messages to Gemini format
|
|
144
|
+
system_instruction = None
|
|
145
|
+
contents: list[dict[str, Any]] = []
|
|
146
|
+
for msg in messages:
|
|
147
|
+
role = msg.get("role", "user")
|
|
148
|
+
content = msg.get("content", "")
|
|
149
|
+
if role == "system":
|
|
150
|
+
system_instruction = content
|
|
151
|
+
else:
|
|
152
|
+
# Gemini uses "model" for assistant role
|
|
153
|
+
gemini_role = "model" if role == "assistant" else "user"
|
|
154
|
+
contents.append({"role": gemini_role, "parts": [content]})
|
|
155
|
+
|
|
121
156
|
try:
|
|
122
157
|
logger.debug(f"Initializing {self.model} for generation")
|
|
123
|
-
|
|
158
|
+
model_kwargs: dict[str, Any] = {}
|
|
159
|
+
if system_instruction:
|
|
160
|
+
model_kwargs["system_instruction"] = system_instruction
|
|
161
|
+
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
124
162
|
|
|
125
163
|
# Generate response
|
|
126
|
-
logger.debug(f"Generating with
|
|
164
|
+
logger.debug(f"Generating with {len(contents)} content parts")
|
|
165
|
+
# If single user message, pass content directly for backward compatibility
|
|
166
|
+
gen_input: Any = contents if len(contents) != 1 else contents[0]["parts"][0]
|
|
127
167
|
response = model.generate_content(
|
|
128
|
-
|
|
168
|
+
gen_input,
|
|
129
169
|
generation_config=generation_config if generation_config else None,
|
|
130
|
-
safety_settings=safety_settings if safety_settings else None
|
|
170
|
+
safety_settings=safety_settings if safety_settings else None,
|
|
131
171
|
)
|
|
132
|
-
|
|
172
|
+
|
|
133
173
|
if not response.text:
|
|
134
174
|
raise ValueError("Empty response from model")
|
|
135
175
|
|
|
136
176
|
# Calculate token usage and cost
|
|
137
|
-
|
|
138
|
-
prompt_chars = len(prompt)
|
|
177
|
+
total_prompt_chars = sum(len(msg.get("content", "")) for msg in messages)
|
|
139
178
|
completion_chars = len(response.text)
|
|
140
|
-
|
|
141
|
-
#
|
|
142
|
-
|
|
143
|
-
prompt_cost = (prompt_chars / 1_000_000) * model_pricing["prompt"]
|
|
144
|
-
completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
|
|
145
|
-
total_cost = prompt_cost + completion_cost
|
|
179
|
+
|
|
180
|
+
# Google uses character-based cost estimation
|
|
181
|
+
total_cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
|
|
146
182
|
|
|
147
183
|
meta = {
|
|
148
|
-
"prompt_chars":
|
|
184
|
+
"prompt_chars": total_prompt_chars,
|
|
149
185
|
"completion_chars": completion_chars,
|
|
150
|
-
"total_chars":
|
|
186
|
+
"total_chars": total_prompt_chars + completion_chars,
|
|
151
187
|
"cost": total_cost,
|
|
152
188
|
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
153
189
|
"model_name": self.model,
|
|
@@ -157,4 +193,4 @@ class GoogleDriver(Driver):
|
|
|
157
193
|
|
|
158
194
|
except Exception as e:
|
|
159
195
|
logger.error(f"Google API request failed: {e}")
|
|
160
|
-
raise RuntimeError(f"Google API request failed: {e}")
|
|
196
|
+
raise RuntimeError(f"Google API request failed: {e}") from e
|
prompture/drivers/grok_driver.py
CHANGED
|
@@ -1,15 +1,21 @@
|
|
|
1
1
|
"""xAI Grok driver.
|
|
2
2
|
Requires the `requests` package. Uses GROK_API_KEY env var.
|
|
3
3
|
"""
|
|
4
|
+
|
|
4
5
|
import os
|
|
5
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
6
8
|
import requests
|
|
7
9
|
|
|
10
|
+
from ..cost_mixin import CostMixin
|
|
8
11
|
from ..driver import Driver
|
|
9
12
|
|
|
10
13
|
|
|
11
|
-
class GrokDriver(Driver):
|
|
14
|
+
class GrokDriver(CostMixin, Driver):
|
|
15
|
+
supports_json_mode = True
|
|
16
|
+
|
|
12
17
|
# Pricing per 1M tokens based on xAI's documentation
|
|
18
|
+
_PRICING_UNIT = 1_000_000
|
|
13
19
|
MODEL_PRICING = {
|
|
14
20
|
"grok-code-fast-1": {
|
|
15
21
|
"prompt": 0.20,
|
|
@@ -72,19 +78,16 @@ class GrokDriver(Driver):
|
|
|
72
78
|
self.model = model
|
|
73
79
|
self.api_base = "https://api.x.ai/v1"
|
|
74
80
|
|
|
75
|
-
|
|
76
|
-
"""Generate completion using Grok API.
|
|
81
|
+
supports_messages = True
|
|
77
82
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
RuntimeError: If API key is missing or request fails
|
|
87
|
-
"""
|
|
83
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
84
|
+
messages = [{"role": "user", "content": prompt}]
|
|
85
|
+
return self._do_generate(messages, options)
|
|
86
|
+
|
|
87
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
88
|
+
return self._do_generate(messages, options)
|
|
89
|
+
|
|
90
|
+
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
88
91
|
if not self.api_key:
|
|
89
92
|
raise RuntimeError("GROK_API_KEY environment variable is required")
|
|
90
93
|
|
|
@@ -101,7 +104,7 @@ class GrokDriver(Driver):
|
|
|
101
104
|
# Base request payload
|
|
102
105
|
payload = {
|
|
103
106
|
"model": model,
|
|
104
|
-
"messages":
|
|
107
|
+
"messages": messages,
|
|
105
108
|
}
|
|
106
109
|
|
|
107
110
|
# Add token limit with correct parameter name
|
|
@@ -111,33 +114,27 @@ class GrokDriver(Driver):
|
|
|
111
114
|
if supports_temperature and "temperature" in opts:
|
|
112
115
|
payload["temperature"] = opts["temperature"]
|
|
113
116
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
"
|
|
117
|
-
|
|
117
|
+
# Native JSON mode support
|
|
118
|
+
if options.get("json_mode"):
|
|
119
|
+
payload["response_format"] = {"type": "json_object"}
|
|
120
|
+
|
|
121
|
+
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
|
118
122
|
|
|
119
123
|
try:
|
|
120
|
-
response = requests.post(
|
|
121
|
-
f"{self.api_base}/chat/completions",
|
|
122
|
-
headers=headers,
|
|
123
|
-
json=payload
|
|
124
|
-
)
|
|
124
|
+
response = requests.post(f"{self.api_base}/chat/completions", headers=headers, json=payload)
|
|
125
125
|
response.raise_for_status()
|
|
126
126
|
resp = response.json()
|
|
127
127
|
except requests.exceptions.RequestException as e:
|
|
128
|
-
raise RuntimeError(f"Grok API request failed: {
|
|
128
|
+
raise RuntimeError(f"Grok API request failed: {e!s}") from e
|
|
129
129
|
|
|
130
130
|
# Extract usage info
|
|
131
131
|
usage = resp.get("usage", {})
|
|
132
132
|
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
133
|
-
completion_tokens = usage.get("completion_tokens", 0)
|
|
133
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
134
134
|
total_tokens = usage.get("total_tokens", 0)
|
|
135
135
|
|
|
136
|
-
# Calculate cost
|
|
137
|
-
|
|
138
|
-
prompt_cost = (prompt_tokens / 1000000) * model_pricing["prompt"]
|
|
139
|
-
completion_cost = (completion_tokens / 1000000) * model_pricing["completion"]
|
|
140
|
-
total_cost = prompt_cost + completion_cost
|
|
136
|
+
# Calculate cost via shared mixin
|
|
137
|
+
total_cost = self._calculate_cost("grok", model, prompt_tokens, completion_tokens)
|
|
141
138
|
|
|
142
139
|
# Standardized meta object
|
|
143
140
|
meta = {
|
|
@@ -150,4 +147,4 @@ class GrokDriver(Driver):
|
|
|
150
147
|
}
|
|
151
148
|
|
|
152
149
|
text = resp["choices"][0]["message"]["content"]
|
|
153
|
-
return {"text": text, "meta": meta}
|
|
150
|
+
return {"text": text, "meta": meta}
|