cognee 0.5.1__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +2 -0
- cognee/alembic/README +1 -0
- cognee/alembic/env.py +107 -0
- cognee/alembic/script.py.mako +26 -0
- cognee/alembic/versions/1a58b986e6e1_enable_delete_for_old_tutorial_notebooks.py +52 -0
- cognee/alembic/versions/1d0bb7fede17_add_pipeline_run_status.py +33 -0
- cognee/alembic/versions/1daae0df1866_incremental_loading.py +48 -0
- cognee/alembic/versions/211ab850ef3d_add_sync_operations_table.py +118 -0
- cognee/alembic/versions/45957f0a9849_add_notebook_table.py +46 -0
- cognee/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py +333 -0
- cognee/alembic/versions/482cd6517ce4_add_default_user.py +30 -0
- cognee/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py +98 -0
- cognee/alembic/versions/8057ae7329c2_initial_migration.py +25 -0
- cognee/alembic/versions/9e7a3cb85175_loader_separation.py +104 -0
- cognee/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py +38 -0
- cognee/alembic/versions/ab7e313804ae_permission_system_rework.py +236 -0
- cognee/alembic/versions/b9274c27a25a_kuzu_11_migration.py +75 -0
- cognee/alembic/versions/c946955da633_multi_tenant_support.py +137 -0
- cognee/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py +51 -0
- cognee/alembic/versions/e4ebee1091e7_expand_data_model_info.py +140 -0
- cognee/alembic.ini +117 -0
- cognee/api/v1/add/add.py +2 -1
- cognee/api/v1/add/routers/get_add_router.py +2 -0
- cognee/api/v1/cognify/cognify.py +11 -6
- cognee/api/v1/cognify/routers/get_cognify_router.py +8 -0
- cognee/api/v1/config/config.py +60 -0
- cognee/api/v1/datasets/routers/get_datasets_router.py +46 -3
- cognee/api/v1/memify/routers/get_memify_router.py +3 -0
- cognee/api/v1/search/routers/get_search_router.py +21 -6
- cognee/api/v1/search/search.py +21 -5
- cognee/api/v1/sync/routers/get_sync_router.py +3 -3
- cognee/cli/commands/add_command.py +1 -1
- cognee/cli/commands/cognify_command.py +6 -0
- cognee/cli/commands/config_command.py +1 -1
- cognee/context_global_variables.py +5 -1
- cognee/eval_framework/answer_generation/answer_generation_executor.py +7 -8
- cognee/infrastructure/databases/cache/cache_db_interface.py +38 -1
- cognee/infrastructure/databases/cache/config.py +6 -0
- cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py +21 -0
- cognee/infrastructure/databases/cache/get_cache_engine.py +9 -3
- cognee/infrastructure/databases/cache/redis/RedisAdapter.py +60 -1
- cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +7 -0
- cognee/infrastructure/databases/graph/get_graph_engine.py +29 -1
- cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +62 -27
- cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +17 -4
- cognee/infrastructure/databases/relational/config.py +16 -1
- cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +26 -3
- cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +2 -0
- cognee/infrastructure/databases/vector/config.py +6 -0
- cognee/infrastructure/databases/vector/create_vector_engine.py +70 -16
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +64 -9
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +13 -2
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +16 -3
- cognee/infrastructure/databases/vector/models/ScoredResult.py +3 -3
- cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +16 -3
- cognee/infrastructure/databases/vector/pgvector/PGVectorDatasetDatabaseHandler.py +86 -0
- cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py +81 -2
- cognee/infrastructure/databases/vector/vector_db_interface.py +8 -0
- cognee/infrastructure/files/utils/get_data_file_path.py +33 -27
- cognee/infrastructure/llm/LLMGateway.py +0 -13
- cognee/infrastructure/llm/prompts/extract_query_time.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_event_entity_prompt.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_event_graph_prompt.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt.txt +2 -2
- cognee/infrastructure/llm/prompts/generate_graph_prompt_guided.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt_oneshot.txt +2 -2
- cognee/infrastructure/llm/prompts/generate_graph_prompt_simple.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt_strict.txt +1 -1
- cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt +6 -6
- cognee/infrastructure/llm/prompts/test.txt +1 -1
- cognee/infrastructure/llm/prompts/translate_content.txt +19 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +29 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llama_cpp/adapter.py +191 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
- cognee/modules/chunking/models/DocumentChunk.py +0 -1
- cognee/modules/cognify/config.py +2 -0
- cognee/modules/data/models/Data.py +3 -1
- cognee/modules/engine/models/Entity.py +0 -1
- cognee/modules/engine/operations/setup.py +6 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +150 -37
- cognee/modules/graph/cognee_graph/CogneeGraphElements.py +48 -2
- cognee/modules/graph/utils/__init__.py +1 -0
- cognee/modules/graph/utils/get_entity_nodes_from_triplets.py +12 -0
- cognee/modules/notebooks/methods/__init__.py +1 -0
- cognee/modules/notebooks/methods/create_notebook.py +0 -34
- cognee/modules/notebooks/methods/create_tutorial_notebooks.py +191 -0
- cognee/modules/notebooks/methods/get_notebooks.py +12 -8
- cognee/modules/notebooks/tutorials/cognee-basics/cell-1.md +3 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-2.md +10 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-3.md +7 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-4.py +28 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-5.py +3 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-6.py +9 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-7.py +17 -0
- cognee/modules/notebooks/tutorials/cognee-basics/config.json +4 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-1.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-10.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-11.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-12.py +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-13.md +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-14.py +6 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-15.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-16.py +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-2.md +9 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-3.md +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-4.md +9 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-5.md +5 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-6.py +13 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-7.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-8.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-9.py +31 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/config.json +4 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/copilot_conversations.json +107 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/guido_contributions.json +976 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/my_developer_rules.md +79 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/pep_style_guide.md +74 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/zen_principles.md +74 -0
- cognee/modules/retrieval/EntityCompletionRetriever.py +51 -38
- cognee/modules/retrieval/__init__.py +0 -1
- cognee/modules/retrieval/base_retriever.py +66 -10
- cognee/modules/retrieval/chunks_retriever.py +57 -49
- cognee/modules/retrieval/coding_rules_retriever.py +12 -5
- cognee/modules/retrieval/completion_retriever.py +29 -28
- cognee/modules/retrieval/cypher_search_retriever.py +25 -20
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +42 -46
- cognee/modules/retrieval/graph_completion_cot_retriever.py +68 -51
- cognee/modules/retrieval/graph_completion_retriever.py +78 -63
- cognee/modules/retrieval/graph_summary_completion_retriever.py +2 -0
- cognee/modules/retrieval/lexical_retriever.py +34 -12
- cognee/modules/retrieval/natural_language_retriever.py +18 -15
- cognee/modules/retrieval/summaries_retriever.py +51 -34
- cognee/modules/retrieval/temporal_retriever.py +59 -49
- cognee/modules/retrieval/triplet_retriever.py +32 -33
- cognee/modules/retrieval/utils/access_tracking.py +88 -0
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +99 -103
- cognee/modules/retrieval/utils/node_edge_vector_search.py +174 -0
- cognee/modules/search/methods/__init__.py +1 -0
- cognee/modules/search/methods/get_retriever_output.py +53 -0
- cognee/modules/search/methods/get_search_type_retriever_instance.py +252 -0
- cognee/modules/search/methods/search.py +90 -222
- cognee/modules/search/models/SearchResultPayload.py +67 -0
- cognee/modules/search/types/SearchResult.py +1 -8
- cognee/modules/search/types/SearchType.py +1 -2
- cognee/modules/search/types/__init__.py +1 -1
- cognee/modules/search/utils/__init__.py +1 -2
- cognee/modules/search/utils/transform_insights_to_graph.py +2 -2
- cognee/modules/search/utils/{transform_context_to_graph.py → transform_triplets_to_graph.py} +2 -2
- cognee/modules/users/authentication/default/default_transport.py +11 -1
- cognee/modules/users/authentication/get_api_auth_backend.py +2 -1
- cognee/modules/users/authentication/get_client_auth_backend.py +2 -1
- cognee/modules/users/methods/create_user.py +0 -9
- cognee/modules/users/permissions/methods/has_user_management_permission.py +29 -0
- cognee/modules/visualization/cognee_network_visualization.py +1 -1
- cognee/run_migrations.py +48 -0
- cognee/shared/exceptions/__init__.py +1 -3
- cognee/shared/exceptions/exceptions.py +11 -1
- cognee/shared/usage_logger.py +332 -0
- cognee/shared/utils.py +12 -5
- cognee/tasks/chunks/__init__.py +9 -0
- cognee/tasks/cleanup/cleanup_unused_data.py +172 -0
- cognee/tasks/graph/__init__.py +7 -0
- cognee/tasks/ingestion/data_item.py +8 -0
- cognee/tasks/ingestion/ingest_data.py +12 -1
- cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
- cognee/tasks/memify/__init__.py +8 -0
- cognee/tasks/memify/extract_usage_frequency.py +613 -0
- cognee/tasks/summarization/models.py +0 -2
- cognee/tasks/temporal_graph/__init__.py +0 -1
- cognee/tasks/translation/__init__.py +96 -0
- cognee/tasks/translation/config.py +110 -0
- cognee/tasks/translation/detect_language.py +190 -0
- cognee/tasks/translation/exceptions.py +62 -0
- cognee/tasks/translation/models.py +72 -0
- cognee/tasks/translation/providers/__init__.py +44 -0
- cognee/tasks/translation/providers/azure_provider.py +192 -0
- cognee/tasks/translation/providers/base.py +85 -0
- cognee/tasks/translation/providers/google_provider.py +158 -0
- cognee/tasks/translation/providers/llm_provider.py +143 -0
- cognee/tasks/translation/translate_content.py +282 -0
- cognee/tasks/web_scraper/default_url_crawler.py +6 -2
- cognee/tests/cli_tests/cli_unit_tests/test_cli_commands.py +1 -0
- cognee/tests/cli_tests/cli_unit_tests/test_cli_edge_cases.py +3 -0
- cognee/tests/integration/retrieval/test_brute_force_triplet_search_with_cognify.py +62 -0
- cognee/tests/integration/retrieval/test_chunks_retriever.py +351 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever.py +276 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +228 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +217 -0
- cognee/tests/integration/retrieval/test_rag_completion_retriever.py +319 -0
- cognee/tests/integration/retrieval/test_structured_output.py +258 -0
- cognee/tests/integration/retrieval/test_summaries_retriever.py +195 -0
- cognee/tests/integration/retrieval/test_temporal_retriever.py +336 -0
- cognee/tests/integration/retrieval/test_triplet_retriever.py +45 -1
- cognee/tests/integration/shared/test_usage_logger_integration.py +255 -0
- cognee/tests/tasks/translation/README.md +147 -0
- cognee/tests/tasks/translation/__init__.py +1 -0
- cognee/tests/tasks/translation/config_test.py +93 -0
- cognee/tests/tasks/translation/detect_language_test.py +118 -0
- cognee/tests/tasks/translation/providers_test.py +151 -0
- cognee/tests/tasks/translation/translate_content_test.py +213 -0
- cognee/tests/test_chromadb.py +1 -1
- cognee/tests/test_cleanup_unused_data.py +165 -0
- cognee/tests/test_custom_data_label.py +68 -0
- cognee/tests/test_delete_by_id.py +6 -6
- cognee/tests/test_extract_usage_frequency.py +308 -0
- cognee/tests/test_kuzu.py +17 -7
- cognee/tests/test_lancedb.py +3 -1
- cognee/tests/test_library.py +1 -1
- cognee/tests/test_neo4j.py +17 -7
- cognee/tests/test_neptune_analytics_vector.py +3 -1
- cognee/tests/test_permissions.py +172 -187
- cognee/tests/test_pgvector.py +3 -1
- cognee/tests/test_relational_db_migration.py +15 -1
- cognee/tests/test_remote_kuzu.py +3 -1
- cognee/tests/test_s3_file_storage.py +1 -1
- cognee/tests/test_search_db.py +345 -205
- cognee/tests/test_usage_logger_e2e.py +268 -0
- cognee/tests/unit/api/test_get_raw_data_endpoint.py +206 -0
- cognee/tests/unit/eval_framework/answer_generation_test.py +4 -3
- cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
- cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
- cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +2 -0
- cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
- cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +42 -2
- cognee/tests/unit/modules/graph/cognee_graph_test.py +329 -31
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +122 -168
- cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +486 -157
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +693 -155
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +619 -200
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +300 -171
- cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +184 -155
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +544 -79
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +476 -28
- cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
- cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
- cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py +273 -0
- cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +267 -7
- cognee/tests/unit/modules/search/test_get_search_type_retriever_instance.py +125 -0
- cognee/tests/unit/modules/search/test_search.py +96 -20
- cognee/tests/unit/modules/search/test_search_prepare_search_result_contract.py +190 -0
- cognee/tests/unit/modules/users/test_tutorial_notebook_creation.py +511 -297
- cognee/tests/unit/shared/test_usage_logger.py +241 -0
- cognee/tests/unit/users/permissions/test_has_user_management_permission.py +46 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/METADATA +22 -17
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/RECORD +258 -157
- cognee/api/.env.example +0 -5
- cognee/modules/retrieval/base_graph_retriever.py +0 -24
- cognee/modules/search/methods/get_search_type_tools.py +0 -223
- cognee/modules/search/methods/no_access_control_search.py +0 -62
- cognee/modules/search/utils/prepare_search_result.py +0 -63
- cognee/tests/test_feedback_enrichment.py +0 -174
- cognee/tests/unit/modules/retrieval/structured_output_test.py +0 -204
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/WHEEL +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Callable, List, Optional, Type, Tuple
|
|
3
|
+
|
|
4
|
+
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|
5
|
+
|
|
6
|
+
from cognee.modules.engine.models.node_set import NodeSet
|
|
7
|
+
from cognee.modules.retrieval.triplet_retriever import TripletRetriever
|
|
8
|
+
from cognee.modules.search.types import SearchType
|
|
9
|
+
from cognee.modules.search.operations import select_search_type
|
|
10
|
+
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
|
|
11
|
+
|
|
12
|
+
# Retrievers
|
|
13
|
+
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
|
|
14
|
+
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
|
|
15
|
+
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
|
|
16
|
+
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
|
17
|
+
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
|
18
|
+
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
|
19
|
+
from cognee.modules.retrieval.coding_rules_retriever import CodingRulesRetriever
|
|
20
|
+
from cognee.modules.retrieval.jaccard_retrival import JaccardChunksRetriever
|
|
21
|
+
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
|
22
|
+
GraphSummaryCompletionRetriever,
|
|
23
|
+
)
|
|
24
|
+
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
|
25
|
+
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
|
26
|
+
GraphCompletionContextExtensionRetriever,
|
|
27
|
+
)
|
|
28
|
+
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
|
|
29
|
+
from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
async def get_search_type_retriever_instance(
|
|
33
|
+
query_type: SearchType,
|
|
34
|
+
query_text: str,
|
|
35
|
+
**kwargs,
|
|
36
|
+
) -> BaseRetriever:
|
|
37
|
+
"""
|
|
38
|
+
Factory method to get the appropriate retriever instance based on the search type.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
query_type: SearchType enum indicating the type of search.
|
|
42
|
+
query_text: query string.
|
|
43
|
+
retriever_specific_config: Retriever specific configuration dictionary.
|
|
44
|
+
**kwargs: General keyword arguments for retriever initialization.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
|
|
48
|
+
"""
|
|
49
|
+
# Transform retriever specific config if empty to avoid None checks later
|
|
50
|
+
retriever_specific_config = kwargs.get("retriever_specific_config")
|
|
51
|
+
if retriever_specific_config is None:
|
|
52
|
+
retriever_specific_config = {}
|
|
53
|
+
|
|
54
|
+
# Extract common defaults with fallback values from kwargs
|
|
55
|
+
top_k = kwargs.get("top_k", 10)
|
|
56
|
+
system_prompt_path = kwargs.get("system_prompt_path", "answer_simple_question.txt")
|
|
57
|
+
system_prompt = kwargs.get("system_prompt")
|
|
58
|
+
node_type = kwargs.get("node_type", NodeSet)
|
|
59
|
+
node_name = kwargs.get("node_name")
|
|
60
|
+
save_interaction = kwargs.get("save_interaction", False)
|
|
61
|
+
wide_search_top_k = kwargs.get("wide_search_top_k", 100)
|
|
62
|
+
triplet_distance_penalty = kwargs.get("triplet_distance_penalty", 3.5)
|
|
63
|
+
session_id = kwargs.get("session_id")
|
|
64
|
+
|
|
65
|
+
# Registry mapping search types to their corresponding retriever classes and input parameters
|
|
66
|
+
search_core_registry: dict[SearchType, Tuple[BaseRetriever, dict]] = {
|
|
67
|
+
SearchType.SUMMARIES: (SummariesRetriever, {"top_k": top_k, "session_id": session_id}),
|
|
68
|
+
SearchType.CHUNKS: (
|
|
69
|
+
ChunksRetriever,
|
|
70
|
+
{"top_k": top_k},
|
|
71
|
+
),
|
|
72
|
+
SearchType.RAG_COMPLETION: (
|
|
73
|
+
CompletionRetriever,
|
|
74
|
+
{
|
|
75
|
+
"system_prompt_path": system_prompt_path,
|
|
76
|
+
"top_k": top_k,
|
|
77
|
+
"system_prompt": system_prompt,
|
|
78
|
+
"session_id": session_id,
|
|
79
|
+
"response_model": retriever_specific_config.get("response_model", str),
|
|
80
|
+
},
|
|
81
|
+
),
|
|
82
|
+
SearchType.TRIPLET_COMPLETION: (
|
|
83
|
+
TripletRetriever,
|
|
84
|
+
{
|
|
85
|
+
"system_prompt_path": system_prompt_path,
|
|
86
|
+
"top_k": top_k,
|
|
87
|
+
"system_prompt": system_prompt,
|
|
88
|
+
"session_id": session_id,
|
|
89
|
+
"response_model": retriever_specific_config.get("response_model", str),
|
|
90
|
+
},
|
|
91
|
+
),
|
|
92
|
+
SearchType.GRAPH_COMPLETION: (
|
|
93
|
+
GraphCompletionRetriever,
|
|
94
|
+
{
|
|
95
|
+
"system_prompt_path": system_prompt_path,
|
|
96
|
+
"top_k": top_k,
|
|
97
|
+
"node_type": node_type,
|
|
98
|
+
"node_name": node_name,
|
|
99
|
+
"save_interaction": save_interaction,
|
|
100
|
+
"system_prompt": system_prompt,
|
|
101
|
+
"wide_search_top_k": wide_search_top_k,
|
|
102
|
+
"triplet_distance_penalty": triplet_distance_penalty,
|
|
103
|
+
"session_id": session_id,
|
|
104
|
+
"response_model": retriever_specific_config.get("response_model", str),
|
|
105
|
+
},
|
|
106
|
+
),
|
|
107
|
+
SearchType.GRAPH_COMPLETION_COT: (
|
|
108
|
+
GraphCompletionCotRetriever,
|
|
109
|
+
{
|
|
110
|
+
"system_prompt_path": system_prompt_path,
|
|
111
|
+
"top_k": top_k,
|
|
112
|
+
"node_type": node_type,
|
|
113
|
+
"node_name": node_name,
|
|
114
|
+
"save_interaction": save_interaction,
|
|
115
|
+
"system_prompt": system_prompt,
|
|
116
|
+
"wide_search_top_k": wide_search_top_k,
|
|
117
|
+
"triplet_distance_penalty": triplet_distance_penalty,
|
|
118
|
+
"max_iter": retriever_specific_config.get("max_iter", 4),
|
|
119
|
+
"validation_system_prompt_path": retriever_specific_config.get(
|
|
120
|
+
"validation_system_prompt_path", "cot_validation_system_prompt.txt"
|
|
121
|
+
),
|
|
122
|
+
"validation_user_prompt_path": retriever_specific_config.get(
|
|
123
|
+
"validation_user_prompt_path", "cot_validation_user_prompt.txt"
|
|
124
|
+
),
|
|
125
|
+
"followup_system_prompt_path": retriever_specific_config.get(
|
|
126
|
+
"followup_system_prompt_path", "cot_followup_system_prompt.txt"
|
|
127
|
+
),
|
|
128
|
+
"followup_user_prompt_path": retriever_specific_config.get(
|
|
129
|
+
"followup_user_prompt_path", "cot_followup_user_prompt.txt"
|
|
130
|
+
),
|
|
131
|
+
"session_id": session_id,
|
|
132
|
+
"response_model": retriever_specific_config.get("response_model", str),
|
|
133
|
+
},
|
|
134
|
+
),
|
|
135
|
+
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: (
|
|
136
|
+
GraphCompletionContextExtensionRetriever,
|
|
137
|
+
{
|
|
138
|
+
"system_prompt_path": system_prompt_path,
|
|
139
|
+
"top_k": top_k,
|
|
140
|
+
"node_type": node_type,
|
|
141
|
+
"node_name": node_name,
|
|
142
|
+
"save_interaction": save_interaction,
|
|
143
|
+
"system_prompt": system_prompt,
|
|
144
|
+
"wide_search_top_k": wide_search_top_k,
|
|
145
|
+
"triplet_distance_penalty": triplet_distance_penalty,
|
|
146
|
+
"context_extension_rounds": retriever_specific_config.get(
|
|
147
|
+
"context_extension_rounds", 4
|
|
148
|
+
),
|
|
149
|
+
"session_id": session_id,
|
|
150
|
+
"response_model": retriever_specific_config.get("response_model", str),
|
|
151
|
+
},
|
|
152
|
+
),
|
|
153
|
+
SearchType.GRAPH_SUMMARY_COMPLETION: (
|
|
154
|
+
GraphSummaryCompletionRetriever,
|
|
155
|
+
{
|
|
156
|
+
"system_prompt_path": system_prompt_path,
|
|
157
|
+
"top_k": top_k,
|
|
158
|
+
"node_type": node_type,
|
|
159
|
+
"node_name": node_name,
|
|
160
|
+
"save_interaction": save_interaction,
|
|
161
|
+
"system_prompt": system_prompt,
|
|
162
|
+
"wide_search_top_k": wide_search_top_k,
|
|
163
|
+
"triplet_distance_penalty": triplet_distance_penalty,
|
|
164
|
+
"session_id": session_id,
|
|
165
|
+
"summarize_prompt_path": retriever_specific_config.get(
|
|
166
|
+
"summarize_prompt_path", "summarize_search_results.txt"
|
|
167
|
+
),
|
|
168
|
+
},
|
|
169
|
+
),
|
|
170
|
+
SearchType.CYPHER: (
|
|
171
|
+
CypherSearchRetriever,
|
|
172
|
+
{
|
|
173
|
+
"user_prompt_path": retriever_specific_config.get(
|
|
174
|
+
"user_prompt_path", "context_for_question.txt"
|
|
175
|
+
),
|
|
176
|
+
"system_prompt_path": retriever_specific_config.get(
|
|
177
|
+
"system_prompt_path", "answer_simple_question.txt"
|
|
178
|
+
),
|
|
179
|
+
"session_id": session_id,
|
|
180
|
+
},
|
|
181
|
+
),
|
|
182
|
+
SearchType.NATURAL_LANGUAGE: (
|
|
183
|
+
NaturalLanguageRetriever,
|
|
184
|
+
{
|
|
185
|
+
"session_id": session_id,
|
|
186
|
+
"system_prompt_path": retriever_specific_config.get(
|
|
187
|
+
"system_prompt_path", "natural_language_retriever_system.txt"
|
|
188
|
+
),
|
|
189
|
+
"max_attempts": retriever_specific_config.get("max_attempts", 3),
|
|
190
|
+
},
|
|
191
|
+
),
|
|
192
|
+
SearchType.TEMPORAL: (
|
|
193
|
+
TemporalRetriever,
|
|
194
|
+
{
|
|
195
|
+
"top_k": top_k,
|
|
196
|
+
"wide_search_top_k": wide_search_top_k,
|
|
197
|
+
"triplet_distance_penalty": triplet_distance_penalty,
|
|
198
|
+
"session_id": session_id,
|
|
199
|
+
"response_model": retriever_specific_config.get("response_model", str),
|
|
200
|
+
"user_prompt_path": retriever_specific_config.get(
|
|
201
|
+
"user_prompt_path", "graph_context_for_question.txt"
|
|
202
|
+
),
|
|
203
|
+
"system_prompt_path": retriever_specific_config.get(
|
|
204
|
+
"system_prompt_path", "answer_simple_question.txt"
|
|
205
|
+
),
|
|
206
|
+
"time_extraction_prompt_path": retriever_specific_config.get(
|
|
207
|
+
"time_extraction_prompt_path", "extract_query_time.txt"
|
|
208
|
+
),
|
|
209
|
+
"node_type": node_type,
|
|
210
|
+
"node_name": node_name,
|
|
211
|
+
},
|
|
212
|
+
),
|
|
213
|
+
SearchType.CHUNKS_LEXICAL: (JaccardChunksRetriever, {"top_k": top_k}),
|
|
214
|
+
SearchType.CODING_RULES: (
|
|
215
|
+
CodingRulesRetriever,
|
|
216
|
+
{"rules_nodeset_name": node_name},
|
|
217
|
+
),
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
# If the query type is FEELING_LUCKY, select the search type intelligently
|
|
221
|
+
if query_type is SearchType.FEELING_LUCKY:
|
|
222
|
+
query_type = await select_search_type(query_text)
|
|
223
|
+
|
|
224
|
+
if (
|
|
225
|
+
query_type in [SearchType.CYPHER, SearchType.NATURAL_LANGUAGE]
|
|
226
|
+
and os.getenv("ALLOW_CYPHER_QUERY", "true").lower() == "false"
|
|
227
|
+
):
|
|
228
|
+
raise UnsupportedSearchTypeError("Cypher query search types are disabled.")
|
|
229
|
+
|
|
230
|
+
from cognee.modules.retrieval.registered_community_retrievers import (
|
|
231
|
+
registered_community_retrievers,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
if query_type in registered_community_retrievers:
|
|
235
|
+
retriever = registered_community_retrievers.get(query_type)
|
|
236
|
+
|
|
237
|
+
if not retriever:
|
|
238
|
+
raise UnsupportedSearchTypeError(str(query_type))
|
|
239
|
+
# TODO: Fix community retrievers on the community side so they get all input parameters properly
|
|
240
|
+
retriever_instance = retriever(**kwargs)
|
|
241
|
+
else:
|
|
242
|
+
retriever_info = search_core_registry.get(query_type)
|
|
243
|
+
# Check if retriever info is found for the given query type
|
|
244
|
+
if not retriever_info:
|
|
245
|
+
raise UnsupportedSearchTypeError(str(query_type))
|
|
246
|
+
|
|
247
|
+
# If it exists unpack the retriever class and its initialization arguments
|
|
248
|
+
retriever_cls, retriever_args = retriever_info
|
|
249
|
+
|
|
250
|
+
retriever_instance = retriever_cls(**retriever_args)
|
|
251
|
+
|
|
252
|
+
return retriever_instance
|
|
@@ -14,8 +14,6 @@ from cognee.modules.engine.models.node_set import NodeSet
|
|
|
14
14
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
|
15
15
|
from cognee.modules.search.types import (
|
|
16
16
|
SearchResult,
|
|
17
|
-
CombinedSearchResult,
|
|
18
|
-
SearchResultDataset,
|
|
19
17
|
SearchType,
|
|
20
18
|
)
|
|
21
19
|
from cognee.modules.search.operations import log_query, log_result
|
|
@@ -25,9 +23,7 @@ from cognee.modules.data.methods.get_authorized_existing_datasets import (
|
|
|
25
23
|
get_authorized_existing_datasets,
|
|
26
24
|
)
|
|
27
25
|
from cognee import __version__ as cognee_version
|
|
28
|
-
from .
|
|
29
|
-
from .no_access_control_search import no_access_control_search
|
|
30
|
-
from ..utils.prepare_search_result import prepare_search_result
|
|
26
|
+
from cognee.modules.search.methods.get_retriever_output import get_retriever_output
|
|
31
27
|
|
|
32
28
|
logger = get_logger()
|
|
33
29
|
|
|
@@ -45,12 +41,12 @@ async def search(
|
|
|
45
41
|
save_interaction: bool = False,
|
|
46
42
|
last_k: Optional[int] = None,
|
|
47
43
|
only_context: bool = False,
|
|
48
|
-
use_combined_context: bool = False,
|
|
49
44
|
session_id: Optional[str] = None,
|
|
50
45
|
wide_search_top_k: Optional[int] = 100,
|
|
51
46
|
triplet_distance_penalty: Optional[float] = 3.5,
|
|
52
|
-
verbose
|
|
53
|
-
|
|
47
|
+
verbose=False,
|
|
48
|
+
retriever_specific_config: Optional[dict] = None,
|
|
49
|
+
) -> List[SearchResult]:
|
|
54
50
|
"""
|
|
55
51
|
|
|
56
52
|
Args:
|
|
@@ -76,44 +72,24 @@ async def search(
|
|
|
76
72
|
},
|
|
77
73
|
)
|
|
78
74
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
triplet_distance_penalty=triplet_distance_penalty,
|
|
98
|
-
)
|
|
99
|
-
else:
|
|
100
|
-
search_results = [
|
|
101
|
-
await no_access_control_search(
|
|
102
|
-
query_type=query_type,
|
|
103
|
-
query_text=query_text,
|
|
104
|
-
system_prompt_path=system_prompt_path,
|
|
105
|
-
system_prompt=system_prompt,
|
|
106
|
-
top_k=top_k,
|
|
107
|
-
node_type=node_type,
|
|
108
|
-
node_name=node_name,
|
|
109
|
-
save_interaction=save_interaction,
|
|
110
|
-
last_k=last_k,
|
|
111
|
-
only_context=only_context,
|
|
112
|
-
session_id=session_id,
|
|
113
|
-
wide_search_top_k=wide_search_top_k,
|
|
114
|
-
triplet_distance_penalty=triplet_distance_penalty,
|
|
115
|
-
)
|
|
116
|
-
]
|
|
75
|
+
search_results = await authorized_search(
|
|
76
|
+
query_type=query_type,
|
|
77
|
+
query_text=query_text,
|
|
78
|
+
user=user,
|
|
79
|
+
dataset_ids=dataset_ids,
|
|
80
|
+
system_prompt_path=system_prompt_path,
|
|
81
|
+
system_prompt=system_prompt,
|
|
82
|
+
top_k=top_k,
|
|
83
|
+
node_type=node_type,
|
|
84
|
+
node_name=node_name,
|
|
85
|
+
save_interaction=save_interaction,
|
|
86
|
+
last_k=last_k,
|
|
87
|
+
only_context=only_context,
|
|
88
|
+
session_id=session_id,
|
|
89
|
+
wide_search_top_k=wide_search_top_k,
|
|
90
|
+
triplet_distance_penalty=triplet_distance_penalty,
|
|
91
|
+
retriever_specific_config=retriever_specific_config,
|
|
92
|
+
)
|
|
117
93
|
|
|
118
94
|
send_telemetry(
|
|
119
95
|
"cognee.search EXECUTION COMPLETED",
|
|
@@ -126,95 +102,11 @@ async def search(
|
|
|
126
102
|
|
|
127
103
|
await log_result(
|
|
128
104
|
query.id,
|
|
129
|
-
json.dumps(
|
|
130
|
-
jsonable_encoder(
|
|
131
|
-
await prepare_search_result(
|
|
132
|
-
search_results[0] if isinstance(search_results, list) else search_results
|
|
133
|
-
)
|
|
134
|
-
if use_combined_context
|
|
135
|
-
else [
|
|
136
|
-
await prepare_search_result(search_result) for search_result in search_results
|
|
137
|
-
]
|
|
138
|
-
)
|
|
139
|
-
),
|
|
105
|
+
json.dumps(jsonable_encoder(search_results)),
|
|
140
106
|
user.id,
|
|
141
107
|
)
|
|
142
108
|
|
|
143
|
-
|
|
144
|
-
# Note: combined context search must always be verbose and return a CombinedSearchResult with graphs info
|
|
145
|
-
prepared_search_results = await prepare_search_result(
|
|
146
|
-
search_results[0] if isinstance(search_results, list) else search_results
|
|
147
|
-
)
|
|
148
|
-
result = prepared_search_results["result"]
|
|
149
|
-
graphs = prepared_search_results["graphs"]
|
|
150
|
-
context = prepared_search_results["context"]
|
|
151
|
-
datasets = prepared_search_results["datasets"]
|
|
152
|
-
|
|
153
|
-
return CombinedSearchResult(
|
|
154
|
-
result=result,
|
|
155
|
-
graphs=graphs,
|
|
156
|
-
context=context,
|
|
157
|
-
datasets=[
|
|
158
|
-
SearchResultDataset(
|
|
159
|
-
id=dataset.id,
|
|
160
|
-
name=dataset.name,
|
|
161
|
-
)
|
|
162
|
-
for dataset in datasets
|
|
163
|
-
],
|
|
164
|
-
)
|
|
165
|
-
else:
|
|
166
|
-
# This is for maintaining backwards compatibility
|
|
167
|
-
if backend_access_control_enabled():
|
|
168
|
-
return_value = []
|
|
169
|
-
for search_result in search_results:
|
|
170
|
-
prepared_search_results = await prepare_search_result(search_result)
|
|
171
|
-
|
|
172
|
-
result = prepared_search_results["result"]
|
|
173
|
-
graphs = prepared_search_results["graphs"]
|
|
174
|
-
context = prepared_search_results["context"]
|
|
175
|
-
datasets = prepared_search_results["datasets"]
|
|
176
|
-
|
|
177
|
-
if only_context:
|
|
178
|
-
search_result_dict = {
|
|
179
|
-
"search_result": [context] if context else None,
|
|
180
|
-
"dataset_id": datasets[0].id,
|
|
181
|
-
"dataset_name": datasets[0].name,
|
|
182
|
-
"dataset_tenant_id": datasets[0].tenant_id,
|
|
183
|
-
}
|
|
184
|
-
if verbose:
|
|
185
|
-
# Include graphs only in verbose mode
|
|
186
|
-
search_result_dict["graphs"] = graphs
|
|
187
|
-
|
|
188
|
-
return_value.append(search_result_dict)
|
|
189
|
-
else:
|
|
190
|
-
search_result_dict = {
|
|
191
|
-
"search_result": [result] if result else None,
|
|
192
|
-
"dataset_id": datasets[0].id,
|
|
193
|
-
"dataset_name": datasets[0].name,
|
|
194
|
-
"dataset_tenant_id": datasets[0].tenant_id,
|
|
195
|
-
}
|
|
196
|
-
if verbose:
|
|
197
|
-
# Include graphs only in verbose mode
|
|
198
|
-
search_result_dict["graphs"] = graphs
|
|
199
|
-
|
|
200
|
-
return_value.append(search_result_dict)
|
|
201
|
-
|
|
202
|
-
return return_value
|
|
203
|
-
else:
|
|
204
|
-
return_value = []
|
|
205
|
-
if only_context:
|
|
206
|
-
for search_result in search_results:
|
|
207
|
-
prepared_search_results = await prepare_search_result(search_result)
|
|
208
|
-
return_value.append(prepared_search_results["context"])
|
|
209
|
-
else:
|
|
210
|
-
for search_result in search_results:
|
|
211
|
-
result, context, datasets = search_result
|
|
212
|
-
return_value.append(result)
|
|
213
|
-
# For maintaining backwards compatibility
|
|
214
|
-
if len(return_value) == 1 and isinstance(return_value[0], list):
|
|
215
|
-
return return_value[0]
|
|
216
|
-
else:
|
|
217
|
-
return return_value
|
|
109
|
+
return _backwards_compatible_search_results(search_results, verbose)
|
|
218
110
|
|
|
219
111
|
|
|
220
112
|
async def authorized_search(
|
|
@@ -230,14 +122,11 @@ async def authorized_search(
|
|
|
230
122
|
save_interaction: bool = False,
|
|
231
123
|
last_k: Optional[int] = None,
|
|
232
124
|
only_context: bool = False,
|
|
233
|
-
use_combined_context: bool = False,
|
|
234
125
|
session_id: Optional[str] = None,
|
|
235
126
|
wide_search_top_k: Optional[int] = 100,
|
|
236
127
|
triplet_distance_penalty: Optional[float] = 3.5,
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
|
|
240
|
-
]:
|
|
128
|
+
retriever_specific_config: Optional[dict] = None,
|
|
129
|
+
) -> List[Tuple[Any, Union[List[Edge], str], List[Dataset]]]:
|
|
241
130
|
"""
|
|
242
131
|
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
|
|
243
132
|
Not to be used outside of active access control mode.
|
|
@@ -247,70 +136,6 @@ async def authorized_search(
|
|
|
247
136
|
datasets=dataset_ids, permission_type="read", user=user
|
|
248
137
|
)
|
|
249
138
|
|
|
250
|
-
if use_combined_context:
|
|
251
|
-
search_responses = await search_in_datasets_context(
|
|
252
|
-
search_datasets=search_datasets,
|
|
253
|
-
query_type=query_type,
|
|
254
|
-
query_text=query_text,
|
|
255
|
-
system_prompt_path=system_prompt_path,
|
|
256
|
-
system_prompt=system_prompt,
|
|
257
|
-
top_k=top_k,
|
|
258
|
-
node_type=node_type,
|
|
259
|
-
node_name=node_name,
|
|
260
|
-
save_interaction=save_interaction,
|
|
261
|
-
last_k=last_k,
|
|
262
|
-
only_context=True,
|
|
263
|
-
session_id=session_id,
|
|
264
|
-
wide_search_top_k=wide_search_top_k,
|
|
265
|
-
triplet_distance_penalty=triplet_distance_penalty,
|
|
266
|
-
)
|
|
267
|
-
|
|
268
|
-
context = {}
|
|
269
|
-
datasets: List[Dataset] = []
|
|
270
|
-
|
|
271
|
-
for _, search_context, search_datasets in search_responses:
|
|
272
|
-
for dataset in search_datasets:
|
|
273
|
-
context[str(dataset.id)] = search_context
|
|
274
|
-
|
|
275
|
-
datasets.extend(search_datasets)
|
|
276
|
-
|
|
277
|
-
specific_search_tools = await get_search_type_tools(
|
|
278
|
-
query_type=query_type,
|
|
279
|
-
query_text=query_text,
|
|
280
|
-
system_prompt_path=system_prompt_path,
|
|
281
|
-
system_prompt=system_prompt,
|
|
282
|
-
top_k=top_k,
|
|
283
|
-
node_type=node_type,
|
|
284
|
-
node_name=node_name,
|
|
285
|
-
save_interaction=save_interaction,
|
|
286
|
-
last_k=last_k,
|
|
287
|
-
wide_search_top_k=wide_search_top_k,
|
|
288
|
-
triplet_distance_penalty=triplet_distance_penalty,
|
|
289
|
-
)
|
|
290
|
-
search_tools = specific_search_tools
|
|
291
|
-
if len(search_tools) == 2:
|
|
292
|
-
[get_completion, _] = search_tools
|
|
293
|
-
else:
|
|
294
|
-
get_completion = search_tools[0]
|
|
295
|
-
|
|
296
|
-
def prepare_combined_context(
|
|
297
|
-
context,
|
|
298
|
-
) -> Union[List[Edge], str]:
|
|
299
|
-
combined_context = []
|
|
300
|
-
|
|
301
|
-
for dataset_context in context.values():
|
|
302
|
-
combined_context += dataset_context
|
|
303
|
-
|
|
304
|
-
if combined_context and isinstance(combined_context[0], str):
|
|
305
|
-
return "\n".join(combined_context)
|
|
306
|
-
|
|
307
|
-
return combined_context
|
|
308
|
-
|
|
309
|
-
combined_context = prepare_combined_context(context)
|
|
310
|
-
completion = await get_completion(query_text, combined_context, session_id=session_id)
|
|
311
|
-
|
|
312
|
-
return completion, combined_context, datasets
|
|
313
|
-
|
|
314
139
|
# Searches all provided datasets and handles setting up of appropriate database context based on permissions
|
|
315
140
|
search_results = await search_in_datasets_context(
|
|
316
141
|
search_datasets=search_datasets,
|
|
@@ -326,6 +151,8 @@ async def authorized_search(
|
|
|
326
151
|
only_context=only_context,
|
|
327
152
|
session_id=session_id,
|
|
328
153
|
wide_search_top_k=wide_search_top_k,
|
|
154
|
+
triplet_distance_penalty=triplet_distance_penalty,
|
|
155
|
+
retriever_specific_config=retriever_specific_config,
|
|
329
156
|
)
|
|
330
157
|
|
|
331
158
|
return search_results
|
|
@@ -343,10 +170,10 @@ async def search_in_datasets_context(
|
|
|
343
170
|
save_interaction: bool = False,
|
|
344
171
|
last_k: Optional[int] = None,
|
|
345
172
|
only_context: bool = False,
|
|
346
|
-
context: Optional[Any] = None,
|
|
347
173
|
session_id: Optional[str] = None,
|
|
348
174
|
wide_search_top_k: Optional[int] = 100,
|
|
349
175
|
triplet_distance_penalty: Optional[float] = 3.5,
|
|
176
|
+
retriever_specific_config: Optional[dict] = None,
|
|
350
177
|
) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]:
|
|
351
178
|
"""
|
|
352
179
|
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
|
|
@@ -365,17 +192,17 @@ async def search_in_datasets_context(
|
|
|
365
192
|
save_interaction: bool = False,
|
|
366
193
|
last_k: Optional[int] = None,
|
|
367
194
|
only_context: bool = False,
|
|
368
|
-
context: Optional[Any] = None,
|
|
369
195
|
session_id: Optional[str] = None,
|
|
370
196
|
wide_search_top_k: Optional[int] = 100,
|
|
371
197
|
triplet_distance_penalty: Optional[float] = 3.5,
|
|
198
|
+
retriever_specific_config: Optional[dict] = None,
|
|
372
199
|
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
|
|
373
200
|
# Set database configuration in async context for each dataset user has access for
|
|
374
201
|
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
|
375
202
|
|
|
203
|
+
# Check if graph for dataset is empty and log warnings if necessary
|
|
376
204
|
graph_engine = await get_graph_engine()
|
|
377
205
|
is_empty = await graph_engine.is_empty()
|
|
378
|
-
|
|
379
206
|
if is_empty:
|
|
380
207
|
# TODO: we can log here, but not all search types use graph. Still keeping this here for reviewer input
|
|
381
208
|
from cognee.modules.data.methods import get_dataset_data
|
|
@@ -389,12 +216,14 @@ async def search_in_datasets_context(
|
|
|
389
216
|
)
|
|
390
217
|
else:
|
|
391
218
|
logger.warning(
|
|
392
|
-
"Search attempt on an empty knowledge graph - no data has been added to this dataset"
|
|
219
|
+
f"Search attempt on an empty knowledge graph - no data has been added to this dataset: {dataset.name}"
|
|
393
220
|
)
|
|
394
221
|
|
|
395
|
-
|
|
222
|
+
# Get retriever output in the context of the current dataset
|
|
223
|
+
return await get_retriever_output(
|
|
396
224
|
query_type=query_type,
|
|
397
225
|
query_text=query_text,
|
|
226
|
+
dataset=dataset,
|
|
398
227
|
system_prompt_path=system_prompt_path,
|
|
399
228
|
system_prompt=system_prompt,
|
|
400
229
|
top_k=top_k,
|
|
@@ -402,24 +231,12 @@ async def search_in_datasets_context(
|
|
|
402
231
|
node_name=node_name,
|
|
403
232
|
save_interaction=save_interaction,
|
|
404
233
|
last_k=last_k,
|
|
234
|
+
only_context=only_context,
|
|
235
|
+
session_id=session_id,
|
|
405
236
|
wide_search_top_k=wide_search_top_k,
|
|
406
237
|
triplet_distance_penalty=triplet_distance_penalty,
|
|
238
|
+
retriever_specific_config=retriever_specific_config,
|
|
407
239
|
)
|
|
408
|
-
search_tools = specific_search_tools
|
|
409
|
-
if len(search_tools) == 2:
|
|
410
|
-
[get_completion, get_context] = search_tools
|
|
411
|
-
|
|
412
|
-
if only_context:
|
|
413
|
-
return None, await get_context(query_text), [dataset]
|
|
414
|
-
|
|
415
|
-
search_context = context or await get_context(query_text)
|
|
416
|
-
search_result = await get_completion(query_text, search_context, session_id=session_id)
|
|
417
|
-
|
|
418
|
-
return search_result, search_context, [dataset]
|
|
419
|
-
else:
|
|
420
|
-
unknown_tool = search_tools[0]
|
|
421
|
-
|
|
422
|
-
return await unknown_tool(query_text), "", [dataset]
|
|
423
240
|
|
|
424
241
|
# Search every dataset async based on query and appropriate database configuration
|
|
425
242
|
tasks = []
|
|
@@ -437,11 +254,62 @@ async def search_in_datasets_context(
|
|
|
437
254
|
save_interaction=save_interaction,
|
|
438
255
|
last_k=last_k,
|
|
439
256
|
only_context=only_context,
|
|
440
|
-
context=context,
|
|
441
257
|
session_id=session_id,
|
|
442
258
|
wide_search_top_k=wide_search_top_k,
|
|
443
259
|
triplet_distance_penalty=triplet_distance_penalty,
|
|
260
|
+
retriever_specific_config=retriever_specific_config,
|
|
444
261
|
)
|
|
445
262
|
)
|
|
446
263
|
|
|
447
264
|
return await asyncio.gather(*tasks)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def _backwards_compatible_search_results(search_results, verbose: bool):
|
|
268
|
+
"""
|
|
269
|
+
Prepares search results in a format compatible with previous versions of the API.
|
|
270
|
+
"""
|
|
271
|
+
# This is for maintaining backwards compatibility
|
|
272
|
+
if backend_access_control_enabled():
|
|
273
|
+
return_value = []
|
|
274
|
+
for search_result in search_results:
|
|
275
|
+
# Dataset info needs to be always included
|
|
276
|
+
search_result_dict = {
|
|
277
|
+
"dataset_id": search_result.dataset_id,
|
|
278
|
+
"dataset_name": search_result.dataset_name,
|
|
279
|
+
"dataset_tenant_id": search_result.dataset_tenant_id,
|
|
280
|
+
}
|
|
281
|
+
if verbose:
|
|
282
|
+
# Include all different types of results only in verbose mode
|
|
283
|
+
search_result_dict["text_result"] = search_result.completion
|
|
284
|
+
search_result_dict["context_result"] = search_result.context
|
|
285
|
+
search_result_dict["objects_result"] = search_result.result_object
|
|
286
|
+
else:
|
|
287
|
+
# Result attribute handles returning appropriate result based on set flags and outputs
|
|
288
|
+
search_result_dict["search_result"] = search_result.result
|
|
289
|
+
|
|
290
|
+
return_value.append(search_result_dict)
|
|
291
|
+
return return_value
|
|
292
|
+
else:
|
|
293
|
+
return_value = []
|
|
294
|
+
if verbose:
|
|
295
|
+
for search_result in search_results:
|
|
296
|
+
# Include all different types of results only in verbose mode
|
|
297
|
+
search_result_dict = {
|
|
298
|
+
"text_result": search_result.completion,
|
|
299
|
+
"context_result": search_result.context,
|
|
300
|
+
"objects_result": search_result.result_object,
|
|
301
|
+
}
|
|
302
|
+
return_value.append(search_result_dict)
|
|
303
|
+
return return_value
|
|
304
|
+
else:
|
|
305
|
+
for search_result in search_results:
|
|
306
|
+
# Result attribute handles returning appropriate result based on set flags and outputs
|
|
307
|
+
return_value.append(search_result.result)
|
|
308
|
+
|
|
309
|
+
# For maintaining backwards compatibility
|
|
310
|
+
if len(return_value) == 1 and isinstance(return_value[0], list):
|
|
311
|
+
# If a single element list return the element directly
|
|
312
|
+
return return_value[0]
|
|
313
|
+
else:
|
|
314
|
+
# Otherwise return the list of results
|
|
315
|
+
return return_value
|