graphiti-core 0.3.0__tar.gz → 0.3.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/PKG-INFO +1 -1
- graphiti_core-0.3.1/graphiti_core/errors.py +43 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/graphiti.py +33 -25
- graphiti_core-0.3.1/graphiti_core/helpers.py +23 -0
- graphiti_core-0.3.1/graphiti_core/llm_client/errors.py +23 -0
- graphiti_core-0.3.1/graphiti_core/llm_client/utils.py +38 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/prompts/extract_edge_dates.py +16 -0
- graphiti_core-0.3.1/graphiti_core/search/search.py +242 -0
- graphiti_core-0.3.1/graphiti_core/search/search_config.py +81 -0
- graphiti_core-0.3.1/graphiti_core/search/search_config_recipes.py +84 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/search/search_utils.py +242 -140
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/pyproject.toml +1 -1
- graphiti_core-0.3.0/graphiti_core/errors.py +0 -18
- graphiti_core-0.3.0/graphiti_core/helpers.py +0 -7
- graphiti_core-0.3.0/graphiti_core/llm_client/errors.py +0 -6
- graphiti_core-0.3.0/graphiti_core/llm_client/utils.py +0 -22
- graphiti_core-0.3.0/graphiti_core/search/search.py +0 -145
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/LICENSE +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/README.md +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/__init__.py +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/edges.py +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/llm_client/__init__.py +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/llm_client/anthropic_client.py +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/llm_client/client.py +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/llm_client/config.py +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/llm_client/groq_client.py +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/llm_client/openai_client.py +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/nodes.py +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/prompts/__init__.py +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/prompts/dedupe_edges.py +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/prompts/dedupe_nodes.py +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/prompts/extract_edges.py +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/prompts/extract_nodes.py +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/prompts/invalidate_edges.py +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/prompts/lib.py +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/prompts/models.py +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/prompts/summarize_nodes.py +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/py.typed +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/search/__init__.py +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/utils/__init__.py +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/utils/bulk_utils.py +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/utils/maintenance/__init__.py +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/utils/maintenance/community_operations.py +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/utils/maintenance/edge_operations.py +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/utils/maintenance/graph_data_operations.py +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/utils/maintenance/node_operations.py +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/utils/maintenance/temporal_operations.py +0 -0
- {graphiti_core-0.3.0 → graphiti_core-0.3.1}/graphiti_core/utils/maintenance/utils.py +0 -0
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class GraphitiError(Exception):
|
|
19
|
+
"""Base exception class for Graphiti Core."""
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class EdgeNotFoundError(GraphitiError):
|
|
23
|
+
"""Raised when an edge is not found."""
|
|
24
|
+
|
|
25
|
+
def __init__(self, uuid: str):
|
|
26
|
+
self.message = f'edge {uuid} not found'
|
|
27
|
+
super().__init__(self.message)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class NodeNotFoundError(GraphitiError):
|
|
31
|
+
"""Raised when a node is not found."""
|
|
32
|
+
|
|
33
|
+
def __init__(self, uuid: str):
|
|
34
|
+
self.message = f'node {uuid} not found'
|
|
35
|
+
super().__init__(self.message)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class SearchRerankerError(GraphitiError):
|
|
39
|
+
"""Raised when a node is not found."""
|
|
40
|
+
|
|
41
|
+
def __init__(self, text: str):
|
|
42
|
+
self.message = text
|
|
43
|
+
super().__init__(self.message)
|
|
@@ -24,14 +24,19 @@ from neo4j import AsyncGraphDatabase
|
|
|
24
24
|
|
|
25
25
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
|
26
26
|
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
|
27
|
-
from graphiti_core.llm_client.utils import generate_embedding
|
|
28
27
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
|
29
|
-
from graphiti_core.search.search import
|
|
28
|
+
from graphiti_core.search.search import SearchConfig, search
|
|
29
|
+
from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
|
|
30
|
+
from graphiti_core.search.search_config_recipes import (
|
|
31
|
+
EDGE_HYBRID_SEARCH_NODE_DISTANCE,
|
|
32
|
+
EDGE_HYBRID_SEARCH_RRF,
|
|
33
|
+
NODE_HYBRID_SEARCH_NODE_DISTANCE,
|
|
34
|
+
NODE_HYBRID_SEARCH_RRF,
|
|
35
|
+
)
|
|
30
36
|
from graphiti_core.search.search_utils import (
|
|
31
37
|
RELEVANT_SCHEMA_LIMIT,
|
|
32
38
|
get_relevant_edges,
|
|
33
39
|
get_relevant_nodes,
|
|
34
|
-
hybrid_node_search,
|
|
35
40
|
)
|
|
36
41
|
from graphiti_core.utils import (
|
|
37
42
|
build_episodic_edges,
|
|
@@ -548,7 +553,7 @@ class Graphiti:
|
|
|
548
553
|
query: str,
|
|
549
554
|
center_node_uuid: str | None = None,
|
|
550
555
|
group_ids: list[str | None] | None = None,
|
|
551
|
-
num_results=
|
|
556
|
+
num_results=DEFAULT_SEARCH_LIMIT,
|
|
552
557
|
):
|
|
553
558
|
"""
|
|
554
559
|
Perform a hybrid search on the knowledge graph.
|
|
@@ -564,7 +569,7 @@ class Graphiti:
|
|
|
564
569
|
Facts will be reranked based on proximity to this node
|
|
565
570
|
group_ids : list[str | None] | None, optional
|
|
566
571
|
The graph partitions to return data from.
|
|
567
|
-
|
|
572
|
+
limit : int, optional
|
|
568
573
|
The maximum number of results to return. Defaults to 10.
|
|
569
574
|
|
|
570
575
|
Returns
|
|
@@ -581,21 +586,17 @@ class Graphiti:
|
|
|
581
586
|
The search is performed using the current date and time as the reference
|
|
582
587
|
point for temporal relevance.
|
|
583
588
|
"""
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
num_episodes=0,
|
|
587
|
-
num_edges=num_results,
|
|
588
|
-
num_nodes=0,
|
|
589
|
-
group_ids=group_ids,
|
|
590
|
-
search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
|
|
591
|
-
reranker=reranker,
|
|
589
|
+
search_config = (
|
|
590
|
+
EDGE_HYBRID_SEARCH_RRF if center_node_uuid is None else EDGE_HYBRID_SEARCH_NODE_DISTANCE
|
|
592
591
|
)
|
|
592
|
+
search_config.limit = num_results
|
|
593
|
+
|
|
593
594
|
edges = (
|
|
594
|
-
await
|
|
595
|
+
await search(
|
|
595
596
|
self.driver,
|
|
596
597
|
self.llm_client.get_embedder(),
|
|
597
598
|
query,
|
|
598
|
-
|
|
599
|
+
group_ids,
|
|
599
600
|
search_config,
|
|
600
601
|
center_node_uuid,
|
|
601
602
|
)
|
|
@@ -606,19 +607,20 @@ class Graphiti:
|
|
|
606
607
|
async def _search(
|
|
607
608
|
self,
|
|
608
609
|
query: str,
|
|
609
|
-
timestamp: datetime,
|
|
610
610
|
config: SearchConfig,
|
|
611
|
+
group_ids: list[str | None] | None = None,
|
|
611
612
|
center_node_uuid: str | None = None,
|
|
612
|
-
):
|
|
613
|
-
return await
|
|
614
|
-
self.driver, self.llm_client.get_embedder(), query,
|
|
613
|
+
) -> SearchResults:
|
|
614
|
+
return await search(
|
|
615
|
+
self.driver, self.llm_client.get_embedder(), query, group_ids, config, center_node_uuid
|
|
615
616
|
)
|
|
616
617
|
|
|
617
618
|
async def get_nodes_by_query(
|
|
618
619
|
self,
|
|
619
620
|
query: str,
|
|
621
|
+
center_node_uuid: str | None = None,
|
|
620
622
|
group_ids: list[str | None] | None = None,
|
|
621
|
-
limit: int =
|
|
623
|
+
limit: int = DEFAULT_SEARCH_LIMIT,
|
|
622
624
|
) -> list[EntityNode]:
|
|
623
625
|
"""
|
|
624
626
|
Retrieve nodes from the graph database based on a text query.
|
|
@@ -629,7 +631,9 @@ class Graphiti:
|
|
|
629
631
|
Parameters
|
|
630
632
|
----------
|
|
631
633
|
query : str
|
|
632
|
-
The text query to search for in the graph
|
|
634
|
+
The text query to search for in the graph
|
|
635
|
+
center_node_uuid: str, optional
|
|
636
|
+
Facts will be reranked based on proximity to this node.
|
|
633
637
|
group_ids : list[str | None] | None, optional
|
|
634
638
|
The graph partitions to return data from.
|
|
635
639
|
limit : int | None, optional
|
|
@@ -655,8 +659,12 @@ class Graphiti:
|
|
|
655
659
|
If not specified, a default limit (defined in the search functions) will be used.
|
|
656
660
|
"""
|
|
657
661
|
embedder = self.llm_client.get_embedder()
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
[query], [query_embedding], self.driver, group_ids, limit
|
|
662
|
+
search_config = (
|
|
663
|
+
NODE_HYBRID_SEARCH_RRF if center_node_uuid is None else NODE_HYBRID_SEARCH_NODE_DISTANCE
|
|
661
664
|
)
|
|
662
|
-
|
|
665
|
+
search_config.limit = limit
|
|
666
|
+
|
|
667
|
+
nodes = (
|
|
668
|
+
await search(self.driver, embedder, query, group_ids, search_config, center_node_uuid)
|
|
669
|
+
).nodes
|
|
670
|
+
return nodes
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from datetime import datetime
|
|
18
|
+
|
|
19
|
+
from neo4j import time as neo4j_time
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
|
23
|
+
return neo_date.to_native() if neo_date else None
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RateLimitError(Exception):
|
|
19
|
+
"""Exception raised when the rate limit is exceeded."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, message='Rate limit exceeded. Please try again later.'):
|
|
22
|
+
self.message = message
|
|
23
|
+
super().__init__(self.message)
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
import typing
|
|
19
|
+
from time import time
|
|
20
|
+
|
|
21
|
+
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
async def generate_embedding(
|
|
27
|
+
embedder: typing.Any, text: str, model: str = 'text-embedding-3-small'
|
|
28
|
+
):
|
|
29
|
+
start = time()
|
|
30
|
+
|
|
31
|
+
text = text.replace('\n', ' ')
|
|
32
|
+
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
|
|
33
|
+
embedding = embedding[:EMBEDDING_DIM]
|
|
34
|
+
|
|
35
|
+
end = time()
|
|
36
|
+
logger.debug(f'embedded text of length {len(text)} in {end - start} ms')
|
|
37
|
+
|
|
38
|
+
return embedding
|
|
@@ -1,3 +1,19 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
1
17
|
from typing import Any, Protocol, TypedDict
|
|
2
18
|
|
|
3
19
|
from .models import Message, PromptFunction, PromptVersion
|
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from time import time
|
|
19
|
+
|
|
20
|
+
from neo4j import AsyncDriver
|
|
21
|
+
|
|
22
|
+
from graphiti_core.edges import EntityEdge
|
|
23
|
+
from graphiti_core.errors import SearchRerankerError
|
|
24
|
+
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
|
25
|
+
from graphiti_core.nodes import CommunityNode, EntityNode
|
|
26
|
+
from graphiti_core.search.search_config import (
|
|
27
|
+
DEFAULT_SEARCH_LIMIT,
|
|
28
|
+
CommunityReranker,
|
|
29
|
+
CommunitySearchConfig,
|
|
30
|
+
CommunitySearchMethod,
|
|
31
|
+
EdgeReranker,
|
|
32
|
+
EdgeSearchConfig,
|
|
33
|
+
EdgeSearchMethod,
|
|
34
|
+
NodeReranker,
|
|
35
|
+
NodeSearchConfig,
|
|
36
|
+
NodeSearchMethod,
|
|
37
|
+
SearchConfig,
|
|
38
|
+
SearchResults,
|
|
39
|
+
)
|
|
40
|
+
from graphiti_core.search.search_utils import (
|
|
41
|
+
community_fulltext_search,
|
|
42
|
+
community_similarity_search,
|
|
43
|
+
edge_fulltext_search,
|
|
44
|
+
edge_similarity_search,
|
|
45
|
+
node_distance_reranker,
|
|
46
|
+
node_fulltext_search,
|
|
47
|
+
node_similarity_search,
|
|
48
|
+
rrf,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
logger = logging.getLogger(__name__)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
async def search(
|
|
55
|
+
driver: AsyncDriver,
|
|
56
|
+
embedder,
|
|
57
|
+
query: str,
|
|
58
|
+
group_ids: list[str | None] | None,
|
|
59
|
+
config: SearchConfig,
|
|
60
|
+
center_node_uuid: str | None = None,
|
|
61
|
+
) -> SearchResults:
|
|
62
|
+
start = time()
|
|
63
|
+
query = query.replace('\n', ' ')
|
|
64
|
+
# if group_ids is empty, set it to None
|
|
65
|
+
group_ids = group_ids if group_ids else None
|
|
66
|
+
edges = (
|
|
67
|
+
await edge_search(
|
|
68
|
+
driver, embedder, query, group_ids, config.edge_config, center_node_uuid, config.limit
|
|
69
|
+
)
|
|
70
|
+
if config.edge_config is not None
|
|
71
|
+
else []
|
|
72
|
+
)
|
|
73
|
+
nodes = (
|
|
74
|
+
await node_search(
|
|
75
|
+
driver, embedder, query, group_ids, config.node_config, center_node_uuid, config.limit
|
|
76
|
+
)
|
|
77
|
+
if config.node_config is not None
|
|
78
|
+
else []
|
|
79
|
+
)
|
|
80
|
+
communities = (
|
|
81
|
+
await community_search(
|
|
82
|
+
driver, embedder, query, group_ids, config.community_config, config.limit
|
|
83
|
+
)
|
|
84
|
+
if config.community_config is not None
|
|
85
|
+
else []
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
results = SearchResults(
|
|
89
|
+
edges=edges[: config.limit],
|
|
90
|
+
nodes=nodes[: config.limit],
|
|
91
|
+
communities=communities[: config.limit],
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
end = time()
|
|
95
|
+
|
|
96
|
+
logger.info(f'search returned context for query {query} in {(end - start) * 1000} ms')
|
|
97
|
+
|
|
98
|
+
return results
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
async def edge_search(
|
|
102
|
+
driver: AsyncDriver,
|
|
103
|
+
embedder,
|
|
104
|
+
query: str,
|
|
105
|
+
group_ids: list[str | None] | None,
|
|
106
|
+
config: EdgeSearchConfig,
|
|
107
|
+
center_node_uuid: str | None = None,
|
|
108
|
+
limit=DEFAULT_SEARCH_LIMIT,
|
|
109
|
+
) -> list[EntityEdge]:
|
|
110
|
+
search_results: list[list[EntityEdge]] = []
|
|
111
|
+
|
|
112
|
+
if EdgeSearchMethod.bm25 in config.search_methods:
|
|
113
|
+
text_search = await edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit)
|
|
114
|
+
search_results.append(text_search)
|
|
115
|
+
|
|
116
|
+
if EdgeSearchMethod.cosine_similarity in config.search_methods:
|
|
117
|
+
search_vector = (
|
|
118
|
+
(await embedder.create(input=[query], model='text-embedding-3-small'))
|
|
119
|
+
.data[0]
|
|
120
|
+
.embedding[:EMBEDDING_DIM]
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
similarity_search = await edge_similarity_search(
|
|
124
|
+
driver, search_vector, None, None, group_ids, 2 * limit
|
|
125
|
+
)
|
|
126
|
+
search_results.append(similarity_search)
|
|
127
|
+
|
|
128
|
+
if len(search_results) > 1 and config.reranker is None:
|
|
129
|
+
raise SearchRerankerError('Multiple edge searches enabled without a reranker')
|
|
130
|
+
|
|
131
|
+
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
|
|
132
|
+
|
|
133
|
+
reranked_uuids: list[str] = []
|
|
134
|
+
if config.reranker == EdgeReranker.rrf:
|
|
135
|
+
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
|
136
|
+
|
|
137
|
+
reranked_uuids = rrf(search_result_uuids)
|
|
138
|
+
elif config.reranker == EdgeReranker.node_distance:
|
|
139
|
+
if center_node_uuid is None:
|
|
140
|
+
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
|
141
|
+
|
|
142
|
+
source_to_edge_uuid_map = {
|
|
143
|
+
edge.source_node_uuid: edge.uuid for result in search_results for edge in result
|
|
144
|
+
}
|
|
145
|
+
source_uuids = [[edge.source_node_uuid for edge in result] for result in search_results]
|
|
146
|
+
|
|
147
|
+
reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid)
|
|
148
|
+
|
|
149
|
+
reranked_uuids = [source_to_edge_uuid_map[node_uuid] for node_uuid in reranked_node_uuids]
|
|
150
|
+
|
|
151
|
+
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
|
|
152
|
+
|
|
153
|
+
return reranked_edges
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
async def node_search(
|
|
157
|
+
driver: AsyncDriver,
|
|
158
|
+
embedder,
|
|
159
|
+
query: str,
|
|
160
|
+
group_ids: list[str | None] | None,
|
|
161
|
+
config: NodeSearchConfig,
|
|
162
|
+
center_node_uuid: str | None = None,
|
|
163
|
+
limit=DEFAULT_SEARCH_LIMIT,
|
|
164
|
+
) -> list[EntityNode]:
|
|
165
|
+
search_results: list[list[EntityNode]] = []
|
|
166
|
+
|
|
167
|
+
if NodeSearchMethod.bm25 in config.search_methods:
|
|
168
|
+
text_search = await node_fulltext_search(driver, query, group_ids, 2 * limit)
|
|
169
|
+
search_results.append(text_search)
|
|
170
|
+
|
|
171
|
+
if NodeSearchMethod.cosine_similarity in config.search_methods:
|
|
172
|
+
search_vector = (
|
|
173
|
+
(await embedder.create(input=[query], model='text-embedding-3-small'))
|
|
174
|
+
.data[0]
|
|
175
|
+
.embedding[:EMBEDDING_DIM]
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
similarity_search = await node_similarity_search(
|
|
179
|
+
driver, search_vector, group_ids, 2 * limit
|
|
180
|
+
)
|
|
181
|
+
search_results.append(similarity_search)
|
|
182
|
+
|
|
183
|
+
if len(search_results) > 1 and config.reranker is None:
|
|
184
|
+
raise SearchRerankerError('Multiple node searches enabled without a reranker')
|
|
185
|
+
|
|
186
|
+
search_result_uuids = [[node.uuid for node in result] for result in search_results]
|
|
187
|
+
node_uuid_map = {node.uuid: node for result in search_results for node in result}
|
|
188
|
+
|
|
189
|
+
reranked_uuids: list[str] = []
|
|
190
|
+
if config.reranker == NodeReranker.rrf:
|
|
191
|
+
reranked_uuids = rrf(search_result_uuids)
|
|
192
|
+
elif config.reranker == NodeReranker.node_distance:
|
|
193
|
+
if center_node_uuid is None:
|
|
194
|
+
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
|
195
|
+
reranked_uuids = await node_distance_reranker(driver, search_result_uuids, center_node_uuid)
|
|
196
|
+
|
|
197
|
+
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
|
|
198
|
+
|
|
199
|
+
return reranked_nodes
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
async def community_search(
|
|
203
|
+
driver: AsyncDriver,
|
|
204
|
+
embedder,
|
|
205
|
+
query: str,
|
|
206
|
+
group_ids: list[str | None] | None,
|
|
207
|
+
config: CommunitySearchConfig,
|
|
208
|
+
limit=DEFAULT_SEARCH_LIMIT,
|
|
209
|
+
) -> list[CommunityNode]:
|
|
210
|
+
search_results: list[list[CommunityNode]] = []
|
|
211
|
+
|
|
212
|
+
if CommunitySearchMethod.bm25 in config.search_methods:
|
|
213
|
+
text_search = await community_fulltext_search(driver, query, group_ids, 2 * limit)
|
|
214
|
+
search_results.append(text_search)
|
|
215
|
+
|
|
216
|
+
if CommunitySearchMethod.cosine_similarity in config.search_methods:
|
|
217
|
+
search_vector = (
|
|
218
|
+
(await embedder.create(input=[query], model='text-embedding-3-small'))
|
|
219
|
+
.data[0]
|
|
220
|
+
.embedding[:EMBEDDING_DIM]
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
similarity_search = await community_similarity_search(
|
|
224
|
+
driver, search_vector, group_ids, 2 * limit
|
|
225
|
+
)
|
|
226
|
+
search_results.append(similarity_search)
|
|
227
|
+
|
|
228
|
+
if len(search_results) > 1 and config.reranker is None:
|
|
229
|
+
raise SearchRerankerError('Multiple node searches enabled without a reranker')
|
|
230
|
+
|
|
231
|
+
search_result_uuids = [[community.uuid for community in result] for result in search_results]
|
|
232
|
+
community_uuid_map = {
|
|
233
|
+
community.uuid: community for result in search_results for community in result
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
reranked_uuids: list[str] = []
|
|
237
|
+
if config.reranker == CommunityReranker.rrf:
|
|
238
|
+
reranked_uuids = rrf(search_result_uuids)
|
|
239
|
+
|
|
240
|
+
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
|
|
241
|
+
|
|
242
|
+
return reranked_communities
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from enum import Enum
|
|
18
|
+
|
|
19
|
+
from pydantic import BaseModel, Field
|
|
20
|
+
|
|
21
|
+
from graphiti_core.edges import EntityEdge
|
|
22
|
+
from graphiti_core.nodes import CommunityNode, EntityNode
|
|
23
|
+
|
|
24
|
+
DEFAULT_SEARCH_LIMIT = 10
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class EdgeSearchMethod(Enum):
|
|
28
|
+
cosine_similarity = 'cosine_similarity'
|
|
29
|
+
bm25 = 'bm25'
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class NodeSearchMethod(Enum):
|
|
33
|
+
cosine_similarity = 'cosine_similarity'
|
|
34
|
+
bm25 = 'bm25'
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class CommunitySearchMethod(Enum):
|
|
38
|
+
cosine_similarity = 'cosine_similarity'
|
|
39
|
+
bm25 = 'bm25'
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class EdgeReranker(Enum):
|
|
43
|
+
rrf = 'reciprocal_rank_fusion'
|
|
44
|
+
node_distance = 'node_distance'
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class NodeReranker(Enum):
|
|
48
|
+
rrf = 'reciprocal_rank_fusion'
|
|
49
|
+
node_distance = 'node_distance'
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class CommunityReranker(Enum):
|
|
53
|
+
rrf = 'reciprocal_rank_fusion'
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class EdgeSearchConfig(BaseModel):
|
|
57
|
+
search_methods: list[EdgeSearchMethod]
|
|
58
|
+
reranker: EdgeReranker | None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class NodeSearchConfig(BaseModel):
|
|
62
|
+
search_methods: list[NodeSearchMethod]
|
|
63
|
+
reranker: NodeReranker | None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class CommunitySearchConfig(BaseModel):
|
|
67
|
+
search_methods: list[CommunitySearchMethod]
|
|
68
|
+
reranker: CommunityReranker | None
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class SearchConfig(BaseModel):
|
|
72
|
+
edge_config: EdgeSearchConfig | None = Field(default=None)
|
|
73
|
+
node_config: NodeSearchConfig | None = Field(default=None)
|
|
74
|
+
community_config: CommunitySearchConfig | None = Field(default=None)
|
|
75
|
+
limit: int = Field(default=DEFAULT_SEARCH_LIMIT)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class SearchResults(BaseModel):
|
|
79
|
+
edges: list[EntityEdge]
|
|
80
|
+
nodes: list[EntityNode]
|
|
81
|
+
communities: list[CommunityNode]
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from graphiti_core.search.search_config import (
|
|
18
|
+
CommunityReranker,
|
|
19
|
+
CommunitySearchConfig,
|
|
20
|
+
CommunitySearchMethod,
|
|
21
|
+
EdgeReranker,
|
|
22
|
+
EdgeSearchConfig,
|
|
23
|
+
EdgeSearchMethod,
|
|
24
|
+
NodeReranker,
|
|
25
|
+
NodeSearchConfig,
|
|
26
|
+
NodeSearchMethod,
|
|
27
|
+
SearchConfig,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
# Performs a hybrid search with rrf reranking over edges, nodes, and communities
|
|
31
|
+
COMBINED_HYBRID_SEARCH_RRF = SearchConfig(
|
|
32
|
+
edge_config=EdgeSearchConfig(
|
|
33
|
+
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
|
34
|
+
reranker=EdgeReranker.rrf,
|
|
35
|
+
),
|
|
36
|
+
node_config=NodeSearchConfig(
|
|
37
|
+
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
|
|
38
|
+
reranker=NodeReranker.rrf,
|
|
39
|
+
),
|
|
40
|
+
community_config=CommunitySearchConfig(
|
|
41
|
+
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
|
|
42
|
+
reranker=CommunityReranker.rrf,
|
|
43
|
+
),
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# performs a hybrid search over edges with rrf reranking
|
|
47
|
+
EDGE_HYBRID_SEARCH_RRF = SearchConfig(
|
|
48
|
+
edge_config=EdgeSearchConfig(
|
|
49
|
+
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
|
50
|
+
reranker=EdgeReranker.rrf,
|
|
51
|
+
)
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# performs a hybrid search over edges with node distance reranking
|
|
55
|
+
EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
|
|
56
|
+
edge_config=EdgeSearchConfig(
|
|
57
|
+
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
|
58
|
+
reranker=EdgeReranker.node_distance,
|
|
59
|
+
)
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# performs a hybrid search over nodes with rrf reranking
|
|
63
|
+
NODE_HYBRID_SEARCH_RRF = SearchConfig(
|
|
64
|
+
node_config=NodeSearchConfig(
|
|
65
|
+
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
|
|
66
|
+
reranker=NodeReranker.rrf,
|
|
67
|
+
)
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# performs a hybrid search over nodes with node distance reranking
|
|
71
|
+
NODE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
|
|
72
|
+
node_config=NodeSearchConfig(
|
|
73
|
+
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
|
|
74
|
+
reranker=NodeReranker.node_distance,
|
|
75
|
+
)
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# performs a hybrid search over communities with rrf reranking
|
|
79
|
+
COMMUNITY_HYBRID_SEARCH_RRF = SearchConfig(
|
|
80
|
+
community_config=CommunitySearchConfig(
|
|
81
|
+
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
|
|
82
|
+
reranker=CommunityReranker.rrf,
|
|
83
|
+
)
|
|
84
|
+
)
|