graphiti-core 0.3.8__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/edges.py +8 -8
- graphiti_core/errors.py +8 -0
- graphiti_core/graphiti.py +44 -24
- graphiti_core/helpers.py +15 -1
- graphiti_core/nodes.py +16 -8
- graphiti_core/prompts/eval.py +28 -2
- graphiti_core/prompts/extract_edge_dates.py +8 -9
- graphiti_core/prompts/extract_edges.py +3 -2
- graphiti_core/prompts/invalidate_edges.py +1 -1
- graphiti_core/search/search.py +62 -46
- graphiti_core/search/search_config.py +13 -3
- graphiti_core/search/search_config_recipes.py +42 -1
- graphiti_core/search/search_utils.py +53 -13
- graphiti_core/utils/maintenance/__init__.py +0 -2
- graphiti_core/utils/maintenance/community_operations.py +14 -26
- graphiti_core/utils/maintenance/edge_operations.py +7 -13
- graphiti_core/utils/maintenance/node_operations.py +5 -5
- graphiti_core/utils/maintenance/temporal_operations.py +4 -126
- {graphiti_core-0.3.8.dist-info → graphiti_core-0.3.11.dist-info}/METADATA +2 -1
- {graphiti_core-0.3.8.dist-info → graphiti_core-0.3.11.dist-info}/RECORD +22 -22
- {graphiti_core-0.3.8.dist-info → graphiti_core-0.3.11.dist-info}/LICENSE +0 -0
- {graphiti_core-0.3.8.dist-info → graphiti_core-0.3.11.dist-info}/WHEEL +0 -0
|
@@ -20,6 +20,7 @@ from pydantic import BaseModel, Field
|
|
|
20
20
|
|
|
21
21
|
from graphiti_core.edges import EntityEdge
|
|
22
22
|
from graphiti_core.nodes import CommunityNode, EntityNode
|
|
23
|
+
from graphiti_core.search.search_utils import DEFAULT_MIN_SCORE, DEFAULT_MMR_LAMBDA
|
|
23
24
|
|
|
24
25
|
DEFAULT_SEARCH_LIMIT = 10
|
|
25
26
|
|
|
@@ -43,31 +44,40 @@ class EdgeReranker(Enum):
|
|
|
43
44
|
rrf = 'reciprocal_rank_fusion'
|
|
44
45
|
node_distance = 'node_distance'
|
|
45
46
|
episode_mentions = 'episode_mentions'
|
|
47
|
+
mmr = 'mmr'
|
|
46
48
|
|
|
47
49
|
|
|
48
50
|
class NodeReranker(Enum):
|
|
49
51
|
rrf = 'reciprocal_rank_fusion'
|
|
50
52
|
node_distance = 'node_distance'
|
|
51
53
|
episode_mentions = 'episode_mentions'
|
|
54
|
+
mmr = 'mmr'
|
|
52
55
|
|
|
53
56
|
|
|
54
57
|
class CommunityReranker(Enum):
|
|
55
58
|
rrf = 'reciprocal_rank_fusion'
|
|
59
|
+
mmr = 'mmr'
|
|
56
60
|
|
|
57
61
|
|
|
58
62
|
class EdgeSearchConfig(BaseModel):
|
|
59
63
|
search_methods: list[EdgeSearchMethod]
|
|
60
|
-
reranker: EdgeReranker
|
|
64
|
+
reranker: EdgeReranker = Field(default=EdgeReranker.rrf)
|
|
65
|
+
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
|
66
|
+
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
|
61
67
|
|
|
62
68
|
|
|
63
69
|
class NodeSearchConfig(BaseModel):
|
|
64
70
|
search_methods: list[NodeSearchMethod]
|
|
65
|
-
reranker: NodeReranker
|
|
71
|
+
reranker: NodeReranker = Field(default=NodeReranker.rrf)
|
|
72
|
+
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
|
73
|
+
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
|
66
74
|
|
|
67
75
|
|
|
68
76
|
class CommunitySearchConfig(BaseModel):
|
|
69
77
|
search_methods: list[CommunitySearchMethod]
|
|
70
|
-
reranker: CommunityReranker
|
|
78
|
+
reranker: CommunityReranker = Field(default=CommunityReranker.rrf)
|
|
79
|
+
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
|
80
|
+
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
|
71
81
|
|
|
72
82
|
|
|
73
83
|
class SearchConfig(BaseModel):
|
|
@@ -43,6 +43,22 @@ COMBINED_HYBRID_SEARCH_RRF = SearchConfig(
|
|
|
43
43
|
),
|
|
44
44
|
)
|
|
45
45
|
|
|
46
|
+
# Performs a hybrid search with mmr reranking over edges, nodes, and communities
|
|
47
|
+
COMBINED_HYBRID_SEARCH_MMR = SearchConfig(
|
|
48
|
+
edge_config=EdgeSearchConfig(
|
|
49
|
+
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
|
50
|
+
reranker=EdgeReranker.mmr,
|
|
51
|
+
),
|
|
52
|
+
node_config=NodeSearchConfig(
|
|
53
|
+
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
|
|
54
|
+
reranker=NodeReranker.mmr,
|
|
55
|
+
),
|
|
56
|
+
community_config=CommunitySearchConfig(
|
|
57
|
+
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
|
|
58
|
+
reranker=CommunityReranker.mmr,
|
|
59
|
+
),
|
|
60
|
+
)
|
|
61
|
+
|
|
46
62
|
# performs a hybrid search over edges with rrf reranking
|
|
47
63
|
EDGE_HYBRID_SEARCH_RRF = SearchConfig(
|
|
48
64
|
edge_config=EdgeSearchConfig(
|
|
@@ -51,12 +67,21 @@ EDGE_HYBRID_SEARCH_RRF = SearchConfig(
|
|
|
51
67
|
)
|
|
52
68
|
)
|
|
53
69
|
|
|
70
|
+
# performs a hybrid search over edges with mmr reranking
|
|
71
|
+
EDGE_HYBRID_SEARCH_MMR = SearchConfig(
|
|
72
|
+
edge_config=EdgeSearchConfig(
|
|
73
|
+
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
|
74
|
+
reranker=EdgeReranker.mmr,
|
|
75
|
+
)
|
|
76
|
+
)
|
|
77
|
+
|
|
54
78
|
# performs a hybrid search over edges with node distance reranking
|
|
55
79
|
EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
|
|
56
80
|
edge_config=EdgeSearchConfig(
|
|
57
81
|
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
|
58
82
|
reranker=EdgeReranker.node_distance,
|
|
59
|
-
)
|
|
83
|
+
),
|
|
84
|
+
limit=30,
|
|
60
85
|
)
|
|
61
86
|
|
|
62
87
|
# performs a hybrid search over edges with episode mention reranking
|
|
@@ -75,6 +100,14 @@ NODE_HYBRID_SEARCH_RRF = SearchConfig(
|
|
|
75
100
|
)
|
|
76
101
|
)
|
|
77
102
|
|
|
103
|
+
# performs a hybrid search over nodes with mmr reranking
|
|
104
|
+
NODE_HYBRID_SEARCH_MMR = SearchConfig(
|
|
105
|
+
node_config=NodeSearchConfig(
|
|
106
|
+
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
|
|
107
|
+
reranker=NodeReranker.mmr,
|
|
108
|
+
)
|
|
109
|
+
)
|
|
110
|
+
|
|
78
111
|
# performs a hybrid search over nodes with node distance reranking
|
|
79
112
|
NODE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
|
|
80
113
|
node_config=NodeSearchConfig(
|
|
@@ -98,3 +131,11 @@ COMMUNITY_HYBRID_SEARCH_RRF = SearchConfig(
|
|
|
98
131
|
reranker=CommunityReranker.rrf,
|
|
99
132
|
)
|
|
100
133
|
)
|
|
134
|
+
|
|
135
|
+
# performs a hybrid search over communities with mmr reranking
|
|
136
|
+
COMMUNITY_HYBRID_SEARCH_MMR = SearchConfig(
|
|
137
|
+
community_config=CommunitySearchConfig(
|
|
138
|
+
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
|
|
139
|
+
reranker=CommunityReranker.mmr,
|
|
140
|
+
)
|
|
141
|
+
)
|
|
@@ -19,10 +19,11 @@ import logging
|
|
|
19
19
|
from collections import defaultdict
|
|
20
20
|
from time import time
|
|
21
21
|
|
|
22
|
+
import numpy as np
|
|
22
23
|
from neo4j import AsyncDriver, Query
|
|
23
24
|
|
|
24
25
|
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
|
25
|
-
from graphiti_core.helpers import lucene_sanitize
|
|
26
|
+
from graphiti_core.helpers import lucene_sanitize, normalize_l2
|
|
26
27
|
from graphiti_core.nodes import (
|
|
27
28
|
CommunityNode,
|
|
28
29
|
EntityNode,
|
|
@@ -34,6 +35,8 @@ from graphiti_core.nodes import (
|
|
|
34
35
|
logger = logging.getLogger(__name__)
|
|
35
36
|
|
|
36
37
|
RELEVANT_SCHEMA_LIMIT = 3
|
|
38
|
+
DEFAULT_MIN_SCORE = 0.6
|
|
39
|
+
DEFAULT_MMR_LAMBDA = 0.5
|
|
37
40
|
|
|
38
41
|
|
|
39
42
|
def fulltext_query(query: str, group_ids: list[str] | None = None):
|
|
@@ -52,6 +55,21 @@ def fulltext_query(query: str, group_ids: list[str] | None = None):
|
|
|
52
55
|
return full_query
|
|
53
56
|
|
|
54
57
|
|
|
58
|
+
async def get_episodes_by_mentions(
|
|
59
|
+
driver: AsyncDriver,
|
|
60
|
+
nodes: list[EntityNode],
|
|
61
|
+
edges: list[EntityEdge],
|
|
62
|
+
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
63
|
+
) -> list[EpisodicNode]:
|
|
64
|
+
episode_uuids: list[str] = []
|
|
65
|
+
for edge in edges:
|
|
66
|
+
episode_uuids.extend(edge.episodes)
|
|
67
|
+
|
|
68
|
+
episodes = await EpisodicNode.get_by_uuids(driver, episode_uuids[:limit])
|
|
69
|
+
|
|
70
|
+
return episodes
|
|
71
|
+
|
|
72
|
+
|
|
55
73
|
async def get_mentioned_nodes(
|
|
56
74
|
driver: AsyncDriver, episodes: list[EpisodicNode]
|
|
57
75
|
) -> list[EntityNode]:
|
|
@@ -113,9 +131,6 @@ async def edge_fulltext_search(
|
|
|
113
131
|
CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query)
|
|
114
132
|
YIELD relationship AS rel, score
|
|
115
133
|
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
|
116
|
-
WHERE ($source_uuid IS NULL OR n.uuid = $source_uuid)
|
|
117
|
-
AND ($target_uuid IS NULL OR m.uuid = $target_uuid)
|
|
118
|
-
AND ($group_ids IS NULL OR n.group_id IN $group_ids)
|
|
119
134
|
RETURN
|
|
120
135
|
r.uuid AS uuid,
|
|
121
136
|
r.group_id AS group_id,
|
|
@@ -153,15 +168,18 @@ async def edge_similarity_search(
|
|
|
153
168
|
target_node_uuid: str | None,
|
|
154
169
|
group_ids: list[str] | None = None,
|
|
155
170
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
171
|
+
min_score: float = DEFAULT_MIN_SCORE,
|
|
156
172
|
) -> list[EntityEdge]:
|
|
157
173
|
# vector similarity search over embedded facts
|
|
158
174
|
query = Query("""
|
|
175
|
+
CYPHER runtime = parallel parallelRuntimeSupport=all
|
|
159
176
|
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
|
|
160
177
|
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
|
|
161
178
|
AND ($source_uuid IS NULL OR n.uuid = $source_uuid)
|
|
162
179
|
AND ($target_uuid IS NULL OR m.uuid = $target_uuid)
|
|
180
|
+
WITH n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
|
181
|
+
WHERE score > $min_score
|
|
163
182
|
RETURN
|
|
164
|
-
vector.similarity.cosine(r.fact_embedding, $search_vector) AS score,
|
|
165
183
|
r.uuid AS uuid,
|
|
166
184
|
r.group_id AS group_id,
|
|
167
185
|
n.uuid AS source_node_uuid,
|
|
@@ -185,6 +203,7 @@ async def edge_similarity_search(
|
|
|
185
203
|
target_uuid=target_node_uuid,
|
|
186
204
|
group_ids=group_ids,
|
|
187
205
|
limit=limit,
|
|
206
|
+
min_score=min_score,
|
|
188
207
|
)
|
|
189
208
|
|
|
190
209
|
edges = [get_entity_edge_from_record(record) for record in records]
|
|
@@ -205,7 +224,6 @@ async def node_fulltext_search(
|
|
|
205
224
|
"""
|
|
206
225
|
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
|
|
207
226
|
YIELD node AS n, score
|
|
208
|
-
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
|
209
227
|
RETURN
|
|
210
228
|
n.uuid AS uuid,
|
|
211
229
|
n.group_id AS group_id,
|
|
@@ -230,14 +248,17 @@ async def node_similarity_search(
|
|
|
230
248
|
search_vector: list[float],
|
|
231
249
|
group_ids: list[str] | None = None,
|
|
232
250
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
251
|
+
min_score: float = DEFAULT_MIN_SCORE,
|
|
233
252
|
) -> list[EntityNode]:
|
|
234
253
|
# vector similarity search over entity names
|
|
235
254
|
records, _, _ = await driver.execute_query(
|
|
236
255
|
"""
|
|
256
|
+
CYPHER runtime = parallel parallelRuntimeSupport=all
|
|
237
257
|
MATCH (n:Entity)
|
|
238
258
|
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
|
259
|
+
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
|
|
260
|
+
WHERE score > $min_score
|
|
239
261
|
RETURN
|
|
240
|
-
vector.similarity.cosine(n.name_embedding, $search_vector) AS score,
|
|
241
262
|
n.uuid As uuid,
|
|
242
263
|
n.group_id AS group_id,
|
|
243
264
|
n.name AS name,
|
|
@@ -250,6 +271,7 @@ async def node_similarity_search(
|
|
|
250
271
|
search_vector=search_vector,
|
|
251
272
|
group_ids=group_ids,
|
|
252
273
|
limit=limit,
|
|
274
|
+
min_score=min_score,
|
|
253
275
|
)
|
|
254
276
|
nodes = [get_entity_node_from_record(record) for record in records]
|
|
255
277
|
|
|
@@ -269,8 +291,6 @@ async def community_fulltext_search(
|
|
|
269
291
|
"""
|
|
270
292
|
CALL db.index.fulltext.queryNodes("community_name", $query)
|
|
271
293
|
YIELD node AS comm, score
|
|
272
|
-
MATCH (comm:Community)
|
|
273
|
-
WHERE $group_ids IS NULL OR comm.group_id in $group_ids
|
|
274
294
|
RETURN
|
|
275
295
|
comm.uuid AS uuid,
|
|
276
296
|
comm.group_id AS group_id,
|
|
@@ -295,14 +315,17 @@ async def community_similarity_search(
|
|
|
295
315
|
search_vector: list[float],
|
|
296
316
|
group_ids: list[str] | None = None,
|
|
297
317
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
318
|
+
min_score=DEFAULT_MIN_SCORE,
|
|
298
319
|
) -> list[CommunityNode]:
|
|
299
320
|
# vector similarity search over entity names
|
|
300
321
|
records, _, _ = await driver.execute_query(
|
|
301
322
|
"""
|
|
323
|
+
CYPHER runtime = parallel parallelRuntimeSupport=all
|
|
302
324
|
MATCH (comm:Community)
|
|
303
325
|
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
|
|
326
|
+
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
|
|
327
|
+
WHERE score > $min_score
|
|
304
328
|
RETURN
|
|
305
|
-
vector.similarity.cosine(comm.name_embedding, $search_vector) AS score,
|
|
306
329
|
comm.uuid As uuid,
|
|
307
330
|
comm.group_id AS group_id,
|
|
308
331
|
comm.name AS name,
|
|
@@ -315,6 +338,7 @@ async def community_similarity_search(
|
|
|
315
338
|
search_vector=search_vector,
|
|
316
339
|
group_ids=group_ids,
|
|
317
340
|
limit=limit,
|
|
341
|
+
min_score=min_score,
|
|
318
342
|
)
|
|
319
343
|
communities = [get_community_node_from_record(record) for record in records]
|
|
320
344
|
|
|
@@ -384,7 +408,7 @@ async def hybrid_node_search(
|
|
|
384
408
|
relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids]
|
|
385
409
|
|
|
386
410
|
end = time()
|
|
387
|
-
logger.
|
|
411
|
+
logger.debug(f'Found relevant nodes: {ranked_uuids} in {(end - start) * 1000} ms')
|
|
388
412
|
return relevant_nodes
|
|
389
413
|
|
|
390
414
|
|
|
@@ -467,7 +491,7 @@ async def get_relevant_edges(
|
|
|
467
491
|
relevant_edges.append(edge)
|
|
468
492
|
|
|
469
493
|
end = time()
|
|
470
|
-
logger.
|
|
494
|
+
logger.debug(f'Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms')
|
|
471
495
|
|
|
472
496
|
return relevant_edges
|
|
473
497
|
|
|
@@ -520,7 +544,7 @@ async def node_distance_reranker(
|
|
|
520
544
|
# rerank on shortest distance
|
|
521
545
|
filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
|
522
546
|
|
|
523
|
-
# add back in filtered center
|
|
547
|
+
# add back in filtered center uuid
|
|
524
548
|
filtered_uuids = [center_node_uuid] + filtered_uuids
|
|
525
549
|
|
|
526
550
|
return filtered_uuids
|
|
@@ -555,3 +579,19 @@ async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[s
|
|
|
555
579
|
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
|
556
580
|
|
|
557
581
|
return sorted_uuids
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
def maximal_marginal_relevance(
|
|
585
|
+
query_vector: list[float],
|
|
586
|
+
candidates: list[tuple[str, list[float]]],
|
|
587
|
+
mmr_lambda: float = DEFAULT_MMR_LAMBDA,
|
|
588
|
+
):
|
|
589
|
+
candidates_with_mmr: list[tuple[str, float]] = []
|
|
590
|
+
for candidate in candidates:
|
|
591
|
+
max_sim = max([np.dot(normalize_l2(candidate[1]), normalize_l2(c[1])) for c in candidates])
|
|
592
|
+
mmr = mmr_lambda * np.dot(candidate[1], query_vector) - (1 - mmr_lambda) * max_sim
|
|
593
|
+
candidates_with_mmr.append((candidate[0], mmr))
|
|
594
|
+
|
|
595
|
+
candidates_with_mmr.sort(reverse=True, key=lambda c: c[1])
|
|
596
|
+
|
|
597
|
+
return [candidate[0] for candidate in candidates_with_mmr]
|
|
@@ -4,7 +4,6 @@ from .graph_data_operations import (
|
|
|
4
4
|
retrieve_episodes,
|
|
5
5
|
)
|
|
6
6
|
from .node_operations import extract_nodes
|
|
7
|
-
from .temporal_operations import invalidate_edges
|
|
8
7
|
|
|
9
8
|
__all__ = [
|
|
10
9
|
'extract_edges',
|
|
@@ -12,5 +11,4 @@ __all__ = [
|
|
|
12
11
|
'extract_nodes',
|
|
13
12
|
'clear_data',
|
|
14
13
|
'retrieve_episodes',
|
|
15
|
-
'invalidate_edges',
|
|
16
14
|
]
|
|
@@ -15,7 +15,6 @@ from graphiti_core.utils.maintenance.edge_operations import build_community_edge
|
|
|
15
15
|
|
|
16
16
|
MAX_COMMUNITY_BUILD_CONCURRENCY = 10
|
|
17
17
|
|
|
18
|
-
|
|
19
18
|
logger = logging.getLogger(__name__)
|
|
20
19
|
|
|
21
20
|
|
|
@@ -24,31 +23,20 @@ class Neighbor(BaseModel):
|
|
|
24
23
|
edge_count: int
|
|
25
24
|
|
|
26
25
|
|
|
27
|
-
async def
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
{RELATES_TO: {
|
|
31
|
-
type: "RELATES_TO",
|
|
32
|
-
orientation: "UNDIRECTED",
|
|
33
|
-
properties: {weight: {property: "*", aggregation: "COUNT"}}
|
|
34
|
-
}}
|
|
35
|
-
)
|
|
36
|
-
YIELD graphName AS graph, nodeProjection AS nodes, relationshipProjection AS edges
|
|
37
|
-
""")
|
|
38
|
-
|
|
39
|
-
return records[0]['graph']
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
async def get_community_clusters(driver: AsyncDriver) -> list[list[EntityNode]]:
|
|
26
|
+
async def get_community_clusters(
|
|
27
|
+
driver: AsyncDriver, group_ids: list[str] | None
|
|
28
|
+
) -> list[list[EntityNode]]:
|
|
43
29
|
community_clusters: list[list[EntityNode]] = []
|
|
44
30
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
31
|
+
if group_ids is None:
|
|
32
|
+
group_id_values, _, _ = await driver.execute_query("""
|
|
33
|
+
MATCH (n:Entity WHERE n.group_id IS NOT NULL)
|
|
34
|
+
RETURN
|
|
35
|
+
collect(DISTINCT n.group_id) AS group_ids
|
|
36
|
+
""")
|
|
37
|
+
|
|
38
|
+
group_ids = group_id_values[0]['group_ids']
|
|
50
39
|
|
|
51
|
-
group_ids = group_id_values[0]['group_ids']
|
|
52
40
|
for group_id in group_ids:
|
|
53
41
|
projection: dict[str, list[Neighbor]] = {}
|
|
54
42
|
nodes = await EntityNode.get_by_group_ids(driver, [group_id])
|
|
@@ -191,15 +179,15 @@ async def build_community(
|
|
|
191
179
|
)
|
|
192
180
|
community_edges = build_community_edges(community_cluster, community_node, now)
|
|
193
181
|
|
|
194
|
-
logger.
|
|
182
|
+
logger.debug((community_node, community_edges))
|
|
195
183
|
|
|
196
184
|
return community_node, community_edges
|
|
197
185
|
|
|
198
186
|
|
|
199
187
|
async def build_communities(
|
|
200
|
-
driver: AsyncDriver, llm_client: LLMClient
|
|
188
|
+
driver: AsyncDriver, llm_client: LLMClient, group_ids: list[str] | None
|
|
201
189
|
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
|
202
|
-
community_clusters = await get_community_clusters(driver)
|
|
190
|
+
community_clusters = await get_community_clusters(driver, group_ids)
|
|
203
191
|
|
|
204
192
|
semaphore = asyncio.Semaphore(MAX_COMMUNITY_BUILD_CONCURRENCY)
|
|
205
193
|
|
|
@@ -97,7 +97,7 @@ async def extract_edges(
|
|
|
97
97
|
edges_data = llm_response.get('edges', [])
|
|
98
98
|
|
|
99
99
|
end = time()
|
|
100
|
-
logger.
|
|
100
|
+
logger.debug(f'Extracted new edges: {edges_data} in {(end - start) * 1000} ms')
|
|
101
101
|
|
|
102
102
|
# Convert the extracted data into EntityEdge objects
|
|
103
103
|
edges = []
|
|
@@ -115,19 +115,13 @@ async def extract_edges(
|
|
|
115
115
|
invalid_at=None,
|
|
116
116
|
)
|
|
117
117
|
edges.append(edge)
|
|
118
|
-
logger.
|
|
118
|
+
logger.debug(
|
|
119
119
|
f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})'
|
|
120
120
|
)
|
|
121
121
|
|
|
122
122
|
return edges
|
|
123
123
|
|
|
124
124
|
|
|
125
|
-
def create_edge_identifier(
|
|
126
|
-
source_node: EntityNode, edge: EntityEdge, target_node: EntityNode
|
|
127
|
-
) -> str:
|
|
128
|
-
return f'{source_node.name}-{edge.name}-{target_node.name}'
|
|
129
|
-
|
|
130
|
-
|
|
131
125
|
async def dedupe_extracted_edges(
|
|
132
126
|
llm_client: LLMClient,
|
|
133
127
|
extracted_edges: list[EntityEdge],
|
|
@@ -150,7 +144,7 @@ async def dedupe_extracted_edges(
|
|
|
150
144
|
|
|
151
145
|
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v1(context))
|
|
152
146
|
duplicate_data = llm_response.get('duplicates', [])
|
|
153
|
-
logger.
|
|
147
|
+
logger.debug(f'Extracted unique edges: {duplicate_data}')
|
|
154
148
|
|
|
155
149
|
duplicate_uuid_map: dict[str, str] = {}
|
|
156
150
|
for duplicate in duplicate_data:
|
|
@@ -251,11 +245,11 @@ async def resolve_extracted_edge(
|
|
|
251
245
|
if (
|
|
252
246
|
edge.invalid_at is not None
|
|
253
247
|
and resolved_edge.valid_at is not None
|
|
254
|
-
and edge.invalid_at
|
|
248
|
+
and edge.invalid_at <= resolved_edge.valid_at
|
|
255
249
|
) or (
|
|
256
250
|
edge.valid_at is not None
|
|
257
251
|
and resolved_edge.invalid_at is not None
|
|
258
|
-
and resolved_edge.invalid_at
|
|
252
|
+
and resolved_edge.invalid_at <= edge.valid_at
|
|
259
253
|
):
|
|
260
254
|
continue
|
|
261
255
|
# New edge invalidates edge
|
|
@@ -305,7 +299,7 @@ async def dedupe_extracted_edge(
|
|
|
305
299
|
edge = existing_edge
|
|
306
300
|
|
|
307
301
|
end = time()
|
|
308
|
-
logger.
|
|
302
|
+
logger.debug(
|
|
309
303
|
f'Resolved Edge: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms'
|
|
310
304
|
)
|
|
311
305
|
|
|
@@ -332,7 +326,7 @@ async def dedupe_edge_list(
|
|
|
332
326
|
unique_edges_data = llm_response.get('unique_facts', [])
|
|
333
327
|
|
|
334
328
|
end = time()
|
|
335
|
-
logger.
|
|
329
|
+
logger.debug(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')
|
|
336
330
|
|
|
337
331
|
# Get full edge data
|
|
338
332
|
unique_edges = []
|
|
@@ -104,7 +104,7 @@ async def extract_nodes(
|
|
|
104
104
|
extracted_node_data = await extract_json_nodes(llm_client, episode)
|
|
105
105
|
|
|
106
106
|
end = time()
|
|
107
|
-
logger.
|
|
107
|
+
logger.debug(f'Extracted new nodes: {extracted_node_data} in {(end - start) * 1000} ms')
|
|
108
108
|
# Convert the extracted data into EntityNode objects
|
|
109
109
|
new_nodes = []
|
|
110
110
|
for node_data in extracted_node_data:
|
|
@@ -116,7 +116,7 @@ async def extract_nodes(
|
|
|
116
116
|
created_at=datetime.now(),
|
|
117
117
|
)
|
|
118
118
|
new_nodes.append(new_node)
|
|
119
|
-
logger.
|
|
119
|
+
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
|
|
120
120
|
|
|
121
121
|
return new_nodes
|
|
122
122
|
|
|
@@ -152,7 +152,7 @@ async def dedupe_extracted_nodes(
|
|
|
152
152
|
duplicate_data = llm_response.get('duplicates', [])
|
|
153
153
|
|
|
154
154
|
end = time()
|
|
155
|
-
logger.
|
|
155
|
+
logger.debug(f'Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms')
|
|
156
156
|
|
|
157
157
|
uuid_map: dict[str, str] = {}
|
|
158
158
|
for duplicate in duplicate_data:
|
|
@@ -232,7 +232,7 @@ async def resolve_extracted_node(
|
|
|
232
232
|
uuid_map[extracted_node.uuid] = existing_node.uuid
|
|
233
233
|
|
|
234
234
|
end = time()
|
|
235
|
-
logger.
|
|
235
|
+
logger.debug(
|
|
236
236
|
f'Resolved node: {extracted_node.name} is {node.name}, in {(end - start) * 1000} ms'
|
|
237
237
|
)
|
|
238
238
|
|
|
@@ -266,7 +266,7 @@ async def dedupe_node_list(
|
|
|
266
266
|
nodes_data = llm_response.get('nodes', [])
|
|
267
267
|
|
|
268
268
|
end = time()
|
|
269
|
-
logger.
|
|
269
|
+
logger.debug(f'Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms')
|
|
270
270
|
|
|
271
271
|
# Get full node data
|
|
272
272
|
unique_nodes = []
|
|
@@ -21,129 +21,11 @@ from typing import List
|
|
|
21
21
|
|
|
22
22
|
from graphiti_core.edges import EntityEdge
|
|
23
23
|
from graphiti_core.llm_client import LLMClient
|
|
24
|
-
from graphiti_core.nodes import
|
|
24
|
+
from graphiti_core.nodes import EpisodicNode
|
|
25
25
|
from graphiti_core.prompts import prompt_library
|
|
26
26
|
|
|
27
27
|
logger = logging.getLogger(__name__)
|
|
28
28
|
|
|
29
|
-
NodeEdgeNodeTriplet = tuple[EntityNode, EntityEdge, EntityNode]
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def extract_node_and_edge_triplets(
|
|
33
|
-
edges: list[EntityEdge], nodes: list[EntityNode]
|
|
34
|
-
) -> list[NodeEdgeNodeTriplet]:
|
|
35
|
-
return [extract_node_edge_node_triplet(edge, nodes) for edge in edges]
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
def extract_node_edge_node_triplet(
|
|
39
|
-
edge: EntityEdge, nodes: list[EntityNode]
|
|
40
|
-
) -> NodeEdgeNodeTriplet:
|
|
41
|
-
source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None)
|
|
42
|
-
target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None)
|
|
43
|
-
if not source_node or not target_node:
|
|
44
|
-
raise ValueError(f'Source or target node not found for edge {edge.uuid}')
|
|
45
|
-
return (source_node, edge, target_node)
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
def prepare_edges_for_invalidation(
|
|
49
|
-
existing_edges: list[EntityEdge],
|
|
50
|
-
new_edges: list[EntityEdge],
|
|
51
|
-
nodes: list[EntityNode],
|
|
52
|
-
) -> tuple[list[NodeEdgeNodeTriplet], list[NodeEdgeNodeTriplet]]:
|
|
53
|
-
existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet] = []
|
|
54
|
-
new_edges_with_nodes: list[NodeEdgeNodeTriplet] = []
|
|
55
|
-
|
|
56
|
-
for edge_list, result_list in [
|
|
57
|
-
(existing_edges, existing_edges_pending_invalidation),
|
|
58
|
-
(new_edges, new_edges_with_nodes),
|
|
59
|
-
]:
|
|
60
|
-
for edge in edge_list:
|
|
61
|
-
source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None)
|
|
62
|
-
target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None)
|
|
63
|
-
|
|
64
|
-
if source_node and target_node:
|
|
65
|
-
result_list.append((source_node, edge, target_node))
|
|
66
|
-
|
|
67
|
-
return existing_edges_pending_invalidation, new_edges_with_nodes
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
async def invalidate_edges(
|
|
71
|
-
llm_client: LLMClient,
|
|
72
|
-
existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet],
|
|
73
|
-
new_edges: list[NodeEdgeNodeTriplet],
|
|
74
|
-
current_episode: EpisodicNode,
|
|
75
|
-
previous_episodes: list[EpisodicNode],
|
|
76
|
-
) -> list[EntityEdge]:
|
|
77
|
-
invalidated_edges = [] # TODO: this is not yet used?
|
|
78
|
-
|
|
79
|
-
context = prepare_invalidation_context(
|
|
80
|
-
existing_edges_pending_invalidation,
|
|
81
|
-
new_edges,
|
|
82
|
-
current_episode,
|
|
83
|
-
previous_episodes,
|
|
84
|
-
)
|
|
85
|
-
llm_response = await llm_client.generate_response(prompt_library.invalidate_edges.v1(context))
|
|
86
|
-
|
|
87
|
-
edges_to_invalidate = llm_response.get('invalidated_edges', [])
|
|
88
|
-
invalidated_edges = process_edge_invalidation_llm_response(
|
|
89
|
-
edges_to_invalidate, existing_edges_pending_invalidation
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
return invalidated_edges
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
def extract_date_strings_from_edge(edge: EntityEdge) -> str:
|
|
96
|
-
start = edge.valid_at
|
|
97
|
-
end = edge.invalid_at
|
|
98
|
-
date_string = f'Start Date: {start.isoformat()}' if start else ''
|
|
99
|
-
if end:
|
|
100
|
-
date_string += f' (End Date: {end.isoformat()})'
|
|
101
|
-
return date_string
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
def prepare_invalidation_context(
|
|
105
|
-
existing_edges: list[NodeEdgeNodeTriplet],
|
|
106
|
-
new_edges: list[NodeEdgeNodeTriplet],
|
|
107
|
-
current_episode: EpisodicNode,
|
|
108
|
-
previous_episodes: list[EpisodicNode],
|
|
109
|
-
) -> dict:
|
|
110
|
-
return {
|
|
111
|
-
'existing_edges': [
|
|
112
|
-
f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) {extract_date_strings_from_edge(edge)}'
|
|
113
|
-
for source_node, edge, target_node in sorted(
|
|
114
|
-
existing_edges, key=lambda x: (x[1].created_at), reverse=True
|
|
115
|
-
)
|
|
116
|
-
],
|
|
117
|
-
'new_edges': [
|
|
118
|
-
f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) {extract_date_strings_from_edge(edge)}'
|
|
119
|
-
for source_node, edge, target_node in sorted(
|
|
120
|
-
new_edges, key=lambda x: (x[1].created_at), reverse=True
|
|
121
|
-
)
|
|
122
|
-
],
|
|
123
|
-
'current_episode': current_episode.content,
|
|
124
|
-
'previous_episodes': [episode.content for episode in previous_episodes],
|
|
125
|
-
}
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
def process_edge_invalidation_llm_response(
|
|
129
|
-
edges_to_invalidate: List[dict], existing_edges: List[NodeEdgeNodeTriplet]
|
|
130
|
-
) -> List[EntityEdge]:
|
|
131
|
-
invalidated_edges = []
|
|
132
|
-
for edge_to_invalidate in edges_to_invalidate:
|
|
133
|
-
edge_uuid = edge_to_invalidate['edge_uuid']
|
|
134
|
-
edge_to_update = next(
|
|
135
|
-
(edge for _, edge, _ in existing_edges if edge.uuid == edge_uuid),
|
|
136
|
-
None,
|
|
137
|
-
)
|
|
138
|
-
if edge_to_update:
|
|
139
|
-
edge_to_update.expired_at = datetime.now()
|
|
140
|
-
edge_to_update.fact = edge_to_invalidate['fact']
|
|
141
|
-
invalidated_edges.append(edge_to_update)
|
|
142
|
-
logger.info(
|
|
143
|
-
f"Invalidated edge: {edge_to_update.name} (UUID: {edge_to_update.uuid}). Updated Fact: {edge_to_invalidate['fact']}"
|
|
144
|
-
)
|
|
145
|
-
return invalidated_edges
|
|
146
|
-
|
|
147
29
|
|
|
148
30
|
async def extract_edge_dates(
|
|
149
31
|
llm_client: LLMClient,
|
|
@@ -152,7 +34,6 @@ async def extract_edge_dates(
|
|
|
152
34
|
previous_episodes: List[EpisodicNode],
|
|
153
35
|
) -> tuple[datetime | None, datetime | None]:
|
|
154
36
|
context = {
|
|
155
|
-
'edge_name': edge.name,
|
|
156
37
|
'edge_fact': edge.fact,
|
|
157
38
|
'current_episode': current_episode.content,
|
|
158
39
|
'previous_episodes': [ep.content for ep in previous_episodes],
|
|
@@ -162,25 +43,22 @@ async def extract_edge_dates(
|
|
|
162
43
|
|
|
163
44
|
valid_at = llm_response.get('valid_at')
|
|
164
45
|
invalid_at = llm_response.get('invalid_at')
|
|
165
|
-
explanation = llm_response.get('explanation', '')
|
|
166
46
|
|
|
167
47
|
valid_at_datetime = None
|
|
168
48
|
invalid_at_datetime = None
|
|
169
49
|
|
|
170
|
-
if valid_at
|
|
50
|
+
if valid_at:
|
|
171
51
|
try:
|
|
172
52
|
valid_at_datetime = datetime.fromisoformat(valid_at.replace('Z', '+00:00'))
|
|
173
53
|
except ValueError as e:
|
|
174
54
|
logger.error(f'Error parsing valid_at date: {e}. Input: {valid_at}')
|
|
175
55
|
|
|
176
|
-
if invalid_at
|
|
56
|
+
if invalid_at:
|
|
177
57
|
try:
|
|
178
58
|
invalid_at_datetime = datetime.fromisoformat(invalid_at.replace('Z', '+00:00'))
|
|
179
59
|
except ValueError as e:
|
|
180
60
|
logger.error(f'Error parsing invalid_at date: {e}. Input: {invalid_at}')
|
|
181
61
|
|
|
182
|
-
logger.info(f'Edge date extraction explanation: {explanation}')
|
|
183
|
-
|
|
184
62
|
return valid_at_datetime, invalid_at_datetime
|
|
185
63
|
|
|
186
64
|
|
|
@@ -210,7 +88,7 @@ async def get_edge_contradictions(
|
|
|
210
88
|
contradicted_edges.append(contradicted_edge)
|
|
211
89
|
|
|
212
90
|
end = time()
|
|
213
|
-
logger.
|
|
91
|
+
logger.debug(
|
|
214
92
|
f'Found invalidated edge candidates from {new_edge.fact}, in {(end - start) * 1000} ms'
|
|
215
93
|
)
|
|
216
94
|
|