okb 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
okb/data/init.sql ADDED
@@ -0,0 +1,92 @@
1
+ -- Knowledge Base Schema for pgvector
2
+ CREATE EXTENSION IF NOT EXISTS vector;
3
+
4
+ -- Main documents table
5
+ CREATE TABLE documents (
6
+ id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
7
+ source_path TEXT NOT NULL,
8
+ source_type TEXT NOT NULL, -- 'markdown', 'code', 'pdf', 'note'
9
+ title TEXT,
10
+ content TEXT NOT NULL,
11
+ metadata JSONB DEFAULT '{}',
12
+ created_at TIMESTAMPTZ DEFAULT NOW(),
13
+ updated_at TIMESTAMPTZ DEFAULT NOW(),
14
+ content_hash TEXT NOT NULL, -- For deduplication/change detection
15
+ UNIQUE(content_hash)
16
+ );
17
+
18
+ -- Chunks for semantic search
19
+ CREATE TABLE chunks (
20
+ id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
21
+ document_id UUID REFERENCES documents(id) ON DELETE CASCADE,
22
+ chunk_index INTEGER NOT NULL,
23
+ content TEXT NOT NULL, -- Original chunk text (for display)
24
+ embedding_text TEXT NOT NULL, -- Contextualized text (what was embedded)
25
+ embedding vector(768), -- nomic-embed-text-v1.5 dimension
26
+ token_count INTEGER,
27
+ metadata JSONB DEFAULT '{}',
28
+ created_at TIMESTAMPTZ DEFAULT NOW()
29
+ );
30
+
31
+ -- Optimized index for similarity search
32
+ -- Using HNSW for better query performance (slightly slower build than IVFFlat)
33
+ CREATE INDEX chunks_embedding_idx ON chunks
34
+ USING hnsw (embedding vector_cosine_ops)
35
+ WITH (m = 16, ef_construction = 64);
36
+
37
+ -- Full-text search as fallback/hybrid
38
+ CREATE INDEX documents_content_fts ON documents
39
+ USING gin(to_tsvector('english', content));
40
+
41
+ CREATE INDEX chunks_content_fts ON chunks
42
+ USING gin(to_tsvector('english', content));
43
+
44
+ -- Source path index for updates
45
+ CREATE INDEX documents_source_path_idx ON documents(source_path);
46
+
47
+ -- Source type index for filtering
48
+ CREATE INDEX documents_source_type_idx ON documents(source_type);
49
+
50
+ -- Metadata GIN index for JSONB queries
51
+ CREATE INDEX documents_metadata_idx ON documents USING gin(metadata);
52
+
53
+ -- Function to update timestamp
54
+ CREATE OR REPLACE FUNCTION update_updated_at()
55
+ RETURNS TRIGGER AS $$
56
+ BEGIN
57
+ NEW.updated_at = NOW();
58
+ RETURN NEW;
59
+ END;
60
+ $$ LANGUAGE plpgsql;
61
+
62
+ CREATE TRIGGER documents_updated_at
63
+ BEFORE UPDATE ON documents
64
+ FOR EACH ROW
65
+ EXECUTE FUNCTION update_updated_at();
66
+
67
+ -- Helper view for search results
68
+ CREATE VIEW search_results AS
69
+ SELECT
70
+ c.id as chunk_id,
71
+ c.content,
72
+ c.chunk_index,
73
+ c.token_count,
74
+ c.embedding,
75
+ d.id as document_id,
76
+ d.source_path,
77
+ d.source_type,
78
+ d.title,
79
+ d.metadata
80
+ FROM chunks c
81
+ JOIN documents d ON c.document_id = d.id;
82
+
83
+ -- Stats view
84
+ CREATE VIEW index_stats AS
85
+ SELECT
86
+ source_type,
87
+ COUNT(DISTINCT d.id) as document_count,
88
+ COUNT(c.id) as chunk_count,
89
+ SUM(c.token_count) as total_tokens
90
+ FROM documents d
91
+ LEFT JOIN chunks c ON c.document_id = d.id
92
+ GROUP BY source_type;
okb/http_server.py ADDED
@@ -0,0 +1,463 @@
1
+ """HTTP transport server for MCP with token authentication.
2
+
3
+ This module provides an HTTP server that serves the LKB MCP server with
4
+ token-based authentication. Tokens can be passed via Authorization header
5
+ or query parameter. A single HTTP server can serve multiple databases,
6
+ with the token determining which database to use.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import sys
12
+ from typing import Any
13
+
14
+ from mcp.server import Server
15
+ from mcp.server.sse import SseServerTransport
16
+ from mcp.types import CallToolResult, TextContent, Tool
17
+ from starlette.applications import Starlette
18
+ from starlette.requests import Request
19
+ from starlette.responses import JSONResponse, Response
20
+ from starlette.routing import Mount, Route
21
+
22
+ from .config import config
23
+ from .local_embedder import warmup
24
+ from .mcp_server import (
25
+ KnowledgeBase,
26
+ format_actionable_items,
27
+ format_search_results,
28
+ )
29
+ from .tokens import OKBTokenVerifier, TokenInfo
30
+
31
+ # Permission sets
32
+ READ_ONLY_TOOLS = frozenset(
33
+ {
34
+ "search_knowledge",
35
+ "keyword_search",
36
+ "hybrid_search",
37
+ "get_document",
38
+ "list_sources",
39
+ "list_projects",
40
+ "recent_documents",
41
+ "get_actionable_items",
42
+ "get_database_info",
43
+ }
44
+ )
45
+
46
+ WRITE_TOOLS = frozenset(
47
+ {
48
+ "save_knowledge",
49
+ "delete_knowledge",
50
+ "set_database_description",
51
+ "add_todo",
52
+ }
53
+ )
54
+
55
+
56
+ def extract_token(request: Request) -> str | None:
57
+ """Extract token from Authorization header or query parameter."""
58
+ auth_header = request.headers.get("Authorization", "")
59
+ if auth_header.startswith("Bearer "):
60
+ return auth_header[7:]
61
+ if "token" in request.query_params:
62
+ return request.query_params["token"]
63
+ return None
64
+
65
+
66
+ class HTTPMCPServer:
67
+ """HTTP server for MCP with token authentication."""
68
+
69
+ def __init__(self):
70
+ self.knowledge_bases: dict[str, KnowledgeBase] = {}
71
+ self.server = Server("knowledge-base")
72
+ # Single shared transport instance for all connections
73
+ self.transport = SseServerTransport("/messages/")
74
+ # Map session_id (hex string) -> token_info
75
+ self.session_tokens: dict[str, TokenInfo] = {}
76
+ self._setup_handlers()
77
+
78
+ def _get_db_url(self, db_name: str) -> str:
79
+ """Get database URL by name."""
80
+ return config.get_database(db_name).url
81
+
82
+ def _setup_handlers(self):
83
+ """Set up MCP server handlers."""
84
+
85
+ @self.server.list_tools()
86
+ async def list_tools() -> list[Tool]:
87
+ """Define available tools for Claude Code."""
88
+ # Import the tool definitions from mcp_server
89
+ from .mcp_server import list_tools as get_tools
90
+
91
+ return await get_tools()
92
+
93
+ @self.server.call_tool()
94
+ async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult:
95
+ """Handle tool invocations with permission checking."""
96
+ # Get auth context from the current request
97
+ # This is passed via the transport
98
+ token_info: TokenInfo | None = getattr(self.server, "_current_token_info", None)
99
+
100
+ if token_info is None:
101
+ return CallToolResult(
102
+ content=[TextContent(type="text", text="Error: No authentication context")]
103
+ )
104
+
105
+ # Check permissions
106
+ if name in WRITE_TOOLS and token_info.permissions == "ro":
107
+ return CallToolResult(
108
+ content=[
109
+ TextContent(
110
+ type="text",
111
+ text=f"Error: Permission denied. Tool '{name}' requires write access.",
112
+ )
113
+ ]
114
+ )
115
+
116
+ # Get or create knowledge base for this database
117
+ if token_info.database not in self.knowledge_bases:
118
+ db_url = self._get_db_url(token_info.database)
119
+ self.knowledge_bases[token_info.database] = KnowledgeBase(db_url)
120
+
121
+ kb = self.knowledge_bases[token_info.database]
122
+
123
+ # Execute the tool
124
+ return await self._execute_tool(kb, name, arguments)
125
+
126
+ async def _execute_tool(
127
+ self, kb: KnowledgeBase, name: str, arguments: dict[str, Any]
128
+ ) -> CallToolResult:
129
+ """Execute a tool on a specific knowledge base."""
130
+ try:
131
+ if name == "search_knowledge":
132
+ results = kb.semantic_search(
133
+ query=arguments["query"],
134
+ limit=arguments.get("limit", 5),
135
+ source_type=arguments.get("source_type"),
136
+ project=arguments.get("project"),
137
+ since=arguments.get("since"),
138
+ )
139
+ return CallToolResult(
140
+ content=[TextContent(type="text", text=format_search_results(results))]
141
+ )
142
+
143
+ elif name == "keyword_search":
144
+ results = kb.keyword_search(
145
+ query=arguments["query"],
146
+ limit=arguments.get("limit", 5),
147
+ source_type=arguments.get("source_type"),
148
+ since=arguments.get("since"),
149
+ )
150
+ return CallToolResult(
151
+ content=[
152
+ TextContent(
153
+ type="text", text=format_search_results(results, show_similarity=False)
154
+ )
155
+ ]
156
+ )
157
+
158
+ elif name == "hybrid_search":
159
+ results = kb.hybrid_search(
160
+ query=arguments["query"],
161
+ limit=arguments.get("limit", 5),
162
+ source_type=arguments.get("source_type"),
163
+ since=arguments.get("since"),
164
+ )
165
+ return CallToolResult(
166
+ content=[
167
+ TextContent(
168
+ type="text", text=format_search_results(results, show_similarity=False)
169
+ )
170
+ ]
171
+ )
172
+
173
+ elif name == "get_document":
174
+ doc = kb.get_document(arguments["source_path"])
175
+ if not doc:
176
+ return CallToolResult(
177
+ content=[TextContent(type="text", text="Document not found.")]
178
+ )
179
+ return CallToolResult(
180
+ content=[TextContent(type="text", text=f"# {doc['title']}\n\n{doc['content']}")]
181
+ )
182
+
183
+ elif name == "list_sources":
184
+ sources = kb.list_sources()
185
+ if not sources:
186
+ return CallToolResult(
187
+ content=[TextContent(type="text", text="No documents indexed yet.")]
188
+ )
189
+ output = ["## Indexed Sources\n"]
190
+ for s in sources:
191
+ tokens = s.get("total_tokens") or 0
192
+ output.append(
193
+ f"- **{s['source_type']}**: {s['document_count']} documents, "
194
+ f"{s['chunk_count']} chunks (~{tokens:,} tokens)"
195
+ )
196
+ return CallToolResult(content=[TextContent(type="text", text="\n".join(output))])
197
+
198
+ elif name == "list_projects":
199
+ projects = kb.list_projects()
200
+ if not projects:
201
+ return CallToolResult(
202
+ content=[TextContent(type="text", text="No projects found.")]
203
+ )
204
+ project_list = "\n".join(f"- {p}" for p in projects)
205
+ return CallToolResult(
206
+ content=[TextContent(type="text", text=f"## Projects\n\n{project_list}")]
207
+ )
208
+
209
+ elif name == "recent_documents":
210
+ from .mcp_server import format_relative_time, get_document_date
211
+
212
+ docs = kb.get_recent_documents(arguments.get("limit", 10))
213
+ if not docs:
214
+ return CallToolResult(
215
+ content=[TextContent(type="text", text="No documents indexed yet.")]
216
+ )
217
+ output = ["## Recent Documents\n"]
218
+ for d in docs:
219
+ project = d["metadata"].get("project", "")
220
+ project_str = f" [{project}]" if project else ""
221
+ date_str = ""
222
+ if doc_date := get_document_date(d["metadata"]):
223
+ date_str = f" - {format_relative_time(doc_date)}"
224
+ output.append(f"- **{d['title']}**{project_str} ({d['source_type']}){date_str}")
225
+ output.append(f" `{d['source_path']}`")
226
+ return CallToolResult(content=[TextContent(type="text", text="\n".join(output))])
227
+
228
+ elif name == "save_knowledge":
229
+ result = kb.save_knowledge(
230
+ title=arguments["title"],
231
+ content=arguments["content"],
232
+ tags=arguments.get("tags"),
233
+ project=arguments.get("project"),
234
+ )
235
+ if result["status"] == "duplicate":
236
+ return CallToolResult(
237
+ content=[
238
+ TextContent(
239
+ type="text",
240
+ text=(
241
+ f"Duplicate content already exists:\n"
242
+ f"- Title: {result['existing_title']}\n"
243
+ f"- Path: `{result['existing_path']}`"
244
+ ),
245
+ )
246
+ ]
247
+ )
248
+ return CallToolResult(
249
+ content=[
250
+ TextContent(
251
+ type="text",
252
+ text=(
253
+ f"Knowledge saved successfully:\n"
254
+ f"- Title: {result['title']}\n"
255
+ f"- Path: `{result['source_path']}`\n"
256
+ f"- Tokens: ~{result['token_count']}"
257
+ ),
258
+ )
259
+ ]
260
+ )
261
+
262
+ elif name == "delete_knowledge":
263
+ deleted = kb.delete_knowledge(arguments["source_path"])
264
+ if deleted:
265
+ return CallToolResult(
266
+ content=[TextContent(type="text", text="Knowledge entry deleted.")]
267
+ )
268
+ return CallToolResult(
269
+ content=[
270
+ TextContent(
271
+ type="text",
272
+ text="Could not delete. Entry not found or not a Claude-saved entry.",
273
+ )
274
+ ]
275
+ )
276
+
277
+ elif name == "get_actionable_items":
278
+ items = kb.get_actionable_items(
279
+ item_type=arguments.get("item_type"),
280
+ status=arguments.get("status"),
281
+ due_date=arguments.get("due_date"),
282
+ event_date=arguments.get("event_date"),
283
+ min_priority=arguments.get("min_priority"),
284
+ limit=arguments.get("limit", 20),
285
+ )
286
+ return CallToolResult(
287
+ content=[TextContent(type="text", text=format_actionable_items(items))]
288
+ )
289
+
290
+ elif name == "get_database_info":
291
+ # Get config-based info for the token's database
292
+ token_info = getattr(self.server, "_current_token_info", None)
293
+ db_config = config.get_database(token_info.database if token_info else None)
294
+ info_parts = ["## Knowledge Base Info\n"]
295
+
296
+ if db_config.description:
297
+ info_parts.append(f"**Description (config):** {db_config.description}")
298
+ if db_config.topics:
299
+ info_parts.append(f"**Topics (config):** {', '.join(db_config.topics)}")
300
+
301
+ # LLM-enhanced metadata
302
+ try:
303
+ metadata = kb.get_database_metadata()
304
+ llm_desc = metadata.get("llm_description", {}).get("value")
305
+ llm_topics = metadata.get("llm_topics", {}).get("value")
306
+ if llm_desc:
307
+ info_parts.append(f"**Description (LLM-enhanced):** {llm_desc}")
308
+ if llm_topics:
309
+ info_parts.append(f"**Topics (LLM-enhanced):** {', '.join(llm_topics)}")
310
+ except Exception:
311
+ pass
312
+
313
+ sources = kb.list_sources()
314
+ if sources:
315
+ info_parts.append("\n### Content Statistics")
316
+ for s in sources:
317
+ tokens = s.get("total_tokens") or 0
318
+ info_parts.append(
319
+ f"- **{s['source_type']}**: {s['document_count']} documents, "
320
+ f"{s['chunk_count']} chunks (~{tokens:,} tokens)"
321
+ )
322
+
323
+ projects = kb.list_projects()
324
+ if projects:
325
+ info_parts.append(f"\n### Projects\n{', '.join(projects)}")
326
+
327
+ return CallToolResult(
328
+ content=[TextContent(type="text", text="\n".join(info_parts))]
329
+ )
330
+
331
+ elif name == "set_database_description":
332
+ updated = []
333
+ if "description" in arguments:
334
+ kb.set_database_metadata("llm_description", arguments["description"])
335
+ updated.append("description")
336
+ if "topics" in arguments:
337
+ kb.set_database_metadata("llm_topics", arguments["topics"])
338
+ updated.append("topics")
339
+ if updated:
340
+ return CallToolResult(
341
+ content=[
342
+ TextContent(
343
+ type="text",
344
+ text=f"Updated database metadata: {', '.join(updated)}",
345
+ )
346
+ ]
347
+ )
348
+ return CallToolResult(
349
+ content=[TextContent(type="text", text="No fields provided to update.")]
350
+ )
351
+
352
+ else:
353
+ return CallToolResult(
354
+ content=[TextContent(type="text", text=f"Unknown tool: {name}")]
355
+ )
356
+
357
+ except Exception as e:
358
+ return CallToolResult(content=[TextContent(type="text", text=f"Error: {e!s}")])
359
+
360
+ def create_app(self) -> Starlette:
361
+ """Create the Starlette application."""
362
+ verifier = OKBTokenVerifier(self._get_db_url)
363
+
364
+ async def handle_sse(request: Request) -> Response:
365
+ """Handle SSE connections for MCP."""
366
+ # Verify token
367
+ token = extract_token(request)
368
+ if not token:
369
+ return JSONResponse(
370
+ {"error": "Missing token. Use Authorization header or ?token= parameter"},
371
+ status_code=401,
372
+ )
373
+
374
+ token_info = verifier.verify(token)
375
+ if not token_info:
376
+ return JSONResponse(
377
+ {"error": "Invalid or expired token"},
378
+ status_code=401,
379
+ )
380
+
381
+ # Track existing sessions before connecting
382
+ existing_sessions = set(self.transport._read_stream_writers.keys())
383
+
384
+ async with self.transport.connect_sse(
385
+ request.scope, request.receive, request._send
386
+ ) as (read_stream, write_stream):
387
+ # Find the new session ID by comparing before/after
388
+ current_sessions = set(self.transport._read_stream_writers.keys())
389
+ new_sessions = current_sessions - existing_sessions
390
+ if not new_sessions:
391
+ return JSONResponse(
392
+ {"error": "Failed to establish session"},
393
+ status_code=500,
394
+ )
395
+ session_id = new_sessions.pop()
396
+ session_id_hex = session_id.hex
397
+
398
+ # Store token mapping for this session
399
+ self.session_tokens[session_id_hex] = token_info
400
+ self.server._current_token_info = token_info
401
+
402
+ try:
403
+ await self.server.run(
404
+ read_stream, write_stream, self.server.create_initialization_options()
405
+ )
406
+ finally:
407
+ # Clean up session on disconnect
408
+ self.session_tokens.pop(session_id_hex, None)
409
+
410
+ return Response()
411
+
412
+ async def handle_messages(scope, receive, send):
413
+ """Handle POST messages for MCP (raw ASGI handler)."""
414
+ request = Request(scope, receive)
415
+
416
+ # Look up session from query params
417
+ session_id = request.query_params.get("session_id")
418
+ if not session_id:
419
+ response = JSONResponse({"error": "Missing session_id"}, status_code=400)
420
+ await response(scope, receive, send)
421
+ return
422
+
423
+ token_info = self.session_tokens.get(session_id)
424
+ if not token_info:
425
+ response = JSONResponse({"error": "Invalid or expired session"}, status_code=401)
426
+ await response(scope, receive, send)
427
+ return
428
+
429
+ # Set current token info for tool calls
430
+ self.server._current_token_info = token_info
431
+
432
+ await self.transport.handle_post_message(scope, receive, send)
433
+
434
+ async def health(request: Request) -> JSONResponse:
435
+ """Health check endpoint."""
436
+ return JSONResponse({"status": "ok"})
437
+
438
+ routes = [
439
+ Route("/health", health, methods=["GET"]),
440
+ Route("/sse", handle_sse, methods=["GET"]),
441
+ Mount("/messages", app=handle_messages),
442
+ ]
443
+
444
+ return Starlette(routes=routes)
445
+
446
+
447
+ def run_http_server(host: str = "127.0.0.1", port: int = 8080):
448
+ """Run the HTTP MCP server."""
449
+ import uvicorn
450
+
451
+ print("Warming up embedding model...", file=sys.stderr)
452
+ warmup()
453
+ print("Ready.", file=sys.stderr)
454
+
455
+ http_server = HTTPMCPServer()
456
+ app = http_server.create_app()
457
+
458
+ print(f"Starting HTTP MCP server on http://{host}:{port}", file=sys.stderr)
459
+ print(" SSE endpoint: /sse", file=sys.stderr)
460
+ print(" Messages endpoint: /messages/", file=sys.stderr)
461
+ print(" Health endpoint: /health", file=sys.stderr)
462
+
463
+ uvicorn.run(app, host=host, port=port, log_level="info")