sourcefire 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sourcefire/__init__.py ADDED
File without changes
File without changes
@@ -0,0 +1,24 @@
1
+ """Pydantic request/response models for the Sourcefire API."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import BaseModel
6
+
7
+
8
+ class QueryRequest(BaseModel):
9
+ query: str
10
+ mode: Literal["debug", "feature", "explain"] = "debug"
11
+ model: Literal["gemini-2.5-flash", "gemini-2.5-pro"] = "gemini-2.5-flash"
12
+ history: list[dict] = []
13
+
14
+
15
+ class StatusResponse(BaseModel):
16
+ files_indexed: int
17
+ last_indexed: str
18
+ index_status: str
19
+ language: str = "generic"
20
+
21
+
22
+ class SourceResponse(BaseModel):
23
+ content: str
24
+ language: str
@@ -0,0 +1,166 @@
1
+ """FastAPI router for the Sourcefire API."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Any, AsyncGenerator
8
+
9
+ from fastapi import APIRouter, HTTPException, Query
10
+ from sse_starlette.sse import EventSourceResponse
11
+
12
+ from sourcefire.api.models import QueryRequest, SourceResponse, StatusResponse
13
+
14
+ router = APIRouter(prefix="/api")
15
+
16
+ # ---------------------------------------------------------------------------
17
+ # Module-level dependency state — set once at startup via init_dependencies()
18
+ # ---------------------------------------------------------------------------
19
+
20
+ _collection: Any = None
21
+ _graph: Any = None
22
+ _profile: Any = None
23
+ _project_dir: Path | None = None
24
+ _gemini_api_key: str = ""
25
+ _index_status: dict[str, Any] = {
26
+ "files_indexed": 0,
27
+ "last_indexed": "never",
28
+ "index_status": "not_ready",
29
+ "language": "generic",
30
+ }
31
+
32
+
33
+ def init_dependencies(
34
+ collection: Any,
35
+ graph: Any,
36
+ index_status: dict[str, Any],
37
+ profile: Any = None,
38
+ project_dir: Path | None = None,
39
+ gemini_api_key: str = "",
40
+ ) -> None:
41
+ """Inject shared dependencies from the application lifespan."""
42
+ global _collection, _graph, _index_status, _profile, _project_dir, _gemini_api_key
43
+ _collection = collection
44
+ _graph = graph
45
+ _index_status = index_status
46
+ _profile = profile
47
+ _project_dir = project_dir
48
+ _gemini_api_key = gemini_api_key
49
+
50
+
51
+ # ---------------------------------------------------------------------------
52
+ # Language detection helper
53
+ # ---------------------------------------------------------------------------
54
+
55
+ _EXTENSION_TO_LANGUAGE: dict[str, str] = {
56
+ ".dart": "dart",
57
+ ".py": "python",
58
+ ".md": "markdown",
59
+ ".yaml": "yaml",
60
+ ".yml": "yaml",
61
+ ".json": "json",
62
+ ".ts": "typescript",
63
+ ".tsx": "typescript",
64
+ ".js": "javascript",
65
+ ".jsx": "javascript",
66
+ ".html": "html",
67
+ ".css": "css",
68
+ ".sh": "bash",
69
+ ".go": "go",
70
+ ".rs": "rust",
71
+ ".java": "java",
72
+ ".kt": "kotlin",
73
+ ".swift": "swift",
74
+ ".rb": "ruby",
75
+ ".php": "php",
76
+ ".c": "c",
77
+ ".cpp": "cpp",
78
+ ".h": "c",
79
+ ".hpp": "cpp",
80
+ ".toml": "toml",
81
+ ".xml": "xml",
82
+ ".sql": "sql",
83
+ ".graphql": "graphql",
84
+ ".proto": "protobuf",
85
+ ".tf": "hcl",
86
+ ".dockerfile": "dockerfile",
87
+ }
88
+
89
+
90
+ def _detect_language(file_path: Path) -> str:
91
+ name = file_path.name.lower()
92
+ if name == "dockerfile":
93
+ return "dockerfile"
94
+ if name == "makefile":
95
+ return "makefile"
96
+ return _EXTENSION_TO_LANGUAGE.get(file_path.suffix.lower(), "plaintext")
97
+
98
+
99
+ # ---------------------------------------------------------------------------
100
+ # Routes
101
+ # ---------------------------------------------------------------------------
102
+
103
+
104
+ @router.post("/query")
105
+ async def query(request: QueryRequest) -> EventSourceResponse:
106
+ """Stream a RAG response for the given query via Server-Sent Events."""
107
+ if not _gemini_api_key:
108
+ raise HTTPException(
109
+ status_code=503,
110
+ detail="GEMINI_API_KEY is not configured.",
111
+ )
112
+
113
+ from sourcefire.chain.rag_chain import stream_rag_response
114
+
115
+ async def _event_generator() -> AsyncGenerator[dict[str, str], None]:
116
+ async for chunk in stream_rag_response(
117
+ collection=_collection,
118
+ graph=_graph,
119
+ query=request.query,
120
+ mode=request.mode,
121
+ model=request.model,
122
+ history=request.history,
123
+ profile=_profile,
124
+ project_dir=_project_dir,
125
+ gemini_api_key=_gemini_api_key,
126
+ ):
127
+ yield {"data": json.dumps(chunk)}
128
+
129
+ return EventSourceResponse(_event_generator())
130
+
131
+
132
+ @router.get("/sources", response_model=SourceResponse)
133
+ async def sources(path: str = Query(..., description="Relative path within the codebase")) -> SourceResponse:
134
+ """Return the content and detected language of a source file."""
135
+ if _project_dir is None:
136
+ raise HTTPException(status_code=503, detail="Project directory not initialized.")
137
+
138
+ codebase_resolved = _project_dir.resolve()
139
+ full_path = (_project_dir / path).resolve()
140
+
141
+ if not str(full_path).startswith(str(codebase_resolved)):
142
+ raise HTTPException(status_code=400, detail="Path traversal detected.")
143
+
144
+ if not full_path.is_file():
145
+ raise HTTPException(status_code=404, detail=f"File not found: {path}")
146
+
147
+ try:
148
+ content = full_path.read_text(encoding="utf-8", errors="replace")
149
+ except OSError as exc:
150
+ raise HTTPException(status_code=500, detail=f"Could not read file: {exc}") from exc
151
+
152
+ return SourceResponse(
153
+ content=content,
154
+ language=_detect_language(full_path),
155
+ )
156
+
157
+
158
+ @router.get("/status", response_model=StatusResponse)
159
+ async def status() -> StatusResponse:
160
+ """Return current index status."""
161
+ return StatusResponse(
162
+ files_indexed=_index_status.get("files_indexed", 0),
163
+ last_indexed=str(_index_status.get("last_indexed", "never")),
164
+ index_status=str(_index_status.get("index_status", "not_ready")),
165
+ language=str(_index_status.get("language", "generic")),
166
+ )
File without changes
@@ -0,0 +1,195 @@
1
+ """Prompt assembly and token budget management for Sourcefire."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ from sourcefire.config import MAX_HISTORY_PAIRS, MAX_TOKEN_BUDGET, RESPONSE_HEADROOM
9
+
10
+ # ---------------------------------------------------------------------------
11
+ # System template
12
+ # ---------------------------------------------------------------------------
13
+
14
+ from importlib.resources import files as _resource_files
15
+
16
+ _SYSTEM_MD_PATH = Path(str(_resource_files("sourcefire") / "prompts" / "system.md"))
17
+ _SYSTEM_TEMPLATE: str = _SYSTEM_MD_PATH.read_text(encoding="utf-8")
18
+
19
+ # ---------------------------------------------------------------------------
20
+ # Priority ordering (lower = higher priority)
21
+ # ---------------------------------------------------------------------------
22
+
23
+ _PRIORITY_ORDER: dict[str, int] = {"direct": 0, "semantic": 1, "graph": 2}
24
+
25
+ # ---------------------------------------------------------------------------
26
+ # Per-mode suffixes appended to the system prompt
27
+ # ---------------------------------------------------------------------------
28
+
29
+ _MODE_SUFFIXES: dict[str, str] = {
30
+ "debug": (
31
+ "\n\n## Mode: Debug\n"
32
+ "Diagnose the root cause from the retrieved context first. "
33
+ "If the stack trace references files not in the context, use tools to read them. "
34
+ "Trace the call chain. Show exact files and lines. Suggest a minimal fix."
35
+ ),
36
+ "feature": (
37
+ "\n\n## Mode: Feature\n"
38
+ "Use the retrieved context to identify the project structure and patterns. "
39
+ "Show where new code should live based on existing conventions you can see. "
40
+ "If you need to see similar features for reference, use semantic_code_search or find_files_by_name."
41
+ ),
42
+ "explain": (
43
+ "\n\n## Mode: Explain\n"
44
+ "Explain from the retrieved context first — it already contains the most relevant code. "
45
+ "Walk through files in dependency order. "
46
+ "Only use tools if you need to trace deeper connections not visible in the context."
47
+ ),
48
+ }
49
+
50
+ # ---------------------------------------------------------------------------
51
+ # Token utilities
52
+ # ---------------------------------------------------------------------------
53
+
54
+
55
+ def estimate_tokens(text: str) -> int:
56
+ """Rough token estimate: 1 token ~ 4 characters."""
57
+ return len(text) // 4
58
+
59
+
60
+ # ---------------------------------------------------------------------------
61
+ # Chunk truncation with token budget enforcement
62
+ # ---------------------------------------------------------------------------
63
+
64
+ _MAX_CHUNK_CHARS = 6_000 # ~1500 tokens per chunk
65
+
66
+
67
+ def truncate_chunks(
68
+ chunks: list[dict[str, Any]],
69
+ max_tokens: int,
70
+ ) -> list[dict[str, Any]]:
71
+ """Return a subset of *chunks* that fits within *max_tokens*."""
72
+ sorted_chunks = sorted(
73
+ chunks,
74
+ key=lambda c: (
75
+ _PRIORITY_ORDER.get(c.get("priority", "graph"), 2),
76
+ -float(c.get("relevance", 0.0)),
77
+ ),
78
+ )
79
+
80
+ capped: list[dict[str, Any]] = []
81
+ for chunk in sorted_chunks:
82
+ c = dict(chunk)
83
+ if len(c.get("code", "")) > _MAX_CHUNK_CHARS:
84
+ c["code"] = c["code"][:_MAX_CHUNK_CHARS]
85
+ capped.append(c)
86
+
87
+ result: list[dict[str, Any]] = []
88
+ used_tokens = 0
89
+ for chunk in capped:
90
+ chunk_tokens = estimate_tokens(chunk.get("code", ""))
91
+ if used_tokens + chunk_tokens <= max_tokens:
92
+ result.append(chunk)
93
+ used_tokens += chunk_tokens
94
+
95
+ return result
96
+
97
+
98
+ # ---------------------------------------------------------------------------
99
+ # Prompt assembly
100
+ # ---------------------------------------------------------------------------
101
+
102
+
103
+ def assemble_prompt(
104
+ mode: str,
105
+ query: str,
106
+ chunks: list[dict[str, Any]],
107
+ claude_md: str,
108
+ memory_content: str,
109
+ history: list[dict[str, str]],
110
+ model: str,
111
+ highlight_language: str = "text",
112
+ ) -> dict[str, Any]:
113
+ """Assemble the full prompt dict for the LLM call.
114
+
115
+ Returns a dict with keys:
116
+ - ``system`` — full system prompt
117
+ - ``context`` — formatted retrieved code chunks
118
+ - ``query`` — the user's question (unchanged)
119
+ - ``history`` — trimmed conversation history
120
+ - ``stats`` — token usage summary dict
121
+ """
122
+ # 1. Compute available context budget
123
+ token_budget = MAX_TOKEN_BUDGET.get(model, 100_000)
124
+ system_tokens = estimate_tokens(_SYSTEM_TEMPLATE + claude_md + memory_content)
125
+ query_tokens = estimate_tokens(query)
126
+ history_tokens = sum(estimate_tokens(m.get("content", "")) for m in history)
127
+ overhead = system_tokens + query_tokens + history_tokens + RESPONSE_HEADROOM
128
+ context_budget = max(0, token_budget - overhead)
129
+
130
+ # 2. Truncate chunks to fit context budget
131
+ kept_chunks = truncate_chunks(chunks, max_tokens=context_budget)
132
+
133
+ # 3. Build system prompt
134
+ mode_suffix = _MODE_SUFFIXES.get(mode, "")
135
+ system_parts = [
136
+ _SYSTEM_TEMPLATE,
137
+ "\n\n---\n\n## Project Rules (CLAUDE.md)\n\n",
138
+ claude_md,
139
+ "\n\n---\n\n## Developer Memory\n\n",
140
+ memory_content,
141
+ mode_suffix,
142
+ ]
143
+ system_prompt = "".join(system_parts)
144
+
145
+ # 4. Build context block — use the detected language for syntax highlighting
146
+ context_parts: list[str] = []
147
+ for chunk in kept_chunks:
148
+ filename = chunk.get("filename", "unknown")
149
+ location = chunk.get("location", "")
150
+ relevance = chunk.get("relevance", 0.0)
151
+ code = chunk.get("code", "")
152
+
153
+ # Detect per-file highlight language from extension
154
+ ext = Path(filename).suffix.lower()
155
+ _ext_map = {
156
+ ".py": "python", ".js": "javascript", ".ts": "typescript",
157
+ ".tsx": "typescript", ".jsx": "javascript", ".dart": "dart",
158
+ ".go": "go", ".rs": "rust", ".java": "java", ".md": "markdown",
159
+ ".yaml": "yaml", ".yml": "yaml", ".json": "json", ".html": "html",
160
+ ".css": "css", ".sh": "bash", ".toml": "toml",
161
+ }
162
+ lang = _ext_map.get(ext, highlight_language)
163
+
164
+ header = f"### {filename}"
165
+ if location:
166
+ header += f" ({location})"
167
+ header += f" [relevance: {relevance:.2f}]"
168
+ context_parts.append(f"{header}\n```{lang}\n{code}\n```")
169
+
170
+ context_block = "\n\n".join(context_parts)
171
+
172
+ # 5. Trim history to MAX_HISTORY_PAIRS
173
+ trimmed_history = history[-(MAX_HISTORY_PAIRS * 2):]
174
+
175
+ # 6. Stats
176
+ context_tokens = estimate_tokens(context_block)
177
+ stats = {
178
+ "model": model,
179
+ "token_budget": token_budget,
180
+ "system_tokens": system_tokens,
181
+ "context_tokens": context_tokens,
182
+ "query_tokens": query_tokens,
183
+ "history_tokens": history_tokens,
184
+ "total_estimated": system_tokens + context_tokens + query_tokens + history_tokens,
185
+ "chunks_used": len(kept_chunks),
186
+ "chunks_dropped": len(chunks) - len(kept_chunks),
187
+ }
188
+
189
+ return {
190
+ "system": system_prompt,
191
+ "context": context_block,
192
+ "query": query,
193
+ "history": trimmed_history,
194
+ "stats": stats,
195
+ }