cognee 0.3.4.dev4__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/api/client.py +16 -7
- cognee/api/health.py +5 -9
- cognee/api/v1/add/add.py +3 -1
- cognee/api/v1/cognify/cognify.py +44 -7
- cognee/api/v1/permissions/routers/get_permissions_router.py +8 -4
- cognee/api/v1/search/search.py +3 -0
- cognee/api/v1/ui/__init__.py +1 -1
- cognee/api/v1/ui/ui.py +215 -150
- cognee/api/v1/update/__init__.py +1 -0
- cognee/api/v1/update/routers/__init__.py +1 -0
- cognee/api/v1/update/routers/get_update_router.py +90 -0
- cognee/api/v1/update/update.py +100 -0
- cognee/base_config.py +5 -2
- cognee/cli/_cognee.py +28 -10
- cognee/cli/commands/delete_command.py +34 -2
- cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py +2 -2
- cognee/eval_framework/evaluation/direct_llm_eval_adapter.py +3 -2
- cognee/eval_framework/modal_eval_dashboard.py +9 -1
- cognee/infrastructure/databases/graph/config.py +9 -9
- cognee/infrastructure/databases/graph/get_graph_engine.py +4 -21
- cognee/infrastructure/databases/graph/kuzu/adapter.py +60 -9
- cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +3 -3
- cognee/infrastructure/databases/relational/config.py +4 -4
- cognee/infrastructure/databases/relational/create_relational_engine.py +11 -3
- cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +7 -3
- cognee/infrastructure/databases/vector/config.py +7 -7
- cognee/infrastructure/databases/vector/create_vector_engine.py +7 -15
- cognee/infrastructure/databases/vector/embeddings/EmbeddingEngine.py +9 -0
- cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +11 -0
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +19 -2
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +11 -0
- cognee/infrastructure/databases/vector/embeddings/config.py +8 -0
- cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py +5 -0
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +11 -10
- cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +48 -38
- cognee/infrastructure/databases/vector/vector_db_interface.py +8 -4
- cognee/infrastructure/files/storage/S3FileStorage.py +15 -5
- cognee/infrastructure/files/storage/s3_config.py +1 -0
- cognee/infrastructure/files/utils/open_data_file.py +7 -14
- cognee/infrastructure/llm/LLMGateway.py +19 -117
- cognee/infrastructure/llm/config.py +28 -13
- cognee/infrastructure/llm/{structured_output_framework/litellm_instructor/extraction → extraction}/extract_categories.py +2 -1
- cognee/infrastructure/llm/{structured_output_framework/litellm_instructor/extraction → extraction}/extract_event_entities.py +3 -2
- cognee/infrastructure/llm/{structured_output_framework/litellm_instructor/extraction → extraction}/extract_summary.py +3 -2
- cognee/infrastructure/llm/{structured_output_framework/litellm_instructor/extraction → extraction}/knowledge_graph/extract_content_graph.py +2 -1
- cognee/infrastructure/llm/{structured_output_framework/litellm_instructor/extraction → extraction}/knowledge_graph/extract_event_graph.py +3 -2
- cognee/infrastructure/llm/prompts/read_query_prompt.py +3 -2
- cognee/infrastructure/llm/prompts/show_prompt.py +35 -0
- cognee/infrastructure/llm/prompts/test.txt +1 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/__init__.py +2 -2
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/async_client.py +50 -397
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/inlinedbaml.py +2 -3
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/parser.py +8 -88
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/runtime.py +78 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/stream_types.py +2 -99
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/sync_client.py +49 -401
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/type_builder.py +19 -882
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/type_map.py +2 -34
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/types.py +2 -107
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/acreate_structured_output.baml +26 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/__init__.py +1 -2
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +76 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/create_dynamic_baml_type.py +122 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/generators.baml +3 -3
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +0 -32
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +107 -98
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +5 -6
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +5 -6
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +0 -26
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +17 -67
- cognee/infrastructure/llm/tokenizer/Gemini/adapter.py +8 -7
- cognee/infrastructure/llm/utils.py +4 -4
- cognee/infrastructure/loaders/LoaderEngine.py +5 -2
- cognee/infrastructure/loaders/external/__init__.py +7 -0
- cognee/infrastructure/loaders/external/advanced_pdf_loader.py +244 -0
- cognee/infrastructure/loaders/supported_loaders.py +7 -0
- cognee/modules/data/methods/create_authorized_dataset.py +9 -0
- cognee/modules/data/methods/get_authorized_dataset.py +1 -1
- cognee/modules/data/methods/get_authorized_dataset_by_name.py +11 -0
- cognee/modules/data/methods/get_deletion_counts.py +92 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +1 -1
- cognee/modules/graph/utils/expand_with_nodes_and_edges.py +22 -8
- cognee/modules/graph/utils/retrieve_existing_edges.py +0 -2
- cognee/modules/ingestion/data_types/TextData.py +0 -1
- cognee/modules/observability/get_observe.py +14 -0
- cognee/modules/observability/observers.py +1 -0
- cognee/modules/ontology/base_ontology_resolver.py +42 -0
- cognee/modules/ontology/get_default_ontology_resolver.py +41 -0
- cognee/modules/ontology/matching_strategies.py +53 -0
- cognee/modules/ontology/models.py +20 -0
- cognee/modules/ontology/ontology_config.py +24 -0
- cognee/modules/ontology/ontology_env_config.py +45 -0
- cognee/modules/ontology/rdf_xml/{OntologyResolver.py → RDFLibOntologyResolver.py} +20 -28
- cognee/modules/pipelines/layers/resolve_authorized_user_dataset.py +21 -24
- cognee/modules/pipelines/layers/resolve_authorized_user_datasets.py +3 -3
- cognee/modules/retrieval/code_retriever.py +2 -1
- cognee/modules/retrieval/context_providers/TripletSearchContextProvider.py +1 -4
- cognee/modules/retrieval/graph_completion_cot_retriever.py +6 -5
- cognee/modules/retrieval/graph_completion_retriever.py +0 -3
- cognee/modules/retrieval/insights_retriever.py +1 -1
- cognee/modules/retrieval/jaccard_retrival.py +60 -0
- cognee/modules/retrieval/lexical_retriever.py +123 -0
- cognee/modules/retrieval/natural_language_retriever.py +2 -1
- cognee/modules/retrieval/temporal_retriever.py +3 -2
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +2 -12
- cognee/modules/retrieval/utils/completion.py +4 -7
- cognee/modules/search/methods/get_search_type_tools.py +7 -0
- cognee/modules/search/methods/no_access_control_search.py +1 -1
- cognee/modules/search/methods/search.py +32 -13
- cognee/modules/search/types/SearchType.py +1 -0
- cognee/modules/users/permissions/methods/authorized_give_permission_on_datasets.py +12 -0
- cognee/modules/users/permissions/methods/check_permission_on_dataset.py +11 -0
- cognee/modules/users/permissions/methods/get_all_user_permission_datasets.py +10 -0
- cognee/modules/users/permissions/methods/get_document_ids_for_user.py +10 -0
- cognee/modules/users/permissions/methods/get_principal.py +9 -0
- cognee/modules/users/permissions/methods/get_principal_datasets.py +11 -0
- cognee/modules/users/permissions/methods/get_role.py +10 -0
- cognee/modules/users/permissions/methods/get_specific_user_permission_datasets.py +3 -3
- cognee/modules/users/permissions/methods/get_tenant.py +9 -0
- cognee/modules/users/permissions/methods/give_default_permission_to_role.py +9 -0
- cognee/modules/users/permissions/methods/give_default_permission_to_tenant.py +9 -0
- cognee/modules/users/permissions/methods/give_default_permission_to_user.py +9 -0
- cognee/modules/users/permissions/methods/give_permission_on_dataset.py +10 -0
- cognee/modules/users/roles/methods/add_user_to_role.py +11 -0
- cognee/modules/users/roles/methods/create_role.py +12 -1
- cognee/modules/users/tenants/methods/add_user_to_tenant.py +12 -0
- cognee/modules/users/tenants/methods/create_tenant.py +12 -1
- cognee/modules/visualization/cognee_network_visualization.py +13 -9
- cognee/shared/data_models.py +0 -1
- cognee/shared/utils.py +0 -32
- cognee/tasks/chunk_naive_llm_classifier/chunk_naive_llm_classifier.py +2 -2
- cognee/tasks/codingagents/coding_rule_associations.py +3 -2
- cognee/tasks/entity_completion/entity_extractors/llm_entity_extractor.py +3 -2
- cognee/tasks/graph/cascade_extract/utils/extract_content_nodes_and_relationship_names.py +3 -2
- cognee/tasks/graph/cascade_extract/utils/extract_edge_triplets.py +3 -2
- cognee/tasks/graph/cascade_extract/utils/extract_nodes.py +3 -2
- cognee/tasks/graph/extract_graph_from_code.py +2 -2
- cognee/tasks/graph/extract_graph_from_data.py +55 -12
- cognee/tasks/graph/extract_graph_from_data_v2.py +16 -4
- cognee/tasks/ingestion/migrate_relational_database.py +132 -41
- cognee/tasks/ingestion/resolve_data_directories.py +4 -1
- cognee/tasks/schema/ingest_database_schema.py +134 -0
- cognee/tasks/schema/models.py +40 -0
- cognee/tasks/storage/index_data_points.py +1 -1
- cognee/tasks/storage/index_graph_edges.py +3 -1
- cognee/tasks/summarization/summarize_code.py +2 -2
- cognee/tasks/summarization/summarize_text.py +2 -2
- cognee/tasks/temporal_graph/enrich_events.py +2 -2
- cognee/tasks/temporal_graph/extract_events_and_entities.py +2 -2
- cognee/tests/cli_tests/cli_unit_tests/test_cli_commands.py +13 -4
- cognee/tests/cli_tests/cli_unit_tests/test_cli_edge_cases.py +13 -3
- cognee/tests/test_advanced_pdf_loader.py +141 -0
- cognee/tests/test_chromadb.py +40 -0
- cognee/tests/test_cognee_server_start.py +6 -1
- cognee/tests/test_data/Quantum_computers.txt +9 -0
- cognee/tests/test_lancedb.py +211 -0
- cognee/tests/test_pgvector.py +40 -0
- cognee/tests/test_relational_db_migration.py +76 -0
- cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py +2 -1
- cognee/tests/unit/modules/ontology/test_ontology_adapter.py +330 -13
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +0 -4
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +0 -4
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +0 -4
- {cognee-0.3.4.dev4.dist-info → cognee-0.3.5.dist-info}/METADATA +92 -96
- {cognee-0.3.4.dev4.dist-info → cognee-0.3.5.dist-info}/RECORD +173 -159
- distributed/pyproject.toml +0 -1
- cognee/infrastructure/data/utils/extract_keywords.py +0 -48
- cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py +0 -1227
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extract_categories.baml +0 -109
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extract_content_graph.baml +0 -343
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/extract_categories.py +0 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/extract_summary.py +0 -89
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/knowledge_graph/__init__.py +0 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/knowledge_graph/extract_content_graph.py +0 -44
- cognee/tasks/graph/infer_data_ontology.py +0 -309
- cognee/tests/test_falkordb.py +0 -174
- /cognee/infrastructure/llm/{structured_output_framework/litellm_instructor/extraction → extraction}/__init__.py +0 -0
- /cognee/infrastructure/llm/{structured_output_framework/litellm_instructor/extraction → extraction}/knowledge_graph/__init__.py +0 -0
- /cognee/infrastructure/llm/{structured_output_framework/litellm_instructor/extraction → extraction}/texts.json +0 -0
- {cognee-0.3.4.dev4.dist-info → cognee-0.3.5.dist-info}/WHEEL +0 -0
- {cognee-0.3.4.dev4.dist-info → cognee-0.3.5.dist-info}/entry_points.txt +0 -0
- {cognee-0.3.4.dev4.dist-info → cognee-0.3.5.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.3.4.dev4.dist-info → cognee-0.3.5.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -19,8 +19,7 @@ def create_vector_engine(
|
|
|
19
19
|
for each provider, raising an EnvironmentError if any are missing, or ImportError if the
|
|
20
20
|
ChromaDB package is not installed.
|
|
21
21
|
|
|
22
|
-
Supported providers include: pgvector,
|
|
23
|
-
LanceDB.
|
|
22
|
+
Supported providers include: pgvector, ChromaDB, and LanceDB.
|
|
24
23
|
|
|
25
24
|
Parameters:
|
|
26
25
|
-----------
|
|
@@ -66,7 +65,12 @@ def create_vector_engine(
|
|
|
66
65
|
f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
|
|
67
66
|
)
|
|
68
67
|
|
|
69
|
-
|
|
68
|
+
try:
|
|
69
|
+
from .pgvector.PGVectorAdapter import PGVectorAdapter
|
|
70
|
+
except ImportError:
|
|
71
|
+
raise ImportError(
|
|
72
|
+
"PostgreSQL dependencies are not installed. Please install with 'pip install cognee\"[postgres]\"' or 'pip install cognee\"[postgres-binary]\"' to use PGVector functionality."
|
|
73
|
+
)
|
|
70
74
|
|
|
71
75
|
return PGVectorAdapter(
|
|
72
76
|
connection_string,
|
|
@@ -74,18 +78,6 @@ def create_vector_engine(
|
|
|
74
78
|
embedding_engine,
|
|
75
79
|
)
|
|
76
80
|
|
|
77
|
-
elif vector_db_provider == "falkordb":
|
|
78
|
-
if not (vector_db_url and vector_db_port):
|
|
79
|
-
raise EnvironmentError("Missing requred FalkorDB credentials!")
|
|
80
|
-
|
|
81
|
-
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
|
|
82
|
-
|
|
83
|
-
return FalkorDBAdapter(
|
|
84
|
-
database_url=vector_db_url,
|
|
85
|
-
database_port=vector_db_port,
|
|
86
|
-
embedding_engine=embedding_engine,
|
|
87
|
-
)
|
|
88
|
-
|
|
89
81
|
elif vector_db_provider == "chromadb":
|
|
90
82
|
try:
|
|
91
83
|
import chromadb
|
|
@@ -34,3 +34,12 @@ class EmbeddingEngine(Protocol):
|
|
|
34
34
|
- int: An integer representing the number of dimensions in the embedding vector.
|
|
35
35
|
"""
|
|
36
36
|
raise NotImplementedError()
|
|
37
|
+
|
|
38
|
+
def get_batch_size(self) -> int:
|
|
39
|
+
"""
|
|
40
|
+
Return the desired batch size for embedding calls
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
|
|
44
|
+
"""
|
|
45
|
+
raise NotImplementedError()
|
|
@@ -42,11 +42,13 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
|
|
|
42
42
|
model: Optional[str] = "openai/text-embedding-3-large",
|
|
43
43
|
dimensions: Optional[int] = 3072,
|
|
44
44
|
max_completion_tokens: int = 512,
|
|
45
|
+
batch_size: int = 100,
|
|
45
46
|
):
|
|
46
47
|
self.model = model
|
|
47
48
|
self.dimensions = dimensions
|
|
48
49
|
self.max_completion_tokens = max_completion_tokens
|
|
49
50
|
self.tokenizer = self.get_tokenizer()
|
|
51
|
+
self.batch_size = batch_size
|
|
50
52
|
# self.retry_count = 0
|
|
51
53
|
self.embedding_model = TextEmbedding(model_name=model)
|
|
52
54
|
|
|
@@ -101,6 +103,15 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
|
|
|
101
103
|
"""
|
|
102
104
|
return self.dimensions
|
|
103
105
|
|
|
106
|
+
def get_batch_size(self) -> int:
|
|
107
|
+
"""
|
|
108
|
+
Return the desired batch size for embedding calls
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
|
|
112
|
+
"""
|
|
113
|
+
return self.batch_size
|
|
114
|
+
|
|
104
115
|
def get_tokenizer(self):
|
|
105
116
|
"""
|
|
106
117
|
Instantiate and return the tokenizer used for preparing text for embedding.
|
|
@@ -58,6 +58,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|
|
58
58
|
endpoint: str = None,
|
|
59
59
|
api_version: str = None,
|
|
60
60
|
max_completion_tokens: int = 512,
|
|
61
|
+
batch_size: int = 100,
|
|
61
62
|
):
|
|
62
63
|
self.api_key = api_key
|
|
63
64
|
self.endpoint = endpoint
|
|
@@ -68,6 +69,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|
|
68
69
|
self.max_completion_tokens = max_completion_tokens
|
|
69
70
|
self.tokenizer = self.get_tokenizer()
|
|
70
71
|
self.retry_count = 0
|
|
72
|
+
self.batch_size = batch_size
|
|
71
73
|
|
|
72
74
|
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
|
|
73
75
|
if isinstance(enable_mocking, bool):
|
|
@@ -165,6 +167,15 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|
|
165
167
|
"""
|
|
166
168
|
return self.dimensions
|
|
167
169
|
|
|
170
|
+
def get_batch_size(self) -> int:
|
|
171
|
+
"""
|
|
172
|
+
Return the desired batch size for embedding calls
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
|
|
176
|
+
"""
|
|
177
|
+
return self.batch_size
|
|
178
|
+
|
|
168
179
|
def get_tokenizer(self):
|
|
169
180
|
"""
|
|
170
181
|
Load and return the appropriate tokenizer for the specified model based on the provider.
|
|
@@ -183,9 +194,15 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|
|
183
194
|
model=model, max_completion_tokens=self.max_completion_tokens
|
|
184
195
|
)
|
|
185
196
|
elif "gemini" in self.provider.lower():
|
|
186
|
-
|
|
187
|
-
|
|
197
|
+
# Since Gemini tokenization needs to send an API request to get the token count we will use TikToken to
|
|
198
|
+
# count tokens as we calculate tokens word by word
|
|
199
|
+
tokenizer = TikTokenTokenizer(
|
|
200
|
+
model=None, max_completion_tokens=self.max_completion_tokens
|
|
188
201
|
)
|
|
202
|
+
# Note: Gemini Tokenizer expects an LLM model as input and not the embedding model
|
|
203
|
+
# tokenizer = GeminiTokenizer(
|
|
204
|
+
# llm_model=llm_model, max_completion_tokens=self.max_completion_tokens
|
|
205
|
+
# )
|
|
189
206
|
elif "mistral" in self.provider.lower():
|
|
190
207
|
tokenizer = MistralTokenizer(
|
|
191
208
|
model=model, max_completion_tokens=self.max_completion_tokens
|
|
@@ -54,12 +54,14 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|
|
54
54
|
max_completion_tokens: int = 512,
|
|
55
55
|
endpoint: Optional[str] = "http://localhost:11434/api/embeddings",
|
|
56
56
|
huggingface_tokenizer: str = "Salesforce/SFR-Embedding-Mistral",
|
|
57
|
+
batch_size: int = 100,
|
|
57
58
|
):
|
|
58
59
|
self.model = model
|
|
59
60
|
self.dimensions = dimensions
|
|
60
61
|
self.max_completion_tokens = max_completion_tokens
|
|
61
62
|
self.endpoint = endpoint
|
|
62
63
|
self.huggingface_tokenizer_name = huggingface_tokenizer
|
|
64
|
+
self.batch_size = batch_size
|
|
63
65
|
self.tokenizer = self.get_tokenizer()
|
|
64
66
|
|
|
65
67
|
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
|
|
@@ -122,6 +124,15 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|
|
122
124
|
"""
|
|
123
125
|
return self.dimensions
|
|
124
126
|
|
|
127
|
+
def get_batch_size(self) -> int:
|
|
128
|
+
"""
|
|
129
|
+
Return the desired batch size for embedding calls
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
|
|
133
|
+
"""
|
|
134
|
+
return self.batch_size
|
|
135
|
+
|
|
125
136
|
def get_tokenizer(self):
|
|
126
137
|
"""
|
|
127
138
|
Load and return a HuggingFace tokenizer for the embedding engine.
|
|
@@ -19,9 +19,17 @@ class EmbeddingConfig(BaseSettings):
|
|
|
19
19
|
embedding_api_key: Optional[str] = None
|
|
20
20
|
embedding_api_version: Optional[str] = None
|
|
21
21
|
embedding_max_completion_tokens: Optional[int] = 8191
|
|
22
|
+
embedding_batch_size: Optional[int] = None
|
|
22
23
|
huggingface_tokenizer: Optional[str] = None
|
|
23
24
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
|
24
25
|
|
|
26
|
+
def model_post_init(self, __context) -> None:
|
|
27
|
+
# If embedding batch size is not defined use 2048 as default for OpenAI and 100 for all other embedding models
|
|
28
|
+
if not self.embedding_batch_size and self.embedding_provider.lower() == "openai":
|
|
29
|
+
self.embedding_batch_size = 2048
|
|
30
|
+
elif not self.embedding_batch_size:
|
|
31
|
+
self.embedding_batch_size = 100
|
|
32
|
+
|
|
25
33
|
def to_dict(self) -> dict:
|
|
26
34
|
"""
|
|
27
35
|
Serialize all embedding configuration settings to a dictionary.
|
|
@@ -31,6 +31,7 @@ def get_embedding_engine() -> EmbeddingEngine:
|
|
|
31
31
|
config.embedding_endpoint,
|
|
32
32
|
config.embedding_api_key,
|
|
33
33
|
config.embedding_api_version,
|
|
34
|
+
config.embedding_batch_size,
|
|
34
35
|
config.huggingface_tokenizer,
|
|
35
36
|
llm_config.llm_api_key,
|
|
36
37
|
llm_config.llm_provider,
|
|
@@ -46,6 +47,7 @@ def create_embedding_engine(
|
|
|
46
47
|
embedding_endpoint,
|
|
47
48
|
embedding_api_key,
|
|
48
49
|
embedding_api_version,
|
|
50
|
+
embedding_batch_size,
|
|
49
51
|
huggingface_tokenizer,
|
|
50
52
|
llm_api_key,
|
|
51
53
|
llm_provider,
|
|
@@ -84,6 +86,7 @@ def create_embedding_engine(
|
|
|
84
86
|
model=embedding_model,
|
|
85
87
|
dimensions=embedding_dimensions,
|
|
86
88
|
max_completion_tokens=embedding_max_completion_tokens,
|
|
89
|
+
batch_size=embedding_batch_size,
|
|
87
90
|
)
|
|
88
91
|
|
|
89
92
|
if embedding_provider == "ollama":
|
|
@@ -95,6 +98,7 @@ def create_embedding_engine(
|
|
|
95
98
|
max_completion_tokens=embedding_max_completion_tokens,
|
|
96
99
|
endpoint=embedding_endpoint,
|
|
97
100
|
huggingface_tokenizer=huggingface_tokenizer,
|
|
101
|
+
batch_size=embedding_batch_size,
|
|
98
102
|
)
|
|
99
103
|
|
|
100
104
|
from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine
|
|
@@ -108,4 +112,5 @@ def create_embedding_engine(
|
|
|
108
112
|
model=embedding_model,
|
|
109
113
|
dimensions=embedding_dimensions,
|
|
110
114
|
max_completion_tokens=embedding_max_completion_tokens,
|
|
115
|
+
batch_size=embedding_batch_size,
|
|
111
116
|
)
|
|
@@ -205,9 +205,12 @@ class LanceDBAdapter(VectorDBInterface):
|
|
|
205
205
|
collection = await self.get_collection(collection_name)
|
|
206
206
|
|
|
207
207
|
if len(data_point_ids) == 1:
|
|
208
|
-
results = await collection.query().where(f"id = '{data_point_ids[0]}'")
|
|
208
|
+
results = await collection.query().where(f"id = '{data_point_ids[0]}'")
|
|
209
209
|
else:
|
|
210
|
-
results = await collection.query().where(f"id IN {tuple(data_point_ids)}")
|
|
210
|
+
results = await collection.query().where(f"id IN {tuple(data_point_ids)}")
|
|
211
|
+
|
|
212
|
+
# Convert query results to list format
|
|
213
|
+
results_list = results.to_list() if hasattr(results, "to_list") else list(results)
|
|
211
214
|
|
|
212
215
|
return [
|
|
213
216
|
ScoredResult(
|
|
@@ -215,7 +218,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|
|
215
218
|
payload=result["payload"],
|
|
216
219
|
score=0,
|
|
217
220
|
)
|
|
218
|
-
for result in
|
|
221
|
+
for result in results_list
|
|
219
222
|
]
|
|
220
223
|
|
|
221
224
|
async def search(
|
|
@@ -223,7 +226,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|
|
223
226
|
collection_name: str,
|
|
224
227
|
query_text: str = None,
|
|
225
228
|
query_vector: List[float] = None,
|
|
226
|
-
limit: int = 15,
|
|
229
|
+
limit: Optional[int] = 15,
|
|
227
230
|
with_vector: bool = False,
|
|
228
231
|
normalized: bool = True,
|
|
229
232
|
):
|
|
@@ -235,16 +238,14 @@ class LanceDBAdapter(VectorDBInterface):
|
|
|
235
238
|
|
|
236
239
|
collection = await self.get_collection(collection_name)
|
|
237
240
|
|
|
238
|
-
if limit
|
|
241
|
+
if limit is None:
|
|
239
242
|
limit = await collection.count_rows()
|
|
240
243
|
|
|
241
244
|
# LanceDB search will break if limit is 0 so we must return
|
|
242
|
-
if limit
|
|
245
|
+
if limit <= 0:
|
|
243
246
|
return []
|
|
244
247
|
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
result_values = list(results.to_dict("index").values())
|
|
248
|
+
result_values = await collection.vector_search(query_vector).limit(limit).to_list()
|
|
248
249
|
|
|
249
250
|
if not result_values:
|
|
250
251
|
return []
|
|
@@ -264,7 +265,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|
|
264
265
|
self,
|
|
265
266
|
collection_name: str,
|
|
266
267
|
query_texts: List[str],
|
|
267
|
-
limit: int = None,
|
|
268
|
+
limit: Optional[int] = None,
|
|
268
269
|
with_vectors: bool = False,
|
|
269
270
|
):
|
|
270
271
|
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
|
@@ -3,13 +3,12 @@ from typing import List, Optional, get_type_hints
|
|
|
3
3
|
from sqlalchemy.inspection import inspect
|
|
4
4
|
from sqlalchemy.orm import Mapped, mapped_column
|
|
5
5
|
from sqlalchemy.dialects.postgresql import insert
|
|
6
|
-
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
|
|
6
|
+
from sqlalchemy import JSON, Column, Table, select, delete, MetaData, func
|
|
7
7
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
|
8
8
|
from sqlalchemy.exc import ProgrammingError
|
|
9
9
|
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
|
10
10
|
from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationError
|
|
11
11
|
|
|
12
|
-
|
|
13
12
|
from cognee.shared.logging_utils import get_logger
|
|
14
13
|
from cognee.infrastructure.engine import DataPoint
|
|
15
14
|
from cognee.infrastructure.engine.utils import parse_id
|
|
@@ -126,41 +125,42 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|
|
126
125
|
data_point_types = get_type_hints(DataPoint)
|
|
127
126
|
vector_size = self.embedding_engine.get_vector_size()
|
|
128
127
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
self
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
128
|
+
if not await self.has_collection(collection_name):
|
|
129
|
+
async with self.VECTOR_DB_LOCK:
|
|
130
|
+
if not await self.has_collection(collection_name):
|
|
131
|
+
|
|
132
|
+
class PGVectorDataPoint(Base):
|
|
133
|
+
"""
|
|
134
|
+
Represent a point in a vector data space with associated data and vector representation.
|
|
135
|
+
|
|
136
|
+
This class inherits from Base and is associated with a database table defined by
|
|
137
|
+
__tablename__. It maintains the following public methods and instance variables:
|
|
138
|
+
|
|
139
|
+
- __init__(self, id, payload, vector): Initializes a new PGVectorDataPoint instance.
|
|
140
|
+
|
|
141
|
+
Instance variables:
|
|
142
|
+
- id: Identifier for the data point, defined by data_point_types.
|
|
143
|
+
- payload: JSON data associated with the data point.
|
|
144
|
+
- vector: Vector representation of the data point, with size defined by vector_size.
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
__tablename__ = collection_name
|
|
148
|
+
__table_args__ = {"extend_existing": True}
|
|
149
|
+
# PGVector requires one column to be the primary key
|
|
150
|
+
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
|
|
151
|
+
payload = Column(JSON)
|
|
152
|
+
vector = Column(self.Vector(vector_size))
|
|
153
|
+
|
|
154
|
+
def __init__(self, id, payload, vector):
|
|
155
|
+
self.id = id
|
|
156
|
+
self.payload = payload
|
|
157
|
+
self.vector = vector
|
|
158
|
+
|
|
159
|
+
async with self.engine.begin() as connection:
|
|
160
|
+
if len(Base.metadata.tables.keys()) > 0:
|
|
161
|
+
await connection.run_sync(
|
|
162
|
+
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
|
|
163
|
+
)
|
|
164
164
|
|
|
165
165
|
@retry(
|
|
166
166
|
retry=retry_if_exception_type(DeadlockDetectedError),
|
|
@@ -299,7 +299,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|
|
299
299
|
collection_name: str,
|
|
300
300
|
query_text: Optional[str] = None,
|
|
301
301
|
query_vector: Optional[List[float]] = None,
|
|
302
|
-
limit: int = 15,
|
|
302
|
+
limit: Optional[int] = 15,
|
|
303
303
|
with_vector: bool = False,
|
|
304
304
|
) -> List[ScoredResult]:
|
|
305
305
|
if query_text is None and query_vector is None:
|
|
@@ -311,6 +311,16 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|
|
311
311
|
# Get PGVectorDataPoint Table from database
|
|
312
312
|
PGVectorDataPoint = await self.get_table(collection_name)
|
|
313
313
|
|
|
314
|
+
if limit is None:
|
|
315
|
+
async with self.get_async_session() as session:
|
|
316
|
+
query = select(func.count()).select_from(PGVectorDataPoint)
|
|
317
|
+
result = await session.execute(query)
|
|
318
|
+
limit = result.scalar_one()
|
|
319
|
+
|
|
320
|
+
# If limit is still 0, no need to do the search, just return empty results
|
|
321
|
+
if limit <= 0:
|
|
322
|
+
return []
|
|
323
|
+
|
|
314
324
|
# NOTE: This needs to be initialized in case search doesn't return a value
|
|
315
325
|
closest_items = []
|
|
316
326
|
|
|
@@ -83,7 +83,7 @@ class VectorDBInterface(Protocol):
|
|
|
83
83
|
collection_name: str,
|
|
84
84
|
query_text: Optional[str],
|
|
85
85
|
query_vector: Optional[List[float]],
|
|
86
|
-
limit: int,
|
|
86
|
+
limit: Optional[int],
|
|
87
87
|
with_vector: bool = False,
|
|
88
88
|
):
|
|
89
89
|
"""
|
|
@@ -98,7 +98,7 @@ class VectorDBInterface(Protocol):
|
|
|
98
98
|
collection.
|
|
99
99
|
- query_vector (Optional[List[float]]): An optional vector representation for
|
|
100
100
|
searching the collection.
|
|
101
|
-
- limit (int): The maximum number of results to return from the search.
|
|
101
|
+
- limit (Optional[int]): The maximum number of results to return from the search.
|
|
102
102
|
- with_vector (bool): Whether to return the vector representations with search
|
|
103
103
|
results. (default False)
|
|
104
104
|
"""
|
|
@@ -106,7 +106,11 @@ class VectorDBInterface(Protocol):
|
|
|
106
106
|
|
|
107
107
|
@abstractmethod
|
|
108
108
|
async def batch_search(
|
|
109
|
-
self,
|
|
109
|
+
self,
|
|
110
|
+
collection_name: str,
|
|
111
|
+
query_texts: List[str],
|
|
112
|
+
limit: Optional[int],
|
|
113
|
+
with_vectors: bool = False,
|
|
110
114
|
):
|
|
111
115
|
"""
|
|
112
116
|
Perform a batch search using multiple text queries against a collection.
|
|
@@ -116,7 +120,7 @@ class VectorDBInterface(Protocol):
|
|
|
116
120
|
|
|
117
121
|
- collection_name (str): The name of the collection to conduct the batch search in.
|
|
118
122
|
- query_texts (List[str]): A list of text queries to use for the search.
|
|
119
|
-
- limit (int): The maximum number of results to return for each query.
|
|
123
|
+
- limit (Optional[int]): The maximum number of results to return for each query.
|
|
120
124
|
- with_vectors (bool): Whether to include vector representations with search
|
|
121
125
|
results. (default False)
|
|
122
126
|
"""
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
|
-
import
|
|
3
|
-
from typing import BinaryIO, Union
|
|
2
|
+
from typing import BinaryIO, Union, TYPE_CHECKING
|
|
4
3
|
from contextlib import asynccontextmanager
|
|
5
4
|
|
|
6
5
|
from cognee.infrastructure.files.storage.s3_config import get_s3_config
|
|
@@ -8,23 +7,34 @@ from cognee.infrastructure.utils.run_async import run_async
|
|
|
8
7
|
from cognee.infrastructure.files.storage.FileBufferedReader import FileBufferedReader
|
|
9
8
|
from .storage import Storage
|
|
10
9
|
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
import s3fs
|
|
12
|
+
|
|
11
13
|
|
|
12
14
|
class S3FileStorage(Storage):
|
|
13
15
|
"""
|
|
14
|
-
Manage
|
|
15
|
-
|
|
16
|
+
Manage S3 file storage operations such as storing, retrieving, and managing files on
|
|
17
|
+
S3-compatible storage.
|
|
16
18
|
"""
|
|
17
19
|
|
|
18
20
|
storage_path: str
|
|
19
|
-
s3: s3fs.S3FileSystem
|
|
21
|
+
s3: "s3fs.S3FileSystem"
|
|
20
22
|
|
|
21
23
|
def __init__(self, storage_path: str):
|
|
24
|
+
try:
|
|
25
|
+
import s3fs
|
|
26
|
+
except ImportError:
|
|
27
|
+
raise ImportError(
|
|
28
|
+
's3fs is required for S3FileStorage. Install it with: pip install cognee"[aws]"'
|
|
29
|
+
)
|
|
30
|
+
|
|
22
31
|
self.storage_path = storage_path
|
|
23
32
|
s3_config = get_s3_config()
|
|
24
33
|
if s3_config.aws_access_key_id is not None and s3_config.aws_secret_access_key is not None:
|
|
25
34
|
self.s3 = s3fs.S3FileSystem(
|
|
26
35
|
key=s3_config.aws_access_key_id,
|
|
27
36
|
secret=s3_config.aws_secret_access_key,
|
|
37
|
+
token=s3_config.aws_session_token,
|
|
28
38
|
anon=False,
|
|
29
39
|
endpoint_url=s3_config.aws_endpoint_url,
|
|
30
40
|
client_kwargs={"region_name": s3_config.aws_region},
|
|
@@ -8,6 +8,7 @@ class S3Config(BaseSettings):
|
|
|
8
8
|
aws_endpoint_url: Optional[str] = None
|
|
9
9
|
aws_access_key_id: Optional[str] = None
|
|
10
10
|
aws_secret_access_key: Optional[str] = None
|
|
11
|
+
aws_session_token: Optional[str] = None
|
|
11
12
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
|
12
13
|
|
|
13
14
|
|
|
@@ -4,7 +4,6 @@ from urllib.parse import urlparse
|
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
5
|
|
|
6
6
|
from cognee.infrastructure.files.utils.get_data_file_path import get_data_file_path
|
|
7
|
-
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
|
|
8
7
|
from cognee.infrastructure.files.storage.LocalFileStorage import LocalFileStorage
|
|
9
8
|
|
|
10
9
|
|
|
@@ -23,23 +22,17 @@ async def open_data_file(file_path: str, mode: str = "rb", encoding: str = None,
|
|
|
23
22
|
yield file
|
|
24
23
|
|
|
25
24
|
elif file_path.startswith("s3://"):
|
|
25
|
+
try:
|
|
26
|
+
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
|
|
27
|
+
except ImportError:
|
|
28
|
+
raise ImportError(
|
|
29
|
+
"S3 dependencies are not installed. Please install with 'pip install cognee\"[aws]\"' to use S3 functionality."
|
|
30
|
+
)
|
|
31
|
+
|
|
26
32
|
normalized_url = get_data_file_path(file_path)
|
|
27
33
|
s3_dir_path = os.path.dirname(normalized_url)
|
|
28
34
|
s3_filename = os.path.basename(normalized_url)
|
|
29
35
|
|
|
30
|
-
# if "/" in s3_path:
|
|
31
|
-
# s3_dir = "/".join(s3_path.split("/")[:-1])
|
|
32
|
-
# s3_filename = s3_path.split("/")[-1]
|
|
33
|
-
# else:
|
|
34
|
-
# s3_dir = ""
|
|
35
|
-
# s3_filename = s3_path
|
|
36
|
-
|
|
37
|
-
# Extract filesystem path from S3 URL structure
|
|
38
|
-
# file_dir_path = (
|
|
39
|
-
# f"s3://{parsed_url.netloc}/{s3_dir}" if s3_dir else f"s3://{parsed_url.netloc}"
|
|
40
|
-
# )
|
|
41
|
-
# file_name = s3_filename
|
|
42
|
-
|
|
43
36
|
file_storage = S3FileStorage(s3_dir_path)
|
|
44
37
|
|
|
45
38
|
async with file_storage.open(s3_filename, mode=mode, **kwargs) as file:
|