uipath-langchain 0.1.24__py3-none-any.whl → 0.1.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,6 +18,7 @@ class GuardrailAction(ABC):
18
18
  guardrail: BaseGuardrail,
19
19
  scope: GuardrailScope,
20
20
  execution_stage: ExecutionStage,
21
+ guarded_component_name: str,
21
22
  ) -> GuardrailActionNode:
22
23
  """Create and return the Action node to execute on validation failure."""
23
24
  ...
@@ -26,6 +26,7 @@ class BlockAction(GuardrailAction):
26
26
  guardrail: BaseGuardrail,
27
27
  scope: GuardrailScope,
28
28
  execution_stage: ExecutionStage,
29
+ guarded_component_name: str,
29
30
  ) -> GuardrailActionNode:
30
31
  raw_node_name = f"{scope.name}_{execution_stage.name}_{guardrail.name}_block"
31
32
  node_name = re.sub(r"\W+", "_", raw_node_name.lower()).strip("_")
@@ -4,7 +4,7 @@ import json
4
4
  import re
5
5
  from typing import Any, Dict, Literal
6
6
 
7
- from langchain_core.messages import AIMessage, ToolMessage
7
+ from langchain_core.messages import AIMessage, ToolCall, ToolMessage
8
8
  from langgraph.types import Command, interrupt
9
9
  from uipath.platform.common import CreateEscalation
10
10
  from uipath.platform.guardrails import (
@@ -34,6 +34,14 @@ class EscalateAction(GuardrailAction):
34
34
  version: int,
35
35
  assignee: str,
36
36
  ):
37
+ """Initialize EscalateAction with escalation app configuration.
38
+
39
+ Args:
40
+ app_name: Name of the escalation app.
41
+ app_folder_path: Folder path where the escalation app is located.
42
+ version: Version of the escalation app.
43
+ assignee: User or role assigned to handle the escalation.
44
+ """
37
45
  self.app_name = app_name
38
46
  self.app_folder_path = app_folder_path
39
47
  self.version = version
@@ -45,13 +53,27 @@ class EscalateAction(GuardrailAction):
45
53
  guardrail: BaseGuardrail,
46
54
  scope: GuardrailScope,
47
55
  execution_stage: ExecutionStage,
56
+ guarded_component_name: str,
48
57
  ) -> GuardrailActionNode:
58
+ """Create a HITL escalation node for the guardrail.
59
+
60
+ Args:
61
+ guardrail: The guardrail that triggered this escalation action.
62
+ scope: The guardrail scope (LLM/AGENT/TOOL).
63
+ execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
64
+
65
+ Returns:
66
+ A tuple of (node_name, node_function) where the node function triggers
67
+ a HITL interruption and processes the escalation response.
68
+ """
49
69
  node_name = _get_node_name(execution_stage, guardrail, scope)
50
70
 
51
71
  async def _node(
52
72
  state: AgentGuardrailsGraphState,
53
73
  ) -> Dict[str, Any] | Command[Any]:
54
- input = _extract_escalation_content(state, scope, execution_stage)
74
+ input = _extract_escalation_content(
75
+ state, scope, execution_stage, guarded_component_name
76
+ )
55
77
  escalation_field = _execution_stage_to_escalation_field(execution_stage)
56
78
 
57
79
  data = {
@@ -75,7 +97,11 @@ class EscalateAction(GuardrailAction):
75
97
 
76
98
  if escalation_result.action == "Approve":
77
99
  return _process_escalation_response(
78
- state, escalation_result.data, scope, execution_stage
100
+ state,
101
+ escalation_result.data,
102
+ scope,
103
+ execution_stage,
104
+ guarded_component_name,
79
105
  )
80
106
 
81
107
  raise AgentTerminationException(
@@ -95,46 +121,58 @@ def _get_node_name(
95
121
  return node_name
96
122
 
97
123
 
98
- def _execution_stage_to_string(
124
+ def _process_escalation_response(
125
+ state: AgentGuardrailsGraphState,
126
+ escalation_result: Dict[str, Any],
127
+ scope: GuardrailScope,
99
128
  execution_stage: ExecutionStage,
100
- ) -> Literal["PreExecution", "PostExecution"]:
101
- """Convert ExecutionStage enum to string literal.
129
+ guarded_node_name: str,
130
+ ) -> Dict[str, Any] | Command[Any]:
131
+ """Process escalation response and route to appropriate handler based on scope.
102
132
 
103
133
  Args:
104
- execution_stage: The execution stage enum.
134
+ state: The current agent graph state.
135
+ escalation_result: The result from the escalation interrupt containing reviewed inputs/outputs.
136
+ scope: The guardrail scope (LLM/AGENT/TOOL).
137
+ execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
105
138
 
106
139
  Returns:
107
- "PreExecution" for PRE_EXECUTION, "PostExecution" for POST_EXECUTION.
140
+ For LLM/TOOL scope: Command to update messages with reviewed inputs/outputs, or empty dict.
141
+ For AGENT scope: Empty dict (no message alteration).
108
142
  """
109
- if execution_stage == ExecutionStage.PRE_EXECUTION:
110
- return "PreExecution"
111
- return "PostExecution"
143
+ match scope:
144
+ case GuardrailScope.LLM:
145
+ return _process_llm_escalation_response(
146
+ state, escalation_result, execution_stage
147
+ )
148
+ case GuardrailScope.TOOL:
149
+ return _process_tool_escalation_response(
150
+ state, escalation_result, execution_stage, guarded_node_name
151
+ )
152
+ case GuardrailScope.AGENT:
153
+ return {}
112
154
 
113
155
 
114
- def _process_escalation_response(
156
+ def _process_llm_escalation_response(
115
157
  state: AgentGuardrailsGraphState,
116
158
  escalation_result: Dict[str, Any],
117
- scope: GuardrailScope,
118
159
  execution_stage: ExecutionStage,
119
160
  ) -> Dict[str, Any] | Command[Any]:
120
- """Process escalation response and update state based on guardrail scope.
161
+ """Process escalation response for LLM scope guardrails.
162
+
163
+ Updates message content or tool calls based on reviewed inputs/outputs from escalation.
121
164
 
122
165
  Args:
123
166
  state: The current agent graph state.
124
- escalation_result: The result from the escalation interrupt.
125
- scope: The guardrail scope (LLM/AGENT/TOOL).
126
- execution_stage: The hook type ("PreExecution" or "PostExecution").
167
+ escalation_result: The result from the escalation interrupt containing reviewed inputs/outputs.
168
+ execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
127
169
 
128
170
  Returns:
129
- For LLM scope: Command to update messages with reviewed inputs/outputs.
130
- For non-LLM scope: Empty dict (no message alteration).
171
+ Command to update messages with reviewed inputs/outputs, or empty dict if no updates needed.
131
172
 
132
173
  Raises:
133
174
  AgentTerminationException: If escalation response processing fails.
134
175
  """
135
- if scope != GuardrailScope.LLM:
136
- return {}
137
-
138
176
  try:
139
177
  reviewed_field = (
140
178
  "ReviewedInputs"
@@ -200,33 +238,140 @@ def _process_escalation_response(
200
238
  ) from e
201
239
 
202
240
 
241
+ def _process_tool_escalation_response(
242
+ state: AgentGuardrailsGraphState,
243
+ escalation_result: Dict[str, Any],
244
+ execution_stage: ExecutionStage,
245
+ tool_name: str,
246
+ ) -> Dict[str, Any] | Command[Any]:
247
+ """Process escalation response for TOOL scope guardrails.
248
+
249
+ Updates the tool call arguments (PreExecution) or tool message content (PostExecution)
250
+ for the specific tool matching the tool_name. For PreExecution, finds the tool call
251
+ with the matching name and updates only that tool call's args with the reviewed dict.
252
+ For PostExecution, updates the tool message content.
253
+
254
+ Args:
255
+ state: The current agent graph state.
256
+ escalation_result: The result from the escalation interrupt containing reviewed inputs/outputs.
257
+ execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
258
+ tool_name: Name of the tool to update. Only the tool call matching this name will be updated.
259
+
260
+ Returns:
261
+ Command to update messages with reviewed tool call args or content, or empty dict if no updates needed.
262
+
263
+ Raises:
264
+ AgentTerminationException: If escalation response processing fails.
265
+ """
266
+ try:
267
+ reviewed_field = (
268
+ "ReviewedInputs"
269
+ if execution_stage == ExecutionStage.PRE_EXECUTION
270
+ else "ReviewedOutputs"
271
+ )
272
+
273
+ msgs = state.messages.copy()
274
+ if not msgs or reviewed_field not in escalation_result:
275
+ return {}
276
+
277
+ last_message = msgs[-1]
278
+ if execution_stage == ExecutionStage.PRE_EXECUTION:
279
+ if not isinstance(last_message, AIMessage):
280
+ return {}
281
+
282
+ # Get reviewed tool calls args from escalation result
283
+ reviewed_inputs_json = escalation_result[reviewed_field]
284
+ if not reviewed_inputs_json:
285
+ return {}
286
+
287
+ reviewed_tool_calls_args = json.loads(reviewed_inputs_json)
288
+ if not isinstance(reviewed_tool_calls_args, dict):
289
+ return {}
290
+
291
+ # Find and update only the tool call with matching name
292
+ if last_message.tool_calls:
293
+ tool_calls = list(last_message.tool_calls)
294
+ for tool_call in tool_calls:
295
+ call_name = extract_tool_name(tool_call)
296
+ if call_name == tool_name:
297
+ # Update args for the matching tool call
298
+ if isinstance(reviewed_tool_calls_args, dict):
299
+ if isinstance(tool_call, dict):
300
+ tool_call["args"] = reviewed_tool_calls_args
301
+ else:
302
+ tool_call.args = reviewed_tool_calls_args
303
+ break
304
+ last_message.tool_calls = tool_calls
305
+ else:
306
+ if not isinstance(last_message, ToolMessage):
307
+ return {}
308
+
309
+ # PostExecution: update tool message content
310
+ reviewed_outputs_json = escalation_result[reviewed_field]
311
+ if reviewed_outputs_json:
312
+ last_message.content = reviewed_outputs_json
313
+
314
+ return Command[Any](update={"messages": msgs})
315
+ except Exception as e:
316
+ raise AgentTerminationException(
317
+ code=UiPathErrorCode.EXECUTION_ERROR,
318
+ title="Escalation rejected",
319
+ detail=str(e),
320
+ ) from e
321
+
322
+
203
323
  def _extract_escalation_content(
204
324
  state: AgentGuardrailsGraphState,
205
325
  scope: GuardrailScope,
206
326
  execution_stage: ExecutionStage,
327
+ guarded_node_name: str,
207
328
  ) -> str | list[str | Dict[str, Any]]:
208
329
  """Extract escalation content from state based on guardrail scope and execution stage.
209
330
 
210
331
  Args:
211
332
  state: The current agent graph state.
212
333
  scope: The guardrail scope (LLM/AGENT/TOOL).
213
- execution_stage: The execution stage enum.
334
+ execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
214
335
 
215
336
  Returns:
216
- For non-LLM scope: Empty string.
217
- For LLM PreExecution: JSON string with message content.
218
- For LLM PostExecution: JSON array with tool call content and message content.
219
- """
220
- if scope != GuardrailScope.LLM:
221
- return ""
337
+ str or list[str | Dict[str, Any]]: For LLM scope, returns JSON string or list with message/tool call content.
338
+ For AGENT scope, returns empty string. For TOOL scope, returns JSON string or list with tool-specific content.
222
339
 
340
+ Raises:
341
+ AgentTerminationException: If no messages are found in state.
342
+ """
223
343
  if not state.messages:
224
344
  raise AgentTerminationException(
225
345
  code=UiPathErrorCode.EXECUTION_ERROR,
226
346
  title="Invalid state message",
227
- detail="No messages in state",
347
+ detail="No message found into agent state",
228
348
  )
229
349
 
350
+ match scope:
351
+ case GuardrailScope.LLM:
352
+ return _extract_llm_escalation_content(state, execution_stage)
353
+ case GuardrailScope.AGENT:
354
+ return _extract_agent_escalation_content(state, execution_stage)
355
+ case GuardrailScope.TOOL:
356
+ return _extract_tool_escalation_content(
357
+ state, execution_stage, guarded_node_name
358
+ )
359
+
360
+
361
+ def _extract_llm_escalation_content(
362
+ state: AgentGuardrailsGraphState, execution_stage: ExecutionStage
363
+ ) -> str | list[str | Dict[str, Any]]:
364
+ """Extract escalation content for LLM scope guardrails.
365
+
366
+ Args:
367
+ state: The current agent graph state.
368
+ execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
369
+
370
+ Returns:
371
+ str or list[str | Dict[str, Any]]: For PreExecution, returns JSON string with message content or empty string.
372
+ For PostExecution, returns JSON string (array) with tool call content and message content.
373
+ Returns empty string if no content found.
374
+ """
230
375
  last_message = state.messages[-1]
231
376
  if execution_stage == ExecutionStage.PRE_EXECUTION:
232
377
  if isinstance(last_message, ToolMessage):
@@ -260,6 +405,70 @@ def _extract_escalation_content(
260
405
  return _message_text(last_message)
261
406
 
262
407
 
408
+ def _extract_agent_escalation_content(
409
+ state: AgentGuardrailsGraphState, execution_stage: ExecutionStage
410
+ ) -> str | list[str | Dict[str, Any]]:
411
+ """Extract escalation content for AGENT scope guardrails.
412
+
413
+ Args:
414
+ state: The current agent graph state.
415
+ execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
416
+
417
+ Returns:
418
+ str: Empty string (AGENT scope guardrails do not extract escalation content).
419
+ """
420
+ return ""
421
+
422
+
423
+ def _extract_tool_escalation_content(
424
+ state: AgentGuardrailsGraphState, execution_stage: ExecutionStage, tool_name: str
425
+ ) -> str | list[str | Dict[str, Any]]:
426
+ """Extract escalation content for TOOL scope guardrails.
427
+
428
+ Args:
429
+ state: The current agent graph state.
430
+ execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
431
+ tool_name: Optional tool name to filter tool calls. If provided, only extracts args for matching tool.
432
+
433
+ Returns:
434
+ str or list[str | Dict[str, Any]]: For PreExecution, returns JSON string with tool call arguments
435
+ for the specified tool name, or empty string if not found. For PostExecution, returns string with
436
+ tool message content, or empty string if message type doesn't match.
437
+ """
438
+ last_message = state.messages[-1]
439
+ if execution_stage == ExecutionStage.PRE_EXECUTION:
440
+ if not isinstance(last_message, AIMessage):
441
+ return ""
442
+ if not last_message.tool_calls:
443
+ return ""
444
+
445
+ # Find the tool call with matching name
446
+ for tool_call in last_message.tool_calls:
447
+ call_name = extract_tool_name(tool_call)
448
+ if call_name == tool_name:
449
+ # Extract args from the matching tool call
450
+ args = (
451
+ tool_call.get("args")
452
+ if isinstance(tool_call, dict)
453
+ else getattr(tool_call, "args", None)
454
+ )
455
+ if args is not None:
456
+ return json.dumps(args)
457
+ return ""
458
+ else:
459
+ if not isinstance(last_message, ToolMessage):
460
+ return ""
461
+ return last_message.content
462
+
463
+
464
+ def extract_tool_name(tool_call: ToolCall) -> Any | None:
465
+ return (
466
+ tool_call.get("name")
467
+ if isinstance(tool_call, dict)
468
+ else getattr(tool_call, "name", None)
469
+ )
470
+
471
+
263
472
  def _execution_stage_to_escalation_field(
264
473
  execution_stage: ExecutionStage,
265
474
  ) -> str:
@@ -272,3 +481,19 @@ def _execution_stage_to_escalation_field(
272
481
  "Inputs" for PRE_EXECUTION, "Outputs" for POST_EXECUTION.
273
482
  """
274
483
  return "Inputs" if execution_stage == ExecutionStage.PRE_EXECUTION else "Outputs"
484
+
485
+
486
+ def _execution_stage_to_string(
487
+ execution_stage: ExecutionStage,
488
+ ) -> Literal["PreExecution", "PostExecution"]:
489
+ """Convert ExecutionStage enum to string literal.
490
+
491
+ Args:
492
+ execution_stage: The execution stage enum.
493
+
494
+ Returns:
495
+ "PreExecution" for PRE_EXECUTION, "PostExecution" for POST_EXECUTION.
496
+ """
497
+ if execution_stage == ExecutionStage.PRE_EXECUTION:
498
+ return "PreExecution"
499
+ return "PostExecution"
@@ -31,6 +31,7 @@ class LogAction(GuardrailAction):
31
31
  guardrail: BaseGuardrail,
32
32
  scope: GuardrailScope,
33
33
  execution_stage: ExecutionStage,
34
+ guarded_component_name: str,
34
35
  ) -> GuardrailActionNode:
35
36
  """Create a guardrail action node that logs validation failures.
36
37
 
@@ -1,8 +1,9 @@
1
+ import json
1
2
  import logging
2
3
  import re
3
4
  from typing import Any, Callable
4
5
 
5
- from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage
6
+ from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, SystemMessage
6
7
  from langgraph.types import Command
7
8
  from uipath.platform import UiPath
8
9
  from uipath.platform.guardrails import (
@@ -108,11 +109,58 @@ def create_tool_guardrail_node(
108
109
  execution_stage: ExecutionStage,
109
110
  success_node: str,
110
111
  failure_node: str,
112
+ tool_name: str,
111
113
  ) -> tuple[str, Callable[[AgentGuardrailsGraphState], Any]]:
112
- # To be implemented in future PR
114
+ """Create a guardrail node for TOOL scope guardrails.
115
+
116
+ Args:
117
+ guardrail: The guardrail to evaluate.
118
+ execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
119
+ success_node: Node to route to on validation pass.
120
+ failure_node: Node to route to on validation fail.
121
+ tool_name: Name of the tool to extract arguments from.
122
+
123
+ Returns:
124
+ A tuple of (node_name, node_function) for the guardrail evaluation node.
125
+ """
126
+
113
127
  def _payload_generator(state: AgentGuardrailsGraphState) -> str:
128
+ """Extract tool call arguments for the specified tool name.
129
+
130
+ Args:
131
+ state: The current agent graph state.
132
+
133
+ Returns:
134
+ JSON string of the tool call arguments, or empty string if not found.
135
+ """
114
136
  if not state.messages:
115
137
  return ""
138
+
139
+ if execution_stage == ExecutionStage.PRE_EXECUTION:
140
+ if not isinstance(state.messages[-1], AIMessage):
141
+ return ""
142
+ message = state.messages[-1]
143
+
144
+ if not message.tool_calls:
145
+ return ""
146
+
147
+ # Find the first tool call with matching name
148
+ for tool_call in message.tool_calls:
149
+ call_name = (
150
+ tool_call.get("name")
151
+ if isinstance(tool_call, dict)
152
+ else getattr(tool_call, "name", None)
153
+ )
154
+ if call_name == tool_name:
155
+ # Extract args from the tool call
156
+ args = (
157
+ tool_call.get("args")
158
+ if isinstance(tool_call, dict)
159
+ else getattr(tool_call, "args", None)
160
+ )
161
+ if args is not None:
162
+ return json.dumps(args)
163
+
116
164
  return _message_text(state.messages[-1])
117
165
 
118
166
  return _create_guardrail_node(
@@ -1,7 +1,9 @@
1
+ from functools import partial
1
2
  from typing import Any, Callable, Sequence
2
3
 
3
4
  from langgraph.constants import END, START
4
5
  from langgraph.graph import StateGraph
6
+ from langgraph.prebuilt import ToolNode
5
7
  from uipath.platform.guardrails import (
6
8
  BaseGuardrail,
7
9
  BuiltInValidatorGuardrail,
@@ -90,6 +92,7 @@ def _create_guardrails_subgraph(
90
92
  ExecutionStage.PRE_EXECUTION,
91
93
  node_factory,
92
94
  inner_name,
95
+ inner_name,
93
96
  )
94
97
  subgraph.add_edge(START, first_pre_exec_guardrail_node)
95
98
  else:
@@ -107,6 +110,7 @@ def _create_guardrails_subgraph(
107
110
  ExecutionStage.POST_EXECUTION,
108
111
  node_factory,
109
112
  END,
113
+ inner_node,
110
114
  )
111
115
  subgraph.add_edge(inner_name, first_post_exec_guardrail_node)
112
116
  else:
@@ -130,6 +134,7 @@ def _build_guardrail_node_chain(
130
134
  GuardrailActionNode,
131
135
  ],
132
136
  next_node: str,
137
+ guarded_node_name: str,
133
138
  ) -> str:
134
139
  """Recursively build a chain of guardrail nodes in reverse order.
135
140
 
@@ -157,7 +162,10 @@ def _build_guardrail_node_chain(
157
162
  remaining_guardrails = guardrails[:-1]
158
163
 
159
164
  fail_node_name, fail_node = action.action_node(
160
- guardrail=guardrail, scope=scope, execution_stage=execution_stage
165
+ guardrail=guardrail,
166
+ scope=scope,
167
+ execution_stage=execution_stage,
168
+ guarded_component_name=guarded_node_name,
161
169
  )
162
170
 
163
171
  # Create the guardrail evaluation node.
@@ -179,6 +187,7 @@ def _build_guardrail_node_chain(
179
187
  execution_stage,
180
188
  node_factory,
181
189
  guardrail_node_name,
190
+ guarded_node_name,
182
191
  )
183
192
 
184
193
  return previous_node_name
@@ -193,6 +202,9 @@ def create_llm_guardrails_subgraph(
193
202
  for (guardrail, _) in (guardrails or [])
194
203
  if GuardrailScope.LLM in guardrail.selector.scopes
195
204
  ]
205
+ if applicable_guardrails is None or len(applicable_guardrails) == 0:
206
+ return llm_node[1]
207
+
196
208
  return _create_guardrails_subgraph(
197
209
  main_inner_node=llm_node,
198
210
  guardrails=applicable_guardrails,
@@ -202,6 +214,24 @@ def create_llm_guardrails_subgraph(
202
214
  )
203
215
 
204
216
 
217
+ def create_tools_guardrails_subgraph(
218
+ tool_nodes: dict[str, ToolNode],
219
+ guardrails: Sequence[tuple[BaseGuardrail, GuardrailAction]] | None,
220
+ ) -> dict[str, ToolNode]:
221
+ """Create tool nodes with guardrails.
222
+ Args:
223
+ """
224
+ result: dict[str, ToolNode] = {}
225
+ for tool_name, tool_node in tool_nodes.items():
226
+ subgraph = create_tool_guardrails_subgraph(
227
+ (tool_name, tool_node),
228
+ guardrails,
229
+ )
230
+ result[tool_name] = subgraph
231
+
232
+ return result
233
+
234
+
205
235
  def create_agent_guardrails_subgraph(
206
236
  agent_node: tuple[str, Any],
207
237
  guardrails: Sequence[tuple[BaseGuardrail, GuardrailAction]] | None,
@@ -217,6 +247,9 @@ def create_agent_guardrails_subgraph(
217
247
  for (guardrail, _) in (guardrails or [])
218
248
  if GuardrailScope.AGENT in guardrail.selector.scopes
219
249
  ]
250
+ if applicable_guardrails is None or len(applicable_guardrails) == 0:
251
+ return agent_node[1]
252
+
220
253
  return _create_guardrails_subgraph(
221
254
  main_inner_node=agent_node,
222
255
  guardrails=applicable_guardrails,
@@ -238,10 +271,13 @@ def create_tool_guardrails_subgraph(
238
271
  and guardrail.selector.match_names is not None
239
272
  and tool_name in guardrail.selector.match_names
240
273
  ]
274
+ if applicable_guardrails is None or len(applicable_guardrails) == 0:
275
+ return tool_node[1]
276
+
241
277
  return _create_guardrails_subgraph(
242
278
  main_inner_node=tool_node,
243
279
  guardrails=applicable_guardrails,
244
280
  scope=GuardrailScope.TOOL,
245
281
  execution_stages=[ExecutionStage.PRE_EXECUTION, ExecutionStage.POST_EXECUTION],
246
- node_factory=create_tool_guardrail_node,
282
+ node_factory=partial(create_tool_guardrail_node, tool_name=tool_name),
247
283
  )
@@ -11,6 +11,7 @@ from uipath.platform.guardrails import BaseGuardrail
11
11
 
12
12
  from ..guardrails import create_llm_guardrails_subgraph
13
13
  from ..guardrails.actions import GuardrailAction
14
+ from ..guardrails.guardrails_subgraph import create_tools_guardrails_subgraph
14
15
  from ..tools import create_tool_node
15
16
  from .init_node import (
16
17
  create_init_node,
@@ -73,6 +74,9 @@ def create_agent(
73
74
 
74
75
  init_node = create_init_node(messages)
75
76
  tool_nodes = create_tool_node(agent_tools)
77
+ tool_nodes_with_guardrails = create_tools_guardrails_subgraph(
78
+ tool_nodes, guardrails
79
+ )
76
80
  terminate_node = create_terminate_node(output_schema)
77
81
 
78
82
  InnerAgentGraphState = create_state_with_input(
@@ -84,7 +88,7 @@ def create_agent(
84
88
  )
85
89
  builder.add_node(AgentGraphNode.INIT, init_node)
86
90
 
87
- for tool_name, tool_node in tool_nodes.items():
91
+ for tool_name, tool_node in tool_nodes_with_guardrails.items():
88
92
  builder.add_node(tool_name, tool_node)
89
93
 
90
94
  builder.add_node(AgentGraphNode.TERMINATE, terminate_node)
@@ -98,7 +102,7 @@ def create_agent(
98
102
  builder.add_node(AgentGraphNode.AGENT, llm_with_guardrails_subgraph)
99
103
  builder.add_edge(AgentGraphNode.INIT, AgentGraphNode.AGENT)
100
104
 
101
- tool_node_names = list(tool_nodes.keys())
105
+ tool_node_names = list(tool_nodes_with_guardrails.keys())
102
106
  builder.add_conditional_edges(
103
107
  AgentGraphNode.AGENT,
104
108
  route_agent,
@@ -4,6 +4,7 @@ from typing import Optional
4
4
 
5
5
  import httpx
6
6
  from langchain_openai import AzureChatOpenAI
7
+ from uipath._utils._ssl_context import get_httpx_client_kwargs
7
8
  from uipath.utils import EndpointManager
8
9
 
9
10
  from .supported_models import OpenAIModels
@@ -87,11 +88,11 @@ class UiPathChatOpenAI(AzureChatOpenAI):
87
88
  default_headers=self._build_headers(token),
88
89
  http_async_client=httpx.AsyncClient(
89
90
  transport=UiPathURLRewriteTransport(verify=True),
90
- verify=True,
91
+ **get_httpx_client_kwargs(),
91
92
  ),
92
93
  http_client=httpx.Client(
93
94
  transport=UiPathSyncURLRewriteTransport(verify=True),
94
- verify=True,
95
+ **get_httpx_client_kwargs(),
95
96
  ),
96
97
  api_key=token,
97
98
  api_version=api_version,