uipath-langchain 0.1.24__py3-none-any.whl → 0.1.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. uipath_langchain/_utils/_request_mixin.py +8 -0
  2. uipath_langchain/_utils/_settings.py +3 -2
  3. uipath_langchain/agent/guardrails/__init__.py +0 -16
  4. uipath_langchain/agent/guardrails/actions/__init__.py +2 -0
  5. uipath_langchain/agent/guardrails/actions/base_action.py +1 -0
  6. uipath_langchain/agent/guardrails/actions/block_action.py +2 -1
  7. uipath_langchain/agent/guardrails/actions/escalate_action.py +243 -35
  8. uipath_langchain/agent/guardrails/actions/filter_action.py +55 -0
  9. uipath_langchain/agent/guardrails/actions/log_action.py +2 -1
  10. uipath_langchain/agent/guardrails/guardrail_nodes.py +186 -22
  11. uipath_langchain/agent/guardrails/guardrails_factory.py +200 -4
  12. uipath_langchain/agent/guardrails/types.py +0 -12
  13. uipath_langchain/agent/guardrails/utils.py +146 -0
  14. uipath_langchain/agent/react/agent.py +25 -8
  15. uipath_langchain/agent/react/constants.py +1 -2
  16. uipath_langchain/agent/{guardrails → react/guardrails}/guardrails_subgraph.py +94 -19
  17. uipath_langchain/agent/react/llm_node.py +41 -10
  18. uipath_langchain/agent/react/router.py +48 -37
  19. uipath_langchain/agent/react/types.py +15 -1
  20. uipath_langchain/agent/react/utils.py +1 -1
  21. uipath_langchain/agent/tools/__init__.py +2 -0
  22. uipath_langchain/agent/tools/mcp_tool.py +86 -0
  23. uipath_langchain/chat/__init__.py +4 -0
  24. uipath_langchain/chat/bedrock.py +16 -0
  25. uipath_langchain/chat/openai.py +57 -26
  26. uipath_langchain/chat/supported_models.py +9 -0
  27. uipath_langchain/chat/vertex.py +271 -0
  28. uipath_langchain/embeddings/embeddings.py +18 -12
  29. uipath_langchain/runtime/schema.py +116 -23
  30. {uipath_langchain-0.1.24.dist-info → uipath_langchain-0.1.34.dist-info}/METADATA +9 -6
  31. {uipath_langchain-0.1.24.dist-info → uipath_langchain-0.1.34.dist-info}/RECORD +34 -31
  32. uipath_langchain/chat/gemini.py +0 -330
  33. {uipath_langchain-0.1.24.dist-info → uipath_langchain-0.1.34.dist-info}/WHEEL +0 -0
  34. {uipath_langchain-0.1.24.dist-info → uipath_langchain-0.1.34.dist-info}/entry_points.txt +0 -0
  35. {uipath_langchain-0.1.24.dist-info → uipath_langchain-0.1.34.dist-info}/licenses/LICENSE +0 -0
@@ -1,26 +1,109 @@
1
+ import json
1
2
  import logging
2
3
  import re
3
4
  from typing import Any, Callable
4
5
 
5
- from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage
6
6
  from langgraph.types import Command
7
+ from uipath.core.guardrails import (
8
+ DeterministicGuardrail,
9
+ DeterministicGuardrailsService,
10
+ )
7
11
  from uipath.platform import UiPath
8
12
  from uipath.platform.guardrails import (
9
13
  BaseGuardrail,
14
+ BuiltInValidatorGuardrail,
10
15
  GuardrailScope,
11
16
  )
17
+ from uipath.runtime.errors import UiPathErrorCode
12
18
 
13
19
  from uipath_langchain.agent.guardrails.types import ExecutionStage
20
+ from uipath_langchain.agent.guardrails.utils import (
21
+ _extract_tool_input_data,
22
+ _extract_tool_output_data,
23
+ get_message_content,
24
+ )
25
+ from uipath_langchain.agent.react.types import AgentGuardrailsGraphState
14
26
 
15
- from .types import AgentGuardrailsGraphState
27
+ from ..exceptions import AgentTerminationException
16
28
 
17
29
  logger = logging.getLogger(__name__)
18
30
 
19
31
 
20
- def _message_text(msg: AnyMessage) -> str:
21
- if isinstance(msg, (HumanMessage, SystemMessage)):
22
- return msg.content if isinstance(msg.content, str) else str(msg.content)
23
- return str(getattr(msg, "content", "")) if hasattr(msg, "content") else ""
32
+ def _evaluate_deterministic_guardrail(
33
+ state: AgentGuardrailsGraphState,
34
+ guardrail: DeterministicGuardrail,
35
+ execution_stage: ExecutionStage,
36
+ input_data_extractor: Callable[[AgentGuardrailsGraphState], dict[str, Any]],
37
+ output_data_extractor: Callable[[AgentGuardrailsGraphState], dict[str, Any]] | None,
38
+ ):
39
+ """Evaluate deterministic guardrail.
40
+
41
+ Args:
42
+ state: The current agent graph state.
43
+ guardrail: The deterministic guardrail to evaluate.
44
+ execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
45
+ input_data_extractor: Function to extract input data from state.
46
+ output_data_extractor: Function to extract output data from state (optional).
47
+
48
+ Returns:
49
+ The guardrail evaluation result.
50
+ """
51
+ service = DeterministicGuardrailsService()
52
+ input_data = input_data_extractor(state)
53
+
54
+ if execution_stage == ExecutionStage.PRE_EXECUTION:
55
+ return service.evaluate_pre_deterministic_guardrail(
56
+ input_data=input_data, guardrail=guardrail
57
+ )
58
+ else: # POST_EXECUTION
59
+ output_data = output_data_extractor(state) if output_data_extractor else {}
60
+ return service.evaluate_post_deterministic_guardrail(
61
+ input_data=input_data,
62
+ output_data=output_data,
63
+ guardrail=guardrail,
64
+ )
65
+
66
+
67
+ def _evaluate_builtin_guardrail(
68
+ state: AgentGuardrailsGraphState,
69
+ guardrail: BuiltInValidatorGuardrail,
70
+ payload_generator: Callable[[AgentGuardrailsGraphState], str],
71
+ ):
72
+ """Evaluate built-in validator guardrail.
73
+
74
+ Args:
75
+ state: The current agent graph state.
76
+ guardrail: The built-in validator guardrail to evaluate.
77
+ payload_generator: Function to generate payload text from state.
78
+
79
+ Returns:
80
+ The guardrail evaluation result.
81
+ """
82
+ text = payload_generator(state)
83
+ uipath = UiPath()
84
+ return uipath.guardrails.evaluate_guardrail(text, guardrail)
85
+
86
+
87
+ def _create_validation_command(
88
+ result,
89
+ success_node: str,
90
+ failure_node: str,
91
+ ) -> Command[Any]:
92
+ """Create command based on validation result.
93
+
94
+ Args:
95
+ result: The guardrail evaluation result.
96
+ success_node: Node to route to on validation pass.
97
+ failure_node: Node to route to on validation fail.
98
+
99
+ Returns:
100
+ Command to update state and route to appropriate node.
101
+ """
102
+ if not result.validation_passed:
103
+ return Command(
104
+ goto=failure_node, update={"guardrail_validation_result": result.reason}
105
+ )
106
+ return Command(goto=success_node, update={"guardrail_validation_result": None})
24
107
 
25
108
 
26
109
  def _create_guardrail_node(
@@ -30,6 +113,10 @@ def _create_guardrail_node(
30
113
  payload_generator: Callable[[AgentGuardrailsGraphState], str],
31
114
  success_node: str,
32
115
  failure_node: str,
116
+ input_data_extractor: Callable[[AgentGuardrailsGraphState], dict[str, Any]]
117
+ | None = None,
118
+ output_data_extractor: Callable[[AgentGuardrailsGraphState], dict[str, Any]]
119
+ | None = None,
33
120
  ) -> tuple[str, Callable[[AgentGuardrailsGraphState], Any]]:
34
121
  """Private factory for guardrail evaluation nodes.
35
122
 
@@ -43,19 +130,41 @@ def _create_guardrail_node(
43
130
  async def node(
44
131
  state: AgentGuardrailsGraphState,
45
132
  ):
46
- text = payload_generator(state)
47
133
  try:
48
- uipath = UiPath()
49
- result = uipath.guardrails.evaluate_guardrail(text, guardrail)
50
- except Exception as exc:
51
- logger.error("Failed to evaluate guardrail: %s", exc)
52
- raise
134
+ # Route to appropriate evaluation service based on guardrail type and scope
135
+ if (
136
+ isinstance(guardrail, DeterministicGuardrail)
137
+ and scope == GuardrailScope.TOOL
138
+ and input_data_extractor is not None
139
+ ):
140
+ result = _evaluate_deterministic_guardrail(
141
+ state,
142
+ guardrail,
143
+ execution_stage,
144
+ input_data_extractor,
145
+ output_data_extractor,
146
+ )
147
+ elif isinstance(guardrail, BuiltInValidatorGuardrail):
148
+ result = _evaluate_builtin_guardrail(
149
+ state, guardrail, payload_generator
150
+ )
151
+ else:
152
+ raise AgentTerminationException(
153
+ code=UiPathErrorCode.EXECUTION_ERROR,
154
+ title="Unsupported guardrail type",
155
+ detail=f"Guardrail type '{type(guardrail).__name__}' is not supported. "
156
+ f"Expected DeterministicGuardrail or BuiltInValidatorGuardrail.",
157
+ )
158
+
159
+ return _create_validation_command(result, success_node, failure_node)
53
160
 
54
- if not result.validation_passed:
55
- return Command(
56
- goto=failure_node, update={"guardrail_validation_result": result.reason}
161
+ except Exception as exc:
162
+ logger.error(
163
+ "Failed to evaluate guardrail '%s': %s",
164
+ guardrail.name,
165
+ exc,
57
166
  )
58
- return Command(goto=success_node, update={"guardrail_validation_result": None})
167
+ raise
59
168
 
60
169
  return node_name, node
61
170
 
@@ -69,7 +178,7 @@ def create_llm_guardrail_node(
69
178
  def _payload_generator(state: AgentGuardrailsGraphState) -> str:
70
179
  if not state.messages:
71
180
  return ""
72
- return _message_text(state.messages[-1])
181
+ return get_message_content(state.messages[-1])
73
182
 
74
183
  return _create_guardrail_node(
75
184
  guardrail,
@@ -81,17 +190,35 @@ def create_llm_guardrail_node(
81
190
  )
82
191
 
83
192
 
84
- def create_agent_guardrail_node(
193
+ def create_agent_init_guardrail_node(
85
194
  guardrail: BaseGuardrail,
86
195
  execution_stage: ExecutionStage,
87
196
  success_node: str,
88
197
  failure_node: str,
89
198
  ) -> tuple[str, Callable[[AgentGuardrailsGraphState], Any]]:
90
- # To be implemented in future PR
91
199
  def _payload_generator(state: AgentGuardrailsGraphState) -> str:
92
200
  if not state.messages:
93
201
  return ""
94
- return _message_text(state.messages[-1])
202
+ return get_message_content(state.messages[-1])
203
+
204
+ return _create_guardrail_node(
205
+ guardrail,
206
+ GuardrailScope.AGENT,
207
+ execution_stage,
208
+ _payload_generator,
209
+ success_node,
210
+ failure_node,
211
+ )
212
+
213
+
214
+ def create_agent_terminate_guardrail_node(
215
+ guardrail: BaseGuardrail,
216
+ execution_stage: ExecutionStage,
217
+ success_node: str,
218
+ failure_node: str,
219
+ ) -> tuple[str, Callable[[AgentGuardrailsGraphState], Any]]:
220
+ def _payload_generator(state: AgentGuardrailsGraphState) -> str:
221
+ return str(state.agent_result)
95
222
 
96
223
  return _create_guardrail_node(
97
224
  guardrail,
@@ -108,12 +235,47 @@ def create_tool_guardrail_node(
108
235
  execution_stage: ExecutionStage,
109
236
  success_node: str,
110
237
  failure_node: str,
238
+ tool_name: str,
111
239
  ) -> tuple[str, Callable[[AgentGuardrailsGraphState], Any]]:
112
- # To be implemented in future PR
240
+ """Create a guardrail node for TOOL scope guardrails.
241
+
242
+ Args:
243
+ guardrail: The guardrail to evaluate.
244
+ execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
245
+ success_node: Node to route to on validation pass.
246
+ failure_node: Node to route to on validation fail.
247
+ tool_name: Name of the tool to extract arguments from.
248
+
249
+ Returns:
250
+ A tuple of (node_name, node_function) for the guardrail evaluation node.
251
+ """
252
+
113
253
  def _payload_generator(state: AgentGuardrailsGraphState) -> str:
254
+ """Extract tool call arguments for the specified tool name.
255
+
256
+ Args:
257
+ state: The current agent graph state.
258
+
259
+ Returns:
260
+ JSON string of the tool call arguments, or empty string if not found.
261
+ """
114
262
  if not state.messages:
115
263
  return ""
116
- return _message_text(state.messages[-1])
264
+
265
+ if execution_stage == ExecutionStage.PRE_EXECUTION:
266
+ # Extract tool args as dict and convert to JSON string
267
+ args_dict = _extract_tool_input_data(state, tool_name, execution_stage)
268
+ if args_dict:
269
+ return json.dumps(args_dict)
270
+
271
+ return get_message_content(state.messages[-1])
272
+
273
+ # Create closures for input/output data extraction (for deterministic guardrails)
274
+ def _input_data_extractor(state: AgentGuardrailsGraphState) -> dict[str, Any]:
275
+ return _extract_tool_input_data(state, tool_name, execution_stage)
276
+
277
+ def _output_data_extractor(state: AgentGuardrailsGraphState) -> dict[str, Any]:
278
+ return _extract_tool_output_data(state)
117
279
 
118
280
  return _create_guardrail_node(
119
281
  guardrail,
@@ -122,4 +284,6 @@ def create_tool_guardrail_node(
122
284
  _payload_generator,
123
285
  success_node,
124
286
  failure_node,
287
+ _input_data_extractor,
288
+ _output_data_extractor,
125
289
  )
@@ -1,13 +1,28 @@
1
1
  import logging
2
- from typing import Sequence
2
+ import re
3
+ from typing import Callable, Sequence
3
4
 
4
5
  from uipath.agent.models.agent import (
6
+ AgentBooleanOperator,
7
+ AgentBooleanRule,
8
+ AgentCustomGuardrail,
5
9
  AgentGuardrail,
6
10
  AgentGuardrailBlockAction,
7
11
  AgentGuardrailEscalateAction,
8
12
  AgentGuardrailLogAction,
9
13
  AgentGuardrailSeverityLevel,
14
+ AgentNumberOperator,
15
+ AgentNumberRule,
10
16
  AgentUnknownGuardrail,
17
+ AgentWordOperator,
18
+ AgentWordRule,
19
+ )
20
+ from uipath.core.guardrails import (
21
+ BooleanRule,
22
+ DeterministicGuardrail,
23
+ NumberRule,
24
+ UniversalRule,
25
+ WordRule,
11
26
  )
12
27
  from uipath.platform.guardrails import BaseGuardrail
13
28
 
@@ -19,6 +34,180 @@ from uipath_langchain.agent.guardrails.actions import (
19
34
  )
20
35
 
21
36
 
37
+ def _assert_value_not_none(value: str | None, operator: AgentWordOperator) -> str:
38
+ """Assert value is not None and return as string."""
39
+ assert value is not None, f"value cannot be None for {operator.name} operator"
40
+ return value
41
+
42
+
43
+ def _create_word_rule_func(
44
+ operator: AgentWordOperator, value: str | None
45
+ ) -> Callable[[str], bool]:
46
+ """Create a callable function from AgentWordOperator and value.
47
+
48
+ Args:
49
+ operator: The word operator to convert.
50
+ value: The value to compare against (may be None for isEmpty/isNotEmpty).
51
+
52
+ Returns:
53
+ A callable that takes a string and returns a boolean.
54
+ """
55
+ match operator:
56
+ case AgentWordOperator.CONTAINS:
57
+ val = _assert_value_not_none(value, operator)
58
+ return lambda s: val.lower() in s.lower()
59
+ case AgentWordOperator.DOES_NOT_CONTAIN:
60
+ val = _assert_value_not_none(value, operator)
61
+ return lambda s: val.lower() not in s.lower()
62
+ case AgentWordOperator.EQUALS:
63
+ val = _assert_value_not_none(value, operator)
64
+ return lambda s: s == val
65
+ case AgentWordOperator.DOES_NOT_EQUAL:
66
+ val = _assert_value_not_none(value, operator)
67
+ return lambda s: s != val
68
+ case AgentWordOperator.STARTS_WITH:
69
+ val = _assert_value_not_none(value, operator)
70
+ return lambda s: s.startswith(val)
71
+ case AgentWordOperator.DOES_NOT_START_WITH:
72
+ val = _assert_value_not_none(value, operator)
73
+ return lambda s: not s.startswith(val)
74
+ case AgentWordOperator.ENDS_WITH:
75
+ val = _assert_value_not_none(value, operator)
76
+ return lambda s: s.endswith(val)
77
+ case AgentWordOperator.DOES_NOT_END_WITH:
78
+ val = _assert_value_not_none(value, operator)
79
+ return lambda s: not s.endswith(val)
80
+ case AgentWordOperator.IS_EMPTY:
81
+ return lambda s: len(s) == 0
82
+ case AgentWordOperator.IS_NOT_EMPTY:
83
+ return lambda s: len(s) > 0
84
+ case AgentWordOperator.MATCHES_REGEX:
85
+ val = _assert_value_not_none(value, operator)
86
+ pattern = re.compile(val)
87
+ return lambda s: bool(pattern.match(s))
88
+ case _:
89
+ raise ValueError(f"Unsupported word operator: {operator}")
90
+
91
+
92
+ def _create_number_rule_func(
93
+ operator: AgentNumberOperator, value: float
94
+ ) -> Callable[[float], bool]:
95
+ """Create a callable function from AgentNumberOperator and value.
96
+
97
+ Args:
98
+ operator: The number operator to convert.
99
+ value: The value to compare against.
100
+
101
+ Returns:
102
+ A callable that takes a float and returns a boolean.
103
+ """
104
+ match operator:
105
+ case AgentNumberOperator.EQUALS:
106
+ return lambda n: n == value
107
+ case AgentNumberOperator.DOES_NOT_EQUAL:
108
+ return lambda n: n != value
109
+ case AgentNumberOperator.GREATER_THAN:
110
+ return lambda n: n > value
111
+ case AgentNumberOperator.GREATER_THAN_OR_EQUAL:
112
+ return lambda n: n >= value
113
+ case AgentNumberOperator.LESS_THAN:
114
+ return lambda n: n < value
115
+ case AgentNumberOperator.LESS_THAN_OR_EQUAL:
116
+ return lambda n: n <= value
117
+ case _:
118
+ raise ValueError(f"Unsupported number operator: {operator}")
119
+
120
+
121
+ def _create_boolean_rule_func(
122
+ operator: AgentBooleanOperator, value: bool
123
+ ) -> Callable[[bool], bool]:
124
+ """Create a callable function from AgentBooleanOperator and value.
125
+
126
+ Args:
127
+ operator: The boolean operator to convert.
128
+ value: The value to compare against.
129
+
130
+ Returns:
131
+ A callable that takes a boolean and returns a boolean.
132
+ """
133
+ match operator:
134
+ case AgentBooleanOperator.EQUALS:
135
+ return lambda b: b == value
136
+ case _:
137
+ raise ValueError(f"Unsupported boolean operator: {operator}")
138
+
139
+
140
+ def _convert_agent_rule_to_deterministic(
141
+ agent_rule: AgentWordRule | AgentNumberRule | AgentBooleanRule | UniversalRule,
142
+ ) -> WordRule | NumberRule | BooleanRule | UniversalRule:
143
+ """Convert an Agent rule to its Deterministic equivalent.
144
+
145
+ Args:
146
+ agent_rule: The agent rule to convert.
147
+
148
+ Returns:
149
+ The corresponding deterministic rule with a callable function.
150
+ """
151
+ if isinstance(agent_rule, UniversalRule):
152
+ # UniversalRule is already compatible
153
+ return agent_rule
154
+
155
+ if isinstance(agent_rule, AgentWordRule):
156
+ return WordRule(
157
+ rule_type="word",
158
+ field_selector=agent_rule.field_selector,
159
+ detects_violation=_create_word_rule_func(
160
+ agent_rule.operator, agent_rule.value
161
+ ),
162
+ )
163
+
164
+ if isinstance(agent_rule, AgentNumberRule):
165
+ return NumberRule(
166
+ rule_type="number",
167
+ field_selector=agent_rule.field_selector,
168
+ detects_violation=_create_number_rule_func(
169
+ agent_rule.operator, agent_rule.value
170
+ ),
171
+ )
172
+
173
+ if isinstance(agent_rule, AgentBooleanRule):
174
+ return BooleanRule(
175
+ rule_type="boolean",
176
+ field_selector=agent_rule.field_selector,
177
+ detects_violation=_create_boolean_rule_func(
178
+ agent_rule.operator, agent_rule.value
179
+ ),
180
+ )
181
+
182
+ raise ValueError(f"Unsupported agent rule type: {type(agent_rule)}")
183
+
184
+
185
+ def _convert_agent_custom_guardrail_to_deterministic(
186
+ guardrail: AgentCustomGuardrail,
187
+ ) -> DeterministicGuardrail:
188
+ """Convert AgentCustomGuardrail to DeterministicGuardrail.
189
+
190
+ Args:
191
+ guardrail: The agent custom guardrail to convert.
192
+
193
+ Returns:
194
+ A DeterministicGuardrail with converted rules.
195
+ """
196
+ converted_rules = [
197
+ _convert_agent_rule_to_deterministic(rule) for rule in guardrail.rules
198
+ ]
199
+
200
+ return DeterministicGuardrail(
201
+ id=guardrail.id,
202
+ name=guardrail.name,
203
+ description=guardrail.description,
204
+ enabled_for_evals=guardrail.enabled_for_evals,
205
+ selector=guardrail.selector,
206
+ guardrail_type="custom",
207
+ rules=converted_rules,
208
+ )
209
+
210
+
22
211
  def build_guardrails_with_actions(
23
212
  guardrails: Sequence[AgentGuardrail] | None,
24
213
  ) -> list[tuple[BaseGuardrail, GuardrailAction]]:
@@ -38,10 +227,17 @@ def build_guardrails_with_actions(
38
227
  if isinstance(guardrail, AgentUnknownGuardrail):
39
228
  continue
40
229
 
230
+ # Convert AgentCustomGuardrail to DeterministicGuardrail
231
+ converted_guardrail: BaseGuardrail = guardrail
232
+ if isinstance(guardrail, AgentCustomGuardrail):
233
+ converted_guardrail = _convert_agent_custom_guardrail_to_deterministic(
234
+ guardrail
235
+ )
236
+
41
237
  action = guardrail.action
42
238
 
43
239
  if isinstance(action, AgentGuardrailBlockAction):
44
- result.append((guardrail, BlockAction(action.reason)))
240
+ result.append((converted_guardrail, BlockAction(action.reason)))
45
241
  elif isinstance(action, AgentGuardrailLogAction):
46
242
  severity_level_map = {
47
243
  AgentGuardrailSeverityLevel.ERROR: logging.ERROR,
@@ -51,14 +247,14 @@ def build_guardrails_with_actions(
51
247
  level = severity_level_map.get(action.severity_level, logging.INFO)
52
248
  result.append(
53
249
  (
54
- guardrail,
250
+ converted_guardrail,
55
251
  LogAction(message=action.message, level=level),
56
252
  )
57
253
  )
58
254
  elif isinstance(action, AgentGuardrailEscalateAction):
59
255
  result.append(
60
256
  (
61
- guardrail,
257
+ converted_guardrail,
62
258
  EscalateAction(
63
259
  app_name=action.app.name,
64
260
  app_folder_path=action.app.folder_name,
@@ -1,16 +1,4 @@
1
1
  from enum import Enum
2
- from typing import Annotated, Optional
3
-
4
- from langchain_core.messages import AnyMessage
5
- from langgraph.graph.message import add_messages
6
- from pydantic import BaseModel
7
-
8
-
9
- class AgentGuardrailsGraphState(BaseModel):
10
- """Agent Guardrails Graph state for guardrail subgraph."""
11
-
12
- messages: Annotated[list[AnyMessage], add_messages] = []
13
- guardrail_validation_result: Optional[str] = None
14
2
 
15
3
 
16
4
  class ExecutionStage(str, Enum):
@@ -0,0 +1,146 @@
1
+ import json
2
+ import logging
3
+ from typing import Any
4
+
5
+ from langchain_core.messages import (
6
+ AIMessage,
7
+ AnyMessage,
8
+ HumanMessage,
9
+ SystemMessage,
10
+ ToolMessage,
11
+ )
12
+
13
+ from uipath_langchain.agent.guardrails.types import ExecutionStage
14
+ from uipath_langchain.agent.react.types import AgentGuardrailsGraphState
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def _extract_tool_args_from_message(
20
+ message: AnyMessage, tool_name: str
21
+ ) -> dict[str, Any]:
22
+ """Extract tool call arguments from an AIMessage.
23
+
24
+ Args:
25
+ message: The message to extract from.
26
+ tool_name: Name of the tool to extract arguments from.
27
+
28
+ Returns:
29
+ Dict containing tool call arguments, or empty dict if not found.
30
+ """
31
+ if not isinstance(message, AIMessage):
32
+ return {}
33
+
34
+ if not message.tool_calls:
35
+ return {}
36
+
37
+ # Find the first tool call with matching name
38
+ for tool_call in message.tool_calls:
39
+ call_name = (
40
+ tool_call.get("name")
41
+ if isinstance(tool_call, dict)
42
+ else getattr(tool_call, "name", None)
43
+ )
44
+ if call_name == tool_name:
45
+ # Extract args from the tool call
46
+ args = (
47
+ tool_call.get("args")
48
+ if isinstance(tool_call, dict)
49
+ else getattr(tool_call, "args", None)
50
+ )
51
+ if args is not None:
52
+ # Args should already be a dict
53
+ if isinstance(args, dict):
54
+ return args
55
+ # If it's a JSON string, parse it
56
+ if isinstance(args, str):
57
+ try:
58
+ parsed = json.loads(args)
59
+ if isinstance(parsed, dict):
60
+ return parsed
61
+ except json.JSONDecodeError:
62
+ logger.warning(
63
+ "Failed to parse tool args as JSON for tool '%s': %s",
64
+ tool_name,
65
+ args[:100] if len(args) > 100 else args,
66
+ )
67
+ return {}
68
+
69
+ return {}
70
+
71
+
72
+ def _extract_tool_input_data(
73
+ state: AgentGuardrailsGraphState, tool_name: str, execution_stage: ExecutionStage
74
+ ) -> dict[str, Any]:
75
+ """Extract tool call arguments as dict for deterministic guardrails.
76
+
77
+ Args:
78
+ state: The current agent graph state.
79
+ tool_name: Name of the tool to extract arguments from.
80
+ execution_stage: PRE_EXECUTION or POST_EXECUTION.
81
+
82
+ Returns:
83
+ Dict containing tool call arguments, or empty dict if not found.
84
+ - For PRE_EXECUTION: extracts from last message
85
+ - For POST_EXECUTION: extracts from second-to-last message
86
+ """
87
+ if not state.messages:
88
+ return {}
89
+
90
+ # For PRE_EXECUTION, look at last message
91
+ # For POST_EXECUTION, look at second-to-last message (before the ToolMessage)
92
+ if execution_stage == ExecutionStage.PRE_EXECUTION:
93
+ if len(state.messages) < 1:
94
+ return {}
95
+ message = state.messages[-1]
96
+ else: # POST_EXECUTION
97
+ if len(state.messages) < 2:
98
+ return {}
99
+ message = state.messages[-2]
100
+
101
+ return _extract_tool_args_from_message(message, tool_name)
102
+
103
+
104
+ def _extract_tool_output_data(state: AgentGuardrailsGraphState) -> dict[str, Any]:
105
+ """Extract tool execution output as dict for POST_EXECUTION deterministic guardrails.
106
+
107
+ Args:
108
+ state: The current agent graph state.
109
+
110
+ Returns:
111
+ Dict containing tool output. If output is not valid JSON, wraps it in {"output": content}.
112
+ """
113
+ if not state.messages:
114
+ return {}
115
+
116
+ last_message = state.messages[-1]
117
+ if not isinstance(last_message, ToolMessage):
118
+ return {}
119
+
120
+ content = last_message.content
121
+ if not content:
122
+ return {}
123
+
124
+ # Try to parse as JSON first
125
+ if isinstance(content, str):
126
+ try:
127
+ parsed = json.loads(content)
128
+ if isinstance(parsed, dict):
129
+ return parsed
130
+ else:
131
+ # JSON array or primitive - wrap it
132
+ return {"output": parsed}
133
+ except json.JSONDecodeError:
134
+ logger.warning("Tool output is not valid JSON")
135
+ return {"output": content}
136
+ elif isinstance(content, dict):
137
+ return content
138
+ else:
139
+ # List or other type
140
+ return {"output": content}
141
+
142
+
143
+ def get_message_content(msg: AnyMessage) -> str:
144
+ if isinstance(msg, (HumanMessage, SystemMessage)):
145
+ return msg.content if isinstance(msg.content, str) else str(msg.content)
146
+ return str(getattr(msg, "content", "")) if hasattr(msg, "content") else ""