graphiti-core 0.3.16__tar.gz → 0.3.17__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/PKG-INFO +1 -2
- graphiti_core-0.3.17/graphiti_core/cross_encoder/bge_reranker_client.py +45 -0
- graphiti_core-0.3.17/graphiti_core/cross_encoder/client.py +41 -0
- graphiti_core-0.3.17/graphiti_core/cross_encoder/openai_reranker_client.py +113 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/graphiti.py +28 -3
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/search/search.py +43 -15
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/search/search_config.py +13 -1
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/search/search_config_recipes.py +27 -1
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/search/search_utils.py +247 -192
- graphiti_core-0.3.17/graphiti_core/utils/maintenance/utils.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/pyproject.toml +4 -2
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/LICENSE +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/README.md +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/__init__.py +0 -0
- {graphiti_core-0.3.16/graphiti_core/models → graphiti_core-0.3.17/graphiti_core/cross_encoder}/__init__.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/edges.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/embedder/__init__.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/embedder/client.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/embedder/openai.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/embedder/voyage.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/errors.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/helpers.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/llm_client/__init__.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/llm_client/anthropic_client.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/llm_client/client.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/llm_client/config.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/llm_client/errors.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/llm_client/groq_client.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/llm_client/openai_client.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/llm_client/utils.py +0 -0
- {graphiti_core-0.3.16/graphiti_core/models/edges → graphiti_core-0.3.17/graphiti_core/models}/__init__.py +0 -0
- {graphiti_core-0.3.16/graphiti_core/models/nodes → graphiti_core-0.3.17/graphiti_core/models/edges}/__init__.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/models/edges/edge_db_queries.py +0 -0
- {graphiti_core-0.3.16/graphiti_core/search → graphiti_core-0.3.17/graphiti_core/models/nodes}/__init__.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/models/nodes/node_db_queries.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/nodes.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/prompts/__init__.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/prompts/dedupe_edges.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/prompts/dedupe_nodes.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/prompts/eval.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/prompts/extract_edge_dates.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/prompts/extract_edges.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/prompts/extract_nodes.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/prompts/invalidate_edges.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/prompts/lib.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/prompts/models.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/prompts/summarize_nodes.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/py.typed +0 -0
- /graphiti_core-0.3.16/graphiti_core/utils/maintenance/utils.py → /graphiti_core-0.3.17/graphiti_core/search/__init__.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/utils/__init__.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/utils/bulk_utils.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/utils/maintenance/__init__.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/utils/maintenance/community_operations.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/utils/maintenance/edge_operations.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/utils/maintenance/graph_data_operations.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/utils/maintenance/node_operations.py +0 -0
- {graphiti_core-0.3.16 → graphiti_core-0.3.17}/graphiti_core/utils/maintenance/temporal_operations.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: graphiti-core
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.17
|
|
4
4
|
Summary: A temporal graph building library
|
|
5
5
|
License: Apache-2.0
|
|
6
6
|
Author: Paul Paliychuk
|
|
@@ -17,7 +17,6 @@ Requires-Dist: numpy (>=1.0.0)
|
|
|
17
17
|
Requires-Dist: openai (>=1.50.2,<2.0.0)
|
|
18
18
|
Requires-Dist: pydantic (>=2.8.2,<3.0.0)
|
|
19
19
|
Requires-Dist: tenacity (<9.0.0)
|
|
20
|
-
Requires-Dist: voyageai (>=0.2.3,<0.3.0)
|
|
21
20
|
Description-Content-Type: text/markdown
|
|
22
21
|
|
|
23
22
|
<div align="center">
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import asyncio
|
|
18
|
+
from typing import List, Tuple
|
|
19
|
+
|
|
20
|
+
from sentence_transformers import CrossEncoder
|
|
21
|
+
|
|
22
|
+
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class BGERerankerClient(CrossEncoderClient):
|
|
26
|
+
def __init__(self):
|
|
27
|
+
self.model = CrossEncoder('BAAI/bge-reranker-v2-m3')
|
|
28
|
+
|
|
29
|
+
async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]:
|
|
30
|
+
if not passages:
|
|
31
|
+
return []
|
|
32
|
+
|
|
33
|
+
input_pairs = [[query, passage] for passage in passages]
|
|
34
|
+
|
|
35
|
+
# Run the synchronous predict method in an executor
|
|
36
|
+
loop = asyncio.get_running_loop()
|
|
37
|
+
scores = await loop.run_in_executor(None, self.model.predict, input_pairs)
|
|
38
|
+
|
|
39
|
+
ranked_passages = sorted(
|
|
40
|
+
[(passage, float(score)) for passage, score in zip(passages, scores)],
|
|
41
|
+
key=lambda x: x[1],
|
|
42
|
+
reverse=True,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
return ranked_passages
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from abc import ABC, abstractmethod
|
|
18
|
+
from typing import List, Tuple
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class CrossEncoderClient(ABC):
|
|
22
|
+
"""
|
|
23
|
+
CrossEncoderClient is an abstract base class that defines the interface
|
|
24
|
+
for cross-encoder models used for ranking passages based on their relevance to a query.
|
|
25
|
+
It allows for different implementations of cross-encoder models to be used interchangeably.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]:
|
|
30
|
+
"""
|
|
31
|
+
Rank the given passages based on their relevance to the query.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
query (str): The query string.
|
|
35
|
+
passages (List[str]): A list of passages to rank.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
List[Tuple[str, float]]: A list of tuples containing the passage and its score,
|
|
39
|
+
sorted in descending order of relevance.
|
|
40
|
+
"""
|
|
41
|
+
pass
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import asyncio
|
|
18
|
+
import logging
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
import openai
|
|
22
|
+
from openai import AsyncOpenAI
|
|
23
|
+
from pydantic import BaseModel
|
|
24
|
+
|
|
25
|
+
from ..llm_client import LLMConfig, RateLimitError
|
|
26
|
+
from ..prompts import Message
|
|
27
|
+
from .client import CrossEncoderClient
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
DEFAULT_MODEL = 'gpt-4o-mini'
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class BooleanClassifier(BaseModel):
|
|
35
|
+
isTrue: bool
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class OpenAIRerankerClient(CrossEncoderClient):
|
|
39
|
+
def __init__(self, config: LLMConfig | None = None):
|
|
40
|
+
"""
|
|
41
|
+
Initialize the OpenAIClient with the provided configuration, cache setting, and client.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
|
45
|
+
cache (bool): Whether to use caching for responses. Defaults to False.
|
|
46
|
+
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
|
47
|
+
|
|
48
|
+
"""
|
|
49
|
+
if config is None:
|
|
50
|
+
config = LLMConfig()
|
|
51
|
+
|
|
52
|
+
self.config = config
|
|
53
|
+
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
|
54
|
+
|
|
55
|
+
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
|
|
56
|
+
openai_messages_list: Any = [
|
|
57
|
+
[
|
|
58
|
+
Message(
|
|
59
|
+
role='system',
|
|
60
|
+
content='You are an expert tasked with determining whether the passage is relevant to the query',
|
|
61
|
+
),
|
|
62
|
+
Message(
|
|
63
|
+
role='user',
|
|
64
|
+
content=f"""
|
|
65
|
+
Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise.
|
|
66
|
+
<PASSAGE>
|
|
67
|
+
{query}
|
|
68
|
+
</PASSAGE>
|
|
69
|
+
{passage}
|
|
70
|
+
<QUERY>
|
|
71
|
+
</QUERY>
|
|
72
|
+
""",
|
|
73
|
+
),
|
|
74
|
+
]
|
|
75
|
+
for passage in passages
|
|
76
|
+
]
|
|
77
|
+
try:
|
|
78
|
+
responses = await asyncio.gather(
|
|
79
|
+
*[
|
|
80
|
+
self.client.chat.completions.create(
|
|
81
|
+
model=DEFAULT_MODEL,
|
|
82
|
+
messages=openai_messages,
|
|
83
|
+
temperature=0,
|
|
84
|
+
max_tokens=1,
|
|
85
|
+
logit_bias={'6432': 1, '7983': 1},
|
|
86
|
+
logprobs=True,
|
|
87
|
+
top_logprobs=2,
|
|
88
|
+
)
|
|
89
|
+
for openai_messages in openai_messages_list
|
|
90
|
+
]
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
responses_top_logprobs = [
|
|
94
|
+
response.choices[0].logprobs.content[0].top_logprobs
|
|
95
|
+
if response.choices[0].logprobs is not None
|
|
96
|
+
and response.choices[0].logprobs.content is not None
|
|
97
|
+
else []
|
|
98
|
+
for response in responses
|
|
99
|
+
]
|
|
100
|
+
scores: list[float] = []
|
|
101
|
+
for top_logprobs in responses_top_logprobs:
|
|
102
|
+
for logprob in top_logprobs:
|
|
103
|
+
if bool(logprob.token):
|
|
104
|
+
scores.append(logprob.logprob)
|
|
105
|
+
|
|
106
|
+
results = [(passage, score) for passage, score in zip(passages, scores)]
|
|
107
|
+
results.sort(reverse=True, key=lambda x: x[1])
|
|
108
|
+
return results
|
|
109
|
+
except openai.RateLimitError as e:
|
|
110
|
+
raise RateLimitError from e
|
|
111
|
+
except Exception as e:
|
|
112
|
+
logger.error(f'Error in generating LLM response: {e}')
|
|
113
|
+
raise
|
|
@@ -23,8 +23,11 @@ from dotenv import load_dotenv
|
|
|
23
23
|
from neo4j import AsyncGraphDatabase
|
|
24
24
|
from pydantic import BaseModel
|
|
25
25
|
|
|
26
|
+
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
27
|
+
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
|
26
28
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
|
27
29
|
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
|
30
|
+
from graphiti_core.helpers import DEFAULT_DATABASE
|
|
28
31
|
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
|
29
32
|
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
|
30
33
|
from graphiti_core.search.search import SearchConfig, search
|
|
@@ -92,6 +95,7 @@ class Graphiti:
|
|
|
92
95
|
password: str,
|
|
93
96
|
llm_client: LLMClient | None = None,
|
|
94
97
|
embedder: EmbedderClient | None = None,
|
|
98
|
+
cross_encoder: CrossEncoderClient | None = None,
|
|
95
99
|
store_raw_episode_content: bool = True,
|
|
96
100
|
):
|
|
97
101
|
"""
|
|
@@ -131,7 +135,7 @@ class Graphiti:
|
|
|
131
135
|
Graphiti if you're using the default OpenAIClient.
|
|
132
136
|
"""
|
|
133
137
|
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
|
|
134
|
-
self.database =
|
|
138
|
+
self.database = DEFAULT_DATABASE
|
|
135
139
|
self.store_raw_episode_content = store_raw_episode_content
|
|
136
140
|
if llm_client:
|
|
137
141
|
self.llm_client = llm_client
|
|
@@ -141,6 +145,10 @@ class Graphiti:
|
|
|
141
145
|
self.embedder = embedder
|
|
142
146
|
else:
|
|
143
147
|
self.embedder = OpenAIEmbedder()
|
|
148
|
+
if cross_encoder:
|
|
149
|
+
self.cross_encoder = cross_encoder
|
|
150
|
+
else:
|
|
151
|
+
self.cross_encoder = OpenAIRerankerClient()
|
|
144
152
|
|
|
145
153
|
async def close(self):
|
|
146
154
|
"""
|
|
@@ -648,6 +656,7 @@ class Graphiti:
|
|
|
648
656
|
await search(
|
|
649
657
|
self.driver,
|
|
650
658
|
self.embedder,
|
|
659
|
+
self.cross_encoder,
|
|
651
660
|
query,
|
|
652
661
|
group_ids,
|
|
653
662
|
search_config,
|
|
@@ -663,8 +672,18 @@ class Graphiti:
|
|
|
663
672
|
config: SearchConfig,
|
|
664
673
|
group_ids: list[str] | None = None,
|
|
665
674
|
center_node_uuid: str | None = None,
|
|
675
|
+
bfs_origin_node_uuids: list[str] | None = None,
|
|
666
676
|
) -> SearchResults:
|
|
667
|
-
return await search(
|
|
677
|
+
return await search(
|
|
678
|
+
self.driver,
|
|
679
|
+
self.embedder,
|
|
680
|
+
self.cross_encoder,
|
|
681
|
+
query,
|
|
682
|
+
group_ids,
|
|
683
|
+
config,
|
|
684
|
+
center_node_uuid,
|
|
685
|
+
bfs_origin_node_uuids,
|
|
686
|
+
)
|
|
668
687
|
|
|
669
688
|
async def get_nodes_by_query(
|
|
670
689
|
self,
|
|
@@ -716,7 +735,13 @@ class Graphiti:
|
|
|
716
735
|
|
|
717
736
|
nodes = (
|
|
718
737
|
await search(
|
|
719
|
-
self.driver,
|
|
738
|
+
self.driver,
|
|
739
|
+
self.embedder,
|
|
740
|
+
self.cross_encoder,
|
|
741
|
+
query,
|
|
742
|
+
group_ids,
|
|
743
|
+
search_config,
|
|
744
|
+
center_node_uuid,
|
|
720
745
|
)
|
|
721
746
|
).nodes
|
|
722
747
|
return nodes
|
|
@@ -21,6 +21,7 @@ from time import time
|
|
|
21
21
|
|
|
22
22
|
from neo4j import AsyncDriver
|
|
23
23
|
|
|
24
|
+
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
24
25
|
from graphiti_core.edges import EntityEdge
|
|
25
26
|
from graphiti_core.embedder import EmbedderClient
|
|
26
27
|
from graphiti_core.errors import SearchRerankerError
|
|
@@ -39,6 +40,7 @@ from graphiti_core.search.search_config import (
|
|
|
39
40
|
from graphiti_core.search.search_utils import (
|
|
40
41
|
community_fulltext_search,
|
|
41
42
|
community_similarity_search,
|
|
43
|
+
edge_bfs_search,
|
|
42
44
|
edge_fulltext_search,
|
|
43
45
|
edge_similarity_search,
|
|
44
46
|
episode_mentions_reranker,
|
|
@@ -55,40 +57,49 @@ logger = logging.getLogger(__name__)
|
|
|
55
57
|
async def search(
|
|
56
58
|
driver: AsyncDriver,
|
|
57
59
|
embedder: EmbedderClient,
|
|
60
|
+
cross_encoder: CrossEncoderClient,
|
|
58
61
|
query: str,
|
|
59
62
|
group_ids: list[str] | None,
|
|
60
63
|
config: SearchConfig,
|
|
61
64
|
center_node_uuid: str | None = None,
|
|
65
|
+
bfs_origin_node_uuids: list[str] | None = None,
|
|
62
66
|
) -> SearchResults:
|
|
63
67
|
start = time()
|
|
64
|
-
|
|
68
|
+
query_vector = await embedder.create(input=[query.replace('\n', ' ')])
|
|
69
|
+
|
|
65
70
|
# if group_ids is empty, set it to None
|
|
66
71
|
group_ids = group_ids if group_ids else None
|
|
67
72
|
edges, nodes, communities = await asyncio.gather(
|
|
68
73
|
edge_search(
|
|
69
74
|
driver,
|
|
70
|
-
|
|
75
|
+
cross_encoder,
|
|
71
76
|
query,
|
|
77
|
+
query_vector,
|
|
72
78
|
group_ids,
|
|
73
79
|
config.edge_config,
|
|
74
80
|
center_node_uuid,
|
|
81
|
+
bfs_origin_node_uuids,
|
|
75
82
|
config.limit,
|
|
76
83
|
),
|
|
77
84
|
node_search(
|
|
78
85
|
driver,
|
|
79
|
-
|
|
86
|
+
cross_encoder,
|
|
80
87
|
query,
|
|
88
|
+
query_vector,
|
|
81
89
|
group_ids,
|
|
82
90
|
config.node_config,
|
|
83
91
|
center_node_uuid,
|
|
92
|
+
bfs_origin_node_uuids,
|
|
84
93
|
config.limit,
|
|
85
94
|
),
|
|
86
95
|
community_search(
|
|
87
96
|
driver,
|
|
88
|
-
|
|
97
|
+
cross_encoder,
|
|
89
98
|
query,
|
|
99
|
+
query_vector,
|
|
90
100
|
group_ids,
|
|
91
101
|
config.community_config,
|
|
102
|
+
bfs_origin_node_uuids,
|
|
92
103
|
config.limit,
|
|
93
104
|
),
|
|
94
105
|
)
|
|
@@ -99,27 +110,27 @@ async def search(
|
|
|
99
110
|
communities=communities,
|
|
100
111
|
)
|
|
101
112
|
|
|
102
|
-
|
|
113
|
+
latency = (time() - start) * 1000
|
|
103
114
|
|
|
104
|
-
logger.
|
|
115
|
+
logger.debug(f'search returned context for query {query} in {latency} ms')
|
|
105
116
|
|
|
106
117
|
return results
|
|
107
118
|
|
|
108
119
|
|
|
109
120
|
async def edge_search(
|
|
110
121
|
driver: AsyncDriver,
|
|
111
|
-
|
|
122
|
+
cross_encoder: CrossEncoderClient,
|
|
112
123
|
query: str,
|
|
124
|
+
query_vector: list[float],
|
|
113
125
|
group_ids: list[str] | None,
|
|
114
126
|
config: EdgeSearchConfig | None,
|
|
115
127
|
center_node_uuid: str | None = None,
|
|
128
|
+
bfs_origin_node_uuids: list[str] | None = None,
|
|
116
129
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
117
130
|
) -> list[EntityEdge]:
|
|
118
131
|
if config is None:
|
|
119
132
|
return []
|
|
120
133
|
|
|
121
|
-
query_vector = await embedder.create(input=[query])
|
|
122
|
-
|
|
123
134
|
search_results: list[list[EntityEdge]] = list(
|
|
124
135
|
await asyncio.gather(
|
|
125
136
|
*[
|
|
@@ -127,6 +138,7 @@ async def edge_search(
|
|
|
127
138
|
edge_similarity_search(
|
|
128
139
|
driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score
|
|
129
140
|
),
|
|
141
|
+
edge_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth),
|
|
130
142
|
]
|
|
131
143
|
)
|
|
132
144
|
)
|
|
@@ -147,6 +159,10 @@ async def edge_search(
|
|
|
147
159
|
reranked_uuids = maximal_marginal_relevance(
|
|
148
160
|
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
|
|
149
161
|
)
|
|
162
|
+
elif config.reranker == EdgeReranker.cross_encoder:
|
|
163
|
+
fact_to_uuid_map = {edge.fact: edge.uuid for result in search_results for edge in result}
|
|
164
|
+
reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
|
|
165
|
+
reranked_uuids = [fact_to_uuid_map[fact] for fact, _ in reranked_facts]
|
|
150
166
|
elif config.reranker == EdgeReranker.node_distance:
|
|
151
167
|
if center_node_uuid is None:
|
|
152
168
|
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
|
@@ -177,18 +193,18 @@ async def edge_search(
|
|
|
177
193
|
|
|
178
194
|
async def node_search(
|
|
179
195
|
driver: AsyncDriver,
|
|
180
|
-
|
|
196
|
+
cross_encoder: CrossEncoderClient,
|
|
181
197
|
query: str,
|
|
198
|
+
query_vector: list[float],
|
|
182
199
|
group_ids: list[str] | None,
|
|
183
200
|
config: NodeSearchConfig | None,
|
|
184
201
|
center_node_uuid: str | None = None,
|
|
202
|
+
bfs_origin_node_uuids: list[str] | None = None,
|
|
185
203
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
186
204
|
) -> list[EntityNode]:
|
|
187
205
|
if config is None:
|
|
188
206
|
return []
|
|
189
207
|
|
|
190
|
-
query_vector = await embedder.create(input=[query])
|
|
191
|
-
|
|
192
208
|
search_results: list[list[EntityNode]] = list(
|
|
193
209
|
await asyncio.gather(
|
|
194
210
|
*[
|
|
@@ -215,6 +231,12 @@ async def node_search(
|
|
|
215
231
|
reranked_uuids = maximal_marginal_relevance(
|
|
216
232
|
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
|
|
217
233
|
)
|
|
234
|
+
elif config.reranker == NodeReranker.cross_encoder:
|
|
235
|
+
summary_to_uuid_map = {
|
|
236
|
+
node.summary: node.uuid for result in search_results for node in result
|
|
237
|
+
}
|
|
238
|
+
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
|
|
239
|
+
reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
|
|
218
240
|
elif config.reranker == NodeReranker.episode_mentions:
|
|
219
241
|
reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
|
|
220
242
|
elif config.reranker == NodeReranker.node_distance:
|
|
@@ -231,17 +253,17 @@ async def node_search(
|
|
|
231
253
|
|
|
232
254
|
async def community_search(
|
|
233
255
|
driver: AsyncDriver,
|
|
234
|
-
|
|
256
|
+
cross_encoder: CrossEncoderClient,
|
|
235
257
|
query: str,
|
|
258
|
+
query_vector: list[float],
|
|
236
259
|
group_ids: list[str] | None,
|
|
237
260
|
config: CommunitySearchConfig | None,
|
|
261
|
+
bfs_origin_node_uuids: list[str] | None = None,
|
|
238
262
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
239
263
|
) -> list[CommunityNode]:
|
|
240
264
|
if config is None:
|
|
241
265
|
return []
|
|
242
266
|
|
|
243
|
-
query_vector = await embedder.create(input=[query])
|
|
244
|
-
|
|
245
267
|
search_results: list[list[CommunityNode]] = list(
|
|
246
268
|
await asyncio.gather(
|
|
247
269
|
*[
|
|
@@ -273,6 +295,12 @@ async def community_search(
|
|
|
273
295
|
reranked_uuids = maximal_marginal_relevance(
|
|
274
296
|
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
|
|
275
297
|
)
|
|
298
|
+
elif config.reranker == CommunityReranker.cross_encoder:
|
|
299
|
+
summary_to_uuid_map = {
|
|
300
|
+
node.summary: node.uuid for result in search_results for node in result
|
|
301
|
+
}
|
|
302
|
+
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
|
|
303
|
+
reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
|
|
276
304
|
|
|
277
305
|
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
|
|
278
306
|
|
|
@@ -20,7 +20,11 @@ from pydantic import BaseModel, Field
|
|
|
20
20
|
|
|
21
21
|
from graphiti_core.edges import EntityEdge
|
|
22
22
|
from graphiti_core.nodes import CommunityNode, EntityNode
|
|
23
|
-
from graphiti_core.search.search_utils import
|
|
23
|
+
from graphiti_core.search.search_utils import (
|
|
24
|
+
DEFAULT_MIN_SCORE,
|
|
25
|
+
DEFAULT_MMR_LAMBDA,
|
|
26
|
+
MAX_SEARCH_DEPTH,
|
|
27
|
+
)
|
|
24
28
|
|
|
25
29
|
DEFAULT_SEARCH_LIMIT = 10
|
|
26
30
|
|
|
@@ -28,11 +32,13 @@ DEFAULT_SEARCH_LIMIT = 10
|
|
|
28
32
|
class EdgeSearchMethod(Enum):
|
|
29
33
|
cosine_similarity = 'cosine_similarity'
|
|
30
34
|
bm25 = 'bm25'
|
|
35
|
+
bfs = 'breadth_first_search'
|
|
31
36
|
|
|
32
37
|
|
|
33
38
|
class NodeSearchMethod(Enum):
|
|
34
39
|
cosine_similarity = 'cosine_similarity'
|
|
35
40
|
bm25 = 'bm25'
|
|
41
|
+
bfs = 'breadth_first_search'
|
|
36
42
|
|
|
37
43
|
|
|
38
44
|
class CommunitySearchMethod(Enum):
|
|
@@ -45,6 +51,7 @@ class EdgeReranker(Enum):
|
|
|
45
51
|
node_distance = 'node_distance'
|
|
46
52
|
episode_mentions = 'episode_mentions'
|
|
47
53
|
mmr = 'mmr'
|
|
54
|
+
cross_encoder = 'cross_encoder'
|
|
48
55
|
|
|
49
56
|
|
|
50
57
|
class NodeReranker(Enum):
|
|
@@ -52,11 +59,13 @@ class NodeReranker(Enum):
|
|
|
52
59
|
node_distance = 'node_distance'
|
|
53
60
|
episode_mentions = 'episode_mentions'
|
|
54
61
|
mmr = 'mmr'
|
|
62
|
+
cross_encoder = 'cross_encoder'
|
|
55
63
|
|
|
56
64
|
|
|
57
65
|
class CommunityReranker(Enum):
|
|
58
66
|
rrf = 'reciprocal_rank_fusion'
|
|
59
67
|
mmr = 'mmr'
|
|
68
|
+
cross_encoder = 'cross_encoder'
|
|
60
69
|
|
|
61
70
|
|
|
62
71
|
class EdgeSearchConfig(BaseModel):
|
|
@@ -64,6 +73,7 @@ class EdgeSearchConfig(BaseModel):
|
|
|
64
73
|
reranker: EdgeReranker = Field(default=EdgeReranker.rrf)
|
|
65
74
|
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
|
66
75
|
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
|
76
|
+
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
|
67
77
|
|
|
68
78
|
|
|
69
79
|
class NodeSearchConfig(BaseModel):
|
|
@@ -71,6 +81,7 @@ class NodeSearchConfig(BaseModel):
|
|
|
71
81
|
reranker: NodeReranker = Field(default=NodeReranker.rrf)
|
|
72
82
|
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
|
73
83
|
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
|
84
|
+
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
|
74
85
|
|
|
75
86
|
|
|
76
87
|
class CommunitySearchConfig(BaseModel):
|
|
@@ -78,6 +89,7 @@ class CommunitySearchConfig(BaseModel):
|
|
|
78
89
|
reranker: CommunityReranker = Field(default=CommunityReranker.rrf)
|
|
79
90
|
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
|
80
91
|
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
|
92
|
+
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
|
81
93
|
|
|
82
94
|
|
|
83
95
|
class SearchConfig(BaseModel):
|
|
@@ -48,14 +48,41 @@ COMBINED_HYBRID_SEARCH_MMR = SearchConfig(
|
|
|
48
48
|
edge_config=EdgeSearchConfig(
|
|
49
49
|
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
|
50
50
|
reranker=EdgeReranker.mmr,
|
|
51
|
+
mmr_lambda=1,
|
|
51
52
|
),
|
|
52
53
|
node_config=NodeSearchConfig(
|
|
53
54
|
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
|
|
54
55
|
reranker=NodeReranker.mmr,
|
|
56
|
+
mmr_lambda=1,
|
|
55
57
|
),
|
|
56
58
|
community_config=CommunitySearchConfig(
|
|
57
59
|
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
|
|
58
60
|
reranker=CommunityReranker.mmr,
|
|
61
|
+
mmr_lambda=1,
|
|
62
|
+
),
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Performs a full-text search, similarity search, and bfs with cross_encoder reranking over edges, nodes, and communities
|
|
66
|
+
COMBINED_HYBRID_SEARCH_CROSS_ENCODER = SearchConfig(
|
|
67
|
+
edge_config=EdgeSearchConfig(
|
|
68
|
+
search_methods=[
|
|
69
|
+
EdgeSearchMethod.bm25,
|
|
70
|
+
EdgeSearchMethod.cosine_similarity,
|
|
71
|
+
EdgeSearchMethod.bfs,
|
|
72
|
+
],
|
|
73
|
+
reranker=EdgeReranker.cross_encoder,
|
|
74
|
+
),
|
|
75
|
+
node_config=NodeSearchConfig(
|
|
76
|
+
search_methods=[
|
|
77
|
+
NodeSearchMethod.bm25,
|
|
78
|
+
NodeSearchMethod.cosine_similarity,
|
|
79
|
+
NodeSearchMethod.bfs,
|
|
80
|
+
],
|
|
81
|
+
reranker=NodeReranker.cross_encoder,
|
|
82
|
+
),
|
|
83
|
+
community_config=CommunitySearchConfig(
|
|
84
|
+
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
|
|
85
|
+
reranker=CommunityReranker.cross_encoder,
|
|
59
86
|
),
|
|
60
87
|
)
|
|
61
88
|
|
|
@@ -81,7 +108,6 @@ EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
|
|
|
81
108
|
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
|
82
109
|
reranker=EdgeReranker.node_distance,
|
|
83
110
|
),
|
|
84
|
-
limit=30,
|
|
85
111
|
)
|
|
86
112
|
|
|
87
113
|
# performs a hybrid search over edges with episode mention reranking
|