aiqtoolkit 1.2.0.dev0__py3-none-any.whl → 1.2.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aiqtoolkit might be problematic. Click here for more details.

Files changed (220) hide show
  1. aiq/agent/base.py +170 -8
  2. aiq/agent/dual_node.py +1 -1
  3. aiq/agent/react_agent/agent.py +146 -112
  4. aiq/agent/react_agent/prompt.py +1 -6
  5. aiq/agent/react_agent/register.py +36 -35
  6. aiq/agent/rewoo_agent/agent.py +36 -35
  7. aiq/agent/rewoo_agent/register.py +2 -2
  8. aiq/agent/tool_calling_agent/agent.py +3 -7
  9. aiq/agent/tool_calling_agent/register.py +1 -1
  10. aiq/authentication/__init__.py +14 -0
  11. aiq/authentication/api_key/__init__.py +14 -0
  12. aiq/authentication/api_key/api_key_auth_provider.py +92 -0
  13. aiq/authentication/api_key/api_key_auth_provider_config.py +124 -0
  14. aiq/authentication/api_key/register.py +26 -0
  15. aiq/authentication/exceptions/__init__.py +14 -0
  16. aiq/authentication/exceptions/api_key_exceptions.py +38 -0
  17. aiq/authentication/exceptions/auth_code_grant_exceptions.py +86 -0
  18. aiq/authentication/exceptions/call_back_exceptions.py +38 -0
  19. aiq/authentication/exceptions/request_exceptions.py +54 -0
  20. aiq/authentication/http_basic_auth/__init__.py +0 -0
  21. aiq/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
  22. aiq/authentication/http_basic_auth/register.py +30 -0
  23. aiq/authentication/interfaces.py +93 -0
  24. aiq/authentication/oauth2/__init__.py +14 -0
  25. aiq/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
  26. aiq/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
  27. aiq/authentication/oauth2/register.py +25 -0
  28. aiq/authentication/register.py +21 -0
  29. aiq/builder/builder.py +64 -2
  30. aiq/builder/component_utils.py +16 -3
  31. aiq/builder/context.py +37 -0
  32. aiq/builder/eval_builder.py +43 -2
  33. aiq/builder/function.py +44 -12
  34. aiq/builder/function_base.py +1 -1
  35. aiq/builder/intermediate_step_manager.py +6 -8
  36. aiq/builder/user_interaction_manager.py +3 -0
  37. aiq/builder/workflow.py +23 -18
  38. aiq/builder/workflow_builder.py +421 -61
  39. aiq/cli/commands/info/list_mcp.py +103 -16
  40. aiq/cli/commands/sizing/__init__.py +14 -0
  41. aiq/cli/commands/sizing/calc.py +294 -0
  42. aiq/cli/commands/sizing/sizing.py +27 -0
  43. aiq/cli/commands/start.py +2 -1
  44. aiq/cli/entrypoint.py +2 -0
  45. aiq/cli/register_workflow.py +80 -0
  46. aiq/cli/type_registry.py +151 -30
  47. aiq/data_models/api_server.py +124 -12
  48. aiq/data_models/authentication.py +231 -0
  49. aiq/data_models/common.py +35 -7
  50. aiq/data_models/component.py +17 -9
  51. aiq/data_models/component_ref.py +33 -0
  52. aiq/data_models/config.py +60 -3
  53. aiq/data_models/dataset_handler.py +2 -1
  54. aiq/data_models/embedder.py +1 -0
  55. aiq/data_models/evaluate.py +23 -0
  56. aiq/data_models/function_dependencies.py +8 -0
  57. aiq/data_models/interactive.py +10 -1
  58. aiq/data_models/intermediate_step.py +38 -5
  59. aiq/data_models/its_strategy.py +30 -0
  60. aiq/data_models/llm.py +1 -0
  61. aiq/data_models/memory.py +1 -0
  62. aiq/data_models/object_store.py +44 -0
  63. aiq/data_models/profiler.py +1 -0
  64. aiq/data_models/retry_mixin.py +35 -0
  65. aiq/data_models/span.py +187 -0
  66. aiq/data_models/telemetry_exporter.py +2 -2
  67. aiq/embedder/nim_embedder.py +2 -1
  68. aiq/embedder/openai_embedder.py +2 -1
  69. aiq/eval/config.py +19 -1
  70. aiq/eval/dataset_handler/dataset_handler.py +87 -2
  71. aiq/eval/evaluate.py +208 -27
  72. aiq/eval/evaluator/base_evaluator.py +73 -0
  73. aiq/eval/evaluator/evaluator_model.py +1 -0
  74. aiq/eval/intermediate_step_adapter.py +11 -5
  75. aiq/eval/rag_evaluator/evaluate.py +55 -15
  76. aiq/eval/rag_evaluator/register.py +6 -1
  77. aiq/eval/remote_workflow.py +7 -2
  78. aiq/eval/runners/__init__.py +14 -0
  79. aiq/eval/runners/config.py +39 -0
  80. aiq/eval/runners/multi_eval_runner.py +54 -0
  81. aiq/eval/trajectory_evaluator/evaluate.py +22 -65
  82. aiq/eval/tunable_rag_evaluator/evaluate.py +150 -168
  83. aiq/eval/tunable_rag_evaluator/register.py +2 -0
  84. aiq/eval/usage_stats.py +41 -0
  85. aiq/eval/utils/output_uploader.py +10 -1
  86. aiq/eval/utils/weave_eval.py +184 -0
  87. aiq/experimental/__init__.py +0 -0
  88. aiq/experimental/decorators/__init__.py +0 -0
  89. aiq/experimental/decorators/experimental_warning_decorator.py +130 -0
  90. aiq/experimental/inference_time_scaling/__init__.py +0 -0
  91. aiq/experimental/inference_time_scaling/editing/__init__.py +0 -0
  92. aiq/experimental/inference_time_scaling/editing/iterative_plan_refinement_editor.py +147 -0
  93. aiq/experimental/inference_time_scaling/editing/llm_as_a_judge_editor.py +204 -0
  94. aiq/experimental/inference_time_scaling/editing/motivation_aware_summarization.py +107 -0
  95. aiq/experimental/inference_time_scaling/functions/__init__.py +0 -0
  96. aiq/experimental/inference_time_scaling/functions/execute_score_select_function.py +105 -0
  97. aiq/experimental/inference_time_scaling/functions/its_tool_orchestration_function.py +205 -0
  98. aiq/experimental/inference_time_scaling/functions/its_tool_wrapper_function.py +146 -0
  99. aiq/experimental/inference_time_scaling/functions/plan_select_execute_function.py +224 -0
  100. aiq/experimental/inference_time_scaling/models/__init__.py +0 -0
  101. aiq/experimental/inference_time_scaling/models/editor_config.py +132 -0
  102. aiq/experimental/inference_time_scaling/models/its_item.py +48 -0
  103. aiq/experimental/inference_time_scaling/models/scoring_config.py +112 -0
  104. aiq/experimental/inference_time_scaling/models/search_config.py +120 -0
  105. aiq/experimental/inference_time_scaling/models/selection_config.py +154 -0
  106. aiq/experimental/inference_time_scaling/models/stage_enums.py +43 -0
  107. aiq/experimental/inference_time_scaling/models/strategy_base.py +66 -0
  108. aiq/experimental/inference_time_scaling/models/tool_use_config.py +41 -0
  109. aiq/experimental/inference_time_scaling/register.py +36 -0
  110. aiq/experimental/inference_time_scaling/scoring/__init__.py +0 -0
  111. aiq/experimental/inference_time_scaling/scoring/llm_based_agent_scorer.py +168 -0
  112. aiq/experimental/inference_time_scaling/scoring/llm_based_plan_scorer.py +168 -0
  113. aiq/experimental/inference_time_scaling/scoring/motivation_aware_scorer.py +111 -0
  114. aiq/experimental/inference_time_scaling/search/__init__.py +0 -0
  115. aiq/experimental/inference_time_scaling/search/multi_llm_planner.py +128 -0
  116. aiq/experimental/inference_time_scaling/search/multi_query_retrieval_search.py +122 -0
  117. aiq/experimental/inference_time_scaling/search/single_shot_multi_plan_planner.py +128 -0
  118. aiq/experimental/inference_time_scaling/selection/__init__.py +0 -0
  119. aiq/experimental/inference_time_scaling/selection/best_of_n_selector.py +63 -0
  120. aiq/experimental/inference_time_scaling/selection/llm_based_agent_output_selector.py +131 -0
  121. aiq/experimental/inference_time_scaling/selection/llm_based_output_merging_selector.py +159 -0
  122. aiq/experimental/inference_time_scaling/selection/llm_based_plan_selector.py +128 -0
  123. aiq/experimental/inference_time_scaling/selection/threshold_selector.py +58 -0
  124. aiq/front_ends/console/authentication_flow_handler.py +233 -0
  125. aiq/front_ends/console/console_front_end_plugin.py +11 -2
  126. aiq/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
  127. aiq/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
  128. aiq/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
  129. aiq/front_ends/fastapi/fastapi_front_end_config.py +93 -9
  130. aiq/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
  131. aiq/front_ends/fastapi/fastapi_front_end_plugin.py +14 -1
  132. aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +537 -52
  133. aiq/front_ends/fastapi/html_snippets/__init__.py +14 -0
  134. aiq/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
  135. aiq/front_ends/fastapi/job_store.py +47 -25
  136. aiq/front_ends/fastapi/main.py +2 -0
  137. aiq/front_ends/fastapi/message_handler.py +108 -89
  138. aiq/front_ends/fastapi/step_adaptor.py +2 -1
  139. aiq/llm/aws_bedrock_llm.py +57 -0
  140. aiq/llm/nim_llm.py +2 -1
  141. aiq/llm/openai_llm.py +3 -2
  142. aiq/llm/register.py +1 -0
  143. aiq/meta/pypi.md +12 -12
  144. aiq/object_store/__init__.py +20 -0
  145. aiq/object_store/in_memory_object_store.py +74 -0
  146. aiq/object_store/interfaces.py +84 -0
  147. aiq/object_store/models.py +36 -0
  148. aiq/object_store/register.py +20 -0
  149. aiq/observability/__init__.py +14 -0
  150. aiq/observability/exporter/__init__.py +14 -0
  151. aiq/observability/exporter/base_exporter.py +449 -0
  152. aiq/observability/exporter/exporter.py +78 -0
  153. aiq/observability/exporter/file_exporter.py +33 -0
  154. aiq/observability/exporter/processing_exporter.py +269 -0
  155. aiq/observability/exporter/raw_exporter.py +52 -0
  156. aiq/observability/exporter/span_exporter.py +264 -0
  157. aiq/observability/exporter_manager.py +335 -0
  158. aiq/observability/mixin/__init__.py +14 -0
  159. aiq/observability/mixin/batch_config_mixin.py +26 -0
  160. aiq/observability/mixin/collector_config_mixin.py +23 -0
  161. aiq/observability/mixin/file_mixin.py +288 -0
  162. aiq/observability/mixin/file_mode.py +23 -0
  163. aiq/observability/mixin/resource_conflict_mixin.py +134 -0
  164. aiq/observability/mixin/serialize_mixin.py +61 -0
  165. aiq/observability/mixin/type_introspection_mixin.py +183 -0
  166. aiq/observability/processor/__init__.py +14 -0
  167. aiq/observability/processor/batching_processor.py +316 -0
  168. aiq/observability/processor/intermediate_step_serializer.py +28 -0
  169. aiq/observability/processor/processor.py +68 -0
  170. aiq/observability/register.py +36 -39
  171. aiq/observability/utils/__init__.py +14 -0
  172. aiq/observability/utils/dict_utils.py +236 -0
  173. aiq/observability/utils/time_utils.py +31 -0
  174. aiq/profiler/calc/__init__.py +14 -0
  175. aiq/profiler/calc/calc_runner.py +623 -0
  176. aiq/profiler/calc/calculations.py +288 -0
  177. aiq/profiler/calc/data_models.py +176 -0
  178. aiq/profiler/calc/plot.py +345 -0
  179. aiq/profiler/callbacks/langchain_callback_handler.py +22 -10
  180. aiq/profiler/data_models.py +24 -0
  181. aiq/profiler/inference_metrics_model.py +3 -0
  182. aiq/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +8 -0
  183. aiq/profiler/inference_optimization/data_models.py +2 -2
  184. aiq/profiler/inference_optimization/llm_metrics.py +2 -2
  185. aiq/profiler/profile_runner.py +61 -21
  186. aiq/runtime/loader.py +9 -3
  187. aiq/runtime/runner.py +23 -9
  188. aiq/runtime/session.py +25 -7
  189. aiq/runtime/user_metadata.py +2 -3
  190. aiq/tool/chat_completion.py +74 -0
  191. aiq/tool/code_execution/README.md +152 -0
  192. aiq/tool/code_execution/code_sandbox.py +151 -72
  193. aiq/tool/code_execution/local_sandbox/.gitignore +1 -0
  194. aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +139 -24
  195. aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +3 -1
  196. aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +27 -2
  197. aiq/tool/code_execution/register.py +7 -3
  198. aiq/tool/code_execution/test_code_execution_sandbox.py +414 -0
  199. aiq/tool/mcp/exceptions.py +142 -0
  200. aiq/tool/mcp/mcp_client.py +41 -6
  201. aiq/tool/mcp/mcp_tool.py +3 -2
  202. aiq/tool/register.py +1 -0
  203. aiq/tool/server_tools.py +6 -3
  204. aiq/utils/exception_handlers/automatic_retries.py +289 -0
  205. aiq/utils/exception_handlers/mcp.py +211 -0
  206. aiq/utils/io/model_processing.py +28 -0
  207. aiq/utils/log_utils.py +37 -0
  208. aiq/utils/string_utils.py +38 -0
  209. aiq/utils/type_converter.py +18 -2
  210. aiq/utils/type_utils.py +87 -0
  211. {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/METADATA +53 -21
  212. aiqtoolkit-1.2.0rc2.dist-info/RECORD +436 -0
  213. {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/WHEEL +1 -1
  214. {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/entry_points.txt +3 -0
  215. aiq/front_ends/fastapi/websocket.py +0 -148
  216. aiq/observability/async_otel_listener.py +0 -429
  217. aiqtoolkit-1.2.0.dev0.dist-info/RECORD +0 -316
  218. {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  219. {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  220. {aiqtoolkit-1.2.0.dev0.dist-info → aiqtoolkit-1.2.0rc2.dist-info}/top_level.txt +0 -0
aiq/agent/base.py CHANGED
@@ -13,25 +13,32 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import asyncio
17
+ import json
16
18
  import logging
17
19
  from abc import ABC
18
20
  from abc import abstractmethod
19
21
  from enum import Enum
22
+ from typing import Any
20
23
 
21
24
  from colorama import Fore
22
25
  from langchain_core.callbacks import AsyncCallbackHandler
23
26
  from langchain_core.language_models import BaseChatModel
27
+ from langchain_core.messages import AIMessage
28
+ from langchain_core.messages import BaseMessage
29
+ from langchain_core.messages import ToolMessage
30
+ from langchain_core.runnables import RunnableConfig
24
31
  from langchain_core.tools import BaseTool
25
32
  from langgraph.graph.graph import CompiledGraph
26
33
 
27
- log = logging.getLogger(__name__)
34
+ logger = logging.getLogger(__name__)
28
35
 
29
36
  TOOL_NOT_FOUND_ERROR_MESSAGE = "There is no tool named {tool_name}. Tool must be one of {tools}."
30
37
  INPUT_SCHEMA_MESSAGE = ". Arguments must be provided as a valid JSON object following this format: {schema}"
31
- NO_INPUT_ERROR_MESSAGE = "No human input recieved to the agent, Please ask a valid question."
38
+ NO_INPUT_ERROR_MESSAGE = "No human input received to the agent, Please ask a valid question."
32
39
 
33
40
  AGENT_LOG_PREFIX = "[AGENT]"
34
- AGENT_RESPONSE_LOG_MESSAGE = f"\n{'-' * 30}\n" + \
41
+ AGENT_CALL_LOG_MESSAGE = f"\n{'-' * 30}\n" + \
35
42
  AGENT_LOG_PREFIX + "\n" + \
36
43
  Fore.YELLOW + \
37
44
  "Agent input: %s\n" + \
@@ -40,7 +47,7 @@ AGENT_RESPONSE_LOG_MESSAGE = f"\n{'-' * 30}\n" + \
40
47
  Fore.RESET + \
41
48
  f"\n{'-' * 30}"
42
49
 
43
- TOOL_RESPONSE_LOG_MESSAGE = f"\n{'-' * 30}\n" + \
50
+ TOOL_CALL_LOG_MESSAGE = f"\n{'-' * 30}\n" + \
44
51
  AGENT_LOG_PREFIX + "\n" + \
45
52
  Fore.WHITE + \
46
53
  "Calling tools: %s\n" + \
@@ -62,15 +69,170 @@ class BaseAgent(ABC):
62
69
  def __init__(self,
63
70
  llm: BaseChatModel,
64
71
  tools: list[BaseTool],
65
- callbacks: list[AsyncCallbackHandler] = None,
66
- detailed_logs: bool = False):
67
- log.debug("Initializing Agent Graph")
72
+ callbacks: list[AsyncCallbackHandler] | None = None,
73
+ detailed_logs: bool = False) -> None:
74
+ logger.debug("Initializing Agent Graph")
68
75
  self.llm = llm
69
76
  self.tools = tools
70
77
  self.callbacks = callbacks or []
71
78
  self.detailed_logs = detailed_logs
72
79
  self.graph = None
73
80
 
81
+ async def _stream_llm(self,
82
+ runnable: Any,
83
+ inputs: dict[str, Any],
84
+ config: RunnableConfig | None = None) -> AIMessage:
85
+ """
86
+ Stream from LLM runnable. Retry logic is handled automatically by the underlying LLM client.
87
+
88
+ Parameters
89
+ ----------
90
+ runnable : Any
91
+ The LLM runnable (prompt | llm or similar)
92
+ inputs : Dict[str, Any]
93
+ The inputs to pass to the runnable
94
+ config : RunnableConfig | None
95
+ The config to pass to the runnable (should include callbacks)
96
+
97
+ Returns
98
+ -------
99
+ AIMessage
100
+ The LLM response
101
+ """
102
+ output_message = ""
103
+ async for event in runnable.astream(inputs, config=config):
104
+ output_message += event.content
105
+
106
+ return AIMessage(content=output_message)
107
+
108
+ async def _call_llm(self, messages: list[BaseMessage]) -> AIMessage:
109
+ """
110
+ Call the LLM directly. Retry logic is handled automatically by the underlying LLM client.
111
+
112
+ Parameters
113
+ ----------
114
+ messages : list[BaseMessage]
115
+ The messages to send to the LLM
116
+
117
+ Returns
118
+ -------
119
+ AIMessage
120
+ The LLM response
121
+ """
122
+ response = await self.llm.ainvoke(messages)
123
+ return AIMessage(content=str(response.content))
124
+
125
+ async def _call_tool(self,
126
+ tool: BaseTool,
127
+ tool_input: dict[str, Any] | str,
128
+ config: RunnableConfig | None = None,
129
+ max_retries: int = 3) -> ToolMessage:
130
+ """
131
+ Call a tool with retry logic and error handling.
132
+
133
+ Parameters
134
+ ----------
135
+ tool : BaseTool
136
+ The tool to call
137
+ tool_input : Union[Dict[str, Any], str]
138
+ The input to pass to the tool
139
+ config : RunnableConfig | None
140
+ The config to pass to the tool
141
+ max_retries : int
142
+ Maximum number of retry attempts (default: 3)
143
+
144
+ Returns
145
+ -------
146
+ ToolMessage
147
+ The tool response
148
+ """
149
+ last_exception = None
150
+
151
+ for attempt in range(max_retries + 1):
152
+ try:
153
+ response = await tool.ainvoke(tool_input, config=config)
154
+
155
+ # Handle empty responses
156
+ if response is None or (isinstance(response, str) and response == ""):
157
+ return ToolMessage(name=tool.name,
158
+ tool_call_id=tool.name,
159
+ content=f"The tool {tool.name} provided an empty response.")
160
+
161
+ return ToolMessage(name=tool.name, tool_call_id=tool.name, content=response)
162
+
163
+ except Exception as e:
164
+ last_exception = e
165
+ logger.warning("%s Tool call attempt %d/%d failed for tool %s: %s",
166
+ AGENT_LOG_PREFIX,
167
+ attempt + 1,
168
+ max_retries + 1,
169
+ tool.name,
170
+ str(e))
171
+
172
+ # If this was the last attempt, don't sleep
173
+ if attempt == max_retries:
174
+ break
175
+
176
+ # Exponential backoff: 2^attempt seconds
177
+ sleep_time = 2**attempt
178
+ logger.debug("%s Retrying tool call for %s in %d seconds...", AGENT_LOG_PREFIX, tool.name, sleep_time)
179
+ await asyncio.sleep(sleep_time)
180
+
181
+ # All retries exhausted, return error message
182
+ error_content = "Tool call failed after all retry attempts. Last error: %s" % str(last_exception)
183
+ logger.error("%s %s", AGENT_LOG_PREFIX, error_content)
184
+ return ToolMessage(name=tool.name, tool_call_id=tool.name, content=error_content, status="error")
185
+
186
+ def _log_tool_response(self, tool_name: str, tool_input: Any, tool_response: str, max_chars: int = 1000) -> None:
187
+ """
188
+ Log tool response with consistent formatting and length limits.
189
+
190
+ Parameters
191
+ ----------
192
+ tool_name : str
193
+ The name of the tool that was called
194
+ tool_input : Any
195
+ The input that was passed to the tool
196
+ tool_response : str
197
+ The response from the tool
198
+ max_chars : int
199
+ Maximum number of characters to log (default: 1000)
200
+ """
201
+ if self.detailed_logs:
202
+ # Truncate tool response if too long
203
+ display_response = tool_response[:max_chars] + "...(rest of response truncated)" if len(
204
+ tool_response) > max_chars else tool_response
205
+
206
+ # Format the tool input for display
207
+ tool_input_str = str(tool_input)
208
+
209
+ tool_response_log_message = TOOL_CALL_LOG_MESSAGE % (tool_name, tool_input_str, display_response)
210
+ logger.info(tool_response_log_message)
211
+
212
+ def _parse_json(self, json_string: str) -> dict[str, Any]:
213
+ """
214
+ Safely parse JSON with graceful error handling.
215
+ If JSON parsing fails, returns an empty dict or error info.
216
+
217
+ Parameters
218
+ ----------
219
+ json_string : str
220
+ The JSON string to parse
221
+
222
+ Returns
223
+ -------
224
+ Dict[str, Any]
225
+ The parsed JSON or error information
226
+ """
227
+ try:
228
+ return json.loads(json_string)
229
+ except json.JSONDecodeError as e:
230
+ logger.warning("%s JSON parsing failed, returning the original string: %s", AGENT_LOG_PREFIX, str(e))
231
+ return {"error": f"JSON parsing failed: {str(e)}", "original_string": json_string}
232
+ except Exception as e:
233
+ logger.warning("%s Unexpected error during JSON parsing: %s", AGENT_LOG_PREFIX, str(e))
234
+ return {"error": f"Unexpected parsing error: {str(e)}", "original_string": json_string}
235
+
74
236
  @abstractmethod
75
- async def _build_graph(self, state_schema) -> CompiledGraph:
237
+ async def _build_graph(self, state_schema: type) -> CompiledGraph:
76
238
  pass
aiq/agent/dual_node.py CHANGED
@@ -34,7 +34,7 @@ class DualNodeAgent(BaseAgent):
34
34
  def __init__(self,
35
35
  llm: BaseChatModel,
36
36
  tools: list[BaseTool],
37
- callbacks: list[AsyncCallbackHandler] = None,
37
+ callbacks: list[AsyncCallbackHandler] | None = None,
38
38
  detailed_logs: bool = False):
39
39
  super().__init__(llm=llm, tools=tools, callbacks=callbacks, detailed_logs=detailed_logs)
40
40
 
@@ -26,22 +26,25 @@ from langchain_core.messages.ai import AIMessage
26
26
  from langchain_core.messages.base import BaseMessage
27
27
  from langchain_core.messages.human import HumanMessage
28
28
  from langchain_core.messages.tool import ToolMessage
29
- from langchain_core.prompts.chat import ChatPromptTemplate
29
+ from langchain_core.prompts import ChatPromptTemplate
30
+ from langchain_core.prompts import MessagesPlaceholder
30
31
  from langchain_core.runnables.config import RunnableConfig
31
32
  from langchain_core.tools import BaseTool
32
33
  from pydantic import BaseModel
33
34
  from pydantic import Field
34
35
 
36
+ from aiq.agent.base import AGENT_CALL_LOG_MESSAGE
35
37
  from aiq.agent.base import AGENT_LOG_PREFIX
36
- from aiq.agent.base import AGENT_RESPONSE_LOG_MESSAGE
37
38
  from aiq.agent.base import INPUT_SCHEMA_MESSAGE
38
39
  from aiq.agent.base import NO_INPUT_ERROR_MESSAGE
39
40
  from aiq.agent.base import TOOL_NOT_FOUND_ERROR_MESSAGE
40
- from aiq.agent.base import TOOL_RESPONSE_LOG_MESSAGE
41
41
  from aiq.agent.base import AgentDecision
42
42
  from aiq.agent.dual_node import DualNodeAgent
43
43
  from aiq.agent.react_agent.output_parser import ReActOutputParser
44
44
  from aiq.agent.react_agent.output_parser import ReActOutputParserException
45
+ from aiq.agent.react_agent.prompt import SYSTEM_PROMPT
46
+ from aiq.agent.react_agent.prompt import USER_PROMPT
47
+ from aiq.agent.react_agent.register import ReActAgentWorkflowConfig
45
48
 
46
49
  logger = logging.getLogger(__name__)
47
50
 
@@ -63,13 +66,17 @@ class ReActAgentGraph(DualNodeAgent):
63
66
  prompt: ChatPromptTemplate,
64
67
  tools: list[BaseTool],
65
68
  use_tool_schema: bool = True,
66
- callbacks: list[AsyncCallbackHandler] = None,
69
+ callbacks: list[AsyncCallbackHandler] | None = None,
67
70
  detailed_logs: bool = False,
68
- retry_parsing_errors: bool = True,
69
- max_retries: int = 1):
71
+ retry_agent_response_parsing_errors: bool = True,
72
+ parse_agent_response_max_retries: int = 1,
73
+ tool_call_max_retries: int = 1,
74
+ pass_tool_call_errors_to_agent: bool = True):
70
75
  super().__init__(llm=llm, tools=tools, callbacks=callbacks, detailed_logs=detailed_logs)
71
- self.retry_parsing_errors = retry_parsing_errors
72
- self.max_tries = (max_retries + 1) if retry_parsing_errors else 1
76
+ self.parse_agent_response_max_retries = (parse_agent_response_max_retries
77
+ if retry_agent_response_parsing_errors else 1)
78
+ self.tool_call_max_retries = tool_call_max_retries
79
+ self.pass_tool_call_errors_to_agent = pass_tool_call_errors_to_agent
73
80
  logger.debug(
74
81
  "%s Filling the prompt variables 'tools' and 'tool_names', using the tools provided in the config.",
75
82
  AGENT_LOG_PREFIX)
@@ -87,12 +94,12 @@ class ReActAgentGraph(DualNodeAgent):
87
94
  f"{INPUT_SCHEMA_MESSAGE.format(schema=tools[-1].input_schema.model_fields)}")
88
95
  prompt = prompt.partial(tools=tool_names_and_descriptions, tool_names=tool_names)
89
96
  # construct the ReAct Agent
90
- llm = llm.bind(stop=["Observation:"])
91
- self.agent = prompt | llm
97
+ bound_llm = llm.bind(stop=["Observation:"]) # type: ignore
98
+ self.agent = prompt | bound_llm
92
99
  self.tools_dict = {tool.name: tool for tool in tools}
93
100
  logger.debug("%s Initialized ReAct Agent Graph", AGENT_LOG_PREFIX)
94
101
 
95
- def _get_tool(self, tool_name):
102
+ def _get_tool(self, tool_name: str):
96
103
  try:
97
104
  return self.tools_dict.get(tool_name)
98
105
  except Exception as ex:
@@ -109,26 +116,30 @@ class ReActAgentGraph(DualNodeAgent):
109
116
  # keeping a working state allows us to resolve parsing errors without polluting the agent scratchpad
110
117
  # the agent "forgets" about the parsing error after solving it - prevents hallucinations in next cycles
111
118
  working_state = []
112
- for attempt in range(1, self.max_tries + 1):
119
+ # Starting from attempt 1 instead of 0 for logging
120
+ for attempt in range(1, self.parse_agent_response_max_retries + 1):
113
121
  # the first time we are invoking the ReAct Agent, it won't have any intermediate steps / agent thoughts
114
122
  if len(state.agent_scratchpad) == 0 and len(working_state) == 0:
115
123
  # the user input comes from the "messages" state channel
116
124
  if len(state.messages) == 0:
117
125
  raise RuntimeError('No input received in state: "messages"')
118
126
  # to check is any human input passed or not, if no input passed Agent will return the state
119
- if state.messages[0].content.strip() == "":
127
+ content = str(state.messages[0].content)
128
+ if content.strip() == "":
120
129
  logger.error("%s No human input passed to the agent.", AGENT_LOG_PREFIX)
121
130
  state.messages += [AIMessage(content=NO_INPUT_ERROR_MESSAGE)]
122
131
  return state
123
- question = state.messages[0].content
132
+ question = content
124
133
  logger.debug("%s Querying agent, attempt: %s", AGENT_LOG_PREFIX, attempt)
125
- output_message = ""
126
- async for event in self.agent.astream({"question": question},
127
- config=RunnableConfig(callbacks=self.callbacks)):
128
- output_message += event.content
129
- output_message = AIMessage(content=output_message)
134
+
135
+ output_message = await self._stream_llm(
136
+ self.agent,
137
+ {"question": question},
138
+ RunnableConfig(callbacks=self.callbacks) # type: ignore
139
+ )
140
+
130
141
  if self.detailed_logs:
131
- logger.info(AGENT_RESPONSE_LOG_MESSAGE, question, output_message.content)
142
+ logger.info(AGENT_CALL_LOG_MESSAGE, question, output_message.content)
132
143
  else:
133
144
  # ReAct Agents require agentic cycles
134
145
  # in an agentic cycle, preserve the agent's thoughts from the previous cycles,
@@ -137,20 +148,20 @@ class ReActAgentGraph(DualNodeAgent):
137
148
  for index, intermediate_step in enumerate(state.agent_scratchpad):
138
149
  agent_thoughts = AIMessage(content=intermediate_step.log)
139
150
  agent_scratchpad.append(agent_thoughts)
140
- tool_response = HumanMessage(content=state.tool_responses[index].content)
151
+ tool_response_content = str(state.tool_responses[index].content)
152
+ tool_response = HumanMessage(content=tool_response_content)
141
153
  agent_scratchpad.append(tool_response)
142
154
  agent_scratchpad += working_state
143
- question = state.messages[0].content
155
+ question = str(state.messages[0].content)
144
156
  logger.debug("%s Querying agent, attempt: %s", AGENT_LOG_PREFIX, attempt)
145
- output_message = ""
146
- async for event in self.agent.astream({
147
- "question": question, "agent_scratchpad": agent_scratchpad
157
+
158
+ output_message = await self._stream_llm(self.agent, {
159
+ "question": question, "agent_scratchpad": agent_scratchpad
148
160
  },
149
- config=RunnableConfig(callbacks=self.callbacks)):
150
- output_message += event.content
151
- output_message = AIMessage(content=output_message)
161
+ RunnableConfig(callbacks=self.callbacks))
162
+
152
163
  if self.detailed_logs:
153
- logger.info(AGENT_RESPONSE_LOG_MESSAGE, question, output_message.content)
164
+ logger.info(AGENT_CALL_LOG_MESSAGE, question, output_message.content)
154
165
  logger.debug("%s The agent's scratchpad (with tool result) was:\n%s",
155
166
  AGENT_LOG_PREFIX,
156
167
  agent_scratchpad)
@@ -158,11 +169,7 @@ class ReActAgentGraph(DualNodeAgent):
158
169
  # check if the agent has the final answer yet
159
170
  logger.debug("%s Successfully obtained agent response. Parsing agent's response", AGENT_LOG_PREFIX)
160
171
  agent_output = await ReActOutputParser().aparse(output_message.content)
161
- logger.debug("%s Successfully parsed agent's response", AGENT_LOG_PREFIX)
162
- if attempt > 1:
163
- logger.debug("%s Successfully parsed agent response after %s attempts",
164
- AGENT_LOG_PREFIX,
165
- attempt)
172
+ logger.debug("%s Successfully parsed agent response after %s attempts", AGENT_LOG_PREFIX, attempt)
166
173
  if isinstance(agent_output, AgentFinish):
167
174
  final_answer = agent_output.return_values.get('output', output_message.content)
168
175
  logger.debug("%s The agent has finished, and has the final answer", AGENT_LOG_PREFIX)
@@ -174,31 +181,33 @@ class ReActAgentGraph(DualNodeAgent):
174
181
  agent_output.log = output_message.content
175
182
  logger.debug("%s The agent wants to call a tool: %s", AGENT_LOG_PREFIX, agent_output.tool)
176
183
  state.agent_scratchpad += [agent_output]
184
+
177
185
  return state
178
186
  except ReActOutputParserException as ex:
179
187
  # the agent output did not meet the expected ReAct output format. This can happen for a few reasons:
180
188
  # the agent mentioned a tool, but already has the final answer, this can happen with Llama models
181
189
  # - the ReAct Agent already has the answer, and is reflecting on how it obtained the answer
182
190
  # the agent might have also missed Action or Action Input in its output
183
- logger.warning("%s Error parsing agent output\nObservation:%s\nAgent Output:\n%s",
184
- AGENT_LOG_PREFIX,
185
- ex.observation,
186
- output_message.content)
187
- if attempt == self.max_tries:
188
- logger.exception(
191
+ logger.debug("%s Error parsing agent output\nObservation:%s\nAgent Output:\n%s",
192
+ AGENT_LOG_PREFIX,
193
+ ex.observation,
194
+ output_message.content)
195
+ if attempt == self.parse_agent_response_max_retries:
196
+ logger.error(
189
197
  "%s Failed to parse agent output after %d attempts, consider enabling or "
190
- "increasing max_retries",
198
+ "increasing parse_agent_response_max_retries",
191
199
  AGENT_LOG_PREFIX,
192
200
  attempt,
193
201
  exc_info=True)
194
202
  # the final answer goes in the "messages" state channel
195
- output_message.content = ex.observation + '\n' + output_message.content
203
+ combined_content = str(ex.observation) + '\n' + str(output_message.content)
204
+ output_message.content = combined_content
196
205
  state.messages += [output_message]
197
206
  return state
198
207
  # retry parsing errors, if configured
199
208
  logger.info("%s Retrying ReAct Agent, including output parsing Observation", AGENT_LOG_PREFIX)
200
209
  working_state.append(output_message)
201
- working_state.append(HumanMessage(content=ex.observation))
210
+ working_state.append(HumanMessage(content=str(ex.observation)))
202
211
  except Exception as ex:
203
212
  logger.exception("%s Failed to call agent_node: %s", AGENT_LOG_PREFIX, ex, exc_info=True)
204
213
  raise ex
@@ -208,7 +217,8 @@ class ReActAgentGraph(DualNodeAgent):
208
217
  logger.debug("%s Starting the ReAct Conditional Edge", AGENT_LOG_PREFIX)
209
218
  if len(state.messages) > 1:
210
219
  # the ReAct Agent has finished executing, the last agent output was AgentFinish
211
- logger.debug("%s Final answer:\n%s", AGENT_LOG_PREFIX, state.messages[-1].content)
220
+ last_message_content = str(state.messages[-1].content)
221
+ logger.debug("%s Final answer:\n%s", AGENT_LOG_PREFIX, last_message_content)
212
222
  return AgentDecision.END
213
223
  # else the agent wants to call a tool
214
224
  agent_output = state.agent_scratchpad[-1]
@@ -223,76 +233,71 @@ class ReActAgentGraph(DualNodeAgent):
223
233
  return AgentDecision.END
224
234
 
225
235
  async def tool_node(self, state: ReActGraphState):
226
- try:
227
- logger.debug("%s Starting the Tool Call Node", AGENT_LOG_PREFIX)
228
- if len(state.agent_scratchpad) == 0:
229
- raise RuntimeError('No tool input received in state: "agent_scratchpad"')
230
- agent_thoughts = state.agent_scratchpad[-1]
231
- # the agent can run any installed tool, simply install the tool and add it to the config file
232
- requested_tool = self._get_tool(agent_thoughts.tool)
233
- if not requested_tool:
234
- configured_tool_names = list(self.tools_dict.keys())
235
- logger.warning(
236
- "%s ReAct Agent wants to call tool %s. In the ReAct Agent's configuration within the config file,"
237
- "there is no tool with that name: %s",
238
- AGENT_LOG_PREFIX,
239
- agent_thoughts.tool,
240
- configured_tool_names)
241
- tool_response = ToolMessage(name='agent_error',
242
- tool_call_id='agent_error',
243
- content=TOOL_NOT_FOUND_ERROR_MESSAGE.format(tool_name=agent_thoughts.tool,
244
- tools=configured_tool_names))
245
- state.tool_responses += [tool_response]
246
- return state
247
-
248
- logger.debug("%s Calling tool %s with input: %s",
249
- AGENT_LOG_PREFIX,
250
- requested_tool.name,
251
- agent_thoughts.tool_input)
252
-
253
- # Run the tool. Try to use structured input, if possible.
254
- try:
255
- tool_input_str = agent_thoughts.tool_input.strip().replace("'", '"')
256
- tool_input_dict = json.loads(tool_input_str) if tool_input_str != 'None' else tool_input_str
257
- logger.debug("%s Successfully parsed structured tool input from Action Input", AGENT_LOG_PREFIX)
258
- tool_response = await requested_tool.ainvoke(tool_input_dict,
259
- config=RunnableConfig(callbacks=self.callbacks))
260
- if self.detailed_logs:
261
- # The tool response can be very large, so we log only the first 1000 characters
262
- tool_response_str = str(tool_response)
263
- tool_response_str = tool_response_str[:1000] + "..." if len(
264
- tool_response_str) > 1000 else tool_response_str
265
- tool_response_log_message = TOOL_RESPONSE_LOG_MESSAGE % (
266
- requested_tool.name, tool_input_str, tool_response_str)
267
- logger.info(tool_response_log_message)
268
- except JSONDecodeError as ex:
269
- logger.warning(
270
- "%s Unable to parse structured tool input from Action Input. Using Action Input as is."
271
- "\nParsing error: %s",
272
- AGENT_LOG_PREFIX,
273
- ex,
274
- exc_info=True)
275
- tool_input_str = agent_thoughts.tool_input
276
- tool_response = await requested_tool.ainvoke(tool_input_str,
277
- config=RunnableConfig(callbacks=self.callbacks))
278
-
279
- # some tools, such as Wikipedia, will return an empty response when no search results are found
280
- if tool_response is None or tool_response == "":
281
- tool_response = "The tool provided an empty response.\n"
282
- # put the tool response in the graph state
283
- tool_response = ToolMessage(name=agent_thoughts.tool,
284
- tool_call_id=agent_thoughts.tool,
285
- content=tool_response)
286
- logger.debug("%s Called tool %s with input: %s\nThe tool returned: %s",
287
- AGENT_LOG_PREFIX,
288
- requested_tool.name,
289
- agent_thoughts.tool_input,
290
- tool_response.content)
236
+
237
+ logger.debug("%s Starting the Tool Call Node", AGENT_LOG_PREFIX)
238
+ if len(state.agent_scratchpad) == 0:
239
+ raise RuntimeError('No tool input received in state: "agent_scratchpad"')
240
+ agent_thoughts = state.agent_scratchpad[-1]
241
+ # the agent can run any installed tool, simply install the tool and add it to the config file
242
+ requested_tool = self._get_tool(agent_thoughts.tool)
243
+ if not requested_tool:
244
+ configured_tool_names = list(self.tools_dict.keys())
245
+ logger.warning(
246
+ "%s ReAct Agent wants to call tool %s. In the ReAct Agent's configuration within the config file,"
247
+ "there is no tool with that name: %s",
248
+ AGENT_LOG_PREFIX,
249
+ agent_thoughts.tool,
250
+ configured_tool_names)
251
+ tool_response = ToolMessage(name='agent_error',
252
+ tool_call_id='agent_error',
253
+ content=TOOL_NOT_FOUND_ERROR_MESSAGE.format(tool_name=agent_thoughts.tool,
254
+ tools=configured_tool_names))
291
255
  state.tool_responses += [tool_response]
292
256
  return state
293
- except Exception as ex:
294
- logger.exception("%s Failed to call tool_node: %s", AGENT_LOG_PREFIX, ex, exc_info=ex)
295
- raise ex
257
+
258
+ logger.debug("%s Calling tool %s with input: %s",
259
+ AGENT_LOG_PREFIX,
260
+ requested_tool.name,
261
+ agent_thoughts.tool_input)
262
+
263
+ # Run the tool. Try to use structured input, if possible.
264
+ try:
265
+ tool_input_str = str(agent_thoughts.tool_input).strip().replace("'", '"')
266
+ tool_input_dict = json.loads(tool_input_str) if tool_input_str != 'None' else tool_input_str
267
+ logger.debug("%s Successfully parsed structured tool input from Action Input", AGENT_LOG_PREFIX)
268
+
269
+ tool_response = await self._call_tool(requested_tool,
270
+ tool_input_dict,
271
+ RunnableConfig(callbacks=self.callbacks),
272
+ max_retries=self.tool_call_max_retries)
273
+
274
+ if self.detailed_logs:
275
+ self._log_tool_response(requested_tool.name, tool_input_dict, str(tool_response.content))
276
+
277
+ except JSONDecodeError as ex:
278
+ logger.debug(
279
+ "%s Unable to parse structured tool input from Action Input. Using Action Input as is."
280
+ "\nParsing error: %s",
281
+ AGENT_LOG_PREFIX,
282
+ ex,
283
+ exc_info=True)
284
+ tool_input_str = str(agent_thoughts.tool_input)
285
+
286
+ tool_response = await self._call_tool(requested_tool,
287
+ tool_input_str,
288
+ RunnableConfig(callbacks=self.callbacks),
289
+ max_retries=self.tool_call_max_retries)
290
+
291
+ if self.detailed_logs:
292
+ self._log_tool_response(requested_tool.name, tool_input_str, str(tool_response.content))
293
+
294
+ if not self.pass_tool_call_errors_to_agent:
295
+ if tool_response.status == "error":
296
+ logger.error("%s Tool %s failed: %s", AGENT_LOG_PREFIX, requested_tool.name, tool_response.content)
297
+ raise RuntimeError("Tool call failed: " + str(tool_response.content))
298
+
299
+ state.tool_responses += [tool_response]
300
+ return state
296
301
 
297
302
  async def build_graph(self):
298
303
  try:
@@ -320,3 +325,32 @@ class ReActAgentGraph(DualNodeAgent):
320
325
  logger.exception("%s %s", AGENT_LOG_PREFIX, error_text)
321
326
  raise ValueError(error_text)
322
327
  return True
328
+
329
+
330
+ def create_react_agent_prompt(config: ReActAgentWorkflowConfig) -> ChatPromptTemplate:
331
+ """
332
+ Create a ReAct Agent prompt from the config.
333
+
334
+ Args:
335
+ config (ReActAgentWorkflowConfig): The config to use for the prompt.
336
+
337
+ Returns:
338
+ ChatPromptTemplate: The ReAct Agent prompt.
339
+ """
340
+ # the ReAct Agent prompt can be customized via config option system_prompt and additional_instructions.
341
+
342
+ if config.system_prompt:
343
+ prompt_str = config.system_prompt
344
+ else:
345
+ prompt_str = SYSTEM_PROMPT
346
+
347
+ if config.additional_instructions:
348
+ prompt_str += f" {config.additional_instructions}"
349
+
350
+ valid_prompt = ReActAgentGraph.validate_system_prompt(prompt_str)
351
+ if not valid_prompt:
352
+ logger.exception("%s Invalid system_prompt", AGENT_LOG_PREFIX)
353
+ raise ValueError("Invalid system_prompt")
354
+ prompt = ChatPromptTemplate([("system", prompt_str), ("user", USER_PROMPT),
355
+ MessagesPlaceholder(variable_name='agent_scratchpad', optional=True)])
356
+ return prompt
@@ -14,8 +14,6 @@
14
14
  # limitations under the License.
15
15
 
16
16
  # flake8: noqa
17
- from langchain_core.prompts.chat import ChatPromptTemplate
18
- from langchain_core.prompts.chat import MessagesPlaceholder
19
17
 
20
18
  SYSTEM_PROMPT = """
21
19
  Answer the following questions as best you can. You may ask the human to use the following tools:
@@ -37,10 +35,7 @@ Use the following format once you have the final answer:
37
35
  Thought: I now know the final answer
38
36
  Final Answer: the final answer to the original input question
39
37
  """
38
+
40
39
  USER_PROMPT = """
41
40
  Question: {question}
42
41
  """
43
-
44
- # This is the prompt - (ReAct Agent prompt)
45
- react_agent_prompt = ChatPromptTemplate([("system", SYSTEM_PROMPT), ("user", USER_PROMPT),
46
- MessagesPlaceholder(variable_name='agent_scratchpad', optional=True)])