graphiti-core 0.11.6rc9__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (33) hide show
  1. graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
  2. graphiti_core/driver/__init__.py +17 -0
  3. graphiti_core/driver/driver.py +66 -0
  4. graphiti_core/driver/falkordb_driver.py +132 -0
  5. graphiti_core/driver/neo4j_driver.py +61 -0
  6. graphiti_core/edges.py +66 -40
  7. graphiti_core/embedder/azure_openai.py +64 -0
  8. graphiti_core/embedder/gemini.py +14 -3
  9. graphiti_core/graph_queries.py +149 -0
  10. graphiti_core/graphiti.py +41 -14
  11. graphiti_core/graphiti_types.py +2 -2
  12. graphiti_core/helpers.py +9 -4
  13. graphiti_core/llm_client/__init__.py +16 -0
  14. graphiti_core/llm_client/azure_openai_client.py +73 -0
  15. graphiti_core/llm_client/gemini_client.py +4 -1
  16. graphiti_core/models/edges/edge_db_queries.py +2 -4
  17. graphiti_core/nodes.py +31 -31
  18. graphiti_core/prompts/dedupe_edges.py +52 -1
  19. graphiti_core/prompts/dedupe_nodes.py +79 -4
  20. graphiti_core/prompts/extract_edges.py +50 -5
  21. graphiti_core/prompts/invalidate_edges.py +1 -1
  22. graphiti_core/search/search.py +6 -10
  23. graphiti_core/search/search_filters.py +23 -9
  24. graphiti_core/search/search_utils.py +250 -189
  25. graphiti_core/utils/bulk_utils.py +38 -11
  26. graphiti_core/utils/maintenance/community_operations.py +6 -7
  27. graphiti_core/utils/maintenance/edge_operations.py +149 -19
  28. graphiti_core/utils/maintenance/graph_data_operations.py +13 -42
  29. graphiti_core/utils/maintenance/node_operations.py +52 -71
  30. {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/METADATA +14 -5
  31. {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/RECORD +33 -26
  32. {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/LICENSE +0 -0
  33. {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/WHEEL +0 -0
@@ -20,22 +20,24 @@ from collections import defaultdict
20
20
  from datetime import datetime
21
21
  from math import ceil
22
22
 
23
- from neo4j import AsyncDriver, AsyncManagedTransaction
24
23
  from numpy import dot, sqrt
25
24
  from pydantic import BaseModel
26
25
  from typing_extensions import Any
27
26
 
27
+ from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
28
28
  from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
29
29
  from graphiti_core.embedder import EmbedderClient
30
+ from graphiti_core.graph_queries import (
31
+ get_entity_edge_save_bulk_query,
32
+ get_entity_node_save_bulk_query,
33
+ )
30
34
  from graphiti_core.graphiti_types import GraphitiClients
31
35
  from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
32
36
  from graphiti_core.llm_client import LLMClient
33
37
  from graphiti_core.models.edges.edge_db_queries import (
34
- ENTITY_EDGE_SAVE_BULK,
35
38
  EPISODIC_EDGE_SAVE_BULK,
36
39
  )
37
40
  from graphiti_core.models.nodes.node_db_queries import (
38
- ENTITY_NODE_SAVE_BULK,
39
41
  EPISODIC_NODE_SAVE_BULK,
40
42
  )
41
43
  from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
@@ -73,7 +75,7 @@ class RawEpisode(BaseModel):
73
75
 
74
76
 
75
77
  async def retrieve_previous_episodes_bulk(
76
- driver: AsyncDriver, episodes: list[EpisodicNode]
78
+ driver: GraphDriver, episodes: list[EpisodicNode]
77
79
  ) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
78
80
  previous_episodes_list = await semaphore_gather(
79
81
  *[
@@ -91,14 +93,15 @@ async def retrieve_previous_episodes_bulk(
91
93
 
92
94
 
93
95
  async def add_nodes_and_edges_bulk(
94
- driver: AsyncDriver,
96
+ driver: GraphDriver,
95
97
  episodic_nodes: list[EpisodicNode],
96
98
  episodic_edges: list[EpisodicEdge],
97
99
  entity_nodes: list[EntityNode],
98
100
  entity_edges: list[EntityEdge],
99
101
  embedder: EmbedderClient,
100
102
  ):
101
- async with driver.session(database=DEFAULT_DATABASE) as session:
103
+ session = driver.session(database=DEFAULT_DATABASE)
104
+ try:
102
105
  await session.execute_write(
103
106
  add_nodes_and_edges_bulk_tx,
104
107
  episodic_nodes,
@@ -106,16 +109,20 @@ async def add_nodes_and_edges_bulk(
106
109
  entity_nodes,
107
110
  entity_edges,
108
111
  embedder,
112
+ driver=driver,
109
113
  )
114
+ finally:
115
+ await session.close()
110
116
 
111
117
 
112
118
  async def add_nodes_and_edges_bulk_tx(
113
- tx: AsyncManagedTransaction,
119
+ tx: GraphDriverSession,
114
120
  episodic_nodes: list[EpisodicNode],
115
121
  episodic_edges: list[EpisodicEdge],
116
122
  entity_nodes: list[EntityNode],
117
123
  entity_edges: list[EntityEdge],
118
124
  embedder: EmbedderClient,
125
+ driver: GraphDriver,
119
126
  ):
120
127
  episodes = [dict(episode) for episode in episodic_nodes]
121
128
  for episode in episodes:
@@ -137,16 +144,36 @@ async def add_nodes_and_edges_bulk_tx(
137
144
  entity_data['labels'] = list(set(node.labels + ['Entity']))
138
145
  nodes.append(entity_data)
139
146
 
147
+ edges: list[dict[str, Any]] = []
140
148
  for edge in entity_edges:
141
149
  if edge.fact_embedding is None:
142
150
  await edge.generate_embedding(embedder)
151
+ edge_data: dict[str, Any] = {
152
+ 'uuid': edge.uuid,
153
+ 'source_node_uuid': edge.source_node_uuid,
154
+ 'target_node_uuid': edge.target_node_uuid,
155
+ 'name': edge.name,
156
+ 'fact': edge.fact,
157
+ 'fact_embedding': edge.fact_embedding,
158
+ 'group_id': edge.group_id,
159
+ 'episodes': edge.episodes,
160
+ 'created_at': edge.created_at,
161
+ 'expired_at': edge.expired_at,
162
+ 'valid_at': edge.valid_at,
163
+ 'invalid_at': edge.invalid_at,
164
+ }
165
+
166
+ edge_data.update(edge.attributes or {})
167
+ edges.append(edge_data)
143
168
 
144
169
  await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
145
- await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
170
+ entity_node_save_bulk = get_entity_node_save_bulk_query(nodes, driver.provider)
171
+ await tx.run(entity_node_save_bulk, nodes=nodes)
146
172
  await tx.run(
147
173
  EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
148
174
  )
149
- await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[edge.model_dump() for edge in entity_edges])
175
+ entity_edge_save_bulk = get_entity_edge_save_bulk_query(driver.provider)
176
+ await tx.run(entity_edge_save_bulk, entity_edges=edges)
150
177
 
151
178
 
152
179
  async def extract_nodes_and_edges_bulk(
@@ -193,7 +220,7 @@ async def extract_nodes_and_edges_bulk(
193
220
 
194
221
 
195
222
  async def dedupe_nodes_bulk(
196
- driver: AsyncDriver,
223
+ driver: GraphDriver,
197
224
  llm_client: LLMClient,
198
225
  extracted_nodes: list[EntityNode],
199
226
  ) -> tuple[list[EntityNode], dict[str, str]]:
@@ -229,7 +256,7 @@ async def dedupe_nodes_bulk(
229
256
 
230
257
 
231
258
  async def dedupe_edges_bulk(
232
- driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
259
+ driver: GraphDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
233
260
  ) -> list[EntityEdge]:
234
261
  # First compress edges
235
262
  compressed_edges = await compress_edges(llm_client, extracted_edges)
@@ -2,9 +2,9 @@ import asyncio
2
2
  import logging
3
3
  from collections import defaultdict
4
4
 
5
- from neo4j import AsyncDriver
6
5
  from pydantic import BaseModel
7
6
 
7
+ from graphiti_core.driver.driver import GraphDriver
8
8
  from graphiti_core.edges import CommunityEdge
9
9
  from graphiti_core.embedder import EmbedderClient
10
10
  from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
@@ -26,7 +26,7 @@ class Neighbor(BaseModel):
26
26
 
27
27
 
28
28
  async def get_community_clusters(
29
- driver: AsyncDriver, group_ids: list[str] | None
29
+ driver: GraphDriver, group_ids: list[str] | None
30
30
  ) -> list[list[EntityNode]]:
31
31
  community_clusters: list[list[EntityNode]] = []
32
32
 
@@ -95,7 +95,6 @@ def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]:
95
95
  community_candidates: dict[int, int] = defaultdict(int)
96
96
  for neighbor in neighbors:
97
97
  community_candidates[community_map[neighbor.node_uuid]] += neighbor.edge_count
98
-
99
98
  community_lst = [
100
99
  (count, community) for community, count in community_candidates.items()
101
100
  ]
@@ -194,7 +193,7 @@ async def build_community(
194
193
 
195
194
 
196
195
  async def build_communities(
197
- driver: AsyncDriver, llm_client: LLMClient, group_ids: list[str] | None
196
+ driver: GraphDriver, llm_client: LLMClient, group_ids: list[str] | None
198
197
  ) -> tuple[list[CommunityNode], list[CommunityEdge]]:
199
198
  community_clusters = await get_community_clusters(driver, group_ids)
200
199
 
@@ -219,7 +218,7 @@ async def build_communities(
219
218
  return community_nodes, community_edges
220
219
 
221
220
 
222
- async def remove_communities(driver: AsyncDriver):
221
+ async def remove_communities(driver: GraphDriver):
223
222
  await driver.execute_query(
224
223
  """
225
224
  MATCH (c:Community)
@@ -230,7 +229,7 @@ async def remove_communities(driver: AsyncDriver):
230
229
 
231
230
 
232
231
  async def determine_entity_community(
233
- driver: AsyncDriver, entity: EntityNode
232
+ driver: GraphDriver, entity: EntityNode
234
233
  ) -> tuple[CommunityNode | None, bool]:
235
234
  # Check if the node is already part of a community
236
235
  records, _, _ = await driver.execute_query(
@@ -291,7 +290,7 @@ async def determine_entity_community(
291
290
 
292
291
 
293
292
  async def update_community(
294
- driver: AsyncDriver, llm_client: LLMClient, embedder: EmbedderClient, entity: EntityNode
293
+ driver: GraphDriver, llm_client: LLMClient, embedder: EmbedderClient, entity: EntityNode
295
294
  ):
296
295
  community, is_new = await determine_entity_community(driver, entity)
297
296
 
@@ -18,6 +18,8 @@ import logging
18
18
  from datetime import datetime
19
19
  from time import time
20
20
 
21
+ from pydantic import BaseModel
22
+
21
23
  from graphiti_core.edges import (
22
24
  CommunityEdge,
23
25
  EntityEdge,
@@ -35,9 +37,6 @@ from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
35
37
  from graphiti_core.search.search_filters import SearchFilters
36
38
  from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
37
39
  from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
38
- from graphiti_core.utils.maintenance.temporal_operations import (
39
- get_edge_contradictions,
40
- )
41
40
 
42
41
  logger = logging.getLogger(__name__)
43
42
 
@@ -86,20 +85,32 @@ async def extract_edges(
86
85
  nodes: list[EntityNode],
87
86
  previous_episodes: list[EpisodicNode],
88
87
  group_id: str = '',
88
+ edge_types: dict[str, BaseModel] | None = None,
89
89
  ) -> list[EntityEdge]:
90
90
  start = time()
91
91
 
92
92
  extract_edges_max_tokens = 16384
93
93
  llm_client = clients.llm_client
94
94
 
95
- node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
95
+ edge_types_context = (
96
+ [
97
+ {
98
+ 'fact_type_name': type_name,
99
+ 'fact_type_description': type_model.__doc__,
100
+ }
101
+ for type_name, type_model in edge_types.items()
102
+ ]
103
+ if edge_types is not None
104
+ else []
105
+ )
96
106
 
97
107
  # Prepare context for LLM
98
108
  context = {
99
109
  'episode_content': episode.content,
100
- 'nodes': [node.name for node in nodes],
110
+ 'nodes': [{'id': idx, 'name': node.name} for idx, node in enumerate(nodes)],
101
111
  'previous_episodes': [ep.content for ep in previous_episodes],
102
112
  'reference_time': episode.valid_at,
113
+ 'edge_types': edge_types_context,
103
114
  'custom_prompt': '',
104
115
  }
105
116
 
@@ -148,6 +159,16 @@ async def extract_edges(
148
159
  valid_at_datetime = None
149
160
  invalid_at_datetime = None
150
161
 
162
+ source_node_idx = edge_data.get('source_entity_id', -1)
163
+ target_node_idx = edge_data.get('target_entity_id', -1)
164
+ if not (-1 < source_node_idx < len(nodes) and -1 < target_node_idx < len(nodes)):
165
+ logger.warning(
166
+ f'WARNING: source or target node not filled {edge_data.get("edge_name")}. source_node_uuid: {source_node_idx} and target_node_uuid: {target_node_idx} '
167
+ )
168
+ continue
169
+ source_node_uuid = nodes[source_node_idx].uuid
170
+ target_node_uuid = nodes[edge_data.get('target_entity_id')].uuid
171
+
151
172
  if valid_at:
152
173
  try:
153
174
  valid_at_datetime = ensure_utc(
@@ -164,12 +185,8 @@ async def extract_edges(
164
185
  except ValueError as e:
165
186
  logger.warning(f'WARNING: Error parsing invalid_at date: {e}. Input: {invalid_at}')
166
187
  edge = EntityEdge(
167
- source_node_uuid=node_uuids_by_name_map.get(
168
- edge_data.get('source_entity_name', ''), ''
169
- ),
170
- target_node_uuid=node_uuids_by_name_map.get(
171
- edge_data.get('target_entity_name', ''), ''
172
- ),
188
+ source_node_uuid=source_node_uuid,
189
+ target_node_uuid=target_node_uuid,
173
190
  name=edge_data.get('relation_type', ''),
174
191
  group_id=group_id,
175
192
  fact=edge_data.get('fact', ''),
@@ -236,16 +253,18 @@ async def resolve_extracted_edges(
236
253
  clients: GraphitiClients,
237
254
  extracted_edges: list[EntityEdge],
238
255
  episode: EpisodicNode,
256
+ entities: list[EntityNode],
257
+ edge_types: dict[str, BaseModel],
258
+ edge_type_map: dict[tuple[str, str], list[str]],
239
259
  ) -> tuple[list[EntityEdge], list[EntityEdge]]:
240
260
  driver = clients.driver
241
261
  llm_client = clients.llm_client
242
262
  embedder = clients.embedder
243
-
244
263
  await create_entity_edge_embeddings(embedder, extracted_edges)
245
264
 
246
265
  search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather(
247
266
  get_relevant_edges(driver, extracted_edges, SearchFilters()),
248
- get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters()),
267
+ get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2),
249
268
  )
250
269
 
251
270
  related_edges_lists, edge_invalidation_candidates = search_results
@@ -254,15 +273,50 @@ async def resolve_extracted_edges(
254
273
  f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
255
274
  )
256
275
 
276
+ # Build entity hash table
277
+ uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
278
+
279
+ # Determine which edge types are relevant for each edge
280
+ edge_types_lst: list[dict[str, BaseModel]] = []
281
+ for extracted_edge in extracted_edges:
282
+ source_node_labels = uuid_entity_map[extracted_edge.source_node_uuid].labels + ['Entity']
283
+ target_node_labels = uuid_entity_map[extracted_edge.target_node_uuid].labels + ['Entity']
284
+ label_tuples = [
285
+ (source_label, target_label)
286
+ for source_label in source_node_labels
287
+ for target_label in target_node_labels
288
+ ]
289
+
290
+ extracted_edge_types = {}
291
+ for label_tuple in label_tuples:
292
+ type_names = edge_type_map.get(label_tuple, [])
293
+ for type_name in type_names:
294
+ type_model = edge_types.get(type_name)
295
+ if type_model is None:
296
+ continue
297
+
298
+ extracted_edge_types[type_name] = type_model
299
+
300
+ edge_types_lst.append(extracted_edge_types)
301
+
257
302
  # resolve edges with related edges in the graph and find invalidation candidates
258
303
  results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
259
304
  await semaphore_gather(
260
305
  *[
261
306
  resolve_extracted_edge(
262
- llm_client, extracted_edge, related_edges, existing_edges, episode
307
+ llm_client,
308
+ extracted_edge,
309
+ related_edges,
310
+ existing_edges,
311
+ episode,
312
+ extracted_edge_types,
263
313
  )
264
- for extracted_edge, related_edges, existing_edges in zip(
265
- extracted_edges, related_edges_lists, edge_invalidation_candidates, strict=True
314
+ for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
315
+ extracted_edges,
316
+ related_edges_lists,
317
+ edge_invalidation_candidates,
318
+ edge_types_lst,
319
+ strict=True,
266
320
  )
267
321
  ]
268
322
  )
@@ -326,10 +380,86 @@ async def resolve_extracted_edge(
326
380
  related_edges: list[EntityEdge],
327
381
  existing_edges: list[EntityEdge],
328
382
  episode: EpisodicNode,
383
+ edge_types: dict[str, BaseModel] | None = None,
329
384
  ) -> tuple[EntityEdge, list[EntityEdge]]:
330
- resolved_edge, invalidation_candidates = await semaphore_gather(
331
- dedupe_extracted_edge(llm_client, extracted_edge, related_edges, episode),
332
- get_edge_contradictions(llm_client, extracted_edge, existing_edges),
385
+ if len(related_edges) == 0 and len(existing_edges) == 0:
386
+ return extracted_edge, []
387
+
388
+ start = time()
389
+
390
+ # Prepare context for LLM
391
+ related_edges_context = [
392
+ {'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
393
+ ]
394
+
395
+ invalidation_edge_candidates_context = [
396
+ {'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
397
+ ]
398
+
399
+ edge_types_context = (
400
+ [
401
+ {
402
+ 'fact_type_id': i,
403
+ 'fact_type_name': type_name,
404
+ 'fact_type_description': type_model.__doc__,
405
+ }
406
+ for i, (type_name, type_model) in enumerate(edge_types.items())
407
+ ]
408
+ if edge_types is not None
409
+ else []
410
+ )
411
+
412
+ context = {
413
+ 'existing_edges': related_edges_context,
414
+ 'new_edge': extracted_edge.fact,
415
+ 'edge_invalidation_candidates': invalidation_edge_candidates_context,
416
+ 'edge_types': edge_types_context,
417
+ }
418
+
419
+ llm_response = await llm_client.generate_response(
420
+ prompt_library.dedupe_edges.resolve_edge(context),
421
+ response_model=EdgeDuplicate,
422
+ model_size=ModelSize.small,
423
+ )
424
+
425
+ duplicate_fact_id: int = llm_response.get('duplicate_fact_id', -1)
426
+
427
+ resolved_edge = (
428
+ related_edges[duplicate_fact_id]
429
+ if 0 <= duplicate_fact_id < len(related_edges)
430
+ else extracted_edge
431
+ )
432
+
433
+ if duplicate_fact_id >= 0 and episode is not None:
434
+ resolved_edge.episodes.append(episode.uuid)
435
+
436
+ contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])
437
+
438
+ invalidation_candidates: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts]
439
+
440
+ fact_type: str = str(llm_response.get('fact_type'))
441
+ if fact_type.upper() != 'DEFAULT' and edge_types is not None:
442
+ resolved_edge.name = fact_type
443
+
444
+ edge_attributes_context = {
445
+ 'episode_content': episode.content,
446
+ 'reference_time': episode.valid_at,
447
+ 'fact': resolved_edge.fact,
448
+ }
449
+
450
+ edge_model = edge_types.get(fact_type)
451
+
452
+ edge_attributes_response = await llm_client.generate_response(
453
+ prompt_library.extract_edges.extract_attributes(edge_attributes_context),
454
+ response_model=edge_model, # type: ignore
455
+ model_size=ModelSize.small,
456
+ )
457
+
458
+ resolved_edge.attributes = edge_attributes_response
459
+
460
+ end = time()
461
+ logger.debug(
462
+ f'Resolved Edge: {extracted_edge.name} is {resolved_edge.name}, in {(end - start) * 1000} ms'
333
463
  )
334
464
 
335
465
  now = utc_now()
@@ -17,9 +17,10 @@ limitations under the License.
17
17
  import logging
18
18
  from datetime import datetime, timezone
19
19
 
20
- from neo4j import AsyncDriver
21
20
  from typing_extensions import LiteralString
22
21
 
22
+ from graphiti_core.driver.driver import GraphDriver
23
+ from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
23
24
  from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
24
25
  from graphiti_core.nodes import EpisodeType, EpisodicNode
25
26
 
@@ -28,7 +29,7 @@ EPISODE_WINDOW_LEN = 3
28
29
  logger = logging.getLogger(__name__)
29
30
 
30
31
 
31
- async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bool = False):
32
+ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
32
33
  if delete_existing:
33
34
  records, _, _ = await driver.execute_query(
34
35
  """
@@ -47,39 +48,9 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
47
48
  for name in index_names
48
49
  ]
49
50
  )
51
+ range_indices: list[LiteralString] = get_range_indices(driver.provider)
50
52
 
51
- range_indices: list[LiteralString] = [
52
- 'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
53
- 'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
54
- 'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
55
- 'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
56
- 'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
57
- 'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
58
- 'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
59
- 'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
60
- 'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
61
- 'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
62
- 'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
63
- 'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
64
- 'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
65
- 'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
66
- 'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
67
- 'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
68
- 'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
69
- 'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
70
- 'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
71
- ]
72
-
73
- fulltext_indices: list[LiteralString] = [
74
- """CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
75
- FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
76
- """CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
77
- FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
78
- """CREATE FULLTEXT INDEX community_name IF NOT EXISTS
79
- FOR (n:Community) ON EACH [n.name, n.group_id]""",
80
- """CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
81
- FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
82
- ]
53
+ fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
83
54
 
84
55
  index_queries: list[LiteralString] = range_indices + fulltext_indices
85
56
 
@@ -94,7 +65,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
94
65
  )
95
66
 
96
67
 
97
- async def clear_data(driver: AsyncDriver, group_ids: list[str] | None = None):
68
+ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
98
69
  async with driver.session(database=DEFAULT_DATABASE) as session:
99
70
 
100
71
  async def delete_all(tx):
@@ -113,7 +84,7 @@ async def clear_data(driver: AsyncDriver, group_ids: list[str] | None = None):
113
84
 
114
85
 
115
86
  async def retrieve_episodes(
116
- driver: AsyncDriver,
87
+ driver: GraphDriver,
117
88
  reference_time: datetime,
118
89
  last_n: int = EPISODE_WINDOW_LEN,
119
90
  group_ids: list[str] | None = None,
@@ -123,7 +94,7 @@ async def retrieve_episodes(
123
94
  Retrieve the last n episodic nodes from the graph.
124
95
 
125
96
  Args:
126
- driver (AsyncDriver): The Neo4j driver instance.
97
+ driver (Driver): The Neo4j driver instance.
127
98
  reference_time (datetime): The reference time to filter episodes. Only episodes with a valid_at timestamp
128
99
  less than or equal to this reference_time will be retrieved. This allows for
129
100
  querying the graph's state at a specific point in time.
@@ -140,8 +111,8 @@ async def retrieve_episodes(
140
111
 
141
112
  query: LiteralString = (
142
113
  """
143
- MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
144
- """
114
+ MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
115
+ """
145
116
  + group_id_filter
146
117
  + source_filter
147
118
  + """
@@ -157,8 +128,7 @@ async def retrieve_episodes(
157
128
  LIMIT $num_episodes
158
129
  """
159
130
  )
160
-
161
- result = await driver.execute_query(
131
+ result, _, _ = await driver.execute_query(
162
132
  query,
163
133
  reference_time=reference_time,
164
134
  source=source.name if source is not None else None,
@@ -166,6 +136,7 @@ async def retrieve_episodes(
166
136
  group_ids=group_ids,
167
137
  database_=DEFAULT_DATABASE,
168
138
  )
139
+
169
140
  episodes = [
170
141
  EpisodicNode(
171
142
  content=record['content'],
@@ -179,6 +150,6 @@ async def retrieve_episodes(
179
150
  name=record['name'],
180
151
  source_description=record['source_description'],
181
152
  )
182
- for record in result.records
153
+ for record in result
183
154
  ]
184
155
  return list(reversed(episodes)) # Return in chronological order