codegraph-gen 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,415 @@
1
+ import logging
2
+ from pathlib import Path
3
+ import tree_sitter
4
+ import tree_sitter_python
5
+ from codegraph_gen.parser.base import BaseParser, ExtractionResult, NodeSchema, EdgeSchema
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class PythonParser(BaseParser):
11
+ def __init__(self):
12
+ self.language = tree_sitter.Language(tree_sitter_python.language())
13
+ self.parser = tree_sitter.Parser(self.language)
14
+
15
+ def _get_docstring(self, node, source: bytes) -> str:
16
+ """Extracts the docstring from class/function/module body."""
17
+ body = node.child_by_field_name("body")
18
+ if not body:
19
+ # For modules, the root node is the body container
20
+ body = node
21
+
22
+ for child in body.children:
23
+ if child.type == "expression_statement":
24
+ for sub in child.children:
25
+ if sub.type in ("string", "concatenated_string"):
26
+ text = source[sub.start_byte : sub.end_byte].decode(
27
+ "utf-8", errors="replace"
28
+ )
29
+ # Strip quotes
30
+ return text.strip("\"'").strip()
31
+ # Docstring must be the first statement
32
+ if child.type not in ("comment",):
33
+ break
34
+ return ""
35
+
36
+ def _get_signature(self, node, source: bytes) -> str:
37
+ """Extracts class/function signature (e.g. def hello(a, b))."""
38
+ # Take the text from start of definition up to the colon / block
39
+ body = node.child_by_field_name("body")
40
+ if body:
41
+ end_byte = body.start_byte
42
+ # Trim trailing whitespace and colons
43
+ sig_bytes = source[node.start_byte : end_byte]
44
+ sig = sig_bytes.decode("utf-8", errors="replace").strip()
45
+ if sig.endswith(":"):
46
+ sig = sig[:-1].strip()
47
+ return sig
48
+ return (
49
+ source[node.start_byte : node.end_byte]
50
+ .decode("utf-8", errors="replace")
51
+ .split("\n")[0]
52
+ )
53
+
54
+ def parse_file(self, file_path: Path, workspace_dir: Path) -> ExtractionResult:
55
+ try:
56
+ source = file_path.read_bytes()
57
+ except Exception as e:
58
+ logger.error(f"Error reading file {file_path}: {e}")
59
+ return ExtractionResult()
60
+
61
+ tree = self.parser.parse(source)
62
+ root = tree.root_node
63
+
64
+ rel_path = str(file_path.relative_to(workspace_dir))
65
+ result = ExtractionResult()
66
+
67
+ # 1. Add file node representing the module itself
68
+ file_node_id = rel_path
69
+ result.nodes.append(
70
+ NodeSchema(
71
+ id=file_node_id,
72
+ label=file_path.name,
73
+ type="file",
74
+ source_file=rel_path,
75
+ line_start=1,
76
+ line_end=len(source.splitlines()) or 1,
77
+ signature=f"module {file_path.name}",
78
+ docstring=self._get_docstring(root, source),
79
+ )
80
+ )
81
+
82
+ # Scope helper to manage parent IDs during recursive walk
83
+ # stack of (node_id, node_type)
84
+ scope_stack = [(file_node_id, "file")]
85
+
86
+ def get_current_parent_id():
87
+ return scope_stack[-1][0] if scope_stack else file_node_id
88
+
89
+ def walk(node):
90
+ nonlocal result
91
+
92
+ if node.type == "ERROR" or (hasattr(node, "is_error") and node.is_error):
93
+ logger.debug(f"Skipping syntax error node in Python AST: {node}")
94
+ return
95
+
96
+ node_type = node.type
97
+ pushed_scope = False
98
+
99
+ if node_type == "class_definition":
100
+ name_node = node.child_by_field_name("name")
101
+ if name_node:
102
+ class_name = source[
103
+ name_node.start_byte : name_node.end_byte
104
+ ].decode("utf-8", errors="replace")
105
+ parent_id = get_current_parent_id()
106
+
107
+ # Compute qualified ID
108
+ class_id = f"{rel_path}::{class_name}"
109
+
110
+ # Add node
111
+ result.nodes.append(
112
+ NodeSchema(
113
+ id=class_id,
114
+ label=class_name,
115
+ type="class",
116
+ source_file=rel_path,
117
+ line_start=node.start_point[0] + 1,
118
+ line_end=node.end_point[0] + 1,
119
+ signature=self._get_signature(node, source),
120
+ docstring=self._get_docstring(node, source),
121
+ )
122
+ )
123
+
124
+ # Add containment edge
125
+ result.edges.append(
126
+ EdgeSchema(
127
+ source=parent_id, target=class_id, relation="contains"
128
+ )
129
+ )
130
+
131
+ # Check inheritance
132
+ superclasses = node.child_by_field_name("superclasses")
133
+ if superclasses:
134
+ # Extract inherited class names
135
+ for child in superclasses.children:
136
+ if child.type in ("identifier", "attribute"):
137
+ parent_class_name = source[
138
+ child.start_byte : child.end_byte
139
+ ].decode("utf-8", errors="replace")
140
+ # We record inheritance edge; builder will resolve the full ID later
141
+ result.edges.append(
142
+ EdgeSchema(
143
+ source=class_id,
144
+ target=parent_class_name,
145
+ relation="inherits",
146
+ )
147
+ )
148
+
149
+ scope_stack.append((class_id, "class"))
150
+ pushed_scope = True
151
+
152
+ elif node_type == "function_definition":
153
+ name_node = node.child_by_field_name("name")
154
+ if name_node:
155
+ func_name = source[
156
+ name_node.start_byte : name_node.end_byte
157
+ ].decode("utf-8", errors="replace")
158
+ parent_id = get_current_parent_id()
159
+ parent_type = scope_stack[-1][1] if scope_stack else "file"
160
+
161
+ # Compute ID: if inside a class, prepend class name.
162
+ if parent_type == "class":
163
+ func_id = f"{parent_id}.{func_name}"
164
+ sym_type = "method"
165
+ else:
166
+ func_id = f"{rel_path}::{func_name}"
167
+ sym_type = "function"
168
+
169
+ local_bindings = {}
170
+
171
+ def extract_type_from_call_or_type(type_or_call_node):
172
+ if type_or_call_node.type == "identifier":
173
+ return source[
174
+ type_or_call_node.start_byte : type_or_call_node.end_byte
175
+ ].decode("utf-8", errors="replace")
176
+ elif type_or_call_node.type == "attribute":
177
+ attr_node = type_or_call_node.child_by_field_name(
178
+ "attribute"
179
+ )
180
+ if attr_node:
181
+ return source[
182
+ attr_node.start_byte : attr_node.end_byte
183
+ ].decode("utf-8", errors="replace")
184
+ elif type_or_call_node.type == "type":
185
+ for child in type_or_call_node.children:
186
+ res = extract_type_from_call_or_type(child)
187
+ if res:
188
+ return res
189
+ elif type_or_call_node.type == "call":
190
+ func_node = type_or_call_node.child_by_field_name(
191
+ "function"
192
+ )
193
+ if func_node:
194
+ return extract_type_from_call_or_type(func_node)
195
+ for child in type_or_call_node.children:
196
+ res = extract_type_from_call_or_type(child)
197
+ if res:
198
+ return res
199
+ return None
200
+
201
+ def collect_local_bindings(n):
202
+ if n.type == "typed_parameter":
203
+ var_name = None
204
+ for child in n.children:
205
+ if child.type == "identifier":
206
+ var_name = source[
207
+ child.start_byte : child.end_byte
208
+ ].decode("utf-8", errors="replace")
209
+ break
210
+ type_node = n.child_by_field_name("type")
211
+ if var_name and type_node:
212
+ t_name = extract_type_from_call_or_type(type_node)
213
+ if t_name:
214
+ local_bindings[var_name] = t_name
215
+ elif n.type == "assignment":
216
+ left = n.child_by_field_name("left") or (
217
+ n.children[0] if n.children else None
218
+ )
219
+ right = n.child_by_field_name("right") or (
220
+ n.children[2] if len(n.children) > 2 else None
221
+ )
222
+ if (
223
+ left
224
+ and right
225
+ and left.type == "identifier"
226
+ and right.type == "call"
227
+ ):
228
+ t_name = extract_type_from_call_or_type(right)
229
+ var_name = source[
230
+ left.start_byte : left.end_byte
231
+ ].decode("utf-8", errors="replace")
232
+ if t_name:
233
+ local_bindings[var_name] = t_name
234
+ elif n.type == "as_pattern":
235
+ call_node = None
236
+ target_node = None
237
+ for child in n.children:
238
+ if child.type == "call":
239
+ call_node = child
240
+ elif child.type == "as_pattern_target":
241
+ for sub in child.children:
242
+ if sub.type == "identifier":
243
+ target_node = sub
244
+ break
245
+ if call_node and target_node:
246
+ t_name = extract_type_from_call_or_type(call_node)
247
+ var_name = source[
248
+ target_node.start_byte : target_node.end_byte
249
+ ].decode("utf-8", errors="replace")
250
+ if t_name:
251
+ local_bindings[var_name] = t_name
252
+
253
+ for child in n.children:
254
+ if child.type != "function_definition":
255
+ collect_local_bindings(child)
256
+
257
+ collect_local_bindings(node)
258
+
259
+ result.nodes.append(
260
+ NodeSchema(
261
+ id=func_id,
262
+ label=func_name,
263
+ type=sym_type,
264
+ source_file=rel_path,
265
+ line_start=node.start_point[0] + 1,
266
+ line_end=node.end_point[0] + 1,
267
+ signature=self._get_signature(node, source),
268
+ docstring=self._get_docstring(node, source),
269
+ local_bindings=local_bindings,
270
+ )
271
+ )
272
+
273
+ result.edges.append(
274
+ EdgeSchema(
275
+ source=parent_id, target=func_id, relation="contains"
276
+ )
277
+ )
278
+
279
+ scope_stack.append((func_id, sym_type))
280
+ pushed_scope = True
281
+
282
+ elif node_type in ("import_statement", "import_from_statement"):
283
+ if node_type == "import_statement":
284
+ for child in node.children:
285
+ if child.type == "dotted_name":
286
+ module_name = source[
287
+ child.start_byte : child.end_byte
288
+ ].decode("utf-8", errors="replace")
289
+ result.edges.append(
290
+ EdgeSchema(
291
+ source=file_node_id,
292
+ target=module_name,
293
+ relation="imports",
294
+ import_map={module_name: module_name},
295
+ )
296
+ )
297
+ elif child.type == "aliased_import":
298
+ name_node = child.child_by_field_name("name")
299
+ alias_node = child.child_by_field_name("alias")
300
+ if name_node and alias_node:
301
+ module_name = source[
302
+ name_node.start_byte : name_node.end_byte
303
+ ].decode("utf-8", errors="replace")
304
+ alias_name = source[
305
+ alias_node.start_byte : alias_node.end_byte
306
+ ].decode("utf-8", errors="replace")
307
+ result.edges.append(
308
+ EdgeSchema(
309
+ source=file_node_id,
310
+ target=module_name,
311
+ relation="imports",
312
+ import_map={alias_name: module_name},
313
+ )
314
+ )
315
+ elif node_type == "import_from_statement":
316
+ module_node = node.child_by_field_name("module_name")
317
+ module_name = ""
318
+ if module_node:
319
+ module_name = source[
320
+ module_node.start_byte : module_node.end_byte
321
+ ].decode("utf-8", errors="replace")
322
+
323
+ dots = ""
324
+ for child in node.children:
325
+ if child.type == "relative_source":
326
+ dots = source[child.start_byte : child.end_byte].decode(
327
+ "utf-8", errors="replace"
328
+ )
329
+ break
330
+
331
+ target_module = dots + module_name
332
+ import_map = {}
333
+ import_items = []
334
+
335
+ start_collecting = False
336
+ for child in node.children:
337
+ if (module_node and child == module_node) or (
338
+ child.type == "relative_source" and not start_collecting
339
+ ):
340
+ start_collecting = True
341
+ continue
342
+ if start_collecting:
343
+ if child.type == "wildcard_import":
344
+ import_items.append(child)
345
+ elif child.type in (
346
+ "dotted_name",
347
+ "aliased_import",
348
+ "identifier",
349
+ ):
350
+ import_items.append(child)
351
+ elif child.type == "import_list":
352
+ for sub_child in child.children:
353
+ if sub_child.type in (
354
+ "dotted_name",
355
+ "aliased_import",
356
+ "identifier",
357
+ ):
358
+ import_items.append(sub_child)
359
+
360
+ for item in import_items:
361
+ if item.type == "wildcard_import":
362
+ import_map["*"] = "*"
363
+ elif item.type in ("dotted_name", "identifier"):
364
+ name = source[item.start_byte : item.end_byte].decode(
365
+ "utf-8", errors="replace"
366
+ )
367
+ import_map[name] = name
368
+ elif item.type == "aliased_import":
369
+ name_node = item.child_by_field_name("name")
370
+ alias_node = item.child_by_field_name("alias")
371
+ if name_node and alias_node:
372
+ name = source[
373
+ name_node.start_byte : name_node.end_byte
374
+ ].decode("utf-8", errors="replace")
375
+ alias = source[
376
+ alias_node.start_byte : alias_node.end_byte
377
+ ].decode("utf-8", errors="replace")
378
+ import_map[alias] = name
379
+
380
+ if target_module:
381
+ result.edges.append(
382
+ EdgeSchema(
383
+ source=file_node_id,
384
+ target=target_module,
385
+ relation="imports",
386
+ import_map=import_map,
387
+ )
388
+ )
389
+
390
+ elif node_type == "call":
391
+ # Function/method call extraction
392
+ func_node = node.child_by_field_name("function")
393
+ if func_node:
394
+ callee_name = source[
395
+ func_node.start_byte : func_node.end_byte
396
+ ].decode("utf-8", errors="replace")
397
+ # Source of the call is the current function/method, or the file if at top level
398
+ caller_id = get_current_parent_id()
399
+
400
+ # We record a calls edge; builder will resolve the full ID later
401
+ result.edges.append(
402
+ EdgeSchema(
403
+ source=caller_id, target=callee_name, relation="calls"
404
+ )
405
+ )
406
+
407
+ # Recurse children
408
+ for child in node.children:
409
+ walk(child)
410
+
411
+ if pushed_scope:
412
+ scope_stack.pop()
413
+
414
+ walk(root)
415
+ return result