kg-mcp 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kg_mcp/kg/neo4j.py ADDED
@@ -0,0 +1,155 @@
1
+ """
2
+ Neo4j driver and session management.
3
+ Provides async-compatible driver wrapper with connection pooling.
4
+ """
5
+
6
+ import logging
7
+ from contextlib import asynccontextmanager
8
+ from typing import Any, AsyncGenerator, Dict, List, Optional
9
+
10
+ from neo4j import AsyncGraphDatabase, AsyncDriver, AsyncSession
11
+ from neo4j.exceptions import ServiceUnavailable, Neo4jError
12
+
13
+ from kg_mcp.config import get_settings
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class Neo4jClient:
19
+ """Async Neo4j client with connection management."""
20
+
21
+ _instance: Optional["Neo4jClient"] = None
22
+ _driver: Optional[AsyncDriver] = None
23
+
24
+ def __new__(cls) -> "Neo4jClient":
25
+ """Singleton pattern for Neo4j client."""
26
+ if cls._instance is None:
27
+ cls._instance = super().__new__(cls)
28
+ return cls._instance
29
+
30
+ async def connect(self) -> None:
31
+ """Initialize the Neo4j driver connection."""
32
+ if self._driver is not None:
33
+ return
34
+
35
+ settings = get_settings()
36
+ try:
37
+ self._driver = AsyncGraphDatabase.driver(
38
+ settings.neo4j_uri,
39
+ auth=(settings.neo4j_user, settings.neo4j_password),
40
+ max_connection_lifetime=3600,
41
+ max_connection_pool_size=50,
42
+ connection_acquisition_timeout=60,
43
+ )
44
+ # Verify connectivity
45
+ await self._driver.verify_connectivity()
46
+ logger.info(f"Connected to Neo4j at {settings.neo4j_uri}")
47
+ except ServiceUnavailable as e:
48
+ logger.error(f"Failed to connect to Neo4j: {e}")
49
+ raise
50
+
51
+ async def close(self) -> None:
52
+ """Close the Neo4j driver connection."""
53
+ if self._driver is not None:
54
+ await self._driver.close()
55
+ self._driver = None
56
+ logger.info("Neo4j connection closed")
57
+
58
+ @asynccontextmanager
59
+ async def session(self, database: str = "neo4j") -> AsyncGenerator[AsyncSession, None]:
60
+ """Get an async session context manager."""
61
+ if self._driver is None:
62
+ await self.connect()
63
+
64
+ session = self._driver.session(database=database)
65
+ try:
66
+ yield session
67
+ finally:
68
+ await session.close()
69
+
70
+ async def execute_query(
71
+ self,
72
+ query: str,
73
+ parameters: Optional[Dict[str, Any]] = None,
74
+ database: str = "neo4j",
75
+ ) -> List[Dict[str, Any]]:
76
+ """
77
+ Execute a Cypher query and return results as a list of dicts.
78
+
79
+ Args:
80
+ query: Cypher query string
81
+ parameters: Query parameters
82
+ database: Target database name
83
+
84
+ Returns:
85
+ List of records as dictionaries
86
+ """
87
+ if self._driver is None:
88
+ await self.connect()
89
+
90
+ try:
91
+ result = await self._driver.execute_query(
92
+ query,
93
+ parameters_=parameters or {},
94
+ database_=database,
95
+ )
96
+ return [dict(record) for record in result.records]
97
+ except Neo4jError as e:
98
+ logger.error(f"Query execution failed: {e}")
99
+ raise
100
+
101
+ async def execute_write(
102
+ self,
103
+ query: str,
104
+ parameters: Optional[Dict[str, Any]] = None,
105
+ database: str = "neo4j",
106
+ ) -> Dict[str, Any]:
107
+ """
108
+ Execute a write query and return summary.
109
+
110
+ Args:
111
+ query: Cypher query string
112
+ parameters: Query parameters
113
+ database: Target database name
114
+
115
+ Returns:
116
+ Summary with nodes/relationships created/modified counts
117
+ """
118
+ if self._driver is None:
119
+ await self.connect()
120
+
121
+ async with self.session(database) as session:
122
+ result = await session.run(query, parameters or {})
123
+ summary = await result.consume()
124
+
125
+ return {
126
+ "nodes_created": summary.counters.nodes_created,
127
+ "nodes_deleted": summary.counters.nodes_deleted,
128
+ "relationships_created": summary.counters.relationships_created,
129
+ "relationships_deleted": summary.counters.relationships_deleted,
130
+ "properties_set": summary.counters.properties_set,
131
+ }
132
+
133
+
134
+ # Singleton instance
135
+ _client: Optional[Neo4jClient] = None
136
+
137
+
138
+ def get_neo4j_client() -> Neo4jClient:
139
+ """Get or create the Neo4j client singleton."""
140
+ global _client
141
+ if _client is None:
142
+ _client = Neo4jClient()
143
+ return _client
144
+
145
+
146
+ async def init_neo4j() -> None:
147
+ """Initialize Neo4j connection (call at startup)."""
148
+ client = get_neo4j_client()
149
+ await client.connect()
150
+
151
+
152
+ async def close_neo4j() -> None:
153
+ """Close Neo4j connection (call at shutdown)."""
154
+ client = get_neo4j_client()
155
+ await client.close()