memctrl 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
memctrl/mcp_server.py ADDED
@@ -0,0 +1,231 @@
1
+ """MemCtrl — MCP server for AI assistant integration.
2
+
3
+ Exposes memory operations as MCP tools:
4
+ memctrl_query — Retrieve relevant memories with trace
5
+ memctrl_add — Store a new memory
6
+ memctrl_trigger — Fire a trigger event
7
+ memctrl_tree — Get full memory tree
8
+ memctrl_audit — Get trigger audit log
9
+
10
+ MCP config for Claude Code:
11
+ {
12
+ "mcpServers": {
13
+ "memctrl": {
14
+ "command": "memctrl",
15
+ "args": ["serve"],
16
+ "env": {}
17
+ }
18
+ }
19
+ }
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import asyncio
25
+ import json
26
+ import os
27
+ from typing import Any, Dict, List, Optional
28
+
29
+ # ---------------------------------------------------------------------------
30
+ # MCP imports (optional — graceful degradation if mcp not installed)
31
+ # ---------------------------------------------------------------------------
32
+
33
+ try:
34
+ from mcp.server import Server
35
+ from mcp.server.stdio import stdio_server
36
+ from mcp.types import Tool, TextContent
37
+ HAS_MCP = True
38
+ except ImportError:
39
+ HAS_MCP = False
40
+
41
+ # Stub classes for type checking
42
+ class Tool: # type: ignore
43
+ def __init__(self, **kwargs): pass
44
+
45
+ class TextContent: # type: ignore
46
+ def __init__(self, type="", text=""): pass
47
+
48
+ class Server: # type: ignore
49
+ def __init__(self, name): pass
50
+
51
+
52
+ # ---------------------------------------------------------------------------
53
+ # Tool schemas
54
+ # ---------------------------------------------------------------------------
55
+
56
+ MCP_TOOLS = [
57
+ Tool(
58
+ name="memctrl_query",
59
+ description=(
60
+ "Query the memory tree for relevant facts about the project, "
61
+ "session, or user preferences. Returns facts with reasoning trace."
62
+ ),
63
+ inputSchema={
64
+ "type": "object",
65
+ "properties": {
66
+ "query": {"type": "string", "description": "Natural language query"},
67
+ "layer": {"type": "string", "description": "Optional layer filter (project/session/user)"},
68
+ },
69
+ "required": ["query"],
70
+ },
71
+ ),
72
+ Tool(
73
+ name="memctrl_add",
74
+ description="Add a memory to the store",
75
+ inputSchema={
76
+ "type": "object",
77
+ "properties": {
78
+ "content": {"type": "string", "description": "Memory content"},
79
+ "layer": {"type": "string", "description": "Target layer", "enum": ["project", "session", "user"]},
80
+ "source": {"type": "string", "default": "mcp"},
81
+ },
82
+ "required": ["content", "layer"],
83
+ },
84
+ ),
85
+ Tool(
86
+ name="memctrl_trigger",
87
+ description="Fire a trigger event (e.g., on_session_end)",
88
+ inputSchema={
89
+ "type": "object",
90
+ "properties": {
91
+ "event": {"type": "string", "description": "Event name"},
92
+ "context": {"type": "object", "default": {}},
93
+ },
94
+ "required": ["event"],
95
+ },
96
+ ),
97
+ Tool(
98
+ name="memctrl_tree",
99
+ description="Get the full memory tree as JSON",
100
+ inputSchema={"type": "object", "properties": {}},
101
+ ),
102
+ Tool(
103
+ name="memctrl_audit",
104
+ description="Get trigger audit log",
105
+ inputSchema={
106
+ "type": "object",
107
+ "properties": {
108
+ "limit": {"type": "integer", "default": 50},
109
+ },
110
+ },
111
+ ),
112
+ ]
113
+
114
+
115
+ # ---------------------------------------------------------------------------
116
+ # Server
117
+ # ---------------------------------------------------------------------------
118
+
119
+ async def serve_mcp(host: str = "127.0.0.1", port: int = 8080) -> None:
120
+ """Start MCP server using stdio transport."""
121
+ if not HAS_MCP:
122
+ print("ERROR: MCP package not installed. Install: pip install mcp")
123
+ return
124
+
125
+ from memctrl.store import MemoryStore
126
+ from memctrl.rules import RuleEngine
127
+ from memctrl.tree import MemoryTreeBuilder
128
+ from memctrl.retriever import MemoryRetriever
129
+
130
+ server = Server("memctrl")
131
+
132
+ @server.list_tools()
133
+ async def list_tools() -> list[Tool]:
134
+ return MCP_TOOLS
135
+
136
+ @server.call_tool()
137
+ async def call_tool(name: str, arguments: dict) -> list[TextContent]:
138
+ db_path = os.environ.get("MEMCTRL_DB_PATH")
139
+ store = MemoryStore(db_path)
140
+ engine = RuleEngine()
141
+
142
+ try:
143
+ if name == "memctrl_add":
144
+ mid = store.insert_memory(
145
+ layer=arguments["layer"],
146
+ content=arguments["content"],
147
+ source=arguments.get("source", "mcp"),
148
+ )
149
+ return [TextContent(
150
+ type="text",
151
+ text=json.dumps({"id": mid, "status": "stored"}),
152
+ )]
153
+
154
+ elif name == "memctrl_query":
155
+ memories = store.list_memories(arguments.get("layer"))
156
+ if not memories:
157
+ return [TextContent(
158
+ type="text",
159
+ text=json.dumps({"facts": [], "trace": ["no_memories"]}),
160
+ )]
161
+
162
+ mem_dicts = [m.to_dict() for m in memories]
163
+ memory_lookup = {m.id: m.to_dict() for m in memories}
164
+
165
+ builder = MemoryTreeBuilder()
166
+ tree = await builder.build_tree(mem_dicts)
167
+ tree_dict = tree.to_dict()
168
+
169
+ retriever = MemoryRetriever()
170
+ result = await retriever.retrieve(
171
+ arguments["query"],
172
+ tree_dict,
173
+ memory_lookup=memory_lookup,
174
+ )
175
+ return [TextContent(
176
+ type="text",
177
+ text=json.dumps(result.to_dict()),
178
+ )]
179
+
180
+ elif name == "memctrl_tree":
181
+ memories = store.list_memories()
182
+ mem_dicts = [m.to_dict() for m in memories]
183
+ builder = MemoryTreeBuilder()
184
+ tree = await builder.build_tree(mem_dicts)
185
+ return [TextContent(
186
+ type="text",
187
+ text=json.dumps(tree.to_dict()),
188
+ )]
189
+
190
+ elif name == "memctrl_audit":
191
+ logs = store.get_trigger_log(arguments.get("limit", 50))
192
+ return [TextContent(
193
+ type="text",
194
+ text=json.dumps({
195
+ "logs": [l.to_dict() for l in logs],
196
+ }),
197
+ )]
198
+
199
+ elif name == "memctrl_trigger":
200
+ rules = engine.load()
201
+ ids = engine.fire_trigger(
202
+ arguments["event"],
203
+ arguments.get("context", {}),
204
+ store,
205
+ )
206
+ return [TextContent(
207
+ type="text",
208
+ text=json.dumps({
209
+ "status": "fired",
210
+ "event": arguments["event"],
211
+ "affected": len(ids),
212
+ }),
213
+ )]
214
+
215
+ else:
216
+ return [TextContent(type="text", text="Unknown tool")]
217
+
218
+ except Exception as exc:
219
+ return [TextContent(
220
+ type="text",
221
+ text=json.dumps({"error": str(exc)}),
222
+ )]
223
+
224
+ # Use stdio transport (standard for MCP)
225
+ async with stdio_server(server) as (read_stream, write_stream):
226
+ init_options = server.create_initialization_options()
227
+ await server.run(read_stream, write_stream, init_options)
228
+
229
+
230
+ if __name__ == "__main__":
231
+ asyncio.run(serve_mcp())
memctrl/retriever.py ADDED
@@ -0,0 +1,267 @@
1
+ """MemCtrl — PageIndex-style reasoning-based retrieval.
2
+
3
+ Uses LLM to traverse memory tree (titles + summaries only) rather than
4
+ vector similarity. Returns facts WITH reasoning trace.
5
+
6
+ Research: PageIndex (VectifyAI) achieves 98.7% accuracy on FinanceBench
7
+ by replacing vector search with LLM tree traversal. Each retrieval:
8
+ 1. Scan tree structure (titles + summaries)
9
+ 2. LLM reasons which branches are relevant
10
+ 3. Traverse selected branches
11
+ 4. Return facts + full trace
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import json
17
+ import re
18
+ from dataclasses import dataclass, field
19
+ from datetime import datetime
20
+ from typing import Any, Callable, Coroutine, Dict, List, Optional
21
+
22
+ # Type alias
23
+ LLMCallable = Callable[[str, bool], Coroutine[Any, Any, str]]
24
+
25
+
26
+ @dataclass
27
+ class RetrievalResult:
28
+ """Result of a memory retrieval with reasoning trace."""
29
+
30
+ facts: List[str] = field(default_factory=list)
31
+ trace: List[str] = field(default_factory=list)
32
+ confidence: float = 0.0
33
+ sources: List[str] = field(default_factory=list)
34
+
35
+ def to_dict(self) -> dict:
36
+ return {
37
+ "facts": self.facts,
38
+ "trace": self.trace,
39
+ "confidence": self.confidence,
40
+ "sources": self.sources,
41
+ }
42
+
43
+
44
+ class MemoryRetriever:
45
+ """PageIndex-style tree traversal for memory retrieval.
46
+
47
+ Algorithm:
48
+ 1. Strip leaf memory text from tree (keep structure only)
49
+ 2. LLM reads tree titles/summaries, decides which branches relevant
50
+ 3. Traverse into selected branches
51
+ 4. Collect memory facts from leaves
52
+ 5. Return facts + trace showing path taken
53
+ """
54
+
55
+ def __init__(self, llm_client: Optional[LLMCallable] = None):
56
+ self.llm_client = llm_client
57
+
58
+ # --- Public API ---
59
+
60
+ async def retrieve(
61
+ self,
62
+ query: str,
63
+ tree: dict,
64
+ top_k: int = 5,
65
+ memory_lookup: Optional[Dict[str, dict]] = None,
66
+ ) -> RetrievalResult:
67
+ """Retrieve relevant memories with reasoning trace.
68
+
69
+ query: natural language question
70
+ tree: TreeNode serialized as dict (from MemoryTreeBuilder.to_dict)
71
+ memory_lookup: dict of memory_id -> memory dict for content lookup
72
+ top_k: maximum number of facts to return
73
+
74
+ Returns RetrievalResult with facts, trace, confidence, sources.
75
+ """
76
+ if not tree or not memory_lookup:
77
+ return RetrievalResult(facts=[], trace=["empty_tree"], confidence=0.0)
78
+
79
+ if self.llm_client:
80
+ return await self._llm_retrieve(query, tree, memory_lookup, top_k)
81
+ return self._keyword_retrieve(query, tree, memory_lookup, top_k)
82
+
83
+ # --- LLM-based retrieval ---
84
+
85
+ async def _llm_retrieve(
86
+ self,
87
+ query: str,
88
+ tree: dict,
89
+ memory_lookup: Dict[str, dict],
90
+ top_k: int,
91
+ ) -> RetrievalResult:
92
+ # 1. Strip leaf content, keep structure
93
+ stripped = self._strip_leaves(tree)
94
+
95
+ # 2. Build prompt and ask LLM
96
+ prompt = self._build_retrieval_prompt(query, stripped)
97
+
98
+ try:
99
+ response = await self.llm_client(prompt, json_mode=True)
100
+ parsed = json.loads(response)
101
+ except Exception:
102
+ # Fall back to keyword search on any error
103
+ return self._keyword_retrieve(query, tree, memory_lookup, top_k)
104
+
105
+ relevant_node_ids = parsed.get("relevant_nodes", [])
106
+ thinking = parsed.get("thinking", "")
107
+ confidence = parsed.get("confidence", 0.8)
108
+
109
+ if not relevant_node_ids:
110
+ return self._keyword_retrieve(query, tree, memory_lookup, top_k)
111
+
112
+ # 3. Collect memories from selected nodes
113
+ facts, sources = self._collect_from_nodes(
114
+ relevant_node_ids, tree, memory_lookup
115
+ )
116
+
117
+ # 4. Build trace
118
+ trace = ["root"]
119
+ for nid in relevant_node_ids[:3]: # Top 3 nodes in trace
120
+ node = self._find_node(tree, nid)
121
+ if node:
122
+ trace.append(node.get("title", nid))
123
+
124
+ # Limit to top_k
125
+ facts = facts[:top_k]
126
+ sources = sources[:top_k]
127
+
128
+ return RetrievalResult(
129
+ facts=facts,
130
+ trace=trace,
131
+ confidence=confidence,
132
+ sources=sources,
133
+ )
134
+
135
+ def _strip_leaves(self, tree: dict) -> dict:
136
+ """Remove full memory content, keep structure for LLM."""
137
+ result = {
138
+ "id": tree.get("id", ""),
139
+ "title": tree.get("title", ""),
140
+ "layer": tree.get("layer", ""),
141
+ "summary": tree.get("summary", ""),
142
+ "memory_count": len(tree.get("memory_ids", [])),
143
+ "children": [self._strip_leaves(c) for c in tree.get("children", [])],
144
+ }
145
+ return result
146
+
147
+ def _build_retrieval_prompt(self, query: str, stripped_tree: dict) -> str:
148
+ """Build LLM prompt for tree-based retrieval."""
149
+ tree_json = json.dumps(stripped_tree, indent=2)
150
+ return (
151
+ "You are a memory retrieval expert. Given a user query and a "
152
+ "hierarchical memory tree, identify which tree nodes are most "
153
+ "likely to contain relevant information.\n\n"
154
+ f"Query: {query}\n\n"
155
+ "Memory Tree Structure:\n"
156
+ f"{tree_json}\n\n"
157
+ "Return ONLY JSON in this exact format:\n"
158
+ '{\n'
159
+ ' "thinking": "reason about which branches are relevant",\n'
160
+ ' "relevant_nodes": ["node_id_1", "node_id_2"],\n'
161
+ ' "confidence": 0.9\n'
162
+ '}'
163
+ )
164
+
165
+ def _collect_from_nodes(
166
+ self,
167
+ node_ids: List[str],
168
+ tree: dict,
169
+ memory_lookup: Dict[str, dict],
170
+ ) -> tuple[List[str], List[str]]:
171
+ """Collect facts and sources from specified tree nodes."""
172
+ facts: List[str] = []
173
+ sources: List[str] = []
174
+
175
+ for nid in node_ids:
176
+ node = self._find_node(tree, nid)
177
+ if not node:
178
+ continue
179
+
180
+ # If node has direct memory_ids, look them up
181
+ for mid in node.get("memory_ids", []):
182
+ mem = memory_lookup.get(mid)
183
+ if mem and mem.get("content"):
184
+ facts.append(mem["content"])
185
+ sources.append(mem.get("source", "unknown"))
186
+
187
+ # Also check children recursively
188
+ child_facts, child_sources = self._collect_from_nodes(
189
+ [c["id"] for c in node.get("children", [])],
190
+ node, memory_lookup,
191
+ )
192
+ facts.extend(child_facts)
193
+ sources.extend(child_sources)
194
+
195
+ return facts, sources
196
+
197
+ def _find_node(self, tree: dict, node_id: str) -> Optional[dict]:
198
+ if tree.get("id") == node_id:
199
+ return tree
200
+ for child in tree.get("children", []):
201
+ found = self._find_node(child, node_id)
202
+ if found:
203
+ return found
204
+ return None
205
+
206
+ # --- Keyword fallback (no LLM) ---
207
+
208
+ def _keyword_retrieve(
209
+ self,
210
+ query: str,
211
+ tree: dict,
212
+ memory_lookup: Dict[str, dict],
213
+ top_k: int,
214
+ ) -> RetrievalResult:
215
+ """Fallback: score nodes by keyword matching."""
216
+ query_words = set(re.findall(r"\b\w{3,}\b", query.lower()))
217
+ if not query_words:
218
+ return RetrievalResult(facts=[], trace=["no_keywords"], confidence=0.0)
219
+
220
+ scored_memories: Dict[str, tuple[float, str, str]] = {} # mem_id -> (score, content, source)
221
+
222
+ def score_node(node: dict, depth: int = 0):
223
+ node_title = node.get("title", "").lower()
224
+ node_summary = node.get("summary", "").lower()
225
+
226
+ title_score = sum(1 for w in query_words if w in node_title) * 3
227
+ summary_score = sum(1 for w in query_words if w in node_summary) * 2
228
+
229
+ for mid in node.get("memory_ids", []):
230
+ mem = memory_lookup.get(mid)
231
+ if not mem:
232
+ continue
233
+ content = mem.get("content", "").lower()
234
+ content_score = sum(1 for w in query_words if w in content)
235
+ total = title_score + summary_score + content_score + (1.0 / (depth + 1))
236
+ if total > 0:
237
+ existing = scored_memories.get(mid, (0, "", ""))
238
+ if total > existing[0]:
239
+ scored_memories[mid] = (total, mem["content"], mem.get("source", ""))
240
+
241
+ for child in node.get("children", []):
242
+ score_node(child, depth + 1)
243
+
244
+ score_node(tree)
245
+
246
+ sorted_mems = sorted(scored_memories.values(), reverse=True)
247
+ top = sorted_mems[:top_k]
248
+
249
+ if not top:
250
+ return RetrievalResult(
251
+ facts=[], trace=["root", "no_match"], confidence=0.0
252
+ )
253
+
254
+ facts = [s[1] for s in top]
255
+ sources = [s[2] for s in top]
256
+ avg_score = sum(s[0] for s in top) / len(top)
257
+ confidence = min(avg_score / 10, 1.0) # Normalize
258
+
259
+ # Build simple trace from matched content
260
+ trace = ["root", "keyword_search"]
261
+ if facts:
262
+ trace.append(facts[0][:30])
263
+
264
+ return RetrievalResult(
265
+ facts=facts, trace=trace, confidence=round(confidence, 2),
266
+ sources=sources,
267
+ )