code-memory 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tests/test_logging.py ADDED
@@ -0,0 +1,169 @@
1
+ """Tests for logging configuration module."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import io
7
+
8
+ import pytest
9
+
10
+ import logging_config
11
+
12
+
13
+ class TestSetupLogging:
14
+ """Tests for setup_logging function."""
15
+
16
+ def test_creates_logger(self):
17
+ """Test that setup_logging creates a logger."""
18
+ # Reset initialization state for this test
19
+ logging_config._initialized = False
20
+ logger = logging_config.setup_logging(level="INFO")
21
+ assert logger is not None
22
+ assert logger.name == "code_memory"
23
+
24
+ def test_respects_level(self):
25
+ """Test that log level is set correctly."""
26
+ # Create a fresh logger with custom stream
27
+ stream = io.StringIO()
28
+ logger = logging.getLogger("test_code_memory_1")
29
+ logger.handlers.clear()
30
+ logger.setLevel(logging.DEBUG)
31
+ handler = logging.StreamHandler(stream)
32
+ handler.setLevel(logging.DEBUG)
33
+ logger.addHandler(handler)
34
+ assert logger.level == logging.DEBUG
35
+
36
+ def test_custom_level_string(self):
37
+ """Test custom log level string."""
38
+ logger = logging.getLogger("test_code_memory_2")
39
+ logger.handlers.clear()
40
+ logger.setLevel(logging.WARNING)
41
+ assert logger.level == logging.WARNING
42
+
43
+
44
+ class TestGetLogger:
45
+ """Tests for get_logger function."""
46
+
47
+ def test_get_module_logger(self):
48
+ """Test getting a module-specific logger."""
49
+ logger = logging_config.get_logger("test_module")
50
+ assert logger.name == "code_memory.test_module"
51
+
52
+ def test_different_modules_different_loggers(self):
53
+ """Test that different modules get different loggers."""
54
+ logger1 = logging_config.get_logger("module1")
55
+ logger2 = logging_config.get_logger("module2")
56
+ assert logger1.name != logger2.name
57
+
58
+
59
+ class TestToolLogger:
60
+ """Tests for ToolLogger context manager."""
61
+
62
+ def test_logs_invocation(self):
63
+ """Test that tool invocation is logged."""
64
+ # Create a logger with a string stream to capture output
65
+ stream = io.StringIO()
66
+ handler = logging.StreamHandler(stream)
67
+ handler.setFormatter(logging.Formatter('%(message)s'))
68
+
69
+ logger = logging.getLogger("code_memory.tools")
70
+ logger.handlers.clear()
71
+ logger.setLevel(logging.INFO)
72
+ logger.addHandler(handler)
73
+
74
+ with logging_config.ToolLogger("test_tool", query="test"):
75
+ pass
76
+
77
+ output = stream.getvalue()
78
+ assert "Tool invoked" in output
79
+
80
+ def test_logs_completion(self):
81
+ """Test that tool completion is logged."""
82
+ stream = io.StringIO()
83
+ handler = logging.StreamHandler(stream)
84
+ handler.setFormatter(logging.Formatter('%(message)s'))
85
+
86
+ logger = logging.getLogger("code_memory.tools")
87
+ logger.handlers.clear()
88
+ logger.setLevel(logging.INFO)
89
+ logger.addHandler(handler)
90
+
91
+ with logging_config.ToolLogger("test_tool", query="test"):
92
+ pass
93
+
94
+ output = stream.getvalue()
95
+ assert "Tool completed" in output
96
+
97
+ def test_logs_error_on_exception(self):
98
+ """Test that exceptions are logged."""
99
+ stream = io.StringIO()
100
+ handler = logging.StreamHandler(stream)
101
+ handler.setFormatter(logging.Formatter('%(message)s'))
102
+
103
+ logger = logging.getLogger("code_memory.tools")
104
+ logger.handlers.clear()
105
+ logger.setLevel(logging.INFO)
106
+ logger.addHandler(handler)
107
+
108
+ try:
109
+ with logging_config.ToolLogger("test_tool", query="test"):
110
+ raise ValueError("Test error")
111
+ except ValueError:
112
+ pass
113
+
114
+ output = stream.getvalue()
115
+ assert "Tool failed" in output
116
+
117
+ def test_result_count_logged(self):
118
+ """Test that result count is logged."""
119
+ stream = io.StringIO()
120
+ handler = logging.StreamHandler(stream)
121
+ handler.setFormatter(logging.Formatter('%(message)s'))
122
+
123
+ logger = logging.getLogger("code_memory.tools")
124
+ logger.handlers.clear()
125
+ logger.setLevel(logging.INFO)
126
+ logger.addHandler(handler)
127
+
128
+ with logging_config.ToolLogger("test_tool") as log:
129
+ log.set_result_count(5)
130
+
131
+ output = stream.getvalue()
132
+ assert "count=5" in output
133
+
134
+
135
+ class TestIndexingLogger:
136
+ """Tests for IndexingLogger class."""
137
+
138
+ def test_tracks_files_processed(self):
139
+ """Test that files processed are tracked."""
140
+ idx_logger = logging_config.IndexingLogger("test")
141
+ idx_logger.file_indexed("file1.py", 3)
142
+ idx_logger.file_indexed("file2.py", 2)
143
+ assert idx_logger.files_processed == 2
144
+ assert idx_logger.items_indexed == 5
145
+
146
+ def test_tracks_files_skipped(self):
147
+ """Test that files skipped are tracked."""
148
+ idx_logger = logging_config.IndexingLogger("test")
149
+ idx_logger.file_skipped("file1.py", "unchanged")
150
+ assert idx_logger.files_skipped == 1
151
+
152
+
153
+ class TestPreconfiguredLoggers:
154
+ """Tests for pre-configured logger functions."""
155
+
156
+ def test_get_server_logger(self):
157
+ """Test getting server logger."""
158
+ logger = logging_config.get_server_logger()
159
+ assert "server" in logger.name
160
+
161
+ def test_get_db_logger(self):
162
+ """Test getting db logger."""
163
+ logger = logging_config.get_db_logger()
164
+ assert "db" in logger.name
165
+
166
+ def test_get_query_logger(self):
167
+ """Test getting query logger."""
168
+ logger = logging_config.get_query_logger()
169
+ assert "queries" in logger.name
tests/test_tools.py ADDED
@@ -0,0 +1,114 @@
1
+ """Integration tests for MCP tools."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pytest
6
+ import sys
7
+
8
+
9
+ class TestSearchCodeValidation:
10
+ """Tests for search_code tool input validation."""
11
+
12
+ def test_empty_query_returns_error(self):
13
+ """Test that empty query returns structured error."""
14
+ import server
15
+ result = server.search_code("", "definition")
16
+ assert result.get("error") is True
17
+ assert "ValidationError" in result.get("error_type", "")
18
+
19
+ def test_invalid_search_type_returns_error(self):
20
+ """Test that invalid search_type returns structured error."""
21
+ import server
22
+ result = server.search_code("test", "invalid_type")
23
+ assert result.get("error") is True
24
+ assert "ValidationError" in result.get("error_type", "")
25
+
26
+
27
+ class TestSearchDocsValidation:
28
+ """Tests for search_docs tool input validation."""
29
+
30
+ def test_empty_query_returns_error(self):
31
+ """Test that empty query returns structured error."""
32
+ import server
33
+ result = server.search_docs("")
34
+ assert result.get("error") is True
35
+ assert "ValidationError" in result.get("error_type", "")
36
+
37
+ def test_invalid_top_k_returns_error(self):
38
+ """Test that invalid top_k returns structured error."""
39
+ import server
40
+ result = server.search_docs("test", top_k=-1)
41
+ assert result.get("error") is True
42
+
43
+
44
+ class TestSearchHistoryValidation:
45
+ """Tests for search_history tool input validation."""
46
+
47
+ def test_invalid_search_type_returns_error(self):
48
+ """Test that invalid search_type returns structured error."""
49
+ import server
50
+ result = server.search_history("test", search_type="invalid")
51
+ assert result.get("error") is True
52
+ assert "ValidationError" in result.get("error_type", "")
53
+
54
+ def test_file_history_requires_target_file(self):
55
+ """Test that file_history requires target_file."""
56
+ import server
57
+ result = server.search_history("test", search_type="file_history", target_file=None)
58
+ assert result.get("error") is True
59
+ assert "target_file" in result.get("message", "").lower()
60
+
61
+ def test_blame_requires_target_file(self):
62
+ """Test that blame requires target_file."""
63
+ import server
64
+ result = server.search_history("test", search_type="blame", target_file=None)
65
+ assert result.get("error") is True
66
+ assert "target_file" in result.get("message", "").lower()
67
+
68
+ def test_invalid_line_range_returns_error(self):
69
+ """Test that invalid line range returns error."""
70
+ import server
71
+ # This should work since we're in a git repo, but line_start > line_end
72
+ result = server.search_history(
73
+ "test",
74
+ search_type="blame",
75
+ target_file="server.py",
76
+ line_start=10,
77
+ line_end=5
78
+ )
79
+ assert result.get("error") is True
80
+
81
+
82
+ class TestIndexCodebaseValidation:
83
+ """Tests for index_codebase tool input validation."""
84
+
85
+ def test_nonexistent_directory_returns_error(self):
86
+ """Test that nonexistent directory returns structured error."""
87
+ import server
88
+ result = server.index_codebase("/nonexistent/directory")
89
+ assert result.get("error") is True
90
+ assert "ValidationError" in result.get("error_type", "")
91
+
92
+
93
+ class TestToolResponseStructure:
94
+ """Tests for consistent tool response structure."""
95
+
96
+ def test_success_response_has_status(self):
97
+ """Test that successful responses have status field."""
98
+ import server
99
+ # search_docs should work even without indexed content
100
+ result = server.search_docs("test query")
101
+ # Either it succeeds or fails gracefully
102
+ if "status" in result:
103
+ assert result["status"] in ("ok", "error")
104
+ elif "error" in result:
105
+ assert result["error"] is True
106
+
107
+ def test_error_response_structure(self):
108
+ """Test that error responses have consistent structure."""
109
+ import server
110
+ result = server.search_code("", "definition")
111
+ assert "error" in result
112
+ assert result["error"] is True
113
+ assert "error_type" in result
114
+ assert "message" in result
@@ -0,0 +1,216 @@
1
+ """Tests for input validation module."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pytest
6
+
7
+ from errors import ValidationError
8
+ import validation as val
9
+
10
+
11
+ class TestValidateQuery:
12
+ """Tests for validate_query function."""
13
+
14
+ def test_valid_query(self):
15
+ """Test that valid queries pass validation."""
16
+ result = val.validate_query("test query")
17
+ assert result == "test query"
18
+
19
+ def test_strips_whitespace(self):
20
+ """Test that whitespace is stripped."""
21
+ result = val.validate_query(" test query ")
22
+ assert result == "test query"
23
+
24
+ def test_empty_query_fails(self):
25
+ """Test that empty queries raise ValidationError."""
26
+ with pytest.raises(ValidationError) as exc_info:
27
+ val.validate_query("")
28
+ assert "too short" in str(exc_info.value.message)
29
+
30
+ def test_whitespace_only_fails(self):
31
+ """Test that whitespace-only queries raise ValidationError."""
32
+ with pytest.raises(ValidationError) as exc_info:
33
+ val.validate_query(" ")
34
+ assert "too short" in str(exc_info.value.message)
35
+
36
+ def test_long_query_fails(self):
37
+ """Test that overly long queries raise ValidationError."""
38
+ long_query = "a" * 1001
39
+ with pytest.raises(ValidationError) as exc_info:
40
+ val.validate_query(long_query, max_length=1000)
41
+ assert "too long" in str(exc_info.value.message)
42
+
43
+ def test_custom_min_length(self):
44
+ """Test custom minimum length."""
45
+ with pytest.raises(ValidationError):
46
+ val.validate_query("ab", min_length=5)
47
+
48
+
49
+ class TestValidateSearchType:
50
+ """Tests for validate_search_type function."""
51
+
52
+ def test_valid_search_type(self):
53
+ """Test that valid search types pass."""
54
+ result = val.validate_search_type("definition", ["definition", "references"])
55
+ assert result == "definition"
56
+
57
+ def test_invalid_search_type(self):
58
+ """Test that invalid search types raise ValidationError."""
59
+ with pytest.raises(ValidationError) as exc_info:
60
+ val.validate_search_type("invalid", ["definition", "references"])
61
+ assert "Invalid search type" in str(exc_info.value.message)
62
+
63
+ def test_empty_search_type(self):
64
+ """Test that empty search type raises ValidationError."""
65
+ with pytest.raises(ValidationError):
66
+ val.validate_search_type("", ["definition"])
67
+
68
+
69
+ class TestValidateLineNumber:
70
+ """Tests for validate_line_number function."""
71
+
72
+ def test_valid_line_number(self):
73
+ """Test that valid line numbers pass."""
74
+ result = val.validate_line_number(10, "line_start")
75
+ assert result == 10
76
+
77
+ def test_none_allowed(self):
78
+ """Test that None is allowed."""
79
+ result = val.validate_line_number(None, "line_start")
80
+ assert result is None
81
+
82
+ def test_negative_fails(self):
83
+ """Test that negative numbers raise ValidationError."""
84
+ with pytest.raises(ValidationError):
85
+ val.validate_line_number(-1, "line_start")
86
+
87
+ def test_zero_fails(self):
88
+ """Test that zero raises ValidationError with default min."""
89
+ with pytest.raises(ValidationError):
90
+ val.validate_line_number(0, "line_start")
91
+
92
+
93
+ class TestValidateLineRange:
94
+ """Tests for validate_line_range function."""
95
+
96
+ def test_valid_range(self):
97
+ """Test that valid ranges pass."""
98
+ start, end = val.validate_line_range(1, 10)
99
+ assert start == 1
100
+ assert end == 10
101
+
102
+ def test_start_greater_than_end_fails(self):
103
+ """Test that start > end raises ValidationError."""
104
+ with pytest.raises(ValidationError):
105
+ val.validate_line_range(10, 1)
106
+
107
+
108
+ class TestValidateTopK:
109
+ """Tests for validate_top_k function."""
110
+
111
+ def test_valid_value(self):
112
+ """Test that valid values pass."""
113
+ result = val.validate_top_k(10)
114
+ assert result == 10
115
+
116
+ def test_none_uses_default(self):
117
+ """Test that None returns default."""
118
+ result = val.validate_top_k(None)
119
+ assert result == 10
120
+
121
+ def test_zero_uses_default(self):
122
+ """Test that zero returns default."""
123
+ result = val.validate_top_k(0)
124
+ assert result == 10
125
+
126
+ def test_too_large_fails(self):
127
+ """Test that values > max raise ValidationError."""
128
+ with pytest.raises(ValidationError):
129
+ val.validate_top_k(200, max_val=100)
130
+
131
+ def test_negative_fails(self):
132
+ """Test that negative values raise ValidationError."""
133
+ with pytest.raises(ValidationError):
134
+ val.validate_top_k(-1, min_val=1, default=10)
135
+
136
+
137
+ class TestValidateDirectory:
138
+ """Tests for validate_directory function."""
139
+
140
+ def test_existing_directory(self, temp_dir):
141
+ """Test that existing directories pass."""
142
+ result = val.validate_directory(str(temp_dir))
143
+ assert result == temp_dir
144
+
145
+ def test_nonexistent_fails(self):
146
+ """Test that nonexistent directories raise ValidationError."""
147
+ with pytest.raises(ValidationError) as exc_info:
148
+ val.validate_directory("/nonexistent/path")
149
+ assert "not found" in str(exc_info.value.message)
150
+
151
+ def test_file_not_directory_fails(self, temp_dir):
152
+ """Test that files (not directories) raise ValidationError."""
153
+ test_file = temp_dir / "test.txt"
154
+ test_file.write_text("test")
155
+ with pytest.raises(ValidationError):
156
+ val.validate_directory(str(test_file))
157
+
158
+
159
+ class TestValidateFile:
160
+ """Tests for validate_file function."""
161
+
162
+ def test_existing_file(self, temp_dir):
163
+ """Test that existing files pass."""
164
+ test_file = temp_dir / "test.txt"
165
+ test_file.write_text("test")
166
+ result = val.validate_file(str(test_file))
167
+ assert result == test_file
168
+
169
+ def test_nonexistent_fails(self):
170
+ """Test that nonexistent files raise ValidationError."""
171
+ with pytest.raises(ValidationError) as exc_info:
172
+ val.validate_file("/nonexistent/file.txt")
173
+ assert "not found" in str(exc_info.value.message)
174
+
175
+ def test_directory_not_file_fails(self, temp_dir):
176
+ """Test that directories (not files) raise ValidationError."""
177
+ with pytest.raises(ValidationError):
178
+ val.validate_file(str(temp_dir))
179
+
180
+
181
+ class TestSanitizeFtsQuery:
182
+ """Tests for sanitize_fts_query function."""
183
+
184
+ def test_simple_query(self):
185
+ """Test that simple queries pass through."""
186
+ result = val.sanitize_fts_query("simple query")
187
+ assert result == "simple query"
188
+
189
+ def test_escapes_quotes(self):
190
+ """Test that quotes are escaped."""
191
+ result = val.sanitize_fts_query('test "quoted"')
192
+ assert '""' in result or '"' in result
193
+
194
+
195
+ class TestValidateCommitHash:
196
+ """Tests for validate_commit_hash function."""
197
+
198
+ def test_valid_full_hash(self):
199
+ """Test that valid full hashes pass."""
200
+ result = val.validate_commit_hash("a" * 40)
201
+ assert result == "a" * 40
202
+
203
+ def test_valid_short_hash(self):
204
+ """Test that valid short hashes pass."""
205
+ result = val.validate_commit_hash("abc1234")
206
+ assert result == "abc1234"
207
+
208
+ def test_invalid_format_fails(self):
209
+ """Test that invalid formats raise ValidationError."""
210
+ with pytest.raises(ValidationError):
211
+ val.validate_commit_hash("not-a-hash")
212
+
213
+ def test_too_short_fails(self):
214
+ """Test that too-short hashes raise ValidationError."""
215
+ with pytest.raises(ValidationError):
216
+ val.validate_commit_hash("abc123")