code-memory 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- .github/workflows/ci.yml +71 -0
- .github/workflows/publish.yml +33 -0
- .gitignore +40 -0
- .python-version +1 -0
- CHANGELOG.md +43 -0
- CONTRIBUTING.md +133 -0
- LICENSE +21 -0
- Makefile +33 -0
- PKG-INFO +275 -0
- README.md +233 -0
- code_memory-0.1.0.dist-info/METADATA +275 -0
- code_memory-0.1.0.dist-info/RECORD +37 -0
- code_memory-0.1.0.dist-info/WHEEL +4 -0
- code_memory-0.1.0.dist-info/entry_points.txt +2 -0
- code_memory-0.1.0.dist-info/licenses/LICENSE +21 -0
- db.py +403 -0
- doc_parser.py +494 -0
- errors.py +115 -0
- git_search.py +313 -0
- logging_config.py +191 -0
- parser.py +392 -0
- prompts/milestone_1.xml +62 -0
- prompts/milestone_2.xml +246 -0
- prompts/milestone_3.xml +214 -0
- prompts/milestone_4.xml +453 -0
- prompts/milestone_5.xml +599 -0
- pyproject.toml +92 -0
- queries.py +446 -0
- server.py +299 -0
- tests/__init__.py +1 -0
- tests/conftest.py +192 -0
- tests/test_errors.py +112 -0
- tests/test_logging.py +169 -0
- tests/test_tools.py +114 -0
- tests/test_validation.py +216 -0
- uv.lock +1921 -0
- validation.py +316 -0
server.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
"""
|
|
2
|
+
code-memory MCP Server
|
|
3
|
+
|
|
4
|
+
A deterministic, high-precision code intelligence layer exposed via the
|
|
5
|
+
Model Context Protocol (MCP). Uses a "Progressive Disclosure" routing
|
|
6
|
+
architecture:
|
|
7
|
+
|
|
8
|
+
1. "Who/Why?" → search_history (Git data)
|
|
9
|
+
2. "Where/What?" → search_code (AST data + hybrid retrieval)
|
|
10
|
+
3. "How?" → search_docs (Semantic / Fuzzy logic)
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from typing import Literal
|
|
16
|
+
|
|
17
|
+
from mcp.server.fastmcp import FastMCP
|
|
18
|
+
|
|
19
|
+
import db as db_mod
|
|
20
|
+
import doc_parser as doc_parser_mod
|
|
21
|
+
import errors
|
|
22
|
+
import logging_config
|
|
23
|
+
import parser as parser_mod
|
|
24
|
+
import queries
|
|
25
|
+
import validation as val
|
|
26
|
+
|
|
27
|
+
# ── Initialize logging ───────────────────────────────────────────────────
|
|
28
|
+
logger = logging_config.setup_logging()
|
|
29
|
+
tool_logger = logging_config.get_logger("tools")
|
|
30
|
+
|
|
31
|
+
# ── Initialize the FastMCP server ────────────────────────────────────────
|
|
32
|
+
mcp = FastMCP("code-memory")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# ── Tool 1: search_code ───────────────────────────────────────────────────
|
|
36
|
+
@mcp.tool()
|
|
37
|
+
def search_code(
|
|
38
|
+
query: str,
|
|
39
|
+
search_type: Literal["definition", "references", "file_structure"],
|
|
40
|
+
) -> dict:
|
|
41
|
+
"""Search the indexed codebase for definitions, references, or file
|
|
42
|
+
structure.
|
|
43
|
+
|
|
44
|
+
Uses hybrid retrieval (BM25 keyword search + dense vector semantic
|
|
45
|
+
search) with Reciprocal Rank Fusion for definition queries.
|
|
46
|
+
|
|
47
|
+
- **definition**: Find where a symbol is defined (hybrid search).
|
|
48
|
+
- **references**: Find all cross-references to a symbol name.
|
|
49
|
+
- **file_structure**: List all symbols in a file, ordered by line.
|
|
50
|
+
|
|
51
|
+
Run ``index_codebase`` first to populate the search index."""
|
|
52
|
+
with logging_config.ToolLogger("search_code", query=query, search_type=search_type) as log:
|
|
53
|
+
try:
|
|
54
|
+
# Validate inputs
|
|
55
|
+
query = val.validate_query(query)
|
|
56
|
+
search_type = val.validate_search_type(
|
|
57
|
+
search_type, ["definition", "references", "file_structure"]
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
database = db_mod.get_db()
|
|
61
|
+
|
|
62
|
+
if search_type == "definition":
|
|
63
|
+
results = queries.find_definition(query, database)
|
|
64
|
+
log.set_result_count(len(results))
|
|
65
|
+
return {"status": "ok", "search_type": "definition", "query": query, "results": results}
|
|
66
|
+
|
|
67
|
+
elif search_type == "references":
|
|
68
|
+
results = queries.find_references(query, database)
|
|
69
|
+
log.set_result_count(len(results))
|
|
70
|
+
return {"status": "ok", "search_type": "references", "query": query, "results": results}
|
|
71
|
+
|
|
72
|
+
elif search_type == "file_structure":
|
|
73
|
+
results = queries.get_file_structure(query, database)
|
|
74
|
+
log.set_result_count(len(results))
|
|
75
|
+
return {"status": "ok", "search_type": "file_structure", "query": query, "results": results}
|
|
76
|
+
|
|
77
|
+
return errors.format_error(errors.ValidationError(f"Unknown search_type: {search_type}"))
|
|
78
|
+
|
|
79
|
+
except errors.CodeMemoryError as e:
|
|
80
|
+
return e.to_dict()
|
|
81
|
+
except Exception as e:
|
|
82
|
+
return errors.format_error(e)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
# ── Tool 2: index_codebase ────────────────────────────────────────────────
|
|
86
|
+
@mcp.tool()
|
|
87
|
+
def index_codebase(directory: str = ".") -> dict:
|
|
88
|
+
"""Indexes or re-indexes source files and documentation in the given directory.
|
|
89
|
+
|
|
90
|
+
Run this before using search_code or search_docs to ensure the database
|
|
91
|
+
is up to date. Uses tree-sitter for language-agnostic structural extraction
|
|
92
|
+
and generates embeddings for semantic search. Supports Python, JavaScript/
|
|
93
|
+
TypeScript, Java, Kotlin, Go, Rust, C/C++, Ruby, and more.
|
|
94
|
+
|
|
95
|
+
Also indexes markdown documentation files and extracts docstrings from
|
|
96
|
+
indexed code symbols. Unchanged files (by mtime) are automatically skipped.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
directory: The root directory to index (recursively).
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
Summary of indexing results including code and documentation stats.
|
|
103
|
+
"""
|
|
104
|
+
with logging_config.ToolLogger("index_codebase", directory=directory) as log:
|
|
105
|
+
try:
|
|
106
|
+
# Validate directory
|
|
107
|
+
directory_path = val.validate_directory(directory)
|
|
108
|
+
|
|
109
|
+
database = db_mod.get_db()
|
|
110
|
+
|
|
111
|
+
# Index code files
|
|
112
|
+
code_logger = logging_config.IndexingLogger("code")
|
|
113
|
+
code_logger.start(str(directory_path))
|
|
114
|
+
|
|
115
|
+
code_results = parser_mod.index_directory(str(directory_path), database)
|
|
116
|
+
for r in code_results:
|
|
117
|
+
if r.get("skipped"):
|
|
118
|
+
code_logger.file_skipped(r.get("file", "unknown"), r.get("reason", "unknown"))
|
|
119
|
+
else:
|
|
120
|
+
code_logger.file_indexed(r.get("file", "unknown"), r.get("symbols_indexed", 0))
|
|
121
|
+
code_logger.complete()
|
|
122
|
+
|
|
123
|
+
indexed = [r for r in code_results if not r.get("skipped")]
|
|
124
|
+
skipped = [r for r in code_results if r.get("skipped")]
|
|
125
|
+
|
|
126
|
+
# Index documentation files
|
|
127
|
+
doc_logger = logging_config.IndexingLogger("documentation")
|
|
128
|
+
doc_logger.start(str(directory_path))
|
|
129
|
+
|
|
130
|
+
doc_results = doc_parser_mod.index_doc_directory(str(directory_path), database)
|
|
131
|
+
for r in doc_results:
|
|
132
|
+
if r.get("skipped"):
|
|
133
|
+
doc_logger.file_skipped(r.get("file", "unknown"), r.get("reason", "unknown"))
|
|
134
|
+
else:
|
|
135
|
+
doc_logger.file_indexed(r.get("file", "unknown"), r.get("chunks_indexed", 0))
|
|
136
|
+
doc_logger.complete()
|
|
137
|
+
|
|
138
|
+
doc_indexed = [r for r in doc_results if not r.get("skipped")]
|
|
139
|
+
doc_skipped = [r for r in doc_results if r.get("skipped")]
|
|
140
|
+
|
|
141
|
+
# Extract docstrings from indexed code
|
|
142
|
+
docstring_results = doc_parser_mod.extract_docstrings_from_code(database)
|
|
143
|
+
|
|
144
|
+
total_symbols = sum(r.get("symbols_indexed", 0) for r in indexed)
|
|
145
|
+
total_chunks = sum(r.get("chunks_indexed", 0) for r in doc_indexed)
|
|
146
|
+
log.set_result_count(total_symbols + total_chunks + len(docstring_results))
|
|
147
|
+
|
|
148
|
+
return {
|
|
149
|
+
"status": "ok",
|
|
150
|
+
"directory": str(directory_path),
|
|
151
|
+
"code": {
|
|
152
|
+
"files_indexed": len(indexed),
|
|
153
|
+
"files_skipped": len(skipped),
|
|
154
|
+
"total_symbols": total_symbols,
|
|
155
|
+
"total_references": sum(r.get("references_indexed", 0) for r in indexed),
|
|
156
|
+
},
|
|
157
|
+
"documentation": {
|
|
158
|
+
"files_indexed": len(doc_indexed),
|
|
159
|
+
"files_skipped": len(doc_skipped),
|
|
160
|
+
"total_chunks": total_chunks,
|
|
161
|
+
"docstrings_extracted": len(docstring_results),
|
|
162
|
+
},
|
|
163
|
+
"details": {
|
|
164
|
+
"code": indexed,
|
|
165
|
+
"docs": doc_indexed,
|
|
166
|
+
},
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
except errors.CodeMemoryError as e:
|
|
170
|
+
return e.to_dict()
|
|
171
|
+
except Exception as e:
|
|
172
|
+
return errors.format_error(e)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
# ── Tool 3: search_docs ────────────────────────────────────────────────────
|
|
176
|
+
@mcp.tool()
|
|
177
|
+
def search_docs(query: str, top_k: int = 10) -> dict:
|
|
178
|
+
"""Use this tool to understand the codebase conceptually. Ideal for
|
|
179
|
+
'how does X work?', 'explain the architecture', or finding standard
|
|
180
|
+
operating procedures in the documentation.
|
|
181
|
+
|
|
182
|
+
Uses hybrid retrieval (BM25 keyword search + dense vector semantic
|
|
183
|
+
search) with Reciprocal Rank Fusion over markdown documentation,
|
|
184
|
+
README files, and docstrings extracted from code.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
query: A natural language question about the codebase.
|
|
188
|
+
top_k: Maximum number of results to return (default 10).
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
Dictionary with 'results' key containing matching documentation
|
|
192
|
+
chunks, each with source attribution (file, section, line numbers)
|
|
193
|
+
and relevance score.
|
|
194
|
+
"""
|
|
195
|
+
with logging_config.ToolLogger("search_docs", query=query, top_k=top_k) as log:
|
|
196
|
+
try:
|
|
197
|
+
# Validate inputs
|
|
198
|
+
query = val.validate_query(query)
|
|
199
|
+
top_k = val.validate_top_k(top_k)
|
|
200
|
+
|
|
201
|
+
database = db_mod.get_db()
|
|
202
|
+
results = queries.search_documentation(query, database, top_k=top_k)
|
|
203
|
+
log.set_result_count(len(results))
|
|
204
|
+
|
|
205
|
+
return {
|
|
206
|
+
"status": "ok",
|
|
207
|
+
"query": query,
|
|
208
|
+
"results": results,
|
|
209
|
+
"count": len(results),
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
except errors.CodeMemoryError as e:
|
|
213
|
+
return e.to_dict()
|
|
214
|
+
except Exception as e:
|
|
215
|
+
return errors.format_error(e)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
# ── Tool 4: search_history ─────────────────────────────────────────────────
|
|
219
|
+
@mcp.tool()
|
|
220
|
+
def search_history(
|
|
221
|
+
query: str,
|
|
222
|
+
search_type: Literal["commits", "file_history", "blame", "commit_detail"] = "commits",
|
|
223
|
+
target_file: str | None = None,
|
|
224
|
+
line_start: int | None = None,
|
|
225
|
+
line_end: int | None = None,
|
|
226
|
+
) -> dict:
|
|
227
|
+
"""Search local Git history to debug regressions, understand developer
|
|
228
|
+
intent, or find out WHY a specific change was made.
|
|
229
|
+
|
|
230
|
+
**search_type options:**
|
|
231
|
+
|
|
232
|
+
- ``commits`` — Search commit messages for *query* (case-insensitive).
|
|
233
|
+
Optionally filter to commits that touched *target_file*.
|
|
234
|
+
- ``file_history`` — Show the commit log for *target_file* (follows
|
|
235
|
+
renames). *target_file* is required; *query* is ignored.
|
|
236
|
+
- ``blame`` — Run ``git blame`` on *target_file*, optionally limited to
|
|
237
|
+
*line_start*–*line_end*. *target_file* is required.
|
|
238
|
+
- ``commit_detail`` — Get full metadata and diff for one commit.
|
|
239
|
+
Pass the commit hash as *query*. Optionally set *target_file* to
|
|
240
|
+
restrict the diff to that file.
|
|
241
|
+
"""
|
|
242
|
+
with logging_config.ToolLogger("search_history", query=query, search_type=search_type,
|
|
243
|
+
target_file=target_file) as log:
|
|
244
|
+
try:
|
|
245
|
+
import git_search as gs
|
|
246
|
+
from git.exc import InvalidGitRepositoryError, NoSuchPathError
|
|
247
|
+
|
|
248
|
+
# Validate inputs
|
|
249
|
+
search_type = val.validate_search_type(
|
|
250
|
+
search_type, ["commits", "file_history", "blame", "commit_detail"]
|
|
251
|
+
)
|
|
252
|
+
line_start, line_end = val.validate_line_range(line_start, line_end)
|
|
253
|
+
|
|
254
|
+
# Get git repository
|
|
255
|
+
try:
|
|
256
|
+
repo = gs.get_repo(".")
|
|
257
|
+
except (InvalidGitRepositoryError, NoSuchPathError) as exc:
|
|
258
|
+
raise errors.GitError(f"Git repository not found: {exc}")
|
|
259
|
+
|
|
260
|
+
if search_type == "commits":
|
|
261
|
+
query = val.validate_query(query, min_length=1)
|
|
262
|
+
results = gs.search_commits(repo, query, target_file)
|
|
263
|
+
log.set_result_count(len(results))
|
|
264
|
+
return {"status": "ok", "search_type": "commits", "query": query, "results": results}
|
|
265
|
+
|
|
266
|
+
elif search_type == "file_history":
|
|
267
|
+
if not target_file:
|
|
268
|
+
raise errors.ValidationError("target_file is required for file_history search")
|
|
269
|
+
results = gs.get_file_history(repo, target_file)
|
|
270
|
+
log.set_result_count(len(results))
|
|
271
|
+
return {"status": "ok", "search_type": "file_history", "target_file": target_file, "results": results}
|
|
272
|
+
|
|
273
|
+
elif search_type == "blame":
|
|
274
|
+
if not target_file:
|
|
275
|
+
raise errors.ValidationError("target_file is required for blame search")
|
|
276
|
+
results = gs.get_blame(repo, target_file, line_start, line_end)
|
|
277
|
+
log.set_result_count(len(results))
|
|
278
|
+
return {"status": "ok", "search_type": "blame", "target_file": target_file, "results": results}
|
|
279
|
+
|
|
280
|
+
elif search_type == "commit_detail":
|
|
281
|
+
result = gs.get_commit_detail(repo, query, target_file)
|
|
282
|
+
return {"status": "ok", "search_type": "commit_detail", "result": result}
|
|
283
|
+
|
|
284
|
+
return errors.format_error(errors.ValidationError(f"Unknown search_type: {search_type}"))
|
|
285
|
+
|
|
286
|
+
except errors.CodeMemoryError as e:
|
|
287
|
+
return e.to_dict()
|
|
288
|
+
except Exception as e:
|
|
289
|
+
return errors.format_error(e)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
# ── Entrypoint ────────────────────────────────────────────────────────────
|
|
293
|
+
def main():
|
|
294
|
+
"""Entry point for the MCP server when installed as a package."""
|
|
295
|
+
mcp.run()
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
if __name__ == "__main__":
|
|
299
|
+
main()
|
tests/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Test suite for code-memory."""
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Shared test fixtures for code-memory tests.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
import sys
|
|
9
|
+
import sqlite3
|
|
10
|
+
import tempfile
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
# Add parent directory to path for imports
|
|
14
|
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
15
|
+
|
|
16
|
+
import pytest
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@pytest.fixture
|
|
20
|
+
def temp_db():
|
|
21
|
+
"""Provide a temporary in-memory database for tests."""
|
|
22
|
+
# Use a temporary file for sqlite-vec compatibility
|
|
23
|
+
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
|
24
|
+
db_path = f.name
|
|
25
|
+
|
|
26
|
+
db = sqlite3.connect(db_path)
|
|
27
|
+
|
|
28
|
+
# Load sqlite-vec
|
|
29
|
+
try:
|
|
30
|
+
import sqlite_vec
|
|
31
|
+
db.enable_load_extension(True)
|
|
32
|
+
sqlite_vec.load(db)
|
|
33
|
+
db.enable_load_extension(False)
|
|
34
|
+
except ImportError:
|
|
35
|
+
pass # sqlite-vec not available, skip vector tests
|
|
36
|
+
|
|
37
|
+
yield db
|
|
38
|
+
|
|
39
|
+
db.close()
|
|
40
|
+
os.unlink(db_path)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@pytest.fixture
|
|
44
|
+
def temp_dir():
|
|
45
|
+
"""Provide a temporary directory for file tests."""
|
|
46
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
47
|
+
yield Path(tmpdir)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@pytest.fixture
|
|
51
|
+
def sample_python_file(temp_dir):
|
|
52
|
+
"""Create a sample Python file for parsing tests."""
|
|
53
|
+
code = '''
|
|
54
|
+
"""Module docstring for testing."""
|
|
55
|
+
|
|
56
|
+
import os
|
|
57
|
+
from typing import Optional
|
|
58
|
+
|
|
59
|
+
class SampleClass:
|
|
60
|
+
"""A sample class for testing."""
|
|
61
|
+
|
|
62
|
+
def __init__(self, name: str):
|
|
63
|
+
"""Initialize the sample class."""
|
|
64
|
+
self.name = name
|
|
65
|
+
|
|
66
|
+
def get_name(self) -> str:
|
|
67
|
+
"""Return the name."""
|
|
68
|
+
return self.name
|
|
69
|
+
|
|
70
|
+
def process_data(self, data: Optional[dict] = None) -> dict:
|
|
71
|
+
"""Process some data."""
|
|
72
|
+
if data is None:
|
|
73
|
+
data = {}
|
|
74
|
+
return {"name": self.name, **data}
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def standalone_function(x: int, y: int) -> int:
|
|
78
|
+
"""A standalone function that adds two numbers."""
|
|
79
|
+
return x + y
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def another_function(text: str) -> str:
|
|
83
|
+
"""Another function for testing."""
|
|
84
|
+
return text.upper()
|
|
85
|
+
'''
|
|
86
|
+
filepath = temp_dir / "sample.py"
|
|
87
|
+
filepath.write_text(code)
|
|
88
|
+
return filepath
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@pytest.fixture
|
|
92
|
+
def sample_markdown_file(temp_dir):
|
|
93
|
+
"""Create a sample markdown file for documentation tests."""
|
|
94
|
+
content = """# Sample Documentation
|
|
95
|
+
|
|
96
|
+
This is a sample documentation file for testing.
|
|
97
|
+
|
|
98
|
+
## Installation
|
|
99
|
+
|
|
100
|
+
To install, run:
|
|
101
|
+
|
|
102
|
+
```bash
|
|
103
|
+
pip install code-memory
|
|
104
|
+
```
|
|
105
|
+
|
|
106
|
+
## Usage
|
|
107
|
+
|
|
108
|
+
Here's how to use the tool:
|
|
109
|
+
|
|
110
|
+
1. Index your codebase
|
|
111
|
+
2. Search for symbols
|
|
112
|
+
|
|
113
|
+
## Architecture
|
|
114
|
+
|
|
115
|
+
The system uses a Progressive Disclosure architecture.
|
|
116
|
+
|
|
117
|
+
### Components
|
|
118
|
+
|
|
119
|
+
- search_code: Find definitions
|
|
120
|
+
- search_docs: Find documentation
|
|
121
|
+
- search_history: Search git history
|
|
122
|
+
"""
|
|
123
|
+
filepath = temp_dir / "README.md"
|
|
124
|
+
filepath.write_text(content)
|
|
125
|
+
return filepath
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@pytest.fixture
|
|
129
|
+
def temp_git_repo(temp_dir):
|
|
130
|
+
"""Provide a temporary git repository for tests."""
|
|
131
|
+
import subprocess
|
|
132
|
+
|
|
133
|
+
# Initialize git repo
|
|
134
|
+
subprocess.run(["git", "init"], cwd=temp_dir, check=True, capture_output=True)
|
|
135
|
+
subprocess.run(["git", "config", "user.email", "test@test.com"], cwd=temp_dir, check=True, capture_output=True)
|
|
136
|
+
subprocess.run(["git", "config", "user.name", "Test User"], cwd=temp_dir, check=True, capture_output=True)
|
|
137
|
+
|
|
138
|
+
# Create and commit a file
|
|
139
|
+
test_file = temp_dir / "test.py"
|
|
140
|
+
test_file.write_text("# Test file\nprint('hello')\n")
|
|
141
|
+
subprocess.run(["git", "add", "."], cwd=temp_dir, check=True, capture_output=True)
|
|
142
|
+
subprocess.run(["git", "commit", "-m", "Initial commit"], cwd=temp_dir, check=True, capture_output=True)
|
|
143
|
+
|
|
144
|
+
yield temp_dir
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@pytest.fixture
|
|
148
|
+
def sample_symbols_db(temp_db):
|
|
149
|
+
"""Provide a database with sample symbols for search tests."""
|
|
150
|
+
# Create minimal schema
|
|
151
|
+
temp_db.execute("""
|
|
152
|
+
CREATE TABLE IF NOT EXISTS files (
|
|
153
|
+
id INTEGER PRIMARY KEY,
|
|
154
|
+
path TEXT UNIQUE NOT NULL,
|
|
155
|
+
last_modified REAL NOT NULL,
|
|
156
|
+
file_hash TEXT NOT NULL
|
|
157
|
+
)
|
|
158
|
+
""")
|
|
159
|
+
|
|
160
|
+
temp_db.execute("""
|
|
161
|
+
CREATE TABLE IF NOT EXISTS symbols (
|
|
162
|
+
id INTEGER PRIMARY KEY,
|
|
163
|
+
name TEXT NOT NULL,
|
|
164
|
+
kind TEXT NOT NULL,
|
|
165
|
+
file_id INTEGER NOT NULL,
|
|
166
|
+
line_start INTEGER NOT NULL,
|
|
167
|
+
line_end INTEGER NOT NULL,
|
|
168
|
+
parent_symbol_id INTEGER,
|
|
169
|
+
source_text TEXT NOT NULL
|
|
170
|
+
)
|
|
171
|
+
""")
|
|
172
|
+
|
|
173
|
+
# Insert sample data
|
|
174
|
+
temp_db.execute("INSERT INTO files (path, last_modified, file_hash) VALUES (?, ?, ?)",
|
|
175
|
+
("/test/sample.py", 0.0, "abc123"))
|
|
176
|
+
file_id = temp_db.lastrowid
|
|
177
|
+
|
|
178
|
+
symbols = [
|
|
179
|
+
("SampleClass", "class", file_id, 5, 20, None, "class SampleClass: ..."),
|
|
180
|
+
("__init__", "method", file_id, 8, 10, 1, "def __init__(self, name): ..."),
|
|
181
|
+
("get_name", "method", file_id, 12, 14, 1, "def get_name(self): ..."),
|
|
182
|
+
("standalone_function", "function", file_id, 22, 24, None, "def standalone_function(x, y): ..."),
|
|
183
|
+
]
|
|
184
|
+
|
|
185
|
+
for name, kind, fid, line_start, line_end, parent, source in symbols:
|
|
186
|
+
temp_db.execute(
|
|
187
|
+
"INSERT INTO symbols (name, kind, file_id, line_start, line_end, parent_symbol_id, source_text) VALUES (?, ?, ?, ?, ?, ?, ?)",
|
|
188
|
+
(name, kind, fid, line_start, line_end, parent, source)
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
temp_db.commit()
|
|
192
|
+
return temp_db
|
tests/test_errors.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
"""Tests for error handling module."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
|
|
7
|
+
from errors import (
|
|
8
|
+
CodeMemoryError,
|
|
9
|
+
DatabaseError,
|
|
10
|
+
IndexingError,
|
|
11
|
+
GitError,
|
|
12
|
+
ValidationError,
|
|
13
|
+
EmbeddingError,
|
|
14
|
+
format_error,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TestCodeMemoryError:
|
|
19
|
+
"""Tests for base CodeMemoryError class."""
|
|
20
|
+
|
|
21
|
+
def test_basic_error(self):
|
|
22
|
+
"""Test basic error creation."""
|
|
23
|
+
error = CodeMemoryError("Test error")
|
|
24
|
+
assert error.message == "Test error"
|
|
25
|
+
assert str(error) == "Test error"
|
|
26
|
+
|
|
27
|
+
def test_error_with_details(self):
|
|
28
|
+
"""Test error with additional details."""
|
|
29
|
+
error = CodeMemoryError("Test error", {"key": "value"})
|
|
30
|
+
assert error.details == {"key": "value"}
|
|
31
|
+
|
|
32
|
+
def test_to_dict(self):
|
|
33
|
+
"""Test conversion to dict."""
|
|
34
|
+
error = CodeMemoryError("Test error")
|
|
35
|
+
result = error.to_dict()
|
|
36
|
+
assert result["error"] is True
|
|
37
|
+
assert result["error_type"] == "CodeMemoryError"
|
|
38
|
+
assert result["message"] == "Test error"
|
|
39
|
+
|
|
40
|
+
def test_to_dict_with_details(self):
|
|
41
|
+
"""Test conversion to dict with details."""
|
|
42
|
+
error = CodeMemoryError("Test error", {"key": "value"})
|
|
43
|
+
result = error.to_dict()
|
|
44
|
+
assert result["details"] == {"key": "value"}
|
|
45
|
+
|
|
46
|
+
def test_to_dict_without_details(self):
|
|
47
|
+
"""Test that None details are preserved."""
|
|
48
|
+
error = CodeMemoryError("Test error", {})
|
|
49
|
+
result = error.to_dict()
|
|
50
|
+
assert result["details"] is None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class TestSpecializedErrors:
|
|
54
|
+
"""Tests for specialized error classes."""
|
|
55
|
+
|
|
56
|
+
def test_database_error(self):
|
|
57
|
+
"""Test DatabaseError."""
|
|
58
|
+
error = DatabaseError("Connection failed")
|
|
59
|
+
assert isinstance(error, CodeMemoryError)
|
|
60
|
+
assert error.to_dict()["error_type"] == "DatabaseError"
|
|
61
|
+
|
|
62
|
+
def test_indexing_error(self):
|
|
63
|
+
"""Test IndexingError."""
|
|
64
|
+
error = IndexingError("Parse failed")
|
|
65
|
+
assert isinstance(error, CodeMemoryError)
|
|
66
|
+
assert error.to_dict()["error_type"] == "IndexingError"
|
|
67
|
+
|
|
68
|
+
def test_git_error(self):
|
|
69
|
+
"""Test GitError."""
|
|
70
|
+
error = GitError("Not a git repo")
|
|
71
|
+
assert isinstance(error, CodeMemoryError)
|
|
72
|
+
assert error.to_dict()["error_type"] == "GitError"
|
|
73
|
+
|
|
74
|
+
def test_validation_error(self):
|
|
75
|
+
"""Test ValidationError."""
|
|
76
|
+
error = ValidationError("Invalid input")
|
|
77
|
+
assert isinstance(error, CodeMemoryError)
|
|
78
|
+
assert error.to_dict()["error_type"] == "ValidationError"
|
|
79
|
+
|
|
80
|
+
def test_embedding_error(self):
|
|
81
|
+
"""Test EmbeddingError."""
|
|
82
|
+
error = EmbeddingError("Model load failed")
|
|
83
|
+
assert isinstance(error, CodeMemoryError)
|
|
84
|
+
assert error.to_dict()["error_type"] == "EmbeddingError"
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class TestFormatError:
|
|
88
|
+
"""Tests for format_error function."""
|
|
89
|
+
|
|
90
|
+
def test_format_code_memory_error(self):
|
|
91
|
+
"""Test formatting CodeMemoryError."""
|
|
92
|
+
error = ValidationError("Invalid input", {"field": "query"})
|
|
93
|
+
result = format_error(error)
|
|
94
|
+
assert result["error"] is True
|
|
95
|
+
assert result["error_type"] == "ValidationError"
|
|
96
|
+
assert result["message"] == "Invalid input"
|
|
97
|
+
assert result["details"] == {"field": "query"}
|
|
98
|
+
|
|
99
|
+
def test_format_builtin_exception(self):
|
|
100
|
+
"""Test formatting built-in exceptions."""
|
|
101
|
+
error = ValueError("Something went wrong")
|
|
102
|
+
result = format_error(error)
|
|
103
|
+
assert result["error"] is True
|
|
104
|
+
assert result["error_type"] == "ValueError"
|
|
105
|
+
assert result["message"] == "Something went wrong"
|
|
106
|
+
|
|
107
|
+
def test_format_exception_without_message(self):
|
|
108
|
+
"""Test formatting exception with empty message."""
|
|
109
|
+
error = RuntimeError()
|
|
110
|
+
result = format_error(error)
|
|
111
|
+
assert result["error"] is True
|
|
112
|
+
assert "RuntimeError" in result["message"]
|