graphiti-core 0.12.0rc1__py3-none-any.whl → 0.24.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphiti_core/cross_encoder/bge_reranker_client.py +12 -2
- graphiti_core/cross_encoder/gemini_reranker_client.py +161 -0
- graphiti_core/cross_encoder/openai_reranker_client.py +7 -5
- graphiti_core/decorators.py +110 -0
- graphiti_core/driver/__init__.py +19 -0
- graphiti_core/driver/driver.py +124 -0
- graphiti_core/driver/falkordb_driver.py +362 -0
- graphiti_core/driver/graph_operations/graph_operations.py +191 -0
- graphiti_core/driver/kuzu_driver.py +182 -0
- graphiti_core/driver/neo4j_driver.py +117 -0
- graphiti_core/driver/neptune_driver.py +305 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +287 -172
- graphiti_core/embedder/azure_openai.py +71 -0
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/embedder/gemini.py +116 -22
- graphiti_core/embedder/voyage.py +13 -2
- graphiti_core/errors.py +8 -0
- graphiti_core/graph_queries.py +162 -0
- graphiti_core/graphiti.py +705 -193
- graphiti_core/graphiti_types.py +4 -2
- graphiti_core/helpers.py +87 -10
- graphiti_core/llm_client/__init__.py +16 -0
- graphiti_core/llm_client/anthropic_client.py +159 -56
- graphiti_core/llm_client/azure_openai_client.py +115 -0
- graphiti_core/llm_client/client.py +98 -21
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/gemini_client.py +290 -41
- graphiti_core/llm_client/groq_client.py +14 -3
- graphiti_core/llm_client/openai_base_client.py +261 -0
- graphiti_core/llm_client/openai_client.py +56 -132
- graphiti_core/llm_client/openai_generic_client.py +91 -56
- graphiti_core/models/edges/edge_db_queries.py +259 -35
- graphiti_core/models/nodes/node_db_queries.py +311 -32
- graphiti_core/nodes.py +420 -205
- graphiti_core/prompts/dedupe_edges.py +46 -32
- graphiti_core/prompts/dedupe_nodes.py +67 -42
- graphiti_core/prompts/eval.py +4 -4
- graphiti_core/prompts/extract_edges.py +27 -16
- graphiti_core/prompts/extract_nodes.py +74 -31
- graphiti_core/prompts/prompt_helpers.py +39 -0
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +23 -25
- graphiti_core/search/search.py +158 -82
- graphiti_core/search/search_config.py +39 -4
- graphiti_core/search/search_filters.py +126 -35
- graphiti_core/search/search_helpers.py +5 -6
- graphiti_core/search/search_utils.py +1405 -485
- graphiti_core/telemetry/__init__.py +9 -0
- graphiti_core/telemetry/telemetry.py +117 -0
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +364 -285
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +67 -49
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +339 -197
- graphiti_core/utils/maintenance/graph_data_operations.py +50 -114
- graphiti_core/utils/maintenance/node_operations.py +319 -238
- graphiti_core/utils/maintenance/temporal_operations.py +11 -3
- graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
- graphiti_core/utils/text_utils.py +53 -0
- graphiti_core-0.24.3.dist-info/METADATA +726 -0
- graphiti_core-0.24.3.dist-info/RECORD +86 -0
- {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
- graphiti_core-0.12.0rc1.dist-info/METADATA +0 -350
- graphiti_core-0.12.0rc1.dist-info/RECORD +0 -66
- /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
- {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info/licenses}/LICENSE +0 -0
|
@@ -19,7 +19,9 @@ from datetime import datetime
|
|
|
19
19
|
from time import time
|
|
20
20
|
|
|
21
21
|
from pydantic import BaseModel
|
|
22
|
+
from typing_extensions import LiteralString
|
|
22
23
|
|
|
24
|
+
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
|
23
25
|
from graphiti_core.edges import (
|
|
24
26
|
CommunityEdge,
|
|
25
27
|
EntityEdge,
|
|
@@ -32,26 +34,31 @@ from graphiti_core.llm_client import LLMClient
|
|
|
32
34
|
from graphiti_core.llm_client.config import ModelSize
|
|
33
35
|
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
|
34
36
|
from graphiti_core.prompts import prompt_library
|
|
35
|
-
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
|
|
37
|
+
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
|
|
36
38
|
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
|
|
39
|
+
from graphiti_core.search.search import search
|
|
40
|
+
from graphiti_core.search.search_config import SearchResults
|
|
41
|
+
from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
|
|
37
42
|
from graphiti_core.search.search_filters import SearchFilters
|
|
38
|
-
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
|
|
39
43
|
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
|
|
44
|
+
from graphiti_core.utils.maintenance.dedup_helpers import _normalize_string_exact
|
|
45
|
+
|
|
46
|
+
DEFAULT_EDGE_NAME = 'RELATES_TO'
|
|
40
47
|
|
|
41
48
|
logger = logging.getLogger(__name__)
|
|
42
49
|
|
|
43
50
|
|
|
44
51
|
def build_episodic_edges(
|
|
45
52
|
entity_nodes: list[EntityNode],
|
|
46
|
-
|
|
53
|
+
episode_uuid: str,
|
|
47
54
|
created_at: datetime,
|
|
48
55
|
) -> list[EpisodicEdge]:
|
|
49
56
|
episodic_edges: list[EpisodicEdge] = [
|
|
50
57
|
EpisodicEdge(
|
|
51
|
-
source_node_uuid=
|
|
58
|
+
source_node_uuid=episode_uuid,
|
|
52
59
|
target_node_uuid=node.uuid,
|
|
53
60
|
created_at=created_at,
|
|
54
|
-
group_id=
|
|
61
|
+
group_id=node.group_id,
|
|
55
62
|
)
|
|
56
63
|
for node in entity_nodes
|
|
57
64
|
]
|
|
@@ -84,20 +91,26 @@ async def extract_edges(
|
|
|
84
91
|
episode: EpisodicNode,
|
|
85
92
|
nodes: list[EntityNode],
|
|
86
93
|
previous_episodes: list[EpisodicNode],
|
|
94
|
+
edge_type_map: dict[tuple[str, str], list[str]],
|
|
87
95
|
group_id: str = '',
|
|
88
|
-
edge_types: dict[str, BaseModel] | None = None,
|
|
96
|
+
edge_types: dict[str, type[BaseModel]] | None = None,
|
|
89
97
|
) -> list[EntityEdge]:
|
|
90
98
|
start = time()
|
|
91
99
|
|
|
92
100
|
extract_edges_max_tokens = 16384
|
|
93
101
|
llm_client = clients.llm_client
|
|
94
102
|
|
|
95
|
-
|
|
103
|
+
edge_type_signature_map: dict[str, tuple[str, str]] = {
|
|
104
|
+
edge_type: signature
|
|
105
|
+
for signature, edge_types in edge_type_map.items()
|
|
106
|
+
for edge_type in edge_types
|
|
107
|
+
}
|
|
96
108
|
|
|
97
109
|
edge_types_context = (
|
|
98
110
|
[
|
|
99
111
|
{
|
|
100
112
|
'fact_type_name': type_name,
|
|
113
|
+
'fact_type_signature': edge_type_signature_map.get(type_name, ('Entity', 'Entity')),
|
|
101
114
|
'fact_type_description': type_model.__doc__,
|
|
102
115
|
}
|
|
103
116
|
for type_name, type_model in edge_types.items()
|
|
@@ -109,7 +122,10 @@ async def extract_edges(
|
|
|
109
122
|
# Prepare context for LLM
|
|
110
123
|
context = {
|
|
111
124
|
'episode_content': episode.content,
|
|
112
|
-
'nodes': [
|
|
125
|
+
'nodes': [
|
|
126
|
+
{'id': idx, 'name': node.name, 'entity_types': node.labels}
|
|
127
|
+
for idx, node in enumerate(nodes)
|
|
128
|
+
],
|
|
113
129
|
'previous_episodes': [ep.content for ep in previous_episodes],
|
|
114
130
|
'reference_time': episode.valid_at,
|
|
115
131
|
'edge_types': edge_types_context,
|
|
@@ -123,10 +139,12 @@ async def extract_edges(
|
|
|
123
139
|
prompt_library.extract_edges.edge(context),
|
|
124
140
|
response_model=ExtractedEdges,
|
|
125
141
|
max_tokens=extract_edges_max_tokens,
|
|
142
|
+
group_id=group_id,
|
|
143
|
+
prompt_name='extract_edges.edge',
|
|
126
144
|
)
|
|
127
|
-
edges_data = llm_response.
|
|
145
|
+
edges_data = ExtractedEdges(**llm_response).edges
|
|
128
146
|
|
|
129
|
-
context['extracted_facts'] = [edge_data.
|
|
147
|
+
context['extracted_facts'] = [edge_data.fact for edge_data in edges_data]
|
|
130
148
|
|
|
131
149
|
reflexion_iterations += 1
|
|
132
150
|
if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
|
|
@@ -134,6 +152,8 @@ async def extract_edges(
|
|
|
134
152
|
prompt_library.extract_edges.reflexion(context),
|
|
135
153
|
response_model=MissingFacts,
|
|
136
154
|
max_tokens=extract_edges_max_tokens,
|
|
155
|
+
group_id=group_id,
|
|
156
|
+
prompt_name='extract_edges.reflexion',
|
|
137
157
|
)
|
|
138
158
|
|
|
139
159
|
missing_facts = reflexion_response.get('missing_facts', [])
|
|
@@ -156,11 +176,32 @@ async def extract_edges(
|
|
|
156
176
|
edges = []
|
|
157
177
|
for edge_data in edges_data:
|
|
158
178
|
# Validate Edge Date information
|
|
159
|
-
valid_at = edge_data.
|
|
160
|
-
invalid_at = edge_data.
|
|
179
|
+
valid_at = edge_data.valid_at
|
|
180
|
+
invalid_at = edge_data.invalid_at
|
|
161
181
|
valid_at_datetime = None
|
|
162
182
|
invalid_at_datetime = None
|
|
163
183
|
|
|
184
|
+
# Filter out empty edges
|
|
185
|
+
if not edge_data.fact.strip():
|
|
186
|
+
continue
|
|
187
|
+
|
|
188
|
+
source_node_idx = edge_data.source_entity_id
|
|
189
|
+
target_node_idx = edge_data.target_entity_id
|
|
190
|
+
|
|
191
|
+
if len(nodes) == 0:
|
|
192
|
+
logger.warning('No entities provided for edge extraction')
|
|
193
|
+
continue
|
|
194
|
+
|
|
195
|
+
if not (0 <= source_node_idx < len(nodes) and 0 <= target_node_idx < len(nodes)):
|
|
196
|
+
logger.warning(
|
|
197
|
+
f'Invalid entity IDs in edge extraction for {edge_data.relation_type}. '
|
|
198
|
+
f'source_entity_id: {source_node_idx}, target_entity_id: {target_node_idx}, '
|
|
199
|
+
f'but only {len(nodes)} entities available (valid range: 0-{len(nodes) - 1})'
|
|
200
|
+
)
|
|
201
|
+
continue
|
|
202
|
+
source_node_uuid = nodes[source_node_idx].uuid
|
|
203
|
+
target_node_uuid = nodes[target_node_idx].uuid
|
|
204
|
+
|
|
164
205
|
if valid_at:
|
|
165
206
|
try:
|
|
166
207
|
valid_at_datetime = ensure_utc(
|
|
@@ -177,15 +218,11 @@ async def extract_edges(
|
|
|
177
218
|
except ValueError as e:
|
|
178
219
|
logger.warning(f'WARNING: Error parsing invalid_at date: {e}. Input: {invalid_at}')
|
|
179
220
|
edge = EntityEdge(
|
|
180
|
-
source_node_uuid=
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
target_node_uuid=node_uuids_by_name_map.get(
|
|
184
|
-
edge_data.get('target_entity_name', ''), ''
|
|
185
|
-
),
|
|
186
|
-
name=edge_data.get('relation_type', ''),
|
|
221
|
+
source_node_uuid=source_node_uuid,
|
|
222
|
+
target_node_uuid=target_node_uuid,
|
|
223
|
+
name=edge_data.relation_type,
|
|
187
224
|
group_id=group_id,
|
|
188
|
-
fact=edge_data.
|
|
225
|
+
fact=edge_data.fact,
|
|
189
226
|
episodes=[episode.uuid],
|
|
190
227
|
created_at=utc_now(),
|
|
191
228
|
valid_at=valid_at_datetime,
|
|
@@ -201,70 +238,73 @@ async def extract_edges(
|
|
|
201
238
|
return edges
|
|
202
239
|
|
|
203
240
|
|
|
204
|
-
async def dedupe_extracted_edges(
|
|
205
|
-
llm_client: LLMClient,
|
|
206
|
-
extracted_edges: list[EntityEdge],
|
|
207
|
-
existing_edges: list[EntityEdge],
|
|
208
|
-
) -> list[EntityEdge]:
|
|
209
|
-
# Create edge map
|
|
210
|
-
edge_map: dict[str, EntityEdge] = {}
|
|
211
|
-
for edge in existing_edges:
|
|
212
|
-
edge_map[edge.uuid] = edge
|
|
213
|
-
|
|
214
|
-
# Prepare context for LLM
|
|
215
|
-
context = {
|
|
216
|
-
'extracted_edges': [
|
|
217
|
-
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in extracted_edges
|
|
218
|
-
],
|
|
219
|
-
'existing_edges': [
|
|
220
|
-
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges
|
|
221
|
-
],
|
|
222
|
-
}
|
|
223
|
-
|
|
224
|
-
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.edge(context))
|
|
225
|
-
duplicate_data = llm_response.get('duplicates', [])
|
|
226
|
-
logger.debug(f'Extracted unique edges: {duplicate_data}')
|
|
227
|
-
|
|
228
|
-
duplicate_uuid_map: dict[str, str] = {}
|
|
229
|
-
for duplicate in duplicate_data:
|
|
230
|
-
uuid_value = duplicate['duplicate_of']
|
|
231
|
-
duplicate_uuid_map[duplicate['uuid']] = uuid_value
|
|
232
|
-
|
|
233
|
-
# Get full edge data
|
|
234
|
-
edges: list[EntityEdge] = []
|
|
235
|
-
for edge in extracted_edges:
|
|
236
|
-
if edge.uuid in duplicate_uuid_map:
|
|
237
|
-
existing_uuid = duplicate_uuid_map[edge.uuid]
|
|
238
|
-
existing_edge = edge_map[existing_uuid]
|
|
239
|
-
# Add current episode to the episodes list
|
|
240
|
-
existing_edge.episodes += edge.episodes
|
|
241
|
-
edges.append(existing_edge)
|
|
242
|
-
else:
|
|
243
|
-
edges.append(edge)
|
|
244
|
-
|
|
245
|
-
return edges
|
|
246
|
-
|
|
247
|
-
|
|
248
241
|
async def resolve_extracted_edges(
|
|
249
242
|
clients: GraphitiClients,
|
|
250
243
|
extracted_edges: list[EntityEdge],
|
|
251
244
|
episode: EpisodicNode,
|
|
252
245
|
entities: list[EntityNode],
|
|
253
|
-
edge_types: dict[str, BaseModel],
|
|
246
|
+
edge_types: dict[str, type[BaseModel]],
|
|
254
247
|
edge_type_map: dict[tuple[str, str], list[str]],
|
|
255
248
|
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
|
249
|
+
# Fast path: deduplicate exact matches within the extracted edges before parallel processing
|
|
250
|
+
seen: dict[tuple[str, str, str], EntityEdge] = {}
|
|
251
|
+
deduplicated_edges: list[EntityEdge] = []
|
|
252
|
+
|
|
253
|
+
for edge in extracted_edges:
|
|
254
|
+
key = (
|
|
255
|
+
edge.source_node_uuid,
|
|
256
|
+
edge.target_node_uuid,
|
|
257
|
+
_normalize_string_exact(edge.fact),
|
|
258
|
+
)
|
|
259
|
+
if key not in seen:
|
|
260
|
+
seen[key] = edge
|
|
261
|
+
deduplicated_edges.append(edge)
|
|
262
|
+
|
|
263
|
+
extracted_edges = deduplicated_edges
|
|
264
|
+
|
|
256
265
|
driver = clients.driver
|
|
257
266
|
llm_client = clients.llm_client
|
|
258
267
|
embedder = clients.embedder
|
|
259
|
-
|
|
260
268
|
await create_entity_edge_embeddings(embedder, extracted_edges)
|
|
261
269
|
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
270
|
+
valid_edges_list: list[list[EntityEdge]] = await semaphore_gather(
|
|
271
|
+
*[
|
|
272
|
+
EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
|
|
273
|
+
for edge in extracted_edges
|
|
274
|
+
]
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
related_edges_results: list[SearchResults] = await semaphore_gather(
|
|
278
|
+
*[
|
|
279
|
+
search(
|
|
280
|
+
clients,
|
|
281
|
+
extracted_edge.fact,
|
|
282
|
+
group_ids=[extracted_edge.group_id],
|
|
283
|
+
config=EDGE_HYBRID_SEARCH_RRF,
|
|
284
|
+
search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
|
|
285
|
+
)
|
|
286
|
+
for extracted_edge, valid_edges in zip(extracted_edges, valid_edges_list, strict=True)
|
|
287
|
+
]
|
|
265
288
|
)
|
|
266
289
|
|
|
267
|
-
related_edges_lists
|
|
290
|
+
related_edges_lists: list[list[EntityEdge]] = [result.edges for result in related_edges_results]
|
|
291
|
+
|
|
292
|
+
edge_invalidation_candidate_results: list[SearchResults] = await semaphore_gather(
|
|
293
|
+
*[
|
|
294
|
+
search(
|
|
295
|
+
clients,
|
|
296
|
+
extracted_edge.fact,
|
|
297
|
+
group_ids=[extracted_edge.group_id],
|
|
298
|
+
config=EDGE_HYBRID_SEARCH_RRF,
|
|
299
|
+
search_filter=SearchFilters(),
|
|
300
|
+
)
|
|
301
|
+
for extracted_edge in extracted_edges
|
|
302
|
+
]
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
edge_invalidation_candidates: list[list[EntityEdge]] = [
|
|
306
|
+
result.edges for result in edge_invalidation_candidate_results
|
|
307
|
+
]
|
|
268
308
|
|
|
269
309
|
logger.debug(
|
|
270
310
|
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
|
|
@@ -273,11 +313,21 @@ async def resolve_extracted_edges(
|
|
|
273
313
|
# Build entity hash table
|
|
274
314
|
uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
|
|
275
315
|
|
|
276
|
-
# Determine which edge types are relevant for each edge
|
|
277
|
-
edge_types_lst
|
|
316
|
+
# Determine which edge types are relevant for each edge.
|
|
317
|
+
# `edge_types_lst` stores the subset of custom edge definitions whose
|
|
318
|
+
# node signature matches each extracted edge. Anything outside this subset
|
|
319
|
+
# should only stay on the edge if it is a non-custom (LLM generated) label.
|
|
320
|
+
edge_types_lst: list[dict[str, type[BaseModel]]] = []
|
|
321
|
+
custom_type_names = set(edge_types or {})
|
|
278
322
|
for extracted_edge in extracted_edges:
|
|
279
|
-
|
|
280
|
-
|
|
323
|
+
source_node = uuid_entity_map.get(extracted_edge.source_node_uuid)
|
|
324
|
+
target_node = uuid_entity_map.get(extracted_edge.target_node_uuid)
|
|
325
|
+
source_node_labels = (
|
|
326
|
+
source_node.labels + ['Entity'] if source_node is not None else ['Entity']
|
|
327
|
+
)
|
|
328
|
+
target_node_labels = (
|
|
329
|
+
target_node.labels + ['Entity'] if target_node is not None else ['Entity']
|
|
330
|
+
)
|
|
281
331
|
label_tuples = [
|
|
282
332
|
(source_label, target_label)
|
|
283
333
|
for source_label in source_node_labels
|
|
@@ -296,8 +346,22 @@ async def resolve_extracted_edges(
|
|
|
296
346
|
|
|
297
347
|
edge_types_lst.append(extracted_edge_types)
|
|
298
348
|
|
|
349
|
+
for extracted_edge, extracted_edge_types in zip(extracted_edges, edge_types_lst, strict=True):
|
|
350
|
+
allowed_type_names = set(extracted_edge_types)
|
|
351
|
+
is_custom_name = extracted_edge.name in custom_type_names
|
|
352
|
+
if not allowed_type_names:
|
|
353
|
+
# No custom types are valid for this node pairing. Keep LLM generated
|
|
354
|
+
# labels, but flip disallowed custom names back to the default.
|
|
355
|
+
if is_custom_name and extracted_edge.name != DEFAULT_EDGE_NAME:
|
|
356
|
+
extracted_edge.name = DEFAULT_EDGE_NAME
|
|
357
|
+
continue
|
|
358
|
+
if is_custom_name and extracted_edge.name not in allowed_type_names:
|
|
359
|
+
# Custom name exists but it is not permitted for this source/target
|
|
360
|
+
# signature, so fall back to the default edge label.
|
|
361
|
+
extracted_edge.name = DEFAULT_EDGE_NAME
|
|
362
|
+
|
|
299
363
|
# resolve edges with related edges in the graph and find invalidation candidates
|
|
300
|
-
results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
|
|
364
|
+
results: list[tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]] = list(
|
|
301
365
|
await semaphore_gather(
|
|
302
366
|
*[
|
|
303
367
|
resolve_extracted_edge(
|
|
@@ -307,6 +371,7 @@ async def resolve_extracted_edges(
|
|
|
307
371
|
existing_edges,
|
|
308
372
|
episode,
|
|
309
373
|
extracted_edge_types,
|
|
374
|
+
custom_type_names,
|
|
310
375
|
)
|
|
311
376
|
for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
|
|
312
377
|
extracted_edges,
|
|
@@ -348,21 +413,26 @@ def resolve_edge_contradictions(
|
|
|
348
413
|
invalidated_edges: list[EntityEdge] = []
|
|
349
414
|
for edge in invalidation_candidates:
|
|
350
415
|
# (Edge invalid before new edge becomes valid) or (new edge invalid before edge becomes valid)
|
|
416
|
+
edge_invalid_at_utc = ensure_utc(edge.invalid_at)
|
|
417
|
+
resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
|
|
418
|
+
edge_valid_at_utc = ensure_utc(edge.valid_at)
|
|
419
|
+
resolved_edge_invalid_at_utc = ensure_utc(resolved_edge.invalid_at)
|
|
420
|
+
|
|
351
421
|
if (
|
|
352
|
-
|
|
353
|
-
and
|
|
354
|
-
and
|
|
422
|
+
edge_invalid_at_utc is not None
|
|
423
|
+
and resolved_edge_valid_at_utc is not None
|
|
424
|
+
and edge_invalid_at_utc <= resolved_edge_valid_at_utc
|
|
355
425
|
) or (
|
|
356
|
-
|
|
357
|
-
and
|
|
358
|
-
and
|
|
426
|
+
edge_valid_at_utc is not None
|
|
427
|
+
and resolved_edge_invalid_at_utc is not None
|
|
428
|
+
and resolved_edge_invalid_at_utc <= edge_valid_at_utc
|
|
359
429
|
):
|
|
360
430
|
continue
|
|
361
431
|
# New edge invalidates edge
|
|
362
432
|
elif (
|
|
363
|
-
|
|
364
|
-
and
|
|
365
|
-
and
|
|
433
|
+
edge_valid_at_utc is not None
|
|
434
|
+
and resolved_edge_valid_at_utc is not None
|
|
435
|
+
and edge_valid_at_utc < resolved_edge_valid_at_utc
|
|
366
436
|
):
|
|
367
437
|
edge.invalid_at = resolved_edge.valid_at
|
|
368
438
|
edge.expired_at = edge.expired_at if edge.expired_at is not None else utc_now()
|
|
@@ -377,32 +447,69 @@ async def resolve_extracted_edge(
|
|
|
377
447
|
related_edges: list[EntityEdge],
|
|
378
448
|
existing_edges: list[EntityEdge],
|
|
379
449
|
episode: EpisodicNode,
|
|
380
|
-
|
|
381
|
-
|
|
450
|
+
edge_type_candidates: dict[str, type[BaseModel]] | None = None,
|
|
451
|
+
custom_edge_type_names: set[str] | None = None,
|
|
452
|
+
) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
|
|
453
|
+
"""Resolve an extracted edge against existing graph context.
|
|
454
|
+
|
|
455
|
+
Parameters
|
|
456
|
+
----------
|
|
457
|
+
llm_client : LLMClient
|
|
458
|
+
Client used to invoke the LLM for deduplication and attribute extraction.
|
|
459
|
+
extracted_edge : EntityEdge
|
|
460
|
+
Newly extracted edge whose canonical representation is being resolved.
|
|
461
|
+
related_edges : list[EntityEdge]
|
|
462
|
+
Candidate edges with identical endpoints used for duplicate detection.
|
|
463
|
+
existing_edges : list[EntityEdge]
|
|
464
|
+
Broader set of edges evaluated for contradiction / invalidation.
|
|
465
|
+
episode : EpisodicNode
|
|
466
|
+
Episode providing content context when extracting edge attributes.
|
|
467
|
+
edge_type_candidates : dict[str, type[BaseModel]] | None
|
|
468
|
+
Custom edge types permitted for the current source/target signature.
|
|
469
|
+
custom_edge_type_names : set[str] | None
|
|
470
|
+
Full catalog of registered custom edge names. Used to distinguish
|
|
471
|
+
between disallowed custom types (which fall back to the default label)
|
|
472
|
+
and ad-hoc labels emitted by the LLM.
|
|
473
|
+
|
|
474
|
+
Returns
|
|
475
|
+
-------
|
|
476
|
+
tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]
|
|
477
|
+
The resolved edge, any duplicates, and edges to invalidate.
|
|
478
|
+
"""
|
|
382
479
|
if len(related_edges) == 0 and len(existing_edges) == 0:
|
|
383
|
-
return extracted_edge, []
|
|
480
|
+
return extracted_edge, [], []
|
|
481
|
+
|
|
482
|
+
# Fast path: if the fact text and endpoints already exist verbatim, reuse the matching edge.
|
|
483
|
+
normalized_fact = _normalize_string_exact(extracted_edge.fact)
|
|
484
|
+
for edge in related_edges:
|
|
485
|
+
if (
|
|
486
|
+
edge.source_node_uuid == extracted_edge.source_node_uuid
|
|
487
|
+
and edge.target_node_uuid == extracted_edge.target_node_uuid
|
|
488
|
+
and _normalize_string_exact(edge.fact) == normalized_fact
|
|
489
|
+
):
|
|
490
|
+
resolved = edge
|
|
491
|
+
if episode is not None and episode.uuid not in resolved.episodes:
|
|
492
|
+
resolved.episodes.append(episode.uuid)
|
|
493
|
+
return resolved, [], []
|
|
384
494
|
|
|
385
495
|
start = time()
|
|
386
496
|
|
|
387
497
|
# Prepare context for LLM
|
|
388
|
-
related_edges_context = [
|
|
389
|
-
{'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
|
|
390
|
-
]
|
|
498
|
+
related_edges_context = [{'idx': i, 'fact': edge.fact} for i, edge in enumerate(related_edges)]
|
|
391
499
|
|
|
392
500
|
invalidation_edge_candidates_context = [
|
|
393
|
-
{'
|
|
501
|
+
{'idx': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
|
|
394
502
|
]
|
|
395
503
|
|
|
396
504
|
edge_types_context = (
|
|
397
505
|
[
|
|
398
506
|
{
|
|
399
|
-
'fact_type_id': i,
|
|
400
507
|
'fact_type_name': type_name,
|
|
401
508
|
'fact_type_description': type_model.__doc__,
|
|
402
509
|
}
|
|
403
|
-
for
|
|
510
|
+
for type_name, type_model in edge_type_candidates.items()
|
|
404
511
|
]
|
|
405
|
-
if
|
|
512
|
+
if edge_type_candidates is not None
|
|
406
513
|
else []
|
|
407
514
|
)
|
|
408
515
|
|
|
@@ -413,46 +520,97 @@ async def resolve_extracted_edge(
|
|
|
413
520
|
'edge_types': edge_types_context,
|
|
414
521
|
}
|
|
415
522
|
|
|
523
|
+
if related_edges or existing_edges:
|
|
524
|
+
logger.debug(
|
|
525
|
+
'Resolving edge: sent %d EXISTING FACTS%s and %d INVALIDATION CANDIDATES%s',
|
|
526
|
+
len(related_edges),
|
|
527
|
+
f' (idx 0-{len(related_edges) - 1})' if related_edges else '',
|
|
528
|
+
len(existing_edges),
|
|
529
|
+
f' (idx 0-{len(existing_edges) - 1})' if existing_edges else '',
|
|
530
|
+
)
|
|
531
|
+
|
|
416
532
|
llm_response = await llm_client.generate_response(
|
|
417
533
|
prompt_library.dedupe_edges.resolve_edge(context),
|
|
418
534
|
response_model=EdgeDuplicate,
|
|
419
535
|
model_size=ModelSize.small,
|
|
536
|
+
prompt_name='dedupe_edges.resolve_edge',
|
|
420
537
|
)
|
|
538
|
+
response_object = EdgeDuplicate(**llm_response)
|
|
539
|
+
duplicate_facts = response_object.duplicate_facts
|
|
540
|
+
|
|
541
|
+
# Validate duplicate_facts are in valid range for EXISTING FACTS
|
|
542
|
+
invalid_duplicates = [i for i in duplicate_facts if i < 0 or i >= len(related_edges)]
|
|
543
|
+
if invalid_duplicates:
|
|
544
|
+
logger.warning(
|
|
545
|
+
'LLM returned invalid duplicate_facts idx values %s (valid range: 0-%d for EXISTING FACTS)',
|
|
546
|
+
invalid_duplicates,
|
|
547
|
+
len(related_edges) - 1,
|
|
548
|
+
)
|
|
421
549
|
|
|
422
|
-
|
|
550
|
+
duplicate_fact_ids: list[int] = [i for i in duplicate_facts if 0 <= i < len(related_edges)]
|
|
423
551
|
|
|
424
|
-
resolved_edge =
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
)
|
|
552
|
+
resolved_edge = extracted_edge
|
|
553
|
+
for duplicate_fact_id in duplicate_fact_ids:
|
|
554
|
+
resolved_edge = related_edges[duplicate_fact_id]
|
|
555
|
+
break
|
|
429
556
|
|
|
430
|
-
if
|
|
557
|
+
if duplicate_fact_ids and episode is not None:
|
|
431
558
|
resolved_edge.episodes.append(episode.uuid)
|
|
432
559
|
|
|
433
|
-
contradicted_facts: list[int] =
|
|
560
|
+
contradicted_facts: list[int] = response_object.contradicted_facts
|
|
434
561
|
|
|
435
|
-
|
|
562
|
+
# Validate contradicted_facts are in valid range for INVALIDATION CANDIDATES
|
|
563
|
+
invalid_contradictions = [i for i in contradicted_facts if i < 0 or i >= len(existing_edges)]
|
|
564
|
+
if invalid_contradictions:
|
|
565
|
+
logger.warning(
|
|
566
|
+
'LLM returned invalid contradicted_facts idx values %s (valid range: 0-%d for INVALIDATION CANDIDATES)',
|
|
567
|
+
invalid_contradictions,
|
|
568
|
+
len(existing_edges) - 1,
|
|
569
|
+
)
|
|
436
570
|
|
|
437
|
-
|
|
438
|
-
|
|
571
|
+
invalidation_candidates: list[EntityEdge] = [
|
|
572
|
+
existing_edges[i] for i in contradicted_facts if 0 <= i < len(existing_edges)
|
|
573
|
+
]
|
|
574
|
+
|
|
575
|
+
fact_type: str = response_object.fact_type
|
|
576
|
+
candidate_type_names = set(edge_type_candidates or {})
|
|
577
|
+
custom_type_names = custom_edge_type_names or set()
|
|
578
|
+
|
|
579
|
+
is_default_type = fact_type.upper() == 'DEFAULT'
|
|
580
|
+
is_custom_type = fact_type in custom_type_names
|
|
581
|
+
is_allowed_custom_type = fact_type in candidate_type_names
|
|
582
|
+
|
|
583
|
+
if is_allowed_custom_type:
|
|
584
|
+
# The LLM selected a custom type that is allowed for the node pair.
|
|
585
|
+
# Adopt the custom type and, if needed, extract its structured attributes.
|
|
439
586
|
resolved_edge.name = fact_type
|
|
440
587
|
|
|
441
588
|
edge_attributes_context = {
|
|
442
|
-
'
|
|
589
|
+
'episode_content': episode.content,
|
|
443
590
|
'reference_time': episode.valid_at,
|
|
444
591
|
'fact': resolved_edge.fact,
|
|
445
592
|
}
|
|
446
593
|
|
|
447
|
-
edge_model =
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
594
|
+
edge_model = edge_type_candidates.get(fact_type) if edge_type_candidates else None
|
|
595
|
+
if edge_model is not None and len(edge_model.model_fields) != 0:
|
|
596
|
+
edge_attributes_response = await llm_client.generate_response(
|
|
597
|
+
prompt_library.extract_edges.extract_attributes(edge_attributes_context),
|
|
598
|
+
response_model=edge_model, # type: ignore
|
|
599
|
+
model_size=ModelSize.small,
|
|
600
|
+
prompt_name='extract_edges.extract_attributes',
|
|
601
|
+
)
|
|
454
602
|
|
|
455
|
-
|
|
603
|
+
resolved_edge.attributes = edge_attributes_response
|
|
604
|
+
elif not is_default_type and is_custom_type:
|
|
605
|
+
# The LLM picked a custom type that is not allowed for this signature.
|
|
606
|
+
# Reset to the default label and drop any structured attributes.
|
|
607
|
+
resolved_edge.name = DEFAULT_EDGE_NAME
|
|
608
|
+
resolved_edge.attributes = {}
|
|
609
|
+
elif not is_default_type:
|
|
610
|
+
# Non-custom labels are allowed to pass through so long as the LLM does
|
|
611
|
+
# not return the sentinel DEFAULT value.
|
|
612
|
+
resolved_edge.name = fact_type
|
|
613
|
+
resolved_edge.attributes = {}
|
|
456
614
|
|
|
457
615
|
end = time()
|
|
458
616
|
logger.debug(
|
|
@@ -466,14 +624,14 @@ async def resolve_extracted_edge(
|
|
|
466
624
|
|
|
467
625
|
# Determine if the new_edge needs to be expired
|
|
468
626
|
if resolved_edge.expired_at is None:
|
|
469
|
-
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at))
|
|
627
|
+
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, ensure_utc(c.valid_at)))
|
|
470
628
|
for candidate in invalidation_candidates:
|
|
629
|
+
candidate_valid_at_utc = ensure_utc(candidate.valid_at)
|
|
630
|
+
resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
|
|
471
631
|
if (
|
|
472
|
-
|
|
473
|
-
and
|
|
474
|
-
and
|
|
475
|
-
and resolved_edge.valid_at.tzinfo
|
|
476
|
-
and candidate.valid_at > resolved_edge.valid_at
|
|
632
|
+
candidate_valid_at_utc is not None
|
|
633
|
+
and resolved_edge_valid_at_utc is not None
|
|
634
|
+
and candidate_valid_at_utc > resolved_edge_valid_at_utc
|
|
477
635
|
):
|
|
478
636
|
# Expire new edge since we have information about more recent events
|
|
479
637
|
resolved_edge.invalid_at = candidate.valid_at
|
|
@@ -481,89 +639,73 @@ async def resolve_extracted_edge(
|
|
|
481
639
|
break
|
|
482
640
|
|
|
483
641
|
# Determine which contradictory edges need to be expired
|
|
484
|
-
invalidated_edges = resolve_edge_contradictions(
|
|
485
|
-
|
|
486
|
-
return resolved_edge, invalidated_edges
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
async def dedupe_extracted_edge(
|
|
490
|
-
llm_client: LLMClient,
|
|
491
|
-
extracted_edge: EntityEdge,
|
|
492
|
-
related_edges: list[EntityEdge],
|
|
493
|
-
episode: EpisodicNode | None = None,
|
|
494
|
-
) -> EntityEdge:
|
|
495
|
-
if len(related_edges) == 0:
|
|
496
|
-
return extracted_edge
|
|
497
|
-
|
|
498
|
-
start = time()
|
|
499
|
-
|
|
500
|
-
# Prepare context for LLM
|
|
501
|
-
related_edges_context = [
|
|
502
|
-
{'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
|
|
503
|
-
]
|
|
504
|
-
|
|
505
|
-
extracted_edge_context = {
|
|
506
|
-
'fact': extracted_edge.fact,
|
|
507
|
-
}
|
|
508
|
-
|
|
509
|
-
context = {
|
|
510
|
-
'related_edges': related_edges_context,
|
|
511
|
-
'extracted_edges': extracted_edge_context,
|
|
512
|
-
}
|
|
513
|
-
|
|
514
|
-
llm_response = await llm_client.generate_response(
|
|
515
|
-
prompt_library.dedupe_edges.edge(context),
|
|
516
|
-
response_model=EdgeDuplicate,
|
|
517
|
-
model_size=ModelSize.small,
|
|
518
|
-
)
|
|
519
|
-
|
|
520
|
-
duplicate_fact_id: int = llm_response.get('duplicate_fact_id', -1)
|
|
521
|
-
|
|
522
|
-
edge = (
|
|
523
|
-
related_edges[duplicate_fact_id]
|
|
524
|
-
if 0 <= duplicate_fact_id < len(related_edges)
|
|
525
|
-
else extracted_edge
|
|
526
|
-
)
|
|
527
|
-
|
|
528
|
-
if duplicate_fact_id >= 0 and episode is not None:
|
|
529
|
-
edge.episodes.append(episode.uuid)
|
|
530
|
-
|
|
531
|
-
end = time()
|
|
532
|
-
logger.debug(
|
|
533
|
-
f'Resolved Edge: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms'
|
|
642
|
+
invalidated_edges: list[EntityEdge] = resolve_edge_contradictions(
|
|
643
|
+
resolved_edge, invalidation_candidates
|
|
534
644
|
)
|
|
645
|
+
duplicate_edges: list[EntityEdge] = [related_edges[idx] for idx in duplicate_fact_ids]
|
|
535
646
|
|
|
536
|
-
return
|
|
647
|
+
return resolved_edge, invalidated_edges, duplicate_edges
|
|
537
648
|
|
|
538
649
|
|
|
539
|
-
async def
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
# Create edge map
|
|
546
|
-
edge_map = {}
|
|
547
|
-
for edge in edges:
|
|
548
|
-
edge_map[edge.uuid] = edge
|
|
650
|
+
async def filter_existing_duplicate_of_edges(
|
|
651
|
+
driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]]
|
|
652
|
+
) -> list[tuple[EntityNode, EntityNode]]:
|
|
653
|
+
if not duplicates_node_tuples:
|
|
654
|
+
return []
|
|
549
655
|
|
|
550
|
-
|
|
551
|
-
|
|
656
|
+
duplicate_nodes_map = {
|
|
657
|
+
(source.uuid, target.uuid): (source, target) for source, target in duplicates_node_tuples
|
|
658
|
+
}
|
|
552
659
|
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
660
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
661
|
+
query: LiteralString = """
|
|
662
|
+
UNWIND $duplicate_node_uuids AS duplicate_tuple
|
|
663
|
+
MATCH (n:Entity {uuid: duplicate_tuple.source})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple.target})
|
|
664
|
+
RETURN DISTINCT
|
|
665
|
+
n.uuid AS source_uuid,
|
|
666
|
+
m.uuid AS target_uuid
|
|
667
|
+
"""
|
|
668
|
+
|
|
669
|
+
duplicate_nodes = [
|
|
670
|
+
{'source': source.uuid, 'target': target.uuid}
|
|
671
|
+
for source, target in duplicates_node_tuples
|
|
672
|
+
]
|
|
557
673
|
|
|
558
|
-
|
|
559
|
-
|
|
674
|
+
records, _, _ = await driver.execute_query(
|
|
675
|
+
query,
|
|
676
|
+
duplicate_node_uuids=duplicate_nodes,
|
|
677
|
+
routing_='r',
|
|
678
|
+
)
|
|
679
|
+
else:
|
|
680
|
+
if driver.provider == GraphProvider.KUZU:
|
|
681
|
+
query = """
|
|
682
|
+
UNWIND $duplicate_node_uuids AS duplicate
|
|
683
|
+
MATCH (n:Entity {uuid: duplicate.src})-[:RELATES_TO]->(e:RelatesToNode_ {name: 'IS_DUPLICATE_OF'})-[:RELATES_TO]->(m:Entity {uuid: duplicate.dst})
|
|
684
|
+
RETURN DISTINCT
|
|
685
|
+
n.uuid AS source_uuid,
|
|
686
|
+
m.uuid AS target_uuid
|
|
687
|
+
"""
|
|
688
|
+
duplicate_node_uuids = [{'src': src, 'dst': dst} for src, dst in duplicate_nodes_map]
|
|
689
|
+
else:
|
|
690
|
+
query: LiteralString = """
|
|
691
|
+
UNWIND $duplicate_node_uuids AS duplicate_tuple
|
|
692
|
+
MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
|
|
693
|
+
RETURN DISTINCT
|
|
694
|
+
n.uuid AS source_uuid,
|
|
695
|
+
m.uuid AS target_uuid
|
|
696
|
+
"""
|
|
697
|
+
duplicate_node_uuids = list(duplicate_nodes_map.keys())
|
|
698
|
+
|
|
699
|
+
records, _, _ = await driver.execute_query(
|
|
700
|
+
query,
|
|
701
|
+
duplicate_node_uuids=duplicate_node_uuids,
|
|
702
|
+
routing_='r',
|
|
703
|
+
)
|
|
560
704
|
|
|
561
|
-
#
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
edge.fact = edge_data['fact']
|
|
567
|
-
unique_edges.append(edge)
|
|
705
|
+
# Remove duplicates that already have the IS_DUPLICATE_OF edge
|
|
706
|
+
for record in records:
|
|
707
|
+
duplicate_tuple = (record.get('source_uuid'), record.get('target_uuid'))
|
|
708
|
+
if duplicate_nodes_map.get(duplicate_tuple):
|
|
709
|
+
duplicate_nodes_map.pop(duplicate_tuple)
|
|
568
710
|
|
|
569
|
-
return
|
|
711
|
+
return list(duplicate_nodes_map.values())
|