cognee 0.5.1.dev0__py3-none-any.whl → 0.5.2.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +2 -0
- cognee/alembic/README +1 -0
- cognee/alembic/env.py +107 -0
- cognee/alembic/script.py.mako +26 -0
- cognee/alembic/versions/1a58b986e6e1_enable_delete_for_old_tutorial_notebooks.py +52 -0
- cognee/alembic/versions/1d0bb7fede17_add_pipeline_run_status.py +33 -0
- cognee/alembic/versions/1daae0df1866_incremental_loading.py +48 -0
- cognee/alembic/versions/211ab850ef3d_add_sync_operations_table.py +118 -0
- cognee/alembic/versions/45957f0a9849_add_notebook_table.py +46 -0
- cognee/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py +333 -0
- cognee/alembic/versions/482cd6517ce4_add_default_user.py +30 -0
- cognee/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py +98 -0
- cognee/alembic/versions/8057ae7329c2_initial_migration.py +25 -0
- cognee/alembic/versions/9e7a3cb85175_loader_separation.py +104 -0
- cognee/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py +38 -0
- cognee/alembic/versions/ab7e313804ae_permission_system_rework.py +236 -0
- cognee/alembic/versions/b9274c27a25a_kuzu_11_migration.py +75 -0
- cognee/alembic/versions/c946955da633_multi_tenant_support.py +137 -0
- cognee/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py +51 -0
- cognee/alembic/versions/e4ebee1091e7_expand_data_model_info.py +140 -0
- cognee/alembic.ini +117 -0
- cognee/api/v1/add/routers/get_add_router.py +2 -0
- cognee/api/v1/cognify/cognify.py +11 -6
- cognee/api/v1/cognify/routers/get_cognify_router.py +8 -0
- cognee/api/v1/config/config.py +60 -0
- cognee/api/v1/datasets/routers/get_datasets_router.py +45 -3
- cognee/api/v1/memify/routers/get_memify_router.py +2 -0
- cognee/api/v1/search/routers/get_search_router.py +21 -6
- cognee/api/v1/search/search.py +25 -5
- cognee/api/v1/sync/routers/get_sync_router.py +3 -3
- cognee/cli/commands/add_command.py +1 -1
- cognee/cli/commands/cognify_command.py +6 -0
- cognee/cli/commands/config_command.py +1 -1
- cognee/context_global_variables.py +5 -1
- cognee/eval_framework/answer_generation/answer_generation_executor.py +7 -8
- cognee/infrastructure/databases/cache/cache_db_interface.py +38 -1
- cognee/infrastructure/databases/cache/config.py +6 -0
- cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py +21 -0
- cognee/infrastructure/databases/cache/get_cache_engine.py +9 -3
- cognee/infrastructure/databases/cache/redis/RedisAdapter.py +60 -1
- cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +7 -0
- cognee/infrastructure/databases/graph/get_graph_engine.py +29 -1
- cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +62 -27
- cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +17 -4
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +2 -1
- cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +2 -0
- cognee/infrastructure/databases/vector/config.py +6 -0
- cognee/infrastructure/databases/vector/create_vector_engine.py +69 -22
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +64 -9
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +13 -2
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +16 -3
- cognee/infrastructure/databases/vector/models/ScoredResult.py +3 -3
- cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +16 -3
- cognee/infrastructure/databases/vector/pgvector/PGVectorDatasetDatabaseHandler.py +86 -0
- cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py +81 -2
- cognee/infrastructure/databases/vector/vector_db_interface.py +8 -0
- cognee/infrastructure/files/utils/get_data_file_path.py +33 -27
- cognee/infrastructure/llm/prompts/extract_query_time.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_event_entity_prompt.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_event_graph_prompt.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt.txt +2 -2
- cognee/infrastructure/llm/prompts/generate_graph_prompt_guided.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt_oneshot.txt +2 -2
- cognee/infrastructure/llm/prompts/generate_graph_prompt_simple.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt_strict.txt +1 -1
- cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt +6 -6
- cognee/infrastructure/llm/prompts/test.txt +1 -1
- cognee/infrastructure/llm/prompts/translate_content.txt +19 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +24 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llama_cpp/adapter.py +191 -0
- cognee/modules/chunking/models/DocumentChunk.py +0 -1
- cognee/modules/cognify/config.py +2 -0
- cognee/modules/data/models/Data.py +1 -0
- cognee/modules/engine/models/Entity.py +0 -1
- cognee/modules/engine/operations/setup.py +6 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +150 -37
- cognee/modules/graph/cognee_graph/CogneeGraphElements.py +48 -2
- cognee/modules/graph/utils/__init__.py +1 -0
- cognee/modules/graph/utils/get_entity_nodes_from_triplets.py +12 -0
- cognee/modules/notebooks/methods/__init__.py +1 -0
- cognee/modules/notebooks/methods/create_notebook.py +0 -34
- cognee/modules/notebooks/methods/create_tutorial_notebooks.py +191 -0
- cognee/modules/notebooks/methods/get_notebooks.py +12 -8
- cognee/modules/notebooks/tutorials/cognee-basics/cell-1.md +3 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-2.md +10 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-3.md +7 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-4.py +28 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-5.py +3 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-6.py +9 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-7.py +17 -0
- cognee/modules/notebooks/tutorials/cognee-basics/config.json +4 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-1.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-10.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-11.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-12.py +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-13.md +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-14.py +6 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-15.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-16.py +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-2.md +9 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-3.md +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-4.md +9 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-5.md +5 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-6.py +13 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-7.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-8.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-9.py +31 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/config.json +4 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/copilot_conversations.json +107 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/guido_contributions.json +976 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/my_developer_rules.md +79 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/pep_style_guide.md +74 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/zen_principles.md +74 -0
- cognee/modules/retrieval/EntityCompletionRetriever.py +51 -38
- cognee/modules/retrieval/__init__.py +0 -1
- cognee/modules/retrieval/base_retriever.py +66 -10
- cognee/modules/retrieval/chunks_retriever.py +57 -49
- cognee/modules/retrieval/coding_rules_retriever.py +12 -5
- cognee/modules/retrieval/completion_retriever.py +29 -28
- cognee/modules/retrieval/cypher_search_retriever.py +25 -20
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +42 -46
- cognee/modules/retrieval/graph_completion_cot_retriever.py +68 -51
- cognee/modules/retrieval/graph_completion_retriever.py +78 -63
- cognee/modules/retrieval/graph_summary_completion_retriever.py +2 -0
- cognee/modules/retrieval/lexical_retriever.py +34 -12
- cognee/modules/retrieval/natural_language_retriever.py +18 -15
- cognee/modules/retrieval/summaries_retriever.py +51 -34
- cognee/modules/retrieval/temporal_retriever.py +59 -49
- cognee/modules/retrieval/triplet_retriever.py +31 -32
- cognee/modules/retrieval/utils/access_tracking.py +88 -0
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +99 -85
- cognee/modules/retrieval/utils/node_edge_vector_search.py +174 -0
- cognee/modules/search/methods/__init__.py +1 -0
- cognee/modules/search/methods/get_retriever_output.py +53 -0
- cognee/modules/search/methods/get_search_type_retriever_instance.py +252 -0
- cognee/modules/search/methods/search.py +90 -215
- cognee/modules/search/models/SearchResultPayload.py +67 -0
- cognee/modules/search/types/SearchResult.py +1 -8
- cognee/modules/search/types/SearchType.py +1 -2
- cognee/modules/search/types/__init__.py +1 -1
- cognee/modules/search/utils/__init__.py +1 -2
- cognee/modules/search/utils/transform_insights_to_graph.py +2 -2
- cognee/modules/search/utils/{transform_context_to_graph.py → transform_triplets_to_graph.py} +2 -2
- cognee/modules/users/authentication/default/default_transport.py +11 -1
- cognee/modules/users/authentication/get_api_auth_backend.py +2 -1
- cognee/modules/users/authentication/get_client_auth_backend.py +2 -1
- cognee/modules/users/methods/create_user.py +0 -9
- cognee/modules/users/permissions/methods/has_user_management_permission.py +29 -0
- cognee/modules/visualization/cognee_network_visualization.py +1 -1
- cognee/run_migrations.py +48 -0
- cognee/shared/exceptions/__init__.py +1 -3
- cognee/shared/exceptions/exceptions.py +11 -1
- cognee/shared/usage_logger.py +332 -0
- cognee/shared/utils.py +12 -5
- cognee/tasks/cleanup/cleanup_unused_data.py +172 -0
- cognee/tasks/memify/extract_usage_frequency.py +613 -0
- cognee/tasks/summarization/models.py +0 -2
- cognee/tasks/temporal_graph/__init__.py +0 -1
- cognee/tasks/translation/__init__.py +96 -0
- cognee/tasks/translation/config.py +110 -0
- cognee/tasks/translation/detect_language.py +190 -0
- cognee/tasks/translation/exceptions.py +62 -0
- cognee/tasks/translation/models.py +72 -0
- cognee/tasks/translation/providers/__init__.py +44 -0
- cognee/tasks/translation/providers/azure_provider.py +192 -0
- cognee/tasks/translation/providers/base.py +85 -0
- cognee/tasks/translation/providers/google_provider.py +158 -0
- cognee/tasks/translation/providers/llm_provider.py +143 -0
- cognee/tasks/translation/translate_content.py +282 -0
- cognee/tasks/web_scraper/default_url_crawler.py +6 -2
- cognee/tests/cli_tests/cli_unit_tests/test_cli_commands.py +1 -0
- cognee/tests/cli_tests/cli_unit_tests/test_cli_edge_cases.py +3 -0
- cognee/tests/integration/retrieval/test_brute_force_triplet_search_with_cognify.py +62 -0
- cognee/tests/integration/retrieval/test_chunks_retriever.py +115 -16
- cognee/tests/integration/retrieval/test_graph_completion_retriever.py +13 -5
- cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +22 -20
- cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +23 -24
- cognee/tests/integration/retrieval/test_rag_completion_retriever.py +70 -5
- cognee/tests/integration/retrieval/test_structured_output.py +62 -18
- cognee/tests/integration/retrieval/test_summaries_retriever.py +20 -9
- cognee/tests/integration/retrieval/test_temporal_retriever.py +38 -8
- cognee/tests/integration/retrieval/test_triplet_retriever.py +13 -4
- cognee/tests/integration/shared/test_usage_logger_integration.py +255 -0
- cognee/tests/tasks/translation/README.md +147 -0
- cognee/tests/tasks/translation/__init__.py +1 -0
- cognee/tests/tasks/translation/config_test.py +93 -0
- cognee/tests/tasks/translation/detect_language_test.py +118 -0
- cognee/tests/tasks/translation/providers_test.py +151 -0
- cognee/tests/tasks/translation/translate_content_test.py +213 -0
- cognee/tests/test_chromadb.py +1 -1
- cognee/tests/test_cleanup_unused_data.py +165 -0
- cognee/tests/test_delete_by_id.py +6 -6
- cognee/tests/test_extract_usage_frequency.py +308 -0
- cognee/tests/test_kuzu.py +17 -7
- cognee/tests/test_lancedb.py +3 -1
- cognee/tests/test_library.py +1 -1
- cognee/tests/test_neo4j.py +17 -7
- cognee/tests/test_neptune_analytics_vector.py +3 -1
- cognee/tests/test_permissions.py +172 -187
- cognee/tests/test_pgvector.py +3 -1
- cognee/tests/test_relational_db_migration.py +15 -1
- cognee/tests/test_remote_kuzu.py +3 -1
- cognee/tests/test_s3_file_storage.py +1 -1
- cognee/tests/test_search_db.py +97 -110
- cognee/tests/test_usage_logger_e2e.py +268 -0
- cognee/tests/unit/api/test_get_raw_data_endpoint.py +206 -0
- cognee/tests/unit/eval_framework/answer_generation_test.py +4 -3
- cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +2 -0
- cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +42 -2
- cognee/tests/unit/modules/graph/cognee_graph_test.py +329 -31
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +31 -59
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +70 -33
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +72 -52
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +27 -33
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +28 -15
- cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +37 -42
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +48 -64
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +263 -24
- cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py +273 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +30 -16
- cognee/tests/unit/modules/search/test_get_search_type_retriever_instance.py +125 -0
- cognee/tests/unit/modules/search/test_search.py +176 -0
- cognee/tests/unit/modules/search/test_search_prepare_search_result_contract.py +190 -0
- cognee/tests/unit/modules/users/test_tutorial_notebook_creation.py +511 -297
- cognee/tests/unit/shared/test_usage_logger.py +241 -0
- cognee/tests/unit/users/permissions/test_has_user_management_permission.py +46 -0
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dev0.dist-info}/METADATA +17 -10
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dev0.dist-info}/RECORD +232 -144
- cognee/api/.env.example +0 -5
- cognee/modules/retrieval/base_graph_retriever.py +0 -24
- cognee/modules/search/methods/get_search_type_tools.py +0 -223
- cognee/modules/search/methods/no_access_control_search.py +0 -62
- cognee/modules/search/utils/prepare_search_result.py +0 -63
- cognee/tests/test_feedback_enrichment.py +0 -174
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dev0.dist-info}/WHEEL +0 -0
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dev0.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dev0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dev0.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -7,7 +7,6 @@ from cognee.infrastructure.databases.graph.config import get_graph_context_confi
|
|
|
7
7
|
from functools import lru_cache
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
@lru_cache
|
|
11
10
|
def create_vector_engine(
|
|
12
11
|
vector_db_provider: str,
|
|
13
12
|
vector_db_url: str,
|
|
@@ -15,6 +14,38 @@ def create_vector_engine(
|
|
|
15
14
|
vector_db_port: str = "",
|
|
16
15
|
vector_db_key: str = "",
|
|
17
16
|
vector_dataset_database_handler: str = "",
|
|
17
|
+
vector_db_username: str = "",
|
|
18
|
+
vector_db_password: str = "",
|
|
19
|
+
vector_db_host: str = "",
|
|
20
|
+
):
|
|
21
|
+
"""
|
|
22
|
+
Wrapper function to call create vector engine with caching.
|
|
23
|
+
For a detailed description, see _create_vector_engine.
|
|
24
|
+
"""
|
|
25
|
+
return _create_vector_engine(
|
|
26
|
+
vector_db_provider,
|
|
27
|
+
vector_db_url,
|
|
28
|
+
vector_db_name,
|
|
29
|
+
vector_db_port,
|
|
30
|
+
vector_db_key,
|
|
31
|
+
vector_dataset_database_handler,
|
|
32
|
+
vector_db_username,
|
|
33
|
+
vector_db_password,
|
|
34
|
+
vector_db_host,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@lru_cache
|
|
39
|
+
def _create_vector_engine(
|
|
40
|
+
vector_db_provider: str,
|
|
41
|
+
vector_db_url: str,
|
|
42
|
+
vector_db_name: str,
|
|
43
|
+
vector_db_port: str,
|
|
44
|
+
vector_db_key: str,
|
|
45
|
+
vector_dataset_database_handler: str,
|
|
46
|
+
vector_db_username: str,
|
|
47
|
+
vector_db_password: str,
|
|
48
|
+
vector_db_host: str,
|
|
18
49
|
):
|
|
19
50
|
"""
|
|
20
51
|
Create a vector database engine based on the specified provider.
|
|
@@ -55,27 +86,43 @@ def create_vector_engine(
|
|
|
55
86
|
)
|
|
56
87
|
|
|
57
88
|
if vector_db_provider.lower() == "pgvector":
|
|
58
|
-
from cognee.
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
89
|
+
from cognee.context_global_variables import backend_access_control_enabled
|
|
90
|
+
|
|
91
|
+
if backend_access_control_enabled():
|
|
92
|
+
connection_string: str = (
|
|
93
|
+
f"postgresql+asyncpg://{vector_db_username}:{vector_db_password}"
|
|
94
|
+
f"@{vector_db_host}:{vector_db_port}/{vector_db_name}"
|
|
95
|
+
)
|
|
96
|
+
else:
|
|
97
|
+
if (
|
|
98
|
+
vector_db_port
|
|
99
|
+
and vector_db_username
|
|
100
|
+
and vector_db_password
|
|
101
|
+
and vector_db_host
|
|
102
|
+
and vector_db_name
|
|
103
|
+
):
|
|
104
|
+
connection_string: str = (
|
|
105
|
+
f"postgresql+asyncpg://{vector_db_username}:{vector_db_password}"
|
|
106
|
+
f"@{vector_db_host}:{vector_db_port}/{vector_db_name}"
|
|
107
|
+
)
|
|
108
|
+
else:
|
|
109
|
+
from cognee.infrastructure.databases.relational import get_relational_config
|
|
110
|
+
|
|
111
|
+
# Get configuration for postgres database
|
|
112
|
+
relational_config = get_relational_config()
|
|
113
|
+
db_username = relational_config.db_username
|
|
114
|
+
db_password = relational_config.db_password
|
|
115
|
+
db_host = relational_config.db_host
|
|
116
|
+
db_port = relational_config.db_port
|
|
117
|
+
db_name = relational_config.db_name
|
|
118
|
+
|
|
119
|
+
if not (db_host and db_port and db_name and db_username and db_password):
|
|
120
|
+
raise EnvironmentError("Missing required pgvector credentials!")
|
|
121
|
+
|
|
122
|
+
connection_string: str = (
|
|
123
|
+
f"postgresql+asyncpg://{db_username}:{db_password}"
|
|
124
|
+
f"@{db_host}:{db_port}/{db_name}"
|
|
125
|
+
)
|
|
79
126
|
|
|
80
127
|
try:
|
|
81
128
|
from .pgvector.PGVectorAdapter import PGVectorAdapter
|
|
@@ -14,6 +14,8 @@ from tenacity import (
|
|
|
14
14
|
)
|
|
15
15
|
import litellm
|
|
16
16
|
import os
|
|
17
|
+
from urllib.parse import urlparse
|
|
18
|
+
import httpx
|
|
17
19
|
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
|
18
20
|
from cognee.infrastructure.databases.exceptions import EmbeddingException
|
|
19
21
|
from cognee.infrastructure.llm.tokenizer.HuggingFace import (
|
|
@@ -79,10 +81,26 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|
|
79
81
|
enable_mocking = str(enable_mocking).lower()
|
|
80
82
|
self.mock = enable_mocking in ("true", "1", "yes")
|
|
81
83
|
|
|
84
|
+
# Validate provided custom embedding endpoint early to avoid long hangs later
|
|
85
|
+
if self.endpoint:
|
|
86
|
+
try:
|
|
87
|
+
parsed = urlparse(self.endpoint)
|
|
88
|
+
except Exception:
|
|
89
|
+
parsed = None
|
|
90
|
+
if not parsed or parsed.scheme not in ("http", "https") or not parsed.netloc:
|
|
91
|
+
logger.error(
|
|
92
|
+
"Invalid EMBEDDING_ENDPOINT configured: '%s'. Expected a URL starting with http:// or https://",
|
|
93
|
+
str(self.endpoint),
|
|
94
|
+
)
|
|
95
|
+
raise EmbeddingException(
|
|
96
|
+
"Invalid EMBEDDING_ENDPOINT. Please set a valid URL (e.g., https://host:port) "
|
|
97
|
+
"via environment variable EMBEDDING_ENDPOINT."
|
|
98
|
+
)
|
|
99
|
+
|
|
82
100
|
@retry(
|
|
83
101
|
stop=stop_after_delay(128),
|
|
84
102
|
wait=wait_exponential_jitter(2, 128),
|
|
85
|
-
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
103
|
+
retry=retry_if_not_exception_type((litellm.exceptions.NotFoundError)),
|
|
86
104
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
87
105
|
reraise=True,
|
|
88
106
|
)
|
|
@@ -111,12 +129,21 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|
|
111
129
|
return [data["embedding"] for data in response["data"]]
|
|
112
130
|
else:
|
|
113
131
|
async with embedding_rate_limiter_context_manager():
|
|
114
|
-
|
|
115
|
-
model
|
|
116
|
-
input
|
|
117
|
-
api_key
|
|
118
|
-
api_base
|
|
119
|
-
api_version
|
|
132
|
+
embedding_kwargs = {
|
|
133
|
+
"model": self.model,
|
|
134
|
+
"input": text,
|
|
135
|
+
"api_key": self.api_key,
|
|
136
|
+
"api_base": self.endpoint,
|
|
137
|
+
"api_version": self.api_version,
|
|
138
|
+
}
|
|
139
|
+
# Pass through target embedding dimensions when supported
|
|
140
|
+
if self.dimensions is not None:
|
|
141
|
+
embedding_kwargs["dimensions"] = self.dimensions
|
|
142
|
+
|
|
143
|
+
# Ensure each attempt does not hang indefinitely
|
|
144
|
+
response = await asyncio.wait_for(
|
|
145
|
+
litellm.aembedding(**embedding_kwargs),
|
|
146
|
+
timeout=30.0,
|
|
120
147
|
)
|
|
121
148
|
|
|
122
149
|
return [data["embedding"] for data in response.data]
|
|
@@ -154,6 +181,27 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|
|
154
181
|
logger.error("Context window exceeded for embedding text: %s", str(error))
|
|
155
182
|
raise error
|
|
156
183
|
|
|
184
|
+
except asyncio.TimeoutError as e:
|
|
185
|
+
# Per-attempt timeout – likely an unreachable endpoint
|
|
186
|
+
logger.error(
|
|
187
|
+
"Embedding endpoint timed out. EMBEDDING_ENDPOINT='%s'. "
|
|
188
|
+
"Verify that the endpoint is reachable and correct.",
|
|
189
|
+
str(self.endpoint),
|
|
190
|
+
)
|
|
191
|
+
raise EmbeddingException(
|
|
192
|
+
"Embedding request timed out. Check EMBEDDING_ENDPOINT connectivity."
|
|
193
|
+
) from e
|
|
194
|
+
|
|
195
|
+
except (httpx.ConnectError, httpx.ReadTimeout) as e:
|
|
196
|
+
logger.error(
|
|
197
|
+
"Failed to connect to embedding endpoint. EMBEDDING_ENDPOINT='%s'. "
|
|
198
|
+
"Ensure the URL is correct and the server is running.",
|
|
199
|
+
str(self.endpoint),
|
|
200
|
+
)
|
|
201
|
+
raise EmbeddingException(
|
|
202
|
+
"Cannot connect to embedding endpoint. Check EMBEDDING_ENDPOINT."
|
|
203
|
+
) from e
|
|
204
|
+
|
|
157
205
|
except (
|
|
158
206
|
litellm.exceptions.BadRequestError,
|
|
159
207
|
litellm.exceptions.NotFoundError,
|
|
@@ -162,8 +210,15 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|
|
162
210
|
raise EmbeddingException(f"Failed to index data points using model {self.model}") from e
|
|
163
211
|
|
|
164
212
|
except Exception as error:
|
|
165
|
-
|
|
166
|
-
|
|
213
|
+
# Fall back to a clear, actionable message for connectivity/misconfiguration issues
|
|
214
|
+
logger.error(
|
|
215
|
+
"Error embedding text: %s. EMBEDDING_ENDPOINT='%s'.",
|
|
216
|
+
str(error),
|
|
217
|
+
str(self.endpoint),
|
|
218
|
+
)
|
|
219
|
+
raise EmbeddingException(
|
|
220
|
+
"Embedding failed due to an unexpected error. Verify EMBEDDING_ENDPOINT and provider settings."
|
|
221
|
+
) from error
|
|
167
222
|
|
|
168
223
|
def get_vector_size(self) -> int:
|
|
169
224
|
"""
|
|
@@ -57,7 +57,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|
|
57
57
|
model: Optional[str] = "avr/sfr-embedding-mistral:latest",
|
|
58
58
|
dimensions: Optional[int] = 1024,
|
|
59
59
|
max_completion_tokens: int = 512,
|
|
60
|
-
endpoint: Optional[str] = "http://localhost:11434/api/
|
|
60
|
+
endpoint: Optional[str] = "http://localhost:11434/api/embed",
|
|
61
61
|
huggingface_tokenizer: str = "Salesforce/SFR-Embedding-Mistral",
|
|
62
62
|
batch_size: int = 100,
|
|
63
63
|
):
|
|
@@ -93,6 +93,10 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|
|
93
93
|
if self.mock:
|
|
94
94
|
return [[0.0] * self.dimensions for _ in text]
|
|
95
95
|
|
|
96
|
+
# Handle case when a single string is passed instead of a list
|
|
97
|
+
if not isinstance(text, list):
|
|
98
|
+
text = [text]
|
|
99
|
+
|
|
96
100
|
embeddings = await asyncio.gather(*[self._get_embedding(prompt) for prompt in text])
|
|
97
101
|
return embeddings
|
|
98
102
|
|
|
@@ -107,7 +111,12 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|
|
107
111
|
"""
|
|
108
112
|
Internal method to call the Ollama embeddings endpoint for a single prompt.
|
|
109
113
|
"""
|
|
110
|
-
payload = {
|
|
114
|
+
payload = {
|
|
115
|
+
"model": self.model,
|
|
116
|
+
"prompt": prompt,
|
|
117
|
+
"input": prompt,
|
|
118
|
+
"dimensions": self.dimensions,
|
|
119
|
+
}
|
|
111
120
|
|
|
112
121
|
headers = {}
|
|
113
122
|
api_key = os.getenv("LLM_API_KEY")
|
|
@@ -124,6 +133,8 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|
|
124
133
|
data = await response.json()
|
|
125
134
|
if "embeddings" in data:
|
|
126
135
|
return data["embeddings"][0]
|
|
136
|
+
if "embedding" in data:
|
|
137
|
+
return data["embedding"]
|
|
127
138
|
else:
|
|
128
139
|
return data["data"][0]["embedding"]
|
|
129
140
|
|
|
@@ -231,6 +231,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|
|
231
231
|
limit: Optional[int] = 15,
|
|
232
232
|
with_vector: bool = False,
|
|
233
233
|
normalized: bool = True,
|
|
234
|
+
include_payload: bool = False,
|
|
234
235
|
):
|
|
235
236
|
if query_text is None and query_vector is None:
|
|
236
237
|
raise MissingQueryParameterError()
|
|
@@ -247,17 +248,27 @@ class LanceDBAdapter(VectorDBInterface):
|
|
|
247
248
|
if limit <= 0:
|
|
248
249
|
return []
|
|
249
250
|
|
|
250
|
-
|
|
251
|
+
# Note: Exclude payload if not needed to optimize performance
|
|
252
|
+
select_columns = (
|
|
253
|
+
["id", "vector", "payload", "_distance"]
|
|
254
|
+
if include_payload
|
|
255
|
+
else ["id", "vector", "_distance"]
|
|
256
|
+
)
|
|
257
|
+
result_values = (
|
|
258
|
+
await collection.vector_search(query_vector)
|
|
259
|
+
.select(select_columns)
|
|
260
|
+
.limit(limit)
|
|
261
|
+
.to_list()
|
|
262
|
+
)
|
|
251
263
|
|
|
252
264
|
if not result_values:
|
|
253
265
|
return []
|
|
254
|
-
|
|
255
266
|
normalized_values = normalize_distances(result_values)
|
|
256
267
|
|
|
257
268
|
return [
|
|
258
269
|
ScoredResult(
|
|
259
270
|
id=parse_id(result["id"]),
|
|
260
|
-
payload=result["payload"],
|
|
271
|
+
payload=result["payload"] if include_payload else None,
|
|
261
272
|
score=normalized_values[value_index],
|
|
262
273
|
)
|
|
263
274
|
for value_index, result in enumerate(result_values)
|
|
@@ -269,6 +280,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|
|
269
280
|
query_texts: List[str],
|
|
270
281
|
limit: Optional[int] = None,
|
|
271
282
|
with_vectors: bool = False,
|
|
283
|
+
include_payload: bool = False,
|
|
272
284
|
):
|
|
273
285
|
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
|
274
286
|
|
|
@@ -279,6 +291,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|
|
279
291
|
query_vector=query_vector,
|
|
280
292
|
limit=limit,
|
|
281
293
|
with_vector=with_vectors,
|
|
294
|
+
include_payload=include_payload,
|
|
282
295
|
)
|
|
283
296
|
for query_vector in query_vectors
|
|
284
297
|
]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Dict
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
2
|
from uuid import UUID
|
|
3
3
|
from pydantic import BaseModel
|
|
4
4
|
|
|
@@ -12,10 +12,10 @@ class ScoredResult(BaseModel):
|
|
|
12
12
|
- id (UUID): Unique identifier for the scored result.
|
|
13
13
|
- score (float): The score associated with the result, where a lower score indicates a
|
|
14
14
|
better outcome.
|
|
15
|
-
- payload (Dict[str, Any]): Additional information related to the score, stored as
|
|
15
|
+
- payload (Optional[Dict[str, Any]]): Additional information related to the score, stored as
|
|
16
16
|
key-value pairs in a dictionary.
|
|
17
17
|
"""
|
|
18
18
|
|
|
19
19
|
id: UUID
|
|
20
20
|
score: float # Lower score is better
|
|
21
|
-
payload: Dict[str, Any]
|
|
21
|
+
payload: Optional[Dict[str, Any]] = None
|
|
@@ -301,6 +301,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|
|
301
301
|
query_vector: Optional[List[float]] = None,
|
|
302
302
|
limit: Optional[int] = 15,
|
|
303
303
|
with_vector: bool = False,
|
|
304
|
+
include_payload: bool = False,
|
|
304
305
|
) -> List[ScoredResult]:
|
|
305
306
|
if query_text is None and query_vector is None:
|
|
306
307
|
raise MissingQueryParameterError()
|
|
@@ -324,10 +325,16 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|
|
324
325
|
# NOTE: This needs to be initialized in case search doesn't return a value
|
|
325
326
|
closest_items = []
|
|
326
327
|
|
|
328
|
+
# Note: Exclude payload from returned columns if not needed to optimize performance
|
|
329
|
+
select_columns = (
|
|
330
|
+
[PGVectorDataPoint]
|
|
331
|
+
if include_payload
|
|
332
|
+
else [PGVectorDataPoint.c.id, PGVectorDataPoint.c.vector]
|
|
333
|
+
)
|
|
327
334
|
# Use async session to connect to the database
|
|
328
335
|
async with self.get_async_session() as session:
|
|
329
336
|
query = select(
|
|
330
|
-
|
|
337
|
+
*select_columns,
|
|
331
338
|
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
|
|
332
339
|
).order_by("similarity")
|
|
333
340
|
|
|
@@ -344,7 +351,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|
|
344
351
|
vector_list.append(
|
|
345
352
|
{
|
|
346
353
|
"id": parse_id(str(vector.id)),
|
|
347
|
-
"payload": vector.payload,
|
|
354
|
+
"payload": vector.payload if include_payload else None,
|
|
348
355
|
"_distance": vector.similarity,
|
|
349
356
|
}
|
|
350
357
|
)
|
|
@@ -359,7 +366,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|
|
359
366
|
|
|
360
367
|
# Create and return ScoredResult objects
|
|
361
368
|
return [
|
|
362
|
-
ScoredResult(
|
|
369
|
+
ScoredResult(
|
|
370
|
+
id=row.get("id"),
|
|
371
|
+
payload=row.get("payload") if include_payload else None,
|
|
372
|
+
score=row.get("score"),
|
|
373
|
+
)
|
|
363
374
|
for row in vector_list
|
|
364
375
|
]
|
|
365
376
|
|
|
@@ -369,6 +380,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|
|
369
380
|
query_texts: List[str],
|
|
370
381
|
limit: int = None,
|
|
371
382
|
with_vectors: bool = False,
|
|
383
|
+
include_payload: bool = False,
|
|
372
384
|
):
|
|
373
385
|
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
|
374
386
|
|
|
@@ -379,6 +391,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|
|
379
391
|
query_vector=query_vector,
|
|
380
392
|
limit=limit,
|
|
381
393
|
with_vector=with_vectors,
|
|
394
|
+
include_payload=include_payload,
|
|
382
395
|
)
|
|
383
396
|
for query_vector in query_vectors
|
|
384
397
|
]
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
from uuid import UUID
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
|
|
5
|
+
from cognee.infrastructure.databases.vector.pgvector.create_db_and_tables import delete_pg_database
|
|
6
|
+
from cognee.modules.users.models import User
|
|
7
|
+
from cognee.modules.users.models import DatasetDatabase
|
|
8
|
+
from cognee.infrastructure.databases.vector import get_vectordb_config
|
|
9
|
+
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PGVectorDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
|
13
|
+
"""
|
|
14
|
+
Handler for interacting with PGVector Dataset databases.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
@classmethod
|
|
18
|
+
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
|
19
|
+
vector_config = get_vectordb_config()
|
|
20
|
+
|
|
21
|
+
if vector_config.vector_db_provider != "pgvector":
|
|
22
|
+
raise ValueError(
|
|
23
|
+
"PGVectorDatasetDatabaseHandler can only be used with PGVector vector database provider."
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
vector_db_name = f"{dataset_id}"
|
|
27
|
+
|
|
28
|
+
new_vector_config = {
|
|
29
|
+
"vector_database_provider": vector_config.vector_db_provider,
|
|
30
|
+
"vector_database_url": vector_config.vector_db_url,
|
|
31
|
+
"vector_database_name": vector_db_name,
|
|
32
|
+
"vector_database_connection_info": {
|
|
33
|
+
"port": vector_config.vector_db_port,
|
|
34
|
+
"host": vector_config.vector_db_host,
|
|
35
|
+
},
|
|
36
|
+
"vector_dataset_database_handler": "pgvector",
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
from .create_db_and_tables import create_pg_database
|
|
40
|
+
|
|
41
|
+
await create_pg_database(
|
|
42
|
+
{
|
|
43
|
+
"vector_db_provider": new_vector_config["vector_database_provider"],
|
|
44
|
+
"vector_db_url": new_vector_config["vector_database_url"],
|
|
45
|
+
"vector_db_name": new_vector_config["vector_database_name"],
|
|
46
|
+
"vector_db_port": new_vector_config["vector_database_connection_info"]["port"],
|
|
47
|
+
"vector_db_key": "",
|
|
48
|
+
"vector_db_username": vector_config.vector_db_username,
|
|
49
|
+
"vector_db_password": vector_config.vector_db_password,
|
|
50
|
+
"vector_db_host": new_vector_config["vector_database_connection_info"]["host"],
|
|
51
|
+
"vector_dataset_database_handler": "pgvector",
|
|
52
|
+
}
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
return new_vector_config
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
async def resolve_dataset_connection_info(
|
|
59
|
+
cls, dataset_database: DatasetDatabase
|
|
60
|
+
) -> DatasetDatabase:
|
|
61
|
+
vector_config = get_vectordb_config()
|
|
62
|
+
# Note: For PGVector, we use the vector DB username/password from configuration so it's never stored in the DB
|
|
63
|
+
dataset_database.vector_database_connection_info["username"] = (
|
|
64
|
+
vector_config.vector_db_username
|
|
65
|
+
)
|
|
66
|
+
dataset_database.vector_database_connection_info["password"] = (
|
|
67
|
+
vector_config.vector_db_password
|
|
68
|
+
)
|
|
69
|
+
return dataset_database
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
async def delete_dataset(cls, dataset_database: DatasetDatabase):
|
|
73
|
+
dataset_database = await cls.resolve_dataset_connection_info(dataset_database)
|
|
74
|
+
vector_engine = create_vector_engine(
|
|
75
|
+
vector_db_provider=dataset_database.vector_database_provider,
|
|
76
|
+
vector_db_url=dataset_database.vector_database_url,
|
|
77
|
+
vector_db_name=dataset_database.vector_database_name,
|
|
78
|
+
vector_db_port=dataset_database.vector_database_connection_info["port"],
|
|
79
|
+
vector_db_username=dataset_database.vector_database_connection_info["username"],
|
|
80
|
+
vector_db_password=dataset_database.vector_database_connection_info["password"],
|
|
81
|
+
)
|
|
82
|
+
# Prune data
|
|
83
|
+
await vector_engine.prune()
|
|
84
|
+
|
|
85
|
+
# Drop entire database
|
|
86
|
+
await delete_pg_database(dataset_database)
|
|
@@ -1,5 +1,10 @@
|
|
|
1
|
-
from sqlalchemy import text
|
|
1
|
+
from sqlalchemy import text, URL
|
|
2
|
+
from sqlalchemy.ext.asyncio import create_async_engine
|
|
3
|
+
|
|
4
|
+
from cognee.modules.users.models import DatasetDatabase
|
|
2
5
|
from ..get_vector_engine import get_vector_engine, get_vectordb_context_config
|
|
6
|
+
from ...vector import get_vectordb_config
|
|
7
|
+
from cognee.context_global_variables import backend_access_control_enabled
|
|
3
8
|
|
|
4
9
|
|
|
5
10
|
async def create_db_and_tables():
|
|
@@ -7,6 +12,80 @@ async def create_db_and_tables():
|
|
|
7
12
|
vector_config = get_vectordb_context_config()
|
|
8
13
|
vector_engine = get_vector_engine()
|
|
9
14
|
|
|
10
|
-
if vector_config["vector_db_provider"] == "pgvector":
|
|
15
|
+
if vector_config["vector_db_provider"] == "pgvector" and not backend_access_control_enabled():
|
|
11
16
|
async with vector_engine.engine.begin() as connection:
|
|
12
17
|
await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
async def create_pg_database(vector_config):
|
|
21
|
+
"""
|
|
22
|
+
Create the necessary Postgres database, and the PGVector extension on it.
|
|
23
|
+
This is defined separately because the creation needs the latest vector config,
|
|
24
|
+
which is not yet saved in the vector config context variable.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
|
|
28
|
+
|
|
29
|
+
# Create a maintenance engine, used when creating new postgres databases.
|
|
30
|
+
# Database named "postgres" should always exist. We need this since the SQLAlchemy
|
|
31
|
+
# engine cannot directly execute queries without first connecting to a database.
|
|
32
|
+
maintenance_db_name = "postgres"
|
|
33
|
+
maintenance_db_url = URL.create(
|
|
34
|
+
"postgresql+asyncpg",
|
|
35
|
+
username=vector_config["vector_db_username"],
|
|
36
|
+
password=vector_config["vector_db_password"],
|
|
37
|
+
host=vector_config["vector_db_host"],
|
|
38
|
+
port=int(vector_config["vector_db_port"]),
|
|
39
|
+
database=maintenance_db_name,
|
|
40
|
+
)
|
|
41
|
+
maintenance_engine = create_async_engine(maintenance_db_url)
|
|
42
|
+
|
|
43
|
+
# Connect to maintenance db in order to create new database
|
|
44
|
+
# Make sure to execute CREATE DATABASE outside of transaction block, and set AUTOCOMMIT isolation level
|
|
45
|
+
connection = await maintenance_engine.connect()
|
|
46
|
+
connection = await connection.execution_options(isolation_level="AUTOCOMMIT")
|
|
47
|
+
await connection.execute(text(f'CREATE DATABASE "{vector_config["vector_db_name"]}";'))
|
|
48
|
+
|
|
49
|
+
# Clean up resources
|
|
50
|
+
await connection.close()
|
|
51
|
+
|
|
52
|
+
vector_engine = create_vector_engine(**vector_config)
|
|
53
|
+
async with vector_engine.engine.begin() as connection:
|
|
54
|
+
await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
async def delete_pg_database(dataset_database: DatasetDatabase):
|
|
58
|
+
"""
|
|
59
|
+
Delete the Postgres database that was created for the PGVector extension,
|
|
60
|
+
in case of using PGVector with multi-user mode.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
vector_config = get_vectordb_config()
|
|
64
|
+
# Create a maintenance engine, used when creating new postgres databases.
|
|
65
|
+
# Database named "postgres" should always exist. We need this since the SQLAlchemy
|
|
66
|
+
# engine cannot drop a database to which it is connected.
|
|
67
|
+
maintenance_db_name = "postgres"
|
|
68
|
+
maintenance_db_url = URL.create(
|
|
69
|
+
"postgresql+asyncpg",
|
|
70
|
+
username=vector_config.vector_db_username,
|
|
71
|
+
password=vector_config.vector_db_password,
|
|
72
|
+
host=vector_config.vector_db_host,
|
|
73
|
+
port=int(vector_config.vector_db_port),
|
|
74
|
+
database=maintenance_db_name,
|
|
75
|
+
)
|
|
76
|
+
maintenance_engine = create_async_engine(maintenance_db_url)
|
|
77
|
+
|
|
78
|
+
connection = await maintenance_engine.connect()
|
|
79
|
+
connection = await connection.execution_options(isolation_level="AUTOCOMMIT")
|
|
80
|
+
# We first have to kill all active sessions on the database, then delete it
|
|
81
|
+
await connection.execute(
|
|
82
|
+
text(
|
|
83
|
+
"SELECT pg_terminate_backend(pid) "
|
|
84
|
+
"FROM pg_stat_activity "
|
|
85
|
+
"WHERE datname = :db AND pid <> pg_backend_pid()"
|
|
86
|
+
),
|
|
87
|
+
{"db": dataset_database.vector_database_name},
|
|
88
|
+
)
|
|
89
|
+
await connection.execute(text(f'DROP DATABASE "{dataset_database.vector_database_name}";'))
|
|
90
|
+
|
|
91
|
+
await connection.close()
|
|
@@ -87,6 +87,7 @@ class VectorDBInterface(Protocol):
|
|
|
87
87
|
query_vector: Optional[List[float]],
|
|
88
88
|
limit: Optional[int],
|
|
89
89
|
with_vector: bool = False,
|
|
90
|
+
include_payload: bool = False,
|
|
90
91
|
):
|
|
91
92
|
"""
|
|
92
93
|
Perform a search in the specified collection using either a text query or a vector
|
|
@@ -103,6 +104,9 @@ class VectorDBInterface(Protocol):
|
|
|
103
104
|
- limit (Optional[int]): The maximum number of results to return from the search.
|
|
104
105
|
- with_vector (bool): Whether to return the vector representations with search
|
|
105
106
|
results. (default False)
|
|
107
|
+
- include_payload (bool): Whether to include the payload data with search. Search is faster when set to False.
|
|
108
|
+
Payload contains metadata about the data point, useful for searches that are only based on embedding distances
|
|
109
|
+
like the RAG_COMPLETION search type, but not needed when search also contains graph data.
|
|
106
110
|
"""
|
|
107
111
|
raise NotImplementedError
|
|
108
112
|
|
|
@@ -113,6 +117,7 @@ class VectorDBInterface(Protocol):
|
|
|
113
117
|
query_texts: List[str],
|
|
114
118
|
limit: Optional[int],
|
|
115
119
|
with_vectors: bool = False,
|
|
120
|
+
include_payload: bool = False,
|
|
116
121
|
):
|
|
117
122
|
"""
|
|
118
123
|
Perform a batch search using multiple text queries against a collection.
|
|
@@ -125,6 +130,9 @@ class VectorDBInterface(Protocol):
|
|
|
125
130
|
- limit (Optional[int]): The maximum number of results to return for each query.
|
|
126
131
|
- with_vectors (bool): Whether to include vector representations with search
|
|
127
132
|
results. (default False)
|
|
133
|
+
- include_payload (bool): Whether to include the payload data with search. Search is faster when set to False.
|
|
134
|
+
Payload contains metadata about the data point, useful for searches that are only based on embedding distances
|
|
135
|
+
like the RAG_COMPLETION search type, but not needed when search also contains graph data.
|
|
128
136
|
"""
|
|
129
137
|
raise NotImplementedError
|
|
130
138
|
|