okb 1.1.0a0__py3-none-any.whl → 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
okb/http_server.py CHANGED
@@ -1,9 +1,15 @@
1
1
  """HTTP transport server for MCP with token authentication.
2
2
 
3
- This module provides an HTTP server that serves the LKB MCP server with
4
- token-based authentication. Tokens can be passed via Authorization header
5
- or query parameter. A single HTTP server can serve multiple databases,
6
- with the token determining which database to use.
3
+ This module provides an HTTP server that serves the OKB MCP server with
4
+ token-based authentication using Streamable HTTP transport. Tokens can be
5
+ passed via Authorization header or query parameter. A single HTTP server
6
+ can serve multiple databases, with the token determining which database to use.
7
+
8
+ Transport: Streamable HTTP (RFC 9728 compliant)
9
+ - POST /mcp → send JSON-RPC messages, get SSE response
10
+ - GET /mcp → optional standalone SSE for server notifications
11
+ - DELETE /mcp → terminate session
12
+ - Session ID in Mcp-Session-Id header
7
13
  """
8
14
 
9
15
  from __future__ import annotations
@@ -12,12 +18,11 @@ import sys
12
18
  from typing import Any
13
19
 
14
20
  from mcp.server import Server
15
- from mcp.server.sse import SseServerTransport
21
+ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
16
22
  from mcp.types import CallToolResult, TextContent, Tool
17
- from starlette.applications import Starlette
23
+ from starlette.middleware.cors import CORSMiddleware
18
24
  from starlette.requests import Request
19
- from starlette.responses import JSONResponse, Response
20
- from starlette.routing import Mount, Route
25
+ from starlette.responses import JSONResponse
21
26
 
22
27
  from .config import config
23
28
  from .local_embedder import warmup
@@ -37,9 +42,16 @@ READ_ONLY_TOOLS = frozenset(
37
42
  "get_document",
38
43
  "list_sources",
39
44
  "list_projects",
45
+ "list_documents_by_project",
46
+ "get_project_stats",
40
47
  "recent_documents",
41
48
  "get_actionable_items",
42
49
  "get_database_info",
50
+ "list_sync_sources",
51
+ "list_pending_entities",
52
+ "list_pending_merges",
53
+ "get_topic_clusters",
54
+ "get_entity_relationships",
43
55
  }
44
56
  )
45
57
 
@@ -51,6 +63,17 @@ WRITE_TOOLS = frozenset(
51
63
  "add_todo",
52
64
  "trigger_sync",
53
65
  "trigger_rescan",
66
+ "enrich_document",
67
+ "approve_entity",
68
+ "reject_entity",
69
+ "analyze_knowledge_base",
70
+ "find_entity_duplicates",
71
+ "merge_entities",
72
+ "approve_merge",
73
+ "reject_merge",
74
+ "run_consolidation",
75
+ "rename_project",
76
+ "set_document_project",
54
77
  }
55
78
  )
56
79
 
@@ -66,14 +89,14 @@ def extract_token(request: Request) -> str | None:
66
89
 
67
90
 
68
91
  class HTTPMCPServer:
69
- """HTTP server for MCP with token authentication."""
92
+ """HTTP server for MCP with token authentication using Streamable HTTP transport."""
70
93
 
71
94
  def __init__(self):
72
95
  self.knowledge_bases: dict[str, KnowledgeBase] = {}
73
96
  self.server = Server("knowledge-base")
74
- # Single shared transport instance for all connections
75
- self.transport = SseServerTransport("/messages/")
76
- # Map session_id (hex string) -> token_info
97
+ # Session manager handles all transport complexity
98
+ self.session_manager = StreamableHTTPSessionManager(app=self.server)
99
+ # Map mcp-session-id -> token_info
77
100
  self.session_tokens: dict[str, TokenInfo] = {}
78
101
  self._setup_handlers()
79
102
 
@@ -208,6 +231,83 @@ class HTTPMCPServer:
208
231
  content=[TextContent(type="text", text=f"## Projects\n\n{project_list}")]
209
232
  )
210
233
 
234
+ elif name == "list_documents_by_project":
235
+ project = arguments["project"]
236
+ limit = arguments.get("limit", 100)
237
+ docs = kb.list_documents_by_project(project, limit)
238
+ if not docs:
239
+ return CallToolResult(
240
+ content=[
241
+ TextContent(
242
+ type="text", text=f"No documents found for project '{project}'."
243
+ )
244
+ ]
245
+ )
246
+ output = [f"## Documents in '{project}' ({len(docs)} documents)\n"]
247
+ for d in docs:
248
+ output.append(f"- **{d['title'] or d['source_path']}** ({d['source_type']})")
249
+ output.append(f" - `{d['source_path']}`")
250
+ return CallToolResult(content=[TextContent(type="text", text="\n".join(output))])
251
+
252
+ elif name == "get_project_stats":
253
+ stats = kb.get_project_stats()
254
+ if not stats:
255
+ return CallToolResult(
256
+ content=[TextContent(type="text", text="No projects found.")]
257
+ )
258
+ output = ["## Project Statistics\n"]
259
+ for p in stats:
260
+ output.append(f"- **{p['project']}**: {p['document_count']} documents")
261
+ return CallToolResult(content=[TextContent(type="text", text="\n".join(output))])
262
+
263
+ elif name == "rename_project":
264
+ old_name = arguments["old_name"]
265
+ new_name = arguments["new_name"]
266
+ if old_name == new_name:
267
+ return CallToolResult(
268
+ content=[TextContent(type="text", text="Old and new names are the same.")]
269
+ )
270
+ count = kb.rename_project(old_name, new_name)
271
+ if count == 0:
272
+ return CallToolResult(
273
+ content=[
274
+ TextContent(
275
+ type="text", text=f"No documents found with project '{old_name}'."
276
+ )
277
+ ]
278
+ )
279
+ return CallToolResult(
280
+ content=[
281
+ TextContent(
282
+ type="text",
283
+ text=f"Renamed project '{old_name}' to '{new_name}' "
284
+ f"({count} documents updated).",
285
+ )
286
+ ]
287
+ )
288
+
289
+ elif name == "set_document_project":
290
+ source_path = arguments["source_path"]
291
+ project = arguments.get("project")
292
+ success = kb.set_document_project(source_path, project)
293
+ if not success:
294
+ return CallToolResult(
295
+ content=[
296
+ TextContent(type="text", text=f"Document not found: {source_path}")
297
+ ]
298
+ )
299
+ if project:
300
+ return CallToolResult(
301
+ content=[
302
+ TextContent(
303
+ type="text", text=f"Set project to '{project}' for {source_path}"
304
+ )
305
+ ]
306
+ )
307
+ return CallToolResult(
308
+ content=[TextContent(type="text", text=f"Cleared project for {source_path}")]
309
+ )
310
+
211
311
  elif name == "recent_documents":
212
312
  from .mcp_server import format_relative_time, get_document_date
213
313
 
@@ -265,13 +365,13 @@ class HTTPMCPServer:
265
365
  deleted = kb.delete_knowledge(arguments["source_path"])
266
366
  if deleted:
267
367
  return CallToolResult(
268
- content=[TextContent(type="text", text="Knowledge entry deleted.")]
368
+ content=[TextContent(type="text", text="Document deleted.")]
269
369
  )
270
370
  return CallToolResult(
271
371
  content=[
272
372
  TextContent(
273
373
  type="text",
274
- text="Could not delete. Entry not found or not a Claude-saved entry.",
374
+ text="Could not delete. Document not found.",
275
375
  )
276
376
  ]
277
377
  )
@@ -394,6 +494,134 @@ class HTTPMCPServer:
394
494
  )
395
495
  return CallToolResult(content=[TextContent(type="text", text=result)])
396
496
 
497
+ elif name == "list_sync_sources":
498
+ from .mcp_server import _list_sync_sources
499
+
500
+ token_info = getattr(self.server, "_current_token_info", None)
501
+ db_name = token_info.database if token_info else config.get_database().name
502
+ result = _list_sync_sources(kb.db_url, db_name)
503
+ return CallToolResult(content=[TextContent(type="text", text=result)])
504
+
505
+ elif name == "enrich_document":
506
+ from .mcp_server import _enrich_document
507
+
508
+ result = _enrich_document(
509
+ kb.db_url,
510
+ source_path=arguments["source_path"],
511
+ extract_todos=arguments.get("extract_todos", True),
512
+ extract_entities=arguments.get("extract_entities", True),
513
+ auto_create_entities=arguments.get("auto_create_entities", False),
514
+ )
515
+ return CallToolResult(content=[TextContent(type="text", text=result)])
516
+
517
+ elif name == "list_pending_entities":
518
+ from .mcp_server import _list_pending_entities
519
+
520
+ result = _list_pending_entities(
521
+ kb.db_url,
522
+ entity_type=arguments.get("entity_type"),
523
+ limit=arguments.get("limit", 20),
524
+ )
525
+ return CallToolResult(content=[TextContent(type="text", text=result)])
526
+
527
+ elif name == "approve_entity":
528
+ from .mcp_server import _approve_entity
529
+
530
+ result = _approve_entity(kb.db_url, arguments["pending_id"])
531
+ return CallToolResult(content=[TextContent(type="text", text=result)])
532
+
533
+ elif name == "reject_entity":
534
+ from .mcp_server import _reject_entity
535
+
536
+ result = _reject_entity(kb.db_url, arguments["pending_id"])
537
+ return CallToolResult(content=[TextContent(type="text", text=result)])
538
+
539
+ elif name == "analyze_knowledge_base":
540
+ from .mcp_server import _analyze_knowledge_base
541
+
542
+ result = _analyze_knowledge_base(
543
+ kb.db_url,
544
+ project=arguments.get("project"),
545
+ sample_size=arguments.get("sample_size", 15),
546
+ auto_update=arguments.get("auto_update", True),
547
+ )
548
+ return CallToolResult(content=[TextContent(type="text", text=result)])
549
+
550
+ # Entity consolidation tools
551
+ elif name == "find_entity_duplicates":
552
+ from .mcp_server import _find_entity_duplicates
553
+
554
+ result = _find_entity_duplicates(
555
+ kb.db_url,
556
+ similarity_threshold=arguments.get("similarity_threshold", 0.85),
557
+ limit=arguments.get("limit", 50),
558
+ )
559
+ return CallToolResult(content=[TextContent(type="text", text=result)])
560
+
561
+ elif name == "merge_entities":
562
+ from .mcp_server import _merge_entities
563
+
564
+ result = _merge_entities(
565
+ kb.db_url,
566
+ canonical_path=arguments["canonical_path"],
567
+ duplicate_path=arguments["duplicate_path"],
568
+ )
569
+ return CallToolResult(content=[TextContent(type="text", text=result)])
570
+
571
+ elif name == "list_pending_merges":
572
+ from .mcp_server import _list_pending_merges
573
+
574
+ result = _list_pending_merges(
575
+ kb.db_url,
576
+ limit=arguments.get("limit", 50),
577
+ )
578
+ return CallToolResult(content=[TextContent(type="text", text=result)])
579
+
580
+ elif name == "approve_merge":
581
+ from .mcp_server import _approve_merge
582
+
583
+ result = _approve_merge(kb.db_url, arguments["merge_id"])
584
+ return CallToolResult(content=[TextContent(type="text", text=result)])
585
+
586
+ elif name == "reject_merge":
587
+ from .mcp_server import _reject_merge
588
+
589
+ result = _reject_merge(kb.db_url, arguments["merge_id"])
590
+ return CallToolResult(content=[TextContent(type="text", text=result)])
591
+
592
+ elif name == "get_topic_clusters":
593
+ from .mcp_server import _get_topic_clusters
594
+
595
+ result = _get_topic_clusters(
596
+ kb.db_url,
597
+ limit=arguments.get("limit", 20),
598
+ )
599
+ return CallToolResult(content=[TextContent(type="text", text=result)])
600
+
601
+ elif name == "get_entity_relationships":
602
+ from .mcp_server import _get_entity_relationships
603
+
604
+ result = _get_entity_relationships(
605
+ kb.db_url,
606
+ entity_name=arguments.get("entity_name"),
607
+ relationship_type=arguments.get("relationship_type"),
608
+ limit=arguments.get("limit", 50),
609
+ )
610
+ return CallToolResult(content=[TextContent(type="text", text=result)])
611
+
612
+ elif name == "run_consolidation":
613
+ from .mcp_server import _run_consolidation
614
+
615
+ result = _run_consolidation(
616
+ kb.db_url,
617
+ detect_duplicates=arguments.get("detect_duplicates", True),
618
+ detect_cross_doc=arguments.get("detect_cross_doc", True),
619
+ build_clusters=arguments.get("build_clusters", True),
620
+ extract_relationships=arguments.get("extract_relationships", True),
621
+ dry_run=arguments.get("dry_run", False),
622
+ )
623
+ return CallToolResult(content=[TextContent(type="text", text=result)])
624
+
397
625
  else:
398
626
  return CallToolResult(
399
627
  content=[TextContent(type="text", text=f"Unknown tool: {name}")]
@@ -402,95 +630,131 @@ class HTTPMCPServer:
402
630
  except Exception as e:
403
631
  return CallToolResult(content=[TextContent(type="text", text=f"Error: {e!s}")])
404
632
 
405
- def create_app(self) -> Starlette:
633
+ def create_app(self):
406
634
  """Create the Starlette application."""
407
635
  verifier = OKBTokenVerifier(self._get_db_url)
636
+ session_header_name = "mcp-session-id"
408
637
 
409
- async def handle_sse(request: Request) -> Response:
410
- """Handle SSE connections for MCP."""
411
- # Verify token
412
- token = extract_token(request)
413
- if not token:
414
- return JSONResponse(
415
- {"error": "Missing token. Use Authorization header or ?token= parameter"},
416
- status_code=401,
417
- )
418
-
419
- token_info = verifier.verify(token)
420
- if not token_info:
421
- return JSONResponse(
422
- {"error": "Invalid or expired token"},
423
- status_code=401,
424
- )
425
-
426
- # Track existing sessions before connecting
427
- existing_sessions = set(self.transport._read_stream_writers.keys())
428
-
429
- async with self.transport.connect_sse(
430
- request.scope, request.receive, request._send
431
- ) as (read_stream, write_stream):
432
- # Find the new session ID by comparing before/after
433
- current_sessions = set(self.transport._read_stream_writers.keys())
434
- new_sessions = current_sessions - existing_sessions
435
- if not new_sessions:
436
- return JSONResponse(
437
- {"error": "Failed to establish session"},
438
- status_code=500,
439
- )
440
- session_id = new_sessions.pop()
441
- session_id_hex = session_id.hex
638
+ def create_mcp_handler():
639
+ """Create an ASGI handler for MCP with auth."""
442
640
 
443
- # Store token mapping for this session
444
- self.session_tokens[session_id_hex] = token_info
445
- self.server._current_token_info = token_info
641
+ async def handle_mcp(scope, receive, send):
642
+ """Handle all MCP requests (GET, POST, DELETE) with auth."""
643
+ request = Request(scope, receive)
446
644
 
447
- try:
448
- await self.server.run(
449
- read_stream, write_stream, self.server.create_initialization_options()
645
+ # Extract and verify token
646
+ token = extract_token(request)
647
+ if not token:
648
+ response = JSONResponse(
649
+ {"error": "Missing token. Use Authorization header or ?token= param"},
650
+ status_code=401,
450
651
  )
451
- finally:
452
- # Clean up session on disconnect
453
- self.session_tokens.pop(session_id_hex, None)
454
-
455
- return Response()
456
-
457
- async def handle_messages(scope, receive, send):
458
- """Handle POST messages for MCP (raw ASGI handler)."""
459
- request = Request(scope, receive)
460
-
461
- # Look up session from query params
462
- session_id = request.query_params.get("session_id")
463
- if not session_id:
464
- response = JSONResponse({"error": "Missing session_id"}, status_code=400)
465
- await response(scope, receive, send)
466
- return
467
-
468
- token_info = self.session_tokens.get(session_id)
469
- if not token_info:
470
- response = JSONResponse({"error": "Invalid or expired session"}, status_code=401)
471
- await response(scope, receive, send)
472
- return
652
+ await response(scope, receive, send)
653
+ return
654
+
655
+ token_info = verifier.verify(token)
656
+ if not token_info:
657
+ response = JSONResponse(
658
+ {"error": "Invalid or expired token"},
659
+ status_code=401,
660
+ )
661
+ await response(scope, receive, send)
662
+ return
663
+
664
+ # Check if this is an existing session
665
+ session_id = request.headers.get(session_header_name)
666
+ if session_id:
667
+ # Verify token matches existing session (compare by hash and db, not object)
668
+ existing_token = self.session_tokens.get(session_id)
669
+ if existing_token:
670
+ # Token must match the one used to create the session
671
+ if (
672
+ existing_token.token_hash != token_info.token_hash
673
+ or existing_token.database != token_info.database
674
+ ):
675
+ response = JSONResponse(
676
+ {"error": "Token mismatch for session"},
677
+ status_code=401,
678
+ )
679
+ await response(scope, receive, send)
680
+ return
473
681
 
474
- # Set current token info for tool calls
475
- self.server._current_token_info = token_info
682
+ # Set current token info for tool calls
683
+ self.server._current_token_info = token_info
476
684
 
477
- await self.transport.handle_post_message(scope, receive, send)
685
+ # Wrap send to capture the session ID from response headers
686
+ captured_session_id = None
478
687
 
479
- async def health(request: Request) -> JSONResponse:
480
- """Health check endpoint."""
481
- return JSONResponse({"status": "ok"})
688
+ async def send_wrapper(message):
689
+ nonlocal captured_session_id
690
+ if message["type"] == "http.response.start":
691
+ headers = message.get("headers", [])
692
+ for name, value in headers:
693
+ header_name = (
694
+ name.lower() if isinstance(name, bytes) else name.lower().encode()
695
+ )
696
+ if header_name == session_header_name.encode():
697
+ captured_session_id = (
698
+ value.decode() if isinstance(value, bytes) else value
699
+ )
700
+ # Store immediately since SSE keeps connection open
701
+ if captured_session_id not in self.session_tokens:
702
+ self.session_tokens[captured_session_id] = token_info
703
+ break
704
+ await send(message)
705
+
706
+ # Delegate to session manager
707
+ await self.session_manager.handle_request(scope, receive, send_wrapper)
708
+
709
+ return handle_mcp
710
+
711
+ # Create the MCP handler ASGI app
712
+ mcp_handler = create_mcp_handler()
713
+
714
+ # Custom ASGI app that routes /mcp and /sse to MCP handler
715
+ async def router(scope, receive, send):
716
+ if scope["type"] == "http":
717
+ path = scope["path"].rstrip("/") # Handle trailing slash
718
+ if path in ("/mcp", "/sse"):
719
+ await mcp_handler(scope, receive, send)
720
+ return
721
+ elif path == "/health" or scope["path"] == "/health":
722
+ response = JSONResponse({"status": "ok"})
723
+ await response(scope, receive, send)
724
+ return
725
+ # 404 for unknown paths
726
+ response = JSONResponse({"error": "Not found"}, status_code=404)
727
+ await response(scope, receive, send)
728
+
729
+ # Wrap with lifespan handling
730
+ async def app_with_lifespan(scope, receive, send):
731
+ if scope["type"] == "lifespan":
732
+ async with self.session_manager.run():
733
+ # Handle lifespan protocol
734
+ while True:
735
+ message = await receive()
736
+ if message["type"] == "lifespan.startup":
737
+ await send({"type": "lifespan.startup.complete"})
738
+ elif message["type"] == "lifespan.shutdown":
739
+ await send({"type": "lifespan.shutdown.complete"})
740
+ return
741
+ else:
742
+ await router(scope, receive, send)
482
743
 
483
- routes = [
484
- Route("/health", health, methods=["GET"]),
485
- Route("/sse", handle_sse, methods=["GET"]),
486
- Mount("/messages", app=handle_messages),
487
- ]
744
+ # Add CORS for browser clients - wrap the raw ASGI app
745
+ app = CORSMiddleware(
746
+ app_with_lifespan,
747
+ allow_origins=["*"],
748
+ allow_methods=["GET", "POST", "DELETE"],
749
+ allow_headers=["Authorization", "Content-Type", session_header_name],
750
+ expose_headers=[session_header_name],
751
+ )
488
752
 
489
- return Starlette(routes=routes)
753
+ return app
490
754
 
491
755
 
492
756
  def run_http_server(host: str = "127.0.0.1", port: int = 8080):
493
- """Run the HTTP MCP server."""
757
+ """Run the HTTP MCP server with Streamable HTTP transport."""
494
758
  import uvicorn
495
759
 
496
760
  print("Warming up embedding model...", file=sys.stderr)
@@ -501,8 +765,9 @@ def run_http_server(host: str = "127.0.0.1", port: int = 8080):
501
765
  app = http_server.create_app()
502
766
 
503
767
  print(f"Starting HTTP MCP server on http://{host}:{port}", file=sys.stderr)
504
- print(" SSE endpoint: /sse", file=sys.stderr)
505
- print(" Messages endpoint: /messages/", file=sys.stderr)
768
+ print(" MCP endpoint: /mcp (GET, POST, DELETE)", file=sys.stderr)
769
+ print(" MCP endpoint: /sse (alias for /mcp)", file=sys.stderr)
506
770
  print(" Health endpoint: /health", file=sys.stderr)
771
+ print(" Transport: Streamable HTTP", file=sys.stderr)
507
772
 
508
773
  uvicorn.run(app, host=host, port=port, log_level="info")