prompture 0.0.29.dev8__py3-none-any.whl → 0.0.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. prompture/__init__.py +146 -23
  2. prompture/_version.py +34 -0
  3. prompture/aio/__init__.py +74 -0
  4. prompture/async_conversation.py +607 -0
  5. prompture/async_core.py +803 -0
  6. prompture/async_driver.py +169 -0
  7. prompture/cache.py +469 -0
  8. prompture/callbacks.py +55 -0
  9. prompture/cli.py +63 -4
  10. prompture/conversation.py +631 -0
  11. prompture/core.py +876 -263
  12. prompture/cost_mixin.py +51 -0
  13. prompture/discovery.py +164 -0
  14. prompture/driver.py +168 -5
  15. prompture/drivers/__init__.py +173 -69
  16. prompture/drivers/airllm_driver.py +109 -0
  17. prompture/drivers/async_airllm_driver.py +26 -0
  18. prompture/drivers/async_azure_driver.py +117 -0
  19. prompture/drivers/async_claude_driver.py +107 -0
  20. prompture/drivers/async_google_driver.py +132 -0
  21. prompture/drivers/async_grok_driver.py +91 -0
  22. prompture/drivers/async_groq_driver.py +84 -0
  23. prompture/drivers/async_hugging_driver.py +61 -0
  24. prompture/drivers/async_lmstudio_driver.py +79 -0
  25. prompture/drivers/async_local_http_driver.py +44 -0
  26. prompture/drivers/async_ollama_driver.py +125 -0
  27. prompture/drivers/async_openai_driver.py +96 -0
  28. prompture/drivers/async_openrouter_driver.py +96 -0
  29. prompture/drivers/async_registry.py +129 -0
  30. prompture/drivers/azure_driver.py +36 -9
  31. prompture/drivers/claude_driver.py +251 -34
  32. prompture/drivers/google_driver.py +107 -38
  33. prompture/drivers/grok_driver.py +29 -32
  34. prompture/drivers/groq_driver.py +27 -26
  35. prompture/drivers/hugging_driver.py +6 -6
  36. prompture/drivers/lmstudio_driver.py +26 -13
  37. prompture/drivers/local_http_driver.py +6 -6
  38. prompture/drivers/ollama_driver.py +157 -23
  39. prompture/drivers/openai_driver.py +178 -9
  40. prompture/drivers/openrouter_driver.py +31 -25
  41. prompture/drivers/registry.py +306 -0
  42. prompture/field_definitions.py +106 -96
  43. prompture/logging.py +80 -0
  44. prompture/model_rates.py +217 -0
  45. prompture/runner.py +49 -47
  46. prompture/scaffold/__init__.py +1 -0
  47. prompture/scaffold/generator.py +84 -0
  48. prompture/scaffold/templates/Dockerfile.j2 +12 -0
  49. prompture/scaffold/templates/README.md.j2 +41 -0
  50. prompture/scaffold/templates/config.py.j2 +21 -0
  51. prompture/scaffold/templates/env.example.j2 +8 -0
  52. prompture/scaffold/templates/main.py.j2 +86 -0
  53. prompture/scaffold/templates/models.py.j2 +40 -0
  54. prompture/scaffold/templates/requirements.txt.j2 +5 -0
  55. prompture/server.py +183 -0
  56. prompture/session.py +117 -0
  57. prompture/settings.py +18 -1
  58. prompture/tools.py +219 -267
  59. prompture/tools_schema.py +254 -0
  60. prompture/validator.py +3 -3
  61. {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/METADATA +117 -21
  62. prompture-0.0.35.dist-info/RECORD +66 -0
  63. {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/WHEEL +1 -1
  64. prompture-0.0.29.dev8.dist-info/RECORD +0 -27
  65. {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/entry_points.txt +0 -0
  66. {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/licenses/LICENSE +0 -0
  67. {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,109 @@
1
+ import logging
2
+ from typing import Any, Optional
3
+
4
+ from ..driver import Driver
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ class AirLLMDriver(Driver):
10
+ """Driver for AirLLM — run large models (70B+) on consumer GPUs via
11
+ layer-by-layer memory management.
12
+
13
+ The ``airllm`` package is a lazy dependency: it is imported on first
14
+ ``generate()`` call so the rest of Prompture works without it installed.
15
+ """
16
+
17
+ MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
18
+
19
+ def __init__(self, model: str = "meta-llama/Llama-2-7b-hf", compression: Optional[str] = None):
20
+ """
21
+ Args:
22
+ model: HuggingFace repo ID (e.g. ``"meta-llama/Llama-2-70b-hf"``).
23
+ compression: Optional quantization mode — ``"4bit"`` or ``"8bit"``.
24
+ """
25
+ self.model = model
26
+ self.compression = compression
27
+ self.options: dict[str, Any] = {}
28
+ self._llm = None
29
+ self._tokenizer = None
30
+
31
+ # ------------------------------------------------------------------
32
+ # Lazy model loading
33
+ # ------------------------------------------------------------------
34
+ def _ensure_loaded(self):
35
+ """Load the AirLLM model and tokenizer on first use."""
36
+ if self._llm is not None:
37
+ return
38
+
39
+ try:
40
+ from airllm import AutoModel
41
+ except ImportError:
42
+ raise ImportError(
43
+ "The 'airllm' package is required for the AirLLM driver. Install it with: pip install prompture[airllm]"
44
+ ) from None
45
+
46
+ try:
47
+ from transformers import AutoTokenizer
48
+ except ImportError:
49
+ raise ImportError(
50
+ "The 'transformers' package is required for the AirLLM driver. "
51
+ "Install it with: pip install transformers"
52
+ ) from None
53
+
54
+ logger.info(f"Loading AirLLM model: {self.model} (compression={self.compression})")
55
+
56
+ load_kwargs: dict[str, Any] = {}
57
+ if self.compression:
58
+ load_kwargs["compression"] = self.compression
59
+
60
+ self._llm = AutoModel.from_pretrained(self.model, **load_kwargs)
61
+ self._tokenizer = AutoTokenizer.from_pretrained(self.model)
62
+ logger.info("AirLLM model loaded successfully")
63
+
64
+ # ------------------------------------------------------------------
65
+ # Driver interface
66
+ # ------------------------------------------------------------------
67
+ def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
68
+ self._ensure_loaded()
69
+
70
+ merged_options = self.options.copy()
71
+ if options:
72
+ merged_options.update(options)
73
+
74
+ max_new_tokens = merged_options.get("max_new_tokens", 256)
75
+
76
+ # Tokenize
77
+ input_ids = self._tokenizer(prompt, return_tensors="pt").input_ids
78
+
79
+ prompt_tokens = input_ids.shape[1]
80
+
81
+ logger.debug(f"AirLLM generating with max_new_tokens={max_new_tokens}, prompt_tokens={prompt_tokens}")
82
+
83
+ # Generate
84
+ output_ids = self._llm.generate(
85
+ input_ids,
86
+ max_new_tokens=max_new_tokens,
87
+ )
88
+
89
+ # Decode only the newly generated tokens (strip the prompt prefix)
90
+ new_tokens = output_ids[0, prompt_tokens:]
91
+ completion_tokens = len(new_tokens)
92
+ text = self._tokenizer.decode(new_tokens, skip_special_tokens=True)
93
+
94
+ total_tokens = prompt_tokens + completion_tokens
95
+
96
+ meta = {
97
+ "prompt_tokens": prompt_tokens,
98
+ "completion_tokens": completion_tokens,
99
+ "total_tokens": total_tokens,
100
+ "cost": 0.0,
101
+ "raw_response": {
102
+ "model": self.model,
103
+ "compression": self.compression,
104
+ "max_new_tokens": max_new_tokens,
105
+ },
106
+ "model_name": self.model,
107
+ }
108
+
109
+ return {"text": text, "meta": meta}
@@ -0,0 +1,26 @@
1
+ """Async AirLLM driver — wraps the sync GPU-bound driver with asyncio.to_thread."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ from typing import Any
7
+
8
+ from ..async_driver import AsyncDriver
9
+ from .airllm_driver import AirLLMDriver
10
+
11
+
12
+ class AsyncAirLLMDriver(AsyncDriver):
13
+ """Async wrapper around :class:`AirLLMDriver`.
14
+
15
+ AirLLM is GPU-bound with no native async API, so we delegate to
16
+ ``asyncio.to_thread()`` to avoid blocking the event loop.
17
+ """
18
+
19
+ MODEL_PRICING = AirLLMDriver.MODEL_PRICING
20
+
21
+ def __init__(self, model: str = "meta-llama/Llama-2-7b-hf", compression: str | None = None):
22
+ self.model = model
23
+ self._sync_driver = AirLLMDriver(model=model, compression=compression)
24
+
25
+ async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
26
+ return await asyncio.to_thread(self._sync_driver.generate, prompt, options)
@@ -0,0 +1,117 @@
1
+ """Async Azure OpenAI driver. Requires the ``openai`` package (>=1.0.0)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from typing import Any
7
+
8
+ try:
9
+ from openai import AsyncAzureOpenAI
10
+ except Exception:
11
+ AsyncAzureOpenAI = None
12
+
13
+ from ..async_driver import AsyncDriver
14
+ from ..cost_mixin import CostMixin
15
+ from .azure_driver import AzureDriver
16
+
17
+
18
+ class AsyncAzureDriver(CostMixin, AsyncDriver):
19
+ supports_json_mode = True
20
+ supports_json_schema = True
21
+
22
+ MODEL_PRICING = AzureDriver.MODEL_PRICING
23
+
24
+ def __init__(
25
+ self,
26
+ api_key: str | None = None,
27
+ endpoint: str | None = None,
28
+ deployment_id: str | None = None,
29
+ model: str = "gpt-4o-mini",
30
+ ):
31
+ self.api_key = api_key or os.getenv("AZURE_API_KEY")
32
+ self.endpoint = endpoint or os.getenv("AZURE_API_ENDPOINT")
33
+ self.deployment_id = deployment_id or os.getenv("AZURE_DEPLOYMENT_ID")
34
+ self.api_version = os.getenv("AZURE_API_VERSION", "2023-07-01-preview")
35
+ self.model = model
36
+
37
+ if not self.api_key:
38
+ raise ValueError("Missing Azure API key (AZURE_API_KEY).")
39
+ if not self.endpoint:
40
+ raise ValueError("Missing Azure API endpoint (AZURE_API_ENDPOINT).")
41
+ if not self.deployment_id:
42
+ raise ValueError("Missing Azure deployment ID (AZURE_DEPLOYMENT_ID).")
43
+
44
+ if AsyncAzureOpenAI:
45
+ self.client = AsyncAzureOpenAI(
46
+ api_key=self.api_key,
47
+ api_version=self.api_version,
48
+ azure_endpoint=self.endpoint,
49
+ )
50
+ else:
51
+ self.client = None
52
+
53
+ supports_messages = True
54
+
55
+ async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
56
+ messages = [{"role": "user", "content": prompt}]
57
+ return await self._do_generate(messages, options)
58
+
59
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
60
+ return await self._do_generate(messages, options)
61
+
62
+ async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
63
+ if self.client is None:
64
+ raise RuntimeError("openai package (>=1.0.0) with AsyncAzureOpenAI not installed")
65
+
66
+ model = options.get("model", self.model)
67
+ model_info = self.MODEL_PRICING.get(model, {})
68
+ tokens_param = model_info.get("tokens_param", "max_tokens")
69
+ supports_temperature = model_info.get("supports_temperature", True)
70
+
71
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
72
+
73
+ kwargs = {
74
+ "model": self.deployment_id,
75
+ "messages": messages,
76
+ }
77
+ kwargs[tokens_param] = opts.get("max_tokens", 512)
78
+
79
+ if supports_temperature and "temperature" in opts:
80
+ kwargs["temperature"] = opts["temperature"]
81
+
82
+ # Native JSON mode support
83
+ if options.get("json_mode"):
84
+ json_schema = options.get("json_schema")
85
+ if json_schema:
86
+ kwargs["response_format"] = {
87
+ "type": "json_schema",
88
+ "json_schema": {
89
+ "name": "extraction",
90
+ "strict": True,
91
+ "schema": json_schema,
92
+ },
93
+ }
94
+ else:
95
+ kwargs["response_format"] = {"type": "json_object"}
96
+
97
+ resp = await self.client.chat.completions.create(**kwargs)
98
+
99
+ usage = getattr(resp, "usage", None)
100
+ prompt_tokens = getattr(usage, "prompt_tokens", 0)
101
+ completion_tokens = getattr(usage, "completion_tokens", 0)
102
+ total_tokens = getattr(usage, "total_tokens", 0)
103
+
104
+ total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
105
+
106
+ meta = {
107
+ "prompt_tokens": prompt_tokens,
108
+ "completion_tokens": completion_tokens,
109
+ "total_tokens": total_tokens,
110
+ "cost": total_cost,
111
+ "raw_response": resp.model_dump(),
112
+ "model_name": model,
113
+ "deployment_id": self.deployment_id,
114
+ }
115
+
116
+ text = resp.choices[0].message.content
117
+ return {"text": text, "meta": meta}
@@ -0,0 +1,107 @@
1
+ """Async Anthropic Claude driver. Requires the ``anthropic`` package."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ from typing import Any
8
+
9
+ try:
10
+ import anthropic
11
+ except Exception:
12
+ anthropic = None
13
+
14
+ from ..async_driver import AsyncDriver
15
+ from ..cost_mixin import CostMixin
16
+ from .claude_driver import ClaudeDriver
17
+
18
+
19
+ class AsyncClaudeDriver(CostMixin, AsyncDriver):
20
+ supports_json_mode = True
21
+ supports_json_schema = True
22
+
23
+ MODEL_PRICING = ClaudeDriver.MODEL_PRICING
24
+
25
+ def __init__(self, api_key: str | None = None, model: str = "claude-3-5-haiku-20241022"):
26
+ self.api_key = api_key or os.getenv("CLAUDE_API_KEY")
27
+ self.model = model or os.getenv("CLAUDE_MODEL_NAME", "claude-3-5-haiku-20241022")
28
+
29
+ supports_messages = True
30
+
31
+ async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
32
+ messages = [{"role": "user", "content": prompt}]
33
+ return await self._do_generate(messages, options)
34
+
35
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
36
+ return await self._do_generate(messages, options)
37
+
38
+ async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
39
+ if anthropic is None:
40
+ raise RuntimeError("anthropic package not installed")
41
+
42
+ opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
43
+ model = options.get("model", self.model)
44
+
45
+ client = anthropic.AsyncAnthropic(api_key=self.api_key)
46
+
47
+ # Anthropic requires system messages as a top-level parameter
48
+ system_content = None
49
+ api_messages = []
50
+ for msg in messages:
51
+ if msg.get("role") == "system":
52
+ system_content = msg.get("content", "")
53
+ else:
54
+ api_messages.append(msg)
55
+
56
+ # Build common kwargs
57
+ common_kwargs: dict[str, Any] = {
58
+ "model": model,
59
+ "messages": api_messages,
60
+ "temperature": opts["temperature"],
61
+ "max_tokens": opts["max_tokens"],
62
+ }
63
+ if system_content:
64
+ common_kwargs["system"] = system_content
65
+
66
+ # Native JSON mode: use tool-use for schema enforcement
67
+ if options.get("json_mode"):
68
+ json_schema = options.get("json_schema")
69
+ if json_schema:
70
+ tool_def = {
71
+ "name": "extract_json",
72
+ "description": "Extract structured data matching the schema",
73
+ "input_schema": json_schema,
74
+ }
75
+ resp = await client.messages.create(
76
+ **common_kwargs,
77
+ tools=[tool_def],
78
+ tool_choice={"type": "tool", "name": "extract_json"},
79
+ )
80
+ text = ""
81
+ for block in resp.content:
82
+ if block.type == "tool_use":
83
+ text = json.dumps(block.input)
84
+ break
85
+ else:
86
+ resp = await client.messages.create(**common_kwargs)
87
+ text = resp.content[0].text
88
+ else:
89
+ resp = await client.messages.create(**common_kwargs)
90
+ text = resp.content[0].text
91
+
92
+ prompt_tokens = resp.usage.input_tokens
93
+ completion_tokens = resp.usage.output_tokens
94
+ total_tokens = prompt_tokens + completion_tokens
95
+
96
+ total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
97
+
98
+ meta = {
99
+ "prompt_tokens": prompt_tokens,
100
+ "completion_tokens": completion_tokens,
101
+ "total_tokens": total_tokens,
102
+ "cost": total_cost,
103
+ "raw_response": dict(resp),
104
+ "model_name": model,
105
+ }
106
+
107
+ return {"text": text, "meta": meta}
@@ -0,0 +1,132 @@
1
+ """Async Google Generative AI (Gemini) driver."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ from typing import Any
8
+
9
+ import google.generativeai as genai
10
+
11
+ from ..async_driver import AsyncDriver
12
+ from ..cost_mixin import CostMixin
13
+ from .google_driver import GoogleDriver
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class AsyncGoogleDriver(CostMixin, AsyncDriver):
19
+ """Async driver for Google's Generative AI API (Gemini)."""
20
+
21
+ supports_json_mode = True
22
+ supports_json_schema = True
23
+
24
+ MODEL_PRICING = GoogleDriver.MODEL_PRICING
25
+ _PRICING_UNIT = 1_000_000
26
+
27
+ def __init__(self, api_key: str | None = None, model: str = "gemini-1.5-pro"):
28
+ self.api_key = api_key or os.getenv("GOOGLE_API_KEY")
29
+ if not self.api_key:
30
+ raise ValueError("Google API key not found. Set GOOGLE_API_KEY env var or pass api_key to constructor")
31
+ self.model = model
32
+ genai.configure(api_key=self.api_key)
33
+ self.options: dict[str, Any] = {}
34
+
35
+ def _calculate_cost_chars(self, prompt_chars: int, completion_chars: int) -> float:
36
+ """Calculate cost from character counts (same logic as sync GoogleDriver)."""
37
+ from ..model_rates import get_model_rates
38
+
39
+ live_rates = get_model_rates("google", self.model)
40
+ if live_rates:
41
+ est_prompt_tokens = prompt_chars / 4
42
+ est_completion_tokens = completion_chars / 4
43
+ prompt_cost = (est_prompt_tokens / 1_000_000) * live_rates["input"]
44
+ completion_cost = (est_completion_tokens / 1_000_000) * live_rates["output"]
45
+ else:
46
+ model_pricing = self.MODEL_PRICING.get(self.model, {"prompt": 0, "completion": 0})
47
+ prompt_cost = (prompt_chars / 1_000_000) * model_pricing["prompt"]
48
+ completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
49
+ return round(prompt_cost + completion_cost, 6)
50
+
51
+ supports_messages = True
52
+
53
+ async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
54
+ messages = [{"role": "user", "content": prompt}]
55
+ return await self._do_generate(messages, options)
56
+
57
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
58
+ return await self._do_generate(messages, options)
59
+
60
+ async def _do_generate(
61
+ self, messages: list[dict[str, str]], options: dict[str, Any] | None = None
62
+ ) -> dict[str, Any]:
63
+ merged_options = self.options.copy()
64
+ if options:
65
+ merged_options.update(options)
66
+
67
+ generation_config = merged_options.get("generation_config", {})
68
+ safety_settings = merged_options.get("safety_settings", {})
69
+
70
+ if "temperature" in merged_options and "temperature" not in generation_config:
71
+ generation_config["temperature"] = merged_options["temperature"]
72
+ if "max_tokens" in merged_options and "max_output_tokens" not in generation_config:
73
+ generation_config["max_output_tokens"] = merged_options["max_tokens"]
74
+ if "top_p" in merged_options and "top_p" not in generation_config:
75
+ generation_config["top_p"] = merged_options["top_p"]
76
+ if "top_k" in merged_options and "top_k" not in generation_config:
77
+ generation_config["top_k"] = merged_options["top_k"]
78
+
79
+ # Native JSON mode support
80
+ if merged_options.get("json_mode"):
81
+ generation_config["response_mime_type"] = "application/json"
82
+ json_schema = merged_options.get("json_schema")
83
+ if json_schema:
84
+ generation_config["response_schema"] = json_schema
85
+
86
+ # Convert messages to Gemini format
87
+ system_instruction = None
88
+ contents: list[dict[str, Any]] = []
89
+ for msg in messages:
90
+ role = msg.get("role", "user")
91
+ content = msg.get("content", "")
92
+ if role == "system":
93
+ system_instruction = content
94
+ else:
95
+ gemini_role = "model" if role == "assistant" else "user"
96
+ contents.append({"role": gemini_role, "parts": [content]})
97
+
98
+ try:
99
+ model_kwargs: dict[str, Any] = {}
100
+ if system_instruction:
101
+ model_kwargs["system_instruction"] = system_instruction
102
+ model = genai.GenerativeModel(self.model, **model_kwargs)
103
+
104
+ gen_input: Any = contents if len(contents) != 1 else contents[0]["parts"][0]
105
+ response = await model.generate_content_async(
106
+ gen_input,
107
+ generation_config=generation_config if generation_config else None,
108
+ safety_settings=safety_settings if safety_settings else None,
109
+ )
110
+
111
+ if not response.text:
112
+ raise ValueError("Empty response from model")
113
+
114
+ total_prompt_chars = sum(len(msg.get("content", "")) for msg in messages)
115
+ completion_chars = len(response.text)
116
+
117
+ total_cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
118
+
119
+ meta = {
120
+ "prompt_chars": total_prompt_chars,
121
+ "completion_chars": completion_chars,
122
+ "total_chars": total_prompt_chars + completion_chars,
123
+ "cost": total_cost,
124
+ "raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
125
+ "model_name": self.model,
126
+ }
127
+
128
+ return {"text": response.text, "meta": meta}
129
+
130
+ except Exception as e:
131
+ logger.error(f"Google API request failed: {e}")
132
+ raise RuntimeError(f"Google API request failed: {e}") from e
@@ -0,0 +1,91 @@
1
+ """Async xAI Grok driver using httpx."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from typing import Any
7
+
8
+ import httpx
9
+
10
+ from ..async_driver import AsyncDriver
11
+ from ..cost_mixin import CostMixin
12
+ from .grok_driver import GrokDriver
13
+
14
+
15
+ class AsyncGrokDriver(CostMixin, AsyncDriver):
16
+ supports_json_mode = True
17
+
18
+ MODEL_PRICING = GrokDriver.MODEL_PRICING
19
+ _PRICING_UNIT = 1_000_000
20
+
21
+ def __init__(self, api_key: str | None = None, model: str = "grok-4-fast-reasoning"):
22
+ self.api_key = api_key or os.getenv("GROK_API_KEY")
23
+ self.model = model
24
+ self.api_base = "https://api.x.ai/v1"
25
+
26
+ supports_messages = True
27
+
28
+ async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
29
+ messages = [{"role": "user", "content": prompt}]
30
+ return await self._do_generate(messages, options)
31
+
32
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
33
+ return await self._do_generate(messages, options)
34
+
35
+ async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
36
+ if not self.api_key:
37
+ raise RuntimeError("GROK_API_KEY environment variable is required")
38
+
39
+ model = options.get("model", self.model)
40
+
41
+ model_info = self.MODEL_PRICING.get(model, {})
42
+ tokens_param = model_info.get("tokens_param", "max_tokens")
43
+ supports_temperature = model_info.get("supports_temperature", True)
44
+
45
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
46
+
47
+ payload = {
48
+ "model": model,
49
+ "messages": messages,
50
+ }
51
+ payload[tokens_param] = opts.get("max_tokens", 512)
52
+
53
+ if supports_temperature and "temperature" in opts:
54
+ payload["temperature"] = opts["temperature"]
55
+
56
+ # Native JSON mode support
57
+ if options.get("json_mode"):
58
+ payload["response_format"] = {"type": "json_object"}
59
+
60
+ headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
61
+
62
+ async with httpx.AsyncClient() as client:
63
+ try:
64
+ response = await client.post(
65
+ f"{self.api_base}/chat/completions", headers=headers, json=payload, timeout=120
66
+ )
67
+ response.raise_for_status()
68
+ resp = response.json()
69
+ except httpx.HTTPStatusError as e:
70
+ raise RuntimeError(f"Grok API request failed: {e!s}") from e
71
+ except Exception as e:
72
+ raise RuntimeError(f"Grok API request failed: {e!s}") from e
73
+
74
+ usage = resp.get("usage", {})
75
+ prompt_tokens = usage.get("prompt_tokens", 0)
76
+ completion_tokens = usage.get("completion_tokens", 0)
77
+ total_tokens = usage.get("total_tokens", 0)
78
+
79
+ total_cost = self._calculate_cost("grok", model, prompt_tokens, completion_tokens)
80
+
81
+ meta = {
82
+ "prompt_tokens": prompt_tokens,
83
+ "completion_tokens": completion_tokens,
84
+ "total_tokens": total_tokens,
85
+ "cost": total_cost,
86
+ "raw_response": resp,
87
+ "model_name": model,
88
+ }
89
+
90
+ text = resp["choices"][0]["message"]["content"]
91
+ return {"text": text, "meta": meta}
@@ -0,0 +1,84 @@
1
+ """Async Groq driver. Requires the ``groq`` package."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from typing import Any
7
+
8
+ try:
9
+ import groq
10
+ except Exception:
11
+ groq = None
12
+
13
+ from ..async_driver import AsyncDriver
14
+ from ..cost_mixin import CostMixin
15
+ from .groq_driver import GroqDriver
16
+
17
+
18
+ class AsyncGroqDriver(CostMixin, AsyncDriver):
19
+ supports_json_mode = True
20
+
21
+ MODEL_PRICING = GroqDriver.MODEL_PRICING
22
+
23
+ def __init__(self, api_key: str | None = None, model: str = "llama2-70b-4096"):
24
+ self.api_key = api_key or os.getenv("GROQ_API_KEY")
25
+ self.model = model
26
+ if groq:
27
+ self.client = groq.AsyncClient(api_key=self.api_key)
28
+ else:
29
+ self.client = None
30
+
31
+ supports_messages = True
32
+
33
+ async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
34
+ messages = [{"role": "user", "content": prompt}]
35
+ return await self._do_generate(messages, options)
36
+
37
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
38
+ return await self._do_generate(messages, options)
39
+
40
+ async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
41
+ if self.client is None:
42
+ raise RuntimeError("groq package is not installed")
43
+
44
+ model = options.get("model", self.model)
45
+
46
+ model_info = self.MODEL_PRICING.get(model, {})
47
+ tokens_param = model_info.get("tokens_param", "max_tokens")
48
+ supports_temperature = model_info.get("supports_temperature", True)
49
+
50
+ opts = {"temperature": 0.7, "max_tokens": 512, **options}
51
+
52
+ kwargs = {
53
+ "model": model,
54
+ "messages": messages,
55
+ }
56
+ kwargs[tokens_param] = opts.get("max_tokens", 512)
57
+
58
+ if supports_temperature and "temperature" in opts:
59
+ kwargs["temperature"] = opts["temperature"]
60
+
61
+ # Native JSON mode support
62
+ if options.get("json_mode"):
63
+ kwargs["response_format"] = {"type": "json_object"}
64
+
65
+ resp = await self.client.chat.completions.create(**kwargs)
66
+
67
+ usage = getattr(resp, "usage", None)
68
+ prompt_tokens = getattr(usage, "prompt_tokens", 0)
69
+ completion_tokens = getattr(usage, "completion_tokens", 0)
70
+ total_tokens = getattr(usage, "total_tokens", 0)
71
+
72
+ total_cost = self._calculate_cost("groq", model, prompt_tokens, completion_tokens)
73
+
74
+ meta = {
75
+ "prompt_tokens": prompt_tokens,
76
+ "completion_tokens": completion_tokens,
77
+ "total_tokens": total_tokens,
78
+ "cost": total_cost,
79
+ "raw_response": resp.model_dump(),
80
+ "model_name": model,
81
+ }
82
+
83
+ text = resp.choices[0].message.content
84
+ return {"text": text, "meta": meta}