codetree-rag 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
codetree/parser.py ADDED
@@ -0,0 +1,352 @@
1
+ """Code parser using tree-sitter for AST extraction."""
2
+
3
+ import re
4
+ from dataclasses import dataclass, field
5
+ from pathlib import Path
6
+ from typing import Optional
7
+
8
+ # Language file extensions mapping
9
+ LANGUAGE_EXTENSIONS = {
10
+ "python": [".py", ".pyi"],
11
+ "javascript": [".js", ".jsx", ".mjs"],
12
+ "typescript": [".ts", ".tsx"],
13
+ "go": [".go"],
14
+ "rust": [".rs"],
15
+ "java": [".java"],
16
+ "c": [".c", ".h"],
17
+ "cpp": [".cpp", ".hpp", ".cc", ".cxx"],
18
+ }
19
+
20
+
21
+ @dataclass
22
+ class CodeEntity:
23
+ """Represents a code entity (function, class, etc.)."""
24
+ name: str
25
+ type: str # function, class, method, variable
26
+ start_line: int
27
+ end_line: int
28
+ docstring: Optional[str] = None
29
+ signature: Optional[str] = None
30
+ decorators: list[str] = field(default_factory=list)
31
+ children: list["CodeEntity"] = field(default_factory=list)
32
+
33
+
34
+ @dataclass
35
+ class FileInfo:
36
+ """Parsed information about a code file."""
37
+ path: Path
38
+ language: str
39
+ imports: list[str] = field(default_factory=list)
40
+ functions: list[CodeEntity] = field(default_factory=list)
41
+ classes: list[CodeEntity] = field(default_factory=list)
42
+ variables: list[str] = field(default_factory=list)
43
+ summary: Optional[str] = None
44
+ line_count: int = 0
45
+
46
+
47
+ class CodeParser:
48
+ """Parse code files to extract structure information."""
49
+
50
+ def __init__(self):
51
+ self._parsers = {}
52
+
53
+ def detect_language(self, file_path: Path) -> Optional[str]:
54
+ """Detect programming language from file extension."""
55
+ suffix = file_path.suffix.lower()
56
+ for lang, extensions in LANGUAGE_EXTENSIONS.items():
57
+ if suffix in extensions:
58
+ return lang
59
+ return None
60
+
61
+ def parse_file(self, file_path: Path, content: Optional[str] = None) -> Optional[FileInfo]:
62
+ """Parse a code file and extract structure information."""
63
+ language = self.detect_language(file_path)
64
+ if not language:
65
+ return None
66
+
67
+ if content is None:
68
+ try:
69
+ content = file_path.read_text(encoding="utf-8")
70
+ except (UnicodeDecodeError, IOError):
71
+ return None
72
+
73
+ lines = content.split("\n")
74
+
75
+ # Use regex-based parsing for now (simpler than tree-sitter for MVP)
76
+ if language == "python":
77
+ return self._parse_python(file_path, content, lines)
78
+ elif language in ("javascript", "typescript"):
79
+ return self._parse_javascript(file_path, content, lines, language)
80
+ elif language == "go":
81
+ return self._parse_go(file_path, content, lines)
82
+ elif language == "rust":
83
+ return self._parse_rust(file_path, content, lines)
84
+ elif language == "java":
85
+ return self._parse_java(file_path, content, lines)
86
+ else:
87
+ # Basic fallback
88
+ return FileInfo(
89
+ path=file_path,
90
+ language=language,
91
+ line_count=len(lines),
92
+ )
93
+
94
+ def _parse_python(self, file_path: Path, content: str, lines: list[str]) -> FileInfo:
95
+ """Parse Python file."""
96
+ imports = []
97
+ functions = []
98
+ classes = []
99
+ variables = []
100
+
101
+ # Extract imports
102
+ import_pattern = re.compile(r"^(?:from\s+[\w.]+\s+)?import\s+.+", re.MULTILINE)
103
+ for match in import_pattern.finditer(content):
104
+ imports.append(match.group().strip())
105
+
106
+ # Extract functions
107
+ func_pattern = re.compile(
108
+ r"^(?P<decorators>(?:@[\w.]+(?:\([^)]*\))?\s*\n)*)"
109
+ r"(?P<async>async\s+)?def\s+(?P<name>\w+)\s*\((?P<params>[^)]*)\)",
110
+ re.MULTILINE
111
+ )
112
+ for match in func_pattern.finditer(content):
113
+ start_line = content[:match.start()].count("\n") + 1
114
+ name = match.group("name")
115
+ signature = f"def {name}({match.group('params')})"
116
+ if match.group("async"):
117
+ signature = "async " + signature
118
+
119
+ decorators = []
120
+ if match.group("decorators"):
121
+ decorators = [d.strip() for d in match.group("decorators").strip().split("\n") if d.strip()]
122
+
123
+ # Find docstring
124
+ docstring = self._extract_python_docstring(content, match.end())
125
+
126
+ functions.append(CodeEntity(
127
+ name=name,
128
+ type="function",
129
+ start_line=start_line,
130
+ end_line=start_line, # Simplified
131
+ signature=signature,
132
+ decorators=decorators,
133
+ docstring=docstring,
134
+ ))
135
+
136
+ # Extract classes
137
+ class_pattern = re.compile(
138
+ r"^(?P<decorators>(?:@[\w.]+(?:\([^)]*\))?\s*\n)*)"
139
+ r"class\s+(?P<name>\w+)(?:\((?P<bases>[^)]*)\))?:",
140
+ re.MULTILINE
141
+ )
142
+ for match in class_pattern.finditer(content):
143
+ start_line = content[:match.start()].count("\n") + 1
144
+ name = match.group("name")
145
+ bases = match.group("bases") or ""
146
+
147
+ decorators = []
148
+ if match.group("decorators"):
149
+ decorators = [d.strip() for d in match.group("decorators").strip().split("\n") if d.strip()]
150
+
151
+ docstring = self._extract_python_docstring(content, match.end())
152
+
153
+ classes.append(CodeEntity(
154
+ name=name,
155
+ type="class",
156
+ start_line=start_line,
157
+ end_line=start_line, # Simplified
158
+ signature=f"class {name}({bases})" if bases else f"class {name}",
159
+ decorators=decorators,
160
+ docstring=docstring,
161
+ ))
162
+
163
+ # Extract module-level variables (simplified)
164
+ var_pattern = re.compile(r"^([A-Z][A-Z_0-9]*)\s*=", re.MULTILINE)
165
+ for match in var_pattern.finditer(content):
166
+ variables.append(match.group(1))
167
+
168
+ return FileInfo(
169
+ path=file_path,
170
+ language="python",
171
+ imports=imports,
172
+ functions=functions,
173
+ classes=classes,
174
+ variables=variables,
175
+ line_count=len(lines),
176
+ )
177
+
178
+ def _extract_python_docstring(self, content: str, pos: int) -> Optional[str]:
179
+ """Extract Python docstring after a definition."""
180
+ remaining = content[pos:pos + 500]
181
+ # Look for triple-quoted string
182
+ match = re.search(r'^\s*:\s*\n\s*("""|\'\'\')(.+?)\1', remaining, re.DOTALL)
183
+ if match:
184
+ return match.group(2).strip()[:200] # Truncate
185
+ return None
186
+
187
+ def _parse_javascript(self, file_path: Path, content: str, lines: list[str], language: str) -> FileInfo:
188
+ """Parse JavaScript/TypeScript file."""
189
+ imports = []
190
+ functions = []
191
+ classes = []
192
+
193
+ # Extract imports
194
+ import_pattern = re.compile(r"^(?:import|export)\s+.+?['\"];?$", re.MULTILINE)
195
+ for match in import_pattern.finditer(content):
196
+ imports.append(match.group().strip())
197
+
198
+ # Extract functions
199
+ func_patterns = [
200
+ # function declaration
201
+ re.compile(r"(?:export\s+)?(?:async\s+)?function\s+(\w+)\s*\(([^)]*)\)"),
202
+ # arrow function with const
203
+ re.compile(r"(?:export\s+)?const\s+(\w+)\s*=\s*(?:async\s+)?\([^)]*\)\s*=>"),
204
+ ]
205
+ for pattern in func_patterns:
206
+ for match in pattern.finditer(content):
207
+ start_line = content[:match.start()].count("\n") + 1
208
+ name = match.group(1)
209
+ functions.append(CodeEntity(
210
+ name=name,
211
+ type="function",
212
+ start_line=start_line,
213
+ end_line=start_line,
214
+ ))
215
+
216
+ # Extract classes
217
+ class_pattern = re.compile(r"(?:export\s+)?class\s+(\w+)(?:\s+extends\s+(\w+))?")
218
+ for match in class_pattern.finditer(content):
219
+ start_line = content[:match.start()].count("\n") + 1
220
+ name = match.group(1)
221
+ extends = match.group(2)
222
+ classes.append(CodeEntity(
223
+ name=name,
224
+ type="class",
225
+ start_line=start_line,
226
+ end_line=start_line,
227
+ signature=f"class {name}" + (f" extends {extends}" if extends else ""),
228
+ ))
229
+
230
+ return FileInfo(
231
+ path=file_path,
232
+ language=language,
233
+ imports=imports,
234
+ functions=functions,
235
+ classes=classes,
236
+ line_count=len(lines),
237
+ )
238
+
239
+ def _parse_go(self, file_path: Path, content: str, lines: list[str]) -> FileInfo:
240
+ """Parse Go file."""
241
+ imports = []
242
+ functions = []
243
+
244
+ # Extract imports
245
+ import_pattern = re.compile(r'import\s+(?:\(\s*([^)]+)\s*\)|"([^"]+)")')
246
+ for match in import_pattern.finditer(content):
247
+ if match.group(1):
248
+ for line in match.group(1).strip().split("\n"):
249
+ line = line.strip().strip('"')
250
+ if line:
251
+ imports.append(line)
252
+ elif match.group(2):
253
+ imports.append(match.group(2))
254
+
255
+ # Extract functions
256
+ func_pattern = re.compile(r"func\s+(?:\([^)]+\)\s+)?(\w+)\s*\(([^)]*)\)")
257
+ for match in func_pattern.finditer(content):
258
+ start_line = content[:match.start()].count("\n") + 1
259
+ name = match.group(1)
260
+ params = match.group(2)
261
+ functions.append(CodeEntity(
262
+ name=name,
263
+ type="function",
264
+ start_line=start_line,
265
+ end_line=start_line,
266
+ signature=f"func {name}({params})",
267
+ ))
268
+
269
+ return FileInfo(
270
+ path=file_path,
271
+ language="go",
272
+ imports=imports,
273
+ functions=functions,
274
+ line_count=len(lines),
275
+ )
276
+
277
+ def _parse_rust(self, file_path: Path, content: str, lines: list[str]) -> FileInfo:
278
+ """Parse Rust file."""
279
+ imports = []
280
+ functions = []
281
+
282
+ # Extract use statements
283
+ use_pattern = re.compile(r"^use\s+.+;", re.MULTILINE)
284
+ for match in use_pattern.finditer(content):
285
+ imports.append(match.group().strip())
286
+
287
+ # Extract functions
288
+ func_pattern = re.compile(r"(?:pub\s+)?(?:async\s+)?fn\s+(\w+)\s*(?:<[^>]+>)?\s*\(([^)]*)\)")
289
+ for match in func_pattern.finditer(content):
290
+ start_line = content[:match.start()].count("\n") + 1
291
+ name = match.group(1)
292
+ functions.append(CodeEntity(
293
+ name=name,
294
+ type="function",
295
+ start_line=start_line,
296
+ end_line=start_line,
297
+ ))
298
+
299
+ return FileInfo(
300
+ path=file_path,
301
+ language="rust",
302
+ imports=imports,
303
+ functions=functions,
304
+ line_count=len(lines),
305
+ )
306
+
307
+ def _parse_java(self, file_path: Path, content: str, lines: list[str]) -> FileInfo:
308
+ """Parse Java file."""
309
+ imports = []
310
+ functions = []
311
+ classes = []
312
+
313
+ # Extract imports
314
+ import_pattern = re.compile(r"^import\s+.+;", re.MULTILINE)
315
+ for match in import_pattern.finditer(content):
316
+ imports.append(match.group().strip())
317
+
318
+ # Extract classes
319
+ class_pattern = re.compile(r"(?:public\s+)?(?:abstract\s+)?class\s+(\w+)(?:\s+extends\s+(\w+))?")
320
+ for match in class_pattern.finditer(content):
321
+ start_line = content[:match.start()].count("\n") + 1
322
+ name = match.group(1)
323
+ classes.append(CodeEntity(
324
+ name=name,
325
+ type="class",
326
+ start_line=start_line,
327
+ end_line=start_line,
328
+ ))
329
+
330
+ # Extract methods
331
+ method_pattern = re.compile(
332
+ r"(?:public|private|protected)?\s*(?:static\s+)?(?:\w+)\s+(\w+)\s*\(([^)]*)\)"
333
+ )
334
+ for match in method_pattern.finditer(content):
335
+ name = match.group(1)
336
+ if name not in ("if", "while", "for", "switch", "catch"):
337
+ start_line = content[:match.start()].count("\n") + 1
338
+ functions.append(CodeEntity(
339
+ name=name,
340
+ type="method",
341
+ start_line=start_line,
342
+ end_line=start_line,
343
+ ))
344
+
345
+ return FileInfo(
346
+ path=file_path,
347
+ language="java",
348
+ imports=imports,
349
+ functions=functions,
350
+ classes=classes,
351
+ line_count=len(lines),
352
+ )
codetree/retriever.py ADDED
@@ -0,0 +1,192 @@
1
+ """Reasoning-based code retriever."""
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ from .config import Config
8
+ from .indexer import CodeIndex, TreeNode
9
+ from .llm import create_llm_client, LLMClient
10
+
11
+
12
+ RETRIEVAL_SYSTEM_PROMPT = """You are a code navigation expert. Your task is to analyze a code repository structure and identify the most relevant files and code sections to answer a user's question.
13
+
14
+ You will be given:
15
+ 1. A tree structure of the codebase showing directories, files, functions, and classes
16
+ 2. A user's question about the code
17
+
18
+ Your job is to:
19
+ 1. Analyze the question to understand what the user is looking for
20
+ 2. Navigate the tree structure using your reasoning
21
+ 3. Identify the most relevant files and specific functions/classes
22
+ 4. Return a JSON list of file paths that should be examined
23
+
24
+ Think step by step:
25
+ - What concepts does the question involve? (authentication, database, API, etc.)
26
+ - Which directories/modules are likely to contain relevant code?
27
+ - Which specific files have functions or classes related to the question?
28
+
29
+ Return your answer as JSON in this format:
30
+ {
31
+ "reasoning": "Brief explanation of your navigation logic",
32
+ "relevant_files": [
33
+ {"path": "path/to/file.py", "relevance": "why this file is relevant", "focus": ["function_name", "ClassName"]}
34
+ ]
35
+ }
36
+
37
+ Only include files that are truly relevant. Aim for 1-5 files maximum."""
38
+
39
+
40
+ ANSWER_SYSTEM_PROMPT = """You are a helpful code assistant. You have been given relevant code sections from a repository to answer a user's question.
41
+
42
+ Guidelines:
43
+ - Answer the question directly and concisely
44
+ - Reference specific code sections when relevant
45
+ - Include code snippets if they help explain the answer
46
+ - If the provided code doesn't fully answer the question, say so
47
+ - Use markdown formatting for code blocks"""
48
+
49
+
50
+ class CodeRetriever:
51
+ """Retrieves relevant code using LLM reasoning."""
52
+
53
+ def __init__(self, index: CodeIndex, config: Optional[Config] = None):
54
+ self.index = index
55
+ self.config = config or Config.load()
56
+ self.llm = create_llm_client(self.config.llm)
57
+ self.repo_path = Path(index.repo_path)
58
+
59
+ def retrieve(self, query: str, max_files: int = 5) -> list[dict]:
60
+ """Retrieve relevant files for a query using LLM reasoning."""
61
+ # Get compact tree representation
62
+ tree_str = self.index.get_compact_tree(max_depth=4)
63
+
64
+ # Ask LLM to identify relevant files
65
+ messages = [
66
+ {"role": "system", "content": RETRIEVAL_SYSTEM_PROMPT},
67
+ {"role": "user", "content": f"""## Repository Structure
68
+
69
+ {tree_str}
70
+
71
+ ## Question
72
+ {query}
73
+
74
+ Analyze the repository structure and identify the most relevant files to answer this question. Return JSON."""}
75
+ ]
76
+
77
+ response = self.llm.chat(messages)
78
+
79
+ # Parse response
80
+ try:
81
+ # Extract JSON from response
82
+ json_start = response.find("{")
83
+ json_end = response.rfind("}") + 1
84
+ if json_start >= 0 and json_end > json_start:
85
+ result = json.loads(response[json_start:json_end])
86
+ return result.get("relevant_files", [])[:max_files]
87
+ except json.JSONDecodeError:
88
+ pass
89
+
90
+ return []
91
+
92
+ def get_file_content(self, file_path: str, focus: Optional[list[str]] = None) -> Optional[str]:
93
+ """Get content of a file, optionally focusing on specific functions/classes."""
94
+ full_path = self.repo_path / file_path
95
+
96
+ if not full_path.exists():
97
+ return None
98
+
99
+ try:
100
+ content = full_path.read_text(encoding="utf-8")
101
+ except (UnicodeDecodeError, IOError):
102
+ return None
103
+
104
+ # If no focus specified, return full content (truncated)
105
+ if not focus:
106
+ lines = content.split("\n")
107
+ if len(lines) > 200:
108
+ return "\n".join(lines[:200]) + f"\n\n... ({len(lines) - 200} more lines)"
109
+ return content
110
+
111
+ # TODO: Extract only focused sections
112
+ # For now, return full content
113
+ return content
114
+
115
+ def query(self, question: str) -> str:
116
+ """Query the codebase and get an answer."""
117
+ # Step 1: Retrieve relevant files
118
+ relevant_files = self.retrieve(question)
119
+
120
+ if not relevant_files:
121
+ return "I couldn't identify any relevant files for your question. Please try rephrasing or being more specific."
122
+
123
+ # Step 2: Get file contents
124
+ context_parts = []
125
+ for file_info in relevant_files:
126
+ path = file_info.get("path", "")
127
+ focus = file_info.get("focus", [])
128
+
129
+ content = self.get_file_content(path, focus)
130
+ if content:
131
+ context_parts.append(f"## File: {path}\n\n```\n{content}\n```")
132
+
133
+ if not context_parts:
134
+ return "I found relevant files but couldn't read their contents."
135
+
136
+ # Step 3: Generate answer
137
+ context = "\n\n".join(context_parts)
138
+
139
+ messages = [
140
+ {"role": "system", "content": ANSWER_SYSTEM_PROMPT},
141
+ {"role": "user", "content": f"""## Relevant Code
142
+
143
+ {context}
144
+
145
+ ## Question
146
+ {question}
147
+
148
+ Please answer the question based on the code provided above."""}
149
+ ]
150
+
151
+ return self.llm.chat(messages)
152
+
153
+ def find_references(self, symbol: str) -> list[dict]:
154
+ """Find all references to a symbol across the codebase."""
155
+ references = []
156
+
157
+ def search_node(node: TreeNode):
158
+ if node.type == "file":
159
+ # Check functions
160
+ for func in node.functions:
161
+ if symbol.lower() in func.get("name", "").lower():
162
+ references.append({
163
+ "type": "function",
164
+ "file": node.path,
165
+ "name": func["name"],
166
+ "line": func.get("line"),
167
+ })
168
+
169
+ # Check classes
170
+ for cls in node.classes:
171
+ if symbol.lower() in cls.get("name", "").lower():
172
+ references.append({
173
+ "type": "class",
174
+ "file": node.path,
175
+ "name": cls["name"],
176
+ "line": cls.get("line"),
177
+ })
178
+
179
+ # Check imports
180
+ for imp in node.imports:
181
+ if symbol.lower() in imp.lower():
182
+ references.append({
183
+ "type": "import",
184
+ "file": node.path,
185
+ "statement": imp,
186
+ })
187
+ else:
188
+ for child in node.children:
189
+ search_node(child)
190
+
191
+ search_node(self.index.root)
192
+ return references