graphiti-core 0.2.3__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -15,131 +15,228 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  import logging
18
- from datetime import datetime
19
- from enum import Enum
20
18
  from time import time
21
19
 
22
20
  from neo4j import AsyncDriver
23
- from pydantic import BaseModel, Field
24
21
 
25
22
  from graphiti_core.edges import EntityEdge
23
+ from graphiti_core.errors import SearchRerankerError
26
24
  from graphiti_core.llm_client.config import EMBEDDING_DIM
27
- from graphiti_core.nodes import EntityNode, EpisodicNode
25
+ from graphiti_core.nodes import CommunityNode, EntityNode
26
+ from graphiti_core.search.search_config import (
27
+ DEFAULT_SEARCH_LIMIT,
28
+ CommunityReranker,
29
+ CommunitySearchConfig,
30
+ CommunitySearchMethod,
31
+ EdgeReranker,
32
+ EdgeSearchConfig,
33
+ EdgeSearchMethod,
34
+ NodeReranker,
35
+ NodeSearchConfig,
36
+ NodeSearchMethod,
37
+ SearchConfig,
38
+ SearchResults,
39
+ )
28
40
  from graphiti_core.search.search_utils import (
41
+ community_fulltext_search,
42
+ community_similarity_search,
29
43
  edge_fulltext_search,
30
44
  edge_similarity_search,
31
- get_mentioned_nodes,
32
45
  node_distance_reranker,
46
+ node_fulltext_search,
47
+ node_similarity_search,
33
48
  rrf,
34
49
  )
35
- from graphiti_core.utils import retrieve_episodes
36
- from graphiti_core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
37
50
 
38
51
  logger = logging.getLogger(__name__)
39
52
 
40
53
 
41
- class SearchMethod(Enum):
42
- cosine_similarity = 'cosine_similarity'
43
- bm25 = 'bm25'
44
-
45
-
46
- class Reranker(Enum):
47
- rrf = 'reciprocal_rank_fusion'
48
- node_distance = 'node_distance'
54
+ async def search(
55
+ driver: AsyncDriver,
56
+ embedder,
57
+ query: str,
58
+ group_ids: list[str | None] | None,
59
+ config: SearchConfig,
60
+ center_node_uuid: str | None = None,
61
+ ) -> SearchResults:
62
+ start = time()
63
+ query = query.replace('\n', ' ')
64
+ # if group_ids is empty, set it to None
65
+ group_ids = group_ids if group_ids else None
66
+ edges = (
67
+ await edge_search(
68
+ driver, embedder, query, group_ids, config.edge_config, center_node_uuid, config.limit
69
+ )
70
+ if config.edge_config is not None
71
+ else []
72
+ )
73
+ nodes = (
74
+ await node_search(
75
+ driver, embedder, query, group_ids, config.node_config, center_node_uuid, config.limit
76
+ )
77
+ if config.node_config is not None
78
+ else []
79
+ )
80
+ communities = (
81
+ await community_search(
82
+ driver, embedder, query, group_ids, config.community_config, config.limit
83
+ )
84
+ if config.community_config is not None
85
+ else []
86
+ )
49
87
 
88
+ results = SearchResults(
89
+ edges=edges[: config.limit],
90
+ nodes=nodes[: config.limit],
91
+ communities=communities[: config.limit],
92
+ )
50
93
 
51
- class SearchConfig(BaseModel):
52
- num_edges: int = Field(default=10)
53
- num_nodes: int = Field(default=10)
54
- num_episodes: int = EPISODE_WINDOW_LEN
55
- group_ids: list[str | None] | None
56
- search_methods: list[SearchMethod]
57
- reranker: Reranker | None
94
+ end = time()
58
95
 
96
+ logger.info(f'search returned context for query {query} in {(end - start) * 1000} ms')
59
97
 
60
- class SearchResults(BaseModel):
61
- episodes: list[EpisodicNode]
62
- nodes: list[EntityNode]
63
- edges: list[EntityEdge]
98
+ return results
64
99
 
65
100
 
66
- async def hybrid_search(
101
+ async def edge_search(
67
102
  driver: AsyncDriver,
68
103
  embedder,
69
104
  query: str,
70
- timestamp: datetime,
71
- config: SearchConfig,
105
+ group_ids: list[str | None] | None,
106
+ config: EdgeSearchConfig,
72
107
  center_node_uuid: str | None = None,
73
- ) -> SearchResults:
74
- start = time()
75
-
76
- episodes = []
77
- nodes = []
78
- edges = []
108
+ limit=DEFAULT_SEARCH_LIMIT,
109
+ ) -> list[EntityEdge]:
110
+ search_results: list[list[EntityEdge]] = []
79
111
 
80
- search_results = []
112
+ if EdgeSearchMethod.bm25 in config.search_methods:
113
+ text_search = await edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit)
114
+ search_results.append(text_search)
81
115
 
82
- if config.num_episodes > 0:
83
- episodes.extend(await retrieve_episodes(driver, timestamp, config.num_episodes))
84
- nodes.extend(await get_mentioned_nodes(driver, episodes))
116
+ if EdgeSearchMethod.cosine_similarity in config.search_methods:
117
+ search_vector = (
118
+ (await embedder.create(input=[query], model='text-embedding-3-small'))
119
+ .data[0]
120
+ .embedding[:EMBEDDING_DIM]
121
+ )
85
122
 
86
- if SearchMethod.bm25 in config.search_methods:
87
- text_search = await edge_fulltext_search(
88
- driver, query, None, None, config.group_ids, 2 * config.num_edges
123
+ similarity_search = await edge_similarity_search(
124
+ driver, search_vector, None, None, group_ids, 2 * limit
89
125
  )
126
+ search_results.append(similarity_search)
127
+
128
+ if len(search_results) > 1 and config.reranker is None:
129
+ raise SearchRerankerError('Multiple edge searches enabled without a reranker')
130
+
131
+ edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
132
+
133
+ reranked_uuids: list[str] = []
134
+ if config.reranker == EdgeReranker.rrf:
135
+ search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
136
+
137
+ reranked_uuids = rrf(search_result_uuids)
138
+ elif config.reranker == EdgeReranker.node_distance:
139
+ if center_node_uuid is None:
140
+ raise SearchRerankerError('No center node provided for Node Distance reranker')
141
+
142
+ source_to_edge_uuid_map = {
143
+ edge.source_node_uuid: edge.uuid for result in search_results for edge in result
144
+ }
145
+ source_uuids = [[edge.source_node_uuid for edge in result] for result in search_results]
146
+
147
+ reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid)
148
+
149
+ reranked_uuids = [source_to_edge_uuid_map[node_uuid] for node_uuid in reranked_node_uuids]
150
+
151
+ reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
152
+
153
+ return reranked_edges
154
+
155
+
156
+ async def node_search(
157
+ driver: AsyncDriver,
158
+ embedder,
159
+ query: str,
160
+ group_ids: list[str | None] | None,
161
+ config: NodeSearchConfig,
162
+ center_node_uuid: str | None = None,
163
+ limit=DEFAULT_SEARCH_LIMIT,
164
+ ) -> list[EntityNode]:
165
+ search_results: list[list[EntityNode]] = []
166
+
167
+ if NodeSearchMethod.bm25 in config.search_methods:
168
+ text_search = await node_fulltext_search(driver, query, group_ids, 2 * limit)
90
169
  search_results.append(text_search)
91
170
 
92
- if SearchMethod.cosine_similarity in config.search_methods:
93
- query_text = query.replace('\n', ' ')
171
+ if NodeSearchMethod.cosine_similarity in config.search_methods:
94
172
  search_vector = (
95
- (await embedder.create(input=[query_text], model='text-embedding-3-small'))
173
+ (await embedder.create(input=[query], model='text-embedding-3-small'))
96
174
  .data[0]
97
175
  .embedding[:EMBEDDING_DIM]
98
176
  )
99
177
 
100
- similarity_search = await edge_similarity_search(
101
- driver, search_vector, None, None, config.group_ids, 2 * config.num_edges
178
+ similarity_search = await node_similarity_search(
179
+ driver, search_vector, group_ids, 2 * limit
102
180
  )
103
181
  search_results.append(similarity_search)
104
182
 
105
183
  if len(search_results) > 1 and config.reranker is None:
106
- logger.exception('Multiple searches enabled without a reranker')
107
- raise Exception('Multiple searches enabled without a reranker')
184
+ raise SearchRerankerError('Multiple node searches enabled without a reranker')
108
185
 
109
- else:
110
- edge_uuid_map = {}
111
- search_result_uuids = []
186
+ search_result_uuids = [[node.uuid for node in result] for result in search_results]
187
+ node_uuid_map = {node.uuid: node for result in search_results for node in result}
112
188
 
113
- for result in search_results:
114
- result_uuids = []
115
- for edge in result:
116
- result_uuids.append(edge.uuid)
117
- edge_uuid_map[edge.uuid] = edge
189
+ reranked_uuids: list[str] = []
190
+ if config.reranker == NodeReranker.rrf:
191
+ reranked_uuids = rrf(search_result_uuids)
192
+ elif config.reranker == NodeReranker.node_distance:
193
+ if center_node_uuid is None:
194
+ raise SearchRerankerError('No center node provided for Node Distance reranker')
195
+ reranked_uuids = await node_distance_reranker(driver, search_result_uuids, center_node_uuid)
118
196
 
119
- search_result_uuids.append(result_uuids)
197
+ reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
120
198
 
121
- search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
199
+ return reranked_nodes
122
200
 
123
- reranked_uuids: list[str] = []
124
- if config.reranker == Reranker.rrf:
125
- reranked_uuids = rrf(search_result_uuids)
126
- elif config.reranker == Reranker.node_distance:
127
- if center_node_uuid is None:
128
- logger.exception('No center node provided for Node Distance reranker')
129
- raise Exception('No center node provided for Node Distance reranker')
130
- reranked_uuids = await node_distance_reranker(
131
- driver, search_result_uuids, center_node_uuid
132
- )
133
-
134
- reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
135
- edges.extend(reranked_edges)
136
-
137
- context = SearchResults(
138
- episodes=episodes, nodes=nodes[: config.num_nodes], edges=edges[: config.num_edges]
139
- )
140
201
 
141
- end = time()
202
+ async def community_search(
203
+ driver: AsyncDriver,
204
+ embedder,
205
+ query: str,
206
+ group_ids: list[str | None] | None,
207
+ config: CommunitySearchConfig,
208
+ limit=DEFAULT_SEARCH_LIMIT,
209
+ ) -> list[CommunityNode]:
210
+ search_results: list[list[CommunityNode]] = []
211
+
212
+ if CommunitySearchMethod.bm25 in config.search_methods:
213
+ text_search = await community_fulltext_search(driver, query, group_ids, 2 * limit)
214
+ search_results.append(text_search)
142
215
 
143
- logger.info(f'search returned context for query {query} in {(end - start) * 1000} ms')
216
+ if CommunitySearchMethod.cosine_similarity in config.search_methods:
217
+ search_vector = (
218
+ (await embedder.create(input=[query], model='text-embedding-3-small'))
219
+ .data[0]
220
+ .embedding[:EMBEDDING_DIM]
221
+ )
222
+
223
+ similarity_search = await community_similarity_search(
224
+ driver, search_vector, group_ids, 2 * limit
225
+ )
226
+ search_results.append(similarity_search)
227
+
228
+ if len(search_results) > 1 and config.reranker is None:
229
+ raise SearchRerankerError('Multiple node searches enabled without a reranker')
230
+
231
+ search_result_uuids = [[community.uuid for community in result] for result in search_results]
232
+ community_uuid_map = {
233
+ community.uuid: community for result in search_results for community in result
234
+ }
235
+
236
+ reranked_uuids: list[str] = []
237
+ if config.reranker == CommunityReranker.rrf:
238
+ reranked_uuids = rrf(search_result_uuids)
239
+
240
+ reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
144
241
 
145
- return context
242
+ return reranked_communities
@@ -0,0 +1,81 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from enum import Enum
18
+
19
+ from pydantic import BaseModel, Field
20
+
21
+ from graphiti_core.edges import EntityEdge
22
+ from graphiti_core.nodes import CommunityNode, EntityNode
23
+
24
+ DEFAULT_SEARCH_LIMIT = 10
25
+
26
+
27
+ class EdgeSearchMethod(Enum):
28
+ cosine_similarity = 'cosine_similarity'
29
+ bm25 = 'bm25'
30
+
31
+
32
+ class NodeSearchMethod(Enum):
33
+ cosine_similarity = 'cosine_similarity'
34
+ bm25 = 'bm25'
35
+
36
+
37
+ class CommunitySearchMethod(Enum):
38
+ cosine_similarity = 'cosine_similarity'
39
+ bm25 = 'bm25'
40
+
41
+
42
+ class EdgeReranker(Enum):
43
+ rrf = 'reciprocal_rank_fusion'
44
+ node_distance = 'node_distance'
45
+
46
+
47
+ class NodeReranker(Enum):
48
+ rrf = 'reciprocal_rank_fusion'
49
+ node_distance = 'node_distance'
50
+
51
+
52
+ class CommunityReranker(Enum):
53
+ rrf = 'reciprocal_rank_fusion'
54
+
55
+
56
+ class EdgeSearchConfig(BaseModel):
57
+ search_methods: list[EdgeSearchMethod]
58
+ reranker: EdgeReranker | None
59
+
60
+
61
+ class NodeSearchConfig(BaseModel):
62
+ search_methods: list[NodeSearchMethod]
63
+ reranker: NodeReranker | None
64
+
65
+
66
+ class CommunitySearchConfig(BaseModel):
67
+ search_methods: list[CommunitySearchMethod]
68
+ reranker: CommunityReranker | None
69
+
70
+
71
+ class SearchConfig(BaseModel):
72
+ edge_config: EdgeSearchConfig | None = Field(default=None)
73
+ node_config: NodeSearchConfig | None = Field(default=None)
74
+ community_config: CommunitySearchConfig | None = Field(default=None)
75
+ limit: int = Field(default=DEFAULT_SEARCH_LIMIT)
76
+
77
+
78
+ class SearchResults(BaseModel):
79
+ edges: list[EntityEdge]
80
+ nodes: list[EntityNode]
81
+ communities: list[CommunityNode]
@@ -0,0 +1,84 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from graphiti_core.search.search_config import (
18
+ CommunityReranker,
19
+ CommunitySearchConfig,
20
+ CommunitySearchMethod,
21
+ EdgeReranker,
22
+ EdgeSearchConfig,
23
+ EdgeSearchMethod,
24
+ NodeReranker,
25
+ NodeSearchConfig,
26
+ NodeSearchMethod,
27
+ SearchConfig,
28
+ )
29
+
30
+ # Performs a hybrid search with rrf reranking over edges, nodes, and communities
31
+ COMBINED_HYBRID_SEARCH_RRF = SearchConfig(
32
+ edge_config=EdgeSearchConfig(
33
+ search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
34
+ reranker=EdgeReranker.rrf,
35
+ ),
36
+ node_config=NodeSearchConfig(
37
+ search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
38
+ reranker=NodeReranker.rrf,
39
+ ),
40
+ community_config=CommunitySearchConfig(
41
+ search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
42
+ reranker=CommunityReranker.rrf,
43
+ ),
44
+ )
45
+
46
+ # performs a hybrid search over edges with rrf reranking
47
+ EDGE_HYBRID_SEARCH_RRF = SearchConfig(
48
+ edge_config=EdgeSearchConfig(
49
+ search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
50
+ reranker=EdgeReranker.rrf,
51
+ )
52
+ )
53
+
54
+ # performs a hybrid search over edges with node distance reranking
55
+ EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
56
+ edge_config=EdgeSearchConfig(
57
+ search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
58
+ reranker=EdgeReranker.node_distance,
59
+ )
60
+ )
61
+
62
+ # performs a hybrid search over nodes with rrf reranking
63
+ NODE_HYBRID_SEARCH_RRF = SearchConfig(
64
+ node_config=NodeSearchConfig(
65
+ search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
66
+ reranker=NodeReranker.rrf,
67
+ )
68
+ )
69
+
70
+ # performs a hybrid search over nodes with node distance reranking
71
+ NODE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
72
+ node_config=NodeSearchConfig(
73
+ search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
74
+ reranker=NodeReranker.node_distance,
75
+ )
76
+ )
77
+
78
+ # performs a hybrid search over communities with rrf reranking
79
+ COMMUNITY_HYBRID_SEARCH_RRF = SearchConfig(
80
+ community_config=CommunitySearchConfig(
81
+ search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
82
+ reranker=CommunityReranker.rrf,
83
+ )
84
+ )