cognitive-cache 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognitive_cache/__init__.py +12 -0
- cognitive_cache/api.py +265 -0
- cognitive_cache/baselines/__init__.py +0 -0
- cognitive_cache/baselines/base.py +11 -0
- cognitive_cache/baselines/embedding_select.py +37 -0
- cognitive_cache/baselines/full_stuff.py +19 -0
- cognitive_cache/baselines/grep_select.py +32 -0
- cognitive_cache/baselines/llm_triage.py +42 -0
- cognitive_cache/baselines/random_select.py +26 -0
- cognitive_cache/cli.py +141 -0
- cognitive_cache/core/__init__.py +0 -0
- cognitive_cache/core/chunker.py +155 -0
- cognitive_cache/core/orderer.py +30 -0
- cognitive_cache/core/selector.py +112 -0
- cognitive_cache/core/value_function.py +96 -0
- cognitive_cache/indexer/__init__.py +0 -0
- cognitive_cache/indexer/git_analyzer.py +113 -0
- cognitive_cache/indexer/graph_builder.py +139 -0
- cognitive_cache/indexer/repo_indexer.py +103 -0
- cognitive_cache/indexer/token_counter.py +19 -0
- cognitive_cache/llm/__init__.py +0 -0
- cognitive_cache/llm/adapter.py +37 -0
- cognitive_cache/llm/claude_adapter.py +45 -0
- cognitive_cache/llm/llamacpp_adapter.py +72 -0
- cognitive_cache/llm/openai_adapter.py +45 -0
- cognitive_cache/mcp_server.py +100 -0
- cognitive_cache/models.py +59 -0
- cognitive_cache/py.typed +0 -0
- cognitive_cache/signals/__init__.py +0 -0
- cognitive_cache/signals/base.py +27 -0
- cognitive_cache/signals/change_recency.py +19 -0
- cognitive_cache/signals/embedding_sim.py +63 -0
- cognitive_cache/signals/file_role_prior.py +60 -0
- cognitive_cache/signals/graph_distance.py +41 -0
- cognitive_cache/signals/redundancy.py +29 -0
- cognitive_cache/signals/symbol_overlap.py +32 -0
- cognitive_cache-0.1.1.dist-info/METADATA +263 -0
- cognitive_cache-0.1.1.dist-info/RECORD +41 -0
- cognitive_cache-0.1.1.dist-info/WHEEL +4 -0
- cognitive_cache-0.1.1.dist-info/entry_points.txt +4 -0
- cognitive_cache-0.1.1.dist-info/licenses/LICENSE +190 -0
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from cognitive_cache.api import RepoIndex, select_context, select_context_from_repo
|
|
2
|
+
from cognitive_cache.models import Source, Task, ScoredSource, SelectionResult
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"RepoIndex",
|
|
6
|
+
"select_context",
|
|
7
|
+
"select_context_from_repo",
|
|
8
|
+
"Source",
|
|
9
|
+
"Task",
|
|
10
|
+
"ScoredSource",
|
|
11
|
+
"SelectionResult",
|
|
12
|
+
]
|
cognitive_cache/api.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
"""Public API for Cognitive Cache: RepoIndex and select_context.
|
|
2
|
+
|
|
3
|
+
This module is the single entry point for library consumers. The CLI
|
|
4
|
+
and MCP server are thin wrappers around RepoIndex and select_context;
|
|
5
|
+
everything else (indexing, scoring, selection) is internal machinery.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import re
|
|
10
|
+
import subprocess
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
|
|
13
|
+
from cognitive_cache.models import Source, Task, SelectionResult
|
|
14
|
+
from cognitive_cache.indexer.repo_indexer import (
|
|
15
|
+
index_repo,
|
|
16
|
+
SOURCE_EXTENSIONS,
|
|
17
|
+
SKIP_DIRS,
|
|
18
|
+
)
|
|
19
|
+
from cognitive_cache.indexer.git_analyzer import GitAnalyzer
|
|
20
|
+
from cognitive_cache.indexer.graph_builder import build_dependency_graph
|
|
21
|
+
from cognitive_cache.signals.embedding_sim import EmbeddingSimilaritySignal
|
|
22
|
+
from cognitive_cache.core.value_function import ValueFunction
|
|
23
|
+
from cognitive_cache.core.selector import GreedySelector
|
|
24
|
+
from cognitive_cache.core.orderer import order_context
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _extract_task_symbols(
|
|
28
|
+
title: str, body: str, sources: list[Source]
|
|
29
|
+
) -> frozenset[str]:
|
|
30
|
+
"""Extract symbols from task text using exact and selective substring matching.
|
|
31
|
+
|
|
32
|
+
Strategy:
|
|
33
|
+
1. Exact match: symbol name appears verbatim in issue text (strongest signal).
|
|
34
|
+
2. Substring match: issue words (6+ chars) appear in symbol names.
|
|
35
|
+
Short common words like 'error', 'class', 'test' cause too many false
|
|
36
|
+
matches, so we only use longer words.
|
|
37
|
+
"""
|
|
38
|
+
all_symbols = set()
|
|
39
|
+
for s in sources:
|
|
40
|
+
all_symbols.update(s.symbols)
|
|
41
|
+
|
|
42
|
+
text = f"{title} {body}".lower()
|
|
43
|
+
task_words_long = set(re.findall(r"\b[a-z_][a-z0-9_]{5,}\b", text))
|
|
44
|
+
stop_words = {
|
|
45
|
+
"return",
|
|
46
|
+
"import",
|
|
47
|
+
"should",
|
|
48
|
+
"string",
|
|
49
|
+
"number",
|
|
50
|
+
"before",
|
|
51
|
+
"after",
|
|
52
|
+
"called",
|
|
53
|
+
"values",
|
|
54
|
+
"object",
|
|
55
|
+
"update",
|
|
56
|
+
"create",
|
|
57
|
+
"delete",
|
|
58
|
+
"method",
|
|
59
|
+
"function",
|
|
60
|
+
"default",
|
|
61
|
+
"option",
|
|
62
|
+
"options",
|
|
63
|
+
"config",
|
|
64
|
+
"module",
|
|
65
|
+
"result",
|
|
66
|
+
"response",
|
|
67
|
+
"request",
|
|
68
|
+
"handler",
|
|
69
|
+
"callback",
|
|
70
|
+
"parameter",
|
|
71
|
+
}
|
|
72
|
+
task_words_long -= stop_words
|
|
73
|
+
|
|
74
|
+
matches = set()
|
|
75
|
+
for sym in all_symbols:
|
|
76
|
+
sym_lower = sym.lower()
|
|
77
|
+
if sym_lower in text and len(sym_lower) >= 4:
|
|
78
|
+
matches.add(sym)
|
|
79
|
+
continue
|
|
80
|
+
for word in task_words_long:
|
|
81
|
+
if word in sym_lower:
|
|
82
|
+
matches.add(sym)
|
|
83
|
+
break
|
|
84
|
+
|
|
85
|
+
return frozenset(matches)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _find_entry_points(task_symbols: frozenset[str], sources: list[Source]) -> set[str]:
|
|
89
|
+
entry_points = set()
|
|
90
|
+
for s in sources:
|
|
91
|
+
if s.symbols & task_symbols:
|
|
92
|
+
entry_points.add(s.path)
|
|
93
|
+
return entry_points
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _get_head_commit(repo_path: str) -> str:
|
|
97
|
+
try:
|
|
98
|
+
result = subprocess.run(
|
|
99
|
+
["git", "rev-parse", "HEAD"],
|
|
100
|
+
cwd=repo_path,
|
|
101
|
+
capture_output=True,
|
|
102
|
+
text=True,
|
|
103
|
+
timeout=10,
|
|
104
|
+
)
|
|
105
|
+
if result.returncode != 0:
|
|
106
|
+
return ""
|
|
107
|
+
return result.stdout.strip()
|
|
108
|
+
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
|
109
|
+
return ""
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _collect_mtimes(repo_path: str, sources: list[Source]) -> dict[str, float]:
|
|
113
|
+
mtimes: dict[str, float] = {}
|
|
114
|
+
for s in sources:
|
|
115
|
+
full_path = os.path.join(repo_path, s.path)
|
|
116
|
+
try:
|
|
117
|
+
mtimes[s.path] = os.path.getmtime(full_path)
|
|
118
|
+
except OSError:
|
|
119
|
+
pass
|
|
120
|
+
return mtimes
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@dataclass
|
|
124
|
+
class RepoIndex:
|
|
125
|
+
"""An indexed snapshot of a repository, ready for context selection.
|
|
126
|
+
|
|
127
|
+
Build one with ``RepoIndex.build(repo_path)`` and reuse it across
|
|
128
|
+
multiple ``select_context`` calls. Call ``refresh()`` to cheaply detect
|
|
129
|
+
whether the repo has changed since the last index.
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
repo_path: str
|
|
133
|
+
sources: list[Source]
|
|
134
|
+
recency_data: dict[str, float]
|
|
135
|
+
graph: object # DependencyGraph
|
|
136
|
+
embedding_signal: EmbeddingSimilaritySignal
|
|
137
|
+
file_mtimes: dict[str, float]
|
|
138
|
+
head_commit: str
|
|
139
|
+
|
|
140
|
+
@classmethod
|
|
141
|
+
def build(cls, repo_path: str) -> "RepoIndex":
|
|
142
|
+
"""Build a full index from scratch.
|
|
143
|
+
|
|
144
|
+
Raises FileNotFoundError if repo_path does not exist.
|
|
145
|
+
"""
|
|
146
|
+
if not os.path.exists(repo_path):
|
|
147
|
+
raise FileNotFoundError(f"Repository path does not exist: {repo_path}")
|
|
148
|
+
|
|
149
|
+
sources = index_repo(repo_path)
|
|
150
|
+
graph = build_dependency_graph(sources)
|
|
151
|
+
|
|
152
|
+
git_analyzer = GitAnalyzer(repo_path)
|
|
153
|
+
recency_data = git_analyzer.recency_scores()
|
|
154
|
+
|
|
155
|
+
embedding_signal = EmbeddingSimilaritySignal()
|
|
156
|
+
if sources:
|
|
157
|
+
embedding_signal.fit(sources)
|
|
158
|
+
|
|
159
|
+
return cls(
|
|
160
|
+
repo_path=repo_path,
|
|
161
|
+
sources=sources,
|
|
162
|
+
recency_data=recency_data,
|
|
163
|
+
graph=graph,
|
|
164
|
+
embedding_signal=embedding_signal,
|
|
165
|
+
file_mtimes=_collect_mtimes(repo_path, sources),
|
|
166
|
+
head_commit=_get_head_commit(repo_path),
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
def refresh(self) -> "RepoIndex":
|
|
170
|
+
"""Check whether the repo has changed and rebuild if needed.
|
|
171
|
+
|
|
172
|
+
Returns ``self`` if nothing changed. Returns a new RepoIndex if
|
|
173
|
+
files were modified or HEAD moved. Only re-fetches git recency
|
|
174
|
+
data when HEAD has actually moved, since that's the expensive part.
|
|
175
|
+
"""
|
|
176
|
+
current_head = _get_head_commit(self.repo_path)
|
|
177
|
+
head_changed = current_head != self.head_commit
|
|
178
|
+
|
|
179
|
+
# Walk the source tree and compare mtimes against our snapshot.
|
|
180
|
+
current_mtimes: dict[str, float] = {}
|
|
181
|
+
for root, dirs, files in os.walk(self.repo_path):
|
|
182
|
+
dirs[:] = [d for d in dirs if d not in SKIP_DIRS]
|
|
183
|
+
for filename in files:
|
|
184
|
+
ext = os.path.splitext(filename)[1]
|
|
185
|
+
if ext not in SOURCE_EXTENSIONS:
|
|
186
|
+
continue
|
|
187
|
+
full_path = os.path.join(root, filename)
|
|
188
|
+
rel_path = os.path.relpath(full_path, self.repo_path).replace("\\", "/")
|
|
189
|
+
try:
|
|
190
|
+
current_mtimes[rel_path] = os.path.getmtime(full_path)
|
|
191
|
+
except OSError:
|
|
192
|
+
pass
|
|
193
|
+
|
|
194
|
+
files_changed = current_mtimes != self.file_mtimes
|
|
195
|
+
|
|
196
|
+
if not files_changed and not head_changed:
|
|
197
|
+
return self
|
|
198
|
+
|
|
199
|
+
# Something changed, so rebuild the index.
|
|
200
|
+
sources = index_repo(self.repo_path)
|
|
201
|
+
graph = build_dependency_graph(sources)
|
|
202
|
+
embedding_signal = EmbeddingSimilaritySignal()
|
|
203
|
+
if sources:
|
|
204
|
+
embedding_signal.fit(sources)
|
|
205
|
+
|
|
206
|
+
recency_data = self.recency_data
|
|
207
|
+
if head_changed:
|
|
208
|
+
git_analyzer = GitAnalyzer(self.repo_path)
|
|
209
|
+
recency_data = git_analyzer.recency_scores()
|
|
210
|
+
|
|
211
|
+
return RepoIndex(
|
|
212
|
+
repo_path=self.repo_path,
|
|
213
|
+
sources=sources,
|
|
214
|
+
recency_data=recency_data,
|
|
215
|
+
graph=graph,
|
|
216
|
+
embedding_signal=embedding_signal,
|
|
217
|
+
file_mtimes=_collect_mtimes(self.repo_path, sources),
|
|
218
|
+
head_commit=current_head,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def select_context(
|
|
223
|
+
index: RepoIndex,
|
|
224
|
+
task: str | Task,
|
|
225
|
+
budget: int = 12_000,
|
|
226
|
+
) -> SelectionResult:
|
|
227
|
+
"""Select the most valuable context for a task from an indexed repo.
|
|
228
|
+
|
|
229
|
+
If *task* is a plain string it gets wrapped into a Task with symbols
|
|
230
|
+
extracted automatically. If it's already a Task object, its symbols
|
|
231
|
+
are used directly.
|
|
232
|
+
"""
|
|
233
|
+
if not index.sources:
|
|
234
|
+
return SelectionResult(selected=[], total_tokens=0, budget=budget)
|
|
235
|
+
|
|
236
|
+
if isinstance(task, str):
|
|
237
|
+
task_symbols = _extract_task_symbols(task, "", index.sources)
|
|
238
|
+
task = Task(title=task, body="", symbols=task_symbols)
|
|
239
|
+
|
|
240
|
+
entry_points = _find_entry_points(task.symbols, index.sources)
|
|
241
|
+
vf = ValueFunction(
|
|
242
|
+
graph=index.graph,
|
|
243
|
+
recency_data=index.recency_data,
|
|
244
|
+
embedding_signal=index.embedding_signal,
|
|
245
|
+
entry_points=entry_points,
|
|
246
|
+
)
|
|
247
|
+
selector = GreedySelector(value_function=vf)
|
|
248
|
+
result = selector.select(index.sources, task, budget)
|
|
249
|
+
result.selected = order_context(result.selected)
|
|
250
|
+
return result
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def select_context_from_repo(
|
|
254
|
+
repo_path: str,
|
|
255
|
+
task: str | Task,
|
|
256
|
+
budget: int = 12_000,
|
|
257
|
+
) -> SelectionResult:
|
|
258
|
+
"""Convenience wrapper that builds an index and selects context in one call.
|
|
259
|
+
|
|
260
|
+
Useful for one-shot usage where you don't need to reuse the index.
|
|
261
|
+
For repeated queries against the same repo, build a RepoIndex once
|
|
262
|
+
and call select_context directly.
|
|
263
|
+
"""
|
|
264
|
+
index = RepoIndex.build(repo_path)
|
|
265
|
+
return select_context(index, task, budget)
|
|
File without changes
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Base interface for context selection strategies (baselines + our algorithm)."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
from cognitive_cache.models import Source, Task, SelectionResult
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BaselineStrategy(ABC):
|
|
9
|
+
@abstractmethod
|
|
10
|
+
def select(self, sources: list[Source], task: Task, budget: int) -> SelectionResult:
|
|
11
|
+
...
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""Baseline 3: Embedding Similarity (simulates RAG). Top-k by TF-IDF cosine similarity."""
|
|
2
|
+
|
|
3
|
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
4
|
+
from sklearn.metrics.pairwise import cosine_similarity
|
|
5
|
+
|
|
6
|
+
from cognitive_cache.models import Source, Task, ScoredSource, SelectionResult
|
|
7
|
+
from cognitive_cache.baselines.base import BaselineStrategy
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class EmbeddingStrategy(BaselineStrategy):
|
|
11
|
+
def select(self, sources: list[Source], task: Task, budget: int) -> SelectionResult:
|
|
12
|
+
if not sources:
|
|
13
|
+
return SelectionResult(selected=[], total_tokens=0, budget=budget)
|
|
14
|
+
|
|
15
|
+
corpus = [task.full_text] + [s.content for s in sources]
|
|
16
|
+
try:
|
|
17
|
+
vectorizer = TfidfVectorizer(max_features=5000, token_pattern=r"(?u)\b\w+\b")
|
|
18
|
+
matrix = vectorizer.fit_transform(corpus)
|
|
19
|
+
except ValueError:
|
|
20
|
+
return SelectionResult(selected=[], total_tokens=0, budget=budget)
|
|
21
|
+
|
|
22
|
+
task_vec = matrix[0:1]
|
|
23
|
+
source_vecs = matrix[1:]
|
|
24
|
+
similarities = cosine_similarity(task_vec, source_vecs)[0]
|
|
25
|
+
|
|
26
|
+
ranked = sorted(enumerate(similarities), key=lambda x: x[1], reverse=True)
|
|
27
|
+
|
|
28
|
+
selected = []
|
|
29
|
+
total = 0
|
|
30
|
+
for idx, sim in ranked:
|
|
31
|
+
s = sources[idx]
|
|
32
|
+
if total + s.token_count > budget:
|
|
33
|
+
continue
|
|
34
|
+
selected.append(ScoredSource(source=s, score=float(sim)))
|
|
35
|
+
total += s.token_count
|
|
36
|
+
|
|
37
|
+
return SelectionResult(selected=selected, total_tokens=total, budget=budget)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Baseline 2: Full Stuff. Cram as many files as fit, in filesystem order."""
|
|
2
|
+
|
|
3
|
+
from cognitive_cache.models import Source, Task, ScoredSource, SelectionResult
|
|
4
|
+
from cognitive_cache.baselines.base import BaselineStrategy
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class FullStuffStrategy(BaselineStrategy):
|
|
8
|
+
def select(self, sources: list[Source], task: Task, budget: int) -> SelectionResult:
|
|
9
|
+
sorted_sources = sorted(sources, key=lambda s: s.path)
|
|
10
|
+
|
|
11
|
+
selected = []
|
|
12
|
+
total = 0
|
|
13
|
+
for s in sorted_sources:
|
|
14
|
+
if total + s.token_count > budget:
|
|
15
|
+
continue
|
|
16
|
+
selected.append(ScoredSource(source=s, score=0.0))
|
|
17
|
+
total += s.token_count
|
|
18
|
+
|
|
19
|
+
return SelectionResult(selected=selected, total_tokens=total, budget=budget)
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""Baseline 4: Grep Strategy (simulates Claude Code / Copilot symbol search)."""
|
|
2
|
+
|
|
3
|
+
from cognitive_cache.models import Source, Task, ScoredSource, SelectionResult
|
|
4
|
+
from cognitive_cache.baselines.base import BaselineStrategy
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class GrepStrategy(BaselineStrategy):
|
|
8
|
+
def select(self, sources: list[Source], task: Task, budget: int) -> SelectionResult:
|
|
9
|
+
if not task.symbols:
|
|
10
|
+
keywords = set(task.full_text.lower().split())
|
|
11
|
+
else:
|
|
12
|
+
keywords = {s.lower() for s in task.symbols}
|
|
13
|
+
|
|
14
|
+
scored = []
|
|
15
|
+
for source in sources:
|
|
16
|
+
content_lower = source.content.lower()
|
|
17
|
+
matches = sum(1 for kw in keywords if kw in content_lower)
|
|
18
|
+
if matches > 0:
|
|
19
|
+
scored.append((source, matches))
|
|
20
|
+
|
|
21
|
+
scored.sort(key=lambda x: x[1], reverse=True)
|
|
22
|
+
|
|
23
|
+
selected = []
|
|
24
|
+
total = 0
|
|
25
|
+
for source, matches in scored:
|
|
26
|
+
if total + source.token_count > budget:
|
|
27
|
+
continue
|
|
28
|
+
score = matches / max(len(keywords), 1)
|
|
29
|
+
selected.append(ScoredSource(source=source, score=score))
|
|
30
|
+
total += source.token_count
|
|
31
|
+
|
|
32
|
+
return SelectionResult(selected=selected, total_tokens=total, budget=budget)
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""Baseline 5: LLM Triage (let the brain pick its own context).
|
|
2
|
+
|
|
3
|
+
Two-pass strategy — requires an LLM adapter. Only used during benchmark runs.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from cognitive_cache.models import Source, Task, ScoredSource, SelectionResult
|
|
7
|
+
from cognitive_cache.baselines.base import BaselineStrategy
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class LLMTriageStrategy(BaselineStrategy):
|
|
11
|
+
def __init__(self, llm_adapter=None):
|
|
12
|
+
self._llm = llm_adapter
|
|
13
|
+
|
|
14
|
+
def select(self, sources: list[Source], task: Task, budget: int) -> SelectionResult:
|
|
15
|
+
if self._llm is None:
|
|
16
|
+
raise RuntimeError("LLMTriageStrategy requires an LLM adapter")
|
|
17
|
+
|
|
18
|
+
file_listing = "\n".join(s.path for s in sources)
|
|
19
|
+
triage_prompt = (
|
|
20
|
+
f"You are helping fix a bug. Here is the issue:\n\n"
|
|
21
|
+
f"Title: {task.title}\n"
|
|
22
|
+
f"Body: {task.body}\n\n"
|
|
23
|
+
f"Here are all the files in the repository:\n\n"
|
|
24
|
+
f"{file_listing}\n\n"
|
|
25
|
+
f"Which files are most likely relevant to this issue? "
|
|
26
|
+
f"List just the file paths, one per line. Pick at most 20 files."
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
response = self._llm.complete(triage_prompt, max_tokens=500)
|
|
30
|
+
|
|
31
|
+
source_by_path = {s.path: s for s in sources}
|
|
32
|
+
selected = []
|
|
33
|
+
total = 0
|
|
34
|
+
for line in response.strip().split("\n"):
|
|
35
|
+
path = line.strip().strip("- ").strip("`")
|
|
36
|
+
if path in source_by_path:
|
|
37
|
+
source = source_by_path[path]
|
|
38
|
+
if total + source.token_count <= budget:
|
|
39
|
+
selected.append(ScoredSource(source=source, score=1.0))
|
|
40
|
+
total += source.token_count
|
|
41
|
+
|
|
42
|
+
return SelectionResult(selected=selected, total_tokens=total, budget=budget)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Baseline 1: Random Selection. Lower bound."""
|
|
2
|
+
|
|
3
|
+
import random as random_module
|
|
4
|
+
|
|
5
|
+
from cognitive_cache.models import Source, Task, ScoredSource, SelectionResult
|
|
6
|
+
from cognitive_cache.baselines.base import BaselineStrategy
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class RandomStrategy(BaselineStrategy):
|
|
10
|
+
def __init__(self, seed: int | None = None):
|
|
11
|
+
self._seed = seed
|
|
12
|
+
|
|
13
|
+
def select(self, sources: list[Source], task: Task, budget: int) -> SelectionResult:
|
|
14
|
+
rng = random_module.Random(self._seed)
|
|
15
|
+
shuffled = list(sources)
|
|
16
|
+
rng.shuffle(shuffled)
|
|
17
|
+
|
|
18
|
+
selected = []
|
|
19
|
+
total = 0
|
|
20
|
+
for s in shuffled:
|
|
21
|
+
if total + s.token_count > budget:
|
|
22
|
+
continue
|
|
23
|
+
selected.append(ScoredSource(source=s, score=0.0))
|
|
24
|
+
total += s.token_count
|
|
25
|
+
|
|
26
|
+
return SelectionResult(selected=selected, total_tokens=total, budget=budget)
|
cognitive_cache/cli.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""CLI entry point for cognitive-cache.
|
|
2
|
+
|
|
3
|
+
Usage:
|
|
4
|
+
cognitive-cache select --repo . --task "fix the login bug" --budget 12000
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import argparse
|
|
8
|
+
import json
|
|
9
|
+
import sys
|
|
10
|
+
|
|
11
|
+
from cognitive_cache.api import RepoIndex, select_context
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _format_human_readable(result) -> str:
|
|
15
|
+
lines = []
|
|
16
|
+
for ss in result.selected:
|
|
17
|
+
path = ss.source.path
|
|
18
|
+
score = ss.score
|
|
19
|
+
signals = ss.signal_scores
|
|
20
|
+
|
|
21
|
+
sig_parts = []
|
|
22
|
+
for key in [
|
|
23
|
+
"symbol_overlap",
|
|
24
|
+
"graph_distance",
|
|
25
|
+
"change_recency",
|
|
26
|
+
"embedding_sim",
|
|
27
|
+
"file_role_prior",
|
|
28
|
+
"redundancy",
|
|
29
|
+
]:
|
|
30
|
+
val = signals.get(key, 0.0)
|
|
31
|
+
short = {
|
|
32
|
+
"symbol_overlap": "sym",
|
|
33
|
+
"graph_distance": "graph",
|
|
34
|
+
"change_recency": "recency",
|
|
35
|
+
"embedding_sim": "embed",
|
|
36
|
+
"file_role_prior": "role",
|
|
37
|
+
"redundancy": "redund",
|
|
38
|
+
}[key]
|
|
39
|
+
sig_parts.append(f"{short}={val:.1f}")
|
|
40
|
+
|
|
41
|
+
sig_str = " ".join(sig_parts)
|
|
42
|
+
lines.append(f"{path:<45} {score:.3f} [{sig_str}]")
|
|
43
|
+
|
|
44
|
+
lines.append("")
|
|
45
|
+
lines.append(
|
|
46
|
+
f"{len(result.selected)} files selected | "
|
|
47
|
+
f"{result.total_tokens:,} / {result.budget:,} tokens used"
|
|
48
|
+
)
|
|
49
|
+
return "\n".join(lines)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _format_json(result) -> str:
|
|
53
|
+
data = {
|
|
54
|
+
"files": [
|
|
55
|
+
{
|
|
56
|
+
"path": ss.source.path,
|
|
57
|
+
"score": round(ss.score, 4),
|
|
58
|
+
"signals": {k: round(v, 4) for k, v in ss.signal_scores.items()},
|
|
59
|
+
"token_count": ss.source.token_count,
|
|
60
|
+
}
|
|
61
|
+
for ss in result.selected
|
|
62
|
+
],
|
|
63
|
+
"total_tokens": result.total_tokens,
|
|
64
|
+
"budget": result.budget,
|
|
65
|
+
"budget_remaining": result.budget_remaining,
|
|
66
|
+
}
|
|
67
|
+
return json.dumps(data, indent=2)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _write_context_file(result, output_path: str):
|
|
71
|
+
with open(output_path, "w") as f:
|
|
72
|
+
for ss in result.selected:
|
|
73
|
+
f.write(f"# --- {ss.source.path} ---\n")
|
|
74
|
+
f.write(ss.source.content)
|
|
75
|
+
f.write("\n\n")
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def main():
|
|
79
|
+
parser = argparse.ArgumentParser(
|
|
80
|
+
prog="cognitive-cache",
|
|
81
|
+
description="Optimal context selection for LLMs",
|
|
82
|
+
)
|
|
83
|
+
subparsers = parser.add_subparsers(dest="command")
|
|
84
|
+
|
|
85
|
+
select_parser = subparsers.add_parser(
|
|
86
|
+
"select", help="Select context files for a task"
|
|
87
|
+
)
|
|
88
|
+
select_parser.add_argument(
|
|
89
|
+
"--repo", required=True, help="Path to the repository root"
|
|
90
|
+
)
|
|
91
|
+
select_parser.add_argument("--task", required=True, help="Task description")
|
|
92
|
+
select_parser.add_argument(
|
|
93
|
+
"--budget", type=int, default=12000, help="Token budget (default: 12000)"
|
|
94
|
+
)
|
|
95
|
+
select_parser.add_argument(
|
|
96
|
+
"--json", action="store_true", dest="json_output", help="Output as JSON"
|
|
97
|
+
)
|
|
98
|
+
select_parser.add_argument("--output", help="Write full context to this file")
|
|
99
|
+
|
|
100
|
+
args = parser.parse_args()
|
|
101
|
+
|
|
102
|
+
if args.command != "select":
|
|
103
|
+
parser.print_help()
|
|
104
|
+
sys.exit(1)
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
index = RepoIndex.build(args.repo)
|
|
108
|
+
except FileNotFoundError as e:
|
|
109
|
+
print(f"Error: {e}", file=sys.stderr)
|
|
110
|
+
sys.exit(1)
|
|
111
|
+
|
|
112
|
+
if not index.sources:
|
|
113
|
+
print("No source files found.", file=sys.stderr)
|
|
114
|
+
sys.exit(1)
|
|
115
|
+
|
|
116
|
+
result = select_context(index, args.task, budget=args.budget)
|
|
117
|
+
|
|
118
|
+
if not result.selected:
|
|
119
|
+
print("No files selected.", file=sys.stderr)
|
|
120
|
+
sys.exit(0)
|
|
121
|
+
|
|
122
|
+
# Warn if no symbols matched (vague task description)
|
|
123
|
+
if all(
|
|
124
|
+
ss.signal_scores.get("symbol_overlap", 0.0) == 0.0 for ss in result.selected
|
|
125
|
+
):
|
|
126
|
+
print(
|
|
127
|
+
"Warning: No symbol matches found, results may be less precise.",
|
|
128
|
+
file=sys.stderr,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
if args.json_output:
|
|
132
|
+
print(_format_json(result))
|
|
133
|
+
else:
|
|
134
|
+
print(_format_human_readable(result))
|
|
135
|
+
|
|
136
|
+
if args.output:
|
|
137
|
+
_write_context_file(result, args.output)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
if __name__ == "__main__":
|
|
141
|
+
main()
|
|
File without changes
|