prompture 0.0.34.dev2__py3-none-any.whl → 0.0.35.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +4 -0
- prompture/_version.py +2 -2
- prompture/async_conversation.py +129 -6
- prompture/async_driver.py +40 -2
- prompture/callbacks.py +5 -0
- prompture/cli.py +56 -1
- prompture/conversation.py +132 -5
- prompture/driver.py +46 -3
- prompture/drivers/claude_driver.py +167 -2
- prompture/drivers/ollama_driver.py +68 -1
- prompture/drivers/openai_driver.py +144 -2
- prompture/scaffold/__init__.py +1 -0
- prompture/scaffold/generator.py +84 -0
- prompture/scaffold/templates/Dockerfile.j2 +12 -0
- prompture/scaffold/templates/README.md.j2 +41 -0
- prompture/scaffold/templates/config.py.j2 +21 -0
- prompture/scaffold/templates/env.example.j2 +8 -0
- prompture/scaffold/templates/main.py.j2 +86 -0
- prompture/scaffold/templates/models.py.j2 +40 -0
- prompture/scaffold/templates/requirements.txt.j2 +5 -0
- prompture/server.py +183 -0
- prompture/tools_schema.py +254 -0
- {prompture-0.0.34.dev2.dist-info → prompture-0.0.35.dev1.dist-info}/METADATA +7 -1
- {prompture-0.0.34.dev2.dist-info → prompture-0.0.35.dev1.dist-info}/RECORD +28 -17
- {prompture-0.0.34.dev2.dist-info → prompture-0.0.35.dev1.dist-info}/WHEEL +0 -0
- {prompture-0.0.34.dev2.dist-info → prompture-0.0.35.dev1.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.34.dev2.dist-info → prompture-0.0.35.dev1.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.34.dev2.dist-info → prompture-0.0.35.dev1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Configuration for {{ project_name }}."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pydantic_settings import BaseSettings
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Settings(BaseSettings):
|
|
9
|
+
model_name: str = "{{ model_name }}"
|
|
10
|
+
system_prompt: str = "You are a helpful assistant."
|
|
11
|
+
cors_origins: list[str] = ["*"]
|
|
12
|
+
|
|
13
|
+
# Provider API keys (loaded from environment / .env)
|
|
14
|
+
openai_api_key: str = ""
|
|
15
|
+
claude_api_key: str = ""
|
|
16
|
+
google_api_key: str = ""
|
|
17
|
+
|
|
18
|
+
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
settings = Settings()
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""{{ project_name }} -- FastAPI server powered by Prompture."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import uuid
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from fastapi import FastAPI, HTTPException
|
|
10
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
11
|
+
|
|
12
|
+
from .config import settings
|
|
13
|
+
from .models import (
|
|
14
|
+
ChatRequest,
|
|
15
|
+
ChatResponse,
|
|
16
|
+
ConversationHistory,
|
|
17
|
+
ExtractRequest,
|
|
18
|
+
ExtractResponse,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
from prompture import AsyncConversation
|
|
22
|
+
|
|
23
|
+
app = FastAPI(title="{{ project_name }}", version="0.1.0")
|
|
24
|
+
|
|
25
|
+
app.add_middleware(
|
|
26
|
+
CORSMiddleware,
|
|
27
|
+
allow_origins=settings.cors_origins,
|
|
28
|
+
allow_credentials=True,
|
|
29
|
+
allow_methods=["*"],
|
|
30
|
+
allow_headers=["*"],
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
_conversations: dict[str, AsyncConversation] = {}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _get_or_create_conversation(conv_id: str | None) -> tuple[str, AsyncConversation]:
|
|
37
|
+
if conv_id and conv_id in _conversations:
|
|
38
|
+
return conv_id, _conversations[conv_id]
|
|
39
|
+
new_id = conv_id or uuid.uuid4().hex[:12]
|
|
40
|
+
conv = AsyncConversation(
|
|
41
|
+
model_name=settings.model_name,
|
|
42
|
+
system_prompt=settings.system_prompt,
|
|
43
|
+
)
|
|
44
|
+
_conversations[new_id] = conv
|
|
45
|
+
return new_id, conv
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@app.post("/v1/chat", response_model=ChatResponse)
|
|
49
|
+
async def chat(request: ChatRequest):
|
|
50
|
+
conv_id, conv = _get_or_create_conversation(request.conversation_id)
|
|
51
|
+
text = await conv.ask(request.message, request.options)
|
|
52
|
+
return ChatResponse(message=text, conversation_id=conv_id, usage=conv.usage)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@app.post("/v1/extract", response_model=ExtractResponse)
|
|
56
|
+
async def extract(request: ExtractRequest):
|
|
57
|
+
conv_id, conv = _get_or_create_conversation(request.conversation_id)
|
|
58
|
+
result = await conv.ask_for_json(
|
|
59
|
+
content=request.text,
|
|
60
|
+
json_schema=request.schema_def,
|
|
61
|
+
)
|
|
62
|
+
return ExtractResponse(
|
|
63
|
+
json_object=result["json_object"],
|
|
64
|
+
conversation_id=conv_id,
|
|
65
|
+
usage=conv.usage,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@app.get("/v1/conversations/{conversation_id}", response_model=ConversationHistory)
|
|
70
|
+
async def get_conversation(conversation_id: str):
|
|
71
|
+
if conversation_id not in _conversations:
|
|
72
|
+
raise HTTPException(status_code=404, detail="Conversation not found")
|
|
73
|
+
conv = _conversations[conversation_id]
|
|
74
|
+
return ConversationHistory(
|
|
75
|
+
conversation_id=conversation_id,
|
|
76
|
+
messages=conv.messages,
|
|
77
|
+
usage=conv.usage,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@app.delete("/v1/conversations/{conversation_id}")
|
|
82
|
+
async def delete_conversation(conversation_id: str):
|
|
83
|
+
if conversation_id not in _conversations:
|
|
84
|
+
raise HTTPException(status_code=404, detail="Conversation not found")
|
|
85
|
+
del _conversations[conversation_id]
|
|
86
|
+
return {"status": "deleted", "conversation_id": conversation_id}
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""Pydantic request/response models for {{ project_name }}."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, Field
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ChatRequest(BaseModel):
|
|
11
|
+
message: str
|
|
12
|
+
conversation_id: str | None = None
|
|
13
|
+
stream: bool = False
|
|
14
|
+
options: dict[str, Any] | None = None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ChatResponse(BaseModel):
|
|
18
|
+
message: str
|
|
19
|
+
conversation_id: str
|
|
20
|
+
usage: dict[str, Any]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ExtractRequest(BaseModel):
|
|
24
|
+
text: str
|
|
25
|
+
schema_def: dict[str, Any] = Field(..., alias="schema")
|
|
26
|
+
conversation_id: str | None = None
|
|
27
|
+
|
|
28
|
+
model_config = {"populate_by_name": True}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ExtractResponse(BaseModel):
|
|
32
|
+
json_object: dict[str, Any]
|
|
33
|
+
conversation_id: str
|
|
34
|
+
usage: dict[str, Any]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ConversationHistory(BaseModel):
|
|
38
|
+
conversation_id: str
|
|
39
|
+
messages: list[dict[str, Any]]
|
|
40
|
+
usage: dict[str, Any]
|
prompture/server.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
"""Built-in API server wrapping AsyncConversation.
|
|
2
|
+
|
|
3
|
+
Provides a FastAPI application with chat, extraction, and model
|
|
4
|
+
listing endpoints. ``fastapi``, ``uvicorn``, and ``sse-starlette``
|
|
5
|
+
are lazy-imported so the module is importable without them installed.
|
|
6
|
+
|
|
7
|
+
Usage::
|
|
8
|
+
|
|
9
|
+
from prompture.server import create_app
|
|
10
|
+
app = create_app(model_name="openai/gpt-4o-mini")
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import json
|
|
14
|
+
import logging
|
|
15
|
+
import uuid
|
|
16
|
+
from typing import Any, Optional
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger("prompture.server")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def create_app(
|
|
22
|
+
model_name: str = "openai/gpt-4o-mini",
|
|
23
|
+
system_prompt: Optional[str] = None,
|
|
24
|
+
tools: Any = None,
|
|
25
|
+
cors_origins: Optional[list[str]] = None,
|
|
26
|
+
) -> Any:
|
|
27
|
+
"""Create and return a FastAPI application.
|
|
28
|
+
|
|
29
|
+
Parameters:
|
|
30
|
+
model_name: Default model string (``provider/model``).
|
|
31
|
+
system_prompt: Optional system prompt for new conversations.
|
|
32
|
+
tools: Optional :class:`~prompture.tools_schema.ToolRegistry`.
|
|
33
|
+
cors_origins: CORS allowed origins. ``["*"]`` to allow all.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
A ``fastapi.FastAPI`` instance.
|
|
37
|
+
"""
|
|
38
|
+
try:
|
|
39
|
+
from fastapi import FastAPI, HTTPException
|
|
40
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
41
|
+
from pydantic import BaseModel, Field
|
|
42
|
+
except ImportError as exc:
|
|
43
|
+
raise ImportError(
|
|
44
|
+
"The 'serve' extra is required: pip install prompture[serve]"
|
|
45
|
+
) from exc
|
|
46
|
+
|
|
47
|
+
from .async_conversation import AsyncConversation
|
|
48
|
+
from .tools_schema import ToolRegistry
|
|
49
|
+
|
|
50
|
+
# ---- Pydantic request/response models ----
|
|
51
|
+
|
|
52
|
+
class ChatRequest(BaseModel):
|
|
53
|
+
message: str
|
|
54
|
+
conversation_id: Optional[str] = None
|
|
55
|
+
stream: bool = False
|
|
56
|
+
options: Optional[dict[str, Any]] = None
|
|
57
|
+
|
|
58
|
+
class ChatResponse(BaseModel):
|
|
59
|
+
message: str
|
|
60
|
+
conversation_id: str
|
|
61
|
+
usage: dict[str, Any]
|
|
62
|
+
|
|
63
|
+
class ExtractRequest(BaseModel):
|
|
64
|
+
text: str
|
|
65
|
+
schema_def: dict[str, Any] = Field(..., alias="schema")
|
|
66
|
+
conversation_id: Optional[str] = None
|
|
67
|
+
|
|
68
|
+
model_config = {"populate_by_name": True}
|
|
69
|
+
|
|
70
|
+
class ExtractResponse(BaseModel):
|
|
71
|
+
json_object: dict[str, Any]
|
|
72
|
+
conversation_id: str
|
|
73
|
+
usage: dict[str, Any]
|
|
74
|
+
|
|
75
|
+
class ModelInfo(BaseModel):
|
|
76
|
+
models: list[str]
|
|
77
|
+
|
|
78
|
+
class ConversationHistory(BaseModel):
|
|
79
|
+
conversation_id: str
|
|
80
|
+
messages: list[dict[str, Any]]
|
|
81
|
+
usage: dict[str, Any]
|
|
82
|
+
|
|
83
|
+
# ---- App ----
|
|
84
|
+
|
|
85
|
+
app = FastAPI(title="Prompture API", version="0.1.0")
|
|
86
|
+
|
|
87
|
+
if cors_origins:
|
|
88
|
+
app.add_middleware(
|
|
89
|
+
CORSMiddleware,
|
|
90
|
+
allow_origins=cors_origins,
|
|
91
|
+
allow_credentials=True,
|
|
92
|
+
allow_methods=["*"],
|
|
93
|
+
allow_headers=["*"],
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# In-memory conversation store
|
|
97
|
+
_conversations: dict[str, AsyncConversation] = {}
|
|
98
|
+
|
|
99
|
+
tool_registry: Optional[ToolRegistry] = tools
|
|
100
|
+
|
|
101
|
+
def _get_or_create_conversation(conv_id: Optional[str]) -> tuple[str, AsyncConversation]:
|
|
102
|
+
if conv_id and conv_id in _conversations:
|
|
103
|
+
return conv_id, _conversations[conv_id]
|
|
104
|
+
new_id = conv_id or uuid.uuid4().hex[:12]
|
|
105
|
+
conv = AsyncConversation(
|
|
106
|
+
model_name=model_name,
|
|
107
|
+
system_prompt=system_prompt,
|
|
108
|
+
tools=tool_registry,
|
|
109
|
+
)
|
|
110
|
+
_conversations[new_id] = conv
|
|
111
|
+
return new_id, conv
|
|
112
|
+
|
|
113
|
+
# ---- Endpoints ----
|
|
114
|
+
|
|
115
|
+
@app.post("/v1/chat", response_model=ChatResponse)
|
|
116
|
+
async def chat(chat_req: ChatRequest):
|
|
117
|
+
conv_id, conv = _get_or_create_conversation(chat_req.conversation_id)
|
|
118
|
+
|
|
119
|
+
if chat_req.stream:
|
|
120
|
+
# SSE streaming
|
|
121
|
+
try:
|
|
122
|
+
from sse_starlette.sse import EventSourceResponse
|
|
123
|
+
except ImportError:
|
|
124
|
+
raise HTTPException(
|
|
125
|
+
status_code=501,
|
|
126
|
+
detail="Streaming requires sse-starlette: pip install prompture[serve]",
|
|
127
|
+
) from None
|
|
128
|
+
|
|
129
|
+
async def event_generator():
|
|
130
|
+
full_text = ""
|
|
131
|
+
async for chunk in conv.ask_stream(chat_req.message, chat_req.options):
|
|
132
|
+
full_text += chunk
|
|
133
|
+
yield {"data": json.dumps({"text": chunk})}
|
|
134
|
+
yield {"data": json.dumps({"text": "", "done": True, "conversation_id": conv_id, "usage": conv.usage})}
|
|
135
|
+
|
|
136
|
+
return EventSourceResponse(event_generator())
|
|
137
|
+
|
|
138
|
+
text = await conv.ask(chat_req.message, chat_req.options)
|
|
139
|
+
return ChatResponse(message=text, conversation_id=conv_id, usage=conv.usage)
|
|
140
|
+
|
|
141
|
+
@app.post("/v1/extract", response_model=ExtractResponse)
|
|
142
|
+
async def extract(extract_req: ExtractRequest):
|
|
143
|
+
conv_id, conv = _get_or_create_conversation(extract_req.conversation_id)
|
|
144
|
+
result = await conv.ask_for_json(
|
|
145
|
+
content=extract_req.text,
|
|
146
|
+
json_schema=extract_req.schema_def,
|
|
147
|
+
)
|
|
148
|
+
return ExtractResponse(
|
|
149
|
+
json_object=result["json_object"],
|
|
150
|
+
conversation_id=conv_id,
|
|
151
|
+
usage=conv.usage,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
@app.get("/v1/conversations/{conversation_id}", response_model=ConversationHistory)
|
|
155
|
+
async def get_conversation(conversation_id: str):
|
|
156
|
+
if conversation_id not in _conversations:
|
|
157
|
+
raise HTTPException(status_code=404, detail="Conversation not found")
|
|
158
|
+
conv = _conversations[conversation_id]
|
|
159
|
+
return ConversationHistory(
|
|
160
|
+
conversation_id=conversation_id,
|
|
161
|
+
messages=conv.messages,
|
|
162
|
+
usage=conv.usage,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
@app.delete("/v1/conversations/{conversation_id}")
|
|
166
|
+
async def delete_conversation(conversation_id: str):
|
|
167
|
+
if conversation_id not in _conversations:
|
|
168
|
+
raise HTTPException(status_code=404, detail="Conversation not found")
|
|
169
|
+
del _conversations[conversation_id]
|
|
170
|
+
return {"status": "deleted", "conversation_id": conversation_id}
|
|
171
|
+
|
|
172
|
+
@app.get("/v1/models", response_model=ModelInfo)
|
|
173
|
+
async def list_models():
|
|
174
|
+
from .discovery import get_available_models
|
|
175
|
+
|
|
176
|
+
try:
|
|
177
|
+
models = get_available_models()
|
|
178
|
+
model_names = [m["id"] if isinstance(m, dict) else str(m) for m in models]
|
|
179
|
+
except Exception:
|
|
180
|
+
model_names = [model_name]
|
|
181
|
+
return ModelInfo(models=model_names)
|
|
182
|
+
|
|
183
|
+
return app
|
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
"""Function calling / tool use support for Prompture.
|
|
2
|
+
|
|
3
|
+
Provides :class:`ToolDefinition` for describing callable tools,
|
|
4
|
+
:class:`ToolRegistry` for managing a collection of tools, and
|
|
5
|
+
:func:`tool_from_function` to auto-generate tool schemas from type hints.
|
|
6
|
+
|
|
7
|
+
Example::
|
|
8
|
+
|
|
9
|
+
from prompture import ToolRegistry
|
|
10
|
+
|
|
11
|
+
registry = ToolRegistry()
|
|
12
|
+
|
|
13
|
+
@registry.tool
|
|
14
|
+
def get_weather(city: str, units: str = "celsius") -> str:
|
|
15
|
+
\"\"\"Get the current weather for a city.\"\"\"
|
|
16
|
+
return f"Weather in {city}: 22 {units}"
|
|
17
|
+
|
|
18
|
+
# Or register explicitly
|
|
19
|
+
registry.register(get_weather)
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
import inspect
|
|
25
|
+
import logging
|
|
26
|
+
from dataclasses import dataclass, field
|
|
27
|
+
from typing import Any, Callable, get_type_hints
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger("prompture.tools_schema")
|
|
30
|
+
|
|
31
|
+
# Mapping from Python types to JSON Schema types
|
|
32
|
+
_TYPE_MAP: dict[type, str] = {
|
|
33
|
+
str: "string",
|
|
34
|
+
int: "integer",
|
|
35
|
+
float: "number",
|
|
36
|
+
bool: "boolean",
|
|
37
|
+
list: "array",
|
|
38
|
+
dict: "object",
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _python_type_to_json_schema(annotation: Any) -> dict[str, Any]:
|
|
43
|
+
"""Convert a Python type annotation to a JSON Schema snippet."""
|
|
44
|
+
if annotation is inspect.Parameter.empty or annotation is None:
|
|
45
|
+
return {"type": "string"}
|
|
46
|
+
|
|
47
|
+
# Handle Optional[X] (Union[X, None])
|
|
48
|
+
origin = getattr(annotation, "__origin__", None)
|
|
49
|
+
args = getattr(annotation, "__args__", ())
|
|
50
|
+
|
|
51
|
+
if origin is type(None):
|
|
52
|
+
return {"type": "string"}
|
|
53
|
+
|
|
54
|
+
# Union types (Optional)
|
|
55
|
+
if origin is not None and hasattr(origin, "__name__") and origin.__name__ == "Union":
|
|
56
|
+
non_none = [a for a in args if a is not type(None)]
|
|
57
|
+
if len(non_none) == 1:
|
|
58
|
+
return _python_type_to_json_schema(non_none[0])
|
|
59
|
+
|
|
60
|
+
# list[X]
|
|
61
|
+
if origin is list and args:
|
|
62
|
+
return {"type": "array", "items": _python_type_to_json_schema(args[0])}
|
|
63
|
+
|
|
64
|
+
# dict[str, X]
|
|
65
|
+
if origin is dict:
|
|
66
|
+
return {"type": "object"}
|
|
67
|
+
|
|
68
|
+
# Simple types
|
|
69
|
+
json_type = _TYPE_MAP.get(annotation, "string")
|
|
70
|
+
return {"type": json_type}
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@dataclass
|
|
74
|
+
class ToolDefinition:
|
|
75
|
+
"""Describes a single callable tool the LLM can invoke.
|
|
76
|
+
|
|
77
|
+
Attributes:
|
|
78
|
+
name: Unique tool identifier.
|
|
79
|
+
description: Human-readable description shown to the LLM.
|
|
80
|
+
parameters: JSON Schema describing the function parameters.
|
|
81
|
+
function: The Python callable to execute.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
name: str
|
|
85
|
+
description: str
|
|
86
|
+
parameters: dict[str, Any]
|
|
87
|
+
function: Callable[..., Any]
|
|
88
|
+
|
|
89
|
+
# ------------------------------------------------------------------
|
|
90
|
+
# Serialisation helpers
|
|
91
|
+
# ------------------------------------------------------------------
|
|
92
|
+
|
|
93
|
+
def to_openai_format(self) -> dict[str, Any]:
|
|
94
|
+
"""Serialise to OpenAI ``tools`` array element format."""
|
|
95
|
+
return {
|
|
96
|
+
"type": "function",
|
|
97
|
+
"function": {
|
|
98
|
+
"name": self.name,
|
|
99
|
+
"description": self.description,
|
|
100
|
+
"parameters": self.parameters,
|
|
101
|
+
},
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
def to_anthropic_format(self) -> dict[str, Any]:
|
|
105
|
+
"""Serialise to Anthropic ``tools`` array element format."""
|
|
106
|
+
return {
|
|
107
|
+
"name": self.name,
|
|
108
|
+
"description": self.description,
|
|
109
|
+
"input_schema": self.parameters,
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def tool_from_function(fn: Callable[..., Any], *, name: str | None = None, description: str | None = None) -> ToolDefinition:
|
|
114
|
+
"""Build a :class:`ToolDefinition` by inspecting *fn*'s signature and docstring.
|
|
115
|
+
|
|
116
|
+
Parameters:
|
|
117
|
+
fn: The callable to wrap.
|
|
118
|
+
name: Override the tool name (defaults to ``fn.__name__``).
|
|
119
|
+
description: Override the description (defaults to the first line of the docstring).
|
|
120
|
+
"""
|
|
121
|
+
tool_name = name or fn.__name__
|
|
122
|
+
tool_desc = description or (inspect.getdoc(fn) or "").split("\n")[0] or f"Call {tool_name}"
|
|
123
|
+
|
|
124
|
+
sig = inspect.signature(fn)
|
|
125
|
+
try:
|
|
126
|
+
hints = get_type_hints(fn)
|
|
127
|
+
except Exception:
|
|
128
|
+
hints = {}
|
|
129
|
+
|
|
130
|
+
properties: dict[str, Any] = {}
|
|
131
|
+
required: list[str] = []
|
|
132
|
+
|
|
133
|
+
for param_name, param in sig.parameters.items():
|
|
134
|
+
if param_name == "self":
|
|
135
|
+
continue
|
|
136
|
+
annotation = hints.get(param_name, param.annotation)
|
|
137
|
+
prop = _python_type_to_json_schema(annotation)
|
|
138
|
+
|
|
139
|
+
# Use parameter name as description fallback
|
|
140
|
+
prop.setdefault("description", f"Parameter: {param_name}")
|
|
141
|
+
|
|
142
|
+
properties[param_name] = prop
|
|
143
|
+
|
|
144
|
+
if param.default is inspect.Parameter.empty:
|
|
145
|
+
required.append(param_name)
|
|
146
|
+
|
|
147
|
+
parameters: dict[str, Any] = {
|
|
148
|
+
"type": "object",
|
|
149
|
+
"properties": properties,
|
|
150
|
+
}
|
|
151
|
+
if required:
|
|
152
|
+
parameters["required"] = required
|
|
153
|
+
|
|
154
|
+
return ToolDefinition(
|
|
155
|
+
name=tool_name,
|
|
156
|
+
description=tool_desc,
|
|
157
|
+
parameters=parameters,
|
|
158
|
+
function=fn,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
@dataclass
|
|
163
|
+
class ToolRegistry:
|
|
164
|
+
"""A collection of :class:`ToolDefinition` instances.
|
|
165
|
+
|
|
166
|
+
Supports decorator-based and explicit registration::
|
|
167
|
+
|
|
168
|
+
registry = ToolRegistry()
|
|
169
|
+
|
|
170
|
+
@registry.tool
|
|
171
|
+
def my_func(x: int) -> str:
|
|
172
|
+
...
|
|
173
|
+
|
|
174
|
+
registry.register(another_func)
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
_tools: dict[str, ToolDefinition] = field(default_factory=dict)
|
|
178
|
+
|
|
179
|
+
# ------------------------------------------------------------------
|
|
180
|
+
# Registration
|
|
181
|
+
# ------------------------------------------------------------------
|
|
182
|
+
|
|
183
|
+
def register(
|
|
184
|
+
self,
|
|
185
|
+
fn: Callable[..., Any],
|
|
186
|
+
*,
|
|
187
|
+
name: str | None = None,
|
|
188
|
+
description: str | None = None,
|
|
189
|
+
) -> ToolDefinition:
|
|
190
|
+
"""Register *fn* as a tool and return the :class:`ToolDefinition`."""
|
|
191
|
+
td = tool_from_function(fn, name=name, description=description)
|
|
192
|
+
self._tools[td.name] = td
|
|
193
|
+
return td
|
|
194
|
+
|
|
195
|
+
def tool(self, fn: Callable[..., Any]) -> Callable[..., Any]:
|
|
196
|
+
"""Decorator to register a function as a tool.
|
|
197
|
+
|
|
198
|
+
Returns the original function unchanged so it remains callable.
|
|
199
|
+
"""
|
|
200
|
+
self.register(fn)
|
|
201
|
+
return fn
|
|
202
|
+
|
|
203
|
+
def add(self, tool_def: ToolDefinition) -> None:
|
|
204
|
+
"""Add a pre-built :class:`ToolDefinition`."""
|
|
205
|
+
self._tools[tool_def.name] = tool_def
|
|
206
|
+
|
|
207
|
+
# ------------------------------------------------------------------
|
|
208
|
+
# Lookup
|
|
209
|
+
# ------------------------------------------------------------------
|
|
210
|
+
|
|
211
|
+
def get(self, name: str) -> ToolDefinition | None:
|
|
212
|
+
return self._tools.get(name)
|
|
213
|
+
|
|
214
|
+
def __contains__(self, name: str) -> bool:
|
|
215
|
+
return name in self._tools
|
|
216
|
+
|
|
217
|
+
def __len__(self) -> int:
|
|
218
|
+
return len(self._tools)
|
|
219
|
+
|
|
220
|
+
def __bool__(self) -> bool:
|
|
221
|
+
return bool(self._tools)
|
|
222
|
+
|
|
223
|
+
@property
|
|
224
|
+
def names(self) -> list[str]:
|
|
225
|
+
return list(self._tools.keys())
|
|
226
|
+
|
|
227
|
+
@property
|
|
228
|
+
def definitions(self) -> list[ToolDefinition]:
|
|
229
|
+
return list(self._tools.values())
|
|
230
|
+
|
|
231
|
+
# ------------------------------------------------------------------
|
|
232
|
+
# Serialisation
|
|
233
|
+
# ------------------------------------------------------------------
|
|
234
|
+
|
|
235
|
+
def to_openai_format(self) -> list[dict[str, Any]]:
|
|
236
|
+
return [td.to_openai_format() for td in self._tools.values()]
|
|
237
|
+
|
|
238
|
+
def to_anthropic_format(self) -> list[dict[str, Any]]:
|
|
239
|
+
return [td.to_anthropic_format() for td in self._tools.values()]
|
|
240
|
+
|
|
241
|
+
# ------------------------------------------------------------------
|
|
242
|
+
# Execution
|
|
243
|
+
# ------------------------------------------------------------------
|
|
244
|
+
|
|
245
|
+
def execute(self, name: str, arguments: dict[str, Any]) -> Any:
|
|
246
|
+
"""Execute a registered tool by name with the given arguments.
|
|
247
|
+
|
|
248
|
+
Raises:
|
|
249
|
+
KeyError: If no tool with *name* is registered.
|
|
250
|
+
"""
|
|
251
|
+
td = self._tools.get(name)
|
|
252
|
+
if td is None:
|
|
253
|
+
raise KeyError(f"Tool not registered: {name!r}")
|
|
254
|
+
return td.function(**arguments)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: prompture
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.35.dev1
|
|
4
4
|
Summary: Ask LLMs to return structured JSON and run cross-model tests. API-first.
|
|
5
5
|
Author-email: Juan Denis <juan@vene.co>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -37,6 +37,12 @@ Provides-Extra: airllm
|
|
|
37
37
|
Requires-Dist: airllm>=2.8.0; extra == "airllm"
|
|
38
38
|
Provides-Extra: redis
|
|
39
39
|
Requires-Dist: redis>=4.0; extra == "redis"
|
|
40
|
+
Provides-Extra: serve
|
|
41
|
+
Requires-Dist: fastapi>=0.100; extra == "serve"
|
|
42
|
+
Requires-Dist: uvicorn[standard]>=0.20; extra == "serve"
|
|
43
|
+
Requires-Dist: sse-starlette>=1.6; extra == "serve"
|
|
44
|
+
Provides-Extra: scaffold
|
|
45
|
+
Requires-Dist: jinja2>=3.0; extra == "scaffold"
|
|
40
46
|
Dynamic: license-file
|
|
41
47
|
|
|
42
48
|
# Prompture
|