tarang 4.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,464 @@
1
+ """
2
+ Symbol Graph - Lightweight Code Knowledge Graph.
3
+
4
+ A labeled property graph storing relationships between code symbols:
5
+ - Functions, classes, methods (Python/JS/TS)
6
+ - Tables, views, procedures, triggers, indexes (SQL)
7
+ - Calls, imports, inheritance, references relationships
8
+
9
+ Enables graph-augmented retrieval by expanding BM25 results
10
+ to include connected symbols.
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import json
15
+ from dataclasses import dataclass, field
16
+ from pathlib import Path
17
+ from typing import Dict, List, Optional, Set
18
+
19
+ from .chunker import SymbolInfo
20
+
21
+
22
+ @dataclass
23
+ class SymbolNode:
24
+ """A node in the symbol graph."""
25
+ id: str # "file.py:function_name" or "file.sql:table:users"
26
+ type: str # "function" | "method" | "class" | "module" |
27
+ # "table" | "view" | "procedure" | "trigger" | "index"
28
+ file: str # File path
29
+ name: str # Symbol name
30
+ signature: str # Function/class/table signature
31
+ line: int # Definition line
32
+
33
+ def to_dict(self) -> Dict:
34
+ return {
35
+ "type": self.type,
36
+ "file": self.file,
37
+ "name": self.name,
38
+ "signature": self.signature,
39
+ "line": self.line,
40
+ }
41
+
42
+ @classmethod
43
+ def from_dict(cls, id: str, data: Dict) -> "SymbolNode":
44
+ return cls(
45
+ id=id,
46
+ type=data["type"],
47
+ file=data["file"],
48
+ name=data["name"],
49
+ signature=data["signature"],
50
+ line=data["line"],
51
+ )
52
+
53
+
54
+ @dataclass
55
+ class SymbolEdges:
56
+ """Edges for a symbol in the graph."""
57
+ calls: List[str] = field(default_factory=list) # Functions this calls
58
+ called_by: List[str] = field(default_factory=list) # Functions that call this
59
+ imports: List[str] = field(default_factory=list) # Modules imported
60
+ imported_by: List[str] = field(default_factory=list) # Files that import this
61
+ inherits: List[str] = field(default_factory=list) # Parent classes
62
+ inherited_by: List[str] = field(default_factory=list) # Child classes
63
+ defines: List[str] = field(default_factory=list) # Symbols defined in this scope
64
+ defined_in: Optional[str] = None # Parent scope
65
+ # SQL relationships
66
+ references: List[str] = field(default_factory=list) # Tables referenced by views/funcs
67
+ referenced_by: List[str] = field(default_factory=list) # Views/funcs that reference this
68
+
69
+ def to_dict(self) -> Dict:
70
+ result = {}
71
+ if self.calls:
72
+ result["calls"] = self.calls
73
+ if self.called_by:
74
+ result["called_by"] = self.called_by
75
+ if self.imports:
76
+ result["imports"] = self.imports
77
+ if self.imported_by:
78
+ result["imported_by"] = self.imported_by
79
+ if self.inherits:
80
+ result["inherits"] = self.inherits
81
+ if self.inherited_by:
82
+ result["inherited_by"] = self.inherited_by
83
+ if self.defines:
84
+ result["defines"] = self.defines
85
+ if self.defined_in:
86
+ result["defined_in"] = self.defined_in
87
+ if self.references:
88
+ result["references"] = self.references
89
+ if self.referenced_by:
90
+ result["referenced_by"] = self.referenced_by
91
+ return result
92
+
93
+ @classmethod
94
+ def from_dict(cls, data: Dict) -> "SymbolEdges":
95
+ return cls(
96
+ calls=data.get("calls", []),
97
+ called_by=data.get("called_by", []),
98
+ imports=data.get("imports", []),
99
+ imported_by=data.get("imported_by", []),
100
+ inherits=data.get("inherits", []),
101
+ inherited_by=data.get("inherited_by", []),
102
+ defines=data.get("defines", []),
103
+ defined_in=data.get("defined_in"),
104
+ references=data.get("references", []),
105
+ referenced_by=data.get("referenced_by", []),
106
+ )
107
+
108
+
109
+ class SymbolGraph:
110
+ """
111
+ Lightweight Code Knowledge Graph.
112
+
113
+ Stores symbols and their relationships as an adjacency list.
114
+ Supports graph traversal for context expansion.
115
+ """
116
+
117
+ def __init__(self):
118
+ self._nodes: Dict[str, SymbolNode] = {}
119
+ self._edges: Dict[str, SymbolEdges] = {}
120
+ # Reverse index: name -> [symbol_ids]
121
+ self._name_index: Dict[str, List[str]] = {}
122
+
123
+ @property
124
+ def is_empty(self) -> bool:
125
+ return len(self._nodes) == 0
126
+
127
+ def add_symbol(self, info: SymbolInfo) -> None:
128
+ """
129
+ Add a symbol to the graph.
130
+
131
+ Args:
132
+ info: Symbol information from chunker
133
+ """
134
+ node = SymbolNode(
135
+ id=info.id,
136
+ type=info.type,
137
+ file=info.file,
138
+ name=info.name,
139
+ signature=info.signature,
140
+ line=info.line,
141
+ )
142
+ self._nodes[info.id] = node
143
+
144
+ # Initialize edges if not exists
145
+ if info.id not in self._edges:
146
+ self._edges[info.id] = SymbolEdges()
147
+
148
+ # Update name index
149
+ if info.name not in self._name_index:
150
+ self._name_index[info.name] = []
151
+ if info.id not in self._name_index[info.name]:
152
+ self._name_index[info.name].append(info.id)
153
+
154
+ # Process calls
155
+ for call_name in info.calls:
156
+ # Try to resolve call to a symbol ID
157
+ target_ids = self._resolve_call(call_name, info.file)
158
+ for target_id in target_ids:
159
+ self._add_edge(info.id, target_id, "calls")
160
+
161
+ # Process imports (stored as inheritance for classes, references for SQL)
162
+ if info.type == "class":
163
+ for parent in info.imports:
164
+ parent_ids = self._resolve_call(parent, info.file)
165
+ for parent_id in parent_ids:
166
+ self._add_edge(info.id, parent_id, "inherits")
167
+ elif info.type == "module":
168
+ for module in info.imports:
169
+ self._edges[info.id].imports.append(module)
170
+ elif info.type in ("view", "procedure", "function", "trigger", "index"):
171
+ # SQL: views/procedures/triggers/indexes reference tables
172
+ for table_ref in info.imports:
173
+ target_ids = self._resolve_call(table_ref, info.file)
174
+ for target_id in target_ids:
175
+ self._add_edge(info.id, target_id, "references")
176
+
177
+ # Process parent class relationship
178
+ if info.parent_class:
179
+ parent_id = f"{info.file}:{info.parent_class}"
180
+ self._edges[info.id].defined_in = parent_id
181
+ if parent_id in self._edges:
182
+ if info.id not in self._edges[parent_id].defines:
183
+ self._edges[parent_id].defines.append(info.id)
184
+
185
+ def _resolve_call(self, call_name: str, current_file: str) -> List[str]:
186
+ """
187
+ Resolve a function call name to symbol IDs.
188
+
189
+ Strategy:
190
+ 1. Look in name index for matching symbols
191
+ 2. Prefer symbols in same file
192
+ 3. Fall back to any matching symbol
193
+ """
194
+ if call_name not in self._name_index:
195
+ return []
196
+
197
+ candidates = self._name_index[call_name]
198
+
199
+ # Prefer same file
200
+ same_file = [c for c in candidates if c.startswith(current_file + ":")]
201
+ if same_file:
202
+ return same_file
203
+
204
+ return candidates
205
+
206
+ def _add_edge(self, source: str, target: str, edge_type: str) -> None:
207
+ """Add an edge between two symbols."""
208
+ # Ensure edges exist for both
209
+ if source not in self._edges:
210
+ self._edges[source] = SymbolEdges()
211
+ if target not in self._edges:
212
+ self._edges[target] = SymbolEdges()
213
+
214
+ # Add forward edge
215
+ if edge_type == "calls":
216
+ if target not in self._edges[source].calls:
217
+ self._edges[source].calls.append(target)
218
+ if source not in self._edges[target].called_by:
219
+ self._edges[target].called_by.append(source)
220
+ elif edge_type == "inherits":
221
+ if target not in self._edges[source].inherits:
222
+ self._edges[source].inherits.append(target)
223
+ if source not in self._edges[target].inherited_by:
224
+ self._edges[target].inherited_by.append(source)
225
+ elif edge_type == "references":
226
+ # SQL: view/procedure references table
227
+ if target not in self._edges[source].references:
228
+ self._edges[source].references.append(target)
229
+ if source not in self._edges[target].referenced_by:
230
+ self._edges[target].referenced_by.append(source)
231
+
232
+ def remove_file(self, file_path: str) -> None:
233
+ """Remove all symbols from a file."""
234
+ # Find symbols to remove
235
+ to_remove = [sid for sid in self._nodes if sid.startswith(file_path + ":")]
236
+
237
+ for sid in to_remove:
238
+ # Remove from nodes
239
+ node = self._nodes.pop(sid, None)
240
+ if node:
241
+ # Remove from name index
242
+ if node.name in self._name_index:
243
+ self._name_index[node.name] = [
244
+ s for s in self._name_index[node.name] if s != sid
245
+ ]
246
+ if not self._name_index[node.name]:
247
+ del self._name_index[node.name]
248
+
249
+ # Remove edges
250
+ self._edges.pop(sid, None)
251
+
252
+ # Remove references from other edges
253
+ for edges in self._edges.values():
254
+ edges.calls = [c for c in edges.calls if c != sid]
255
+ edges.called_by = [c for c in edges.called_by if c != sid]
256
+ edges.inherits = [c for c in edges.inherits if c != sid]
257
+ edges.inherited_by = [c for c in edges.inherited_by if c != sid]
258
+ edges.defines = [c for c in edges.defines if c != sid]
259
+ edges.references = [c for c in edges.references if c != sid]
260
+ edges.referenced_by = [c for c in edges.referenced_by if c != sid]
261
+ if edges.defined_in == sid:
262
+ edges.defined_in = None
263
+
264
+ def get_node(self, symbol_id: str) -> Optional[SymbolNode]:
265
+ """Get a symbol node by ID."""
266
+ return self._nodes.get(symbol_id)
267
+
268
+ def get_edges(self, symbol_id: str) -> Optional[SymbolEdges]:
269
+ """Get edges for a symbol."""
270
+ return self._edges.get(symbol_id)
271
+
272
+ def get_signature(self, symbol_id: str) -> Optional[str]:
273
+ """Get just the signature for a symbol."""
274
+ node = self._nodes.get(symbol_id)
275
+ return node.signature if node else None
276
+
277
+ def get_neighbors(
278
+ self,
279
+ symbol_id: str,
280
+ hops: int = 1,
281
+ edge_types: Optional[List[str]] = None
282
+ ) -> List[SymbolNode]:
283
+ """
284
+ Get symbols within N hops of a symbol.
285
+
286
+ Args:
287
+ symbol_id: Starting symbol
288
+ hops: Number of hops (1 = direct connections, 2 = 2 levels)
289
+ edge_types: Edge types to follow (None = all)
290
+
291
+ Returns:
292
+ List of connected SymbolNodes
293
+ """
294
+ if hops < 1 or symbol_id not in self._edges:
295
+ return []
296
+
297
+ visited: Set[str] = {symbol_id}
298
+ current_level: Set[str] = {symbol_id}
299
+ result: List[SymbolNode] = []
300
+
301
+ for _ in range(hops):
302
+ next_level: Set[str] = set()
303
+
304
+ for sid in current_level:
305
+ edges = self._edges.get(sid)
306
+ if not edges:
307
+ continue
308
+
309
+ # Collect neighbors based on edge types
310
+ neighbors: List[str] = []
311
+
312
+ if edge_types is None or "calls" in edge_types:
313
+ neighbors.extend(edges.calls)
314
+ if edge_types is None or "called_by" in edge_types:
315
+ neighbors.extend(edges.called_by)
316
+ if edge_types is None or "inherits" in edge_types:
317
+ neighbors.extend(edges.inherits)
318
+ if edge_types is None or "inherited_by" in edge_types:
319
+ neighbors.extend(edges.inherited_by)
320
+ if edge_types is None or "defines" in edge_types:
321
+ neighbors.extend(edges.defines)
322
+ if edge_types is None or "defined_in" in edge_types:
323
+ if edges.defined_in:
324
+ neighbors.append(edges.defined_in)
325
+ # SQL relationships
326
+ if edge_types is None or "references" in edge_types:
327
+ neighbors.extend(edges.references)
328
+ if edge_types is None or "referenced_by" in edge_types:
329
+ neighbors.extend(edges.referenced_by)
330
+
331
+ for neighbor in neighbors:
332
+ if neighbor not in visited and neighbor in self._nodes:
333
+ visited.add(neighbor)
334
+ next_level.add(neighbor)
335
+ result.append(self._nodes[neighbor])
336
+
337
+ current_level = next_level
338
+
339
+ return result
340
+
341
+ def get_callers(self, symbol_id: str) -> List[SymbolNode]:
342
+ """Get all symbols that call this symbol."""
343
+ edges = self._edges.get(symbol_id)
344
+ if not edges:
345
+ return []
346
+ return [self._nodes[sid] for sid in edges.called_by if sid in self._nodes]
347
+
348
+ def get_callees(self, symbol_id: str) -> List[SymbolNode]:
349
+ """Get all symbols that this symbol calls."""
350
+ edges = self._edges.get(symbol_id)
351
+ if not edges:
352
+ return []
353
+ return [self._nodes[sid] for sid in edges.calls if sid in self._nodes]
354
+
355
+ def save(self, path: Path) -> None:
356
+ """Save graph to JSON file."""
357
+ data = {
358
+ "nodes": {sid: node.to_dict() for sid, node in self._nodes.items()},
359
+ "edges": {sid: edges.to_dict() for sid, edges in self._edges.items()},
360
+ }
361
+
362
+ with open(path, "w") as f:
363
+ json.dump(data, f, indent=2)
364
+
365
+ def load(self, path: Path) -> bool:
366
+ """Load graph from JSON file."""
367
+ if not path.exists():
368
+ return False
369
+
370
+ try:
371
+ with open(path, "r") as f:
372
+ data = json.load(f)
373
+
374
+ self._nodes = {
375
+ sid: SymbolNode.from_dict(sid, node_data)
376
+ for sid, node_data in data.get("nodes", {}).items()
377
+ }
378
+
379
+ self._edges = {
380
+ sid: SymbolEdges.from_dict(edge_data)
381
+ for sid, edge_data in data.get("edges", {}).items()
382
+ }
383
+
384
+ # Rebuild name index
385
+ self._name_index = {}
386
+ for sid, node in self._nodes.items():
387
+ if node.name not in self._name_index:
388
+ self._name_index[node.name] = []
389
+ self._name_index[node.name].append(sid)
390
+
391
+ return True
392
+
393
+ except Exception:
394
+ return False
395
+
396
+ def stats(self) -> Dict:
397
+ """Get graph statistics."""
398
+ if not self._nodes:
399
+ return {
400
+ "total_symbols": 0,
401
+ "total_edges": 0,
402
+ "symbol_types": {},
403
+ }
404
+
405
+ types = {}
406
+ for node in self._nodes.values():
407
+ types[node.type] = types.get(node.type, 0) + 1
408
+
409
+ total_edges = sum(
410
+ len(e.calls) + len(e.inherits) + len(e.defines) + len(e.references)
411
+ for e in self._edges.values()
412
+ )
413
+
414
+ return {
415
+ "total_symbols": len(self._nodes),
416
+ "total_edges": total_edges,
417
+ "symbol_types": types,
418
+ }
419
+
420
+ def get_graph_context(self, symbol_ids: List[str]) -> Dict:
421
+ """
422
+ Get a summary of graph relationships for symbols.
423
+
424
+ Useful for including in LLM context.
425
+ """
426
+ context = {}
427
+
428
+ for sid in symbol_ids:
429
+ edges = self._edges.get(sid)
430
+ if not edges:
431
+ continue
432
+
433
+ # Get human-readable names
434
+ calls = [
435
+ self._nodes[c].name for c in edges.calls if c in self._nodes
436
+ ]
437
+ called_by = [
438
+ self._nodes[c].name for c in edges.called_by if c in self._nodes
439
+ ]
440
+ inherits = [
441
+ self._nodes[c].name for c in edges.inherits if c in self._nodes
442
+ ]
443
+ # SQL relationships
444
+ references = [
445
+ self._nodes[c].name for c in edges.references if c in self._nodes
446
+ ]
447
+ referenced_by = [
448
+ self._nodes[c].name for c in edges.referenced_by if c in self._nodes
449
+ ]
450
+
451
+ if calls or called_by or inherits or references or referenced_by:
452
+ context[sid] = {}
453
+ if calls:
454
+ context[sid]["calls"] = calls
455
+ if called_by:
456
+ context[sid]["called_by"] = called_by
457
+ if inherits:
458
+ context[sid]["inherits"] = inherits
459
+ if references:
460
+ context[sid]["references"] = references
461
+ if referenced_by:
462
+ context[sid]["referenced_by"] = referenced_by
463
+
464
+ return context