graphiti-core 0.3.16__py3-none-any.whl → 0.3.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

File without changes
@@ -0,0 +1,45 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import asyncio
18
+ from typing import List, Tuple
19
+
20
+ from sentence_transformers import CrossEncoder
21
+
22
+ from graphiti_core.cross_encoder.client import CrossEncoderClient
23
+
24
+
25
+ class BGERerankerClient(CrossEncoderClient):
26
+ def __init__(self):
27
+ self.model = CrossEncoder('BAAI/bge-reranker-v2-m3')
28
+
29
+ async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]:
30
+ if not passages:
31
+ return []
32
+
33
+ input_pairs = [[query, passage] for passage in passages]
34
+
35
+ # Run the synchronous predict method in an executor
36
+ loop = asyncio.get_running_loop()
37
+ scores = await loop.run_in_executor(None, self.model.predict, input_pairs)
38
+
39
+ ranked_passages = sorted(
40
+ [(passage, float(score)) for passage, score in zip(passages, scores)],
41
+ key=lambda x: x[1],
42
+ reverse=True,
43
+ )
44
+
45
+ return ranked_passages
@@ -0,0 +1,41 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from abc import ABC, abstractmethod
18
+ from typing import List, Tuple
19
+
20
+
21
+ class CrossEncoderClient(ABC):
22
+ """
23
+ CrossEncoderClient is an abstract base class that defines the interface
24
+ for cross-encoder models used for ranking passages based on their relevance to a query.
25
+ It allows for different implementations of cross-encoder models to be used interchangeably.
26
+ """
27
+
28
+ @abstractmethod
29
+ async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]:
30
+ """
31
+ Rank the given passages based on their relevance to the query.
32
+
33
+ Args:
34
+ query (str): The query string.
35
+ passages (List[str]): A list of passages to rank.
36
+
37
+ Returns:
38
+ List[Tuple[str, float]]: A list of tuples containing the passage and its score,
39
+ sorted in descending order of relevance.
40
+ """
41
+ pass
@@ -0,0 +1,113 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import asyncio
18
+ import logging
19
+ from typing import Any
20
+
21
+ import openai
22
+ from openai import AsyncOpenAI
23
+ from pydantic import BaseModel
24
+
25
+ from ..llm_client import LLMConfig, RateLimitError
26
+ from ..prompts import Message
27
+ from .client import CrossEncoderClient
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ DEFAULT_MODEL = 'gpt-4o-mini'
32
+
33
+
34
+ class BooleanClassifier(BaseModel):
35
+ isTrue: bool
36
+
37
+
38
+ class OpenAIRerankerClient(CrossEncoderClient):
39
+ def __init__(self, config: LLMConfig | None = None):
40
+ """
41
+ Initialize the OpenAIClient with the provided configuration, cache setting, and client.
42
+
43
+ Args:
44
+ config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
45
+ cache (bool): Whether to use caching for responses. Defaults to False.
46
+ client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
47
+
48
+ """
49
+ if config is None:
50
+ config = LLMConfig()
51
+
52
+ self.config = config
53
+ self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
54
+
55
+ async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
56
+ openai_messages_list: Any = [
57
+ [
58
+ Message(
59
+ role='system',
60
+ content='You are an expert tasked with determining whether the passage is relevant to the query',
61
+ ),
62
+ Message(
63
+ role='user',
64
+ content=f"""
65
+ Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise.
66
+ <PASSAGE>
67
+ {query}
68
+ </PASSAGE>
69
+ {passage}
70
+ <QUERY>
71
+ </QUERY>
72
+ """,
73
+ ),
74
+ ]
75
+ for passage in passages
76
+ ]
77
+ try:
78
+ responses = await asyncio.gather(
79
+ *[
80
+ self.client.chat.completions.create(
81
+ model=DEFAULT_MODEL,
82
+ messages=openai_messages,
83
+ temperature=0,
84
+ max_tokens=1,
85
+ logit_bias={'6432': 1, '7983': 1},
86
+ logprobs=True,
87
+ top_logprobs=2,
88
+ )
89
+ for openai_messages in openai_messages_list
90
+ ]
91
+ )
92
+
93
+ responses_top_logprobs = [
94
+ response.choices[0].logprobs.content[0].top_logprobs
95
+ if response.choices[0].logprobs is not None
96
+ and response.choices[0].logprobs.content is not None
97
+ else []
98
+ for response in responses
99
+ ]
100
+ scores: list[float] = []
101
+ for top_logprobs in responses_top_logprobs:
102
+ for logprob in top_logprobs:
103
+ if bool(logprob.token):
104
+ scores.append(logprob.logprob)
105
+
106
+ results = [(passage, score) for passage, score in zip(passages, scores)]
107
+ results.sort(reverse=True, key=lambda x: x[1])
108
+ return results
109
+ except openai.RateLimitError as e:
110
+ raise RateLimitError from e
111
+ except Exception as e:
112
+ logger.error(f'Error in generating LLM response: {e}')
113
+ raise
@@ -42,7 +42,9 @@ class OpenAIEmbedder(EmbedderClient):
42
42
  self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
43
43
 
44
44
  async def create(
45
- self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
45
+ self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
46
46
  ) -> list[float]:
47
- result = await self.client.embeddings.create(input=input, model=self.config.embedding_model)
47
+ result = await self.client.embeddings.create(
48
+ input=input_data, model=self.config.embedding_model
49
+ )
48
50
  return result.data[0].embedding[: self.config.embedding_dim]
@@ -41,7 +41,18 @@ class VoyageAIEmbedder(EmbedderClient):
41
41
  self.client = voyageai.AsyncClient(api_key=config.api_key)
42
42
 
43
43
  async def create(
44
- self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
44
+ self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
45
45
  ) -> list[float]:
46
- result = await self.client.embed(input, model=self.config.embedding_model)
46
+ if isinstance(input_data, str):
47
+ input_list = [input_data]
48
+ elif isinstance(input_data, List):
49
+ input_list = [str(i) for i in input_data if i]
50
+ else:
51
+ input_list = [str(i) for i in input_data if i is not None]
52
+
53
+ input_list = [i for i in input_list if i]
54
+ if len(input_list) == 0:
55
+ return []
56
+
57
+ result = await self.client.embed(input_list, model=self.config.embedding_model)
47
58
  return result.embeddings[0][: self.config.embedding_dim]
graphiti_core/graphiti.py CHANGED
@@ -23,8 +23,11 @@ from dotenv import load_dotenv
23
23
  from neo4j import AsyncGraphDatabase
24
24
  from pydantic import BaseModel
25
25
 
26
+ from graphiti_core.cross_encoder.client import CrossEncoderClient
27
+ from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
26
28
  from graphiti_core.edges import EntityEdge, EpisodicEdge
27
29
  from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
30
+ from graphiti_core.helpers import DEFAULT_DATABASE
28
31
  from graphiti_core.llm_client import LLMClient, OpenAIClient
29
32
  from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
30
33
  from graphiti_core.search.search import SearchConfig, search
@@ -92,6 +95,7 @@ class Graphiti:
92
95
  password: str,
93
96
  llm_client: LLMClient | None = None,
94
97
  embedder: EmbedderClient | None = None,
98
+ cross_encoder: CrossEncoderClient | None = None,
95
99
  store_raw_episode_content: bool = True,
96
100
  ):
97
101
  """
@@ -131,7 +135,7 @@ class Graphiti:
131
135
  Graphiti if you're using the default OpenAIClient.
132
136
  """
133
137
  self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
134
- self.database = 'neo4j'
138
+ self.database = DEFAULT_DATABASE
135
139
  self.store_raw_episode_content = store_raw_episode_content
136
140
  if llm_client:
137
141
  self.llm_client = llm_client
@@ -141,6 +145,10 @@ class Graphiti:
141
145
  self.embedder = embedder
142
146
  else:
143
147
  self.embedder = OpenAIEmbedder()
148
+ if cross_encoder:
149
+ self.cross_encoder = cross_encoder
150
+ else:
151
+ self.cross_encoder = OpenAIRerankerClient()
144
152
 
145
153
  async def close(self):
146
154
  """
@@ -648,6 +656,7 @@ class Graphiti:
648
656
  await search(
649
657
  self.driver,
650
658
  self.embedder,
659
+ self.cross_encoder,
651
660
  query,
652
661
  group_ids,
653
662
  search_config,
@@ -663,8 +672,18 @@ class Graphiti:
663
672
  config: SearchConfig,
664
673
  group_ids: list[str] | None = None,
665
674
  center_node_uuid: str | None = None,
675
+ bfs_origin_node_uuids: list[str] | None = None,
666
676
  ) -> SearchResults:
667
- return await search(self.driver, self.embedder, query, group_ids, config, center_node_uuid)
677
+ return await search(
678
+ self.driver,
679
+ self.embedder,
680
+ self.cross_encoder,
681
+ query,
682
+ group_ids,
683
+ config,
684
+ center_node_uuid,
685
+ bfs_origin_node_uuids,
686
+ )
668
687
 
669
688
  async def get_nodes_by_query(
670
689
  self,
@@ -716,7 +735,13 @@ class Graphiti:
716
735
 
717
736
  nodes = (
718
737
  await search(
719
- self.driver, self.embedder, query, group_ids, search_config, center_node_uuid
738
+ self.driver,
739
+ self.embedder,
740
+ self.cross_encoder,
741
+ query,
742
+ group_ids,
743
+ search_config,
744
+ center_node_uuid,
720
745
  )
721
746
  ).nodes
722
747
  return nodes
@@ -21,6 +21,7 @@ from time import time
21
21
 
22
22
  from neo4j import AsyncDriver
23
23
 
24
+ from graphiti_core.cross_encoder.client import CrossEncoderClient
24
25
  from graphiti_core.edges import EntityEdge
25
26
  from graphiti_core.embedder import EmbedderClient
26
27
  from graphiti_core.errors import SearchRerankerError
@@ -39,10 +40,12 @@ from graphiti_core.search.search_config import (
39
40
  from graphiti_core.search.search_utils import (
40
41
  community_fulltext_search,
41
42
  community_similarity_search,
43
+ edge_bfs_search,
42
44
  edge_fulltext_search,
43
45
  edge_similarity_search,
44
46
  episode_mentions_reranker,
45
47
  maximal_marginal_relevance,
48
+ node_bfs_search,
46
49
  node_distance_reranker,
47
50
  node_fulltext_search,
48
51
  node_similarity_search,
@@ -55,40 +58,49 @@ logger = logging.getLogger(__name__)
55
58
  async def search(
56
59
  driver: AsyncDriver,
57
60
  embedder: EmbedderClient,
61
+ cross_encoder: CrossEncoderClient,
58
62
  query: str,
59
63
  group_ids: list[str] | None,
60
64
  config: SearchConfig,
61
65
  center_node_uuid: str | None = None,
66
+ bfs_origin_node_uuids: list[str] | None = None,
62
67
  ) -> SearchResults:
63
68
  start = time()
64
- query = query.replace('\n', ' ')
69
+ query_vector = await embedder.create(input=[query.replace('\n', ' ')])
70
+
65
71
  # if group_ids is empty, set it to None
66
72
  group_ids = group_ids if group_ids else None
67
73
  edges, nodes, communities = await asyncio.gather(
68
74
  edge_search(
69
75
  driver,
70
- embedder,
76
+ cross_encoder,
71
77
  query,
78
+ query_vector,
72
79
  group_ids,
73
80
  config.edge_config,
74
81
  center_node_uuid,
82
+ bfs_origin_node_uuids,
75
83
  config.limit,
76
84
  ),
77
85
  node_search(
78
86
  driver,
79
- embedder,
87
+ cross_encoder,
80
88
  query,
89
+ query_vector,
81
90
  group_ids,
82
91
  config.node_config,
83
92
  center_node_uuid,
93
+ bfs_origin_node_uuids,
84
94
  config.limit,
85
95
  ),
86
96
  community_search(
87
97
  driver,
88
- embedder,
98
+ cross_encoder,
89
99
  query,
100
+ query_vector,
90
101
  group_ids,
91
102
  config.community_config,
103
+ bfs_origin_node_uuids,
92
104
  config.limit,
93
105
  ),
94
106
  )
@@ -99,27 +111,27 @@ async def search(
99
111
  communities=communities,
100
112
  )
101
113
 
102
- end = time()
114
+ latency = (time() - start) * 1000
103
115
 
104
- logger.info(f'search returned context for query {query} in {(end - start) * 1000} ms')
116
+ logger.debug(f'search returned context for query {query} in {latency} ms')
105
117
 
106
118
  return results
107
119
 
108
120
 
109
121
  async def edge_search(
110
122
  driver: AsyncDriver,
111
- embedder: EmbedderClient,
123
+ cross_encoder: CrossEncoderClient,
112
124
  query: str,
125
+ query_vector: list[float],
113
126
  group_ids: list[str] | None,
114
127
  config: EdgeSearchConfig | None,
115
128
  center_node_uuid: str | None = None,
129
+ bfs_origin_node_uuids: list[str] | None = None,
116
130
  limit=DEFAULT_SEARCH_LIMIT,
117
131
  ) -> list[EntityEdge]:
118
132
  if config is None:
119
133
  return []
120
134
 
121
- query_vector = await embedder.create(input=[query])
122
-
123
135
  search_results: list[list[EntityEdge]] = list(
124
136
  await asyncio.gather(
125
137
  *[
@@ -127,6 +139,7 @@ async def edge_search(
127
139
  edge_similarity_search(
128
140
  driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score
129
141
  ),
142
+ edge_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth, 2 * limit),
130
143
  ]
131
144
  )
132
145
  )
@@ -147,6 +160,15 @@ async def edge_search(
147
160
  reranked_uuids = maximal_marginal_relevance(
148
161
  query_vector, search_result_uuids_and_vectors, config.mmr_lambda
149
162
  )
163
+ elif config.reranker == EdgeReranker.cross_encoder:
164
+ search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
165
+
166
+ rrf_result_uuids = rrf(search_result_uuids)
167
+ rrf_edges = [edge_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
168
+
169
+ fact_to_uuid_map = {edge.fact: edge.uuid for edge in rrf_edges}
170
+ reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
171
+ reranked_uuids = [fact_to_uuid_map[fact] for fact, _ in reranked_facts]
150
172
  elif config.reranker == EdgeReranker.node_distance:
151
173
  if center_node_uuid is None:
152
174
  raise SearchRerankerError('No center node provided for Node Distance reranker')
@@ -177,18 +199,18 @@ async def edge_search(
177
199
 
178
200
  async def node_search(
179
201
  driver: AsyncDriver,
180
- embedder: EmbedderClient,
202
+ cross_encoder: CrossEncoderClient,
181
203
  query: str,
204
+ query_vector: list[float],
182
205
  group_ids: list[str] | None,
183
206
  config: NodeSearchConfig | None,
184
207
  center_node_uuid: str | None = None,
208
+ bfs_origin_node_uuids: list[str] | None = None,
185
209
  limit=DEFAULT_SEARCH_LIMIT,
186
210
  ) -> list[EntityNode]:
187
211
  if config is None:
188
212
  return []
189
213
 
190
- query_vector = await embedder.create(input=[query])
191
-
192
214
  search_results: list[list[EntityNode]] = list(
193
215
  await asyncio.gather(
194
216
  *[
@@ -196,6 +218,7 @@ async def node_search(
196
218
  node_similarity_search(
197
219
  driver, query_vector, group_ids, 2 * limit, config.sim_min_score
198
220
  ),
221
+ node_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth, 2 * limit),
199
222
  ]
200
223
  )
201
224
  )
@@ -215,6 +238,15 @@ async def node_search(
215
238
  reranked_uuids = maximal_marginal_relevance(
216
239
  query_vector, search_result_uuids_and_vectors, config.mmr_lambda
217
240
  )
241
+ elif config.reranker == NodeReranker.cross_encoder:
242
+ # use rrf as a preliminary reranker
243
+ rrf_result_uuids = rrf(search_result_uuids)
244
+ rrf_results = [node_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
245
+
246
+ summary_to_uuid_map = {node.summary: node.uuid for node in rrf_results}
247
+
248
+ reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
249
+ reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
218
250
  elif config.reranker == NodeReranker.episode_mentions:
219
251
  reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
220
252
  elif config.reranker == NodeReranker.node_distance:
@@ -231,17 +263,17 @@ async def node_search(
231
263
 
232
264
  async def community_search(
233
265
  driver: AsyncDriver,
234
- embedder: EmbedderClient,
266
+ cross_encoder: CrossEncoderClient,
235
267
  query: str,
268
+ query_vector: list[float],
236
269
  group_ids: list[str] | None,
237
270
  config: CommunitySearchConfig | None,
271
+ bfs_origin_node_uuids: list[str] | None = None,
238
272
  limit=DEFAULT_SEARCH_LIMIT,
239
273
  ) -> list[CommunityNode]:
240
274
  if config is None:
241
275
  return []
242
276
 
243
- query_vector = await embedder.create(input=[query])
244
-
245
277
  search_results: list[list[CommunityNode]] = list(
246
278
  await asyncio.gather(
247
279
  *[
@@ -273,6 +305,12 @@ async def community_search(
273
305
  reranked_uuids = maximal_marginal_relevance(
274
306
  query_vector, search_result_uuids_and_vectors, config.mmr_lambda
275
307
  )
308
+ elif config.reranker == CommunityReranker.cross_encoder:
309
+ summary_to_uuid_map = {
310
+ node.summary: node.uuid for result in search_results for node in result
311
+ }
312
+ reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
313
+ reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
276
314
 
277
315
  reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
278
316
 
@@ -20,7 +20,11 @@ from pydantic import BaseModel, Field
20
20
 
21
21
  from graphiti_core.edges import EntityEdge
22
22
  from graphiti_core.nodes import CommunityNode, EntityNode
23
- from graphiti_core.search.search_utils import DEFAULT_MIN_SCORE, DEFAULT_MMR_LAMBDA
23
+ from graphiti_core.search.search_utils import (
24
+ DEFAULT_MIN_SCORE,
25
+ DEFAULT_MMR_LAMBDA,
26
+ MAX_SEARCH_DEPTH,
27
+ )
24
28
 
25
29
  DEFAULT_SEARCH_LIMIT = 10
26
30
 
@@ -28,11 +32,13 @@ DEFAULT_SEARCH_LIMIT = 10
28
32
  class EdgeSearchMethod(Enum):
29
33
  cosine_similarity = 'cosine_similarity'
30
34
  bm25 = 'bm25'
35
+ bfs = 'breadth_first_search'
31
36
 
32
37
 
33
38
  class NodeSearchMethod(Enum):
34
39
  cosine_similarity = 'cosine_similarity'
35
40
  bm25 = 'bm25'
41
+ bfs = 'breadth_first_search'
36
42
 
37
43
 
38
44
  class CommunitySearchMethod(Enum):
@@ -45,6 +51,7 @@ class EdgeReranker(Enum):
45
51
  node_distance = 'node_distance'
46
52
  episode_mentions = 'episode_mentions'
47
53
  mmr = 'mmr'
54
+ cross_encoder = 'cross_encoder'
48
55
 
49
56
 
50
57
  class NodeReranker(Enum):
@@ -52,11 +59,13 @@ class NodeReranker(Enum):
52
59
  node_distance = 'node_distance'
53
60
  episode_mentions = 'episode_mentions'
54
61
  mmr = 'mmr'
62
+ cross_encoder = 'cross_encoder'
55
63
 
56
64
 
57
65
  class CommunityReranker(Enum):
58
66
  rrf = 'reciprocal_rank_fusion'
59
67
  mmr = 'mmr'
68
+ cross_encoder = 'cross_encoder'
60
69
 
61
70
 
62
71
  class EdgeSearchConfig(BaseModel):
@@ -64,6 +73,7 @@ class EdgeSearchConfig(BaseModel):
64
73
  reranker: EdgeReranker = Field(default=EdgeReranker.rrf)
65
74
  sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
66
75
  mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
76
+ bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
67
77
 
68
78
 
69
79
  class NodeSearchConfig(BaseModel):
@@ -71,6 +81,7 @@ class NodeSearchConfig(BaseModel):
71
81
  reranker: NodeReranker = Field(default=NodeReranker.rrf)
72
82
  sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
73
83
  mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
84
+ bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
74
85
 
75
86
 
76
87
  class CommunitySearchConfig(BaseModel):
@@ -78,6 +89,7 @@ class CommunitySearchConfig(BaseModel):
78
89
  reranker: CommunityReranker = Field(default=CommunityReranker.rrf)
79
90
  sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
80
91
  mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
92
+ bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
81
93
 
82
94
 
83
95
  class SearchConfig(BaseModel):
@@ -48,14 +48,41 @@ COMBINED_HYBRID_SEARCH_MMR = SearchConfig(
48
48
  edge_config=EdgeSearchConfig(
49
49
  search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
50
50
  reranker=EdgeReranker.mmr,
51
+ mmr_lambda=1,
51
52
  ),
52
53
  node_config=NodeSearchConfig(
53
54
  search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
54
55
  reranker=NodeReranker.mmr,
56
+ mmr_lambda=1,
55
57
  ),
56
58
  community_config=CommunitySearchConfig(
57
59
  search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
58
60
  reranker=CommunityReranker.mmr,
61
+ mmr_lambda=1,
62
+ ),
63
+ )
64
+
65
+ # Performs a full-text search, similarity search, and bfs with cross_encoder reranking over edges, nodes, and communities
66
+ COMBINED_HYBRID_SEARCH_CROSS_ENCODER = SearchConfig(
67
+ edge_config=EdgeSearchConfig(
68
+ search_methods=[
69
+ EdgeSearchMethod.bm25,
70
+ EdgeSearchMethod.cosine_similarity,
71
+ EdgeSearchMethod.bfs,
72
+ ],
73
+ reranker=EdgeReranker.cross_encoder,
74
+ ),
75
+ node_config=NodeSearchConfig(
76
+ search_methods=[
77
+ NodeSearchMethod.bm25,
78
+ NodeSearchMethod.cosine_similarity,
79
+ NodeSearchMethod.bfs,
80
+ ],
81
+ reranker=NodeReranker.cross_encoder,
82
+ ),
83
+ community_config=CommunitySearchConfig(
84
+ search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
85
+ reranker=CommunityReranker.cross_encoder,
59
86
  ),
60
87
  )
61
88
 
@@ -81,7 +108,6 @@ EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
81
108
  search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
82
109
  reranker=EdgeReranker.node_distance,
83
110
  ),
84
- limit=30,
85
111
  )
86
112
 
87
113
  # performs a hybrid search over edges with episode mention reranking
@@ -37,6 +37,7 @@ logger = logging.getLogger(__name__)
37
37
  RELEVANT_SCHEMA_LIMIT = 3
38
38
  DEFAULT_MIN_SCORE = 0.6
39
39
  DEFAULT_MMR_LAMBDA = 0.5
40
+ MAX_SEARCH_DEPTH = 3
40
41
  MAX_QUERY_LENGTH = 128
41
42
 
42
43
 
@@ -79,21 +80,21 @@ async def get_mentioned_nodes(
79
80
  driver: AsyncDriver, episodes: list[EpisodicNode]
80
81
  ) -> list[EntityNode]:
81
82
  episode_uuids = [episode.uuid for episode in episodes]
82
- async with driver.session(database=DEFAULT_DATABASE) as session:
83
- result = await session.run(
84
- """
85
- MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
86
- RETURN DISTINCT
87
- n.uuid As uuid,
88
- n.group_id AS group_id,
89
- n.name AS name,
90
- n.name_embedding AS name_embedding,
91
- n.created_at AS created_at,
92
- n.summary AS summary
93
- """,
94
- {'uuids': episode_uuids},
95
- )
96
- records = [record async for record in result]
83
+ records, _, _ = await driver.execute_query(
84
+ """
85
+ MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
86
+ RETURN DISTINCT
87
+ n.uuid As uuid,
88
+ n.group_id AS group_id,
89
+ n.name AS name,
90
+ n.name_embedding AS name_embedding,
91
+ n.created_at AS created_at,
92
+ n.summary AS summary
93
+ """,
94
+ uuids=episode_uuids,
95
+ database_=DEFAULT_DATABASE,
96
+ routing_='r',
97
+ )
97
98
 
98
99
  nodes = [get_entity_node_from_record(record) for record in records]
99
100
 
@@ -104,21 +105,21 @@ async def get_communities_by_nodes(
104
105
  driver: AsyncDriver, nodes: list[EntityNode]
105
106
  ) -> list[CommunityNode]:
106
107
  node_uuids = [node.uuid for node in nodes]
107
- async with driver.session(database=DEFAULT_DATABASE) as session:
108
- result = await session.run(
109
- """
110
- MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
111
- RETURN DISTINCT
112
- c.uuid As uuid,
113
- c.group_id AS group_id,
114
- c.name AS name,
115
- c.name_embedding AS name_embedding
116
- c.created_at AS created_at,
117
- c.summary AS summary
118
- """,
119
- {'uuids': node_uuids},
120
- )
121
- records = [record async for record in result]
108
+ records, _, _ = await driver.execute_query(
109
+ """
110
+ MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
111
+ RETURN DISTINCT
112
+ c.uuid As uuid,
113
+ c.group_id AS group_id,
114
+ c.name AS name,
115
+ c.name_embedding AS name_embedding
116
+ c.created_at AS created_at,
117
+ c.summary AS summary
118
+ """,
119
+ uuids=node_uuids,
120
+ database_=DEFAULT_DATABASE,
121
+ routing_='r',
122
+ )
122
123
 
123
124
  communities = [get_community_node_from_record(record) for record in records]
124
125
 
@@ -141,8 +142,10 @@ async def edge_fulltext_search(
141
142
  cypher_query = Query("""
142
143
  CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query)
143
144
  YIELD relationship AS rel, score
144
- MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
145
- RETURN
145
+ MATCH (n:Entity)-[r {uuid: rel.uuid}]->(m:Entity)
146
+ WHERE ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
147
+ AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid])
148
+ RETURN
146
149
  r.uuid AS uuid,
147
150
  r.group_id AS group_id,
148
151
  n.uuid AS source_node_uuid,
@@ -158,18 +161,16 @@ async def edge_fulltext_search(
158
161
  ORDER BY score DESC LIMIT $limit
159
162
  """)
160
163
 
161
- async with driver.session(database=DEFAULT_DATABASE) as session:
162
- result = await session.run(
163
- cypher_query,
164
- {
165
- 'query': fuzzy_query,
166
- 'source_uuid': source_node_uuid,
167
- 'target_uuid': target_node_uuid,
168
- 'group_ids': group_ids,
169
- 'limit': limit,
170
- },
171
- )
172
- records = [record async for record in result]
164
+ records, _, _ = await driver.execute_query(
165
+ cypher_query,
166
+ query=fuzzy_query,
167
+ source_uuid=source_node_uuid,
168
+ target_uuid=target_node_uuid,
169
+ group_ids=group_ids,
170
+ limit=limit,
171
+ database_=DEFAULT_DATABASE,
172
+ routing_='r',
173
+ )
173
174
 
174
175
  edges = [get_entity_edge_from_record(record) for record in records]
175
176
 
@@ -188,17 +189,17 @@ async def edge_similarity_search(
188
189
  # vector similarity search over embedded facts
189
190
  query = Query("""
190
191
  CYPHER runtime = parallel parallelRuntimeSupport=all
191
- MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
192
+ MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
192
193
  WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
193
- AND ($source_uuid IS NULL OR n.uuid = $source_uuid)
194
- AND ($target_uuid IS NULL OR m.uuid = $target_uuid)
195
- WITH n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
194
+ AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
195
+ AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid])
196
+ WITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
196
197
  WHERE score > $min_score
197
198
  RETURN
198
199
  r.uuid AS uuid,
199
200
  r.group_id AS group_id,
200
- n.uuid AS source_node_uuid,
201
- m.uuid AS target_node_uuid,
201
+ startNode(r).uuid AS source_node_uuid,
202
+ endNode(r).uuid AS target_node_uuid,
202
203
  r.created_at AS created_at,
203
204
  r.name AS name,
204
205
  r.fact AS fact,
@@ -211,19 +212,62 @@ async def edge_similarity_search(
211
212
  LIMIT $limit
212
213
  """)
213
214
 
214
- async with driver.session(database=DEFAULT_DATABASE) as session:
215
- result = await session.run(
216
- query,
217
- {
218
- 'search_vector': search_vector,
219
- 'source_uuid': source_node_uuid,
220
- 'target_uuid': target_node_uuid,
221
- 'group_ids': group_ids,
222
- 'limit': limit,
223
- 'min_score': min_score,
224
- },
225
- )
226
- records = [record async for record in result]
215
+ records, _, _ = await driver.execute_query(
216
+ query,
217
+ search_vector=search_vector,
218
+ source_uuid=source_node_uuid,
219
+ target_uuid=target_node_uuid,
220
+ group_ids=group_ids,
221
+ limit=limit,
222
+ min_score=min_score,
223
+ database_=DEFAULT_DATABASE,
224
+ routing_='r',
225
+ )
226
+
227
+ edges = [get_entity_edge_from_record(record) for record in records]
228
+
229
+ return edges
230
+
231
+
232
+ async def edge_bfs_search(
233
+ driver: AsyncDriver,
234
+ bfs_origin_node_uuids: list[str] | None,
235
+ bfs_max_depth: int,
236
+ limit: int,
237
+ ) -> list[EntityEdge]:
238
+ # vector similarity search over embedded facts
239
+ if bfs_origin_node_uuids is None:
240
+ return []
241
+
242
+ query = Query("""
243
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
244
+ MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
245
+ UNWIND relationships(path) AS rel
246
+ MATCH ()-[r:RELATES_TO {uuid: rel.uuid}]-()
247
+ RETURN DISTINCT
248
+ r.uuid AS uuid,
249
+ r.group_id AS group_id,
250
+ startNode(r).uuid AS source_node_uuid,
251
+ endNode(r).uuid AS target_node_uuid,
252
+ r.created_at AS created_at,
253
+ r.name AS name,
254
+ r.fact AS fact,
255
+ r.fact_embedding AS fact_embedding,
256
+ r.episodes AS episodes,
257
+ r.expired_at AS expired_at,
258
+ r.valid_at AS valid_at,
259
+ r.invalid_at AS invalid_at
260
+ LIMIT $limit
261
+ """)
262
+
263
+ records, _, _ = await driver.execute_query(
264
+ query,
265
+ bfs_origin_node_uuids=bfs_origin_node_uuids,
266
+ depth=bfs_max_depth,
267
+ limit=limit,
268
+ database_=DEFAULT_DATABASE,
269
+ routing_='r',
270
+ )
227
271
 
228
272
  edges = [get_entity_edge_from_record(record) for record in records]
229
273
 
@@ -241,28 +285,26 @@ async def node_fulltext_search(
241
285
  if fuzzy_query == '':
242
286
  return []
243
287
 
244
- async with driver.session(database=DEFAULT_DATABASE) as session:
245
- result = await session.run(
246
- """
247
- CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
248
- YIELD node AS n, score
249
- RETURN
250
- n.uuid AS uuid,
251
- n.group_id AS group_id,
252
- n.name AS name,
253
- n.name_embedding AS name_embedding,
254
- n.created_at AS created_at,
255
- n.summary AS summary
256
- ORDER BY score DESC
257
- LIMIT $limit
258
- """,
259
- {
260
- 'query': fuzzy_query,
261
- 'group_ids': group_ids,
262
- 'limit': limit,
263
- },
264
- )
265
- records = [record async for record in result]
288
+ records, _, _ = await driver.execute_query(
289
+ """
290
+ CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
291
+ YIELD node AS n, score
292
+ RETURN
293
+ n.uuid AS uuid,
294
+ n.group_id AS group_id,
295
+ n.name AS name,
296
+ n.name_embedding AS name_embedding,
297
+ n.created_at AS created_at,
298
+ n.summary AS summary
299
+ ORDER BY score DESC
300
+ LIMIT $limit
301
+ """,
302
+ query=fuzzy_query,
303
+ group_ids=group_ids,
304
+ limit=limit,
305
+ database_=DEFAULT_DATABASE,
306
+ routing_='r',
307
+ )
266
308
  nodes = [get_entity_node_from_record(record) for record in records]
267
309
 
268
310
  return nodes
@@ -276,32 +318,64 @@ async def node_similarity_search(
276
318
  min_score: float = DEFAULT_MIN_SCORE,
277
319
  ) -> list[EntityNode]:
278
320
  # vector similarity search over entity names
279
- async with driver.session(database=DEFAULT_DATABASE) as session:
280
- result = await session.run(
281
- """
282
- CYPHER runtime = parallel parallelRuntimeSupport=all
283
- MATCH (n:Entity)
284
- WHERE $group_ids IS NULL OR n.group_id IN $group_ids
285
- WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
286
- WHERE score > $min_score
287
- RETURN
288
- n.uuid As uuid,
289
- n.group_id AS group_id,
290
- n.name AS name,
291
- n.name_embedding AS name_embedding,
292
- n.created_at AS created_at,
293
- n.summary AS summary
294
- ORDER BY score DESC
295
- LIMIT $limit
296
- """,
297
- {
298
- 'search_vector': search_vector,
299
- 'group_ids': group_ids,
300
- 'limit': limit,
301
- 'min_score': min_score,
302
- },
303
- )
304
- records = [record async for record in result]
321
+ records, _, _ = await driver.execute_query(
322
+ """
323
+ CYPHER runtime = parallel parallelRuntimeSupport=all
324
+ MATCH (n:Entity)
325
+ WHERE $group_ids IS NULL OR n.group_id IN $group_ids
326
+ WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
327
+ WHERE score > $min_score
328
+ RETURN
329
+ n.uuid As uuid,
330
+ n.group_id AS group_id,
331
+ n.name AS name,
332
+ n.name_embedding AS name_embedding,
333
+ n.created_at AS created_at,
334
+ n.summary AS summary
335
+ ORDER BY score DESC
336
+ LIMIT $limit
337
+ """,
338
+ search_vector=search_vector,
339
+ group_ids=group_ids,
340
+ limit=limit,
341
+ min_score=min_score,
342
+ database_=DEFAULT_DATABASE,
343
+ routing_='r',
344
+ )
345
+ nodes = [get_entity_node_from_record(record) for record in records]
346
+
347
+ return nodes
348
+
349
+
350
+ async def node_bfs_search(
351
+ driver: AsyncDriver,
352
+ bfs_origin_node_uuids: list[str] | None,
353
+ bfs_max_depth: int,
354
+ limit: int,
355
+ ) -> list[EntityNode]:
356
+ # vector similarity search over entity names
357
+ if bfs_origin_node_uuids is None:
358
+ return []
359
+
360
+ records, _, _ = await driver.execute_query(
361
+ """
362
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
363
+ MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
364
+ RETURN DISTINCT
365
+ n.uuid As uuid,
366
+ n.group_id AS group_id,
367
+ n.name AS name,
368
+ n.name_embedding AS name_embedding,
369
+ n.created_at AS created_at,
370
+ n.summary AS summary
371
+ LIMIT $limit
372
+ """,
373
+ bfs_origin_node_uuids=bfs_origin_node_uuids,
374
+ depth=bfs_max_depth,
375
+ limit=limit,
376
+ database_=DEFAULT_DATABASE,
377
+ routing_='r',
378
+ )
305
379
  nodes = [get_entity_node_from_record(record) for record in records]
306
380
 
307
381
  return nodes
@@ -318,28 +392,26 @@ async def community_fulltext_search(
318
392
  if fuzzy_query == '':
319
393
  return []
320
394
 
321
- async with driver.session(database=DEFAULT_DATABASE) as session:
322
- result = await session.run(
323
- """
324
- CALL db.index.fulltext.queryNodes("community_name", $query)
325
- YIELD node AS comm, score
326
- RETURN
327
- comm.uuid AS uuid,
328
- comm.group_id AS group_id,
329
- comm.name AS name,
330
- comm.name_embedding AS name_embedding,
331
- comm.created_at AS created_at,
332
- comm.summary AS summary
333
- ORDER BY score DESC
334
- LIMIT $limit
335
- """,
336
- {
337
- 'query': fuzzy_query,
338
- 'group_ids': group_ids,
339
- 'limit': limit,
340
- },
341
- )
342
- records = [record async for record in result]
395
+ records, _, _ = await driver.execute_query(
396
+ """
397
+ CALL db.index.fulltext.queryNodes("community_name", $query)
398
+ YIELD node AS comm, score
399
+ RETURN
400
+ comm.uuid AS uuid,
401
+ comm.group_id AS group_id,
402
+ comm.name AS name,
403
+ comm.name_embedding AS name_embedding,
404
+ comm.created_at AS created_at,
405
+ comm.summary AS summary
406
+ ORDER BY score DESC
407
+ LIMIT $limit
408
+ """,
409
+ query=fuzzy_query,
410
+ group_ids=group_ids,
411
+ limit=limit,
412
+ database_=DEFAULT_DATABASE,
413
+ routing_='r',
414
+ )
343
415
  communities = [get_community_node_from_record(record) for record in records]
344
416
 
345
417
  return communities
@@ -353,32 +425,30 @@ async def community_similarity_search(
353
425
  min_score=DEFAULT_MIN_SCORE,
354
426
  ) -> list[CommunityNode]:
355
427
  # vector similarity search over entity names
356
- async with driver.session(database=DEFAULT_DATABASE) as session:
357
- result = await session.run(
358
- """
359
- CYPHER runtime = parallel parallelRuntimeSupport=all
360
- MATCH (comm:Community)
361
- WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
362
- WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
363
- WHERE score > $min_score
364
- RETURN
365
- comm.uuid As uuid,
366
- comm.group_id AS group_id,
367
- comm.name AS name,
368
- comm.name_embedding AS name_embedding,
369
- comm.created_at AS created_at,
370
- comm.summary AS summary
371
- ORDER BY score DESC
372
- LIMIT $limit
373
- """,
374
- {
375
- 'search_vector': search_vector,
376
- 'group_ids': group_ids,
377
- 'limit': limit,
378
- 'min_score': min_score,
379
- },
380
- )
381
- records = [record async for record in result]
428
+ records, _, _ = await driver.execute_query(
429
+ """
430
+ CYPHER runtime = parallel parallelRuntimeSupport=all
431
+ MATCH (comm:Community)
432
+ WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
433
+ WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
434
+ WHERE score > $min_score
435
+ RETURN
436
+ comm.uuid As uuid,
437
+ comm.group_id AS group_id,
438
+ comm.name AS name,
439
+ comm.name_embedding AS name_embedding,
440
+ comm.created_at AS created_at,
441
+ comm.summary AS summary
442
+ ORDER BY score DESC
443
+ LIMIT $limit
444
+ """,
445
+ search_vector=search_vector,
446
+ group_ids=group_ids,
447
+ limit=limit,
448
+ min_score=min_score,
449
+ database_=DEFAULT_DATABASE,
450
+ routing_='r',
451
+ )
382
452
  communities = [get_community_node_from_record(record) for record in records]
383
453
 
384
454
  return communities
@@ -554,32 +624,27 @@ async def node_distance_reranker(
554
624
  driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str
555
625
  ) -> list[str]:
556
626
  # filter out node_uuid center node node uuid
557
- filtered_uuids = list(filter(lambda uuid: uuid != center_node_uuid, node_uuids))
627
+ filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
558
628
  scores: dict[str, float] = {}
559
629
 
560
630
  # Find the shortest path to center node
561
631
  query = Query("""
562
- MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: $node_uuid})
563
- RETURN length(p) AS score
632
+ UNWIND $node_uuids AS node_uuid
633
+ MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: node_uuid})
634
+ RETURN length(p) AS score, node_uuid AS uuid
564
635
  """)
565
636
 
566
- path_results = await asyncio.gather(
567
- *[
568
- driver.execute_query(
569
- query,
570
- node_uuid=uuid,
571
- center_uuid=center_node_uuid,
572
- database_=DEFAULT_DATABASE,
573
- )
574
- for uuid in filtered_uuids
575
- ]
637
+ path_results, _, _ = await driver.execute_query(
638
+ query,
639
+ node_uuids=filtered_uuids,
640
+ center_uuid=center_node_uuid,
641
+ database_=DEFAULT_DATABASE,
576
642
  )
577
643
 
578
- for uuid, result in zip(filtered_uuids, path_results):
579
- records = result[0]
580
- record = records[0] if len(records) > 0 else None
581
- distance: float = record['score'] if record is not None else float('inf')
582
- scores[uuid] = distance
644
+ for result in path_results:
645
+ uuid = result['uuid']
646
+ score = result['score'] if 'score' in result else float('inf')
647
+ scores[uuid] = score
583
648
 
584
649
  # rerank on shortest distance
585
650
  filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
@@ -596,25 +661,20 @@ async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[s
596
661
  scores: dict[str, float] = {}
597
662
 
598
663
  # Find the shortest path to center node
599
- query = Query("""
600
- MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $node_uuid})
601
- RETURN count(*) AS score
664
+ query = Query("""
665
+ UNWIND $node_uuids AS node_uuid
666
+ MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
667
+ RETURN count(*) AS score, n.uuid AS uuid
602
668
  """)
603
669
 
604
- result_scores = await asyncio.gather(
605
- *[
606
- driver.execute_query(
607
- query,
608
- node_uuid=uuid,
609
- database_=DEFAULT_DATABASE,
610
- )
611
- for uuid in sorted_uuids
612
- ]
670
+ results, _, _ = await driver.execute_query(
671
+ query,
672
+ node_uuids=sorted_uuids,
673
+ database_=DEFAULT_DATABASE,
613
674
  )
614
675
 
615
- for uuid, result in zip(sorted_uuids, result_scores):
616
- record = result[0][0]
617
- scores[uuid] = record['score']
676
+ for result in results:
677
+ scores[result['uuid']] = result['score']
618
678
 
619
679
  # rerank on shortest distance
620
680
  sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
@@ -635,4 +695,4 @@ def maximal_marginal_relevance(
635
695
 
636
696
  candidates_with_mmr.sort(reverse=True, key=lambda c: c[1])
637
697
 
638
- return [candidate[0] for candidate in candidates_with_mmr]
698
+ return list(set([candidate[0] for candidate in candidates_with_mmr]))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: graphiti-core
3
- Version: 0.3.16
3
+ Version: 0.3.18
4
4
  Summary: A temporal graph building library
5
5
  License: Apache-2.0
6
6
  Author: Paul Paliychuk
@@ -14,10 +14,9 @@ Classifier: Programming Language :: Python :: 3.12
14
14
  Requires-Dist: diskcache (>=5.6.3,<6.0.0)
15
15
  Requires-Dist: neo4j (>=5.23.0,<6.0.0)
16
16
  Requires-Dist: numpy (>=1.0.0)
17
- Requires-Dist: openai (>=1.50.2,<2.0.0)
17
+ Requires-Dist: openai (>=1.52.2,<2.0.0)
18
18
  Requires-Dist: pydantic (>=2.8.2,<3.0.0)
19
19
  Requires-Dist: tenacity (<9.0.0)
20
- Requires-Dist: voyageai (>=0.2.3,<0.3.0)
21
20
  Description-Content-Type: text/markdown
22
21
 
23
22
  <div align="center">
@@ -1,11 +1,15 @@
1
1
  graphiti_core/__init__.py,sha256=e5SWFkRiaUwfprYIeIgVIh7JDedNiloZvd3roU-0aDY,55
2
+ graphiti_core/cross_encoder/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ graphiti_core/cross_encoder/bge_reranker_client.py,sha256=jsXBUHfFpGsNASHaRnfz1_miQ3x070DdU8QS4J3DciI,1466
4
+ graphiti_core/cross_encoder/client.py,sha256=PyFYYsALQAD9wu0gL5uquPsulmaBZ0AZkJmLq2DFA-c,1472
5
+ graphiti_core/cross_encoder/openai_reranker_client.py,sha256=ij1E1Y5G9GNP3h3h8nSUF-ZJrQ921B54uudZUsCUaDc,4063
2
6
  graphiti_core/edges.py,sha256=KgH1f-nwexEX3PCRaQHPqbD033EeiKo_s39mqZn43zk,13082
3
7
  graphiti_core/embedder/__init__.py,sha256=eWd-0sPxflnYXLoWNT9sxwCIFun5JNO9Fk4E-ZXXf8Y,164
4
8
  graphiti_core/embedder/client.py,sha256=Sd9CyYXaqRazdOH8opKackrTx-y9y-T54M78XTVMzxs,1006
5
- graphiti_core/embedder/openai.py,sha256=28cl4qQCQeu6EGxVVPw3lPesA-Z_Cpvuhozyc1jdqVg,1586
6
- graphiti_core/embedder/voyage.py,sha256=pGrSquGnSiYl4nXGnutbdWchtYgZb0Fi_yW3c90dPlI,1497
9
+ graphiti_core/embedder/openai.py,sha256=yYUYPymx_lBlxDTGrlc03yNhPFyGG-etM2sszRK2G2U,1618
10
+ graphiti_core/embedder/voyage.py,sha256=_eGFI5_NjNG8z7qG3jTWCdE7sAs1Yb8fiSZSJlQLD9o,1879
7
11
  graphiti_core/errors.py,sha256=ddHrHGQxhwkVAtSph4AV84UoOlgwZufMczXPwB7uqPo,1795
8
- graphiti_core/graphiti.py,sha256=BBYuSDgGj8FZKm6ldNntn8Dv7jFccFSZK1_kTDZNUQE,26945
12
+ graphiti_core/graphiti.py,sha256=c9Rh777TrHYffPF6qvFAfm-m-PA4kD8a3ZW_ShsZGxE,27714
9
13
  graphiti_core/helpers.py,sha256=kqC2TD8Auwty4sG7KH4BuRMX413oTChGaAT_XUt9ZjU,2108
10
14
  graphiti_core/llm_client/__init__.py,sha256=PA80TSMeX-sUXITXEAxMDEt3gtfZgcJrGJUcyds1mSo,207
11
15
  graphiti_core/llm_client/anthropic_client.py,sha256=4l2PbCjIoeRr7UJ2DUh2grYLTtE2vNaWlo72IIRQDeI,2405
@@ -34,10 +38,10 @@ graphiti_core/prompts/models.py,sha256=cvx_Bv5RMFUD_5IUawYrbpOKLPHogai7_bm7YXrSz
34
38
  graphiti_core/prompts/summarize_nodes.py,sha256=FLuZpGTABgcxuIDkx_IKH115nHEw0rIaFhcGlWveAMc,2357
35
39
  graphiti_core/py.typed,sha256=vlmmzQOt7bmeQl9L3XJP4W6Ry0iiELepnOrinKz5KQg,79
36
40
  graphiti_core/search/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
37
- graphiti_core/search/search.py,sha256=2YvUjOWnPYVa2DvZAqOuKbcdCxjX5bSynpQICqFqqGU,9040
38
- graphiti_core/search/search_config.py,sha256=dWcanEmMoL42RHF-jcZO9C2G9BdqjkI9w-5xe9Wd2Xg,2737
39
- graphiti_core/search/search_config_recipes.py,sha256=FpASZLdyMdTSwY4ISHrjRUnFKVCego7Wd3j5RPN-ris,4907
40
- graphiti_core/search/search_utils.py,sha256=4OChWhtJXAtiOUeyZ3AoEWROY5JKJzEi-TzhlkZZfoo,21020
41
+ graphiti_core/search/search.py,sha256=F2Plut6YKb5CcBsa-UsbojXbDpL_iKMIuQh6zfuxGKY,11171
42
+ graphiti_core/search/search_config.py,sha256=UZN8jFA4pBlw2O5N1cuhVRBdTwMLR9N3Oyo6sQ4MDVw,3117
43
+ graphiti_core/search/search_config_recipes.py,sha256=20jS7veJExDnXA-ovJSUJfyDHKt7GW-nng-eoiT7ATA,5810
44
+ graphiti_core/search/search_utils.py,sha256=l8BR4GOo-A2eIXx4ybC18n6t6CeerN_9KQbYzCB6ix0,22551
41
45
  graphiti_core/utils/__init__.py,sha256=cJAcMnBZdHBQmWrZdU1PQ1YmaL75bhVUkyVpIPuOyns,260
42
46
  graphiti_core/utils/bulk_utils.py,sha256=JtoYTZPCigPa3n2E43Oe7QhFZRTA_QKNGy1jVgklHag,12614
43
47
  graphiti_core/utils/maintenance/__init__.py,sha256=TRY3wWWu5kn3Oahk_KKhltrWnh0NACw0FskjqF6OtlA,314
@@ -47,7 +51,7 @@ graphiti_core/utils/maintenance/graph_data_operations.py,sha256=w66_SLlvPapuG91Y
47
51
  graphiti_core/utils/maintenance/node_operations.py,sha256=h5nlRojbXOGJs-alpv6z6WnZ1UCixVGlAQYBQUqz8Bs,9030
48
52
  graphiti_core/utils/maintenance/temporal_operations.py,sha256=MvaRLWrBlDeYw8CQrKish1xbYcY5ovpfdqA2hSX7v5k,3367
49
53
  graphiti_core/utils/maintenance/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
50
- graphiti_core-0.3.16.dist-info/LICENSE,sha256=KCUwCyDXuVEgmDWkozHyniRyWjnWUWjkuDHfU6o3JlA,11325
51
- graphiti_core-0.3.16.dist-info/METADATA,sha256=CXX1YrYZICJQzNzufvTOCMIecS-i8I9ZnpGiuRiRhio,9437
52
- graphiti_core-0.3.16.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
53
- graphiti_core-0.3.16.dist-info/RECORD,,
54
+ graphiti_core-0.3.18.dist-info/LICENSE,sha256=KCUwCyDXuVEgmDWkozHyniRyWjnWUWjkuDHfU6o3JlA,11325
55
+ graphiti_core-0.3.18.dist-info/METADATA,sha256=D45OPLftoNd7wWJLtrewFJ1YkgcMLDADopI7P4jWwDg,9396
56
+ graphiti_core-0.3.18.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
57
+ graphiti_core-0.3.18.dist-info/RECORD,,