uipath-langchain 0.1.28__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. uipath_langchain/_cli/_templates/langgraph.json.template +2 -4
  2. uipath_langchain/_cli/cli_new.py +1 -2
  3. uipath_langchain/_utils/_request_mixin.py +8 -0
  4. uipath_langchain/_utils/_settings.py +3 -2
  5. uipath_langchain/agent/guardrails/__init__.py +0 -16
  6. uipath_langchain/agent/guardrails/actions/__init__.py +2 -0
  7. uipath_langchain/agent/guardrails/actions/block_action.py +1 -1
  8. uipath_langchain/agent/guardrails/actions/escalate_action.py +265 -138
  9. uipath_langchain/agent/guardrails/actions/filter_action.py +290 -0
  10. uipath_langchain/agent/guardrails/actions/log_action.py +1 -1
  11. uipath_langchain/agent/guardrails/guardrail_nodes.py +193 -42
  12. uipath_langchain/agent/guardrails/guardrails_factory.py +235 -14
  13. uipath_langchain/agent/guardrails/types.py +0 -12
  14. uipath_langchain/agent/guardrails/utils.py +177 -0
  15. uipath_langchain/agent/react/agent.py +24 -9
  16. uipath_langchain/agent/react/constants.py +1 -2
  17. uipath_langchain/agent/react/file_type_handler.py +123 -0
  18. uipath_langchain/agent/{guardrails → react/guardrails}/guardrails_subgraph.py +119 -25
  19. uipath_langchain/agent/react/init_node.py +16 -1
  20. uipath_langchain/agent/react/job_attachments.py +125 -0
  21. uipath_langchain/agent/react/json_utils.py +183 -0
  22. uipath_langchain/agent/react/jsonschema_pydantic_converter.py +76 -0
  23. uipath_langchain/agent/react/llm_node.py +41 -10
  24. uipath_langchain/agent/react/llm_with_files.py +76 -0
  25. uipath_langchain/agent/react/router.py +48 -37
  26. uipath_langchain/agent/react/types.py +19 -1
  27. uipath_langchain/agent/react/utils.py +30 -4
  28. uipath_langchain/agent/tools/__init__.py +7 -1
  29. uipath_langchain/agent/tools/context_tool.py +151 -1
  30. uipath_langchain/agent/tools/escalation_tool.py +46 -15
  31. uipath_langchain/agent/tools/integration_tool.py +20 -16
  32. uipath_langchain/agent/tools/internal_tools/__init__.py +5 -0
  33. uipath_langchain/agent/tools/internal_tools/analyze_files_tool.py +113 -0
  34. uipath_langchain/agent/tools/internal_tools/internal_tool_factory.py +54 -0
  35. uipath_langchain/agent/tools/mcp_tool.py +86 -0
  36. uipath_langchain/agent/tools/process_tool.py +8 -1
  37. uipath_langchain/agent/tools/static_args.py +18 -40
  38. uipath_langchain/agent/tools/tool_factory.py +13 -5
  39. uipath_langchain/agent/tools/tool_node.py +133 -4
  40. uipath_langchain/agent/tools/utils.py +31 -0
  41. uipath_langchain/agent/wrappers/__init__.py +6 -0
  42. uipath_langchain/agent/wrappers/job_attachment_wrapper.py +62 -0
  43. uipath_langchain/agent/wrappers/static_args_wrapper.py +34 -0
  44. uipath_langchain/chat/__init__.py +4 -0
  45. uipath_langchain/chat/bedrock.py +16 -0
  46. uipath_langchain/chat/mapper.py +60 -42
  47. uipath_langchain/chat/openai.py +56 -26
  48. uipath_langchain/chat/supported_models.py +9 -0
  49. uipath_langchain/chat/vertex.py +62 -46
  50. uipath_langchain/embeddings/embeddings.py +18 -12
  51. uipath_langchain/runtime/factory.py +10 -5
  52. uipath_langchain/runtime/runtime.py +38 -35
  53. uipath_langchain/runtime/schema.py +72 -16
  54. uipath_langchain/runtime/storage.py +178 -71
  55. {uipath_langchain-0.1.28.dist-info → uipath_langchain-0.3.1.dist-info}/METADATA +7 -4
  56. uipath_langchain-0.3.1.dist-info/RECORD +90 -0
  57. uipath_langchain-0.1.28.dist-info/RECORD +0 -76
  58. {uipath_langchain-0.1.28.dist-info → uipath_langchain-0.3.1.dist-info}/WHEEL +0 -0
  59. {uipath_langchain-0.1.28.dist-info → uipath_langchain-0.3.1.dist-info}/entry_points.txt +0 -0
  60. {uipath_langchain-0.1.28.dist-info → uipath_langchain-0.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,290 @@
1
+ import re
2
+ from typing import Any
3
+
4
+ from langchain_core.messages import AIMessage, ToolMessage
5
+ from langgraph.types import Command
6
+ from uipath.core.guardrails.guardrails import FieldReference, FieldSource
7
+ from uipath.platform.guardrails import BaseGuardrail, GuardrailScope
8
+ from uipath.runtime.errors import UiPathErrorCategory, UiPathErrorCode
9
+
10
+ from uipath_langchain.agent.guardrails.types import ExecutionStage
11
+
12
+ from ...exceptions import AgentTerminationException
13
+ from ...react.types import AgentGuardrailsGraphState
14
+ from .base_action import GuardrailAction, GuardrailActionNode
15
+
16
+
17
+ class FilterAction(GuardrailAction):
18
+ """Action that filters inputs/outputs on guardrail failure.
19
+
20
+ For Tool scope, this action removes specified fields from tool call arguments.
21
+ For AGENT and LLM scopes, this action raises an exception as it's not supported yet.
22
+ """
23
+
24
+ def __init__(self, fields: list[FieldReference] | None = None):
25
+ """Initialize FilterAction with fields to filter.
26
+
27
+ Args:
28
+ fields: List of FieldReference objects specifying which fields to filter.
29
+ """
30
+ self.fields = fields or []
31
+
32
+ def action_node(
33
+ self,
34
+ *,
35
+ guardrail: BaseGuardrail,
36
+ scope: GuardrailScope,
37
+ execution_stage: ExecutionStage,
38
+ guarded_component_name: str,
39
+ ) -> GuardrailActionNode:
40
+ """Create a guardrail action node that performs filtering.
41
+
42
+ Args:
43
+ guardrail: The guardrail responsible for the validation.
44
+ scope: The scope in which the guardrail applies.
45
+ execution_stage: Whether this runs before or after execution.
46
+ guarded_component_name: Name of the guarded component.
47
+
48
+ Returns:
49
+ A tuple containing the node name and the async node callable.
50
+ """
51
+ raw_node_name = f"{scope.name}_{execution_stage.name}_{guardrail.name}_filter"
52
+ node_name = re.sub(r"\W+", "_", raw_node_name.lower()).strip("_")
53
+
54
+ async def _node(
55
+ _state: AgentGuardrailsGraphState,
56
+ ) -> dict[str, Any] | Command[Any]:
57
+ if scope == GuardrailScope.TOOL:
58
+ return _filter_tool_fields(
59
+ _state,
60
+ self.fields,
61
+ execution_stage,
62
+ guarded_component_name,
63
+ guardrail.name,
64
+ )
65
+
66
+ raise AgentTerminationException(
67
+ code=UiPathErrorCode.EXECUTION_ERROR,
68
+ title="Guardrail filter action not supported",
69
+ detail=f"FilterAction is not supported for scope [{scope.name}] at this time.",
70
+ category=UiPathErrorCategory.USER,
71
+ )
72
+
73
+ return node_name, _node
74
+
75
+
76
+ def _filter_tool_fields(
77
+ state: AgentGuardrailsGraphState,
78
+ fields_to_filter: list[FieldReference],
79
+ execution_stage: ExecutionStage,
80
+ tool_name: str,
81
+ guardrail_name: str,
82
+ ) -> dict[str, Any] | Command[Any]:
83
+ """Filter specified fields from tool call arguments or tool output.
84
+
85
+ The filter action filters fields based on the execution stage:
86
+ - PRE_EXECUTION: Only input fields are filtered
87
+ - POST_EXECUTION: Only output fields are filtered
88
+
89
+ Args:
90
+ state: The current agent graph state.
91
+ fields_to_filter: List of FieldReference objects specifying which fields to filter.
92
+ execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
93
+ tool_name: Name of the tool to filter.
94
+ guardrail_name: Name of the guardrail for logging purposes.
95
+
96
+ Returns:
97
+ Command to update messages with filtered tool call args or output.
98
+
99
+ Raises:
100
+ AgentTerminationException: If filtering fails.
101
+ """
102
+ try:
103
+ if not fields_to_filter:
104
+ return {}
105
+
106
+ if execution_stage == ExecutionStage.PRE_EXECUTION:
107
+ return _filter_tool_input_fields(state, fields_to_filter, tool_name)
108
+ else:
109
+ return _filter_tool_output_fields(state, fields_to_filter)
110
+
111
+ except Exception as e:
112
+ raise AgentTerminationException(
113
+ code=UiPathErrorCode.EXECUTION_ERROR,
114
+ title="Filter action failed",
115
+ detail=f"Failed to filter tool fields: {str(e)}",
116
+ category=UiPathErrorCategory.USER,
117
+ ) from e
118
+
119
+
120
+ def _filter_tool_input_fields(
121
+ state: AgentGuardrailsGraphState,
122
+ fields_to_filter: list[FieldReference],
123
+ tool_name: str,
124
+ ) -> dict[str, Any] | Command[Any]:
125
+ """Filter specified input fields from tool call arguments (PRE_EXECUTION only).
126
+
127
+ This function is called at PRE_EXECUTION to filter input fields from tool call arguments
128
+ before the tool is executed.
129
+
130
+ Args:
131
+ state: The current agent graph state.
132
+ fields_to_filter: List of FieldReference objects specifying which fields to filter.
133
+ tool_name: Name of the tool to filter.
134
+
135
+ Returns:
136
+ Command to update messages with filtered tool call args, or empty dict if no input fields to filter.
137
+ """
138
+ # Check if there are any input fields to filter
139
+ has_input_fields = any(
140
+ field_ref.source == FieldSource.INPUT for field_ref in fields_to_filter
141
+ )
142
+
143
+ if not has_input_fields:
144
+ return {}
145
+
146
+ msgs = state.messages.copy()
147
+ if not msgs:
148
+ return {}
149
+
150
+ # Find the AIMessage with tool calls
151
+ # At PRE_EXECUTION, this is always the last message
152
+ ai_message = None
153
+ for i in range(len(msgs) - 1, -1, -1):
154
+ msg = msgs[i]
155
+ if isinstance(msg, AIMessage) and msg.tool_calls:
156
+ ai_message = msg
157
+ break
158
+
159
+ if ai_message is None:
160
+ return {}
161
+
162
+ # Find and filter the tool call with matching name
163
+ # Type assertion: we know ai_message is AIMessage from the check above
164
+ assert isinstance(ai_message, AIMessage)
165
+ tool_calls = list(ai_message.tool_calls)
166
+ modified = False
167
+
168
+ for tool_call in tool_calls:
169
+ call_name = (
170
+ tool_call.get("name")
171
+ if isinstance(tool_call, dict)
172
+ else getattr(tool_call, "name", None)
173
+ )
174
+
175
+ if call_name == tool_name:
176
+ # Get the current args
177
+ args = (
178
+ tool_call.get("args")
179
+ if isinstance(tool_call, dict)
180
+ else getattr(tool_call, "args", None)
181
+ )
182
+
183
+ if args and isinstance(args, dict):
184
+ # Filter out the specified input fields
185
+ filtered_args = args.copy()
186
+ for field_ref in fields_to_filter:
187
+ # Only filter input fields
188
+ if (
189
+ field_ref.source == FieldSource.INPUT
190
+ and field_ref.path in filtered_args
191
+ ):
192
+ del filtered_args[field_ref.path]
193
+ modified = True
194
+
195
+ # Update the tool call with filtered args
196
+ if isinstance(tool_call, dict):
197
+ tool_call["args"] = filtered_args
198
+ else:
199
+ tool_call.args = filtered_args
200
+
201
+ break
202
+
203
+ if modified:
204
+ ai_message.tool_calls = tool_calls
205
+ return Command(update={"messages": msgs})
206
+
207
+ return {}
208
+
209
+
210
+ def _filter_tool_output_fields(
211
+ state: AgentGuardrailsGraphState,
212
+ fields_to_filter: list[FieldReference],
213
+ ) -> dict[str, Any] | Command[Any]:
214
+ """Filter specified output fields from tool output (POST_EXECUTION only).
215
+
216
+ This function is called at POST_EXECUTION to filter output fields from tool results
217
+ after the tool has been executed.
218
+
219
+ Args:
220
+ state: The current agent graph state.
221
+ fields_to_filter: List of FieldReference objects specifying which fields to filter.
222
+
223
+ Returns:
224
+ Command to update messages with filtered tool output, or empty dict if no output fields to filter.
225
+ """
226
+ # Check if there are any output fields to filter
227
+ has_output_fields = any(
228
+ field_ref.source == FieldSource.OUTPUT for field_ref in fields_to_filter
229
+ )
230
+
231
+ if not has_output_fields:
232
+ return {}
233
+
234
+ msgs = state.messages.copy()
235
+ if not msgs:
236
+ return {}
237
+
238
+ last_message = msgs[-1]
239
+ if not isinstance(last_message, ToolMessage):
240
+ return {}
241
+
242
+ # Parse the tool output content
243
+ import json
244
+
245
+ content = last_message.content
246
+ if not content:
247
+ return {}
248
+
249
+ # Try to parse the content as JSON or dict
250
+ try:
251
+ if isinstance(content, dict):
252
+ output_data = content
253
+ elif isinstance(content, str):
254
+ try:
255
+ output_data = json.loads(content)
256
+ except json.JSONDecodeError:
257
+ # Try to parse as Python literal (dict representation)
258
+ import ast
259
+
260
+ try:
261
+ output_data = ast.literal_eval(content)
262
+ if not isinstance(output_data, dict):
263
+ return {}
264
+ except (ValueError, SyntaxError):
265
+ return {}
266
+ else:
267
+ # Content is not JSON-parseable, can't filter specific fields
268
+ return {}
269
+ except Exception:
270
+ return {}
271
+
272
+ if not isinstance(output_data, dict):
273
+ return {}
274
+
275
+ # Filter out the specified fields
276
+ filtered_output = output_data.copy()
277
+ modified = False
278
+
279
+ for field_ref in fields_to_filter:
280
+ # Only filter output fields
281
+ if field_ref.source == FieldSource.OUTPUT and field_ref.path in filtered_output:
282
+ del filtered_output[field_ref.path]
283
+ modified = True
284
+
285
+ if modified:
286
+ # Update the tool message content with filtered output
287
+ last_message.content = json.dumps(filtered_output)
288
+ return Command(update={"messages": msgs})
289
+
290
+ return {}
@@ -6,7 +6,7 @@ from uipath.platform.guardrails import BaseGuardrail, GuardrailScope
6
6
 
7
7
  from uipath_langchain.agent.guardrails.types import ExecutionStage
8
8
 
9
- from ..types import AgentGuardrailsGraphState
9
+ from ...react.types import AgentGuardrailsGraphState
10
10
  from .base_action import GuardrailAction, GuardrailActionNode
11
11
 
12
12
  logger = logging.getLogger(__name__)
@@ -3,25 +3,108 @@ import logging
3
3
  import re
4
4
  from typing import Any, Callable
5
5
 
6
- from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, SystemMessage
7
6
  from langgraph.types import Command
7
+ from uipath.core.guardrails import (
8
+ DeterministicGuardrail,
9
+ DeterministicGuardrailsService,
10
+ )
8
11
  from uipath.platform import UiPath
9
12
  from uipath.platform.guardrails import (
10
13
  BaseGuardrail,
14
+ BuiltInValidatorGuardrail,
11
15
  GuardrailScope,
12
16
  )
17
+ from uipath.runtime.errors import UiPathErrorCode
13
18
 
14
19
  from uipath_langchain.agent.guardrails.types import ExecutionStage
20
+ from uipath_langchain.agent.guardrails.utils import (
21
+ _extract_tool_args_from_message,
22
+ _extract_tool_output_data,
23
+ _extract_tools_args_from_message,
24
+ get_message_content,
25
+ )
26
+ from uipath_langchain.agent.react.types import AgentGuardrailsGraphState
15
27
 
16
- from .types import AgentGuardrailsGraphState
28
+ from ..exceptions import AgentTerminationException
17
29
 
18
30
  logger = logging.getLogger(__name__)
19
31
 
20
32
 
21
- def _message_text(msg: AnyMessage) -> str:
22
- if isinstance(msg, (HumanMessage, SystemMessage)):
23
- return msg.content if isinstance(msg.content, str) else str(msg.content)
24
- return str(getattr(msg, "content", "")) if hasattr(msg, "content") else ""
33
+ def _evaluate_deterministic_guardrail(
34
+ state: AgentGuardrailsGraphState,
35
+ guardrail: DeterministicGuardrail,
36
+ execution_stage: ExecutionStage,
37
+ input_data_extractor: Callable[[AgentGuardrailsGraphState], dict[str, Any]],
38
+ output_data_extractor: Callable[[AgentGuardrailsGraphState], dict[str, Any]] | None,
39
+ ):
40
+ """Evaluate deterministic guardrail.
41
+
42
+ Args:
43
+ state: The current agent graph state.
44
+ guardrail: The deterministic guardrail to evaluate.
45
+ execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
46
+ input_data_extractor: Function to extract input data from state.
47
+ output_data_extractor: Function to extract output data from state (optional).
48
+
49
+ Returns:
50
+ The guardrail evaluation result.
51
+ """
52
+ service = DeterministicGuardrailsService()
53
+ input_data = input_data_extractor(state)
54
+
55
+ if execution_stage == ExecutionStage.PRE_EXECUTION:
56
+ return service.evaluate_pre_deterministic_guardrail(
57
+ input_data=input_data, guardrail=guardrail
58
+ )
59
+ else: # POST_EXECUTION
60
+ output_data = output_data_extractor(state) if output_data_extractor else {}
61
+ return service.evaluate_post_deterministic_guardrail(
62
+ input_data=input_data,
63
+ output_data=output_data,
64
+ guardrail=guardrail,
65
+ )
66
+
67
+
68
+ def _evaluate_builtin_guardrail(
69
+ state: AgentGuardrailsGraphState,
70
+ guardrail: BuiltInValidatorGuardrail,
71
+ payload_generator: Callable[[AgentGuardrailsGraphState], str],
72
+ ):
73
+ """Evaluate built-in validator guardrail.
74
+
75
+ Args:
76
+ state: The current agent graph state.
77
+ guardrail: The built-in validator guardrail to evaluate.
78
+ payload_generator: Function to generate payload text from state.
79
+
80
+ Returns:
81
+ The guardrail evaluation result.
82
+ """
83
+ text = payload_generator(state)
84
+ uipath = UiPath()
85
+ return uipath.guardrails.evaluate_guardrail(text, guardrail)
86
+
87
+
88
+ def _create_validation_command(
89
+ result,
90
+ success_node: str,
91
+ failure_node: str,
92
+ ) -> Command[Any]:
93
+ """Create command based on validation result.
94
+
95
+ Args:
96
+ result: The guardrail evaluation result.
97
+ success_node: Node to route to on validation pass.
98
+ failure_node: Node to route to on validation fail.
99
+
100
+ Returns:
101
+ Command to update state and route to appropriate node.
102
+ """
103
+ if not result.validation_passed:
104
+ return Command(
105
+ goto=failure_node, update={"guardrail_validation_result": result.reason}
106
+ )
107
+ return Command(goto=success_node, update={"guardrail_validation_result": None})
25
108
 
26
109
 
27
110
  def _create_guardrail_node(
@@ -31,10 +114,15 @@ def _create_guardrail_node(
31
114
  payload_generator: Callable[[AgentGuardrailsGraphState], str],
32
115
  success_node: str,
33
116
  failure_node: str,
117
+ input_data_extractor: Callable[[AgentGuardrailsGraphState], dict[str, Any]]
118
+ | None = None,
119
+ output_data_extractor: Callable[[AgentGuardrailsGraphState], dict[str, Any]]
120
+ | None = None,
121
+ tool_name: str | None = None,
34
122
  ) -> tuple[str, Callable[[AgentGuardrailsGraphState], Any]]:
35
123
  """Private factory for guardrail evaluation nodes.
36
124
 
37
- Returns a node that evaluates the guardrail and routes via Command:
125
+ Returns a node with observability metadata attached as __metadata__ attribute:
38
126
  - goto success_node on validation pass
39
127
  - goto failure_node on validation fail
40
128
  """
@@ -44,19 +132,60 @@ def _create_guardrail_node(
44
132
  async def node(
45
133
  state: AgentGuardrailsGraphState,
46
134
  ):
47
- text = payload_generator(state)
48
135
  try:
49
- uipath = UiPath()
50
- result = uipath.guardrails.evaluate_guardrail(text, guardrail)
136
+ # Route to appropriate evaluation service based on guardrail type and scope
137
+ if (
138
+ isinstance(guardrail, DeterministicGuardrail)
139
+ and scope == GuardrailScope.TOOL
140
+ and input_data_extractor is not None
141
+ ):
142
+ result = _evaluate_deterministic_guardrail(
143
+ state,
144
+ guardrail,
145
+ execution_stage,
146
+ input_data_extractor,
147
+ output_data_extractor,
148
+ )
149
+ elif isinstance(guardrail, BuiltInValidatorGuardrail):
150
+ result = _evaluate_builtin_guardrail(
151
+ state, guardrail, payload_generator
152
+ )
153
+ else:
154
+ # Provide specific error message for DeterministicGuardrails with wrong scope
155
+ if isinstance(guardrail, DeterministicGuardrail):
156
+ raise AgentTerminationException(
157
+ code=UiPathErrorCode.EXECUTION_ERROR,
158
+ title="Invalid guardrail scope",
159
+ detail=f"DeterministicGuardrail '{guardrail.name}' can only be used with TOOL scope. "
160
+ f"Current scope: {scope.name}. "
161
+ f"Please configure this guardrail to use only TOOL scope.",
162
+ )
163
+ else:
164
+ raise AgentTerminationException(
165
+ code=UiPathErrorCode.EXECUTION_ERROR,
166
+ title="Unsupported guardrail type",
167
+ detail=f"Guardrail type '{type(guardrail).__name__}' is not supported. "
168
+ f"Expected DeterministicGuardrail (TOOL scope only) or BuiltInValidatorGuardrail.",
169
+ )
170
+
171
+ return _create_validation_command(result, success_node, failure_node)
172
+
51
173
  except Exception as exc:
52
- logger.error("Failed to evaluate guardrail: %s", exc)
174
+ logger.error(
175
+ "Failed to evaluate guardrail '%s': %s",
176
+ guardrail.name,
177
+ exc,
178
+ )
53
179
  raise
54
180
 
55
- if not result.validation_passed:
56
- return Command(
57
- goto=failure_node, update={"guardrail_validation_result": result.reason}
58
- )
59
- return Command(goto=success_node, update={"guardrail_validation_result": None})
181
+ # Attach observability metadata as function attribute
182
+ node.__metadata__ = { # type: ignore[attr-defined]
183
+ "guardrail_name": guardrail.name,
184
+ "guardrail_description": getattr(guardrail, "description", None),
185
+ "guardrail_scope": scope.value,
186
+ "guardrail_stage": execution_stage.value,
187
+ "tool_name": tool_name,
188
+ }
60
189
 
61
190
  return node_name, node
62
191
 
@@ -70,7 +199,11 @@ def create_llm_guardrail_node(
70
199
  def _payload_generator(state: AgentGuardrailsGraphState) -> str:
71
200
  if not state.messages:
72
201
  return ""
73
- return _message_text(state.messages[-1])
202
+ match execution_stage:
203
+ case ExecutionStage.PRE_EXECUTION:
204
+ return get_message_content(state.messages[-1])
205
+ case ExecutionStage.POST_EXECUTION:
206
+ return json.dumps(_extract_tools_args_from_message(state.messages[-1]))
74
207
 
75
208
  return _create_guardrail_node(
76
209
  guardrail,
@@ -82,17 +215,35 @@ def create_llm_guardrail_node(
82
215
  )
83
216
 
84
217
 
85
- def create_agent_guardrail_node(
218
+ def create_agent_init_guardrail_node(
86
219
  guardrail: BaseGuardrail,
87
220
  execution_stage: ExecutionStage,
88
221
  success_node: str,
89
222
  failure_node: str,
90
223
  ) -> tuple[str, Callable[[AgentGuardrailsGraphState], Any]]:
91
- # To be implemented in future PR
92
224
  def _payload_generator(state: AgentGuardrailsGraphState) -> str:
93
225
  if not state.messages:
94
226
  return ""
95
- return _message_text(state.messages[-1])
227
+ return get_message_content(state.messages[-1])
228
+
229
+ return _create_guardrail_node(
230
+ guardrail,
231
+ GuardrailScope.AGENT,
232
+ execution_stage,
233
+ _payload_generator,
234
+ success_node,
235
+ failure_node,
236
+ )
237
+
238
+
239
+ def create_agent_terminate_guardrail_node(
240
+ guardrail: BaseGuardrail,
241
+ execution_stage: ExecutionStage,
242
+ success_node: str,
243
+ failure_node: str,
244
+ ) -> tuple[str, Callable[[AgentGuardrailsGraphState], Any]]:
245
+ def _payload_generator(state: AgentGuardrailsGraphState) -> str:
246
+ return str(state.agent_result)
96
247
 
97
248
  return _create_guardrail_node(
98
249
  guardrail,
@@ -137,31 +288,28 @@ def create_tool_guardrail_node(
137
288
  return ""
138
289
 
139
290
  if execution_stage == ExecutionStage.PRE_EXECUTION:
140
- if not isinstance(state.messages[-1], AIMessage):
141
- return ""
142
- message = state.messages[-1]
291
+ last_message = state.messages[-1]
292
+ args_dict = _extract_tool_args_from_message(last_message, tool_name)
293
+ if args_dict:
294
+ return json.dumps(args_dict)
143
295
 
144
- if not message.tool_calls:
145
- return ""
296
+ return get_message_content(state.messages[-1])
146
297
 
147
- # Find the first tool call with matching name
148
- for tool_call in message.tool_calls:
149
- call_name = (
150
- tool_call.get("name")
151
- if isinstance(tool_call, dict)
152
- else getattr(tool_call, "name", None)
153
- )
154
- if call_name == tool_name:
155
- # Extract args from the tool call
156
- args = (
157
- tool_call.get("args")
158
- if isinstance(tool_call, dict)
159
- else getattr(tool_call, "args", None)
160
- )
161
- if args is not None:
162
- return json.dumps(args)
298
+ # Create closures for input/output data extraction (for deterministic guardrails)
299
+ def _input_data_extractor(state: AgentGuardrailsGraphState) -> dict[str, Any]:
300
+ if execution_stage == ExecutionStage.PRE_EXECUTION:
301
+ if len(state.messages) < 1:
302
+ return {}
303
+ message = state.messages[-1]
304
+ else: # POST_EXECUTION
305
+ if len(state.messages) < 2:
306
+ return {}
307
+ message = state.messages[-2]
308
+
309
+ return _extract_tool_args_from_message(message, tool_name)
163
310
 
164
- return _message_text(state.messages[-1])
311
+ def _output_data_extractor(state: AgentGuardrailsGraphState) -> dict[str, Any]:
312
+ return _extract_tool_output_data(state)
165
313
 
166
314
  return _create_guardrail_node(
167
315
  guardrail,
@@ -170,4 +318,7 @@ def create_tool_guardrail_node(
170
318
  _payload_generator,
171
319
  success_node,
172
320
  failure_node,
321
+ _input_data_extractor,
322
+ _output_data_extractor,
323
+ tool_name,
173
324
  )