prompture 0.0.34.dev2__py3-none-any.whl → 0.0.35.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +4 -0
- prompture/_version.py +2 -2
- prompture/async_conversation.py +129 -6
- prompture/async_driver.py +40 -2
- prompture/callbacks.py +5 -0
- prompture/cli.py +56 -1
- prompture/conversation.py +132 -5
- prompture/driver.py +46 -3
- prompture/drivers/claude_driver.py +167 -2
- prompture/drivers/ollama_driver.py +68 -1
- prompture/drivers/openai_driver.py +144 -2
- prompture/scaffold/__init__.py +1 -0
- prompture/scaffold/generator.py +84 -0
- prompture/scaffold/templates/Dockerfile.j2 +12 -0
- prompture/scaffold/templates/README.md.j2 +41 -0
- prompture/scaffold/templates/config.py.j2 +21 -0
- prompture/scaffold/templates/env.example.j2 +8 -0
- prompture/scaffold/templates/main.py.j2 +86 -0
- prompture/scaffold/templates/models.py.j2 +40 -0
- prompture/scaffold/templates/requirements.txt.j2 +5 -0
- prompture/server.py +183 -0
- prompture/tools_schema.py +254 -0
- {prompture-0.0.34.dev2.dist-info → prompture-0.0.35.dev1.dist-info}/METADATA +7 -1
- {prompture-0.0.34.dev2.dist-info → prompture-0.0.35.dev1.dist-info}/RECORD +28 -17
- {prompture-0.0.34.dev2.dist-info → prompture-0.0.35.dev1.dist-info}/WHEEL +0 -0
- {prompture-0.0.34.dev2.dist-info → prompture-0.0.35.dev1.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.34.dev2.dist-info → prompture-0.0.35.dev1.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.34.dev2.dist-info → prompture-0.0.35.dev1.dist-info}/top_level.txt +0 -0
prompture/__init__.py
CHANGED
|
@@ -74,6 +74,7 @@ from .runner import run_suite_from_spec
|
|
|
74
74
|
from .session import UsageSession
|
|
75
75
|
from .settings import settings as _settings
|
|
76
76
|
from .tools import clean_json_text, clean_toon_text
|
|
77
|
+
from .tools_schema import ToolDefinition, ToolRegistry, tool_from_function
|
|
77
78
|
from .validator import validate_against_schema
|
|
78
79
|
|
|
79
80
|
# Load environment variables from .env file
|
|
@@ -128,6 +129,8 @@ __all__ = [
|
|
|
128
129
|
"RedisCacheBackend",
|
|
129
130
|
"ResponseCache",
|
|
130
131
|
"SQLiteCacheBackend",
|
|
132
|
+
"ToolDefinition",
|
|
133
|
+
"ToolRegistry",
|
|
131
134
|
"UsageSession",
|
|
132
135
|
"add_field_definition",
|
|
133
136
|
"add_field_definitions",
|
|
@@ -170,6 +173,7 @@ __all__ = [
|
|
|
170
173
|
"reset_registry",
|
|
171
174
|
"run_suite_from_spec",
|
|
172
175
|
"stepwise_extract_with_model",
|
|
176
|
+
"tool_from_function",
|
|
173
177
|
"unregister_async_driver",
|
|
174
178
|
"unregister_driver",
|
|
175
179
|
"validate_against_schema",
|
prompture/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.0.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 0,
|
|
31
|
+
__version__ = version = '0.0.35.dev1'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 0, 35, 'dev1')
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
prompture/async_conversation.py
CHANGED
|
@@ -4,9 +4,10 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import json
|
|
6
6
|
import logging
|
|
7
|
+
from collections.abc import AsyncIterator
|
|
7
8
|
from datetime import date, datetime
|
|
8
9
|
from decimal import Decimal
|
|
9
|
-
from typing import Any, Literal, Union
|
|
10
|
+
from typing import Any, Callable, Literal, Union
|
|
10
11
|
|
|
11
12
|
from pydantic import BaseModel
|
|
12
13
|
|
|
@@ -19,6 +20,7 @@ from .tools import (
|
|
|
19
20
|
convert_value,
|
|
20
21
|
get_field_default,
|
|
21
22
|
)
|
|
23
|
+
from .tools_schema import ToolRegistry
|
|
22
24
|
|
|
23
25
|
logger = logging.getLogger("prompture.async_conversation")
|
|
24
26
|
|
|
@@ -43,6 +45,8 @@ class AsyncConversation:
|
|
|
43
45
|
system_prompt: str | None = None,
|
|
44
46
|
options: dict[str, Any] | None = None,
|
|
45
47
|
callbacks: DriverCallbacks | None = None,
|
|
48
|
+
tools: ToolRegistry | None = None,
|
|
49
|
+
max_tool_rounds: int = 10,
|
|
46
50
|
) -> None:
|
|
47
51
|
if model_name is None and driver is None:
|
|
48
52
|
raise ValueError("Either model_name or driver must be provided")
|
|
@@ -58,7 +62,7 @@ class AsyncConversation:
|
|
|
58
62
|
self._model_name = model_name or ""
|
|
59
63
|
self._system_prompt = system_prompt
|
|
60
64
|
self._options = dict(options) if options else {}
|
|
61
|
-
self._messages: list[dict[str,
|
|
65
|
+
self._messages: list[dict[str, Any]] = []
|
|
62
66
|
self._usage = {
|
|
63
67
|
"prompt_tokens": 0,
|
|
64
68
|
"completion_tokens": 0,
|
|
@@ -66,13 +70,15 @@ class AsyncConversation:
|
|
|
66
70
|
"cost": 0.0,
|
|
67
71
|
"turns": 0,
|
|
68
72
|
}
|
|
73
|
+
self._tools = tools or ToolRegistry()
|
|
74
|
+
self._max_tool_rounds = max_tool_rounds
|
|
69
75
|
|
|
70
76
|
# ------------------------------------------------------------------
|
|
71
77
|
# Public helpers
|
|
72
78
|
# ------------------------------------------------------------------
|
|
73
79
|
|
|
74
80
|
@property
|
|
75
|
-
def messages(self) -> list[dict[str,
|
|
81
|
+
def messages(self) -> list[dict[str, Any]]:
|
|
76
82
|
"""Read-only view of the conversation history."""
|
|
77
83
|
return list(self._messages)
|
|
78
84
|
|
|
@@ -91,6 +97,16 @@ class AsyncConversation:
|
|
|
91
97
|
raise ValueError("role must be 'user' or 'assistant'")
|
|
92
98
|
self._messages.append({"role": role, "content": content})
|
|
93
99
|
|
|
100
|
+
def register_tool(
|
|
101
|
+
self,
|
|
102
|
+
fn: Callable[..., Any],
|
|
103
|
+
*,
|
|
104
|
+
name: str | None = None,
|
|
105
|
+
description: str | None = None,
|
|
106
|
+
) -> None:
|
|
107
|
+
"""Register a Python function as a tool the LLM can call."""
|
|
108
|
+
self._tools.register(fn, name=name, description=description)
|
|
109
|
+
|
|
94
110
|
def usage_summary(self) -> str:
|
|
95
111
|
"""Human-readable summary of accumulated usage."""
|
|
96
112
|
u = self._usage
|
|
@@ -100,9 +116,9 @@ class AsyncConversation:
|
|
|
100
116
|
# Core methods
|
|
101
117
|
# ------------------------------------------------------------------
|
|
102
118
|
|
|
103
|
-
def _build_messages(self, user_content: str) -> list[dict[str,
|
|
119
|
+
def _build_messages(self, user_content: str) -> list[dict[str, Any]]:
|
|
104
120
|
"""Build the full messages array for an API call."""
|
|
105
|
-
msgs: list[dict[str,
|
|
121
|
+
msgs: list[dict[str, Any]] = []
|
|
106
122
|
if self._system_prompt:
|
|
107
123
|
msgs.append({"role": "system", "content": self._system_prompt})
|
|
108
124
|
msgs.extend(self._messages)
|
|
@@ -121,7 +137,14 @@ class AsyncConversation:
|
|
|
121
137
|
content: str,
|
|
122
138
|
options: dict[str, Any] | None = None,
|
|
123
139
|
) -> str:
|
|
124
|
-
"""Send a message and get a raw text response (async).
|
|
140
|
+
"""Send a message and get a raw text response (async).
|
|
141
|
+
|
|
142
|
+
If tools are registered and the driver supports tool use,
|
|
143
|
+
dispatches to the async tool execution loop.
|
|
144
|
+
"""
|
|
145
|
+
if self._tools and getattr(self._driver, "supports_tool_use", False):
|
|
146
|
+
return await self._ask_with_tools(content, options)
|
|
147
|
+
|
|
125
148
|
merged = {**self._options, **(options or {})}
|
|
126
149
|
messages = self._build_messages(content)
|
|
127
150
|
resp = await self._driver.generate_messages_with_hooks(messages, merged)
|
|
@@ -135,6 +158,106 @@ class AsyncConversation:
|
|
|
135
158
|
|
|
136
159
|
return text
|
|
137
160
|
|
|
161
|
+
async def _ask_with_tools(
|
|
162
|
+
self,
|
|
163
|
+
content: str,
|
|
164
|
+
options: dict[str, Any] | None = None,
|
|
165
|
+
) -> str:
|
|
166
|
+
"""Async tool-use loop: send -> check tool_calls -> execute -> re-send."""
|
|
167
|
+
merged = {**self._options, **(options or {})}
|
|
168
|
+
tool_defs = self._tools.to_openai_format()
|
|
169
|
+
|
|
170
|
+
self._messages.append({"role": "user", "content": content})
|
|
171
|
+
msgs = self._build_messages_raw()
|
|
172
|
+
|
|
173
|
+
for _round in range(self._max_tool_rounds):
|
|
174
|
+
resp = await self._driver.generate_messages_with_tools(msgs, tool_defs, merged)
|
|
175
|
+
|
|
176
|
+
meta = resp.get("meta", {})
|
|
177
|
+
self._accumulate_usage(meta)
|
|
178
|
+
|
|
179
|
+
tool_calls = resp.get("tool_calls", [])
|
|
180
|
+
text = resp.get("text", "")
|
|
181
|
+
|
|
182
|
+
if not tool_calls:
|
|
183
|
+
self._messages.append({"role": "assistant", "content": text})
|
|
184
|
+
return text
|
|
185
|
+
|
|
186
|
+
assistant_msg: dict[str, Any] = {"role": "assistant", "content": text}
|
|
187
|
+
assistant_msg["tool_calls"] = [
|
|
188
|
+
{
|
|
189
|
+
"id": tc["id"],
|
|
190
|
+
"type": "function",
|
|
191
|
+
"function": {"name": tc["name"], "arguments": json.dumps(tc["arguments"])},
|
|
192
|
+
}
|
|
193
|
+
for tc in tool_calls
|
|
194
|
+
]
|
|
195
|
+
self._messages.append(assistant_msg)
|
|
196
|
+
msgs.append(assistant_msg)
|
|
197
|
+
|
|
198
|
+
for tc in tool_calls:
|
|
199
|
+
try:
|
|
200
|
+
result = self._tools.execute(tc["name"], tc["arguments"])
|
|
201
|
+
result_str = json.dumps(result) if not isinstance(result, str) else result
|
|
202
|
+
except Exception as exc:
|
|
203
|
+
result_str = f"Error: {exc}"
|
|
204
|
+
|
|
205
|
+
tool_result_msg: dict[str, Any] = {
|
|
206
|
+
"role": "tool",
|
|
207
|
+
"tool_call_id": tc["id"],
|
|
208
|
+
"content": result_str,
|
|
209
|
+
}
|
|
210
|
+
self._messages.append(tool_result_msg)
|
|
211
|
+
msgs.append(tool_result_msg)
|
|
212
|
+
|
|
213
|
+
raise RuntimeError(f"Tool execution loop exceeded {self._max_tool_rounds} rounds")
|
|
214
|
+
|
|
215
|
+
def _build_messages_raw(self) -> list[dict[str, Any]]:
|
|
216
|
+
"""Build messages array from system prompt + full history (including tool messages)."""
|
|
217
|
+
msgs: list[dict[str, Any]] = []
|
|
218
|
+
if self._system_prompt:
|
|
219
|
+
msgs.append({"role": "system", "content": self._system_prompt})
|
|
220
|
+
msgs.extend(self._messages)
|
|
221
|
+
return msgs
|
|
222
|
+
|
|
223
|
+
# ------------------------------------------------------------------
|
|
224
|
+
# Streaming
|
|
225
|
+
# ------------------------------------------------------------------
|
|
226
|
+
|
|
227
|
+
async def ask_stream(
|
|
228
|
+
self,
|
|
229
|
+
content: str,
|
|
230
|
+
options: dict[str, Any] | None = None,
|
|
231
|
+
) -> AsyncIterator[str]:
|
|
232
|
+
"""Send a message and yield text chunks as they arrive (async).
|
|
233
|
+
|
|
234
|
+
Falls back to non-streaming :meth:`ask` if the driver doesn't
|
|
235
|
+
support streaming.
|
|
236
|
+
"""
|
|
237
|
+
if not getattr(self._driver, "supports_streaming", False):
|
|
238
|
+
yield await self.ask(content, options)
|
|
239
|
+
return
|
|
240
|
+
|
|
241
|
+
merged = {**self._options, **(options or {})}
|
|
242
|
+
messages = self._build_messages(content)
|
|
243
|
+
|
|
244
|
+
self._messages.append({"role": "user", "content": content})
|
|
245
|
+
|
|
246
|
+
full_text = ""
|
|
247
|
+
async for chunk in self._driver.generate_messages_stream(messages, merged):
|
|
248
|
+
if chunk["type"] == "delta":
|
|
249
|
+
full_text += chunk["text"]
|
|
250
|
+
self._driver._fire_callback(
|
|
251
|
+
"on_stream_delta",
|
|
252
|
+
{"text": chunk["text"], "driver": getattr(self._driver, "model", self._driver.__class__.__name__)},
|
|
253
|
+
)
|
|
254
|
+
yield chunk["text"]
|
|
255
|
+
elif chunk["type"] == "done":
|
|
256
|
+
meta = chunk.get("meta", {})
|
|
257
|
+
self._accumulate_usage(meta)
|
|
258
|
+
|
|
259
|
+
self._messages.append({"role": "assistant", "content": full_text})
|
|
260
|
+
|
|
138
261
|
async def ask_for_json(
|
|
139
262
|
self,
|
|
140
263
|
content: str,
|
prompture/async_driver.py
CHANGED
|
@@ -4,6 +4,7 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import logging
|
|
6
6
|
import time
|
|
7
|
+
from collections.abc import AsyncIterator
|
|
7
8
|
from typing import Any
|
|
8
9
|
|
|
9
10
|
from .callbacks import DriverCallbacks
|
|
@@ -32,13 +33,15 @@ class AsyncDriver:
|
|
|
32
33
|
supports_json_mode: bool = False
|
|
33
34
|
supports_json_schema: bool = False
|
|
34
35
|
supports_messages: bool = False
|
|
36
|
+
supports_tool_use: bool = False
|
|
37
|
+
supports_streaming: bool = False
|
|
35
38
|
|
|
36
39
|
callbacks: DriverCallbacks | None = None
|
|
37
40
|
|
|
38
41
|
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
39
42
|
raise NotImplementedError
|
|
40
43
|
|
|
41
|
-
async def generate_messages(self, messages: list[dict[str,
|
|
44
|
+
async def generate_messages(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
|
|
42
45
|
"""Generate a response from a list of conversation messages (async).
|
|
43
46
|
|
|
44
47
|
Default implementation flattens the messages into a single prompt
|
|
@@ -49,6 +52,41 @@ class AsyncDriver:
|
|
|
49
52
|
prompt = Driver._flatten_messages(messages)
|
|
50
53
|
return await self.generate(prompt, options)
|
|
51
54
|
|
|
55
|
+
# ------------------------------------------------------------------
|
|
56
|
+
# Tool use
|
|
57
|
+
# ------------------------------------------------------------------
|
|
58
|
+
|
|
59
|
+
async def generate_messages_with_tools(
|
|
60
|
+
self,
|
|
61
|
+
messages: list[dict[str, Any]],
|
|
62
|
+
tools: list[dict[str, Any]],
|
|
63
|
+
options: dict[str, Any],
|
|
64
|
+
) -> dict[str, Any]:
|
|
65
|
+
"""Generate a response that may include tool calls (async).
|
|
66
|
+
|
|
67
|
+
Returns a dict with keys: ``text``, ``meta``, ``tool_calls``, ``stop_reason``.
|
|
68
|
+
"""
|
|
69
|
+
raise NotImplementedError(f"{self.__class__.__name__} does not support tool use")
|
|
70
|
+
|
|
71
|
+
# ------------------------------------------------------------------
|
|
72
|
+
# Streaming
|
|
73
|
+
# ------------------------------------------------------------------
|
|
74
|
+
|
|
75
|
+
async def generate_messages_stream(
|
|
76
|
+
self,
|
|
77
|
+
messages: list[dict[str, Any]],
|
|
78
|
+
options: dict[str, Any],
|
|
79
|
+
) -> AsyncIterator[dict[str, Any]]:
|
|
80
|
+
"""Yield response chunks incrementally (async).
|
|
81
|
+
|
|
82
|
+
Each chunk is a dict:
|
|
83
|
+
- ``{"type": "delta", "text": str}``
|
|
84
|
+
- ``{"type": "done", "text": str, "meta": dict}``
|
|
85
|
+
"""
|
|
86
|
+
raise NotImplementedError(f"{self.__class__.__name__} does not support streaming")
|
|
87
|
+
# yield is needed to make this an async generator
|
|
88
|
+
yield # pragma: no cover
|
|
89
|
+
|
|
52
90
|
# ------------------------------------------------------------------
|
|
53
91
|
# Hook-aware wrappers
|
|
54
92
|
# ------------------------------------------------------------------
|
|
@@ -82,7 +120,7 @@ class AsyncDriver:
|
|
|
82
120
|
return resp
|
|
83
121
|
|
|
84
122
|
async def generate_messages_with_hooks(
|
|
85
|
-
self, messages: list[dict[str,
|
|
123
|
+
self, messages: list[dict[str, Any]], options: dict[str, Any]
|
|
86
124
|
) -> dict[str, Any]:
|
|
87
125
|
"""Wrap :meth:`generate_messages` with callbacks."""
|
|
88
126
|
driver_name = getattr(self, "model", self.__class__.__name__)
|
prompture/callbacks.py
CHANGED
|
@@ -27,6 +27,7 @@ from typing import Any, Callable
|
|
|
27
27
|
OnRequestCallback = Callable[[dict[str, Any]], None]
|
|
28
28
|
OnResponseCallback = Callable[[dict[str, Any]], None]
|
|
29
29
|
OnErrorCallback = Callable[[dict[str, Any]], None]
|
|
30
|
+
OnStreamDeltaCallback = Callable[[dict[str, Any]], None]
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
@dataclass
|
|
@@ -43,8 +44,12 @@ class DriverCallbacks:
|
|
|
43
44
|
|
|
44
45
|
``on_error``
|
|
45
46
|
``{error, prompt, messages, options, driver}``
|
|
47
|
+
|
|
48
|
+
``on_stream_delta``
|
|
49
|
+
``{text, driver}``
|
|
46
50
|
"""
|
|
47
51
|
|
|
48
52
|
on_request: OnRequestCallback | None = field(default=None)
|
|
49
53
|
on_response: OnResponseCallback | None = field(default=None)
|
|
50
54
|
on_error: OnErrorCallback | None = field(default=None)
|
|
55
|
+
on_stream_delta: OnStreamDeltaCallback | None = field(default=None)
|
prompture/cli.py
CHANGED
|
@@ -8,7 +8,7 @@ from .runner import run_suite_from_spec
|
|
|
8
8
|
|
|
9
9
|
@click.group()
|
|
10
10
|
def cli():
|
|
11
|
-
"""
|
|
11
|
+
"""Prompture CLI -- structured LLM output toolkit."""
|
|
12
12
|
pass
|
|
13
13
|
|
|
14
14
|
|
|
@@ -25,3 +25,58 @@ def run(specfile, outfile):
|
|
|
25
25
|
with open(outfile, "w", encoding="utf-8") as fh:
|
|
26
26
|
json.dump(report, fh, indent=2, ensure_ascii=False)
|
|
27
27
|
click.echo(f"Report saved to {outfile}")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@cli.command()
|
|
31
|
+
@click.option("--model", default="openai/gpt-4o-mini", help="Model string (provider/model).")
|
|
32
|
+
@click.option("--system-prompt", default=None, help="System prompt for conversations.")
|
|
33
|
+
@click.option("--host", default="0.0.0.0", help="Bind host.")
|
|
34
|
+
@click.option("--port", default=8000, type=int, help="Bind port.")
|
|
35
|
+
@click.option("--cors-origins", default=None, help="Comma-separated CORS origins (use * for all).")
|
|
36
|
+
def serve(model, system_prompt, host, port, cors_origins):
|
|
37
|
+
"""Start an API server wrapping AsyncConversation.
|
|
38
|
+
|
|
39
|
+
Requires the 'serve' extra: pip install prompture[serve]
|
|
40
|
+
"""
|
|
41
|
+
try:
|
|
42
|
+
import uvicorn
|
|
43
|
+
except ImportError:
|
|
44
|
+
click.echo("Error: uvicorn not installed. Run: pip install prompture[serve]", err=True)
|
|
45
|
+
raise SystemExit(1) from None
|
|
46
|
+
|
|
47
|
+
from .server import create_app
|
|
48
|
+
|
|
49
|
+
origins = [o.strip() for o in cors_origins.split(",")] if cors_origins else None
|
|
50
|
+
app = create_app(
|
|
51
|
+
model_name=model,
|
|
52
|
+
system_prompt=system_prompt,
|
|
53
|
+
cors_origins=origins,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
click.echo(f"Starting Prompture server on {host}:{port} with model {model}")
|
|
57
|
+
uvicorn.run(app, host=host, port=port)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@cli.command()
|
|
61
|
+
@click.argument("output_dir", type=click.Path())
|
|
62
|
+
@click.option("--name", default="my_app", help="Project name.")
|
|
63
|
+
@click.option("--model", default="openai/gpt-4o-mini", help="Default model string.")
|
|
64
|
+
@click.option("--docker/--no-docker", default=True, help="Include Dockerfile.")
|
|
65
|
+
def scaffold(output_dir, name, model, docker):
|
|
66
|
+
"""Generate a standalone FastAPI project using Prompture.
|
|
67
|
+
|
|
68
|
+
Requires the 'scaffold' extra: pip install prompture[scaffold]
|
|
69
|
+
"""
|
|
70
|
+
try:
|
|
71
|
+
from .scaffold.generator import scaffold_project
|
|
72
|
+
except ImportError:
|
|
73
|
+
click.echo("Error: jinja2 not installed. Run: pip install prompture[scaffold]", err=True)
|
|
74
|
+
raise SystemExit(1) from None
|
|
75
|
+
|
|
76
|
+
scaffold_project(
|
|
77
|
+
output_dir=output_dir,
|
|
78
|
+
project_name=name,
|
|
79
|
+
model_name=model,
|
|
80
|
+
include_docker=docker,
|
|
81
|
+
)
|
|
82
|
+
click.echo(f"Project scaffolded at {output_dir}")
|
prompture/conversation.py
CHANGED
|
@@ -4,9 +4,10 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import json
|
|
6
6
|
import logging
|
|
7
|
+
from collections.abc import Iterator
|
|
7
8
|
from datetime import date, datetime
|
|
8
9
|
from decimal import Decimal
|
|
9
|
-
from typing import Any, Literal, Union
|
|
10
|
+
from typing import Any, Callable, Literal, Union
|
|
10
11
|
|
|
11
12
|
from pydantic import BaseModel
|
|
12
13
|
|
|
@@ -19,6 +20,7 @@ from .tools import (
|
|
|
19
20
|
convert_value,
|
|
20
21
|
get_field_default,
|
|
21
22
|
)
|
|
23
|
+
from .tools_schema import ToolRegistry
|
|
22
24
|
|
|
23
25
|
logger = logging.getLogger("prompture.conversation")
|
|
24
26
|
|
|
@@ -44,6 +46,8 @@ class Conversation:
|
|
|
44
46
|
system_prompt: str | None = None,
|
|
45
47
|
options: dict[str, Any] | None = None,
|
|
46
48
|
callbacks: DriverCallbacks | None = None,
|
|
49
|
+
tools: ToolRegistry | None = None,
|
|
50
|
+
max_tool_rounds: int = 10,
|
|
47
51
|
) -> None:
|
|
48
52
|
if model_name is None and driver is None:
|
|
49
53
|
raise ValueError("Either model_name or driver must be provided")
|
|
@@ -59,7 +63,7 @@ class Conversation:
|
|
|
59
63
|
self._model_name = model_name or ""
|
|
60
64
|
self._system_prompt = system_prompt
|
|
61
65
|
self._options = dict(options) if options else {}
|
|
62
|
-
self._messages: list[dict[str,
|
|
66
|
+
self._messages: list[dict[str, Any]] = []
|
|
63
67
|
self._usage = {
|
|
64
68
|
"prompt_tokens": 0,
|
|
65
69
|
"completion_tokens": 0,
|
|
@@ -67,13 +71,15 @@ class Conversation:
|
|
|
67
71
|
"cost": 0.0,
|
|
68
72
|
"turns": 0,
|
|
69
73
|
}
|
|
74
|
+
self._tools = tools or ToolRegistry()
|
|
75
|
+
self._max_tool_rounds = max_tool_rounds
|
|
70
76
|
|
|
71
77
|
# ------------------------------------------------------------------
|
|
72
78
|
# Public helpers
|
|
73
79
|
# ------------------------------------------------------------------
|
|
74
80
|
|
|
75
81
|
@property
|
|
76
|
-
def messages(self) -> list[dict[str,
|
|
82
|
+
def messages(self) -> list[dict[str, Any]]:
|
|
77
83
|
"""Read-only view of the conversation history."""
|
|
78
84
|
return list(self._messages)
|
|
79
85
|
|
|
@@ -92,6 +98,16 @@ class Conversation:
|
|
|
92
98
|
raise ValueError("role must be 'user' or 'assistant'")
|
|
93
99
|
self._messages.append({"role": role, "content": content})
|
|
94
100
|
|
|
101
|
+
def register_tool(
|
|
102
|
+
self,
|
|
103
|
+
fn: Callable[..., Any],
|
|
104
|
+
*,
|
|
105
|
+
name: str | None = None,
|
|
106
|
+
description: str | None = None,
|
|
107
|
+
) -> None:
|
|
108
|
+
"""Register a Python function as a tool the LLM can call."""
|
|
109
|
+
self._tools.register(fn, name=name, description=description)
|
|
110
|
+
|
|
95
111
|
def usage_summary(self) -> str:
|
|
96
112
|
"""Human-readable summary of accumulated usage."""
|
|
97
113
|
u = self._usage
|
|
@@ -101,9 +117,9 @@ class Conversation:
|
|
|
101
117
|
# Core methods
|
|
102
118
|
# ------------------------------------------------------------------
|
|
103
119
|
|
|
104
|
-
def _build_messages(self, user_content: str) -> list[dict[str,
|
|
120
|
+
def _build_messages(self, user_content: str) -> list[dict[str, Any]]:
|
|
105
121
|
"""Build the full messages array for an API call."""
|
|
106
|
-
msgs: list[dict[str,
|
|
122
|
+
msgs: list[dict[str, Any]] = []
|
|
107
123
|
if self._system_prompt:
|
|
108
124
|
msgs.append({"role": "system", "content": self._system_prompt})
|
|
109
125
|
msgs.extend(self._messages)
|
|
@@ -125,7 +141,12 @@ class Conversation:
|
|
|
125
141
|
"""Send a message and get a raw text response.
|
|
126
142
|
|
|
127
143
|
Appends the user message and assistant response to history.
|
|
144
|
+
If tools are registered and the driver supports tool use,
|
|
145
|
+
dispatches to the tool execution loop.
|
|
128
146
|
"""
|
|
147
|
+
if self._tools and getattr(self._driver, "supports_tool_use", False):
|
|
148
|
+
return self._ask_with_tools(content, options)
|
|
149
|
+
|
|
129
150
|
merged = {**self._options, **(options or {})}
|
|
130
151
|
messages = self._build_messages(content)
|
|
131
152
|
resp = self._driver.generate_messages_with_hooks(messages, merged)
|
|
@@ -140,6 +161,112 @@ class Conversation:
|
|
|
140
161
|
|
|
141
162
|
return text
|
|
142
163
|
|
|
164
|
+
def _ask_with_tools(
|
|
165
|
+
self,
|
|
166
|
+
content: str,
|
|
167
|
+
options: dict[str, Any] | None = None,
|
|
168
|
+
) -> str:
|
|
169
|
+
"""Execute the tool-use loop: send -> check tool_calls -> execute -> re-send."""
|
|
170
|
+
merged = {**self._options, **(options or {})}
|
|
171
|
+
tool_defs = self._tools.to_openai_format()
|
|
172
|
+
|
|
173
|
+
# Build messages including user content
|
|
174
|
+
self._messages.append({"role": "user", "content": content})
|
|
175
|
+
msgs = self._build_messages_raw()
|
|
176
|
+
|
|
177
|
+
for _round in range(self._max_tool_rounds):
|
|
178
|
+
resp = self._driver.generate_messages_with_tools(msgs, tool_defs, merged)
|
|
179
|
+
|
|
180
|
+
meta = resp.get("meta", {})
|
|
181
|
+
self._accumulate_usage(meta)
|
|
182
|
+
|
|
183
|
+
tool_calls = resp.get("tool_calls", [])
|
|
184
|
+
text = resp.get("text", "")
|
|
185
|
+
|
|
186
|
+
if not tool_calls:
|
|
187
|
+
# No tool calls -> final response
|
|
188
|
+
self._messages.append({"role": "assistant", "content": text})
|
|
189
|
+
return text
|
|
190
|
+
|
|
191
|
+
# Record assistant message with tool_calls
|
|
192
|
+
assistant_msg: dict[str, Any] = {"role": "assistant", "content": text}
|
|
193
|
+
assistant_msg["tool_calls"] = [
|
|
194
|
+
{
|
|
195
|
+
"id": tc["id"],
|
|
196
|
+
"type": "function",
|
|
197
|
+
"function": {"name": tc["name"], "arguments": json.dumps(tc["arguments"])},
|
|
198
|
+
}
|
|
199
|
+
for tc in tool_calls
|
|
200
|
+
]
|
|
201
|
+
self._messages.append(assistant_msg)
|
|
202
|
+
msgs.append(assistant_msg)
|
|
203
|
+
|
|
204
|
+
# Execute each tool call and append results
|
|
205
|
+
for tc in tool_calls:
|
|
206
|
+
try:
|
|
207
|
+
result = self._tools.execute(tc["name"], tc["arguments"])
|
|
208
|
+
result_str = json.dumps(result) if not isinstance(result, str) else result
|
|
209
|
+
except Exception as exc:
|
|
210
|
+
result_str = f"Error: {exc}"
|
|
211
|
+
|
|
212
|
+
tool_result_msg: dict[str, Any] = {
|
|
213
|
+
"role": "tool",
|
|
214
|
+
"tool_call_id": tc["id"],
|
|
215
|
+
"content": result_str,
|
|
216
|
+
}
|
|
217
|
+
self._messages.append(tool_result_msg)
|
|
218
|
+
msgs.append(tool_result_msg)
|
|
219
|
+
|
|
220
|
+
raise RuntimeError(f"Tool execution loop exceeded {self._max_tool_rounds} rounds")
|
|
221
|
+
|
|
222
|
+
def _build_messages_raw(self) -> list[dict[str, Any]]:
|
|
223
|
+
"""Build messages array from system prompt + full history (including tool messages)."""
|
|
224
|
+
msgs: list[dict[str, Any]] = []
|
|
225
|
+
if self._system_prompt:
|
|
226
|
+
msgs.append({"role": "system", "content": self._system_prompt})
|
|
227
|
+
msgs.extend(self._messages)
|
|
228
|
+
return msgs
|
|
229
|
+
|
|
230
|
+
# ------------------------------------------------------------------
|
|
231
|
+
# Streaming
|
|
232
|
+
# ------------------------------------------------------------------
|
|
233
|
+
|
|
234
|
+
def ask_stream(
|
|
235
|
+
self,
|
|
236
|
+
content: str,
|
|
237
|
+
options: dict[str, Any] | None = None,
|
|
238
|
+
) -> Iterator[str]:
|
|
239
|
+
"""Send a message and yield text chunks as they arrive.
|
|
240
|
+
|
|
241
|
+
Falls back to non-streaming :meth:`ask` if the driver doesn't
|
|
242
|
+
support streaming. After iteration completes, the full response
|
|
243
|
+
is recorded in history.
|
|
244
|
+
"""
|
|
245
|
+
if not getattr(self._driver, "supports_streaming", False):
|
|
246
|
+
yield self.ask(content, options)
|
|
247
|
+
return
|
|
248
|
+
|
|
249
|
+
merged = {**self._options, **(options or {})}
|
|
250
|
+
messages = self._build_messages(content)
|
|
251
|
+
|
|
252
|
+
self._messages.append({"role": "user", "content": content})
|
|
253
|
+
|
|
254
|
+
full_text = ""
|
|
255
|
+
for chunk in self._driver.generate_messages_stream(messages, merged):
|
|
256
|
+
if chunk["type"] == "delta":
|
|
257
|
+
full_text += chunk["text"]
|
|
258
|
+
# Fire stream delta callback
|
|
259
|
+
self._driver._fire_callback(
|
|
260
|
+
"on_stream_delta",
|
|
261
|
+
{"text": chunk["text"], "driver": getattr(self._driver, "model", self._driver.__class__.__name__)},
|
|
262
|
+
)
|
|
263
|
+
yield chunk["text"]
|
|
264
|
+
elif chunk["type"] == "done":
|
|
265
|
+
meta = chunk.get("meta", {})
|
|
266
|
+
self._accumulate_usage(meta)
|
|
267
|
+
|
|
268
|
+
self._messages.append({"role": "assistant", "content": full_text})
|
|
269
|
+
|
|
143
270
|
def ask_for_json(
|
|
144
271
|
self,
|
|
145
272
|
content: str,
|