graphiti-core 0.17.4__py3-none-any.whl → 0.25.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphiti_core/cross_encoder/gemini_reranker_client.py +1 -1
- graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
- graphiti_core/decorators.py +110 -0
- graphiti_core/driver/driver.py +62 -2
- graphiti_core/driver/falkordb_driver.py +215 -23
- graphiti_core/driver/graph_operations/graph_operations.py +191 -0
- graphiti_core/driver/kuzu_driver.py +182 -0
- graphiti_core/driver/neo4j_driver.py +70 -8
- graphiti_core/driver/neptune_driver.py +305 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +264 -132
- graphiti_core/embedder/azure_openai.py +10 -3
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/graph_queries.py +114 -101
- graphiti_core/graphiti.py +635 -260
- graphiti_core/graphiti_types.py +2 -0
- graphiti_core/helpers.py +37 -15
- graphiti_core/llm_client/anthropic_client.py +142 -52
- graphiti_core/llm_client/azure_openai_client.py +57 -19
- graphiti_core/llm_client/client.py +83 -21
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/gemini_client.py +75 -57
- graphiti_core/llm_client/openai_base_client.py +92 -48
- graphiti_core/llm_client/openai_client.py +39 -9
- graphiti_core/llm_client/openai_generic_client.py +91 -56
- graphiti_core/models/edges/edge_db_queries.py +259 -35
- graphiti_core/models/nodes/node_db_queries.py +311 -32
- graphiti_core/nodes.py +388 -164
- graphiti_core/prompts/dedupe_edges.py +42 -31
- graphiti_core/prompts/dedupe_nodes.py +56 -39
- graphiti_core/prompts/eval.py +4 -4
- graphiti_core/prompts/extract_edges.py +24 -15
- graphiti_core/prompts/extract_nodes.py +76 -35
- graphiti_core/prompts/prompt_helpers.py +39 -0
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +23 -25
- graphiti_core/search/search.py +154 -74
- graphiti_core/search/search_config.py +39 -4
- graphiti_core/search/search_filters.py +110 -31
- graphiti_core/search/search_helpers.py +5 -6
- graphiti_core/search/search_utils.py +1360 -473
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +216 -90
- graphiti_core/utils/content_chunking.py +702 -0
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +62 -38
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +306 -156
- graphiti_core/utils/maintenance/graph_data_operations.py +44 -74
- graphiti_core/utils/maintenance/node_operations.py +466 -206
- graphiti_core/utils/maintenance/temporal_operations.py +11 -3
- graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
- graphiti_core/utils/text_utils.py +53 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/METADATA +221 -87
- graphiti_core-0.25.3.dist-info/RECORD +87 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/WHEEL +1 -1
- graphiti_core-0.17.4.dist-info/RECORD +0 -77
- /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/licenses/LICENSE +0 -0
graphiti_core/graphiti.py
CHANGED
|
@@ -24,18 +24,34 @@ from typing_extensions import LiteralString
|
|
|
24
24
|
|
|
25
25
|
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
26
26
|
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
|
27
|
+
from graphiti_core.decorators import handle_multiple_group_ids
|
|
27
28
|
from graphiti_core.driver.driver import GraphDriver
|
|
28
29
|
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
|
29
|
-
from graphiti_core.edges import
|
|
30
|
+
from graphiti_core.edges import (
|
|
31
|
+
CommunityEdge,
|
|
32
|
+
Edge,
|
|
33
|
+
EntityEdge,
|
|
34
|
+
EpisodicEdge,
|
|
35
|
+
create_entity_edge_embeddings,
|
|
36
|
+
)
|
|
30
37
|
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
|
38
|
+
from graphiti_core.errors import NodeNotFoundError
|
|
31
39
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
32
40
|
from graphiti_core.helpers import (
|
|
41
|
+
get_default_group_id,
|
|
33
42
|
semaphore_gather,
|
|
34
43
|
validate_excluded_entity_types,
|
|
35
44
|
validate_group_id,
|
|
36
45
|
)
|
|
37
46
|
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
|
38
|
-
from graphiti_core.nodes import
|
|
47
|
+
from graphiti_core.nodes import (
|
|
48
|
+
CommunityNode,
|
|
49
|
+
EntityNode,
|
|
50
|
+
EpisodeType,
|
|
51
|
+
EpisodicNode,
|
|
52
|
+
Node,
|
|
53
|
+
create_entity_node_embeddings,
|
|
54
|
+
)
|
|
39
55
|
from graphiti_core.search.search import SearchConfig, search
|
|
40
56
|
from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
|
|
41
57
|
from graphiti_core.search.search_config_recipes import (
|
|
@@ -46,11 +62,10 @@ from graphiti_core.search.search_config_recipes import (
|
|
|
46
62
|
from graphiti_core.search.search_filters import SearchFilters
|
|
47
63
|
from graphiti_core.search.search_utils import (
|
|
48
64
|
RELEVANT_SCHEMA_LIMIT,
|
|
49
|
-
get_edge_invalidation_candidates,
|
|
50
65
|
get_mentioned_nodes,
|
|
51
|
-
get_relevant_edges,
|
|
52
66
|
)
|
|
53
67
|
from graphiti_core.telemetry import capture_event
|
|
68
|
+
from graphiti_core.tracer import Tracer, create_tracer
|
|
54
69
|
from graphiti_core.utils.bulk_utils import (
|
|
55
70
|
RawEpisode,
|
|
56
71
|
add_nodes_and_edges_bulk,
|
|
@@ -67,7 +82,6 @@ from graphiti_core.utils.maintenance.community_operations import (
|
|
|
67
82
|
update_community,
|
|
68
83
|
)
|
|
69
84
|
from graphiti_core.utils.maintenance.edge_operations import (
|
|
70
|
-
build_duplicate_of_edges,
|
|
71
85
|
build_episodic_edges,
|
|
72
86
|
extract_edges,
|
|
73
87
|
resolve_extracted_edge,
|
|
@@ -75,7 +89,6 @@ from graphiti_core.utils.maintenance.edge_operations import (
|
|
|
75
89
|
)
|
|
76
90
|
from graphiti_core.utils.maintenance.graph_data_operations import (
|
|
77
91
|
EPISODE_WINDOW_LEN,
|
|
78
|
-
build_indices_and_constraints,
|
|
79
92
|
retrieve_episodes,
|
|
80
93
|
)
|
|
81
94
|
from graphiti_core.utils.maintenance.node_operations import (
|
|
@@ -92,6 +105,23 @@ load_dotenv()
|
|
|
92
105
|
|
|
93
106
|
class AddEpisodeResults(BaseModel):
|
|
94
107
|
episode: EpisodicNode
|
|
108
|
+
episodic_edges: list[EpisodicEdge]
|
|
109
|
+
nodes: list[EntityNode]
|
|
110
|
+
edges: list[EntityEdge]
|
|
111
|
+
communities: list[CommunityNode]
|
|
112
|
+
community_edges: list[CommunityEdge]
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class AddBulkEpisodeResults(BaseModel):
|
|
116
|
+
episodes: list[EpisodicNode]
|
|
117
|
+
episodic_edges: list[EpisodicEdge]
|
|
118
|
+
nodes: list[EntityNode]
|
|
119
|
+
edges: list[EntityEdge]
|
|
120
|
+
communities: list[CommunityNode]
|
|
121
|
+
community_edges: list[CommunityEdge]
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class AddTripletResults(BaseModel):
|
|
95
125
|
nodes: list[EntityNode]
|
|
96
126
|
edges: list[EntityEdge]
|
|
97
127
|
|
|
@@ -108,11 +138,13 @@ class Graphiti:
|
|
|
108
138
|
store_raw_episode_content: bool = True,
|
|
109
139
|
graph_driver: GraphDriver | None = None,
|
|
110
140
|
max_coroutines: int | None = None,
|
|
141
|
+
tracer: Tracer | None = None,
|
|
142
|
+
trace_span_prefix: str = 'graphiti',
|
|
111
143
|
):
|
|
112
144
|
"""
|
|
113
145
|
Initialize a Graphiti instance.
|
|
114
146
|
|
|
115
|
-
This constructor sets up a connection to
|
|
147
|
+
This constructor sets up a connection to a graph database and initializes
|
|
116
148
|
the LLM client for natural language processing tasks.
|
|
117
149
|
|
|
118
150
|
Parameters
|
|
@@ -140,6 +172,10 @@ class Graphiti:
|
|
|
140
172
|
max_coroutines : int | None, optional
|
|
141
173
|
The maximum number of concurrent operations allowed. Overrides SEMAPHORE_LIMIT set in the environment.
|
|
142
174
|
If not set, the Graphiti default is used.
|
|
175
|
+
tracer : Tracer | None, optional
|
|
176
|
+
An OpenTelemetry tracer instance for distributed tracing. If not provided, tracing is disabled (no-op).
|
|
177
|
+
trace_span_prefix : str, optional
|
|
178
|
+
Prefix to prepend to all span names. Defaults to 'graphiti'.
|
|
143
179
|
|
|
144
180
|
Returns
|
|
145
181
|
-------
|
|
@@ -147,11 +183,11 @@ class Graphiti:
|
|
|
147
183
|
|
|
148
184
|
Notes
|
|
149
185
|
-----
|
|
150
|
-
This method establishes a connection to
|
|
186
|
+
This method establishes a connection to a graph database (Neo4j by default) using the provided
|
|
151
187
|
credentials. It also sets up the LLM client, either using the provided client
|
|
152
188
|
or by creating a default OpenAIClient.
|
|
153
189
|
|
|
154
|
-
The default database name is
|
|
190
|
+
The default database name is defined during the driver’s construction. If a different database name
|
|
155
191
|
is required, it should be specified in the URI or set separately after
|
|
156
192
|
initialization.
|
|
157
193
|
|
|
@@ -182,11 +218,18 @@ class Graphiti:
|
|
|
182
218
|
else:
|
|
183
219
|
self.cross_encoder = OpenAIRerankerClient()
|
|
184
220
|
|
|
221
|
+
# Initialize tracer
|
|
222
|
+
self.tracer = create_tracer(tracer, trace_span_prefix)
|
|
223
|
+
|
|
224
|
+
# Set tracer on clients
|
|
225
|
+
self.llm_client.set_tracer(self.tracer)
|
|
226
|
+
|
|
185
227
|
self.clients = GraphitiClients(
|
|
186
228
|
driver=self.driver,
|
|
187
229
|
llm_client=self.llm_client,
|
|
188
230
|
embedder=self.embedder,
|
|
189
231
|
cross_encoder=self.cross_encoder,
|
|
232
|
+
tracer=self.tracer,
|
|
190
233
|
)
|
|
191
234
|
|
|
192
235
|
# Capture telemetry event
|
|
@@ -298,25 +341,249 @@ class Graphiti:
|
|
|
298
341
|
-----
|
|
299
342
|
This method should typically be called once during the initial setup of the
|
|
300
343
|
knowledge graph or when updating the database schema. It uses the
|
|
301
|
-
`build_indices_and_constraints`
|
|
302
|
-
`graphiti_core.utils.maintenance.graph_data_operations` module to perform
|
|
344
|
+
driver's `build_indices_and_constraints` method to perform
|
|
303
345
|
the actual database operations.
|
|
304
346
|
|
|
305
347
|
The specific indices and constraints created depend on the implementation
|
|
306
|
-
of the `build_indices_and_constraints`
|
|
307
|
-
documentation for details on the exact database schema modifications.
|
|
348
|
+
of the driver's `build_indices_and_constraints` method. Refer to the specific
|
|
349
|
+
driver documentation for details on the exact database schema modifications.
|
|
308
350
|
|
|
309
351
|
Caution: Running this method on a large existing database may take some time
|
|
310
352
|
and could impact database performance during execution.
|
|
311
353
|
"""
|
|
312
|
-
await
|
|
354
|
+
await self.driver.build_indices_and_constraints(delete_existing)
|
|
355
|
+
|
|
356
|
+
async def _extract_and_resolve_nodes(
|
|
357
|
+
self,
|
|
358
|
+
episode: EpisodicNode,
|
|
359
|
+
previous_episodes: list[EpisodicNode],
|
|
360
|
+
entity_types: dict[str, type[BaseModel]] | None,
|
|
361
|
+
excluded_entity_types: list[str] | None,
|
|
362
|
+
) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
|
|
363
|
+
"""Extract nodes from episode and resolve against existing graph."""
|
|
364
|
+
extracted_nodes = await extract_nodes(
|
|
365
|
+
self.clients, episode, previous_episodes, entity_types, excluded_entity_types
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
nodes, uuid_map, duplicates = await resolve_extracted_nodes(
|
|
369
|
+
self.clients,
|
|
370
|
+
extracted_nodes,
|
|
371
|
+
episode,
|
|
372
|
+
previous_episodes,
|
|
373
|
+
entity_types,
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
return nodes, uuid_map, duplicates
|
|
377
|
+
|
|
378
|
+
async def _extract_and_resolve_edges(
|
|
379
|
+
self,
|
|
380
|
+
episode: EpisodicNode,
|
|
381
|
+
extracted_nodes: list[EntityNode],
|
|
382
|
+
previous_episodes: list[EpisodicNode],
|
|
383
|
+
edge_type_map: dict[tuple[str, str], list[str]],
|
|
384
|
+
group_id: str,
|
|
385
|
+
edge_types: dict[str, type[BaseModel]] | None,
|
|
386
|
+
nodes: list[EntityNode],
|
|
387
|
+
uuid_map: dict[str, str],
|
|
388
|
+
custom_extraction_instructions: str | None = None,
|
|
389
|
+
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
|
390
|
+
"""Extract edges from episode and resolve against existing graph."""
|
|
391
|
+
extracted_edges = await extract_edges(
|
|
392
|
+
self.clients,
|
|
393
|
+
episode,
|
|
394
|
+
extracted_nodes,
|
|
395
|
+
previous_episodes,
|
|
396
|
+
edge_type_map,
|
|
397
|
+
group_id,
|
|
398
|
+
edge_types,
|
|
399
|
+
custom_extraction_instructions,
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
edges = resolve_edge_pointers(extracted_edges, uuid_map)
|
|
403
|
+
|
|
404
|
+
resolved_edges, invalidated_edges = await resolve_extracted_edges(
|
|
405
|
+
self.clients,
|
|
406
|
+
edges,
|
|
407
|
+
episode,
|
|
408
|
+
nodes,
|
|
409
|
+
edge_types or {},
|
|
410
|
+
edge_type_map,
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
return resolved_edges, invalidated_edges
|
|
414
|
+
|
|
415
|
+
async def _process_episode_data(
|
|
416
|
+
self,
|
|
417
|
+
episode: EpisodicNode,
|
|
418
|
+
nodes: list[EntityNode],
|
|
419
|
+
entity_edges: list[EntityEdge],
|
|
420
|
+
now: datetime,
|
|
421
|
+
) -> tuple[list[EpisodicEdge], EpisodicNode]:
|
|
422
|
+
"""Process and save episode data to the graph."""
|
|
423
|
+
episodic_edges = build_episodic_edges(nodes, episode.uuid, now)
|
|
424
|
+
episode.entity_edges = [edge.uuid for edge in entity_edges]
|
|
425
|
+
|
|
426
|
+
if not self.store_raw_episode_content:
|
|
427
|
+
episode.content = ''
|
|
428
|
+
|
|
429
|
+
await add_nodes_and_edges_bulk(
|
|
430
|
+
self.driver,
|
|
431
|
+
[episode],
|
|
432
|
+
episodic_edges,
|
|
433
|
+
nodes,
|
|
434
|
+
entity_edges,
|
|
435
|
+
self.embedder,
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
return episodic_edges, episode
|
|
439
|
+
|
|
440
|
+
async def _extract_and_dedupe_nodes_bulk(
|
|
441
|
+
self,
|
|
442
|
+
episode_context: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
|
443
|
+
edge_type_map: dict[tuple[str, str], list[str]],
|
|
444
|
+
edge_types: dict[str, type[BaseModel]] | None,
|
|
445
|
+
entity_types: dict[str, type[BaseModel]] | None,
|
|
446
|
+
excluded_entity_types: list[str] | None,
|
|
447
|
+
) -> tuple[
|
|
448
|
+
dict[str, list[EntityNode]],
|
|
449
|
+
dict[str, str],
|
|
450
|
+
list[list[EntityEdge]],
|
|
451
|
+
]:
|
|
452
|
+
"""Extract nodes and edges from all episodes and deduplicate."""
|
|
453
|
+
# Extract all nodes and edges for each episode
|
|
454
|
+
extracted_nodes_bulk, extracted_edges_bulk = await extract_nodes_and_edges_bulk(
|
|
455
|
+
self.clients,
|
|
456
|
+
episode_context,
|
|
457
|
+
edge_type_map=edge_type_map,
|
|
458
|
+
edge_types=edge_types,
|
|
459
|
+
entity_types=entity_types,
|
|
460
|
+
excluded_entity_types=excluded_entity_types,
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
# Dedupe extracted nodes in memory
|
|
464
|
+
nodes_by_episode, uuid_map = await dedupe_nodes_bulk(
|
|
465
|
+
self.clients, extracted_nodes_bulk, episode_context, entity_types
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
return nodes_by_episode, uuid_map, extracted_edges_bulk
|
|
469
|
+
|
|
470
|
+
async def _resolve_nodes_and_edges_bulk(
|
|
471
|
+
self,
|
|
472
|
+
nodes_by_episode: dict[str, list[EntityNode]],
|
|
473
|
+
edges_by_episode: dict[str, list[EntityEdge]],
|
|
474
|
+
episode_context: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
|
475
|
+
entity_types: dict[str, type[BaseModel]] | None,
|
|
476
|
+
edge_types: dict[str, type[BaseModel]] | None,
|
|
477
|
+
edge_type_map: dict[tuple[str, str], list[str]],
|
|
478
|
+
episodes: list[EpisodicNode],
|
|
479
|
+
) -> tuple[list[EntityNode], list[EntityEdge], list[EntityEdge], dict[str, str]]:
|
|
480
|
+
"""Resolve nodes and edges against the existing graph."""
|
|
481
|
+
nodes_by_uuid: dict[str, EntityNode] = {
|
|
482
|
+
node.uuid: node for nodes in nodes_by_episode.values() for node in nodes
|
|
483
|
+
}
|
|
484
|
+
|
|
485
|
+
# Get unique nodes per episode
|
|
486
|
+
nodes_by_episode_unique: dict[str, list[EntityNode]] = {}
|
|
487
|
+
nodes_uuid_set: set[str] = set()
|
|
488
|
+
for episode, _ in episode_context:
|
|
489
|
+
nodes_by_episode_unique[episode.uuid] = []
|
|
490
|
+
nodes = [nodes_by_uuid[node.uuid] for node in nodes_by_episode[episode.uuid]]
|
|
491
|
+
for node in nodes:
|
|
492
|
+
if node.uuid not in nodes_uuid_set:
|
|
493
|
+
nodes_by_episode_unique[episode.uuid].append(node)
|
|
494
|
+
nodes_uuid_set.add(node.uuid)
|
|
495
|
+
|
|
496
|
+
# Resolve nodes
|
|
497
|
+
node_results = await semaphore_gather(
|
|
498
|
+
*[
|
|
499
|
+
resolve_extracted_nodes(
|
|
500
|
+
self.clients,
|
|
501
|
+
nodes_by_episode_unique[episode.uuid],
|
|
502
|
+
episode,
|
|
503
|
+
previous_episodes,
|
|
504
|
+
entity_types,
|
|
505
|
+
)
|
|
506
|
+
for episode, previous_episodes in episode_context
|
|
507
|
+
]
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
resolved_nodes: list[EntityNode] = []
|
|
511
|
+
uuid_map: dict[str, str] = {}
|
|
512
|
+
for result in node_results:
|
|
513
|
+
resolved_nodes.extend(result[0])
|
|
514
|
+
uuid_map.update(result[1])
|
|
515
|
+
|
|
516
|
+
# Update nodes_by_uuid with resolved nodes
|
|
517
|
+
for resolved_node in resolved_nodes:
|
|
518
|
+
nodes_by_uuid[resolved_node.uuid] = resolved_node
|
|
519
|
+
|
|
520
|
+
# Update nodes_by_episode_unique with resolved pointers
|
|
521
|
+
for episode_uuid, nodes in nodes_by_episode_unique.items():
|
|
522
|
+
updated_nodes: list[EntityNode] = []
|
|
523
|
+
for node in nodes:
|
|
524
|
+
updated_node_uuid = uuid_map.get(node.uuid, node.uuid)
|
|
525
|
+
updated_node = nodes_by_uuid[updated_node_uuid]
|
|
526
|
+
updated_nodes.append(updated_node)
|
|
527
|
+
nodes_by_episode_unique[episode_uuid] = updated_nodes
|
|
528
|
+
|
|
529
|
+
# Extract attributes for resolved nodes
|
|
530
|
+
hydrated_nodes_results: list[list[EntityNode]] = await semaphore_gather(
|
|
531
|
+
*[
|
|
532
|
+
extract_attributes_from_nodes(
|
|
533
|
+
self.clients,
|
|
534
|
+
nodes_by_episode_unique[episode.uuid],
|
|
535
|
+
episode,
|
|
536
|
+
previous_episodes,
|
|
537
|
+
entity_types,
|
|
538
|
+
)
|
|
539
|
+
for episode, previous_episodes in episode_context
|
|
540
|
+
]
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
final_hydrated_nodes = [node for nodes in hydrated_nodes_results for node in nodes]
|
|
544
|
+
|
|
545
|
+
# Resolve edges with updated pointers
|
|
546
|
+
edges_by_episode_unique: dict[str, list[EntityEdge]] = {}
|
|
547
|
+
edges_uuid_set: set[str] = set()
|
|
548
|
+
for episode_uuid, edges in edges_by_episode.items():
|
|
549
|
+
edges_with_updated_pointers = resolve_edge_pointers(edges, uuid_map)
|
|
550
|
+
edges_by_episode_unique[episode_uuid] = []
|
|
551
|
+
|
|
552
|
+
for edge in edges_with_updated_pointers:
|
|
553
|
+
if edge.uuid not in edges_uuid_set:
|
|
554
|
+
edges_by_episode_unique[episode_uuid].append(edge)
|
|
555
|
+
edges_uuid_set.add(edge.uuid)
|
|
556
|
+
|
|
557
|
+
edge_results = await semaphore_gather(
|
|
558
|
+
*[
|
|
559
|
+
resolve_extracted_edges(
|
|
560
|
+
self.clients,
|
|
561
|
+
edges_by_episode_unique[episode.uuid],
|
|
562
|
+
episode,
|
|
563
|
+
final_hydrated_nodes,
|
|
564
|
+
edge_types or {},
|
|
565
|
+
edge_type_map,
|
|
566
|
+
)
|
|
567
|
+
for episode in episodes
|
|
568
|
+
]
|
|
569
|
+
)
|
|
313
570
|
|
|
571
|
+
resolved_edges: list[EntityEdge] = []
|
|
572
|
+
invalidated_edges: list[EntityEdge] = []
|
|
573
|
+
for result in edge_results:
|
|
574
|
+
resolved_edges.extend(result[0])
|
|
575
|
+
invalidated_edges.extend(result[1])
|
|
576
|
+
|
|
577
|
+
return final_hydrated_nodes, resolved_edges, invalidated_edges, uuid_map
|
|
578
|
+
|
|
579
|
+
@handle_multiple_group_ids
|
|
314
580
|
async def retrieve_episodes(
|
|
315
581
|
self,
|
|
316
582
|
reference_time: datetime,
|
|
317
583
|
last_n: int = EPISODE_WINDOW_LEN,
|
|
318
584
|
group_ids: list[str] | None = None,
|
|
319
585
|
source: EpisodeType | None = None,
|
|
586
|
+
driver: GraphDriver | None = None,
|
|
320
587
|
) -> list[EpisodicNode]:
|
|
321
588
|
"""
|
|
322
589
|
Retrieve the last n episodic nodes from the graph.
|
|
@@ -343,7 +610,10 @@ class Graphiti:
|
|
|
343
610
|
The actual retrieval is performed by the `retrieve_episodes` function
|
|
344
611
|
from the `graphiti_core.utils` module.
|
|
345
612
|
"""
|
|
346
|
-
|
|
613
|
+
if driver is None:
|
|
614
|
+
driver = self.clients.driver
|
|
615
|
+
|
|
616
|
+
return await retrieve_episodes(driver, reference_time, last_n, group_ids, source)
|
|
347
617
|
|
|
348
618
|
async def add_episode(
|
|
349
619
|
self,
|
|
@@ -352,14 +622,15 @@ class Graphiti:
|
|
|
352
622
|
source_description: str,
|
|
353
623
|
reference_time: datetime,
|
|
354
624
|
source: EpisodeType = EpisodeType.message,
|
|
355
|
-
group_id: str =
|
|
625
|
+
group_id: str | None = None,
|
|
356
626
|
uuid: str | None = None,
|
|
357
627
|
update_communities: bool = False,
|
|
358
|
-
entity_types: dict[str, BaseModel] | None = None,
|
|
628
|
+
entity_types: dict[str, type[BaseModel]] | None = None,
|
|
359
629
|
excluded_entity_types: list[str] | None = None,
|
|
360
630
|
previous_episode_uuids: list[str] | None = None,
|
|
361
|
-
edge_types: dict[str, BaseModel] | None = None,
|
|
631
|
+
edge_types: dict[str, type[BaseModel]] | None = None,
|
|
362
632
|
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
|
|
633
|
+
custom_extraction_instructions: str | None = None,
|
|
363
634
|
) -> AddEpisodeResults:
|
|
364
635
|
"""
|
|
365
636
|
Process an episode and update the graph.
|
|
@@ -394,6 +665,9 @@ class Graphiti:
|
|
|
394
665
|
previous_episode_uuids : list[str] | None
|
|
395
666
|
Optional. list of episode uuids to use as the previous episodes. If this is not provided,
|
|
396
667
|
the most recent episodes by created_at date will be used.
|
|
668
|
+
custom_extraction_instructions : str | None
|
|
669
|
+
Optional. Custom extraction instructions string to be included in the extract entities and extract edges prompts.
|
|
670
|
+
This allows for additional instructions or context to guide the extraction process.
|
|
397
671
|
|
|
398
672
|
Returns
|
|
399
673
|
-------
|
|
@@ -416,133 +690,161 @@ class Graphiti:
|
|
|
416
690
|
background_tasks.add_task(graphiti.add_episode, **episode_data.dict())
|
|
417
691
|
return {"message": "Episode processing started"}
|
|
418
692
|
"""
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
693
|
+
start = time()
|
|
694
|
+
now = utc_now()
|
|
695
|
+
|
|
696
|
+
validate_entity_types(entity_types)
|
|
697
|
+
validate_excluded_entity_types(excluded_entity_types, entity_types)
|
|
422
698
|
|
|
423
|
-
|
|
424
|
-
|
|
699
|
+
if group_id is None:
|
|
700
|
+
# if group_id is None, use the default group id by the provider
|
|
701
|
+
# and the preset database name will be used
|
|
702
|
+
group_id = get_default_group_id(self.driver.provider)
|
|
703
|
+
else:
|
|
425
704
|
validate_group_id(group_id)
|
|
705
|
+
if group_id != self.driver._database:
|
|
706
|
+
# if group_id is provided, use it as the database name
|
|
707
|
+
self.driver = self.driver.clone(database=group_id)
|
|
708
|
+
self.clients.driver = self.driver
|
|
426
709
|
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
710
|
+
with self.tracer.start_span('add_episode') as span:
|
|
711
|
+
try:
|
|
712
|
+
# Retrieve previous episodes for context
|
|
713
|
+
previous_episodes = (
|
|
714
|
+
await self.retrieve_episodes(
|
|
715
|
+
reference_time,
|
|
716
|
+
last_n=RELEVANT_SCHEMA_LIMIT,
|
|
717
|
+
group_ids=[group_id],
|
|
718
|
+
source=source,
|
|
719
|
+
)
|
|
720
|
+
if previous_episode_uuids is None
|
|
721
|
+
else await EpisodicNode.get_by_uuids(self.driver, previous_episode_uuids)
|
|
433
722
|
)
|
|
434
|
-
if previous_episode_uuids is None
|
|
435
|
-
else await EpisodicNode.get_by_uuids(self.driver, previous_episode_uuids)
|
|
436
|
-
)
|
|
437
723
|
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
724
|
+
# Get or create episode
|
|
725
|
+
episode = (
|
|
726
|
+
await EpisodicNode.get_by_uuid(self.driver, uuid)
|
|
727
|
+
if uuid is not None
|
|
728
|
+
else EpisodicNode(
|
|
729
|
+
name=name,
|
|
730
|
+
group_id=group_id,
|
|
731
|
+
labels=[],
|
|
732
|
+
source=source,
|
|
733
|
+
content=episode_body,
|
|
734
|
+
source_description=source_description,
|
|
735
|
+
created_at=now,
|
|
736
|
+
valid_at=reference_time,
|
|
737
|
+
)
|
|
450
738
|
)
|
|
451
|
-
)
|
|
452
739
|
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
# Extract entities as nodes
|
|
740
|
+
# Create default edge type map
|
|
741
|
+
edge_type_map_default = (
|
|
742
|
+
{('Entity', 'Entity'): list(edge_types.keys())}
|
|
743
|
+
if edge_types is not None
|
|
744
|
+
else {('Entity', 'Entity'): []}
|
|
745
|
+
)
|
|
461
746
|
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
747
|
+
# Extract and resolve nodes
|
|
748
|
+
extracted_nodes = await extract_nodes(
|
|
749
|
+
self.clients,
|
|
750
|
+
episode,
|
|
751
|
+
previous_episodes,
|
|
752
|
+
entity_types,
|
|
753
|
+
excluded_entity_types,
|
|
754
|
+
custom_extraction_instructions,
|
|
755
|
+
)
|
|
465
756
|
|
|
466
|
-
|
|
467
|
-
(nodes, uuid_map, node_duplicates), extracted_edges = await semaphore_gather(
|
|
468
|
-
resolve_extracted_nodes(
|
|
757
|
+
nodes, uuid_map, _ = await resolve_extracted_nodes(
|
|
469
758
|
self.clients,
|
|
470
759
|
extracted_nodes,
|
|
471
760
|
episode,
|
|
472
761
|
previous_episodes,
|
|
473
762
|
entity_types,
|
|
474
|
-
)
|
|
475
|
-
|
|
476
|
-
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
# Extract and resolve edges in parallel with attribute extraction
|
|
766
|
+
resolved_edges, invalidated_edges = await self._extract_and_resolve_edges(
|
|
477
767
|
episode,
|
|
478
768
|
extracted_nodes,
|
|
479
769
|
previous_episodes,
|
|
480
770
|
edge_type_map or edge_type_map_default,
|
|
481
771
|
group_id,
|
|
482
772
|
edge_types,
|
|
483
|
-
),
|
|
484
|
-
max_coroutines=self.max_coroutines,
|
|
485
|
-
)
|
|
486
|
-
|
|
487
|
-
edges = resolve_edge_pointers(extracted_edges, uuid_map)
|
|
488
|
-
|
|
489
|
-
(resolved_edges, invalidated_edges), hydrated_nodes = await semaphore_gather(
|
|
490
|
-
resolve_extracted_edges(
|
|
491
|
-
self.clients,
|
|
492
|
-
edges,
|
|
493
|
-
episode,
|
|
494
773
|
nodes,
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
)
|
|
498
|
-
extract_attributes_from_nodes(
|
|
499
|
-
self.clients, nodes, episode, previous_episodes, entity_types
|
|
500
|
-
),
|
|
501
|
-
max_coroutines=self.max_coroutines,
|
|
502
|
-
)
|
|
774
|
+
uuid_map,
|
|
775
|
+
custom_extraction_instructions,
|
|
776
|
+
)
|
|
503
777
|
|
|
504
|
-
|
|
778
|
+
# Extract node attributes
|
|
779
|
+
hydrated_nodes = await extract_attributes_from_nodes(
|
|
780
|
+
self.clients, nodes, episode, previous_episodes, entity_types
|
|
781
|
+
)
|
|
505
782
|
|
|
506
|
-
|
|
783
|
+
entity_edges = resolved_edges + invalidated_edges
|
|
507
784
|
|
|
508
|
-
|
|
785
|
+
# Process and save episode data
|
|
786
|
+
episodic_edges, episode = await self._process_episode_data(
|
|
787
|
+
episode, hydrated_nodes, entity_edges, now
|
|
788
|
+
)
|
|
509
789
|
|
|
510
|
-
|
|
790
|
+
# Update communities if requested
|
|
791
|
+
communities = []
|
|
792
|
+
community_edges = []
|
|
793
|
+
if update_communities:
|
|
794
|
+
communities, community_edges = await semaphore_gather(
|
|
795
|
+
*[
|
|
796
|
+
update_community(self.driver, self.llm_client, self.embedder, node)
|
|
797
|
+
for node in nodes
|
|
798
|
+
],
|
|
799
|
+
max_coroutines=self.max_coroutines,
|
|
800
|
+
)
|
|
511
801
|
|
|
512
|
-
|
|
513
|
-
|
|
802
|
+
end = time()
|
|
803
|
+
|
|
804
|
+
# Add span attributes
|
|
805
|
+
span.add_attributes(
|
|
806
|
+
{
|
|
807
|
+
'episode.uuid': episode.uuid,
|
|
808
|
+
'episode.source': source.value,
|
|
809
|
+
'episode.reference_time': reference_time.isoformat(),
|
|
810
|
+
'group_id': group_id,
|
|
811
|
+
'node.count': len(hydrated_nodes),
|
|
812
|
+
'edge.count': len(entity_edges),
|
|
813
|
+
'edge.invalidated_count': len(invalidated_edges),
|
|
814
|
+
'previous_episodes.count': len(previous_episodes),
|
|
815
|
+
'entity_types.count': len(entity_types) if entity_types else 0,
|
|
816
|
+
'edge_types.count': len(edge_types) if edge_types else 0,
|
|
817
|
+
'update_communities': update_communities,
|
|
818
|
+
'communities.count': len(communities) if update_communities else 0,
|
|
819
|
+
'duration_ms': (end - start) * 1000,
|
|
820
|
+
}
|
|
821
|
+
)
|
|
514
822
|
|
|
515
|
-
|
|
516
|
-
self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges, self.embedder
|
|
517
|
-
)
|
|
823
|
+
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
|
518
824
|
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
max_coroutines=self.max_coroutines,
|
|
825
|
+
return AddEpisodeResults(
|
|
826
|
+
episode=episode,
|
|
827
|
+
episodic_edges=episodic_edges,
|
|
828
|
+
nodes=hydrated_nodes,
|
|
829
|
+
edges=entity_edges,
|
|
830
|
+
communities=communities,
|
|
831
|
+
community_edges=community_edges,
|
|
527
832
|
)
|
|
528
|
-
end = time()
|
|
529
|
-
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
|
530
|
-
|
|
531
|
-
return AddEpisodeResults(episode=episode, nodes=nodes, edges=entity_edges)
|
|
532
833
|
|
|
533
|
-
|
|
534
|
-
|
|
834
|
+
except Exception as e:
|
|
835
|
+
span.set_status('error', str(e))
|
|
836
|
+
span.record_exception(e)
|
|
837
|
+
raise e
|
|
535
838
|
|
|
536
|
-
##### EXPERIMENTAL #####
|
|
537
839
|
async def add_episode_bulk(
|
|
538
840
|
self,
|
|
539
841
|
bulk_episodes: list[RawEpisode],
|
|
540
|
-
group_id: str =
|
|
541
|
-
entity_types: dict[str, BaseModel] | None = None,
|
|
842
|
+
group_id: str | None = None,
|
|
843
|
+
entity_types: dict[str, type[BaseModel]] | None = None,
|
|
542
844
|
excluded_entity_types: list[str] | None = None,
|
|
543
|
-
edge_types: dict[str, BaseModel] | None = None,
|
|
845
|
+
edge_types: dict[str, type[BaseModel]] | None = None,
|
|
544
846
|
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
|
|
545
|
-
):
|
|
847
|
+
) -> AddBulkEpisodeResults:
|
|
546
848
|
"""
|
|
547
849
|
Process multiple episodes in bulk and update the graph.
|
|
548
850
|
|
|
@@ -558,7 +860,7 @@ class Graphiti:
|
|
|
558
860
|
|
|
559
861
|
Returns
|
|
560
862
|
-------
|
|
561
|
-
|
|
863
|
+
AddBulkEpisodeResults
|
|
562
864
|
|
|
563
865
|
Notes
|
|
564
866
|
-----
|
|
@@ -579,156 +881,167 @@ class Graphiti:
|
|
|
579
881
|
If these operations are required, use the `add_episode` method instead for each
|
|
580
882
|
individual episode.
|
|
581
883
|
"""
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
now = utc_now()
|
|
884
|
+
with self.tracer.start_span('add_episode_bulk') as bulk_span:
|
|
885
|
+
bulk_span.add_attributes({'episode.count': len(bulk_episodes)})
|
|
585
886
|
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
if
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
created_at=now,
|
|
606
|
-
valid_at=episode.reference_time,
|
|
887
|
+
try:
|
|
888
|
+
start = time()
|
|
889
|
+
now = utc_now()
|
|
890
|
+
|
|
891
|
+
# if group_id is None, use the default group id by the provider
|
|
892
|
+
if group_id is None:
|
|
893
|
+
group_id = get_default_group_id(self.driver.provider)
|
|
894
|
+
else:
|
|
895
|
+
validate_group_id(group_id)
|
|
896
|
+
if group_id != self.driver._database:
|
|
897
|
+
# if group_id is provided, use it as the database name
|
|
898
|
+
self.driver = self.driver.clone(database=group_id)
|
|
899
|
+
self.clients.driver = self.driver
|
|
900
|
+
|
|
901
|
+
# Create default edge type map
|
|
902
|
+
edge_type_map_default = (
|
|
903
|
+
{('Entity', 'Entity'): list(edge_types.keys())}
|
|
904
|
+
if edge_types is not None
|
|
905
|
+
else {('Entity', 'Entity'): []}
|
|
607
906
|
)
|
|
608
|
-
for episode in bulk_episodes
|
|
609
|
-
]
|
|
610
907
|
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
908
|
+
episodes = [
|
|
909
|
+
await EpisodicNode.get_by_uuid(self.driver, episode.uuid)
|
|
910
|
+
if episode.uuid is not None
|
|
911
|
+
else EpisodicNode(
|
|
912
|
+
name=episode.name,
|
|
913
|
+
labels=[],
|
|
914
|
+
source=episode.source,
|
|
915
|
+
content=episode.content,
|
|
916
|
+
source_description=episode.source_description,
|
|
917
|
+
group_id=group_id,
|
|
918
|
+
created_at=now,
|
|
919
|
+
valid_at=episode.reference_time,
|
|
920
|
+
)
|
|
921
|
+
for episode in bulk_episodes
|
|
922
|
+
]
|
|
614
923
|
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
924
|
+
# Save all episodes
|
|
925
|
+
await add_nodes_and_edges_bulk(
|
|
926
|
+
driver=self.driver,
|
|
927
|
+
episodic_nodes=episodes,
|
|
928
|
+
episodic_edges=[],
|
|
929
|
+
entity_nodes=[],
|
|
930
|
+
entity_edges=[],
|
|
931
|
+
embedder=self.embedder,
|
|
932
|
+
)
|
|
624
933
|
|
|
625
|
-
|
|
626
|
-
|
|
934
|
+
# Get previous episode context for each episode
|
|
935
|
+
episode_context = await retrieve_previous_episodes_bulk(self.driver, episodes)
|
|
627
936
|
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
self.clients, extracted_nodes_bulk, episode_context, entity_types
|
|
641
|
-
)
|
|
937
|
+
# Extract and dedupe nodes and edges
|
|
938
|
+
(
|
|
939
|
+
nodes_by_episode,
|
|
940
|
+
uuid_map,
|
|
941
|
+
extracted_edges_bulk,
|
|
942
|
+
) = await self._extract_and_dedupe_nodes_bulk(
|
|
943
|
+
episode_context,
|
|
944
|
+
edge_type_map or edge_type_map_default,
|
|
945
|
+
edge_types,
|
|
946
|
+
entity_types,
|
|
947
|
+
excluded_entity_types,
|
|
948
|
+
)
|
|
642
949
|
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
950
|
+
# Create Episodic Edges
|
|
951
|
+
episodic_edges: list[EpisodicEdge] = []
|
|
952
|
+
for episode_uuid, nodes in nodes_by_episode.items():
|
|
953
|
+
episodic_edges.extend(build_episodic_edges(nodes, episode_uuid, now))
|
|
646
954
|
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
955
|
+
# Re-map edge pointers and dedupe edges
|
|
956
|
+
extracted_edges_bulk_updated: list[list[EntityEdge]] = [
|
|
957
|
+
resolve_edge_pointers(edges, uuid_map) for edges in extracted_edges_bulk
|
|
958
|
+
]
|
|
651
959
|
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
)
|
|
960
|
+
edges_by_episode = await dedupe_edges_bulk(
|
|
961
|
+
self.clients,
|
|
962
|
+
extracted_edges_bulk_updated,
|
|
963
|
+
episode_context,
|
|
964
|
+
[],
|
|
965
|
+
edge_types or {},
|
|
966
|
+
edge_type_map or edge_type_map_default,
|
|
967
|
+
)
|
|
661
968
|
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
969
|
+
# Resolve nodes and edges against the existing graph
|
|
970
|
+
(
|
|
971
|
+
final_hydrated_nodes,
|
|
972
|
+
resolved_edges,
|
|
973
|
+
invalidated_edges,
|
|
974
|
+
final_uuid_map,
|
|
975
|
+
) = await self._resolve_nodes_and_edges_bulk(
|
|
976
|
+
nodes_by_episode,
|
|
977
|
+
edges_by_episode,
|
|
978
|
+
episode_context,
|
|
979
|
+
entity_types,
|
|
980
|
+
edge_types,
|
|
981
|
+
edge_type_map or edge_type_map_default,
|
|
982
|
+
episodes,
|
|
983
|
+
)
|
|
666
984
|
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
episode_mentions.sort(key=lambda x: x.valid_at, reverse=True)
|
|
680
|
-
|
|
681
|
-
extract_attributes_params.append((node, episode_mentions))
|
|
682
|
-
|
|
683
|
-
new_hydrated_nodes: list[list[EntityNode]] = await semaphore_gather(
|
|
684
|
-
*[
|
|
685
|
-
extract_attributes_from_nodes(
|
|
686
|
-
self.clients,
|
|
687
|
-
[params[0]],
|
|
688
|
-
params[1][0],
|
|
689
|
-
params[1][0:],
|
|
690
|
-
entity_types,
|
|
691
|
-
)
|
|
692
|
-
for params in extract_attributes_params
|
|
693
|
-
]
|
|
694
|
-
)
|
|
985
|
+
# Resolved pointers for episodic edges
|
|
986
|
+
resolved_episodic_edges = resolve_edge_pointers(episodic_edges, final_uuid_map)
|
|
987
|
+
|
|
988
|
+
# save data to KG
|
|
989
|
+
await add_nodes_and_edges_bulk(
|
|
990
|
+
self.driver,
|
|
991
|
+
episodes,
|
|
992
|
+
resolved_episodic_edges,
|
|
993
|
+
final_hydrated_nodes,
|
|
994
|
+
resolved_edges + invalidated_edges,
|
|
995
|
+
self.embedder,
|
|
996
|
+
)
|
|
695
997
|
|
|
696
|
-
|
|
998
|
+
end = time()
|
|
697
999
|
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
1000
|
+
# Add span attributes
|
|
1001
|
+
bulk_span.add_attributes(
|
|
1002
|
+
{
|
|
1003
|
+
'group_id': group_id,
|
|
1004
|
+
'node.count': len(final_hydrated_nodes),
|
|
1005
|
+
'edge.count': len(resolved_edges + invalidated_edges),
|
|
1006
|
+
'duration_ms': (end - start) * 1000,
|
|
1007
|
+
}
|
|
1008
|
+
)
|
|
702
1009
|
|
|
703
|
-
|
|
704
|
-
await add_nodes_and_edges_bulk(
|
|
705
|
-
self.driver,
|
|
706
|
-
episodes,
|
|
707
|
-
episodic_edges,
|
|
708
|
-
hydrated_nodes,
|
|
709
|
-
list(edges_by_uuid.values()),
|
|
710
|
-
self.embedder,
|
|
711
|
-
)
|
|
1010
|
+
logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
|
|
712
1011
|
|
|
713
|
-
|
|
714
|
-
|
|
1012
|
+
return AddBulkEpisodeResults(
|
|
1013
|
+
episodes=episodes,
|
|
1014
|
+
episodic_edges=resolved_episodic_edges,
|
|
1015
|
+
nodes=final_hydrated_nodes,
|
|
1016
|
+
edges=resolved_edges + invalidated_edges,
|
|
1017
|
+
communities=[],
|
|
1018
|
+
community_edges=[],
|
|
1019
|
+
)
|
|
715
1020
|
|
|
716
|
-
|
|
717
|
-
|
|
1021
|
+
except Exception as e:
|
|
1022
|
+
bulk_span.set_status('error', str(e))
|
|
1023
|
+
bulk_span.record_exception(e)
|
|
1024
|
+
raise e
|
|
718
1025
|
|
|
719
|
-
|
|
1026
|
+
@handle_multiple_group_ids
|
|
1027
|
+
async def build_communities(
|
|
1028
|
+
self, group_ids: list[str] | None = None, driver: GraphDriver | None = None
|
|
1029
|
+
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
|
720
1030
|
"""
|
|
721
1031
|
Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
|
|
722
1032
|
the content of these communities.
|
|
723
1033
|
----------
|
|
724
|
-
|
|
1034
|
+
group_ids : list[str] | None
|
|
725
1035
|
Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
|
|
726
1036
|
"""
|
|
1037
|
+
if driver is None:
|
|
1038
|
+
driver = self.clients.driver
|
|
1039
|
+
|
|
727
1040
|
# Clear existing communities
|
|
728
|
-
await remove_communities(
|
|
1041
|
+
await remove_communities(driver)
|
|
729
1042
|
|
|
730
1043
|
community_nodes, community_edges = await build_communities(
|
|
731
|
-
|
|
1044
|
+
driver, self.llm_client, group_ids
|
|
732
1045
|
)
|
|
733
1046
|
|
|
734
1047
|
await semaphore_gather(
|
|
@@ -737,16 +1050,17 @@ class Graphiti:
|
|
|
737
1050
|
)
|
|
738
1051
|
|
|
739
1052
|
await semaphore_gather(
|
|
740
|
-
*[node.save(
|
|
1053
|
+
*[node.save(driver) for node in community_nodes],
|
|
741
1054
|
max_coroutines=self.max_coroutines,
|
|
742
1055
|
)
|
|
743
1056
|
await semaphore_gather(
|
|
744
|
-
*[edge.save(
|
|
1057
|
+
*[edge.save(driver) for edge in community_edges],
|
|
745
1058
|
max_coroutines=self.max_coroutines,
|
|
746
1059
|
)
|
|
747
1060
|
|
|
748
|
-
return community_nodes
|
|
1061
|
+
return community_nodes, community_edges
|
|
749
1062
|
|
|
1063
|
+
@handle_multiple_group_ids
|
|
750
1064
|
async def search(
|
|
751
1065
|
self,
|
|
752
1066
|
query: str,
|
|
@@ -754,6 +1068,7 @@ class Graphiti:
|
|
|
754
1068
|
group_ids: list[str] | None = None,
|
|
755
1069
|
num_results=DEFAULT_SEARCH_LIMIT,
|
|
756
1070
|
search_filter: SearchFilters | None = None,
|
|
1071
|
+
driver: GraphDriver | None = None,
|
|
757
1072
|
) -> list[EntityEdge]:
|
|
758
1073
|
"""
|
|
759
1074
|
Perform a hybrid search on the knowledge graph.
|
|
@@ -800,7 +1115,8 @@ class Graphiti:
|
|
|
800
1115
|
group_ids,
|
|
801
1116
|
search_config,
|
|
802
1117
|
search_filter if search_filter is not None else SearchFilters(),
|
|
803
|
-
|
|
1118
|
+
driver=driver,
|
|
1119
|
+
center_node_uuid=center_node_uuid,
|
|
804
1120
|
)
|
|
805
1121
|
).edges
|
|
806
1122
|
|
|
@@ -820,6 +1136,7 @@ class Graphiti:
|
|
|
820
1136
|
query, config, group_ids, center_node_uuid, bfs_origin_node_uuids, search_filter
|
|
821
1137
|
)
|
|
822
1138
|
|
|
1139
|
+
@handle_multiple_group_ids
|
|
823
1140
|
async def search_(
|
|
824
1141
|
self,
|
|
825
1142
|
query: str,
|
|
@@ -828,6 +1145,7 @@ class Graphiti:
|
|
|
828
1145
|
center_node_uuid: str | None = None,
|
|
829
1146
|
bfs_origin_node_uuids: list[str] | None = None,
|
|
830
1147
|
search_filter: SearchFilters | None = None,
|
|
1148
|
+
driver: GraphDriver | None = None,
|
|
831
1149
|
) -> SearchResults:
|
|
832
1150
|
"""search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather
|
|
833
1151
|
than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and
|
|
@@ -844,6 +1162,7 @@ class Graphiti:
|
|
|
844
1162
|
search_filter if search_filter is not None else SearchFilters(),
|
|
845
1163
|
center_node_uuid,
|
|
846
1164
|
bfs_origin_node_uuids,
|
|
1165
|
+
driver=driver,
|
|
847
1166
|
)
|
|
848
1167
|
|
|
849
1168
|
async def get_nodes_and_edges_by_episode(self, episode_uuids: list[str]) -> SearchResults:
|
|
@@ -858,9 +1177,11 @@ class Graphiti:
|
|
|
858
1177
|
|
|
859
1178
|
nodes = await get_mentioned_nodes(self.driver, episodes)
|
|
860
1179
|
|
|
861
|
-
return SearchResults(edges=edges, nodes=nodes
|
|
1180
|
+
return SearchResults(edges=edges, nodes=nodes)
|
|
862
1181
|
|
|
863
|
-
async def add_triplet(
|
|
1182
|
+
async def add_triplet(
|
|
1183
|
+
self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode
|
|
1184
|
+
) -> AddTripletResults:
|
|
864
1185
|
if source_node.name_embedding is None:
|
|
865
1186
|
await source_node.generate_name_embedding(self.embedder)
|
|
866
1187
|
if target_node.name_embedding is None:
|
|
@@ -868,21 +1189,74 @@ class Graphiti:
|
|
|
868
1189
|
if edge.fact_embedding is None:
|
|
869
1190
|
await edge.generate_embedding(self.embedder)
|
|
870
1191
|
|
|
871
|
-
|
|
872
|
-
self.
|
|
873
|
-
|
|
874
|
-
|
|
1192
|
+
try:
|
|
1193
|
+
resolved_source = await EntityNode.get_by_uuid(self.driver, source_node.uuid)
|
|
1194
|
+
except NodeNotFoundError:
|
|
1195
|
+
resolved_source_nodes, _, _ = await resolve_extracted_nodes(
|
|
1196
|
+
self.clients,
|
|
1197
|
+
[source_node],
|
|
1198
|
+
)
|
|
1199
|
+
resolved_source = resolved_source_nodes[0]
|
|
875
1200
|
|
|
876
|
-
|
|
1201
|
+
try:
|
|
1202
|
+
resolved_target = await EntityNode.get_by_uuid(self.driver, target_node.uuid)
|
|
1203
|
+
except NodeNotFoundError:
|
|
1204
|
+
resolved_target_nodes, _, _ = await resolve_extracted_nodes(
|
|
1205
|
+
self.clients,
|
|
1206
|
+
[target_node],
|
|
1207
|
+
)
|
|
1208
|
+
resolved_target = resolved_target_nodes[0]
|
|
1209
|
+
|
|
1210
|
+
nodes = [resolved_source, resolved_target]
|
|
1211
|
+
|
|
1212
|
+
# Merge user-provided properties from original nodes into resolved nodes (excluding uuid)
|
|
1213
|
+
# Update attributes dictionary (merge rather than replace)
|
|
1214
|
+
if source_node.attributes:
|
|
1215
|
+
resolved_source.attributes.update(source_node.attributes)
|
|
1216
|
+
if target_node.attributes:
|
|
1217
|
+
resolved_target.attributes.update(target_node.attributes)
|
|
1218
|
+
|
|
1219
|
+
# Update summary if provided by user (non-empty string)
|
|
1220
|
+
if source_node.summary:
|
|
1221
|
+
resolved_source.summary = source_node.summary
|
|
1222
|
+
if target_node.summary:
|
|
1223
|
+
resolved_target.summary = target_node.summary
|
|
1224
|
+
|
|
1225
|
+
# Update labels (merge with existing)
|
|
1226
|
+
if source_node.labels:
|
|
1227
|
+
resolved_source.labels = list(set(resolved_source.labels) | set(source_node.labels))
|
|
1228
|
+
if target_node.labels:
|
|
1229
|
+
resolved_target.labels = list(set(resolved_target.labels) | set(target_node.labels))
|
|
1230
|
+
|
|
1231
|
+
edge.source_node_uuid = resolved_source.uuid
|
|
1232
|
+
edge.target_node_uuid = resolved_target.uuid
|
|
1233
|
+
|
|
1234
|
+
valid_edges = await EntityEdge.get_between_nodes(
|
|
1235
|
+
self.driver, edge.source_node_uuid, edge.target_node_uuid
|
|
1236
|
+
)
|
|
877
1237
|
|
|
878
|
-
related_edges = (
|
|
1238
|
+
related_edges = (
|
|
1239
|
+
await search(
|
|
1240
|
+
self.clients,
|
|
1241
|
+
edge.fact,
|
|
1242
|
+
group_ids=[edge.group_id],
|
|
1243
|
+
config=EDGE_HYBRID_SEARCH_RRF,
|
|
1244
|
+
search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
|
|
1245
|
+
)
|
|
1246
|
+
).edges
|
|
879
1247
|
existing_edges = (
|
|
880
|
-
await
|
|
881
|
-
|
|
1248
|
+
await search(
|
|
1249
|
+
self.clients,
|
|
1250
|
+
edge.fact,
|
|
1251
|
+
group_ids=[edge.group_id],
|
|
1252
|
+
config=EDGE_HYBRID_SEARCH_RRF,
|
|
1253
|
+
search_filter=SearchFilters(),
|
|
1254
|
+
)
|
|
1255
|
+
).edges
|
|
882
1256
|
|
|
883
1257
|
resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
|
|
884
1258
|
self.llm_client,
|
|
885
|
-
|
|
1259
|
+
edge,
|
|
886
1260
|
related_edges,
|
|
887
1261
|
existing_edges,
|
|
888
1262
|
EpisodicNode(
|
|
@@ -894,11 +1268,17 @@ class Graphiti:
|
|
|
894
1268
|
entity_edges=[],
|
|
895
1269
|
group_id=edge.group_id,
|
|
896
1270
|
),
|
|
1271
|
+
None,
|
|
1272
|
+
None,
|
|
897
1273
|
)
|
|
898
1274
|
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
)
|
|
1275
|
+
edges: list[EntityEdge] = [resolved_edge] + invalidated_edges
|
|
1276
|
+
|
|
1277
|
+
await create_entity_edge_embeddings(self.embedder, edges)
|
|
1278
|
+
await create_entity_node_embeddings(self.embedder, nodes)
|
|
1279
|
+
|
|
1280
|
+
await add_nodes_and_edges_bulk(self.driver, [], [], nodes, edges, self.embedder)
|
|
1281
|
+
return AddTripletResults(edges=edges, nodes=nodes)
|
|
902
1282
|
|
|
903
1283
|
async def remove_episode(self, episode_uuid: str):
|
|
904
1284
|
# Find the episode to be deleted
|
|
@@ -925,12 +1305,7 @@ class Graphiti:
|
|
|
925
1305
|
if record['episode_count'] == 1:
|
|
926
1306
|
nodes_to_delete.append(node)
|
|
927
1307
|
|
|
928
|
-
await
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
)
|
|
932
|
-
await semaphore_gather(
|
|
933
|
-
*[edge.delete(self.driver) for edge in edges_to_delete],
|
|
934
|
-
max_coroutines=self.max_coroutines,
|
|
935
|
-
)
|
|
1308
|
+
await Edge.delete_by_uuids(self.driver, [edge.uuid for edge in edges_to_delete])
|
|
1309
|
+
await Node.delete_by_uuids(self.driver, [node.uuid for node in nodes_to_delete])
|
|
1310
|
+
|
|
936
1311
|
await episode.delete(self.driver)
|