athena-code 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of athena-code might be problematic. Click here for more details.

athena/sync.py ADDED
@@ -0,0 +1,577 @@
1
+ """Core sync logic for updating @athena tags in docstrings."""
2
+
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ from athena.docstring_updater import update_docstring_in_source
7
+ from athena.entity_path import EntityPath, parse_entity_path, resolve_entity_path
8
+ from athena.hashing import compute_class_hash, compute_function_hash
9
+ from athena.models import EntityStatus, Location
10
+ from athena.parsers.python_parser import PythonParser
11
+
12
+
13
+ def should_exclude_path(path: Path, repo_root: Path) -> bool:
14
+ """Check if a path should be excluded from sync operations.
15
+
16
+ Excludes:
17
+ - athena package itself (prevents self-modification)
18
+ - virtualenvs (.venv, venv, etc.)
19
+ - site-packages
20
+ - Python installation directories
21
+
22
+ Args:
23
+ path: Path to check
24
+ repo_root: Repository root directory
25
+
26
+ Returns:
27
+ True if path should be excluded, False otherwise
28
+ """
29
+ # Convert to absolute paths for comparison (resolves symlinks)
30
+ abs_path = path.resolve()
31
+ abs_repo_root = repo_root.resolve()
32
+
33
+ # Exclude virtualenvs
34
+ parts = abs_path.parts
35
+ if any(part in ['.venv', 'venv', '.virtualenv', 'virtualenv'] for part in parts):
36
+ return True
37
+
38
+ # Exclude site-packages
39
+ if 'site-packages' in parts:
40
+ return True
41
+
42
+ # Exclude Python installation directories
43
+ sys_prefix = Path(sys.prefix).resolve()
44
+ try:
45
+ abs_path.relative_to(sys_prefix)
46
+ return True
47
+ except ValueError:
48
+ pass
49
+
50
+ # Exclude athena package itself
51
+ try:
52
+ rel_path = abs_path.relative_to(abs_repo_root)
53
+ rel_parts = rel_path.parts
54
+ # Check for src/athena/ pattern
55
+ if len(rel_parts) >= 2 and rel_parts[0] == 'src' and rel_parts[1] == 'athena':
56
+ return True
57
+ # Check for athena/ pattern at root (but only if it's the actual athena package)
58
+ # We need to be careful not to exclude user files that happen to be in a directory called "athena"
59
+ # Only exclude if it's actually the athena package with __init__.py
60
+ if rel_parts and rel_parts[0] == 'athena':
61
+ potential_package = abs_repo_root / 'athena'
62
+ if potential_package.is_dir() and (potential_package / '__init__.py').exists():
63
+ return True
64
+ except ValueError as e:
65
+ pass
66
+
67
+ return False
68
+
69
+
70
+ def needs_update(current_hash: str | None, computed_hash: str, force: bool) -> bool:
71
+ """Determine if an entity needs hash update.
72
+
73
+ Args:
74
+ current_hash: Current hash from docstring (None if no tag)
75
+ computed_hash: Newly computed hash from code
76
+ force: If True, always update regardless of match
77
+
78
+ Returns:
79
+ True if update is needed, False otherwise
80
+ """
81
+ if force:
82
+ return True
83
+
84
+ if current_hash is None:
85
+ # No existing hash, needs update
86
+ return True
87
+
88
+ # Update if hashes don't match
89
+ return current_hash != computed_hash
90
+
91
+
92
+ def inspect_entity(entity_path_str: str, repo_root: Path) -> EntityStatus:
93
+ """Inspect an entity and return its status information.
94
+
95
+ This function analyzes an entity and computes its current state:
96
+ - Resolves the entity path to a file
97
+ - Finds the entity in the AST
98
+ - Computes the hash from the AST
99
+ - Extracts the recorded hash from the docstring
100
+ - Determines entity kind and extent
101
+
102
+ Args:
103
+ entity_path_str: Entity path string (e.g., "src/foo.py:Bar")
104
+ repo_root: Repository root directory
105
+
106
+ Returns:
107
+ EntityStatus containing all status information
108
+
109
+ Raises:
110
+ FileNotFoundError: If entity file doesn't exist
111
+ ValueError: If entity is not found in file or path is invalid
112
+ NotImplementedError: For package/module level inspection
113
+ """
114
+ entity_path = parse_entity_path(entity_path_str)
115
+
116
+ resolved_path = resolve_entity_path(entity_path, repo_root)
117
+ if resolved_path is None or not resolved_path.exists():
118
+ raise FileNotFoundError(f"Entity file not found: {entity_path.file_path}")
119
+
120
+ if should_exclude_path(resolved_path, repo_root):
121
+ raise ValueError(f"Cannot inspect excluded path: {entity_path.file_path}")
122
+
123
+ if entity_path.is_package:
124
+ raise NotImplementedError(
125
+ "Package-level inspection not yet implemented. "
126
+ "Use module or entity-level inspection instead."
127
+ )
128
+
129
+ source_code = resolved_path.read_text()
130
+ parser = PythonParser()
131
+ tree = parser.parser.parse(bytes(source_code, "utf8"))
132
+ root_node = tree.root_node
133
+
134
+ if entity_path.is_module:
135
+ raise NotImplementedError(
136
+ "Module-level inspection not yet implemented. "
137
+ "Use entity-level inspection instead (e.g., file.py:function_name)."
138
+ )
139
+
140
+ entity_node = None
141
+ entity_extent_node = None
142
+
143
+ for child in root_node.children:
144
+ if child.type == "function_definition":
145
+ name_node = child.child_by_field_name("name")
146
+ if name_node:
147
+ name = parser._extract_text(
148
+ source_code, name_node.start_byte, name_node.end_byte
149
+ )
150
+ if name == entity_path.entity_name:
151
+ entity_node = child
152
+ entity_extent_node = child
153
+ break
154
+
155
+ elif child.type == "decorated_definition":
156
+ for subchild in child.children:
157
+ if subchild.type == "function_definition":
158
+ name_node = subchild.child_by_field_name("name")
159
+ if name_node:
160
+ name = parser._extract_text(
161
+ source_code, name_node.start_byte, name_node.end_byte
162
+ )
163
+ if name == entity_path.entity_name:
164
+ entity_node = subchild
165
+ entity_extent_node = child
166
+ break
167
+ elif subchild.type == "class_definition":
168
+ name_node = subchild.child_by_field_name("name")
169
+ if name_node:
170
+ name = parser._extract_text(
171
+ source_code, name_node.start_byte, name_node.end_byte
172
+ )
173
+ if name == entity_path.entity_name:
174
+ entity_node = subchild
175
+ entity_extent_node = child
176
+ break
177
+
178
+ # Check if we're looking for a method within this decorated class
179
+ if entity_path.is_method and name == entity_path.class_name:
180
+ body = subchild.child_by_field_name("body")
181
+ if body:
182
+ for item in body.children:
183
+ method_node = None
184
+ method_extent_node = None
185
+
186
+ if item.type == "function_definition":
187
+ method_node = item
188
+ method_extent_node = item
189
+ elif item.type == "decorated_definition":
190
+ for subitem in item.children:
191
+ if subitem.type == "function_definition":
192
+ method_node = subitem
193
+ method_extent_node = item
194
+ break
195
+
196
+ if method_node:
197
+ method_name_node = method_node.child_by_field_name(
198
+ "name"
199
+ )
200
+ if method_name_node:
201
+ method_name = parser._extract_text(
202
+ source_code,
203
+ method_name_node.start_byte,
204
+ method_name_node.end_byte,
205
+ )
206
+ if method_name == entity_path.method_name:
207
+ entity_node = method_node
208
+ entity_extent_node = method_extent_node
209
+ break
210
+
211
+ elif child.type == "class_definition":
212
+ name_node = child.child_by_field_name("name")
213
+ if name_node:
214
+ name = parser._extract_text(
215
+ source_code, name_node.start_byte, name_node.end_byte
216
+ )
217
+ if name == entity_path.entity_name:
218
+ entity_node = child
219
+ entity_extent_node = child
220
+ break
221
+
222
+ if entity_path.is_method and name == entity_path.class_name:
223
+ body = child.child_by_field_name("body")
224
+ if body:
225
+ for item in body.children:
226
+ method_node = None
227
+ method_extent_node = None
228
+
229
+ if item.type == "function_definition":
230
+ method_node = item
231
+ method_extent_node = item
232
+ elif item.type == "decorated_definition":
233
+ for subitem in item.children:
234
+ if subitem.type == "function_definition":
235
+ method_node = subitem
236
+ method_extent_node = item
237
+ break
238
+
239
+ if method_node:
240
+ method_name_node = method_node.child_by_field_name(
241
+ "name"
242
+ )
243
+ if method_name_node:
244
+ method_name = parser._extract_text(
245
+ source_code,
246
+ method_name_node.start_byte,
247
+ method_name_node.end_byte,
248
+ )
249
+ if method_name == entity_path.method_name:
250
+ entity_node = method_node
251
+ entity_extent_node = method_extent_node
252
+ break
253
+
254
+ if entity_node is None:
255
+ raise ValueError(f"Entity not found in file: {entity_path.entity_name}")
256
+
257
+ if entity_node.type == "function_definition":
258
+ computed_hash = compute_function_hash(entity_node, source_code)
259
+ kind = "method" if entity_path.is_method else "function"
260
+ elif entity_node.type == "class_definition":
261
+ computed_hash = compute_class_hash(entity_node, source_code)
262
+ kind = "class"
263
+ else:
264
+ raise ValueError(f"Unsupported entity type: {entity_node.type}")
265
+
266
+ current_docstring = parser._extract_docstring(entity_node, source_code)
267
+ if current_docstring:
268
+ current_docstring = current_docstring.strip()
269
+ current_hash = (
270
+ parser.parse_athena_tag(current_docstring) if current_docstring else None
271
+ )
272
+
273
+ extent_str = f"{entity_extent_node.start_point[0]}-{entity_extent_node.end_point[0]}"
274
+
275
+ return EntityStatus(
276
+ kind=kind,
277
+ path=entity_path_str,
278
+ extent=extent_str,
279
+ recorded_hash=current_hash,
280
+ calculated_hash=computed_hash
281
+ )
282
+
283
+
284
+ def sync_entity(entity_path_str: str, force: bool, repo_root: Path) -> bool:
285
+ """Sync hash tag for a single entity.
286
+
287
+ Uses inspect_entity() to analyze the entity state, then updates the
288
+ docstring if needed.
289
+
290
+ Args:
291
+ entity_path_str: Entity path string (e.g., "src/foo.py:Bar")
292
+ force: Force update even if hash matches
293
+ repo_root: Repository root directory
294
+
295
+ Returns:
296
+ True if entity was updated, False otherwise
297
+
298
+ Raises:
299
+ FileNotFoundError: If entity file doesn't exist
300
+ ValueError: If entity is not found in file or path is invalid
301
+ NotImplementedError: For package/module level sync
302
+ """
303
+ status = inspect_entity(entity_path_str, repo_root)
304
+
305
+ if not needs_update(status.recorded_hash, status.calculated_hash, force):
306
+ return False
307
+
308
+ entity_path = parse_entity_path(entity_path_str)
309
+ resolved_path = resolve_entity_path(entity_path, repo_root)
310
+ source_code = resolved_path.read_text()
311
+
312
+ parser = PythonParser()
313
+ tree = parser.parser.parse(bytes(source_code, "utf8"))
314
+ root_node = tree.root_node
315
+
316
+ entity_node = None
317
+ entity_extent_node = None
318
+
319
+ for child in root_node.children:
320
+ if child.type == "function_definition":
321
+ name_node = child.child_by_field_name("name")
322
+ if name_node:
323
+ name = parser._extract_text(
324
+ source_code, name_node.start_byte, name_node.end_byte
325
+ )
326
+ if name == entity_path.entity_name:
327
+ entity_node = child
328
+ entity_extent_node = child
329
+ break
330
+
331
+ elif child.type == "decorated_definition":
332
+ for subchild in child.children:
333
+ if subchild.type == "function_definition":
334
+ name_node = subchild.child_by_field_name("name")
335
+ if name_node:
336
+ name = parser._extract_text(
337
+ source_code, name_node.start_byte, name_node.end_byte
338
+ )
339
+ if name == entity_path.entity_name:
340
+ entity_node = subchild
341
+ entity_extent_node = child
342
+ break
343
+ elif subchild.type == "class_definition":
344
+ name_node = subchild.child_by_field_name("name")
345
+ if name_node:
346
+ name = parser._extract_text(
347
+ source_code, name_node.start_byte, name_node.end_byte
348
+ )
349
+ if name == entity_path.entity_name:
350
+ entity_node = subchild
351
+ entity_extent_node = child
352
+ break
353
+
354
+ # Check if we're looking for a method within this decorated class
355
+ if entity_path.is_method and name == entity_path.class_name:
356
+ body = subchild.child_by_field_name("body")
357
+ if body:
358
+ for item in body.children:
359
+ method_node = None
360
+ method_extent_node = None
361
+
362
+ if item.type == "function_definition":
363
+ method_node = item
364
+ method_extent_node = item
365
+ elif item.type == "decorated_definition":
366
+ for subitem in item.children:
367
+ if subitem.type == "function_definition":
368
+ method_node = subitem
369
+ method_extent_node = item
370
+ break
371
+
372
+ if method_node:
373
+ method_name_node = method_node.child_by_field_name(
374
+ "name"
375
+ )
376
+ if method_name_node:
377
+ method_name = parser._extract_text(
378
+ source_code,
379
+ method_name_node.start_byte,
380
+ method_name_node.end_byte,
381
+ )
382
+ if method_name == entity_path.method_name:
383
+ entity_node = method_node
384
+ entity_extent_node = method_extent_node
385
+ break
386
+
387
+ elif child.type == "class_definition":
388
+ name_node = child.child_by_field_name("name")
389
+ if name_node:
390
+ name = parser._extract_text(
391
+ source_code, name_node.start_byte, name_node.end_byte
392
+ )
393
+ if name == entity_path.entity_name:
394
+ entity_node = child
395
+ entity_extent_node = child
396
+ break
397
+
398
+ if entity_path.is_method and name == entity_path.class_name:
399
+ body = child.child_by_field_name("body")
400
+ if body:
401
+ for item in body.children:
402
+ method_node = None
403
+ method_extent_node = None
404
+
405
+ if item.type == "function_definition":
406
+ method_node = item
407
+ method_extent_node = item
408
+ elif item.type == "decorated_definition":
409
+ for subitem in item.children:
410
+ if subitem.type == "function_definition":
411
+ method_node = subitem
412
+ method_extent_node = item
413
+ break
414
+
415
+ if method_node:
416
+ method_name_node = method_node.child_by_field_name(
417
+ "name"
418
+ )
419
+ if method_name_node:
420
+ method_name = parser._extract_text(
421
+ source_code,
422
+ method_name_node.start_byte,
423
+ method_name_node.end_byte,
424
+ )
425
+ if method_name == entity_path.method_name:
426
+ entity_node = method_node
427
+ entity_extent_node = method_extent_node
428
+ break
429
+
430
+ current_docstring = parser._extract_docstring(entity_node, source_code)
431
+ if current_docstring:
432
+ current_docstring = current_docstring.strip()
433
+
434
+ updated_docstring = parser.update_athena_tag(current_docstring, status.calculated_hash)
435
+
436
+ entity_location = Location(
437
+ start=entity_extent_node.start_point[0], end=entity_extent_node.end_point[0]
438
+ )
439
+
440
+ updated_source = update_docstring_in_source(
441
+ source_code, entity_location, updated_docstring
442
+ )
443
+
444
+ resolved_path.write_text(updated_source)
445
+
446
+ return True
447
+
448
+
449
+ def collect_sub_entities(entity_path: EntityPath, repo_root: Path) -> list[str]:
450
+ """Collect all sub-entities for a given entity path.
451
+
452
+ For modules: returns all functions, classes, and methods
453
+ For packages: returns all modules and their entities
454
+ For classes: returns all methods
455
+
456
+ Args:
457
+ entity_path: Parsed entity path
458
+ repo_root: Repository root directory
459
+
460
+ Returns:
461
+ List of entity path strings for all sub-entities
462
+ """
463
+ resolved_path = resolve_entity_path(entity_path, repo_root)
464
+ if resolved_path is None:
465
+ return []
466
+
467
+ sub_entities = []
468
+ parser = PythonParser()
469
+
470
+ # Handle package (directory with __init__.py)
471
+ if entity_path.is_package:
472
+ # Find all .py files in the package
473
+ for py_file in resolved_path.rglob("*.py"):
474
+ # Skip excluded paths
475
+ if should_exclude_path(py_file, repo_root):
476
+ continue
477
+
478
+ if py_file.name == "__init__.py":
479
+ continue # Skip __init__ for now
480
+
481
+ # Get relative path from repo root
482
+ rel_path = py_file.relative_to(repo_root)
483
+
484
+ # Read file and extract entities
485
+ source_code = py_file.read_text()
486
+ entities = parser.extract_entities(source_code, str(rel_path))
487
+
488
+ # Add all entities
489
+ for entity in entities:
490
+ if entity.kind == "function":
491
+ sub_entities.append(f"{rel_path}:{entity.name}")
492
+ elif entity.kind == "class":
493
+ sub_entities.append(f"{rel_path}:{entity.name}")
494
+ elif entity.kind == "method":
495
+ # Method names include class prefix
496
+ sub_entities.append(f"{rel_path}:{entity.name}")
497
+
498
+ return sub_entities
499
+
500
+ # Handle module (single .py file)
501
+ if entity_path.is_module:
502
+ source_code = resolved_path.read_text()
503
+ entities = parser.extract_entities(source_code, entity_path.file_path)
504
+
505
+ for entity in entities:
506
+ if entity.kind == "function":
507
+ sub_entities.append(f"{entity_path.file_path}:{entity.name}")
508
+ elif entity.kind == "class":
509
+ sub_entities.append(f"{entity_path.file_path}:{entity.name}")
510
+ elif entity.kind == "method":
511
+ # Method names include class prefix
512
+ sub_entities.append(f"{entity_path.file_path}:{entity.name}")
513
+
514
+ return sub_entities
515
+
516
+ # Handle class (return all methods)
517
+ if entity_path.is_class:
518
+ source_code = resolved_path.read_text()
519
+ entities = parser.extract_entities(source_code, entity_path.file_path)
520
+
521
+ for entity in entities:
522
+ if entity.kind == "method" and entity.name.startswith(
523
+ f"{entity_path.entity_name}."
524
+ ):
525
+ sub_entities.append(f"{entity_path.file_path}:{entity.name}")
526
+
527
+ return sub_entities
528
+
529
+ # For methods or functions, no sub-entities
530
+ return []
531
+
532
+
533
+ def sync_recursive(entity_path_str: str, force: bool, repo_root: Path) -> int:
534
+ """Recursively sync an entity and all its sub-entities.
535
+
536
+ For modules: syncs all functions, classes, and methods
537
+ For packages: syncs all modules and their entities
538
+ For classes: syncs the class and all its methods
539
+ For functions/methods: syncs only that entity
540
+
541
+ Args:
542
+ entity_path_str: Entity path string
543
+ force: Force update even if hash matches
544
+ repo_root: Repository root directory
545
+
546
+ Returns:
547
+ Number of entities updated
548
+ """
549
+ entity_path = parse_entity_path(entity_path_str)
550
+
551
+ # Collect all entities to sync (including the main entity if applicable)
552
+ entities_to_sync = []
553
+
554
+ # For packages and modules, we only sync sub-entities
555
+ if entity_path.is_package or entity_path.is_module:
556
+ entities_to_sync = collect_sub_entities(entity_path, repo_root)
557
+ else:
558
+ # For classes, sync the class itself and all methods
559
+ if entity_path.is_class:
560
+ entities_to_sync.append(entity_path_str)
561
+ entities_to_sync.extend(collect_sub_entities(entity_path, repo_root))
562
+ else:
563
+ # For functions/methods, just sync the entity itself
564
+ entities_to_sync.append(entity_path_str)
565
+
566
+ # Sync all entities
567
+ update_count = 0
568
+ for entity in entities_to_sync:
569
+ try:
570
+ if sync_entity(entity, force, repo_root):
571
+ update_count += 1
572
+ except (ValueError, FileNotFoundError) as e:
573
+ # Log error but continue with other entities
574
+ # In a real implementation, we might want to use proper logging
575
+ print(f"Warning: Failed to sync {entity}: {e}")
576
+
577
+ return update_count