uipath-langchain 0.1.28__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. uipath_langchain/_cli/_templates/langgraph.json.template +2 -4
  2. uipath_langchain/_cli/cli_new.py +1 -2
  3. uipath_langchain/_utils/_request_mixin.py +8 -0
  4. uipath_langchain/_utils/_settings.py +3 -2
  5. uipath_langchain/agent/guardrails/__init__.py +0 -16
  6. uipath_langchain/agent/guardrails/actions/__init__.py +2 -0
  7. uipath_langchain/agent/guardrails/actions/block_action.py +1 -1
  8. uipath_langchain/agent/guardrails/actions/escalate_action.py +265 -138
  9. uipath_langchain/agent/guardrails/actions/filter_action.py +290 -0
  10. uipath_langchain/agent/guardrails/actions/log_action.py +1 -1
  11. uipath_langchain/agent/guardrails/guardrail_nodes.py +193 -42
  12. uipath_langchain/agent/guardrails/guardrails_factory.py +235 -14
  13. uipath_langchain/agent/guardrails/types.py +0 -12
  14. uipath_langchain/agent/guardrails/utils.py +177 -0
  15. uipath_langchain/agent/react/agent.py +24 -9
  16. uipath_langchain/agent/react/constants.py +1 -2
  17. uipath_langchain/agent/react/file_type_handler.py +123 -0
  18. uipath_langchain/agent/{guardrails → react/guardrails}/guardrails_subgraph.py +119 -25
  19. uipath_langchain/agent/react/init_node.py +16 -1
  20. uipath_langchain/agent/react/job_attachments.py +125 -0
  21. uipath_langchain/agent/react/json_utils.py +183 -0
  22. uipath_langchain/agent/react/jsonschema_pydantic_converter.py +76 -0
  23. uipath_langchain/agent/react/llm_node.py +41 -10
  24. uipath_langchain/agent/react/llm_with_files.py +76 -0
  25. uipath_langchain/agent/react/router.py +48 -37
  26. uipath_langchain/agent/react/types.py +19 -1
  27. uipath_langchain/agent/react/utils.py +30 -4
  28. uipath_langchain/agent/tools/__init__.py +7 -1
  29. uipath_langchain/agent/tools/context_tool.py +151 -1
  30. uipath_langchain/agent/tools/escalation_tool.py +46 -15
  31. uipath_langchain/agent/tools/integration_tool.py +20 -16
  32. uipath_langchain/agent/tools/internal_tools/__init__.py +5 -0
  33. uipath_langchain/agent/tools/internal_tools/analyze_files_tool.py +113 -0
  34. uipath_langchain/agent/tools/internal_tools/internal_tool_factory.py +54 -0
  35. uipath_langchain/agent/tools/mcp_tool.py +86 -0
  36. uipath_langchain/agent/tools/process_tool.py +8 -1
  37. uipath_langchain/agent/tools/static_args.py +18 -40
  38. uipath_langchain/agent/tools/tool_factory.py +13 -5
  39. uipath_langchain/agent/tools/tool_node.py +133 -4
  40. uipath_langchain/agent/tools/utils.py +31 -0
  41. uipath_langchain/agent/wrappers/__init__.py +6 -0
  42. uipath_langchain/agent/wrappers/job_attachment_wrapper.py +62 -0
  43. uipath_langchain/agent/wrappers/static_args_wrapper.py +34 -0
  44. uipath_langchain/chat/__init__.py +4 -0
  45. uipath_langchain/chat/bedrock.py +16 -0
  46. uipath_langchain/chat/mapper.py +60 -42
  47. uipath_langchain/chat/openai.py +56 -26
  48. uipath_langchain/chat/supported_models.py +9 -0
  49. uipath_langchain/chat/vertex.py +62 -46
  50. uipath_langchain/embeddings/embeddings.py +18 -12
  51. uipath_langchain/runtime/factory.py +10 -5
  52. uipath_langchain/runtime/runtime.py +38 -35
  53. uipath_langchain/runtime/schema.py +72 -16
  54. uipath_langchain/runtime/storage.py +178 -71
  55. {uipath_langchain-0.1.28.dist-info → uipath_langchain-0.3.1.dist-info}/METADATA +7 -4
  56. uipath_langchain-0.3.1.dist-info/RECORD +90 -0
  57. uipath_langchain-0.1.28.dist-info/RECORD +0 -76
  58. {uipath_langchain-0.1.28.dist-info → uipath_langchain-0.3.1.dist-info}/WHEEL +0 -0
  59. {uipath_langchain-0.1.28.dist-info → uipath_langchain-0.3.1.dist-info}/entry_points.txt +0 -0
  60. {uipath_langchain-0.1.28.dist-info → uipath_langchain-0.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,10 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import ast
3
4
  import json
4
5
  import re
5
- from typing import Any, Dict, Literal
6
+ from typing import Any, Dict, Literal, cast
6
7
 
7
- from langchain_core.messages import AIMessage, ToolCall, ToolMessage
8
+ from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, ToolMessage
8
9
  from langgraph.types import Command, interrupt
9
10
  from uipath.platform.common import CreateEscalation
10
11
  from uipath.platform.guardrails import (
@@ -14,8 +15,9 @@ from uipath.platform.guardrails import (
14
15
  from uipath.runtime.errors import UiPathErrorCode
15
16
 
16
17
  from ...exceptions import AgentTerminationException
17
- from ..guardrail_nodes import _message_text
18
- from ..types import AgentGuardrailsGraphState, ExecutionStage
18
+ from ...react.types import AgentGuardrailsGraphState
19
+ from ..types import ExecutionStage
20
+ from ..utils import _extract_tool_args_from_message, get_message_content
19
21
  from .base_action import GuardrailAction, GuardrailActionNode
20
22
 
21
23
 
@@ -71,25 +73,59 @@ class EscalateAction(GuardrailAction):
71
73
  async def _node(
72
74
  state: AgentGuardrailsGraphState,
73
75
  ) -> Dict[str, Any] | Command[Any]:
74
- input = _extract_escalation_content(
75
- state, scope, execution_stage, guarded_component_name
76
- )
77
- escalation_field = _execution_stage_to_escalation_field(execution_stage)
76
+ # Validate message count based on execution stage
77
+ _validate_message_count(state, execution_stage)
78
78
 
79
- data = {
79
+ # Build base data dictionary with common fields
80
+ data: Dict[str, Any] = {
80
81
  "GuardrailName": guardrail.name,
81
82
  "GuardrailDescription": guardrail.description,
82
83
  "Component": scope.name.lower(),
83
84
  "ExecutionStage": _execution_stage_to_string(execution_stage),
84
85
  "GuardrailResult": state.guardrail_validation_result,
85
- escalation_field: input,
86
86
  }
87
87
 
88
+ # Add stage-specific fields
89
+ if execution_stage == ExecutionStage.PRE_EXECUTION:
90
+ # PRE_EXECUTION: Only Inputs field from last message
91
+ input_content = _extract_escalation_content(
92
+ state.messages[-1],
93
+ state,
94
+ scope,
95
+ execution_stage,
96
+ guarded_component_name,
97
+ )
98
+ data["Inputs"] = input_content
99
+ else: # POST_EXECUTION
100
+ if scope == GuardrailScope.AGENT:
101
+ input_message = state.messages[1]
102
+ else:
103
+ input_message = state.messages[-2]
104
+ input_content = _extract_escalation_content(
105
+ input_message,
106
+ state,
107
+ scope,
108
+ ExecutionStage.PRE_EXECUTION,
109
+ guarded_component_name,
110
+ )
111
+
112
+ # Extract Outputs from last message using POST_EXECUTION logic
113
+ output_content = _extract_escalation_content(
114
+ state.messages[-1],
115
+ state,
116
+ scope,
117
+ execution_stage,
118
+ guarded_component_name,
119
+ )
120
+
121
+ data["Inputs"] = input_content
122
+ data["Outputs"] = output_content
123
+
88
124
  escalation_result = interrupt(
89
125
  CreateEscalation(
90
126
  app_name=self.app_name,
91
127
  app_folder_path=self.app_folder_path,
92
- title=self.app_name,
128
+ title="Agents Guardrail Task",
93
129
  data=data,
94
130
  assignee=self.assignee,
95
131
  )
@@ -107,17 +143,50 @@ class EscalateAction(GuardrailAction):
107
143
  raise AgentTerminationException(
108
144
  code=UiPathErrorCode.EXECUTION_ERROR,
109
145
  title="Escalation rejected",
110
- detail=f"Action was rejected after reviewing the task created by guardrail [{guardrail.name}]. Please contact your administrator.",
146
+ detail=f"Please contact your administrator. Action was rejected after reviewing the task created by guardrail [{guardrail.name}], with reason: {escalation_result.data['Reason']}",
111
147
  )
112
148
 
113
149
  return node_name, _node
114
150
 
115
151
 
152
+ def _validate_message_count(
153
+ state: AgentGuardrailsGraphState,
154
+ execution_stage: ExecutionStage,
155
+ ) -> None:
156
+ """Validate that state has the required number of messages for the execution stage.
157
+
158
+ Args:
159
+ state: The current agent graph state.
160
+ execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
161
+
162
+ Raises:
163
+ AgentTerminationException: If the state doesn't have enough messages.
164
+ """
165
+ required_messages = 1 if execution_stage == ExecutionStage.PRE_EXECUTION else 2
166
+ actual_messages = len(state.messages)
167
+
168
+ if actual_messages < required_messages:
169
+ stage_name = (
170
+ "PRE_EXECUTION"
171
+ if execution_stage == ExecutionStage.PRE_EXECUTION
172
+ else "POST_EXECUTION"
173
+ )
174
+ detail = f"{stage_name} requires at least {required_messages} message{'s' if required_messages > 1 else ''} in state, but found {actual_messages}."
175
+ if execution_stage == ExecutionStage.POST_EXECUTION:
176
+ detail += " Cannot extract Inputs from previous message."
177
+
178
+ raise AgentTerminationException(
179
+ code=UiPathErrorCode.EXECUTION_ERROR,
180
+ title=f"Invalid state for {stage_name}",
181
+ detail=detail,
182
+ )
183
+
184
+
116
185
  def _get_node_name(
117
186
  execution_stage: ExecutionStage, guardrail: BaseGuardrail, scope: GuardrailScope
118
187
  ) -> str:
119
- sanitized = re.sub(r"\W+", "_", guardrail.name).strip("_").lower()
120
- node_name = f"{sanitized}_hitl_{execution_stage.name.lower()}_{scope.lower()}"
188
+ raw_node_name = f"{scope.name}_{execution_stage.name}_{guardrail.name}_hitl"
189
+ node_name = re.sub(r"\W+", "_", raw_node_name.lower()).strip("_")
121
190
  return node_name
122
191
 
123
192
 
@@ -137,8 +206,8 @@ def _process_escalation_response(
137
206
  execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
138
207
 
139
208
  Returns:
140
- For LLM/TOOL scope: Command to update messages with reviewed inputs/outputs, or empty dict.
141
- For AGENT scope: Empty dict (no message alteration).
209
+ Command updates for the state (e.g., updating messages / tool calls / agent_result),
210
+ or an empty dict if no update is needed.
142
211
  """
143
212
  match scope:
144
213
  case GuardrailScope.LLM:
@@ -150,8 +219,71 @@ def _process_escalation_response(
150
219
  state, escalation_result, execution_stage, guarded_node_name
151
220
  )
152
221
  case GuardrailScope.AGENT:
222
+ return _process_agent_escalation_response(
223
+ state, escalation_result, execution_stage
224
+ )
225
+
226
+
227
+ def _process_agent_escalation_response(
228
+ state: AgentGuardrailsGraphState,
229
+ escalation_result: Dict[str, Any],
230
+ execution_stage: ExecutionStage,
231
+ ) -> Dict[str, Any] | Command[Any]:
232
+ """Process escalation response for AGENT scope guardrails.
233
+
234
+ For AGENT scope:
235
+ - PRE_EXECUTION: updates the last message content using `ReviewedInputs`
236
+ - POST_EXECUTION: updates `agent_result` using `ReviewedOutputs`
237
+
238
+ Args:
239
+ state: The current agent graph state.
240
+ escalation_result: The result from the escalation interrupt containing reviewed inputs/outputs.
241
+ execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
242
+
243
+ Returns:
244
+ Command to update state, or empty dict if no updates are needed.
245
+
246
+ Raises:
247
+ AgentTerminationException: If escalation response processing fails.
248
+ """
249
+ try:
250
+ reviewed_field = get_reviewed_field_name(execution_stage)
251
+ if reviewed_field not in escalation_result:
252
+ return {}
253
+
254
+ reviewed_value = escalation_result.get(reviewed_field)
255
+ if not reviewed_value:
153
256
  return {}
154
257
 
258
+ try:
259
+ parsed = json.loads(reviewed_value)
260
+ except json.JSONDecodeError:
261
+ parsed = reviewed_value
262
+
263
+ if execution_stage == ExecutionStage.PRE_EXECUTION:
264
+ msgs = state.messages.copy()
265
+ if not msgs:
266
+ return {}
267
+ msgs[-1].content = parsed
268
+ return Command(update={"messages": msgs})
269
+
270
+ # POST_EXECUTION: update agent_result
271
+ return Command(update={"agent_result": parsed})
272
+ except Exception as e:
273
+ raise AgentTerminationException(
274
+ code=UiPathErrorCode.EXECUTION_ERROR,
275
+ title="Escalation rejected",
276
+ detail=str(e),
277
+ ) from e
278
+
279
+
280
+ def get_reviewed_field_name(execution_stage):
281
+ return (
282
+ "ReviewedInputs"
283
+ if execution_stage == ExecutionStage.PRE_EXECUTION
284
+ else "ReviewedOutputs"
285
+ )
286
+
155
287
 
156
288
  def _process_llm_escalation_response(
157
289
  state: AgentGuardrailsGraphState,
@@ -174,11 +306,7 @@ def _process_llm_escalation_response(
174
306
  AgentTerminationException: If escalation response processing fails.
175
307
  """
176
308
  try:
177
- reviewed_field = (
178
- "ReviewedInputs"
179
- if execution_stage == ExecutionStage.PRE_EXECUTION
180
- else "ReviewedOutputs"
181
- )
309
+ reviewed_field = get_reviewed_field_name(execution_stage)
182
310
 
183
311
  msgs = state.messages.copy()
184
312
  if not msgs or reviewed_field not in escalation_result:
@@ -195,41 +323,56 @@ def _process_llm_escalation_response(
195
323
  if not reviewed_outputs_json:
196
324
  return {}
197
325
 
198
- content_list = json.loads(reviewed_outputs_json)
199
- if not content_list:
326
+ reviewed_tool_calls_list = json.loads(reviewed_outputs_json)
327
+ if not reviewed_tool_calls_list:
200
328
  return {}
201
329
 
330
+ # Track if tool calls were successfully processed
331
+ tool_calls_processed = False
332
+
202
333
  # For AI messages, process tool calls if present
203
334
  if isinstance(last_message, AIMessage):
204
335
  ai_message: AIMessage = last_message
205
- content_index = 0
206
336
 
207
- if ai_message.tool_calls:
337
+ if ai_message.tool_calls and isinstance(reviewed_tool_calls_list, list):
208
338
  tool_calls = list(ai_message.tool_calls)
209
- for tool_call in tool_calls:
210
- args = tool_call["args"]
339
+
340
+ # Create a name-to-args mapping from reviewed tool call data
341
+ reviewed_tool_calls_map = {}
342
+ for reviewed_data in reviewed_tool_calls_list:
211
343
  if (
212
- isinstance(args, dict)
213
- and "content" in args
214
- and args["content"] is not None
344
+ isinstance(reviewed_data, dict)
345
+ and "name" in reviewed_data
346
+ and "args" in reviewed_data
215
347
  ):
216
- if content_index < len(content_list):
217
- updated_content = json.loads(
218
- content_list[content_index]
219
- )
220
- args["content"] = updated_content
221
- tool_call["args"] = args
222
- content_index += 1
223
- ai_message.tool_calls = tool_calls
224
-
225
- if len(content_list) > content_index:
226
- ai_message.content = content_list[-1]
227
- else:
228
- # Fallback for other message types
229
- if content_list:
230
- last_message.content = content_list[-1]
231
-
232
- return Command[Any](update={"messages": msgs})
348
+ reviewed_tool_calls_map[reviewed_data["name"]] = (
349
+ reviewed_data["args"]
350
+ )
351
+
352
+ # Update tool calls with reviewed args by matching name
353
+ if reviewed_tool_calls_map:
354
+ for tool_call in tool_calls:
355
+ tool_name = (
356
+ tool_call.get("name")
357
+ if isinstance(tool_call, dict)
358
+ else getattr(tool_call, "name", None)
359
+ )
360
+ if tool_name and tool_name in reviewed_tool_calls_map:
361
+ if isinstance(tool_call, dict):
362
+ tool_call["args"] = reviewed_tool_calls_map[
363
+ tool_name
364
+ ]
365
+ else:
366
+ tool_call.args = reviewed_tool_calls_map[tool_name]
367
+
368
+ ai_message.tool_calls = tool_calls
369
+ tool_calls_processed = True
370
+
371
+ # Fallback: update message content if tool_calls weren't processed
372
+ if not tool_calls_processed:
373
+ last_message.content = reviewed_outputs_json
374
+
375
+ return Command(update={"messages": msgs})
233
376
  except Exception as e:
234
377
  raise AgentTerminationException(
235
378
  code=UiPathErrorCode.EXECUTION_ERROR,
@@ -264,11 +407,7 @@ def _process_tool_escalation_response(
264
407
  AgentTerminationException: If escalation response processing fails.
265
408
  """
266
409
  try:
267
- reviewed_field = (
268
- "ReviewedInputs"
269
- if execution_stage == ExecutionStage.PRE_EXECUTION
270
- else "ReviewedOutputs"
271
- )
410
+ reviewed_field = get_reviewed_field_name(execution_stage)
272
411
 
273
412
  msgs = state.messages.copy()
274
413
  if not msgs or reviewed_field not in escalation_result:
@@ -292,7 +431,11 @@ def _process_tool_escalation_response(
292
431
  if last_message.tool_calls:
293
432
  tool_calls = list(last_message.tool_calls)
294
433
  for tool_call in tool_calls:
295
- call_name = extract_tool_name(tool_call)
434
+ call_name = (
435
+ tool_call.get("name")
436
+ if isinstance(tool_call, dict)
437
+ else getattr(tool_call, "name", None)
438
+ )
296
439
  if call_name == tool_name:
297
440
  # Update args for the matching tool call
298
441
  if isinstance(reviewed_tool_calls_args, dict):
@@ -311,7 +454,7 @@ def _process_tool_escalation_response(
311
454
  if reviewed_outputs_json:
312
455
  last_message.content = reviewed_outputs_json
313
456
 
314
- return Command[Any](update={"messages": msgs})
457
+ return Command(update={"messages": msgs})
315
458
  except Exception as e:
316
459
  raise AgentTerminationException(
317
460
  code=UiPathErrorCode.EXECUTION_ERROR,
@@ -321,50 +464,65 @@ def _process_tool_escalation_response(
321
464
 
322
465
 
323
466
  def _extract_escalation_content(
467
+ message: BaseMessage,
324
468
  state: AgentGuardrailsGraphState,
325
469
  scope: GuardrailScope,
326
470
  execution_stage: ExecutionStage,
327
471
  guarded_node_name: str,
328
472
  ) -> str | list[str | Dict[str, Any]]:
329
- """Extract escalation content from state based on guardrail scope and execution stage.
473
+ """Extract escalation content from a message based on guardrail scope and execution stage.
330
474
 
331
475
  Args:
332
- state: The current agent graph state.
476
+ message: The message to extract content from.
333
477
  scope: The guardrail scope (LLM/AGENT/TOOL).
334
478
  execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
479
+ guarded_node_name: Name of the guarded component.
335
480
 
336
481
  Returns:
337
482
  str or list[str | Dict[str, Any]]: For LLM scope, returns JSON string or list with message/tool call content.
338
483
  For AGENT scope, returns empty string. For TOOL scope, returns JSON string or list with tool-specific content.
339
-
340
- Raises:
341
- AgentTerminationException: If no messages are found in state.
342
484
  """
343
- if not state.messages:
344
- raise AgentTerminationException(
345
- code=UiPathErrorCode.EXECUTION_ERROR,
346
- title="Invalid state message",
347
- detail="No message found into agent state",
348
- )
349
-
350
485
  match scope:
351
486
  case GuardrailScope.LLM:
352
- return _extract_llm_escalation_content(state, execution_stage)
487
+ return _extract_llm_escalation_content(message, execution_stage)
353
488
  case GuardrailScope.AGENT:
354
- return _extract_agent_escalation_content(state, execution_stage)
489
+ return _extract_agent_escalation_content(message, state, execution_stage)
355
490
  case GuardrailScope.TOOL:
356
491
  return _extract_tool_escalation_content(
357
- state, execution_stage, guarded_node_name
492
+ message, execution_stage, guarded_node_name
358
493
  )
359
494
 
360
495
 
496
+ def _extract_agent_escalation_content(
497
+ message: BaseMessage,
498
+ state: AgentGuardrailsGraphState,
499
+ execution_stage: ExecutionStage,
500
+ ) -> str | list[str | Dict[str, Any]]:
501
+ """Extract escalation content for AGENT scope guardrails.
502
+
503
+ Args:
504
+ message: The message used to extract the agent input content.
505
+ state: The current agent guardrails graph state. Used to read `agent_result` for POST_EXECUTION.
506
+ execution_stage: PRE_EXECUTION or POST_EXECUTION.
507
+
508
+ Returns:
509
+ - PRE_EXECUTION: the agent input string (from message content).
510
+ - POST_EXECUTION: a JSON-serialized representation of `state.agent_result`.
511
+ """
512
+ if execution_stage == ExecutionStage.PRE_EXECUTION:
513
+ return get_message_content(cast(AnyMessage, message))
514
+
515
+ output_content = state.agent_result or ""
516
+ return json.dumps(output_content)
517
+
518
+
361
519
  def _extract_llm_escalation_content(
362
- state: AgentGuardrailsGraphState, execution_stage: ExecutionStage
520
+ message: BaseMessage, execution_stage: ExecutionStage
363
521
  ) -> str | list[str | Dict[str, Any]]:
364
522
  """Extract escalation content for LLM scope guardrails.
365
523
 
366
524
  Args:
367
- state: The current agent graph state.
525
+ message: The message to extract content from.
368
526
  execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
369
527
 
370
528
  Returns:
@@ -372,61 +530,37 @@ def _extract_llm_escalation_content(
372
530
  For PostExecution, returns JSON string (array) with tool call content and message content.
373
531
  Returns empty string if no content found.
374
532
  """
375
- last_message = state.messages[-1]
376
533
  if execution_stage == ExecutionStage.PRE_EXECUTION:
377
- if isinstance(last_message, ToolMessage):
378
- return last_message.content
534
+ if isinstance(message, ToolMessage):
535
+ return message.content
379
536
 
380
- content = _message_text(last_message)
381
- return json.dumps(content) if content else ""
537
+ return get_message_content(cast(AnyMessage, message))
382
538
 
383
539
  # For AI messages, process tool calls if present
384
- if isinstance(last_message, AIMessage):
385
- ai_message: AIMessage = last_message
386
- content_list: list[str] = []
540
+ if isinstance(message, AIMessage):
541
+ ai_message: AIMessage = message
387
542
 
388
543
  if ai_message.tool_calls:
544
+ content_list: list[Dict[str, Any]] = []
389
545
  for tool_call in ai_message.tool_calls:
390
- args = tool_call["args"]
391
- if (
392
- isinstance(args, dict)
393
- and "content" in args
394
- and args["content"] is not None
395
- ):
396
- content_list.append(json.dumps(args["content"]))
397
-
398
- message_content = _message_text(last_message)
399
- if message_content:
400
- content_list.append(message_content)
401
-
402
- return json.dumps(content_list)
546
+ tool_call_data = {
547
+ "name": tool_call.get("name"),
548
+ "args": tool_call.get("args"),
549
+ }
550
+ content_list.append(tool_call_data)
551
+ return json.dumps(content_list)
403
552
 
404
553
  # Fallback for other message types
405
- return _message_text(last_message)
406
-
407
-
408
- def _extract_agent_escalation_content(
409
- state: AgentGuardrailsGraphState, execution_stage: ExecutionStage
410
- ) -> str | list[str | Dict[str, Any]]:
411
- """Extract escalation content for AGENT scope guardrails.
412
-
413
- Args:
414
- state: The current agent graph state.
415
- execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
416
-
417
- Returns:
418
- str: Empty string (AGENT scope guardrails do not extract escalation content).
419
- """
420
- return ""
554
+ return get_message_content(cast(AnyMessage, message))
421
555
 
422
556
 
423
557
  def _extract_tool_escalation_content(
424
- state: AgentGuardrailsGraphState, execution_stage: ExecutionStage, tool_name: str
558
+ message: BaseMessage, execution_stage: ExecutionStage, tool_name: str
425
559
  ) -> str | list[str | Dict[str, Any]]:
426
560
  """Extract escalation content for TOOL scope guardrails.
427
561
 
428
562
  Args:
429
- state: The current agent graph state.
563
+ message: The message to extract content from.
430
564
  execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
431
565
  tool_name: Optional tool name to filter tool calls. If provided, only extracts args for matching tool.
432
566
 
@@ -435,38 +569,31 @@ def _extract_tool_escalation_content(
435
569
  for the specified tool name, or empty string if not found. For PostExecution, returns string with
436
570
  tool message content, or empty string if message type doesn't match.
437
571
  """
438
- last_message = state.messages[-1]
439
572
  if execution_stage == ExecutionStage.PRE_EXECUTION:
440
- if not isinstance(last_message, AIMessage):
441
- return ""
442
- if not last_message.tool_calls:
443
- return ""
444
-
445
- # Find the tool call with matching name
446
- for tool_call in last_message.tool_calls:
447
- call_name = extract_tool_name(tool_call)
448
- if call_name == tool_name:
449
- # Extract args from the matching tool call
450
- args = (
451
- tool_call.get("args")
452
- if isinstance(tool_call, dict)
453
- else getattr(tool_call, "args", None)
454
- )
455
- if args is not None:
456
- return json.dumps(args)
573
+ args = _extract_tool_args_from_message(cast(AnyMessage, message), tool_name)
574
+ if args:
575
+ return json.dumps(args)
457
576
  return ""
458
577
  else:
459
- if not isinstance(last_message, ToolMessage):
578
+ if not isinstance(message, ToolMessage):
460
579
  return ""
461
- return last_message.content
462
-
463
-
464
- def extract_tool_name(tool_call: ToolCall) -> Any | None:
465
- return (
466
- tool_call.get("name")
467
- if isinstance(tool_call, dict)
468
- else getattr(tool_call, "name", None)
469
- )
580
+ content = message.content
581
+
582
+ # If content is already dict/list, serialize to JSON
583
+ if isinstance(content, (dict, list)):
584
+ return json.dumps(content)
585
+
586
+ # If content is a string that looks like a Python literal, convert to JSON
587
+ if isinstance(content, str):
588
+ try:
589
+ # Try to parse as Python literal and convert to JSON
590
+ parsed_content = ast.literal_eval(content)
591
+ return json.dumps(parsed_content)
592
+ except (ValueError, SyntaxError):
593
+ # If parsing fails, return as-is
594
+ pass
595
+
596
+ return content
470
597
 
471
598
 
472
599
  def _execution_stage_to_escalation_field(