alma-memory 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. alma/__init__.py +88 -44
  2. alma/confidence/__init__.py +1 -1
  3. alma/confidence/engine.py +92 -58
  4. alma/confidence/types.py +34 -14
  5. alma/config/loader.py +3 -2
  6. alma/consolidation/__init__.py +23 -0
  7. alma/consolidation/engine.py +678 -0
  8. alma/consolidation/prompts.py +84 -0
  9. alma/core.py +15 -15
  10. alma/domains/__init__.py +6 -6
  11. alma/domains/factory.py +12 -9
  12. alma/domains/schemas.py +17 -3
  13. alma/domains/types.py +8 -4
  14. alma/events/__init__.py +75 -0
  15. alma/events/emitter.py +284 -0
  16. alma/events/storage_mixin.py +246 -0
  17. alma/events/types.py +126 -0
  18. alma/events/webhook.py +425 -0
  19. alma/exceptions.py +49 -0
  20. alma/extraction/__init__.py +31 -0
  21. alma/extraction/auto_learner.py +264 -0
  22. alma/extraction/extractor.py +420 -0
  23. alma/graph/__init__.py +81 -0
  24. alma/graph/backends/__init__.py +18 -0
  25. alma/graph/backends/memory.py +236 -0
  26. alma/graph/backends/neo4j.py +417 -0
  27. alma/graph/base.py +159 -0
  28. alma/graph/extraction.py +198 -0
  29. alma/graph/store.py +860 -0
  30. alma/harness/__init__.py +4 -4
  31. alma/harness/base.py +18 -9
  32. alma/harness/domains.py +27 -11
  33. alma/initializer/__init__.py +1 -1
  34. alma/initializer/initializer.py +51 -43
  35. alma/initializer/types.py +25 -17
  36. alma/integration/__init__.py +9 -9
  37. alma/integration/claude_agents.py +10 -10
  38. alma/integration/helena.py +32 -22
  39. alma/integration/victor.py +57 -33
  40. alma/learning/__init__.py +27 -27
  41. alma/learning/forgetting.py +198 -148
  42. alma/learning/heuristic_extractor.py +40 -24
  43. alma/learning/protocols.py +62 -14
  44. alma/learning/validation.py +7 -2
  45. alma/mcp/__init__.py +4 -4
  46. alma/mcp/__main__.py +2 -1
  47. alma/mcp/resources.py +17 -16
  48. alma/mcp/server.py +102 -44
  49. alma/mcp/tools.py +174 -37
  50. alma/progress/__init__.py +3 -3
  51. alma/progress/tracker.py +26 -20
  52. alma/progress/types.py +8 -12
  53. alma/py.typed +0 -0
  54. alma/retrieval/__init__.py +11 -11
  55. alma/retrieval/cache.py +20 -21
  56. alma/retrieval/embeddings.py +4 -4
  57. alma/retrieval/engine.py +114 -35
  58. alma/retrieval/scoring.py +73 -63
  59. alma/session/__init__.py +2 -2
  60. alma/session/manager.py +5 -5
  61. alma/session/types.py +5 -4
  62. alma/storage/__init__.py +41 -0
  63. alma/storage/azure_cosmos.py +101 -31
  64. alma/storage/base.py +157 -4
  65. alma/storage/chroma.py +1443 -0
  66. alma/storage/file_based.py +56 -20
  67. alma/storage/pinecone.py +1080 -0
  68. alma/storage/postgresql.py +1452 -0
  69. alma/storage/qdrant.py +1306 -0
  70. alma/storage/sqlite_local.py +376 -31
  71. alma/types.py +62 -14
  72. alma_memory-0.5.0.dist-info/METADATA +905 -0
  73. alma_memory-0.5.0.dist-info/RECORD +76 -0
  74. {alma_memory-0.4.0.dist-info → alma_memory-0.5.0.dist-info}/WHEEL +1 -1
  75. alma_memory-0.4.0.dist-info/METADATA +0 -488
  76. alma_memory-0.4.0.dist-info/RECORD +0 -52
  77. {alma_memory-0.4.0.dist-info → alma_memory-0.5.0.dist-info}/top_level.txt +0 -0
alma/graph/store.py ADDED
@@ -0,0 +1,860 @@
1
+ """
2
+ ALMA Graph Memory Module.
3
+
4
+ Graph-based memory storage for capturing relationships between entities.
5
+ Supports Neo4j and in-memory graph for testing.
6
+
7
+ This module provides two APIs:
8
+ 1. GraphStore (high-level) - Full-featured graph store with traversal and query support
9
+ 2. GraphBackend (low-level) - Simple CRUD operations for pluggable backends
10
+
11
+ The GraphStore classes can optionally use GraphBackend implementations for storage,
12
+ enabling easy swapping of database backends while keeping the high-level API.
13
+ """
14
+
15
+ import json
16
+ import logging
17
+ import time
18
+ from abc import ABC, abstractmethod
19
+ from dataclasses import dataclass, field
20
+ from datetime import datetime, timezone
21
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
22
+
23
+ if TYPE_CHECKING:
24
+ from alma.graph.base import GraphBackend
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ @dataclass
30
+ class Entity:
31
+ """A node in the knowledge graph."""
32
+
33
+ id: str
34
+ name: str
35
+ entity_type: str # person, organization, concept, tool, etc.
36
+ properties: Dict[str, Any] = field(default_factory=dict)
37
+ created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
38
+
39
+
40
+ @dataclass
41
+ class Relationship:
42
+ """An edge in the knowledge graph."""
43
+
44
+ id: str
45
+ source_id: str
46
+ target_id: str
47
+ relation_type: str # WORKS_AT, USES, KNOWS, CREATED_BY, etc.
48
+ properties: Dict[str, Any] = field(default_factory=dict)
49
+ confidence: float = 1.0
50
+ created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
51
+
52
+
53
+ @dataclass
54
+ class GraphQuery:
55
+ """A query against the knowledge graph."""
56
+
57
+ entities: List[str] # Entity names or IDs to search for
58
+ relation_types: Optional[List[str]] = None # Filter by relation types
59
+ max_hops: int = 2 # Maximum traversal depth
60
+ limit: int = 20 # Maximum results
61
+
62
+
63
+ @dataclass
64
+ class GraphResult:
65
+ """Result from a graph query."""
66
+
67
+ entities: List[Entity]
68
+ relationships: List[Relationship]
69
+ paths: List[List[str]] # Paths through the graph
70
+ query_time_ms: int
71
+
72
+
73
+ class GraphStore(ABC):
74
+ """
75
+ Abstract base class for graph storage backends.
76
+
77
+ This is the high-level API for graph operations including traversal
78
+ and complex queries. For simple CRUD operations, see GraphBackend.
79
+ """
80
+
81
+ @abstractmethod
82
+ def add_entity(self, entity: Entity) -> str:
83
+ """Add or update an entity."""
84
+ pass
85
+
86
+ @abstractmethod
87
+ def add_relationship(self, relationship: Relationship) -> str:
88
+ """Add or update a relationship."""
89
+ pass
90
+
91
+ @abstractmethod
92
+ def get_entity(self, entity_id: str) -> Optional[Entity]:
93
+ """Get an entity by ID."""
94
+ pass
95
+
96
+ @abstractmethod
97
+ def find_entities(
98
+ self,
99
+ name: Optional[str] = None,
100
+ entity_type: Optional[str] = None,
101
+ limit: int = 10,
102
+ ) -> List[Entity]:
103
+ """Find entities by name or type."""
104
+ pass
105
+
106
+ @abstractmethod
107
+ def get_relationships(
108
+ self,
109
+ entity_id: str,
110
+ direction: str = "both", # "outgoing", "incoming", "both"
111
+ relation_type: Optional[str] = None,
112
+ ) -> List[Relationship]:
113
+ """Get relationships for an entity."""
114
+ pass
115
+
116
+ @abstractmethod
117
+ def traverse(
118
+ self,
119
+ start_entity_id: str,
120
+ max_hops: int = 2,
121
+ relation_types: Optional[List[str]] = None,
122
+ ) -> GraphResult:
123
+ """Traverse the graph from a starting entity."""
124
+ pass
125
+
126
+ @abstractmethod
127
+ def query(self, query: GraphQuery) -> GraphResult:
128
+ """Execute a graph query."""
129
+ pass
130
+
131
+ @abstractmethod
132
+ def delete_entity(self, entity_id: str) -> bool:
133
+ """Delete an entity and its relationships."""
134
+ pass
135
+
136
+ def close(self) -> None: # noqa: B027
137
+ """Close the store connection. Override in subclasses if needed."""
138
+ pass
139
+
140
+ def __enter__(self):
141
+ """Context manager entry."""
142
+ return self
143
+
144
+ def __exit__(self, exc_type, exc_val, exc_tb):
145
+ """Context manager exit - ensures close is called."""
146
+ self.close()
147
+ return False
148
+
149
+
150
+ class Neo4jGraphStore(GraphStore):
151
+ """
152
+ Neo4j graph storage backend.
153
+
154
+ Requires neo4j Python driver: pip install neo4j
155
+ """
156
+
157
+ def __init__(
158
+ self,
159
+ uri: str,
160
+ username: str,
161
+ password: str,
162
+ database: str = "neo4j",
163
+ ):
164
+ """
165
+ Initialize Neo4j connection.
166
+
167
+ Args:
168
+ uri: Neo4j connection URI (bolt:// or neo4j+s://)
169
+ username: Database username
170
+ password: Database password
171
+ database: Database name
172
+ """
173
+ self.uri = uri
174
+ self.username = username
175
+ self.password = password
176
+ self.database = database
177
+ self._driver = None
178
+
179
+ def _get_driver(self):
180
+ """Lazy initialization of Neo4j driver."""
181
+ if self._driver is None:
182
+ try:
183
+ from neo4j import GraphDatabase
184
+
185
+ self._driver = GraphDatabase.driver(
186
+ self.uri,
187
+ auth=(self.username, self.password),
188
+ )
189
+ except ImportError as err:
190
+ raise ImportError(
191
+ "neo4j package required for Neo4j graph store. "
192
+ "Install with: pip install neo4j"
193
+ ) from err
194
+ return self._driver
195
+
196
+ def _run_query(self, query: str, parameters: Dict = None) -> List[Dict]:
197
+ """Execute a Cypher query."""
198
+ driver = self._get_driver()
199
+ with driver.session(database=self.database) as session:
200
+ result = session.run(query, parameters or {})
201
+ return [dict(record) for record in result]
202
+
203
+ def add_entity(self, entity: Entity) -> str:
204
+ """Add or update an entity in Neo4j."""
205
+ query = """
206
+ MERGE (e:Entity {id: $id})
207
+ SET e.name = $name,
208
+ e.entity_type = $entity_type,
209
+ e.properties = $properties,
210
+ e.created_at = $created_at
211
+ RETURN e.id as id
212
+ """
213
+ result = self._run_query(
214
+ query,
215
+ {
216
+ "id": entity.id,
217
+ "name": entity.name,
218
+ "entity_type": entity.entity_type,
219
+ "properties": json.dumps(entity.properties),
220
+ "created_at": entity.created_at.isoformat(),
221
+ },
222
+ )
223
+ return result[0]["id"] if result else entity.id
224
+
225
+ def add_relationship(self, relationship: Relationship) -> str:
226
+ """Add or update a relationship in Neo4j."""
227
+ # Sanitize relationship type for Cypher
228
+ rel_type = (
229
+ relationship.relation_type.replace("-", "_").replace(" ", "_").upper()
230
+ )
231
+ query = f"""
232
+ MATCH (source:Entity {{id: $source_id}})
233
+ MATCH (target:Entity {{id: $target_id}})
234
+ MERGE (source)-[r:{rel_type}]->(target)
235
+ SET r.id = $id,
236
+ r.properties = $properties,
237
+ r.confidence = $confidence,
238
+ r.created_at = $created_at
239
+ RETURN r.id as id
240
+ """
241
+ result = self._run_query(
242
+ query,
243
+ {
244
+ "id": relationship.id,
245
+ "source_id": relationship.source_id,
246
+ "target_id": relationship.target_id,
247
+ "properties": json.dumps(relationship.properties),
248
+ "confidence": relationship.confidence,
249
+ "created_at": relationship.created_at.isoformat(),
250
+ },
251
+ )
252
+ return result[0]["id"] if result else relationship.id
253
+
254
+ def get_entity(self, entity_id: str) -> Optional[Entity]:
255
+ """Get an entity by ID."""
256
+ query = """
257
+ MATCH (e:Entity {id: $id})
258
+ RETURN e.id as id, e.name as name, e.entity_type as entity_type,
259
+ e.properties as properties, e.created_at as created_at
260
+ """
261
+ result = self._run_query(query, {"id": entity_id})
262
+ if not result:
263
+ return None
264
+ r = result[0]
265
+ return Entity(
266
+ id=r["id"],
267
+ name=r["name"],
268
+ entity_type=r["entity_type"],
269
+ properties=json.loads(r["properties"]) if r["properties"] else {},
270
+ created_at=(
271
+ datetime.fromisoformat(r["created_at"])
272
+ if r["created_at"]
273
+ else datetime.now(timezone.utc)
274
+ ),
275
+ )
276
+
277
+ def find_entities(
278
+ self,
279
+ name: Optional[str] = None,
280
+ entity_type: Optional[str] = None,
281
+ limit: int = 10,
282
+ ) -> List[Entity]:
283
+ """Find entities by name or type."""
284
+ conditions = []
285
+ params: Dict[str, Any] = {"limit": limit}
286
+
287
+ if name:
288
+ conditions.append("e.name CONTAINS $name")
289
+ params["name"] = name
290
+ if entity_type:
291
+ conditions.append("e.entity_type = $entity_type")
292
+ params["entity_type"] = entity_type
293
+
294
+ where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
295
+
296
+ query = f"""
297
+ MATCH (e:Entity)
298
+ {where_clause}
299
+ RETURN e.id as id, e.name as name, e.entity_type as entity_type,
300
+ e.properties as properties, e.created_at as created_at
301
+ LIMIT $limit
302
+ """
303
+
304
+ results = self._run_query(query, params)
305
+ return [
306
+ Entity(
307
+ id=r["id"],
308
+ name=r["name"],
309
+ entity_type=r["entity_type"],
310
+ properties=json.loads(r["properties"]) if r["properties"] else {},
311
+ )
312
+ for r in results
313
+ ]
314
+
315
+ def get_relationships(
316
+ self,
317
+ entity_id: str,
318
+ direction: str = "both",
319
+ relation_type: Optional[str] = None,
320
+ ) -> List[Relationship]:
321
+ """Get relationships for an entity."""
322
+ if direction == "outgoing":
323
+ pattern = "(e)-[r]->(other)"
324
+ elif direction == "incoming":
325
+ pattern = "(e)<-[r]-(other)"
326
+ else:
327
+ pattern = "(e)-[r]-(other)"
328
+
329
+ type_filter = f":{relation_type}" if relation_type else ""
330
+
331
+ query = f"""
332
+ MATCH (e:Entity {{id: $entity_id}}){pattern.replace("[r]", f"[r{type_filter}]")}
333
+ RETURN r.id as id, e.id as source_id, other.id as target_id,
334
+ type(r) as relation_type, r.properties as properties,
335
+ r.confidence as confidence, r.created_at as created_at
336
+ """
337
+
338
+ results = self._run_query(query, {"entity_id": entity_id})
339
+ return [
340
+ Relationship(
341
+ id=r["id"] or f"{r['source_id']}-{r['relation_type']}-{r['target_id']}",
342
+ source_id=r["source_id"],
343
+ target_id=r["target_id"],
344
+ relation_type=r["relation_type"],
345
+ properties=json.loads(r["properties"]) if r["properties"] else {},
346
+ confidence=r["confidence"] or 1.0,
347
+ )
348
+ for r in results
349
+ ]
350
+
351
+ def traverse(
352
+ self,
353
+ start_entity_id: str,
354
+ max_hops: int = 2,
355
+ relation_types: Optional[List[str]] = None,
356
+ ) -> GraphResult:
357
+ """Traverse the graph from a starting entity."""
358
+ start_time = time.time()
359
+
360
+ type_filter = ""
361
+ if relation_types:
362
+ type_filter = ":" + "|".join(relation_types)
363
+
364
+ query = f"""
365
+ MATCH path = (start:Entity {{id: $start_id}})-[r{type_filter}*1..{max_hops}]-(end:Entity)
366
+ RETURN nodes(path) as nodes, relationships(path) as rels
367
+ LIMIT 100
368
+ """
369
+
370
+ results = self._run_query(query, {"start_id": start_entity_id})
371
+
372
+ entities = {}
373
+ relationships = {}
374
+ paths = []
375
+
376
+ for r in results:
377
+ path_ids = []
378
+ for node in r["nodes"]:
379
+ if node["id"] not in entities:
380
+ entities[node["id"]] = Entity(
381
+ id=node["id"],
382
+ name=node.get("name", ""),
383
+ entity_type=node.get("entity_type", "unknown"),
384
+ )
385
+ path_ids.append(node["id"])
386
+ paths.append(path_ids)
387
+
388
+ for rel in r["rels"]:
389
+ rel_id = rel.get("id", f"{rel['source_id']}-{rel['target_id']}")
390
+ if rel_id not in relationships:
391
+ relationships[rel_id] = Relationship(
392
+ id=rel_id,
393
+ source_id=rel.get("source_id", ""),
394
+ target_id=rel.get("target_id", ""),
395
+ relation_type=rel.get("type", "RELATED"),
396
+ )
397
+
398
+ query_time_ms = int((time.time() - start_time) * 1000)
399
+
400
+ return GraphResult(
401
+ entities=list(entities.values()),
402
+ relationships=list(relationships.values()),
403
+ paths=paths,
404
+ query_time_ms=query_time_ms,
405
+ )
406
+
407
+ def query(self, query: GraphQuery) -> GraphResult:
408
+ """Execute a graph query."""
409
+ # Find starting entities
410
+ all_entities = {}
411
+ all_relationships = {}
412
+ all_paths = []
413
+
414
+ for entity_name in query.entities:
415
+ entities = self.find_entities(name=entity_name, limit=5)
416
+ for entity in entities:
417
+ all_entities[entity.id] = entity
418
+ result = self.traverse(
419
+ entity.id,
420
+ max_hops=query.max_hops,
421
+ relation_types=query.relation_types,
422
+ )
423
+ for e in result.entities:
424
+ all_entities[e.id] = e
425
+ for r in result.relationships:
426
+ all_relationships[r.id] = r
427
+ all_paths.extend(result.paths)
428
+
429
+ return GraphResult(
430
+ entities=list(all_entities.values())[: query.limit],
431
+ relationships=list(all_relationships.values()),
432
+ paths=all_paths[: query.limit],
433
+ query_time_ms=0,
434
+ )
435
+
436
+ def delete_entity(self, entity_id: str) -> bool:
437
+ """Delete an entity and its relationships."""
438
+ query = """
439
+ MATCH (e:Entity {id: $id})
440
+ DETACH DELETE e
441
+ RETURN count(e) as deleted
442
+ """
443
+ result = self._run_query(query, {"id": entity_id})
444
+ return result[0]["deleted"] > 0 if result else False
445
+
446
+ def close(self):
447
+ """Close the Neo4j driver connection."""
448
+ if self._driver:
449
+ self._driver.close()
450
+ self._driver = None
451
+
452
+
453
+ class InMemoryGraphStore(GraphStore):
454
+ """
455
+ In-memory graph storage for testing and development.
456
+
457
+ No external dependencies required.
458
+ """
459
+
460
+ def __init__(self):
461
+ self._entities: Dict[str, Entity] = {}
462
+ self._relationships: Dict[str, Relationship] = {}
463
+ self._outgoing: Dict[str, List[str]] = {} # entity_id -> [rel_ids]
464
+ self._incoming: Dict[str, List[str]] = {} # entity_id -> [rel_ids]
465
+
466
+ def add_entity(self, entity: Entity) -> str:
467
+ self._entities[entity.id] = entity
468
+ if entity.id not in self._outgoing:
469
+ self._outgoing[entity.id] = []
470
+ if entity.id not in self._incoming:
471
+ self._incoming[entity.id] = []
472
+ return entity.id
473
+
474
+ def add_relationship(self, relationship: Relationship) -> str:
475
+ self._relationships[relationship.id] = relationship
476
+
477
+ if relationship.source_id not in self._outgoing:
478
+ self._outgoing[relationship.source_id] = []
479
+ if relationship.id not in self._outgoing[relationship.source_id]:
480
+ self._outgoing[relationship.source_id].append(relationship.id)
481
+
482
+ if relationship.target_id not in self._incoming:
483
+ self._incoming[relationship.target_id] = []
484
+ if relationship.id not in self._incoming[relationship.target_id]:
485
+ self._incoming[relationship.target_id].append(relationship.id)
486
+
487
+ return relationship.id
488
+
489
+ def get_entity(self, entity_id: str) -> Optional[Entity]:
490
+ return self._entities.get(entity_id)
491
+
492
+ def find_entities(
493
+ self,
494
+ name: Optional[str] = None,
495
+ entity_type: Optional[str] = None,
496
+ limit: int = 10,
497
+ ) -> List[Entity]:
498
+ results = []
499
+ for entity in self._entities.values():
500
+ if name and name.lower() not in entity.name.lower():
501
+ continue
502
+ if entity_type and entity.entity_type != entity_type:
503
+ continue
504
+ results.append(entity)
505
+ if len(results) >= limit:
506
+ break
507
+ return results
508
+
509
+ def get_relationships(
510
+ self,
511
+ entity_id: str,
512
+ direction: str = "both",
513
+ relation_type: Optional[str] = None,
514
+ ) -> List[Relationship]:
515
+ rel_ids = set()
516
+
517
+ if direction in ("outgoing", "both"):
518
+ rel_ids.update(self._outgoing.get(entity_id, []))
519
+ if direction in ("incoming", "both"):
520
+ rel_ids.update(self._incoming.get(entity_id, []))
521
+
522
+ results = []
523
+ for rel_id in rel_ids:
524
+ rel = self._relationships.get(rel_id)
525
+ if rel:
526
+ if relation_type and rel.relation_type != relation_type:
527
+ continue
528
+ results.append(rel)
529
+ return results
530
+
531
+ def traverse(
532
+ self,
533
+ start_entity_id: str,
534
+ max_hops: int = 2,
535
+ relation_types: Optional[List[str]] = None,
536
+ ) -> GraphResult:
537
+ start_time = time.time()
538
+
539
+ visited_entities = {start_entity_id}
540
+ visited_relationships = set()
541
+ paths = []
542
+
543
+ def _traverse(current_id: str, depth: int, current_path: List[str]):
544
+ if depth > max_hops:
545
+ return
546
+
547
+ for rel in self.get_relationships(current_id, "both"):
548
+ if relation_types and rel.relation_type not in relation_types:
549
+ continue
550
+
551
+ visited_relationships.add(rel.id)
552
+
553
+ next_id = (
554
+ rel.target_id if rel.source_id == current_id else rel.source_id
555
+ )
556
+
557
+ if next_id not in visited_entities:
558
+ visited_entities.add(next_id)
559
+ new_path = current_path + [next_id]
560
+ paths.append(new_path)
561
+ _traverse(next_id, depth + 1, new_path)
562
+
563
+ _traverse(start_entity_id, 0, [start_entity_id])
564
+
565
+ entities = [
566
+ self._entities[eid] for eid in visited_entities if eid in self._entities
567
+ ]
568
+ relationships = [
569
+ self._relationships[rid]
570
+ for rid in visited_relationships
571
+ if rid in self._relationships
572
+ ]
573
+
574
+ query_time_ms = int((time.time() - start_time) * 1000)
575
+
576
+ return GraphResult(
577
+ entities=entities,
578
+ relationships=relationships,
579
+ paths=paths,
580
+ query_time_ms=query_time_ms,
581
+ )
582
+
583
+ def query(self, query: GraphQuery) -> GraphResult:
584
+ all_entities = {}
585
+ all_relationships = {}
586
+ all_paths = []
587
+
588
+ for entity_name in query.entities:
589
+ entities = self.find_entities(name=entity_name, limit=5)
590
+ for entity in entities:
591
+ all_entities[entity.id] = entity
592
+ result = self.traverse(
593
+ entity.id,
594
+ max_hops=query.max_hops,
595
+ relation_types=query.relation_types,
596
+ )
597
+ for e in result.entities:
598
+ all_entities[e.id] = e
599
+ for r in result.relationships:
600
+ all_relationships[r.id] = r
601
+ all_paths.extend(result.paths)
602
+
603
+ return GraphResult(
604
+ entities=list(all_entities.values())[: query.limit],
605
+ relationships=list(all_relationships.values()),
606
+ paths=all_paths[: query.limit],
607
+ query_time_ms=0,
608
+ )
609
+
610
+ def delete_entity(self, entity_id: str) -> bool:
611
+ if entity_id not in self._entities:
612
+ return False
613
+
614
+ # Delete outgoing relationships
615
+ for rel_id in list(self._outgoing.get(entity_id, [])):
616
+ if rel_id in self._relationships:
617
+ rel = self._relationships[rel_id]
618
+ # Remove from target's incoming
619
+ if rel.target_id in self._incoming:
620
+ if rel_id in self._incoming[rel.target_id]:
621
+ self._incoming[rel.target_id].remove(rel_id)
622
+ del self._relationships[rel_id]
623
+
624
+ # Delete incoming relationships
625
+ for rel_id in list(self._incoming.get(entity_id, [])):
626
+ if rel_id in self._relationships:
627
+ rel = self._relationships[rel_id]
628
+ # Remove from source's outgoing
629
+ if rel.source_id in self._outgoing:
630
+ if rel_id in self._outgoing[rel.source_id]:
631
+ self._outgoing[rel.source_id].remove(rel_id)
632
+ del self._relationships[rel_id]
633
+
634
+ # Delete entity
635
+ del self._entities[entity_id]
636
+ self._outgoing.pop(entity_id, None)
637
+ self._incoming.pop(entity_id, None)
638
+
639
+ return True
640
+
641
+ def clear(self) -> None:
642
+ """Clear all stored data."""
643
+ self._entities.clear()
644
+ self._relationships.clear()
645
+ self._outgoing.clear()
646
+ self._incoming.clear()
647
+
648
+
649
+ class BackendGraphStore(GraphStore):
650
+ """
651
+ GraphStore implementation that delegates to a GraphBackend.
652
+
653
+ This class bridges the high-level GraphStore API with the pluggable
654
+ GraphBackend interface, enabling use of different database backends
655
+ while maintaining the full GraphStore functionality.
656
+
657
+ Example:
658
+ from alma.graph import create_graph_backend, BackendGraphStore
659
+
660
+ backend = create_graph_backend("neo4j", uri="bolt://localhost:7687", ...)
661
+ store = BackendGraphStore(backend)
662
+ store.add_entity(entity)
663
+ result = store.traverse(entity.id)
664
+ """
665
+
666
+ def __init__(self, backend: "GraphBackend"):
667
+ """
668
+ Initialize with a GraphBackend.
669
+
670
+ Args:
671
+ backend: The GraphBackend implementation to use for storage.
672
+ """
673
+ self._backend = backend
674
+
675
+ @property
676
+ def backend(self) -> "GraphBackend":
677
+ """Access the underlying backend."""
678
+ return self._backend
679
+
680
+ def add_entity(self, entity: Entity) -> str:
681
+ return self._backend.add_entity(entity)
682
+
683
+ def add_relationship(self, relationship: Relationship) -> str:
684
+ return self._backend.add_relationship(relationship)
685
+
686
+ def get_entity(self, entity_id: str) -> Optional[Entity]:
687
+ return self._backend.get_entity(entity_id)
688
+
689
+ def find_entities(
690
+ self,
691
+ name: Optional[str] = None,
692
+ entity_type: Optional[str] = None,
693
+ limit: int = 10,
694
+ ) -> List[Entity]:
695
+ """Find entities by name or type using the backend."""
696
+ if hasattr(self._backend, "find_entities"):
697
+ return self._backend.find_entities(
698
+ name=name, entity_type=entity_type, limit=limit
699
+ )
700
+ # Fallback to search_entities if find_entities not available
701
+ if name:
702
+ return self._backend.search_entities(query=name, top_k=limit)
703
+ return self._backend.get_entities(entity_type=entity_type, limit=limit)
704
+
705
+ def get_relationships(
706
+ self,
707
+ entity_id: str,
708
+ direction: str = "both",
709
+ relation_type: Optional[str] = None,
710
+ ) -> List[Relationship]:
711
+ """Get relationships for an entity."""
712
+ if hasattr(self._backend, "get_relationships_directional"):
713
+ return self._backend.get_relationships_directional(
714
+ entity_id=entity_id, direction=direction, relation_type=relation_type
715
+ )
716
+ # Fallback - get all relationships and filter
717
+ all_rels = self._backend.get_relationships(entity_id)
718
+ if relation_type:
719
+ all_rels = [r for r in all_rels if r.relation_type == relation_type]
720
+ return all_rels
721
+
722
+ def traverse(
723
+ self,
724
+ start_entity_id: str,
725
+ max_hops: int = 2,
726
+ relation_types: Optional[List[str]] = None,
727
+ ) -> GraphResult:
728
+ """Traverse the graph from a starting entity."""
729
+ start_time = time.time()
730
+
731
+ visited_entities = {start_entity_id}
732
+ visited_relationships: Dict[str, Relationship] = {}
733
+ paths: List[List[str]] = []
734
+
735
+ def _traverse(current_id: str, depth: int, current_path: List[str]):
736
+ if depth > max_hops:
737
+ return
738
+
739
+ rels = self._backend.get_relationships(current_id)
740
+ for rel in rels:
741
+ if relation_types and rel.relation_type not in relation_types:
742
+ continue
743
+
744
+ visited_relationships[rel.id] = rel
745
+
746
+ next_id = (
747
+ rel.target_id if rel.source_id == current_id else rel.source_id
748
+ )
749
+
750
+ if next_id not in visited_entities:
751
+ visited_entities.add(next_id)
752
+ new_path = current_path + [next_id]
753
+ paths.append(new_path)
754
+ _traverse(next_id, depth + 1, new_path)
755
+
756
+ _traverse(start_entity_id, 0, [start_entity_id])
757
+
758
+ # Fetch all visited entities
759
+ entities = []
760
+ for eid in visited_entities:
761
+ entity = self._backend.get_entity(eid)
762
+ if entity:
763
+ entities.append(entity)
764
+
765
+ query_time_ms = int((time.time() - start_time) * 1000)
766
+
767
+ return GraphResult(
768
+ entities=entities,
769
+ relationships=list(visited_relationships.values()),
770
+ paths=paths,
771
+ query_time_ms=query_time_ms,
772
+ )
773
+
774
+ def query(self, query: GraphQuery) -> GraphResult:
775
+ """Execute a graph query."""
776
+ all_entities: Dict[str, Entity] = {}
777
+ all_relationships: Dict[str, Relationship] = {}
778
+ all_paths: List[List[str]] = []
779
+
780
+ for entity_name in query.entities:
781
+ entities = self.find_entities(name=entity_name, limit=5)
782
+ for entity in entities:
783
+ all_entities[entity.id] = entity
784
+ result = self.traverse(
785
+ entity.id,
786
+ max_hops=query.max_hops,
787
+ relation_types=query.relation_types,
788
+ )
789
+ for e in result.entities:
790
+ all_entities[e.id] = e
791
+ for r in result.relationships:
792
+ all_relationships[r.id] = r
793
+ all_paths.extend(result.paths)
794
+
795
+ return GraphResult(
796
+ entities=list(all_entities.values())[: query.limit],
797
+ relationships=list(all_relationships.values()),
798
+ paths=all_paths[: query.limit],
799
+ query_time_ms=0,
800
+ )
801
+
802
+ def delete_entity(self, entity_id: str) -> bool:
803
+ return self._backend.delete_entity(entity_id)
804
+
805
+ def close(self) -> None:
806
+ """Close the backend connection."""
807
+ self._backend.close()
808
+
809
+
810
+ def create_graph_store(
811
+ provider: str = "memory",
812
+ **kwargs,
813
+ ) -> GraphStore:
814
+ """
815
+ Factory function to create a graph store.
816
+
817
+ Args:
818
+ provider: "neo4j", "memory", or "backend"
819
+ **kwargs: Provider-specific arguments
820
+ For "neo4j": uri, username, password, database
821
+ For "backend": backend (GraphBackend instance)
822
+
823
+ Returns:
824
+ Configured GraphStore instance
825
+
826
+ Note:
827
+ Amazon Neptune support is planned for a future release.
828
+
829
+ Example:
830
+ # Create in-memory store
831
+ store = create_graph_store("memory")
832
+
833
+ # Create Neo4j store
834
+ store = create_graph_store(
835
+ "neo4j",
836
+ uri="bolt://localhost:7687",
837
+ username="neo4j",
838
+ password="password"
839
+ )
840
+
841
+ # Create store using a backend
842
+ from alma.graph import create_graph_backend
843
+ backend = create_graph_backend("neo4j", uri="...", ...)
844
+ store = create_graph_store("backend", backend=backend)
845
+ """
846
+ if provider == "neo4j":
847
+ return Neo4jGraphStore(**kwargs)
848
+ elif provider == "neptune":
849
+ # Neptune support is planned for a future release
850
+ raise NotImplementedError(
851
+ "Neptune support is not yet implemented. "
852
+ "Use 'neo4j' or 'memory' providers instead."
853
+ )
854
+ elif provider == "backend":
855
+ backend = kwargs.get("backend")
856
+ if backend is None:
857
+ raise ValueError("'backend' argument required for 'backend' provider")
858
+ return BackendGraphStore(backend)
859
+ else:
860
+ return InMemoryGraphStore()