code-memory 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
server.py ADDED
@@ -0,0 +1,299 @@
1
+ """
2
+ code-memory MCP Server
3
+
4
+ A deterministic, high-precision code intelligence layer exposed via the
5
+ Model Context Protocol (MCP). Uses a "Progressive Disclosure" routing
6
+ architecture:
7
+
8
+ 1. "Who/Why?" → search_history (Git data)
9
+ 2. "Where/What?" → search_code (AST data + hybrid retrieval)
10
+ 3. "How?" → search_docs (Semantic / Fuzzy logic)
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from typing import Literal
16
+
17
+ from mcp.server.fastmcp import FastMCP
18
+
19
+ import db as db_mod
20
+ import doc_parser as doc_parser_mod
21
+ import errors
22
+ import logging_config
23
+ import parser as parser_mod
24
+ import queries
25
+ import validation as val
26
+
27
+ # ── Initialize logging ───────────────────────────────────────────────────
28
+ logger = logging_config.setup_logging()
29
+ tool_logger = logging_config.get_logger("tools")
30
+
31
+ # ── Initialize the FastMCP server ────────────────────────────────────────
32
+ mcp = FastMCP("code-memory")
33
+
34
+
35
+ # ── Tool 1: search_code ───────────────────────────────────────────────────
36
+ @mcp.tool()
37
+ def search_code(
38
+ query: str,
39
+ search_type: Literal["definition", "references", "file_structure"],
40
+ ) -> dict:
41
+ """Search the indexed codebase for definitions, references, or file
42
+ structure.
43
+
44
+ Uses hybrid retrieval (BM25 keyword search + dense vector semantic
45
+ search) with Reciprocal Rank Fusion for definition queries.
46
+
47
+ - **definition**: Find where a symbol is defined (hybrid search).
48
+ - **references**: Find all cross-references to a symbol name.
49
+ - **file_structure**: List all symbols in a file, ordered by line.
50
+
51
+ Run ``index_codebase`` first to populate the search index."""
52
+ with logging_config.ToolLogger("search_code", query=query, search_type=search_type) as log:
53
+ try:
54
+ # Validate inputs
55
+ query = val.validate_query(query)
56
+ search_type = val.validate_search_type(
57
+ search_type, ["definition", "references", "file_structure"]
58
+ )
59
+
60
+ database = db_mod.get_db()
61
+
62
+ if search_type == "definition":
63
+ results = queries.find_definition(query, database)
64
+ log.set_result_count(len(results))
65
+ return {"status": "ok", "search_type": "definition", "query": query, "results": results}
66
+
67
+ elif search_type == "references":
68
+ results = queries.find_references(query, database)
69
+ log.set_result_count(len(results))
70
+ return {"status": "ok", "search_type": "references", "query": query, "results": results}
71
+
72
+ elif search_type == "file_structure":
73
+ results = queries.get_file_structure(query, database)
74
+ log.set_result_count(len(results))
75
+ return {"status": "ok", "search_type": "file_structure", "query": query, "results": results}
76
+
77
+ return errors.format_error(errors.ValidationError(f"Unknown search_type: {search_type}"))
78
+
79
+ except errors.CodeMemoryError as e:
80
+ return e.to_dict()
81
+ except Exception as e:
82
+ return errors.format_error(e)
83
+
84
+
85
+ # ── Tool 2: index_codebase ────────────────────────────────────────────────
86
+ @mcp.tool()
87
+ def index_codebase(directory: str = ".") -> dict:
88
+ """Indexes or re-indexes source files and documentation in the given directory.
89
+
90
+ Run this before using search_code or search_docs to ensure the database
91
+ is up to date. Uses tree-sitter for language-agnostic structural extraction
92
+ and generates embeddings for semantic search. Supports Python, JavaScript/
93
+ TypeScript, Java, Kotlin, Go, Rust, C/C++, Ruby, and more.
94
+
95
+ Also indexes markdown documentation files and extracts docstrings from
96
+ indexed code symbols. Unchanged files (by mtime) are automatically skipped.
97
+
98
+ Args:
99
+ directory: The root directory to index (recursively).
100
+
101
+ Returns:
102
+ Summary of indexing results including code and documentation stats.
103
+ """
104
+ with logging_config.ToolLogger("index_codebase", directory=directory) as log:
105
+ try:
106
+ # Validate directory
107
+ directory_path = val.validate_directory(directory)
108
+
109
+ database = db_mod.get_db()
110
+
111
+ # Index code files
112
+ code_logger = logging_config.IndexingLogger("code")
113
+ code_logger.start(str(directory_path))
114
+
115
+ code_results = parser_mod.index_directory(str(directory_path), database)
116
+ for r in code_results:
117
+ if r.get("skipped"):
118
+ code_logger.file_skipped(r.get("file", "unknown"), r.get("reason", "unknown"))
119
+ else:
120
+ code_logger.file_indexed(r.get("file", "unknown"), r.get("symbols_indexed", 0))
121
+ code_logger.complete()
122
+
123
+ indexed = [r for r in code_results if not r.get("skipped")]
124
+ skipped = [r for r in code_results if r.get("skipped")]
125
+
126
+ # Index documentation files
127
+ doc_logger = logging_config.IndexingLogger("documentation")
128
+ doc_logger.start(str(directory_path))
129
+
130
+ doc_results = doc_parser_mod.index_doc_directory(str(directory_path), database)
131
+ for r in doc_results:
132
+ if r.get("skipped"):
133
+ doc_logger.file_skipped(r.get("file", "unknown"), r.get("reason", "unknown"))
134
+ else:
135
+ doc_logger.file_indexed(r.get("file", "unknown"), r.get("chunks_indexed", 0))
136
+ doc_logger.complete()
137
+
138
+ doc_indexed = [r for r in doc_results if not r.get("skipped")]
139
+ doc_skipped = [r for r in doc_results if r.get("skipped")]
140
+
141
+ # Extract docstrings from indexed code
142
+ docstring_results = doc_parser_mod.extract_docstrings_from_code(database)
143
+
144
+ total_symbols = sum(r.get("symbols_indexed", 0) for r in indexed)
145
+ total_chunks = sum(r.get("chunks_indexed", 0) for r in doc_indexed)
146
+ log.set_result_count(total_symbols + total_chunks + len(docstring_results))
147
+
148
+ return {
149
+ "status": "ok",
150
+ "directory": str(directory_path),
151
+ "code": {
152
+ "files_indexed": len(indexed),
153
+ "files_skipped": len(skipped),
154
+ "total_symbols": total_symbols,
155
+ "total_references": sum(r.get("references_indexed", 0) for r in indexed),
156
+ },
157
+ "documentation": {
158
+ "files_indexed": len(doc_indexed),
159
+ "files_skipped": len(doc_skipped),
160
+ "total_chunks": total_chunks,
161
+ "docstrings_extracted": len(docstring_results),
162
+ },
163
+ "details": {
164
+ "code": indexed,
165
+ "docs": doc_indexed,
166
+ },
167
+ }
168
+
169
+ except errors.CodeMemoryError as e:
170
+ return e.to_dict()
171
+ except Exception as e:
172
+ return errors.format_error(e)
173
+
174
+
175
+ # ── Tool 3: search_docs ────────────────────────────────────────────────────
176
+ @mcp.tool()
177
+ def search_docs(query: str, top_k: int = 10) -> dict:
178
+ """Use this tool to understand the codebase conceptually. Ideal for
179
+ 'how does X work?', 'explain the architecture', or finding standard
180
+ operating procedures in the documentation.
181
+
182
+ Uses hybrid retrieval (BM25 keyword search + dense vector semantic
183
+ search) with Reciprocal Rank Fusion over markdown documentation,
184
+ README files, and docstrings extracted from code.
185
+
186
+ Args:
187
+ query: A natural language question about the codebase.
188
+ top_k: Maximum number of results to return (default 10).
189
+
190
+ Returns:
191
+ Dictionary with 'results' key containing matching documentation
192
+ chunks, each with source attribution (file, section, line numbers)
193
+ and relevance score.
194
+ """
195
+ with logging_config.ToolLogger("search_docs", query=query, top_k=top_k) as log:
196
+ try:
197
+ # Validate inputs
198
+ query = val.validate_query(query)
199
+ top_k = val.validate_top_k(top_k)
200
+
201
+ database = db_mod.get_db()
202
+ results = queries.search_documentation(query, database, top_k=top_k)
203
+ log.set_result_count(len(results))
204
+
205
+ return {
206
+ "status": "ok",
207
+ "query": query,
208
+ "results": results,
209
+ "count": len(results),
210
+ }
211
+
212
+ except errors.CodeMemoryError as e:
213
+ return e.to_dict()
214
+ except Exception as e:
215
+ return errors.format_error(e)
216
+
217
+
218
+ # ── Tool 4: search_history ─────────────────────────────────────────────────
219
+ @mcp.tool()
220
+ def search_history(
221
+ query: str,
222
+ search_type: Literal["commits", "file_history", "blame", "commit_detail"] = "commits",
223
+ target_file: str | None = None,
224
+ line_start: int | None = None,
225
+ line_end: int | None = None,
226
+ ) -> dict:
227
+ """Search local Git history to debug regressions, understand developer
228
+ intent, or find out WHY a specific change was made.
229
+
230
+ **search_type options:**
231
+
232
+ - ``commits`` — Search commit messages for *query* (case-insensitive).
233
+ Optionally filter to commits that touched *target_file*.
234
+ - ``file_history`` — Show the commit log for *target_file* (follows
235
+ renames). *target_file* is required; *query* is ignored.
236
+ - ``blame`` — Run ``git blame`` on *target_file*, optionally limited to
237
+ *line_start*–*line_end*. *target_file* is required.
238
+ - ``commit_detail`` — Get full metadata and diff for one commit.
239
+ Pass the commit hash as *query*. Optionally set *target_file* to
240
+ restrict the diff to that file.
241
+ """
242
+ with logging_config.ToolLogger("search_history", query=query, search_type=search_type,
243
+ target_file=target_file) as log:
244
+ try:
245
+ import git_search as gs
246
+ from git.exc import InvalidGitRepositoryError, NoSuchPathError
247
+
248
+ # Validate inputs
249
+ search_type = val.validate_search_type(
250
+ search_type, ["commits", "file_history", "blame", "commit_detail"]
251
+ )
252
+ line_start, line_end = val.validate_line_range(line_start, line_end)
253
+
254
+ # Get git repository
255
+ try:
256
+ repo = gs.get_repo(".")
257
+ except (InvalidGitRepositoryError, NoSuchPathError) as exc:
258
+ raise errors.GitError(f"Git repository not found: {exc}")
259
+
260
+ if search_type == "commits":
261
+ query = val.validate_query(query, min_length=1)
262
+ results = gs.search_commits(repo, query, target_file)
263
+ log.set_result_count(len(results))
264
+ return {"status": "ok", "search_type": "commits", "query": query, "results": results}
265
+
266
+ elif search_type == "file_history":
267
+ if not target_file:
268
+ raise errors.ValidationError("target_file is required for file_history search")
269
+ results = gs.get_file_history(repo, target_file)
270
+ log.set_result_count(len(results))
271
+ return {"status": "ok", "search_type": "file_history", "target_file": target_file, "results": results}
272
+
273
+ elif search_type == "blame":
274
+ if not target_file:
275
+ raise errors.ValidationError("target_file is required for blame search")
276
+ results = gs.get_blame(repo, target_file, line_start, line_end)
277
+ log.set_result_count(len(results))
278
+ return {"status": "ok", "search_type": "blame", "target_file": target_file, "results": results}
279
+
280
+ elif search_type == "commit_detail":
281
+ result = gs.get_commit_detail(repo, query, target_file)
282
+ return {"status": "ok", "search_type": "commit_detail", "result": result}
283
+
284
+ return errors.format_error(errors.ValidationError(f"Unknown search_type: {search_type}"))
285
+
286
+ except errors.CodeMemoryError as e:
287
+ return e.to_dict()
288
+ except Exception as e:
289
+ return errors.format_error(e)
290
+
291
+
292
+ # ── Entrypoint ────────────────────────────────────────────────────────────
293
+ def main():
294
+ """Entry point for the MCP server when installed as a package."""
295
+ mcp.run()
296
+
297
+
298
+ if __name__ == "__main__":
299
+ main()
tests/__init__.py ADDED
@@ -0,0 +1 @@
1
+ """Test suite for code-memory."""
tests/conftest.py ADDED
@@ -0,0 +1,192 @@
1
+ """
2
+ Shared test fixtures for code-memory tests.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import os
8
+ import sys
9
+ import sqlite3
10
+ import tempfile
11
+ from pathlib import Path
12
+
13
+ # Add parent directory to path for imports
14
+ sys.path.insert(0, str(Path(__file__).parent.parent))
15
+
16
+ import pytest
17
+
18
+
19
+ @pytest.fixture
20
+ def temp_db():
21
+ """Provide a temporary in-memory database for tests."""
22
+ # Use a temporary file for sqlite-vec compatibility
23
+ with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
24
+ db_path = f.name
25
+
26
+ db = sqlite3.connect(db_path)
27
+
28
+ # Load sqlite-vec
29
+ try:
30
+ import sqlite_vec
31
+ db.enable_load_extension(True)
32
+ sqlite_vec.load(db)
33
+ db.enable_load_extension(False)
34
+ except ImportError:
35
+ pass # sqlite-vec not available, skip vector tests
36
+
37
+ yield db
38
+
39
+ db.close()
40
+ os.unlink(db_path)
41
+
42
+
43
+ @pytest.fixture
44
+ def temp_dir():
45
+ """Provide a temporary directory for file tests."""
46
+ with tempfile.TemporaryDirectory() as tmpdir:
47
+ yield Path(tmpdir)
48
+
49
+
50
+ @pytest.fixture
51
+ def sample_python_file(temp_dir):
52
+ """Create a sample Python file for parsing tests."""
53
+ code = '''
54
+ """Module docstring for testing."""
55
+
56
+ import os
57
+ from typing import Optional
58
+
59
+ class SampleClass:
60
+ """A sample class for testing."""
61
+
62
+ def __init__(self, name: str):
63
+ """Initialize the sample class."""
64
+ self.name = name
65
+
66
+ def get_name(self) -> str:
67
+ """Return the name."""
68
+ return self.name
69
+
70
+ def process_data(self, data: Optional[dict] = None) -> dict:
71
+ """Process some data."""
72
+ if data is None:
73
+ data = {}
74
+ return {"name": self.name, **data}
75
+
76
+
77
+ def standalone_function(x: int, y: int) -> int:
78
+ """A standalone function that adds two numbers."""
79
+ return x + y
80
+
81
+
82
+ def another_function(text: str) -> str:
83
+ """Another function for testing."""
84
+ return text.upper()
85
+ '''
86
+ filepath = temp_dir / "sample.py"
87
+ filepath.write_text(code)
88
+ return filepath
89
+
90
+
91
+ @pytest.fixture
92
+ def sample_markdown_file(temp_dir):
93
+ """Create a sample markdown file for documentation tests."""
94
+ content = """# Sample Documentation
95
+
96
+ This is a sample documentation file for testing.
97
+
98
+ ## Installation
99
+
100
+ To install, run:
101
+
102
+ ```bash
103
+ pip install code-memory
104
+ ```
105
+
106
+ ## Usage
107
+
108
+ Here's how to use the tool:
109
+
110
+ 1. Index your codebase
111
+ 2. Search for symbols
112
+
113
+ ## Architecture
114
+
115
+ The system uses a Progressive Disclosure architecture.
116
+
117
+ ### Components
118
+
119
+ - search_code: Find definitions
120
+ - search_docs: Find documentation
121
+ - search_history: Search git history
122
+ """
123
+ filepath = temp_dir / "README.md"
124
+ filepath.write_text(content)
125
+ return filepath
126
+
127
+
128
+ @pytest.fixture
129
+ def temp_git_repo(temp_dir):
130
+ """Provide a temporary git repository for tests."""
131
+ import subprocess
132
+
133
+ # Initialize git repo
134
+ subprocess.run(["git", "init"], cwd=temp_dir, check=True, capture_output=True)
135
+ subprocess.run(["git", "config", "user.email", "test@test.com"], cwd=temp_dir, check=True, capture_output=True)
136
+ subprocess.run(["git", "config", "user.name", "Test User"], cwd=temp_dir, check=True, capture_output=True)
137
+
138
+ # Create and commit a file
139
+ test_file = temp_dir / "test.py"
140
+ test_file.write_text("# Test file\nprint('hello')\n")
141
+ subprocess.run(["git", "add", "."], cwd=temp_dir, check=True, capture_output=True)
142
+ subprocess.run(["git", "commit", "-m", "Initial commit"], cwd=temp_dir, check=True, capture_output=True)
143
+
144
+ yield temp_dir
145
+
146
+
147
+ @pytest.fixture
148
+ def sample_symbols_db(temp_db):
149
+ """Provide a database with sample symbols for search tests."""
150
+ # Create minimal schema
151
+ temp_db.execute("""
152
+ CREATE TABLE IF NOT EXISTS files (
153
+ id INTEGER PRIMARY KEY,
154
+ path TEXT UNIQUE NOT NULL,
155
+ last_modified REAL NOT NULL,
156
+ file_hash TEXT NOT NULL
157
+ )
158
+ """)
159
+
160
+ temp_db.execute("""
161
+ CREATE TABLE IF NOT EXISTS symbols (
162
+ id INTEGER PRIMARY KEY,
163
+ name TEXT NOT NULL,
164
+ kind TEXT NOT NULL,
165
+ file_id INTEGER NOT NULL,
166
+ line_start INTEGER NOT NULL,
167
+ line_end INTEGER NOT NULL,
168
+ parent_symbol_id INTEGER,
169
+ source_text TEXT NOT NULL
170
+ )
171
+ """)
172
+
173
+ # Insert sample data
174
+ temp_db.execute("INSERT INTO files (path, last_modified, file_hash) VALUES (?, ?, ?)",
175
+ ("/test/sample.py", 0.0, "abc123"))
176
+ file_id = temp_db.lastrowid
177
+
178
+ symbols = [
179
+ ("SampleClass", "class", file_id, 5, 20, None, "class SampleClass: ..."),
180
+ ("__init__", "method", file_id, 8, 10, 1, "def __init__(self, name): ..."),
181
+ ("get_name", "method", file_id, 12, 14, 1, "def get_name(self): ..."),
182
+ ("standalone_function", "function", file_id, 22, 24, None, "def standalone_function(x, y): ..."),
183
+ ]
184
+
185
+ for name, kind, fid, line_start, line_end, parent, source in symbols:
186
+ temp_db.execute(
187
+ "INSERT INTO symbols (name, kind, file_id, line_start, line_end, parent_symbol_id, source_text) VALUES (?, ?, ?, ?, ?, ?, ?)",
188
+ (name, kind, fid, line_start, line_end, parent, source)
189
+ )
190
+
191
+ temp_db.commit()
192
+ return temp_db
tests/test_errors.py ADDED
@@ -0,0 +1,112 @@
1
+ """Tests for error handling module."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pytest
6
+
7
+ from errors import (
8
+ CodeMemoryError,
9
+ DatabaseError,
10
+ IndexingError,
11
+ GitError,
12
+ ValidationError,
13
+ EmbeddingError,
14
+ format_error,
15
+ )
16
+
17
+
18
+ class TestCodeMemoryError:
19
+ """Tests for base CodeMemoryError class."""
20
+
21
+ def test_basic_error(self):
22
+ """Test basic error creation."""
23
+ error = CodeMemoryError("Test error")
24
+ assert error.message == "Test error"
25
+ assert str(error) == "Test error"
26
+
27
+ def test_error_with_details(self):
28
+ """Test error with additional details."""
29
+ error = CodeMemoryError("Test error", {"key": "value"})
30
+ assert error.details == {"key": "value"}
31
+
32
+ def test_to_dict(self):
33
+ """Test conversion to dict."""
34
+ error = CodeMemoryError("Test error")
35
+ result = error.to_dict()
36
+ assert result["error"] is True
37
+ assert result["error_type"] == "CodeMemoryError"
38
+ assert result["message"] == "Test error"
39
+
40
+ def test_to_dict_with_details(self):
41
+ """Test conversion to dict with details."""
42
+ error = CodeMemoryError("Test error", {"key": "value"})
43
+ result = error.to_dict()
44
+ assert result["details"] == {"key": "value"}
45
+
46
+ def test_to_dict_without_details(self):
47
+ """Test that None details are preserved."""
48
+ error = CodeMemoryError("Test error", {})
49
+ result = error.to_dict()
50
+ assert result["details"] is None
51
+
52
+
53
+ class TestSpecializedErrors:
54
+ """Tests for specialized error classes."""
55
+
56
+ def test_database_error(self):
57
+ """Test DatabaseError."""
58
+ error = DatabaseError("Connection failed")
59
+ assert isinstance(error, CodeMemoryError)
60
+ assert error.to_dict()["error_type"] == "DatabaseError"
61
+
62
+ def test_indexing_error(self):
63
+ """Test IndexingError."""
64
+ error = IndexingError("Parse failed")
65
+ assert isinstance(error, CodeMemoryError)
66
+ assert error.to_dict()["error_type"] == "IndexingError"
67
+
68
+ def test_git_error(self):
69
+ """Test GitError."""
70
+ error = GitError("Not a git repo")
71
+ assert isinstance(error, CodeMemoryError)
72
+ assert error.to_dict()["error_type"] == "GitError"
73
+
74
+ def test_validation_error(self):
75
+ """Test ValidationError."""
76
+ error = ValidationError("Invalid input")
77
+ assert isinstance(error, CodeMemoryError)
78
+ assert error.to_dict()["error_type"] == "ValidationError"
79
+
80
+ def test_embedding_error(self):
81
+ """Test EmbeddingError."""
82
+ error = EmbeddingError("Model load failed")
83
+ assert isinstance(error, CodeMemoryError)
84
+ assert error.to_dict()["error_type"] == "EmbeddingError"
85
+
86
+
87
+ class TestFormatError:
88
+ """Tests for format_error function."""
89
+
90
+ def test_format_code_memory_error(self):
91
+ """Test formatting CodeMemoryError."""
92
+ error = ValidationError("Invalid input", {"field": "query"})
93
+ result = format_error(error)
94
+ assert result["error"] is True
95
+ assert result["error_type"] == "ValidationError"
96
+ assert result["message"] == "Invalid input"
97
+ assert result["details"] == {"field": "query"}
98
+
99
+ def test_format_builtin_exception(self):
100
+ """Test formatting built-in exceptions."""
101
+ error = ValueError("Something went wrong")
102
+ result = format_error(error)
103
+ assert result["error"] is True
104
+ assert result["error_type"] == "ValueError"
105
+ assert result["message"] == "Something went wrong"
106
+
107
+ def test_format_exception_without_message(self):
108
+ """Test formatting exception with empty message."""
109
+ error = RuntimeError()
110
+ result = format_error(error)
111
+ assert result["error"] is True
112
+ assert "RuntimeError" in result["message"]