graphiti-core 0.17.4__py3-none-any.whl → 0.24.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphiti_core/cross_encoder/gemini_reranker_client.py +1 -1
- graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
- graphiti_core/decorators.py +110 -0
- graphiti_core/driver/driver.py +62 -2
- graphiti_core/driver/falkordb_driver.py +215 -23
- graphiti_core/driver/graph_operations/graph_operations.py +191 -0
- graphiti_core/driver/kuzu_driver.py +182 -0
- graphiti_core/driver/neo4j_driver.py +61 -8
- graphiti_core/driver/neptune_driver.py +305 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +264 -132
- graphiti_core/embedder/azure_openai.py +10 -3
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/graph_queries.py +114 -101
- graphiti_core/graphiti.py +582 -255
- graphiti_core/graphiti_types.py +2 -0
- graphiti_core/helpers.py +21 -14
- graphiti_core/llm_client/anthropic_client.py +142 -52
- graphiti_core/llm_client/azure_openai_client.py +57 -19
- graphiti_core/llm_client/client.py +83 -21
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/gemini_client.py +75 -57
- graphiti_core/llm_client/openai_base_client.py +94 -50
- graphiti_core/llm_client/openai_client.py +28 -8
- graphiti_core/llm_client/openai_generic_client.py +91 -56
- graphiti_core/models/edges/edge_db_queries.py +259 -35
- graphiti_core/models/nodes/node_db_queries.py +311 -32
- graphiti_core/nodes.py +388 -164
- graphiti_core/prompts/dedupe_edges.py +42 -31
- graphiti_core/prompts/dedupe_nodes.py +56 -39
- graphiti_core/prompts/eval.py +4 -4
- graphiti_core/prompts/extract_edges.py +23 -14
- graphiti_core/prompts/extract_nodes.py +73 -32
- graphiti_core/prompts/prompt_helpers.py +39 -0
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +23 -25
- graphiti_core/search/search.py +154 -74
- graphiti_core/search/search_config.py +39 -4
- graphiti_core/search/search_filters.py +109 -31
- graphiti_core/search/search_helpers.py +5 -6
- graphiti_core/search/search_utils.py +1360 -473
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +216 -90
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +62 -38
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +286 -126
- graphiti_core/utils/maintenance/graph_data_operations.py +44 -74
- graphiti_core/utils/maintenance/node_operations.py +320 -158
- graphiti_core/utils/maintenance/temporal_operations.py +11 -3
- graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
- graphiti_core/utils/text_utils.py +53 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/METADATA +221 -87
- graphiti_core-0.24.3.dist-info/RECORD +86 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
- graphiti_core-0.17.4.dist-info/RECORD +0 -77
- /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/licenses/LICENSE +0 -0
graphiti_core/graphiti.py
CHANGED
|
@@ -24,18 +24,33 @@ from typing_extensions import LiteralString
|
|
|
24
24
|
|
|
25
25
|
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
26
26
|
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
|
27
|
+
from graphiti_core.decorators import handle_multiple_group_ids
|
|
27
28
|
from graphiti_core.driver.driver import GraphDriver
|
|
28
29
|
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
|
29
|
-
from graphiti_core.edges import
|
|
30
|
+
from graphiti_core.edges import (
|
|
31
|
+
CommunityEdge,
|
|
32
|
+
Edge,
|
|
33
|
+
EntityEdge,
|
|
34
|
+
EpisodicEdge,
|
|
35
|
+
create_entity_edge_embeddings,
|
|
36
|
+
)
|
|
30
37
|
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
|
31
38
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
32
39
|
from graphiti_core.helpers import (
|
|
40
|
+
get_default_group_id,
|
|
33
41
|
semaphore_gather,
|
|
34
42
|
validate_excluded_entity_types,
|
|
35
43
|
validate_group_id,
|
|
36
44
|
)
|
|
37
45
|
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
|
38
|
-
from graphiti_core.nodes import
|
|
46
|
+
from graphiti_core.nodes import (
|
|
47
|
+
CommunityNode,
|
|
48
|
+
EntityNode,
|
|
49
|
+
EpisodeType,
|
|
50
|
+
EpisodicNode,
|
|
51
|
+
Node,
|
|
52
|
+
create_entity_node_embeddings,
|
|
53
|
+
)
|
|
39
54
|
from graphiti_core.search.search import SearchConfig, search
|
|
40
55
|
from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
|
|
41
56
|
from graphiti_core.search.search_config_recipes import (
|
|
@@ -46,11 +61,10 @@ from graphiti_core.search.search_config_recipes import (
|
|
|
46
61
|
from graphiti_core.search.search_filters import SearchFilters
|
|
47
62
|
from graphiti_core.search.search_utils import (
|
|
48
63
|
RELEVANT_SCHEMA_LIMIT,
|
|
49
|
-
get_edge_invalidation_candidates,
|
|
50
64
|
get_mentioned_nodes,
|
|
51
|
-
get_relevant_edges,
|
|
52
65
|
)
|
|
53
66
|
from graphiti_core.telemetry import capture_event
|
|
67
|
+
from graphiti_core.tracer import Tracer, create_tracer
|
|
54
68
|
from graphiti_core.utils.bulk_utils import (
|
|
55
69
|
RawEpisode,
|
|
56
70
|
add_nodes_and_edges_bulk,
|
|
@@ -67,7 +81,6 @@ from graphiti_core.utils.maintenance.community_operations import (
|
|
|
67
81
|
update_community,
|
|
68
82
|
)
|
|
69
83
|
from graphiti_core.utils.maintenance.edge_operations import (
|
|
70
|
-
build_duplicate_of_edges,
|
|
71
84
|
build_episodic_edges,
|
|
72
85
|
extract_edges,
|
|
73
86
|
resolve_extracted_edge,
|
|
@@ -75,7 +88,6 @@ from graphiti_core.utils.maintenance.edge_operations import (
|
|
|
75
88
|
)
|
|
76
89
|
from graphiti_core.utils.maintenance.graph_data_operations import (
|
|
77
90
|
EPISODE_WINDOW_LEN,
|
|
78
|
-
build_indices_and_constraints,
|
|
79
91
|
retrieve_episodes,
|
|
80
92
|
)
|
|
81
93
|
from graphiti_core.utils.maintenance.node_operations import (
|
|
@@ -92,6 +104,23 @@ load_dotenv()
|
|
|
92
104
|
|
|
93
105
|
class AddEpisodeResults(BaseModel):
|
|
94
106
|
episode: EpisodicNode
|
|
107
|
+
episodic_edges: list[EpisodicEdge]
|
|
108
|
+
nodes: list[EntityNode]
|
|
109
|
+
edges: list[EntityEdge]
|
|
110
|
+
communities: list[CommunityNode]
|
|
111
|
+
community_edges: list[CommunityEdge]
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class AddBulkEpisodeResults(BaseModel):
|
|
115
|
+
episodes: list[EpisodicNode]
|
|
116
|
+
episodic_edges: list[EpisodicEdge]
|
|
117
|
+
nodes: list[EntityNode]
|
|
118
|
+
edges: list[EntityEdge]
|
|
119
|
+
communities: list[CommunityNode]
|
|
120
|
+
community_edges: list[CommunityEdge]
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class AddTripletResults(BaseModel):
|
|
95
124
|
nodes: list[EntityNode]
|
|
96
125
|
edges: list[EntityEdge]
|
|
97
126
|
|
|
@@ -108,11 +137,13 @@ class Graphiti:
|
|
|
108
137
|
store_raw_episode_content: bool = True,
|
|
109
138
|
graph_driver: GraphDriver | None = None,
|
|
110
139
|
max_coroutines: int | None = None,
|
|
140
|
+
tracer: Tracer | None = None,
|
|
141
|
+
trace_span_prefix: str = 'graphiti',
|
|
111
142
|
):
|
|
112
143
|
"""
|
|
113
144
|
Initialize a Graphiti instance.
|
|
114
145
|
|
|
115
|
-
This constructor sets up a connection to
|
|
146
|
+
This constructor sets up a connection to a graph database and initializes
|
|
116
147
|
the LLM client for natural language processing tasks.
|
|
117
148
|
|
|
118
149
|
Parameters
|
|
@@ -140,6 +171,10 @@ class Graphiti:
|
|
|
140
171
|
max_coroutines : int | None, optional
|
|
141
172
|
The maximum number of concurrent operations allowed. Overrides SEMAPHORE_LIMIT set in the environment.
|
|
142
173
|
If not set, the Graphiti default is used.
|
|
174
|
+
tracer : Tracer | None, optional
|
|
175
|
+
An OpenTelemetry tracer instance for distributed tracing. If not provided, tracing is disabled (no-op).
|
|
176
|
+
trace_span_prefix : str, optional
|
|
177
|
+
Prefix to prepend to all span names. Defaults to 'graphiti'.
|
|
143
178
|
|
|
144
179
|
Returns
|
|
145
180
|
-------
|
|
@@ -147,11 +182,11 @@ class Graphiti:
|
|
|
147
182
|
|
|
148
183
|
Notes
|
|
149
184
|
-----
|
|
150
|
-
This method establishes a connection to
|
|
185
|
+
This method establishes a connection to a graph database (Neo4j by default) using the provided
|
|
151
186
|
credentials. It also sets up the LLM client, either using the provided client
|
|
152
187
|
or by creating a default OpenAIClient.
|
|
153
188
|
|
|
154
|
-
The default database name is
|
|
189
|
+
The default database name is defined during the driver’s construction. If a different database name
|
|
155
190
|
is required, it should be specified in the URI or set separately after
|
|
156
191
|
initialization.
|
|
157
192
|
|
|
@@ -182,11 +217,18 @@ class Graphiti:
|
|
|
182
217
|
else:
|
|
183
218
|
self.cross_encoder = OpenAIRerankerClient()
|
|
184
219
|
|
|
220
|
+
# Initialize tracer
|
|
221
|
+
self.tracer = create_tracer(tracer, trace_span_prefix)
|
|
222
|
+
|
|
223
|
+
# Set tracer on clients
|
|
224
|
+
self.llm_client.set_tracer(self.tracer)
|
|
225
|
+
|
|
185
226
|
self.clients = GraphitiClients(
|
|
186
227
|
driver=self.driver,
|
|
187
228
|
llm_client=self.llm_client,
|
|
188
229
|
embedder=self.embedder,
|
|
189
230
|
cross_encoder=self.cross_encoder,
|
|
231
|
+
tracer=self.tracer,
|
|
190
232
|
)
|
|
191
233
|
|
|
192
234
|
# Capture telemetry event
|
|
@@ -298,25 +340,247 @@ class Graphiti:
|
|
|
298
340
|
-----
|
|
299
341
|
This method should typically be called once during the initial setup of the
|
|
300
342
|
knowledge graph or when updating the database schema. It uses the
|
|
301
|
-
`build_indices_and_constraints`
|
|
302
|
-
`graphiti_core.utils.maintenance.graph_data_operations` module to perform
|
|
343
|
+
driver's `build_indices_and_constraints` method to perform
|
|
303
344
|
the actual database operations.
|
|
304
345
|
|
|
305
346
|
The specific indices and constraints created depend on the implementation
|
|
306
|
-
of the `build_indices_and_constraints`
|
|
307
|
-
documentation for details on the exact database schema modifications.
|
|
347
|
+
of the driver's `build_indices_and_constraints` method. Refer to the specific
|
|
348
|
+
driver documentation for details on the exact database schema modifications.
|
|
308
349
|
|
|
309
350
|
Caution: Running this method on a large existing database may take some time
|
|
310
351
|
and could impact database performance during execution.
|
|
311
352
|
"""
|
|
312
|
-
await
|
|
353
|
+
await self.driver.build_indices_and_constraints(delete_existing)
|
|
354
|
+
|
|
355
|
+
async def _extract_and_resolve_nodes(
|
|
356
|
+
self,
|
|
357
|
+
episode: EpisodicNode,
|
|
358
|
+
previous_episodes: list[EpisodicNode],
|
|
359
|
+
entity_types: dict[str, type[BaseModel]] | None,
|
|
360
|
+
excluded_entity_types: list[str] | None,
|
|
361
|
+
) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
|
|
362
|
+
"""Extract nodes from episode and resolve against existing graph."""
|
|
363
|
+
extracted_nodes = await extract_nodes(
|
|
364
|
+
self.clients, episode, previous_episodes, entity_types, excluded_entity_types
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
nodes, uuid_map, duplicates = await resolve_extracted_nodes(
|
|
368
|
+
self.clients,
|
|
369
|
+
extracted_nodes,
|
|
370
|
+
episode,
|
|
371
|
+
previous_episodes,
|
|
372
|
+
entity_types,
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
return nodes, uuid_map, duplicates
|
|
376
|
+
|
|
377
|
+
async def _extract_and_resolve_edges(
|
|
378
|
+
self,
|
|
379
|
+
episode: EpisodicNode,
|
|
380
|
+
extracted_nodes: list[EntityNode],
|
|
381
|
+
previous_episodes: list[EpisodicNode],
|
|
382
|
+
edge_type_map: dict[tuple[str, str], list[str]],
|
|
383
|
+
group_id: str,
|
|
384
|
+
edge_types: dict[str, type[BaseModel]] | None,
|
|
385
|
+
nodes: list[EntityNode],
|
|
386
|
+
uuid_map: dict[str, str],
|
|
387
|
+
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
|
388
|
+
"""Extract edges from episode and resolve against existing graph."""
|
|
389
|
+
extracted_edges = await extract_edges(
|
|
390
|
+
self.clients,
|
|
391
|
+
episode,
|
|
392
|
+
extracted_nodes,
|
|
393
|
+
previous_episodes,
|
|
394
|
+
edge_type_map,
|
|
395
|
+
group_id,
|
|
396
|
+
edge_types,
|
|
397
|
+
)
|
|
313
398
|
|
|
399
|
+
edges = resolve_edge_pointers(extracted_edges, uuid_map)
|
|
400
|
+
|
|
401
|
+
resolved_edges, invalidated_edges = await resolve_extracted_edges(
|
|
402
|
+
self.clients,
|
|
403
|
+
edges,
|
|
404
|
+
episode,
|
|
405
|
+
nodes,
|
|
406
|
+
edge_types or {},
|
|
407
|
+
edge_type_map,
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
return resolved_edges, invalidated_edges
|
|
411
|
+
|
|
412
|
+
async def _process_episode_data(
|
|
413
|
+
self,
|
|
414
|
+
episode: EpisodicNode,
|
|
415
|
+
nodes: list[EntityNode],
|
|
416
|
+
entity_edges: list[EntityEdge],
|
|
417
|
+
now: datetime,
|
|
418
|
+
) -> tuple[list[EpisodicEdge], EpisodicNode]:
|
|
419
|
+
"""Process and save episode data to the graph."""
|
|
420
|
+
episodic_edges = build_episodic_edges(nodes, episode.uuid, now)
|
|
421
|
+
episode.entity_edges = [edge.uuid for edge in entity_edges]
|
|
422
|
+
|
|
423
|
+
if not self.store_raw_episode_content:
|
|
424
|
+
episode.content = ''
|
|
425
|
+
|
|
426
|
+
await add_nodes_and_edges_bulk(
|
|
427
|
+
self.driver,
|
|
428
|
+
[episode],
|
|
429
|
+
episodic_edges,
|
|
430
|
+
nodes,
|
|
431
|
+
entity_edges,
|
|
432
|
+
self.embedder,
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
return episodic_edges, episode
|
|
436
|
+
|
|
437
|
+
async def _extract_and_dedupe_nodes_bulk(
|
|
438
|
+
self,
|
|
439
|
+
episode_context: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
|
440
|
+
edge_type_map: dict[tuple[str, str], list[str]],
|
|
441
|
+
edge_types: dict[str, type[BaseModel]] | None,
|
|
442
|
+
entity_types: dict[str, type[BaseModel]] | None,
|
|
443
|
+
excluded_entity_types: list[str] | None,
|
|
444
|
+
) -> tuple[
|
|
445
|
+
dict[str, list[EntityNode]],
|
|
446
|
+
dict[str, str],
|
|
447
|
+
list[list[EntityEdge]],
|
|
448
|
+
]:
|
|
449
|
+
"""Extract nodes and edges from all episodes and deduplicate."""
|
|
450
|
+
# Extract all nodes and edges for each episode
|
|
451
|
+
extracted_nodes_bulk, extracted_edges_bulk = await extract_nodes_and_edges_bulk(
|
|
452
|
+
self.clients,
|
|
453
|
+
episode_context,
|
|
454
|
+
edge_type_map=edge_type_map,
|
|
455
|
+
edge_types=edge_types,
|
|
456
|
+
entity_types=entity_types,
|
|
457
|
+
excluded_entity_types=excluded_entity_types,
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
# Dedupe extracted nodes in memory
|
|
461
|
+
nodes_by_episode, uuid_map = await dedupe_nodes_bulk(
|
|
462
|
+
self.clients, extracted_nodes_bulk, episode_context, entity_types
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
return nodes_by_episode, uuid_map, extracted_edges_bulk
|
|
466
|
+
|
|
467
|
+
async def _resolve_nodes_and_edges_bulk(
|
|
468
|
+
self,
|
|
469
|
+
nodes_by_episode: dict[str, list[EntityNode]],
|
|
470
|
+
edges_by_episode: dict[str, list[EntityEdge]],
|
|
471
|
+
episode_context: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
|
472
|
+
entity_types: dict[str, type[BaseModel]] | None,
|
|
473
|
+
edge_types: dict[str, type[BaseModel]] | None,
|
|
474
|
+
edge_type_map: dict[tuple[str, str], list[str]],
|
|
475
|
+
episodes: list[EpisodicNode],
|
|
476
|
+
) -> tuple[list[EntityNode], list[EntityEdge], list[EntityEdge], dict[str, str]]:
|
|
477
|
+
"""Resolve nodes and edges against the existing graph."""
|
|
478
|
+
nodes_by_uuid: dict[str, EntityNode] = {
|
|
479
|
+
node.uuid: node for nodes in nodes_by_episode.values() for node in nodes
|
|
480
|
+
}
|
|
481
|
+
|
|
482
|
+
# Get unique nodes per episode
|
|
483
|
+
nodes_by_episode_unique: dict[str, list[EntityNode]] = {}
|
|
484
|
+
nodes_uuid_set: set[str] = set()
|
|
485
|
+
for episode, _ in episode_context:
|
|
486
|
+
nodes_by_episode_unique[episode.uuid] = []
|
|
487
|
+
nodes = [nodes_by_uuid[node.uuid] for node in nodes_by_episode[episode.uuid]]
|
|
488
|
+
for node in nodes:
|
|
489
|
+
if node.uuid not in nodes_uuid_set:
|
|
490
|
+
nodes_by_episode_unique[episode.uuid].append(node)
|
|
491
|
+
nodes_uuid_set.add(node.uuid)
|
|
492
|
+
|
|
493
|
+
# Resolve nodes
|
|
494
|
+
node_results = await semaphore_gather(
|
|
495
|
+
*[
|
|
496
|
+
resolve_extracted_nodes(
|
|
497
|
+
self.clients,
|
|
498
|
+
nodes_by_episode_unique[episode.uuid],
|
|
499
|
+
episode,
|
|
500
|
+
previous_episodes,
|
|
501
|
+
entity_types,
|
|
502
|
+
)
|
|
503
|
+
for episode, previous_episodes in episode_context
|
|
504
|
+
]
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
resolved_nodes: list[EntityNode] = []
|
|
508
|
+
uuid_map: dict[str, str] = {}
|
|
509
|
+
for result in node_results:
|
|
510
|
+
resolved_nodes.extend(result[0])
|
|
511
|
+
uuid_map.update(result[1])
|
|
512
|
+
|
|
513
|
+
# Update nodes_by_uuid with resolved nodes
|
|
514
|
+
for resolved_node in resolved_nodes:
|
|
515
|
+
nodes_by_uuid[resolved_node.uuid] = resolved_node
|
|
516
|
+
|
|
517
|
+
# Update nodes_by_episode_unique with resolved pointers
|
|
518
|
+
for episode_uuid, nodes in nodes_by_episode_unique.items():
|
|
519
|
+
updated_nodes: list[EntityNode] = []
|
|
520
|
+
for node in nodes:
|
|
521
|
+
updated_node_uuid = uuid_map.get(node.uuid, node.uuid)
|
|
522
|
+
updated_node = nodes_by_uuid[updated_node_uuid]
|
|
523
|
+
updated_nodes.append(updated_node)
|
|
524
|
+
nodes_by_episode_unique[episode_uuid] = updated_nodes
|
|
525
|
+
|
|
526
|
+
# Extract attributes for resolved nodes
|
|
527
|
+
hydrated_nodes_results: list[list[EntityNode]] = await semaphore_gather(
|
|
528
|
+
*[
|
|
529
|
+
extract_attributes_from_nodes(
|
|
530
|
+
self.clients,
|
|
531
|
+
nodes_by_episode_unique[episode.uuid],
|
|
532
|
+
episode,
|
|
533
|
+
previous_episodes,
|
|
534
|
+
entity_types,
|
|
535
|
+
)
|
|
536
|
+
for episode, previous_episodes in episode_context
|
|
537
|
+
]
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
final_hydrated_nodes = [node for nodes in hydrated_nodes_results for node in nodes]
|
|
541
|
+
|
|
542
|
+
# Resolve edges with updated pointers
|
|
543
|
+
edges_by_episode_unique: dict[str, list[EntityEdge]] = {}
|
|
544
|
+
edges_uuid_set: set[str] = set()
|
|
545
|
+
for episode_uuid, edges in edges_by_episode.items():
|
|
546
|
+
edges_with_updated_pointers = resolve_edge_pointers(edges, uuid_map)
|
|
547
|
+
edges_by_episode_unique[episode_uuid] = []
|
|
548
|
+
|
|
549
|
+
for edge in edges_with_updated_pointers:
|
|
550
|
+
if edge.uuid not in edges_uuid_set:
|
|
551
|
+
edges_by_episode_unique[episode_uuid].append(edge)
|
|
552
|
+
edges_uuid_set.add(edge.uuid)
|
|
553
|
+
|
|
554
|
+
edge_results = await semaphore_gather(
|
|
555
|
+
*[
|
|
556
|
+
resolve_extracted_edges(
|
|
557
|
+
self.clients,
|
|
558
|
+
edges_by_episode_unique[episode.uuid],
|
|
559
|
+
episode,
|
|
560
|
+
final_hydrated_nodes,
|
|
561
|
+
edge_types or {},
|
|
562
|
+
edge_type_map,
|
|
563
|
+
)
|
|
564
|
+
for episode in episodes
|
|
565
|
+
]
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
resolved_edges: list[EntityEdge] = []
|
|
569
|
+
invalidated_edges: list[EntityEdge] = []
|
|
570
|
+
for result in edge_results:
|
|
571
|
+
resolved_edges.extend(result[0])
|
|
572
|
+
invalidated_edges.extend(result[1])
|
|
573
|
+
|
|
574
|
+
return final_hydrated_nodes, resolved_edges, invalidated_edges, uuid_map
|
|
575
|
+
|
|
576
|
+
@handle_multiple_group_ids
|
|
314
577
|
async def retrieve_episodes(
|
|
315
578
|
self,
|
|
316
579
|
reference_time: datetime,
|
|
317
580
|
last_n: int = EPISODE_WINDOW_LEN,
|
|
318
581
|
group_ids: list[str] | None = None,
|
|
319
582
|
source: EpisodeType | None = None,
|
|
583
|
+
driver: GraphDriver | None = None,
|
|
320
584
|
) -> list[EpisodicNode]:
|
|
321
585
|
"""
|
|
322
586
|
Retrieve the last n episodic nodes from the graph.
|
|
@@ -343,7 +607,10 @@ class Graphiti:
|
|
|
343
607
|
The actual retrieval is performed by the `retrieve_episodes` function
|
|
344
608
|
from the `graphiti_core.utils` module.
|
|
345
609
|
"""
|
|
346
|
-
|
|
610
|
+
if driver is None:
|
|
611
|
+
driver = self.clients.driver
|
|
612
|
+
|
|
613
|
+
return await retrieve_episodes(driver, reference_time, last_n, group_ids, source)
|
|
347
614
|
|
|
348
615
|
async def add_episode(
|
|
349
616
|
self,
|
|
@@ -352,13 +619,13 @@ class Graphiti:
|
|
|
352
619
|
source_description: str,
|
|
353
620
|
reference_time: datetime,
|
|
354
621
|
source: EpisodeType = EpisodeType.message,
|
|
355
|
-
group_id: str =
|
|
622
|
+
group_id: str | None = None,
|
|
356
623
|
uuid: str | None = None,
|
|
357
624
|
update_communities: bool = False,
|
|
358
|
-
entity_types: dict[str, BaseModel] | None = None,
|
|
625
|
+
entity_types: dict[str, type[BaseModel]] | None = None,
|
|
359
626
|
excluded_entity_types: list[str] | None = None,
|
|
360
627
|
previous_episode_uuids: list[str] | None = None,
|
|
361
|
-
edge_types: dict[str, BaseModel] | None = None,
|
|
628
|
+
edge_types: dict[str, type[BaseModel]] | None = None,
|
|
362
629
|
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
|
|
363
630
|
) -> AddEpisodeResults:
|
|
364
631
|
"""
|
|
@@ -416,133 +683,155 @@ class Graphiti:
|
|
|
416
683
|
background_tasks.add_task(graphiti.add_episode, **episode_data.dict())
|
|
417
684
|
return {"message": "Episode processing started"}
|
|
418
685
|
"""
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
686
|
+
start = time()
|
|
687
|
+
now = utc_now()
|
|
688
|
+
|
|
689
|
+
validate_entity_types(entity_types)
|
|
690
|
+
validate_excluded_entity_types(excluded_entity_types, entity_types)
|
|
422
691
|
|
|
423
|
-
|
|
424
|
-
|
|
692
|
+
if group_id is None:
|
|
693
|
+
# if group_id is None, use the default group id by the provider
|
|
694
|
+
# and the preset database name will be used
|
|
695
|
+
group_id = get_default_group_id(self.driver.provider)
|
|
696
|
+
else:
|
|
425
697
|
validate_group_id(group_id)
|
|
698
|
+
if group_id != self.driver._database:
|
|
699
|
+
# if group_id is provided, use it as the database name
|
|
700
|
+
self.driver = self.driver.clone(database=group_id)
|
|
701
|
+
self.clients.driver = self.driver
|
|
426
702
|
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
703
|
+
with self.tracer.start_span('add_episode') as span:
|
|
704
|
+
try:
|
|
705
|
+
# Retrieve previous episodes for context
|
|
706
|
+
previous_episodes = (
|
|
707
|
+
await self.retrieve_episodes(
|
|
708
|
+
reference_time,
|
|
709
|
+
last_n=RELEVANT_SCHEMA_LIMIT,
|
|
710
|
+
group_ids=[group_id],
|
|
711
|
+
source=source,
|
|
712
|
+
)
|
|
713
|
+
if previous_episode_uuids is None
|
|
714
|
+
else await EpisodicNode.get_by_uuids(self.driver, previous_episode_uuids)
|
|
433
715
|
)
|
|
434
|
-
if previous_episode_uuids is None
|
|
435
|
-
else await EpisodicNode.get_by_uuids(self.driver, previous_episode_uuids)
|
|
436
|
-
)
|
|
437
716
|
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
717
|
+
# Get or create episode
|
|
718
|
+
episode = (
|
|
719
|
+
await EpisodicNode.get_by_uuid(self.driver, uuid)
|
|
720
|
+
if uuid is not None
|
|
721
|
+
else EpisodicNode(
|
|
722
|
+
name=name,
|
|
723
|
+
group_id=group_id,
|
|
724
|
+
labels=[],
|
|
725
|
+
source=source,
|
|
726
|
+
content=episode_body,
|
|
727
|
+
source_description=source_description,
|
|
728
|
+
created_at=now,
|
|
729
|
+
valid_at=reference_time,
|
|
730
|
+
)
|
|
450
731
|
)
|
|
451
|
-
)
|
|
452
|
-
|
|
453
|
-
# Create default edge type map
|
|
454
|
-
edge_type_map_default = (
|
|
455
|
-
{('Entity', 'Entity'): list(edge_types.keys())}
|
|
456
|
-
if edge_types is not None
|
|
457
|
-
else {('Entity', 'Entity'): []}
|
|
458
|
-
)
|
|
459
732
|
|
|
460
|
-
|
|
733
|
+
# Create default edge type map
|
|
734
|
+
edge_type_map_default = (
|
|
735
|
+
{('Entity', 'Entity'): list(edge_types.keys())}
|
|
736
|
+
if edge_types is not None
|
|
737
|
+
else {('Entity', 'Entity'): []}
|
|
738
|
+
)
|
|
461
739
|
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
740
|
+
# Extract and resolve nodes
|
|
741
|
+
extracted_nodes = await extract_nodes(
|
|
742
|
+
self.clients, episode, previous_episodes, entity_types, excluded_entity_types
|
|
743
|
+
)
|
|
465
744
|
|
|
466
|
-
|
|
467
|
-
(nodes, uuid_map, node_duplicates), extracted_edges = await semaphore_gather(
|
|
468
|
-
resolve_extracted_nodes(
|
|
745
|
+
nodes, uuid_map, _ = await resolve_extracted_nodes(
|
|
469
746
|
self.clients,
|
|
470
747
|
extracted_nodes,
|
|
471
748
|
episode,
|
|
472
749
|
previous_episodes,
|
|
473
750
|
entity_types,
|
|
474
|
-
)
|
|
475
|
-
|
|
476
|
-
|
|
751
|
+
)
|
|
752
|
+
|
|
753
|
+
# Extract and resolve edges in parallel with attribute extraction
|
|
754
|
+
resolved_edges, invalidated_edges = await self._extract_and_resolve_edges(
|
|
477
755
|
episode,
|
|
478
756
|
extracted_nodes,
|
|
479
757
|
previous_episodes,
|
|
480
758
|
edge_type_map or edge_type_map_default,
|
|
481
759
|
group_id,
|
|
482
760
|
edge_types,
|
|
483
|
-
),
|
|
484
|
-
max_coroutines=self.max_coroutines,
|
|
485
|
-
)
|
|
486
|
-
|
|
487
|
-
edges = resolve_edge_pointers(extracted_edges, uuid_map)
|
|
488
|
-
|
|
489
|
-
(resolved_edges, invalidated_edges), hydrated_nodes = await semaphore_gather(
|
|
490
|
-
resolve_extracted_edges(
|
|
491
|
-
self.clients,
|
|
492
|
-
edges,
|
|
493
|
-
episode,
|
|
494
761
|
nodes,
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
),
|
|
498
|
-
extract_attributes_from_nodes(
|
|
499
|
-
self.clients, nodes, episode, previous_episodes, entity_types
|
|
500
|
-
),
|
|
501
|
-
max_coroutines=self.max_coroutines,
|
|
502
|
-
)
|
|
762
|
+
uuid_map,
|
|
763
|
+
)
|
|
503
764
|
|
|
504
|
-
|
|
765
|
+
# Extract node attributes
|
|
766
|
+
hydrated_nodes = await extract_attributes_from_nodes(
|
|
767
|
+
self.clients, nodes, episode, previous_episodes, entity_types
|
|
768
|
+
)
|
|
505
769
|
|
|
506
|
-
|
|
770
|
+
entity_edges = resolved_edges + invalidated_edges
|
|
507
771
|
|
|
508
|
-
|
|
772
|
+
# Process and save episode data
|
|
773
|
+
episodic_edges, episode = await self._process_episode_data(
|
|
774
|
+
episode, hydrated_nodes, entity_edges, now
|
|
775
|
+
)
|
|
509
776
|
|
|
510
|
-
|
|
777
|
+
# Update communities if requested
|
|
778
|
+
communities = []
|
|
779
|
+
community_edges = []
|
|
780
|
+
if update_communities:
|
|
781
|
+
communities, community_edges = await semaphore_gather(
|
|
782
|
+
*[
|
|
783
|
+
update_community(self.driver, self.llm_client, self.embedder, node)
|
|
784
|
+
for node in nodes
|
|
785
|
+
],
|
|
786
|
+
max_coroutines=self.max_coroutines,
|
|
787
|
+
)
|
|
511
788
|
|
|
512
|
-
|
|
513
|
-
|
|
789
|
+
end = time()
|
|
790
|
+
|
|
791
|
+
# Add span attributes
|
|
792
|
+
span.add_attributes(
|
|
793
|
+
{
|
|
794
|
+
'episode.uuid': episode.uuid,
|
|
795
|
+
'episode.source': source.value,
|
|
796
|
+
'episode.reference_time': reference_time.isoformat(),
|
|
797
|
+
'group_id': group_id,
|
|
798
|
+
'node.count': len(hydrated_nodes),
|
|
799
|
+
'edge.count': len(entity_edges),
|
|
800
|
+
'edge.invalidated_count': len(invalidated_edges),
|
|
801
|
+
'previous_episodes.count': len(previous_episodes),
|
|
802
|
+
'entity_types.count': len(entity_types) if entity_types else 0,
|
|
803
|
+
'edge_types.count': len(edge_types) if edge_types else 0,
|
|
804
|
+
'update_communities': update_communities,
|
|
805
|
+
'communities.count': len(communities) if update_communities else 0,
|
|
806
|
+
'duration_ms': (end - start) * 1000,
|
|
807
|
+
}
|
|
808
|
+
)
|
|
514
809
|
|
|
515
|
-
|
|
516
|
-
self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges, self.embedder
|
|
517
|
-
)
|
|
810
|
+
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
|
518
811
|
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
max_coroutines=self.max_coroutines,
|
|
812
|
+
return AddEpisodeResults(
|
|
813
|
+
episode=episode,
|
|
814
|
+
episodic_edges=episodic_edges,
|
|
815
|
+
nodes=hydrated_nodes,
|
|
816
|
+
edges=entity_edges,
|
|
817
|
+
communities=communities,
|
|
818
|
+
community_edges=community_edges,
|
|
527
819
|
)
|
|
528
|
-
end = time()
|
|
529
|
-
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
|
530
820
|
|
|
531
|
-
|
|
821
|
+
except Exception as e:
|
|
822
|
+
span.set_status('error', str(e))
|
|
823
|
+
span.record_exception(e)
|
|
824
|
+
raise e
|
|
532
825
|
|
|
533
|
-
except Exception as e:
|
|
534
|
-
raise e
|
|
535
|
-
|
|
536
|
-
##### EXPERIMENTAL #####
|
|
537
826
|
async def add_episode_bulk(
|
|
538
827
|
self,
|
|
539
828
|
bulk_episodes: list[RawEpisode],
|
|
540
|
-
group_id: str =
|
|
541
|
-
entity_types: dict[str, BaseModel] | None = None,
|
|
829
|
+
group_id: str | None = None,
|
|
830
|
+
entity_types: dict[str, type[BaseModel]] | None = None,
|
|
542
831
|
excluded_entity_types: list[str] | None = None,
|
|
543
|
-
edge_types: dict[str, BaseModel] | None = None,
|
|
832
|
+
edge_types: dict[str, type[BaseModel]] | None = None,
|
|
544
833
|
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
|
|
545
|
-
):
|
|
834
|
+
) -> AddBulkEpisodeResults:
|
|
546
835
|
"""
|
|
547
836
|
Process multiple episodes in bulk and update the graph.
|
|
548
837
|
|
|
@@ -558,7 +847,7 @@ class Graphiti:
|
|
|
558
847
|
|
|
559
848
|
Returns
|
|
560
849
|
-------
|
|
561
|
-
|
|
850
|
+
AddBulkEpisodeResults
|
|
562
851
|
|
|
563
852
|
Notes
|
|
564
853
|
-----
|
|
@@ -579,156 +868,167 @@ class Graphiti:
|
|
|
579
868
|
If these operations are required, use the `add_episode` method instead for each
|
|
580
869
|
individual episode.
|
|
581
870
|
"""
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
now = utc_now()
|
|
585
|
-
|
|
586
|
-
validate_group_id(group_id)
|
|
587
|
-
|
|
588
|
-
# Create default edge type map
|
|
589
|
-
edge_type_map_default = (
|
|
590
|
-
{('Entity', 'Entity'): list(edge_types.keys())}
|
|
591
|
-
if edge_types is not None
|
|
592
|
-
else {('Entity', 'Entity'): []}
|
|
593
|
-
)
|
|
871
|
+
with self.tracer.start_span('add_episode_bulk') as bulk_span:
|
|
872
|
+
bulk_span.add_attributes({'episode.count': len(bulk_episodes)})
|
|
594
873
|
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
group_id
|
|
605
|
-
|
|
606
|
-
|
|
874
|
+
try:
|
|
875
|
+
start = time()
|
|
876
|
+
now = utc_now()
|
|
877
|
+
|
|
878
|
+
# if group_id is None, use the default group id by the provider
|
|
879
|
+
if group_id is None:
|
|
880
|
+
group_id = get_default_group_id(self.driver.provider)
|
|
881
|
+
else:
|
|
882
|
+
validate_group_id(group_id)
|
|
883
|
+
if group_id != self.driver._database:
|
|
884
|
+
# if group_id is provided, use it as the database name
|
|
885
|
+
self.driver = self.driver.clone(database=group_id)
|
|
886
|
+
self.clients.driver = self.driver
|
|
887
|
+
|
|
888
|
+
# Create default edge type map
|
|
889
|
+
edge_type_map_default = (
|
|
890
|
+
{('Entity', 'Entity'): list(edge_types.keys())}
|
|
891
|
+
if edge_types is not None
|
|
892
|
+
else {('Entity', 'Entity'): []}
|
|
607
893
|
)
|
|
608
|
-
for episode in bulk_episodes
|
|
609
|
-
]
|
|
610
894
|
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
895
|
+
episodes = [
|
|
896
|
+
await EpisodicNode.get_by_uuid(self.driver, episode.uuid)
|
|
897
|
+
if episode.uuid is not None
|
|
898
|
+
else EpisodicNode(
|
|
899
|
+
name=episode.name,
|
|
900
|
+
labels=[],
|
|
901
|
+
source=episode.source,
|
|
902
|
+
content=episode.content,
|
|
903
|
+
source_description=episode.source_description,
|
|
904
|
+
group_id=group_id,
|
|
905
|
+
created_at=now,
|
|
906
|
+
valid_at=episode.reference_time,
|
|
907
|
+
)
|
|
908
|
+
for episode in bulk_episodes
|
|
909
|
+
]
|
|
624
910
|
|
|
625
|
-
|
|
626
|
-
|
|
911
|
+
# Save all episodes
|
|
912
|
+
await add_nodes_and_edges_bulk(
|
|
913
|
+
driver=self.driver,
|
|
914
|
+
episodic_nodes=episodes,
|
|
915
|
+
episodic_edges=[],
|
|
916
|
+
entity_nodes=[],
|
|
917
|
+
entity_edges=[],
|
|
918
|
+
embedder=self.embedder,
|
|
919
|
+
)
|
|
627
920
|
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
self.clients,
|
|
631
|
-
episode_context,
|
|
632
|
-
edge_type_map=edge_type_map or edge_type_map_default,
|
|
633
|
-
edge_types=edge_types,
|
|
634
|
-
entity_types=entity_types,
|
|
635
|
-
excluded_entity_types=excluded_entity_types,
|
|
636
|
-
)
|
|
921
|
+
# Get previous episode context for each episode
|
|
922
|
+
episode_context = await retrieve_previous_episodes_bulk(self.driver, episodes)
|
|
637
923
|
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
924
|
+
# Extract and dedupe nodes and edges
|
|
925
|
+
(
|
|
926
|
+
nodes_by_episode,
|
|
927
|
+
uuid_map,
|
|
928
|
+
extracted_edges_bulk,
|
|
929
|
+
) = await self._extract_and_dedupe_nodes_bulk(
|
|
930
|
+
episode_context,
|
|
931
|
+
edge_type_map or edge_type_map_default,
|
|
932
|
+
edge_types,
|
|
933
|
+
entity_types,
|
|
934
|
+
excluded_entity_types,
|
|
935
|
+
)
|
|
642
936
|
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
937
|
+
# Create Episodic Edges
|
|
938
|
+
episodic_edges: list[EpisodicEdge] = []
|
|
939
|
+
for episode_uuid, nodes in nodes_by_episode.items():
|
|
940
|
+
episodic_edges.extend(build_episodic_edges(nodes, episode_uuid, now))
|
|
646
941
|
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
942
|
+
# Re-map edge pointers and dedupe edges
|
|
943
|
+
extracted_edges_bulk_updated: list[list[EntityEdge]] = [
|
|
944
|
+
resolve_edge_pointers(edges, uuid_map) for edges in extracted_edges_bulk
|
|
945
|
+
]
|
|
651
946
|
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
)
|
|
947
|
+
edges_by_episode = await dedupe_edges_bulk(
|
|
948
|
+
self.clients,
|
|
949
|
+
extracted_edges_bulk_updated,
|
|
950
|
+
episode_context,
|
|
951
|
+
[],
|
|
952
|
+
edge_types or {},
|
|
953
|
+
edge_type_map or edge_type_map_default,
|
|
954
|
+
)
|
|
661
955
|
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
956
|
+
# Resolve nodes and edges against the existing graph
|
|
957
|
+
(
|
|
958
|
+
final_hydrated_nodes,
|
|
959
|
+
resolved_edges,
|
|
960
|
+
invalidated_edges,
|
|
961
|
+
final_uuid_map,
|
|
962
|
+
) = await self._resolve_nodes_and_edges_bulk(
|
|
963
|
+
nodes_by_episode,
|
|
964
|
+
edges_by_episode,
|
|
965
|
+
episode_context,
|
|
966
|
+
entity_types,
|
|
967
|
+
edge_types,
|
|
968
|
+
edge_type_map or edge_type_map_default,
|
|
969
|
+
episodes,
|
|
970
|
+
)
|
|
666
971
|
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
episode_mentions.sort(key=lambda x: x.valid_at, reverse=True)
|
|
680
|
-
|
|
681
|
-
extract_attributes_params.append((node, episode_mentions))
|
|
682
|
-
|
|
683
|
-
new_hydrated_nodes: list[list[EntityNode]] = await semaphore_gather(
|
|
684
|
-
*[
|
|
685
|
-
extract_attributes_from_nodes(
|
|
686
|
-
self.clients,
|
|
687
|
-
[params[0]],
|
|
688
|
-
params[1][0],
|
|
689
|
-
params[1][0:],
|
|
690
|
-
entity_types,
|
|
691
|
-
)
|
|
692
|
-
for params in extract_attributes_params
|
|
693
|
-
]
|
|
694
|
-
)
|
|
972
|
+
# Resolved pointers for episodic edges
|
|
973
|
+
resolved_episodic_edges = resolve_edge_pointers(episodic_edges, final_uuid_map)
|
|
974
|
+
|
|
975
|
+
# save data to KG
|
|
976
|
+
await add_nodes_and_edges_bulk(
|
|
977
|
+
self.driver,
|
|
978
|
+
episodes,
|
|
979
|
+
resolved_episodic_edges,
|
|
980
|
+
final_hydrated_nodes,
|
|
981
|
+
resolved_edges + invalidated_edges,
|
|
982
|
+
self.embedder,
|
|
983
|
+
)
|
|
695
984
|
|
|
696
|
-
|
|
985
|
+
end = time()
|
|
697
986
|
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
987
|
+
# Add span attributes
|
|
988
|
+
bulk_span.add_attributes(
|
|
989
|
+
{
|
|
990
|
+
'group_id': group_id,
|
|
991
|
+
'node.count': len(final_hydrated_nodes),
|
|
992
|
+
'edge.count': len(resolved_edges + invalidated_edges),
|
|
993
|
+
'duration_ms': (end - start) * 1000,
|
|
994
|
+
}
|
|
995
|
+
)
|
|
702
996
|
|
|
703
|
-
|
|
704
|
-
await add_nodes_and_edges_bulk(
|
|
705
|
-
self.driver,
|
|
706
|
-
episodes,
|
|
707
|
-
episodic_edges,
|
|
708
|
-
hydrated_nodes,
|
|
709
|
-
list(edges_by_uuid.values()),
|
|
710
|
-
self.embedder,
|
|
711
|
-
)
|
|
997
|
+
logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
|
|
712
998
|
|
|
713
|
-
|
|
714
|
-
|
|
999
|
+
return AddBulkEpisodeResults(
|
|
1000
|
+
episodes=episodes,
|
|
1001
|
+
episodic_edges=resolved_episodic_edges,
|
|
1002
|
+
nodes=final_hydrated_nodes,
|
|
1003
|
+
edges=resolved_edges + invalidated_edges,
|
|
1004
|
+
communities=[],
|
|
1005
|
+
community_edges=[],
|
|
1006
|
+
)
|
|
715
1007
|
|
|
716
|
-
|
|
717
|
-
|
|
1008
|
+
except Exception as e:
|
|
1009
|
+
bulk_span.set_status('error', str(e))
|
|
1010
|
+
bulk_span.record_exception(e)
|
|
1011
|
+
raise e
|
|
718
1012
|
|
|
719
|
-
|
|
1013
|
+
@handle_multiple_group_ids
|
|
1014
|
+
async def build_communities(
|
|
1015
|
+
self, group_ids: list[str] | None = None, driver: GraphDriver | None = None
|
|
1016
|
+
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
|
720
1017
|
"""
|
|
721
1018
|
Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
|
|
722
1019
|
the content of these communities.
|
|
723
1020
|
----------
|
|
724
|
-
|
|
1021
|
+
group_ids : list[str] | None
|
|
725
1022
|
Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
|
|
726
1023
|
"""
|
|
1024
|
+
if driver is None:
|
|
1025
|
+
driver = self.clients.driver
|
|
1026
|
+
|
|
727
1027
|
# Clear existing communities
|
|
728
|
-
await remove_communities(
|
|
1028
|
+
await remove_communities(driver)
|
|
729
1029
|
|
|
730
1030
|
community_nodes, community_edges = await build_communities(
|
|
731
|
-
|
|
1031
|
+
driver, self.llm_client, group_ids
|
|
732
1032
|
)
|
|
733
1033
|
|
|
734
1034
|
await semaphore_gather(
|
|
@@ -737,16 +1037,17 @@ class Graphiti:
|
|
|
737
1037
|
)
|
|
738
1038
|
|
|
739
1039
|
await semaphore_gather(
|
|
740
|
-
*[node.save(
|
|
1040
|
+
*[node.save(driver) for node in community_nodes],
|
|
741
1041
|
max_coroutines=self.max_coroutines,
|
|
742
1042
|
)
|
|
743
1043
|
await semaphore_gather(
|
|
744
|
-
*[edge.save(
|
|
1044
|
+
*[edge.save(driver) for edge in community_edges],
|
|
745
1045
|
max_coroutines=self.max_coroutines,
|
|
746
1046
|
)
|
|
747
1047
|
|
|
748
|
-
return community_nodes
|
|
1048
|
+
return community_nodes, community_edges
|
|
749
1049
|
|
|
1050
|
+
@handle_multiple_group_ids
|
|
750
1051
|
async def search(
|
|
751
1052
|
self,
|
|
752
1053
|
query: str,
|
|
@@ -754,6 +1055,7 @@ class Graphiti:
|
|
|
754
1055
|
group_ids: list[str] | None = None,
|
|
755
1056
|
num_results=DEFAULT_SEARCH_LIMIT,
|
|
756
1057
|
search_filter: SearchFilters | None = None,
|
|
1058
|
+
driver: GraphDriver | None = None,
|
|
757
1059
|
) -> list[EntityEdge]:
|
|
758
1060
|
"""
|
|
759
1061
|
Perform a hybrid search on the knowledge graph.
|
|
@@ -800,7 +1102,8 @@ class Graphiti:
|
|
|
800
1102
|
group_ids,
|
|
801
1103
|
search_config,
|
|
802
1104
|
search_filter if search_filter is not None else SearchFilters(),
|
|
803
|
-
|
|
1105
|
+
driver=driver,
|
|
1106
|
+
center_node_uuid=center_node_uuid,
|
|
804
1107
|
)
|
|
805
1108
|
).edges
|
|
806
1109
|
|
|
@@ -820,6 +1123,7 @@ class Graphiti:
|
|
|
820
1123
|
query, config, group_ids, center_node_uuid, bfs_origin_node_uuids, search_filter
|
|
821
1124
|
)
|
|
822
1125
|
|
|
1126
|
+
@handle_multiple_group_ids
|
|
823
1127
|
async def search_(
|
|
824
1128
|
self,
|
|
825
1129
|
query: str,
|
|
@@ -828,6 +1132,7 @@ class Graphiti:
|
|
|
828
1132
|
center_node_uuid: str | None = None,
|
|
829
1133
|
bfs_origin_node_uuids: list[str] | None = None,
|
|
830
1134
|
search_filter: SearchFilters | None = None,
|
|
1135
|
+
driver: GraphDriver | None = None,
|
|
831
1136
|
) -> SearchResults:
|
|
832
1137
|
"""search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather
|
|
833
1138
|
than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and
|
|
@@ -844,6 +1149,7 @@ class Graphiti:
|
|
|
844
1149
|
search_filter if search_filter is not None else SearchFilters(),
|
|
845
1150
|
center_node_uuid,
|
|
846
1151
|
bfs_origin_node_uuids,
|
|
1152
|
+
driver=driver,
|
|
847
1153
|
)
|
|
848
1154
|
|
|
849
1155
|
async def get_nodes_and_edges_by_episode(self, episode_uuids: list[str]) -> SearchResults:
|
|
@@ -858,9 +1164,11 @@ class Graphiti:
|
|
|
858
1164
|
|
|
859
1165
|
nodes = await get_mentioned_nodes(self.driver, episodes)
|
|
860
1166
|
|
|
861
|
-
return SearchResults(edges=edges, nodes=nodes
|
|
1167
|
+
return SearchResults(edges=edges, nodes=nodes)
|
|
862
1168
|
|
|
863
|
-
async def add_triplet(
|
|
1169
|
+
async def add_triplet(
|
|
1170
|
+
self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode
|
|
1171
|
+
) -> AddTripletResults:
|
|
864
1172
|
if source_node.name_embedding is None:
|
|
865
1173
|
await source_node.generate_name_embedding(self.embedder)
|
|
866
1174
|
if target_node.name_embedding is None:
|
|
@@ -868,17 +1176,35 @@ class Graphiti:
|
|
|
868
1176
|
if edge.fact_embedding is None:
|
|
869
1177
|
await edge.generate_embedding(self.embedder)
|
|
870
1178
|
|
|
871
|
-
|
|
1179
|
+
nodes, uuid_map, _ = await resolve_extracted_nodes(
|
|
872
1180
|
self.clients,
|
|
873
1181
|
[source_node, target_node],
|
|
874
1182
|
)
|
|
875
1183
|
|
|
876
1184
|
updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
|
|
877
1185
|
|
|
878
|
-
|
|
1186
|
+
valid_edges = await EntityEdge.get_between_nodes(
|
|
1187
|
+
self.driver, edge.source_node_uuid, edge.target_node_uuid
|
|
1188
|
+
)
|
|
1189
|
+
|
|
1190
|
+
related_edges = (
|
|
1191
|
+
await search(
|
|
1192
|
+
self.clients,
|
|
1193
|
+
updated_edge.fact,
|
|
1194
|
+
group_ids=[updated_edge.group_id],
|
|
1195
|
+
config=EDGE_HYBRID_SEARCH_RRF,
|
|
1196
|
+
search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
|
|
1197
|
+
)
|
|
1198
|
+
).edges
|
|
879
1199
|
existing_edges = (
|
|
880
|
-
await
|
|
881
|
-
|
|
1200
|
+
await search(
|
|
1201
|
+
self.clients,
|
|
1202
|
+
updated_edge.fact,
|
|
1203
|
+
group_ids=[updated_edge.group_id],
|
|
1204
|
+
config=EDGE_HYBRID_SEARCH_RRF,
|
|
1205
|
+
search_filter=SearchFilters(),
|
|
1206
|
+
)
|
|
1207
|
+
).edges
|
|
882
1208
|
|
|
883
1209
|
resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
|
|
884
1210
|
self.llm_client,
|
|
@@ -894,11 +1220,17 @@ class Graphiti:
|
|
|
894
1220
|
entity_edges=[],
|
|
895
1221
|
group_id=edge.group_id,
|
|
896
1222
|
),
|
|
1223
|
+
None,
|
|
1224
|
+
None,
|
|
897
1225
|
)
|
|
898
1226
|
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
)
|
|
1227
|
+
edges: list[EntityEdge] = [resolved_edge] + invalidated_edges
|
|
1228
|
+
|
|
1229
|
+
await create_entity_edge_embeddings(self.embedder, edges)
|
|
1230
|
+
await create_entity_node_embeddings(self.embedder, nodes)
|
|
1231
|
+
|
|
1232
|
+
await add_nodes_and_edges_bulk(self.driver, [], [], nodes, edges, self.embedder)
|
|
1233
|
+
return AddTripletResults(edges=edges, nodes=nodes)
|
|
902
1234
|
|
|
903
1235
|
async def remove_episode(self, episode_uuid: str):
|
|
904
1236
|
# Find the episode to be deleted
|
|
@@ -925,12 +1257,7 @@ class Graphiti:
|
|
|
925
1257
|
if record['episode_count'] == 1:
|
|
926
1258
|
nodes_to_delete.append(node)
|
|
927
1259
|
|
|
928
|
-
await
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
)
|
|
932
|
-
await semaphore_gather(
|
|
933
|
-
*[edge.delete(self.driver) for edge in edges_to_delete],
|
|
934
|
-
max_coroutines=self.max_coroutines,
|
|
935
|
-
)
|
|
1260
|
+
await Edge.delete_by_uuids(self.driver, [edge.uuid for edge in edges_to_delete])
|
|
1261
|
+
await Node.delete_by_uuids(self.driver, [node.uuid for node in nodes_to_delete])
|
|
1262
|
+
|
|
936
1263
|
await episode.delete(self.driver)
|