graphiti-core 0.18.8__py3-none-any.whl → 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (32) hide show
  1. graphiti_core/driver/driver.py +4 -0
  2. graphiti_core/driver/falkordb_driver.py +3 -14
  3. graphiti_core/driver/kuzu_driver.py +175 -0
  4. graphiti_core/driver/neptune_driver.py +301 -0
  5. graphiti_core/edges.py +155 -62
  6. graphiti_core/graph_queries.py +31 -2
  7. graphiti_core/graphiti.py +6 -1
  8. graphiti_core/helpers.py +8 -8
  9. graphiti_core/llm_client/config.py +1 -1
  10. graphiti_core/llm_client/openai_base_client.py +15 -5
  11. graphiti_core/llm_client/openai_client.py +16 -6
  12. graphiti_core/migrations/__init__.py +0 -0
  13. graphiti_core/migrations/neo4j_node_group_labels.py +114 -0
  14. graphiti_core/models/edges/edge_db_queries.py +205 -76
  15. graphiti_core/models/nodes/node_db_queries.py +253 -74
  16. graphiti_core/nodes.py +271 -98
  17. graphiti_core/prompts/extract_edges.py +1 -0
  18. graphiti_core/prompts/extract_nodes.py +1 -1
  19. graphiti_core/search/search.py +42 -12
  20. graphiti_core/search/search_config.py +4 -0
  21. graphiti_core/search/search_filters.py +35 -22
  22. graphiti_core/search/search_utils.py +1329 -392
  23. graphiti_core/utils/bulk_utils.py +50 -15
  24. graphiti_core/utils/datetime_utils.py +13 -0
  25. graphiti_core/utils/maintenance/community_operations.py +39 -32
  26. graphiti_core/utils/maintenance/edge_operations.py +47 -13
  27. graphiti_core/utils/maintenance/graph_data_operations.py +100 -15
  28. graphiti_core/utils/maintenance/node_operations.py +7 -3
  29. {graphiti_core-0.18.8.dist-info → graphiti_core-0.19.0.dist-info}/METADATA +87 -13
  30. {graphiti_core-0.18.8.dist-info → graphiti_core-0.19.0.dist-info}/RECORD +32 -28
  31. {graphiti_core-0.18.8.dist-info → graphiti_core-0.19.0.dist-info}/WHEEL +0 -0
  32. {graphiti_core-0.18.8.dist-info → graphiti_core-0.19.0.dist-info}/licenses/LICENSE +0 -0
@@ -21,6 +21,7 @@ from time import time
21
21
  from graphiti_core.cross_encoder.client import CrossEncoderClient
22
22
  from graphiti_core.driver.driver import GraphDriver
23
23
  from graphiti_core.edges import EntityEdge
24
+ from graphiti_core.embedder.client import EMBEDDING_DIM
24
25
  from graphiti_core.errors import SearchRerankerError
25
26
  from graphiti_core.graphiti_types import GraphitiClients
26
27
  from graphiti_core.helpers import semaphore_gather
@@ -29,6 +30,7 @@ from graphiti_core.search.search_config import (
29
30
  DEFAULT_SEARCH_LIMIT,
30
31
  CommunityReranker,
31
32
  CommunitySearchConfig,
33
+ CommunitySearchMethod,
32
34
  EdgeReranker,
33
35
  EdgeSearchConfig,
34
36
  EdgeSearchMethod,
@@ -81,11 +83,29 @@ async def search(
81
83
 
82
84
  if query.strip() == '':
83
85
  return SearchResults()
84
- query_vector = (
85
- query_vector
86
- if query_vector is not None
87
- else await embedder.create(input_data=[query.replace('\n', ' ')])
88
- )
86
+
87
+ if (
88
+ config.edge_config
89
+ and EdgeSearchMethod.cosine_similarity in config.edge_config.search_methods
90
+ or config.edge_config
91
+ and EdgeReranker.mmr == config.edge_config.reranker
92
+ or config.node_config
93
+ and NodeSearchMethod.cosine_similarity in config.node_config.search_methods
94
+ or config.node_config
95
+ and NodeReranker.mmr == config.node_config.reranker
96
+ or (
97
+ config.community_config
98
+ and CommunitySearchMethod.cosine_similarity in config.community_config.search_methods
99
+ )
100
+ or (config.community_config and CommunityReranker.mmr == config.community_config.reranker)
101
+ ):
102
+ search_vector = (
103
+ query_vector
104
+ if query_vector is not None
105
+ else await embedder.create(input_data=[query.replace('\n', ' ')])
106
+ )
107
+ else:
108
+ search_vector = [0.0] * EMBEDDING_DIM
89
109
 
90
110
  # if group_ids is empty, set it to None
91
111
  group_ids = group_ids if group_ids and group_ids != [''] else None
@@ -99,7 +119,7 @@ async def search(
99
119
  driver,
100
120
  cross_encoder,
101
121
  query,
102
- query_vector,
122
+ search_vector,
103
123
  group_ids,
104
124
  config.edge_config,
105
125
  search_filter,
@@ -112,7 +132,7 @@ async def search(
112
132
  driver,
113
133
  cross_encoder,
114
134
  query,
115
- query_vector,
135
+ search_vector,
116
136
  group_ids,
117
137
  config.node_config,
118
138
  search_filter,
@@ -125,7 +145,7 @@ async def search(
125
145
  driver,
126
146
  cross_encoder,
127
147
  query,
128
- query_vector,
148
+ search_vector,
129
149
  group_ids,
130
150
  config.episode_config,
131
151
  search_filter,
@@ -136,7 +156,7 @@ async def search(
136
156
  driver,
137
157
  cross_encoder,
138
158
  query,
139
- query_vector,
159
+ search_vector,
140
160
  group_ids,
141
161
  config.community_config,
142
162
  config.limit,
@@ -305,12 +325,20 @@ async def node_search(
305
325
  search_tasks = []
306
326
  if NodeSearchMethod.bm25 in config.search_methods:
307
327
  search_tasks.append(
308
- node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
328
+ node_fulltext_search(
329
+ driver, query, search_filter, group_ids, 2 * limit, config.use_local_indexes
330
+ )
309
331
  )
310
332
  if NodeSearchMethod.cosine_similarity in config.search_methods:
311
333
  search_tasks.append(
312
334
  node_similarity_search(
313
- driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score
335
+ driver,
336
+ query_vector,
337
+ search_filter,
338
+ group_ids,
339
+ 2 * limit,
340
+ config.sim_min_score,
341
+ config.use_local_indexes,
314
342
  )
315
343
  )
316
344
  if NodeSearchMethod.bfs in config.search_methods:
@@ -406,7 +434,9 @@ async def episode_search(
406
434
  search_results: list[list[EpisodicNode]] = list(
407
435
  await semaphore_gather(
408
436
  *[
409
- episode_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
437
+ episode_fulltext_search(
438
+ driver, query, search_filter, group_ids, 2 * limit, config.use_local_indexes
439
+ ),
410
440
  ]
411
441
  )
412
442
  )
@@ -24,6 +24,7 @@ from graphiti_core.search.search_utils import (
24
24
  DEFAULT_MIN_SCORE,
25
25
  DEFAULT_MMR_LAMBDA,
26
26
  MAX_SEARCH_DEPTH,
27
+ USE_HNSW,
27
28
  )
28
29
 
29
30
  DEFAULT_SEARCH_LIMIT = 10
@@ -91,6 +92,7 @@ class NodeSearchConfig(BaseModel):
91
92
  sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
92
93
  mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
93
94
  bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
95
+ use_local_indexes: bool = Field(default=USE_HNSW)
94
96
 
95
97
 
96
98
  class EpisodeSearchConfig(BaseModel):
@@ -99,6 +101,7 @@ class EpisodeSearchConfig(BaseModel):
99
101
  sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
100
102
  mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
101
103
  bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
104
+ use_local_indexes: bool = Field(default=USE_HNSW)
102
105
 
103
106
 
104
107
  class CommunitySearchConfig(BaseModel):
@@ -107,6 +110,7 @@ class CommunitySearchConfig(BaseModel):
107
110
  sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
108
111
  mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
109
112
  bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
113
+ use_local_indexes: bool = Field(default=USE_HNSW)
110
114
 
111
115
 
112
116
  class SearchConfig(BaseModel):
@@ -20,6 +20,8 @@ from typing import Any
20
20
 
21
21
  from pydantic import BaseModel, Field
22
22
 
23
+ from graphiti_core.driver.driver import GraphProvider
24
+
23
25
 
24
26
  class ComparisonOperator(Enum):
25
27
  equals = '='
@@ -54,16 +56,21 @@ class SearchFilters(BaseModel):
54
56
 
55
57
  def node_search_filter_query_constructor(
56
58
  filters: SearchFilters,
57
- ) -> tuple[str, dict[str, Any]]:
58
- filter_query: str = ''
59
+ provider: GraphProvider,
60
+ ) -> tuple[list[str], dict[str, Any]]:
61
+ filter_queries: list[str] = []
59
62
  filter_params: dict[str, Any] = {}
60
63
 
61
64
  if filters.node_labels is not None:
62
- node_labels = '|'.join(filters.node_labels)
63
- node_label_filter = ' AND n:' + node_labels
64
- filter_query += node_label_filter
65
+ if provider == GraphProvider.KUZU:
66
+ node_label_filter = 'list_has_all(n.labels, $labels)'
67
+ filter_params['labels'] = filters.node_labels
68
+ else:
69
+ node_labels = '|'.join(filters.node_labels)
70
+ node_label_filter = 'n:' + node_labels
71
+ filter_queries.append(node_label_filter)
65
72
 
66
- return filter_query, filter_params
73
+ return filter_queries, filter_params
67
74
 
68
75
 
69
76
  def date_filter_query_constructor(
@@ -81,23 +88,29 @@ def date_filter_query_constructor(
81
88
 
82
89
  def edge_search_filter_query_constructor(
83
90
  filters: SearchFilters,
84
- ) -> tuple[str, dict[str, Any]]:
85
- filter_query: str = ''
91
+ provider: GraphProvider,
92
+ ) -> tuple[list[str], dict[str, Any]]:
93
+ filter_queries: list[str] = []
86
94
  filter_params: dict[str, Any] = {}
87
95
 
88
96
  if filters.edge_types is not None:
89
97
  edge_types = filters.edge_types
90
- edge_types_filter = '\nAND e.name in $edge_types'
91
- filter_query += edge_types_filter
98
+ filter_queries.append('e.name in $edge_types')
92
99
  filter_params['edge_types'] = edge_types
93
100
 
94
101
  if filters.node_labels is not None:
95
- node_labels = '|'.join(filters.node_labels)
96
- node_label_filter = '\nAND n:' + node_labels + ' AND m:' + node_labels
97
- filter_query += node_label_filter
102
+ if provider == GraphProvider.KUZU:
103
+ node_label_filter = (
104
+ 'list_has_all(n.labels, $labels) AND list_has_all(m.labels, $labels)'
105
+ )
106
+ filter_params['labels'] = filters.node_labels
107
+ else:
108
+ node_labels = '|'.join(filters.node_labels)
109
+ node_label_filter = 'n:' + node_labels + ' AND m:' + node_labels
110
+ filter_queries.append(node_label_filter)
98
111
 
99
112
  if filters.valid_at is not None:
100
- valid_at_filter = '\nAND ('
113
+ valid_at_filter = '('
101
114
  for i, or_list in enumerate(filters.valid_at):
102
115
  for j, date_filter in enumerate(or_list):
103
116
  if date_filter.comparison_operator not in [
@@ -125,10 +138,10 @@ def edge_search_filter_query_constructor(
125
138
  else:
126
139
  valid_at_filter += ' OR '
127
140
 
128
- filter_query += valid_at_filter
141
+ filter_queries.append(valid_at_filter)
129
142
 
130
143
  if filters.invalid_at is not None:
131
- invalid_at_filter = ' AND ('
144
+ invalid_at_filter = '('
132
145
  for i, or_list in enumerate(filters.invalid_at):
133
146
  for j, date_filter in enumerate(or_list):
134
147
  if date_filter.comparison_operator not in [
@@ -156,10 +169,10 @@ def edge_search_filter_query_constructor(
156
169
  else:
157
170
  invalid_at_filter += ' OR '
158
171
 
159
- filter_query += invalid_at_filter
172
+ filter_queries.append(invalid_at_filter)
160
173
 
161
174
  if filters.created_at is not None:
162
- created_at_filter = ' AND ('
175
+ created_at_filter = '('
163
176
  for i, or_list in enumerate(filters.created_at):
164
177
  for j, date_filter in enumerate(or_list):
165
178
  if date_filter.comparison_operator not in [
@@ -187,10 +200,10 @@ def edge_search_filter_query_constructor(
187
200
  else:
188
201
  created_at_filter += ' OR '
189
202
 
190
- filter_query += created_at_filter
203
+ filter_queries.append(created_at_filter)
191
204
 
192
205
  if filters.expired_at is not None:
193
- expired_at_filter = ' AND ('
206
+ expired_at_filter = '('
194
207
  for i, or_list in enumerate(filters.expired_at):
195
208
  for j, date_filter in enumerate(or_list):
196
209
  if date_filter.comparison_operator not in [
@@ -218,6 +231,6 @@ def edge_search_filter_query_constructor(
218
231
  else:
219
232
  expired_at_filter += ' OR '
220
233
 
221
- filter_query += expired_at_filter
234
+ filter_queries.append(expired_at_filter)
222
235
 
223
- return filter_query, filter_params
236
+ return filter_queries, filter_params