alma-memory 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alma/__init__.py +296 -194
- alma/compression/__init__.py +33 -0
- alma/compression/pipeline.py +980 -0
- alma/confidence/__init__.py +47 -47
- alma/confidence/engine.py +540 -540
- alma/confidence/types.py +351 -351
- alma/config/loader.py +157 -157
- alma/consolidation/__init__.py +23 -23
- alma/consolidation/engine.py +678 -678
- alma/consolidation/prompts.py +84 -84
- alma/core.py +1189 -322
- alma/domains/__init__.py +30 -30
- alma/domains/factory.py +359 -359
- alma/domains/schemas.py +448 -448
- alma/domains/types.py +272 -272
- alma/events/__init__.py +75 -75
- alma/events/emitter.py +285 -284
- alma/events/storage_mixin.py +246 -246
- alma/events/types.py +126 -126
- alma/events/webhook.py +425 -425
- alma/exceptions.py +49 -49
- alma/extraction/__init__.py +31 -31
- alma/extraction/auto_learner.py +265 -264
- alma/extraction/extractor.py +420 -420
- alma/graph/__init__.py +106 -81
- alma/graph/backends/__init__.py +32 -18
- alma/graph/backends/kuzu.py +624 -0
- alma/graph/backends/memgraph.py +432 -0
- alma/graph/backends/memory.py +236 -236
- alma/graph/backends/neo4j.py +417 -417
- alma/graph/base.py +159 -159
- alma/graph/extraction.py +198 -198
- alma/graph/store.py +860 -860
- alma/harness/__init__.py +35 -35
- alma/harness/base.py +386 -386
- alma/harness/domains.py +705 -705
- alma/initializer/__init__.py +37 -37
- alma/initializer/initializer.py +418 -418
- alma/initializer/types.py +250 -250
- alma/integration/__init__.py +62 -62
- alma/integration/claude_agents.py +444 -432
- alma/integration/helena.py +423 -423
- alma/integration/victor.py +471 -471
- alma/learning/__init__.py +101 -86
- alma/learning/decay.py +878 -0
- alma/learning/forgetting.py +1446 -1446
- alma/learning/heuristic_extractor.py +390 -390
- alma/learning/protocols.py +374 -374
- alma/learning/validation.py +346 -346
- alma/mcp/__init__.py +123 -45
- alma/mcp/__main__.py +156 -156
- alma/mcp/resources.py +122 -122
- alma/mcp/server.py +955 -591
- alma/mcp/tools.py +3254 -511
- alma/observability/__init__.py +91 -0
- alma/observability/config.py +302 -0
- alma/observability/guidelines.py +170 -0
- alma/observability/logging.py +424 -0
- alma/observability/metrics.py +583 -0
- alma/observability/tracing.py +440 -0
- alma/progress/__init__.py +21 -21
- alma/progress/tracker.py +607 -607
- alma/progress/types.py +250 -250
- alma/retrieval/__init__.py +134 -53
- alma/retrieval/budget.py +525 -0
- alma/retrieval/cache.py +1304 -1061
- alma/retrieval/embeddings.py +202 -202
- alma/retrieval/engine.py +850 -366
- alma/retrieval/modes.py +365 -0
- alma/retrieval/progressive.py +560 -0
- alma/retrieval/scoring.py +344 -344
- alma/retrieval/trust_scoring.py +637 -0
- alma/retrieval/verification.py +797 -0
- alma/session/__init__.py +19 -19
- alma/session/manager.py +442 -399
- alma/session/types.py +288 -288
- alma/storage/__init__.py +101 -61
- alma/storage/archive.py +233 -0
- alma/storage/azure_cosmos.py +1259 -1048
- alma/storage/base.py +1083 -525
- alma/storage/chroma.py +1443 -1443
- alma/storage/constants.py +103 -0
- alma/storage/file_based.py +614 -619
- alma/storage/migrations/__init__.py +21 -0
- alma/storage/migrations/base.py +321 -0
- alma/storage/migrations/runner.py +323 -0
- alma/storage/migrations/version_stores.py +337 -0
- alma/storage/migrations/versions/__init__.py +11 -0
- alma/storage/migrations/versions/v1_0_0.py +373 -0
- alma/storage/migrations/versions/v1_1_0_workflow_context.py +551 -0
- alma/storage/pinecone.py +1080 -1080
- alma/storage/postgresql.py +1948 -1452
- alma/storage/qdrant.py +1306 -1306
- alma/storage/sqlite_local.py +3041 -1358
- alma/testing/__init__.py +46 -0
- alma/testing/factories.py +301 -0
- alma/testing/mocks.py +389 -0
- alma/types.py +292 -264
- alma/utils/__init__.py +19 -0
- alma/utils/tokenizer.py +521 -0
- alma/workflow/__init__.py +83 -0
- alma/workflow/artifacts.py +170 -0
- alma/workflow/checkpoint.py +311 -0
- alma/workflow/context.py +228 -0
- alma/workflow/outcomes.py +189 -0
- alma/workflow/reducers.py +393 -0
- {alma_memory-0.5.0.dist-info → alma_memory-0.7.0.dist-info}/METADATA +244 -72
- alma_memory-0.7.0.dist-info/RECORD +112 -0
- alma_memory-0.5.0.dist-info/RECORD +0 -76
- {alma_memory-0.5.0.dist-info → alma_memory-0.7.0.dist-info}/WHEEL +0 -0
- {alma_memory-0.5.0.dist-info → alma_memory-0.7.0.dist-info}/top_level.txt +0 -0
alma/graph/store.py
CHANGED
|
@@ -1,860 +1,860 @@
|
|
|
1
|
-
"""
|
|
2
|
-
ALMA Graph Memory Module.
|
|
3
|
-
|
|
4
|
-
Graph-based memory storage for capturing relationships between entities.
|
|
5
|
-
Supports Neo4j and in-memory graph for testing.
|
|
6
|
-
|
|
7
|
-
This module provides two APIs:
|
|
8
|
-
1. GraphStore (high-level) - Full-featured graph store with traversal and query support
|
|
9
|
-
2. GraphBackend (low-level) - Simple CRUD operations for pluggable backends
|
|
10
|
-
|
|
11
|
-
The GraphStore classes can optionally use GraphBackend implementations for storage,
|
|
12
|
-
enabling easy swapping of database backends while keeping the high-level API.
|
|
13
|
-
"""
|
|
14
|
-
|
|
15
|
-
import json
|
|
16
|
-
import logging
|
|
17
|
-
import time
|
|
18
|
-
from abc import ABC, abstractmethod
|
|
19
|
-
from dataclasses import dataclass, field
|
|
20
|
-
from datetime import datetime, timezone
|
|
21
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|
22
|
-
|
|
23
|
-
if TYPE_CHECKING:
|
|
24
|
-
from alma.graph.base import GraphBackend
|
|
25
|
-
|
|
26
|
-
logger = logging.getLogger(__name__)
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
@dataclass
|
|
30
|
-
class Entity:
|
|
31
|
-
"""A node in the knowledge graph."""
|
|
32
|
-
|
|
33
|
-
id: str
|
|
34
|
-
name: str
|
|
35
|
-
entity_type: str # person, organization, concept, tool, etc.
|
|
36
|
-
properties: Dict[str, Any] = field(default_factory=dict)
|
|
37
|
-
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
@dataclass
|
|
41
|
-
class Relationship:
|
|
42
|
-
"""An edge in the knowledge graph."""
|
|
43
|
-
|
|
44
|
-
id: str
|
|
45
|
-
source_id: str
|
|
46
|
-
target_id: str
|
|
47
|
-
relation_type: str # WORKS_AT, USES, KNOWS, CREATED_BY, etc.
|
|
48
|
-
properties: Dict[str, Any] = field(default_factory=dict)
|
|
49
|
-
confidence: float = 1.0
|
|
50
|
-
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
@dataclass
|
|
54
|
-
class GraphQuery:
|
|
55
|
-
"""A query against the knowledge graph."""
|
|
56
|
-
|
|
57
|
-
entities: List[str] # Entity names or IDs to search for
|
|
58
|
-
relation_types: Optional[List[str]] = None # Filter by relation types
|
|
59
|
-
max_hops: int = 2 # Maximum traversal depth
|
|
60
|
-
limit: int = 20 # Maximum results
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
@dataclass
|
|
64
|
-
class GraphResult:
|
|
65
|
-
"""Result from a graph query."""
|
|
66
|
-
|
|
67
|
-
entities: List[Entity]
|
|
68
|
-
relationships: List[Relationship]
|
|
69
|
-
paths: List[List[str]] # Paths through the graph
|
|
70
|
-
query_time_ms: int
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
class GraphStore(ABC):
|
|
74
|
-
"""
|
|
75
|
-
Abstract base class for graph storage backends.
|
|
76
|
-
|
|
77
|
-
This is the high-level API for graph operations including traversal
|
|
78
|
-
and complex queries. For simple CRUD operations, see GraphBackend.
|
|
79
|
-
"""
|
|
80
|
-
|
|
81
|
-
@abstractmethod
|
|
82
|
-
def add_entity(self, entity: Entity) -> str:
|
|
83
|
-
"""Add or update an entity."""
|
|
84
|
-
pass
|
|
85
|
-
|
|
86
|
-
@abstractmethod
|
|
87
|
-
def add_relationship(self, relationship: Relationship) -> str:
|
|
88
|
-
"""Add or update a relationship."""
|
|
89
|
-
pass
|
|
90
|
-
|
|
91
|
-
@abstractmethod
|
|
92
|
-
def get_entity(self, entity_id: str) -> Optional[Entity]:
|
|
93
|
-
"""Get an entity by ID."""
|
|
94
|
-
pass
|
|
95
|
-
|
|
96
|
-
@abstractmethod
|
|
97
|
-
def find_entities(
|
|
98
|
-
self,
|
|
99
|
-
name: Optional[str] = None,
|
|
100
|
-
entity_type: Optional[str] = None,
|
|
101
|
-
limit: int = 10,
|
|
102
|
-
) -> List[Entity]:
|
|
103
|
-
"""Find entities by name or type."""
|
|
104
|
-
pass
|
|
105
|
-
|
|
106
|
-
@abstractmethod
|
|
107
|
-
def get_relationships(
|
|
108
|
-
self,
|
|
109
|
-
entity_id: str,
|
|
110
|
-
direction: str = "both", # "outgoing", "incoming", "both"
|
|
111
|
-
relation_type: Optional[str] = None,
|
|
112
|
-
) -> List[Relationship]:
|
|
113
|
-
"""Get relationships for an entity."""
|
|
114
|
-
pass
|
|
115
|
-
|
|
116
|
-
@abstractmethod
|
|
117
|
-
def traverse(
|
|
118
|
-
self,
|
|
119
|
-
start_entity_id: str,
|
|
120
|
-
max_hops: int = 2,
|
|
121
|
-
relation_types: Optional[List[str]] = None,
|
|
122
|
-
) -> GraphResult:
|
|
123
|
-
"""Traverse the graph from a starting entity."""
|
|
124
|
-
pass
|
|
125
|
-
|
|
126
|
-
@abstractmethod
|
|
127
|
-
def query(self, query: GraphQuery) -> GraphResult:
|
|
128
|
-
"""Execute a graph query."""
|
|
129
|
-
pass
|
|
130
|
-
|
|
131
|
-
@abstractmethod
|
|
132
|
-
def delete_entity(self, entity_id: str) -> bool:
|
|
133
|
-
"""Delete an entity and its relationships."""
|
|
134
|
-
pass
|
|
135
|
-
|
|
136
|
-
def close(self) -> None: # noqa: B027
|
|
137
|
-
"""Close the store connection. Override in subclasses if needed."""
|
|
138
|
-
pass
|
|
139
|
-
|
|
140
|
-
def __enter__(self):
|
|
141
|
-
"""Context manager entry."""
|
|
142
|
-
return self
|
|
143
|
-
|
|
144
|
-
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
145
|
-
"""Context manager exit - ensures close is called."""
|
|
146
|
-
self.close()
|
|
147
|
-
return False
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
class Neo4jGraphStore(GraphStore):
|
|
151
|
-
"""
|
|
152
|
-
Neo4j graph storage backend.
|
|
153
|
-
|
|
154
|
-
Requires neo4j Python driver: pip install neo4j
|
|
155
|
-
"""
|
|
156
|
-
|
|
157
|
-
def __init__(
|
|
158
|
-
self,
|
|
159
|
-
uri: str,
|
|
160
|
-
username: str,
|
|
161
|
-
password: str,
|
|
162
|
-
database: str = "neo4j",
|
|
163
|
-
):
|
|
164
|
-
"""
|
|
165
|
-
Initialize Neo4j connection.
|
|
166
|
-
|
|
167
|
-
Args:
|
|
168
|
-
uri: Neo4j connection URI (bolt:// or neo4j+s://)
|
|
169
|
-
username: Database username
|
|
170
|
-
password: Database password
|
|
171
|
-
database: Database name
|
|
172
|
-
"""
|
|
173
|
-
self.uri = uri
|
|
174
|
-
self.username = username
|
|
175
|
-
self.password = password
|
|
176
|
-
self.database = database
|
|
177
|
-
self._driver = None
|
|
178
|
-
|
|
179
|
-
def _get_driver(self):
|
|
180
|
-
"""Lazy initialization of Neo4j driver."""
|
|
181
|
-
if self._driver is None:
|
|
182
|
-
try:
|
|
183
|
-
from neo4j import GraphDatabase
|
|
184
|
-
|
|
185
|
-
self._driver = GraphDatabase.driver(
|
|
186
|
-
self.uri,
|
|
187
|
-
auth=(self.username, self.password),
|
|
188
|
-
)
|
|
189
|
-
except ImportError as err:
|
|
190
|
-
raise ImportError(
|
|
191
|
-
"neo4j package required for Neo4j graph store. "
|
|
192
|
-
"Install with: pip install neo4j"
|
|
193
|
-
) from err
|
|
194
|
-
return self._driver
|
|
195
|
-
|
|
196
|
-
def _run_query(self, query: str, parameters: Dict = None) -> List[Dict]:
|
|
197
|
-
"""Execute a Cypher query."""
|
|
198
|
-
driver = self._get_driver()
|
|
199
|
-
with driver.session(database=self.database) as session:
|
|
200
|
-
result = session.run(query, parameters or {})
|
|
201
|
-
return [dict(record) for record in result]
|
|
202
|
-
|
|
203
|
-
def add_entity(self, entity: Entity) -> str:
|
|
204
|
-
"""Add or update an entity in Neo4j."""
|
|
205
|
-
query = """
|
|
206
|
-
MERGE (e:Entity {id: $id})
|
|
207
|
-
SET e.name = $name,
|
|
208
|
-
e.entity_type = $entity_type,
|
|
209
|
-
e.properties = $properties,
|
|
210
|
-
e.created_at = $created_at
|
|
211
|
-
RETURN e.id as id
|
|
212
|
-
"""
|
|
213
|
-
result = self._run_query(
|
|
214
|
-
query,
|
|
215
|
-
{
|
|
216
|
-
"id": entity.id,
|
|
217
|
-
"name": entity.name,
|
|
218
|
-
"entity_type": entity.entity_type,
|
|
219
|
-
"properties": json.dumps(entity.properties),
|
|
220
|
-
"created_at": entity.created_at.isoformat(),
|
|
221
|
-
},
|
|
222
|
-
)
|
|
223
|
-
return result[0]["id"] if result else entity.id
|
|
224
|
-
|
|
225
|
-
def add_relationship(self, relationship: Relationship) -> str:
|
|
226
|
-
"""Add or update a relationship in Neo4j."""
|
|
227
|
-
# Sanitize relationship type for Cypher
|
|
228
|
-
rel_type = (
|
|
229
|
-
relationship.relation_type.replace("-", "_").replace(" ", "_").upper()
|
|
230
|
-
)
|
|
231
|
-
query = f"""
|
|
232
|
-
MATCH (source:Entity {{id: $source_id}})
|
|
233
|
-
MATCH (target:Entity {{id: $target_id}})
|
|
234
|
-
MERGE (source)-[r:{rel_type}]->(target)
|
|
235
|
-
SET r.id = $id,
|
|
236
|
-
r.properties = $properties,
|
|
237
|
-
r.confidence = $confidence,
|
|
238
|
-
r.created_at = $created_at
|
|
239
|
-
RETURN r.id as id
|
|
240
|
-
"""
|
|
241
|
-
result = self._run_query(
|
|
242
|
-
query,
|
|
243
|
-
{
|
|
244
|
-
"id": relationship.id,
|
|
245
|
-
"source_id": relationship.source_id,
|
|
246
|
-
"target_id": relationship.target_id,
|
|
247
|
-
"properties": json.dumps(relationship.properties),
|
|
248
|
-
"confidence": relationship.confidence,
|
|
249
|
-
"created_at": relationship.created_at.isoformat(),
|
|
250
|
-
},
|
|
251
|
-
)
|
|
252
|
-
return result[0]["id"] if result else relationship.id
|
|
253
|
-
|
|
254
|
-
def get_entity(self, entity_id: str) -> Optional[Entity]:
|
|
255
|
-
"""Get an entity by ID."""
|
|
256
|
-
query = """
|
|
257
|
-
MATCH (e:Entity {id: $id})
|
|
258
|
-
RETURN e.id as id, e.name as name, e.entity_type as entity_type,
|
|
259
|
-
e.properties as properties, e.created_at as created_at
|
|
260
|
-
"""
|
|
261
|
-
result = self._run_query(query, {"id": entity_id})
|
|
262
|
-
if not result:
|
|
263
|
-
return None
|
|
264
|
-
r = result[0]
|
|
265
|
-
return Entity(
|
|
266
|
-
id=r["id"],
|
|
267
|
-
name=r["name"],
|
|
268
|
-
entity_type=r["entity_type"],
|
|
269
|
-
properties=json.loads(r["properties"]) if r["properties"] else {},
|
|
270
|
-
created_at=(
|
|
271
|
-
datetime.fromisoformat(r["created_at"])
|
|
272
|
-
if r["created_at"]
|
|
273
|
-
else datetime.now(timezone.utc)
|
|
274
|
-
),
|
|
275
|
-
)
|
|
276
|
-
|
|
277
|
-
def find_entities(
|
|
278
|
-
self,
|
|
279
|
-
name: Optional[str] = None,
|
|
280
|
-
entity_type: Optional[str] = None,
|
|
281
|
-
limit: int = 10,
|
|
282
|
-
) -> List[Entity]:
|
|
283
|
-
"""Find entities by name or type."""
|
|
284
|
-
conditions = []
|
|
285
|
-
params: Dict[str, Any] = {"limit": limit}
|
|
286
|
-
|
|
287
|
-
if name:
|
|
288
|
-
conditions.append("e.name CONTAINS $name")
|
|
289
|
-
params["name"] = name
|
|
290
|
-
if entity_type:
|
|
291
|
-
conditions.append("e.entity_type = $entity_type")
|
|
292
|
-
params["entity_type"] = entity_type
|
|
293
|
-
|
|
294
|
-
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
|
295
|
-
|
|
296
|
-
query = f"""
|
|
297
|
-
MATCH (e:Entity)
|
|
298
|
-
{where_clause}
|
|
299
|
-
RETURN e.id as id, e.name as name, e.entity_type as entity_type,
|
|
300
|
-
e.properties as properties, e.created_at as created_at
|
|
301
|
-
LIMIT $limit
|
|
302
|
-
"""
|
|
303
|
-
|
|
304
|
-
results = self._run_query(query, params)
|
|
305
|
-
return [
|
|
306
|
-
Entity(
|
|
307
|
-
id=r["id"],
|
|
308
|
-
name=r["name"],
|
|
309
|
-
entity_type=r["entity_type"],
|
|
310
|
-
properties=json.loads(r["properties"]) if r["properties"] else {},
|
|
311
|
-
)
|
|
312
|
-
for r in results
|
|
313
|
-
]
|
|
314
|
-
|
|
315
|
-
def get_relationships(
|
|
316
|
-
self,
|
|
317
|
-
entity_id: str,
|
|
318
|
-
direction: str = "both",
|
|
319
|
-
relation_type: Optional[str] = None,
|
|
320
|
-
) -> List[Relationship]:
|
|
321
|
-
"""Get relationships for an entity."""
|
|
322
|
-
if direction == "outgoing":
|
|
323
|
-
pattern = "(e)-[r]->(other)"
|
|
324
|
-
elif direction == "incoming":
|
|
325
|
-
pattern = "(e)<-[r]-(other)"
|
|
326
|
-
else:
|
|
327
|
-
pattern = "(e)-[r]-(other)"
|
|
328
|
-
|
|
329
|
-
type_filter = f":{relation_type}" if relation_type else ""
|
|
330
|
-
|
|
331
|
-
query = f"""
|
|
332
|
-
MATCH (e:Entity {{id: $entity_id}}){pattern.replace("[r]", f"[r{type_filter}]")}
|
|
333
|
-
RETURN r.id as id, e.id as source_id, other.id as target_id,
|
|
334
|
-
type(r) as relation_type, r.properties as properties,
|
|
335
|
-
r.confidence as confidence, r.created_at as created_at
|
|
336
|
-
"""
|
|
337
|
-
|
|
338
|
-
results = self._run_query(query, {"entity_id": entity_id})
|
|
339
|
-
return [
|
|
340
|
-
Relationship(
|
|
341
|
-
id=r["id"] or f"{r['source_id']}-{r['relation_type']}-{r['target_id']}",
|
|
342
|
-
source_id=r["source_id"],
|
|
343
|
-
target_id=r["target_id"],
|
|
344
|
-
relation_type=r["relation_type"],
|
|
345
|
-
properties=json.loads(r["properties"]) if r["properties"] else {},
|
|
346
|
-
confidence=r["confidence"] or 1.0,
|
|
347
|
-
)
|
|
348
|
-
for r in results
|
|
349
|
-
]
|
|
350
|
-
|
|
351
|
-
def traverse(
|
|
352
|
-
self,
|
|
353
|
-
start_entity_id: str,
|
|
354
|
-
max_hops: int = 2,
|
|
355
|
-
relation_types: Optional[List[str]] = None,
|
|
356
|
-
) -> GraphResult:
|
|
357
|
-
"""Traverse the graph from a starting entity."""
|
|
358
|
-
start_time = time.time()
|
|
359
|
-
|
|
360
|
-
type_filter = ""
|
|
361
|
-
if relation_types:
|
|
362
|
-
type_filter = ":" + "|".join(relation_types)
|
|
363
|
-
|
|
364
|
-
query = f"""
|
|
365
|
-
MATCH path = (start:Entity {{id: $start_id}})-[r{type_filter}*1..{max_hops}]-(end:Entity)
|
|
366
|
-
RETURN nodes(path) as nodes, relationships(path) as rels
|
|
367
|
-
LIMIT 100
|
|
368
|
-
"""
|
|
369
|
-
|
|
370
|
-
results = self._run_query(query, {"start_id": start_entity_id})
|
|
371
|
-
|
|
372
|
-
entities = {}
|
|
373
|
-
relationships = {}
|
|
374
|
-
paths = []
|
|
375
|
-
|
|
376
|
-
for r in results:
|
|
377
|
-
path_ids = []
|
|
378
|
-
for node in r["nodes"]:
|
|
379
|
-
if node["id"] not in entities:
|
|
380
|
-
entities[node["id"]] = Entity(
|
|
381
|
-
id=node["id"],
|
|
382
|
-
name=node.get("name", ""),
|
|
383
|
-
entity_type=node.get("entity_type", "unknown"),
|
|
384
|
-
)
|
|
385
|
-
path_ids.append(node["id"])
|
|
386
|
-
paths.append(path_ids)
|
|
387
|
-
|
|
388
|
-
for rel in r["rels"]:
|
|
389
|
-
rel_id = rel.get("id", f"{rel['source_id']}-{rel['target_id']}")
|
|
390
|
-
if rel_id not in relationships:
|
|
391
|
-
relationships[rel_id] = Relationship(
|
|
392
|
-
id=rel_id,
|
|
393
|
-
source_id=rel.get("source_id", ""),
|
|
394
|
-
target_id=rel.get("target_id", ""),
|
|
395
|
-
relation_type=rel.get("type", "RELATED"),
|
|
396
|
-
)
|
|
397
|
-
|
|
398
|
-
query_time_ms = int((time.time() - start_time) * 1000)
|
|
399
|
-
|
|
400
|
-
return GraphResult(
|
|
401
|
-
entities=list(entities.values()),
|
|
402
|
-
relationships=list(relationships.values()),
|
|
403
|
-
paths=paths,
|
|
404
|
-
query_time_ms=query_time_ms,
|
|
405
|
-
)
|
|
406
|
-
|
|
407
|
-
def query(self, query: GraphQuery) -> GraphResult:
|
|
408
|
-
"""Execute a graph query."""
|
|
409
|
-
# Find starting entities
|
|
410
|
-
all_entities = {}
|
|
411
|
-
all_relationships = {}
|
|
412
|
-
all_paths = []
|
|
413
|
-
|
|
414
|
-
for entity_name in query.entities:
|
|
415
|
-
entities = self.find_entities(name=entity_name, limit=5)
|
|
416
|
-
for entity in entities:
|
|
417
|
-
all_entities[entity.id] = entity
|
|
418
|
-
result = self.traverse(
|
|
419
|
-
entity.id,
|
|
420
|
-
max_hops=query.max_hops,
|
|
421
|
-
relation_types=query.relation_types,
|
|
422
|
-
)
|
|
423
|
-
for e in result.entities:
|
|
424
|
-
all_entities[e.id] = e
|
|
425
|
-
for r in result.relationships:
|
|
426
|
-
all_relationships[r.id] = r
|
|
427
|
-
all_paths.extend(result.paths)
|
|
428
|
-
|
|
429
|
-
return GraphResult(
|
|
430
|
-
entities=list(all_entities.values())[: query.limit],
|
|
431
|
-
relationships=list(all_relationships.values()),
|
|
432
|
-
paths=all_paths[: query.limit],
|
|
433
|
-
query_time_ms=0,
|
|
434
|
-
)
|
|
435
|
-
|
|
436
|
-
def delete_entity(self, entity_id: str) -> bool:
|
|
437
|
-
"""Delete an entity and its relationships."""
|
|
438
|
-
query = """
|
|
439
|
-
MATCH (e:Entity {id: $id})
|
|
440
|
-
DETACH DELETE e
|
|
441
|
-
RETURN count(e) as deleted
|
|
442
|
-
"""
|
|
443
|
-
result = self._run_query(query, {"id": entity_id})
|
|
444
|
-
return result[0]["deleted"] > 0 if result else False
|
|
445
|
-
|
|
446
|
-
def close(self):
|
|
447
|
-
"""Close the Neo4j driver connection."""
|
|
448
|
-
if self._driver:
|
|
449
|
-
self._driver.close()
|
|
450
|
-
self._driver = None
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
class InMemoryGraphStore(GraphStore):
|
|
454
|
-
"""
|
|
455
|
-
In-memory graph storage for testing and development.
|
|
456
|
-
|
|
457
|
-
No external dependencies required.
|
|
458
|
-
"""
|
|
459
|
-
|
|
460
|
-
def __init__(self):
|
|
461
|
-
self._entities: Dict[str, Entity] = {}
|
|
462
|
-
self._relationships: Dict[str, Relationship] = {}
|
|
463
|
-
self._outgoing: Dict[str, List[str]] = {} # entity_id -> [rel_ids]
|
|
464
|
-
self._incoming: Dict[str, List[str]] = {} # entity_id -> [rel_ids]
|
|
465
|
-
|
|
466
|
-
def add_entity(self, entity: Entity) -> str:
|
|
467
|
-
self._entities[entity.id] = entity
|
|
468
|
-
if entity.id not in self._outgoing:
|
|
469
|
-
self._outgoing[entity.id] = []
|
|
470
|
-
if entity.id not in self._incoming:
|
|
471
|
-
self._incoming[entity.id] = []
|
|
472
|
-
return entity.id
|
|
473
|
-
|
|
474
|
-
def add_relationship(self, relationship: Relationship) -> str:
|
|
475
|
-
self._relationships[relationship.id] = relationship
|
|
476
|
-
|
|
477
|
-
if relationship.source_id not in self._outgoing:
|
|
478
|
-
self._outgoing[relationship.source_id] = []
|
|
479
|
-
if relationship.id not in self._outgoing[relationship.source_id]:
|
|
480
|
-
self._outgoing[relationship.source_id].append(relationship.id)
|
|
481
|
-
|
|
482
|
-
if relationship.target_id not in self._incoming:
|
|
483
|
-
self._incoming[relationship.target_id] = []
|
|
484
|
-
if relationship.id not in self._incoming[relationship.target_id]:
|
|
485
|
-
self._incoming[relationship.target_id].append(relationship.id)
|
|
486
|
-
|
|
487
|
-
return relationship.id
|
|
488
|
-
|
|
489
|
-
def get_entity(self, entity_id: str) -> Optional[Entity]:
|
|
490
|
-
return self._entities.get(entity_id)
|
|
491
|
-
|
|
492
|
-
def find_entities(
|
|
493
|
-
self,
|
|
494
|
-
name: Optional[str] = None,
|
|
495
|
-
entity_type: Optional[str] = None,
|
|
496
|
-
limit: int = 10,
|
|
497
|
-
) -> List[Entity]:
|
|
498
|
-
results = []
|
|
499
|
-
for entity in self._entities.values():
|
|
500
|
-
if name and name.lower() not in entity.name.lower():
|
|
501
|
-
continue
|
|
502
|
-
if entity_type and entity.entity_type != entity_type:
|
|
503
|
-
continue
|
|
504
|
-
results.append(entity)
|
|
505
|
-
if len(results) >= limit:
|
|
506
|
-
break
|
|
507
|
-
return results
|
|
508
|
-
|
|
509
|
-
def get_relationships(
|
|
510
|
-
self,
|
|
511
|
-
entity_id: str,
|
|
512
|
-
direction: str = "both",
|
|
513
|
-
relation_type: Optional[str] = None,
|
|
514
|
-
) -> List[Relationship]:
|
|
515
|
-
rel_ids = set()
|
|
516
|
-
|
|
517
|
-
if direction in ("outgoing", "both"):
|
|
518
|
-
rel_ids.update(self._outgoing.get(entity_id, []))
|
|
519
|
-
if direction in ("incoming", "both"):
|
|
520
|
-
rel_ids.update(self._incoming.get(entity_id, []))
|
|
521
|
-
|
|
522
|
-
results = []
|
|
523
|
-
for rel_id in rel_ids:
|
|
524
|
-
rel = self._relationships.get(rel_id)
|
|
525
|
-
if rel:
|
|
526
|
-
if relation_type and rel.relation_type != relation_type:
|
|
527
|
-
continue
|
|
528
|
-
results.append(rel)
|
|
529
|
-
return results
|
|
530
|
-
|
|
531
|
-
def traverse(
|
|
532
|
-
self,
|
|
533
|
-
start_entity_id: str,
|
|
534
|
-
max_hops: int = 2,
|
|
535
|
-
relation_types: Optional[List[str]] = None,
|
|
536
|
-
) -> GraphResult:
|
|
537
|
-
start_time = time.time()
|
|
538
|
-
|
|
539
|
-
visited_entities = {start_entity_id}
|
|
540
|
-
visited_relationships = set()
|
|
541
|
-
paths = []
|
|
542
|
-
|
|
543
|
-
def _traverse(current_id: str, depth: int, current_path: List[str]):
|
|
544
|
-
if depth > max_hops:
|
|
545
|
-
return
|
|
546
|
-
|
|
547
|
-
for rel in self.get_relationships(current_id, "both"):
|
|
548
|
-
if relation_types and rel.relation_type not in relation_types:
|
|
549
|
-
continue
|
|
550
|
-
|
|
551
|
-
visited_relationships.add(rel.id)
|
|
552
|
-
|
|
553
|
-
next_id = (
|
|
554
|
-
rel.target_id if rel.source_id == current_id else rel.source_id
|
|
555
|
-
)
|
|
556
|
-
|
|
557
|
-
if next_id not in visited_entities:
|
|
558
|
-
visited_entities.add(next_id)
|
|
559
|
-
new_path = current_path + [next_id]
|
|
560
|
-
paths.append(new_path)
|
|
561
|
-
_traverse(next_id, depth + 1, new_path)
|
|
562
|
-
|
|
563
|
-
_traverse(start_entity_id, 0, [start_entity_id])
|
|
564
|
-
|
|
565
|
-
entities = [
|
|
566
|
-
self._entities[eid] for eid in visited_entities if eid in self._entities
|
|
567
|
-
]
|
|
568
|
-
relationships = [
|
|
569
|
-
self._relationships[rid]
|
|
570
|
-
for rid in visited_relationships
|
|
571
|
-
if rid in self._relationships
|
|
572
|
-
]
|
|
573
|
-
|
|
574
|
-
query_time_ms = int((time.time() - start_time) * 1000)
|
|
575
|
-
|
|
576
|
-
return GraphResult(
|
|
577
|
-
entities=entities,
|
|
578
|
-
relationships=relationships,
|
|
579
|
-
paths=paths,
|
|
580
|
-
query_time_ms=query_time_ms,
|
|
581
|
-
)
|
|
582
|
-
|
|
583
|
-
def query(self, query: GraphQuery) -> GraphResult:
|
|
584
|
-
all_entities = {}
|
|
585
|
-
all_relationships = {}
|
|
586
|
-
all_paths = []
|
|
587
|
-
|
|
588
|
-
for entity_name in query.entities:
|
|
589
|
-
entities = self.find_entities(name=entity_name, limit=5)
|
|
590
|
-
for entity in entities:
|
|
591
|
-
all_entities[entity.id] = entity
|
|
592
|
-
result = self.traverse(
|
|
593
|
-
entity.id,
|
|
594
|
-
max_hops=query.max_hops,
|
|
595
|
-
relation_types=query.relation_types,
|
|
596
|
-
)
|
|
597
|
-
for e in result.entities:
|
|
598
|
-
all_entities[e.id] = e
|
|
599
|
-
for r in result.relationships:
|
|
600
|
-
all_relationships[r.id] = r
|
|
601
|
-
all_paths.extend(result.paths)
|
|
602
|
-
|
|
603
|
-
return GraphResult(
|
|
604
|
-
entities=list(all_entities.values())[: query.limit],
|
|
605
|
-
relationships=list(all_relationships.values()),
|
|
606
|
-
paths=all_paths[: query.limit],
|
|
607
|
-
query_time_ms=0,
|
|
608
|
-
)
|
|
609
|
-
|
|
610
|
-
def delete_entity(self, entity_id: str) -> bool:
|
|
611
|
-
if entity_id not in self._entities:
|
|
612
|
-
return False
|
|
613
|
-
|
|
614
|
-
# Delete outgoing relationships
|
|
615
|
-
for rel_id in list(self._outgoing.get(entity_id, [])):
|
|
616
|
-
if rel_id in self._relationships:
|
|
617
|
-
rel = self._relationships[rel_id]
|
|
618
|
-
# Remove from target's incoming
|
|
619
|
-
if rel.target_id in self._incoming:
|
|
620
|
-
if rel_id in self._incoming[rel.target_id]:
|
|
621
|
-
self._incoming[rel.target_id].remove(rel_id)
|
|
622
|
-
del self._relationships[rel_id]
|
|
623
|
-
|
|
624
|
-
# Delete incoming relationships
|
|
625
|
-
for rel_id in list(self._incoming.get(entity_id, [])):
|
|
626
|
-
if rel_id in self._relationships:
|
|
627
|
-
rel = self._relationships[rel_id]
|
|
628
|
-
# Remove from source's outgoing
|
|
629
|
-
if rel.source_id in self._outgoing:
|
|
630
|
-
if rel_id in self._outgoing[rel.source_id]:
|
|
631
|
-
self._outgoing[rel.source_id].remove(rel_id)
|
|
632
|
-
del self._relationships[rel_id]
|
|
633
|
-
|
|
634
|
-
# Delete entity
|
|
635
|
-
del self._entities[entity_id]
|
|
636
|
-
self._outgoing.pop(entity_id, None)
|
|
637
|
-
self._incoming.pop(entity_id, None)
|
|
638
|
-
|
|
639
|
-
return True
|
|
640
|
-
|
|
641
|
-
def clear(self) -> None:
|
|
642
|
-
"""Clear all stored data."""
|
|
643
|
-
self._entities.clear()
|
|
644
|
-
self._relationships.clear()
|
|
645
|
-
self._outgoing.clear()
|
|
646
|
-
self._incoming.clear()
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
class BackendGraphStore(GraphStore):
|
|
650
|
-
"""
|
|
651
|
-
GraphStore implementation that delegates to a GraphBackend.
|
|
652
|
-
|
|
653
|
-
This class bridges the high-level GraphStore API with the pluggable
|
|
654
|
-
GraphBackend interface, enabling use of different database backends
|
|
655
|
-
while maintaining the full GraphStore functionality.
|
|
656
|
-
|
|
657
|
-
Example:
|
|
658
|
-
from alma.graph import create_graph_backend, BackendGraphStore
|
|
659
|
-
|
|
660
|
-
backend = create_graph_backend("neo4j", uri="bolt://localhost:7687", ...)
|
|
661
|
-
store = BackendGraphStore(backend)
|
|
662
|
-
store.add_entity(entity)
|
|
663
|
-
result = store.traverse(entity.id)
|
|
664
|
-
"""
|
|
665
|
-
|
|
666
|
-
def __init__(self, backend: "GraphBackend"):
|
|
667
|
-
"""
|
|
668
|
-
Initialize with a GraphBackend.
|
|
669
|
-
|
|
670
|
-
Args:
|
|
671
|
-
backend: The GraphBackend implementation to use for storage.
|
|
672
|
-
"""
|
|
673
|
-
self._backend = backend
|
|
674
|
-
|
|
675
|
-
@property
|
|
676
|
-
def backend(self) -> "GraphBackend":
|
|
677
|
-
"""Access the underlying backend."""
|
|
678
|
-
return self._backend
|
|
679
|
-
|
|
680
|
-
def add_entity(self, entity: Entity) -> str:
|
|
681
|
-
return self._backend.add_entity(entity)
|
|
682
|
-
|
|
683
|
-
def add_relationship(self, relationship: Relationship) -> str:
|
|
684
|
-
return self._backend.add_relationship(relationship)
|
|
685
|
-
|
|
686
|
-
def get_entity(self, entity_id: str) -> Optional[Entity]:
|
|
687
|
-
return self._backend.get_entity(entity_id)
|
|
688
|
-
|
|
689
|
-
def find_entities(
|
|
690
|
-
self,
|
|
691
|
-
name: Optional[str] = None,
|
|
692
|
-
entity_type: Optional[str] = None,
|
|
693
|
-
limit: int = 10,
|
|
694
|
-
) -> List[Entity]:
|
|
695
|
-
"""Find entities by name or type using the backend."""
|
|
696
|
-
if hasattr(self._backend, "find_entities"):
|
|
697
|
-
return self._backend.find_entities(
|
|
698
|
-
name=name, entity_type=entity_type, limit=limit
|
|
699
|
-
)
|
|
700
|
-
# Fallback to search_entities if find_entities not available
|
|
701
|
-
if name:
|
|
702
|
-
return self._backend.search_entities(query=name, top_k=limit)
|
|
703
|
-
return self._backend.get_entities(entity_type=entity_type, limit=limit)
|
|
704
|
-
|
|
705
|
-
def get_relationships(
|
|
706
|
-
self,
|
|
707
|
-
entity_id: str,
|
|
708
|
-
direction: str = "both",
|
|
709
|
-
relation_type: Optional[str] = None,
|
|
710
|
-
) -> List[Relationship]:
|
|
711
|
-
"""Get relationships for an entity."""
|
|
712
|
-
if hasattr(self._backend, "get_relationships_directional"):
|
|
713
|
-
return self._backend.get_relationships_directional(
|
|
714
|
-
entity_id=entity_id, direction=direction, relation_type=relation_type
|
|
715
|
-
)
|
|
716
|
-
# Fallback - get all relationships and filter
|
|
717
|
-
all_rels = self._backend.get_relationships(entity_id)
|
|
718
|
-
if relation_type:
|
|
719
|
-
all_rels = [r for r in all_rels if r.relation_type == relation_type]
|
|
720
|
-
return all_rels
|
|
721
|
-
|
|
722
|
-
def traverse(
|
|
723
|
-
self,
|
|
724
|
-
start_entity_id: str,
|
|
725
|
-
max_hops: int = 2,
|
|
726
|
-
relation_types: Optional[List[str]] = None,
|
|
727
|
-
) -> GraphResult:
|
|
728
|
-
"""Traverse the graph from a starting entity."""
|
|
729
|
-
start_time = time.time()
|
|
730
|
-
|
|
731
|
-
visited_entities = {start_entity_id}
|
|
732
|
-
visited_relationships: Dict[str, Relationship] = {}
|
|
733
|
-
paths: List[List[str]] = []
|
|
734
|
-
|
|
735
|
-
def _traverse(current_id: str, depth: int, current_path: List[str]):
|
|
736
|
-
if depth > max_hops:
|
|
737
|
-
return
|
|
738
|
-
|
|
739
|
-
rels = self._backend.get_relationships(current_id)
|
|
740
|
-
for rel in rels:
|
|
741
|
-
if relation_types and rel.relation_type not in relation_types:
|
|
742
|
-
continue
|
|
743
|
-
|
|
744
|
-
visited_relationships[rel.id] = rel
|
|
745
|
-
|
|
746
|
-
next_id = (
|
|
747
|
-
rel.target_id if rel.source_id == current_id else rel.source_id
|
|
748
|
-
)
|
|
749
|
-
|
|
750
|
-
if next_id not in visited_entities:
|
|
751
|
-
visited_entities.add(next_id)
|
|
752
|
-
new_path = current_path + [next_id]
|
|
753
|
-
paths.append(new_path)
|
|
754
|
-
_traverse(next_id, depth + 1, new_path)
|
|
755
|
-
|
|
756
|
-
_traverse(start_entity_id, 0, [start_entity_id])
|
|
757
|
-
|
|
758
|
-
# Fetch all visited entities
|
|
759
|
-
entities = []
|
|
760
|
-
for eid in visited_entities:
|
|
761
|
-
entity = self._backend.get_entity(eid)
|
|
762
|
-
if entity:
|
|
763
|
-
entities.append(entity)
|
|
764
|
-
|
|
765
|
-
query_time_ms = int((time.time() - start_time) * 1000)
|
|
766
|
-
|
|
767
|
-
return GraphResult(
|
|
768
|
-
entities=entities,
|
|
769
|
-
relationships=list(visited_relationships.values()),
|
|
770
|
-
paths=paths,
|
|
771
|
-
query_time_ms=query_time_ms,
|
|
772
|
-
)
|
|
773
|
-
|
|
774
|
-
def query(self, query: GraphQuery) -> GraphResult:
|
|
775
|
-
"""Execute a graph query."""
|
|
776
|
-
all_entities: Dict[str, Entity] = {}
|
|
777
|
-
all_relationships: Dict[str, Relationship] = {}
|
|
778
|
-
all_paths: List[List[str]] = []
|
|
779
|
-
|
|
780
|
-
for entity_name in query.entities:
|
|
781
|
-
entities = self.find_entities(name=entity_name, limit=5)
|
|
782
|
-
for entity in entities:
|
|
783
|
-
all_entities[entity.id] = entity
|
|
784
|
-
result = self.traverse(
|
|
785
|
-
entity.id,
|
|
786
|
-
max_hops=query.max_hops,
|
|
787
|
-
relation_types=query.relation_types,
|
|
788
|
-
)
|
|
789
|
-
for e in result.entities:
|
|
790
|
-
all_entities[e.id] = e
|
|
791
|
-
for r in result.relationships:
|
|
792
|
-
all_relationships[r.id] = r
|
|
793
|
-
all_paths.extend(result.paths)
|
|
794
|
-
|
|
795
|
-
return GraphResult(
|
|
796
|
-
entities=list(all_entities.values())[: query.limit],
|
|
797
|
-
relationships=list(all_relationships.values()),
|
|
798
|
-
paths=all_paths[: query.limit],
|
|
799
|
-
query_time_ms=0,
|
|
800
|
-
)
|
|
801
|
-
|
|
802
|
-
def delete_entity(self, entity_id: str) -> bool:
|
|
803
|
-
return self._backend.delete_entity(entity_id)
|
|
804
|
-
|
|
805
|
-
def close(self) -> None:
|
|
806
|
-
"""Close the backend connection."""
|
|
807
|
-
self._backend.close()
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
def create_graph_store(
|
|
811
|
-
provider: str = "memory",
|
|
812
|
-
**kwargs,
|
|
813
|
-
) -> GraphStore:
|
|
814
|
-
"""
|
|
815
|
-
Factory function to create a graph store.
|
|
816
|
-
|
|
817
|
-
Args:
|
|
818
|
-
provider: "neo4j", "memory", or "backend"
|
|
819
|
-
**kwargs: Provider-specific arguments
|
|
820
|
-
For "neo4j": uri, username, password, database
|
|
821
|
-
For "backend": backend (GraphBackend instance)
|
|
822
|
-
|
|
823
|
-
Returns:
|
|
824
|
-
Configured GraphStore instance
|
|
825
|
-
|
|
826
|
-
Note:
|
|
827
|
-
Amazon Neptune support is planned for a future release.
|
|
828
|
-
|
|
829
|
-
Example:
|
|
830
|
-
# Create in-memory store
|
|
831
|
-
store = create_graph_store("memory")
|
|
832
|
-
|
|
833
|
-
# Create Neo4j store
|
|
834
|
-
store = create_graph_store(
|
|
835
|
-
"neo4j",
|
|
836
|
-
uri="bolt://localhost:7687",
|
|
837
|
-
username="neo4j",
|
|
838
|
-
password="password"
|
|
839
|
-
)
|
|
840
|
-
|
|
841
|
-
# Create store using a backend
|
|
842
|
-
from alma.graph import create_graph_backend
|
|
843
|
-
backend = create_graph_backend("neo4j", uri="...", ...)
|
|
844
|
-
store = create_graph_store("backend", backend=backend)
|
|
845
|
-
"""
|
|
846
|
-
if provider == "neo4j":
|
|
847
|
-
return Neo4jGraphStore(**kwargs)
|
|
848
|
-
elif provider == "neptune":
|
|
849
|
-
# Neptune support is planned for a future release
|
|
850
|
-
raise NotImplementedError(
|
|
851
|
-
"Neptune support is not yet implemented. "
|
|
852
|
-
"Use 'neo4j' or 'memory' providers instead."
|
|
853
|
-
)
|
|
854
|
-
elif provider == "backend":
|
|
855
|
-
backend = kwargs.get("backend")
|
|
856
|
-
if backend is None:
|
|
857
|
-
raise ValueError("'backend' argument required for 'backend' provider")
|
|
858
|
-
return BackendGraphStore(backend)
|
|
859
|
-
else:
|
|
860
|
-
return InMemoryGraphStore()
|
|
1
|
+
"""
|
|
2
|
+
ALMA Graph Memory Module.
|
|
3
|
+
|
|
4
|
+
Graph-based memory storage for capturing relationships between entities.
|
|
5
|
+
Supports Neo4j and in-memory graph for testing.
|
|
6
|
+
|
|
7
|
+
This module provides two APIs:
|
|
8
|
+
1. GraphStore (high-level) - Full-featured graph store with traversal and query support
|
|
9
|
+
2. GraphBackend (low-level) - Simple CRUD operations for pluggable backends
|
|
10
|
+
|
|
11
|
+
The GraphStore classes can optionally use GraphBackend implementations for storage,
|
|
12
|
+
enabling easy swapping of database backends while keeping the high-level API.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import logging
|
|
17
|
+
import time
|
|
18
|
+
from abc import ABC, abstractmethod
|
|
19
|
+
from dataclasses import dataclass, field
|
|
20
|
+
from datetime import datetime, timezone
|
|
21
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from alma.graph.base import GraphBackend
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class Entity:
|
|
31
|
+
"""A node in the knowledge graph."""
|
|
32
|
+
|
|
33
|
+
id: str
|
|
34
|
+
name: str
|
|
35
|
+
entity_type: str # person, organization, concept, tool, etc.
|
|
36
|
+
properties: Dict[str, Any] = field(default_factory=dict)
|
|
37
|
+
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class Relationship:
|
|
42
|
+
"""An edge in the knowledge graph."""
|
|
43
|
+
|
|
44
|
+
id: str
|
|
45
|
+
source_id: str
|
|
46
|
+
target_id: str
|
|
47
|
+
relation_type: str # WORKS_AT, USES, KNOWS, CREATED_BY, etc.
|
|
48
|
+
properties: Dict[str, Any] = field(default_factory=dict)
|
|
49
|
+
confidence: float = 1.0
|
|
50
|
+
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class GraphQuery:
|
|
55
|
+
"""A query against the knowledge graph."""
|
|
56
|
+
|
|
57
|
+
entities: List[str] # Entity names or IDs to search for
|
|
58
|
+
relation_types: Optional[List[str]] = None # Filter by relation types
|
|
59
|
+
max_hops: int = 2 # Maximum traversal depth
|
|
60
|
+
limit: int = 20 # Maximum results
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@dataclass
|
|
64
|
+
class GraphResult:
|
|
65
|
+
"""Result from a graph query."""
|
|
66
|
+
|
|
67
|
+
entities: List[Entity]
|
|
68
|
+
relationships: List[Relationship]
|
|
69
|
+
paths: List[List[str]] # Paths through the graph
|
|
70
|
+
query_time_ms: int
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class GraphStore(ABC):
|
|
74
|
+
"""
|
|
75
|
+
Abstract base class for graph storage backends.
|
|
76
|
+
|
|
77
|
+
This is the high-level API for graph operations including traversal
|
|
78
|
+
and complex queries. For simple CRUD operations, see GraphBackend.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
@abstractmethod
|
|
82
|
+
def add_entity(self, entity: Entity) -> str:
|
|
83
|
+
"""Add or update an entity."""
|
|
84
|
+
pass
|
|
85
|
+
|
|
86
|
+
@abstractmethod
|
|
87
|
+
def add_relationship(self, relationship: Relationship) -> str:
|
|
88
|
+
"""Add or update a relationship."""
|
|
89
|
+
pass
|
|
90
|
+
|
|
91
|
+
@abstractmethod
|
|
92
|
+
def get_entity(self, entity_id: str) -> Optional[Entity]:
|
|
93
|
+
"""Get an entity by ID."""
|
|
94
|
+
pass
|
|
95
|
+
|
|
96
|
+
@abstractmethod
|
|
97
|
+
def find_entities(
|
|
98
|
+
self,
|
|
99
|
+
name: Optional[str] = None,
|
|
100
|
+
entity_type: Optional[str] = None,
|
|
101
|
+
limit: int = 10,
|
|
102
|
+
) -> List[Entity]:
|
|
103
|
+
"""Find entities by name or type."""
|
|
104
|
+
pass
|
|
105
|
+
|
|
106
|
+
@abstractmethod
|
|
107
|
+
def get_relationships(
|
|
108
|
+
self,
|
|
109
|
+
entity_id: str,
|
|
110
|
+
direction: str = "both", # "outgoing", "incoming", "both"
|
|
111
|
+
relation_type: Optional[str] = None,
|
|
112
|
+
) -> List[Relationship]:
|
|
113
|
+
"""Get relationships for an entity."""
|
|
114
|
+
pass
|
|
115
|
+
|
|
116
|
+
@abstractmethod
|
|
117
|
+
def traverse(
|
|
118
|
+
self,
|
|
119
|
+
start_entity_id: str,
|
|
120
|
+
max_hops: int = 2,
|
|
121
|
+
relation_types: Optional[List[str]] = None,
|
|
122
|
+
) -> GraphResult:
|
|
123
|
+
"""Traverse the graph from a starting entity."""
|
|
124
|
+
pass
|
|
125
|
+
|
|
126
|
+
@abstractmethod
|
|
127
|
+
def query(self, query: GraphQuery) -> GraphResult:
|
|
128
|
+
"""Execute a graph query."""
|
|
129
|
+
pass
|
|
130
|
+
|
|
131
|
+
@abstractmethod
|
|
132
|
+
def delete_entity(self, entity_id: str) -> bool:
|
|
133
|
+
"""Delete an entity and its relationships."""
|
|
134
|
+
pass
|
|
135
|
+
|
|
136
|
+
def close(self) -> None: # noqa: B027
|
|
137
|
+
"""Close the store connection. Override in subclasses if needed."""
|
|
138
|
+
pass
|
|
139
|
+
|
|
140
|
+
def __enter__(self):
|
|
141
|
+
"""Context manager entry."""
|
|
142
|
+
return self
|
|
143
|
+
|
|
144
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
145
|
+
"""Context manager exit - ensures close is called."""
|
|
146
|
+
self.close()
|
|
147
|
+
return False
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class Neo4jGraphStore(GraphStore):
|
|
151
|
+
"""
|
|
152
|
+
Neo4j graph storage backend.
|
|
153
|
+
|
|
154
|
+
Requires neo4j Python driver: pip install neo4j
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
def __init__(
|
|
158
|
+
self,
|
|
159
|
+
uri: str,
|
|
160
|
+
username: str,
|
|
161
|
+
password: str,
|
|
162
|
+
database: str = "neo4j",
|
|
163
|
+
):
|
|
164
|
+
"""
|
|
165
|
+
Initialize Neo4j connection.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
uri: Neo4j connection URI (bolt:// or neo4j+s://)
|
|
169
|
+
username: Database username
|
|
170
|
+
password: Database password
|
|
171
|
+
database: Database name
|
|
172
|
+
"""
|
|
173
|
+
self.uri = uri
|
|
174
|
+
self.username = username
|
|
175
|
+
self.password = password
|
|
176
|
+
self.database = database
|
|
177
|
+
self._driver = None
|
|
178
|
+
|
|
179
|
+
def _get_driver(self):
|
|
180
|
+
"""Lazy initialization of Neo4j driver."""
|
|
181
|
+
if self._driver is None:
|
|
182
|
+
try:
|
|
183
|
+
from neo4j import GraphDatabase
|
|
184
|
+
|
|
185
|
+
self._driver = GraphDatabase.driver(
|
|
186
|
+
self.uri,
|
|
187
|
+
auth=(self.username, self.password),
|
|
188
|
+
)
|
|
189
|
+
except ImportError as err:
|
|
190
|
+
raise ImportError(
|
|
191
|
+
"neo4j package required for Neo4j graph store. "
|
|
192
|
+
"Install with: pip install neo4j"
|
|
193
|
+
) from err
|
|
194
|
+
return self._driver
|
|
195
|
+
|
|
196
|
+
def _run_query(self, query: str, parameters: Dict = None) -> List[Dict]:
|
|
197
|
+
"""Execute a Cypher query."""
|
|
198
|
+
driver = self._get_driver()
|
|
199
|
+
with driver.session(database=self.database) as session:
|
|
200
|
+
result = session.run(query, parameters or {})
|
|
201
|
+
return [dict(record) for record in result]
|
|
202
|
+
|
|
203
|
+
def add_entity(self, entity: Entity) -> str:
|
|
204
|
+
"""Add or update an entity in Neo4j."""
|
|
205
|
+
query = """
|
|
206
|
+
MERGE (e:Entity {id: $id})
|
|
207
|
+
SET e.name = $name,
|
|
208
|
+
e.entity_type = $entity_type,
|
|
209
|
+
e.properties = $properties,
|
|
210
|
+
e.created_at = $created_at
|
|
211
|
+
RETURN e.id as id
|
|
212
|
+
"""
|
|
213
|
+
result = self._run_query(
|
|
214
|
+
query,
|
|
215
|
+
{
|
|
216
|
+
"id": entity.id,
|
|
217
|
+
"name": entity.name,
|
|
218
|
+
"entity_type": entity.entity_type,
|
|
219
|
+
"properties": json.dumps(entity.properties),
|
|
220
|
+
"created_at": entity.created_at.isoformat(),
|
|
221
|
+
},
|
|
222
|
+
)
|
|
223
|
+
return result[0]["id"] if result else entity.id
|
|
224
|
+
|
|
225
|
+
def add_relationship(self, relationship: Relationship) -> str:
|
|
226
|
+
"""Add or update a relationship in Neo4j."""
|
|
227
|
+
# Sanitize relationship type for Cypher
|
|
228
|
+
rel_type = (
|
|
229
|
+
relationship.relation_type.replace("-", "_").replace(" ", "_").upper()
|
|
230
|
+
)
|
|
231
|
+
query = f"""
|
|
232
|
+
MATCH (source:Entity {{id: $source_id}})
|
|
233
|
+
MATCH (target:Entity {{id: $target_id}})
|
|
234
|
+
MERGE (source)-[r:{rel_type}]->(target)
|
|
235
|
+
SET r.id = $id,
|
|
236
|
+
r.properties = $properties,
|
|
237
|
+
r.confidence = $confidence,
|
|
238
|
+
r.created_at = $created_at
|
|
239
|
+
RETURN r.id as id
|
|
240
|
+
"""
|
|
241
|
+
result = self._run_query(
|
|
242
|
+
query,
|
|
243
|
+
{
|
|
244
|
+
"id": relationship.id,
|
|
245
|
+
"source_id": relationship.source_id,
|
|
246
|
+
"target_id": relationship.target_id,
|
|
247
|
+
"properties": json.dumps(relationship.properties),
|
|
248
|
+
"confidence": relationship.confidence,
|
|
249
|
+
"created_at": relationship.created_at.isoformat(),
|
|
250
|
+
},
|
|
251
|
+
)
|
|
252
|
+
return result[0]["id"] if result else relationship.id
|
|
253
|
+
|
|
254
|
+
def get_entity(self, entity_id: str) -> Optional[Entity]:
|
|
255
|
+
"""Get an entity by ID."""
|
|
256
|
+
query = """
|
|
257
|
+
MATCH (e:Entity {id: $id})
|
|
258
|
+
RETURN e.id as id, e.name as name, e.entity_type as entity_type,
|
|
259
|
+
e.properties as properties, e.created_at as created_at
|
|
260
|
+
"""
|
|
261
|
+
result = self._run_query(query, {"id": entity_id})
|
|
262
|
+
if not result:
|
|
263
|
+
return None
|
|
264
|
+
r = result[0]
|
|
265
|
+
return Entity(
|
|
266
|
+
id=r["id"],
|
|
267
|
+
name=r["name"],
|
|
268
|
+
entity_type=r["entity_type"],
|
|
269
|
+
properties=json.loads(r["properties"]) if r["properties"] else {},
|
|
270
|
+
created_at=(
|
|
271
|
+
datetime.fromisoformat(r["created_at"])
|
|
272
|
+
if r["created_at"]
|
|
273
|
+
else datetime.now(timezone.utc)
|
|
274
|
+
),
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
def find_entities(
|
|
278
|
+
self,
|
|
279
|
+
name: Optional[str] = None,
|
|
280
|
+
entity_type: Optional[str] = None,
|
|
281
|
+
limit: int = 10,
|
|
282
|
+
) -> List[Entity]:
|
|
283
|
+
"""Find entities by name or type."""
|
|
284
|
+
conditions = []
|
|
285
|
+
params: Dict[str, Any] = {"limit": limit}
|
|
286
|
+
|
|
287
|
+
if name:
|
|
288
|
+
conditions.append("e.name CONTAINS $name")
|
|
289
|
+
params["name"] = name
|
|
290
|
+
if entity_type:
|
|
291
|
+
conditions.append("e.entity_type = $entity_type")
|
|
292
|
+
params["entity_type"] = entity_type
|
|
293
|
+
|
|
294
|
+
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
|
295
|
+
|
|
296
|
+
query = f"""
|
|
297
|
+
MATCH (e:Entity)
|
|
298
|
+
{where_clause}
|
|
299
|
+
RETURN e.id as id, e.name as name, e.entity_type as entity_type,
|
|
300
|
+
e.properties as properties, e.created_at as created_at
|
|
301
|
+
LIMIT $limit
|
|
302
|
+
"""
|
|
303
|
+
|
|
304
|
+
results = self._run_query(query, params)
|
|
305
|
+
return [
|
|
306
|
+
Entity(
|
|
307
|
+
id=r["id"],
|
|
308
|
+
name=r["name"],
|
|
309
|
+
entity_type=r["entity_type"],
|
|
310
|
+
properties=json.loads(r["properties"]) if r["properties"] else {},
|
|
311
|
+
)
|
|
312
|
+
for r in results
|
|
313
|
+
]
|
|
314
|
+
|
|
315
|
+
def get_relationships(
|
|
316
|
+
self,
|
|
317
|
+
entity_id: str,
|
|
318
|
+
direction: str = "both",
|
|
319
|
+
relation_type: Optional[str] = None,
|
|
320
|
+
) -> List[Relationship]:
|
|
321
|
+
"""Get relationships for an entity."""
|
|
322
|
+
if direction == "outgoing":
|
|
323
|
+
pattern = "(e)-[r]->(other)"
|
|
324
|
+
elif direction == "incoming":
|
|
325
|
+
pattern = "(e)<-[r]-(other)"
|
|
326
|
+
else:
|
|
327
|
+
pattern = "(e)-[r]-(other)"
|
|
328
|
+
|
|
329
|
+
type_filter = f":{relation_type}" if relation_type else ""
|
|
330
|
+
|
|
331
|
+
query = f"""
|
|
332
|
+
MATCH (e:Entity {{id: $entity_id}}){pattern.replace("[r]", f"[r{type_filter}]")}
|
|
333
|
+
RETURN r.id as id, e.id as source_id, other.id as target_id,
|
|
334
|
+
type(r) as relation_type, r.properties as properties,
|
|
335
|
+
r.confidence as confidence, r.created_at as created_at
|
|
336
|
+
"""
|
|
337
|
+
|
|
338
|
+
results = self._run_query(query, {"entity_id": entity_id})
|
|
339
|
+
return [
|
|
340
|
+
Relationship(
|
|
341
|
+
id=r["id"] or f"{r['source_id']}-{r['relation_type']}-{r['target_id']}",
|
|
342
|
+
source_id=r["source_id"],
|
|
343
|
+
target_id=r["target_id"],
|
|
344
|
+
relation_type=r["relation_type"],
|
|
345
|
+
properties=json.loads(r["properties"]) if r["properties"] else {},
|
|
346
|
+
confidence=r["confidence"] or 1.0,
|
|
347
|
+
)
|
|
348
|
+
for r in results
|
|
349
|
+
]
|
|
350
|
+
|
|
351
|
+
def traverse(
|
|
352
|
+
self,
|
|
353
|
+
start_entity_id: str,
|
|
354
|
+
max_hops: int = 2,
|
|
355
|
+
relation_types: Optional[List[str]] = None,
|
|
356
|
+
) -> GraphResult:
|
|
357
|
+
"""Traverse the graph from a starting entity."""
|
|
358
|
+
start_time = time.time()
|
|
359
|
+
|
|
360
|
+
type_filter = ""
|
|
361
|
+
if relation_types:
|
|
362
|
+
type_filter = ":" + "|".join(relation_types)
|
|
363
|
+
|
|
364
|
+
query = f"""
|
|
365
|
+
MATCH path = (start:Entity {{id: $start_id}})-[r{type_filter}*1..{max_hops}]-(end:Entity)
|
|
366
|
+
RETURN nodes(path) as nodes, relationships(path) as rels
|
|
367
|
+
LIMIT 100
|
|
368
|
+
"""
|
|
369
|
+
|
|
370
|
+
results = self._run_query(query, {"start_id": start_entity_id})
|
|
371
|
+
|
|
372
|
+
entities = {}
|
|
373
|
+
relationships = {}
|
|
374
|
+
paths = []
|
|
375
|
+
|
|
376
|
+
for r in results:
|
|
377
|
+
path_ids = []
|
|
378
|
+
for node in r["nodes"]:
|
|
379
|
+
if node["id"] not in entities:
|
|
380
|
+
entities[node["id"]] = Entity(
|
|
381
|
+
id=node["id"],
|
|
382
|
+
name=node.get("name", ""),
|
|
383
|
+
entity_type=node.get("entity_type", "unknown"),
|
|
384
|
+
)
|
|
385
|
+
path_ids.append(node["id"])
|
|
386
|
+
paths.append(path_ids)
|
|
387
|
+
|
|
388
|
+
for rel in r["rels"]:
|
|
389
|
+
rel_id = rel.get("id", f"{rel['source_id']}-{rel['target_id']}")
|
|
390
|
+
if rel_id not in relationships:
|
|
391
|
+
relationships[rel_id] = Relationship(
|
|
392
|
+
id=rel_id,
|
|
393
|
+
source_id=rel.get("source_id", ""),
|
|
394
|
+
target_id=rel.get("target_id", ""),
|
|
395
|
+
relation_type=rel.get("type", "RELATED"),
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
query_time_ms = int((time.time() - start_time) * 1000)
|
|
399
|
+
|
|
400
|
+
return GraphResult(
|
|
401
|
+
entities=list(entities.values()),
|
|
402
|
+
relationships=list(relationships.values()),
|
|
403
|
+
paths=paths,
|
|
404
|
+
query_time_ms=query_time_ms,
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
def query(self, query: GraphQuery) -> GraphResult:
|
|
408
|
+
"""Execute a graph query."""
|
|
409
|
+
# Find starting entities
|
|
410
|
+
all_entities = {}
|
|
411
|
+
all_relationships = {}
|
|
412
|
+
all_paths = []
|
|
413
|
+
|
|
414
|
+
for entity_name in query.entities:
|
|
415
|
+
entities = self.find_entities(name=entity_name, limit=5)
|
|
416
|
+
for entity in entities:
|
|
417
|
+
all_entities[entity.id] = entity
|
|
418
|
+
result = self.traverse(
|
|
419
|
+
entity.id,
|
|
420
|
+
max_hops=query.max_hops,
|
|
421
|
+
relation_types=query.relation_types,
|
|
422
|
+
)
|
|
423
|
+
for e in result.entities:
|
|
424
|
+
all_entities[e.id] = e
|
|
425
|
+
for r in result.relationships:
|
|
426
|
+
all_relationships[r.id] = r
|
|
427
|
+
all_paths.extend(result.paths)
|
|
428
|
+
|
|
429
|
+
return GraphResult(
|
|
430
|
+
entities=list(all_entities.values())[: query.limit],
|
|
431
|
+
relationships=list(all_relationships.values()),
|
|
432
|
+
paths=all_paths[: query.limit],
|
|
433
|
+
query_time_ms=0,
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
def delete_entity(self, entity_id: str) -> bool:
|
|
437
|
+
"""Delete an entity and its relationships."""
|
|
438
|
+
query = """
|
|
439
|
+
MATCH (e:Entity {id: $id})
|
|
440
|
+
DETACH DELETE e
|
|
441
|
+
RETURN count(e) as deleted
|
|
442
|
+
"""
|
|
443
|
+
result = self._run_query(query, {"id": entity_id})
|
|
444
|
+
return result[0]["deleted"] > 0 if result else False
|
|
445
|
+
|
|
446
|
+
def close(self):
|
|
447
|
+
"""Close the Neo4j driver connection."""
|
|
448
|
+
if self._driver:
|
|
449
|
+
self._driver.close()
|
|
450
|
+
self._driver = None
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
class InMemoryGraphStore(GraphStore):
|
|
454
|
+
"""
|
|
455
|
+
In-memory graph storage for testing and development.
|
|
456
|
+
|
|
457
|
+
No external dependencies required.
|
|
458
|
+
"""
|
|
459
|
+
|
|
460
|
+
def __init__(self):
|
|
461
|
+
self._entities: Dict[str, Entity] = {}
|
|
462
|
+
self._relationships: Dict[str, Relationship] = {}
|
|
463
|
+
self._outgoing: Dict[str, List[str]] = {} # entity_id -> [rel_ids]
|
|
464
|
+
self._incoming: Dict[str, List[str]] = {} # entity_id -> [rel_ids]
|
|
465
|
+
|
|
466
|
+
def add_entity(self, entity: Entity) -> str:
|
|
467
|
+
self._entities[entity.id] = entity
|
|
468
|
+
if entity.id not in self._outgoing:
|
|
469
|
+
self._outgoing[entity.id] = []
|
|
470
|
+
if entity.id not in self._incoming:
|
|
471
|
+
self._incoming[entity.id] = []
|
|
472
|
+
return entity.id
|
|
473
|
+
|
|
474
|
+
def add_relationship(self, relationship: Relationship) -> str:
|
|
475
|
+
self._relationships[relationship.id] = relationship
|
|
476
|
+
|
|
477
|
+
if relationship.source_id not in self._outgoing:
|
|
478
|
+
self._outgoing[relationship.source_id] = []
|
|
479
|
+
if relationship.id not in self._outgoing[relationship.source_id]:
|
|
480
|
+
self._outgoing[relationship.source_id].append(relationship.id)
|
|
481
|
+
|
|
482
|
+
if relationship.target_id not in self._incoming:
|
|
483
|
+
self._incoming[relationship.target_id] = []
|
|
484
|
+
if relationship.id not in self._incoming[relationship.target_id]:
|
|
485
|
+
self._incoming[relationship.target_id].append(relationship.id)
|
|
486
|
+
|
|
487
|
+
return relationship.id
|
|
488
|
+
|
|
489
|
+
def get_entity(self, entity_id: str) -> Optional[Entity]:
|
|
490
|
+
return self._entities.get(entity_id)
|
|
491
|
+
|
|
492
|
+
def find_entities(
|
|
493
|
+
self,
|
|
494
|
+
name: Optional[str] = None,
|
|
495
|
+
entity_type: Optional[str] = None,
|
|
496
|
+
limit: int = 10,
|
|
497
|
+
) -> List[Entity]:
|
|
498
|
+
results = []
|
|
499
|
+
for entity in self._entities.values():
|
|
500
|
+
if name and name.lower() not in entity.name.lower():
|
|
501
|
+
continue
|
|
502
|
+
if entity_type and entity.entity_type != entity_type:
|
|
503
|
+
continue
|
|
504
|
+
results.append(entity)
|
|
505
|
+
if len(results) >= limit:
|
|
506
|
+
break
|
|
507
|
+
return results
|
|
508
|
+
|
|
509
|
+
def get_relationships(
|
|
510
|
+
self,
|
|
511
|
+
entity_id: str,
|
|
512
|
+
direction: str = "both",
|
|
513
|
+
relation_type: Optional[str] = None,
|
|
514
|
+
) -> List[Relationship]:
|
|
515
|
+
rel_ids = set()
|
|
516
|
+
|
|
517
|
+
if direction in ("outgoing", "both"):
|
|
518
|
+
rel_ids.update(self._outgoing.get(entity_id, []))
|
|
519
|
+
if direction in ("incoming", "both"):
|
|
520
|
+
rel_ids.update(self._incoming.get(entity_id, []))
|
|
521
|
+
|
|
522
|
+
results = []
|
|
523
|
+
for rel_id in rel_ids:
|
|
524
|
+
rel = self._relationships.get(rel_id)
|
|
525
|
+
if rel:
|
|
526
|
+
if relation_type and rel.relation_type != relation_type:
|
|
527
|
+
continue
|
|
528
|
+
results.append(rel)
|
|
529
|
+
return results
|
|
530
|
+
|
|
531
|
+
def traverse(
|
|
532
|
+
self,
|
|
533
|
+
start_entity_id: str,
|
|
534
|
+
max_hops: int = 2,
|
|
535
|
+
relation_types: Optional[List[str]] = None,
|
|
536
|
+
) -> GraphResult:
|
|
537
|
+
start_time = time.time()
|
|
538
|
+
|
|
539
|
+
visited_entities = {start_entity_id}
|
|
540
|
+
visited_relationships = set()
|
|
541
|
+
paths = []
|
|
542
|
+
|
|
543
|
+
def _traverse(current_id: str, depth: int, current_path: List[str]):
|
|
544
|
+
if depth > max_hops:
|
|
545
|
+
return
|
|
546
|
+
|
|
547
|
+
for rel in self.get_relationships(current_id, "both"):
|
|
548
|
+
if relation_types and rel.relation_type not in relation_types:
|
|
549
|
+
continue
|
|
550
|
+
|
|
551
|
+
visited_relationships.add(rel.id)
|
|
552
|
+
|
|
553
|
+
next_id = (
|
|
554
|
+
rel.target_id if rel.source_id == current_id else rel.source_id
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
if next_id not in visited_entities:
|
|
558
|
+
visited_entities.add(next_id)
|
|
559
|
+
new_path = current_path + [next_id]
|
|
560
|
+
paths.append(new_path)
|
|
561
|
+
_traverse(next_id, depth + 1, new_path)
|
|
562
|
+
|
|
563
|
+
_traverse(start_entity_id, 0, [start_entity_id])
|
|
564
|
+
|
|
565
|
+
entities = [
|
|
566
|
+
self._entities[eid] for eid in visited_entities if eid in self._entities
|
|
567
|
+
]
|
|
568
|
+
relationships = [
|
|
569
|
+
self._relationships[rid]
|
|
570
|
+
for rid in visited_relationships
|
|
571
|
+
if rid in self._relationships
|
|
572
|
+
]
|
|
573
|
+
|
|
574
|
+
query_time_ms = int((time.time() - start_time) * 1000)
|
|
575
|
+
|
|
576
|
+
return GraphResult(
|
|
577
|
+
entities=entities,
|
|
578
|
+
relationships=relationships,
|
|
579
|
+
paths=paths,
|
|
580
|
+
query_time_ms=query_time_ms,
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
def query(self, query: GraphQuery) -> GraphResult:
|
|
584
|
+
all_entities = {}
|
|
585
|
+
all_relationships = {}
|
|
586
|
+
all_paths = []
|
|
587
|
+
|
|
588
|
+
for entity_name in query.entities:
|
|
589
|
+
entities = self.find_entities(name=entity_name, limit=5)
|
|
590
|
+
for entity in entities:
|
|
591
|
+
all_entities[entity.id] = entity
|
|
592
|
+
result = self.traverse(
|
|
593
|
+
entity.id,
|
|
594
|
+
max_hops=query.max_hops,
|
|
595
|
+
relation_types=query.relation_types,
|
|
596
|
+
)
|
|
597
|
+
for e in result.entities:
|
|
598
|
+
all_entities[e.id] = e
|
|
599
|
+
for r in result.relationships:
|
|
600
|
+
all_relationships[r.id] = r
|
|
601
|
+
all_paths.extend(result.paths)
|
|
602
|
+
|
|
603
|
+
return GraphResult(
|
|
604
|
+
entities=list(all_entities.values())[: query.limit],
|
|
605
|
+
relationships=list(all_relationships.values()),
|
|
606
|
+
paths=all_paths[: query.limit],
|
|
607
|
+
query_time_ms=0,
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
def delete_entity(self, entity_id: str) -> bool:
|
|
611
|
+
if entity_id not in self._entities:
|
|
612
|
+
return False
|
|
613
|
+
|
|
614
|
+
# Delete outgoing relationships
|
|
615
|
+
for rel_id in list(self._outgoing.get(entity_id, [])):
|
|
616
|
+
if rel_id in self._relationships:
|
|
617
|
+
rel = self._relationships[rel_id]
|
|
618
|
+
# Remove from target's incoming
|
|
619
|
+
if rel.target_id in self._incoming:
|
|
620
|
+
if rel_id in self._incoming[rel.target_id]:
|
|
621
|
+
self._incoming[rel.target_id].remove(rel_id)
|
|
622
|
+
del self._relationships[rel_id]
|
|
623
|
+
|
|
624
|
+
# Delete incoming relationships
|
|
625
|
+
for rel_id in list(self._incoming.get(entity_id, [])):
|
|
626
|
+
if rel_id in self._relationships:
|
|
627
|
+
rel = self._relationships[rel_id]
|
|
628
|
+
# Remove from source's outgoing
|
|
629
|
+
if rel.source_id in self._outgoing:
|
|
630
|
+
if rel_id in self._outgoing[rel.source_id]:
|
|
631
|
+
self._outgoing[rel.source_id].remove(rel_id)
|
|
632
|
+
del self._relationships[rel_id]
|
|
633
|
+
|
|
634
|
+
# Delete entity
|
|
635
|
+
del self._entities[entity_id]
|
|
636
|
+
self._outgoing.pop(entity_id, None)
|
|
637
|
+
self._incoming.pop(entity_id, None)
|
|
638
|
+
|
|
639
|
+
return True
|
|
640
|
+
|
|
641
|
+
def clear(self) -> None:
|
|
642
|
+
"""Clear all stored data."""
|
|
643
|
+
self._entities.clear()
|
|
644
|
+
self._relationships.clear()
|
|
645
|
+
self._outgoing.clear()
|
|
646
|
+
self._incoming.clear()
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
class BackendGraphStore(GraphStore):
|
|
650
|
+
"""
|
|
651
|
+
GraphStore implementation that delegates to a GraphBackend.
|
|
652
|
+
|
|
653
|
+
This class bridges the high-level GraphStore API with the pluggable
|
|
654
|
+
GraphBackend interface, enabling use of different database backends
|
|
655
|
+
while maintaining the full GraphStore functionality.
|
|
656
|
+
|
|
657
|
+
Example:
|
|
658
|
+
from alma.graph import create_graph_backend, BackendGraphStore
|
|
659
|
+
|
|
660
|
+
backend = create_graph_backend("neo4j", uri="bolt://localhost:7687", ...)
|
|
661
|
+
store = BackendGraphStore(backend)
|
|
662
|
+
store.add_entity(entity)
|
|
663
|
+
result = store.traverse(entity.id)
|
|
664
|
+
"""
|
|
665
|
+
|
|
666
|
+
def __init__(self, backend: "GraphBackend"):
|
|
667
|
+
"""
|
|
668
|
+
Initialize with a GraphBackend.
|
|
669
|
+
|
|
670
|
+
Args:
|
|
671
|
+
backend: The GraphBackend implementation to use for storage.
|
|
672
|
+
"""
|
|
673
|
+
self._backend = backend
|
|
674
|
+
|
|
675
|
+
@property
|
|
676
|
+
def backend(self) -> "GraphBackend":
|
|
677
|
+
"""Access the underlying backend."""
|
|
678
|
+
return self._backend
|
|
679
|
+
|
|
680
|
+
def add_entity(self, entity: Entity) -> str:
|
|
681
|
+
return self._backend.add_entity(entity)
|
|
682
|
+
|
|
683
|
+
def add_relationship(self, relationship: Relationship) -> str:
|
|
684
|
+
return self._backend.add_relationship(relationship)
|
|
685
|
+
|
|
686
|
+
def get_entity(self, entity_id: str) -> Optional[Entity]:
|
|
687
|
+
return self._backend.get_entity(entity_id)
|
|
688
|
+
|
|
689
|
+
def find_entities(
|
|
690
|
+
self,
|
|
691
|
+
name: Optional[str] = None,
|
|
692
|
+
entity_type: Optional[str] = None,
|
|
693
|
+
limit: int = 10,
|
|
694
|
+
) -> List[Entity]:
|
|
695
|
+
"""Find entities by name or type using the backend."""
|
|
696
|
+
if hasattr(self._backend, "find_entities"):
|
|
697
|
+
return self._backend.find_entities(
|
|
698
|
+
name=name, entity_type=entity_type, limit=limit
|
|
699
|
+
)
|
|
700
|
+
# Fallback to search_entities if find_entities not available
|
|
701
|
+
if name:
|
|
702
|
+
return self._backend.search_entities(query=name, top_k=limit)
|
|
703
|
+
return self._backend.get_entities(entity_type=entity_type, limit=limit)
|
|
704
|
+
|
|
705
|
+
def get_relationships(
|
|
706
|
+
self,
|
|
707
|
+
entity_id: str,
|
|
708
|
+
direction: str = "both",
|
|
709
|
+
relation_type: Optional[str] = None,
|
|
710
|
+
) -> List[Relationship]:
|
|
711
|
+
"""Get relationships for an entity."""
|
|
712
|
+
if hasattr(self._backend, "get_relationships_directional"):
|
|
713
|
+
return self._backend.get_relationships_directional(
|
|
714
|
+
entity_id=entity_id, direction=direction, relation_type=relation_type
|
|
715
|
+
)
|
|
716
|
+
# Fallback - get all relationships and filter
|
|
717
|
+
all_rels = self._backend.get_relationships(entity_id)
|
|
718
|
+
if relation_type:
|
|
719
|
+
all_rels = [r for r in all_rels if r.relation_type == relation_type]
|
|
720
|
+
return all_rels
|
|
721
|
+
|
|
722
|
+
def traverse(
|
|
723
|
+
self,
|
|
724
|
+
start_entity_id: str,
|
|
725
|
+
max_hops: int = 2,
|
|
726
|
+
relation_types: Optional[List[str]] = None,
|
|
727
|
+
) -> GraphResult:
|
|
728
|
+
"""Traverse the graph from a starting entity."""
|
|
729
|
+
start_time = time.time()
|
|
730
|
+
|
|
731
|
+
visited_entities = {start_entity_id}
|
|
732
|
+
visited_relationships: Dict[str, Relationship] = {}
|
|
733
|
+
paths: List[List[str]] = []
|
|
734
|
+
|
|
735
|
+
def _traverse(current_id: str, depth: int, current_path: List[str]):
|
|
736
|
+
if depth > max_hops:
|
|
737
|
+
return
|
|
738
|
+
|
|
739
|
+
rels = self._backend.get_relationships(current_id)
|
|
740
|
+
for rel in rels:
|
|
741
|
+
if relation_types and rel.relation_type not in relation_types:
|
|
742
|
+
continue
|
|
743
|
+
|
|
744
|
+
visited_relationships[rel.id] = rel
|
|
745
|
+
|
|
746
|
+
next_id = (
|
|
747
|
+
rel.target_id if rel.source_id == current_id else rel.source_id
|
|
748
|
+
)
|
|
749
|
+
|
|
750
|
+
if next_id not in visited_entities:
|
|
751
|
+
visited_entities.add(next_id)
|
|
752
|
+
new_path = current_path + [next_id]
|
|
753
|
+
paths.append(new_path)
|
|
754
|
+
_traverse(next_id, depth + 1, new_path)
|
|
755
|
+
|
|
756
|
+
_traverse(start_entity_id, 0, [start_entity_id])
|
|
757
|
+
|
|
758
|
+
# Fetch all visited entities
|
|
759
|
+
entities = []
|
|
760
|
+
for eid in visited_entities:
|
|
761
|
+
entity = self._backend.get_entity(eid)
|
|
762
|
+
if entity:
|
|
763
|
+
entities.append(entity)
|
|
764
|
+
|
|
765
|
+
query_time_ms = int((time.time() - start_time) * 1000)
|
|
766
|
+
|
|
767
|
+
return GraphResult(
|
|
768
|
+
entities=entities,
|
|
769
|
+
relationships=list(visited_relationships.values()),
|
|
770
|
+
paths=paths,
|
|
771
|
+
query_time_ms=query_time_ms,
|
|
772
|
+
)
|
|
773
|
+
|
|
774
|
+
def query(self, query: GraphQuery) -> GraphResult:
|
|
775
|
+
"""Execute a graph query."""
|
|
776
|
+
all_entities: Dict[str, Entity] = {}
|
|
777
|
+
all_relationships: Dict[str, Relationship] = {}
|
|
778
|
+
all_paths: List[List[str]] = []
|
|
779
|
+
|
|
780
|
+
for entity_name in query.entities:
|
|
781
|
+
entities = self.find_entities(name=entity_name, limit=5)
|
|
782
|
+
for entity in entities:
|
|
783
|
+
all_entities[entity.id] = entity
|
|
784
|
+
result = self.traverse(
|
|
785
|
+
entity.id,
|
|
786
|
+
max_hops=query.max_hops,
|
|
787
|
+
relation_types=query.relation_types,
|
|
788
|
+
)
|
|
789
|
+
for e in result.entities:
|
|
790
|
+
all_entities[e.id] = e
|
|
791
|
+
for r in result.relationships:
|
|
792
|
+
all_relationships[r.id] = r
|
|
793
|
+
all_paths.extend(result.paths)
|
|
794
|
+
|
|
795
|
+
return GraphResult(
|
|
796
|
+
entities=list(all_entities.values())[: query.limit],
|
|
797
|
+
relationships=list(all_relationships.values()),
|
|
798
|
+
paths=all_paths[: query.limit],
|
|
799
|
+
query_time_ms=0,
|
|
800
|
+
)
|
|
801
|
+
|
|
802
|
+
def delete_entity(self, entity_id: str) -> bool:
|
|
803
|
+
return self._backend.delete_entity(entity_id)
|
|
804
|
+
|
|
805
|
+
def close(self) -> None:
|
|
806
|
+
"""Close the backend connection."""
|
|
807
|
+
self._backend.close()
|
|
808
|
+
|
|
809
|
+
|
|
810
|
+
def create_graph_store(
|
|
811
|
+
provider: str = "memory",
|
|
812
|
+
**kwargs,
|
|
813
|
+
) -> GraphStore:
|
|
814
|
+
"""
|
|
815
|
+
Factory function to create a graph store.
|
|
816
|
+
|
|
817
|
+
Args:
|
|
818
|
+
provider: "neo4j", "memory", or "backend"
|
|
819
|
+
**kwargs: Provider-specific arguments
|
|
820
|
+
For "neo4j": uri, username, password, database
|
|
821
|
+
For "backend": backend (GraphBackend instance)
|
|
822
|
+
|
|
823
|
+
Returns:
|
|
824
|
+
Configured GraphStore instance
|
|
825
|
+
|
|
826
|
+
Note:
|
|
827
|
+
Amazon Neptune support is planned for a future release.
|
|
828
|
+
|
|
829
|
+
Example:
|
|
830
|
+
# Create in-memory store
|
|
831
|
+
store = create_graph_store("memory")
|
|
832
|
+
|
|
833
|
+
# Create Neo4j store
|
|
834
|
+
store = create_graph_store(
|
|
835
|
+
"neo4j",
|
|
836
|
+
uri="bolt://localhost:7687",
|
|
837
|
+
username="neo4j",
|
|
838
|
+
password="password"
|
|
839
|
+
)
|
|
840
|
+
|
|
841
|
+
# Create store using a backend
|
|
842
|
+
from alma.graph import create_graph_backend
|
|
843
|
+
backend = create_graph_backend("neo4j", uri="...", ...)
|
|
844
|
+
store = create_graph_store("backend", backend=backend)
|
|
845
|
+
"""
|
|
846
|
+
if provider == "neo4j":
|
|
847
|
+
return Neo4jGraphStore(**kwargs)
|
|
848
|
+
elif provider == "neptune":
|
|
849
|
+
# Neptune support is planned for a future release
|
|
850
|
+
raise NotImplementedError(
|
|
851
|
+
"Neptune support is not yet implemented. "
|
|
852
|
+
"Use 'neo4j' or 'memory' providers instead."
|
|
853
|
+
)
|
|
854
|
+
elif provider == "backend":
|
|
855
|
+
backend = kwargs.get("backend")
|
|
856
|
+
if backend is None:
|
|
857
|
+
raise ValueError("'backend' argument required for 'backend' provider")
|
|
858
|
+
return BackendGraphStore(backend)
|
|
859
|
+
else:
|
|
860
|
+
return InMemoryGraphStore()
|