hanzo-mcp 0.3.4__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hanzo-mcp might be problematic. Click here for more details.

Files changed (87) hide show
  1. hanzo_mcp/__init__.py +1 -1
  2. hanzo_mcp/cli.py +123 -160
  3. hanzo_mcp/cli_enhanced.py +438 -0
  4. hanzo_mcp/config/__init__.py +19 -0
  5. hanzo_mcp/config/settings.py +388 -0
  6. hanzo_mcp/config/tool_config.py +197 -0
  7. hanzo_mcp/prompts/__init__.py +117 -0
  8. hanzo_mcp/prompts/compact_conversation.py +77 -0
  9. hanzo_mcp/prompts/create_release.py +38 -0
  10. hanzo_mcp/prompts/project_system.py +120 -0
  11. hanzo_mcp/prompts/project_todo_reminder.py +111 -0
  12. hanzo_mcp/prompts/utils.py +286 -0
  13. hanzo_mcp/server.py +120 -98
  14. hanzo_mcp/tools/__init__.py +107 -31
  15. hanzo_mcp/tools/agent/__init__.py +8 -11
  16. hanzo_mcp/tools/agent/agent_tool.py +290 -224
  17. hanzo_mcp/tools/agent/prompt.py +16 -13
  18. hanzo_mcp/tools/agent/tool_adapter.py +9 -9
  19. hanzo_mcp/tools/common/__init__.py +17 -16
  20. hanzo_mcp/tools/common/base.py +79 -110
  21. hanzo_mcp/tools/common/batch_tool.py +330 -0
  22. hanzo_mcp/tools/common/context.py +26 -292
  23. hanzo_mcp/tools/common/permissions.py +12 -12
  24. hanzo_mcp/tools/common/thinking_tool.py +153 -0
  25. hanzo_mcp/tools/common/validation.py +1 -63
  26. hanzo_mcp/tools/filesystem/__init__.py +88 -41
  27. hanzo_mcp/tools/filesystem/base.py +32 -24
  28. hanzo_mcp/tools/filesystem/content_replace.py +114 -107
  29. hanzo_mcp/tools/filesystem/directory_tree.py +129 -105
  30. hanzo_mcp/tools/filesystem/edit.py +279 -0
  31. hanzo_mcp/tools/filesystem/grep.py +458 -0
  32. hanzo_mcp/tools/filesystem/grep_ast_tool.py +250 -0
  33. hanzo_mcp/tools/filesystem/multi_edit.py +362 -0
  34. hanzo_mcp/tools/filesystem/read.py +255 -0
  35. hanzo_mcp/tools/filesystem/write.py +156 -0
  36. hanzo_mcp/tools/jupyter/__init__.py +41 -29
  37. hanzo_mcp/tools/jupyter/base.py +66 -57
  38. hanzo_mcp/tools/jupyter/{edit_notebook.py → notebook_edit.py} +162 -139
  39. hanzo_mcp/tools/jupyter/notebook_read.py +152 -0
  40. hanzo_mcp/tools/shell/__init__.py +29 -20
  41. hanzo_mcp/tools/shell/base.py +87 -45
  42. hanzo_mcp/tools/shell/bash_session.py +731 -0
  43. hanzo_mcp/tools/shell/bash_session_executor.py +295 -0
  44. hanzo_mcp/tools/shell/command_executor.py +435 -384
  45. hanzo_mcp/tools/shell/run_command.py +284 -131
  46. hanzo_mcp/tools/shell/run_command_windows.py +328 -0
  47. hanzo_mcp/tools/shell/session_manager.py +196 -0
  48. hanzo_mcp/tools/shell/session_storage.py +325 -0
  49. hanzo_mcp/tools/todo/__init__.py +66 -0
  50. hanzo_mcp/tools/todo/base.py +319 -0
  51. hanzo_mcp/tools/todo/todo_read.py +148 -0
  52. hanzo_mcp/tools/todo/todo_write.py +378 -0
  53. hanzo_mcp/tools/vector/__init__.py +95 -0
  54. hanzo_mcp/tools/vector/infinity_store.py +365 -0
  55. hanzo_mcp/tools/vector/project_manager.py +361 -0
  56. hanzo_mcp/tools/vector/vector_index.py +115 -0
  57. hanzo_mcp/tools/vector/vector_search.py +215 -0
  58. {hanzo_mcp-0.3.4.dist-info → hanzo_mcp-0.5.0.dist-info}/METADATA +35 -3
  59. hanzo_mcp-0.5.0.dist-info/RECORD +63 -0
  60. {hanzo_mcp-0.3.4.dist-info → hanzo_mcp-0.5.0.dist-info}/WHEEL +1 -1
  61. hanzo_mcp/tools/agent/base_provider.py +0 -73
  62. hanzo_mcp/tools/agent/litellm_provider.py +0 -45
  63. hanzo_mcp/tools/agent/lmstudio_agent.py +0 -385
  64. hanzo_mcp/tools/agent/lmstudio_provider.py +0 -219
  65. hanzo_mcp/tools/agent/provider_registry.py +0 -120
  66. hanzo_mcp/tools/common/error_handling.py +0 -86
  67. hanzo_mcp/tools/common/logging_config.py +0 -115
  68. hanzo_mcp/tools/common/session.py +0 -91
  69. hanzo_mcp/tools/common/think_tool.py +0 -123
  70. hanzo_mcp/tools/common/version_tool.py +0 -120
  71. hanzo_mcp/tools/filesystem/edit_file.py +0 -287
  72. hanzo_mcp/tools/filesystem/get_file_info.py +0 -170
  73. hanzo_mcp/tools/filesystem/read_files.py +0 -198
  74. hanzo_mcp/tools/filesystem/search_content.py +0 -275
  75. hanzo_mcp/tools/filesystem/write_file.py +0 -162
  76. hanzo_mcp/tools/jupyter/notebook_operations.py +0 -514
  77. hanzo_mcp/tools/jupyter/read_notebook.py +0 -165
  78. hanzo_mcp/tools/project/__init__.py +0 -64
  79. hanzo_mcp/tools/project/analysis.py +0 -882
  80. hanzo_mcp/tools/project/base.py +0 -66
  81. hanzo_mcp/tools/project/project_analyze.py +0 -173
  82. hanzo_mcp/tools/shell/run_script.py +0 -215
  83. hanzo_mcp/tools/shell/script_tool.py +0 -244
  84. hanzo_mcp-0.3.4.dist-info/RECORD +0 -53
  85. {hanzo_mcp-0.3.4.dist-info → hanzo_mcp-0.5.0.dist-info}/entry_points.txt +0 -0
  86. {hanzo_mcp-0.3.4.dist-info → hanzo_mcp-0.5.0.dist-info}/licenses/LICENSE +0 -0
  87. {hanzo_mcp-0.3.4.dist-info → hanzo_mcp-0.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,330 @@
1
+ """Batch tool implementation for Hanzo MCP.
2
+
3
+ This module provides the BatchTool that allows for executing multiple tools in
4
+ parallel or serial depending on their characteristics.
5
+ """
6
+
7
+ import asyncio
8
+ from typing import Annotated, Any, TypedDict, Unpack, final, override
9
+
10
+ from fastmcp import Context as MCPContext
11
+ from fastmcp import FastMCP
12
+ from fastmcp.server.dependencies import get_context
13
+ from pydantic import Field
14
+
15
+ from hanzo_mcp.tools.common.base import BaseTool
16
+ from hanzo_mcp.tools.common.context import create_tool_context
17
+
18
+
19
+ class InvocationItem(TypedDict):
20
+ """A single tool invocation item.
21
+
22
+ Attributes:
23
+ tool_name: The name of the tool to invoke
24
+ input: The input to pass to the tool
25
+ """
26
+
27
+ tool_name: Annotated[
28
+ str,
29
+ Field(
30
+ description="The name of the tool to invoke",
31
+ min_length=1,
32
+ ),
33
+ ]
34
+ input: Annotated[
35
+ dict[str, Any],
36
+ Field(
37
+ description="The input to pass to the tool",
38
+ ),
39
+ ]
40
+
41
+
42
+ Description = Annotated[
43
+ str,
44
+ Field(
45
+ description="A short (3-5 word) description of the batch operation",
46
+ min_length=1,
47
+ ),
48
+ ]
49
+
50
+ Invocations = Annotated[
51
+ list[InvocationItem],
52
+ Field(
53
+ description="The list of tool invocations to execute (required -- you MUST provide at least one tool invocation)",
54
+ min_length=1,
55
+ ),
56
+ ]
57
+
58
+
59
+ class BatchToolParams(TypedDict):
60
+ """Parameters for the BatchTool.
61
+
62
+ Attributes:
63
+ description: A short (3-5 word) description of the batch operation
64
+ invocations: The list of tool invocations to execute (required -- you MUST provide at least one tool invocation)
65
+ """
66
+
67
+ description: Description
68
+ invocations: Invocations
69
+
70
+
71
+ @final
72
+ class BatchTool(BaseTool):
73
+ """Tool for executing multiple tools in a single request.
74
+
75
+ Executes a list of tool invocations in parallel when possible, or
76
+ otherwise serially. Returns the collected results from all invocations.
77
+ """
78
+
79
+ @property
80
+ @override
81
+ def name(self) -> str:
82
+ """Get the tool name.
83
+
84
+ Returns:
85
+ Tool name
86
+ """
87
+ return "batch"
88
+
89
+ @property
90
+ @override
91
+ def description(self) -> str:
92
+ """Get the tool description.
93
+
94
+ Returns:
95
+ Tool description
96
+ """
97
+ return """Batch execution tool that runs multiple tool invocations in a single request.
98
+
99
+ Tools are executed in parallel when possible, and otherwise serially.
100
+ Takes a list of tool invocations (tool_name and input pairs).
101
+ Returns the collected results from all invocations.
102
+ Use this tool when you need to run multiple independent tool operations at once -- it is awesome for speeding up your workflow, reducing both context usage and latency.
103
+ Each tool will respect its own permissions and validation rules.
104
+ The tool's outputs are NOT shown to the user; to answer the user's query, you MUST send a message with the results after the tool call completes, otherwise the user will not see the results.
105
+
106
+ <batch_example>
107
+ When dispatching multiple agents to find necessary information.
108
+ batch(
109
+ description="Update import statements across modules",
110
+ invocations=[
111
+ {tool_name: "dispatch_agent", input: {prompt: "Search for all instances of 'logger' configuration in /app/config directory"}},
112
+ {tool_name: "dispatch_agent", input: {prompt: "Find all test files that reference 'UserService' in /app/tests"}},
113
+ ]
114
+ )
115
+
116
+ Common scenarios for effective batching:
117
+ 1. Reading multiple related files in one operation
118
+ 2. Performing a series of simple mechanical changes
119
+ 3. Running multiple diagnostic commands
120
+ 4. Dispatch multiple agents to complete the task
121
+
122
+ To make a batch call, provide the following:
123
+ 1. description: A short (3-5 word) description of the batch operation
124
+ 2. invocations: List of invocation [{"tool_name": "...", "input": "..."}], tool_name: The name of the tool to invoke,newText: The input to pass to the tool
125
+
126
+
127
+ Available tools in batch call:
128
+ Tool: dispatch_agent,read,directory_tree,grep,grep_ast,run_command,notebook_read
129
+ Not available: think,write,edit,multi_edit,notebook_edit
130
+ """
131
+
132
+ def __init__(self, tools: dict[str, BaseTool]) -> None:
133
+ """Initialize the batch tool.
134
+
135
+ Args:
136
+ tools: Dictionary mapping tool names to tool instances
137
+ """
138
+ self.tools = tools
139
+
140
+ @override
141
+ async def call(
142
+ self,
143
+ ctx: MCPContext,
144
+ **params: Unpack[BatchToolParams],
145
+ ) -> str:
146
+ """Execute the tool with the given parameters.
147
+
148
+ Args:
149
+ ctx: MCP context
150
+ **params: Tool parameters
151
+
152
+ Returns:
153
+ Tool result
154
+ """
155
+ tool_ctx = create_tool_context(ctx)
156
+ tool_ctx.set_tool_info(self.name)
157
+
158
+ # Extract parameters
159
+ description = params.get("description")
160
+ invocations: list[dict[str, Any]] = params.get("invocations", list())
161
+
162
+ # Validate required parameters
163
+ if not description:
164
+ await tool_ctx.error(
165
+ "Parameter 'description' is required but was None or empty"
166
+ )
167
+ return "Error: Parameter 'description' is required but was None or empty"
168
+
169
+ if not invocations:
170
+ await tool_ctx.error(
171
+ "Parameter 'invocations' is required but was None or empty"
172
+ )
173
+ return "Error: Parameter 'invocations' is required but was None or empty"
174
+
175
+ if not isinstance(invocations, list) or len(invocations) == 0:
176
+ await tool_ctx.error("Parameter 'invocations' must be a non-empty list")
177
+ return "Error: Parameter 'invocations' must be a non-empty list"
178
+
179
+ await tool_ctx.info(
180
+ f"Executing batch operation: {description} ({len(invocations)} invocations)"
181
+ )
182
+
183
+ # Execute all tool invocations in parallel
184
+ tasks: list[asyncio.Future[dict[str, Any]]] = []
185
+ invocation_map: dict[
186
+ asyncio.Future[dict[str, Any]], dict[str, Any]
187
+ ] = {} # Map task Future to invocation
188
+
189
+ for i, invocation in enumerate(invocations):
190
+ # Extract tool name and input from invocation
191
+ tool_name: str = invocation.get("tool_name", "")
192
+ tool_input: dict[str, Any] = invocation.get("input", {})
193
+
194
+ # Validate tool invocation
195
+ if not tool_name:
196
+ error_message = f"Tool name is required in invocation {i}"
197
+ await tool_ctx.error(error_message)
198
+ # Add direct result for this invocation
199
+ tasks.append(asyncio.Future())
200
+ tasks[-1].set_result(
201
+ {"invocation": invocation, "result": f"Error: {error_message}"}
202
+ )
203
+ invocation_map[tasks[-1]] = invocation
204
+ continue
205
+
206
+ # Check if the tool exists
207
+ if tool_name not in self.tools:
208
+ error_message = f"Tool '{tool_name}' not found"
209
+ await tool_ctx.error(error_message)
210
+ # Add direct result for this invocation
211
+ tasks.append(asyncio.Future())
212
+ tasks[-1].set_result(
213
+ {"invocation": invocation, "result": f"Error: {error_message}"}
214
+ )
215
+ invocation_map[tasks[-1]] = invocation
216
+ continue
217
+
218
+ # Create a task for this tool invocation
219
+ try:
220
+ tool = self.tools[tool_name]
221
+ await tool_ctx.info(f"Creating task for tool: {tool_name}")
222
+
223
+ # Create coroutine for this tool execution
224
+ async def execute_tool(
225
+ tool_obj: BaseTool, tool_name: str, tool_input: dict[str, Any]
226
+ ):
227
+ try:
228
+ await tool_ctx.info(f"Executing tool: {tool_name}")
229
+ result = await tool_obj.call(ctx, **tool_input)
230
+ await tool_ctx.info(f"Tool '{tool_name}' execution completed")
231
+ return {
232
+ "invocation": {"tool_name": tool_name, "input": tool_input},
233
+ "result": result,
234
+ }
235
+ except Exception as e:
236
+ error_message = f"Error executing tool '{tool_name}': {str(e)}"
237
+ await tool_ctx.error(error_message)
238
+ return {
239
+ "invocation": {"tool_name": tool_name, "input": tool_input},
240
+ "result": f"Error: {error_message}",
241
+ }
242
+
243
+ # Schedule the task
244
+ task = asyncio.create_task(execute_tool(tool, tool_name, tool_input))
245
+ tasks.append(task)
246
+ invocation_map[task] = invocation
247
+ except Exception as e:
248
+ error_message = f"Error scheduling tool '{tool_name}': {str(e)}"
249
+ await tool_ctx.error(error_message)
250
+ # Add direct result for this invocation
251
+ tasks.append(asyncio.Future())
252
+ tasks[-1].set_result(
253
+ {"invocation": invocation, "result": f"Error: {error_message}"}
254
+ )
255
+ invocation_map[tasks[-1]] = invocation
256
+
257
+ # Wait for all tasks to complete
258
+ await tool_ctx.info(f"Waiting for {len(tasks)} tool executions to complete")
259
+ results: list[dict[str, Any]] = []
260
+
261
+ # As tasks complete, collect their results
262
+ for task in asyncio.as_completed(tasks):
263
+ try:
264
+ result = await task
265
+ results.append(result)
266
+ except Exception as e:
267
+ invocation = invocation_map[task]
268
+ tool_name: str = invocation.get("tool_name", "unknown")
269
+ error_message = f"Unexpected error in tool '{tool_name}': {str(e)}"
270
+ await tool_ctx.error(error_message)
271
+ results.append(
272
+ {"invocation": invocation, "result": f"Error: {error_message}"}
273
+ )
274
+
275
+ # Format the results
276
+ formatted_results = self._format_results(results)
277
+ await tool_ctx.info(
278
+ f"Batch operation '{description}' completed with {len(results)} results"
279
+ )
280
+
281
+ return formatted_results
282
+
283
+ def _format_results(self, results: list[dict[str, dict[str, Any]]]) -> str:
284
+ """Format the results from multiple tool invocations.
285
+
286
+ Args:
287
+ results: List of tool invocation results
288
+
289
+ Returns:
290
+ Formatted results string
291
+ """
292
+ formatted_parts: list[str] = []
293
+ for i, result in enumerate(results):
294
+ invocation: dict[str, Any] = result["invocation"]
295
+ tool_name: str = invocation.get("tool_name", "unknown")
296
+
297
+ # Add the result header
298
+ formatted_parts.append(f"### Result {i + 1}: {tool_name}")
299
+ # Add the result content - use multi-line code blocks for code outputs
300
+ if "\n" in result["result"]:
301
+ formatted_parts.append(f"```\n{result['result']}\n```")
302
+ else:
303
+ formatted_parts.append(result["result"])
304
+ # Add a separator
305
+ formatted_parts.append("")
306
+
307
+ return "\n".join(formatted_parts)
308
+
309
+ @override
310
+ def register(self, mcp_server: FastMCP) -> None:
311
+ """Register this batch tool with the MCP server.
312
+
313
+ Creates a wrapper function with explicitly defined parameters that match
314
+ the tool's parameter schema and registers it with the MCP server.
315
+
316
+ Args:
317
+ mcp_server: The FastMCP server instance
318
+ """
319
+ tool_self = self # Create a reference to self for use in the closure
320
+
321
+ @mcp_server.tool(name=self.name, description=self.description)
322
+ async def batch(
323
+ ctx: MCPContext,
324
+ description: Description,
325
+ invocations: Invocations,
326
+ ) -> str:
327
+ ctx = get_context()
328
+ return await tool_self.call(
329
+ ctx, description=description, invocations=invocations
330
+ )
@@ -11,7 +11,7 @@ from collections.abc import Iterable
11
11
  from pathlib import Path
12
12
  from typing import Any, ClassVar, final
13
13
 
14
- from mcp.server.fastmcp import Context as MCPContext
14
+ from fastmcp import Context as MCPContext
15
15
  from mcp.server.lowlevel.helper_types import ReadResourceContents
16
16
 
17
17
 
@@ -87,7 +87,11 @@ class ToolContext:
87
87
  Args:
88
88
  message: The message to log
89
89
  """
90
- await self._mcp_context.info(self._format_message(message))
90
+ try:
91
+ await self._mcp_context.info(self._format_message(message))
92
+ except Exception:
93
+ # Silently ignore errors when client has disconnected
94
+ pass
91
95
 
92
96
  async def debug(self, message: str) -> None:
93
97
  """Log a debug message.
@@ -95,7 +99,11 @@ class ToolContext:
95
99
  Args:
96
100
  message: The message to log
97
101
  """
98
- await self._mcp_context.debug(self._format_message(message))
102
+ try:
103
+ await self._mcp_context.debug(self._format_message(message))
104
+ except Exception:
105
+ # Silently ignore errors when client has disconnected
106
+ pass
99
107
 
100
108
  async def warning(self, message: str) -> None:
101
109
  """Log a warning message.
@@ -103,7 +111,11 @@ class ToolContext:
103
111
  Args:
104
112
  message: The message to log
105
113
  """
106
- await self._mcp_context.warning(self._format_message(message))
114
+ try:
115
+ await self._mcp_context.warning(self._format_message(message))
116
+ except Exception:
117
+ # Silently ignore errors when client has disconnected
118
+ pass
107
119
 
108
120
  async def error(self, message: str) -> None:
109
121
  """Log an error message.
@@ -111,7 +123,11 @@ class ToolContext:
111
123
  Args:
112
124
  message: The message to log
113
125
  """
114
- await self._mcp_context.error(self._format_message(message))
126
+ try:
127
+ await self._mcp_context.error(self._format_message(message))
128
+ except Exception:
129
+ # Silently ignore errors when client has disconnected
130
+ pass
115
131
 
116
132
  def _format_message(self, message: str) -> str:
117
133
  """Format a message with tool information if available.
@@ -135,7 +151,11 @@ class ToolContext:
135
151
  current: Current progress value
136
152
  total: Total progress value
137
153
  """
138
- await self._mcp_context.report_progress(current, total)
154
+ try:
155
+ await self._mcp_context.report_progress(current, total)
156
+ except Exception:
157
+ # Silently ignore errors when client has disconnected
158
+ pass
139
159
 
140
160
  async def read_resource(self, uri: str) -> Iterable[ReadResourceContents]:
141
161
  """Read a resource via the MCP protocol.
@@ -160,289 +180,3 @@ def create_tool_context(mcp_context: MCPContext) -> ToolContext:
160
180
  A new ToolContext
161
181
  """
162
182
  return ToolContext(mcp_context)
163
-
164
-
165
- @final
166
- class DocumentContext:
167
- """Manages document context and codebase understanding."""
168
-
169
- def __init__(self) -> None:
170
- """Initialize the document context."""
171
- self.documents: dict[str, str] = {}
172
- self.document_metadata: dict[str, dict[str, Any]] = {}
173
- self.modified_times: dict[str, float] = {}
174
- self.allowed_paths: set[Path] = set()
175
-
176
- def add_allowed_path(self, path: str) -> None:
177
- """Add a path to the allowed paths.
178
-
179
- Args:
180
- path: The path to allow
181
- """
182
- # Expand user path (e.g., ~/ or $HOME)
183
- expanded_path = os.path.expanduser(path)
184
- resolved_path: Path = Path(expanded_path).resolve()
185
- self.allowed_paths.add(resolved_path)
186
-
187
- def is_path_allowed(self, path: str) -> bool:
188
- """Check if a path is allowed.
189
-
190
- Args:
191
- path: The path to check
192
-
193
- Returns:
194
- True if the path is allowed, False otherwise
195
- """
196
- # Expand user path (e.g., ~/ or $HOME)
197
- expanded_path = os.path.expanduser(path)
198
- resolved_path: Path = Path(expanded_path).resolve()
199
-
200
- # Check if the path is within any allowed path
201
- for allowed_path in self.allowed_paths:
202
- try:
203
- _ = resolved_path.relative_to(allowed_path)
204
- return True
205
- except ValueError:
206
- continue
207
-
208
- return False
209
-
210
- def add_document(
211
- self, path: str, content: str, metadata: dict[str, Any] | None = None
212
- ) -> None:
213
- """Add a document to the context.
214
-
215
- Args:
216
- path: The path of the document
217
- content: The content of the document
218
- metadata: Optional metadata about the document
219
- """
220
- self.documents[path] = content
221
- self.modified_times[path] = time.time()
222
-
223
- if metadata:
224
- self.document_metadata[path] = metadata
225
- else:
226
- # Try to infer metadata
227
- self.document_metadata[path] = self._infer_metadata(path, content)
228
-
229
- def get_document(self, path: str) -> str | None:
230
- """Get a document from the context.
231
-
232
- Args:
233
- path: The path of the document
234
-
235
- Returns:
236
- The document content, or None if not found
237
- """
238
- return self.documents.get(path)
239
-
240
- def get_document_metadata(self, path: str) -> dict[str, Any] | None:
241
- """Get document metadata.
242
-
243
- Args:
244
- path: The path of the document
245
-
246
- Returns:
247
- The document metadata, or None if not found
248
- """
249
- return self.document_metadata.get(path)
250
-
251
- def update_document(self, path: str, content: str) -> None:
252
- """Update a document in the context.
253
-
254
- Args:
255
- path: The path of the document
256
- content: The new content of the document
257
- """
258
- self.documents[path] = content
259
- self.modified_times[path] = time.time()
260
-
261
- # Update metadata
262
- self.document_metadata[path] = self._infer_metadata(path, content)
263
-
264
- def remove_document(self, path: str) -> None:
265
- """Remove a document from the context.
266
-
267
- Args:
268
- path: The path of the document
269
- """
270
- if path in self.documents:
271
- del self.documents[path]
272
-
273
- if path in self.document_metadata:
274
- del self.document_metadata[path]
275
-
276
- if path in self.modified_times:
277
- del self.modified_times[path]
278
-
279
- def _infer_metadata(self, path: str, content: str) -> dict[str, Any]:
280
- """Infer metadata about a document.
281
-
282
- Args:
283
- path: The path of the document
284
- content: The content of the document
285
-
286
- Returns:
287
- Inferred metadata
288
- """
289
- extension: str = Path(path).suffix.lower()
290
-
291
- metadata: dict[str, Any] = {
292
- "extension": extension,
293
- "size": len(content),
294
- "line_count": content.count("\n") + 1,
295
- }
296
-
297
- # Infer language based on extension
298
- language_map: dict[str, list[str]] = {
299
- "python": [".py"],
300
- "javascript": [".js", ".jsx"],
301
- "typescript": [".ts", ".tsx"],
302
- "java": [".java"],
303
- "c++": [".c", ".cpp", ".h", ".hpp"],
304
- "go": [".go"],
305
- "rust": [".rs"],
306
- "ruby": [".rb"],
307
- "php": [".php"],
308
- "html": [".html", ".htm"],
309
- "css": [".css"],
310
- "markdown": [".md"],
311
- "json": [".json"],
312
- "yaml": [".yaml", ".yml"],
313
- "xml": [".xml"],
314
- "sql": [".sql"],
315
- "shell": [".sh", ".bash"],
316
- }
317
-
318
- # Find matching language
319
- for language, extensions in language_map.items():
320
- if extension in extensions:
321
- metadata["language"] = language
322
- break
323
- else:
324
- metadata["language"] = "text"
325
-
326
- return metadata
327
-
328
- def load_directory(
329
- self,
330
- directory: str,
331
- recursive: bool = True,
332
- exclude_patterns: list[str] | None = None,
333
- ) -> None:
334
- """Load all files in a directory into the context.
335
-
336
- Args:
337
- directory: The directory to load
338
- recursive: Whether to load subdirectories
339
- exclude_patterns: Patterns to exclude
340
- """
341
- if not self.is_path_allowed(directory):
342
- raise ValueError(f"Directory not allowed: {directory}")
343
-
344
- dir_path: Path = Path(directory)
345
-
346
- if not dir_path.exists() or not dir_path.is_dir():
347
- raise ValueError(f"Not a valid directory: {directory}")
348
-
349
- if exclude_patterns is None:
350
- exclude_patterns = []
351
-
352
- # Common directories and files to exclude
353
- default_excludes: list[str] = [
354
- "__pycache__",
355
- ".git",
356
- ".github",
357
- ".ssh",
358
- ".gnupg",
359
- ".config",
360
- "node_modules",
361
- "__pycache__",
362
- ".venv",
363
- "venv",
364
- "env",
365
- ".idea",
366
- ".vscode",
367
- ".DS_Store",
368
- ]
369
-
370
- exclude_patterns.extend(default_excludes)
371
-
372
- def should_exclude(path: Path) -> bool:
373
- """Check if a path should be excluded.
374
-
375
- Args:
376
- path: The path to check
377
-
378
- Returns:
379
- True if the path should be excluded, False otherwise
380
- """
381
- for pattern in exclude_patterns:
382
- if pattern.startswith("*"):
383
- if path.name.endswith(pattern[1:]):
384
- return True
385
- elif pattern in str(path):
386
- return True
387
- return False
388
-
389
- # Walk the directory
390
- for root, dirs, files in os.walk(dir_path):
391
- # Skip excluded directories
392
- dirs[:] = [d for d in dirs if not should_exclude(Path(root) / d)]
393
-
394
- # Process files
395
- for file in files:
396
- file_path: Path = Path(root) / file
397
-
398
- if should_exclude(file_path):
399
- continue
400
-
401
- try:
402
- with open(file_path, "r", encoding="utf-8") as f:
403
- content: str = f.read()
404
-
405
- # Add to context
406
- self.add_document(str(file_path), content)
407
- except UnicodeDecodeError:
408
- # Skip binary files
409
- continue
410
-
411
- # Stop if not recursive
412
- if not recursive:
413
- break
414
-
415
- def to_json(self) -> str:
416
- """Convert the context to a JSON string.
417
-
418
- Returns:
419
- A JSON string representation of the context
420
- """
421
- data: dict[str, Any] = {
422
- "documents": self.documents,
423
- "metadata": self.document_metadata,
424
- "modified_times": self.modified_times,
425
- "allowed_paths": [str(p) for p in self.allowed_paths],
426
- }
427
-
428
- return json.dumps(data)
429
-
430
- @classmethod
431
- def from_json(cls, json_str: str) -> "DocumentContext":
432
- """Create a context from a JSON string.
433
-
434
- Args:
435
- json_str: The JSON string
436
-
437
- Returns:
438
- A new DocumentContext instance
439
- """
440
- data: dict[str, Any] = json.loads(json_str)
441
-
442
- context = cls()
443
- context.documents = data.get("documents", {})
444
- context.document_metadata = data.get("metadata", {})
445
- context.modified_times = data.get("modified_times", {})
446
- context.allowed_paths = set(Path(p) for p in data.get("allowed_paths", []))
447
-
448
- return context