hanzo-mcp 0.1.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hanzo-mcp might be problematic. Click here for more details.
- hanzo_mcp/__init__.py +3 -0
- hanzo_mcp/cli.py +213 -0
- hanzo_mcp/server.py +149 -0
- hanzo_mcp/tools/__init__.py +81 -0
- hanzo_mcp/tools/agent/__init__.py +59 -0
- hanzo_mcp/tools/agent/agent_tool.py +474 -0
- hanzo_mcp/tools/agent/prompt.py +137 -0
- hanzo_mcp/tools/agent/tool_adapter.py +75 -0
- hanzo_mcp/tools/common/__init__.py +18 -0
- hanzo_mcp/tools/common/base.py +216 -0
- hanzo_mcp/tools/common/context.py +444 -0
- hanzo_mcp/tools/common/permissions.py +253 -0
- hanzo_mcp/tools/common/thinking_tool.py +123 -0
- hanzo_mcp/tools/common/validation.py +124 -0
- hanzo_mcp/tools/filesystem/__init__.py +89 -0
- hanzo_mcp/tools/filesystem/base.py +113 -0
- hanzo_mcp/tools/filesystem/content_replace.py +287 -0
- hanzo_mcp/tools/filesystem/directory_tree.py +286 -0
- hanzo_mcp/tools/filesystem/edit_file.py +287 -0
- hanzo_mcp/tools/filesystem/get_file_info.py +170 -0
- hanzo_mcp/tools/filesystem/read_files.py +198 -0
- hanzo_mcp/tools/filesystem/search_content.py +275 -0
- hanzo_mcp/tools/filesystem/write_file.py +162 -0
- hanzo_mcp/tools/jupyter/__init__.py +71 -0
- hanzo_mcp/tools/jupyter/base.py +284 -0
- hanzo_mcp/tools/jupyter/edit_notebook.py +295 -0
- hanzo_mcp/tools/jupyter/notebook_operations.py +514 -0
- hanzo_mcp/tools/jupyter/read_notebook.py +165 -0
- hanzo_mcp/tools/project/__init__.py +64 -0
- hanzo_mcp/tools/project/analysis.py +882 -0
- hanzo_mcp/tools/project/base.py +66 -0
- hanzo_mcp/tools/project/project_analyze.py +173 -0
- hanzo_mcp/tools/shell/__init__.py +58 -0
- hanzo_mcp/tools/shell/base.py +148 -0
- hanzo_mcp/tools/shell/command_executor.py +740 -0
- hanzo_mcp/tools/shell/run_command.py +204 -0
- hanzo_mcp/tools/shell/run_script.py +215 -0
- hanzo_mcp/tools/shell/script_tool.py +244 -0
- hanzo_mcp-0.1.20.dist-info/METADATA +111 -0
- hanzo_mcp-0.1.20.dist-info/RECORD +44 -0
- hanzo_mcp-0.1.20.dist-info/WHEEL +5 -0
- hanzo_mcp-0.1.20.dist-info/entry_points.txt +2 -0
- hanzo_mcp-0.1.20.dist-info/licenses/LICENSE +21 -0
- hanzo_mcp-0.1.20.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""Common utilities for Hanzo MCP tools."""
|
|
2
|
+
|
|
3
|
+
from mcp.server.fastmcp import FastMCP
|
|
4
|
+
|
|
5
|
+
from hanzo_mcp.tools.common.base import ToolRegistry
|
|
6
|
+
from hanzo_mcp.tools.common.thinking_tool import ThinkingTool
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def register_thinking_tool(
|
|
10
|
+
mcp_server: FastMCP,
|
|
11
|
+
) -> None:
|
|
12
|
+
"""Register all thinking tools with the MCP server.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
mcp_server: The FastMCP server instance
|
|
16
|
+
"""
|
|
17
|
+
thinking_tool = ThinkingTool()
|
|
18
|
+
ToolRegistry.register_tool(mcp_server, thinking_tool)
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
"""Base classes for Hanzo MCP tools.
|
|
2
|
+
|
|
3
|
+
This module provides abstract base classes that define interfaces and common functionality
|
|
4
|
+
for all tools used in Hanzo MCP. These abstractions help ensure consistent tool
|
|
5
|
+
behavior and provide a foundation for tool registration and management.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from typing import Any, final
|
|
10
|
+
|
|
11
|
+
from mcp.server.fastmcp import Context as MCPContext
|
|
12
|
+
from mcp.server.fastmcp import FastMCP
|
|
13
|
+
|
|
14
|
+
from hanzo_mcp.tools.common.context import DocumentContext
|
|
15
|
+
from hanzo_mcp.tools.common.permissions import PermissionManager
|
|
16
|
+
from hanzo_mcp.tools.common.validation import ValidationResult, validate_path_parameter
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BaseTool(ABC):
|
|
20
|
+
"""Abstract base class for all Hanzo MCP tools.
|
|
21
|
+
|
|
22
|
+
This class defines the core interface that all tools must implement, ensuring
|
|
23
|
+
consistency in how tools are registered, documented, and called.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def name(self) -> str:
|
|
29
|
+
"""Get the tool name.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
The tool name as it will appear in the MCP server
|
|
33
|
+
"""
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
@abstractmethod
|
|
38
|
+
def description(self) -> str:
|
|
39
|
+
"""Get the tool description.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
Detailed description of the tool's purpose and usage
|
|
43
|
+
"""
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def mcp_description(self) -> str:
|
|
48
|
+
"""Get the complete tool description for MCP.
|
|
49
|
+
|
|
50
|
+
This method combines the tool description with parameter descriptions.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Complete tool description including parameter details
|
|
54
|
+
"""
|
|
55
|
+
# Start with the base description
|
|
56
|
+
desc = self.description.strip()
|
|
57
|
+
|
|
58
|
+
# Add parameter descriptions section if there are parameters
|
|
59
|
+
if self.parameters and "properties" in self.parameters:
|
|
60
|
+
# Add Args section header
|
|
61
|
+
desc += "\n\nArgs:"
|
|
62
|
+
|
|
63
|
+
# Get the properties dictionary
|
|
64
|
+
properties = self.parameters["properties"]
|
|
65
|
+
|
|
66
|
+
# Add each parameter description
|
|
67
|
+
for param_name, param_info in properties.items():
|
|
68
|
+
# Get the title if available, otherwise use the parameter name and capitalize it
|
|
69
|
+
if "title" in param_info:
|
|
70
|
+
title = param_info["title"]
|
|
71
|
+
else:
|
|
72
|
+
# Convert snake_case to Title Case
|
|
73
|
+
title = " ".join(word.capitalize() for word in param_name.split("_"))
|
|
74
|
+
|
|
75
|
+
# Check if the parameter is required
|
|
76
|
+
required = param_name in self.required
|
|
77
|
+
required_text = "" if required else " (optional)"
|
|
78
|
+
|
|
79
|
+
# Add the parameter description line
|
|
80
|
+
desc += f"\n {param_name}: {title}{required_text}"
|
|
81
|
+
|
|
82
|
+
# Add Returns section
|
|
83
|
+
desc += "\n\nReturns:\n "
|
|
84
|
+
|
|
85
|
+
# Add a generic return description based on the tool's purpose
|
|
86
|
+
# This could be enhanced with more specific return descriptions
|
|
87
|
+
if "read" in self.name or "get" in self.name or "search" in self.name:
|
|
88
|
+
desc += f"{self.name.replace('_', ' ').capitalize()} results"
|
|
89
|
+
elif "write" in self.name or "edit" in self.name or "create" in self.name:
|
|
90
|
+
desc += "Result of the operation"
|
|
91
|
+
else:
|
|
92
|
+
desc += "Tool execution results"
|
|
93
|
+
|
|
94
|
+
return desc
|
|
95
|
+
|
|
96
|
+
@property
|
|
97
|
+
@abstractmethod
|
|
98
|
+
def parameters(self) -> dict[str, Any]:
|
|
99
|
+
"""Get the parameter specifications for the tool.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
Dictionary containing parameter specifications in JSON Schema format
|
|
103
|
+
"""
|
|
104
|
+
pass
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
@abstractmethod
|
|
108
|
+
def required(self) -> list[str]:
|
|
109
|
+
"""Get the list of required parameter names.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
List of parameter names that are required for the tool
|
|
113
|
+
"""
|
|
114
|
+
pass
|
|
115
|
+
|
|
116
|
+
@abstractmethod
|
|
117
|
+
async def call(self, ctx: MCPContext, **params: Any) -> str:
|
|
118
|
+
"""Execute the tool with the given parameters.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
ctx: MCP context for the tool call
|
|
122
|
+
**params: Tool parameters provided by the caller
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
Tool execution result as a string
|
|
126
|
+
"""
|
|
127
|
+
pass
|
|
128
|
+
|
|
129
|
+
@abstractmethod
|
|
130
|
+
def register(self, mcp_server: FastMCP) -> None:
|
|
131
|
+
"""Register this tool with the MCP server.
|
|
132
|
+
|
|
133
|
+
This method must be implemented by each tool class to create a wrapper function
|
|
134
|
+
with explicitly defined parameters that calls this tool's call method.
|
|
135
|
+
The wrapper function is then registered with the MCP server.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
mcp_server: The FastMCP server instance
|
|
139
|
+
"""
|
|
140
|
+
pass
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class FileSystemTool(BaseTool,ABC):
|
|
144
|
+
"""Base class for filesystem-related tools.
|
|
145
|
+
|
|
146
|
+
Provides common functionality for working with files and directories,
|
|
147
|
+
including permission checking and path validation.
|
|
148
|
+
"""
|
|
149
|
+
|
|
150
|
+
def __init__(
|
|
151
|
+
self,
|
|
152
|
+
document_context: DocumentContext,
|
|
153
|
+
permission_manager: PermissionManager
|
|
154
|
+
) -> None:
|
|
155
|
+
"""Initialize filesystem tool.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
document_context: Document context for tracking file contents
|
|
159
|
+
permission_manager: Permission manager for access control
|
|
160
|
+
"""
|
|
161
|
+
self.document_context:DocumentContext = document_context
|
|
162
|
+
self.permission_manager:PermissionManager = permission_manager
|
|
163
|
+
|
|
164
|
+
def validate_path(self, path: str, param_name: str = "path") -> ValidationResult:
|
|
165
|
+
"""Validate a path parameter.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
path: Path to validate
|
|
169
|
+
param_name: Name of the parameter (for error messages)
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
Validation result containing validation status and error message if any
|
|
173
|
+
"""
|
|
174
|
+
return validate_path_parameter(path, param_name)
|
|
175
|
+
|
|
176
|
+
def is_path_allowed(self, path: str) -> bool:
|
|
177
|
+
"""Check if a path is allowed according to permission settings.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
path: Path to check
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
True if the path is allowed, False otherwise
|
|
184
|
+
"""
|
|
185
|
+
return self.permission_manager.is_path_allowed(path)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
@final
|
|
189
|
+
class ToolRegistry:
|
|
190
|
+
"""Registry for Hanzo MCP tools.
|
|
191
|
+
|
|
192
|
+
Provides functionality for registering tool implementations with an MCP server,
|
|
193
|
+
handling the conversion between tool classes and MCP tool functions.
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
@staticmethod
|
|
197
|
+
def register_tool(mcp_server: FastMCP, tool: BaseTool) -> None:
|
|
198
|
+
"""Register a tool with the MCP server.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
mcp_server: The FastMCP server instance
|
|
202
|
+
tool: The tool to register
|
|
203
|
+
"""
|
|
204
|
+
# Use the tool's register method which handles all the details
|
|
205
|
+
tool.register(mcp_server)
|
|
206
|
+
|
|
207
|
+
@staticmethod
|
|
208
|
+
def register_tools(mcp_server: FastMCP, tools: list[BaseTool]) -> None:
|
|
209
|
+
"""Register multiple tools with the MCP server.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
mcp_server: The FastMCP server instance
|
|
213
|
+
tools: List of tools to register
|
|
214
|
+
"""
|
|
215
|
+
for tool in tools:
|
|
216
|
+
ToolRegistry.register_tool(mcp_server, tool)
|
|
@@ -0,0 +1,444 @@
|
|
|
1
|
+
"""Enhanced Context for Hanzo MCP tools.
|
|
2
|
+
|
|
3
|
+
This module provides an enhanced Context class that wraps the MCP Context
|
|
4
|
+
and adds additional functionality specific to Hanzo tools.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import os
|
|
9
|
+
import time
|
|
10
|
+
from collections.abc import Iterable
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any, ClassVar, final
|
|
13
|
+
|
|
14
|
+
from mcp.server.fastmcp import Context as MCPContext
|
|
15
|
+
from mcp.server.lowlevel.helper_types import ReadResourceContents
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@final
|
|
19
|
+
class ToolContext:
|
|
20
|
+
"""Enhanced context for Hanzo MCP tools.
|
|
21
|
+
|
|
22
|
+
This class wraps the MCP Context and adds additional functionality
|
|
23
|
+
for tracking tool execution, progress reporting, and resource access.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
# Track all active contexts for debugging
|
|
27
|
+
_active_contexts: ClassVar[set["ToolContext"]] = set()
|
|
28
|
+
|
|
29
|
+
def __init__(self, mcp_context: MCPContext) -> None:
|
|
30
|
+
"""Initialize the tool context.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
mcp_context: The underlying MCP Context
|
|
34
|
+
"""
|
|
35
|
+
self._mcp_context: MCPContext = mcp_context
|
|
36
|
+
self._tool_name: str | None = None
|
|
37
|
+
self._execution_id: str | None = None
|
|
38
|
+
|
|
39
|
+
# Add to active contexts
|
|
40
|
+
ToolContext._active_contexts.add(self)
|
|
41
|
+
|
|
42
|
+
def __del__(self) -> None:
|
|
43
|
+
"""Clean up when the context is destroyed."""
|
|
44
|
+
# Remove from active contexts
|
|
45
|
+
ToolContext._active_contexts.discard(self)
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def mcp_context(self) -> MCPContext:
|
|
49
|
+
"""Get the underlying MCP Context.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
The MCP Context
|
|
53
|
+
"""
|
|
54
|
+
return self._mcp_context
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def request_id(self) -> str:
|
|
58
|
+
"""Get the request ID from the MCP context.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
The request ID
|
|
62
|
+
"""
|
|
63
|
+
return self._mcp_context.request_id
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def client_id(self) -> str | None:
|
|
67
|
+
"""Get the client ID from the MCP context.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
The client ID
|
|
71
|
+
"""
|
|
72
|
+
return self._mcp_context.client_id
|
|
73
|
+
|
|
74
|
+
def set_tool_info(self, tool_name: str, execution_id: str | None = None) -> None:
|
|
75
|
+
"""Set information about the currently executing tool.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
tool_name: The name of the tool being executed
|
|
79
|
+
execution_id: Optional unique execution ID
|
|
80
|
+
"""
|
|
81
|
+
self._tool_name = tool_name
|
|
82
|
+
self._execution_id = execution_id
|
|
83
|
+
|
|
84
|
+
async def info(self, message: str) -> None:
|
|
85
|
+
"""Log an informational message.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
message: The message to log
|
|
89
|
+
"""
|
|
90
|
+
await self._mcp_context.info(self._format_message(message))
|
|
91
|
+
|
|
92
|
+
async def debug(self, message: str) -> None:
|
|
93
|
+
"""Log a debug message.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
message: The message to log
|
|
97
|
+
"""
|
|
98
|
+
await self._mcp_context.debug(self._format_message(message))
|
|
99
|
+
|
|
100
|
+
async def warning(self, message: str) -> None:
|
|
101
|
+
"""Log a warning message.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
message: The message to log
|
|
105
|
+
"""
|
|
106
|
+
await self._mcp_context.warning(self._format_message(message))
|
|
107
|
+
|
|
108
|
+
async def error(self, message: str) -> None:
|
|
109
|
+
"""Log an error message.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
message: The message to log
|
|
113
|
+
"""
|
|
114
|
+
await self._mcp_context.error(self._format_message(message))
|
|
115
|
+
|
|
116
|
+
def _format_message(self, message: str) -> str:
|
|
117
|
+
"""Format a message with tool information if available.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
message: The original message
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
The formatted message
|
|
124
|
+
"""
|
|
125
|
+
if self._tool_name:
|
|
126
|
+
if self._execution_id:
|
|
127
|
+
return f"[{self._tool_name}:{self._execution_id}] {message}"
|
|
128
|
+
return f"[{self._tool_name}] {message}"
|
|
129
|
+
return message
|
|
130
|
+
|
|
131
|
+
async def report_progress(self, current: int, total: int) -> None:
|
|
132
|
+
"""Report progress to the client.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
current: Current progress value
|
|
136
|
+
total: Total progress value
|
|
137
|
+
"""
|
|
138
|
+
await self._mcp_context.report_progress(current, total)
|
|
139
|
+
|
|
140
|
+
async def read_resource(self, uri: str) -> Iterable[ReadResourceContents]:
|
|
141
|
+
"""Read a resource via the MCP protocol.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
uri: The resource URI
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
A tuple of (content, mime_type)
|
|
148
|
+
"""
|
|
149
|
+
return await self._mcp_context.read_resource(uri)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
# Factory function to create a ToolContext from an MCP Context
|
|
153
|
+
def create_tool_context(mcp_context: MCPContext) -> ToolContext:
|
|
154
|
+
"""Create a ToolContext from an MCP Context.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
mcp_context: The MCP Context
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
A new ToolContext
|
|
161
|
+
"""
|
|
162
|
+
return ToolContext(mcp_context)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@final
|
|
166
|
+
class DocumentContext:
|
|
167
|
+
"""Manages document context and codebase understanding."""
|
|
168
|
+
|
|
169
|
+
def __init__(self) -> None:
|
|
170
|
+
"""Initialize the document context."""
|
|
171
|
+
self.documents: dict[str, str] = {}
|
|
172
|
+
self.document_metadata: dict[str, dict[str, Any]] = {}
|
|
173
|
+
self.modified_times: dict[str, float] = {}
|
|
174
|
+
self.allowed_paths: set[Path] = set()
|
|
175
|
+
|
|
176
|
+
def add_allowed_path(self, path: str) -> None:
|
|
177
|
+
"""Add a path to the allowed paths.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
path: The path to allow
|
|
181
|
+
"""
|
|
182
|
+
resolved_path: Path = Path(path).resolve()
|
|
183
|
+
self.allowed_paths.add(resolved_path)
|
|
184
|
+
|
|
185
|
+
def is_path_allowed(self, path: str) -> bool:
|
|
186
|
+
"""Check if a path is allowed.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
path: The path to check
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
True if the path is allowed, False otherwise
|
|
193
|
+
"""
|
|
194
|
+
resolved_path: Path = Path(path).resolve()
|
|
195
|
+
|
|
196
|
+
# Check if the path is within any allowed path
|
|
197
|
+
for allowed_path in self.allowed_paths:
|
|
198
|
+
try:
|
|
199
|
+
_ = resolved_path.relative_to(allowed_path)
|
|
200
|
+
return True
|
|
201
|
+
except ValueError:
|
|
202
|
+
continue
|
|
203
|
+
|
|
204
|
+
return False
|
|
205
|
+
|
|
206
|
+
def add_document(
|
|
207
|
+
self, path: str, content: str, metadata: dict[str, Any] | None = None
|
|
208
|
+
) -> None:
|
|
209
|
+
"""Add a document to the context.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
path: The path of the document
|
|
213
|
+
content: The content of the document
|
|
214
|
+
metadata: Optional metadata about the document
|
|
215
|
+
"""
|
|
216
|
+
self.documents[path] = content
|
|
217
|
+
self.modified_times[path] = time.time()
|
|
218
|
+
|
|
219
|
+
if metadata:
|
|
220
|
+
self.document_metadata[path] = metadata
|
|
221
|
+
else:
|
|
222
|
+
# Try to infer metadata
|
|
223
|
+
self.document_metadata[path] = self._infer_metadata(path, content)
|
|
224
|
+
|
|
225
|
+
def get_document(self, path: str) -> str | None:
|
|
226
|
+
"""Get a document from the context.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
path: The path of the document
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
The document content, or None if not found
|
|
233
|
+
"""
|
|
234
|
+
return self.documents.get(path)
|
|
235
|
+
|
|
236
|
+
def get_document_metadata(self, path: str) -> dict[str, Any] | None:
|
|
237
|
+
"""Get document metadata.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
path: The path of the document
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
The document metadata, or None if not found
|
|
244
|
+
"""
|
|
245
|
+
return self.document_metadata.get(path)
|
|
246
|
+
|
|
247
|
+
def update_document(self, path: str, content: str) -> None:
|
|
248
|
+
"""Update a document in the context.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
path: The path of the document
|
|
252
|
+
content: The new content of the document
|
|
253
|
+
"""
|
|
254
|
+
self.documents[path] = content
|
|
255
|
+
self.modified_times[path] = time.time()
|
|
256
|
+
|
|
257
|
+
# Update metadata
|
|
258
|
+
self.document_metadata[path] = self._infer_metadata(path, content)
|
|
259
|
+
|
|
260
|
+
def remove_document(self, path: str) -> None:
|
|
261
|
+
"""Remove a document from the context.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
path: The path of the document
|
|
265
|
+
"""
|
|
266
|
+
if path in self.documents:
|
|
267
|
+
del self.documents[path]
|
|
268
|
+
|
|
269
|
+
if path in self.document_metadata:
|
|
270
|
+
del self.document_metadata[path]
|
|
271
|
+
|
|
272
|
+
if path in self.modified_times:
|
|
273
|
+
del self.modified_times[path]
|
|
274
|
+
|
|
275
|
+
def _infer_metadata(self, path: str, content: str) -> dict[str, Any]:
|
|
276
|
+
"""Infer metadata about a document.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
path: The path of the document
|
|
280
|
+
content: The content of the document
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
Inferred metadata
|
|
284
|
+
"""
|
|
285
|
+
extension: str = Path(path).suffix.lower()
|
|
286
|
+
|
|
287
|
+
metadata: dict[str, Any] = {
|
|
288
|
+
"extension": extension,
|
|
289
|
+
"size": len(content),
|
|
290
|
+
"line_count": content.count("\n") + 1,
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
# Infer language based on extension
|
|
294
|
+
language_map: dict[str, list[str]] = {
|
|
295
|
+
"python": [".py"],
|
|
296
|
+
"javascript": [".js", ".jsx"],
|
|
297
|
+
"typescript": [".ts", ".tsx"],
|
|
298
|
+
"java": [".java"],
|
|
299
|
+
"c++": [".c", ".cpp", ".h", ".hpp"],
|
|
300
|
+
"go": [".go"],
|
|
301
|
+
"rust": [".rs"],
|
|
302
|
+
"ruby": [".rb"],
|
|
303
|
+
"php": [".php"],
|
|
304
|
+
"html": [".html", ".htm"],
|
|
305
|
+
"css": [".css"],
|
|
306
|
+
"markdown": [".md"],
|
|
307
|
+
"json": [".json"],
|
|
308
|
+
"yaml": [".yaml", ".yml"],
|
|
309
|
+
"xml": [".xml"],
|
|
310
|
+
"sql": [".sql"],
|
|
311
|
+
"shell": [".sh", ".bash"],
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
# Find matching language
|
|
315
|
+
for language, extensions in language_map.items():
|
|
316
|
+
if extension in extensions:
|
|
317
|
+
metadata["language"] = language
|
|
318
|
+
break
|
|
319
|
+
else:
|
|
320
|
+
metadata["language"] = "text"
|
|
321
|
+
|
|
322
|
+
return metadata
|
|
323
|
+
|
|
324
|
+
def load_directory(
|
|
325
|
+
self,
|
|
326
|
+
directory: str,
|
|
327
|
+
recursive: bool = True,
|
|
328
|
+
exclude_patterns: list[str] | None = None,
|
|
329
|
+
) -> None:
|
|
330
|
+
"""Load all files in a directory into the context.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
directory: The directory to load
|
|
334
|
+
recursive: Whether to load subdirectories
|
|
335
|
+
exclude_patterns: Patterns to exclude
|
|
336
|
+
"""
|
|
337
|
+
if not self.is_path_allowed(directory):
|
|
338
|
+
raise ValueError(f"Directory not allowed: {directory}")
|
|
339
|
+
|
|
340
|
+
dir_path: Path = Path(directory)
|
|
341
|
+
|
|
342
|
+
if not dir_path.exists() or not dir_path.is_dir():
|
|
343
|
+
raise ValueError(f"Not a valid directory: {directory}")
|
|
344
|
+
|
|
345
|
+
if exclude_patterns is None:
|
|
346
|
+
exclude_patterns = []
|
|
347
|
+
|
|
348
|
+
# Common directories and files to exclude
|
|
349
|
+
default_excludes: list[str] = [
|
|
350
|
+
"__pycache__",
|
|
351
|
+
".git",
|
|
352
|
+
".github",
|
|
353
|
+
".ssh",
|
|
354
|
+
".gnupg",
|
|
355
|
+
".config",
|
|
356
|
+
"node_modules",
|
|
357
|
+
"__pycache__",
|
|
358
|
+
".venv",
|
|
359
|
+
"venv",
|
|
360
|
+
"env",
|
|
361
|
+
".idea",
|
|
362
|
+
".vscode",
|
|
363
|
+
".DS_Store",
|
|
364
|
+
]
|
|
365
|
+
|
|
366
|
+
exclude_patterns.extend(default_excludes)
|
|
367
|
+
|
|
368
|
+
def should_exclude(path: Path) -> bool:
|
|
369
|
+
"""Check if a path should be excluded.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
path: The path to check
|
|
373
|
+
|
|
374
|
+
Returns:
|
|
375
|
+
True if the path should be excluded, False otherwise
|
|
376
|
+
"""
|
|
377
|
+
for pattern in exclude_patterns:
|
|
378
|
+
if pattern.startswith("*"):
|
|
379
|
+
if path.name.endswith(pattern[1:]):
|
|
380
|
+
return True
|
|
381
|
+
elif pattern in str(path):
|
|
382
|
+
return True
|
|
383
|
+
return False
|
|
384
|
+
|
|
385
|
+
# Walk the directory
|
|
386
|
+
for root, dirs, files in os.walk(dir_path):
|
|
387
|
+
# Skip excluded directories
|
|
388
|
+
dirs[:] = [d for d in dirs if not should_exclude(Path(root) / d)]
|
|
389
|
+
|
|
390
|
+
# Process files
|
|
391
|
+
for file in files:
|
|
392
|
+
file_path: Path = Path(root) / file
|
|
393
|
+
|
|
394
|
+
if should_exclude(file_path):
|
|
395
|
+
continue
|
|
396
|
+
|
|
397
|
+
try:
|
|
398
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
|
399
|
+
content: str = f.read()
|
|
400
|
+
|
|
401
|
+
# Add to context
|
|
402
|
+
self.add_document(str(file_path), content)
|
|
403
|
+
except UnicodeDecodeError:
|
|
404
|
+
# Skip binary files
|
|
405
|
+
continue
|
|
406
|
+
|
|
407
|
+
# Stop if not recursive
|
|
408
|
+
if not recursive:
|
|
409
|
+
break
|
|
410
|
+
|
|
411
|
+
def to_json(self) -> str:
|
|
412
|
+
"""Convert the context to a JSON string.
|
|
413
|
+
|
|
414
|
+
Returns:
|
|
415
|
+
A JSON string representation of the context
|
|
416
|
+
"""
|
|
417
|
+
data: dict[str, Any] = {
|
|
418
|
+
"documents": self.documents,
|
|
419
|
+
"metadata": self.document_metadata,
|
|
420
|
+
"modified_times": self.modified_times,
|
|
421
|
+
"allowed_paths": [str(p) for p in self.allowed_paths],
|
|
422
|
+
}
|
|
423
|
+
|
|
424
|
+
return json.dumps(data)
|
|
425
|
+
|
|
426
|
+
@classmethod
|
|
427
|
+
def from_json(cls, json_str: str) -> "DocumentContext":
|
|
428
|
+
"""Create a context from a JSON string.
|
|
429
|
+
|
|
430
|
+
Args:
|
|
431
|
+
json_str: The JSON string
|
|
432
|
+
|
|
433
|
+
Returns:
|
|
434
|
+
A new DocumentContext instance
|
|
435
|
+
"""
|
|
436
|
+
data: dict[str, Any] = json.loads(json_str)
|
|
437
|
+
|
|
438
|
+
context = cls()
|
|
439
|
+
context.documents = data.get("documents", {})
|
|
440
|
+
context.document_metadata = data.get("metadata", {})
|
|
441
|
+
context.modified_times = data.get("modified_times", {})
|
|
442
|
+
context.allowed_paths = set(Path(p) for p in data.get("allowed_paths", []))
|
|
443
|
+
|
|
444
|
+
return context
|