prompture 0.0.29.dev8__py3-none-any.whl → 0.0.38.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. prompture/__init__.py +264 -23
  2. prompture/_version.py +34 -0
  3. prompture/agent.py +924 -0
  4. prompture/agent_types.py +156 -0
  5. prompture/aio/__init__.py +74 -0
  6. prompture/async_agent.py +880 -0
  7. prompture/async_conversation.py +789 -0
  8. prompture/async_core.py +803 -0
  9. prompture/async_driver.py +193 -0
  10. prompture/async_groups.py +551 -0
  11. prompture/cache.py +469 -0
  12. prompture/callbacks.py +55 -0
  13. prompture/cli.py +63 -4
  14. prompture/conversation.py +826 -0
  15. prompture/core.py +894 -263
  16. prompture/cost_mixin.py +51 -0
  17. prompture/discovery.py +187 -0
  18. prompture/driver.py +206 -5
  19. prompture/drivers/__init__.py +175 -67
  20. prompture/drivers/airllm_driver.py +109 -0
  21. prompture/drivers/async_airllm_driver.py +26 -0
  22. prompture/drivers/async_azure_driver.py +123 -0
  23. prompture/drivers/async_claude_driver.py +113 -0
  24. prompture/drivers/async_google_driver.py +316 -0
  25. prompture/drivers/async_grok_driver.py +97 -0
  26. prompture/drivers/async_groq_driver.py +90 -0
  27. prompture/drivers/async_hugging_driver.py +61 -0
  28. prompture/drivers/async_lmstudio_driver.py +148 -0
  29. prompture/drivers/async_local_http_driver.py +44 -0
  30. prompture/drivers/async_ollama_driver.py +135 -0
  31. prompture/drivers/async_openai_driver.py +102 -0
  32. prompture/drivers/async_openrouter_driver.py +102 -0
  33. prompture/drivers/async_registry.py +133 -0
  34. prompture/drivers/azure_driver.py +42 -9
  35. prompture/drivers/claude_driver.py +257 -34
  36. prompture/drivers/google_driver.py +295 -42
  37. prompture/drivers/grok_driver.py +35 -32
  38. prompture/drivers/groq_driver.py +33 -26
  39. prompture/drivers/hugging_driver.py +6 -6
  40. prompture/drivers/lmstudio_driver.py +97 -19
  41. prompture/drivers/local_http_driver.py +6 -6
  42. prompture/drivers/ollama_driver.py +168 -23
  43. prompture/drivers/openai_driver.py +184 -9
  44. prompture/drivers/openrouter_driver.py +37 -25
  45. prompture/drivers/registry.py +306 -0
  46. prompture/drivers/vision_helpers.py +153 -0
  47. prompture/field_definitions.py +106 -96
  48. prompture/group_types.py +147 -0
  49. prompture/groups.py +530 -0
  50. prompture/image.py +180 -0
  51. prompture/logging.py +80 -0
  52. prompture/model_rates.py +217 -0
  53. prompture/persistence.py +254 -0
  54. prompture/persona.py +482 -0
  55. prompture/runner.py +49 -47
  56. prompture/scaffold/__init__.py +1 -0
  57. prompture/scaffold/generator.py +84 -0
  58. prompture/scaffold/templates/Dockerfile.j2 +12 -0
  59. prompture/scaffold/templates/README.md.j2 +41 -0
  60. prompture/scaffold/templates/config.py.j2 +21 -0
  61. prompture/scaffold/templates/env.example.j2 +8 -0
  62. prompture/scaffold/templates/main.py.j2 +86 -0
  63. prompture/scaffold/templates/models.py.j2 +40 -0
  64. prompture/scaffold/templates/requirements.txt.j2 +5 -0
  65. prompture/serialization.py +218 -0
  66. prompture/server.py +183 -0
  67. prompture/session.py +117 -0
  68. prompture/settings.py +19 -1
  69. prompture/tools.py +219 -267
  70. prompture/tools_schema.py +254 -0
  71. prompture/validator.py +3 -3
  72. prompture-0.0.38.dev2.dist-info/METADATA +369 -0
  73. prompture-0.0.38.dev2.dist-info/RECORD +77 -0
  74. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/WHEEL +1 -1
  75. prompture-0.0.29.dev8.dist-info/METADATA +0 -368
  76. prompture-0.0.29.dev8.dist-info/RECORD +0 -27
  77. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/entry_points.txt +0 -0
  78. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/licenses/LICENSE +0 -0
  79. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,316 @@
1
+ """Async Google Generative AI (Gemini) driver."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ import uuid
8
+ from collections.abc import AsyncIterator
9
+ from typing import Any
10
+
11
+ import google.generativeai as genai
12
+
13
+ from ..async_driver import AsyncDriver
14
+ from ..cost_mixin import CostMixin
15
+ from .google_driver import GoogleDriver
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class AsyncGoogleDriver(CostMixin, AsyncDriver):
21
+ """Async driver for Google's Generative AI API (Gemini)."""
22
+
23
+ supports_json_mode = True
24
+ supports_json_schema = True
25
+ supports_vision = True
26
+ supports_tool_use = True
27
+ supports_streaming = True
28
+
29
+ MODEL_PRICING = GoogleDriver.MODEL_PRICING
30
+ _PRICING_UNIT = 1_000_000
31
+
32
+ def __init__(self, api_key: str | None = None, model: str = "gemini-1.5-pro"):
33
+ self.api_key = api_key or os.getenv("GOOGLE_API_KEY")
34
+ if not self.api_key:
35
+ raise ValueError("Google API key not found. Set GOOGLE_API_KEY env var or pass api_key to constructor")
36
+ self.model = model
37
+ genai.configure(api_key=self.api_key)
38
+ self.options: dict[str, Any] = {}
39
+
40
+ def _calculate_cost_chars(self, prompt_chars: int, completion_chars: int) -> float:
41
+ """Calculate cost from character counts (same logic as sync GoogleDriver)."""
42
+ from ..model_rates import get_model_rates
43
+
44
+ live_rates = get_model_rates("google", self.model)
45
+ if live_rates:
46
+ est_prompt_tokens = prompt_chars / 4
47
+ est_completion_tokens = completion_chars / 4
48
+ prompt_cost = (est_prompt_tokens / 1_000_000) * live_rates["input"]
49
+ completion_cost = (est_completion_tokens / 1_000_000) * live_rates["output"]
50
+ else:
51
+ model_pricing = self.MODEL_PRICING.get(self.model, {"prompt": 0, "completion": 0})
52
+ prompt_cost = (prompt_chars / 1_000_000) * model_pricing["prompt"]
53
+ completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
54
+ return round(prompt_cost + completion_cost, 6)
55
+
56
+ def _extract_usage_metadata(self, response: Any, messages: list[dict[str, Any]]) -> dict[str, Any]:
57
+ """Extract token counts from response, falling back to character estimation."""
58
+ usage = getattr(response, "usage_metadata", None)
59
+ if usage:
60
+ prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
61
+ completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
62
+ total_tokens = getattr(usage, "total_token_count", 0) or (prompt_tokens + completion_tokens)
63
+ cost = self._calculate_cost("google", self.model, prompt_tokens, completion_tokens)
64
+ else:
65
+ # Fallback: estimate from character counts
66
+ total_prompt_chars = 0
67
+ for msg in messages:
68
+ c = msg.get("content", "")
69
+ if isinstance(c, str):
70
+ total_prompt_chars += len(c)
71
+ elif isinstance(c, list):
72
+ for part in c:
73
+ if isinstance(part, str):
74
+ total_prompt_chars += len(part)
75
+ elif isinstance(part, dict) and "text" in part:
76
+ total_prompt_chars += len(part["text"])
77
+ completion_chars = len(response.text) if response.text else 0
78
+ prompt_tokens = total_prompt_chars // 4
79
+ completion_tokens = completion_chars // 4
80
+ total_tokens = prompt_tokens + completion_tokens
81
+ cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
82
+
83
+ return {
84
+ "prompt_tokens": prompt_tokens,
85
+ "completion_tokens": completion_tokens,
86
+ "total_tokens": total_tokens,
87
+ "cost": round(cost, 6),
88
+ }
89
+
90
+ supports_messages = True
91
+
92
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
93
+ from .vision_helpers import _prepare_google_vision_messages
94
+
95
+ return _prepare_google_vision_messages(messages)
96
+
97
+ def _build_generation_args(
98
+ self, messages: list[dict[str, Any]], options: dict[str, Any] | None = None
99
+ ) -> tuple[Any, dict[str, Any], dict[str, Any]]:
100
+ """Parse messages and options into (gen_input, gen_kwargs, model_kwargs)."""
101
+ merged_options = self.options.copy()
102
+ if options:
103
+ merged_options.update(options)
104
+
105
+ generation_config = merged_options.get("generation_config", {})
106
+ safety_settings = merged_options.get("safety_settings", {})
107
+
108
+ if "temperature" in merged_options and "temperature" not in generation_config:
109
+ generation_config["temperature"] = merged_options["temperature"]
110
+ if "max_tokens" in merged_options and "max_output_tokens" not in generation_config:
111
+ generation_config["max_output_tokens"] = merged_options["max_tokens"]
112
+ if "top_p" in merged_options and "top_p" not in generation_config:
113
+ generation_config["top_p"] = merged_options["top_p"]
114
+ if "top_k" in merged_options and "top_k" not in generation_config:
115
+ generation_config["top_k"] = merged_options["top_k"]
116
+
117
+ # Native JSON mode support
118
+ if merged_options.get("json_mode"):
119
+ generation_config["response_mime_type"] = "application/json"
120
+ json_schema = merged_options.get("json_schema")
121
+ if json_schema:
122
+ generation_config["response_schema"] = json_schema
123
+
124
+ # Convert messages to Gemini format
125
+ system_instruction = None
126
+ contents: list[dict[str, Any]] = []
127
+ for msg in messages:
128
+ role = msg.get("role", "user")
129
+ content = msg.get("content", "")
130
+ if role == "system":
131
+ system_instruction = content if isinstance(content, str) else str(content)
132
+ else:
133
+ gemini_role = "model" if role == "assistant" else "user"
134
+ if msg.get("_vision_parts"):
135
+ contents.append({"role": gemini_role, "parts": content})
136
+ else:
137
+ contents.append({"role": gemini_role, "parts": [content]})
138
+
139
+ # For a single message, unwrap only if it has exactly one string part
140
+ if len(contents) == 1:
141
+ parts = contents[0]["parts"]
142
+ if len(parts) == 1 and isinstance(parts[0], str):
143
+ gen_input = parts[0]
144
+ else:
145
+ gen_input = contents
146
+ else:
147
+ gen_input = contents
148
+
149
+ model_kwargs: dict[str, Any] = {}
150
+ if system_instruction:
151
+ model_kwargs["system_instruction"] = system_instruction
152
+
153
+ gen_kwargs: dict[str, Any] = {
154
+ "generation_config": generation_config if generation_config else None,
155
+ "safety_settings": safety_settings if safety_settings else None,
156
+ }
157
+
158
+ return gen_input, gen_kwargs, model_kwargs
159
+
160
+ async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
161
+ messages = [{"role": "user", "content": prompt}]
162
+ return await self._do_generate(messages, options)
163
+
164
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
165
+ return await self._do_generate(self._prepare_messages(messages), options)
166
+
167
+ async def _do_generate(
168
+ self, messages: list[dict[str, str]], options: dict[str, Any] | None = None
169
+ ) -> dict[str, Any]:
170
+ gen_input, gen_kwargs, model_kwargs = self._build_generation_args(messages, options)
171
+
172
+ try:
173
+ model = genai.GenerativeModel(self.model, **model_kwargs)
174
+ response = await model.generate_content_async(gen_input, **gen_kwargs)
175
+
176
+ if not response.text:
177
+ raise ValueError("Empty response from model")
178
+
179
+ usage_meta = self._extract_usage_metadata(response, messages)
180
+
181
+ meta = {
182
+ **usage_meta,
183
+ "raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
184
+ "model_name": self.model,
185
+ }
186
+
187
+ return {"text": response.text, "meta": meta}
188
+
189
+ except Exception as e:
190
+ logger.error(f"Google API request failed: {e}")
191
+ raise RuntimeError(f"Google API request failed: {e}") from e
192
+
193
+ # ------------------------------------------------------------------
194
+ # Tool use
195
+ # ------------------------------------------------------------------
196
+
197
+ async def generate_messages_with_tools(
198
+ self,
199
+ messages: list[dict[str, Any]],
200
+ tools: list[dict[str, Any]],
201
+ options: dict[str, Any],
202
+ ) -> dict[str, Any]:
203
+ """Generate a response that may include tool/function calls (async)."""
204
+ gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
205
+ self._prepare_messages(messages), options
206
+ )
207
+
208
+ # Convert tools from OpenAI format to Gemini function declarations
209
+ function_declarations = []
210
+ for t in tools:
211
+ if "type" in t and t["type"] == "function":
212
+ fn = t["function"]
213
+ decl = {
214
+ "name": fn["name"],
215
+ "description": fn.get("description", ""),
216
+ }
217
+ params = fn.get("parameters")
218
+ if params:
219
+ decl["parameters"] = params
220
+ function_declarations.append(decl)
221
+ elif "name" in t:
222
+ decl = {"name": t["name"], "description": t.get("description", "")}
223
+ params = t.get("parameters") or t.get("input_schema")
224
+ if params:
225
+ decl["parameters"] = params
226
+ function_declarations.append(decl)
227
+
228
+ try:
229
+ model = genai.GenerativeModel(self.model, **model_kwargs)
230
+
231
+ gemini_tools = [genai.types.Tool(function_declarations=function_declarations)]
232
+ response = await model.generate_content_async(gen_input, tools=gemini_tools, **gen_kwargs)
233
+
234
+ usage_meta = self._extract_usage_metadata(response, messages)
235
+ meta = {
236
+ **usage_meta,
237
+ "raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
238
+ "model_name": self.model,
239
+ }
240
+
241
+ text = ""
242
+ tool_calls_out: list[dict[str, Any]] = []
243
+ stop_reason = "stop"
244
+
245
+ for candidate in response.candidates:
246
+ for part in candidate.content.parts:
247
+ if hasattr(part, "text") and part.text:
248
+ text += part.text
249
+ if hasattr(part, "function_call") and part.function_call.name:
250
+ fc = part.function_call
251
+ tool_calls_out.append({
252
+ "id": str(uuid.uuid4()),
253
+ "name": fc.name,
254
+ "arguments": dict(fc.args) if fc.args else {},
255
+ })
256
+
257
+ finish_reason = getattr(candidate, "finish_reason", None)
258
+ if finish_reason is not None:
259
+ reason_map = {1: "stop", 2: "max_tokens", 3: "safety", 4: "recitation", 5: "other"}
260
+ stop_reason = reason_map.get(finish_reason, "stop")
261
+
262
+ if tool_calls_out:
263
+ stop_reason = "tool_use"
264
+
265
+ return {
266
+ "text": text,
267
+ "meta": meta,
268
+ "tool_calls": tool_calls_out,
269
+ "stop_reason": stop_reason,
270
+ }
271
+
272
+ except Exception as e:
273
+ logger.error(f"Google API tool call request failed: {e}")
274
+ raise RuntimeError(f"Google API tool call request failed: {e}") from e
275
+
276
+ # ------------------------------------------------------------------
277
+ # Streaming
278
+ # ------------------------------------------------------------------
279
+
280
+ async def generate_messages_stream(
281
+ self,
282
+ messages: list[dict[str, Any]],
283
+ options: dict[str, Any],
284
+ ) -> AsyncIterator[dict[str, Any]]:
285
+ """Yield response chunks via Gemini async streaming API."""
286
+ gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
287
+ self._prepare_messages(messages), options
288
+ )
289
+
290
+ try:
291
+ model = genai.GenerativeModel(self.model, **model_kwargs)
292
+ response = await model.generate_content_async(gen_input, stream=True, **gen_kwargs)
293
+
294
+ full_text = ""
295
+ async for chunk in response:
296
+ chunk_text = getattr(chunk, "text", None) or ""
297
+ if chunk_text:
298
+ full_text += chunk_text
299
+ yield {"type": "delta", "text": chunk_text}
300
+
301
+ # After iteration completes, usage_metadata should be available
302
+ usage_meta = self._extract_usage_metadata(response, messages)
303
+
304
+ yield {
305
+ "type": "done",
306
+ "text": full_text,
307
+ "meta": {
308
+ **usage_meta,
309
+ "raw_response": {},
310
+ "model_name": self.model,
311
+ },
312
+ }
313
+
314
+ except Exception as e:
315
+ logger.error(f"Google API streaming request failed: {e}")
316
+ raise RuntimeError(f"Google API streaming request failed: {e}") from e
@@ -0,0 +1,97 @@
1
+ """Async xAI Grok driver using httpx."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from typing import Any
7
+
8
+ import httpx
9
+
10
+ from ..async_driver import AsyncDriver
11
+ from ..cost_mixin import CostMixin
12
+ from .grok_driver import GrokDriver
13
+
14
+
15
+ class AsyncGrokDriver(CostMixin, AsyncDriver):
16
+ supports_json_mode = True
17
+ supports_vision = True
18
+
19
+ MODEL_PRICING = GrokDriver.MODEL_PRICING
20
+ _PRICING_UNIT = 1_000_000
21
+
22
+ def __init__(self, api_key: str | None = None, model: str = "grok-4-fast-reasoning"):
23
+ self.api_key = api_key or os.getenv("GROK_API_KEY")
24
+ self.model = model
25
+ self.api_base = "https://api.x.ai/v1"
26
+
27
+ supports_messages = True
28
+
29
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
30
+ from .vision_helpers import _prepare_openai_vision_messages
31
+
32
+ return _prepare_openai_vision_messages(messages)
33
+
34
+ async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
35
+ messages = [{"role": "user", "content": prompt}]
36
+ return await self._do_generate(messages, options)
37
+
38
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
39
+ return await self._do_generate(self._prepare_messages(messages), options)
40
+
41
+ async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
42
+ if not self.api_key:
43
+ raise RuntimeError("GROK_API_KEY environment variable is required")
44
+
45
+ model = options.get("model", self.model)
46
+
47
+ model_info = self.MODEL_PRICING.get(model, {})
48
+ tokens_param = model_info.get("tokens_param", "max_tokens")
49
+ supports_temperature = model_info.get("supports_temperature", True)
50
+
51
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
52
+
53
+ payload = {
54
+ "model": model,
55
+ "messages": messages,
56
+ }
57
+ payload[tokens_param] = opts.get("max_tokens", 512)
58
+
59
+ if supports_temperature and "temperature" in opts:
60
+ payload["temperature"] = opts["temperature"]
61
+
62
+ # Native JSON mode support
63
+ if options.get("json_mode"):
64
+ payload["response_format"] = {"type": "json_object"}
65
+
66
+ headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
67
+
68
+ async with httpx.AsyncClient() as client:
69
+ try:
70
+ response = await client.post(
71
+ f"{self.api_base}/chat/completions", headers=headers, json=payload, timeout=120
72
+ )
73
+ response.raise_for_status()
74
+ resp = response.json()
75
+ except httpx.HTTPStatusError as e:
76
+ raise RuntimeError(f"Grok API request failed: {e!s}") from e
77
+ except Exception as e:
78
+ raise RuntimeError(f"Grok API request failed: {e!s}") from e
79
+
80
+ usage = resp.get("usage", {})
81
+ prompt_tokens = usage.get("prompt_tokens", 0)
82
+ completion_tokens = usage.get("completion_tokens", 0)
83
+ total_tokens = usage.get("total_tokens", 0)
84
+
85
+ total_cost = self._calculate_cost("grok", model, prompt_tokens, completion_tokens)
86
+
87
+ meta = {
88
+ "prompt_tokens": prompt_tokens,
89
+ "completion_tokens": completion_tokens,
90
+ "total_tokens": total_tokens,
91
+ "cost": total_cost,
92
+ "raw_response": resp,
93
+ "model_name": model,
94
+ }
95
+
96
+ text = resp["choices"][0]["message"]["content"]
97
+ return {"text": text, "meta": meta}
@@ -0,0 +1,90 @@
1
+ """Async Groq driver. Requires the ``groq`` package."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from typing import Any
7
+
8
+ try:
9
+ import groq
10
+ except Exception:
11
+ groq = None
12
+
13
+ from ..async_driver import AsyncDriver
14
+ from ..cost_mixin import CostMixin
15
+ from .groq_driver import GroqDriver
16
+
17
+
18
+ class AsyncGroqDriver(CostMixin, AsyncDriver):
19
+ supports_json_mode = True
20
+ supports_vision = True
21
+
22
+ MODEL_PRICING = GroqDriver.MODEL_PRICING
23
+
24
+ def __init__(self, api_key: str | None = None, model: str = "llama2-70b-4096"):
25
+ self.api_key = api_key or os.getenv("GROQ_API_KEY")
26
+ self.model = model
27
+ if groq:
28
+ self.client = groq.AsyncClient(api_key=self.api_key)
29
+ else:
30
+ self.client = None
31
+
32
+ supports_messages = True
33
+
34
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
35
+ from .vision_helpers import _prepare_openai_vision_messages
36
+
37
+ return _prepare_openai_vision_messages(messages)
38
+
39
+ async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
40
+ messages = [{"role": "user", "content": prompt}]
41
+ return await self._do_generate(messages, options)
42
+
43
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
44
+ return await self._do_generate(self._prepare_messages(messages), options)
45
+
46
+ async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
47
+ if self.client is None:
48
+ raise RuntimeError("groq package is not installed")
49
+
50
+ model = options.get("model", self.model)
51
+
52
+ model_info = self.MODEL_PRICING.get(model, {})
53
+ tokens_param = model_info.get("tokens_param", "max_tokens")
54
+ supports_temperature = model_info.get("supports_temperature", True)
55
+
56
+ opts = {"temperature": 0.7, "max_tokens": 512, **options}
57
+
58
+ kwargs = {
59
+ "model": model,
60
+ "messages": messages,
61
+ }
62
+ kwargs[tokens_param] = opts.get("max_tokens", 512)
63
+
64
+ if supports_temperature and "temperature" in opts:
65
+ kwargs["temperature"] = opts["temperature"]
66
+
67
+ # Native JSON mode support
68
+ if options.get("json_mode"):
69
+ kwargs["response_format"] = {"type": "json_object"}
70
+
71
+ resp = await self.client.chat.completions.create(**kwargs)
72
+
73
+ usage = getattr(resp, "usage", None)
74
+ prompt_tokens = getattr(usage, "prompt_tokens", 0)
75
+ completion_tokens = getattr(usage, "completion_tokens", 0)
76
+ total_tokens = getattr(usage, "total_tokens", 0)
77
+
78
+ total_cost = self._calculate_cost("groq", model, prompt_tokens, completion_tokens)
79
+
80
+ meta = {
81
+ "prompt_tokens": prompt_tokens,
82
+ "completion_tokens": completion_tokens,
83
+ "total_tokens": total_tokens,
84
+ "cost": total_cost,
85
+ "raw_response": resp.model_dump(),
86
+ "model_name": model,
87
+ }
88
+
89
+ text = resp.choices[0].message.content
90
+ return {"text": text, "meta": meta}
@@ -0,0 +1,61 @@
1
+ """Async Hugging Face driver using httpx."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from typing import Any
7
+
8
+ import httpx
9
+
10
+ from ..async_driver import AsyncDriver
11
+
12
+
13
+ class AsyncHuggingFaceDriver(AsyncDriver):
14
+ MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
15
+
16
+ def __init__(self, endpoint: str | None = None, token: str | None = None, model: str = "bert-base-uncased"):
17
+ self.endpoint = endpoint or os.getenv("HF_ENDPOINT")
18
+ self.token = token or os.getenv("HF_TOKEN")
19
+ self.model = model
20
+
21
+ if not self.endpoint:
22
+ raise ValueError("Hugging Face endpoint is not configured. Set HF_ENDPOINT or pass explicitly.")
23
+ if not self.token:
24
+ raise ValueError("Hugging Face token is not configured. Set HF_TOKEN or pass explicitly.")
25
+
26
+ self.headers = {"Authorization": f"Bearer {self.token}"}
27
+
28
+ async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
29
+ payload = {
30
+ "inputs": prompt,
31
+ "parameters": options,
32
+ }
33
+
34
+ async with httpx.AsyncClient() as client:
35
+ try:
36
+ r = await client.post(
37
+ self.endpoint, headers=self.headers, json=payload, timeout=options.get("timeout", 60)
38
+ )
39
+ r.raise_for_status()
40
+ response_data = r.json()
41
+ except Exception as e:
42
+ raise RuntimeError(f"AsyncHuggingFaceDriver request failed: {e}") from e
43
+
44
+ text = None
45
+ if isinstance(response_data, list) and response_data and "generated_text" in response_data[0]:
46
+ text = response_data[0]["generated_text"]
47
+ elif isinstance(response_data, dict) and "generated_text" in response_data:
48
+ text = response_data["generated_text"]
49
+ else:
50
+ text = str(response_data)
51
+
52
+ meta = {
53
+ "prompt_tokens": 0,
54
+ "completion_tokens": 0,
55
+ "total_tokens": 0,
56
+ "cost": 0.0,
57
+ "raw_response": response_data,
58
+ "model_name": options.get("model", self.model),
59
+ }
60
+
61
+ return {"text": text, "meta": meta}