graphiti-core 0.10.5__py3-none-any.whl → 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -26,7 +26,7 @@ from typing_extensions import LiteralString
26
26
  from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
27
27
  from graphiti_core.helpers import (
28
28
  DEFAULT_DATABASE,
29
- USE_PARALLEL_RUNTIME,
29
+ RUNTIME_QUERY,
30
30
  lucene_sanitize,
31
31
  normalize_l2,
32
32
  semaphore_gather,
@@ -207,10 +207,6 @@ async def edge_similarity_search(
207
207
  min_score: float = DEFAULT_MIN_SCORE,
208
208
  ) -> list[EntityEdge]:
209
209
  # vector similarity search over embedded facts
210
- runtime_query: LiteralString = (
211
- 'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
212
- )
213
-
214
210
  query_params: dict[str, Any] = {}
215
211
 
216
212
  filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
@@ -230,9 +226,10 @@ async def edge_similarity_search(
230
226
  group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
231
227
 
232
228
  query: LiteralString = (
233
- """
234
- MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
235
- """
229
+ RUNTIME_QUERY
230
+ + """
231
+ MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
232
+ """
236
233
  + group_filter_query
237
234
  + filter_query
238
235
  + """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
@@ -256,7 +253,7 @@ async def edge_similarity_search(
256
253
  )
257
254
 
258
255
  records, _, _ = await driver.execute_query(
259
- runtime_query + query,
256
+ query,
260
257
  query_params,
261
258
  search_vector=search_vector,
262
259
  source_uuid=source_node_uuid,
@@ -344,10 +341,10 @@ async def node_fulltext_search(
344
341
 
345
342
  query = (
346
343
  """
347
- CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
348
- YIELD node AS n, score
349
- WHERE n:Entity
350
- """
344
+ CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
345
+ YIELD node AS n, score
346
+ WHERE n:Entity
347
+ """
351
348
  + filter_query
352
349
  + ENTITY_NODE_RETURN
353
350
  + """
@@ -378,10 +375,6 @@ async def node_similarity_search(
378
375
  min_score: float = DEFAULT_MIN_SCORE,
379
376
  ) -> list[EntityNode]:
380
377
  # vector similarity search over entity names
381
- runtime_query: LiteralString = (
382
- 'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
383
- )
384
-
385
378
  query_params: dict[str, Any] = {}
386
379
 
387
380
  group_filter_query: LiteralString = ''
@@ -393,7 +386,7 @@ async def node_similarity_search(
393
386
  query_params.update(filter_params)
394
387
 
395
388
  records, _, _ = await driver.execute_query(
396
- runtime_query
389
+ RUNTIME_QUERY
397
390
  + """
398
391
  MATCH (n:Entity)
399
392
  """
@@ -542,10 +535,6 @@ async def community_similarity_search(
542
535
  min_score=DEFAULT_MIN_SCORE,
543
536
  ) -> list[CommunityNode]:
544
537
  # vector similarity search over entity names
545
- runtime_query: LiteralString = (
546
- 'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
547
- )
548
-
549
538
  query_params: dict[str, Any] = {}
550
539
 
551
540
  group_filter_query: LiteralString = ''
@@ -554,7 +543,7 @@ async def community_similarity_search(
554
543
  query_params['group_ids'] = group_ids
555
544
 
556
545
  records, _, _ = await driver.execute_query(
557
- runtime_query
546
+ RUNTIME_QUERY
558
547
  + """
559
548
  MATCH (comm:Community)
560
549
  """
@@ -660,86 +649,223 @@ async def hybrid_node_search(
660
649
 
661
650
  async def get_relevant_nodes(
662
651
  driver: AsyncDriver,
663
- search_filter: SearchFilters,
664
652
  nodes: list[EntityNode],
665
- ) -> list[EntityNode]:
666
- """
667
- Retrieve relevant nodes based on the provided list of EntityNodes.
653
+ search_filter: SearchFilters,
654
+ min_score: float = DEFAULT_MIN_SCORE,
655
+ limit: int = RELEVANT_SCHEMA_LIMIT,
656
+ ) -> list[list[EntityNode]]:
657
+ if len(nodes) == 0:
658
+ return []
668
659
 
669
- This method performs a hybrid search using both the names and embeddings
670
- of the input nodes to find relevant nodes in the graph database.
660
+ group_id = nodes[0].group_id
671
661
 
672
- Parameters
673
- ----------
674
- nodes : list[EntityNode]
675
- A list of EntityNode objects to use as the basis for the search.
676
- driver : AsyncDriver
677
- The Neo4j driver instance for database operations.
662
+ # vector similarity search over entity names
663
+ query_params: dict[str, Any] = {}
678
664
 
679
- Returns
680
- -------
681
- list[EntityNode]
682
- A list of EntityNode objects that are deemed relevant based on the input nodes.
665
+ filter_query, filter_params = node_search_filter_query_constructor(search_filter)
666
+ query_params.update(filter_params)
683
667
 
684
- Notes
685
- -----
686
- This method uses the hybrid_node_search function to perform the search,
687
- which combines fulltext search and vector similarity search.
688
- It extracts the names and name embeddings (if available) from the input nodes
689
- to use as search criteria.
690
- """
691
- relevant_nodes = await hybrid_node_search(
692
- [node.name for node in nodes],
693
- [node.name_embedding for node in nodes if node.name_embedding is not None],
694
- driver,
695
- search_filter,
696
- [node.group_id for node in nodes],
668
+ query = (
669
+ RUNTIME_QUERY
670
+ + """UNWIND $nodes AS node
671
+ MATCH (n:Entity {group_id: $group_id})
672
+ """
673
+ + filter_query
674
+ + """
675
+ WITH node, n, vector.similarity.cosine(n.name_embedding, node.name_embedding) AS score
676
+ WHERE score > $min_score
677
+ WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
678
+
679
+ CALL db.index.fulltext.queryNodes("node_name_and_summary", 'group_id:"' + $group_id + '" AND ' + node.name, {limit: $limit})
680
+ YIELD node AS m
681
+ WHERE m.group_id = $group_id
682
+ WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
683
+
684
+ WITH node,
685
+ top_vector_nodes,
686
+ [m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
687
+
688
+ WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
689
+
690
+ UNWIND combined_nodes AS combined_node
691
+ WITH node, collect(DISTINCT combined_node) AS deduped_nodes
692
+
693
+ RETURN
694
+ node.uuid AS search_node_uuid,
695
+ [x IN deduped_nodes | {
696
+ uuid: x.uuid,
697
+ name: x.name,
698
+ name_embedding: x.name_embedding,
699
+ group_id: x.group_id,
700
+ created_at: x.created_at,
701
+ summary: x.summary,
702
+ labels: labels(x),
703
+ attributes: properties(x)
704
+ }] AS matches
705
+ """
706
+ )
707
+
708
+ results, _, _ = await driver.execute_query(
709
+ query,
710
+ query_params,
711
+ nodes=[
712
+ {
713
+ 'uuid': node.uuid,
714
+ 'name': lucene_sanitize(node.name),
715
+ 'name_embedding': node.name_embedding,
716
+ }
717
+ for node in nodes
718
+ ],
719
+ group_id=lucene_sanitize(group_id),
720
+ limit=limit,
721
+ min_score=min_score,
722
+ database_=DEFAULT_DATABASE,
723
+ routing_='r',
697
724
  )
698
725
 
726
+ relevant_nodes_dict: dict[str, list[EntityNode]] = {
727
+ result['search_node_uuid']: [
728
+ get_entity_node_from_record(record) for record in result['matches']
729
+ ]
730
+ for result in results
731
+ }
732
+
733
+ relevant_nodes = [relevant_nodes_dict.get(node.uuid, []) for node in nodes]
734
+
699
735
  return relevant_nodes
700
736
 
701
737
 
702
738
  async def get_relevant_edges(
703
739
  driver: AsyncDriver,
704
740
  edges: list[EntityEdge],
705
- source_node_uuid: str | None,
706
- target_node_uuid: str | None,
741
+ search_filter: SearchFilters,
742
+ min_score: float = DEFAULT_MIN_SCORE,
707
743
  limit: int = RELEVANT_SCHEMA_LIMIT,
708
- ) -> list[EntityEdge]:
709
- start = time()
710
- relevant_edges: list[EntityEdge] = []
711
- relevant_edge_uuids = set()
712
-
713
- results = await semaphore_gather(
714
- *[
715
- edge_similarity_search(
716
- driver,
717
- edge.fact_embedding,
718
- source_node_uuid,
719
- target_node_uuid,
720
- SearchFilters(),
721
- [edge.group_id],
722
- limit,
723
- )
724
- for edge in edges
725
- if edge.fact_embedding is not None
726
- ]
727
- )
744
+ ) -> list[list[EntityEdge]]:
745
+ if len(edges) == 0:
746
+ return []
728
747
 
729
- for result in results:
730
- for edge in result:
731
- if edge.uuid in relevant_edge_uuids:
732
- continue
748
+ query_params: dict[str, Any] = {}
733
749
 
734
- relevant_edge_uuids.add(edge.uuid)
735
- relevant_edges.append(edge)
750
+ filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
751
+ query_params.update(filter_params)
736
752
 
737
- end = time()
738
- logger.debug(f'Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms')
753
+ query = (
754
+ RUNTIME_QUERY
755
+ + """UNWIND $edges AS edge
756
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
757
+ """
758
+ + filter_query
759
+ + """
760
+ WITH e, edge, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score
761
+ WHERE score > $min_score
762
+ WITH edge, e, score
763
+ ORDER BY score DESC
764
+ RETURN edge.uuid AS search_edge_uuid,
765
+ collect({
766
+ uuid: e.uuid,
767
+ source_node_uuid: startNode(e).uuid,
768
+ target_node_uuid: endNode(e).uuid,
769
+ created_at: e.created_at,
770
+ name: e.name,
771
+ group_id: e.group_id,
772
+ fact: e.fact,
773
+ fact_embedding: e.fact_embedding,
774
+ episodes: e.episodes,
775
+ expired_at: e.expired_at,
776
+ valid_at: e.valid_at,
777
+ invalid_at: e.invalid_at
778
+ })[..$limit] AS matches
779
+ """
780
+ )
781
+
782
+ results, _, _ = await driver.execute_query(
783
+ query,
784
+ query_params,
785
+ edges=[edge.model_dump() for edge in edges],
786
+ limit=limit,
787
+ min_score=min_score,
788
+ database_=DEFAULT_DATABASE,
789
+ routing_='r',
790
+ )
791
+ relevant_edges_dict: dict[str, list[EntityEdge]] = {
792
+ result['search_edge_uuid']: [
793
+ get_entity_edge_from_record(record) for record in result['matches']
794
+ ]
795
+ for result in results
796
+ }
797
+
798
+ relevant_edges = [relevant_edges_dict.get(edge.uuid, []) for edge in edges]
739
799
 
740
800
  return relevant_edges
741
801
 
742
802
 
803
+ async def get_edge_invalidation_candidates(
804
+ driver: AsyncDriver,
805
+ edges: list[EntityEdge],
806
+ search_filter: SearchFilters,
807
+ min_score: float = DEFAULT_MIN_SCORE,
808
+ limit: int = RELEVANT_SCHEMA_LIMIT,
809
+ ) -> list[list[EntityEdge]]:
810
+ if len(edges) == 0:
811
+ return []
812
+
813
+ query_params: dict[str, Any] = {}
814
+
815
+ filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
816
+ query_params.update(filter_params)
817
+
818
+ query = (
819
+ RUNTIME_QUERY
820
+ + """UNWIND $edges AS edge
821
+ MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
822
+ WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
823
+ """
824
+ + filter_query
825
+ + """
826
+ WITH edge, e, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score
827
+ WHERE score > $min_score
828
+ WITH edge, e, score
829
+ ORDER BY score DESC
830
+ RETURN edge.uuid AS search_edge_uuid,
831
+ collect({
832
+ uuid: e.uuid,
833
+ source_node_uuid: startNode(e).uuid,
834
+ target_node_uuid: endNode(e).uuid,
835
+ created_at: e.created_at,
836
+ name: e.name,
837
+ group_id: e.group_id,
838
+ fact: e.fact,
839
+ fact_embedding: e.fact_embedding,
840
+ episodes: e.episodes,
841
+ expired_at: e.expired_at,
842
+ valid_at: e.valid_at,
843
+ invalid_at: e.invalid_at
844
+ })[..$limit] AS matches
845
+ """
846
+ )
847
+
848
+ results, _, _ = await driver.execute_query(
849
+ query,
850
+ query_params,
851
+ edges=[edge.model_dump() for edge in edges],
852
+ limit=limit,
853
+ min_score=min_score,
854
+ database_=DEFAULT_DATABASE,
855
+ routing_='r',
856
+ )
857
+ invalidation_edges_dict: dict[str, list[EntityEdge]] = {
858
+ result['search_edge_uuid']: [
859
+ get_entity_edge_from_record(record) for record in result['matches']
860
+ ]
861
+ for result in results
862
+ }
863
+
864
+ invalidation_edges = [invalidation_edges_dict.get(edge.uuid, []) for edge in edges]
865
+
866
+ return invalidation_edges
867
+
868
+
743
869
  # takes in a list of rankings of uuids
744
870
  def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[str]:
745
871
  scores: dict[str, float] = defaultdict(float)
@@ -26,6 +26,7 @@ from pydantic import BaseModel
26
26
  from typing_extensions import Any
27
27
 
28
28
  from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
29
+ from graphiti_core.graphiti_types import GraphitiClients
29
30
  from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
30
31
  from graphiti_core.llm_client import LLMClient
31
32
  from graphiti_core.models.edges.edge_db_queries import (
@@ -128,16 +129,18 @@ async def add_nodes_and_edges_bulk_tx(
128
129
 
129
130
  await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
130
131
  await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
131
- await tx.run(EPISODIC_EDGE_SAVE_BULK, episodic_edges=[dict(edge) for edge in episodic_edges])
132
- await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[dict(edge) for edge in entity_edges])
132
+ await tx.run(
133
+ EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
134
+ )
135
+ await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[edge.model_dump() for edge in entity_edges])
133
136
 
134
137
 
135
138
  async def extract_nodes_and_edges_bulk(
136
- llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
139
+ clients: GraphitiClients, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
137
140
  ) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
138
141
  extracted_nodes_bulk = await semaphore_gather(
139
142
  *[
140
- extract_nodes(llm_client, episode, previous_episodes)
143
+ extract_nodes(clients, episode, previous_episodes)
141
144
  for episode, previous_episodes in episode_tuples
142
145
  ]
143
146
  )
@@ -150,7 +153,7 @@ async def extract_nodes_and_edges_bulk(
150
153
  extracted_edges_bulk = await semaphore_gather(
151
154
  *[
152
155
  extract_edges(
153
- llm_client,
156
+ clients,
154
157
  episode,
155
158
  extracted_nodes_bulk[i],
156
159
  previous_episodes_list[i],
@@ -189,7 +192,7 @@ async def dedupe_nodes_bulk(
189
192
 
190
193
  existing_nodes_chunks: list[list[EntityNode]] = list(
191
194
  await semaphore_gather(
192
- *[get_relevant_nodes(driver, SearchFilters(), node_chunk) for node_chunk in node_chunks]
195
+ *[get_relevant_nodes(driver, node_chunk, SearchFilters()) for node_chunk in node_chunks]
193
196
  )
194
197
  )
195
198
 
@@ -223,7 +226,7 @@ async def dedupe_edges_bulk(
223
226
 
224
227
  relevant_edges_chunks: list[list[EntityEdge]] = list(
225
228
  await semaphore_gather(
226
- *[get_relevant_edges(driver, edge_chunk, None, None) for edge_chunk in edge_chunks]
229
+ *[get_relevant_edges(driver, edge_chunk, SearchFilters()) for edge_chunk in edge_chunks]
227
230
  )
228
231
  )
229
232