kader 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/README.md +169 -0
- cli/__init__.py +5 -0
- cli/__main__.py +6 -0
- cli/app.py +707 -0
- cli/app.tcss +664 -0
- cli/utils.py +68 -0
- cli/widgets/__init__.py +13 -0
- cli/widgets/confirmation.py +309 -0
- cli/widgets/conversation.py +55 -0
- cli/widgets/loading.py +59 -0
- kader/__init__.py +22 -0
- kader/agent/__init__.py +8 -0
- kader/agent/agents.py +126 -0
- kader/agent/base.py +927 -0
- kader/agent/logger.py +170 -0
- kader/config.py +139 -0
- kader/memory/__init__.py +66 -0
- kader/memory/conversation.py +409 -0
- kader/memory/session.py +385 -0
- kader/memory/state.py +211 -0
- kader/memory/types.py +116 -0
- kader/prompts/__init__.py +9 -0
- kader/prompts/agent_prompts.py +27 -0
- kader/prompts/base.py +81 -0
- kader/prompts/templates/planning_agent.j2 +26 -0
- kader/prompts/templates/react_agent.j2 +18 -0
- kader/providers/__init__.py +9 -0
- kader/providers/base.py +581 -0
- kader/providers/mock.py +96 -0
- kader/providers/ollama.py +447 -0
- kader/tools/README.md +483 -0
- kader/tools/__init__.py +130 -0
- kader/tools/base.py +955 -0
- kader/tools/exec_commands.py +249 -0
- kader/tools/filesys.py +650 -0
- kader/tools/filesystem.py +607 -0
- kader/tools/protocol.py +456 -0
- kader/tools/rag.py +555 -0
- kader/tools/todo.py +210 -0
- kader/tools/utils.py +456 -0
- kader/tools/web.py +246 -0
- kader-0.1.5.dist-info/METADATA +321 -0
- kader-0.1.5.dist-info/RECORD +45 -0
- kader-0.1.5.dist-info/WHEEL +4 -0
- kader-0.1.5.dist-info/entry_points.txt +2 -0
kader/tools/base.py
ADDED
|
@@ -0,0 +1,955 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base class for Agentic Tools.
|
|
3
|
+
|
|
4
|
+
A versatile, provider-agnostic base class for defining tools that can be used
|
|
5
|
+
with any LLM provider (OpenAI, Google, Anthropic, Mistral, and others).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
from abc import ABC, abstractmethod
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from enum import Enum
|
|
12
|
+
from typing import (
|
|
13
|
+
Any,
|
|
14
|
+
Callable,
|
|
15
|
+
Generic,
|
|
16
|
+
Literal,
|
|
17
|
+
TypeAlias,
|
|
18
|
+
TypeVar,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
# Type Aliases
|
|
22
|
+
ParameterType: TypeAlias = Literal[
|
|
23
|
+
"string", "integer", "number", "boolean", "array", "object"
|
|
24
|
+
]
|
|
25
|
+
ToolResultStatus: TypeAlias = Literal["success", "error", "pending"]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ToolCategory(str, Enum):
|
|
29
|
+
"""Categories of tools for organization and filtering."""
|
|
30
|
+
|
|
31
|
+
FILE_SYSTEM = "file_system"
|
|
32
|
+
CODE = "code"
|
|
33
|
+
WEB = "web"
|
|
34
|
+
SEARCH = "search"
|
|
35
|
+
DATABASE = "database"
|
|
36
|
+
API = "api"
|
|
37
|
+
UTILITY = "utility"
|
|
38
|
+
CUSTOM = "custom"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class ParameterSchema:
|
|
43
|
+
"""Schema for a single tool parameter."""
|
|
44
|
+
|
|
45
|
+
name: str
|
|
46
|
+
type: ParameterType
|
|
47
|
+
description: str
|
|
48
|
+
required: bool = True
|
|
49
|
+
|
|
50
|
+
# Additional constraints
|
|
51
|
+
enum: list[str] | None = None
|
|
52
|
+
default: Any = None
|
|
53
|
+
minimum: int | float | None = None
|
|
54
|
+
maximum: int | float | None = None
|
|
55
|
+
min_length: int | None = None
|
|
56
|
+
max_length: int | None = None
|
|
57
|
+
pattern: str | None = None
|
|
58
|
+
|
|
59
|
+
# For array types
|
|
60
|
+
items_type: ParameterType | None = None
|
|
61
|
+
|
|
62
|
+
# For object types
|
|
63
|
+
properties: list["ParameterSchema"] | None = None
|
|
64
|
+
|
|
65
|
+
def to_json_schema(self) -> dict[str, Any]:
|
|
66
|
+
"""Convert to JSON Schema format (OpenAI/standard format)."""
|
|
67
|
+
schema: dict[str, Any] = {
|
|
68
|
+
"type": self.type,
|
|
69
|
+
"description": self.description,
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
if self.enum:
|
|
73
|
+
schema["enum"] = self.enum
|
|
74
|
+
if self.default is not None:
|
|
75
|
+
schema["default"] = self.default
|
|
76
|
+
if self.minimum is not None:
|
|
77
|
+
schema["minimum"] = self.minimum
|
|
78
|
+
if self.maximum is not None:
|
|
79
|
+
schema["maximum"] = self.maximum
|
|
80
|
+
if self.min_length is not None:
|
|
81
|
+
schema["minLength"] = self.min_length
|
|
82
|
+
if self.max_length is not None:
|
|
83
|
+
schema["maxLength"] = self.max_length
|
|
84
|
+
if self.pattern is not None:
|
|
85
|
+
schema["pattern"] = self.pattern
|
|
86
|
+
|
|
87
|
+
# Array items
|
|
88
|
+
if self.type == "array" and self.items_type:
|
|
89
|
+
schema["items"] = {"type": self.items_type}
|
|
90
|
+
|
|
91
|
+
# Nested object properties
|
|
92
|
+
if self.type == "object" and self.properties:
|
|
93
|
+
schema["properties"] = {
|
|
94
|
+
prop.name: prop.to_json_schema() for prop in self.properties
|
|
95
|
+
}
|
|
96
|
+
schema["required"] = [
|
|
97
|
+
prop.name for prop in self.properties if prop.required
|
|
98
|
+
]
|
|
99
|
+
|
|
100
|
+
return schema
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@dataclass
|
|
104
|
+
class ToolSchema:
|
|
105
|
+
"""Complete schema definition for a tool."""
|
|
106
|
+
|
|
107
|
+
name: str
|
|
108
|
+
description: str
|
|
109
|
+
parameters: list[ParameterSchema] = field(default_factory=list)
|
|
110
|
+
|
|
111
|
+
# Optional metadata
|
|
112
|
+
category: ToolCategory = ToolCategory.CUSTOM
|
|
113
|
+
version: str = "1.0.0"
|
|
114
|
+
deprecated: bool = False
|
|
115
|
+
|
|
116
|
+
def to_json_schema(self) -> dict[str, Any]:
|
|
117
|
+
"""Convert to JSON Schema format for parameters."""
|
|
118
|
+
properties = {param.name: param.to_json_schema() for param in self.parameters}
|
|
119
|
+
required = [param.name for param in self.parameters if param.required]
|
|
120
|
+
|
|
121
|
+
return {
|
|
122
|
+
"type": "object",
|
|
123
|
+
"properties": properties,
|
|
124
|
+
"required": required,
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
def to_openai_format(self) -> dict[str, Any]:
|
|
128
|
+
"""Convert to OpenAI function calling format."""
|
|
129
|
+
return {
|
|
130
|
+
"type": "function",
|
|
131
|
+
"function": {
|
|
132
|
+
"name": self.name,
|
|
133
|
+
"description": self.description,
|
|
134
|
+
"parameters": self.to_json_schema(),
|
|
135
|
+
},
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
def to_anthropic_format(self) -> dict[str, Any]:
|
|
139
|
+
"""Convert to Anthropic tool format."""
|
|
140
|
+
return {
|
|
141
|
+
"name": self.name,
|
|
142
|
+
"description": self.description,
|
|
143
|
+
"input_schema": self.to_json_schema(),
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
def to_google_format(self) -> dict[str, Any]:
|
|
147
|
+
"""Convert to Google (Gemini) tool format."""
|
|
148
|
+
return {
|
|
149
|
+
"name": self.name,
|
|
150
|
+
"description": self.description,
|
|
151
|
+
"parameters": self.to_json_schema(),
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
def to_mistral_format(self) -> dict[str, Any]:
|
|
155
|
+
"""Convert to Mistral tool format (same as OpenAI)."""
|
|
156
|
+
return self.to_openai_format()
|
|
157
|
+
|
|
158
|
+
def to_ollama_format(self) -> dict[str, Any]:
|
|
159
|
+
"""Convert to Ollama tool format (same as OpenAI)."""
|
|
160
|
+
return self.to_openai_format()
|
|
161
|
+
|
|
162
|
+
def to_provider_format(self, provider: str) -> dict[str, Any]:
|
|
163
|
+
"""
|
|
164
|
+
Convert to a specific provider's format.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
provider: Provider name (openai, anthropic, google, mistral, ollama)
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
Tool schema in the provider's format
|
|
171
|
+
"""
|
|
172
|
+
formatters = {
|
|
173
|
+
"openai": self.to_openai_format,
|
|
174
|
+
"anthropic": self.to_anthropic_format,
|
|
175
|
+
"google": self.to_google_format,
|
|
176
|
+
"gemini": self.to_google_format,
|
|
177
|
+
"mistral": self.to_mistral_format,
|
|
178
|
+
"ollama": self.to_ollama_format,
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
formatter = formatters.get(provider.lower())
|
|
182
|
+
if formatter:
|
|
183
|
+
return formatter()
|
|
184
|
+
|
|
185
|
+
# Default to OpenAI format as it's most common
|
|
186
|
+
return self.to_openai_format()
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
@dataclass
|
|
190
|
+
class ToolCall:
|
|
191
|
+
"""Represents a tool call from an LLM."""
|
|
192
|
+
|
|
193
|
+
id: str # Unique identifier for the tool call
|
|
194
|
+
name: str # Name of the tool to call
|
|
195
|
+
arguments: dict[str, Any] # Parsed arguments
|
|
196
|
+
raw_arguments: str | None = None # Original JSON string (if available)
|
|
197
|
+
|
|
198
|
+
@classmethod
|
|
199
|
+
def from_openai(cls, tool_call: dict[str, Any]) -> "ToolCall":
|
|
200
|
+
"""Create from OpenAI tool call format."""
|
|
201
|
+
function = tool_call.get("function", {})
|
|
202
|
+
raw_args = function.get("arguments", "{}")
|
|
203
|
+
return cls(
|
|
204
|
+
id=tool_call.get("id", ""),
|
|
205
|
+
name=function.get("name", ""),
|
|
206
|
+
arguments=json.loads(raw_args) if raw_args else {},
|
|
207
|
+
raw_arguments=raw_args,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
@classmethod
|
|
211
|
+
def from_anthropic(cls, tool_use: dict[str, Any]) -> "ToolCall":
|
|
212
|
+
"""Create from Anthropic tool use format."""
|
|
213
|
+
return cls(
|
|
214
|
+
id=tool_use.get("id", ""),
|
|
215
|
+
name=tool_use.get("name", ""),
|
|
216
|
+
arguments=tool_use.get("input", {}),
|
|
217
|
+
raw_arguments=json.dumps(tool_use.get("input", {})),
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
@classmethod
|
|
221
|
+
def from_google(cls, function_call: dict[str, Any]) -> "ToolCall":
|
|
222
|
+
"""Create from Google (Gemini) function call format."""
|
|
223
|
+
return cls(
|
|
224
|
+
id=function_call.get("id", ""),
|
|
225
|
+
name=function_call.get("name", ""),
|
|
226
|
+
arguments=function_call.get("args", {}),
|
|
227
|
+
raw_arguments=json.dumps(function_call.get("args", {})),
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
@classmethod
|
|
231
|
+
def from_provider(cls, tool_call: dict[str, Any], provider: str) -> "ToolCall":
|
|
232
|
+
"""
|
|
233
|
+
Create from a specific provider's format.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
tool_call: Tool call data from the provider
|
|
237
|
+
provider: Provider name
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
Normalized ToolCall instance
|
|
241
|
+
"""
|
|
242
|
+
parsers = {
|
|
243
|
+
"openai": cls.from_openai,
|
|
244
|
+
"anthropic": cls.from_anthropic,
|
|
245
|
+
"google": cls.from_google,
|
|
246
|
+
"gemini": cls.from_google,
|
|
247
|
+
"mistral": cls.from_openai, # Mistral uses OpenAI format
|
|
248
|
+
"ollama": cls.from_openai, # Ollama uses OpenAI format
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
parser = parsers.get(provider.lower())
|
|
252
|
+
if parser:
|
|
253
|
+
return parser(tool_call)
|
|
254
|
+
|
|
255
|
+
# Default to OpenAI format
|
|
256
|
+
return cls.from_openai(tool_call)
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
@dataclass
|
|
260
|
+
class ToolResult:
|
|
261
|
+
"""Result from executing a tool."""
|
|
262
|
+
|
|
263
|
+
tool_call_id: str # ID of the tool call this result is for
|
|
264
|
+
content: str # String content of the result
|
|
265
|
+
status: ToolResultStatus = "success"
|
|
266
|
+
|
|
267
|
+
# Structured data (optional)
|
|
268
|
+
data: Any = None
|
|
269
|
+
|
|
270
|
+
# Error information (if status is "error")
|
|
271
|
+
error_type: str | None = None
|
|
272
|
+
error_message: str | None = None
|
|
273
|
+
|
|
274
|
+
# Execution metadata
|
|
275
|
+
execution_time_ms: float | None = None
|
|
276
|
+
|
|
277
|
+
def to_openai_format(self) -> dict[str, Any]:
|
|
278
|
+
"""Convert to OpenAI tool result format."""
|
|
279
|
+
return {
|
|
280
|
+
"role": "tool",
|
|
281
|
+
"tool_call_id": self.tool_call_id,
|
|
282
|
+
"content": self.content,
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
def to_anthropic_format(self) -> dict[str, Any]:
|
|
286
|
+
"""Convert to Anthropic tool result format."""
|
|
287
|
+
result: dict[str, Any] = {
|
|
288
|
+
"type": "tool_result",
|
|
289
|
+
"tool_use_id": self.tool_call_id,
|
|
290
|
+
"content": self.content,
|
|
291
|
+
}
|
|
292
|
+
if self.status == "error":
|
|
293
|
+
result["is_error"] = True
|
|
294
|
+
return result
|
|
295
|
+
|
|
296
|
+
def to_google_format(self) -> dict[str, Any]:
|
|
297
|
+
"""Convert to Google (Gemini) function response format."""
|
|
298
|
+
return {
|
|
299
|
+
"function_response": {
|
|
300
|
+
"name": "", # Needs to be filled by the caller
|
|
301
|
+
"response": {
|
|
302
|
+
"content": self.content,
|
|
303
|
+
"status": self.status,
|
|
304
|
+
},
|
|
305
|
+
},
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
def to_provider_format(self, provider: str) -> dict[str, Any]:
|
|
309
|
+
"""
|
|
310
|
+
Convert to a specific provider's format.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
provider: Provider name
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
Tool result in the provider's format
|
|
317
|
+
"""
|
|
318
|
+
formatters = {
|
|
319
|
+
"openai": self.to_openai_format,
|
|
320
|
+
"anthropic": self.to_anthropic_format,
|
|
321
|
+
"google": self.to_google_format,
|
|
322
|
+
"gemini": self.to_google_format,
|
|
323
|
+
"mistral": self.to_openai_format,
|
|
324
|
+
"ollama": self.to_openai_format,
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
formatter = formatters.get(provider.lower())
|
|
328
|
+
if formatter:
|
|
329
|
+
return formatter()
|
|
330
|
+
|
|
331
|
+
return self.to_openai_format()
|
|
332
|
+
|
|
333
|
+
@classmethod
|
|
334
|
+
def success(cls, tool_call_id: str, content: str, data: Any = None) -> "ToolResult":
|
|
335
|
+
"""Create a successful tool result."""
|
|
336
|
+
return cls(
|
|
337
|
+
tool_call_id=tool_call_id,
|
|
338
|
+
content=content,
|
|
339
|
+
status="success",
|
|
340
|
+
data=data,
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
@classmethod
|
|
344
|
+
def error(
|
|
345
|
+
cls,
|
|
346
|
+
tool_call_id: str,
|
|
347
|
+
error_message: str,
|
|
348
|
+
error_type: str = "ExecutionError",
|
|
349
|
+
) -> "ToolResult":
|
|
350
|
+
"""Create an error tool result."""
|
|
351
|
+
return cls(
|
|
352
|
+
tool_call_id=tool_call_id,
|
|
353
|
+
content=f"Error: {error_message}",
|
|
354
|
+
status="error",
|
|
355
|
+
error_type=error_type,
|
|
356
|
+
error_message=error_message,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
# Type variable for tool return types
|
|
361
|
+
T = TypeVar("T")
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
class BaseTool(ABC, Generic[T]):
|
|
365
|
+
"""
|
|
366
|
+
Abstract base class for agentic tools.
|
|
367
|
+
|
|
368
|
+
Provides a unified interface for defining tools that can be used with
|
|
369
|
+
any LLM provider including OpenAI, Google, Anthropic, Mistral, and others.
|
|
370
|
+
|
|
371
|
+
Subclasses must implement:
|
|
372
|
+
- execute: Synchronous tool execution
|
|
373
|
+
- aexecute: Asynchronous tool execution
|
|
374
|
+
|
|
375
|
+
Example:
|
|
376
|
+
class ReadFileTool(BaseTool[str]):
|
|
377
|
+
def __init__(self):
|
|
378
|
+
super().__init__(
|
|
379
|
+
name="read_file",
|
|
380
|
+
description="Read the contents of a file",
|
|
381
|
+
parameters=[
|
|
382
|
+
ParameterSchema(
|
|
383
|
+
name="path",
|
|
384
|
+
type="string",
|
|
385
|
+
description="Path to the file to read",
|
|
386
|
+
),
|
|
387
|
+
],
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
def execute(self, path: str) -> str:
|
|
391
|
+
with open(path, "r") as f:
|
|
392
|
+
return f.read()
|
|
393
|
+
"""
|
|
394
|
+
|
|
395
|
+
def __init__(
|
|
396
|
+
self,
|
|
397
|
+
name: str,
|
|
398
|
+
description: str,
|
|
399
|
+
parameters: list[ParameterSchema] | None = None,
|
|
400
|
+
category: ToolCategory = ToolCategory.CUSTOM,
|
|
401
|
+
version: str = "1.0.0",
|
|
402
|
+
) -> None:
|
|
403
|
+
"""
|
|
404
|
+
Initialize the tool.
|
|
405
|
+
|
|
406
|
+
Args:
|
|
407
|
+
name: Unique name for the tool (used in function calls)
|
|
408
|
+
description: Human-readable description of what the tool does
|
|
409
|
+
parameters: List of parameter schemas
|
|
410
|
+
category: Category for organization
|
|
411
|
+
version: Version string for the tool
|
|
412
|
+
"""
|
|
413
|
+
self._schema = ToolSchema(
|
|
414
|
+
name=name,
|
|
415
|
+
description=description,
|
|
416
|
+
parameters=parameters or [],
|
|
417
|
+
category=category,
|
|
418
|
+
version=version,
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
# Execution tracking
|
|
422
|
+
self._execution_count = 0
|
|
423
|
+
self._total_execution_time_ms = 0.0
|
|
424
|
+
self._last_execution_time_ms: float | None = None
|
|
425
|
+
|
|
426
|
+
# Session Context
|
|
427
|
+
self._session_id: str | None = None
|
|
428
|
+
|
|
429
|
+
def set_session_id(self, session_id: str) -> None:
|
|
430
|
+
"""
|
|
431
|
+
Set the session ID for the tool.
|
|
432
|
+
|
|
433
|
+
Args:
|
|
434
|
+
session_id: The session ID to associate with this tool instance.
|
|
435
|
+
"""
|
|
436
|
+
self._session_id = session_id
|
|
437
|
+
|
|
438
|
+
@property
|
|
439
|
+
def name(self) -> str:
|
|
440
|
+
"""Get the tool name."""
|
|
441
|
+
return self._schema.name
|
|
442
|
+
|
|
443
|
+
@property
|
|
444
|
+
def description(self) -> str:
|
|
445
|
+
"""Get the tool description."""
|
|
446
|
+
return self._schema.description
|
|
447
|
+
|
|
448
|
+
@property
|
|
449
|
+
def schema(self) -> ToolSchema:
|
|
450
|
+
"""Get the full tool schema."""
|
|
451
|
+
return self._schema
|
|
452
|
+
|
|
453
|
+
@property
|
|
454
|
+
def execution_count(self) -> int:
|
|
455
|
+
"""Get the total number of executions."""
|
|
456
|
+
return self._execution_count
|
|
457
|
+
|
|
458
|
+
@property
|
|
459
|
+
def average_execution_time_ms(self) -> float:
|
|
460
|
+
"""Get the average execution time in milliseconds."""
|
|
461
|
+
if self._execution_count == 0:
|
|
462
|
+
return 0.0
|
|
463
|
+
return self._total_execution_time_ms / self._execution_count
|
|
464
|
+
|
|
465
|
+
def to_provider_format(self, provider: str) -> dict[str, Any]:
|
|
466
|
+
"""
|
|
467
|
+
Get the tool definition in a specific provider's format.
|
|
468
|
+
|
|
469
|
+
Args:
|
|
470
|
+
provider: Provider name (openai, anthropic, google, mistral, ollama)
|
|
471
|
+
|
|
472
|
+
Returns:
|
|
473
|
+
Tool definition in the provider's format
|
|
474
|
+
"""
|
|
475
|
+
return self._schema.to_provider_format(provider)
|
|
476
|
+
|
|
477
|
+
def validate_arguments(self, arguments: dict[str, Any]) -> tuple[bool, list[str]]:
|
|
478
|
+
"""
|
|
479
|
+
Validate the provided arguments against the schema.
|
|
480
|
+
|
|
481
|
+
Args:
|
|
482
|
+
arguments: Dictionary of argument name to value
|
|
483
|
+
|
|
484
|
+
Returns:
|
|
485
|
+
Tuple of (is_valid, list of error messages)
|
|
486
|
+
"""
|
|
487
|
+
errors: list[str] = []
|
|
488
|
+
|
|
489
|
+
# Check required parameters
|
|
490
|
+
for param in self._schema.parameters:
|
|
491
|
+
if param.required and param.name not in arguments:
|
|
492
|
+
errors.append(f"Missing required parameter: {param.name}")
|
|
493
|
+
|
|
494
|
+
# Check parameter types (basic validation)
|
|
495
|
+
for param in self._schema.parameters:
|
|
496
|
+
if param.name not in arguments:
|
|
497
|
+
continue
|
|
498
|
+
|
|
499
|
+
value = arguments[param.name]
|
|
500
|
+
|
|
501
|
+
# Type checking
|
|
502
|
+
type_checks = {
|
|
503
|
+
"string": lambda v: isinstance(v, str),
|
|
504
|
+
"integer": lambda v: isinstance(v, int) and not isinstance(v, bool),
|
|
505
|
+
"number": lambda v: isinstance(v, (int, float))
|
|
506
|
+
and not isinstance(v, bool),
|
|
507
|
+
"boolean": lambda v: isinstance(v, bool),
|
|
508
|
+
"array": lambda v: isinstance(v, list),
|
|
509
|
+
"object": lambda v: isinstance(v, dict),
|
|
510
|
+
}
|
|
511
|
+
|
|
512
|
+
checker = type_checks.get(param.type)
|
|
513
|
+
if checker and not checker(value):
|
|
514
|
+
errors.append(
|
|
515
|
+
f"Parameter '{param.name}' should be {param.type}, got {type(value).__name__}"
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
return len(errors) == 0, errors
|
|
519
|
+
|
|
520
|
+
def _update_tracking(self, execution_time_ms: float) -> None:
|
|
521
|
+
"""Update execution tracking metrics."""
|
|
522
|
+
self._execution_count += 1
|
|
523
|
+
self._total_execution_time_ms += execution_time_ms
|
|
524
|
+
self._last_execution_time_ms = execution_time_ms
|
|
525
|
+
|
|
526
|
+
def reset_tracking(self) -> None:
|
|
527
|
+
"""Reset execution tracking metrics."""
|
|
528
|
+
self._execution_count = 0
|
|
529
|
+
self._total_execution_time_ms = 0.0
|
|
530
|
+
self._last_execution_time_ms = None
|
|
531
|
+
|
|
532
|
+
# -------------------------------------------------------------------------
|
|
533
|
+
# Abstract Methods - Must be implemented by subclasses
|
|
534
|
+
# -------------------------------------------------------------------------
|
|
535
|
+
|
|
536
|
+
@abstractmethod
|
|
537
|
+
def execute(self, **kwargs: Any) -> T:
|
|
538
|
+
"""
|
|
539
|
+
Synchronously execute the tool.
|
|
540
|
+
|
|
541
|
+
Args:
|
|
542
|
+
**kwargs: Tool arguments matching the parameter schema
|
|
543
|
+
|
|
544
|
+
Returns:
|
|
545
|
+
The tool's result
|
|
546
|
+
"""
|
|
547
|
+
...
|
|
548
|
+
|
|
549
|
+
@abstractmethod
|
|
550
|
+
async def aexecute(self, **kwargs: Any) -> T:
|
|
551
|
+
"""
|
|
552
|
+
Asynchronously execute the tool.
|
|
553
|
+
|
|
554
|
+
Args:
|
|
555
|
+
**kwargs: Tool arguments matching the parameter schema
|
|
556
|
+
|
|
557
|
+
Returns:
|
|
558
|
+
The tool's result
|
|
559
|
+
"""
|
|
560
|
+
...
|
|
561
|
+
|
|
562
|
+
@abstractmethod
|
|
563
|
+
def get_interruption_message(self, **kwargs: Any) -> str:
|
|
564
|
+
"""
|
|
565
|
+
Get a human-readable message describing the tool action for user confirmation.
|
|
566
|
+
|
|
567
|
+
This method should return a message that clearly describes what the tool
|
|
568
|
+
is about to do, suitable for displaying to the user before execution.
|
|
569
|
+
|
|
570
|
+
Args:
|
|
571
|
+
**kwargs: Tool arguments matching the parameter schema
|
|
572
|
+
|
|
573
|
+
Returns:
|
|
574
|
+
A formatted string describing the action, e.g., "execute read_file: /path/to/file"
|
|
575
|
+
"""
|
|
576
|
+
...
|
|
577
|
+
|
|
578
|
+
# -------------------------------------------------------------------------
|
|
579
|
+
# Convenience Methods
|
|
580
|
+
# -------------------------------------------------------------------------
|
|
581
|
+
|
|
582
|
+
def run(self, tool_call: ToolCall) -> ToolResult:
|
|
583
|
+
"""
|
|
584
|
+
Execute the tool from a ToolCall and return a ToolResult.
|
|
585
|
+
|
|
586
|
+
Args:
|
|
587
|
+
tool_call: The tool call to execute
|
|
588
|
+
|
|
589
|
+
Returns:
|
|
590
|
+
ToolResult with the execution result
|
|
591
|
+
"""
|
|
592
|
+
import time
|
|
593
|
+
|
|
594
|
+
start_time = time.perf_counter()
|
|
595
|
+
|
|
596
|
+
try:
|
|
597
|
+
# Validate arguments
|
|
598
|
+
is_valid, errors = self.validate_arguments(tool_call.arguments)
|
|
599
|
+
if not is_valid:
|
|
600
|
+
return ToolResult.error(
|
|
601
|
+
tool_call_id=tool_call.id,
|
|
602
|
+
error_message="; ".join(errors),
|
|
603
|
+
error_type="ValidationError",
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
# Execute the tool
|
|
607
|
+
result = self.execute(**tool_call.arguments)
|
|
608
|
+
|
|
609
|
+
# Calculate execution time
|
|
610
|
+
execution_time_ms = (time.perf_counter() - start_time) * 1000
|
|
611
|
+
self._update_tracking(execution_time_ms)
|
|
612
|
+
|
|
613
|
+
# Convert result to string if needed
|
|
614
|
+
content = result if isinstance(result, str) else json.dumps(result)
|
|
615
|
+
|
|
616
|
+
return ToolResult(
|
|
617
|
+
tool_call_id=tool_call.id,
|
|
618
|
+
content=content,
|
|
619
|
+
status="success",
|
|
620
|
+
data=result,
|
|
621
|
+
execution_time_ms=execution_time_ms,
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
except Exception as e:
|
|
625
|
+
execution_time_ms = (time.perf_counter() - start_time) * 1000
|
|
626
|
+
self._update_tracking(execution_time_ms)
|
|
627
|
+
|
|
628
|
+
return ToolResult.error(
|
|
629
|
+
tool_call_id=tool_call.id,
|
|
630
|
+
error_message=str(e),
|
|
631
|
+
error_type=type(e).__name__,
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
async def arun(self, tool_call: ToolCall) -> ToolResult:
|
|
635
|
+
"""
|
|
636
|
+
Asynchronously execute the tool from a ToolCall.
|
|
637
|
+
|
|
638
|
+
Args:
|
|
639
|
+
tool_call: The tool call to execute
|
|
640
|
+
|
|
641
|
+
Returns:
|
|
642
|
+
ToolResult with the execution result
|
|
643
|
+
"""
|
|
644
|
+
import time
|
|
645
|
+
|
|
646
|
+
start_time = time.perf_counter()
|
|
647
|
+
|
|
648
|
+
try:
|
|
649
|
+
# Validate arguments
|
|
650
|
+
is_valid, errors = self.validate_arguments(tool_call.arguments)
|
|
651
|
+
if not is_valid:
|
|
652
|
+
return ToolResult.error(
|
|
653
|
+
tool_call_id=tool_call.id,
|
|
654
|
+
error_message="; ".join(errors),
|
|
655
|
+
error_type="ValidationError",
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
# Execute the tool asynchronously
|
|
659
|
+
result = await self.aexecute(**tool_call.arguments)
|
|
660
|
+
|
|
661
|
+
# Calculate execution time
|
|
662
|
+
execution_time_ms = (time.perf_counter() - start_time) * 1000
|
|
663
|
+
self._update_tracking(execution_time_ms)
|
|
664
|
+
|
|
665
|
+
# Convert result to string if needed
|
|
666
|
+
content = result if isinstance(result, str) else json.dumps(result)
|
|
667
|
+
|
|
668
|
+
return ToolResult(
|
|
669
|
+
tool_call_id=tool_call.id,
|
|
670
|
+
content=content,
|
|
671
|
+
status="success",
|
|
672
|
+
data=result,
|
|
673
|
+
execution_time_ms=execution_time_ms,
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
except Exception as e:
|
|
677
|
+
execution_time_ms = (time.perf_counter() - start_time) * 1000
|
|
678
|
+
self._update_tracking(execution_time_ms)
|
|
679
|
+
|
|
680
|
+
return ToolResult.error(
|
|
681
|
+
tool_call_id=tool_call.id,
|
|
682
|
+
error_message=str(e),
|
|
683
|
+
error_type=type(e).__name__,
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
def __repr__(self) -> str:
|
|
687
|
+
"""String representation of the tool."""
|
|
688
|
+
return f"{self.__class__.__name__}(name='{self.name}')"
|
|
689
|
+
|
|
690
|
+
|
|
691
|
+
class ToolRegistry:
|
|
692
|
+
"""
|
|
693
|
+
Registry for managing multiple tools.
|
|
694
|
+
|
|
695
|
+
Provides a central location to register, retrieve, and manage tools.
|
|
696
|
+
|
|
697
|
+
Example:
|
|
698
|
+
registry = ToolRegistry()
|
|
699
|
+
registry.register(ReadFileTool())
|
|
700
|
+
registry.register(WriteFileTool())
|
|
701
|
+
|
|
702
|
+
tools = registry.to_provider_format("openai")
|
|
703
|
+
tool = registry.get("read_file")
|
|
704
|
+
"""
|
|
705
|
+
|
|
706
|
+
def __init__(self) -> None:
|
|
707
|
+
"""Initialize an empty tool registry."""
|
|
708
|
+
self._tools: dict[str, BaseTool] = {}
|
|
709
|
+
|
|
710
|
+
def register(self, tool: BaseTool) -> None:
|
|
711
|
+
"""
|
|
712
|
+
Register a tool.
|
|
713
|
+
|
|
714
|
+
Args:
|
|
715
|
+
tool: The tool to register
|
|
716
|
+
|
|
717
|
+
Raises:
|
|
718
|
+
ValueError: If a tool with the same name is already registered
|
|
719
|
+
"""
|
|
720
|
+
if tool.name in self._tools:
|
|
721
|
+
raise ValueError(f"Tool '{tool.name}' is already registered")
|
|
722
|
+
self._tools[tool.name] = tool
|
|
723
|
+
|
|
724
|
+
def unregister(self, name: str) -> bool:
|
|
725
|
+
"""
|
|
726
|
+
Unregister a tool by name.
|
|
727
|
+
|
|
728
|
+
Args:
|
|
729
|
+
name: Name of the tool to unregister
|
|
730
|
+
|
|
731
|
+
Returns:
|
|
732
|
+
True if the tool was unregistered, False if not found
|
|
733
|
+
"""
|
|
734
|
+
if name in self._tools:
|
|
735
|
+
del self._tools[name]
|
|
736
|
+
return True
|
|
737
|
+
return False
|
|
738
|
+
|
|
739
|
+
def get(self, name: str) -> BaseTool | None:
|
|
740
|
+
"""
|
|
741
|
+
Get a tool by name.
|
|
742
|
+
|
|
743
|
+
Args:
|
|
744
|
+
name: Name of the tool
|
|
745
|
+
|
|
746
|
+
Returns:
|
|
747
|
+
The tool if found, None otherwise
|
|
748
|
+
"""
|
|
749
|
+
return self._tools.get(name)
|
|
750
|
+
|
|
751
|
+
def get_by_category(self, category: ToolCategory) -> list[BaseTool]:
|
|
752
|
+
"""
|
|
753
|
+
Get all tools in a category.
|
|
754
|
+
|
|
755
|
+
Args:
|
|
756
|
+
category: The category to filter by
|
|
757
|
+
|
|
758
|
+
Returns:
|
|
759
|
+
List of tools in the category
|
|
760
|
+
"""
|
|
761
|
+
return [
|
|
762
|
+
tool for tool in self._tools.values() if tool.schema.category == category
|
|
763
|
+
]
|
|
764
|
+
|
|
765
|
+
@property
|
|
766
|
+
def tools(self) -> list[BaseTool]:
|
|
767
|
+
"""Get all registered tools."""
|
|
768
|
+
return list(self._tools.values())
|
|
769
|
+
|
|
770
|
+
@property
|
|
771
|
+
def names(self) -> list[str]:
|
|
772
|
+
"""Get all registered tool names."""
|
|
773
|
+
return list(self._tools.keys())
|
|
774
|
+
|
|
775
|
+
def to_provider_format(self, provider: str) -> list[dict[str, Any]]:
|
|
776
|
+
"""
|
|
777
|
+
Get all tools in a specific provider's format.
|
|
778
|
+
|
|
779
|
+
Args:
|
|
780
|
+
provider: Provider name
|
|
781
|
+
|
|
782
|
+
Returns:
|
|
783
|
+
List of tool definitions in the provider's format
|
|
784
|
+
"""
|
|
785
|
+
return [tool.to_provider_format(provider) for tool in self._tools.values()]
|
|
786
|
+
|
|
787
|
+
def run(self, tool_call: ToolCall) -> ToolResult:
|
|
788
|
+
"""
|
|
789
|
+
Execute a tool call using the registry.
|
|
790
|
+
|
|
791
|
+
Args:
|
|
792
|
+
tool_call: The tool call to execute
|
|
793
|
+
|
|
794
|
+
Returns:
|
|
795
|
+
ToolResult with the execution result
|
|
796
|
+
"""
|
|
797
|
+
tool = self.get(tool_call.name)
|
|
798
|
+
if tool is None:
|
|
799
|
+
return ToolResult.error(
|
|
800
|
+
tool_call_id=tool_call.id,
|
|
801
|
+
error_message=f"Tool '{tool_call.name}' not found",
|
|
802
|
+
error_type="ToolNotFoundError",
|
|
803
|
+
)
|
|
804
|
+
return tool.run(tool_call)
|
|
805
|
+
|
|
806
|
+
async def arun(self, tool_call: ToolCall) -> ToolResult:
|
|
807
|
+
"""
|
|
808
|
+
Asynchronously execute a tool call using the registry.
|
|
809
|
+
|
|
810
|
+
Args:
|
|
811
|
+
tool_call: The tool call to execute
|
|
812
|
+
|
|
813
|
+
Returns:
|
|
814
|
+
ToolResult with the execution result
|
|
815
|
+
"""
|
|
816
|
+
tool = self.get(tool_call.name)
|
|
817
|
+
if tool is None:
|
|
818
|
+
return ToolResult.error(
|
|
819
|
+
tool_call_id=tool_call.id,
|
|
820
|
+
error_message=f"Tool '{tool_call.name}' not found",
|
|
821
|
+
error_type="ToolNotFoundError",
|
|
822
|
+
)
|
|
823
|
+
return await tool.arun(tool_call)
|
|
824
|
+
|
|
825
|
+
def __len__(self) -> int:
|
|
826
|
+
"""Get the number of registered tools."""
|
|
827
|
+
return len(self._tools)
|
|
828
|
+
|
|
829
|
+
def __contains__(self, name: str) -> bool:
|
|
830
|
+
"""Check if a tool is registered."""
|
|
831
|
+
return name in self._tools
|
|
832
|
+
|
|
833
|
+
def __iter__(self):
|
|
834
|
+
"""Iterate over registered tools."""
|
|
835
|
+
return iter(self._tools.values())
|
|
836
|
+
|
|
837
|
+
def __repr__(self) -> str:
|
|
838
|
+
"""String representation of the registry."""
|
|
839
|
+
return f"ToolRegistry(tools={list(self._tools.keys())})"
|
|
840
|
+
|
|
841
|
+
|
|
842
|
+
def tool(
|
|
843
|
+
name: str | None = None,
|
|
844
|
+
description: str | None = None,
|
|
845
|
+
parameters: list[ParameterSchema] | None = None,
|
|
846
|
+
category: ToolCategory = ToolCategory.CUSTOM,
|
|
847
|
+
) -> Callable[[Callable[..., T]], "FunctionTool[T]"]:
|
|
848
|
+
"""
|
|
849
|
+
Decorator to create a tool from a function.
|
|
850
|
+
|
|
851
|
+
Example:
|
|
852
|
+
@tool(
|
|
853
|
+
name="greet",
|
|
854
|
+
description="Greet a user by name",
|
|
855
|
+
parameters=[
|
|
856
|
+
ParameterSchema(name="name", type="string", description="Name to greet"),
|
|
857
|
+
],
|
|
858
|
+
)
|
|
859
|
+
def greet(name: str) -> str:
|
|
860
|
+
return f"Hello, {name}!"
|
|
861
|
+
|
|
862
|
+
Args:
|
|
863
|
+
name: Tool name (defaults to function name)
|
|
864
|
+
description: Tool description (defaults to function docstring)
|
|
865
|
+
parameters: Parameter schemas
|
|
866
|
+
category: Tool category
|
|
867
|
+
|
|
868
|
+
Returns:
|
|
869
|
+
Decorator that creates a FunctionTool
|
|
870
|
+
"""
|
|
871
|
+
|
|
872
|
+
def decorator(func: Callable[..., T]) -> "FunctionTool[T]":
|
|
873
|
+
tool_name = name or func.__name__
|
|
874
|
+
tool_description = description or func.__doc__ or f"Execute {tool_name}"
|
|
875
|
+
|
|
876
|
+
return FunctionTool(
|
|
877
|
+
name=tool_name,
|
|
878
|
+
description=tool_description,
|
|
879
|
+
parameters=parameters,
|
|
880
|
+
category=category,
|
|
881
|
+
func=func,
|
|
882
|
+
)
|
|
883
|
+
|
|
884
|
+
return decorator
|
|
885
|
+
|
|
886
|
+
|
|
887
|
+
class FunctionTool(BaseTool[T]):
|
|
888
|
+
"""
|
|
889
|
+
A tool created from a function.
|
|
890
|
+
|
|
891
|
+
This is used by the @tool decorator to wrap functions as tools.
|
|
892
|
+
"""
|
|
893
|
+
|
|
894
|
+
def __init__(
|
|
895
|
+
self,
|
|
896
|
+
name: str,
|
|
897
|
+
description: str,
|
|
898
|
+
func: Callable[..., T],
|
|
899
|
+
parameters: list[ParameterSchema] | None = None,
|
|
900
|
+
category: ToolCategory = ToolCategory.CUSTOM,
|
|
901
|
+
version: str = "1.0.0",
|
|
902
|
+
) -> None:
|
|
903
|
+
"""
|
|
904
|
+
Initialize a function-based tool.
|
|
905
|
+
|
|
906
|
+
Args:
|
|
907
|
+
name: Tool name
|
|
908
|
+
description: Tool description
|
|
909
|
+
func: The function to wrap
|
|
910
|
+
parameters: Parameter schemas
|
|
911
|
+
category: Tool category
|
|
912
|
+
version: Tool version
|
|
913
|
+
"""
|
|
914
|
+
super().__init__(
|
|
915
|
+
name=name,
|
|
916
|
+
description=description,
|
|
917
|
+
parameters=parameters,
|
|
918
|
+
category=category,
|
|
919
|
+
version=version,
|
|
920
|
+
)
|
|
921
|
+
self._func = func
|
|
922
|
+
|
|
923
|
+
def execute(self, **kwargs: Any) -> T:
|
|
924
|
+
"""Execute the wrapped function synchronously."""
|
|
925
|
+
return self._func(**kwargs)
|
|
926
|
+
|
|
927
|
+
async def aexecute(self, **kwargs: Any) -> T:
|
|
928
|
+
"""
|
|
929
|
+
Execute the wrapped function asynchronously.
|
|
930
|
+
|
|
931
|
+
If the function is a coroutine, it will be awaited.
|
|
932
|
+
Otherwise, it will be run in a thread pool.
|
|
933
|
+
"""
|
|
934
|
+
import asyncio
|
|
935
|
+
import inspect
|
|
936
|
+
|
|
937
|
+
if inspect.iscoroutinefunction(self._func):
|
|
938
|
+
return await self._func(**kwargs)
|
|
939
|
+
else:
|
|
940
|
+
return await asyncio.to_thread(self._func, **kwargs)
|
|
941
|
+
|
|
942
|
+
def get_interruption_message(self, **kwargs: Any) -> str:
|
|
943
|
+
"""
|
|
944
|
+
Get interruption message for user confirmation.
|
|
945
|
+
|
|
946
|
+
For function-based tools, generates a message using the tool name
|
|
947
|
+
and the first string argument value (if any).
|
|
948
|
+
"""
|
|
949
|
+
# Try to find a meaningful argument to display
|
|
950
|
+
for key, value in kwargs.items():
|
|
951
|
+
if isinstance(value, str) and value:
|
|
952
|
+
return f"execute {self.name}: {value}"
|
|
953
|
+
|
|
954
|
+
# Fallback to just the tool name
|
|
955
|
+
return f"execute {self.name}"
|