aiqtoolkit 1.2.0a20250706__py3-none-any.whl → 1.2.0a20250730__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aiqtoolkit might be problematic. Click here for more details.
- aiq/agent/base.py +171 -8
- aiq/agent/dual_node.py +1 -1
- aiq/agent/react_agent/agent.py +113 -113
- aiq/agent/react_agent/register.py +31 -14
- aiq/agent/rewoo_agent/agent.py +36 -35
- aiq/agent/rewoo_agent/register.py +2 -2
- aiq/agent/tool_calling_agent/agent.py +3 -7
- aiq/authentication/__init__.py +14 -0
- aiq/authentication/api_key/__init__.py +14 -0
- aiq/authentication/api_key/api_key_auth_provider.py +92 -0
- aiq/authentication/api_key/api_key_auth_provider_config.py +124 -0
- aiq/authentication/api_key/register.py +26 -0
- aiq/authentication/exceptions/__init__.py +14 -0
- aiq/authentication/exceptions/api_key_exceptions.py +38 -0
- aiq/authentication/exceptions/auth_code_grant_exceptions.py +86 -0
- aiq/authentication/exceptions/call_back_exceptions.py +38 -0
- aiq/authentication/exceptions/request_exceptions.py +54 -0
- aiq/authentication/http_basic_auth/__init__.py +0 -0
- aiq/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- aiq/authentication/http_basic_auth/register.py +30 -0
- aiq/authentication/interfaces.py +93 -0
- aiq/authentication/oauth2/__init__.py +14 -0
- aiq/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
- aiq/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- aiq/authentication/oauth2/register.py +25 -0
- aiq/authentication/register.py +21 -0
- aiq/builder/builder.py +64 -2
- aiq/builder/component_utils.py +16 -3
- aiq/builder/context.py +26 -0
- aiq/builder/eval_builder.py +43 -2
- aiq/builder/function.py +32 -4
- aiq/builder/function_base.py +1 -1
- aiq/builder/intermediate_step_manager.py +6 -8
- aiq/builder/user_interaction_manager.py +3 -0
- aiq/builder/workflow.py +23 -18
- aiq/builder/workflow_builder.py +420 -73
- aiq/cli/commands/info/list_mcp.py +103 -16
- aiq/cli/commands/sizing/__init__.py +14 -0
- aiq/cli/commands/sizing/calc.py +294 -0
- aiq/cli/commands/sizing/sizing.py +27 -0
- aiq/cli/commands/start.py +1 -0
- aiq/cli/entrypoint.py +2 -0
- aiq/cli/register_workflow.py +80 -0
- aiq/cli/type_registry.py +151 -30
- aiq/data_models/api_server.py +117 -11
- aiq/data_models/authentication.py +231 -0
- aiq/data_models/common.py +35 -7
- aiq/data_models/component.py +17 -9
- aiq/data_models/component_ref.py +33 -0
- aiq/data_models/config.py +60 -3
- aiq/data_models/embedder.py +1 -0
- aiq/data_models/function_dependencies.py +8 -0
- aiq/data_models/interactive.py +10 -1
- aiq/data_models/intermediate_step.py +15 -5
- aiq/data_models/its_strategy.py +30 -0
- aiq/data_models/llm.py +1 -0
- aiq/data_models/memory.py +1 -0
- aiq/data_models/object_store.py +44 -0
- aiq/data_models/retry_mixin.py +35 -0
- aiq/data_models/span.py +187 -0
- aiq/data_models/telemetry_exporter.py +2 -2
- aiq/embedder/nim_embedder.py +2 -1
- aiq/embedder/openai_embedder.py +2 -1
- aiq/eval/config.py +19 -1
- aiq/eval/dataset_handler/dataset_handler.py +75 -1
- aiq/eval/evaluate.py +53 -10
- aiq/eval/rag_evaluator/evaluate.py +23 -12
- aiq/eval/remote_workflow.py +7 -2
- aiq/eval/runners/__init__.py +14 -0
- aiq/eval/runners/config.py +39 -0
- aiq/eval/runners/multi_eval_runner.py +54 -0
- aiq/eval/usage_stats.py +6 -0
- aiq/eval/utils/weave_eval.py +5 -1
- aiq/experimental/__init__.py +0 -0
- aiq/experimental/decorators/__init__.py +0 -0
- aiq/experimental/decorators/experimental_warning_decorator.py +130 -0
- aiq/experimental/inference_time_scaling/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/editing/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/editing/iterative_plan_refinement_editor.py +147 -0
- aiq/experimental/inference_time_scaling/editing/llm_as_a_judge_editor.py +204 -0
- aiq/experimental/inference_time_scaling/editing/motivation_aware_summarization.py +107 -0
- aiq/experimental/inference_time_scaling/functions/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/functions/execute_score_select_function.py +105 -0
- aiq/experimental/inference_time_scaling/functions/its_tool_orchestration_function.py +205 -0
- aiq/experimental/inference_time_scaling/functions/its_tool_wrapper_function.py +146 -0
- aiq/experimental/inference_time_scaling/functions/plan_select_execute_function.py +224 -0
- aiq/experimental/inference_time_scaling/models/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/models/editor_config.py +132 -0
- aiq/experimental/inference_time_scaling/models/its_item.py +48 -0
- aiq/experimental/inference_time_scaling/models/scoring_config.py +112 -0
- aiq/experimental/inference_time_scaling/models/search_config.py +120 -0
- aiq/experimental/inference_time_scaling/models/selection_config.py +154 -0
- aiq/experimental/inference_time_scaling/models/stage_enums.py +43 -0
- aiq/experimental/inference_time_scaling/models/strategy_base.py +66 -0
- aiq/experimental/inference_time_scaling/models/tool_use_config.py +41 -0
- aiq/experimental/inference_time_scaling/register.py +36 -0
- aiq/experimental/inference_time_scaling/scoring/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/scoring/llm_based_agent_scorer.py +168 -0
- aiq/experimental/inference_time_scaling/scoring/llm_based_plan_scorer.py +168 -0
- aiq/experimental/inference_time_scaling/scoring/motivation_aware_scorer.py +111 -0
- aiq/experimental/inference_time_scaling/search/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/search/multi_llm_planner.py +128 -0
- aiq/experimental/inference_time_scaling/search/multi_query_retrieval_search.py +122 -0
- aiq/experimental/inference_time_scaling/search/single_shot_multi_plan_planner.py +128 -0
- aiq/experimental/inference_time_scaling/selection/__init__.py +0 -0
- aiq/experimental/inference_time_scaling/selection/best_of_n_selector.py +63 -0
- aiq/experimental/inference_time_scaling/selection/llm_based_agent_output_selector.py +131 -0
- aiq/experimental/inference_time_scaling/selection/llm_based_output_merging_selector.py +159 -0
- aiq/experimental/inference_time_scaling/selection/llm_based_plan_selector.py +128 -0
- aiq/experimental/inference_time_scaling/selection/threshold_selector.py +58 -0
- aiq/front_ends/console/authentication_flow_handler.py +233 -0
- aiq/front_ends/console/console_front_end_plugin.py +11 -2
- aiq/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- aiq/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- aiq/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
- aiq/front_ends/fastapi/fastapi_front_end_config.py +20 -0
- aiq/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin.py +14 -1
- aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +353 -31
- aiq/front_ends/fastapi/html_snippets/__init__.py +14 -0
- aiq/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- aiq/front_ends/fastapi/main.py +2 -0
- aiq/front_ends/fastapi/message_handler.py +102 -84
- aiq/front_ends/fastapi/step_adaptor.py +2 -1
- aiq/llm/aws_bedrock_llm.py +2 -1
- aiq/llm/nim_llm.py +2 -1
- aiq/llm/openai_llm.py +2 -1
- aiq/object_store/__init__.py +20 -0
- aiq/object_store/in_memory_object_store.py +74 -0
- aiq/object_store/interfaces.py +84 -0
- aiq/object_store/models.py +36 -0
- aiq/object_store/register.py +20 -0
- aiq/observability/__init__.py +14 -0
- aiq/observability/exporter/__init__.py +14 -0
- aiq/observability/exporter/base_exporter.py +449 -0
- aiq/observability/exporter/exporter.py +78 -0
- aiq/observability/exporter/file_exporter.py +33 -0
- aiq/observability/exporter/processing_exporter.py +269 -0
- aiq/observability/exporter/raw_exporter.py +52 -0
- aiq/observability/exporter/span_exporter.py +264 -0
- aiq/observability/exporter_manager.py +335 -0
- aiq/observability/mixin/__init__.py +14 -0
- aiq/observability/mixin/batch_config_mixin.py +26 -0
- aiq/observability/mixin/collector_config_mixin.py +23 -0
- aiq/observability/mixin/file_mixin.py +288 -0
- aiq/observability/mixin/file_mode.py +23 -0
- aiq/observability/mixin/resource_conflict_mixin.py +134 -0
- aiq/observability/mixin/serialize_mixin.py +61 -0
- aiq/observability/mixin/type_introspection_mixin.py +183 -0
- aiq/observability/processor/__init__.py +14 -0
- aiq/observability/processor/batching_processor.py +316 -0
- aiq/observability/processor/intermediate_step_serializer.py +28 -0
- aiq/observability/processor/processor.py +68 -0
- aiq/observability/register.py +32 -116
- aiq/observability/utils/__init__.py +14 -0
- aiq/observability/utils/dict_utils.py +236 -0
- aiq/observability/utils/time_utils.py +31 -0
- aiq/profiler/calc/__init__.py +14 -0
- aiq/profiler/calc/calc_runner.py +623 -0
- aiq/profiler/calc/calculations.py +288 -0
- aiq/profiler/calc/data_models.py +176 -0
- aiq/profiler/calc/plot.py +345 -0
- aiq/profiler/data_models.py +2 -0
- aiq/profiler/profile_runner.py +16 -13
- aiq/runtime/loader.py +8 -2
- aiq/runtime/runner.py +23 -9
- aiq/runtime/session.py +16 -5
- aiq/tool/chat_completion.py +74 -0
- aiq/tool/code_execution/README.md +152 -0
- aiq/tool/code_execution/code_sandbox.py +151 -72
- aiq/tool/code_execution/local_sandbox/.gitignore +1 -0
- aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +139 -24
- aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +3 -1
- aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +27 -2
- aiq/tool/code_execution/register.py +7 -3
- aiq/tool/code_execution/test_code_execution_sandbox.py +414 -0
- aiq/tool/mcp/exceptions.py +142 -0
- aiq/tool/mcp/mcp_client.py +17 -3
- aiq/tool/mcp/mcp_tool.py +1 -1
- aiq/tool/register.py +1 -0
- aiq/tool/server_tools.py +2 -2
- aiq/utils/exception_handlers/automatic_retries.py +289 -0
- aiq/utils/exception_handlers/mcp.py +211 -0
- aiq/utils/io/model_processing.py +28 -0
- aiq/utils/log_utils.py +37 -0
- aiq/utils/string_utils.py +38 -0
- aiq/utils/type_converter.py +18 -2
- aiq/utils/type_utils.py +87 -0
- {aiqtoolkit-1.2.0a20250706.dist-info → aiqtoolkit-1.2.0a20250730.dist-info}/METADATA +37 -9
- {aiqtoolkit-1.2.0a20250706.dist-info → aiqtoolkit-1.2.0a20250730.dist-info}/RECORD +195 -80
- {aiqtoolkit-1.2.0a20250706.dist-info → aiqtoolkit-1.2.0a20250730.dist-info}/entry_points.txt +3 -0
- aiq/front_ends/fastapi/websocket.py +0 -153
- aiq/observability/async_otel_listener.py +0 -470
- {aiqtoolkit-1.2.0a20250706.dist-info → aiqtoolkit-1.2.0a20250730.dist-info}/WHEEL +0 -0
- {aiqtoolkit-1.2.0a20250706.dist-info → aiqtoolkit-1.2.0a20250730.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {aiqtoolkit-1.2.0a20250706.dist-info → aiqtoolkit-1.2.0a20250730.dist-info}/licenses/LICENSE.md +0 -0
- {aiqtoolkit-1.2.0a20250706.dist-info → aiqtoolkit-1.2.0a20250730.dist-info}/top_level.txt +0 -0
aiq/agent/base.py
CHANGED
|
@@ -13,25 +13,32 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import asyncio
|
|
17
|
+
import json
|
|
16
18
|
import logging
|
|
17
19
|
from abc import ABC
|
|
18
20
|
from abc import abstractmethod
|
|
19
21
|
from enum import Enum
|
|
22
|
+
from typing import Any
|
|
20
23
|
|
|
21
24
|
from colorama import Fore
|
|
22
25
|
from langchain_core.callbacks import AsyncCallbackHandler
|
|
23
26
|
from langchain_core.language_models import BaseChatModel
|
|
27
|
+
from langchain_core.messages import AIMessage
|
|
28
|
+
from langchain_core.messages import BaseMessage
|
|
29
|
+
from langchain_core.messages import ToolMessage
|
|
30
|
+
from langchain_core.runnables import RunnableConfig
|
|
24
31
|
from langchain_core.tools import BaseTool
|
|
25
32
|
from langgraph.graph.graph import CompiledGraph
|
|
26
33
|
|
|
27
|
-
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
28
35
|
|
|
29
36
|
TOOL_NOT_FOUND_ERROR_MESSAGE = "There is no tool named {tool_name}. Tool must be one of {tools}."
|
|
30
37
|
INPUT_SCHEMA_MESSAGE = ". Arguments must be provided as a valid JSON object following this format: {schema}"
|
|
31
|
-
NO_INPUT_ERROR_MESSAGE = "No human input
|
|
38
|
+
NO_INPUT_ERROR_MESSAGE = "No human input received to the agent, Please ask a valid question."
|
|
32
39
|
|
|
33
40
|
AGENT_LOG_PREFIX = "[AGENT]"
|
|
34
|
-
|
|
41
|
+
AGENT_CALL_LOG_MESSAGE = f"\n{'-' * 30}\n" + \
|
|
35
42
|
AGENT_LOG_PREFIX + "\n" + \
|
|
36
43
|
Fore.YELLOW + \
|
|
37
44
|
"Agent input: %s\n" + \
|
|
@@ -40,7 +47,7 @@ AGENT_RESPONSE_LOG_MESSAGE = f"\n{'-' * 30}\n" + \
|
|
|
40
47
|
Fore.RESET + \
|
|
41
48
|
f"\n{'-' * 30}"
|
|
42
49
|
|
|
43
|
-
|
|
50
|
+
TOOL_CALL_LOG_MESSAGE = f"\n{'-' * 30}\n" + \
|
|
44
51
|
AGENT_LOG_PREFIX + "\n" + \
|
|
45
52
|
Fore.WHITE + \
|
|
46
53
|
"Calling tools: %s\n" + \
|
|
@@ -62,15 +69,171 @@ class BaseAgent(ABC):
|
|
|
62
69
|
def __init__(self,
|
|
63
70
|
llm: BaseChatModel,
|
|
64
71
|
tools: list[BaseTool],
|
|
65
|
-
callbacks: list[AsyncCallbackHandler] = None,
|
|
66
|
-
detailed_logs: bool = False):
|
|
67
|
-
|
|
72
|
+
callbacks: list[AsyncCallbackHandler] | None = None,
|
|
73
|
+
detailed_logs: bool = False) -> None:
|
|
74
|
+
logger.debug("Initializing Agent Graph")
|
|
68
75
|
self.llm = llm
|
|
69
76
|
self.tools = tools
|
|
70
77
|
self.callbacks = callbacks or []
|
|
71
78
|
self.detailed_logs = detailed_logs
|
|
72
79
|
self.graph = None
|
|
73
80
|
|
|
81
|
+
async def _stream_llm(self,
|
|
82
|
+
runnable: Any,
|
|
83
|
+
inputs: dict[str, Any],
|
|
84
|
+
config: RunnableConfig | None = None) -> AIMessage:
|
|
85
|
+
"""
|
|
86
|
+
Stream from LLM runnable. Retry logic is handled automatically by the underlying LLM client.
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
runnable : Any
|
|
91
|
+
The LLM runnable (prompt | llm or similar)
|
|
92
|
+
inputs : Dict[str, Any]
|
|
93
|
+
The inputs to pass to the runnable
|
|
94
|
+
config : RunnableConfig | None
|
|
95
|
+
The config to pass to the runnable (should include callbacks)
|
|
96
|
+
|
|
97
|
+
Returns
|
|
98
|
+
-------
|
|
99
|
+
AIMessage
|
|
100
|
+
The LLM response
|
|
101
|
+
"""
|
|
102
|
+
output_message = ""
|
|
103
|
+
async for event in runnable.astream(inputs, config=config):
|
|
104
|
+
output_message += event.content
|
|
105
|
+
|
|
106
|
+
return AIMessage(content=output_message)
|
|
107
|
+
|
|
108
|
+
async def _call_llm(self, messages: list[BaseMessage]) -> AIMessage:
|
|
109
|
+
"""
|
|
110
|
+
Call the LLM directly. Retry logic is handled automatically by the underlying LLM client.
|
|
111
|
+
|
|
112
|
+
Parameters
|
|
113
|
+
----------
|
|
114
|
+
messages : list[BaseMessage]
|
|
115
|
+
The messages to send to the LLM
|
|
116
|
+
|
|
117
|
+
Returns
|
|
118
|
+
-------
|
|
119
|
+
AIMessage
|
|
120
|
+
The LLM response
|
|
121
|
+
"""
|
|
122
|
+
response = await self.llm.ainvoke(messages)
|
|
123
|
+
return AIMessage(content=str(response.content))
|
|
124
|
+
|
|
125
|
+
async def _call_tool(self,
|
|
126
|
+
tool: BaseTool,
|
|
127
|
+
tool_input: dict[str, Any] | str,
|
|
128
|
+
config: RunnableConfig | None = None,
|
|
129
|
+
max_retries: int = 3) -> ToolMessage:
|
|
130
|
+
"""
|
|
131
|
+
Call a tool with retry logic and error handling.
|
|
132
|
+
|
|
133
|
+
Parameters
|
|
134
|
+
----------
|
|
135
|
+
tool : BaseTool
|
|
136
|
+
The tool to call
|
|
137
|
+
tool_input : Union[Dict[str, Any], str]
|
|
138
|
+
The input to pass to the tool
|
|
139
|
+
config : RunnableConfig | None
|
|
140
|
+
The config to pass to the tool
|
|
141
|
+
max_retries : int
|
|
142
|
+
Maximum number of retry attempts (default: 3)
|
|
143
|
+
|
|
144
|
+
Returns
|
|
145
|
+
-------
|
|
146
|
+
ToolMessage
|
|
147
|
+
The tool response
|
|
148
|
+
"""
|
|
149
|
+
last_exception = None
|
|
150
|
+
|
|
151
|
+
for attempt in range(1, max_retries + 1):
|
|
152
|
+
try:
|
|
153
|
+
response = await tool.ainvoke(tool_input, config=config)
|
|
154
|
+
|
|
155
|
+
# Handle empty responses
|
|
156
|
+
if response is None or (isinstance(response, str) and response == ""):
|
|
157
|
+
return ToolMessage(name=tool.name,
|
|
158
|
+
tool_call_id=tool.name,
|
|
159
|
+
content=f"The tool {tool.name} provided an empty response.")
|
|
160
|
+
|
|
161
|
+
return ToolMessage(name=tool.name, tool_call_id=tool.name, content=response)
|
|
162
|
+
|
|
163
|
+
except Exception as e:
|
|
164
|
+
last_exception = e
|
|
165
|
+
|
|
166
|
+
# If this was the last attempt, don't sleep
|
|
167
|
+
if attempt == max_retries:
|
|
168
|
+
break
|
|
169
|
+
|
|
170
|
+
logger.warning("%s Tool call attempt %d/%d failed for tool %s: %s",
|
|
171
|
+
AGENT_LOG_PREFIX,
|
|
172
|
+
attempt,
|
|
173
|
+
max_retries,
|
|
174
|
+
tool.name,
|
|
175
|
+
str(e))
|
|
176
|
+
|
|
177
|
+
# Exponential backoff: 2^attempt seconds
|
|
178
|
+
sleep_time = 2**attempt
|
|
179
|
+
logger.debug("%s Retrying tool call for %s in %d seconds...", AGENT_LOG_PREFIX, tool.name, sleep_time)
|
|
180
|
+
await asyncio.sleep(sleep_time)
|
|
181
|
+
|
|
182
|
+
# All retries exhausted, return error message
|
|
183
|
+
error_content = "Tool call failed after all retry attempts. Last error: %s" % str(last_exception)
|
|
184
|
+
logger.error("%s %s", AGENT_LOG_PREFIX, error_content)
|
|
185
|
+
return ToolMessage(name=tool.name, tool_call_id=tool.name, content=error_content, status="error")
|
|
186
|
+
|
|
187
|
+
def _log_tool_response(self, tool_name: str, tool_input: Any, tool_response: str, max_chars: int = 1000) -> None:
|
|
188
|
+
"""
|
|
189
|
+
Log tool response with consistent formatting and length limits.
|
|
190
|
+
|
|
191
|
+
Parameters
|
|
192
|
+
----------
|
|
193
|
+
tool_name : str
|
|
194
|
+
The name of the tool that was called
|
|
195
|
+
tool_input : Any
|
|
196
|
+
The input that was passed to the tool
|
|
197
|
+
tool_response : str
|
|
198
|
+
The response from the tool
|
|
199
|
+
max_chars : int
|
|
200
|
+
Maximum number of characters to log (default: 1000)
|
|
201
|
+
"""
|
|
202
|
+
if self.detailed_logs:
|
|
203
|
+
# Truncate tool response if too long
|
|
204
|
+
display_response = tool_response[:max_chars] + "...(rest of response truncated)" if len(
|
|
205
|
+
tool_response) > max_chars else tool_response
|
|
206
|
+
|
|
207
|
+
# Format the tool input for display
|
|
208
|
+
tool_input_str = str(tool_input)
|
|
209
|
+
|
|
210
|
+
tool_response_log_message = TOOL_CALL_LOG_MESSAGE % (tool_name, tool_input_str, display_response)
|
|
211
|
+
logger.info(tool_response_log_message)
|
|
212
|
+
|
|
213
|
+
def _parse_json(self, json_string: str) -> dict[str, Any]:
|
|
214
|
+
"""
|
|
215
|
+
Safely parse JSON with graceful error handling.
|
|
216
|
+
If JSON parsing fails, returns an empty dict or error info.
|
|
217
|
+
|
|
218
|
+
Parameters
|
|
219
|
+
----------
|
|
220
|
+
json_string : str
|
|
221
|
+
The JSON string to parse
|
|
222
|
+
|
|
223
|
+
Returns
|
|
224
|
+
-------
|
|
225
|
+
Dict[str, Any]
|
|
226
|
+
The parsed JSON or error information
|
|
227
|
+
"""
|
|
228
|
+
try:
|
|
229
|
+
return json.loads(json_string)
|
|
230
|
+
except json.JSONDecodeError as e:
|
|
231
|
+
logger.warning("%s JSON parsing failed, returning the original string: %s", AGENT_LOG_PREFIX, str(e))
|
|
232
|
+
return {"error": f"JSON parsing failed: {str(e)}", "original_string": json_string}
|
|
233
|
+
except Exception as e:
|
|
234
|
+
logger.warning("%s Unexpected error during JSON parsing: %s", AGENT_LOG_PREFIX, str(e))
|
|
235
|
+
return {"error": f"Unexpected parsing error: {str(e)}", "original_string": json_string}
|
|
236
|
+
|
|
74
237
|
@abstractmethod
|
|
75
|
-
async def _build_graph(self, state_schema) -> CompiledGraph:
|
|
238
|
+
async def _build_graph(self, state_schema: type) -> CompiledGraph:
|
|
76
239
|
pass
|
aiq/agent/dual_node.py
CHANGED
|
@@ -34,7 +34,7 @@ class DualNodeAgent(BaseAgent):
|
|
|
34
34
|
def __init__(self,
|
|
35
35
|
llm: BaseChatModel,
|
|
36
36
|
tools: list[BaseTool],
|
|
37
|
-
callbacks: list[AsyncCallbackHandler] = None,
|
|
37
|
+
callbacks: list[AsyncCallbackHandler] | None = None,
|
|
38
38
|
detailed_logs: bool = False):
|
|
39
39
|
super().__init__(llm=llm, tools=tools, callbacks=callbacks, detailed_logs=detailed_logs)
|
|
40
40
|
|
aiq/agent/react_agent/agent.py
CHANGED
|
@@ -33,12 +33,11 @@ from langchain_core.tools import BaseTool
|
|
|
33
33
|
from pydantic import BaseModel
|
|
34
34
|
from pydantic import Field
|
|
35
35
|
|
|
36
|
+
from aiq.agent.base import AGENT_CALL_LOG_MESSAGE
|
|
36
37
|
from aiq.agent.base import AGENT_LOG_PREFIX
|
|
37
|
-
from aiq.agent.base import AGENT_RESPONSE_LOG_MESSAGE
|
|
38
38
|
from aiq.agent.base import INPUT_SCHEMA_MESSAGE
|
|
39
39
|
from aiq.agent.base import NO_INPUT_ERROR_MESSAGE
|
|
40
40
|
from aiq.agent.base import TOOL_NOT_FOUND_ERROR_MESSAGE
|
|
41
|
-
from aiq.agent.base import TOOL_RESPONSE_LOG_MESSAGE
|
|
42
41
|
from aiq.agent.base import AgentDecision
|
|
43
42
|
from aiq.agent.dual_node import DualNodeAgent
|
|
44
43
|
from aiq.agent.react_agent.output_parser import ReActOutputParser
|
|
@@ -67,13 +66,17 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
67
66
|
prompt: ChatPromptTemplate,
|
|
68
67
|
tools: list[BaseTool],
|
|
69
68
|
use_tool_schema: bool = True,
|
|
70
|
-
callbacks: list[AsyncCallbackHandler] = None,
|
|
69
|
+
callbacks: list[AsyncCallbackHandler] | None = None,
|
|
71
70
|
detailed_logs: bool = False,
|
|
72
|
-
|
|
73
|
-
|
|
71
|
+
retry_agent_response_parsing_errors: bool = True,
|
|
72
|
+
parse_agent_response_max_retries: int = 1,
|
|
73
|
+
tool_call_max_retries: int = 1,
|
|
74
|
+
pass_tool_call_errors_to_agent: bool = True):
|
|
74
75
|
super().__init__(llm=llm, tools=tools, callbacks=callbacks, detailed_logs=detailed_logs)
|
|
75
|
-
self.
|
|
76
|
-
|
|
76
|
+
self.parse_agent_response_max_retries = (parse_agent_response_max_retries
|
|
77
|
+
if retry_agent_response_parsing_errors else 1)
|
|
78
|
+
self.tool_call_max_retries = tool_call_max_retries
|
|
79
|
+
self.pass_tool_call_errors_to_agent = pass_tool_call_errors_to_agent
|
|
77
80
|
logger.debug(
|
|
78
81
|
"%s Filling the prompt variables 'tools' and 'tool_names', using the tools provided in the config.",
|
|
79
82
|
AGENT_LOG_PREFIX)
|
|
@@ -91,12 +94,12 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
91
94
|
f"{INPUT_SCHEMA_MESSAGE.format(schema=tools[-1].input_schema.model_fields)}")
|
|
92
95
|
prompt = prompt.partial(tools=tool_names_and_descriptions, tool_names=tool_names)
|
|
93
96
|
# construct the ReAct Agent
|
|
94
|
-
|
|
95
|
-
self.agent = prompt |
|
|
97
|
+
bound_llm = llm.bind(stop=["Observation:"]) # type: ignore
|
|
98
|
+
self.agent = prompt | bound_llm
|
|
96
99
|
self.tools_dict = {tool.name: tool for tool in tools}
|
|
97
100
|
logger.debug("%s Initialized ReAct Agent Graph", AGENT_LOG_PREFIX)
|
|
98
101
|
|
|
99
|
-
def _get_tool(self, tool_name):
|
|
102
|
+
def _get_tool(self, tool_name: str):
|
|
100
103
|
try:
|
|
101
104
|
return self.tools_dict.get(tool_name)
|
|
102
105
|
except Exception as ex:
|
|
@@ -113,26 +116,30 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
113
116
|
# keeping a working state allows us to resolve parsing errors without polluting the agent scratchpad
|
|
114
117
|
# the agent "forgets" about the parsing error after solving it - prevents hallucinations in next cycles
|
|
115
118
|
working_state = []
|
|
116
|
-
|
|
119
|
+
# Starting from attempt 1 instead of 0 for logging
|
|
120
|
+
for attempt in range(1, self.parse_agent_response_max_retries + 1):
|
|
117
121
|
# the first time we are invoking the ReAct Agent, it won't have any intermediate steps / agent thoughts
|
|
118
122
|
if len(state.agent_scratchpad) == 0 and len(working_state) == 0:
|
|
119
123
|
# the user input comes from the "messages" state channel
|
|
120
124
|
if len(state.messages) == 0:
|
|
121
125
|
raise RuntimeError('No input received in state: "messages"')
|
|
122
126
|
# to check is any human input passed or not, if no input passed Agent will return the state
|
|
123
|
-
|
|
127
|
+
content = str(state.messages[0].content)
|
|
128
|
+
if content.strip() == "":
|
|
124
129
|
logger.error("%s No human input passed to the agent.", AGENT_LOG_PREFIX)
|
|
125
130
|
state.messages += [AIMessage(content=NO_INPUT_ERROR_MESSAGE)]
|
|
126
131
|
return state
|
|
127
|
-
question =
|
|
132
|
+
question = content
|
|
128
133
|
logger.debug("%s Querying agent, attempt: %s", AGENT_LOG_PREFIX, attempt)
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
+
|
|
135
|
+
output_message = await self._stream_llm(
|
|
136
|
+
self.agent,
|
|
137
|
+
{"question": question},
|
|
138
|
+
RunnableConfig(callbacks=self.callbacks) # type: ignore
|
|
139
|
+
)
|
|
140
|
+
|
|
134
141
|
if self.detailed_logs:
|
|
135
|
-
logger.info(
|
|
142
|
+
logger.info(AGENT_CALL_LOG_MESSAGE, question, output_message.content)
|
|
136
143
|
else:
|
|
137
144
|
# ReAct Agents require agentic cycles
|
|
138
145
|
# in an agentic cycle, preserve the agent's thoughts from the previous cycles,
|
|
@@ -141,20 +148,20 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
141
148
|
for index, intermediate_step in enumerate(state.agent_scratchpad):
|
|
142
149
|
agent_thoughts = AIMessage(content=intermediate_step.log)
|
|
143
150
|
agent_scratchpad.append(agent_thoughts)
|
|
144
|
-
|
|
151
|
+
tool_response_content = str(state.tool_responses[index].content)
|
|
152
|
+
tool_response = HumanMessage(content=tool_response_content)
|
|
145
153
|
agent_scratchpad.append(tool_response)
|
|
146
154
|
agent_scratchpad += working_state
|
|
147
|
-
question = state.messages[0].content
|
|
155
|
+
question = str(state.messages[0].content)
|
|
148
156
|
logger.debug("%s Querying agent, attempt: %s", AGENT_LOG_PREFIX, attempt)
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
157
|
+
|
|
158
|
+
output_message = await self._stream_llm(self.agent, {
|
|
159
|
+
"question": question, "agent_scratchpad": agent_scratchpad
|
|
152
160
|
},
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
output_message = AIMessage(content=output_message)
|
|
161
|
+
RunnableConfig(callbacks=self.callbacks))
|
|
162
|
+
|
|
156
163
|
if self.detailed_logs:
|
|
157
|
-
logger.info(
|
|
164
|
+
logger.info(AGENT_CALL_LOG_MESSAGE, question, output_message.content)
|
|
158
165
|
logger.debug("%s The agent's scratchpad (with tool result) was:\n%s",
|
|
159
166
|
AGENT_LOG_PREFIX,
|
|
160
167
|
agent_scratchpad)
|
|
@@ -162,11 +169,7 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
162
169
|
# check if the agent has the final answer yet
|
|
163
170
|
logger.debug("%s Successfully obtained agent response. Parsing agent's response", AGENT_LOG_PREFIX)
|
|
164
171
|
agent_output = await ReActOutputParser().aparse(output_message.content)
|
|
165
|
-
logger.debug("%s Successfully parsed agent
|
|
166
|
-
if attempt > 1:
|
|
167
|
-
logger.debug("%s Successfully parsed agent response after %s attempts",
|
|
168
|
-
AGENT_LOG_PREFIX,
|
|
169
|
-
attempt)
|
|
172
|
+
logger.debug("%s Successfully parsed agent response after %s attempts", AGENT_LOG_PREFIX, attempt)
|
|
170
173
|
if isinstance(agent_output, AgentFinish):
|
|
171
174
|
final_answer = agent_output.return_values.get('output', output_message.content)
|
|
172
175
|
logger.debug("%s The agent has finished, and has the final answer", AGENT_LOG_PREFIX)
|
|
@@ -178,31 +181,32 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
178
181
|
agent_output.log = output_message.content
|
|
179
182
|
logger.debug("%s The agent wants to call a tool: %s", AGENT_LOG_PREFIX, agent_output.tool)
|
|
180
183
|
state.agent_scratchpad += [agent_output]
|
|
184
|
+
|
|
181
185
|
return state
|
|
182
186
|
except ReActOutputParserException as ex:
|
|
183
187
|
# the agent output did not meet the expected ReAct output format. This can happen for a few reasons:
|
|
184
188
|
# the agent mentioned a tool, but already has the final answer, this can happen with Llama models
|
|
185
189
|
# - the ReAct Agent already has the answer, and is reflecting on how it obtained the answer
|
|
186
190
|
# the agent might have also missed Action or Action Input in its output
|
|
187
|
-
logger.
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
if attempt == self.
|
|
192
|
-
logger.
|
|
191
|
+
logger.debug("%s Error parsing agent output\nObservation:%s\nAgent Output:\n%s",
|
|
192
|
+
AGENT_LOG_PREFIX,
|
|
193
|
+
ex.observation,
|
|
194
|
+
output_message.content)
|
|
195
|
+
if attempt == self.parse_agent_response_max_retries:
|
|
196
|
+
logger.warning(
|
|
193
197
|
"%s Failed to parse agent output after %d attempts, consider enabling or "
|
|
194
|
-
"increasing
|
|
198
|
+
"increasing parse_agent_response_max_retries",
|
|
195
199
|
AGENT_LOG_PREFIX,
|
|
196
|
-
attempt
|
|
197
|
-
exc_info=True)
|
|
200
|
+
attempt)
|
|
198
201
|
# the final answer goes in the "messages" state channel
|
|
199
|
-
|
|
202
|
+
combined_content = str(ex.observation) + '\n' + str(output_message.content)
|
|
203
|
+
output_message.content = combined_content
|
|
200
204
|
state.messages += [output_message]
|
|
201
205
|
return state
|
|
202
206
|
# retry parsing errors, if configured
|
|
203
207
|
logger.info("%s Retrying ReAct Agent, including output parsing Observation", AGENT_LOG_PREFIX)
|
|
204
208
|
working_state.append(output_message)
|
|
205
|
-
working_state.append(HumanMessage(content=ex.observation))
|
|
209
|
+
working_state.append(HumanMessage(content=str(ex.observation)))
|
|
206
210
|
except Exception as ex:
|
|
207
211
|
logger.exception("%s Failed to call agent_node: %s", AGENT_LOG_PREFIX, ex, exc_info=True)
|
|
208
212
|
raise ex
|
|
@@ -212,7 +216,8 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
212
216
|
logger.debug("%s Starting the ReAct Conditional Edge", AGENT_LOG_PREFIX)
|
|
213
217
|
if len(state.messages) > 1:
|
|
214
218
|
# the ReAct Agent has finished executing, the last agent output was AgentFinish
|
|
215
|
-
|
|
219
|
+
last_message_content = str(state.messages[-1].content)
|
|
220
|
+
logger.debug("%s Final answer:\n%s", AGENT_LOG_PREFIX, last_message_content)
|
|
216
221
|
return AgentDecision.END
|
|
217
222
|
# else the agent wants to call a tool
|
|
218
223
|
agent_output = state.agent_scratchpad[-1]
|
|
@@ -227,76 +232,71 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
227
232
|
return AgentDecision.END
|
|
228
233
|
|
|
229
234
|
async def tool_node(self, state: ReActGraphState):
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
state.tool_responses += [tool_response]
|
|
250
|
-
return state
|
|
251
|
-
|
|
252
|
-
logger.debug("%s Calling tool %s with input: %s",
|
|
253
|
-
AGENT_LOG_PREFIX,
|
|
254
|
-
requested_tool.name,
|
|
255
|
-
agent_thoughts.tool_input)
|
|
256
|
-
|
|
257
|
-
# Run the tool. Try to use structured input, if possible.
|
|
258
|
-
try:
|
|
259
|
-
tool_input_str = agent_thoughts.tool_input.strip().replace("'", '"')
|
|
260
|
-
tool_input_dict = json.loads(tool_input_str) if tool_input_str != 'None' else tool_input_str
|
|
261
|
-
logger.debug("%s Successfully parsed structured tool input from Action Input", AGENT_LOG_PREFIX)
|
|
262
|
-
tool_response = await requested_tool.ainvoke(tool_input_dict,
|
|
263
|
-
config=RunnableConfig(callbacks=self.callbacks))
|
|
264
|
-
if self.detailed_logs:
|
|
265
|
-
# The tool response can be very large, so we log only the first 1000 characters
|
|
266
|
-
tool_response_str = str(tool_response)
|
|
267
|
-
tool_response_str = tool_response_str[:1000] + "..." if len(
|
|
268
|
-
tool_response_str) > 1000 else tool_response_str
|
|
269
|
-
tool_response_log_message = TOOL_RESPONSE_LOG_MESSAGE % (
|
|
270
|
-
requested_tool.name, tool_input_str, tool_response_str)
|
|
271
|
-
logger.info(tool_response_log_message)
|
|
272
|
-
except JSONDecodeError as ex:
|
|
273
|
-
logger.warning(
|
|
274
|
-
"%s Unable to parse structured tool input from Action Input. Using Action Input as is."
|
|
275
|
-
"\nParsing error: %s",
|
|
276
|
-
AGENT_LOG_PREFIX,
|
|
277
|
-
ex,
|
|
278
|
-
exc_info=True)
|
|
279
|
-
tool_input_str = agent_thoughts.tool_input
|
|
280
|
-
tool_response = await requested_tool.ainvoke(tool_input_str,
|
|
281
|
-
config=RunnableConfig(callbacks=self.callbacks))
|
|
282
|
-
|
|
283
|
-
# some tools, such as Wikipedia, will return an empty response when no search results are found
|
|
284
|
-
if tool_response is None or tool_response == "":
|
|
285
|
-
tool_response = "The tool provided an empty response.\n"
|
|
286
|
-
# put the tool response in the graph state
|
|
287
|
-
tool_response = ToolMessage(name=agent_thoughts.tool,
|
|
288
|
-
tool_call_id=agent_thoughts.tool,
|
|
289
|
-
content=tool_response)
|
|
290
|
-
logger.debug("%s Called tool %s with input: %s\nThe tool returned: %s",
|
|
291
|
-
AGENT_LOG_PREFIX,
|
|
292
|
-
requested_tool.name,
|
|
293
|
-
agent_thoughts.tool_input,
|
|
294
|
-
tool_response.content)
|
|
235
|
+
|
|
236
|
+
logger.debug("%s Starting the Tool Call Node", AGENT_LOG_PREFIX)
|
|
237
|
+
if len(state.agent_scratchpad) == 0:
|
|
238
|
+
raise RuntimeError('No tool input received in state: "agent_scratchpad"')
|
|
239
|
+
agent_thoughts = state.agent_scratchpad[-1]
|
|
240
|
+
# the agent can run any installed tool, simply install the tool and add it to the config file
|
|
241
|
+
requested_tool = self._get_tool(agent_thoughts.tool)
|
|
242
|
+
if not requested_tool:
|
|
243
|
+
configured_tool_names = list(self.tools_dict.keys())
|
|
244
|
+
logger.warning(
|
|
245
|
+
"%s ReAct Agent wants to call tool %s. In the ReAct Agent's configuration within the config file,"
|
|
246
|
+
"there is no tool with that name: %s",
|
|
247
|
+
AGENT_LOG_PREFIX,
|
|
248
|
+
agent_thoughts.tool,
|
|
249
|
+
configured_tool_names)
|
|
250
|
+
tool_response = ToolMessage(name='agent_error',
|
|
251
|
+
tool_call_id='agent_error',
|
|
252
|
+
content=TOOL_NOT_FOUND_ERROR_MESSAGE.format(tool_name=agent_thoughts.tool,
|
|
253
|
+
tools=configured_tool_names))
|
|
295
254
|
state.tool_responses += [tool_response]
|
|
296
255
|
return state
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
256
|
+
|
|
257
|
+
logger.debug("%s Calling tool %s with input: %s",
|
|
258
|
+
AGENT_LOG_PREFIX,
|
|
259
|
+
requested_tool.name,
|
|
260
|
+
agent_thoughts.tool_input)
|
|
261
|
+
|
|
262
|
+
# Run the tool. Try to use structured input, if possible.
|
|
263
|
+
try:
|
|
264
|
+
tool_input_str = str(agent_thoughts.tool_input).strip().replace("'", '"')
|
|
265
|
+
tool_input_dict = json.loads(tool_input_str) if tool_input_str != 'None' else tool_input_str
|
|
266
|
+
logger.debug("%s Successfully parsed structured tool input from Action Input", AGENT_LOG_PREFIX)
|
|
267
|
+
|
|
268
|
+
tool_response = await self._call_tool(requested_tool,
|
|
269
|
+
tool_input_dict,
|
|
270
|
+
RunnableConfig(callbacks=self.callbacks),
|
|
271
|
+
max_retries=self.tool_call_max_retries)
|
|
272
|
+
|
|
273
|
+
if self.detailed_logs:
|
|
274
|
+
self._log_tool_response(requested_tool.name, tool_input_dict, str(tool_response.content))
|
|
275
|
+
|
|
276
|
+
except JSONDecodeError as ex:
|
|
277
|
+
logger.debug(
|
|
278
|
+
"%s Unable to parse structured tool input from Action Input. Using Action Input as is."
|
|
279
|
+
"\nParsing error: %s",
|
|
280
|
+
AGENT_LOG_PREFIX,
|
|
281
|
+
ex,
|
|
282
|
+
exc_info=True)
|
|
283
|
+
tool_input_str = str(agent_thoughts.tool_input)
|
|
284
|
+
|
|
285
|
+
tool_response = await self._call_tool(requested_tool,
|
|
286
|
+
tool_input_str,
|
|
287
|
+
RunnableConfig(callbacks=self.callbacks),
|
|
288
|
+
max_retries=self.tool_call_max_retries)
|
|
289
|
+
|
|
290
|
+
if self.detailed_logs:
|
|
291
|
+
self._log_tool_response(requested_tool.name, tool_input_str, str(tool_response.content))
|
|
292
|
+
|
|
293
|
+
if not self.pass_tool_call_errors_to_agent:
|
|
294
|
+
if tool_response.status == "error":
|
|
295
|
+
logger.error("%s Tool %s failed: %s", AGENT_LOG_PREFIX, requested_tool.name, tool_response.content)
|
|
296
|
+
raise RuntimeError("Tool call failed: " + str(tool_response.content))
|
|
297
|
+
|
|
298
|
+
state.tool_responses += [tool_response]
|
|
299
|
+
return state
|
|
300
300
|
|
|
301
301
|
async def build_graph(self):
|
|
302
302
|
try:
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import logging
|
|
17
17
|
|
|
18
|
+
from pydantic import AliasChoices
|
|
18
19
|
from pydantic import Field
|
|
19
20
|
|
|
20
21
|
from aiq.agent.base import AGENT_LOG_PREFIX
|
|
@@ -42,11 +43,24 @@ class ReActAgentWorkflowConfig(FunctionBaseConfig, name="react_agent"):
|
|
|
42
43
|
description="The list of tools to provide to the react agent.")
|
|
43
44
|
llm_name: LLMRef = Field(description="The LLM model to use with the react agent.")
|
|
44
45
|
verbose: bool = Field(default=False, description="Set the verbosity of the react agent's logging.")
|
|
45
|
-
|
|
46
|
-
|
|
46
|
+
retry_agent_response_parsing_errors: bool = Field(
|
|
47
|
+
default=True,
|
|
48
|
+
validation_alias=AliasChoices("retry_agent_response_parsing_errors", "retry_parsing_errors"),
|
|
49
|
+
description="Whether to retry when encountering parsing errors in the agent's response.")
|
|
50
|
+
parse_agent_response_max_retries: int = Field(
|
|
51
|
+
default=1,
|
|
52
|
+
validation_alias=AliasChoices("parse_agent_response_max_retries", "max_retries"),
|
|
53
|
+
description="Maximum number of times the Agent may retry parsing errors. "
|
|
54
|
+
"Prevents the Agent from getting into infinite hallucination loops.")
|
|
55
|
+
tool_call_max_retries: int = Field(default=1, description="The number of retries before raising a tool call error.")
|
|
56
|
+
max_tool_calls: int = Field(default=15,
|
|
57
|
+
validation_alias=AliasChoices("max_tool_calls", "max_iterations"),
|
|
58
|
+
description="Maximum number of tool calls before stopping the agent.")
|
|
59
|
+
pass_tool_call_errors_to_agent: bool = Field(
|
|
60
|
+
default=True,
|
|
61
|
+
description="Whether to pass tool call errors to agent. If False, failed tool calls will raise an exception.")
|
|
47
62
|
include_tool_input_schema_in_tool_description: bool = Field(
|
|
48
63
|
default=True, description="Specify inclusion of tool input schemas in the prompt.")
|
|
49
|
-
max_iterations: int = Field(default=15, description="Number of tool calls before stoping the react agent.")
|
|
50
64
|
description: str = Field(default="ReAct Agent Workflow", description="The description of this functions use.")
|
|
51
65
|
system_prompt: str | None = Field(
|
|
52
66
|
default=None,
|
|
@@ -80,13 +94,16 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
80
94
|
raise ValueError(f"No tools specified for ReAct Agent '{config.llm_name}'")
|
|
81
95
|
# configure callbacks, for sending intermediate steps
|
|
82
96
|
# construct the ReAct Agent Graph from the configured llm, prompt, and tools
|
|
83
|
-
graph: CompiledGraph = await ReActAgentGraph(
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
97
|
+
graph: CompiledGraph = await ReActAgentGraph(
|
|
98
|
+
llm=llm,
|
|
99
|
+
prompt=prompt,
|
|
100
|
+
tools=tools,
|
|
101
|
+
use_tool_schema=config.include_tool_input_schema_in_tool_description,
|
|
102
|
+
detailed_logs=config.verbose,
|
|
103
|
+
retry_agent_response_parsing_errors=config.retry_agent_response_parsing_errors,
|
|
104
|
+
parse_agent_response_max_retries=config.parse_agent_response_max_retries,
|
|
105
|
+
tool_call_max_retries=config.tool_call_max_retries,
|
|
106
|
+
pass_tool_call_errors_to_agent=config.pass_tool_call_errors_to_agent).build_graph()
|
|
90
107
|
|
|
91
108
|
async def _response_fn(input_message: AIQChatRequest) -> AIQChatResponse:
|
|
92
109
|
try:
|
|
@@ -101,7 +118,7 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
101
118
|
state = ReActGraphState(messages=messages)
|
|
102
119
|
|
|
103
120
|
# run the ReAct Agent Graph
|
|
104
|
-
state = await graph.ainvoke(state, config={'recursion_limit': (config.
|
|
121
|
+
state = await graph.ainvoke(state, config={'recursion_limit': (config.max_tool_calls + 1) * 2})
|
|
105
122
|
# setting recursion_limit: 4 allows 1 tool call
|
|
106
123
|
# - allows the ReAct Agent to perform 1 cycle / call 1 single tool,
|
|
107
124
|
# - but stops the agent when it tries to call a tool a second time
|
|
@@ -109,7 +126,7 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
109
126
|
# get and return the output from the state
|
|
110
127
|
state = ReActGraphState(**state)
|
|
111
128
|
output_message = state.messages[-1] # pylint: disable=E1136
|
|
112
|
-
return AIQChatResponse.from_string(output_message.content)
|
|
129
|
+
return AIQChatResponse.from_string(str(output_message.content))
|
|
113
130
|
|
|
114
131
|
except Exception as ex:
|
|
115
132
|
logger.exception("%s ReAct Agent failed with exception: %s", AGENT_LOG_PREFIX, ex, exc_info=ex)
|
|
@@ -123,10 +140,10 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
123
140
|
else:
|
|
124
141
|
|
|
125
142
|
async def _str_api_fn(input_message: str) -> str:
|
|
126
|
-
oai_input = GlobalTypeConverter.get().
|
|
143
|
+
oai_input = GlobalTypeConverter.get().try_convert(input_message, to_type=AIQChatRequest)
|
|
127
144
|
|
|
128
145
|
oai_output = await _response_fn(oai_input)
|
|
129
146
|
|
|
130
|
-
return GlobalTypeConverter.get().
|
|
147
|
+
return GlobalTypeConverter.get().try_convert(oai_output, to_type=str)
|
|
131
148
|
|
|
132
149
|
yield FunctionInfo.from_fn(_str_api_fn, description=config.description)
|