repgen-ai 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- repgen/__init__.py +51 -0
- repgen/__pycache__/__init__.cpython-313.pyc +0 -0
- repgen/__pycache__/cli.cpython-313.pyc +0 -0
- repgen/__pycache__/core.cpython-313.pyc +0 -0
- repgen/__pycache__/server.cpython-313.pyc +0 -0
- repgen/__pycache__/utils.cpython-313.pyc +0 -0
- repgen/cli.py +375 -0
- repgen/core.py +239 -0
- repgen/retrieval/__init__.py +4 -0
- repgen/retrieval/__pycache__/__init__.cpython-313.pyc +0 -0
- repgen/retrieval/__pycache__/config.cpython-313.pyc +0 -0
- repgen/retrieval/__pycache__/pipeline.cpython-313.pyc +0 -0
- repgen/retrieval/config.py +53 -0
- repgen/retrieval/core/__init__.py +0 -0
- repgen/retrieval/core/__pycache__/__init__.cpython-313.pyc +0 -0
- repgen/retrieval/core/__pycache__/code_indexer.cpython-313.pyc +0 -0
- repgen/retrieval/core/__pycache__/dependency_analyzer.cpython-313.pyc +0 -0
- repgen/retrieval/core/__pycache__/module_analyzer.cpython-313.pyc +0 -0
- repgen/retrieval/core/__pycache__/training_code_detector.cpython-313.pyc +0 -0
- repgen/retrieval/core/__pycache__/utils.cpython-313.pyc +0 -0
- repgen/retrieval/core/code_indexer.py +138 -0
- repgen/retrieval/core/dependency_analyzer.py +121 -0
- repgen/retrieval/core/module_analyzer.py +65 -0
- repgen/retrieval/core/training_code_detector.py +240 -0
- repgen/retrieval/core/utils.py +52 -0
- repgen/retrieval/models/__init__.py +0 -0
- repgen/retrieval/models/__pycache__/__init__.cpython-313.pyc +0 -0
- repgen/retrieval/models/__pycache__/hybrid_search.cpython-313.pyc +0 -0
- repgen/retrieval/models/hybrid_search.py +151 -0
- repgen/retrieval/pipeline.py +166 -0
- repgen/server.py +111 -0
- repgen/utils.py +550 -0
- repgen_ai-0.1.0.dist-info/METADATA +199 -0
- repgen_ai-0.1.0.dist-info/RECORD +36 -0
- repgen_ai-0.1.0.dist-info/WHEEL +5 -0
- repgen_ai-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Config:
|
|
6
|
+
def __init__(self, repo_path: str, bug_report_path: str, output_dir: str):
|
|
7
|
+
|
|
8
|
+
self.BASE_DIR = Path(os.path.dirname(os.path.abspath(__file__))).parent
|
|
9
|
+
|
|
10
|
+
self.CODE_DIR = Path(repo_path).resolve()
|
|
11
|
+
self.BUG_REPORT_FILE = Path(bug_report_path).resolve()
|
|
12
|
+
self.OUTPUT_DIR = Path(output_dir).resolve()
|
|
13
|
+
|
|
14
|
+
# Unique project identifier based on repo name and bug report name
|
|
15
|
+
self.PROJECT_ID = f"{self.CODE_DIR.name}_{self.BUG_REPORT_FILE.stem}"
|
|
16
|
+
self.PROJECT_DIR = self.OUTPUT_DIR / self.PROJECT_ID
|
|
17
|
+
|
|
18
|
+
# --- Base paths for outputs ---
|
|
19
|
+
self.REFINED_BUG_REPORT_DIR = self.PROJECT_DIR / "refined_bug_report"
|
|
20
|
+
self.CONTEXT_DIR = self.PROJECT_DIR / "context"
|
|
21
|
+
self.PLANS_DIR = self.PROJECT_DIR / "plan"
|
|
22
|
+
self.REPRODUCTION_DIR = self.PROJECT_DIR / "reproduction_code"
|
|
23
|
+
|
|
24
|
+
# Set IN/OUT paths to the same consolidated directories
|
|
25
|
+
self.REFINED_BUG_REPORT_DIR_IN = self.REFINED_BUG_REPORT_DIR
|
|
26
|
+
self.REFINED_BUG_REPORT_DIR_OUT = self.REFINED_BUG_REPORT_DIR
|
|
27
|
+
self.CONTEXT_DIR_IN = self.CONTEXT_DIR
|
|
28
|
+
self.CONTEXT_DIR_OUT = self.CONTEXT_DIR
|
|
29
|
+
self.PLANS_DIR_IN = self.PLANS_DIR
|
|
30
|
+
self.PLANS_DIR_OUT = self.PLANS_DIR
|
|
31
|
+
self.REPRODUCTION_DIR_OUT = self.REPRODUCTION_DIR
|
|
32
|
+
|
|
33
|
+
# Model configurations
|
|
34
|
+
self.EMBEDDING_MODEL = (
|
|
35
|
+
"flax-sentence-embeddings/st-codesearch-distilroberta-base"
|
|
36
|
+
)
|
|
37
|
+
self.RERANKER_MODEL = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
|
|
38
|
+
|
|
39
|
+
# Search parameters
|
|
40
|
+
self.SEARCH_TOP_K = 200
|
|
41
|
+
self.RERANK_TOP_K = 20
|
|
42
|
+
self.ALPHA = 0.55
|
|
43
|
+
|
|
44
|
+
# Ensure directories exist
|
|
45
|
+
self._setup_directories()
|
|
46
|
+
|
|
47
|
+
def _setup_directories(self):
|
|
48
|
+
# Ensure base directories exist
|
|
49
|
+
self.PROJECT_DIR.mkdir(parents=True, exist_ok=True)
|
|
50
|
+
self.REFINED_BUG_REPORT_DIR.mkdir(parents=True, exist_ok=True)
|
|
51
|
+
self.CONTEXT_DIR.mkdir(parents=True, exist_ok=True)
|
|
52
|
+
self.PLANS_DIR.mkdir(parents=True, exist_ok=True)
|
|
53
|
+
self.REPRODUCTION_DIR.mkdir(parents=True, exist_ok=True)
|
|
File without changes
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import logging
|
|
3
|
+
import time
|
|
4
|
+
from multiprocessing import Pool, cpu_count
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Dict, List, Optional
|
|
7
|
+
|
|
8
|
+
from ..models.hybrid_search import HybridSearchIndex
|
|
9
|
+
from .utils import tokenize
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class CodeIndexer:
|
|
15
|
+
def __init__(self, config):
|
|
16
|
+
self.config = config
|
|
17
|
+
|
|
18
|
+
def _load_file(self, file_path: Path) -> Optional[Dict[str, Any]]:
|
|
19
|
+
"""Load a single file and return as document."""
|
|
20
|
+
try:
|
|
21
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
|
22
|
+
return {
|
|
23
|
+
"page_content": f.read(),
|
|
24
|
+
"metadata": {"source": str(file_path), "type": "file"},
|
|
25
|
+
}
|
|
26
|
+
except Exception as e:
|
|
27
|
+
logger.warning(f"Error loading {file_path}: {e}")
|
|
28
|
+
return None
|
|
29
|
+
|
|
30
|
+
def _semantic_chunking(self, doc: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
31
|
+
"""Split code into semantic chunks (functions, classes)."""
|
|
32
|
+
chunks = []
|
|
33
|
+
file_path = doc["metadata"]["source"]
|
|
34
|
+
code = doc["page_content"]
|
|
35
|
+
|
|
36
|
+
try:
|
|
37
|
+
tree = ast.parse(code)
|
|
38
|
+
current_chunk = []
|
|
39
|
+
|
|
40
|
+
for node in ast.walk(tree):
|
|
41
|
+
if isinstance(
|
|
42
|
+
node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)
|
|
43
|
+
):
|
|
44
|
+
start_lineno = node.lineno - 1
|
|
45
|
+
end_lineno = (
|
|
46
|
+
node.end_lineno
|
|
47
|
+
if hasattr(node, "end_lineno")
|
|
48
|
+
else start_lineno + 1
|
|
49
|
+
)
|
|
50
|
+
chunk_code = "\n".join(code.splitlines()[start_lineno:end_lineno])
|
|
51
|
+
current_chunk.append(chunk_code)
|
|
52
|
+
|
|
53
|
+
if len(current_chunk) >= 3:
|
|
54
|
+
chunk = {
|
|
55
|
+
"page_content": "\n\n".join(current_chunk),
|
|
56
|
+
"metadata": {
|
|
57
|
+
"source": file_path,
|
|
58
|
+
"type": "semantic_chunk",
|
|
59
|
+
"start_line": start_lineno + 1,
|
|
60
|
+
"end_line": end_lineno,
|
|
61
|
+
},
|
|
62
|
+
}
|
|
63
|
+
chunks.append(chunk)
|
|
64
|
+
current_chunk = []
|
|
65
|
+
|
|
66
|
+
if current_chunk:
|
|
67
|
+
chunk = {
|
|
68
|
+
"page_content": "\n\n".join(current_chunk),
|
|
69
|
+
"metadata": {
|
|
70
|
+
"source": file_path,
|
|
71
|
+
"type": "semantic_chunk",
|
|
72
|
+
"start_line": start_lineno + 1,
|
|
73
|
+
"end_line": end_lineno,
|
|
74
|
+
},
|
|
75
|
+
}
|
|
76
|
+
chunks.append(chunk)
|
|
77
|
+
|
|
78
|
+
except SyntaxError as e:
|
|
79
|
+
logger.warning(f"Syntax error in {file_path}: {e}")
|
|
80
|
+
|
|
81
|
+
return chunks
|
|
82
|
+
|
|
83
|
+
def _process_codebase(
|
|
84
|
+
self, code_dir: Path
|
|
85
|
+
) -> tuple[List[Dict[str, Any]], List[List[str]]]:
|
|
86
|
+
"""Process all files in the code directory."""
|
|
87
|
+
files = [
|
|
88
|
+
f
|
|
89
|
+
for f in code_dir.glob("**/*.py")
|
|
90
|
+
if "__pycache__" not in str(f)
|
|
91
|
+
and not any(part.startswith(".") for part in f.parts)
|
|
92
|
+
]
|
|
93
|
+
|
|
94
|
+
with Pool(cpu_count()) as pool:
|
|
95
|
+
docs = pool.map(self._load_file, files)
|
|
96
|
+
|
|
97
|
+
valid_docs = [doc for doc in docs if doc is not None]
|
|
98
|
+
chunks = []
|
|
99
|
+
for doc in valid_docs:
|
|
100
|
+
chunks.extend(self._semantic_chunking(doc))
|
|
101
|
+
|
|
102
|
+
corpus = [tokenize(chunk["page_content"]) for chunk in chunks]
|
|
103
|
+
return chunks, corpus
|
|
104
|
+
|
|
105
|
+
def index_codebase(self, code_dir: Path) -> HybridSearchIndex:
|
|
106
|
+
"""Index the codebase for hybrid search."""
|
|
107
|
+
start_time = time.time()
|
|
108
|
+
index_dir = self.config.PROJECT_DIR / "hybrid_index"
|
|
109
|
+
|
|
110
|
+
hybrid_index = HybridSearchIndex(
|
|
111
|
+
embedding_model=self.config.EMBEDDING_MODEL,
|
|
112
|
+
reranker_model=self.config.RERANKER_MODEL,
|
|
113
|
+
config=self.config,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
if (index_dir / "documents.json").exists():
|
|
117
|
+
logger.info("Loading existing index")
|
|
118
|
+
hybrid_index.load_index(index_dir)
|
|
119
|
+
return hybrid_index
|
|
120
|
+
|
|
121
|
+
logger.info("Building new index")
|
|
122
|
+
chunks, corpus = self._process_codebase(code_dir)
|
|
123
|
+
hybrid_index.build_index(chunks, corpus)
|
|
124
|
+
hybrid_index.save_index(index_dir)
|
|
125
|
+
|
|
126
|
+
logger.info(f"Indexing completed in {time.time()-start_time:.2f}s")
|
|
127
|
+
return hybrid_index
|
|
128
|
+
|
|
129
|
+
def find_relevant_code(
|
|
130
|
+
self, bug_report: str, hybrid_index: HybridSearchIndex
|
|
131
|
+
) -> List[dict]:
|
|
132
|
+
"""Find code relevant to the bug report."""
|
|
133
|
+
return hybrid_index.search(
|
|
134
|
+
bug_report,
|
|
135
|
+
top_k=self.config.SEARCH_TOP_K,
|
|
136
|
+
rerank_top_k=self.config.RERANK_TOP_K,
|
|
137
|
+
alpha=self.config.ALPHA,
|
|
138
|
+
)
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import logging
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Dict, List, Set, Tuple
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DependencyAnalyzer:
|
|
10
|
+
def __init__(self, config):
|
|
11
|
+
self.config = config
|
|
12
|
+
self.code_dir = Path(config.CODE_DIR).resolve()
|
|
13
|
+
self.ignore_modules = {
|
|
14
|
+
"os",
|
|
15
|
+
"sys",
|
|
16
|
+
"re",
|
|
17
|
+
"math",
|
|
18
|
+
"tensorflow",
|
|
19
|
+
"pytorch",
|
|
20
|
+
"numpy",
|
|
21
|
+
"pandas",
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
def _make_relative(self, path: Path) -> str:
|
|
25
|
+
"""Convert path to relative path from the code directory."""
|
|
26
|
+
try:
|
|
27
|
+
return str(path.relative_to(self.code_dir))
|
|
28
|
+
except ValueError:
|
|
29
|
+
return str(path)
|
|
30
|
+
|
|
31
|
+
def _get_function_calls(self, node: ast.AST) -> Set[Tuple[str, str]]:
|
|
32
|
+
"""Extract function calls from AST nodes."""
|
|
33
|
+
calls = set()
|
|
34
|
+
|
|
35
|
+
if isinstance(node, ast.Call):
|
|
36
|
+
if isinstance(node.func, ast.Attribute):
|
|
37
|
+
# Handle method calls like x.y()
|
|
38
|
+
if isinstance(node.func.value, ast.Name):
|
|
39
|
+
calls.add((node.func.value.id, node.func.attr))
|
|
40
|
+
elif isinstance(node.func, ast.Name):
|
|
41
|
+
# Handle direct function calls like x()
|
|
42
|
+
calls.add(("", node.func.id))
|
|
43
|
+
|
|
44
|
+
return calls
|
|
45
|
+
|
|
46
|
+
def analyze_file_dependencies(self, file_path: Path) -> List[Dict[str, str]]:
|
|
47
|
+
"""Analyze a file and return its external dependencies."""
|
|
48
|
+
try:
|
|
49
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
|
50
|
+
tree = ast.parse(f.read())
|
|
51
|
+
except Exception as e:
|
|
52
|
+
logger.debug(f"Could not parse {file_path}: {e}")
|
|
53
|
+
return []
|
|
54
|
+
|
|
55
|
+
imports = {}
|
|
56
|
+
dependencies = set()
|
|
57
|
+
|
|
58
|
+
# First pass: collect all imports
|
|
59
|
+
for node in ast.walk(tree):
|
|
60
|
+
if isinstance(node, ast.Import):
|
|
61
|
+
for alias in node.names:
|
|
62
|
+
module = alias.name.split(".")[0]
|
|
63
|
+
if module not in self.ignore_modules:
|
|
64
|
+
imports[alias.asname or alias.name.split(".")[0]] = module
|
|
65
|
+
|
|
66
|
+
elif isinstance(node, ast.ImportFrom):
|
|
67
|
+
if node.module and node.level == 0: # Only absolute imports
|
|
68
|
+
module = node.module.split(".")[0]
|
|
69
|
+
if module not in self.ignore_modules:
|
|
70
|
+
for alias in node.names:
|
|
71
|
+
imports[alias.asname or alias.name] = (
|
|
72
|
+
f"{module}.{alias.name}"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Second pass: find function calls
|
|
76
|
+
for node in ast.walk(tree):
|
|
77
|
+
for module, func in self._get_function_calls(node):
|
|
78
|
+
if module in imports:
|
|
79
|
+
# Resolve the actual module from imports
|
|
80
|
+
actual_module = imports[module].split(".")[0]
|
|
81
|
+
dependencies.add((actual_module, func))
|
|
82
|
+
elif not module: # Direct function call
|
|
83
|
+
if func in imports:
|
|
84
|
+
full_path = imports[func].split(".")
|
|
85
|
+
if len(full_path) > 1:
|
|
86
|
+
dependencies.add((full_path[0], full_path[1]))
|
|
87
|
+
|
|
88
|
+
# Convert to the desired output format
|
|
89
|
+
return [{"module": m, "function": f} for m, f in sorted(dependencies)]
|
|
90
|
+
|
|
91
|
+
def analyze_training_dependencies(
|
|
92
|
+
self, training_report: Dict[str, Any]
|
|
93
|
+
) -> Dict[str, Any]:
|
|
94
|
+
"""Analyze dependencies for all training files."""
|
|
95
|
+
report = []
|
|
96
|
+
|
|
97
|
+
for entry in training_report["training_files"]:
|
|
98
|
+
try:
|
|
99
|
+
file_path = (self.code_dir / entry["file"]).resolve()
|
|
100
|
+
if not file_path.exists():
|
|
101
|
+
continue
|
|
102
|
+
|
|
103
|
+
external_deps = self.analyze_file_dependencies(file_path)
|
|
104
|
+
|
|
105
|
+
report.append(
|
|
106
|
+
{
|
|
107
|
+
"file": self._make_relative(file_path),
|
|
108
|
+
"score": entry["score"],
|
|
109
|
+
"contains_training": entry["contains_training"],
|
|
110
|
+
"external_dependencies": external_deps,
|
|
111
|
+
}
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
except Exception as e:
|
|
115
|
+
logger.warning(f"Error processing {entry.get('file')}: {e}")
|
|
116
|
+
continue
|
|
117
|
+
|
|
118
|
+
return {
|
|
119
|
+
"bug_report": training_report["bug_report"],
|
|
120
|
+
"dependencies": sorted(report, key=lambda x: x["score"], reverse=True),
|
|
121
|
+
}
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
|
|
6
|
+
from .utils import extract_module_path
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ModuleAnalyzer:
|
|
10
|
+
def __init__(self, config):
|
|
11
|
+
self.config = config
|
|
12
|
+
self.code_dir = Path(config.CODE_DIR)
|
|
13
|
+
|
|
14
|
+
def _make_relative(self, file_path: str) -> str:
|
|
15
|
+
"""Convert absolute path to relative path from the code directory."""
|
|
16
|
+
try:
|
|
17
|
+
path = Path(file_path)
|
|
18
|
+
# Handle both str and Path inputs
|
|
19
|
+
if not isinstance(file_path, (str, Path)):
|
|
20
|
+
return file_path
|
|
21
|
+
|
|
22
|
+
# Make relative to the code directory
|
|
23
|
+
relative_path = os.path.relpath(str(path), start=str(self.code_dir))
|
|
24
|
+
return relative_path.replace("\\", "/") # Normalize to forward slashes
|
|
25
|
+
except (TypeError, ValueError):
|
|
26
|
+
return file_path # Return original if conversion fails
|
|
27
|
+
|
|
28
|
+
def analyze_modules(
|
|
29
|
+
self, relevant_code: List[Dict[str, Any]]
|
|
30
|
+
) -> List[Dict[str, Any]]:
|
|
31
|
+
"""Group snippets by module and file using relative paths."""
|
|
32
|
+
module_map = defaultdict(lambda: defaultdict(list))
|
|
33
|
+
|
|
34
|
+
for snippet in relevant_code:
|
|
35
|
+
file_path = snippet["metadata"]["source"]
|
|
36
|
+
relative_path = self._make_relative(file_path)
|
|
37
|
+
module_path = extract_module_path(relative_path)
|
|
38
|
+
|
|
39
|
+
module_map[module_path][relative_path].append(
|
|
40
|
+
{
|
|
41
|
+
"start_line": snippet["metadata"].get("start_line"),
|
|
42
|
+
"end_line": snippet["metadata"].get("end_line"),
|
|
43
|
+
"code": snippet["page_content"],
|
|
44
|
+
}
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
return [
|
|
48
|
+
{
|
|
49
|
+
"module": module,
|
|
50
|
+
"files": [
|
|
51
|
+
{
|
|
52
|
+
"path": file,
|
|
53
|
+
"snippets": [
|
|
54
|
+
{
|
|
55
|
+
"lines": f"{s['start_line']}-{s['end_line']}",
|
|
56
|
+
"code": s["code"],
|
|
57
|
+
}
|
|
58
|
+
for s in snippets
|
|
59
|
+
],
|
|
60
|
+
}
|
|
61
|
+
for file, snippets in files.items()
|
|
62
|
+
],
|
|
63
|
+
}
|
|
64
|
+
for module, files in module_map.items()
|
|
65
|
+
]
|
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Dict, List, Tuple
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from sentence_transformers import CrossEncoder
|
|
9
|
+
from transformers import AutoTokenizer
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TrainingLoopVisitor(ast.NodeVisitor):
|
|
15
|
+
"""
|
|
16
|
+
AST visitor to detect training loops in:
|
|
17
|
+
- TensorFlow 1.x (Session.run in loops)
|
|
18
|
+
- TensorFlow 2.x (Keras APIs, GradientTape, and optimizer methods)
|
|
19
|
+
- PyTorch (optimizer steps in loops, DataLoader usage)
|
|
20
|
+
- PyTorch Lightning (training_step method)
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self):
|
|
24
|
+
# Framework flags
|
|
25
|
+
self.found_training = False
|
|
26
|
+
|
|
27
|
+
# Context tracking
|
|
28
|
+
self.loop_depth = 0 # Track nested loops
|
|
29
|
+
self.gradient_tape_depth = 0 # Track GradientTape contexts
|
|
30
|
+
self.current_class = None # Track class context
|
|
31
|
+
|
|
32
|
+
# Context management ------------------------------------------------------
|
|
33
|
+
def visit_ClassDef(self, node):
|
|
34
|
+
"""Track class context for LightningModule detection"""
|
|
35
|
+
self.current_class = node.name
|
|
36
|
+
self.generic_visit(node)
|
|
37
|
+
self.current_class = None
|
|
38
|
+
|
|
39
|
+
def visit_For(self, node):
|
|
40
|
+
self.loop_depth += 1
|
|
41
|
+
self.generic_visit(node)
|
|
42
|
+
self.loop_depth -= 1
|
|
43
|
+
|
|
44
|
+
def visit_While(self, node):
|
|
45
|
+
self.loop_depth += 1
|
|
46
|
+
self.generic_visit(node)
|
|
47
|
+
self.loop_depth -= 1
|
|
48
|
+
|
|
49
|
+
def visit_With(self, node):
|
|
50
|
+
"""Handle GradientTape and PyTorch profiler contexts"""
|
|
51
|
+
# Check for GradientTape first
|
|
52
|
+
if any(self._is_gradient_tape(item.context_expr) for item in node.items):
|
|
53
|
+
self.gradient_tape_depth += 1
|
|
54
|
+
self.generic_visit(node)
|
|
55
|
+
self.gradient_tape_depth -= 1
|
|
56
|
+
else:
|
|
57
|
+
self.generic_visit(node)
|
|
58
|
+
|
|
59
|
+
# Core detection logic -----------------------------------------------------
|
|
60
|
+
def visit_Call(self, node):
|
|
61
|
+
"""Analyze method calls for framework patterns"""
|
|
62
|
+
self._check_tf1_patterns(node)
|
|
63
|
+
self._check_tf2_patterns(node)
|
|
64
|
+
self._check_pytorch_patterns(node)
|
|
65
|
+
self.generic_visit(node)
|
|
66
|
+
|
|
67
|
+
def _check_tf1_patterns(self, node):
|
|
68
|
+
"""TensorFlow 1.x: Session.run in loops with training ops"""
|
|
69
|
+
if self.loop_depth > 0:
|
|
70
|
+
if isinstance(node.func, ast.Attribute) and node.func.attr == "run":
|
|
71
|
+
# Basic check for Session-like objects
|
|
72
|
+
if isinstance(node.func.value, (ast.Name, ast.Attribute)):
|
|
73
|
+
self.found_training = True
|
|
74
|
+
|
|
75
|
+
def _check_tf2_patterns(self, node):
|
|
76
|
+
"""TensorFlow 2.x detection"""
|
|
77
|
+
if isinstance(node.func, ast.Attribute):
|
|
78
|
+
# High-level APIs
|
|
79
|
+
if node.func.attr in {"fit", "fit_generator", "train_on_batch"}:
|
|
80
|
+
self.found_training = True
|
|
81
|
+
|
|
82
|
+
# Optimizer methods
|
|
83
|
+
if node.func.attr in {"apply_gradients", "minimize"}:
|
|
84
|
+
self.found_training = True
|
|
85
|
+
|
|
86
|
+
# GradientTape context patterns
|
|
87
|
+
if self.gradient_tape_depth > 0:
|
|
88
|
+
if node.func.attr in {"gradient", "watch"}:
|
|
89
|
+
self.found_training = True
|
|
90
|
+
|
|
91
|
+
def _check_pytorch_patterns(self, node):
|
|
92
|
+
"""PyTorch detection: Training steps in loops"""
|
|
93
|
+
if self.loop_depth > 0 and isinstance(node.func, ast.Attribute):
|
|
94
|
+
# Core training methods
|
|
95
|
+
if node.func.attr in {"backward", "step", "zero_grad"}:
|
|
96
|
+
self.found_training = True
|
|
97
|
+
|
|
98
|
+
# Loss calculation patterns
|
|
99
|
+
if node.func.attr in {"item", "backward"} and self._is_loss_node(
|
|
100
|
+
node.func.value
|
|
101
|
+
):
|
|
102
|
+
self.found_training = True
|
|
103
|
+
|
|
104
|
+
# DataLoader patterns
|
|
105
|
+
if node.func.attr in {"to", "cuda"} and self._is_dataloader_node(
|
|
106
|
+
node.func.value
|
|
107
|
+
):
|
|
108
|
+
self.found_training = True
|
|
109
|
+
|
|
110
|
+
# Helper methods -----------------------------------------------------------
|
|
111
|
+
def _is_gradient_tape(self, node):
|
|
112
|
+
"""Identify GradientTape usage (direct or via tf.GradientTape)"""
|
|
113
|
+
if isinstance(node, ast.Call):
|
|
114
|
+
func = node.func
|
|
115
|
+
return (
|
|
116
|
+
isinstance(func, ast.Attribute) and func.attr == "GradientTape"
|
|
117
|
+
) or (isinstance(func, ast.Name) and func.id == "GradientTape")
|
|
118
|
+
return False
|
|
119
|
+
|
|
120
|
+
def _is_loss_node(self, node):
|
|
121
|
+
"""Heuristic to identify loss nodes (e.g., 'loss' in names)"""
|
|
122
|
+
if isinstance(node, ast.Name):
|
|
123
|
+
return "loss" in node.id.lower()
|
|
124
|
+
elif isinstance(node, ast.Attribute):
|
|
125
|
+
return "loss" in node.attr.lower()
|
|
126
|
+
return False
|
|
127
|
+
|
|
128
|
+
def _is_dataloader_node(self, node):
|
|
129
|
+
"""Heuristic to identify DataLoader instances"""
|
|
130
|
+
if isinstance(node, ast.Name):
|
|
131
|
+
return "loader" in node.id.lower() or "dataloader" in node.id.lower()
|
|
132
|
+
elif isinstance(node, ast.Attribute):
|
|
133
|
+
return "loader" in node.attr.lower() or "dataloader" in node.attr.lower()
|
|
134
|
+
return False
|
|
135
|
+
|
|
136
|
+
# PyTorch Lightning support -----------------------------------------------
|
|
137
|
+
def visit_FunctionDef(self, node):
|
|
138
|
+
"""Check for LightningModule training_step hooks"""
|
|
139
|
+
if self.current_class and node.name == "training_step":
|
|
140
|
+
self.found_training = True
|
|
141
|
+
self.generic_visit(node)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class TrainingCodeDetector:
|
|
145
|
+
def __init__(self, config):
|
|
146
|
+
self.config = config
|
|
147
|
+
self.reranker = CrossEncoder(config.RERANKER_MODEL)
|
|
148
|
+
self.tokenizer = AutoTokenizer.from_pretrained(config.RERANKER_MODEL)
|
|
149
|
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
150
|
+
self.reranker.to(self.device)
|
|
151
|
+
|
|
152
|
+
def contains_training(self, file_path: Path) -> bool:
|
|
153
|
+
"""Check if a file contains training-related code using AST analysis."""
|
|
154
|
+
try:
|
|
155
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
|
156
|
+
code = f.read()
|
|
157
|
+
tree = ast.parse(code)
|
|
158
|
+
except Exception as e:
|
|
159
|
+
logger.error(f"Error parsing {file_path}: {str(e)}")
|
|
160
|
+
return False
|
|
161
|
+
|
|
162
|
+
visitor = TrainingLoopVisitor()
|
|
163
|
+
visitor.visit(tree)
|
|
164
|
+
return visitor.found_training
|
|
165
|
+
|
|
166
|
+
def _get_all_files(self, module_report: Dict[str, Any]) -> List[Path]:
|
|
167
|
+
"""Get all Python files from modules in the report."""
|
|
168
|
+
# --- MODIFICATION: Ablation: No Module-Centric Partitioning ---
|
|
169
|
+
files = set() # Use set to avoid duplicates
|
|
170
|
+
|
|
171
|
+
files = set() # Use set to avoid duplicates
|
|
172
|
+
|
|
173
|
+
# Original: Walk all files in the retrieved modules' directories
|
|
174
|
+
for module in module_report["modules"]:
|
|
175
|
+
module_path = self.config.CODE_DIR / module["module"]
|
|
176
|
+
if not module_path.exists():
|
|
177
|
+
continue
|
|
178
|
+
|
|
179
|
+
for root, _, filenames in os.walk(module_path):
|
|
180
|
+
for filename in filenames:
|
|
181
|
+
if filename.endswith(".py"):
|
|
182
|
+
files.add(Path(root) / filename)
|
|
183
|
+
|
|
184
|
+
return list(files)
|
|
185
|
+
|
|
186
|
+
def _rank_files(
|
|
187
|
+
self, files: List[Path], bug_report: str
|
|
188
|
+
) -> List[Tuple[Path, float]]:
|
|
189
|
+
"""Rank files by relevance to the bug report."""
|
|
190
|
+
if not files:
|
|
191
|
+
return []
|
|
192
|
+
|
|
193
|
+
# Tokenize with proper truncation
|
|
194
|
+
features = self.tokenizer(
|
|
195
|
+
[bug_report[:100000]] * len(files),
|
|
196
|
+
[file.read_text(encoding="utf-8")[:100000] for file in files],
|
|
197
|
+
padding=True,
|
|
198
|
+
truncation="longest_first",
|
|
199
|
+
max_length=512,
|
|
200
|
+
return_tensors="pt",
|
|
201
|
+
).to(self.device)
|
|
202
|
+
|
|
203
|
+
# Run cross-encoder
|
|
204
|
+
with torch.no_grad():
|
|
205
|
+
scores = self.reranker(**features).logits.squeeze().cpu().numpy()
|
|
206
|
+
|
|
207
|
+
# Normalize scores
|
|
208
|
+
if len(scores) > 1:
|
|
209
|
+
scores = (scores - scores.min()) / (scores.max() - scores.min() + 1e-8)
|
|
210
|
+
|
|
211
|
+
return sorted(zip(files, scores), key=lambda x: x[1], reverse=True)
|
|
212
|
+
|
|
213
|
+
def detect_training_code(
|
|
214
|
+
self, module_report: Dict[str, Any], bug_report_path: Path
|
|
215
|
+
) -> Dict[str, Any]:
|
|
216
|
+
"""Detect and rank training-related code."""
|
|
217
|
+
bug_report = bug_report_path.read_text(encoding="utf-8")
|
|
218
|
+
all_files = self._get_all_files(module_report)
|
|
219
|
+
|
|
220
|
+
training_files = []
|
|
221
|
+
training_files = []
|
|
222
|
+
for file in all_files:
|
|
223
|
+
if self.contains_training(file):
|
|
224
|
+
training_files.append(file)
|
|
225
|
+
|
|
226
|
+
# Original ranking logic
|
|
227
|
+
ranked_files = self._rank_files(training_files, bug_report)
|
|
228
|
+
# --- END MODIFICATION ---
|
|
229
|
+
|
|
230
|
+
return {
|
|
231
|
+
"bug_report": module_report["bug_report"],
|
|
232
|
+
"training_files": [
|
|
233
|
+
{
|
|
234
|
+
"file": str(file.relative_to(self.config.CODE_DIR)),
|
|
235
|
+
"score": float(score),
|
|
236
|
+
"contains_training": True,
|
|
237
|
+
}
|
|
238
|
+
for file, score in ranked_files
|
|
239
|
+
],
|
|
240
|
+
}
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import re
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, List
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def extract_module_path(file_path: str, dataset_root: str = "dataset/") -> str:
|
|
10
|
+
"""Extract full module path relative to dataset root."""
|
|
11
|
+
try:
|
|
12
|
+
file_path = os.path.normpath(file_path)
|
|
13
|
+
dataset_root = os.path.normpath(dataset_root)
|
|
14
|
+
parts = file_path.split(os.sep)
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
root_idx = parts.index(dataset_root)
|
|
18
|
+
except ValueError:
|
|
19
|
+
return os.path.dirname(file_path)
|
|
20
|
+
|
|
21
|
+
module_parts = parts[root_idx + 1 : -1] # Exclude filename
|
|
22
|
+
return os.path.join(*module_parts) if module_parts else "root"
|
|
23
|
+
except Exception as e:
|
|
24
|
+
logging.warning(f"Path parsing error {file_path}: {e}")
|
|
25
|
+
return "unknown"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def load_bug_report(bug_report_path: Path) -> str:
|
|
29
|
+
"""Load bug report content."""
|
|
30
|
+
try:
|
|
31
|
+
with open(bug_report_path, "r", encoding="utf-8") as f:
|
|
32
|
+
return f.read()
|
|
33
|
+
except Exception as e:
|
|
34
|
+
raise ValueError(f"Error loading bug report: {str(e)}")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def save_json(data: Any, file_path: Path) -> None:
|
|
38
|
+
"""Save data to JSON file."""
|
|
39
|
+
file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
40
|
+
with open(file_path, "w", encoding="utf-8") as f:
|
|
41
|
+
json.dump(data, f, indent=2)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def load_json(file_path: Path) -> Any:
|
|
45
|
+
"""Load data from JSON file."""
|
|
46
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
|
47
|
+
return json.load(f)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def tokenize(text: str) -> List[str]:
|
|
51
|
+
"""Basic tokenizer for code search."""
|
|
52
|
+
return re.findall(r"\b\w+[\w\d_]*\b", text.lower())
|
|
File without changes
|