prompture 0.0.36.dev1__py3-none-any.whl → 0.0.37.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +120 -2
- prompture/_version.py +2 -2
- prompture/agent.py +925 -0
- prompture/agent_types.py +156 -0
- prompture/async_agent.py +879 -0
- prompture/async_conversation.py +199 -17
- prompture/async_driver.py +24 -0
- prompture/async_groups.py +551 -0
- prompture/conversation.py +213 -18
- prompture/core.py +30 -12
- prompture/discovery.py +24 -1
- prompture/driver.py +38 -0
- prompture/drivers/__init__.py +5 -1
- prompture/drivers/async_azure_driver.py +7 -1
- prompture/drivers/async_claude_driver.py +7 -1
- prompture/drivers/async_google_driver.py +24 -4
- prompture/drivers/async_grok_driver.py +7 -1
- prompture/drivers/async_groq_driver.py +7 -1
- prompture/drivers/async_lmstudio_driver.py +59 -3
- prompture/drivers/async_ollama_driver.py +7 -0
- prompture/drivers/async_openai_driver.py +7 -1
- prompture/drivers/async_openrouter_driver.py +7 -1
- prompture/drivers/async_registry.py +5 -1
- prompture/drivers/azure_driver.py +7 -1
- prompture/drivers/claude_driver.py +7 -1
- prompture/drivers/google_driver.py +24 -4
- prompture/drivers/grok_driver.py +7 -1
- prompture/drivers/groq_driver.py +7 -1
- prompture/drivers/lmstudio_driver.py +58 -6
- prompture/drivers/ollama_driver.py +7 -0
- prompture/drivers/openai_driver.py +7 -1
- prompture/drivers/openrouter_driver.py +7 -1
- prompture/drivers/vision_helpers.py +153 -0
- prompture/group_types.py +147 -0
- prompture/groups.py +530 -0
- prompture/image.py +180 -0
- prompture/persistence.py +254 -0
- prompture/persona.py +482 -0
- prompture/serialization.py +218 -0
- prompture/settings.py +1 -0
- {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/METADATA +1 -1
- prompture-0.0.37.dev1.dist-info/RECORD +77 -0
- prompture-0.0.36.dev1.dist-info/RECORD +0 -66
- {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/WHEEL +0 -0
- {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/top_level.txt +0 -0
|
@@ -20,6 +20,7 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
|
|
|
20
20
|
|
|
21
21
|
supports_json_mode = True
|
|
22
22
|
supports_json_schema = True
|
|
23
|
+
supports_vision = True
|
|
23
24
|
|
|
24
25
|
MODEL_PRICING = GoogleDriver.MODEL_PRICING
|
|
25
26
|
_PRICING_UNIT = 1_000_000
|
|
@@ -50,12 +51,17 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
|
|
|
50
51
|
|
|
51
52
|
supports_messages = True
|
|
52
53
|
|
|
54
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
55
|
+
from .vision_helpers import _prepare_google_vision_messages
|
|
56
|
+
|
|
57
|
+
return _prepare_google_vision_messages(messages)
|
|
58
|
+
|
|
53
59
|
async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
54
60
|
messages = [{"role": "user", "content": prompt}]
|
|
55
61
|
return await self._do_generate(messages, options)
|
|
56
62
|
|
|
57
63
|
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
58
|
-
return await self._do_generate(messages, options)
|
|
64
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
59
65
|
|
|
60
66
|
async def _do_generate(
|
|
61
67
|
self, messages: list[dict[str, str]], options: dict[str, Any] | None = None
|
|
@@ -90,10 +96,14 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
|
|
|
90
96
|
role = msg.get("role", "user")
|
|
91
97
|
content = msg.get("content", "")
|
|
92
98
|
if role == "system":
|
|
93
|
-
system_instruction = content
|
|
99
|
+
system_instruction = content if isinstance(content, str) else str(content)
|
|
94
100
|
else:
|
|
95
101
|
gemini_role = "model" if role == "assistant" else "user"
|
|
96
|
-
|
|
102
|
+
if msg.get("_vision_parts"):
|
|
103
|
+
# Already converted to Gemini parts by _prepare_messages
|
|
104
|
+
contents.append({"role": gemini_role, "parts": content})
|
|
105
|
+
else:
|
|
106
|
+
contents.append({"role": gemini_role, "parts": [content]})
|
|
97
107
|
|
|
98
108
|
try:
|
|
99
109
|
model_kwargs: dict[str, Any] = {}
|
|
@@ -111,7 +121,17 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
|
|
|
111
121
|
if not response.text:
|
|
112
122
|
raise ValueError("Empty response from model")
|
|
113
123
|
|
|
114
|
-
total_prompt_chars =
|
|
124
|
+
total_prompt_chars = 0
|
|
125
|
+
for msg in messages:
|
|
126
|
+
c = msg.get("content", "")
|
|
127
|
+
if isinstance(c, str):
|
|
128
|
+
total_prompt_chars += len(c)
|
|
129
|
+
elif isinstance(c, list):
|
|
130
|
+
for part in c:
|
|
131
|
+
if isinstance(part, str):
|
|
132
|
+
total_prompt_chars += len(part)
|
|
133
|
+
elif isinstance(part, dict) and "text" in part:
|
|
134
|
+
total_prompt_chars += len(part["text"])
|
|
115
135
|
completion_chars = len(response.text)
|
|
116
136
|
|
|
117
137
|
total_cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
|
|
@@ -14,6 +14,7 @@ from .grok_driver import GrokDriver
|
|
|
14
14
|
|
|
15
15
|
class AsyncGrokDriver(CostMixin, AsyncDriver):
|
|
16
16
|
supports_json_mode = True
|
|
17
|
+
supports_vision = True
|
|
17
18
|
|
|
18
19
|
MODEL_PRICING = GrokDriver.MODEL_PRICING
|
|
19
20
|
_PRICING_UNIT = 1_000_000
|
|
@@ -25,12 +26,17 @@ class AsyncGrokDriver(CostMixin, AsyncDriver):
|
|
|
25
26
|
|
|
26
27
|
supports_messages = True
|
|
27
28
|
|
|
29
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
30
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
31
|
+
|
|
32
|
+
return _prepare_openai_vision_messages(messages)
|
|
33
|
+
|
|
28
34
|
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
29
35
|
messages = [{"role": "user", "content": prompt}]
|
|
30
36
|
return await self._do_generate(messages, options)
|
|
31
37
|
|
|
32
38
|
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
33
|
-
return await self._do_generate(messages, options)
|
|
39
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
34
40
|
|
|
35
41
|
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
36
42
|
if not self.api_key:
|
|
@@ -17,6 +17,7 @@ from .groq_driver import GroqDriver
|
|
|
17
17
|
|
|
18
18
|
class AsyncGroqDriver(CostMixin, AsyncDriver):
|
|
19
19
|
supports_json_mode = True
|
|
20
|
+
supports_vision = True
|
|
20
21
|
|
|
21
22
|
MODEL_PRICING = GroqDriver.MODEL_PRICING
|
|
22
23
|
|
|
@@ -30,12 +31,17 @@ class AsyncGroqDriver(CostMixin, AsyncDriver):
|
|
|
30
31
|
|
|
31
32
|
supports_messages = True
|
|
32
33
|
|
|
34
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
35
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
36
|
+
|
|
37
|
+
return _prepare_openai_vision_messages(messages)
|
|
38
|
+
|
|
33
39
|
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
34
40
|
messages = [{"role": "user", "content": prompt}]
|
|
35
41
|
return await self._do_generate(messages, options)
|
|
36
42
|
|
|
37
43
|
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
38
|
-
return await self._do_generate(messages, options)
|
|
44
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
39
45
|
|
|
40
46
|
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
41
47
|
if self.client is None:
|
|
@@ -15,22 +15,47 @@ logger = logging.getLogger(__name__)
|
|
|
15
15
|
|
|
16
16
|
class AsyncLMStudioDriver(AsyncDriver):
|
|
17
17
|
supports_json_mode = True
|
|
18
|
+
supports_vision = True
|
|
18
19
|
|
|
19
20
|
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
20
21
|
|
|
21
|
-
def __init__(
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
endpoint: str | None = None,
|
|
25
|
+
model: str = "deepseek/deepseek-r1-0528-qwen3-8b",
|
|
26
|
+
api_key: str | None = None,
|
|
27
|
+
):
|
|
22
28
|
self.endpoint = endpoint or os.getenv("LMSTUDIO_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions")
|
|
23
29
|
self.model = model
|
|
24
30
|
self.options: dict[str, Any] = {}
|
|
25
31
|
|
|
32
|
+
# Derive base_url once for reuse across management endpoints
|
|
33
|
+
self.base_url = self.endpoint.split("/v1/")[0]
|
|
34
|
+
|
|
35
|
+
# API key for LM Studio 0.4.0+ authentication
|
|
36
|
+
self.api_key = api_key or os.getenv("LMSTUDIO_API_KEY")
|
|
37
|
+
self._headers = self._build_headers()
|
|
38
|
+
|
|
26
39
|
supports_messages = True
|
|
27
40
|
|
|
41
|
+
def _build_headers(self) -> dict[str, str]:
|
|
42
|
+
"""Build request headers, including auth if an API key is configured."""
|
|
43
|
+
headers: dict[str, str] = {"Content-Type": "application/json"}
|
|
44
|
+
if self.api_key:
|
|
45
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
46
|
+
return headers
|
|
47
|
+
|
|
48
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
49
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
50
|
+
|
|
51
|
+
return _prepare_openai_vision_messages(messages)
|
|
52
|
+
|
|
28
53
|
async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
29
54
|
messages = [{"role": "user", "content": prompt}]
|
|
30
55
|
return await self._do_generate(messages, options)
|
|
31
56
|
|
|
32
57
|
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
33
|
-
return await self._do_generate(messages, options)
|
|
58
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
34
59
|
|
|
35
60
|
async def _do_generate(
|
|
36
61
|
self, messages: list[dict[str, str]], options: dict[str, Any] | None = None
|
|
@@ -51,7 +76,7 @@ class AsyncLMStudioDriver(AsyncDriver):
|
|
|
51
76
|
|
|
52
77
|
async with httpx.AsyncClient() as client:
|
|
53
78
|
try:
|
|
54
|
-
r = await client.post(self.endpoint, json=payload, timeout=120)
|
|
79
|
+
r = await client.post(self.endpoint, json=payload, headers=self._headers, timeout=120)
|
|
55
80
|
r.raise_for_status()
|
|
56
81
|
response_data = r.json()
|
|
57
82
|
except Exception as e:
|
|
@@ -77,3 +102,34 @@ class AsyncLMStudioDriver(AsyncDriver):
|
|
|
77
102
|
}
|
|
78
103
|
|
|
79
104
|
return {"text": text, "meta": meta}
|
|
105
|
+
|
|
106
|
+
# -- Model management (LM Studio 0.4.0+) ----------------------------------
|
|
107
|
+
|
|
108
|
+
async def list_models(self) -> list[dict[str, Any]]:
|
|
109
|
+
"""List currently loaded models via GET /v1/models (OpenAI-compatible)."""
|
|
110
|
+
url = f"{self.base_url}/v1/models"
|
|
111
|
+
async with httpx.AsyncClient() as client:
|
|
112
|
+
r = await client.get(url, headers=self._headers, timeout=10)
|
|
113
|
+
r.raise_for_status()
|
|
114
|
+
data = r.json()
|
|
115
|
+
return data.get("data", [])
|
|
116
|
+
|
|
117
|
+
async def load_model(self, model: str, context_length: int | None = None) -> dict[str, Any]:
|
|
118
|
+
"""Load a model into LM Studio via POST /api/v1/models/load."""
|
|
119
|
+
url = f"{self.base_url}/api/v1/models/load"
|
|
120
|
+
payload: dict[str, Any] = {"model": model}
|
|
121
|
+
if context_length is not None:
|
|
122
|
+
payload["context_length"] = context_length
|
|
123
|
+
async with httpx.AsyncClient() as client:
|
|
124
|
+
r = await client.post(url, json=payload, headers=self._headers, timeout=120)
|
|
125
|
+
r.raise_for_status()
|
|
126
|
+
return r.json()
|
|
127
|
+
|
|
128
|
+
async def unload_model(self, model: str) -> dict[str, Any]:
|
|
129
|
+
"""Unload a model from LM Studio via POST /api/v1/models/unload."""
|
|
130
|
+
url = f"{self.base_url}/api/v1/models/unload"
|
|
131
|
+
payload = {"model": model}
|
|
132
|
+
async with httpx.AsyncClient() as client:
|
|
133
|
+
r = await client.post(url, json=payload, headers=self._headers, timeout=30)
|
|
134
|
+
r.raise_for_status()
|
|
135
|
+
return r.json()
|
|
@@ -15,6 +15,7 @@ logger = logging.getLogger(__name__)
|
|
|
15
15
|
|
|
16
16
|
class AsyncOllamaDriver(AsyncDriver):
|
|
17
17
|
supports_json_mode = True
|
|
18
|
+
supports_vision = True
|
|
18
19
|
|
|
19
20
|
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
20
21
|
|
|
@@ -25,6 +26,11 @@ class AsyncOllamaDriver(AsyncDriver):
|
|
|
25
26
|
|
|
26
27
|
supports_messages = True
|
|
27
28
|
|
|
29
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
30
|
+
from .vision_helpers import _prepare_ollama_vision_messages
|
|
31
|
+
|
|
32
|
+
return _prepare_ollama_vision_messages(messages)
|
|
33
|
+
|
|
28
34
|
async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
29
35
|
merged_options = self.options.copy()
|
|
30
36
|
if options:
|
|
@@ -74,6 +80,7 @@ class AsyncOllamaDriver(AsyncDriver):
|
|
|
74
80
|
|
|
75
81
|
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
76
82
|
"""Use Ollama's /api/chat endpoint for multi-turn conversations."""
|
|
83
|
+
messages = self._prepare_messages(messages)
|
|
77
84
|
merged_options = self.options.copy()
|
|
78
85
|
if options:
|
|
79
86
|
merged_options.update(options)
|
|
@@ -18,6 +18,7 @@ from .openai_driver import OpenAIDriver
|
|
|
18
18
|
class AsyncOpenAIDriver(CostMixin, AsyncDriver):
|
|
19
19
|
supports_json_mode = True
|
|
20
20
|
supports_json_schema = True
|
|
21
|
+
supports_vision = True
|
|
21
22
|
|
|
22
23
|
MODEL_PRICING = OpenAIDriver.MODEL_PRICING
|
|
23
24
|
|
|
@@ -31,12 +32,17 @@ class AsyncOpenAIDriver(CostMixin, AsyncDriver):
|
|
|
31
32
|
|
|
32
33
|
supports_messages = True
|
|
33
34
|
|
|
35
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
36
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
37
|
+
|
|
38
|
+
return _prepare_openai_vision_messages(messages)
|
|
39
|
+
|
|
34
40
|
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
35
41
|
messages = [{"role": "user", "content": prompt}]
|
|
36
42
|
return await self._do_generate(messages, options)
|
|
37
43
|
|
|
38
44
|
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
39
|
-
return await self._do_generate(messages, options)
|
|
45
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
40
46
|
|
|
41
47
|
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
42
48
|
if self.client is None:
|
|
@@ -14,6 +14,7 @@ from .openrouter_driver import OpenRouterDriver
|
|
|
14
14
|
|
|
15
15
|
class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
|
|
16
16
|
supports_json_mode = True
|
|
17
|
+
supports_vision = True
|
|
17
18
|
|
|
18
19
|
MODEL_PRICING = OpenRouterDriver.MODEL_PRICING
|
|
19
20
|
|
|
@@ -31,12 +32,17 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
|
|
|
31
32
|
|
|
32
33
|
supports_messages = True
|
|
33
34
|
|
|
35
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
36
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
37
|
+
|
|
38
|
+
return _prepare_openai_vision_messages(messages)
|
|
39
|
+
|
|
34
40
|
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
35
41
|
messages = [{"role": "user", "content": prompt}]
|
|
36
42
|
return await self._do_generate(messages, options)
|
|
37
43
|
|
|
38
44
|
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
39
|
-
return await self._do_generate(messages, options)
|
|
45
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
40
46
|
|
|
41
47
|
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
42
48
|
model = options.get("model", self.model)
|
|
@@ -49,7 +49,11 @@ register_async_driver(
|
|
|
49
49
|
)
|
|
50
50
|
register_async_driver(
|
|
51
51
|
"lmstudio",
|
|
52
|
-
lambda model=None: AsyncLMStudioDriver(
|
|
52
|
+
lambda model=None: AsyncLMStudioDriver(
|
|
53
|
+
endpoint=settings.lmstudio_endpoint,
|
|
54
|
+
model=model or settings.lmstudio_model,
|
|
55
|
+
api_key=settings.lmstudio_api_key,
|
|
56
|
+
),
|
|
53
57
|
overwrite=True,
|
|
54
58
|
)
|
|
55
59
|
register_async_driver(
|
|
@@ -17,6 +17,7 @@ from ..driver import Driver
|
|
|
17
17
|
class AzureDriver(CostMixin, Driver):
|
|
18
18
|
supports_json_mode = True
|
|
19
19
|
supports_json_schema = True
|
|
20
|
+
supports_vision = True
|
|
20
21
|
|
|
21
22
|
# Pricing per 1K tokens (adjust if your Azure pricing differs from OpenAI defaults)
|
|
22
23
|
MODEL_PRICING = {
|
|
@@ -90,12 +91,17 @@ class AzureDriver(CostMixin, Driver):
|
|
|
90
91
|
|
|
91
92
|
supports_messages = True
|
|
92
93
|
|
|
94
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
95
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
96
|
+
|
|
97
|
+
return _prepare_openai_vision_messages(messages)
|
|
98
|
+
|
|
93
99
|
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
94
100
|
messages = [{"role": "user", "content": prompt}]
|
|
95
101
|
return self._do_generate(messages, options)
|
|
96
102
|
|
|
97
103
|
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
98
|
-
return self._do_generate(messages, options)
|
|
104
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
99
105
|
|
|
100
106
|
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
101
107
|
if self.client is None:
|
|
@@ -21,6 +21,7 @@ class ClaudeDriver(CostMixin, Driver):
|
|
|
21
21
|
supports_json_schema = True
|
|
22
22
|
supports_tool_use = True
|
|
23
23
|
supports_streaming = True
|
|
24
|
+
supports_vision = True
|
|
24
25
|
|
|
25
26
|
# Claude pricing per 1000 tokens (prices should be kept current with Anthropic's pricing)
|
|
26
27
|
MODEL_PRICING = {
|
|
@@ -57,12 +58,17 @@ class ClaudeDriver(CostMixin, Driver):
|
|
|
57
58
|
|
|
58
59
|
supports_messages = True
|
|
59
60
|
|
|
61
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
62
|
+
from .vision_helpers import _prepare_claude_vision_messages
|
|
63
|
+
|
|
64
|
+
return _prepare_claude_vision_messages(messages)
|
|
65
|
+
|
|
60
66
|
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
61
67
|
messages = [{"role": "user", "content": prompt}]
|
|
62
68
|
return self._do_generate(messages, options)
|
|
63
69
|
|
|
64
70
|
def generate_messages(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
|
|
65
|
-
return self._do_generate(messages, options)
|
|
71
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
66
72
|
|
|
67
73
|
def _do_generate(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
|
|
68
74
|
if anthropic is None:
|
|
@@ -15,6 +15,7 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
15
15
|
|
|
16
16
|
supports_json_mode = True
|
|
17
17
|
supports_json_schema = True
|
|
18
|
+
supports_vision = True
|
|
18
19
|
|
|
19
20
|
# Based on current Gemini pricing (as of 2025)
|
|
20
21
|
# Source: https://cloud.google.com/vertex-ai/pricing#gemini_models
|
|
@@ -107,12 +108,17 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
107
108
|
|
|
108
109
|
supports_messages = True
|
|
109
110
|
|
|
111
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
112
|
+
from .vision_helpers import _prepare_google_vision_messages
|
|
113
|
+
|
|
114
|
+
return _prepare_google_vision_messages(messages)
|
|
115
|
+
|
|
110
116
|
def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
111
117
|
messages = [{"role": "user", "content": prompt}]
|
|
112
118
|
return self._do_generate(messages, options)
|
|
113
119
|
|
|
114
120
|
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
115
|
-
return self._do_generate(messages, options)
|
|
121
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
116
122
|
|
|
117
123
|
def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
118
124
|
merged_options = self.options.copy()
|
|
@@ -147,11 +153,15 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
147
153
|
role = msg.get("role", "user")
|
|
148
154
|
content = msg.get("content", "")
|
|
149
155
|
if role == "system":
|
|
150
|
-
system_instruction = content
|
|
156
|
+
system_instruction = content if isinstance(content, str) else str(content)
|
|
151
157
|
else:
|
|
152
158
|
# Gemini uses "model" for assistant role
|
|
153
159
|
gemini_role = "model" if role == "assistant" else "user"
|
|
154
|
-
|
|
160
|
+
if msg.get("_vision_parts"):
|
|
161
|
+
# Already converted to Gemini parts by _prepare_messages
|
|
162
|
+
contents.append({"role": gemini_role, "parts": content})
|
|
163
|
+
else:
|
|
164
|
+
contents.append({"role": gemini_role, "parts": [content]})
|
|
155
165
|
|
|
156
166
|
try:
|
|
157
167
|
logger.debug(f"Initializing {self.model} for generation")
|
|
@@ -174,7 +184,17 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
174
184
|
raise ValueError("Empty response from model")
|
|
175
185
|
|
|
176
186
|
# Calculate token usage and cost
|
|
177
|
-
total_prompt_chars =
|
|
187
|
+
total_prompt_chars = 0
|
|
188
|
+
for msg in messages:
|
|
189
|
+
c = msg.get("content", "")
|
|
190
|
+
if isinstance(c, str):
|
|
191
|
+
total_prompt_chars += len(c)
|
|
192
|
+
elif isinstance(c, list):
|
|
193
|
+
for part in c:
|
|
194
|
+
if isinstance(part, str):
|
|
195
|
+
total_prompt_chars += len(part)
|
|
196
|
+
elif isinstance(part, dict) and "text" in part:
|
|
197
|
+
total_prompt_chars += len(part["text"])
|
|
178
198
|
completion_chars = len(response.text)
|
|
179
199
|
|
|
180
200
|
# Google uses character-based cost estimation
|
prompture/drivers/grok_driver.py
CHANGED
|
@@ -13,6 +13,7 @@ from ..driver import Driver
|
|
|
13
13
|
|
|
14
14
|
class GrokDriver(CostMixin, Driver):
|
|
15
15
|
supports_json_mode = True
|
|
16
|
+
supports_vision = True
|
|
16
17
|
|
|
17
18
|
# Pricing per 1M tokens based on xAI's documentation
|
|
18
19
|
_PRICING_UNIT = 1_000_000
|
|
@@ -80,12 +81,17 @@ class GrokDriver(CostMixin, Driver):
|
|
|
80
81
|
|
|
81
82
|
supports_messages = True
|
|
82
83
|
|
|
84
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
85
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
86
|
+
|
|
87
|
+
return _prepare_openai_vision_messages(messages)
|
|
88
|
+
|
|
83
89
|
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
84
90
|
messages = [{"role": "user", "content": prompt}]
|
|
85
91
|
return self._do_generate(messages, options)
|
|
86
92
|
|
|
87
93
|
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
88
|
-
return self._do_generate(messages, options)
|
|
94
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
89
95
|
|
|
90
96
|
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
91
97
|
if not self.api_key:
|
prompture/drivers/groq_driver.py
CHANGED
|
@@ -16,6 +16,7 @@ from ..driver import Driver
|
|
|
16
16
|
|
|
17
17
|
class GroqDriver(CostMixin, Driver):
|
|
18
18
|
supports_json_mode = True
|
|
19
|
+
supports_vision = True
|
|
19
20
|
|
|
20
21
|
# Approximate pricing per 1K tokens (to be updated with official pricing)
|
|
21
22
|
# Each model entry defines token parameters and temperature support
|
|
@@ -50,12 +51,17 @@ class GroqDriver(CostMixin, Driver):
|
|
|
50
51
|
|
|
51
52
|
supports_messages = True
|
|
52
53
|
|
|
54
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
55
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
56
|
+
|
|
57
|
+
return _prepare_openai_vision_messages(messages)
|
|
58
|
+
|
|
53
59
|
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
54
60
|
messages = [{"role": "user", "content": prompt}]
|
|
55
61
|
return self._do_generate(messages, options)
|
|
56
62
|
|
|
57
63
|
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
58
|
-
return self._do_generate(messages, options)
|
|
64
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
59
65
|
|
|
60
66
|
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
61
67
|
if self.client is None:
|
|
@@ -12,27 +12,46 @@ logger = logging.getLogger(__name__)
|
|
|
12
12
|
|
|
13
13
|
class LMStudioDriver(Driver):
|
|
14
14
|
supports_json_mode = True
|
|
15
|
+
supports_vision = True
|
|
15
16
|
|
|
16
17
|
# LM Studio is local – costs are always zero.
|
|
17
18
|
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
18
19
|
|
|
19
|
-
def __init__(
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
endpoint: str | None = None,
|
|
23
|
+
model: str = "deepseek/deepseek-r1-0528-qwen3-8b",
|
|
24
|
+
api_key: str | None = None,
|
|
25
|
+
):
|
|
20
26
|
# Allow override via env var
|
|
21
27
|
self.endpoint = endpoint or os.getenv("LMSTUDIO_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions")
|
|
22
28
|
self.model = model
|
|
23
29
|
self.options: dict[str, Any] = {}
|
|
24
30
|
|
|
31
|
+
# Derive base_url once for reuse across management endpoints
|
|
32
|
+
self.base_url = self.endpoint.split("/v1/")[0]
|
|
33
|
+
|
|
34
|
+
# API key for LM Studio 0.4.0+ authentication
|
|
35
|
+
self.api_key = api_key or os.getenv("LMSTUDIO_API_KEY")
|
|
36
|
+
self._headers = self._build_headers()
|
|
37
|
+
|
|
25
38
|
# Validate connection to LM Studio server
|
|
26
39
|
self._validate_connection()
|
|
27
40
|
|
|
41
|
+
def _build_headers(self) -> dict[str, str]:
|
|
42
|
+
"""Build request headers, including auth if an API key is configured."""
|
|
43
|
+
headers: dict[str, str] = {"Content-Type": "application/json"}
|
|
44
|
+
if self.api_key:
|
|
45
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
46
|
+
return headers
|
|
47
|
+
|
|
28
48
|
def _validate_connection(self):
|
|
29
49
|
"""Validate connection to the LM Studio server."""
|
|
30
50
|
try:
|
|
31
|
-
|
|
32
|
-
health_url = f"{base_url}/v1/models"
|
|
51
|
+
health_url = f"{self.base_url}/v1/models"
|
|
33
52
|
|
|
34
53
|
logger.debug(f"Validating connection to LM Studio server at: {health_url}")
|
|
35
|
-
response = requests.get(health_url, timeout=5)
|
|
54
|
+
response = requests.get(health_url, headers=self._headers, timeout=5)
|
|
36
55
|
response.raise_for_status()
|
|
37
56
|
logger.debug("Connection to LM Studio server validated successfully")
|
|
38
57
|
except requests.exceptions.RequestException as e:
|
|
@@ -40,12 +59,17 @@ class LMStudioDriver(Driver):
|
|
|
40
59
|
|
|
41
60
|
supports_messages = True
|
|
42
61
|
|
|
62
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
63
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
64
|
+
|
|
65
|
+
return _prepare_openai_vision_messages(messages)
|
|
66
|
+
|
|
43
67
|
def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
44
68
|
messages = [{"role": "user", "content": prompt}]
|
|
45
69
|
return self._do_generate(messages, options)
|
|
46
70
|
|
|
47
71
|
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
48
|
-
return self._do_generate(messages, options)
|
|
72
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
49
73
|
|
|
50
74
|
def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
51
75
|
merged_options = self.options.copy()
|
|
@@ -66,7 +90,7 @@ class LMStudioDriver(Driver):
|
|
|
66
90
|
logger.debug(f"Sending request to LM Studio endpoint: {self.endpoint}")
|
|
67
91
|
logger.debug(f"Request payload: {payload}")
|
|
68
92
|
|
|
69
|
-
r = requests.post(self.endpoint, json=payload, timeout=120)
|
|
93
|
+
r = requests.post(self.endpoint, json=payload, headers=self._headers, timeout=120)
|
|
70
94
|
r.raise_for_status()
|
|
71
95
|
|
|
72
96
|
response_data = r.json()
|
|
@@ -104,3 +128,31 @@ class LMStudioDriver(Driver):
|
|
|
104
128
|
}
|
|
105
129
|
|
|
106
130
|
return {"text": text, "meta": meta}
|
|
131
|
+
|
|
132
|
+
# -- Model management (LM Studio 0.4.0+) ----------------------------------
|
|
133
|
+
|
|
134
|
+
def list_models(self) -> list[dict[str, Any]]:
|
|
135
|
+
"""List currently loaded models via GET /v1/models (OpenAI-compatible)."""
|
|
136
|
+
url = f"{self.base_url}/v1/models"
|
|
137
|
+
r = requests.get(url, headers=self._headers, timeout=10)
|
|
138
|
+
r.raise_for_status()
|
|
139
|
+
data = r.json()
|
|
140
|
+
return data.get("data", [])
|
|
141
|
+
|
|
142
|
+
def load_model(self, model: str, context_length: int | None = None) -> dict[str, Any]:
|
|
143
|
+
"""Load a model into LM Studio via POST /api/v1/models/load."""
|
|
144
|
+
url = f"{self.base_url}/api/v1/models/load"
|
|
145
|
+
payload: dict[str, Any] = {"model": model}
|
|
146
|
+
if context_length is not None:
|
|
147
|
+
payload["context_length"] = context_length
|
|
148
|
+
r = requests.post(url, json=payload, headers=self._headers, timeout=120)
|
|
149
|
+
r.raise_for_status()
|
|
150
|
+
return r.json()
|
|
151
|
+
|
|
152
|
+
def unload_model(self, model: str) -> dict[str, Any]:
|
|
153
|
+
"""Unload a model from LM Studio via POST /api/v1/models/unload."""
|
|
154
|
+
url = f"{self.base_url}/api/v1/models/unload"
|
|
155
|
+
payload = {"model": model}
|
|
156
|
+
r = requests.post(url, json=payload, headers=self._headers, timeout=30)
|
|
157
|
+
r.raise_for_status()
|
|
158
|
+
return r.json()
|
|
@@ -14,6 +14,7 @@ logger = logging.getLogger(__name__)
|
|
|
14
14
|
class OllamaDriver(Driver):
|
|
15
15
|
supports_json_mode = True
|
|
16
16
|
supports_streaming = True
|
|
17
|
+
supports_vision = True
|
|
17
18
|
|
|
18
19
|
# Ollama is free – costs are always zero.
|
|
19
20
|
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
@@ -46,6 +47,11 @@ class OllamaDriver(Driver):
|
|
|
46
47
|
|
|
47
48
|
supports_messages = True
|
|
48
49
|
|
|
50
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
51
|
+
from .vision_helpers import _prepare_ollama_vision_messages
|
|
52
|
+
|
|
53
|
+
return _prepare_ollama_vision_messages(messages)
|
|
54
|
+
|
|
49
55
|
def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
50
56
|
# Merge instance options with call-specific options
|
|
51
57
|
merged_options = self.options.copy()
|
|
@@ -190,6 +196,7 @@ class OllamaDriver(Driver):
|
|
|
190
196
|
|
|
191
197
|
def generate_messages(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
|
|
192
198
|
"""Use Ollama's /api/chat endpoint for multi-turn conversations."""
|
|
199
|
+
messages = self._prepare_messages(messages)
|
|
193
200
|
merged_options = self.options.copy()
|
|
194
201
|
if options:
|
|
195
202
|
merged_options.update(options)
|
|
@@ -21,6 +21,7 @@ class OpenAIDriver(CostMixin, Driver):
|
|
|
21
21
|
supports_json_schema = True
|
|
22
22
|
supports_tool_use = True
|
|
23
23
|
supports_streaming = True
|
|
24
|
+
supports_vision = True
|
|
24
25
|
|
|
25
26
|
# Approximate pricing per 1K tokens (keep updated with OpenAI's official pricing)
|
|
26
27
|
# Each model entry also defines which token parameter it supports and
|
|
@@ -74,12 +75,17 @@ class OpenAIDriver(CostMixin, Driver):
|
|
|
74
75
|
|
|
75
76
|
supports_messages = True
|
|
76
77
|
|
|
78
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
79
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
80
|
+
|
|
81
|
+
return _prepare_openai_vision_messages(messages)
|
|
82
|
+
|
|
77
83
|
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
78
84
|
messages = [{"role": "user", "content": prompt}]
|
|
79
85
|
return self._do_generate(messages, options)
|
|
80
86
|
|
|
81
87
|
def generate_messages(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
|
|
82
|
-
return self._do_generate(messages, options)
|
|
88
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
83
89
|
|
|
84
90
|
def _do_generate(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
|
|
85
91
|
if self.client is None:
|