uipath-langchain 0.1.28__py3-none-any.whl → 0.1.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. uipath_langchain/_utils/_request_mixin.py +8 -0
  2. uipath_langchain/_utils/_settings.py +3 -2
  3. uipath_langchain/agent/guardrails/__init__.py +0 -16
  4. uipath_langchain/agent/guardrails/actions/__init__.py +2 -0
  5. uipath_langchain/agent/guardrails/actions/block_action.py +1 -1
  6. uipath_langchain/agent/guardrails/actions/escalate_action.py +17 -34
  7. uipath_langchain/agent/guardrails/actions/filter_action.py +55 -0
  8. uipath_langchain/agent/guardrails/actions/log_action.py +1 -1
  9. uipath_langchain/agent/guardrails/guardrail_nodes.py +161 -45
  10. uipath_langchain/agent/guardrails/guardrails_factory.py +200 -4
  11. uipath_langchain/agent/guardrails/types.py +0 -12
  12. uipath_langchain/agent/guardrails/utils.py +146 -0
  13. uipath_langchain/agent/react/agent.py +20 -7
  14. uipath_langchain/agent/react/constants.py +1 -2
  15. uipath_langchain/agent/{guardrails → react/guardrails}/guardrails_subgraph.py +57 -18
  16. uipath_langchain/agent/react/llm_node.py +41 -10
  17. uipath_langchain/agent/react/router.py +48 -37
  18. uipath_langchain/agent/react/types.py +15 -1
  19. uipath_langchain/agent/react/utils.py +1 -1
  20. uipath_langchain/agent/tools/__init__.py +2 -0
  21. uipath_langchain/agent/tools/mcp_tool.py +86 -0
  22. uipath_langchain/chat/__init__.py +4 -0
  23. uipath_langchain/chat/bedrock.py +16 -0
  24. uipath_langchain/chat/openai.py +56 -26
  25. uipath_langchain/chat/supported_models.py +9 -0
  26. uipath_langchain/chat/vertex.py +62 -46
  27. uipath_langchain/embeddings/embeddings.py +18 -12
  28. uipath_langchain/runtime/schema.py +72 -16
  29. {uipath_langchain-0.1.28.dist-info → uipath_langchain-0.1.34.dist-info}/METADATA +4 -2
  30. {uipath_langchain-0.1.28.dist-info → uipath_langchain-0.1.34.dist-info}/RECORD +33 -30
  31. {uipath_langchain-0.1.28.dist-info → uipath_langchain-0.1.34.dist-info}/WHEEL +0 -0
  32. {uipath_langchain-0.1.28.dist-info → uipath_langchain-0.1.34.dist-info}/entry_points.txt +0 -0
  33. {uipath_langchain-0.1.28.dist-info → uipath_langchain-0.1.34.dist-info}/licenses/LICENSE +0 -0
@@ -137,6 +137,8 @@ class UiPathRequestMixin(BaseModel):
137
137
  max_tokens: int | None = 1000
138
138
  frequency_penalty: float | None = None
139
139
  presence_penalty: float | None = None
140
+ agenthub_config: str | None = None
141
+ byo_connection_id: str | None = None
140
142
 
141
143
  logger: logging.Logger | None = None
142
144
  max_retries: int | None = 5
@@ -748,6 +750,12 @@ class UiPathRequestMixin(BaseModel):
748
750
  "Authorization": f"Bearer {self.access_token}",
749
751
  "X-UiPath-LlmGateway-TimeoutSeconds": str(self.default_request_timeout),
750
752
  }
753
+ if self.agenthub_config:
754
+ self._auth_headers["X-UiPath-AgentHub-Config"] = self.agenthub_config
755
+ if self.byo_connection_id:
756
+ self._auth_headers["X-UiPath-LlmGateway-ByoIsConnectionId"] = (
757
+ self.byo_connection_id
758
+ )
751
759
  if self.is_normalized and self.model_name:
752
760
  self._auth_headers["X-UiPath-LlmGateway-NormalizedApi-ModelName"] = (
753
761
  self.model_name
@@ -5,6 +5,7 @@ from typing import Any
5
5
  import httpx
6
6
  from pydantic import Field
7
7
  from pydantic_settings import BaseSettings
8
+ from uipath._utils._ssl_context import get_httpx_client_kwargs
8
9
 
9
10
 
10
11
  class UiPathCachedPathsSettings(BaseSettings):
@@ -58,7 +59,7 @@ def get_uipath_token_header(
58
59
  client_secret=settings.client_secret,
59
60
  grant_type="client_credentials",
60
61
  )
61
- with httpx.Client() as client:
62
+ with httpx.Client(**get_httpx_client_kwargs()) as client:
62
63
  res = client.post(url_get_token, data=token_credentials)
63
64
  res_json = res.json()
64
65
  uipath_token_header = res_json.get("access_token")
@@ -79,7 +80,7 @@ async def get_token_header_async(
79
80
  grant_type="client_credentials",
80
81
  )
81
82
 
82
- with httpx.Client() as client:
83
+ with httpx.Client(**get_httpx_client_kwargs()) as client:
83
84
  res_json = client.post(url_get_token, data=token_credentials).json()
84
85
  uipath_token_header = res_json.get("access_token")
85
86
 
@@ -1,21 +1,5 @@
1
- from .guardrail_nodes import (
2
- create_agent_guardrail_node,
3
- create_llm_guardrail_node,
4
- create_tool_guardrail_node,
5
- )
6
1
  from .guardrails_factory import build_guardrails_with_actions
7
- from .guardrails_subgraph import (
8
- create_agent_guardrails_subgraph,
9
- create_llm_guardrails_subgraph,
10
- create_tool_guardrails_subgraph,
11
- )
12
2
 
13
3
  __all__ = [
14
- "create_llm_guardrails_subgraph",
15
- "create_agent_guardrails_subgraph",
16
- "create_tool_guardrails_subgraph",
17
- "create_llm_guardrail_node",
18
- "create_agent_guardrail_node",
19
- "create_tool_guardrail_node",
20
4
  "build_guardrails_with_actions",
21
5
  ]
@@ -1,6 +1,7 @@
1
1
  from .base_action import GuardrailAction
2
2
  from .block_action import BlockAction
3
3
  from .escalate_action import EscalateAction
4
+ from .filter_action import FilterAction
4
5
  from .log_action import LogAction
5
6
 
6
7
  __all__ = [
@@ -8,4 +9,5 @@ __all__ = [
8
9
  "BlockAction",
9
10
  "LogAction",
10
11
  "EscalateAction",
12
+ "FilterAction",
11
13
  ]
@@ -6,7 +6,7 @@ from uipath.runtime.errors import UiPathErrorCategory, UiPathErrorCode
6
6
  from uipath_langchain.agent.guardrails.types import ExecutionStage
7
7
 
8
8
  from ...exceptions import AgentTerminationException
9
- from ..types import AgentGuardrailsGraphState
9
+ from ...react.types import AgentGuardrailsGraphState
10
10
  from .base_action import GuardrailAction, GuardrailActionNode
11
11
 
12
12
 
@@ -4,7 +4,7 @@ import json
4
4
  import re
5
5
  from typing import Any, Dict, Literal
6
6
 
7
- from langchain_core.messages import AIMessage, ToolCall, ToolMessage
7
+ from langchain_core.messages import AIMessage, ToolMessage
8
8
  from langgraph.types import Command, interrupt
9
9
  from uipath.platform.common import CreateEscalation
10
10
  from uipath.platform.guardrails import (
@@ -14,8 +14,9 @@ from uipath.platform.guardrails import (
14
14
  from uipath.runtime.errors import UiPathErrorCode
15
15
 
16
16
  from ...exceptions import AgentTerminationException
17
- from ..guardrail_nodes import _message_text
18
- from ..types import AgentGuardrailsGraphState, ExecutionStage
17
+ from ...react.types import AgentGuardrailsGraphState
18
+ from ..types import ExecutionStage
19
+ from ..utils import _extract_tool_args_from_message, get_message_content
19
20
  from .base_action import GuardrailAction, GuardrailActionNode
20
21
 
21
22
 
@@ -229,7 +230,7 @@ def _process_llm_escalation_response(
229
230
  if content_list:
230
231
  last_message.content = content_list[-1]
231
232
 
232
- return Command[Any](update={"messages": msgs})
233
+ return Command(update={"messages": msgs})
233
234
  except Exception as e:
234
235
  raise AgentTerminationException(
235
236
  code=UiPathErrorCode.EXECUTION_ERROR,
@@ -292,7 +293,11 @@ def _process_tool_escalation_response(
292
293
  if last_message.tool_calls:
293
294
  tool_calls = list(last_message.tool_calls)
294
295
  for tool_call in tool_calls:
295
- call_name = extract_tool_name(tool_call)
296
+ call_name = (
297
+ tool_call.get("name")
298
+ if isinstance(tool_call, dict)
299
+ else getattr(tool_call, "name", None)
300
+ )
296
301
  if call_name == tool_name:
297
302
  # Update args for the matching tool call
298
303
  if isinstance(reviewed_tool_calls_args, dict):
@@ -311,7 +316,7 @@ def _process_tool_escalation_response(
311
316
  if reviewed_outputs_json:
312
317
  last_message.content = reviewed_outputs_json
313
318
 
314
- return Command[Any](update={"messages": msgs})
319
+ return Command(update={"messages": msgs})
315
320
  except Exception as e:
316
321
  raise AgentTerminationException(
317
322
  code=UiPathErrorCode.EXECUTION_ERROR,
@@ -377,7 +382,7 @@ def _extract_llm_escalation_content(
377
382
  if isinstance(last_message, ToolMessage):
378
383
  return last_message.content
379
384
 
380
- content = _message_text(last_message)
385
+ content = get_message_content(last_message)
381
386
  return json.dumps(content) if content else ""
382
387
 
383
388
  # For AI messages, process tool calls if present
@@ -395,14 +400,14 @@ def _extract_llm_escalation_content(
395
400
  ):
396
401
  content_list.append(json.dumps(args["content"]))
397
402
 
398
- message_content = _message_text(last_message)
403
+ message_content = get_message_content(last_message)
399
404
  if message_content:
400
405
  content_list.append(message_content)
401
406
 
402
407
  return json.dumps(content_list)
403
408
 
404
409
  # Fallback for other message types
405
- return _message_text(last_message)
410
+ return get_message_content(last_message)
406
411
 
407
412
 
408
413
  def _extract_agent_escalation_content(
@@ -437,23 +442,9 @@ def _extract_tool_escalation_content(
437
442
  """
438
443
  last_message = state.messages[-1]
439
444
  if execution_stage == ExecutionStage.PRE_EXECUTION:
440
- if not isinstance(last_message, AIMessage):
441
- return ""
442
- if not last_message.tool_calls:
443
- return ""
444
-
445
- # Find the tool call with matching name
446
- for tool_call in last_message.tool_calls:
447
- call_name = extract_tool_name(tool_call)
448
- if call_name == tool_name:
449
- # Extract args from the matching tool call
450
- args = (
451
- tool_call.get("args")
452
- if isinstance(tool_call, dict)
453
- else getattr(tool_call, "args", None)
454
- )
455
- if args is not None:
456
- return json.dumps(args)
445
+ args = _extract_tool_args_from_message(last_message, tool_name)
446
+ if args:
447
+ return json.dumps(args)
457
448
  return ""
458
449
  else:
459
450
  if not isinstance(last_message, ToolMessage):
@@ -461,14 +452,6 @@ def _extract_tool_escalation_content(
461
452
  return last_message.content
462
453
 
463
454
 
464
- def extract_tool_name(tool_call: ToolCall) -> Any | None:
465
- return (
466
- tool_call.get("name")
467
- if isinstance(tool_call, dict)
468
- else getattr(tool_call, "name", None)
469
- )
470
-
471
-
472
455
  def _execution_stage_to_escalation_field(
473
456
  execution_stage: ExecutionStage,
474
457
  ) -> str:
@@ -0,0 +1,55 @@
1
+ import re
2
+ from typing import Any
3
+
4
+ from uipath.platform.guardrails import BaseGuardrail, GuardrailScope
5
+ from uipath.runtime.errors import UiPathErrorCategory, UiPathErrorCode
6
+
7
+ from uipath_langchain.agent.guardrails.types import ExecutionStage
8
+
9
+ from ...exceptions import AgentTerminationException
10
+ from ...react.types import AgentGuardrailsGraphState
11
+ from .base_action import GuardrailAction, GuardrailActionNode
12
+
13
+
14
+ class FilterAction(GuardrailAction):
15
+ """Action that filters inputs/outputs on guardrail failure.
16
+
17
+ For now, filtering is only supported for non-AGENT and non-LLM scopes.
18
+ If invoked for ``GuardrailScope.AGENT`` or ``GuardrailScope.LLM``, this action
19
+ raises an exception to indicate the operation is not supported yet.
20
+ """
21
+
22
+ def action_node(
23
+ self,
24
+ *,
25
+ guardrail: BaseGuardrail,
26
+ scope: GuardrailScope,
27
+ execution_stage: ExecutionStage,
28
+ guarded_component_name: str,
29
+ ) -> GuardrailActionNode:
30
+ """Create a guardrail action node that performs filtering.
31
+
32
+ Args:
33
+ guardrail: The guardrail responsible for the validation.
34
+ scope: The scope in which the guardrail applies.
35
+ execution_stage: Whether this runs before or after execution.
36
+ guarded_component_name: Name of the guarded component.
37
+
38
+ Returns:
39
+ A tuple containing the node name and the async node callable.
40
+ """
41
+ raw_node_name = f"{scope.name}_{execution_stage.name}_{guardrail.name}_filter"
42
+ node_name = re.sub(r"\W+", "_", raw_node_name.lower()).strip("_")
43
+
44
+ async def _node(_state: AgentGuardrailsGraphState) -> dict[str, Any]:
45
+ if scope in (GuardrailScope.AGENT, GuardrailScope.LLM):
46
+ raise AgentTerminationException(
47
+ code=UiPathErrorCode.EXECUTION_ERROR,
48
+ title="Guardrail filter action not supported",
49
+ detail=f"FilterAction is not supported for scope [{scope.name}] at this time.",
50
+ category=UiPathErrorCategory.USER,
51
+ )
52
+ # No-op for other scopes for now.
53
+ return {}
54
+
55
+ return node_name, _node
@@ -6,7 +6,7 @@ from uipath.platform.guardrails import BaseGuardrail, GuardrailScope
6
6
 
7
7
  from uipath_langchain.agent.guardrails.types import ExecutionStage
8
8
 
9
- from ..types import AgentGuardrailsGraphState
9
+ from ...react.types import AgentGuardrailsGraphState
10
10
  from .base_action import GuardrailAction, GuardrailActionNode
11
11
 
12
12
  logger = logging.getLogger(__name__)
@@ -3,25 +3,107 @@ import logging
3
3
  import re
4
4
  from typing import Any, Callable
5
5
 
6
- from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, SystemMessage
7
6
  from langgraph.types import Command
7
+ from uipath.core.guardrails import (
8
+ DeterministicGuardrail,
9
+ DeterministicGuardrailsService,
10
+ )
8
11
  from uipath.platform import UiPath
9
12
  from uipath.platform.guardrails import (
10
13
  BaseGuardrail,
14
+ BuiltInValidatorGuardrail,
11
15
  GuardrailScope,
12
16
  )
17
+ from uipath.runtime.errors import UiPathErrorCode
13
18
 
14
19
  from uipath_langchain.agent.guardrails.types import ExecutionStage
20
+ from uipath_langchain.agent.guardrails.utils import (
21
+ _extract_tool_input_data,
22
+ _extract_tool_output_data,
23
+ get_message_content,
24
+ )
25
+ from uipath_langchain.agent.react.types import AgentGuardrailsGraphState
15
26
 
16
- from .types import AgentGuardrailsGraphState
27
+ from ..exceptions import AgentTerminationException
17
28
 
18
29
  logger = logging.getLogger(__name__)
19
30
 
20
31
 
21
- def _message_text(msg: AnyMessage) -> str:
22
- if isinstance(msg, (HumanMessage, SystemMessage)):
23
- return msg.content if isinstance(msg.content, str) else str(msg.content)
24
- return str(getattr(msg, "content", "")) if hasattr(msg, "content") else ""
32
+ def _evaluate_deterministic_guardrail(
33
+ state: AgentGuardrailsGraphState,
34
+ guardrail: DeterministicGuardrail,
35
+ execution_stage: ExecutionStage,
36
+ input_data_extractor: Callable[[AgentGuardrailsGraphState], dict[str, Any]],
37
+ output_data_extractor: Callable[[AgentGuardrailsGraphState], dict[str, Any]] | None,
38
+ ):
39
+ """Evaluate deterministic guardrail.
40
+
41
+ Args:
42
+ state: The current agent graph state.
43
+ guardrail: The deterministic guardrail to evaluate.
44
+ execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
45
+ input_data_extractor: Function to extract input data from state.
46
+ output_data_extractor: Function to extract output data from state (optional).
47
+
48
+ Returns:
49
+ The guardrail evaluation result.
50
+ """
51
+ service = DeterministicGuardrailsService()
52
+ input_data = input_data_extractor(state)
53
+
54
+ if execution_stage == ExecutionStage.PRE_EXECUTION:
55
+ return service.evaluate_pre_deterministic_guardrail(
56
+ input_data=input_data, guardrail=guardrail
57
+ )
58
+ else: # POST_EXECUTION
59
+ output_data = output_data_extractor(state) if output_data_extractor else {}
60
+ return service.evaluate_post_deterministic_guardrail(
61
+ input_data=input_data,
62
+ output_data=output_data,
63
+ guardrail=guardrail,
64
+ )
65
+
66
+
67
+ def _evaluate_builtin_guardrail(
68
+ state: AgentGuardrailsGraphState,
69
+ guardrail: BuiltInValidatorGuardrail,
70
+ payload_generator: Callable[[AgentGuardrailsGraphState], str],
71
+ ):
72
+ """Evaluate built-in validator guardrail.
73
+
74
+ Args:
75
+ state: The current agent graph state.
76
+ guardrail: The built-in validator guardrail to evaluate.
77
+ payload_generator: Function to generate payload text from state.
78
+
79
+ Returns:
80
+ The guardrail evaluation result.
81
+ """
82
+ text = payload_generator(state)
83
+ uipath = UiPath()
84
+ return uipath.guardrails.evaluate_guardrail(text, guardrail)
85
+
86
+
87
+ def _create_validation_command(
88
+ result,
89
+ success_node: str,
90
+ failure_node: str,
91
+ ) -> Command[Any]:
92
+ """Create command based on validation result.
93
+
94
+ Args:
95
+ result: The guardrail evaluation result.
96
+ success_node: Node to route to on validation pass.
97
+ failure_node: Node to route to on validation fail.
98
+
99
+ Returns:
100
+ Command to update state and route to appropriate node.
101
+ """
102
+ if not result.validation_passed:
103
+ return Command(
104
+ goto=failure_node, update={"guardrail_validation_result": result.reason}
105
+ )
106
+ return Command(goto=success_node, update={"guardrail_validation_result": None})
25
107
 
26
108
 
27
109
  def _create_guardrail_node(
@@ -31,6 +113,10 @@ def _create_guardrail_node(
31
113
  payload_generator: Callable[[AgentGuardrailsGraphState], str],
32
114
  success_node: str,
33
115
  failure_node: str,
116
+ input_data_extractor: Callable[[AgentGuardrailsGraphState], dict[str, Any]]
117
+ | None = None,
118
+ output_data_extractor: Callable[[AgentGuardrailsGraphState], dict[str, Any]]
119
+ | None = None,
34
120
  ) -> tuple[str, Callable[[AgentGuardrailsGraphState], Any]]:
35
121
  """Private factory for guardrail evaluation nodes.
36
122
 
@@ -44,19 +130,41 @@ def _create_guardrail_node(
44
130
  async def node(
45
131
  state: AgentGuardrailsGraphState,
46
132
  ):
47
- text = payload_generator(state)
48
133
  try:
49
- uipath = UiPath()
50
- result = uipath.guardrails.evaluate_guardrail(text, guardrail)
51
- except Exception as exc:
52
- logger.error("Failed to evaluate guardrail: %s", exc)
53
- raise
134
+ # Route to appropriate evaluation service based on guardrail type and scope
135
+ if (
136
+ isinstance(guardrail, DeterministicGuardrail)
137
+ and scope == GuardrailScope.TOOL
138
+ and input_data_extractor is not None
139
+ ):
140
+ result = _evaluate_deterministic_guardrail(
141
+ state,
142
+ guardrail,
143
+ execution_stage,
144
+ input_data_extractor,
145
+ output_data_extractor,
146
+ )
147
+ elif isinstance(guardrail, BuiltInValidatorGuardrail):
148
+ result = _evaluate_builtin_guardrail(
149
+ state, guardrail, payload_generator
150
+ )
151
+ else:
152
+ raise AgentTerminationException(
153
+ code=UiPathErrorCode.EXECUTION_ERROR,
154
+ title="Unsupported guardrail type",
155
+ detail=f"Guardrail type '{type(guardrail).__name__}' is not supported. "
156
+ f"Expected DeterministicGuardrail or BuiltInValidatorGuardrail.",
157
+ )
158
+
159
+ return _create_validation_command(result, success_node, failure_node)
54
160
 
55
- if not result.validation_passed:
56
- return Command(
57
- goto=failure_node, update={"guardrail_validation_result": result.reason}
161
+ except Exception as exc:
162
+ logger.error(
163
+ "Failed to evaluate guardrail '%s': %s",
164
+ guardrail.name,
165
+ exc,
58
166
  )
59
- return Command(goto=success_node, update={"guardrail_validation_result": None})
167
+ raise
60
168
 
61
169
  return node_name, node
62
170
 
@@ -70,7 +178,7 @@ def create_llm_guardrail_node(
70
178
  def _payload_generator(state: AgentGuardrailsGraphState) -> str:
71
179
  if not state.messages:
72
180
  return ""
73
- return _message_text(state.messages[-1])
181
+ return get_message_content(state.messages[-1])
74
182
 
75
183
  return _create_guardrail_node(
76
184
  guardrail,
@@ -82,17 +190,35 @@ def create_llm_guardrail_node(
82
190
  )
83
191
 
84
192
 
85
- def create_agent_guardrail_node(
193
+ def create_agent_init_guardrail_node(
86
194
  guardrail: BaseGuardrail,
87
195
  execution_stage: ExecutionStage,
88
196
  success_node: str,
89
197
  failure_node: str,
90
198
  ) -> tuple[str, Callable[[AgentGuardrailsGraphState], Any]]:
91
- # To be implemented in future PR
92
199
  def _payload_generator(state: AgentGuardrailsGraphState) -> str:
93
200
  if not state.messages:
94
201
  return ""
95
- return _message_text(state.messages[-1])
202
+ return get_message_content(state.messages[-1])
203
+
204
+ return _create_guardrail_node(
205
+ guardrail,
206
+ GuardrailScope.AGENT,
207
+ execution_stage,
208
+ _payload_generator,
209
+ success_node,
210
+ failure_node,
211
+ )
212
+
213
+
214
+ def create_agent_terminate_guardrail_node(
215
+ guardrail: BaseGuardrail,
216
+ execution_stage: ExecutionStage,
217
+ success_node: str,
218
+ failure_node: str,
219
+ ) -> tuple[str, Callable[[AgentGuardrailsGraphState], Any]]:
220
+ def _payload_generator(state: AgentGuardrailsGraphState) -> str:
221
+ return str(state.agent_result)
96
222
 
97
223
  return _create_guardrail_node(
98
224
  guardrail,
@@ -137,31 +263,19 @@ def create_tool_guardrail_node(
137
263
  return ""
138
264
 
139
265
  if execution_stage == ExecutionStage.PRE_EXECUTION:
140
- if not isinstance(state.messages[-1], AIMessage):
141
- return ""
142
- message = state.messages[-1]
143
-
144
- if not message.tool_calls:
145
- return ""
146
-
147
- # Find the first tool call with matching name
148
- for tool_call in message.tool_calls:
149
- call_name = (
150
- tool_call.get("name")
151
- if isinstance(tool_call, dict)
152
- else getattr(tool_call, "name", None)
153
- )
154
- if call_name == tool_name:
155
- # Extract args from the tool call
156
- args = (
157
- tool_call.get("args")
158
- if isinstance(tool_call, dict)
159
- else getattr(tool_call, "args", None)
160
- )
161
- if args is not None:
162
- return json.dumps(args)
163
-
164
- return _message_text(state.messages[-1])
266
+ # Extract tool args as dict and convert to JSON string
267
+ args_dict = _extract_tool_input_data(state, tool_name, execution_stage)
268
+ if args_dict:
269
+ return json.dumps(args_dict)
270
+
271
+ return get_message_content(state.messages[-1])
272
+
273
+ # Create closures for input/output data extraction (for deterministic guardrails)
274
+ def _input_data_extractor(state: AgentGuardrailsGraphState) -> dict[str, Any]:
275
+ return _extract_tool_input_data(state, tool_name, execution_stage)
276
+
277
+ def _output_data_extractor(state: AgentGuardrailsGraphState) -> dict[str, Any]:
278
+ return _extract_tool_output_data(state)
165
279
 
166
280
  return _create_guardrail_node(
167
281
  guardrail,
@@ -170,4 +284,6 @@ def create_tool_guardrail_node(
170
284
  _payload_generator,
171
285
  success_node,
172
286
  failure_node,
287
+ _input_data_extractor,
288
+ _output_data_extractor,
173
289
  )