selectools 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- selectools-0.2.0.dist-info/METADATA +730 -0
- selectools-0.2.0.dist-info/RECORD +18 -0
- selectools-0.2.0.dist-info/WHEEL +5 -0
- selectools-0.2.0.dist-info/entry_points.txt +2 -0
- selectools-0.2.0.dist-info/licenses/LICENSE +165 -0
- selectools-0.2.0.dist-info/top_level.txt +1 -0
- toolcalling/__init__.py +27 -0
- toolcalling/agent.py +188 -0
- toolcalling/cli.py +194 -0
- toolcalling/env.py +41 -0
- toolcalling/examples/bbox.py +272 -0
- toolcalling/parser.py +114 -0
- toolcalling/prompt.py +44 -0
- toolcalling/providers/base.py +55 -0
- toolcalling/providers/openai_provider.py +122 -0
- toolcalling/providers/stubs.py +245 -0
- toolcalling/tools.py +233 -0
- toolcalling/types.py +76 -0
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Bounding-box detection tool example using OpenAI Vision.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import base64
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
import re
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Dict, Optional
|
|
13
|
+
|
|
14
|
+
from PIL import Image, ImageDraw, ImageFont
|
|
15
|
+
|
|
16
|
+
from ..tools import Tool, ToolParameter
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
|
20
|
+
ASSETS_DIR = PROJECT_ROOT / "assets"
|
|
21
|
+
BBOX_MOCK_ENV = "TOOLCALLING_BBOX_MOCK_JSON"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _resolve_image_path(image_path: str) -> Path:
|
|
25
|
+
candidate = Path(image_path)
|
|
26
|
+
if not candidate.is_absolute():
|
|
27
|
+
asset_candidate = ASSETS_DIR / candidate
|
|
28
|
+
if asset_candidate.exists():
|
|
29
|
+
candidate = asset_candidate
|
|
30
|
+
return candidate.resolve()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _load_openai_client():
|
|
34
|
+
try:
|
|
35
|
+
from openai import OpenAI
|
|
36
|
+
except ImportError as exc: # noqa: BLE001
|
|
37
|
+
raise RuntimeError("openai package is required for bounding-box detection.") from exc
|
|
38
|
+
|
|
39
|
+
api_key = os.getenv("OPENAI_API_KEY")
|
|
40
|
+
if not api_key:
|
|
41
|
+
raise RuntimeError("Set OPENAI_API_KEY to run bounding-box detection.")
|
|
42
|
+
return OpenAI(api_key=api_key)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def detect_bounding_box_impl(target_object: str, image_path: str) -> str:
|
|
46
|
+
"""
|
|
47
|
+
Detect a target object in an image and draw a bounding box.
|
|
48
|
+
|
|
49
|
+
Returns a JSON string containing success status, coordinates, and output path.
|
|
50
|
+
"""
|
|
51
|
+
resolved_path = _resolve_image_path(image_path)
|
|
52
|
+
if not resolved_path.exists():
|
|
53
|
+
return json.dumps(
|
|
54
|
+
{
|
|
55
|
+
"success": False,
|
|
56
|
+
"message": f"Image file not found: {resolved_path}",
|
|
57
|
+
"coordinates": None,
|
|
58
|
+
"output_path": None,
|
|
59
|
+
}
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
mock_json = os.getenv(BBOX_MOCK_ENV)
|
|
63
|
+
if mock_json:
|
|
64
|
+
detection_data = _load_mock_detection(Path(mock_json))
|
|
65
|
+
else:
|
|
66
|
+
detection_data = _call_openai_vision(target_object=target_object, image_path=resolved_path)
|
|
67
|
+
|
|
68
|
+
if not detection_data:
|
|
69
|
+
return json.dumps(
|
|
70
|
+
{
|
|
71
|
+
"success": False,
|
|
72
|
+
"message": "No detection data returned.",
|
|
73
|
+
"coordinates": None,
|
|
74
|
+
"output_path": None,
|
|
75
|
+
}
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
if not detection_data.get("found"):
|
|
79
|
+
return json.dumps(
|
|
80
|
+
{
|
|
81
|
+
"success": False,
|
|
82
|
+
"message": f"Could not find {target_object}: {detection_data.get('description', '')}",
|
|
83
|
+
"coordinates": None,
|
|
84
|
+
"output_path": None,
|
|
85
|
+
}
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
x_min = float(detection_data["x_min"])
|
|
89
|
+
y_min = float(detection_data["y_min"])
|
|
90
|
+
x_max = float(detection_data["x_max"])
|
|
91
|
+
y_max = float(detection_data["y_max"])
|
|
92
|
+
|
|
93
|
+
if not _coordinates_valid(x_min, y_min, x_max, y_max):
|
|
94
|
+
return json.dumps(
|
|
95
|
+
{
|
|
96
|
+
"success": False,
|
|
97
|
+
"message": f"Invalid coordinates returned (must be between 0 and 1): {detection_data}",
|
|
98
|
+
"coordinates": None,
|
|
99
|
+
"output_path": None,
|
|
100
|
+
}
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
output_path, pixel_coordinates = _draw_box(resolved_path, target_object, x_min, y_min, x_max, y_max)
|
|
104
|
+
|
|
105
|
+
return json.dumps(
|
|
106
|
+
{
|
|
107
|
+
"success": True,
|
|
108
|
+
"message": f"Detected {target_object}; output saved to {output_path}",
|
|
109
|
+
"coordinates": {
|
|
110
|
+
"normalized": {
|
|
111
|
+
"x_min": x_min,
|
|
112
|
+
"y_min": y_min,
|
|
113
|
+
"x_max": x_max,
|
|
114
|
+
"y_max": y_max,
|
|
115
|
+
},
|
|
116
|
+
"pixels": pixel_coordinates,
|
|
117
|
+
},
|
|
118
|
+
"output_path": str(output_path),
|
|
119
|
+
"confidence": detection_data.get("confidence", "unknown"),
|
|
120
|
+
"description": detection_data.get("description", ""),
|
|
121
|
+
},
|
|
122
|
+
indent=2,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _parse_detection_response(response_text: str) -> Dict[str, str]:
|
|
127
|
+
json_match = re.search(r"\{.*\}", response_text, re.DOTALL)
|
|
128
|
+
if json_match:
|
|
129
|
+
return json.loads(json_match.group())
|
|
130
|
+
return json.loads(response_text)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _load_mock_detection(mock_path: Path) -> Dict[str, str]:
|
|
134
|
+
"""Load a deterministic mock response for offline testing."""
|
|
135
|
+
if not mock_path.exists():
|
|
136
|
+
return {"found": False, "description": f"Mock file not found at {mock_path}"}
|
|
137
|
+
try:
|
|
138
|
+
return json.loads(mock_path.read_text())
|
|
139
|
+
except Exception as exc: # noqa: BLE001
|
|
140
|
+
return {"found": False, "description": f"Failed to read mock file: {exc}"}
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _call_openai_vision(target_object: str, image_path: Path) -> Optional[Dict[str, str]]:
|
|
144
|
+
client = _load_openai_client()
|
|
145
|
+
image_base64 = base64.b64encode(image_path.read_bytes()).decode("utf-8")
|
|
146
|
+
|
|
147
|
+
prompt = f"""Analyze this image and locate the {target_object}.
|
|
148
|
+
|
|
149
|
+
Return ONLY a JSON object with the bounding box coordinates in normalized format (0.0 to 1.0):
|
|
150
|
+
|
|
151
|
+
{{
|
|
152
|
+
"found": true/false,
|
|
153
|
+
"x_min": 0.0-1.0,
|
|
154
|
+
"y_min": 0.0-1.0,
|
|
155
|
+
"x_max": 0.0-1.0,
|
|
156
|
+
"y_max": 0.0-1.0,
|
|
157
|
+
"confidence": "high/medium/low",
|
|
158
|
+
"description": "brief description of what you found"
|
|
159
|
+
}}
|
|
160
|
+
|
|
161
|
+
If the {target_object} is not found, set "found" to false and explain why in the description.
|
|
162
|
+
Coordinates should be normalized (0.0 = left/top edge, 1.0 = right/bottom edge).
|
|
163
|
+
Return ONLY the JSON object, no other text."""
|
|
164
|
+
|
|
165
|
+
try:
|
|
166
|
+
response = client.chat.completions.create(
|
|
167
|
+
model="gpt-4o",
|
|
168
|
+
messages=[
|
|
169
|
+
{
|
|
170
|
+
"role": "user",
|
|
171
|
+
"content": [
|
|
172
|
+
{"type": "text", "text": prompt},
|
|
173
|
+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}},
|
|
174
|
+
],
|
|
175
|
+
}
|
|
176
|
+
],
|
|
177
|
+
max_tokens=500,
|
|
178
|
+
)
|
|
179
|
+
except Exception as exc: # noqa: BLE001
|
|
180
|
+
return {"found": False, "description": f"Vision API error: {exc}"}
|
|
181
|
+
|
|
182
|
+
response_text = response.choices[0].message.content
|
|
183
|
+
return _parse_detection_response(response_text)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _coordinates_valid(x_min: float, y_min: float, x_max: float, y_max: float) -> bool:
|
|
187
|
+
return 0 <= x_min <= 1 and 0 <= y_min <= 1 and 0 <= x_max <= 1 and 0 <= y_max <= 1
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _draw_box(image_path: Path, target_object: str, x_min: float, y_min: float, x_max: float, y_max: float):
|
|
191
|
+
image = Image.open(image_path)
|
|
192
|
+
width, height = image.size
|
|
193
|
+
|
|
194
|
+
pixel_coordinates = {
|
|
195
|
+
"x_min": int(x_min * width),
|
|
196
|
+
"y_min": int(y_min * height),
|
|
197
|
+
"x_max": int(x_max * width),
|
|
198
|
+
"y_max": int(y_max * height),
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
draw = ImageDraw.Draw(image)
|
|
202
|
+
thickness = max(3, int(min(width, height) * 0.005))
|
|
203
|
+
for offset in range(thickness):
|
|
204
|
+
draw.rectangle(
|
|
205
|
+
[
|
|
206
|
+
pixel_coordinates["x_min"] - offset,
|
|
207
|
+
pixel_coordinates["y_min"] - offset,
|
|
208
|
+
pixel_coordinates["x_max"] + offset,
|
|
209
|
+
pixel_coordinates["y_max"] + offset,
|
|
210
|
+
],
|
|
211
|
+
outline="red",
|
|
212
|
+
width=1,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
label = target_object.upper()
|
|
216
|
+
try:
|
|
217
|
+
font = ImageFont.truetype("arial.ttf", size=max(20, int(height * 0.03)))
|
|
218
|
+
except Exception:
|
|
219
|
+
font = ImageFont.load_default()
|
|
220
|
+
|
|
221
|
+
bbox = draw.textbbox((0, 0), label, font=font)
|
|
222
|
+
text_width = bbox[2] - bbox[0]
|
|
223
|
+
text_height = bbox[3] - bbox[1]
|
|
224
|
+
label_x = pixel_coordinates["x_min"]
|
|
225
|
+
label_y = pixel_coordinates["y_min"] - text_height - 10
|
|
226
|
+
if label_y < 0:
|
|
227
|
+
label_y = pixel_coordinates["y_min"] + 5
|
|
228
|
+
|
|
229
|
+
padding = 5
|
|
230
|
+
draw.rectangle(
|
|
231
|
+
[
|
|
232
|
+
label_x - padding,
|
|
233
|
+
label_y - padding,
|
|
234
|
+
label_x + text_width + padding,
|
|
235
|
+
label_y + text_height + padding,
|
|
236
|
+
],
|
|
237
|
+
fill="red",
|
|
238
|
+
)
|
|
239
|
+
draw.text((label_x, label_y), label, fill="white", font=font)
|
|
240
|
+
|
|
241
|
+
output_path = image_path.parent / f"{image_path.stem}_with_bbox.png"
|
|
242
|
+
image.save(output_path)
|
|
243
|
+
return output_path, pixel_coordinates
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def create_bounding_box_tool() -> Tool:
|
|
247
|
+
"""Factory for the bounding-box detection tool."""
|
|
248
|
+
return Tool(
|
|
249
|
+
name="detect_bounding_box",
|
|
250
|
+
description=(
|
|
251
|
+
"Detects and draws a bounding box around a specific object in an image. "
|
|
252
|
+
"Returns normalized and pixel coordinates plus the output image path."
|
|
253
|
+
),
|
|
254
|
+
parameters=[
|
|
255
|
+
ToolParameter(
|
|
256
|
+
name="target_object",
|
|
257
|
+
param_type=str,
|
|
258
|
+
description="The object to locate in the image.",
|
|
259
|
+
required=True,
|
|
260
|
+
),
|
|
261
|
+
ToolParameter(
|
|
262
|
+
name="image_path",
|
|
263
|
+
param_type=str,
|
|
264
|
+
description="Path to the image file (absolute or relative to assets/).",
|
|
265
|
+
required=True,
|
|
266
|
+
),
|
|
267
|
+
],
|
|
268
|
+
function=detect_bounding_box_impl,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
__all__ = ["create_bounding_box_tool", "detect_bounding_box_impl"]
|
toolcalling/parser.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Parser for TOOL_CALL directives emitted by language models.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import re
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from typing import Any, Dict, List, Optional
|
|
11
|
+
|
|
12
|
+
from .types import ToolCall
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class ParseResult:
|
|
17
|
+
"""Result of attempting to parse a tool call."""
|
|
18
|
+
|
|
19
|
+
tool_call: Optional[ToolCall]
|
|
20
|
+
raw_text: str
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ToolCallParser:
|
|
24
|
+
"""Robustly extract TOOL_CALL directives from model output."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, marker: str = "TOOL_CALL", max_payload_chars: int = 8000):
|
|
27
|
+
self.marker = marker
|
|
28
|
+
self.max_payload_chars = max_payload_chars
|
|
29
|
+
|
|
30
|
+
def parse(self, text: str) -> ParseResult:
|
|
31
|
+
"""
|
|
32
|
+
Attempt to parse a tool call from the provided text.
|
|
33
|
+
|
|
34
|
+
Supports fenced code blocks, inline JSON, and newline-heavy outputs.
|
|
35
|
+
"""
|
|
36
|
+
candidates = self._extract_candidate_blocks(text)
|
|
37
|
+
for candidate in candidates:
|
|
38
|
+
if self.max_payload_chars and len(candidate) > self.max_payload_chars:
|
|
39
|
+
continue
|
|
40
|
+
tool_data = self._load_json(candidate)
|
|
41
|
+
if not tool_data:
|
|
42
|
+
continue
|
|
43
|
+
tool_name = tool_data.get("tool_name") or tool_data.get("tool") or tool_data.get("name")
|
|
44
|
+
parameters: Dict[str, Any] = tool_data.get("parameters") or tool_data.get("params") or {}
|
|
45
|
+
if tool_name:
|
|
46
|
+
return ParseResult(tool_call=ToolCall(tool_name=tool_name, parameters=parameters), raw_text=text)
|
|
47
|
+
return ParseResult(tool_call=None, raw_text=text)
|
|
48
|
+
|
|
49
|
+
def _extract_candidate_blocks(self, text: str) -> List[str]:
|
|
50
|
+
"""Pull out all JSON substrings that might contain the TOOL_CALL payload."""
|
|
51
|
+
blocks: List[str] = []
|
|
52
|
+
|
|
53
|
+
marker_positions = [m.start() for m in re.finditer(self.marker, text)]
|
|
54
|
+
for pos in marker_positions:
|
|
55
|
+
subset = text[pos:]
|
|
56
|
+
blocks.extend(self._find_balanced_json(subset))
|
|
57
|
+
|
|
58
|
+
fenced_blocks = re.findall(r"```.*?```", text, re.DOTALL)
|
|
59
|
+
for block in fenced_blocks:
|
|
60
|
+
if self.marker in block or "tool_name" in block or "parameters" in block:
|
|
61
|
+
cleaned = block.strip("` \n")
|
|
62
|
+
blocks.extend(self._find_balanced_json(cleaned))
|
|
63
|
+
|
|
64
|
+
if not blocks:
|
|
65
|
+
blocks.extend(self._find_balanced_json(text))
|
|
66
|
+
|
|
67
|
+
# Deduplicate while preserving order
|
|
68
|
+
deduped = []
|
|
69
|
+
seen = set()
|
|
70
|
+
for block in blocks:
|
|
71
|
+
if block in seen:
|
|
72
|
+
continue
|
|
73
|
+
deduped.append(block)
|
|
74
|
+
seen.add(block)
|
|
75
|
+
return deduped
|
|
76
|
+
|
|
77
|
+
def _load_json(self, candidate: str) -> Optional[Dict[str, Any]]:
|
|
78
|
+
"""Attempt JSON parsing with lenient fallbacks."""
|
|
79
|
+
normalized = candidate
|
|
80
|
+
if self.marker in normalized:
|
|
81
|
+
normalized = normalized.split(self.marker, maxsplit=1)[-1]
|
|
82
|
+
normalized = normalized.strip("` \n:")
|
|
83
|
+
|
|
84
|
+
attempts = [
|
|
85
|
+
normalized,
|
|
86
|
+
normalized.replace("'", '"'),
|
|
87
|
+
normalized.replace("\n", "\\n"),
|
|
88
|
+
]
|
|
89
|
+
for attempt in attempts:
|
|
90
|
+
try:
|
|
91
|
+
return json.loads(attempt)
|
|
92
|
+
except json.JSONDecodeError:
|
|
93
|
+
continue
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
def _find_balanced_json(self, text: str) -> List[str]:
|
|
97
|
+
"""Collect balanced JSON-like substrings from text."""
|
|
98
|
+
candidates: List[str] = []
|
|
99
|
+
starts = [m.start() for m in re.finditer(r"\{", text)]
|
|
100
|
+
for start in starts:
|
|
101
|
+
depth = 0
|
|
102
|
+
for idx in range(start, len(text)):
|
|
103
|
+
char = text[idx]
|
|
104
|
+
if char == "{":
|
|
105
|
+
depth += 1
|
|
106
|
+
elif char == "}":
|
|
107
|
+
depth -= 1
|
|
108
|
+
if depth == 0:
|
|
109
|
+
candidates.append(text[start : idx + 1])
|
|
110
|
+
break
|
|
111
|
+
return candidates
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
__all__ = ["ToolCallParser", "ParseResult"]
|
toolcalling/prompt.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Prompt templating for the TOOL_CALL contract.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
from typing import List
|
|
9
|
+
|
|
10
|
+
from .tools import Tool
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
DEFAULT_SYSTEM_INSTRUCTIONS = """You are an assistant that can call tools when helpful.
|
|
14
|
+
|
|
15
|
+
Tool call contract:
|
|
16
|
+
- Emit TOOL_CALL with JSON: {"tool_name": "<name>", "parameters": {...}}
|
|
17
|
+
- Include every required parameter. Ask for missing details instead of guessing.
|
|
18
|
+
- Wait for tool results before giving a final answer.
|
|
19
|
+
- Do not invent tool outputs; only report what was returned.
|
|
20
|
+
- Keep tool payloads compact (<=8k chars) and emit one tool call at a time.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class PromptBuilder:
|
|
25
|
+
"""Render a system prompt that includes tool schemas."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, base_instructions: str = DEFAULT_SYSTEM_INSTRUCTIONS):
|
|
28
|
+
self.base_instructions = base_instructions
|
|
29
|
+
|
|
30
|
+
def build(self, tools: List[Tool]) -> str:
|
|
31
|
+
tool_blocks = []
|
|
32
|
+
for tool in tools:
|
|
33
|
+
tool_blocks.append(json.dumps(tool.schema(), indent=2))
|
|
34
|
+
tools_text = "\n\n".join(tool_blocks)
|
|
35
|
+
|
|
36
|
+
return (
|
|
37
|
+
f"{self.base_instructions.strip()}\n\n"
|
|
38
|
+
f"Available tools (JSON schema):\n\n{tools_text}\n\n"
|
|
39
|
+
"If a relevant tool exists, respond with a TOOL_CALL first. "
|
|
40
|
+
"When no tool is useful, answer directly."
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
__all__ = ["PromptBuilder", "DEFAULT_SYSTEM_INSTRUCTIONS"]
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Provider abstraction for model-agnostic tool calling.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from typing import Protocol, runtime_checkable
|
|
8
|
+
|
|
9
|
+
from ..types import Message
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ProviderError(RuntimeError):
|
|
13
|
+
"""Raised when an adapter cannot complete a request."""
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@runtime_checkable
|
|
17
|
+
class Provider(Protocol):
|
|
18
|
+
"""Interface every provider adapter must satisfy."""
|
|
19
|
+
|
|
20
|
+
name: str
|
|
21
|
+
supports_streaming: bool
|
|
22
|
+
|
|
23
|
+
def complete(
|
|
24
|
+
self,
|
|
25
|
+
*,
|
|
26
|
+
model: str,
|
|
27
|
+
system_prompt: str,
|
|
28
|
+
messages: list[Message],
|
|
29
|
+
temperature: float = 0.0,
|
|
30
|
+
max_tokens: int = 1000,
|
|
31
|
+
timeout: float | None = None,
|
|
32
|
+
) -> str:
|
|
33
|
+
"""Return assistant text given conversation state."""
|
|
34
|
+
...
|
|
35
|
+
|
|
36
|
+
def stream(
|
|
37
|
+
self,
|
|
38
|
+
*,
|
|
39
|
+
model: str,
|
|
40
|
+
system_prompt: str,
|
|
41
|
+
messages: list[Message],
|
|
42
|
+
temperature: float = 0.0,
|
|
43
|
+
max_tokens: int = 1000,
|
|
44
|
+
timeout: float | None = None,
|
|
45
|
+
):
|
|
46
|
+
"""
|
|
47
|
+
Yield assistant text chunks for providers that support streaming.
|
|
48
|
+
|
|
49
|
+
Implementations should raise ProviderError if streaming is not supported
|
|
50
|
+
or fails.
|
|
51
|
+
"""
|
|
52
|
+
...
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
__all__ = ["Provider", "ProviderError"]
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
"""
|
|
2
|
+
OpenAI provider adapter for the tool-calling library.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
from typing import List
|
|
9
|
+
|
|
10
|
+
from ..env import load_default_env
|
|
11
|
+
from ..types import Message, Role
|
|
12
|
+
from .base import Provider, ProviderError
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class OpenAIProvider(Provider):
|
|
16
|
+
"""Adapter that speaks to OpenAI's Chat Completions API."""
|
|
17
|
+
|
|
18
|
+
name = "openai"
|
|
19
|
+
supports_streaming = True
|
|
20
|
+
|
|
21
|
+
def __init__(self, api_key: str | None = None, default_model: str = "gpt-4o"):
|
|
22
|
+
load_default_env()
|
|
23
|
+
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
|
24
|
+
if not self.api_key:
|
|
25
|
+
raise ProviderError("OPENAI_API_KEY is not set. Set it in env or pass api_key.")
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
from openai import OpenAI
|
|
29
|
+
except ImportError as exc:
|
|
30
|
+
raise ProviderError("openai package not installed. Install with `pip install openai`.") from exc
|
|
31
|
+
|
|
32
|
+
self._client = OpenAI(api_key=self.api_key)
|
|
33
|
+
self.default_model = default_model
|
|
34
|
+
|
|
35
|
+
def complete(
|
|
36
|
+
self,
|
|
37
|
+
*,
|
|
38
|
+
model: str,
|
|
39
|
+
system_prompt: str,
|
|
40
|
+
messages: List[Message],
|
|
41
|
+
temperature: float = 0.0,
|
|
42
|
+
max_tokens: int = 1000,
|
|
43
|
+
timeout: float | None = None,
|
|
44
|
+
) -> str:
|
|
45
|
+
formatted = self._format_messages(system_prompt=system_prompt, messages=messages)
|
|
46
|
+
|
|
47
|
+
try:
|
|
48
|
+
response = self._client.chat.completions.create(
|
|
49
|
+
model=model or self.default_model,
|
|
50
|
+
messages=formatted,
|
|
51
|
+
temperature=temperature,
|
|
52
|
+
max_tokens=max_tokens,
|
|
53
|
+
timeout=timeout,
|
|
54
|
+
)
|
|
55
|
+
except Exception as exc: # noqa: BLE001
|
|
56
|
+
raise ProviderError(f"OpenAI completion failed: {exc}") from exc
|
|
57
|
+
|
|
58
|
+
content = response.choices[0].message.content
|
|
59
|
+
return content or ""
|
|
60
|
+
|
|
61
|
+
def stream(
|
|
62
|
+
self,
|
|
63
|
+
*,
|
|
64
|
+
model: str,
|
|
65
|
+
system_prompt: str,
|
|
66
|
+
messages: List[Message],
|
|
67
|
+
temperature: float = 0.0,
|
|
68
|
+
max_tokens: int = 1000,
|
|
69
|
+
timeout: float | None = None,
|
|
70
|
+
):
|
|
71
|
+
formatted = self._format_messages(system_prompt=system_prompt, messages=messages)
|
|
72
|
+
try:
|
|
73
|
+
response = self._client.chat.completions.create(
|
|
74
|
+
model=model or self.default_model,
|
|
75
|
+
messages=formatted,
|
|
76
|
+
temperature=temperature,
|
|
77
|
+
max_tokens=max_tokens,
|
|
78
|
+
stream=True,
|
|
79
|
+
timeout=timeout,
|
|
80
|
+
)
|
|
81
|
+
except Exception as exc: # noqa: BLE001
|
|
82
|
+
raise ProviderError(f"OpenAI streaming failed: {exc}") from exc
|
|
83
|
+
|
|
84
|
+
for chunk in response:
|
|
85
|
+
try:
|
|
86
|
+
delta = chunk.choices[0].delta
|
|
87
|
+
if not delta or not delta.content:
|
|
88
|
+
continue
|
|
89
|
+
content = delta.content
|
|
90
|
+
if isinstance(content, list):
|
|
91
|
+
content = "".join([part.text for part in content if getattr(part, "text", None)])
|
|
92
|
+
yield content
|
|
93
|
+
except Exception as exc: # noqa: BLE001
|
|
94
|
+
raise ProviderError(f"OpenAI stream parsing failed: {exc}") from exc
|
|
95
|
+
|
|
96
|
+
def _format_messages(self, system_prompt: str, messages: List[Message]):
|
|
97
|
+
payload = [{"role": "system", "content": system_prompt}]
|
|
98
|
+
for message in messages:
|
|
99
|
+
role = message.role.value
|
|
100
|
+
if role == Role.TOOL.value:
|
|
101
|
+
role = Role.ASSISTANT.value
|
|
102
|
+
payload.append(
|
|
103
|
+
{
|
|
104
|
+
"role": role,
|
|
105
|
+
"content": self._format_content(message),
|
|
106
|
+
}
|
|
107
|
+
)
|
|
108
|
+
return payload
|
|
109
|
+
|
|
110
|
+
def _format_content(self, message: Message):
|
|
111
|
+
if message.image_base64:
|
|
112
|
+
return [
|
|
113
|
+
{"type": "text", "text": message.content},
|
|
114
|
+
{
|
|
115
|
+
"type": "image_url",
|
|
116
|
+
"image_url": {"url": f"data:image/jpeg;base64,{message.image_base64}"},
|
|
117
|
+
},
|
|
118
|
+
]
|
|
119
|
+
return message.content
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
__all__ = ["OpenAIProvider"]
|