tricoder 1.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tricoder/__about__.py +6 -0
- tricoder/__init__.py +19 -0
- tricoder/calibration.py +276 -0
- tricoder/cli.py +890 -0
- tricoder/context_view.py +228 -0
- tricoder/data_loader.py +144 -0
- tricoder/extract.py +622 -0
- tricoder/fusion.py +203 -0
- tricoder/git_tracker.py +203 -0
- tricoder/gpu_utils.py +414 -0
- tricoder/graph_view.py +583 -0
- tricoder/model.py +476 -0
- tricoder/optimize.py +263 -0
- tricoder/subtoken_utils.py +196 -0
- tricoder/train.py +857 -0
- tricoder/typed_view.py +313 -0
- tricoder-1.2.8.dist-info/METADATA +306 -0
- tricoder-1.2.8.dist-info/RECORD +22 -0
- tricoder-1.2.8.dist-info/WHEEL +4 -0
- tricoder-1.2.8.dist-info/entry_points.txt +3 -0
- tricoder-1.2.8.dist-info/licenses/LICENSE +56 -0
- tricoder-1.2.8.dist-info/licenses/LICENSE_COMMERCIAL.md +68 -0
tricoder/extract.py
ADDED
|
@@ -0,0 +1,622 @@
|
|
|
1
|
+
"""Symbol extraction utilities for TriCoder."""
|
|
2
|
+
# This module contains the symbol extraction logic
|
|
3
|
+
# Imported from extract_symbols.py for package distribution
|
|
4
|
+
|
|
5
|
+
import ast
|
|
6
|
+
import fnmatch
|
|
7
|
+
import json
|
|
8
|
+
import os
|
|
9
|
+
import re
|
|
10
|
+
from collections import defaultdict
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Dict, List, Tuple, Optional, Set
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SymbolExtractor(ast.NodeVisitor):
|
|
16
|
+
"""AST visitor to extract symbols and relationships."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, file_path: str, excluded_keywords: Set[str] = None):
|
|
19
|
+
self.file_path = file_path
|
|
20
|
+
self.symbols = []
|
|
21
|
+
self.edges = []
|
|
22
|
+
self.current_class = None
|
|
23
|
+
self.current_function = None
|
|
24
|
+
self.imports = {}
|
|
25
|
+
self.symbol_counter = 0
|
|
26
|
+
self.symbol_map = {}
|
|
27
|
+
self.type_tokens = defaultdict(int)
|
|
28
|
+
self.added_symbol_ids = set() # Track which symbol IDs have been added
|
|
29
|
+
self.edge_weights = {} # Track edge weights for aggregation: (src, dst, rel) -> weight
|
|
30
|
+
self.excluded_keywords = excluded_keywords or set() # Symbol names to exclude
|
|
31
|
+
|
|
32
|
+
def _sanitize_name(self, name: str) -> str:
|
|
33
|
+
"""Sanitize symbol name for use in ID (remove/replace invalid characters)."""
|
|
34
|
+
# Replace spaces and special characters with underscores
|
|
35
|
+
sanitized = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
|
36
|
+
# Remove consecutive underscores
|
|
37
|
+
sanitized = re.sub(r'_+', '_', sanitized)
|
|
38
|
+
# Remove leading/trailing underscores
|
|
39
|
+
sanitized = sanitized.strip('_')
|
|
40
|
+
# Limit length to avoid very long IDs
|
|
41
|
+
if len(sanitized) > 50:
|
|
42
|
+
sanitized = sanitized[:50]
|
|
43
|
+
# Ensure it's not empty
|
|
44
|
+
if not sanitized:
|
|
45
|
+
sanitized = 'unnamed'
|
|
46
|
+
return sanitized.lower()
|
|
47
|
+
|
|
48
|
+
def _get_symbol_id(self, name: str, kind: str) -> str:
|
|
49
|
+
"""Generate unique symbol ID with descriptive name."""
|
|
50
|
+
key = f"{self.file_path}:{kind}:{name}"
|
|
51
|
+
if key not in self.symbol_map:
|
|
52
|
+
self.symbol_counter += 1
|
|
53
|
+
sanitized_name = self._sanitize_name(name)
|
|
54
|
+
sanitized_kind = kind.lower()
|
|
55
|
+
# Format: {kind}_{name}_{counter}
|
|
56
|
+
self.symbol_map[key] = f"{sanitized_kind}_{sanitized_name}_{self.symbol_counter:04d}"
|
|
57
|
+
return self.symbol_map[key]
|
|
58
|
+
|
|
59
|
+
def _add_symbol(self, name: str, kind: str, lineno: int,
|
|
60
|
+
extra_meta: Optional[Dict] = None):
|
|
61
|
+
"""Add a symbol to the collection."""
|
|
62
|
+
# Skip if symbol name is in excluded keywords
|
|
63
|
+
if name.lower() in self.excluded_keywords:
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
symbol_id = self._get_symbol_id(name, kind)
|
|
67
|
+
|
|
68
|
+
# Skip if symbol already added (prevents duplicates)
|
|
69
|
+
if symbol_id in self.added_symbol_ids:
|
|
70
|
+
return symbol_id
|
|
71
|
+
|
|
72
|
+
meta = {
|
|
73
|
+
"file": self.file_path,
|
|
74
|
+
"lineno": lineno,
|
|
75
|
+
"typing": []
|
|
76
|
+
}
|
|
77
|
+
if extra_meta:
|
|
78
|
+
meta.update(extra_meta)
|
|
79
|
+
|
|
80
|
+
self.symbols.append({
|
|
81
|
+
"id": symbol_id,
|
|
82
|
+
"kind": kind,
|
|
83
|
+
"name": name,
|
|
84
|
+
"meta": meta
|
|
85
|
+
})
|
|
86
|
+
self.added_symbol_ids.add(symbol_id)
|
|
87
|
+
return symbol_id
|
|
88
|
+
|
|
89
|
+
def _add_edge(self, src: str, dst: str, rel: str, weight: float = 1.0):
|
|
90
|
+
"""Add a relationship edge, aggregating weights for duplicate edges."""
|
|
91
|
+
edge_key = (src, dst, rel)
|
|
92
|
+
if edge_key in self.edge_weights:
|
|
93
|
+
# Aggregate weights: use maximum (consistent with training code)
|
|
94
|
+
self.edge_weights[edge_key] = max(self.edge_weights[edge_key], weight)
|
|
95
|
+
else:
|
|
96
|
+
self.edge_weights[edge_key] = weight
|
|
97
|
+
|
|
98
|
+
def _add_type_token(self, symbol_id: str, type_token: str, count: int = 1):
|
|
99
|
+
"""Add type token information."""
|
|
100
|
+
self.type_tokens[(symbol_id, type_token)] += count
|
|
101
|
+
|
|
102
|
+
def visit_Module(self, node):
|
|
103
|
+
"""Visit module node."""
|
|
104
|
+
file_symbol = self._add_symbol(
|
|
105
|
+
os.path.basename(self.file_path),
|
|
106
|
+
"file",
|
|
107
|
+
node.lineno if hasattr(node, 'lineno') else 1
|
|
108
|
+
)
|
|
109
|
+
self.generic_visit(node)
|
|
110
|
+
return file_symbol
|
|
111
|
+
|
|
112
|
+
def visit_ClassDef(self, node):
|
|
113
|
+
"""Visit class definition."""
|
|
114
|
+
bases = []
|
|
115
|
+
for base in node.bases:
|
|
116
|
+
if hasattr(ast, 'unparse'):
|
|
117
|
+
bases.append(ast.unparse(base))
|
|
118
|
+
elif isinstance(base, ast.Name):
|
|
119
|
+
bases.append(base.id)
|
|
120
|
+
else:
|
|
121
|
+
bases.append(str(base))
|
|
122
|
+
|
|
123
|
+
class_id = self._add_symbol(
|
|
124
|
+
node.name,
|
|
125
|
+
"class",
|
|
126
|
+
node.lineno,
|
|
127
|
+
{"bases": bases}
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Link class to file
|
|
131
|
+
file_symbol = self._get_symbol_id(os.path.basename(self.file_path), "file")
|
|
132
|
+
self._add_edge(file_symbol, class_id, "defines_in_file", 1.0)
|
|
133
|
+
|
|
134
|
+
# Handle inheritance
|
|
135
|
+
for base in node.bases:
|
|
136
|
+
if isinstance(base, ast.Name):
|
|
137
|
+
base_name = base.id
|
|
138
|
+
base_id = self._get_symbol_id(base_name, "class")
|
|
139
|
+
self._add_edge(class_id, base_id, "inherits", 1.0)
|
|
140
|
+
|
|
141
|
+
old_class = self.current_class
|
|
142
|
+
self.current_class = class_id
|
|
143
|
+
self.generic_visit(node)
|
|
144
|
+
self.current_class = old_class
|
|
145
|
+
|
|
146
|
+
def visit_FunctionDef(self, node):
|
|
147
|
+
"""Visit function definition."""
|
|
148
|
+
func_id = self._add_symbol(
|
|
149
|
+
node.name,
|
|
150
|
+
"function",
|
|
151
|
+
node.lineno,
|
|
152
|
+
{"args": len(node.args.args)}
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# Link function to containing class or file
|
|
156
|
+
if self.current_class:
|
|
157
|
+
self._add_edge(self.current_class, func_id, "defines_in_file", 1.0)
|
|
158
|
+
else:
|
|
159
|
+
file_symbol = self._get_symbol_id(os.path.basename(self.file_path), "file")
|
|
160
|
+
self._add_edge(file_symbol, func_id, "defines_in_file", 1.0)
|
|
161
|
+
|
|
162
|
+
# Extract return type annotation
|
|
163
|
+
if node.returns:
|
|
164
|
+
return_type = self._extract_type_annotation(node.returns)
|
|
165
|
+
if return_type:
|
|
166
|
+
self._add_type_token(func_id, return_type, 1)
|
|
167
|
+
|
|
168
|
+
# Extract parameter type annotations
|
|
169
|
+
for arg in node.args.args:
|
|
170
|
+
if arg.annotation:
|
|
171
|
+
param_type = self._extract_type_annotation(arg.annotation)
|
|
172
|
+
if param_type:
|
|
173
|
+
self._add_type_token(func_id, param_type, 1)
|
|
174
|
+
|
|
175
|
+
old_function = self.current_function
|
|
176
|
+
self.current_function = func_id
|
|
177
|
+
self.generic_visit(node)
|
|
178
|
+
self.current_function = old_function
|
|
179
|
+
|
|
180
|
+
def visit_AsyncFunctionDef(self, node):
|
|
181
|
+
"""Visit async function definition."""
|
|
182
|
+
self.visit_FunctionDef(node)
|
|
183
|
+
|
|
184
|
+
def visit_Import(self, node):
|
|
185
|
+
"""Visit import statement."""
|
|
186
|
+
for alias in node.names:
|
|
187
|
+
module_name = alias.name
|
|
188
|
+
import_name = alias.asname if alias.asname else alias.name.split('.')[0]
|
|
189
|
+
import_id = self._get_symbol_id(import_name, "import")
|
|
190
|
+
|
|
191
|
+
# Store import mapping
|
|
192
|
+
self.imports[import_name] = import_id
|
|
193
|
+
|
|
194
|
+
# Link import to file
|
|
195
|
+
file_symbol = self._get_symbol_id(os.path.basename(self.file_path), "file")
|
|
196
|
+
self._add_edge(file_symbol, import_id, "imports", 1.0)
|
|
197
|
+
|
|
198
|
+
def visit_ImportFrom(self, node):
|
|
199
|
+
"""Visit from ... import statement."""
|
|
200
|
+
module_name = node.module if node.module else ""
|
|
201
|
+
for alias in node.names:
|
|
202
|
+
import_name = alias.asname if alias.asname else alias.name
|
|
203
|
+
import_id = self._get_symbol_id(import_name, "import")
|
|
204
|
+
|
|
205
|
+
# Store import mapping
|
|
206
|
+
self.imports[import_name] = import_id
|
|
207
|
+
|
|
208
|
+
# Link import to file
|
|
209
|
+
file_symbol = self._get_symbol_id(os.path.basename(self.file_path), "file")
|
|
210
|
+
self._add_edge(file_symbol, import_id, "imports", 1.0)
|
|
211
|
+
|
|
212
|
+
def visit_Call(self, node):
|
|
213
|
+
"""Visit function call."""
|
|
214
|
+
if self.current_function or self.current_class:
|
|
215
|
+
caller_id = self.current_function if self.current_function else self.current_class
|
|
216
|
+
|
|
217
|
+
# Extract called function/class name
|
|
218
|
+
if isinstance(node.func, ast.Name):
|
|
219
|
+
callee_name = node.func.id
|
|
220
|
+
callee_id = self._get_symbol_id(callee_name, "function")
|
|
221
|
+
self._add_edge(caller_id, callee_id, "calls", 1.0)
|
|
222
|
+
elif isinstance(node.func, ast.Attribute):
|
|
223
|
+
# Handle method calls like obj.method()
|
|
224
|
+
if isinstance(node.func.value, ast.Name):
|
|
225
|
+
obj_name = node.func.value.id
|
|
226
|
+
method_name = node.func.attr
|
|
227
|
+
method_id = self._get_symbol_id(f"{obj_name}.{method_name}", "function")
|
|
228
|
+
self._add_edge(caller_id, method_id, "calls", 1.0)
|
|
229
|
+
|
|
230
|
+
# Check for co-occurrence with arguments
|
|
231
|
+
for arg in node.args:
|
|
232
|
+
if isinstance(arg, ast.Name):
|
|
233
|
+
var_id = self._get_symbol_id(arg.id, "var")
|
|
234
|
+
self._add_edge(caller_id, var_id, "cooccurs", 0.5)
|
|
235
|
+
|
|
236
|
+
self.generic_visit(node)
|
|
237
|
+
|
|
238
|
+
def visit_Assign(self, node):
|
|
239
|
+
"""Visit variable assignment."""
|
|
240
|
+
for target in node.targets:
|
|
241
|
+
if isinstance(target, ast.Name):
|
|
242
|
+
var_id = self._add_symbol(
|
|
243
|
+
target.id,
|
|
244
|
+
"var",
|
|
245
|
+
node.lineno
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# Link variable to containing function/class/file
|
|
249
|
+
if self.current_function:
|
|
250
|
+
self._add_edge(self.current_function, var_id, "defines_in_file", 1.0)
|
|
251
|
+
elif self.current_class:
|
|
252
|
+
self._add_edge(self.current_class, var_id, "defines_in_file", 1.0)
|
|
253
|
+
else:
|
|
254
|
+
file_symbol = self._get_symbol_id(os.path.basename(self.file_path), "file")
|
|
255
|
+
self._add_edge(file_symbol, var_id, "defines_in_file", 1.0)
|
|
256
|
+
|
|
257
|
+
self.generic_visit(node)
|
|
258
|
+
|
|
259
|
+
def visit_AnnAssign(self, node):
|
|
260
|
+
"""Visit annotated assignment (e.g., x: int = 5)."""
|
|
261
|
+
if isinstance(node.target, ast.Name):
|
|
262
|
+
var_id = self._add_symbol(
|
|
263
|
+
node.target.id,
|
|
264
|
+
"var",
|
|
265
|
+
node.lineno
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
# Link variable to containing function/class/file
|
|
269
|
+
if self.current_function:
|
|
270
|
+
self._add_edge(self.current_function, var_id, "defines_in_file", 1.0)
|
|
271
|
+
elif self.current_class:
|
|
272
|
+
self._add_edge(self.current_class, var_id, "defines_in_file", 1.0)
|
|
273
|
+
else:
|
|
274
|
+
file_symbol = self._get_symbol_id(os.path.basename(self.file_path), "file")
|
|
275
|
+
self._add_edge(file_symbol, var_id, "defines_in_file", 1.0)
|
|
276
|
+
|
|
277
|
+
# Extract type annotation
|
|
278
|
+
if node.annotation:
|
|
279
|
+
var_type = self._extract_type_annotation(node.annotation)
|
|
280
|
+
if var_type:
|
|
281
|
+
self._add_type_token(var_id, var_type, 1)
|
|
282
|
+
|
|
283
|
+
self.generic_visit(node)
|
|
284
|
+
|
|
285
|
+
def visit_Name(self, node):
|
|
286
|
+
"""Visit name node (variable reference)."""
|
|
287
|
+
if isinstance(node.ctx, ast.Load):
|
|
288
|
+
if self.current_function or self.current_class:
|
|
289
|
+
referencer_id = self.current_function if self.current_function else self.current_class
|
|
290
|
+
var_id = self._get_symbol_id(node.id, "var")
|
|
291
|
+
self._add_edge(referencer_id, var_id, "cooccurs", 0.3)
|
|
292
|
+
self.generic_visit(node)
|
|
293
|
+
|
|
294
|
+
def _extract_type_annotation(self, node) -> Optional[str]:
|
|
295
|
+
"""Extract type annotation string from AST node."""
|
|
296
|
+
try:
|
|
297
|
+
if hasattr(ast, 'unparse'):
|
|
298
|
+
return ast.unparse(node)
|
|
299
|
+
else:
|
|
300
|
+
# Fallback for older Python versions
|
|
301
|
+
if isinstance(node, ast.Name):
|
|
302
|
+
return node.id
|
|
303
|
+
elif isinstance(node, ast.Subscript):
|
|
304
|
+
if isinstance(node.value, ast.Name):
|
|
305
|
+
base = node.value.id
|
|
306
|
+
if isinstance(node.slice, ast.Name):
|
|
307
|
+
return f"{base}[{node.slice.id}]"
|
|
308
|
+
elif isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Name):
|
|
309
|
+
return f"{base}[{node.slice.value.id}]"
|
|
310
|
+
return base
|
|
311
|
+
elif isinstance(node, ast.Attribute):
|
|
312
|
+
return f"{node.value.id}.{node.attr}" if isinstance(node.value, ast.Name) else None
|
|
313
|
+
except:
|
|
314
|
+
pass
|
|
315
|
+
return None
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
class GitIgnoreMatcher:
|
|
319
|
+
"""Simple gitignore pattern matcher."""
|
|
320
|
+
|
|
321
|
+
def __init__(self, gitignore_path: Path, root_path: Path):
|
|
322
|
+
self.root_path = root_path.resolve()
|
|
323
|
+
self.patterns = []
|
|
324
|
+
self.load_patterns(gitignore_path)
|
|
325
|
+
|
|
326
|
+
def load_patterns(self, gitignore_path: Path):
|
|
327
|
+
"""Load patterns from .gitignore file."""
|
|
328
|
+
if not gitignore_path.exists():
|
|
329
|
+
return
|
|
330
|
+
|
|
331
|
+
with open(gitignore_path, 'r', encoding='utf-8', errors='ignore') as f:
|
|
332
|
+
for line in f:
|
|
333
|
+
line = line.strip()
|
|
334
|
+
# Skip empty lines and comments
|
|
335
|
+
if not line or line.startswith('#'):
|
|
336
|
+
continue
|
|
337
|
+
|
|
338
|
+
# Handle negation (patterns starting with !)
|
|
339
|
+
negated = line.startswith('!')
|
|
340
|
+
if negated:
|
|
341
|
+
pattern = line[1:]
|
|
342
|
+
else:
|
|
343
|
+
pattern = line
|
|
344
|
+
|
|
345
|
+
# Normalize pattern
|
|
346
|
+
pattern = pattern.rstrip('/')
|
|
347
|
+
|
|
348
|
+
self.patterns.append((pattern, negated))
|
|
349
|
+
|
|
350
|
+
def matches(self, file_path: Path) -> bool:
|
|
351
|
+
"""Check if file/directory matches any gitignore pattern."""
|
|
352
|
+
file_path = file_path.resolve()
|
|
353
|
+
|
|
354
|
+
# Get relative path from root
|
|
355
|
+
try:
|
|
356
|
+
rel_path = file_path.relative_to(self.root_path)
|
|
357
|
+
except ValueError:
|
|
358
|
+
# File is outside root, don't ignore
|
|
359
|
+
return False
|
|
360
|
+
|
|
361
|
+
rel_str = str(rel_path).replace('\\', '/')
|
|
362
|
+
parts = rel_str.split('/')
|
|
363
|
+
|
|
364
|
+
# Check against all patterns
|
|
365
|
+
matched = False
|
|
366
|
+
for pattern, negated in self.patterns:
|
|
367
|
+
if self._match_pattern(pattern, rel_str, parts):
|
|
368
|
+
if negated:
|
|
369
|
+
# Negation pattern - explicitly include
|
|
370
|
+
matched = False
|
|
371
|
+
else:
|
|
372
|
+
matched = True
|
|
373
|
+
|
|
374
|
+
return matched
|
|
375
|
+
|
|
376
|
+
def _match_pattern(self, pattern: str, rel_str: str, parts: List[str]) -> bool:
|
|
377
|
+
"""Match a single gitignore pattern."""
|
|
378
|
+
# Handle directory patterns (ending with /)
|
|
379
|
+
if pattern.endswith('/'):
|
|
380
|
+
pattern = pattern[:-1]
|
|
381
|
+
# Match if any directory component matches
|
|
382
|
+
return any(self._match_glob(pattern, part) for part in parts[:-1])
|
|
383
|
+
|
|
384
|
+
# Handle absolute patterns (starting with /)
|
|
385
|
+
if pattern.startswith('/'):
|
|
386
|
+
pattern = pattern[1:]
|
|
387
|
+
return self._match_glob(pattern, parts[0] if parts else '')
|
|
388
|
+
|
|
389
|
+
# Handle patterns with ** (matches any number of directories)
|
|
390
|
+
if '**' in pattern:
|
|
391
|
+
# Convert ** to regex
|
|
392
|
+
regex_pattern = pattern.replace('**', '.*').replace('*', '[^/]*')
|
|
393
|
+
regex_pattern = regex_pattern.replace('?', '.')
|
|
394
|
+
try:
|
|
395
|
+
return bool(re.search(regex_pattern, rel_str))
|
|
396
|
+
except:
|
|
397
|
+
return fnmatch.fnmatch(rel_str, pattern)
|
|
398
|
+
|
|
399
|
+
# Simple glob pattern - check if any part matches
|
|
400
|
+
return any(self._match_glob(pattern, part) for part in parts)
|
|
401
|
+
|
|
402
|
+
def _match_glob(self, pattern: str, text: str) -> bool:
|
|
403
|
+
"""Match glob pattern against text."""
|
|
404
|
+
return fnmatch.fnmatch(text, pattern)
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def find_gitignore(root_path: Path) -> Optional[Path]:
|
|
408
|
+
"""Find .gitignore file in directory hierarchy."""
|
|
409
|
+
current = root_path.resolve()
|
|
410
|
+
while current != current.parent:
|
|
411
|
+
gitignore = current / '.gitignore'
|
|
412
|
+
if gitignore.exists():
|
|
413
|
+
return gitignore
|
|
414
|
+
current = current.parent
|
|
415
|
+
return None
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
def should_process_file(file_path: str, include_dirs: List[str],
|
|
419
|
+
exclude_dirs: List[str],
|
|
420
|
+
gitignore_matcher: Optional[GitIgnoreMatcher] = None) -> bool:
|
|
421
|
+
"""Check if file should be processed based on directory filters and gitignore."""
|
|
422
|
+
path = Path(file_path)
|
|
423
|
+
|
|
424
|
+
# Check gitignore first (if enabled)
|
|
425
|
+
if gitignore_matcher and gitignore_matcher.matches(path):
|
|
426
|
+
return False
|
|
427
|
+
|
|
428
|
+
# Check exclude patterns
|
|
429
|
+
for exclude in exclude_dirs:
|
|
430
|
+
if exclude in str(path):
|
|
431
|
+
return False
|
|
432
|
+
|
|
433
|
+
# If include_dirs is empty, process all (except excluded)
|
|
434
|
+
if not include_dirs:
|
|
435
|
+
return True
|
|
436
|
+
|
|
437
|
+
# Check if file is in any included directory
|
|
438
|
+
for include in include_dirs:
|
|
439
|
+
if include in str(path):
|
|
440
|
+
return True
|
|
441
|
+
|
|
442
|
+
return False
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def extract_from_file(file_path: str, excluded_keywords: Set[str] = None) -> Tuple[List[Dict], List[Dict], List[Dict]]:
|
|
446
|
+
"""Extract symbols, edges, and types from a source file.
|
|
447
|
+
|
|
448
|
+
Note: Currently only supports Python files (uses AST parsing).
|
|
449
|
+
For other languages, additional parsers would need to be implemented.
|
|
450
|
+
|
|
451
|
+
Args:
|
|
452
|
+
file_path: Path to the source file
|
|
453
|
+
excluded_keywords: Set of symbol names to exclude from extraction
|
|
454
|
+
"""
|
|
455
|
+
try:
|
|
456
|
+
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
|
457
|
+
content = f.read()
|
|
458
|
+
|
|
459
|
+
tree = ast.parse(content, filename=file_path)
|
|
460
|
+
extractor = SymbolExtractor(file_path, excluded_keywords=excluded_keywords)
|
|
461
|
+
extractor.visit(tree)
|
|
462
|
+
|
|
463
|
+
# Convert aggregated edges from dictionary to list format
|
|
464
|
+
edges = []
|
|
465
|
+
for (src, dst, rel), weight in extractor.edge_weights.items():
|
|
466
|
+
edges.append({
|
|
467
|
+
"src": src,
|
|
468
|
+
"dst": dst,
|
|
469
|
+
"rel": rel,
|
|
470
|
+
"weight": weight
|
|
471
|
+
})
|
|
472
|
+
|
|
473
|
+
# Convert type_tokens to list format
|
|
474
|
+
types = []
|
|
475
|
+
for (symbol_id, type_token), count in extractor.type_tokens.items():
|
|
476
|
+
types.append({
|
|
477
|
+
"symbol": symbol_id,
|
|
478
|
+
"type_token": type_token,
|
|
479
|
+
"count": count
|
|
480
|
+
})
|
|
481
|
+
|
|
482
|
+
return extractor.symbols, edges, types
|
|
483
|
+
except Exception as e:
|
|
484
|
+
print(f"Error processing {file_path}: {e}")
|
|
485
|
+
return [], [], []
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
def extract_from_directory(root_dir: str, include_dirs: List[str] = None,
|
|
489
|
+
exclude_dirs: List[str] = None,
|
|
490
|
+
extensions: List[str] = None,
|
|
491
|
+
excluded_keywords: Set[str] = None,
|
|
492
|
+
output_nodes: str = "nodes.jsonl",
|
|
493
|
+
output_edges: str = "edges.jsonl",
|
|
494
|
+
output_types: str = "types.jsonl",
|
|
495
|
+
use_gitignore: bool = True):
|
|
496
|
+
"""Extract symbols from directory recursively.
|
|
497
|
+
|
|
498
|
+
Args:
|
|
499
|
+
root_dir: Root directory to scan
|
|
500
|
+
include_dirs: List of directories to include (empty = all)
|
|
501
|
+
exclude_dirs: List of directories to exclude
|
|
502
|
+
extensions: List of file extensions to process
|
|
503
|
+
excluded_keywords: Set of symbol names to exclude from extraction
|
|
504
|
+
output_nodes: Output file for nodes
|
|
505
|
+
output_edges: Output file for edges
|
|
506
|
+
output_types: Output file for types
|
|
507
|
+
use_gitignore: Whether to use .gitignore filtering
|
|
508
|
+
"""
|
|
509
|
+
import click
|
|
510
|
+
|
|
511
|
+
if include_dirs is None:
|
|
512
|
+
include_dirs = []
|
|
513
|
+
if exclude_dirs is None:
|
|
514
|
+
exclude_dirs = ['.venv', '__pycache__', '.git', 'node_modules', '.pytest_cache']
|
|
515
|
+
if extensions is None:
|
|
516
|
+
extensions = ['py'] # Default to Python files
|
|
517
|
+
if excluded_keywords is None:
|
|
518
|
+
excluded_keywords = set()
|
|
519
|
+
|
|
520
|
+
root_path = Path(root_dir).resolve()
|
|
521
|
+
all_symbols = []
|
|
522
|
+
all_edges = []
|
|
523
|
+
all_types = []
|
|
524
|
+
seen_symbols = {}
|
|
525
|
+
|
|
526
|
+
# Load gitignore if enabled
|
|
527
|
+
gitignore_matcher = None
|
|
528
|
+
if use_gitignore:
|
|
529
|
+
gitignore_path = find_gitignore(root_path)
|
|
530
|
+
if gitignore_path:
|
|
531
|
+
gitignore_matcher = GitIgnoreMatcher(gitignore_path, root_path)
|
|
532
|
+
click.echo(f"Using .gitignore from: {gitignore_path}")
|
|
533
|
+
else:
|
|
534
|
+
click.echo("No .gitignore found, skipping gitignore filtering")
|
|
535
|
+
|
|
536
|
+
click.echo(f"Scanning directory: {root_dir}")
|
|
537
|
+
click.echo(f"Include dirs: {include_dirs if include_dirs else 'all'}")
|
|
538
|
+
click.echo(f"Exclude dirs: {exclude_dirs}")
|
|
539
|
+
click.echo(f"Extensions: {', '.join(extensions)}")
|
|
540
|
+
if excluded_keywords:
|
|
541
|
+
click.echo(f"Excluded keywords: {len(excluded_keywords)} symbol names")
|
|
542
|
+
click.echo(f"Gitignore: {'enabled' if use_gitignore else 'disabled'}")
|
|
543
|
+
|
|
544
|
+
# Normalize extensions: ensure they start with dot
|
|
545
|
+
normalized_extensions = [ext if ext.startswith('.') else f'.{ext}' for ext in extensions]
|
|
546
|
+
|
|
547
|
+
source_files = []
|
|
548
|
+
skipped_gitignore = 0
|
|
549
|
+
|
|
550
|
+
for root, dirs, files in os.walk(root_path):
|
|
551
|
+
# Filter out excluded directories
|
|
552
|
+
dirs_to_remove = []
|
|
553
|
+
for d in dirs:
|
|
554
|
+
dir_path = Path(root) / d
|
|
555
|
+
# Check gitignore
|
|
556
|
+
if gitignore_matcher and gitignore_matcher.matches(dir_path):
|
|
557
|
+
dirs_to_remove.append(d)
|
|
558
|
+
skipped_gitignore += 1
|
|
559
|
+
continue
|
|
560
|
+
# Check exclude patterns
|
|
561
|
+
if any(exclude in str(dir_path) for exclude in exclude_dirs):
|
|
562
|
+
dirs_to_remove.append(d)
|
|
563
|
+
|
|
564
|
+
for d in dirs_to_remove:
|
|
565
|
+
dirs.remove(d)
|
|
566
|
+
|
|
567
|
+
for file in files:
|
|
568
|
+
# Check if file has one of the allowed extensions
|
|
569
|
+
file_path_obj = Path(file)
|
|
570
|
+
file_ext = file_path_obj.suffix.lower()
|
|
571
|
+
if file_ext in normalized_extensions:
|
|
572
|
+
file_path = Path(root) / file
|
|
573
|
+
file_path_str = str(file_path)
|
|
574
|
+
|
|
575
|
+
# Check gitignore
|
|
576
|
+
if gitignore_matcher and gitignore_matcher.matches(file_path):
|
|
577
|
+
skipped_gitignore += 1
|
|
578
|
+
continue
|
|
579
|
+
|
|
580
|
+
if should_process_file(file_path_str, include_dirs, exclude_dirs, gitignore_matcher):
|
|
581
|
+
source_files.append(file_path_str)
|
|
582
|
+
|
|
583
|
+
ext_display = ', '.join(extensions)
|
|
584
|
+
click.echo(f"Found {len(source_files)} file(s) with extensions [{ext_display}] to process")
|
|
585
|
+
if skipped_gitignore > 0:
|
|
586
|
+
click.echo(f"Skipped {skipped_gitignore} files/directories due to .gitignore")
|
|
587
|
+
|
|
588
|
+
with click.progressbar(source_files, label='Processing files') as bar:
|
|
589
|
+
for file_path in bar:
|
|
590
|
+
symbols, edges, types = extract_from_file(file_path, excluded_keywords=excluded_keywords)
|
|
591
|
+
|
|
592
|
+
# Deduplicate symbols by ID
|
|
593
|
+
for symbol in symbols:
|
|
594
|
+
symbol_id = symbol['id']
|
|
595
|
+
if symbol_id not in seen_symbols:
|
|
596
|
+
seen_symbols[symbol_id] = symbol
|
|
597
|
+
all_symbols.append(symbol)
|
|
598
|
+
|
|
599
|
+
all_edges.extend(edges)
|
|
600
|
+
all_types.extend(types)
|
|
601
|
+
|
|
602
|
+
excluded_msg = f" (excluded symbol names: {len(excluded_keywords)} keywords)" if excluded_keywords else ""
|
|
603
|
+
click.echo(
|
|
604
|
+
f"\nExtracted {len(all_symbols)} symbols, {len(all_edges)} edges, {len(all_types)} type tokens{excluded_msg}")
|
|
605
|
+
|
|
606
|
+
# Write output files
|
|
607
|
+
click.echo(f"Writing {output_nodes}...")
|
|
608
|
+
with open(output_nodes, 'w') as f:
|
|
609
|
+
for symbol in all_symbols:
|
|
610
|
+
f.write(json.dumps(symbol) + '\n')
|
|
611
|
+
|
|
612
|
+
click.echo(f"Writing {output_edges}...")
|
|
613
|
+
with open(output_edges, 'w') as f:
|
|
614
|
+
for edge in all_edges:
|
|
615
|
+
f.write(json.dumps(edge) + '\n')
|
|
616
|
+
|
|
617
|
+
click.echo(f"Writing {output_types}...")
|
|
618
|
+
with open(output_types, 'w') as f:
|
|
619
|
+
for type_entry in all_types:
|
|
620
|
+
f.write(json.dumps(type_entry) + '\n')
|
|
621
|
+
|
|
622
|
+
click.echo(click.style("✓ Extraction complete!", fg='green'))
|