graphiti-core 0.2.3__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -1,3 +1,19 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
1
17
  import asyncio
2
18
  import logging
3
19
  import re
@@ -7,7 +23,13 @@ from time import time
7
23
  from neo4j import AsyncDriver, Query
8
24
 
9
25
  from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
10
- from graphiti_core.nodes import EntityNode, EpisodicNode, get_entity_node_from_record
26
+ from graphiti_core.nodes import (
27
+ CommunityNode,
28
+ EntityNode,
29
+ EpisodicNode,
30
+ get_community_node_from_record,
31
+ get_entity_node_from_record,
32
+ )
11
33
 
12
34
  logger = logging.getLogger(__name__)
13
35
 
@@ -35,6 +57,128 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode])
35
57
  return nodes
36
58
 
37
59
 
60
+ async def edge_fulltext_search(
61
+ driver: AsyncDriver,
62
+ query: str,
63
+ source_node_uuid: str | None,
64
+ target_node_uuid: str | None,
65
+ group_ids: list[str | None] | None = None,
66
+ limit=RELEVANT_SCHEMA_LIMIT,
67
+ ) -> list[EntityEdge]:
68
+ # fulltext search over facts
69
+ cypher_query = Query("""
70
+ CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
71
+ YIELD relationship AS rel, score
72
+ MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
73
+ WHERE CASE
74
+ WHEN $group_ids IS NULL THEN n.group_id IS NULL
75
+ ELSE n.group_id IN $group_ids
76
+ END
77
+ RETURN
78
+ r.uuid AS uuid,
79
+ r.group_id AS group_id,
80
+ n.uuid AS source_node_uuid,
81
+ m.uuid AS target_node_uuid,
82
+ r.created_at AS created_at,
83
+ r.name AS name,
84
+ r.fact AS fact,
85
+ r.fact_embedding AS fact_embedding,
86
+ r.episodes AS episodes,
87
+ r.expired_at AS expired_at,
88
+ r.valid_at AS valid_at,
89
+ r.invalid_at AS invalid_at
90
+ ORDER BY score DESC LIMIT $limit
91
+ """)
92
+
93
+ if source_node_uuid is None and target_node_uuid is None:
94
+ cypher_query = Query("""
95
+ CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
96
+ YIELD relationship AS rel, score
97
+ MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
98
+ WHERE CASE
99
+ WHEN $group_ids IS NULL THEN r.group_id IS NULL
100
+ ELSE r.group_id IN $group_ids
101
+ END
102
+ RETURN
103
+ r.uuid AS uuid,
104
+ r.group_id AS group_id,
105
+ n.uuid AS source_node_uuid,
106
+ m.uuid AS target_node_uuid,
107
+ r.created_at AS created_at,
108
+ r.name AS name,
109
+ r.fact AS fact,
110
+ r.fact_embedding AS fact_embedding,
111
+ r.episodes AS episodes,
112
+ r.expired_at AS expired_at,
113
+ r.valid_at AS valid_at,
114
+ r.invalid_at AS invalid_at
115
+ ORDER BY score DESC LIMIT $limit
116
+ """)
117
+ elif source_node_uuid is None:
118
+ cypher_query = Query("""
119
+ CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
120
+ YIELD relationship AS rel, score
121
+ MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
122
+ WHERE CASE
123
+ WHEN $group_ids IS NULL THEN r.group_id IS NULL
124
+ ELSE r.group_id IN $group_ids
125
+ END
126
+ RETURN
127
+ r.uuid AS uuid,
128
+ r.group_id AS group_id,
129
+ n.uuid AS source_node_uuid,
130
+ m.uuid AS target_node_uuid,
131
+ r.created_at AS created_at,
132
+ r.name AS name,
133
+ r.fact AS fact,
134
+ r.fact_embedding AS fact_embedding,
135
+ r.episodes AS episodes,
136
+ r.expired_at AS expired_at,
137
+ r.valid_at AS valid_at,
138
+ r.invalid_at AS invalid_at
139
+ ORDER BY score DESC LIMIT $limit
140
+ """)
141
+ elif target_node_uuid is None:
142
+ cypher_query = Query("""
143
+ CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
144
+ YIELD relationship AS rel, score
145
+ MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
146
+ WHERE CASE
147
+ WHEN $group_ids IS NULL THEN r.group_id IS NULL
148
+ ELSE r.group_id IN $group_ids
149
+ END
150
+ RETURN
151
+ r.uuid AS uuid,
152
+ r.group_id AS group_id,
153
+ n.uuid AS source_node_uuid,
154
+ m.uuid AS target_node_uuid,
155
+ r.created_at AS created_at,
156
+ r.name AS name,
157
+ r.fact AS fact,
158
+ r.fact_embedding AS fact_embedding,
159
+ r.episodes AS episodes,
160
+ r.expired_at AS expired_at,
161
+ r.valid_at AS valid_at,
162
+ r.invalid_at AS invalid_at
163
+ ORDER BY score DESC LIMIT $limit
164
+ """)
165
+
166
+ fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
167
+
168
+ records, _, _ = await driver.execute_query(
169
+ cypher_query,
170
+ query=fuzzy_query,
171
+ source_uuid=source_node_uuid,
172
+ target_uuid=target_node_uuid,
173
+ group_ids=group_ids,
174
+ limit=limit,
175
+ )
176
+
177
+ edges = [get_entity_edge_from_record(record) for record in records]
178
+
179
+ return edges
180
+
181
+
38
182
  async def edge_similarity_search(
39
183
  driver: AsyncDriver,
40
184
  search_vector: list[float],
@@ -43,13 +187,15 @@ async def edge_similarity_search(
43
187
  group_ids: list[str | None] | None = None,
44
188
  limit: int = RELEVANT_SCHEMA_LIMIT,
45
189
  ) -> list[EntityEdge]:
46
- group_ids = group_ids if group_ids is not None else [None]
47
190
  # vector similarity search over embedded facts
48
191
  query = Query("""
49
192
  CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
50
193
  YIELD relationship AS rel, score
51
194
  MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
52
- WHERE r.group_id IN $group_ids
195
+ WHERE CASE
196
+ WHEN $group_ids IS NULL THEN r.group_id IS NULL
197
+ ELSE r.group_id IN $group_ids
198
+ END
53
199
  RETURN
54
200
  r.uuid AS uuid,
55
201
  r.group_id AS group_id,
@@ -71,7 +217,10 @@ async def edge_similarity_search(
71
217
  CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
72
218
  YIELD relationship AS rel, score
73
219
  MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
74
- WHERE r.group_id IN $group_ids
220
+ WHERE CASE
221
+ WHEN $group_ids IS NULL THEN r.group_id IS NULL
222
+ ELSE r.group_id IN $group_ids
223
+ END
75
224
  RETURN
76
225
  r.uuid AS uuid,
77
226
  r.group_id AS group_id,
@@ -92,7 +241,10 @@ async def edge_similarity_search(
92
241
  CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
93
242
  YIELD relationship AS rel, score
94
243
  MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
95
- WHERE r.group_id IN $group_ids
244
+ WHERE CASE
245
+ WHEN $group_ids IS NULL THEN r.group_id IS NULL
246
+ ELSE r.group_id IN $group_ids
247
+ END
96
248
  RETURN
97
249
  r.uuid AS uuid,
98
250
  r.group_id AS group_id,
@@ -113,7 +265,10 @@ async def edge_similarity_search(
113
265
  CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
114
266
  YIELD relationship AS rel, score
115
267
  MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
116
- WHERE r.group_id IN $group_ids
268
+ WHERE CASE
269
+ WHEN $group_ids IS NULL THEN r.group_id IS NULL
270
+ ELSE r.group_id IN $group_ids
271
+ END
117
272
  RETURN
118
273
  r.uuid AS uuid,
119
274
  r.group_id AS group_id,
@@ -144,9 +299,44 @@ async def edge_similarity_search(
144
299
  return edges
145
300
 
146
301
 
147
- async def entity_similarity_search(
148
- search_vector: list[float],
302
+ async def node_fulltext_search(
149
303
  driver: AsyncDriver,
304
+ query: str,
305
+ group_ids: list[str | None] | None = None,
306
+ limit=RELEVANT_SCHEMA_LIMIT,
307
+ ) -> list[EntityNode]:
308
+ # BM25 search to get top nodes
309
+ fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
310
+ records, _, _ = await driver.execute_query(
311
+ """
312
+ CALL db.index.fulltext.queryNodes("name_and_summary", $query)
313
+ YIELD node AS n, score
314
+ WHERE CASE
315
+ WHEN $group_ids IS NULL THEN n.group_id IS NULL
316
+ ELSE n.group_id IN $group_ids
317
+ END
318
+ RETURN
319
+ n.uuid AS uuid,
320
+ n.group_id AS group_id,
321
+ n.name AS name,
322
+ n.name_embedding AS name_embedding,
323
+ n.created_at AS created_at,
324
+ n.summary AS summary
325
+ ORDER BY score DESC
326
+ LIMIT $limit
327
+ """,
328
+ query=fuzzy_query,
329
+ group_ids=group_ids,
330
+ limit=limit,
331
+ )
332
+ nodes = [get_entity_node_from_record(record) for record in records]
333
+
334
+ return nodes
335
+
336
+
337
+ async def node_similarity_search(
338
+ driver: AsyncDriver,
339
+ search_vector: list[float],
150
340
  group_ids: list[str | None] | None = None,
151
341
  limit=RELEVANT_SCHEMA_LIMIT,
152
342
  ) -> list[EntityNode]:
@@ -176,28 +366,28 @@ async def entity_similarity_search(
176
366
  return nodes
177
367
 
178
368
 
179
- async def entity_fulltext_search(
180
- query: str,
369
+ async def community_fulltext_search(
181
370
  driver: AsyncDriver,
371
+ query: str,
182
372
  group_ids: list[str | None] | None = None,
183
373
  limit=RELEVANT_SCHEMA_LIMIT,
184
- ) -> list[EntityNode]:
374
+ ) -> list[CommunityNode]:
185
375
  group_ids = group_ids if group_ids is not None else [None]
186
376
 
187
- # BM25 search to get top nodes
377
+ # BM25 search to get top communities
188
378
  fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
189
379
  records, _, _ = await driver.execute_query(
190
380
  """
191
- CALL db.index.fulltext.queryNodes("name_and_summary", $query)
192
- YIELD node AS n, score
193
- MATCH (n WHERE n.group_id in $group_ids)
381
+ CALL db.index.fulltext.queryNodes("community_name", $query)
382
+ YIELD node AS comm, score
383
+ MATCH (comm WHERE comm.group_id in $group_ids)
194
384
  RETURN
195
- n.uuid AS uuid,
196
- n.group_id AS group_id,
197
- n.name AS name,
198
- n.name_embedding AS name_embedding,
199
- n.created_at AS created_at,
200
- n.summary AS summary
385
+ comm.uuid AS uuid,
386
+ comm.group_id AS group_id,
387
+ comm.name AS name,
388
+ comm.name_embedding AS name_embedding,
389
+ comm.created_at AS created_at,
390
+ comm.summary AS summary
201
391
  ORDER BY score DESC
202
392
  LIMIT $limit
203
393
  """,
@@ -205,121 +395,41 @@ async def entity_fulltext_search(
205
395
  group_ids=group_ids,
206
396
  limit=limit,
207
397
  )
208
- nodes = [get_entity_node_from_record(record) for record in records]
398
+ communities = [get_community_node_from_record(record) for record in records]
209
399
 
210
- return nodes
400
+ return communities
211
401
 
212
402
 
213
- async def edge_fulltext_search(
403
+ async def community_similarity_search(
214
404
  driver: AsyncDriver,
215
- query: str,
216
- source_node_uuid: str | None,
217
- target_node_uuid: str | None,
405
+ search_vector: list[float],
218
406
  group_ids: list[str | None] | None = None,
219
407
  limit=RELEVANT_SCHEMA_LIMIT,
220
- ) -> list[EntityEdge]:
408
+ ) -> list[CommunityNode]:
221
409
  group_ids = group_ids if group_ids is not None else [None]
222
410
 
223
- # fulltext search over facts
224
- cypher_query = Query("""
225
- CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
226
- YIELD relationship AS rel, score
227
- MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
228
- WHERE r.group_id IN $group_ids
229
- RETURN
230
- r.uuid AS uuid,
231
- r.group_id AS group_id,
232
- n.uuid AS source_node_uuid,
233
- m.uuid AS target_node_uuid,
234
- r.created_at AS created_at,
235
- r.name AS name,
236
- r.fact AS fact,
237
- r.fact_embedding AS fact_embedding,
238
- r.episodes AS episodes,
239
- r.expired_at AS expired_at,
240
- r.valid_at AS valid_at,
241
- r.invalid_at AS invalid_at
242
- ORDER BY score DESC LIMIT $limit
243
- """)
244
-
245
- if source_node_uuid is None and target_node_uuid is None:
246
- cypher_query = Query("""
247
- CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
248
- YIELD relationship AS rel, score
249
- MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
250
- WHERE r.group_id IN $group_ids
251
- RETURN
252
- r.uuid AS uuid,
253
- r.group_id AS group_id,
254
- n.uuid AS source_node_uuid,
255
- m.uuid AS target_node_uuid,
256
- r.created_at AS created_at,
257
- r.name AS name,
258
- r.fact AS fact,
259
- r.fact_embedding AS fact_embedding,
260
- r.episodes AS episodes,
261
- r.expired_at AS expired_at,
262
- r.valid_at AS valid_at,
263
- r.invalid_at AS invalid_at
264
- ORDER BY score DESC LIMIT $limit
265
- """)
266
- elif source_node_uuid is None:
267
- cypher_query = Query("""
268
- CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
269
- YIELD relationship AS rel, score
270
- MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
271
- WHERE r.group_id IN $group_ids
272
- RETURN
273
- r.uuid AS uuid,
274
- r.group_id AS group_id,
275
- n.uuid AS source_node_uuid,
276
- m.uuid AS target_node_uuid,
277
- r.created_at AS created_at,
278
- r.name AS name,
279
- r.fact AS fact,
280
- r.fact_embedding AS fact_embedding,
281
- r.episodes AS episodes,
282
- r.expired_at AS expired_at,
283
- r.valid_at AS valid_at,
284
- r.invalid_at AS invalid_at
285
- ORDER BY score DESC LIMIT $limit
286
- """)
287
- elif target_node_uuid is None:
288
- cypher_query = Query("""
289
- CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
290
- YIELD relationship AS rel, score
291
- MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
292
- WHERE r.group_id IN $group_ids
293
- RETURN
294
- r.uuid AS uuid,
295
- r.group_id AS group_id,
296
- n.uuid AS source_node_uuid,
297
- m.uuid AS target_node_uuid,
298
- r.created_at AS created_at,
299
- r.name AS name,
300
- r.fact AS fact,
301
- r.fact_embedding AS fact_embedding,
302
- r.episodes AS episodes,
303
- r.expired_at AS expired_at,
304
- r.valid_at AS valid_at,
305
- r.invalid_at AS invalid_at
306
- ORDER BY score DESC LIMIT $limit
307
- """)
308
-
309
- fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
310
-
411
+ # vector similarity search over entity names
311
412
  records, _, _ = await driver.execute_query(
312
- cypher_query,
313
- query=fuzzy_query,
314
- source_uuid=source_node_uuid,
315
- target_uuid=target_node_uuid,
413
+ """
414
+ CALL db.index.vector.queryNodes("community_name_embedding", $limit, $search_vector)
415
+ YIELD node AS comm, score
416
+ MATCH (comm WHERE comm.group_id IN $group_ids)
417
+ RETURN
418
+ comm.uuid As uuid,
419
+ comm.group_id AS group_id,
420
+ comm.name AS name,
421
+ comm.name_embedding AS name_embedding,
422
+ comm.created_at AS created_at,
423
+ comm.summary AS summary
424
+ ORDER BY score DESC
425
+ """,
426
+ search_vector=search_vector,
316
427
  group_ids=group_ids,
317
428
  limit=limit,
318
429
  )
430
+ communities = [get_community_node_from_record(record) for record in records]
319
431
 
320
- edges = [get_entity_edge_from_record(record) for record in records]
321
-
322
- return edges
432
+ return communities
323
433
 
324
434
 
325
435
  async def hybrid_node_search(
@@ -371,8 +481,8 @@ async def hybrid_node_search(
371
481
 
372
482
  results: list[list[EntityNode]] = list(
373
483
  await asyncio.gather(
374
- *[entity_fulltext_search(q, driver, group_ids, 2 * limit) for q in queries],
375
- *[entity_similarity_search(e, driver, group_ids, 2 * limit) for e in embeddings],
484
+ *[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries],
485
+ *[node_similarity_search(driver, e, group_ids, 2 * limit) for e in embeddings],
376
486
  )
377
487
  )
378
488
 
@@ -490,40 +600,37 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
490
600
 
491
601
 
492
602
  async def node_distance_reranker(
493
- driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
603
+ driver: AsyncDriver, node_uuids: list[list[str]], center_node_uuid: str
494
604
  ) -> list[str]:
495
605
  # use rrf as a preliminary ranker
496
- sorted_uuids = rrf(results)
606
+ sorted_uuids = rrf(node_uuids)
497
607
  scores: dict[str, float] = {}
498
608
 
499
- for uuid in sorted_uuids:
500
- # Find the shortest path to center node
501
- records, _, _ = await driver.execute_query(
502
- """
503
- MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
504
- MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO*1..10]->(n:Entity)
505
- WHERE center.uuid = $center_uuid AND n.uuid IN [source.uuid, target.uuid]
506
- RETURN min(length(p)) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
507
- """,
508
- edge_uuid=uuid,
509
- center_uuid=center_node_uuid,
510
- )
511
- distance = 0.01
609
+ # Find the shortest path to center node
610
+ query = Query("""
611
+ MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: $node_uuid})
612
+ RETURN length(p) AS score
613
+ """)
512
614
 
513
- for record in records:
514
- if (
515
- record['source_uuid'] == center_node_uuid
516
- or record['target_uuid'] == center_node_uuid
517
- ):
518
- continue
519
- distance = record['score']
615
+ path_results = await asyncio.gather(
616
+ *[
617
+ driver.execute_query(
618
+ query,
619
+ node_uuid=uuid,
620
+ center_uuid=center_node_uuid,
621
+ )
622
+ for uuid in sorted_uuids
623
+ ]
624
+ )
520
625
 
521
- if uuid in scores:
522
- scores[uuid] = min(1 / distance, scores[uuid])
523
- else:
524
- scores[uuid] = 1 / distance
626
+ for uuid, result in zip(sorted_uuids, path_results):
627
+ records = result[0]
628
+ record = records[0] if len(records) > 0 else None
629
+ distance: float = record['score'] if record is not None else float('inf')
630
+ distance = 0 if uuid == center_node_uuid else distance
631
+ scores[uuid] = distance
525
632
 
526
633
  # rerank on shortest distance
527
- sorted_uuids.sort(reverse=True, key=lambda cur_uuid: scores[cur_uuid])
634
+ sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
528
635
 
529
636
  return sorted_uuids
@@ -0,0 +1,155 @@
1
+ import asyncio
2
+ import logging
3
+ from collections import defaultdict
4
+ from datetime import datetime
5
+
6
+ from neo4j import AsyncDriver
7
+
8
+ from graphiti_core.edges import CommunityEdge
9
+ from graphiti_core.llm_client import LLMClient
10
+ from graphiti_core.nodes import CommunityNode, EntityNode
11
+ from graphiti_core.prompts import prompt_library
12
+ from graphiti_core.utils.maintenance.edge_operations import build_community_edges
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ async def build_community_projection(driver: AsyncDriver) -> str:
18
+ records, _, _ = await driver.execute_query("""
19
+ CALL gds.graph.project("communities", "Entity",
20
+ {RELATES_TO: {
21
+ type: "RELATES_TO",
22
+ orientation: "UNDIRECTED",
23
+ properties: {weight: {property: "*", aggregation: "COUNT"}}
24
+ }}
25
+ )
26
+ YIELD graphName AS graph, nodeProjection AS nodes, relationshipProjection AS edges
27
+ """)
28
+
29
+ return records[0]['graph']
30
+
31
+
32
+ async def destroy_projection(driver: AsyncDriver, projection_name: str):
33
+ await driver.execute_query(
34
+ """
35
+ CALL gds.graph.drop($projection_name)
36
+ """,
37
+ projection_name=projection_name,
38
+ )
39
+
40
+
41
+ async def get_community_clusters(
42
+ driver: AsyncDriver, projection_name: str
43
+ ) -> list[list[EntityNode]]:
44
+ records, _, _ = await driver.execute_query("""
45
+ CALL gds.leiden.stream("communities")
46
+ YIELD nodeId, communityId
47
+ RETURN gds.util.asNode(nodeId).uuid AS entity_uuid, communityId
48
+ """)
49
+ community_map: dict[int, list[str]] = defaultdict(list)
50
+ for record in records:
51
+ community_map[record['communityId']].append(record['entity_uuid'])
52
+
53
+ community_clusters: list[list[EntityNode]] = list(
54
+ await asyncio.gather(
55
+ *[EntityNode.get_by_uuids(driver, cluster) for cluster in community_map.values()]
56
+ )
57
+ )
58
+
59
+ return community_clusters
60
+
61
+
62
+ async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -> str:
63
+ # Prepare context for LLM
64
+ context = {'node_summaries': [{'summary': summary} for summary in summary_pair]}
65
+
66
+ llm_response = await llm_client.generate_response(
67
+ prompt_library.summarize_nodes.summarize_pair(context)
68
+ )
69
+
70
+ pair_summary = llm_response.get('summary', '')
71
+
72
+ return pair_summary
73
+
74
+
75
+ async def generate_summary_description(llm_client: LLMClient, summary: str) -> str:
76
+ context = {'summary': summary}
77
+
78
+ llm_response = await llm_client.generate_response(
79
+ prompt_library.summarize_nodes.summary_description(context)
80
+ )
81
+
82
+ description = llm_response.get('description', '')
83
+
84
+ return description
85
+
86
+
87
+ async def build_community(
88
+ llm_client: LLMClient, community_cluster: list[EntityNode]
89
+ ) -> tuple[CommunityNode, list[CommunityEdge]]:
90
+ summaries = [entity.summary for entity in community_cluster]
91
+ length = len(summaries)
92
+ while length > 1:
93
+ odd_one_out: str | None = None
94
+ if length % 2 == 1:
95
+ odd_one_out = summaries.pop()
96
+ length -= 1
97
+ new_summaries: list[str] = list(
98
+ await asyncio.gather(
99
+ *[
100
+ summarize_pair(llm_client, (str(left_summary), str(right_summary)))
101
+ for left_summary, right_summary in zip(
102
+ summaries[: int(length / 2)], summaries[int(length / 2) :]
103
+ )
104
+ ]
105
+ )
106
+ )
107
+ if odd_one_out is not None:
108
+ new_summaries.append(odd_one_out)
109
+ summaries = new_summaries
110
+ length = len(summaries)
111
+
112
+ summary = summaries[0]
113
+ name = await generate_summary_description(llm_client, summary)
114
+ now = datetime.now()
115
+ community_node = CommunityNode(
116
+ name=name,
117
+ group_id=community_cluster[0].group_id,
118
+ labels=['Community'],
119
+ created_at=now,
120
+ summary=summary,
121
+ )
122
+ community_edges = build_community_edges(community_cluster, community_node, now)
123
+
124
+ logger.info((community_node, community_edges))
125
+
126
+ return community_node, community_edges
127
+
128
+
129
+ async def build_communities(
130
+ driver: AsyncDriver, llm_client: LLMClient
131
+ ) -> tuple[list[CommunityNode], list[CommunityEdge]]:
132
+ projection = await build_community_projection(driver)
133
+ community_clusters = await get_community_clusters(driver, projection)
134
+
135
+ communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list(
136
+ await asyncio.gather(
137
+ *[build_community(llm_client, cluster) for cluster in community_clusters]
138
+ )
139
+ )
140
+
141
+ community_nodes: list[CommunityNode] = []
142
+ community_edges: list[CommunityEdge] = []
143
+ for community in communities:
144
+ community_nodes.append(community[0])
145
+ community_edges.extend(community[1])
146
+
147
+ await destroy_projection(driver, projection)
148
+ return community_nodes, community_edges
149
+
150
+
151
+ async def remove_communities(driver: AsyncDriver):
152
+ await driver.execute_query("""
153
+ MATCH (c:Community)
154
+ DETACH DELETE c
155
+ """)