graphiti-core 0.18.0__py3-none-any.whl → 0.18.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -23,7 +23,7 @@ import numpy as np
23
23
  from numpy._typing import NDArray
24
24
  from typing_extensions import LiteralString
25
25
 
26
- from graphiti_core.driver.driver import GraphDriver
26
+ from graphiti_core.driver.driver import GraphDriver, GraphProvider
27
27
  from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
28
28
  from graphiti_core.graph_queries import (
29
29
  get_nodes_query,
@@ -36,6 +36,8 @@ from graphiti_core.helpers import (
36
36
  normalize_l2,
37
37
  semaphore_gather,
38
38
  )
39
+ from graphiti_core.models.edges.edge_db_queries import ENTITY_EDGE_RETURN
40
+ from graphiti_core.models.nodes.node_db_queries import COMMUNITY_NODE_RETURN, EPISODIC_NODE_RETURN
39
41
  from graphiti_core.nodes import (
40
42
  ENTITY_NODE_RETURN,
41
43
  CommunityNode,
@@ -100,20 +102,13 @@ async def get_mentioned_nodes(
100
102
  ) -> list[EntityNode]:
101
103
  episode_uuids = [episode.uuid for episode in episodes]
102
104
 
103
- query = """
104
- MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
105
+ records, _, _ = await driver.execute_query(
106
+ """
107
+ MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity)
108
+ WHERE episode.uuid IN $uuids
105
109
  RETURN DISTINCT
106
- n.uuid As uuid,
107
- n.group_id AS group_id,
108
- n.name AS name,
109
- n.created_at AS created_at,
110
- n.summary AS summary,
111
- labels(n) AS labels,
112
- properties(n) AS attributes
113
110
  """
114
-
115
- records, _, _ = await driver.execute_query(
116
- query,
111
+ + ENTITY_NODE_RETURN,
117
112
  uuids=episode_uuids,
118
113
  routing_='r',
119
114
  )
@@ -128,18 +123,13 @@ async def get_communities_by_nodes(
128
123
  ) -> list[CommunityNode]:
129
124
  node_uuids = [node.uuid for node in nodes]
130
125
 
131
- query = """
132
- MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
133
- RETURN DISTINCT
134
- c.uuid As uuid,
135
- c.group_id AS group_id,
136
- c.name AS name,
137
- c.created_at AS created_at,
138
- c.summary AS summary
139
- """
140
-
141
126
  records, _, _ = await driver.execute_query(
142
- query,
127
+ """
128
+ MATCH (n:Community)-[:HAS_MEMBER]->(m:Entity)
129
+ WHERE m.uuid IN $uuids
130
+ RETURN DISTINCT
131
+ """
132
+ + COMMUNITY_NODE_RETURN,
143
133
  uuids=node_uuids,
144
134
  routing_='r',
145
135
  )
@@ -164,38 +154,30 @@ async def edge_fulltext_search(
164
154
  filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
165
155
 
166
156
  query = (
167
- get_relationships_query('edge_name_and_fact', db_type=driver.provider)
157
+ get_relationships_query('edge_name_and_fact', provider=driver.provider)
168
158
  + """
169
159
  YIELD relationship AS rel, score
170
- MATCH (n:Entity)-[r:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
171
- WHERE r.group_id IN $group_ids """
160
+ MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
161
+ WHERE e.group_id IN $group_ids """
172
162
  + filter_query
173
163
  + """
174
- WITH r, score, startNode(r) AS n, endNode(r) AS m
164
+ WITH e, score, n, m
175
165
  RETURN
176
- r.uuid AS uuid,
177
- r.group_id AS group_id,
178
- n.uuid AS source_node_uuid,
179
- m.uuid AS target_node_uuid,
180
- r.created_at AS created_at,
181
- r.name AS name,
182
- r.fact AS fact,
183
- r.episodes AS episodes,
184
- r.expired_at AS expired_at,
185
- r.valid_at AS valid_at,
186
- r.invalid_at AS invalid_at,
187
- properties(r) AS attributes
188
- ORDER BY score DESC LIMIT $limit
166
+ """
167
+ + ENTITY_EDGE_RETURN
168
+ + """
169
+ ORDER BY score DESC
170
+ LIMIT $limit
189
171
  """
190
172
  )
191
173
 
192
174
  records, _, _ = await driver.execute_query(
193
175
  query,
194
- params=filter_params,
195
176
  query=fuzzy_query,
196
177
  group_ids=group_ids,
197
178
  limit=limit,
198
179
  routing_='r',
180
+ **filter_params,
199
181
  )
200
182
 
201
183
  edges = [get_entity_edge_from_record(record) for record in records]
@@ -219,58 +201,47 @@ async def edge_similarity_search(
219
201
  filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
220
202
  query_params.update(filter_params)
221
203
 
222
- group_filter_query: LiteralString = 'WHERE r.group_id IS NOT NULL'
204
+ group_filter_query: LiteralString = 'WHERE e.group_id IS NOT NULL'
223
205
  if group_ids is not None:
224
- group_filter_query += '\nAND r.group_id IN $group_ids'
206
+ group_filter_query += '\nAND e.group_id IN $group_ids'
225
207
  query_params['group_ids'] = group_ids
226
- query_params['source_node_uuid'] = source_node_uuid
227
- query_params['target_node_uuid'] = target_node_uuid
228
208
 
229
209
  if source_node_uuid is not None:
230
- group_filter_query += '\nAND (n.uuid IN [$source_uuid, $target_uuid])'
210
+ query_params['source_uuid'] = source_node_uuid
211
+ group_filter_query += '\nAND (n.uuid = $source_uuid)'
231
212
 
232
213
  if target_node_uuid is not None:
233
- group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
214
+ query_params['target_uuid'] = target_node_uuid
215
+ group_filter_query += '\nAND (m.uuid = $target_uuid)'
234
216
 
235
217
  query = (
236
218
  RUNTIME_QUERY
237
219
  + """
238
- MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
220
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
239
221
  """
240
222
  + group_filter_query
241
223
  + filter_query
242
224
  + """
243
- WITH DISTINCT r, """
244
- + get_vector_cosine_func_query('r.fact_embedding', '$search_vector', driver.provider)
225
+ WITH DISTINCT e, n, m, """
226
+ + get_vector_cosine_func_query('e.fact_embedding', '$search_vector', driver.provider)
245
227
  + """ AS score
246
228
  WHERE score > $min_score
247
229
  RETURN
248
- r.uuid AS uuid,
249
- r.group_id AS group_id,
250
- startNode(r).uuid AS source_node_uuid,
251
- endNode(r).uuid AS target_node_uuid,
252
- r.created_at AS created_at,
253
- r.name AS name,
254
- r.fact AS fact,
255
- r.episodes AS episodes,
256
- r.expired_at AS expired_at,
257
- r.valid_at AS valid_at,
258
- r.invalid_at AS invalid_at,
259
- properties(r) AS attributes
230
+ """
231
+ + ENTITY_EDGE_RETURN
232
+ + """
260
233
  ORDER BY score DESC
261
234
  LIMIT $limit
262
235
  """
263
236
  )
264
- records, header, _ = await driver.execute_query(
237
+
238
+ records, _, _ = await driver.execute_query(
265
239
  query,
266
- params=query_params,
267
240
  search_vector=search_vector,
268
- source_uuid=source_node_uuid,
269
- target_uuid=target_node_uuid,
270
- group_ids=group_ids,
271
241
  limit=limit,
272
242
  min_score=min_score,
273
243
  routing_='r',
244
+ **query_params,
274
245
  )
275
246
 
276
247
  edges = [get_entity_edge_from_record(record) for record in records]
@@ -293,41 +264,31 @@ async def edge_bfs_search(
293
264
  filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
294
265
 
295
266
  query = (
267
+ f"""
268
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
269
+ MATCH path = (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity)
270
+ UNWIND relationships(path) AS rel
271
+ MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
272
+ WHERE e.uuid = rel.uuid
273
+ AND e.group_id IN $group_ids
296
274
  """
297
- UNWIND $bfs_origin_node_uuids AS origin_uuid
298
- MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
299
- UNWIND relationships(path) AS rel
300
- MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
301
- WHERE r.uuid = rel.uuid
302
- AND r.group_id IN $group_ids
303
- """
304
275
  + filter_query
305
- + """
306
- RETURN DISTINCT
307
- r.uuid AS uuid,
308
- r.group_id AS group_id,
309
- startNode(r).uuid AS source_node_uuid,
310
- endNode(r).uuid AS target_node_uuid,
311
- r.created_at AS created_at,
312
- r.name AS name,
313
- r.fact AS fact,
314
- r.episodes AS episodes,
315
- r.expired_at AS expired_at,
316
- r.valid_at AS valid_at,
317
- r.invalid_at AS invalid_at,
318
- properties(r) AS attributes
319
- LIMIT $limit
276
+ + """
277
+ RETURN DISTINCT
278
+ """
279
+ + ENTITY_EDGE_RETURN
280
+ + """
281
+ LIMIT $limit
320
282
  """
321
283
  )
322
284
 
323
285
  records, _, _ = await driver.execute_query(
324
286
  query,
325
- params=filter_params,
326
287
  bfs_origin_node_uuids=bfs_origin_node_uuids,
327
- depth=bfs_max_depth,
328
288
  group_ids=group_ids,
329
289
  limit=limit,
330
290
  routing_='r',
291
+ **filter_params,
331
292
  )
332
293
 
333
294
  edges = [get_entity_edge_from_record(record) for record in records]
@@ -352,23 +313,25 @@ async def node_fulltext_search(
352
313
  get_nodes_query(driver.provider, 'node_name_and_summary', '$query')
353
314
  + """
354
315
  YIELD node AS n, score
355
- WITH n, score
356
- LIMIT $limit
357
- WHERE n:Entity AND n.group_id IN $group_ids
316
+ WHERE n:Entity AND n.group_id IN $group_ids
358
317
  """
359
318
  + filter_query
360
- + ENTITY_NODE_RETURN
361
319
  + """
320
+ WITH n, score
362
321
  ORDER BY score DESC
322
+ LIMIT $limit
323
+ RETURN
363
324
  """
325
+ + ENTITY_NODE_RETURN
364
326
  )
365
- records, header, _ = await driver.execute_query(
327
+
328
+ records, _, _ = await driver.execute_query(
366
329
  query,
367
- params=filter_params,
368
330
  query=fuzzy_query,
369
331
  group_ids=group_ids,
370
332
  limit=limit,
371
333
  routing_='r',
334
+ **filter_params,
372
335
  )
373
336
 
374
337
  nodes = [get_entity_node_from_record(record) for record in records]
@@ -406,22 +369,23 @@ async def node_similarity_search(
406
369
  WITH n, """
407
370
  + get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
408
371
  + """ AS score
409
- WHERE score > $min_score"""
372
+ WHERE score > $min_score
373
+ RETURN
374
+ """
410
375
  + ENTITY_NODE_RETURN
411
376
  + """
412
377
  ORDER BY score DESC
413
378
  LIMIT $limit
414
- """
379
+ """
415
380
  )
416
381
 
417
- records, header, _ = await driver.execute_query(
382
+ records, _, _ = await driver.execute_query(
418
383
  query,
419
- params=query_params,
420
384
  search_vector=search_vector,
421
- group_ids=group_ids,
422
385
  limit=limit,
423
386
  min_score=min_score,
424
387
  routing_='r',
388
+ **query_params,
425
389
  )
426
390
 
427
391
  nodes = [get_entity_node_from_record(record) for record in records]
@@ -444,26 +408,29 @@ async def node_bfs_search(
444
408
  filter_query, filter_params = node_search_filter_query_constructor(search_filter)
445
409
 
446
410
  query = (
411
+ f"""
412
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
413
+ MATCH (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
414
+ WHERE n.group_id = origin.group_id
415
+ AND origin.group_id IN $group_ids
447
416
  """
448
- UNWIND $bfs_origin_node_uuids AS origin_uuid
449
- MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
450
- WHERE n.group_id = origin.group_id
451
- AND origin.group_id IN $group_ids
452
- """
453
417
  + filter_query
418
+ + """
419
+ RETURN
420
+ """
454
421
  + ENTITY_NODE_RETURN
455
422
  + """
456
423
  LIMIT $limit
457
424
  """
458
425
  )
426
+
459
427
  records, _, _ = await driver.execute_query(
460
428
  query,
461
- params=filter_params,
462
429
  bfs_origin_node_uuids=bfs_origin_node_uuids,
463
- depth=bfs_max_depth,
464
430
  group_ids=group_ids,
465
431
  limit=limit,
466
432
  routing_='r',
433
+ **filter_params,
467
434
  )
468
435
  nodes = [get_entity_node_from_record(record) for record in records]
469
436
 
@@ -489,16 +456,10 @@ async def episode_fulltext_search(
489
456
  MATCH (e:Episodic)
490
457
  WHERE e.uuid = episode.uuid
491
458
  AND e.group_id IN $group_ids
492
- RETURN
493
- e.content AS content,
494
- e.created_at AS created_at,
495
- e.valid_at AS valid_at,
496
- e.uuid AS uuid,
497
- e.name AS name,
498
- e.group_id AS group_id,
499
- e.source_description AS source_description,
500
- e.source AS source,
501
- e.entity_edges AS entity_edges
459
+ RETURN
460
+ """
461
+ + EPISODIC_NODE_RETURN
462
+ + """
502
463
  ORDER BY score DESC
503
464
  LIMIT $limit
504
465
  """
@@ -530,15 +491,12 @@ async def community_fulltext_search(
530
491
  query = (
531
492
  get_nodes_query(driver.provider, 'community_name', '$query')
532
493
  + """
533
- YIELD node AS comm, score
534
- WHERE comm.group_id IN $group_ids
494
+ YIELD node AS n, score
495
+ WHERE n.group_id IN $group_ids
535
496
  RETURN
536
- comm.uuid AS uuid,
537
- comm.group_id AS group_id,
538
- comm.name AS name,
539
- comm.created_at AS created_at,
540
- comm.summary AS summary,
541
- comm.name_embedding AS name_embedding
497
+ """
498
+ + COMMUNITY_NODE_RETURN
499
+ + """
542
500
  ORDER BY score DESC
543
501
  LIMIT $limit
544
502
  """
@@ -568,39 +526,37 @@ async def community_similarity_search(
568
526
 
569
527
  group_filter_query: LiteralString = ''
570
528
  if group_ids is not None:
571
- group_filter_query += 'WHERE comm.group_id IN $group_ids'
529
+ group_filter_query += 'WHERE n.group_id IN $group_ids'
572
530
  query_params['group_ids'] = group_ids
573
531
 
574
532
  query = (
575
533
  RUNTIME_QUERY
576
534
  + """
577
- MATCH (comm:Community)
578
- """
535
+ MATCH (n:Community)
536
+ """
579
537
  + group_filter_query
580
538
  + """
581
- WITH comm, """
582
- + get_vector_cosine_func_query('comm.name_embedding', '$search_vector', driver.provider)
539
+ WITH n,
540
+ """
541
+ + get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
583
542
  + """ AS score
584
- WHERE score > $min_score
585
- RETURN
586
- comm.uuid As uuid,
587
- comm.group_id AS group_id,
588
- comm.name AS name,
589
- comm.created_at AS created_at,
590
- comm.summary AS summary,
591
- comm.name_embedding AS name_embedding
592
- ORDER BY score DESC
593
- LIMIT $limit
543
+ WHERE score > $min_score
544
+ RETURN
545
+ """
546
+ + COMMUNITY_NODE_RETURN
547
+ + """
548
+ ORDER BY score DESC
549
+ LIMIT $limit
594
550
  """
595
551
  )
596
552
 
597
553
  records, _, _ = await driver.execute_query(
598
554
  query,
599
555
  search_vector=search_vector,
600
- group_ids=group_ids,
601
556
  limit=limit,
602
557
  min_score=min_score,
603
558
  routing_='r',
559
+ **query_params,
604
560
  )
605
561
  communities = [get_community_node_from_record(record) for record in records]
606
562
 
@@ -719,8 +675,8 @@ async def get_relevant_nodes(
719
675
  WHERE m.group_id = $group_id
720
676
  WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
721
677
 
722
- WITH node,
723
- top_vector_nodes,
678
+ WITH node,
679
+ top_vector_nodes,
724
680
  [m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
725
681
 
726
682
  WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
@@ -728,10 +684,10 @@ async def get_relevant_nodes(
728
684
  UNWIND combined_nodes AS combined_node
729
685
  WITH node, collect(DISTINCT combined_node) AS deduped_nodes
730
686
 
731
- RETURN
687
+ RETURN
732
688
  node.uuid AS search_node_uuid,
733
689
  [x IN deduped_nodes | {
734
- uuid: x.uuid,
690
+ uuid: x.uuid,
735
691
  name: x.name,
736
692
  name_embedding: x.name_embedding,
737
693
  group_id: x.group_id,
@@ -755,12 +711,12 @@ async def get_relevant_nodes(
755
711
 
756
712
  results, _, _ = await driver.execute_query(
757
713
  query,
758
- params=query_params,
759
714
  nodes=query_nodes,
760
715
  group_id=group_id,
761
716
  limit=limit,
762
717
  min_score=min_score,
763
718
  routing_='r',
719
+ **query_params,
764
720
  )
765
721
 
766
722
  relevant_nodes_dict: dict[str, list[EntityNode]] = {
@@ -825,11 +781,11 @@ async def get_relevant_edges(
825
781
 
826
782
  results, _, _ = await driver.execute_query(
827
783
  query,
828
- params=query_params,
829
784
  edges=[edge.model_dump() for edge in edges],
830
785
  limit=limit,
831
786
  min_score=min_score,
832
787
  routing_='r',
788
+ **query_params,
833
789
  )
834
790
 
835
791
  relevant_edges_dict: dict[str, list[EntityEdge]] = {
@@ -895,11 +851,11 @@ async def get_edge_invalidation_candidates(
895
851
 
896
852
  results, _, _ = await driver.execute_query(
897
853
  query,
898
- params=query_params,
899
854
  edges=[edge.model_dump() for edge in edges],
900
855
  limit=limit,
901
856
  min_score=min_score,
902
857
  routing_='r',
858
+ **query_params,
903
859
  )
904
860
  invalidation_edges_dict: dict[str, list[EntityEdge]] = {
905
861
  result['search_edge_uuid']: [
@@ -943,18 +899,17 @@ async def node_distance_reranker(
943
899
  scores: dict[str, float] = {center_node_uuid: 0.0}
944
900
 
945
901
  # Find the shortest path to center node
946
- query = """
902
+ results, header, _ = await driver.execute_query(
903
+ """
947
904
  UNWIND $node_uuids AS node_uuid
948
905
  MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
949
906
  RETURN 1 AS score, node_uuid AS uuid
950
- """
951
- results, header, _ = await driver.execute_query(
952
- query,
907
+ """,
953
908
  node_uuids=filtered_uuids,
954
909
  center_uuid=center_node_uuid,
955
910
  routing_='r',
956
911
  )
957
- if driver.provider == 'falkordb':
912
+ if driver.provider == GraphProvider.FALKORDB:
958
913
  results = [dict(zip(header, row, strict=True)) for row in results]
959
914
 
960
915
  for result in results:
@@ -987,13 +942,12 @@ async def episode_mentions_reranker(
987
942
  scores: dict[str, float] = {}
988
943
 
989
944
  # Find the shortest path to center node
990
- query = """
991
- UNWIND $node_uuids AS node_uuid
945
+ results, _, _ = await driver.execute_query(
946
+ """
947
+ UNWIND $node_uuids AS node_uuid
992
948
  MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
993
949
  RETURN count(*) AS score, n.uuid AS uuid
994
- """
995
- results, _, _ = await driver.execute_query(
996
- query,
950
+ """,
997
951
  node_uuids=sorted_uuids,
998
952
  routing_='r',
999
953
  )
@@ -1053,15 +1007,16 @@ def maximal_marginal_relevance(
1053
1007
  async def get_embeddings_for_nodes(
1054
1008
  driver: GraphDriver, nodes: list[EntityNode]
1055
1009
  ) -> dict[str, list[float]]:
1056
- query: LiteralString = """MATCH (n:Entity)
1057
- WHERE n.uuid IN $node_uuids
1058
- RETURN DISTINCT
1059
- n.uuid AS uuid,
1060
- n.name_embedding AS name_embedding
1061
- """
1062
-
1063
1010
  results, _, _ = await driver.execute_query(
1064
- query, node_uuids=[node.uuid for node in nodes], routing_='r'
1011
+ """
1012
+ MATCH (n:Entity)
1013
+ WHERE n.uuid IN $node_uuids
1014
+ RETURN DISTINCT
1015
+ n.uuid AS uuid,
1016
+ n.name_embedding AS name_embedding
1017
+ """,
1018
+ node_uuids=[node.uuid for node in nodes],
1019
+ routing_='r',
1065
1020
  )
1066
1021
 
1067
1022
  embeddings_dict: dict[str, list[float]] = {}
@@ -1077,15 +1032,14 @@ async def get_embeddings_for_nodes(
1077
1032
  async def get_embeddings_for_communities(
1078
1033
  driver: GraphDriver, communities: list[CommunityNode]
1079
1034
  ) -> dict[str, list[float]]:
1080
- query: LiteralString = """MATCH (c:Community)
1081
- WHERE c.uuid IN $community_uuids
1082
- RETURN DISTINCT
1083
- c.uuid AS uuid,
1084
- c.name_embedding AS name_embedding
1085
- """
1086
-
1087
1035
  results, _, _ = await driver.execute_query(
1088
- query,
1036
+ """
1037
+ MATCH (c:Community)
1038
+ WHERE c.uuid IN $community_uuids
1039
+ RETURN DISTINCT
1040
+ c.uuid AS uuid,
1041
+ c.name_embedding AS name_embedding
1042
+ """,
1089
1043
  community_uuids=[community.uuid for community in communities],
1090
1044
  routing_='r',
1091
1045
  )
@@ -1103,15 +1057,14 @@ async def get_embeddings_for_communities(
1103
1057
  async def get_embeddings_for_edges(
1104
1058
  driver: GraphDriver, edges: list[EntityEdge]
1105
1059
  ) -> dict[str, list[float]]:
1106
- query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
1107
- WHERE e.uuid IN $edge_uuids
1108
- RETURN DISTINCT
1109
- e.uuid AS uuid,
1110
- e.fact_embedding AS fact_embedding
1111
- """
1112
-
1113
1060
  results, _, _ = await driver.execute_query(
1114
- query,
1061
+ """
1062
+ MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
1063
+ WHERE e.uuid IN $edge_uuids
1064
+ RETURN DISTINCT
1065
+ e.uuid AS uuid,
1066
+ e.fact_embedding AS fact_embedding
1067
+ """,
1115
1068
  edge_uuids=[edge.uuid for edge in edges],
1116
1069
  routing_='r',
1117
1070
  )
@@ -25,17 +25,15 @@ from typing_extensions import Any
25
25
  from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
26
26
  from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
27
27
  from graphiti_core.embedder import EmbedderClient
28
- from graphiti_core.graph_queries import (
29
- get_entity_edge_save_bulk_query,
30
- get_entity_node_save_bulk_query,
31
- )
32
28
  from graphiti_core.graphiti_types import GraphitiClients
33
29
  from graphiti_core.helpers import normalize_l2, semaphore_gather
34
30
  from graphiti_core.models.edges.edge_db_queries import (
35
31
  EPISODIC_EDGE_SAVE_BULK,
32
+ get_entity_edge_save_bulk_query,
36
33
  )
37
34
  from graphiti_core.models.nodes.node_db_queries import (
38
35
  EPISODIC_NODE_SAVE_BULK,
36
+ get_entity_node_save_bulk_query,
39
37
  )
40
38
  from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
41
39
  from graphiti_core.utils.maintenance.edge_operations import (
@@ -158,7 +156,7 @@ async def add_nodes_and_edges_bulk_tx(
158
156
  edges.append(edge_data)
159
157
 
160
158
  await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
161
- entity_node_save_bulk = get_entity_node_save_bulk_query(nodes, driver.provider)
159
+ entity_node_save_bulk = get_entity_node_save_bulk_query(driver.provider, nodes)
162
160
  await tx.run(entity_node_save_bulk, nodes=nodes)
163
161
  await tx.run(
164
162
  EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
@@ -171,9 +169,9 @@ async def extract_nodes_and_edges_bulk(
171
169
  clients: GraphitiClients,
172
170
  episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
173
171
  edge_type_map: dict[tuple[str, str], list[str]],
174
- entity_types: dict[str, BaseModel] | None = None,
172
+ entity_types: dict[str, type[BaseModel]] | None = None,
175
173
  excluded_entity_types: list[str] | None = None,
176
- edge_types: dict[str, BaseModel] | None = None,
174
+ edge_types: dict[str, type[BaseModel]] | None = None,
177
175
  ) -> tuple[list[list[EntityNode]], list[list[EntityEdge]]]:
178
176
  extracted_nodes_bulk: list[list[EntityNode]] = await semaphore_gather(
179
177
  *[
@@ -204,7 +202,7 @@ async def dedupe_nodes_bulk(
204
202
  clients: GraphitiClients,
205
203
  extracted_nodes: list[list[EntityNode]],
206
204
  episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
207
- entity_types: dict[str, BaseModel] | None = None,
205
+ entity_types: dict[str, type[BaseModel]] | None = None,
208
206
  ) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
209
207
  embedder = clients.embedder
210
208
  min_score = 0.8
@@ -292,7 +290,7 @@ async def dedupe_edges_bulk(
292
290
  extracted_edges: list[list[EntityEdge]],
293
291
  episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
294
292
  _entities: list[EntityNode],
295
- edge_types: dict[str, BaseModel],
293
+ edge_types: dict[str, type[BaseModel]],
296
294
  _edge_type_map: dict[tuple[str, str], list[str]],
297
295
  ) -> dict[str, list[EntityEdge]]:
298
296
  embedder = clients.embedder