graphiti-core 0.3.0__tar.gz → 0.3.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (48) hide show
  1. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/PKG-INFO +1 -1
  2. graphiti_core-0.3.1/graphiti_core/errors.py +43 -0
  3. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/graphiti.py +33 -25
  4. graphiti_core-0.3.1/graphiti_core/helpers.py +23 -0
  5. graphiti_core-0.3.1/graphiti_core/llm_client/errors.py +23 -0
  6. graphiti_core-0.3.1/graphiti_core/llm_client/utils.py +38 -0
  7. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/prompts/extract_edge_dates.py +16 -0
  8. graphiti_core-0.3.1/graphiti_core/search/search.py +242 -0
  9. graphiti_core-0.3.1/graphiti_core/search/search_config.py +81 -0
  10. graphiti_core-0.3.1/graphiti_core/search/search_config_recipes.py +84 -0
  11. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/search/search_utils.py +242 -140
  12. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/pyproject.toml +1 -1
  13. graphiti_core-0.3.0/graphiti_core/errors.py +0 -18
  14. graphiti_core-0.3.0/graphiti_core/helpers.py +0 -7
  15. graphiti_core-0.3.0/graphiti_core/llm_client/errors.py +0 -6
  16. graphiti_core-0.3.0/graphiti_core/llm_client/utils.py +0 -22
  17. graphiti_core-0.3.0/graphiti_core/search/search.py +0 -145
  18. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/LICENSE +0 -0
  19. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/README.md +0 -0
  20. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/__init__.py +0 -0
  21. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/edges.py +0 -0
  22. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/llm_client/__init__.py +0 -0
  23. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/llm_client/anthropic_client.py +0 -0
  24. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/llm_client/client.py +0 -0
  25. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/llm_client/config.py +0 -0
  26. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/llm_client/groq_client.py +0 -0
  27. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/llm_client/openai_client.py +0 -0
  28. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/nodes.py +0 -0
  29. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/prompts/__init__.py +0 -0
  30. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/prompts/dedupe_edges.py +0 -0
  31. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/prompts/dedupe_nodes.py +0 -0
  32. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/prompts/extract_edges.py +0 -0
  33. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/prompts/extract_nodes.py +0 -0
  34. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/prompts/invalidate_edges.py +0 -0
  35. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/prompts/lib.py +0 -0
  36. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/prompts/models.py +0 -0
  37. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/prompts/summarize_nodes.py +0 -0
  38. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/py.typed +0 -0
  39. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/search/__init__.py +0 -0
  40. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/utils/__init__.py +0 -0
  41. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/utils/bulk_utils.py +0 -0
  42. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/utils/maintenance/__init__.py +0 -0
  43. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/utils/maintenance/community_operations.py +0 -0
  44. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/utils/maintenance/edge_operations.py +0 -0
  45. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/utils/maintenance/graph_data_operations.py +0 -0
  46. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/utils/maintenance/node_operations.py +0 -0
  47. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/utils/maintenance/temporal_operations.py +0 -0
  48. {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/utils/maintenance/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: graphiti-core
3
- Version: 0.3.0
3
+ Version: 0.3.1
4
4
  Summary: A temporal graph building library
5
5
  License: Apache-2.0
6
6
  Author: Paul Paliychuk
@@ -0,0 +1,43 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+
18
+ class GraphitiError(Exception):
19
+ """Base exception class for Graphiti Core."""
20
+
21
+
22
+ class EdgeNotFoundError(GraphitiError):
23
+ """Raised when an edge is not found."""
24
+
25
+ def __init__(self, uuid: str):
26
+ self.message = f'edge {uuid} not found'
27
+ super().__init__(self.message)
28
+
29
+
30
+ class NodeNotFoundError(GraphitiError):
31
+ """Raised when a node is not found."""
32
+
33
+ def __init__(self, uuid: str):
34
+ self.message = f'node {uuid} not found'
35
+ super().__init__(self.message)
36
+
37
+
38
+ class SearchRerankerError(GraphitiError):
39
+ """Raised when a node is not found."""
40
+
41
+ def __init__(self, text: str):
42
+ self.message = text
43
+ super().__init__(self.message)
@@ -24,14 +24,19 @@ from neo4j import AsyncGraphDatabase
24
24
 
25
25
  from graphiti_core.edges import EntityEdge, EpisodicEdge
26
26
  from graphiti_core.llm_client import LLMClient, OpenAIClient
27
- from graphiti_core.llm_client.utils import generate_embedding
28
27
  from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
29
- from graphiti_core.search.search import Reranker, SearchConfig, SearchMethod, hybrid_search
28
+ from graphiti_core.search.search import SearchConfig, search
29
+ from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
30
+ from graphiti_core.search.search_config_recipes import (
31
+ EDGE_HYBRID_SEARCH_NODE_DISTANCE,
32
+ EDGE_HYBRID_SEARCH_RRF,
33
+ NODE_HYBRID_SEARCH_NODE_DISTANCE,
34
+ NODE_HYBRID_SEARCH_RRF,
35
+ )
30
36
  from graphiti_core.search.search_utils import (
31
37
  RELEVANT_SCHEMA_LIMIT,
32
38
  get_relevant_edges,
33
39
  get_relevant_nodes,
34
- hybrid_node_search,
35
40
  )
36
41
  from graphiti_core.utils import (
37
42
  build_episodic_edges,
@@ -548,7 +553,7 @@ class Graphiti:
548
553
  query: str,
549
554
  center_node_uuid: str | None = None,
550
555
  group_ids: list[str | None] | None = None,
551
- num_results=10,
556
+ num_results=DEFAULT_SEARCH_LIMIT,
552
557
  ):
553
558
  """
554
559
  Perform a hybrid search on the knowledge graph.
@@ -564,7 +569,7 @@ class Graphiti:
564
569
  Facts will be reranked based on proximity to this node
565
570
  group_ids : list[str | None] | None, optional
566
571
  The graph partitions to return data from.
567
- num_results : int, optional
572
+ limit : int, optional
568
573
  The maximum number of results to return. Defaults to 10.
569
574
 
570
575
  Returns
@@ -581,21 +586,17 @@ class Graphiti:
581
586
  The search is performed using the current date and time as the reference
582
587
  point for temporal relevance.
583
588
  """
584
- reranker = Reranker.rrf if center_node_uuid is None else Reranker.node_distance
585
- search_config = SearchConfig(
586
- num_episodes=0,
587
- num_edges=num_results,
588
- num_nodes=0,
589
- group_ids=group_ids,
590
- search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
591
- reranker=reranker,
589
+ search_config = (
590
+ EDGE_HYBRID_SEARCH_RRF if center_node_uuid is None else EDGE_HYBRID_SEARCH_NODE_DISTANCE
592
591
  )
592
+ search_config.limit = num_results
593
+
593
594
  edges = (
594
- await hybrid_search(
595
+ await search(
595
596
  self.driver,
596
597
  self.llm_client.get_embedder(),
597
598
  query,
598
- datetime.now(),
599
+ group_ids,
599
600
  search_config,
600
601
  center_node_uuid,
601
602
  )
@@ -606,19 +607,20 @@ class Graphiti:
606
607
  async def _search(
607
608
  self,
608
609
  query: str,
609
- timestamp: datetime,
610
610
  config: SearchConfig,
611
+ group_ids: list[str | None] | None = None,
611
612
  center_node_uuid: str | None = None,
612
- ):
613
- return await hybrid_search(
614
- self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
613
+ ) -> SearchResults:
614
+ return await search(
615
+ self.driver, self.llm_client.get_embedder(), query, group_ids, config, center_node_uuid
615
616
  )
616
617
 
617
618
  async def get_nodes_by_query(
618
619
  self,
619
620
  query: str,
621
+ center_node_uuid: str | None = None,
620
622
  group_ids: list[str | None] | None = None,
621
- limit: int = RELEVANT_SCHEMA_LIMIT,
623
+ limit: int = DEFAULT_SEARCH_LIMIT,
622
624
  ) -> list[EntityNode]:
623
625
  """
624
626
  Retrieve nodes from the graph database based on a text query.
@@ -629,7 +631,9 @@ class Graphiti:
629
631
  Parameters
630
632
  ----------
631
633
  query : str
632
- The text query to search for in the graph.
634
+ The text query to search for in the graph
635
+ center_node_uuid: str, optional
636
+ Facts will be reranked based on proximity to this node.
633
637
  group_ids : list[str | None] | None, optional
634
638
  The graph partitions to return data from.
635
639
  limit : int | None, optional
@@ -655,8 +659,12 @@ class Graphiti:
655
659
  If not specified, a default limit (defined in the search functions) will be used.
656
660
  """
657
661
  embedder = self.llm_client.get_embedder()
658
- query_embedding = await generate_embedding(embedder, query)
659
- relevant_nodes = await hybrid_node_search(
660
- [query], [query_embedding], self.driver, group_ids, limit
662
+ search_config = (
663
+ NODE_HYBRID_SEARCH_RRF if center_node_uuid is None else NODE_HYBRID_SEARCH_NODE_DISTANCE
661
664
  )
662
- return relevant_nodes
665
+ search_config.limit = limit
666
+
667
+ nodes = (
668
+ await search(self.driver, embedder, query, group_ids, search_config, center_node_uuid)
669
+ ).nodes
670
+ return nodes
@@ -0,0 +1,23 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from datetime import datetime
18
+
19
+ from neo4j import time as neo4j_time
20
+
21
+
22
+ def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
23
+ return neo_date.to_native() if neo_date else None
@@ -0,0 +1,23 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+
18
+ class RateLimitError(Exception):
19
+ """Exception raised when the rate limit is exceeded."""
20
+
21
+ def __init__(self, message='Rate limit exceeded. Please try again later.'):
22
+ self.message = message
23
+ super().__init__(self.message)
@@ -0,0 +1,38 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import logging
18
+ import typing
19
+ from time import time
20
+
21
+ from graphiti_core.llm_client.config import EMBEDDING_DIM
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ async def generate_embedding(
27
+ embedder: typing.Any, text: str, model: str = 'text-embedding-3-small'
28
+ ):
29
+ start = time()
30
+
31
+ text = text.replace('\n', ' ')
32
+ embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
33
+ embedding = embedding[:EMBEDDING_DIM]
34
+
35
+ end = time()
36
+ logger.debug(f'embedded text of length {len(text)} in {end - start} ms')
37
+
38
+ return embedding
@@ -1,3 +1,19 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
1
17
  from typing import Any, Protocol, TypedDict
2
18
 
3
19
  from .models import Message, PromptFunction, PromptVersion
@@ -0,0 +1,242 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import logging
18
+ from time import time
19
+
20
+ from neo4j import AsyncDriver
21
+
22
+ from graphiti_core.edges import EntityEdge
23
+ from graphiti_core.errors import SearchRerankerError
24
+ from graphiti_core.llm_client.config import EMBEDDING_DIM
25
+ from graphiti_core.nodes import CommunityNode, EntityNode
26
+ from graphiti_core.search.search_config import (
27
+ DEFAULT_SEARCH_LIMIT,
28
+ CommunityReranker,
29
+ CommunitySearchConfig,
30
+ CommunitySearchMethod,
31
+ EdgeReranker,
32
+ EdgeSearchConfig,
33
+ EdgeSearchMethod,
34
+ NodeReranker,
35
+ NodeSearchConfig,
36
+ NodeSearchMethod,
37
+ SearchConfig,
38
+ SearchResults,
39
+ )
40
+ from graphiti_core.search.search_utils import (
41
+ community_fulltext_search,
42
+ community_similarity_search,
43
+ edge_fulltext_search,
44
+ edge_similarity_search,
45
+ node_distance_reranker,
46
+ node_fulltext_search,
47
+ node_similarity_search,
48
+ rrf,
49
+ )
50
+
51
+ logger = logging.getLogger(__name__)
52
+
53
+
54
+ async def search(
55
+ driver: AsyncDriver,
56
+ embedder,
57
+ query: str,
58
+ group_ids: list[str | None] | None,
59
+ config: SearchConfig,
60
+ center_node_uuid: str | None = None,
61
+ ) -> SearchResults:
62
+ start = time()
63
+ query = query.replace('\n', ' ')
64
+ # if group_ids is empty, set it to None
65
+ group_ids = group_ids if group_ids else None
66
+ edges = (
67
+ await edge_search(
68
+ driver, embedder, query, group_ids, config.edge_config, center_node_uuid, config.limit
69
+ )
70
+ if config.edge_config is not None
71
+ else []
72
+ )
73
+ nodes = (
74
+ await node_search(
75
+ driver, embedder, query, group_ids, config.node_config, center_node_uuid, config.limit
76
+ )
77
+ if config.node_config is not None
78
+ else []
79
+ )
80
+ communities = (
81
+ await community_search(
82
+ driver, embedder, query, group_ids, config.community_config, config.limit
83
+ )
84
+ if config.community_config is not None
85
+ else []
86
+ )
87
+
88
+ results = SearchResults(
89
+ edges=edges[: config.limit],
90
+ nodes=nodes[: config.limit],
91
+ communities=communities[: config.limit],
92
+ )
93
+
94
+ end = time()
95
+
96
+ logger.info(f'search returned context for query {query} in {(end - start) * 1000} ms')
97
+
98
+ return results
99
+
100
+
101
+ async def edge_search(
102
+ driver: AsyncDriver,
103
+ embedder,
104
+ query: str,
105
+ group_ids: list[str | None] | None,
106
+ config: EdgeSearchConfig,
107
+ center_node_uuid: str | None = None,
108
+ limit=DEFAULT_SEARCH_LIMIT,
109
+ ) -> list[EntityEdge]:
110
+ search_results: list[list[EntityEdge]] = []
111
+
112
+ if EdgeSearchMethod.bm25 in config.search_methods:
113
+ text_search = await edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit)
114
+ search_results.append(text_search)
115
+
116
+ if EdgeSearchMethod.cosine_similarity in config.search_methods:
117
+ search_vector = (
118
+ (await embedder.create(input=[query], model='text-embedding-3-small'))
119
+ .data[0]
120
+ .embedding[:EMBEDDING_DIM]
121
+ )
122
+
123
+ similarity_search = await edge_similarity_search(
124
+ driver, search_vector, None, None, group_ids, 2 * limit
125
+ )
126
+ search_results.append(similarity_search)
127
+
128
+ if len(search_results) > 1 and config.reranker is None:
129
+ raise SearchRerankerError('Multiple edge searches enabled without a reranker')
130
+
131
+ edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
132
+
133
+ reranked_uuids: list[str] = []
134
+ if config.reranker == EdgeReranker.rrf:
135
+ search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
136
+
137
+ reranked_uuids = rrf(search_result_uuids)
138
+ elif config.reranker == EdgeReranker.node_distance:
139
+ if center_node_uuid is None:
140
+ raise SearchRerankerError('No center node provided for Node Distance reranker')
141
+
142
+ source_to_edge_uuid_map = {
143
+ edge.source_node_uuid: edge.uuid for result in search_results for edge in result
144
+ }
145
+ source_uuids = [[edge.source_node_uuid for edge in result] for result in search_results]
146
+
147
+ reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid)
148
+
149
+ reranked_uuids = [source_to_edge_uuid_map[node_uuid] for node_uuid in reranked_node_uuids]
150
+
151
+ reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
152
+
153
+ return reranked_edges
154
+
155
+
156
+ async def node_search(
157
+ driver: AsyncDriver,
158
+ embedder,
159
+ query: str,
160
+ group_ids: list[str | None] | None,
161
+ config: NodeSearchConfig,
162
+ center_node_uuid: str | None = None,
163
+ limit=DEFAULT_SEARCH_LIMIT,
164
+ ) -> list[EntityNode]:
165
+ search_results: list[list[EntityNode]] = []
166
+
167
+ if NodeSearchMethod.bm25 in config.search_methods:
168
+ text_search = await node_fulltext_search(driver, query, group_ids, 2 * limit)
169
+ search_results.append(text_search)
170
+
171
+ if NodeSearchMethod.cosine_similarity in config.search_methods:
172
+ search_vector = (
173
+ (await embedder.create(input=[query], model='text-embedding-3-small'))
174
+ .data[0]
175
+ .embedding[:EMBEDDING_DIM]
176
+ )
177
+
178
+ similarity_search = await node_similarity_search(
179
+ driver, search_vector, group_ids, 2 * limit
180
+ )
181
+ search_results.append(similarity_search)
182
+
183
+ if len(search_results) > 1 and config.reranker is None:
184
+ raise SearchRerankerError('Multiple node searches enabled without a reranker')
185
+
186
+ search_result_uuids = [[node.uuid for node in result] for result in search_results]
187
+ node_uuid_map = {node.uuid: node for result in search_results for node in result}
188
+
189
+ reranked_uuids: list[str] = []
190
+ if config.reranker == NodeReranker.rrf:
191
+ reranked_uuids = rrf(search_result_uuids)
192
+ elif config.reranker == NodeReranker.node_distance:
193
+ if center_node_uuid is None:
194
+ raise SearchRerankerError('No center node provided for Node Distance reranker')
195
+ reranked_uuids = await node_distance_reranker(driver, search_result_uuids, center_node_uuid)
196
+
197
+ reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
198
+
199
+ return reranked_nodes
200
+
201
+
202
+ async def community_search(
203
+ driver: AsyncDriver,
204
+ embedder,
205
+ query: str,
206
+ group_ids: list[str | None] | None,
207
+ config: CommunitySearchConfig,
208
+ limit=DEFAULT_SEARCH_LIMIT,
209
+ ) -> list[CommunityNode]:
210
+ search_results: list[list[CommunityNode]] = []
211
+
212
+ if CommunitySearchMethod.bm25 in config.search_methods:
213
+ text_search = await community_fulltext_search(driver, query, group_ids, 2 * limit)
214
+ search_results.append(text_search)
215
+
216
+ if CommunitySearchMethod.cosine_similarity in config.search_methods:
217
+ search_vector = (
218
+ (await embedder.create(input=[query], model='text-embedding-3-small'))
219
+ .data[0]
220
+ .embedding[:EMBEDDING_DIM]
221
+ )
222
+
223
+ similarity_search = await community_similarity_search(
224
+ driver, search_vector, group_ids, 2 * limit
225
+ )
226
+ search_results.append(similarity_search)
227
+
228
+ if len(search_results) > 1 and config.reranker is None:
229
+ raise SearchRerankerError('Multiple node searches enabled without a reranker')
230
+
231
+ search_result_uuids = [[community.uuid for community in result] for result in search_results]
232
+ community_uuid_map = {
233
+ community.uuid: community for result in search_results for community in result
234
+ }
235
+
236
+ reranked_uuids: list[str] = []
237
+ if config.reranker == CommunityReranker.rrf:
238
+ reranked_uuids = rrf(search_result_uuids)
239
+
240
+ reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
241
+
242
+ return reranked_communities
@@ -0,0 +1,81 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from enum import Enum
18
+
19
+ from pydantic import BaseModel, Field
20
+
21
+ from graphiti_core.edges import EntityEdge
22
+ from graphiti_core.nodes import CommunityNode, EntityNode
23
+
24
+ DEFAULT_SEARCH_LIMIT = 10
25
+
26
+
27
+ class EdgeSearchMethod(Enum):
28
+ cosine_similarity = 'cosine_similarity'
29
+ bm25 = 'bm25'
30
+
31
+
32
+ class NodeSearchMethod(Enum):
33
+ cosine_similarity = 'cosine_similarity'
34
+ bm25 = 'bm25'
35
+
36
+
37
+ class CommunitySearchMethod(Enum):
38
+ cosine_similarity = 'cosine_similarity'
39
+ bm25 = 'bm25'
40
+
41
+
42
+ class EdgeReranker(Enum):
43
+ rrf = 'reciprocal_rank_fusion'
44
+ node_distance = 'node_distance'
45
+
46
+
47
+ class NodeReranker(Enum):
48
+ rrf = 'reciprocal_rank_fusion'
49
+ node_distance = 'node_distance'
50
+
51
+
52
+ class CommunityReranker(Enum):
53
+ rrf = 'reciprocal_rank_fusion'
54
+
55
+
56
+ class EdgeSearchConfig(BaseModel):
57
+ search_methods: list[EdgeSearchMethod]
58
+ reranker: EdgeReranker | None
59
+
60
+
61
+ class NodeSearchConfig(BaseModel):
62
+ search_methods: list[NodeSearchMethod]
63
+ reranker: NodeReranker | None
64
+
65
+
66
+ class CommunitySearchConfig(BaseModel):
67
+ search_methods: list[CommunitySearchMethod]
68
+ reranker: CommunityReranker | None
69
+
70
+
71
+ class SearchConfig(BaseModel):
72
+ edge_config: EdgeSearchConfig | None = Field(default=None)
73
+ node_config: NodeSearchConfig | None = Field(default=None)
74
+ community_config: CommunitySearchConfig | None = Field(default=None)
75
+ limit: int = Field(default=DEFAULT_SEARCH_LIMIT)
76
+
77
+
78
+ class SearchResults(BaseModel):
79
+ edges: list[EntityEdge]
80
+ nodes: list[EntityNode]
81
+ communities: list[CommunityNode]
@@ -0,0 +1,84 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from graphiti_core.search.search_config import (
18
+ CommunityReranker,
19
+ CommunitySearchConfig,
20
+ CommunitySearchMethod,
21
+ EdgeReranker,
22
+ EdgeSearchConfig,
23
+ EdgeSearchMethod,
24
+ NodeReranker,
25
+ NodeSearchConfig,
26
+ NodeSearchMethod,
27
+ SearchConfig,
28
+ )
29
+
30
+ # Performs a hybrid search with rrf reranking over edges, nodes, and communities
31
+ COMBINED_HYBRID_SEARCH_RRF = SearchConfig(
32
+ edge_config=EdgeSearchConfig(
33
+ search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
34
+ reranker=EdgeReranker.rrf,
35
+ ),
36
+ node_config=NodeSearchConfig(
37
+ search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
38
+ reranker=NodeReranker.rrf,
39
+ ),
40
+ community_config=CommunitySearchConfig(
41
+ search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
42
+ reranker=CommunityReranker.rrf,
43
+ ),
44
+ )
45
+
46
+ # performs a hybrid search over edges with rrf reranking
47
+ EDGE_HYBRID_SEARCH_RRF = SearchConfig(
48
+ edge_config=EdgeSearchConfig(
49
+ search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
50
+ reranker=EdgeReranker.rrf,
51
+ )
52
+ )
53
+
54
+ # performs a hybrid search over edges with node distance reranking
55
+ EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
56
+ edge_config=EdgeSearchConfig(
57
+ search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
58
+ reranker=EdgeReranker.node_distance,
59
+ )
60
+ )
61
+
62
+ # performs a hybrid search over nodes with rrf reranking
63
+ NODE_HYBRID_SEARCH_RRF = SearchConfig(
64
+ node_config=NodeSearchConfig(
65
+ search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
66
+ reranker=NodeReranker.rrf,
67
+ )
68
+ )
69
+
70
+ # performs a hybrid search over nodes with node distance reranking
71
+ NODE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
72
+ node_config=NodeSearchConfig(
73
+ search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
74
+ reranker=NodeReranker.node_distance,
75
+ )
76
+ )
77
+
78
+ # performs a hybrid search over communities with rrf reranking
79
+ COMMUNITY_HYBRID_SEARCH_RRF = SearchConfig(
80
+ community_config=CommunitySearchConfig(
81
+ search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
82
+ reranker=CommunityReranker.rrf,
83
+ )
84
+ )