prompture 0.0.29.dev8__py3-none-any.whl → 0.0.38.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. prompture/__init__.py +264 -23
  2. prompture/_version.py +34 -0
  3. prompture/agent.py +924 -0
  4. prompture/agent_types.py +156 -0
  5. prompture/aio/__init__.py +74 -0
  6. prompture/async_agent.py +880 -0
  7. prompture/async_conversation.py +789 -0
  8. prompture/async_core.py +803 -0
  9. prompture/async_driver.py +193 -0
  10. prompture/async_groups.py +551 -0
  11. prompture/cache.py +469 -0
  12. prompture/callbacks.py +55 -0
  13. prompture/cli.py +63 -4
  14. prompture/conversation.py +826 -0
  15. prompture/core.py +894 -263
  16. prompture/cost_mixin.py +51 -0
  17. prompture/discovery.py +187 -0
  18. prompture/driver.py +206 -5
  19. prompture/drivers/__init__.py +175 -67
  20. prompture/drivers/airllm_driver.py +109 -0
  21. prompture/drivers/async_airllm_driver.py +26 -0
  22. prompture/drivers/async_azure_driver.py +123 -0
  23. prompture/drivers/async_claude_driver.py +113 -0
  24. prompture/drivers/async_google_driver.py +316 -0
  25. prompture/drivers/async_grok_driver.py +97 -0
  26. prompture/drivers/async_groq_driver.py +90 -0
  27. prompture/drivers/async_hugging_driver.py +61 -0
  28. prompture/drivers/async_lmstudio_driver.py +148 -0
  29. prompture/drivers/async_local_http_driver.py +44 -0
  30. prompture/drivers/async_ollama_driver.py +135 -0
  31. prompture/drivers/async_openai_driver.py +102 -0
  32. prompture/drivers/async_openrouter_driver.py +102 -0
  33. prompture/drivers/async_registry.py +133 -0
  34. prompture/drivers/azure_driver.py +42 -9
  35. prompture/drivers/claude_driver.py +257 -34
  36. prompture/drivers/google_driver.py +295 -42
  37. prompture/drivers/grok_driver.py +35 -32
  38. prompture/drivers/groq_driver.py +33 -26
  39. prompture/drivers/hugging_driver.py +6 -6
  40. prompture/drivers/lmstudio_driver.py +97 -19
  41. prompture/drivers/local_http_driver.py +6 -6
  42. prompture/drivers/ollama_driver.py +168 -23
  43. prompture/drivers/openai_driver.py +184 -9
  44. prompture/drivers/openrouter_driver.py +37 -25
  45. prompture/drivers/registry.py +306 -0
  46. prompture/drivers/vision_helpers.py +153 -0
  47. prompture/field_definitions.py +106 -96
  48. prompture/group_types.py +147 -0
  49. prompture/groups.py +530 -0
  50. prompture/image.py +180 -0
  51. prompture/logging.py +80 -0
  52. prompture/model_rates.py +217 -0
  53. prompture/persistence.py +254 -0
  54. prompture/persona.py +482 -0
  55. prompture/runner.py +49 -47
  56. prompture/scaffold/__init__.py +1 -0
  57. prompture/scaffold/generator.py +84 -0
  58. prompture/scaffold/templates/Dockerfile.j2 +12 -0
  59. prompture/scaffold/templates/README.md.j2 +41 -0
  60. prompture/scaffold/templates/config.py.j2 +21 -0
  61. prompture/scaffold/templates/env.example.j2 +8 -0
  62. prompture/scaffold/templates/main.py.j2 +86 -0
  63. prompture/scaffold/templates/models.py.j2 +40 -0
  64. prompture/scaffold/templates/requirements.txt.j2 +5 -0
  65. prompture/serialization.py +218 -0
  66. prompture/server.py +183 -0
  67. prompture/session.py +117 -0
  68. prompture/settings.py +19 -1
  69. prompture/tools.py +219 -267
  70. prompture/tools_schema.py +254 -0
  71. prompture/validator.py +3 -3
  72. prompture-0.0.38.dev2.dist-info/METADATA +369 -0
  73. prompture-0.0.38.dev2.dist-info/RECORD +77 -0
  74. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/WHEEL +1 -1
  75. prompture-0.0.29.dev8.dist-info/METADATA +0 -368
  76. prompture-0.0.29.dev8.dist-info/RECORD +0 -27
  77. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/entry_points.txt +0 -0
  78. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/licenses/LICENSE +0 -0
  79. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,148 @@
1
+ """Async LM Studio driver using httpx."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ from typing import Any
8
+
9
+ import httpx
10
+
11
+ from ..async_driver import AsyncDriver
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class AsyncLMStudioDriver(AsyncDriver):
17
+ supports_json_mode = True
18
+ supports_json_schema = True
19
+ supports_vision = True
20
+
21
+ MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
22
+
23
+ def __init__(
24
+ self,
25
+ endpoint: str | None = None,
26
+ model: str = "deepseek/deepseek-r1-0528-qwen3-8b",
27
+ api_key: str | None = None,
28
+ ):
29
+ self.endpoint = endpoint or os.getenv("LMSTUDIO_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions")
30
+ self.model = model
31
+ self.options: dict[str, Any] = {}
32
+
33
+ # Derive base_url once for reuse across management endpoints
34
+ self.base_url = self.endpoint.split("/v1/")[0]
35
+
36
+ # API key for LM Studio 0.4.0+ authentication
37
+ self.api_key = api_key or os.getenv("LMSTUDIO_API_KEY")
38
+ self._headers = self._build_headers()
39
+
40
+ supports_messages = True
41
+
42
+ def _build_headers(self) -> dict[str, str]:
43
+ """Build request headers, including auth if an API key is configured."""
44
+ headers: dict[str, str] = {"Content-Type": "application/json"}
45
+ if self.api_key:
46
+ headers["Authorization"] = f"Bearer {self.api_key}"
47
+ return headers
48
+
49
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
50
+ from .vision_helpers import _prepare_openai_vision_messages
51
+
52
+ return _prepare_openai_vision_messages(messages)
53
+
54
+ async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
55
+ messages = [{"role": "user", "content": prompt}]
56
+ return await self._do_generate(messages, options)
57
+
58
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
59
+ return await self._do_generate(self._prepare_messages(messages), options)
60
+
61
+ async def _do_generate(
62
+ self, messages: list[dict[str, str]], options: dict[str, Any] | None = None
63
+ ) -> dict[str, Any]:
64
+ merged_options = self.options.copy()
65
+ if options:
66
+ merged_options.update(options)
67
+
68
+ payload = {
69
+ "model": merged_options.get("model", self.model),
70
+ "messages": messages,
71
+ "temperature": merged_options.get("temperature", 0.7),
72
+ }
73
+
74
+ # Native JSON mode support (LM Studio requires json_schema, not json_object)
75
+ if merged_options.get("json_mode"):
76
+ json_schema = merged_options.get("json_schema")
77
+ if json_schema:
78
+ payload["response_format"] = {
79
+ "type": "json_schema",
80
+ "json_schema": {
81
+ "name": "extraction",
82
+ "schema": json_schema,
83
+ },
84
+ }
85
+ else:
86
+ # No schema provided — omit response_format entirely;
87
+ # LM Studio rejects "json_object" type.
88
+ pass
89
+
90
+ async with httpx.AsyncClient() as client:
91
+ try:
92
+ r = await client.post(self.endpoint, json=payload, headers=self._headers, timeout=120)
93
+ r.raise_for_status()
94
+ response_data = r.json()
95
+ except Exception as e:
96
+ raise RuntimeError(f"AsyncLMStudioDriver request failed: {e}") from e
97
+
98
+ if "choices" not in response_data or not response_data["choices"]:
99
+ raise ValueError(f"Unexpected response format: {response_data}")
100
+
101
+ text = response_data["choices"][0]["message"]["content"]
102
+
103
+ usage = response_data.get("usage", {})
104
+ prompt_tokens = usage.get("prompt_tokens", 0)
105
+ completion_tokens = usage.get("completion_tokens", 0)
106
+ total_tokens = usage.get("total_tokens", prompt_tokens + completion_tokens)
107
+
108
+ meta = {
109
+ "prompt_tokens": prompt_tokens,
110
+ "completion_tokens": completion_tokens,
111
+ "total_tokens": total_tokens,
112
+ "cost": 0.0,
113
+ "raw_response": response_data,
114
+ "model_name": merged_options.get("model", self.model),
115
+ }
116
+
117
+ return {"text": text, "meta": meta}
118
+
119
+ # -- Model management (LM Studio 0.4.0+) ----------------------------------
120
+
121
+ async def list_models(self) -> list[dict[str, Any]]:
122
+ """List currently loaded models via GET /v1/models (OpenAI-compatible)."""
123
+ url = f"{self.base_url}/v1/models"
124
+ async with httpx.AsyncClient() as client:
125
+ r = await client.get(url, headers=self._headers, timeout=10)
126
+ r.raise_for_status()
127
+ data = r.json()
128
+ return data.get("data", [])
129
+
130
+ async def load_model(self, model: str, context_length: int | None = None) -> dict[str, Any]:
131
+ """Load a model into LM Studio via POST /api/v1/models/load."""
132
+ url = f"{self.base_url}/api/v1/models/load"
133
+ payload: dict[str, Any] = {"model": model}
134
+ if context_length is not None:
135
+ payload["context_length"] = context_length
136
+ async with httpx.AsyncClient() as client:
137
+ r = await client.post(url, json=payload, headers=self._headers, timeout=120)
138
+ r.raise_for_status()
139
+ return r.json()
140
+
141
+ async def unload_model(self, model: str) -> dict[str, Any]:
142
+ """Unload a model from LM Studio via POST /api/v1/models/unload."""
143
+ url = f"{self.base_url}/api/v1/models/unload"
144
+ payload = {"instance_id": model}
145
+ async with httpx.AsyncClient() as client:
146
+ r = await client.post(url, json=payload, headers=self._headers, timeout=30)
147
+ r.raise_for_status()
148
+ return r.json()
@@ -0,0 +1,44 @@
1
+ """Async LocalHTTP driver using httpx."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from typing import Any
7
+
8
+ import httpx
9
+
10
+ from ..async_driver import AsyncDriver
11
+
12
+
13
+ class AsyncLocalHTTPDriver(AsyncDriver):
14
+ MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
15
+
16
+ def __init__(self, endpoint: str | None = None, model: str = "local-model"):
17
+ self.endpoint = endpoint or os.getenv("LOCAL_HTTP_ENDPOINT", "http://localhost:8000/generate")
18
+ self.model = model
19
+
20
+ async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
21
+ payload = {"prompt": prompt, "options": options}
22
+
23
+ async with httpx.AsyncClient() as client:
24
+ try:
25
+ r = await client.post(self.endpoint, json=payload, timeout=options.get("timeout", 30))
26
+ r.raise_for_status()
27
+ response_data = r.json()
28
+ except Exception as e:
29
+ raise RuntimeError(f"AsyncLocalHTTPDriver request failed: {e}") from e
30
+
31
+ if "text" in response_data and "meta" in response_data:
32
+ return response_data
33
+
34
+ meta = {
35
+ "prompt_tokens": response_data.get("prompt_tokens", 0),
36
+ "completion_tokens": response_data.get("completion_tokens", 0),
37
+ "total_tokens": response_data.get("total_tokens", 0),
38
+ "cost": 0.0,
39
+ "raw_response": response_data,
40
+ "model_name": options.get("model", self.model),
41
+ }
42
+
43
+ text = response_data.get("text") or response_data.get("response") or str(response_data)
44
+ return {"text": text, "meta": meta}
@@ -0,0 +1,135 @@
1
+ """Async Ollama driver using httpx."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ from typing import Any
8
+
9
+ import httpx
10
+
11
+ from ..async_driver import AsyncDriver
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class AsyncOllamaDriver(AsyncDriver):
17
+ supports_json_mode = True
18
+ supports_json_schema = True
19
+ supports_vision = True
20
+
21
+ MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
22
+
23
+ def __init__(self, endpoint: str | None = None, model: str = "llama3"):
24
+ self.endpoint = endpoint or os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434/api/generate")
25
+ self.model = model
26
+ self.options: dict[str, Any] = {}
27
+
28
+ supports_messages = True
29
+
30
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
31
+ from .vision_helpers import _prepare_ollama_vision_messages
32
+
33
+ return _prepare_ollama_vision_messages(messages)
34
+
35
+ async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
36
+ merged_options = self.options.copy()
37
+ if options:
38
+ merged_options.update(options)
39
+
40
+ payload = {
41
+ "prompt": prompt,
42
+ "model": merged_options.get("model", self.model),
43
+ "stream": False,
44
+ }
45
+
46
+ # Native JSON mode / structured output support
47
+ if merged_options.get("json_mode"):
48
+ json_schema = merged_options.get("json_schema")
49
+ payload["format"] = json_schema if json_schema else "json"
50
+
51
+ if "temperature" in merged_options:
52
+ payload["temperature"] = merged_options["temperature"]
53
+ if "top_p" in merged_options:
54
+ payload["top_p"] = merged_options["top_p"]
55
+ if "top_k" in merged_options:
56
+ payload["top_k"] = merged_options["top_k"]
57
+
58
+ async with httpx.AsyncClient() as client:
59
+ try:
60
+ r = await client.post(self.endpoint, json=payload, timeout=120)
61
+ r.raise_for_status()
62
+ response_data = r.json()
63
+ except httpx.HTTPStatusError as e:
64
+ raise RuntimeError(f"Ollama request failed: {e}") from e
65
+ except Exception as e:
66
+ raise RuntimeError(f"Ollama request failed: {e}") from e
67
+
68
+ prompt_tokens = response_data.get("prompt_eval_count", 0)
69
+ completion_tokens = response_data.get("eval_count", 0)
70
+ total_tokens = prompt_tokens + completion_tokens
71
+
72
+ meta = {
73
+ "prompt_tokens": prompt_tokens,
74
+ "completion_tokens": completion_tokens,
75
+ "total_tokens": total_tokens,
76
+ "cost": 0.0,
77
+ "raw_response": response_data,
78
+ "model_name": merged_options.get("model", self.model),
79
+ }
80
+
81
+ return {"text": response_data.get("response", ""), "meta": meta}
82
+
83
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
84
+ """Use Ollama's /api/chat endpoint for multi-turn conversations."""
85
+ messages = self._prepare_messages(messages)
86
+ merged_options = self.options.copy()
87
+ if options:
88
+ merged_options.update(options)
89
+
90
+ # Derive the chat endpoint from the generate endpoint
91
+ chat_endpoint = self.endpoint.replace("/api/generate", "/api/chat")
92
+
93
+ payload: dict[str, Any] = {
94
+ "model": merged_options.get("model", self.model),
95
+ "messages": messages,
96
+ "stream": False,
97
+ }
98
+
99
+ if merged_options.get("json_mode"):
100
+ json_schema = merged_options.get("json_schema")
101
+ payload["format"] = json_schema if json_schema else "json"
102
+
103
+ if "temperature" in merged_options:
104
+ payload["temperature"] = merged_options["temperature"]
105
+ if "top_p" in merged_options:
106
+ payload["top_p"] = merged_options["top_p"]
107
+ if "top_k" in merged_options:
108
+ payload["top_k"] = merged_options["top_k"]
109
+
110
+ async with httpx.AsyncClient() as client:
111
+ try:
112
+ r = await client.post(chat_endpoint, json=payload, timeout=120)
113
+ r.raise_for_status()
114
+ response_data = r.json()
115
+ except httpx.HTTPStatusError as e:
116
+ raise RuntimeError(f"Ollama chat request failed: {e}") from e
117
+ except Exception as e:
118
+ raise RuntimeError(f"Ollama chat request failed: {e}") from e
119
+
120
+ prompt_tokens = response_data.get("prompt_eval_count", 0)
121
+ completion_tokens = response_data.get("eval_count", 0)
122
+ total_tokens = prompt_tokens + completion_tokens
123
+
124
+ meta = {
125
+ "prompt_tokens": prompt_tokens,
126
+ "completion_tokens": completion_tokens,
127
+ "total_tokens": total_tokens,
128
+ "cost": 0.0,
129
+ "raw_response": response_data,
130
+ "model_name": merged_options.get("model", self.model),
131
+ }
132
+
133
+ message = response_data.get("message", {})
134
+ text = message.get("content", "")
135
+ return {"text": text, "meta": meta}
@@ -0,0 +1,102 @@
1
+ """Async OpenAI driver. Requires the ``openai`` package (>=1.0.0)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from typing import Any
7
+
8
+ try:
9
+ from openai import AsyncOpenAI
10
+ except Exception:
11
+ AsyncOpenAI = None
12
+
13
+ from ..async_driver import AsyncDriver
14
+ from ..cost_mixin import CostMixin
15
+ from .openai_driver import OpenAIDriver
16
+
17
+
18
+ class AsyncOpenAIDriver(CostMixin, AsyncDriver):
19
+ supports_json_mode = True
20
+ supports_json_schema = True
21
+ supports_vision = True
22
+
23
+ MODEL_PRICING = OpenAIDriver.MODEL_PRICING
24
+
25
+ def __init__(self, api_key: str | None = None, model: str = "gpt-4o-mini"):
26
+ self.api_key = api_key or os.getenv("OPENAI_API_KEY")
27
+ self.model = model
28
+ if AsyncOpenAI:
29
+ self.client = AsyncOpenAI(api_key=self.api_key)
30
+ else:
31
+ self.client = None
32
+
33
+ supports_messages = True
34
+
35
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
36
+ from .vision_helpers import _prepare_openai_vision_messages
37
+
38
+ return _prepare_openai_vision_messages(messages)
39
+
40
+ async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
41
+ messages = [{"role": "user", "content": prompt}]
42
+ return await self._do_generate(messages, options)
43
+
44
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
45
+ return await self._do_generate(self._prepare_messages(messages), options)
46
+
47
+ async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
48
+ if self.client is None:
49
+ raise RuntimeError("openai package (>=1.0.0) is not installed")
50
+
51
+ model = options.get("model", self.model)
52
+
53
+ model_info = self.MODEL_PRICING.get(model, {})
54
+ tokens_param = model_info.get("tokens_param", "max_tokens")
55
+ supports_temperature = model_info.get("supports_temperature", True)
56
+
57
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
58
+
59
+ kwargs = {
60
+ "model": model,
61
+ "messages": messages,
62
+ }
63
+ kwargs[tokens_param] = opts.get("max_tokens", 512)
64
+
65
+ if supports_temperature and "temperature" in opts:
66
+ kwargs["temperature"] = opts["temperature"]
67
+
68
+ # Native JSON mode support
69
+ if options.get("json_mode"):
70
+ json_schema = options.get("json_schema")
71
+ if json_schema:
72
+ kwargs["response_format"] = {
73
+ "type": "json_schema",
74
+ "json_schema": {
75
+ "name": "extraction",
76
+ "strict": True,
77
+ "schema": json_schema,
78
+ },
79
+ }
80
+ else:
81
+ kwargs["response_format"] = {"type": "json_object"}
82
+
83
+ resp = await self.client.chat.completions.create(**kwargs)
84
+
85
+ usage = getattr(resp, "usage", None)
86
+ prompt_tokens = getattr(usage, "prompt_tokens", 0)
87
+ completion_tokens = getattr(usage, "completion_tokens", 0)
88
+ total_tokens = getattr(usage, "total_tokens", 0)
89
+
90
+ total_cost = self._calculate_cost("openai", model, prompt_tokens, completion_tokens)
91
+
92
+ meta = {
93
+ "prompt_tokens": prompt_tokens,
94
+ "completion_tokens": completion_tokens,
95
+ "total_tokens": total_tokens,
96
+ "cost": total_cost,
97
+ "raw_response": resp.model_dump(),
98
+ "model_name": model,
99
+ }
100
+
101
+ text = resp.choices[0].message.content
102
+ return {"text": text, "meta": meta}
@@ -0,0 +1,102 @@
1
+ """Async OpenRouter driver using httpx."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from typing import Any
7
+
8
+ import httpx
9
+
10
+ from ..async_driver import AsyncDriver
11
+ from ..cost_mixin import CostMixin
12
+ from .openrouter_driver import OpenRouterDriver
13
+
14
+
15
+ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
16
+ supports_json_mode = True
17
+ supports_vision = True
18
+
19
+ MODEL_PRICING = OpenRouterDriver.MODEL_PRICING
20
+
21
+ def __init__(self, api_key: str | None = None, model: str = "openai/gpt-3.5-turbo"):
22
+ self.api_key = api_key or os.getenv("OPENROUTER_API_KEY")
23
+ if not self.api_key:
24
+ raise ValueError("OpenRouter API key not found. Set OPENROUTER_API_KEY env var.")
25
+ self.model = model
26
+ self.base_url = "https://openrouter.ai/api/v1"
27
+ self.headers = {
28
+ "Authorization": f"Bearer {self.api_key}",
29
+ "HTTP-Referer": "https://github.com/jhd3197/prompture",
30
+ "Content-Type": "application/json",
31
+ }
32
+
33
+ supports_messages = True
34
+
35
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
36
+ from .vision_helpers import _prepare_openai_vision_messages
37
+
38
+ return _prepare_openai_vision_messages(messages)
39
+
40
+ async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
41
+ messages = [{"role": "user", "content": prompt}]
42
+ return await self._do_generate(messages, options)
43
+
44
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
45
+ return await self._do_generate(self._prepare_messages(messages), options)
46
+
47
+ async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
48
+ model = options.get("model", self.model)
49
+
50
+ model_info = self.MODEL_PRICING.get(model, {})
51
+ tokens_param = model_info.get("tokens_param", "max_tokens")
52
+ supports_temperature = model_info.get("supports_temperature", True)
53
+
54
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
55
+
56
+ data = {
57
+ "model": model,
58
+ "messages": messages,
59
+ }
60
+ data[tokens_param] = opts.get("max_tokens", 512)
61
+
62
+ if supports_temperature and "temperature" in opts:
63
+ data["temperature"] = opts["temperature"]
64
+
65
+ # Native JSON mode support
66
+ if options.get("json_mode"):
67
+ data["response_format"] = {"type": "json_object"}
68
+
69
+ async with httpx.AsyncClient() as client:
70
+ try:
71
+ response = await client.post(
72
+ f"{self.base_url}/chat/completions",
73
+ headers=self.headers,
74
+ json=data,
75
+ timeout=120,
76
+ )
77
+ response.raise_for_status()
78
+ resp = response.json()
79
+ except httpx.HTTPStatusError as e:
80
+ error_msg = f"OpenRouter API request failed: {e!s}"
81
+ raise RuntimeError(error_msg) from e
82
+ except Exception as e:
83
+ raise RuntimeError(f"OpenRouter API request failed: {e!s}") from e
84
+
85
+ usage = resp.get("usage", {})
86
+ prompt_tokens = usage.get("prompt_tokens", 0)
87
+ completion_tokens = usage.get("completion_tokens", 0)
88
+ total_tokens = usage.get("total_tokens", 0)
89
+
90
+ total_cost = self._calculate_cost("openrouter", model, prompt_tokens, completion_tokens)
91
+
92
+ meta = {
93
+ "prompt_tokens": prompt_tokens,
94
+ "completion_tokens": completion_tokens,
95
+ "total_tokens": total_tokens,
96
+ "cost": total_cost,
97
+ "raw_response": resp,
98
+ "model_name": model,
99
+ }
100
+
101
+ text = resp["choices"][0]["message"]["content"]
102
+ return {"text": text, "meta": meta}
@@ -0,0 +1,133 @@
1
+ """Async driver registry — mirrors the sync DRIVER_REGISTRY.
2
+
3
+ This module provides async driver registration and factory functions.
4
+ Custom async drivers can be registered via the ``register_async_driver()``
5
+ function or discovered via entry points.
6
+
7
+ Entry Point Discovery:
8
+ Add to your pyproject.toml:
9
+
10
+ [project.entry-points."prompture.async_drivers"]
11
+ my_provider = "my_package.drivers:my_async_driver_factory"
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from ..settings import settings
17
+ from .async_airllm_driver import AsyncAirLLMDriver
18
+ from .async_azure_driver import AsyncAzureDriver
19
+ from .async_claude_driver import AsyncClaudeDriver
20
+ from .async_google_driver import AsyncGoogleDriver
21
+ from .async_grok_driver import AsyncGrokDriver
22
+ from .async_groq_driver import AsyncGroqDriver
23
+ from .async_lmstudio_driver import AsyncLMStudioDriver
24
+ from .async_local_http_driver import AsyncLocalHTTPDriver
25
+ from .async_ollama_driver import AsyncOllamaDriver
26
+ from .async_openai_driver import AsyncOpenAIDriver
27
+ from .async_openrouter_driver import AsyncOpenRouterDriver
28
+ from .registry import (
29
+ _get_async_registry,
30
+ get_async_driver_factory,
31
+ register_async_driver,
32
+ )
33
+
34
+ # Register built-in async drivers
35
+ register_async_driver(
36
+ "openai",
37
+ lambda model=None: AsyncOpenAIDriver(api_key=settings.openai_api_key, model=model or settings.openai_model),
38
+ overwrite=True,
39
+ )
40
+ register_async_driver(
41
+ "ollama",
42
+ lambda model=None: AsyncOllamaDriver(endpoint=settings.ollama_endpoint, model=model or settings.ollama_model),
43
+ overwrite=True,
44
+ )
45
+ register_async_driver(
46
+ "claude",
47
+ lambda model=None: AsyncClaudeDriver(api_key=settings.claude_api_key, model=model or settings.claude_model),
48
+ overwrite=True,
49
+ )
50
+ register_async_driver(
51
+ "lmstudio",
52
+ lambda model=None: AsyncLMStudioDriver(
53
+ endpoint=settings.lmstudio_endpoint,
54
+ model=model or settings.lmstudio_model,
55
+ api_key=settings.lmstudio_api_key,
56
+ ),
57
+ overwrite=True,
58
+ )
59
+ register_async_driver(
60
+ "azure",
61
+ lambda model=None: AsyncAzureDriver(
62
+ api_key=settings.azure_api_key, endpoint=settings.azure_api_endpoint, deployment_id=settings.azure_deployment_id
63
+ ),
64
+ overwrite=True,
65
+ )
66
+ register_async_driver(
67
+ "local_http",
68
+ lambda model=None: AsyncLocalHTTPDriver(endpoint=settings.local_http_endpoint, model=model),
69
+ overwrite=True,
70
+ )
71
+ register_async_driver(
72
+ "google",
73
+ lambda model=None: AsyncGoogleDriver(api_key=settings.google_api_key, model=model or settings.google_model),
74
+ overwrite=True,
75
+ )
76
+ register_async_driver(
77
+ "groq",
78
+ lambda model=None: AsyncGroqDriver(api_key=settings.groq_api_key, model=model or settings.groq_model),
79
+ overwrite=True,
80
+ )
81
+ register_async_driver(
82
+ "openrouter",
83
+ lambda model=None: AsyncOpenRouterDriver(
84
+ api_key=settings.openrouter_api_key, model=model or settings.openrouter_model
85
+ ),
86
+ overwrite=True,
87
+ )
88
+ register_async_driver(
89
+ "grok",
90
+ lambda model=None: AsyncGrokDriver(api_key=settings.grok_api_key, model=model or settings.grok_model),
91
+ overwrite=True,
92
+ )
93
+ register_async_driver(
94
+ "airllm",
95
+ lambda model=None: AsyncAirLLMDriver(
96
+ model=model or settings.airllm_model,
97
+ compression=settings.airllm_compression,
98
+ ),
99
+ overwrite=True,
100
+ )
101
+
102
+ # Backwards compatibility: expose registry dict
103
+ ASYNC_DRIVER_REGISTRY = _get_async_registry()
104
+
105
+
106
+ def get_async_driver(provider_name: str | None = None):
107
+ """Factory to get an async driver instance based on the provider name.
108
+
109
+ Uses default model from settings if not overridden.
110
+ """
111
+ provider = (provider_name or settings.ai_provider or "ollama").strip().lower()
112
+ factory = get_async_driver_factory(provider)
113
+ return factory()
114
+
115
+
116
+ def get_async_driver_for_model(model_str: str):
117
+ """Factory to get an async driver instance based on a full model string.
118
+
119
+ Format: ``provider/model_id``
120
+ Example: ``"openai/gpt-4-turbo-preview"``
121
+ """
122
+ if not isinstance(model_str, str):
123
+ raise ValueError("Model string must be a string, got {type(model_str)}")
124
+
125
+ if not model_str:
126
+ raise ValueError("Model string cannot be empty")
127
+
128
+ parts = model_str.split("/", 1)
129
+ provider = parts[0].lower()
130
+ model_id = parts[1] if len(parts) > 1 else None
131
+
132
+ factory = get_async_driver_factory(provider)
133
+ return factory(model_id)