tricoder 1.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tricoder/model.py ADDED
@@ -0,0 +1,476 @@
1
+ """SymbolModel: main model class for loading and querying."""
2
+ import json
3
+ import os
4
+ from typing import List, Dict, Set
5
+
6
+ import numpy as np
7
+ from annoy import AnnoyIndex
8
+ from gensim.models.keyedvectors import KeyedVectors
9
+
10
+ # Default excluded keywords: Python keywords, builtins, and common library names
11
+ # These don't provide value for code intelligence as they're language constructs
12
+ # rather than user-defined code patterns
13
+ DEFAULT_EXCLUDED_KEYWORDS: Set[str] = {
14
+ # Python keywords
15
+ 'import', 'from', 'as', 'def', 'class', 'if', 'else', 'elif', 'for', 'while',
16
+ 'try', 'except', 'finally', 'with', 'return', 'pass', 'break', 'continue',
17
+ 'yield', 'lambda', 'del', 'global', 'nonlocal', 'assert', 'raise', 'and',
18
+ 'or', 'not', 'in', 'is', 'None', 'True', 'False',
19
+
20
+ # Common builtin functions/types
21
+ 'print', 'len', 'str', 'int', 'float', 'bool', 'list', 'dict', 'tuple',
22
+ 'set', 'frozenset', 'type', 'isinstance', 'hasattr', 'getattr', 'setattr',
23
+ 'delattr', 'dir', 'vars', 'locals', 'globals', 'eval', 'exec', 'compile',
24
+ 'open', 'file', 'range', 'enumerate', 'zip', 'map', 'filter', 'reduce',
25
+ 'sorted', 'reversed', 'iter', 'next', 'all', 'any', 'sum', 'max', 'min',
26
+ 'abs', 'round', 'divmod', 'pow', 'bin', 'hex', 'oct', 'ord', 'chr',
27
+ 'repr', 'ascii', 'format', 'hash', 'id', 'slice', 'super', 'property',
28
+ 'staticmethod', 'classmethod', 'object', 'Exception', 'BaseException',
29
+
30
+ # Common standard library module names
31
+ 'os', 'sys', 'json', 're', 'datetime', 'time', 'random', 'math', 'collections',
32
+ 'itertools', 'functools', 'operator', 'string', 'textwrap', 'unicodedata',
33
+ 'stringprep', 'readline', 'rlcompleter', 'struct', 'codecs', 'types', 'copy',
34
+ 'pprint', 'reprlib', 'enum', 'numbers', 'cmath', 'decimal', 'fractions',
35
+ 'statistics', 'array', 'bisect', 'heapq', 'weakref', 'gc', 'inspect',
36
+ 'site', 'fpectl', 'atexit', 'traceback', 'future', 'importlib', 'pkgutil',
37
+ 'modulefinder', 'runpy', 'pickle', 'copyreg', 'shelve', 'marshal', 'dbm',
38
+ 'sqlite3', 'zlib', 'gzip', 'bz2', 'lzma', 'zipfile', 'tarfile', 'csv',
39
+ 'configparser', 'netrc', 'xdrlib', 'plistlib', 'hashlib', 'hmac', 'secrets',
40
+ 'io', 'argparse', 'getopt', 'logging', 'getpass', 'curses', 'platform',
41
+ 'errno', 'ctypes', 'threading', 'multiprocessing', 'concurrent', 'subprocess',
42
+ 'sched', 'queue', 'select', 'selectors', 'asyncio', 'socket', 'ssl', 'email',
43
+ 'urllib', 'http', 'html', 'xml', 'webbrowser', 'tkinter', 'turtle', 'cmd',
44
+ 'shlex', 'configparser', 'fileinput', 'linecache', 'shutil', 'tempfile',
45
+ 'glob', 'fnmatch', 'linecache', 'shutil', 'macpath', 'pathlib', 'stat',
46
+ 'filecmp', 'mmap', 'codecs', 'unicodedata', 'stringprep', 'readline',
47
+ 'rlcompleter', 'ast', 'symtable', 'symbol', 'token', 'tokenize', 'keyword',
48
+ 'parser', 'dis', 'pickletools', 'doctest', 'unittest', 'test', 'lib2to3',
49
+ 'typing', 'pydoc', 'doctest', 'unittest', 'test', 'lib2to3', 'distutils',
50
+ 'ensurepip', 'venv', 'zipapp', 'faulthandler', 'pdb', 'profile', 'pstats',
51
+ 'timeit', 'trace', 'tracemalloc', 'gc', 'inspect', 'site', 'fpectl',
52
+ 'warnings', 'contextlib', 'abc', 'atexit', 'traceback', 'future', '__future__',
53
+ 'importlib', 'pkgutil', 'modulefinder', 'runpy', 'zipimport', 'pkgutil',
54
+ 'modulefinder', 'runpy', 'zipimport', 'pkgutil', 'modulefinder', 'runpy',
55
+
56
+ # Common dunder methods (though these might be useful, excluding common ones)
57
+ '__init__', '__main__', '__name__', '__file__', '__doc__', '__package__',
58
+ '__builtins__', '__dict__', '__class__', '__module__', '__qualname__',
59
+
60
+ # Common variable names that aren't useful
61
+ 'self', 'cls', 'args', 'kwargs', 'data', 'result', 'value', 'item',
62
+ 'key', 'val', 'obj', 'instance', 'cls', 'self', 'other', 'x', 'y', 'z',
63
+ 'i', 'j', 'k', 'n', 'm', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w',
64
+ }
65
+
66
+
67
+ class SymbolModel:
68
+ """Main model class for TriVector Code Intelligence."""
69
+
70
+ def __init__(self):
71
+ self.embeddings = None
72
+ self.tau = None
73
+ self.node_map = None
74
+ self.node_metadata = None
75
+ self.pca_components = None
76
+ self.pca_mean = None
77
+ self.svd_components_graph = None
78
+ self.svd_components_types = None
79
+ self.word2vec_kv = None
80
+ self.ann_index = None
81
+ self.type_token_map = None
82
+ self.embedding_dim = None
83
+ self.idx_to_node = None
84
+ self.metadata_lookup = None
85
+ self.mean_norm = None
86
+ self.subtoken_to_idx = None
87
+ self.node_subtokens = None
88
+ self.node_types = None
89
+ self.alpha = 0.05 # Length penalty coefficient
90
+
91
+ def load(self, model_dir: str):
92
+ """
93
+ Load model from directory.
94
+
95
+ Args:
96
+ model_dir: path to model directory
97
+ """
98
+ # Check if model directory exists
99
+ if not os.path.exists(model_dir):
100
+ raise FileNotFoundError(f"Model directory not found: {model_dir}")
101
+
102
+ # Load embeddings
103
+ embeddings_path = os.path.join(model_dir, 'embeddings.npy')
104
+ if not os.path.exists(embeddings_path):
105
+ raise FileNotFoundError(
106
+ f"Embeddings file not found: {embeddings_path}\n"
107
+ f"This usually means training was interrupted or failed before completion.\n"
108
+ f"Please retrain the model."
109
+ )
110
+ self.embeddings = np.load(embeddings_path)
111
+ self.embedding_dim = self.embeddings.shape[1]
112
+
113
+ # Load temperature
114
+ self.tau = float(np.load(os.path.join(model_dir, 'tau.npy')))
115
+
116
+ # Load metadata
117
+ with open(os.path.join(model_dir, 'metadata.json'), 'r') as f:
118
+ metadata = json.load(f)
119
+ self.node_map = metadata['node_map']
120
+ self.node_metadata = metadata['node_metadata']
121
+
122
+ # Create reverse mapping for efficient lookup
123
+ self.idx_to_node = {idx: node_id for node_id, idx in self.node_map.items()}
124
+
125
+ # Create metadata lookup
126
+ self.metadata_lookup = {nm['id']: nm for nm in self.node_metadata}
127
+
128
+ # Load PCA components
129
+ self.pca_components = np.load(os.path.join(model_dir, 'fusion_pca_components.npy'))
130
+ self.pca_mean = np.load(os.path.join(model_dir, 'fusion_pca_mean.npy'))
131
+
132
+ # Load SVD components
133
+ self.svd_components_graph = np.load(os.path.join(model_dir, 'svd_components.npy'))
134
+
135
+ # Load type SVD components if available
136
+ types_svd_path = os.path.join(model_dir, 'svd_components_types.npy')
137
+ if os.path.exists(types_svd_path):
138
+ self.svd_components_types = np.load(types_svd_path)
139
+
140
+ # Load Word2Vec
141
+ w2v_path = os.path.join(model_dir, 'word2vec.kv')
142
+ if os.path.exists(w2v_path):
143
+ self.word2vec_kv = KeyedVectors.load(w2v_path, mmap='r')
144
+
145
+ # Load type token map if available
146
+ type_map_path = os.path.join(model_dir, 'type_token_map.json')
147
+ if os.path.exists(type_map_path):
148
+ with open(type_map_path, 'r') as f:
149
+ self.type_token_map = json.load(f)
150
+
151
+ # Load mean_norm for length penalty
152
+ mean_norm_path = os.path.join(model_dir, 'mean_norm.npy')
153
+ if os.path.exists(mean_norm_path):
154
+ self.mean_norm = float(np.load(mean_norm_path))
155
+ else:
156
+ # Compute from embeddings if not saved
157
+ if self.embeddings is not None:
158
+ self.mean_norm = float(np.mean(np.linalg.norm(self.embeddings, axis=1)))
159
+ else:
160
+ self.mean_norm = 1.0
161
+
162
+ # Load subtoken mapping if available
163
+ subtoken_map_path = os.path.join(model_dir, 'subtoken_map.json')
164
+ if os.path.exists(subtoken_map_path):
165
+ with open(subtoken_map_path, 'r') as f:
166
+ self.subtoken_to_idx = json.load(f)
167
+ else:
168
+ self.subtoken_to_idx = {}
169
+
170
+ # Load node subtokens if available
171
+ node_subtokens_path = os.path.join(model_dir, 'node_subtokens.json')
172
+ if os.path.exists(node_subtokens_path):
173
+ with open(node_subtokens_path, 'r') as f:
174
+ self.node_subtokens = json.load(f)
175
+ else:
176
+ self.node_subtokens = {}
177
+
178
+ # Load node types if available
179
+ node_types_path = os.path.join(model_dir, 'node_types.json')
180
+ if os.path.exists(node_types_path):
181
+ with open(node_types_path, 'r') as f:
182
+ self.node_types = json.load(f)
183
+ else:
184
+ self.node_types = {}
185
+
186
+ # Load ANN index
187
+ self.ann_index = AnnoyIndex(self.embedding_dim, 'angular')
188
+ self.ann_index.load(os.path.join(model_dir, 'ann_index.ann'))
189
+
190
+ def expand_query_vector(self, node_id: str) -> np.ndarray:
191
+ """
192
+ Expand query vector with subtokens and type tokens.
193
+
194
+ Args:
195
+ node_id: symbol ID to query
196
+
197
+ Returns:
198
+ Expanded and normalized query vector
199
+ """
200
+ if node_id not in self.node_map:
201
+ return None
202
+
203
+ node_idx = self.node_map[node_id]
204
+ base_vector = self.embeddings[node_idx]
205
+
206
+ vectors_to_average = [base_vector] # Start with symbol vector (weight 1.0)
207
+ weights = [1.0]
208
+
209
+ # Add subtoken vectors (weight 0.6)
210
+ if self.node_subtokens and node_id in self.node_subtokens:
211
+ subtoken_vectors = []
212
+ for subtoken in self.node_subtokens[node_id]:
213
+ if self.subtoken_to_idx and subtoken in self.subtoken_to_idx:
214
+ subtoken_idx = self.subtoken_to_idx[subtoken]
215
+ if subtoken_idx < len(self.embeddings):
216
+ subtoken_vectors.append(self.embeddings[subtoken_idx])
217
+
218
+ if subtoken_vectors:
219
+ subtoken_avg = np.mean(subtoken_vectors, axis=0)
220
+ vectors_to_average.append(subtoken_avg)
221
+ weights.append(0.6)
222
+
223
+ # Add type token vectors (weight 0.4)
224
+ if self.node_types and node_id in self.node_types:
225
+ type_vectors = []
226
+ for type_token, count in self.node_types[node_id].items():
227
+ if self.type_token_map and type_token in self.type_token_map:
228
+ type_idx = self.type_token_map[type_token]
229
+ # Type tokens might be in a separate space, skip for now
230
+ # In a full implementation, we'd need to map type embeddings
231
+
232
+ # For now, we'll skip type expansion in query if types aren't directly mapped
233
+ # This would require storing type embeddings separately
234
+
235
+ # Weighted average
236
+ if len(vectors_to_average) > 1:
237
+ weights_array = np.array(weights)
238
+ weights_array = weights_array / weights_array.sum() # Normalize weights
239
+
240
+ expanded = np.zeros_like(base_vector)
241
+ for vec, weight in zip(vectors_to_average, weights_array):
242
+ expanded += weight * vec
243
+ else:
244
+ expanded = base_vector
245
+
246
+ # Normalize
247
+ norm = np.linalg.norm(expanded)
248
+ if norm > 1e-10:
249
+ expanded = expanded / norm
250
+
251
+ return expanded
252
+
253
+ def compute_hybrid_score(self, query_vec: np.ndarray, candidate_vec: np.ndarray,
254
+ candidate_norm_before_normalization: float = None) -> float:
255
+ """
256
+ Compute hybrid similarity score with length penalty.
257
+
258
+ Args:
259
+ query_vec: query embedding vector (normalized)
260
+ candidate_vec: candidate embedding vector (normalized)
261
+ candidate_norm_before_normalization: norm before normalization (for penalty)
262
+
263
+ Returns:
264
+ Hybrid similarity score
265
+ """
266
+ # Cosine similarity
267
+ cosine_sim = np.dot(query_vec, candidate_vec)
268
+
269
+ # Length penalty
270
+ if candidate_norm_before_normalization is not None and self.mean_norm is not None:
271
+ length_penalty = max(0, candidate_norm_before_normalization - self.mean_norm)
272
+ score = cosine_sim - self.alpha * length_penalty
273
+ else:
274
+ score = cosine_sim
275
+
276
+ return score
277
+
278
+ def query(self, node_id: str, top_k: int = 10) -> List[Dict]:
279
+ """
280
+ Query for similar symbols with query expansion and hybrid scoring.
281
+
282
+ Args:
283
+ node_id: symbol ID to query
284
+ top_k: number of results to return
285
+
286
+ Returns:
287
+ List of result dictionaries with symbol, score, distance, meta
288
+ """
289
+ if node_id not in self.node_map:
290
+ return []
291
+
292
+ node_idx = self.node_map[node_id]
293
+
294
+ # Expand query vector
295
+ query_vector = self.expand_query_vector(node_id)
296
+ if query_vector is None:
297
+ query_vector = self.embeddings[node_idx]
298
+
299
+ # ANN search
300
+ indices, distances = self.ann_index.get_nns_by_vector(
301
+ query_vector, top_k + 1, include_distances=True
302
+ )
303
+
304
+ results = []
305
+ for idx, dist in zip(indices, distances):
306
+ # Skip self
307
+ if idx == node_idx:
308
+ continue
309
+
310
+ # Find node_id for this index
311
+ node_id_result = self.idx_to_node.get(idx)
312
+ if node_id_result is None:
313
+ continue
314
+
315
+ # Compute hybrid score
316
+ candidate_vec = self.embeddings[idx]
317
+ hybrid_score = self.compute_hybrid_score(query_vector, candidate_vec)
318
+
319
+ # Compute calibrated score (for probability)
320
+ calibrated_score = hybrid_score / self.tau if self.tau else hybrid_score
321
+
322
+ # Get metadata
323
+ meta = self.metadata_lookup.get(node_id_result)
324
+
325
+ results.append({
326
+ 'symbol': node_id_result,
327
+ 'score': float(calibrated_score),
328
+ 'hybrid_score': float(hybrid_score),
329
+ 'distance': float(dist),
330
+ 'meta': meta
331
+ })
332
+
333
+ if len(results) >= top_k:
334
+ break
335
+
336
+ # Sort by hybrid score (descending)
337
+ results.sort(key=lambda x: x['hybrid_score'], reverse=True)
338
+
339
+ return results
340
+
341
+ def search_by_keywords(self, keywords: str, top_k: int = 10,
342
+ excluded_keywords: Set[str] = None) -> List[Dict]:
343
+ """
344
+ Search for symbols by keywords (name matching).
345
+
346
+ Args:
347
+ keywords: space-separated keywords or quoted string to search for
348
+ top_k: number of results to return
349
+ excluded_keywords: set of keywords to exclude (defaults to DEFAULT_EXCLUDED_KEYWORDS)
350
+
351
+ Returns:
352
+ List of matching symbol dictionaries with symbol, score, meta
353
+ """
354
+ if not self.metadata_lookup:
355
+ return []
356
+
357
+ # Use default excluded keywords if not provided
358
+ if excluded_keywords is None:
359
+ excluded_keywords = DEFAULT_EXCLUDED_KEYWORDS
360
+
361
+ # Normalize keywords (case-insensitive)
362
+ keywords_lower = keywords.lower().strip()
363
+ keyword_words = keywords_lower.split()
364
+
365
+ # Filter out excluded keywords from search query
366
+ filtered_keyword_words = [w for w in keyword_words if w not in excluded_keywords]
367
+
368
+ # If all keywords were filtered out, return empty results
369
+ if not filtered_keyword_words:
370
+ return []
371
+
372
+ # Rebuild keywords string from filtered words
373
+ filtered_keywords_lower = ' '.join(filtered_keyword_words)
374
+
375
+ # Get type tokens for this symbol (if available)
376
+ def get_type_tokens(node_id: str) -> List[str]:
377
+ """Get all type tokens for a symbol, including expanded primitives."""
378
+ if not self.node_types or node_id not in self.node_types:
379
+ return []
380
+
381
+ type_tokens = []
382
+ for type_token, count in self.node_types[node_id].items():
383
+ # Add the full type token
384
+ type_tokens.append(type_token.lower())
385
+
386
+ # Also parse composite types to extract primitives (e.g., "List[bool]" -> ["bool"])
387
+ # This allows matching "bool" when searching for "bool variable"
388
+ if '[' in type_token and ']' in type_token:
389
+ # Extract content between brackets
390
+ start = type_token.find('[')
391
+ end = type_token.rfind(']')
392
+ if start < end:
393
+ inner = type_token[start+1:end].strip()
394
+ # Split by comma and add individual types
395
+ for part in inner.split(','):
396
+ part = part.strip().lower()
397
+ if part and part not in type_tokens:
398
+ type_tokens.append(part)
399
+
400
+ return type_tokens
401
+
402
+ # Find matching symbols
403
+ matches = []
404
+ for node_id, meta in self.metadata_lookup.items():
405
+ if not meta:
406
+ continue
407
+
408
+ name = meta.get('name', '').lower()
409
+ kind = meta.get('kind', '').lower()
410
+
411
+ # Skip symbols whose names are in excluded keywords (they're not useful)
412
+ if name in excluded_keywords:
413
+ continue
414
+
415
+ # Get type tokens for this symbol
416
+ type_tokens = get_type_tokens(node_id)
417
+ type_tokens_str = ' '.join(type_tokens) # Combined string for matching
418
+
419
+ # Calculate a simple relevance score
420
+ score = 0.0
421
+
422
+ # Check exact phrase match first (highest priority)
423
+ if filtered_keywords_lower == name:
424
+ score = 1.0 # Exact name match
425
+ elif name.startswith(filtered_keywords_lower):
426
+ score = 0.8 # Name starts with keywords
427
+ elif filtered_keywords_lower in name:
428
+ score = 0.6 # Keywords contained in name
429
+ # For multi-word queries, check if all words appear in name, kind, or types
430
+ elif len(filtered_keyword_words) > 1:
431
+ # Check if all words appear in the name
432
+ all_words_in_name = all(word in name for word in filtered_keyword_words)
433
+ if all_words_in_name:
434
+ # Count how many words match
435
+ matching_words = sum(1 for word in filtered_keyword_words if word in name)
436
+ score = 0.5 + (0.2 * matching_words / len(filtered_keyword_words)) # 0.5-0.7 range
437
+ # Check if words match name + type (e.g., "bool variable")
438
+ else:
439
+ # Try to match some words in name/kind and some in types
440
+ name_kind_matches = sum(1 for word in filtered_keyword_words if word in name or word in kind)
441
+ type_matches = sum(1 for word in filtered_keyword_words if word in type_tokens_str)
442
+
443
+ if name_kind_matches > 0 and type_matches > 0:
444
+ # Combined match: name/kind + type (e.g., "bool variable")
445
+ score = 0.4 + (0.2 * (name_kind_matches + type_matches) / len(filtered_keyword_words))
446
+ elif all(word in kind for word in filtered_keyword_words):
447
+ score = 0.3 # All words in kind
448
+ elif all(word in type_tokens_str for word in filtered_keyword_words):
449
+ score = 0.35 # All words in types
450
+ # Single word queries
451
+ elif len(filtered_keyword_words) == 1:
452
+ word = filtered_keyword_words[0]
453
+ if word == name:
454
+ score = 1.0 # Exact name match
455
+ elif name.startswith(word):
456
+ score = 0.8 # Name starts with word
457
+ elif word in name:
458
+ score = 0.6 # Word contained in name
459
+ elif word == kind:
460
+ score = 0.4 # Kind match
461
+ elif word in kind:
462
+ score = 0.2 # Word in kind
463
+ elif word in type_tokens_str:
464
+ score = 0.3 # Word in type tokens
465
+
466
+ if score > 0:
467
+ matches.append({
468
+ 'symbol': node_id,
469
+ 'score': score,
470
+ 'meta': meta
471
+ })
472
+
473
+ # Sort by relevance score (descending)
474
+ matches.sort(key=lambda x: x['score'], reverse=True)
475
+
476
+ return matches[:top_k]