repgen-ai 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. repgen/__init__.py +51 -0
  2. repgen/__pycache__/__init__.cpython-313.pyc +0 -0
  3. repgen/__pycache__/cli.cpython-313.pyc +0 -0
  4. repgen/__pycache__/core.cpython-313.pyc +0 -0
  5. repgen/__pycache__/server.cpython-313.pyc +0 -0
  6. repgen/__pycache__/utils.cpython-313.pyc +0 -0
  7. repgen/cli.py +375 -0
  8. repgen/core.py +239 -0
  9. repgen/retrieval/__init__.py +4 -0
  10. repgen/retrieval/__pycache__/__init__.cpython-313.pyc +0 -0
  11. repgen/retrieval/__pycache__/config.cpython-313.pyc +0 -0
  12. repgen/retrieval/__pycache__/pipeline.cpython-313.pyc +0 -0
  13. repgen/retrieval/config.py +53 -0
  14. repgen/retrieval/core/__init__.py +0 -0
  15. repgen/retrieval/core/__pycache__/__init__.cpython-313.pyc +0 -0
  16. repgen/retrieval/core/__pycache__/code_indexer.cpython-313.pyc +0 -0
  17. repgen/retrieval/core/__pycache__/dependency_analyzer.cpython-313.pyc +0 -0
  18. repgen/retrieval/core/__pycache__/module_analyzer.cpython-313.pyc +0 -0
  19. repgen/retrieval/core/__pycache__/training_code_detector.cpython-313.pyc +0 -0
  20. repgen/retrieval/core/__pycache__/utils.cpython-313.pyc +0 -0
  21. repgen/retrieval/core/code_indexer.py +138 -0
  22. repgen/retrieval/core/dependency_analyzer.py +121 -0
  23. repgen/retrieval/core/module_analyzer.py +65 -0
  24. repgen/retrieval/core/training_code_detector.py +240 -0
  25. repgen/retrieval/core/utils.py +52 -0
  26. repgen/retrieval/models/__init__.py +0 -0
  27. repgen/retrieval/models/__pycache__/__init__.cpython-313.pyc +0 -0
  28. repgen/retrieval/models/__pycache__/hybrid_search.cpython-313.pyc +0 -0
  29. repgen/retrieval/models/hybrid_search.py +151 -0
  30. repgen/retrieval/pipeline.py +166 -0
  31. repgen/server.py +111 -0
  32. repgen/utils.py +550 -0
  33. repgen_ai-0.1.0.dist-info/METADATA +199 -0
  34. repgen_ai-0.1.0.dist-info/RECORD +36 -0
  35. repgen_ai-0.1.0.dist-info/WHEEL +5 -0
  36. repgen_ai-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,151 @@
1
+ import json
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import Any, List, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ from annoy import AnnoyIndex
9
+ from rank_bm25 import BM25Okapi
10
+ from sentence_transformers import CrossEncoder, SentenceTransformer
11
+ from sklearn.preprocessing import normalize
12
+ from transformers import AutoTokenizer
13
+
14
+ from ..core.utils import tokenize
15
+
16
+ # Suppress logs from transformers and sentence_transformers
17
+ logging.getLogger("transformers").setLevel(logging.ERROR)
18
+ logging.getLogger("sentence_transformers").setLevel(logging.ERROR)
19
+
20
+
21
+ class HybridSearchIndex:
22
+ def __init__(
23
+ self,
24
+ embedding_model: str,
25
+ reranker_model: str,
26
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
27
+ config: Optional[Any] = None,
28
+ ):
29
+ self.device = device
30
+ self.encoder = SentenceTransformer(embedding_model, device=device)
31
+ self.cross_encoder = CrossEncoder(reranker_model, device=device)
32
+ self.tokenizer = AutoTokenizer.from_pretrained(reranker_model)
33
+ self.max_seq_length = 512
34
+ self.bm25 = None
35
+ self.embeddings = None
36
+ self.code_chunks = None
37
+ self.annoy_index = None
38
+ self.config = config
39
+
40
+ def build_index(self, code_chunks: List[dict], corpus: List[List[str]]) -> None:
41
+ """Build the hybrid search index."""
42
+ self.code_chunks = code_chunks
43
+ self.bm25 = BM25Okapi(corpus)
44
+
45
+ texts = [chunk["page_content"] for chunk in code_chunks]
46
+ self.embeddings = self.encoder.encode(
47
+ texts, convert_to_tensor=True, show_progress_bar=False
48
+ )
49
+ self.embeddings = normalize(self.embeddings.cpu().numpy())
50
+
51
+ self._build_annoy_index()
52
+
53
+ def _build_annoy_index(self) -> None:
54
+ """Build Annoy index for approximate nearest neighbor search."""
55
+ dim = self.embeddings.shape[1]
56
+ self.annoy_index = AnnoyIndex(dim, "angular")
57
+ for i, vec in enumerate(self.embeddings):
58
+ self.annoy_index.add_item(i, vec)
59
+ self.annoy_index.build(n_trees=50)
60
+
61
+ def save_index(self, index_dir: Path) -> None:
62
+ """Save the index to disk."""
63
+ index_dir.mkdir(exist_ok=True)
64
+
65
+ with open(index_dir / "documents.json", "w") as f:
66
+ json.dump(self.code_chunks, f)
67
+
68
+ np.save(index_dir / "embeddings.npy", self.embeddings)
69
+ self.annoy_index.save(str(index_dir / "annoy_index.ann"))
70
+
71
+ def load_index(self, index_dir: Path) -> None:
72
+ """Load the index from disk."""
73
+ with open(index_dir / "documents.json", "r") as f:
74
+ self.code_chunks = json.load(f)
75
+
76
+ self.embeddings = np.load(index_dir / "embeddings.npy")
77
+
78
+ corpus = [tokenize(doc["page_content"]) for doc in self.code_chunks]
79
+ self.bm25 = BM25Okapi(corpus)
80
+
81
+ self.annoy_index = AnnoyIndex(self.embeddings.shape[1], "angular")
82
+ self.annoy_index.load(str(index_dir / "annoy_index.ann"))
83
+
84
+ def semantic_search(
85
+ self, query_embedding: np.ndarray, top_k: int
86
+ ) -> Tuple[np.ndarray, np.ndarray]:
87
+ """Perform semantic search using the index."""
88
+ indices, distances = self.annoy_index.get_nns_by_vector(
89
+ query_embedding.flatten(), top_k, include_distances=True
90
+ )
91
+ return np.array(indices), 1 - np.array(distances)
92
+
93
+ def search(
94
+ self,
95
+ query: str,
96
+ top_k: int = 200,
97
+ alpha: float = 0.55,
98
+ rerank_top_k: int = 20,
99
+ ann_top_k: int = 200,
100
+ ) -> List[dict]:
101
+ """Perform hybrid search with BM25 and semantic search."""
102
+ if not query or not isinstance(query, str):
103
+ return []
104
+
105
+ if self.config:
106
+ alpha = self.config.ALPHA
107
+ rerank_top_k = self.config.RERANK_TOP_K
108
+
109
+ query = query[:100000] # Truncate very long queries
110
+
111
+ # BM25 search
112
+ query_tokens = tokenize(query)
113
+ bm25_scores = np.array(self.bm25.get_scores(query_tokens))
114
+ bm25_scores = (bm25_scores - np.min(bm25_scores)) / (
115
+ np.max(bm25_scores) - np.min(bm25_scores) + 1e-6
116
+ )
117
+
118
+ # Semantic search
119
+ query_embedding = self.encoder.encode(query, convert_to_tensor=True)
120
+ query_embedding = normalize(query_embedding.cpu().numpy().reshape(1, -1))
121
+ ann_indices, ann_scores = self.semantic_search(query_embedding, ann_top_k)
122
+ ann_indices = np.array(ann_indices, dtype=int)
123
+
124
+ if len(ann_indices) == 0:
125
+ return []
126
+
127
+ # Combine scores
128
+ combined_scores = (1 - alpha) * bm25_scores[ann_indices] + alpha * ann_scores
129
+ combined_indices_sorted = ann_indices[np.argsort(combined_scores)[::-1]]
130
+ top_combined_indices = combined_indices_sorted[:rerank_top_k]
131
+
132
+ # Prepare for cross-encoder
133
+ top_chunks = [self.code_chunks[i] for i in top_combined_indices]
134
+
135
+ # Tokenize with proper truncation
136
+ features = self.tokenizer(
137
+ [query] * len(top_chunks),
138
+ [chunk["page_content"][:100000] for chunk in top_chunks],
139
+ padding=True,
140
+ truncation="longest_first",
141
+ max_length=self.max_seq_length,
142
+ return_tensors="pt",
143
+ ).to(self.device)
144
+
145
+ # Run cross-encoder
146
+ with torch.no_grad():
147
+ rerank_scores = self.cross_encoder.model(**features).logits.squeeze()
148
+
149
+ # Sort by cross-encoder scores
150
+ reranked_indices = np.argsort(rerank_scores.cpu().numpy())[::-1]
151
+ return [top_chunks[i] for i in reranked_indices]
@@ -0,0 +1,166 @@
1
+ import json
2
+ import logging
3
+ from typing import Any, Dict, Optional
4
+
5
+ from .config import Config
6
+ from .core.code_indexer import CodeIndexer
7
+ from .core.dependency_analyzer import DependencyAnalyzer
8
+ from .core.module_analyzer import ModuleAnalyzer
9
+ from .core.training_code_detector import TrainingCodeDetector
10
+ from .core.utils import load_bug_report
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class RetrievalPipeline:
16
+ def __init__(
17
+ self,
18
+ repo_path: str,
19
+ bug_report_path: str,
20
+ output_dir: str,
21
+ config: Config = None,
22
+ ):
23
+ self.config = config or Config(
24
+ repo_path=repo_path, bug_report_path=bug_report_path, output_dir=output_dir
25
+ )
26
+ self.code_indexer = CodeIndexer(self.config)
27
+ self.module_analyzer = ModuleAnalyzer(self.config)
28
+ self.training_detector = TrainingCodeDetector(self.config)
29
+ self.dependency_analyzer = DependencyAnalyzer(self.config)
30
+
31
+ def run_pipeline(self) -> Optional[Dict[str, Any]]:
32
+ """Run full pipeline for the configured bug report and repository."""
33
+ try:
34
+ # 1. Setup paths
35
+ bug_report_path = self.config.BUG_REPORT_FILE
36
+ code_dir = self.config.CODE_DIR
37
+
38
+ if not bug_report_path.exists():
39
+ raise FileNotFoundError(f"Bug report not found: {bug_report_path}")
40
+ if not code_dir.exists():
41
+ raise FileNotFoundError(f"Code directory not found: {code_dir}")
42
+
43
+ # 2. Index codebase
44
+ logger.info(f"Indexing codebase: {code_dir.name}")
45
+ hybrid_index = self.code_indexer.index_codebase(code_dir)
46
+
47
+ # 3. Find relevant code
48
+ logger.info(f"Finding relevant code for bug report: {bug_report_path.name}")
49
+ bug_report = load_bug_report(bug_report_path)
50
+ relevant_code = self.code_indexer.find_relevant_code(
51
+ bug_report, hybrid_index
52
+ )
53
+ if not relevant_code:
54
+ logger.warning(
55
+ f"No relevant code found for bug report: {bug_report_path.name}"
56
+ )
57
+ return None
58
+
59
+ # 4. Analyze modules
60
+ logger.info("Analyzing modules")
61
+ module_report = {
62
+ "bug_report": self.config.PROJECT_ID,
63
+ "modules": self.module_analyzer.analyze_modules(relevant_code),
64
+ }
65
+ # 5. Detect training code
66
+ logger.info("Detecting training code")
67
+ training_report = self.training_detector.detect_training_code(
68
+ module_report, bug_report_path
69
+ )
70
+
71
+ # # 6. Analyze dependencies
72
+ logger.info("Analyzing dependencies")
73
+ dependency_report = self.dependency_analyzer.analyze_training_dependencies(
74
+ training_report
75
+ )
76
+
77
+ # 7. Generate final context
78
+ logger.info("Generating final context")
79
+ self.create_context_files(
80
+ self.config.PROJECT_ID, module_report, dependency_report
81
+ )
82
+ return {
83
+ "status": "success",
84
+ "context_dir": str(self.config.CONTEXT_DIR_OUT),
85
+ }
86
+ except Exception as e:
87
+ logger.error(f"Pipeline failed: {str(e)}")
88
+ raise
89
+
90
+ def create_context_files(self, project_id, module_report, dependency_report):
91
+ # Create the context directory if it doesn't exist
92
+ context_dir = self.config.CONTEXT_DIR_OUT
93
+
94
+ dependencies = dependency_report.get("dependencies", [])
95
+ if not dependencies:
96
+ for i, module in enumerate(module_report["modules"], 1):
97
+ context = {
98
+ "bug_report": project_id,
99
+ "module": module,
100
+ "module_snippets": [],
101
+ }
102
+
103
+ # Find relevant module snippets for this file
104
+ for file in module["files"]:
105
+ context["module_snippets"].append(
106
+ {"file": file["path"], "snippets": file["snippets"]}
107
+ )
108
+
109
+ # Create the output file path
110
+ output_filename = f"{project_id}_module_{i}.json"
111
+ filename = context_dir / output_filename
112
+
113
+ # Write to file
114
+ with open(filename, "w") as f:
115
+ json.dump(context, f, indent=2)
116
+
117
+ else:
118
+ for i, dep in enumerate(dependencies[:5], 1):
119
+ training_file_path = self.config.CODE_DIR / dep["file"]
120
+ try:
121
+ with open(training_file_path, "r") as f:
122
+ training_file_content = f.read()
123
+ except FileNotFoundError:
124
+ print(f"Warning: File not found - {training_file_path}")
125
+ training_file_content = (
126
+ f"Content not available for {training_file_path}"
127
+ )
128
+
129
+ # Create the context structure
130
+ context = {
131
+ "bug_report": project_id,
132
+ "rank": i,
133
+ "score": dep["score"],
134
+ "main_file": {
135
+ "path": dep["file"],
136
+ "content": training_file_content,
137
+ },
138
+ "module_snippets": [],
139
+ "dependencies": dep["external_dependencies"],
140
+ }
141
+
142
+ # Find relevant module snippets for this file
143
+ for module in module_report["modules"]:
144
+ for file in module["files"]:
145
+ if file["path"] == dep["file"]:
146
+ context["module_snippets"].append(
147
+ {"file": file["path"], "snippets": file["snippets"]}
148
+ )
149
+ else:
150
+ # Check if this module file is a dependency
151
+ for ext_dep in dep["external_dependencies"]:
152
+ if ext_dep["module"].replace("/", "_") in file["path"]:
153
+ context["module_snippets"].append(
154
+ {
155
+ "file": file["path"],
156
+ "snippets": file["snippets"],
157
+ }
158
+ )
159
+
160
+ # Create the output file path
161
+ output_filename = f"{project_id}_{i}.json"
162
+ filename = context_dir / output_filename
163
+
164
+ # Write to file
165
+ with open(filename, "w") as f:
166
+ json.dump(context, f, indent=2)
repgen/server.py ADDED
@@ -0,0 +1,111 @@
1
+ import uuid
2
+ from typing import Any, Dict, Optional
3
+
4
+ from fastapi import BackgroundTasks, FastAPI, HTTPException
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from pydantic import BaseModel
7
+
8
+ from .core import RepGenService
9
+
10
+ app = FastAPI(title="RepGen API")
11
+
12
+ # Configure CORS
13
+ app.add_middleware(
14
+ CORSMiddleware,
15
+ allow_origins=["*"],
16
+ allow_credentials=True,
17
+ allow_methods=["*"],
18
+ allow_headers=["*"],
19
+ )
20
+
21
+ # In-memory store for task status
22
+ # (In production, use Redis/db)
23
+ tasks: Dict[str, Dict[str, Any]] = {}
24
+
25
+
26
+ class ReproductionRequest(BaseModel):
27
+ bug_report: str
28
+ repo_url: str
29
+ backend: str = "openai" # Default
30
+ model: str = "gpt-4o"
31
+ backend: str = "openai" # Default
32
+ model: str = "gpt-4o"
33
+ commit: Optional[str] = None
34
+ api_key: Optional[str] = None
35
+
36
+
37
+ def run_reproduction_task(task_id: str, req: ReproductionRequest):
38
+ service = RepGenService(output_dir="./repgen_results")
39
+
40
+ tasks[task_id]["status"] = "running"
41
+
42
+ def progress_cb(stage, msg, data=None):
43
+ tasks[task_id]["logs"].append(f"[{stage}] {msg}")
44
+ tasks[task_id]["stage"] = stage
45
+ if data:
46
+ if "artifacts" not in tasks[task_id]:
47
+ tasks[task_id]["artifacts"] = {}
48
+ if data["type"] == "refined_report":
49
+ tasks[task_id]["artifacts"]["refined_report"] = data["content"]
50
+ elif data["type"] == "plan":
51
+ tasks[task_id]["artifacts"]["plan"] = data["content"]
52
+ elif data["type"] == "code":
53
+ tasks[task_id]["artifacts"]["code"] = data["content"]
54
+ tasks[task_id]["artifacts"]["code_path"] = data["path"]
55
+ elif data["type"] == "context":
56
+ tasks[task_id]["artifacts"]["context"] = data[
57
+ "content"
58
+ ] # Maybe accumulate? For now, last one.
59
+
60
+ try:
61
+ result = service.run_reproduction(
62
+ bug_report_source=req.bug_report,
63
+ repo_source=req.repo_url,
64
+ backend=req.backend,
65
+ model=req.model,
66
+ commit=req.commit,
67
+ api_key=req.api_key,
68
+ progress_callback=progress_cb,
69
+ )
70
+
71
+ if result["success"]:
72
+ tasks[task_id]["status"] = "completed"
73
+ tasks[task_id]["result"] = result[
74
+ "files"
75
+ ] # This will now be a list of dicts {path, content}
76
+ else:
77
+ tasks[task_id]["status"] = "failed"
78
+ tasks[task_id]["error"] = result["error"]
79
+
80
+ except Exception as e:
81
+ tasks[task_id]["status"] = "failed"
82
+ tasks[task_id]["error"] = str(e)
83
+
84
+
85
+ @app.post("/api/reproduce")
86
+ async def start_reproduction(
87
+ req: ReproductionRequest, background_tasks: BackgroundTasks
88
+ ):
89
+ task_id = str(uuid.uuid4())
90
+ tasks[task_id] = {
91
+ "status": "pending",
92
+ "stage": "init",
93
+ "logs": [],
94
+ "artifacts": {},
95
+ "result": None,
96
+ "error": None,
97
+ }
98
+
99
+ background_tasks.add_task(run_reproduction_task, task_id, req)
100
+ return {"task_id": task_id}
101
+
102
+
103
+ @app.get("/api/status/{task_id}")
104
+ async def get_status(task_id: str):
105
+ if task_id not in tasks:
106
+ raise HTTPException(status_code=404, detail="Task not found")
107
+ return tasks[task_id]
108
+
109
+
110
+ # Serve UI (We will build this next)
111
+ # app.mount("/", StaticFiles(directory="ui/dist", html=True), name="ui")