graphiti-core 0.12.0rc1__py3-none-any.whl → 0.24.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphiti_core/cross_encoder/bge_reranker_client.py +12 -2
- graphiti_core/cross_encoder/gemini_reranker_client.py +161 -0
- graphiti_core/cross_encoder/openai_reranker_client.py +7 -5
- graphiti_core/decorators.py +110 -0
- graphiti_core/driver/__init__.py +19 -0
- graphiti_core/driver/driver.py +124 -0
- graphiti_core/driver/falkordb_driver.py +362 -0
- graphiti_core/driver/graph_operations/graph_operations.py +191 -0
- graphiti_core/driver/kuzu_driver.py +182 -0
- graphiti_core/driver/neo4j_driver.py +117 -0
- graphiti_core/driver/neptune_driver.py +305 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +287 -172
- graphiti_core/embedder/azure_openai.py +71 -0
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/embedder/gemini.py +116 -22
- graphiti_core/embedder/voyage.py +13 -2
- graphiti_core/errors.py +8 -0
- graphiti_core/graph_queries.py +162 -0
- graphiti_core/graphiti.py +705 -193
- graphiti_core/graphiti_types.py +4 -2
- graphiti_core/helpers.py +87 -10
- graphiti_core/llm_client/__init__.py +16 -0
- graphiti_core/llm_client/anthropic_client.py +159 -56
- graphiti_core/llm_client/azure_openai_client.py +115 -0
- graphiti_core/llm_client/client.py +98 -21
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/gemini_client.py +290 -41
- graphiti_core/llm_client/groq_client.py +14 -3
- graphiti_core/llm_client/openai_base_client.py +261 -0
- graphiti_core/llm_client/openai_client.py +56 -132
- graphiti_core/llm_client/openai_generic_client.py +91 -56
- graphiti_core/models/edges/edge_db_queries.py +259 -35
- graphiti_core/models/nodes/node_db_queries.py +311 -32
- graphiti_core/nodes.py +420 -205
- graphiti_core/prompts/dedupe_edges.py +46 -32
- graphiti_core/prompts/dedupe_nodes.py +67 -42
- graphiti_core/prompts/eval.py +4 -4
- graphiti_core/prompts/extract_edges.py +27 -16
- graphiti_core/prompts/extract_nodes.py +74 -31
- graphiti_core/prompts/prompt_helpers.py +39 -0
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +23 -25
- graphiti_core/search/search.py +158 -82
- graphiti_core/search/search_config.py +39 -4
- graphiti_core/search/search_filters.py +126 -35
- graphiti_core/search/search_helpers.py +5 -6
- graphiti_core/search/search_utils.py +1405 -485
- graphiti_core/telemetry/__init__.py +9 -0
- graphiti_core/telemetry/telemetry.py +117 -0
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +364 -285
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +67 -49
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +339 -197
- graphiti_core/utils/maintenance/graph_data_operations.py +50 -114
- graphiti_core/utils/maintenance/node_operations.py +319 -238
- graphiti_core/utils/maintenance/temporal_operations.py +11 -3
- graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
- graphiti_core/utils/text_utils.py +53 -0
- graphiti_core-0.24.3.dist-info/METADATA +726 -0
- graphiti_core-0.24.3.dist-info/RECORD +86 -0
- {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
- graphiti_core-0.12.0rc1.dist-info/METADATA +0 -350
- graphiti_core-0.12.0rc1.dist-info/RECORD +0 -66
- /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
- {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info/licenses}/LICENSE +0 -0
graphiti_core/graphiti.py
CHANGED
|
@@ -19,18 +19,38 @@ from datetime import datetime
|
|
|
19
19
|
from time import time
|
|
20
20
|
|
|
21
21
|
from dotenv import load_dotenv
|
|
22
|
-
from neo4j import AsyncGraphDatabase
|
|
23
22
|
from pydantic import BaseModel
|
|
24
23
|
from typing_extensions import LiteralString
|
|
25
24
|
|
|
26
25
|
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
27
26
|
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
|
28
|
-
from graphiti_core.
|
|
27
|
+
from graphiti_core.decorators import handle_multiple_group_ids
|
|
28
|
+
from graphiti_core.driver.driver import GraphDriver
|
|
29
|
+
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
|
30
|
+
from graphiti_core.edges import (
|
|
31
|
+
CommunityEdge,
|
|
32
|
+
Edge,
|
|
33
|
+
EntityEdge,
|
|
34
|
+
EpisodicEdge,
|
|
35
|
+
create_entity_edge_embeddings,
|
|
36
|
+
)
|
|
29
37
|
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
|
30
38
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
31
|
-
from graphiti_core.helpers import
|
|
39
|
+
from graphiti_core.helpers import (
|
|
40
|
+
get_default_group_id,
|
|
41
|
+
semaphore_gather,
|
|
42
|
+
validate_excluded_entity_types,
|
|
43
|
+
validate_group_id,
|
|
44
|
+
)
|
|
32
45
|
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
|
33
|
-
from graphiti_core.nodes import
|
|
46
|
+
from graphiti_core.nodes import (
|
|
47
|
+
CommunityNode,
|
|
48
|
+
EntityNode,
|
|
49
|
+
EpisodeType,
|
|
50
|
+
EpisodicNode,
|
|
51
|
+
Node,
|
|
52
|
+
create_entity_node_embeddings,
|
|
53
|
+
)
|
|
34
54
|
from graphiti_core.search.search import SearchConfig, search
|
|
35
55
|
from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
|
|
36
56
|
from graphiti_core.search.search_config_recipes import (
|
|
@@ -41,16 +61,15 @@ from graphiti_core.search.search_config_recipes import (
|
|
|
41
61
|
from graphiti_core.search.search_filters import SearchFilters
|
|
42
62
|
from graphiti_core.search.search_utils import (
|
|
43
63
|
RELEVANT_SCHEMA_LIMIT,
|
|
44
|
-
get_edge_invalidation_candidates,
|
|
45
64
|
get_mentioned_nodes,
|
|
46
|
-
get_relevant_edges,
|
|
47
65
|
)
|
|
66
|
+
from graphiti_core.telemetry import capture_event
|
|
67
|
+
from graphiti_core.tracer import Tracer, create_tracer
|
|
48
68
|
from graphiti_core.utils.bulk_utils import (
|
|
49
69
|
RawEpisode,
|
|
50
70
|
add_nodes_and_edges_bulk,
|
|
51
71
|
dedupe_edges_bulk,
|
|
52
72
|
dedupe_nodes_bulk,
|
|
53
|
-
extract_edge_dates_bulk,
|
|
54
73
|
extract_nodes_and_edges_bulk,
|
|
55
74
|
resolve_edge_pointers,
|
|
56
75
|
retrieve_previous_episodes_bulk,
|
|
@@ -69,7 +88,6 @@ from graphiti_core.utils.maintenance.edge_operations import (
|
|
|
69
88
|
)
|
|
70
89
|
from graphiti_core.utils.maintenance.graph_data_operations import (
|
|
71
90
|
EPISODE_WINDOW_LEN,
|
|
72
|
-
build_indices_and_constraints,
|
|
73
91
|
retrieve_episodes,
|
|
74
92
|
)
|
|
75
93
|
from graphiti_core.utils.maintenance.node_operations import (
|
|
@@ -86,6 +104,23 @@ load_dotenv()
|
|
|
86
104
|
|
|
87
105
|
class AddEpisodeResults(BaseModel):
|
|
88
106
|
episode: EpisodicNode
|
|
107
|
+
episodic_edges: list[EpisodicEdge]
|
|
108
|
+
nodes: list[EntityNode]
|
|
109
|
+
edges: list[EntityEdge]
|
|
110
|
+
communities: list[CommunityNode]
|
|
111
|
+
community_edges: list[CommunityEdge]
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class AddBulkEpisodeResults(BaseModel):
|
|
115
|
+
episodes: list[EpisodicNode]
|
|
116
|
+
episodic_edges: list[EpisodicEdge]
|
|
117
|
+
nodes: list[EntityNode]
|
|
118
|
+
edges: list[EntityEdge]
|
|
119
|
+
communities: list[CommunityNode]
|
|
120
|
+
community_edges: list[CommunityEdge]
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class AddTripletResults(BaseModel):
|
|
89
124
|
nodes: list[EntityNode]
|
|
90
125
|
edges: list[EntityEdge]
|
|
91
126
|
|
|
@@ -93,18 +128,22 @@ class AddEpisodeResults(BaseModel):
|
|
|
93
128
|
class Graphiti:
|
|
94
129
|
def __init__(
|
|
95
130
|
self,
|
|
96
|
-
uri: str,
|
|
97
|
-
user: str,
|
|
98
|
-
password: str,
|
|
131
|
+
uri: str | None = None,
|
|
132
|
+
user: str | None = None,
|
|
133
|
+
password: str | None = None,
|
|
99
134
|
llm_client: LLMClient | None = None,
|
|
100
135
|
embedder: EmbedderClient | None = None,
|
|
101
136
|
cross_encoder: CrossEncoderClient | None = None,
|
|
102
137
|
store_raw_episode_content: bool = True,
|
|
138
|
+
graph_driver: GraphDriver | None = None,
|
|
139
|
+
max_coroutines: int | None = None,
|
|
140
|
+
tracer: Tracer | None = None,
|
|
141
|
+
trace_span_prefix: str = 'graphiti',
|
|
103
142
|
):
|
|
104
143
|
"""
|
|
105
144
|
Initialize a Graphiti instance.
|
|
106
145
|
|
|
107
|
-
This constructor sets up a connection to
|
|
146
|
+
This constructor sets up a connection to a graph database and initializes
|
|
108
147
|
the LLM client for natural language processing tasks.
|
|
109
148
|
|
|
110
149
|
Parameters
|
|
@@ -118,6 +157,24 @@ class Graphiti:
|
|
|
118
157
|
llm_client : LLMClient | None, optional
|
|
119
158
|
An instance of LLMClient for natural language processing tasks.
|
|
120
159
|
If not provided, a default OpenAIClient will be initialized.
|
|
160
|
+
embedder : EmbedderClient | None, optional
|
|
161
|
+
An instance of EmbedderClient for embedding tasks.
|
|
162
|
+
If not provided, a default OpenAIEmbedder will be initialized.
|
|
163
|
+
cross_encoder : CrossEncoderClient | None, optional
|
|
164
|
+
An instance of CrossEncoderClient for reranking tasks.
|
|
165
|
+
If not provided, a default OpenAIRerankerClient will be initialized.
|
|
166
|
+
store_raw_episode_content : bool, optional
|
|
167
|
+
Whether to store the raw content of episodes. Defaults to True.
|
|
168
|
+
graph_driver : GraphDriver | None, optional
|
|
169
|
+
An instance of GraphDriver for database operations.
|
|
170
|
+
If not provided, a default Neo4jDriver will be initialized.
|
|
171
|
+
max_coroutines : int | None, optional
|
|
172
|
+
The maximum number of concurrent operations allowed. Overrides SEMAPHORE_LIMIT set in the environment.
|
|
173
|
+
If not set, the Graphiti default is used.
|
|
174
|
+
tracer : Tracer | None, optional
|
|
175
|
+
An OpenTelemetry tracer instance for distributed tracing. If not provided, tracing is disabled (no-op).
|
|
176
|
+
trace_span_prefix : str, optional
|
|
177
|
+
Prefix to prepend to all span names. Defaults to 'graphiti'.
|
|
121
178
|
|
|
122
179
|
Returns
|
|
123
180
|
-------
|
|
@@ -125,11 +182,11 @@ class Graphiti:
|
|
|
125
182
|
|
|
126
183
|
Notes
|
|
127
184
|
-----
|
|
128
|
-
This method establishes a connection to
|
|
185
|
+
This method establishes a connection to a graph database (Neo4j by default) using the provided
|
|
129
186
|
credentials. It also sets up the LLM client, either using the provided client
|
|
130
187
|
or by creating a default OpenAIClient.
|
|
131
188
|
|
|
132
|
-
The default database name is
|
|
189
|
+
The default database name is defined during the driver’s construction. If a different database name
|
|
133
190
|
is required, it should be specified in the URI or set separately after
|
|
134
191
|
initialization.
|
|
135
192
|
|
|
@@ -137,9 +194,16 @@ class Graphiti:
|
|
|
137
194
|
Make sure to set the OPENAI_API_KEY environment variable before initializing
|
|
138
195
|
Graphiti if you're using the default OpenAIClient.
|
|
139
196
|
"""
|
|
140
|
-
|
|
141
|
-
|
|
197
|
+
|
|
198
|
+
if graph_driver:
|
|
199
|
+
self.driver = graph_driver
|
|
200
|
+
else:
|
|
201
|
+
if uri is None:
|
|
202
|
+
raise ValueError('uri must be provided when graph_driver is None')
|
|
203
|
+
self.driver = Neo4jDriver(uri, user, password)
|
|
204
|
+
|
|
142
205
|
self.store_raw_episode_content = store_raw_episode_content
|
|
206
|
+
self.max_coroutines = max_coroutines
|
|
143
207
|
if llm_client:
|
|
144
208
|
self.llm_client = llm_client
|
|
145
209
|
else:
|
|
@@ -153,13 +217,75 @@ class Graphiti:
|
|
|
153
217
|
else:
|
|
154
218
|
self.cross_encoder = OpenAIRerankerClient()
|
|
155
219
|
|
|
220
|
+
# Initialize tracer
|
|
221
|
+
self.tracer = create_tracer(tracer, trace_span_prefix)
|
|
222
|
+
|
|
223
|
+
# Set tracer on clients
|
|
224
|
+
self.llm_client.set_tracer(self.tracer)
|
|
225
|
+
|
|
156
226
|
self.clients = GraphitiClients(
|
|
157
227
|
driver=self.driver,
|
|
158
228
|
llm_client=self.llm_client,
|
|
159
229
|
embedder=self.embedder,
|
|
160
230
|
cross_encoder=self.cross_encoder,
|
|
231
|
+
tracer=self.tracer,
|
|
161
232
|
)
|
|
162
233
|
|
|
234
|
+
# Capture telemetry event
|
|
235
|
+
self._capture_initialization_telemetry()
|
|
236
|
+
|
|
237
|
+
def _capture_initialization_telemetry(self):
|
|
238
|
+
"""Capture telemetry event for Graphiti initialization."""
|
|
239
|
+
try:
|
|
240
|
+
# Detect provider types from class names
|
|
241
|
+
llm_provider = self._get_provider_type(self.llm_client)
|
|
242
|
+
embedder_provider = self._get_provider_type(self.embedder)
|
|
243
|
+
reranker_provider = self._get_provider_type(self.cross_encoder)
|
|
244
|
+
database_provider = self._get_provider_type(self.driver)
|
|
245
|
+
|
|
246
|
+
properties = {
|
|
247
|
+
'llm_provider': llm_provider,
|
|
248
|
+
'embedder_provider': embedder_provider,
|
|
249
|
+
'reranker_provider': reranker_provider,
|
|
250
|
+
'database_provider': database_provider,
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
capture_event('graphiti_initialized', properties)
|
|
254
|
+
except Exception:
|
|
255
|
+
# Silently handle telemetry errors
|
|
256
|
+
pass
|
|
257
|
+
|
|
258
|
+
def _get_provider_type(self, client) -> str:
|
|
259
|
+
"""Get provider type from client class name."""
|
|
260
|
+
if client is None:
|
|
261
|
+
return 'none'
|
|
262
|
+
|
|
263
|
+
class_name = client.__class__.__name__.lower()
|
|
264
|
+
|
|
265
|
+
# LLM providers
|
|
266
|
+
if 'openai' in class_name:
|
|
267
|
+
return 'openai'
|
|
268
|
+
elif 'azure' in class_name:
|
|
269
|
+
return 'azure'
|
|
270
|
+
elif 'anthropic' in class_name:
|
|
271
|
+
return 'anthropic'
|
|
272
|
+
elif 'crossencoder' in class_name:
|
|
273
|
+
return 'crossencoder'
|
|
274
|
+
elif 'gemini' in class_name:
|
|
275
|
+
return 'gemini'
|
|
276
|
+
elif 'groq' in class_name:
|
|
277
|
+
return 'groq'
|
|
278
|
+
# Database providers
|
|
279
|
+
elif 'neo4j' in class_name:
|
|
280
|
+
return 'neo4j'
|
|
281
|
+
elif 'falkor' in class_name:
|
|
282
|
+
return 'falkordb'
|
|
283
|
+
# Embedder providers
|
|
284
|
+
elif 'voyage' in class_name:
|
|
285
|
+
return 'voyage'
|
|
286
|
+
else:
|
|
287
|
+
return 'unknown'
|
|
288
|
+
|
|
163
289
|
async def close(self):
|
|
164
290
|
"""
|
|
165
291
|
Close the connection to the Neo4j database.
|
|
@@ -214,25 +340,247 @@ class Graphiti:
|
|
|
214
340
|
-----
|
|
215
341
|
This method should typically be called once during the initial setup of the
|
|
216
342
|
knowledge graph or when updating the database schema. It uses the
|
|
217
|
-
`build_indices_and_constraints`
|
|
218
|
-
`graphiti_core.utils.maintenance.graph_data_operations` module to perform
|
|
343
|
+
driver's `build_indices_and_constraints` method to perform
|
|
219
344
|
the actual database operations.
|
|
220
345
|
|
|
221
346
|
The specific indices and constraints created depend on the implementation
|
|
222
|
-
of the `build_indices_and_constraints`
|
|
223
|
-
documentation for details on the exact database schema modifications.
|
|
347
|
+
of the driver's `build_indices_and_constraints` method. Refer to the specific
|
|
348
|
+
driver documentation for details on the exact database schema modifications.
|
|
224
349
|
|
|
225
350
|
Caution: Running this method on a large existing database may take some time
|
|
226
351
|
and could impact database performance during execution.
|
|
227
352
|
"""
|
|
228
|
-
await
|
|
353
|
+
await self.driver.build_indices_and_constraints(delete_existing)
|
|
354
|
+
|
|
355
|
+
async def _extract_and_resolve_nodes(
|
|
356
|
+
self,
|
|
357
|
+
episode: EpisodicNode,
|
|
358
|
+
previous_episodes: list[EpisodicNode],
|
|
359
|
+
entity_types: dict[str, type[BaseModel]] | None,
|
|
360
|
+
excluded_entity_types: list[str] | None,
|
|
361
|
+
) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
|
|
362
|
+
"""Extract nodes from episode and resolve against existing graph."""
|
|
363
|
+
extracted_nodes = await extract_nodes(
|
|
364
|
+
self.clients, episode, previous_episodes, entity_types, excluded_entity_types
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
nodes, uuid_map, duplicates = await resolve_extracted_nodes(
|
|
368
|
+
self.clients,
|
|
369
|
+
extracted_nodes,
|
|
370
|
+
episode,
|
|
371
|
+
previous_episodes,
|
|
372
|
+
entity_types,
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
return nodes, uuid_map, duplicates
|
|
376
|
+
|
|
377
|
+
async def _extract_and_resolve_edges(
|
|
378
|
+
self,
|
|
379
|
+
episode: EpisodicNode,
|
|
380
|
+
extracted_nodes: list[EntityNode],
|
|
381
|
+
previous_episodes: list[EpisodicNode],
|
|
382
|
+
edge_type_map: dict[tuple[str, str], list[str]],
|
|
383
|
+
group_id: str,
|
|
384
|
+
edge_types: dict[str, type[BaseModel]] | None,
|
|
385
|
+
nodes: list[EntityNode],
|
|
386
|
+
uuid_map: dict[str, str],
|
|
387
|
+
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
|
388
|
+
"""Extract edges from episode and resolve against existing graph."""
|
|
389
|
+
extracted_edges = await extract_edges(
|
|
390
|
+
self.clients,
|
|
391
|
+
episode,
|
|
392
|
+
extracted_nodes,
|
|
393
|
+
previous_episodes,
|
|
394
|
+
edge_type_map,
|
|
395
|
+
group_id,
|
|
396
|
+
edge_types,
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
edges = resolve_edge_pointers(extracted_edges, uuid_map)
|
|
400
|
+
|
|
401
|
+
resolved_edges, invalidated_edges = await resolve_extracted_edges(
|
|
402
|
+
self.clients,
|
|
403
|
+
edges,
|
|
404
|
+
episode,
|
|
405
|
+
nodes,
|
|
406
|
+
edge_types or {},
|
|
407
|
+
edge_type_map,
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
return resolved_edges, invalidated_edges
|
|
411
|
+
|
|
412
|
+
async def _process_episode_data(
|
|
413
|
+
self,
|
|
414
|
+
episode: EpisodicNode,
|
|
415
|
+
nodes: list[EntityNode],
|
|
416
|
+
entity_edges: list[EntityEdge],
|
|
417
|
+
now: datetime,
|
|
418
|
+
) -> tuple[list[EpisodicEdge], EpisodicNode]:
|
|
419
|
+
"""Process and save episode data to the graph."""
|
|
420
|
+
episodic_edges = build_episodic_edges(nodes, episode.uuid, now)
|
|
421
|
+
episode.entity_edges = [edge.uuid for edge in entity_edges]
|
|
422
|
+
|
|
423
|
+
if not self.store_raw_episode_content:
|
|
424
|
+
episode.content = ''
|
|
425
|
+
|
|
426
|
+
await add_nodes_and_edges_bulk(
|
|
427
|
+
self.driver,
|
|
428
|
+
[episode],
|
|
429
|
+
episodic_edges,
|
|
430
|
+
nodes,
|
|
431
|
+
entity_edges,
|
|
432
|
+
self.embedder,
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
return episodic_edges, episode
|
|
436
|
+
|
|
437
|
+
async def _extract_and_dedupe_nodes_bulk(
|
|
438
|
+
self,
|
|
439
|
+
episode_context: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
|
440
|
+
edge_type_map: dict[tuple[str, str], list[str]],
|
|
441
|
+
edge_types: dict[str, type[BaseModel]] | None,
|
|
442
|
+
entity_types: dict[str, type[BaseModel]] | None,
|
|
443
|
+
excluded_entity_types: list[str] | None,
|
|
444
|
+
) -> tuple[
|
|
445
|
+
dict[str, list[EntityNode]],
|
|
446
|
+
dict[str, str],
|
|
447
|
+
list[list[EntityEdge]],
|
|
448
|
+
]:
|
|
449
|
+
"""Extract nodes and edges from all episodes and deduplicate."""
|
|
450
|
+
# Extract all nodes and edges for each episode
|
|
451
|
+
extracted_nodes_bulk, extracted_edges_bulk = await extract_nodes_and_edges_bulk(
|
|
452
|
+
self.clients,
|
|
453
|
+
episode_context,
|
|
454
|
+
edge_type_map=edge_type_map,
|
|
455
|
+
edge_types=edge_types,
|
|
456
|
+
entity_types=entity_types,
|
|
457
|
+
excluded_entity_types=excluded_entity_types,
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
# Dedupe extracted nodes in memory
|
|
461
|
+
nodes_by_episode, uuid_map = await dedupe_nodes_bulk(
|
|
462
|
+
self.clients, extracted_nodes_bulk, episode_context, entity_types
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
return nodes_by_episode, uuid_map, extracted_edges_bulk
|
|
466
|
+
|
|
467
|
+
async def _resolve_nodes_and_edges_bulk(
|
|
468
|
+
self,
|
|
469
|
+
nodes_by_episode: dict[str, list[EntityNode]],
|
|
470
|
+
edges_by_episode: dict[str, list[EntityEdge]],
|
|
471
|
+
episode_context: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
|
472
|
+
entity_types: dict[str, type[BaseModel]] | None,
|
|
473
|
+
edge_types: dict[str, type[BaseModel]] | None,
|
|
474
|
+
edge_type_map: dict[tuple[str, str], list[str]],
|
|
475
|
+
episodes: list[EpisodicNode],
|
|
476
|
+
) -> tuple[list[EntityNode], list[EntityEdge], list[EntityEdge], dict[str, str]]:
|
|
477
|
+
"""Resolve nodes and edges against the existing graph."""
|
|
478
|
+
nodes_by_uuid: dict[str, EntityNode] = {
|
|
479
|
+
node.uuid: node for nodes in nodes_by_episode.values() for node in nodes
|
|
480
|
+
}
|
|
481
|
+
|
|
482
|
+
# Get unique nodes per episode
|
|
483
|
+
nodes_by_episode_unique: dict[str, list[EntityNode]] = {}
|
|
484
|
+
nodes_uuid_set: set[str] = set()
|
|
485
|
+
for episode, _ in episode_context:
|
|
486
|
+
nodes_by_episode_unique[episode.uuid] = []
|
|
487
|
+
nodes = [nodes_by_uuid[node.uuid] for node in nodes_by_episode[episode.uuid]]
|
|
488
|
+
for node in nodes:
|
|
489
|
+
if node.uuid not in nodes_uuid_set:
|
|
490
|
+
nodes_by_episode_unique[episode.uuid].append(node)
|
|
491
|
+
nodes_uuid_set.add(node.uuid)
|
|
492
|
+
|
|
493
|
+
# Resolve nodes
|
|
494
|
+
node_results = await semaphore_gather(
|
|
495
|
+
*[
|
|
496
|
+
resolve_extracted_nodes(
|
|
497
|
+
self.clients,
|
|
498
|
+
nodes_by_episode_unique[episode.uuid],
|
|
499
|
+
episode,
|
|
500
|
+
previous_episodes,
|
|
501
|
+
entity_types,
|
|
502
|
+
)
|
|
503
|
+
for episode, previous_episodes in episode_context
|
|
504
|
+
]
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
resolved_nodes: list[EntityNode] = []
|
|
508
|
+
uuid_map: dict[str, str] = {}
|
|
509
|
+
for result in node_results:
|
|
510
|
+
resolved_nodes.extend(result[0])
|
|
511
|
+
uuid_map.update(result[1])
|
|
512
|
+
|
|
513
|
+
# Update nodes_by_uuid with resolved nodes
|
|
514
|
+
for resolved_node in resolved_nodes:
|
|
515
|
+
nodes_by_uuid[resolved_node.uuid] = resolved_node
|
|
516
|
+
|
|
517
|
+
# Update nodes_by_episode_unique with resolved pointers
|
|
518
|
+
for episode_uuid, nodes in nodes_by_episode_unique.items():
|
|
519
|
+
updated_nodes: list[EntityNode] = []
|
|
520
|
+
for node in nodes:
|
|
521
|
+
updated_node_uuid = uuid_map.get(node.uuid, node.uuid)
|
|
522
|
+
updated_node = nodes_by_uuid[updated_node_uuid]
|
|
523
|
+
updated_nodes.append(updated_node)
|
|
524
|
+
nodes_by_episode_unique[episode_uuid] = updated_nodes
|
|
525
|
+
|
|
526
|
+
# Extract attributes for resolved nodes
|
|
527
|
+
hydrated_nodes_results: list[list[EntityNode]] = await semaphore_gather(
|
|
528
|
+
*[
|
|
529
|
+
extract_attributes_from_nodes(
|
|
530
|
+
self.clients,
|
|
531
|
+
nodes_by_episode_unique[episode.uuid],
|
|
532
|
+
episode,
|
|
533
|
+
previous_episodes,
|
|
534
|
+
entity_types,
|
|
535
|
+
)
|
|
536
|
+
for episode, previous_episodes in episode_context
|
|
537
|
+
]
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
final_hydrated_nodes = [node for nodes in hydrated_nodes_results for node in nodes]
|
|
541
|
+
|
|
542
|
+
# Resolve edges with updated pointers
|
|
543
|
+
edges_by_episode_unique: dict[str, list[EntityEdge]] = {}
|
|
544
|
+
edges_uuid_set: set[str] = set()
|
|
545
|
+
for episode_uuid, edges in edges_by_episode.items():
|
|
546
|
+
edges_with_updated_pointers = resolve_edge_pointers(edges, uuid_map)
|
|
547
|
+
edges_by_episode_unique[episode_uuid] = []
|
|
548
|
+
|
|
549
|
+
for edge in edges_with_updated_pointers:
|
|
550
|
+
if edge.uuid not in edges_uuid_set:
|
|
551
|
+
edges_by_episode_unique[episode_uuid].append(edge)
|
|
552
|
+
edges_uuid_set.add(edge.uuid)
|
|
553
|
+
|
|
554
|
+
edge_results = await semaphore_gather(
|
|
555
|
+
*[
|
|
556
|
+
resolve_extracted_edges(
|
|
557
|
+
self.clients,
|
|
558
|
+
edges_by_episode_unique[episode.uuid],
|
|
559
|
+
episode,
|
|
560
|
+
final_hydrated_nodes,
|
|
561
|
+
edge_types or {},
|
|
562
|
+
edge_type_map,
|
|
563
|
+
)
|
|
564
|
+
for episode in episodes
|
|
565
|
+
]
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
resolved_edges: list[EntityEdge] = []
|
|
569
|
+
invalidated_edges: list[EntityEdge] = []
|
|
570
|
+
for result in edge_results:
|
|
571
|
+
resolved_edges.extend(result[0])
|
|
572
|
+
invalidated_edges.extend(result[1])
|
|
573
|
+
|
|
574
|
+
return final_hydrated_nodes, resolved_edges, invalidated_edges, uuid_map
|
|
229
575
|
|
|
576
|
+
@handle_multiple_group_ids
|
|
230
577
|
async def retrieve_episodes(
|
|
231
578
|
self,
|
|
232
579
|
reference_time: datetime,
|
|
233
580
|
last_n: int = EPISODE_WINDOW_LEN,
|
|
234
581
|
group_ids: list[str] | None = None,
|
|
235
582
|
source: EpisodeType | None = None,
|
|
583
|
+
driver: GraphDriver | None = None,
|
|
236
584
|
) -> list[EpisodicNode]:
|
|
237
585
|
"""
|
|
238
586
|
Retrieve the last n episodic nodes from the graph.
|
|
@@ -259,7 +607,10 @@ class Graphiti:
|
|
|
259
607
|
The actual retrieval is performed by the `retrieve_episodes` function
|
|
260
608
|
from the `graphiti_core.utils` module.
|
|
261
609
|
"""
|
|
262
|
-
|
|
610
|
+
if driver is None:
|
|
611
|
+
driver = self.clients.driver
|
|
612
|
+
|
|
613
|
+
return await retrieve_episodes(driver, reference_time, last_n, group_ids, source)
|
|
263
614
|
|
|
264
615
|
async def add_episode(
|
|
265
616
|
self,
|
|
@@ -268,12 +619,13 @@ class Graphiti:
|
|
|
268
619
|
source_description: str,
|
|
269
620
|
reference_time: datetime,
|
|
270
621
|
source: EpisodeType = EpisodeType.message,
|
|
271
|
-
group_id: str =
|
|
622
|
+
group_id: str | None = None,
|
|
272
623
|
uuid: str | None = None,
|
|
273
624
|
update_communities: bool = False,
|
|
274
|
-
entity_types: dict[str, BaseModel] | None = None,
|
|
625
|
+
entity_types: dict[str, type[BaseModel]] | None = None,
|
|
626
|
+
excluded_entity_types: list[str] | None = None,
|
|
275
627
|
previous_episode_uuids: list[str] | None = None,
|
|
276
|
-
edge_types: dict[str, BaseModel] | None = None,
|
|
628
|
+
edge_types: dict[str, type[BaseModel]] | None = None,
|
|
277
629
|
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
|
|
278
630
|
) -> AddEpisodeResults:
|
|
279
631
|
"""
|
|
@@ -300,6 +652,12 @@ class Graphiti:
|
|
|
300
652
|
Optional uuid of the episode.
|
|
301
653
|
update_communities : bool
|
|
302
654
|
Optional. Whether to update communities with new node information
|
|
655
|
+
entity_types : dict[str, BaseModel] | None
|
|
656
|
+
Optional. Dictionary mapping entity type names to their Pydantic model definitions.
|
|
657
|
+
excluded_entity_types : list[str] | None
|
|
658
|
+
Optional. List of entity type names to exclude from the graph. Entities classified
|
|
659
|
+
into these types will not be added to the graph. Can include 'Entity' to exclude
|
|
660
|
+
the default entity type.
|
|
303
661
|
previous_episode_uuids : list[str] | None
|
|
304
662
|
Optional. list of episode uuids to use as the previous episodes. If this is not provided,
|
|
305
663
|
the most recent episodes by created_at date will be used.
|
|
@@ -325,112 +683,155 @@ class Graphiti:
|
|
|
325
683
|
background_tasks.add_task(graphiti.add_episode, **episode_data.dict())
|
|
326
684
|
return {"message": "Episode processing started"}
|
|
327
685
|
"""
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
now = utc_now()
|
|
686
|
+
start = time()
|
|
687
|
+
now = utc_now()
|
|
331
688
|
|
|
332
|
-
|
|
689
|
+
validate_entity_types(entity_types)
|
|
690
|
+
validate_excluded_entity_types(excluded_entity_types, entity_types)
|
|
333
691
|
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
if
|
|
342
|
-
|
|
343
|
-
|
|
692
|
+
if group_id is None:
|
|
693
|
+
# if group_id is None, use the default group id by the provider
|
|
694
|
+
# and the preset database name will be used
|
|
695
|
+
group_id = get_default_group_id(self.driver.provider)
|
|
696
|
+
else:
|
|
697
|
+
validate_group_id(group_id)
|
|
698
|
+
if group_id != self.driver._database:
|
|
699
|
+
# if group_id is provided, use it as the database name
|
|
700
|
+
self.driver = self.driver.clone(database=group_id)
|
|
701
|
+
self.clients.driver = self.driver
|
|
344
702
|
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
703
|
+
with self.tracer.start_span('add_episode') as span:
|
|
704
|
+
try:
|
|
705
|
+
# Retrieve previous episodes for context
|
|
706
|
+
previous_episodes = (
|
|
707
|
+
await self.retrieve_episodes(
|
|
708
|
+
reference_time,
|
|
709
|
+
last_n=RELEVANT_SCHEMA_LIMIT,
|
|
710
|
+
group_ids=[group_id],
|
|
711
|
+
source=source,
|
|
712
|
+
)
|
|
713
|
+
if previous_episode_uuids is None
|
|
714
|
+
else await EpisodicNode.get_by_uuids(self.driver, previous_episode_uuids)
|
|
357
715
|
)
|
|
358
|
-
)
|
|
359
716
|
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
717
|
+
# Get or create episode
|
|
718
|
+
episode = (
|
|
719
|
+
await EpisodicNode.get_by_uuid(self.driver, uuid)
|
|
720
|
+
if uuid is not None
|
|
721
|
+
else EpisodicNode(
|
|
722
|
+
name=name,
|
|
723
|
+
group_id=group_id,
|
|
724
|
+
labels=[],
|
|
725
|
+
source=source,
|
|
726
|
+
content=episode_body,
|
|
727
|
+
source_description=source_description,
|
|
728
|
+
created_at=now,
|
|
729
|
+
valid_at=reference_time,
|
|
730
|
+
)
|
|
731
|
+
)
|
|
366
732
|
|
|
367
|
-
|
|
733
|
+
# Create default edge type map
|
|
734
|
+
edge_type_map_default = (
|
|
735
|
+
{('Entity', 'Entity'): list(edge_types.keys())}
|
|
736
|
+
if edge_types is not None
|
|
737
|
+
else {('Entity', 'Entity'): []}
|
|
738
|
+
)
|
|
368
739
|
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
740
|
+
# Extract and resolve nodes
|
|
741
|
+
extracted_nodes = await extract_nodes(
|
|
742
|
+
self.clients, episode, previous_episodes, entity_types, excluded_entity_types
|
|
743
|
+
)
|
|
372
744
|
|
|
373
|
-
|
|
374
|
-
(nodes, uuid_map), extracted_edges = await semaphore_gather(
|
|
375
|
-
resolve_extracted_nodes(
|
|
745
|
+
nodes, uuid_map, _ = await resolve_extracted_nodes(
|
|
376
746
|
self.clients,
|
|
377
747
|
extracted_nodes,
|
|
378
748
|
episode,
|
|
379
749
|
previous_episodes,
|
|
380
750
|
entity_types,
|
|
381
|
-
)
|
|
382
|
-
extract_edges(
|
|
383
|
-
self.clients, episode, extracted_nodes, previous_episodes, group_id, edge_types
|
|
384
|
-
),
|
|
385
|
-
)
|
|
386
|
-
|
|
387
|
-
edges = resolve_edge_pointers(extracted_edges, uuid_map)
|
|
751
|
+
)
|
|
388
752
|
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
self.clients,
|
|
392
|
-
edges,
|
|
753
|
+
# Extract and resolve edges in parallel with attribute extraction
|
|
754
|
+
resolved_edges, invalidated_edges = await self._extract_and_resolve_edges(
|
|
393
755
|
episode,
|
|
394
|
-
|
|
395
|
-
|
|
756
|
+
extracted_nodes,
|
|
757
|
+
previous_episodes,
|
|
396
758
|
edge_type_map or edge_type_map_default,
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
759
|
+
group_id,
|
|
760
|
+
edge_types,
|
|
761
|
+
nodes,
|
|
762
|
+
uuid_map,
|
|
763
|
+
)
|
|
402
764
|
|
|
403
|
-
|
|
765
|
+
# Extract node attributes
|
|
766
|
+
hydrated_nodes = await extract_attributes_from_nodes(
|
|
767
|
+
self.clients, nodes, episode, previous_episodes, entity_types
|
|
768
|
+
)
|
|
404
769
|
|
|
405
|
-
|
|
770
|
+
entity_edges = resolved_edges + invalidated_edges
|
|
406
771
|
|
|
407
|
-
|
|
772
|
+
# Process and save episode data
|
|
773
|
+
episodic_edges, episode = await self._process_episode_data(
|
|
774
|
+
episode, hydrated_nodes, entity_edges, now
|
|
775
|
+
)
|
|
408
776
|
|
|
409
|
-
|
|
410
|
-
|
|
777
|
+
# Update communities if requested
|
|
778
|
+
communities = []
|
|
779
|
+
community_edges = []
|
|
780
|
+
if update_communities:
|
|
781
|
+
communities, community_edges = await semaphore_gather(
|
|
782
|
+
*[
|
|
783
|
+
update_community(self.driver, self.llm_client, self.embedder, node)
|
|
784
|
+
for node in nodes
|
|
785
|
+
],
|
|
786
|
+
max_coroutines=self.max_coroutines,
|
|
787
|
+
)
|
|
788
|
+
|
|
789
|
+
end = time()
|
|
790
|
+
|
|
791
|
+
# Add span attributes
|
|
792
|
+
span.add_attributes(
|
|
793
|
+
{
|
|
794
|
+
'episode.uuid': episode.uuid,
|
|
795
|
+
'episode.source': source.value,
|
|
796
|
+
'episode.reference_time': reference_time.isoformat(),
|
|
797
|
+
'group_id': group_id,
|
|
798
|
+
'node.count': len(hydrated_nodes),
|
|
799
|
+
'edge.count': len(entity_edges),
|
|
800
|
+
'edge.invalidated_count': len(invalidated_edges),
|
|
801
|
+
'previous_episodes.count': len(previous_episodes),
|
|
802
|
+
'entity_types.count': len(entity_types) if entity_types else 0,
|
|
803
|
+
'edge_types.count': len(edge_types) if edge_types else 0,
|
|
804
|
+
'update_communities': update_communities,
|
|
805
|
+
'communities.count': len(communities) if update_communities else 0,
|
|
806
|
+
'duration_ms': (end - start) * 1000,
|
|
807
|
+
}
|
|
808
|
+
)
|
|
411
809
|
|
|
412
|
-
|
|
413
|
-
self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges, self.embedder
|
|
414
|
-
)
|
|
810
|
+
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
|
415
811
|
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
812
|
+
return AddEpisodeResults(
|
|
813
|
+
episode=episode,
|
|
814
|
+
episodic_edges=episodic_edges,
|
|
815
|
+
nodes=hydrated_nodes,
|
|
816
|
+
edges=entity_edges,
|
|
817
|
+
communities=communities,
|
|
818
|
+
community_edges=community_edges,
|
|
423
819
|
)
|
|
424
|
-
end = time()
|
|
425
|
-
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
|
426
|
-
|
|
427
|
-
return AddEpisodeResults(episode=episode, nodes=nodes, edges=entity_edges)
|
|
428
820
|
|
|
429
|
-
|
|
430
|
-
|
|
821
|
+
except Exception as e:
|
|
822
|
+
span.set_status('error', str(e))
|
|
823
|
+
span.record_exception(e)
|
|
824
|
+
raise e
|
|
431
825
|
|
|
432
|
-
|
|
433
|
-
|
|
826
|
+
async def add_episode_bulk(
|
|
827
|
+
self,
|
|
828
|
+
bulk_episodes: list[RawEpisode],
|
|
829
|
+
group_id: str | None = None,
|
|
830
|
+
entity_types: dict[str, type[BaseModel]] | None = None,
|
|
831
|
+
excluded_entity_types: list[str] | None = None,
|
|
832
|
+
edge_types: dict[str, type[BaseModel]] | None = None,
|
|
833
|
+
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
|
|
834
|
+
) -> AddBulkEpisodeResults:
|
|
434
835
|
"""
|
|
435
836
|
Process multiple episodes in bulk and update the graph.
|
|
436
837
|
|
|
@@ -446,7 +847,7 @@ class Graphiti:
|
|
|
446
847
|
|
|
447
848
|
Returns
|
|
448
849
|
-------
|
|
449
|
-
|
|
850
|
+
AddBulkEpisodeResults
|
|
450
851
|
|
|
451
852
|
Notes
|
|
452
853
|
-----
|
|
@@ -467,106 +868,186 @@ class Graphiti:
|
|
|
467
868
|
If these operations are required, use the `add_episode` method instead for each
|
|
468
869
|
individual episode.
|
|
469
870
|
"""
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
group_id
|
|
482
|
-
|
|
483
|
-
|
|
871
|
+
with self.tracer.start_span('add_episode_bulk') as bulk_span:
|
|
872
|
+
bulk_span.add_attributes({'episode.count': len(bulk_episodes)})
|
|
873
|
+
|
|
874
|
+
try:
|
|
875
|
+
start = time()
|
|
876
|
+
now = utc_now()
|
|
877
|
+
|
|
878
|
+
# if group_id is None, use the default group id by the provider
|
|
879
|
+
if group_id is None:
|
|
880
|
+
group_id = get_default_group_id(self.driver.provider)
|
|
881
|
+
else:
|
|
882
|
+
validate_group_id(group_id)
|
|
883
|
+
if group_id != self.driver._database:
|
|
884
|
+
# if group_id is provided, use it as the database name
|
|
885
|
+
self.driver = self.driver.clone(database=group_id)
|
|
886
|
+
self.clients.driver = self.driver
|
|
887
|
+
|
|
888
|
+
# Create default edge type map
|
|
889
|
+
edge_type_map_default = (
|
|
890
|
+
{('Entity', 'Entity'): list(edge_types.keys())}
|
|
891
|
+
if edge_types is not None
|
|
892
|
+
else {('Entity', 'Entity'): []}
|
|
484
893
|
)
|
|
485
|
-
for episode in bulk_episodes
|
|
486
|
-
]
|
|
487
894
|
|
|
488
|
-
|
|
489
|
-
|
|
895
|
+
episodes = [
|
|
896
|
+
await EpisodicNode.get_by_uuid(self.driver, episode.uuid)
|
|
897
|
+
if episode.uuid is not None
|
|
898
|
+
else EpisodicNode(
|
|
899
|
+
name=episode.name,
|
|
900
|
+
labels=[],
|
|
901
|
+
source=episode.source,
|
|
902
|
+
content=episode.content,
|
|
903
|
+
source_description=episode.source_description,
|
|
904
|
+
group_id=group_id,
|
|
905
|
+
created_at=now,
|
|
906
|
+
valid_at=episode.reference_time,
|
|
907
|
+
)
|
|
908
|
+
for episode in bulk_episodes
|
|
909
|
+
]
|
|
910
|
+
|
|
911
|
+
# Save all episodes
|
|
912
|
+
await add_nodes_and_edges_bulk(
|
|
913
|
+
driver=self.driver,
|
|
914
|
+
episodic_nodes=episodes,
|
|
915
|
+
episodic_edges=[],
|
|
916
|
+
entity_nodes=[],
|
|
917
|
+
entity_edges=[],
|
|
918
|
+
embedder=self.embedder,
|
|
919
|
+
)
|
|
490
920
|
|
|
491
|
-
|
|
492
|
-
|
|
921
|
+
# Get previous episode context for each episode
|
|
922
|
+
episode_context = await retrieve_previous_episodes_bulk(self.driver, episodes)
|
|
493
923
|
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
924
|
+
# Extract and dedupe nodes and edges
|
|
925
|
+
(
|
|
926
|
+
nodes_by_episode,
|
|
927
|
+
uuid_map,
|
|
928
|
+
extracted_edges_bulk,
|
|
929
|
+
) = await self._extract_and_dedupe_nodes_bulk(
|
|
930
|
+
episode_context,
|
|
931
|
+
edge_type_map or edge_type_map_default,
|
|
932
|
+
edge_types,
|
|
933
|
+
entity_types,
|
|
934
|
+
excluded_entity_types,
|
|
935
|
+
)
|
|
500
936
|
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
)
|
|
937
|
+
# Create Episodic Edges
|
|
938
|
+
episodic_edges: list[EpisodicEdge] = []
|
|
939
|
+
for episode_uuid, nodes in nodes_by_episode.items():
|
|
940
|
+
episodic_edges.extend(build_episodic_edges(nodes, episode_uuid, now))
|
|
506
941
|
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
)
|
|
942
|
+
# Re-map edge pointers and dedupe edges
|
|
943
|
+
extracted_edges_bulk_updated: list[list[EntityEdge]] = [
|
|
944
|
+
resolve_edge_pointers(edges, uuid_map) for edges in extracted_edges_bulk
|
|
945
|
+
]
|
|
512
946
|
|
|
513
|
-
|
|
514
|
-
|
|
947
|
+
edges_by_episode = await dedupe_edges_bulk(
|
|
948
|
+
self.clients,
|
|
949
|
+
extracted_edges_bulk_updated,
|
|
950
|
+
episode_context,
|
|
951
|
+
[],
|
|
952
|
+
edge_types or {},
|
|
953
|
+
edge_type_map or edge_type_map_default,
|
|
954
|
+
)
|
|
515
955
|
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
956
|
+
# Resolve nodes and edges against the existing graph
|
|
957
|
+
(
|
|
958
|
+
final_hydrated_nodes,
|
|
959
|
+
resolved_edges,
|
|
960
|
+
invalidated_edges,
|
|
961
|
+
final_uuid_map,
|
|
962
|
+
) = await self._resolve_nodes_and_edges_bulk(
|
|
963
|
+
nodes_by_episode,
|
|
964
|
+
edges_by_episode,
|
|
965
|
+
episode_context,
|
|
966
|
+
entity_types,
|
|
967
|
+
edge_types,
|
|
968
|
+
edge_type_map or edge_type_map_default,
|
|
969
|
+
episodes,
|
|
970
|
+
)
|
|
523
971
|
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
972
|
+
# Resolved pointers for episodic edges
|
|
973
|
+
resolved_episodic_edges = resolve_edge_pointers(episodic_edges, final_uuid_map)
|
|
974
|
+
|
|
975
|
+
# save data to KG
|
|
976
|
+
await add_nodes_and_edges_bulk(
|
|
977
|
+
self.driver,
|
|
978
|
+
episodes,
|
|
979
|
+
resolved_episodic_edges,
|
|
980
|
+
final_hydrated_nodes,
|
|
981
|
+
resolved_edges + invalidated_edges,
|
|
982
|
+
self.embedder,
|
|
983
|
+
)
|
|
528
984
|
|
|
529
|
-
|
|
530
|
-
edges = await dedupe_edges_bulk(
|
|
531
|
-
self.driver, self.llm_client, extracted_edges_with_resolved_pointers
|
|
532
|
-
)
|
|
533
|
-
logger.debug(f'extracted edge length: {len(edges)}')
|
|
985
|
+
end = time()
|
|
534
986
|
|
|
535
|
-
|
|
987
|
+
# Add span attributes
|
|
988
|
+
bulk_span.add_attributes(
|
|
989
|
+
{
|
|
990
|
+
'group_id': group_id,
|
|
991
|
+
'node.count': len(final_hydrated_nodes),
|
|
992
|
+
'edge.count': len(resolved_edges + invalidated_edges),
|
|
993
|
+
'duration_ms': (end - start) * 1000,
|
|
994
|
+
}
|
|
995
|
+
)
|
|
536
996
|
|
|
537
|
-
|
|
538
|
-
await semaphore_gather(*[edge.save(self.driver) for edge in edges])
|
|
997
|
+
logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
|
|
539
998
|
|
|
540
|
-
|
|
541
|
-
|
|
999
|
+
return AddBulkEpisodeResults(
|
|
1000
|
+
episodes=episodes,
|
|
1001
|
+
episodic_edges=resolved_episodic_edges,
|
|
1002
|
+
nodes=final_hydrated_nodes,
|
|
1003
|
+
edges=resolved_edges + invalidated_edges,
|
|
1004
|
+
communities=[],
|
|
1005
|
+
community_edges=[],
|
|
1006
|
+
)
|
|
542
1007
|
|
|
543
|
-
|
|
544
|
-
|
|
1008
|
+
except Exception as e:
|
|
1009
|
+
bulk_span.set_status('error', str(e))
|
|
1010
|
+
bulk_span.record_exception(e)
|
|
1011
|
+
raise e
|
|
545
1012
|
|
|
546
|
-
|
|
1013
|
+
@handle_multiple_group_ids
|
|
1014
|
+
async def build_communities(
|
|
1015
|
+
self, group_ids: list[str] | None = None, driver: GraphDriver | None = None
|
|
1016
|
+
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
|
547
1017
|
"""
|
|
548
1018
|
Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
|
|
549
1019
|
the content of these communities.
|
|
550
1020
|
----------
|
|
551
|
-
|
|
1021
|
+
group_ids : list[str] | None
|
|
552
1022
|
Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
|
|
553
1023
|
"""
|
|
1024
|
+
if driver is None:
|
|
1025
|
+
driver = self.clients.driver
|
|
1026
|
+
|
|
554
1027
|
# Clear existing communities
|
|
555
|
-
await remove_communities(
|
|
1028
|
+
await remove_communities(driver)
|
|
556
1029
|
|
|
557
1030
|
community_nodes, community_edges = await build_communities(
|
|
558
|
-
|
|
1031
|
+
driver, self.llm_client, group_ids
|
|
559
1032
|
)
|
|
560
1033
|
|
|
561
1034
|
await semaphore_gather(
|
|
562
|
-
*[node.generate_name_embedding(self.embedder) for node in community_nodes]
|
|
1035
|
+
*[node.generate_name_embedding(self.embedder) for node in community_nodes],
|
|
1036
|
+
max_coroutines=self.max_coroutines,
|
|
563
1037
|
)
|
|
564
1038
|
|
|
565
|
-
await semaphore_gather(
|
|
566
|
-
|
|
1039
|
+
await semaphore_gather(
|
|
1040
|
+
*[node.save(driver) for node in community_nodes],
|
|
1041
|
+
max_coroutines=self.max_coroutines,
|
|
1042
|
+
)
|
|
1043
|
+
await semaphore_gather(
|
|
1044
|
+
*[edge.save(driver) for edge in community_edges],
|
|
1045
|
+
max_coroutines=self.max_coroutines,
|
|
1046
|
+
)
|
|
567
1047
|
|
|
568
|
-
return community_nodes
|
|
1048
|
+
return community_nodes, community_edges
|
|
569
1049
|
|
|
1050
|
+
@handle_multiple_group_ids
|
|
570
1051
|
async def search(
|
|
571
1052
|
self,
|
|
572
1053
|
query: str,
|
|
@@ -574,6 +1055,7 @@ class Graphiti:
|
|
|
574
1055
|
group_ids: list[str] | None = None,
|
|
575
1056
|
num_results=DEFAULT_SEARCH_LIMIT,
|
|
576
1057
|
search_filter: SearchFilters | None = None,
|
|
1058
|
+
driver: GraphDriver | None = None,
|
|
577
1059
|
) -> list[EntityEdge]:
|
|
578
1060
|
"""
|
|
579
1061
|
Perform a hybrid search on the knowledge graph.
|
|
@@ -620,7 +1102,8 @@ class Graphiti:
|
|
|
620
1102
|
group_ids,
|
|
621
1103
|
search_config,
|
|
622
1104
|
search_filter if search_filter is not None else SearchFilters(),
|
|
623
|
-
|
|
1105
|
+
driver=driver,
|
|
1106
|
+
center_node_uuid=center_node_uuid,
|
|
624
1107
|
)
|
|
625
1108
|
).edges
|
|
626
1109
|
|
|
@@ -640,6 +1123,7 @@ class Graphiti:
|
|
|
640
1123
|
query, config, group_ids, center_node_uuid, bfs_origin_node_uuids, search_filter
|
|
641
1124
|
)
|
|
642
1125
|
|
|
1126
|
+
@handle_multiple_group_ids
|
|
643
1127
|
async def search_(
|
|
644
1128
|
self,
|
|
645
1129
|
query: str,
|
|
@@ -648,6 +1132,7 @@ class Graphiti:
|
|
|
648
1132
|
center_node_uuid: str | None = None,
|
|
649
1133
|
bfs_origin_node_uuids: list[str] | None = None,
|
|
650
1134
|
search_filter: SearchFilters | None = None,
|
|
1135
|
+
driver: GraphDriver | None = None,
|
|
651
1136
|
) -> SearchResults:
|
|
652
1137
|
"""search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather
|
|
653
1138
|
than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and
|
|
@@ -664,22 +1149,26 @@ class Graphiti:
|
|
|
664
1149
|
search_filter if search_filter is not None else SearchFilters(),
|
|
665
1150
|
center_node_uuid,
|
|
666
1151
|
bfs_origin_node_uuids,
|
|
1152
|
+
driver=driver,
|
|
667
1153
|
)
|
|
668
1154
|
|
|
669
1155
|
async def get_nodes_and_edges_by_episode(self, episode_uuids: list[str]) -> SearchResults:
|
|
670
1156
|
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
|
|
671
1157
|
|
|
672
1158
|
edges_list = await semaphore_gather(
|
|
673
|
-
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
|
|
1159
|
+
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes],
|
|
1160
|
+
max_coroutines=self.max_coroutines,
|
|
674
1161
|
)
|
|
675
1162
|
|
|
676
1163
|
edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
|
|
677
1164
|
|
|
678
1165
|
nodes = await get_mentioned_nodes(self.driver, episodes)
|
|
679
1166
|
|
|
680
|
-
return SearchResults(edges=edges, nodes=nodes
|
|
1167
|
+
return SearchResults(edges=edges, nodes=nodes)
|
|
681
1168
|
|
|
682
|
-
async def add_triplet(
|
|
1169
|
+
async def add_triplet(
|
|
1170
|
+
self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode
|
|
1171
|
+
) -> AddTripletResults:
|
|
683
1172
|
if source_node.name_embedding is None:
|
|
684
1173
|
await source_node.generate_name_embedding(self.embedder)
|
|
685
1174
|
if target_node.name_embedding is None:
|
|
@@ -687,19 +1176,37 @@ class Graphiti:
|
|
|
687
1176
|
if edge.fact_embedding is None:
|
|
688
1177
|
await edge.generate_embedding(self.embedder)
|
|
689
1178
|
|
|
690
|
-
|
|
1179
|
+
nodes, uuid_map, _ = await resolve_extracted_nodes(
|
|
691
1180
|
self.clients,
|
|
692
1181
|
[source_node, target_node],
|
|
693
1182
|
)
|
|
694
1183
|
|
|
695
1184
|
updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
|
|
696
1185
|
|
|
697
|
-
|
|
1186
|
+
valid_edges = await EntityEdge.get_between_nodes(
|
|
1187
|
+
self.driver, edge.source_node_uuid, edge.target_node_uuid
|
|
1188
|
+
)
|
|
1189
|
+
|
|
1190
|
+
related_edges = (
|
|
1191
|
+
await search(
|
|
1192
|
+
self.clients,
|
|
1193
|
+
updated_edge.fact,
|
|
1194
|
+
group_ids=[updated_edge.group_id],
|
|
1195
|
+
config=EDGE_HYBRID_SEARCH_RRF,
|
|
1196
|
+
search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
|
|
1197
|
+
)
|
|
1198
|
+
).edges
|
|
698
1199
|
existing_edges = (
|
|
699
|
-
await
|
|
700
|
-
|
|
1200
|
+
await search(
|
|
1201
|
+
self.clients,
|
|
1202
|
+
updated_edge.fact,
|
|
1203
|
+
group_ids=[updated_edge.group_id],
|
|
1204
|
+
config=EDGE_HYBRID_SEARCH_RRF,
|
|
1205
|
+
search_filter=SearchFilters(),
|
|
1206
|
+
)
|
|
1207
|
+
).edges
|
|
701
1208
|
|
|
702
|
-
resolved_edge, invalidated_edges = await resolve_extracted_edge(
|
|
1209
|
+
resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
|
|
703
1210
|
self.llm_client,
|
|
704
1211
|
updated_edge,
|
|
705
1212
|
related_edges,
|
|
@@ -713,11 +1220,17 @@ class Graphiti:
|
|
|
713
1220
|
entity_edges=[],
|
|
714
1221
|
group_id=edge.group_id,
|
|
715
1222
|
),
|
|
1223
|
+
None,
|
|
1224
|
+
None,
|
|
716
1225
|
)
|
|
717
1226
|
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
)
|
|
1227
|
+
edges: list[EntityEdge] = [resolved_edge] + invalidated_edges
|
|
1228
|
+
|
|
1229
|
+
await create_entity_edge_embeddings(self.embedder, edges)
|
|
1230
|
+
await create_entity_node_embeddings(self.embedder, nodes)
|
|
1231
|
+
|
|
1232
|
+
await add_nodes_and_edges_bulk(self.driver, [], [], nodes, edges, self.embedder)
|
|
1233
|
+
return AddTripletResults(edges=edges, nodes=nodes)
|
|
721
1234
|
|
|
722
1235
|
async def remove_episode(self, episode_uuid: str):
|
|
723
1236
|
# Find the episode to be deleted
|
|
@@ -738,14 +1251,13 @@ class Graphiti:
|
|
|
738
1251
|
nodes_to_delete: list[EntityNode] = []
|
|
739
1252
|
for node in nodes:
|
|
740
1253
|
query: LiteralString = 'MATCH (e:Episodic)-[:MENTIONS]->(n:Entity {uuid: $uuid}) RETURN count(*) AS episode_count'
|
|
741
|
-
records, _, _ = await self.driver.execute_query(
|
|
742
|
-
query, uuid=node.uuid, database_=DEFAULT_DATABASE, routing_='r'
|
|
743
|
-
)
|
|
1254
|
+
records, _, _ = await self.driver.execute_query(query, uuid=node.uuid, routing_='r')
|
|
744
1255
|
|
|
745
1256
|
for record in records:
|
|
746
1257
|
if record['episode_count'] == 1:
|
|
747
1258
|
nodes_to_delete.append(node)
|
|
748
1259
|
|
|
749
|
-
await
|
|
750
|
-
await
|
|
1260
|
+
await Edge.delete_by_uuids(self.driver, [edge.uuid for edge in edges_to_delete])
|
|
1261
|
+
await Node.delete_by_uuids(self.driver, [node.uuid for node in nodes_to_delete])
|
|
1262
|
+
|
|
751
1263
|
await episode.delete(self.driver)
|