aiqtoolkit 1.2.0a20250707__py3-none-any.whl → 1.2.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aiqtoolkit might be problematic. Click here for more details.

Files changed (197) hide show
  1. aiq/agent/base.py +170 -8
  2. aiq/agent/dual_node.py +1 -1
  3. aiq/agent/react_agent/agent.py +112 -111
  4. aiq/agent/react_agent/register.py +31 -14
  5. aiq/agent/rewoo_agent/agent.py +36 -35
  6. aiq/agent/rewoo_agent/register.py +2 -2
  7. aiq/agent/tool_calling_agent/agent.py +3 -7
  8. aiq/authentication/__init__.py +14 -0
  9. aiq/authentication/api_key/__init__.py +14 -0
  10. aiq/authentication/api_key/api_key_auth_provider.py +92 -0
  11. aiq/authentication/api_key/api_key_auth_provider_config.py +124 -0
  12. aiq/authentication/api_key/register.py +26 -0
  13. aiq/authentication/exceptions/__init__.py +14 -0
  14. aiq/authentication/exceptions/api_key_exceptions.py +38 -0
  15. aiq/authentication/exceptions/auth_code_grant_exceptions.py +86 -0
  16. aiq/authentication/exceptions/call_back_exceptions.py +38 -0
  17. aiq/authentication/exceptions/request_exceptions.py +54 -0
  18. aiq/authentication/http_basic_auth/__init__.py +0 -0
  19. aiq/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
  20. aiq/authentication/http_basic_auth/register.py +30 -0
  21. aiq/authentication/interfaces.py +93 -0
  22. aiq/authentication/oauth2/__init__.py +14 -0
  23. aiq/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
  24. aiq/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
  25. aiq/authentication/oauth2/register.py +25 -0
  26. aiq/authentication/register.py +21 -0
  27. aiq/builder/builder.py +64 -2
  28. aiq/builder/component_utils.py +16 -3
  29. aiq/builder/context.py +26 -0
  30. aiq/builder/eval_builder.py +43 -2
  31. aiq/builder/function.py +32 -4
  32. aiq/builder/function_base.py +1 -1
  33. aiq/builder/intermediate_step_manager.py +6 -8
  34. aiq/builder/user_interaction_manager.py +3 -0
  35. aiq/builder/workflow.py +23 -18
  36. aiq/builder/workflow_builder.py +420 -73
  37. aiq/cli/commands/info/list_mcp.py +103 -16
  38. aiq/cli/commands/sizing/__init__.py +14 -0
  39. aiq/cli/commands/sizing/calc.py +294 -0
  40. aiq/cli/commands/sizing/sizing.py +27 -0
  41. aiq/cli/commands/start.py +1 -0
  42. aiq/cli/entrypoint.py +2 -0
  43. aiq/cli/register_workflow.py +80 -0
  44. aiq/cli/type_registry.py +151 -30
  45. aiq/data_models/api_server.py +123 -11
  46. aiq/data_models/authentication.py +231 -0
  47. aiq/data_models/common.py +35 -7
  48. aiq/data_models/component.py +17 -9
  49. aiq/data_models/component_ref.py +33 -0
  50. aiq/data_models/config.py +60 -3
  51. aiq/data_models/embedder.py +1 -0
  52. aiq/data_models/function_dependencies.py +8 -0
  53. aiq/data_models/interactive.py +10 -1
  54. aiq/data_models/intermediate_step.py +15 -5
  55. aiq/data_models/its_strategy.py +30 -0
  56. aiq/data_models/llm.py +1 -0
  57. aiq/data_models/memory.py +1 -0
  58. aiq/data_models/object_store.py +44 -0
  59. aiq/data_models/retry_mixin.py +35 -0
  60. aiq/data_models/span.py +187 -0
  61. aiq/data_models/telemetry_exporter.py +2 -2
  62. aiq/embedder/nim_embedder.py +2 -1
  63. aiq/embedder/openai_embedder.py +2 -1
  64. aiq/eval/config.py +19 -1
  65. aiq/eval/dataset_handler/dataset_handler.py +75 -1
  66. aiq/eval/evaluate.py +53 -10
  67. aiq/eval/rag_evaluator/evaluate.py +23 -12
  68. aiq/eval/remote_workflow.py +7 -2
  69. aiq/eval/runners/__init__.py +14 -0
  70. aiq/eval/runners/config.py +39 -0
  71. aiq/eval/runners/multi_eval_runner.py +54 -0
  72. aiq/eval/usage_stats.py +6 -0
  73. aiq/eval/utils/weave_eval.py +5 -1
  74. aiq/experimental/__init__.py +0 -0
  75. aiq/experimental/decorators/__init__.py +0 -0
  76. aiq/experimental/decorators/experimental_warning_decorator.py +130 -0
  77. aiq/experimental/inference_time_scaling/__init__.py +0 -0
  78. aiq/experimental/inference_time_scaling/editing/__init__.py +0 -0
  79. aiq/experimental/inference_time_scaling/editing/iterative_plan_refinement_editor.py +147 -0
  80. aiq/experimental/inference_time_scaling/editing/llm_as_a_judge_editor.py +204 -0
  81. aiq/experimental/inference_time_scaling/editing/motivation_aware_summarization.py +107 -0
  82. aiq/experimental/inference_time_scaling/functions/__init__.py +0 -0
  83. aiq/experimental/inference_time_scaling/functions/execute_score_select_function.py +105 -0
  84. aiq/experimental/inference_time_scaling/functions/its_tool_orchestration_function.py +205 -0
  85. aiq/experimental/inference_time_scaling/functions/its_tool_wrapper_function.py +146 -0
  86. aiq/experimental/inference_time_scaling/functions/plan_select_execute_function.py +224 -0
  87. aiq/experimental/inference_time_scaling/models/__init__.py +0 -0
  88. aiq/experimental/inference_time_scaling/models/editor_config.py +132 -0
  89. aiq/experimental/inference_time_scaling/models/its_item.py +48 -0
  90. aiq/experimental/inference_time_scaling/models/scoring_config.py +112 -0
  91. aiq/experimental/inference_time_scaling/models/search_config.py +120 -0
  92. aiq/experimental/inference_time_scaling/models/selection_config.py +154 -0
  93. aiq/experimental/inference_time_scaling/models/stage_enums.py +43 -0
  94. aiq/experimental/inference_time_scaling/models/strategy_base.py +66 -0
  95. aiq/experimental/inference_time_scaling/models/tool_use_config.py +41 -0
  96. aiq/experimental/inference_time_scaling/register.py +36 -0
  97. aiq/experimental/inference_time_scaling/scoring/__init__.py +0 -0
  98. aiq/experimental/inference_time_scaling/scoring/llm_based_agent_scorer.py +168 -0
  99. aiq/experimental/inference_time_scaling/scoring/llm_based_plan_scorer.py +168 -0
  100. aiq/experimental/inference_time_scaling/scoring/motivation_aware_scorer.py +111 -0
  101. aiq/experimental/inference_time_scaling/search/__init__.py +0 -0
  102. aiq/experimental/inference_time_scaling/search/multi_llm_planner.py +128 -0
  103. aiq/experimental/inference_time_scaling/search/multi_query_retrieval_search.py +122 -0
  104. aiq/experimental/inference_time_scaling/search/single_shot_multi_plan_planner.py +128 -0
  105. aiq/experimental/inference_time_scaling/selection/__init__.py +0 -0
  106. aiq/experimental/inference_time_scaling/selection/best_of_n_selector.py +63 -0
  107. aiq/experimental/inference_time_scaling/selection/llm_based_agent_output_selector.py +131 -0
  108. aiq/experimental/inference_time_scaling/selection/llm_based_output_merging_selector.py +159 -0
  109. aiq/experimental/inference_time_scaling/selection/llm_based_plan_selector.py +128 -0
  110. aiq/experimental/inference_time_scaling/selection/threshold_selector.py +58 -0
  111. aiq/front_ends/console/authentication_flow_handler.py +233 -0
  112. aiq/front_ends/console/console_front_end_plugin.py +11 -2
  113. aiq/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
  114. aiq/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
  115. aiq/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
  116. aiq/front_ends/fastapi/fastapi_front_end_config.py +20 -0
  117. aiq/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
  118. aiq/front_ends/fastapi/fastapi_front_end_plugin.py +14 -1
  119. aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +353 -31
  120. aiq/front_ends/fastapi/html_snippets/__init__.py +14 -0
  121. aiq/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
  122. aiq/front_ends/fastapi/main.py +2 -0
  123. aiq/front_ends/fastapi/message_handler.py +102 -84
  124. aiq/front_ends/fastapi/step_adaptor.py +2 -1
  125. aiq/llm/aws_bedrock_llm.py +2 -1
  126. aiq/llm/nim_llm.py +2 -1
  127. aiq/llm/openai_llm.py +2 -1
  128. aiq/object_store/__init__.py +20 -0
  129. aiq/object_store/in_memory_object_store.py +74 -0
  130. aiq/object_store/interfaces.py +84 -0
  131. aiq/object_store/models.py +36 -0
  132. aiq/object_store/register.py +20 -0
  133. aiq/observability/__init__.py +14 -0
  134. aiq/observability/exporter/__init__.py +14 -0
  135. aiq/observability/exporter/base_exporter.py +449 -0
  136. aiq/observability/exporter/exporter.py +78 -0
  137. aiq/observability/exporter/file_exporter.py +33 -0
  138. aiq/observability/exporter/processing_exporter.py +269 -0
  139. aiq/observability/exporter/raw_exporter.py +52 -0
  140. aiq/observability/exporter/span_exporter.py +264 -0
  141. aiq/observability/exporter_manager.py +335 -0
  142. aiq/observability/mixin/__init__.py +14 -0
  143. aiq/observability/mixin/batch_config_mixin.py +26 -0
  144. aiq/observability/mixin/collector_config_mixin.py +23 -0
  145. aiq/observability/mixin/file_mixin.py +288 -0
  146. aiq/observability/mixin/file_mode.py +23 -0
  147. aiq/observability/mixin/resource_conflict_mixin.py +134 -0
  148. aiq/observability/mixin/serialize_mixin.py +61 -0
  149. aiq/observability/mixin/type_introspection_mixin.py +183 -0
  150. aiq/observability/processor/__init__.py +14 -0
  151. aiq/observability/processor/batching_processor.py +316 -0
  152. aiq/observability/processor/intermediate_step_serializer.py +28 -0
  153. aiq/observability/processor/processor.py +68 -0
  154. aiq/observability/register.py +32 -116
  155. aiq/observability/utils/__init__.py +14 -0
  156. aiq/observability/utils/dict_utils.py +236 -0
  157. aiq/observability/utils/time_utils.py +31 -0
  158. aiq/profiler/calc/__init__.py +14 -0
  159. aiq/profiler/calc/calc_runner.py +623 -0
  160. aiq/profiler/calc/calculations.py +288 -0
  161. aiq/profiler/calc/data_models.py +176 -0
  162. aiq/profiler/calc/plot.py +345 -0
  163. aiq/profiler/data_models.py +2 -0
  164. aiq/profiler/profile_runner.py +16 -13
  165. aiq/runtime/loader.py +8 -2
  166. aiq/runtime/runner.py +23 -9
  167. aiq/runtime/session.py +16 -5
  168. aiq/tool/chat_completion.py +74 -0
  169. aiq/tool/code_execution/README.md +152 -0
  170. aiq/tool/code_execution/code_sandbox.py +151 -72
  171. aiq/tool/code_execution/local_sandbox/.gitignore +1 -0
  172. aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +139 -24
  173. aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +3 -1
  174. aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +27 -2
  175. aiq/tool/code_execution/register.py +7 -3
  176. aiq/tool/code_execution/test_code_execution_sandbox.py +414 -0
  177. aiq/tool/mcp/exceptions.py +142 -0
  178. aiq/tool/mcp/mcp_client.py +17 -3
  179. aiq/tool/mcp/mcp_tool.py +1 -1
  180. aiq/tool/register.py +1 -0
  181. aiq/tool/server_tools.py +2 -2
  182. aiq/utils/exception_handlers/automatic_retries.py +289 -0
  183. aiq/utils/exception_handlers/mcp.py +211 -0
  184. aiq/utils/io/model_processing.py +28 -0
  185. aiq/utils/log_utils.py +37 -0
  186. aiq/utils/string_utils.py +38 -0
  187. aiq/utils/type_converter.py +18 -2
  188. aiq/utils/type_utils.py +87 -0
  189. {aiqtoolkit-1.2.0a20250707.dist-info → aiqtoolkit-1.2.0rc1.dist-info}/METADATA +37 -9
  190. {aiqtoolkit-1.2.0a20250707.dist-info → aiqtoolkit-1.2.0rc1.dist-info}/RECORD +195 -80
  191. {aiqtoolkit-1.2.0a20250707.dist-info → aiqtoolkit-1.2.0rc1.dist-info}/entry_points.txt +3 -0
  192. aiq/front_ends/fastapi/websocket.py +0 -153
  193. aiq/observability/async_otel_listener.py +0 -470
  194. {aiqtoolkit-1.2.0a20250707.dist-info → aiqtoolkit-1.2.0rc1.dist-info}/WHEEL +0 -0
  195. {aiqtoolkit-1.2.0a20250707.dist-info → aiqtoolkit-1.2.0rc1.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  196. {aiqtoolkit-1.2.0a20250707.dist-info → aiqtoolkit-1.2.0rc1.dist-info}/licenses/LICENSE.md +0 -0
  197. {aiqtoolkit-1.2.0a20250707.dist-info → aiqtoolkit-1.2.0rc1.dist-info}/top_level.txt +0 -0
aiq/agent/base.py CHANGED
@@ -13,25 +13,32 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import asyncio
17
+ import json
16
18
  import logging
17
19
  from abc import ABC
18
20
  from abc import abstractmethod
19
21
  from enum import Enum
22
+ from typing import Any
20
23
 
21
24
  from colorama import Fore
22
25
  from langchain_core.callbacks import AsyncCallbackHandler
23
26
  from langchain_core.language_models import BaseChatModel
27
+ from langchain_core.messages import AIMessage
28
+ from langchain_core.messages import BaseMessage
29
+ from langchain_core.messages import ToolMessage
30
+ from langchain_core.runnables import RunnableConfig
24
31
  from langchain_core.tools import BaseTool
25
32
  from langgraph.graph.graph import CompiledGraph
26
33
 
27
- log = logging.getLogger(__name__)
34
+ logger = logging.getLogger(__name__)
28
35
 
29
36
  TOOL_NOT_FOUND_ERROR_MESSAGE = "There is no tool named {tool_name}. Tool must be one of {tools}."
30
37
  INPUT_SCHEMA_MESSAGE = ". Arguments must be provided as a valid JSON object following this format: {schema}"
31
- NO_INPUT_ERROR_MESSAGE = "No human input recieved to the agent, Please ask a valid question."
38
+ NO_INPUT_ERROR_MESSAGE = "No human input received to the agent, Please ask a valid question."
32
39
 
33
40
  AGENT_LOG_PREFIX = "[AGENT]"
34
- AGENT_RESPONSE_LOG_MESSAGE = f"\n{'-' * 30}\n" + \
41
+ AGENT_CALL_LOG_MESSAGE = f"\n{'-' * 30}\n" + \
35
42
  AGENT_LOG_PREFIX + "\n" + \
36
43
  Fore.YELLOW + \
37
44
  "Agent input: %s\n" + \
@@ -40,7 +47,7 @@ AGENT_RESPONSE_LOG_MESSAGE = f"\n{'-' * 30}\n" + \
40
47
  Fore.RESET + \
41
48
  f"\n{'-' * 30}"
42
49
 
43
- TOOL_RESPONSE_LOG_MESSAGE = f"\n{'-' * 30}\n" + \
50
+ TOOL_CALL_LOG_MESSAGE = f"\n{'-' * 30}\n" + \
44
51
  AGENT_LOG_PREFIX + "\n" + \
45
52
  Fore.WHITE + \
46
53
  "Calling tools: %s\n" + \
@@ -62,15 +69,170 @@ class BaseAgent(ABC):
62
69
  def __init__(self,
63
70
  llm: BaseChatModel,
64
71
  tools: list[BaseTool],
65
- callbacks: list[AsyncCallbackHandler] = None,
66
- detailed_logs: bool = False):
67
- log.debug("Initializing Agent Graph")
72
+ callbacks: list[AsyncCallbackHandler] | None = None,
73
+ detailed_logs: bool = False) -> None:
74
+ logger.debug("Initializing Agent Graph")
68
75
  self.llm = llm
69
76
  self.tools = tools
70
77
  self.callbacks = callbacks or []
71
78
  self.detailed_logs = detailed_logs
72
79
  self.graph = None
73
80
 
81
+ async def _stream_llm(self,
82
+ runnable: Any,
83
+ inputs: dict[str, Any],
84
+ config: RunnableConfig | None = None) -> AIMessage:
85
+ """
86
+ Stream from LLM runnable. Retry logic is handled automatically by the underlying LLM client.
87
+
88
+ Parameters
89
+ ----------
90
+ runnable : Any
91
+ The LLM runnable (prompt | llm or similar)
92
+ inputs : Dict[str, Any]
93
+ The inputs to pass to the runnable
94
+ config : RunnableConfig | None
95
+ The config to pass to the runnable (should include callbacks)
96
+
97
+ Returns
98
+ -------
99
+ AIMessage
100
+ The LLM response
101
+ """
102
+ output_message = ""
103
+ async for event in runnable.astream(inputs, config=config):
104
+ output_message += event.content
105
+
106
+ return AIMessage(content=output_message)
107
+
108
+ async def _call_llm(self, messages: list[BaseMessage]) -> AIMessage:
109
+ """
110
+ Call the LLM directly. Retry logic is handled automatically by the underlying LLM client.
111
+
112
+ Parameters
113
+ ----------
114
+ messages : list[BaseMessage]
115
+ The messages to send to the LLM
116
+
117
+ Returns
118
+ -------
119
+ AIMessage
120
+ The LLM response
121
+ """
122
+ response = await self.llm.ainvoke(messages)
123
+ return AIMessage(content=str(response.content))
124
+
125
+ async def _call_tool(self,
126
+ tool: BaseTool,
127
+ tool_input: dict[str, Any] | str,
128
+ config: RunnableConfig | None = None,
129
+ max_retries: int = 3) -> ToolMessage:
130
+ """
131
+ Call a tool with retry logic and error handling.
132
+
133
+ Parameters
134
+ ----------
135
+ tool : BaseTool
136
+ The tool to call
137
+ tool_input : Union[Dict[str, Any], str]
138
+ The input to pass to the tool
139
+ config : RunnableConfig | None
140
+ The config to pass to the tool
141
+ max_retries : int
142
+ Maximum number of retry attempts (default: 3)
143
+
144
+ Returns
145
+ -------
146
+ ToolMessage
147
+ The tool response
148
+ """
149
+ last_exception = None
150
+
151
+ for attempt in range(max_retries + 1):
152
+ try:
153
+ response = await tool.ainvoke(tool_input, config=config)
154
+
155
+ # Handle empty responses
156
+ if response is None or (isinstance(response, str) and response == ""):
157
+ return ToolMessage(name=tool.name,
158
+ tool_call_id=tool.name,
159
+ content=f"The tool {tool.name} provided an empty response.")
160
+
161
+ return ToolMessage(name=tool.name, tool_call_id=tool.name, content=response)
162
+
163
+ except Exception as e:
164
+ last_exception = e
165
+ logger.warning("%s Tool call attempt %d/%d failed for tool %s: %s",
166
+ AGENT_LOG_PREFIX,
167
+ attempt + 1,
168
+ max_retries + 1,
169
+ tool.name,
170
+ str(e))
171
+
172
+ # If this was the last attempt, don't sleep
173
+ if attempt == max_retries:
174
+ break
175
+
176
+ # Exponential backoff: 2^attempt seconds
177
+ sleep_time = 2**attempt
178
+ logger.debug("%s Retrying tool call for %s in %d seconds...", AGENT_LOG_PREFIX, tool.name, sleep_time)
179
+ await asyncio.sleep(sleep_time)
180
+
181
+ # All retries exhausted, return error message
182
+ error_content = "Tool call failed after all retry attempts. Last error: %s" % str(last_exception)
183
+ logger.error("%s %s", AGENT_LOG_PREFIX, error_content)
184
+ return ToolMessage(name=tool.name, tool_call_id=tool.name, content=error_content, status="error")
185
+
186
+ def _log_tool_response(self, tool_name: str, tool_input: Any, tool_response: str, max_chars: int = 1000) -> None:
187
+ """
188
+ Log tool response with consistent formatting and length limits.
189
+
190
+ Parameters
191
+ ----------
192
+ tool_name : str
193
+ The name of the tool that was called
194
+ tool_input : Any
195
+ The input that was passed to the tool
196
+ tool_response : str
197
+ The response from the tool
198
+ max_chars : int
199
+ Maximum number of characters to log (default: 1000)
200
+ """
201
+ if self.detailed_logs:
202
+ # Truncate tool response if too long
203
+ display_response = tool_response[:max_chars] + "...(rest of response truncated)" if len(
204
+ tool_response) > max_chars else tool_response
205
+
206
+ # Format the tool input for display
207
+ tool_input_str = str(tool_input)
208
+
209
+ tool_response_log_message = TOOL_CALL_LOG_MESSAGE % (tool_name, tool_input_str, display_response)
210
+ logger.info(tool_response_log_message)
211
+
212
+ def _parse_json(self, json_string: str) -> dict[str, Any]:
213
+ """
214
+ Safely parse JSON with graceful error handling.
215
+ If JSON parsing fails, returns an empty dict or error info.
216
+
217
+ Parameters
218
+ ----------
219
+ json_string : str
220
+ The JSON string to parse
221
+
222
+ Returns
223
+ -------
224
+ Dict[str, Any]
225
+ The parsed JSON or error information
226
+ """
227
+ try:
228
+ return json.loads(json_string)
229
+ except json.JSONDecodeError as e:
230
+ logger.warning("%s JSON parsing failed, returning the original string: %s", AGENT_LOG_PREFIX, str(e))
231
+ return {"error": f"JSON parsing failed: {str(e)}", "original_string": json_string}
232
+ except Exception as e:
233
+ logger.warning("%s Unexpected error during JSON parsing: %s", AGENT_LOG_PREFIX, str(e))
234
+ return {"error": f"Unexpected parsing error: {str(e)}", "original_string": json_string}
235
+
74
236
  @abstractmethod
75
- async def _build_graph(self, state_schema) -> CompiledGraph:
237
+ async def _build_graph(self, state_schema: type) -> CompiledGraph:
76
238
  pass
aiq/agent/dual_node.py CHANGED
@@ -34,7 +34,7 @@ class DualNodeAgent(BaseAgent):
34
34
  def __init__(self,
35
35
  llm: BaseChatModel,
36
36
  tools: list[BaseTool],
37
- callbacks: list[AsyncCallbackHandler] = None,
37
+ callbacks: list[AsyncCallbackHandler] | None = None,
38
38
  detailed_logs: bool = False):
39
39
  super().__init__(llm=llm, tools=tools, callbacks=callbacks, detailed_logs=detailed_logs)
40
40
 
@@ -33,12 +33,11 @@ from langchain_core.tools import BaseTool
33
33
  from pydantic import BaseModel
34
34
  from pydantic import Field
35
35
 
36
+ from aiq.agent.base import AGENT_CALL_LOG_MESSAGE
36
37
  from aiq.agent.base import AGENT_LOG_PREFIX
37
- from aiq.agent.base import AGENT_RESPONSE_LOG_MESSAGE
38
38
  from aiq.agent.base import INPUT_SCHEMA_MESSAGE
39
39
  from aiq.agent.base import NO_INPUT_ERROR_MESSAGE
40
40
  from aiq.agent.base import TOOL_NOT_FOUND_ERROR_MESSAGE
41
- from aiq.agent.base import TOOL_RESPONSE_LOG_MESSAGE
42
41
  from aiq.agent.base import AgentDecision
43
42
  from aiq.agent.dual_node import DualNodeAgent
44
43
  from aiq.agent.react_agent.output_parser import ReActOutputParser
@@ -67,13 +66,17 @@ class ReActAgentGraph(DualNodeAgent):
67
66
  prompt: ChatPromptTemplate,
68
67
  tools: list[BaseTool],
69
68
  use_tool_schema: bool = True,
70
- callbacks: list[AsyncCallbackHandler] = None,
69
+ callbacks: list[AsyncCallbackHandler] | None = None,
71
70
  detailed_logs: bool = False,
72
- retry_parsing_errors: bool = True,
73
- max_retries: int = 1):
71
+ retry_agent_response_parsing_errors: bool = True,
72
+ parse_agent_response_max_retries: int = 1,
73
+ tool_call_max_retries: int = 1,
74
+ pass_tool_call_errors_to_agent: bool = True):
74
75
  super().__init__(llm=llm, tools=tools, callbacks=callbacks, detailed_logs=detailed_logs)
75
- self.retry_parsing_errors = retry_parsing_errors
76
- self.max_tries = (max_retries + 1) if retry_parsing_errors else 1
76
+ self.parse_agent_response_max_retries = (parse_agent_response_max_retries
77
+ if retry_agent_response_parsing_errors else 1)
78
+ self.tool_call_max_retries = tool_call_max_retries
79
+ self.pass_tool_call_errors_to_agent = pass_tool_call_errors_to_agent
77
80
  logger.debug(
78
81
  "%s Filling the prompt variables 'tools' and 'tool_names', using the tools provided in the config.",
79
82
  AGENT_LOG_PREFIX)
@@ -91,12 +94,12 @@ class ReActAgentGraph(DualNodeAgent):
91
94
  f"{INPUT_SCHEMA_MESSAGE.format(schema=tools[-1].input_schema.model_fields)}")
92
95
  prompt = prompt.partial(tools=tool_names_and_descriptions, tool_names=tool_names)
93
96
  # construct the ReAct Agent
94
- llm = llm.bind(stop=["Observation:"])
95
- self.agent = prompt | llm
97
+ bound_llm = llm.bind(stop=["Observation:"]) # type: ignore
98
+ self.agent = prompt | bound_llm
96
99
  self.tools_dict = {tool.name: tool for tool in tools}
97
100
  logger.debug("%s Initialized ReAct Agent Graph", AGENT_LOG_PREFIX)
98
101
 
99
- def _get_tool(self, tool_name):
102
+ def _get_tool(self, tool_name: str):
100
103
  try:
101
104
  return self.tools_dict.get(tool_name)
102
105
  except Exception as ex:
@@ -113,26 +116,30 @@ class ReActAgentGraph(DualNodeAgent):
113
116
  # keeping a working state allows us to resolve parsing errors without polluting the agent scratchpad
114
117
  # the agent "forgets" about the parsing error after solving it - prevents hallucinations in next cycles
115
118
  working_state = []
116
- for attempt in range(1, self.max_tries + 1):
119
+ # Starting from attempt 1 instead of 0 for logging
120
+ for attempt in range(1, self.parse_agent_response_max_retries + 1):
117
121
  # the first time we are invoking the ReAct Agent, it won't have any intermediate steps / agent thoughts
118
122
  if len(state.agent_scratchpad) == 0 and len(working_state) == 0:
119
123
  # the user input comes from the "messages" state channel
120
124
  if len(state.messages) == 0:
121
125
  raise RuntimeError('No input received in state: "messages"')
122
126
  # to check is any human input passed or not, if no input passed Agent will return the state
123
- if state.messages[0].content.strip() == "":
127
+ content = str(state.messages[0].content)
128
+ if content.strip() == "":
124
129
  logger.error("%s No human input passed to the agent.", AGENT_LOG_PREFIX)
125
130
  state.messages += [AIMessage(content=NO_INPUT_ERROR_MESSAGE)]
126
131
  return state
127
- question = state.messages[0].content
132
+ question = content
128
133
  logger.debug("%s Querying agent, attempt: %s", AGENT_LOG_PREFIX, attempt)
129
- output_message = ""
130
- async for event in self.agent.astream({"question": question},
131
- config=RunnableConfig(callbacks=self.callbacks)):
132
- output_message += event.content
133
- output_message = AIMessage(content=output_message)
134
+
135
+ output_message = await self._stream_llm(
136
+ self.agent,
137
+ {"question": question},
138
+ RunnableConfig(callbacks=self.callbacks) # type: ignore
139
+ )
140
+
134
141
  if self.detailed_logs:
135
- logger.info(AGENT_RESPONSE_LOG_MESSAGE, question, output_message.content)
142
+ logger.info(AGENT_CALL_LOG_MESSAGE, question, output_message.content)
136
143
  else:
137
144
  # ReAct Agents require agentic cycles
138
145
  # in an agentic cycle, preserve the agent's thoughts from the previous cycles,
@@ -141,20 +148,20 @@ class ReActAgentGraph(DualNodeAgent):
141
148
  for index, intermediate_step in enumerate(state.agent_scratchpad):
142
149
  agent_thoughts = AIMessage(content=intermediate_step.log)
143
150
  agent_scratchpad.append(agent_thoughts)
144
- tool_response = HumanMessage(content=state.tool_responses[index].content)
151
+ tool_response_content = str(state.tool_responses[index].content)
152
+ tool_response = HumanMessage(content=tool_response_content)
145
153
  agent_scratchpad.append(tool_response)
146
154
  agent_scratchpad += working_state
147
- question = state.messages[0].content
155
+ question = str(state.messages[0].content)
148
156
  logger.debug("%s Querying agent, attempt: %s", AGENT_LOG_PREFIX, attempt)
149
- output_message = ""
150
- async for event in self.agent.astream({
151
- "question": question, "agent_scratchpad": agent_scratchpad
157
+
158
+ output_message = await self._stream_llm(self.agent, {
159
+ "question": question, "agent_scratchpad": agent_scratchpad
152
160
  },
153
- config=RunnableConfig(callbacks=self.callbacks)):
154
- output_message += event.content
155
- output_message = AIMessage(content=output_message)
161
+ RunnableConfig(callbacks=self.callbacks))
162
+
156
163
  if self.detailed_logs:
157
- logger.info(AGENT_RESPONSE_LOG_MESSAGE, question, output_message.content)
164
+ logger.info(AGENT_CALL_LOG_MESSAGE, question, output_message.content)
158
165
  logger.debug("%s The agent's scratchpad (with tool result) was:\n%s",
159
166
  AGENT_LOG_PREFIX,
160
167
  agent_scratchpad)
@@ -162,11 +169,7 @@ class ReActAgentGraph(DualNodeAgent):
162
169
  # check if the agent has the final answer yet
163
170
  logger.debug("%s Successfully obtained agent response. Parsing agent's response", AGENT_LOG_PREFIX)
164
171
  agent_output = await ReActOutputParser().aparse(output_message.content)
165
- logger.debug("%s Successfully parsed agent's response", AGENT_LOG_PREFIX)
166
- if attempt > 1:
167
- logger.debug("%s Successfully parsed agent response after %s attempts",
168
- AGENT_LOG_PREFIX,
169
- attempt)
172
+ logger.debug("%s Successfully parsed agent response after %s attempts", AGENT_LOG_PREFIX, attempt)
170
173
  if isinstance(agent_output, AgentFinish):
171
174
  final_answer = agent_output.return_values.get('output', output_message.content)
172
175
  logger.debug("%s The agent has finished, and has the final answer", AGENT_LOG_PREFIX)
@@ -178,31 +181,33 @@ class ReActAgentGraph(DualNodeAgent):
178
181
  agent_output.log = output_message.content
179
182
  logger.debug("%s The agent wants to call a tool: %s", AGENT_LOG_PREFIX, agent_output.tool)
180
183
  state.agent_scratchpad += [agent_output]
184
+
181
185
  return state
182
186
  except ReActOutputParserException as ex:
183
187
  # the agent output did not meet the expected ReAct output format. This can happen for a few reasons:
184
188
  # the agent mentioned a tool, but already has the final answer, this can happen with Llama models
185
189
  # - the ReAct Agent already has the answer, and is reflecting on how it obtained the answer
186
190
  # the agent might have also missed Action or Action Input in its output
187
- logger.warning("%s Error parsing agent output\nObservation:%s\nAgent Output:\n%s",
188
- AGENT_LOG_PREFIX,
189
- ex.observation,
190
- output_message.content)
191
- if attempt == self.max_tries:
192
- logger.exception(
191
+ logger.debug("%s Error parsing agent output\nObservation:%s\nAgent Output:\n%s",
192
+ AGENT_LOG_PREFIX,
193
+ ex.observation,
194
+ output_message.content)
195
+ if attempt == self.parse_agent_response_max_retries:
196
+ logger.error(
193
197
  "%s Failed to parse agent output after %d attempts, consider enabling or "
194
- "increasing max_retries",
198
+ "increasing parse_agent_response_max_retries",
195
199
  AGENT_LOG_PREFIX,
196
200
  attempt,
197
201
  exc_info=True)
198
202
  # the final answer goes in the "messages" state channel
199
- output_message.content = ex.observation + '\n' + output_message.content
203
+ combined_content = str(ex.observation) + '\n' + str(output_message.content)
204
+ output_message.content = combined_content
200
205
  state.messages += [output_message]
201
206
  return state
202
207
  # retry parsing errors, if configured
203
208
  logger.info("%s Retrying ReAct Agent, including output parsing Observation", AGENT_LOG_PREFIX)
204
209
  working_state.append(output_message)
205
- working_state.append(HumanMessage(content=ex.observation))
210
+ working_state.append(HumanMessage(content=str(ex.observation)))
206
211
  except Exception as ex:
207
212
  logger.exception("%s Failed to call agent_node: %s", AGENT_LOG_PREFIX, ex, exc_info=True)
208
213
  raise ex
@@ -212,7 +217,8 @@ class ReActAgentGraph(DualNodeAgent):
212
217
  logger.debug("%s Starting the ReAct Conditional Edge", AGENT_LOG_PREFIX)
213
218
  if len(state.messages) > 1:
214
219
  # the ReAct Agent has finished executing, the last agent output was AgentFinish
215
- logger.debug("%s Final answer:\n%s", AGENT_LOG_PREFIX, state.messages[-1].content)
220
+ last_message_content = str(state.messages[-1].content)
221
+ logger.debug("%s Final answer:\n%s", AGENT_LOG_PREFIX, last_message_content)
216
222
  return AgentDecision.END
217
223
  # else the agent wants to call a tool
218
224
  agent_output = state.agent_scratchpad[-1]
@@ -227,76 +233,71 @@ class ReActAgentGraph(DualNodeAgent):
227
233
  return AgentDecision.END
228
234
 
229
235
  async def tool_node(self, state: ReActGraphState):
230
- try:
231
- logger.debug("%s Starting the Tool Call Node", AGENT_LOG_PREFIX)
232
- if len(state.agent_scratchpad) == 0:
233
- raise RuntimeError('No tool input received in state: "agent_scratchpad"')
234
- agent_thoughts = state.agent_scratchpad[-1]
235
- # the agent can run any installed tool, simply install the tool and add it to the config file
236
- requested_tool = self._get_tool(agent_thoughts.tool)
237
- if not requested_tool:
238
- configured_tool_names = list(self.tools_dict.keys())
239
- logger.warning(
240
- "%s ReAct Agent wants to call tool %s. In the ReAct Agent's configuration within the config file,"
241
- "there is no tool with that name: %s",
242
- AGENT_LOG_PREFIX,
243
- agent_thoughts.tool,
244
- configured_tool_names)
245
- tool_response = ToolMessage(name='agent_error',
246
- tool_call_id='agent_error',
247
- content=TOOL_NOT_FOUND_ERROR_MESSAGE.format(tool_name=agent_thoughts.tool,
248
- tools=configured_tool_names))
249
- state.tool_responses += [tool_response]
250
- return state
251
-
252
- logger.debug("%s Calling tool %s with input: %s",
253
- AGENT_LOG_PREFIX,
254
- requested_tool.name,
255
- agent_thoughts.tool_input)
256
-
257
- # Run the tool. Try to use structured input, if possible.
258
- try:
259
- tool_input_str = agent_thoughts.tool_input.strip().replace("'", '"')
260
- tool_input_dict = json.loads(tool_input_str) if tool_input_str != 'None' else tool_input_str
261
- logger.debug("%s Successfully parsed structured tool input from Action Input", AGENT_LOG_PREFIX)
262
- tool_response = await requested_tool.ainvoke(tool_input_dict,
263
- config=RunnableConfig(callbacks=self.callbacks))
264
- if self.detailed_logs:
265
- # The tool response can be very large, so we log only the first 1000 characters
266
- tool_response_str = str(tool_response)
267
- tool_response_str = tool_response_str[:1000] + "..." if len(
268
- tool_response_str) > 1000 else tool_response_str
269
- tool_response_log_message = TOOL_RESPONSE_LOG_MESSAGE % (
270
- requested_tool.name, tool_input_str, tool_response_str)
271
- logger.info(tool_response_log_message)
272
- except JSONDecodeError as ex:
273
- logger.warning(
274
- "%s Unable to parse structured tool input from Action Input. Using Action Input as is."
275
- "\nParsing error: %s",
276
- AGENT_LOG_PREFIX,
277
- ex,
278
- exc_info=True)
279
- tool_input_str = agent_thoughts.tool_input
280
- tool_response = await requested_tool.ainvoke(tool_input_str,
281
- config=RunnableConfig(callbacks=self.callbacks))
282
-
283
- # some tools, such as Wikipedia, will return an empty response when no search results are found
284
- if tool_response is None or tool_response == "":
285
- tool_response = "The tool provided an empty response.\n"
286
- # put the tool response in the graph state
287
- tool_response = ToolMessage(name=agent_thoughts.tool,
288
- tool_call_id=agent_thoughts.tool,
289
- content=tool_response)
290
- logger.debug("%s Called tool %s with input: %s\nThe tool returned: %s",
291
- AGENT_LOG_PREFIX,
292
- requested_tool.name,
293
- agent_thoughts.tool_input,
294
- tool_response.content)
236
+
237
+ logger.debug("%s Starting the Tool Call Node", AGENT_LOG_PREFIX)
238
+ if len(state.agent_scratchpad) == 0:
239
+ raise RuntimeError('No tool input received in state: "agent_scratchpad"')
240
+ agent_thoughts = state.agent_scratchpad[-1]
241
+ # the agent can run any installed tool, simply install the tool and add it to the config file
242
+ requested_tool = self._get_tool(agent_thoughts.tool)
243
+ if not requested_tool:
244
+ configured_tool_names = list(self.tools_dict.keys())
245
+ logger.warning(
246
+ "%s ReAct Agent wants to call tool %s. In the ReAct Agent's configuration within the config file,"
247
+ "there is no tool with that name: %s",
248
+ AGENT_LOG_PREFIX,
249
+ agent_thoughts.tool,
250
+ configured_tool_names)
251
+ tool_response = ToolMessage(name='agent_error',
252
+ tool_call_id='agent_error',
253
+ content=TOOL_NOT_FOUND_ERROR_MESSAGE.format(tool_name=agent_thoughts.tool,
254
+ tools=configured_tool_names))
295
255
  state.tool_responses += [tool_response]
296
256
  return state
297
- except Exception as ex:
298
- logger.exception("%s Failed to call tool_node: %s", AGENT_LOG_PREFIX, ex, exc_info=ex)
299
- raise ex
257
+
258
+ logger.debug("%s Calling tool %s with input: %s",
259
+ AGENT_LOG_PREFIX,
260
+ requested_tool.name,
261
+ agent_thoughts.tool_input)
262
+
263
+ # Run the tool. Try to use structured input, if possible.
264
+ try:
265
+ tool_input_str = str(agent_thoughts.tool_input).strip().replace("'", '"')
266
+ tool_input_dict = json.loads(tool_input_str) if tool_input_str != 'None' else tool_input_str
267
+ logger.debug("%s Successfully parsed structured tool input from Action Input", AGENT_LOG_PREFIX)
268
+
269
+ tool_response = await self._call_tool(requested_tool,
270
+ tool_input_dict,
271
+ RunnableConfig(callbacks=self.callbacks),
272
+ max_retries=self.tool_call_max_retries)
273
+
274
+ if self.detailed_logs:
275
+ self._log_tool_response(requested_tool.name, tool_input_dict, str(tool_response.content))
276
+
277
+ except JSONDecodeError as ex:
278
+ logger.debug(
279
+ "%s Unable to parse structured tool input from Action Input. Using Action Input as is."
280
+ "\nParsing error: %s",
281
+ AGENT_LOG_PREFIX,
282
+ ex,
283
+ exc_info=True)
284
+ tool_input_str = str(agent_thoughts.tool_input)
285
+
286
+ tool_response = await self._call_tool(requested_tool,
287
+ tool_input_str,
288
+ RunnableConfig(callbacks=self.callbacks),
289
+ max_retries=self.tool_call_max_retries)
290
+
291
+ if self.detailed_logs:
292
+ self._log_tool_response(requested_tool.name, tool_input_str, str(tool_response.content))
293
+
294
+ if not self.pass_tool_call_errors_to_agent:
295
+ if tool_response.status == "error":
296
+ logger.error("%s Tool %s failed: %s", AGENT_LOG_PREFIX, requested_tool.name, tool_response.content)
297
+ raise RuntimeError("Tool call failed: " + str(tool_response.content))
298
+
299
+ state.tool_responses += [tool_response]
300
+ return state
300
301
 
301
302
  async def build_graph(self):
302
303
  try:
@@ -15,6 +15,7 @@
15
15
 
16
16
  import logging
17
17
 
18
+ from pydantic import AliasChoices
18
19
  from pydantic import Field
19
20
 
20
21
  from aiq.agent.base import AGENT_LOG_PREFIX
@@ -42,11 +43,24 @@ class ReActAgentWorkflowConfig(FunctionBaseConfig, name="react_agent"):
42
43
  description="The list of tools to provide to the react agent.")
43
44
  llm_name: LLMRef = Field(description="The LLM model to use with the react agent.")
44
45
  verbose: bool = Field(default=False, description="Set the verbosity of the react agent's logging.")
45
- retry_parsing_errors: bool = Field(default=True, description="Specify retrying when encountering parsing errors.")
46
- max_retries: int = Field(default=1, description="Sent the number of retries before raising a parsing error.")
46
+ retry_agent_response_parsing_errors: bool = Field(
47
+ default=True,
48
+ validation_alias=AliasChoices("retry_agent_response_parsing_errors", "retry_parsing_errors"),
49
+ description="Whether to retry when encountering parsing errors in the agent's response.")
50
+ parse_agent_response_max_retries: int = Field(
51
+ default=1,
52
+ validation_alias=AliasChoices("parse_agent_response_max_retries", "max_retries"),
53
+ description="Maximum number of times the Agent may retry parsing errors. "
54
+ "Prevents the Agent from getting into infinite hallucination loops.")
55
+ tool_call_max_retries: int = Field(default=1, description="The number of retries before raising a tool call error.")
56
+ max_tool_calls: int = Field(default=15,
57
+ validation_alias=AliasChoices("max_tool_calls", "max_iterations"),
58
+ description="Maximum number of tool calls before stopping the agent.")
59
+ pass_tool_call_errors_to_agent: bool = Field(
60
+ default=True,
61
+ description="Whether to pass tool call errors to agent. If False, failed tool calls will raise an exception.")
47
62
  include_tool_input_schema_in_tool_description: bool = Field(
48
63
  default=True, description="Specify inclusion of tool input schemas in the prompt.")
49
- max_iterations: int = Field(default=15, description="Number of tool calls before stoping the react agent.")
50
64
  description: str = Field(default="ReAct Agent Workflow", description="The description of this functions use.")
51
65
  system_prompt: str | None = Field(
52
66
  default=None,
@@ -80,13 +94,16 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
80
94
  raise ValueError(f"No tools specified for ReAct Agent '{config.llm_name}'")
81
95
  # configure callbacks, for sending intermediate steps
82
96
  # construct the ReAct Agent Graph from the configured llm, prompt, and tools
83
- graph: CompiledGraph = await ReActAgentGraph(llm=llm,
84
- prompt=prompt,
85
- tools=tools,
86
- use_tool_schema=config.include_tool_input_schema_in_tool_description,
87
- detailed_logs=config.verbose,
88
- retry_parsing_errors=config.retry_parsing_errors,
89
- max_retries=config.max_retries).build_graph()
97
+ graph: CompiledGraph = await ReActAgentGraph(
98
+ llm=llm,
99
+ prompt=prompt,
100
+ tools=tools,
101
+ use_tool_schema=config.include_tool_input_schema_in_tool_description,
102
+ detailed_logs=config.verbose,
103
+ retry_agent_response_parsing_errors=config.retry_agent_response_parsing_errors,
104
+ parse_agent_response_max_retries=config.parse_agent_response_max_retries,
105
+ tool_call_max_retries=config.tool_call_max_retries,
106
+ pass_tool_call_errors_to_agent=config.pass_tool_call_errors_to_agent).build_graph()
90
107
 
91
108
  async def _response_fn(input_message: AIQChatRequest) -> AIQChatResponse:
92
109
  try:
@@ -101,7 +118,7 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
101
118
  state = ReActGraphState(messages=messages)
102
119
 
103
120
  # run the ReAct Agent Graph
104
- state = await graph.ainvoke(state, config={'recursion_limit': (config.max_iterations + 1) * 2})
121
+ state = await graph.ainvoke(state, config={'recursion_limit': (config.max_tool_calls + 1) * 2})
105
122
  # setting recursion_limit: 4 allows 1 tool call
106
123
  # - allows the ReAct Agent to perform 1 cycle / call 1 single tool,
107
124
  # - but stops the agent when it tries to call a tool a second time
@@ -109,7 +126,7 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
109
126
  # get and return the output from the state
110
127
  state = ReActGraphState(**state)
111
128
  output_message = state.messages[-1] # pylint: disable=E1136
112
- return AIQChatResponse.from_string(output_message.content)
129
+ return AIQChatResponse.from_string(str(output_message.content))
113
130
 
114
131
  except Exception as ex:
115
132
  logger.exception("%s ReAct Agent failed with exception: %s", AGENT_LOG_PREFIX, ex, exc_info=ex)
@@ -123,10 +140,10 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
123
140
  else:
124
141
 
125
142
  async def _str_api_fn(input_message: str) -> str:
126
- oai_input = GlobalTypeConverter.get().convert(input_message, to_type=AIQChatRequest)
143
+ oai_input = GlobalTypeConverter.get().try_convert(input_message, to_type=AIQChatRequest)
127
144
 
128
145
  oai_output = await _response_fn(oai_input)
129
146
 
130
- return GlobalTypeConverter.get().convert(oai_output, to_type=str)
147
+ return GlobalTypeConverter.get().try_convert(oai_output, to_type=str)
131
148
 
132
149
  yield FunctionInfo.from_fn(_str_api_fn, description=config.description)