tricoder 1.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tricoder/__about__.py +6 -0
- tricoder/__init__.py +19 -0
- tricoder/calibration.py +276 -0
- tricoder/cli.py +890 -0
- tricoder/context_view.py +228 -0
- tricoder/data_loader.py +144 -0
- tricoder/extract.py +622 -0
- tricoder/fusion.py +203 -0
- tricoder/git_tracker.py +203 -0
- tricoder/gpu_utils.py +414 -0
- tricoder/graph_view.py +583 -0
- tricoder/model.py +476 -0
- tricoder/optimize.py +263 -0
- tricoder/subtoken_utils.py +196 -0
- tricoder/train.py +857 -0
- tricoder/typed_view.py +313 -0
- tricoder-1.2.8.dist-info/METADATA +306 -0
- tricoder-1.2.8.dist-info/RECORD +22 -0
- tricoder-1.2.8.dist-info/WHEEL +4 -0
- tricoder-1.2.8.dist-info/entry_points.txt +3 -0
- tricoder-1.2.8.dist-info/licenses/LICENSE +56 -0
- tricoder-1.2.8.dist-info/licenses/LICENSE_COMMERCIAL.md +68 -0
tricoder/model.py
ADDED
|
@@ -0,0 +1,476 @@
|
|
|
1
|
+
"""SymbolModel: main model class for loading and querying."""
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
from typing import List, Dict, Set
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from annoy import AnnoyIndex
|
|
8
|
+
from gensim.models.keyedvectors import KeyedVectors
|
|
9
|
+
|
|
10
|
+
# Default excluded keywords: Python keywords, builtins, and common library names
|
|
11
|
+
# These don't provide value for code intelligence as they're language constructs
|
|
12
|
+
# rather than user-defined code patterns
|
|
13
|
+
DEFAULT_EXCLUDED_KEYWORDS: Set[str] = {
|
|
14
|
+
# Python keywords
|
|
15
|
+
'import', 'from', 'as', 'def', 'class', 'if', 'else', 'elif', 'for', 'while',
|
|
16
|
+
'try', 'except', 'finally', 'with', 'return', 'pass', 'break', 'continue',
|
|
17
|
+
'yield', 'lambda', 'del', 'global', 'nonlocal', 'assert', 'raise', 'and',
|
|
18
|
+
'or', 'not', 'in', 'is', 'None', 'True', 'False',
|
|
19
|
+
|
|
20
|
+
# Common builtin functions/types
|
|
21
|
+
'print', 'len', 'str', 'int', 'float', 'bool', 'list', 'dict', 'tuple',
|
|
22
|
+
'set', 'frozenset', 'type', 'isinstance', 'hasattr', 'getattr', 'setattr',
|
|
23
|
+
'delattr', 'dir', 'vars', 'locals', 'globals', 'eval', 'exec', 'compile',
|
|
24
|
+
'open', 'file', 'range', 'enumerate', 'zip', 'map', 'filter', 'reduce',
|
|
25
|
+
'sorted', 'reversed', 'iter', 'next', 'all', 'any', 'sum', 'max', 'min',
|
|
26
|
+
'abs', 'round', 'divmod', 'pow', 'bin', 'hex', 'oct', 'ord', 'chr',
|
|
27
|
+
'repr', 'ascii', 'format', 'hash', 'id', 'slice', 'super', 'property',
|
|
28
|
+
'staticmethod', 'classmethod', 'object', 'Exception', 'BaseException',
|
|
29
|
+
|
|
30
|
+
# Common standard library module names
|
|
31
|
+
'os', 'sys', 'json', 're', 'datetime', 'time', 'random', 'math', 'collections',
|
|
32
|
+
'itertools', 'functools', 'operator', 'string', 'textwrap', 'unicodedata',
|
|
33
|
+
'stringprep', 'readline', 'rlcompleter', 'struct', 'codecs', 'types', 'copy',
|
|
34
|
+
'pprint', 'reprlib', 'enum', 'numbers', 'cmath', 'decimal', 'fractions',
|
|
35
|
+
'statistics', 'array', 'bisect', 'heapq', 'weakref', 'gc', 'inspect',
|
|
36
|
+
'site', 'fpectl', 'atexit', 'traceback', 'future', 'importlib', 'pkgutil',
|
|
37
|
+
'modulefinder', 'runpy', 'pickle', 'copyreg', 'shelve', 'marshal', 'dbm',
|
|
38
|
+
'sqlite3', 'zlib', 'gzip', 'bz2', 'lzma', 'zipfile', 'tarfile', 'csv',
|
|
39
|
+
'configparser', 'netrc', 'xdrlib', 'plistlib', 'hashlib', 'hmac', 'secrets',
|
|
40
|
+
'io', 'argparse', 'getopt', 'logging', 'getpass', 'curses', 'platform',
|
|
41
|
+
'errno', 'ctypes', 'threading', 'multiprocessing', 'concurrent', 'subprocess',
|
|
42
|
+
'sched', 'queue', 'select', 'selectors', 'asyncio', 'socket', 'ssl', 'email',
|
|
43
|
+
'urllib', 'http', 'html', 'xml', 'webbrowser', 'tkinter', 'turtle', 'cmd',
|
|
44
|
+
'shlex', 'configparser', 'fileinput', 'linecache', 'shutil', 'tempfile',
|
|
45
|
+
'glob', 'fnmatch', 'linecache', 'shutil', 'macpath', 'pathlib', 'stat',
|
|
46
|
+
'filecmp', 'mmap', 'codecs', 'unicodedata', 'stringprep', 'readline',
|
|
47
|
+
'rlcompleter', 'ast', 'symtable', 'symbol', 'token', 'tokenize', 'keyword',
|
|
48
|
+
'parser', 'dis', 'pickletools', 'doctest', 'unittest', 'test', 'lib2to3',
|
|
49
|
+
'typing', 'pydoc', 'doctest', 'unittest', 'test', 'lib2to3', 'distutils',
|
|
50
|
+
'ensurepip', 'venv', 'zipapp', 'faulthandler', 'pdb', 'profile', 'pstats',
|
|
51
|
+
'timeit', 'trace', 'tracemalloc', 'gc', 'inspect', 'site', 'fpectl',
|
|
52
|
+
'warnings', 'contextlib', 'abc', 'atexit', 'traceback', 'future', '__future__',
|
|
53
|
+
'importlib', 'pkgutil', 'modulefinder', 'runpy', 'zipimport', 'pkgutil',
|
|
54
|
+
'modulefinder', 'runpy', 'zipimport', 'pkgutil', 'modulefinder', 'runpy',
|
|
55
|
+
|
|
56
|
+
# Common dunder methods (though these might be useful, excluding common ones)
|
|
57
|
+
'__init__', '__main__', '__name__', '__file__', '__doc__', '__package__',
|
|
58
|
+
'__builtins__', '__dict__', '__class__', '__module__', '__qualname__',
|
|
59
|
+
|
|
60
|
+
# Common variable names that aren't useful
|
|
61
|
+
'self', 'cls', 'args', 'kwargs', 'data', 'result', 'value', 'item',
|
|
62
|
+
'key', 'val', 'obj', 'instance', 'cls', 'self', 'other', 'x', 'y', 'z',
|
|
63
|
+
'i', 'j', 'k', 'n', 'm', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w',
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class SymbolModel:
|
|
68
|
+
"""Main model class for TriVector Code Intelligence."""
|
|
69
|
+
|
|
70
|
+
def __init__(self):
|
|
71
|
+
self.embeddings = None
|
|
72
|
+
self.tau = None
|
|
73
|
+
self.node_map = None
|
|
74
|
+
self.node_metadata = None
|
|
75
|
+
self.pca_components = None
|
|
76
|
+
self.pca_mean = None
|
|
77
|
+
self.svd_components_graph = None
|
|
78
|
+
self.svd_components_types = None
|
|
79
|
+
self.word2vec_kv = None
|
|
80
|
+
self.ann_index = None
|
|
81
|
+
self.type_token_map = None
|
|
82
|
+
self.embedding_dim = None
|
|
83
|
+
self.idx_to_node = None
|
|
84
|
+
self.metadata_lookup = None
|
|
85
|
+
self.mean_norm = None
|
|
86
|
+
self.subtoken_to_idx = None
|
|
87
|
+
self.node_subtokens = None
|
|
88
|
+
self.node_types = None
|
|
89
|
+
self.alpha = 0.05 # Length penalty coefficient
|
|
90
|
+
|
|
91
|
+
def load(self, model_dir: str):
|
|
92
|
+
"""
|
|
93
|
+
Load model from directory.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
model_dir: path to model directory
|
|
97
|
+
"""
|
|
98
|
+
# Check if model directory exists
|
|
99
|
+
if not os.path.exists(model_dir):
|
|
100
|
+
raise FileNotFoundError(f"Model directory not found: {model_dir}")
|
|
101
|
+
|
|
102
|
+
# Load embeddings
|
|
103
|
+
embeddings_path = os.path.join(model_dir, 'embeddings.npy')
|
|
104
|
+
if not os.path.exists(embeddings_path):
|
|
105
|
+
raise FileNotFoundError(
|
|
106
|
+
f"Embeddings file not found: {embeddings_path}\n"
|
|
107
|
+
f"This usually means training was interrupted or failed before completion.\n"
|
|
108
|
+
f"Please retrain the model."
|
|
109
|
+
)
|
|
110
|
+
self.embeddings = np.load(embeddings_path)
|
|
111
|
+
self.embedding_dim = self.embeddings.shape[1]
|
|
112
|
+
|
|
113
|
+
# Load temperature
|
|
114
|
+
self.tau = float(np.load(os.path.join(model_dir, 'tau.npy')))
|
|
115
|
+
|
|
116
|
+
# Load metadata
|
|
117
|
+
with open(os.path.join(model_dir, 'metadata.json'), 'r') as f:
|
|
118
|
+
metadata = json.load(f)
|
|
119
|
+
self.node_map = metadata['node_map']
|
|
120
|
+
self.node_metadata = metadata['node_metadata']
|
|
121
|
+
|
|
122
|
+
# Create reverse mapping for efficient lookup
|
|
123
|
+
self.idx_to_node = {idx: node_id for node_id, idx in self.node_map.items()}
|
|
124
|
+
|
|
125
|
+
# Create metadata lookup
|
|
126
|
+
self.metadata_lookup = {nm['id']: nm for nm in self.node_metadata}
|
|
127
|
+
|
|
128
|
+
# Load PCA components
|
|
129
|
+
self.pca_components = np.load(os.path.join(model_dir, 'fusion_pca_components.npy'))
|
|
130
|
+
self.pca_mean = np.load(os.path.join(model_dir, 'fusion_pca_mean.npy'))
|
|
131
|
+
|
|
132
|
+
# Load SVD components
|
|
133
|
+
self.svd_components_graph = np.load(os.path.join(model_dir, 'svd_components.npy'))
|
|
134
|
+
|
|
135
|
+
# Load type SVD components if available
|
|
136
|
+
types_svd_path = os.path.join(model_dir, 'svd_components_types.npy')
|
|
137
|
+
if os.path.exists(types_svd_path):
|
|
138
|
+
self.svd_components_types = np.load(types_svd_path)
|
|
139
|
+
|
|
140
|
+
# Load Word2Vec
|
|
141
|
+
w2v_path = os.path.join(model_dir, 'word2vec.kv')
|
|
142
|
+
if os.path.exists(w2v_path):
|
|
143
|
+
self.word2vec_kv = KeyedVectors.load(w2v_path, mmap='r')
|
|
144
|
+
|
|
145
|
+
# Load type token map if available
|
|
146
|
+
type_map_path = os.path.join(model_dir, 'type_token_map.json')
|
|
147
|
+
if os.path.exists(type_map_path):
|
|
148
|
+
with open(type_map_path, 'r') as f:
|
|
149
|
+
self.type_token_map = json.load(f)
|
|
150
|
+
|
|
151
|
+
# Load mean_norm for length penalty
|
|
152
|
+
mean_norm_path = os.path.join(model_dir, 'mean_norm.npy')
|
|
153
|
+
if os.path.exists(mean_norm_path):
|
|
154
|
+
self.mean_norm = float(np.load(mean_norm_path))
|
|
155
|
+
else:
|
|
156
|
+
# Compute from embeddings if not saved
|
|
157
|
+
if self.embeddings is not None:
|
|
158
|
+
self.mean_norm = float(np.mean(np.linalg.norm(self.embeddings, axis=1)))
|
|
159
|
+
else:
|
|
160
|
+
self.mean_norm = 1.0
|
|
161
|
+
|
|
162
|
+
# Load subtoken mapping if available
|
|
163
|
+
subtoken_map_path = os.path.join(model_dir, 'subtoken_map.json')
|
|
164
|
+
if os.path.exists(subtoken_map_path):
|
|
165
|
+
with open(subtoken_map_path, 'r') as f:
|
|
166
|
+
self.subtoken_to_idx = json.load(f)
|
|
167
|
+
else:
|
|
168
|
+
self.subtoken_to_idx = {}
|
|
169
|
+
|
|
170
|
+
# Load node subtokens if available
|
|
171
|
+
node_subtokens_path = os.path.join(model_dir, 'node_subtokens.json')
|
|
172
|
+
if os.path.exists(node_subtokens_path):
|
|
173
|
+
with open(node_subtokens_path, 'r') as f:
|
|
174
|
+
self.node_subtokens = json.load(f)
|
|
175
|
+
else:
|
|
176
|
+
self.node_subtokens = {}
|
|
177
|
+
|
|
178
|
+
# Load node types if available
|
|
179
|
+
node_types_path = os.path.join(model_dir, 'node_types.json')
|
|
180
|
+
if os.path.exists(node_types_path):
|
|
181
|
+
with open(node_types_path, 'r') as f:
|
|
182
|
+
self.node_types = json.load(f)
|
|
183
|
+
else:
|
|
184
|
+
self.node_types = {}
|
|
185
|
+
|
|
186
|
+
# Load ANN index
|
|
187
|
+
self.ann_index = AnnoyIndex(self.embedding_dim, 'angular')
|
|
188
|
+
self.ann_index.load(os.path.join(model_dir, 'ann_index.ann'))
|
|
189
|
+
|
|
190
|
+
def expand_query_vector(self, node_id: str) -> np.ndarray:
|
|
191
|
+
"""
|
|
192
|
+
Expand query vector with subtokens and type tokens.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
node_id: symbol ID to query
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
Expanded and normalized query vector
|
|
199
|
+
"""
|
|
200
|
+
if node_id not in self.node_map:
|
|
201
|
+
return None
|
|
202
|
+
|
|
203
|
+
node_idx = self.node_map[node_id]
|
|
204
|
+
base_vector = self.embeddings[node_idx]
|
|
205
|
+
|
|
206
|
+
vectors_to_average = [base_vector] # Start with symbol vector (weight 1.0)
|
|
207
|
+
weights = [1.0]
|
|
208
|
+
|
|
209
|
+
# Add subtoken vectors (weight 0.6)
|
|
210
|
+
if self.node_subtokens and node_id in self.node_subtokens:
|
|
211
|
+
subtoken_vectors = []
|
|
212
|
+
for subtoken in self.node_subtokens[node_id]:
|
|
213
|
+
if self.subtoken_to_idx and subtoken in self.subtoken_to_idx:
|
|
214
|
+
subtoken_idx = self.subtoken_to_idx[subtoken]
|
|
215
|
+
if subtoken_idx < len(self.embeddings):
|
|
216
|
+
subtoken_vectors.append(self.embeddings[subtoken_idx])
|
|
217
|
+
|
|
218
|
+
if subtoken_vectors:
|
|
219
|
+
subtoken_avg = np.mean(subtoken_vectors, axis=0)
|
|
220
|
+
vectors_to_average.append(subtoken_avg)
|
|
221
|
+
weights.append(0.6)
|
|
222
|
+
|
|
223
|
+
# Add type token vectors (weight 0.4)
|
|
224
|
+
if self.node_types and node_id in self.node_types:
|
|
225
|
+
type_vectors = []
|
|
226
|
+
for type_token, count in self.node_types[node_id].items():
|
|
227
|
+
if self.type_token_map and type_token in self.type_token_map:
|
|
228
|
+
type_idx = self.type_token_map[type_token]
|
|
229
|
+
# Type tokens might be in a separate space, skip for now
|
|
230
|
+
# In a full implementation, we'd need to map type embeddings
|
|
231
|
+
|
|
232
|
+
# For now, we'll skip type expansion in query if types aren't directly mapped
|
|
233
|
+
# This would require storing type embeddings separately
|
|
234
|
+
|
|
235
|
+
# Weighted average
|
|
236
|
+
if len(vectors_to_average) > 1:
|
|
237
|
+
weights_array = np.array(weights)
|
|
238
|
+
weights_array = weights_array / weights_array.sum() # Normalize weights
|
|
239
|
+
|
|
240
|
+
expanded = np.zeros_like(base_vector)
|
|
241
|
+
for vec, weight in zip(vectors_to_average, weights_array):
|
|
242
|
+
expanded += weight * vec
|
|
243
|
+
else:
|
|
244
|
+
expanded = base_vector
|
|
245
|
+
|
|
246
|
+
# Normalize
|
|
247
|
+
norm = np.linalg.norm(expanded)
|
|
248
|
+
if norm > 1e-10:
|
|
249
|
+
expanded = expanded / norm
|
|
250
|
+
|
|
251
|
+
return expanded
|
|
252
|
+
|
|
253
|
+
def compute_hybrid_score(self, query_vec: np.ndarray, candidate_vec: np.ndarray,
|
|
254
|
+
candidate_norm_before_normalization: float = None) -> float:
|
|
255
|
+
"""
|
|
256
|
+
Compute hybrid similarity score with length penalty.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
query_vec: query embedding vector (normalized)
|
|
260
|
+
candidate_vec: candidate embedding vector (normalized)
|
|
261
|
+
candidate_norm_before_normalization: norm before normalization (for penalty)
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
Hybrid similarity score
|
|
265
|
+
"""
|
|
266
|
+
# Cosine similarity
|
|
267
|
+
cosine_sim = np.dot(query_vec, candidate_vec)
|
|
268
|
+
|
|
269
|
+
# Length penalty
|
|
270
|
+
if candidate_norm_before_normalization is not None and self.mean_norm is not None:
|
|
271
|
+
length_penalty = max(0, candidate_norm_before_normalization - self.mean_norm)
|
|
272
|
+
score = cosine_sim - self.alpha * length_penalty
|
|
273
|
+
else:
|
|
274
|
+
score = cosine_sim
|
|
275
|
+
|
|
276
|
+
return score
|
|
277
|
+
|
|
278
|
+
def query(self, node_id: str, top_k: int = 10) -> List[Dict]:
|
|
279
|
+
"""
|
|
280
|
+
Query for similar symbols with query expansion and hybrid scoring.
|
|
281
|
+
|
|
282
|
+
Args:
|
|
283
|
+
node_id: symbol ID to query
|
|
284
|
+
top_k: number of results to return
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
List of result dictionaries with symbol, score, distance, meta
|
|
288
|
+
"""
|
|
289
|
+
if node_id not in self.node_map:
|
|
290
|
+
return []
|
|
291
|
+
|
|
292
|
+
node_idx = self.node_map[node_id]
|
|
293
|
+
|
|
294
|
+
# Expand query vector
|
|
295
|
+
query_vector = self.expand_query_vector(node_id)
|
|
296
|
+
if query_vector is None:
|
|
297
|
+
query_vector = self.embeddings[node_idx]
|
|
298
|
+
|
|
299
|
+
# ANN search
|
|
300
|
+
indices, distances = self.ann_index.get_nns_by_vector(
|
|
301
|
+
query_vector, top_k + 1, include_distances=True
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
results = []
|
|
305
|
+
for idx, dist in zip(indices, distances):
|
|
306
|
+
# Skip self
|
|
307
|
+
if idx == node_idx:
|
|
308
|
+
continue
|
|
309
|
+
|
|
310
|
+
# Find node_id for this index
|
|
311
|
+
node_id_result = self.idx_to_node.get(idx)
|
|
312
|
+
if node_id_result is None:
|
|
313
|
+
continue
|
|
314
|
+
|
|
315
|
+
# Compute hybrid score
|
|
316
|
+
candidate_vec = self.embeddings[idx]
|
|
317
|
+
hybrid_score = self.compute_hybrid_score(query_vector, candidate_vec)
|
|
318
|
+
|
|
319
|
+
# Compute calibrated score (for probability)
|
|
320
|
+
calibrated_score = hybrid_score / self.tau if self.tau else hybrid_score
|
|
321
|
+
|
|
322
|
+
# Get metadata
|
|
323
|
+
meta = self.metadata_lookup.get(node_id_result)
|
|
324
|
+
|
|
325
|
+
results.append({
|
|
326
|
+
'symbol': node_id_result,
|
|
327
|
+
'score': float(calibrated_score),
|
|
328
|
+
'hybrid_score': float(hybrid_score),
|
|
329
|
+
'distance': float(dist),
|
|
330
|
+
'meta': meta
|
|
331
|
+
})
|
|
332
|
+
|
|
333
|
+
if len(results) >= top_k:
|
|
334
|
+
break
|
|
335
|
+
|
|
336
|
+
# Sort by hybrid score (descending)
|
|
337
|
+
results.sort(key=lambda x: x['hybrid_score'], reverse=True)
|
|
338
|
+
|
|
339
|
+
return results
|
|
340
|
+
|
|
341
|
+
def search_by_keywords(self, keywords: str, top_k: int = 10,
|
|
342
|
+
excluded_keywords: Set[str] = None) -> List[Dict]:
|
|
343
|
+
"""
|
|
344
|
+
Search for symbols by keywords (name matching).
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
keywords: space-separated keywords or quoted string to search for
|
|
348
|
+
top_k: number of results to return
|
|
349
|
+
excluded_keywords: set of keywords to exclude (defaults to DEFAULT_EXCLUDED_KEYWORDS)
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
List of matching symbol dictionaries with symbol, score, meta
|
|
353
|
+
"""
|
|
354
|
+
if not self.metadata_lookup:
|
|
355
|
+
return []
|
|
356
|
+
|
|
357
|
+
# Use default excluded keywords if not provided
|
|
358
|
+
if excluded_keywords is None:
|
|
359
|
+
excluded_keywords = DEFAULT_EXCLUDED_KEYWORDS
|
|
360
|
+
|
|
361
|
+
# Normalize keywords (case-insensitive)
|
|
362
|
+
keywords_lower = keywords.lower().strip()
|
|
363
|
+
keyword_words = keywords_lower.split()
|
|
364
|
+
|
|
365
|
+
# Filter out excluded keywords from search query
|
|
366
|
+
filtered_keyword_words = [w for w in keyword_words if w not in excluded_keywords]
|
|
367
|
+
|
|
368
|
+
# If all keywords were filtered out, return empty results
|
|
369
|
+
if not filtered_keyword_words:
|
|
370
|
+
return []
|
|
371
|
+
|
|
372
|
+
# Rebuild keywords string from filtered words
|
|
373
|
+
filtered_keywords_lower = ' '.join(filtered_keyword_words)
|
|
374
|
+
|
|
375
|
+
# Get type tokens for this symbol (if available)
|
|
376
|
+
def get_type_tokens(node_id: str) -> List[str]:
|
|
377
|
+
"""Get all type tokens for a symbol, including expanded primitives."""
|
|
378
|
+
if not self.node_types or node_id not in self.node_types:
|
|
379
|
+
return []
|
|
380
|
+
|
|
381
|
+
type_tokens = []
|
|
382
|
+
for type_token, count in self.node_types[node_id].items():
|
|
383
|
+
# Add the full type token
|
|
384
|
+
type_tokens.append(type_token.lower())
|
|
385
|
+
|
|
386
|
+
# Also parse composite types to extract primitives (e.g., "List[bool]" -> ["bool"])
|
|
387
|
+
# This allows matching "bool" when searching for "bool variable"
|
|
388
|
+
if '[' in type_token and ']' in type_token:
|
|
389
|
+
# Extract content between brackets
|
|
390
|
+
start = type_token.find('[')
|
|
391
|
+
end = type_token.rfind(']')
|
|
392
|
+
if start < end:
|
|
393
|
+
inner = type_token[start+1:end].strip()
|
|
394
|
+
# Split by comma and add individual types
|
|
395
|
+
for part in inner.split(','):
|
|
396
|
+
part = part.strip().lower()
|
|
397
|
+
if part and part not in type_tokens:
|
|
398
|
+
type_tokens.append(part)
|
|
399
|
+
|
|
400
|
+
return type_tokens
|
|
401
|
+
|
|
402
|
+
# Find matching symbols
|
|
403
|
+
matches = []
|
|
404
|
+
for node_id, meta in self.metadata_lookup.items():
|
|
405
|
+
if not meta:
|
|
406
|
+
continue
|
|
407
|
+
|
|
408
|
+
name = meta.get('name', '').lower()
|
|
409
|
+
kind = meta.get('kind', '').lower()
|
|
410
|
+
|
|
411
|
+
# Skip symbols whose names are in excluded keywords (they're not useful)
|
|
412
|
+
if name in excluded_keywords:
|
|
413
|
+
continue
|
|
414
|
+
|
|
415
|
+
# Get type tokens for this symbol
|
|
416
|
+
type_tokens = get_type_tokens(node_id)
|
|
417
|
+
type_tokens_str = ' '.join(type_tokens) # Combined string for matching
|
|
418
|
+
|
|
419
|
+
# Calculate a simple relevance score
|
|
420
|
+
score = 0.0
|
|
421
|
+
|
|
422
|
+
# Check exact phrase match first (highest priority)
|
|
423
|
+
if filtered_keywords_lower == name:
|
|
424
|
+
score = 1.0 # Exact name match
|
|
425
|
+
elif name.startswith(filtered_keywords_lower):
|
|
426
|
+
score = 0.8 # Name starts with keywords
|
|
427
|
+
elif filtered_keywords_lower in name:
|
|
428
|
+
score = 0.6 # Keywords contained in name
|
|
429
|
+
# For multi-word queries, check if all words appear in name, kind, or types
|
|
430
|
+
elif len(filtered_keyword_words) > 1:
|
|
431
|
+
# Check if all words appear in the name
|
|
432
|
+
all_words_in_name = all(word in name for word in filtered_keyword_words)
|
|
433
|
+
if all_words_in_name:
|
|
434
|
+
# Count how many words match
|
|
435
|
+
matching_words = sum(1 for word in filtered_keyword_words if word in name)
|
|
436
|
+
score = 0.5 + (0.2 * matching_words / len(filtered_keyword_words)) # 0.5-0.7 range
|
|
437
|
+
# Check if words match name + type (e.g., "bool variable")
|
|
438
|
+
else:
|
|
439
|
+
# Try to match some words in name/kind and some in types
|
|
440
|
+
name_kind_matches = sum(1 for word in filtered_keyword_words if word in name or word in kind)
|
|
441
|
+
type_matches = sum(1 for word in filtered_keyword_words if word in type_tokens_str)
|
|
442
|
+
|
|
443
|
+
if name_kind_matches > 0 and type_matches > 0:
|
|
444
|
+
# Combined match: name/kind + type (e.g., "bool variable")
|
|
445
|
+
score = 0.4 + (0.2 * (name_kind_matches + type_matches) / len(filtered_keyword_words))
|
|
446
|
+
elif all(word in kind for word in filtered_keyword_words):
|
|
447
|
+
score = 0.3 # All words in kind
|
|
448
|
+
elif all(word in type_tokens_str for word in filtered_keyword_words):
|
|
449
|
+
score = 0.35 # All words in types
|
|
450
|
+
# Single word queries
|
|
451
|
+
elif len(filtered_keyword_words) == 1:
|
|
452
|
+
word = filtered_keyword_words[0]
|
|
453
|
+
if word == name:
|
|
454
|
+
score = 1.0 # Exact name match
|
|
455
|
+
elif name.startswith(word):
|
|
456
|
+
score = 0.8 # Name starts with word
|
|
457
|
+
elif word in name:
|
|
458
|
+
score = 0.6 # Word contained in name
|
|
459
|
+
elif word == kind:
|
|
460
|
+
score = 0.4 # Kind match
|
|
461
|
+
elif word in kind:
|
|
462
|
+
score = 0.2 # Word in kind
|
|
463
|
+
elif word in type_tokens_str:
|
|
464
|
+
score = 0.3 # Word in type tokens
|
|
465
|
+
|
|
466
|
+
if score > 0:
|
|
467
|
+
matches.append({
|
|
468
|
+
'symbol': node_id,
|
|
469
|
+
'score': score,
|
|
470
|
+
'meta': meta
|
|
471
|
+
})
|
|
472
|
+
|
|
473
|
+
# Sort by relevance score (descending)
|
|
474
|
+
matches.sort(key=lambda x: x['score'], reverse=True)
|
|
475
|
+
|
|
476
|
+
return matches[:top_k]
|