graphiti-core 0.21.0rc12__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +4 -211
- graphiti_core/driver/falkordb_driver.py +31 -3
- graphiti_core/driver/graph_operations/graph_operations.py +195 -0
- graphiti_core/driver/neo4j_driver.py +0 -49
- graphiti_core/driver/neptune_driver.py +43 -26
- graphiti_core/driver/search_interface/__init__.py +0 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +11 -34
- graphiti_core/graphiti.py +459 -326
- graphiti_core/graphiti_types.py +2 -0
- graphiti_core/llm_client/anthropic_client.py +64 -45
- graphiti_core/llm_client/client.py +67 -19
- graphiti_core/llm_client/gemini_client.py +73 -54
- graphiti_core/llm_client/openai_base_client.py +65 -43
- graphiti_core/llm_client/openai_generic_client.py +65 -43
- graphiti_core/models/edges/edge_db_queries.py +1 -0
- graphiti_core/models/nodes/node_db_queries.py +1 -0
- graphiti_core/nodes.py +26 -99
- graphiti_core/prompts/dedupe_edges.py +4 -4
- graphiti_core/prompts/dedupe_nodes.py +10 -10
- graphiti_core/prompts/extract_edges.py +4 -4
- graphiti_core/prompts/extract_nodes.py +26 -28
- graphiti_core/prompts/prompt_helpers.py +18 -2
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +22 -24
- graphiti_core/search/search_filters.py +0 -38
- graphiti_core/search/search_helpers.py +4 -4
- graphiti_core/search/search_utils.py +84 -220
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +16 -28
- graphiti_core/utils/maintenance/community_operations.py +4 -1
- graphiti_core/utils/maintenance/edge_operations.py +30 -15
- graphiti_core/utils/maintenance/graph_data_operations.py +6 -25
- graphiti_core/utils/maintenance/node_operations.py +99 -51
- graphiti_core/utils/maintenance/temporal_operations.py +4 -1
- graphiti_core/utils/text_utils.py +53 -0
- {graphiti_core-0.21.0rc12.dist-info → graphiti_core-0.22.0.dist-info}/METADATA +7 -3
- {graphiti_core-0.21.0rc12.dist-info → graphiti_core-0.22.0.dist-info}/RECORD +41 -35
- /graphiti_core/{utils/maintenance/utils.py → driver/graph_operations/__init__.py} +0 -0
- {graphiti_core-0.21.0rc12.dist-info → graphiti_core-0.22.0.dist-info}/WHEEL +0 -0
- {graphiti_core-0.21.0rc12.dist-info → graphiti_core-0.22.0.dist-info}/licenses/LICENSE +0 -0
graphiti_core/graphiti.py
CHANGED
|
@@ -63,6 +63,7 @@ from graphiti_core.search.search_utils import (
|
|
|
63
63
|
get_mentioned_nodes,
|
|
64
64
|
)
|
|
65
65
|
from graphiti_core.telemetry import capture_event
|
|
66
|
+
from graphiti_core.tracer import Tracer, create_tracer
|
|
66
67
|
from graphiti_core.utils.bulk_utils import (
|
|
67
68
|
RawEpisode,
|
|
68
69
|
add_nodes_and_edges_bulk,
|
|
@@ -136,6 +137,8 @@ class Graphiti:
|
|
|
136
137
|
store_raw_episode_content: bool = True,
|
|
137
138
|
graph_driver: GraphDriver | None = None,
|
|
138
139
|
max_coroutines: int | None = None,
|
|
140
|
+
tracer: Tracer | None = None,
|
|
141
|
+
trace_span_prefix: str = 'graphiti',
|
|
139
142
|
):
|
|
140
143
|
"""
|
|
141
144
|
Initialize a Graphiti instance.
|
|
@@ -168,6 +171,10 @@ class Graphiti:
|
|
|
168
171
|
max_coroutines : int | None, optional
|
|
169
172
|
The maximum number of concurrent operations allowed. Overrides SEMAPHORE_LIMIT set in the environment.
|
|
170
173
|
If not set, the Graphiti default is used.
|
|
174
|
+
tracer : Tracer | None, optional
|
|
175
|
+
An OpenTelemetry tracer instance for distributed tracing. If not provided, tracing is disabled (no-op).
|
|
176
|
+
trace_span_prefix : str, optional
|
|
177
|
+
Prefix to prepend to all span names. Defaults to 'graphiti'.
|
|
171
178
|
|
|
172
179
|
Returns
|
|
173
180
|
-------
|
|
@@ -210,11 +217,18 @@ class Graphiti:
|
|
|
210
217
|
else:
|
|
211
218
|
self.cross_encoder = OpenAIRerankerClient()
|
|
212
219
|
|
|
220
|
+
# Initialize tracer
|
|
221
|
+
self.tracer = create_tracer(tracer, trace_span_prefix)
|
|
222
|
+
|
|
223
|
+
# Set tracer on clients
|
|
224
|
+
self.llm_client.set_tracer(self.tracer)
|
|
225
|
+
|
|
213
226
|
self.clients = GraphitiClients(
|
|
214
227
|
driver=self.driver,
|
|
215
228
|
llm_client=self.llm_client,
|
|
216
229
|
embedder=self.embedder,
|
|
217
230
|
cross_encoder=self.cross_encoder,
|
|
231
|
+
tracer=self.tracer,
|
|
218
232
|
)
|
|
219
233
|
|
|
220
234
|
# Capture telemetry event
|
|
@@ -339,6 +353,227 @@ class Graphiti:
|
|
|
339
353
|
"""
|
|
340
354
|
await build_indices_and_constraints(self.driver, delete_existing)
|
|
341
355
|
|
|
356
|
+
async def _extract_and_resolve_nodes(
|
|
357
|
+
self,
|
|
358
|
+
episode: EpisodicNode,
|
|
359
|
+
previous_episodes: list[EpisodicNode],
|
|
360
|
+
entity_types: dict[str, type[BaseModel]] | None,
|
|
361
|
+
excluded_entity_types: list[str] | None,
|
|
362
|
+
) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
|
|
363
|
+
"""Extract nodes from episode and resolve against existing graph."""
|
|
364
|
+
extracted_nodes = await extract_nodes(
|
|
365
|
+
self.clients, episode, previous_episodes, entity_types, excluded_entity_types
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
nodes, uuid_map, duplicates = await resolve_extracted_nodes(
|
|
369
|
+
self.clients,
|
|
370
|
+
extracted_nodes,
|
|
371
|
+
episode,
|
|
372
|
+
previous_episodes,
|
|
373
|
+
entity_types,
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
return nodes, uuid_map, duplicates
|
|
377
|
+
|
|
378
|
+
async def _extract_and_resolve_edges(
|
|
379
|
+
self,
|
|
380
|
+
episode: EpisodicNode,
|
|
381
|
+
extracted_nodes: list[EntityNode],
|
|
382
|
+
previous_episodes: list[EpisodicNode],
|
|
383
|
+
edge_type_map: dict[tuple[str, str], list[str]],
|
|
384
|
+
group_id: str,
|
|
385
|
+
edge_types: dict[str, type[BaseModel]] | None,
|
|
386
|
+
nodes: list[EntityNode],
|
|
387
|
+
uuid_map: dict[str, str],
|
|
388
|
+
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
|
389
|
+
"""Extract edges from episode and resolve against existing graph."""
|
|
390
|
+
extracted_edges = await extract_edges(
|
|
391
|
+
self.clients,
|
|
392
|
+
episode,
|
|
393
|
+
extracted_nodes,
|
|
394
|
+
previous_episodes,
|
|
395
|
+
edge_type_map,
|
|
396
|
+
group_id,
|
|
397
|
+
edge_types,
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
edges = resolve_edge_pointers(extracted_edges, uuid_map)
|
|
401
|
+
|
|
402
|
+
resolved_edges, invalidated_edges = await resolve_extracted_edges(
|
|
403
|
+
self.clients,
|
|
404
|
+
edges,
|
|
405
|
+
episode,
|
|
406
|
+
nodes,
|
|
407
|
+
edge_types or {},
|
|
408
|
+
edge_type_map,
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
return resolved_edges, invalidated_edges
|
|
412
|
+
|
|
413
|
+
async def _process_episode_data(
|
|
414
|
+
self,
|
|
415
|
+
episode: EpisodicNode,
|
|
416
|
+
nodes: list[EntityNode],
|
|
417
|
+
entity_edges: list[EntityEdge],
|
|
418
|
+
now: datetime,
|
|
419
|
+
) -> tuple[list[EpisodicEdge], EpisodicNode]:
|
|
420
|
+
"""Process and save episode data to the graph."""
|
|
421
|
+
episodic_edges = build_episodic_edges(nodes, episode.uuid, now)
|
|
422
|
+
episode.entity_edges = [edge.uuid for edge in entity_edges]
|
|
423
|
+
|
|
424
|
+
if not self.store_raw_episode_content:
|
|
425
|
+
episode.content = ''
|
|
426
|
+
|
|
427
|
+
await add_nodes_and_edges_bulk(
|
|
428
|
+
self.driver,
|
|
429
|
+
[episode],
|
|
430
|
+
episodic_edges,
|
|
431
|
+
nodes,
|
|
432
|
+
entity_edges,
|
|
433
|
+
self.embedder,
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
return episodic_edges, episode
|
|
437
|
+
|
|
438
|
+
async def _extract_and_dedupe_nodes_bulk(
|
|
439
|
+
self,
|
|
440
|
+
episode_context: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
|
441
|
+
edge_type_map: dict[tuple[str, str], list[str]],
|
|
442
|
+
edge_types: dict[str, type[BaseModel]] | None,
|
|
443
|
+
entity_types: dict[str, type[BaseModel]] | None,
|
|
444
|
+
excluded_entity_types: list[str] | None,
|
|
445
|
+
) -> tuple[
|
|
446
|
+
dict[str, list[EntityNode]],
|
|
447
|
+
dict[str, str],
|
|
448
|
+
list[list[EntityEdge]],
|
|
449
|
+
]:
|
|
450
|
+
"""Extract nodes and edges from all episodes and deduplicate."""
|
|
451
|
+
# Extract all nodes and edges for each episode
|
|
452
|
+
extracted_nodes_bulk, extracted_edges_bulk = await extract_nodes_and_edges_bulk(
|
|
453
|
+
self.clients,
|
|
454
|
+
episode_context,
|
|
455
|
+
edge_type_map=edge_type_map,
|
|
456
|
+
edge_types=edge_types,
|
|
457
|
+
entity_types=entity_types,
|
|
458
|
+
excluded_entity_types=excluded_entity_types,
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
# Dedupe extracted nodes in memory
|
|
462
|
+
nodes_by_episode, uuid_map = await dedupe_nodes_bulk(
|
|
463
|
+
self.clients, extracted_nodes_bulk, episode_context, entity_types
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
return nodes_by_episode, uuid_map, extracted_edges_bulk
|
|
467
|
+
|
|
468
|
+
async def _resolve_nodes_and_edges_bulk(
|
|
469
|
+
self,
|
|
470
|
+
nodes_by_episode: dict[str, list[EntityNode]],
|
|
471
|
+
edges_by_episode: dict[str, list[EntityEdge]],
|
|
472
|
+
episode_context: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
|
473
|
+
entity_types: dict[str, type[BaseModel]] | None,
|
|
474
|
+
edge_types: dict[str, type[BaseModel]] | None,
|
|
475
|
+
edge_type_map: dict[tuple[str, str], list[str]],
|
|
476
|
+
episodes: list[EpisodicNode],
|
|
477
|
+
) -> tuple[list[EntityNode], list[EntityEdge], list[EntityEdge], dict[str, str]]:
|
|
478
|
+
"""Resolve nodes and edges against the existing graph."""
|
|
479
|
+
nodes_by_uuid: dict[str, EntityNode] = {
|
|
480
|
+
node.uuid: node for nodes in nodes_by_episode.values() for node in nodes
|
|
481
|
+
}
|
|
482
|
+
|
|
483
|
+
# Get unique nodes per episode
|
|
484
|
+
nodes_by_episode_unique: dict[str, list[EntityNode]] = {}
|
|
485
|
+
nodes_uuid_set: set[str] = set()
|
|
486
|
+
for episode, _ in episode_context:
|
|
487
|
+
nodes_by_episode_unique[episode.uuid] = []
|
|
488
|
+
nodes = [nodes_by_uuid[node.uuid] for node in nodes_by_episode[episode.uuid]]
|
|
489
|
+
for node in nodes:
|
|
490
|
+
if node.uuid not in nodes_uuid_set:
|
|
491
|
+
nodes_by_episode_unique[episode.uuid].append(node)
|
|
492
|
+
nodes_uuid_set.add(node.uuid)
|
|
493
|
+
|
|
494
|
+
# Resolve nodes
|
|
495
|
+
node_results = await semaphore_gather(
|
|
496
|
+
*[
|
|
497
|
+
resolve_extracted_nodes(
|
|
498
|
+
self.clients,
|
|
499
|
+
nodes_by_episode_unique[episode.uuid],
|
|
500
|
+
episode,
|
|
501
|
+
previous_episodes,
|
|
502
|
+
entity_types,
|
|
503
|
+
)
|
|
504
|
+
for episode, previous_episodes in episode_context
|
|
505
|
+
]
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
resolved_nodes: list[EntityNode] = []
|
|
509
|
+
uuid_map: dict[str, str] = {}
|
|
510
|
+
for result in node_results:
|
|
511
|
+
resolved_nodes.extend(result[0])
|
|
512
|
+
uuid_map.update(result[1])
|
|
513
|
+
|
|
514
|
+
# Update nodes_by_uuid with resolved nodes
|
|
515
|
+
for resolved_node in resolved_nodes:
|
|
516
|
+
nodes_by_uuid[resolved_node.uuid] = resolved_node
|
|
517
|
+
|
|
518
|
+
# Update nodes_by_episode_unique with resolved pointers
|
|
519
|
+
for episode_uuid, nodes in nodes_by_episode_unique.items():
|
|
520
|
+
updated_nodes: list[EntityNode] = []
|
|
521
|
+
for node in nodes:
|
|
522
|
+
updated_node_uuid = uuid_map.get(node.uuid, node.uuid)
|
|
523
|
+
updated_node = nodes_by_uuid[updated_node_uuid]
|
|
524
|
+
updated_nodes.append(updated_node)
|
|
525
|
+
nodes_by_episode_unique[episode_uuid] = updated_nodes
|
|
526
|
+
|
|
527
|
+
# Extract attributes for resolved nodes
|
|
528
|
+
hydrated_nodes_results: list[list[EntityNode]] = await semaphore_gather(
|
|
529
|
+
*[
|
|
530
|
+
extract_attributes_from_nodes(
|
|
531
|
+
self.clients,
|
|
532
|
+
nodes_by_episode_unique[episode.uuid],
|
|
533
|
+
episode,
|
|
534
|
+
previous_episodes,
|
|
535
|
+
entity_types,
|
|
536
|
+
)
|
|
537
|
+
for episode, previous_episodes in episode_context
|
|
538
|
+
]
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
final_hydrated_nodes = [node for nodes in hydrated_nodes_results for node in nodes]
|
|
542
|
+
|
|
543
|
+
# Resolve edges with updated pointers
|
|
544
|
+
edges_by_episode_unique: dict[str, list[EntityEdge]] = {}
|
|
545
|
+
edges_uuid_set: set[str] = set()
|
|
546
|
+
for episode_uuid, edges in edges_by_episode.items():
|
|
547
|
+
edges_with_updated_pointers = resolve_edge_pointers(edges, uuid_map)
|
|
548
|
+
edges_by_episode_unique[episode_uuid] = []
|
|
549
|
+
|
|
550
|
+
for edge in edges_with_updated_pointers:
|
|
551
|
+
if edge.uuid not in edges_uuid_set:
|
|
552
|
+
edges_by_episode_unique[episode_uuid].append(edge)
|
|
553
|
+
edges_uuid_set.add(edge.uuid)
|
|
554
|
+
|
|
555
|
+
edge_results = await semaphore_gather(
|
|
556
|
+
*[
|
|
557
|
+
resolve_extracted_edges(
|
|
558
|
+
self.clients,
|
|
559
|
+
edges_by_episode_unique[episode.uuid],
|
|
560
|
+
episode,
|
|
561
|
+
final_hydrated_nodes,
|
|
562
|
+
edge_types or {},
|
|
563
|
+
edge_type_map,
|
|
564
|
+
)
|
|
565
|
+
for episode in episodes
|
|
566
|
+
]
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
resolved_edges: list[EntityEdge] = []
|
|
570
|
+
invalidated_edges: list[EntityEdge] = []
|
|
571
|
+
for result in edge_results:
|
|
572
|
+
resolved_edges.extend(result[0])
|
|
573
|
+
invalidated_edges.extend(result[1])
|
|
574
|
+
|
|
575
|
+
return final_hydrated_nodes, resolved_edges, invalidated_edges, uuid_map
|
|
576
|
+
|
|
342
577
|
async def retrieve_episodes(
|
|
343
578
|
self,
|
|
344
579
|
reference_time: datetime,
|
|
@@ -444,133 +679,138 @@ class Graphiti:
|
|
|
444
679
|
background_tasks.add_task(graphiti.add_episode, **episode_data.dict())
|
|
445
680
|
return {"message": "Episode processing started"}
|
|
446
681
|
"""
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
validate_entity_types(entity_types)
|
|
452
|
-
|
|
453
|
-
validate_excluded_entity_types(excluded_entity_types, entity_types)
|
|
454
|
-
validate_group_id(group_id)
|
|
455
|
-
# if group_id is None, use the default group id by the provider
|
|
456
|
-
group_id = group_id or get_default_group_id(self.driver.provider)
|
|
457
|
-
|
|
458
|
-
previous_episodes = (
|
|
459
|
-
await self.retrieve_episodes(
|
|
460
|
-
reference_time,
|
|
461
|
-
last_n=RELEVANT_SCHEMA_LIMIT,
|
|
462
|
-
group_ids=[group_id],
|
|
463
|
-
source=source,
|
|
464
|
-
)
|
|
465
|
-
if previous_episode_uuids is None
|
|
466
|
-
else await EpisodicNode.get_by_uuids(self.driver, previous_episode_uuids)
|
|
467
|
-
)
|
|
682
|
+
start = time()
|
|
683
|
+
now = utc_now()
|
|
684
|
+
|
|
685
|
+
validate_entity_types(entity_types)
|
|
468
686
|
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
687
|
+
validate_excluded_entity_types(excluded_entity_types, entity_types)
|
|
688
|
+
validate_group_id(group_id)
|
|
689
|
+
# if group_id is None, use the default group id by the provider
|
|
690
|
+
group_id = group_id or get_default_group_id(self.driver.provider)
|
|
691
|
+
|
|
692
|
+
with self.tracer.start_span('add_episode') as span:
|
|
693
|
+
try:
|
|
694
|
+
# Retrieve previous episodes for context
|
|
695
|
+
previous_episodes = (
|
|
696
|
+
await self.retrieve_episodes(
|
|
697
|
+
reference_time,
|
|
698
|
+
last_n=RELEVANT_SCHEMA_LIMIT,
|
|
699
|
+
group_ids=[group_id],
|
|
700
|
+
source=source,
|
|
701
|
+
)
|
|
702
|
+
if previous_episode_uuids is None
|
|
703
|
+
else await EpisodicNode.get_by_uuids(self.driver, previous_episode_uuids)
|
|
481
704
|
)
|
|
482
|
-
)
|
|
483
705
|
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
706
|
+
# Get or create episode
|
|
707
|
+
episode = (
|
|
708
|
+
await EpisodicNode.get_by_uuid(self.driver, uuid)
|
|
709
|
+
if uuid is not None
|
|
710
|
+
else EpisodicNode(
|
|
711
|
+
name=name,
|
|
712
|
+
group_id=group_id,
|
|
713
|
+
labels=[],
|
|
714
|
+
source=source,
|
|
715
|
+
content=episode_body,
|
|
716
|
+
source_description=source_description,
|
|
717
|
+
created_at=now,
|
|
718
|
+
valid_at=reference_time,
|
|
719
|
+
)
|
|
720
|
+
)
|
|
490
721
|
|
|
491
|
-
|
|
722
|
+
# Create default edge type map
|
|
723
|
+
edge_type_map_default = (
|
|
724
|
+
{('Entity', 'Entity'): list(edge_types.keys())}
|
|
725
|
+
if edge_types is not None
|
|
726
|
+
else {('Entity', 'Entity'): []}
|
|
727
|
+
)
|
|
492
728
|
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
729
|
+
# Extract and resolve nodes
|
|
730
|
+
extracted_nodes = await extract_nodes(
|
|
731
|
+
self.clients, episode, previous_episodes, entity_types, excluded_entity_types
|
|
732
|
+
)
|
|
496
733
|
|
|
497
|
-
|
|
498
|
-
(nodes, uuid_map, _), extracted_edges = await semaphore_gather(
|
|
499
|
-
resolve_extracted_nodes(
|
|
734
|
+
nodes, uuid_map, _ = await resolve_extracted_nodes(
|
|
500
735
|
self.clients,
|
|
501
736
|
extracted_nodes,
|
|
502
737
|
episode,
|
|
503
738
|
previous_episodes,
|
|
504
739
|
entity_types,
|
|
505
|
-
)
|
|
506
|
-
|
|
507
|
-
|
|
740
|
+
)
|
|
741
|
+
|
|
742
|
+
# Extract and resolve edges in parallel with attribute extraction
|
|
743
|
+
resolved_edges, invalidated_edges = await self._extract_and_resolve_edges(
|
|
508
744
|
episode,
|
|
509
745
|
extracted_nodes,
|
|
510
746
|
previous_episodes,
|
|
511
747
|
edge_type_map or edge_type_map_default,
|
|
512
748
|
group_id,
|
|
513
749
|
edge_types,
|
|
514
|
-
),
|
|
515
|
-
max_coroutines=self.max_coroutines,
|
|
516
|
-
)
|
|
517
|
-
|
|
518
|
-
edges = resolve_edge_pointers(extracted_edges, uuid_map)
|
|
519
|
-
|
|
520
|
-
(resolved_edges, invalidated_edges), hydrated_nodes = await semaphore_gather(
|
|
521
|
-
resolve_extracted_edges(
|
|
522
|
-
self.clients,
|
|
523
|
-
edges,
|
|
524
|
-
episode,
|
|
525
750
|
nodes,
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
751
|
+
uuid_map,
|
|
752
|
+
)
|
|
753
|
+
|
|
754
|
+
# Extract node attributes
|
|
755
|
+
hydrated_nodes = await extract_attributes_from_nodes(
|
|
530
756
|
self.clients, nodes, episode, previous_episodes, entity_types
|
|
531
|
-
)
|
|
532
|
-
max_coroutines=self.max_coroutines,
|
|
533
|
-
)
|
|
757
|
+
)
|
|
534
758
|
|
|
535
|
-
|
|
759
|
+
entity_edges = resolved_edges + invalidated_edges
|
|
536
760
|
|
|
537
|
-
|
|
761
|
+
# Process and save episode data
|
|
762
|
+
episodic_edges, episode = await self._process_episode_data(
|
|
763
|
+
episode, hydrated_nodes, entity_edges, now
|
|
764
|
+
)
|
|
538
765
|
|
|
539
|
-
|
|
766
|
+
# Update communities if requested
|
|
767
|
+
communities = []
|
|
768
|
+
community_edges = []
|
|
769
|
+
if update_communities:
|
|
770
|
+
communities, community_edges = await semaphore_gather(
|
|
771
|
+
*[
|
|
772
|
+
update_community(self.driver, self.llm_client, self.embedder, node)
|
|
773
|
+
for node in nodes
|
|
774
|
+
],
|
|
775
|
+
max_coroutines=self.max_coroutines,
|
|
776
|
+
)
|
|
540
777
|
|
|
541
|
-
|
|
542
|
-
|
|
778
|
+
end = time()
|
|
779
|
+
|
|
780
|
+
# Add span attributes
|
|
781
|
+
span.add_attributes(
|
|
782
|
+
{
|
|
783
|
+
'episode.uuid': episode.uuid,
|
|
784
|
+
'episode.source': source.value,
|
|
785
|
+
'episode.reference_time': reference_time.isoformat(),
|
|
786
|
+
'group_id': group_id,
|
|
787
|
+
'node.count': len(hydrated_nodes),
|
|
788
|
+
'edge.count': len(entity_edges),
|
|
789
|
+
'edge.invalidated_count': len(invalidated_edges),
|
|
790
|
+
'previous_episodes.count': len(previous_episodes),
|
|
791
|
+
'entity_types.count': len(entity_types) if entity_types else 0,
|
|
792
|
+
'edge_types.count': len(edge_types) if edge_types else 0,
|
|
793
|
+
'update_communities': update_communities,
|
|
794
|
+
'communities.count': len(communities) if update_communities else 0,
|
|
795
|
+
'duration_ms': (end - start) * 1000,
|
|
796
|
+
}
|
|
797
|
+
)
|
|
543
798
|
|
|
544
|
-
|
|
545
|
-
self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges, self.embedder
|
|
546
|
-
)
|
|
799
|
+
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
|
547
800
|
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
update_community(self.driver, self.llm_client, self.embedder, node)
|
|
556
|
-
for node in nodes
|
|
557
|
-
],
|
|
558
|
-
max_coroutines=self.max_coroutines,
|
|
801
|
+
return AddEpisodeResults(
|
|
802
|
+
episode=episode,
|
|
803
|
+
episodic_edges=episodic_edges,
|
|
804
|
+
nodes=hydrated_nodes,
|
|
805
|
+
edges=entity_edges,
|
|
806
|
+
communities=communities,
|
|
807
|
+
community_edges=community_edges,
|
|
559
808
|
)
|
|
560
|
-
end = time()
|
|
561
|
-
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
|
562
|
-
|
|
563
|
-
return AddEpisodeResults(
|
|
564
|
-
episode=episode,
|
|
565
|
-
episodic_edges=episodic_edges,
|
|
566
|
-
nodes=hydrated_nodes,
|
|
567
|
-
edges=entity_edges,
|
|
568
|
-
communities=communities,
|
|
569
|
-
community_edges=community_edges,
|
|
570
|
-
)
|
|
571
809
|
|
|
572
|
-
|
|
573
|
-
|
|
810
|
+
except Exception as e:
|
|
811
|
+
span.set_status('error', str(e))
|
|
812
|
+
span.record_exception(e)
|
|
813
|
+
raise e
|
|
574
814
|
|
|
575
815
|
async def add_episode_bulk(
|
|
576
816
|
self,
|
|
@@ -617,248 +857,141 @@ class Graphiti:
|
|
|
617
857
|
If these operations are required, use the `add_episode` method instead for each
|
|
618
858
|
individual episode.
|
|
619
859
|
"""
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
now = utc_now()
|
|
623
|
-
|
|
624
|
-
# if group_id is None, use the default group id by the provider
|
|
625
|
-
group_id = group_id or get_default_group_id(self.driver.provider)
|
|
626
|
-
validate_group_id(group_id)
|
|
627
|
-
|
|
628
|
-
# Create default edge type map
|
|
629
|
-
edge_type_map_default = (
|
|
630
|
-
{('Entity', 'Entity'): list(edge_types.keys())}
|
|
631
|
-
if edge_types is not None
|
|
632
|
-
else {('Entity', 'Entity'): []}
|
|
633
|
-
)
|
|
860
|
+
with self.tracer.start_span('add_episode_bulk') as bulk_span:
|
|
861
|
+
bulk_span.add_attributes({'episode.count': len(bulk_episodes)})
|
|
634
862
|
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
863
|
+
try:
|
|
864
|
+
start = time()
|
|
865
|
+
now = utc_now()
|
|
866
|
+
|
|
867
|
+
# if group_id is None, use the default group id by the provider
|
|
868
|
+
group_id = group_id or get_default_group_id(self.driver.provider)
|
|
869
|
+
validate_group_id(group_id)
|
|
870
|
+
|
|
871
|
+
# Create default edge type map
|
|
872
|
+
edge_type_map_default = (
|
|
873
|
+
{('Entity', 'Entity'): list(edge_types.keys())}
|
|
874
|
+
if edge_types is not None
|
|
875
|
+
else {('Entity', 'Entity'): []}
|
|
647
876
|
)
|
|
648
|
-
for episode in bulk_episodes
|
|
649
|
-
]
|
|
650
|
-
|
|
651
|
-
episodes_by_uuid: dict[str, EpisodicNode] = {
|
|
652
|
-
episode.uuid: episode for episode in episodes
|
|
653
|
-
}
|
|
654
|
-
|
|
655
|
-
# Save all episodes
|
|
656
|
-
await add_nodes_and_edges_bulk(
|
|
657
|
-
driver=self.driver,
|
|
658
|
-
episodic_nodes=episodes,
|
|
659
|
-
episodic_edges=[],
|
|
660
|
-
entity_nodes=[],
|
|
661
|
-
entity_edges=[],
|
|
662
|
-
embedder=self.embedder,
|
|
663
|
-
)
|
|
664
|
-
|
|
665
|
-
# Get previous episode context for each episode
|
|
666
|
-
episode_context = await retrieve_previous_episodes_bulk(self.driver, episodes)
|
|
667
877
|
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
878
|
+
episodes = [
|
|
879
|
+
await EpisodicNode.get_by_uuid(self.driver, episode.uuid)
|
|
880
|
+
if episode.uuid is not None
|
|
881
|
+
else EpisodicNode(
|
|
882
|
+
name=episode.name,
|
|
883
|
+
labels=[],
|
|
884
|
+
source=episode.source,
|
|
885
|
+
content=episode.content,
|
|
886
|
+
source_description=episode.source_description,
|
|
887
|
+
group_id=group_id,
|
|
888
|
+
created_at=now,
|
|
889
|
+
valid_at=episode.reference_time,
|
|
890
|
+
)
|
|
891
|
+
for episode in bulk_episodes
|
|
892
|
+
]
|
|
682
893
|
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
894
|
+
# Save all episodes
|
|
895
|
+
await add_nodes_and_edges_bulk(
|
|
896
|
+
driver=self.driver,
|
|
897
|
+
episodic_nodes=episodes,
|
|
898
|
+
episodic_edges=[],
|
|
899
|
+
entity_nodes=[],
|
|
900
|
+
entity_edges=[],
|
|
901
|
+
embedder=self.embedder,
|
|
902
|
+
)
|
|
687
903
|
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
resolve_edge_pointers(edges, uuid_map) for edges in extracted_edges_bulk
|
|
691
|
-
]
|
|
904
|
+
# Get previous episode context for each episode
|
|
905
|
+
episode_context = await retrieve_previous_episodes_bulk(self.driver, episodes)
|
|
692
906
|
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
907
|
+
# Extract and dedupe nodes and edges
|
|
908
|
+
(
|
|
909
|
+
nodes_by_episode,
|
|
910
|
+
uuid_map,
|
|
911
|
+
extracted_edges_bulk,
|
|
912
|
+
) = await self._extract_and_dedupe_nodes_bulk(
|
|
913
|
+
episode_context,
|
|
914
|
+
edge_type_map or edge_type_map_default,
|
|
915
|
+
edge_types,
|
|
916
|
+
entity_types,
|
|
917
|
+
excluded_entity_types,
|
|
918
|
+
)
|
|
702
919
|
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
920
|
+
# Create Episodic Edges
|
|
921
|
+
episodic_edges: list[EpisodicEdge] = []
|
|
922
|
+
for episode_uuid, nodes in nodes_by_episode.items():
|
|
923
|
+
episodic_edges.extend(build_episodic_edges(nodes, episode_uuid, now))
|
|
707
924
|
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
for episode_uuid, mentioned_nodes in nodes_by_episode.items():
|
|
712
|
-
for mentioned_node in mentioned_nodes:
|
|
713
|
-
if node.uuid == mentioned_node.uuid:
|
|
714
|
-
episode_uuids.append(episode_uuid)
|
|
715
|
-
break
|
|
716
|
-
|
|
717
|
-
episode_mentions: list[EpisodicNode] = [
|
|
718
|
-
episodes_by_uuid[episode_uuid] for episode_uuid in episode_uuids
|
|
719
|
-
]
|
|
720
|
-
episode_mentions.sort(key=lambda x: x.valid_at, reverse=True)
|
|
721
|
-
|
|
722
|
-
extract_attributes_params.append((node, episode_mentions))
|
|
723
|
-
|
|
724
|
-
new_hydrated_nodes: list[list[EntityNode]] = await semaphore_gather(
|
|
725
|
-
*[
|
|
726
|
-
extract_attributes_from_nodes(
|
|
727
|
-
self.clients,
|
|
728
|
-
[params[0]],
|
|
729
|
-
params[1][0],
|
|
730
|
-
params[1][0:],
|
|
731
|
-
entity_types,
|
|
732
|
-
)
|
|
733
|
-
for params in extract_attributes_params
|
|
925
|
+
# Re-map edge pointers and dedupe edges
|
|
926
|
+
extracted_edges_bulk_updated: list[list[EntityEdge]] = [
|
|
927
|
+
resolve_edge_pointers(edges, uuid_map) for edges in extracted_edges_bulk
|
|
734
928
|
]
|
|
735
|
-
)
|
|
736
929
|
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
nodes_uuid_set: set[str] = set()
|
|
746
|
-
for episode, _ in episode_context:
|
|
747
|
-
nodes_by_episode_unique[episode.uuid] = []
|
|
748
|
-
nodes = [nodes_by_uuid[node.uuid] for node in nodes_by_episode[episode.uuid]]
|
|
749
|
-
for node in nodes:
|
|
750
|
-
if node.uuid not in nodes_uuid_set:
|
|
751
|
-
nodes_by_episode_unique[episode.uuid].append(node)
|
|
752
|
-
nodes_uuid_set.add(node.uuid)
|
|
753
|
-
|
|
754
|
-
node_results = await semaphore_gather(
|
|
755
|
-
*[
|
|
756
|
-
resolve_extracted_nodes(
|
|
757
|
-
self.clients,
|
|
758
|
-
nodes_by_episode_unique[episode.uuid],
|
|
759
|
-
episode,
|
|
760
|
-
previous_episodes,
|
|
761
|
-
entity_types,
|
|
762
|
-
)
|
|
763
|
-
for episode, previous_episodes in episode_context
|
|
764
|
-
]
|
|
765
|
-
)
|
|
930
|
+
edges_by_episode = await dedupe_edges_bulk(
|
|
931
|
+
self.clients,
|
|
932
|
+
extracted_edges_bulk_updated,
|
|
933
|
+
episode_context,
|
|
934
|
+
[],
|
|
935
|
+
edge_types or {},
|
|
936
|
+
edge_type_map or edge_type_map_default,
|
|
937
|
+
)
|
|
766
938
|
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
for node in nodes:
|
|
783
|
-
updated_node_uuid = uuid_map.get(node.uuid, node.uuid)
|
|
784
|
-
updated_node = nodes_by_uuid[updated_node_uuid]
|
|
785
|
-
updated_nodes.append(updated_node)
|
|
786
|
-
|
|
787
|
-
nodes_by_episode_unique[episode_uuid] = updated_nodes
|
|
788
|
-
|
|
789
|
-
hydrated_nodes_results: list[list[EntityNode]] = await semaphore_gather(
|
|
790
|
-
*[
|
|
791
|
-
extract_attributes_from_nodes(
|
|
792
|
-
self.clients,
|
|
793
|
-
nodes_by_episode_unique[episode.uuid],
|
|
794
|
-
episode,
|
|
795
|
-
previous_episodes,
|
|
796
|
-
entity_types,
|
|
797
|
-
)
|
|
798
|
-
for episode, previous_episodes in episode_context
|
|
799
|
-
]
|
|
800
|
-
)
|
|
939
|
+
# Resolve nodes and edges against the existing graph
|
|
940
|
+
(
|
|
941
|
+
final_hydrated_nodes,
|
|
942
|
+
resolved_edges,
|
|
943
|
+
invalidated_edges,
|
|
944
|
+
final_uuid_map,
|
|
945
|
+
) = await self._resolve_nodes_and_edges_bulk(
|
|
946
|
+
nodes_by_episode,
|
|
947
|
+
edges_by_episode,
|
|
948
|
+
episode_context,
|
|
949
|
+
entity_types,
|
|
950
|
+
edge_types,
|
|
951
|
+
edge_type_map or edge_type_map_default,
|
|
952
|
+
episodes,
|
|
953
|
+
)
|
|
801
954
|
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
edge_results = await semaphore_gather(
|
|
816
|
-
*[
|
|
817
|
-
resolve_extracted_edges(
|
|
818
|
-
self.clients,
|
|
819
|
-
edges_by_episode_unique[episode.uuid],
|
|
820
|
-
episode,
|
|
821
|
-
hydrated_nodes,
|
|
822
|
-
edge_types or {},
|
|
823
|
-
edge_type_map or edge_type_map_default,
|
|
824
|
-
)
|
|
825
|
-
for episode in episodes
|
|
826
|
-
]
|
|
827
|
-
)
|
|
955
|
+
# Resolved pointers for episodic edges
|
|
956
|
+
resolved_episodic_edges = resolve_edge_pointers(episodic_edges, final_uuid_map)
|
|
957
|
+
|
|
958
|
+
# save data to KG
|
|
959
|
+
await add_nodes_and_edges_bulk(
|
|
960
|
+
self.driver,
|
|
961
|
+
episodes,
|
|
962
|
+
resolved_episodic_edges,
|
|
963
|
+
final_hydrated_nodes,
|
|
964
|
+
resolved_edges + invalidated_edges,
|
|
965
|
+
self.embedder,
|
|
966
|
+
)
|
|
828
967
|
|
|
829
|
-
|
|
830
|
-
invalidated_edges: list[EntityEdge] = []
|
|
831
|
-
for result in edge_results:
|
|
832
|
-
resolved_edges.extend(result[0])
|
|
833
|
-
invalidated_edges.extend(result[1])
|
|
834
|
-
|
|
835
|
-
# Resolved pointers for episodic edges
|
|
836
|
-
resolved_episodic_edges = resolve_edge_pointers(episodic_edges, uuid_map)
|
|
837
|
-
|
|
838
|
-
# save data to KG
|
|
839
|
-
await add_nodes_and_edges_bulk(
|
|
840
|
-
self.driver,
|
|
841
|
-
episodes,
|
|
842
|
-
resolved_episodic_edges,
|
|
843
|
-
final_hydrated_nodes,
|
|
844
|
-
resolved_edges + invalidated_edges,
|
|
845
|
-
self.embedder,
|
|
846
|
-
)
|
|
968
|
+
end = time()
|
|
847
969
|
|
|
848
|
-
|
|
849
|
-
|
|
970
|
+
# Add span attributes
|
|
971
|
+
bulk_span.add_attributes(
|
|
972
|
+
{
|
|
973
|
+
'group_id': group_id,
|
|
974
|
+
'node.count': len(final_hydrated_nodes),
|
|
975
|
+
'edge.count': len(resolved_edges + invalidated_edges),
|
|
976
|
+
'duration_ms': (end - start) * 1000,
|
|
977
|
+
}
|
|
978
|
+
)
|
|
850
979
|
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
980
|
+
logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
|
|
981
|
+
|
|
982
|
+
return AddBulkEpisodeResults(
|
|
983
|
+
episodes=episodes,
|
|
984
|
+
episodic_edges=resolved_episodic_edges,
|
|
985
|
+
nodes=final_hydrated_nodes,
|
|
986
|
+
edges=resolved_edges + invalidated_edges,
|
|
987
|
+
communities=[],
|
|
988
|
+
community_edges=[],
|
|
989
|
+
)
|
|
859
990
|
|
|
860
|
-
|
|
861
|
-
|
|
991
|
+
except Exception as e:
|
|
992
|
+
bulk_span.set_status('error', str(e))
|
|
993
|
+
bulk_span.record_exception(e)
|
|
994
|
+
raise e
|
|
862
995
|
|
|
863
996
|
async def build_communities(
|
|
864
997
|
self, group_ids: list[str] | None = None
|