prompture 0.0.33.dev2__py3-none-any.whl → 0.0.34.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. prompture/__init__.py +112 -54
  2. prompture/_version.py +34 -0
  3. prompture/aio/__init__.py +74 -0
  4. prompture/async_conversation.py +484 -0
  5. prompture/async_core.py +803 -0
  6. prompture/async_driver.py +131 -0
  7. prompture/cache.py +469 -0
  8. prompture/callbacks.py +50 -0
  9. prompture/cli.py +7 -3
  10. prompture/conversation.py +504 -0
  11. prompture/core.py +475 -352
  12. prompture/cost_mixin.py +51 -0
  13. prompture/discovery.py +41 -36
  14. prompture/driver.py +125 -5
  15. prompture/drivers/__init__.py +63 -57
  16. prompture/drivers/airllm_driver.py +13 -20
  17. prompture/drivers/async_airllm_driver.py +26 -0
  18. prompture/drivers/async_azure_driver.py +117 -0
  19. prompture/drivers/async_claude_driver.py +107 -0
  20. prompture/drivers/async_google_driver.py +132 -0
  21. prompture/drivers/async_grok_driver.py +91 -0
  22. prompture/drivers/async_groq_driver.py +84 -0
  23. prompture/drivers/async_hugging_driver.py +61 -0
  24. prompture/drivers/async_lmstudio_driver.py +79 -0
  25. prompture/drivers/async_local_http_driver.py +44 -0
  26. prompture/drivers/async_ollama_driver.py +125 -0
  27. prompture/drivers/async_openai_driver.py +96 -0
  28. prompture/drivers/async_openrouter_driver.py +96 -0
  29. prompture/drivers/async_registry.py +80 -0
  30. prompture/drivers/azure_driver.py +36 -15
  31. prompture/drivers/claude_driver.py +86 -40
  32. prompture/drivers/google_driver.py +86 -58
  33. prompture/drivers/grok_driver.py +29 -38
  34. prompture/drivers/groq_driver.py +27 -32
  35. prompture/drivers/hugging_driver.py +6 -6
  36. prompture/drivers/lmstudio_driver.py +26 -13
  37. prompture/drivers/local_http_driver.py +6 -6
  38. prompture/drivers/ollama_driver.py +90 -23
  39. prompture/drivers/openai_driver.py +36 -15
  40. prompture/drivers/openrouter_driver.py +31 -31
  41. prompture/field_definitions.py +106 -96
  42. prompture/logging.py +80 -0
  43. prompture/model_rates.py +16 -15
  44. prompture/runner.py +49 -47
  45. prompture/session.py +117 -0
  46. prompture/settings.py +11 -1
  47. prompture/tools.py +172 -265
  48. prompture/validator.py +3 -3
  49. {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/METADATA +18 -20
  50. prompture-0.0.34.dev1.dist-info/RECORD +54 -0
  51. prompture-0.0.33.dev2.dist-info/RECORD +0 -30
  52. {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/WHEEL +0 -0
  53. {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/entry_points.txt +0 -0
  54. {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/licenses/LICENSE +0 -0
  55. {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,117 @@
1
+ """Async Azure OpenAI driver. Requires the ``openai`` package (>=1.0.0)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from typing import Any
7
+
8
+ try:
9
+ from openai import AsyncAzureOpenAI
10
+ except Exception:
11
+ AsyncAzureOpenAI = None
12
+
13
+ from ..async_driver import AsyncDriver
14
+ from ..cost_mixin import CostMixin
15
+ from .azure_driver import AzureDriver
16
+
17
+
18
+ class AsyncAzureDriver(CostMixin, AsyncDriver):
19
+ supports_json_mode = True
20
+ supports_json_schema = True
21
+
22
+ MODEL_PRICING = AzureDriver.MODEL_PRICING
23
+
24
+ def __init__(
25
+ self,
26
+ api_key: str | None = None,
27
+ endpoint: str | None = None,
28
+ deployment_id: str | None = None,
29
+ model: str = "gpt-4o-mini",
30
+ ):
31
+ self.api_key = api_key or os.getenv("AZURE_API_KEY")
32
+ self.endpoint = endpoint or os.getenv("AZURE_API_ENDPOINT")
33
+ self.deployment_id = deployment_id or os.getenv("AZURE_DEPLOYMENT_ID")
34
+ self.api_version = os.getenv("AZURE_API_VERSION", "2023-07-01-preview")
35
+ self.model = model
36
+
37
+ if not self.api_key:
38
+ raise ValueError("Missing Azure API key (AZURE_API_KEY).")
39
+ if not self.endpoint:
40
+ raise ValueError("Missing Azure API endpoint (AZURE_API_ENDPOINT).")
41
+ if not self.deployment_id:
42
+ raise ValueError("Missing Azure deployment ID (AZURE_DEPLOYMENT_ID).")
43
+
44
+ if AsyncAzureOpenAI:
45
+ self.client = AsyncAzureOpenAI(
46
+ api_key=self.api_key,
47
+ api_version=self.api_version,
48
+ azure_endpoint=self.endpoint,
49
+ )
50
+ else:
51
+ self.client = None
52
+
53
+ supports_messages = True
54
+
55
+ async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
56
+ messages = [{"role": "user", "content": prompt}]
57
+ return await self._do_generate(messages, options)
58
+
59
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
60
+ return await self._do_generate(messages, options)
61
+
62
+ async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
63
+ if self.client is None:
64
+ raise RuntimeError("openai package (>=1.0.0) with AsyncAzureOpenAI not installed")
65
+
66
+ model = options.get("model", self.model)
67
+ model_info = self.MODEL_PRICING.get(model, {})
68
+ tokens_param = model_info.get("tokens_param", "max_tokens")
69
+ supports_temperature = model_info.get("supports_temperature", True)
70
+
71
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
72
+
73
+ kwargs = {
74
+ "model": self.deployment_id,
75
+ "messages": messages,
76
+ }
77
+ kwargs[tokens_param] = opts.get("max_tokens", 512)
78
+
79
+ if supports_temperature and "temperature" in opts:
80
+ kwargs["temperature"] = opts["temperature"]
81
+
82
+ # Native JSON mode support
83
+ if options.get("json_mode"):
84
+ json_schema = options.get("json_schema")
85
+ if json_schema:
86
+ kwargs["response_format"] = {
87
+ "type": "json_schema",
88
+ "json_schema": {
89
+ "name": "extraction",
90
+ "strict": True,
91
+ "schema": json_schema,
92
+ },
93
+ }
94
+ else:
95
+ kwargs["response_format"] = {"type": "json_object"}
96
+
97
+ resp = await self.client.chat.completions.create(**kwargs)
98
+
99
+ usage = getattr(resp, "usage", None)
100
+ prompt_tokens = getattr(usage, "prompt_tokens", 0)
101
+ completion_tokens = getattr(usage, "completion_tokens", 0)
102
+ total_tokens = getattr(usage, "total_tokens", 0)
103
+
104
+ total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
105
+
106
+ meta = {
107
+ "prompt_tokens": prompt_tokens,
108
+ "completion_tokens": completion_tokens,
109
+ "total_tokens": total_tokens,
110
+ "cost": total_cost,
111
+ "raw_response": resp.model_dump(),
112
+ "model_name": model,
113
+ "deployment_id": self.deployment_id,
114
+ }
115
+
116
+ text = resp.choices[0].message.content
117
+ return {"text": text, "meta": meta}
@@ -0,0 +1,107 @@
1
+ """Async Anthropic Claude driver. Requires the ``anthropic`` package."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ from typing import Any
8
+
9
+ try:
10
+ import anthropic
11
+ except Exception:
12
+ anthropic = None
13
+
14
+ from ..async_driver import AsyncDriver
15
+ from ..cost_mixin import CostMixin
16
+ from .claude_driver import ClaudeDriver
17
+
18
+
19
+ class AsyncClaudeDriver(CostMixin, AsyncDriver):
20
+ supports_json_mode = True
21
+ supports_json_schema = True
22
+
23
+ MODEL_PRICING = ClaudeDriver.MODEL_PRICING
24
+
25
+ def __init__(self, api_key: str | None = None, model: str = "claude-3-5-haiku-20241022"):
26
+ self.api_key = api_key or os.getenv("CLAUDE_API_KEY")
27
+ self.model = model or os.getenv("CLAUDE_MODEL_NAME", "claude-3-5-haiku-20241022")
28
+
29
+ supports_messages = True
30
+
31
+ async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
32
+ messages = [{"role": "user", "content": prompt}]
33
+ return await self._do_generate(messages, options)
34
+
35
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
36
+ return await self._do_generate(messages, options)
37
+
38
+ async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
39
+ if anthropic is None:
40
+ raise RuntimeError("anthropic package not installed")
41
+
42
+ opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
43
+ model = options.get("model", self.model)
44
+
45
+ client = anthropic.AsyncAnthropic(api_key=self.api_key)
46
+
47
+ # Anthropic requires system messages as a top-level parameter
48
+ system_content = None
49
+ api_messages = []
50
+ for msg in messages:
51
+ if msg.get("role") == "system":
52
+ system_content = msg.get("content", "")
53
+ else:
54
+ api_messages.append(msg)
55
+
56
+ # Build common kwargs
57
+ common_kwargs: dict[str, Any] = {
58
+ "model": model,
59
+ "messages": api_messages,
60
+ "temperature": opts["temperature"],
61
+ "max_tokens": opts["max_tokens"],
62
+ }
63
+ if system_content:
64
+ common_kwargs["system"] = system_content
65
+
66
+ # Native JSON mode: use tool-use for schema enforcement
67
+ if options.get("json_mode"):
68
+ json_schema = options.get("json_schema")
69
+ if json_schema:
70
+ tool_def = {
71
+ "name": "extract_json",
72
+ "description": "Extract structured data matching the schema",
73
+ "input_schema": json_schema,
74
+ }
75
+ resp = await client.messages.create(
76
+ **common_kwargs,
77
+ tools=[tool_def],
78
+ tool_choice={"type": "tool", "name": "extract_json"},
79
+ )
80
+ text = ""
81
+ for block in resp.content:
82
+ if block.type == "tool_use":
83
+ text = json.dumps(block.input)
84
+ break
85
+ else:
86
+ resp = await client.messages.create(**common_kwargs)
87
+ text = resp.content[0].text
88
+ else:
89
+ resp = await client.messages.create(**common_kwargs)
90
+ text = resp.content[0].text
91
+
92
+ prompt_tokens = resp.usage.input_tokens
93
+ completion_tokens = resp.usage.output_tokens
94
+ total_tokens = prompt_tokens + completion_tokens
95
+
96
+ total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
97
+
98
+ meta = {
99
+ "prompt_tokens": prompt_tokens,
100
+ "completion_tokens": completion_tokens,
101
+ "total_tokens": total_tokens,
102
+ "cost": total_cost,
103
+ "raw_response": dict(resp),
104
+ "model_name": model,
105
+ }
106
+
107
+ return {"text": text, "meta": meta}
@@ -0,0 +1,132 @@
1
+ """Async Google Generative AI (Gemini) driver."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ from typing import Any
8
+
9
+ import google.generativeai as genai
10
+
11
+ from ..async_driver import AsyncDriver
12
+ from ..cost_mixin import CostMixin
13
+ from .google_driver import GoogleDriver
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class AsyncGoogleDriver(CostMixin, AsyncDriver):
19
+ """Async driver for Google's Generative AI API (Gemini)."""
20
+
21
+ supports_json_mode = True
22
+ supports_json_schema = True
23
+
24
+ MODEL_PRICING = GoogleDriver.MODEL_PRICING
25
+ _PRICING_UNIT = 1_000_000
26
+
27
+ def __init__(self, api_key: str | None = None, model: str = "gemini-1.5-pro"):
28
+ self.api_key = api_key or os.getenv("GOOGLE_API_KEY")
29
+ if not self.api_key:
30
+ raise ValueError("Google API key not found. Set GOOGLE_API_KEY env var or pass api_key to constructor")
31
+ self.model = model
32
+ genai.configure(api_key=self.api_key)
33
+ self.options: dict[str, Any] = {}
34
+
35
+ def _calculate_cost_chars(self, prompt_chars: int, completion_chars: int) -> float:
36
+ """Calculate cost from character counts (same logic as sync GoogleDriver)."""
37
+ from ..model_rates import get_model_rates
38
+
39
+ live_rates = get_model_rates("google", self.model)
40
+ if live_rates:
41
+ est_prompt_tokens = prompt_chars / 4
42
+ est_completion_tokens = completion_chars / 4
43
+ prompt_cost = (est_prompt_tokens / 1_000_000) * live_rates["input"]
44
+ completion_cost = (est_completion_tokens / 1_000_000) * live_rates["output"]
45
+ else:
46
+ model_pricing = self.MODEL_PRICING.get(self.model, {"prompt": 0, "completion": 0})
47
+ prompt_cost = (prompt_chars / 1_000_000) * model_pricing["prompt"]
48
+ completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
49
+ return round(prompt_cost + completion_cost, 6)
50
+
51
+ supports_messages = True
52
+
53
+ async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
54
+ messages = [{"role": "user", "content": prompt}]
55
+ return await self._do_generate(messages, options)
56
+
57
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
58
+ return await self._do_generate(messages, options)
59
+
60
+ async def _do_generate(
61
+ self, messages: list[dict[str, str]], options: dict[str, Any] | None = None
62
+ ) -> dict[str, Any]:
63
+ merged_options = self.options.copy()
64
+ if options:
65
+ merged_options.update(options)
66
+
67
+ generation_config = merged_options.get("generation_config", {})
68
+ safety_settings = merged_options.get("safety_settings", {})
69
+
70
+ if "temperature" in merged_options and "temperature" not in generation_config:
71
+ generation_config["temperature"] = merged_options["temperature"]
72
+ if "max_tokens" in merged_options and "max_output_tokens" not in generation_config:
73
+ generation_config["max_output_tokens"] = merged_options["max_tokens"]
74
+ if "top_p" in merged_options and "top_p" not in generation_config:
75
+ generation_config["top_p"] = merged_options["top_p"]
76
+ if "top_k" in merged_options and "top_k" not in generation_config:
77
+ generation_config["top_k"] = merged_options["top_k"]
78
+
79
+ # Native JSON mode support
80
+ if merged_options.get("json_mode"):
81
+ generation_config["response_mime_type"] = "application/json"
82
+ json_schema = merged_options.get("json_schema")
83
+ if json_schema:
84
+ generation_config["response_schema"] = json_schema
85
+
86
+ # Convert messages to Gemini format
87
+ system_instruction = None
88
+ contents: list[dict[str, Any]] = []
89
+ for msg in messages:
90
+ role = msg.get("role", "user")
91
+ content = msg.get("content", "")
92
+ if role == "system":
93
+ system_instruction = content
94
+ else:
95
+ gemini_role = "model" if role == "assistant" else "user"
96
+ contents.append({"role": gemini_role, "parts": [content]})
97
+
98
+ try:
99
+ model_kwargs: dict[str, Any] = {}
100
+ if system_instruction:
101
+ model_kwargs["system_instruction"] = system_instruction
102
+ model = genai.GenerativeModel(self.model, **model_kwargs)
103
+
104
+ gen_input: Any = contents if len(contents) != 1 else contents[0]["parts"][0]
105
+ response = await model.generate_content_async(
106
+ gen_input,
107
+ generation_config=generation_config if generation_config else None,
108
+ safety_settings=safety_settings if safety_settings else None,
109
+ )
110
+
111
+ if not response.text:
112
+ raise ValueError("Empty response from model")
113
+
114
+ total_prompt_chars = sum(len(msg.get("content", "")) for msg in messages)
115
+ completion_chars = len(response.text)
116
+
117
+ total_cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
118
+
119
+ meta = {
120
+ "prompt_chars": total_prompt_chars,
121
+ "completion_chars": completion_chars,
122
+ "total_chars": total_prompt_chars + completion_chars,
123
+ "cost": total_cost,
124
+ "raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
125
+ "model_name": self.model,
126
+ }
127
+
128
+ return {"text": response.text, "meta": meta}
129
+
130
+ except Exception as e:
131
+ logger.error(f"Google API request failed: {e}")
132
+ raise RuntimeError(f"Google API request failed: {e}") from e
@@ -0,0 +1,91 @@
1
+ """Async xAI Grok driver using httpx."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from typing import Any
7
+
8
+ import httpx
9
+
10
+ from ..async_driver import AsyncDriver
11
+ from ..cost_mixin import CostMixin
12
+ from .grok_driver import GrokDriver
13
+
14
+
15
+ class AsyncGrokDriver(CostMixin, AsyncDriver):
16
+ supports_json_mode = True
17
+
18
+ MODEL_PRICING = GrokDriver.MODEL_PRICING
19
+ _PRICING_UNIT = 1_000_000
20
+
21
+ def __init__(self, api_key: str | None = None, model: str = "grok-4-fast-reasoning"):
22
+ self.api_key = api_key or os.getenv("GROK_API_KEY")
23
+ self.model = model
24
+ self.api_base = "https://api.x.ai/v1"
25
+
26
+ supports_messages = True
27
+
28
+ async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
29
+ messages = [{"role": "user", "content": prompt}]
30
+ return await self._do_generate(messages, options)
31
+
32
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
33
+ return await self._do_generate(messages, options)
34
+
35
+ async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
36
+ if not self.api_key:
37
+ raise RuntimeError("GROK_API_KEY environment variable is required")
38
+
39
+ model = options.get("model", self.model)
40
+
41
+ model_info = self.MODEL_PRICING.get(model, {})
42
+ tokens_param = model_info.get("tokens_param", "max_tokens")
43
+ supports_temperature = model_info.get("supports_temperature", True)
44
+
45
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
46
+
47
+ payload = {
48
+ "model": model,
49
+ "messages": messages,
50
+ }
51
+ payload[tokens_param] = opts.get("max_tokens", 512)
52
+
53
+ if supports_temperature and "temperature" in opts:
54
+ payload["temperature"] = opts["temperature"]
55
+
56
+ # Native JSON mode support
57
+ if options.get("json_mode"):
58
+ payload["response_format"] = {"type": "json_object"}
59
+
60
+ headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
61
+
62
+ async with httpx.AsyncClient() as client:
63
+ try:
64
+ response = await client.post(
65
+ f"{self.api_base}/chat/completions", headers=headers, json=payload, timeout=120
66
+ )
67
+ response.raise_for_status()
68
+ resp = response.json()
69
+ except httpx.HTTPStatusError as e:
70
+ raise RuntimeError(f"Grok API request failed: {e!s}") from e
71
+ except Exception as e:
72
+ raise RuntimeError(f"Grok API request failed: {e!s}") from e
73
+
74
+ usage = resp.get("usage", {})
75
+ prompt_tokens = usage.get("prompt_tokens", 0)
76
+ completion_tokens = usage.get("completion_tokens", 0)
77
+ total_tokens = usage.get("total_tokens", 0)
78
+
79
+ total_cost = self._calculate_cost("grok", model, prompt_tokens, completion_tokens)
80
+
81
+ meta = {
82
+ "prompt_tokens": prompt_tokens,
83
+ "completion_tokens": completion_tokens,
84
+ "total_tokens": total_tokens,
85
+ "cost": total_cost,
86
+ "raw_response": resp,
87
+ "model_name": model,
88
+ }
89
+
90
+ text = resp["choices"][0]["message"]["content"]
91
+ return {"text": text, "meta": meta}
@@ -0,0 +1,84 @@
1
+ """Async Groq driver. Requires the ``groq`` package."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from typing import Any
7
+
8
+ try:
9
+ import groq
10
+ except Exception:
11
+ groq = None
12
+
13
+ from ..async_driver import AsyncDriver
14
+ from ..cost_mixin import CostMixin
15
+ from .groq_driver import GroqDriver
16
+
17
+
18
+ class AsyncGroqDriver(CostMixin, AsyncDriver):
19
+ supports_json_mode = True
20
+
21
+ MODEL_PRICING = GroqDriver.MODEL_PRICING
22
+
23
+ def __init__(self, api_key: str | None = None, model: str = "llama2-70b-4096"):
24
+ self.api_key = api_key or os.getenv("GROQ_API_KEY")
25
+ self.model = model
26
+ if groq:
27
+ self.client = groq.AsyncClient(api_key=self.api_key)
28
+ else:
29
+ self.client = None
30
+
31
+ supports_messages = True
32
+
33
+ async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
34
+ messages = [{"role": "user", "content": prompt}]
35
+ return await self._do_generate(messages, options)
36
+
37
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
38
+ return await self._do_generate(messages, options)
39
+
40
+ async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
41
+ if self.client is None:
42
+ raise RuntimeError("groq package is not installed")
43
+
44
+ model = options.get("model", self.model)
45
+
46
+ model_info = self.MODEL_PRICING.get(model, {})
47
+ tokens_param = model_info.get("tokens_param", "max_tokens")
48
+ supports_temperature = model_info.get("supports_temperature", True)
49
+
50
+ opts = {"temperature": 0.7, "max_tokens": 512, **options}
51
+
52
+ kwargs = {
53
+ "model": model,
54
+ "messages": messages,
55
+ }
56
+ kwargs[tokens_param] = opts.get("max_tokens", 512)
57
+
58
+ if supports_temperature and "temperature" in opts:
59
+ kwargs["temperature"] = opts["temperature"]
60
+
61
+ # Native JSON mode support
62
+ if options.get("json_mode"):
63
+ kwargs["response_format"] = {"type": "json_object"}
64
+
65
+ resp = await self.client.chat.completions.create(**kwargs)
66
+
67
+ usage = getattr(resp, "usage", None)
68
+ prompt_tokens = getattr(usage, "prompt_tokens", 0)
69
+ completion_tokens = getattr(usage, "completion_tokens", 0)
70
+ total_tokens = getattr(usage, "total_tokens", 0)
71
+
72
+ total_cost = self._calculate_cost("groq", model, prompt_tokens, completion_tokens)
73
+
74
+ meta = {
75
+ "prompt_tokens": prompt_tokens,
76
+ "completion_tokens": completion_tokens,
77
+ "total_tokens": total_tokens,
78
+ "cost": total_cost,
79
+ "raw_response": resp.model_dump(),
80
+ "model_name": model,
81
+ }
82
+
83
+ text = resp.choices[0].message.content
84
+ return {"text": text, "meta": meta}
@@ -0,0 +1,61 @@
1
+ """Async Hugging Face driver using httpx."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from typing import Any
7
+
8
+ import httpx
9
+
10
+ from ..async_driver import AsyncDriver
11
+
12
+
13
+ class AsyncHuggingFaceDriver(AsyncDriver):
14
+ MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
15
+
16
+ def __init__(self, endpoint: str | None = None, token: str | None = None, model: str = "bert-base-uncased"):
17
+ self.endpoint = endpoint or os.getenv("HF_ENDPOINT")
18
+ self.token = token or os.getenv("HF_TOKEN")
19
+ self.model = model
20
+
21
+ if not self.endpoint:
22
+ raise ValueError("Hugging Face endpoint is not configured. Set HF_ENDPOINT or pass explicitly.")
23
+ if not self.token:
24
+ raise ValueError("Hugging Face token is not configured. Set HF_TOKEN or pass explicitly.")
25
+
26
+ self.headers = {"Authorization": f"Bearer {self.token}"}
27
+
28
+ async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
29
+ payload = {
30
+ "inputs": prompt,
31
+ "parameters": options,
32
+ }
33
+
34
+ async with httpx.AsyncClient() as client:
35
+ try:
36
+ r = await client.post(
37
+ self.endpoint, headers=self.headers, json=payload, timeout=options.get("timeout", 60)
38
+ )
39
+ r.raise_for_status()
40
+ response_data = r.json()
41
+ except Exception as e:
42
+ raise RuntimeError(f"AsyncHuggingFaceDriver request failed: {e}") from e
43
+
44
+ text = None
45
+ if isinstance(response_data, list) and response_data and "generated_text" in response_data[0]:
46
+ text = response_data[0]["generated_text"]
47
+ elif isinstance(response_data, dict) and "generated_text" in response_data:
48
+ text = response_data["generated_text"]
49
+ else:
50
+ text = str(response_data)
51
+
52
+ meta = {
53
+ "prompt_tokens": 0,
54
+ "completion_tokens": 0,
55
+ "total_tokens": 0,
56
+ "cost": 0.0,
57
+ "raw_response": response_data,
58
+ "model_name": options.get("model", self.model),
59
+ }
60
+
61
+ return {"text": text, "meta": meta}
@@ -0,0 +1,79 @@
1
+ """Async LM Studio driver using httpx."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ from typing import Any
8
+
9
+ import httpx
10
+
11
+ from ..async_driver import AsyncDriver
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class AsyncLMStudioDriver(AsyncDriver):
17
+ supports_json_mode = True
18
+
19
+ MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
20
+
21
+ def __init__(self, endpoint: str | None = None, model: str = "deepseek/deepseek-r1-0528-qwen3-8b"):
22
+ self.endpoint = endpoint or os.getenv("LMSTUDIO_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions")
23
+ self.model = model
24
+ self.options: dict[str, Any] = {}
25
+
26
+ supports_messages = True
27
+
28
+ async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
29
+ messages = [{"role": "user", "content": prompt}]
30
+ return await self._do_generate(messages, options)
31
+
32
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
33
+ return await self._do_generate(messages, options)
34
+
35
+ async def _do_generate(
36
+ self, messages: list[dict[str, str]], options: dict[str, Any] | None = None
37
+ ) -> dict[str, Any]:
38
+ merged_options = self.options.copy()
39
+ if options:
40
+ merged_options.update(options)
41
+
42
+ payload = {
43
+ "model": merged_options.get("model", self.model),
44
+ "messages": messages,
45
+ "temperature": merged_options.get("temperature", 0.7),
46
+ }
47
+
48
+ # Native JSON mode support
49
+ if merged_options.get("json_mode"):
50
+ payload["response_format"] = {"type": "json_object"}
51
+
52
+ async with httpx.AsyncClient() as client:
53
+ try:
54
+ r = await client.post(self.endpoint, json=payload, timeout=120)
55
+ r.raise_for_status()
56
+ response_data = r.json()
57
+ except Exception as e:
58
+ raise RuntimeError(f"AsyncLMStudioDriver request failed: {e}") from e
59
+
60
+ if "choices" not in response_data or not response_data["choices"]:
61
+ raise ValueError(f"Unexpected response format: {response_data}")
62
+
63
+ text = response_data["choices"][0]["message"]["content"]
64
+
65
+ usage = response_data.get("usage", {})
66
+ prompt_tokens = usage.get("prompt_tokens", 0)
67
+ completion_tokens = usage.get("completion_tokens", 0)
68
+ total_tokens = usage.get("total_tokens", prompt_tokens + completion_tokens)
69
+
70
+ meta = {
71
+ "prompt_tokens": prompt_tokens,
72
+ "completion_tokens": completion_tokens,
73
+ "total_tokens": total_tokens,
74
+ "cost": 0.0,
75
+ "raw_response": response_data,
76
+ "model_name": merged_options.get("model", self.model),
77
+ }
78
+
79
+ return {"text": text, "meta": meta}