graphiti-core 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

graphiti_core/errors.py CHANGED
@@ -1,3 +1,20 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+
1
18
  class GraphitiError(Exception):
2
19
  """Base exception class for Graphiti Core."""
3
20
 
@@ -16,3 +33,11 @@ class NodeNotFoundError(GraphitiError):
16
33
  def __init__(self, uuid: str):
17
34
  self.message = f'node {uuid} not found'
18
35
  super().__init__(self.message)
36
+
37
+
38
+ class SearchRerankerError(GraphitiError):
39
+ """Raised when a node is not found."""
40
+
41
+ def __init__(self, text: str):
42
+ self.message = text
43
+ super().__init__(self.message)
graphiti_core/graphiti.py CHANGED
@@ -24,14 +24,19 @@ from neo4j import AsyncGraphDatabase
24
24
 
25
25
  from graphiti_core.edges import EntityEdge, EpisodicEdge
26
26
  from graphiti_core.llm_client import LLMClient, OpenAIClient
27
- from graphiti_core.llm_client.utils import generate_embedding
28
27
  from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
29
- from graphiti_core.search.search import Reranker, SearchConfig, SearchMethod, hybrid_search
28
+ from graphiti_core.search.search import SearchConfig, search
29
+ from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
30
+ from graphiti_core.search.search_config_recipes import (
31
+ EDGE_HYBRID_SEARCH_NODE_DISTANCE,
32
+ EDGE_HYBRID_SEARCH_RRF,
33
+ NODE_HYBRID_SEARCH_NODE_DISTANCE,
34
+ NODE_HYBRID_SEARCH_RRF,
35
+ )
30
36
  from graphiti_core.search.search_utils import (
31
37
  RELEVANT_SCHEMA_LIMIT,
32
38
  get_relevant_edges,
33
39
  get_relevant_nodes,
34
- hybrid_node_search,
35
40
  )
36
41
  from graphiti_core.utils import (
37
42
  build_episodic_edges,
@@ -548,7 +553,7 @@ class Graphiti:
548
553
  query: str,
549
554
  center_node_uuid: str | None = None,
550
555
  group_ids: list[str | None] | None = None,
551
- num_results=10,
556
+ num_results=DEFAULT_SEARCH_LIMIT,
552
557
  ):
553
558
  """
554
559
  Perform a hybrid search on the knowledge graph.
@@ -564,7 +569,7 @@ class Graphiti:
564
569
  Facts will be reranked based on proximity to this node
565
570
  group_ids : list[str | None] | None, optional
566
571
  The graph partitions to return data from.
567
- num_results : int, optional
572
+ limit : int, optional
568
573
  The maximum number of results to return. Defaults to 10.
569
574
 
570
575
  Returns
@@ -581,21 +586,17 @@ class Graphiti:
581
586
  The search is performed using the current date and time as the reference
582
587
  point for temporal relevance.
583
588
  """
584
- reranker = Reranker.rrf if center_node_uuid is None else Reranker.node_distance
585
- search_config = SearchConfig(
586
- num_episodes=0,
587
- num_edges=num_results,
588
- num_nodes=0,
589
- group_ids=group_ids,
590
- search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
591
- reranker=reranker,
589
+ search_config = (
590
+ EDGE_HYBRID_SEARCH_RRF if center_node_uuid is None else EDGE_HYBRID_SEARCH_NODE_DISTANCE
592
591
  )
592
+ search_config.limit = num_results
593
+
593
594
  edges = (
594
- await hybrid_search(
595
+ await search(
595
596
  self.driver,
596
597
  self.llm_client.get_embedder(),
597
598
  query,
598
- datetime.now(),
599
+ group_ids,
599
600
  search_config,
600
601
  center_node_uuid,
601
602
  )
@@ -606,19 +607,20 @@ class Graphiti:
606
607
  async def _search(
607
608
  self,
608
609
  query: str,
609
- timestamp: datetime,
610
610
  config: SearchConfig,
611
+ group_ids: list[str | None] | None = None,
611
612
  center_node_uuid: str | None = None,
612
- ):
613
- return await hybrid_search(
614
- self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
613
+ ) -> SearchResults:
614
+ return await search(
615
+ self.driver, self.llm_client.get_embedder(), query, group_ids, config, center_node_uuid
615
616
  )
616
617
 
617
618
  async def get_nodes_by_query(
618
619
  self,
619
620
  query: str,
621
+ center_node_uuid: str | None = None,
620
622
  group_ids: list[str | None] | None = None,
621
- limit: int = RELEVANT_SCHEMA_LIMIT,
623
+ limit: int = DEFAULT_SEARCH_LIMIT,
622
624
  ) -> list[EntityNode]:
623
625
  """
624
626
  Retrieve nodes from the graph database based on a text query.
@@ -629,7 +631,9 @@ class Graphiti:
629
631
  Parameters
630
632
  ----------
631
633
  query : str
632
- The text query to search for in the graph.
634
+ The text query to search for in the graph
635
+ center_node_uuid: str, optional
636
+ Facts will be reranked based on proximity to this node.
633
637
  group_ids : list[str | None] | None, optional
634
638
  The graph partitions to return data from.
635
639
  limit : int | None, optional
@@ -655,8 +659,12 @@ class Graphiti:
655
659
  If not specified, a default limit (defined in the search functions) will be used.
656
660
  """
657
661
  embedder = self.llm_client.get_embedder()
658
- query_embedding = await generate_embedding(embedder, query)
659
- relevant_nodes = await hybrid_node_search(
660
- [query], [query_embedding], self.driver, group_ids, limit
662
+ search_config = (
663
+ NODE_HYBRID_SEARCH_RRF if center_node_uuid is None else NODE_HYBRID_SEARCH_NODE_DISTANCE
661
664
  )
662
- return relevant_nodes
665
+ search_config.limit = limit
666
+
667
+ nodes = (
668
+ await search(self.driver, embedder, query, group_ids, search_config, center_node_uuid)
669
+ ).nodes
670
+ return nodes
graphiti_core/helpers.py CHANGED
@@ -1,3 +1,19 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
1
17
  from datetime import datetime
2
18
 
3
19
  from neo4j import time as neo4j_time
@@ -1,3 +1,20 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+
1
18
  class RateLimitError(Exception):
2
19
  """Exception raised when the rate limit is exceeded."""
3
20
 
@@ -1,3 +1,19 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
1
17
  import logging
2
18
  import typing
3
19
  from time import time
@@ -17,6 +33,6 @@ async def generate_embedding(
17
33
  embedding = embedding[:EMBEDDING_DIM]
18
34
 
19
35
  end = time()
20
- logger.debug(f'embedded text of length {len(text)} in {end-start} ms')
36
+ logger.debug(f'embedded text of length {len(text)} in {end - start} ms')
21
37
 
22
38
  return embedding
graphiti_core/nodes.py CHANGED
@@ -175,7 +175,7 @@ class EpisodicNode(Node):
175
175
  e.valid_at AS valid_at,
176
176
  e.uuid AS uuid,
177
177
  e.name AS name,
178
- e.group_id AS group_id
178
+ e.group_id AS group_id,
179
179
  e.source_description AS source_description,
180
180
  e.source AS source
181
181
  """,
@@ -230,7 +230,7 @@ class EntityNode(Node):
230
230
  n.uuid As uuid,
231
231
  n.name AS name,
232
232
  n.name_embedding AS name_embedding,
233
- n.group_id AS group_id
233
+ n.group_id AS group_id,
234
234
  n.created_at AS created_at,
235
235
  n.summary AS summary
236
236
  """,
@@ -307,7 +307,7 @@ class CommunityNode(Node):
307
307
  n.uuid As uuid,
308
308
  n.name AS name,
309
309
  n.name_embedding AS name_embedding,
310
- n.group_id AS group_id
310
+ n.group_id AS group_id,
311
311
  n.created_at AS created_at,
312
312
  n.summary AS summary
313
313
  """,
@@ -329,7 +329,7 @@ class CommunityNode(Node):
329
329
  n.uuid As uuid,
330
330
  n.name AS name,
331
331
  n.name_embedding AS name_embedding,
332
- n.group_id AS group_id
332
+ n.group_id AS group_id,
333
333
  n.created_at AS created_at,
334
334
  n.summary AS summary
335
335
  """,
@@ -1,3 +1,19 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
1
17
  from typing import Any, Protocol, TypedDict
2
18
 
3
19
  from .models import Message, PromptFunction, PromptVersion
@@ -15,131 +15,228 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  import logging
18
- from datetime import datetime
19
- from enum import Enum
20
18
  from time import time
21
19
 
22
20
  from neo4j import AsyncDriver
23
- from pydantic import BaseModel, Field
24
21
 
25
22
  from graphiti_core.edges import EntityEdge
23
+ from graphiti_core.errors import SearchRerankerError
26
24
  from graphiti_core.llm_client.config import EMBEDDING_DIM
27
- from graphiti_core.nodes import EntityNode, EpisodicNode
25
+ from graphiti_core.nodes import CommunityNode, EntityNode
26
+ from graphiti_core.search.search_config import (
27
+ DEFAULT_SEARCH_LIMIT,
28
+ CommunityReranker,
29
+ CommunitySearchConfig,
30
+ CommunitySearchMethod,
31
+ EdgeReranker,
32
+ EdgeSearchConfig,
33
+ EdgeSearchMethod,
34
+ NodeReranker,
35
+ NodeSearchConfig,
36
+ NodeSearchMethod,
37
+ SearchConfig,
38
+ SearchResults,
39
+ )
28
40
  from graphiti_core.search.search_utils import (
41
+ community_fulltext_search,
42
+ community_similarity_search,
29
43
  edge_fulltext_search,
30
44
  edge_similarity_search,
31
- get_mentioned_nodes,
32
45
  node_distance_reranker,
46
+ node_fulltext_search,
47
+ node_similarity_search,
33
48
  rrf,
34
49
  )
35
- from graphiti_core.utils import retrieve_episodes
36
- from graphiti_core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
37
50
 
38
51
  logger = logging.getLogger(__name__)
39
52
 
40
53
 
41
- class SearchMethod(Enum):
42
- cosine_similarity = 'cosine_similarity'
43
- bm25 = 'bm25'
44
-
45
-
46
- class Reranker(Enum):
47
- rrf = 'reciprocal_rank_fusion'
48
- node_distance = 'node_distance'
54
+ async def search(
55
+ driver: AsyncDriver,
56
+ embedder,
57
+ query: str,
58
+ group_ids: list[str | None] | None,
59
+ config: SearchConfig,
60
+ center_node_uuid: str | None = None,
61
+ ) -> SearchResults:
62
+ start = time()
63
+ query = query.replace('\n', ' ')
64
+ # if group_ids is empty, set it to None
65
+ group_ids = group_ids if group_ids else None
66
+ edges = (
67
+ await edge_search(
68
+ driver, embedder, query, group_ids, config.edge_config, center_node_uuid, config.limit
69
+ )
70
+ if config.edge_config is not None
71
+ else []
72
+ )
73
+ nodes = (
74
+ await node_search(
75
+ driver, embedder, query, group_ids, config.node_config, center_node_uuid, config.limit
76
+ )
77
+ if config.node_config is not None
78
+ else []
79
+ )
80
+ communities = (
81
+ await community_search(
82
+ driver, embedder, query, group_ids, config.community_config, config.limit
83
+ )
84
+ if config.community_config is not None
85
+ else []
86
+ )
49
87
 
88
+ results = SearchResults(
89
+ edges=edges[: config.limit],
90
+ nodes=nodes[: config.limit],
91
+ communities=communities[: config.limit],
92
+ )
50
93
 
51
- class SearchConfig(BaseModel):
52
- num_edges: int = Field(default=10)
53
- num_nodes: int = Field(default=10)
54
- num_episodes: int = EPISODE_WINDOW_LEN
55
- group_ids: list[str | None] | None
56
- search_methods: list[SearchMethod]
57
- reranker: Reranker | None
94
+ end = time()
58
95
 
96
+ logger.info(f'search returned context for query {query} in {(end - start) * 1000} ms')
59
97
 
60
- class SearchResults(BaseModel):
61
- episodes: list[EpisodicNode]
62
- nodes: list[EntityNode]
63
- edges: list[EntityEdge]
98
+ return results
64
99
 
65
100
 
66
- async def hybrid_search(
101
+ async def edge_search(
67
102
  driver: AsyncDriver,
68
103
  embedder,
69
104
  query: str,
70
- timestamp: datetime,
71
- config: SearchConfig,
105
+ group_ids: list[str | None] | None,
106
+ config: EdgeSearchConfig,
72
107
  center_node_uuid: str | None = None,
73
- ) -> SearchResults:
74
- start = time()
75
-
76
- episodes = []
77
- nodes = []
78
- edges = []
108
+ limit=DEFAULT_SEARCH_LIMIT,
109
+ ) -> list[EntityEdge]:
110
+ search_results: list[list[EntityEdge]] = []
79
111
 
80
- search_results = []
112
+ if EdgeSearchMethod.bm25 in config.search_methods:
113
+ text_search = await edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit)
114
+ search_results.append(text_search)
81
115
 
82
- if config.num_episodes > 0:
83
- episodes.extend(await retrieve_episodes(driver, timestamp, config.num_episodes))
84
- nodes.extend(await get_mentioned_nodes(driver, episodes))
116
+ if EdgeSearchMethod.cosine_similarity in config.search_methods:
117
+ search_vector = (
118
+ (await embedder.create(input=[query], model='text-embedding-3-small'))
119
+ .data[0]
120
+ .embedding[:EMBEDDING_DIM]
121
+ )
85
122
 
86
- if SearchMethod.bm25 in config.search_methods:
87
- text_search = await edge_fulltext_search(
88
- driver, query, None, None, config.group_ids, 2 * config.num_edges
123
+ similarity_search = await edge_similarity_search(
124
+ driver, search_vector, None, None, group_ids, 2 * limit
89
125
  )
126
+ search_results.append(similarity_search)
127
+
128
+ if len(search_results) > 1 and config.reranker is None:
129
+ raise SearchRerankerError('Multiple edge searches enabled without a reranker')
130
+
131
+ edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
132
+
133
+ reranked_uuids: list[str] = []
134
+ if config.reranker == EdgeReranker.rrf:
135
+ search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
136
+
137
+ reranked_uuids = rrf(search_result_uuids)
138
+ elif config.reranker == EdgeReranker.node_distance:
139
+ if center_node_uuid is None:
140
+ raise SearchRerankerError('No center node provided for Node Distance reranker')
141
+
142
+ source_to_edge_uuid_map = {
143
+ edge.source_node_uuid: edge.uuid for result in search_results for edge in result
144
+ }
145
+ source_uuids = [[edge.source_node_uuid for edge in result] for result in search_results]
146
+
147
+ reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid)
148
+
149
+ reranked_uuids = [source_to_edge_uuid_map[node_uuid] for node_uuid in reranked_node_uuids]
150
+
151
+ reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
152
+
153
+ return reranked_edges
154
+
155
+
156
+ async def node_search(
157
+ driver: AsyncDriver,
158
+ embedder,
159
+ query: str,
160
+ group_ids: list[str | None] | None,
161
+ config: NodeSearchConfig,
162
+ center_node_uuid: str | None = None,
163
+ limit=DEFAULT_SEARCH_LIMIT,
164
+ ) -> list[EntityNode]:
165
+ search_results: list[list[EntityNode]] = []
166
+
167
+ if NodeSearchMethod.bm25 in config.search_methods:
168
+ text_search = await node_fulltext_search(driver, query, group_ids, 2 * limit)
90
169
  search_results.append(text_search)
91
170
 
92
- if SearchMethod.cosine_similarity in config.search_methods:
93
- query_text = query.replace('\n', ' ')
171
+ if NodeSearchMethod.cosine_similarity in config.search_methods:
94
172
  search_vector = (
95
- (await embedder.create(input=[query_text], model='text-embedding-3-small'))
173
+ (await embedder.create(input=[query], model='text-embedding-3-small'))
96
174
  .data[0]
97
175
  .embedding[:EMBEDDING_DIM]
98
176
  )
99
177
 
100
- similarity_search = await edge_similarity_search(
101
- driver, search_vector, None, None, config.group_ids, 2 * config.num_edges
178
+ similarity_search = await node_similarity_search(
179
+ driver, search_vector, group_ids, 2 * limit
102
180
  )
103
181
  search_results.append(similarity_search)
104
182
 
105
183
  if len(search_results) > 1 and config.reranker is None:
106
- logger.exception('Multiple searches enabled without a reranker')
107
- raise Exception('Multiple searches enabled without a reranker')
184
+ raise SearchRerankerError('Multiple node searches enabled without a reranker')
108
185
 
109
- else:
110
- edge_uuid_map = {}
111
- search_result_uuids = []
186
+ search_result_uuids = [[node.uuid for node in result] for result in search_results]
187
+ node_uuid_map = {node.uuid: node for result in search_results for node in result}
112
188
 
113
- for result in search_results:
114
- result_uuids = []
115
- for edge in result:
116
- result_uuids.append(edge.uuid)
117
- edge_uuid_map[edge.uuid] = edge
189
+ reranked_uuids: list[str] = []
190
+ if config.reranker == NodeReranker.rrf:
191
+ reranked_uuids = rrf(search_result_uuids)
192
+ elif config.reranker == NodeReranker.node_distance:
193
+ if center_node_uuid is None:
194
+ raise SearchRerankerError('No center node provided for Node Distance reranker')
195
+ reranked_uuids = await node_distance_reranker(driver, search_result_uuids, center_node_uuid)
118
196
 
119
- search_result_uuids.append(result_uuids)
197
+ reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
120
198
 
121
- search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
199
+ return reranked_nodes
122
200
 
123
- reranked_uuids: list[str] = []
124
- if config.reranker == Reranker.rrf:
125
- reranked_uuids = rrf(search_result_uuids)
126
- elif config.reranker == Reranker.node_distance:
127
- if center_node_uuid is None:
128
- logger.exception('No center node provided for Node Distance reranker')
129
- raise Exception('No center node provided for Node Distance reranker')
130
- reranked_uuids = await node_distance_reranker(
131
- driver, search_result_uuids, center_node_uuid
132
- )
133
-
134
- reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
135
- edges.extend(reranked_edges)
136
-
137
- context = SearchResults(
138
- episodes=episodes, nodes=nodes[: config.num_nodes], edges=edges[: config.num_edges]
139
- )
140
201
 
141
- end = time()
202
+ async def community_search(
203
+ driver: AsyncDriver,
204
+ embedder,
205
+ query: str,
206
+ group_ids: list[str | None] | None,
207
+ config: CommunitySearchConfig,
208
+ limit=DEFAULT_SEARCH_LIMIT,
209
+ ) -> list[CommunityNode]:
210
+ search_results: list[list[CommunityNode]] = []
211
+
212
+ if CommunitySearchMethod.bm25 in config.search_methods:
213
+ text_search = await community_fulltext_search(driver, query, group_ids, 2 * limit)
214
+ search_results.append(text_search)
142
215
 
143
- logger.info(f'search returned context for query {query} in {(end - start) * 1000} ms')
216
+ if CommunitySearchMethod.cosine_similarity in config.search_methods:
217
+ search_vector = (
218
+ (await embedder.create(input=[query], model='text-embedding-3-small'))
219
+ .data[0]
220
+ .embedding[:EMBEDDING_DIM]
221
+ )
222
+
223
+ similarity_search = await community_similarity_search(
224
+ driver, search_vector, group_ids, 2 * limit
225
+ )
226
+ search_results.append(similarity_search)
227
+
228
+ if len(search_results) > 1 and config.reranker is None:
229
+ raise SearchRerankerError('Multiple node searches enabled without a reranker')
230
+
231
+ search_result_uuids = [[community.uuid for community in result] for result in search_results]
232
+ community_uuid_map = {
233
+ community.uuid: community for result in search_results for community in result
234
+ }
235
+
236
+ reranked_uuids: list[str] = []
237
+ if config.reranker == CommunityReranker.rrf:
238
+ reranked_uuids = rrf(search_result_uuids)
239
+
240
+ reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
144
241
 
145
- return context
242
+ return reranked_communities
@@ -0,0 +1,81 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from enum import Enum
18
+
19
+ from pydantic import BaseModel, Field
20
+
21
+ from graphiti_core.edges import EntityEdge
22
+ from graphiti_core.nodes import CommunityNode, EntityNode
23
+
24
+ DEFAULT_SEARCH_LIMIT = 10
25
+
26
+
27
+ class EdgeSearchMethod(Enum):
28
+ cosine_similarity = 'cosine_similarity'
29
+ bm25 = 'bm25'
30
+
31
+
32
+ class NodeSearchMethod(Enum):
33
+ cosine_similarity = 'cosine_similarity'
34
+ bm25 = 'bm25'
35
+
36
+
37
+ class CommunitySearchMethod(Enum):
38
+ cosine_similarity = 'cosine_similarity'
39
+ bm25 = 'bm25'
40
+
41
+
42
+ class EdgeReranker(Enum):
43
+ rrf = 'reciprocal_rank_fusion'
44
+ node_distance = 'node_distance'
45
+
46
+
47
+ class NodeReranker(Enum):
48
+ rrf = 'reciprocal_rank_fusion'
49
+ node_distance = 'node_distance'
50
+
51
+
52
+ class CommunityReranker(Enum):
53
+ rrf = 'reciprocal_rank_fusion'
54
+
55
+
56
+ class EdgeSearchConfig(BaseModel):
57
+ search_methods: list[EdgeSearchMethod]
58
+ reranker: EdgeReranker | None
59
+
60
+
61
+ class NodeSearchConfig(BaseModel):
62
+ search_methods: list[NodeSearchMethod]
63
+ reranker: NodeReranker | None
64
+
65
+
66
+ class CommunitySearchConfig(BaseModel):
67
+ search_methods: list[CommunitySearchMethod]
68
+ reranker: CommunityReranker | None
69
+
70
+
71
+ class SearchConfig(BaseModel):
72
+ edge_config: EdgeSearchConfig | None = Field(default=None)
73
+ node_config: NodeSearchConfig | None = Field(default=None)
74
+ community_config: CommunitySearchConfig | None = Field(default=None)
75
+ limit: int = Field(default=DEFAULT_SEARCH_LIMIT)
76
+
77
+
78
+ class SearchResults(BaseModel):
79
+ edges: list[EntityEdge]
80
+ nodes: list[EntityNode]
81
+ communities: list[CommunityNode]
@@ -0,0 +1,84 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from graphiti_core.search.search_config import (
18
+ CommunityReranker,
19
+ CommunitySearchConfig,
20
+ CommunitySearchMethod,
21
+ EdgeReranker,
22
+ EdgeSearchConfig,
23
+ EdgeSearchMethod,
24
+ NodeReranker,
25
+ NodeSearchConfig,
26
+ NodeSearchMethod,
27
+ SearchConfig,
28
+ )
29
+
30
+ # Performs a hybrid search with rrf reranking over edges, nodes, and communities
31
+ COMBINED_HYBRID_SEARCH_RRF = SearchConfig(
32
+ edge_config=EdgeSearchConfig(
33
+ search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
34
+ reranker=EdgeReranker.rrf,
35
+ ),
36
+ node_config=NodeSearchConfig(
37
+ search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
38
+ reranker=NodeReranker.rrf,
39
+ ),
40
+ community_config=CommunitySearchConfig(
41
+ search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
42
+ reranker=CommunityReranker.rrf,
43
+ ),
44
+ )
45
+
46
+ # performs a hybrid search over edges with rrf reranking
47
+ EDGE_HYBRID_SEARCH_RRF = SearchConfig(
48
+ edge_config=EdgeSearchConfig(
49
+ search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
50
+ reranker=EdgeReranker.rrf,
51
+ )
52
+ )
53
+
54
+ # performs a hybrid search over edges with node distance reranking
55
+ EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
56
+ edge_config=EdgeSearchConfig(
57
+ search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
58
+ reranker=EdgeReranker.node_distance,
59
+ )
60
+ )
61
+
62
+ # performs a hybrid search over nodes with rrf reranking
63
+ NODE_HYBRID_SEARCH_RRF = SearchConfig(
64
+ node_config=NodeSearchConfig(
65
+ search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
66
+ reranker=NodeReranker.rrf,
67
+ )
68
+ )
69
+
70
+ # performs a hybrid search over nodes with node distance reranking
71
+ NODE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
72
+ node_config=NodeSearchConfig(
73
+ search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
74
+ reranker=NodeReranker.node_distance,
75
+ )
76
+ )
77
+
78
+ # performs a hybrid search over communities with rrf reranking
79
+ COMMUNITY_HYBRID_SEARCH_RRF = SearchConfig(
80
+ community_config=CommunitySearchConfig(
81
+ search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
82
+ reranker=CommunityReranker.rrf,
83
+ )
84
+ )
@@ -1,3 +1,19 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
1
17
  import asyncio
2
18
  import logging
3
19
  import re
@@ -7,7 +23,13 @@ from time import time
7
23
  from neo4j import AsyncDriver, Query
8
24
 
9
25
  from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
10
- from graphiti_core.nodes import EntityNode, EpisodicNode, get_entity_node_from_record
26
+ from graphiti_core.nodes import (
27
+ CommunityNode,
28
+ EntityNode,
29
+ EpisodicNode,
30
+ get_community_node_from_record,
31
+ get_entity_node_from_record,
32
+ )
11
33
 
12
34
  logger = logging.getLogger(__name__)
13
35
 
@@ -35,6 +57,128 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode])
35
57
  return nodes
36
58
 
37
59
 
60
+ async def edge_fulltext_search(
61
+ driver: AsyncDriver,
62
+ query: str,
63
+ source_node_uuid: str | None,
64
+ target_node_uuid: str | None,
65
+ group_ids: list[str | None] | None = None,
66
+ limit=RELEVANT_SCHEMA_LIMIT,
67
+ ) -> list[EntityEdge]:
68
+ # fulltext search over facts
69
+ cypher_query = Query("""
70
+ CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
71
+ YIELD relationship AS rel, score
72
+ MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
73
+ WHERE CASE
74
+ WHEN $group_ids IS NULL THEN n.group_id IS NULL
75
+ ELSE n.group_id IN $group_ids
76
+ END
77
+ RETURN
78
+ r.uuid AS uuid,
79
+ r.group_id AS group_id,
80
+ n.uuid AS source_node_uuid,
81
+ m.uuid AS target_node_uuid,
82
+ r.created_at AS created_at,
83
+ r.name AS name,
84
+ r.fact AS fact,
85
+ r.fact_embedding AS fact_embedding,
86
+ r.episodes AS episodes,
87
+ r.expired_at AS expired_at,
88
+ r.valid_at AS valid_at,
89
+ r.invalid_at AS invalid_at
90
+ ORDER BY score DESC LIMIT $limit
91
+ """)
92
+
93
+ if source_node_uuid is None and target_node_uuid is None:
94
+ cypher_query = Query("""
95
+ CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
96
+ YIELD relationship AS rel, score
97
+ MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
98
+ WHERE CASE
99
+ WHEN $group_ids IS NULL THEN r.group_id IS NULL
100
+ ELSE r.group_id IN $group_ids
101
+ END
102
+ RETURN
103
+ r.uuid AS uuid,
104
+ r.group_id AS group_id,
105
+ n.uuid AS source_node_uuid,
106
+ m.uuid AS target_node_uuid,
107
+ r.created_at AS created_at,
108
+ r.name AS name,
109
+ r.fact AS fact,
110
+ r.fact_embedding AS fact_embedding,
111
+ r.episodes AS episodes,
112
+ r.expired_at AS expired_at,
113
+ r.valid_at AS valid_at,
114
+ r.invalid_at AS invalid_at
115
+ ORDER BY score DESC LIMIT $limit
116
+ """)
117
+ elif source_node_uuid is None:
118
+ cypher_query = Query("""
119
+ CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
120
+ YIELD relationship AS rel, score
121
+ MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
122
+ WHERE CASE
123
+ WHEN $group_ids IS NULL THEN r.group_id IS NULL
124
+ ELSE r.group_id IN $group_ids
125
+ END
126
+ RETURN
127
+ r.uuid AS uuid,
128
+ r.group_id AS group_id,
129
+ n.uuid AS source_node_uuid,
130
+ m.uuid AS target_node_uuid,
131
+ r.created_at AS created_at,
132
+ r.name AS name,
133
+ r.fact AS fact,
134
+ r.fact_embedding AS fact_embedding,
135
+ r.episodes AS episodes,
136
+ r.expired_at AS expired_at,
137
+ r.valid_at AS valid_at,
138
+ r.invalid_at AS invalid_at
139
+ ORDER BY score DESC LIMIT $limit
140
+ """)
141
+ elif target_node_uuid is None:
142
+ cypher_query = Query("""
143
+ CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
144
+ YIELD relationship AS rel, score
145
+ MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
146
+ WHERE CASE
147
+ WHEN $group_ids IS NULL THEN r.group_id IS NULL
148
+ ELSE r.group_id IN $group_ids
149
+ END
150
+ RETURN
151
+ r.uuid AS uuid,
152
+ r.group_id AS group_id,
153
+ n.uuid AS source_node_uuid,
154
+ m.uuid AS target_node_uuid,
155
+ r.created_at AS created_at,
156
+ r.name AS name,
157
+ r.fact AS fact,
158
+ r.fact_embedding AS fact_embedding,
159
+ r.episodes AS episodes,
160
+ r.expired_at AS expired_at,
161
+ r.valid_at AS valid_at,
162
+ r.invalid_at AS invalid_at
163
+ ORDER BY score DESC LIMIT $limit
164
+ """)
165
+
166
+ fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
167
+
168
+ records, _, _ = await driver.execute_query(
169
+ cypher_query,
170
+ query=fuzzy_query,
171
+ source_uuid=source_node_uuid,
172
+ target_uuid=target_node_uuid,
173
+ group_ids=group_ids,
174
+ limit=limit,
175
+ )
176
+
177
+ edges = [get_entity_edge_from_record(record) for record in records]
178
+
179
+ return edges
180
+
181
+
38
182
  async def edge_similarity_search(
39
183
  driver: AsyncDriver,
40
184
  search_vector: list[float],
@@ -43,13 +187,15 @@ async def edge_similarity_search(
43
187
  group_ids: list[str | None] | None = None,
44
188
  limit: int = RELEVANT_SCHEMA_LIMIT,
45
189
  ) -> list[EntityEdge]:
46
- group_ids = group_ids if group_ids is not None else [None]
47
190
  # vector similarity search over embedded facts
48
191
  query = Query("""
49
192
  CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
50
193
  YIELD relationship AS rel, score
51
194
  MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
52
- WHERE r.group_id IN $group_ids
195
+ WHERE CASE
196
+ WHEN $group_ids IS NULL THEN r.group_id IS NULL
197
+ ELSE r.group_id IN $group_ids
198
+ END
53
199
  RETURN
54
200
  r.uuid AS uuid,
55
201
  r.group_id AS group_id,
@@ -71,7 +217,10 @@ async def edge_similarity_search(
71
217
  CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
72
218
  YIELD relationship AS rel, score
73
219
  MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
74
- WHERE r.group_id IN $group_ids
220
+ WHERE CASE
221
+ WHEN $group_ids IS NULL THEN r.group_id IS NULL
222
+ ELSE r.group_id IN $group_ids
223
+ END
75
224
  RETURN
76
225
  r.uuid AS uuid,
77
226
  r.group_id AS group_id,
@@ -92,7 +241,10 @@ async def edge_similarity_search(
92
241
  CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
93
242
  YIELD relationship AS rel, score
94
243
  MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
95
- WHERE r.group_id IN $group_ids
244
+ WHERE CASE
245
+ WHEN $group_ids IS NULL THEN r.group_id IS NULL
246
+ ELSE r.group_id IN $group_ids
247
+ END
96
248
  RETURN
97
249
  r.uuid AS uuid,
98
250
  r.group_id AS group_id,
@@ -113,7 +265,10 @@ async def edge_similarity_search(
113
265
  CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
114
266
  YIELD relationship AS rel, score
115
267
  MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
116
- WHERE r.group_id IN $group_ids
268
+ WHERE CASE
269
+ WHEN $group_ids IS NULL THEN r.group_id IS NULL
270
+ ELSE r.group_id IN $group_ids
271
+ END
117
272
  RETURN
118
273
  r.uuid AS uuid,
119
274
  r.group_id AS group_id,
@@ -144,9 +299,44 @@ async def edge_similarity_search(
144
299
  return edges
145
300
 
146
301
 
147
- async def entity_similarity_search(
148
- search_vector: list[float],
302
+ async def node_fulltext_search(
303
+ driver: AsyncDriver,
304
+ query: str,
305
+ group_ids: list[str | None] | None = None,
306
+ limit=RELEVANT_SCHEMA_LIMIT,
307
+ ) -> list[EntityNode]:
308
+ # BM25 search to get top nodes
309
+ fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
310
+ records, _, _ = await driver.execute_query(
311
+ """
312
+ CALL db.index.fulltext.queryNodes("name_and_summary", $query)
313
+ YIELD node AS n, score
314
+ WHERE CASE
315
+ WHEN $group_ids IS NULL THEN n.group_id IS NULL
316
+ ELSE n.group_id IN $group_ids
317
+ END
318
+ RETURN
319
+ n.uuid AS uuid,
320
+ n.group_id AS group_id,
321
+ n.name AS name,
322
+ n.name_embedding AS name_embedding,
323
+ n.created_at AS created_at,
324
+ n.summary AS summary
325
+ ORDER BY score DESC
326
+ LIMIT $limit
327
+ """,
328
+ query=fuzzy_query,
329
+ group_ids=group_ids,
330
+ limit=limit,
331
+ )
332
+ nodes = [get_entity_node_from_record(record) for record in records]
333
+
334
+ return nodes
335
+
336
+
337
+ async def node_similarity_search(
149
338
  driver: AsyncDriver,
339
+ search_vector: list[float],
150
340
  group_ids: list[str | None] | None = None,
151
341
  limit=RELEVANT_SCHEMA_LIMIT,
152
342
  ) -> list[EntityNode]:
@@ -176,28 +366,28 @@ async def entity_similarity_search(
176
366
  return nodes
177
367
 
178
368
 
179
- async def entity_fulltext_search(
180
- query: str,
369
+ async def community_fulltext_search(
181
370
  driver: AsyncDriver,
371
+ query: str,
182
372
  group_ids: list[str | None] | None = None,
183
373
  limit=RELEVANT_SCHEMA_LIMIT,
184
- ) -> list[EntityNode]:
374
+ ) -> list[CommunityNode]:
185
375
  group_ids = group_ids if group_ids is not None else [None]
186
376
 
187
- # BM25 search to get top nodes
377
+ # BM25 search to get top communities
188
378
  fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
189
379
  records, _, _ = await driver.execute_query(
190
380
  """
191
- CALL db.index.fulltext.queryNodes("name_and_summary", $query)
192
- YIELD node AS n, score
193
- MATCH (n WHERE n.group_id in $group_ids)
381
+ CALL db.index.fulltext.queryNodes("community_name", $query)
382
+ YIELD node AS comm, score
383
+ MATCH (comm WHERE comm.group_id in $group_ids)
194
384
  RETURN
195
- n.uuid AS uuid,
196
- n.group_id AS group_id,
197
- n.name AS name,
198
- n.name_embedding AS name_embedding,
199
- n.created_at AS created_at,
200
- n.summary AS summary
385
+ comm.uuid AS uuid,
386
+ comm.group_id AS group_id,
387
+ comm.name AS name,
388
+ comm.name_embedding AS name_embedding,
389
+ comm.created_at AS created_at,
390
+ comm.summary AS summary
201
391
  ORDER BY score DESC
202
392
  LIMIT $limit
203
393
  """,
@@ -205,121 +395,41 @@ async def entity_fulltext_search(
205
395
  group_ids=group_ids,
206
396
  limit=limit,
207
397
  )
208
- nodes = [get_entity_node_from_record(record) for record in records]
398
+ communities = [get_community_node_from_record(record) for record in records]
209
399
 
210
- return nodes
400
+ return communities
211
401
 
212
402
 
213
- async def edge_fulltext_search(
403
+ async def community_similarity_search(
214
404
  driver: AsyncDriver,
215
- query: str,
216
- source_node_uuid: str | None,
217
- target_node_uuid: str | None,
405
+ search_vector: list[float],
218
406
  group_ids: list[str | None] | None = None,
219
407
  limit=RELEVANT_SCHEMA_LIMIT,
220
- ) -> list[EntityEdge]:
408
+ ) -> list[CommunityNode]:
221
409
  group_ids = group_ids if group_ids is not None else [None]
222
410
 
223
- # fulltext search over facts
224
- cypher_query = Query("""
225
- CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
226
- YIELD relationship AS rel, score
227
- MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
228
- WHERE r.group_id IN $group_ids
229
- RETURN
230
- r.uuid AS uuid,
231
- r.group_id AS group_id,
232
- n.uuid AS source_node_uuid,
233
- m.uuid AS target_node_uuid,
234
- r.created_at AS created_at,
235
- r.name AS name,
236
- r.fact AS fact,
237
- r.fact_embedding AS fact_embedding,
238
- r.episodes AS episodes,
239
- r.expired_at AS expired_at,
240
- r.valid_at AS valid_at,
241
- r.invalid_at AS invalid_at
242
- ORDER BY score DESC LIMIT $limit
243
- """)
244
-
245
- if source_node_uuid is None and target_node_uuid is None:
246
- cypher_query = Query("""
247
- CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
248
- YIELD relationship AS rel, score
249
- MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
250
- WHERE r.group_id IN $group_ids
251
- RETURN
252
- r.uuid AS uuid,
253
- r.group_id AS group_id,
254
- n.uuid AS source_node_uuid,
255
- m.uuid AS target_node_uuid,
256
- r.created_at AS created_at,
257
- r.name AS name,
258
- r.fact AS fact,
259
- r.fact_embedding AS fact_embedding,
260
- r.episodes AS episodes,
261
- r.expired_at AS expired_at,
262
- r.valid_at AS valid_at,
263
- r.invalid_at AS invalid_at
264
- ORDER BY score DESC LIMIT $limit
265
- """)
266
- elif source_node_uuid is None:
267
- cypher_query = Query("""
268
- CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
269
- YIELD relationship AS rel, score
270
- MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
271
- WHERE r.group_id IN $group_ids
272
- RETURN
273
- r.uuid AS uuid,
274
- r.group_id AS group_id,
275
- n.uuid AS source_node_uuid,
276
- m.uuid AS target_node_uuid,
277
- r.created_at AS created_at,
278
- r.name AS name,
279
- r.fact AS fact,
280
- r.fact_embedding AS fact_embedding,
281
- r.episodes AS episodes,
282
- r.expired_at AS expired_at,
283
- r.valid_at AS valid_at,
284
- r.invalid_at AS invalid_at
285
- ORDER BY score DESC LIMIT $limit
286
- """)
287
- elif target_node_uuid is None:
288
- cypher_query = Query("""
289
- CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
290
- YIELD relationship AS rel, score
291
- MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
292
- WHERE r.group_id IN $group_ids
293
- RETURN
294
- r.uuid AS uuid,
295
- r.group_id AS group_id,
296
- n.uuid AS source_node_uuid,
297
- m.uuid AS target_node_uuid,
298
- r.created_at AS created_at,
299
- r.name AS name,
300
- r.fact AS fact,
301
- r.fact_embedding AS fact_embedding,
302
- r.episodes AS episodes,
303
- r.expired_at AS expired_at,
304
- r.valid_at AS valid_at,
305
- r.invalid_at AS invalid_at
306
- ORDER BY score DESC LIMIT $limit
307
- """)
308
-
309
- fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
310
-
411
+ # vector similarity search over entity names
311
412
  records, _, _ = await driver.execute_query(
312
- cypher_query,
313
- query=fuzzy_query,
314
- source_uuid=source_node_uuid,
315
- target_uuid=target_node_uuid,
413
+ """
414
+ CALL db.index.vector.queryNodes("community_name_embedding", $limit, $search_vector)
415
+ YIELD node AS comm, score
416
+ MATCH (comm WHERE comm.group_id IN $group_ids)
417
+ RETURN
418
+ comm.uuid As uuid,
419
+ comm.group_id AS group_id,
420
+ comm.name AS name,
421
+ comm.name_embedding AS name_embedding,
422
+ comm.created_at AS created_at,
423
+ comm.summary AS summary
424
+ ORDER BY score DESC
425
+ """,
426
+ search_vector=search_vector,
316
427
  group_ids=group_ids,
317
428
  limit=limit,
318
429
  )
430
+ communities = [get_community_node_from_record(record) for record in records]
319
431
 
320
- edges = [get_entity_edge_from_record(record) for record in records]
321
-
322
- return edges
432
+ return communities
323
433
 
324
434
 
325
435
  async def hybrid_node_search(
@@ -371,8 +481,8 @@ async def hybrid_node_search(
371
481
 
372
482
  results: list[list[EntityNode]] = list(
373
483
  await asyncio.gather(
374
- *[entity_fulltext_search(q, driver, group_ids, 2 * limit) for q in queries],
375
- *[entity_similarity_search(e, driver, group_ids, 2 * limit) for e in embeddings],
484
+ *[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries],
485
+ *[node_similarity_search(driver, e, group_ids, 2 * limit) for e in embeddings],
376
486
  )
377
487
  )
378
488
 
@@ -490,24 +600,23 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
490
600
 
491
601
 
492
602
  async def node_distance_reranker(
493
- driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
603
+ driver: AsyncDriver, node_uuids: list[list[str]], center_node_uuid: str
494
604
  ) -> list[str]:
495
605
  # use rrf as a preliminary ranker
496
- sorted_uuids = rrf(results)
606
+ sorted_uuids = rrf(node_uuids)
497
607
  scores: dict[str, float] = {}
498
608
 
499
609
  # Find the shortest path to center node
500
610
  query = Query("""
501
- MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
502
- MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: source.uuid})
503
- RETURN length(p) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
611
+ MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: $node_uuid})
612
+ RETURN length(p) AS score
504
613
  """)
505
614
 
506
615
  path_results = await asyncio.gather(
507
616
  *[
508
617
  driver.execute_query(
509
618
  query,
510
- edge_uuid=uuid,
619
+ node_uuid=uuid,
511
620
  center_uuid=center_node_uuid,
512
621
  )
513
622
  for uuid in sorted_uuids
@@ -518,15 +627,8 @@ async def node_distance_reranker(
518
627
  records = result[0]
519
628
  record = records[0] if len(records) > 0 else None
520
629
  distance: float = record['score'] if record is not None else float('inf')
521
- if record is not None and (
522
- record['source_uuid'] == center_node_uuid or record['target_uuid'] == center_node_uuid
523
- ):
524
- distance = 0
525
-
526
- if uuid in scores:
527
- scores[uuid] = min(distance, scores[uuid])
528
- else:
529
- scores[uuid] = distance
630
+ distance = 0 if uuid == center_node_uuid else distance
631
+ scores[uuid] = distance
530
632
 
531
633
  # rerank on shortest distance
532
634
  sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: graphiti-core
3
- Version: 0.3.0
3
+ Version: 0.3.2
4
4
  Summary: A temporal graph building library
5
5
  License: Apache-2.0
6
6
  Author: Paul Paliychuk
@@ -1,21 +1,21 @@
1
1
  graphiti_core/__init__.py,sha256=e5SWFkRiaUwfprYIeIgVIh7JDedNiloZvd3roU-0aDY,55
2
2
  graphiti_core/edges.py,sha256=bKzlrIrzofggRckgL3RA3MKLTgCKwkPVMB-tVA6Vd_A,9130
3
- graphiti_core/errors.py,sha256=VnirEGvivs0DGdBapy8nGmwLwOeO8CdkFLi4s0dM4W0,499
4
- graphiti_core/graphiti.py,sha256=8Cs--MmvSYTziI73c5NjvX4KuB4p8PbzdeaQouj8L-A,24312
5
- graphiti_core/helpers.py,sha256=EAeC3RrcecjiTGN2vxergN5RHTy2_jhFXA5PQVT3toU,200
3
+ graphiti_core/errors.py,sha256=BOwL0VVnoUuMjK3EUYKvqefsbsYhRhcKcVWXaX9hanw,1259
4
+ graphiti_core/graphiti.py,sha256=ViKKrF84VENTIR6WFtGpZ3FCZqC9B9__lKVsiXPCjV8,24563
5
+ graphiti_core/helpers.py,sha256=qQqZJBkc_z5f3x5axPfCKK_QHLRybvWNFb57WXNENfQ,769
6
6
  graphiti_core/llm_client/__init__.py,sha256=PA80TSMeX-sUXITXEAxMDEt3gtfZgcJrGJUcyds1mSo,207
7
7
  graphiti_core/llm_client/anthropic_client.py,sha256=3zsOkewLFxBhKe90OkmpfkvrcwykgGwRoqII05Jno_Q,2410
8
8
  graphiti_core/llm_client/client.py,sha256=7-gEhOKxjdkllV_xS2Ikn-a4QzK9NE63CANnZgdn3VY,3438
9
9
  graphiti_core/llm_client/config.py,sha256=d1oZ9tt7QBQlbph7v-0HjItb6otK9_-IwF8kkRYL2rc,2359
10
- graphiti_core/llm_client/errors.py,sha256=wYz7pJDC7ppwYpoHpIJPcu--If7NObDy6vJSu__jtDc,244
10
+ graphiti_core/llm_client/errors.py,sha256=-qlWwv1X-UjfsFIiNl-7yJIYvPwi7z8srVRfX4-s6uk,814
11
11
  graphiti_core/llm_client/groq_client.py,sha256=clQvQ9-zCRoqK9NGMx9Icyl4lUXmM70lZgVquXikxBo,2334
12
12
  graphiti_core/llm_client/openai_client.py,sha256=VqzWdSrHuNfF2l1aRDua00NHhtP9UR7VNtLcu8h9vLc,2343
13
- graphiti_core/llm_client/utils.py,sha256=H8-Kwa5SyvIYDNIas8O4bHJ6jsOL49li44VoDEMyauY,555
14
- graphiti_core/nodes.py,sha256=b3R06tFdmKriTwe7evXa7K8uwjuW33mOApXiW404aKU,12150
13
+ graphiti_core/llm_client/utils.py,sha256=0KT4XxTVw3c0__HLDj3F8kNR4K_qY0hT0TH-pQZ_IZw,1126
14
+ graphiti_core/nodes.py,sha256=w2cbyA7g_0eSm7axFWraG4opYxQz7-mPCxkcNHdefJY,12154
15
15
  graphiti_core/prompts/__init__.py,sha256=EA-x9xUki9l8wnu2l8ek_oNf75-do5tq5hVq7Zbv8Kw,101
16
16
  graphiti_core/prompts/dedupe_edges.py,sha256=DUNHdIudj50FAjkla4nc68tSFSD2yjmYHBw-Bb7ph20,6529
17
17
  graphiti_core/prompts/dedupe_nodes.py,sha256=BZ9S-PB9SSGjc5Oo8ivdgA6rZx3OGOFhKtwrBlQ0bm0,7269
18
- graphiti_core/prompts/extract_edge_dates.py,sha256=G-Gnsyt8pYx9lFJEwlIsTdADF3ESDe26WSsrAGmvlYk,3086
18
+ graphiti_core/prompts/extract_edge_dates.py,sha256=oOCR8mC_3gI1bumrmIjUbkNO-WTuLTXXAalPDYnDXeM,3655
19
19
  graphiti_core/prompts/extract_edges.py,sha256=AQ8xYbAv_RKXAT6WMwXs1_GvUdLtM_lhLNbt3SkOAmk,5348
20
20
  graphiti_core/prompts/extract_nodes.py,sha256=VIr0Nh0mSiodI3iGOQFszh7DOni4mufOKJDuGkMysl8,6889
21
21
  graphiti_core/prompts/invalidate_edges.py,sha256=8SHt3iPTdmqk8A52LxgdMtI39w4USKqVDMOS2i6lRQ4,4342
@@ -24,8 +24,10 @@ graphiti_core/prompts/models.py,sha256=cvx_Bv5RMFUD_5IUawYrbpOKLPHogai7_bm7YXrSz
24
24
  graphiti_core/prompts/summarize_nodes.py,sha256=FLuZpGTABgcxuIDkx_IKH115nHEw0rIaFhcGlWveAMc,2357
25
25
  graphiti_core/py.typed,sha256=vlmmzQOt7bmeQl9L3XJP4W6Ry0iiELepnOrinKz5KQg,79
26
26
  graphiti_core/search/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
- graphiti_core/search/search.py,sha256=cr1-syRlRdijnLtbuQYWy_2G1CtAeIaz6BQ2kl_6FrY,4535
28
- graphiti_core/search/search_utils.py,sha256=jwFoN1z6XKpHw5xdReJpblA4eBDiLIlY4lR8rCgDP5o,19502
27
+ graphiti_core/search/search.py,sha256=BtyZBhwAt_IbU8dqm-DeRAIovkFDTdFly5IBGxs4yy8,8101
28
+ graphiti_core/search/search_config.py,sha256=nOLU_k2p_sM0-JBYci8rWhc-mERv8uWkDn0GOYqZjL8,2081
29
+ graphiti_core/search/search_config_recipes.py,sha256=CJIhYjXPgSm20cY9IkXQxArCgwLvjz-4xB7mr4NylWg,2857
30
+ graphiti_core/search/search_utils.py,sha256=vFxLMt0CB_1Avn32d1PFsJPtJ26MCEdoq-BSBx2uCGQ,22802
29
31
  graphiti_core/utils/__init__.py,sha256=cJAcMnBZdHBQmWrZdU1PQ1YmaL75bhVUkyVpIPuOyns,260
30
32
  graphiti_core/utils/bulk_utils.py,sha256=JtoYTZPCigPa3n2E43Oe7QhFZRTA_QKNGy1jVgklHag,12614
31
33
  graphiti_core/utils/maintenance/__init__.py,sha256=4b9sfxqyFZMLwxxS2lnQ6_wBr3xrJRIqfAWOidK8EK0,388
@@ -35,7 +37,7 @@ graphiti_core/utils/maintenance/graph_data_operations.py,sha256=d27efEVLvQTmoKE7
35
37
  graphiti_core/utils/maintenance/node_operations.py,sha256=WXJFU1AprYjmHSq6rZhTIX4JFHtF5W9LbzA2Tfksp5Q,8838
36
38
  graphiti_core/utils/maintenance/temporal_operations.py,sha256=BzfGDm96w4HcUEsaWTHUBt5S8dNmDQL1eX6AuBL-XFM,8135
37
39
  graphiti_core/utils/maintenance/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
38
- graphiti_core-0.3.0.dist-info/LICENSE,sha256=KCUwCyDXuVEgmDWkozHyniRyWjnWUWjkuDHfU6o3JlA,11325
39
- graphiti_core-0.3.0.dist-info/METADATA,sha256=L96LEC27fgAsQ_uTtol9fagh5fykQfJxRd0fIiWcrus,9323
40
- graphiti_core-0.3.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
41
- graphiti_core-0.3.0.dist-info/RECORD,,
40
+ graphiti_core-0.3.2.dist-info/LICENSE,sha256=KCUwCyDXuVEgmDWkozHyniRyWjnWUWjkuDHfU6o3JlA,11325
41
+ graphiti_core-0.3.2.dist-info/METADATA,sha256=zmcAQu2r7J1odYWWcYQ91fpahTbTu4YZnxEXN-1Qge0,9323
42
+ graphiti_core-0.3.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
43
+ graphiti_core-0.3.2.dist-info/RECORD,,