roma-debug 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tests/test_engine.py ADDED
@@ -0,0 +1,296 @@
1
+ """Tests for the engine module."""
2
+
3
+ import pytest
4
+ from unittest.mock import patch, MagicMock
5
+ import json
6
+
7
+ from roma_debug.core.engine import (
8
+ analyze_error,
9
+ analyze_error_v2,
10
+ _build_prompt,
11
+ _build_prompt_v2,
12
+ _parse_json_response,
13
+ _normalize_filepath,
14
+ FixResult,
15
+ FixResultV2,
16
+ )
17
+
18
+
19
+ class TestBuildPrompt:
20
+ """Tests for prompt building."""
21
+
22
+ def test_builds_prompt_with_log_only(self):
23
+ """Test prompt building with just error log."""
24
+ log = "ValueError: test error"
25
+ context = ""
26
+
27
+ prompt = _build_prompt(log, context)
28
+
29
+ assert "## ERROR LOG" in prompt
30
+ assert "ValueError: test error" in prompt
31
+ assert "## SOURCE CONTEXT" not in prompt
32
+
33
+ def test_builds_prompt_with_context(self):
34
+ """Test prompt building with log and context."""
35
+ log = "ValueError: test error"
36
+ context = "def func():\n pass"
37
+
38
+ prompt = _build_prompt(log, context)
39
+
40
+ assert "## ERROR LOG" in prompt
41
+ assert "ValueError: test error" in prompt
42
+ assert "## SOURCE CONTEXT" in prompt
43
+ assert "def func():" in prompt
44
+
45
+ def test_v2_prompt(self):
46
+ """Test V2 prompt building."""
47
+ log = "ValueError: test"
48
+ context = "## PRIMARY ERROR\nsome context"
49
+
50
+ prompt = _build_prompt_v2(log, context)
51
+
52
+ assert "## ERROR LOG" in prompt or "TRACEBACK" in prompt
53
+ assert "root cause" in prompt.lower()
54
+
55
+
56
+ class TestParseJsonResponse:
57
+ """Tests for JSON response parsing."""
58
+
59
+ def test_parses_plain_json(self):
60
+ """Test parsing plain JSON."""
61
+ text = '{"filepath": "test.py", "full_code_block": "code", "explanation": "fixed"}'
62
+ result = _parse_json_response(text)
63
+
64
+ assert result["filepath"] == "test.py"
65
+ assert result["full_code_block"] == "code"
66
+
67
+ def test_parses_json_in_markdown(self):
68
+ """Test parsing JSON wrapped in markdown code block."""
69
+ text = '''```json
70
+ {"filepath": "test.py", "full_code_block": "code", "explanation": "fixed"}
71
+ ```'''
72
+ result = _parse_json_response(text)
73
+
74
+ assert result["filepath"] == "test.py"
75
+
76
+ def test_parses_json_in_markdown_no_lang(self):
77
+ """Test parsing JSON in markdown without language specifier."""
78
+ text = '''```
79
+ {"filepath": "test.py", "full_code_block": "code", "explanation": "fixed"}
80
+ ```'''
81
+ result = _parse_json_response(text)
82
+
83
+ assert result["filepath"] == "test.py"
84
+
85
+ def test_raises_on_invalid_json(self):
86
+ """Test that invalid JSON raises ValueError."""
87
+ text = "This is not JSON at all"
88
+
89
+ with pytest.raises(ValueError):
90
+ _parse_json_response(text)
91
+
92
+
93
+ class TestNormalizeFilepath:
94
+ """Tests for filepath normalization."""
95
+
96
+ def test_returns_valid_path(self):
97
+ """Test that valid paths are returned as-is."""
98
+ assert _normalize_filepath("src/main.py") == "src/main.py"
99
+ assert _normalize_filepath("/app/test.py") == "/app/test.py"
100
+
101
+ def test_returns_none_for_placeholders(self):
102
+ """Test that placeholder paths return None."""
103
+ assert _normalize_filepath("path/to/file.py") is None
104
+ assert _normalize_filepath("your_file.py") is None
105
+ assert _normalize_filepath("example.py") is None
106
+ assert _normalize_filepath("<filename>") is None
107
+
108
+ def test_returns_none_for_empty(self):
109
+ """Test that empty/null paths return None."""
110
+ assert _normalize_filepath(None) is None
111
+ assert _normalize_filepath("") is None
112
+ assert _normalize_filepath(" ") is None
113
+
114
+
115
+ class TestFixResult:
116
+ """Tests for FixResult class."""
117
+
118
+ def test_to_dict(self):
119
+ """Test FixResult to_dict conversion."""
120
+ result = FixResult(
121
+ filepath="test.py",
122
+ full_code_block="def fix(): pass",
123
+ explanation="Fixed the bug",
124
+ raw_response="{}",
125
+ model_used="gemini-2.5-flash",
126
+ )
127
+
128
+ d = result.to_dict()
129
+
130
+ assert d["filepath"] == "test.py"
131
+ assert d["full_code_block"] == "def fix(): pass"
132
+ assert d["explanation"] == "Fixed the bug"
133
+
134
+
135
+ class TestFixResultV2:
136
+ """Tests for FixResultV2 class."""
137
+
138
+ def test_has_root_cause(self):
139
+ """Test has_root_cause property."""
140
+ result = FixResultV2(
141
+ filepath="main.py",
142
+ full_code_block="code",
143
+ explanation="fix",
144
+ raw_response="{}",
145
+ model_used="gemini",
146
+ root_cause_file="utils.py",
147
+ root_cause_explanation="The bug is actually here",
148
+ )
149
+
150
+ assert result.has_root_cause is True
151
+
152
+ def test_no_root_cause_when_same_file(self):
153
+ """Test has_root_cause is False when root_cause_file matches filepath."""
154
+ result = FixResultV2(
155
+ filepath="main.py",
156
+ full_code_block="code",
157
+ explanation="fix",
158
+ raw_response="{}",
159
+ model_used="gemini",
160
+ root_cause_file="main.py",
161
+ )
162
+
163
+ assert result.has_root_cause is False
164
+
165
+ def test_all_files_to_fix(self):
166
+ """Test all_files_to_fix property."""
167
+ from roma_debug.core.engine import AdditionalFix
168
+
169
+ result = FixResultV2(
170
+ filepath="main.py",
171
+ full_code_block="code",
172
+ explanation="fix",
173
+ raw_response="{}",
174
+ model_used="gemini",
175
+ root_cause_file="utils.py",
176
+ additional_fixes=[
177
+ AdditionalFix(filepath="helpers.py", full_code_block="", explanation=""),
178
+ ],
179
+ )
180
+
181
+ files = result.all_files_to_fix
182
+ assert "main.py" in files
183
+ assert "utils.py" in files
184
+ assert "helpers.py" in files
185
+
186
+ def test_v2_to_dict(self):
187
+ """Test FixResultV2 to_dict includes V2 fields."""
188
+ from roma_debug.core.engine import AdditionalFix
189
+
190
+ result = FixResultV2(
191
+ filepath="main.py",
192
+ full_code_block="code",
193
+ explanation="fix",
194
+ raw_response="{}",
195
+ model_used="gemini",
196
+ root_cause_file="utils.py",
197
+ root_cause_explanation="Root cause here",
198
+ additional_fixes=[
199
+ AdditionalFix(filepath="other.py", full_code_block="more code", explanation="also fix"),
200
+ ],
201
+ )
202
+
203
+ d = result.to_dict()
204
+
205
+ assert d["root_cause_file"] == "utils.py"
206
+ assert d["root_cause_explanation"] == "Root cause here"
207
+ assert len(d["additional_fixes"]) == 1
208
+ assert d["additional_fixes"][0]["filepath"] == "other.py"
209
+
210
+
211
+ class TestAnalyzeError:
212
+ """Tests for analyze_error function with mocked API."""
213
+
214
+ @patch('roma_debug.core.engine._get_client')
215
+ def test_returns_fix_result(self, mock_get_client):
216
+ """Test that analyze_error returns a FixResult."""
217
+ mock_client = MagicMock()
218
+ mock_response = MagicMock()
219
+ mock_response.text = json.dumps({
220
+ "filepath": "test.py",
221
+ "full_code_block": "def fixed(): pass",
222
+ "explanation": "Fixed the function"
223
+ })
224
+ mock_client.models.generate_content.return_value = mock_response
225
+ mock_get_client.return_value = mock_client
226
+
227
+ result = analyze_error("ValueError: test", "def broken(): pass")
228
+
229
+ assert isinstance(result, FixResult)
230
+ assert result.filepath == "test.py"
231
+ assert result.full_code_block == "def fixed(): pass"
232
+
233
+ @patch('roma_debug.core.engine._get_client')
234
+ def test_handles_null_filepath(self, mock_get_client):
235
+ """Test handling of null filepath in response."""
236
+ mock_client = MagicMock()
237
+ mock_response = MagicMock()
238
+ mock_response.text = json.dumps({
239
+ "filepath": None,
240
+ "full_code_block": "general advice",
241
+ "explanation": "This is a config error"
242
+ })
243
+ mock_client.models.generate_content.return_value = mock_response
244
+ mock_get_client.return_value = mock_client
245
+
246
+ result = analyze_error("400 API key invalid", "")
247
+
248
+ assert result.filepath is None
249
+
250
+
251
+ class TestAnalyzeErrorV2:
252
+ """Tests for analyze_error_v2 function."""
253
+
254
+ @patch('roma_debug.core.engine._get_client')
255
+ def test_returns_v2_result(self, mock_get_client):
256
+ """Test that analyze_error_v2 returns FixResultV2."""
257
+ mock_client = MagicMock()
258
+ mock_response = MagicMock()
259
+ mock_response.text = json.dumps({
260
+ "filepath": "main.py",
261
+ "full_code_block": "fixed code",
262
+ "explanation": "Fixed",
263
+ "root_cause_file": "utils.py",
264
+ "root_cause_explanation": "The bug was here",
265
+ "additional_fixes": []
266
+ })
267
+ mock_client.models.generate_content.return_value = mock_response
268
+ mock_get_client.return_value = mock_client
269
+
270
+ result = analyze_error_v2("Error trace", "context")
271
+
272
+ assert isinstance(result, FixResultV2)
273
+ assert result.root_cause_file == "utils.py"
274
+
275
+ @patch('roma_debug.core.engine._get_client')
276
+ def test_parses_additional_fixes(self, mock_get_client):
277
+ """Test parsing of additional_fixes in V2 response."""
278
+ mock_client = MagicMock()
279
+ mock_response = MagicMock()
280
+ mock_response.text = json.dumps({
281
+ "filepath": "main.py",
282
+ "full_code_block": "code1",
283
+ "explanation": "Fix 1",
284
+ "additional_fixes": [
285
+ {"filepath": "utils.py", "full_code_block": "code2", "explanation": "Fix 2"},
286
+ {"filepath": "helpers.py", "full_code_block": "code3", "explanation": "Fix 3"},
287
+ ]
288
+ })
289
+ mock_client.models.generate_content.return_value = mock_response
290
+ mock_get_client.return_value = mock_client
291
+
292
+ result = analyze_error_v2("Error", "context")
293
+
294
+ assert len(result.additional_fixes) == 2
295
+ assert result.additional_fixes[0].filepath == "utils.py"
296
+ assert result.additional_fixes[1].filepath == "helpers.py"