uipath-langchain 0.1.28__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. uipath_langchain/_cli/_templates/langgraph.json.template +2 -4
  2. uipath_langchain/_cli/cli_new.py +1 -2
  3. uipath_langchain/_utils/_request_mixin.py +8 -0
  4. uipath_langchain/_utils/_settings.py +3 -2
  5. uipath_langchain/agent/guardrails/__init__.py +0 -16
  6. uipath_langchain/agent/guardrails/actions/__init__.py +2 -0
  7. uipath_langchain/agent/guardrails/actions/block_action.py +1 -1
  8. uipath_langchain/agent/guardrails/actions/escalate_action.py +265 -138
  9. uipath_langchain/agent/guardrails/actions/filter_action.py +290 -0
  10. uipath_langchain/agent/guardrails/actions/log_action.py +1 -1
  11. uipath_langchain/agent/guardrails/guardrail_nodes.py +193 -42
  12. uipath_langchain/agent/guardrails/guardrails_factory.py +235 -14
  13. uipath_langchain/agent/guardrails/types.py +0 -12
  14. uipath_langchain/agent/guardrails/utils.py +177 -0
  15. uipath_langchain/agent/react/agent.py +24 -9
  16. uipath_langchain/agent/react/constants.py +1 -2
  17. uipath_langchain/agent/react/file_type_handler.py +123 -0
  18. uipath_langchain/agent/{guardrails → react/guardrails}/guardrails_subgraph.py +119 -25
  19. uipath_langchain/agent/react/init_node.py +16 -1
  20. uipath_langchain/agent/react/job_attachments.py +125 -0
  21. uipath_langchain/agent/react/json_utils.py +183 -0
  22. uipath_langchain/agent/react/jsonschema_pydantic_converter.py +76 -0
  23. uipath_langchain/agent/react/llm_node.py +41 -10
  24. uipath_langchain/agent/react/llm_with_files.py +76 -0
  25. uipath_langchain/agent/react/router.py +48 -37
  26. uipath_langchain/agent/react/types.py +19 -1
  27. uipath_langchain/agent/react/utils.py +30 -4
  28. uipath_langchain/agent/tools/__init__.py +7 -1
  29. uipath_langchain/agent/tools/context_tool.py +151 -1
  30. uipath_langchain/agent/tools/escalation_tool.py +46 -15
  31. uipath_langchain/agent/tools/integration_tool.py +20 -16
  32. uipath_langchain/agent/tools/internal_tools/__init__.py +5 -0
  33. uipath_langchain/agent/tools/internal_tools/analyze_files_tool.py +113 -0
  34. uipath_langchain/agent/tools/internal_tools/internal_tool_factory.py +54 -0
  35. uipath_langchain/agent/tools/mcp_tool.py +86 -0
  36. uipath_langchain/agent/tools/process_tool.py +8 -1
  37. uipath_langchain/agent/tools/static_args.py +18 -40
  38. uipath_langchain/agent/tools/tool_factory.py +13 -5
  39. uipath_langchain/agent/tools/tool_node.py +133 -4
  40. uipath_langchain/agent/tools/utils.py +31 -0
  41. uipath_langchain/agent/wrappers/__init__.py +6 -0
  42. uipath_langchain/agent/wrappers/job_attachment_wrapper.py +62 -0
  43. uipath_langchain/agent/wrappers/static_args_wrapper.py +34 -0
  44. uipath_langchain/chat/__init__.py +4 -0
  45. uipath_langchain/chat/bedrock.py +16 -0
  46. uipath_langchain/chat/mapper.py +60 -42
  47. uipath_langchain/chat/openai.py +56 -26
  48. uipath_langchain/chat/supported_models.py +9 -0
  49. uipath_langchain/chat/vertex.py +62 -46
  50. uipath_langchain/embeddings/embeddings.py +18 -12
  51. uipath_langchain/runtime/factory.py +10 -5
  52. uipath_langchain/runtime/runtime.py +38 -35
  53. uipath_langchain/runtime/schema.py +72 -16
  54. uipath_langchain/runtime/storage.py +178 -71
  55. {uipath_langchain-0.1.28.dist-info → uipath_langchain-0.3.1.dist-info}/METADATA +7 -4
  56. uipath_langchain-0.3.1.dist-info/RECORD +90 -0
  57. uipath_langchain-0.1.28.dist-info/RECORD +0 -76
  58. {uipath_langchain-0.1.28.dist-info → uipath_langchain-0.3.1.dist-info}/WHEEL +0 -0
  59. {uipath_langchain-0.1.28.dist-info → uipath_langchain-0.3.1.dist-info}/entry_points.txt +0 -0
  60. {uipath_langchain-0.1.28.dist-info → uipath_langchain-0.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,34 +1,65 @@
1
- """LLM node implementation for LangGraph."""
1
+ """LLM node for ReAct Agent graph."""
2
2
 
3
- from typing import Sequence
3
+ from typing import Literal, Sequence
4
4
 
5
5
  from langchain_core.language_models import BaseChatModel
6
6
  from langchain_core.messages import AIMessage, AnyMessage
7
7
  from langchain_core.tools import BaseTool
8
8
 
9
- from .constants import MAX_SUCCESSIVE_COMPLETIONS
9
+ from .constants import MAX_CONSECUTIVE_THINKING_MESSAGES
10
10
  from .types import AgentGraphState
11
- from .utils import count_successive_completions
11
+ from .utils import count_consecutive_thinking_messages
12
+
13
+ OPENAI_COMPATIBLE_CHAT_MODELS = (
14
+ "UiPathChatOpenAI",
15
+ "AzureChatOpenAI",
16
+ "ChatOpenAI",
17
+ "UiPathChat",
18
+ "UiPathAzureChatOpenAI",
19
+ )
20
+
21
+
22
+ def _get_required_tool_choice_by_model(
23
+ model: BaseChatModel,
24
+ ) -> Literal["required", "any"]:
25
+ """Get the appropriate tool_choice value to enforce tool usage based on model type.
26
+
27
+ "required" - OpenAI compatible required tool_choice value
28
+ "any" - Vertex and Bedrock parameter for required tool_choice value
29
+ """
30
+ model_class_name = model.__class__.__name__
31
+ if model_class_name in OPENAI_COMPATIBLE_CHAT_MODELS:
32
+ return "required"
33
+ return "any"
12
34
 
13
35
 
14
36
  def create_llm_node(
15
37
  model: BaseChatModel,
16
38
  tools: Sequence[BaseTool] | None = None,
39
+ thinking_messages_limit: int = MAX_CONSECUTIVE_THINKING_MESSAGES,
17
40
  ):
18
- """Invoke LLM with tools and dynamically control tool_choice based on successive completions.
41
+ """Create LLM node with dynamic tool_choice enforcement.
19
42
 
20
- When successive completions reach the limit, tool_choice is set to "required" to force
21
- the LLM to use a tool and prevent infinite reasoning loops.
43
+ Controls when to force tool usage based on consecutive thinking steps
44
+ to prevent infinite loops and ensure progress.
45
+
46
+ Args:
47
+ model: The chat model to use
48
+ tools: Available tools to bind
49
+ thinking_messages_limit: Max consecutive LLM responses without tool calls
50
+ before enforcing tool usage. 0 = force tools every time.
22
51
  """
23
52
  bindable_tools = list(tools) if tools else []
24
53
  base_llm = model.bind_tools(bindable_tools) if bindable_tools else model
54
+ tool_choice_required_value = _get_required_tool_choice_by_model(model)
25
55
 
26
56
  async def llm_node(state: AgentGraphState):
27
57
  messages: list[AnyMessage] = state.messages
28
58
 
29
- successive_completions = count_successive_completions(messages)
30
- if successive_completions >= MAX_SUCCESSIVE_COMPLETIONS:
31
- llm = base_llm.bind(tool_choice="required")
59
+ consecutive_thinking_messages = count_consecutive_thinking_messages(messages)
60
+
61
+ if bindable_tools and consecutive_thinking_messages >= thinking_messages_limit:
62
+ llm = base_llm.bind(tool_choice=tool_choice_required_value)
32
63
  else:
33
64
  llm = base_llm
34
65
 
@@ -0,0 +1,76 @@
1
+ """LLM invocation with file attachments support."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any
5
+
6
+ from langchain_core.language_models import BaseChatModel
7
+ from langchain_core.messages import AIMessage, AnyMessage, HumanMessage
8
+
9
+ from .file_type_handler import build_message_content_part_from_data
10
+
11
+
12
+ @dataclass
13
+ class FileInfo:
14
+ """File information for LLM file attachments."""
15
+
16
+ url: str
17
+ name: str
18
+ mime_type: str
19
+
20
+
21
+ def _get_model_name(model: BaseChatModel) -> str:
22
+ """Extract model name from a BaseChatModel instance."""
23
+ for attr in ["model_name", "_model_name", "model", "model_id"]:
24
+ value = getattr(model, attr, None)
25
+ if value and isinstance(value, str):
26
+ return value
27
+ raise ValueError(f"Model name not found in model {model}")
28
+
29
+
30
+ async def create_part_for_file(
31
+ file_info: FileInfo,
32
+ model: BaseChatModel,
33
+ ) -> dict[str, Any]:
34
+ """Create a provider-specific message content part for a file attachment.
35
+
36
+ Downloads the file from file_info.url and formats it for the model's provider.
37
+ """
38
+ model_name = _get_model_name(model)
39
+ return await build_message_content_part_from_data(
40
+ url=file_info.url,
41
+ filename=file_info.name,
42
+ mime_type=file_info.mime_type,
43
+ model=model_name,
44
+ )
45
+
46
+
47
+ async def llm_call_with_files(
48
+ messages: list[AnyMessage],
49
+ files: list[FileInfo],
50
+ model: BaseChatModel,
51
+ ) -> AIMessage:
52
+ """Invoke an LLM with file attachments.
53
+
54
+ Downloads files, creates provider-specific content parts, and appends them
55
+ as a HumanMessage. If no files are provided, equivalent to model.ainvoke().
56
+ """
57
+ if not files:
58
+ response = await model.ainvoke(messages)
59
+ if not isinstance(response, AIMessage):
60
+ raise TypeError(
61
+ f"LLM returned {type(response).__name__} instead of AIMessage"
62
+ )
63
+ return response
64
+
65
+ content_parts: list[str | dict[Any, Any]] = []
66
+ for file_info in files:
67
+ content_part = await create_part_for_file(file_info, model)
68
+ content_parts.append(content_part)
69
+
70
+ file_message = HumanMessage(content=content_parts)
71
+ all_messages = list(messages) + [file_message]
72
+
73
+ response = await model.ainvoke(all_messages)
74
+ if not isinstance(response, AIMessage):
75
+ raise TypeError(f"LLM returned {type(response).__name__} instead of AIMessage")
76
+ return response
@@ -6,9 +6,8 @@ from langchain_core.messages import AIMessage, AnyMessage, ToolCall
6
6
  from uipath.agent.react import END_EXECUTION_TOOL, RAISE_ERROR_TOOL
7
7
 
8
8
  from ..exceptions import AgentNodeRoutingException
9
- from .constants import MAX_SUCCESSIVE_COMPLETIONS
10
9
  from .types import AgentGraphNode, AgentGraphState
11
- from .utils import count_successive_completions
10
+ from .utils import count_consecutive_thinking_messages
12
11
 
13
12
  FLOW_CONTROL_TOOLS = [END_EXECUTION_TOOL.name, RAISE_ERROR_TOOL.name]
14
13
 
@@ -48,50 +47,62 @@ def __validate_last_message_is_AI(messages: list[AnyMessage]) -> AIMessage:
48
47
  return last_message
49
48
 
50
49
 
51
- def route_agent(
52
- state: AgentGraphState,
53
- ) -> list[str] | Literal[AgentGraphNode.AGENT, AgentGraphNode.TERMINATE]:
54
- """Route after agent: handles all routing logic including control flow detection.
50
+ def create_route_agent(thinking_messages_limit: int = 0):
51
+ """Create a routing function configured with thinking_messages_limit.
55
52
 
56
- Routing logic:
57
- 1. If multiple tool calls exist, filter out control flow tools (EndExecution, RaiseError)
58
- 2. If control flow tool(s) remain, route to TERMINATE
59
- 3. If regular tool calls remain, route to specific tool nodes (return list of tool names)
60
- 4. If no tool calls, handle successive completions
53
+ Args:
54
+ thinking_messages_limit: Max consecutive thinking messages before error
61
55
 
62
56
  Returns:
63
- - list[str]: Tool node names for parallel execution
64
- - AgentGraphNode.AGENT: For successive completions
65
- - AgentGraphNode.TERMINATE: For control flow termination
66
-
67
- Raises:
68
- AgentNodeRoutingException: When encountering unexpected state (empty messages, non-AIMessage, or excessive completions)
57
+ Routing function for LangGraph conditional edges
69
58
  """
70
- messages = state.messages
71
- last_message = __validate_last_message_is_AI(messages)
72
59
 
73
- tool_calls = list(last_message.tool_calls) if last_message.tool_calls else []
74
- tool_calls = __filter_control_flow_tool_calls(tool_calls)
60
+ def route_agent(
61
+ state: AgentGraphState,
62
+ ) -> list[str] | Literal[AgentGraphNode.AGENT, AgentGraphNode.TERMINATE]:
63
+ """Route after agent: handles all routing logic including control flow detection.
64
+
65
+ Routing logic:
66
+ 1. If multiple tool calls exist, filter out control flow tools (EndExecution, RaiseError)
67
+ 2. If control flow tool(s) remain, route to TERMINATE
68
+ 3. If regular tool calls remain, route to specific tool nodes (return list of tool names)
69
+ 4. If no tool calls, handle consecutive completions
70
+
71
+ Returns:
72
+ - list[str]: Tool node names for parallel execution
73
+ - AgentGraphNode.AGENT: For consecutive completions
74
+ - AgentGraphNode.TERMINATE: For control flow termination
75
+
76
+ Raises:
77
+ AgentNodeRoutingException: When encountering unexpected state (empty messages, non-AIMessage, or excessive completions)
78
+ """
79
+ messages = state.messages
80
+ last_message = __validate_last_message_is_AI(messages)
75
81
 
76
- if tool_calls and __has_control_flow_tool(tool_calls):
77
- return AgentGraphNode.TERMINATE
82
+ tool_calls = list(last_message.tool_calls) if last_message.tool_calls else []
83
+ tool_calls = __filter_control_flow_tool_calls(tool_calls)
78
84
 
79
- if tool_calls:
80
- return [tc["name"] for tc in tool_calls]
85
+ if tool_calls and __has_control_flow_tool(tool_calls):
86
+ return AgentGraphNode.TERMINATE
81
87
 
82
- successive_completions = count_successive_completions(messages)
88
+ if tool_calls:
89
+ return [tc["name"] for tc in tool_calls]
90
+
91
+ consecutive_thinking_messages = count_consecutive_thinking_messages(messages)
92
+
93
+ if consecutive_thinking_messages > thinking_messages_limit:
94
+ raise AgentNodeRoutingException(
95
+ f"Agent exceeded consecutive completions limit without producing tool calls "
96
+ f"(completions: {consecutive_thinking_messages}, max: {thinking_messages_limit}). "
97
+ f"This should not happen as tool_choice='required' is enforced at the limit."
98
+ )
99
+
100
+ if last_message.content:
101
+ return AgentGraphNode.AGENT
83
102
 
84
- if successive_completions > MAX_SUCCESSIVE_COMPLETIONS:
85
103
  raise AgentNodeRoutingException(
86
- f"Agent exceeded successive completions limit without producing tool calls "
87
- f"(completions: {successive_completions}, max: {MAX_SUCCESSIVE_COMPLETIONS}). "
88
- f"This should not happen as tool_choice='required' is enforced at the limit."
104
+ f"Agent produced empty response without tool calls "
105
+ f"(completions: {consecutive_thinking_messages}, has_content: False)"
89
106
  )
90
107
 
91
- if last_message.content:
92
- return AgentGraphNode.AGENT
93
-
94
- raise AgentNodeRoutingException(
95
- f"Agent produced empty response without tool calls "
96
- f"(completions: {successive_completions}, has_content: False)"
97
- )
108
+ return route_agent
@@ -1,9 +1,12 @@
1
1
  from enum import StrEnum
2
- from typing import Annotated
2
+ from typing import Annotated, Any, Optional
3
3
 
4
4
  from langchain_core.messages import AnyMessage
5
5
  from langgraph.graph.message import add_messages
6
6
  from pydantic import BaseModel, Field
7
+ from uipath.platform.attachments import Attachment
8
+
9
+ from uipath_langchain.agent.react.utils import add_job_attachments
7
10
 
8
11
 
9
12
  class AgentTerminationSource(StrEnum):
@@ -22,18 +25,33 @@ class AgentGraphState(BaseModel):
22
25
  """Agent Graph state for standard loop execution."""
23
26
 
24
27
  messages: Annotated[list[AnyMessage], add_messages] = []
28
+ job_attachments: Annotated[dict[str, Attachment], add_job_attachments] = {}
25
29
  termination: AgentTermination | None = None
26
30
 
27
31
 
32
+ class AgentGuardrailsGraphState(AgentGraphState):
33
+ """Agent Guardrails Graph state for guardrail subgraph."""
34
+
35
+ guardrail_validation_result: Optional[str] = None
36
+ agent_result: Optional[dict[str, Any]] = None
37
+
38
+
28
39
  class AgentGraphNode(StrEnum):
29
40
  INIT = "init"
41
+ GUARDED_INIT = "guarded-init"
30
42
  AGENT = "agent"
31
43
  LLM = "llm"
32
44
  TOOLS = "tools"
33
45
  TERMINATE = "terminate"
46
+ GUARDED_TERMINATE = "guarded-terminate"
34
47
 
35
48
 
36
49
  class AgentGraphConfig(BaseModel):
37
50
  recursion_limit: int = Field(
38
51
  default=50, ge=1, description="Maximum recursion limit for the agent graph"
39
52
  )
53
+ thinking_messages_limit: int = Field(
54
+ default=0,
55
+ ge=0,
56
+ description="Max consecutive thinking messages before enforcing tool usage. 0 = force tools every time.",
57
+ )
@@ -5,7 +5,9 @@ from typing import Any, Sequence
5
5
  from langchain_core.messages import AIMessage, BaseMessage
6
6
  from pydantic import BaseModel
7
7
  from uipath.agent.react import END_EXECUTION_TOOL
8
- from uipath.utils.dynamic_schema import jsonschema_to_pydantic
8
+ from uipath.platform.attachments import Attachment
9
+
10
+ from uipath_langchain.agent.react.jsonschema_pydantic_converter import create_model
9
11
 
10
12
 
11
13
  def resolve_input_model(
@@ -13,7 +15,7 @@ def resolve_input_model(
13
15
  ) -> type[BaseModel]:
14
16
  """Resolve the input model from the input schema."""
15
17
  if input_schema:
16
- return jsonschema_to_pydantic(input_schema)
18
+ return create_model(input_schema)
17
19
 
18
20
  return BaseModel
19
21
 
@@ -23,12 +25,12 @@ def resolve_output_model(
23
25
  ) -> type[BaseModel]:
24
26
  """Fallback to default end_execution tool schema when no agent output schema is provided."""
25
27
  if output_schema:
26
- return jsonschema_to_pydantic(output_schema)
28
+ return create_model(output_schema)
27
29
 
28
30
  return END_EXECUTION_TOOL.args_schema
29
31
 
30
32
 
31
- def count_successive_completions(messages: Sequence[BaseMessage]) -> int:
33
+ def count_consecutive_thinking_messages(messages: Sequence[BaseMessage]) -> int:
32
34
  """Count consecutive AIMessages without tool calls at end of message history."""
33
35
  if not messages:
34
36
  return 0
@@ -47,3 +49,27 @@ def count_successive_completions(messages: Sequence[BaseMessage]) -> int:
47
49
  count += 1
48
50
 
49
51
  return count
52
+
53
+
54
+ def add_job_attachments(
55
+ left: dict[str, Attachment], right: dict[str, Attachment]
56
+ ) -> dict[str, Attachment]:
57
+ """Merge attachment dictionaries, with right values taking precedence.
58
+
59
+ This reducer function merges two dictionaries of attachments by UUID string.
60
+ If the same UUID exists in both dictionaries, the value from 'right' takes precedence.
61
+
62
+ Args:
63
+ left: Existing dictionary of attachments keyed by UUID string
64
+ right: New dictionary of attachments to merge
65
+
66
+ Returns:
67
+ Merged dictionary with right values overriding left values for duplicate keys
68
+ """
69
+ if not right:
70
+ return left
71
+
72
+ if not left:
73
+ return right
74
+
75
+ return {**left, **right}
@@ -1,12 +1,14 @@
1
1
  """Tool creation and management for LowCode agents."""
2
2
 
3
3
  from .context_tool import create_context_tool
4
+ from .escalation_tool import create_escalation_tool
4
5
  from .integration_tool import create_integration_tool
6
+ from .mcp_tool import create_mcp_tools
5
7
  from .process_tool import create_process_tool
6
8
  from .tool_factory import (
7
9
  create_tools_from_resources,
8
10
  )
9
- from .tool_node import create_tool_node
11
+ from .tool_node import ToolWrapperMixin, UiPathToolNode, create_tool_node
10
12
 
11
13
  __all__ = [
12
14
  "create_tools_from_resources",
@@ -14,4 +16,8 @@ __all__ = [
14
16
  "create_context_tool",
15
17
  "create_process_tool",
16
18
  "create_integration_tool",
19
+ "create_escalation_tool",
20
+ "create_mcp_tools",
21
+ "UiPathToolNode",
22
+ "ToolWrapperMixin",
17
23
  ]
@@ -1,12 +1,24 @@
1
1
  """Context tool creation for semantic index retrieval."""
2
2
 
3
+ import uuid
3
4
  from typing import Any
4
5
 
5
6
  from langchain_core.documents import Document
6
7
  from langchain_core.tools import StructuredTool
8
+ from langgraph.types import interrupt
7
9
  from pydantic import BaseModel, Field
8
- from uipath.agent.models.agent import AgentContextResourceConfig
10
+ from uipath.agent.models.agent import (
11
+ AgentContextResourceConfig,
12
+ AgentContextRetrievalMode,
13
+ )
9
14
  from uipath.eval.mocks import mockable
15
+ from uipath.platform.common import CreateBatchTransform, CreateDeepRag
16
+ from uipath.platform.context_grounding import (
17
+ BatchTransformOutputColumn,
18
+ BatchTransformResponse,
19
+ CitationMode,
20
+ DeepRagResponse,
21
+ )
10
22
 
11
23
  from uipath_langchain.retrievers import ContextGroundingRetriever
12
24
 
@@ -16,6 +28,18 @@ from .utils import sanitize_tool_name
16
28
 
17
29
  def create_context_tool(resource: AgentContextResourceConfig) -> StructuredTool:
18
30
  tool_name = sanitize_tool_name(resource.name)
31
+ retrieval_mode = resource.settings.retrieval_mode.lower()
32
+ if retrieval_mode == AgentContextRetrievalMode.DEEP_RAG.value.lower():
33
+ return handle_deep_rag(tool_name, resource)
34
+ elif retrieval_mode == AgentContextRetrievalMode.BATCH_TRANSFORM.value.lower():
35
+ return handle_batch_transform(tool_name, resource)
36
+ else:
37
+ return handle_semantic_search(tool_name, resource)
38
+
39
+
40
+ def handle_semantic_search(
41
+ tool_name: str, resource: AgentContextResourceConfig
42
+ ) -> StructuredTool:
19
43
  retriever = ContextGroundingRetriever(
20
44
  index_name=resource.index_name,
21
45
  folder_path=resource.folder_path,
@@ -40,6 +64,7 @@ def create_context_tool(resource: AgentContextResourceConfig) -> StructuredTool:
40
64
  description=resource.description,
41
65
  input_schema=input_model.model_json_schema(),
42
66
  output_schema=output_model.model_json_schema(),
67
+ example_calls=[], # Examples cannot be provided for context.
43
68
  )
44
69
  async def context_tool_fn(query: str) -> dict[str, Any]:
45
70
  return {"documents": await retriever.ainvoke(query)}
@@ -51,3 +76,128 @@ def create_context_tool(resource: AgentContextResourceConfig) -> StructuredTool:
51
76
  coroutine=context_tool_fn,
52
77
  output_type=output_model,
53
78
  )
79
+
80
+
81
+ def handle_deep_rag(
82
+ tool_name: str, resource: AgentContextResourceConfig
83
+ ) -> StructuredTool:
84
+ ensure_valid_fields(resource)
85
+ # needed for type checking
86
+ assert resource.settings.query is not None
87
+ assert resource.settings.query.value is not None
88
+
89
+ index_name = resource.index_name
90
+ prompt = resource.settings.query.value
91
+ if not resource.settings.citation_mode:
92
+ raise ValueError("Citation mode is required for Deep RAG")
93
+ citation_mode = CitationMode(resource.settings.citation_mode.value)
94
+
95
+ input_model = None
96
+ output_model = DeepRagResponse
97
+
98
+ @mockable(
99
+ name=resource.name,
100
+ description=resource.description,
101
+ input_schema=input_model,
102
+ output_schema=output_model.model_json_schema(),
103
+ example_calls=[], # Examples cannot be provided for context.
104
+ )
105
+ async def context_tool_fn() -> dict[str, Any]:
106
+ # TODO: add glob pattern support
107
+ return interrupt(
108
+ CreateDeepRag(
109
+ name=f"task-{uuid.uuid4()}",
110
+ index_name=index_name,
111
+ prompt=prompt,
112
+ citation_mode=citation_mode,
113
+ )
114
+ )
115
+
116
+ return StructuredToolWithOutputType(
117
+ name=tool_name,
118
+ description=resource.description,
119
+ args_schema=input_model,
120
+ coroutine=context_tool_fn,
121
+ output_type=output_model,
122
+ )
123
+
124
+
125
+ def handle_batch_transform(
126
+ tool_name: str, resource: AgentContextResourceConfig
127
+ ) -> StructuredTool:
128
+ ensure_valid_fields(resource)
129
+
130
+ # needed for type checking
131
+ assert resource.settings.query is not None
132
+ assert resource.settings.query.value is not None
133
+
134
+ index_name = resource.index_name
135
+ prompt = resource.settings.query.value
136
+
137
+ index_folder_path = resource.folder_path
138
+ if not resource.settings.web_search_grounding:
139
+ raise ValueError("Web search grounding field is required for Batch Transform")
140
+ enable_web_search_grounding = (
141
+ resource.settings.web_search_grounding.value.lower() == "enabled"
142
+ )
143
+
144
+ batch_transform_output_columns: list[BatchTransformOutputColumn] = []
145
+ if (output_columns := resource.settings.output_columns) is None or not len(
146
+ output_columns
147
+ ):
148
+ raise ValueError(
149
+ "Batch transform requires at least one output column to be specified in settings.output_columns"
150
+ )
151
+
152
+ for column in output_columns:
153
+ batch_transform_output_columns.append(
154
+ BatchTransformOutputColumn(
155
+ name=column.name,
156
+ description=column.description,
157
+ )
158
+ )
159
+
160
+ class BatchTransformSchemaModel(BaseModel):
161
+ destination_path: str = Field(
162
+ ...,
163
+ description="The relative file path destination for the modified csv file",
164
+ )
165
+
166
+ input_model = BatchTransformSchemaModel
167
+ output_model = BatchTransformResponse
168
+
169
+ @mockable(
170
+ name=resource.name,
171
+ description=resource.description,
172
+ input_schema=input_model.model_json_schema(),
173
+ output_schema=output_model.model_json_schema(),
174
+ example_calls=[], # Examples cannot be provided for context.
175
+ )
176
+ async def context_tool_fn(destination_path: str) -> dict[str, Any]:
177
+ # TODO: storage_bucket_folder_path_prefix support
178
+ return interrupt(
179
+ CreateBatchTransform(
180
+ name=f"task-{uuid.uuid4()}",
181
+ index_name=index_name,
182
+ prompt=prompt,
183
+ destination_path=destination_path,
184
+ index_folder_path=index_folder_path,
185
+ enable_web_search_grounding=enable_web_search_grounding,
186
+ output_columns=batch_transform_output_columns,
187
+ )
188
+ )
189
+
190
+ return StructuredToolWithOutputType(
191
+ name=tool_name,
192
+ description=resource.description,
193
+ args_schema=input_model,
194
+ coroutine=context_tool_fn,
195
+ output_type=output_model,
196
+ )
197
+
198
+
199
+ def ensure_valid_fields(resource_config: AgentContextResourceConfig):
200
+ if not resource_config.settings.query:
201
+ raise ValueError("Query object is required")
202
+ if not resource_config.settings.query.value:
203
+ raise ValueError("Query prompt is required")