graphiti-core 0.3.16__py3-none-any.whl → 0.3.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/cross_encoder/__init__.py +0 -0
- graphiti_core/cross_encoder/bge_reranker_client.py +45 -0
- graphiti_core/cross_encoder/client.py +41 -0
- graphiti_core/cross_encoder/openai_reranker_client.py +113 -0
- graphiti_core/embedder/openai.py +4 -2
- graphiti_core/embedder/voyage.py +13 -2
- graphiti_core/graphiti.py +28 -3
- graphiti_core/search/search.py +53 -15
- graphiti_core/search/search_config.py +13 -1
- graphiti_core/search/search_config_recipes.py +27 -1
- graphiti_core/search/search_utils.py +253 -193
- {graphiti_core-0.3.16.dist-info → graphiti_core-0.3.18.dist-info}/METADATA +2 -3
- {graphiti_core-0.3.16.dist-info → graphiti_core-0.3.18.dist-info}/RECORD +15 -11
- {graphiti_core-0.3.16.dist-info → graphiti_core-0.3.18.dist-info}/LICENSE +0 -0
- {graphiti_core-0.3.16.dist-info → graphiti_core-0.3.18.dist-info}/WHEEL +0 -0
|
File without changes
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import asyncio
|
|
18
|
+
from typing import List, Tuple
|
|
19
|
+
|
|
20
|
+
from sentence_transformers import CrossEncoder
|
|
21
|
+
|
|
22
|
+
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class BGERerankerClient(CrossEncoderClient):
|
|
26
|
+
def __init__(self):
|
|
27
|
+
self.model = CrossEncoder('BAAI/bge-reranker-v2-m3')
|
|
28
|
+
|
|
29
|
+
async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]:
|
|
30
|
+
if not passages:
|
|
31
|
+
return []
|
|
32
|
+
|
|
33
|
+
input_pairs = [[query, passage] for passage in passages]
|
|
34
|
+
|
|
35
|
+
# Run the synchronous predict method in an executor
|
|
36
|
+
loop = asyncio.get_running_loop()
|
|
37
|
+
scores = await loop.run_in_executor(None, self.model.predict, input_pairs)
|
|
38
|
+
|
|
39
|
+
ranked_passages = sorted(
|
|
40
|
+
[(passage, float(score)) for passage, score in zip(passages, scores)],
|
|
41
|
+
key=lambda x: x[1],
|
|
42
|
+
reverse=True,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
return ranked_passages
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from abc import ABC, abstractmethod
|
|
18
|
+
from typing import List, Tuple
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class CrossEncoderClient(ABC):
|
|
22
|
+
"""
|
|
23
|
+
CrossEncoderClient is an abstract base class that defines the interface
|
|
24
|
+
for cross-encoder models used for ranking passages based on their relevance to a query.
|
|
25
|
+
It allows for different implementations of cross-encoder models to be used interchangeably.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]:
|
|
30
|
+
"""
|
|
31
|
+
Rank the given passages based on their relevance to the query.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
query (str): The query string.
|
|
35
|
+
passages (List[str]): A list of passages to rank.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
List[Tuple[str, float]]: A list of tuples containing the passage and its score,
|
|
39
|
+
sorted in descending order of relevance.
|
|
40
|
+
"""
|
|
41
|
+
pass
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import asyncio
|
|
18
|
+
import logging
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
import openai
|
|
22
|
+
from openai import AsyncOpenAI
|
|
23
|
+
from pydantic import BaseModel
|
|
24
|
+
|
|
25
|
+
from ..llm_client import LLMConfig, RateLimitError
|
|
26
|
+
from ..prompts import Message
|
|
27
|
+
from .client import CrossEncoderClient
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
DEFAULT_MODEL = 'gpt-4o-mini'
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class BooleanClassifier(BaseModel):
|
|
35
|
+
isTrue: bool
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class OpenAIRerankerClient(CrossEncoderClient):
|
|
39
|
+
def __init__(self, config: LLMConfig | None = None):
|
|
40
|
+
"""
|
|
41
|
+
Initialize the OpenAIClient with the provided configuration, cache setting, and client.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
|
45
|
+
cache (bool): Whether to use caching for responses. Defaults to False.
|
|
46
|
+
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
|
47
|
+
|
|
48
|
+
"""
|
|
49
|
+
if config is None:
|
|
50
|
+
config = LLMConfig()
|
|
51
|
+
|
|
52
|
+
self.config = config
|
|
53
|
+
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
|
54
|
+
|
|
55
|
+
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
|
|
56
|
+
openai_messages_list: Any = [
|
|
57
|
+
[
|
|
58
|
+
Message(
|
|
59
|
+
role='system',
|
|
60
|
+
content='You are an expert tasked with determining whether the passage is relevant to the query',
|
|
61
|
+
),
|
|
62
|
+
Message(
|
|
63
|
+
role='user',
|
|
64
|
+
content=f"""
|
|
65
|
+
Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise.
|
|
66
|
+
<PASSAGE>
|
|
67
|
+
{query}
|
|
68
|
+
</PASSAGE>
|
|
69
|
+
{passage}
|
|
70
|
+
<QUERY>
|
|
71
|
+
</QUERY>
|
|
72
|
+
""",
|
|
73
|
+
),
|
|
74
|
+
]
|
|
75
|
+
for passage in passages
|
|
76
|
+
]
|
|
77
|
+
try:
|
|
78
|
+
responses = await asyncio.gather(
|
|
79
|
+
*[
|
|
80
|
+
self.client.chat.completions.create(
|
|
81
|
+
model=DEFAULT_MODEL,
|
|
82
|
+
messages=openai_messages,
|
|
83
|
+
temperature=0,
|
|
84
|
+
max_tokens=1,
|
|
85
|
+
logit_bias={'6432': 1, '7983': 1},
|
|
86
|
+
logprobs=True,
|
|
87
|
+
top_logprobs=2,
|
|
88
|
+
)
|
|
89
|
+
for openai_messages in openai_messages_list
|
|
90
|
+
]
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
responses_top_logprobs = [
|
|
94
|
+
response.choices[0].logprobs.content[0].top_logprobs
|
|
95
|
+
if response.choices[0].logprobs is not None
|
|
96
|
+
and response.choices[0].logprobs.content is not None
|
|
97
|
+
else []
|
|
98
|
+
for response in responses
|
|
99
|
+
]
|
|
100
|
+
scores: list[float] = []
|
|
101
|
+
for top_logprobs in responses_top_logprobs:
|
|
102
|
+
for logprob in top_logprobs:
|
|
103
|
+
if bool(logprob.token):
|
|
104
|
+
scores.append(logprob.logprob)
|
|
105
|
+
|
|
106
|
+
results = [(passage, score) for passage, score in zip(passages, scores)]
|
|
107
|
+
results.sort(reverse=True, key=lambda x: x[1])
|
|
108
|
+
return results
|
|
109
|
+
except openai.RateLimitError as e:
|
|
110
|
+
raise RateLimitError from e
|
|
111
|
+
except Exception as e:
|
|
112
|
+
logger.error(f'Error in generating LLM response: {e}')
|
|
113
|
+
raise
|
graphiti_core/embedder/openai.py
CHANGED
|
@@ -42,7 +42,9 @@ class OpenAIEmbedder(EmbedderClient):
|
|
|
42
42
|
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
|
43
43
|
|
|
44
44
|
async def create(
|
|
45
|
-
self,
|
|
45
|
+
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
|
46
46
|
) -> list[float]:
|
|
47
|
-
result = await self.client.embeddings.create(
|
|
47
|
+
result = await self.client.embeddings.create(
|
|
48
|
+
input=input_data, model=self.config.embedding_model
|
|
49
|
+
)
|
|
48
50
|
return result.data[0].embedding[: self.config.embedding_dim]
|
graphiti_core/embedder/voyage.py
CHANGED
|
@@ -41,7 +41,18 @@ class VoyageAIEmbedder(EmbedderClient):
|
|
|
41
41
|
self.client = voyageai.AsyncClient(api_key=config.api_key)
|
|
42
42
|
|
|
43
43
|
async def create(
|
|
44
|
-
self,
|
|
44
|
+
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
|
45
45
|
) -> list[float]:
|
|
46
|
-
|
|
46
|
+
if isinstance(input_data, str):
|
|
47
|
+
input_list = [input_data]
|
|
48
|
+
elif isinstance(input_data, List):
|
|
49
|
+
input_list = [str(i) for i in input_data if i]
|
|
50
|
+
else:
|
|
51
|
+
input_list = [str(i) for i in input_data if i is not None]
|
|
52
|
+
|
|
53
|
+
input_list = [i for i in input_list if i]
|
|
54
|
+
if len(input_list) == 0:
|
|
55
|
+
return []
|
|
56
|
+
|
|
57
|
+
result = await self.client.embed(input_list, model=self.config.embedding_model)
|
|
47
58
|
return result.embeddings[0][: self.config.embedding_dim]
|
graphiti_core/graphiti.py
CHANGED
|
@@ -23,8 +23,11 @@ from dotenv import load_dotenv
|
|
|
23
23
|
from neo4j import AsyncGraphDatabase
|
|
24
24
|
from pydantic import BaseModel
|
|
25
25
|
|
|
26
|
+
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
27
|
+
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
|
26
28
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
|
27
29
|
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
|
30
|
+
from graphiti_core.helpers import DEFAULT_DATABASE
|
|
28
31
|
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
|
29
32
|
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
|
30
33
|
from graphiti_core.search.search import SearchConfig, search
|
|
@@ -92,6 +95,7 @@ class Graphiti:
|
|
|
92
95
|
password: str,
|
|
93
96
|
llm_client: LLMClient | None = None,
|
|
94
97
|
embedder: EmbedderClient | None = None,
|
|
98
|
+
cross_encoder: CrossEncoderClient | None = None,
|
|
95
99
|
store_raw_episode_content: bool = True,
|
|
96
100
|
):
|
|
97
101
|
"""
|
|
@@ -131,7 +135,7 @@ class Graphiti:
|
|
|
131
135
|
Graphiti if you're using the default OpenAIClient.
|
|
132
136
|
"""
|
|
133
137
|
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
|
|
134
|
-
self.database =
|
|
138
|
+
self.database = DEFAULT_DATABASE
|
|
135
139
|
self.store_raw_episode_content = store_raw_episode_content
|
|
136
140
|
if llm_client:
|
|
137
141
|
self.llm_client = llm_client
|
|
@@ -141,6 +145,10 @@ class Graphiti:
|
|
|
141
145
|
self.embedder = embedder
|
|
142
146
|
else:
|
|
143
147
|
self.embedder = OpenAIEmbedder()
|
|
148
|
+
if cross_encoder:
|
|
149
|
+
self.cross_encoder = cross_encoder
|
|
150
|
+
else:
|
|
151
|
+
self.cross_encoder = OpenAIRerankerClient()
|
|
144
152
|
|
|
145
153
|
async def close(self):
|
|
146
154
|
"""
|
|
@@ -648,6 +656,7 @@ class Graphiti:
|
|
|
648
656
|
await search(
|
|
649
657
|
self.driver,
|
|
650
658
|
self.embedder,
|
|
659
|
+
self.cross_encoder,
|
|
651
660
|
query,
|
|
652
661
|
group_ids,
|
|
653
662
|
search_config,
|
|
@@ -663,8 +672,18 @@ class Graphiti:
|
|
|
663
672
|
config: SearchConfig,
|
|
664
673
|
group_ids: list[str] | None = None,
|
|
665
674
|
center_node_uuid: str | None = None,
|
|
675
|
+
bfs_origin_node_uuids: list[str] | None = None,
|
|
666
676
|
) -> SearchResults:
|
|
667
|
-
return await search(
|
|
677
|
+
return await search(
|
|
678
|
+
self.driver,
|
|
679
|
+
self.embedder,
|
|
680
|
+
self.cross_encoder,
|
|
681
|
+
query,
|
|
682
|
+
group_ids,
|
|
683
|
+
config,
|
|
684
|
+
center_node_uuid,
|
|
685
|
+
bfs_origin_node_uuids,
|
|
686
|
+
)
|
|
668
687
|
|
|
669
688
|
async def get_nodes_by_query(
|
|
670
689
|
self,
|
|
@@ -716,7 +735,13 @@ class Graphiti:
|
|
|
716
735
|
|
|
717
736
|
nodes = (
|
|
718
737
|
await search(
|
|
719
|
-
self.driver,
|
|
738
|
+
self.driver,
|
|
739
|
+
self.embedder,
|
|
740
|
+
self.cross_encoder,
|
|
741
|
+
query,
|
|
742
|
+
group_ids,
|
|
743
|
+
search_config,
|
|
744
|
+
center_node_uuid,
|
|
720
745
|
)
|
|
721
746
|
).nodes
|
|
722
747
|
return nodes
|
graphiti_core/search/search.py
CHANGED
|
@@ -21,6 +21,7 @@ from time import time
|
|
|
21
21
|
|
|
22
22
|
from neo4j import AsyncDriver
|
|
23
23
|
|
|
24
|
+
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
24
25
|
from graphiti_core.edges import EntityEdge
|
|
25
26
|
from graphiti_core.embedder import EmbedderClient
|
|
26
27
|
from graphiti_core.errors import SearchRerankerError
|
|
@@ -39,10 +40,12 @@ from graphiti_core.search.search_config import (
|
|
|
39
40
|
from graphiti_core.search.search_utils import (
|
|
40
41
|
community_fulltext_search,
|
|
41
42
|
community_similarity_search,
|
|
43
|
+
edge_bfs_search,
|
|
42
44
|
edge_fulltext_search,
|
|
43
45
|
edge_similarity_search,
|
|
44
46
|
episode_mentions_reranker,
|
|
45
47
|
maximal_marginal_relevance,
|
|
48
|
+
node_bfs_search,
|
|
46
49
|
node_distance_reranker,
|
|
47
50
|
node_fulltext_search,
|
|
48
51
|
node_similarity_search,
|
|
@@ -55,40 +58,49 @@ logger = logging.getLogger(__name__)
|
|
|
55
58
|
async def search(
|
|
56
59
|
driver: AsyncDriver,
|
|
57
60
|
embedder: EmbedderClient,
|
|
61
|
+
cross_encoder: CrossEncoderClient,
|
|
58
62
|
query: str,
|
|
59
63
|
group_ids: list[str] | None,
|
|
60
64
|
config: SearchConfig,
|
|
61
65
|
center_node_uuid: str | None = None,
|
|
66
|
+
bfs_origin_node_uuids: list[str] | None = None,
|
|
62
67
|
) -> SearchResults:
|
|
63
68
|
start = time()
|
|
64
|
-
|
|
69
|
+
query_vector = await embedder.create(input=[query.replace('\n', ' ')])
|
|
70
|
+
|
|
65
71
|
# if group_ids is empty, set it to None
|
|
66
72
|
group_ids = group_ids if group_ids else None
|
|
67
73
|
edges, nodes, communities = await asyncio.gather(
|
|
68
74
|
edge_search(
|
|
69
75
|
driver,
|
|
70
|
-
|
|
76
|
+
cross_encoder,
|
|
71
77
|
query,
|
|
78
|
+
query_vector,
|
|
72
79
|
group_ids,
|
|
73
80
|
config.edge_config,
|
|
74
81
|
center_node_uuid,
|
|
82
|
+
bfs_origin_node_uuids,
|
|
75
83
|
config.limit,
|
|
76
84
|
),
|
|
77
85
|
node_search(
|
|
78
86
|
driver,
|
|
79
|
-
|
|
87
|
+
cross_encoder,
|
|
80
88
|
query,
|
|
89
|
+
query_vector,
|
|
81
90
|
group_ids,
|
|
82
91
|
config.node_config,
|
|
83
92
|
center_node_uuid,
|
|
93
|
+
bfs_origin_node_uuids,
|
|
84
94
|
config.limit,
|
|
85
95
|
),
|
|
86
96
|
community_search(
|
|
87
97
|
driver,
|
|
88
|
-
|
|
98
|
+
cross_encoder,
|
|
89
99
|
query,
|
|
100
|
+
query_vector,
|
|
90
101
|
group_ids,
|
|
91
102
|
config.community_config,
|
|
103
|
+
bfs_origin_node_uuids,
|
|
92
104
|
config.limit,
|
|
93
105
|
),
|
|
94
106
|
)
|
|
@@ -99,27 +111,27 @@ async def search(
|
|
|
99
111
|
communities=communities,
|
|
100
112
|
)
|
|
101
113
|
|
|
102
|
-
|
|
114
|
+
latency = (time() - start) * 1000
|
|
103
115
|
|
|
104
|
-
logger.
|
|
116
|
+
logger.debug(f'search returned context for query {query} in {latency} ms')
|
|
105
117
|
|
|
106
118
|
return results
|
|
107
119
|
|
|
108
120
|
|
|
109
121
|
async def edge_search(
|
|
110
122
|
driver: AsyncDriver,
|
|
111
|
-
|
|
123
|
+
cross_encoder: CrossEncoderClient,
|
|
112
124
|
query: str,
|
|
125
|
+
query_vector: list[float],
|
|
113
126
|
group_ids: list[str] | None,
|
|
114
127
|
config: EdgeSearchConfig | None,
|
|
115
128
|
center_node_uuid: str | None = None,
|
|
129
|
+
bfs_origin_node_uuids: list[str] | None = None,
|
|
116
130
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
117
131
|
) -> list[EntityEdge]:
|
|
118
132
|
if config is None:
|
|
119
133
|
return []
|
|
120
134
|
|
|
121
|
-
query_vector = await embedder.create(input=[query])
|
|
122
|
-
|
|
123
135
|
search_results: list[list[EntityEdge]] = list(
|
|
124
136
|
await asyncio.gather(
|
|
125
137
|
*[
|
|
@@ -127,6 +139,7 @@ async def edge_search(
|
|
|
127
139
|
edge_similarity_search(
|
|
128
140
|
driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score
|
|
129
141
|
),
|
|
142
|
+
edge_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth, 2 * limit),
|
|
130
143
|
]
|
|
131
144
|
)
|
|
132
145
|
)
|
|
@@ -147,6 +160,15 @@ async def edge_search(
|
|
|
147
160
|
reranked_uuids = maximal_marginal_relevance(
|
|
148
161
|
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
|
|
149
162
|
)
|
|
163
|
+
elif config.reranker == EdgeReranker.cross_encoder:
|
|
164
|
+
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
|
165
|
+
|
|
166
|
+
rrf_result_uuids = rrf(search_result_uuids)
|
|
167
|
+
rrf_edges = [edge_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
|
|
168
|
+
|
|
169
|
+
fact_to_uuid_map = {edge.fact: edge.uuid for edge in rrf_edges}
|
|
170
|
+
reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
|
|
171
|
+
reranked_uuids = [fact_to_uuid_map[fact] for fact, _ in reranked_facts]
|
|
150
172
|
elif config.reranker == EdgeReranker.node_distance:
|
|
151
173
|
if center_node_uuid is None:
|
|
152
174
|
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
|
@@ -177,18 +199,18 @@ async def edge_search(
|
|
|
177
199
|
|
|
178
200
|
async def node_search(
|
|
179
201
|
driver: AsyncDriver,
|
|
180
|
-
|
|
202
|
+
cross_encoder: CrossEncoderClient,
|
|
181
203
|
query: str,
|
|
204
|
+
query_vector: list[float],
|
|
182
205
|
group_ids: list[str] | None,
|
|
183
206
|
config: NodeSearchConfig | None,
|
|
184
207
|
center_node_uuid: str | None = None,
|
|
208
|
+
bfs_origin_node_uuids: list[str] | None = None,
|
|
185
209
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
186
210
|
) -> list[EntityNode]:
|
|
187
211
|
if config is None:
|
|
188
212
|
return []
|
|
189
213
|
|
|
190
|
-
query_vector = await embedder.create(input=[query])
|
|
191
|
-
|
|
192
214
|
search_results: list[list[EntityNode]] = list(
|
|
193
215
|
await asyncio.gather(
|
|
194
216
|
*[
|
|
@@ -196,6 +218,7 @@ async def node_search(
|
|
|
196
218
|
node_similarity_search(
|
|
197
219
|
driver, query_vector, group_ids, 2 * limit, config.sim_min_score
|
|
198
220
|
),
|
|
221
|
+
node_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth, 2 * limit),
|
|
199
222
|
]
|
|
200
223
|
)
|
|
201
224
|
)
|
|
@@ -215,6 +238,15 @@ async def node_search(
|
|
|
215
238
|
reranked_uuids = maximal_marginal_relevance(
|
|
216
239
|
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
|
|
217
240
|
)
|
|
241
|
+
elif config.reranker == NodeReranker.cross_encoder:
|
|
242
|
+
# use rrf as a preliminary reranker
|
|
243
|
+
rrf_result_uuids = rrf(search_result_uuids)
|
|
244
|
+
rrf_results = [node_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
|
|
245
|
+
|
|
246
|
+
summary_to_uuid_map = {node.summary: node.uuid for node in rrf_results}
|
|
247
|
+
|
|
248
|
+
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
|
|
249
|
+
reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
|
|
218
250
|
elif config.reranker == NodeReranker.episode_mentions:
|
|
219
251
|
reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
|
|
220
252
|
elif config.reranker == NodeReranker.node_distance:
|
|
@@ -231,17 +263,17 @@ async def node_search(
|
|
|
231
263
|
|
|
232
264
|
async def community_search(
|
|
233
265
|
driver: AsyncDriver,
|
|
234
|
-
|
|
266
|
+
cross_encoder: CrossEncoderClient,
|
|
235
267
|
query: str,
|
|
268
|
+
query_vector: list[float],
|
|
236
269
|
group_ids: list[str] | None,
|
|
237
270
|
config: CommunitySearchConfig | None,
|
|
271
|
+
bfs_origin_node_uuids: list[str] | None = None,
|
|
238
272
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
239
273
|
) -> list[CommunityNode]:
|
|
240
274
|
if config is None:
|
|
241
275
|
return []
|
|
242
276
|
|
|
243
|
-
query_vector = await embedder.create(input=[query])
|
|
244
|
-
|
|
245
277
|
search_results: list[list[CommunityNode]] = list(
|
|
246
278
|
await asyncio.gather(
|
|
247
279
|
*[
|
|
@@ -273,6 +305,12 @@ async def community_search(
|
|
|
273
305
|
reranked_uuids = maximal_marginal_relevance(
|
|
274
306
|
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
|
|
275
307
|
)
|
|
308
|
+
elif config.reranker == CommunityReranker.cross_encoder:
|
|
309
|
+
summary_to_uuid_map = {
|
|
310
|
+
node.summary: node.uuid for result in search_results for node in result
|
|
311
|
+
}
|
|
312
|
+
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
|
|
313
|
+
reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
|
|
276
314
|
|
|
277
315
|
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
|
|
278
316
|
|
|
@@ -20,7 +20,11 @@ from pydantic import BaseModel, Field
|
|
|
20
20
|
|
|
21
21
|
from graphiti_core.edges import EntityEdge
|
|
22
22
|
from graphiti_core.nodes import CommunityNode, EntityNode
|
|
23
|
-
from graphiti_core.search.search_utils import
|
|
23
|
+
from graphiti_core.search.search_utils import (
|
|
24
|
+
DEFAULT_MIN_SCORE,
|
|
25
|
+
DEFAULT_MMR_LAMBDA,
|
|
26
|
+
MAX_SEARCH_DEPTH,
|
|
27
|
+
)
|
|
24
28
|
|
|
25
29
|
DEFAULT_SEARCH_LIMIT = 10
|
|
26
30
|
|
|
@@ -28,11 +32,13 @@ DEFAULT_SEARCH_LIMIT = 10
|
|
|
28
32
|
class EdgeSearchMethod(Enum):
|
|
29
33
|
cosine_similarity = 'cosine_similarity'
|
|
30
34
|
bm25 = 'bm25'
|
|
35
|
+
bfs = 'breadth_first_search'
|
|
31
36
|
|
|
32
37
|
|
|
33
38
|
class NodeSearchMethod(Enum):
|
|
34
39
|
cosine_similarity = 'cosine_similarity'
|
|
35
40
|
bm25 = 'bm25'
|
|
41
|
+
bfs = 'breadth_first_search'
|
|
36
42
|
|
|
37
43
|
|
|
38
44
|
class CommunitySearchMethod(Enum):
|
|
@@ -45,6 +51,7 @@ class EdgeReranker(Enum):
|
|
|
45
51
|
node_distance = 'node_distance'
|
|
46
52
|
episode_mentions = 'episode_mentions'
|
|
47
53
|
mmr = 'mmr'
|
|
54
|
+
cross_encoder = 'cross_encoder'
|
|
48
55
|
|
|
49
56
|
|
|
50
57
|
class NodeReranker(Enum):
|
|
@@ -52,11 +59,13 @@ class NodeReranker(Enum):
|
|
|
52
59
|
node_distance = 'node_distance'
|
|
53
60
|
episode_mentions = 'episode_mentions'
|
|
54
61
|
mmr = 'mmr'
|
|
62
|
+
cross_encoder = 'cross_encoder'
|
|
55
63
|
|
|
56
64
|
|
|
57
65
|
class CommunityReranker(Enum):
|
|
58
66
|
rrf = 'reciprocal_rank_fusion'
|
|
59
67
|
mmr = 'mmr'
|
|
68
|
+
cross_encoder = 'cross_encoder'
|
|
60
69
|
|
|
61
70
|
|
|
62
71
|
class EdgeSearchConfig(BaseModel):
|
|
@@ -64,6 +73,7 @@ class EdgeSearchConfig(BaseModel):
|
|
|
64
73
|
reranker: EdgeReranker = Field(default=EdgeReranker.rrf)
|
|
65
74
|
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
|
66
75
|
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
|
76
|
+
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
|
67
77
|
|
|
68
78
|
|
|
69
79
|
class NodeSearchConfig(BaseModel):
|
|
@@ -71,6 +81,7 @@ class NodeSearchConfig(BaseModel):
|
|
|
71
81
|
reranker: NodeReranker = Field(default=NodeReranker.rrf)
|
|
72
82
|
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
|
73
83
|
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
|
84
|
+
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
|
74
85
|
|
|
75
86
|
|
|
76
87
|
class CommunitySearchConfig(BaseModel):
|
|
@@ -78,6 +89,7 @@ class CommunitySearchConfig(BaseModel):
|
|
|
78
89
|
reranker: CommunityReranker = Field(default=CommunityReranker.rrf)
|
|
79
90
|
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
|
80
91
|
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
|
92
|
+
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
|
81
93
|
|
|
82
94
|
|
|
83
95
|
class SearchConfig(BaseModel):
|
|
@@ -48,14 +48,41 @@ COMBINED_HYBRID_SEARCH_MMR = SearchConfig(
|
|
|
48
48
|
edge_config=EdgeSearchConfig(
|
|
49
49
|
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
|
50
50
|
reranker=EdgeReranker.mmr,
|
|
51
|
+
mmr_lambda=1,
|
|
51
52
|
),
|
|
52
53
|
node_config=NodeSearchConfig(
|
|
53
54
|
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
|
|
54
55
|
reranker=NodeReranker.mmr,
|
|
56
|
+
mmr_lambda=1,
|
|
55
57
|
),
|
|
56
58
|
community_config=CommunitySearchConfig(
|
|
57
59
|
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
|
|
58
60
|
reranker=CommunityReranker.mmr,
|
|
61
|
+
mmr_lambda=1,
|
|
62
|
+
),
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Performs a full-text search, similarity search, and bfs with cross_encoder reranking over edges, nodes, and communities
|
|
66
|
+
COMBINED_HYBRID_SEARCH_CROSS_ENCODER = SearchConfig(
|
|
67
|
+
edge_config=EdgeSearchConfig(
|
|
68
|
+
search_methods=[
|
|
69
|
+
EdgeSearchMethod.bm25,
|
|
70
|
+
EdgeSearchMethod.cosine_similarity,
|
|
71
|
+
EdgeSearchMethod.bfs,
|
|
72
|
+
],
|
|
73
|
+
reranker=EdgeReranker.cross_encoder,
|
|
74
|
+
),
|
|
75
|
+
node_config=NodeSearchConfig(
|
|
76
|
+
search_methods=[
|
|
77
|
+
NodeSearchMethod.bm25,
|
|
78
|
+
NodeSearchMethod.cosine_similarity,
|
|
79
|
+
NodeSearchMethod.bfs,
|
|
80
|
+
],
|
|
81
|
+
reranker=NodeReranker.cross_encoder,
|
|
82
|
+
),
|
|
83
|
+
community_config=CommunitySearchConfig(
|
|
84
|
+
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
|
|
85
|
+
reranker=CommunityReranker.cross_encoder,
|
|
59
86
|
),
|
|
60
87
|
)
|
|
61
88
|
|
|
@@ -81,7 +108,6 @@ EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
|
|
|
81
108
|
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
|
82
109
|
reranker=EdgeReranker.node_distance,
|
|
83
110
|
),
|
|
84
|
-
limit=30,
|
|
85
111
|
)
|
|
86
112
|
|
|
87
113
|
# performs a hybrid search over edges with episode mention reranking
|
|
@@ -37,6 +37,7 @@ logger = logging.getLogger(__name__)
|
|
|
37
37
|
RELEVANT_SCHEMA_LIMIT = 3
|
|
38
38
|
DEFAULT_MIN_SCORE = 0.6
|
|
39
39
|
DEFAULT_MMR_LAMBDA = 0.5
|
|
40
|
+
MAX_SEARCH_DEPTH = 3
|
|
40
41
|
MAX_QUERY_LENGTH = 128
|
|
41
42
|
|
|
42
43
|
|
|
@@ -79,21 +80,21 @@ async def get_mentioned_nodes(
|
|
|
79
80
|
driver: AsyncDriver, episodes: list[EpisodicNode]
|
|
80
81
|
) -> list[EntityNode]:
|
|
81
82
|
episode_uuids = [episode.uuid for episode in episodes]
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
83
|
+
records, _, _ = await driver.execute_query(
|
|
84
|
+
"""
|
|
85
|
+
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
|
|
86
|
+
RETURN DISTINCT
|
|
87
|
+
n.uuid As uuid,
|
|
88
|
+
n.group_id AS group_id,
|
|
89
|
+
n.name AS name,
|
|
90
|
+
n.name_embedding AS name_embedding,
|
|
91
|
+
n.created_at AS created_at,
|
|
92
|
+
n.summary AS summary
|
|
93
|
+
""",
|
|
94
|
+
uuids=episode_uuids,
|
|
95
|
+
database_=DEFAULT_DATABASE,
|
|
96
|
+
routing_='r',
|
|
97
|
+
)
|
|
97
98
|
|
|
98
99
|
nodes = [get_entity_node_from_record(record) for record in records]
|
|
99
100
|
|
|
@@ -104,21 +105,21 @@ async def get_communities_by_nodes(
|
|
|
104
105
|
driver: AsyncDriver, nodes: list[EntityNode]
|
|
105
106
|
) -> list[CommunityNode]:
|
|
106
107
|
node_uuids = [node.uuid for node in nodes]
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
108
|
+
records, _, _ = await driver.execute_query(
|
|
109
|
+
"""
|
|
110
|
+
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
|
|
111
|
+
RETURN DISTINCT
|
|
112
|
+
c.uuid As uuid,
|
|
113
|
+
c.group_id AS group_id,
|
|
114
|
+
c.name AS name,
|
|
115
|
+
c.name_embedding AS name_embedding
|
|
116
|
+
c.created_at AS created_at,
|
|
117
|
+
c.summary AS summary
|
|
118
|
+
""",
|
|
119
|
+
uuids=node_uuids,
|
|
120
|
+
database_=DEFAULT_DATABASE,
|
|
121
|
+
routing_='r',
|
|
122
|
+
)
|
|
122
123
|
|
|
123
124
|
communities = [get_community_node_from_record(record) for record in records]
|
|
124
125
|
|
|
@@ -141,8 +142,10 @@ async def edge_fulltext_search(
|
|
|
141
142
|
cypher_query = Query("""
|
|
142
143
|
CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query)
|
|
143
144
|
YIELD relationship AS rel, score
|
|
144
|
-
MATCH (n:Entity)-[r {uuid: rel.uuid}]
|
|
145
|
-
|
|
145
|
+
MATCH (n:Entity)-[r {uuid: rel.uuid}]->(m:Entity)
|
|
146
|
+
WHERE ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
|
|
147
|
+
AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid])
|
|
148
|
+
RETURN
|
|
146
149
|
r.uuid AS uuid,
|
|
147
150
|
r.group_id AS group_id,
|
|
148
151
|
n.uuid AS source_node_uuid,
|
|
@@ -158,18 +161,16 @@ async def edge_fulltext_search(
|
|
|
158
161
|
ORDER BY score DESC LIMIT $limit
|
|
159
162
|
""")
|
|
160
163
|
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
)
|
|
172
|
-
records = [record async for record in result]
|
|
164
|
+
records, _, _ = await driver.execute_query(
|
|
165
|
+
cypher_query,
|
|
166
|
+
query=fuzzy_query,
|
|
167
|
+
source_uuid=source_node_uuid,
|
|
168
|
+
target_uuid=target_node_uuid,
|
|
169
|
+
group_ids=group_ids,
|
|
170
|
+
limit=limit,
|
|
171
|
+
database_=DEFAULT_DATABASE,
|
|
172
|
+
routing_='r',
|
|
173
|
+
)
|
|
173
174
|
|
|
174
175
|
edges = [get_entity_edge_from_record(record) for record in records]
|
|
175
176
|
|
|
@@ -188,17 +189,17 @@ async def edge_similarity_search(
|
|
|
188
189
|
# vector similarity search over embedded facts
|
|
189
190
|
query = Query("""
|
|
190
191
|
CYPHER runtime = parallel parallelRuntimeSupport=all
|
|
191
|
-
MATCH (n:Entity)-[r:RELATES_TO]
|
|
192
|
+
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
|
192
193
|
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
|
|
193
|
-
AND ($source_uuid IS NULL OR n.uuid
|
|
194
|
-
AND ($target_uuid IS NULL OR m.uuid
|
|
195
|
-
WITH
|
|
194
|
+
AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
|
|
195
|
+
AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid])
|
|
196
|
+
WITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
|
196
197
|
WHERE score > $min_score
|
|
197
198
|
RETURN
|
|
198
199
|
r.uuid AS uuid,
|
|
199
200
|
r.group_id AS group_id,
|
|
200
|
-
|
|
201
|
-
|
|
201
|
+
startNode(r).uuid AS source_node_uuid,
|
|
202
|
+
endNode(r).uuid AS target_node_uuid,
|
|
202
203
|
r.created_at AS created_at,
|
|
203
204
|
r.name AS name,
|
|
204
205
|
r.fact AS fact,
|
|
@@ -211,19 +212,62 @@ async def edge_similarity_search(
|
|
|
211
212
|
LIMIT $limit
|
|
212
213
|
""")
|
|
213
214
|
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
215
|
+
records, _, _ = await driver.execute_query(
|
|
216
|
+
query,
|
|
217
|
+
search_vector=search_vector,
|
|
218
|
+
source_uuid=source_node_uuid,
|
|
219
|
+
target_uuid=target_node_uuid,
|
|
220
|
+
group_ids=group_ids,
|
|
221
|
+
limit=limit,
|
|
222
|
+
min_score=min_score,
|
|
223
|
+
database_=DEFAULT_DATABASE,
|
|
224
|
+
routing_='r',
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
edges = [get_entity_edge_from_record(record) for record in records]
|
|
228
|
+
|
|
229
|
+
return edges
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
async def edge_bfs_search(
|
|
233
|
+
driver: AsyncDriver,
|
|
234
|
+
bfs_origin_node_uuids: list[str] | None,
|
|
235
|
+
bfs_max_depth: int,
|
|
236
|
+
limit: int,
|
|
237
|
+
) -> list[EntityEdge]:
|
|
238
|
+
# vector similarity search over embedded facts
|
|
239
|
+
if bfs_origin_node_uuids is None:
|
|
240
|
+
return []
|
|
241
|
+
|
|
242
|
+
query = Query("""
|
|
243
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
244
|
+
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
|
245
|
+
UNWIND relationships(path) AS rel
|
|
246
|
+
MATCH ()-[r:RELATES_TO {uuid: rel.uuid}]-()
|
|
247
|
+
RETURN DISTINCT
|
|
248
|
+
r.uuid AS uuid,
|
|
249
|
+
r.group_id AS group_id,
|
|
250
|
+
startNode(r).uuid AS source_node_uuid,
|
|
251
|
+
endNode(r).uuid AS target_node_uuid,
|
|
252
|
+
r.created_at AS created_at,
|
|
253
|
+
r.name AS name,
|
|
254
|
+
r.fact AS fact,
|
|
255
|
+
r.fact_embedding AS fact_embedding,
|
|
256
|
+
r.episodes AS episodes,
|
|
257
|
+
r.expired_at AS expired_at,
|
|
258
|
+
r.valid_at AS valid_at,
|
|
259
|
+
r.invalid_at AS invalid_at
|
|
260
|
+
LIMIT $limit
|
|
261
|
+
""")
|
|
262
|
+
|
|
263
|
+
records, _, _ = await driver.execute_query(
|
|
264
|
+
query,
|
|
265
|
+
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
266
|
+
depth=bfs_max_depth,
|
|
267
|
+
limit=limit,
|
|
268
|
+
database_=DEFAULT_DATABASE,
|
|
269
|
+
routing_='r',
|
|
270
|
+
)
|
|
227
271
|
|
|
228
272
|
edges = [get_entity_edge_from_record(record) for record in records]
|
|
229
273
|
|
|
@@ -241,28 +285,26 @@ async def node_fulltext_search(
|
|
|
241
285
|
if fuzzy_query == '':
|
|
242
286
|
return []
|
|
243
287
|
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
)
|
|
265
|
-
records = [record async for record in result]
|
|
288
|
+
records, _, _ = await driver.execute_query(
|
|
289
|
+
"""
|
|
290
|
+
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
|
|
291
|
+
YIELD node AS n, score
|
|
292
|
+
RETURN
|
|
293
|
+
n.uuid AS uuid,
|
|
294
|
+
n.group_id AS group_id,
|
|
295
|
+
n.name AS name,
|
|
296
|
+
n.name_embedding AS name_embedding,
|
|
297
|
+
n.created_at AS created_at,
|
|
298
|
+
n.summary AS summary
|
|
299
|
+
ORDER BY score DESC
|
|
300
|
+
LIMIT $limit
|
|
301
|
+
""",
|
|
302
|
+
query=fuzzy_query,
|
|
303
|
+
group_ids=group_ids,
|
|
304
|
+
limit=limit,
|
|
305
|
+
database_=DEFAULT_DATABASE,
|
|
306
|
+
routing_='r',
|
|
307
|
+
)
|
|
266
308
|
nodes = [get_entity_node_from_record(record) for record in records]
|
|
267
309
|
|
|
268
310
|
return nodes
|
|
@@ -276,32 +318,64 @@ async def node_similarity_search(
|
|
|
276
318
|
min_score: float = DEFAULT_MIN_SCORE,
|
|
277
319
|
) -> list[EntityNode]:
|
|
278
320
|
# vector similarity search over entity names
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
321
|
+
records, _, _ = await driver.execute_query(
|
|
322
|
+
"""
|
|
323
|
+
CYPHER runtime = parallel parallelRuntimeSupport=all
|
|
324
|
+
MATCH (n:Entity)
|
|
325
|
+
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
|
326
|
+
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
|
|
327
|
+
WHERE score > $min_score
|
|
328
|
+
RETURN
|
|
329
|
+
n.uuid As uuid,
|
|
330
|
+
n.group_id AS group_id,
|
|
331
|
+
n.name AS name,
|
|
332
|
+
n.name_embedding AS name_embedding,
|
|
333
|
+
n.created_at AS created_at,
|
|
334
|
+
n.summary AS summary
|
|
335
|
+
ORDER BY score DESC
|
|
336
|
+
LIMIT $limit
|
|
337
|
+
""",
|
|
338
|
+
search_vector=search_vector,
|
|
339
|
+
group_ids=group_ids,
|
|
340
|
+
limit=limit,
|
|
341
|
+
min_score=min_score,
|
|
342
|
+
database_=DEFAULT_DATABASE,
|
|
343
|
+
routing_='r',
|
|
344
|
+
)
|
|
345
|
+
nodes = [get_entity_node_from_record(record) for record in records]
|
|
346
|
+
|
|
347
|
+
return nodes
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
async def node_bfs_search(
|
|
351
|
+
driver: AsyncDriver,
|
|
352
|
+
bfs_origin_node_uuids: list[str] | None,
|
|
353
|
+
bfs_max_depth: int,
|
|
354
|
+
limit: int,
|
|
355
|
+
) -> list[EntityNode]:
|
|
356
|
+
# vector similarity search over entity names
|
|
357
|
+
if bfs_origin_node_uuids is None:
|
|
358
|
+
return []
|
|
359
|
+
|
|
360
|
+
records, _, _ = await driver.execute_query(
|
|
361
|
+
"""
|
|
362
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
363
|
+
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
|
364
|
+
RETURN DISTINCT
|
|
365
|
+
n.uuid As uuid,
|
|
366
|
+
n.group_id AS group_id,
|
|
367
|
+
n.name AS name,
|
|
368
|
+
n.name_embedding AS name_embedding,
|
|
369
|
+
n.created_at AS created_at,
|
|
370
|
+
n.summary AS summary
|
|
371
|
+
LIMIT $limit
|
|
372
|
+
""",
|
|
373
|
+
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
374
|
+
depth=bfs_max_depth,
|
|
375
|
+
limit=limit,
|
|
376
|
+
database_=DEFAULT_DATABASE,
|
|
377
|
+
routing_='r',
|
|
378
|
+
)
|
|
305
379
|
nodes = [get_entity_node_from_record(record) for record in records]
|
|
306
380
|
|
|
307
381
|
return nodes
|
|
@@ -318,28 +392,26 @@ async def community_fulltext_search(
|
|
|
318
392
|
if fuzzy_query == '':
|
|
319
393
|
return []
|
|
320
394
|
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
)
|
|
342
|
-
records = [record async for record in result]
|
|
395
|
+
records, _, _ = await driver.execute_query(
|
|
396
|
+
"""
|
|
397
|
+
CALL db.index.fulltext.queryNodes("community_name", $query)
|
|
398
|
+
YIELD node AS comm, score
|
|
399
|
+
RETURN
|
|
400
|
+
comm.uuid AS uuid,
|
|
401
|
+
comm.group_id AS group_id,
|
|
402
|
+
comm.name AS name,
|
|
403
|
+
comm.name_embedding AS name_embedding,
|
|
404
|
+
comm.created_at AS created_at,
|
|
405
|
+
comm.summary AS summary
|
|
406
|
+
ORDER BY score DESC
|
|
407
|
+
LIMIT $limit
|
|
408
|
+
""",
|
|
409
|
+
query=fuzzy_query,
|
|
410
|
+
group_ids=group_ids,
|
|
411
|
+
limit=limit,
|
|
412
|
+
database_=DEFAULT_DATABASE,
|
|
413
|
+
routing_='r',
|
|
414
|
+
)
|
|
343
415
|
communities = [get_community_node_from_record(record) for record in records]
|
|
344
416
|
|
|
345
417
|
return communities
|
|
@@ -353,32 +425,30 @@ async def community_similarity_search(
|
|
|
353
425
|
min_score=DEFAULT_MIN_SCORE,
|
|
354
426
|
) -> list[CommunityNode]:
|
|
355
427
|
# vector similarity search over entity names
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
)
|
|
381
|
-
records = [record async for record in result]
|
|
428
|
+
records, _, _ = await driver.execute_query(
|
|
429
|
+
"""
|
|
430
|
+
CYPHER runtime = parallel parallelRuntimeSupport=all
|
|
431
|
+
MATCH (comm:Community)
|
|
432
|
+
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
|
|
433
|
+
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
|
|
434
|
+
WHERE score > $min_score
|
|
435
|
+
RETURN
|
|
436
|
+
comm.uuid As uuid,
|
|
437
|
+
comm.group_id AS group_id,
|
|
438
|
+
comm.name AS name,
|
|
439
|
+
comm.name_embedding AS name_embedding,
|
|
440
|
+
comm.created_at AS created_at,
|
|
441
|
+
comm.summary AS summary
|
|
442
|
+
ORDER BY score DESC
|
|
443
|
+
LIMIT $limit
|
|
444
|
+
""",
|
|
445
|
+
search_vector=search_vector,
|
|
446
|
+
group_ids=group_ids,
|
|
447
|
+
limit=limit,
|
|
448
|
+
min_score=min_score,
|
|
449
|
+
database_=DEFAULT_DATABASE,
|
|
450
|
+
routing_='r',
|
|
451
|
+
)
|
|
382
452
|
communities = [get_community_node_from_record(record) for record in records]
|
|
383
453
|
|
|
384
454
|
return communities
|
|
@@ -554,32 +624,27 @@ async def node_distance_reranker(
|
|
|
554
624
|
driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str
|
|
555
625
|
) -> list[str]:
|
|
556
626
|
# filter out node_uuid center node node uuid
|
|
557
|
-
filtered_uuids = list(filter(lambda
|
|
627
|
+
filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
|
|
558
628
|
scores: dict[str, float] = {}
|
|
559
629
|
|
|
560
630
|
# Find the shortest path to center node
|
|
561
631
|
query = Query("""
|
|
562
|
-
|
|
563
|
-
|
|
632
|
+
UNWIND $node_uuids AS node_uuid
|
|
633
|
+
MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: node_uuid})
|
|
634
|
+
RETURN length(p) AS score, node_uuid AS uuid
|
|
564
635
|
""")
|
|
565
636
|
|
|
566
|
-
path_results = await
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
center_uuid=center_node_uuid,
|
|
572
|
-
database_=DEFAULT_DATABASE,
|
|
573
|
-
)
|
|
574
|
-
for uuid in filtered_uuids
|
|
575
|
-
]
|
|
637
|
+
path_results, _, _ = await driver.execute_query(
|
|
638
|
+
query,
|
|
639
|
+
node_uuids=filtered_uuids,
|
|
640
|
+
center_uuid=center_node_uuid,
|
|
641
|
+
database_=DEFAULT_DATABASE,
|
|
576
642
|
)
|
|
577
643
|
|
|
578
|
-
for
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
scores[uuid] = distance
|
|
644
|
+
for result in path_results:
|
|
645
|
+
uuid = result['uuid']
|
|
646
|
+
score = result['score'] if 'score' in result else float('inf')
|
|
647
|
+
scores[uuid] = score
|
|
583
648
|
|
|
584
649
|
# rerank on shortest distance
|
|
585
650
|
filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
|
@@ -596,25 +661,20 @@ async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[s
|
|
|
596
661
|
scores: dict[str, float] = {}
|
|
597
662
|
|
|
598
663
|
# Find the shortest path to center node
|
|
599
|
-
query = Query("""
|
|
600
|
-
|
|
601
|
-
|
|
664
|
+
query = Query("""
|
|
665
|
+
UNWIND $node_uuids AS node_uuid
|
|
666
|
+
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
|
|
667
|
+
RETURN count(*) AS score, n.uuid AS uuid
|
|
602
668
|
""")
|
|
603
669
|
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
node_uuid=uuid,
|
|
609
|
-
database_=DEFAULT_DATABASE,
|
|
610
|
-
)
|
|
611
|
-
for uuid in sorted_uuids
|
|
612
|
-
]
|
|
670
|
+
results, _, _ = await driver.execute_query(
|
|
671
|
+
query,
|
|
672
|
+
node_uuids=sorted_uuids,
|
|
673
|
+
database_=DEFAULT_DATABASE,
|
|
613
674
|
)
|
|
614
675
|
|
|
615
|
-
for
|
|
616
|
-
|
|
617
|
-
scores[uuid] = record['score']
|
|
676
|
+
for result in results:
|
|
677
|
+
scores[result['uuid']] = result['score']
|
|
618
678
|
|
|
619
679
|
# rerank on shortest distance
|
|
620
680
|
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
|
@@ -635,4 +695,4 @@ def maximal_marginal_relevance(
|
|
|
635
695
|
|
|
636
696
|
candidates_with_mmr.sort(reverse=True, key=lambda c: c[1])
|
|
637
697
|
|
|
638
|
-
return [candidate[0] for candidate in candidates_with_mmr]
|
|
698
|
+
return list(set([candidate[0] for candidate in candidates_with_mmr]))
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: graphiti-core
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.18
|
|
4
4
|
Summary: A temporal graph building library
|
|
5
5
|
License: Apache-2.0
|
|
6
6
|
Author: Paul Paliychuk
|
|
@@ -14,10 +14,9 @@ Classifier: Programming Language :: Python :: 3.12
|
|
|
14
14
|
Requires-Dist: diskcache (>=5.6.3,<6.0.0)
|
|
15
15
|
Requires-Dist: neo4j (>=5.23.0,<6.0.0)
|
|
16
16
|
Requires-Dist: numpy (>=1.0.0)
|
|
17
|
-
Requires-Dist: openai (>=1.
|
|
17
|
+
Requires-Dist: openai (>=1.52.2,<2.0.0)
|
|
18
18
|
Requires-Dist: pydantic (>=2.8.2,<3.0.0)
|
|
19
19
|
Requires-Dist: tenacity (<9.0.0)
|
|
20
|
-
Requires-Dist: voyageai (>=0.2.3,<0.3.0)
|
|
21
20
|
Description-Content-Type: text/markdown
|
|
22
21
|
|
|
23
22
|
<div align="center">
|
|
@@ -1,11 +1,15 @@
|
|
|
1
1
|
graphiti_core/__init__.py,sha256=e5SWFkRiaUwfprYIeIgVIh7JDedNiloZvd3roU-0aDY,55
|
|
2
|
+
graphiti_core/cross_encoder/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
+
graphiti_core/cross_encoder/bge_reranker_client.py,sha256=jsXBUHfFpGsNASHaRnfz1_miQ3x070DdU8QS4J3DciI,1466
|
|
4
|
+
graphiti_core/cross_encoder/client.py,sha256=PyFYYsALQAD9wu0gL5uquPsulmaBZ0AZkJmLq2DFA-c,1472
|
|
5
|
+
graphiti_core/cross_encoder/openai_reranker_client.py,sha256=ij1E1Y5G9GNP3h3h8nSUF-ZJrQ921B54uudZUsCUaDc,4063
|
|
2
6
|
graphiti_core/edges.py,sha256=KgH1f-nwexEX3PCRaQHPqbD033EeiKo_s39mqZn43zk,13082
|
|
3
7
|
graphiti_core/embedder/__init__.py,sha256=eWd-0sPxflnYXLoWNT9sxwCIFun5JNO9Fk4E-ZXXf8Y,164
|
|
4
8
|
graphiti_core/embedder/client.py,sha256=Sd9CyYXaqRazdOH8opKackrTx-y9y-T54M78XTVMzxs,1006
|
|
5
|
-
graphiti_core/embedder/openai.py,sha256=
|
|
6
|
-
graphiti_core/embedder/voyage.py,sha256=
|
|
9
|
+
graphiti_core/embedder/openai.py,sha256=yYUYPymx_lBlxDTGrlc03yNhPFyGG-etM2sszRK2G2U,1618
|
|
10
|
+
graphiti_core/embedder/voyage.py,sha256=_eGFI5_NjNG8z7qG3jTWCdE7sAs1Yb8fiSZSJlQLD9o,1879
|
|
7
11
|
graphiti_core/errors.py,sha256=ddHrHGQxhwkVAtSph4AV84UoOlgwZufMczXPwB7uqPo,1795
|
|
8
|
-
graphiti_core/graphiti.py,sha256=
|
|
12
|
+
graphiti_core/graphiti.py,sha256=c9Rh777TrHYffPF6qvFAfm-m-PA4kD8a3ZW_ShsZGxE,27714
|
|
9
13
|
graphiti_core/helpers.py,sha256=kqC2TD8Auwty4sG7KH4BuRMX413oTChGaAT_XUt9ZjU,2108
|
|
10
14
|
graphiti_core/llm_client/__init__.py,sha256=PA80TSMeX-sUXITXEAxMDEt3gtfZgcJrGJUcyds1mSo,207
|
|
11
15
|
graphiti_core/llm_client/anthropic_client.py,sha256=4l2PbCjIoeRr7UJ2DUh2grYLTtE2vNaWlo72IIRQDeI,2405
|
|
@@ -34,10 +38,10 @@ graphiti_core/prompts/models.py,sha256=cvx_Bv5RMFUD_5IUawYrbpOKLPHogai7_bm7YXrSz
|
|
|
34
38
|
graphiti_core/prompts/summarize_nodes.py,sha256=FLuZpGTABgcxuIDkx_IKH115nHEw0rIaFhcGlWveAMc,2357
|
|
35
39
|
graphiti_core/py.typed,sha256=vlmmzQOt7bmeQl9L3XJP4W6Ry0iiELepnOrinKz5KQg,79
|
|
36
40
|
graphiti_core/search/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
37
|
-
graphiti_core/search/search.py,sha256=
|
|
38
|
-
graphiti_core/search/search_config.py,sha256=
|
|
39
|
-
graphiti_core/search/search_config_recipes.py,sha256=
|
|
40
|
-
graphiti_core/search/search_utils.py,sha256=
|
|
41
|
+
graphiti_core/search/search.py,sha256=F2Plut6YKb5CcBsa-UsbojXbDpL_iKMIuQh6zfuxGKY,11171
|
|
42
|
+
graphiti_core/search/search_config.py,sha256=UZN8jFA4pBlw2O5N1cuhVRBdTwMLR9N3Oyo6sQ4MDVw,3117
|
|
43
|
+
graphiti_core/search/search_config_recipes.py,sha256=20jS7veJExDnXA-ovJSUJfyDHKt7GW-nng-eoiT7ATA,5810
|
|
44
|
+
graphiti_core/search/search_utils.py,sha256=l8BR4GOo-A2eIXx4ybC18n6t6CeerN_9KQbYzCB6ix0,22551
|
|
41
45
|
graphiti_core/utils/__init__.py,sha256=cJAcMnBZdHBQmWrZdU1PQ1YmaL75bhVUkyVpIPuOyns,260
|
|
42
46
|
graphiti_core/utils/bulk_utils.py,sha256=JtoYTZPCigPa3n2E43Oe7QhFZRTA_QKNGy1jVgklHag,12614
|
|
43
47
|
graphiti_core/utils/maintenance/__init__.py,sha256=TRY3wWWu5kn3Oahk_KKhltrWnh0NACw0FskjqF6OtlA,314
|
|
@@ -47,7 +51,7 @@ graphiti_core/utils/maintenance/graph_data_operations.py,sha256=w66_SLlvPapuG91Y
|
|
|
47
51
|
graphiti_core/utils/maintenance/node_operations.py,sha256=h5nlRojbXOGJs-alpv6z6WnZ1UCixVGlAQYBQUqz8Bs,9030
|
|
48
52
|
graphiti_core/utils/maintenance/temporal_operations.py,sha256=MvaRLWrBlDeYw8CQrKish1xbYcY5ovpfdqA2hSX7v5k,3367
|
|
49
53
|
graphiti_core/utils/maintenance/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
50
|
-
graphiti_core-0.3.
|
|
51
|
-
graphiti_core-0.3.
|
|
52
|
-
graphiti_core-0.3.
|
|
53
|
-
graphiti_core-0.3.
|
|
54
|
+
graphiti_core-0.3.18.dist-info/LICENSE,sha256=KCUwCyDXuVEgmDWkozHyniRyWjnWUWjkuDHfU6o3JlA,11325
|
|
55
|
+
graphiti_core-0.3.18.dist-info/METADATA,sha256=D45OPLftoNd7wWJLtrewFJ1YkgcMLDADopI7P4jWwDg,9396
|
|
56
|
+
graphiti_core-0.3.18.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
57
|
+
graphiti_core-0.3.18.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|