prompture 0.0.35__py3-none-any.whl → 0.0.38.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +120 -2
- prompture/_version.py +2 -2
- prompture/agent.py +924 -0
- prompture/agent_types.py +156 -0
- prompture/async_agent.py +880 -0
- prompture/async_conversation.py +199 -17
- prompture/async_driver.py +24 -0
- prompture/async_groups.py +551 -0
- prompture/conversation.py +213 -18
- prompture/core.py +30 -12
- prompture/discovery.py +24 -1
- prompture/driver.py +38 -0
- prompture/drivers/__init__.py +5 -1
- prompture/drivers/async_azure_driver.py +7 -1
- prompture/drivers/async_claude_driver.py +7 -1
- prompture/drivers/async_google_driver.py +212 -28
- prompture/drivers/async_grok_driver.py +7 -1
- prompture/drivers/async_groq_driver.py +7 -1
- prompture/drivers/async_lmstudio_driver.py +74 -5
- prompture/drivers/async_ollama_driver.py +13 -3
- prompture/drivers/async_openai_driver.py +7 -1
- prompture/drivers/async_openrouter_driver.py +7 -1
- prompture/drivers/async_registry.py +5 -1
- prompture/drivers/azure_driver.py +7 -1
- prompture/drivers/claude_driver.py +7 -1
- prompture/drivers/google_driver.py +217 -33
- prompture/drivers/grok_driver.py +7 -1
- prompture/drivers/groq_driver.py +7 -1
- prompture/drivers/lmstudio_driver.py +73 -8
- prompture/drivers/ollama_driver.py +16 -5
- prompture/drivers/openai_driver.py +7 -1
- prompture/drivers/openrouter_driver.py +7 -1
- prompture/drivers/vision_helpers.py +153 -0
- prompture/group_types.py +147 -0
- prompture/groups.py +530 -0
- prompture/image.py +180 -0
- prompture/persistence.py +254 -0
- prompture/persona.py +482 -0
- prompture/serialization.py +218 -0
- prompture/settings.py +1 -0
- prompture-0.0.38.dev2.dist-info/METADATA +369 -0
- prompture-0.0.38.dev2.dist-info/RECORD +77 -0
- prompture-0.0.35.dist-info/METADATA +0 -464
- prompture-0.0.35.dist-info/RECORD +0 -66
- {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/WHEEL +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/top_level.txt +0 -0
prompture/conversation.py
CHANGED
|
@@ -4,9 +4,11 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import json
|
|
6
6
|
import logging
|
|
7
|
+
import uuid
|
|
7
8
|
from collections.abc import Iterator
|
|
8
|
-
from datetime import date, datetime
|
|
9
|
+
from datetime import date, datetime, timezone
|
|
9
10
|
from decimal import Decimal
|
|
11
|
+
from pathlib import Path
|
|
10
12
|
from typing import Any, Callable, Literal, Union
|
|
11
13
|
|
|
12
14
|
from pydantic import BaseModel
|
|
@@ -15,6 +17,11 @@ from .callbacks import DriverCallbacks
|
|
|
15
17
|
from .driver import Driver
|
|
16
18
|
from .drivers import get_driver_for_model
|
|
17
19
|
from .field_definitions import get_registry_snapshot
|
|
20
|
+
from .image import ImageInput, make_image
|
|
21
|
+
from .persistence import load_from_file, save_to_file
|
|
22
|
+
from .persona import Persona, get_persona
|
|
23
|
+
from .serialization import export_conversation, import_conversation
|
|
24
|
+
from .session import UsageSession
|
|
18
25
|
from .tools import (
|
|
19
26
|
clean_json_text,
|
|
20
27
|
convert_value,
|
|
@@ -44,13 +51,34 @@ class Conversation:
|
|
|
44
51
|
*,
|
|
45
52
|
driver: Driver | None = None,
|
|
46
53
|
system_prompt: str | None = None,
|
|
54
|
+
persona: str | Persona | None = None,
|
|
47
55
|
options: dict[str, Any] | None = None,
|
|
48
56
|
callbacks: DriverCallbacks | None = None,
|
|
49
57
|
tools: ToolRegistry | None = None,
|
|
50
58
|
max_tool_rounds: int = 10,
|
|
59
|
+
conversation_id: str | None = None,
|
|
60
|
+
auto_save: str | Path | None = None,
|
|
61
|
+
tags: list[str] | None = None,
|
|
51
62
|
) -> None:
|
|
63
|
+
if system_prompt is not None and persona is not None:
|
|
64
|
+
raise ValueError("Cannot provide both 'system_prompt' and 'persona'. Use one or the other.")
|
|
65
|
+
|
|
66
|
+
# Resolve persona
|
|
67
|
+
resolved_persona: Persona | None = None
|
|
68
|
+
if persona is not None:
|
|
69
|
+
if isinstance(persona, str):
|
|
70
|
+
resolved_persona = get_persona(persona)
|
|
71
|
+
if resolved_persona is None:
|
|
72
|
+
raise ValueError(f"Persona '{persona}' not found in registry.")
|
|
73
|
+
else:
|
|
74
|
+
resolved_persona = persona
|
|
75
|
+
|
|
52
76
|
if model_name is None and driver is None:
|
|
53
|
-
|
|
77
|
+
# Check persona for model_hint
|
|
78
|
+
if resolved_persona is not None and resolved_persona.model_hint:
|
|
79
|
+
model_name = resolved_persona.model_hint
|
|
80
|
+
else:
|
|
81
|
+
raise ValueError("Either model_name or driver must be provided")
|
|
54
82
|
|
|
55
83
|
if driver is not None:
|
|
56
84
|
self._driver = driver
|
|
@@ -61,8 +89,16 @@ class Conversation:
|
|
|
61
89
|
self._driver.callbacks = callbacks
|
|
62
90
|
|
|
63
91
|
self._model_name = model_name or ""
|
|
64
|
-
|
|
65
|
-
|
|
92
|
+
|
|
93
|
+
# Apply persona: render system_prompt and merge settings
|
|
94
|
+
if resolved_persona is not None:
|
|
95
|
+
self._system_prompt = resolved_persona.render()
|
|
96
|
+
# Persona settings as defaults, explicit options override
|
|
97
|
+
self._options = {**resolved_persona.settings, **(dict(options) if options else {})}
|
|
98
|
+
else:
|
|
99
|
+
self._system_prompt = system_prompt
|
|
100
|
+
self._options = dict(options) if options else {}
|
|
101
|
+
|
|
66
102
|
self._messages: list[dict[str, Any]] = []
|
|
67
103
|
self._usage = {
|
|
68
104
|
"prompt_tokens": 0,
|
|
@@ -74,6 +110,14 @@ class Conversation:
|
|
|
74
110
|
self._tools = tools or ToolRegistry()
|
|
75
111
|
self._max_tool_rounds = max_tool_rounds
|
|
76
112
|
|
|
113
|
+
# Persistence
|
|
114
|
+
self._conversation_id = conversation_id or str(uuid.uuid4())
|
|
115
|
+
self._auto_save = Path(auto_save) if auto_save else None
|
|
116
|
+
self._metadata: dict[str, Any] = {
|
|
117
|
+
"created_at": datetime.now(timezone.utc).isoformat(),
|
|
118
|
+
"tags": list(tags) if tags else [],
|
|
119
|
+
}
|
|
120
|
+
|
|
77
121
|
# ------------------------------------------------------------------
|
|
78
122
|
# Public helpers
|
|
79
123
|
# ------------------------------------------------------------------
|
|
@@ -92,11 +136,12 @@ class Conversation:
|
|
|
92
136
|
"""Reset message history (keeps system_prompt and driver)."""
|
|
93
137
|
self._messages.clear()
|
|
94
138
|
|
|
95
|
-
def add_context(self, role: str, content: str) -> None:
|
|
139
|
+
def add_context(self, role: str, content: str, images: list[ImageInput] | None = None) -> None:
|
|
96
140
|
"""Seed the history with a user or assistant message."""
|
|
97
141
|
if role not in ("user", "assistant"):
|
|
98
142
|
raise ValueError("role must be 'user' or 'assistant'")
|
|
99
|
-
self.
|
|
143
|
+
msg_content = self._build_content_with_images(content, images)
|
|
144
|
+
self._messages.append({"role": role, "content": msg_content})
|
|
100
145
|
|
|
101
146
|
def register_tool(
|
|
102
147
|
self,
|
|
@@ -113,17 +158,149 @@ class Conversation:
|
|
|
113
158
|
u = self._usage
|
|
114
159
|
return f"Conversation: {u['total_tokens']:,} tokens across {u['turns']} turn(s) costing ${u['cost']:.4f}"
|
|
115
160
|
|
|
161
|
+
# ------------------------------------------------------------------
|
|
162
|
+
# Persistence properties
|
|
163
|
+
# ------------------------------------------------------------------
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
def conversation_id(self) -> str:
|
|
167
|
+
"""Unique identifier for this conversation."""
|
|
168
|
+
return self._conversation_id
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def tags(self) -> list[str]:
|
|
172
|
+
"""Tags attached to this conversation."""
|
|
173
|
+
return self._metadata.get("tags", [])
|
|
174
|
+
|
|
175
|
+
@tags.setter
|
|
176
|
+
def tags(self, value: list[str]) -> None:
|
|
177
|
+
self._metadata["tags"] = list(value)
|
|
178
|
+
|
|
179
|
+
# ------------------------------------------------------------------
|
|
180
|
+
# Export / Import
|
|
181
|
+
# ------------------------------------------------------------------
|
|
182
|
+
|
|
183
|
+
def export(self, *, usage_session: UsageSession | None = None, strip_images: bool = False) -> dict[str, Any]:
|
|
184
|
+
"""Export conversation state to a JSON-serializable dict."""
|
|
185
|
+
tools_metadata = (
|
|
186
|
+
[
|
|
187
|
+
{"name": td.name, "description": td.description, "parameters": td.parameters}
|
|
188
|
+
for td in self._tools.definitions
|
|
189
|
+
]
|
|
190
|
+
if self._tools and self._tools.definitions
|
|
191
|
+
else None
|
|
192
|
+
)
|
|
193
|
+
return export_conversation(
|
|
194
|
+
model_name=self._model_name,
|
|
195
|
+
system_prompt=self._system_prompt,
|
|
196
|
+
options=self._options,
|
|
197
|
+
messages=self._messages,
|
|
198
|
+
usage=self._usage,
|
|
199
|
+
max_tool_rounds=self._max_tool_rounds,
|
|
200
|
+
tools_metadata=tools_metadata,
|
|
201
|
+
usage_session=usage_session,
|
|
202
|
+
metadata=self._metadata,
|
|
203
|
+
conversation_id=self._conversation_id,
|
|
204
|
+
strip_images=strip_images,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
@classmethod
|
|
208
|
+
def from_export(
|
|
209
|
+
cls,
|
|
210
|
+
data: dict[str, Any],
|
|
211
|
+
*,
|
|
212
|
+
callbacks: DriverCallbacks | None = None,
|
|
213
|
+
tools: ToolRegistry | None = None,
|
|
214
|
+
) -> Conversation:
|
|
215
|
+
"""Reconstruct a :class:`Conversation` from an export dict.
|
|
216
|
+
|
|
217
|
+
The driver is reconstructed from the stored ``model_name``.
|
|
218
|
+
Callbacks and tool *functions* must be re-attached by the caller
|
|
219
|
+
(tool metadata — name/description/parameters — is preserved in
|
|
220
|
+
the export but executable functions cannot be serialized).
|
|
221
|
+
"""
|
|
222
|
+
imported = import_conversation(data)
|
|
223
|
+
|
|
224
|
+
model_name = imported.get("model_name") or ""
|
|
225
|
+
if not model_name:
|
|
226
|
+
raise ValueError("Cannot restore conversation: export has no model_name")
|
|
227
|
+
conv = cls(
|
|
228
|
+
model_name=model_name,
|
|
229
|
+
system_prompt=imported.get("system_prompt"),
|
|
230
|
+
options=imported.get("options", {}),
|
|
231
|
+
callbacks=callbacks,
|
|
232
|
+
tools=tools,
|
|
233
|
+
max_tool_rounds=imported.get("max_tool_rounds", 10),
|
|
234
|
+
conversation_id=imported.get("conversation_id"),
|
|
235
|
+
tags=imported.get("metadata", {}).get("tags", []),
|
|
236
|
+
)
|
|
237
|
+
conv._messages = imported.get("messages", [])
|
|
238
|
+
conv._usage = imported.get(
|
|
239
|
+
"usage",
|
|
240
|
+
{
|
|
241
|
+
"prompt_tokens": 0,
|
|
242
|
+
"completion_tokens": 0,
|
|
243
|
+
"total_tokens": 0,
|
|
244
|
+
"cost": 0.0,
|
|
245
|
+
"turns": 0,
|
|
246
|
+
},
|
|
247
|
+
)
|
|
248
|
+
meta = imported.get("metadata", {})
|
|
249
|
+
if "created_at" in meta:
|
|
250
|
+
conv._metadata["created_at"] = meta["created_at"]
|
|
251
|
+
return conv
|
|
252
|
+
|
|
253
|
+
def save(self, path: str | Path, **kwargs: Any) -> None:
|
|
254
|
+
"""Export and write to a JSON file.
|
|
255
|
+
|
|
256
|
+
Keyword arguments are forwarded to :meth:`export`.
|
|
257
|
+
"""
|
|
258
|
+
save_to_file(self.export(**kwargs), path)
|
|
259
|
+
|
|
260
|
+
@classmethod
|
|
261
|
+
def load(
|
|
262
|
+
cls,
|
|
263
|
+
path: str | Path,
|
|
264
|
+
*,
|
|
265
|
+
callbacks: DriverCallbacks | None = None,
|
|
266
|
+
tools: ToolRegistry | None = None,
|
|
267
|
+
) -> Conversation:
|
|
268
|
+
"""Load a conversation from a JSON file."""
|
|
269
|
+
data = load_from_file(path)
|
|
270
|
+
return cls.from_export(data, callbacks=callbacks, tools=tools)
|
|
271
|
+
|
|
272
|
+
def _maybe_auto_save(self) -> None:
|
|
273
|
+
"""Auto-save after each turn if configured. Errors are silently logged."""
|
|
274
|
+
if self._auto_save is None:
|
|
275
|
+
return
|
|
276
|
+
try:
|
|
277
|
+
self.save(self._auto_save)
|
|
278
|
+
except Exception:
|
|
279
|
+
logger.debug("Auto-save failed for conversation %s", self._conversation_id, exc_info=True)
|
|
280
|
+
|
|
116
281
|
# ------------------------------------------------------------------
|
|
117
282
|
# Core methods
|
|
118
283
|
# ------------------------------------------------------------------
|
|
119
284
|
|
|
120
|
-
|
|
285
|
+
@staticmethod
|
|
286
|
+
def _build_content_with_images(text: str, images: list[ImageInput] | None = None) -> str | list[dict[str, Any]]:
|
|
287
|
+
"""Return plain string when no images, or a list of content blocks."""
|
|
288
|
+
if not images:
|
|
289
|
+
return text
|
|
290
|
+
blocks: list[dict[str, Any]] = [{"type": "text", "text": text}]
|
|
291
|
+
for img in images:
|
|
292
|
+
ic = make_image(img)
|
|
293
|
+
blocks.append({"type": "image", "source": ic})
|
|
294
|
+
return blocks
|
|
295
|
+
|
|
296
|
+
def _build_messages(self, user_content: str, images: list[ImageInput] | None = None) -> list[dict[str, Any]]:
|
|
121
297
|
"""Build the full messages array for an API call."""
|
|
122
298
|
msgs: list[dict[str, Any]] = []
|
|
123
299
|
if self._system_prompt:
|
|
124
300
|
msgs.append({"role": "system", "content": self._system_prompt})
|
|
125
301
|
msgs.extend(self._messages)
|
|
126
|
-
|
|
302
|
+
content = self._build_content_with_images(user_content, images)
|
|
303
|
+
msgs.append({"role": "user", "content": content})
|
|
127
304
|
return msgs
|
|
128
305
|
|
|
129
306
|
def _accumulate_usage(self, meta: dict[str, Any]) -> None:
|
|
@@ -132,30 +309,39 @@ class Conversation:
|
|
|
132
309
|
self._usage["total_tokens"] += meta.get("total_tokens", 0)
|
|
133
310
|
self._usage["cost"] += meta.get("cost", 0.0)
|
|
134
311
|
self._usage["turns"] += 1
|
|
312
|
+
self._maybe_auto_save()
|
|
135
313
|
|
|
136
314
|
def ask(
|
|
137
315
|
self,
|
|
138
316
|
content: str,
|
|
139
317
|
options: dict[str, Any] | None = None,
|
|
318
|
+
images: list[ImageInput] | None = None,
|
|
140
319
|
) -> str:
|
|
141
320
|
"""Send a message and get a raw text response.
|
|
142
321
|
|
|
143
322
|
Appends the user message and assistant response to history.
|
|
144
323
|
If tools are registered and the driver supports tool use,
|
|
145
324
|
dispatches to the tool execution loop.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
content: The text message to send.
|
|
328
|
+
options: Additional options for the driver.
|
|
329
|
+
images: Optional list of images to include (bytes, path, URL,
|
|
330
|
+
base64 string, or :class:`ImageContent`).
|
|
146
331
|
"""
|
|
147
332
|
if self._tools and getattr(self._driver, "supports_tool_use", False):
|
|
148
|
-
return self._ask_with_tools(content, options)
|
|
333
|
+
return self._ask_with_tools(content, options, images=images)
|
|
149
334
|
|
|
150
335
|
merged = {**self._options, **(options or {})}
|
|
151
|
-
messages = self._build_messages(content)
|
|
336
|
+
messages = self._build_messages(content, images=images)
|
|
152
337
|
resp = self._driver.generate_messages_with_hooks(messages, merged)
|
|
153
338
|
|
|
154
339
|
text = resp.get("text", "")
|
|
155
340
|
meta = resp.get("meta", {})
|
|
156
341
|
|
|
157
|
-
# Record in history
|
|
158
|
-
self.
|
|
342
|
+
# Record in history — store content with images for context
|
|
343
|
+
user_content = self._build_content_with_images(content, images)
|
|
344
|
+
self._messages.append({"role": "user", "content": user_content})
|
|
159
345
|
self._messages.append({"role": "assistant", "content": text})
|
|
160
346
|
self._accumulate_usage(meta)
|
|
161
347
|
|
|
@@ -165,13 +351,15 @@ class Conversation:
|
|
|
165
351
|
self,
|
|
166
352
|
content: str,
|
|
167
353
|
options: dict[str, Any] | None = None,
|
|
354
|
+
images: list[ImageInput] | None = None,
|
|
168
355
|
) -> str:
|
|
169
356
|
"""Execute the tool-use loop: send -> check tool_calls -> execute -> re-send."""
|
|
170
357
|
merged = {**self._options, **(options or {})}
|
|
171
358
|
tool_defs = self._tools.to_openai_format()
|
|
172
359
|
|
|
173
360
|
# Build messages including user content
|
|
174
|
-
self.
|
|
361
|
+
user_content = self._build_content_with_images(content, images)
|
|
362
|
+
self._messages.append({"role": "user", "content": user_content})
|
|
175
363
|
msgs = self._build_messages_raw()
|
|
176
364
|
|
|
177
365
|
for _round in range(self._max_tool_rounds):
|
|
@@ -235,6 +423,7 @@ class Conversation:
|
|
|
235
423
|
self,
|
|
236
424
|
content: str,
|
|
237
425
|
options: dict[str, Any] | None = None,
|
|
426
|
+
images: list[ImageInput] | None = None,
|
|
238
427
|
) -> Iterator[str]:
|
|
239
428
|
"""Send a message and yield text chunks as they arrive.
|
|
240
429
|
|
|
@@ -243,13 +432,14 @@ class Conversation:
|
|
|
243
432
|
is recorded in history.
|
|
244
433
|
"""
|
|
245
434
|
if not getattr(self._driver, "supports_streaming", False):
|
|
246
|
-
yield self.ask(content, options)
|
|
435
|
+
yield self.ask(content, options, images=images)
|
|
247
436
|
return
|
|
248
437
|
|
|
249
438
|
merged = {**self._options, **(options or {})}
|
|
250
|
-
messages = self._build_messages(content)
|
|
439
|
+
messages = self._build_messages(content, images=images)
|
|
251
440
|
|
|
252
|
-
self.
|
|
441
|
+
user_content = self._build_content_with_images(content, images)
|
|
442
|
+
self._messages.append({"role": "user", "content": user_content})
|
|
253
443
|
|
|
254
444
|
full_text = ""
|
|
255
445
|
for chunk in self._driver.generate_messages_stream(messages, merged):
|
|
@@ -276,6 +466,7 @@ class Conversation:
|
|
|
276
466
|
options: dict[str, Any] | None = None,
|
|
277
467
|
output_format: Literal["json", "toon"] = "json",
|
|
278
468
|
json_mode: Literal["auto", "on", "off"] = "auto",
|
|
469
|
+
images: list[ImageInput] | None = None,
|
|
279
470
|
) -> dict[str, Any]:
|
|
280
471
|
"""Send a message with schema enforcement and get structured JSON back.
|
|
281
472
|
|
|
@@ -320,14 +511,16 @@ class Conversation:
|
|
|
320
511
|
|
|
321
512
|
full_user_content = f"{content}\n\n{instruct}"
|
|
322
513
|
|
|
323
|
-
messages = self._build_messages(full_user_content)
|
|
514
|
+
messages = self._build_messages(full_user_content, images=images)
|
|
324
515
|
resp = self._driver.generate_messages_with_hooks(messages, merged)
|
|
325
516
|
|
|
326
517
|
text = resp.get("text", "")
|
|
327
518
|
meta = resp.get("meta", {})
|
|
328
519
|
|
|
329
520
|
# Store original content (without schema boilerplate) for cleaner context
|
|
330
|
-
|
|
521
|
+
# Include images in history so subsequent turns can reference them
|
|
522
|
+
user_content = self._build_content_with_images(content, images)
|
|
523
|
+
self._messages.append({"role": "user", "content": user_content})
|
|
331
524
|
|
|
332
525
|
# Parse JSON
|
|
333
526
|
cleaned = clean_json_text(text)
|
|
@@ -383,6 +576,7 @@ class Conversation:
|
|
|
383
576
|
output_format: Literal["json", "toon"] = "json",
|
|
384
577
|
options: dict[str, Any] | None = None,
|
|
385
578
|
json_mode: Literal["auto", "on", "off"] = "auto",
|
|
579
|
+
images: list[ImageInput] | None = None,
|
|
386
580
|
) -> dict[str, Any]:
|
|
387
581
|
"""Extract structured information into a Pydantic model with conversation context."""
|
|
388
582
|
from .core import normalize_field_value
|
|
@@ -397,6 +591,7 @@ class Conversation:
|
|
|
397
591
|
options=options,
|
|
398
592
|
output_format=output_format,
|
|
399
593
|
json_mode=json_mode,
|
|
594
|
+
images=images,
|
|
400
595
|
)
|
|
401
596
|
|
|
402
597
|
# Normalize field values
|
prompture/core.py
CHANGED
|
@@ -21,6 +21,7 @@ from pydantic import BaseModel
|
|
|
21
21
|
from .driver import Driver
|
|
22
22
|
from .drivers import get_driver_for_model
|
|
23
23
|
from .field_definitions import get_registry_snapshot
|
|
24
|
+
from .image import ImageInput, make_image
|
|
24
25
|
from .tools import (
|
|
25
26
|
clean_json_text,
|
|
26
27
|
convert_value,
|
|
@@ -30,6 +31,17 @@ from .tools import (
|
|
|
30
31
|
logger = logging.getLogger("prompture.core")
|
|
31
32
|
|
|
32
33
|
|
|
34
|
+
def _build_content_with_images(text: str, images: list[ImageInput] | None = None) -> str | list[dict[str, Any]]:
|
|
35
|
+
"""Return plain string when no images, or a list of content blocks."""
|
|
36
|
+
if not images:
|
|
37
|
+
return text
|
|
38
|
+
blocks: list[dict[str, Any]] = [{"type": "text", "text": text}]
|
|
39
|
+
for img in images:
|
|
40
|
+
ic = make_image(img)
|
|
41
|
+
blocks.append({"type": "image", "source": ic})
|
|
42
|
+
return blocks
|
|
43
|
+
|
|
44
|
+
|
|
33
45
|
def normalize_field_value(value: Any, field_type: type, field_def: dict[str, Any]) -> Any:
|
|
34
46
|
"""Normalize invalid values for fields based on their type and nullable status.
|
|
35
47
|
|
|
@@ -142,6 +154,7 @@ def render_output(
|
|
|
142
154
|
model_name: str = "",
|
|
143
155
|
options: dict[str, Any] | None = None,
|
|
144
156
|
system_prompt: str | None = None,
|
|
157
|
+
images: list[ImageInput] | None = None,
|
|
145
158
|
) -> dict[str, Any]:
|
|
146
159
|
"""Sends a prompt to the driver and returns the raw output in the requested format.
|
|
147
160
|
|
|
@@ -186,12 +199,12 @@ def render_output(
|
|
|
186
199
|
|
|
187
200
|
full_prompt = f"{content_prompt}\n\nSYSTEM INSTRUCTION: {instruct}"
|
|
188
201
|
|
|
189
|
-
# Use generate_messages when system_prompt
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
202
|
+
# Use generate_messages when system_prompt or images are provided
|
|
203
|
+
user_content = _build_content_with_images(full_prompt, images)
|
|
204
|
+
if system_prompt is not None or images:
|
|
205
|
+
messages = [{"role": "user", "content": user_content}]
|
|
206
|
+
if system_prompt is not None:
|
|
207
|
+
messages.insert(0, {"role": "system", "content": system_prompt})
|
|
195
208
|
resp = driver.generate_messages(messages, options)
|
|
196
209
|
else:
|
|
197
210
|
resp = driver.generate(full_prompt, options)
|
|
@@ -232,6 +245,7 @@ def ask_for_json(
|
|
|
232
245
|
cache: bool | None = None,
|
|
233
246
|
json_mode: Literal["auto", "on", "off"] = "auto",
|
|
234
247
|
system_prompt: str | None = None,
|
|
248
|
+
images: list[ImageInput] | None = None,
|
|
235
249
|
) -> dict[str, Any]:
|
|
236
250
|
"""Sends a prompt to the driver and returns structured output plus usage metadata.
|
|
237
251
|
|
|
@@ -327,12 +341,12 @@ def ask_for_json(
|
|
|
327
341
|
|
|
328
342
|
full_prompt = f"{content_prompt}\n\n{instruct}"
|
|
329
343
|
|
|
330
|
-
# Use generate_messages when system_prompt
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
344
|
+
# Use generate_messages when system_prompt or images are provided
|
|
345
|
+
user_content = _build_content_with_images(full_prompt, images)
|
|
346
|
+
if system_prompt is not None or images:
|
|
347
|
+
messages = [{"role": "user", "content": user_content}]
|
|
348
|
+
if system_prompt is not None:
|
|
349
|
+
messages.insert(0, {"role": "system", "content": system_prompt})
|
|
336
350
|
resp = driver.generate_messages(messages, options)
|
|
337
351
|
else:
|
|
338
352
|
resp = driver.generate(full_prompt, options)
|
|
@@ -411,6 +425,7 @@ def extract_and_jsonify(
|
|
|
411
425
|
options: dict[str, Any] | None = None,
|
|
412
426
|
json_mode: Literal["auto", "on", "off"] = "auto",
|
|
413
427
|
system_prompt: str | None = None,
|
|
428
|
+
images: list[ImageInput] | None = None,
|
|
414
429
|
) -> dict[str, Any]:
|
|
415
430
|
"""Extracts structured information using automatic driver selection based on model name.
|
|
416
431
|
|
|
@@ -497,6 +512,7 @@ def extract_and_jsonify(
|
|
|
497
512
|
output_format=actual_output_format,
|
|
498
513
|
json_mode=json_mode,
|
|
499
514
|
system_prompt=system_prompt,
|
|
515
|
+
images=images,
|
|
500
516
|
)
|
|
501
517
|
except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError) as e:
|
|
502
518
|
if "pytest" in sys.modules:
|
|
@@ -595,6 +611,7 @@ def extract_with_model(
|
|
|
595
611
|
cache: bool | None = None,
|
|
596
612
|
json_mode: Literal["auto", "on", "off"] = "auto",
|
|
597
613
|
system_prompt: str | None = None,
|
|
614
|
+
images: list[ImageInput] | None = None,
|
|
598
615
|
) -> dict[str, Any]:
|
|
599
616
|
"""Extracts structured information into a Pydantic model instance.
|
|
600
617
|
|
|
@@ -684,6 +701,7 @@ def extract_with_model(
|
|
|
684
701
|
options=options,
|
|
685
702
|
json_mode=json_mode,
|
|
686
703
|
system_prompt=system_prompt,
|
|
704
|
+
images=images,
|
|
687
705
|
)
|
|
688
706
|
logger.debug("[extract] Extraction completed successfully")
|
|
689
707
|
|
prompture/discovery.py
CHANGED
|
@@ -147,7 +147,30 @@ def get_available_models() -> list[str]:
|
|
|
147
147
|
except Exception as e:
|
|
148
148
|
logger.debug(f"Failed to fetch Ollama models: {e}")
|
|
149
149
|
|
|
150
|
-
#
|
|
150
|
+
# Dynamic Detection: LM Studio loaded models
|
|
151
|
+
if provider == "lmstudio":
|
|
152
|
+
try:
|
|
153
|
+
endpoint = settings.lmstudio_endpoint or os.getenv(
|
|
154
|
+
"LMSTUDIO_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions"
|
|
155
|
+
)
|
|
156
|
+
base_url = endpoint.split("/v1/")[0]
|
|
157
|
+
models_url = f"{base_url}/v1/models"
|
|
158
|
+
|
|
159
|
+
headers: dict[str, str] = {}
|
|
160
|
+
api_key = settings.lmstudio_api_key or os.getenv("LMSTUDIO_API_KEY")
|
|
161
|
+
if api_key:
|
|
162
|
+
headers["Authorization"] = f"Bearer {api_key}"
|
|
163
|
+
|
|
164
|
+
resp = requests.get(models_url, headers=headers, timeout=2)
|
|
165
|
+
if resp.status_code == 200:
|
|
166
|
+
data = resp.json()
|
|
167
|
+
models = data.get("data", [])
|
|
168
|
+
for model in models:
|
|
169
|
+
model_id = model.get("id")
|
|
170
|
+
if model_id:
|
|
171
|
+
available_models.add(f"lmstudio/{model_id}")
|
|
172
|
+
except Exception as e:
|
|
173
|
+
logger.debug(f"Failed to fetch LM Studio models: {e}")
|
|
151
174
|
|
|
152
175
|
except Exception as e:
|
|
153
176
|
logger.warning(f"Error detecting models for provider {provider}: {e}")
|
prompture/driver.py
CHANGED
|
@@ -35,6 +35,7 @@ class Driver:
|
|
|
35
35
|
supports_messages: bool = False
|
|
36
36
|
supports_tool_use: bool = False
|
|
37
37
|
supports_streaming: bool = False
|
|
38
|
+
supports_vision: bool = False
|
|
38
39
|
|
|
39
40
|
callbacks: DriverCallbacks | None = None
|
|
40
41
|
|
|
@@ -52,6 +53,7 @@ class Driver:
|
|
|
52
53
|
support message arrays should override this method and set
|
|
53
54
|
``supports_messages = True``.
|
|
54
55
|
"""
|
|
56
|
+
self._check_vision_support(messages)
|
|
55
57
|
prompt = self._flatten_messages(messages)
|
|
56
58
|
return self.generate(prompt, options)
|
|
57
59
|
|
|
@@ -171,6 +173,30 @@ class Driver:
|
|
|
171
173
|
except Exception:
|
|
172
174
|
logger.exception("Callback %s raised an exception", event)
|
|
173
175
|
|
|
176
|
+
def _check_vision_support(self, messages: list[dict[str, Any]]) -> None:
|
|
177
|
+
"""Raise if messages contain image blocks and the driver lacks vision support."""
|
|
178
|
+
if self.supports_vision:
|
|
179
|
+
return
|
|
180
|
+
for msg in messages:
|
|
181
|
+
content = msg.get("content")
|
|
182
|
+
if isinstance(content, list):
|
|
183
|
+
for block in content:
|
|
184
|
+
if isinstance(block, dict) and block.get("type") == "image":
|
|
185
|
+
raise NotImplementedError(
|
|
186
|
+
f"{self.__class__.__name__} does not support vision/image inputs. "
|
|
187
|
+
"Use a vision-capable model."
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
191
|
+
"""Transform universal message format into provider-specific wire format.
|
|
192
|
+
|
|
193
|
+
Vision-capable drivers override this to convert the universal image
|
|
194
|
+
blocks into their provider-specific format. The base implementation
|
|
195
|
+
validates vision support and returns messages unchanged.
|
|
196
|
+
"""
|
|
197
|
+
self._check_vision_support(messages)
|
|
198
|
+
return messages
|
|
199
|
+
|
|
174
200
|
@staticmethod
|
|
175
201
|
def _flatten_messages(messages: list[dict[str, Any]]) -> str:
|
|
176
202
|
"""Join messages into a single prompt string with role prefixes."""
|
|
@@ -178,6 +204,18 @@ class Driver:
|
|
|
178
204
|
for msg in messages:
|
|
179
205
|
role = msg.get("role", "user")
|
|
180
206
|
content = msg.get("content", "")
|
|
207
|
+
# Handle content that is a list of blocks (vision messages)
|
|
208
|
+
if isinstance(content, list):
|
|
209
|
+
text_parts = []
|
|
210
|
+
for block in content:
|
|
211
|
+
if isinstance(block, dict):
|
|
212
|
+
if block.get("type") == "text":
|
|
213
|
+
text_parts.append(block.get("text", ""))
|
|
214
|
+
elif block.get("type") == "image":
|
|
215
|
+
text_parts.append("[image]")
|
|
216
|
+
elif isinstance(block, str):
|
|
217
|
+
text_parts.append(block)
|
|
218
|
+
content = " ".join(text_parts)
|
|
181
219
|
if role == "system":
|
|
182
220
|
parts.append(f"[System]: {content}")
|
|
183
221
|
elif role == "assistant":
|
prompture/drivers/__init__.py
CHANGED
|
@@ -84,7 +84,11 @@ register_driver(
|
|
|
84
84
|
)
|
|
85
85
|
register_driver(
|
|
86
86
|
"lmstudio",
|
|
87
|
-
lambda model=None: LMStudioDriver(
|
|
87
|
+
lambda model=None: LMStudioDriver(
|
|
88
|
+
endpoint=settings.lmstudio_endpoint,
|
|
89
|
+
model=model or settings.lmstudio_model,
|
|
90
|
+
api_key=settings.lmstudio_api_key,
|
|
91
|
+
),
|
|
88
92
|
overwrite=True,
|
|
89
93
|
)
|
|
90
94
|
register_driver(
|
|
@@ -18,6 +18,7 @@ from .azure_driver import AzureDriver
|
|
|
18
18
|
class AsyncAzureDriver(CostMixin, AsyncDriver):
|
|
19
19
|
supports_json_mode = True
|
|
20
20
|
supports_json_schema = True
|
|
21
|
+
supports_vision = True
|
|
21
22
|
|
|
22
23
|
MODEL_PRICING = AzureDriver.MODEL_PRICING
|
|
23
24
|
|
|
@@ -52,12 +53,17 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
|
|
|
52
53
|
|
|
53
54
|
supports_messages = True
|
|
54
55
|
|
|
56
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
57
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
58
|
+
|
|
59
|
+
return _prepare_openai_vision_messages(messages)
|
|
60
|
+
|
|
55
61
|
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
56
62
|
messages = [{"role": "user", "content": prompt}]
|
|
57
63
|
return await self._do_generate(messages, options)
|
|
58
64
|
|
|
59
65
|
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
60
|
-
return await self._do_generate(messages, options)
|
|
66
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
61
67
|
|
|
62
68
|
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
63
69
|
if self.client is None:
|
|
@@ -19,6 +19,7 @@ from .claude_driver import ClaudeDriver
|
|
|
19
19
|
class AsyncClaudeDriver(CostMixin, AsyncDriver):
|
|
20
20
|
supports_json_mode = True
|
|
21
21
|
supports_json_schema = True
|
|
22
|
+
supports_vision = True
|
|
22
23
|
|
|
23
24
|
MODEL_PRICING = ClaudeDriver.MODEL_PRICING
|
|
24
25
|
|
|
@@ -28,12 +29,17 @@ class AsyncClaudeDriver(CostMixin, AsyncDriver):
|
|
|
28
29
|
|
|
29
30
|
supports_messages = True
|
|
30
31
|
|
|
32
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
33
|
+
from .vision_helpers import _prepare_claude_vision_messages
|
|
34
|
+
|
|
35
|
+
return _prepare_claude_vision_messages(messages)
|
|
36
|
+
|
|
31
37
|
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
32
38
|
messages = [{"role": "user", "content": prompt}]
|
|
33
39
|
return await self._do_generate(messages, options)
|
|
34
40
|
|
|
35
41
|
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
36
|
-
return await self._do_generate(messages, options)
|
|
42
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
37
43
|
|
|
38
44
|
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
39
45
|
if anthropic is None:
|