graphiti-core 0.11.6rc7__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (33) hide show
  1. graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
  2. graphiti_core/driver/__init__.py +17 -0
  3. graphiti_core/driver/driver.py +66 -0
  4. graphiti_core/driver/falkordb_driver.py +132 -0
  5. graphiti_core/driver/neo4j_driver.py +61 -0
  6. graphiti_core/edges.py +66 -40
  7. graphiti_core/embedder/azure_openai.py +64 -0
  8. graphiti_core/embedder/gemini.py +14 -3
  9. graphiti_core/graph_queries.py +149 -0
  10. graphiti_core/graphiti.py +41 -14
  11. graphiti_core/graphiti_types.py +2 -2
  12. graphiti_core/helpers.py +17 -30
  13. graphiti_core/llm_client/__init__.py +16 -0
  14. graphiti_core/llm_client/azure_openai_client.py +73 -0
  15. graphiti_core/llm_client/gemini_client.py +4 -1
  16. graphiti_core/models/edges/edge_db_queries.py +2 -4
  17. graphiti_core/nodes.py +31 -31
  18. graphiti_core/prompts/dedupe_edges.py +52 -1
  19. graphiti_core/prompts/dedupe_nodes.py +79 -4
  20. graphiti_core/prompts/extract_edges.py +50 -5
  21. graphiti_core/prompts/invalidate_edges.py +1 -1
  22. graphiti_core/search/search.py +25 -55
  23. graphiti_core/search/search_filters.py +23 -9
  24. graphiti_core/search/search_utils.py +360 -195
  25. graphiti_core/utils/bulk_utils.py +38 -11
  26. graphiti_core/utils/maintenance/community_operations.py +6 -7
  27. graphiti_core/utils/maintenance/edge_operations.py +149 -19
  28. graphiti_core/utils/maintenance/graph_data_operations.py +13 -42
  29. graphiti_core/utils/maintenance/node_operations.py +52 -71
  30. {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0.dist-info}/METADATA +14 -5
  31. {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0.dist-info}/RECORD +33 -26
  32. {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0.dist-info}/LICENSE +0 -0
  33. {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0.dist-info}/WHEEL +0 -0
@@ -18,9 +18,8 @@ import logging
18
18
  from collections import defaultdict
19
19
  from time import time
20
20
 
21
- from neo4j import AsyncDriver
22
-
23
21
  from graphiti_core.cross_encoder.client import CrossEncoderClient
22
+ from graphiti_core.driver.driver import GraphDriver
24
23
  from graphiti_core.edges import EntityEdge
25
24
  from graphiti_core.errors import SearchRerankerError
26
25
  from graphiti_core.graphiti_types import GraphitiClients
@@ -50,6 +49,9 @@ from graphiti_core.search.search_utils import (
50
49
  edge_similarity_search,
51
50
  episode_fulltext_search,
52
51
  episode_mentions_reranker,
52
+ get_embeddings_for_communities,
53
+ get_embeddings_for_edges,
54
+ get_embeddings_for_nodes,
53
55
  maximal_marginal_relevance,
54
56
  node_bfs_search,
55
57
  node_distance_reranker,
@@ -91,7 +93,7 @@ async def search(
91
93
  )
92
94
 
93
95
  # if group_ids is empty, set it to None
94
- group_ids = group_ids if group_ids else None
96
+ group_ids = group_ids if group_ids and group_ids != [''] else None
95
97
  edges, nodes, episodes, communities = await semaphore_gather(
96
98
  edge_search(
97
99
  driver,
@@ -157,7 +159,7 @@ async def search(
157
159
 
158
160
 
159
161
  async def edge_search(
160
- driver: AsyncDriver,
162
+ driver: GraphDriver,
161
163
  cross_encoder: CrossEncoderClient,
162
164
  query: str,
163
165
  query_vector: list[float],
@@ -171,7 +173,6 @@ async def edge_search(
171
173
  ) -> list[EntityEdge]:
172
174
  if config is None:
173
175
  return []
174
-
175
176
  search_results: list[list[EntityEdge]] = list(
176
177
  await semaphore_gather(
177
178
  *[
@@ -209,26 +210,17 @@ async def edge_search(
209
210
 
210
211
  reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
211
212
  elif config.reranker == EdgeReranker.mmr:
212
- await semaphore_gather(
213
- *[edge.load_fact_embedding(driver) for result in search_results for edge in result]
213
+ search_result_uuids_and_vectors = await get_embeddings_for_edges(
214
+ driver, list(edge_uuid_map.values())
214
215
  )
215
- search_result_uuids_and_vectors = [
216
- (edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024)
217
- for result in search_results
218
- for edge in result
219
- ]
220
216
  reranked_uuids = maximal_marginal_relevance(
221
217
  query_vector,
222
218
  search_result_uuids_and_vectors,
223
219
  config.mmr_lambda,
220
+ reranker_min_score,
224
221
  )
225
222
  elif config.reranker == EdgeReranker.cross_encoder:
226
- search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
227
-
228
- rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
229
- rrf_edges = [edge_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
230
-
231
- fact_to_uuid_map = {edge.fact: edge.uuid for edge in rrf_edges}
223
+ fact_to_uuid_map = {edge.fact: edge.uuid for edge in list(edge_uuid_map.values())[:limit]}
232
224
  reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
233
225
  reranked_uuids = [
234
226
  fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
@@ -267,7 +259,7 @@ async def edge_search(
267
259
 
268
260
 
269
261
  async def node_search(
270
- driver: AsyncDriver,
262
+ driver: GraphDriver,
271
263
  cross_encoder: CrossEncoderClient,
272
264
  query: str,
273
265
  query_vector: list[float],
@@ -281,7 +273,6 @@ async def node_search(
281
273
  ) -> list[EntityNode]:
282
274
  if config is None:
283
275
  return []
284
-
285
276
  search_results: list[list[EntityNode]] = list(
286
277
  await semaphore_gather(
287
278
  *[
@@ -311,30 +302,23 @@ async def node_search(
311
302
  if config.reranker == NodeReranker.rrf:
312
303
  reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
313
304
  elif config.reranker == NodeReranker.mmr:
314
- await semaphore_gather(
315
- *[node.load_name_embedding(driver) for result in search_results for node in result]
305
+ search_result_uuids_and_vectors = await get_embeddings_for_nodes(
306
+ driver, list(node_uuid_map.values())
316
307
  )
317
- search_result_uuids_and_vectors = [
318
- (node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024)
319
- for result in search_results
320
- for node in result
321
- ]
308
+
322
309
  reranked_uuids = maximal_marginal_relevance(
323
310
  query_vector,
324
311
  search_result_uuids_and_vectors,
325
312
  config.mmr_lambda,
313
+ reranker_min_score,
326
314
  )
327
315
  elif config.reranker == NodeReranker.cross_encoder:
328
- # use rrf as a preliminary reranker
329
- rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
330
- rrf_results = [node_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
331
-
332
- summary_to_uuid_map = {node.summary: node.uuid for node in rrf_results}
316
+ name_to_uuid_map = {node.name: node.uuid for node in list(node_uuid_map.values())}
333
317
 
334
- reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
318
+ reranked_node_names = await cross_encoder.rank(query, list(name_to_uuid_map.keys()))
335
319
  reranked_uuids = [
336
- summary_to_uuid_map[fact]
337
- for fact, score in reranked_summaries
320
+ name_to_uuid_map[name]
321
+ for name, score in reranked_node_names
338
322
  if score >= reranker_min_score
339
323
  ]
340
324
  elif config.reranker == NodeReranker.episode_mentions:
@@ -357,7 +341,7 @@ async def node_search(
357
341
 
358
342
 
359
343
  async def episode_search(
360
- driver: AsyncDriver,
344
+ driver: GraphDriver,
361
345
  cross_encoder: CrossEncoderClient,
362
346
  query: str,
363
347
  _query_vector: list[float],
@@ -369,7 +353,6 @@ async def episode_search(
369
353
  ) -> list[EpisodicNode]:
370
354
  if config is None:
371
355
  return []
372
-
373
356
  search_results: list[list[EpisodicNode]] = list(
374
357
  await semaphore_gather(
375
358
  *[
@@ -405,7 +388,7 @@ async def episode_search(
405
388
 
406
389
 
407
390
  async def community_search(
408
- driver: AsyncDriver,
391
+ driver: GraphDriver,
409
392
  cross_encoder: CrossEncoderClient,
410
393
  query: str,
411
394
  query_vector: list[float],
@@ -437,25 +420,12 @@ async def community_search(
437
420
  if config.reranker == CommunityReranker.rrf:
438
421
  reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
439
422
  elif config.reranker == CommunityReranker.mmr:
440
- await semaphore_gather(
441
- *[
442
- community.load_name_embedding(driver)
443
- for result in search_results
444
- for community in result
445
- ]
423
+ search_result_uuids_and_vectors = await get_embeddings_for_communities(
424
+ driver, list(community_uuid_map.values())
446
425
  )
447
- search_result_uuids_and_vectors = [
448
- (
449
- community.uuid,
450
- community.name_embedding if community.name_embedding is not None else [0.0] * 1024,
451
- )
452
- for result in search_results
453
- for community in result
454
- ]
426
+
455
427
  reranked_uuids = maximal_marginal_relevance(
456
- query_vector,
457
- search_result_uuids_and_vectors,
458
- config.mmr_lambda,
428
+ query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score
459
429
  )
460
430
  elif config.reranker == CommunityReranker.cross_encoder:
461
431
  name_to_uuid_map = {node.name: node.uuid for result in search_results for node in result}
@@ -42,6 +42,9 @@ class SearchFilters(BaseModel):
42
42
  node_labels: list[str] | None = Field(
43
43
  default=None, description='List of node labels to filter on'
44
44
  )
45
+ edge_types: list[str] | None = Field(
46
+ default=None, description='List of edge types to filter on'
47
+ )
45
48
  valid_at: list[list[DateFilter]] | None = Field(default=None)
46
49
  invalid_at: list[list[DateFilter]] | None = Field(default=None)
47
50
  created_at: list[list[DateFilter]] | None = Field(default=None)
@@ -68,8 +71,19 @@ def edge_search_filter_query_constructor(
68
71
  filter_query: LiteralString = ''
69
72
  filter_params: dict[str, Any] = {}
70
73
 
74
+ if filters.edge_types is not None:
75
+ edge_types = filters.edge_types
76
+ edge_types_filter = '\nAND r.name in $edge_types'
77
+ filter_query += edge_types_filter
78
+ filter_params['edge_types'] = edge_types
79
+
80
+ if filters.node_labels is not None:
81
+ node_labels = '|'.join(filters.node_labels)
82
+ node_label_filter = '\nAND n:' + node_labels + ' AND m:' + node_labels
83
+ filter_query += node_label_filter
84
+
71
85
  if filters.valid_at is not None:
72
- valid_at_filter = ' AND ('
86
+ valid_at_filter = '\nAND ('
73
87
  for i, or_list in enumerate(filters.valid_at):
74
88
  for j, date_filter in enumerate(or_list):
75
89
  filter_params['valid_at_' + str(j)] = date_filter.date
@@ -81,12 +95,12 @@ def edge_search_filter_query_constructor(
81
95
  and_filter_query = ''
82
96
  for j, and_filter in enumerate(and_filters):
83
97
  and_filter_query += and_filter
84
- if j != len(and_filter_query) - 1:
98
+ if j != len(and_filters) - 1:
85
99
  and_filter_query += ' AND '
86
100
 
87
101
  valid_at_filter += and_filter_query
88
102
 
89
- if i == len(or_list) - 1:
103
+ if i == len(filters.valid_at) - 1:
90
104
  valid_at_filter += ')'
91
105
  else:
92
106
  valid_at_filter += ' OR '
@@ -106,12 +120,12 @@ def edge_search_filter_query_constructor(
106
120
  and_filter_query = ''
107
121
  for j, and_filter in enumerate(and_filters):
108
122
  and_filter_query += and_filter
109
- if j != len(and_filter_query) - 1:
123
+ if j != len(and_filters) - 1:
110
124
  and_filter_query += ' AND '
111
125
 
112
126
  invalid_at_filter += and_filter_query
113
127
 
114
- if i == len(or_list) - 1:
128
+ if i == len(filters.invalid_at) - 1:
115
129
  invalid_at_filter += ')'
116
130
  else:
117
131
  invalid_at_filter += ' OR '
@@ -131,12 +145,12 @@ def edge_search_filter_query_constructor(
131
145
  and_filter_query = ''
132
146
  for j, and_filter in enumerate(and_filters):
133
147
  and_filter_query += and_filter
134
- if j != len(and_filter_query) - 1:
148
+ if j != len(and_filters) - 1:
135
149
  and_filter_query += ' AND '
136
150
 
137
151
  created_at_filter += and_filter_query
138
152
 
139
- if i == len(or_list) - 1:
153
+ if i == len(filters.created_at) - 1:
140
154
  created_at_filter += ')'
141
155
  else:
142
156
  created_at_filter += ' OR '
@@ -156,12 +170,12 @@ def edge_search_filter_query_constructor(
156
170
  and_filter_query = ''
157
171
  for j, and_filter in enumerate(and_filters):
158
172
  and_filter_query += and_filter
159
- if j != len(and_filter_query) - 1:
173
+ if j != len(and_filters) - 1:
160
174
  and_filter_query += ' AND '
161
175
 
162
176
  expired_at_filter += and_filter_query
163
177
 
164
- if i == len(or_list) - 1:
178
+ if i == len(filters.expired_at) - 1:
165
179
  expired_at_filter += ')'
166
180
  else:
167
181
  expired_at_filter += ' OR '