hanzo-mcp 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hanzo_mcp/__init__.py +3 -0
- hanzo_mcp/cli.py +155 -0
- hanzo_mcp/server.py +125 -0
- hanzo_mcp/tools/__init__.py +62 -0
- hanzo_mcp/tools/common/__init__.py +1 -0
- hanzo_mcp/tools/common/context.py +444 -0
- hanzo_mcp/tools/common/permissions.py +253 -0
- hanzo_mcp/tools/common/thinking.py +65 -0
- hanzo_mcp/tools/common/validation.py +124 -0
- hanzo_mcp/tools/filesystem/__init__.py +9 -0
- hanzo_mcp/tools/filesystem/file_operations.py +1050 -0
- hanzo_mcp/tools/jupyter/__init__.py +8 -0
- hanzo_mcp/tools/jupyter/notebook_operations.py +554 -0
- hanzo_mcp/tools/project/__init__.py +1 -0
- hanzo_mcp/tools/project/analysis.py +879 -0
- hanzo_mcp/tools/shell/__init__.py +1 -0
- hanzo_mcp/tools/shell/command_executor.py +1001 -0
- hanzo_mcp-0.1.21.dist-info/METADATA +168 -0
- hanzo_mcp-0.1.21.dist-info/RECORD +23 -0
- hanzo_mcp-0.1.21.dist-info/WHEEL +5 -0
- hanzo_mcp-0.1.21.dist-info/entry_points.txt +2 -0
- hanzo_mcp-0.1.21.dist-info/licenses/LICENSE +21 -0
- hanzo_mcp-0.1.21.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,444 @@
|
|
|
1
|
+
"""Enhanced Context for Hanzo Dev MCP tools.
|
|
2
|
+
|
|
3
|
+
This module provides an enhanced Context class that wraps the MCP Context
|
|
4
|
+
and adds additional functionality specific to Claude Code tools.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import os
|
|
9
|
+
import time
|
|
10
|
+
from collections.abc import Iterable
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any, ClassVar, final
|
|
13
|
+
|
|
14
|
+
from mcp.server.fastmcp import Context as MCPContext
|
|
15
|
+
from mcp.server.lowlevel.helper_types import ReadResourceContents
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@final
|
|
19
|
+
class ToolContext:
|
|
20
|
+
"""Enhanced context for Hanzo Dev MCP tools.
|
|
21
|
+
|
|
22
|
+
This class wraps the MCP Context and adds additional functionality
|
|
23
|
+
for tracking tool execution, progress reporting, and resource access.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
# Track all active contexts for debugging
|
|
27
|
+
_active_contexts: ClassVar[set["ToolContext"]] = set()
|
|
28
|
+
|
|
29
|
+
def __init__(self, mcp_context: MCPContext) -> None:
|
|
30
|
+
"""Initialize the tool context.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
mcp_context: The underlying MCP Context
|
|
34
|
+
"""
|
|
35
|
+
self._mcp_context: MCPContext = mcp_context
|
|
36
|
+
self._tool_name: str | None = None
|
|
37
|
+
self._execution_id: str | None = None
|
|
38
|
+
|
|
39
|
+
# Add to active contexts
|
|
40
|
+
ToolContext._active_contexts.add(self)
|
|
41
|
+
|
|
42
|
+
def __del__(self) -> None:
|
|
43
|
+
"""Clean up when the context is destroyed."""
|
|
44
|
+
# Remove from active contexts
|
|
45
|
+
ToolContext._active_contexts.discard(self)
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def mcp_context(self) -> MCPContext:
|
|
49
|
+
"""Get the underlying MCP Context.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
The MCP Context
|
|
53
|
+
"""
|
|
54
|
+
return self._mcp_context
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def request_id(self) -> str:
|
|
58
|
+
"""Get the request ID from the MCP context.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
The request ID
|
|
62
|
+
"""
|
|
63
|
+
return self._mcp_context.request_id
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def client_id(self) -> str | None:
|
|
67
|
+
"""Get the client ID from the MCP context.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
The client ID
|
|
71
|
+
"""
|
|
72
|
+
return self._mcp_context.client_id
|
|
73
|
+
|
|
74
|
+
def set_tool_info(self, tool_name: str, execution_id: str | None = None) -> None:
|
|
75
|
+
"""Set information about the currently executing tool.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
tool_name: The name of the tool being executed
|
|
79
|
+
execution_id: Optional unique execution ID
|
|
80
|
+
"""
|
|
81
|
+
self._tool_name = tool_name
|
|
82
|
+
self._execution_id = execution_id
|
|
83
|
+
|
|
84
|
+
async def info(self, message: str) -> None:
|
|
85
|
+
"""Log an informational message.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
message: The message to log
|
|
89
|
+
"""
|
|
90
|
+
await self._mcp_context.info(self._format_message(message))
|
|
91
|
+
|
|
92
|
+
async def debug(self, message: str) -> None:
|
|
93
|
+
"""Log a debug message.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
message: The message to log
|
|
97
|
+
"""
|
|
98
|
+
await self._mcp_context.debug(self._format_message(message))
|
|
99
|
+
|
|
100
|
+
async def warning(self, message: str) -> None:
|
|
101
|
+
"""Log a warning message.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
message: The message to log
|
|
105
|
+
"""
|
|
106
|
+
await self._mcp_context.warning(self._format_message(message))
|
|
107
|
+
|
|
108
|
+
async def error(self, message: str) -> None:
|
|
109
|
+
"""Log an error message.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
message: The message to log
|
|
113
|
+
"""
|
|
114
|
+
await self._mcp_context.error(self._format_message(message))
|
|
115
|
+
|
|
116
|
+
def _format_message(self, message: str) -> str:
|
|
117
|
+
"""Format a message with tool information if available.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
message: The original message
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
The formatted message
|
|
124
|
+
"""
|
|
125
|
+
if self._tool_name:
|
|
126
|
+
if self._execution_id:
|
|
127
|
+
return f"[{self._tool_name}:{self._execution_id}] {message}"
|
|
128
|
+
return f"[{self._tool_name}] {message}"
|
|
129
|
+
return message
|
|
130
|
+
|
|
131
|
+
async def report_progress(self, current: int, total: int) -> None:
|
|
132
|
+
"""Report progress to the client.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
current: Current progress value
|
|
136
|
+
total: Total progress value
|
|
137
|
+
"""
|
|
138
|
+
await self._mcp_context.report_progress(current, total)
|
|
139
|
+
|
|
140
|
+
async def read_resource(self, uri: str) -> Iterable[ReadResourceContents]:
|
|
141
|
+
"""Read a resource via the MCP protocol.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
uri: The resource URI
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
A tuple of (content, mime_type)
|
|
148
|
+
"""
|
|
149
|
+
return await self._mcp_context.read_resource(uri)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
# Factory function to create a ToolContext from an MCP Context
|
|
153
|
+
def create_tool_context(mcp_context: MCPContext) -> ToolContext:
|
|
154
|
+
"""Create a ToolContext from an MCP Context.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
mcp_context: The MCP Context
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
A new ToolContext
|
|
161
|
+
"""
|
|
162
|
+
return ToolContext(mcp_context)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@final
|
|
166
|
+
class DocumentContext:
|
|
167
|
+
"""Manages document context and codebase understanding."""
|
|
168
|
+
|
|
169
|
+
def __init__(self) -> None:
|
|
170
|
+
"""Initialize the document context."""
|
|
171
|
+
self.documents: dict[str, str] = {}
|
|
172
|
+
self.document_metadata: dict[str, dict[str, Any]] = {}
|
|
173
|
+
self.modified_times: dict[str, float] = {}
|
|
174
|
+
self.allowed_paths: set[Path] = set()
|
|
175
|
+
|
|
176
|
+
def add_allowed_path(self, path: str) -> None:
|
|
177
|
+
"""Add a path to the allowed paths.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
path: The path to allow
|
|
181
|
+
"""
|
|
182
|
+
resolved_path: Path = Path(path).resolve()
|
|
183
|
+
self.allowed_paths.add(resolved_path)
|
|
184
|
+
|
|
185
|
+
def is_path_allowed(self, path: str) -> bool:
|
|
186
|
+
"""Check if a path is allowed.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
path: The path to check
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
True if the path is allowed, False otherwise
|
|
193
|
+
"""
|
|
194
|
+
resolved_path: Path = Path(path).resolve()
|
|
195
|
+
|
|
196
|
+
# Check if the path is within any allowed path
|
|
197
|
+
for allowed_path in self.allowed_paths:
|
|
198
|
+
try:
|
|
199
|
+
_ = resolved_path.relative_to(allowed_path)
|
|
200
|
+
return True
|
|
201
|
+
except ValueError:
|
|
202
|
+
continue
|
|
203
|
+
|
|
204
|
+
return False
|
|
205
|
+
|
|
206
|
+
def add_document(
|
|
207
|
+
self, path: str, content: str, metadata: dict[str, Any] | None = None
|
|
208
|
+
) -> None:
|
|
209
|
+
"""Add a document to the context.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
path: The path of the document
|
|
213
|
+
content: The content of the document
|
|
214
|
+
metadata: Optional metadata about the document
|
|
215
|
+
"""
|
|
216
|
+
self.documents[path] = content
|
|
217
|
+
self.modified_times[path] = time.time()
|
|
218
|
+
|
|
219
|
+
if metadata:
|
|
220
|
+
self.document_metadata[path] = metadata
|
|
221
|
+
else:
|
|
222
|
+
# Try to infer metadata
|
|
223
|
+
self.document_metadata[path] = self._infer_metadata(path, content)
|
|
224
|
+
|
|
225
|
+
def get_document(self, path: str) -> str | None:
|
|
226
|
+
"""Get a document from the context.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
path: The path of the document
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
The document content, or None if not found
|
|
233
|
+
"""
|
|
234
|
+
return self.documents.get(path)
|
|
235
|
+
|
|
236
|
+
def get_document_metadata(self, path: str) -> dict[str, Any] | None:
|
|
237
|
+
"""Get document metadata.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
path: The path of the document
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
The document metadata, or None if not found
|
|
244
|
+
"""
|
|
245
|
+
return self.document_metadata.get(path)
|
|
246
|
+
|
|
247
|
+
def update_document(self, path: str, content: str) -> None:
|
|
248
|
+
"""Update a document in the context.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
path: The path of the document
|
|
252
|
+
content: The new content of the document
|
|
253
|
+
"""
|
|
254
|
+
self.documents[path] = content
|
|
255
|
+
self.modified_times[path] = time.time()
|
|
256
|
+
|
|
257
|
+
# Update metadata
|
|
258
|
+
self.document_metadata[path] = self._infer_metadata(path, content)
|
|
259
|
+
|
|
260
|
+
def remove_document(self, path: str) -> None:
|
|
261
|
+
"""Remove a document from the context.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
path: The path of the document
|
|
265
|
+
"""
|
|
266
|
+
if path in self.documents:
|
|
267
|
+
del self.documents[path]
|
|
268
|
+
|
|
269
|
+
if path in self.document_metadata:
|
|
270
|
+
del self.document_metadata[path]
|
|
271
|
+
|
|
272
|
+
if path in self.modified_times:
|
|
273
|
+
del self.modified_times[path]
|
|
274
|
+
|
|
275
|
+
def _infer_metadata(self, path: str, content: str) -> dict[str, Any]:
|
|
276
|
+
"""Infer metadata about a document.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
path: The path of the document
|
|
280
|
+
content: The content of the document
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
Inferred metadata
|
|
284
|
+
"""
|
|
285
|
+
extension: str = Path(path).suffix.lower()
|
|
286
|
+
|
|
287
|
+
metadata: dict[str, Any] = {
|
|
288
|
+
"extension": extension,
|
|
289
|
+
"size": len(content),
|
|
290
|
+
"line_count": content.count("\n") + 1,
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
# Infer language based on extension
|
|
294
|
+
language_map: dict[str, list[str]] = {
|
|
295
|
+
"python": [".py"],
|
|
296
|
+
"javascript": [".js", ".jsx"],
|
|
297
|
+
"typescript": [".ts", ".tsx"],
|
|
298
|
+
"java": [".java"],
|
|
299
|
+
"c++": [".c", ".cpp", ".h", ".hpp"],
|
|
300
|
+
"go": [".go"],
|
|
301
|
+
"rust": [".rs"],
|
|
302
|
+
"ruby": [".rb"],
|
|
303
|
+
"php": [".php"],
|
|
304
|
+
"html": [".html", ".htm"],
|
|
305
|
+
"css": [".css"],
|
|
306
|
+
"markdown": [".md"],
|
|
307
|
+
"json": [".json"],
|
|
308
|
+
"yaml": [".yaml", ".yml"],
|
|
309
|
+
"xml": [".xml"],
|
|
310
|
+
"sql": [".sql"],
|
|
311
|
+
"shell": [".sh", ".bash"],
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
# Find matching language
|
|
315
|
+
for language, extensions in language_map.items():
|
|
316
|
+
if extension in extensions:
|
|
317
|
+
metadata["language"] = language
|
|
318
|
+
break
|
|
319
|
+
else:
|
|
320
|
+
metadata["language"] = "text"
|
|
321
|
+
|
|
322
|
+
return metadata
|
|
323
|
+
|
|
324
|
+
def load_directory(
|
|
325
|
+
self,
|
|
326
|
+
directory: str,
|
|
327
|
+
recursive: bool = True,
|
|
328
|
+
exclude_patterns: list[str] | None = None,
|
|
329
|
+
) -> None:
|
|
330
|
+
"""Load all files in a directory into the context.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
directory: The directory to load
|
|
334
|
+
recursive: Whether to load subdirectories
|
|
335
|
+
exclude_patterns: Patterns to exclude
|
|
336
|
+
"""
|
|
337
|
+
if not self.is_path_allowed(directory):
|
|
338
|
+
raise ValueError(f"Directory not allowed: {directory}")
|
|
339
|
+
|
|
340
|
+
dir_path: Path = Path(directory)
|
|
341
|
+
|
|
342
|
+
if not dir_path.exists() or not dir_path.is_dir():
|
|
343
|
+
raise ValueError(f"Not a valid directory: {directory}")
|
|
344
|
+
|
|
345
|
+
if exclude_patterns is None:
|
|
346
|
+
exclude_patterns = []
|
|
347
|
+
|
|
348
|
+
# Common directories and files to exclude
|
|
349
|
+
default_excludes: list[str] = [
|
|
350
|
+
"__pycache__",
|
|
351
|
+
".git",
|
|
352
|
+
".github",
|
|
353
|
+
".ssh",
|
|
354
|
+
".gnupg",
|
|
355
|
+
".config",
|
|
356
|
+
"node_modules",
|
|
357
|
+
"__pycache__",
|
|
358
|
+
".venv",
|
|
359
|
+
"venv",
|
|
360
|
+
"env",
|
|
361
|
+
".idea",
|
|
362
|
+
".vscode",
|
|
363
|
+
".DS_Store",
|
|
364
|
+
]
|
|
365
|
+
|
|
366
|
+
exclude_patterns.extend(default_excludes)
|
|
367
|
+
|
|
368
|
+
def should_exclude(path: Path) -> bool:
|
|
369
|
+
"""Check if a path should be excluded.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
path: The path to check
|
|
373
|
+
|
|
374
|
+
Returns:
|
|
375
|
+
True if the path should be excluded, False otherwise
|
|
376
|
+
"""
|
|
377
|
+
for pattern in exclude_patterns:
|
|
378
|
+
if pattern.startswith("*"):
|
|
379
|
+
if path.name.endswith(pattern[1:]):
|
|
380
|
+
return True
|
|
381
|
+
elif pattern in str(path):
|
|
382
|
+
return True
|
|
383
|
+
return False
|
|
384
|
+
|
|
385
|
+
# Walk the directory
|
|
386
|
+
for root, dirs, files in os.walk(dir_path):
|
|
387
|
+
# Skip excluded directories
|
|
388
|
+
dirs[:] = [d for d in dirs if not should_exclude(Path(root) / d)]
|
|
389
|
+
|
|
390
|
+
# Process files
|
|
391
|
+
for file in files:
|
|
392
|
+
file_path: Path = Path(root) / file
|
|
393
|
+
|
|
394
|
+
if should_exclude(file_path):
|
|
395
|
+
continue
|
|
396
|
+
|
|
397
|
+
try:
|
|
398
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
|
399
|
+
content: str = f.read()
|
|
400
|
+
|
|
401
|
+
# Add to context
|
|
402
|
+
self.add_document(str(file_path), content)
|
|
403
|
+
except UnicodeDecodeError:
|
|
404
|
+
# Skip binary files
|
|
405
|
+
continue
|
|
406
|
+
|
|
407
|
+
# Stop if not recursive
|
|
408
|
+
if not recursive:
|
|
409
|
+
break
|
|
410
|
+
|
|
411
|
+
def to_json(self) -> str:
|
|
412
|
+
"""Convert the context to a JSON string.
|
|
413
|
+
|
|
414
|
+
Returns:
|
|
415
|
+
A JSON string representation of the context
|
|
416
|
+
"""
|
|
417
|
+
data: dict[str, Any] = {
|
|
418
|
+
"documents": self.documents,
|
|
419
|
+
"metadata": self.document_metadata,
|
|
420
|
+
"modified_times": self.modified_times,
|
|
421
|
+
"allowed_paths": [str(p) for p in self.allowed_paths],
|
|
422
|
+
}
|
|
423
|
+
|
|
424
|
+
return json.dumps(data)
|
|
425
|
+
|
|
426
|
+
@classmethod
|
|
427
|
+
def from_json(cls, json_str: str) -> "DocumentContext":
|
|
428
|
+
"""Create a context from a JSON string.
|
|
429
|
+
|
|
430
|
+
Args:
|
|
431
|
+
json_str: The JSON string
|
|
432
|
+
|
|
433
|
+
Returns:
|
|
434
|
+
A new DocumentContext instance
|
|
435
|
+
"""
|
|
436
|
+
data: dict[str, Any] = json.loads(json_str)
|
|
437
|
+
|
|
438
|
+
context = cls()
|
|
439
|
+
context.documents = data.get("documents", {})
|
|
440
|
+
context.document_metadata = data.get("metadata", {})
|
|
441
|
+
context.modified_times = data.get("modified_times", {})
|
|
442
|
+
context.allowed_paths = set(Path(p) for p in data.get("allowed_paths", []))
|
|
443
|
+
|
|
444
|
+
return context
|
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
"""Permission system for the Hanzo Dev MCP server."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from collections.abc import Awaitable, Callable
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, TypeVar, final
|
|
8
|
+
|
|
9
|
+
# Define type variables for better type annotations
|
|
10
|
+
T = TypeVar("T")
|
|
11
|
+
P = TypeVar("P")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@final
|
|
15
|
+
class PermissionManager:
|
|
16
|
+
"""Manages permissions for file and command operations."""
|
|
17
|
+
|
|
18
|
+
def __init__(self) -> None:
|
|
19
|
+
"""Initialize the permission manager."""
|
|
20
|
+
# Allowed paths
|
|
21
|
+
self.allowed_paths: set[Path] = set(
|
|
22
|
+
[Path("/tmp").resolve(), Path("/var").resolve()]
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
# Excluded paths
|
|
26
|
+
self.excluded_paths: set[Path] = set()
|
|
27
|
+
self.excluded_patterns: list[str] = []
|
|
28
|
+
|
|
29
|
+
# Default excluded patterns
|
|
30
|
+
self._add_default_exclusions()
|
|
31
|
+
|
|
32
|
+
def _add_default_exclusions(self) -> None:
|
|
33
|
+
"""Add default exclusions for sensitive files and directories."""
|
|
34
|
+
# Sensitive directories
|
|
35
|
+
sensitive_dirs: list[str] = [
|
|
36
|
+
# ".git" is now allowed by default
|
|
37
|
+
".ssh",
|
|
38
|
+
".gnupg",
|
|
39
|
+
".config",
|
|
40
|
+
"node_modules",
|
|
41
|
+
"__pycache__",
|
|
42
|
+
".venv",
|
|
43
|
+
"venv",
|
|
44
|
+
"env",
|
|
45
|
+
".idea",
|
|
46
|
+
".vscode",
|
|
47
|
+
".DS_Store",
|
|
48
|
+
]
|
|
49
|
+
self.excluded_patterns.extend(sensitive_dirs)
|
|
50
|
+
|
|
51
|
+
# Sensitive file patterns
|
|
52
|
+
sensitive_patterns: list[str] = [
|
|
53
|
+
".env",
|
|
54
|
+
"*.key",
|
|
55
|
+
"*.pem",
|
|
56
|
+
"*.crt",
|
|
57
|
+
"*password*",
|
|
58
|
+
"*secret*",
|
|
59
|
+
"*.sqlite",
|
|
60
|
+
"*.db",
|
|
61
|
+
"*.sqlite3",
|
|
62
|
+
"*.log",
|
|
63
|
+
]
|
|
64
|
+
self.excluded_patterns.extend(sensitive_patterns)
|
|
65
|
+
|
|
66
|
+
def add_allowed_path(self, path: str) -> None:
|
|
67
|
+
"""Add a path to the allowed paths.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
path: The path to allow
|
|
71
|
+
"""
|
|
72
|
+
resolved_path: Path = Path(path).resolve()
|
|
73
|
+
self.allowed_paths.add(resolved_path)
|
|
74
|
+
|
|
75
|
+
def remove_allowed_path(self, path: str) -> None:
|
|
76
|
+
"""Remove a path from the allowed paths.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
path: The path to remove
|
|
80
|
+
"""
|
|
81
|
+
resolved_path: Path = Path(path).resolve()
|
|
82
|
+
if resolved_path in self.allowed_paths:
|
|
83
|
+
self.allowed_paths.remove(resolved_path)
|
|
84
|
+
|
|
85
|
+
def exclude_path(self, path: str) -> None:
|
|
86
|
+
"""Exclude a path from allowed operations.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
path: The path to exclude
|
|
90
|
+
"""
|
|
91
|
+
resolved_path: Path = Path(path).resolve()
|
|
92
|
+
self.excluded_paths.add(resolved_path)
|
|
93
|
+
|
|
94
|
+
def add_exclusion_pattern(self, pattern: str) -> None:
|
|
95
|
+
"""Add an exclusion pattern.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
pattern: The pattern to exclude
|
|
99
|
+
"""
|
|
100
|
+
self.excluded_patterns.append(pattern)
|
|
101
|
+
|
|
102
|
+
def is_path_allowed(self, path: str) -> bool:
|
|
103
|
+
"""Check if a path is allowed.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
path: The path to check
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
True if the path is allowed, False otherwise
|
|
110
|
+
"""
|
|
111
|
+
resolved_path: Path = Path(path).resolve()
|
|
112
|
+
|
|
113
|
+
# Check exclusions first
|
|
114
|
+
if self._is_path_excluded(resolved_path):
|
|
115
|
+
return False
|
|
116
|
+
|
|
117
|
+
# Check if the path is within any allowed path
|
|
118
|
+
for allowed_path in self.allowed_paths:
|
|
119
|
+
try:
|
|
120
|
+
resolved_path.relative_to(allowed_path)
|
|
121
|
+
return True
|
|
122
|
+
except ValueError:
|
|
123
|
+
continue
|
|
124
|
+
|
|
125
|
+
return False
|
|
126
|
+
|
|
127
|
+
def _is_path_excluded(self, path: Path) -> bool:
|
|
128
|
+
"""Check if a path is excluded.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
path: The path to check
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
True if the path is excluded, False otherwise
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
# Check exact excluded paths
|
|
138
|
+
if path in self.excluded_paths:
|
|
139
|
+
return True
|
|
140
|
+
|
|
141
|
+
# Check excluded patterns
|
|
142
|
+
path_str: str = str(path)
|
|
143
|
+
|
|
144
|
+
# Get path parts to check for exact directory/file name matches
|
|
145
|
+
path_parts = path_str.split(os.sep)
|
|
146
|
+
|
|
147
|
+
for pattern in self.excluded_patterns:
|
|
148
|
+
# Handle wildcard patterns (e.g., "*.log")
|
|
149
|
+
if pattern.startswith("*"):
|
|
150
|
+
if path_str.endswith(pattern[1:]):
|
|
151
|
+
return True
|
|
152
|
+
else:
|
|
153
|
+
# For non-wildcard patterns, check if any path component matches exactly
|
|
154
|
+
if pattern in path_parts:
|
|
155
|
+
return True
|
|
156
|
+
|
|
157
|
+
return False
|
|
158
|
+
|
|
159
|
+
def to_json(self) -> str:
|
|
160
|
+
"""Convert the permission manager to a JSON string.
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
A JSON string representation of the permission manager
|
|
164
|
+
"""
|
|
165
|
+
data: dict[str, Any] = {
|
|
166
|
+
"allowed_paths": [str(p) for p in self.allowed_paths],
|
|
167
|
+
"excluded_paths": [str(p) for p in self.excluded_paths],
|
|
168
|
+
"excluded_patterns": self.excluded_patterns,
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
return json.dumps(data)
|
|
172
|
+
|
|
173
|
+
@classmethod
|
|
174
|
+
def from_json(cls, json_str: str) -> "PermissionManager":
|
|
175
|
+
"""Create a permission manager from a JSON string.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
json_str: The JSON string
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
A new PermissionManager instance
|
|
182
|
+
"""
|
|
183
|
+
data: dict[str, Any] = json.loads(json_str)
|
|
184
|
+
|
|
185
|
+
manager = cls()
|
|
186
|
+
|
|
187
|
+
for path in data.get("allowed_paths", []):
|
|
188
|
+
manager.add_allowed_path(path)
|
|
189
|
+
|
|
190
|
+
for path in data.get("excluded_paths", []):
|
|
191
|
+
manager.exclude_path(path)
|
|
192
|
+
|
|
193
|
+
manager.excluded_patterns = data.get("excluded_patterns", [])
|
|
194
|
+
|
|
195
|
+
return manager
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class PermissibleOperation:
|
|
199
|
+
"""A decorator for operations that require permission."""
|
|
200
|
+
|
|
201
|
+
def __init__(
|
|
202
|
+
self,
|
|
203
|
+
permission_manager: PermissionManager,
|
|
204
|
+
operation: str,
|
|
205
|
+
get_path_fn: Callable[[list[Any], dict[str, Any]], str] | None = None,
|
|
206
|
+
) -> None:
|
|
207
|
+
"""Initialize the permissible operation.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
permission_manager: The permission manager
|
|
211
|
+
operation: The operation type (read, write, execute, etc.)
|
|
212
|
+
get_path_fn: Optional function to extract the path from args and kwargs
|
|
213
|
+
"""
|
|
214
|
+
self.permission_manager: PermissionManager = permission_manager
|
|
215
|
+
self.operation: str = operation
|
|
216
|
+
self.get_path_fn: Callable[[list[Any], dict[str, Any]], str] | None = (
|
|
217
|
+
get_path_fn
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
def __call__(
|
|
221
|
+
self, func: Callable[..., Awaitable[T]]
|
|
222
|
+
) -> Callable[..., Awaitable[T]]:
|
|
223
|
+
"""Decorate the function.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
func: The function to decorate
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
The decorated function
|
|
230
|
+
"""
|
|
231
|
+
|
|
232
|
+
async def wrapper(*args: Any, **kwargs: Any) -> T:
|
|
233
|
+
# Extract the path
|
|
234
|
+
if self.get_path_fn:
|
|
235
|
+
# Pass args as a list and kwargs as a dict to the path function
|
|
236
|
+
path = self.get_path_fn(list(args), kwargs)
|
|
237
|
+
else:
|
|
238
|
+
# Default to first argument
|
|
239
|
+
path = args[0] if args else next(iter(kwargs.values()), None)
|
|
240
|
+
|
|
241
|
+
if not isinstance(path, str):
|
|
242
|
+
raise ValueError(f"Invalid path type: {type(path)}")
|
|
243
|
+
|
|
244
|
+
# Check permission
|
|
245
|
+
if not self.permission_manager.is_path_allowed(path):
|
|
246
|
+
raise PermissionError(
|
|
247
|
+
f"Operation '{self.operation}' not allowed for path: {path}"
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
# Call the function
|
|
251
|
+
return await func(*args, **kwargs)
|
|
252
|
+
|
|
253
|
+
return wrapper
|