cognee 0.5.0__py3-none-any.whl → 0.5.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/api/client.py +5 -1
- cognee/api/v1/add/add.py +1 -2
- cognee/api/v1/cognify/code_graph_pipeline.py +119 -0
- cognee/api/v1/cognify/cognify.py +16 -24
- cognee/api/v1/cognify/routers/__init__.py +1 -0
- cognee/api/v1/cognify/routers/get_code_pipeline_router.py +90 -0
- cognee/api/v1/cognify/routers/get_cognify_router.py +1 -3
- cognee/api/v1/datasets/routers/get_datasets_router.py +3 -3
- cognee/api/v1/ontologies/ontologies.py +37 -12
- cognee/api/v1/ontologies/routers/get_ontology_router.py +25 -27
- cognee/api/v1/search/search.py +0 -4
- cognee/api/v1/ui/ui.py +68 -38
- cognee/context_global_variables.py +16 -61
- cognee/eval_framework/answer_generation/answer_generation_executor.py +0 -10
- cognee/eval_framework/answer_generation/run_question_answering_module.py +1 -1
- cognee/eval_framework/corpus_builder/task_getters/get_cascade_graph_tasks.py +2 -0
- cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py +4 -4
- cognee/eval_framework/eval_config.py +2 -2
- cognee/eval_framework/modal_run_eval.py +28 -16
- cognee/infrastructure/databases/graph/config.py +0 -3
- cognee/infrastructure/databases/graph/get_graph_engine.py +0 -1
- cognee/infrastructure/databases/graph/graph_db_interface.py +0 -15
- cognee/infrastructure/databases/graph/kuzu/adapter.py +0 -228
- cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +1 -80
- cognee/infrastructure/databases/utils/__init__.py +0 -3
- cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +48 -62
- cognee/infrastructure/databases/vector/config.py +0 -2
- cognee/infrastructure/databases/vector/create_vector_engine.py +0 -1
- cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +6 -8
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +7 -9
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +10 -11
- cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +544 -0
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +0 -2
- cognee/infrastructure/databases/vector/vector_db_interface.py +0 -35
- cognee/infrastructure/files/storage/s3_config.py +0 -2
- cognee/infrastructure/llm/LLMGateway.py +2 -5
- cognee/infrastructure/llm/config.py +0 -35
- cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py +2 -2
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +8 -23
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +16 -17
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +37 -40
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +36 -39
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +1 -19
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +9 -11
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +21 -23
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +34 -42
- cognee/modules/cognify/config.py +0 -2
- cognee/modules/data/deletion/prune_system.py +2 -52
- cognee/modules/data/methods/delete_dataset.py +0 -26
- cognee/modules/engine/models/__init__.py +0 -1
- cognee/modules/graph/cognee_graph/CogneeGraph.py +37 -85
- cognee/modules/graph/cognee_graph/CogneeGraphElements.py +3 -8
- cognee/modules/memify/memify.py +7 -1
- cognee/modules/pipelines/operations/pipeline.py +2 -18
- cognee/modules/retrieval/__init__.py +1 -1
- cognee/modules/retrieval/code_retriever.py +232 -0
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +0 -4
- cognee/modules/retrieval/graph_completion_cot_retriever.py +0 -4
- cognee/modules/retrieval/graph_completion_retriever.py +0 -10
- cognee/modules/retrieval/graph_summary_completion_retriever.py +0 -4
- cognee/modules/retrieval/temporal_retriever.py +0 -4
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +10 -42
- cognee/modules/run_custom_pipeline/run_custom_pipeline.py +1 -8
- cognee/modules/search/methods/get_search_type_tools.py +8 -54
- cognee/modules/search/methods/no_access_control_search.py +0 -4
- cognee/modules/search/methods/search.py +0 -21
- cognee/modules/search/types/SearchType.py +1 -1
- cognee/modules/settings/get_settings.py +0 -19
- cognee/modules/users/methods/get_authenticated_user.py +2 -2
- cognee/modules/users/models/DatasetDatabase.py +3 -15
- cognee/shared/logging_utils.py +0 -4
- cognee/tasks/code/enrich_dependency_graph_checker.py +35 -0
- cognee/tasks/code/get_local_dependencies_checker.py +20 -0
- cognee/tasks/code/get_repo_dependency_graph_checker.py +35 -0
- cognee/tasks/documents/__init__.py +1 -0
- cognee/tasks/documents/check_permissions_on_dataset.py +26 -0
- cognee/tasks/graph/extract_graph_from_data.py +10 -9
- cognee/tasks/repo_processor/__init__.py +2 -0
- cognee/tasks/repo_processor/get_local_dependencies.py +335 -0
- cognee/tasks/repo_processor/get_non_code_files.py +158 -0
- cognee/tasks/repo_processor/get_repo_file_dependencies.py +243 -0
- cognee/tasks/storage/add_data_points.py +2 -142
- cognee/tests/test_cognee_server_start.py +4 -2
- cognee/tests/test_conversation_history.py +1 -23
- cognee/tests/test_delete_bmw_example.py +60 -0
- cognee/tests/test_search_db.py +1 -37
- cognee/tests/unit/api/test_ontology_endpoint.py +89 -77
- cognee/tests/unit/infrastructure/mock_embedding_engine.py +7 -3
- cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +5 -0
- cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +2 -2
- cognee/tests/unit/modules/graph/cognee_graph_test.py +0 -406
- {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/METADATA +89 -76
- {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/RECORD +97 -118
- {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/WHEEL +1 -1
- cognee/api/v1/ui/node_setup.py +0 -360
- cognee/api/v1/ui/npm_utils.py +0 -50
- cognee/eval_framework/Dockerfile +0 -29
- cognee/infrastructure/databases/dataset_database_handler/__init__.py +0 -3
- cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py +0 -80
- cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +0 -18
- cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py +0 -10
- cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py +0 -81
- cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +0 -168
- cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py +0 -10
- cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py +0 -10
- cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py +0 -30
- cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py +0 -50
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py +0 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py +0 -153
- cognee/memify_pipelines/create_triplet_embeddings.py +0 -53
- cognee/modules/engine/models/Triplet.py +0 -9
- cognee/modules/retrieval/register_retriever.py +0 -10
- cognee/modules/retrieval/registered_community_retrievers.py +0 -1
- cognee/modules/retrieval/triplet_retriever.py +0 -182
- cognee/shared/rate_limiting.py +0 -30
- cognee/tasks/memify/get_triplet_datapoints.py +0 -289
- cognee/tests/integration/retrieval/test_triplet_retriever.py +0 -84
- cognee/tests/integration/tasks/test_add_data_points.py +0 -139
- cognee/tests/integration/tasks/test_get_triplet_datapoints.py +0 -69
- cognee/tests/test_dataset_database_handler.py +0 -137
- cognee/tests/test_dataset_delete.py +0 -76
- cognee/tests/test_edge_centered_payload.py +0 -170
- cognee/tests/test_pipeline_cache.py +0 -164
- cognee/tests/unit/infrastructure/llm/test_llm_config.py +0 -46
- cognee/tests/unit/modules/memify_tasks/test_get_triplet_datapoints.py +0 -214
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +0 -608
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +0 -83
- cognee/tests/unit/tasks/storage/test_add_data_points.py +0 -288
- {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -0,0 +1,544 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
import logging
|
|
3
|
+
import functools
|
|
4
|
+
import os
|
|
5
|
+
import time
|
|
6
|
+
import asyncio
|
|
7
|
+
import random
|
|
8
|
+
from cognee.shared.logging_utils import get_logger
|
|
9
|
+
from cognee.infrastructure.llm.config import get_llm_config
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
logger = get_logger()
|
|
13
|
+
|
|
14
|
+
# Common error patterns that indicate rate limiting
|
|
15
|
+
RATE_LIMIT_ERROR_PATTERNS = [
|
|
16
|
+
"rate limit",
|
|
17
|
+
"rate_limit",
|
|
18
|
+
"ratelimit",
|
|
19
|
+
"too many requests",
|
|
20
|
+
"retry after",
|
|
21
|
+
"capacity",
|
|
22
|
+
"quota",
|
|
23
|
+
"limit exceeded",
|
|
24
|
+
"tps limit exceeded",
|
|
25
|
+
"request limit exceeded",
|
|
26
|
+
"maximum requests",
|
|
27
|
+
"exceeded your current quota",
|
|
28
|
+
"throttled",
|
|
29
|
+
"throttling",
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
# Default retry settings
|
|
33
|
+
DEFAULT_MAX_RETRIES = 5
|
|
34
|
+
DEFAULT_INITIAL_BACKOFF = 1.0 # seconds
|
|
35
|
+
DEFAULT_BACKOFF_FACTOR = 2.0 # exponential backoff multiplier
|
|
36
|
+
DEFAULT_JITTER = 0.1 # 10% jitter to avoid thundering herd
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class EmbeddingRateLimiter:
|
|
40
|
+
"""
|
|
41
|
+
Rate limiter for embedding API calls.
|
|
42
|
+
|
|
43
|
+
This class implements a singleton pattern to ensure that rate limiting
|
|
44
|
+
is consistent across all embedding requests. It uses the limits
|
|
45
|
+
library with a moving window strategy to control request rates.
|
|
46
|
+
|
|
47
|
+
The rate limiter uses the same configuration as the LLM API rate limiter
|
|
48
|
+
but uses a separate key to track embedding API calls independently.
|
|
49
|
+
|
|
50
|
+
Public Methods:
|
|
51
|
+
- get_instance
|
|
52
|
+
- reset_instance
|
|
53
|
+
- hit_limit
|
|
54
|
+
- wait_if_needed
|
|
55
|
+
- async_wait_if_needed
|
|
56
|
+
|
|
57
|
+
Instance Variables:
|
|
58
|
+
- enabled
|
|
59
|
+
- requests_limit
|
|
60
|
+
- interval_seconds
|
|
61
|
+
- request_times
|
|
62
|
+
- lock
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
_instance = None
|
|
66
|
+
lock = threading.Lock()
|
|
67
|
+
|
|
68
|
+
@classmethod
|
|
69
|
+
def get_instance(cls):
|
|
70
|
+
"""
|
|
71
|
+
Retrieve the singleton instance of the EmbeddingRateLimiter.
|
|
72
|
+
|
|
73
|
+
This method ensures that only one instance of the class exists and
|
|
74
|
+
is thread-safe. It lazily initializes the instance if it doesn't
|
|
75
|
+
already exist.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
--------
|
|
79
|
+
|
|
80
|
+
The singleton instance of the EmbeddingRateLimiter class.
|
|
81
|
+
"""
|
|
82
|
+
if cls._instance is None:
|
|
83
|
+
with cls.lock:
|
|
84
|
+
if cls._instance is None:
|
|
85
|
+
cls._instance = cls()
|
|
86
|
+
return cls._instance
|
|
87
|
+
|
|
88
|
+
@classmethod
|
|
89
|
+
def reset_instance(cls):
|
|
90
|
+
"""
|
|
91
|
+
Reset the singleton instance of the EmbeddingRateLimiter.
|
|
92
|
+
|
|
93
|
+
This method is thread-safe and sets the instance to None, allowing
|
|
94
|
+
for a new instance to be created when requested again.
|
|
95
|
+
"""
|
|
96
|
+
with cls.lock:
|
|
97
|
+
cls._instance = None
|
|
98
|
+
|
|
99
|
+
def __init__(self):
|
|
100
|
+
config = get_llm_config()
|
|
101
|
+
self.enabled = config.embedding_rate_limit_enabled
|
|
102
|
+
self.requests_limit = config.embedding_rate_limit_requests
|
|
103
|
+
self.interval_seconds = config.embedding_rate_limit_interval
|
|
104
|
+
self.request_times = []
|
|
105
|
+
self.lock = threading.Lock()
|
|
106
|
+
|
|
107
|
+
logging.info(
|
|
108
|
+
f"EmbeddingRateLimiter initialized: enabled={self.enabled}, "
|
|
109
|
+
f"requests_limit={self.requests_limit}, interval_seconds={self.interval_seconds}"
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
def hit_limit(self) -> bool:
|
|
113
|
+
"""
|
|
114
|
+
Check if the current request would exceed the rate limit.
|
|
115
|
+
|
|
116
|
+
This method checks if the rate limiter is enabled and evaluates
|
|
117
|
+
the number of requests made in the elapsed interval.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
- bool: True if the rate limit would be exceeded, False otherwise.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
--------
|
|
124
|
+
|
|
125
|
+
- bool: True if the rate limit would be exceeded, otherwise False.
|
|
126
|
+
"""
|
|
127
|
+
if not self.enabled:
|
|
128
|
+
return False
|
|
129
|
+
|
|
130
|
+
current_time = time.time()
|
|
131
|
+
|
|
132
|
+
with self.lock:
|
|
133
|
+
# Remove expired request times
|
|
134
|
+
cutoff_time = current_time - self.interval_seconds
|
|
135
|
+
self.request_times = [t for t in self.request_times if t > cutoff_time]
|
|
136
|
+
|
|
137
|
+
# Check if adding a new request would exceed the limit
|
|
138
|
+
if len(self.request_times) >= self.requests_limit:
|
|
139
|
+
logger.info(
|
|
140
|
+
f"Rate limit hit: {len(self.request_times)} requests in the last {self.interval_seconds} seconds"
|
|
141
|
+
)
|
|
142
|
+
return True
|
|
143
|
+
|
|
144
|
+
# Otherwise, we're under the limit
|
|
145
|
+
return False
|
|
146
|
+
|
|
147
|
+
def wait_if_needed(self) -> float:
|
|
148
|
+
"""
|
|
149
|
+
Block until a request can be made without exceeding the rate limit.
|
|
150
|
+
|
|
151
|
+
This method will wait if the current request would exceed the
|
|
152
|
+
rate limit and returns the time waited in seconds.
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
- float: Time waited in seconds before a request is allowed.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
--------
|
|
159
|
+
|
|
160
|
+
- float: Time waited in seconds before proceeding.
|
|
161
|
+
"""
|
|
162
|
+
if not self.enabled:
|
|
163
|
+
return 0
|
|
164
|
+
|
|
165
|
+
wait_time = 0
|
|
166
|
+
start_time = time.time()
|
|
167
|
+
|
|
168
|
+
while self.hit_limit():
|
|
169
|
+
time.sleep(0.5) # Poll every 0.5 seconds
|
|
170
|
+
wait_time = time.time() - start_time
|
|
171
|
+
|
|
172
|
+
# Record this request
|
|
173
|
+
with self.lock:
|
|
174
|
+
self.request_times.append(time.time())
|
|
175
|
+
|
|
176
|
+
return wait_time
|
|
177
|
+
|
|
178
|
+
async def async_wait_if_needed(self) -> float:
|
|
179
|
+
"""
|
|
180
|
+
Asynchronously wait until a request can be made without exceeding the rate limit.
|
|
181
|
+
|
|
182
|
+
This method will wait if the current request would exceed the
|
|
183
|
+
rate limit and returns the time waited in seconds.
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
- float: Time waited in seconds before a request is allowed.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
--------
|
|
190
|
+
|
|
191
|
+
- float: Time waited in seconds before proceeding.
|
|
192
|
+
"""
|
|
193
|
+
if not self.enabled:
|
|
194
|
+
return 0
|
|
195
|
+
|
|
196
|
+
wait_time = 0
|
|
197
|
+
start_time = time.time()
|
|
198
|
+
|
|
199
|
+
while self.hit_limit():
|
|
200
|
+
await asyncio.sleep(0.5) # Poll every 0.5 seconds
|
|
201
|
+
wait_time = time.time() - start_time
|
|
202
|
+
|
|
203
|
+
# Record this request
|
|
204
|
+
with self.lock:
|
|
205
|
+
self.request_times.append(time.time())
|
|
206
|
+
|
|
207
|
+
return wait_time
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def embedding_rate_limit_sync(func):
|
|
211
|
+
"""
|
|
212
|
+
Apply rate limiting to a synchronous embedding function.
|
|
213
|
+
|
|
214
|
+
Parameters:
|
|
215
|
+
-----------
|
|
216
|
+
|
|
217
|
+
- func: Function to decorate with rate limiting logic.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
--------
|
|
221
|
+
|
|
222
|
+
Returns the decorated function that applies rate limiting.
|
|
223
|
+
"""
|
|
224
|
+
|
|
225
|
+
@functools.wraps(func)
|
|
226
|
+
def wrapper(*args, **kwargs):
|
|
227
|
+
"""
|
|
228
|
+
Wrap the given function with rate limiting logic to control the embedding API usage.
|
|
229
|
+
|
|
230
|
+
Checks if the rate limit has been exceeded before allowing the function to execute. If
|
|
231
|
+
the limit is hit, it logs a warning and raises an EmbeddingException. Otherwise, it
|
|
232
|
+
updates the request count and proceeds to call the original function.
|
|
233
|
+
|
|
234
|
+
Parameters:
|
|
235
|
+
-----------
|
|
236
|
+
|
|
237
|
+
- *args: Variable length argument list for the wrapped function.
|
|
238
|
+
- **kwargs: Keyword arguments for the wrapped function.
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
--------
|
|
242
|
+
|
|
243
|
+
Returns the result of the wrapped function if rate limiting conditions are met.
|
|
244
|
+
"""
|
|
245
|
+
limiter = EmbeddingRateLimiter.get_instance()
|
|
246
|
+
|
|
247
|
+
# Check if rate limiting is enabled and if we're at the limit
|
|
248
|
+
if limiter.hit_limit():
|
|
249
|
+
error_msg = "Embedding API rate limit exceeded"
|
|
250
|
+
logger.warning(error_msg)
|
|
251
|
+
|
|
252
|
+
# Create a custom embedding rate limit exception
|
|
253
|
+
from cognee.infrastructure.databases.exceptions import EmbeddingException
|
|
254
|
+
|
|
255
|
+
raise EmbeddingException(error_msg)
|
|
256
|
+
|
|
257
|
+
# Add this request to the counter and proceed
|
|
258
|
+
limiter.wait_if_needed()
|
|
259
|
+
return func(*args, **kwargs)
|
|
260
|
+
|
|
261
|
+
return wrapper
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def embedding_rate_limit_async(func):
|
|
265
|
+
"""
|
|
266
|
+
Decorator that applies rate limiting to an asynchronous embedding function.
|
|
267
|
+
|
|
268
|
+
Parameters:
|
|
269
|
+
-----------
|
|
270
|
+
|
|
271
|
+
- func: Async function to decorate.
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
--------
|
|
275
|
+
|
|
276
|
+
Returns the decorated async function that applies rate limiting.
|
|
277
|
+
"""
|
|
278
|
+
|
|
279
|
+
@functools.wraps(func)
|
|
280
|
+
async def wrapper(*args, **kwargs):
|
|
281
|
+
"""
|
|
282
|
+
Handle function calls with embedding rate limiting.
|
|
283
|
+
|
|
284
|
+
This asynchronous wrapper checks if the embedding API rate limit is exceeded before
|
|
285
|
+
allowing the function to execute. If the limit is exceeded, it logs a warning and raises
|
|
286
|
+
an EmbeddingException. If not, it waits as necessary and proceeds with the function
|
|
287
|
+
call.
|
|
288
|
+
|
|
289
|
+
Parameters:
|
|
290
|
+
-----------
|
|
291
|
+
|
|
292
|
+
- *args: Positional arguments passed to the wrapped function.
|
|
293
|
+
- **kwargs: Keyword arguments passed to the wrapped function.
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
--------
|
|
297
|
+
|
|
298
|
+
Returns the result of the wrapped function after handling rate limiting.
|
|
299
|
+
"""
|
|
300
|
+
limiter = EmbeddingRateLimiter.get_instance()
|
|
301
|
+
|
|
302
|
+
# Check if rate limiting is enabled and if we're at the limit
|
|
303
|
+
if limiter.hit_limit():
|
|
304
|
+
error_msg = "Embedding API rate limit exceeded"
|
|
305
|
+
logger.warning(error_msg)
|
|
306
|
+
|
|
307
|
+
# Create a custom embedding rate limit exception
|
|
308
|
+
from cognee.infrastructure.databases.exceptions import EmbeddingException
|
|
309
|
+
|
|
310
|
+
raise EmbeddingException(error_msg)
|
|
311
|
+
|
|
312
|
+
# Add this request to the counter and proceed
|
|
313
|
+
await limiter.async_wait_if_needed()
|
|
314
|
+
return await func(*args, **kwargs)
|
|
315
|
+
|
|
316
|
+
return wrapper
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def embedding_sleep_and_retry_sync(max_retries=5, base_backoff=1.0, jitter=0.5):
|
|
320
|
+
"""
|
|
321
|
+
Add retry with exponential backoff for synchronous embedding functions.
|
|
322
|
+
|
|
323
|
+
Parameters:
|
|
324
|
+
-----------
|
|
325
|
+
|
|
326
|
+
- max_retries: Maximum number of retries before giving up. (default 5)
|
|
327
|
+
- base_backoff: Base backoff time in seconds for retry intervals. (default 1.0)
|
|
328
|
+
- jitter: Jitter factor to randomize the backoff time to avoid collision. (default
|
|
329
|
+
0.5)
|
|
330
|
+
|
|
331
|
+
Returns:
|
|
332
|
+
--------
|
|
333
|
+
|
|
334
|
+
A decorator that retries the wrapped function on rate limit errors, applying
|
|
335
|
+
exponential backoff with jitter.
|
|
336
|
+
"""
|
|
337
|
+
|
|
338
|
+
def decorator(func):
|
|
339
|
+
"""
|
|
340
|
+
Wraps a function to apply retry logic on rate limit errors.
|
|
341
|
+
|
|
342
|
+
Parameters:
|
|
343
|
+
-----------
|
|
344
|
+
|
|
345
|
+
- func: The function to be wrapped with retry logic.
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
--------
|
|
349
|
+
|
|
350
|
+
Returns the wrapped function with retry logic applied.
|
|
351
|
+
"""
|
|
352
|
+
|
|
353
|
+
@functools.wraps(func)
|
|
354
|
+
def wrapper(*args, **kwargs):
|
|
355
|
+
"""
|
|
356
|
+
Retry the execution of a function with backoff on failure due to rate limit errors.
|
|
357
|
+
|
|
358
|
+
This wrapper function will call the specified function and if it raises an exception, it
|
|
359
|
+
will handle retries according to defined conditions. It will check the environment for a
|
|
360
|
+
DISABLE_RETRIES flag to determine whether to retry or propagate errors immediately
|
|
361
|
+
during tests. If the error is identified as a rate limit error, it will apply an
|
|
362
|
+
exponential backoff strategy with jitter before retrying, up to a maximum number of
|
|
363
|
+
retries. If the retries are exhausted, it raises the last encountered error.
|
|
364
|
+
|
|
365
|
+
Parameters:
|
|
366
|
+
-----------
|
|
367
|
+
|
|
368
|
+
- *args: Positional arguments passed to the wrapped function.
|
|
369
|
+
- **kwargs: Keyword arguments passed to the wrapped function.
|
|
370
|
+
|
|
371
|
+
Returns:
|
|
372
|
+
--------
|
|
373
|
+
|
|
374
|
+
Returns the result of the wrapped function if successful; otherwise, raises the last
|
|
375
|
+
error encountered after maximum retries are exhausted.
|
|
376
|
+
"""
|
|
377
|
+
# If DISABLE_RETRIES is set, don't retry for testing purposes
|
|
378
|
+
disable_retries = os.environ.get("DISABLE_RETRIES", "false").lower() in (
|
|
379
|
+
"true",
|
|
380
|
+
"1",
|
|
381
|
+
"yes",
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
retries = 0
|
|
385
|
+
last_error = None
|
|
386
|
+
|
|
387
|
+
while retries <= max_retries:
|
|
388
|
+
try:
|
|
389
|
+
return func(*args, **kwargs)
|
|
390
|
+
except Exception as e:
|
|
391
|
+
# Check if this is a rate limit error
|
|
392
|
+
error_str = str(e).lower()
|
|
393
|
+
error_type = type(e).__name__
|
|
394
|
+
is_rate_limit = any(
|
|
395
|
+
pattern in error_str.lower() for pattern in RATE_LIMIT_ERROR_PATTERNS
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
if disable_retries:
|
|
399
|
+
# For testing, propagate the exception immediately
|
|
400
|
+
raise
|
|
401
|
+
|
|
402
|
+
if is_rate_limit and retries < max_retries:
|
|
403
|
+
# Calculate backoff with jitter
|
|
404
|
+
backoff = (
|
|
405
|
+
base_backoff * (2**retries) * (1 + random.uniform(-jitter, jitter))
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
logger.warning(
|
|
409
|
+
f"Embedding rate limit hit, retrying in {backoff:.2f}s "
|
|
410
|
+
f"(attempt {retries + 1}/{max_retries}): "
|
|
411
|
+
f"({error_str!r}, {error_type!r})"
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
time.sleep(backoff)
|
|
415
|
+
retries += 1
|
|
416
|
+
last_error = e
|
|
417
|
+
else:
|
|
418
|
+
# Not a rate limit error or max retries reached, raise
|
|
419
|
+
raise
|
|
420
|
+
|
|
421
|
+
# If we exit the loop due to max retries, raise the last error
|
|
422
|
+
if last_error:
|
|
423
|
+
raise last_error
|
|
424
|
+
|
|
425
|
+
return wrapper
|
|
426
|
+
|
|
427
|
+
return decorator
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def embedding_sleep_and_retry_async(max_retries=5, base_backoff=1.0, jitter=0.5):
|
|
431
|
+
"""
|
|
432
|
+
Add retry logic with exponential backoff for asynchronous embedding functions.
|
|
433
|
+
|
|
434
|
+
This decorator retries the wrapped asynchronous function upon encountering rate limit
|
|
435
|
+
errors, utilizing exponential backoff with optional jitter to space out retry attempts.
|
|
436
|
+
It allows for a maximum number of retries before giving up and raising the last error
|
|
437
|
+
encountered.
|
|
438
|
+
|
|
439
|
+
Parameters:
|
|
440
|
+
-----------
|
|
441
|
+
|
|
442
|
+
- max_retries: Maximum number of retries allowed before giving up. (default 5)
|
|
443
|
+
- base_backoff: Base amount of time in seconds to wait before retrying after a rate
|
|
444
|
+
limit error. (default 1.0)
|
|
445
|
+
- jitter: Amount of randomness to add to the backoff duration to help mitigate burst
|
|
446
|
+
issues on retries. (default 0.5)
|
|
447
|
+
|
|
448
|
+
Returns:
|
|
449
|
+
--------
|
|
450
|
+
|
|
451
|
+
Returns a decorated asynchronous function that implements the retry logic on rate
|
|
452
|
+
limit errors.
|
|
453
|
+
"""
|
|
454
|
+
|
|
455
|
+
def decorator(func):
|
|
456
|
+
"""
|
|
457
|
+
Handle retries for an async function with exponential backoff and jitter.
|
|
458
|
+
|
|
459
|
+
Parameters:
|
|
460
|
+
-----------
|
|
461
|
+
|
|
462
|
+
- func: An asynchronous function to be wrapped with retry logic.
|
|
463
|
+
|
|
464
|
+
Returns:
|
|
465
|
+
--------
|
|
466
|
+
|
|
467
|
+
Returns the wrapper function that manages the retry behavior for the wrapped async
|
|
468
|
+
function.
|
|
469
|
+
"""
|
|
470
|
+
|
|
471
|
+
@functools.wraps(func)
|
|
472
|
+
async def wrapper(*args, **kwargs):
|
|
473
|
+
"""
|
|
474
|
+
Handle retries for an async function with exponential backoff and jitter.
|
|
475
|
+
|
|
476
|
+
If the environment variable DISABLE_RETRIES is set to true, 1, or yes, the function will
|
|
477
|
+
not retry on errors.
|
|
478
|
+
It attempts to call the wrapped function until it succeeds or the maximum number of
|
|
479
|
+
retries is reached. If an exception occurs, it checks if it's a rate limit error to
|
|
480
|
+
determine if a retry is needed.
|
|
481
|
+
|
|
482
|
+
Parameters:
|
|
483
|
+
-----------
|
|
484
|
+
|
|
485
|
+
- *args: Positional arguments passed to the wrapped function.
|
|
486
|
+
- **kwargs: Keyword arguments passed to the wrapped function.
|
|
487
|
+
|
|
488
|
+
Returns:
|
|
489
|
+
--------
|
|
490
|
+
|
|
491
|
+
Returns the result of the wrapped async function if successful; raises the last
|
|
492
|
+
encountered error if all retries fail.
|
|
493
|
+
"""
|
|
494
|
+
# If DISABLE_RETRIES is set, don't retry for testing purposes
|
|
495
|
+
disable_retries = os.environ.get("DISABLE_RETRIES", "false").lower() in (
|
|
496
|
+
"true",
|
|
497
|
+
"1",
|
|
498
|
+
"yes",
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
retries = 0
|
|
502
|
+
last_error = None
|
|
503
|
+
|
|
504
|
+
while retries <= max_retries:
|
|
505
|
+
try:
|
|
506
|
+
return await func(*args, **kwargs)
|
|
507
|
+
except Exception as e:
|
|
508
|
+
# Check if this is a rate limit error
|
|
509
|
+
error_str = str(e).lower()
|
|
510
|
+
error_type = type(e).__name__
|
|
511
|
+
is_rate_limit = any(
|
|
512
|
+
pattern in error_str.lower() for pattern in RATE_LIMIT_ERROR_PATTERNS
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
if disable_retries:
|
|
516
|
+
# For testing, propagate the exception immediately
|
|
517
|
+
raise
|
|
518
|
+
|
|
519
|
+
if is_rate_limit and retries < max_retries:
|
|
520
|
+
# Calculate backoff with jitter
|
|
521
|
+
backoff = (
|
|
522
|
+
base_backoff * (2**retries) * (1 + random.uniform(-jitter, jitter))
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
logger.warning(
|
|
526
|
+
f"Embedding rate limit hit, retrying in {backoff:.2f}s "
|
|
527
|
+
f"(attempt {retries + 1}/{max_retries}): "
|
|
528
|
+
f"({error_str!r}, {error_type!r})"
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
await asyncio.sleep(backoff)
|
|
532
|
+
retries += 1
|
|
533
|
+
last_error = e
|
|
534
|
+
else:
|
|
535
|
+
# Not a rate limit error or max retries reached, raise
|
|
536
|
+
raise
|
|
537
|
+
|
|
538
|
+
# If we exit the loop due to max retries, raise the last error
|
|
539
|
+
if last_error:
|
|
540
|
+
raise last_error
|
|
541
|
+
|
|
542
|
+
return wrapper
|
|
543
|
+
|
|
544
|
+
return decorator
|
|
@@ -193,8 +193,6 @@ class LanceDBAdapter(VectorDBInterface):
|
|
|
193
193
|
for (data_point_index, data_point) in enumerate(data_points)
|
|
194
194
|
]
|
|
195
195
|
|
|
196
|
-
lance_data_points = list({dp.id: dp for dp in lance_data_points}.values())
|
|
197
|
-
|
|
198
196
|
async with self.VECTOR_DB_LOCK:
|
|
199
197
|
await (
|
|
200
198
|
collection.merge_insert("id")
|
|
@@ -2,8 +2,6 @@ from typing import List, Protocol, Optional, Union, Any
|
|
|
2
2
|
from abc import abstractmethod
|
|
3
3
|
from cognee.infrastructure.engine import DataPoint
|
|
4
4
|
from .models.PayloadSchema import PayloadSchema
|
|
5
|
-
from uuid import UUID
|
|
6
|
-
from cognee.modules.users.models import User
|
|
7
5
|
|
|
8
6
|
|
|
9
7
|
class VectorDBInterface(Protocol):
|
|
@@ -219,36 +217,3 @@ class VectorDBInterface(Protocol):
|
|
|
219
217
|
- Any: The schema object suitable for this vector database
|
|
220
218
|
"""
|
|
221
219
|
return model_type
|
|
222
|
-
|
|
223
|
-
@classmethod
|
|
224
|
-
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
|
225
|
-
"""
|
|
226
|
-
Return a dictionary with connection info for a vector database for the given dataset.
|
|
227
|
-
Function can auto handle deploying of the actual database if needed, but is not necessary.
|
|
228
|
-
Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future.
|
|
229
|
-
Needed for Cognee multi-tenant/multi-user and backend access control support.
|
|
230
|
-
|
|
231
|
-
Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database.
|
|
232
|
-
From which internal mapping of dataset -> database connection info will be done.
|
|
233
|
-
|
|
234
|
-
Each dataset needs to map to a unique vector database when backend access control is enabled to facilitate a separation of concern for data.
|
|
235
|
-
|
|
236
|
-
Args:
|
|
237
|
-
dataset_id: UUID of the dataset if needed by the database creation logic
|
|
238
|
-
user: User object if needed by the database creation logic
|
|
239
|
-
Returns:
|
|
240
|
-
dict: Connection info for the created vector database instance.
|
|
241
|
-
"""
|
|
242
|
-
pass
|
|
243
|
-
|
|
244
|
-
async def delete_dataset(self, dataset_id: UUID, user: User) -> None:
|
|
245
|
-
"""
|
|
246
|
-
Delete the vector database for the given dataset.
|
|
247
|
-
Function should auto handle deleting of the actual database or send a request to the proper service to delete the database.
|
|
248
|
-
Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control.
|
|
249
|
-
|
|
250
|
-
Args:
|
|
251
|
-
dataset_id: UUID of the dataset
|
|
252
|
-
user: User object
|
|
253
|
-
"""
|
|
254
|
-
pass
|
|
@@ -9,8 +9,6 @@ class S3Config(BaseSettings):
|
|
|
9
9
|
aws_access_key_id: Optional[str] = None
|
|
10
10
|
aws_secret_access_key: Optional[str] = None
|
|
11
11
|
aws_session_token: Optional[str] = None
|
|
12
|
-
aws_profile_name: Optional[str] = None
|
|
13
|
-
aws_bedrock_runtime_endpoint: Optional[str] = None
|
|
14
12
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
|
15
13
|
|
|
16
14
|
|
|
@@ -11,7 +11,7 @@ class LLMGateway:
|
|
|
11
11
|
|
|
12
12
|
@staticmethod
|
|
13
13
|
def acreate_structured_output(
|
|
14
|
-
text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
14
|
+
text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
15
15
|
) -> Coroutine:
|
|
16
16
|
llm_config = get_llm_config()
|
|
17
17
|
if llm_config.structured_output_framework.upper() == "BAML":
|
|
@@ -31,10 +31,7 @@ class LLMGateway:
|
|
|
31
31
|
|
|
32
32
|
llm_client = get_llm_client()
|
|
33
33
|
return llm_client.acreate_structured_output(
|
|
34
|
-
text_input=text_input,
|
|
35
|
-
system_prompt=system_prompt,
|
|
36
|
-
response_model=response_model,
|
|
37
|
-
**kwargs,
|
|
34
|
+
text_input=text_input, system_prompt=system_prompt, response_model=response_model
|
|
38
35
|
)
|
|
39
36
|
|
|
40
37
|
@staticmethod
|
|
@@ -74,41 +74,6 @@ class LLMConfig(BaseSettings):
|
|
|
74
74
|
|
|
75
75
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
|
76
76
|
|
|
77
|
-
@model_validator(mode="after")
|
|
78
|
-
def strip_quotes_from_strings(self) -> "LLMConfig":
|
|
79
|
-
"""
|
|
80
|
-
Strip surrounding quotes from specific string fields that often come from
|
|
81
|
-
environment variables with extra quotes (e.g., via Docker's --env-file).
|
|
82
|
-
|
|
83
|
-
Only applies to known config keys where quotes are invalid or cause issues.
|
|
84
|
-
"""
|
|
85
|
-
string_fields_to_strip = [
|
|
86
|
-
"llm_api_key",
|
|
87
|
-
"llm_endpoint",
|
|
88
|
-
"llm_api_version",
|
|
89
|
-
"baml_llm_api_key",
|
|
90
|
-
"baml_llm_endpoint",
|
|
91
|
-
"baml_llm_api_version",
|
|
92
|
-
"fallback_api_key",
|
|
93
|
-
"fallback_endpoint",
|
|
94
|
-
"fallback_model",
|
|
95
|
-
"llm_provider",
|
|
96
|
-
"llm_model",
|
|
97
|
-
"baml_llm_provider",
|
|
98
|
-
"baml_llm_model",
|
|
99
|
-
]
|
|
100
|
-
|
|
101
|
-
cls = self.__class__
|
|
102
|
-
for field_name in string_fields_to_strip:
|
|
103
|
-
if field_name not in cls.model_fields:
|
|
104
|
-
continue
|
|
105
|
-
value = getattr(self, field_name, None)
|
|
106
|
-
if isinstance(value, str) and len(value) >= 2:
|
|
107
|
-
if value[0] == value[-1] and value[0] in ("'", '"'):
|
|
108
|
-
setattr(self, field_name, value[1:-1])
|
|
109
|
-
|
|
110
|
-
return self
|
|
111
|
-
|
|
112
77
|
def model_post_init(self, __context) -> None:
|
|
113
78
|
"""Initialize the BAML registry after the model is created."""
|
|
114
79
|
# Check if BAML is selected as structured output framework but not available
|
|
@@ -10,7 +10,7 @@ from cognee.infrastructure.llm.config import (
|
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
async def extract_content_graph(
|
|
13
|
-
content: str, response_model: Type[BaseModel], custom_prompt: Optional[str] = None
|
|
13
|
+
content: str, response_model: Type[BaseModel], custom_prompt: Optional[str] = None
|
|
14
14
|
):
|
|
15
15
|
if custom_prompt:
|
|
16
16
|
system_prompt = custom_prompt
|
|
@@ -30,7 +30,7 @@ async def extract_content_graph(
|
|
|
30
30
|
system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
|
|
31
31
|
|
|
32
32
|
content_graph = await LLMGateway.acreate_structured_output(
|
|
33
|
-
content, system_prompt, response_model
|
|
33
|
+
content, system_prompt, response_model
|
|
34
34
|
)
|
|
35
35
|
|
|
36
36
|
return content_graph
|