sourcefire 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sourcefire/__init__.py +0 -0
- sourcefire/api/__init__.py +0 -0
- sourcefire/api/models.py +24 -0
- sourcefire/api/routes.py +166 -0
- sourcefire/chain/__init__.py +0 -0
- sourcefire/chain/prompts.py +195 -0
- sourcefire/chain/rag_chain.py +967 -0
- sourcefire/cli.py +293 -0
- sourcefire/config.py +148 -0
- sourcefire/db.py +196 -0
- sourcefire/indexer/__init__.py +0 -0
- sourcefire/indexer/embeddings.py +27 -0
- sourcefire/indexer/language_profiles.py +448 -0
- sourcefire/indexer/metadata.py +289 -0
- sourcefire/indexer/pipeline.py +406 -0
- sourcefire/init.py +189 -0
- sourcefire/prompts/system.md +28 -0
- sourcefire/retriever/__init__.py +0 -0
- sourcefire/retriever/graph.py +162 -0
- sourcefire/retriever/search.py +86 -0
- sourcefire/static/.DS_Store +0 -0
- sourcefire/static/app.js +414 -0
- sourcefire/static/index.html +102 -0
- sourcefire/static/styles.css +607 -0
- sourcefire/watcher.py +105 -0
- sourcefire-0.2.0.dist-info/METADATA +145 -0
- sourcefire-0.2.0.dist-info/RECORD +31 -0
- sourcefire-0.2.0.dist-info/WHEEL +5 -0
- sourcefire-0.2.0.dist-info/entry_points.txt +2 -0
- sourcefire-0.2.0.dist-info/licenses/LICENSE +21 -0
- sourcefire-0.2.0.dist-info/top_level.txt +1 -0
sourcefire/cli.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
"""Sourcefire CLI — single command entry point."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import argparse
|
|
6
|
+
import asyncio
|
|
7
|
+
import fcntl
|
|
8
|
+
import os
|
|
9
|
+
import sys
|
|
10
|
+
import webbrowser
|
|
11
|
+
from contextlib import asynccontextmanager
|
|
12
|
+
from datetime import datetime, timezone
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
|
|
15
|
+
import uvicorn
|
|
16
|
+
from dotenv import load_dotenv
|
|
17
|
+
from fastapi import FastAPI
|
|
18
|
+
from fastapi.responses import FileResponse
|
|
19
|
+
from fastapi.staticfiles import StaticFiles
|
|
20
|
+
|
|
21
|
+
load_dotenv()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def parse_args() -> argparse.Namespace:
|
|
25
|
+
parser = argparse.ArgumentParser(
|
|
26
|
+
prog="sourcefire",
|
|
27
|
+
description="Sourcefire — AI-powered codebase RAG from your terminal",
|
|
28
|
+
)
|
|
29
|
+
parser.add_argument("--port", type=int, default=None, help="Server port (default: from config or 8000)")
|
|
30
|
+
parser.add_argument("--no-open", action="store_true", help="Don't auto-open browser")
|
|
31
|
+
parser.add_argument("--reinit", action="store_true", help="Regenerate .sourcefire/config.toml via LLM")
|
|
32
|
+
parser.add_argument("--verbose", action="store_true", help="Verbose logging")
|
|
33
|
+
return parser.parse_args()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def discover_project() -> tuple[Path, Path]:
|
|
37
|
+
"""Walk up from cwd to find .sourcefire/, like git finds .git/.
|
|
38
|
+
|
|
39
|
+
Returns (project_dir, sourcefire_dir).
|
|
40
|
+
If not found, returns (cwd, cwd/.sourcefire).
|
|
41
|
+
"""
|
|
42
|
+
current = Path.cwd().resolve()
|
|
43
|
+
while True:
|
|
44
|
+
candidate = current / ".sourcefire"
|
|
45
|
+
if candidate.is_dir():
|
|
46
|
+
return current, candidate
|
|
47
|
+
parent = current.parent
|
|
48
|
+
if parent == current:
|
|
49
|
+
break
|
|
50
|
+
current = parent
|
|
51
|
+
|
|
52
|
+
cwd = Path.cwd().resolve()
|
|
53
|
+
return cwd, cwd / ".sourcefire"
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def acquire_lock(lock_path: Path) -> int | None:
|
|
57
|
+
"""Acquire an exclusive file lock. Returns fd on success, None on failure."""
|
|
58
|
+
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
|
59
|
+
try:
|
|
60
|
+
fd = os.open(str(lock_path), os.O_CREAT | os.O_RDWR)
|
|
61
|
+
fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
|
62
|
+
return fd
|
|
63
|
+
except (OSError, BlockingIOError):
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def release_lock(fd: int, lock_path: Path) -> None:
|
|
68
|
+
"""Release the file lock."""
|
|
69
|
+
try:
|
|
70
|
+
fcntl.flock(fd, fcntl.LOCK_UN)
|
|
71
|
+
os.close(fd)
|
|
72
|
+
except OSError:
|
|
73
|
+
pass
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
# ---------------------------------------------------------------------------
|
|
77
|
+
# App state (shared between main() and lifespan)
|
|
78
|
+
# ---------------------------------------------------------------------------
|
|
79
|
+
|
|
80
|
+
_app_state: dict = {}
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
# ---------------------------------------------------------------------------
|
|
84
|
+
# Lifespan
|
|
85
|
+
# ---------------------------------------------------------------------------
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@asynccontextmanager
|
|
89
|
+
async def lifespan(app: FastAPI):
|
|
90
|
+
"""Start-up: index codebase, build graph, start watcher."""
|
|
91
|
+
config = _app_state["config"]
|
|
92
|
+
project_dir = _app_state["project_dir"]
|
|
93
|
+
api_key = _app_state["api_key"]
|
|
94
|
+
args = _app_state["args"]
|
|
95
|
+
|
|
96
|
+
from sourcefire.db import create_client, get_collection
|
|
97
|
+
from sourcefire.indexer.language_profiles import get_profile
|
|
98
|
+
from sourcefire.indexer.pipeline import run_indexing
|
|
99
|
+
from sourcefire.retriever.graph import ImportGraph
|
|
100
|
+
from sourcefire.api.routes import init_dependencies
|
|
101
|
+
from sourcefire.watcher import watch_and_reindex
|
|
102
|
+
|
|
103
|
+
print(f"[sourcefire] Project: {project_dir.name}")
|
|
104
|
+
print(f"[sourcefire] Config: {config.config_path}")
|
|
105
|
+
|
|
106
|
+
# Detect language
|
|
107
|
+
language_override = config.language if config.language != "auto" else None
|
|
108
|
+
profile = get_profile(project_dir, language_override)
|
|
109
|
+
lang_name = profile.language if profile else "generic"
|
|
110
|
+
print(f"[sourcefire] Language: {lang_name}")
|
|
111
|
+
|
|
112
|
+
# Create ChromaDB client
|
|
113
|
+
client = create_client(config.chroma_dir)
|
|
114
|
+
collection = get_collection(client)
|
|
115
|
+
|
|
116
|
+
# Determine if this is a first run (empty collection)
|
|
117
|
+
existing_count = collection.count()
|
|
118
|
+
is_first_run = existing_count == 0
|
|
119
|
+
|
|
120
|
+
# Run indexing
|
|
121
|
+
if is_first_run:
|
|
122
|
+
print("[sourcefire] First run — full index...")
|
|
123
|
+
stats = run_indexing(collection, config, client=client, full=True)
|
|
124
|
+
else:
|
|
125
|
+
print("[sourcefire] Checking for changes...")
|
|
126
|
+
stats = run_indexing(collection, config, client=client, full=False)
|
|
127
|
+
|
|
128
|
+
print(f"[sourcefire] Indexed: {stats['files']} files, {stats['chunks']} chunks")
|
|
129
|
+
|
|
130
|
+
# Build import graph
|
|
131
|
+
external_prefixes = profile.external_import_prefixes if profile else ()
|
|
132
|
+
graph = ImportGraph(external_prefixes=external_prefixes)
|
|
133
|
+
|
|
134
|
+
import_edges = stats.get("import_edges", {})
|
|
135
|
+
if import_edges:
|
|
136
|
+
for source_file, imports in import_edges.items():
|
|
137
|
+
for imp in imports:
|
|
138
|
+
resolved = ImportGraph._resolve_import(source_file, imp)
|
|
139
|
+
graph.add_edge(source_file, resolved)
|
|
140
|
+
elif config.graph_path.is_file():
|
|
141
|
+
graph = ImportGraph.load(config.graph_path, external_prefixes=external_prefixes)
|
|
142
|
+
|
|
143
|
+
print(f"[sourcefire] Import graph: {graph.node_count} nodes")
|
|
144
|
+
|
|
145
|
+
# Build index status
|
|
146
|
+
index_status = {
|
|
147
|
+
"files_indexed": stats.get("files", 0),
|
|
148
|
+
"last_indexed": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"),
|
|
149
|
+
"index_status": "ready",
|
|
150
|
+
"language": lang_name,
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
# Inject dependencies into routes
|
|
154
|
+
init_dependencies(collection, graph, index_status, profile, project_dir, api_key)
|
|
155
|
+
|
|
156
|
+
# Start file watcher
|
|
157
|
+
watcher_task = asyncio.create_task(
|
|
158
|
+
watch_and_reindex(config, collection, graph, profile)
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# Open browser
|
|
162
|
+
url = f"http://{config.host}:{config.port}"
|
|
163
|
+
print(f"[sourcefire] Ready — {url}")
|
|
164
|
+
if not args.no_open:
|
|
165
|
+
webbrowser.open(url)
|
|
166
|
+
|
|
167
|
+
yield
|
|
168
|
+
|
|
169
|
+
# Shutdown
|
|
170
|
+
print("[sourcefire] Shutting down...")
|
|
171
|
+
watcher_task.cancel()
|
|
172
|
+
try:
|
|
173
|
+
await watcher_task
|
|
174
|
+
except asyncio.CancelledError:
|
|
175
|
+
pass
|
|
176
|
+
|
|
177
|
+
# Save graph
|
|
178
|
+
graph.save(config.graph_path)
|
|
179
|
+
print("[sourcefire] Graph saved.")
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
# ---------------------------------------------------------------------------
|
|
183
|
+
# App
|
|
184
|
+
# ---------------------------------------------------------------------------
|
|
185
|
+
|
|
186
|
+
from importlib.resources import files as _resource_files
|
|
187
|
+
|
|
188
|
+
_static_dir = str(Path(_resource_files("sourcefire")) / "static")
|
|
189
|
+
|
|
190
|
+
app = FastAPI(
|
|
191
|
+
title="Sourcefire",
|
|
192
|
+
description="AI-powered codebase RAG. Created by Athar Wani.",
|
|
193
|
+
version="0.2.0",
|
|
194
|
+
lifespan=lifespan,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
from sourcefire.api.routes import router # noqa: E402
|
|
198
|
+
|
|
199
|
+
app.include_router(router)
|
|
200
|
+
app.mount("/static", StaticFiles(directory=_static_dir), name="static")
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
@app.get("/", include_in_schema=False)
|
|
204
|
+
async def root() -> FileResponse:
|
|
205
|
+
return FileResponse(os.path.join(_static_dir, "index.html"))
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
# ---------------------------------------------------------------------------
|
|
209
|
+
# Main
|
|
210
|
+
# ---------------------------------------------------------------------------
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def main() -> None:
|
|
214
|
+
"""Sourcefire CLI entry point."""
|
|
215
|
+
args = parse_args()
|
|
216
|
+
|
|
217
|
+
project_dir, sourcefire_dir = discover_project()
|
|
218
|
+
|
|
219
|
+
# Acquire lock
|
|
220
|
+
lock_fd = acquire_lock(sourcefire_dir / ".lock")
|
|
221
|
+
if lock_fd is None:
|
|
222
|
+
print("Error: Another sourcefire instance is already running for this project.")
|
|
223
|
+
sys.exit(1)
|
|
224
|
+
|
|
225
|
+
# Check for API key — prompt interactively if missing
|
|
226
|
+
api_key = os.getenv("GEMINI_API_KEY", "")
|
|
227
|
+
if not api_key:
|
|
228
|
+
print("No GEMINI_API_KEY found in environment.")
|
|
229
|
+
try:
|
|
230
|
+
api_key = input("Enter your Gemini API key: ").strip()
|
|
231
|
+
except (EOFError, KeyboardInterrupt):
|
|
232
|
+
print("\nAborted.")
|
|
233
|
+
release_lock(lock_fd, sourcefire_dir / ".lock")
|
|
234
|
+
sys.exit(1)
|
|
235
|
+
|
|
236
|
+
if not api_key:
|
|
237
|
+
print("Error: A Gemini API key is required to run Sourcefire.")
|
|
238
|
+
release_lock(lock_fd, sourcefire_dir / ".lock")
|
|
239
|
+
sys.exit(1)
|
|
240
|
+
|
|
241
|
+
# Persist to .env in project root
|
|
242
|
+
env_path = project_dir / ".env"
|
|
243
|
+
with open(env_path, "a") as f:
|
|
244
|
+
f.write(f"\nGEMINI_API_KEY={api_key}\n")
|
|
245
|
+
os.environ["GEMINI_API_KEY"] = api_key
|
|
246
|
+
print(f"API key saved to {env_path}")
|
|
247
|
+
|
|
248
|
+
# Auto-init or reinit
|
|
249
|
+
needs_init = not sourcefire_dir.exists() or not (sourcefire_dir / "config.toml").exists()
|
|
250
|
+
|
|
251
|
+
if needs_init:
|
|
252
|
+
from sourcefire.init import auto_init
|
|
253
|
+
config = auto_init(
|
|
254
|
+
project_dir=project_dir,
|
|
255
|
+
sourcefire_dir=sourcefire_dir,
|
|
256
|
+
api_key=api_key,
|
|
257
|
+
)
|
|
258
|
+
elif args.reinit:
|
|
259
|
+
from sourcefire.config import load_config
|
|
260
|
+
from sourcefire.init import reinit_patterns
|
|
261
|
+
config = load_config(project_dir, sourcefire_dir)
|
|
262
|
+
config = reinit_patterns(config, api_key=api_key)
|
|
263
|
+
else:
|
|
264
|
+
from sourcefire.config import load_config
|
|
265
|
+
config = load_config(project_dir, sourcefire_dir)
|
|
266
|
+
|
|
267
|
+
# Override port from CLI
|
|
268
|
+
if args.port:
|
|
269
|
+
config.port = args.port
|
|
270
|
+
|
|
271
|
+
# Store state for lifespan access
|
|
272
|
+
_app_state["config"] = config
|
|
273
|
+
_app_state["project_dir"] = project_dir
|
|
274
|
+
_app_state["sourcefire_dir"] = sourcefire_dir
|
|
275
|
+
_app_state["api_key"] = api_key
|
|
276
|
+
_app_state["args"] = args
|
|
277
|
+
_app_state["lock_fd"] = lock_fd
|
|
278
|
+
|
|
279
|
+
# Run server
|
|
280
|
+
try:
|
|
281
|
+
uvicorn.run(
|
|
282
|
+
"sourcefire.cli:app",
|
|
283
|
+
host=config.host,
|
|
284
|
+
port=config.port,
|
|
285
|
+
reload=False,
|
|
286
|
+
log_level="info" if args.verbose else "warning",
|
|
287
|
+
)
|
|
288
|
+
finally:
|
|
289
|
+
release_lock(lock_fd, sourcefire_dir / ".lock")
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
if __name__ == "__main__":
|
|
293
|
+
main()
|
sourcefire/config.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
"""Configuration for Sourcefire — loaded from .sourcefire/config.toml."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import tomllib
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
import tomli_w
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class SourcefireConfig:
|
|
15
|
+
"""All Sourcefire configuration for a project."""
|
|
16
|
+
|
|
17
|
+
# Resolved at runtime, not stored in TOML
|
|
18
|
+
project_dir: Path = field(default_factory=Path.cwd)
|
|
19
|
+
sourcefire_dir: Path = field(default_factory=lambda: Path.cwd() / ".sourcefire")
|
|
20
|
+
|
|
21
|
+
# [project]
|
|
22
|
+
project_name: str = ""
|
|
23
|
+
language: str = "auto"
|
|
24
|
+
|
|
25
|
+
# [indexer]
|
|
26
|
+
include: list[str] = field(default_factory=list)
|
|
27
|
+
exclude: list[str] = field(default_factory=list)
|
|
28
|
+
chunk_size: int = 1000
|
|
29
|
+
chunk_overlap: int = 300
|
|
30
|
+
|
|
31
|
+
# [llm]
|
|
32
|
+
provider: str = "gemini"
|
|
33
|
+
model: str = "gemini-2.5-flash"
|
|
34
|
+
api_key_env: str = "GEMINI_API_KEY"
|
|
35
|
+
|
|
36
|
+
# [server]
|
|
37
|
+
host: str = "127.0.0.1"
|
|
38
|
+
port: int = 8000
|
|
39
|
+
|
|
40
|
+
# [retrieval]
|
|
41
|
+
top_k: int = 8
|
|
42
|
+
relevance_threshold: float = 0.3
|
|
43
|
+
|
|
44
|
+
# Versioning
|
|
45
|
+
config_version: int = 1
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def gemini_api_key(self) -> str:
|
|
49
|
+
return os.getenv(self.api_key_env, "")
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def chroma_dir(self) -> Path:
|
|
53
|
+
return self.sourcefire_dir / "chroma"
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def graph_path(self) -> Path:
|
|
57
|
+
return self.sourcefire_dir / "graph.json"
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def config_path(self) -> Path:
|
|
61
|
+
return self.sourcefire_dir / "config.toml"
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def lock_path(self) -> Path:
|
|
65
|
+
return self.sourcefire_dir / ".lock"
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# Constants used by other modules
|
|
69
|
+
EMBEDDING_MODEL: str = "sentence-transformers/all-MiniLM-L6-v2"
|
|
70
|
+
MAX_TOKEN_BUDGET: dict[str, int] = {
|
|
71
|
+
"gemini-2.5-flash": 100_000,
|
|
72
|
+
"gemini-2.5-pro": 200_000,
|
|
73
|
+
}
|
|
74
|
+
MAX_HISTORY_PAIRS: int = 5
|
|
75
|
+
RESPONSE_HEADROOM: int = 8_000
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def default_config(project_dir: Path) -> SourcefireConfig:
|
|
79
|
+
"""Return a SourcefireConfig with sensible defaults for the given project."""
|
|
80
|
+
return SourcefireConfig(
|
|
81
|
+
project_dir=project_dir,
|
|
82
|
+
sourcefire_dir=project_dir / ".sourcefire",
|
|
83
|
+
project_name=project_dir.name,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def load_config(project_dir: Path, sourcefire_dir: Path) -> SourcefireConfig:
|
|
88
|
+
"""Load config from .sourcefire/config.toml."""
|
|
89
|
+
config_path = sourcefire_dir / "config.toml"
|
|
90
|
+
raw = config_path.read_text(encoding="utf-8")
|
|
91
|
+
data = tomllib.loads(raw)
|
|
92
|
+
|
|
93
|
+
project = data.get("project", {})
|
|
94
|
+
indexer = data.get("indexer", {})
|
|
95
|
+
llm = data.get("llm", {})
|
|
96
|
+
server = data.get("server", {})
|
|
97
|
+
retrieval = data.get("retrieval", {})
|
|
98
|
+
|
|
99
|
+
return SourcefireConfig(
|
|
100
|
+
project_dir=project_dir,
|
|
101
|
+
sourcefire_dir=sourcefire_dir,
|
|
102
|
+
config_version=data.get("config_version", 1),
|
|
103
|
+
project_name=project.get("name", project_dir.name),
|
|
104
|
+
language=project.get("language", "auto"),
|
|
105
|
+
include=indexer.get("include", []),
|
|
106
|
+
exclude=indexer.get("exclude", []),
|
|
107
|
+
chunk_size=indexer.get("chunk_size", 1000),
|
|
108
|
+
chunk_overlap=indexer.get("chunk_overlap", 300),
|
|
109
|
+
provider=llm.get("provider", "gemini"),
|
|
110
|
+
model=llm.get("model", "gemini-2.5-flash"),
|
|
111
|
+
api_key_env=llm.get("api_key_env", "GEMINI_API_KEY"),
|
|
112
|
+
host=server.get("host", "127.0.0.1"),
|
|
113
|
+
port=server.get("port", 8000),
|
|
114
|
+
top_k=retrieval.get("top_k", 8),
|
|
115
|
+
relevance_threshold=retrieval.get("relevance_threshold", 0.3),
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def save_config(config: SourcefireConfig) -> None:
|
|
120
|
+
"""Write config to .sourcefire/config.toml."""
|
|
121
|
+
data = {
|
|
122
|
+
"config_version": config.config_version,
|
|
123
|
+
"project": {
|
|
124
|
+
"name": config.project_name,
|
|
125
|
+
"language": config.language,
|
|
126
|
+
},
|
|
127
|
+
"indexer": {
|
|
128
|
+
"include": config.include,
|
|
129
|
+
"exclude": config.exclude,
|
|
130
|
+
"chunk_size": config.chunk_size,
|
|
131
|
+
"chunk_overlap": config.chunk_overlap,
|
|
132
|
+
},
|
|
133
|
+
"llm": {
|
|
134
|
+
"provider": config.provider,
|
|
135
|
+
"model": config.model,
|
|
136
|
+
"api_key_env": config.api_key_env,
|
|
137
|
+
},
|
|
138
|
+
"server": {
|
|
139
|
+
"host": config.host,
|
|
140
|
+
"port": config.port,
|
|
141
|
+
},
|
|
142
|
+
"retrieval": {
|
|
143
|
+
"top_k": config.top_k,
|
|
144
|
+
"relevance_threshold": config.relevance_threshold,
|
|
145
|
+
},
|
|
146
|
+
}
|
|
147
|
+
config.config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
148
|
+
config.config_path.write_text(tomli_w.dumps(data), encoding="utf-8")
|
sourcefire/db.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
"""ChromaDB wrapper for Sourcefire — async-safe via run_in_executor."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
from functools import partial
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import chromadb
|
|
11
|
+
|
|
12
|
+
COLLECTION_NAME = "code_chunks"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def create_client(chroma_dir: Path) -> chromadb.ClientAPI:
|
|
16
|
+
"""Create a persistent ChromaDB client."""
|
|
17
|
+
chroma_dir.mkdir(parents=True, exist_ok=True)
|
|
18
|
+
return chromadb.PersistentClient(path=str(chroma_dir))
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_collection(client: chromadb.ClientAPI) -> chromadb.Collection:
|
|
22
|
+
"""Get or create the code_chunks collection."""
|
|
23
|
+
return client.get_or_create_collection(
|
|
24
|
+
name=COLLECTION_NAME,
|
|
25
|
+
metadata={"hnsw:space": "cosine"},
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def reset_collection(client: chromadb.ClientAPI) -> chromadb.Collection:
|
|
30
|
+
"""Delete and recreate the collection (for full re-index)."""
|
|
31
|
+
try:
|
|
32
|
+
client.delete_collection(COLLECTION_NAME)
|
|
33
|
+
except ValueError:
|
|
34
|
+
pass
|
|
35
|
+
return get_collection(client)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# ---------------------------------------------------------------------------
|
|
39
|
+
# Sync operations
|
|
40
|
+
# ---------------------------------------------------------------------------
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def add_chunks(
|
|
44
|
+
collection: chromadb.Collection,
|
|
45
|
+
ids: list[str],
|
|
46
|
+
documents: list[str],
|
|
47
|
+
embeddings: list[list[float]],
|
|
48
|
+
metadatas: list[dict[str, str]],
|
|
49
|
+
) -> None:
|
|
50
|
+
"""Add chunks to the collection."""
|
|
51
|
+
collection.add(
|
|
52
|
+
ids=ids,
|
|
53
|
+
documents=documents,
|
|
54
|
+
embeddings=embeddings,
|
|
55
|
+
metadatas=metadatas,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def delete_file_chunks(collection: chromadb.Collection, filename: str) -> None:
|
|
60
|
+
"""Delete all chunks for a given filename."""
|
|
61
|
+
collection.delete(where={"filename": filename})
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def query_similar(
|
|
65
|
+
collection: chromadb.Collection,
|
|
66
|
+
query_embedding: list[float],
|
|
67
|
+
n_results: int = 8,
|
|
68
|
+
where: dict | None = None,
|
|
69
|
+
) -> list[dict[str, Any]]:
|
|
70
|
+
"""Query for similar chunks. Returns list of result dicts."""
|
|
71
|
+
kwargs: dict[str, Any] = {
|
|
72
|
+
"query_embeddings": [query_embedding],
|
|
73
|
+
"n_results": n_results,
|
|
74
|
+
"include": ["documents", "metadatas", "distances"],
|
|
75
|
+
}
|
|
76
|
+
if where:
|
|
77
|
+
kwargs["where"] = where
|
|
78
|
+
|
|
79
|
+
results = collection.query(**kwargs)
|
|
80
|
+
|
|
81
|
+
rows: list[dict[str, Any]] = []
|
|
82
|
+
if not results["ids"] or not results["ids"][0]:
|
|
83
|
+
return rows
|
|
84
|
+
|
|
85
|
+
for i, doc_id in enumerate(results["ids"][0]):
|
|
86
|
+
meta = results["metadatas"][0][i] if results["metadatas"] else {}
|
|
87
|
+
distance = results["distances"][0][i] if results["distances"] else 1.0
|
|
88
|
+
relevance = 1.0 - distance # cosine distance -> similarity
|
|
89
|
+
|
|
90
|
+
rows.append({
|
|
91
|
+
"filename": meta.get("filename", ""),
|
|
92
|
+
"location": meta.get("location", ""),
|
|
93
|
+
"code": results["documents"][0][i] if results["documents"] else "",
|
|
94
|
+
"feature": meta.get("feature", ""),
|
|
95
|
+
"layer": meta.get("layer", ""),
|
|
96
|
+
"file_type": meta.get("file_type", ""),
|
|
97
|
+
"relevance": relevance,
|
|
98
|
+
})
|
|
99
|
+
|
|
100
|
+
return rows
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def get_chunks_by_files(
|
|
104
|
+
collection: chromadb.Collection,
|
|
105
|
+
filenames: list[str],
|
|
106
|
+
) -> list[dict[str, Any]]:
|
|
107
|
+
"""Retrieve all chunks for the given filenames."""
|
|
108
|
+
if not filenames:
|
|
109
|
+
return []
|
|
110
|
+
|
|
111
|
+
results = collection.get(
|
|
112
|
+
where={"filename": {"$in": filenames}},
|
|
113
|
+
include=["documents", "metadatas"],
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
rows: list[dict[str, Any]] = []
|
|
117
|
+
if not results["ids"]:
|
|
118
|
+
return rows
|
|
119
|
+
|
|
120
|
+
for i, doc_id in enumerate(results["ids"]):
|
|
121
|
+
meta = results["metadatas"][i] if results["metadatas"] else {}
|
|
122
|
+
rows.append({
|
|
123
|
+
"filename": meta.get("filename", ""),
|
|
124
|
+
"location": meta.get("location", ""),
|
|
125
|
+
"code": results["documents"][i] if results["documents"] else "",
|
|
126
|
+
"feature": meta.get("feature", ""),
|
|
127
|
+
"layer": meta.get("layer", ""),
|
|
128
|
+
"file_type": meta.get("file_type", ""),
|
|
129
|
+
})
|
|
130
|
+
|
|
131
|
+
return rows
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def get_indexed_files(collection: chromadb.Collection) -> set[str]:
|
|
135
|
+
"""Return set of all filenames currently in the collection."""
|
|
136
|
+
results = collection.get(include=["metadatas"])
|
|
137
|
+
files: set[str] = set()
|
|
138
|
+
if results["metadatas"]:
|
|
139
|
+
for meta in results["metadatas"]:
|
|
140
|
+
if meta and "filename" in meta:
|
|
141
|
+
files.add(meta["filename"])
|
|
142
|
+
return files
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def get_stored_mtimes(collection: chromadb.Collection) -> dict[str, float]:
|
|
146
|
+
"""Get stored mtimes for all indexed files from ChromaDB metadata."""
|
|
147
|
+
results = collection.get(include=["metadatas"])
|
|
148
|
+
mtimes: dict[str, float] = {}
|
|
149
|
+
if results["metadatas"]:
|
|
150
|
+
for meta in results["metadatas"]:
|
|
151
|
+
if meta and "filename" in meta and "mtime" in meta:
|
|
152
|
+
try:
|
|
153
|
+
mtimes[meta["filename"]] = float(meta["mtime"])
|
|
154
|
+
except (ValueError, TypeError):
|
|
155
|
+
pass
|
|
156
|
+
return mtimes
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
# ---------------------------------------------------------------------------
|
|
160
|
+
# Async wrappers (for use in FastAPI routes / RAG chain)
|
|
161
|
+
# ---------------------------------------------------------------------------
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
async def async_query_similar(
|
|
165
|
+
collection: chromadb.Collection,
|
|
166
|
+
query_embedding: list[float],
|
|
167
|
+
n_results: int = 8,
|
|
168
|
+
where: dict | None = None,
|
|
169
|
+
) -> list[dict[str, Any]]:
|
|
170
|
+
"""Async wrapper for query_similar."""
|
|
171
|
+
loop = asyncio.get_event_loop()
|
|
172
|
+
return await loop.run_in_executor(
|
|
173
|
+
None, partial(query_similar, collection, query_embedding, n_results, where)
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
async def async_get_chunks_by_files(
|
|
178
|
+
collection: chromadb.Collection,
|
|
179
|
+
filenames: list[str],
|
|
180
|
+
) -> list[dict[str, Any]]:
|
|
181
|
+
"""Async wrapper for get_chunks_by_files."""
|
|
182
|
+
loop = asyncio.get_event_loop()
|
|
183
|
+
return await loop.run_in_executor(
|
|
184
|
+
None, partial(get_chunks_by_files, collection, filenames)
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
async def async_delete_file_chunks(
|
|
189
|
+
collection: chromadb.Collection,
|
|
190
|
+
filename: str,
|
|
191
|
+
) -> None:
|
|
192
|
+
"""Async wrapper for delete_file_chunks."""
|
|
193
|
+
loop = asyncio.get_event_loop()
|
|
194
|
+
await loop.run_in_executor(
|
|
195
|
+
None, partial(delete_file_chunks, collection, filename)
|
|
196
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Shared embedding module using sentence-transformers directly."""
|
|
2
|
+
|
|
3
|
+
from sentence_transformers import SentenceTransformer
|
|
4
|
+
from sourcefire.config import EMBEDDING_MODEL
|
|
5
|
+
|
|
6
|
+
_model: SentenceTransformer | None = None
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def get_model() -> SentenceTransformer:
|
|
10
|
+
global _model
|
|
11
|
+
if _model is None:
|
|
12
|
+
print(f"Loading embedding model: {EMBEDDING_MODEL}...")
|
|
13
|
+
_model = SentenceTransformer(EMBEDDING_MODEL)
|
|
14
|
+
print("Embedding model loaded.")
|
|
15
|
+
return _model
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def embed_text(text: str) -> list[float]:
|
|
19
|
+
"""Embed a single text string. Returns a list of floats."""
|
|
20
|
+
model = get_model()
|
|
21
|
+
return model.encode(text).tolist()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def embed_batch(texts: list[str]) -> list[list[float]]:
|
|
25
|
+
"""Embed a batch of texts. More efficient than calling embed_text in a loop."""
|
|
26
|
+
model = get_model()
|
|
27
|
+
return model.encode(texts).tolist()
|