graphiti-core 0.3.7__tar.gz → 0.3.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/PKG-INFO +3 -2
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/edges.py +3 -3
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/embedder/openai.py +1 -1
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/embedder/voyage.py +1 -1
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/graphiti.py +33 -14
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/helpers.py +15 -1
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/nodes.py +4 -2
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/prompts/eval.py +28 -2
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/prompts/extract_edge_dates.py +8 -9
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/prompts/extract_edges.py +3 -2
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/prompts/invalidate_edges.py +1 -1
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/search/search.py +61 -45
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/search/search_config.py +13 -3
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/search/search_config_recipes.py +40 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/search/search_utils.py +98 -53
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/utils/maintenance/__init__.py +0 -2
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/utils/maintenance/community_operations.py +13 -25
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/utils/maintenance/edge_operations.py +2 -8
- graphiti_core-0.3.9/graphiti_core/utils/maintenance/temporal_operations.py +95 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/pyproject.toml +5 -4
- graphiti_core-0.3.7/graphiti_core/utils/maintenance/temporal_operations.py +0 -217
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/LICENSE +0 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/README.md +0 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/__init__.py +0 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/embedder/__init__.py +0 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/embedder/client.py +0 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/errors.py +0 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/llm_client/__init__.py +0 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/llm_client/anthropic_client.py +0 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/llm_client/client.py +0 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/llm_client/config.py +0 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/llm_client/errors.py +0 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/llm_client/groq_client.py +0 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/llm_client/openai_client.py +0 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/llm_client/utils.py +0 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/prompts/__init__.py +0 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/prompts/dedupe_edges.py +0 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/prompts/dedupe_nodes.py +0 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/prompts/extract_nodes.py +0 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/prompts/lib.py +0 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/prompts/models.py +0 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/prompts/summarize_nodes.py +0 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/py.typed +0 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/search/__init__.py +0 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/utils/__init__.py +0 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/utils/bulk_utils.py +0 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/utils/maintenance/graph_data_operations.py +0 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/utils/maintenance/node_operations.py +0 -0
- {graphiti_core-0.3.7 → graphiti_core-0.3.9}/graphiti_core/utils/maintenance/utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: graphiti-core
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.9
|
|
4
4
|
Summary: A temporal graph building library
|
|
5
5
|
License: Apache-2.0
|
|
6
6
|
Author: Paul Paliychuk
|
|
@@ -14,9 +14,10 @@ Classifier: Programming Language :: Python :: 3.12
|
|
|
14
14
|
Requires-Dist: diskcache (>=5.6.3,<6.0.0)
|
|
15
15
|
Requires-Dist: neo4j (>=5.23.0,<6.0.0)
|
|
16
16
|
Requires-Dist: numpy (>=1.0.0)
|
|
17
|
-
Requires-Dist: openai (>=1.
|
|
17
|
+
Requires-Dist: openai (>=1.50.2,<2.0.0)
|
|
18
18
|
Requires-Dist: pydantic (>=2.8.2,<3.0.0)
|
|
19
19
|
Requires-Dist: tenacity (<9.0.0)
|
|
20
|
+
Requires-Dist: voyageai (>=0.2.3,<0.3.0)
|
|
20
21
|
Description-Content-Type: text/markdown
|
|
21
22
|
|
|
22
23
|
<div align="center">
|
|
@@ -188,9 +188,9 @@ class EntityEdge(Edge):
|
|
|
188
188
|
MATCH (source:Entity {uuid: $source_uuid})
|
|
189
189
|
MATCH (target:Entity {uuid: $target_uuid})
|
|
190
190
|
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
|
|
191
|
-
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact,
|
|
192
|
-
|
|
193
|
-
|
|
191
|
+
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, episodes: $episodes,
|
|
192
|
+
created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at}
|
|
193
|
+
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
|
|
194
194
|
RETURN r.uuid AS uuid""",
|
|
195
195
|
source_uuid=self.source_node_uuid,
|
|
196
196
|
target_uuid=self.target_node_uuid,
|
|
@@ -42,7 +42,7 @@ class OpenAIEmbedder(EmbedderClient):
|
|
|
42
42
|
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
|
43
43
|
|
|
44
44
|
async def create(
|
|
45
|
-
|
|
45
|
+
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
|
46
46
|
) -> list[float]:
|
|
47
47
|
result = await self.client.embeddings.create(input=input, model=self.config.embedding_model)
|
|
48
48
|
return result.data[0].embedding[: self.config.embedding_dim]
|
|
@@ -41,7 +41,7 @@ class VoyageAIEmbedder(EmbedderClient):
|
|
|
41
41
|
self.client = voyageai.AsyncClient(api_key=config.api_key)
|
|
42
42
|
|
|
43
43
|
async def create(
|
|
44
|
-
|
|
44
|
+
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
|
45
45
|
) -> list[float]:
|
|
46
46
|
result = await self.client.embed(input, model=self.config.embedding_model)
|
|
47
47
|
return result.embeddings[0][: self.config.embedding_dim]
|
|
@@ -21,11 +21,12 @@ from time import time
|
|
|
21
21
|
|
|
22
22
|
from dotenv import load_dotenv
|
|
23
23
|
from neo4j import AsyncGraphDatabase
|
|
24
|
+
from pydantic import BaseModel
|
|
24
25
|
|
|
25
26
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
|
26
27
|
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
|
27
28
|
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
|
28
|
-
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
|
29
|
+
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
|
29
30
|
from graphiti_core.search.search import SearchConfig, search
|
|
30
31
|
from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
|
|
31
32
|
from graphiti_core.search.search_config_recipes import (
|
|
@@ -77,6 +78,12 @@ logger = logging.getLogger(__name__)
|
|
|
77
78
|
load_dotenv()
|
|
78
79
|
|
|
79
80
|
|
|
81
|
+
class AddEpisodeResults(BaseModel):
|
|
82
|
+
episode: EpisodicNode
|
|
83
|
+
nodes: list[EntityNode]
|
|
84
|
+
edges: list[EntityEdge]
|
|
85
|
+
|
|
86
|
+
|
|
80
87
|
class Graphiti:
|
|
81
88
|
def __init__(
|
|
82
89
|
self,
|
|
@@ -245,7 +252,7 @@ class Graphiti:
|
|
|
245
252
|
group_id: str = '',
|
|
246
253
|
uuid: str | None = None,
|
|
247
254
|
update_communities: bool = False,
|
|
248
|
-
):
|
|
255
|
+
) -> AddEpisodeResults:
|
|
249
256
|
"""
|
|
250
257
|
Process an episode and update the graph.
|
|
251
258
|
|
|
@@ -451,6 +458,8 @@ class Graphiti:
|
|
|
451
458
|
end = time()
|
|
452
459
|
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
|
453
460
|
|
|
461
|
+
return AddEpisodeResults(episode=episode, nodes=nodes, edges=entity_edges)
|
|
462
|
+
|
|
454
463
|
except Exception as e:
|
|
455
464
|
raise e
|
|
456
465
|
|
|
@@ -567,11 +576,20 @@ class Graphiti:
|
|
|
567
576
|
except Exception as e:
|
|
568
577
|
raise e
|
|
569
578
|
|
|
570
|
-
async def build_communities(self):
|
|
579
|
+
async def build_communities(self, group_ids: list[str] | None = None) -> list[CommunityNode]:
|
|
580
|
+
"""
|
|
581
|
+
Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
|
|
582
|
+
the content of these communities.
|
|
583
|
+
----------
|
|
584
|
+
query : list[str] | None
|
|
585
|
+
Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
|
|
586
|
+
"""
|
|
571
587
|
# Clear existing communities
|
|
572
588
|
await remove_communities(self.driver)
|
|
573
589
|
|
|
574
|
-
community_nodes, community_edges = await build_communities(
|
|
590
|
+
community_nodes, community_edges = await build_communities(
|
|
591
|
+
self.driver, self.llm_client, group_ids
|
|
592
|
+
)
|
|
575
593
|
|
|
576
594
|
await asyncio.gather(
|
|
577
595
|
*[node.generate_name_embedding(self.embedder) for node in community_nodes]
|
|
@@ -580,6 +598,8 @@ class Graphiti:
|
|
|
580
598
|
await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
|
|
581
599
|
await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
|
|
582
600
|
|
|
601
|
+
return community_nodes
|
|
602
|
+
|
|
583
603
|
async def search(
|
|
584
604
|
self,
|
|
585
605
|
query: str,
|
|
@@ -700,18 +720,17 @@ class Graphiti:
|
|
|
700
720
|
).nodes
|
|
701
721
|
return nodes
|
|
702
722
|
|
|
723
|
+
async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
|
|
724
|
+
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
|
|
703
725
|
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
edges_list = await asyncio.gather(
|
|
708
|
-
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
|
|
709
|
-
)
|
|
726
|
+
edges_list = await asyncio.gather(
|
|
727
|
+
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
|
|
728
|
+
)
|
|
710
729
|
|
|
711
|
-
|
|
730
|
+
edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
|
|
712
731
|
|
|
713
|
-
|
|
732
|
+
nodes = await get_mentioned_nodes(self.driver, episodes)
|
|
714
733
|
|
|
715
|
-
|
|
734
|
+
communities = await get_communities_by_nodes(self.driver, nodes)
|
|
716
735
|
|
|
717
|
-
|
|
736
|
+
return SearchResults(edges=edges, nodes=nodes, communities=communities)
|
|
@@ -16,6 +16,7 @@ limitations under the License.
|
|
|
16
16
|
|
|
17
17
|
from datetime import datetime
|
|
18
18
|
|
|
19
|
+
import numpy as np
|
|
19
20
|
from neo4j import time as neo4j_time
|
|
20
21
|
|
|
21
22
|
|
|
@@ -25,7 +26,7 @@ def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
|
|
25
26
|
|
|
26
27
|
def lucene_sanitize(query: str) -> str:
|
|
27
28
|
# Escape special characters from a query before passing into Lucene
|
|
28
|
-
# + - && || ! ( ) { } [ ] ^ " ~ * ? : \
|
|
29
|
+
# + - && || ! ( ) { } [ ] ^ " ~ * ? : \ /
|
|
29
30
|
escape_map = str.maketrans(
|
|
30
31
|
{
|
|
31
32
|
'+': r'\+',
|
|
@@ -46,8 +47,21 @@ def lucene_sanitize(query: str) -> str:
|
|
|
46
47
|
'?': r'\?',
|
|
47
48
|
':': r'\:',
|
|
48
49
|
'\\': r'\\',
|
|
50
|
+
'/': r'\/',
|
|
49
51
|
}
|
|
50
52
|
)
|
|
51
53
|
|
|
52
54
|
sanitized = query.translate(escape_map)
|
|
53
55
|
return sanitized
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def normalize_l2(embedding: list[float]) -> list[float]:
|
|
59
|
+
embedding_array = np.array(embedding)
|
|
60
|
+
if embedding_array.ndim == 1:
|
|
61
|
+
norm = np.linalg.norm(embedding_array)
|
|
62
|
+
if norm == 0:
|
|
63
|
+
return embedding_array.tolist()
|
|
64
|
+
return (embedding_array / norm).tolist()
|
|
65
|
+
else:
|
|
66
|
+
norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
|
|
67
|
+
return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist()
|
|
@@ -225,7 +225,8 @@ class EntityNode(Node):
|
|
|
225
225
|
result = await driver.execute_query(
|
|
226
226
|
"""
|
|
227
227
|
MERGE (n:Entity {uuid: $uuid})
|
|
228
|
-
SET n = {uuid: $uuid, name: $name,
|
|
228
|
+
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
|
229
|
+
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
|
229
230
|
RETURN n.uuid AS uuid""",
|
|
230
231
|
uuid=self.uuid,
|
|
231
232
|
name=self.name,
|
|
@@ -308,7 +309,8 @@ class CommunityNode(Node):
|
|
|
308
309
|
result = await driver.execute_query(
|
|
309
310
|
"""
|
|
310
311
|
MERGE (n:Community {uuid: $uuid})
|
|
311
|
-
SET n = {uuid: $uuid, name: $name,
|
|
312
|
+
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
|
313
|
+
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
|
312
314
|
RETURN n.uuid AS uuid""",
|
|
313
315
|
uuid=self.uuid,
|
|
314
316
|
name=self.name,
|
|
@@ -23,11 +23,33 @@ from .models import Message, PromptFunction, PromptVersion
|
|
|
23
23
|
class Prompt(Protocol):
|
|
24
24
|
qa_prompt: PromptVersion
|
|
25
25
|
eval_prompt: PromptVersion
|
|
26
|
+
query_expansion: PromptVersion
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
class Versions(TypedDict):
|
|
29
30
|
qa_prompt: PromptFunction
|
|
30
31
|
eval_prompt: PromptFunction
|
|
32
|
+
query_expansion: PromptFunction
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def query_expansion(context: dict[str, Any]) -> list[Message]:
|
|
36
|
+
sys_prompt = """You are an expert at rephrasing questions into queries used in a database retrieval system"""
|
|
37
|
+
|
|
38
|
+
user_prompt = f"""
|
|
39
|
+
Bob is asking Alice a question, are you able to rephrase the question into a simpler one about Alice in the third person
|
|
40
|
+
that maintains the relevant context?
|
|
41
|
+
<QUESTION>
|
|
42
|
+
{json.dumps(context['query'])}
|
|
43
|
+
</QUESTION>
|
|
44
|
+
respond with a JSON object in the following format:
|
|
45
|
+
{{
|
|
46
|
+
"query": "query optimized for database search"
|
|
47
|
+
}}
|
|
48
|
+
"""
|
|
49
|
+
return [
|
|
50
|
+
Message(role='system', content=sys_prompt),
|
|
51
|
+
Message(role='user', content=user_prompt),
|
|
52
|
+
]
|
|
31
53
|
|
|
32
54
|
|
|
33
55
|
def qa_prompt(context: dict[str, Any]) -> list[Message]:
|
|
@@ -38,7 +60,7 @@ def qa_prompt(context: dict[str, Any]) -> list[Message]:
|
|
|
38
60
|
You are given the following entity summaries and facts to help you determine the answer to your question.
|
|
39
61
|
<ENTITY_SUMMARIES>
|
|
40
62
|
{json.dumps(context['entity_summaries'])}
|
|
41
|
-
</ENTITY_SUMMARIES
|
|
63
|
+
</ENTITY_SUMMARIES>
|
|
42
64
|
<FACTS>
|
|
43
65
|
{json.dumps(context['facts'])}
|
|
44
66
|
</FACTS>
|
|
@@ -87,4 +109,8 @@ def eval_prompt(context: dict[str, Any]) -> list[Message]:
|
|
|
87
109
|
]
|
|
88
110
|
|
|
89
111
|
|
|
90
|
-
versions: Versions = {
|
|
112
|
+
versions: Versions = {
|
|
113
|
+
'qa_prompt': qa_prompt,
|
|
114
|
+
'eval_prompt': eval_prompt,
|
|
115
|
+
'query_expansion': query_expansion,
|
|
116
|
+
}
|
|
@@ -37,7 +37,6 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
|
|
37
37
|
role='user',
|
|
38
38
|
content=f"""
|
|
39
39
|
Edge:
|
|
40
|
-
Edge Name: {context['edge_name']}
|
|
41
40
|
Fact: {context['edge_fact']}
|
|
42
41
|
|
|
43
42
|
Current Episode: {context['current_episode']}
|
|
@@ -56,17 +55,17 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
|
|
56
55
|
Guidelines:
|
|
57
56
|
1. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SSZ) for datetimes.
|
|
58
57
|
2. Use the reference timestamp as the current time when determining the valid_at and invalid_at dates.
|
|
59
|
-
3. If
|
|
60
|
-
4.
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
7. If only a
|
|
58
|
+
3. If the fact is written in the present tense, use the Reference Timestamp for the valid_at date
|
|
59
|
+
4. If no temporal information is found that establishes or changes the relationship, leave the fields as null.
|
|
60
|
+
5. Do not infer dates from related events. Only use dates that are directly stated to establish or change the relationship.
|
|
61
|
+
6. For relative time mentions directly related to the relationship, calculate the actual datetime based on the reference timestamp.
|
|
62
|
+
7. If only a date is mentioned without a specific time, use 00:00:00 (midnight) for that date.
|
|
63
|
+
8. If only a year is mentioned, use January 1st of that year at 00:00:00.
|
|
64
64
|
9. Always include the time zone offset (use Z for UTC if no specific time zone is mentioned).
|
|
65
65
|
Respond with a JSON object:
|
|
66
66
|
{{
|
|
67
|
-
"valid_at": "YYYY-MM-DDTHH:MM:
|
|
68
|
-
"invalid_at": "YYYY-MM-DDTHH:MM:
|
|
69
|
-
"explanation": "Brief explanation of why these dates were chosen or why they were set to null"
|
|
67
|
+
"valid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null",
|
|
68
|
+
"invalid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null",
|
|
70
69
|
}}
|
|
71
70
|
""",
|
|
72
71
|
),
|
|
@@ -113,8 +113,9 @@ def v2(context: dict[str, Any]) -> list[Message]:
|
|
|
113
113
|
2. Each edge should represent a clear relationship between two DISTINCT nodes.
|
|
114
114
|
3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
|
|
115
115
|
4. Provide a more detailed fact describing the relationship.
|
|
116
|
-
5.
|
|
117
|
-
6.
|
|
116
|
+
5. The fact should include any specific relevant information, including numeric information
|
|
117
|
+
6. Consider temporal aspects of relationships when relevant.
|
|
118
|
+
7. Avoid using the same node as the source and target of a relationship
|
|
118
119
|
|
|
119
120
|
Respond with a JSON object in the following format:
|
|
120
121
|
{{
|
|
@@ -82,7 +82,7 @@ def v2(context: dict[str, Any]) -> list[Message]:
|
|
|
82
82
|
Message(
|
|
83
83
|
role='user',
|
|
84
84
|
content=f"""
|
|
85
|
-
Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to
|
|
85
|
+
Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to invalidations with the New Edge.
|
|
86
86
|
|
|
87
87
|
Existing Edges:
|
|
88
88
|
{context['existing_edges']}
|
|
@@ -29,13 +29,10 @@ from graphiti_core.search.search_config import (
|
|
|
29
29
|
DEFAULT_SEARCH_LIMIT,
|
|
30
30
|
CommunityReranker,
|
|
31
31
|
CommunitySearchConfig,
|
|
32
|
-
CommunitySearchMethod,
|
|
33
32
|
EdgeReranker,
|
|
34
33
|
EdgeSearchConfig,
|
|
35
|
-
EdgeSearchMethod,
|
|
36
34
|
NodeReranker,
|
|
37
35
|
NodeSearchConfig,
|
|
38
|
-
NodeSearchMethod,
|
|
39
36
|
SearchConfig,
|
|
40
37
|
SearchResults,
|
|
41
38
|
)
|
|
@@ -45,6 +42,7 @@ from graphiti_core.search.search_utils import (
|
|
|
45
42
|
edge_fulltext_search,
|
|
46
43
|
edge_similarity_search,
|
|
47
44
|
episode_mentions_reranker,
|
|
45
|
+
maximal_marginal_relevance,
|
|
48
46
|
node_distance_reranker,
|
|
49
47
|
node_fulltext_search,
|
|
50
48
|
node_similarity_search,
|
|
@@ -120,22 +118,18 @@ async def edge_search(
|
|
|
120
118
|
if config is None:
|
|
121
119
|
return []
|
|
122
120
|
|
|
123
|
-
|
|
121
|
+
query_vector = await embedder.create(input=[query])
|
|
124
122
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
driver, search_vector, None, None, group_ids, 2 * limit
|
|
123
|
+
search_results: list[list[EntityEdge]] = list(
|
|
124
|
+
await asyncio.gather(
|
|
125
|
+
*[
|
|
126
|
+
edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit),
|
|
127
|
+
edge_similarity_search(
|
|
128
|
+
driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score
|
|
129
|
+
),
|
|
130
|
+
]
|
|
134
131
|
)
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
if len(search_results) > 1 and config.reranker is None:
|
|
138
|
-
raise SearchRerankerError('Multiple edge searches enabled without a reranker')
|
|
132
|
+
)
|
|
139
133
|
|
|
140
134
|
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
|
|
141
135
|
|
|
@@ -144,6 +138,15 @@ async def edge_search(
|
|
|
144
138
|
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
|
145
139
|
|
|
146
140
|
reranked_uuids = rrf(search_result_uuids)
|
|
141
|
+
elif config.reranker == EdgeReranker.mmr:
|
|
142
|
+
search_result_uuids_and_vectors = [
|
|
143
|
+
(edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024)
|
|
144
|
+
for result in search_results
|
|
145
|
+
for edge in result
|
|
146
|
+
]
|
|
147
|
+
reranked_uuids = maximal_marginal_relevance(
|
|
148
|
+
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
|
|
149
|
+
)
|
|
147
150
|
elif config.reranker == EdgeReranker.node_distance:
|
|
148
151
|
if center_node_uuid is None:
|
|
149
152
|
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
|
@@ -184,22 +187,18 @@ async def node_search(
|
|
|
184
187
|
if config is None:
|
|
185
188
|
return []
|
|
186
189
|
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
if NodeSearchMethod.bm25 in config.search_methods:
|
|
190
|
-
text_search = await node_fulltext_search(driver, query, group_ids, 2 * limit)
|
|
191
|
-
search_results.append(text_search)
|
|
190
|
+
query_vector = await embedder.create(input=[query])
|
|
192
191
|
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
192
|
+
search_results: list[list[EntityNode]] = list(
|
|
193
|
+
await asyncio.gather(
|
|
194
|
+
*[
|
|
195
|
+
node_fulltext_search(driver, query, group_ids, 2 * limit),
|
|
196
|
+
node_similarity_search(
|
|
197
|
+
driver, query_vector, group_ids, 2 * limit, config.sim_min_score
|
|
198
|
+
),
|
|
199
|
+
]
|
|
198
200
|
)
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
if len(search_results) > 1 and config.reranker is None:
|
|
202
|
-
raise SearchRerankerError('Multiple node searches enabled without a reranker')
|
|
201
|
+
)
|
|
203
202
|
|
|
204
203
|
search_result_uuids = [[node.uuid for node in result] for result in search_results]
|
|
205
204
|
node_uuid_map = {node.uuid: node for result in search_results for node in result}
|
|
@@ -207,6 +206,15 @@ async def node_search(
|
|
|
207
206
|
reranked_uuids: list[str] = []
|
|
208
207
|
if config.reranker == NodeReranker.rrf:
|
|
209
208
|
reranked_uuids = rrf(search_result_uuids)
|
|
209
|
+
elif config.reranker == NodeReranker.mmr:
|
|
210
|
+
search_result_uuids_and_vectors = [
|
|
211
|
+
(node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024)
|
|
212
|
+
for result in search_results
|
|
213
|
+
for node in result
|
|
214
|
+
]
|
|
215
|
+
reranked_uuids = maximal_marginal_relevance(
|
|
216
|
+
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
|
|
217
|
+
)
|
|
210
218
|
elif config.reranker == NodeReranker.episode_mentions:
|
|
211
219
|
reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
|
|
212
220
|
elif config.reranker == NodeReranker.node_distance:
|
|
@@ -232,22 +240,18 @@ async def community_search(
|
|
|
232
240
|
if config is None:
|
|
233
241
|
return []
|
|
234
242
|
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
if CommunitySearchMethod.bm25 in config.search_methods:
|
|
238
|
-
text_search = await community_fulltext_search(driver, query, group_ids, 2 * limit)
|
|
239
|
-
search_results.append(text_search)
|
|
240
|
-
|
|
241
|
-
if CommunitySearchMethod.cosine_similarity in config.search_methods:
|
|
242
|
-
search_vector = await embedder.create(input=[query])
|
|
243
|
+
query_vector = await embedder.create(input=[query])
|
|
243
244
|
|
|
244
|
-
|
|
245
|
-
|
|
245
|
+
search_results: list[list[CommunityNode]] = list(
|
|
246
|
+
await asyncio.gather(
|
|
247
|
+
*[
|
|
248
|
+
community_fulltext_search(driver, query, group_ids, 2 * limit),
|
|
249
|
+
community_similarity_search(
|
|
250
|
+
driver, query_vector, group_ids, 2 * limit, config.sim_min_score
|
|
251
|
+
),
|
|
252
|
+
]
|
|
246
253
|
)
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
if len(search_results) > 1 and config.reranker is None:
|
|
250
|
-
raise SearchRerankerError('Multiple node searches enabled without a reranker')
|
|
254
|
+
)
|
|
251
255
|
|
|
252
256
|
search_result_uuids = [[community.uuid for community in result] for result in search_results]
|
|
253
257
|
community_uuid_map = {
|
|
@@ -257,6 +261,18 @@ async def community_search(
|
|
|
257
261
|
reranked_uuids: list[str] = []
|
|
258
262
|
if config.reranker == CommunityReranker.rrf:
|
|
259
263
|
reranked_uuids = rrf(search_result_uuids)
|
|
264
|
+
elif config.reranker == CommunityReranker.mmr:
|
|
265
|
+
search_result_uuids_and_vectors = [
|
|
266
|
+
(
|
|
267
|
+
community.uuid,
|
|
268
|
+
community.name_embedding if community.name_embedding is not None else [0.0] * 1024,
|
|
269
|
+
)
|
|
270
|
+
for result in search_results
|
|
271
|
+
for community in result
|
|
272
|
+
]
|
|
273
|
+
reranked_uuids = maximal_marginal_relevance(
|
|
274
|
+
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
|
|
275
|
+
)
|
|
260
276
|
|
|
261
277
|
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
|
|
262
278
|
|
|
@@ -20,6 +20,7 @@ from pydantic import BaseModel, Field
|
|
|
20
20
|
|
|
21
21
|
from graphiti_core.edges import EntityEdge
|
|
22
22
|
from graphiti_core.nodes import CommunityNode, EntityNode
|
|
23
|
+
from graphiti_core.search.search_utils import DEFAULT_MIN_SCORE, DEFAULT_MMR_LAMBDA
|
|
23
24
|
|
|
24
25
|
DEFAULT_SEARCH_LIMIT = 10
|
|
25
26
|
|
|
@@ -43,31 +44,40 @@ class EdgeReranker(Enum):
|
|
|
43
44
|
rrf = 'reciprocal_rank_fusion'
|
|
44
45
|
node_distance = 'node_distance'
|
|
45
46
|
episode_mentions = 'episode_mentions'
|
|
47
|
+
mmr = 'mmr'
|
|
46
48
|
|
|
47
49
|
|
|
48
50
|
class NodeReranker(Enum):
|
|
49
51
|
rrf = 'reciprocal_rank_fusion'
|
|
50
52
|
node_distance = 'node_distance'
|
|
51
53
|
episode_mentions = 'episode_mentions'
|
|
54
|
+
mmr = 'mmr'
|
|
52
55
|
|
|
53
56
|
|
|
54
57
|
class CommunityReranker(Enum):
|
|
55
58
|
rrf = 'reciprocal_rank_fusion'
|
|
59
|
+
mmr = 'mmr'
|
|
56
60
|
|
|
57
61
|
|
|
58
62
|
class EdgeSearchConfig(BaseModel):
|
|
59
63
|
search_methods: list[EdgeSearchMethod]
|
|
60
|
-
reranker: EdgeReranker
|
|
64
|
+
reranker: EdgeReranker = Field(default=EdgeReranker.rrf)
|
|
65
|
+
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
|
66
|
+
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
|
61
67
|
|
|
62
68
|
|
|
63
69
|
class NodeSearchConfig(BaseModel):
|
|
64
70
|
search_methods: list[NodeSearchMethod]
|
|
65
|
-
reranker: NodeReranker
|
|
71
|
+
reranker: NodeReranker = Field(default=NodeReranker.rrf)
|
|
72
|
+
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
|
73
|
+
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
|
66
74
|
|
|
67
75
|
|
|
68
76
|
class CommunitySearchConfig(BaseModel):
|
|
69
77
|
search_methods: list[CommunitySearchMethod]
|
|
70
|
-
reranker: CommunityReranker
|
|
78
|
+
reranker: CommunityReranker = Field(default=CommunityReranker.rrf)
|
|
79
|
+
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
|
80
|
+
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
|
71
81
|
|
|
72
82
|
|
|
73
83
|
class SearchConfig(BaseModel):
|
|
@@ -43,6 +43,22 @@ COMBINED_HYBRID_SEARCH_RRF = SearchConfig(
|
|
|
43
43
|
),
|
|
44
44
|
)
|
|
45
45
|
|
|
46
|
+
# Performs a hybrid search with mmr reranking over edges, nodes, and communities
|
|
47
|
+
COMBINED_HYBRID_SEARCH_MMR = SearchConfig(
|
|
48
|
+
edge_config=EdgeSearchConfig(
|
|
49
|
+
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
|
50
|
+
reranker=EdgeReranker.mmr,
|
|
51
|
+
),
|
|
52
|
+
node_config=NodeSearchConfig(
|
|
53
|
+
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
|
|
54
|
+
reranker=NodeReranker.mmr,
|
|
55
|
+
),
|
|
56
|
+
community_config=CommunitySearchConfig(
|
|
57
|
+
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
|
|
58
|
+
reranker=CommunityReranker.mmr,
|
|
59
|
+
),
|
|
60
|
+
)
|
|
61
|
+
|
|
46
62
|
# performs a hybrid search over edges with rrf reranking
|
|
47
63
|
EDGE_HYBRID_SEARCH_RRF = SearchConfig(
|
|
48
64
|
edge_config=EdgeSearchConfig(
|
|
@@ -51,6 +67,14 @@ EDGE_HYBRID_SEARCH_RRF = SearchConfig(
|
|
|
51
67
|
)
|
|
52
68
|
)
|
|
53
69
|
|
|
70
|
+
# performs a hybrid search over edges with mmr reranking
|
|
71
|
+
EDGE_HYBRID_SEARCH_mmr = SearchConfig(
|
|
72
|
+
edge_config=EdgeSearchConfig(
|
|
73
|
+
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
|
74
|
+
reranker=EdgeReranker.mmr,
|
|
75
|
+
)
|
|
76
|
+
)
|
|
77
|
+
|
|
54
78
|
# performs a hybrid search over edges with node distance reranking
|
|
55
79
|
EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
|
|
56
80
|
edge_config=EdgeSearchConfig(
|
|
@@ -75,6 +99,14 @@ NODE_HYBRID_SEARCH_RRF = SearchConfig(
|
|
|
75
99
|
)
|
|
76
100
|
)
|
|
77
101
|
|
|
102
|
+
# performs a hybrid search over nodes with mmr reranking
|
|
103
|
+
NODE_HYBRID_SEARCH_MMR = SearchConfig(
|
|
104
|
+
node_config=NodeSearchConfig(
|
|
105
|
+
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
|
|
106
|
+
reranker=NodeReranker.mmr,
|
|
107
|
+
)
|
|
108
|
+
)
|
|
109
|
+
|
|
78
110
|
# performs a hybrid search over nodes with node distance reranking
|
|
79
111
|
NODE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
|
|
80
112
|
node_config=NodeSearchConfig(
|
|
@@ -98,3 +130,11 @@ COMMUNITY_HYBRID_SEARCH_RRF = SearchConfig(
|
|
|
98
130
|
reranker=CommunityReranker.rrf,
|
|
99
131
|
)
|
|
100
132
|
)
|
|
133
|
+
|
|
134
|
+
# performs a hybrid search over communities with mmr reranking
|
|
135
|
+
COMMUNITY_HYBRID_SEARCH_MMR = SearchConfig(
|
|
136
|
+
community_config=CommunitySearchConfig(
|
|
137
|
+
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
|
|
138
|
+
reranker=CommunityReranker.mmr,
|
|
139
|
+
)
|
|
140
|
+
)
|