okb 1.1.0a0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
okb/config.py CHANGED
@@ -53,7 +53,7 @@ class DatabaseConfig:
53
53
 
54
54
  name: str
55
55
  url: str
56
- managed: bool = True # Whether lkb manages this (Docker) or external
56
+ managed: bool = True # Whether okb manages this (Docker) or external
57
57
  default: bool = False
58
58
  description: str | None = None # Human-readable description for LLM context
59
59
  topics: list[str] | None = None # Topic keywords to help LLM route queries
@@ -259,6 +259,7 @@ DEFAULTS = {
259
259
  "yarn.lock",
260
260
  "uv.lock",
261
261
  "Cargo.lock",
262
+ "poetry.lock",
262
263
  "*.pyc",
263
264
  "*.pyo",
264
265
  "*.tmp",
@@ -281,7 +282,7 @@ DEFAULTS = {
281
282
  },
282
283
  "llm": {
283
284
  # LLM provider configuration
284
- # provider: None = disabled, "claude" = Anthropic API
285
+ # provider: None = disabled, "claude" = Anthropic API, "modal" = Modal GPU
285
286
  "provider": None,
286
287
  "model": "claude-haiku-4-5-20251001",
287
288
  "timeout": 30,
@@ -289,6 +290,38 @@ DEFAULTS = {
289
290
  # Bedrock settings (when use_bedrock is True)
290
291
  "use_bedrock": False,
291
292
  "aws_region": "us-west-2",
293
+ # Modal settings (when provider is "modal")
294
+ "modal_gpu": "L4", # GPU type: T4, L4, A10G, A100, etc.
295
+ },
296
+ "enrichment": {
297
+ # LLM-based document enrichment
298
+ "enabled": True,
299
+ "version": 1, # Increment to force re-enrichment
300
+ # What to extract
301
+ "extract_todos": True,
302
+ "extract_entities": True,
303
+ # Auto-create behavior
304
+ "auto_create_todos": True, # TODOs created immediately
305
+ "auto_create_entities": False, # Entities go to pending_entities table
306
+ # Confidence thresholds
307
+ "min_confidence_todo": 0.7,
308
+ "min_confidence_entity": 0.8,
309
+ # Auto-enrich during ingest (per source type)
310
+ "auto_enrich": {
311
+ "markdown": True,
312
+ "org": True,
313
+ "text": True,
314
+ "code": False, # Skip code files
315
+ "web": False, # Skip web pages
316
+ "todoist-task": False, # Already structured
317
+ },
318
+ # Entity consolidation settings
319
+ "consolidation": {
320
+ "cross_doc_min_mentions": 3, # Min docs for cross-doc detection
321
+ "embedding_similarity_threshold": 0.85, # For duplicate detection
322
+ "auto_merge_threshold": 0.95, # Auto-approve above this
323
+ "min_cluster_size": 3, # Min entities per cluster
324
+ },
292
325
  },
293
326
  }
294
327
 
@@ -349,12 +382,30 @@ class Config:
349
382
  llm_cache_responses: bool = True
350
383
  llm_use_bedrock: bool = False
351
384
  llm_aws_region: str = "us-west-2"
385
+ llm_modal_gpu: str = "L4"
386
+
387
+ # Enrichment settings (loaded from config in __post_init__)
388
+ enrichment_enabled: bool = True
389
+ enrichment_version: int = 1
390
+ enrichment_extract_todos: bool = True
391
+ enrichment_extract_entities: bool = True
392
+ enrichment_auto_create_todos: bool = True
393
+ enrichment_auto_create_entities: bool = False
394
+ enrichment_min_confidence_todo: float = 0.7
395
+ enrichment_min_confidence_entity: float = 0.8
396
+ enrichment_auto_enrich: dict[str, bool] = field(default_factory=dict)
397
+
398
+ # Consolidation settings (loaded from config in __post_init__)
399
+ consolidation_cross_doc_min_mentions: int = 3
400
+ consolidation_embedding_similarity_threshold: float = 0.85
401
+ consolidation_auto_merge_threshold: float = 0.95
402
+ consolidation_min_cluster_size: int = 3
352
403
 
353
404
  def __post_init__(self):
354
405
  """Load configuration from file and environment."""
355
406
  file_config = load_config_file()
356
407
 
357
- # Load and merge local config overlay (.lkbconf.yaml)
408
+ # Load and merge local config overlay (.okbconf.yaml)
358
409
  local_path = find_local_config()
359
410
  local_default_db: str | None = None
360
411
  if local_path:
@@ -417,7 +468,7 @@ class Config:
417
468
  else:
418
469
  # Legacy: single database_url (env > file > default)
419
470
  legacy_url = os.environ.get(
420
- "KB_DATABASE_URL",
471
+ "OKB_DATABASE_URL",
421
472
  file_config.get("database_url", DEFAULTS["databases"]["default"]["url"]),
422
473
  )
423
474
  self.databases["default"] = DatabaseConfig(
@@ -535,6 +586,55 @@ class Config:
535
586
  )
536
587
  self.llm_use_bedrock = llm_cfg.get("use_bedrock", DEFAULTS["llm"]["use_bedrock"])
537
588
  self.llm_aws_region = llm_cfg.get("aws_region", DEFAULTS["llm"]["aws_region"])
589
+ self.llm_modal_gpu = os.environ.get(
590
+ "OKB_MODAL_GPU",
591
+ llm_cfg.get("modal_gpu", DEFAULTS["llm"]["modal_gpu"]),
592
+ )
593
+
594
+ # Enrichment settings
595
+ enrich_cfg = file_config.get("enrichment", {})
596
+ self.enrichment_enabled = enrich_cfg.get("enabled", DEFAULTS["enrichment"]["enabled"])
597
+ self.enrichment_version = enrich_cfg.get("version", DEFAULTS["enrichment"]["version"])
598
+ self.enrichment_extract_todos = enrich_cfg.get(
599
+ "extract_todos", DEFAULTS["enrichment"]["extract_todos"]
600
+ )
601
+ self.enrichment_extract_entities = enrich_cfg.get(
602
+ "extract_entities", DEFAULTS["enrichment"]["extract_entities"]
603
+ )
604
+ self.enrichment_auto_create_todos = enrich_cfg.get(
605
+ "auto_create_todos", DEFAULTS["enrichment"]["auto_create_todos"]
606
+ )
607
+ self.enrichment_auto_create_entities = enrich_cfg.get(
608
+ "auto_create_entities", DEFAULTS["enrichment"]["auto_create_entities"]
609
+ )
610
+ self.enrichment_min_confidence_todo = enrich_cfg.get(
611
+ "min_confidence_todo", DEFAULTS["enrichment"]["min_confidence_todo"]
612
+ )
613
+ self.enrichment_min_confidence_entity = enrich_cfg.get(
614
+ "min_confidence_entity", DEFAULTS["enrichment"]["min_confidence_entity"]
615
+ )
616
+ self.enrichment_auto_enrich = enrich_cfg.get(
617
+ "auto_enrich", DEFAULTS["enrichment"]["auto_enrich"]
618
+ )
619
+
620
+ # Consolidation settings
621
+ consolidation_cfg = enrich_cfg.get("consolidation", {})
622
+ self.consolidation_cross_doc_min_mentions = consolidation_cfg.get(
623
+ "cross_doc_min_mentions",
624
+ DEFAULTS["enrichment"]["consolidation"]["cross_doc_min_mentions"],
625
+ )
626
+ self.consolidation_embedding_similarity_threshold = consolidation_cfg.get(
627
+ "embedding_similarity_threshold",
628
+ DEFAULTS["enrichment"]["consolidation"]["embedding_similarity_threshold"],
629
+ )
630
+ self.consolidation_auto_merge_threshold = consolidation_cfg.get(
631
+ "auto_merge_threshold",
632
+ DEFAULTS["enrichment"]["consolidation"]["auto_merge_threshold"],
633
+ )
634
+ self.consolidation_min_cluster_size = consolidation_cfg.get(
635
+ "min_cluster_size",
636
+ DEFAULTS["enrichment"]["consolidation"]["min_cluster_size"],
637
+ )
538
638
 
539
639
  def get_database(self, name: str | None = None) -> DatabaseConfig:
540
640
  """Get database config by name, or default if None."""
@@ -648,6 +748,24 @@ class Config:
648
748
  "cache_responses": self.llm_cache_responses,
649
749
  "use_bedrock": self.llm_use_bedrock,
650
750
  "aws_region": self.llm_aws_region,
751
+ "modal_gpu": self.llm_modal_gpu,
752
+ },
753
+ "enrichment": {
754
+ "enabled": self.enrichment_enabled,
755
+ "version": self.enrichment_version,
756
+ "extract_todos": self.enrichment_extract_todos,
757
+ "extract_entities": self.enrichment_extract_entities,
758
+ "auto_create_todos": self.enrichment_auto_create_todos,
759
+ "auto_create_entities": self.enrichment_auto_create_entities,
760
+ "min_confidence_todo": self.enrichment_min_confidence_todo,
761
+ "min_confidence_entity": self.enrichment_min_confidence_entity,
762
+ "auto_enrich": self.enrichment_auto_enrich,
763
+ "consolidation": {
764
+ "cross_doc_min_mentions": self.consolidation_cross_doc_min_mentions,
765
+ "embedding_similarity_threshold": self.consolidation_embedding_similarity_threshold,
766
+ "auto_merge_threshold": self.consolidation_auto_merge_threshold,
767
+ "min_cluster_size": self.consolidation_min_cluster_size,
768
+ },
651
769
  },
652
770
  }
653
771
 
okb/http_server.py CHANGED
@@ -1,9 +1,15 @@
1
1
  """HTTP transport server for MCP with token authentication.
2
2
 
3
- This module provides an HTTP server that serves the LKB MCP server with
4
- token-based authentication. Tokens can be passed via Authorization header
5
- or query parameter. A single HTTP server can serve multiple databases,
6
- with the token determining which database to use.
3
+ This module provides an HTTP server that serves the OKB MCP server with
4
+ token-based authentication using Streamable HTTP transport. Tokens can be
5
+ passed via Authorization header or query parameter. A single HTTP server
6
+ can serve multiple databases, with the token determining which database to use.
7
+
8
+ Transport: Streamable HTTP (RFC 9728 compliant)
9
+ - POST /mcp → send JSON-RPC messages, get SSE response
10
+ - GET /mcp → optional standalone SSE for server notifications
11
+ - DELETE /mcp → terminate session
12
+ - Session ID in Mcp-Session-Id header
7
13
  """
8
14
 
9
15
  from __future__ import annotations
@@ -12,12 +18,11 @@ import sys
12
18
  from typing import Any
13
19
 
14
20
  from mcp.server import Server
15
- from mcp.server.sse import SseServerTransport
21
+ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
16
22
  from mcp.types import CallToolResult, TextContent, Tool
17
- from starlette.applications import Starlette
23
+ from starlette.middleware.cors import CORSMiddleware
18
24
  from starlette.requests import Request
19
- from starlette.responses import JSONResponse, Response
20
- from starlette.routing import Mount, Route
25
+ from starlette.responses import JSONResponse
21
26
 
22
27
  from .config import config
23
28
  from .local_embedder import warmup
@@ -37,9 +42,15 @@ READ_ONLY_TOOLS = frozenset(
37
42
  "get_document",
38
43
  "list_sources",
39
44
  "list_projects",
45
+ "list_documents_by_project",
40
46
  "recent_documents",
41
47
  "get_actionable_items",
42
48
  "get_database_info",
49
+ "list_sync_sources",
50
+ "list_pending_entities",
51
+ "list_pending_merges",
52
+ "get_topic_clusters",
53
+ "get_entity_relationships",
43
54
  }
44
55
  )
45
56
 
@@ -51,6 +62,15 @@ WRITE_TOOLS = frozenset(
51
62
  "add_todo",
52
63
  "trigger_sync",
53
64
  "trigger_rescan",
65
+ "enrich_document",
66
+ "approve_entity",
67
+ "reject_entity",
68
+ "analyze_knowledge_base",
69
+ "find_entity_duplicates",
70
+ "merge_entities",
71
+ "approve_merge",
72
+ "reject_merge",
73
+ "run_consolidation",
54
74
  }
55
75
  )
56
76
 
@@ -66,14 +86,14 @@ def extract_token(request: Request) -> str | None:
66
86
 
67
87
 
68
88
  class HTTPMCPServer:
69
- """HTTP server for MCP with token authentication."""
89
+ """HTTP server for MCP with token authentication using Streamable HTTP transport."""
70
90
 
71
91
  def __init__(self):
72
92
  self.knowledge_bases: dict[str, KnowledgeBase] = {}
73
93
  self.server = Server("knowledge-base")
74
- # Single shared transport instance for all connections
75
- self.transport = SseServerTransport("/messages/")
76
- # Map session_id (hex string) -> token_info
94
+ # Session manager handles all transport complexity
95
+ self.session_manager = StreamableHTTPSessionManager(app=self.server)
96
+ # Map mcp-session-id -> token_info
77
97
  self.session_tokens: dict[str, TokenInfo] = {}
78
98
  self._setup_handlers()
79
99
 
@@ -208,6 +228,24 @@ class HTTPMCPServer:
208
228
  content=[TextContent(type="text", text=f"## Projects\n\n{project_list}")]
209
229
  )
210
230
 
231
+ elif name == "list_documents_by_project":
232
+ project = arguments["project"]
233
+ limit = arguments.get("limit", 100)
234
+ docs = kb.list_documents_by_project(project, limit)
235
+ if not docs:
236
+ return CallToolResult(
237
+ content=[
238
+ TextContent(
239
+ type="text", text=f"No documents found for project '{project}'."
240
+ )
241
+ ]
242
+ )
243
+ output = [f"## Documents in '{project}' ({len(docs)} documents)\n"]
244
+ for d in docs:
245
+ output.append(f"- **{d['title'] or d['source_path']}** ({d['source_type']})")
246
+ output.append(f" - `{d['source_path']}`")
247
+ return CallToolResult(content=[TextContent(type="text", text="\n".join(output))])
248
+
211
249
  elif name == "recent_documents":
212
250
  from .mcp_server import format_relative_time, get_document_date
213
251
 
@@ -265,13 +303,13 @@ class HTTPMCPServer:
265
303
  deleted = kb.delete_knowledge(arguments["source_path"])
266
304
  if deleted:
267
305
  return CallToolResult(
268
- content=[TextContent(type="text", text="Knowledge entry deleted.")]
306
+ content=[TextContent(type="text", text="Document deleted.")]
269
307
  )
270
308
  return CallToolResult(
271
309
  content=[
272
310
  TextContent(
273
311
  type="text",
274
- text="Could not delete. Entry not found or not a Claude-saved entry.",
312
+ text="Could not delete. Document not found.",
275
313
  )
276
314
  ]
277
315
  )
@@ -394,6 +432,134 @@ class HTTPMCPServer:
394
432
  )
395
433
  return CallToolResult(content=[TextContent(type="text", text=result)])
396
434
 
435
+ elif name == "list_sync_sources":
436
+ from .mcp_server import _list_sync_sources
437
+
438
+ token_info = getattr(self.server, "_current_token_info", None)
439
+ db_name = token_info.database if token_info else config.get_database().name
440
+ result = _list_sync_sources(kb.db_url, db_name)
441
+ return CallToolResult(content=[TextContent(type="text", text=result)])
442
+
443
+ elif name == "enrich_document":
444
+ from .mcp_server import _enrich_document
445
+
446
+ result = _enrich_document(
447
+ kb.db_url,
448
+ source_path=arguments["source_path"],
449
+ extract_todos=arguments.get("extract_todos", True),
450
+ extract_entities=arguments.get("extract_entities", True),
451
+ auto_create_entities=arguments.get("auto_create_entities", False),
452
+ )
453
+ return CallToolResult(content=[TextContent(type="text", text=result)])
454
+
455
+ elif name == "list_pending_entities":
456
+ from .mcp_server import _list_pending_entities
457
+
458
+ result = _list_pending_entities(
459
+ kb.db_url,
460
+ entity_type=arguments.get("entity_type"),
461
+ limit=arguments.get("limit", 20),
462
+ )
463
+ return CallToolResult(content=[TextContent(type="text", text=result)])
464
+
465
+ elif name == "approve_entity":
466
+ from .mcp_server import _approve_entity
467
+
468
+ result = _approve_entity(kb.db_url, arguments["pending_id"])
469
+ return CallToolResult(content=[TextContent(type="text", text=result)])
470
+
471
+ elif name == "reject_entity":
472
+ from .mcp_server import _reject_entity
473
+
474
+ result = _reject_entity(kb.db_url, arguments["pending_id"])
475
+ return CallToolResult(content=[TextContent(type="text", text=result)])
476
+
477
+ elif name == "analyze_knowledge_base":
478
+ from .mcp_server import _analyze_knowledge_base
479
+
480
+ result = _analyze_knowledge_base(
481
+ kb.db_url,
482
+ project=arguments.get("project"),
483
+ sample_size=arguments.get("sample_size", 15),
484
+ auto_update=arguments.get("auto_update", True),
485
+ )
486
+ return CallToolResult(content=[TextContent(type="text", text=result)])
487
+
488
+ # Entity consolidation tools
489
+ elif name == "find_entity_duplicates":
490
+ from .mcp_server import _find_entity_duplicates
491
+
492
+ result = _find_entity_duplicates(
493
+ kb.db_url,
494
+ similarity_threshold=arguments.get("similarity_threshold", 0.85),
495
+ limit=arguments.get("limit", 50),
496
+ )
497
+ return CallToolResult(content=[TextContent(type="text", text=result)])
498
+
499
+ elif name == "merge_entities":
500
+ from .mcp_server import _merge_entities
501
+
502
+ result = _merge_entities(
503
+ kb.db_url,
504
+ canonical_path=arguments["canonical_path"],
505
+ duplicate_path=arguments["duplicate_path"],
506
+ )
507
+ return CallToolResult(content=[TextContent(type="text", text=result)])
508
+
509
+ elif name == "list_pending_merges":
510
+ from .mcp_server import _list_pending_merges
511
+
512
+ result = _list_pending_merges(
513
+ kb.db_url,
514
+ limit=arguments.get("limit", 50),
515
+ )
516
+ return CallToolResult(content=[TextContent(type="text", text=result)])
517
+
518
+ elif name == "approve_merge":
519
+ from .mcp_server import _approve_merge
520
+
521
+ result = _approve_merge(kb.db_url, arguments["merge_id"])
522
+ return CallToolResult(content=[TextContent(type="text", text=result)])
523
+
524
+ elif name == "reject_merge":
525
+ from .mcp_server import _reject_merge
526
+
527
+ result = _reject_merge(kb.db_url, arguments["merge_id"])
528
+ return CallToolResult(content=[TextContent(type="text", text=result)])
529
+
530
+ elif name == "get_topic_clusters":
531
+ from .mcp_server import _get_topic_clusters
532
+
533
+ result = _get_topic_clusters(
534
+ kb.db_url,
535
+ limit=arguments.get("limit", 20),
536
+ )
537
+ return CallToolResult(content=[TextContent(type="text", text=result)])
538
+
539
+ elif name == "get_entity_relationships":
540
+ from .mcp_server import _get_entity_relationships
541
+
542
+ result = _get_entity_relationships(
543
+ kb.db_url,
544
+ entity_name=arguments.get("entity_name"),
545
+ relationship_type=arguments.get("relationship_type"),
546
+ limit=arguments.get("limit", 50),
547
+ )
548
+ return CallToolResult(content=[TextContent(type="text", text=result)])
549
+
550
+ elif name == "run_consolidation":
551
+ from .mcp_server import _run_consolidation
552
+
553
+ result = _run_consolidation(
554
+ kb.db_url,
555
+ detect_duplicates=arguments.get("detect_duplicates", True),
556
+ detect_cross_doc=arguments.get("detect_cross_doc", True),
557
+ build_clusters=arguments.get("build_clusters", True),
558
+ extract_relationships=arguments.get("extract_relationships", True),
559
+ dry_run=arguments.get("dry_run", False),
560
+ )
561
+ return CallToolResult(content=[TextContent(type="text", text=result)])
562
+
397
563
  else:
398
564
  return CallToolResult(
399
565
  content=[TextContent(type="text", text=f"Unknown tool: {name}")]
@@ -402,95 +568,131 @@ class HTTPMCPServer:
402
568
  except Exception as e:
403
569
  return CallToolResult(content=[TextContent(type="text", text=f"Error: {e!s}")])
404
570
 
405
- def create_app(self) -> Starlette:
571
+ def create_app(self):
406
572
  """Create the Starlette application."""
407
573
  verifier = OKBTokenVerifier(self._get_db_url)
574
+ session_header_name = "mcp-session-id"
408
575
 
409
- async def handle_sse(request: Request) -> Response:
410
- """Handle SSE connections for MCP."""
411
- # Verify token
412
- token = extract_token(request)
413
- if not token:
414
- return JSONResponse(
415
- {"error": "Missing token. Use Authorization header or ?token= parameter"},
416
- status_code=401,
417
- )
576
+ def create_mcp_handler():
577
+ """Create an ASGI handler for MCP with auth."""
418
578
 
419
- token_info = verifier.verify(token)
420
- if not token_info:
421
- return JSONResponse(
422
- {"error": "Invalid or expired token"},
423
- status_code=401,
424
- )
579
+ async def handle_mcp(scope, receive, send):
580
+ """Handle all MCP requests (GET, POST, DELETE) with auth."""
581
+ request = Request(scope, receive)
425
582
 
426
- # Track existing sessions before connecting
427
- existing_sessions = set(self.transport._read_stream_writers.keys())
428
-
429
- async with self.transport.connect_sse(
430
- request.scope, request.receive, request._send
431
- ) as (read_stream, write_stream):
432
- # Find the new session ID by comparing before/after
433
- current_sessions = set(self.transport._read_stream_writers.keys())
434
- new_sessions = current_sessions - existing_sessions
435
- if not new_sessions:
436
- return JSONResponse(
437
- {"error": "Failed to establish session"},
438
- status_code=500,
583
+ # Extract and verify token
584
+ token = extract_token(request)
585
+ if not token:
586
+ response = JSONResponse(
587
+ {"error": "Missing token. Use Authorization header or ?token= param"},
588
+ status_code=401,
439
589
  )
440
- session_id = new_sessions.pop()
441
- session_id_hex = session_id.hex
442
-
443
- # Store token mapping for this session
444
- self.session_tokens[session_id_hex] = token_info
445
- self.server._current_token_info = token_info
446
-
447
- try:
448
- await self.server.run(
449
- read_stream, write_stream, self.server.create_initialization_options()
590
+ await response(scope, receive, send)
591
+ return
592
+
593
+ token_info = verifier.verify(token)
594
+ if not token_info:
595
+ response = JSONResponse(
596
+ {"error": "Invalid or expired token"},
597
+ status_code=401,
450
598
  )
451
- finally:
452
- # Clean up session on disconnect
453
- self.session_tokens.pop(session_id_hex, None)
454
-
455
- return Response()
456
-
457
- async def handle_messages(scope, receive, send):
458
- """Handle POST messages for MCP (raw ASGI handler)."""
459
- request = Request(scope, receive)
460
-
461
- # Look up session from query params
462
- session_id = request.query_params.get("session_id")
463
- if not session_id:
464
- response = JSONResponse({"error": "Missing session_id"}, status_code=400)
465
- await response(scope, receive, send)
466
- return
467
-
468
- token_info = self.session_tokens.get(session_id)
469
- if not token_info:
470
- response = JSONResponse({"error": "Invalid or expired session"}, status_code=401)
471
- await response(scope, receive, send)
472
- return
599
+ await response(scope, receive, send)
600
+ return
601
+
602
+ # Check if this is an existing session
603
+ session_id = request.headers.get(session_header_name)
604
+ if session_id:
605
+ # Verify token matches existing session (compare by hash and db, not object)
606
+ existing_token = self.session_tokens.get(session_id)
607
+ if existing_token:
608
+ # Token must match the one used to create the session
609
+ if (
610
+ existing_token.token_hash != token_info.token_hash
611
+ or existing_token.database != token_info.database
612
+ ):
613
+ response = JSONResponse(
614
+ {"error": "Token mismatch for session"},
615
+ status_code=401,
616
+ )
617
+ await response(scope, receive, send)
618
+ return
473
619
 
474
- # Set current token info for tool calls
475
- self.server._current_token_info = token_info
620
+ # Set current token info for tool calls
621
+ self.server._current_token_info = token_info
476
622
 
477
- await self.transport.handle_post_message(scope, receive, send)
623
+ # Wrap send to capture the session ID from response headers
624
+ captured_session_id = None
478
625
 
479
- async def health(request: Request) -> JSONResponse:
480
- """Health check endpoint."""
481
- return JSONResponse({"status": "ok"})
626
+ async def send_wrapper(message):
627
+ nonlocal captured_session_id
628
+ if message["type"] == "http.response.start":
629
+ headers = message.get("headers", [])
630
+ for name, value in headers:
631
+ header_name = (
632
+ name.lower() if isinstance(name, bytes) else name.lower().encode()
633
+ )
634
+ if header_name == session_header_name.encode():
635
+ captured_session_id = (
636
+ value.decode() if isinstance(value, bytes) else value
637
+ )
638
+ # Store immediately since SSE keeps connection open
639
+ if captured_session_id not in self.session_tokens:
640
+ self.session_tokens[captured_session_id] = token_info
641
+ break
642
+ await send(message)
643
+
644
+ # Delegate to session manager
645
+ await self.session_manager.handle_request(scope, receive, send_wrapper)
646
+
647
+ return handle_mcp
648
+
649
+ # Create the MCP handler ASGI app
650
+ mcp_handler = create_mcp_handler()
651
+
652
+ # Custom ASGI app that routes /mcp and /sse to MCP handler
653
+ async def router(scope, receive, send):
654
+ if scope["type"] == "http":
655
+ path = scope["path"].rstrip("/") # Handle trailing slash
656
+ if path in ("/mcp", "/sse"):
657
+ await mcp_handler(scope, receive, send)
658
+ return
659
+ elif path == "/health" or scope["path"] == "/health":
660
+ response = JSONResponse({"status": "ok"})
661
+ await response(scope, receive, send)
662
+ return
663
+ # 404 for unknown paths
664
+ response = JSONResponse({"error": "Not found"}, status_code=404)
665
+ await response(scope, receive, send)
666
+
667
+ # Wrap with lifespan handling
668
+ async def app_with_lifespan(scope, receive, send):
669
+ if scope["type"] == "lifespan":
670
+ async with self.session_manager.run():
671
+ # Handle lifespan protocol
672
+ while True:
673
+ message = await receive()
674
+ if message["type"] == "lifespan.startup":
675
+ await send({"type": "lifespan.startup.complete"})
676
+ elif message["type"] == "lifespan.shutdown":
677
+ await send({"type": "lifespan.shutdown.complete"})
678
+ return
679
+ else:
680
+ await router(scope, receive, send)
482
681
 
483
- routes = [
484
- Route("/health", health, methods=["GET"]),
485
- Route("/sse", handle_sse, methods=["GET"]),
486
- Mount("/messages", app=handle_messages),
487
- ]
682
+ # Add CORS for browser clients - wrap the raw ASGI app
683
+ app = CORSMiddleware(
684
+ app_with_lifespan,
685
+ allow_origins=["*"],
686
+ allow_methods=["GET", "POST", "DELETE"],
687
+ allow_headers=["Authorization", "Content-Type", session_header_name],
688
+ expose_headers=[session_header_name],
689
+ )
488
690
 
489
- return Starlette(routes=routes)
691
+ return app
490
692
 
491
693
 
492
694
  def run_http_server(host: str = "127.0.0.1", port: int = 8080):
493
- """Run the HTTP MCP server."""
695
+ """Run the HTTP MCP server with Streamable HTTP transport."""
494
696
  import uvicorn
495
697
 
496
698
  print("Warming up embedding model...", file=sys.stderr)
@@ -501,8 +703,9 @@ def run_http_server(host: str = "127.0.0.1", port: int = 8080):
501
703
  app = http_server.create_app()
502
704
 
503
705
  print(f"Starting HTTP MCP server on http://{host}:{port}", file=sys.stderr)
504
- print(" SSE endpoint: /sse", file=sys.stderr)
505
- print(" Messages endpoint: /messages/", file=sys.stderr)
706
+ print(" MCP endpoint: /mcp (GET, POST, DELETE)", file=sys.stderr)
707
+ print(" MCP endpoint: /sse (alias for /mcp)", file=sys.stderr)
506
708
  print(" Health endpoint: /health", file=sys.stderr)
709
+ print(" Transport: Streamable HTTP", file=sys.stderr)
507
710
 
508
711
  uvicorn.run(app, host=host, port=port, log_level="info")