graphiti-core 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

graphiti_core/edges.py CHANGED
@@ -18,6 +18,7 @@ import logging
18
18
  from abc import ABC, abstractmethod
19
19
  from datetime import datetime
20
20
  from time import time
21
+ from typing import Any
21
22
  from uuid import uuid4
22
23
 
23
24
  from neo4j import AsyncDriver
@@ -32,6 +33,7 @@ logger = logging.getLogger(__name__)
32
33
 
33
34
  class Edge(BaseModel, ABC):
34
35
  uuid: str = Field(default_factory=lambda: uuid4().hex)
36
+ group_id: str | None = Field(description='partition of the graph')
35
37
  source_node_uuid: str
36
38
  target_node_uuid: str
37
39
  created_at: datetime
@@ -61,11 +63,12 @@ class EpisodicEdge(Edge):
61
63
  MATCH (episode:Episodic {uuid: $episode_uuid})
62
64
  MATCH (node:Entity {uuid: $entity_uuid})
63
65
  MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
64
- SET r = {uuid: $uuid, created_at: $created_at}
66
+ SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
65
67
  RETURN r.uuid AS uuid""",
66
68
  episode_uuid=self.source_node_uuid,
67
69
  entity_uuid=self.target_node_uuid,
68
70
  uuid=self.uuid,
71
+ group_id=self.group_id,
69
72
  created_at=self.created_at,
70
73
  )
71
74
 
@@ -92,7 +95,8 @@ class EpisodicEdge(Edge):
92
95
  """
93
96
  MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
94
97
  RETURN
95
- e.uuid As uuid,
98
+ e.uuid As uuid,
99
+ e.group_id AS group_id,
96
100
  n.uuid AS source_node_uuid,
97
101
  m.uuid AS target_node_uuid,
98
102
  e.created_at AS created_at
@@ -100,17 +104,7 @@ class EpisodicEdge(Edge):
100
104
  uuid=uuid,
101
105
  )
102
106
 
103
- edges: list[EpisodicEdge] = []
104
-
105
- for record in records:
106
- edges.append(
107
- EpisodicEdge(
108
- uuid=record['uuid'],
109
- source_node_uuid=record['source_node_uuid'],
110
- target_node_uuid=record['target_node_uuid'],
111
- created_at=record['created_at'].to_native(),
112
- )
113
- )
107
+ edges = [get_episodic_edge_from_record(record) for record in records]
114
108
 
115
109
  logger.info(f'Found Edge: {uuid}')
116
110
 
@@ -153,7 +147,7 @@ class EntityEdge(Edge):
153
147
  MATCH (source:Entity {uuid: $source_uuid})
154
148
  MATCH (target:Entity {uuid: $target_uuid})
155
149
  MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
156
- SET r = {uuid: $uuid, name: $name, fact: $fact, fact_embedding: $fact_embedding,
150
+ SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, fact_embedding: $fact_embedding,
157
151
  episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
158
152
  valid_at: $valid_at, invalid_at: $invalid_at}
159
153
  RETURN r.uuid AS uuid""",
@@ -161,6 +155,7 @@ class EntityEdge(Edge):
161
155
  target_uuid=self.target_node_uuid,
162
156
  uuid=self.uuid,
163
157
  name=self.name,
158
+ group_id=self.group_id,
164
159
  fact=self.fact,
165
160
  fact_embedding=self.fact_embedding,
166
161
  episodes=self.episodes,
@@ -198,6 +193,7 @@ class EntityEdge(Edge):
198
193
  m.uuid AS target_node_uuid,
199
194
  e.created_at AS created_at,
200
195
  e.name AS name,
196
+ e.group_id AS group_id,
201
197
  e.fact AS fact,
202
198
  e.fact_embedding AS fact_embedding,
203
199
  e.episodes AS episodes,
@@ -208,25 +204,36 @@ class EntityEdge(Edge):
208
204
  uuid=uuid,
209
205
  )
210
206
 
211
- edges: list[EntityEdge] = []
212
-
213
- for record in records:
214
- edges.append(
215
- EntityEdge(
216
- uuid=record['uuid'],
217
- source_node_uuid=record['source_node_uuid'],
218
- target_node_uuid=record['target_node_uuid'],
219
- fact=record['fact'],
220
- name=record['name'],
221
- episodes=record['episodes'],
222
- fact_embedding=record['fact_embedding'],
223
- created_at=record['created_at'].to_native(),
224
- expired_at=parse_db_date(record['expired_at']),
225
- valid_at=parse_db_date(record['valid_at']),
226
- invalid_at=parse_db_date(record['invalid_at']),
227
- )
228
- )
207
+ edges = [get_entity_edge_from_record(record) for record in records]
229
208
 
230
209
  logger.info(f'Found Edge: {uuid}')
231
210
 
232
211
  return edges[0]
212
+
213
+
214
+ # Edge helpers
215
+ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
216
+ return EpisodicEdge(
217
+ uuid=record['uuid'],
218
+ group_id=record['group_id'],
219
+ source_node_uuid=record['source_node_uuid'],
220
+ target_node_uuid=record['target_node_uuid'],
221
+ created_at=record['created_at'].to_native(),
222
+ )
223
+
224
+
225
+ def get_entity_edge_from_record(record: Any) -> EntityEdge:
226
+ return EntityEdge(
227
+ uuid=record['uuid'],
228
+ source_node_uuid=record['source_node_uuid'],
229
+ target_node_uuid=record['target_node_uuid'],
230
+ fact=record['fact'],
231
+ name=record['name'],
232
+ group_id=record['group_id'],
233
+ episodes=record['episodes'],
234
+ fact_embedding=record['fact_embedding'],
235
+ created_at=record['created_at'].to_native(),
236
+ expired_at=parse_db_date(record['expired_at']),
237
+ valid_at=parse_db_date(record['valid_at']),
238
+ invalid_at=parse_db_date(record['invalid_at']),
239
+ )
graphiti_core/graphiti.py CHANGED
@@ -18,7 +18,6 @@ import asyncio
18
18
  import logging
19
19
  from datetime import datetime
20
20
  from time import time
21
- from typing import Callable
22
21
 
23
22
  from dotenv import load_dotenv
24
23
  from neo4j import AsyncGraphDatabase
@@ -120,7 +119,7 @@ class Graphiti:
120
119
 
121
120
  Parameters
122
121
  ----------
123
- None
122
+ self
124
123
 
125
124
  Returns
126
125
  -------
@@ -151,7 +150,7 @@ class Graphiti:
151
150
 
152
151
  Parameters
153
152
  ----------
154
- None
153
+ self
155
154
 
156
155
  Returns
157
156
  -------
@@ -178,6 +177,7 @@ class Graphiti:
178
177
  self,
179
178
  reference_time: datetime,
180
179
  last_n: int = EPISODE_WINDOW_LEN,
180
+ group_ids: list[str | None] | None = None,
181
181
  ) -> list[EpisodicNode]:
182
182
  """
183
183
  Retrieve the last n episodic nodes from the graph.
@@ -191,6 +191,8 @@ class Graphiti:
191
191
  The reference time to retrieve episodes before.
192
192
  last_n : int, optional
193
193
  The number of episodes to retrieve. Defaults to EPISODE_WINDOW_LEN.
194
+ group_ids : list[str | None], optional
195
+ The group ids to return data from.
194
196
 
195
197
  Returns
196
198
  -------
@@ -202,7 +204,7 @@ class Graphiti:
202
204
  The actual retrieval is performed by the `retrieve_episodes` function
203
205
  from the `graphiti_core.utils` module.
204
206
  """
205
- return await retrieve_episodes(self.driver, reference_time, last_n)
207
+ return await retrieve_episodes(self.driver, reference_time, last_n, group_ids)
206
208
 
207
209
  async def add_episode(
208
210
  self,
@@ -211,8 +213,8 @@ class Graphiti:
211
213
  source_description: str,
212
214
  reference_time: datetime,
213
215
  source: EpisodeType = EpisodeType.message,
214
- success_callback: Callable | None = None,
215
- error_callback: Callable | None = None,
216
+ group_id: str | None = None,
217
+ uuid: str | None = None,
216
218
  ):
217
219
  """
218
220
  Process an episode and update the graph.
@@ -232,10 +234,10 @@ class Graphiti:
232
234
  The reference time for the episode.
233
235
  source : EpisodeType, optional
234
236
  The type of the episode. Defaults to EpisodeType.message.
235
- success_callback : Callable | None, optional
236
- A callback function to be called upon successful processing.
237
- error_callback : Callable | None, optional
238
- A callback function to be called if an error occurs during processing.
237
+ group_id : str | None
238
+ An id for the graph partition the episode is a part of.
239
+ uuid : str | None
240
+ Optional uuid of the episode.
239
241
 
240
242
  Returns
241
243
  -------
@@ -266,9 +268,12 @@ class Graphiti:
266
268
  embedder = self.llm_client.get_embedder()
267
269
  now = datetime.now()
268
270
 
269
- previous_episodes = await self.retrieve_episodes(reference_time, last_n=3)
271
+ previous_episodes = await self.retrieve_episodes(
272
+ reference_time, last_n=3, group_ids=[group_id]
273
+ )
270
274
  episode = EpisodicNode(
271
275
  name=name,
276
+ group_id=group_id,
272
277
  labels=[],
273
278
  source=source,
274
279
  content=episode_body,
@@ -276,6 +281,7 @@ class Graphiti:
276
281
  created_at=now,
277
282
  valid_at=reference_time,
278
283
  )
284
+ episode.uuid = uuid if uuid is not None else episode.uuid
279
285
 
280
286
  # Extract entities as nodes
281
287
 
@@ -299,7 +305,9 @@ class Graphiti:
299
305
 
300
306
  (mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
301
307
  resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists),
302
- extract_edges(self.llm_client, episode, extracted_nodes, previous_episodes),
308
+ extract_edges(
309
+ self.llm_client, episode, extracted_nodes, previous_episodes, group_id
310
+ ),
303
311
  )
304
312
  logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
305
313
  nodes.extend(mentioned_nodes)
@@ -388,11 +396,7 @@ class Graphiti:
388
396
 
389
397
  logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
390
398
 
391
- episodic_edges: list[EpisodicEdge] = build_episodic_edges(
392
- mentioned_nodes,
393
- episode,
394
- now,
395
- )
399
+ episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now)
396
400
 
397
401
  logger.info(f'Built episodic edges: {episodic_edges}')
398
402
 
@@ -405,18 +409,10 @@ class Graphiti:
405
409
  end = time()
406
410
  logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
407
411
 
408
- if success_callback:
409
- await success_callback(episode)
410
412
  except Exception as e:
411
- if error_callback:
412
- await error_callback(episode, e)
413
- else:
414
- raise e
413
+ raise e
415
414
 
416
- async def add_episode_bulk(
417
- self,
418
- bulk_episodes: list[RawEpisode],
419
- ):
415
+ async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None):
420
416
  """
421
417
  Process multiple episodes in bulk and update the graph.
422
418
 
@@ -427,6 +423,8 @@ class Graphiti:
427
423
  ----------
428
424
  bulk_episodes : list[RawEpisode]
429
425
  A list of RawEpisode objects to be processed and added to the graph.
426
+ group_id : str | None
427
+ An id for the graph partition the episode is a part of.
430
428
 
431
429
  Returns
432
430
  -------
@@ -463,6 +461,7 @@ class Graphiti:
463
461
  source=episode.source,
464
462
  content=episode.content,
465
463
  source_description=episode.source_description,
464
+ group_id=group_id,
466
465
  created_at=now,
467
466
  valid_at=episode.reference_time,
468
467
  )
@@ -527,7 +526,13 @@ class Graphiti:
527
526
  except Exception as e:
528
527
  raise e
529
528
 
530
- async def search(self, query: str, center_node_uuid: str | None = None, num_results=10):
529
+ async def search(
530
+ self,
531
+ query: str,
532
+ center_node_uuid: str | None = None,
533
+ group_ids: list[str | None] | None = None,
534
+ num_results=10,
535
+ ):
531
536
  """
532
537
  Perform a hybrid search on the knowledge graph.
533
538
 
@@ -540,6 +545,8 @@ class Graphiti:
540
545
  The search query string.
541
546
  center_node_uuid: str, optional
542
547
  Facts will be reranked based on proximity to this node
548
+ group_ids : list[str | None] | None, optional
549
+ The graph partitions to return data from.
543
550
  num_results : int, optional
544
551
  The maximum number of results to return. Defaults to 10.
545
552
 
@@ -562,6 +569,7 @@ class Graphiti:
562
569
  num_episodes=0,
563
570
  num_edges=num_results,
564
571
  num_nodes=0,
572
+ group_ids=group_ids,
565
573
  search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
566
574
  reranker=reranker,
567
575
  )
@@ -590,7 +598,10 @@ class Graphiti:
590
598
  )
591
599
 
592
600
  async def get_nodes_by_query(
593
- self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
601
+ self,
602
+ query: str,
603
+ group_ids: list[str | None] | None = None,
604
+ limit: int = RELEVANT_SCHEMA_LIMIT,
594
605
  ) -> list[EntityNode]:
595
606
  """
596
607
  Retrieve nodes from the graph database based on a text query.
@@ -602,6 +613,8 @@ class Graphiti:
602
613
  ----------
603
614
  query : str
604
615
  The text query to search for in the graph.
616
+ group_ids : list[str | None] | None, optional
617
+ The graph partitions to return data from.
605
618
  limit : int | None, optional
606
619
  The maximum number of results to return per search method.
607
620
  If None, a default limit will be applied.
@@ -626,5 +639,7 @@ class Graphiti:
626
639
  """
627
640
  embedder = self.llm_client.get_embedder()
628
641
  query_embedding = await generate_embedding(embedder, query)
629
- relevant_nodes = await hybrid_node_search([query], [query_embedding], self.driver, limit)
642
+ relevant_nodes = await hybrid_node_search(
643
+ [query], [query_embedding], self.driver, group_ids, limit
644
+ )
630
645
  return relevant_nodes
graphiti_core/nodes.py CHANGED
@@ -19,10 +19,10 @@ from abc import ABC, abstractmethod
19
19
  from datetime import datetime
20
20
  from enum import Enum
21
21
  from time import time
22
+ from typing import Any
22
23
  from uuid import uuid4
23
24
 
24
25
  from neo4j import AsyncDriver
25
- from openai import OpenAI
26
26
  from pydantic import BaseModel, Field
27
27
 
28
28
  from graphiti_core.llm_client.config import EMBEDDING_DIM
@@ -69,6 +69,7 @@ class EpisodeType(Enum):
69
69
  class Node(BaseModel, ABC):
70
70
  uuid: str = Field(default_factory=lambda: uuid4().hex)
71
71
  name: str = Field(description='name of the node')
72
+ group_id: str | None = Field(description='partition of the graph')
72
73
  labels: list[str] = Field(default_factory=list)
73
74
  created_at: datetime = Field(default_factory=lambda: datetime.now())
74
75
 
@@ -106,11 +107,12 @@ class EpisodicNode(Node):
106
107
  result = await driver.execute_query(
107
108
  """
108
109
  MERGE (n:Episodic {uuid: $uuid})
109
- SET n = {uuid: $uuid, name: $name, source_description: $source_description, source: $source, content: $content,
110
+ SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
110
111
  entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
111
112
  RETURN n.uuid AS uuid""",
112
113
  uuid=self.uuid,
113
114
  name=self.name,
115
+ group_id=self.group_id,
114
116
  source_description=self.source_description,
115
117
  content=self.content,
116
118
  entity_edges=self.entity_edges,
@@ -141,29 +143,19 @@ class EpisodicNode(Node):
141
143
  records, _, _ = await driver.execute_query(
142
144
  """
143
145
  MATCH (e:Episodic {uuid: $uuid})
144
- RETURN e.content as content,
145
- e.created_at as created_at,
146
- e.valid_at as valid_at,
147
- e.uuid as uuid,
148
- e.name as name,
149
- e.source_description as source_description,
150
- e.source as source
146
+ RETURN e.content AS content,
147
+ e.created_at AS created_at,
148
+ e.valid_at AS valid_at,
149
+ e.uuid AS uuid,
150
+ e.name AS name,
151
+ e.group_id AS group_id
152
+ e.source_description AS source_description,
153
+ e.source AS source
151
154
  """,
152
155
  uuid=uuid,
153
156
  )
154
157
 
155
- episodes = [
156
- EpisodicNode(
157
- content=record['content'],
158
- created_at=record['created_at'].to_native().timestamp(),
159
- valid_at=(record['valid_at'].to_native()),
160
- uuid=record['uuid'],
161
- source=EpisodeType.from_str(record['source']),
162
- name=record['name'],
163
- source_description=record['source_description'],
164
- )
165
- for record in records
166
- ]
158
+ episodes = [get_episodic_node_from_record(record) for record in records]
167
159
 
168
160
  logger.info(f'Found Node: {uuid}')
169
161
 
@@ -174,10 +166,6 @@ class EntityNode(Node):
174
166
  name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
175
167
  summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
176
168
 
177
- async def update_summary(self, driver: AsyncDriver): ...
178
-
179
- async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ...
180
-
181
169
  async def generate_name_embedding(self, embedder, model='text-embedding-3-small'):
182
170
  start = time()
183
171
  text = self.name.replace('\n', ' ')
@@ -192,10 +180,11 @@ class EntityNode(Node):
192
180
  result = await driver.execute_query(
193
181
  """
194
182
  MERGE (n:Entity {uuid: $uuid})
195
- SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, summary: $summary, created_at: $created_at}
183
+ SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
196
184
  RETURN n.uuid AS uuid""",
197
185
  uuid=self.uuid,
198
186
  name=self.name,
187
+ group_id=self.group_id,
199
188
  summary=self.summary,
200
189
  name_embedding=self.name_embedding,
201
190
  created_at=self.created_at,
@@ -227,25 +216,14 @@ class EntityNode(Node):
227
216
  n.uuid As uuid,
228
217
  n.name AS name,
229
218
  n.name_embedding AS name_embedding,
219
+ n.group_id AS group_id
230
220
  n.created_at AS created_at,
231
221
  n.summary AS summary
232
222
  """,
233
223
  uuid=uuid,
234
224
  )
235
225
 
236
- nodes: list[EntityNode] = []
237
-
238
- for record in records:
239
- nodes.append(
240
- EntityNode(
241
- uuid=record['uuid'],
242
- name=record['name'],
243
- name_embedding=record['name_embedding'],
244
- labels=['Entity'],
245
- created_at=record['created_at'].to_native(),
246
- summary=record['summary'],
247
- )
248
- )
226
+ nodes = [get_entity_node_from_record(record) for record in records]
249
227
 
250
228
  logger.info(f'Found Node: {uuid}')
251
229
 
@@ -253,3 +231,26 @@ class EntityNode(Node):
253
231
 
254
232
 
255
233
  # Node helpers
234
+ def get_episodic_node_from_record(record: Any) -> EpisodicNode:
235
+ return EpisodicNode(
236
+ content=record['content'],
237
+ created_at=record['created_at'].to_native().timestamp(),
238
+ valid_at=(record['valid_at'].to_native()),
239
+ uuid=record['uuid'],
240
+ group_id=record['group_id'],
241
+ source=EpisodeType.from_str(record['source']),
242
+ name=record['name'],
243
+ source_description=record['source_description'],
244
+ )
245
+
246
+
247
+ def get_entity_node_from_record(record: Any) -> EntityNode:
248
+ return EntityNode(
249
+ uuid=record['uuid'],
250
+ name=record['name'],
251
+ group_id=record['group_id'],
252
+ name_embedding=record['name_embedding'],
253
+ labels=['Entity'],
254
+ created_at=record['created_at'].to_native(),
255
+ summary=record['summary'],
256
+ )
@@ -52,6 +52,7 @@ class SearchConfig(BaseModel):
52
52
  num_edges: int = Field(default=10)
53
53
  num_nodes: int = Field(default=10)
54
54
  num_episodes: int = EPISODE_WINDOW_LEN
55
+ group_ids: list[str | None] | None
55
56
  search_methods: list[SearchMethod]
56
57
  reranker: Reranker | None
57
58
 
@@ -83,7 +84,9 @@ async def hybrid_search(
83
84
  nodes.extend(await get_mentioned_nodes(driver, episodes))
84
85
 
85
86
  if SearchMethod.bm25 in config.search_methods:
86
- text_search = await edge_fulltext_search(driver, query, None, None, 2 * config.num_edges)
87
+ text_search = await edge_fulltext_search(
88
+ driver, query, None, None, config.group_ids, 2 * config.num_edges
89
+ )
87
90
  search_results.append(text_search)
88
91
 
89
92
  if SearchMethod.cosine_similarity in config.search_methods:
@@ -95,7 +98,7 @@ async def hybrid_search(
95
98
  )
96
99
 
97
100
  similarity_search = await edge_similarity_search(
98
- driver, search_vector, None, None, 2 * config.num_edges
101
+ driver, search_vector, None, None, config.group_ids, 2 * config.num_edges
99
102
  )
100
103
  search_results.append(similarity_search)
101
104
 
@@ -3,13 +3,11 @@ import logging
3
3
  import re
4
4
  from collections import defaultdict
5
5
  from time import time
6
- from typing import Any
7
6
 
8
7
  from neo4j import AsyncDriver, Query
9
8
 
10
- from graphiti_core.edges import EntityEdge
11
- from graphiti_core.helpers import parse_db_date
12
- from graphiti_core.nodes import EntityNode, EpisodicNode
9
+ from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
10
+ from graphiti_core.nodes import EntityNode, EpisodicNode, get_entity_node_from_record
13
11
 
14
12
  logger = logging.getLogger(__name__)
15
13
 
@@ -23,6 +21,7 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode])
23
21
  MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
24
22
  RETURN DISTINCT
25
23
  n.uuid As uuid,
24
+ n.group_id AS group_id,
26
25
  n.name AS name,
27
26
  n.name_embedding AS name_embedding
28
27
  n.created_at AS created_at,
@@ -31,86 +30,29 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode])
31
30
  uuids=episode_uuids,
32
31
  )
33
32
 
34
- nodes: list[EntityNode] = []
35
-
36
- for record in records:
37
- nodes.append(
38
- EntityNode(
39
- uuid=record['uuid'],
40
- name=record['name'],
41
- name_embedding=record['name_embedding'],
42
- labels=['Entity'],
43
- created_at=record['created_at'].to_native(),
44
- summary=record['summary'],
45
- )
46
- )
33
+ nodes = [get_entity_node_from_record(record) for record in records]
47
34
 
48
35
  return nodes
49
36
 
50
37
 
51
- async def bfs(node_ids: list[str], driver: AsyncDriver):
52
- records, _, _ = await driver.execute_query(
53
- """
54
- MATCH (n WHERE n.uuid in $node_ids)-[r]->(m)
55
- RETURN DISTINCT
56
- n.uuid AS source_node_uuid,
57
- n.name AS source_name,
58
- n.summary AS source_summary,
59
- m.uuid AS target_node_uuid,
60
- m.name AS target_name,
61
- m.summary AS target_summary,
62
- r.uuid AS uuid,
63
- r.created_at AS created_at,
64
- r.name AS name,
65
- r.fact AS fact,
66
- r.fact_embedding AS fact_embedding,
67
- r.episodes AS episodes,
68
- r.expired_at AS expired_at,
69
- r.valid_at AS valid_at,
70
- r.invalid_at AS invalid_at
71
-
72
- """,
73
- node_ids=node_ids,
74
- )
75
-
76
- context: dict[str, Any] = {}
77
-
78
- for record in records:
79
- n_uuid = record['source_node_uuid']
80
- if n_uuid in context:
81
- context[n_uuid]['facts'].append(record['fact'])
82
- else:
83
- context[n_uuid] = {
84
- 'name': record['source_name'],
85
- 'summary': record['source_summary'],
86
- 'facts': [record['fact']],
87
- }
88
-
89
- m_uuid = record['target_node_uuid']
90
- if m_uuid not in context:
91
- context[m_uuid] = {
92
- 'name': record['target_name'],
93
- 'summary': record['target_summary'],
94
- 'facts': [],
95
- }
96
- logger.info(f'bfs search returned context: {context}')
97
- return context
98
-
99
-
100
38
  async def edge_similarity_search(
101
39
  driver: AsyncDriver,
102
40
  search_vector: list[float],
103
41
  source_node_uuid: str | None,
104
42
  target_node_uuid: str | None,
43
+ group_ids: list[str | None] | None = None,
105
44
  limit: int = RELEVANT_SCHEMA_LIMIT,
106
45
  ) -> list[EntityEdge]:
46
+ group_ids = group_ids if group_ids is not None else [None]
107
47
  # vector similarity search over embedded facts
108
48
  query = Query("""
109
49
  CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
110
50
  YIELD relationship AS rel, score
111
51
  MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
52
+ WHERE r.group_id IN $group_ids
112
53
  RETURN
113
54
  r.uuid AS uuid,
55
+ r.group_id AS group_id,
114
56
  n.uuid AS source_node_uuid,
115
57
  m.uuid AS target_node_uuid,
116
58
  r.created_at AS created_at,
@@ -129,8 +71,10 @@ async def edge_similarity_search(
129
71
  CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
130
72
  YIELD relationship AS rel, score
131
73
  MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
74
+ WHERE r.group_id IN $group_ids
132
75
  RETURN
133
76
  r.uuid AS uuid,
77
+ r.group_id AS group_id,
134
78
  n.uuid AS source_node_uuid,
135
79
  m.uuid AS target_node_uuid,
136
80
  r.created_at AS created_at,
@@ -148,8 +92,10 @@ async def edge_similarity_search(
148
92
  CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
149
93
  YIELD relationship AS rel, score
150
94
  MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
95
+ WHERE r.group_id IN $group_ids
151
96
  RETURN
152
97
  r.uuid AS uuid,
98
+ r.group_id AS group_id,
153
99
  n.uuid AS source_node_uuid,
154
100
  m.uuid AS target_node_uuid,
155
101
  r.created_at AS created_at,
@@ -167,8 +113,10 @@ async def edge_similarity_search(
167
113
  CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
168
114
  YIELD relationship AS rel, score
169
115
  MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
116
+ WHERE r.group_id IN $group_ids
170
117
  RETURN
171
118
  r.uuid AS uuid,
119
+ r.group_id AS group_id,
172
120
  n.uuid AS source_node_uuid,
173
121
  m.uuid AS target_node_uuid,
174
122
  r.created_at AS created_at,
@@ -187,41 +135,32 @@ async def edge_similarity_search(
187
135
  search_vector=search_vector,
188
136
  source_uuid=source_node_uuid,
189
137
  target_uuid=target_node_uuid,
138
+ group_ids=group_ids,
190
139
  limit=limit,
191
140
  )
192
141
 
193
- edges: list[EntityEdge] = []
194
-
195
- for record in records:
196
- edge = EntityEdge(
197
- uuid=record['uuid'],
198
- source_node_uuid=record['source_node_uuid'],
199
- target_node_uuid=record['target_node_uuid'],
200
- fact=record['fact'],
201
- name=record['name'],
202
- episodes=record['episodes'],
203
- fact_embedding=record['fact_embedding'],
204
- created_at=record['created_at'].to_native(),
205
- expired_at=parse_db_date(record['expired_at']),
206
- valid_at=parse_db_date(record['valid_at']),
207
- invalid_at=parse_db_date(record['invalid_at']),
208
- )
209
-
210
- edges.append(edge)
142
+ edges = [get_entity_edge_from_record(record) for record in records]
211
143
 
212
144
  return edges
213
145
 
214
146
 
215
147
  async def entity_similarity_search(
216
- search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
148
+ search_vector: list[float],
149
+ driver: AsyncDriver,
150
+ group_ids: list[str | None] | None = None,
151
+ limit=RELEVANT_SCHEMA_LIMIT,
217
152
  ) -> list[EntityNode]:
153
+ group_ids = group_ids if group_ids is not None else [None]
154
+
218
155
  # vector similarity search over entity names
219
156
  records, _, _ = await driver.execute_query(
220
157
  """
221
158
  CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
222
159
  YIELD node AS n, score
160
+ MATCH (n WHERE n.group_id IN $group_ids)
223
161
  RETURN
224
- n.uuid As uuid,
162
+ n.uuid As uuid,
163
+ n.group_id AS group_id,
225
164
  n.name AS name,
226
165
  n.name_embedding AS name_embedding,
227
166
  n.created_at AS created_at,
@@ -229,58 +168,44 @@ async def entity_similarity_search(
229
168
  ORDER BY score DESC
230
169
  """,
231
170
  search_vector=search_vector,
171
+ group_ids=group_ids,
232
172
  limit=limit,
233
173
  )
234
- nodes: list[EntityNode] = []
235
-
236
- for record in records:
237
- nodes.append(
238
- EntityNode(
239
- uuid=record['uuid'],
240
- name=record['name'],
241
- name_embedding=record['name_embedding'],
242
- labels=['Entity'],
243
- created_at=record['created_at'].to_native(),
244
- summary=record['summary'],
245
- )
246
- )
174
+ nodes = [get_entity_node_from_record(record) for record in records]
247
175
 
248
176
  return nodes
249
177
 
250
178
 
251
179
  async def entity_fulltext_search(
252
- query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
180
+ query: str,
181
+ driver: AsyncDriver,
182
+ group_ids: list[str | None] | None = None,
183
+ limit=RELEVANT_SCHEMA_LIMIT,
253
184
  ) -> list[EntityNode]:
185
+ group_ids = group_ids if group_ids is not None else [None]
186
+
254
187
  # BM25 search to get top nodes
255
188
  fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
256
189
  records, _, _ = await driver.execute_query(
257
190
  """
258
- CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
191
+ CALL db.index.fulltext.queryNodes("name_and_summary", $query)
192
+ YIELD node AS n, score
193
+ MATCH (n WHERE n.group_id in $group_ids)
259
194
  RETURN
260
- node.uuid AS uuid,
261
- node.name AS name,
262
- node.name_embedding AS name_embedding,
263
- node.created_at AS created_at,
264
- node.summary AS summary
195
+ n.uuid AS uuid,
196
+ n.group_id AS group_id,
197
+ n.name AS name,
198
+ n.name_embedding AS name_embedding,
199
+ n.created_at AS created_at,
200
+ n.summary AS summary
265
201
  ORDER BY score DESC
266
202
  LIMIT $limit
267
203
  """,
268
204
  query=fuzzy_query,
205
+ group_ids=group_ids,
269
206
  limit=limit,
270
207
  )
271
- nodes: list[EntityNode] = []
272
-
273
- for record in records:
274
- nodes.append(
275
- EntityNode(
276
- uuid=record['uuid'],
277
- name=record['name'],
278
- name_embedding=record['name_embedding'],
279
- labels=['Entity'],
280
- created_at=record['created_at'].to_native(),
281
- summary=record['summary'],
282
- )
283
- )
208
+ nodes = [get_entity_node_from_record(record) for record in records]
284
209
 
285
210
  return nodes
286
211
 
@@ -290,15 +215,20 @@ async def edge_fulltext_search(
290
215
  query: str,
291
216
  source_node_uuid: str | None,
292
217
  target_node_uuid: str | None,
218
+ group_ids: list[str | None] | None = None,
293
219
  limit=RELEVANT_SCHEMA_LIMIT,
294
220
  ) -> list[EntityEdge]:
221
+ group_ids = group_ids if group_ids is not None else [None]
222
+
295
223
  # fulltext search over facts
296
224
  cypher_query = Query("""
297
225
  CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
298
226
  YIELD relationship AS rel, score
299
227
  MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
228
+ WHERE r.group_id IN $group_ids
300
229
  RETURN
301
230
  r.uuid AS uuid,
231
+ r.group_id AS group_id,
302
232
  n.uuid AS source_node_uuid,
303
233
  m.uuid AS target_node_uuid,
304
234
  r.created_at AS created_at,
@@ -317,8 +247,10 @@ async def edge_fulltext_search(
317
247
  CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
318
248
  YIELD relationship AS rel, score
319
249
  MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
250
+ WHERE r.group_id IN $group_ids
320
251
  RETURN
321
252
  r.uuid AS uuid,
253
+ r.group_id AS group_id,
322
254
  n.uuid AS source_node_uuid,
323
255
  m.uuid AS target_node_uuid,
324
256
  r.created_at AS created_at,
@@ -335,9 +267,11 @@ async def edge_fulltext_search(
335
267
  cypher_query = Query("""
336
268
  CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
337
269
  YIELD relationship AS rel, score
338
- MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
270
+ MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
271
+ WHERE r.group_id IN $group_ids
339
272
  RETURN
340
273
  r.uuid AS uuid,
274
+ r.group_id AS group_id,
341
275
  n.uuid AS source_node_uuid,
342
276
  m.uuid AS target_node_uuid,
343
277
  r.created_at AS created_at,
@@ -354,9 +288,11 @@ async def edge_fulltext_search(
354
288
  cypher_query = Query("""
355
289
  CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
356
290
  YIELD relationship AS rel, score
357
- MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
291
+ MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
292
+ WHERE r.group_id IN $group_ids
358
293
  RETURN
359
294
  r.uuid AS uuid,
295
+ r.group_id AS group_id,
360
296
  n.uuid AS source_node_uuid,
361
297
  m.uuid AS target_node_uuid,
362
298
  r.created_at AS created_at,
@@ -377,27 +313,11 @@ async def edge_fulltext_search(
377
313
  query=fuzzy_query,
378
314
  source_uuid=source_node_uuid,
379
315
  target_uuid=target_node_uuid,
316
+ group_ids=group_ids,
380
317
  limit=limit,
381
318
  )
382
319
 
383
- edges: list[EntityEdge] = []
384
-
385
- for record in records:
386
- edge = EntityEdge(
387
- uuid=record['uuid'],
388
- source_node_uuid=record['source_node_uuid'],
389
- target_node_uuid=record['target_node_uuid'],
390
- fact=record['fact'],
391
- name=record['name'],
392
- episodes=record['episodes'],
393
- fact_embedding=record['fact_embedding'],
394
- created_at=record['created_at'].to_native(),
395
- expired_at=parse_db_date(record['expired_at']),
396
- valid_at=parse_db_date(record['valid_at']),
397
- invalid_at=parse_db_date(record['invalid_at']),
398
- )
399
-
400
- edges.append(edge)
320
+ edges = [get_entity_edge_from_record(record) for record in records]
401
321
 
402
322
  return edges
403
323
 
@@ -406,6 +326,7 @@ async def hybrid_node_search(
406
326
  queries: list[str],
407
327
  embeddings: list[list[float]],
408
328
  driver: AsyncDriver,
329
+ group_ids: list[str | None] | None = None,
409
330
  limit: int = RELEVANT_SCHEMA_LIMIT,
410
331
  ) -> list[EntityNode]:
411
332
  """
@@ -422,6 +343,8 @@ async def hybrid_node_search(
422
343
  A list of embedding vectors corresponding to the queries. If empty only fulltext search is performed.
423
344
  driver : AsyncDriver
424
345
  The Neo4j driver instance for database operations.
346
+ group_ids : list[str] | None, optional
347
+ The list of group ids to retrieve nodes from.
425
348
  limit : int | None, optional
426
349
  The maximum number of results to return per search method. If None, a default limit will be applied.
427
350
 
@@ -448,8 +371,8 @@ async def hybrid_node_search(
448
371
 
449
372
  results: list[list[EntityNode]] = list(
450
373
  await asyncio.gather(
451
- *[entity_fulltext_search(q, driver, 2 * limit) for q in queries],
452
- *[entity_similarity_search(e, driver, 2 * limit) for e in embeddings],
374
+ *[entity_fulltext_search(q, driver, group_ids, 2 * limit) for q in queries],
375
+ *[entity_similarity_search(e, driver, group_ids, 2 * limit) for e in embeddings],
453
376
  )
454
377
  )
455
378
 
@@ -500,6 +423,7 @@ async def get_relevant_nodes(
500
423
  [node.name for node in nodes],
501
424
  [node.name_embedding for node in nodes if node.name_embedding is not None],
502
425
  driver,
426
+ [node.group_id for node in nodes],
503
427
  )
504
428
  return relevant_nodes
505
429
 
@@ -518,13 +442,20 @@ async def get_relevant_edges(
518
442
  results = await asyncio.gather(
519
443
  *[
520
444
  edge_similarity_search(
521
- driver, edge.fact_embedding, source_node_uuid, target_node_uuid, limit
445
+ driver,
446
+ edge.fact_embedding,
447
+ source_node_uuid,
448
+ target_node_uuid,
449
+ [edge.group_id],
450
+ limit,
522
451
  )
523
452
  for edge in edges
524
453
  if edge.fact_embedding is not None
525
454
  ],
526
455
  *[
527
- edge_fulltext_search(driver, edge.fact, source_node_uuid, target_node_uuid, limit)
456
+ edge_fulltext_search(
457
+ driver, edge.fact, source_node_uuid, target_node_uuid, [edge.group_id], limit
458
+ )
528
459
  for edge in edges
529
460
  ],
530
461
  )
@@ -17,6 +17,7 @@ limitations under the License.
17
17
  import asyncio
18
18
  import logging
19
19
  import typing
20
+ from collections import defaultdict
20
21
  from datetime import datetime
21
22
  from math import ceil
22
23
 
@@ -42,7 +43,6 @@ from graphiti_core.utils.maintenance.node_operations import (
42
43
  extract_nodes,
43
44
  )
44
45
  from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates
45
- from graphiti_core.utils.utils import chunk_edges_by_nodes
46
46
 
47
47
  logger = logging.getLogger(__name__)
48
48
 
@@ -62,7 +62,9 @@ async def retrieve_previous_episodes_bulk(
62
62
  ) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
63
63
  previous_episodes_list = await asyncio.gather(
64
64
  *[
65
- retrieve_episodes(driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN)
65
+ retrieve_episodes(
66
+ driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN, group_ids=[episode.group_id]
67
+ )
66
68
  for episode in episodes
67
69
  ]
68
70
  )
@@ -90,7 +92,13 @@ async def extract_nodes_and_edges_bulk(
90
92
 
91
93
  extracted_edges_bulk = await asyncio.gather(
92
94
  *[
93
- extract_edges(llm_client, episode, extracted_nodes_bulk[i], previous_episodes_list[i])
95
+ extract_edges(
96
+ llm_client,
97
+ episode,
98
+ extracted_nodes_bulk[i],
99
+ previous_episodes_list[i],
100
+ episode.group_id,
101
+ )
94
102
  for i, episode in enumerate(episodes)
95
103
  ]
96
104
  )
@@ -343,3 +351,23 @@ async def extract_edge_dates_bulk(
343
351
  edge.expired_at = datetime.now()
344
352
 
345
353
  return edges
354
+
355
+
356
+ def chunk_edges_by_nodes(edges: list[EntityEdge]) -> list[list[EntityEdge]]:
357
+ # We only want to dedupe edges that are between the same pair of nodes
358
+ # We build a map of the edges based on their source and target nodes.
359
+ edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list)
360
+ for edge in edges:
361
+ # We drop loop edges
362
+ if edge.source_node_uuid == edge.target_node_uuid:
363
+ continue
364
+
365
+ # Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution
366
+ pointers = [edge.source_node_uuid, edge.target_node_uuid]
367
+ pointers.sort()
368
+
369
+ edge_chunk_map[pointers[0] + pointers[1]].append(edge)
370
+
371
+ edge_chunks = [chunk for chunk in edge_chunk_map.values()]
372
+
373
+ return edge_chunks
@@ -37,15 +37,15 @@ def build_episodic_edges(
37
37
  episode: EpisodicNode,
38
38
  created_at: datetime,
39
39
  ) -> List[EpisodicEdge]:
40
- edges: List[EpisodicEdge] = []
41
-
42
- for node in entity_nodes:
43
- edge = EpisodicEdge(
40
+ edges: List[EpisodicEdge] = [
41
+ EpisodicEdge(
44
42
  source_node_uuid=episode.uuid,
45
43
  target_node_uuid=node.uuid,
46
44
  created_at=created_at,
45
+ group_id=episode.group_id,
47
46
  )
48
- edges.append(edge)
47
+ for node in entity_nodes
48
+ ]
49
49
 
50
50
  return edges
51
51
 
@@ -55,6 +55,7 @@ async def extract_edges(
55
55
  episode: EpisodicNode,
56
56
  nodes: list[EntityNode],
57
57
  previous_episodes: list[EpisodicNode],
58
+ group_id: str | None,
58
59
  ) -> list[EntityEdge]:
59
60
  start = time()
60
61
 
@@ -88,6 +89,7 @@ async def extract_edges(
88
89
  source_node_uuid=edge_data['source_node_uuid'],
89
90
  target_node_uuid=edge_data['target_node_uuid'],
90
91
  name=edge_data['relation_type'],
92
+ group_id=group_id,
91
93
  fact=edge_data['fact'],
92
94
  episodes=[episode.uuid],
93
95
  created_at=datetime.now(),
@@ -34,6 +34,10 @@ async def build_indices_and_constraints(driver: AsyncDriver):
34
34
  'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
35
35
  'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
36
36
  'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
37
+ 'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
38
+ 'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
39
+ 'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
40
+ 'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
37
41
  'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
38
42
  'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
39
43
  'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
@@ -86,6 +90,7 @@ async def retrieve_episodes(
86
90
  driver: AsyncDriver,
87
91
  reference_time: datetime,
88
92
  last_n: int = EPISODE_WINDOW_LEN,
93
+ group_ids: list[str | None] | None = None,
89
94
  ) -> list[EpisodicNode]:
90
95
  """
91
96
  Retrieve the last n episodic nodes from the graph.
@@ -96,25 +101,28 @@ async def retrieve_episodes(
96
101
  less than or equal to this reference_time will be retrieved. This allows for
97
102
  querying the graph's state at a specific point in time.
98
103
  last_n (int, optional): The number of most recent episodes to retrieve, relative to the reference_time.
104
+ group_ids (list[str], optional): The list of group ids to return data from.
99
105
 
100
106
  Returns:
101
107
  list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes.
102
108
  """
103
109
  result = await driver.execute_query(
104
110
  """
105
- MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
106
- RETURN e.content as content,
107
- e.created_at as created_at,
108
- e.valid_at as valid_at,
109
- e.uuid as uuid,
110
- e.name as name,
111
- e.source_description as source_description,
112
- e.source as source
111
+ MATCH (e:Episodic) WHERE e.valid_at <= $reference_time AND e.group_id in $group_ids
112
+ RETURN e.content AS content,
113
+ e.created_at AS created_at,
114
+ e.valid_at AS valid_at,
115
+ e.uuid AS uuid,
116
+ e.group_id AS group_id,
117
+ e.name AS name,
118
+ e.source_description AS source_description,
119
+ e.source AS source
113
120
  ORDER BY e.created_at DESC
114
121
  LIMIT $num_episodes
115
122
  """,
116
123
  reference_time=reference_time,
117
124
  num_episodes=last_n,
125
+ group_ids=group_ids,
118
126
  )
119
127
  episodes = [
120
128
  EpisodicNode(
@@ -124,6 +132,7 @@ async def retrieve_episodes(
124
132
  ),
125
133
  valid_at=(record['valid_at'].to_native()),
126
134
  uuid=record['uuid'],
135
+ group_id=record['group_id'],
127
136
  source=EpisodeType.from_str(record['source']),
128
137
  name=record['name'],
129
138
  source_description=record['source_description'],
@@ -85,6 +85,7 @@ async def extract_nodes(
85
85
  for node_data in extracted_node_data:
86
86
  new_node = EntityNode(
87
87
  name=node_data['name'],
88
+ group_id=episode.group_id,
88
89
  labels=node_data['labels'],
89
90
  summary=node_data['summary'],
90
91
  created_at=datetime.now(),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: graphiti-core
3
- Version: 0.2.2
3
+ Version: 0.2.3
4
4
  Summary: A temporal graph building library
5
5
  License: Apache-2.0
6
6
  Author: Paul Paliychuk
@@ -12,11 +12,10 @@ Classifier: Programming Language :: Python :: 3.10
12
12
  Classifier: Programming Language :: Python :: 3.11
13
13
  Classifier: Programming Language :: Python :: 3.12
14
14
  Requires-Dist: diskcache (>=5.6.3,<6.0.0)
15
- Requires-Dist: fastapi (>=0.112.0,<0.113.0)
16
15
  Requires-Dist: neo4j (>=5.23.0,<6.0.0)
16
+ Requires-Dist: numpy (>=2.1.1,<3.0.0)
17
17
  Requires-Dist: openai (>=1.38.0,<2.0.0)
18
18
  Requires-Dist: pydantic (>=2.8.2,<3.0.0)
19
- Requires-Dist: sentence-transformers (>=3.0.1,<4.0.0)
20
19
  Requires-Dist: tenacity (<9.0.0)
21
20
  Description-Content-Type: text/markdown
22
21
 
@@ -1,6 +1,6 @@
1
1
  graphiti_core/__init__.py,sha256=e5SWFkRiaUwfprYIeIgVIh7JDedNiloZvd3roU-0aDY,55
2
- graphiti_core/edges.py,sha256=Sxsqw7WZAC6YJKftMaF9t69o7HV_GM6m6ULjtLhZg0M,7484
3
- graphiti_core/graphiti.py,sha256=hLIDjvbdvgQPPi1-HVyiQ1gw67jUdiaKqWRBZhtxqFc,23106
2
+ graphiti_core/edges.py,sha256=oy_tK9YWE7_g4aQMGutymVdreiC-SsWP6ZtayEYGCFQ,7700
3
+ graphiti_core/graphiti.py,sha256=tUEtyBb8hQXTn_eMmVSsFVBV7AKWE22SPQihCMZtcZU,23647
4
4
  graphiti_core/helpers.py,sha256=EAeC3RrcecjiTGN2vxergN5RHTy2_jhFXA5PQVT3toU,200
5
5
  graphiti_core/llm_client/__init__.py,sha256=f4OSk82jJ70wZ2HOuQu6-RQWkkf7HIB0FCT6xOuxZkQ,154
6
6
  graphiti_core/llm_client/anthropic_client.py,sha256=C8lOLm7in_eNfOP7s8gjMM0Y99-TzKWlGaPuVGceX68,2180
@@ -9,7 +9,7 @@ graphiti_core/llm_client/config.py,sha256=d1oZ9tt7QBQlbph7v-0HjItb6otK9_-IwF8kkR
9
9
  graphiti_core/llm_client/groq_client.py,sha256=qscr5-190wBTUCBL31EAjQTLytK9AF75-y9GsVRvGJU,2206
10
10
  graphiti_core/llm_client/openai_client.py,sha256=Bkrp_mKzAxK6kgPzv1UtVUgr1ZvvJhE2H39hgAwWrsI,2211
11
11
  graphiti_core/llm_client/utils.py,sha256=H8-Kwa5SyvIYDNIas8O4bHJ6jsOL49li44VoDEMyauY,555
12
- graphiti_core/nodes.py,sha256=gB2HxaLHeLAo_wthSI8kRonTdz-BR_GJ4f6JMrxXd0c,8004
12
+ graphiti_core/nodes.py,sha256=RZnIKyu9ZzWVlbodae3Rkzlg00fQIqp5o3iGB4Ffm-M,8140
13
13
  graphiti_core/prompts/__init__.py,sha256=EA-x9xUki9l8wnu2l8ek_oNf75-do5tq5hVq7Zbv8Kw,101
14
14
  graphiti_core/prompts/dedupe_edges.py,sha256=DUNHdIudj50FAjkla4nc68tSFSD2yjmYHBw-Bb7ph20,6529
15
15
  graphiti_core/prompts/dedupe_nodes.py,sha256=BZ9S-PB9SSGjc5Oo8ivdgA6rZx3OGOFhKtwrBlQ0bm0,7269
@@ -20,18 +20,17 @@ graphiti_core/prompts/invalidate_edges.py,sha256=8SHt3iPTdmqk8A52LxgdMtI39w4USKq
20
20
  graphiti_core/prompts/lib.py,sha256=RR8f8DQfioUK5bJonMzn02pKLxJlaENv1VocpvRJ488,3532
21
21
  graphiti_core/prompts/models.py,sha256=cvx_Bv5RMFUD_5IUawYrbpOKLPHogai7_bm7YXrSz84,867
22
22
  graphiti_core/search/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
- graphiti_core/search/search.py,sha256=IUqAdWub2mg-j9Mz-NacJVLequsG5fxqx2SySKnQtXA,4438
24
- graphiti_core/search/search_utils.py,sha256=MPzYTp0ybEZjDH92_1Bxwm7dz8CKHkTBcgPWDIXapg0,21135
23
+ graphiti_core/search/search.py,sha256=cr1-syRlRdijnLtbuQYWy_2G1CtAeIaz6BQ2kl_6FrY,4535
24
+ graphiti_core/search/search_utils.py,sha256=YeJ-M67HXPQySruwZmau3jvilFlcwf8OwfuflnSdf1Q,19355
25
25
  graphiti_core/utils/__init__.py,sha256=cJAcMnBZdHBQmWrZdU1PQ1YmaL75bhVUkyVpIPuOyns,260
26
- graphiti_core/utils/bulk_utils.py,sha256=xwKgHDNiGDt3-jG_YfN4vrHfG-SUxfuBnsFnBANal98,11683
26
+ graphiti_core/utils/bulk_utils.py,sha256=JtoYTZPCigPa3n2E43Oe7QhFZRTA_QKNGy1jVgklHag,12614
27
27
  graphiti_core/utils/maintenance/__init__.py,sha256=4b9sfxqyFZMLwxxS2lnQ6_wBr3xrJRIqfAWOidK8EK0,388
28
- graphiti_core/utils/maintenance/edge_operations.py,sha256=JMrMAinkGaGTzaiiCFG-HACOTnoGfJa2hhTQKhujqgM,10782
29
- graphiti_core/utils/maintenance/graph_data_operations.py,sha256=ggzCWezFyLC29VZBiYHvanOpSRLaPtcmbgHgcl-qHy8,5321
30
- graphiti_core/utils/maintenance/node_operations.py,sha256=1Iswwoqy7HDH_CQACQUq3oQKrX7cNZb1kdkSQOawj84,7956
28
+ graphiti_core/utils/maintenance/edge_operations.py,sha256=Xq60YlOGQKzD5qN6eahUMOiLQJiBaDNOeIiGkS8EdB0,10855
29
+ graphiti_core/utils/maintenance/graph_data_operations.py,sha256=-A4fPYtXIjoBBX6IDPoaU9pDcSjZGeRbRPj23W1C-l4,5951
30
+ graphiti_core/utils/maintenance/node_operations.py,sha256=ecBOp_reQynENFN0M69IzRPgEuBYOuPpDBwFZq5e-I4,7995
31
31
  graphiti_core/utils/maintenance/temporal_operations.py,sha256=BzfGDm96w4HcUEsaWTHUBt5S8dNmDQL1eX6AuBL-XFM,8135
32
32
  graphiti_core/utils/maintenance/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
33
- graphiti_core/utils/utils.py,sha256=LguHvEDi9JruXKWXXHaz2f4vpezdfgY-rpxjPq0dao8,1959
34
- graphiti_core-0.2.2.dist-info/LICENSE,sha256=KCUwCyDXuVEgmDWkozHyniRyWjnWUWjkuDHfU6o3JlA,11325
35
- graphiti_core-0.2.2.dist-info/METADATA,sha256=HOn2oMZZFhh5Tz4v0fNPO45AbEp4muF4QXXFhZOb45o,9184
36
- graphiti_core-0.2.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
37
- graphiti_core-0.2.2.dist-info/RECORD,,
33
+ graphiti_core-0.2.3.dist-info/LICENSE,sha256=KCUwCyDXuVEgmDWkozHyniRyWjnWUWjkuDHfU6o3JlA,11325
34
+ graphiti_core-0.2.3.dist-info/METADATA,sha256=o81BUoLGtzm0AhnO9MBW8yeG1_UPFN6m-0PmnFOJKis,9124
35
+ graphiti_core-0.2.3.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
36
+ graphiti_core-0.2.3.dist-info/RECORD,,
@@ -1,60 +0,0 @@
1
- """
2
- Copyright 2024, Zep Software, Inc.
3
-
4
- Licensed under the Apache License, Version 2.0 (the "License");
5
- you may not use this file except in compliance with the License.
6
- You may obtain a copy of the License at
7
-
8
- http://www.apache.org/licenses/LICENSE-2.0
9
-
10
- Unless required by applicable law or agreed to in writing, software
11
- distributed under the License is distributed on an "AS IS" BASIS,
12
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- See the License for the specific language governing permissions and
14
- limitations under the License.
15
- """
16
-
17
- import logging
18
- from collections import defaultdict
19
-
20
- from graphiti_core.edges import EntityEdge, EpisodicEdge
21
- from graphiti_core.nodes import EntityNode, EpisodicNode
22
-
23
- logger = logging.getLogger(__name__)
24
-
25
-
26
- def build_episodic_edges(
27
- entity_nodes: list[EntityNode], episode: EpisodicNode
28
- ) -> list[EpisodicEdge]:
29
- edges: list[EpisodicEdge] = []
30
-
31
- for node in entity_nodes:
32
- edges.append(
33
- EpisodicEdge(
34
- source_node_uuid=episode.uuid,
35
- target_node_uuid=node.uuid,
36
- created_at=episode.created_at,
37
- )
38
- )
39
-
40
- return edges
41
-
42
-
43
- def chunk_edges_by_nodes(edges: list[EntityEdge]) -> list[list[EntityEdge]]:
44
- # We only want to dedupe edges that are between the same pair of nodes
45
- # We build a map of the edges based on their source and target nodes.
46
- edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list)
47
- for edge in edges:
48
- # We drop loop edges
49
- if edge.source_node_uuid == edge.target_node_uuid:
50
- continue
51
-
52
- # Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution
53
- pointers = [edge.source_node_uuid, edge.target_node_uuid]
54
- pointers.sort()
55
-
56
- edge_chunk_map[pointers[0] + pointers[1]].append(edge)
57
-
58
- edge_chunks = [chunk for chunk in edge_chunk_map.values()]
59
-
60
- return edge_chunks