kader 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kader/tools/base.py ADDED
@@ -0,0 +1,955 @@
1
+ """
2
+ Base class for Agentic Tools.
3
+
4
+ A versatile, provider-agnostic base class for defining tools that can be used
5
+ with any LLM provider (OpenAI, Google, Anthropic, Mistral, and others).
6
+ """
7
+
8
+ import json
9
+ from abc import ABC, abstractmethod
10
+ from dataclasses import dataclass, field
11
+ from enum import Enum
12
+ from typing import (
13
+ Any,
14
+ Callable,
15
+ Generic,
16
+ Literal,
17
+ TypeAlias,
18
+ TypeVar,
19
+ )
20
+
21
+ # Type Aliases
22
+ ParameterType: TypeAlias = Literal[
23
+ "string", "integer", "number", "boolean", "array", "object"
24
+ ]
25
+ ToolResultStatus: TypeAlias = Literal["success", "error", "pending"]
26
+
27
+
28
+ class ToolCategory(str, Enum):
29
+ """Categories of tools for organization and filtering."""
30
+
31
+ FILE_SYSTEM = "file_system"
32
+ CODE = "code"
33
+ WEB = "web"
34
+ SEARCH = "search"
35
+ DATABASE = "database"
36
+ API = "api"
37
+ UTILITY = "utility"
38
+ CUSTOM = "custom"
39
+
40
+
41
+ @dataclass
42
+ class ParameterSchema:
43
+ """Schema for a single tool parameter."""
44
+
45
+ name: str
46
+ type: ParameterType
47
+ description: str
48
+ required: bool = True
49
+
50
+ # Additional constraints
51
+ enum: list[str] | None = None
52
+ default: Any = None
53
+ minimum: int | float | None = None
54
+ maximum: int | float | None = None
55
+ min_length: int | None = None
56
+ max_length: int | None = None
57
+ pattern: str | None = None
58
+
59
+ # For array types
60
+ items_type: ParameterType | None = None
61
+
62
+ # For object types
63
+ properties: list["ParameterSchema"] | None = None
64
+
65
+ def to_json_schema(self) -> dict[str, Any]:
66
+ """Convert to JSON Schema format (OpenAI/standard format)."""
67
+ schema: dict[str, Any] = {
68
+ "type": self.type,
69
+ "description": self.description,
70
+ }
71
+
72
+ if self.enum:
73
+ schema["enum"] = self.enum
74
+ if self.default is not None:
75
+ schema["default"] = self.default
76
+ if self.minimum is not None:
77
+ schema["minimum"] = self.minimum
78
+ if self.maximum is not None:
79
+ schema["maximum"] = self.maximum
80
+ if self.min_length is not None:
81
+ schema["minLength"] = self.min_length
82
+ if self.max_length is not None:
83
+ schema["maxLength"] = self.max_length
84
+ if self.pattern is not None:
85
+ schema["pattern"] = self.pattern
86
+
87
+ # Array items
88
+ if self.type == "array" and self.items_type:
89
+ schema["items"] = {"type": self.items_type}
90
+
91
+ # Nested object properties
92
+ if self.type == "object" and self.properties:
93
+ schema["properties"] = {
94
+ prop.name: prop.to_json_schema() for prop in self.properties
95
+ }
96
+ schema["required"] = [
97
+ prop.name for prop in self.properties if prop.required
98
+ ]
99
+
100
+ return schema
101
+
102
+
103
+ @dataclass
104
+ class ToolSchema:
105
+ """Complete schema definition for a tool."""
106
+
107
+ name: str
108
+ description: str
109
+ parameters: list[ParameterSchema] = field(default_factory=list)
110
+
111
+ # Optional metadata
112
+ category: ToolCategory = ToolCategory.CUSTOM
113
+ version: str = "1.0.0"
114
+ deprecated: bool = False
115
+
116
+ def to_json_schema(self) -> dict[str, Any]:
117
+ """Convert to JSON Schema format for parameters."""
118
+ properties = {param.name: param.to_json_schema() for param in self.parameters}
119
+ required = [param.name for param in self.parameters if param.required]
120
+
121
+ return {
122
+ "type": "object",
123
+ "properties": properties,
124
+ "required": required,
125
+ }
126
+
127
+ def to_openai_format(self) -> dict[str, Any]:
128
+ """Convert to OpenAI function calling format."""
129
+ return {
130
+ "type": "function",
131
+ "function": {
132
+ "name": self.name,
133
+ "description": self.description,
134
+ "parameters": self.to_json_schema(),
135
+ },
136
+ }
137
+
138
+ def to_anthropic_format(self) -> dict[str, Any]:
139
+ """Convert to Anthropic tool format."""
140
+ return {
141
+ "name": self.name,
142
+ "description": self.description,
143
+ "input_schema": self.to_json_schema(),
144
+ }
145
+
146
+ def to_google_format(self) -> dict[str, Any]:
147
+ """Convert to Google (Gemini) tool format."""
148
+ return {
149
+ "name": self.name,
150
+ "description": self.description,
151
+ "parameters": self.to_json_schema(),
152
+ }
153
+
154
+ def to_mistral_format(self) -> dict[str, Any]:
155
+ """Convert to Mistral tool format (same as OpenAI)."""
156
+ return self.to_openai_format()
157
+
158
+ def to_ollama_format(self) -> dict[str, Any]:
159
+ """Convert to Ollama tool format (same as OpenAI)."""
160
+ return self.to_openai_format()
161
+
162
+ def to_provider_format(self, provider: str) -> dict[str, Any]:
163
+ """
164
+ Convert to a specific provider's format.
165
+
166
+ Args:
167
+ provider: Provider name (openai, anthropic, google, mistral, ollama)
168
+
169
+ Returns:
170
+ Tool schema in the provider's format
171
+ """
172
+ formatters = {
173
+ "openai": self.to_openai_format,
174
+ "anthropic": self.to_anthropic_format,
175
+ "google": self.to_google_format,
176
+ "gemini": self.to_google_format,
177
+ "mistral": self.to_mistral_format,
178
+ "ollama": self.to_ollama_format,
179
+ }
180
+
181
+ formatter = formatters.get(provider.lower())
182
+ if formatter:
183
+ return formatter()
184
+
185
+ # Default to OpenAI format as it's most common
186
+ return self.to_openai_format()
187
+
188
+
189
+ @dataclass
190
+ class ToolCall:
191
+ """Represents a tool call from an LLM."""
192
+
193
+ id: str # Unique identifier for the tool call
194
+ name: str # Name of the tool to call
195
+ arguments: dict[str, Any] # Parsed arguments
196
+ raw_arguments: str | None = None # Original JSON string (if available)
197
+
198
+ @classmethod
199
+ def from_openai(cls, tool_call: dict[str, Any]) -> "ToolCall":
200
+ """Create from OpenAI tool call format."""
201
+ function = tool_call.get("function", {})
202
+ raw_args = function.get("arguments", "{}")
203
+ return cls(
204
+ id=tool_call.get("id", ""),
205
+ name=function.get("name", ""),
206
+ arguments=json.loads(raw_args) if raw_args else {},
207
+ raw_arguments=raw_args,
208
+ )
209
+
210
+ @classmethod
211
+ def from_anthropic(cls, tool_use: dict[str, Any]) -> "ToolCall":
212
+ """Create from Anthropic tool use format."""
213
+ return cls(
214
+ id=tool_use.get("id", ""),
215
+ name=tool_use.get("name", ""),
216
+ arguments=tool_use.get("input", {}),
217
+ raw_arguments=json.dumps(tool_use.get("input", {})),
218
+ )
219
+
220
+ @classmethod
221
+ def from_google(cls, function_call: dict[str, Any]) -> "ToolCall":
222
+ """Create from Google (Gemini) function call format."""
223
+ return cls(
224
+ id=function_call.get("id", ""),
225
+ name=function_call.get("name", ""),
226
+ arguments=function_call.get("args", {}),
227
+ raw_arguments=json.dumps(function_call.get("args", {})),
228
+ )
229
+
230
+ @classmethod
231
+ def from_provider(cls, tool_call: dict[str, Any], provider: str) -> "ToolCall":
232
+ """
233
+ Create from a specific provider's format.
234
+
235
+ Args:
236
+ tool_call: Tool call data from the provider
237
+ provider: Provider name
238
+
239
+ Returns:
240
+ Normalized ToolCall instance
241
+ """
242
+ parsers = {
243
+ "openai": cls.from_openai,
244
+ "anthropic": cls.from_anthropic,
245
+ "google": cls.from_google,
246
+ "gemini": cls.from_google,
247
+ "mistral": cls.from_openai, # Mistral uses OpenAI format
248
+ "ollama": cls.from_openai, # Ollama uses OpenAI format
249
+ }
250
+
251
+ parser = parsers.get(provider.lower())
252
+ if parser:
253
+ return parser(tool_call)
254
+
255
+ # Default to OpenAI format
256
+ return cls.from_openai(tool_call)
257
+
258
+
259
+ @dataclass
260
+ class ToolResult:
261
+ """Result from executing a tool."""
262
+
263
+ tool_call_id: str # ID of the tool call this result is for
264
+ content: str # String content of the result
265
+ status: ToolResultStatus = "success"
266
+
267
+ # Structured data (optional)
268
+ data: Any = None
269
+
270
+ # Error information (if status is "error")
271
+ error_type: str | None = None
272
+ error_message: str | None = None
273
+
274
+ # Execution metadata
275
+ execution_time_ms: float | None = None
276
+
277
+ def to_openai_format(self) -> dict[str, Any]:
278
+ """Convert to OpenAI tool result format."""
279
+ return {
280
+ "role": "tool",
281
+ "tool_call_id": self.tool_call_id,
282
+ "content": self.content,
283
+ }
284
+
285
+ def to_anthropic_format(self) -> dict[str, Any]:
286
+ """Convert to Anthropic tool result format."""
287
+ result: dict[str, Any] = {
288
+ "type": "tool_result",
289
+ "tool_use_id": self.tool_call_id,
290
+ "content": self.content,
291
+ }
292
+ if self.status == "error":
293
+ result["is_error"] = True
294
+ return result
295
+
296
+ def to_google_format(self) -> dict[str, Any]:
297
+ """Convert to Google (Gemini) function response format."""
298
+ return {
299
+ "function_response": {
300
+ "name": "", # Needs to be filled by the caller
301
+ "response": {
302
+ "content": self.content,
303
+ "status": self.status,
304
+ },
305
+ },
306
+ }
307
+
308
+ def to_provider_format(self, provider: str) -> dict[str, Any]:
309
+ """
310
+ Convert to a specific provider's format.
311
+
312
+ Args:
313
+ provider: Provider name
314
+
315
+ Returns:
316
+ Tool result in the provider's format
317
+ """
318
+ formatters = {
319
+ "openai": self.to_openai_format,
320
+ "anthropic": self.to_anthropic_format,
321
+ "google": self.to_google_format,
322
+ "gemini": self.to_google_format,
323
+ "mistral": self.to_openai_format,
324
+ "ollama": self.to_openai_format,
325
+ }
326
+
327
+ formatter = formatters.get(provider.lower())
328
+ if formatter:
329
+ return formatter()
330
+
331
+ return self.to_openai_format()
332
+
333
+ @classmethod
334
+ def success(cls, tool_call_id: str, content: str, data: Any = None) -> "ToolResult":
335
+ """Create a successful tool result."""
336
+ return cls(
337
+ tool_call_id=tool_call_id,
338
+ content=content,
339
+ status="success",
340
+ data=data,
341
+ )
342
+
343
+ @classmethod
344
+ def error(
345
+ cls,
346
+ tool_call_id: str,
347
+ error_message: str,
348
+ error_type: str = "ExecutionError",
349
+ ) -> "ToolResult":
350
+ """Create an error tool result."""
351
+ return cls(
352
+ tool_call_id=tool_call_id,
353
+ content=f"Error: {error_message}",
354
+ status="error",
355
+ error_type=error_type,
356
+ error_message=error_message,
357
+ )
358
+
359
+
360
+ # Type variable for tool return types
361
+ T = TypeVar("T")
362
+
363
+
364
+ class BaseTool(ABC, Generic[T]):
365
+ """
366
+ Abstract base class for agentic tools.
367
+
368
+ Provides a unified interface for defining tools that can be used with
369
+ any LLM provider including OpenAI, Google, Anthropic, Mistral, and others.
370
+
371
+ Subclasses must implement:
372
+ - execute: Synchronous tool execution
373
+ - aexecute: Asynchronous tool execution
374
+
375
+ Example:
376
+ class ReadFileTool(BaseTool[str]):
377
+ def __init__(self):
378
+ super().__init__(
379
+ name="read_file",
380
+ description="Read the contents of a file",
381
+ parameters=[
382
+ ParameterSchema(
383
+ name="path",
384
+ type="string",
385
+ description="Path to the file to read",
386
+ ),
387
+ ],
388
+ )
389
+
390
+ def execute(self, path: str) -> str:
391
+ with open(path, "r") as f:
392
+ return f.read()
393
+ """
394
+
395
+ def __init__(
396
+ self,
397
+ name: str,
398
+ description: str,
399
+ parameters: list[ParameterSchema] | None = None,
400
+ category: ToolCategory = ToolCategory.CUSTOM,
401
+ version: str = "1.0.0",
402
+ ) -> None:
403
+ """
404
+ Initialize the tool.
405
+
406
+ Args:
407
+ name: Unique name for the tool (used in function calls)
408
+ description: Human-readable description of what the tool does
409
+ parameters: List of parameter schemas
410
+ category: Category for organization
411
+ version: Version string for the tool
412
+ """
413
+ self._schema = ToolSchema(
414
+ name=name,
415
+ description=description,
416
+ parameters=parameters or [],
417
+ category=category,
418
+ version=version,
419
+ )
420
+
421
+ # Execution tracking
422
+ self._execution_count = 0
423
+ self._total_execution_time_ms = 0.0
424
+ self._last_execution_time_ms: float | None = None
425
+
426
+ # Session Context
427
+ self._session_id: str | None = None
428
+
429
+ def set_session_id(self, session_id: str) -> None:
430
+ """
431
+ Set the session ID for the tool.
432
+
433
+ Args:
434
+ session_id: The session ID to associate with this tool instance.
435
+ """
436
+ self._session_id = session_id
437
+
438
+ @property
439
+ def name(self) -> str:
440
+ """Get the tool name."""
441
+ return self._schema.name
442
+
443
+ @property
444
+ def description(self) -> str:
445
+ """Get the tool description."""
446
+ return self._schema.description
447
+
448
+ @property
449
+ def schema(self) -> ToolSchema:
450
+ """Get the full tool schema."""
451
+ return self._schema
452
+
453
+ @property
454
+ def execution_count(self) -> int:
455
+ """Get the total number of executions."""
456
+ return self._execution_count
457
+
458
+ @property
459
+ def average_execution_time_ms(self) -> float:
460
+ """Get the average execution time in milliseconds."""
461
+ if self._execution_count == 0:
462
+ return 0.0
463
+ return self._total_execution_time_ms / self._execution_count
464
+
465
+ def to_provider_format(self, provider: str) -> dict[str, Any]:
466
+ """
467
+ Get the tool definition in a specific provider's format.
468
+
469
+ Args:
470
+ provider: Provider name (openai, anthropic, google, mistral, ollama)
471
+
472
+ Returns:
473
+ Tool definition in the provider's format
474
+ """
475
+ return self._schema.to_provider_format(provider)
476
+
477
+ def validate_arguments(self, arguments: dict[str, Any]) -> tuple[bool, list[str]]:
478
+ """
479
+ Validate the provided arguments against the schema.
480
+
481
+ Args:
482
+ arguments: Dictionary of argument name to value
483
+
484
+ Returns:
485
+ Tuple of (is_valid, list of error messages)
486
+ """
487
+ errors: list[str] = []
488
+
489
+ # Check required parameters
490
+ for param in self._schema.parameters:
491
+ if param.required and param.name not in arguments:
492
+ errors.append(f"Missing required parameter: {param.name}")
493
+
494
+ # Check parameter types (basic validation)
495
+ for param in self._schema.parameters:
496
+ if param.name not in arguments:
497
+ continue
498
+
499
+ value = arguments[param.name]
500
+
501
+ # Type checking
502
+ type_checks = {
503
+ "string": lambda v: isinstance(v, str),
504
+ "integer": lambda v: isinstance(v, int) and not isinstance(v, bool),
505
+ "number": lambda v: isinstance(v, (int, float))
506
+ and not isinstance(v, bool),
507
+ "boolean": lambda v: isinstance(v, bool),
508
+ "array": lambda v: isinstance(v, list),
509
+ "object": lambda v: isinstance(v, dict),
510
+ }
511
+
512
+ checker = type_checks.get(param.type)
513
+ if checker and not checker(value):
514
+ errors.append(
515
+ f"Parameter '{param.name}' should be {param.type}, got {type(value).__name__}"
516
+ )
517
+
518
+ return len(errors) == 0, errors
519
+
520
+ def _update_tracking(self, execution_time_ms: float) -> None:
521
+ """Update execution tracking metrics."""
522
+ self._execution_count += 1
523
+ self._total_execution_time_ms += execution_time_ms
524
+ self._last_execution_time_ms = execution_time_ms
525
+
526
+ def reset_tracking(self) -> None:
527
+ """Reset execution tracking metrics."""
528
+ self._execution_count = 0
529
+ self._total_execution_time_ms = 0.0
530
+ self._last_execution_time_ms = None
531
+
532
+ # -------------------------------------------------------------------------
533
+ # Abstract Methods - Must be implemented by subclasses
534
+ # -------------------------------------------------------------------------
535
+
536
+ @abstractmethod
537
+ def execute(self, **kwargs: Any) -> T:
538
+ """
539
+ Synchronously execute the tool.
540
+
541
+ Args:
542
+ **kwargs: Tool arguments matching the parameter schema
543
+
544
+ Returns:
545
+ The tool's result
546
+ """
547
+ ...
548
+
549
+ @abstractmethod
550
+ async def aexecute(self, **kwargs: Any) -> T:
551
+ """
552
+ Asynchronously execute the tool.
553
+
554
+ Args:
555
+ **kwargs: Tool arguments matching the parameter schema
556
+
557
+ Returns:
558
+ The tool's result
559
+ """
560
+ ...
561
+
562
+ @abstractmethod
563
+ def get_interruption_message(self, **kwargs: Any) -> str:
564
+ """
565
+ Get a human-readable message describing the tool action for user confirmation.
566
+
567
+ This method should return a message that clearly describes what the tool
568
+ is about to do, suitable for displaying to the user before execution.
569
+
570
+ Args:
571
+ **kwargs: Tool arguments matching the parameter schema
572
+
573
+ Returns:
574
+ A formatted string describing the action, e.g., "execute read_file: /path/to/file"
575
+ """
576
+ ...
577
+
578
+ # -------------------------------------------------------------------------
579
+ # Convenience Methods
580
+ # -------------------------------------------------------------------------
581
+
582
+ def run(self, tool_call: ToolCall) -> ToolResult:
583
+ """
584
+ Execute the tool from a ToolCall and return a ToolResult.
585
+
586
+ Args:
587
+ tool_call: The tool call to execute
588
+
589
+ Returns:
590
+ ToolResult with the execution result
591
+ """
592
+ import time
593
+
594
+ start_time = time.perf_counter()
595
+
596
+ try:
597
+ # Validate arguments
598
+ is_valid, errors = self.validate_arguments(tool_call.arguments)
599
+ if not is_valid:
600
+ return ToolResult.error(
601
+ tool_call_id=tool_call.id,
602
+ error_message="; ".join(errors),
603
+ error_type="ValidationError",
604
+ )
605
+
606
+ # Execute the tool
607
+ result = self.execute(**tool_call.arguments)
608
+
609
+ # Calculate execution time
610
+ execution_time_ms = (time.perf_counter() - start_time) * 1000
611
+ self._update_tracking(execution_time_ms)
612
+
613
+ # Convert result to string if needed
614
+ content = result if isinstance(result, str) else json.dumps(result)
615
+
616
+ return ToolResult(
617
+ tool_call_id=tool_call.id,
618
+ content=content,
619
+ status="success",
620
+ data=result,
621
+ execution_time_ms=execution_time_ms,
622
+ )
623
+
624
+ except Exception as e:
625
+ execution_time_ms = (time.perf_counter() - start_time) * 1000
626
+ self._update_tracking(execution_time_ms)
627
+
628
+ return ToolResult.error(
629
+ tool_call_id=tool_call.id,
630
+ error_message=str(e),
631
+ error_type=type(e).__name__,
632
+ )
633
+
634
+ async def arun(self, tool_call: ToolCall) -> ToolResult:
635
+ """
636
+ Asynchronously execute the tool from a ToolCall.
637
+
638
+ Args:
639
+ tool_call: The tool call to execute
640
+
641
+ Returns:
642
+ ToolResult with the execution result
643
+ """
644
+ import time
645
+
646
+ start_time = time.perf_counter()
647
+
648
+ try:
649
+ # Validate arguments
650
+ is_valid, errors = self.validate_arguments(tool_call.arguments)
651
+ if not is_valid:
652
+ return ToolResult.error(
653
+ tool_call_id=tool_call.id,
654
+ error_message="; ".join(errors),
655
+ error_type="ValidationError",
656
+ )
657
+
658
+ # Execute the tool asynchronously
659
+ result = await self.aexecute(**tool_call.arguments)
660
+
661
+ # Calculate execution time
662
+ execution_time_ms = (time.perf_counter() - start_time) * 1000
663
+ self._update_tracking(execution_time_ms)
664
+
665
+ # Convert result to string if needed
666
+ content = result if isinstance(result, str) else json.dumps(result)
667
+
668
+ return ToolResult(
669
+ tool_call_id=tool_call.id,
670
+ content=content,
671
+ status="success",
672
+ data=result,
673
+ execution_time_ms=execution_time_ms,
674
+ )
675
+
676
+ except Exception as e:
677
+ execution_time_ms = (time.perf_counter() - start_time) * 1000
678
+ self._update_tracking(execution_time_ms)
679
+
680
+ return ToolResult.error(
681
+ tool_call_id=tool_call.id,
682
+ error_message=str(e),
683
+ error_type=type(e).__name__,
684
+ )
685
+
686
+ def __repr__(self) -> str:
687
+ """String representation of the tool."""
688
+ return f"{self.__class__.__name__}(name='{self.name}')"
689
+
690
+
691
+ class ToolRegistry:
692
+ """
693
+ Registry for managing multiple tools.
694
+
695
+ Provides a central location to register, retrieve, and manage tools.
696
+
697
+ Example:
698
+ registry = ToolRegistry()
699
+ registry.register(ReadFileTool())
700
+ registry.register(WriteFileTool())
701
+
702
+ tools = registry.to_provider_format("openai")
703
+ tool = registry.get("read_file")
704
+ """
705
+
706
+ def __init__(self) -> None:
707
+ """Initialize an empty tool registry."""
708
+ self._tools: dict[str, BaseTool] = {}
709
+
710
+ def register(self, tool: BaseTool) -> None:
711
+ """
712
+ Register a tool.
713
+
714
+ Args:
715
+ tool: The tool to register
716
+
717
+ Raises:
718
+ ValueError: If a tool with the same name is already registered
719
+ """
720
+ if tool.name in self._tools:
721
+ raise ValueError(f"Tool '{tool.name}' is already registered")
722
+ self._tools[tool.name] = tool
723
+
724
+ def unregister(self, name: str) -> bool:
725
+ """
726
+ Unregister a tool by name.
727
+
728
+ Args:
729
+ name: Name of the tool to unregister
730
+
731
+ Returns:
732
+ True if the tool was unregistered, False if not found
733
+ """
734
+ if name in self._tools:
735
+ del self._tools[name]
736
+ return True
737
+ return False
738
+
739
+ def get(self, name: str) -> BaseTool | None:
740
+ """
741
+ Get a tool by name.
742
+
743
+ Args:
744
+ name: Name of the tool
745
+
746
+ Returns:
747
+ The tool if found, None otherwise
748
+ """
749
+ return self._tools.get(name)
750
+
751
+ def get_by_category(self, category: ToolCategory) -> list[BaseTool]:
752
+ """
753
+ Get all tools in a category.
754
+
755
+ Args:
756
+ category: The category to filter by
757
+
758
+ Returns:
759
+ List of tools in the category
760
+ """
761
+ return [
762
+ tool for tool in self._tools.values() if tool.schema.category == category
763
+ ]
764
+
765
+ @property
766
+ def tools(self) -> list[BaseTool]:
767
+ """Get all registered tools."""
768
+ return list(self._tools.values())
769
+
770
+ @property
771
+ def names(self) -> list[str]:
772
+ """Get all registered tool names."""
773
+ return list(self._tools.keys())
774
+
775
+ def to_provider_format(self, provider: str) -> list[dict[str, Any]]:
776
+ """
777
+ Get all tools in a specific provider's format.
778
+
779
+ Args:
780
+ provider: Provider name
781
+
782
+ Returns:
783
+ List of tool definitions in the provider's format
784
+ """
785
+ return [tool.to_provider_format(provider) for tool in self._tools.values()]
786
+
787
+ def run(self, tool_call: ToolCall) -> ToolResult:
788
+ """
789
+ Execute a tool call using the registry.
790
+
791
+ Args:
792
+ tool_call: The tool call to execute
793
+
794
+ Returns:
795
+ ToolResult with the execution result
796
+ """
797
+ tool = self.get(tool_call.name)
798
+ if tool is None:
799
+ return ToolResult.error(
800
+ tool_call_id=tool_call.id,
801
+ error_message=f"Tool '{tool_call.name}' not found",
802
+ error_type="ToolNotFoundError",
803
+ )
804
+ return tool.run(tool_call)
805
+
806
+ async def arun(self, tool_call: ToolCall) -> ToolResult:
807
+ """
808
+ Asynchronously execute a tool call using the registry.
809
+
810
+ Args:
811
+ tool_call: The tool call to execute
812
+
813
+ Returns:
814
+ ToolResult with the execution result
815
+ """
816
+ tool = self.get(tool_call.name)
817
+ if tool is None:
818
+ return ToolResult.error(
819
+ tool_call_id=tool_call.id,
820
+ error_message=f"Tool '{tool_call.name}' not found",
821
+ error_type="ToolNotFoundError",
822
+ )
823
+ return await tool.arun(tool_call)
824
+
825
+ def __len__(self) -> int:
826
+ """Get the number of registered tools."""
827
+ return len(self._tools)
828
+
829
+ def __contains__(self, name: str) -> bool:
830
+ """Check if a tool is registered."""
831
+ return name in self._tools
832
+
833
+ def __iter__(self):
834
+ """Iterate over registered tools."""
835
+ return iter(self._tools.values())
836
+
837
+ def __repr__(self) -> str:
838
+ """String representation of the registry."""
839
+ return f"ToolRegistry(tools={list(self._tools.keys())})"
840
+
841
+
842
+ def tool(
843
+ name: str | None = None,
844
+ description: str | None = None,
845
+ parameters: list[ParameterSchema] | None = None,
846
+ category: ToolCategory = ToolCategory.CUSTOM,
847
+ ) -> Callable[[Callable[..., T]], "FunctionTool[T]"]:
848
+ """
849
+ Decorator to create a tool from a function.
850
+
851
+ Example:
852
+ @tool(
853
+ name="greet",
854
+ description="Greet a user by name",
855
+ parameters=[
856
+ ParameterSchema(name="name", type="string", description="Name to greet"),
857
+ ],
858
+ )
859
+ def greet(name: str) -> str:
860
+ return f"Hello, {name}!"
861
+
862
+ Args:
863
+ name: Tool name (defaults to function name)
864
+ description: Tool description (defaults to function docstring)
865
+ parameters: Parameter schemas
866
+ category: Tool category
867
+
868
+ Returns:
869
+ Decorator that creates a FunctionTool
870
+ """
871
+
872
+ def decorator(func: Callable[..., T]) -> "FunctionTool[T]":
873
+ tool_name = name or func.__name__
874
+ tool_description = description or func.__doc__ or f"Execute {tool_name}"
875
+
876
+ return FunctionTool(
877
+ name=tool_name,
878
+ description=tool_description,
879
+ parameters=parameters,
880
+ category=category,
881
+ func=func,
882
+ )
883
+
884
+ return decorator
885
+
886
+
887
+ class FunctionTool(BaseTool[T]):
888
+ """
889
+ A tool created from a function.
890
+
891
+ This is used by the @tool decorator to wrap functions as tools.
892
+ """
893
+
894
+ def __init__(
895
+ self,
896
+ name: str,
897
+ description: str,
898
+ func: Callable[..., T],
899
+ parameters: list[ParameterSchema] | None = None,
900
+ category: ToolCategory = ToolCategory.CUSTOM,
901
+ version: str = "1.0.0",
902
+ ) -> None:
903
+ """
904
+ Initialize a function-based tool.
905
+
906
+ Args:
907
+ name: Tool name
908
+ description: Tool description
909
+ func: The function to wrap
910
+ parameters: Parameter schemas
911
+ category: Tool category
912
+ version: Tool version
913
+ """
914
+ super().__init__(
915
+ name=name,
916
+ description=description,
917
+ parameters=parameters,
918
+ category=category,
919
+ version=version,
920
+ )
921
+ self._func = func
922
+
923
+ def execute(self, **kwargs: Any) -> T:
924
+ """Execute the wrapped function synchronously."""
925
+ return self._func(**kwargs)
926
+
927
+ async def aexecute(self, **kwargs: Any) -> T:
928
+ """
929
+ Execute the wrapped function asynchronously.
930
+
931
+ If the function is a coroutine, it will be awaited.
932
+ Otherwise, it will be run in a thread pool.
933
+ """
934
+ import asyncio
935
+ import inspect
936
+
937
+ if inspect.iscoroutinefunction(self._func):
938
+ return await self._func(**kwargs)
939
+ else:
940
+ return await asyncio.to_thread(self._func, **kwargs)
941
+
942
+ def get_interruption_message(self, **kwargs: Any) -> str:
943
+ """
944
+ Get interruption message for user confirmation.
945
+
946
+ For function-based tools, generates a message using the tool name
947
+ and the first string argument value (if any).
948
+ """
949
+ # Try to find a meaningful argument to display
950
+ for key, value in kwargs.items():
951
+ if isinstance(value, str) and value:
952
+ return f"execute {self.name}: {value}"
953
+
954
+ # Fallback to just the tool name
955
+ return f"execute {self.name}"