repgen-ai 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. repgen/__init__.py +51 -0
  2. repgen/__pycache__/__init__.cpython-313.pyc +0 -0
  3. repgen/__pycache__/cli.cpython-313.pyc +0 -0
  4. repgen/__pycache__/core.cpython-313.pyc +0 -0
  5. repgen/__pycache__/server.cpython-313.pyc +0 -0
  6. repgen/__pycache__/utils.cpython-313.pyc +0 -0
  7. repgen/cli.py +375 -0
  8. repgen/core.py +239 -0
  9. repgen/retrieval/__init__.py +4 -0
  10. repgen/retrieval/__pycache__/__init__.cpython-313.pyc +0 -0
  11. repgen/retrieval/__pycache__/config.cpython-313.pyc +0 -0
  12. repgen/retrieval/__pycache__/pipeline.cpython-313.pyc +0 -0
  13. repgen/retrieval/config.py +53 -0
  14. repgen/retrieval/core/__init__.py +0 -0
  15. repgen/retrieval/core/__pycache__/__init__.cpython-313.pyc +0 -0
  16. repgen/retrieval/core/__pycache__/code_indexer.cpython-313.pyc +0 -0
  17. repgen/retrieval/core/__pycache__/dependency_analyzer.cpython-313.pyc +0 -0
  18. repgen/retrieval/core/__pycache__/module_analyzer.cpython-313.pyc +0 -0
  19. repgen/retrieval/core/__pycache__/training_code_detector.cpython-313.pyc +0 -0
  20. repgen/retrieval/core/__pycache__/utils.cpython-313.pyc +0 -0
  21. repgen/retrieval/core/code_indexer.py +138 -0
  22. repgen/retrieval/core/dependency_analyzer.py +121 -0
  23. repgen/retrieval/core/module_analyzer.py +65 -0
  24. repgen/retrieval/core/training_code_detector.py +240 -0
  25. repgen/retrieval/core/utils.py +52 -0
  26. repgen/retrieval/models/__init__.py +0 -0
  27. repgen/retrieval/models/__pycache__/__init__.cpython-313.pyc +0 -0
  28. repgen/retrieval/models/__pycache__/hybrid_search.cpython-313.pyc +0 -0
  29. repgen/retrieval/models/hybrid_search.py +151 -0
  30. repgen/retrieval/pipeline.py +166 -0
  31. repgen/server.py +111 -0
  32. repgen/utils.py +550 -0
  33. repgen_ai-0.1.0.dist-info/METADATA +199 -0
  34. repgen_ai-0.1.0.dist-info/RECORD +36 -0
  35. repgen_ai-0.1.0.dist-info/WHEEL +5 -0
  36. repgen_ai-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,53 @@
1
+ import os
2
+ from pathlib import Path
3
+
4
+
5
+ class Config:
6
+ def __init__(self, repo_path: str, bug_report_path: str, output_dir: str):
7
+
8
+ self.BASE_DIR = Path(os.path.dirname(os.path.abspath(__file__))).parent
9
+
10
+ self.CODE_DIR = Path(repo_path).resolve()
11
+ self.BUG_REPORT_FILE = Path(bug_report_path).resolve()
12
+ self.OUTPUT_DIR = Path(output_dir).resolve()
13
+
14
+ # Unique project identifier based on repo name and bug report name
15
+ self.PROJECT_ID = f"{self.CODE_DIR.name}_{self.BUG_REPORT_FILE.stem}"
16
+ self.PROJECT_DIR = self.OUTPUT_DIR / self.PROJECT_ID
17
+
18
+ # --- Base paths for outputs ---
19
+ self.REFINED_BUG_REPORT_DIR = self.PROJECT_DIR / "refined_bug_report"
20
+ self.CONTEXT_DIR = self.PROJECT_DIR / "context"
21
+ self.PLANS_DIR = self.PROJECT_DIR / "plan"
22
+ self.REPRODUCTION_DIR = self.PROJECT_DIR / "reproduction_code"
23
+
24
+ # Set IN/OUT paths to the same consolidated directories
25
+ self.REFINED_BUG_REPORT_DIR_IN = self.REFINED_BUG_REPORT_DIR
26
+ self.REFINED_BUG_REPORT_DIR_OUT = self.REFINED_BUG_REPORT_DIR
27
+ self.CONTEXT_DIR_IN = self.CONTEXT_DIR
28
+ self.CONTEXT_DIR_OUT = self.CONTEXT_DIR
29
+ self.PLANS_DIR_IN = self.PLANS_DIR
30
+ self.PLANS_DIR_OUT = self.PLANS_DIR
31
+ self.REPRODUCTION_DIR_OUT = self.REPRODUCTION_DIR
32
+
33
+ # Model configurations
34
+ self.EMBEDDING_MODEL = (
35
+ "flax-sentence-embeddings/st-codesearch-distilroberta-base"
36
+ )
37
+ self.RERANKER_MODEL = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
38
+
39
+ # Search parameters
40
+ self.SEARCH_TOP_K = 200
41
+ self.RERANK_TOP_K = 20
42
+ self.ALPHA = 0.55
43
+
44
+ # Ensure directories exist
45
+ self._setup_directories()
46
+
47
+ def _setup_directories(self):
48
+ # Ensure base directories exist
49
+ self.PROJECT_DIR.mkdir(parents=True, exist_ok=True)
50
+ self.REFINED_BUG_REPORT_DIR.mkdir(parents=True, exist_ok=True)
51
+ self.CONTEXT_DIR.mkdir(parents=True, exist_ok=True)
52
+ self.PLANS_DIR.mkdir(parents=True, exist_ok=True)
53
+ self.REPRODUCTION_DIR.mkdir(parents=True, exist_ok=True)
File without changes
@@ -0,0 +1,138 @@
1
+ import ast
2
+ import logging
3
+ import time
4
+ from multiprocessing import Pool, cpu_count
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List, Optional
7
+
8
+ from ..models.hybrid_search import HybridSearchIndex
9
+ from .utils import tokenize
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class CodeIndexer:
15
+ def __init__(self, config):
16
+ self.config = config
17
+
18
+ def _load_file(self, file_path: Path) -> Optional[Dict[str, Any]]:
19
+ """Load a single file and return as document."""
20
+ try:
21
+ with open(file_path, "r", encoding="utf-8") as f:
22
+ return {
23
+ "page_content": f.read(),
24
+ "metadata": {"source": str(file_path), "type": "file"},
25
+ }
26
+ except Exception as e:
27
+ logger.warning(f"Error loading {file_path}: {e}")
28
+ return None
29
+
30
+ def _semantic_chunking(self, doc: Dict[str, Any]) -> List[Dict[str, Any]]:
31
+ """Split code into semantic chunks (functions, classes)."""
32
+ chunks = []
33
+ file_path = doc["metadata"]["source"]
34
+ code = doc["page_content"]
35
+
36
+ try:
37
+ tree = ast.parse(code)
38
+ current_chunk = []
39
+
40
+ for node in ast.walk(tree):
41
+ if isinstance(
42
+ node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)
43
+ ):
44
+ start_lineno = node.lineno - 1
45
+ end_lineno = (
46
+ node.end_lineno
47
+ if hasattr(node, "end_lineno")
48
+ else start_lineno + 1
49
+ )
50
+ chunk_code = "\n".join(code.splitlines()[start_lineno:end_lineno])
51
+ current_chunk.append(chunk_code)
52
+
53
+ if len(current_chunk) >= 3:
54
+ chunk = {
55
+ "page_content": "\n\n".join(current_chunk),
56
+ "metadata": {
57
+ "source": file_path,
58
+ "type": "semantic_chunk",
59
+ "start_line": start_lineno + 1,
60
+ "end_line": end_lineno,
61
+ },
62
+ }
63
+ chunks.append(chunk)
64
+ current_chunk = []
65
+
66
+ if current_chunk:
67
+ chunk = {
68
+ "page_content": "\n\n".join(current_chunk),
69
+ "metadata": {
70
+ "source": file_path,
71
+ "type": "semantic_chunk",
72
+ "start_line": start_lineno + 1,
73
+ "end_line": end_lineno,
74
+ },
75
+ }
76
+ chunks.append(chunk)
77
+
78
+ except SyntaxError as e:
79
+ logger.warning(f"Syntax error in {file_path}: {e}")
80
+
81
+ return chunks
82
+
83
+ def _process_codebase(
84
+ self, code_dir: Path
85
+ ) -> tuple[List[Dict[str, Any]], List[List[str]]]:
86
+ """Process all files in the code directory."""
87
+ files = [
88
+ f
89
+ for f in code_dir.glob("**/*.py")
90
+ if "__pycache__" not in str(f)
91
+ and not any(part.startswith(".") for part in f.parts)
92
+ ]
93
+
94
+ with Pool(cpu_count()) as pool:
95
+ docs = pool.map(self._load_file, files)
96
+
97
+ valid_docs = [doc for doc in docs if doc is not None]
98
+ chunks = []
99
+ for doc in valid_docs:
100
+ chunks.extend(self._semantic_chunking(doc))
101
+
102
+ corpus = [tokenize(chunk["page_content"]) for chunk in chunks]
103
+ return chunks, corpus
104
+
105
+ def index_codebase(self, code_dir: Path) -> HybridSearchIndex:
106
+ """Index the codebase for hybrid search."""
107
+ start_time = time.time()
108
+ index_dir = self.config.PROJECT_DIR / "hybrid_index"
109
+
110
+ hybrid_index = HybridSearchIndex(
111
+ embedding_model=self.config.EMBEDDING_MODEL,
112
+ reranker_model=self.config.RERANKER_MODEL,
113
+ config=self.config,
114
+ )
115
+
116
+ if (index_dir / "documents.json").exists():
117
+ logger.info("Loading existing index")
118
+ hybrid_index.load_index(index_dir)
119
+ return hybrid_index
120
+
121
+ logger.info("Building new index")
122
+ chunks, corpus = self._process_codebase(code_dir)
123
+ hybrid_index.build_index(chunks, corpus)
124
+ hybrid_index.save_index(index_dir)
125
+
126
+ logger.info(f"Indexing completed in {time.time()-start_time:.2f}s")
127
+ return hybrid_index
128
+
129
+ def find_relevant_code(
130
+ self, bug_report: str, hybrid_index: HybridSearchIndex
131
+ ) -> List[dict]:
132
+ """Find code relevant to the bug report."""
133
+ return hybrid_index.search(
134
+ bug_report,
135
+ top_k=self.config.SEARCH_TOP_K,
136
+ rerank_top_k=self.config.RERANK_TOP_K,
137
+ alpha=self.config.ALPHA,
138
+ )
@@ -0,0 +1,121 @@
1
+ import ast
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import Any, Dict, List, Set, Tuple
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ class DependencyAnalyzer:
10
+ def __init__(self, config):
11
+ self.config = config
12
+ self.code_dir = Path(config.CODE_DIR).resolve()
13
+ self.ignore_modules = {
14
+ "os",
15
+ "sys",
16
+ "re",
17
+ "math",
18
+ "tensorflow",
19
+ "pytorch",
20
+ "numpy",
21
+ "pandas",
22
+ }
23
+
24
+ def _make_relative(self, path: Path) -> str:
25
+ """Convert path to relative path from the code directory."""
26
+ try:
27
+ return str(path.relative_to(self.code_dir))
28
+ except ValueError:
29
+ return str(path)
30
+
31
+ def _get_function_calls(self, node: ast.AST) -> Set[Tuple[str, str]]:
32
+ """Extract function calls from AST nodes."""
33
+ calls = set()
34
+
35
+ if isinstance(node, ast.Call):
36
+ if isinstance(node.func, ast.Attribute):
37
+ # Handle method calls like x.y()
38
+ if isinstance(node.func.value, ast.Name):
39
+ calls.add((node.func.value.id, node.func.attr))
40
+ elif isinstance(node.func, ast.Name):
41
+ # Handle direct function calls like x()
42
+ calls.add(("", node.func.id))
43
+
44
+ return calls
45
+
46
+ def analyze_file_dependencies(self, file_path: Path) -> List[Dict[str, str]]:
47
+ """Analyze a file and return its external dependencies."""
48
+ try:
49
+ with open(file_path, "r", encoding="utf-8") as f:
50
+ tree = ast.parse(f.read())
51
+ except Exception as e:
52
+ logger.debug(f"Could not parse {file_path}: {e}")
53
+ return []
54
+
55
+ imports = {}
56
+ dependencies = set()
57
+
58
+ # First pass: collect all imports
59
+ for node in ast.walk(tree):
60
+ if isinstance(node, ast.Import):
61
+ for alias in node.names:
62
+ module = alias.name.split(".")[0]
63
+ if module not in self.ignore_modules:
64
+ imports[alias.asname or alias.name.split(".")[0]] = module
65
+
66
+ elif isinstance(node, ast.ImportFrom):
67
+ if node.module and node.level == 0: # Only absolute imports
68
+ module = node.module.split(".")[0]
69
+ if module not in self.ignore_modules:
70
+ for alias in node.names:
71
+ imports[alias.asname or alias.name] = (
72
+ f"{module}.{alias.name}"
73
+ )
74
+
75
+ # Second pass: find function calls
76
+ for node in ast.walk(tree):
77
+ for module, func in self._get_function_calls(node):
78
+ if module in imports:
79
+ # Resolve the actual module from imports
80
+ actual_module = imports[module].split(".")[0]
81
+ dependencies.add((actual_module, func))
82
+ elif not module: # Direct function call
83
+ if func in imports:
84
+ full_path = imports[func].split(".")
85
+ if len(full_path) > 1:
86
+ dependencies.add((full_path[0], full_path[1]))
87
+
88
+ # Convert to the desired output format
89
+ return [{"module": m, "function": f} for m, f in sorted(dependencies)]
90
+
91
+ def analyze_training_dependencies(
92
+ self, training_report: Dict[str, Any]
93
+ ) -> Dict[str, Any]:
94
+ """Analyze dependencies for all training files."""
95
+ report = []
96
+
97
+ for entry in training_report["training_files"]:
98
+ try:
99
+ file_path = (self.code_dir / entry["file"]).resolve()
100
+ if not file_path.exists():
101
+ continue
102
+
103
+ external_deps = self.analyze_file_dependencies(file_path)
104
+
105
+ report.append(
106
+ {
107
+ "file": self._make_relative(file_path),
108
+ "score": entry["score"],
109
+ "contains_training": entry["contains_training"],
110
+ "external_dependencies": external_deps,
111
+ }
112
+ )
113
+
114
+ except Exception as e:
115
+ logger.warning(f"Error processing {entry.get('file')}: {e}")
116
+ continue
117
+
118
+ return {
119
+ "bug_report": training_report["bug_report"],
120
+ "dependencies": sorted(report, key=lambda x: x["score"], reverse=True),
121
+ }
@@ -0,0 +1,65 @@
1
+ import os
2
+ from collections import defaultdict
3
+ from pathlib import Path
4
+ from typing import Any, Dict, List
5
+
6
+ from .utils import extract_module_path
7
+
8
+
9
+ class ModuleAnalyzer:
10
+ def __init__(self, config):
11
+ self.config = config
12
+ self.code_dir = Path(config.CODE_DIR)
13
+
14
+ def _make_relative(self, file_path: str) -> str:
15
+ """Convert absolute path to relative path from the code directory."""
16
+ try:
17
+ path = Path(file_path)
18
+ # Handle both str and Path inputs
19
+ if not isinstance(file_path, (str, Path)):
20
+ return file_path
21
+
22
+ # Make relative to the code directory
23
+ relative_path = os.path.relpath(str(path), start=str(self.code_dir))
24
+ return relative_path.replace("\\", "/") # Normalize to forward slashes
25
+ except (TypeError, ValueError):
26
+ return file_path # Return original if conversion fails
27
+
28
+ def analyze_modules(
29
+ self, relevant_code: List[Dict[str, Any]]
30
+ ) -> List[Dict[str, Any]]:
31
+ """Group snippets by module and file using relative paths."""
32
+ module_map = defaultdict(lambda: defaultdict(list))
33
+
34
+ for snippet in relevant_code:
35
+ file_path = snippet["metadata"]["source"]
36
+ relative_path = self._make_relative(file_path)
37
+ module_path = extract_module_path(relative_path)
38
+
39
+ module_map[module_path][relative_path].append(
40
+ {
41
+ "start_line": snippet["metadata"].get("start_line"),
42
+ "end_line": snippet["metadata"].get("end_line"),
43
+ "code": snippet["page_content"],
44
+ }
45
+ )
46
+
47
+ return [
48
+ {
49
+ "module": module,
50
+ "files": [
51
+ {
52
+ "path": file,
53
+ "snippets": [
54
+ {
55
+ "lines": f"{s['start_line']}-{s['end_line']}",
56
+ "code": s["code"],
57
+ }
58
+ for s in snippets
59
+ ],
60
+ }
61
+ for file, snippets in files.items()
62
+ ],
63
+ }
64
+ for module, files in module_map.items()
65
+ ]
@@ -0,0 +1,240 @@
1
+ import ast
2
+ import logging
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List, Tuple
6
+
7
+ import torch
8
+ from sentence_transformers import CrossEncoder
9
+ from transformers import AutoTokenizer
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class TrainingLoopVisitor(ast.NodeVisitor):
15
+ """
16
+ AST visitor to detect training loops in:
17
+ - TensorFlow 1.x (Session.run in loops)
18
+ - TensorFlow 2.x (Keras APIs, GradientTape, and optimizer methods)
19
+ - PyTorch (optimizer steps in loops, DataLoader usage)
20
+ - PyTorch Lightning (training_step method)
21
+ """
22
+
23
+ def __init__(self):
24
+ # Framework flags
25
+ self.found_training = False
26
+
27
+ # Context tracking
28
+ self.loop_depth = 0 # Track nested loops
29
+ self.gradient_tape_depth = 0 # Track GradientTape contexts
30
+ self.current_class = None # Track class context
31
+
32
+ # Context management ------------------------------------------------------
33
+ def visit_ClassDef(self, node):
34
+ """Track class context for LightningModule detection"""
35
+ self.current_class = node.name
36
+ self.generic_visit(node)
37
+ self.current_class = None
38
+
39
+ def visit_For(self, node):
40
+ self.loop_depth += 1
41
+ self.generic_visit(node)
42
+ self.loop_depth -= 1
43
+
44
+ def visit_While(self, node):
45
+ self.loop_depth += 1
46
+ self.generic_visit(node)
47
+ self.loop_depth -= 1
48
+
49
+ def visit_With(self, node):
50
+ """Handle GradientTape and PyTorch profiler contexts"""
51
+ # Check for GradientTape first
52
+ if any(self._is_gradient_tape(item.context_expr) for item in node.items):
53
+ self.gradient_tape_depth += 1
54
+ self.generic_visit(node)
55
+ self.gradient_tape_depth -= 1
56
+ else:
57
+ self.generic_visit(node)
58
+
59
+ # Core detection logic -----------------------------------------------------
60
+ def visit_Call(self, node):
61
+ """Analyze method calls for framework patterns"""
62
+ self._check_tf1_patterns(node)
63
+ self._check_tf2_patterns(node)
64
+ self._check_pytorch_patterns(node)
65
+ self.generic_visit(node)
66
+
67
+ def _check_tf1_patterns(self, node):
68
+ """TensorFlow 1.x: Session.run in loops with training ops"""
69
+ if self.loop_depth > 0:
70
+ if isinstance(node.func, ast.Attribute) and node.func.attr == "run":
71
+ # Basic check for Session-like objects
72
+ if isinstance(node.func.value, (ast.Name, ast.Attribute)):
73
+ self.found_training = True
74
+
75
+ def _check_tf2_patterns(self, node):
76
+ """TensorFlow 2.x detection"""
77
+ if isinstance(node.func, ast.Attribute):
78
+ # High-level APIs
79
+ if node.func.attr in {"fit", "fit_generator", "train_on_batch"}:
80
+ self.found_training = True
81
+
82
+ # Optimizer methods
83
+ if node.func.attr in {"apply_gradients", "minimize"}:
84
+ self.found_training = True
85
+
86
+ # GradientTape context patterns
87
+ if self.gradient_tape_depth > 0:
88
+ if node.func.attr in {"gradient", "watch"}:
89
+ self.found_training = True
90
+
91
+ def _check_pytorch_patterns(self, node):
92
+ """PyTorch detection: Training steps in loops"""
93
+ if self.loop_depth > 0 and isinstance(node.func, ast.Attribute):
94
+ # Core training methods
95
+ if node.func.attr in {"backward", "step", "zero_grad"}:
96
+ self.found_training = True
97
+
98
+ # Loss calculation patterns
99
+ if node.func.attr in {"item", "backward"} and self._is_loss_node(
100
+ node.func.value
101
+ ):
102
+ self.found_training = True
103
+
104
+ # DataLoader patterns
105
+ if node.func.attr in {"to", "cuda"} and self._is_dataloader_node(
106
+ node.func.value
107
+ ):
108
+ self.found_training = True
109
+
110
+ # Helper methods -----------------------------------------------------------
111
+ def _is_gradient_tape(self, node):
112
+ """Identify GradientTape usage (direct or via tf.GradientTape)"""
113
+ if isinstance(node, ast.Call):
114
+ func = node.func
115
+ return (
116
+ isinstance(func, ast.Attribute) and func.attr == "GradientTape"
117
+ ) or (isinstance(func, ast.Name) and func.id == "GradientTape")
118
+ return False
119
+
120
+ def _is_loss_node(self, node):
121
+ """Heuristic to identify loss nodes (e.g., 'loss' in names)"""
122
+ if isinstance(node, ast.Name):
123
+ return "loss" in node.id.lower()
124
+ elif isinstance(node, ast.Attribute):
125
+ return "loss" in node.attr.lower()
126
+ return False
127
+
128
+ def _is_dataloader_node(self, node):
129
+ """Heuristic to identify DataLoader instances"""
130
+ if isinstance(node, ast.Name):
131
+ return "loader" in node.id.lower() or "dataloader" in node.id.lower()
132
+ elif isinstance(node, ast.Attribute):
133
+ return "loader" in node.attr.lower() or "dataloader" in node.attr.lower()
134
+ return False
135
+
136
+ # PyTorch Lightning support -----------------------------------------------
137
+ def visit_FunctionDef(self, node):
138
+ """Check for LightningModule training_step hooks"""
139
+ if self.current_class and node.name == "training_step":
140
+ self.found_training = True
141
+ self.generic_visit(node)
142
+
143
+
144
+ class TrainingCodeDetector:
145
+ def __init__(self, config):
146
+ self.config = config
147
+ self.reranker = CrossEncoder(config.RERANKER_MODEL)
148
+ self.tokenizer = AutoTokenizer.from_pretrained(config.RERANKER_MODEL)
149
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
150
+ self.reranker.to(self.device)
151
+
152
+ def contains_training(self, file_path: Path) -> bool:
153
+ """Check if a file contains training-related code using AST analysis."""
154
+ try:
155
+ with open(file_path, "r", encoding="utf-8") as f:
156
+ code = f.read()
157
+ tree = ast.parse(code)
158
+ except Exception as e:
159
+ logger.error(f"Error parsing {file_path}: {str(e)}")
160
+ return False
161
+
162
+ visitor = TrainingLoopVisitor()
163
+ visitor.visit(tree)
164
+ return visitor.found_training
165
+
166
+ def _get_all_files(self, module_report: Dict[str, Any]) -> List[Path]:
167
+ """Get all Python files from modules in the report."""
168
+ # --- MODIFICATION: Ablation: No Module-Centric Partitioning ---
169
+ files = set() # Use set to avoid duplicates
170
+
171
+ files = set() # Use set to avoid duplicates
172
+
173
+ # Original: Walk all files in the retrieved modules' directories
174
+ for module in module_report["modules"]:
175
+ module_path = self.config.CODE_DIR / module["module"]
176
+ if not module_path.exists():
177
+ continue
178
+
179
+ for root, _, filenames in os.walk(module_path):
180
+ for filename in filenames:
181
+ if filename.endswith(".py"):
182
+ files.add(Path(root) / filename)
183
+
184
+ return list(files)
185
+
186
+ def _rank_files(
187
+ self, files: List[Path], bug_report: str
188
+ ) -> List[Tuple[Path, float]]:
189
+ """Rank files by relevance to the bug report."""
190
+ if not files:
191
+ return []
192
+
193
+ # Tokenize with proper truncation
194
+ features = self.tokenizer(
195
+ [bug_report[:100000]] * len(files),
196
+ [file.read_text(encoding="utf-8")[:100000] for file in files],
197
+ padding=True,
198
+ truncation="longest_first",
199
+ max_length=512,
200
+ return_tensors="pt",
201
+ ).to(self.device)
202
+
203
+ # Run cross-encoder
204
+ with torch.no_grad():
205
+ scores = self.reranker(**features).logits.squeeze().cpu().numpy()
206
+
207
+ # Normalize scores
208
+ if len(scores) > 1:
209
+ scores = (scores - scores.min()) / (scores.max() - scores.min() + 1e-8)
210
+
211
+ return sorted(zip(files, scores), key=lambda x: x[1], reverse=True)
212
+
213
+ def detect_training_code(
214
+ self, module_report: Dict[str, Any], bug_report_path: Path
215
+ ) -> Dict[str, Any]:
216
+ """Detect and rank training-related code."""
217
+ bug_report = bug_report_path.read_text(encoding="utf-8")
218
+ all_files = self._get_all_files(module_report)
219
+
220
+ training_files = []
221
+ training_files = []
222
+ for file in all_files:
223
+ if self.contains_training(file):
224
+ training_files.append(file)
225
+
226
+ # Original ranking logic
227
+ ranked_files = self._rank_files(training_files, bug_report)
228
+ # --- END MODIFICATION ---
229
+
230
+ return {
231
+ "bug_report": module_report["bug_report"],
232
+ "training_files": [
233
+ {
234
+ "file": str(file.relative_to(self.config.CODE_DIR)),
235
+ "score": float(score),
236
+ "contains_training": True,
237
+ }
238
+ for file, score in ranked_files
239
+ ],
240
+ }
@@ -0,0 +1,52 @@
1
+ import json
2
+ import logging
3
+ import os
4
+ import re
5
+ from pathlib import Path
6
+ from typing import Any, List
7
+
8
+
9
+ def extract_module_path(file_path: str, dataset_root: str = "dataset/") -> str:
10
+ """Extract full module path relative to dataset root."""
11
+ try:
12
+ file_path = os.path.normpath(file_path)
13
+ dataset_root = os.path.normpath(dataset_root)
14
+ parts = file_path.split(os.sep)
15
+
16
+ try:
17
+ root_idx = parts.index(dataset_root)
18
+ except ValueError:
19
+ return os.path.dirname(file_path)
20
+
21
+ module_parts = parts[root_idx + 1 : -1] # Exclude filename
22
+ return os.path.join(*module_parts) if module_parts else "root"
23
+ except Exception as e:
24
+ logging.warning(f"Path parsing error {file_path}: {e}")
25
+ return "unknown"
26
+
27
+
28
+ def load_bug_report(bug_report_path: Path) -> str:
29
+ """Load bug report content."""
30
+ try:
31
+ with open(bug_report_path, "r", encoding="utf-8") as f:
32
+ return f.read()
33
+ except Exception as e:
34
+ raise ValueError(f"Error loading bug report: {str(e)}")
35
+
36
+
37
+ def save_json(data: Any, file_path: Path) -> None:
38
+ """Save data to JSON file."""
39
+ file_path.parent.mkdir(parents=True, exist_ok=True)
40
+ with open(file_path, "w", encoding="utf-8") as f:
41
+ json.dump(data, f, indent=2)
42
+
43
+
44
+ def load_json(file_path: Path) -> Any:
45
+ """Load data from JSON file."""
46
+ with open(file_path, "r", encoding="utf-8") as f:
47
+ return json.load(f)
48
+
49
+
50
+ def tokenize(text: str) -> List[str]:
51
+ """Basic tokenizer for code search."""
52
+ return re.findall(r"\b\w+[\w\d_]*\b", text.lower())
File without changes