cognee 0.5.0.dev0__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/api/client.py +1 -5
- cognee/api/v1/add/add.py +2 -1
- cognee/api/v1/cognify/cognify.py +24 -16
- cognee/api/v1/cognify/routers/__init__.py +0 -1
- cognee/api/v1/cognify/routers/get_cognify_router.py +3 -1
- cognee/api/v1/datasets/routers/get_datasets_router.py +3 -3
- cognee/api/v1/ontologies/ontologies.py +12 -37
- cognee/api/v1/ontologies/routers/get_ontology_router.py +27 -25
- cognee/api/v1/search/search.py +8 -0
- cognee/api/v1/ui/node_setup.py +360 -0
- cognee/api/v1/ui/npm_utils.py +50 -0
- cognee/api/v1/ui/ui.py +38 -68
- cognee/context_global_variables.py +61 -16
- cognee/eval_framework/Dockerfile +29 -0
- cognee/eval_framework/answer_generation/answer_generation_executor.py +10 -0
- cognee/eval_framework/answer_generation/run_question_answering_module.py +1 -1
- cognee/eval_framework/corpus_builder/task_getters/get_cascade_graph_tasks.py +0 -2
- cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py +4 -4
- cognee/eval_framework/eval_config.py +2 -2
- cognee/eval_framework/modal_run_eval.py +16 -28
- cognee/infrastructure/databases/dataset_database_handler/__init__.py +3 -0
- cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py +80 -0
- cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +18 -0
- cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py +10 -0
- cognee/infrastructure/databases/graph/config.py +3 -0
- cognee/infrastructure/databases/graph/get_graph_engine.py +1 -0
- cognee/infrastructure/databases/graph/graph_db_interface.py +15 -0
- cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py +81 -0
- cognee/infrastructure/databases/graph/kuzu/adapter.py +228 -0
- cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +168 -0
- cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +80 -1
- cognee/infrastructure/databases/utils/__init__.py +3 -0
- cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py +10 -0
- cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +62 -48
- cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py +10 -0
- cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py +30 -0
- cognee/infrastructure/databases/vector/config.py +2 -0
- cognee/infrastructure/databases/vector/create_vector_engine.py +1 -0
- cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +8 -6
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +9 -7
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +11 -10
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +2 -0
- cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py +50 -0
- cognee/infrastructure/databases/vector/vector_db_interface.py +35 -0
- cognee/infrastructure/files/storage/s3_config.py +2 -0
- cognee/infrastructure/llm/LLMGateway.py +5 -2
- cognee/infrastructure/llm/config.py +35 -0
- cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py +2 -2
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +23 -8
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -16
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py +5 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py +153 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +40 -37
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +39 -36
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +19 -1
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +11 -9
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +23 -21
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +42 -34
- cognee/memify_pipelines/create_triplet_embeddings.py +53 -0
- cognee/modules/cognify/config.py +2 -0
- cognee/modules/data/deletion/prune_system.py +52 -2
- cognee/modules/data/methods/delete_dataset.py +26 -0
- cognee/modules/engine/models/Triplet.py +9 -0
- cognee/modules/engine/models/__init__.py +1 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +85 -37
- cognee/modules/graph/cognee_graph/CogneeGraphElements.py +8 -3
- cognee/modules/memify/memify.py +1 -7
- cognee/modules/pipelines/operations/pipeline.py +18 -2
- cognee/modules/retrieval/__init__.py +1 -1
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +4 -0
- cognee/modules/retrieval/graph_completion_cot_retriever.py +4 -0
- cognee/modules/retrieval/graph_completion_retriever.py +10 -0
- cognee/modules/retrieval/graph_summary_completion_retriever.py +4 -0
- cognee/modules/retrieval/register_retriever.py +10 -0
- cognee/modules/retrieval/registered_community_retrievers.py +1 -0
- cognee/modules/retrieval/temporal_retriever.py +4 -0
- cognee/modules/retrieval/triplet_retriever.py +182 -0
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +42 -10
- cognee/modules/run_custom_pipeline/run_custom_pipeline.py +8 -1
- cognee/modules/search/methods/get_search_type_tools.py +54 -8
- cognee/modules/search/methods/no_access_control_search.py +4 -0
- cognee/modules/search/methods/search.py +46 -18
- cognee/modules/search/types/SearchType.py +1 -1
- cognee/modules/settings/get_settings.py +19 -0
- cognee/modules/users/methods/get_authenticated_user.py +2 -2
- cognee/modules/users/models/DatasetDatabase.py +15 -3
- cognee/shared/logging_utils.py +4 -0
- cognee/shared/rate_limiting.py +30 -0
- cognee/tasks/documents/__init__.py +0 -1
- cognee/tasks/graph/extract_graph_from_data.py +9 -10
- cognee/tasks/memify/get_triplet_datapoints.py +289 -0
- cognee/tasks/storage/add_data_points.py +142 -2
- cognee/tests/integration/retrieval/test_triplet_retriever.py +84 -0
- cognee/tests/integration/tasks/test_add_data_points.py +139 -0
- cognee/tests/integration/tasks/test_get_triplet_datapoints.py +69 -0
- cognee/tests/test_cognee_server_start.py +2 -4
- cognee/tests/test_conversation_history.py +23 -1
- cognee/tests/test_dataset_database_handler.py +137 -0
- cognee/tests/test_dataset_delete.py +76 -0
- cognee/tests/test_edge_centered_payload.py +170 -0
- cognee/tests/test_pipeline_cache.py +164 -0
- cognee/tests/test_search_db.py +37 -1
- cognee/tests/unit/api/test_ontology_endpoint.py +77 -89
- cognee/tests/unit/infrastructure/llm/test_llm_config.py +46 -0
- cognee/tests/unit/infrastructure/mock_embedding_engine.py +3 -7
- cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +0 -5
- cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +2 -2
- cognee/tests/unit/modules/graph/cognee_graph_test.py +406 -0
- cognee/tests/unit/modules/memify_tasks/test_get_triplet_datapoints.py +214 -0
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +608 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +83 -0
- cognee/tests/unit/modules/search/test_search.py +100 -0
- cognee/tests/unit/tasks/storage/test_add_data_points.py +288 -0
- {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/METADATA +76 -89
- {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/RECORD +119 -97
- {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/WHEEL +1 -1
- cognee/api/v1/cognify/code_graph_pipeline.py +0 -119
- cognee/api/v1/cognify/routers/get_code_pipeline_router.py +0 -90
- cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +0 -544
- cognee/modules/retrieval/code_retriever.py +0 -232
- cognee/tasks/code/enrich_dependency_graph_checker.py +0 -35
- cognee/tasks/code/get_local_dependencies_checker.py +0 -20
- cognee/tasks/code/get_repo_dependency_graph_checker.py +0 -35
- cognee/tasks/documents/check_permissions_on_dataset.py +0 -26
- cognee/tasks/repo_processor/__init__.py +0 -2
- cognee/tasks/repo_processor/get_local_dependencies.py +0 -335
- cognee/tasks/repo_processor/get_non_code_files.py +0 -158
- cognee/tasks/repo_processor/get_repo_file_dependencies.py +0 -243
- cognee/tests/test_delete_bmw_example.py +0 -60
- {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def get_vector_dataset_database_handler(dataset_database: DatasetDatabase) -> dict:
|
|
5
|
+
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
|
6
|
+
supported_dataset_database_handlers,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
handler = supported_dataset_database_handlers[dataset_database.vector_dataset_database_handler]
|
|
10
|
+
return handler
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from cognee.infrastructure.databases.utils.get_graph_dataset_database_handler import (
|
|
2
|
+
get_graph_dataset_database_handler,
|
|
3
|
+
)
|
|
4
|
+
from cognee.infrastructure.databases.utils.get_vector_dataset_database_handler import (
|
|
5
|
+
get_vector_dataset_database_handler,
|
|
6
|
+
)
|
|
7
|
+
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
async def resolve_dataset_database_connection_info(
|
|
11
|
+
dataset_database: DatasetDatabase,
|
|
12
|
+
) -> DatasetDatabase:
|
|
13
|
+
"""
|
|
14
|
+
Resolve the connection info for the given DatasetDatabase instance.
|
|
15
|
+
Resolve both vector and graph database connection info and return the updated DatasetDatabase instance.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
dataset_database: DatasetDatabase instance
|
|
19
|
+
Returns:
|
|
20
|
+
DatasetDatabase instance with resolved connection info
|
|
21
|
+
"""
|
|
22
|
+
vector_dataset_database_handler = get_vector_dataset_database_handler(dataset_database)
|
|
23
|
+
graph_dataset_database_handler = get_graph_dataset_database_handler(dataset_database)
|
|
24
|
+
dataset_database = await vector_dataset_database_handler[
|
|
25
|
+
"handler_instance"
|
|
26
|
+
].resolve_dataset_connection_info(dataset_database)
|
|
27
|
+
dataset_database = await graph_dataset_database_handler[
|
|
28
|
+
"handler_instance"
|
|
29
|
+
].resolve_dataset_connection_info(dataset_database)
|
|
30
|
+
return dataset_database
|
|
@@ -28,6 +28,7 @@ class VectorConfig(BaseSettings):
|
|
|
28
28
|
vector_db_name: str = ""
|
|
29
29
|
vector_db_key: str = ""
|
|
30
30
|
vector_db_provider: str = "lancedb"
|
|
31
|
+
vector_dataset_database_handler: str = "lancedb"
|
|
31
32
|
|
|
32
33
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
|
33
34
|
|
|
@@ -63,6 +64,7 @@ class VectorConfig(BaseSettings):
|
|
|
63
64
|
"vector_db_name": self.vector_db_name,
|
|
64
65
|
"vector_db_key": self.vector_db_key,
|
|
65
66
|
"vector_db_provider": self.vector_db_provider,
|
|
67
|
+
"vector_dataset_database_handler": self.vector_dataset_database_handler,
|
|
66
68
|
}
|
|
67
69
|
|
|
68
70
|
|
|
@@ -17,6 +17,7 @@ from cognee.infrastructure.databases.exceptions import EmbeddingException
|
|
|
17
17
|
from cognee.infrastructure.llm.tokenizer.TikToken import (
|
|
18
18
|
TikTokenTokenizer,
|
|
19
19
|
)
|
|
20
|
+
from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
|
|
20
21
|
|
|
21
22
|
litellm.set_verbose = False
|
|
22
23
|
logger = get_logger("FastembedEmbeddingEngine")
|
|
@@ -68,7 +69,7 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
|
|
|
68
69
|
|
|
69
70
|
@retry(
|
|
70
71
|
stop=stop_after_delay(128),
|
|
71
|
-
wait=wait_exponential_jitter(
|
|
72
|
+
wait=wait_exponential_jitter(8, 128),
|
|
72
73
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
73
74
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
74
75
|
reraise=True,
|
|
@@ -96,11 +97,12 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
|
|
|
96
97
|
if self.mock:
|
|
97
98
|
return [[0.0] * self.dimensions for _ in text]
|
|
98
99
|
else:
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
100
|
+
async with embedding_rate_limiter_context_manager():
|
|
101
|
+
embeddings = self.embedding_model.embed(
|
|
102
|
+
text,
|
|
103
|
+
batch_size=len(text),
|
|
104
|
+
parallel=None,
|
|
105
|
+
)
|
|
104
106
|
|
|
105
107
|
return list(embeddings)
|
|
106
108
|
|
|
@@ -25,6 +25,7 @@ from cognee.infrastructure.llm.tokenizer.Mistral import (
|
|
|
25
25
|
from cognee.infrastructure.llm.tokenizer.TikToken import (
|
|
26
26
|
TikTokenTokenizer,
|
|
27
27
|
)
|
|
28
|
+
from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
|
|
28
29
|
|
|
29
30
|
litellm.set_verbose = False
|
|
30
31
|
logger = get_logger("LiteLLMEmbeddingEngine")
|
|
@@ -109,13 +110,14 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|
|
109
110
|
response = {"data": [{"embedding": [0.0] * self.dimensions} for _ in text]}
|
|
110
111
|
return [data["embedding"] for data in response["data"]]
|
|
111
112
|
else:
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
113
|
+
async with embedding_rate_limiter_context_manager():
|
|
114
|
+
response = await litellm.aembedding(
|
|
115
|
+
model=self.model,
|
|
116
|
+
input=text,
|
|
117
|
+
api_key=self.api_key,
|
|
118
|
+
api_base=self.endpoint,
|
|
119
|
+
api_version=self.api_version,
|
|
120
|
+
)
|
|
119
121
|
|
|
120
122
|
return [data["embedding"] for data in response.data]
|
|
121
123
|
|
|
@@ -18,10 +18,7 @@ from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import Em
|
|
|
18
18
|
from cognee.infrastructure.llm.tokenizer.HuggingFace import (
|
|
19
19
|
HuggingFaceTokenizer,
|
|
20
20
|
)
|
|
21
|
-
from cognee.
|
|
22
|
-
embedding_rate_limit_async,
|
|
23
|
-
embedding_sleep_and_retry_async,
|
|
24
|
-
)
|
|
21
|
+
from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
|
|
25
22
|
from cognee.shared.utils import create_secure_ssl_context
|
|
26
23
|
|
|
27
24
|
logger = get_logger("OllamaEmbeddingEngine")
|
|
@@ -101,7 +98,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|
|
101
98
|
|
|
102
99
|
@retry(
|
|
103
100
|
stop=stop_after_delay(128),
|
|
104
|
-
wait=wait_exponential_jitter(
|
|
101
|
+
wait=wait_exponential_jitter(8, 128),
|
|
105
102
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
106
103
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
107
104
|
reraise=True,
|
|
@@ -120,11 +117,15 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|
|
120
117
|
ssl_context = create_secure_ssl_context()
|
|
121
118
|
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
|
122
119
|
async with aiohttp.ClientSession(connector=connector) as session:
|
|
123
|
-
async with
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
120
|
+
async with embedding_rate_limiter_context_manager():
|
|
121
|
+
async with session.post(
|
|
122
|
+
self.endpoint, json=payload, headers=headers, timeout=60.0
|
|
123
|
+
) as response:
|
|
124
|
+
data = await response.json()
|
|
125
|
+
if "embeddings" in data:
|
|
126
|
+
return data["embeddings"][0]
|
|
127
|
+
else:
|
|
128
|
+
return data["data"][0]["embedding"]
|
|
128
129
|
|
|
129
130
|
def get_vector_size(self) -> int:
|
|
130
131
|
"""
|
|
@@ -193,6 +193,8 @@ class LanceDBAdapter(VectorDBInterface):
|
|
|
193
193
|
for (data_point_index, data_point) in enumerate(data_points)
|
|
194
194
|
]
|
|
195
195
|
|
|
196
|
+
lance_data_points = list({dp.id: dp for dp in lance_data_points}.values())
|
|
197
|
+
|
|
196
198
|
async with self.VECTOR_DB_LOCK:
|
|
197
199
|
await (
|
|
198
200
|
collection.merge_insert("id")
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from uuid import UUID
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
|
|
6
|
+
from cognee.modules.users.models import User
|
|
7
|
+
from cognee.modules.users.models import DatasetDatabase
|
|
8
|
+
from cognee.base_config import get_base_config
|
|
9
|
+
from cognee.infrastructure.databases.vector import get_vectordb_config
|
|
10
|
+
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LanceDBDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
|
14
|
+
"""
|
|
15
|
+
Handler for interacting with LanceDB Dataset databases.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
@classmethod
|
|
19
|
+
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
|
20
|
+
vector_config = get_vectordb_config()
|
|
21
|
+
base_config = get_base_config()
|
|
22
|
+
|
|
23
|
+
if vector_config.vector_db_provider != "lancedb":
|
|
24
|
+
raise ValueError(
|
|
25
|
+
"LanceDBDatasetDatabaseHandler can only be used with LanceDB vector database provider."
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
databases_directory_path = os.path.join(
|
|
29
|
+
base_config.system_root_directory, "databases", str(user.id)
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
vector_db_name = f"{dataset_id}.lance.db"
|
|
33
|
+
|
|
34
|
+
return {
|
|
35
|
+
"vector_database_provider": vector_config.vector_db_provider,
|
|
36
|
+
"vector_database_url": os.path.join(databases_directory_path, vector_db_name),
|
|
37
|
+
"vector_database_key": vector_config.vector_db_key,
|
|
38
|
+
"vector_database_name": vector_db_name,
|
|
39
|
+
"vector_dataset_database_handler": "lancedb",
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
async def delete_dataset(cls, dataset_database: DatasetDatabase):
|
|
44
|
+
vector_engine = create_vector_engine(
|
|
45
|
+
vector_db_provider=dataset_database.vector_database_provider,
|
|
46
|
+
vector_db_url=dataset_database.vector_database_url,
|
|
47
|
+
vector_db_key=dataset_database.vector_database_key,
|
|
48
|
+
vector_db_name=dataset_database.vector_database_name,
|
|
49
|
+
)
|
|
50
|
+
await vector_engine.prune()
|
|
@@ -2,6 +2,8 @@ from typing import List, Protocol, Optional, Union, Any
|
|
|
2
2
|
from abc import abstractmethod
|
|
3
3
|
from cognee.infrastructure.engine import DataPoint
|
|
4
4
|
from .models.PayloadSchema import PayloadSchema
|
|
5
|
+
from uuid import UUID
|
|
6
|
+
from cognee.modules.users.models import User
|
|
5
7
|
|
|
6
8
|
|
|
7
9
|
class VectorDBInterface(Protocol):
|
|
@@ -217,3 +219,36 @@ class VectorDBInterface(Protocol):
|
|
|
217
219
|
- Any: The schema object suitable for this vector database
|
|
218
220
|
"""
|
|
219
221
|
return model_type
|
|
222
|
+
|
|
223
|
+
@classmethod
|
|
224
|
+
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
|
225
|
+
"""
|
|
226
|
+
Return a dictionary with connection info for a vector database for the given dataset.
|
|
227
|
+
Function can auto handle deploying of the actual database if needed, but is not necessary.
|
|
228
|
+
Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future.
|
|
229
|
+
Needed for Cognee multi-tenant/multi-user and backend access control support.
|
|
230
|
+
|
|
231
|
+
Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database.
|
|
232
|
+
From which internal mapping of dataset -> database connection info will be done.
|
|
233
|
+
|
|
234
|
+
Each dataset needs to map to a unique vector database when backend access control is enabled to facilitate a separation of concern for data.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
dataset_id: UUID of the dataset if needed by the database creation logic
|
|
238
|
+
user: User object if needed by the database creation logic
|
|
239
|
+
Returns:
|
|
240
|
+
dict: Connection info for the created vector database instance.
|
|
241
|
+
"""
|
|
242
|
+
pass
|
|
243
|
+
|
|
244
|
+
async def delete_dataset(self, dataset_id: UUID, user: User) -> None:
|
|
245
|
+
"""
|
|
246
|
+
Delete the vector database for the given dataset.
|
|
247
|
+
Function should auto handle deleting of the actual database or send a request to the proper service to delete the database.
|
|
248
|
+
Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
dataset_id: UUID of the dataset
|
|
252
|
+
user: User object
|
|
253
|
+
"""
|
|
254
|
+
pass
|
|
@@ -9,6 +9,8 @@ class S3Config(BaseSettings):
|
|
|
9
9
|
aws_access_key_id: Optional[str] = None
|
|
10
10
|
aws_secret_access_key: Optional[str] = None
|
|
11
11
|
aws_session_token: Optional[str] = None
|
|
12
|
+
aws_profile_name: Optional[str] = None
|
|
13
|
+
aws_bedrock_runtime_endpoint: Optional[str] = None
|
|
12
14
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
|
13
15
|
|
|
14
16
|
|
|
@@ -11,7 +11,7 @@ class LLMGateway:
|
|
|
11
11
|
|
|
12
12
|
@staticmethod
|
|
13
13
|
def acreate_structured_output(
|
|
14
|
-
text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
14
|
+
text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
|
|
15
15
|
) -> Coroutine:
|
|
16
16
|
llm_config = get_llm_config()
|
|
17
17
|
if llm_config.structured_output_framework.upper() == "BAML":
|
|
@@ -31,7 +31,10 @@ class LLMGateway:
|
|
|
31
31
|
|
|
32
32
|
llm_client = get_llm_client()
|
|
33
33
|
return llm_client.acreate_structured_output(
|
|
34
|
-
text_input=text_input,
|
|
34
|
+
text_input=text_input,
|
|
35
|
+
system_prompt=system_prompt,
|
|
36
|
+
response_model=response_model,
|
|
37
|
+
**kwargs,
|
|
35
38
|
)
|
|
36
39
|
|
|
37
40
|
@staticmethod
|
|
@@ -74,6 +74,41 @@ class LLMConfig(BaseSettings):
|
|
|
74
74
|
|
|
75
75
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
|
76
76
|
|
|
77
|
+
@model_validator(mode="after")
|
|
78
|
+
def strip_quotes_from_strings(self) -> "LLMConfig":
|
|
79
|
+
"""
|
|
80
|
+
Strip surrounding quotes from specific string fields that often come from
|
|
81
|
+
environment variables with extra quotes (e.g., via Docker's --env-file).
|
|
82
|
+
|
|
83
|
+
Only applies to known config keys where quotes are invalid or cause issues.
|
|
84
|
+
"""
|
|
85
|
+
string_fields_to_strip = [
|
|
86
|
+
"llm_api_key",
|
|
87
|
+
"llm_endpoint",
|
|
88
|
+
"llm_api_version",
|
|
89
|
+
"baml_llm_api_key",
|
|
90
|
+
"baml_llm_endpoint",
|
|
91
|
+
"baml_llm_api_version",
|
|
92
|
+
"fallback_api_key",
|
|
93
|
+
"fallback_endpoint",
|
|
94
|
+
"fallback_model",
|
|
95
|
+
"llm_provider",
|
|
96
|
+
"llm_model",
|
|
97
|
+
"baml_llm_provider",
|
|
98
|
+
"baml_llm_model",
|
|
99
|
+
]
|
|
100
|
+
|
|
101
|
+
cls = self.__class__
|
|
102
|
+
for field_name in string_fields_to_strip:
|
|
103
|
+
if field_name not in cls.model_fields:
|
|
104
|
+
continue
|
|
105
|
+
value = getattr(self, field_name, None)
|
|
106
|
+
if isinstance(value, str) and len(value) >= 2:
|
|
107
|
+
if value[0] == value[-1] and value[0] in ("'", '"'):
|
|
108
|
+
setattr(self, field_name, value[1:-1])
|
|
109
|
+
|
|
110
|
+
return self
|
|
111
|
+
|
|
77
112
|
def model_post_init(self, __context) -> None:
|
|
78
113
|
"""Initialize the BAML registry after the model is created."""
|
|
79
114
|
# Check if BAML is selected as structured output framework but not available
|
|
@@ -10,7 +10,7 @@ from cognee.infrastructure.llm.config import (
|
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
async def extract_content_graph(
|
|
13
|
-
content: str, response_model: Type[BaseModel], custom_prompt: Optional[str] = None
|
|
13
|
+
content: str, response_model: Type[BaseModel], custom_prompt: Optional[str] = None, **kwargs
|
|
14
14
|
):
|
|
15
15
|
if custom_prompt:
|
|
16
16
|
system_prompt = custom_prompt
|
|
@@ -30,7 +30,7 @@ async def extract_content_graph(
|
|
|
30
30
|
system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
|
|
31
31
|
|
|
32
32
|
content_graph = await LLMGateway.acreate_structured_output(
|
|
33
|
-
content, system_prompt, response_model
|
|
33
|
+
content, system_prompt, response_model, **kwargs
|
|
34
34
|
)
|
|
35
35
|
|
|
36
36
|
return content_graph
|
|
@@ -1,7 +1,15 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
from typing import Type
|
|
3
|
-
from
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
from tenacity import (
|
|
5
|
+
retry,
|
|
6
|
+
stop_after_delay,
|
|
7
|
+
wait_exponential_jitter,
|
|
8
|
+
retry_if_not_exception_type,
|
|
9
|
+
before_sleep_log,
|
|
10
|
+
)
|
|
4
11
|
|
|
12
|
+
from cognee.shared.logging_utils import get_logger
|
|
5
13
|
from cognee.infrastructure.llm.config import get_llm_config
|
|
6
14
|
from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction.create_dynamic_baml_type import (
|
|
7
15
|
create_dynamic_baml_type,
|
|
@@ -10,12 +18,18 @@ from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.type
|
|
|
10
18
|
TypeBuilder,
|
|
11
19
|
)
|
|
12
20
|
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client import b
|
|
13
|
-
from
|
|
14
|
-
|
|
21
|
+
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
|
22
|
+
import logging
|
|
15
23
|
|
|
16
24
|
logger = get_logger()
|
|
17
25
|
|
|
18
26
|
|
|
27
|
+
@retry(
|
|
28
|
+
stop=stop_after_delay(128),
|
|
29
|
+
wait=wait_exponential_jitter(8, 128),
|
|
30
|
+
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
31
|
+
reraise=True,
|
|
32
|
+
)
|
|
19
33
|
async def acreate_structured_output(
|
|
20
34
|
text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
21
35
|
):
|
|
@@ -45,11 +59,12 @@ async def acreate_structured_output(
|
|
|
45
59
|
tb = TypeBuilder()
|
|
46
60
|
type_builder = create_dynamic_baml_type(tb, tb.ResponseModel, response_model)
|
|
47
61
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
62
|
+
async with llm_rate_limiter_context_manager():
|
|
63
|
+
result = await b.AcreateStructuredOutput(
|
|
64
|
+
text_input=text_input,
|
|
65
|
+
system_prompt=system_prompt,
|
|
66
|
+
baml_options={"client_registry": config.baml_registry, "tb": type_builder},
|
|
67
|
+
)
|
|
53
68
|
|
|
54
69
|
# Transform BAML response to proper pydantic reponse model
|
|
55
70
|
if response_model is str:
|
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py
CHANGED
|
@@ -15,6 +15,7 @@ from tenacity import (
|
|
|
15
15
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
|
16
16
|
LLMInterface,
|
|
17
17
|
)
|
|
18
|
+
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
|
18
19
|
from cognee.infrastructure.llm.config import get_llm_config
|
|
19
20
|
|
|
20
21
|
logger = get_logger()
|
|
@@ -45,13 +46,13 @@ class AnthropicAdapter(LLMInterface):
|
|
|
45
46
|
|
|
46
47
|
@retry(
|
|
47
48
|
stop=stop_after_delay(128),
|
|
48
|
-
wait=wait_exponential_jitter(
|
|
49
|
+
wait=wait_exponential_jitter(8, 128),
|
|
49
50
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
50
51
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
51
52
|
reraise=True,
|
|
52
53
|
)
|
|
53
54
|
async def acreate_structured_output(
|
|
54
|
-
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
55
|
+
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
|
|
55
56
|
) -> BaseModel:
|
|
56
57
|
"""
|
|
57
58
|
Generate a response from a user query.
|
|
@@ -69,17 +70,17 @@ class AnthropicAdapter(LLMInterface):
|
|
|
69
70
|
|
|
70
71
|
- BaseModel: An instance of BaseModel containing the structured response.
|
|
71
72
|
"""
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
73
|
+
async with llm_rate_limiter_context_manager():
|
|
74
|
+
return await self.aclient(
|
|
75
|
+
model=self.model,
|
|
76
|
+
max_tokens=4096,
|
|
77
|
+
max_retries=2,
|
|
78
|
+
messages=[
|
|
79
|
+
{
|
|
80
|
+
"role": "user",
|
|
81
|
+
"content": f"""Use the given format to extract information
|
|
82
|
+
from the following input: {text_input}. {system_prompt}""",
|
|
83
|
+
}
|
|
84
|
+
],
|
|
85
|
+
response_model=response_model,
|
|
86
|
+
)
|
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
import litellm
|
|
2
|
+
import instructor
|
|
3
|
+
from typing import Type
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
from litellm.exceptions import ContentPolicyViolationError
|
|
6
|
+
from instructor.exceptions import InstructorRetryException
|
|
7
|
+
|
|
8
|
+
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
|
9
|
+
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
|
10
|
+
LLMInterface,
|
|
11
|
+
)
|
|
12
|
+
from cognee.infrastructure.llm.exceptions import (
|
|
13
|
+
ContentPolicyFilterError,
|
|
14
|
+
MissingSystemPromptPathError,
|
|
15
|
+
)
|
|
16
|
+
from cognee.infrastructure.files.storage.s3_config import get_s3_config
|
|
17
|
+
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
|
|
18
|
+
rate_limit_async,
|
|
19
|
+
rate_limit_sync,
|
|
20
|
+
sleep_and_retry_async,
|
|
21
|
+
sleep_and_retry_sync,
|
|
22
|
+
)
|
|
23
|
+
from cognee.modules.observability.get_observe import get_observe
|
|
24
|
+
|
|
25
|
+
observe = get_observe()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class BedrockAdapter(LLMInterface):
|
|
29
|
+
"""
|
|
30
|
+
Adapter for AWS Bedrock API with support for three authentication methods:
|
|
31
|
+
1. API Key (Bearer Token)
|
|
32
|
+
2. AWS Credentials (access key + secret key)
|
|
33
|
+
3. AWS Profile (boto3 credential chain)
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
name = "Bedrock"
|
|
37
|
+
model: str
|
|
38
|
+
api_key: str
|
|
39
|
+
default_instructor_mode = "json_schema_mode"
|
|
40
|
+
|
|
41
|
+
MAX_RETRIES = 5
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
model: str,
|
|
46
|
+
api_key: str = None,
|
|
47
|
+
max_completion_tokens: int = 16384,
|
|
48
|
+
streaming: bool = False,
|
|
49
|
+
instructor_mode: str = None,
|
|
50
|
+
):
|
|
51
|
+
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
|
52
|
+
|
|
53
|
+
self.aclient = instructor.from_litellm(
|
|
54
|
+
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
|
|
55
|
+
)
|
|
56
|
+
self.client = instructor.from_litellm(litellm.completion)
|
|
57
|
+
self.model = model
|
|
58
|
+
self.api_key = api_key
|
|
59
|
+
self.max_completion_tokens = max_completion_tokens
|
|
60
|
+
self.streaming = streaming
|
|
61
|
+
|
|
62
|
+
def _create_bedrock_request(
|
|
63
|
+
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
64
|
+
) -> dict:
|
|
65
|
+
"""Create Bedrock request with authentication."""
|
|
66
|
+
|
|
67
|
+
request_params = {
|
|
68
|
+
"model": self.model,
|
|
69
|
+
"custom_llm_provider": "bedrock",
|
|
70
|
+
"drop_params": True,
|
|
71
|
+
"messages": [
|
|
72
|
+
{"role": "user", "content": text_input},
|
|
73
|
+
{"role": "system", "content": system_prompt},
|
|
74
|
+
],
|
|
75
|
+
"response_model": response_model,
|
|
76
|
+
"max_retries": self.MAX_RETRIES,
|
|
77
|
+
"max_completion_tokens": self.max_completion_tokens,
|
|
78
|
+
"stream": self.streaming,
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
s3_config = get_s3_config()
|
|
82
|
+
|
|
83
|
+
# Add authentication parameters
|
|
84
|
+
if self.api_key:
|
|
85
|
+
request_params["api_key"] = self.api_key
|
|
86
|
+
elif s3_config.aws_access_key_id and s3_config.aws_secret_access_key:
|
|
87
|
+
request_params["aws_access_key_id"] = s3_config.aws_access_key_id
|
|
88
|
+
request_params["aws_secret_access_key"] = s3_config.aws_secret_access_key
|
|
89
|
+
if s3_config.aws_session_token:
|
|
90
|
+
request_params["aws_session_token"] = s3_config.aws_session_token
|
|
91
|
+
elif s3_config.aws_profile_name:
|
|
92
|
+
request_params["aws_profile_name"] = s3_config.aws_profile_name
|
|
93
|
+
|
|
94
|
+
if s3_config.aws_region:
|
|
95
|
+
request_params["aws_region_name"] = s3_config.aws_region
|
|
96
|
+
|
|
97
|
+
# Add optional parameters
|
|
98
|
+
if s3_config.aws_bedrock_runtime_endpoint:
|
|
99
|
+
request_params["aws_bedrock_runtime_endpoint"] = s3_config.aws_bedrock_runtime_endpoint
|
|
100
|
+
|
|
101
|
+
return request_params
|
|
102
|
+
|
|
103
|
+
@observe(as_type="generation")
|
|
104
|
+
@sleep_and_retry_async()
|
|
105
|
+
@rate_limit_async
|
|
106
|
+
async def acreate_structured_output(
|
|
107
|
+
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
108
|
+
) -> BaseModel:
|
|
109
|
+
"""Generate structured output from AWS Bedrock API."""
|
|
110
|
+
|
|
111
|
+
try:
|
|
112
|
+
request_params = self._create_bedrock_request(text_input, system_prompt, response_model)
|
|
113
|
+
return await self.aclient.chat.completions.create(**request_params)
|
|
114
|
+
|
|
115
|
+
except (
|
|
116
|
+
ContentPolicyViolationError,
|
|
117
|
+
InstructorRetryException,
|
|
118
|
+
) as error:
|
|
119
|
+
if (
|
|
120
|
+
isinstance(error, InstructorRetryException)
|
|
121
|
+
and "content management policy" not in str(error).lower()
|
|
122
|
+
):
|
|
123
|
+
raise error
|
|
124
|
+
|
|
125
|
+
raise ContentPolicyFilterError(
|
|
126
|
+
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
@observe
|
|
130
|
+
@sleep_and_retry_sync()
|
|
131
|
+
@rate_limit_sync
|
|
132
|
+
def create_structured_output(
|
|
133
|
+
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
134
|
+
) -> BaseModel:
|
|
135
|
+
"""Generate structured output from AWS Bedrock API (synchronous)."""
|
|
136
|
+
|
|
137
|
+
request_params = self._create_bedrock_request(text_input, system_prompt, response_model)
|
|
138
|
+
return self.client.chat.completions.create(**request_params)
|
|
139
|
+
|
|
140
|
+
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
|
141
|
+
"""Format and display the prompt for a user query."""
|
|
142
|
+
if not text_input:
|
|
143
|
+
text_input = "No user input provided."
|
|
144
|
+
if not system_prompt:
|
|
145
|
+
raise MissingSystemPromptPathError()
|
|
146
|
+
system_prompt = LLMGateway.read_query_prompt(system_prompt)
|
|
147
|
+
|
|
148
|
+
formatted_prompt = (
|
|
149
|
+
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
|
|
150
|
+
if system_prompt
|
|
151
|
+
else None
|
|
152
|
+
)
|
|
153
|
+
return formatted_prompt
|