mcp-code-indexer 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mcp_code_indexer/__init__.py +16 -0
- mcp_code_indexer/database/__init__.py +1 -0
- mcp_code_indexer/database/database.py +480 -0
- mcp_code_indexer/database/models.py +123 -0
- mcp_code_indexer/error_handler.py +365 -0
- mcp_code_indexer/file_scanner.py +375 -0
- mcp_code_indexer/logging_config.py +183 -0
- mcp_code_indexer/main.py +129 -0
- mcp_code_indexer/merge_handler.py +386 -0
- mcp_code_indexer/middleware/__init__.py +7 -0
- mcp_code_indexer/middleware/error_middleware.py +286 -0
- mcp_code_indexer/server/__init__.py +1 -0
- mcp_code_indexer/server/mcp_server.py +699 -0
- mcp_code_indexer/tiktoken_cache/9b5ad71b2ce5302211f9c61530b329a4922fc6a4 +100256 -0
- mcp_code_indexer/token_counter.py +243 -0
- mcp_code_indexer/tools/__init__.py +1 -0
- mcp_code_indexer-1.0.0.dist-info/METADATA +364 -0
- mcp_code_indexer-1.0.0.dist-info/RECORD +22 -0
- mcp_code_indexer-1.0.0.dist-info/WHEEL +5 -0
- mcp_code_indexer-1.0.0.dist-info/entry_points.txt +2 -0
- mcp_code_indexer-1.0.0.dist-info/licenses/LICENSE +21 -0
- mcp_code_indexer-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,386 @@
|
|
1
|
+
"""
|
2
|
+
Two-phase merge functionality for branch descriptions.
|
3
|
+
|
4
|
+
This module implements conflict detection and resolution for merging
|
5
|
+
file descriptions between branches with AI-assisted conflict resolution.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import logging
|
9
|
+
from datetime import datetime
|
10
|
+
from typing import Dict, List, Optional, Tuple
|
11
|
+
from uuid import uuid4
|
12
|
+
|
13
|
+
from mcp_code_indexer.database.database import DatabaseManager
|
14
|
+
from mcp_code_indexer.database.models import FileDescription
|
15
|
+
from mcp_code_indexer.error_handler import ValidationError, DatabaseError
|
16
|
+
from mcp_code_indexer.logging_config import get_logger
|
17
|
+
|
18
|
+
logger = get_logger(__name__)
|
19
|
+
|
20
|
+
|
21
|
+
class MergeConflict:
|
22
|
+
"""Represents a merge conflict between file descriptions."""
|
23
|
+
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
file_path: str,
|
27
|
+
source_branch: str,
|
28
|
+
target_branch: str,
|
29
|
+
source_description: str,
|
30
|
+
target_description: str,
|
31
|
+
conflict_id: Optional[str] = None
|
32
|
+
):
|
33
|
+
"""
|
34
|
+
Initialize merge conflict.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
file_path: Path to conflicted file
|
38
|
+
source_branch: Branch being merged from
|
39
|
+
target_branch: Branch being merged into
|
40
|
+
source_description: Description from source branch
|
41
|
+
target_description: Description from target branch
|
42
|
+
conflict_id: Optional conflict identifier
|
43
|
+
"""
|
44
|
+
self.file_path = file_path
|
45
|
+
self.source_branch = source_branch
|
46
|
+
self.target_branch = target_branch
|
47
|
+
self.source_description = source_description
|
48
|
+
self.target_description = target_description
|
49
|
+
self.conflict_id = conflict_id or str(uuid4())
|
50
|
+
self.resolution: Optional[str] = None
|
51
|
+
|
52
|
+
def to_dict(self) -> Dict:
|
53
|
+
"""Convert conflict to dictionary representation."""
|
54
|
+
return {
|
55
|
+
"conflictId": self.conflict_id,
|
56
|
+
"filePath": self.file_path,
|
57
|
+
"sourceBranch": self.source_branch,
|
58
|
+
"targetBranch": self.target_branch,
|
59
|
+
"sourceDescription": self.source_description,
|
60
|
+
"targetDescription": self.target_description,
|
61
|
+
"resolution": self.resolution
|
62
|
+
}
|
63
|
+
|
64
|
+
|
65
|
+
class MergeSession:
|
66
|
+
"""Manages a merge session with conflicts and resolutions."""
|
67
|
+
|
68
|
+
def __init__(self, project_id: str, source_branch: str, target_branch: str):
|
69
|
+
"""
|
70
|
+
Initialize merge session.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
project_id: Project identifier
|
74
|
+
source_branch: Branch being merged from
|
75
|
+
target_branch: Branch being merged into
|
76
|
+
"""
|
77
|
+
self.session_id = str(uuid4())
|
78
|
+
self.project_id = project_id
|
79
|
+
self.source_branch = source_branch
|
80
|
+
self.target_branch = target_branch
|
81
|
+
self.conflicts: List[MergeConflict] = []
|
82
|
+
self.created = datetime.utcnow()
|
83
|
+
self.status = "pending" # pending, resolved, aborted
|
84
|
+
|
85
|
+
def add_conflict(self, conflict: MergeConflict) -> None:
|
86
|
+
"""Add a conflict to the session."""
|
87
|
+
self.conflicts.append(conflict)
|
88
|
+
|
89
|
+
def get_conflict_count(self) -> int:
|
90
|
+
"""Get total number of conflicts."""
|
91
|
+
return len(self.conflicts)
|
92
|
+
|
93
|
+
def get_resolved_count(self) -> int:
|
94
|
+
"""Get number of resolved conflicts."""
|
95
|
+
return len([c for c in self.conflicts if c.resolution is not None])
|
96
|
+
|
97
|
+
def is_fully_resolved(self) -> bool:
|
98
|
+
"""Check if all conflicts are resolved."""
|
99
|
+
return self.get_resolved_count() == self.get_conflict_count()
|
100
|
+
|
101
|
+
def to_dict(self) -> Dict:
|
102
|
+
"""Convert session to dictionary representation."""
|
103
|
+
return {
|
104
|
+
"sessionId": self.session_id,
|
105
|
+
"projectId": self.project_id,
|
106
|
+
"sourceBranch": self.source_branch,
|
107
|
+
"targetBranch": self.target_branch,
|
108
|
+
"totalConflicts": self.get_conflict_count(),
|
109
|
+
"resolvedConflicts": self.get_resolved_count(),
|
110
|
+
"isFullyResolved": self.is_fully_resolved(),
|
111
|
+
"created": self.created.isoformat(),
|
112
|
+
"status": self.status,
|
113
|
+
"conflicts": [conflict.to_dict() for conflict in self.conflicts]
|
114
|
+
}
|
115
|
+
|
116
|
+
|
117
|
+
class MergeHandler:
|
118
|
+
"""
|
119
|
+
Handles two-phase merge operations for file descriptions.
|
120
|
+
|
121
|
+
Phase 1: Detect conflicts between source and target branches
|
122
|
+
Phase 2: Apply resolutions and complete merge
|
123
|
+
"""
|
124
|
+
|
125
|
+
def __init__(self, db_manager: DatabaseManager):
|
126
|
+
"""
|
127
|
+
Initialize merge handler.
|
128
|
+
|
129
|
+
Args:
|
130
|
+
db_manager: Database manager instance
|
131
|
+
"""
|
132
|
+
self.db_manager = db_manager
|
133
|
+
self._active_sessions: Dict[str, MergeSession] = {}
|
134
|
+
|
135
|
+
async def start_merge_phase1(
|
136
|
+
self,
|
137
|
+
project_id: str,
|
138
|
+
source_branch: str,
|
139
|
+
target_branch: str
|
140
|
+
) -> MergeSession:
|
141
|
+
"""
|
142
|
+
Phase 1: Detect merge conflicts.
|
143
|
+
|
144
|
+
Args:
|
145
|
+
project_id: Project identifier
|
146
|
+
source_branch: Branch to merge from
|
147
|
+
target_branch: Branch to merge into
|
148
|
+
|
149
|
+
Returns:
|
150
|
+
MergeSession with detected conflicts
|
151
|
+
|
152
|
+
Raises:
|
153
|
+
ValidationError: If branches are invalid
|
154
|
+
DatabaseError: If database operation fails
|
155
|
+
"""
|
156
|
+
if source_branch == target_branch:
|
157
|
+
raise ValidationError("Source and target branches cannot be the same")
|
158
|
+
|
159
|
+
logger.info(f"Starting merge phase 1: {source_branch} -> {target_branch}")
|
160
|
+
|
161
|
+
try:
|
162
|
+
# Get file descriptions from both branches
|
163
|
+
source_descriptions = await self.db_manager.get_all_file_descriptions(
|
164
|
+
project_id, source_branch
|
165
|
+
)
|
166
|
+
target_descriptions = await self.db_manager.get_all_file_descriptions(
|
167
|
+
project_id, target_branch
|
168
|
+
)
|
169
|
+
|
170
|
+
# Create session
|
171
|
+
session = MergeSession(project_id, source_branch, target_branch)
|
172
|
+
|
173
|
+
# Build lookup dictionaries
|
174
|
+
source_lookup = {desc.file_path: desc for desc in source_descriptions}
|
175
|
+
target_lookup = {desc.file_path: desc for desc in target_descriptions}
|
176
|
+
|
177
|
+
# Detect conflicts
|
178
|
+
conflicts_found = 0
|
179
|
+
all_files = set(source_lookup.keys()) | set(target_lookup.keys())
|
180
|
+
|
181
|
+
for file_path in all_files:
|
182
|
+
source_desc = source_lookup.get(file_path)
|
183
|
+
target_desc = target_lookup.get(file_path)
|
184
|
+
|
185
|
+
# Conflict occurs when:
|
186
|
+
# 1. File exists in both branches with different descriptions
|
187
|
+
# 2. File has been modified in source but also exists in target
|
188
|
+
if source_desc and target_desc:
|
189
|
+
if source_desc.description != target_desc.description:
|
190
|
+
conflict = MergeConflict(
|
191
|
+
file_path=file_path,
|
192
|
+
source_branch=source_branch,
|
193
|
+
target_branch=target_branch,
|
194
|
+
source_description=source_desc.description,
|
195
|
+
target_description=target_desc.description
|
196
|
+
)
|
197
|
+
session.add_conflict(conflict)
|
198
|
+
conflicts_found += 1
|
199
|
+
|
200
|
+
# Store session
|
201
|
+
self._active_sessions[session.session_id] = session
|
202
|
+
|
203
|
+
logger.info(f"Merge phase 1 completed: {conflicts_found} conflicts found")
|
204
|
+
|
205
|
+
return session
|
206
|
+
|
207
|
+
except Exception as e:
|
208
|
+
logger.error(f"Error in merge phase 1: {e}")
|
209
|
+
raise DatabaseError(f"Failed to detect merge conflicts: {e}") from e
|
210
|
+
|
211
|
+
async def complete_merge_phase2(
|
212
|
+
self,
|
213
|
+
session_id: str,
|
214
|
+
conflict_resolutions: List[Dict[str, str]]
|
215
|
+
) -> Dict:
|
216
|
+
"""
|
217
|
+
Phase 2: Apply resolutions and complete merge.
|
218
|
+
|
219
|
+
Args:
|
220
|
+
session_id: Merge session identifier
|
221
|
+
conflict_resolutions: List of {conflictId, resolvedDescription}
|
222
|
+
|
223
|
+
Returns:
|
224
|
+
Merge result summary
|
225
|
+
|
226
|
+
Raises:
|
227
|
+
ValidationError: If session not found or resolutions invalid
|
228
|
+
DatabaseError: If database operation fails
|
229
|
+
"""
|
230
|
+
session = self._active_sessions.get(session_id)
|
231
|
+
if not session:
|
232
|
+
raise ValidationError(f"Merge session not found: {session_id}")
|
233
|
+
|
234
|
+
logger.info(f"Starting merge phase 2 for session {session_id}")
|
235
|
+
|
236
|
+
try:
|
237
|
+
# Validate and apply resolutions
|
238
|
+
resolution_lookup = {res["conflictId"]: res["resolvedDescription"]
|
239
|
+
for res in conflict_resolutions}
|
240
|
+
|
241
|
+
resolved_count = 0
|
242
|
+
for conflict in session.conflicts:
|
243
|
+
if conflict.conflict_id in resolution_lookup:
|
244
|
+
conflict.resolution = resolution_lookup[conflict.conflict_id]
|
245
|
+
resolved_count += 1
|
246
|
+
|
247
|
+
# Check if all conflicts are resolved
|
248
|
+
if not session.is_fully_resolved():
|
249
|
+
unresolved = session.get_conflict_count() - session.get_resolved_count()
|
250
|
+
raise ValidationError(
|
251
|
+
f"Not all conflicts resolved: {unresolved} remaining",
|
252
|
+
details={
|
253
|
+
"total_conflicts": session.get_conflict_count(),
|
254
|
+
"resolved_conflicts": session.get_resolved_count(),
|
255
|
+
"unresolved_conflicts": unresolved
|
256
|
+
}
|
257
|
+
)
|
258
|
+
|
259
|
+
# Apply merge
|
260
|
+
merged_descriptions = []
|
261
|
+
|
262
|
+
# Get all descriptions from source branch
|
263
|
+
source_descriptions = await self.db_manager.get_all_file_descriptions(
|
264
|
+
session.project_id, session.source_branch
|
265
|
+
)
|
266
|
+
|
267
|
+
# Get existing target descriptions
|
268
|
+
target_descriptions = await self.db_manager.get_all_file_descriptions(
|
269
|
+
session.project_id, session.target_branch
|
270
|
+
)
|
271
|
+
|
272
|
+
target_lookup = {desc.file_path: desc for desc in target_descriptions}
|
273
|
+
|
274
|
+
# Apply resolved descriptions
|
275
|
+
for source_desc in source_descriptions:
|
276
|
+
resolved_conflict = next(
|
277
|
+
(c for c in session.conflicts if c.file_path == source_desc.file_path),
|
278
|
+
None
|
279
|
+
)
|
280
|
+
|
281
|
+
if resolved_conflict:
|
282
|
+
# Use resolved description
|
283
|
+
new_desc = FileDescription(
|
284
|
+
project_id=session.project_id,
|
285
|
+
branch=session.target_branch,
|
286
|
+
file_path=source_desc.file_path,
|
287
|
+
description=resolved_conflict.resolution,
|
288
|
+
file_hash=source_desc.file_hash,
|
289
|
+
last_modified=datetime.utcnow(),
|
290
|
+
version=1,
|
291
|
+
source_project_id=source_desc.source_project_id
|
292
|
+
)
|
293
|
+
else:
|
294
|
+
# No conflict, copy from source
|
295
|
+
new_desc = FileDescription(
|
296
|
+
project_id=session.project_id,
|
297
|
+
branch=session.target_branch,
|
298
|
+
file_path=source_desc.file_path,
|
299
|
+
description=source_desc.description,
|
300
|
+
file_hash=source_desc.file_hash,
|
301
|
+
last_modified=datetime.utcnow(),
|
302
|
+
version=1,
|
303
|
+
source_project_id=source_desc.source_project_id
|
304
|
+
)
|
305
|
+
|
306
|
+
merged_descriptions.append(new_desc)
|
307
|
+
|
308
|
+
# Batch update target branch
|
309
|
+
await self.db_manager.batch_create_file_descriptions(merged_descriptions)
|
310
|
+
|
311
|
+
# Mark session as completed
|
312
|
+
session.status = "resolved"
|
313
|
+
|
314
|
+
result = {
|
315
|
+
"success": True,
|
316
|
+
"sessionId": session_id,
|
317
|
+
"sourceBranch": session.source_branch,
|
318
|
+
"targetBranch": session.target_branch,
|
319
|
+
"totalConflicts": session.get_conflict_count(),
|
320
|
+
"resolvedConflicts": session.get_resolved_count(),
|
321
|
+
"mergedFiles": len(merged_descriptions),
|
322
|
+
"message": f"Successfully merged {len(merged_descriptions)} files from {session.source_branch} to {session.target_branch}"
|
323
|
+
}
|
324
|
+
|
325
|
+
logger.info(f"Merge phase 2 completed successfully: {len(merged_descriptions)} files merged")
|
326
|
+
|
327
|
+
# Clean up session
|
328
|
+
del self._active_sessions[session_id]
|
329
|
+
|
330
|
+
return result
|
331
|
+
|
332
|
+
except Exception as e:
|
333
|
+
if session:
|
334
|
+
session.status = "aborted"
|
335
|
+
logger.error(f"Error in merge phase 2: {e}")
|
336
|
+
raise DatabaseError(f"Failed to complete merge: {e}") from e
|
337
|
+
|
338
|
+
def get_session(self, session_id: str) -> Optional[MergeSession]:
|
339
|
+
"""Get merge session by ID."""
|
340
|
+
return self._active_sessions.get(session_id)
|
341
|
+
|
342
|
+
def get_active_sessions(self) -> List[MergeSession]:
|
343
|
+
"""Get all active merge sessions."""
|
344
|
+
return list(self._active_sessions.values())
|
345
|
+
|
346
|
+
def abort_session(self, session_id: str) -> bool:
|
347
|
+
"""
|
348
|
+
Abort a merge session.
|
349
|
+
|
350
|
+
Args:
|
351
|
+
session_id: Session to abort
|
352
|
+
|
353
|
+
Returns:
|
354
|
+
True if session was aborted
|
355
|
+
"""
|
356
|
+
session = self._active_sessions.get(session_id)
|
357
|
+
if session:
|
358
|
+
session.status = "aborted"
|
359
|
+
del self._active_sessions[session_id]
|
360
|
+
logger.info(f"Merge session {session_id} aborted")
|
361
|
+
return True
|
362
|
+
return False
|
363
|
+
|
364
|
+
def cleanup_old_sessions(self, max_age_hours: int = 24) -> int:
|
365
|
+
"""
|
366
|
+
Clean up old merge sessions.
|
367
|
+
|
368
|
+
Args:
|
369
|
+
max_age_hours: Maximum age of sessions to keep
|
370
|
+
|
371
|
+
Returns:
|
372
|
+
Number of sessions cleaned up
|
373
|
+
"""
|
374
|
+
cutoff_time = datetime.utcnow() - datetime.timedelta(hours=max_age_hours)
|
375
|
+
old_sessions = [
|
376
|
+
session_id for session_id, session in self._active_sessions.items()
|
377
|
+
if session.created < cutoff_time
|
378
|
+
]
|
379
|
+
|
380
|
+
for session_id in old_sessions:
|
381
|
+
del self._active_sessions[session_id]
|
382
|
+
|
383
|
+
if old_sessions:
|
384
|
+
logger.info(f"Cleaned up {len(old_sessions)} old merge sessions")
|
385
|
+
|
386
|
+
return len(old_sessions)
|
@@ -0,0 +1,286 @@
|
|
1
|
+
"""
|
2
|
+
Error handling middleware for MCP tools.
|
3
|
+
|
4
|
+
This module provides decorators and middleware functions to standardize
|
5
|
+
error handling across all MCP tool implementations.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import asyncio
|
9
|
+
import functools
|
10
|
+
import time
|
11
|
+
from typing import Any, Callable, Dict, List
|
12
|
+
|
13
|
+
from mcp import types
|
14
|
+
|
15
|
+
from mcp_code_indexer.error_handler import ErrorHandler, MCPError
|
16
|
+
from mcp_code_indexer.logging_config import get_logger, log_tool_usage, log_performance_metrics
|
17
|
+
|
18
|
+
logger = get_logger(__name__)
|
19
|
+
|
20
|
+
|
21
|
+
class ToolMiddleware:
|
22
|
+
"""Middleware for MCP tool error handling and logging."""
|
23
|
+
|
24
|
+
def __init__(self, error_handler: ErrorHandler):
|
25
|
+
"""Initialize middleware with error handler."""
|
26
|
+
self.error_handler = error_handler
|
27
|
+
|
28
|
+
def wrap_tool_handler(self, tool_name: str):
|
29
|
+
"""
|
30
|
+
Decorator to wrap tool handlers with error handling and logging.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
tool_name: Name of the MCP tool
|
34
|
+
|
35
|
+
Returns:
|
36
|
+
Decorator function
|
37
|
+
"""
|
38
|
+
def decorator(func: Callable) -> Callable:
|
39
|
+
@functools.wraps(func)
|
40
|
+
async def wrapper(arguments: Dict[str, Any]) -> List[types.TextContent]:
|
41
|
+
start_time = time.time()
|
42
|
+
success = False
|
43
|
+
result_size = 0
|
44
|
+
|
45
|
+
try:
|
46
|
+
# Log tool invocation
|
47
|
+
logger.info(f"Tool {tool_name} called", extra={
|
48
|
+
"structured_data": {
|
49
|
+
"tool_invocation": {
|
50
|
+
"tool_name": tool_name,
|
51
|
+
"arguments_count": len(arguments)
|
52
|
+
}
|
53
|
+
}
|
54
|
+
})
|
55
|
+
|
56
|
+
# Call the actual tool handler
|
57
|
+
result = await func(arguments)
|
58
|
+
|
59
|
+
# Calculate result size
|
60
|
+
if isinstance(result, list):
|
61
|
+
result_size = sum(len(item.text) if hasattr(item, 'text') else 0 for item in result)
|
62
|
+
|
63
|
+
success = True
|
64
|
+
duration = time.time() - start_time
|
65
|
+
|
66
|
+
# Log performance metrics
|
67
|
+
log_performance_metrics(
|
68
|
+
logger,
|
69
|
+
f"tool_{tool_name}",
|
70
|
+
duration,
|
71
|
+
result_size=result_size,
|
72
|
+
arguments_count=len(arguments)
|
73
|
+
)
|
74
|
+
|
75
|
+
return result
|
76
|
+
|
77
|
+
except Exception as e:
|
78
|
+
duration = time.time() - start_time
|
79
|
+
|
80
|
+
# Log the error
|
81
|
+
self.error_handler.log_error(
|
82
|
+
e,
|
83
|
+
context={"arguments_count": len(arguments)},
|
84
|
+
tool_name=tool_name
|
85
|
+
)
|
86
|
+
|
87
|
+
# Create error response
|
88
|
+
error_response = self.error_handler.create_mcp_error_response(
|
89
|
+
e, tool_name, arguments
|
90
|
+
)
|
91
|
+
|
92
|
+
return [error_response]
|
93
|
+
|
94
|
+
finally:
|
95
|
+
# Always log tool usage
|
96
|
+
log_tool_usage(
|
97
|
+
logger,
|
98
|
+
tool_name,
|
99
|
+
arguments,
|
100
|
+
success,
|
101
|
+
time.time() - start_time,
|
102
|
+
result_size if success else None
|
103
|
+
)
|
104
|
+
|
105
|
+
return wrapper
|
106
|
+
return decorator
|
107
|
+
|
108
|
+
def validate_tool_arguments(self, required_fields: List[str], optional_fields: List[str] = None):
|
109
|
+
"""
|
110
|
+
Decorator to validate tool arguments.
|
111
|
+
|
112
|
+
Args:
|
113
|
+
required_fields: List of required argument names
|
114
|
+
optional_fields: List of optional argument names
|
115
|
+
|
116
|
+
Returns:
|
117
|
+
Decorator function
|
118
|
+
"""
|
119
|
+
def decorator(func: Callable) -> Callable:
|
120
|
+
@functools.wraps(func)
|
121
|
+
async def wrapper(arguments: Dict[str, Any]) -> Any:
|
122
|
+
from ..error_handler import ValidationError
|
123
|
+
|
124
|
+
# Check required fields
|
125
|
+
missing_fields = [field for field in required_fields if field not in arguments]
|
126
|
+
if missing_fields:
|
127
|
+
raise ValidationError(
|
128
|
+
f"Missing required fields: {', '.join(missing_fields)}",
|
129
|
+
details={"missing_fields": missing_fields, "provided_fields": list(arguments.keys())}
|
130
|
+
)
|
131
|
+
|
132
|
+
# Check for unexpected fields if optional_fields is provided
|
133
|
+
if optional_fields is not None:
|
134
|
+
all_fields = set(required_fields + optional_fields)
|
135
|
+
unexpected_fields = [field for field in arguments.keys() if field not in all_fields]
|
136
|
+
if unexpected_fields:
|
137
|
+
raise ValidationError(
|
138
|
+
f"Unexpected fields: {', '.join(unexpected_fields)}",
|
139
|
+
details={"unexpected_fields": unexpected_fields, "allowed_fields": list(all_fields)}
|
140
|
+
)
|
141
|
+
|
142
|
+
return await func(arguments)
|
143
|
+
|
144
|
+
return wrapper
|
145
|
+
return decorator
|
146
|
+
|
147
|
+
|
148
|
+
class AsyncTaskManager:
|
149
|
+
"""Manages async tasks with proper error handling."""
|
150
|
+
|
151
|
+
def __init__(self, error_handler: ErrorHandler):
|
152
|
+
"""Initialize task manager."""
|
153
|
+
self.error_handler = error_handler
|
154
|
+
self._tasks: List[asyncio.Task] = []
|
155
|
+
|
156
|
+
def create_task(self, coro, name: str = None) -> asyncio.Task:
|
157
|
+
"""
|
158
|
+
Create a managed async task.
|
159
|
+
|
160
|
+
Args:
|
161
|
+
coro: Coroutine to run
|
162
|
+
name: Optional task name for logging
|
163
|
+
|
164
|
+
Returns:
|
165
|
+
Created task
|
166
|
+
"""
|
167
|
+
task = asyncio.create_task(coro, name=name)
|
168
|
+
self._tasks.append(task)
|
169
|
+
|
170
|
+
# Add done callback for error handling
|
171
|
+
task.add_done_callback(
|
172
|
+
lambda t: asyncio.create_task(
|
173
|
+
self._handle_task_completion(t, name or "unnamed_task")
|
174
|
+
)
|
175
|
+
)
|
176
|
+
|
177
|
+
return task
|
178
|
+
|
179
|
+
async def _handle_task_completion(self, task: asyncio.Task, task_name: str) -> None:
|
180
|
+
"""Handle task completion and errors."""
|
181
|
+
try:
|
182
|
+
if task.done() and not task.cancelled():
|
183
|
+
exception = task.exception()
|
184
|
+
if exception:
|
185
|
+
await self.error_handler.handle_async_task_error(
|
186
|
+
task, task_name
|
187
|
+
)
|
188
|
+
except Exception as e:
|
189
|
+
logger.error(f"Error handling task completion for {task_name}: {e}")
|
190
|
+
finally:
|
191
|
+
# Remove completed task from tracking
|
192
|
+
if task in self._tasks:
|
193
|
+
self._tasks.remove(task)
|
194
|
+
|
195
|
+
async def wait_for_all(self, timeout: float = None) -> None:
|
196
|
+
"""
|
197
|
+
Wait for all managed tasks to complete.
|
198
|
+
|
199
|
+
Args:
|
200
|
+
timeout: Maximum time to wait in seconds
|
201
|
+
"""
|
202
|
+
if not self._tasks:
|
203
|
+
return
|
204
|
+
|
205
|
+
try:
|
206
|
+
await asyncio.wait_for(
|
207
|
+
asyncio.gather(*self._tasks, return_exceptions=True),
|
208
|
+
timeout=timeout
|
209
|
+
)
|
210
|
+
except asyncio.TimeoutError:
|
211
|
+
logger.warning(f"Timeout waiting for {len(self._tasks)} tasks")
|
212
|
+
# Cancel remaining tasks
|
213
|
+
for task in self._tasks:
|
214
|
+
if not task.done():
|
215
|
+
task.cancel()
|
216
|
+
except Exception as e:
|
217
|
+
logger.error(f"Error waiting for tasks: {e}")
|
218
|
+
|
219
|
+
def cancel_all(self) -> None:
|
220
|
+
"""Cancel all managed tasks."""
|
221
|
+
for task in self._tasks:
|
222
|
+
if not task.done():
|
223
|
+
task.cancel()
|
224
|
+
self._tasks.clear()
|
225
|
+
|
226
|
+
@property
|
227
|
+
def active_task_count(self) -> int:
|
228
|
+
"""Get count of active tasks."""
|
229
|
+
return len([task for task in self._tasks if not task.done()])
|
230
|
+
|
231
|
+
|
232
|
+
def create_tool_middleware(error_handler: ErrorHandler) -> ToolMiddleware:
|
233
|
+
"""
|
234
|
+
Create tool middleware instance.
|
235
|
+
|
236
|
+
Args:
|
237
|
+
error_handler: Error handler instance
|
238
|
+
|
239
|
+
Returns:
|
240
|
+
Configured ToolMiddleware
|
241
|
+
"""
|
242
|
+
return ToolMiddleware(error_handler)
|
243
|
+
|
244
|
+
|
245
|
+
# Convenience decorators for common patterns
|
246
|
+
|
247
|
+
def require_fields(*required_fields):
|
248
|
+
"""Decorator that requires specific fields in arguments."""
|
249
|
+
def decorator(func):
|
250
|
+
@functools.wraps(func)
|
251
|
+
async def wrapper(self, arguments: Dict[str, Any]):
|
252
|
+
from ..error_handler import ValidationError
|
253
|
+
|
254
|
+
missing = [field for field in required_fields if field not in arguments]
|
255
|
+
if missing:
|
256
|
+
raise ValidationError(f"Missing required fields: {', '.join(missing)}")
|
257
|
+
|
258
|
+
return await func(self, arguments)
|
259
|
+
return wrapper
|
260
|
+
return decorator
|
261
|
+
|
262
|
+
|
263
|
+
def handle_file_operations(func):
|
264
|
+
"""Decorator for file operation error handling."""
|
265
|
+
@functools.wraps(func)
|
266
|
+
async def wrapper(*args, **kwargs):
|
267
|
+
try:
|
268
|
+
return await func(*args, **kwargs)
|
269
|
+
except (FileNotFoundError, PermissionError, OSError) as e:
|
270
|
+
from ..error_handler import FileSystemError
|
271
|
+
raise FileSystemError(f"File operation failed: {e}") from e
|
272
|
+
return wrapper
|
273
|
+
|
274
|
+
|
275
|
+
def handle_database_operations(func):
|
276
|
+
"""Decorator for database operation error handling."""
|
277
|
+
@functools.wraps(func)
|
278
|
+
async def wrapper(*args, **kwargs):
|
279
|
+
try:
|
280
|
+
return await func(*args, **kwargs)
|
281
|
+
except Exception as e:
|
282
|
+
if any(keyword in str(e).lower() for keyword in ["database", "sqlite", "sql"]):
|
283
|
+
from ..error_handler import DatabaseError
|
284
|
+
raise DatabaseError(f"Database operation failed: {e}") from e
|
285
|
+
raise
|
286
|
+
return wrapper
|
@@ -0,0 +1 @@
|
|
1
|
+
"""MCP server implementation modules."""
|