code-context-control 2.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__init__.py +1 -0
- cli/_hook_utils.py +99 -0
- cli/c3.py +6152 -0
- cli/commands/__init__.py +1 -0
- cli/commands/common.py +312 -0
- cli/commands/parser.py +286 -0
- cli/docs.html +3178 -0
- cli/edits.html +878 -0
- cli/hook_auto_snapshot.py +142 -0
- cli/hook_c3_signal.py +61 -0
- cli/hook_c3read.py +116 -0
- cli/hook_edit_ledger.py +213 -0
- cli/hook_edit_unlock.py +170 -0
- cli/hook_filter.py +130 -0
- cli/hook_ghost_files.py +238 -0
- cli/hook_pretool_enforce.py +334 -0
- cli/hook_read.py +200 -0
- cli/hook_session_stats.py +62 -0
- cli/hook_terse_advisor.py +190 -0
- cli/hub.html +3764 -0
- cli/hub_server.py +1619 -0
- cli/mcp_proxy.py +428 -0
- cli/mcp_server.py +660 -0
- cli/server.py +2985 -0
- cli/tools/__init__.py +4 -0
- cli/tools/_helpers.py +65 -0
- cli/tools/agent.py +1165 -0
- cli/tools/compress.py +215 -0
- cli/tools/delegate.py +1184 -0
- cli/tools/edit.py +313 -0
- cli/tools/edits.py +118 -0
- cli/tools/filter.py +285 -0
- cli/tools/impact.py +163 -0
- cli/tools/memory.py +469 -0
- cli/tools/read.py +224 -0
- cli/tools/search.py +337 -0
- cli/tools/session.py +95 -0
- cli/tools/shell.py +193 -0
- cli/tools/status.py +306 -0
- cli/tools/validate.py +310 -0
- cli/ui/api.js +36 -0
- cli/ui/app.js +207 -0
- cli/ui/components/chat.js +758 -0
- cli/ui/components/dashboard.js +689 -0
- cli/ui/components/edits.js +220 -0
- cli/ui/components/instructions.js +481 -0
- cli/ui/components/memory.js +626 -0
- cli/ui/components/sessions.js +606 -0
- cli/ui/components/settings.js +1404 -0
- cli/ui/components/sidebar.js +156 -0
- cli/ui/icons.js +51 -0
- cli/ui/shared.js +119 -0
- cli/ui/theme.js +22 -0
- cli/ui.html +168 -0
- cli/ui_legacy.html +6797 -0
- cli/ui_nano.html +503 -0
- code_context_control-2.28.0.dist-info/METADATA +248 -0
- code_context_control-2.28.0.dist-info/RECORD +150 -0
- code_context_control-2.28.0.dist-info/WHEEL +5 -0
- code_context_control-2.28.0.dist-info/entry_points.txt +4 -0
- code_context_control-2.28.0.dist-info/licenses/LICENSE +201 -0
- code_context_control-2.28.0.dist-info/top_level.txt +5 -0
- core/__init__.py +75 -0
- core/config.py +269 -0
- core/ide.py +188 -0
- oracle/__init__.py +1 -0
- oracle/config.py +75 -0
- oracle/oracle.html +3900 -0
- oracle/oracle_server.py +663 -0
- oracle/services/__init__.py +1 -0
- oracle/services/c3_bridge.py +210 -0
- oracle/services/chat_engine.py +1103 -0
- oracle/services/chat_store.py +155 -0
- oracle/services/cross_memory.py +154 -0
- oracle/services/federated_graph.py +463 -0
- oracle/services/health_checker.py +117 -0
- oracle/services/insight_engine.py +307 -0
- oracle/services/memory_reader.py +106 -0
- oracle/services/memory_writer.py +182 -0
- oracle/services/ollama_bridge.py +332 -0
- oracle/services/project_scanner.py +87 -0
- oracle/services/review_agent.py +206 -0
- services/__init__.py +1 -0
- services/activity_log.py +93 -0
- services/agent_base.py +124 -0
- services/agents.py +1529 -0
- services/auto_memory.py +407 -0
- services/bench/__init__.py +6 -0
- services/bench/external/__init__.py +29 -0
- services/bench/external/aider_polyglot.py +405 -0
- services/bench/external/swe_bench.py +485 -0
- services/benchmark_dashboard.py +596 -0
- services/claude_md.py +785 -0
- services/compressor.py +592 -0
- services/context_snapshot.py +356 -0
- services/conversation_store.py +870 -0
- services/doc_index.py +537 -0
- services/e2e_benchmark.py +2884 -0
- services/e2e_evaluator.py +396 -0
- services/e2e_tasks.py +743 -0
- services/edit_ledger.py +459 -0
- services/embedding_index.py +341 -0
- services/error_reporting.py +123 -0
- services/file_memory.py +734 -0
- services/hub_service.py +585 -0
- services/indexer.py +712 -0
- services/memory.py +318 -0
- services/memory_consolidator.py +538 -0
- services/memory_graph.py +382 -0
- services/memory_grounder.py +304 -0
- services/memory_scorer.py +246 -0
- services/metrics.py +86 -0
- services/notifications.py +209 -0
- services/ollama_client.py +201 -0
- services/output_filter.py +488 -0
- services/parser.py +1238 -0
- services/project_manager.py +579 -0
- services/protocol.py +306 -0
- services/proxy_state.py +152 -0
- services/retrieval_broker.py +129 -0
- services/router.py +414 -0
- services/runtime.py +326 -0
- services/session_benchmark.py +1945 -0
- services/session_manager.py +1026 -0
- services/session_preloader.py +251 -0
- services/text_index.py +90 -0
- services/tool_classifier.py +176 -0
- services/transcript_index.py +340 -0
- services/validation_cache.py +155 -0
- services/vector_store.py +299 -0
- services/version_tracker.py +271 -0
- services/watcher.py +192 -0
- tui/__init__.py +0 -0
- tui/backend.py +59 -0
- tui/main.py +145 -0
- tui/screens/__init__.py +1 -0
- tui/screens/benchmark_view.py +109 -0
- tui/screens/claudemd_view.py +46 -0
- tui/screens/compress_view.py +52 -0
- tui/screens/index_view.py +74 -0
- tui/screens/init_view.py +82 -0
- tui/screens/mcp_view.py +73 -0
- tui/screens/optimize_view.py +41 -0
- tui/screens/pipe_view.py +46 -0
- tui/screens/projects_view.py +355 -0
- tui/screens/search_view.py +55 -0
- tui/screens/session_view.py +143 -0
- tui/screens/stats.py +158 -0
- tui/screens/ui_view.py +54 -0
- tui/theme.tcss +335 -0
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
"""SessionPreloader — First-prompt auto-retrieval for Local RAG Pipeline.
|
|
2
|
+
|
|
3
|
+
On the first c3_memory(action='recall') in a session, this module automatically
|
|
4
|
+
retrieves relevant doc chunks, code context, and session history, then merges
|
|
5
|
+
them into a pre-context block injected before the normal recall results.
|
|
6
|
+
|
|
7
|
+
This eliminates repeated discovery work across sessions for the same topics.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
import re
|
|
12
|
+
from typing import Optional
|
|
13
|
+
|
|
14
|
+
from core import count_tokens
|
|
15
|
+
|
|
16
|
+
log = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
# Budget cap for pre-context injection
|
|
19
|
+
_DEFAULT_MAX_PRECONTEXT_TOKENS = 400
|
|
20
|
+
|
|
21
|
+
# Minimum score threshold for doc chunks to be included
|
|
22
|
+
_MIN_DOC_SCORE = 0.05
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class SessionPreloader:
|
|
26
|
+
"""Auto-retrieves relevant project context on first recall of a session."""
|
|
27
|
+
|
|
28
|
+
def __init__(self, doc_index, embedding_index=None, session_mgr=None,
|
|
29
|
+
memory_store=None, config: Optional[dict] = None):
|
|
30
|
+
self.doc_index = doc_index
|
|
31
|
+
self.embedding_index = embedding_index
|
|
32
|
+
self.session_mgr = session_mgr
|
|
33
|
+
self.memory_store = memory_store
|
|
34
|
+
self._config = config or {}
|
|
35
|
+
self._preloaded_sessions: set = set() # session IDs that already got preloaded
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def max_tokens(self) -> int:
|
|
39
|
+
return self._config.get("max_precontext_tokens", _DEFAULT_MAX_PRECONTEXT_TOKENS)
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def enabled(self) -> bool:
|
|
43
|
+
return self._config.get("enabled", True)
|
|
44
|
+
|
|
45
|
+
def should_preload(self, session_id: str) -> bool:
|
|
46
|
+
"""Check if this session hasn't been preloaded yet."""
|
|
47
|
+
if not self.enabled:
|
|
48
|
+
return False
|
|
49
|
+
if not self.doc_index:
|
|
50
|
+
return False
|
|
51
|
+
if not self.doc_index.chunks:
|
|
52
|
+
return False
|
|
53
|
+
return session_id not in self._preloaded_sessions
|
|
54
|
+
|
|
55
|
+
def preload(self, query: str, session_id: str, top_k: int = 5) -> str:
|
|
56
|
+
"""Generate pre-context for the first recall in a session.
|
|
57
|
+
|
|
58
|
+
Returns a formatted string to prepend to the recall results,
|
|
59
|
+
or empty string if nothing relevant found.
|
|
60
|
+
"""
|
|
61
|
+
if not self.should_preload(session_id):
|
|
62
|
+
return ""
|
|
63
|
+
|
|
64
|
+
self._preloaded_sessions.add(session_id)
|
|
65
|
+
|
|
66
|
+
# Extract expanded signals from the query
|
|
67
|
+
signals = self._extract_signals(query)
|
|
68
|
+
# Skip preloading for simple/short queries — not worth 700+ token injection
|
|
69
|
+
if len(signals) < 3:
|
|
70
|
+
return ""
|
|
71
|
+
|
|
72
|
+
signal_query = " ".join(signals)
|
|
73
|
+
|
|
74
|
+
# Retrieve from doc index
|
|
75
|
+
doc_results = self.doc_index.search(signal_query, top_k=top_k * 2)
|
|
76
|
+
|
|
77
|
+
# Also try embedding-based search if available
|
|
78
|
+
if self.embedding_index and self.embedding_index.ready:
|
|
79
|
+
try:
|
|
80
|
+
embed_results = self.embedding_index.search(query, top_k=3)
|
|
81
|
+
# Convert to comparable format (embed results have different shape)
|
|
82
|
+
for er in embed_results:
|
|
83
|
+
if er.get("content") and er.get("score", 0) > 0.3:
|
|
84
|
+
doc_results.append({
|
|
85
|
+
"id": er.get("chunk_id", er.get("id", "")),
|
|
86
|
+
"doc_id": er.get("doc_id", ""),
|
|
87
|
+
"content": er["content"],
|
|
88
|
+
"tokens": count_tokens(er["content"]),
|
|
89
|
+
"kind": "code",
|
|
90
|
+
"source_type": "code_semantic",
|
|
91
|
+
"priority": 1.0,
|
|
92
|
+
"score": er["score"] * 0.8, # slightly discount code vs docs
|
|
93
|
+
"heading_path": [er.get("doc_id", "")],
|
|
94
|
+
})
|
|
95
|
+
except Exception:
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
# Get recent session context as weak signal
|
|
99
|
+
session_context = self._get_session_signals()
|
|
100
|
+
|
|
101
|
+
# Rank, deduplicate, and budget-cap
|
|
102
|
+
precontext = self._build_precontext(doc_results, session_context)
|
|
103
|
+
|
|
104
|
+
if not precontext:
|
|
105
|
+
return ""
|
|
106
|
+
|
|
107
|
+
return precontext
|
|
108
|
+
|
|
109
|
+
def _extract_signals(self, query: str) -> list[str]:
|
|
110
|
+
"""Extract retrieval signals from the user's query."""
|
|
111
|
+
signals = []
|
|
112
|
+
|
|
113
|
+
# Direct tokens from query
|
|
114
|
+
tokens = re.findall(r"\w+", query.lower())
|
|
115
|
+
# Filter out very short/common words
|
|
116
|
+
stopwords = {"the", "a", "an", "is", "are", "was", "were", "be", "been",
|
|
117
|
+
"for", "to", "of", "in", "on", "at", "by", "with", "from",
|
|
118
|
+
"and", "or", "not", "this", "that", "it", "as", "do", "does",
|
|
119
|
+
"has", "have", "had", "can", "could", "will", "would", "should",
|
|
120
|
+
"may", "might", "about", "what", "how", "when", "where", "which",
|
|
121
|
+
"who", "all", "any", "some", "no", "my", "your", "our", "their"}
|
|
122
|
+
signals.extend(t for t in tokens if t not in stopwords and len(t) > 1)
|
|
123
|
+
|
|
124
|
+
# Extract file paths mentioned in query
|
|
125
|
+
file_patterns = re.findall(r"[\w/\\]+\.[\w]+", query)
|
|
126
|
+
for fp in file_patterns:
|
|
127
|
+
# Add stem words from file path
|
|
128
|
+
parts = re.split(r"[/\\_.\-]", fp)
|
|
129
|
+
signals.extend(p.lower() for p in parts if len(p) > 1)
|
|
130
|
+
|
|
131
|
+
return list(dict.fromkeys(signals)) # deduplicate preserving order
|
|
132
|
+
|
|
133
|
+
def _get_session_signals(self) -> str:
|
|
134
|
+
"""Get compressed context from recent sessions as weak signals."""
|
|
135
|
+
if not self.session_mgr:
|
|
136
|
+
return ""
|
|
137
|
+
|
|
138
|
+
try:
|
|
139
|
+
return self.session_mgr.get_session_context(n_sessions=2)
|
|
140
|
+
except Exception:
|
|
141
|
+
return ""
|
|
142
|
+
|
|
143
|
+
def _build_precontext(self, doc_results: list, session_context: str) -> str:
|
|
144
|
+
"""Build the final pre-context string within token budget."""
|
|
145
|
+
if not doc_results and not session_context:
|
|
146
|
+
return ""
|
|
147
|
+
|
|
148
|
+
budget = self.max_tokens
|
|
149
|
+
parts = []
|
|
150
|
+
used_tokens = 0
|
|
151
|
+
seen_docs = set()
|
|
152
|
+
|
|
153
|
+
# Header
|
|
154
|
+
header = "[session:pre-context] Auto-retrieved project context"
|
|
155
|
+
used_tokens += count_tokens(header) + 2 # +2 for newlines
|
|
156
|
+
|
|
157
|
+
# Sort by score descending
|
|
158
|
+
doc_results.sort(key=lambda x: x.get("score", 0), reverse=True)
|
|
159
|
+
|
|
160
|
+
# Filter low-score results
|
|
161
|
+
doc_results = [r for r in doc_results if r.get("score", 0) >= _MIN_DOC_SCORE]
|
|
162
|
+
|
|
163
|
+
# Group by source file for cleaner output
|
|
164
|
+
for result in doc_results:
|
|
165
|
+
doc_id = result.get("doc_id", "")
|
|
166
|
+
content = result.get("content", "").strip()
|
|
167
|
+
if not content:
|
|
168
|
+
continue
|
|
169
|
+
|
|
170
|
+
# Deduplicate by doc_id + heading
|
|
171
|
+
dedup_key = f"{doc_id}::{result.get('id', '')}"
|
|
172
|
+
if dedup_key in seen_docs:
|
|
173
|
+
continue
|
|
174
|
+
seen_docs.add(dedup_key)
|
|
175
|
+
|
|
176
|
+
# Format chunk
|
|
177
|
+
source_label = self._source_label(result)
|
|
178
|
+
chunk_text = f"\n## {source_label}\n{content}"
|
|
179
|
+
chunk_tokens = count_tokens(chunk_text)
|
|
180
|
+
|
|
181
|
+
if used_tokens + chunk_tokens > budget:
|
|
182
|
+
# Try to fit a truncated version
|
|
183
|
+
remaining = budget - used_tokens - 20 # margin
|
|
184
|
+
if remaining > 50:
|
|
185
|
+
lines = content.split("\n")
|
|
186
|
+
truncated = []
|
|
187
|
+
t = 0
|
|
188
|
+
for line in lines:
|
|
189
|
+
lt = count_tokens(line)
|
|
190
|
+
if t + lt > remaining:
|
|
191
|
+
break
|
|
192
|
+
truncated.append(line)
|
|
193
|
+
t += lt
|
|
194
|
+
if truncated:
|
|
195
|
+
chunk_text = f"\n## {source_label}\n" + "\n".join(truncated) + "\n..."
|
|
196
|
+
chunk_tokens = count_tokens(chunk_text)
|
|
197
|
+
parts.append(chunk_text)
|
|
198
|
+
used_tokens += chunk_tokens
|
|
199
|
+
break # budget exhausted
|
|
200
|
+
else:
|
|
201
|
+
parts.append(chunk_text)
|
|
202
|
+
used_tokens += chunk_tokens
|
|
203
|
+
|
|
204
|
+
# Add session context if budget remains
|
|
205
|
+
if session_context and used_tokens < budget - 100:
|
|
206
|
+
remaining = budget - used_tokens - 10
|
|
207
|
+
sc_tokens = count_tokens(session_context)
|
|
208
|
+
if sc_tokens > remaining:
|
|
209
|
+
# Truncate session context
|
|
210
|
+
lines = session_context.split("\n")
|
|
211
|
+
truncated = []
|
|
212
|
+
t = 0
|
|
213
|
+
for line in lines:
|
|
214
|
+
lt = count_tokens(line)
|
|
215
|
+
if t + lt > remaining:
|
|
216
|
+
break
|
|
217
|
+
truncated.append(line)
|
|
218
|
+
t += lt
|
|
219
|
+
session_context = "\n".join(truncated)
|
|
220
|
+
|
|
221
|
+
if session_context.strip():
|
|
222
|
+
parts.append(f"\n## Recent Session Context\n{session_context.strip()}")
|
|
223
|
+
|
|
224
|
+
if not parts:
|
|
225
|
+
return ""
|
|
226
|
+
|
|
227
|
+
chunk_count = len([p for p in parts if p.startswith("\n##")])
|
|
228
|
+
total_tokens = sum(count_tokens(p) for p in parts)
|
|
229
|
+
header = f"[session:pre-context] Auto-retrieved project context ({chunk_count} chunks, {total_tokens} tokens)"
|
|
230
|
+
|
|
231
|
+
return header + "\n" + "\n".join(parts) + "\n\n---\n"
|
|
232
|
+
|
|
233
|
+
def _source_label(self, result: dict) -> str:
|
|
234
|
+
"""Generate a human-readable source label for a chunk."""
|
|
235
|
+
doc_id = result.get("doc_id", "unknown")
|
|
236
|
+
source_type = result.get("source_type", "")
|
|
237
|
+
heading_path = result.get("heading_path", [])
|
|
238
|
+
|
|
239
|
+
if source_type == "markdown":
|
|
240
|
+
if len(heading_path) > 1:
|
|
241
|
+
return f"{heading_path[-1]} (from {doc_id})"
|
|
242
|
+
return f"From {doc_id}"
|
|
243
|
+
elif source_type == "docstring":
|
|
244
|
+
name = result.get("id", "").split("::")[-1] if "::" in result.get("id", "") else doc_id
|
|
245
|
+
return f"Docstring: {name} (from {doc_id})"
|
|
246
|
+
elif source_type == "config":
|
|
247
|
+
return f"Config: {doc_id}"
|
|
248
|
+
elif source_type == "code_semantic":
|
|
249
|
+
return f"Related code: {doc_id}"
|
|
250
|
+
else:
|
|
251
|
+
return f"From {doc_id}"
|
services/text_index.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""Lightweight incremental text index used by local memory stores."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
import re
|
|
7
|
+
from collections import Counter
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TextIndex:
|
|
11
|
+
"""Incremental TF-IDF index over small local document collections."""
|
|
12
|
+
|
|
13
|
+
def __init__(self):
|
|
14
|
+
self._docs: dict[str, str] = {}
|
|
15
|
+
self._tf: dict[str, Counter] = {}
|
|
16
|
+
self._df: Counter = Counter()
|
|
17
|
+
|
|
18
|
+
@staticmethod
|
|
19
|
+
def tokenize(text: str) -> list[str]:
|
|
20
|
+
text = re.sub(r"([a-z])([A-Z])", r"\1 \2", text or "")
|
|
21
|
+
text = text.replace("_", " ").replace("-", " ")
|
|
22
|
+
return re.findall(r"[a-zA-Z0-9]{2,}", text.lower())
|
|
23
|
+
|
|
24
|
+
def __len__(self) -> int:
|
|
25
|
+
return len(self._docs)
|
|
26
|
+
|
|
27
|
+
def clear(self):
|
|
28
|
+
self._docs.clear()
|
|
29
|
+
self._tf.clear()
|
|
30
|
+
self._df.clear()
|
|
31
|
+
|
|
32
|
+
def ids(self) -> list[str]:
|
|
33
|
+
return list(self._docs.keys())
|
|
34
|
+
|
|
35
|
+
def get_text(self, doc_id: str) -> str:
|
|
36
|
+
return self._docs.get(doc_id, "")
|
|
37
|
+
|
|
38
|
+
def rebuild(self, docs: dict[str, str]):
|
|
39
|
+
self.clear()
|
|
40
|
+
for doc_id, text in docs.items():
|
|
41
|
+
self.add_or_update(doc_id, text)
|
|
42
|
+
|
|
43
|
+
def add_or_update(self, doc_id: str, text: str):
|
|
44
|
+
text = text or ""
|
|
45
|
+
old_tf = self._tf.get(doc_id)
|
|
46
|
+
if old_tf:
|
|
47
|
+
for token in old_tf:
|
|
48
|
+
self._df[token] -= 1
|
|
49
|
+
if self._df[token] <= 0:
|
|
50
|
+
del self._df[token]
|
|
51
|
+
|
|
52
|
+
tokens = self.tokenize(text)
|
|
53
|
+
tf = Counter(tokens)
|
|
54
|
+
self._docs[doc_id] = text
|
|
55
|
+
self._tf[doc_id] = tf
|
|
56
|
+
for token in tf:
|
|
57
|
+
self._df[token] += 1
|
|
58
|
+
|
|
59
|
+
def remove(self, doc_id: str):
|
|
60
|
+
old_tf = self._tf.pop(doc_id, None)
|
|
61
|
+
self._docs.pop(doc_id, None)
|
|
62
|
+
if not old_tf:
|
|
63
|
+
return
|
|
64
|
+
for token in old_tf:
|
|
65
|
+
self._df[token] -= 1
|
|
66
|
+
if self._df[token] <= 0:
|
|
67
|
+
del self._df[token]
|
|
68
|
+
|
|
69
|
+
def search(self, query: str, top_k: int = 5) -> list[tuple[str, float]]:
|
|
70
|
+
query_tokens = self.tokenize(query)
|
|
71
|
+
if not query_tokens or not self._docs:
|
|
72
|
+
return []
|
|
73
|
+
|
|
74
|
+
q_terms = Counter(query_tokens)
|
|
75
|
+
total_docs = len(self._docs)
|
|
76
|
+
scores: dict[str, float] = {}
|
|
77
|
+
|
|
78
|
+
for doc_id, term_counts in self._tf.items():
|
|
79
|
+
total_terms = sum(term_counts.values()) or 1
|
|
80
|
+
score = 0.0
|
|
81
|
+
for term, q_count in q_terms.items():
|
|
82
|
+
if term not in term_counts:
|
|
83
|
+
continue
|
|
84
|
+
tf = term_counts[term] / total_terms
|
|
85
|
+
idf = math.log((total_docs + 1) / (self._df.get(term, 0) + 1)) + 1.0
|
|
86
|
+
score += tf * idf * q_count
|
|
87
|
+
if score > 0:
|
|
88
|
+
scores[doc_id] = score
|
|
89
|
+
|
|
90
|
+
return sorted(scores.items(), key=lambda item: item[1], reverse=True)[:top_k]
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
"""Tool classifier for the MCP proxy — dynamically filters visible tools.
|
|
2
|
+
|
|
3
|
+
Categorizes the ~26 C3 tools into 6 groups and selects which groups are
|
|
4
|
+
relevant based on recent tool usage, keyword patterns, and optional SLM input.
|
|
5
|
+
"""
|
|
6
|
+
import re
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
from services.ollama_client import OllamaClient
|
|
10
|
+
|
|
11
|
+
# ── Tool Categories ────────────────────────────────────────
|
|
12
|
+
|
|
13
|
+
CATEGORIES = {
|
|
14
|
+
"core": {
|
|
15
|
+
"tools": [
|
|
16
|
+
"c3_search", "c3_compress", "c3_validate", "c3_filter",
|
|
17
|
+
"c3_session", "c3_memory", "c3_read", "c3_impact", "c3_shell",
|
|
18
|
+
],
|
|
19
|
+
"keywords": None, # Always included
|
|
20
|
+
"priority": 0,
|
|
21
|
+
},
|
|
22
|
+
"analysis": {
|
|
23
|
+
"tools": [
|
|
24
|
+
"c3_delegate",
|
|
25
|
+
],
|
|
26
|
+
"keywords": re.compile(
|
|
27
|
+
r"hybrid|ollama|filter|route|summarize|llm|tier|raw\s*output|delegate",
|
|
28
|
+
re.IGNORECASE,
|
|
29
|
+
),
|
|
30
|
+
"priority": 1,
|
|
31
|
+
},
|
|
32
|
+
"meta": {
|
|
33
|
+
"tools": [
|
|
34
|
+
"c3_status",
|
|
35
|
+
],
|
|
36
|
+
"keywords": re.compile(
|
|
37
|
+
r"token|stats|optimi[sz]e|index|rebuild|notif|budget|context\s*status",
|
|
38
|
+
re.IGNORECASE,
|
|
39
|
+
),
|
|
40
|
+
"priority": 2,
|
|
41
|
+
},
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
# Reverse lookup: tool name -> category
|
|
45
|
+
_TOOL_TO_CATEGORY = {}
|
|
46
|
+
for _cat, _info in CATEGORIES.items():
|
|
47
|
+
for _tool in _info["tools"]:
|
|
48
|
+
_TOOL_TO_CATEGORY[_tool] = _cat
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class ToolClassifier:
|
|
52
|
+
"""Selects which tool categories are active based on context."""
|
|
53
|
+
|
|
54
|
+
def __init__(self, always_visible: list[str] = None,
|
|
55
|
+
max_tools: int = 12,
|
|
56
|
+
use_slm: bool = True,
|
|
57
|
+
slm_model: str = "gemma3n:latest",
|
|
58
|
+
ollama: Optional[OllamaClient] = None):
|
|
59
|
+
self.always_visible = always_visible or ["core"]
|
|
60
|
+
self.max_tools = max_tools
|
|
61
|
+
self.use_slm = use_slm
|
|
62
|
+
self.slm_model = slm_model
|
|
63
|
+
self.ollama = ollama
|
|
64
|
+
self.classification_reasons: dict[str, str] = {}
|
|
65
|
+
|
|
66
|
+
def classify(self, recent_tool_names: list[str],
|
|
67
|
+
recent_text: str) -> list[str]:
|
|
68
|
+
"""Return list of active category names."""
|
|
69
|
+
# "all" shortcut — every category is always visible
|
|
70
|
+
if "all" in self.always_visible:
|
|
71
|
+
all_cats = sorted(CATEGORIES, key=lambda c: CATEGORIES[c].get("priority", 99))
|
|
72
|
+
self.classification_reasons = {c: "always" for c in all_cats}
|
|
73
|
+
return all_cats
|
|
74
|
+
|
|
75
|
+
active = set(self.always_visible)
|
|
76
|
+
reasons: dict[str, str] = {}
|
|
77
|
+
|
|
78
|
+
# Always-visible categories
|
|
79
|
+
for cat in self.always_visible:
|
|
80
|
+
reasons[cat] = "always"
|
|
81
|
+
|
|
82
|
+
# Include categories of recently-used tools
|
|
83
|
+
for name in recent_tool_names[-5:]:
|
|
84
|
+
cat = _TOOL_TO_CATEGORY.get(name)
|
|
85
|
+
if cat and cat not in active:
|
|
86
|
+
reasons[cat] = "recent"
|
|
87
|
+
active.add(cat)
|
|
88
|
+
elif cat and cat not in reasons:
|
|
89
|
+
reasons[cat] = "recent"
|
|
90
|
+
|
|
91
|
+
# Keyword scan
|
|
92
|
+
for cat_name, cat_info in CATEGORIES.items():
|
|
93
|
+
if cat_name in active:
|
|
94
|
+
continue
|
|
95
|
+
pattern = cat_info["keywords"]
|
|
96
|
+
if pattern and pattern.search(recent_text):
|
|
97
|
+
active.add(cat_name)
|
|
98
|
+
reasons[cat_name] = "keyword"
|
|
99
|
+
|
|
100
|
+
# SLM refinement if heuristic is narrow
|
|
101
|
+
if (len(active) <= 2 and self.use_slm
|
|
102
|
+
and self.ollama and recent_text.strip()):
|
|
103
|
+
slm_cats = self._slm_classify(recent_text, active)
|
|
104
|
+
if slm_cats:
|
|
105
|
+
for cat in slm_cats:
|
|
106
|
+
reasons[cat] = "slm"
|
|
107
|
+
active.update(slm_cats)
|
|
108
|
+
|
|
109
|
+
self.classification_reasons = reasons
|
|
110
|
+
return sorted(active, key=lambda c: CATEGORIES.get(c, {}).get("priority", 99))
|
|
111
|
+
|
|
112
|
+
def filter_tools(self, all_tools: list[dict],
|
|
113
|
+
active_categories: list[str]) -> list[dict]:
|
|
114
|
+
"""Filter a tools/list response to only include active categories."""
|
|
115
|
+
# Build set of allowed tool names
|
|
116
|
+
allowed = set()
|
|
117
|
+
for cat in active_categories:
|
|
118
|
+
cat_info = CATEGORIES.get(cat)
|
|
119
|
+
if cat_info:
|
|
120
|
+
allowed.update(cat_info["tools"])
|
|
121
|
+
|
|
122
|
+
filtered = [t for t in all_tools if t.get("name") in allowed]
|
|
123
|
+
|
|
124
|
+
# Cap at max_tools by priority
|
|
125
|
+
if len(filtered) > self.max_tools:
|
|
126
|
+
# Sort by category priority, keep first max_tools
|
|
127
|
+
def tool_priority(t):
|
|
128
|
+
cat = _TOOL_TO_CATEGORY.get(t.get("name"), "")
|
|
129
|
+
return CATEGORIES.get(cat, {}).get("priority", 99)
|
|
130
|
+
filtered.sort(key=tool_priority)
|
|
131
|
+
filtered = filtered[:self.max_tools]
|
|
132
|
+
|
|
133
|
+
return filtered
|
|
134
|
+
|
|
135
|
+
def get_active_tool_count(self, active_categories: list[str]) -> int:
|
|
136
|
+
"""Count how many tools would be visible for given categories."""
|
|
137
|
+
count = 0
|
|
138
|
+
for cat in active_categories:
|
|
139
|
+
cat_info = CATEGORIES.get(cat)
|
|
140
|
+
if cat_info:
|
|
141
|
+
count += len(cat_info["tools"])
|
|
142
|
+
return min(count, self.max_tools)
|
|
143
|
+
|
|
144
|
+
# ── SLM Refinement ─────────────────────────────────────
|
|
145
|
+
|
|
146
|
+
def _slm_classify(self, text: str, current: set[str]) -> list[str]:
|
|
147
|
+
"""Ask SLM which additional categories might be relevant."""
|
|
148
|
+
available = [c for c in CATEGORIES if c not in current]
|
|
149
|
+
if not available:
|
|
150
|
+
return []
|
|
151
|
+
|
|
152
|
+
prompt = (
|
|
153
|
+
f"Given this context: {text[:200]}\n\n"
|
|
154
|
+
f"Which of these tool categories are relevant? "
|
|
155
|
+
f"Categories: {', '.join(available)}\n"
|
|
156
|
+
f"- core: search, compress, read, filter, validate, session, memory\n"
|
|
157
|
+
f"- analysis: delegate tasks to local LLM\n"
|
|
158
|
+
f"- meta: status, budget, notifications, health\n\n"
|
|
159
|
+
f"Reply with ONLY the category names, comma-separated. "
|
|
160
|
+
f"If none are relevant, reply NONE."
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
try:
|
|
164
|
+
result = self.ollama.generate(
|
|
165
|
+
prompt=prompt,
|
|
166
|
+
model=self.slm_model,
|
|
167
|
+
temperature=0.0,
|
|
168
|
+
max_tokens=50,
|
|
169
|
+
)
|
|
170
|
+
if not result or "NONE" in result.upper():
|
|
171
|
+
return []
|
|
172
|
+
# Parse comma-separated category names
|
|
173
|
+
cats = [c.strip().lower() for c in result.split(",")]
|
|
174
|
+
return [c for c in cats if c in CATEGORIES]
|
|
175
|
+
except Exception:
|
|
176
|
+
return []
|