graphiti-core 0.2.3__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/edges.py +68 -29
- graphiti_core/errors.py +43 -0
- graphiti_core/graphiti.py +51 -26
- graphiti_core/helpers.py +16 -0
- graphiti_core/llm_client/__init__.py +2 -1
- graphiti_core/llm_client/anthropic_client.py +9 -1
- graphiti_core/llm_client/client.py +17 -10
- graphiti_core/llm_client/errors.py +23 -0
- graphiti_core/llm_client/groq_client.py +4 -0
- graphiti_core/llm_client/openai_client.py +4 -0
- graphiti_core/llm_client/utils.py +17 -1
- graphiti_core/nodes.py +144 -20
- graphiti_core/prompts/extract_edge_dates.py +16 -0
- graphiti_core/prompts/extract_nodes.py +43 -1
- graphiti_core/prompts/lib.py +6 -0
- graphiti_core/prompts/summarize_nodes.py +79 -0
- graphiti_core/py.typed +1 -0
- graphiti_core/search/search.py +176 -79
- graphiti_core/search/search_config.py +81 -0
- graphiti_core/search/search_config_recipes.py +84 -0
- graphiti_core/search/search_utils.py +259 -152
- graphiti_core/utils/maintenance/community_operations.py +155 -0
- graphiti_core/utils/maintenance/edge_operations.py +20 -2
- graphiti_core/utils/maintenance/graph_data_operations.py +11 -0
- graphiti_core/utils/maintenance/node_operations.py +26 -1
- {graphiti_core-0.2.3.dist-info → graphiti_core-0.3.1.dist-info}/METADATA +8 -2
- graphiti_core-0.3.1.dist-info/RECORD +43 -0
- graphiti_core-0.2.3.dist-info/RECORD +0 -36
- {graphiti_core-0.2.3.dist-info → graphiti_core-0.3.1.dist-info}/LICENSE +0 -0
- {graphiti_core-0.2.3.dist-info → graphiti_core-0.3.1.dist-info}/WHEEL +0 -0
|
@@ -1,3 +1,19 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
1
17
|
import asyncio
|
|
2
18
|
import logging
|
|
3
19
|
import re
|
|
@@ -7,7 +23,13 @@ from time import time
|
|
|
7
23
|
from neo4j import AsyncDriver, Query
|
|
8
24
|
|
|
9
25
|
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
|
10
|
-
from graphiti_core.nodes import
|
|
26
|
+
from graphiti_core.nodes import (
|
|
27
|
+
CommunityNode,
|
|
28
|
+
EntityNode,
|
|
29
|
+
EpisodicNode,
|
|
30
|
+
get_community_node_from_record,
|
|
31
|
+
get_entity_node_from_record,
|
|
32
|
+
)
|
|
11
33
|
|
|
12
34
|
logger = logging.getLogger(__name__)
|
|
13
35
|
|
|
@@ -35,6 +57,128 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode])
|
|
|
35
57
|
return nodes
|
|
36
58
|
|
|
37
59
|
|
|
60
|
+
async def edge_fulltext_search(
|
|
61
|
+
driver: AsyncDriver,
|
|
62
|
+
query: str,
|
|
63
|
+
source_node_uuid: str | None,
|
|
64
|
+
target_node_uuid: str | None,
|
|
65
|
+
group_ids: list[str | None] | None = None,
|
|
66
|
+
limit=RELEVANT_SCHEMA_LIMIT,
|
|
67
|
+
) -> list[EntityEdge]:
|
|
68
|
+
# fulltext search over facts
|
|
69
|
+
cypher_query = Query("""
|
|
70
|
+
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
|
71
|
+
YIELD relationship AS rel, score
|
|
72
|
+
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
|
73
|
+
WHERE CASE
|
|
74
|
+
WHEN $group_ids IS NULL THEN n.group_id IS NULL
|
|
75
|
+
ELSE n.group_id IN $group_ids
|
|
76
|
+
END
|
|
77
|
+
RETURN
|
|
78
|
+
r.uuid AS uuid,
|
|
79
|
+
r.group_id AS group_id,
|
|
80
|
+
n.uuid AS source_node_uuid,
|
|
81
|
+
m.uuid AS target_node_uuid,
|
|
82
|
+
r.created_at AS created_at,
|
|
83
|
+
r.name AS name,
|
|
84
|
+
r.fact AS fact,
|
|
85
|
+
r.fact_embedding AS fact_embedding,
|
|
86
|
+
r.episodes AS episodes,
|
|
87
|
+
r.expired_at AS expired_at,
|
|
88
|
+
r.valid_at AS valid_at,
|
|
89
|
+
r.invalid_at AS invalid_at
|
|
90
|
+
ORDER BY score DESC LIMIT $limit
|
|
91
|
+
""")
|
|
92
|
+
|
|
93
|
+
if source_node_uuid is None and target_node_uuid is None:
|
|
94
|
+
cypher_query = Query("""
|
|
95
|
+
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
|
96
|
+
YIELD relationship AS rel, score
|
|
97
|
+
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
|
98
|
+
WHERE CASE
|
|
99
|
+
WHEN $group_ids IS NULL THEN r.group_id IS NULL
|
|
100
|
+
ELSE r.group_id IN $group_ids
|
|
101
|
+
END
|
|
102
|
+
RETURN
|
|
103
|
+
r.uuid AS uuid,
|
|
104
|
+
r.group_id AS group_id,
|
|
105
|
+
n.uuid AS source_node_uuid,
|
|
106
|
+
m.uuid AS target_node_uuid,
|
|
107
|
+
r.created_at AS created_at,
|
|
108
|
+
r.name AS name,
|
|
109
|
+
r.fact AS fact,
|
|
110
|
+
r.fact_embedding AS fact_embedding,
|
|
111
|
+
r.episodes AS episodes,
|
|
112
|
+
r.expired_at AS expired_at,
|
|
113
|
+
r.valid_at AS valid_at,
|
|
114
|
+
r.invalid_at AS invalid_at
|
|
115
|
+
ORDER BY score DESC LIMIT $limit
|
|
116
|
+
""")
|
|
117
|
+
elif source_node_uuid is None:
|
|
118
|
+
cypher_query = Query("""
|
|
119
|
+
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
|
120
|
+
YIELD relationship AS rel, score
|
|
121
|
+
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
|
122
|
+
WHERE CASE
|
|
123
|
+
WHEN $group_ids IS NULL THEN r.group_id IS NULL
|
|
124
|
+
ELSE r.group_id IN $group_ids
|
|
125
|
+
END
|
|
126
|
+
RETURN
|
|
127
|
+
r.uuid AS uuid,
|
|
128
|
+
r.group_id AS group_id,
|
|
129
|
+
n.uuid AS source_node_uuid,
|
|
130
|
+
m.uuid AS target_node_uuid,
|
|
131
|
+
r.created_at AS created_at,
|
|
132
|
+
r.name AS name,
|
|
133
|
+
r.fact AS fact,
|
|
134
|
+
r.fact_embedding AS fact_embedding,
|
|
135
|
+
r.episodes AS episodes,
|
|
136
|
+
r.expired_at AS expired_at,
|
|
137
|
+
r.valid_at AS valid_at,
|
|
138
|
+
r.invalid_at AS invalid_at
|
|
139
|
+
ORDER BY score DESC LIMIT $limit
|
|
140
|
+
""")
|
|
141
|
+
elif target_node_uuid is None:
|
|
142
|
+
cypher_query = Query("""
|
|
143
|
+
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
|
144
|
+
YIELD relationship AS rel, score
|
|
145
|
+
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
|
|
146
|
+
WHERE CASE
|
|
147
|
+
WHEN $group_ids IS NULL THEN r.group_id IS NULL
|
|
148
|
+
ELSE r.group_id IN $group_ids
|
|
149
|
+
END
|
|
150
|
+
RETURN
|
|
151
|
+
r.uuid AS uuid,
|
|
152
|
+
r.group_id AS group_id,
|
|
153
|
+
n.uuid AS source_node_uuid,
|
|
154
|
+
m.uuid AS target_node_uuid,
|
|
155
|
+
r.created_at AS created_at,
|
|
156
|
+
r.name AS name,
|
|
157
|
+
r.fact AS fact,
|
|
158
|
+
r.fact_embedding AS fact_embedding,
|
|
159
|
+
r.episodes AS episodes,
|
|
160
|
+
r.expired_at AS expired_at,
|
|
161
|
+
r.valid_at AS valid_at,
|
|
162
|
+
r.invalid_at AS invalid_at
|
|
163
|
+
ORDER BY score DESC LIMIT $limit
|
|
164
|
+
""")
|
|
165
|
+
|
|
166
|
+
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
|
167
|
+
|
|
168
|
+
records, _, _ = await driver.execute_query(
|
|
169
|
+
cypher_query,
|
|
170
|
+
query=fuzzy_query,
|
|
171
|
+
source_uuid=source_node_uuid,
|
|
172
|
+
target_uuid=target_node_uuid,
|
|
173
|
+
group_ids=group_ids,
|
|
174
|
+
limit=limit,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
edges = [get_entity_edge_from_record(record) for record in records]
|
|
178
|
+
|
|
179
|
+
return edges
|
|
180
|
+
|
|
181
|
+
|
|
38
182
|
async def edge_similarity_search(
|
|
39
183
|
driver: AsyncDriver,
|
|
40
184
|
search_vector: list[float],
|
|
@@ -43,13 +187,15 @@ async def edge_similarity_search(
|
|
|
43
187
|
group_ids: list[str | None] | None = None,
|
|
44
188
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
45
189
|
) -> list[EntityEdge]:
|
|
46
|
-
group_ids = group_ids if group_ids is not None else [None]
|
|
47
190
|
# vector similarity search over embedded facts
|
|
48
191
|
query = Query("""
|
|
49
192
|
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
|
50
193
|
YIELD relationship AS rel, score
|
|
51
194
|
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
|
52
|
-
WHERE
|
|
195
|
+
WHERE CASE
|
|
196
|
+
WHEN $group_ids IS NULL THEN r.group_id IS NULL
|
|
197
|
+
ELSE r.group_id IN $group_ids
|
|
198
|
+
END
|
|
53
199
|
RETURN
|
|
54
200
|
r.uuid AS uuid,
|
|
55
201
|
r.group_id AS group_id,
|
|
@@ -71,7 +217,10 @@ async def edge_similarity_search(
|
|
|
71
217
|
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
|
72
218
|
YIELD relationship AS rel, score
|
|
73
219
|
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
|
74
|
-
WHERE
|
|
220
|
+
WHERE CASE
|
|
221
|
+
WHEN $group_ids IS NULL THEN r.group_id IS NULL
|
|
222
|
+
ELSE r.group_id IN $group_ids
|
|
223
|
+
END
|
|
75
224
|
RETURN
|
|
76
225
|
r.uuid AS uuid,
|
|
77
226
|
r.group_id AS group_id,
|
|
@@ -92,7 +241,10 @@ async def edge_similarity_search(
|
|
|
92
241
|
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
|
93
242
|
YIELD relationship AS rel, score
|
|
94
243
|
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
|
95
|
-
WHERE
|
|
244
|
+
WHERE CASE
|
|
245
|
+
WHEN $group_ids IS NULL THEN r.group_id IS NULL
|
|
246
|
+
ELSE r.group_id IN $group_ids
|
|
247
|
+
END
|
|
96
248
|
RETURN
|
|
97
249
|
r.uuid AS uuid,
|
|
98
250
|
r.group_id AS group_id,
|
|
@@ -113,7 +265,10 @@ async def edge_similarity_search(
|
|
|
113
265
|
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
|
114
266
|
YIELD relationship AS rel, score
|
|
115
267
|
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
|
|
116
|
-
WHERE
|
|
268
|
+
WHERE CASE
|
|
269
|
+
WHEN $group_ids IS NULL THEN r.group_id IS NULL
|
|
270
|
+
ELSE r.group_id IN $group_ids
|
|
271
|
+
END
|
|
117
272
|
RETURN
|
|
118
273
|
r.uuid AS uuid,
|
|
119
274
|
r.group_id AS group_id,
|
|
@@ -144,9 +299,44 @@ async def edge_similarity_search(
|
|
|
144
299
|
return edges
|
|
145
300
|
|
|
146
301
|
|
|
147
|
-
async def
|
|
148
|
-
search_vector: list[float],
|
|
302
|
+
async def node_fulltext_search(
|
|
149
303
|
driver: AsyncDriver,
|
|
304
|
+
query: str,
|
|
305
|
+
group_ids: list[str | None] | None = None,
|
|
306
|
+
limit=RELEVANT_SCHEMA_LIMIT,
|
|
307
|
+
) -> list[EntityNode]:
|
|
308
|
+
# BM25 search to get top nodes
|
|
309
|
+
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
|
310
|
+
records, _, _ = await driver.execute_query(
|
|
311
|
+
"""
|
|
312
|
+
CALL db.index.fulltext.queryNodes("name_and_summary", $query)
|
|
313
|
+
YIELD node AS n, score
|
|
314
|
+
WHERE CASE
|
|
315
|
+
WHEN $group_ids IS NULL THEN n.group_id IS NULL
|
|
316
|
+
ELSE n.group_id IN $group_ids
|
|
317
|
+
END
|
|
318
|
+
RETURN
|
|
319
|
+
n.uuid AS uuid,
|
|
320
|
+
n.group_id AS group_id,
|
|
321
|
+
n.name AS name,
|
|
322
|
+
n.name_embedding AS name_embedding,
|
|
323
|
+
n.created_at AS created_at,
|
|
324
|
+
n.summary AS summary
|
|
325
|
+
ORDER BY score DESC
|
|
326
|
+
LIMIT $limit
|
|
327
|
+
""",
|
|
328
|
+
query=fuzzy_query,
|
|
329
|
+
group_ids=group_ids,
|
|
330
|
+
limit=limit,
|
|
331
|
+
)
|
|
332
|
+
nodes = [get_entity_node_from_record(record) for record in records]
|
|
333
|
+
|
|
334
|
+
return nodes
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
async def node_similarity_search(
|
|
338
|
+
driver: AsyncDriver,
|
|
339
|
+
search_vector: list[float],
|
|
150
340
|
group_ids: list[str | None] | None = None,
|
|
151
341
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
152
342
|
) -> list[EntityNode]:
|
|
@@ -176,28 +366,28 @@ async def entity_similarity_search(
|
|
|
176
366
|
return nodes
|
|
177
367
|
|
|
178
368
|
|
|
179
|
-
async def
|
|
180
|
-
query: str,
|
|
369
|
+
async def community_fulltext_search(
|
|
181
370
|
driver: AsyncDriver,
|
|
371
|
+
query: str,
|
|
182
372
|
group_ids: list[str | None] | None = None,
|
|
183
373
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
184
|
-
) -> list[
|
|
374
|
+
) -> list[CommunityNode]:
|
|
185
375
|
group_ids = group_ids if group_ids is not None else [None]
|
|
186
376
|
|
|
187
|
-
# BM25 search to get top
|
|
377
|
+
# BM25 search to get top communities
|
|
188
378
|
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
|
189
379
|
records, _, _ = await driver.execute_query(
|
|
190
380
|
"""
|
|
191
|
-
CALL db.index.fulltext.queryNodes("
|
|
192
|
-
YIELD node AS
|
|
193
|
-
MATCH (
|
|
381
|
+
CALL db.index.fulltext.queryNodes("community_name", $query)
|
|
382
|
+
YIELD node AS comm, score
|
|
383
|
+
MATCH (comm WHERE comm.group_id in $group_ids)
|
|
194
384
|
RETURN
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
385
|
+
comm.uuid AS uuid,
|
|
386
|
+
comm.group_id AS group_id,
|
|
387
|
+
comm.name AS name,
|
|
388
|
+
comm.name_embedding AS name_embedding,
|
|
389
|
+
comm.created_at AS created_at,
|
|
390
|
+
comm.summary AS summary
|
|
201
391
|
ORDER BY score DESC
|
|
202
392
|
LIMIT $limit
|
|
203
393
|
""",
|
|
@@ -205,121 +395,41 @@ async def entity_fulltext_search(
|
|
|
205
395
|
group_ids=group_ids,
|
|
206
396
|
limit=limit,
|
|
207
397
|
)
|
|
208
|
-
|
|
398
|
+
communities = [get_community_node_from_record(record) for record in records]
|
|
209
399
|
|
|
210
|
-
return
|
|
400
|
+
return communities
|
|
211
401
|
|
|
212
402
|
|
|
213
|
-
async def
|
|
403
|
+
async def community_similarity_search(
|
|
214
404
|
driver: AsyncDriver,
|
|
215
|
-
|
|
216
|
-
source_node_uuid: str | None,
|
|
217
|
-
target_node_uuid: str | None,
|
|
405
|
+
search_vector: list[float],
|
|
218
406
|
group_ids: list[str | None] | None = None,
|
|
219
407
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
220
|
-
) -> list[
|
|
408
|
+
) -> list[CommunityNode]:
|
|
221
409
|
group_ids = group_ids if group_ids is not None else [None]
|
|
222
410
|
|
|
223
|
-
#
|
|
224
|
-
cypher_query = Query("""
|
|
225
|
-
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
|
226
|
-
YIELD relationship AS rel, score
|
|
227
|
-
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
|
228
|
-
WHERE r.group_id IN $group_ids
|
|
229
|
-
RETURN
|
|
230
|
-
r.uuid AS uuid,
|
|
231
|
-
r.group_id AS group_id,
|
|
232
|
-
n.uuid AS source_node_uuid,
|
|
233
|
-
m.uuid AS target_node_uuid,
|
|
234
|
-
r.created_at AS created_at,
|
|
235
|
-
r.name AS name,
|
|
236
|
-
r.fact AS fact,
|
|
237
|
-
r.fact_embedding AS fact_embedding,
|
|
238
|
-
r.episodes AS episodes,
|
|
239
|
-
r.expired_at AS expired_at,
|
|
240
|
-
r.valid_at AS valid_at,
|
|
241
|
-
r.invalid_at AS invalid_at
|
|
242
|
-
ORDER BY score DESC LIMIT $limit
|
|
243
|
-
""")
|
|
244
|
-
|
|
245
|
-
if source_node_uuid is None and target_node_uuid is None:
|
|
246
|
-
cypher_query = Query("""
|
|
247
|
-
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
|
248
|
-
YIELD relationship AS rel, score
|
|
249
|
-
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
|
250
|
-
WHERE r.group_id IN $group_ids
|
|
251
|
-
RETURN
|
|
252
|
-
r.uuid AS uuid,
|
|
253
|
-
r.group_id AS group_id,
|
|
254
|
-
n.uuid AS source_node_uuid,
|
|
255
|
-
m.uuid AS target_node_uuid,
|
|
256
|
-
r.created_at AS created_at,
|
|
257
|
-
r.name AS name,
|
|
258
|
-
r.fact AS fact,
|
|
259
|
-
r.fact_embedding AS fact_embedding,
|
|
260
|
-
r.episodes AS episodes,
|
|
261
|
-
r.expired_at AS expired_at,
|
|
262
|
-
r.valid_at AS valid_at,
|
|
263
|
-
r.invalid_at AS invalid_at
|
|
264
|
-
ORDER BY score DESC LIMIT $limit
|
|
265
|
-
""")
|
|
266
|
-
elif source_node_uuid is None:
|
|
267
|
-
cypher_query = Query("""
|
|
268
|
-
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
|
269
|
-
YIELD relationship AS rel, score
|
|
270
|
-
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
|
271
|
-
WHERE r.group_id IN $group_ids
|
|
272
|
-
RETURN
|
|
273
|
-
r.uuid AS uuid,
|
|
274
|
-
r.group_id AS group_id,
|
|
275
|
-
n.uuid AS source_node_uuid,
|
|
276
|
-
m.uuid AS target_node_uuid,
|
|
277
|
-
r.created_at AS created_at,
|
|
278
|
-
r.name AS name,
|
|
279
|
-
r.fact AS fact,
|
|
280
|
-
r.fact_embedding AS fact_embedding,
|
|
281
|
-
r.episodes AS episodes,
|
|
282
|
-
r.expired_at AS expired_at,
|
|
283
|
-
r.valid_at AS valid_at,
|
|
284
|
-
r.invalid_at AS invalid_at
|
|
285
|
-
ORDER BY score DESC LIMIT $limit
|
|
286
|
-
""")
|
|
287
|
-
elif target_node_uuid is None:
|
|
288
|
-
cypher_query = Query("""
|
|
289
|
-
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
|
290
|
-
YIELD relationship AS rel, score
|
|
291
|
-
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
|
|
292
|
-
WHERE r.group_id IN $group_ids
|
|
293
|
-
RETURN
|
|
294
|
-
r.uuid AS uuid,
|
|
295
|
-
r.group_id AS group_id,
|
|
296
|
-
n.uuid AS source_node_uuid,
|
|
297
|
-
m.uuid AS target_node_uuid,
|
|
298
|
-
r.created_at AS created_at,
|
|
299
|
-
r.name AS name,
|
|
300
|
-
r.fact AS fact,
|
|
301
|
-
r.fact_embedding AS fact_embedding,
|
|
302
|
-
r.episodes AS episodes,
|
|
303
|
-
r.expired_at AS expired_at,
|
|
304
|
-
r.valid_at AS valid_at,
|
|
305
|
-
r.invalid_at AS invalid_at
|
|
306
|
-
ORDER BY score DESC LIMIT $limit
|
|
307
|
-
""")
|
|
308
|
-
|
|
309
|
-
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
|
310
|
-
|
|
411
|
+
# vector similarity search over entity names
|
|
311
412
|
records, _, _ = await driver.execute_query(
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
413
|
+
"""
|
|
414
|
+
CALL db.index.vector.queryNodes("community_name_embedding", $limit, $search_vector)
|
|
415
|
+
YIELD node AS comm, score
|
|
416
|
+
MATCH (comm WHERE comm.group_id IN $group_ids)
|
|
417
|
+
RETURN
|
|
418
|
+
comm.uuid As uuid,
|
|
419
|
+
comm.group_id AS group_id,
|
|
420
|
+
comm.name AS name,
|
|
421
|
+
comm.name_embedding AS name_embedding,
|
|
422
|
+
comm.created_at AS created_at,
|
|
423
|
+
comm.summary AS summary
|
|
424
|
+
ORDER BY score DESC
|
|
425
|
+
""",
|
|
426
|
+
search_vector=search_vector,
|
|
316
427
|
group_ids=group_ids,
|
|
317
428
|
limit=limit,
|
|
318
429
|
)
|
|
430
|
+
communities = [get_community_node_from_record(record) for record in records]
|
|
319
431
|
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
return edges
|
|
432
|
+
return communities
|
|
323
433
|
|
|
324
434
|
|
|
325
435
|
async def hybrid_node_search(
|
|
@@ -371,8 +481,8 @@ async def hybrid_node_search(
|
|
|
371
481
|
|
|
372
482
|
results: list[list[EntityNode]] = list(
|
|
373
483
|
await asyncio.gather(
|
|
374
|
-
*[
|
|
375
|
-
*[
|
|
484
|
+
*[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries],
|
|
485
|
+
*[node_similarity_search(driver, e, group_ids, 2 * limit) for e in embeddings],
|
|
376
486
|
)
|
|
377
487
|
)
|
|
378
488
|
|
|
@@ -490,40 +600,37 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
|
|
|
490
600
|
|
|
491
601
|
|
|
492
602
|
async def node_distance_reranker(
|
|
493
|
-
driver: AsyncDriver,
|
|
603
|
+
driver: AsyncDriver, node_uuids: list[list[str]], center_node_uuid: str
|
|
494
604
|
) -> list[str]:
|
|
495
605
|
# use rrf as a preliminary ranker
|
|
496
|
-
sorted_uuids = rrf(
|
|
606
|
+
sorted_uuids = rrf(node_uuids)
|
|
497
607
|
scores: dict[str, float] = {}
|
|
498
608
|
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO*1..10]->(n:Entity)
|
|
505
|
-
WHERE center.uuid = $center_uuid AND n.uuid IN [source.uuid, target.uuid]
|
|
506
|
-
RETURN min(length(p)) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
|
|
507
|
-
""",
|
|
508
|
-
edge_uuid=uuid,
|
|
509
|
-
center_uuid=center_node_uuid,
|
|
510
|
-
)
|
|
511
|
-
distance = 0.01
|
|
609
|
+
# Find the shortest path to center node
|
|
610
|
+
query = Query("""
|
|
611
|
+
MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: $node_uuid})
|
|
612
|
+
RETURN length(p) AS score
|
|
613
|
+
""")
|
|
512
614
|
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
615
|
+
path_results = await asyncio.gather(
|
|
616
|
+
*[
|
|
617
|
+
driver.execute_query(
|
|
618
|
+
query,
|
|
619
|
+
node_uuid=uuid,
|
|
620
|
+
center_uuid=center_node_uuid,
|
|
621
|
+
)
|
|
622
|
+
for uuid in sorted_uuids
|
|
623
|
+
]
|
|
624
|
+
)
|
|
520
625
|
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
else
|
|
524
|
-
|
|
626
|
+
for uuid, result in zip(sorted_uuids, path_results):
|
|
627
|
+
records = result[0]
|
|
628
|
+
record = records[0] if len(records) > 0 else None
|
|
629
|
+
distance: float = record['score'] if record is not None else float('inf')
|
|
630
|
+
distance = 0 if uuid == center_node_uuid else distance
|
|
631
|
+
scores[uuid] = distance
|
|
525
632
|
|
|
526
633
|
# rerank on shortest distance
|
|
527
|
-
sorted_uuids.sort(
|
|
634
|
+
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
|
528
635
|
|
|
529
636
|
return sorted_uuids
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
|
|
6
|
+
from neo4j import AsyncDriver
|
|
7
|
+
|
|
8
|
+
from graphiti_core.edges import CommunityEdge
|
|
9
|
+
from graphiti_core.llm_client import LLMClient
|
|
10
|
+
from graphiti_core.nodes import CommunityNode, EntityNode
|
|
11
|
+
from graphiti_core.prompts import prompt_library
|
|
12
|
+
from graphiti_core.utils.maintenance.edge_operations import build_community_edges
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
async def build_community_projection(driver: AsyncDriver) -> str:
|
|
18
|
+
records, _, _ = await driver.execute_query("""
|
|
19
|
+
CALL gds.graph.project("communities", "Entity",
|
|
20
|
+
{RELATES_TO: {
|
|
21
|
+
type: "RELATES_TO",
|
|
22
|
+
orientation: "UNDIRECTED",
|
|
23
|
+
properties: {weight: {property: "*", aggregation: "COUNT"}}
|
|
24
|
+
}}
|
|
25
|
+
)
|
|
26
|
+
YIELD graphName AS graph, nodeProjection AS nodes, relationshipProjection AS edges
|
|
27
|
+
""")
|
|
28
|
+
|
|
29
|
+
return records[0]['graph']
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
async def destroy_projection(driver: AsyncDriver, projection_name: str):
|
|
33
|
+
await driver.execute_query(
|
|
34
|
+
"""
|
|
35
|
+
CALL gds.graph.drop($projection_name)
|
|
36
|
+
""",
|
|
37
|
+
projection_name=projection_name,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
async def get_community_clusters(
|
|
42
|
+
driver: AsyncDriver, projection_name: str
|
|
43
|
+
) -> list[list[EntityNode]]:
|
|
44
|
+
records, _, _ = await driver.execute_query("""
|
|
45
|
+
CALL gds.leiden.stream("communities")
|
|
46
|
+
YIELD nodeId, communityId
|
|
47
|
+
RETURN gds.util.asNode(nodeId).uuid AS entity_uuid, communityId
|
|
48
|
+
""")
|
|
49
|
+
community_map: dict[int, list[str]] = defaultdict(list)
|
|
50
|
+
for record in records:
|
|
51
|
+
community_map[record['communityId']].append(record['entity_uuid'])
|
|
52
|
+
|
|
53
|
+
community_clusters: list[list[EntityNode]] = list(
|
|
54
|
+
await asyncio.gather(
|
|
55
|
+
*[EntityNode.get_by_uuids(driver, cluster) for cluster in community_map.values()]
|
|
56
|
+
)
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
return community_clusters
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -> str:
|
|
63
|
+
# Prepare context for LLM
|
|
64
|
+
context = {'node_summaries': [{'summary': summary} for summary in summary_pair]}
|
|
65
|
+
|
|
66
|
+
llm_response = await llm_client.generate_response(
|
|
67
|
+
prompt_library.summarize_nodes.summarize_pair(context)
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
pair_summary = llm_response.get('summary', '')
|
|
71
|
+
|
|
72
|
+
return pair_summary
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
async def generate_summary_description(llm_client: LLMClient, summary: str) -> str:
|
|
76
|
+
context = {'summary': summary}
|
|
77
|
+
|
|
78
|
+
llm_response = await llm_client.generate_response(
|
|
79
|
+
prompt_library.summarize_nodes.summary_description(context)
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
description = llm_response.get('description', '')
|
|
83
|
+
|
|
84
|
+
return description
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
async def build_community(
|
|
88
|
+
llm_client: LLMClient, community_cluster: list[EntityNode]
|
|
89
|
+
) -> tuple[CommunityNode, list[CommunityEdge]]:
|
|
90
|
+
summaries = [entity.summary for entity in community_cluster]
|
|
91
|
+
length = len(summaries)
|
|
92
|
+
while length > 1:
|
|
93
|
+
odd_one_out: str | None = None
|
|
94
|
+
if length % 2 == 1:
|
|
95
|
+
odd_one_out = summaries.pop()
|
|
96
|
+
length -= 1
|
|
97
|
+
new_summaries: list[str] = list(
|
|
98
|
+
await asyncio.gather(
|
|
99
|
+
*[
|
|
100
|
+
summarize_pair(llm_client, (str(left_summary), str(right_summary)))
|
|
101
|
+
for left_summary, right_summary in zip(
|
|
102
|
+
summaries[: int(length / 2)], summaries[int(length / 2) :]
|
|
103
|
+
)
|
|
104
|
+
]
|
|
105
|
+
)
|
|
106
|
+
)
|
|
107
|
+
if odd_one_out is not None:
|
|
108
|
+
new_summaries.append(odd_one_out)
|
|
109
|
+
summaries = new_summaries
|
|
110
|
+
length = len(summaries)
|
|
111
|
+
|
|
112
|
+
summary = summaries[0]
|
|
113
|
+
name = await generate_summary_description(llm_client, summary)
|
|
114
|
+
now = datetime.now()
|
|
115
|
+
community_node = CommunityNode(
|
|
116
|
+
name=name,
|
|
117
|
+
group_id=community_cluster[0].group_id,
|
|
118
|
+
labels=['Community'],
|
|
119
|
+
created_at=now,
|
|
120
|
+
summary=summary,
|
|
121
|
+
)
|
|
122
|
+
community_edges = build_community_edges(community_cluster, community_node, now)
|
|
123
|
+
|
|
124
|
+
logger.info((community_node, community_edges))
|
|
125
|
+
|
|
126
|
+
return community_node, community_edges
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
async def build_communities(
|
|
130
|
+
driver: AsyncDriver, llm_client: LLMClient
|
|
131
|
+
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
|
132
|
+
projection = await build_community_projection(driver)
|
|
133
|
+
community_clusters = await get_community_clusters(driver, projection)
|
|
134
|
+
|
|
135
|
+
communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list(
|
|
136
|
+
await asyncio.gather(
|
|
137
|
+
*[build_community(llm_client, cluster) for cluster in community_clusters]
|
|
138
|
+
)
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
community_nodes: list[CommunityNode] = []
|
|
142
|
+
community_edges: list[CommunityEdge] = []
|
|
143
|
+
for community in communities:
|
|
144
|
+
community_nodes.append(community[0])
|
|
145
|
+
community_edges.extend(community[1])
|
|
146
|
+
|
|
147
|
+
await destroy_projection(driver, projection)
|
|
148
|
+
return community_nodes, community_edges
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
async def remove_communities(driver: AsyncDriver):
|
|
152
|
+
await driver.execute_query("""
|
|
153
|
+
MATCH (c:Community)
|
|
154
|
+
DETACH DELETE c
|
|
155
|
+
""")
|