alita-sdk 0.3.457__py3-none-any.whl → 0.3.486__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of alita-sdk might be problematic. Click here for more details.

Files changed (102) hide show
  1. alita_sdk/cli/__init__.py +10 -0
  2. alita_sdk/cli/__main__.py +17 -0
  3. alita_sdk/cli/agent/__init__.py +5 -0
  4. alita_sdk/cli/agent/default.py +258 -0
  5. alita_sdk/cli/agent_executor.py +155 -0
  6. alita_sdk/cli/agent_loader.py +194 -0
  7. alita_sdk/cli/agent_ui.py +228 -0
  8. alita_sdk/cli/agents.py +3592 -0
  9. alita_sdk/cli/callbacks.py +647 -0
  10. alita_sdk/cli/cli.py +168 -0
  11. alita_sdk/cli/config.py +306 -0
  12. alita_sdk/cli/context/__init__.py +30 -0
  13. alita_sdk/cli/context/cleanup.py +198 -0
  14. alita_sdk/cli/context/manager.py +731 -0
  15. alita_sdk/cli/context/message.py +285 -0
  16. alita_sdk/cli/context/strategies.py +289 -0
  17. alita_sdk/cli/context/token_estimation.py +127 -0
  18. alita_sdk/cli/formatting.py +182 -0
  19. alita_sdk/cli/input_handler.py +419 -0
  20. alita_sdk/cli/inventory.py +1256 -0
  21. alita_sdk/cli/mcp_loader.py +315 -0
  22. alita_sdk/cli/toolkit.py +327 -0
  23. alita_sdk/cli/toolkit_loader.py +85 -0
  24. alita_sdk/cli/tools/__init__.py +43 -0
  25. alita_sdk/cli/tools/approval.py +224 -0
  26. alita_sdk/cli/tools/filesystem.py +1665 -0
  27. alita_sdk/cli/tools/planning.py +389 -0
  28. alita_sdk/cli/tools/terminal.py +414 -0
  29. alita_sdk/community/__init__.py +64 -8
  30. alita_sdk/community/inventory/__init__.py +224 -0
  31. alita_sdk/community/inventory/config.py +257 -0
  32. alita_sdk/community/inventory/enrichment.py +2137 -0
  33. alita_sdk/community/inventory/extractors.py +1469 -0
  34. alita_sdk/community/inventory/ingestion.py +3172 -0
  35. alita_sdk/community/inventory/knowledge_graph.py +1457 -0
  36. alita_sdk/community/inventory/parsers/__init__.py +218 -0
  37. alita_sdk/community/inventory/parsers/base.py +295 -0
  38. alita_sdk/community/inventory/parsers/csharp_parser.py +907 -0
  39. alita_sdk/community/inventory/parsers/go_parser.py +851 -0
  40. alita_sdk/community/inventory/parsers/html_parser.py +389 -0
  41. alita_sdk/community/inventory/parsers/java_parser.py +593 -0
  42. alita_sdk/community/inventory/parsers/javascript_parser.py +629 -0
  43. alita_sdk/community/inventory/parsers/kotlin_parser.py +768 -0
  44. alita_sdk/community/inventory/parsers/markdown_parser.py +362 -0
  45. alita_sdk/community/inventory/parsers/python_parser.py +604 -0
  46. alita_sdk/community/inventory/parsers/rust_parser.py +858 -0
  47. alita_sdk/community/inventory/parsers/swift_parser.py +832 -0
  48. alita_sdk/community/inventory/parsers/text_parser.py +322 -0
  49. alita_sdk/community/inventory/parsers/yaml_parser.py +370 -0
  50. alita_sdk/community/inventory/patterns/__init__.py +61 -0
  51. alita_sdk/community/inventory/patterns/ast_adapter.py +380 -0
  52. alita_sdk/community/inventory/patterns/loader.py +348 -0
  53. alita_sdk/community/inventory/patterns/registry.py +198 -0
  54. alita_sdk/community/inventory/presets.py +535 -0
  55. alita_sdk/community/inventory/retrieval.py +1403 -0
  56. alita_sdk/community/inventory/toolkit.py +169 -0
  57. alita_sdk/community/inventory/visualize.py +1370 -0
  58. alita_sdk/configurations/bitbucket.py +0 -3
  59. alita_sdk/runtime/clients/client.py +99 -26
  60. alita_sdk/runtime/langchain/assistant.py +4 -2
  61. alita_sdk/runtime/langchain/constants.py +2 -1
  62. alita_sdk/runtime/langchain/langraph_agent.py +134 -31
  63. alita_sdk/runtime/langchain/utils.py +1 -1
  64. alita_sdk/runtime/llms/preloaded.py +2 -6
  65. alita_sdk/runtime/toolkits/__init__.py +2 -0
  66. alita_sdk/runtime/toolkits/application.py +1 -1
  67. alita_sdk/runtime/toolkits/mcp.py +46 -36
  68. alita_sdk/runtime/toolkits/planning.py +171 -0
  69. alita_sdk/runtime/toolkits/tools.py +39 -6
  70. alita_sdk/runtime/tools/function.py +17 -5
  71. alita_sdk/runtime/tools/llm.py +249 -14
  72. alita_sdk/runtime/tools/planning/__init__.py +36 -0
  73. alita_sdk/runtime/tools/planning/models.py +246 -0
  74. alita_sdk/runtime/tools/planning/wrapper.py +607 -0
  75. alita_sdk/runtime/tools/vectorstore_base.py +41 -6
  76. alita_sdk/runtime/utils/mcp_oauth.py +80 -0
  77. alita_sdk/runtime/utils/streamlit.py +6 -10
  78. alita_sdk/runtime/utils/toolkit_utils.py +19 -4
  79. alita_sdk/tools/__init__.py +54 -27
  80. alita_sdk/tools/ado/repos/repos_wrapper.py +1 -2
  81. alita_sdk/tools/base_indexer_toolkit.py +150 -19
  82. alita_sdk/tools/bitbucket/__init__.py +2 -2
  83. alita_sdk/tools/chunkers/__init__.py +3 -1
  84. alita_sdk/tools/chunkers/sematic/markdown_chunker.py +95 -6
  85. alita_sdk/tools/chunkers/universal_chunker.py +269 -0
  86. alita_sdk/tools/code_indexer_toolkit.py +55 -22
  87. alita_sdk/tools/elitea_base.py +86 -21
  88. alita_sdk/tools/jira/__init__.py +1 -1
  89. alita_sdk/tools/jira/api_wrapper.py +91 -40
  90. alita_sdk/tools/non_code_indexer_toolkit.py +1 -0
  91. alita_sdk/tools/qtest/__init__.py +1 -1
  92. alita_sdk/tools/qtest/api_wrapper.py +871 -32
  93. alita_sdk/tools/sharepoint/api_wrapper.py +22 -2
  94. alita_sdk/tools/sharepoint/authorization_helper.py +17 -1
  95. alita_sdk/tools/vector_adapters/VectorStoreAdapter.py +8 -2
  96. alita_sdk/tools/zephyr_essential/api_wrapper.py +12 -13
  97. {alita_sdk-0.3.457.dist-info → alita_sdk-0.3.486.dist-info}/METADATA +146 -2
  98. {alita_sdk-0.3.457.dist-info → alita_sdk-0.3.486.dist-info}/RECORD +102 -40
  99. alita_sdk-0.3.486.dist-info/entry_points.txt +2 -0
  100. {alita_sdk-0.3.457.dist-info → alita_sdk-0.3.486.dist-info}/WHEEL +0 -0
  101. {alita_sdk-0.3.457.dist-info → alita_sdk-0.3.486.dist-info}/licenses/LICENSE +0 -0
  102. {alita_sdk-0.3.457.dist-info → alita_sdk-0.3.486.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ import asyncio
1
2
  import logging
2
3
  from traceback import format_exc
3
4
  from typing import Any, Optional, List, Union
@@ -12,6 +13,7 @@ from ..langchain.utils import create_pydantic_model, propagate_the_input_mapping
12
13
 
13
14
  logger = logging.getLogger(__name__)
14
15
 
16
+
15
17
  class LLMNode(BaseTool):
16
18
  """Enhanced LLM node with chat history and tool binding support"""
17
19
 
@@ -60,6 +62,47 @@ class LLMNode(BaseTool):
60
62
 
61
63
  return filtered_tools
62
64
 
65
+ def _get_tool_truncation_suggestions(self, tool_name: Optional[str]) -> str:
66
+ """
67
+ Get context-specific suggestions for how to reduce output from a tool.
68
+
69
+ First checks if the tool itself provides truncation suggestions via
70
+ `truncation_suggestions` attribute or `get_truncation_suggestions()` method.
71
+ Falls back to generic suggestions if the tool doesn't provide any.
72
+
73
+ Args:
74
+ tool_name: Name of the tool that caused the context overflow
75
+
76
+ Returns:
77
+ Formatted string with numbered suggestions for the specific tool
78
+ """
79
+ suggestions = None
80
+
81
+ # Try to get suggestions from the tool itself
82
+ if tool_name:
83
+ filtered_tools = self.get_filtered_tools()
84
+ for tool in filtered_tools:
85
+ if tool.name == tool_name:
86
+ # Check for truncation_suggestions attribute
87
+ if hasattr(tool, 'truncation_suggestions') and tool.truncation_suggestions:
88
+ suggestions = tool.truncation_suggestions
89
+ break
90
+ # Check for get_truncation_suggestions method
91
+ elif hasattr(tool, 'get_truncation_suggestions') and callable(tool.get_truncation_suggestions):
92
+ suggestions = tool.get_truncation_suggestions()
93
+ break
94
+
95
+ # Fall back to generic suggestions if tool doesn't provide any
96
+ if not suggestions:
97
+ suggestions = [
98
+ "Check if the tool has parameters to limit output size (e.g., max_items, max_results, max_depth)",
99
+ "Target a more specific path or query instead of broad searches",
100
+ "Break the operation into smaller, focused requests",
101
+ ]
102
+
103
+ # Format as numbered list
104
+ return "\n".join(f"{i+1}. {s}" for i, s in enumerate(suggestions))
105
+
63
106
  def invoke(
64
107
  self,
65
108
  state: Union[str, dict],
@@ -132,7 +175,9 @@ class LLMNode(BaseTool):
132
175
  struct_model = create_pydantic_model(f"LLMOutput", struct_params)
133
176
  completion = llm_client.invoke(messages, config=config)
134
177
  if hasattr(completion, 'tool_calls') and completion.tool_calls:
135
- new_messages, _ = self.__perform_tool_calling(completion, messages, llm_client, config)
178
+ new_messages, _ = self._run_async_in_sync_context(
179
+ self.__perform_tool_calling(completion, messages, llm_client, config)
180
+ )
136
181
  llm = self.__get_struct_output_model(llm_client, struct_model)
137
182
  completion = llm.invoke(new_messages, config=config)
138
183
  result = completion.model_dump()
@@ -155,7 +200,9 @@ class LLMNode(BaseTool):
155
200
  # Handle both tool-calling and regular responses
156
201
  if hasattr(completion, 'tool_calls') and completion.tool_calls:
157
202
  # Handle iterative tool-calling and execution
158
- new_messages, current_completion = self.__perform_tool_calling(completion, messages, llm_client, config)
203
+ new_messages, current_completion = self._run_async_in_sync_context(
204
+ self.__perform_tool_calling(completion, messages, llm_client, config)
205
+ )
159
206
 
160
207
  output_msgs = {"messages": new_messages}
161
208
  if self.output_variables:
@@ -190,9 +237,53 @@ class LLMNode(BaseTool):
190
237
  def _run(self, *args, **kwargs):
191
238
  # Legacy support for old interface
192
239
  return self.invoke(kwargs, **kwargs)
240
+
241
+ def _run_async_in_sync_context(self, coro):
242
+ """Run async coroutine from sync context.
243
+
244
+ For MCP tools with persistent sessions, we reuse the same event loop
245
+ that was used to create the MCP client and sessions (set by CLI).
246
+ """
247
+ try:
248
+ loop = asyncio.get_running_loop()
249
+ # Already in async context - run in thread with new loop
250
+ import threading
251
+
252
+ result_container = []
253
+
254
+ def run_in_thread():
255
+ new_loop = asyncio.new_event_loop()
256
+ asyncio.set_event_loop(new_loop)
257
+ try:
258
+ result_container.append(new_loop.run_until_complete(coro))
259
+ finally:
260
+ new_loop.close()
261
+
262
+ thread = threading.Thread(target=run_in_thread)
263
+ thread.start()
264
+ thread.join()
265
+ return result_container[0] if result_container else None
266
+
267
+ except RuntimeError:
268
+ # No event loop running - use/create persistent loop
269
+ # This loop is shared with MCP session creation for stateful tools
270
+ if not hasattr(self.__class__, '_persistent_loop') or \
271
+ self.__class__._persistent_loop is None or \
272
+ self.__class__._persistent_loop.is_closed():
273
+ self.__class__._persistent_loop = asyncio.new_event_loop()
274
+ logger.debug("Created persistent event loop for async tools")
275
+
276
+ loop = self.__class__._persistent_loop
277
+ asyncio.set_event_loop(loop)
278
+ return loop.run_until_complete(coro)
279
+
280
+ async def _arun(self, *args, **kwargs):
281
+ # Legacy async support
282
+ return self.invoke(kwargs, **kwargs)
193
283
 
194
- def __perform_tool_calling(self, completion, messages, llm_client, config):
284
+ async def __perform_tool_calling(self, completion, messages, llm_client, config):
195
285
  # Handle iterative tool-calling and execution
286
+ logger.info(f"__perform_tool_calling called with {len(completion.tool_calls) if hasattr(completion, 'tool_calls') else 0} tool calls")
196
287
  new_messages = messages + [completion]
197
288
  iteration = 0
198
289
 
@@ -230,9 +321,16 @@ class LLMNode(BaseTool):
230
321
  if tool_to_execute:
231
322
  try:
232
323
  logger.info(f"Executing tool '{tool_name}' with args: {tool_args}")
233
- # Pass the underlying config to the tool execution invoke method
234
- # since it may be another agent, graph, etc. to see it properly in thinking steps
235
- tool_result = tool_to_execute.invoke(tool_args, config=config)
324
+
325
+ # Try async invoke first (for MCP tools), fallback to sync
326
+ tool_result = None
327
+ try:
328
+ # Try async invocation first
329
+ tool_result = await tool_to_execute.ainvoke(tool_args, config=config)
330
+ except NotImplementedError:
331
+ # Tool doesn't support async, use sync invoke
332
+ logger.debug(f"Tool '{tool_name}' doesn't support async, using sync invoke")
333
+ tool_result = tool_to_execute.invoke(tool_args, config=config)
236
334
 
237
335
  # Create tool message with result - preserve structured content
238
336
  from langchain_core.messages import ToolMessage
@@ -256,7 +354,10 @@ class LLMNode(BaseTool):
256
354
  new_messages.append(tool_message)
257
355
 
258
356
  except Exception as e:
259
- logger.error(f"Error executing tool '{tool_name}': {e}")
357
+ import traceback
358
+ error_details = traceback.format_exc()
359
+ # Use debug level to avoid duplicate output when CLI callbacks are active
360
+ logger.debug(f"Error executing tool '{tool_name}': {e}\n{error_details}")
260
361
  # Create error tool message
261
362
  from langchain_core.messages import ToolMessage
262
363
  tool_message = ToolMessage(
@@ -287,16 +388,150 @@ class LLMNode(BaseTool):
287
388
  break
288
389
 
289
390
  except Exception as e:
290
- logger.error(f"Error in LLM call during iteration {iteration}: {e}")
291
- # Add error message and break the loop
292
- error_msg = f"Error processing tool results in iteration {iteration}: {str(e)}"
293
- new_messages.append(AIMessage(content=error_msg))
294
- break
391
+ error_str = str(e).lower()
392
+
393
+ # Check for context window / token limit errors
394
+ is_context_error = any(indicator in error_str for indicator in [
395
+ 'context window', 'context_window', 'token limit', 'too long',
396
+ 'maximum context length', 'input is too long', 'exceeds the limit',
397
+ 'contextwindowexceedederror', 'max_tokens', 'content too large'
398
+ ])
399
+
400
+ # Check for Bedrock/Claude output limit errors
401
+ # These often manifest as "model identifier is invalid" when output exceeds limits
402
+ is_output_limit_error = any(indicator in error_str for indicator in [
403
+ 'model identifier is invalid',
404
+ 'bedrockexception',
405
+ 'output token',
406
+ 'response too large',
407
+ 'max_tokens_to_sample',
408
+ 'output_token_limit'
409
+ ])
410
+
411
+ if is_context_error or is_output_limit_error:
412
+ error_type = "output limit" if is_output_limit_error else "context window"
413
+ logger.warning(f"{error_type.title()} exceeded during tool execution iteration {iteration}")
414
+
415
+ # Find the last tool message and its associated tool name
416
+ last_tool_msg_idx = None
417
+ last_tool_name = None
418
+ last_tool_call_id = None
419
+
420
+ # First, find the last tool message
421
+ for i in range(len(new_messages) - 1, -1, -1):
422
+ msg = new_messages[i]
423
+ if hasattr(msg, 'tool_call_id') or (hasattr(msg, 'type') and getattr(msg, 'type', None) == 'tool'):
424
+ last_tool_msg_idx = i
425
+ last_tool_call_id = getattr(msg, 'tool_call_id', None)
426
+ break
427
+
428
+ # Find the tool name from the AIMessage that requested this tool call
429
+ if last_tool_call_id:
430
+ for i in range(last_tool_msg_idx - 1, -1, -1):
431
+ msg = new_messages[i]
432
+ if hasattr(msg, 'tool_calls') and msg.tool_calls:
433
+ for tc in msg.tool_calls:
434
+ tc_id = tc.get('id', '') if isinstance(tc, dict) else getattr(tc, 'id', '')
435
+ if tc_id == last_tool_call_id:
436
+ last_tool_name = tc.get('name', '') if isinstance(tc, dict) else getattr(tc, 'name', '')
437
+ break
438
+ if last_tool_name:
439
+ break
440
+
441
+ # Build dynamic suggestion based on the tool that caused the overflow
442
+ tool_suggestions = self._get_tool_truncation_suggestions(last_tool_name)
443
+
444
+ # Truncate the problematic tool result if found
445
+ if last_tool_msg_idx is not None:
446
+ from langchain_core.messages import ToolMessage
447
+ original_msg = new_messages[last_tool_msg_idx]
448
+ tool_call_id = getattr(original_msg, 'tool_call_id', 'unknown')
449
+
450
+ # Build error-specific guidance
451
+ if is_output_limit_error:
452
+ truncated_content = (
453
+ f"⚠️ MODEL OUTPUT LIMIT EXCEEDED\n\n"
454
+ f"The tool '{last_tool_name or 'unknown'}' returned data, but the model's response was too large.\n\n"
455
+ f"IMPORTANT: You must provide a SMALLER, more focused response.\n"
456
+ f"- Break down your response into smaller chunks\n"
457
+ f"- Summarize instead of listing everything\n"
458
+ f"- Focus on the most relevant information first\n"
459
+ f"- If listing items, show only top 5-10 most important\n\n"
460
+ f"Tool-specific tips:\n{tool_suggestions}\n\n"
461
+ f"Please retry with a more concise response."
462
+ )
463
+ else:
464
+ truncated_content = (
465
+ f"⚠️ TOOL OUTPUT TRUNCATED - Context window exceeded\n\n"
466
+ f"The tool '{last_tool_name or 'unknown'}' returned too much data for the model's context window.\n\n"
467
+ f"To fix this:\n{tool_suggestions}\n\n"
468
+ f"Please retry with more restrictive parameters."
469
+ )
470
+
471
+ truncated_msg = ToolMessage(
472
+ content=truncated_content,
473
+ tool_call_id=tool_call_id
474
+ )
475
+ new_messages[last_tool_msg_idx] = truncated_msg
476
+
477
+ logger.info(f"Truncated large tool result from '{last_tool_name}' and continuing")
478
+ # Continue to next iteration - the model will see the truncation message
479
+ continue
480
+ else:
481
+ # Couldn't find tool message, add error and break
482
+ if is_output_limit_error:
483
+ error_msg = (
484
+ "Model output limit exceeded. Please provide a more concise response. "
485
+ "Break down your answer into smaller parts and summarize where possible."
486
+ )
487
+ else:
488
+ error_msg = (
489
+ "Context window exceeded. The conversation or tool results are too large. "
490
+ "Try using tools with smaller output limits (e.g., max_items, max_depth parameters)."
491
+ )
492
+ new_messages.append(AIMessage(content=error_msg))
493
+ break
494
+ else:
495
+ logger.error(f"Error in LLM call during iteration {iteration}: {e}")
496
+ # Add error message and break the loop
497
+ error_msg = f"Error processing tool results in iteration {iteration}: {str(e)}"
498
+ new_messages.append(AIMessage(content=error_msg))
499
+ break
295
500
 
296
- # Log completion status
501
+ # Handle max iterations
297
502
  if iteration >= self.steps_limit:
298
503
  logger.warning(f"Reached maximum iterations ({self.steps_limit}) for tool execution")
299
- # Add a warning message to the chat
504
+
505
+ # CRITICAL: Check if the last message is an AIMessage with pending tool_calls
506
+ # that were not processed. If so, we need to add placeholder ToolMessages to prevent
507
+ # the "assistant message with 'tool_calls' must be followed by tool messages" error
508
+ # when the conversation continues.
509
+ if new_messages:
510
+ last_msg = new_messages[-1]
511
+ if hasattr(last_msg, 'tool_calls') and last_msg.tool_calls:
512
+ from langchain_core.messages import ToolMessage
513
+ pending_tool_calls = last_msg.tool_calls if hasattr(last_msg.tool_calls, '__iter__') else []
514
+
515
+ # Check which tool_call_ids already have responses
516
+ existing_tool_call_ids = set()
517
+ for msg in new_messages:
518
+ if hasattr(msg, 'tool_call_id'):
519
+ existing_tool_call_ids.add(msg.tool_call_id)
520
+
521
+ # Add placeholder responses for any tool calls without responses
522
+ for tool_call in pending_tool_calls:
523
+ tool_call_id = tool_call.get('id', '') if isinstance(tool_call, dict) else getattr(tool_call, 'id', '')
524
+ tool_name = tool_call.get('name', '') if isinstance(tool_call, dict) else getattr(tool_call, 'name', '')
525
+
526
+ if tool_call_id and tool_call_id not in existing_tool_call_ids:
527
+ logger.info(f"Adding placeholder ToolMessage for interrupted tool call: {tool_name} ({tool_call_id})")
528
+ placeholder_msg = ToolMessage(
529
+ content=f"[Tool execution interrupted - step limit ({self.steps_limit}) reached before {tool_name} could be executed]",
530
+ tool_call_id=tool_call_id
531
+ )
532
+ new_messages.append(placeholder_msg)
533
+
534
+ # Add warning message - CLI or calling code can detect this and prompt user
300
535
  warning_msg = f"Maximum tool execution iterations ({self.steps_limit}) reached. Stopping tool execution."
301
536
  new_messages.append(AIMessage(content=warning_msg))
302
537
  else:
@@ -0,0 +1,36 @@
1
+ """
2
+ Planning tools for runtime agents.
3
+
4
+ Provides plan management for multi-step task execution with progress tracking.
5
+ Supports two storage backends:
6
+ 1. PostgreSQL - when connection_string is provided (production/indexer_worker)
7
+ 2. Filesystem - when no connection string (local CLI usage)
8
+ """
9
+
10
+ from .wrapper import (
11
+ PlanningWrapper,
12
+ PlanStep,
13
+ PlanState,
14
+ FilesystemStorage,
15
+ PostgresStorage,
16
+ )
17
+ from .models import (
18
+ AgentPlan,
19
+ PlanStatus,
20
+ ensure_plan_tables,
21
+ delete_plan_by_conversation_id,
22
+ cleanup_on_graceful_completion
23
+ )
24
+
25
+ __all__ = [
26
+ "PlanningWrapper",
27
+ "PlanStep",
28
+ "PlanState",
29
+ "FilesystemStorage",
30
+ "PostgresStorage",
31
+ "AgentPlan",
32
+ "PlanStatus",
33
+ "ensure_plan_tables",
34
+ "delete_plan_by_conversation_id",
35
+ "cleanup_on_graceful_completion",
36
+ ]
@@ -0,0 +1,246 @@
1
+ """
2
+ SQLAlchemy models for agent planning.
3
+
4
+ Defines the AgentPlan table for storing execution plans with steps.
5
+ Table is created automatically on toolkit initialization if it doesn't exist.
6
+ """
7
+
8
+ import enum
9
+ import logging
10
+ import uuid
11
+ from datetime import datetime
12
+ from typing import List, Dict, Any, Optional
13
+
14
+ from pydantic import BaseModel, Field
15
+ from sqlalchemy import Column, String, DateTime, Text, Index, text
16
+ from sqlalchemy.dialects.postgresql import UUID, JSONB
17
+ from sqlalchemy.orm import declarative_base
18
+ from sqlalchemy import create_engine
19
+ from sqlalchemy.exc import ProgrammingError
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ Base = declarative_base()
24
+
25
+
26
+ class PlanStatus(str, enum.Enum):
27
+ """Status of an execution plan."""
28
+ in_progress = "in_progress"
29
+ completed = "completed"
30
+ abandoned = "abandoned"
31
+
32
+
33
+ class AgentPlan(Base):
34
+ """
35
+ Stores execution plans for agent tasks.
36
+
37
+ Created in the project-specific pgvector database.
38
+ Plans are scoped by conversation_id (from server or CLI session_id).
39
+ """
40
+ __tablename__ = "agent_plans"
41
+
42
+ id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
43
+ conversation_id = Column(String(255), nullable=False, index=True)
44
+
45
+ # Plan metadata
46
+ title = Column(String(255), nullable=True)
47
+ status = Column(String(50), default=PlanStatus.in_progress.value)
48
+
49
+ # Plan content (JSONB for flexible step storage)
50
+ # Structure: {"steps": [{"description": "...", "completed": false}, ...]}
51
+ plan_data = Column(JSONB, nullable=False, default=dict)
52
+
53
+ # Timestamps
54
+ created_at = Column(DateTime, nullable=False, default=datetime.utcnow)
55
+ updated_at = Column(DateTime, nullable=True, onupdate=datetime.utcnow)
56
+
57
+
58
+ # Pydantic models for tool input/output
59
+ class PlanStep(BaseModel):
60
+ """A single step in a plan."""
61
+ description: str = Field(description="Step description")
62
+ completed: bool = Field(default=False, description="Whether step is completed")
63
+
64
+
65
+ class PlanState(BaseModel):
66
+ """Current plan state for serialization."""
67
+ title: str = Field(default="", description="Plan title")
68
+ steps: List[PlanStep] = Field(default_factory=list, description="List of steps")
69
+ status: str = Field(default=PlanStatus.in_progress.value, description="Plan status")
70
+
71
+ def render(self) -> str:
72
+ """Render plan as formatted string with checkboxes."""
73
+ if not self.steps:
74
+ return "No plan currently set."
75
+
76
+ lines = []
77
+ if self.title:
78
+ lines.append(f"📋 {self.title}")
79
+
80
+ completed_count = 0
81
+ for i, step in enumerate(self.steps, 1):
82
+ checkbox = "☑" if step.completed else "☐"
83
+ status_text = " (completed)" if step.completed else ""
84
+ lines.append(f" {checkbox} {i}. {step.description}{status_text}")
85
+ if step.completed:
86
+ completed_count += 1
87
+
88
+ lines.append(f"\nProgress: {completed_count}/{len(self.steps)} steps completed")
89
+
90
+ return "\n".join(lines)
91
+
92
+ def to_dict(self) -> Dict[str, Any]:
93
+ """Convert to dictionary for JSONB storage."""
94
+ return {
95
+ "steps": [{"description": s.description, "completed": s.completed} for s in self.steps]
96
+ }
97
+
98
+ @classmethod
99
+ def from_dict(cls, data: Dict[str, Any], title: str = "", status: str = PlanStatus.in_progress.value) -> "PlanState":
100
+ """Create from dictionary (JSONB data)."""
101
+ steps_data = data.get("steps", [])
102
+ steps = [PlanStep(**s) if isinstance(s, dict) else s for s in steps_data]
103
+ return cls(title=title, steps=steps, status=status)
104
+
105
+
106
+ def ensure_plan_tables(connection_string: str) -> bool:
107
+ """
108
+ Ensure the agent_plans table exists in the database.
109
+
110
+ Creates the table if it doesn't exist. Safe to call multiple times.
111
+
112
+ Args:
113
+ connection_string: PostgreSQL connection string
114
+
115
+ Returns:
116
+ True if table was created or already exists, False on error
117
+ """
118
+ try:
119
+ # Handle SecretStr if passed
120
+ if hasattr(connection_string, 'get_secret_value'):
121
+ connection_string = connection_string.get_secret_value()
122
+
123
+ if not connection_string:
124
+ logger.warning("No connection string provided for plan tables")
125
+ return False
126
+
127
+ engine = create_engine(connection_string)
128
+
129
+ # Create tables if they don't exist
130
+ Base.metadata.create_all(engine, checkfirst=True)
131
+
132
+ logger.debug("Agent plans table ensured")
133
+ return True
134
+
135
+ except Exception as e:
136
+ logger.error(f"Failed to ensure plan tables: {e}")
137
+ return False
138
+
139
+
140
+ def delete_plan_by_conversation_id(connection_string: str, conversation_id: str) -> bool:
141
+ """
142
+ Delete a plan by conversation_id.
143
+
144
+ Args:
145
+ connection_string: PostgreSQL connection string
146
+ conversation_id: The conversation ID to delete plans for
147
+
148
+ Returns:
149
+ True if deletion successful, False otherwise
150
+ """
151
+ try:
152
+ if hasattr(connection_string, 'get_secret_value'):
153
+ connection_string = connection_string.get_secret_value()
154
+
155
+ if not connection_string or not conversation_id:
156
+ return False
157
+
158
+ engine = create_engine(connection_string)
159
+
160
+ with engine.connect() as conn:
161
+ result = conn.execute(
162
+ text("DELETE FROM agent_plans WHERE conversation_id = :conversation_id"),
163
+ {"conversation_id": conversation_id}
164
+ )
165
+ conn.commit()
166
+
167
+ logger.debug(f"Deleted plan for conversation_id: {conversation_id}")
168
+ return True
169
+
170
+ except Exception as e:
171
+ logger.error(f"Failed to delete plan for conversation_id {conversation_id}: {e}")
172
+ return False
173
+
174
+
175
+ def cleanup_on_graceful_completion(
176
+ connection_string: str,
177
+ conversation_id: str,
178
+ thread_id: str = None,
179
+ delete_checkpoints: bool = True
180
+ ) -> dict:
181
+ """
182
+ Cleanup plans and optionally checkpoints after graceful agent completion.
183
+
184
+ This function is designed to be called after an agent completes successfully
185
+ (no exceptions, valid finish reason).
186
+
187
+ Args:
188
+ connection_string: PostgreSQL connection string
189
+ conversation_id: The conversation ID to cleanup plans for
190
+ thread_id: The thread ID to cleanup checkpoints for (optional)
191
+ delete_checkpoints: If True, also delete checkpoint data
192
+
193
+ Returns:
194
+ Dict with cleanup results: {'plan_deleted': bool, 'checkpoints_deleted': bool}
195
+ """
196
+ result = {'plan_deleted': False, 'checkpoints_deleted': False}
197
+
198
+ try:
199
+ if hasattr(connection_string, 'get_secret_value'):
200
+ connection_string = connection_string.get_secret_value()
201
+
202
+ if not connection_string or not conversation_id:
203
+ logger.warning("Missing connection_string or conversation_id for cleanup")
204
+ return result
205
+
206
+ engine = create_engine(connection_string)
207
+
208
+ with engine.connect() as conn:
209
+ # Delete plan by conversation_id
210
+ try:
211
+ conn.execute(
212
+ text("DELETE FROM agent_plans WHERE conversation_id = :conversation_id"),
213
+ {"conversation_id": conversation_id}
214
+ )
215
+ result['plan_deleted'] = True
216
+ logger.debug(f"Deleted plan for conversation_id: {conversation_id}")
217
+ except Exception as e:
218
+ # Table might not exist, which is fine
219
+ logger.debug(f"Could not delete plan (table may not exist): {e}")
220
+
221
+ # Delete checkpoints if requested (still uses thread_id as that's LangGraph's key)
222
+ if delete_checkpoints and thread_id:
223
+ checkpoint_tables = [
224
+ "checkpoints",
225
+ "checkpoint_writes",
226
+ "checkpoint_blobs"
227
+ ]
228
+
229
+ for table in checkpoint_tables:
230
+ try:
231
+ conn.execute(
232
+ text(f"DELETE FROM {table} WHERE thread_id = :thread_id"),
233
+ {"thread_id": thread_id}
234
+ )
235
+ logger.debug(f"Deleted {table} for thread_id: {thread_id}")
236
+ except Exception as e:
237
+ logger.debug(f"Could not delete from {table}: {e}")
238
+
239
+ result['checkpoints_deleted'] = True
240
+
241
+ conn.commit()
242
+
243
+ except Exception as e:
244
+ logger.error(f"Failed to cleanup for conversation_id {conversation_id}: {e}")
245
+
246
+ return result