graphiti-core 0.20.4__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +28 -0
- graphiti_core/driver/falkordb_driver.py +112 -0
- graphiti_core/driver/kuzu_driver.py +1 -0
- graphiti_core/driver/neo4j_driver.py +10 -2
- graphiti_core/driver/neptune_driver.py +4 -6
- graphiti_core/edges.py +67 -7
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/graph_queries.py +35 -6
- graphiti_core/graphiti.py +27 -23
- graphiti_core/graphiti_types.py +0 -1
- graphiti_core/helpers.py +2 -2
- graphiti_core/llm_client/client.py +19 -4
- graphiti_core/llm_client/gemini_client.py +4 -2
- graphiti_core/llm_client/openai_base_client.py +3 -2
- graphiti_core/llm_client/openai_generic_client.py +3 -2
- graphiti_core/models/edges/edge_db_queries.py +36 -16
- graphiti_core/models/nodes/node_db_queries.py +30 -10
- graphiti_core/nodes.py +126 -25
- graphiti_core/prompts/dedupe_edges.py +40 -29
- graphiti_core/prompts/dedupe_nodes.py +51 -34
- graphiti_core/prompts/eval.py +3 -3
- graphiti_core/prompts/extract_edges.py +17 -9
- graphiti_core/prompts/extract_nodes.py +10 -9
- graphiti_core/prompts/prompt_helpers.py +3 -3
- graphiti_core/prompts/summarize_nodes.py +5 -5
- graphiti_core/search/search_filters.py +53 -0
- graphiti_core/search/search_helpers.py +5 -7
- graphiti_core/search/search_utils.py +227 -57
- graphiti_core/utils/bulk_utils.py +168 -69
- graphiti_core/utils/maintenance/community_operations.py +8 -20
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +187 -50
- graphiti_core/utils/maintenance/graph_data_operations.py +9 -5
- graphiti_core/utils/maintenance/node_operations.py +244 -88
- graphiti_core/utils/maintenance/temporal_operations.py +0 -4
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0.dist-info}/METADATA +7 -1
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0.dist-info}/RECORD +39 -38
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0.dist-info}/WHEEL +0 -0
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -36,9 +36,14 @@ from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
|
|
36
36
|
from graphiti_core.prompts import prompt_library
|
|
37
37
|
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
|
|
38
38
|
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
|
|
39
|
+
from graphiti_core.search.search import search
|
|
40
|
+
from graphiti_core.search.search_config import SearchResults
|
|
41
|
+
from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
|
|
39
42
|
from graphiti_core.search.search_filters import SearchFilters
|
|
40
|
-
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
|
|
41
43
|
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
|
|
44
|
+
from graphiti_core.utils.maintenance.dedup_helpers import _normalize_string_exact
|
|
45
|
+
|
|
46
|
+
DEFAULT_EDGE_NAME = 'RELATES_TO'
|
|
42
47
|
|
|
43
48
|
logger = logging.getLogger(__name__)
|
|
44
49
|
|
|
@@ -63,32 +68,6 @@ def build_episodic_edges(
|
|
|
63
68
|
return episodic_edges
|
|
64
69
|
|
|
65
70
|
|
|
66
|
-
def build_duplicate_of_edges(
|
|
67
|
-
episode: EpisodicNode,
|
|
68
|
-
created_at: datetime,
|
|
69
|
-
duplicate_nodes: list[tuple[EntityNode, EntityNode]],
|
|
70
|
-
) -> list[EntityEdge]:
|
|
71
|
-
is_duplicate_of_edges: list[EntityEdge] = []
|
|
72
|
-
for source_node, target_node in duplicate_nodes:
|
|
73
|
-
if source_node.uuid == target_node.uuid:
|
|
74
|
-
continue
|
|
75
|
-
|
|
76
|
-
is_duplicate_of_edges.append(
|
|
77
|
-
EntityEdge(
|
|
78
|
-
source_node_uuid=source_node.uuid,
|
|
79
|
-
target_node_uuid=target_node.uuid,
|
|
80
|
-
name='IS_DUPLICATE_OF',
|
|
81
|
-
group_id=episode.group_id,
|
|
82
|
-
fact=f'{source_node.name} is a duplicate of {target_node.name}',
|
|
83
|
-
episodes=[episode.uuid],
|
|
84
|
-
created_at=created_at,
|
|
85
|
-
valid_at=created_at,
|
|
86
|
-
)
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
return is_duplicate_of_edges
|
|
90
|
-
|
|
91
|
-
|
|
92
71
|
def build_community_edges(
|
|
93
72
|
entity_nodes: list[EntityNode],
|
|
94
73
|
community_node: CommunityNode,
|
|
@@ -151,7 +130,6 @@ async def extract_edges(
|
|
|
151
130
|
'reference_time': episode.valid_at,
|
|
152
131
|
'edge_types': edge_types_context,
|
|
153
132
|
'custom_prompt': '',
|
|
154
|
-
'ensure_ascii': clients.ensure_ascii,
|
|
155
133
|
}
|
|
156
134
|
|
|
157
135
|
facts_missed = True
|
|
@@ -161,6 +139,7 @@ async def extract_edges(
|
|
|
161
139
|
prompt_library.extract_edges.edge(context),
|
|
162
140
|
response_model=ExtractedEdges,
|
|
163
141
|
max_tokens=extract_edges_max_tokens,
|
|
142
|
+
group_id=group_id,
|
|
164
143
|
)
|
|
165
144
|
edges_data = ExtractedEdges(**llm_response).edges
|
|
166
145
|
|
|
@@ -172,6 +151,7 @@ async def extract_edges(
|
|
|
172
151
|
prompt_library.extract_edges.reflexion(context),
|
|
173
152
|
response_model=MissingFacts,
|
|
174
153
|
max_tokens=extract_edges_max_tokens,
|
|
154
|
+
group_id=group_id,
|
|
175
155
|
)
|
|
176
156
|
|
|
177
157
|
missing_facts = reflexion_response.get('missing_facts', [])
|
|
@@ -199,15 +179,26 @@ async def extract_edges(
|
|
|
199
179
|
valid_at_datetime = None
|
|
200
180
|
invalid_at_datetime = None
|
|
201
181
|
|
|
182
|
+
# Filter out empty edges
|
|
183
|
+
if not edge_data.fact.strip():
|
|
184
|
+
continue
|
|
185
|
+
|
|
202
186
|
source_node_idx = edge_data.source_entity_id
|
|
203
187
|
target_node_idx = edge_data.target_entity_id
|
|
204
|
-
|
|
188
|
+
|
|
189
|
+
if len(nodes) == 0:
|
|
190
|
+
logger.warning('No entities provided for edge extraction')
|
|
191
|
+
continue
|
|
192
|
+
|
|
193
|
+
if not (0 <= source_node_idx < len(nodes) and 0 <= target_node_idx < len(nodes)):
|
|
205
194
|
logger.warning(
|
|
206
|
-
f'
|
|
195
|
+
f'Invalid entity IDs in edge extraction for {edge_data.relation_type}. '
|
|
196
|
+
f'source_entity_id: {source_node_idx}, target_entity_id: {target_node_idx}, '
|
|
197
|
+
f'but only {len(nodes)} entities available (valid range: 0-{len(nodes) - 1})'
|
|
207
198
|
)
|
|
208
199
|
continue
|
|
209
200
|
source_node_uuid = nodes[source_node_idx].uuid
|
|
210
|
-
target_node_uuid = nodes[
|
|
201
|
+
target_node_uuid = nodes[target_node_idx].uuid
|
|
211
202
|
|
|
212
203
|
if valid_at:
|
|
213
204
|
try:
|
|
@@ -253,17 +244,65 @@ async def resolve_extracted_edges(
|
|
|
253
244
|
edge_types: dict[str, type[BaseModel]],
|
|
254
245
|
edge_type_map: dict[tuple[str, str], list[str]],
|
|
255
246
|
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
|
247
|
+
# Fast path: deduplicate exact matches within the extracted edges before parallel processing
|
|
248
|
+
seen: dict[tuple[str, str, str], EntityEdge] = {}
|
|
249
|
+
deduplicated_edges: list[EntityEdge] = []
|
|
250
|
+
|
|
251
|
+
for edge in extracted_edges:
|
|
252
|
+
key = (
|
|
253
|
+
edge.source_node_uuid,
|
|
254
|
+
edge.target_node_uuid,
|
|
255
|
+
_normalize_string_exact(edge.fact),
|
|
256
|
+
)
|
|
257
|
+
if key not in seen:
|
|
258
|
+
seen[key] = edge
|
|
259
|
+
deduplicated_edges.append(edge)
|
|
260
|
+
|
|
261
|
+
extracted_edges = deduplicated_edges
|
|
262
|
+
|
|
256
263
|
driver = clients.driver
|
|
257
264
|
llm_client = clients.llm_client
|
|
258
265
|
embedder = clients.embedder
|
|
259
266
|
await create_entity_edge_embeddings(embedder, extracted_edges)
|
|
260
267
|
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
268
|
+
valid_edges_list: list[list[EntityEdge]] = await semaphore_gather(
|
|
269
|
+
*[
|
|
270
|
+
EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
|
|
271
|
+
for edge in extracted_edges
|
|
272
|
+
]
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
related_edges_results: list[SearchResults] = await semaphore_gather(
|
|
276
|
+
*[
|
|
277
|
+
search(
|
|
278
|
+
clients,
|
|
279
|
+
extracted_edge.fact,
|
|
280
|
+
group_ids=[extracted_edge.group_id],
|
|
281
|
+
config=EDGE_HYBRID_SEARCH_RRF,
|
|
282
|
+
search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
|
|
283
|
+
)
|
|
284
|
+
for extracted_edge, valid_edges in zip(extracted_edges, valid_edges_list, strict=True)
|
|
285
|
+
]
|
|
264
286
|
)
|
|
265
287
|
|
|
266
|
-
related_edges_lists
|
|
288
|
+
related_edges_lists: list[list[EntityEdge]] = [result.edges for result in related_edges_results]
|
|
289
|
+
|
|
290
|
+
edge_invalidation_candidate_results: list[SearchResults] = await semaphore_gather(
|
|
291
|
+
*[
|
|
292
|
+
search(
|
|
293
|
+
clients,
|
|
294
|
+
extracted_edge.fact,
|
|
295
|
+
group_ids=[extracted_edge.group_id],
|
|
296
|
+
config=EDGE_HYBRID_SEARCH_RRF,
|
|
297
|
+
search_filter=SearchFilters(),
|
|
298
|
+
)
|
|
299
|
+
for extracted_edge in extracted_edges
|
|
300
|
+
]
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
edge_invalidation_candidates: list[list[EntityEdge]] = [
|
|
304
|
+
result.edges for result in edge_invalidation_candidate_results
|
|
305
|
+
]
|
|
267
306
|
|
|
268
307
|
logger.debug(
|
|
269
308
|
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
|
|
@@ -272,8 +311,12 @@ async def resolve_extracted_edges(
|
|
|
272
311
|
# Build entity hash table
|
|
273
312
|
uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
|
|
274
313
|
|
|
275
|
-
# Determine which edge types are relevant for each edge
|
|
314
|
+
# Determine which edge types are relevant for each edge.
|
|
315
|
+
# `edge_types_lst` stores the subset of custom edge definitions whose
|
|
316
|
+
# node signature matches each extracted edge. Anything outside this subset
|
|
317
|
+
# should only stay on the edge if it is a non-custom (LLM generated) label.
|
|
276
318
|
edge_types_lst: list[dict[str, type[BaseModel]]] = []
|
|
319
|
+
custom_type_names = set(edge_types or {})
|
|
277
320
|
for extracted_edge in extracted_edges:
|
|
278
321
|
source_node = uuid_entity_map.get(extracted_edge.source_node_uuid)
|
|
279
322
|
target_node = uuid_entity_map.get(extracted_edge.target_node_uuid)
|
|
@@ -301,6 +344,20 @@ async def resolve_extracted_edges(
|
|
|
301
344
|
|
|
302
345
|
edge_types_lst.append(extracted_edge_types)
|
|
303
346
|
|
|
347
|
+
for extracted_edge, extracted_edge_types in zip(extracted_edges, edge_types_lst, strict=True):
|
|
348
|
+
allowed_type_names = set(extracted_edge_types)
|
|
349
|
+
is_custom_name = extracted_edge.name in custom_type_names
|
|
350
|
+
if not allowed_type_names:
|
|
351
|
+
# No custom types are valid for this node pairing. Keep LLM generated
|
|
352
|
+
# labels, but flip disallowed custom names back to the default.
|
|
353
|
+
if is_custom_name and extracted_edge.name != DEFAULT_EDGE_NAME:
|
|
354
|
+
extracted_edge.name = DEFAULT_EDGE_NAME
|
|
355
|
+
continue
|
|
356
|
+
if is_custom_name and extracted_edge.name not in allowed_type_names:
|
|
357
|
+
# Custom name exists but it is not permitted for this source/target
|
|
358
|
+
# signature, so fall back to the default edge label.
|
|
359
|
+
extracted_edge.name = DEFAULT_EDGE_NAME
|
|
360
|
+
|
|
304
361
|
# resolve edges with related edges in the graph and find invalidation candidates
|
|
305
362
|
results: list[tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]] = list(
|
|
306
363
|
await semaphore_gather(
|
|
@@ -312,7 +369,7 @@ async def resolve_extracted_edges(
|
|
|
312
369
|
existing_edges,
|
|
313
370
|
episode,
|
|
314
371
|
extracted_edge_types,
|
|
315
|
-
|
|
372
|
+
custom_type_names,
|
|
316
373
|
)
|
|
317
374
|
for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
|
|
318
375
|
extracted_edges,
|
|
@@ -383,33 +440,69 @@ async def resolve_extracted_edge(
|
|
|
383
440
|
related_edges: list[EntityEdge],
|
|
384
441
|
existing_edges: list[EntityEdge],
|
|
385
442
|
episode: EpisodicNode,
|
|
386
|
-
|
|
387
|
-
|
|
443
|
+
edge_type_candidates: dict[str, type[BaseModel]] | None = None,
|
|
444
|
+
custom_edge_type_names: set[str] | None = None,
|
|
388
445
|
) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
|
|
446
|
+
"""Resolve an extracted edge against existing graph context.
|
|
447
|
+
|
|
448
|
+
Parameters
|
|
449
|
+
----------
|
|
450
|
+
llm_client : LLMClient
|
|
451
|
+
Client used to invoke the LLM for deduplication and attribute extraction.
|
|
452
|
+
extracted_edge : EntityEdge
|
|
453
|
+
Newly extracted edge whose canonical representation is being resolved.
|
|
454
|
+
related_edges : list[EntityEdge]
|
|
455
|
+
Candidate edges with identical endpoints used for duplicate detection.
|
|
456
|
+
existing_edges : list[EntityEdge]
|
|
457
|
+
Broader set of edges evaluated for contradiction / invalidation.
|
|
458
|
+
episode : EpisodicNode
|
|
459
|
+
Episode providing content context when extracting edge attributes.
|
|
460
|
+
edge_type_candidates : dict[str, type[BaseModel]] | None
|
|
461
|
+
Custom edge types permitted for the current source/target signature.
|
|
462
|
+
custom_edge_type_names : set[str] | None
|
|
463
|
+
Full catalog of registered custom edge names. Used to distinguish
|
|
464
|
+
between disallowed custom types (which fall back to the default label)
|
|
465
|
+
and ad-hoc labels emitted by the LLM.
|
|
466
|
+
|
|
467
|
+
Returns
|
|
468
|
+
-------
|
|
469
|
+
tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]
|
|
470
|
+
The resolved edge, any duplicates, and edges to invalidate.
|
|
471
|
+
"""
|
|
389
472
|
if len(related_edges) == 0 and len(existing_edges) == 0:
|
|
390
473
|
return extracted_edge, [], []
|
|
391
474
|
|
|
475
|
+
# Fast path: if the fact text and endpoints already exist verbatim, reuse the matching edge.
|
|
476
|
+
normalized_fact = _normalize_string_exact(extracted_edge.fact)
|
|
477
|
+
for edge in related_edges:
|
|
478
|
+
if (
|
|
479
|
+
edge.source_node_uuid == extracted_edge.source_node_uuid
|
|
480
|
+
and edge.target_node_uuid == extracted_edge.target_node_uuid
|
|
481
|
+
and _normalize_string_exact(edge.fact) == normalized_fact
|
|
482
|
+
):
|
|
483
|
+
resolved = edge
|
|
484
|
+
if episode is not None and episode.uuid not in resolved.episodes:
|
|
485
|
+
resolved.episodes.append(episode.uuid)
|
|
486
|
+
return resolved, [], []
|
|
487
|
+
|
|
392
488
|
start = time()
|
|
393
489
|
|
|
394
490
|
# Prepare context for LLM
|
|
395
|
-
related_edges_context = [
|
|
396
|
-
{'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
|
|
397
|
-
]
|
|
491
|
+
related_edges_context = [{'idx': i, 'fact': edge.fact} for i, edge in enumerate(related_edges)]
|
|
398
492
|
|
|
399
493
|
invalidation_edge_candidates_context = [
|
|
400
|
-
{'
|
|
494
|
+
{'idx': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
|
|
401
495
|
]
|
|
402
496
|
|
|
403
497
|
edge_types_context = (
|
|
404
498
|
[
|
|
405
499
|
{
|
|
406
|
-
'fact_type_id': i,
|
|
407
500
|
'fact_type_name': type_name,
|
|
408
501
|
'fact_type_description': type_model.__doc__,
|
|
409
502
|
}
|
|
410
|
-
for
|
|
503
|
+
for type_name, type_model in edge_type_candidates.items()
|
|
411
504
|
]
|
|
412
|
-
if
|
|
505
|
+
if edge_type_candidates is not None
|
|
413
506
|
else []
|
|
414
507
|
)
|
|
415
508
|
|
|
@@ -418,9 +511,17 @@ async def resolve_extracted_edge(
|
|
|
418
511
|
'new_edge': extracted_edge.fact,
|
|
419
512
|
'edge_invalidation_candidates': invalidation_edge_candidates_context,
|
|
420
513
|
'edge_types': edge_types_context,
|
|
421
|
-
'ensure_ascii': ensure_ascii,
|
|
422
514
|
}
|
|
423
515
|
|
|
516
|
+
if related_edges or existing_edges:
|
|
517
|
+
logger.debug(
|
|
518
|
+
'Resolving edge: sent %d EXISTING FACTS%s and %d INVALIDATION CANDIDATES%s',
|
|
519
|
+
len(related_edges),
|
|
520
|
+
f' (idx 0-{len(related_edges) - 1})' if related_edges else '',
|
|
521
|
+
len(existing_edges),
|
|
522
|
+
f' (idx 0-{len(existing_edges) - 1})' if existing_edges else '',
|
|
523
|
+
)
|
|
524
|
+
|
|
424
525
|
llm_response = await llm_client.generate_response(
|
|
425
526
|
prompt_library.dedupe_edges.resolve_edge(context),
|
|
426
527
|
response_model=EdgeDuplicate,
|
|
@@ -429,6 +530,15 @@ async def resolve_extracted_edge(
|
|
|
429
530
|
response_object = EdgeDuplicate(**llm_response)
|
|
430
531
|
duplicate_facts = response_object.duplicate_facts
|
|
431
532
|
|
|
533
|
+
# Validate duplicate_facts are in valid range for EXISTING FACTS
|
|
534
|
+
invalid_duplicates = [i for i in duplicate_facts if i < 0 or i >= len(related_edges)]
|
|
535
|
+
if invalid_duplicates:
|
|
536
|
+
logger.warning(
|
|
537
|
+
'LLM returned invalid duplicate_facts idx values %s (valid range: 0-%d for EXISTING FACTS)',
|
|
538
|
+
invalid_duplicates,
|
|
539
|
+
len(related_edges) - 1,
|
|
540
|
+
)
|
|
541
|
+
|
|
432
542
|
duplicate_fact_ids: list[int] = [i for i in duplicate_facts if 0 <= i < len(related_edges)]
|
|
433
543
|
|
|
434
544
|
resolved_edge = extracted_edge
|
|
@@ -441,22 +551,39 @@ async def resolve_extracted_edge(
|
|
|
441
551
|
|
|
442
552
|
contradicted_facts: list[int] = response_object.contradicted_facts
|
|
443
553
|
|
|
554
|
+
# Validate contradicted_facts are in valid range for INVALIDATION CANDIDATES
|
|
555
|
+
invalid_contradictions = [i for i in contradicted_facts if i < 0 or i >= len(existing_edges)]
|
|
556
|
+
if invalid_contradictions:
|
|
557
|
+
logger.warning(
|
|
558
|
+
'LLM returned invalid contradicted_facts idx values %s (valid range: 0-%d for INVALIDATION CANDIDATES)',
|
|
559
|
+
invalid_contradictions,
|
|
560
|
+
len(existing_edges) - 1,
|
|
561
|
+
)
|
|
562
|
+
|
|
444
563
|
invalidation_candidates: list[EntityEdge] = [
|
|
445
564
|
existing_edges[i] for i in contradicted_facts if 0 <= i < len(existing_edges)
|
|
446
565
|
]
|
|
447
566
|
|
|
448
567
|
fact_type: str = response_object.fact_type
|
|
449
|
-
|
|
568
|
+
candidate_type_names = set(edge_type_candidates or {})
|
|
569
|
+
custom_type_names = custom_edge_type_names or set()
|
|
570
|
+
|
|
571
|
+
is_default_type = fact_type.upper() == 'DEFAULT'
|
|
572
|
+
is_custom_type = fact_type in custom_type_names
|
|
573
|
+
is_allowed_custom_type = fact_type in candidate_type_names
|
|
574
|
+
|
|
575
|
+
if is_allowed_custom_type:
|
|
576
|
+
# The LLM selected a custom type that is allowed for the node pair.
|
|
577
|
+
# Adopt the custom type and, if needed, extract its structured attributes.
|
|
450
578
|
resolved_edge.name = fact_type
|
|
451
579
|
|
|
452
580
|
edge_attributes_context = {
|
|
453
581
|
'episode_content': episode.content,
|
|
454
582
|
'reference_time': episode.valid_at,
|
|
455
583
|
'fact': resolved_edge.fact,
|
|
456
|
-
'ensure_ascii': ensure_ascii,
|
|
457
584
|
}
|
|
458
585
|
|
|
459
|
-
edge_model =
|
|
586
|
+
edge_model = edge_type_candidates.get(fact_type) if edge_type_candidates else None
|
|
460
587
|
if edge_model is not None and len(edge_model.model_fields) != 0:
|
|
461
588
|
edge_attributes_response = await llm_client.generate_response(
|
|
462
589
|
prompt_library.extract_edges.extract_attributes(edge_attributes_context),
|
|
@@ -465,6 +592,16 @@ async def resolve_extracted_edge(
|
|
|
465
592
|
)
|
|
466
593
|
|
|
467
594
|
resolved_edge.attributes = edge_attributes_response
|
|
595
|
+
elif not is_default_type and is_custom_type:
|
|
596
|
+
# The LLM picked a custom type that is not allowed for this signature.
|
|
597
|
+
# Reset to the default label and drop any structured attributes.
|
|
598
|
+
resolved_edge.name = DEFAULT_EDGE_NAME
|
|
599
|
+
resolved_edge.attributes = {}
|
|
600
|
+
elif not is_default_type:
|
|
601
|
+
# Non-custom labels are allowed to pass through so long as the LLM does
|
|
602
|
+
# not return the sentinel DEFAULT value.
|
|
603
|
+
resolved_edge.name = fact_type
|
|
604
|
+
resolved_edge.attributes = {}
|
|
468
605
|
|
|
469
606
|
end = time()
|
|
470
607
|
logger.debug(
|
|
@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
|
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
|
|
37
|
-
if driver.
|
|
37
|
+
if driver.aoss_client:
|
|
38
38
|
await driver.create_aoss_indices() # pyright: ignore[reportAttributeAccessIssue]
|
|
39
39
|
return
|
|
40
40
|
if delete_existing:
|
|
@@ -56,7 +56,9 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
|
|
|
56
56
|
|
|
57
57
|
range_indices: list[LiteralString] = get_range_indices(driver.provider)
|
|
58
58
|
|
|
59
|
-
|
|
59
|
+
# Don't create fulltext indices if OpenSearch is being used
|
|
60
|
+
if not driver.aoss_client:
|
|
61
|
+
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
|
|
60
62
|
|
|
61
63
|
if driver.provider == GraphProvider.KUZU:
|
|
62
64
|
# Skip creating fulltext indices if they already exist. Need to do this manually
|
|
@@ -93,6 +95,8 @@ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
|
|
|
93
95
|
|
|
94
96
|
async def delete_all(tx):
|
|
95
97
|
await tx.run('MATCH (n) DETACH DELETE n')
|
|
98
|
+
if driver.aoss_client:
|
|
99
|
+
await driver.clear_aoss_indices()
|
|
96
100
|
|
|
97
101
|
async def delete_group_ids(tx):
|
|
98
102
|
labels = ['Entity', 'Episodic', 'Community']
|
|
@@ -149,9 +153,9 @@ async def retrieve_episodes(
|
|
|
149
153
|
|
|
150
154
|
query: LiteralString = (
|
|
151
155
|
"""
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
156
|
+
MATCH (e:Episodic)
|
|
157
|
+
WHERE e.valid_at <= $reference_time
|
|
158
|
+
"""
|
|
155
159
|
+ query_filter
|
|
156
160
|
+ """
|
|
157
161
|
RETURN
|