mcp-code-indexer 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,386 @@
1
+ """
2
+ Two-phase merge functionality for branch descriptions.
3
+
4
+ This module implements conflict detection and resolution for merging
5
+ file descriptions between branches with AI-assisted conflict resolution.
6
+ """
7
+
8
+ import logging
9
+ from datetime import datetime
10
+ from typing import Dict, List, Optional, Tuple
11
+ from uuid import uuid4
12
+
13
+ from mcp_code_indexer.database.database import DatabaseManager
14
+ from mcp_code_indexer.database.models import FileDescription
15
+ from mcp_code_indexer.error_handler import ValidationError, DatabaseError
16
+ from mcp_code_indexer.logging_config import get_logger
17
+
18
+ logger = get_logger(__name__)
19
+
20
+
21
+ class MergeConflict:
22
+ """Represents a merge conflict between file descriptions."""
23
+
24
+ def __init__(
25
+ self,
26
+ file_path: str,
27
+ source_branch: str,
28
+ target_branch: str,
29
+ source_description: str,
30
+ target_description: str,
31
+ conflict_id: Optional[str] = None
32
+ ):
33
+ """
34
+ Initialize merge conflict.
35
+
36
+ Args:
37
+ file_path: Path to conflicted file
38
+ source_branch: Branch being merged from
39
+ target_branch: Branch being merged into
40
+ source_description: Description from source branch
41
+ target_description: Description from target branch
42
+ conflict_id: Optional conflict identifier
43
+ """
44
+ self.file_path = file_path
45
+ self.source_branch = source_branch
46
+ self.target_branch = target_branch
47
+ self.source_description = source_description
48
+ self.target_description = target_description
49
+ self.conflict_id = conflict_id or str(uuid4())
50
+ self.resolution: Optional[str] = None
51
+
52
+ def to_dict(self) -> Dict:
53
+ """Convert conflict to dictionary representation."""
54
+ return {
55
+ "conflictId": self.conflict_id,
56
+ "filePath": self.file_path,
57
+ "sourceBranch": self.source_branch,
58
+ "targetBranch": self.target_branch,
59
+ "sourceDescription": self.source_description,
60
+ "targetDescription": self.target_description,
61
+ "resolution": self.resolution
62
+ }
63
+
64
+
65
+ class MergeSession:
66
+ """Manages a merge session with conflicts and resolutions."""
67
+
68
+ def __init__(self, project_id: str, source_branch: str, target_branch: str):
69
+ """
70
+ Initialize merge session.
71
+
72
+ Args:
73
+ project_id: Project identifier
74
+ source_branch: Branch being merged from
75
+ target_branch: Branch being merged into
76
+ """
77
+ self.session_id = str(uuid4())
78
+ self.project_id = project_id
79
+ self.source_branch = source_branch
80
+ self.target_branch = target_branch
81
+ self.conflicts: List[MergeConflict] = []
82
+ self.created = datetime.utcnow()
83
+ self.status = "pending" # pending, resolved, aborted
84
+
85
+ def add_conflict(self, conflict: MergeConflict) -> None:
86
+ """Add a conflict to the session."""
87
+ self.conflicts.append(conflict)
88
+
89
+ def get_conflict_count(self) -> int:
90
+ """Get total number of conflicts."""
91
+ return len(self.conflicts)
92
+
93
+ def get_resolved_count(self) -> int:
94
+ """Get number of resolved conflicts."""
95
+ return len([c for c in self.conflicts if c.resolution is not None])
96
+
97
+ def is_fully_resolved(self) -> bool:
98
+ """Check if all conflicts are resolved."""
99
+ return self.get_resolved_count() == self.get_conflict_count()
100
+
101
+ def to_dict(self) -> Dict:
102
+ """Convert session to dictionary representation."""
103
+ return {
104
+ "sessionId": self.session_id,
105
+ "projectId": self.project_id,
106
+ "sourceBranch": self.source_branch,
107
+ "targetBranch": self.target_branch,
108
+ "totalConflicts": self.get_conflict_count(),
109
+ "resolvedConflicts": self.get_resolved_count(),
110
+ "isFullyResolved": self.is_fully_resolved(),
111
+ "created": self.created.isoformat(),
112
+ "status": self.status,
113
+ "conflicts": [conflict.to_dict() for conflict in self.conflicts]
114
+ }
115
+
116
+
117
+ class MergeHandler:
118
+ """
119
+ Handles two-phase merge operations for file descriptions.
120
+
121
+ Phase 1: Detect conflicts between source and target branches
122
+ Phase 2: Apply resolutions and complete merge
123
+ """
124
+
125
+ def __init__(self, db_manager: DatabaseManager):
126
+ """
127
+ Initialize merge handler.
128
+
129
+ Args:
130
+ db_manager: Database manager instance
131
+ """
132
+ self.db_manager = db_manager
133
+ self._active_sessions: Dict[str, MergeSession] = {}
134
+
135
+ async def start_merge_phase1(
136
+ self,
137
+ project_id: str,
138
+ source_branch: str,
139
+ target_branch: str
140
+ ) -> MergeSession:
141
+ """
142
+ Phase 1: Detect merge conflicts.
143
+
144
+ Args:
145
+ project_id: Project identifier
146
+ source_branch: Branch to merge from
147
+ target_branch: Branch to merge into
148
+
149
+ Returns:
150
+ MergeSession with detected conflicts
151
+
152
+ Raises:
153
+ ValidationError: If branches are invalid
154
+ DatabaseError: If database operation fails
155
+ """
156
+ if source_branch == target_branch:
157
+ raise ValidationError("Source and target branches cannot be the same")
158
+
159
+ logger.info(f"Starting merge phase 1: {source_branch} -> {target_branch}")
160
+
161
+ try:
162
+ # Get file descriptions from both branches
163
+ source_descriptions = await self.db_manager.get_all_file_descriptions(
164
+ project_id, source_branch
165
+ )
166
+ target_descriptions = await self.db_manager.get_all_file_descriptions(
167
+ project_id, target_branch
168
+ )
169
+
170
+ # Create session
171
+ session = MergeSession(project_id, source_branch, target_branch)
172
+
173
+ # Build lookup dictionaries
174
+ source_lookup = {desc.file_path: desc for desc in source_descriptions}
175
+ target_lookup = {desc.file_path: desc for desc in target_descriptions}
176
+
177
+ # Detect conflicts
178
+ conflicts_found = 0
179
+ all_files = set(source_lookup.keys()) | set(target_lookup.keys())
180
+
181
+ for file_path in all_files:
182
+ source_desc = source_lookup.get(file_path)
183
+ target_desc = target_lookup.get(file_path)
184
+
185
+ # Conflict occurs when:
186
+ # 1. File exists in both branches with different descriptions
187
+ # 2. File has been modified in source but also exists in target
188
+ if source_desc and target_desc:
189
+ if source_desc.description != target_desc.description:
190
+ conflict = MergeConflict(
191
+ file_path=file_path,
192
+ source_branch=source_branch,
193
+ target_branch=target_branch,
194
+ source_description=source_desc.description,
195
+ target_description=target_desc.description
196
+ )
197
+ session.add_conflict(conflict)
198
+ conflicts_found += 1
199
+
200
+ # Store session
201
+ self._active_sessions[session.session_id] = session
202
+
203
+ logger.info(f"Merge phase 1 completed: {conflicts_found} conflicts found")
204
+
205
+ return session
206
+
207
+ except Exception as e:
208
+ logger.error(f"Error in merge phase 1: {e}")
209
+ raise DatabaseError(f"Failed to detect merge conflicts: {e}") from e
210
+
211
+ async def complete_merge_phase2(
212
+ self,
213
+ session_id: str,
214
+ conflict_resolutions: List[Dict[str, str]]
215
+ ) -> Dict:
216
+ """
217
+ Phase 2: Apply resolutions and complete merge.
218
+
219
+ Args:
220
+ session_id: Merge session identifier
221
+ conflict_resolutions: List of {conflictId, resolvedDescription}
222
+
223
+ Returns:
224
+ Merge result summary
225
+
226
+ Raises:
227
+ ValidationError: If session not found or resolutions invalid
228
+ DatabaseError: If database operation fails
229
+ """
230
+ session = self._active_sessions.get(session_id)
231
+ if not session:
232
+ raise ValidationError(f"Merge session not found: {session_id}")
233
+
234
+ logger.info(f"Starting merge phase 2 for session {session_id}")
235
+
236
+ try:
237
+ # Validate and apply resolutions
238
+ resolution_lookup = {res["conflictId"]: res["resolvedDescription"]
239
+ for res in conflict_resolutions}
240
+
241
+ resolved_count = 0
242
+ for conflict in session.conflicts:
243
+ if conflict.conflict_id in resolution_lookup:
244
+ conflict.resolution = resolution_lookup[conflict.conflict_id]
245
+ resolved_count += 1
246
+
247
+ # Check if all conflicts are resolved
248
+ if not session.is_fully_resolved():
249
+ unresolved = session.get_conflict_count() - session.get_resolved_count()
250
+ raise ValidationError(
251
+ f"Not all conflicts resolved: {unresolved} remaining",
252
+ details={
253
+ "total_conflicts": session.get_conflict_count(),
254
+ "resolved_conflicts": session.get_resolved_count(),
255
+ "unresolved_conflicts": unresolved
256
+ }
257
+ )
258
+
259
+ # Apply merge
260
+ merged_descriptions = []
261
+
262
+ # Get all descriptions from source branch
263
+ source_descriptions = await self.db_manager.get_all_file_descriptions(
264
+ session.project_id, session.source_branch
265
+ )
266
+
267
+ # Get existing target descriptions
268
+ target_descriptions = await self.db_manager.get_all_file_descriptions(
269
+ session.project_id, session.target_branch
270
+ )
271
+
272
+ target_lookup = {desc.file_path: desc for desc in target_descriptions}
273
+
274
+ # Apply resolved descriptions
275
+ for source_desc in source_descriptions:
276
+ resolved_conflict = next(
277
+ (c for c in session.conflicts if c.file_path == source_desc.file_path),
278
+ None
279
+ )
280
+
281
+ if resolved_conflict:
282
+ # Use resolved description
283
+ new_desc = FileDescription(
284
+ project_id=session.project_id,
285
+ branch=session.target_branch,
286
+ file_path=source_desc.file_path,
287
+ description=resolved_conflict.resolution,
288
+ file_hash=source_desc.file_hash,
289
+ last_modified=datetime.utcnow(),
290
+ version=1,
291
+ source_project_id=source_desc.source_project_id
292
+ )
293
+ else:
294
+ # No conflict, copy from source
295
+ new_desc = FileDescription(
296
+ project_id=session.project_id,
297
+ branch=session.target_branch,
298
+ file_path=source_desc.file_path,
299
+ description=source_desc.description,
300
+ file_hash=source_desc.file_hash,
301
+ last_modified=datetime.utcnow(),
302
+ version=1,
303
+ source_project_id=source_desc.source_project_id
304
+ )
305
+
306
+ merged_descriptions.append(new_desc)
307
+
308
+ # Batch update target branch
309
+ await self.db_manager.batch_create_file_descriptions(merged_descriptions)
310
+
311
+ # Mark session as completed
312
+ session.status = "resolved"
313
+
314
+ result = {
315
+ "success": True,
316
+ "sessionId": session_id,
317
+ "sourceBranch": session.source_branch,
318
+ "targetBranch": session.target_branch,
319
+ "totalConflicts": session.get_conflict_count(),
320
+ "resolvedConflicts": session.get_resolved_count(),
321
+ "mergedFiles": len(merged_descriptions),
322
+ "message": f"Successfully merged {len(merged_descriptions)} files from {session.source_branch} to {session.target_branch}"
323
+ }
324
+
325
+ logger.info(f"Merge phase 2 completed successfully: {len(merged_descriptions)} files merged")
326
+
327
+ # Clean up session
328
+ del self._active_sessions[session_id]
329
+
330
+ return result
331
+
332
+ except Exception as e:
333
+ if session:
334
+ session.status = "aborted"
335
+ logger.error(f"Error in merge phase 2: {e}")
336
+ raise DatabaseError(f"Failed to complete merge: {e}") from e
337
+
338
+ def get_session(self, session_id: str) -> Optional[MergeSession]:
339
+ """Get merge session by ID."""
340
+ return self._active_sessions.get(session_id)
341
+
342
+ def get_active_sessions(self) -> List[MergeSession]:
343
+ """Get all active merge sessions."""
344
+ return list(self._active_sessions.values())
345
+
346
+ def abort_session(self, session_id: str) -> bool:
347
+ """
348
+ Abort a merge session.
349
+
350
+ Args:
351
+ session_id: Session to abort
352
+
353
+ Returns:
354
+ True if session was aborted
355
+ """
356
+ session = self._active_sessions.get(session_id)
357
+ if session:
358
+ session.status = "aborted"
359
+ del self._active_sessions[session_id]
360
+ logger.info(f"Merge session {session_id} aborted")
361
+ return True
362
+ return False
363
+
364
+ def cleanup_old_sessions(self, max_age_hours: int = 24) -> int:
365
+ """
366
+ Clean up old merge sessions.
367
+
368
+ Args:
369
+ max_age_hours: Maximum age of sessions to keep
370
+
371
+ Returns:
372
+ Number of sessions cleaned up
373
+ """
374
+ cutoff_time = datetime.utcnow() - datetime.timedelta(hours=max_age_hours)
375
+ old_sessions = [
376
+ session_id for session_id, session in self._active_sessions.items()
377
+ if session.created < cutoff_time
378
+ ]
379
+
380
+ for session_id in old_sessions:
381
+ del self._active_sessions[session_id]
382
+
383
+ if old_sessions:
384
+ logger.info(f"Cleaned up {len(old_sessions)} old merge sessions")
385
+
386
+ return len(old_sessions)
@@ -0,0 +1,7 @@
1
+ """
2
+ Middleware components for the MCP Code Indexer.
3
+ """
4
+
5
+ from .error_middleware import ToolMiddleware, AsyncTaskManager, create_tool_middleware
6
+
7
+ __all__ = ["ToolMiddleware", "AsyncTaskManager", "create_tool_middleware"]
@@ -0,0 +1,286 @@
1
+ """
2
+ Error handling middleware for MCP tools.
3
+
4
+ This module provides decorators and middleware functions to standardize
5
+ error handling across all MCP tool implementations.
6
+ """
7
+
8
+ import asyncio
9
+ import functools
10
+ import time
11
+ from typing import Any, Callable, Dict, List
12
+
13
+ from mcp import types
14
+
15
+ from mcp_code_indexer.error_handler import ErrorHandler, MCPError
16
+ from mcp_code_indexer.logging_config import get_logger, log_tool_usage, log_performance_metrics
17
+
18
+ logger = get_logger(__name__)
19
+
20
+
21
+ class ToolMiddleware:
22
+ """Middleware for MCP tool error handling and logging."""
23
+
24
+ def __init__(self, error_handler: ErrorHandler):
25
+ """Initialize middleware with error handler."""
26
+ self.error_handler = error_handler
27
+
28
+ def wrap_tool_handler(self, tool_name: str):
29
+ """
30
+ Decorator to wrap tool handlers with error handling and logging.
31
+
32
+ Args:
33
+ tool_name: Name of the MCP tool
34
+
35
+ Returns:
36
+ Decorator function
37
+ """
38
+ def decorator(func: Callable) -> Callable:
39
+ @functools.wraps(func)
40
+ async def wrapper(arguments: Dict[str, Any]) -> List[types.TextContent]:
41
+ start_time = time.time()
42
+ success = False
43
+ result_size = 0
44
+
45
+ try:
46
+ # Log tool invocation
47
+ logger.info(f"Tool {tool_name} called", extra={
48
+ "structured_data": {
49
+ "tool_invocation": {
50
+ "tool_name": tool_name,
51
+ "arguments_count": len(arguments)
52
+ }
53
+ }
54
+ })
55
+
56
+ # Call the actual tool handler
57
+ result = await func(arguments)
58
+
59
+ # Calculate result size
60
+ if isinstance(result, list):
61
+ result_size = sum(len(item.text) if hasattr(item, 'text') else 0 for item in result)
62
+
63
+ success = True
64
+ duration = time.time() - start_time
65
+
66
+ # Log performance metrics
67
+ log_performance_metrics(
68
+ logger,
69
+ f"tool_{tool_name}",
70
+ duration,
71
+ result_size=result_size,
72
+ arguments_count=len(arguments)
73
+ )
74
+
75
+ return result
76
+
77
+ except Exception as e:
78
+ duration = time.time() - start_time
79
+
80
+ # Log the error
81
+ self.error_handler.log_error(
82
+ e,
83
+ context={"arguments_count": len(arguments)},
84
+ tool_name=tool_name
85
+ )
86
+
87
+ # Create error response
88
+ error_response = self.error_handler.create_mcp_error_response(
89
+ e, tool_name, arguments
90
+ )
91
+
92
+ return [error_response]
93
+
94
+ finally:
95
+ # Always log tool usage
96
+ log_tool_usage(
97
+ logger,
98
+ tool_name,
99
+ arguments,
100
+ success,
101
+ time.time() - start_time,
102
+ result_size if success else None
103
+ )
104
+
105
+ return wrapper
106
+ return decorator
107
+
108
+ def validate_tool_arguments(self, required_fields: List[str], optional_fields: List[str] = None):
109
+ """
110
+ Decorator to validate tool arguments.
111
+
112
+ Args:
113
+ required_fields: List of required argument names
114
+ optional_fields: List of optional argument names
115
+
116
+ Returns:
117
+ Decorator function
118
+ """
119
+ def decorator(func: Callable) -> Callable:
120
+ @functools.wraps(func)
121
+ async def wrapper(arguments: Dict[str, Any]) -> Any:
122
+ from ..error_handler import ValidationError
123
+
124
+ # Check required fields
125
+ missing_fields = [field for field in required_fields if field not in arguments]
126
+ if missing_fields:
127
+ raise ValidationError(
128
+ f"Missing required fields: {', '.join(missing_fields)}",
129
+ details={"missing_fields": missing_fields, "provided_fields": list(arguments.keys())}
130
+ )
131
+
132
+ # Check for unexpected fields if optional_fields is provided
133
+ if optional_fields is not None:
134
+ all_fields = set(required_fields + optional_fields)
135
+ unexpected_fields = [field for field in arguments.keys() if field not in all_fields]
136
+ if unexpected_fields:
137
+ raise ValidationError(
138
+ f"Unexpected fields: {', '.join(unexpected_fields)}",
139
+ details={"unexpected_fields": unexpected_fields, "allowed_fields": list(all_fields)}
140
+ )
141
+
142
+ return await func(arguments)
143
+
144
+ return wrapper
145
+ return decorator
146
+
147
+
148
+ class AsyncTaskManager:
149
+ """Manages async tasks with proper error handling."""
150
+
151
+ def __init__(self, error_handler: ErrorHandler):
152
+ """Initialize task manager."""
153
+ self.error_handler = error_handler
154
+ self._tasks: List[asyncio.Task] = []
155
+
156
+ def create_task(self, coro, name: str = None) -> asyncio.Task:
157
+ """
158
+ Create a managed async task.
159
+
160
+ Args:
161
+ coro: Coroutine to run
162
+ name: Optional task name for logging
163
+
164
+ Returns:
165
+ Created task
166
+ """
167
+ task = asyncio.create_task(coro, name=name)
168
+ self._tasks.append(task)
169
+
170
+ # Add done callback for error handling
171
+ task.add_done_callback(
172
+ lambda t: asyncio.create_task(
173
+ self._handle_task_completion(t, name or "unnamed_task")
174
+ )
175
+ )
176
+
177
+ return task
178
+
179
+ async def _handle_task_completion(self, task: asyncio.Task, task_name: str) -> None:
180
+ """Handle task completion and errors."""
181
+ try:
182
+ if task.done() and not task.cancelled():
183
+ exception = task.exception()
184
+ if exception:
185
+ await self.error_handler.handle_async_task_error(
186
+ task, task_name
187
+ )
188
+ except Exception as e:
189
+ logger.error(f"Error handling task completion for {task_name}: {e}")
190
+ finally:
191
+ # Remove completed task from tracking
192
+ if task in self._tasks:
193
+ self._tasks.remove(task)
194
+
195
+ async def wait_for_all(self, timeout: float = None) -> None:
196
+ """
197
+ Wait for all managed tasks to complete.
198
+
199
+ Args:
200
+ timeout: Maximum time to wait in seconds
201
+ """
202
+ if not self._tasks:
203
+ return
204
+
205
+ try:
206
+ await asyncio.wait_for(
207
+ asyncio.gather(*self._tasks, return_exceptions=True),
208
+ timeout=timeout
209
+ )
210
+ except asyncio.TimeoutError:
211
+ logger.warning(f"Timeout waiting for {len(self._tasks)} tasks")
212
+ # Cancel remaining tasks
213
+ for task in self._tasks:
214
+ if not task.done():
215
+ task.cancel()
216
+ except Exception as e:
217
+ logger.error(f"Error waiting for tasks: {e}")
218
+
219
+ def cancel_all(self) -> None:
220
+ """Cancel all managed tasks."""
221
+ for task in self._tasks:
222
+ if not task.done():
223
+ task.cancel()
224
+ self._tasks.clear()
225
+
226
+ @property
227
+ def active_task_count(self) -> int:
228
+ """Get count of active tasks."""
229
+ return len([task for task in self._tasks if not task.done()])
230
+
231
+
232
+ def create_tool_middleware(error_handler: ErrorHandler) -> ToolMiddleware:
233
+ """
234
+ Create tool middleware instance.
235
+
236
+ Args:
237
+ error_handler: Error handler instance
238
+
239
+ Returns:
240
+ Configured ToolMiddleware
241
+ """
242
+ return ToolMiddleware(error_handler)
243
+
244
+
245
+ # Convenience decorators for common patterns
246
+
247
+ def require_fields(*required_fields):
248
+ """Decorator that requires specific fields in arguments."""
249
+ def decorator(func):
250
+ @functools.wraps(func)
251
+ async def wrapper(self, arguments: Dict[str, Any]):
252
+ from ..error_handler import ValidationError
253
+
254
+ missing = [field for field in required_fields if field not in arguments]
255
+ if missing:
256
+ raise ValidationError(f"Missing required fields: {', '.join(missing)}")
257
+
258
+ return await func(self, arguments)
259
+ return wrapper
260
+ return decorator
261
+
262
+
263
+ def handle_file_operations(func):
264
+ """Decorator for file operation error handling."""
265
+ @functools.wraps(func)
266
+ async def wrapper(*args, **kwargs):
267
+ try:
268
+ return await func(*args, **kwargs)
269
+ except (FileNotFoundError, PermissionError, OSError) as e:
270
+ from ..error_handler import FileSystemError
271
+ raise FileSystemError(f"File operation failed: {e}") from e
272
+ return wrapper
273
+
274
+
275
+ def handle_database_operations(func):
276
+ """Decorator for database operation error handling."""
277
+ @functools.wraps(func)
278
+ async def wrapper(*args, **kwargs):
279
+ try:
280
+ return await func(*args, **kwargs)
281
+ except Exception as e:
282
+ if any(keyword in str(e).lower() for keyword in ["database", "sqlite", "sql"]):
283
+ from ..error_handler import DatabaseError
284
+ raise DatabaseError(f"Database operation failed: {e}") from e
285
+ raise
286
+ return wrapper
@@ -0,0 +1 @@
1
+ """MCP server implementation modules."""