ragbits-agents 1.4.0.dev202601300258__py3-none-any.whl → 1.4.0.dev202602030301__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,8 +9,12 @@ from ragbits.agents._main import (
9
9
  ToolCall,
10
10
  ToolCallResult,
11
11
  )
12
+ from ragbits.agents.hooks import (
13
+ EventType,
14
+ Hook,
15
+ HookManager,
16
+ )
12
17
  from ragbits.agents.post_processors.base import PostProcessor, StreamingPostProcessor
13
- from ragbits.agents.tool import requires_confirmation
14
18
  from ragbits.agents.tools import LongTermMemory, MemoryEntry, create_memory_tools
15
19
  from ragbits.agents.types import QuestionAnswerAgent, QuestionAnswerPromptInput, QuestionAnswerPromptOutput
16
20
 
@@ -22,6 +26,9 @@ __all__ = [
22
26
  "AgentResultStreaming",
23
27
  "AgentRunContext",
24
28
  "DownstreamAgentResult",
29
+ "EventType",
30
+ "Hook",
31
+ "HookManager",
25
32
  "LongTermMemory",
26
33
  "MemoryEntry",
27
34
  "PostProcessor",
@@ -32,5 +39,4 @@ __all__ = [
32
39
  "ToolCall",
33
40
  "ToolCallResult",
34
41
  "create_memory_tools",
35
- "requires_confirmation",
36
42
  ]
ragbits/agents/_main.py CHANGED
@@ -1,6 +1,4 @@
1
1
  import asyncio
2
- import hashlib
3
- import json
4
2
  import types
5
3
  import uuid
6
4
  from collections.abc import AsyncGenerator, AsyncIterator, Callable
@@ -32,6 +30,10 @@ from ragbits.agents.exceptions import (
32
30
  AgentToolNotAvailableError,
33
31
  AgentToolNotSupportedError,
34
32
  )
33
+ from ragbits.agents.hooks import (
34
+ Hook,
35
+ HookManager,
36
+ )
35
37
  from ragbits.agents.mcp.server import MCPServer, MCPServerStdio, MCPServerStreamableHttp
36
38
  from ragbits.agents.mcp.utils import get_tools
37
39
  from ragbits.agents.post_processors.base import (
@@ -40,7 +42,7 @@ from ragbits.agents.post_processors.base import (
40
42
  StreamingPostProcessor,
41
43
  stream_with_post_processing,
42
44
  )
43
- from ragbits.agents.tool import Tool, ToolCallResult, ToolChoice
45
+ from ragbits.agents.tool import Tool, ToolCallResult, ToolChoice, ToolReturn
44
46
  from ragbits.core.audit.traces import trace
45
47
  from ragbits.core.llms.base import (
46
48
  LLM,
@@ -65,10 +67,6 @@ with suppress(ImportError):
65
67
 
66
68
  from ragbits.core.llms import LiteLLM
67
69
 
68
- # Confirmation ID length: 16 hex chars provides sufficient uniqueness
69
- # while being compact for display and storage
70
- CONFIRMATION_ID_LENGTH = 16
71
-
72
70
  _Input = TypeVar("_Input", bound=BaseModel)
73
71
  _Output = TypeVar("_Output")
74
72
 
@@ -212,9 +210,10 @@ class AgentRunContext(BaseModel, Generic[DepsT]):
212
210
  """Whether to stream events from downstream agents when tools execute other agents."""
213
211
  downstream_agents: dict[str, "Agent"] = Field(default_factory=dict)
214
212
  """Registry of all agents that participated in this run"""
215
- confirmed_tools: list[dict[str, Any]] | None = Field(
216
- default=None,
217
- description="List of confirmed/declined tools from the frontend",
213
+ tool_confirmations: list[dict[str, Any]] = Field(
214
+ default_factory=list,
215
+ description="List of confirmed/declined tool executions. Each entry has 'confirmation_id' and 'confirmed' "
216
+ "(bool)",
218
217
  )
219
218
 
220
219
  def register_agent(self, agent: "Agent") -> None:
@@ -375,6 +374,7 @@ class Agent(
375
374
  keep_history: bool = False,
376
375
  tools: list[Callable] | None = None,
377
376
  mcp_servers: list[MCPServer] | None = None,
377
+ hooks: list[Hook] | None = None,
378
378
  default_options: AgentOptions[LLMClientOptionsT] | None = None,
379
379
  ) -> None:
380
380
  """
@@ -394,6 +394,7 @@ class Agent(
394
394
  keep_history: Whether to keep the history of the agent.
395
395
  tools: The tools available to the agent.
396
396
  mcp_servers: The MCP servers available to the agent.
397
+ hooks: List of tool hooks to register for tool lifecycle events.
397
398
  default_options: The default options for the agent run.
398
399
  """
399
400
  super().__init__(default_options)
@@ -416,6 +417,7 @@ class Agent(
416
417
  self.mcp_servers = mcp_servers or []
417
418
  self.history = history or []
418
419
  self.keep_history = keep_history
420
+ self.hook_manager = HookManager(hooks)
419
421
 
420
422
  if getattr(self, "system_prompt", None) and not getattr(self, "input_type", None):
421
423
  raise ValueError(
@@ -536,7 +538,9 @@ class Agent(
536
538
  ):
537
539
  if isinstance(result, ToolCallResult):
538
540
  tool_calls.append(result)
539
- prompt_with_history = prompt_with_history.add_tool_use_message(**result.__dict__)
541
+ prompt_with_history = prompt_with_history.add_tool_use_message(
542
+ id=result.id, name=result.name, arguments=result.arguments, result=result.result
543
+ )
540
544
 
541
545
  turn_count += 1
542
546
  else:
@@ -762,7 +766,9 @@ class Agent(
762
766
  elif isinstance(result, ToolCallResult):
763
767
  # Add ALL tool results to history (including pending confirmations)
764
768
  # This allows the agent to see them in the next turn
765
- prompt_with_history = prompt_with_history.add_tool_use_message(**result.__dict__)
769
+ prompt_with_history = prompt_with_history.add_tool_use_message(
770
+ id=result.id, name=result.name, arguments=result.arguments, result=result.result
771
+ )
766
772
  returned_tool_call = True
767
773
 
768
774
  # If we have pending confirmations, prepare for text-only summary generation
@@ -958,54 +964,37 @@ class Agent(
958
964
 
959
965
  tool = tools_mapping[tool_call.name]
960
966
 
961
- # Check if tool requires confirmation
962
- if tool.requires_confirmation:
963
- # Check if this tool has been confirmed in the context
964
- confirmed_tools = context.confirmed_tools or []
965
-
966
- # Generate a stable confirmation ID based on tool name and arguments
967
- confirmation_id = hashlib.sha256(
968
- f"{tool_call.name}:{json.dumps(tool_call.arguments, sort_keys=True)}".encode()
969
- ).hexdigest()[:CONFIRMATION_ID_LENGTH]
967
+ # Execute PRE_TOOL hooks with chaining
968
+ pre_tool_result = await self.hook_manager.execute_pre_tool(
969
+ tool_call=tool_call,
970
+ context=context,
971
+ )
970
972
 
971
- # Check if this specific tool call has been confirmed or declined
972
- is_confirmed = any(
973
- ct.get("confirmation_id") == confirmation_id and ct.get("confirmed") for ct in confirmed_tools
973
+ # Check decision
974
+ if pre_tool_result.decision == "deny":
975
+ yield ToolCallResult(
976
+ id=tool_call.id,
977
+ name=tool_call.name,
978
+ arguments=tool_call.arguments,
979
+ result=pre_tool_result.reason or "Tool execution denied",
974
980
  )
975
- is_declined = any(
976
- ct.get("confirmation_id") == confirmation_id and not ct.get("confirmed", True) for ct in confirmed_tools
981
+ return
982
+ # Handle "ask" decision from hooks
983
+ elif pre_tool_result.decision == "ask" and pre_tool_result.confirmation_request is not None:
984
+ yield pre_tool_result.confirmation_request
985
+
986
+ yield ToolCallResult(
987
+ id=tool_call.id,
988
+ name=tool_call.name,
989
+ arguments=tool_call.arguments,
990
+ result=pre_tool_result.reason or "Hook requires user confirmation",
977
991
  )
992
+ return
978
993
 
979
- if is_declined:
980
- # Tool was explicitly declined - skip execution entirely
981
- yield ToolCallResult(
982
- id=tool_call.id,
983
- name=tool_call.name,
984
- arguments=tool_call.arguments,
985
- result="❌ Action declined by user",
986
- )
987
- return
988
-
989
- if not is_confirmed:
990
- # Tool not confirmed yet - create and yield confirmation request
991
- request = ConfirmationRequest(
992
- confirmation_id=confirmation_id,
993
- tool_name=tool_call.name,
994
- tool_description=tool.description or "",
995
- arguments=tool_call.arguments,
996
- )
997
-
998
- # Yield confirmation request (will be streamed to frontend)
999
- yield request
994
+ # Always update arguments (chained from hooks)
995
+ tool_call.arguments = pre_tool_result.arguments
1000
996
 
1001
- # Yield a pending result and exit without executing
1002
- yield ToolCallResult(
1003
- id=tool_call.id,
1004
- name=tool_call.name,
1005
- arguments=tool_call.arguments,
1006
- result="⏳ Awaiting user confirmation",
1007
- )
1008
- return
997
+ tool_error: Exception | None = None
1009
998
 
1010
999
  with trace(agent_id=self.id, tool_name=tool_call.name, tool_arguments=tool_call.arguments) as outputs:
1011
1000
  try:
@@ -1019,35 +1008,50 @@ class Agent(
1019
1008
  else asyncio.to_thread(tool.on_tool_call, **call_args)
1020
1009
  )
1021
1010
 
1022
- if isinstance(tool_output, AgentResultStreaming):
1011
+ if isinstance(tool_output, ToolReturn):
1012
+ tool_return = tool_output
1013
+ elif isinstance(tool_output, AgentResultStreaming):
1023
1014
  async for downstream_item in tool_output:
1024
1015
  if context.stream_downstream_events:
1025
1016
  yield DownstreamAgentResult(agent_id=tool.id, item=downstream_item)
1026
-
1027
- tool_output = {
1028
- "content": tool_output.content,
1017
+ metadata = {
1029
1018
  "metadata": tool_output.metadata,
1030
1019
  "tool_calls": tool_output.tool_calls,
1031
1020
  "usage": tool_output.usage,
1032
1021
  }
1022
+ tool_return = ToolReturn(value=tool_output.content, metadata=metadata)
1023
+ else:
1024
+ tool_return = ToolReturn(value=tool_output, metadata=None)
1033
1025
 
1034
1026
  outputs.result = {
1035
- "tool_output": tool_output,
1027
+ "tool_output": tool_return.value,
1036
1028
  "tool_call_id": tool_call.id,
1037
1029
  }
1038
1030
 
1039
1031
  except Exception as e:
1032
+ tool_error = e
1040
1033
  outputs.result = {
1041
1034
  "error": str(e),
1042
1035
  "tool_call_id": tool_call.id,
1043
1036
  }
1044
- raise AgentToolExecutionError(tool_call.name, e) from e
1037
+
1038
+ # Execute POST_TOOL hooks with chaining
1039
+ post_tool_output = await self.hook_manager.execute_post_tool(
1040
+ tool_call=tool_call,
1041
+ tool_return=tool_return,
1042
+ error=tool_error,
1043
+ )
1044
+
1045
+ # Raise error after hooks have been executed
1046
+ if tool_error:
1047
+ raise AgentToolExecutionError(tool_call.name, tool_error) from tool_error
1045
1048
 
1046
1049
  yield ToolCallResult(
1047
1050
  id=tool_call.id,
1048
1051
  name=tool_call.name,
1049
1052
  arguments=tool_call.arguments,
1050
- result=tool_output,
1053
+ result=post_tool_output.tool_return.value if post_tool_output.tool_return else None,
1054
+ metadata=post_tool_output.tool_return.metadata if post_tool_output.tool_return else None,
1051
1055
  )
1052
1056
 
1053
1057
  @requires_dependencies(["a2a.types"], "a2a")
@@ -0,0 +1,77 @@
1
+ """
2
+ Hooks system for lifecycle events.
3
+
4
+ This module provides a comprehensive hook system that allows users to register
5
+ custom logic at various points in the execution lifecycle.
6
+
7
+ Available event types:
8
+ - PRE_TOOL: Before a tool is invoked
9
+ - POST_TOOL: After a tool completes
10
+
11
+ Example usage:
12
+
13
+ from ragbits.agents.hooks import (
14
+ EventType,
15
+ Hook,
16
+ PreToolInput,
17
+ PreToolOutput,
18
+ )
19
+
20
+ # Create a pre-tool hook callback
21
+ async def validate_input(input_data: PreToolInput) -> PreToolOutput:
22
+ if input_data.tool_call.name == "dangerous_tool":
23
+ return PreToolOutput(
24
+ arguments=input_data.tool_call.arguments,
25
+ decision="deny",
26
+ reason="This tool is not allowed"
27
+ )
28
+ return PreToolOutput(arguments=input_data.tool_call.arguments, decision="pass")
29
+
30
+ # Create hook instance with proper type annotation
31
+ hook: Hook[PreToolInput, PreToolOutput] = Hook(
32
+ event_type=EventType.PRE_TOOL,
33
+ callback=validate_input,
34
+ tool_names=["dangerous_tool"],
35
+ priority=10
36
+ )
37
+
38
+ # Register hooks with agent
39
+ agent = Agent(
40
+ ...,
41
+ hooks=[hook]
42
+ )
43
+ """
44
+
45
+ from ragbits.agents.hooks.base import Hook, HookInputT, HookOutputT
46
+ from ragbits.agents.hooks.confirmation import create_confirmation_hook
47
+ from ragbits.agents.hooks.manager import HookManager
48
+ from ragbits.agents.hooks.types import (
49
+ EventType,
50
+ PostToolHookCallback,
51
+ PostToolInput,
52
+ PostToolOutput,
53
+ PreToolHookCallback,
54
+ PreToolInput,
55
+ PreToolOutput,
56
+ )
57
+
58
+ __all__ = [
59
+ # Event types
60
+ "EventType",
61
+ # Core classes
62
+ "Hook",
63
+ # Type variables
64
+ "HookInputT",
65
+ "HookManager",
66
+ "HookOutputT",
67
+ "PostToolHookCallback",
68
+ # Input/output types
69
+ "PostToolInput",
70
+ "PostToolOutput",
71
+ # Callback type aliases
72
+ "PreToolHookCallback",
73
+ "PreToolInput",
74
+ "PreToolOutput",
75
+ # Hook factories
76
+ "create_confirmation_hook",
77
+ ]
@@ -0,0 +1,94 @@
1
+ """
2
+ Base classes for the hooks system.
3
+ """
4
+
5
+ from collections.abc import Awaitable, Callable
6
+ from typing import Generic, TypeVar
7
+
8
+ from ragbits.agents.hooks.types import EventType, HookEventIO
9
+
10
+ HookInputT = TypeVar("HookInputT", bound=HookEventIO)
11
+ HookOutputT = TypeVar("HookOutputT", bound=HookEventIO)
12
+
13
+
14
+ class Hook(Generic[HookInputT, HookOutputT]):
15
+ """
16
+ A hook that intercepts execution at various lifecycle points.
17
+
18
+ Hooks allow you to:
19
+ - Validate inputs before execution (pre hooks)
20
+ - Control access (pre hooks)
21
+ - Modify inputs (pre hooks)
22
+ - Deny execution (pre hooks)
23
+ - Modify outputs (post hooks)
24
+ - Handle errors (post hooks)
25
+
26
+ Attributes:
27
+ event_type: The type of event (e.g., PRE_TOOL, POST_TOOL)
28
+ callback: The async function to call when the event is triggered
29
+ tool_names: List of tool names this hook applies to. If None, applies to all tools.
30
+ priority: Execution priority (lower numbers execute first, default: 100)
31
+
32
+ Example:
33
+ ```python
34
+ from ragbits.agents.hooks import Hook, EventType, PreToolInput, PreToolOutput
35
+
36
+
37
+ async def validate_input(input_data: PreToolInput) -> PreToolOutput:
38
+ if input_data.tool_call.name == "dangerous_tool":
39
+ return PreToolOutput(arguments=input_data.tool_call.arguments, decision="deny", reason="Not allowed")
40
+ return PreToolOutput(arguments=input_data.tool_call.arguments, decision="pass")
41
+
42
+
43
+ hook: Hook[PreToolInput, PreToolOutput] = Hook(
44
+ event_type=EventType.PRE_TOOL, callback=validate_input, tool_names=["dangerous_tool"], priority=10
45
+ )
46
+ ```
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ event_type: EventType,
52
+ callback: Callable[[HookInputT], Awaitable[HookOutputT]],
53
+ tool_names: list[str] | None = None,
54
+ priority: int = 100,
55
+ ) -> None:
56
+ """
57
+ Initialize a hook.
58
+
59
+ Args:
60
+ event_type: The type of event (e.g., PRE_TOOL, POST_TOOL)
61
+ callback: The async function to call when the event is triggered
62
+ tool_names: List of tool names this hook applies to. If None, applies to all tools.
63
+ priority: Execution priority (lower numbers execute first, default: 100)
64
+ """
65
+ self.event_type = event_type
66
+ self.callback = callback
67
+ self.tool_names = tool_names
68
+ self.priority = priority
69
+
70
+ def matches_tool(self, tool_name: str) -> bool:
71
+ """
72
+ Check if this hook applies to the given tool name.
73
+
74
+ Args:
75
+ tool_name: The name of the tool to check
76
+
77
+ Returns:
78
+ True if this hook should be executed for the given tool
79
+ """
80
+ if self.tool_names is None:
81
+ return True
82
+ return tool_name in self.tool_names
83
+
84
+ async def execute(self, hook_input: HookInputT) -> HookOutputT:
85
+ """
86
+ Execute the hook callback with the given input.
87
+
88
+ Args:
89
+ hook_input: The input to pass to the callback
90
+
91
+ Returns:
92
+ The output from the callback
93
+ """
94
+ return await self.callback(hook_input)
@@ -0,0 +1,51 @@
1
+ """
2
+ Helper functions for creating common hooks.
3
+
4
+ This module provides factory functions for creating commonly used hooks.
5
+ """
6
+
7
+ from ragbits.agents.hooks.base import Hook
8
+ from ragbits.agents.hooks.types import EventType, PreToolInput, PreToolOutput
9
+
10
+
11
+ def create_confirmation_hook(
12
+ tool_names: list[str] | None = None, priority: int = 1
13
+ ) -> Hook[PreToolInput, PreToolOutput]:
14
+ """
15
+ Create a hook that requires user confirmation before tool execution.
16
+
17
+ The hook returns "ask" decision, which causes the agent to yield a ConfirmationRequest
18
+ and wait for user approval/decline.
19
+
20
+ Args:
21
+ tool_names: List of tool names to require confirmation for. If None, applies to all tools.
22
+ priority: Hook priority (default: 1, runs first)
23
+
24
+ Returns:
25
+ Hook configured to require confirmation
26
+
27
+ Example:
28
+ ```python
29
+ from ragbits.agents import Agent
30
+ from ragbits.agents.hooks.confirmation import create_confirmation_hook
31
+
32
+ agent = Agent(
33
+ tools=[delete_file, send_email], hooks=[create_confirmation_hook(tool_names=["delete_file", "send_email"])]
34
+ )
35
+ ```
36
+ """
37
+
38
+ async def confirm_hook(input_data: PreToolInput) -> PreToolOutput:
39
+ """Hook that always returns 'ask' to require confirmation."""
40
+ return PreToolOutput(
41
+ arguments=input_data.tool_call.arguments,
42
+ decision="ask",
43
+ reason=f"Tool '{input_data.tool_call.name}' requires user confirmation",
44
+ )
45
+
46
+ return Hook(
47
+ event_type=EventType.PRE_TOOL,
48
+ callback=confirm_hook,
49
+ tool_names=tool_names,
50
+ priority=priority,
51
+ )
@@ -0,0 +1,195 @@
1
+ """
2
+ Hook manager for organizing and executing hooks.
3
+
4
+ This module provides the HookManager class which handles registration,
5
+ organization, and execution of hooks during lifecycle events.
6
+ """
7
+
8
+ import hashlib
9
+ import json
10
+ from collections import defaultdict
11
+ from typing import TYPE_CHECKING
12
+
13
+ from ragbits.agents.confirmation import ConfirmationRequest
14
+ from ragbits.agents.hooks.base import Hook
15
+ from ragbits.agents.hooks.types import EventType, PostToolInput, PostToolOutput, PreToolInput, PreToolOutput
16
+ from ragbits.agents.tool import ToolReturn
17
+ from ragbits.core.llms.base import ToolCall
18
+
19
+ if TYPE_CHECKING:
20
+ from ragbits.agents._main import AgentRunContext
21
+
22
+ # Confirmation ID length: 16 hex chars provides sufficient uniqueness
23
+ # while being compact for display and storage
24
+ CONFIRMATION_ID_LENGTH = 16
25
+
26
+
27
+ class HookManager:
28
+ """
29
+ Manages registration and execution of hooks for an agent.
30
+
31
+ The HookManager organizes hooks by type and executes them in priority order,
32
+ with proper chaining of modifications between hooks.
33
+ """
34
+
35
+ def __init__(self, hooks: list[Hook] | None = None) -> None:
36
+ """
37
+ Initialize the hook manager.
38
+
39
+ Args:
40
+ hooks: Initial list of hooks to register
41
+ """
42
+ self._hooks: dict[EventType, list[Hook]] = defaultdict(list)
43
+
44
+ if hooks:
45
+ for hook in hooks:
46
+ self.register(hook)
47
+
48
+ def register(self, hook: Hook) -> None:
49
+ """
50
+ Register a hook.
51
+
52
+ Hooks are organized by type and sorted by priority
53
+ (lower numbers execute first).
54
+
55
+ Args:
56
+ hook: The hook to register
57
+ """
58
+ self._hooks[hook.event_type].append(hook)
59
+ self._hooks[hook.event_type].sort(key=lambda h: h.priority)
60
+
61
+ def get_hooks(self, event_type: EventType, tool_name: str | None) -> list[Hook]:
62
+ """
63
+ Get all hooks for a specific event type that match the tool name.
64
+
65
+ Args:
66
+ event_type: The type of event
67
+ tool_name: Optional tool name to filter hooks. If None, returns all hooks for the event type.
68
+
69
+ Returns:
70
+ List of matching hooks, sorted by priority
71
+ """
72
+ hooks = self._hooks.get(event_type, [])
73
+
74
+ if tool_name is None:
75
+ return hooks
76
+
77
+ return [hook for hook in hooks if hook.matches_tool(tool_name)]
78
+
79
+ async def execute_pre_tool(
80
+ self,
81
+ tool_call: ToolCall,
82
+ context: "AgentRunContext",
83
+ ) -> PreToolOutput:
84
+ """
85
+ Execute pre-tool hooks with proper chaining.
86
+
87
+ Each hook sees the modified arguments from the previous hook.
88
+ Execution stops immediately if any hook returns "deny" or "ask" (unless confirmed).
89
+
90
+ Args:
91
+ tool_call: The tool call to process
92
+ context: Agent run context containing tool_confirmations
93
+
94
+ Returns:
95
+ PreToolOutput with final arguments and decision
96
+ """
97
+ hooks = self.get_hooks(EventType.PRE_TOOL, tool_call.name)
98
+
99
+ # Start with original arguments
100
+ current_arguments = tool_call.arguments
101
+
102
+ for hook in hooks:
103
+ # Generate confirmation_id: hash(hook_function_name + tool_name + arguments)
104
+ hook_name = hook.callback.__name__
105
+ confirmation_id_str = f"{hook_name}:{tool_call.name}:{json.dumps(current_arguments, sort_keys=True)}"
106
+ confirmation_id = hashlib.sha256(confirmation_id_str.encode()).hexdigest()[:CONFIRMATION_ID_LENGTH]
107
+
108
+ # Create input with current state (chained from previous hook)
109
+ hook_input = PreToolInput(
110
+ tool_call=tool_call.model_copy(update={"arguments": current_arguments}),
111
+ )
112
+
113
+ result: PreToolOutput = await hook.execute(hook_input)
114
+
115
+ if result.decision == "deny":
116
+ return PreToolOutput(
117
+ arguments=current_arguments,
118
+ decision="deny",
119
+ reason=result.reason,
120
+ )
121
+
122
+ elif result.decision == "ask":
123
+ # Check if already confirmed/declined in context
124
+ for conf in context.tool_confirmations:
125
+ if conf.get("confirmation_id") == confirmation_id:
126
+ if conf.get("confirmed"):
127
+ # Approved → convert to "pass" and continue to next hook
128
+ result = PreToolOutput(arguments=current_arguments, decision="pass")
129
+ break
130
+ else:
131
+ # Declined → convert to "deny" and stop immediately
132
+ return PreToolOutput(
133
+ arguments=current_arguments,
134
+ decision="deny",
135
+ reason=result.reason or "Tool execution declined by user",
136
+ )
137
+ else:
138
+ # Not in context → return "ask" with full ConfirmationRequest
139
+ return PreToolOutput(
140
+ arguments=current_arguments,
141
+ decision="ask",
142
+ reason=result.reason,
143
+ confirmation_request=ConfirmationRequest(
144
+ confirmation_id=confirmation_id,
145
+ tool_name=tool_call.name,
146
+ tool_description=result.reason or "Hook requires user confirmation",
147
+ arguments=current_arguments,
148
+ ),
149
+ )
150
+
151
+ # Chain arguments for next hook
152
+ current_arguments = result.arguments
153
+
154
+ # All hooks passed
155
+ return PreToolOutput(arguments=current_arguments, decision="pass")
156
+
157
+ async def execute_post_tool(
158
+ self,
159
+ tool_call: ToolCall,
160
+ tool_return: ToolReturn | None,
161
+ error: Exception | None,
162
+ ) -> PostToolOutput:
163
+ """
164
+ Execute post-tool hooks with proper output chaining.
165
+
166
+ Each hook sees the modified output from the previous hook.
167
+
168
+ Args:
169
+ tool_call: The tool call that was executed
170
+ tool_return: Object representing the output of the tool (with value passed to the LLM and metadata)
171
+ error: Any error that occurred
172
+
173
+ Returns:
174
+ PostToolOutput with final output
175
+ """
176
+ hooks = self.get_hooks(EventType.POST_TOOL, tool_call.name)
177
+
178
+ # Start with original output
179
+ current_output = tool_return
180
+
181
+ for hook in hooks:
182
+ # Create input with current state (chained from previous hook)
183
+ hook_input = PostToolInput(
184
+ tool_call=tool_call,
185
+ tool_return=current_output,
186
+ error=error,
187
+ )
188
+
189
+ result: PostToolOutput = await hook.execute(hook_input)
190
+
191
+ # Chain output for next hook
192
+ current_output = result.tool_return
193
+
194
+ # Return final chained result
195
+ return PostToolOutput(tool_return=current_output)
@@ -0,0 +1,142 @@
1
+ """
2
+ Type definitions for the hooks system.
3
+
4
+ This module contains all type definitions including EventType, callback types,
5
+ input types, and output types for the hooks system.
6
+ """
7
+
8
+ from collections.abc import Awaitable, Callable
9
+ from enum import Enum
10
+ from typing import Any, Literal, TypeAlias
11
+
12
+ from pydantic import BaseModel, Field, model_validator
13
+
14
+ from ragbits.agents.confirmation import ConfirmationRequest
15
+ from ragbits.agents.tool import ToolReturn
16
+ from ragbits.core.llms.base import ToolCall
17
+
18
+
19
+ class EventType(str, Enum):
20
+ """
21
+ Types of events that can be hooked.
22
+
23
+ Attributes:
24
+ PRE_TOOL: Triggered before a tool is invoked
25
+ POST_TOOL: Triggered after a tool completes
26
+ """
27
+
28
+ PRE_TOOL = "pre_tool"
29
+ POST_TOOL = "post_tool"
30
+
31
+
32
+ class HookEventIO(BaseModel):
33
+ """
34
+ Base class for hook inputs and outputs.
35
+
36
+ Contains the common event_type attribute shared by all hook events.
37
+
38
+ Attributes:
39
+ event_type: The type of event
40
+ """
41
+
42
+ model_config = {"arbitrary_types_allowed": True}
43
+
44
+ event_type: EventType
45
+
46
+
47
+ class PreToolInput(HookEventIO):
48
+ """
49
+ Input passed to pre-tool hook callbacks.
50
+
51
+ This is provided before a tool is invoked, allowing hooks to:
52
+ - Inspect the tool call
53
+ - Modify tool arguments
54
+ - Deny execution
55
+
56
+ Attributes:
57
+ event_type: Always EventType.PRE_TOOL (unchangeable)
58
+ tool_call: The complete tool call (contains name, arguments, id, type)
59
+ """
60
+
61
+ event_type: Literal[EventType.PRE_TOOL] = Field(default=EventType.PRE_TOOL, frozen=True)
62
+ tool_call: ToolCall
63
+
64
+
65
+ class PostToolInput(HookEventIO):
66
+ """
67
+ Input passed to post-tool hook callbacks.
68
+
69
+ This is provided after a tool completes, allowing hooks to:
70
+ - Inspect the tool result
71
+ - Modify tool output
72
+ - Handle errors
73
+
74
+ Attributes:
75
+ event_type: Always EventType.POST_TOOL (unchangeable)
76
+ tool_call: The original tool call
77
+ tool_return: The result returned by the tool (None if error occurred)
78
+ error: Any error that occurred during execution (None if successful)
79
+ """
80
+
81
+ event_type: Literal[EventType.POST_TOOL] = Field(default=EventType.POST_TOOL, frozen=True)
82
+ tool_call: ToolCall
83
+ tool_return: ToolReturn | None = None
84
+ error: Exception | None = None
85
+
86
+
87
+ class PreToolOutput(HookEventIO):
88
+ """
89
+ Output returned by pre-tool hook callbacks.
90
+
91
+ This allows hooks to control tool execution. The output always contains
92
+ arguments (either original or modified).
93
+
94
+ Attributes:
95
+ event_type: Always EventType.PRE_TOOL (unchangeable)
96
+ arguments: Tool arguments to use (original or modified) - always present
97
+ decision: The decision on tool execution ("pass", "ask", "deny")
98
+ reason: Explanation for ask/deny decisions (required for "ask" and "deny", can be None for "pass")
99
+ confirmation_request: Full confirmation request when decision is "ask" (set by HookManager)
100
+ """
101
+
102
+ event_type: Literal[EventType.PRE_TOOL] = Field(default=EventType.PRE_TOOL, frozen=True) # type: ignore[assignment]
103
+ arguments: dict[str, Any]
104
+ decision: Literal["pass", "ask", "deny"] = "pass"
105
+ reason: str | None = None
106
+ confirmation_request: ConfirmationRequest | None = None
107
+
108
+ @model_validator(mode="after")
109
+ def validate_reason(self) -> "PreToolOutput":
110
+ """Validate that reason is provided for ask and deny decisions."""
111
+ if self.decision in ("ask", "deny") and not self.reason:
112
+ raise ValueError(f"reason is required when decision='{self.decision}'")
113
+ return self
114
+
115
+
116
+ class PostToolOutput(HookEventIO):
117
+ """
118
+ Output returned by post-tool hook callbacks.
119
+
120
+ The output always contains the tool output (either original or modified).
121
+
122
+ Attributes:
123
+ event_type: Always EventType.POST_TOOL (unchangeable)
124
+ tool_return: Tool output to use (original or modified) - None if the tool execution failed
125
+
126
+ Example:
127
+ ```python
128
+ # Pass through unchanged
129
+ return PostToolOutput(tool_return=input.tool_return)
130
+
131
+ # Modify output
132
+ return PostToolOutput(tool_return=ToolReturn(value={"filtered": data}))
133
+ ```
134
+ """
135
+
136
+ event_type: Literal[EventType.POST_TOOL] = Field(default=EventType.POST_TOOL, frozen=True) # type: ignore[assignment]
137
+ tool_return: ToolReturn | None
138
+
139
+
140
+ # Type aliases for hook callbacks
141
+ PreToolHookCallback: TypeAlias = Callable[["PreToolInput"], Awaitable["PreToolOutput"]]
142
+ PostToolHookCallback: TypeAlias = Callable[["PostToolInput"], Awaitable["PostToolOutput"]]
ragbits/agents/tool.py CHANGED
@@ -1,11 +1,10 @@
1
1
  from collections.abc import Callable
2
2
  from contextlib import suppress
3
3
  from dataclasses import dataclass
4
- from functools import wraps
5
- from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast
4
+ from typing import TYPE_CHECKING, Any, Literal, cast
6
5
 
7
6
  from pydantic import BaseModel
8
- from typing_extensions import ParamSpec, Self
7
+ from typing_extensions import Self
9
8
 
10
9
  from ragbits.core.llms.base import LLMClientOptionsT
11
10
  from ragbits.core.prompt.prompt import PromptInputT, PromptOutputT
@@ -18,45 +17,18 @@ if TYPE_CHECKING:
18
17
  with suppress(ImportError):
19
18
  from pydantic_ai import Tool as PydanticAITool
20
19
 
21
- P = ParamSpec("P")
22
- T = TypeVar("T")
23
20
 
24
-
25
- def requires_confirmation(func: Callable[P, T]) -> Callable[P, T]:
21
+ @dataclass
22
+ class ToolReturn:
26
23
  """
27
- Decorator to mark a tool function as requiring user confirmation before execution.
28
-
29
- When a function decorated with @requires_confirmation is used as a tool in an Agent,
30
- the agent will request user confirmation before executing the tool.
31
-
32
- Example:
33
- ```python
34
- @requires_confirmation
35
- def delete_file(filepath: str) -> str:
36
- '''Delete a file from the system.'''
37
- # Implementation
38
- return "File deleted"
39
-
40
-
41
- agent = Agent(llm=llm, tools=[delete_file])
42
- # The agent will automatically mark delete_file as requiring confirmation
43
- ```
44
-
45
- Args:
46
- func: The function to mark as requiring confirmation
47
-
48
- Returns:
49
- The same function with a _requires_confirmation attribute set to True
24
+ Represents an object returned from the tool. If a tool wants to return a value with some content hidden
25
+ from LLM, it needs to return an object of this class directly.
50
26
  """
51
27
 
52
- @wraps(func)
53
- def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
54
- return func(*args, **kwargs)
55
-
56
- # Mark the function as requiring confirmation
57
- wrapper._requires_confirmation = True # type: ignore[attr-defined]
58
-
59
- return wrapper
28
+ value: Any
29
+ "Value passed directly to LLM as a result of the tool"
30
+ metadata: Any = None
31
+ "Metadata not passed to the LLM, but which can be used in the application later on"
60
32
 
61
33
 
62
34
  @dataclass
@@ -72,7 +44,9 @@ class ToolCallResult:
72
44
  arguments: dict[str, Any]
73
45
  """Dictionary containing the arguments passed to the tool"""
74
46
  result: Any
75
- """The output from the tool call."""
47
+ """The output from the tool call passed to the LLM"""
48
+ metadata: Any = None
49
+ """Metadata returned from a tool that is not meant to be seen by the LLM"""
76
50
 
77
51
 
78
52
  @dataclass
@@ -92,35 +66,26 @@ class Tool:
92
66
  context_var_name: str | None = None
93
67
  """The name of the context variable that this tool accepts."""
94
68
  id: str | None = None
95
- requires_confirmation: bool = False
96
- """Whether this tool requires user confirmation before execution."""
97
69
 
98
70
  @classmethod
99
- def from_callable(cls, callable: Callable, requires_confirmation: bool = False) -> Self:
71
+ def from_callable(cls, callable: Callable) -> Self:
100
72
  """
101
73
  Create a Tool instance from a callable function.
102
74
 
103
75
  Args:
104
76
  callable: The function to convert into a Tool
105
- requires_confirmation: Whether this tool requires user confirmation before execution.
106
- If not provided, checks if the callable has been decorated with @requires_confirmation.
107
77
 
108
78
  Returns:
109
79
  A new Tool instance representing the callable function.
110
80
  """
111
81
  schema = convert_function_to_function_schema(callable)
112
82
 
113
- # Check if the callable has been decorated with @requires_confirmation
114
- # Priority: explicit parameter > decorator attribute
115
- needs_confirmation = requires_confirmation or getattr(callable, "_requires_confirmation", False)
116
-
117
83
  return cls(
118
84
  name=schema["function"]["name"],
119
85
  description=schema["function"]["description"],
120
86
  parameters=schema["function"]["parameters"],
121
87
  on_tool_call=callable,
122
88
  context_var_name=get_context_variable_name(callable),
123
- requires_confirmation=needs_confirmation,
124
89
  )
125
90
 
126
91
  def to_function_schema(self) -> dict[str, Any]:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ragbits-agents
3
- Version: 1.4.0.dev202601300258
3
+ Version: 1.4.0.dev202602030301
4
4
  Summary: Building blocks for rapid development of GenAI applications
5
5
  Project-URL: Homepage, https://github.com/deepsense-ai/ragbits
6
6
  Project-URL: Bug Reports, https://github.com/deepsense-ai/ragbits/issues
@@ -22,7 +22,7 @@ Classifier: Programming Language :: Python :: 3.13
22
22
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
23
23
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
24
24
  Requires-Python: >=3.10
25
- Requires-Dist: ragbits-core==1.4.0.dev202601300258
25
+ Requires-Dist: ragbits-core==1.4.0.dev202602030301
26
26
  Provides-Extra: a2a
27
27
  Requires-Dist: a2a-sdk<1.0.0,>=0.2.9; extra == 'a2a'
28
28
  Requires-Dist: fastapi<1.0.0,>=0.115.0; extra == 'a2a'
@@ -1,13 +1,18 @@
1
- ragbits/agents/__init__.py,sha256=SuM-RMiS_EkMsOgVXwhNtfZKSoKiA5B4QmaaOYWw7a4,996
2
- ragbits/agents/_main.py,sha256=5AGHPh5_pRwh0mv8Hwx9xuj8vs5edu5NURkONL_fmSc,51904
1
+ ragbits/agents/__init__.py,sha256=rKqDbbppy70HmDN6AtC3eXzyVWTTagkRpLvPjWulGuY,1040
2
+ ragbits/agents/_main.py,sha256=TF0qvhLQ3LS2AmjQ0xjoXl8W47vnFjV1-aypnCEsYPk,52090
3
3
  ragbits/agents/cli.py,sha256=xUS7k8IAn0479n165i4YVFKo4Jx0M2iCpaMILZi9xt8,17649
4
4
  ragbits/agents/confirmation.py,sha256=cwdd1feSSobxa7gxBvEZcL9e3tcLdc8CyfvvQwaHF1Y,619
5
5
  ragbits/agents/exceptions.py,sha256=TiompKlP1QRD4TZ7hpp52dyZqg0rjoq1t5QUTxFZud8,3552
6
6
  ragbits/agents/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
- ragbits/agents/tool.py,sha256=86a1ilbErXLFLgIQeLu0HBXUy5yUvkL91XXaYqgkZPE,7829
7
+ ragbits/agents/tool.py,sha256=c7TcexLMZSzWGS3nWKp5BX0zAipdkVAuRyvjCL220EE,6501
8
8
  ragbits/agents/types.py,sha256=_dzhn4HzrWuSK1dNVZEpznWTrxKMnKuTJfdUCwJWKuc,869
9
9
  ragbits/agents/a2a/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  ragbits/agents/a2a/server.py,sha256=4a-cq87OeMoVOjEM_PbwTSHhPTpJv8r_xOVaZAzFIiI,3031
11
+ ragbits/agents/hooks/__init__.py,sha256=xvWTFLkbM6aRLEz-Ypw9MMDOKVIJqhZzN5XE0jEvy3o,1984
12
+ ragbits/agents/hooks/base.py,sha256=Rl_SUgnI-e8Rgrg7TAlxDsdC2nqtRAOIGpIsPV3Y_Sk,3137
13
+ ragbits/agents/hooks/confirmation.py,sha256=ykSRR2sF8lbxG4pJN5amtT6qstnDoTbrYWZaafoSa9Y,1648
14
+ ragbits/agents/hooks/manager.py,sha256=Exj-Q9Hajf61n_08pDVtI3ouTZ8EphNnSIxM5xx-EOI,7047
15
+ ragbits/agents/hooks/types.py,sha256=zQkLhrctjU8adELYb78ikX3Q0kqPp_LNrAvfJqHI7GU,4550
11
16
  ragbits/agents/mcp/__init__.py,sha256=7icbSe4FCBMgPADNNoeQBKiLY8J6gIlqFS77gFGr6Ks,405
12
17
  ragbits/agents/mcp/server.py,sha256=slob6ns-sbLmG-bQAbwz-23MWsaAENeKhE2eNdoI3Vs,16397
13
18
  ragbits/agents/mcp/utils.py,sha256=o_UXS-olS1uRdfZHTRJdG697Xmq3q77twfcHzmyRTY0,1394
@@ -20,6 +25,6 @@ ragbits/agents/tools/memory.py,sha256=IrRGpZvdylNbn_Z7FFFwun9Rx3OSQimjVynSt3WgUo
20
25
  ragbits/agents/tools/openai.py,sha256=SpoB6my_T6LwfHjrP7ivQL6H4KvwInVampXq-6nGAHE,4955
21
26
  ragbits/agents/tools/todo.py,sha256=R86_Hu0HIl5Ujp8B2ctOo1iSLAHmfrzyu6jnyCLs4uc,18576
22
27
  ragbits/agents/tools/types.py,sha256=6yNG7IjG476muiCIcXKRgDJSemCeO2ZgycpXLyfu8jc,441
23
- ragbits_agents-1.4.0.dev202601300258.dist-info/METADATA,sha256=X5VLysD79iwTOPRO6tHrBobl8xseUgyapwbsuWVPixc,2273
24
- ragbits_agents-1.4.0.dev202601300258.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
25
- ragbits_agents-1.4.0.dev202601300258.dist-info/RECORD,,
28
+ ragbits_agents-1.4.0.dev202602030301.dist-info/METADATA,sha256=6yp4rccNkmzSd4B8SM8Rq8TBFoSilNNEAgXqra9EwuE,2273
29
+ ragbits_agents-1.4.0.dev202602030301.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
30
+ ragbits_agents-1.4.0.dev202602030301.dist-info/RECORD,,