cognee 0.5.1__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +2 -0
- cognee/alembic/README +1 -0
- cognee/alembic/env.py +107 -0
- cognee/alembic/script.py.mako +26 -0
- cognee/alembic/versions/1a58b986e6e1_enable_delete_for_old_tutorial_notebooks.py +52 -0
- cognee/alembic/versions/1d0bb7fede17_add_pipeline_run_status.py +33 -0
- cognee/alembic/versions/1daae0df1866_incremental_loading.py +48 -0
- cognee/alembic/versions/211ab850ef3d_add_sync_operations_table.py +118 -0
- cognee/alembic/versions/45957f0a9849_add_notebook_table.py +46 -0
- cognee/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py +333 -0
- cognee/alembic/versions/482cd6517ce4_add_default_user.py +30 -0
- cognee/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py +98 -0
- cognee/alembic/versions/8057ae7329c2_initial_migration.py +25 -0
- cognee/alembic/versions/9e7a3cb85175_loader_separation.py +104 -0
- cognee/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py +38 -0
- cognee/alembic/versions/ab7e313804ae_permission_system_rework.py +236 -0
- cognee/alembic/versions/b9274c27a25a_kuzu_11_migration.py +75 -0
- cognee/alembic/versions/c946955da633_multi_tenant_support.py +137 -0
- cognee/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py +51 -0
- cognee/alembic/versions/e4ebee1091e7_expand_data_model_info.py +140 -0
- cognee/alembic.ini +117 -0
- cognee/api/v1/add/add.py +2 -1
- cognee/api/v1/add/routers/get_add_router.py +2 -0
- cognee/api/v1/cognify/cognify.py +11 -6
- cognee/api/v1/cognify/routers/get_cognify_router.py +8 -0
- cognee/api/v1/config/config.py +60 -0
- cognee/api/v1/datasets/routers/get_datasets_router.py +46 -3
- cognee/api/v1/memify/routers/get_memify_router.py +3 -0
- cognee/api/v1/search/routers/get_search_router.py +21 -6
- cognee/api/v1/search/search.py +21 -5
- cognee/api/v1/sync/routers/get_sync_router.py +3 -3
- cognee/cli/commands/add_command.py +1 -1
- cognee/cli/commands/cognify_command.py +6 -0
- cognee/cli/commands/config_command.py +1 -1
- cognee/context_global_variables.py +5 -1
- cognee/eval_framework/answer_generation/answer_generation_executor.py +7 -8
- cognee/infrastructure/databases/cache/cache_db_interface.py +38 -1
- cognee/infrastructure/databases/cache/config.py +6 -0
- cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py +21 -0
- cognee/infrastructure/databases/cache/get_cache_engine.py +9 -3
- cognee/infrastructure/databases/cache/redis/RedisAdapter.py +60 -1
- cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +7 -0
- cognee/infrastructure/databases/graph/get_graph_engine.py +29 -1
- cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +62 -27
- cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +17 -4
- cognee/infrastructure/databases/relational/config.py +16 -1
- cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +26 -3
- cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +2 -0
- cognee/infrastructure/databases/vector/config.py +6 -0
- cognee/infrastructure/databases/vector/create_vector_engine.py +70 -16
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +64 -9
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +13 -2
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +16 -3
- cognee/infrastructure/databases/vector/models/ScoredResult.py +3 -3
- cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +16 -3
- cognee/infrastructure/databases/vector/pgvector/PGVectorDatasetDatabaseHandler.py +86 -0
- cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py +81 -2
- cognee/infrastructure/databases/vector/vector_db_interface.py +8 -0
- cognee/infrastructure/files/utils/get_data_file_path.py +33 -27
- cognee/infrastructure/llm/LLMGateway.py +0 -13
- cognee/infrastructure/llm/prompts/extract_query_time.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_event_entity_prompt.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_event_graph_prompt.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt.txt +2 -2
- cognee/infrastructure/llm/prompts/generate_graph_prompt_guided.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt_oneshot.txt +2 -2
- cognee/infrastructure/llm/prompts/generate_graph_prompt_simple.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt_strict.txt +1 -1
- cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt +6 -6
- cognee/infrastructure/llm/prompts/test.txt +1 -1
- cognee/infrastructure/llm/prompts/translate_content.txt +19 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +29 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llama_cpp/adapter.py +191 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
- cognee/modules/chunking/models/DocumentChunk.py +0 -1
- cognee/modules/cognify/config.py +2 -0
- cognee/modules/data/models/Data.py +3 -1
- cognee/modules/engine/models/Entity.py +0 -1
- cognee/modules/engine/operations/setup.py +6 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +150 -37
- cognee/modules/graph/cognee_graph/CogneeGraphElements.py +48 -2
- cognee/modules/graph/utils/__init__.py +1 -0
- cognee/modules/graph/utils/get_entity_nodes_from_triplets.py +12 -0
- cognee/modules/notebooks/methods/__init__.py +1 -0
- cognee/modules/notebooks/methods/create_notebook.py +0 -34
- cognee/modules/notebooks/methods/create_tutorial_notebooks.py +191 -0
- cognee/modules/notebooks/methods/get_notebooks.py +12 -8
- cognee/modules/notebooks/tutorials/cognee-basics/cell-1.md +3 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-2.md +10 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-3.md +7 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-4.py +28 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-5.py +3 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-6.py +9 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-7.py +17 -0
- cognee/modules/notebooks/tutorials/cognee-basics/config.json +4 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-1.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-10.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-11.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-12.py +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-13.md +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-14.py +6 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-15.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-16.py +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-2.md +9 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-3.md +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-4.md +9 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-5.md +5 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-6.py +13 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-7.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-8.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-9.py +31 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/config.json +4 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/copilot_conversations.json +107 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/guido_contributions.json +976 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/my_developer_rules.md +79 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/pep_style_guide.md +74 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/zen_principles.md +74 -0
- cognee/modules/retrieval/EntityCompletionRetriever.py +51 -38
- cognee/modules/retrieval/__init__.py +0 -1
- cognee/modules/retrieval/base_retriever.py +66 -10
- cognee/modules/retrieval/chunks_retriever.py +57 -49
- cognee/modules/retrieval/coding_rules_retriever.py +12 -5
- cognee/modules/retrieval/completion_retriever.py +29 -28
- cognee/modules/retrieval/cypher_search_retriever.py +25 -20
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +42 -46
- cognee/modules/retrieval/graph_completion_cot_retriever.py +68 -51
- cognee/modules/retrieval/graph_completion_retriever.py +78 -63
- cognee/modules/retrieval/graph_summary_completion_retriever.py +2 -0
- cognee/modules/retrieval/lexical_retriever.py +34 -12
- cognee/modules/retrieval/natural_language_retriever.py +18 -15
- cognee/modules/retrieval/summaries_retriever.py +51 -34
- cognee/modules/retrieval/temporal_retriever.py +59 -49
- cognee/modules/retrieval/triplet_retriever.py +32 -33
- cognee/modules/retrieval/utils/access_tracking.py +88 -0
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +99 -103
- cognee/modules/retrieval/utils/node_edge_vector_search.py +174 -0
- cognee/modules/search/methods/__init__.py +1 -0
- cognee/modules/search/methods/get_retriever_output.py +53 -0
- cognee/modules/search/methods/get_search_type_retriever_instance.py +252 -0
- cognee/modules/search/methods/search.py +90 -222
- cognee/modules/search/models/SearchResultPayload.py +67 -0
- cognee/modules/search/types/SearchResult.py +1 -8
- cognee/modules/search/types/SearchType.py +1 -2
- cognee/modules/search/types/__init__.py +1 -1
- cognee/modules/search/utils/__init__.py +1 -2
- cognee/modules/search/utils/transform_insights_to_graph.py +2 -2
- cognee/modules/search/utils/{transform_context_to_graph.py → transform_triplets_to_graph.py} +2 -2
- cognee/modules/users/authentication/default/default_transport.py +11 -1
- cognee/modules/users/authentication/get_api_auth_backend.py +2 -1
- cognee/modules/users/authentication/get_client_auth_backend.py +2 -1
- cognee/modules/users/methods/create_user.py +0 -9
- cognee/modules/users/permissions/methods/has_user_management_permission.py +29 -0
- cognee/modules/visualization/cognee_network_visualization.py +1 -1
- cognee/run_migrations.py +48 -0
- cognee/shared/exceptions/__init__.py +1 -3
- cognee/shared/exceptions/exceptions.py +11 -1
- cognee/shared/usage_logger.py +332 -0
- cognee/shared/utils.py +12 -5
- cognee/tasks/chunks/__init__.py +9 -0
- cognee/tasks/cleanup/cleanup_unused_data.py +172 -0
- cognee/tasks/graph/__init__.py +7 -0
- cognee/tasks/ingestion/data_item.py +8 -0
- cognee/tasks/ingestion/ingest_data.py +12 -1
- cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
- cognee/tasks/memify/__init__.py +8 -0
- cognee/tasks/memify/extract_usage_frequency.py +613 -0
- cognee/tasks/summarization/models.py +0 -2
- cognee/tasks/temporal_graph/__init__.py +0 -1
- cognee/tasks/translation/__init__.py +96 -0
- cognee/tasks/translation/config.py +110 -0
- cognee/tasks/translation/detect_language.py +190 -0
- cognee/tasks/translation/exceptions.py +62 -0
- cognee/tasks/translation/models.py +72 -0
- cognee/tasks/translation/providers/__init__.py +44 -0
- cognee/tasks/translation/providers/azure_provider.py +192 -0
- cognee/tasks/translation/providers/base.py +85 -0
- cognee/tasks/translation/providers/google_provider.py +158 -0
- cognee/tasks/translation/providers/llm_provider.py +143 -0
- cognee/tasks/translation/translate_content.py +282 -0
- cognee/tasks/web_scraper/default_url_crawler.py +6 -2
- cognee/tests/cli_tests/cli_unit_tests/test_cli_commands.py +1 -0
- cognee/tests/cli_tests/cli_unit_tests/test_cli_edge_cases.py +3 -0
- cognee/tests/integration/retrieval/test_brute_force_triplet_search_with_cognify.py +62 -0
- cognee/tests/integration/retrieval/test_chunks_retriever.py +351 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever.py +276 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +228 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +217 -0
- cognee/tests/integration/retrieval/test_rag_completion_retriever.py +319 -0
- cognee/tests/integration/retrieval/test_structured_output.py +258 -0
- cognee/tests/integration/retrieval/test_summaries_retriever.py +195 -0
- cognee/tests/integration/retrieval/test_temporal_retriever.py +336 -0
- cognee/tests/integration/retrieval/test_triplet_retriever.py +45 -1
- cognee/tests/integration/shared/test_usage_logger_integration.py +255 -0
- cognee/tests/tasks/translation/README.md +147 -0
- cognee/tests/tasks/translation/__init__.py +1 -0
- cognee/tests/tasks/translation/config_test.py +93 -0
- cognee/tests/tasks/translation/detect_language_test.py +118 -0
- cognee/tests/tasks/translation/providers_test.py +151 -0
- cognee/tests/tasks/translation/translate_content_test.py +213 -0
- cognee/tests/test_chromadb.py +1 -1
- cognee/tests/test_cleanup_unused_data.py +165 -0
- cognee/tests/test_custom_data_label.py +68 -0
- cognee/tests/test_delete_by_id.py +6 -6
- cognee/tests/test_extract_usage_frequency.py +308 -0
- cognee/tests/test_kuzu.py +17 -7
- cognee/tests/test_lancedb.py +3 -1
- cognee/tests/test_library.py +1 -1
- cognee/tests/test_neo4j.py +17 -7
- cognee/tests/test_neptune_analytics_vector.py +3 -1
- cognee/tests/test_permissions.py +172 -187
- cognee/tests/test_pgvector.py +3 -1
- cognee/tests/test_relational_db_migration.py +15 -1
- cognee/tests/test_remote_kuzu.py +3 -1
- cognee/tests/test_s3_file_storage.py +1 -1
- cognee/tests/test_search_db.py +345 -205
- cognee/tests/test_usage_logger_e2e.py +268 -0
- cognee/tests/unit/api/test_get_raw_data_endpoint.py +206 -0
- cognee/tests/unit/eval_framework/answer_generation_test.py +4 -3
- cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
- cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
- cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +2 -0
- cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
- cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +42 -2
- cognee/tests/unit/modules/graph/cognee_graph_test.py +329 -31
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +122 -168
- cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +486 -157
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +693 -155
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +619 -200
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +300 -171
- cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +184 -155
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +544 -79
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +476 -28
- cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
- cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
- cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py +273 -0
- cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +267 -7
- cognee/tests/unit/modules/search/test_get_search_type_retriever_instance.py +125 -0
- cognee/tests/unit/modules/search/test_search.py +96 -20
- cognee/tests/unit/modules/search/test_search_prepare_search_result_contract.py +190 -0
- cognee/tests/unit/modules/users/test_tutorial_notebook_creation.py +511 -297
- cognee/tests/unit/shared/test_usage_logger.py +241 -0
- cognee/tests/unit/users/permissions/test_has_user_management_permission.py +46 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/METADATA +22 -17
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/RECORD +258 -157
- cognee/api/.env.example +0 -5
- cognee/modules/retrieval/base_graph_retriever.py +0 -24
- cognee/modules/search/methods/get_search_type_tools.py +0 -223
- cognee/modules/search/methods/no_access_control_search.py +0 -62
- cognee/modules/search/utils/prepare_search_result.py +0 -63
- cognee/tests/test_feedback_enrichment.py +0 -174
- cognee/tests/unit/modules/retrieval/structured_output_test.py +0 -204
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/WHEEL +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -31,14 +31,18 @@ async def test_get_context_success(mock_vector_engine):
|
|
|
31
31
|
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
|
32
32
|
return_value=mock_vector_engine,
|
|
33
33
|
):
|
|
34
|
-
|
|
34
|
+
objects = await retriever.get_retrieved_objects("test query")
|
|
35
|
+
context = await retriever.get_context_from_objects("test query", objects)
|
|
36
|
+
await retriever.get_completion_from_context("test query", objects, context)
|
|
35
37
|
|
|
36
38
|
assert context == "Alice knows Bob\nBob works at Tech Corp"
|
|
37
|
-
mock_vector_engine.search.assert_awaited_once_with(
|
|
39
|
+
mock_vector_engine.search.assert_awaited_once_with(
|
|
40
|
+
"Triplet_text", "test query", limit=5, include_payload=True
|
|
41
|
+
)
|
|
38
42
|
|
|
39
43
|
|
|
40
44
|
@pytest.mark.asyncio
|
|
41
|
-
async def
|
|
45
|
+
async def test_get_objects_no_collection(mock_vector_engine):
|
|
42
46
|
"""Test that NoDataError is raised when Triplet_text collection doesn't exist."""
|
|
43
47
|
mock_vector_engine.has_collection.return_value = False
|
|
44
48
|
|
|
@@ -49,7 +53,7 @@ async def test_get_context_no_collection(mock_vector_engine):
|
|
|
49
53
|
return_value=mock_vector_engine,
|
|
50
54
|
):
|
|
51
55
|
with pytest.raises(NoDataError, match="create_triplet_embeddings"):
|
|
52
|
-
await retriever.
|
|
56
|
+
await retriever.get_retrieved_objects("test query")
|
|
53
57
|
|
|
54
58
|
|
|
55
59
|
@pytest.mark.asyncio
|
|
@@ -63,13 +67,13 @@ async def test_get_context_empty_results(mock_vector_engine):
|
|
|
63
67
|
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
|
64
68
|
return_value=mock_vector_engine,
|
|
65
69
|
):
|
|
66
|
-
context = await retriever.
|
|
70
|
+
context = await retriever.get_context_from_objects("test query", [])
|
|
67
71
|
|
|
68
72
|
assert context == ""
|
|
69
73
|
|
|
70
74
|
|
|
71
75
|
@pytest.mark.asyncio
|
|
72
|
-
async def
|
|
76
|
+
async def test_get_objects_collection_not_found_error(mock_vector_engine):
|
|
73
77
|
"""Test that CollectionNotFoundError is converted to NoDataError."""
|
|
74
78
|
mock_vector_engine.has_collection.side_effect = CollectionNotFoundError("Collection not found")
|
|
75
79
|
|
|
@@ -80,4 +84,260 @@ async def test_get_context_collection_not_found_error(mock_vector_engine):
|
|
|
80
84
|
return_value=mock_vector_engine,
|
|
81
85
|
):
|
|
82
86
|
with pytest.raises(NoDataError, match="No data found"):
|
|
83
|
-
await retriever.
|
|
87
|
+
await retriever.get_retrieved_objects("test query")
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@pytest.mark.asyncio
|
|
91
|
+
async def test_get_context_empty_payload_text(mock_vector_engine):
|
|
92
|
+
"""Test get_context handles missing text in payload."""
|
|
93
|
+
mock_result = MagicMock()
|
|
94
|
+
mock_result.payload = {}
|
|
95
|
+
|
|
96
|
+
mock_vector_engine.search.return_value = [mock_result]
|
|
97
|
+
|
|
98
|
+
retriever = TripletRetriever()
|
|
99
|
+
|
|
100
|
+
with patch(
|
|
101
|
+
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
|
102
|
+
return_value=mock_vector_engine,
|
|
103
|
+
):
|
|
104
|
+
with pytest.raises(KeyError):
|
|
105
|
+
objects = await retriever.get_retrieved_objects("test query")
|
|
106
|
+
await retriever.get_context_from_objects("test query", retrieved_objects=objects)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@pytest.mark.asyncio
|
|
110
|
+
async def test_get_context_single_triplet(mock_vector_engine):
|
|
111
|
+
"""Test get_context with single triplet result."""
|
|
112
|
+
mock_result = MagicMock()
|
|
113
|
+
mock_result.payload = {"text": "Single triplet"}
|
|
114
|
+
|
|
115
|
+
mock_vector_engine.search.return_value = [mock_result]
|
|
116
|
+
|
|
117
|
+
retriever = TripletRetriever()
|
|
118
|
+
|
|
119
|
+
with patch(
|
|
120
|
+
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
|
121
|
+
return_value=mock_vector_engine,
|
|
122
|
+
):
|
|
123
|
+
objects = await retriever.get_retrieved_objects("test query")
|
|
124
|
+
context = await retriever.get_context_from_objects("test query", retrieved_objects=objects)
|
|
125
|
+
|
|
126
|
+
assert context == "Single triplet"
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@pytest.mark.asyncio
|
|
130
|
+
async def test_init_defaults():
|
|
131
|
+
"""Test TripletRetriever initialization with defaults."""
|
|
132
|
+
retriever = TripletRetriever()
|
|
133
|
+
|
|
134
|
+
assert retriever.user_prompt_path == "context_for_question.txt"
|
|
135
|
+
assert retriever.system_prompt_path == "answer_simple_question.txt"
|
|
136
|
+
assert retriever.top_k == 5 # Default is 5
|
|
137
|
+
assert retriever.system_prompt is None
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
@pytest.mark.asyncio
|
|
141
|
+
async def test_init_custom_params():
|
|
142
|
+
"""Test TripletRetriever initialization with custom parameters."""
|
|
143
|
+
retriever = TripletRetriever(
|
|
144
|
+
user_prompt_path="custom_user.txt",
|
|
145
|
+
system_prompt_path="custom_system.txt",
|
|
146
|
+
system_prompt="Custom prompt",
|
|
147
|
+
top_k=10,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
assert retriever.user_prompt_path == "custom_user.txt"
|
|
151
|
+
assert retriever.system_prompt_path == "custom_system.txt"
|
|
152
|
+
assert retriever.system_prompt == "Custom prompt"
|
|
153
|
+
assert retriever.top_k == 10
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@pytest.mark.asyncio
|
|
157
|
+
async def test_get_completion_without_context(mock_vector_engine):
|
|
158
|
+
"""Test get_completion retrieves context when not provided."""
|
|
159
|
+
mock_result = MagicMock()
|
|
160
|
+
mock_result.payload = {"text": "Test triplet"}
|
|
161
|
+
mock_vector_engine.has_collection.return_value = True
|
|
162
|
+
mock_vector_engine.search.return_value = [mock_result]
|
|
163
|
+
|
|
164
|
+
retriever = TripletRetriever()
|
|
165
|
+
|
|
166
|
+
with (
|
|
167
|
+
patch(
|
|
168
|
+
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
|
169
|
+
return_value=mock_vector_engine,
|
|
170
|
+
),
|
|
171
|
+
patch(
|
|
172
|
+
"cognee.modules.retrieval.triplet_retriever.generate_completion",
|
|
173
|
+
return_value="Generated answer",
|
|
174
|
+
),
|
|
175
|
+
patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config,
|
|
176
|
+
):
|
|
177
|
+
mock_config = MagicMock()
|
|
178
|
+
mock_config.caching = False
|
|
179
|
+
mock_cache_config.return_value = mock_config
|
|
180
|
+
|
|
181
|
+
completion = await retriever.get_completion_from_context("test query", None, None)
|
|
182
|
+
|
|
183
|
+
assert isinstance(completion, list)
|
|
184
|
+
assert len(completion) == 1
|
|
185
|
+
assert completion[0] == "Generated answer"
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
@pytest.mark.asyncio
|
|
189
|
+
async def test_get_completion_with_provided_context(mock_vector_engine):
|
|
190
|
+
"""Test get_completion uses provided context."""
|
|
191
|
+
retriever = TripletRetriever()
|
|
192
|
+
|
|
193
|
+
with (
|
|
194
|
+
patch(
|
|
195
|
+
"cognee.modules.retrieval.triplet_retriever.generate_completion",
|
|
196
|
+
return_value="Generated answer",
|
|
197
|
+
),
|
|
198
|
+
patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config,
|
|
199
|
+
):
|
|
200
|
+
mock_config = MagicMock()
|
|
201
|
+
mock_config.caching = False
|
|
202
|
+
mock_cache_config.return_value = mock_config
|
|
203
|
+
|
|
204
|
+
completion = await retriever.get_completion_from_context(
|
|
205
|
+
"test query", None, context="Provided context"
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
assert isinstance(completion, list)
|
|
209
|
+
assert len(completion) == 1
|
|
210
|
+
assert completion[0] == "Generated answer"
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
@pytest.mark.asyncio
|
|
214
|
+
async def test_get_completion_with_session(mock_vector_engine):
|
|
215
|
+
"""Test get_completion with session caching enabled."""
|
|
216
|
+
mock_result = MagicMock()
|
|
217
|
+
mock_result.payload = {"text": "Test triplet"}
|
|
218
|
+
mock_vector_engine.has_collection.return_value = True
|
|
219
|
+
mock_vector_engine.search.return_value = [mock_result]
|
|
220
|
+
|
|
221
|
+
retriever = TripletRetriever(session_id="test_session")
|
|
222
|
+
|
|
223
|
+
mock_user = MagicMock()
|
|
224
|
+
mock_user.id = "test-user-id"
|
|
225
|
+
|
|
226
|
+
with (
|
|
227
|
+
patch(
|
|
228
|
+
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
|
229
|
+
return_value=mock_vector_engine,
|
|
230
|
+
),
|
|
231
|
+
patch(
|
|
232
|
+
"cognee.modules.retrieval.triplet_retriever.get_conversation_history",
|
|
233
|
+
return_value="Previous conversation",
|
|
234
|
+
),
|
|
235
|
+
patch(
|
|
236
|
+
"cognee.modules.retrieval.triplet_retriever.summarize_text",
|
|
237
|
+
return_value="Context summary",
|
|
238
|
+
),
|
|
239
|
+
patch(
|
|
240
|
+
"cognee.modules.retrieval.triplet_retriever.generate_completion",
|
|
241
|
+
return_value="Generated answer",
|
|
242
|
+
),
|
|
243
|
+
patch(
|
|
244
|
+
"cognee.modules.retrieval.triplet_retriever.save_conversation_history",
|
|
245
|
+
) as mock_save,
|
|
246
|
+
patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config,
|
|
247
|
+
patch("cognee.modules.retrieval.triplet_retriever.session_user") as mock_session_user,
|
|
248
|
+
):
|
|
249
|
+
mock_config = MagicMock()
|
|
250
|
+
mock_config.caching = True
|
|
251
|
+
mock_cache_config.return_value = mock_config
|
|
252
|
+
mock_session_user.get.return_value = mock_user
|
|
253
|
+
|
|
254
|
+
objects = await retriever.get_retrieved_objects("test query")
|
|
255
|
+
context = await retriever.get_context_from_objects("test query", retrieved_objects=objects)
|
|
256
|
+
completion = await retriever.get_completion_from_context("test query", objects, context)
|
|
257
|
+
|
|
258
|
+
assert isinstance(completion, list)
|
|
259
|
+
assert len(completion) == 1
|
|
260
|
+
assert completion[0] == "Generated answer"
|
|
261
|
+
mock_save.assert_awaited_once()
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
@pytest.mark.asyncio
|
|
265
|
+
async def test_get_completion_with_session_no_user_id(mock_vector_engine):
|
|
266
|
+
"""Test get_completion with session config but no user ID."""
|
|
267
|
+
mock_result = MagicMock()
|
|
268
|
+
mock_result.payload = {"text": "Test triplet"}
|
|
269
|
+
mock_vector_engine.has_collection.return_value = True
|
|
270
|
+
mock_vector_engine.search.return_value = [mock_result]
|
|
271
|
+
|
|
272
|
+
retriever = TripletRetriever()
|
|
273
|
+
|
|
274
|
+
with (
|
|
275
|
+
patch(
|
|
276
|
+
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
|
277
|
+
return_value=mock_vector_engine,
|
|
278
|
+
),
|
|
279
|
+
patch(
|
|
280
|
+
"cognee.modules.retrieval.triplet_retriever.generate_completion",
|
|
281
|
+
return_value="Generated answer",
|
|
282
|
+
),
|
|
283
|
+
patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config,
|
|
284
|
+
patch("cognee.modules.retrieval.triplet_retriever.session_user") as mock_session_user,
|
|
285
|
+
):
|
|
286
|
+
mock_config = MagicMock()
|
|
287
|
+
mock_config.caching = True
|
|
288
|
+
mock_cache_config.return_value = mock_config
|
|
289
|
+
mock_session_user.get.return_value = None # No user
|
|
290
|
+
|
|
291
|
+
objects = await retriever.get_retrieved_objects("test query")
|
|
292
|
+
context = await retriever.get_context_from_objects("test query", retrieved_objects=objects)
|
|
293
|
+
completion = await retriever.get_completion_from_context("test query", objects, context)
|
|
294
|
+
|
|
295
|
+
assert isinstance(completion, list)
|
|
296
|
+
assert len(completion) == 1
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
@pytest.mark.asyncio
|
|
300
|
+
async def test_get_completion_with_response_model(mock_vector_engine):
|
|
301
|
+
"""Test get_completion with custom response model."""
|
|
302
|
+
from pydantic import BaseModel
|
|
303
|
+
|
|
304
|
+
class TestModel(BaseModel):
|
|
305
|
+
answer: str
|
|
306
|
+
|
|
307
|
+
mock_result = MagicMock()
|
|
308
|
+
mock_result.payload = {"text": "Test triplet"}
|
|
309
|
+
mock_vector_engine.has_collection.return_value = True
|
|
310
|
+
mock_vector_engine.search.return_value = [mock_result]
|
|
311
|
+
|
|
312
|
+
retriever = TripletRetriever(response_model=TestModel)
|
|
313
|
+
|
|
314
|
+
with (
|
|
315
|
+
patch(
|
|
316
|
+
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
|
317
|
+
return_value=mock_vector_engine,
|
|
318
|
+
),
|
|
319
|
+
patch(
|
|
320
|
+
"cognee.modules.retrieval.triplet_retriever.generate_completion",
|
|
321
|
+
return_value=TestModel(answer="Test answer"),
|
|
322
|
+
),
|
|
323
|
+
patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config,
|
|
324
|
+
):
|
|
325
|
+
mock_config = MagicMock()
|
|
326
|
+
mock_config.caching = False
|
|
327
|
+
mock_cache_config.return_value = mock_config
|
|
328
|
+
|
|
329
|
+
objects = await retriever.get_retrieved_objects("test query")
|
|
330
|
+
context = await retriever.get_context_from_objects("test query", retrieved_objects=objects)
|
|
331
|
+
completion = await retriever.get_completion_from_context("test query", objects, context)
|
|
332
|
+
|
|
333
|
+
assert isinstance(completion, list)
|
|
334
|
+
assert len(completion) == 1
|
|
335
|
+
assert isinstance(completion[0], TestModel)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
@pytest.mark.asyncio
|
|
339
|
+
async def test_init_none_top_k():
|
|
340
|
+
"""Test TripletRetriever initialization with None top_k."""
|
|
341
|
+
retriever = TripletRetriever(top_k=None)
|
|
342
|
+
|
|
343
|
+
assert retriever.top_k == 5
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
|
|
4
|
+
from cognee.modules.search.types import SearchType
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class _DummyCommunityRetriever:
|
|
8
|
+
def __init__(self, *args, **kwargs):
|
|
9
|
+
self.kwargs = kwargs
|
|
10
|
+
|
|
11
|
+
def get_completion(self, *args, **kwargs):
|
|
12
|
+
return {"kind": "completion", "init": self.kwargs, "args": args, "kwargs": kwargs}
|
|
13
|
+
|
|
14
|
+
def get_context(self, *args, **kwargs):
|
|
15
|
+
return {"kind": "context", "init": self.kwargs, "args": args, "kwargs": kwargs}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@pytest.mark.asyncio
|
|
19
|
+
async def test_feeling_lucky_delegates_to_select_search_type(monkeypatch):
|
|
20
|
+
import cognee.modules.search.methods.get_search_type_retriever_instance as mod
|
|
21
|
+
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
|
|
22
|
+
|
|
23
|
+
async def _fake_select_search_type(query_text: str):
|
|
24
|
+
assert query_text == "hello"
|
|
25
|
+
return SearchType.CHUNKS
|
|
26
|
+
|
|
27
|
+
monkeypatch.setattr(mod, "select_search_type", _fake_select_search_type)
|
|
28
|
+
|
|
29
|
+
retriever_instance = await mod.get_search_type_retriever_instance(
|
|
30
|
+
SearchType.FEELING_LUCKY, query_text="hello"
|
|
31
|
+
)
|
|
32
|
+
assert isinstance(retriever_instance, ChunksRetriever)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@pytest.mark.asyncio
|
|
36
|
+
async def test_disallowed_cypher_search_types_raise(monkeypatch):
|
|
37
|
+
import cognee.modules.search.methods.get_search_type_retriever_instance as mod
|
|
38
|
+
|
|
39
|
+
monkeypatch.setenv("ALLOW_CYPHER_QUERY", "false")
|
|
40
|
+
|
|
41
|
+
with pytest.raises(UnsupportedSearchTypeError, match="disabled"):
|
|
42
|
+
await mod.get_search_type_retriever_instance(
|
|
43
|
+
SearchType.CYPHER, query_text="MATCH (n) RETURN n"
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
with pytest.raises(UnsupportedSearchTypeError, match="disabled"):
|
|
47
|
+
await mod.get_search_type_retriever_instance(
|
|
48
|
+
SearchType.NATURAL_LANGUAGE, query_text="Find nodes"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@pytest.mark.asyncio
|
|
53
|
+
async def test_allowed_cypher_search_types_return_tools(monkeypatch):
|
|
54
|
+
import cognee.modules.search.methods.get_search_type_retriever_instance as mod
|
|
55
|
+
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
|
|
56
|
+
|
|
57
|
+
monkeypatch.setenv("ALLOW_CYPHER_QUERY", "true")
|
|
58
|
+
|
|
59
|
+
retriever_instance = await mod.get_search_type_retriever_instance(
|
|
60
|
+
SearchType.CYPHER, query_text="q"
|
|
61
|
+
)
|
|
62
|
+
assert isinstance(retriever_instance, CypherSearchRetriever)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@pytest.mark.asyncio
|
|
66
|
+
async def test_registered_community_retriever_is_used(monkeypatch):
|
|
67
|
+
"""
|
|
68
|
+
Integration point: community retrievers are loaded from the registry module and should
|
|
69
|
+
override the default mapping when present.
|
|
70
|
+
"""
|
|
71
|
+
import cognee.modules.search.methods.get_search_type_retriever_instance as mod
|
|
72
|
+
from cognee.modules.retrieval import registered_community_retrievers as registry
|
|
73
|
+
|
|
74
|
+
monkeypatch.setattr(
|
|
75
|
+
registry,
|
|
76
|
+
"registered_community_retrievers",
|
|
77
|
+
{SearchType.SUMMARIES: _DummyCommunityRetriever},
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
retriever_instance = await mod.get_search_type_retriever_instance(
|
|
81
|
+
SearchType.SUMMARIES, query_text="q", top_k=7
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
assert isinstance(retriever_instance, _DummyCommunityRetriever)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@pytest.mark.asyncio
|
|
88
|
+
async def test_unknown_query_type_raises_unsupported():
|
|
89
|
+
import cognee.modules.search.methods.get_search_type_retriever_instance as mod
|
|
90
|
+
|
|
91
|
+
with pytest.raises(UnsupportedSearchTypeError, match="UNKNOWN_TYPE"):
|
|
92
|
+
await mod.get_search_type_retriever_instance("UNKNOWN_TYPE", query_text="q")
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@pytest.mark.asyncio
|
|
96
|
+
async def test_default_mapping_passes_top_k_to_retrievers():
|
|
97
|
+
import cognee.modules.search.methods.get_search_type_retriever_instance as mod
|
|
98
|
+
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
|
|
99
|
+
|
|
100
|
+
retriever_instance = await mod.get_search_type_retriever_instance(
|
|
101
|
+
SearchType.SUMMARIES, query_text="q", top_k=4
|
|
102
|
+
)
|
|
103
|
+
assert isinstance(retriever_instance, SummariesRetriever)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@pytest.mark.asyncio
|
|
107
|
+
async def test_chunks_lexical_returns_jaccard_tools():
|
|
108
|
+
import cognee.modules.search.methods.get_search_type_retriever_instance as mod
|
|
109
|
+
from cognee.modules.retrieval.jaccard_retrival import JaccardChunksRetriever
|
|
110
|
+
|
|
111
|
+
retriever_instance = await mod.get_search_type_retriever_instance(
|
|
112
|
+
SearchType.CHUNKS_LEXICAL, query_text="q", top_k=3
|
|
113
|
+
)
|
|
114
|
+
assert isinstance(retriever_instance, JaccardChunksRetriever)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@pytest.mark.asyncio
|
|
118
|
+
async def test_coding_rules_uses_node_name_as_rules_nodeset_name():
|
|
119
|
+
import cognee.modules.search.methods.get_search_type_retriever_instance as mod
|
|
120
|
+
from cognee.modules.retrieval.coding_rules_retriever import CodingRulesRetriever
|
|
121
|
+
|
|
122
|
+
retriever_instance = await mod.get_search_type_retriever_instance(
|
|
123
|
+
SearchType.CODING_RULES, query_text="q", node_name=[]
|
|
124
|
+
)
|
|
125
|
+
assert isinstance(retriever_instance, CodingRulesRetriever)
|
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
import types
|
|
2
|
-
from uuid import uuid4
|
|
2
|
+
from uuid import uuid4, uuid5, NAMESPACE_OID
|
|
3
3
|
|
|
4
4
|
import pytest
|
|
5
5
|
|
|
6
|
+
from cognee.modules.search.models.SearchResultPayload import SearchResultPayload
|
|
6
7
|
from cognee.modules.search.types import SearchType
|
|
7
8
|
|
|
8
9
|
|
|
@@ -12,9 +13,9 @@ def _make_user(user_id: str = "u1", tenant_id=None):
|
|
|
12
13
|
|
|
13
14
|
def _make_dataset(*, name="ds", tenant_id="t1", dataset_id=None, owner_id=None):
|
|
14
15
|
return types.SimpleNamespace(
|
|
15
|
-
id=
|
|
16
|
+
id=uuid5(NAMESPACE_OID, name),
|
|
16
17
|
name=name,
|
|
17
|
-
tenant_id=tenant_id,
|
|
18
|
+
tenant_id=uuid5(NAMESPACE_OID, tenant_id),
|
|
18
19
|
owner_id=owner_id or uuid4(),
|
|
19
20
|
)
|
|
20
21
|
|
|
@@ -38,16 +39,9 @@ def _patch_side_effect_boundaries(monkeypatch, search_mod):
|
|
|
38
39
|
async def dummy_log_result(*_args, **_kwargs):
|
|
39
40
|
return None
|
|
40
41
|
|
|
41
|
-
async def dummy_prepare_search_result(search_result):
|
|
42
|
-
if isinstance(search_result, tuple) and len(search_result) == 3:
|
|
43
|
-
result, context, datasets = search_result
|
|
44
|
-
return {"result": result, "context": context, "graphs": {}, "datasets": datasets}
|
|
45
|
-
return {"result": None, "context": None, "graphs": {}, "datasets": []}
|
|
46
|
-
|
|
47
42
|
monkeypatch.setattr(search_mod, "send_telemetry", lambda *a, **k: None)
|
|
48
43
|
monkeypatch.setattr(search_mod, "log_query", dummy_log_query)
|
|
49
44
|
monkeypatch.setattr(search_mod, "log_result", dummy_log_result)
|
|
50
|
-
monkeypatch.setattr(search_mod, "prepare_search_result", dummy_prepare_search_result)
|
|
51
45
|
|
|
52
46
|
yield
|
|
53
47
|
|
|
@@ -57,9 +51,19 @@ async def test_search_access_control_returns_dataset_shaped_dicts(monkeypatch, s
|
|
|
57
51
|
user = _make_user()
|
|
58
52
|
ds = _make_dataset(name="ds1", tenant_id="t1")
|
|
59
53
|
|
|
60
|
-
async def dummy_authorized_search(**
|
|
61
|
-
assert
|
|
62
|
-
return [
|
|
54
|
+
async def dummy_authorized_search(**_kwargs):
|
|
55
|
+
assert _kwargs["dataset_ids"] == [ds.id]
|
|
56
|
+
return [
|
|
57
|
+
SearchResultPayload(
|
|
58
|
+
result_object="object",
|
|
59
|
+
context=["ctx"],
|
|
60
|
+
completion=["r"],
|
|
61
|
+
search_type=SearchType.CHUNKS,
|
|
62
|
+
dataset_name=ds.name,
|
|
63
|
+
dataset_id=ds.id,
|
|
64
|
+
dataset_tenant_id=ds.tenant_id,
|
|
65
|
+
)
|
|
66
|
+
]
|
|
63
67
|
|
|
64
68
|
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
|
|
65
69
|
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
|
|
@@ -77,24 +81,96 @@ async def test_search_access_control_returns_dataset_shaped_dicts(monkeypatch, s
|
|
|
77
81
|
"search_result": ["r"],
|
|
78
82
|
"dataset_id": ds.id,
|
|
79
83
|
"dataset_name": "ds1",
|
|
80
|
-
"dataset_tenant_id": "t1",
|
|
84
|
+
"dataset_tenant_id": uuid5(NAMESPACE_OID, "t1"),
|
|
81
85
|
}
|
|
82
86
|
]
|
|
83
87
|
|
|
84
|
-
|
|
88
|
+
|
|
89
|
+
@pytest.mark.asyncio
|
|
90
|
+
async def test_search_access_control_only_context_returns_dataset_shaped_dicts(
|
|
91
|
+
monkeypatch, search_mod
|
|
92
|
+
):
|
|
93
|
+
user = _make_user()
|
|
94
|
+
ds = _make_dataset(name="ds1", tenant_id="t1")
|
|
95
|
+
|
|
96
|
+
async def dummy_authorized_search(**_kwargs):
|
|
97
|
+
return [
|
|
98
|
+
SearchResultPayload(
|
|
99
|
+
result_object="object",
|
|
100
|
+
context=["ctx"],
|
|
101
|
+
completion=None,
|
|
102
|
+
search_type=SearchType.CHUNKS,
|
|
103
|
+
dataset_name=ds.name,
|
|
104
|
+
dataset_id=ds.id,
|
|
105
|
+
dataset_tenant_id=ds.tenant_id,
|
|
106
|
+
)
|
|
107
|
+
]
|
|
108
|
+
|
|
109
|
+
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
|
|
110
|
+
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
|
|
111
|
+
|
|
112
|
+
out = await search_mod.search(
|
|
85
113
|
query_text="q",
|
|
86
114
|
query_type=SearchType.CHUNKS,
|
|
87
115
|
dataset_ids=[ds.id],
|
|
88
116
|
user=user,
|
|
89
|
-
|
|
117
|
+
only_context=True,
|
|
90
118
|
)
|
|
91
119
|
|
|
92
|
-
assert
|
|
120
|
+
assert out == [
|
|
93
121
|
{
|
|
94
|
-
"search_result": ["
|
|
122
|
+
"search_result": ["ctx"],
|
|
95
123
|
"dataset_id": ds.id,
|
|
96
124
|
"dataset_name": "ds1",
|
|
97
|
-
"dataset_tenant_id": "t1",
|
|
98
|
-
"graphs": {},
|
|
125
|
+
"dataset_tenant_id": uuid5(NAMESPACE_OID, "t1"),
|
|
99
126
|
}
|
|
100
127
|
]
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@pytest.mark.asyncio
|
|
131
|
+
async def test_authorized_search_delegates_to_search_in_datasets_context(monkeypatch, search_mod):
|
|
132
|
+
user = _make_user()
|
|
133
|
+
ds = _make_dataset(name="ds1")
|
|
134
|
+
|
|
135
|
+
async def dummy_get_authorized_existing_datasets(*_args, **_kwargs):
|
|
136
|
+
return [ds]
|
|
137
|
+
|
|
138
|
+
expected = [
|
|
139
|
+
SearchResultPayload(
|
|
140
|
+
result_object="object",
|
|
141
|
+
context="text",
|
|
142
|
+
completion="test",
|
|
143
|
+
search_type=SearchType.CHUNKS,
|
|
144
|
+
dataset_name=ds.name,
|
|
145
|
+
dataset_id=ds.id,
|
|
146
|
+
dataset_tenant_id=ds.tenant_id,
|
|
147
|
+
)
|
|
148
|
+
]
|
|
149
|
+
|
|
150
|
+
async def dummy_search_in_datasets_context(**_kwargs):
|
|
151
|
+
return [
|
|
152
|
+
SearchResultPayload(
|
|
153
|
+
result_object="object",
|
|
154
|
+
context="text",
|
|
155
|
+
completion="test",
|
|
156
|
+
search_type=SearchType.CHUNKS,
|
|
157
|
+
dataset_name=ds.name,
|
|
158
|
+
dataset_id=ds.id,
|
|
159
|
+
dataset_tenant_id=ds.tenant_id,
|
|
160
|
+
)
|
|
161
|
+
]
|
|
162
|
+
|
|
163
|
+
monkeypatch.setattr(
|
|
164
|
+
search_mod, "get_authorized_existing_datasets", dummy_get_authorized_existing_datasets
|
|
165
|
+
)
|
|
166
|
+
monkeypatch.setattr(search_mod, "search_in_datasets_context", dummy_search_in_datasets_context)
|
|
167
|
+
|
|
168
|
+
out = await search_mod.authorized_search(
|
|
169
|
+
query_type=SearchType.CHUNKS,
|
|
170
|
+
query_text="q",
|
|
171
|
+
user=user,
|
|
172
|
+
dataset_ids=[ds.id],
|
|
173
|
+
only_context=False,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
assert out == expected
|