kg-mcp 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kg_mcp/__init__.py +5 -0
- kg_mcp/__main__.py +8 -0
- kg_mcp/cli/__init__.py +3 -0
- kg_mcp/cli/setup.py +1100 -0
- kg_mcp/cli/status.py +344 -0
- kg_mcp/codegraph/__init__.py +3 -0
- kg_mcp/codegraph/indexer.py +296 -0
- kg_mcp/codegraph/model.py +170 -0
- kg_mcp/config.py +83 -0
- kg_mcp/kg/__init__.py +3 -0
- kg_mcp/kg/apply_schema.py +93 -0
- kg_mcp/kg/ingest.py +253 -0
- kg_mcp/kg/neo4j.py +155 -0
- kg_mcp/kg/repo.py +756 -0
- kg_mcp/kg/retrieval.py +225 -0
- kg_mcp/kg/schema.cypher +176 -0
- kg_mcp/llm/__init__.py +4 -0
- kg_mcp/llm/client.py +291 -0
- kg_mcp/llm/prompts/__init__.py +8 -0
- kg_mcp/llm/prompts/extractor.py +84 -0
- kg_mcp/llm/prompts/linker.py +117 -0
- kg_mcp/llm/schemas.py +248 -0
- kg_mcp/main.py +195 -0
- kg_mcp/mcp/__init__.py +3 -0
- kg_mcp/mcp/change_schemas.py +140 -0
- kg_mcp/mcp/prompts.py +223 -0
- kg_mcp/mcp/resources.py +218 -0
- kg_mcp/mcp/tools.py +537 -0
- kg_mcp/security/__init__.py +3 -0
- kg_mcp/security/auth.py +121 -0
- kg_mcp/security/origin.py +112 -0
- kg_mcp/utils.py +100 -0
- kg_mcp-0.1.8.dist-info/METADATA +86 -0
- kg_mcp-0.1.8.dist-info/RECORD +36 -0
- kg_mcp-0.1.8.dist-info/WHEEL +4 -0
- kg_mcp-0.1.8.dist-info/entry_points.txt +4 -0
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data models for code graph entities.
|
|
3
|
+
Represents files, symbols, and their relationships.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from enum import Enum
|
|
9
|
+
from typing import List, Optional
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SymbolKind(str, Enum):
|
|
13
|
+
"""Types of code symbols."""
|
|
14
|
+
|
|
15
|
+
FILE = "file"
|
|
16
|
+
MODULE = "module"
|
|
17
|
+
CLASS = "class"
|
|
18
|
+
FUNCTION = "function"
|
|
19
|
+
METHOD = "method"
|
|
20
|
+
PROPERTY = "property"
|
|
21
|
+
VARIABLE = "variable"
|
|
22
|
+
CONSTANT = "constant"
|
|
23
|
+
INTERFACE = "interface"
|
|
24
|
+
ENUM = "enum"
|
|
25
|
+
TYPE_ALIAS = "type_alias"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ReferenceKind(str, Enum):
|
|
29
|
+
"""Types of symbol references."""
|
|
30
|
+
|
|
31
|
+
CALL = "call"
|
|
32
|
+
IMPORT = "import"
|
|
33
|
+
INHERIT = "inherit"
|
|
34
|
+
IMPLEMENT = "implement"
|
|
35
|
+
USE = "use"
|
|
36
|
+
OVERRIDE = "override"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class SourceLocation:
|
|
41
|
+
"""Location in source code."""
|
|
42
|
+
|
|
43
|
+
file_path: str
|
|
44
|
+
start_line: int
|
|
45
|
+
start_column: int = 0
|
|
46
|
+
end_line: Optional[int] = None
|
|
47
|
+
end_column: Optional[int] = None
|
|
48
|
+
|
|
49
|
+
def __post_init__(self):
|
|
50
|
+
if self.end_line is None:
|
|
51
|
+
self.end_line = self.start_line
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclass
|
|
55
|
+
class Symbol:
|
|
56
|
+
"""Represents a code symbol (function, class, variable, etc.)."""
|
|
57
|
+
|
|
58
|
+
fqn: str # Fully qualified name
|
|
59
|
+
name: str # Short name
|
|
60
|
+
kind: SymbolKind
|
|
61
|
+
location: SourceLocation
|
|
62
|
+
signature: Optional[str] = None
|
|
63
|
+
docstring: Optional[str] = None
|
|
64
|
+
parent_fqn: Optional[str] = None # Parent symbol (e.g., class for method)
|
|
65
|
+
modifiers: List[str] = field(default_factory=list) # public, private, static, async, etc.
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def file_path(self) -> str:
|
|
69
|
+
return self.location.file_path
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@dataclass
|
|
73
|
+
class SymbolReference:
|
|
74
|
+
"""Represents a reference from one symbol to another."""
|
|
75
|
+
|
|
76
|
+
source_fqn: str # Symbol making the reference
|
|
77
|
+
target_fqn: str # Symbol being referenced
|
|
78
|
+
kind: ReferenceKind
|
|
79
|
+
location: SourceLocation
|
|
80
|
+
context: Optional[str] = None # Line of code with reference
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass
|
|
84
|
+
class FileInfo:
|
|
85
|
+
"""Metadata about a source file."""
|
|
86
|
+
|
|
87
|
+
path: str
|
|
88
|
+
language: str
|
|
89
|
+
content_hash: str
|
|
90
|
+
size_bytes: int
|
|
91
|
+
line_count: int
|
|
92
|
+
last_modified: datetime
|
|
93
|
+
git_commit: Optional[str] = None
|
|
94
|
+
symbols: List[Symbol] = field(default_factory=list)
|
|
95
|
+
|
|
96
|
+
def add_symbol(self, symbol: Symbol) -> None:
|
|
97
|
+
"""Add a symbol to this file."""
|
|
98
|
+
self.symbols.append(symbol)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@dataclass
|
|
102
|
+
class CodeGraphSnapshot:
|
|
103
|
+
"""A snapshot of the entire code graph."""
|
|
104
|
+
|
|
105
|
+
project_id: str
|
|
106
|
+
timestamp: datetime
|
|
107
|
+
files: List[FileInfo]
|
|
108
|
+
references: List[SymbolReference]
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def total_symbols(self) -> int:
|
|
112
|
+
return sum(len(f.symbols) for f in self.files)
|
|
113
|
+
|
|
114
|
+
@property
|
|
115
|
+
def total_files(self) -> int:
|
|
116
|
+
return len(self.files)
|
|
117
|
+
|
|
118
|
+
@property
|
|
119
|
+
def total_references(self) -> int:
|
|
120
|
+
return len(self.references)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
# Language detection based on file extension
|
|
124
|
+
LANGUAGE_EXTENSIONS = {
|
|
125
|
+
".py": "python",
|
|
126
|
+
".pyi": "python",
|
|
127
|
+
".js": "javascript",
|
|
128
|
+
".jsx": "javascript",
|
|
129
|
+
".ts": "typescript",
|
|
130
|
+
".tsx": "typescript",
|
|
131
|
+
".java": "java",
|
|
132
|
+
".kt": "kotlin",
|
|
133
|
+
".go": "go",
|
|
134
|
+
".rs": "rust",
|
|
135
|
+
".c": "c",
|
|
136
|
+
".cpp": "cpp",
|
|
137
|
+
".cc": "cpp",
|
|
138
|
+
".h": "c",
|
|
139
|
+
".hpp": "cpp",
|
|
140
|
+
".cs": "csharp",
|
|
141
|
+
".rb": "ruby",
|
|
142
|
+
".php": "php",
|
|
143
|
+
".swift": "swift",
|
|
144
|
+
".scala": "scala",
|
|
145
|
+
".r": "r",
|
|
146
|
+
".R": "r",
|
|
147
|
+
".sql": "sql",
|
|
148
|
+
".sh": "bash",
|
|
149
|
+
".bash": "bash",
|
|
150
|
+
".zsh": "zsh",
|
|
151
|
+
".yaml": "yaml",
|
|
152
|
+
".yml": "yaml",
|
|
153
|
+
".json": "json",
|
|
154
|
+
".toml": "toml",
|
|
155
|
+
".xml": "xml",
|
|
156
|
+
".html": "html",
|
|
157
|
+
".css": "css",
|
|
158
|
+
".scss": "scss",
|
|
159
|
+
".less": "less",
|
|
160
|
+
".md": "markdown",
|
|
161
|
+
".rst": "rst",
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def detect_language(file_path: str) -> str:
|
|
166
|
+
"""Detect language from file extension."""
|
|
167
|
+
from pathlib import Path
|
|
168
|
+
|
|
169
|
+
ext = Path(file_path).suffix.lower()
|
|
170
|
+
return LANGUAGE_EXTENSIONS.get(ext, "unknown")
|
kg_mcp/config.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configuration management for MCP-KG-Memory server.
|
|
3
|
+
Uses pydantic-settings for environment variable handling.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from functools import lru_cache
|
|
7
|
+
from typing import List
|
|
8
|
+
|
|
9
|
+
from pydantic import Field
|
|
10
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Settings(BaseSettings):
|
|
14
|
+
"""Application settings loaded from environment variables."""
|
|
15
|
+
|
|
16
|
+
model_config = SettingsConfigDict(
|
|
17
|
+
env_file=".env",
|
|
18
|
+
env_file_encoding="utf-8",
|
|
19
|
+
case_sensitive=False,
|
|
20
|
+
extra="ignore",
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
# Neo4j Configuration
|
|
24
|
+
neo4j_uri: str = Field(default="bolt://localhost:7687", description="Neo4j Bolt URI")
|
|
25
|
+
neo4j_user: str = Field(default="neo4j", description="Neo4j username")
|
|
26
|
+
neo4j_password: str = Field(default="password123", description="Neo4j password")
|
|
27
|
+
neo4j_configured: str = Field(default="1", description="Is Neo4j configured (1/0)")
|
|
28
|
+
|
|
29
|
+
# LLM Configuration (supports both direct Gemini and LiteLLM Gateway)
|
|
30
|
+
# Mode: 'gemini_direct', 'litellm', or 'both'
|
|
31
|
+
llm_mode: str = Field(default="litellm", description="Operation mode")
|
|
32
|
+
llm_primary: str = Field(default="litellm", description="Primary provider if both configured")
|
|
33
|
+
llm_provider: str = Field(default="litellm", description="Active provider tag")
|
|
34
|
+
|
|
35
|
+
# Gemini Direct
|
|
36
|
+
gemini_api_key: str = Field(default="", description="Gemini API Key (direct)")
|
|
37
|
+
gemini_base_url: str = Field(default="https://generativelanguage.googleapis.com/", description="Gemini API Base URL")
|
|
38
|
+
gemini_model: str = Field(default="", description="Gemini Direct Model ID")
|
|
39
|
+
|
|
40
|
+
# LiteLLM
|
|
41
|
+
litellm_api_key: str = Field(default="", description="LiteLLM Gateway API Key")
|
|
42
|
+
litellm_base_url: str = Field(default="", description="LiteLLM Gateway Base URL")
|
|
43
|
+
litellm_model: str = Field(default="", description="LiteLLM Model ID")
|
|
44
|
+
|
|
45
|
+
# General / Fallback Model
|
|
46
|
+
llm_model: str = Field(
|
|
47
|
+
default="gemini/gemini-1.5-flash",
|
|
48
|
+
description="Default model identifier (legacy fallback)",
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# Task-specific Routing (optional)
|
|
52
|
+
kg_model_default: str = Field(default="", description="Default model for general tasks")
|
|
53
|
+
kg_model_fast: str = Field(default="", description="Fast model for high-throughput")
|
|
54
|
+
kg_model_reason: str = Field(default="", description="Reasoning model for complex tasks")
|
|
55
|
+
|
|
56
|
+
llm_temperature: float = Field(default=0.2, description="LLM temperature for extraction")
|
|
57
|
+
llm_max_tokens: int = Field(default=4096, description="Maximum tokens for LLM response")
|
|
58
|
+
|
|
59
|
+
# MCP Server Configuration
|
|
60
|
+
mcp_host: str = Field(default="127.0.0.1", description="MCP server host")
|
|
61
|
+
mcp_port: int = Field(default=8000, description="MCP server port")
|
|
62
|
+
mcp_stateless: bool = Field(default=True, description="Run server in stateless mode")
|
|
63
|
+
|
|
64
|
+
# Security Configuration
|
|
65
|
+
kg_mcp_token: str = Field(default="", description="Bearer token for authentication")
|
|
66
|
+
kg_allowed_origins: str = Field(
|
|
67
|
+
default="http://localhost:*,http://127.0.0.1:*",
|
|
68
|
+
description="Comma-separated list of allowed origins",
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Logging
|
|
72
|
+
log_level: str = Field(default="INFO", description="Logging level")
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def allowed_origins_list(self) -> List[str]:
|
|
76
|
+
"""Parse allowed origins into a list."""
|
|
77
|
+
return [origin.strip() for origin in self.kg_allowed_origins.split(",") if origin.strip()]
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@lru_cache
|
|
81
|
+
def get_settings() -> Settings:
|
|
82
|
+
"""Get cached settings instance."""
|
|
83
|
+
return Settings()
|
kg_mcp/kg/__init__.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Apply Neo4j schema constraints and indexes.
|
|
3
|
+
Run this script to initialize the database schema.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import asyncio
|
|
7
|
+
import logging
|
|
8
|
+
import sys
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
from kg_mcp.kg.neo4j import get_neo4j_client, init_neo4j, close_neo4j
|
|
12
|
+
|
|
13
|
+
logging.basicConfig(
|
|
14
|
+
level=logging.INFO,
|
|
15
|
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
16
|
+
)
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
async def apply_schema() -> None:
|
|
21
|
+
"""Read and apply schema.cypher to Neo4j."""
|
|
22
|
+
logger.info("Connecting to Neo4j...")
|
|
23
|
+
await init_neo4j()
|
|
24
|
+
|
|
25
|
+
client = get_neo4j_client()
|
|
26
|
+
|
|
27
|
+
# Read schema file
|
|
28
|
+
schema_path = Path(__file__).parent / "schema.cypher"
|
|
29
|
+
if not schema_path.exists():
|
|
30
|
+
logger.error(f"Schema file not found: {schema_path}")
|
|
31
|
+
sys.exit(1)
|
|
32
|
+
|
|
33
|
+
logger.info(f"Reading schema from {schema_path}")
|
|
34
|
+
schema_content = schema_path.read_text()
|
|
35
|
+
|
|
36
|
+
# Split into individual statements (skip comments and empty lines)
|
|
37
|
+
statements = []
|
|
38
|
+
for line in schema_content.split("\n"):
|
|
39
|
+
line = line.strip()
|
|
40
|
+
# Skip comments and empty lines
|
|
41
|
+
if not line or line.startswith("//"):
|
|
42
|
+
continue
|
|
43
|
+
statements.append(line)
|
|
44
|
+
|
|
45
|
+
# Join multi-line statements
|
|
46
|
+
full_statements = []
|
|
47
|
+
current_stmt = []
|
|
48
|
+
for line in statements:
|
|
49
|
+
current_stmt.append(line)
|
|
50
|
+
if line.endswith(";"):
|
|
51
|
+
full_stmt = " ".join(current_stmt).rstrip(";")
|
|
52
|
+
full_statements.append(full_stmt)
|
|
53
|
+
current_stmt = []
|
|
54
|
+
|
|
55
|
+
logger.info(f"Found {len(full_statements)} schema statements")
|
|
56
|
+
|
|
57
|
+
# Execute each statement
|
|
58
|
+
success_count = 0
|
|
59
|
+
error_count = 0
|
|
60
|
+
|
|
61
|
+
for i, stmt in enumerate(full_statements, 1):
|
|
62
|
+
if not stmt.strip():
|
|
63
|
+
continue
|
|
64
|
+
|
|
65
|
+
try:
|
|
66
|
+
logger.debug(f"Executing statement {i}: {stmt[:80]}...")
|
|
67
|
+
await client.execute_query(stmt)
|
|
68
|
+
success_count += 1
|
|
69
|
+
logger.info(f"✓ Statement {i} applied successfully")
|
|
70
|
+
except Exception as e:
|
|
71
|
+
error_msg = str(e)
|
|
72
|
+
# Ignore "already exists" errors for constraints/indexes
|
|
73
|
+
if "already exists" in error_msg.lower() or "equivalent" in error_msg.lower():
|
|
74
|
+
logger.info(f"⊘ Statement {i} skipped (already exists)")
|
|
75
|
+
success_count += 1
|
|
76
|
+
else:
|
|
77
|
+
logger.error(f"✗ Statement {i} failed: {e}")
|
|
78
|
+
error_count += 1
|
|
79
|
+
|
|
80
|
+
logger.info(f"\nSchema application complete:")
|
|
81
|
+
logger.info(f" ✓ Success: {success_count}")
|
|
82
|
+
logger.info(f" ✗ Errors: {error_count}")
|
|
83
|
+
|
|
84
|
+
await close_neo4j()
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def main() -> None:
|
|
88
|
+
"""Entry point for schema application."""
|
|
89
|
+
asyncio.run(apply_schema())
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
if __name__ == "__main__":
|
|
93
|
+
main()
|
kg_mcp/kg/ingest.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Ingest pipeline for processing user requests.
|
|
3
|
+
Orchestrates LLM extraction, linking, and Neo4j commit.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
from typing import Any, Dict, List, Optional
|
|
8
|
+
|
|
9
|
+
from kg_mcp.llm.client import get_llm_client
|
|
10
|
+
from kg_mcp.llm.schemas import ExtractionResult, LinkingResult
|
|
11
|
+
from kg_mcp.kg.repo import get_repository
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class IngestPipeline:
|
|
17
|
+
"""Pipeline for ingesting user interactions into the knowledge graph."""
|
|
18
|
+
|
|
19
|
+
def __init__(self):
|
|
20
|
+
self.llm = get_llm_client()
|
|
21
|
+
self.repo = get_repository()
|
|
22
|
+
|
|
23
|
+
async def process_message(
|
|
24
|
+
self,
|
|
25
|
+
project_id: str,
|
|
26
|
+
user_text: str,
|
|
27
|
+
user_id: str = "default_user",
|
|
28
|
+
files: Optional[List[str]] = None,
|
|
29
|
+
diff: Optional[str] = None,
|
|
30
|
+
symbols: Optional[List[str]] = None,
|
|
31
|
+
tags: Optional[List[str]] = None,
|
|
32
|
+
) -> Dict[str, Any]:
|
|
33
|
+
"""
|
|
34
|
+
Process a user message through the full pipeline.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
project_id: Project ID to associate with this interaction
|
|
38
|
+
user_text: The user's message text
|
|
39
|
+
user_id: User ID for preferences
|
|
40
|
+
files: Optional list of file paths involved
|
|
41
|
+
diff: Optional code diff
|
|
42
|
+
symbols: Optional list of code symbols
|
|
43
|
+
tags: Optional tags for this interaction
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Dict containing interaction_id, extracted entities, and created entity IDs
|
|
47
|
+
"""
|
|
48
|
+
logger.info(f"Processing message for project {project_id}")
|
|
49
|
+
|
|
50
|
+
# Step 0: Ensure project exists
|
|
51
|
+
await self.repo.get_or_create_project(project_id)
|
|
52
|
+
|
|
53
|
+
# Step 1: Create interaction record
|
|
54
|
+
interaction = await self.repo.create_interaction(
|
|
55
|
+
project_id=project_id,
|
|
56
|
+
user_text=user_text,
|
|
57
|
+
tags=tags,
|
|
58
|
+
)
|
|
59
|
+
interaction_id = interaction["id"]
|
|
60
|
+
logger.info(f"Created interaction {interaction_id}")
|
|
61
|
+
|
|
62
|
+
# Step 2: Extract entities using LLM
|
|
63
|
+
extraction = await self.llm.extract_entities(
|
|
64
|
+
user_text=user_text,
|
|
65
|
+
files=files,
|
|
66
|
+
diff=diff,
|
|
67
|
+
symbols=symbols,
|
|
68
|
+
)
|
|
69
|
+
logger.info(
|
|
70
|
+
f"Extracted: {len(extraction.goals)} goals, "
|
|
71
|
+
f"{len(extraction.constraints)} constraints, "
|
|
72
|
+
f"{len(extraction.preferences)} preferences, "
|
|
73
|
+
f"{len(extraction.pain_points)} pain points, "
|
|
74
|
+
f"{len(extraction.strategies)} strategies"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Step 3: Get existing entities for linking
|
|
78
|
+
existing_goals = await self.repo.get_all_goals(project_id)
|
|
79
|
+
existing_preferences = await self.repo.get_preferences(user_id)
|
|
80
|
+
recent_interactions = await self.repo.get_recent_interactions(project_id, limit=5)
|
|
81
|
+
|
|
82
|
+
# Step 4: Link entities using LLM
|
|
83
|
+
linking = await self.llm.link_entities(
|
|
84
|
+
extraction=extraction,
|
|
85
|
+
existing_goals=existing_goals,
|
|
86
|
+
existing_preferences=existing_preferences,
|
|
87
|
+
recent_interactions=recent_interactions,
|
|
88
|
+
)
|
|
89
|
+
logger.info(
|
|
90
|
+
f"Linking: {len(linking.merge_suggestions)} merges, "
|
|
91
|
+
f"{len(linking.relationships)} relationships"
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# Step 5: Commit to Neo4j
|
|
95
|
+
created_entities = await self._commit_to_graph(
|
|
96
|
+
project_id=project_id,
|
|
97
|
+
user_id=user_id,
|
|
98
|
+
interaction_id=interaction_id,
|
|
99
|
+
extraction=extraction,
|
|
100
|
+
linking=linking,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
return {
|
|
104
|
+
"interaction_id": interaction_id,
|
|
105
|
+
"extracted": extraction.model_dump(),
|
|
106
|
+
"linking": linking.model_dump(),
|
|
107
|
+
"created_entities": created_entities,
|
|
108
|
+
"confidence": extraction.confidence,
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
async def _commit_to_graph(
|
|
112
|
+
self,
|
|
113
|
+
project_id: str,
|
|
114
|
+
user_id: str,
|
|
115
|
+
interaction_id: str,
|
|
116
|
+
extraction: ExtractionResult,
|
|
117
|
+
linking: LinkingResult,
|
|
118
|
+
) -> Dict[str, List[str]]:
|
|
119
|
+
"""
|
|
120
|
+
Commit extracted entities to Neo4j.
|
|
121
|
+
|
|
122
|
+
Returns dict of entity type -> list of created IDs.
|
|
123
|
+
"""
|
|
124
|
+
created = {
|
|
125
|
+
"goals": [],
|
|
126
|
+
"constraints": [],
|
|
127
|
+
"preferences": [],
|
|
128
|
+
"pain_points": [],
|
|
129
|
+
"strategies": [],
|
|
130
|
+
"code_artifacts": [],
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
# Process merge suggestions first to build ID mapping
|
|
134
|
+
merge_map: Dict[str, str] = {} # new_title -> existing_id
|
|
135
|
+
for merge in linking.merge_suggestions:
|
|
136
|
+
if merge.confidence >= 0.7:
|
|
137
|
+
merge_map[merge.new_entity_title] = merge.existing_entity_id
|
|
138
|
+
logger.info(
|
|
139
|
+
f"Merging '{merge.new_entity_title}' into existing "
|
|
140
|
+
f"'{merge.existing_entity_title}'"
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Create/update goals
|
|
144
|
+
goal_id_map: Dict[str, str] = {} # title -> id
|
|
145
|
+
for goal_extract in extraction.goals:
|
|
146
|
+
if goal_extract.title in merge_map:
|
|
147
|
+
goal_id = merge_map[goal_extract.title]
|
|
148
|
+
goal_id_map[goal_extract.title] = goal_id
|
|
149
|
+
else:
|
|
150
|
+
goal = await self.repo.upsert_goal(
|
|
151
|
+
project_id=project_id,
|
|
152
|
+
title=goal_extract.title,
|
|
153
|
+
description=goal_extract.description,
|
|
154
|
+
status=goal_extract.status,
|
|
155
|
+
priority=goal_extract.priority,
|
|
156
|
+
)
|
|
157
|
+
goal_id = goal["id"]
|
|
158
|
+
goal_id_map[goal_extract.title] = goal_id
|
|
159
|
+
created["goals"].append(goal_id)
|
|
160
|
+
|
|
161
|
+
# Link interaction to goal
|
|
162
|
+
await self.repo.link_interaction_to_goal(interaction_id, goal_id)
|
|
163
|
+
|
|
164
|
+
# Create/update constraints
|
|
165
|
+
for constraint_extract in extraction.constraints:
|
|
166
|
+
# Find related goal if mentioned
|
|
167
|
+
related_goal_id = None
|
|
168
|
+
for goal_extract in extraction.goals:
|
|
169
|
+
if goal_extract.title in goal_id_map:
|
|
170
|
+
related_goal_id = goal_id_map[goal_extract.title]
|
|
171
|
+
break
|
|
172
|
+
|
|
173
|
+
constraint = await self.repo.upsert_constraint(
|
|
174
|
+
project_id=project_id,
|
|
175
|
+
constraint_type=constraint_extract.type,
|
|
176
|
+
description=constraint_extract.description,
|
|
177
|
+
severity=constraint_extract.severity,
|
|
178
|
+
goal_id=related_goal_id,
|
|
179
|
+
)
|
|
180
|
+
created["constraints"].append(constraint["id"])
|
|
181
|
+
|
|
182
|
+
# Create/update preferences
|
|
183
|
+
for pref_extract in extraction.preferences:
|
|
184
|
+
pref = await self.repo.upsert_preference(
|
|
185
|
+
user_id=user_id,
|
|
186
|
+
category=pref_extract.category,
|
|
187
|
+
preference=pref_extract.preference,
|
|
188
|
+
strength=pref_extract.strength,
|
|
189
|
+
)
|
|
190
|
+
created["preferences"].append(pref["id"])
|
|
191
|
+
|
|
192
|
+
# Create/update pain points
|
|
193
|
+
for pp_extract in extraction.pain_points:
|
|
194
|
+
related_goal_id = None
|
|
195
|
+
if pp_extract.related_goal and pp_extract.related_goal in goal_id_map:
|
|
196
|
+
related_goal_id = goal_id_map[pp_extract.related_goal]
|
|
197
|
+
|
|
198
|
+
pp = await self.repo.upsert_painpoint(
|
|
199
|
+
project_id=project_id,
|
|
200
|
+
description=pp_extract.description,
|
|
201
|
+
severity=pp_extract.severity,
|
|
202
|
+
related_goal_id=related_goal_id,
|
|
203
|
+
interaction_id=interaction_id,
|
|
204
|
+
)
|
|
205
|
+
created["pain_points"].append(pp["id"])
|
|
206
|
+
|
|
207
|
+
# Create/update strategies
|
|
208
|
+
for strategy_extract in extraction.strategies:
|
|
209
|
+
related_goal_id = None
|
|
210
|
+
if strategy_extract.related_goal and strategy_extract.related_goal in goal_id_map:
|
|
211
|
+
related_goal_id = goal_id_map[strategy_extract.related_goal]
|
|
212
|
+
|
|
213
|
+
strategy = await self.repo.upsert_strategy(
|
|
214
|
+
project_id=project_id,
|
|
215
|
+
title=strategy_extract.title,
|
|
216
|
+
approach=strategy_extract.approach,
|
|
217
|
+
rationale=strategy_extract.rationale,
|
|
218
|
+
outcome=strategy_extract.outcome,
|
|
219
|
+
outcome_reason=strategy_extract.outcome_reason,
|
|
220
|
+
related_goal_id=related_goal_id,
|
|
221
|
+
)
|
|
222
|
+
created["strategies"].append(strategy["id"])
|
|
223
|
+
|
|
224
|
+
# Create code references as artifacts
|
|
225
|
+
for code_ref in extraction.code_references:
|
|
226
|
+
# Find related goals
|
|
227
|
+
related_goal_ids = list(goal_id_map.values())[:3] # Link to first 3 goals
|
|
228
|
+
|
|
229
|
+
artifact = await self.repo.upsert_code_artifact(
|
|
230
|
+
project_id=project_id,
|
|
231
|
+
path=code_ref.path,
|
|
232
|
+
kind="file",
|
|
233
|
+
symbol_fqn=code_ref.symbol,
|
|
234
|
+
start_line=code_ref.start_line,
|
|
235
|
+
end_line=code_ref.end_line,
|
|
236
|
+
related_goal_ids=related_goal_ids if related_goal_ids else None,
|
|
237
|
+
)
|
|
238
|
+
created["code_artifacts"].append(artifact["id"])
|
|
239
|
+
|
|
240
|
+
logger.info(f"Committed to graph: {created}")
|
|
241
|
+
return created
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
# Factory function
|
|
245
|
+
_pipeline: Optional[IngestPipeline] = None
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def get_ingest_pipeline() -> IngestPipeline:
|
|
249
|
+
"""Get or create the ingest pipeline singleton."""
|
|
250
|
+
global _pipeline
|
|
251
|
+
if _pipeline is None:
|
|
252
|
+
_pipeline = IngestPipeline()
|
|
253
|
+
return _pipeline
|