codegraph-gen 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,497 @@
1
+ import logging
2
+ from pathlib import Path
3
+ import tree_sitter
4
+ import tree_sitter_rust
5
+ from codegraph_gen.parser.base import BaseParser, ExtractionResult, NodeSchema, EdgeSchema
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class RustParser(BaseParser):
11
+ def __init__(self):
12
+ self.language = tree_sitter.Language(tree_sitter_rust.language())
13
+ self.parser = tree_sitter.Parser(self.language)
14
+
15
+ def _get_docstring(self, node, source: bytes) -> str:
16
+ """Finds comments immediately preceding the node."""
17
+ docstring = ""
18
+ prev = node.prev_sibling
19
+ comments = []
20
+ while prev and prev.type in ("line_comment", "block_comment"):
21
+ comment_text = source[prev.start_byte : prev.end_byte].decode(
22
+ "utf-8", errors="replace"
23
+ )
24
+ # Strip comment markers (/// or //)
25
+ clean_text = comment_text.strip().lstrip("/").strip()
26
+ comments.append(clean_text)
27
+ prev = prev.prev_sibling
28
+
29
+ if comments:
30
+ docstring = "\n".join(reversed(comments))
31
+ return docstring
32
+
33
+ def _get_signature(self, node, source: bytes) -> str:
34
+ body = node.child_by_field_name("body")
35
+ if body:
36
+ end_byte = body.start_byte
37
+ sig_bytes = source[node.start_byte : end_byte]
38
+ sig = sig_bytes.decode("utf-8", errors="replace").strip()
39
+ if sig.endswith("{"):
40
+ sig = sig[:-1].strip()
41
+ return sig
42
+ return (
43
+ source[node.start_byte : node.end_byte]
44
+ .decode("utf-8", errors="replace")
45
+ .split("\n")[0]
46
+ )
47
+
48
+ def parse_file(self, file_path: Path, workspace_dir: Path) -> ExtractionResult:
49
+ try:
50
+ source = file_path.read_bytes()
51
+ except Exception as e:
52
+ logger.error(f"Error reading file {file_path}: {e}")
53
+ return ExtractionResult()
54
+
55
+ tree = self.parser.parse(source)
56
+ root = tree.root_node
57
+
58
+ rel_path = str(file_path.relative_to(workspace_dir))
59
+ result = ExtractionResult()
60
+
61
+ # Add file node
62
+ file_node_id = rel_path
63
+ result.nodes.append(
64
+ NodeSchema(
65
+ id=file_node_id,
66
+ label=file_path.name,
67
+ type="file",
68
+ source_file=rel_path,
69
+ line_start=1,
70
+ line_end=len(source.splitlines()) or 1,
71
+ signature=f"mod {file_path.stem}",
72
+ docstring=self._get_docstring(root, source),
73
+ )
74
+ )
75
+
76
+ def get_impl_type(impl_node) -> str | None:
77
+ type_node = impl_node.child_by_field_name("type")
78
+ if type_node:
79
+ raw_type = source[type_node.start_byte : type_node.end_byte].decode(
80
+ "utf-8", errors="replace"
81
+ )
82
+ return raw_type.strip()
83
+ return None
84
+
85
+ def walk(node, current_impl_type=None):
86
+ nonlocal result
87
+
88
+ if node.type == "ERROR" or (hasattr(node, "is_error") and node.is_error):
89
+ logger.debug(f"Skipping syntax error node in Rust AST: {node}")
90
+ return
91
+
92
+ node_type = node.type
93
+ pushed_impl = None
94
+
95
+ if node_type in ("struct_item", "enum_item", "trait_item"):
96
+ name_node = node.child_by_field_name("name")
97
+ if name_node:
98
+ item_name = source[
99
+ name_node.start_byte : name_node.end_byte
100
+ ].decode("utf-8", errors="replace")
101
+ item_id = f"{rel_path}::{item_name}"
102
+
103
+ sym_type = "struct"
104
+ if node_type == "enum_item":
105
+ sym_type = "enum"
106
+ elif node_type == "trait_item":
107
+ sym_type = "interface" # map trait to interface for consistency
108
+
109
+ result.nodes.append(
110
+ NodeSchema(
111
+ id=item_id,
112
+ label=item_name,
113
+ type=sym_type,
114
+ source_file=rel_path,
115
+ line_start=node.start_point[0] + 1,
116
+ line_end=node.end_point[0] + 1,
117
+ signature=self._get_signature(node, source),
118
+ docstring=self._get_docstring(node, source),
119
+ )
120
+ )
121
+
122
+ result.edges.append(
123
+ EdgeSchema(
124
+ source=file_node_id, target=item_id, relation="contains"
125
+ )
126
+ )
127
+
128
+ elif node_type == "impl_item":
129
+ impl_type = get_impl_type(node)
130
+ if impl_type:
131
+ pushed_impl = impl_type
132
+
133
+ # Ensure struct node is created if it hasn't been yet (impls can define methods for external/internal types)
134
+ type_id = f"{rel_path}::{impl_type}"
135
+
136
+ # We might also link impl to trait if it's trait implementation
137
+ trait_node = node.child_by_field_name("trait")
138
+ if trait_node:
139
+ trait_name = source[
140
+ trait_node.start_byte : trait_node.end_byte
141
+ ].decode("utf-8", errors="replace")
142
+ result.edges.append(
143
+ EdgeSchema(
144
+ source=type_id, target=trait_name, relation="implements"
145
+ )
146
+ )
147
+
148
+ elif node_type == "function_item":
149
+ name_node = node.child_by_field_name("name")
150
+ if name_node:
151
+ func_name = source[
152
+ name_node.start_byte : name_node.end_byte
153
+ ].decode("utf-8", errors="replace")
154
+
155
+ if current_impl_type:
156
+ parent_id = f"{rel_path}::{current_impl_type}"
157
+ func_id = f"{parent_id}.{func_name}"
158
+ sym_type = "method"
159
+ relation = "contains"
160
+ else:
161
+ parent_id = file_node_id
162
+ func_id = f"{rel_path}::{func_name}"
163
+ sym_type = "function"
164
+ relation = "contains"
165
+
166
+ local_bindings = {}
167
+
168
+ def extract_rust_type(type_node) -> str | None:
169
+ if type_node.type == "type_identifier":
170
+ return source[
171
+ type_node.start_byte : type_node.end_byte
172
+ ].decode("utf-8", errors="replace")
173
+ elif type_node.type in (
174
+ "pointer_type",
175
+ "reference_type",
176
+ "sliced_type",
177
+ "array_type",
178
+ ):
179
+ for child in type_node.children:
180
+ if child.type not in ("&", "*", "mut", "const"):
181
+ res = extract_rust_type(child)
182
+ if res:
183
+ return res
184
+ elif type_node.type == "generic_type":
185
+ type_id_node = type_node.child_by_field_name("type")
186
+ if type_id_node:
187
+ return extract_rust_type(type_id_node)
188
+ return None
189
+
190
+ def collect_local_bindings(n):
191
+ if n.type == "parameter":
192
+ pattern_node = n.child_by_field_name("pattern")
193
+ type_node = n.child_by_field_name("type")
194
+ if pattern_node and type_node:
195
+ var_name = None
196
+ if pattern_node.type == "identifier":
197
+ var_name = source[
198
+ pattern_node.start_byte : pattern_node.end_byte
199
+ ].decode("utf-8", errors="replace")
200
+ elif pattern_node.type == "mut_pattern":
201
+ inner = pattern_node.child_by_field_name("pattern")
202
+ if inner and inner.type == "identifier":
203
+ var_name = source[
204
+ inner.start_byte : inner.end_byte
205
+ ].decode("utf-8", errors="replace")
206
+ if var_name:
207
+ t_name = extract_rust_type(type_node)
208
+ if t_name:
209
+ local_bindings[var_name] = t_name
210
+
211
+ elif n.type == "let_declaration":
212
+ pattern_node = n.child_by_field_name("pattern")
213
+ type_node = n.child_by_field_name("type")
214
+ value_node = n.child_by_field_name("value")
215
+
216
+ var_name = None
217
+ if pattern_node:
218
+ if pattern_node.type == "identifier":
219
+ var_name = source[
220
+ pattern_node.start_byte : pattern_node.end_byte
221
+ ].decode("utf-8", errors="replace")
222
+ elif pattern_node.type == "mut_pattern":
223
+ inner = pattern_node.child_by_field_name("pattern")
224
+ if inner and inner.type == "identifier":
225
+ var_name = source[
226
+ inner.start_byte : inner.end_byte
227
+ ].decode("utf-8", errors="replace")
228
+
229
+ if var_name:
230
+ type_name = None
231
+ if type_node:
232
+ type_name = extract_rust_type(type_node)
233
+ elif value_node:
234
+ if value_node.type == "call_expression":
235
+ func = value_node.child_by_field_name(
236
+ "function"
237
+ )
238
+ if func and func.type == "scoped_identifier":
239
+ path_node = func.child_by_field_name("path")
240
+ if path_node:
241
+ type_name = source[
242
+ path_node.start_byte : path_node.end_byte
243
+ ].decode("utf-8", errors="replace")
244
+ elif value_node.type == "struct_expression":
245
+ name_node = value_node.child_by_field_name(
246
+ "name"
247
+ )
248
+ if name_node:
249
+ type_name = extract_rust_type(name_node)
250
+ elif value_node.type == "match_expression":
251
+ subject_node = value_node.child_by_field_name(
252
+ "value"
253
+ )
254
+ if not subject_node:
255
+ for child in value_node.children:
256
+ if child.type in ("match_block", "{"):
257
+ break
258
+ if child.type != "match":
259
+ subject_node = child
260
+ break
261
+ if subject_node:
262
+ sub_ids = []
263
+
264
+ def collect_ids(sub_n):
265
+ if sub_n.type == "identifier":
266
+ id_str = source[
267
+ sub_n.start_byte : sub_n.end_byte
268
+ ].decode("utf-8", errors="replace")
269
+ sub_ids.append(id_str)
270
+ for c in sub_n.children:
271
+ collect_ids(c)
272
+
273
+ collect_ids(subject_node)
274
+ for sub_id in sub_ids:
275
+ if sub_id in local_bindings:
276
+ type_name = local_bindings[sub_id]
277
+ break
278
+ if type_name:
279
+ local_bindings[var_name] = type_name
280
+
281
+ for child in n.children:
282
+ collect_local_bindings(child)
283
+
284
+ collect_local_bindings(node)
285
+
286
+ result.nodes.append(
287
+ NodeSchema(
288
+ id=func_id,
289
+ label=func_name,
290
+ type=sym_type,
291
+ source_file=rel_path,
292
+ line_start=node.start_point[0] + 1,
293
+ line_end=node.end_point[0] + 1,
294
+ signature=self._get_signature(node, source),
295
+ docstring=self._get_docstring(node, source),
296
+ local_bindings=local_bindings,
297
+ )
298
+ )
299
+
300
+ result.edges.append(
301
+ EdgeSchema(source=parent_id, target=func_id, relation=relation)
302
+ )
303
+
304
+ elif node_type == "use_declaration":
305
+
306
+ def parse_use_item(n, prefix=""):
307
+ if n.type == "use_path":
308
+ parts = []
309
+ use_list_node = None
310
+ as_clause_node = None
311
+
312
+ for child in n.children:
313
+ if child.type == "use_list":
314
+ use_list_node = child
315
+ elif child.type == "use_as_clause":
316
+ as_clause_node = child
317
+ elif child.type in (
318
+ "identifier",
319
+ "scoped_identifier",
320
+ "use_path",
321
+ ):
322
+ parts.append(
323
+ source[child.start_byte : child.end_byte].decode(
324
+ "utf-8", errors="replace"
325
+ )
326
+ )
327
+
328
+ current_path = "::".join(parts)
329
+ full_path = (
330
+ f"{prefix}::{current_path}" if prefix else current_path
331
+ )
332
+
333
+ if use_list_node:
334
+ for sub in use_list_node.children:
335
+ if sub.type in (
336
+ "use_path",
337
+ "identifier",
338
+ "scoped_identifier",
339
+ "use_as_clause",
340
+ ):
341
+ parse_use_item(sub, full_path)
342
+ elif as_clause_node:
343
+ path_node = as_clause_node.child_by_field_name("path")
344
+ alias_node = as_clause_node.child_by_field_name("alias")
345
+ if path_node and alias_node:
346
+ sub_path = source[
347
+ path_node.start_byte : path_node.end_byte
348
+ ].decode("utf-8", errors="replace")
349
+ alias_name = source[
350
+ alias_node.start_byte : alias_node.end_byte
351
+ ].decode("utf-8", errors="replace")
352
+ item_path = (
353
+ f"{full_path}::{sub_path}"
354
+ if full_path
355
+ else sub_path
356
+ )
357
+ last_symbol = item_path.split("::")[-1]
358
+ result.edges.append(
359
+ EdgeSchema(
360
+ source=file_node_id,
361
+ target=item_path,
362
+ relation="imports",
363
+ import_map={alias_name: last_symbol},
364
+ )
365
+ )
366
+ else:
367
+ last_symbol = full_path.split("::")[-1]
368
+ result.edges.append(
369
+ EdgeSchema(
370
+ source=file_node_id,
371
+ target=full_path,
372
+ relation="imports",
373
+ import_map={last_symbol: last_symbol},
374
+ )
375
+ )
376
+
377
+ elif n.type == "use_as_clause":
378
+ path_node = n.child_by_field_name("path")
379
+ alias_node = n.child_by_field_name("alias")
380
+ if path_node and alias_node:
381
+ path_name = source[
382
+ path_node.start_byte : path_node.end_byte
383
+ ].decode("utf-8", errors="replace")
384
+ alias_name = source[
385
+ alias_node.start_byte : alias_node.end_byte
386
+ ].decode("utf-8", errors="replace")
387
+ full_path = (
388
+ f"{prefix}::{path_name}" if prefix else path_name
389
+ )
390
+ last_symbol = full_path.split("::")[-1]
391
+ result.edges.append(
392
+ EdgeSchema(
393
+ source=file_node_id,
394
+ target=full_path,
395
+ relation="imports",
396
+ import_map={alias_name: last_symbol},
397
+ )
398
+ )
399
+ elif n.type in ("identifier", "scoped_identifier"):
400
+ name = source[n.start_byte : n.end_byte].decode(
401
+ "utf-8", errors="replace"
402
+ )
403
+ full_path = f"{prefix}::{name}" if prefix else name
404
+ last_symbol = full_path.split("::")[-1]
405
+ result.edges.append(
406
+ EdgeSchema(
407
+ source=file_node_id,
408
+ target=full_path,
409
+ relation="imports",
410
+ import_map={last_symbol: last_symbol},
411
+ )
412
+ )
413
+ elif n.type == "self_literal":
414
+ full_path = prefix
415
+ last_symbol = full_path.split("::")[-1] if full_path else "self"
416
+ result.edges.append(
417
+ EdgeSchema(
418
+ source=file_node_id,
419
+ target=full_path,
420
+ relation="imports",
421
+ import_map={last_symbol: last_symbol},
422
+ )
423
+ )
424
+
425
+ for child in node.children:
426
+ if child.type in (
427
+ "use_path",
428
+ "use_list",
429
+ "identifier",
430
+ "scoped_identifier",
431
+ "use_as_clause",
432
+ ):
433
+ parse_use_item(child)
434
+
435
+ elif node_type in ("call_expression", "method_call_expression"):
436
+ callee_name = None
437
+ if node_type == "call_expression":
438
+ func_node = node.child_by_field_name("function")
439
+ if func_node:
440
+ callee_name = source[
441
+ func_node.start_byte : func_node.end_byte
442
+ ].decode("utf-8", errors="replace")
443
+ else:
444
+ value_node = node.child_by_field_name("value")
445
+ name_node = node.child_by_field_name("name")
446
+ if value_node and name_node:
447
+ receiver = (
448
+ source[value_node.start_byte : value_node.end_byte]
449
+ .decode("utf-8", errors="replace")
450
+ .strip()
451
+ )
452
+ method = (
453
+ source[name_node.start_byte : name_node.end_byte]
454
+ .decode("utf-8", errors="replace")
455
+ .strip()
456
+ )
457
+ callee_name = f"{receiver}.{method}"
458
+
459
+ if callee_name:
460
+ # Find enclosing caller function/method ID
461
+ caller_id = file_node_id
462
+ curr = node.parent
463
+ while curr:
464
+ if curr.type == "function_item":
465
+ c_name_node = curr.child_by_field_name("name")
466
+ if c_name_node:
467
+ c_name = source[
468
+ c_name_node.start_byte : c_name_node.end_byte
469
+ ].decode("utf-8", errors="replace")
470
+ # Check if inside an impl block
471
+ impl_node = curr.parent
472
+ while impl_node and impl_node.type != "impl_item":
473
+ impl_node = impl_node.parent
474
+ if impl_node:
475
+ r_type = get_impl_type(impl_node)
476
+ if r_type:
477
+ caller_id = f"{rel_path}::{r_type}.{c_name}"
478
+ else:
479
+ caller_id = f"{rel_path}::{c_name}"
480
+ else:
481
+ caller_id = f"{rel_path}::{c_name}"
482
+ break
483
+ curr = curr.parent
484
+
485
+ result.edges.append(
486
+ EdgeSchema(
487
+ source=caller_id, target=callee_name, relation="calls"
488
+ )
489
+ )
490
+
491
+ # Recurse children
492
+ impl_context = pushed_impl if pushed_impl else current_impl_type
493
+ for child in node.children:
494
+ walk(child, impl_context)
495
+
496
+ walk(root)
497
+ return result