aru-code 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aru/tools/ast_tools.py ADDED
@@ -0,0 +1,422 @@
1
+ """AST-based code analysis tools using tree-sitter."""
2
+
3
+ import os
4
+ import re
5
+ from typing import Any
6
+
7
+ # Tree-sitter availability flag
8
+ _TREE_SITTER_AVAILABLE = False
9
+ _parser: Any = None
10
+
11
+ try:
12
+ import tree_sitter_python as tspython
13
+ from tree_sitter import Language, Parser
14
+
15
+ _TREE_SITTER_AVAILABLE = True
16
+
17
+ PY_LANGUAGE = Language(tspython.language())
18
+ _parser = Parser(PY_LANGUAGE)
19
+ except ImportError:
20
+ pass
21
+
22
+ # Language registry for future extension
23
+ SUPPORTED_EXTENSIONS = {".py"}
24
+
25
+
26
+ def _parse_python_tree(source: bytes) -> Any | None:
27
+ """Parse Python source code with tree-sitter."""
28
+ if not _TREE_SITTER_AVAILABLE or _parser is None:
29
+ return None
30
+ return _parser.parse(source)
31
+
32
+
33
+ def _extract_structure_treesitter(tree: Any, source: bytes, file_path: str) -> dict:
34
+ """Extract code structure from a tree-sitter AST."""
35
+ root = tree.root_node
36
+ source_text = source.decode("utf-8", errors="ignore")
37
+ lines = source_text.split("\n")
38
+
39
+ structure: dict[str, list] = {
40
+ "imports": [],
41
+ "classes": [],
42
+ "functions": [],
43
+ "globals": [],
44
+ }
45
+
46
+ for child in root.children:
47
+ node_type = child.type
48
+ start_line = child.start_point[0] + 1 # 1-indexed
49
+
50
+ if node_type == "import_statement":
51
+ text = source[child.start_byte:child.end_byte].decode("utf-8", errors="ignore").strip()
52
+ structure["imports"].append({"text": text, "line": start_line})
53
+
54
+ elif node_type == "import_from_statement":
55
+ text = source[child.start_byte:child.end_byte].decode("utf-8", errors="ignore").strip()
56
+ structure["imports"].append({"text": text, "line": start_line})
57
+
58
+ elif node_type == "class_definition":
59
+ class_info = _extract_class(child, source)
60
+ class_info["line"] = start_line
61
+ structure["classes"].append(class_info)
62
+
63
+ elif node_type == "function_definition":
64
+ func_info = _extract_function(child, source)
65
+ func_info["line"] = start_line
66
+ structure["functions"].append(func_info)
67
+
68
+ elif node_type == "decorated_definition":
69
+ # Handle decorated classes/functions
70
+ for sub in child.children:
71
+ if sub.type == "class_definition":
72
+ class_info = _extract_class(sub, source)
73
+ class_info["line"] = sub.start_point[0] + 1
74
+ decorators = _extract_decorators(child, source)
75
+ class_info["decorators"] = decorators
76
+ structure["classes"].append(class_info)
77
+ elif sub.type == "function_definition":
78
+ func_info = _extract_function(sub, source)
79
+ func_info["line"] = sub.start_point[0] + 1
80
+ decorators = _extract_decorators(child, source)
81
+ func_info["decorators"] = decorators
82
+ structure["functions"].append(func_info)
83
+
84
+ elif node_type == "expression_statement":
85
+ # Top-level assignments (globals)
86
+ for sub in child.children:
87
+ if sub.type == "assignment":
88
+ text = source[sub.start_byte:sub.end_byte].decode("utf-8", errors="ignore").strip()
89
+ name = text.split("=")[0].strip().split(":")[0].strip()
90
+ if name and not name.startswith("_"):
91
+ structure["globals"].append({"name": name, "line": start_line})
92
+
93
+ return structure
94
+
95
+
96
+ def _extract_class(node: Any, source: bytes) -> dict:
97
+ """Extract class info from a class_definition node."""
98
+ name = ""
99
+ bases = []
100
+ methods = []
101
+
102
+ for child in node.children:
103
+ if child.type == "identifier":
104
+ name = source[child.start_byte:child.end_byte].decode("utf-8", errors="ignore")
105
+ elif child.type == "argument_list":
106
+ bases_text = source[child.start_byte:child.end_byte].decode("utf-8", errors="ignore")
107
+ bases = [b.strip() for b in bases_text.strip("()").split(",") if b.strip()]
108
+ elif child.type == "block":
109
+ for block_child in child.children:
110
+ if block_child.type == "function_definition":
111
+ method_info = _extract_function(block_child, source)
112
+ method_info["line"] = block_child.start_point[0] + 1
113
+ methods.append(method_info)
114
+ elif block_child.type == "decorated_definition":
115
+ for sub in block_child.children:
116
+ if sub.type == "function_definition":
117
+ method_info = _extract_function(sub, source)
118
+ method_info["line"] = sub.start_point[0] + 1
119
+ method_info["decorators"] = _extract_decorators(block_child, source)
120
+ methods.append(method_info)
121
+
122
+ return {"name": name, "bases": bases, "methods": methods}
123
+
124
+
125
+ def _extract_function(node: Any, source: bytes) -> dict:
126
+ """Extract function info from a function_definition node."""
127
+ name = ""
128
+ params = []
129
+
130
+ for child in node.children:
131
+ if child.type == "identifier":
132
+ name = source[child.start_byte:child.end_byte].decode("utf-8", errors="ignore")
133
+ elif child.type == "parameters":
134
+ params_text = source[child.start_byte:child.end_byte].decode("utf-8", errors="ignore")
135
+ raw_params = params_text.strip("()").split(",")
136
+ params = [p.strip().split(":")[0].strip().split("=")[0].strip()
137
+ for p in raw_params if p.strip()]
138
+
139
+ # Extract docstring if present (first expression_statement with a string child)
140
+ docstring = ""
141
+ body = None
142
+ for child in node.children:
143
+ if child.type == "block":
144
+ body = child
145
+ break
146
+ if body and body.children:
147
+ first_stmt = body.children[0]
148
+ if first_stmt.type == "expression_statement":
149
+ for sc in first_stmt.children:
150
+ if sc.type == "string":
151
+ docstring = source[sc.start_byte:sc.end_byte].decode("utf-8", errors="ignore")
152
+ # Strip triple quotes
153
+ for q in ('"""', "'''"):
154
+ if docstring.startswith(q) and docstring.endswith(q):
155
+ docstring = docstring[3:-3].strip()
156
+ break
157
+ break
158
+
159
+ return {"name": name, "params": params, "docstring": docstring}
160
+
161
+
162
+ def _extract_decorators(node: Any, source: bytes) -> list[str]:
163
+ """Extract decorator names from a decorated_definition node."""
164
+ decorators = []
165
+ for child in node.children:
166
+ if child.type == "decorator":
167
+ text = source[child.start_byte:child.end_byte].decode("utf-8", errors="ignore").strip()
168
+ decorators.append(text)
169
+ return decorators
170
+
171
+
172
+ def _extract_structure_regex(content: str) -> dict:
173
+ """Fallback: extract code structure using regex (when tree-sitter is unavailable)."""
174
+ structure: dict[str, list] = {
175
+ "imports": [],
176
+ "classes": [],
177
+ "functions": [],
178
+ "globals": [],
179
+ }
180
+
181
+ for i, line in enumerate(content.split("\n"), 1):
182
+ stripped = line.strip()
183
+
184
+ if stripped.startswith("import ") or stripped.startswith("from "):
185
+ structure["imports"].append({"text": stripped, "line": i})
186
+
187
+ elif stripped.startswith("class "):
188
+ match = re.match(r"class\s+(\w+)(?:\((.*?)\))?:", stripped)
189
+ if match:
190
+ name = match.group(1)
191
+ bases = [b.strip() for b in (match.group(2) or "").split(",") if b.strip()]
192
+ structure["classes"].append({"name": name, "bases": bases, "methods": [], "line": i})
193
+
194
+ elif stripped.startswith("def "):
195
+ match = re.match(r"def\s+(\w+)\((.*?)\)", stripped)
196
+ if match:
197
+ name = match.group(1)
198
+ params = [p.strip().split(":")[0].split("=")[0].strip()
199
+ for p in match.group(2).split(",") if p.strip()]
200
+ # Check if it's a method (indented) or top-level function
201
+ if line.startswith(" ") or line.startswith("\t"):
202
+ # Method - add to last class
203
+ if structure["classes"]:
204
+ structure["classes"][-1]["methods"].append({
205
+ "name": name, "params": params, "line": i
206
+ })
207
+ else:
208
+ structure["functions"].append({"name": name, "params": params, "line": i})
209
+
210
+ return structure
211
+
212
+
213
+ def _format_structure(structure: dict, file_path: str, total_lines: int) -> str:
214
+ """Format extracted structure as readable text."""
215
+ _, ext = os.path.splitext(file_path)
216
+ lang = {"py": "Python", "js": "JavaScript", "ts": "TypeScript"}.get(ext.lstrip("."), ext.lstrip(".").upper() or "Unknown")
217
+
218
+ parts = [f"## {file_path} ({lang}, {total_lines} lines)\n"]
219
+
220
+ if structure["imports"]:
221
+ parts.append("### Imports")
222
+ for imp in structure["imports"]:
223
+ parts.append(f" - {imp['text']} (line {imp['line']})")
224
+ parts.append("")
225
+
226
+ if structure["classes"]:
227
+ parts.append("### Classes")
228
+ for cls in structure["classes"]:
229
+ bases_str = f"({', '.join(cls['bases'])})" if cls.get("bases") else ""
230
+ decorators = cls.get("decorators", [])
231
+ dec_str = " ".join(decorators) + " " if decorators else ""
232
+ parts.append(f" - {dec_str}{cls['name']}{bases_str} (line {cls['line']})")
233
+ for method in cls.get("methods", []):
234
+ params_str = ", ".join(method["params"])
235
+ dec_str = " ".join(method.get("decorators", []))
236
+ prefix = f" {dec_str} " if dec_str else " "
237
+ parts.append(f"{prefix}- {method['name']}({params_str}) - line {method['line']}")
238
+ parts.append("")
239
+
240
+ if structure["functions"]:
241
+ parts.append("### Functions")
242
+ for func in structure["functions"]:
243
+ params_str = ", ".join(func["params"])
244
+ decorators = func.get("decorators", [])
245
+ dec_str = " ".join(decorators) + " " if decorators else ""
246
+ parts.append(f" - {dec_str}{func['name']}({params_str}) - line {func['line']}")
247
+ parts.append("")
248
+
249
+ if structure["globals"]:
250
+ parts.append("### Globals")
251
+ for g in structure["globals"]:
252
+ parts.append(f" - {g['name']} (line {g['line']})")
253
+ parts.append("")
254
+
255
+ return "\n".join(parts)
256
+
257
+
258
+ def code_structure(file_path: str) -> str:
259
+ """Analyze a file and return its structural overview: imports, classes, functions, and globals.
260
+
261
+ Useful for quickly understanding what a file contains without reading its full content.
262
+ Works best with Python files (using tree-sitter AST parsing), but falls back to
263
+ regex-based extraction for other languages.
264
+
265
+ Args:
266
+ file_path: Path to the file to analyze.
267
+ """
268
+ try:
269
+ with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
270
+ content = f.read()
271
+ except FileNotFoundError:
272
+ return f"Error: File not found: {file_path}"
273
+ except Exception as e:
274
+ return f"Error reading file: {e}"
275
+
276
+ total_lines = content.count("\n") + 1
277
+ _, ext = os.path.splitext(file_path)
278
+
279
+ # Try tree-sitter for supported languages
280
+ if ext in SUPPORTED_EXTENSIONS and _TREE_SITTER_AVAILABLE:
281
+ source = content.encode("utf-8")
282
+ tree = _parse_python_tree(source)
283
+ if tree:
284
+ structure = _extract_structure_treesitter(tree, source, file_path)
285
+ return _format_structure(structure, file_path, total_lines)
286
+
287
+ # Fallback to regex
288
+ structure = _extract_structure_regex(content)
289
+ return _format_structure(structure, file_path, total_lines)
290
+
291
+
292
+ def _resolve_import_to_file(import_text: str, project_root: str) -> str | None:
293
+ """Try to resolve an import statement to a file path within the project."""
294
+ # Handle "from X import Y" and "import X"
295
+ match = re.match(r"(?:from\s+)?([\w.]+)", import_text)
296
+ if not match:
297
+ return None
298
+
299
+ module_path = match.group(1)
300
+ parts = module_path.split(".")
301
+
302
+ # Try as package (directory/__init__.py) and module (.py file)
303
+ candidates = [
304
+ os.path.join(*parts, "__init__.py"),
305
+ os.path.join(*parts) + ".py",
306
+ ]
307
+
308
+ # Also try relative to common src directories
309
+ for candidate in candidates:
310
+ full_path = os.path.join(project_root, candidate)
311
+ if os.path.isfile(full_path):
312
+ return candidate
313
+
314
+ return None
315
+
316
+
317
+ def _find_project_root(file_path: str) -> str:
318
+ """Find the project root by looking for pyproject.toml, setup.py, or .git."""
319
+ current = os.path.abspath(os.path.dirname(file_path))
320
+ markers = ("pyproject.toml", "setup.py", "setup.cfg", "package.json", ".git")
321
+
322
+ while True:
323
+ for marker in markers:
324
+ if os.path.exists(os.path.join(current, marker)):
325
+ return current
326
+ parent = os.path.dirname(current)
327
+ if parent == current:
328
+ return os.getcwd()
329
+ current = parent
330
+
331
+
332
+ def find_dependencies(file_path: str, depth: int = 3) -> str:
333
+ """Trace the import dependency tree of a file within the project.
334
+
335
+ Resolves local imports (within the project) and shows which files depend on which.
336
+ Skips stdlib and third-party packages. Useful for understanding how files are connected.
337
+
338
+ Args:
339
+ file_path: Path to the file to analyze.
340
+ depth: Maximum recursion depth for tracing imports. Defaults to 3.
341
+ """
342
+ if not os.path.isfile(file_path):
343
+ return f"Error: File not found: {file_path}"
344
+
345
+ project_root = _find_project_root(file_path)
346
+ rel_start = os.path.relpath(file_path, project_root).replace("\\", "/")
347
+
348
+ visited: set[str] = set()
349
+ tree_lines: list[str] = []
350
+
351
+ def _trace(rel_path: str, current_depth: int, prefix: str = "", is_last: bool = True):
352
+ if rel_path in visited or current_depth > depth:
353
+ if rel_path in visited:
354
+ connector = "└── " if is_last else "├── "
355
+ tree_lines.append(f"{prefix}{connector}{rel_path} (circular)")
356
+ return
357
+
358
+ visited.add(rel_path)
359
+ connector = "└── " if is_last else "├── "
360
+
361
+ if current_depth == 0:
362
+ tree_lines.append(rel_path)
363
+ else:
364
+ tree_lines.append(f"{prefix}{connector}{rel_path}")
365
+
366
+ # Read file and extract imports
367
+ full_path = os.path.join(project_root, rel_path)
368
+ if not os.path.isfile(full_path):
369
+ return
370
+
371
+ try:
372
+ with open(full_path, "r", encoding="utf-8", errors="ignore") as f:
373
+ content = f.read()
374
+ except OSError:
375
+ return
376
+
377
+ # Extract imports (tree-sitter or regex)
378
+ _, ext = os.path.splitext(rel_path)
379
+ imports = []
380
+
381
+ if ext == ".py" and _TREE_SITTER_AVAILABLE:
382
+ source = content.encode("utf-8")
383
+ tree = _parse_python_tree(source)
384
+ if tree:
385
+ for child in tree.root_node.children:
386
+ if child.type in ("import_statement", "import_from_statement"):
387
+ text = source[child.start_byte:child.end_byte].decode("utf-8", errors="ignore").strip()
388
+ imports.append(text)
389
+ else:
390
+ for line in content.split("\n"):
391
+ stripped = line.strip()
392
+ if stripped.startswith("import ") or stripped.startswith("from "):
393
+ imports.append(stripped)
394
+
395
+ # Resolve imports to local files
396
+ local_deps = []
397
+ for imp in imports:
398
+ resolved = _resolve_import_to_file(imp, project_root)
399
+ if resolved and resolved != rel_path:
400
+ local_deps.append(resolved)
401
+
402
+ # Remove duplicates while preserving order
403
+ seen = set()
404
+ unique_deps = []
405
+ for dep in local_deps:
406
+ normalized = dep.replace("\\", "/")
407
+ if normalized not in seen:
408
+ seen.add(normalized)
409
+ unique_deps.append(normalized)
410
+
411
+ # Recurse into dependencies
412
+ child_prefix = prefix + (" " if is_last else "│ ")
413
+ for i, dep in enumerate(unique_deps):
414
+ is_dep_last = (i == len(unique_deps) - 1)
415
+ _trace(dep, current_depth + 1, child_prefix if current_depth > 0 else "", is_dep_last)
416
+
417
+ _trace(rel_start, 0)
418
+
419
+ if not tree_lines:
420
+ return f"No dependencies found for: {file_path}"
421
+
422
+ return "\n".join(tree_lines)