graphiti-core 0.3.16__tar.gz → 0.3.17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (57) hide show
  1. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/PKG-INFO +1 -2
  2. graphiti_core-0.3.17/graphiti_core/cross_encoder/bge_reranker_client.py +45 -0
  3. graphiti_core-0.3.17/graphiti_core/cross_encoder/client.py +41 -0
  4. graphiti_core-0.3.17/graphiti_core/cross_encoder/openai_reranker_client.py +113 -0
  5. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/graphiti.py +28 -3
  6. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/search/search.py +43 -15
  7. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/search/search_config.py +13 -1
  8. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/search/search_config_recipes.py +27 -1
  9. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/search/search_utils.py +247 -192
  10. graphiti_core-0.3.17/graphiti_core/utils/maintenance/utils.py +0 -0
  11. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/pyproject.toml +4 -2
  12. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/LICENSE +0 -0
  13. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/README.md +0 -0
  14. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/__init__.py +0 -0
  15. {graphiti_core-0.3.16/graphiti_core/models → graphiti_core-0.3.17/graphiti_core/cross_encoder}/__init__.py +0 -0
  16. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/edges.py +0 -0
  17. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/embedder/__init__.py +0 -0
  18. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/embedder/client.py +0 -0
  19. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/embedder/openai.py +0 -0
  20. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/embedder/voyage.py +0 -0
  21. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/errors.py +0 -0
  22. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/helpers.py +0 -0
  23. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/llm_client/__init__.py +0 -0
  24. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/llm_client/anthropic_client.py +0 -0
  25. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/llm_client/client.py +0 -0
  26. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/llm_client/config.py +0 -0
  27. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/llm_client/errors.py +0 -0
  28. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/llm_client/groq_client.py +0 -0
  29. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/llm_client/openai_client.py +0 -0
  30. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/llm_client/utils.py +0 -0
  31. {graphiti_core-0.3.16/graphiti_core/models/edges → graphiti_core-0.3.17/graphiti_core/models}/__init__.py +0 -0
  32. {graphiti_core-0.3.16/graphiti_core/models/nodes → graphiti_core-0.3.17/graphiti_core/models/edges}/__init__.py +0 -0
  33. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/models/edges/edge_db_queries.py +0 -0
  34. {graphiti_core-0.3.16/graphiti_core/search → graphiti_core-0.3.17/graphiti_core/models/nodes}/__init__.py +0 -0
  35. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/models/nodes/node_db_queries.py +0 -0
  36. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/nodes.py +0 -0
  37. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/prompts/__init__.py +0 -0
  38. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/prompts/dedupe_edges.py +0 -0
  39. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/prompts/dedupe_nodes.py +0 -0
  40. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/prompts/eval.py +0 -0
  41. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/prompts/extract_edge_dates.py +0 -0
  42. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/prompts/extract_edges.py +0 -0
  43. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/prompts/extract_nodes.py +0 -0
  44. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/prompts/invalidate_edges.py +0 -0
  45. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/prompts/lib.py +0 -0
  46. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/prompts/models.py +0 -0
  47. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/prompts/summarize_nodes.py +0 -0
  48. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/py.typed +0 -0
  49. /graphiti_core-0.3.16/graphiti_core/utils/maintenance/utils.py → /graphiti_core-0.3.17/graphiti_core/search/__init__.py +0 -0
  50. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/utils/__init__.py +0 -0
  51. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/utils/bulk_utils.py +0 -0
  52. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/utils/maintenance/__init__.py +0 -0
  53. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/utils/maintenance/community_operations.py +0 -0
  54. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/utils/maintenance/edge_operations.py +0 -0
  55. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/utils/maintenance/graph_data_operations.py +0 -0
  56. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/utils/maintenance/node_operations.py +0 -0
  57. {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/utils/maintenance/temporal_operations.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: graphiti-core
3
- Version: 0.3.16
3
+ Version: 0.3.17
4
4
  Summary: A temporal graph building library
5
5
  License: Apache-2.0
6
6
  Author: Paul Paliychuk
@@ -17,7 +17,6 @@ Requires-Dist: numpy (>=1.0.0)
17
17
  Requires-Dist: openai (>=1.50.2,<2.0.0)
18
18
  Requires-Dist: pydantic (>=2.8.2,<3.0.0)
19
19
  Requires-Dist: tenacity (<9.0.0)
20
- Requires-Dist: voyageai (>=0.2.3,<0.3.0)
21
20
  Description-Content-Type: text/markdown
22
21
 
23
22
  <div align="center">
@@ -0,0 +1,45 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import asyncio
18
+ from typing import List, Tuple
19
+
20
+ from sentence_transformers import CrossEncoder
21
+
22
+ from graphiti_core.cross_encoder.client import CrossEncoderClient
23
+
24
+
25
+ class BGERerankerClient(CrossEncoderClient):
26
+ def __init__(self):
27
+ self.model = CrossEncoder('BAAI/bge-reranker-v2-m3')
28
+
29
+ async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]:
30
+ if not passages:
31
+ return []
32
+
33
+ input_pairs = [[query, passage] for passage in passages]
34
+
35
+ # Run the synchronous predict method in an executor
36
+ loop = asyncio.get_running_loop()
37
+ scores = await loop.run_in_executor(None, self.model.predict, input_pairs)
38
+
39
+ ranked_passages = sorted(
40
+ [(passage, float(score)) for passage, score in zip(passages, scores)],
41
+ key=lambda x: x[1],
42
+ reverse=True,
43
+ )
44
+
45
+ return ranked_passages
@@ -0,0 +1,41 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from abc import ABC, abstractmethod
18
+ from typing import List, Tuple
19
+
20
+
21
+ class CrossEncoderClient(ABC):
22
+ """
23
+ CrossEncoderClient is an abstract base class that defines the interface
24
+ for cross-encoder models used for ranking passages based on their relevance to a query.
25
+ It allows for different implementations of cross-encoder models to be used interchangeably.
26
+ """
27
+
28
+ @abstractmethod
29
+ async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]:
30
+ """
31
+ Rank the given passages based on their relevance to the query.
32
+
33
+ Args:
34
+ query (str): The query string.
35
+ passages (List[str]): A list of passages to rank.
36
+
37
+ Returns:
38
+ List[Tuple[str, float]]: A list of tuples containing the passage and its score,
39
+ sorted in descending order of relevance.
40
+ """
41
+ pass
@@ -0,0 +1,113 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import asyncio
18
+ import logging
19
+ from typing import Any
20
+
21
+ import openai
22
+ from openai import AsyncOpenAI
23
+ from pydantic import BaseModel
24
+
25
+ from ..llm_client import LLMConfig, RateLimitError
26
+ from ..prompts import Message
27
+ from .client import CrossEncoderClient
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ DEFAULT_MODEL = 'gpt-4o-mini'
32
+
33
+
34
+ class BooleanClassifier(BaseModel):
35
+ isTrue: bool
36
+
37
+
38
+ class OpenAIRerankerClient(CrossEncoderClient):
39
+ def __init__(self, config: LLMConfig | None = None):
40
+ """
41
+ Initialize the OpenAIClient with the provided configuration, cache setting, and client.
42
+
43
+ Args:
44
+ config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
45
+ cache (bool): Whether to use caching for responses. Defaults to False.
46
+ client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
47
+
48
+ """
49
+ if config is None:
50
+ config = LLMConfig()
51
+
52
+ self.config = config
53
+ self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
54
+
55
+ async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
56
+ openai_messages_list: Any = [
57
+ [
58
+ Message(
59
+ role='system',
60
+ content='You are an expert tasked with determining whether the passage is relevant to the query',
61
+ ),
62
+ Message(
63
+ role='user',
64
+ content=f"""
65
+ Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise.
66
+ <PASSAGE>
67
+ {query}
68
+ </PASSAGE>
69
+ {passage}
70
+ <QUERY>
71
+ </QUERY>
72
+ """,
73
+ ),
74
+ ]
75
+ for passage in passages
76
+ ]
77
+ try:
78
+ responses = await asyncio.gather(
79
+ *[
80
+ self.client.chat.completions.create(
81
+ model=DEFAULT_MODEL,
82
+ messages=openai_messages,
83
+ temperature=0,
84
+ max_tokens=1,
85
+ logit_bias={'6432': 1, '7983': 1},
86
+ logprobs=True,
87
+ top_logprobs=2,
88
+ )
89
+ for openai_messages in openai_messages_list
90
+ ]
91
+ )
92
+
93
+ responses_top_logprobs = [
94
+ response.choices[0].logprobs.content[0].top_logprobs
95
+ if response.choices[0].logprobs is not None
96
+ and response.choices[0].logprobs.content is not None
97
+ else []
98
+ for response in responses
99
+ ]
100
+ scores: list[float] = []
101
+ for top_logprobs in responses_top_logprobs:
102
+ for logprob in top_logprobs:
103
+ if bool(logprob.token):
104
+ scores.append(logprob.logprob)
105
+
106
+ results = [(passage, score) for passage, score in zip(passages, scores)]
107
+ results.sort(reverse=True, key=lambda x: x[1])
108
+ return results
109
+ except openai.RateLimitError as e:
110
+ raise RateLimitError from e
111
+ except Exception as e:
112
+ logger.error(f'Error in generating LLM response: {e}')
113
+ raise
@@ -23,8 +23,11 @@ from dotenv import load_dotenv
23
23
  from neo4j import AsyncGraphDatabase
24
24
  from pydantic import BaseModel
25
25
 
26
+ from graphiti_core.cross_encoder.client import CrossEncoderClient
27
+ from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
26
28
  from graphiti_core.edges import EntityEdge, EpisodicEdge
27
29
  from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
30
+ from graphiti_core.helpers import DEFAULT_DATABASE
28
31
  from graphiti_core.llm_client import LLMClient, OpenAIClient
29
32
  from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
30
33
  from graphiti_core.search.search import SearchConfig, search
@@ -92,6 +95,7 @@ class Graphiti:
92
95
  password: str,
93
96
  llm_client: LLMClient | None = None,
94
97
  embedder: EmbedderClient | None = None,
98
+ cross_encoder: CrossEncoderClient | None = None,
95
99
  store_raw_episode_content: bool = True,
96
100
  ):
97
101
  """
@@ -131,7 +135,7 @@ class Graphiti:
131
135
  Graphiti if you're using the default OpenAIClient.
132
136
  """
133
137
  self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
134
- self.database = 'neo4j'
138
+ self.database = DEFAULT_DATABASE
135
139
  self.store_raw_episode_content = store_raw_episode_content
136
140
  if llm_client:
137
141
  self.llm_client = llm_client
@@ -141,6 +145,10 @@ class Graphiti:
141
145
  self.embedder = embedder
142
146
  else:
143
147
  self.embedder = OpenAIEmbedder()
148
+ if cross_encoder:
149
+ self.cross_encoder = cross_encoder
150
+ else:
151
+ self.cross_encoder = OpenAIRerankerClient()
144
152
 
145
153
  async def close(self):
146
154
  """
@@ -648,6 +656,7 @@ class Graphiti:
648
656
  await search(
649
657
  self.driver,
650
658
  self.embedder,
659
+ self.cross_encoder,
651
660
  query,
652
661
  group_ids,
653
662
  search_config,
@@ -663,8 +672,18 @@ class Graphiti:
663
672
  config: SearchConfig,
664
673
  group_ids: list[str] | None = None,
665
674
  center_node_uuid: str | None = None,
675
+ bfs_origin_node_uuids: list[str] | None = None,
666
676
  ) -> SearchResults:
667
- return await search(self.driver, self.embedder, query, group_ids, config, center_node_uuid)
677
+ return await search(
678
+ self.driver,
679
+ self.embedder,
680
+ self.cross_encoder,
681
+ query,
682
+ group_ids,
683
+ config,
684
+ center_node_uuid,
685
+ bfs_origin_node_uuids,
686
+ )
668
687
 
669
688
  async def get_nodes_by_query(
670
689
  self,
@@ -716,7 +735,13 @@ class Graphiti:
716
735
 
717
736
  nodes = (
718
737
  await search(
719
- self.driver, self.embedder, query, group_ids, search_config, center_node_uuid
738
+ self.driver,
739
+ self.embedder,
740
+ self.cross_encoder,
741
+ query,
742
+ group_ids,
743
+ search_config,
744
+ center_node_uuid,
720
745
  )
721
746
  ).nodes
722
747
  return nodes
@@ -21,6 +21,7 @@ from time import time
21
21
 
22
22
  from neo4j import AsyncDriver
23
23
 
24
+ from graphiti_core.cross_encoder.client import CrossEncoderClient
24
25
  from graphiti_core.edges import EntityEdge
25
26
  from graphiti_core.embedder import EmbedderClient
26
27
  from graphiti_core.errors import SearchRerankerError
@@ -39,6 +40,7 @@ from graphiti_core.search.search_config import (
39
40
  from graphiti_core.search.search_utils import (
40
41
  community_fulltext_search,
41
42
  community_similarity_search,
43
+ edge_bfs_search,
42
44
  edge_fulltext_search,
43
45
  edge_similarity_search,
44
46
  episode_mentions_reranker,
@@ -55,40 +57,49 @@ logger = logging.getLogger(__name__)
55
57
  async def search(
56
58
  driver: AsyncDriver,
57
59
  embedder: EmbedderClient,
60
+ cross_encoder: CrossEncoderClient,
58
61
  query: str,
59
62
  group_ids: list[str] | None,
60
63
  config: SearchConfig,
61
64
  center_node_uuid: str | None = None,
65
+ bfs_origin_node_uuids: list[str] | None = None,
62
66
  ) -> SearchResults:
63
67
  start = time()
64
- query = query.replace('\n', ' ')
68
+ query_vector = await embedder.create(input=[query.replace('\n', ' ')])
69
+
65
70
  # if group_ids is empty, set it to None
66
71
  group_ids = group_ids if group_ids else None
67
72
  edges, nodes, communities = await asyncio.gather(
68
73
  edge_search(
69
74
  driver,
70
- embedder,
75
+ cross_encoder,
71
76
  query,
77
+ query_vector,
72
78
  group_ids,
73
79
  config.edge_config,
74
80
  center_node_uuid,
81
+ bfs_origin_node_uuids,
75
82
  config.limit,
76
83
  ),
77
84
  node_search(
78
85
  driver,
79
- embedder,
86
+ cross_encoder,
80
87
  query,
88
+ query_vector,
81
89
  group_ids,
82
90
  config.node_config,
83
91
  center_node_uuid,
92
+ bfs_origin_node_uuids,
84
93
  config.limit,
85
94
  ),
86
95
  community_search(
87
96
  driver,
88
- embedder,
97
+ cross_encoder,
89
98
  query,
99
+ query_vector,
90
100
  group_ids,
91
101
  config.community_config,
102
+ bfs_origin_node_uuids,
92
103
  config.limit,
93
104
  ),
94
105
  )
@@ -99,27 +110,27 @@ async def search(
99
110
  communities=communities,
100
111
  )
101
112
 
102
- end = time()
113
+ latency = (time() - start) * 1000
103
114
 
104
- logger.info(f'search returned context for query {query} in {(end - start) * 1000} ms')
115
+ logger.debug(f'search returned context for query {query} in {latency} ms')
105
116
 
106
117
  return results
107
118
 
108
119
 
109
120
  async def edge_search(
110
121
  driver: AsyncDriver,
111
- embedder: EmbedderClient,
122
+ cross_encoder: CrossEncoderClient,
112
123
  query: str,
124
+ query_vector: list[float],
113
125
  group_ids: list[str] | None,
114
126
  config: EdgeSearchConfig | None,
115
127
  center_node_uuid: str | None = None,
128
+ bfs_origin_node_uuids: list[str] | None = None,
116
129
  limit=DEFAULT_SEARCH_LIMIT,
117
130
  ) -> list[EntityEdge]:
118
131
  if config is None:
119
132
  return []
120
133
 
121
- query_vector = await embedder.create(input=[query])
122
-
123
134
  search_results: list[list[EntityEdge]] = list(
124
135
  await asyncio.gather(
125
136
  *[
@@ -127,6 +138,7 @@ async def edge_search(
127
138
  edge_similarity_search(
128
139
  driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score
129
140
  ),
141
+ edge_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth),
130
142
  ]
131
143
  )
132
144
  )
@@ -147,6 +159,10 @@ async def edge_search(
147
159
  reranked_uuids = maximal_marginal_relevance(
148
160
  query_vector, search_result_uuids_and_vectors, config.mmr_lambda
149
161
  )
162
+ elif config.reranker == EdgeReranker.cross_encoder:
163
+ fact_to_uuid_map = {edge.fact: edge.uuid for result in search_results for edge in result}
164
+ reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
165
+ reranked_uuids = [fact_to_uuid_map[fact] for fact, _ in reranked_facts]
150
166
  elif config.reranker == EdgeReranker.node_distance:
151
167
  if center_node_uuid is None:
152
168
  raise SearchRerankerError('No center node provided for Node Distance reranker')
@@ -177,18 +193,18 @@ async def edge_search(
177
193
 
178
194
  async def node_search(
179
195
  driver: AsyncDriver,
180
- embedder: EmbedderClient,
196
+ cross_encoder: CrossEncoderClient,
181
197
  query: str,
198
+ query_vector: list[float],
182
199
  group_ids: list[str] | None,
183
200
  config: NodeSearchConfig | None,
184
201
  center_node_uuid: str | None = None,
202
+ bfs_origin_node_uuids: list[str] | None = None,
185
203
  limit=DEFAULT_SEARCH_LIMIT,
186
204
  ) -> list[EntityNode]:
187
205
  if config is None:
188
206
  return []
189
207
 
190
- query_vector = await embedder.create(input=[query])
191
-
192
208
  search_results: list[list[EntityNode]] = list(
193
209
  await asyncio.gather(
194
210
  *[
@@ -215,6 +231,12 @@ async def node_search(
215
231
  reranked_uuids = maximal_marginal_relevance(
216
232
  query_vector, search_result_uuids_and_vectors, config.mmr_lambda
217
233
  )
234
+ elif config.reranker == NodeReranker.cross_encoder:
235
+ summary_to_uuid_map = {
236
+ node.summary: node.uuid for result in search_results for node in result
237
+ }
238
+ reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
239
+ reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
218
240
  elif config.reranker == NodeReranker.episode_mentions:
219
241
  reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
220
242
  elif config.reranker == NodeReranker.node_distance:
@@ -231,17 +253,17 @@ async def node_search(
231
253
 
232
254
  async def community_search(
233
255
  driver: AsyncDriver,
234
- embedder: EmbedderClient,
256
+ cross_encoder: CrossEncoderClient,
235
257
  query: str,
258
+ query_vector: list[float],
236
259
  group_ids: list[str] | None,
237
260
  config: CommunitySearchConfig | None,
261
+ bfs_origin_node_uuids: list[str] | None = None,
238
262
  limit=DEFAULT_SEARCH_LIMIT,
239
263
  ) -> list[CommunityNode]:
240
264
  if config is None:
241
265
  return []
242
266
 
243
- query_vector = await embedder.create(input=[query])
244
-
245
267
  search_results: list[list[CommunityNode]] = list(
246
268
  await asyncio.gather(
247
269
  *[
@@ -273,6 +295,12 @@ async def community_search(
273
295
  reranked_uuids = maximal_marginal_relevance(
274
296
  query_vector, search_result_uuids_and_vectors, config.mmr_lambda
275
297
  )
298
+ elif config.reranker == CommunityReranker.cross_encoder:
299
+ summary_to_uuid_map = {
300
+ node.summary: node.uuid for result in search_results for node in result
301
+ }
302
+ reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
303
+ reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
276
304
 
277
305
  reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
278
306
 
@@ -20,7 +20,11 @@ from pydantic import BaseModel, Field
20
20
 
21
21
  from graphiti_core.edges import EntityEdge
22
22
  from graphiti_core.nodes import CommunityNode, EntityNode
23
- from graphiti_core.search.search_utils import DEFAULT_MIN_SCORE, DEFAULT_MMR_LAMBDA
23
+ from graphiti_core.search.search_utils import (
24
+ DEFAULT_MIN_SCORE,
25
+ DEFAULT_MMR_LAMBDA,
26
+ MAX_SEARCH_DEPTH,
27
+ )
24
28
 
25
29
  DEFAULT_SEARCH_LIMIT = 10
26
30
 
@@ -28,11 +32,13 @@ DEFAULT_SEARCH_LIMIT = 10
28
32
  class EdgeSearchMethod(Enum):
29
33
  cosine_similarity = 'cosine_similarity'
30
34
  bm25 = 'bm25'
35
+ bfs = 'breadth_first_search'
31
36
 
32
37
 
33
38
  class NodeSearchMethod(Enum):
34
39
  cosine_similarity = 'cosine_similarity'
35
40
  bm25 = 'bm25'
41
+ bfs = 'breadth_first_search'
36
42
 
37
43
 
38
44
  class CommunitySearchMethod(Enum):
@@ -45,6 +51,7 @@ class EdgeReranker(Enum):
45
51
  node_distance = 'node_distance'
46
52
  episode_mentions = 'episode_mentions'
47
53
  mmr = 'mmr'
54
+ cross_encoder = 'cross_encoder'
48
55
 
49
56
 
50
57
  class NodeReranker(Enum):
@@ -52,11 +59,13 @@ class NodeReranker(Enum):
52
59
  node_distance = 'node_distance'
53
60
  episode_mentions = 'episode_mentions'
54
61
  mmr = 'mmr'
62
+ cross_encoder = 'cross_encoder'
55
63
 
56
64
 
57
65
  class CommunityReranker(Enum):
58
66
  rrf = 'reciprocal_rank_fusion'
59
67
  mmr = 'mmr'
68
+ cross_encoder = 'cross_encoder'
60
69
 
61
70
 
62
71
  class EdgeSearchConfig(BaseModel):
@@ -64,6 +73,7 @@ class EdgeSearchConfig(BaseModel):
64
73
  reranker: EdgeReranker = Field(default=EdgeReranker.rrf)
65
74
  sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
66
75
  mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
76
+ bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
67
77
 
68
78
 
69
79
  class NodeSearchConfig(BaseModel):
@@ -71,6 +81,7 @@ class NodeSearchConfig(BaseModel):
71
81
  reranker: NodeReranker = Field(default=NodeReranker.rrf)
72
82
  sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
73
83
  mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
84
+ bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
74
85
 
75
86
 
76
87
  class CommunitySearchConfig(BaseModel):
@@ -78,6 +89,7 @@ class CommunitySearchConfig(BaseModel):
78
89
  reranker: CommunityReranker = Field(default=CommunityReranker.rrf)
79
90
  sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
80
91
  mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
92
+ bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
81
93
 
82
94
 
83
95
  class SearchConfig(BaseModel):
@@ -48,14 +48,41 @@ COMBINED_HYBRID_SEARCH_MMR = SearchConfig(
48
48
  edge_config=EdgeSearchConfig(
49
49
  search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
50
50
  reranker=EdgeReranker.mmr,
51
+ mmr_lambda=1,
51
52
  ),
52
53
  node_config=NodeSearchConfig(
53
54
  search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
54
55
  reranker=NodeReranker.mmr,
56
+ mmr_lambda=1,
55
57
  ),
56
58
  community_config=CommunitySearchConfig(
57
59
  search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
58
60
  reranker=CommunityReranker.mmr,
61
+ mmr_lambda=1,
62
+ ),
63
+ )
64
+
65
+ # Performs a full-text search, similarity search, and bfs with cross_encoder reranking over edges, nodes, and communities
66
+ COMBINED_HYBRID_SEARCH_CROSS_ENCODER = SearchConfig(
67
+ edge_config=EdgeSearchConfig(
68
+ search_methods=[
69
+ EdgeSearchMethod.bm25,
70
+ EdgeSearchMethod.cosine_similarity,
71
+ EdgeSearchMethod.bfs,
72
+ ],
73
+ reranker=EdgeReranker.cross_encoder,
74
+ ),
75
+ node_config=NodeSearchConfig(
76
+ search_methods=[
77
+ NodeSearchMethod.bm25,
78
+ NodeSearchMethod.cosine_similarity,
79
+ NodeSearchMethod.bfs,
80
+ ],
81
+ reranker=NodeReranker.cross_encoder,
82
+ ),
83
+ community_config=CommunitySearchConfig(
84
+ search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
85
+ reranker=CommunityReranker.cross_encoder,
59
86
  ),
60
87
  )
61
88
 
@@ -81,7 +108,6 @@ EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
81
108
  search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
82
109
  reranker=EdgeReranker.node_distance,
83
110
  ),
84
- limit=30,
85
111
  )
86
112
 
87
113
  # performs a hybrid search over edges with episode mention reranking