cognee 0.5.1__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +2 -0
- cognee/alembic/README +1 -0
- cognee/alembic/env.py +107 -0
- cognee/alembic/script.py.mako +26 -0
- cognee/alembic/versions/1a58b986e6e1_enable_delete_for_old_tutorial_notebooks.py +52 -0
- cognee/alembic/versions/1d0bb7fede17_add_pipeline_run_status.py +33 -0
- cognee/alembic/versions/1daae0df1866_incremental_loading.py +48 -0
- cognee/alembic/versions/211ab850ef3d_add_sync_operations_table.py +118 -0
- cognee/alembic/versions/45957f0a9849_add_notebook_table.py +46 -0
- cognee/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py +333 -0
- cognee/alembic/versions/482cd6517ce4_add_default_user.py +30 -0
- cognee/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py +98 -0
- cognee/alembic/versions/8057ae7329c2_initial_migration.py +25 -0
- cognee/alembic/versions/9e7a3cb85175_loader_separation.py +104 -0
- cognee/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py +38 -0
- cognee/alembic/versions/ab7e313804ae_permission_system_rework.py +236 -0
- cognee/alembic/versions/b9274c27a25a_kuzu_11_migration.py +75 -0
- cognee/alembic/versions/c946955da633_multi_tenant_support.py +137 -0
- cognee/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py +51 -0
- cognee/alembic/versions/e4ebee1091e7_expand_data_model_info.py +140 -0
- cognee/alembic.ini +117 -0
- cognee/api/v1/add/add.py +2 -1
- cognee/api/v1/add/routers/get_add_router.py +2 -0
- cognee/api/v1/cognify/cognify.py +11 -6
- cognee/api/v1/cognify/routers/get_cognify_router.py +8 -0
- cognee/api/v1/config/config.py +60 -0
- cognee/api/v1/datasets/routers/get_datasets_router.py +46 -3
- cognee/api/v1/memify/routers/get_memify_router.py +3 -0
- cognee/api/v1/search/routers/get_search_router.py +21 -6
- cognee/api/v1/search/search.py +21 -5
- cognee/api/v1/sync/routers/get_sync_router.py +3 -3
- cognee/cli/commands/add_command.py +1 -1
- cognee/cli/commands/cognify_command.py +6 -0
- cognee/cli/commands/config_command.py +1 -1
- cognee/context_global_variables.py +5 -1
- cognee/eval_framework/answer_generation/answer_generation_executor.py +7 -8
- cognee/infrastructure/databases/cache/cache_db_interface.py +38 -1
- cognee/infrastructure/databases/cache/config.py +6 -0
- cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py +21 -0
- cognee/infrastructure/databases/cache/get_cache_engine.py +9 -3
- cognee/infrastructure/databases/cache/redis/RedisAdapter.py +60 -1
- cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +7 -0
- cognee/infrastructure/databases/graph/get_graph_engine.py +29 -1
- cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +62 -27
- cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +17 -4
- cognee/infrastructure/databases/relational/config.py +16 -1
- cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +26 -3
- cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +2 -0
- cognee/infrastructure/databases/vector/config.py +6 -0
- cognee/infrastructure/databases/vector/create_vector_engine.py +70 -16
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +64 -9
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +13 -2
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +16 -3
- cognee/infrastructure/databases/vector/models/ScoredResult.py +3 -3
- cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +16 -3
- cognee/infrastructure/databases/vector/pgvector/PGVectorDatasetDatabaseHandler.py +86 -0
- cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py +81 -2
- cognee/infrastructure/databases/vector/vector_db_interface.py +8 -0
- cognee/infrastructure/files/utils/get_data_file_path.py +33 -27
- cognee/infrastructure/llm/LLMGateway.py +0 -13
- cognee/infrastructure/llm/prompts/extract_query_time.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_event_entity_prompt.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_event_graph_prompt.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt.txt +2 -2
- cognee/infrastructure/llm/prompts/generate_graph_prompt_guided.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt_oneshot.txt +2 -2
- cognee/infrastructure/llm/prompts/generate_graph_prompt_simple.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt_strict.txt +1 -1
- cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt +6 -6
- cognee/infrastructure/llm/prompts/test.txt +1 -1
- cognee/infrastructure/llm/prompts/translate_content.txt +19 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +29 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llama_cpp/adapter.py +191 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
- cognee/modules/chunking/models/DocumentChunk.py +0 -1
- cognee/modules/cognify/config.py +2 -0
- cognee/modules/data/models/Data.py +3 -1
- cognee/modules/engine/models/Entity.py +0 -1
- cognee/modules/engine/operations/setup.py +6 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +150 -37
- cognee/modules/graph/cognee_graph/CogneeGraphElements.py +48 -2
- cognee/modules/graph/utils/__init__.py +1 -0
- cognee/modules/graph/utils/get_entity_nodes_from_triplets.py +12 -0
- cognee/modules/notebooks/methods/__init__.py +1 -0
- cognee/modules/notebooks/methods/create_notebook.py +0 -34
- cognee/modules/notebooks/methods/create_tutorial_notebooks.py +191 -0
- cognee/modules/notebooks/methods/get_notebooks.py +12 -8
- cognee/modules/notebooks/tutorials/cognee-basics/cell-1.md +3 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-2.md +10 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-3.md +7 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-4.py +28 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-5.py +3 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-6.py +9 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-7.py +17 -0
- cognee/modules/notebooks/tutorials/cognee-basics/config.json +4 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-1.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-10.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-11.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-12.py +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-13.md +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-14.py +6 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-15.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-16.py +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-2.md +9 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-3.md +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-4.md +9 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-5.md +5 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-6.py +13 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-7.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-8.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-9.py +31 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/config.json +4 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/copilot_conversations.json +107 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/guido_contributions.json +976 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/my_developer_rules.md +79 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/pep_style_guide.md +74 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/zen_principles.md +74 -0
- cognee/modules/retrieval/EntityCompletionRetriever.py +51 -38
- cognee/modules/retrieval/__init__.py +0 -1
- cognee/modules/retrieval/base_retriever.py +66 -10
- cognee/modules/retrieval/chunks_retriever.py +57 -49
- cognee/modules/retrieval/coding_rules_retriever.py +12 -5
- cognee/modules/retrieval/completion_retriever.py +29 -28
- cognee/modules/retrieval/cypher_search_retriever.py +25 -20
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +42 -46
- cognee/modules/retrieval/graph_completion_cot_retriever.py +68 -51
- cognee/modules/retrieval/graph_completion_retriever.py +78 -63
- cognee/modules/retrieval/graph_summary_completion_retriever.py +2 -0
- cognee/modules/retrieval/lexical_retriever.py +34 -12
- cognee/modules/retrieval/natural_language_retriever.py +18 -15
- cognee/modules/retrieval/summaries_retriever.py +51 -34
- cognee/modules/retrieval/temporal_retriever.py +59 -49
- cognee/modules/retrieval/triplet_retriever.py +32 -33
- cognee/modules/retrieval/utils/access_tracking.py +88 -0
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +99 -103
- cognee/modules/retrieval/utils/node_edge_vector_search.py +174 -0
- cognee/modules/search/methods/__init__.py +1 -0
- cognee/modules/search/methods/get_retriever_output.py +53 -0
- cognee/modules/search/methods/get_search_type_retriever_instance.py +252 -0
- cognee/modules/search/methods/search.py +90 -222
- cognee/modules/search/models/SearchResultPayload.py +67 -0
- cognee/modules/search/types/SearchResult.py +1 -8
- cognee/modules/search/types/SearchType.py +1 -2
- cognee/modules/search/types/__init__.py +1 -1
- cognee/modules/search/utils/__init__.py +1 -2
- cognee/modules/search/utils/transform_insights_to_graph.py +2 -2
- cognee/modules/search/utils/{transform_context_to_graph.py → transform_triplets_to_graph.py} +2 -2
- cognee/modules/users/authentication/default/default_transport.py +11 -1
- cognee/modules/users/authentication/get_api_auth_backend.py +2 -1
- cognee/modules/users/authentication/get_client_auth_backend.py +2 -1
- cognee/modules/users/methods/create_user.py +0 -9
- cognee/modules/users/permissions/methods/has_user_management_permission.py +29 -0
- cognee/modules/visualization/cognee_network_visualization.py +1 -1
- cognee/run_migrations.py +48 -0
- cognee/shared/exceptions/__init__.py +1 -3
- cognee/shared/exceptions/exceptions.py +11 -1
- cognee/shared/usage_logger.py +332 -0
- cognee/shared/utils.py +12 -5
- cognee/tasks/chunks/__init__.py +9 -0
- cognee/tasks/cleanup/cleanup_unused_data.py +172 -0
- cognee/tasks/graph/__init__.py +7 -0
- cognee/tasks/ingestion/data_item.py +8 -0
- cognee/tasks/ingestion/ingest_data.py +12 -1
- cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
- cognee/tasks/memify/__init__.py +8 -0
- cognee/tasks/memify/extract_usage_frequency.py +613 -0
- cognee/tasks/summarization/models.py +0 -2
- cognee/tasks/temporal_graph/__init__.py +0 -1
- cognee/tasks/translation/__init__.py +96 -0
- cognee/tasks/translation/config.py +110 -0
- cognee/tasks/translation/detect_language.py +190 -0
- cognee/tasks/translation/exceptions.py +62 -0
- cognee/tasks/translation/models.py +72 -0
- cognee/tasks/translation/providers/__init__.py +44 -0
- cognee/tasks/translation/providers/azure_provider.py +192 -0
- cognee/tasks/translation/providers/base.py +85 -0
- cognee/tasks/translation/providers/google_provider.py +158 -0
- cognee/tasks/translation/providers/llm_provider.py +143 -0
- cognee/tasks/translation/translate_content.py +282 -0
- cognee/tasks/web_scraper/default_url_crawler.py +6 -2
- cognee/tests/cli_tests/cli_unit_tests/test_cli_commands.py +1 -0
- cognee/tests/cli_tests/cli_unit_tests/test_cli_edge_cases.py +3 -0
- cognee/tests/integration/retrieval/test_brute_force_triplet_search_with_cognify.py +62 -0
- cognee/tests/integration/retrieval/test_chunks_retriever.py +351 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever.py +276 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +228 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +217 -0
- cognee/tests/integration/retrieval/test_rag_completion_retriever.py +319 -0
- cognee/tests/integration/retrieval/test_structured_output.py +258 -0
- cognee/tests/integration/retrieval/test_summaries_retriever.py +195 -0
- cognee/tests/integration/retrieval/test_temporal_retriever.py +336 -0
- cognee/tests/integration/retrieval/test_triplet_retriever.py +45 -1
- cognee/tests/integration/shared/test_usage_logger_integration.py +255 -0
- cognee/tests/tasks/translation/README.md +147 -0
- cognee/tests/tasks/translation/__init__.py +1 -0
- cognee/tests/tasks/translation/config_test.py +93 -0
- cognee/tests/tasks/translation/detect_language_test.py +118 -0
- cognee/tests/tasks/translation/providers_test.py +151 -0
- cognee/tests/tasks/translation/translate_content_test.py +213 -0
- cognee/tests/test_chromadb.py +1 -1
- cognee/tests/test_cleanup_unused_data.py +165 -0
- cognee/tests/test_custom_data_label.py +68 -0
- cognee/tests/test_delete_by_id.py +6 -6
- cognee/tests/test_extract_usage_frequency.py +308 -0
- cognee/tests/test_kuzu.py +17 -7
- cognee/tests/test_lancedb.py +3 -1
- cognee/tests/test_library.py +1 -1
- cognee/tests/test_neo4j.py +17 -7
- cognee/tests/test_neptune_analytics_vector.py +3 -1
- cognee/tests/test_permissions.py +172 -187
- cognee/tests/test_pgvector.py +3 -1
- cognee/tests/test_relational_db_migration.py +15 -1
- cognee/tests/test_remote_kuzu.py +3 -1
- cognee/tests/test_s3_file_storage.py +1 -1
- cognee/tests/test_search_db.py +345 -205
- cognee/tests/test_usage_logger_e2e.py +268 -0
- cognee/tests/unit/api/test_get_raw_data_endpoint.py +206 -0
- cognee/tests/unit/eval_framework/answer_generation_test.py +4 -3
- cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
- cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
- cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +2 -0
- cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
- cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +42 -2
- cognee/tests/unit/modules/graph/cognee_graph_test.py +329 -31
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +122 -168
- cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +486 -157
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +693 -155
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +619 -200
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +300 -171
- cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +184 -155
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +544 -79
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +476 -28
- cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
- cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
- cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py +273 -0
- cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +267 -7
- cognee/tests/unit/modules/search/test_get_search_type_retriever_instance.py +125 -0
- cognee/tests/unit/modules/search/test_search.py +96 -20
- cognee/tests/unit/modules/search/test_search_prepare_search_result_contract.py +190 -0
- cognee/tests/unit/modules/users/test_tutorial_notebook_creation.py +511 -297
- cognee/tests/unit/shared/test_usage_logger.py +241 -0
- cognee/tests/unit/users/permissions/test_has_user_management_permission.py +46 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/METADATA +22 -17
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/RECORD +258 -157
- cognee/api/.env.example +0 -5
- cognee/modules/retrieval/base_graph_retriever.py +0 -24
- cognee/modules/search/methods/get_search_type_tools.py +0 -223
- cognee/modules/search/methods/no_access_control_search.py +0 -62
- cognee/modules/search/utils/prepare_search_result.py +0 -63
- cognee/tests/test_feedback_enrichment.py +0 -174
- cognee/tests/unit/modules/retrieval/structured_output_test.py +0 -204
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/WHEEL +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -1,12 +1,15 @@
|
|
|
1
1
|
import pytest
|
|
2
|
-
from unittest.mock import AsyncMock, patch
|
|
2
|
+
from unittest.mock import AsyncMock, patch, MagicMock
|
|
3
3
|
|
|
4
4
|
from cognee.modules.retrieval.utils.brute_force_triplet_search import (
|
|
5
5
|
brute_force_triplet_search,
|
|
6
6
|
get_memory_fragment,
|
|
7
|
+
format_triplets,
|
|
7
8
|
)
|
|
9
|
+
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
|
|
8
10
|
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
|
9
11
|
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
|
12
|
+
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
|
10
13
|
|
|
11
14
|
|
|
12
15
|
class MockScoredResult:
|
|
@@ -28,7 +31,7 @@ async def test_brute_force_triplet_search_empty_query():
|
|
|
28
31
|
@pytest.mark.asyncio
|
|
29
32
|
async def test_brute_force_triplet_search_none_query():
|
|
30
33
|
"""Test that None query raises ValueError."""
|
|
31
|
-
with pytest.raises(ValueError, match="
|
|
34
|
+
with pytest.raises(ValueError, match="Must provide either 'query' or 'query_batch'."):
|
|
32
35
|
await brute_force_triplet_search(query=None)
|
|
33
36
|
|
|
34
37
|
|
|
@@ -55,7 +58,7 @@ async def test_brute_force_triplet_search_wide_search_limit_global_search():
|
|
|
55
58
|
mock_vector_engine.search = AsyncMock(return_value=[])
|
|
56
59
|
|
|
57
60
|
with patch(
|
|
58
|
-
"cognee.modules.retrieval.utils.
|
|
61
|
+
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
|
59
62
|
return_value=mock_vector_engine,
|
|
60
63
|
):
|
|
61
64
|
await brute_force_triplet_search(
|
|
@@ -77,7 +80,7 @@ async def test_brute_force_triplet_search_wide_search_limit_filtered_search():
|
|
|
77
80
|
mock_vector_engine.search = AsyncMock(return_value=[])
|
|
78
81
|
|
|
79
82
|
with patch(
|
|
80
|
-
"cognee.modules.retrieval.utils.
|
|
83
|
+
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
|
81
84
|
return_value=mock_vector_engine,
|
|
82
85
|
):
|
|
83
86
|
await brute_force_triplet_search(
|
|
@@ -99,7 +102,7 @@ async def test_brute_force_triplet_search_wide_search_default():
|
|
|
99
102
|
mock_vector_engine.search = AsyncMock(return_value=[])
|
|
100
103
|
|
|
101
104
|
with patch(
|
|
102
|
-
"cognee.modules.retrieval.utils.
|
|
105
|
+
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
|
103
106
|
return_value=mock_vector_engine,
|
|
104
107
|
):
|
|
105
108
|
await brute_force_triplet_search(query="test", node_name=None)
|
|
@@ -117,7 +120,7 @@ async def test_brute_force_triplet_search_default_collections():
|
|
|
117
120
|
mock_vector_engine.search = AsyncMock(return_value=[])
|
|
118
121
|
|
|
119
122
|
with patch(
|
|
120
|
-
"cognee.modules.retrieval.utils.
|
|
123
|
+
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
|
121
124
|
return_value=mock_vector_engine,
|
|
122
125
|
):
|
|
123
126
|
await brute_force_triplet_search(query="test")
|
|
@@ -147,7 +150,7 @@ async def test_brute_force_triplet_search_custom_collections():
|
|
|
147
150
|
custom_collections = ["CustomCol1", "CustomCol2"]
|
|
148
151
|
|
|
149
152
|
with patch(
|
|
150
|
-
"cognee.modules.retrieval.utils.
|
|
153
|
+
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
|
151
154
|
return_value=mock_vector_engine,
|
|
152
155
|
):
|
|
153
156
|
await brute_force_triplet_search(query="test", collections=custom_collections)
|
|
@@ -169,7 +172,7 @@ async def test_brute_force_triplet_search_always_includes_edge_collection():
|
|
|
169
172
|
collections_without_edge = ["Entity_name", "TextSummary_text"]
|
|
170
173
|
|
|
171
174
|
with patch(
|
|
172
|
-
"cognee.modules.retrieval.utils.
|
|
175
|
+
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
|
173
176
|
return_value=mock_vector_engine,
|
|
174
177
|
):
|
|
175
178
|
await brute_force_triplet_search(query="test", collections=collections_without_edge)
|
|
@@ -192,7 +195,7 @@ async def test_brute_force_triplet_search_all_collections_empty():
|
|
|
192
195
|
mock_vector_engine.search = AsyncMock(return_value=[])
|
|
193
196
|
|
|
194
197
|
with patch(
|
|
195
|
-
"cognee.modules.retrieval.utils.
|
|
198
|
+
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
|
196
199
|
return_value=mock_vector_engine,
|
|
197
200
|
):
|
|
198
201
|
results = await brute_force_triplet_search(query="test")
|
|
@@ -214,7 +217,7 @@ async def test_brute_force_triplet_search_embeds_query():
|
|
|
214
217
|
mock_vector_engine.search = AsyncMock(return_value=[])
|
|
215
218
|
|
|
216
219
|
with patch(
|
|
217
|
-
"cognee.modules.retrieval.utils.
|
|
220
|
+
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
|
218
221
|
return_value=mock_vector_engine,
|
|
219
222
|
):
|
|
220
223
|
await brute_force_triplet_search(query=query_text)
|
|
@@ -247,7 +250,7 @@ async def test_brute_force_triplet_search_extracts_node_ids_global_search():
|
|
|
247
250
|
|
|
248
251
|
with (
|
|
249
252
|
patch(
|
|
250
|
-
"cognee.modules.retrieval.utils.
|
|
253
|
+
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
|
251
254
|
return_value=mock_vector_engine,
|
|
252
255
|
),
|
|
253
256
|
patch(
|
|
@@ -277,7 +280,7 @@ async def test_brute_force_triplet_search_reuses_provided_fragment():
|
|
|
277
280
|
|
|
278
281
|
with (
|
|
279
282
|
patch(
|
|
280
|
-
"cognee.modules.retrieval.utils.
|
|
283
|
+
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
|
281
284
|
return_value=mock_vector_engine,
|
|
282
285
|
),
|
|
283
286
|
patch(
|
|
@@ -309,7 +312,7 @@ async def test_brute_force_triplet_search_creates_fragment_when_not_provided():
|
|
|
309
312
|
|
|
310
313
|
with (
|
|
311
314
|
patch(
|
|
312
|
-
"cognee.modules.retrieval.utils.
|
|
315
|
+
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
|
313
316
|
return_value=mock_vector_engine,
|
|
314
317
|
),
|
|
315
318
|
patch(
|
|
@@ -338,7 +341,7 @@ async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation
|
|
|
338
341
|
|
|
339
342
|
with (
|
|
340
343
|
patch(
|
|
341
|
-
"cognee.modules.retrieval.utils.
|
|
344
|
+
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
|
342
345
|
return_value=mock_vector_engine,
|
|
343
346
|
),
|
|
344
347
|
patch(
|
|
@@ -349,25 +352,37 @@ async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation
|
|
|
349
352
|
custom_top_k = 15
|
|
350
353
|
await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"])
|
|
351
354
|
|
|
352
|
-
mock_fragment.calculate_top_triplet_importances.assert_called_once_with(
|
|
355
|
+
mock_fragment.calculate_top_triplet_importances.assert_called_once_with(
|
|
356
|
+
k=custom_top_k, query_list_length=None
|
|
357
|
+
)
|
|
353
358
|
|
|
354
359
|
|
|
355
360
|
@pytest.mark.asyncio
|
|
356
361
|
async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found():
|
|
357
|
-
"""Test that get_memory_fragment returns empty graph when entity not found."""
|
|
362
|
+
"""Test that get_memory_fragment returns empty graph when entity not found (line 85)."""
|
|
358
363
|
mock_graph_engine = AsyncMock()
|
|
359
|
-
|
|
364
|
+
|
|
365
|
+
# Create a mock fragment that will raise EntityNotFoundError when project_graph_from_db is called
|
|
366
|
+
mock_fragment = MagicMock(spec=CogneeGraph)
|
|
367
|
+
mock_fragment.project_graph_from_db = AsyncMock(
|
|
360
368
|
side_effect=EntityNotFoundError("Entity not found")
|
|
361
369
|
)
|
|
362
370
|
|
|
363
|
-
with
|
|
364
|
-
|
|
365
|
-
|
|
371
|
+
with (
|
|
372
|
+
patch(
|
|
373
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
|
|
374
|
+
return_value=mock_graph_engine,
|
|
375
|
+
),
|
|
376
|
+
patch(
|
|
377
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.CogneeGraph",
|
|
378
|
+
return_value=mock_fragment,
|
|
379
|
+
),
|
|
366
380
|
):
|
|
367
|
-
|
|
381
|
+
result = await get_memory_fragment()
|
|
368
382
|
|
|
369
|
-
|
|
370
|
-
assert
|
|
383
|
+
# Fragment should be returned even though EntityNotFoundError was raised (pass statement on line 85)
|
|
384
|
+
assert result == mock_fragment
|
|
385
|
+
mock_fragment.project_graph_from_db.assert_awaited_once()
|
|
371
386
|
|
|
372
387
|
|
|
373
388
|
@pytest.mark.asyncio
|
|
@@ -418,7 +433,7 @@ async def test_brute_force_triplet_search_deduplicates_node_ids():
|
|
|
418
433
|
|
|
419
434
|
with (
|
|
420
435
|
patch(
|
|
421
|
-
"cognee.modules.retrieval.utils.
|
|
436
|
+
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
|
422
437
|
return_value=mock_vector_engine,
|
|
423
438
|
),
|
|
424
439
|
patch(
|
|
@@ -459,7 +474,7 @@ async def test_brute_force_triplet_search_excludes_edge_collection():
|
|
|
459
474
|
|
|
460
475
|
with (
|
|
461
476
|
patch(
|
|
462
|
-
"cognee.modules.retrieval.utils.
|
|
477
|
+
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
|
463
478
|
return_value=mock_vector_engine,
|
|
464
479
|
),
|
|
465
480
|
patch(
|
|
@@ -511,7 +526,7 @@ async def test_brute_force_triplet_search_skips_nodes_without_ids():
|
|
|
511
526
|
|
|
512
527
|
with (
|
|
513
528
|
patch(
|
|
514
|
-
"cognee.modules.retrieval.utils.
|
|
529
|
+
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
|
515
530
|
return_value=mock_vector_engine,
|
|
516
531
|
),
|
|
517
532
|
patch(
|
|
@@ -552,7 +567,7 @@ async def test_brute_force_triplet_search_handles_tuple_results():
|
|
|
552
567
|
|
|
553
568
|
with (
|
|
554
569
|
patch(
|
|
555
|
-
"cognee.modules.retrieval.utils.
|
|
570
|
+
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
|
556
571
|
return_value=mock_vector_engine,
|
|
557
572
|
),
|
|
558
573
|
patch(
|
|
@@ -594,7 +609,7 @@ async def test_brute_force_triplet_search_mixed_empty_collections():
|
|
|
594
609
|
|
|
595
610
|
with (
|
|
596
611
|
patch(
|
|
597
|
-
"cognee.modules.retrieval.utils.
|
|
612
|
+
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
|
598
613
|
return_value=mock_vector_engine,
|
|
599
614
|
),
|
|
600
615
|
patch(
|
|
@@ -606,3 +621,436 @@ async def test_brute_force_triplet_search_mixed_empty_collections():
|
|
|
606
621
|
|
|
607
622
|
call_kwargs = mock_get_fragment_fn.call_args[1]
|
|
608
623
|
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
|
|
624
|
+
|
|
625
|
+
|
|
626
|
+
def test_format_triplets():
|
|
627
|
+
"""Test format_triplets function."""
|
|
628
|
+
mock_edge = MagicMock()
|
|
629
|
+
mock_node1 = MagicMock()
|
|
630
|
+
mock_node2 = MagicMock()
|
|
631
|
+
|
|
632
|
+
mock_node1.attributes = {"name": "Node1", "type": "Entity", "id": "n1"}
|
|
633
|
+
mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": "n2"}
|
|
634
|
+
mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": "connects"}
|
|
635
|
+
|
|
636
|
+
mock_edge.node1 = mock_node1
|
|
637
|
+
mock_edge.node2 = mock_node2
|
|
638
|
+
|
|
639
|
+
result = format_triplets([mock_edge])
|
|
640
|
+
|
|
641
|
+
assert isinstance(result, str)
|
|
642
|
+
assert "Node1" in result
|
|
643
|
+
assert "Node2" in result
|
|
644
|
+
assert "relates_to" in result
|
|
645
|
+
assert "connects" in result
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
def test_format_triplets_with_none_values():
|
|
649
|
+
"""Test format_triplets filters out None values."""
|
|
650
|
+
mock_edge = MagicMock()
|
|
651
|
+
mock_node1 = MagicMock()
|
|
652
|
+
mock_node2 = MagicMock()
|
|
653
|
+
|
|
654
|
+
mock_node1.attributes = {"name": "Node1", "type": None, "id": "n1"}
|
|
655
|
+
mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": None}
|
|
656
|
+
mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": None}
|
|
657
|
+
|
|
658
|
+
mock_edge.node1 = mock_node1
|
|
659
|
+
mock_edge.node2 = mock_node2
|
|
660
|
+
|
|
661
|
+
result = format_triplets([mock_edge])
|
|
662
|
+
|
|
663
|
+
assert "Node1" in result
|
|
664
|
+
assert "Node2" in result
|
|
665
|
+
assert "relates_to" in result
|
|
666
|
+
assert "None" not in result or result.count("None") == 0
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
def test_format_triplets_with_nested_dict():
|
|
670
|
+
"""Test format_triplets handles nested dict attributes (lines 23-35)."""
|
|
671
|
+
mock_edge = MagicMock()
|
|
672
|
+
mock_node1 = MagicMock()
|
|
673
|
+
mock_node2 = MagicMock()
|
|
674
|
+
|
|
675
|
+
mock_node1.attributes = {"name": "Node1", "metadata": {"type": "Entity", "id": "n1"}}
|
|
676
|
+
mock_node2.attributes = {"name": "Node2", "metadata": {"type": "Entity", "id": "n2"}}
|
|
677
|
+
mock_edge.attributes = {"relationship_name": "relates_to"}
|
|
678
|
+
|
|
679
|
+
mock_edge.node1 = mock_node1
|
|
680
|
+
mock_edge.node2 = mock_node2
|
|
681
|
+
|
|
682
|
+
result = format_triplets([mock_edge])
|
|
683
|
+
|
|
684
|
+
assert isinstance(result, str)
|
|
685
|
+
assert "Node1" in result
|
|
686
|
+
assert "Node2" in result
|
|
687
|
+
assert "relates_to" in result
|
|
688
|
+
|
|
689
|
+
|
|
690
|
+
@pytest.mark.asyncio
|
|
691
|
+
async def test_brute_force_triplet_search_vector_engine_init_error():
|
|
692
|
+
"""Test brute_force_triplet_search handles vector engine initialization error (lines 145-147)."""
|
|
693
|
+
with (
|
|
694
|
+
patch(
|
|
695
|
+
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine"
|
|
696
|
+
) as mock_get_vector_engine,
|
|
697
|
+
):
|
|
698
|
+
mock_get_vector_engine.side_effect = Exception("Initialization error")
|
|
699
|
+
|
|
700
|
+
with pytest.raises(RuntimeError, match="Initialization error"):
|
|
701
|
+
await brute_force_triplet_search(query="test query")
|
|
702
|
+
|
|
703
|
+
|
|
704
|
+
@pytest.mark.asyncio
|
|
705
|
+
async def test_brute_force_triplet_search_collection_not_found_error():
|
|
706
|
+
"""Test brute_force_triplet_search handles CollectionNotFoundError in search (lines 156-157)."""
|
|
707
|
+
mock_vector_engine = AsyncMock()
|
|
708
|
+
mock_embedding_engine = AsyncMock()
|
|
709
|
+
mock_vector_engine.embedding_engine = mock_embedding_engine
|
|
710
|
+
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
711
|
+
|
|
712
|
+
mock_vector_engine.search = AsyncMock(
|
|
713
|
+
side_effect=[
|
|
714
|
+
CollectionNotFoundError("Collection not found"),
|
|
715
|
+
[],
|
|
716
|
+
[],
|
|
717
|
+
]
|
|
718
|
+
)
|
|
719
|
+
|
|
720
|
+
with (
|
|
721
|
+
patch(
|
|
722
|
+
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
|
723
|
+
return_value=mock_vector_engine,
|
|
724
|
+
),
|
|
725
|
+
patch(
|
|
726
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
|
727
|
+
return_value=CogneeGraph(),
|
|
728
|
+
),
|
|
729
|
+
):
|
|
730
|
+
result = await brute_force_triplet_search(
|
|
731
|
+
query="test query", collections=["missing_collection", "existing_collection"]
|
|
732
|
+
)
|
|
733
|
+
|
|
734
|
+
assert result == []
|
|
735
|
+
|
|
736
|
+
|
|
737
|
+
@pytest.mark.asyncio
|
|
738
|
+
async def test_brute_force_triplet_search_generic_exception():
|
|
739
|
+
"""Test brute_force_triplet_search handles generic exceptions (lines 209-217)."""
|
|
740
|
+
mock_vector_engine = AsyncMock()
|
|
741
|
+
mock_embedding_engine = AsyncMock()
|
|
742
|
+
mock_vector_engine.embedding_engine = mock_embedding_engine
|
|
743
|
+
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
744
|
+
|
|
745
|
+
mock_vector_engine.search = AsyncMock(side_effect=Exception("Generic error"))
|
|
746
|
+
|
|
747
|
+
with (
|
|
748
|
+
patch(
|
|
749
|
+
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
|
750
|
+
return_value=mock_vector_engine,
|
|
751
|
+
),
|
|
752
|
+
):
|
|
753
|
+
with pytest.raises(Exception, match="Generic error"):
|
|
754
|
+
await brute_force_triplet_search(query="test query")
|
|
755
|
+
|
|
756
|
+
|
|
757
|
+
@pytest.mark.asyncio
|
|
758
|
+
async def test_brute_force_triplet_search_with_node_name_sets_relevant_ids_to_none():
|
|
759
|
+
"""Test brute_force_triplet_search sets relevant_ids_to_filter to None when node_name is provided (line 191)."""
|
|
760
|
+
mock_vector_engine = AsyncMock()
|
|
761
|
+
mock_embedding_engine = AsyncMock()
|
|
762
|
+
mock_vector_engine.embedding_engine = mock_embedding_engine
|
|
763
|
+
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
764
|
+
|
|
765
|
+
mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"})
|
|
766
|
+
mock_vector_engine.search = AsyncMock(return_value=[mock_result])
|
|
767
|
+
|
|
768
|
+
mock_fragment = AsyncMock()
|
|
769
|
+
mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock()
|
|
770
|
+
mock_fragment.map_vector_distances_to_graph_edges = AsyncMock()
|
|
771
|
+
mock_fragment.calculate_top_triplet_importances = AsyncMock(return_value=[])
|
|
772
|
+
|
|
773
|
+
with (
|
|
774
|
+
patch(
|
|
775
|
+
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
|
776
|
+
return_value=mock_vector_engine,
|
|
777
|
+
),
|
|
778
|
+
patch(
|
|
779
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
|
780
|
+
return_value=mock_fragment,
|
|
781
|
+
) as mock_get_fragment,
|
|
782
|
+
):
|
|
783
|
+
await brute_force_triplet_search(query="test query", node_name=["Node1"])
|
|
784
|
+
|
|
785
|
+
assert mock_get_fragment.called
|
|
786
|
+
call_kwargs = mock_get_fragment.call_args.kwargs if mock_get_fragment.call_args else {}
|
|
787
|
+
assert call_kwargs.get("relevant_ids_to_filter") is None
|
|
788
|
+
|
|
789
|
+
|
|
790
|
+
@pytest.mark.asyncio
|
|
791
|
+
async def test_brute_force_triplet_search_collection_not_found_at_top_level():
|
|
792
|
+
"""Test brute_force_triplet_search handles CollectionNotFoundError at top level (line 210)."""
|
|
793
|
+
mock_vector_engine = AsyncMock()
|
|
794
|
+
mock_embedding_engine = AsyncMock()
|
|
795
|
+
mock_vector_engine.embedding_engine = mock_embedding_engine
|
|
796
|
+
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
797
|
+
|
|
798
|
+
mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"})
|
|
799
|
+
mock_vector_engine.search = AsyncMock(return_value=[mock_result])
|
|
800
|
+
|
|
801
|
+
mock_fragment = AsyncMock()
|
|
802
|
+
mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock()
|
|
803
|
+
mock_fragment.map_vector_distances_to_graph_edges = AsyncMock()
|
|
804
|
+
mock_fragment.calculate_top_triplet_importances = AsyncMock(
|
|
805
|
+
side_effect=CollectionNotFoundError("Collection not found")
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
with (
|
|
809
|
+
patch(
|
|
810
|
+
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
|
811
|
+
return_value=mock_vector_engine,
|
|
812
|
+
),
|
|
813
|
+
patch(
|
|
814
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
|
815
|
+
return_value=mock_fragment,
|
|
816
|
+
),
|
|
817
|
+
):
|
|
818
|
+
result = await brute_force_triplet_search(query="test query")
|
|
819
|
+
|
|
820
|
+
assert result == []
|
|
821
|
+
|
|
822
|
+
|
|
823
|
+
@pytest.mark.asyncio
|
|
824
|
+
async def test_brute_force_triplet_search_single_query_regression():
|
|
825
|
+
"""Test that single-query mode maintains legacy behavior (flat list, ID filtering)."""
|
|
826
|
+
mock_vector_engine = AsyncMock()
|
|
827
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
828
|
+
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
829
|
+
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("node1", 0.95)])
|
|
830
|
+
|
|
831
|
+
mock_fragment = AsyncMock(
|
|
832
|
+
map_vector_distances_to_graph_nodes=AsyncMock(),
|
|
833
|
+
map_vector_distances_to_graph_edges=AsyncMock(),
|
|
834
|
+
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
|
835
|
+
)
|
|
836
|
+
|
|
837
|
+
with (
|
|
838
|
+
patch(
|
|
839
|
+
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
|
840
|
+
return_value=mock_vector_engine,
|
|
841
|
+
),
|
|
842
|
+
patch(
|
|
843
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
|
844
|
+
return_value=mock_fragment,
|
|
845
|
+
) as mock_get_fragment,
|
|
846
|
+
):
|
|
847
|
+
result = await brute_force_triplet_search(
|
|
848
|
+
query="q1", query_batch=None, wide_search_top_k=10, node_name=None
|
|
849
|
+
)
|
|
850
|
+
|
|
851
|
+
assert isinstance(result, list)
|
|
852
|
+
assert not (result and isinstance(result[0], list))
|
|
853
|
+
mock_get_fragment.assert_called_once()
|
|
854
|
+
call_kwargs = mock_get_fragment.call_args[1]
|
|
855
|
+
assert call_kwargs["relevant_ids_to_filter"] is not None
|
|
856
|
+
|
|
857
|
+
|
|
858
|
+
@pytest.mark.asyncio
|
|
859
|
+
async def test_brute_force_triplet_search_batch_wiring_happy_path():
|
|
860
|
+
"""Test that batch mode returns list-of-lists and skips ID filtering."""
|
|
861
|
+
mock_vector_engine = AsyncMock()
|
|
862
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
863
|
+
mock_vector_engine.batch_search = AsyncMock(
|
|
864
|
+
return_value=[
|
|
865
|
+
[MockScoredResult("node1", 0.95)],
|
|
866
|
+
[MockScoredResult("node2", 0.87)],
|
|
867
|
+
]
|
|
868
|
+
)
|
|
869
|
+
|
|
870
|
+
mock_fragment = AsyncMock(
|
|
871
|
+
map_vector_distances_to_graph_nodes=AsyncMock(),
|
|
872
|
+
map_vector_distances_to_graph_edges=AsyncMock(),
|
|
873
|
+
calculate_top_triplet_importances=AsyncMock(return_value=[[], []]),
|
|
874
|
+
)
|
|
875
|
+
|
|
876
|
+
with (
|
|
877
|
+
patch(
|
|
878
|
+
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
|
879
|
+
return_value=mock_vector_engine,
|
|
880
|
+
),
|
|
881
|
+
patch(
|
|
882
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
|
883
|
+
return_value=mock_fragment,
|
|
884
|
+
) as mock_get_fragment,
|
|
885
|
+
):
|
|
886
|
+
result = await brute_force_triplet_search(query_batch=["q1", "q2"])
|
|
887
|
+
|
|
888
|
+
assert isinstance(result, list)
|
|
889
|
+
assert len(result) == 2
|
|
890
|
+
assert isinstance(result[0], list)
|
|
891
|
+
assert isinstance(result[1], list)
|
|
892
|
+
mock_get_fragment.assert_called_once()
|
|
893
|
+
call_kwargs = mock_get_fragment.call_args[1]
|
|
894
|
+
assert call_kwargs["relevant_ids_to_filter"] is None
|
|
895
|
+
|
|
896
|
+
|
|
897
|
+
@pytest.mark.asyncio
|
|
898
|
+
async def test_brute_force_triplet_search_shape_propagation_to_graph():
|
|
899
|
+
"""Test that query_list_length is passed through to graph mapping methods."""
|
|
900
|
+
mock_vector_engine = AsyncMock()
|
|
901
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
902
|
+
mock_vector_engine.batch_search = AsyncMock(
|
|
903
|
+
return_value=[
|
|
904
|
+
[MockScoredResult("node1", 0.95)],
|
|
905
|
+
[MockScoredResult("node2", 0.87)],
|
|
906
|
+
]
|
|
907
|
+
)
|
|
908
|
+
|
|
909
|
+
mock_fragment = AsyncMock(
|
|
910
|
+
map_vector_distances_to_graph_nodes=AsyncMock(),
|
|
911
|
+
map_vector_distances_to_graph_edges=AsyncMock(),
|
|
912
|
+
calculate_top_triplet_importances=AsyncMock(return_value=[[], []]),
|
|
913
|
+
)
|
|
914
|
+
|
|
915
|
+
with (
|
|
916
|
+
patch(
|
|
917
|
+
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
|
918
|
+
return_value=mock_vector_engine,
|
|
919
|
+
),
|
|
920
|
+
patch(
|
|
921
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
|
922
|
+
return_value=mock_fragment,
|
|
923
|
+
),
|
|
924
|
+
):
|
|
925
|
+
await brute_force_triplet_search(query_batch=["q1", "q2"])
|
|
926
|
+
|
|
927
|
+
mock_fragment.map_vector_distances_to_graph_nodes.assert_called_once()
|
|
928
|
+
node_call_kwargs = mock_fragment.map_vector_distances_to_graph_nodes.call_args[1]
|
|
929
|
+
assert "query_list_length" in node_call_kwargs
|
|
930
|
+
assert node_call_kwargs["query_list_length"] == 2
|
|
931
|
+
|
|
932
|
+
mock_fragment.map_vector_distances_to_graph_edges.assert_called_once()
|
|
933
|
+
edge_call_kwargs = mock_fragment.map_vector_distances_to_graph_edges.call_args[1]
|
|
934
|
+
assert "query_list_length" in edge_call_kwargs
|
|
935
|
+
assert edge_call_kwargs["query_list_length"] == 2
|
|
936
|
+
|
|
937
|
+
mock_fragment.calculate_top_triplet_importances.assert_called_once()
|
|
938
|
+
importance_call_kwargs = mock_fragment.calculate_top_triplet_importances.call_args[1]
|
|
939
|
+
assert "query_list_length" in importance_call_kwargs
|
|
940
|
+
assert importance_call_kwargs["query_list_length"] == 2
|
|
941
|
+
|
|
942
|
+
|
|
943
|
+
@pytest.mark.asyncio
|
|
944
|
+
async def test_brute_force_triplet_search_batch_path_comprehensive():
|
|
945
|
+
"""Test batch mode: returns list-of-lists, skips ID filtering, passes None for wide_search_limit."""
|
|
946
|
+
mock_vector_engine = AsyncMock()
|
|
947
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
948
|
+
|
|
949
|
+
def batch_search_side_effect(*args, **kwargs):
|
|
950
|
+
collection_name = kwargs.get("collection_name")
|
|
951
|
+
if collection_name == "Entity_name":
|
|
952
|
+
return [
|
|
953
|
+
[MockScoredResult("node1", 0.95)],
|
|
954
|
+
[MockScoredResult("node2", 0.87)],
|
|
955
|
+
]
|
|
956
|
+
elif collection_name == "EdgeType_relationship_name":
|
|
957
|
+
return [
|
|
958
|
+
[MockScoredResult("edge1", 0.92)],
|
|
959
|
+
[MockScoredResult("edge2", 0.88)],
|
|
960
|
+
]
|
|
961
|
+
return [[], []]
|
|
962
|
+
|
|
963
|
+
mock_vector_engine.batch_search = AsyncMock(side_effect=batch_search_side_effect)
|
|
964
|
+
|
|
965
|
+
mock_fragment = AsyncMock(
|
|
966
|
+
map_vector_distances_to_graph_nodes=AsyncMock(),
|
|
967
|
+
map_vector_distances_to_graph_edges=AsyncMock(),
|
|
968
|
+
calculate_top_triplet_importances=AsyncMock(return_value=[[], []]),
|
|
969
|
+
)
|
|
970
|
+
|
|
971
|
+
with (
|
|
972
|
+
patch(
|
|
973
|
+
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
|
974
|
+
return_value=mock_vector_engine,
|
|
975
|
+
),
|
|
976
|
+
patch(
|
|
977
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
|
978
|
+
return_value=mock_fragment,
|
|
979
|
+
) as mock_get_fragment,
|
|
980
|
+
):
|
|
981
|
+
result = await brute_force_triplet_search(
|
|
982
|
+
query_batch=["q1", "q2"], collections=["Entity_name", "EdgeType_relationship_name"]
|
|
983
|
+
)
|
|
984
|
+
|
|
985
|
+
assert isinstance(result, list)
|
|
986
|
+
assert len(result) == 2
|
|
987
|
+
assert isinstance(result[0], list)
|
|
988
|
+
assert isinstance(result[1], list)
|
|
989
|
+
|
|
990
|
+
mock_get_fragment.assert_called_once()
|
|
991
|
+
fragment_call_kwargs = mock_get_fragment.call_args[1]
|
|
992
|
+
assert fragment_call_kwargs["relevant_ids_to_filter"] is None
|
|
993
|
+
|
|
994
|
+
batch_search_calls = mock_vector_engine.batch_search.call_args_list
|
|
995
|
+
assert len(batch_search_calls) > 0
|
|
996
|
+
for call in batch_search_calls:
|
|
997
|
+
assert call[1]["limit"] is None
|
|
998
|
+
|
|
999
|
+
|
|
1000
|
+
@pytest.mark.asyncio
|
|
1001
|
+
async def test_brute_force_triplet_search_batch_error_fallback():
|
|
1002
|
+
"""Test that CollectionNotFoundError in batch mode returns [[], []] matching batch length."""
|
|
1003
|
+
mock_vector_engine = AsyncMock()
|
|
1004
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
1005
|
+
mock_vector_engine.batch_search = AsyncMock(
|
|
1006
|
+
side_effect=CollectionNotFoundError("Collection not found")
|
|
1007
|
+
)
|
|
1008
|
+
|
|
1009
|
+
with patch(
|
|
1010
|
+
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
|
1011
|
+
return_value=mock_vector_engine,
|
|
1012
|
+
):
|
|
1013
|
+
result = await brute_force_triplet_search(query_batch=["q1", "q2"])
|
|
1014
|
+
|
|
1015
|
+
assert result == [[], []]
|
|
1016
|
+
assert len(result) == 2
|
|
1017
|
+
|
|
1018
|
+
|
|
1019
|
+
@pytest.mark.asyncio
|
|
1020
|
+
async def test_cognee_graph_mapping_batch_shapes():
|
|
1021
|
+
"""Test that CogneeGraph mapping methods accept list-of-lists with query_list_length set."""
|
|
1022
|
+
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
|
|
1023
|
+
|
|
1024
|
+
graph = CogneeGraph()
|
|
1025
|
+
node1 = Node("node1", {"name": "Node1"})
|
|
1026
|
+
node2 = Node("node2", {"name": "Node2"})
|
|
1027
|
+
graph.add_node(node1)
|
|
1028
|
+
graph.add_node(node2)
|
|
1029
|
+
|
|
1030
|
+
edge = Edge(node1, node2, attributes={"edge_text": "relates_to"})
|
|
1031
|
+
graph.add_edge(edge)
|
|
1032
|
+
|
|
1033
|
+
node_distances_batch = {
|
|
1034
|
+
"Entity_name": [
|
|
1035
|
+
[MockScoredResult("node1", 0.95)],
|
|
1036
|
+
[MockScoredResult("node2", 0.87)],
|
|
1037
|
+
]
|
|
1038
|
+
}
|
|
1039
|
+
|
|
1040
|
+
edge_1_text = "relates_to"
|
|
1041
|
+
edge_2_text = "relates_to"
|
|
1042
|
+
edge_distances_batch = [
|
|
1043
|
+
[MockScoredResult(generate_edge_id(edge_1_text), 0.92, payload={"text": edge_1_text})],
|
|
1044
|
+
[MockScoredResult(generate_edge_id(edge_2_text), 0.88, payload={"text": edge_2_text})],
|
|
1045
|
+
]
|
|
1046
|
+
|
|
1047
|
+
await graph.map_vector_distances_to_graph_nodes(
|
|
1048
|
+
node_distances=node_distances_batch, query_list_length=2
|
|
1049
|
+
)
|
|
1050
|
+
await graph.map_vector_distances_to_graph_edges(
|
|
1051
|
+
edge_distances=edge_distances_batch, query_list_length=2
|
|
1052
|
+
)
|
|
1053
|
+
|
|
1054
|
+
assert node1.attributes.get("vector_distance") == [0.95, 3.5]
|
|
1055
|
+
assert node2.attributes.get("vector_distance") == [3.5, 0.87]
|
|
1056
|
+
assert edge.attributes.get("vector_distance") == [0.92, 0.88]
|