graphiti-core 0.2.1__tar.gz → 0.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (37) hide show
  1. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/PKG-INFO +3 -4
  2. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/README.md +1 -1
  3. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/edges.py +39 -32
  4. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/graphiti.py +100 -93
  5. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/nodes.py +45 -39
  6. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/prompts/dedupe_edges.py +1 -1
  7. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/prompts/invalidate_edges.py +37 -1
  8. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/search/search.py +5 -2
  9. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/search/search_utils.py +101 -168
  10. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/utils/bulk_utils.py +31 -3
  11. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/utils/maintenance/edge_operations.py +104 -16
  12. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/utils/maintenance/graph_data_operations.py +17 -8
  13. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/utils/maintenance/node_operations.py +1 -0
  14. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/utils/maintenance/temporal_operations.py +34 -0
  15. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/pyproject.toml +3 -6
  16. graphiti_core-0.2.1/graphiti_core/utils/utils.py +0 -60
  17. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/LICENSE +0 -0
  18. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/__init__.py +0 -0
  19. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/helpers.py +0 -0
  20. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/llm_client/__init__.py +0 -0
  21. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/llm_client/anthropic_client.py +0 -0
  22. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/llm_client/client.py +0 -0
  23. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/llm_client/config.py +0 -0
  24. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/llm_client/groq_client.py +0 -0
  25. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/llm_client/openai_client.py +0 -0
  26. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/llm_client/utils.py +0 -0
  27. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/prompts/__init__.py +0 -0
  28. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/prompts/dedupe_nodes.py +0 -0
  29. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/prompts/extract_edge_dates.py +0 -0
  30. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/prompts/extract_edges.py +0 -0
  31. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/prompts/extract_nodes.py +0 -0
  32. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/prompts/lib.py +0 -0
  33. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/prompts/models.py +0 -0
  34. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/search/__init__.py +0 -0
  35. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/utils/__init__.py +0 -0
  36. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/utils/maintenance/__init__.py +0 -0
  37. {graphiti_core-0.2.1 → graphiti_core-0.2.3}/graphiti_core/utils/maintenance/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: graphiti-core
3
- Version: 0.2.1
3
+ Version: 0.2.3
4
4
  Summary: A temporal graph building library
5
5
  License: Apache-2.0
6
6
  Author: Paul Paliychuk
@@ -12,11 +12,10 @@ Classifier: Programming Language :: Python :: 3.10
12
12
  Classifier: Programming Language :: Python :: 3.11
13
13
  Classifier: Programming Language :: Python :: 3.12
14
14
  Requires-Dist: diskcache (>=5.6.3,<6.0.0)
15
- Requires-Dist: fastapi (>=0.112.0,<0.113.0)
16
15
  Requires-Dist: neo4j (>=5.23.0,<6.0.0)
16
+ Requires-Dist: numpy (>=2.1.1,<3.0.0)
17
17
  Requires-Dist: openai (>=1.38.0,<2.0.0)
18
18
  Requires-Dist: pydantic (>=2.8.2,<3.0.0)
19
- Requires-Dist: sentence-transformers (>=3.0.1,<4.0.0)
20
19
  Requires-Dist: tenacity (<9.0.0)
21
20
  Description-Content-Type: text/markdown
22
21
 
@@ -173,7 +172,7 @@ graphiti.close()
173
172
 
174
173
  ## Documentation
175
174
 
176
- - [Guides and API documentation](https://help.getzep.com/Graphiti/Graphiti).
175
+ - [Guides and API documentation](https://help.getzep.com/graphiti).
177
176
  - [Quick Start](https://help.getzep.com/graphiti/graphiti/quick-start)
178
177
  - [Building an agent with LangChain's LangGraph and Graphiti](https://help.getzep.com/graphiti/graphiti/lang-graph-agent)
179
178
 
@@ -151,7 +151,7 @@ graphiti.close()
151
151
 
152
152
  ## Documentation
153
153
 
154
- - [Guides and API documentation](https://help.getzep.com/Graphiti/Graphiti).
154
+ - [Guides and API documentation](https://help.getzep.com/graphiti).
155
155
  - [Quick Start](https://help.getzep.com/graphiti/graphiti/quick-start)
156
156
  - [Building an agent with LangChain's LangGraph and Graphiti](https://help.getzep.com/graphiti/graphiti/lang-graph-agent)
157
157
 
@@ -18,6 +18,7 @@ import logging
18
18
  from abc import ABC, abstractmethod
19
19
  from datetime import datetime
20
20
  from time import time
21
+ from typing import Any
21
22
  from uuid import uuid4
22
23
 
23
24
  from neo4j import AsyncDriver
@@ -32,6 +33,7 @@ logger = logging.getLogger(__name__)
32
33
 
33
34
  class Edge(BaseModel, ABC):
34
35
  uuid: str = Field(default_factory=lambda: uuid4().hex)
36
+ group_id: str | None = Field(description='partition of the graph')
35
37
  source_node_uuid: str
36
38
  target_node_uuid: str
37
39
  created_at: datetime
@@ -61,11 +63,12 @@ class EpisodicEdge(Edge):
61
63
  MATCH (episode:Episodic {uuid: $episode_uuid})
62
64
  MATCH (node:Entity {uuid: $entity_uuid})
63
65
  MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
64
- SET r = {uuid: $uuid, created_at: $created_at}
66
+ SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
65
67
  RETURN r.uuid AS uuid""",
66
68
  episode_uuid=self.source_node_uuid,
67
69
  entity_uuid=self.target_node_uuid,
68
70
  uuid=self.uuid,
71
+ group_id=self.group_id,
69
72
  created_at=self.created_at,
70
73
  )
71
74
 
@@ -92,7 +95,8 @@ class EpisodicEdge(Edge):
92
95
  """
93
96
  MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
94
97
  RETURN
95
- e.uuid As uuid,
98
+ e.uuid As uuid,
99
+ e.group_id AS group_id,
96
100
  n.uuid AS source_node_uuid,
97
101
  m.uuid AS target_node_uuid,
98
102
  e.created_at AS created_at
@@ -100,17 +104,7 @@ class EpisodicEdge(Edge):
100
104
  uuid=uuid,
101
105
  )
102
106
 
103
- edges: list[EpisodicEdge] = []
104
-
105
- for record in records:
106
- edges.append(
107
- EpisodicEdge(
108
- uuid=record['uuid'],
109
- source_node_uuid=record['source_node_uuid'],
110
- target_node_uuid=record['target_node_uuid'],
111
- created_at=record['created_at'].to_native(),
112
- )
113
- )
107
+ edges = [get_episodic_edge_from_record(record) for record in records]
114
108
 
115
109
  logger.info(f'Found Edge: {uuid}')
116
110
 
@@ -153,7 +147,7 @@ class EntityEdge(Edge):
153
147
  MATCH (source:Entity {uuid: $source_uuid})
154
148
  MATCH (target:Entity {uuid: $target_uuid})
155
149
  MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
156
- SET r = {uuid: $uuid, name: $name, fact: $fact, fact_embedding: $fact_embedding,
150
+ SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, fact_embedding: $fact_embedding,
157
151
  episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
158
152
  valid_at: $valid_at, invalid_at: $invalid_at}
159
153
  RETURN r.uuid AS uuid""",
@@ -161,6 +155,7 @@ class EntityEdge(Edge):
161
155
  target_uuid=self.target_node_uuid,
162
156
  uuid=self.uuid,
163
157
  name=self.name,
158
+ group_id=self.group_id,
164
159
  fact=self.fact,
165
160
  fact_embedding=self.fact_embedding,
166
161
  episodes=self.episodes,
@@ -198,6 +193,7 @@ class EntityEdge(Edge):
198
193
  m.uuid AS target_node_uuid,
199
194
  e.created_at AS created_at,
200
195
  e.name AS name,
196
+ e.group_id AS group_id,
201
197
  e.fact AS fact,
202
198
  e.fact_embedding AS fact_embedding,
203
199
  e.episodes AS episodes,
@@ -208,25 +204,36 @@ class EntityEdge(Edge):
208
204
  uuid=uuid,
209
205
  )
210
206
 
211
- edges: list[EntityEdge] = []
212
-
213
- for record in records:
214
- edges.append(
215
- EntityEdge(
216
- uuid=record['uuid'],
217
- source_node_uuid=record['source_node_uuid'],
218
- target_node_uuid=record['target_node_uuid'],
219
- fact=record['fact'],
220
- name=record['name'],
221
- episodes=record['episodes'],
222
- fact_embedding=record['fact_embedding'],
223
- created_at=record['created_at'].to_native(),
224
- expired_at=parse_db_date(record['expired_at']),
225
- valid_at=parse_db_date(record['valid_at']),
226
- invalid_at=parse_db_date(record['invalid_at']),
227
- )
228
- )
207
+ edges = [get_entity_edge_from_record(record) for record in records]
229
208
 
230
209
  logger.info(f'Found Edge: {uuid}')
231
210
 
232
211
  return edges[0]
212
+
213
+
214
+ # Edge helpers
215
+ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
216
+ return EpisodicEdge(
217
+ uuid=record['uuid'],
218
+ group_id=record['group_id'],
219
+ source_node_uuid=record['source_node_uuid'],
220
+ target_node_uuid=record['target_node_uuid'],
221
+ created_at=record['created_at'].to_native(),
222
+ )
223
+
224
+
225
+ def get_entity_edge_from_record(record: Any) -> EntityEdge:
226
+ return EntityEdge(
227
+ uuid=record['uuid'],
228
+ source_node_uuid=record['source_node_uuid'],
229
+ target_node_uuid=record['target_node_uuid'],
230
+ fact=record['fact'],
231
+ name=record['name'],
232
+ group_id=record['group_id'],
233
+ episodes=record['episodes'],
234
+ fact_embedding=record['fact_embedding'],
235
+ created_at=record['created_at'].to_native(),
236
+ expired_at=parse_db_date(record['expired_at']),
237
+ valid_at=parse_db_date(record['valid_at']),
238
+ invalid_at=parse_db_date(record['invalid_at']),
239
+ )
@@ -18,7 +18,6 @@ import asyncio
18
18
  import logging
19
19
  from datetime import datetime
20
20
  from time import time
21
- from typing import Callable
22
21
 
23
22
  from dotenv import load_dotenv
24
23
  from neo4j import AsyncGraphDatabase
@@ -59,11 +58,6 @@ from graphiti_core.utils.maintenance.node_operations import (
59
58
  extract_nodes,
60
59
  resolve_extracted_nodes,
61
60
  )
62
- from graphiti_core.utils.maintenance.temporal_operations import (
63
- extract_edge_dates,
64
- invalidate_edges,
65
- prepare_edges_for_invalidation,
66
- )
67
61
 
68
62
  logger = logging.getLogger(__name__)
69
63
 
@@ -125,7 +119,7 @@ class Graphiti:
125
119
 
126
120
  Parameters
127
121
  ----------
128
- None
122
+ self
129
123
 
130
124
  Returns
131
125
  -------
@@ -156,7 +150,7 @@ class Graphiti:
156
150
 
157
151
  Parameters
158
152
  ----------
159
- None
153
+ self
160
154
 
161
155
  Returns
162
156
  -------
@@ -183,6 +177,7 @@ class Graphiti:
183
177
  self,
184
178
  reference_time: datetime,
185
179
  last_n: int = EPISODE_WINDOW_LEN,
180
+ group_ids: list[str | None] | None = None,
186
181
  ) -> list[EpisodicNode]:
187
182
  """
188
183
  Retrieve the last n episodic nodes from the graph.
@@ -196,6 +191,8 @@ class Graphiti:
196
191
  The reference time to retrieve episodes before.
197
192
  last_n : int, optional
198
193
  The number of episodes to retrieve. Defaults to EPISODE_WINDOW_LEN.
194
+ group_ids : list[str | None], optional
195
+ The group ids to return data from.
199
196
 
200
197
  Returns
201
198
  -------
@@ -207,7 +204,7 @@ class Graphiti:
207
204
  The actual retrieval is performed by the `retrieve_episodes` function
208
205
  from the `graphiti_core.utils` module.
209
206
  """
210
- return await retrieve_episodes(self.driver, reference_time, last_n)
207
+ return await retrieve_episodes(self.driver, reference_time, last_n, group_ids)
211
208
 
212
209
  async def add_episode(
213
210
  self,
@@ -216,8 +213,8 @@ class Graphiti:
216
213
  source_description: str,
217
214
  reference_time: datetime,
218
215
  source: EpisodeType = EpisodeType.message,
219
- success_callback: Callable | None = None,
220
- error_callback: Callable | None = None,
216
+ group_id: str | None = None,
217
+ uuid: str | None = None,
221
218
  ):
222
219
  """
223
220
  Process an episode and update the graph.
@@ -237,10 +234,10 @@ class Graphiti:
237
234
  The reference time for the episode.
238
235
  source : EpisodeType, optional
239
236
  The type of the episode. Defaults to EpisodeType.message.
240
- success_callback : Callable | None, optional
241
- A callback function to be called upon successful processing.
242
- error_callback : Callable | None, optional
243
- A callback function to be called if an error occurs during processing.
237
+ group_id : str | None
238
+ An id for the graph partition the episode is a part of.
239
+ uuid : str | None
240
+ Optional uuid of the episode.
244
241
 
245
242
  Returns
246
243
  -------
@@ -271,9 +268,12 @@ class Graphiti:
271
268
  embedder = self.llm_client.get_embedder()
272
269
  now = datetime.now()
273
270
 
274
- previous_episodes = await self.retrieve_episodes(reference_time, last_n=3)
271
+ previous_episodes = await self.retrieve_episodes(
272
+ reference_time, last_n=3, group_ids=[group_id]
273
+ )
275
274
  episode = EpisodicNode(
276
275
  name=name,
276
+ group_id=group_id,
277
277
  labels=[],
278
278
  source=source,
279
279
  content=episode_body,
@@ -281,6 +281,7 @@ class Graphiti:
281
281
  created_at=now,
282
282
  valid_at=reference_time,
283
283
  )
284
+ episode.uuid = uuid if uuid is not None else episode.uuid
284
285
 
285
286
  # Extract entities as nodes
286
287
 
@@ -293,7 +294,7 @@ class Graphiti:
293
294
  *[node.generate_name_embedding(embedder) for node in extracted_nodes]
294
295
  )
295
296
 
296
- # Resolve extracted nodes with nodes already in the graph
297
+ # Resolve extracted nodes with nodes already in the graph and extract facts
297
298
  existing_nodes_lists: list[list[EntityNode]] = list(
298
299
  await asyncio.gather(
299
300
  *[get_relevant_nodes([node], self.driver) for node in extracted_nodes]
@@ -302,22 +303,29 @@ class Graphiti:
302
303
 
303
304
  logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
304
305
 
305
- mentioned_nodes, _ = await resolve_extracted_nodes(
306
- self.llm_client, extracted_nodes, existing_nodes_lists
306
+ (mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
307
+ resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists),
308
+ extract_edges(
309
+ self.llm_client, episode, extracted_nodes, previous_episodes, group_id
310
+ ),
307
311
  )
308
312
  logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
309
313
  nodes.extend(mentioned_nodes)
310
314
 
311
- # Extract facts as edges given entity nodes
312
- extracted_edges = await extract_edges(
313
- self.llm_client, episode, mentioned_nodes, previous_episodes
315
+ extracted_edges_with_resolved_pointers = resolve_edge_pointers(
316
+ extracted_edges, uuid_map
314
317
  )
315
318
 
316
319
  # calculate embeddings
317
- await asyncio.gather(*[edge.generate_embedding(embedder) for edge in extracted_edges])
320
+ await asyncio.gather(
321
+ *[
322
+ edge.generate_embedding(embedder)
323
+ for edge in extracted_edges_with_resolved_pointers
324
+ ]
325
+ )
318
326
 
319
- # Resolve extracted edges with edges already in the graph
320
- existing_edges_list: list[list[EntityEdge]] = list(
327
+ # Resolve extracted edges with related edges already in the graph
328
+ related_edges_list: list[list[EntityEdge]] = list(
321
329
  await asyncio.gather(
322
330
  *[
323
331
  get_relevant_edges(
@@ -327,80 +335,68 @@ class Graphiti:
327
335
  edge.target_node_uuid,
328
336
  RELEVANT_SCHEMA_LIMIT,
329
337
  )
330
- for edge in extracted_edges
338
+ for edge in extracted_edges_with_resolved_pointers
331
339
  ]
332
340
  )
333
341
  )
334
342
  logger.info(
335
- f'Existing edges lists: {[(e.name, e.uuid) for edges_lst in existing_edges_list for e in edges_lst]}'
343
+ f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_list for e in edges_lst]}'
336
344
  )
337
- logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}')
338
-
339
- deduped_edges: list[EntityEdge] = await resolve_extracted_edges(
340
- self.llm_client, extracted_edges, existing_edges_list
345
+ logger.info(
346
+ f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges_with_resolved_pointers]}'
341
347
  )
342
348
 
343
- # Extract dates for the newly extracted edges
344
- edge_dates = await asyncio.gather(
345
- *[
346
- extract_edge_dates(
347
- self.llm_client,
348
- edge,
349
- episode,
350
- previous_episodes,
351
- )
352
- for edge in deduped_edges
353
- ]
349
+ existing_source_edges_list: list[list[EntityEdge]] = list(
350
+ await asyncio.gather(
351
+ *[
352
+ get_relevant_edges(
353
+ self.driver,
354
+ [edge],
355
+ edge.source_node_uuid,
356
+ None,
357
+ RELEVANT_SCHEMA_LIMIT,
358
+ )
359
+ for edge in extracted_edges_with_resolved_pointers
360
+ ]
361
+ )
354
362
  )
355
363
 
356
- for i, edge in enumerate(deduped_edges):
357
- valid_at = edge_dates[i][0]
358
- invalid_at = edge_dates[i][1]
359
-
360
- edge.valid_at = valid_at
361
- edge.invalid_at = invalid_at
362
- if edge.invalid_at is not None:
363
- edge.expired_at = now
364
-
365
- entity_edges.extend(deduped_edges)
364
+ existing_target_edges_list: list[list[EntityEdge]] = list(
365
+ await asyncio.gather(
366
+ *[
367
+ get_relevant_edges(
368
+ self.driver,
369
+ [edge],
370
+ None,
371
+ edge.target_node_uuid,
372
+ RELEVANT_SCHEMA_LIMIT,
373
+ )
374
+ for edge in extracted_edges_with_resolved_pointers
375
+ ]
376
+ )
377
+ )
366
378
 
367
- existing_edges: list[EntityEdge] = [
368
- e for edge_lst in existing_edges_list for e in edge_lst
379
+ existing_edges_list: list[list[EntityEdge]] = [
380
+ source_lst + target_lst
381
+ for source_lst, target_lst in zip(
382
+ existing_source_edges_list, existing_target_edges_list
383
+ )
369
384
  ]
370
385
 
371
- (
372
- old_edges_with_nodes_pending_invalidation,
373
- new_edges_with_nodes,
374
- ) = prepare_edges_for_invalidation(
375
- existing_edges=existing_edges, new_edges=deduped_edges, nodes=nodes
376
- )
377
-
378
- invalidated_edges = await invalidate_edges(
386
+ resolved_edges, invalidated_edges = await resolve_extracted_edges(
379
387
  self.llm_client,
380
- old_edges_with_nodes_pending_invalidation,
381
- new_edges_with_nodes,
388
+ extracted_edges_with_resolved_pointers,
389
+ related_edges_list,
390
+ existing_edges_list,
382
391
  episode,
383
392
  previous_episodes,
384
393
  )
385
394
 
386
- for edge in invalidated_edges:
387
- for existing_edge in existing_edges:
388
- if existing_edge.uuid == edge.uuid:
389
- existing_edge.expired_at = edge.expired_at
390
- for deduped_edge in deduped_edges:
391
- if deduped_edge.uuid == edge.uuid:
392
- deduped_edge.expired_at = edge.expired_at
393
- logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}')
395
+ entity_edges.extend(resolved_edges + invalidated_edges)
394
396
 
395
- entity_edges.extend(existing_edges)
397
+ logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
396
398
 
397
- logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}')
398
-
399
- episodic_edges: list[EpisodicEdge] = build_episodic_edges(
400
- mentioned_nodes,
401
- episode,
402
- now,
403
- )
399
+ episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now)
404
400
 
405
401
  logger.info(f'Built episodic edges: {episodic_edges}')
406
402
 
@@ -413,18 +409,10 @@ class Graphiti:
413
409
  end = time()
414
410
  logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
415
411
 
416
- if success_callback:
417
- await success_callback(episode)
418
412
  except Exception as e:
419
- if error_callback:
420
- await error_callback(episode, e)
421
- else:
422
- raise e
413
+ raise e
423
414
 
424
- async def add_episode_bulk(
425
- self,
426
- bulk_episodes: list[RawEpisode],
427
- ):
415
+ async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None):
428
416
  """
429
417
  Process multiple episodes in bulk and update the graph.
430
418
 
@@ -435,6 +423,8 @@ class Graphiti:
435
423
  ----------
436
424
  bulk_episodes : list[RawEpisode]
437
425
  A list of RawEpisode objects to be processed and added to the graph.
426
+ group_id : str | None
427
+ An id for the graph partition the episode is a part of.
438
428
 
439
429
  Returns
440
430
  -------
@@ -471,6 +461,7 @@ class Graphiti:
471
461
  source=episode.source,
472
462
  content=episode.content,
473
463
  source_description=episode.source_description,
464
+ group_id=group_id,
474
465
  created_at=now,
475
466
  valid_at=episode.reference_time,
476
467
  )
@@ -535,7 +526,13 @@ class Graphiti:
535
526
  except Exception as e:
536
527
  raise e
537
528
 
538
- async def search(self, query: str, center_node_uuid: str | None = None, num_results=10):
529
+ async def search(
530
+ self,
531
+ query: str,
532
+ center_node_uuid: str | None = None,
533
+ group_ids: list[str | None] | None = None,
534
+ num_results=10,
535
+ ):
539
536
  """
540
537
  Perform a hybrid search on the knowledge graph.
541
538
 
@@ -548,6 +545,8 @@ class Graphiti:
548
545
  The search query string.
549
546
  center_node_uuid: str, optional
550
547
  Facts will be reranked based on proximity to this node
548
+ group_ids : list[str | None] | None, optional
549
+ The graph partitions to return data from.
551
550
  num_results : int, optional
552
551
  The maximum number of results to return. Defaults to 10.
553
552
 
@@ -570,6 +569,7 @@ class Graphiti:
570
569
  num_episodes=0,
571
570
  num_edges=num_results,
572
571
  num_nodes=0,
572
+ group_ids=group_ids,
573
573
  search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
574
574
  reranker=reranker,
575
575
  )
@@ -598,7 +598,10 @@ class Graphiti:
598
598
  )
599
599
 
600
600
  async def get_nodes_by_query(
601
- self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
601
+ self,
602
+ query: str,
603
+ group_ids: list[str | None] | None = None,
604
+ limit: int = RELEVANT_SCHEMA_LIMIT,
602
605
  ) -> list[EntityNode]:
603
606
  """
604
607
  Retrieve nodes from the graph database based on a text query.
@@ -610,6 +613,8 @@ class Graphiti:
610
613
  ----------
611
614
  query : str
612
615
  The text query to search for in the graph.
616
+ group_ids : list[str | None] | None, optional
617
+ The graph partitions to return data from.
613
618
  limit : int | None, optional
614
619
  The maximum number of results to return per search method.
615
620
  If None, a default limit will be applied.
@@ -634,5 +639,7 @@ class Graphiti:
634
639
  """
635
640
  embedder = self.llm_client.get_embedder()
636
641
  query_embedding = await generate_embedding(embedder, query)
637
- relevant_nodes = await hybrid_node_search([query], [query_embedding], self.driver, limit)
642
+ relevant_nodes = await hybrid_node_search(
643
+ [query], [query_embedding], self.driver, group_ids, limit
644
+ )
638
645
  return relevant_nodes