graphiti-core 0.4.2__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/cross_encoder/bge_reranker_client.py +1 -2
- graphiti_core/cross_encoder/client.py +3 -4
- graphiti_core/cross_encoder/openai_reranker_client.py +2 -2
- graphiti_core/edges.py +56 -7
- graphiti_core/embedder/client.py +3 -3
- graphiti_core/embedder/openai.py +2 -2
- graphiti_core/embedder/voyage.py +3 -3
- graphiti_core/graphiti.py +39 -37
- graphiti_core/helpers.py +26 -0
- graphiti_core/llm_client/anthropic_client.py +4 -1
- graphiti_core/llm_client/client.py +45 -5
- graphiti_core/llm_client/errors.py +8 -0
- graphiti_core/llm_client/groq_client.py +4 -1
- graphiti_core/llm_client/openai_client.py +71 -7
- graphiti_core/llm_client/openai_generic_client.py +163 -0
- graphiti_core/nodes.py +58 -8
- graphiti_core/prompts/dedupe_edges.py +20 -17
- graphiti_core/prompts/dedupe_nodes.py +15 -1
- graphiti_core/prompts/eval.py +17 -14
- graphiti_core/prompts/extract_edge_dates.py +15 -7
- graphiti_core/prompts/extract_edges.py +18 -19
- graphiti_core/prompts/extract_nodes.py +11 -21
- graphiti_core/prompts/invalidate_edges.py +13 -25
- graphiti_core/prompts/lib.py +5 -1
- graphiti_core/prompts/prompt_helpers.py +1 -0
- graphiti_core/prompts/summarize_nodes.py +17 -16
- graphiti_core/search/search.py +5 -5
- graphiti_core/search/search_utils.py +55 -14
- graphiti_core/utils/__init__.py +0 -15
- graphiti_core/utils/bulk_utils.py +22 -15
- graphiti_core/utils/datetime_utils.py +42 -0
- graphiti_core/utils/maintenance/community_operations.py +13 -9
- graphiti_core/utils/maintenance/edge_operations.py +32 -26
- graphiti_core/utils/maintenance/graph_data_operations.py +3 -4
- graphiti_core/utils/maintenance/node_operations.py +19 -13
- graphiti_core/utils/maintenance/temporal_operations.py +17 -9
- {graphiti_core-0.4.2.dist-info → graphiti_core-0.5.0.dist-info}/METADATA +1 -1
- graphiti_core-0.5.0.dist-info/RECORD +60 -0
- graphiti_core-0.4.2.dist-info/RECORD +0 -57
- {graphiti_core-0.4.2.dist-info → graphiti_core-0.5.0.dist-info}/LICENSE +0 -0
- {graphiti_core-0.4.2.dist-info → graphiti_core-0.5.0.dist-info}/WHEEL +0 -0
|
@@ -17,9 +17,26 @@ limitations under the License.
|
|
|
17
17
|
import json
|
|
18
18
|
from typing import Any, Protocol, TypedDict
|
|
19
19
|
|
|
20
|
+
from pydantic import BaseModel, Field
|
|
21
|
+
|
|
20
22
|
from .models import Message, PromptFunction, PromptVersion
|
|
21
23
|
|
|
22
24
|
|
|
25
|
+
class Edge(BaseModel):
|
|
26
|
+
relation_type: str = Field(..., description='RELATION_TYPE_IN_CAPS')
|
|
27
|
+
source_entity_name: str = Field(..., description='name of the source entity')
|
|
28
|
+
target_entity_name: str = Field(..., description='name of the target entity')
|
|
29
|
+
fact: str = Field(..., description='extracted factual information')
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ExtractedEdges(BaseModel):
|
|
33
|
+
edges: list[Edge]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class MissingFacts(BaseModel):
|
|
37
|
+
missing_facts: list[str] = Field(..., description="facts that weren't extracted")
|
|
38
|
+
|
|
39
|
+
|
|
23
40
|
class Prompt(Protocol):
|
|
24
41
|
edge: PromptVersion
|
|
25
42
|
reflexion: PromptVersion
|
|
@@ -54,25 +71,12 @@ def edge(context: dict[str, Any]) -> list[Message]:
|
|
|
54
71
|
|
|
55
72
|
Given the above MESSAGES and ENTITIES, extract all facts pertaining to the listed ENTITIES from the CURRENT MESSAGE.
|
|
56
73
|
|
|
57
|
-
|
|
58
74
|
Guidelines:
|
|
59
75
|
1. Extract facts only between the provided entities.
|
|
60
76
|
2. Each fact should represent a clear relationship between two DISTINCT nodes.
|
|
61
77
|
3. The relation_type should be a concise, all-caps description of the fact (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
|
|
62
78
|
4. Provide a more detailed fact containing all relevant information.
|
|
63
79
|
5. Consider temporal aspects of relationships when relevant.
|
|
64
|
-
|
|
65
|
-
Respond with a JSON object in the following format:
|
|
66
|
-
{{
|
|
67
|
-
"edges": [
|
|
68
|
-
{{
|
|
69
|
-
"relation_type": "RELATION_TYPE_IN_CAPS",
|
|
70
|
-
"source_entity_name": "name of the source entity",
|
|
71
|
-
"target_entity_name": "name of the target entity",
|
|
72
|
-
"fact": "extracted factual information",
|
|
73
|
-
}}
|
|
74
|
-
]
|
|
75
|
-
}}
|
|
76
80
|
""",
|
|
77
81
|
),
|
|
78
82
|
]
|
|
@@ -98,12 +102,7 @@ def reflexion(context: dict[str, Any]) -> list[Message]:
|
|
|
98
102
|
</EXTRACTED FACTS>
|
|
99
103
|
|
|
100
104
|
Given the above MESSAGES, list of EXTRACTED ENTITIES entities, and list of EXTRACTED FACTS;
|
|
101
|
-
determine if any facts haven't been extracted
|
|
102
|
-
|
|
103
|
-
Respond with a JSON object in the following format:
|
|
104
|
-
{{
|
|
105
|
-
"missing_facts": [ "facts that weren't extracted", ...]
|
|
106
|
-
}}
|
|
105
|
+
determine if any facts haven't been extracted.
|
|
107
106
|
"""
|
|
108
107
|
return [
|
|
109
108
|
Message(role='system', content=sys_prompt),
|
|
@@ -17,9 +17,19 @@ limitations under the License.
|
|
|
17
17
|
import json
|
|
18
18
|
from typing import Any, Protocol, TypedDict
|
|
19
19
|
|
|
20
|
+
from pydantic import BaseModel, Field
|
|
21
|
+
|
|
20
22
|
from .models import Message, PromptFunction, PromptVersion
|
|
21
23
|
|
|
22
24
|
|
|
25
|
+
class ExtractedNodes(BaseModel):
|
|
26
|
+
extracted_node_names: list[str] = Field(..., description='Name of the extracted entity')
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class MissedEntities(BaseModel):
|
|
30
|
+
missed_entities: list[str] = Field(..., description="Names of entities that weren't extracted")
|
|
31
|
+
|
|
32
|
+
|
|
23
33
|
class Prompt(Protocol):
|
|
24
34
|
extract_message: PromptVersion
|
|
25
35
|
extract_json: PromptVersion
|
|
@@ -56,11 +66,6 @@ Guidelines:
|
|
|
56
66
|
4. DO NOT create nodes for temporal information like dates, times or years (these will be added to edges later).
|
|
57
67
|
5. Be as explicit as possible in your node names, using full names.
|
|
58
68
|
6. DO NOT extract entities mentioned only in PREVIOUS MESSAGES, those messages are only to provide context.
|
|
59
|
-
|
|
60
|
-
Respond with a JSON object in the following format:
|
|
61
|
-
{{
|
|
62
|
-
"extracted_node_names": ["Name of the extracted entity", ...],
|
|
63
|
-
}}
|
|
64
69
|
"""
|
|
65
70
|
return [
|
|
66
71
|
Message(role='system', content=sys_prompt),
|
|
@@ -87,11 +92,6 @@ Given the above source description and JSON, extract relevant entity nodes from
|
|
|
87
92
|
Guidelines:
|
|
88
93
|
1. Always try to extract an entities that the JSON represents. This will often be something like a "name" or "user field
|
|
89
94
|
2. Do NOT extract any properties that contain dates
|
|
90
|
-
|
|
91
|
-
Respond with a JSON object in the following format:
|
|
92
|
-
{{
|
|
93
|
-
"extracted_node_names": ["Name of the extracted entity", ...],
|
|
94
|
-
}}
|
|
95
95
|
"""
|
|
96
96
|
return [
|
|
97
97
|
Message(role='system', content=sys_prompt),
|
|
@@ -116,11 +116,6 @@ Guidelines:
|
|
|
116
116
|
2. Avoid creating nodes for relationships or actions.
|
|
117
117
|
3. Avoid creating nodes for temporal information like dates, times or years (these will be added to edges later).
|
|
118
118
|
4. Be as explicit as possible in your node names, using full names and avoiding abbreviations.
|
|
119
|
-
|
|
120
|
-
Respond with a JSON object in the following format:
|
|
121
|
-
{{
|
|
122
|
-
"extracted_node_names": ["Name of the extracted entity", ...],
|
|
123
|
-
}}
|
|
124
119
|
"""
|
|
125
120
|
return [
|
|
126
121
|
Message(role='system', content=sys_prompt),
|
|
@@ -144,12 +139,7 @@ def reflexion(context: dict[str, Any]) -> list[Message]:
|
|
|
144
139
|
</EXTRACTED ENTITIES>
|
|
145
140
|
|
|
146
141
|
Given the above previous messages, current message, and list of extracted entities; determine if any entities haven't been
|
|
147
|
-
extracted
|
|
148
|
-
|
|
149
|
-
Respond with a JSON object in the following format:
|
|
150
|
-
{{
|
|
151
|
-
"missed_entities": [ "name of entity that wasn't extracted", ...]
|
|
152
|
-
}}
|
|
142
|
+
extracted.
|
|
153
143
|
"""
|
|
154
144
|
return [
|
|
155
145
|
Message(role='system', content=sys_prompt),
|
|
@@ -16,9 +16,22 @@ limitations under the License.
|
|
|
16
16
|
|
|
17
17
|
from typing import Any, Protocol, TypedDict
|
|
18
18
|
|
|
19
|
+
from pydantic import BaseModel, Field
|
|
20
|
+
|
|
19
21
|
from .models import Message, PromptFunction, PromptVersion
|
|
20
22
|
|
|
21
23
|
|
|
24
|
+
class InvalidatedEdge(BaseModel):
|
|
25
|
+
uuid: str = Field(..., description='The UUID of the edge to be invalidated')
|
|
26
|
+
fact: str = Field(..., description='Updated fact of the edge')
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class InvalidatedEdges(BaseModel):
|
|
30
|
+
invalidated_edges: list[InvalidatedEdge] = Field(
|
|
31
|
+
..., description='List of edges that should be invalidated'
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
22
35
|
class Prompt(Protocol):
|
|
23
36
|
v1: PromptVersion
|
|
24
37
|
v2: PromptVersion
|
|
@@ -56,18 +69,6 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
|
|
56
69
|
{context['new_edges']}
|
|
57
70
|
|
|
58
71
|
Each edge is formatted as: "UUID | SOURCE_NODE - EDGE_NAME - TARGET_NODE (fact: EDGE_FACT), START_DATE (END_DATE, optional))"
|
|
59
|
-
|
|
60
|
-
For each existing edge that should be invalidated, respond with a JSON object in the following format:
|
|
61
|
-
{{
|
|
62
|
-
"invalidated_edges": [
|
|
63
|
-
{{
|
|
64
|
-
"edge_uuid": "The UUID of the edge to be invalidated (the part before the | character)",
|
|
65
|
-
"fact": "Updated fact of the edge"
|
|
66
|
-
}}
|
|
67
|
-
]
|
|
68
|
-
}}
|
|
69
|
-
|
|
70
|
-
If no relationships need to be invalidated based on these strict criteria, return an empty list for "invalidated_edges".
|
|
71
72
|
""",
|
|
72
73
|
),
|
|
73
74
|
]
|
|
@@ -89,19 +90,6 @@ def v2(context: dict[str, Any]) -> list[Message]:
|
|
|
89
90
|
|
|
90
91
|
New Edge:
|
|
91
92
|
{context['new_edge']}
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
For each existing edge that should be invalidated, respond with a JSON object in the following format:
|
|
95
|
-
{{
|
|
96
|
-
"invalidated_edges": [
|
|
97
|
-
{{
|
|
98
|
-
"uuid": "The UUID of the edge to be invalidated",
|
|
99
|
-
"fact": "Updated fact of the edge"
|
|
100
|
-
}}
|
|
101
|
-
]
|
|
102
|
-
}}
|
|
103
|
-
|
|
104
|
-
If no relationships need to be invalidated based on these strict criteria, return an empty list for "invalidated_edges".
|
|
105
93
|
""",
|
|
106
94
|
),
|
|
107
95
|
]
|
graphiti_core/prompts/lib.py
CHANGED
|
@@ -74,6 +74,7 @@ from .invalidate_edges import (
|
|
|
74
74
|
versions as invalidate_edges_versions,
|
|
75
75
|
)
|
|
76
76
|
from .models import Message, PromptFunction
|
|
77
|
+
from .prompt_helpers import DO_NOT_ESCAPE_UNICODE
|
|
77
78
|
from .summarize_nodes import Prompt as SummarizeNodesPrompt
|
|
78
79
|
from .summarize_nodes import Versions as SummarizeNodesVersions
|
|
79
80
|
from .summarize_nodes import versions as summarize_nodes_versions
|
|
@@ -106,7 +107,10 @@ class VersionWrapper:
|
|
|
106
107
|
self.func = func
|
|
107
108
|
|
|
108
109
|
def __call__(self, context: dict[str, Any]) -> list[Message]:
|
|
109
|
-
|
|
110
|
+
messages = self.func(context)
|
|
111
|
+
for message in messages:
|
|
112
|
+
message.content += DO_NOT_ESCAPE_UNICODE if message.role == 'system' else ''
|
|
113
|
+
return messages
|
|
110
114
|
|
|
111
115
|
|
|
112
116
|
class PromptTypeWrapper:
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
DO_NOT_ESCAPE_UNICODE = '\nDo not escape unicode characters.\n'
|
|
@@ -17,9 +17,21 @@ limitations under the License.
|
|
|
17
17
|
import json
|
|
18
18
|
from typing import Any, Protocol, TypedDict
|
|
19
19
|
|
|
20
|
+
from pydantic import BaseModel, Field
|
|
21
|
+
|
|
20
22
|
from .models import Message, PromptFunction, PromptVersion
|
|
21
23
|
|
|
22
24
|
|
|
25
|
+
class Summary(BaseModel):
|
|
26
|
+
summary: str = Field(
|
|
27
|
+
..., description='Summary containing the important information from both summaries'
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class SummaryDescription(BaseModel):
|
|
32
|
+
description: str = Field(..., description='One sentence description of the provided summary')
|
|
33
|
+
|
|
34
|
+
|
|
23
35
|
class Prompt(Protocol):
|
|
24
36
|
summarize_pair: PromptVersion
|
|
25
37
|
summarize_context: PromptVersion
|
|
@@ -42,14 +54,11 @@ def summarize_pair(context: dict[str, Any]) -> list[Message]:
|
|
|
42
54
|
role='user',
|
|
43
55
|
content=f"""
|
|
44
56
|
Synthesize the information from the following two summaries into a single succinct summary.
|
|
57
|
+
|
|
58
|
+
Summaries must be under 500 words.
|
|
45
59
|
|
|
46
60
|
Summaries:
|
|
47
61
|
{json.dumps(context['node_summaries'], indent=2)}
|
|
48
|
-
|
|
49
|
-
Respond with a JSON object in the following format:
|
|
50
|
-
{{
|
|
51
|
-
"summary": "Summary containing the important information from both summaries"
|
|
52
|
-
}}
|
|
53
62
|
""",
|
|
54
63
|
),
|
|
55
64
|
]
|
|
@@ -74,15 +83,11 @@ def summarize_context(context: dict[str, Any]) -> list[Message]:
|
|
|
74
83
|
information from the provided MESSAGES. Your summary should also only contain information relevant to the
|
|
75
84
|
provided ENTITY.
|
|
76
85
|
|
|
86
|
+
Summaries must be under 500 words.
|
|
87
|
+
|
|
77
88
|
<ENTITY>
|
|
78
89
|
{context['node_name']}
|
|
79
90
|
</ENTITY>
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
Respond with a JSON object in the following format:
|
|
83
|
-
{{
|
|
84
|
-
"summary": "Entity summary"
|
|
85
|
-
}}
|
|
86
91
|
""",
|
|
87
92
|
),
|
|
88
93
|
]
|
|
@@ -98,14 +103,10 @@ def summary_description(context: dict[str, Any]) -> list[Message]:
|
|
|
98
103
|
role='user',
|
|
99
104
|
content=f"""
|
|
100
105
|
Create a short one sentence description of the summary that explains what kind of information is summarized.
|
|
106
|
+
Summaries must be under 500 words.
|
|
101
107
|
|
|
102
108
|
Summary:
|
|
103
109
|
{json.dumps(context['summary'], indent=2)}
|
|
104
|
-
|
|
105
|
-
Respond with a JSON object in the following format:
|
|
106
|
-
{{
|
|
107
|
-
"description": "One sentence description of the provided summary"
|
|
108
|
-
}}
|
|
109
110
|
""",
|
|
110
111
|
),
|
|
111
112
|
]
|
graphiti_core/search/search.py
CHANGED
|
@@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
import asyncio
|
|
18
17
|
import logging
|
|
19
18
|
from collections import defaultdict
|
|
20
19
|
from time import time
|
|
@@ -25,6 +24,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
|
25
24
|
from graphiti_core.edges import EntityEdge
|
|
26
25
|
from graphiti_core.embedder import EmbedderClient
|
|
27
26
|
from graphiti_core.errors import SearchRerankerError
|
|
27
|
+
from graphiti_core.helpers import semaphore_gather
|
|
28
28
|
from graphiti_core.nodes import CommunityNode, EntityNode
|
|
29
29
|
from graphiti_core.search.search_config import (
|
|
30
30
|
DEFAULT_SEARCH_LIMIT,
|
|
@@ -78,7 +78,7 @@ async def search(
|
|
|
78
78
|
|
|
79
79
|
# if group_ids is empty, set it to None
|
|
80
80
|
group_ids = group_ids if group_ids else None
|
|
81
|
-
edges, nodes, communities = await
|
|
81
|
+
edges, nodes, communities = await semaphore_gather(
|
|
82
82
|
edge_search(
|
|
83
83
|
driver,
|
|
84
84
|
cross_encoder,
|
|
@@ -141,7 +141,7 @@ async def edge_search(
|
|
|
141
141
|
return []
|
|
142
142
|
|
|
143
143
|
search_results: list[list[EntityEdge]] = list(
|
|
144
|
-
await
|
|
144
|
+
await semaphore_gather(
|
|
145
145
|
*[
|
|
146
146
|
edge_fulltext_search(driver, query, group_ids, 2 * limit),
|
|
147
147
|
edge_similarity_search(
|
|
@@ -226,7 +226,7 @@ async def node_search(
|
|
|
226
226
|
return []
|
|
227
227
|
|
|
228
228
|
search_results: list[list[EntityNode]] = list(
|
|
229
|
-
await
|
|
229
|
+
await semaphore_gather(
|
|
230
230
|
*[
|
|
231
231
|
node_fulltext_search(driver, query, group_ids, 2 * limit),
|
|
232
232
|
node_similarity_search(
|
|
@@ -295,7 +295,7 @@ async def community_search(
|
|
|
295
295
|
return []
|
|
296
296
|
|
|
297
297
|
search_results: list[list[CommunityNode]] = list(
|
|
298
|
-
await
|
|
298
|
+
await semaphore_gather(
|
|
299
299
|
*[
|
|
300
300
|
community_fulltext_search(driver, query, group_ids, 2 * limit),
|
|
301
301
|
community_similarity_search(
|
|
@@ -14,10 +14,10 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
import asyncio
|
|
18
17
|
import logging
|
|
19
18
|
from collections import defaultdict
|
|
20
19
|
from time import time
|
|
20
|
+
from typing import Any
|
|
21
21
|
|
|
22
22
|
import numpy as np
|
|
23
23
|
from neo4j import AsyncDriver, Query
|
|
@@ -29,6 +29,7 @@ from graphiti_core.helpers import (
|
|
|
29
29
|
USE_PARALLEL_RUNTIME,
|
|
30
30
|
lucene_sanitize,
|
|
31
31
|
normalize_l2,
|
|
32
|
+
semaphore_gather,
|
|
32
33
|
)
|
|
33
34
|
from graphiti_core.nodes import (
|
|
34
35
|
CommunityNode,
|
|
@@ -40,7 +41,7 @@ from graphiti_core.nodes import (
|
|
|
40
41
|
|
|
41
42
|
logger = logging.getLogger(__name__)
|
|
42
43
|
|
|
43
|
-
RELEVANT_SCHEMA_LIMIT =
|
|
44
|
+
RELEVANT_SCHEMA_LIMIT = 10
|
|
44
45
|
DEFAULT_MIN_SCORE = 0.6
|
|
45
46
|
DEFAULT_MMR_LAMBDA = 0.5
|
|
46
47
|
MAX_SEARCH_DEPTH = 3
|
|
@@ -191,12 +192,27 @@ async def edge_similarity_search(
|
|
|
191
192
|
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
|
192
193
|
)
|
|
193
194
|
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
195
|
+
query_params: dict[str, Any] = {}
|
|
196
|
+
|
|
197
|
+
group_filter_query: LiteralString = ''
|
|
198
|
+
if group_ids is not None:
|
|
199
|
+
group_filter_query += 'WHERE r.group_id IN $group_ids'
|
|
200
|
+
query_params['group_ids'] = group_ids
|
|
201
|
+
query_params['source_node_uuid'] = source_node_uuid
|
|
202
|
+
query_params['target_node_uuid'] = target_node_uuid
|
|
203
|
+
|
|
204
|
+
if source_node_uuid is not None:
|
|
205
|
+
group_filter_query += '\nAND (n.uuid IN [$source_uuid, $target_uuid])'
|
|
206
|
+
|
|
207
|
+
if target_node_uuid is not None:
|
|
208
|
+
group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
|
|
209
|
+
|
|
210
|
+
query: LiteralString = (
|
|
211
|
+
"""
|
|
212
|
+
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
|
213
|
+
"""
|
|
214
|
+
+ group_filter_query
|
|
215
|
+
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
|
200
216
|
WHERE score > $min_score
|
|
201
217
|
RETURN
|
|
202
218
|
r.uuid AS uuid,
|
|
@@ -214,9 +230,11 @@ async def edge_similarity_search(
|
|
|
214
230
|
ORDER BY score DESC
|
|
215
231
|
LIMIT $limit
|
|
216
232
|
"""
|
|
233
|
+
)
|
|
217
234
|
|
|
218
235
|
records, _, _ = await driver.execute_query(
|
|
219
236
|
runtime_query + query,
|
|
237
|
+
query_params,
|
|
220
238
|
search_vector=search_vector,
|
|
221
239
|
source_uuid=source_node_uuid,
|
|
222
240
|
target_uuid=target_node_uuid,
|
|
@@ -325,11 +343,20 @@ async def node_similarity_search(
|
|
|
325
343
|
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
|
326
344
|
)
|
|
327
345
|
|
|
346
|
+
query_params: dict[str, Any] = {}
|
|
347
|
+
|
|
348
|
+
group_filter_query: LiteralString = ''
|
|
349
|
+
if group_ids is not None:
|
|
350
|
+
group_filter_query += 'WHERE n.group_id IN $group_ids'
|
|
351
|
+
query_params['group_ids'] = group_ids
|
|
352
|
+
|
|
328
353
|
records, _, _ = await driver.execute_query(
|
|
329
354
|
runtime_query
|
|
330
355
|
+ """
|
|
331
356
|
MATCH (n:Entity)
|
|
332
|
-
|
|
357
|
+
"""
|
|
358
|
+
+ group_filter_query
|
|
359
|
+
+ """
|
|
333
360
|
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
|
|
334
361
|
WHERE score > $min_score
|
|
335
362
|
RETURN
|
|
@@ -342,6 +369,7 @@ async def node_similarity_search(
|
|
|
342
369
|
ORDER BY score DESC
|
|
343
370
|
LIMIT $limit
|
|
344
371
|
""",
|
|
372
|
+
query_params,
|
|
345
373
|
search_vector=search_vector,
|
|
346
374
|
group_ids=group_ids,
|
|
347
375
|
limit=limit,
|
|
@@ -436,11 +464,20 @@ async def community_similarity_search(
|
|
|
436
464
|
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
|
437
465
|
)
|
|
438
466
|
|
|
467
|
+
query_params: dict[str, Any] = {}
|
|
468
|
+
|
|
469
|
+
group_filter_query: LiteralString = ''
|
|
470
|
+
if group_ids is not None:
|
|
471
|
+
group_filter_query += 'WHERE comm.group_id IN $group_ids'
|
|
472
|
+
query_params['group_ids'] = group_ids
|
|
473
|
+
|
|
439
474
|
records, _, _ = await driver.execute_query(
|
|
440
475
|
runtime_query
|
|
441
476
|
+ """
|
|
442
477
|
MATCH (comm:Community)
|
|
443
|
-
|
|
478
|
+
"""
|
|
479
|
+
+ group_filter_query
|
|
480
|
+
+ """
|
|
444
481
|
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
|
|
445
482
|
WHERE score > $min_score
|
|
446
483
|
RETURN
|
|
@@ -512,7 +549,7 @@ async def hybrid_node_search(
|
|
|
512
549
|
|
|
513
550
|
start = time()
|
|
514
551
|
results: list[list[EntityNode]] = list(
|
|
515
|
-
await
|
|
552
|
+
await semaphore_gather(
|
|
516
553
|
*[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries],
|
|
517
554
|
*[node_similarity_search(driver, e, group_ids, 2 * limit) for e in embeddings],
|
|
518
555
|
)
|
|
@@ -582,7 +619,7 @@ async def get_relevant_edges(
|
|
|
582
619
|
relevant_edges: list[EntityEdge] = []
|
|
583
620
|
relevant_edge_uuids = set()
|
|
584
621
|
|
|
585
|
-
results = await
|
|
622
|
+
results = await semaphore_gather(
|
|
586
623
|
*[
|
|
587
624
|
edge_similarity_search(
|
|
588
625
|
driver,
|
|
@@ -631,7 +668,7 @@ async def node_distance_reranker(
|
|
|
631
668
|
) -> list[str]:
|
|
632
669
|
# filter out node_uuid center node node uuid
|
|
633
670
|
filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
|
|
634
|
-
scores: dict[str, float] = {}
|
|
671
|
+
scores: dict[str, float] = {center_node_uuid: 0.0}
|
|
635
672
|
|
|
636
673
|
# Find the shortest path to center node
|
|
637
674
|
query = Query("""
|
|
@@ -649,9 +686,13 @@ async def node_distance_reranker(
|
|
|
649
686
|
|
|
650
687
|
for result in path_results:
|
|
651
688
|
uuid = result['uuid']
|
|
652
|
-
score = result['score']
|
|
689
|
+
score = result['score']
|
|
653
690
|
scores[uuid] = score
|
|
654
691
|
|
|
692
|
+
for uuid in filtered_uuids:
|
|
693
|
+
if uuid not in scores:
|
|
694
|
+
scores[uuid] = float('inf')
|
|
695
|
+
|
|
655
696
|
# rerank on shortest distance
|
|
656
697
|
filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
|
657
698
|
|
graphiti_core/utils/__init__.py
CHANGED
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
from .maintenance import (
|
|
2
|
-
build_episodic_edges,
|
|
3
|
-
clear_data,
|
|
4
|
-
extract_edges,
|
|
5
|
-
extract_nodes,
|
|
6
|
-
retrieve_episodes,
|
|
7
|
-
)
|
|
8
|
-
|
|
9
|
-
__all__ = [
|
|
10
|
-
'extract_edges',
|
|
11
|
-
'build_episodic_edges',
|
|
12
|
-
'extract_nodes',
|
|
13
|
-
'clear_data',
|
|
14
|
-
'retrieve_episodes',
|
|
15
|
-
]
|
|
@@ -14,11 +14,10 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
import asyncio
|
|
18
17
|
import logging
|
|
19
18
|
import typing
|
|
20
19
|
from collections import defaultdict
|
|
21
|
-
from datetime import datetime
|
|
20
|
+
from datetime import datetime
|
|
22
21
|
from math import ceil
|
|
23
22
|
|
|
24
23
|
from neo4j import AsyncDriver, AsyncManagedTransaction
|
|
@@ -26,6 +25,7 @@ from numpy import dot, sqrt
|
|
|
26
25
|
from pydantic import BaseModel
|
|
27
26
|
|
|
28
27
|
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
|
28
|
+
from graphiti_core.helpers import semaphore_gather
|
|
29
29
|
from graphiti_core.llm_client import LLMClient
|
|
30
30
|
from graphiti_core.models.edges.edge_db_queries import (
|
|
31
31
|
ENTITY_EDGE_SAVE_BULK,
|
|
@@ -37,14 +37,17 @@ from graphiti_core.models.nodes.node_db_queries import (
|
|
|
37
37
|
)
|
|
38
38
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
|
39
39
|
from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes
|
|
40
|
-
from graphiti_core.utils import
|
|
40
|
+
from graphiti_core.utils.datetime_utils import utc_now
|
|
41
41
|
from graphiti_core.utils.maintenance.edge_operations import (
|
|
42
42
|
build_episodic_edges,
|
|
43
43
|
dedupe_edge_list,
|
|
44
44
|
dedupe_extracted_edges,
|
|
45
45
|
extract_edges,
|
|
46
46
|
)
|
|
47
|
-
from graphiti_core.utils.maintenance.graph_data_operations import
|
|
47
|
+
from graphiti_core.utils.maintenance.graph_data_operations import (
|
|
48
|
+
EPISODE_WINDOW_LEN,
|
|
49
|
+
retrieve_episodes,
|
|
50
|
+
)
|
|
48
51
|
from graphiti_core.utils.maintenance.node_operations import (
|
|
49
52
|
dedupe_extracted_nodes,
|
|
50
53
|
dedupe_node_list,
|
|
@@ -68,7 +71,7 @@ class RawEpisode(BaseModel):
|
|
|
68
71
|
async def retrieve_previous_episodes_bulk(
|
|
69
72
|
driver: AsyncDriver, episodes: list[EpisodicNode]
|
|
70
73
|
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
|
|
71
|
-
previous_episodes_list = await
|
|
74
|
+
previous_episodes_list = await semaphore_gather(
|
|
72
75
|
*[
|
|
73
76
|
retrieve_episodes(
|
|
74
77
|
driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN, group_ids=[episode.group_id]
|
|
@@ -115,7 +118,7 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
115
118
|
async def extract_nodes_and_edges_bulk(
|
|
116
119
|
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
|
|
117
120
|
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
|
|
118
|
-
extracted_nodes_bulk = await
|
|
121
|
+
extracted_nodes_bulk = await semaphore_gather(
|
|
119
122
|
*[
|
|
120
123
|
extract_nodes(llm_client, episode, previous_episodes)
|
|
121
124
|
for episode, previous_episodes in episode_tuples
|
|
@@ -127,7 +130,7 @@ async def extract_nodes_and_edges_bulk(
|
|
|
127
130
|
[episode[1] for episode in episode_tuples],
|
|
128
131
|
)
|
|
129
132
|
|
|
130
|
-
extracted_edges_bulk = await
|
|
133
|
+
extracted_edges_bulk = await semaphore_gather(
|
|
131
134
|
*[
|
|
132
135
|
extract_edges(
|
|
133
136
|
llm_client,
|
|
@@ -168,13 +171,13 @@ async def dedupe_nodes_bulk(
|
|
|
168
171
|
node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
|
|
169
172
|
|
|
170
173
|
existing_nodes_chunks: list[list[EntityNode]] = list(
|
|
171
|
-
await
|
|
174
|
+
await semaphore_gather(
|
|
172
175
|
*[get_relevant_nodes(driver, node_chunk) for node_chunk in node_chunks]
|
|
173
176
|
)
|
|
174
177
|
)
|
|
175
178
|
|
|
176
179
|
results: list[tuple[list[EntityNode], dict[str, str]]] = list(
|
|
177
|
-
await
|
|
180
|
+
await semaphore_gather(
|
|
178
181
|
*[
|
|
179
182
|
dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i])
|
|
180
183
|
for i, node_chunk in enumerate(node_chunks)
|
|
@@ -202,13 +205,13 @@ async def dedupe_edges_bulk(
|
|
|
202
205
|
]
|
|
203
206
|
|
|
204
207
|
relevant_edges_chunks: list[list[EntityEdge]] = list(
|
|
205
|
-
await
|
|
208
|
+
await semaphore_gather(
|
|
206
209
|
*[get_relevant_edges(driver, edge_chunk, None, None) for edge_chunk in edge_chunks]
|
|
207
210
|
)
|
|
208
211
|
)
|
|
209
212
|
|
|
210
213
|
resolved_edge_chunks: list[list[EntityEdge]] = list(
|
|
211
|
-
await
|
|
214
|
+
await semaphore_gather(
|
|
212
215
|
*[
|
|
213
216
|
dedupe_extracted_edges(llm_client, edge_chunk, relevant_edges_chunks[i])
|
|
214
217
|
for i, edge_chunk in enumerate(edge_chunks)
|
|
@@ -289,7 +292,9 @@ async def compress_nodes(
|
|
|
289
292
|
# add both nodes to the shortest chunk
|
|
290
293
|
node_chunks[-1].extend([n, m])
|
|
291
294
|
|
|
292
|
-
results = await
|
|
295
|
+
results = await semaphore_gather(
|
|
296
|
+
*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks]
|
|
297
|
+
)
|
|
293
298
|
|
|
294
299
|
extended_map = dict(uuid_map)
|
|
295
300
|
compressed_nodes: list[EntityNode] = []
|
|
@@ -312,7 +317,9 @@ async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list
|
|
|
312
317
|
# We build a map of the edges based on their source and target nodes.
|
|
313
318
|
edge_chunks = chunk_edges_by_nodes(edges)
|
|
314
319
|
|
|
315
|
-
results = await
|
|
320
|
+
results = await semaphore_gather(
|
|
321
|
+
*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks]
|
|
322
|
+
)
|
|
316
323
|
|
|
317
324
|
compressed_edges: list[EntityEdge] = []
|
|
318
325
|
for edge_chunk in results:
|
|
@@ -365,7 +372,7 @@ async def extract_edge_dates_bulk(
|
|
|
365
372
|
episode.uuid: (episode, previous_episodes) for episode, previous_episodes in episode_pairs
|
|
366
373
|
}
|
|
367
374
|
|
|
368
|
-
results = await
|
|
375
|
+
results = await semaphore_gather(
|
|
369
376
|
*[
|
|
370
377
|
extract_edge_dates(
|
|
371
378
|
llm_client,
|
|
@@ -385,7 +392,7 @@ async def extract_edge_dates_bulk(
|
|
|
385
392
|
edge.valid_at = valid_at
|
|
386
393
|
edge.invalid_at = invalid_at
|
|
387
394
|
if edge.invalid_at:
|
|
388
|
-
edge.expired_at =
|
|
395
|
+
edge.expired_at = utc_now()
|
|
389
396
|
|
|
390
397
|
return edges
|
|
391
398
|
|