graphiti-core 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/cross_encoder/client.py +1 -1
- graphiti_core/cross_encoder/openai_reranker_client.py +2 -2
- graphiti_core/edges.py +13 -10
- graphiti_core/graphiti.py +25 -27
- graphiti_core/helpers.py +25 -0
- graphiti_core/llm_client/anthropic_client.py +4 -1
- graphiti_core/llm_client/client.py +45 -5
- graphiti_core/llm_client/errors.py +8 -0
- graphiti_core/llm_client/groq_client.py +4 -1
- graphiti_core/llm_client/openai_client.py +71 -7
- graphiti_core/llm_client/openai_generic_client.py +163 -0
- graphiti_core/nodes.py +16 -12
- graphiti_core/prompts/dedupe_edges.py +20 -17
- graphiti_core/prompts/dedupe_nodes.py +15 -1
- graphiti_core/prompts/eval.py +17 -14
- graphiti_core/prompts/extract_edge_dates.py +15 -7
- graphiti_core/prompts/extract_edges.py +18 -19
- graphiti_core/prompts/extract_nodes.py +11 -21
- graphiti_core/prompts/invalidate_edges.py +13 -25
- graphiti_core/prompts/summarize_nodes.py +17 -16
- graphiti_core/search/search.py +5 -5
- graphiti_core/search/search_utils.py +54 -13
- graphiti_core/utils/__init__.py +0 -15
- graphiti_core/utils/bulk_utils.py +22 -15
- graphiti_core/utils/datetime_utils.py +42 -0
- graphiti_core/utils/maintenance/community_operations.py +13 -9
- graphiti_core/utils/maintenance/edge_operations.py +26 -19
- graphiti_core/utils/maintenance/graph_data_operations.py +3 -4
- graphiti_core/utils/maintenance/node_operations.py +19 -13
- graphiti_core/utils/maintenance/temporal_operations.py +16 -7
- {graphiti_core-0.4.3.dist-info → graphiti_core-0.5.0.dist-info}/METADATA +1 -1
- graphiti_core-0.5.0.dist-info/RECORD +60 -0
- graphiti_core-0.4.3.dist-info/RECORD +0 -58
- {graphiti_core-0.4.3.dist-info → graphiti_core-0.5.0.dist-info}/LICENSE +0 -0
- {graphiti_core-0.4.3.dist-info → graphiti_core-0.5.0.dist-info}/WHEEL +0 -0
|
@@ -17,9 +17,21 @@ limitations under the License.
|
|
|
17
17
|
import json
|
|
18
18
|
from typing import Any, Protocol, TypedDict
|
|
19
19
|
|
|
20
|
+
from pydantic import BaseModel, Field
|
|
21
|
+
|
|
20
22
|
from .models import Message, PromptFunction, PromptVersion
|
|
21
23
|
|
|
22
24
|
|
|
25
|
+
class Summary(BaseModel):
|
|
26
|
+
summary: str = Field(
|
|
27
|
+
..., description='Summary containing the important information from both summaries'
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class SummaryDescription(BaseModel):
|
|
32
|
+
description: str = Field(..., description='One sentence description of the provided summary')
|
|
33
|
+
|
|
34
|
+
|
|
23
35
|
class Prompt(Protocol):
|
|
24
36
|
summarize_pair: PromptVersion
|
|
25
37
|
summarize_context: PromptVersion
|
|
@@ -42,14 +54,11 @@ def summarize_pair(context: dict[str, Any]) -> list[Message]:
|
|
|
42
54
|
role='user',
|
|
43
55
|
content=f"""
|
|
44
56
|
Synthesize the information from the following two summaries into a single succinct summary.
|
|
57
|
+
|
|
58
|
+
Summaries must be under 500 words.
|
|
45
59
|
|
|
46
60
|
Summaries:
|
|
47
61
|
{json.dumps(context['node_summaries'], indent=2)}
|
|
48
|
-
|
|
49
|
-
Respond with a JSON object in the following format:
|
|
50
|
-
{{
|
|
51
|
-
"summary": "Summary containing the important information from both summaries"
|
|
52
|
-
}}
|
|
53
62
|
""",
|
|
54
63
|
),
|
|
55
64
|
]
|
|
@@ -74,15 +83,11 @@ def summarize_context(context: dict[str, Any]) -> list[Message]:
|
|
|
74
83
|
information from the provided MESSAGES. Your summary should also only contain information relevant to the
|
|
75
84
|
provided ENTITY.
|
|
76
85
|
|
|
86
|
+
Summaries must be under 500 words.
|
|
87
|
+
|
|
77
88
|
<ENTITY>
|
|
78
89
|
{context['node_name']}
|
|
79
90
|
</ENTITY>
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
Respond with a JSON object in the following format:
|
|
83
|
-
{{
|
|
84
|
-
"summary": "Entity summary"
|
|
85
|
-
}}
|
|
86
91
|
""",
|
|
87
92
|
),
|
|
88
93
|
]
|
|
@@ -98,14 +103,10 @@ def summary_description(context: dict[str, Any]) -> list[Message]:
|
|
|
98
103
|
role='user',
|
|
99
104
|
content=f"""
|
|
100
105
|
Create a short one sentence description of the summary that explains what kind of information is summarized.
|
|
106
|
+
Summaries must be under 500 words.
|
|
101
107
|
|
|
102
108
|
Summary:
|
|
103
109
|
{json.dumps(context['summary'], indent=2)}
|
|
104
|
-
|
|
105
|
-
Respond with a JSON object in the following format:
|
|
106
|
-
{{
|
|
107
|
-
"description": "One sentence description of the provided summary"
|
|
108
|
-
}}
|
|
109
110
|
""",
|
|
110
111
|
),
|
|
111
112
|
]
|
graphiti_core/search/search.py
CHANGED
|
@@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
import asyncio
|
|
18
17
|
import logging
|
|
19
18
|
from collections import defaultdict
|
|
20
19
|
from time import time
|
|
@@ -25,6 +24,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
|
25
24
|
from graphiti_core.edges import EntityEdge
|
|
26
25
|
from graphiti_core.embedder import EmbedderClient
|
|
27
26
|
from graphiti_core.errors import SearchRerankerError
|
|
27
|
+
from graphiti_core.helpers import semaphore_gather
|
|
28
28
|
from graphiti_core.nodes import CommunityNode, EntityNode
|
|
29
29
|
from graphiti_core.search.search_config import (
|
|
30
30
|
DEFAULT_SEARCH_LIMIT,
|
|
@@ -78,7 +78,7 @@ async def search(
|
|
|
78
78
|
|
|
79
79
|
# if group_ids is empty, set it to None
|
|
80
80
|
group_ids = group_ids if group_ids else None
|
|
81
|
-
edges, nodes, communities = await
|
|
81
|
+
edges, nodes, communities = await semaphore_gather(
|
|
82
82
|
edge_search(
|
|
83
83
|
driver,
|
|
84
84
|
cross_encoder,
|
|
@@ -141,7 +141,7 @@ async def edge_search(
|
|
|
141
141
|
return []
|
|
142
142
|
|
|
143
143
|
search_results: list[list[EntityEdge]] = list(
|
|
144
|
-
await
|
|
144
|
+
await semaphore_gather(
|
|
145
145
|
*[
|
|
146
146
|
edge_fulltext_search(driver, query, group_ids, 2 * limit),
|
|
147
147
|
edge_similarity_search(
|
|
@@ -226,7 +226,7 @@ async def node_search(
|
|
|
226
226
|
return []
|
|
227
227
|
|
|
228
228
|
search_results: list[list[EntityNode]] = list(
|
|
229
|
-
await
|
|
229
|
+
await semaphore_gather(
|
|
230
230
|
*[
|
|
231
231
|
node_fulltext_search(driver, query, group_ids, 2 * limit),
|
|
232
232
|
node_similarity_search(
|
|
@@ -295,7 +295,7 @@ async def community_search(
|
|
|
295
295
|
return []
|
|
296
296
|
|
|
297
297
|
search_results: list[list[CommunityNode]] = list(
|
|
298
|
-
await
|
|
298
|
+
await semaphore_gather(
|
|
299
299
|
*[
|
|
300
300
|
community_fulltext_search(driver, query, group_ids, 2 * limit),
|
|
301
301
|
community_similarity_search(
|
|
@@ -14,10 +14,10 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
import asyncio
|
|
18
17
|
import logging
|
|
19
18
|
from collections import defaultdict
|
|
20
19
|
from time import time
|
|
20
|
+
from typing import Any
|
|
21
21
|
|
|
22
22
|
import numpy as np
|
|
23
23
|
from neo4j import AsyncDriver, Query
|
|
@@ -29,6 +29,7 @@ from graphiti_core.helpers import (
|
|
|
29
29
|
USE_PARALLEL_RUNTIME,
|
|
30
30
|
lucene_sanitize,
|
|
31
31
|
normalize_l2,
|
|
32
|
+
semaphore_gather,
|
|
32
33
|
)
|
|
33
34
|
from graphiti_core.nodes import (
|
|
34
35
|
CommunityNode,
|
|
@@ -191,12 +192,27 @@ async def edge_similarity_search(
|
|
|
191
192
|
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
|
192
193
|
)
|
|
193
194
|
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
195
|
+
query_params: dict[str, Any] = {}
|
|
196
|
+
|
|
197
|
+
group_filter_query: LiteralString = ''
|
|
198
|
+
if group_ids is not None:
|
|
199
|
+
group_filter_query += 'WHERE r.group_id IN $group_ids'
|
|
200
|
+
query_params['group_ids'] = group_ids
|
|
201
|
+
query_params['source_node_uuid'] = source_node_uuid
|
|
202
|
+
query_params['target_node_uuid'] = target_node_uuid
|
|
203
|
+
|
|
204
|
+
if source_node_uuid is not None:
|
|
205
|
+
group_filter_query += '\nAND (n.uuid IN [$source_uuid, $target_uuid])'
|
|
206
|
+
|
|
207
|
+
if target_node_uuid is not None:
|
|
208
|
+
group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
|
|
209
|
+
|
|
210
|
+
query: LiteralString = (
|
|
211
|
+
"""
|
|
212
|
+
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
|
213
|
+
"""
|
|
214
|
+
+ group_filter_query
|
|
215
|
+
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
|
200
216
|
WHERE score > $min_score
|
|
201
217
|
RETURN
|
|
202
218
|
r.uuid AS uuid,
|
|
@@ -214,9 +230,11 @@ async def edge_similarity_search(
|
|
|
214
230
|
ORDER BY score DESC
|
|
215
231
|
LIMIT $limit
|
|
216
232
|
"""
|
|
233
|
+
)
|
|
217
234
|
|
|
218
235
|
records, _, _ = await driver.execute_query(
|
|
219
236
|
runtime_query + query,
|
|
237
|
+
query_params,
|
|
220
238
|
search_vector=search_vector,
|
|
221
239
|
source_uuid=source_node_uuid,
|
|
222
240
|
target_uuid=target_node_uuid,
|
|
@@ -325,11 +343,20 @@ async def node_similarity_search(
|
|
|
325
343
|
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
|
326
344
|
)
|
|
327
345
|
|
|
346
|
+
query_params: dict[str, Any] = {}
|
|
347
|
+
|
|
348
|
+
group_filter_query: LiteralString = ''
|
|
349
|
+
if group_ids is not None:
|
|
350
|
+
group_filter_query += 'WHERE n.group_id IN $group_ids'
|
|
351
|
+
query_params['group_ids'] = group_ids
|
|
352
|
+
|
|
328
353
|
records, _, _ = await driver.execute_query(
|
|
329
354
|
runtime_query
|
|
330
355
|
+ """
|
|
331
356
|
MATCH (n:Entity)
|
|
332
|
-
|
|
357
|
+
"""
|
|
358
|
+
+ group_filter_query
|
|
359
|
+
+ """
|
|
333
360
|
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
|
|
334
361
|
WHERE score > $min_score
|
|
335
362
|
RETURN
|
|
@@ -342,6 +369,7 @@ async def node_similarity_search(
|
|
|
342
369
|
ORDER BY score DESC
|
|
343
370
|
LIMIT $limit
|
|
344
371
|
""",
|
|
372
|
+
query_params,
|
|
345
373
|
search_vector=search_vector,
|
|
346
374
|
group_ids=group_ids,
|
|
347
375
|
limit=limit,
|
|
@@ -436,11 +464,20 @@ async def community_similarity_search(
|
|
|
436
464
|
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
|
437
465
|
)
|
|
438
466
|
|
|
467
|
+
query_params: dict[str, Any] = {}
|
|
468
|
+
|
|
469
|
+
group_filter_query: LiteralString = ''
|
|
470
|
+
if group_ids is not None:
|
|
471
|
+
group_filter_query += 'WHERE comm.group_id IN $group_ids'
|
|
472
|
+
query_params['group_ids'] = group_ids
|
|
473
|
+
|
|
439
474
|
records, _, _ = await driver.execute_query(
|
|
440
475
|
runtime_query
|
|
441
476
|
+ """
|
|
442
477
|
MATCH (comm:Community)
|
|
443
|
-
|
|
478
|
+
"""
|
|
479
|
+
+ group_filter_query
|
|
480
|
+
+ """
|
|
444
481
|
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
|
|
445
482
|
WHERE score > $min_score
|
|
446
483
|
RETURN
|
|
@@ -512,7 +549,7 @@ async def hybrid_node_search(
|
|
|
512
549
|
|
|
513
550
|
start = time()
|
|
514
551
|
results: list[list[EntityNode]] = list(
|
|
515
|
-
await
|
|
552
|
+
await semaphore_gather(
|
|
516
553
|
*[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries],
|
|
517
554
|
*[node_similarity_search(driver, e, group_ids, 2 * limit) for e in embeddings],
|
|
518
555
|
)
|
|
@@ -582,7 +619,7 @@ async def get_relevant_edges(
|
|
|
582
619
|
relevant_edges: list[EntityEdge] = []
|
|
583
620
|
relevant_edge_uuids = set()
|
|
584
621
|
|
|
585
|
-
results = await
|
|
622
|
+
results = await semaphore_gather(
|
|
586
623
|
*[
|
|
587
624
|
edge_similarity_search(
|
|
588
625
|
driver,
|
|
@@ -631,7 +668,7 @@ async def node_distance_reranker(
|
|
|
631
668
|
) -> list[str]:
|
|
632
669
|
# filter out node_uuid center node node uuid
|
|
633
670
|
filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
|
|
634
|
-
scores: dict[str, float] = {}
|
|
671
|
+
scores: dict[str, float] = {center_node_uuid: 0.0}
|
|
635
672
|
|
|
636
673
|
# Find the shortest path to center node
|
|
637
674
|
query = Query("""
|
|
@@ -649,9 +686,13 @@ async def node_distance_reranker(
|
|
|
649
686
|
|
|
650
687
|
for result in path_results:
|
|
651
688
|
uuid = result['uuid']
|
|
652
|
-
score = result['score']
|
|
689
|
+
score = result['score']
|
|
653
690
|
scores[uuid] = score
|
|
654
691
|
|
|
692
|
+
for uuid in filtered_uuids:
|
|
693
|
+
if uuid not in scores:
|
|
694
|
+
scores[uuid] = float('inf')
|
|
695
|
+
|
|
655
696
|
# rerank on shortest distance
|
|
656
697
|
filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
|
657
698
|
|
graphiti_core/utils/__init__.py
CHANGED
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
from .maintenance import (
|
|
2
|
-
build_episodic_edges,
|
|
3
|
-
clear_data,
|
|
4
|
-
extract_edges,
|
|
5
|
-
extract_nodes,
|
|
6
|
-
retrieve_episodes,
|
|
7
|
-
)
|
|
8
|
-
|
|
9
|
-
__all__ = [
|
|
10
|
-
'extract_edges',
|
|
11
|
-
'build_episodic_edges',
|
|
12
|
-
'extract_nodes',
|
|
13
|
-
'clear_data',
|
|
14
|
-
'retrieve_episodes',
|
|
15
|
-
]
|
|
@@ -14,11 +14,10 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
import asyncio
|
|
18
17
|
import logging
|
|
19
18
|
import typing
|
|
20
19
|
from collections import defaultdict
|
|
21
|
-
from datetime import datetime
|
|
20
|
+
from datetime import datetime
|
|
22
21
|
from math import ceil
|
|
23
22
|
|
|
24
23
|
from neo4j import AsyncDriver, AsyncManagedTransaction
|
|
@@ -26,6 +25,7 @@ from numpy import dot, sqrt
|
|
|
26
25
|
from pydantic import BaseModel
|
|
27
26
|
|
|
28
27
|
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
|
28
|
+
from graphiti_core.helpers import semaphore_gather
|
|
29
29
|
from graphiti_core.llm_client import LLMClient
|
|
30
30
|
from graphiti_core.models.edges.edge_db_queries import (
|
|
31
31
|
ENTITY_EDGE_SAVE_BULK,
|
|
@@ -37,14 +37,17 @@ from graphiti_core.models.nodes.node_db_queries import (
|
|
|
37
37
|
)
|
|
38
38
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
|
39
39
|
from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes
|
|
40
|
-
from graphiti_core.utils import
|
|
40
|
+
from graphiti_core.utils.datetime_utils import utc_now
|
|
41
41
|
from graphiti_core.utils.maintenance.edge_operations import (
|
|
42
42
|
build_episodic_edges,
|
|
43
43
|
dedupe_edge_list,
|
|
44
44
|
dedupe_extracted_edges,
|
|
45
45
|
extract_edges,
|
|
46
46
|
)
|
|
47
|
-
from graphiti_core.utils.maintenance.graph_data_operations import
|
|
47
|
+
from graphiti_core.utils.maintenance.graph_data_operations import (
|
|
48
|
+
EPISODE_WINDOW_LEN,
|
|
49
|
+
retrieve_episodes,
|
|
50
|
+
)
|
|
48
51
|
from graphiti_core.utils.maintenance.node_operations import (
|
|
49
52
|
dedupe_extracted_nodes,
|
|
50
53
|
dedupe_node_list,
|
|
@@ -68,7 +71,7 @@ class RawEpisode(BaseModel):
|
|
|
68
71
|
async def retrieve_previous_episodes_bulk(
|
|
69
72
|
driver: AsyncDriver, episodes: list[EpisodicNode]
|
|
70
73
|
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
|
|
71
|
-
previous_episodes_list = await
|
|
74
|
+
previous_episodes_list = await semaphore_gather(
|
|
72
75
|
*[
|
|
73
76
|
retrieve_episodes(
|
|
74
77
|
driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN, group_ids=[episode.group_id]
|
|
@@ -115,7 +118,7 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
115
118
|
async def extract_nodes_and_edges_bulk(
|
|
116
119
|
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
|
|
117
120
|
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
|
|
118
|
-
extracted_nodes_bulk = await
|
|
121
|
+
extracted_nodes_bulk = await semaphore_gather(
|
|
119
122
|
*[
|
|
120
123
|
extract_nodes(llm_client, episode, previous_episodes)
|
|
121
124
|
for episode, previous_episodes in episode_tuples
|
|
@@ -127,7 +130,7 @@ async def extract_nodes_and_edges_bulk(
|
|
|
127
130
|
[episode[1] for episode in episode_tuples],
|
|
128
131
|
)
|
|
129
132
|
|
|
130
|
-
extracted_edges_bulk = await
|
|
133
|
+
extracted_edges_bulk = await semaphore_gather(
|
|
131
134
|
*[
|
|
132
135
|
extract_edges(
|
|
133
136
|
llm_client,
|
|
@@ -168,13 +171,13 @@ async def dedupe_nodes_bulk(
|
|
|
168
171
|
node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
|
|
169
172
|
|
|
170
173
|
existing_nodes_chunks: list[list[EntityNode]] = list(
|
|
171
|
-
await
|
|
174
|
+
await semaphore_gather(
|
|
172
175
|
*[get_relevant_nodes(driver, node_chunk) for node_chunk in node_chunks]
|
|
173
176
|
)
|
|
174
177
|
)
|
|
175
178
|
|
|
176
179
|
results: list[tuple[list[EntityNode], dict[str, str]]] = list(
|
|
177
|
-
await
|
|
180
|
+
await semaphore_gather(
|
|
178
181
|
*[
|
|
179
182
|
dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i])
|
|
180
183
|
for i, node_chunk in enumerate(node_chunks)
|
|
@@ -202,13 +205,13 @@ async def dedupe_edges_bulk(
|
|
|
202
205
|
]
|
|
203
206
|
|
|
204
207
|
relevant_edges_chunks: list[list[EntityEdge]] = list(
|
|
205
|
-
await
|
|
208
|
+
await semaphore_gather(
|
|
206
209
|
*[get_relevant_edges(driver, edge_chunk, None, None) for edge_chunk in edge_chunks]
|
|
207
210
|
)
|
|
208
211
|
)
|
|
209
212
|
|
|
210
213
|
resolved_edge_chunks: list[list[EntityEdge]] = list(
|
|
211
|
-
await
|
|
214
|
+
await semaphore_gather(
|
|
212
215
|
*[
|
|
213
216
|
dedupe_extracted_edges(llm_client, edge_chunk, relevant_edges_chunks[i])
|
|
214
217
|
for i, edge_chunk in enumerate(edge_chunks)
|
|
@@ -289,7 +292,9 @@ async def compress_nodes(
|
|
|
289
292
|
# add both nodes to the shortest chunk
|
|
290
293
|
node_chunks[-1].extend([n, m])
|
|
291
294
|
|
|
292
|
-
results = await
|
|
295
|
+
results = await semaphore_gather(
|
|
296
|
+
*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks]
|
|
297
|
+
)
|
|
293
298
|
|
|
294
299
|
extended_map = dict(uuid_map)
|
|
295
300
|
compressed_nodes: list[EntityNode] = []
|
|
@@ -312,7 +317,9 @@ async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list
|
|
|
312
317
|
# We build a map of the edges based on their source and target nodes.
|
|
313
318
|
edge_chunks = chunk_edges_by_nodes(edges)
|
|
314
319
|
|
|
315
|
-
results = await
|
|
320
|
+
results = await semaphore_gather(
|
|
321
|
+
*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks]
|
|
322
|
+
)
|
|
316
323
|
|
|
317
324
|
compressed_edges: list[EntityEdge] = []
|
|
318
325
|
for edge_chunk in results:
|
|
@@ -365,7 +372,7 @@ async def extract_edge_dates_bulk(
|
|
|
365
372
|
episode.uuid: (episode, previous_episodes) for episode, previous_episodes in episode_pairs
|
|
366
373
|
}
|
|
367
374
|
|
|
368
|
-
results = await
|
|
375
|
+
results = await semaphore_gather(
|
|
369
376
|
*[
|
|
370
377
|
extract_edge_dates(
|
|
371
378
|
llm_client,
|
|
@@ -385,7 +392,7 @@ async def extract_edge_dates_bulk(
|
|
|
385
392
|
edge.valid_at = valid_at
|
|
386
393
|
edge.invalid_at = invalid_at
|
|
387
394
|
if edge.invalid_at:
|
|
388
|
-
edge.expired_at =
|
|
395
|
+
edge.expired_at = utc_now()
|
|
389
396
|
|
|
390
397
|
return edges
|
|
391
398
|
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from datetime import datetime, timezone
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def utc_now() -> datetime:
|
|
21
|
+
"""Returns the current UTC datetime with timezone information."""
|
|
22
|
+
return datetime.now(timezone.utc)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def ensure_utc(dt: datetime | None) -> datetime | None:
|
|
26
|
+
"""
|
|
27
|
+
Ensures a datetime is timezone-aware and in UTC.
|
|
28
|
+
If the datetime is naive (no timezone), assumes it's in UTC.
|
|
29
|
+
If the datetime has a different timezone, converts it to UTC.
|
|
30
|
+
Returns None if input is None.
|
|
31
|
+
"""
|
|
32
|
+
if dt is None:
|
|
33
|
+
return None
|
|
34
|
+
|
|
35
|
+
if dt.tzinfo is None:
|
|
36
|
+
# If datetime is naive, assume it's UTC
|
|
37
|
+
return dt.replace(tzinfo=timezone.utc)
|
|
38
|
+
elif dt.tzinfo != timezone.utc:
|
|
39
|
+
# If datetime has a different timezone, convert to UTC
|
|
40
|
+
return dt.astimezone(timezone.utc)
|
|
41
|
+
|
|
42
|
+
return dt
|
|
@@ -1,14 +1,13 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import logging
|
|
3
3
|
from collections import defaultdict
|
|
4
|
-
from datetime import datetime, timezone
|
|
5
4
|
|
|
6
5
|
from neo4j import AsyncDriver
|
|
7
6
|
from pydantic import BaseModel
|
|
8
7
|
|
|
9
8
|
from graphiti_core.edges import CommunityEdge
|
|
10
9
|
from graphiti_core.embedder import EmbedderClient
|
|
11
|
-
from graphiti_core.helpers import DEFAULT_DATABASE
|
|
10
|
+
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
|
12
11
|
from graphiti_core.llm_client import LLMClient
|
|
13
12
|
from graphiti_core.nodes import (
|
|
14
13
|
CommunityNode,
|
|
@@ -16,6 +15,8 @@ from graphiti_core.nodes import (
|
|
|
16
15
|
get_community_node_from_record,
|
|
17
16
|
)
|
|
18
17
|
from graphiti_core.prompts import prompt_library
|
|
18
|
+
from graphiti_core.prompts.summarize_nodes import Summary, SummaryDescription
|
|
19
|
+
from graphiti_core.utils.datetime_utils import utc_now
|
|
19
20
|
from graphiti_core.utils.maintenance.edge_operations import build_community_edges
|
|
20
21
|
|
|
21
22
|
MAX_COMMUNITY_BUILD_CONCURRENCY = 10
|
|
@@ -70,7 +71,7 @@ async def get_community_clusters(
|
|
|
70
71
|
|
|
71
72
|
community_clusters.extend(
|
|
72
73
|
list(
|
|
73
|
-
await
|
|
74
|
+
await semaphore_gather(
|
|
74
75
|
*[EntityNode.get_by_uuids(driver, cluster) for cluster in cluster_uuids]
|
|
75
76
|
)
|
|
76
77
|
)
|
|
@@ -131,7 +132,7 @@ async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -
|
|
|
131
132
|
context = {'node_summaries': [{'summary': summary} for summary in summary_pair]}
|
|
132
133
|
|
|
133
134
|
llm_response = await llm_client.generate_response(
|
|
134
|
-
prompt_library.summarize_nodes.summarize_pair(context)
|
|
135
|
+
prompt_library.summarize_nodes.summarize_pair(context), response_model=Summary
|
|
135
136
|
)
|
|
136
137
|
|
|
137
138
|
pair_summary = llm_response.get('summary', '')
|
|
@@ -143,7 +144,8 @@ async def generate_summary_description(llm_client: LLMClient, summary: str) -> s
|
|
|
143
144
|
context = {'summary': summary}
|
|
144
145
|
|
|
145
146
|
llm_response = await llm_client.generate_response(
|
|
146
|
-
prompt_library.summarize_nodes.summary_description(context)
|
|
147
|
+
prompt_library.summarize_nodes.summary_description(context),
|
|
148
|
+
response_model=SummaryDescription,
|
|
147
149
|
)
|
|
148
150
|
|
|
149
151
|
description = llm_response.get('description', '')
|
|
@@ -162,7 +164,7 @@ async def build_community(
|
|
|
162
164
|
odd_one_out = summaries.pop()
|
|
163
165
|
length -= 1
|
|
164
166
|
new_summaries: list[str] = list(
|
|
165
|
-
await
|
|
167
|
+
await semaphore_gather(
|
|
166
168
|
*[
|
|
167
169
|
summarize_pair(llm_client, (str(left_summary), str(right_summary)))
|
|
168
170
|
for left_summary, right_summary in zip(
|
|
@@ -178,7 +180,7 @@ async def build_community(
|
|
|
178
180
|
|
|
179
181
|
summary = summaries[0]
|
|
180
182
|
name = await generate_summary_description(llm_client, summary)
|
|
181
|
-
now =
|
|
183
|
+
now = utc_now()
|
|
182
184
|
community_node = CommunityNode(
|
|
183
185
|
name=name,
|
|
184
186
|
group_id=community_cluster[0].group_id,
|
|
@@ -205,7 +207,9 @@ async def build_communities(
|
|
|
205
207
|
return await build_community(llm_client, cluster)
|
|
206
208
|
|
|
207
209
|
communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list(
|
|
208
|
-
await
|
|
210
|
+
await semaphore_gather(
|
|
211
|
+
*[limited_build_community(cluster) for cluster in community_clusters]
|
|
212
|
+
)
|
|
209
213
|
)
|
|
210
214
|
|
|
211
215
|
community_nodes: list[CommunityNode] = []
|
|
@@ -305,7 +309,7 @@ async def update_community(
|
|
|
305
309
|
community.name = new_name
|
|
306
310
|
|
|
307
311
|
if is_new:
|
|
308
|
-
community_edge = (build_community_edges([entity], community,
|
|
312
|
+
community_edge = (build_community_edges([entity], community, utc_now()))[0]
|
|
309
313
|
await community_edge.save(driver)
|
|
310
314
|
|
|
311
315
|
await community.generate_name_embedding(embedder)
|