graphiti-core 0.3.4__tar.gz → 0.3.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (43) hide show
  1. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/PKG-INFO +1 -1
  2. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/edges.py +7 -9
  3. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/errors.py +8 -0
  4. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/graphiti.py +8 -8
  5. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/llm_client/openai_client.py +1 -1
  6. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/nodes.py +4 -4
  7. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/search/search.py +20 -10
  8. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/search/search_utils.py +33 -62
  9. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/utils/maintenance/community_operations.py +6 -6
  10. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/utils/maintenance/edge_operations.py +1 -1
  11. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/utils/maintenance/graph_data_operations.py +3 -2
  12. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/pyproject.toml +2 -2
  13. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/LICENSE +0 -0
  14. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/README.md +0 -0
  15. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/__init__.py +0 -0
  16. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/helpers.py +0 -0
  17. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/llm_client/__init__.py +0 -0
  18. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/llm_client/anthropic_client.py +0 -0
  19. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/llm_client/client.py +0 -0
  20. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/llm_client/config.py +0 -0
  21. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/llm_client/errors.py +0 -0
  22. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/llm_client/groq_client.py +0 -0
  23. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/llm_client/utils.py +0 -0
  24. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/prompts/__init__.py +0 -0
  25. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/prompts/dedupe_edges.py +0 -0
  26. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/prompts/dedupe_nodes.py +0 -0
  27. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/prompts/extract_edge_dates.py +0 -0
  28. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/prompts/extract_edges.py +0 -0
  29. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/prompts/extract_nodes.py +0 -0
  30. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/prompts/invalidate_edges.py +0 -0
  31. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/prompts/lib.py +0 -0
  32. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/prompts/models.py +0 -0
  33. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/prompts/summarize_nodes.py +0 -0
  34. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/py.typed +0 -0
  35. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/search/__init__.py +0 -0
  36. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/search/search_config.py +0 -0
  37. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/search/search_config_recipes.py +0 -0
  38. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/utils/__init__.py +0 -0
  39. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/utils/bulk_utils.py +0 -0
  40. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/utils/maintenance/__init__.py +0 -0
  41. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/utils/maintenance/node_operations.py +0 -0
  42. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/utils/maintenance/temporal_operations.py +0 -0
  43. {graphiti_core-0.3.4 → graphiti_core-0.3.6}/graphiti_core/utils/maintenance/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: graphiti-core
3
- Version: 0.3.4
3
+ Version: 0.3.6
4
4
  Summary: A temporal graph building library
5
5
  License: Apache-2.0
6
6
  Author: Paul Paliychuk
@@ -24,7 +24,7 @@ from uuid import uuid4
24
24
  from neo4j import AsyncDriver
25
25
  from pydantic import BaseModel, Field
26
26
 
27
- from graphiti_core.errors import EdgeNotFoundError
27
+ from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
28
28
  from graphiti_core.helpers import parse_db_date
29
29
  from graphiti_core.llm_client.config import EMBEDDING_DIM
30
30
  from graphiti_core.nodes import Node
@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
34
34
 
35
35
  class Edge(BaseModel, ABC):
36
36
  uuid: str = Field(default_factory=lambda: str(uuid4()))
37
- group_id: str | None = Field(description='partition of the graph')
37
+ group_id: str = Field(description='partition of the graph')
38
38
  source_node_uuid: str
39
39
  target_node_uuid: str
40
40
  created_at: datetime
@@ -131,7 +131,7 @@ class EpisodicEdge(Edge):
131
131
  return edges
132
132
 
133
133
  @classmethod
134
- async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]):
134
+ async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
135
135
  records, _, _ = await driver.execute_query(
136
136
  """
137
137
  MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
@@ -147,10 +147,9 @@ class EpisodicEdge(Edge):
147
147
  )
148
148
 
149
149
  edges = [get_episodic_edge_from_record(record) for record in records]
150
- uuids = [edge.uuid for edge in edges]
151
150
 
152
151
  if len(edges) == 0:
153
- raise EdgeNotFoundError(uuids[0])
152
+ raise GroupsEdgesNotFoundError(group_ids)
154
153
  return edges
155
154
 
156
155
 
@@ -270,7 +269,7 @@ class EntityEdge(Edge):
270
269
  return edges
271
270
 
272
271
  @classmethod
273
- async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]):
272
+ async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
274
273
  records, _, _ = await driver.execute_query(
275
274
  """
276
275
  MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
@@ -293,10 +292,9 @@ class EntityEdge(Edge):
293
292
  )
294
293
 
295
294
  edges = [get_entity_edge_from_record(record) for record in records]
296
- uuids = [edge.uuid for edge in edges]
297
295
 
298
296
  if len(edges) == 0:
299
- raise EdgeNotFoundError(uuids[0])
297
+ raise GroupsEdgesNotFoundError(group_ids)
300
298
  return edges
301
299
 
302
300
 
@@ -360,7 +358,7 @@ class CommunityEdge(Edge):
360
358
  return edges
361
359
 
362
360
  @classmethod
363
- async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]):
361
+ async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
364
362
  records, _, _ = await driver.execute_query(
365
363
  """
366
364
  MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
@@ -27,6 +27,14 @@ class EdgeNotFoundError(GraphitiError):
27
27
  super().__init__(self.message)
28
28
 
29
29
 
30
+ class GroupsEdgesNotFoundError(GraphitiError):
31
+ """Raised when no edges are found for a list of group ids."""
32
+
33
+ def __init__(self, group_ids: list[str]):
34
+ self.message = f'no edges found for group ids {group_ids}'
35
+ super().__init__(self.message)
36
+
37
+
30
38
  class NodeNotFoundError(GraphitiError):
31
39
  """Raised when a node is not found."""
32
40
 
@@ -129,7 +129,7 @@ class Graphiti:
129
129
  else:
130
130
  self.llm_client = OpenAIClient()
131
131
 
132
- def close(self):
132
+ async def close(self):
133
133
  """
134
134
  Close the connection to the Neo4j database.
135
135
 
@@ -159,7 +159,7 @@ class Graphiti:
159
159
  finally:
160
160
  graphiti.close()
161
161
  """
162
- self.driver.close()
162
+ await self.driver.close()
163
163
 
164
164
  async def build_indices_and_constraints(self):
165
165
  """
@@ -197,7 +197,7 @@ class Graphiti:
197
197
  self,
198
198
  reference_time: datetime,
199
199
  last_n: int = EPISODE_WINDOW_LEN,
200
- group_ids: list[str | None] | None = None,
200
+ group_ids: list[str] | None = None,
201
201
  ) -> list[EpisodicNode]:
202
202
  """
203
203
  Retrieve the last n episodic nodes from the graph.
@@ -233,7 +233,7 @@ class Graphiti:
233
233
  source_description: str,
234
234
  reference_time: datetime,
235
235
  source: EpisodeType = EpisodeType.message,
236
- group_id: str | None = None,
236
+ group_id: str = '',
237
237
  uuid: str | None = None,
238
238
  update_communities: bool = False,
239
239
  ):
@@ -446,7 +446,7 @@ class Graphiti:
446
446
  except Exception as e:
447
447
  raise e
448
448
 
449
- async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None = None):
449
+ async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str = ''):
450
450
  """
451
451
  Process multiple episodes in bulk and update the graph.
452
452
 
@@ -577,7 +577,7 @@ class Graphiti:
577
577
  self,
578
578
  query: str,
579
579
  center_node_uuid: str | None = None,
580
- group_ids: list[str | None] | None = None,
580
+ group_ids: list[str] | None = None,
581
581
  num_results=DEFAULT_SEARCH_LIMIT,
582
582
  ) -> list[EntityEdge]:
583
583
  """
@@ -633,7 +633,7 @@ class Graphiti:
633
633
  self,
634
634
  query: str,
635
635
  config: SearchConfig,
636
- group_ids: list[str | None] | None = None,
636
+ group_ids: list[str] | None = None,
637
637
  center_node_uuid: str | None = None,
638
638
  ) -> SearchResults:
639
639
  return await search(
@@ -644,7 +644,7 @@ class Graphiti:
644
644
  self,
645
645
  query: str,
646
646
  center_node_uuid: str | None = None,
647
- group_ids: list[str | None] | None = None,
647
+ group_ids: list[str] | None = None,
648
648
  limit: int = DEFAULT_SEARCH_LIMIT,
649
649
  ) -> list[EntityNode]:
650
650
  """
@@ -29,7 +29,7 @@ from .errors import RateLimitError
29
29
 
30
30
  logger = logging.getLogger(__name__)
31
31
 
32
- DEFAULT_MODEL = 'gpt-4o-2024-08-06'
32
+ DEFAULT_MODEL = 'gpt-4o-mini'
33
33
 
34
34
 
35
35
  class OpenAIClient(LLMClient):
@@ -70,7 +70,7 @@ class EpisodeType(Enum):
70
70
  class Node(BaseModel, ABC):
71
71
  uuid: str = Field(default_factory=lambda: str(uuid4()))
72
72
  name: str = Field(description='name of the node')
73
- group_id: str | None = Field(description='partition of the graph')
73
+ group_id: str = Field(description='partition of the graph')
74
74
  labels: list[str] = Field(default_factory=list)
75
75
  created_at: datetime = Field(default_factory=lambda: datetime.now())
76
76
 
@@ -186,7 +186,7 @@ class EpisodicNode(Node):
186
186
  return episodes
187
187
 
188
188
  @classmethod
189
- async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]):
189
+ async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
190
190
  records, _, _ = await driver.execute_query(
191
191
  """
192
192
  MATCH (e:Episodic) WHERE e.group_id IN $group_ids
@@ -281,7 +281,7 @@ class EntityNode(Node):
281
281
  return nodes
282
282
 
283
283
  @classmethod
284
- async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]):
284
+ async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
285
285
  records, _, _ = await driver.execute_query(
286
286
  """
287
287
  MATCH (n:Entity) WHERE n.group_id IN $group_ids
@@ -374,7 +374,7 @@ class CommunityNode(Node):
374
374
  return communities
375
375
 
376
376
  @classmethod
377
- async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]):
377
+ async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
378
378
  records, _, _ = await driver.execute_query(
379
379
  """
380
380
  MATCH (n:Community) WHERE n.group_id IN $group_ids
@@ -15,6 +15,7 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  import logging
18
+ from collections import defaultdict
18
19
  from time import time
19
20
 
20
21
  from neo4j import AsyncDriver
@@ -56,7 +57,7 @@ async def search(
56
57
  driver: AsyncDriver,
57
58
  embedder,
58
59
  query: str,
59
- group_ids: list[str | None] | None,
60
+ group_ids: list[str] | None,
60
61
  config: SearchConfig,
61
62
  center_node_uuid: str | None = None,
62
63
  ) -> SearchResults:
@@ -103,7 +104,7 @@ async def edge_search(
103
104
  driver: AsyncDriver,
104
105
  embedder,
105
106
  query: str,
106
- group_ids: list[str | None] | None,
107
+ group_ids: list[str] | None,
107
108
  config: EdgeSearchConfig,
108
109
  center_node_uuid: str | None = None,
109
110
  limit=DEFAULT_SEARCH_LIMIT,
@@ -140,14 +141,21 @@ async def edge_search(
140
141
  if center_node_uuid is None:
141
142
  raise SearchRerankerError('No center node provided for Node Distance reranker')
142
143
 
143
- source_to_edge_uuid_map = {
144
- edge.source_node_uuid: edge.uuid for result in search_results for edge in result
145
- }
146
- source_uuids = [[edge.source_node_uuid for edge in result] for result in search_results]
144
+ # use rrf as a preliminary sort
145
+ sorted_result_uuids = rrf([[edge.uuid for edge in result] for result in search_results])
146
+ sorted_results = [edge_uuid_map[uuid] for uuid in sorted_result_uuids]
147
+
148
+ # node distance reranking
149
+ source_to_edge_uuid_map = defaultdict(list)
150
+ for edge in sorted_results:
151
+ source_to_edge_uuid_map[edge.source_node_uuid].append(edge.uuid)
152
+
153
+ source_uuids = [edge.source_node_uuid for edge in sorted_results]
147
154
 
148
155
  reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid)
149
156
 
150
- reranked_uuids = [source_to_edge_uuid_map[node_uuid] for node_uuid in reranked_node_uuids]
157
+ for node_uuid in reranked_node_uuids:
158
+ reranked_uuids.extend(source_to_edge_uuid_map[node_uuid])
151
159
 
152
160
  reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
153
161
 
@@ -161,7 +169,7 @@ async def node_search(
161
169
  driver: AsyncDriver,
162
170
  embedder,
163
171
  query: str,
164
- group_ids: list[str | None] | None,
172
+ group_ids: list[str] | None,
165
173
  config: NodeSearchConfig,
166
174
  center_node_uuid: str | None = None,
167
175
  limit=DEFAULT_SEARCH_LIMIT,
@@ -198,7 +206,9 @@ async def node_search(
198
206
  elif config.reranker == NodeReranker.node_distance:
199
207
  if center_node_uuid is None:
200
208
  raise SearchRerankerError('No center node provided for Node Distance reranker')
201
- reranked_uuids = await node_distance_reranker(driver, search_result_uuids, center_node_uuid)
209
+ reranked_uuids = await node_distance_reranker(
210
+ driver, rrf(search_result_uuids), center_node_uuid
211
+ )
202
212
 
203
213
  reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
204
214
 
@@ -209,7 +219,7 @@ async def community_search(
209
219
  driver: AsyncDriver,
210
220
  embedder,
211
221
  query: str,
212
- group_ids: list[str | None] | None,
222
+ group_ids: list[str] | None,
213
223
  config: CommunitySearchConfig,
214
224
  limit=DEFAULT_SEARCH_LIMIT,
215
225
  ) -> list[CommunityNode]:
@@ -87,7 +87,7 @@ async def edge_fulltext_search(
87
87
  query: str,
88
88
  source_node_uuid: str | None,
89
89
  target_node_uuid: str | None,
90
- group_ids: list[str | None] | None = None,
90
+ group_ids: list[str] | None = None,
91
91
  limit=RELEVANT_SCHEMA_LIMIT,
92
92
  ) -> list[EntityEdge]:
93
93
  # fulltext search over facts
@@ -95,10 +95,7 @@ async def edge_fulltext_search(
95
95
  CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
96
96
  YIELD relationship AS rel, score
97
97
  MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
98
- WHERE CASE
99
- WHEN $group_ids IS NULL THEN n.group_id IS NULL
100
- ELSE n.group_id IN $group_ids
101
- END
98
+ WHERE $group_ids IS NULL OR n.group_id IN $group_ids
102
99
  RETURN
103
100
  r.uuid AS uuid,
104
101
  r.group_id AS group_id,
@@ -120,10 +117,7 @@ async def edge_fulltext_search(
120
117
  CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
121
118
  YIELD relationship AS rel, score
122
119
  MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
123
- WHERE CASE
124
- WHEN $group_ids IS NULL THEN r.group_id IS NULL
125
- ELSE r.group_id IN $group_ids
126
- END
120
+ WHERE $group_ids IS NULL OR r.group_id IN $group_ids
127
121
  RETURN
128
122
  r.uuid AS uuid,
129
123
  r.group_id AS group_id,
@@ -144,10 +138,7 @@ async def edge_fulltext_search(
144
138
  CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
145
139
  YIELD relationship AS rel, score
146
140
  MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
147
- WHERE CASE
148
- WHEN $group_ids IS NULL THEN r.group_id IS NULL
149
- ELSE r.group_id IN $group_ids
150
- END
141
+ WHERE $group_ids IS NULL OR r.group_id IN $group_ids
151
142
  RETURN
152
143
  r.uuid AS uuid,
153
144
  r.group_id AS group_id,
@@ -168,10 +159,7 @@ async def edge_fulltext_search(
168
159
  CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
169
160
  YIELD relationship AS rel, score
170
161
  MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
171
- WHERE CASE
172
- WHEN $group_ids IS NULL THEN r.group_id IS NULL
173
- ELSE r.group_id IN $group_ids
174
- END
162
+ WHERE $group_ids IS NULL OR r.group_id IN $group_ids
175
163
  RETURN
176
164
  r.uuid AS uuid,
177
165
  r.group_id AS group_id,
@@ -209,7 +197,7 @@ async def edge_similarity_search(
209
197
  search_vector: list[float],
210
198
  source_node_uuid: str | None,
211
199
  target_node_uuid: str | None,
212
- group_ids: list[str | None] | None = None,
200
+ group_ids: list[str] | None = None,
213
201
  limit: int = RELEVANT_SCHEMA_LIMIT,
214
202
  ) -> list[EntityEdge]:
215
203
  # vector similarity search over embedded facts
@@ -217,10 +205,7 @@ async def edge_similarity_search(
217
205
  CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
218
206
  YIELD relationship AS rel, score
219
207
  MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
220
- WHERE CASE
221
- WHEN $group_ids IS NULL THEN r.group_id IS NULL
222
- ELSE r.group_id IN $group_ids
223
- END
208
+ WHERE $group_ids IS NULL OR r.group_id IN $group_ids
224
209
  RETURN
225
210
  r.uuid AS uuid,
226
211
  r.group_id AS group_id,
@@ -242,10 +227,7 @@ async def edge_similarity_search(
242
227
  CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
243
228
  YIELD relationship AS rel, score
244
229
  MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
245
- WHERE CASE
246
- WHEN $group_ids IS NULL THEN r.group_id IS NULL
247
- ELSE r.group_id IN $group_ids
248
- END
230
+ WHERE $group_ids IS NULL OR r.group_id IN $group_ids
249
231
  RETURN
250
232
  r.uuid AS uuid,
251
233
  r.group_id AS group_id,
@@ -266,10 +248,7 @@ async def edge_similarity_search(
266
248
  CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
267
249
  YIELD relationship AS rel, score
268
250
  MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
269
- WHERE CASE
270
- WHEN $group_ids IS NULL THEN r.group_id IS NULL
271
- ELSE r.group_id IN $group_ids
272
- END
251
+ WHERE $group_ids IS NULL OR r.group_id IN $group_ids
273
252
  RETURN
274
253
  r.uuid AS uuid,
275
254
  r.group_id AS group_id,
@@ -290,10 +269,7 @@ async def edge_similarity_search(
290
269
  CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
291
270
  YIELD relationship AS rel, score
292
271
  MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
293
- WHERE CASE
294
- WHEN $group_ids IS NULL THEN r.group_id IS NULL
295
- ELSE r.group_id IN $group_ids
296
- END
272
+ WHERE $group_ids IS NULL OR r.group_id IN $group_ids
297
273
  RETURN
298
274
  r.uuid AS uuid,
299
275
  r.group_id AS group_id,
@@ -327,7 +303,7 @@ async def edge_similarity_search(
327
303
  async def node_fulltext_search(
328
304
  driver: AsyncDriver,
329
305
  query: str,
330
- group_ids: list[str | None] | None = None,
306
+ group_ids: list[str] | None = None,
331
307
  limit=RELEVANT_SCHEMA_LIMIT,
332
308
  ) -> list[EntityNode]:
333
309
  # BM25 search to get top nodes
@@ -336,10 +312,7 @@ async def node_fulltext_search(
336
312
  """
337
313
  CALL db.index.fulltext.queryNodes("name_and_summary", $query)
338
314
  YIELD node AS n, score
339
- WHERE CASE
340
- WHEN $group_ids IS NULL THEN n.group_id IS NULL
341
- ELSE n.group_id IN $group_ids
342
- END
315
+ WHERE $group_ids IS NULL OR n.group_id IN $group_ids
343
316
  RETURN
344
317
  n.uuid AS uuid,
345
318
  n.group_id AS group_id,
@@ -362,17 +335,16 @@ async def node_fulltext_search(
362
335
  async def node_similarity_search(
363
336
  driver: AsyncDriver,
364
337
  search_vector: list[float],
365
- group_ids: list[str | None] | None = None,
338
+ group_ids: list[str] | None = None,
366
339
  limit=RELEVANT_SCHEMA_LIMIT,
367
340
  ) -> list[EntityNode]:
368
- group_ids = group_ids if group_ids is not None else [None]
369
-
370
341
  # vector similarity search over entity names
371
342
  records, _, _ = await driver.execute_query(
372
343
  """
373
344
  CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
374
345
  YIELD node AS n, score
375
- MATCH (n WHERE n.group_id IN $group_ids)
346
+ MATCH (n:Entity)
347
+ WHERE $group_ids IS NULL OR n.group_id IN $group_ids
376
348
  RETURN
377
349
  n.uuid As uuid,
378
350
  n.group_id AS group_id,
@@ -394,18 +366,17 @@ async def node_similarity_search(
394
366
  async def community_fulltext_search(
395
367
  driver: AsyncDriver,
396
368
  query: str,
397
- group_ids: list[str | None] | None = None,
369
+ group_ids: list[str] | None = None,
398
370
  limit=RELEVANT_SCHEMA_LIMIT,
399
371
  ) -> list[CommunityNode]:
400
- group_ids = group_ids if group_ids is not None else [None]
401
-
402
372
  # BM25 search to get top communities
403
373
  fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
404
374
  records, _, _ = await driver.execute_query(
405
375
  """
406
376
  CALL db.index.fulltext.queryNodes("community_name", $query)
407
377
  YIELD node AS comm, score
408
- MATCH (comm WHERE comm.group_id in $group_ids)
378
+ MATCH (comm:Community)
379
+ WHERE $group_ids IS NULL OR comm.group_id in $group_ids
409
380
  RETURN
410
381
  comm.uuid AS uuid,
411
382
  comm.group_id AS group_id,
@@ -428,17 +399,16 @@ async def community_fulltext_search(
428
399
  async def community_similarity_search(
429
400
  driver: AsyncDriver,
430
401
  search_vector: list[float],
431
- group_ids: list[str | None] | None = None,
402
+ group_ids: list[str] | None = None,
432
403
  limit=RELEVANT_SCHEMA_LIMIT,
433
404
  ) -> list[CommunityNode]:
434
- group_ids = group_ids if group_ids is not None else [None]
435
-
436
405
  # vector similarity search over entity names
437
406
  records, _, _ = await driver.execute_query(
438
407
  """
439
408
  CALL db.index.vector.queryNodes("community_name_embedding", $limit, $search_vector)
440
409
  YIELD node AS comm, score
441
- MATCH (comm WHERE comm.group_id IN $group_ids)
410
+ MATCH (comm:Community)
411
+ WHERE $group_ids IS NULL OR comm.group_id IN $group_ids
442
412
  RETURN
443
413
  comm.uuid As uuid,
444
414
  comm.group_id AS group_id,
@@ -461,7 +431,7 @@ async def hybrid_node_search(
461
431
  queries: list[str],
462
432
  embeddings: list[list[float]],
463
433
  driver: AsyncDriver,
464
- group_ids: list[str | None] | None = None,
434
+ group_ids: list[str] | None = None,
465
435
  limit: int = RELEVANT_SCHEMA_LIMIT,
466
436
  ) -> list[EntityNode]:
467
437
  """
@@ -503,7 +473,6 @@ async def hybrid_node_search(
503
473
  """
504
474
 
505
475
  start = time()
506
-
507
476
  results: list[list[EntityNode]] = list(
508
477
  await asyncio.gather(
509
478
  *[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries],
@@ -625,14 +594,14 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
625
594
 
626
595
 
627
596
  async def node_distance_reranker(
628
- driver: AsyncDriver, node_uuids: list[list[str]], center_node_uuid: str
597
+ driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str
629
598
  ) -> list[str]:
630
- # use rrf as a preliminary ranker
631
- sorted_uuids = rrf(node_uuids)
599
+ # filter out node_uuid center node node uuid
600
+ filtered_uuids = list(filter(lambda uuid: uuid != center_node_uuid, node_uuids))
632
601
  scores: dict[str, float] = {}
633
602
 
634
603
  # Find the shortest path to center node
635
- query = Query("""
604
+ query = Query("""
636
605
  MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: $node_uuid})
637
606
  RETURN length(p) AS score
638
607
  """)
@@ -644,21 +613,23 @@ async def node_distance_reranker(
644
613
  node_uuid=uuid,
645
614
  center_uuid=center_node_uuid,
646
615
  )
647
- for uuid in sorted_uuids
616
+ for uuid in filtered_uuids
648
617
  ]
649
618
  )
650
619
 
651
- for uuid, result in zip(sorted_uuids, path_results):
620
+ for uuid, result in zip(filtered_uuids, path_results):
652
621
  records = result[0]
653
622
  record = records[0] if len(records) > 0 else None
654
623
  distance: float = record['score'] if record is not None else float('inf')
655
- distance = 0 if uuid == center_node_uuid else distance
656
624
  scores[uuid] = distance
657
625
 
658
626
  # rerank on shortest distance
659
- sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
627
+ filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
660
628
 
661
- return sorted_uuids
629
+ # add back in filtered center uuids
630
+ filtered_uuids = [center_node_uuid] + filtered_uuids
631
+
632
+ return filtered_uuids
662
633
 
663
634
 
664
635
  async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[str]]) -> list[str]:
@@ -154,7 +154,7 @@ async def generate_summary_description(llm_client: LLMClient, summary: str) -> s
154
154
 
155
155
 
156
156
  async def build_community(
157
- llm_client: LLMClient, community_cluster: list[EntityNode]
157
+ llm_client: LLMClient, community_cluster: list[EntityNode]
158
158
  ) -> tuple[CommunityNode, list[CommunityEdge]]:
159
159
  summaries = [entity.summary for entity in community_cluster]
160
160
  length = len(summaries)
@@ -168,7 +168,7 @@ async def build_community(
168
168
  *[
169
169
  summarize_pair(llm_client, (str(left_summary), str(right_summary)))
170
170
  for left_summary, right_summary in zip(
171
- summaries[: int(length / 2)], summaries[int(length / 2):]
171
+ summaries[: int(length / 2)], summaries[int(length / 2) :]
172
172
  )
173
173
  ]
174
174
  )
@@ -196,7 +196,7 @@ async def build_community(
196
196
 
197
197
 
198
198
  async def build_communities(
199
- driver: AsyncDriver, llm_client: LLMClient
199
+ driver: AsyncDriver, llm_client: LLMClient
200
200
  ) -> tuple[list[CommunityNode], list[CommunityEdge]]:
201
201
  community_clusters = await get_community_clusters(driver)
202
202
 
@@ -227,7 +227,7 @@ async def remove_communities(driver: AsyncDriver):
227
227
 
228
228
 
229
229
  async def determine_entity_community(
230
- driver: AsyncDriver, entity: EntityNode
230
+ driver: AsyncDriver, entity: EntityNode
231
231
  ) -> tuple[CommunityNode | None, bool]:
232
232
  # Check if the node is already part of a community
233
233
  records, _, _ = await driver.execute_query(
@@ -288,7 +288,7 @@ async def determine_entity_community(
288
288
 
289
289
 
290
290
  async def update_community(
291
- driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode
291
+ driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode
292
292
  ):
293
293
  community, is_new = await determine_entity_community(driver, entity)
294
294
 
@@ -307,4 +307,4 @@ async def update_community(
307
307
 
308
308
  await community.generate_name_embedding(embedder)
309
309
 
310
- await community.save(driver)
310
+ await community.save(driver)
@@ -73,7 +73,7 @@ async def extract_edges(
73
73
  episode: EpisodicNode,
74
74
  nodes: list[EntityNode],
75
75
  previous_episodes: list[EpisodicNode],
76
- group_id: str | None,
76
+ group_id: str = '',
77
77
  ) -> list[EntityEdge]:
78
78
  start = time()
79
79
 
@@ -101,7 +101,7 @@ async def retrieve_episodes(
101
101
  driver: AsyncDriver,
102
102
  reference_time: datetime,
103
103
  last_n: int = EPISODE_WINDOW_LEN,
104
- group_ids: list[str | None] | None = None,
104
+ group_ids: list[str] | None = None,
105
105
  ) -> list[EpisodicNode]:
106
106
  """
107
107
  Retrieve the last n episodic nodes from the graph.
@@ -119,7 +119,8 @@ async def retrieve_episodes(
119
119
  """
120
120
  result = await driver.execute_query(
121
121
  """
122
- MATCH (e:Episodic) WHERE e.valid_at <= $reference_time AND e.group_id in $group_ids
122
+ MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
123
+ AND ($group_ids IS NULL) OR e.group_id in $group_ids
123
124
  RETURN e.content AS content,
124
125
  e.created_at AS created_at,
125
126
  e.valid_at AS valid_at,
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "graphiti-core"
3
- version = "0.3.4"
3
+ version = "0.3.6"
4
4
  description = "A temporal graph building library"
5
5
  authors = [
6
6
  "Paul Paliychuk <paul@getzep.com>",
@@ -26,7 +26,7 @@ pytest = "^8.3.3"
26
26
  python-dotenv = "^1.0.1"
27
27
  pytest-asyncio = "^0.24.0"
28
28
  pytest-xdist = "^3.6.1"
29
- ruff = "^0.6.5"
29
+ ruff = "^0.6.7"
30
30
 
31
31
  [tool.poetry.group.dev.dependencies]
32
32
  pydantic = "^2.8.2"
File without changes
File without changes