graphiti-core 0.17.4__py3-none-any.whl → 0.24.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphiti_core/cross_encoder/gemini_reranker_client.py +1 -1
- graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
- graphiti_core/decorators.py +110 -0
- graphiti_core/driver/driver.py +62 -2
- graphiti_core/driver/falkordb_driver.py +215 -23
- graphiti_core/driver/graph_operations/graph_operations.py +191 -0
- graphiti_core/driver/kuzu_driver.py +182 -0
- graphiti_core/driver/neo4j_driver.py +61 -8
- graphiti_core/driver/neptune_driver.py +305 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +264 -132
- graphiti_core/embedder/azure_openai.py +10 -3
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/graph_queries.py +114 -101
- graphiti_core/graphiti.py +582 -255
- graphiti_core/graphiti_types.py +2 -0
- graphiti_core/helpers.py +21 -14
- graphiti_core/llm_client/anthropic_client.py +142 -52
- graphiti_core/llm_client/azure_openai_client.py +57 -19
- graphiti_core/llm_client/client.py +83 -21
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/gemini_client.py +75 -57
- graphiti_core/llm_client/openai_base_client.py +94 -50
- graphiti_core/llm_client/openai_client.py +28 -8
- graphiti_core/llm_client/openai_generic_client.py +91 -56
- graphiti_core/models/edges/edge_db_queries.py +259 -35
- graphiti_core/models/nodes/node_db_queries.py +311 -32
- graphiti_core/nodes.py +388 -164
- graphiti_core/prompts/dedupe_edges.py +42 -31
- graphiti_core/prompts/dedupe_nodes.py +56 -39
- graphiti_core/prompts/eval.py +4 -4
- graphiti_core/prompts/extract_edges.py +23 -14
- graphiti_core/prompts/extract_nodes.py +73 -32
- graphiti_core/prompts/prompt_helpers.py +39 -0
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +23 -25
- graphiti_core/search/search.py +154 -74
- graphiti_core/search/search_config.py +39 -4
- graphiti_core/search/search_filters.py +109 -31
- graphiti_core/search/search_helpers.py +5 -6
- graphiti_core/search/search_utils.py +1360 -473
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +216 -90
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +62 -38
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +286 -126
- graphiti_core/utils/maintenance/graph_data_operations.py +44 -74
- graphiti_core/utils/maintenance/node_operations.py +320 -158
- graphiti_core/utils/maintenance/temporal_operations.py +11 -3
- graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
- graphiti_core/utils/text_utils.py +53 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/METADATA +221 -87
- graphiti_core-0.24.3.dist-info/RECORD +86 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
- graphiti_core-0.17.4.dist-info/RECORD +0 -77
- /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -17,7 +17,7 @@ limitations under the License.
|
|
|
17
17
|
import json
|
|
18
18
|
import logging
|
|
19
19
|
import typing
|
|
20
|
-
from typing import ClassVar
|
|
20
|
+
from typing import Any, ClassVar
|
|
21
21
|
|
|
22
22
|
import openai
|
|
23
23
|
from openai import AsyncOpenAI
|
|
@@ -25,7 +25,7 @@ from openai.types.chat import ChatCompletionMessageParam
|
|
|
25
25
|
from pydantic import BaseModel
|
|
26
26
|
|
|
27
27
|
from ..prompts.models import Message
|
|
28
|
-
from .client import
|
|
28
|
+
from .client import LLMClient, get_extraction_language_instruction
|
|
29
29
|
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
|
|
30
30
|
from .errors import RateLimitError, RefusalError
|
|
31
31
|
|
|
@@ -59,15 +59,20 @@ class OpenAIGenericClient(LLMClient):
|
|
|
59
59
|
MAX_RETRIES: ClassVar[int] = 2
|
|
60
60
|
|
|
61
61
|
def __init__(
|
|
62
|
-
self,
|
|
62
|
+
self,
|
|
63
|
+
config: LLMConfig | None = None,
|
|
64
|
+
cache: bool = False,
|
|
65
|
+
client: typing.Any = None,
|
|
66
|
+
max_tokens: int = 16384,
|
|
63
67
|
):
|
|
64
68
|
"""
|
|
65
|
-
Initialize the
|
|
69
|
+
Initialize the OpenAIGenericClient with the provided configuration, cache setting, and client.
|
|
66
70
|
|
|
67
71
|
Args:
|
|
68
72
|
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
|
69
73
|
cache (bool): Whether to use caching for responses. Defaults to False.
|
|
70
74
|
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
|
75
|
+
max_tokens (int): The maximum number of tokens to generate. Defaults to 16384 (16K) for better compatibility with local models.
|
|
71
76
|
|
|
72
77
|
"""
|
|
73
78
|
# removed caching to simplify the `generate_response` override
|
|
@@ -79,6 +84,9 @@ class OpenAIGenericClient(LLMClient):
|
|
|
79
84
|
|
|
80
85
|
super().__init__(config, cache)
|
|
81
86
|
|
|
87
|
+
# Override max_tokens to support higher limits for local models
|
|
88
|
+
self.max_tokens = max_tokens
|
|
89
|
+
|
|
82
90
|
if client is None:
|
|
83
91
|
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
|
84
92
|
else:
|
|
@@ -99,12 +107,25 @@ class OpenAIGenericClient(LLMClient):
|
|
|
99
107
|
elif m.role == 'system':
|
|
100
108
|
openai_messages.append({'role': 'system', 'content': m.content})
|
|
101
109
|
try:
|
|
110
|
+
# Prepare response format
|
|
111
|
+
response_format: dict[str, Any] = {'type': 'json_object'}
|
|
112
|
+
if response_model is not None:
|
|
113
|
+
schema_name = getattr(response_model, '__name__', 'structured_response')
|
|
114
|
+
json_schema = response_model.model_json_schema()
|
|
115
|
+
response_format = {
|
|
116
|
+
'type': 'json_schema',
|
|
117
|
+
'json_schema': {
|
|
118
|
+
'name': schema_name,
|
|
119
|
+
'schema': json_schema,
|
|
120
|
+
},
|
|
121
|
+
}
|
|
122
|
+
|
|
102
123
|
response = await self.client.chat.completions.create(
|
|
103
124
|
model=self.model or DEFAULT_MODEL,
|
|
104
125
|
messages=openai_messages,
|
|
105
126
|
temperature=self.temperature,
|
|
106
127
|
max_tokens=self.max_tokens,
|
|
107
|
-
response_format=
|
|
128
|
+
response_format=response_format, # type: ignore[arg-type]
|
|
108
129
|
)
|
|
109
130
|
result = response.choices[0].message.content or ''
|
|
110
131
|
return json.loads(result)
|
|
@@ -120,60 +141,74 @@ class OpenAIGenericClient(LLMClient):
|
|
|
120
141
|
response_model: type[BaseModel] | None = None,
|
|
121
142
|
max_tokens: int | None = None,
|
|
122
143
|
model_size: ModelSize = ModelSize.medium,
|
|
144
|
+
group_id: str | None = None,
|
|
145
|
+
prompt_name: str | None = None,
|
|
123
146
|
) -> dict[str, typing.Any]:
|
|
124
147
|
if max_tokens is None:
|
|
125
148
|
max_tokens = self.max_tokens
|
|
126
149
|
|
|
127
|
-
retry_count = 0
|
|
128
|
-
last_error = None
|
|
129
|
-
|
|
130
|
-
if response_model is not None:
|
|
131
|
-
serialized_model = json.dumps(response_model.model_json_schema())
|
|
132
|
-
messages[
|
|
133
|
-
-1
|
|
134
|
-
].content += (
|
|
135
|
-
f'\n\nRespond with a JSON object in the following format:\n\n{serialized_model}'
|
|
136
|
-
)
|
|
137
|
-
|
|
138
150
|
# Add multilingual extraction instructions
|
|
139
|
-
messages[0].content +=
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
151
|
+
messages[0].content += get_extraction_language_instruction(group_id)
|
|
152
|
+
|
|
153
|
+
# Wrap entire operation in tracing span
|
|
154
|
+
with self.tracer.start_span('llm.generate') as span:
|
|
155
|
+
attributes = {
|
|
156
|
+
'llm.provider': 'openai',
|
|
157
|
+
'model.size': model_size.value,
|
|
158
|
+
'max_tokens': max_tokens,
|
|
159
|
+
}
|
|
160
|
+
if prompt_name:
|
|
161
|
+
attributes['prompt.name'] = prompt_name
|
|
162
|
+
span.add_attributes(attributes)
|
|
163
|
+
|
|
164
|
+
retry_count = 0
|
|
165
|
+
last_error = None
|
|
166
|
+
|
|
167
|
+
while retry_count <= self.MAX_RETRIES:
|
|
168
|
+
try:
|
|
169
|
+
response = await self._generate_response(
|
|
170
|
+
messages, response_model, max_tokens=max_tokens, model_size=model_size
|
|
171
|
+
)
|
|
172
|
+
return response
|
|
173
|
+
except (RateLimitError, RefusalError):
|
|
174
|
+
# These errors should not trigger retries
|
|
175
|
+
span.set_status('error', str(last_error))
|
|
159
176
|
raise
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
177
|
+
except (
|
|
178
|
+
openai.APITimeoutError,
|
|
179
|
+
openai.APIConnectionError,
|
|
180
|
+
openai.InternalServerError,
|
|
181
|
+
):
|
|
182
|
+
# Let OpenAI's client handle these retries
|
|
183
|
+
span.set_status('error', str(last_error))
|
|
184
|
+
raise
|
|
185
|
+
except Exception as e:
|
|
186
|
+
last_error = e
|
|
187
|
+
|
|
188
|
+
# Don't retry if we've hit the max retries
|
|
189
|
+
if retry_count >= self.MAX_RETRIES:
|
|
190
|
+
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
|
|
191
|
+
span.set_status('error', str(e))
|
|
192
|
+
span.record_exception(e)
|
|
193
|
+
raise
|
|
194
|
+
|
|
195
|
+
retry_count += 1
|
|
196
|
+
|
|
197
|
+
# Construct a detailed error message for the LLM
|
|
198
|
+
error_context = (
|
|
199
|
+
f'The previous response attempt was invalid. '
|
|
200
|
+
f'Error type: {e.__class__.__name__}. '
|
|
201
|
+
f'Error details: {str(e)}. '
|
|
202
|
+
f'Please try again with a valid response, ensuring the output matches '
|
|
203
|
+
f'the expected format and constraints.'
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
error_message = Message(role='user', content=error_context)
|
|
207
|
+
messages.append(error_message)
|
|
208
|
+
logger.warning(
|
|
209
|
+
f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
# If we somehow get here, raise the last error
|
|
213
|
+
span.set_status('error', str(last_error))
|
|
214
|
+
raise last_error or Exception('Max retries exceeded with no specific error')
|
|
@@ -14,43 +14,267 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
from graphiti_core.driver.driver import GraphProvider
|
|
18
|
+
|
|
17
19
|
EPISODIC_EDGE_SAVE = """
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
UNWIND $episodic_edges AS edge
|
|
26
|
-
MATCH (episode:Episodic {uuid: edge.source_node_uuid})
|
|
27
|
-
MATCH (node:Entity {uuid: edge.target_node_uuid})
|
|
28
|
-
MERGE (episode)-[r:MENTIONS {uuid: edge.uuid}]->(node)
|
|
29
|
-
SET r = {uuid: edge.uuid, group_id: edge.group_id, created_at: edge.created_at}
|
|
30
|
-
RETURN r.uuid AS uuid
|
|
20
|
+
MATCH (episode:Episodic {uuid: $episode_uuid})
|
|
21
|
+
MATCH (node:Entity {uuid: $entity_uuid})
|
|
22
|
+
MERGE (episode)-[e:MENTIONS {uuid: $uuid}]->(node)
|
|
23
|
+
SET
|
|
24
|
+
e.group_id = $group_id,
|
|
25
|
+
e.created_at = $created_at
|
|
26
|
+
RETURN e.uuid AS uuid
|
|
31
27
|
"""
|
|
32
28
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
29
|
+
|
|
30
|
+
def get_episodic_edge_save_bulk_query(provider: GraphProvider) -> str:
|
|
31
|
+
if provider == GraphProvider.KUZU:
|
|
32
|
+
return """
|
|
33
|
+
MATCH (episode:Episodic {uuid: $source_node_uuid})
|
|
34
|
+
MATCH (node:Entity {uuid: $target_node_uuid})
|
|
35
|
+
MERGE (episode)-[e:MENTIONS {uuid: $uuid}]->(node)
|
|
36
|
+
SET
|
|
37
|
+
e.group_id = $group_id,
|
|
38
|
+
e.created_at = $created_at
|
|
39
|
+
RETURN e.uuid AS uuid
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
return """
|
|
43
|
+
UNWIND $episodic_edges AS edge
|
|
44
|
+
MATCH (episode:Episodic {uuid: edge.source_node_uuid})
|
|
45
|
+
MATCH (node:Entity {uuid: edge.target_node_uuid})
|
|
46
|
+
MERGE (episode)-[e:MENTIONS {uuid: edge.uuid}]->(node)
|
|
47
|
+
SET
|
|
48
|
+
e.group_id = edge.group_id,
|
|
49
|
+
e.created_at = edge.created_at
|
|
50
|
+
RETURN e.uuid AS uuid
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
EPISODIC_EDGE_RETURN = """
|
|
55
|
+
e.uuid AS uuid,
|
|
56
|
+
e.group_id AS group_id,
|
|
57
|
+
n.uuid AS source_node_uuid,
|
|
58
|
+
m.uuid AS target_node_uuid,
|
|
59
|
+
e.created_at AS created_at
|
|
49
60
|
"""
|
|
50
61
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
62
|
+
|
|
63
|
+
def get_entity_edge_save_query(provider: GraphProvider, has_aoss: bool = False) -> str:
|
|
64
|
+
match provider:
|
|
65
|
+
case GraphProvider.FALKORDB:
|
|
66
|
+
return """
|
|
67
|
+
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
|
68
|
+
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
|
69
|
+
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
|
70
|
+
SET e = $edge_data
|
|
71
|
+
SET e.fact_embedding = vecf32($edge_data.fact_embedding)
|
|
72
|
+
RETURN e.uuid AS uuid
|
|
73
|
+
"""
|
|
74
|
+
case GraphProvider.NEPTUNE:
|
|
75
|
+
return """
|
|
76
|
+
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
|
77
|
+
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
|
78
|
+
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
|
79
|
+
SET e = removeKeyFromMap(removeKeyFromMap($edge_data, "fact_embedding"), "episodes")
|
|
80
|
+
SET e.fact_embedding = join([x IN coalesce($edge_data.fact_embedding, []) | toString(x) ], ",")
|
|
81
|
+
SET e.episodes = join($edge_data.episodes, ",")
|
|
82
|
+
RETURN $edge_data.uuid AS uuid
|
|
83
|
+
"""
|
|
84
|
+
case GraphProvider.KUZU:
|
|
85
|
+
return """
|
|
86
|
+
MATCH (source:Entity {uuid: $source_uuid})
|
|
87
|
+
MATCH (target:Entity {uuid: $target_uuid})
|
|
88
|
+
MERGE (source)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(target)
|
|
89
|
+
SET
|
|
90
|
+
e.group_id = $group_id,
|
|
91
|
+
e.created_at = $created_at,
|
|
92
|
+
e.name = $name,
|
|
93
|
+
e.fact = $fact,
|
|
94
|
+
e.fact_embedding = $fact_embedding,
|
|
95
|
+
e.episodes = $episodes,
|
|
96
|
+
e.expired_at = $expired_at,
|
|
97
|
+
e.valid_at = $valid_at,
|
|
98
|
+
e.invalid_at = $invalid_at,
|
|
99
|
+
e.attributes = $attributes
|
|
100
|
+
RETURN e.uuid AS uuid
|
|
101
|
+
"""
|
|
102
|
+
case _: # Neo4j
|
|
103
|
+
save_embedding_query = (
|
|
104
|
+
"""WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)"""
|
|
105
|
+
if not has_aoss
|
|
106
|
+
else ''
|
|
107
|
+
)
|
|
108
|
+
return (
|
|
109
|
+
(
|
|
110
|
+
"""
|
|
111
|
+
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
|
112
|
+
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
|
113
|
+
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
|
114
|
+
SET e = $edge_data
|
|
115
|
+
"""
|
|
116
|
+
+ save_embedding_query
|
|
117
|
+
)
|
|
118
|
+
+ """
|
|
119
|
+
RETURN e.uuid AS uuid
|
|
120
|
+
"""
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def get_entity_edge_save_bulk_query(provider: GraphProvider, has_aoss: bool = False) -> str:
|
|
125
|
+
match provider:
|
|
126
|
+
case GraphProvider.FALKORDB:
|
|
127
|
+
return """
|
|
128
|
+
UNWIND $entity_edges AS edge
|
|
129
|
+
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
|
130
|
+
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
|
131
|
+
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
|
132
|
+
SET r = edge
|
|
133
|
+
SET r.fact_embedding = vecf32(edge.fact_embedding)
|
|
134
|
+
WITH r, edge
|
|
135
|
+
RETURN edge.uuid AS uuid
|
|
136
|
+
"""
|
|
137
|
+
case GraphProvider.NEPTUNE:
|
|
138
|
+
return """
|
|
139
|
+
UNWIND $entity_edges AS edge
|
|
140
|
+
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
|
141
|
+
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
|
142
|
+
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
|
143
|
+
SET r = removeKeyFromMap(removeKeyFromMap(edge, "fact_embedding"), "episodes")
|
|
144
|
+
SET r.fact_embedding = join([x IN coalesce(edge.fact_embedding, []) | toString(x) ], ",")
|
|
145
|
+
SET r.episodes = join(edge.episodes, ",")
|
|
146
|
+
RETURN edge.uuid AS uuid
|
|
147
|
+
"""
|
|
148
|
+
case GraphProvider.KUZU:
|
|
149
|
+
return """
|
|
150
|
+
MATCH (source:Entity {uuid: $source_node_uuid})
|
|
151
|
+
MATCH (target:Entity {uuid: $target_node_uuid})
|
|
152
|
+
MERGE (source)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(target)
|
|
153
|
+
SET
|
|
154
|
+
e.group_id = $group_id,
|
|
155
|
+
e.created_at = $created_at,
|
|
156
|
+
e.name = $name,
|
|
157
|
+
e.fact = $fact,
|
|
158
|
+
e.fact_embedding = $fact_embedding,
|
|
159
|
+
e.episodes = $episodes,
|
|
160
|
+
e.expired_at = $expired_at,
|
|
161
|
+
e.valid_at = $valid_at,
|
|
162
|
+
e.invalid_at = $invalid_at,
|
|
163
|
+
e.attributes = $attributes
|
|
164
|
+
RETURN e.uuid AS uuid
|
|
165
|
+
"""
|
|
166
|
+
case _:
|
|
167
|
+
save_embedding_query = (
|
|
168
|
+
'WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)'
|
|
169
|
+
if not has_aoss
|
|
170
|
+
else ''
|
|
171
|
+
)
|
|
172
|
+
return (
|
|
173
|
+
"""
|
|
174
|
+
UNWIND $entity_edges AS edge
|
|
175
|
+
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
|
176
|
+
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
|
177
|
+
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
|
|
178
|
+
SET e = edge
|
|
179
|
+
"""
|
|
180
|
+
+ save_embedding_query
|
|
181
|
+
+ """
|
|
182
|
+
RETURN edge.uuid AS uuid
|
|
183
|
+
"""
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def get_entity_edge_return_query(provider: GraphProvider) -> str:
|
|
188
|
+
# `fact_embedding` is not returned by default and must be manually loaded using `load_fact_embedding()`.
|
|
189
|
+
|
|
190
|
+
if provider == GraphProvider.NEPTUNE:
|
|
191
|
+
return """
|
|
192
|
+
e.uuid AS uuid,
|
|
193
|
+
n.uuid AS source_node_uuid,
|
|
194
|
+
m.uuid AS target_node_uuid,
|
|
195
|
+
e.group_id AS group_id,
|
|
196
|
+
e.name AS name,
|
|
197
|
+
e.fact AS fact,
|
|
198
|
+
split(e.episodes, ',') AS episodes,
|
|
199
|
+
e.created_at AS created_at,
|
|
200
|
+
e.expired_at AS expired_at,
|
|
201
|
+
e.valid_at AS valid_at,
|
|
202
|
+
e.invalid_at AS invalid_at,
|
|
203
|
+
properties(e) AS attributes
|
|
204
|
+
"""
|
|
205
|
+
|
|
206
|
+
return """
|
|
207
|
+
e.uuid AS uuid,
|
|
208
|
+
n.uuid AS source_node_uuid,
|
|
209
|
+
m.uuid AS target_node_uuid,
|
|
210
|
+
e.group_id AS group_id,
|
|
211
|
+
e.created_at AS created_at,
|
|
212
|
+
e.name AS name,
|
|
213
|
+
e.fact AS fact,
|
|
214
|
+
e.episodes AS episodes,
|
|
215
|
+
e.expired_at AS expired_at,
|
|
216
|
+
e.valid_at AS valid_at,
|
|
217
|
+
e.invalid_at AS invalid_at,
|
|
218
|
+
""" + (
|
|
219
|
+
'e.attributes AS attributes'
|
|
220
|
+
if provider == GraphProvider.KUZU
|
|
221
|
+
else 'properties(e) AS attributes'
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def get_community_edge_save_query(provider: GraphProvider) -> str:
|
|
226
|
+
match provider:
|
|
227
|
+
case GraphProvider.FALKORDB:
|
|
228
|
+
return """
|
|
229
|
+
MATCH (community:Community {uuid: $community_uuid})
|
|
230
|
+
MATCH (node {uuid: $entity_uuid})
|
|
231
|
+
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
|
|
232
|
+
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
|
233
|
+
RETURN e.uuid AS uuid
|
|
234
|
+
"""
|
|
235
|
+
case GraphProvider.NEPTUNE:
|
|
236
|
+
return """
|
|
237
|
+
MATCH (community:Community {uuid: $community_uuid})
|
|
238
|
+
MATCH (node {uuid: $entity_uuid})
|
|
239
|
+
WHERE node:Entity OR node:Community
|
|
240
|
+
MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
|
|
241
|
+
SET r.uuid= $uuid
|
|
242
|
+
SET r.group_id= $group_id
|
|
243
|
+
SET r.created_at= $created_at
|
|
244
|
+
RETURN r.uuid AS uuid
|
|
245
|
+
"""
|
|
246
|
+
case GraphProvider.KUZU:
|
|
247
|
+
return """
|
|
248
|
+
MATCH (community:Community {uuid: $community_uuid})
|
|
249
|
+
MATCH (node:Entity {uuid: $entity_uuid})
|
|
250
|
+
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
|
|
251
|
+
SET
|
|
252
|
+
e.group_id = $group_id,
|
|
253
|
+
e.created_at = $created_at
|
|
254
|
+
RETURN e.uuid AS uuid
|
|
255
|
+
UNION
|
|
256
|
+
MATCH (community:Community {uuid: $community_uuid})
|
|
257
|
+
MATCH (node:Community {uuid: $entity_uuid})
|
|
258
|
+
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
|
|
259
|
+
SET
|
|
260
|
+
e.group_id = $group_id,
|
|
261
|
+
e.created_at = $created_at
|
|
262
|
+
RETURN e.uuid AS uuid
|
|
263
|
+
"""
|
|
264
|
+
case _: # Neo4j
|
|
265
|
+
return """
|
|
266
|
+
MATCH (community:Community {uuid: $community_uuid})
|
|
267
|
+
MATCH (node:Entity | Community {uuid: $entity_uuid})
|
|
268
|
+
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
|
|
269
|
+
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
|
270
|
+
RETURN e.uuid AS uuid
|
|
271
|
+
"""
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
COMMUNITY_EDGE_RETURN = """
|
|
275
|
+
e.uuid AS uuid,
|
|
276
|
+
e.group_id AS group_id,
|
|
277
|
+
n.uuid AS source_node_uuid,
|
|
278
|
+
m.uuid AS target_node_uuid,
|
|
279
|
+
e.created_at AS created_at
|
|
280
|
+
"""
|