athena-code 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of athena-code might be problematic. Click here for more details.

@@ -0,0 +1,633 @@
1
+ import re
2
+
3
+ import tree_sitter_python
4
+ from tree_sitter import Language, Parser
5
+
6
+ from athena.models import (
7
+ ClassInfo,
8
+ Entity,
9
+ FunctionInfo,
10
+ Location,
11
+ MethodInfo,
12
+ ModuleInfo,
13
+ Parameter,
14
+ Signature,
15
+ )
16
+ from athena.parsers.base import BaseParser
17
+
18
+
19
+ class PythonParser(BaseParser):
20
+ """Parser for extracting entities from Python source code using tree-sitter."""
21
+
22
+ def __init__(self):
23
+ self.language = Language(tree_sitter_python.language())
24
+ self.parser = Parser(self.language)
25
+
26
+ def _extract_text(self, source_code: str, start_byte: int, end_byte: int) -> str:
27
+ """Extract text from source code using byte offsets.
28
+
29
+ Tree-sitter returns byte offsets for UTF-8 encoded strings, but Python
30
+ strings are Unicode. This helper converts properly.
31
+
32
+ Args:
33
+ source_code: The source code string
34
+ start_byte: Start byte offset
35
+ end_byte: End byte offset
36
+
37
+ Returns:
38
+ The extracted text as a string
39
+ """
40
+ source_bytes = source_code.encode("utf8")
41
+ return source_bytes[start_byte:end_byte].decode("utf8")
42
+
43
+ def extract_entities(self, source_code: str, file_path: str) -> list[Entity]:
44
+ """Extract functions, classes, and methods from Python source code.
45
+
46
+ Args:
47
+ source_code: Python source code to parse
48
+ file_path: Relative path to the file
49
+
50
+ Returns:
51
+ List of Entity objects
52
+ """
53
+ tree = self.parser.parse(bytes(source_code, "utf8"))
54
+ entities = []
55
+
56
+ entities.extend(self._extract_functions(tree.root_node, source_code, file_path))
57
+ entities.extend(self._extract_classes(tree.root_node, source_code, file_path))
58
+ entities.extend(self._extract_methods(tree.root_node, source_code, file_path))
59
+
60
+ return entities
61
+
62
+ def _extract_functions(self, node, source_code: str, file_path: str) -> list[Entity]:
63
+ """Extract top-level function definitions, including decorated ones."""
64
+ functions = []
65
+
66
+ for child in node.children:
67
+ func_node = None
68
+ extent_node = None
69
+
70
+ if child.type == "function_definition":
71
+ func_node = child
72
+ extent_node = child
73
+ elif child.type == "decorated_definition":
74
+ # Check if this decorated definition contains a function
75
+ for subchild in child.children:
76
+ if subchild.type == "function_definition":
77
+ func_node = subchild
78
+ extent_node = child # Use decorator's extent to include decorators
79
+ break
80
+
81
+ if func_node:
82
+ name_node = func_node.child_by_field_name("name")
83
+ if name_node:
84
+ name = self._extract_text(source_code, name_node.start_byte, name_node.end_byte)
85
+ start_line = extent_node.start_point[0]
86
+ end_line = extent_node.end_point[0]
87
+
88
+ functions.append(Entity(
89
+ kind="function",
90
+ path=file_path,
91
+ extent=Location(start=start_line, end=end_line),
92
+ name=name
93
+ ))
94
+
95
+ return functions
96
+
97
+ def _extract_classes(self, node, source_code: str, file_path: str) -> list[Entity]:
98
+ """Extract top-level class definitions, including decorated ones."""
99
+ classes = []
100
+
101
+ for child in node.children:
102
+ class_node = None
103
+ extent_node = None
104
+
105
+ if child.type == "class_definition":
106
+ class_node = child
107
+ extent_node = child
108
+ elif child.type == "decorated_definition":
109
+ # Check if this decorated definition contains a class
110
+ for subchild in child.children:
111
+ if subchild.type == "class_definition":
112
+ class_node = subchild
113
+ extent_node = child # Use decorator's extent to include decorators
114
+ break
115
+
116
+ if class_node:
117
+ name_node = class_node.child_by_field_name("name")
118
+ if name_node:
119
+ name = self._extract_text(source_code, name_node.start_byte, name_node.end_byte)
120
+ start_line = extent_node.start_point[0]
121
+ end_line = extent_node.end_point[0]
122
+
123
+ classes.append(Entity(
124
+ kind="class",
125
+ path=file_path,
126
+ extent=Location(start=start_line, end=end_line),
127
+ name=name
128
+ ))
129
+
130
+ return classes
131
+
132
+ def _extract_methods(self, node, source_code: str, file_path: str) -> list[Entity]:
133
+ """Extract method definitions (functions inside classes), including decorated ones."""
134
+ methods = []
135
+
136
+ for child in node.children:
137
+ class_node = None
138
+
139
+ # Handle both regular and decorated classes
140
+ if child.type == "class_definition":
141
+ class_node = child
142
+ elif child.type == "decorated_definition":
143
+ # Check if this decorated definition contains a class
144
+ for subchild in child.children:
145
+ if subchild.type == "class_definition":
146
+ class_node = subchild
147
+ break
148
+
149
+ if class_node:
150
+ # Get class name
151
+ class_name_node = class_node.child_by_field_name("name")
152
+ if not class_name_node:
153
+ continue
154
+ class_name = self._extract_text(source_code, class_name_node.start_byte, class_name_node.end_byte)
155
+
156
+ # Find the class body
157
+ body = class_node.child_by_field_name("body")
158
+ if body:
159
+ # Extract all function definitions inside the class body
160
+ for item in body.children:
161
+ method_node = None
162
+ extent_node = None
163
+
164
+ if item.type == "function_definition":
165
+ method_node = item
166
+ extent_node = item
167
+ elif item.type == "decorated_definition":
168
+ # Check if this decorated definition contains a method
169
+ for subitem in item.children:
170
+ if subitem.type == "function_definition":
171
+ method_node = subitem
172
+ extent_node = item # Use decorator's extent
173
+ break
174
+
175
+ if method_node:
176
+ name_node = method_node.child_by_field_name("name")
177
+ if name_node:
178
+ method_name = self._extract_text(source_code, name_node.start_byte, name_node.end_byte)
179
+ start_line = extent_node.start_point[0]
180
+ end_line = extent_node.end_point[0]
181
+
182
+ methods.append(Entity(
183
+ kind="method",
184
+ path=file_path,
185
+ extent=Location(start=start_line, end=end_line),
186
+ name=f"{class_name}.{method_name}"
187
+ ))
188
+
189
+ return methods
190
+
191
+ def _extract_docstring(self, node, source_code: str) -> str | None:
192
+ """Extract docstring from function/class/module node.
193
+
194
+ For functions/classes: Check if first child of body block is expression_statement
195
+ containing a string node.
196
+
197
+ For modules: Check if first child of root is expression_statement with string.
198
+
199
+ Args:
200
+ node: Tree-sitter node (function_definition, class_definition, or module root)
201
+ source_code: Source code string for text extraction
202
+
203
+ Returns:
204
+ Docstring content without quotes, or None if no docstring.
205
+ """
206
+ # For function/class definitions, get the body block first
207
+ if node.type in ("function_definition", "class_definition"):
208
+ body = node.child_by_field_name("body")
209
+ if not body or len(body.children) == 0:
210
+ return None
211
+ first_child = body.children[0]
212
+ else:
213
+ # For module nodes, check first child directly
214
+ if len(node.children) == 0:
215
+ return None
216
+ first_child = node.children[0]
217
+
218
+ # Check if first child is an expression_statement
219
+ if first_child.type != "expression_statement":
220
+ return None
221
+
222
+ # Check if the expression_statement contains a string
223
+ for child in first_child.children:
224
+ if child.type == "string":
225
+ # Extract the string content (without quotes)
226
+ # String node structure: string_start, string_content, string_end
227
+ for string_child in child.children:
228
+ if string_child.type == "string_content":
229
+ return self._extract_text(source_code, string_child.start_byte, string_child.end_byte)
230
+ # If no string_content found, the string might be empty
231
+ # Try extracting the whole string and remove quotes
232
+ text = self._extract_text(source_code, child.start_byte, child.end_byte)
233
+ # Handle triple quotes and single quotes
234
+ if text.startswith('"""') or text.startswith("'''"):
235
+ return text[3:-3]
236
+ elif text.startswith('"') or text.startswith("'"):
237
+ return text[1:-1]
238
+
239
+ return None
240
+
241
+ def _extract_parameters(self, node, source_code: str) -> list[Parameter]:
242
+ """Extract parameter list from function/method definition.
243
+
244
+ Args:
245
+ node: function_definition tree-sitter node
246
+ source_code: Source code string for text extraction
247
+
248
+ Returns:
249
+ List of Parameter objects
250
+ """
251
+ parameters = []
252
+
253
+ # Get the parameters node
254
+ params_node = node.child_by_field_name("parameters")
255
+ if not params_node:
256
+ return parameters
257
+
258
+ # Iterate through parameter nodes
259
+ for child in params_node.children:
260
+ # Skip punctuation tokens (, ), ,
261
+ if child.type in ("(", ")", ","):
262
+ continue
263
+
264
+ param_name = None
265
+ param_type = None
266
+ param_default = None
267
+
268
+ if child.type == "identifier":
269
+ # Simple parameter: def foo(x):
270
+ param_name = self._extract_text(source_code, child.start_byte, child.end_byte)
271
+
272
+ elif child.type == "typed_parameter":
273
+ # Parameter with type hint: def foo(x: int):
274
+ # Structure: typed_parameter -> identifier, :, type
275
+ for subchild in child.children:
276
+ if subchild.type == "identifier" and param_name is None:
277
+ param_name = self._extract_text(source_code, subchild.start_byte, subchild.end_byte)
278
+ elif subchild.type == "type":
279
+ param_type = self._extract_text(source_code, subchild.start_byte, subchild.end_byte)
280
+
281
+ elif child.type == "default_parameter":
282
+ # Parameter with default value: def foo(x=5):
283
+ name_node = child.child_by_field_name("name")
284
+ value_node = child.child_by_field_name("value")
285
+ if name_node:
286
+ param_name = self._extract_text(source_code, name_node.start_byte, name_node.end_byte)
287
+ if value_node:
288
+ param_default = self._extract_text(source_code, value_node.start_byte, value_node.end_byte)
289
+
290
+ elif child.type == "typed_default_parameter":
291
+ # Parameter with type and default: def foo(x: int = 5):
292
+ name_node = child.child_by_field_name("name")
293
+ type_node = child.child_by_field_name("type")
294
+ value_node = child.child_by_field_name("value")
295
+ if name_node:
296
+ param_name = self._extract_text(source_code, name_node.start_byte, name_node.end_byte)
297
+ if type_node:
298
+ param_type = self._extract_text(source_code, type_node.start_byte, type_node.end_byte)
299
+ if value_node:
300
+ param_default = self._extract_text(source_code, value_node.start_byte, value_node.end_byte)
301
+
302
+ elif child.type in ("list_splat_pattern", "dictionary_splat_pattern"):
303
+ # Handle *args and **kwargs
304
+ # list_splat_pattern is *args, dictionary_splat_pattern is **kwargs
305
+ text = self._extract_text(source_code, child.start_byte, child.end_byte)
306
+ param_name = text # Keep the * or ** prefix
307
+
308
+ # Add parameter if we found a name
309
+ if param_name:
310
+ parameters.append(Parameter(
311
+ name=param_name,
312
+ type=param_type,
313
+ default=param_default
314
+ ))
315
+
316
+ return parameters
317
+
318
+ def _extract_return_type(self, node, source_code: str) -> str | None:
319
+ """Extract return type annotation from function/method definition.
320
+
321
+ Args:
322
+ node: function_definition tree-sitter node
323
+ source_code: Source code string for text extraction
324
+
325
+ Returns:
326
+ Return type as string, or None if no annotation.
327
+ """
328
+ return_type_node = node.child_by_field_name("return_type")
329
+ if return_type_node:
330
+ return self._extract_text(source_code, return_type_node.start_byte, return_type_node.end_byte)
331
+ return None
332
+
333
+ def _format_signature(self, name: str, params: list[Parameter], return_type: str | None) -> str:
334
+ """Format a signature as a string.
335
+
336
+ Args:
337
+ name: Function/method name
338
+ params: List of Parameter objects
339
+ return_type: Return type annotation or None
340
+
341
+ Returns:
342
+ Formatted signature string like "func(x: int = 5, y: str) -> bool"
343
+ """
344
+ # Format each parameter
345
+ param_strs = []
346
+ for param in params:
347
+ if param.type and param.default:
348
+ # Has both type and default: x: int = 5
349
+ param_strs.append(f"{param.name}: {param.type} = {param.default}")
350
+ elif param.type:
351
+ # Has type only: x: int
352
+ param_strs.append(f"{param.name}: {param.type}")
353
+ elif param.default:
354
+ # Has default only: x = 5
355
+ param_strs.append(f"{param.name} = {param.default}")
356
+ else:
357
+ # Plain parameter: x
358
+ param_strs.append(param.name)
359
+
360
+ # Build signature
361
+ sig = f"{name}({', '.join(param_strs)})"
362
+
363
+ # Add return type if present
364
+ if return_type:
365
+ sig += f" -> {return_type}"
366
+
367
+ return sig
368
+
369
+ def extract_entity_info(
370
+ self,
371
+ source_code: str,
372
+ file_path: str,
373
+ entity_name: str | None = None
374
+ ) -> FunctionInfo | ClassInfo | MethodInfo | ModuleInfo | None:
375
+ """Extract detailed information about a specific entity.
376
+
377
+ Args:
378
+ source_code: Python source code
379
+ file_path: File path (for EntityInfo.path)
380
+ entity_name: Entity name to find, or None for module-level info
381
+
382
+ Returns:
383
+ EntityInfo object, or None if entity not found
384
+ """
385
+ tree = self.parser.parse(bytes(source_code, "utf8"))
386
+ root_node = tree.root_node
387
+
388
+ # If no entity name, return module-level info
389
+ if entity_name is None:
390
+ docstring = self._extract_docstring(root_node, source_code)
391
+ # Module extent is from start to end of file
392
+ lines = source_code.splitlines()
393
+ extent = Location(start=0, end=len(lines) - 1 if lines else 0)
394
+ return ModuleInfo(
395
+ path=file_path,
396
+ extent=extent,
397
+ summary=docstring
398
+ )
399
+
400
+ # Search for the named entity
401
+ # Check functions and classes (including decorated ones)
402
+ for child in root_node.children:
403
+ func_node = None
404
+ class_node = None
405
+ extent_node = None
406
+
407
+ # Handle direct function definitions
408
+ if child.type == "function_definition":
409
+ func_node = child
410
+ extent_node = child
411
+ # Handle decorated definitions
412
+ elif child.type == "decorated_definition":
413
+ for subchild in child.children:
414
+ if subchild.type == "function_definition":
415
+ func_node = subchild
416
+ extent_node = child # Use decorator's extent
417
+ elif subchild.type == "class_definition":
418
+ class_node = subchild
419
+ extent_node = child # Use decorator's extent
420
+ # Handle direct class definitions
421
+ elif child.type == "class_definition":
422
+ class_node = child
423
+ extent_node = child
424
+
425
+ # Check if we found a matching function
426
+ if func_node:
427
+ name_node = func_node.child_by_field_name("name")
428
+ if name_node:
429
+ name = self._extract_text(source_code, name_node.start_byte, name_node.end_byte)
430
+ if name == entity_name:
431
+ return self._build_entity_info_for_function(func_node, source_code, file_path, extent_node=extent_node)
432
+
433
+ # Check if we found a matching class or methods inside it
434
+ if class_node:
435
+ name_node = class_node.child_by_field_name("name")
436
+ if name_node:
437
+ class_name = self._extract_text(source_code, name_node.start_byte, name_node.end_byte)
438
+ if class_name == entity_name:
439
+ return self._build_entity_info_for_class(class_node, source_code, file_path, extent_node=extent_node)
440
+
441
+ # Also check methods inside this class
442
+ body = class_node.child_by_field_name("body")
443
+ if body:
444
+ for item in body.children:
445
+ method_node = None
446
+ method_extent_node = None
447
+
448
+ if item.type == "function_definition":
449
+ method_node = item
450
+ method_extent_node = item
451
+ elif item.type == "decorated_definition":
452
+ for subitem in item.children:
453
+ if subitem.type == "function_definition":
454
+ method_node = subitem
455
+ method_extent_node = item # Use decorator's extent
456
+ break
457
+
458
+ if method_node:
459
+ method_name_node = method_node.child_by_field_name("name")
460
+ if method_name_node:
461
+ method_name = self._extract_text(source_code, method_name_node.start_byte, method_name_node.end_byte)
462
+ if method_name == entity_name:
463
+ # Pass class_name to indicate this is a method
464
+ return self._build_entity_info_for_function(
465
+ method_node, source_code, file_path, class_name=class_name, extent_node=method_extent_node
466
+ )
467
+
468
+ return None
469
+
470
+ def _build_entity_info_for_function(
471
+ self, node, source_code: str, file_path: str, class_name: str | None = None, extent_node=None
472
+ ) -> FunctionInfo | MethodInfo:
473
+ """Build FunctionInfo or MethodInfo for a function or method.
474
+
475
+ Args:
476
+ node: function_definition tree-sitter node
477
+ source_code: Source code string
478
+ file_path: Relative file path
479
+ class_name: Class name if this is a method, None for top-level functions
480
+ extent_node: Optional node to use for extent (e.g., decorated_definition to include decorators)
481
+
482
+ Returns:
483
+ FunctionInfo if class_name is None, MethodInfo otherwise
484
+ """
485
+ name_node = node.child_by_field_name("name")
486
+ name = self._extract_text(source_code, name_node.start_byte, name_node.end_byte) if name_node else ""
487
+
488
+ # Extract signature components
489
+ params = self._extract_parameters(node, source_code)
490
+ return_type = self._extract_return_type(node, source_code)
491
+ sig = Signature(name=name, args=params, return_type=return_type)
492
+
493
+ # Extract docstring
494
+ docstring = self._extract_docstring(node, source_code)
495
+
496
+ # Extract extent (use extent_node if provided to include decorators)
497
+ extent_source = extent_node if extent_node is not None else node
498
+ start_line = extent_source.start_point[0]
499
+ end_line = extent_source.end_point[0]
500
+ extent = Location(start=start_line, end=end_line)
501
+
502
+ # Return MethodInfo if it's a method, otherwise FunctionInfo
503
+ if class_name:
504
+ return MethodInfo(
505
+ name=f"{class_name}.{name}",
506
+ path=file_path,
507
+ extent=extent,
508
+ sig=sig,
509
+ summary=docstring
510
+ )
511
+ else:
512
+ return FunctionInfo(
513
+ path=file_path,
514
+ extent=extent,
515
+ sig=sig,
516
+ summary=docstring
517
+ )
518
+
519
+ def _build_entity_info_for_class(self, node, source_code: str, file_path: str, extent_node=None) -> ClassInfo:
520
+ """Build ClassInfo for a class, including formatted method signatures.
521
+
522
+ Args:
523
+ node: class_definition tree-sitter node
524
+ source_code: Source code string
525
+ file_path: Relative file path
526
+ extent_node: Optional node to use for extent (e.g., decorated_definition to include decorators)
527
+ """
528
+ # Get class name
529
+ name_node = node.child_by_field_name("name")
530
+ class_name = self._extract_text(source_code, name_node.start_byte, name_node.end_byte) if name_node else ""
531
+
532
+ # Extract docstring
533
+ docstring = self._extract_docstring(node, source_code)
534
+
535
+ # Extract extent (use extent_node if provided to include decorators)
536
+ extent_source = extent_node if extent_node is not None else node
537
+ start_line = extent_source.start_point[0]
538
+ end_line = extent_source.end_point[0]
539
+ extent = Location(start=start_line, end=end_line)
540
+
541
+ # Extract methods
542
+ methods = []
543
+ body = node.child_by_field_name("body")
544
+ if body:
545
+ for item in body.children:
546
+ if item.type == "function_definition":
547
+ method_name_node = item.child_by_field_name("name")
548
+ if method_name_node:
549
+ method_name = self._extract_text(source_code, method_name_node.start_byte, method_name_node.end_byte)
550
+ params = self._extract_parameters(item, source_code)
551
+ return_type = self._extract_return_type(item, source_code)
552
+ # Format as string signature
553
+ formatted_sig = self._format_signature(method_name, params, return_type)
554
+ methods.append(formatted_sig)
555
+
556
+ return ClassInfo(
557
+ path=file_path,
558
+ extent=extent,
559
+ methods=methods,
560
+ summary=docstring
561
+ )
562
+
563
+ @staticmethod
564
+ def parse_athena_tag(docstring: str) -> str | None:
565
+ """Extract hash from @athena tag in docstring.
566
+
567
+ Args:
568
+ docstring: Docstring content to parse
569
+
570
+ Returns:
571
+ 12-character hex hash if tag found and valid, None otherwise
572
+ """
573
+ if not docstring:
574
+ return None
575
+
576
+ # Look for @athena: <hash> pattern
577
+ pattern = r"@athena:\s*([0-9a-f]{12})"
578
+ match = re.search(pattern, docstring, re.IGNORECASE)
579
+
580
+ if match:
581
+ return match.group(1)
582
+
583
+ return None
584
+
585
+ @staticmethod
586
+ def update_athena_tag(docstring: str, new_hash: str) -> str:
587
+ """Update or insert @athena tag in docstring.
588
+
589
+ If docstring is empty or None, creates a minimal docstring with the tag.
590
+ If tag exists, updates it. If tag doesn't exist, appends it.
591
+
592
+ Args:
593
+ docstring: Existing docstring content (may be None or empty)
594
+ new_hash: New 12-character hex hash to insert
595
+
596
+ Returns:
597
+ Updated docstring with @athena tag
598
+ """
599
+ # Handle empty/None docstring - create minimal docstring
600
+ if not docstring or not docstring.strip():
601
+ return f"@athena: {new_hash}"
602
+
603
+ # Check if tag already exists (match any 12 non-whitespace chars, not just hex)
604
+ pattern = r"@athena:\s*\S{12}"
605
+ if re.search(pattern, docstring, re.IGNORECASE):
606
+ # Update existing tag
607
+ return re.sub(
608
+ pattern, f"@athena: {new_hash}", docstring, flags=re.IGNORECASE
609
+ )
610
+ else:
611
+ # Append tag to end of docstring
612
+ # Ensure there's a newline before the tag if docstring doesn't end with one
613
+ if docstring.endswith("\n"):
614
+ return f"{docstring}@athena: {new_hash}"
615
+ else:
616
+ return f"{docstring}\n@athena: {new_hash}"
617
+
618
+ @staticmethod
619
+ def validate_athena_tag(tag: str) -> bool:
620
+ """Validate that a tag is a proper 12-character hex hash.
621
+
622
+ Args:
623
+ tag: Tag string to validate (without @athena: prefix)
624
+
625
+ Returns:
626
+ True if valid 12-character hex hash, False otherwise
627
+ """
628
+ if not tag:
629
+ return False
630
+
631
+ # Must be exactly 12 characters and all hex
632
+ pattern = r"^[0-9a-f]{12}$"
633
+ return bool(re.match(pattern, tag, re.IGNORECASE))