uipath-langchain 0.1.28__py3-none-any.whl → 0.1.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. uipath_langchain/_utils/_request_mixin.py +8 -0
  2. uipath_langchain/_utils/_settings.py +3 -2
  3. uipath_langchain/agent/guardrails/__init__.py +0 -16
  4. uipath_langchain/agent/guardrails/actions/__init__.py +2 -0
  5. uipath_langchain/agent/guardrails/actions/block_action.py +1 -1
  6. uipath_langchain/agent/guardrails/actions/escalate_action.py +17 -34
  7. uipath_langchain/agent/guardrails/actions/filter_action.py +55 -0
  8. uipath_langchain/agent/guardrails/actions/log_action.py +1 -1
  9. uipath_langchain/agent/guardrails/guardrail_nodes.py +161 -45
  10. uipath_langchain/agent/guardrails/guardrails_factory.py +200 -4
  11. uipath_langchain/agent/guardrails/types.py +0 -12
  12. uipath_langchain/agent/guardrails/utils.py +146 -0
  13. uipath_langchain/agent/react/agent.py +20 -7
  14. uipath_langchain/agent/react/constants.py +1 -2
  15. uipath_langchain/agent/{guardrails → react/guardrails}/guardrails_subgraph.py +57 -18
  16. uipath_langchain/agent/react/llm_node.py +41 -10
  17. uipath_langchain/agent/react/router.py +48 -37
  18. uipath_langchain/agent/react/types.py +15 -1
  19. uipath_langchain/agent/react/utils.py +1 -1
  20. uipath_langchain/agent/tools/__init__.py +2 -0
  21. uipath_langchain/agent/tools/mcp_tool.py +86 -0
  22. uipath_langchain/chat/__init__.py +4 -0
  23. uipath_langchain/chat/bedrock.py +16 -0
  24. uipath_langchain/chat/openai.py +56 -26
  25. uipath_langchain/chat/supported_models.py +9 -0
  26. uipath_langchain/chat/vertex.py +62 -46
  27. uipath_langchain/embeddings/embeddings.py +18 -12
  28. uipath_langchain/runtime/schema.py +72 -16
  29. {uipath_langchain-0.1.28.dist-info → uipath_langchain-0.1.34.dist-info}/METADATA +4 -2
  30. {uipath_langchain-0.1.28.dist-info → uipath_langchain-0.1.34.dist-info}/RECORD +33 -30
  31. {uipath_langchain-0.1.28.dist-info → uipath_langchain-0.1.34.dist-info}/WHEEL +0 -0
  32. {uipath_langchain-0.1.28.dist-info → uipath_langchain-0.1.34.dist-info}/entry_points.txt +0 -0
  33. {uipath_langchain-0.1.28.dist-info → uipath_langchain-0.1.34.dist-info}/licenses/LICENSE +0 -0
@@ -1,13 +1,28 @@
1
1
  import logging
2
- from typing import Sequence
2
+ import re
3
+ from typing import Callable, Sequence
3
4
 
4
5
  from uipath.agent.models.agent import (
6
+ AgentBooleanOperator,
7
+ AgentBooleanRule,
8
+ AgentCustomGuardrail,
5
9
  AgentGuardrail,
6
10
  AgentGuardrailBlockAction,
7
11
  AgentGuardrailEscalateAction,
8
12
  AgentGuardrailLogAction,
9
13
  AgentGuardrailSeverityLevel,
14
+ AgentNumberOperator,
15
+ AgentNumberRule,
10
16
  AgentUnknownGuardrail,
17
+ AgentWordOperator,
18
+ AgentWordRule,
19
+ )
20
+ from uipath.core.guardrails import (
21
+ BooleanRule,
22
+ DeterministicGuardrail,
23
+ NumberRule,
24
+ UniversalRule,
25
+ WordRule,
11
26
  )
12
27
  from uipath.platform.guardrails import BaseGuardrail
13
28
 
@@ -19,6 +34,180 @@ from uipath_langchain.agent.guardrails.actions import (
19
34
  )
20
35
 
21
36
 
37
+ def _assert_value_not_none(value: str | None, operator: AgentWordOperator) -> str:
38
+ """Assert value is not None and return as string."""
39
+ assert value is not None, f"value cannot be None for {operator.name} operator"
40
+ return value
41
+
42
+
43
+ def _create_word_rule_func(
44
+ operator: AgentWordOperator, value: str | None
45
+ ) -> Callable[[str], bool]:
46
+ """Create a callable function from AgentWordOperator and value.
47
+
48
+ Args:
49
+ operator: The word operator to convert.
50
+ value: The value to compare against (may be None for isEmpty/isNotEmpty).
51
+
52
+ Returns:
53
+ A callable that takes a string and returns a boolean.
54
+ """
55
+ match operator:
56
+ case AgentWordOperator.CONTAINS:
57
+ val = _assert_value_not_none(value, operator)
58
+ return lambda s: val.lower() in s.lower()
59
+ case AgentWordOperator.DOES_NOT_CONTAIN:
60
+ val = _assert_value_not_none(value, operator)
61
+ return lambda s: val.lower() not in s.lower()
62
+ case AgentWordOperator.EQUALS:
63
+ val = _assert_value_not_none(value, operator)
64
+ return lambda s: s == val
65
+ case AgentWordOperator.DOES_NOT_EQUAL:
66
+ val = _assert_value_not_none(value, operator)
67
+ return lambda s: s != val
68
+ case AgentWordOperator.STARTS_WITH:
69
+ val = _assert_value_not_none(value, operator)
70
+ return lambda s: s.startswith(val)
71
+ case AgentWordOperator.DOES_NOT_START_WITH:
72
+ val = _assert_value_not_none(value, operator)
73
+ return lambda s: not s.startswith(val)
74
+ case AgentWordOperator.ENDS_WITH:
75
+ val = _assert_value_not_none(value, operator)
76
+ return lambda s: s.endswith(val)
77
+ case AgentWordOperator.DOES_NOT_END_WITH:
78
+ val = _assert_value_not_none(value, operator)
79
+ return lambda s: not s.endswith(val)
80
+ case AgentWordOperator.IS_EMPTY:
81
+ return lambda s: len(s) == 0
82
+ case AgentWordOperator.IS_NOT_EMPTY:
83
+ return lambda s: len(s) > 0
84
+ case AgentWordOperator.MATCHES_REGEX:
85
+ val = _assert_value_not_none(value, operator)
86
+ pattern = re.compile(val)
87
+ return lambda s: bool(pattern.match(s))
88
+ case _:
89
+ raise ValueError(f"Unsupported word operator: {operator}")
90
+
91
+
92
+ def _create_number_rule_func(
93
+ operator: AgentNumberOperator, value: float
94
+ ) -> Callable[[float], bool]:
95
+ """Create a callable function from AgentNumberOperator and value.
96
+
97
+ Args:
98
+ operator: The number operator to convert.
99
+ value: The value to compare against.
100
+
101
+ Returns:
102
+ A callable that takes a float and returns a boolean.
103
+ """
104
+ match operator:
105
+ case AgentNumberOperator.EQUALS:
106
+ return lambda n: n == value
107
+ case AgentNumberOperator.DOES_NOT_EQUAL:
108
+ return lambda n: n != value
109
+ case AgentNumberOperator.GREATER_THAN:
110
+ return lambda n: n > value
111
+ case AgentNumberOperator.GREATER_THAN_OR_EQUAL:
112
+ return lambda n: n >= value
113
+ case AgentNumberOperator.LESS_THAN:
114
+ return lambda n: n < value
115
+ case AgentNumberOperator.LESS_THAN_OR_EQUAL:
116
+ return lambda n: n <= value
117
+ case _:
118
+ raise ValueError(f"Unsupported number operator: {operator}")
119
+
120
+
121
+ def _create_boolean_rule_func(
122
+ operator: AgentBooleanOperator, value: bool
123
+ ) -> Callable[[bool], bool]:
124
+ """Create a callable function from AgentBooleanOperator and value.
125
+
126
+ Args:
127
+ operator: The boolean operator to convert.
128
+ value: The value to compare against.
129
+
130
+ Returns:
131
+ A callable that takes a boolean and returns a boolean.
132
+ """
133
+ match operator:
134
+ case AgentBooleanOperator.EQUALS:
135
+ return lambda b: b == value
136
+ case _:
137
+ raise ValueError(f"Unsupported boolean operator: {operator}")
138
+
139
+
140
+ def _convert_agent_rule_to_deterministic(
141
+ agent_rule: AgentWordRule | AgentNumberRule | AgentBooleanRule | UniversalRule,
142
+ ) -> WordRule | NumberRule | BooleanRule | UniversalRule:
143
+ """Convert an Agent rule to its Deterministic equivalent.
144
+
145
+ Args:
146
+ agent_rule: The agent rule to convert.
147
+
148
+ Returns:
149
+ The corresponding deterministic rule with a callable function.
150
+ """
151
+ if isinstance(agent_rule, UniversalRule):
152
+ # UniversalRule is already compatible
153
+ return agent_rule
154
+
155
+ if isinstance(agent_rule, AgentWordRule):
156
+ return WordRule(
157
+ rule_type="word",
158
+ field_selector=agent_rule.field_selector,
159
+ detects_violation=_create_word_rule_func(
160
+ agent_rule.operator, agent_rule.value
161
+ ),
162
+ )
163
+
164
+ if isinstance(agent_rule, AgentNumberRule):
165
+ return NumberRule(
166
+ rule_type="number",
167
+ field_selector=agent_rule.field_selector,
168
+ detects_violation=_create_number_rule_func(
169
+ agent_rule.operator, agent_rule.value
170
+ ),
171
+ )
172
+
173
+ if isinstance(agent_rule, AgentBooleanRule):
174
+ return BooleanRule(
175
+ rule_type="boolean",
176
+ field_selector=agent_rule.field_selector,
177
+ detects_violation=_create_boolean_rule_func(
178
+ agent_rule.operator, agent_rule.value
179
+ ),
180
+ )
181
+
182
+ raise ValueError(f"Unsupported agent rule type: {type(agent_rule)}")
183
+
184
+
185
+ def _convert_agent_custom_guardrail_to_deterministic(
186
+ guardrail: AgentCustomGuardrail,
187
+ ) -> DeterministicGuardrail:
188
+ """Convert AgentCustomGuardrail to DeterministicGuardrail.
189
+
190
+ Args:
191
+ guardrail: The agent custom guardrail to convert.
192
+
193
+ Returns:
194
+ A DeterministicGuardrail with converted rules.
195
+ """
196
+ converted_rules = [
197
+ _convert_agent_rule_to_deterministic(rule) for rule in guardrail.rules
198
+ ]
199
+
200
+ return DeterministicGuardrail(
201
+ id=guardrail.id,
202
+ name=guardrail.name,
203
+ description=guardrail.description,
204
+ enabled_for_evals=guardrail.enabled_for_evals,
205
+ selector=guardrail.selector,
206
+ guardrail_type="custom",
207
+ rules=converted_rules,
208
+ )
209
+
210
+
22
211
  def build_guardrails_with_actions(
23
212
  guardrails: Sequence[AgentGuardrail] | None,
24
213
  ) -> list[tuple[BaseGuardrail, GuardrailAction]]:
@@ -38,10 +227,17 @@ def build_guardrails_with_actions(
38
227
  if isinstance(guardrail, AgentUnknownGuardrail):
39
228
  continue
40
229
 
230
+ # Convert AgentCustomGuardrail to DeterministicGuardrail
231
+ converted_guardrail: BaseGuardrail = guardrail
232
+ if isinstance(guardrail, AgentCustomGuardrail):
233
+ converted_guardrail = _convert_agent_custom_guardrail_to_deterministic(
234
+ guardrail
235
+ )
236
+
41
237
  action = guardrail.action
42
238
 
43
239
  if isinstance(action, AgentGuardrailBlockAction):
44
- result.append((guardrail, BlockAction(action.reason)))
240
+ result.append((converted_guardrail, BlockAction(action.reason)))
45
241
  elif isinstance(action, AgentGuardrailLogAction):
46
242
  severity_level_map = {
47
243
  AgentGuardrailSeverityLevel.ERROR: logging.ERROR,
@@ -51,14 +247,14 @@ def build_guardrails_with_actions(
51
247
  level = severity_level_map.get(action.severity_level, logging.INFO)
52
248
  result.append(
53
249
  (
54
- guardrail,
250
+ converted_guardrail,
55
251
  LogAction(message=action.message, level=level),
56
252
  )
57
253
  )
58
254
  elif isinstance(action, AgentGuardrailEscalateAction):
59
255
  result.append(
60
256
  (
61
- guardrail,
257
+ converted_guardrail,
62
258
  EscalateAction(
63
259
  app_name=action.app.name,
64
260
  app_folder_path=action.app.folder_name,
@@ -1,16 +1,4 @@
1
1
  from enum import Enum
2
- from typing import Annotated, Optional
3
-
4
- from langchain_core.messages import AnyMessage
5
- from langgraph.graph.message import add_messages
6
- from pydantic import BaseModel
7
-
8
-
9
- class AgentGuardrailsGraphState(BaseModel):
10
- """Agent Guardrails Graph state for guardrail subgraph."""
11
-
12
- messages: Annotated[list[AnyMessage], add_messages] = []
13
- guardrail_validation_result: Optional[str] = None
14
2
 
15
3
 
16
4
  class ExecutionStage(str, Enum):
@@ -0,0 +1,146 @@
1
+ import json
2
+ import logging
3
+ from typing import Any
4
+
5
+ from langchain_core.messages import (
6
+ AIMessage,
7
+ AnyMessage,
8
+ HumanMessage,
9
+ SystemMessage,
10
+ ToolMessage,
11
+ )
12
+
13
+ from uipath_langchain.agent.guardrails.types import ExecutionStage
14
+ from uipath_langchain.agent.react.types import AgentGuardrailsGraphState
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def _extract_tool_args_from_message(
20
+ message: AnyMessage, tool_name: str
21
+ ) -> dict[str, Any]:
22
+ """Extract tool call arguments from an AIMessage.
23
+
24
+ Args:
25
+ message: The message to extract from.
26
+ tool_name: Name of the tool to extract arguments from.
27
+
28
+ Returns:
29
+ Dict containing tool call arguments, or empty dict if not found.
30
+ """
31
+ if not isinstance(message, AIMessage):
32
+ return {}
33
+
34
+ if not message.tool_calls:
35
+ return {}
36
+
37
+ # Find the first tool call with matching name
38
+ for tool_call in message.tool_calls:
39
+ call_name = (
40
+ tool_call.get("name")
41
+ if isinstance(tool_call, dict)
42
+ else getattr(tool_call, "name", None)
43
+ )
44
+ if call_name == tool_name:
45
+ # Extract args from the tool call
46
+ args = (
47
+ tool_call.get("args")
48
+ if isinstance(tool_call, dict)
49
+ else getattr(tool_call, "args", None)
50
+ )
51
+ if args is not None:
52
+ # Args should already be a dict
53
+ if isinstance(args, dict):
54
+ return args
55
+ # If it's a JSON string, parse it
56
+ if isinstance(args, str):
57
+ try:
58
+ parsed = json.loads(args)
59
+ if isinstance(parsed, dict):
60
+ return parsed
61
+ except json.JSONDecodeError:
62
+ logger.warning(
63
+ "Failed to parse tool args as JSON for tool '%s': %s",
64
+ tool_name,
65
+ args[:100] if len(args) > 100 else args,
66
+ )
67
+ return {}
68
+
69
+ return {}
70
+
71
+
72
+ def _extract_tool_input_data(
73
+ state: AgentGuardrailsGraphState, tool_name: str, execution_stage: ExecutionStage
74
+ ) -> dict[str, Any]:
75
+ """Extract tool call arguments as dict for deterministic guardrails.
76
+
77
+ Args:
78
+ state: The current agent graph state.
79
+ tool_name: Name of the tool to extract arguments from.
80
+ execution_stage: PRE_EXECUTION or POST_EXECUTION.
81
+
82
+ Returns:
83
+ Dict containing tool call arguments, or empty dict if not found.
84
+ - For PRE_EXECUTION: extracts from last message
85
+ - For POST_EXECUTION: extracts from second-to-last message
86
+ """
87
+ if not state.messages:
88
+ return {}
89
+
90
+ # For PRE_EXECUTION, look at last message
91
+ # For POST_EXECUTION, look at second-to-last message (before the ToolMessage)
92
+ if execution_stage == ExecutionStage.PRE_EXECUTION:
93
+ if len(state.messages) < 1:
94
+ return {}
95
+ message = state.messages[-1]
96
+ else: # POST_EXECUTION
97
+ if len(state.messages) < 2:
98
+ return {}
99
+ message = state.messages[-2]
100
+
101
+ return _extract_tool_args_from_message(message, tool_name)
102
+
103
+
104
+ def _extract_tool_output_data(state: AgentGuardrailsGraphState) -> dict[str, Any]:
105
+ """Extract tool execution output as dict for POST_EXECUTION deterministic guardrails.
106
+
107
+ Args:
108
+ state: The current agent graph state.
109
+
110
+ Returns:
111
+ Dict containing tool output. If output is not valid JSON, wraps it in {"output": content}.
112
+ """
113
+ if not state.messages:
114
+ return {}
115
+
116
+ last_message = state.messages[-1]
117
+ if not isinstance(last_message, ToolMessage):
118
+ return {}
119
+
120
+ content = last_message.content
121
+ if not content:
122
+ return {}
123
+
124
+ # Try to parse as JSON first
125
+ if isinstance(content, str):
126
+ try:
127
+ parsed = json.loads(content)
128
+ if isinstance(parsed, dict):
129
+ return parsed
130
+ else:
131
+ # JSON array or primitive - wrap it
132
+ return {"output": parsed}
133
+ except json.JSONDecodeError:
134
+ logger.warning("Tool output is not valid JSON")
135
+ return {"output": content}
136
+ elif isinstance(content, dict):
137
+ return content
138
+ else:
139
+ # List or other type
140
+ return {"output": content}
141
+
142
+
143
+ def get_message_content(msg: AnyMessage) -> str:
144
+ if isinstance(msg, (HumanMessage, SystemMessage)):
145
+ return msg.content if isinstance(msg.content, str) else str(msg.content)
146
+ return str(getattr(msg, "content", "")) if hasattr(msg, "content") else ""
@@ -9,10 +9,14 @@ from langgraph.graph import StateGraph
9
9
  from pydantic import BaseModel
10
10
  from uipath.platform.guardrails import BaseGuardrail
11
11
 
12
- from ..guardrails import create_llm_guardrails_subgraph
13
12
  from ..guardrails.actions import GuardrailAction
14
- from ..guardrails.guardrails_subgraph import create_tools_guardrails_subgraph
15
13
  from ..tools import create_tool_node
14
+ from .guardrails.guardrails_subgraph import (
15
+ create_agent_init_guardrails_subgraph,
16
+ create_agent_terminate_guardrails_subgraph,
17
+ create_llm_guardrails_subgraph,
18
+ create_tools_guardrails_subgraph,
19
+ )
16
20
  from .init_node import (
17
21
  create_init_node,
18
22
  )
@@ -20,7 +24,7 @@ from .llm_node import (
20
24
  create_llm_node,
21
25
  )
22
26
  from .router import (
23
- route_agent,
27
+ create_route_agent,
24
28
  )
25
29
  from .terminate_node import (
26
30
  create_terminate_node,
@@ -54,7 +58,7 @@ def create_agent(
54
58
  config: AgentGraphConfig | None = None,
55
59
  guardrails: Sequence[tuple[BaseGuardrail, GuardrailAction]] | None = None,
56
60
  ) -> StateGraph[AgentGraphState, None, InputT, OutputT]:
57
- """Build agent graph with INIT -> AGENT(subgraph) <-> TOOLS loop, terminated by control flow tools.
61
+ """Build agent graph with INIT -> AGENT (subgraph) <-> TOOLS loop, terminated by control flow tools.
58
62
 
59
63
  The AGENT node is a subgraph that runs:
60
64
  - before-agent guardrail middlewares
@@ -86,16 +90,24 @@ def create_agent(
86
90
  builder: StateGraph[AgentGraphState, None, InputT, OutputT] = StateGraph(
87
91
  InnerAgentGraphState, input_schema=input_schema, output_schema=output_schema
88
92
  )
89
- builder.add_node(AgentGraphNode.INIT, init_node)
93
+ init_with_guardrails_subgraph = create_agent_init_guardrails_subgraph(
94
+ (AgentGraphNode.GUARDED_INIT, init_node),
95
+ guardrails,
96
+ )
97
+ builder.add_node(AgentGraphNode.INIT, init_with_guardrails_subgraph)
90
98
 
91
99
  for tool_name, tool_node in tool_nodes_with_guardrails.items():
92
100
  builder.add_node(tool_name, tool_node)
93
101
 
94
- builder.add_node(AgentGraphNode.TERMINATE, terminate_node)
102
+ terminate_with_guardrails_subgraph = create_agent_terminate_guardrails_subgraph(
103
+ (AgentGraphNode.GUARDED_TERMINATE, terminate_node),
104
+ guardrails,
105
+ )
106
+ builder.add_node(AgentGraphNode.TERMINATE, terminate_with_guardrails_subgraph)
95
107
 
96
108
  builder.add_edge(START, AgentGraphNode.INIT)
97
109
 
98
- llm_node = create_llm_node(model, llm_tools)
110
+ llm_node = create_llm_node(model, llm_tools, config.thinking_messages_limit)
99
111
  llm_with_guardrails_subgraph = create_llm_guardrails_subgraph(
100
112
  (AgentGraphNode.LLM, llm_node), guardrails
101
113
  )
@@ -103,6 +115,7 @@ def create_agent(
103
115
  builder.add_edge(AgentGraphNode.INIT, AgentGraphNode.AGENT)
104
116
 
105
117
  tool_node_names = list(tool_nodes_with_guardrails.keys())
118
+ route_agent = create_route_agent(config.thinking_messages_limit)
106
119
  builder.add_conditional_edges(
107
120
  AgentGraphNode.AGENT,
108
121
  route_agent,
@@ -1,2 +1 @@
1
- # Agent routing configuration
2
- MAX_SUCCESSIVE_COMPLETIONS = 1
1
+ MAX_CONSECUTIVE_THINKING_MESSAGES = 0
@@ -10,15 +10,21 @@ from uipath.platform.guardrails import (
10
10
  GuardrailScope,
11
11
  )
12
12
 
13
- from uipath_langchain.agent.guardrails.types import ExecutionStage
14
-
15
- from .actions.base_action import GuardrailAction, GuardrailActionNode
16
- from .guardrail_nodes import (
17
- create_agent_guardrail_node,
13
+ from uipath_langchain.agent.guardrails.actions.base_action import (
14
+ GuardrailAction,
15
+ GuardrailActionNode,
16
+ )
17
+ from uipath_langchain.agent.guardrails.guardrail_nodes import (
18
+ create_agent_init_guardrail_node,
19
+ create_agent_terminate_guardrail_node,
18
20
  create_llm_guardrail_node,
19
21
  create_tool_guardrail_node,
20
22
  )
21
- from .types import AgentGuardrailsGraphState
23
+ from uipath_langchain.agent.guardrails.types import ExecutionStage
24
+ from uipath_langchain.agent.react.types import (
25
+ AgentGraphState,
26
+ AgentGuardrailsGraphState,
27
+ )
22
28
 
23
29
  _VALIDATOR_ALLOWED_STAGES = {
24
30
  "prompt_injection": {ExecutionStage.PRE_EXECUTION},
@@ -232,32 +238,65 @@ def create_tools_guardrails_subgraph(
232
238
  return result
233
239
 
234
240
 
235
- def create_agent_guardrails_subgraph(
236
- agent_node: tuple[str, Any],
241
+ def create_agent_init_guardrails_subgraph(
242
+ init_node: tuple[str, Any],
237
243
  guardrails: Sequence[tuple[BaseGuardrail, GuardrailAction]] | None,
238
- execution_stage: ExecutionStage,
239
244
  ):
240
- """Create a subgraph for AGENT-scoped guardrails that applies checks at the specified stage.
241
-
242
- This is intended for wrapping nodes like INIT or TERMINATE, where guardrails should run
243
- either before (pre-execution) or after (post-execution) the node logic.
244
- """
245
+ """Create a subgraph for INIT node that applies guardrails on the state messages."""
245
246
  applicable_guardrails = [
246
247
  (guardrail, _)
247
248
  for (guardrail, _) in (guardrails or [])
248
249
  if GuardrailScope.AGENT in guardrail.selector.scopes
249
250
  ]
250
251
  if applicable_guardrails is None or len(applicable_guardrails) == 0:
251
- return agent_node[1]
252
+ return init_node[1]
252
253
 
253
254
  return _create_guardrails_subgraph(
254
- main_inner_node=agent_node,
255
+ main_inner_node=init_node,
256
+ guardrails=applicable_guardrails,
257
+ scope=GuardrailScope.AGENT,
258
+ execution_stages=[ExecutionStage.POST_EXECUTION],
259
+ node_factory=create_agent_init_guardrail_node,
260
+ )
261
+
262
+
263
+ def create_agent_terminate_guardrails_subgraph(
264
+ terminate_node: tuple[str, Any],
265
+ guardrails: Sequence[tuple[BaseGuardrail, GuardrailAction]] | None,
266
+ ):
267
+ """Create a subgraph for TERMINATE node that applies guardrails on the agent result."""
268
+ node_name, node_func = terminate_node
269
+
270
+ def terminate_wrapper(state: Any) -> dict[str, Any]:
271
+ # Call original terminate node
272
+ result = node_func(state)
273
+ # Store result in state
274
+ return {"agent_result": result, "messages": state.messages}
275
+
276
+ applicable_guardrails = [
277
+ (guardrail, _)
278
+ for (guardrail, _) in (guardrails or [])
279
+ if GuardrailScope.AGENT in guardrail.selector.scopes
280
+ ]
281
+ if applicable_guardrails is None or len(applicable_guardrails) == 0:
282
+ return terminate_node[1]
283
+
284
+ subgraph = _create_guardrails_subgraph(
285
+ main_inner_node=(node_name, terminate_wrapper),
255
286
  guardrails=applicable_guardrails,
256
287
  scope=GuardrailScope.AGENT,
257
- execution_stages=[execution_stage],
258
- node_factory=create_agent_guardrail_node,
288
+ execution_stages=[ExecutionStage.POST_EXECUTION],
289
+ node_factory=create_agent_terminate_guardrail_node,
259
290
  )
260
291
 
292
+ async def run_terminate_subgraph(
293
+ state: AgentGraphState,
294
+ ) -> dict[str, Any]:
295
+ result_state = await subgraph.ainvoke(state)
296
+ return result_state["agent_result"]
297
+
298
+ return run_terminate_subgraph
299
+
261
300
 
262
301
  def create_tool_guardrails_subgraph(
263
302
  tool_node: tuple[str, Any],
@@ -1,34 +1,65 @@
1
- """LLM node implementation for LangGraph."""
1
+ """LLM node for ReAct Agent graph."""
2
2
 
3
- from typing import Sequence
3
+ from typing import Literal, Sequence
4
4
 
5
5
  from langchain_core.language_models import BaseChatModel
6
6
  from langchain_core.messages import AIMessage, AnyMessage
7
7
  from langchain_core.tools import BaseTool
8
8
 
9
- from .constants import MAX_SUCCESSIVE_COMPLETIONS
9
+ from .constants import MAX_CONSECUTIVE_THINKING_MESSAGES
10
10
  from .types import AgentGraphState
11
- from .utils import count_successive_completions
11
+ from .utils import count_consecutive_thinking_messages
12
+
13
+ OPENAI_COMPATIBLE_CHAT_MODELS = (
14
+ "UiPathChatOpenAI",
15
+ "AzureChatOpenAI",
16
+ "ChatOpenAI",
17
+ "UiPathChat",
18
+ "UiPathAzureChatOpenAI",
19
+ )
20
+
21
+
22
+ def _get_required_tool_choice_by_model(
23
+ model: BaseChatModel,
24
+ ) -> Literal["required", "any"]:
25
+ """Get the appropriate tool_choice value to enforce tool usage based on model type.
26
+
27
+ "required" - OpenAI compatible required tool_choice value
28
+ "any" - Vertex and Bedrock parameter for required tool_choice value
29
+ """
30
+ model_class_name = model.__class__.__name__
31
+ if model_class_name in OPENAI_COMPATIBLE_CHAT_MODELS:
32
+ return "required"
33
+ return "any"
12
34
 
13
35
 
14
36
  def create_llm_node(
15
37
  model: BaseChatModel,
16
38
  tools: Sequence[BaseTool] | None = None,
39
+ thinking_messages_limit: int = MAX_CONSECUTIVE_THINKING_MESSAGES,
17
40
  ):
18
- """Invoke LLM with tools and dynamically control tool_choice based on successive completions.
41
+ """Create LLM node with dynamic tool_choice enforcement.
19
42
 
20
- When successive completions reach the limit, tool_choice is set to "required" to force
21
- the LLM to use a tool and prevent infinite reasoning loops.
43
+ Controls when to force tool usage based on consecutive thinking steps
44
+ to prevent infinite loops and ensure progress.
45
+
46
+ Args:
47
+ model: The chat model to use
48
+ tools: Available tools to bind
49
+ thinking_messages_limit: Max consecutive LLM responses without tool calls
50
+ before enforcing tool usage. 0 = force tools every time.
22
51
  """
23
52
  bindable_tools = list(tools) if tools else []
24
53
  base_llm = model.bind_tools(bindable_tools) if bindable_tools else model
54
+ tool_choice_required_value = _get_required_tool_choice_by_model(model)
25
55
 
26
56
  async def llm_node(state: AgentGraphState):
27
57
  messages: list[AnyMessage] = state.messages
28
58
 
29
- successive_completions = count_successive_completions(messages)
30
- if successive_completions >= MAX_SUCCESSIVE_COMPLETIONS:
31
- llm = base_llm.bind(tool_choice="required")
59
+ consecutive_thinking_messages = count_consecutive_thinking_messages(messages)
60
+
61
+ if bindable_tools and consecutive_thinking_messages >= thinking_messages_limit:
62
+ llm = base_llm.bind(tool_choice=tool_choice_required_value)
32
63
  else:
33
64
  llm = base_llm
34
65