graphiti-core 0.12.0rc1__py3-none-any.whl → 0.24.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. graphiti_core/cross_encoder/bge_reranker_client.py +12 -2
  2. graphiti_core/cross_encoder/gemini_reranker_client.py +161 -0
  3. graphiti_core/cross_encoder/openai_reranker_client.py +7 -5
  4. graphiti_core/decorators.py +110 -0
  5. graphiti_core/driver/__init__.py +19 -0
  6. graphiti_core/driver/driver.py +124 -0
  7. graphiti_core/driver/falkordb_driver.py +362 -0
  8. graphiti_core/driver/graph_operations/graph_operations.py +191 -0
  9. graphiti_core/driver/kuzu_driver.py +182 -0
  10. graphiti_core/driver/neo4j_driver.py +117 -0
  11. graphiti_core/driver/neptune_driver.py +305 -0
  12. graphiti_core/driver/search_interface/search_interface.py +89 -0
  13. graphiti_core/edges.py +287 -172
  14. graphiti_core/embedder/azure_openai.py +71 -0
  15. graphiti_core/embedder/client.py +2 -1
  16. graphiti_core/embedder/gemini.py +116 -22
  17. graphiti_core/embedder/voyage.py +13 -2
  18. graphiti_core/errors.py +8 -0
  19. graphiti_core/graph_queries.py +162 -0
  20. graphiti_core/graphiti.py +705 -193
  21. graphiti_core/graphiti_types.py +4 -2
  22. graphiti_core/helpers.py +87 -10
  23. graphiti_core/llm_client/__init__.py +16 -0
  24. graphiti_core/llm_client/anthropic_client.py +159 -56
  25. graphiti_core/llm_client/azure_openai_client.py +115 -0
  26. graphiti_core/llm_client/client.py +98 -21
  27. graphiti_core/llm_client/config.py +1 -1
  28. graphiti_core/llm_client/gemini_client.py +290 -41
  29. graphiti_core/llm_client/groq_client.py +14 -3
  30. graphiti_core/llm_client/openai_base_client.py +261 -0
  31. graphiti_core/llm_client/openai_client.py +56 -132
  32. graphiti_core/llm_client/openai_generic_client.py +91 -56
  33. graphiti_core/models/edges/edge_db_queries.py +259 -35
  34. graphiti_core/models/nodes/node_db_queries.py +311 -32
  35. graphiti_core/nodes.py +420 -205
  36. graphiti_core/prompts/dedupe_edges.py +46 -32
  37. graphiti_core/prompts/dedupe_nodes.py +67 -42
  38. graphiti_core/prompts/eval.py +4 -4
  39. graphiti_core/prompts/extract_edges.py +27 -16
  40. graphiti_core/prompts/extract_nodes.py +74 -31
  41. graphiti_core/prompts/prompt_helpers.py +39 -0
  42. graphiti_core/prompts/snippets.py +29 -0
  43. graphiti_core/prompts/summarize_nodes.py +23 -25
  44. graphiti_core/search/search.py +158 -82
  45. graphiti_core/search/search_config.py +39 -4
  46. graphiti_core/search/search_filters.py +126 -35
  47. graphiti_core/search/search_helpers.py +5 -6
  48. graphiti_core/search/search_utils.py +1405 -485
  49. graphiti_core/telemetry/__init__.py +9 -0
  50. graphiti_core/telemetry/telemetry.py +117 -0
  51. graphiti_core/tracer.py +193 -0
  52. graphiti_core/utils/bulk_utils.py +364 -285
  53. graphiti_core/utils/datetime_utils.py +13 -0
  54. graphiti_core/utils/maintenance/community_operations.py +67 -49
  55. graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
  56. graphiti_core/utils/maintenance/edge_operations.py +339 -197
  57. graphiti_core/utils/maintenance/graph_data_operations.py +50 -114
  58. graphiti_core/utils/maintenance/node_operations.py +319 -238
  59. graphiti_core/utils/maintenance/temporal_operations.py +11 -3
  60. graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
  61. graphiti_core/utils/text_utils.py +53 -0
  62. graphiti_core-0.24.3.dist-info/METADATA +726 -0
  63. graphiti_core-0.24.3.dist-info/RECORD +86 -0
  64. {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
  65. graphiti_core-0.12.0rc1.dist-info/METADATA +0 -350
  66. graphiti_core-0.12.0rc1.dist-info/RECORD +0 -66
  67. /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
  68. {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info/licenses}/LICENSE +0 -0
graphiti_core/graphiti.py CHANGED
@@ -19,18 +19,38 @@ from datetime import datetime
19
19
  from time import time
20
20
 
21
21
  from dotenv import load_dotenv
22
- from neo4j import AsyncGraphDatabase
23
22
  from pydantic import BaseModel
24
23
  from typing_extensions import LiteralString
25
24
 
26
25
  from graphiti_core.cross_encoder.client import CrossEncoderClient
27
26
  from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
28
- from graphiti_core.edges import EntityEdge, EpisodicEdge
27
+ from graphiti_core.decorators import handle_multiple_group_ids
28
+ from graphiti_core.driver.driver import GraphDriver
29
+ from graphiti_core.driver.neo4j_driver import Neo4jDriver
30
+ from graphiti_core.edges import (
31
+ CommunityEdge,
32
+ Edge,
33
+ EntityEdge,
34
+ EpisodicEdge,
35
+ create_entity_edge_embeddings,
36
+ )
29
37
  from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
30
38
  from graphiti_core.graphiti_types import GraphitiClients
31
- from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
39
+ from graphiti_core.helpers import (
40
+ get_default_group_id,
41
+ semaphore_gather,
42
+ validate_excluded_entity_types,
43
+ validate_group_id,
44
+ )
32
45
  from graphiti_core.llm_client import LLMClient, OpenAIClient
33
- from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
46
+ from graphiti_core.nodes import (
47
+ CommunityNode,
48
+ EntityNode,
49
+ EpisodeType,
50
+ EpisodicNode,
51
+ Node,
52
+ create_entity_node_embeddings,
53
+ )
34
54
  from graphiti_core.search.search import SearchConfig, search
35
55
  from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
36
56
  from graphiti_core.search.search_config_recipes import (
@@ -41,16 +61,15 @@ from graphiti_core.search.search_config_recipes import (
41
61
  from graphiti_core.search.search_filters import SearchFilters
42
62
  from graphiti_core.search.search_utils import (
43
63
  RELEVANT_SCHEMA_LIMIT,
44
- get_edge_invalidation_candidates,
45
64
  get_mentioned_nodes,
46
- get_relevant_edges,
47
65
  )
66
+ from graphiti_core.telemetry import capture_event
67
+ from graphiti_core.tracer import Tracer, create_tracer
48
68
  from graphiti_core.utils.bulk_utils import (
49
69
  RawEpisode,
50
70
  add_nodes_and_edges_bulk,
51
71
  dedupe_edges_bulk,
52
72
  dedupe_nodes_bulk,
53
- extract_edge_dates_bulk,
54
73
  extract_nodes_and_edges_bulk,
55
74
  resolve_edge_pointers,
56
75
  retrieve_previous_episodes_bulk,
@@ -69,7 +88,6 @@ from graphiti_core.utils.maintenance.edge_operations import (
69
88
  )
70
89
  from graphiti_core.utils.maintenance.graph_data_operations import (
71
90
  EPISODE_WINDOW_LEN,
72
- build_indices_and_constraints,
73
91
  retrieve_episodes,
74
92
  )
75
93
  from graphiti_core.utils.maintenance.node_operations import (
@@ -86,6 +104,23 @@ load_dotenv()
86
104
 
87
105
  class AddEpisodeResults(BaseModel):
88
106
  episode: EpisodicNode
107
+ episodic_edges: list[EpisodicEdge]
108
+ nodes: list[EntityNode]
109
+ edges: list[EntityEdge]
110
+ communities: list[CommunityNode]
111
+ community_edges: list[CommunityEdge]
112
+
113
+
114
+ class AddBulkEpisodeResults(BaseModel):
115
+ episodes: list[EpisodicNode]
116
+ episodic_edges: list[EpisodicEdge]
117
+ nodes: list[EntityNode]
118
+ edges: list[EntityEdge]
119
+ communities: list[CommunityNode]
120
+ community_edges: list[CommunityEdge]
121
+
122
+
123
+ class AddTripletResults(BaseModel):
89
124
  nodes: list[EntityNode]
90
125
  edges: list[EntityEdge]
91
126
 
@@ -93,18 +128,22 @@ class AddEpisodeResults(BaseModel):
93
128
  class Graphiti:
94
129
  def __init__(
95
130
  self,
96
- uri: str,
97
- user: str,
98
- password: str,
131
+ uri: str | None = None,
132
+ user: str | None = None,
133
+ password: str | None = None,
99
134
  llm_client: LLMClient | None = None,
100
135
  embedder: EmbedderClient | None = None,
101
136
  cross_encoder: CrossEncoderClient | None = None,
102
137
  store_raw_episode_content: bool = True,
138
+ graph_driver: GraphDriver | None = None,
139
+ max_coroutines: int | None = None,
140
+ tracer: Tracer | None = None,
141
+ trace_span_prefix: str = 'graphiti',
103
142
  ):
104
143
  """
105
144
  Initialize a Graphiti instance.
106
145
 
107
- This constructor sets up a connection to the Neo4j database and initializes
146
+ This constructor sets up a connection to a graph database and initializes
108
147
  the LLM client for natural language processing tasks.
109
148
 
110
149
  Parameters
@@ -118,6 +157,24 @@ class Graphiti:
118
157
  llm_client : LLMClient | None, optional
119
158
  An instance of LLMClient for natural language processing tasks.
120
159
  If not provided, a default OpenAIClient will be initialized.
160
+ embedder : EmbedderClient | None, optional
161
+ An instance of EmbedderClient for embedding tasks.
162
+ If not provided, a default OpenAIEmbedder will be initialized.
163
+ cross_encoder : CrossEncoderClient | None, optional
164
+ An instance of CrossEncoderClient for reranking tasks.
165
+ If not provided, a default OpenAIRerankerClient will be initialized.
166
+ store_raw_episode_content : bool, optional
167
+ Whether to store the raw content of episodes. Defaults to True.
168
+ graph_driver : GraphDriver | None, optional
169
+ An instance of GraphDriver for database operations.
170
+ If not provided, a default Neo4jDriver will be initialized.
171
+ max_coroutines : int | None, optional
172
+ The maximum number of concurrent operations allowed. Overrides SEMAPHORE_LIMIT set in the environment.
173
+ If not set, the Graphiti default is used.
174
+ tracer : Tracer | None, optional
175
+ An OpenTelemetry tracer instance for distributed tracing. If not provided, tracing is disabled (no-op).
176
+ trace_span_prefix : str, optional
177
+ Prefix to prepend to all span names. Defaults to 'graphiti'.
121
178
 
122
179
  Returns
123
180
  -------
@@ -125,11 +182,11 @@ class Graphiti:
125
182
 
126
183
  Notes
127
184
  -----
128
- This method establishes a connection to the Neo4j database using the provided
185
+ This method establishes a connection to a graph database (Neo4j by default) using the provided
129
186
  credentials. It also sets up the LLM client, either using the provided client
130
187
  or by creating a default OpenAIClient.
131
188
 
132
- The default database name is set to 'neo4j'. If a different database name
189
+ The default database name is defined during the driver’s construction. If a different database name
133
190
  is required, it should be specified in the URI or set separately after
134
191
  initialization.
135
192
 
@@ -137,9 +194,16 @@ class Graphiti:
137
194
  Make sure to set the OPENAI_API_KEY environment variable before initializing
138
195
  Graphiti if you're using the default OpenAIClient.
139
196
  """
140
- self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
141
- self.database = DEFAULT_DATABASE
197
+
198
+ if graph_driver:
199
+ self.driver = graph_driver
200
+ else:
201
+ if uri is None:
202
+ raise ValueError('uri must be provided when graph_driver is None')
203
+ self.driver = Neo4jDriver(uri, user, password)
204
+
142
205
  self.store_raw_episode_content = store_raw_episode_content
206
+ self.max_coroutines = max_coroutines
143
207
  if llm_client:
144
208
  self.llm_client = llm_client
145
209
  else:
@@ -153,13 +217,75 @@ class Graphiti:
153
217
  else:
154
218
  self.cross_encoder = OpenAIRerankerClient()
155
219
 
220
+ # Initialize tracer
221
+ self.tracer = create_tracer(tracer, trace_span_prefix)
222
+
223
+ # Set tracer on clients
224
+ self.llm_client.set_tracer(self.tracer)
225
+
156
226
  self.clients = GraphitiClients(
157
227
  driver=self.driver,
158
228
  llm_client=self.llm_client,
159
229
  embedder=self.embedder,
160
230
  cross_encoder=self.cross_encoder,
231
+ tracer=self.tracer,
161
232
  )
162
233
 
234
+ # Capture telemetry event
235
+ self._capture_initialization_telemetry()
236
+
237
+ def _capture_initialization_telemetry(self):
238
+ """Capture telemetry event for Graphiti initialization."""
239
+ try:
240
+ # Detect provider types from class names
241
+ llm_provider = self._get_provider_type(self.llm_client)
242
+ embedder_provider = self._get_provider_type(self.embedder)
243
+ reranker_provider = self._get_provider_type(self.cross_encoder)
244
+ database_provider = self._get_provider_type(self.driver)
245
+
246
+ properties = {
247
+ 'llm_provider': llm_provider,
248
+ 'embedder_provider': embedder_provider,
249
+ 'reranker_provider': reranker_provider,
250
+ 'database_provider': database_provider,
251
+ }
252
+
253
+ capture_event('graphiti_initialized', properties)
254
+ except Exception:
255
+ # Silently handle telemetry errors
256
+ pass
257
+
258
+ def _get_provider_type(self, client) -> str:
259
+ """Get provider type from client class name."""
260
+ if client is None:
261
+ return 'none'
262
+
263
+ class_name = client.__class__.__name__.lower()
264
+
265
+ # LLM providers
266
+ if 'openai' in class_name:
267
+ return 'openai'
268
+ elif 'azure' in class_name:
269
+ return 'azure'
270
+ elif 'anthropic' in class_name:
271
+ return 'anthropic'
272
+ elif 'crossencoder' in class_name:
273
+ return 'crossencoder'
274
+ elif 'gemini' in class_name:
275
+ return 'gemini'
276
+ elif 'groq' in class_name:
277
+ return 'groq'
278
+ # Database providers
279
+ elif 'neo4j' in class_name:
280
+ return 'neo4j'
281
+ elif 'falkor' in class_name:
282
+ return 'falkordb'
283
+ # Embedder providers
284
+ elif 'voyage' in class_name:
285
+ return 'voyage'
286
+ else:
287
+ return 'unknown'
288
+
163
289
  async def close(self):
164
290
  """
165
291
  Close the connection to the Neo4j database.
@@ -214,25 +340,247 @@ class Graphiti:
214
340
  -----
215
341
  This method should typically be called once during the initial setup of the
216
342
  knowledge graph or when updating the database schema. It uses the
217
- `build_indices_and_constraints` function from the
218
- `graphiti_core.utils.maintenance.graph_data_operations` module to perform
343
+ driver's `build_indices_and_constraints` method to perform
219
344
  the actual database operations.
220
345
 
221
346
  The specific indices and constraints created depend on the implementation
222
- of the `build_indices_and_constraints` function. Refer to that function's
223
- documentation for details on the exact database schema modifications.
347
+ of the driver's `build_indices_and_constraints` method. Refer to the specific
348
+ driver documentation for details on the exact database schema modifications.
224
349
 
225
350
  Caution: Running this method on a large existing database may take some time
226
351
  and could impact database performance during execution.
227
352
  """
228
- await build_indices_and_constraints(self.driver, delete_existing)
353
+ await self.driver.build_indices_and_constraints(delete_existing)
354
+
355
+ async def _extract_and_resolve_nodes(
356
+ self,
357
+ episode: EpisodicNode,
358
+ previous_episodes: list[EpisodicNode],
359
+ entity_types: dict[str, type[BaseModel]] | None,
360
+ excluded_entity_types: list[str] | None,
361
+ ) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
362
+ """Extract nodes from episode and resolve against existing graph."""
363
+ extracted_nodes = await extract_nodes(
364
+ self.clients, episode, previous_episodes, entity_types, excluded_entity_types
365
+ )
366
+
367
+ nodes, uuid_map, duplicates = await resolve_extracted_nodes(
368
+ self.clients,
369
+ extracted_nodes,
370
+ episode,
371
+ previous_episodes,
372
+ entity_types,
373
+ )
374
+
375
+ return nodes, uuid_map, duplicates
376
+
377
+ async def _extract_and_resolve_edges(
378
+ self,
379
+ episode: EpisodicNode,
380
+ extracted_nodes: list[EntityNode],
381
+ previous_episodes: list[EpisodicNode],
382
+ edge_type_map: dict[tuple[str, str], list[str]],
383
+ group_id: str,
384
+ edge_types: dict[str, type[BaseModel]] | None,
385
+ nodes: list[EntityNode],
386
+ uuid_map: dict[str, str],
387
+ ) -> tuple[list[EntityEdge], list[EntityEdge]]:
388
+ """Extract edges from episode and resolve against existing graph."""
389
+ extracted_edges = await extract_edges(
390
+ self.clients,
391
+ episode,
392
+ extracted_nodes,
393
+ previous_episodes,
394
+ edge_type_map,
395
+ group_id,
396
+ edge_types,
397
+ )
398
+
399
+ edges = resolve_edge_pointers(extracted_edges, uuid_map)
400
+
401
+ resolved_edges, invalidated_edges = await resolve_extracted_edges(
402
+ self.clients,
403
+ edges,
404
+ episode,
405
+ nodes,
406
+ edge_types or {},
407
+ edge_type_map,
408
+ )
409
+
410
+ return resolved_edges, invalidated_edges
411
+
412
+ async def _process_episode_data(
413
+ self,
414
+ episode: EpisodicNode,
415
+ nodes: list[EntityNode],
416
+ entity_edges: list[EntityEdge],
417
+ now: datetime,
418
+ ) -> tuple[list[EpisodicEdge], EpisodicNode]:
419
+ """Process and save episode data to the graph."""
420
+ episodic_edges = build_episodic_edges(nodes, episode.uuid, now)
421
+ episode.entity_edges = [edge.uuid for edge in entity_edges]
422
+
423
+ if not self.store_raw_episode_content:
424
+ episode.content = ''
425
+
426
+ await add_nodes_and_edges_bulk(
427
+ self.driver,
428
+ [episode],
429
+ episodic_edges,
430
+ nodes,
431
+ entity_edges,
432
+ self.embedder,
433
+ )
434
+
435
+ return episodic_edges, episode
436
+
437
+ async def _extract_and_dedupe_nodes_bulk(
438
+ self,
439
+ episode_context: list[tuple[EpisodicNode, list[EpisodicNode]]],
440
+ edge_type_map: dict[tuple[str, str], list[str]],
441
+ edge_types: dict[str, type[BaseModel]] | None,
442
+ entity_types: dict[str, type[BaseModel]] | None,
443
+ excluded_entity_types: list[str] | None,
444
+ ) -> tuple[
445
+ dict[str, list[EntityNode]],
446
+ dict[str, str],
447
+ list[list[EntityEdge]],
448
+ ]:
449
+ """Extract nodes and edges from all episodes and deduplicate."""
450
+ # Extract all nodes and edges for each episode
451
+ extracted_nodes_bulk, extracted_edges_bulk = await extract_nodes_and_edges_bulk(
452
+ self.clients,
453
+ episode_context,
454
+ edge_type_map=edge_type_map,
455
+ edge_types=edge_types,
456
+ entity_types=entity_types,
457
+ excluded_entity_types=excluded_entity_types,
458
+ )
459
+
460
+ # Dedupe extracted nodes in memory
461
+ nodes_by_episode, uuid_map = await dedupe_nodes_bulk(
462
+ self.clients, extracted_nodes_bulk, episode_context, entity_types
463
+ )
464
+
465
+ return nodes_by_episode, uuid_map, extracted_edges_bulk
466
+
467
+ async def _resolve_nodes_and_edges_bulk(
468
+ self,
469
+ nodes_by_episode: dict[str, list[EntityNode]],
470
+ edges_by_episode: dict[str, list[EntityEdge]],
471
+ episode_context: list[tuple[EpisodicNode, list[EpisodicNode]]],
472
+ entity_types: dict[str, type[BaseModel]] | None,
473
+ edge_types: dict[str, type[BaseModel]] | None,
474
+ edge_type_map: dict[tuple[str, str], list[str]],
475
+ episodes: list[EpisodicNode],
476
+ ) -> tuple[list[EntityNode], list[EntityEdge], list[EntityEdge], dict[str, str]]:
477
+ """Resolve nodes and edges against the existing graph."""
478
+ nodes_by_uuid: dict[str, EntityNode] = {
479
+ node.uuid: node for nodes in nodes_by_episode.values() for node in nodes
480
+ }
481
+
482
+ # Get unique nodes per episode
483
+ nodes_by_episode_unique: dict[str, list[EntityNode]] = {}
484
+ nodes_uuid_set: set[str] = set()
485
+ for episode, _ in episode_context:
486
+ nodes_by_episode_unique[episode.uuid] = []
487
+ nodes = [nodes_by_uuid[node.uuid] for node in nodes_by_episode[episode.uuid]]
488
+ for node in nodes:
489
+ if node.uuid not in nodes_uuid_set:
490
+ nodes_by_episode_unique[episode.uuid].append(node)
491
+ nodes_uuid_set.add(node.uuid)
492
+
493
+ # Resolve nodes
494
+ node_results = await semaphore_gather(
495
+ *[
496
+ resolve_extracted_nodes(
497
+ self.clients,
498
+ nodes_by_episode_unique[episode.uuid],
499
+ episode,
500
+ previous_episodes,
501
+ entity_types,
502
+ )
503
+ for episode, previous_episodes in episode_context
504
+ ]
505
+ )
506
+
507
+ resolved_nodes: list[EntityNode] = []
508
+ uuid_map: dict[str, str] = {}
509
+ for result in node_results:
510
+ resolved_nodes.extend(result[0])
511
+ uuid_map.update(result[1])
512
+
513
+ # Update nodes_by_uuid with resolved nodes
514
+ for resolved_node in resolved_nodes:
515
+ nodes_by_uuid[resolved_node.uuid] = resolved_node
516
+
517
+ # Update nodes_by_episode_unique with resolved pointers
518
+ for episode_uuid, nodes in nodes_by_episode_unique.items():
519
+ updated_nodes: list[EntityNode] = []
520
+ for node in nodes:
521
+ updated_node_uuid = uuid_map.get(node.uuid, node.uuid)
522
+ updated_node = nodes_by_uuid[updated_node_uuid]
523
+ updated_nodes.append(updated_node)
524
+ nodes_by_episode_unique[episode_uuid] = updated_nodes
525
+
526
+ # Extract attributes for resolved nodes
527
+ hydrated_nodes_results: list[list[EntityNode]] = await semaphore_gather(
528
+ *[
529
+ extract_attributes_from_nodes(
530
+ self.clients,
531
+ nodes_by_episode_unique[episode.uuid],
532
+ episode,
533
+ previous_episodes,
534
+ entity_types,
535
+ )
536
+ for episode, previous_episodes in episode_context
537
+ ]
538
+ )
539
+
540
+ final_hydrated_nodes = [node for nodes in hydrated_nodes_results for node in nodes]
541
+
542
+ # Resolve edges with updated pointers
543
+ edges_by_episode_unique: dict[str, list[EntityEdge]] = {}
544
+ edges_uuid_set: set[str] = set()
545
+ for episode_uuid, edges in edges_by_episode.items():
546
+ edges_with_updated_pointers = resolve_edge_pointers(edges, uuid_map)
547
+ edges_by_episode_unique[episode_uuid] = []
548
+
549
+ for edge in edges_with_updated_pointers:
550
+ if edge.uuid not in edges_uuid_set:
551
+ edges_by_episode_unique[episode_uuid].append(edge)
552
+ edges_uuid_set.add(edge.uuid)
553
+
554
+ edge_results = await semaphore_gather(
555
+ *[
556
+ resolve_extracted_edges(
557
+ self.clients,
558
+ edges_by_episode_unique[episode.uuid],
559
+ episode,
560
+ final_hydrated_nodes,
561
+ edge_types or {},
562
+ edge_type_map,
563
+ )
564
+ for episode in episodes
565
+ ]
566
+ )
567
+
568
+ resolved_edges: list[EntityEdge] = []
569
+ invalidated_edges: list[EntityEdge] = []
570
+ for result in edge_results:
571
+ resolved_edges.extend(result[0])
572
+ invalidated_edges.extend(result[1])
573
+
574
+ return final_hydrated_nodes, resolved_edges, invalidated_edges, uuid_map
229
575
 
576
+ @handle_multiple_group_ids
230
577
  async def retrieve_episodes(
231
578
  self,
232
579
  reference_time: datetime,
233
580
  last_n: int = EPISODE_WINDOW_LEN,
234
581
  group_ids: list[str] | None = None,
235
582
  source: EpisodeType | None = None,
583
+ driver: GraphDriver | None = None,
236
584
  ) -> list[EpisodicNode]:
237
585
  """
238
586
  Retrieve the last n episodic nodes from the graph.
@@ -259,7 +607,10 @@ class Graphiti:
259
607
  The actual retrieval is performed by the `retrieve_episodes` function
260
608
  from the `graphiti_core.utils` module.
261
609
  """
262
- return await retrieve_episodes(self.driver, reference_time, last_n, group_ids, source)
610
+ if driver is None:
611
+ driver = self.clients.driver
612
+
613
+ return await retrieve_episodes(driver, reference_time, last_n, group_ids, source)
263
614
 
264
615
  async def add_episode(
265
616
  self,
@@ -268,12 +619,13 @@ class Graphiti:
268
619
  source_description: str,
269
620
  reference_time: datetime,
270
621
  source: EpisodeType = EpisodeType.message,
271
- group_id: str = '',
622
+ group_id: str | None = None,
272
623
  uuid: str | None = None,
273
624
  update_communities: bool = False,
274
- entity_types: dict[str, BaseModel] | None = None,
625
+ entity_types: dict[str, type[BaseModel]] | None = None,
626
+ excluded_entity_types: list[str] | None = None,
275
627
  previous_episode_uuids: list[str] | None = None,
276
- edge_types: dict[str, BaseModel] | None = None,
628
+ edge_types: dict[str, type[BaseModel]] | None = None,
277
629
  edge_type_map: dict[tuple[str, str], list[str]] | None = None,
278
630
  ) -> AddEpisodeResults:
279
631
  """
@@ -300,6 +652,12 @@ class Graphiti:
300
652
  Optional uuid of the episode.
301
653
  update_communities : bool
302
654
  Optional. Whether to update communities with new node information
655
+ entity_types : dict[str, BaseModel] | None
656
+ Optional. Dictionary mapping entity type names to their Pydantic model definitions.
657
+ excluded_entity_types : list[str] | None
658
+ Optional. List of entity type names to exclude from the graph. Entities classified
659
+ into these types will not be added to the graph. Can include 'Entity' to exclude
660
+ the default entity type.
303
661
  previous_episode_uuids : list[str] | None
304
662
  Optional. list of episode uuids to use as the previous episodes. If this is not provided,
305
663
  the most recent episodes by created_at date will be used.
@@ -325,112 +683,155 @@ class Graphiti:
325
683
  background_tasks.add_task(graphiti.add_episode, **episode_data.dict())
326
684
  return {"message": "Episode processing started"}
327
685
  """
328
- try:
329
- start = time()
330
- now = utc_now()
686
+ start = time()
687
+ now = utc_now()
331
688
 
332
- validate_entity_types(entity_types)
689
+ validate_entity_types(entity_types)
690
+ validate_excluded_entity_types(excluded_entity_types, entity_types)
333
691
 
334
- previous_episodes = (
335
- await self.retrieve_episodes(
336
- reference_time,
337
- last_n=RELEVANT_SCHEMA_LIMIT,
338
- group_ids=[group_id],
339
- source=source,
340
- )
341
- if previous_episode_uuids is None
342
- else await EpisodicNode.get_by_uuids(self.driver, previous_episode_uuids)
343
- )
692
+ if group_id is None:
693
+ # if group_id is None, use the default group id by the provider
694
+ # and the preset database name will be used
695
+ group_id = get_default_group_id(self.driver.provider)
696
+ else:
697
+ validate_group_id(group_id)
698
+ if group_id != self.driver._database:
699
+ # if group_id is provided, use it as the database name
700
+ self.driver = self.driver.clone(database=group_id)
701
+ self.clients.driver = self.driver
344
702
 
345
- episode = (
346
- await EpisodicNode.get_by_uuid(self.driver, uuid)
347
- if uuid is not None
348
- else EpisodicNode(
349
- name=name,
350
- group_id=group_id,
351
- labels=[],
352
- source=source,
353
- content=episode_body,
354
- source_description=source_description,
355
- created_at=now,
356
- valid_at=reference_time,
703
+ with self.tracer.start_span('add_episode') as span:
704
+ try:
705
+ # Retrieve previous episodes for context
706
+ previous_episodes = (
707
+ await self.retrieve_episodes(
708
+ reference_time,
709
+ last_n=RELEVANT_SCHEMA_LIMIT,
710
+ group_ids=[group_id],
711
+ source=source,
712
+ )
713
+ if previous_episode_uuids is None
714
+ else await EpisodicNode.get_by_uuids(self.driver, previous_episode_uuids)
357
715
  )
358
- )
359
716
 
360
- # Create default edge type map
361
- edge_type_map_default = (
362
- {('Entity', 'Entity'): list(edge_types.keys())}
363
- if edge_types is not None
364
- else {('Entity', 'Entity'): []}
365
- )
717
+ # Get or create episode
718
+ episode = (
719
+ await EpisodicNode.get_by_uuid(self.driver, uuid)
720
+ if uuid is not None
721
+ else EpisodicNode(
722
+ name=name,
723
+ group_id=group_id,
724
+ labels=[],
725
+ source=source,
726
+ content=episode_body,
727
+ source_description=source_description,
728
+ created_at=now,
729
+ valid_at=reference_time,
730
+ )
731
+ )
366
732
 
367
- # Extract entities as nodes
733
+ # Create default edge type map
734
+ edge_type_map_default = (
735
+ {('Entity', 'Entity'): list(edge_types.keys())}
736
+ if edge_types is not None
737
+ else {('Entity', 'Entity'): []}
738
+ )
368
739
 
369
- extracted_nodes = await extract_nodes(
370
- self.clients, episode, previous_episodes, entity_types
371
- )
740
+ # Extract and resolve nodes
741
+ extracted_nodes = await extract_nodes(
742
+ self.clients, episode, previous_episodes, entity_types, excluded_entity_types
743
+ )
372
744
 
373
- # Extract edges and resolve nodes
374
- (nodes, uuid_map), extracted_edges = await semaphore_gather(
375
- resolve_extracted_nodes(
745
+ nodes, uuid_map, _ = await resolve_extracted_nodes(
376
746
  self.clients,
377
747
  extracted_nodes,
378
748
  episode,
379
749
  previous_episodes,
380
750
  entity_types,
381
- ),
382
- extract_edges(
383
- self.clients, episode, extracted_nodes, previous_episodes, group_id, edge_types
384
- ),
385
- )
386
-
387
- edges = resolve_edge_pointers(extracted_edges, uuid_map)
751
+ )
388
752
 
389
- (resolved_edges, invalidated_edges), hydrated_nodes = await semaphore_gather(
390
- resolve_extracted_edges(
391
- self.clients,
392
- edges,
753
+ # Extract and resolve edges in parallel with attribute extraction
754
+ resolved_edges, invalidated_edges = await self._extract_and_resolve_edges(
393
755
  episode,
394
- nodes,
395
- edge_types or {},
756
+ extracted_nodes,
757
+ previous_episodes,
396
758
  edge_type_map or edge_type_map_default,
397
- ),
398
- extract_attributes_from_nodes(
399
- self.clients, nodes, episode, previous_episodes, entity_types
400
- ),
401
- )
759
+ group_id,
760
+ edge_types,
761
+ nodes,
762
+ uuid_map,
763
+ )
402
764
 
403
- entity_edges = resolved_edges + invalidated_edges
765
+ # Extract node attributes
766
+ hydrated_nodes = await extract_attributes_from_nodes(
767
+ self.clients, nodes, episode, previous_episodes, entity_types
768
+ )
404
769
 
405
- episodic_edges = build_episodic_edges(nodes, episode, now)
770
+ entity_edges = resolved_edges + invalidated_edges
406
771
 
407
- episode.entity_edges = [edge.uuid for edge in entity_edges]
772
+ # Process and save episode data
773
+ episodic_edges, episode = await self._process_episode_data(
774
+ episode, hydrated_nodes, entity_edges, now
775
+ )
408
776
 
409
- if not self.store_raw_episode_content:
410
- episode.content = ''
777
+ # Update communities if requested
778
+ communities = []
779
+ community_edges = []
780
+ if update_communities:
781
+ communities, community_edges = await semaphore_gather(
782
+ *[
783
+ update_community(self.driver, self.llm_client, self.embedder, node)
784
+ for node in nodes
785
+ ],
786
+ max_coroutines=self.max_coroutines,
787
+ )
788
+
789
+ end = time()
790
+
791
+ # Add span attributes
792
+ span.add_attributes(
793
+ {
794
+ 'episode.uuid': episode.uuid,
795
+ 'episode.source': source.value,
796
+ 'episode.reference_time': reference_time.isoformat(),
797
+ 'group_id': group_id,
798
+ 'node.count': len(hydrated_nodes),
799
+ 'edge.count': len(entity_edges),
800
+ 'edge.invalidated_count': len(invalidated_edges),
801
+ 'previous_episodes.count': len(previous_episodes),
802
+ 'entity_types.count': len(entity_types) if entity_types else 0,
803
+ 'edge_types.count': len(edge_types) if edge_types else 0,
804
+ 'update_communities': update_communities,
805
+ 'communities.count': len(communities) if update_communities else 0,
806
+ 'duration_ms': (end - start) * 1000,
807
+ }
808
+ )
411
809
 
412
- await add_nodes_and_edges_bulk(
413
- self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges, self.embedder
414
- )
810
+ logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
415
811
 
416
- # Update any communities
417
- if update_communities:
418
- await semaphore_gather(
419
- *[
420
- update_community(self.driver, self.llm_client, self.embedder, node)
421
- for node in nodes
422
- ]
812
+ return AddEpisodeResults(
813
+ episode=episode,
814
+ episodic_edges=episodic_edges,
815
+ nodes=hydrated_nodes,
816
+ edges=entity_edges,
817
+ communities=communities,
818
+ community_edges=community_edges,
423
819
  )
424
- end = time()
425
- logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
426
-
427
- return AddEpisodeResults(episode=episode, nodes=nodes, edges=entity_edges)
428
820
 
429
- except Exception as e:
430
- raise e
821
+ except Exception as e:
822
+ span.set_status('error', str(e))
823
+ span.record_exception(e)
824
+ raise e
431
825
 
432
- #### WIP: USE AT YOUR OWN RISK ####
433
- async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str = ''):
826
+ async def add_episode_bulk(
827
+ self,
828
+ bulk_episodes: list[RawEpisode],
829
+ group_id: str | None = None,
830
+ entity_types: dict[str, type[BaseModel]] | None = None,
831
+ excluded_entity_types: list[str] | None = None,
832
+ edge_types: dict[str, type[BaseModel]] | None = None,
833
+ edge_type_map: dict[tuple[str, str], list[str]] | None = None,
834
+ ) -> AddBulkEpisodeResults:
434
835
  """
435
836
  Process multiple episodes in bulk and update the graph.
436
837
 
@@ -446,7 +847,7 @@ class Graphiti:
446
847
 
447
848
  Returns
448
849
  -------
449
- None
850
+ AddBulkEpisodeResults
450
851
 
451
852
  Notes
452
853
  -----
@@ -467,106 +868,186 @@ class Graphiti:
467
868
  If these operations are required, use the `add_episode` method instead for each
468
869
  individual episode.
469
870
  """
470
- try:
471
- start = time()
472
- now = utc_now()
473
-
474
- episodes = [
475
- EpisodicNode(
476
- name=episode.name,
477
- labels=[],
478
- source=episode.source,
479
- content=episode.content,
480
- source_description=episode.source_description,
481
- group_id=group_id,
482
- created_at=now,
483
- valid_at=episode.reference_time,
871
+ with self.tracer.start_span('add_episode_bulk') as bulk_span:
872
+ bulk_span.add_attributes({'episode.count': len(bulk_episodes)})
873
+
874
+ try:
875
+ start = time()
876
+ now = utc_now()
877
+
878
+ # if group_id is None, use the default group id by the provider
879
+ if group_id is None:
880
+ group_id = get_default_group_id(self.driver.provider)
881
+ else:
882
+ validate_group_id(group_id)
883
+ if group_id != self.driver._database:
884
+ # if group_id is provided, use it as the database name
885
+ self.driver = self.driver.clone(database=group_id)
886
+ self.clients.driver = self.driver
887
+
888
+ # Create default edge type map
889
+ edge_type_map_default = (
890
+ {('Entity', 'Entity'): list(edge_types.keys())}
891
+ if edge_types is not None
892
+ else {('Entity', 'Entity'): []}
484
893
  )
485
- for episode in bulk_episodes
486
- ]
487
894
 
488
- # Save all the episodes
489
- await semaphore_gather(*[episode.save(self.driver) for episode in episodes])
895
+ episodes = [
896
+ await EpisodicNode.get_by_uuid(self.driver, episode.uuid)
897
+ if episode.uuid is not None
898
+ else EpisodicNode(
899
+ name=episode.name,
900
+ labels=[],
901
+ source=episode.source,
902
+ content=episode.content,
903
+ source_description=episode.source_description,
904
+ group_id=group_id,
905
+ created_at=now,
906
+ valid_at=episode.reference_time,
907
+ )
908
+ for episode in bulk_episodes
909
+ ]
910
+
911
+ # Save all episodes
912
+ await add_nodes_and_edges_bulk(
913
+ driver=self.driver,
914
+ episodic_nodes=episodes,
915
+ episodic_edges=[],
916
+ entity_nodes=[],
917
+ entity_edges=[],
918
+ embedder=self.embedder,
919
+ )
490
920
 
491
- # Get previous episode context for each episode
492
- episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
921
+ # Get previous episode context for each episode
922
+ episode_context = await retrieve_previous_episodes_bulk(self.driver, episodes)
493
923
 
494
- # Extract all nodes and edges
495
- (
496
- extracted_nodes,
497
- extracted_edges,
498
- episodic_edges,
499
- ) = await extract_nodes_and_edges_bulk(self.clients, episode_pairs)
924
+ # Extract and dedupe nodes and edges
925
+ (
926
+ nodes_by_episode,
927
+ uuid_map,
928
+ extracted_edges_bulk,
929
+ ) = await self._extract_and_dedupe_nodes_bulk(
930
+ episode_context,
931
+ edge_type_map or edge_type_map_default,
932
+ edge_types,
933
+ entity_types,
934
+ excluded_entity_types,
935
+ )
500
936
 
501
- # Generate embeddings
502
- await semaphore_gather(
503
- *[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
504
- *[edge.generate_embedding(self.embedder) for edge in extracted_edges],
505
- )
937
+ # Create Episodic Edges
938
+ episodic_edges: list[EpisodicEdge] = []
939
+ for episode_uuid, nodes in nodes_by_episode.items():
940
+ episodic_edges.extend(build_episodic_edges(nodes, episode_uuid, now))
506
941
 
507
- # Dedupe extracted nodes, compress extracted edges
508
- (nodes, uuid_map), extracted_edges_timestamped = await semaphore_gather(
509
- dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes),
510
- extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
511
- )
942
+ # Re-map edge pointers and dedupe edges
943
+ extracted_edges_bulk_updated: list[list[EntityEdge]] = [
944
+ resolve_edge_pointers(edges, uuid_map) for edges in extracted_edges_bulk
945
+ ]
512
946
 
513
- # save nodes to KG
514
- await semaphore_gather(*[node.save(self.driver) for node in nodes])
947
+ edges_by_episode = await dedupe_edges_bulk(
948
+ self.clients,
949
+ extracted_edges_bulk_updated,
950
+ episode_context,
951
+ [],
952
+ edge_types or {},
953
+ edge_type_map or edge_type_map_default,
954
+ )
515
955
 
516
- # re-map edge pointers so that they don't point to discard dupe nodes
517
- extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
518
- extracted_edges_timestamped, uuid_map
519
- )
520
- episodic_edges_with_resolved_pointers: list[EpisodicEdge] = resolve_edge_pointers(
521
- episodic_edges, uuid_map
522
- )
956
+ # Resolve nodes and edges against the existing graph
957
+ (
958
+ final_hydrated_nodes,
959
+ resolved_edges,
960
+ invalidated_edges,
961
+ final_uuid_map,
962
+ ) = await self._resolve_nodes_and_edges_bulk(
963
+ nodes_by_episode,
964
+ edges_by_episode,
965
+ episode_context,
966
+ entity_types,
967
+ edge_types,
968
+ edge_type_map or edge_type_map_default,
969
+ episodes,
970
+ )
523
971
 
524
- # save episodic edges to KG
525
- await semaphore_gather(
526
- *[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers]
527
- )
972
+ # Resolved pointers for episodic edges
973
+ resolved_episodic_edges = resolve_edge_pointers(episodic_edges, final_uuid_map)
974
+
975
+ # save data to KG
976
+ await add_nodes_and_edges_bulk(
977
+ self.driver,
978
+ episodes,
979
+ resolved_episodic_edges,
980
+ final_hydrated_nodes,
981
+ resolved_edges + invalidated_edges,
982
+ self.embedder,
983
+ )
528
984
 
529
- # Dedupe extracted edges
530
- edges = await dedupe_edges_bulk(
531
- self.driver, self.llm_client, extracted_edges_with_resolved_pointers
532
- )
533
- logger.debug(f'extracted edge length: {len(edges)}')
985
+ end = time()
534
986
 
535
- # invalidate edges
987
+ # Add span attributes
988
+ bulk_span.add_attributes(
989
+ {
990
+ 'group_id': group_id,
991
+ 'node.count': len(final_hydrated_nodes),
992
+ 'edge.count': len(resolved_edges + invalidated_edges),
993
+ 'duration_ms': (end - start) * 1000,
994
+ }
995
+ )
536
996
 
537
- # save edges to KG
538
- await semaphore_gather(*[edge.save(self.driver) for edge in edges])
997
+ logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
539
998
 
540
- end = time()
541
- logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
999
+ return AddBulkEpisodeResults(
1000
+ episodes=episodes,
1001
+ episodic_edges=resolved_episodic_edges,
1002
+ nodes=final_hydrated_nodes,
1003
+ edges=resolved_edges + invalidated_edges,
1004
+ communities=[],
1005
+ community_edges=[],
1006
+ )
542
1007
 
543
- except Exception as e:
544
- raise e
1008
+ except Exception as e:
1009
+ bulk_span.set_status('error', str(e))
1010
+ bulk_span.record_exception(e)
1011
+ raise e
545
1012
 
546
- async def build_communities(self, group_ids: list[str] | None = None) -> list[CommunityNode]:
1013
+ @handle_multiple_group_ids
1014
+ async def build_communities(
1015
+ self, group_ids: list[str] | None = None, driver: GraphDriver | None = None
1016
+ ) -> tuple[list[CommunityNode], list[CommunityEdge]]:
547
1017
  """
548
1018
  Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
549
1019
  the content of these communities.
550
1020
  ----------
551
- query : list[str] | None
1021
+ group_ids : list[str] | None
552
1022
  Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
553
1023
  """
1024
+ if driver is None:
1025
+ driver = self.clients.driver
1026
+
554
1027
  # Clear existing communities
555
- await remove_communities(self.driver)
1028
+ await remove_communities(driver)
556
1029
 
557
1030
  community_nodes, community_edges = await build_communities(
558
- self.driver, self.llm_client, group_ids
1031
+ driver, self.llm_client, group_ids
559
1032
  )
560
1033
 
561
1034
  await semaphore_gather(
562
- *[node.generate_name_embedding(self.embedder) for node in community_nodes]
1035
+ *[node.generate_name_embedding(self.embedder) for node in community_nodes],
1036
+ max_coroutines=self.max_coroutines,
563
1037
  )
564
1038
 
565
- await semaphore_gather(*[node.save(self.driver) for node in community_nodes])
566
- await semaphore_gather(*[edge.save(self.driver) for edge in community_edges])
1039
+ await semaphore_gather(
1040
+ *[node.save(driver) for node in community_nodes],
1041
+ max_coroutines=self.max_coroutines,
1042
+ )
1043
+ await semaphore_gather(
1044
+ *[edge.save(driver) for edge in community_edges],
1045
+ max_coroutines=self.max_coroutines,
1046
+ )
567
1047
 
568
- return community_nodes
1048
+ return community_nodes, community_edges
569
1049
 
1050
+ @handle_multiple_group_ids
570
1051
  async def search(
571
1052
  self,
572
1053
  query: str,
@@ -574,6 +1055,7 @@ class Graphiti:
574
1055
  group_ids: list[str] | None = None,
575
1056
  num_results=DEFAULT_SEARCH_LIMIT,
576
1057
  search_filter: SearchFilters | None = None,
1058
+ driver: GraphDriver | None = None,
577
1059
  ) -> list[EntityEdge]:
578
1060
  """
579
1061
  Perform a hybrid search on the knowledge graph.
@@ -620,7 +1102,8 @@ class Graphiti:
620
1102
  group_ids,
621
1103
  search_config,
622
1104
  search_filter if search_filter is not None else SearchFilters(),
623
- center_node_uuid,
1105
+ driver=driver,
1106
+ center_node_uuid=center_node_uuid,
624
1107
  )
625
1108
  ).edges
626
1109
 
@@ -640,6 +1123,7 @@ class Graphiti:
640
1123
  query, config, group_ids, center_node_uuid, bfs_origin_node_uuids, search_filter
641
1124
  )
642
1125
 
1126
+ @handle_multiple_group_ids
643
1127
  async def search_(
644
1128
  self,
645
1129
  query: str,
@@ -648,6 +1132,7 @@ class Graphiti:
648
1132
  center_node_uuid: str | None = None,
649
1133
  bfs_origin_node_uuids: list[str] | None = None,
650
1134
  search_filter: SearchFilters | None = None,
1135
+ driver: GraphDriver | None = None,
651
1136
  ) -> SearchResults:
652
1137
  """search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather
653
1138
  than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and
@@ -664,22 +1149,26 @@ class Graphiti:
664
1149
  search_filter if search_filter is not None else SearchFilters(),
665
1150
  center_node_uuid,
666
1151
  bfs_origin_node_uuids,
1152
+ driver=driver,
667
1153
  )
668
1154
 
669
1155
  async def get_nodes_and_edges_by_episode(self, episode_uuids: list[str]) -> SearchResults:
670
1156
  episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
671
1157
 
672
1158
  edges_list = await semaphore_gather(
673
- *[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
1159
+ *[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes],
1160
+ max_coroutines=self.max_coroutines,
674
1161
  )
675
1162
 
676
1163
  edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
677
1164
 
678
1165
  nodes = await get_mentioned_nodes(self.driver, episodes)
679
1166
 
680
- return SearchResults(edges=edges, nodes=nodes, episodes=[], communities=[])
1167
+ return SearchResults(edges=edges, nodes=nodes)
681
1168
 
682
- async def add_triplet(self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode):
1169
+ async def add_triplet(
1170
+ self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode
1171
+ ) -> AddTripletResults:
683
1172
  if source_node.name_embedding is None:
684
1173
  await source_node.generate_name_embedding(self.embedder)
685
1174
  if target_node.name_embedding is None:
@@ -687,19 +1176,37 @@ class Graphiti:
687
1176
  if edge.fact_embedding is None:
688
1177
  await edge.generate_embedding(self.embedder)
689
1178
 
690
- resolved_nodes, uuid_map = await resolve_extracted_nodes(
1179
+ nodes, uuid_map, _ = await resolve_extracted_nodes(
691
1180
  self.clients,
692
1181
  [source_node, target_node],
693
1182
  )
694
1183
 
695
1184
  updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
696
1185
 
697
- related_edges = (await get_relevant_edges(self.driver, [updated_edge], SearchFilters()))[0]
1186
+ valid_edges = await EntityEdge.get_between_nodes(
1187
+ self.driver, edge.source_node_uuid, edge.target_node_uuid
1188
+ )
1189
+
1190
+ related_edges = (
1191
+ await search(
1192
+ self.clients,
1193
+ updated_edge.fact,
1194
+ group_ids=[updated_edge.group_id],
1195
+ config=EDGE_HYBRID_SEARCH_RRF,
1196
+ search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
1197
+ )
1198
+ ).edges
698
1199
  existing_edges = (
699
- await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
700
- )[0]
1200
+ await search(
1201
+ self.clients,
1202
+ updated_edge.fact,
1203
+ group_ids=[updated_edge.group_id],
1204
+ config=EDGE_HYBRID_SEARCH_RRF,
1205
+ search_filter=SearchFilters(),
1206
+ )
1207
+ ).edges
701
1208
 
702
- resolved_edge, invalidated_edges = await resolve_extracted_edge(
1209
+ resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
703
1210
  self.llm_client,
704
1211
  updated_edge,
705
1212
  related_edges,
@@ -713,11 +1220,17 @@ class Graphiti:
713
1220
  entity_edges=[],
714
1221
  group_id=edge.group_id,
715
1222
  ),
1223
+ None,
1224
+ None,
716
1225
  )
717
1226
 
718
- await add_nodes_and_edges_bulk(
719
- self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges, self.embedder
720
- )
1227
+ edges: list[EntityEdge] = [resolved_edge] + invalidated_edges
1228
+
1229
+ await create_entity_edge_embeddings(self.embedder, edges)
1230
+ await create_entity_node_embeddings(self.embedder, nodes)
1231
+
1232
+ await add_nodes_and_edges_bulk(self.driver, [], [], nodes, edges, self.embedder)
1233
+ return AddTripletResults(edges=edges, nodes=nodes)
721
1234
 
722
1235
  async def remove_episode(self, episode_uuid: str):
723
1236
  # Find the episode to be deleted
@@ -738,14 +1251,13 @@ class Graphiti:
738
1251
  nodes_to_delete: list[EntityNode] = []
739
1252
  for node in nodes:
740
1253
  query: LiteralString = 'MATCH (e:Episodic)-[:MENTIONS]->(n:Entity {uuid: $uuid}) RETURN count(*) AS episode_count'
741
- records, _, _ = await self.driver.execute_query(
742
- query, uuid=node.uuid, database_=DEFAULT_DATABASE, routing_='r'
743
- )
1254
+ records, _, _ = await self.driver.execute_query(query, uuid=node.uuid, routing_='r')
744
1255
 
745
1256
  for record in records:
746
1257
  if record['episode_count'] == 1:
747
1258
  nodes_to_delete.append(node)
748
1259
 
749
- await semaphore_gather(*[node.delete(self.driver) for node in nodes_to_delete])
750
- await semaphore_gather(*[edge.delete(self.driver) for edge in edges_to_delete])
1260
+ await Edge.delete_by_uuids(self.driver, [edge.uuid for edge in edges_to_delete])
1261
+ await Node.delete_by_uuids(self.driver, [node.uuid for node in nodes_to_delete])
1262
+
751
1263
  await episode.delete(self.driver)