alita-sdk 0.3.457__py3-none-any.whl → 0.3.486__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of alita-sdk might be problematic. Click here for more details.
- alita_sdk/cli/__init__.py +10 -0
- alita_sdk/cli/__main__.py +17 -0
- alita_sdk/cli/agent/__init__.py +5 -0
- alita_sdk/cli/agent/default.py +258 -0
- alita_sdk/cli/agent_executor.py +155 -0
- alita_sdk/cli/agent_loader.py +194 -0
- alita_sdk/cli/agent_ui.py +228 -0
- alita_sdk/cli/agents.py +3592 -0
- alita_sdk/cli/callbacks.py +647 -0
- alita_sdk/cli/cli.py +168 -0
- alita_sdk/cli/config.py +306 -0
- alita_sdk/cli/context/__init__.py +30 -0
- alita_sdk/cli/context/cleanup.py +198 -0
- alita_sdk/cli/context/manager.py +731 -0
- alita_sdk/cli/context/message.py +285 -0
- alita_sdk/cli/context/strategies.py +289 -0
- alita_sdk/cli/context/token_estimation.py +127 -0
- alita_sdk/cli/formatting.py +182 -0
- alita_sdk/cli/input_handler.py +419 -0
- alita_sdk/cli/inventory.py +1256 -0
- alita_sdk/cli/mcp_loader.py +315 -0
- alita_sdk/cli/toolkit.py +327 -0
- alita_sdk/cli/toolkit_loader.py +85 -0
- alita_sdk/cli/tools/__init__.py +43 -0
- alita_sdk/cli/tools/approval.py +224 -0
- alita_sdk/cli/tools/filesystem.py +1665 -0
- alita_sdk/cli/tools/planning.py +389 -0
- alita_sdk/cli/tools/terminal.py +414 -0
- alita_sdk/community/__init__.py +64 -8
- alita_sdk/community/inventory/__init__.py +224 -0
- alita_sdk/community/inventory/config.py +257 -0
- alita_sdk/community/inventory/enrichment.py +2137 -0
- alita_sdk/community/inventory/extractors.py +1469 -0
- alita_sdk/community/inventory/ingestion.py +3172 -0
- alita_sdk/community/inventory/knowledge_graph.py +1457 -0
- alita_sdk/community/inventory/parsers/__init__.py +218 -0
- alita_sdk/community/inventory/parsers/base.py +295 -0
- alita_sdk/community/inventory/parsers/csharp_parser.py +907 -0
- alita_sdk/community/inventory/parsers/go_parser.py +851 -0
- alita_sdk/community/inventory/parsers/html_parser.py +389 -0
- alita_sdk/community/inventory/parsers/java_parser.py +593 -0
- alita_sdk/community/inventory/parsers/javascript_parser.py +629 -0
- alita_sdk/community/inventory/parsers/kotlin_parser.py +768 -0
- alita_sdk/community/inventory/parsers/markdown_parser.py +362 -0
- alita_sdk/community/inventory/parsers/python_parser.py +604 -0
- alita_sdk/community/inventory/parsers/rust_parser.py +858 -0
- alita_sdk/community/inventory/parsers/swift_parser.py +832 -0
- alita_sdk/community/inventory/parsers/text_parser.py +322 -0
- alita_sdk/community/inventory/parsers/yaml_parser.py +370 -0
- alita_sdk/community/inventory/patterns/__init__.py +61 -0
- alita_sdk/community/inventory/patterns/ast_adapter.py +380 -0
- alita_sdk/community/inventory/patterns/loader.py +348 -0
- alita_sdk/community/inventory/patterns/registry.py +198 -0
- alita_sdk/community/inventory/presets.py +535 -0
- alita_sdk/community/inventory/retrieval.py +1403 -0
- alita_sdk/community/inventory/toolkit.py +169 -0
- alita_sdk/community/inventory/visualize.py +1370 -0
- alita_sdk/configurations/bitbucket.py +0 -3
- alita_sdk/runtime/clients/client.py +99 -26
- alita_sdk/runtime/langchain/assistant.py +4 -2
- alita_sdk/runtime/langchain/constants.py +2 -1
- alita_sdk/runtime/langchain/langraph_agent.py +134 -31
- alita_sdk/runtime/langchain/utils.py +1 -1
- alita_sdk/runtime/llms/preloaded.py +2 -6
- alita_sdk/runtime/toolkits/__init__.py +2 -0
- alita_sdk/runtime/toolkits/application.py +1 -1
- alita_sdk/runtime/toolkits/mcp.py +46 -36
- alita_sdk/runtime/toolkits/planning.py +171 -0
- alita_sdk/runtime/toolkits/tools.py +39 -6
- alita_sdk/runtime/tools/function.py +17 -5
- alita_sdk/runtime/tools/llm.py +249 -14
- alita_sdk/runtime/tools/planning/__init__.py +36 -0
- alita_sdk/runtime/tools/planning/models.py +246 -0
- alita_sdk/runtime/tools/planning/wrapper.py +607 -0
- alita_sdk/runtime/tools/vectorstore_base.py +41 -6
- alita_sdk/runtime/utils/mcp_oauth.py +80 -0
- alita_sdk/runtime/utils/streamlit.py +6 -10
- alita_sdk/runtime/utils/toolkit_utils.py +19 -4
- alita_sdk/tools/__init__.py +54 -27
- alita_sdk/tools/ado/repos/repos_wrapper.py +1 -2
- alita_sdk/tools/base_indexer_toolkit.py +150 -19
- alita_sdk/tools/bitbucket/__init__.py +2 -2
- alita_sdk/tools/chunkers/__init__.py +3 -1
- alita_sdk/tools/chunkers/sematic/markdown_chunker.py +95 -6
- alita_sdk/tools/chunkers/universal_chunker.py +269 -0
- alita_sdk/tools/code_indexer_toolkit.py +55 -22
- alita_sdk/tools/elitea_base.py +86 -21
- alita_sdk/tools/jira/__init__.py +1 -1
- alita_sdk/tools/jira/api_wrapper.py +91 -40
- alita_sdk/tools/non_code_indexer_toolkit.py +1 -0
- alita_sdk/tools/qtest/__init__.py +1 -1
- alita_sdk/tools/qtest/api_wrapper.py +871 -32
- alita_sdk/tools/sharepoint/api_wrapper.py +22 -2
- alita_sdk/tools/sharepoint/authorization_helper.py +17 -1
- alita_sdk/tools/vector_adapters/VectorStoreAdapter.py +8 -2
- alita_sdk/tools/zephyr_essential/api_wrapper.py +12 -13
- {alita_sdk-0.3.457.dist-info → alita_sdk-0.3.486.dist-info}/METADATA +146 -2
- {alita_sdk-0.3.457.dist-info → alita_sdk-0.3.486.dist-info}/RECORD +102 -40
- alita_sdk-0.3.486.dist-info/entry_points.txt +2 -0
- {alita_sdk-0.3.457.dist-info → alita_sdk-0.3.486.dist-info}/WHEEL +0 -0
- {alita_sdk-0.3.457.dist-info → alita_sdk-0.3.486.dist-info}/licenses/LICENSE +0 -0
- {alita_sdk-0.3.457.dist-info → alita_sdk-0.3.486.dist-info}/top_level.txt +0 -0
alita_sdk/runtime/tools/llm.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import asyncio
|
|
1
2
|
import logging
|
|
2
3
|
from traceback import format_exc
|
|
3
4
|
from typing import Any, Optional, List, Union
|
|
@@ -12,6 +13,7 @@ from ..langchain.utils import create_pydantic_model, propagate_the_input_mapping
|
|
|
12
13
|
|
|
13
14
|
logger = logging.getLogger(__name__)
|
|
14
15
|
|
|
16
|
+
|
|
15
17
|
class LLMNode(BaseTool):
|
|
16
18
|
"""Enhanced LLM node with chat history and tool binding support"""
|
|
17
19
|
|
|
@@ -60,6 +62,47 @@ class LLMNode(BaseTool):
|
|
|
60
62
|
|
|
61
63
|
return filtered_tools
|
|
62
64
|
|
|
65
|
+
def _get_tool_truncation_suggestions(self, tool_name: Optional[str]) -> str:
|
|
66
|
+
"""
|
|
67
|
+
Get context-specific suggestions for how to reduce output from a tool.
|
|
68
|
+
|
|
69
|
+
First checks if the tool itself provides truncation suggestions via
|
|
70
|
+
`truncation_suggestions` attribute or `get_truncation_suggestions()` method.
|
|
71
|
+
Falls back to generic suggestions if the tool doesn't provide any.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
tool_name: Name of the tool that caused the context overflow
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Formatted string with numbered suggestions for the specific tool
|
|
78
|
+
"""
|
|
79
|
+
suggestions = None
|
|
80
|
+
|
|
81
|
+
# Try to get suggestions from the tool itself
|
|
82
|
+
if tool_name:
|
|
83
|
+
filtered_tools = self.get_filtered_tools()
|
|
84
|
+
for tool in filtered_tools:
|
|
85
|
+
if tool.name == tool_name:
|
|
86
|
+
# Check for truncation_suggestions attribute
|
|
87
|
+
if hasattr(tool, 'truncation_suggestions') and tool.truncation_suggestions:
|
|
88
|
+
suggestions = tool.truncation_suggestions
|
|
89
|
+
break
|
|
90
|
+
# Check for get_truncation_suggestions method
|
|
91
|
+
elif hasattr(tool, 'get_truncation_suggestions') and callable(tool.get_truncation_suggestions):
|
|
92
|
+
suggestions = tool.get_truncation_suggestions()
|
|
93
|
+
break
|
|
94
|
+
|
|
95
|
+
# Fall back to generic suggestions if tool doesn't provide any
|
|
96
|
+
if not suggestions:
|
|
97
|
+
suggestions = [
|
|
98
|
+
"Check if the tool has parameters to limit output size (e.g., max_items, max_results, max_depth)",
|
|
99
|
+
"Target a more specific path or query instead of broad searches",
|
|
100
|
+
"Break the operation into smaller, focused requests",
|
|
101
|
+
]
|
|
102
|
+
|
|
103
|
+
# Format as numbered list
|
|
104
|
+
return "\n".join(f"{i+1}. {s}" for i, s in enumerate(suggestions))
|
|
105
|
+
|
|
63
106
|
def invoke(
|
|
64
107
|
self,
|
|
65
108
|
state: Union[str, dict],
|
|
@@ -132,7 +175,9 @@ class LLMNode(BaseTool):
|
|
|
132
175
|
struct_model = create_pydantic_model(f"LLMOutput", struct_params)
|
|
133
176
|
completion = llm_client.invoke(messages, config=config)
|
|
134
177
|
if hasattr(completion, 'tool_calls') and completion.tool_calls:
|
|
135
|
-
new_messages, _ = self.
|
|
178
|
+
new_messages, _ = self._run_async_in_sync_context(
|
|
179
|
+
self.__perform_tool_calling(completion, messages, llm_client, config)
|
|
180
|
+
)
|
|
136
181
|
llm = self.__get_struct_output_model(llm_client, struct_model)
|
|
137
182
|
completion = llm.invoke(new_messages, config=config)
|
|
138
183
|
result = completion.model_dump()
|
|
@@ -155,7 +200,9 @@ class LLMNode(BaseTool):
|
|
|
155
200
|
# Handle both tool-calling and regular responses
|
|
156
201
|
if hasattr(completion, 'tool_calls') and completion.tool_calls:
|
|
157
202
|
# Handle iterative tool-calling and execution
|
|
158
|
-
new_messages, current_completion = self.
|
|
203
|
+
new_messages, current_completion = self._run_async_in_sync_context(
|
|
204
|
+
self.__perform_tool_calling(completion, messages, llm_client, config)
|
|
205
|
+
)
|
|
159
206
|
|
|
160
207
|
output_msgs = {"messages": new_messages}
|
|
161
208
|
if self.output_variables:
|
|
@@ -190,9 +237,53 @@ class LLMNode(BaseTool):
|
|
|
190
237
|
def _run(self, *args, **kwargs):
|
|
191
238
|
# Legacy support for old interface
|
|
192
239
|
return self.invoke(kwargs, **kwargs)
|
|
240
|
+
|
|
241
|
+
def _run_async_in_sync_context(self, coro):
|
|
242
|
+
"""Run async coroutine from sync context.
|
|
243
|
+
|
|
244
|
+
For MCP tools with persistent sessions, we reuse the same event loop
|
|
245
|
+
that was used to create the MCP client and sessions (set by CLI).
|
|
246
|
+
"""
|
|
247
|
+
try:
|
|
248
|
+
loop = asyncio.get_running_loop()
|
|
249
|
+
# Already in async context - run in thread with new loop
|
|
250
|
+
import threading
|
|
251
|
+
|
|
252
|
+
result_container = []
|
|
253
|
+
|
|
254
|
+
def run_in_thread():
|
|
255
|
+
new_loop = asyncio.new_event_loop()
|
|
256
|
+
asyncio.set_event_loop(new_loop)
|
|
257
|
+
try:
|
|
258
|
+
result_container.append(new_loop.run_until_complete(coro))
|
|
259
|
+
finally:
|
|
260
|
+
new_loop.close()
|
|
261
|
+
|
|
262
|
+
thread = threading.Thread(target=run_in_thread)
|
|
263
|
+
thread.start()
|
|
264
|
+
thread.join()
|
|
265
|
+
return result_container[0] if result_container else None
|
|
266
|
+
|
|
267
|
+
except RuntimeError:
|
|
268
|
+
# No event loop running - use/create persistent loop
|
|
269
|
+
# This loop is shared with MCP session creation for stateful tools
|
|
270
|
+
if not hasattr(self.__class__, '_persistent_loop') or \
|
|
271
|
+
self.__class__._persistent_loop is None or \
|
|
272
|
+
self.__class__._persistent_loop.is_closed():
|
|
273
|
+
self.__class__._persistent_loop = asyncio.new_event_loop()
|
|
274
|
+
logger.debug("Created persistent event loop for async tools")
|
|
275
|
+
|
|
276
|
+
loop = self.__class__._persistent_loop
|
|
277
|
+
asyncio.set_event_loop(loop)
|
|
278
|
+
return loop.run_until_complete(coro)
|
|
279
|
+
|
|
280
|
+
async def _arun(self, *args, **kwargs):
|
|
281
|
+
# Legacy async support
|
|
282
|
+
return self.invoke(kwargs, **kwargs)
|
|
193
283
|
|
|
194
|
-
def __perform_tool_calling(self, completion, messages, llm_client, config):
|
|
284
|
+
async def __perform_tool_calling(self, completion, messages, llm_client, config):
|
|
195
285
|
# Handle iterative tool-calling and execution
|
|
286
|
+
logger.info(f"__perform_tool_calling called with {len(completion.tool_calls) if hasattr(completion, 'tool_calls') else 0} tool calls")
|
|
196
287
|
new_messages = messages + [completion]
|
|
197
288
|
iteration = 0
|
|
198
289
|
|
|
@@ -230,9 +321,16 @@ class LLMNode(BaseTool):
|
|
|
230
321
|
if tool_to_execute:
|
|
231
322
|
try:
|
|
232
323
|
logger.info(f"Executing tool '{tool_name}' with args: {tool_args}")
|
|
233
|
-
|
|
234
|
-
#
|
|
235
|
-
tool_result =
|
|
324
|
+
|
|
325
|
+
# Try async invoke first (for MCP tools), fallback to sync
|
|
326
|
+
tool_result = None
|
|
327
|
+
try:
|
|
328
|
+
# Try async invocation first
|
|
329
|
+
tool_result = await tool_to_execute.ainvoke(tool_args, config=config)
|
|
330
|
+
except NotImplementedError:
|
|
331
|
+
# Tool doesn't support async, use sync invoke
|
|
332
|
+
logger.debug(f"Tool '{tool_name}' doesn't support async, using sync invoke")
|
|
333
|
+
tool_result = tool_to_execute.invoke(tool_args, config=config)
|
|
236
334
|
|
|
237
335
|
# Create tool message with result - preserve structured content
|
|
238
336
|
from langchain_core.messages import ToolMessage
|
|
@@ -256,7 +354,10 @@ class LLMNode(BaseTool):
|
|
|
256
354
|
new_messages.append(tool_message)
|
|
257
355
|
|
|
258
356
|
except Exception as e:
|
|
259
|
-
|
|
357
|
+
import traceback
|
|
358
|
+
error_details = traceback.format_exc()
|
|
359
|
+
# Use debug level to avoid duplicate output when CLI callbacks are active
|
|
360
|
+
logger.debug(f"Error executing tool '{tool_name}': {e}\n{error_details}")
|
|
260
361
|
# Create error tool message
|
|
261
362
|
from langchain_core.messages import ToolMessage
|
|
262
363
|
tool_message = ToolMessage(
|
|
@@ -287,16 +388,150 @@ class LLMNode(BaseTool):
|
|
|
287
388
|
break
|
|
288
389
|
|
|
289
390
|
except Exception as e:
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
391
|
+
error_str = str(e).lower()
|
|
392
|
+
|
|
393
|
+
# Check for context window / token limit errors
|
|
394
|
+
is_context_error = any(indicator in error_str for indicator in [
|
|
395
|
+
'context window', 'context_window', 'token limit', 'too long',
|
|
396
|
+
'maximum context length', 'input is too long', 'exceeds the limit',
|
|
397
|
+
'contextwindowexceedederror', 'max_tokens', 'content too large'
|
|
398
|
+
])
|
|
399
|
+
|
|
400
|
+
# Check for Bedrock/Claude output limit errors
|
|
401
|
+
# These often manifest as "model identifier is invalid" when output exceeds limits
|
|
402
|
+
is_output_limit_error = any(indicator in error_str for indicator in [
|
|
403
|
+
'model identifier is invalid',
|
|
404
|
+
'bedrockexception',
|
|
405
|
+
'output token',
|
|
406
|
+
'response too large',
|
|
407
|
+
'max_tokens_to_sample',
|
|
408
|
+
'output_token_limit'
|
|
409
|
+
])
|
|
410
|
+
|
|
411
|
+
if is_context_error or is_output_limit_error:
|
|
412
|
+
error_type = "output limit" if is_output_limit_error else "context window"
|
|
413
|
+
logger.warning(f"{error_type.title()} exceeded during tool execution iteration {iteration}")
|
|
414
|
+
|
|
415
|
+
# Find the last tool message and its associated tool name
|
|
416
|
+
last_tool_msg_idx = None
|
|
417
|
+
last_tool_name = None
|
|
418
|
+
last_tool_call_id = None
|
|
419
|
+
|
|
420
|
+
# First, find the last tool message
|
|
421
|
+
for i in range(len(new_messages) - 1, -1, -1):
|
|
422
|
+
msg = new_messages[i]
|
|
423
|
+
if hasattr(msg, 'tool_call_id') or (hasattr(msg, 'type') and getattr(msg, 'type', None) == 'tool'):
|
|
424
|
+
last_tool_msg_idx = i
|
|
425
|
+
last_tool_call_id = getattr(msg, 'tool_call_id', None)
|
|
426
|
+
break
|
|
427
|
+
|
|
428
|
+
# Find the tool name from the AIMessage that requested this tool call
|
|
429
|
+
if last_tool_call_id:
|
|
430
|
+
for i in range(last_tool_msg_idx - 1, -1, -1):
|
|
431
|
+
msg = new_messages[i]
|
|
432
|
+
if hasattr(msg, 'tool_calls') and msg.tool_calls:
|
|
433
|
+
for tc in msg.tool_calls:
|
|
434
|
+
tc_id = tc.get('id', '') if isinstance(tc, dict) else getattr(tc, 'id', '')
|
|
435
|
+
if tc_id == last_tool_call_id:
|
|
436
|
+
last_tool_name = tc.get('name', '') if isinstance(tc, dict) else getattr(tc, 'name', '')
|
|
437
|
+
break
|
|
438
|
+
if last_tool_name:
|
|
439
|
+
break
|
|
440
|
+
|
|
441
|
+
# Build dynamic suggestion based on the tool that caused the overflow
|
|
442
|
+
tool_suggestions = self._get_tool_truncation_suggestions(last_tool_name)
|
|
443
|
+
|
|
444
|
+
# Truncate the problematic tool result if found
|
|
445
|
+
if last_tool_msg_idx is not None:
|
|
446
|
+
from langchain_core.messages import ToolMessage
|
|
447
|
+
original_msg = new_messages[last_tool_msg_idx]
|
|
448
|
+
tool_call_id = getattr(original_msg, 'tool_call_id', 'unknown')
|
|
449
|
+
|
|
450
|
+
# Build error-specific guidance
|
|
451
|
+
if is_output_limit_error:
|
|
452
|
+
truncated_content = (
|
|
453
|
+
f"⚠️ MODEL OUTPUT LIMIT EXCEEDED\n\n"
|
|
454
|
+
f"The tool '{last_tool_name or 'unknown'}' returned data, but the model's response was too large.\n\n"
|
|
455
|
+
f"IMPORTANT: You must provide a SMALLER, more focused response.\n"
|
|
456
|
+
f"- Break down your response into smaller chunks\n"
|
|
457
|
+
f"- Summarize instead of listing everything\n"
|
|
458
|
+
f"- Focus on the most relevant information first\n"
|
|
459
|
+
f"- If listing items, show only top 5-10 most important\n\n"
|
|
460
|
+
f"Tool-specific tips:\n{tool_suggestions}\n\n"
|
|
461
|
+
f"Please retry with a more concise response."
|
|
462
|
+
)
|
|
463
|
+
else:
|
|
464
|
+
truncated_content = (
|
|
465
|
+
f"⚠️ TOOL OUTPUT TRUNCATED - Context window exceeded\n\n"
|
|
466
|
+
f"The tool '{last_tool_name or 'unknown'}' returned too much data for the model's context window.\n\n"
|
|
467
|
+
f"To fix this:\n{tool_suggestions}\n\n"
|
|
468
|
+
f"Please retry with more restrictive parameters."
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
truncated_msg = ToolMessage(
|
|
472
|
+
content=truncated_content,
|
|
473
|
+
tool_call_id=tool_call_id
|
|
474
|
+
)
|
|
475
|
+
new_messages[last_tool_msg_idx] = truncated_msg
|
|
476
|
+
|
|
477
|
+
logger.info(f"Truncated large tool result from '{last_tool_name}' and continuing")
|
|
478
|
+
# Continue to next iteration - the model will see the truncation message
|
|
479
|
+
continue
|
|
480
|
+
else:
|
|
481
|
+
# Couldn't find tool message, add error and break
|
|
482
|
+
if is_output_limit_error:
|
|
483
|
+
error_msg = (
|
|
484
|
+
"Model output limit exceeded. Please provide a more concise response. "
|
|
485
|
+
"Break down your answer into smaller parts and summarize where possible."
|
|
486
|
+
)
|
|
487
|
+
else:
|
|
488
|
+
error_msg = (
|
|
489
|
+
"Context window exceeded. The conversation or tool results are too large. "
|
|
490
|
+
"Try using tools with smaller output limits (e.g., max_items, max_depth parameters)."
|
|
491
|
+
)
|
|
492
|
+
new_messages.append(AIMessage(content=error_msg))
|
|
493
|
+
break
|
|
494
|
+
else:
|
|
495
|
+
logger.error(f"Error in LLM call during iteration {iteration}: {e}")
|
|
496
|
+
# Add error message and break the loop
|
|
497
|
+
error_msg = f"Error processing tool results in iteration {iteration}: {str(e)}"
|
|
498
|
+
new_messages.append(AIMessage(content=error_msg))
|
|
499
|
+
break
|
|
295
500
|
|
|
296
|
-
#
|
|
501
|
+
# Handle max iterations
|
|
297
502
|
if iteration >= self.steps_limit:
|
|
298
503
|
logger.warning(f"Reached maximum iterations ({self.steps_limit}) for tool execution")
|
|
299
|
-
|
|
504
|
+
|
|
505
|
+
# CRITICAL: Check if the last message is an AIMessage with pending tool_calls
|
|
506
|
+
# that were not processed. If so, we need to add placeholder ToolMessages to prevent
|
|
507
|
+
# the "assistant message with 'tool_calls' must be followed by tool messages" error
|
|
508
|
+
# when the conversation continues.
|
|
509
|
+
if new_messages:
|
|
510
|
+
last_msg = new_messages[-1]
|
|
511
|
+
if hasattr(last_msg, 'tool_calls') and last_msg.tool_calls:
|
|
512
|
+
from langchain_core.messages import ToolMessage
|
|
513
|
+
pending_tool_calls = last_msg.tool_calls if hasattr(last_msg.tool_calls, '__iter__') else []
|
|
514
|
+
|
|
515
|
+
# Check which tool_call_ids already have responses
|
|
516
|
+
existing_tool_call_ids = set()
|
|
517
|
+
for msg in new_messages:
|
|
518
|
+
if hasattr(msg, 'tool_call_id'):
|
|
519
|
+
existing_tool_call_ids.add(msg.tool_call_id)
|
|
520
|
+
|
|
521
|
+
# Add placeholder responses for any tool calls without responses
|
|
522
|
+
for tool_call in pending_tool_calls:
|
|
523
|
+
tool_call_id = tool_call.get('id', '') if isinstance(tool_call, dict) else getattr(tool_call, 'id', '')
|
|
524
|
+
tool_name = tool_call.get('name', '') if isinstance(tool_call, dict) else getattr(tool_call, 'name', '')
|
|
525
|
+
|
|
526
|
+
if tool_call_id and tool_call_id not in existing_tool_call_ids:
|
|
527
|
+
logger.info(f"Adding placeholder ToolMessage for interrupted tool call: {tool_name} ({tool_call_id})")
|
|
528
|
+
placeholder_msg = ToolMessage(
|
|
529
|
+
content=f"[Tool execution interrupted - step limit ({self.steps_limit}) reached before {tool_name} could be executed]",
|
|
530
|
+
tool_call_id=tool_call_id
|
|
531
|
+
)
|
|
532
|
+
new_messages.append(placeholder_msg)
|
|
533
|
+
|
|
534
|
+
# Add warning message - CLI or calling code can detect this and prompt user
|
|
300
535
|
warning_msg = f"Maximum tool execution iterations ({self.steps_limit}) reached. Stopping tool execution."
|
|
301
536
|
new_messages.append(AIMessage(content=warning_msg))
|
|
302
537
|
else:
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Planning tools for runtime agents.
|
|
3
|
+
|
|
4
|
+
Provides plan management for multi-step task execution with progress tracking.
|
|
5
|
+
Supports two storage backends:
|
|
6
|
+
1. PostgreSQL - when connection_string is provided (production/indexer_worker)
|
|
7
|
+
2. Filesystem - when no connection string (local CLI usage)
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from .wrapper import (
|
|
11
|
+
PlanningWrapper,
|
|
12
|
+
PlanStep,
|
|
13
|
+
PlanState,
|
|
14
|
+
FilesystemStorage,
|
|
15
|
+
PostgresStorage,
|
|
16
|
+
)
|
|
17
|
+
from .models import (
|
|
18
|
+
AgentPlan,
|
|
19
|
+
PlanStatus,
|
|
20
|
+
ensure_plan_tables,
|
|
21
|
+
delete_plan_by_conversation_id,
|
|
22
|
+
cleanup_on_graceful_completion
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"PlanningWrapper",
|
|
27
|
+
"PlanStep",
|
|
28
|
+
"PlanState",
|
|
29
|
+
"FilesystemStorage",
|
|
30
|
+
"PostgresStorage",
|
|
31
|
+
"AgentPlan",
|
|
32
|
+
"PlanStatus",
|
|
33
|
+
"ensure_plan_tables",
|
|
34
|
+
"delete_plan_by_conversation_id",
|
|
35
|
+
"cleanup_on_graceful_completion",
|
|
36
|
+
]
|
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SQLAlchemy models for agent planning.
|
|
3
|
+
|
|
4
|
+
Defines the AgentPlan table for storing execution plans with steps.
|
|
5
|
+
Table is created automatically on toolkit initialization if it doesn't exist.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import enum
|
|
9
|
+
import logging
|
|
10
|
+
import uuid
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
from typing import List, Dict, Any, Optional
|
|
13
|
+
|
|
14
|
+
from pydantic import BaseModel, Field
|
|
15
|
+
from sqlalchemy import Column, String, DateTime, Text, Index, text
|
|
16
|
+
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
|
17
|
+
from sqlalchemy.orm import declarative_base
|
|
18
|
+
from sqlalchemy import create_engine
|
|
19
|
+
from sqlalchemy.exc import ProgrammingError
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
Base = declarative_base()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class PlanStatus(str, enum.Enum):
|
|
27
|
+
"""Status of an execution plan."""
|
|
28
|
+
in_progress = "in_progress"
|
|
29
|
+
completed = "completed"
|
|
30
|
+
abandoned = "abandoned"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class AgentPlan(Base):
|
|
34
|
+
"""
|
|
35
|
+
Stores execution plans for agent tasks.
|
|
36
|
+
|
|
37
|
+
Created in the project-specific pgvector database.
|
|
38
|
+
Plans are scoped by conversation_id (from server or CLI session_id).
|
|
39
|
+
"""
|
|
40
|
+
__tablename__ = "agent_plans"
|
|
41
|
+
|
|
42
|
+
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
|
43
|
+
conversation_id = Column(String(255), nullable=False, index=True)
|
|
44
|
+
|
|
45
|
+
# Plan metadata
|
|
46
|
+
title = Column(String(255), nullable=True)
|
|
47
|
+
status = Column(String(50), default=PlanStatus.in_progress.value)
|
|
48
|
+
|
|
49
|
+
# Plan content (JSONB for flexible step storage)
|
|
50
|
+
# Structure: {"steps": [{"description": "...", "completed": false}, ...]}
|
|
51
|
+
plan_data = Column(JSONB, nullable=False, default=dict)
|
|
52
|
+
|
|
53
|
+
# Timestamps
|
|
54
|
+
created_at = Column(DateTime, nullable=False, default=datetime.utcnow)
|
|
55
|
+
updated_at = Column(DateTime, nullable=True, onupdate=datetime.utcnow)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# Pydantic models for tool input/output
|
|
59
|
+
class PlanStep(BaseModel):
|
|
60
|
+
"""A single step in a plan."""
|
|
61
|
+
description: str = Field(description="Step description")
|
|
62
|
+
completed: bool = Field(default=False, description="Whether step is completed")
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class PlanState(BaseModel):
|
|
66
|
+
"""Current plan state for serialization."""
|
|
67
|
+
title: str = Field(default="", description="Plan title")
|
|
68
|
+
steps: List[PlanStep] = Field(default_factory=list, description="List of steps")
|
|
69
|
+
status: str = Field(default=PlanStatus.in_progress.value, description="Plan status")
|
|
70
|
+
|
|
71
|
+
def render(self) -> str:
|
|
72
|
+
"""Render plan as formatted string with checkboxes."""
|
|
73
|
+
if not self.steps:
|
|
74
|
+
return "No plan currently set."
|
|
75
|
+
|
|
76
|
+
lines = []
|
|
77
|
+
if self.title:
|
|
78
|
+
lines.append(f"📋 {self.title}")
|
|
79
|
+
|
|
80
|
+
completed_count = 0
|
|
81
|
+
for i, step in enumerate(self.steps, 1):
|
|
82
|
+
checkbox = "☑" if step.completed else "☐"
|
|
83
|
+
status_text = " (completed)" if step.completed else ""
|
|
84
|
+
lines.append(f" {checkbox} {i}. {step.description}{status_text}")
|
|
85
|
+
if step.completed:
|
|
86
|
+
completed_count += 1
|
|
87
|
+
|
|
88
|
+
lines.append(f"\nProgress: {completed_count}/{len(self.steps)} steps completed")
|
|
89
|
+
|
|
90
|
+
return "\n".join(lines)
|
|
91
|
+
|
|
92
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
93
|
+
"""Convert to dictionary for JSONB storage."""
|
|
94
|
+
return {
|
|
95
|
+
"steps": [{"description": s.description, "completed": s.completed} for s in self.steps]
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
@classmethod
|
|
99
|
+
def from_dict(cls, data: Dict[str, Any], title: str = "", status: str = PlanStatus.in_progress.value) -> "PlanState":
|
|
100
|
+
"""Create from dictionary (JSONB data)."""
|
|
101
|
+
steps_data = data.get("steps", [])
|
|
102
|
+
steps = [PlanStep(**s) if isinstance(s, dict) else s for s in steps_data]
|
|
103
|
+
return cls(title=title, steps=steps, status=status)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def ensure_plan_tables(connection_string: str) -> bool:
|
|
107
|
+
"""
|
|
108
|
+
Ensure the agent_plans table exists in the database.
|
|
109
|
+
|
|
110
|
+
Creates the table if it doesn't exist. Safe to call multiple times.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
connection_string: PostgreSQL connection string
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
True if table was created or already exists, False on error
|
|
117
|
+
"""
|
|
118
|
+
try:
|
|
119
|
+
# Handle SecretStr if passed
|
|
120
|
+
if hasattr(connection_string, 'get_secret_value'):
|
|
121
|
+
connection_string = connection_string.get_secret_value()
|
|
122
|
+
|
|
123
|
+
if not connection_string:
|
|
124
|
+
logger.warning("No connection string provided for plan tables")
|
|
125
|
+
return False
|
|
126
|
+
|
|
127
|
+
engine = create_engine(connection_string)
|
|
128
|
+
|
|
129
|
+
# Create tables if they don't exist
|
|
130
|
+
Base.metadata.create_all(engine, checkfirst=True)
|
|
131
|
+
|
|
132
|
+
logger.debug("Agent plans table ensured")
|
|
133
|
+
return True
|
|
134
|
+
|
|
135
|
+
except Exception as e:
|
|
136
|
+
logger.error(f"Failed to ensure plan tables: {e}")
|
|
137
|
+
return False
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def delete_plan_by_conversation_id(connection_string: str, conversation_id: str) -> bool:
|
|
141
|
+
"""
|
|
142
|
+
Delete a plan by conversation_id.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
connection_string: PostgreSQL connection string
|
|
146
|
+
conversation_id: The conversation ID to delete plans for
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
True if deletion successful, False otherwise
|
|
150
|
+
"""
|
|
151
|
+
try:
|
|
152
|
+
if hasattr(connection_string, 'get_secret_value'):
|
|
153
|
+
connection_string = connection_string.get_secret_value()
|
|
154
|
+
|
|
155
|
+
if not connection_string or not conversation_id:
|
|
156
|
+
return False
|
|
157
|
+
|
|
158
|
+
engine = create_engine(connection_string)
|
|
159
|
+
|
|
160
|
+
with engine.connect() as conn:
|
|
161
|
+
result = conn.execute(
|
|
162
|
+
text("DELETE FROM agent_plans WHERE conversation_id = :conversation_id"),
|
|
163
|
+
{"conversation_id": conversation_id}
|
|
164
|
+
)
|
|
165
|
+
conn.commit()
|
|
166
|
+
|
|
167
|
+
logger.debug(f"Deleted plan for conversation_id: {conversation_id}")
|
|
168
|
+
return True
|
|
169
|
+
|
|
170
|
+
except Exception as e:
|
|
171
|
+
logger.error(f"Failed to delete plan for conversation_id {conversation_id}: {e}")
|
|
172
|
+
return False
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def cleanup_on_graceful_completion(
|
|
176
|
+
connection_string: str,
|
|
177
|
+
conversation_id: str,
|
|
178
|
+
thread_id: str = None,
|
|
179
|
+
delete_checkpoints: bool = True
|
|
180
|
+
) -> dict:
|
|
181
|
+
"""
|
|
182
|
+
Cleanup plans and optionally checkpoints after graceful agent completion.
|
|
183
|
+
|
|
184
|
+
This function is designed to be called after an agent completes successfully
|
|
185
|
+
(no exceptions, valid finish reason).
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
connection_string: PostgreSQL connection string
|
|
189
|
+
conversation_id: The conversation ID to cleanup plans for
|
|
190
|
+
thread_id: The thread ID to cleanup checkpoints for (optional)
|
|
191
|
+
delete_checkpoints: If True, also delete checkpoint data
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
Dict with cleanup results: {'plan_deleted': bool, 'checkpoints_deleted': bool}
|
|
195
|
+
"""
|
|
196
|
+
result = {'plan_deleted': False, 'checkpoints_deleted': False}
|
|
197
|
+
|
|
198
|
+
try:
|
|
199
|
+
if hasattr(connection_string, 'get_secret_value'):
|
|
200
|
+
connection_string = connection_string.get_secret_value()
|
|
201
|
+
|
|
202
|
+
if not connection_string or not conversation_id:
|
|
203
|
+
logger.warning("Missing connection_string or conversation_id for cleanup")
|
|
204
|
+
return result
|
|
205
|
+
|
|
206
|
+
engine = create_engine(connection_string)
|
|
207
|
+
|
|
208
|
+
with engine.connect() as conn:
|
|
209
|
+
# Delete plan by conversation_id
|
|
210
|
+
try:
|
|
211
|
+
conn.execute(
|
|
212
|
+
text("DELETE FROM agent_plans WHERE conversation_id = :conversation_id"),
|
|
213
|
+
{"conversation_id": conversation_id}
|
|
214
|
+
)
|
|
215
|
+
result['plan_deleted'] = True
|
|
216
|
+
logger.debug(f"Deleted plan for conversation_id: {conversation_id}")
|
|
217
|
+
except Exception as e:
|
|
218
|
+
# Table might not exist, which is fine
|
|
219
|
+
logger.debug(f"Could not delete plan (table may not exist): {e}")
|
|
220
|
+
|
|
221
|
+
# Delete checkpoints if requested (still uses thread_id as that's LangGraph's key)
|
|
222
|
+
if delete_checkpoints and thread_id:
|
|
223
|
+
checkpoint_tables = [
|
|
224
|
+
"checkpoints",
|
|
225
|
+
"checkpoint_writes",
|
|
226
|
+
"checkpoint_blobs"
|
|
227
|
+
]
|
|
228
|
+
|
|
229
|
+
for table in checkpoint_tables:
|
|
230
|
+
try:
|
|
231
|
+
conn.execute(
|
|
232
|
+
text(f"DELETE FROM {table} WHERE thread_id = :thread_id"),
|
|
233
|
+
{"thread_id": thread_id}
|
|
234
|
+
)
|
|
235
|
+
logger.debug(f"Deleted {table} for thread_id: {thread_id}")
|
|
236
|
+
except Exception as e:
|
|
237
|
+
logger.debug(f"Could not delete from {table}: {e}")
|
|
238
|
+
|
|
239
|
+
result['checkpoints_deleted'] = True
|
|
240
|
+
|
|
241
|
+
conn.commit()
|
|
242
|
+
|
|
243
|
+
except Exception as e:
|
|
244
|
+
logger.error(f"Failed to cleanup for conversation_id {conversation_id}: {e}")
|
|
245
|
+
|
|
246
|
+
return result
|