cognee 0.5.1.dev0__py3-none-any.whl → 0.5.2.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +2 -0
- cognee/alembic/README +1 -0
- cognee/alembic/env.py +107 -0
- cognee/alembic/script.py.mako +26 -0
- cognee/alembic/versions/1a58b986e6e1_enable_delete_for_old_tutorial_notebooks.py +52 -0
- cognee/alembic/versions/1d0bb7fede17_add_pipeline_run_status.py +33 -0
- cognee/alembic/versions/1daae0df1866_incremental_loading.py +48 -0
- cognee/alembic/versions/211ab850ef3d_add_sync_operations_table.py +118 -0
- cognee/alembic/versions/45957f0a9849_add_notebook_table.py +46 -0
- cognee/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py +333 -0
- cognee/alembic/versions/482cd6517ce4_add_default_user.py +30 -0
- cognee/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py +98 -0
- cognee/alembic/versions/8057ae7329c2_initial_migration.py +25 -0
- cognee/alembic/versions/9e7a3cb85175_loader_separation.py +104 -0
- cognee/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py +38 -0
- cognee/alembic/versions/ab7e313804ae_permission_system_rework.py +236 -0
- cognee/alembic/versions/b9274c27a25a_kuzu_11_migration.py +75 -0
- cognee/alembic/versions/c946955da633_multi_tenant_support.py +137 -0
- cognee/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py +51 -0
- cognee/alembic/versions/e4ebee1091e7_expand_data_model_info.py +140 -0
- cognee/alembic.ini +117 -0
- cognee/api/v1/add/routers/get_add_router.py +2 -0
- cognee/api/v1/cognify/cognify.py +11 -6
- cognee/api/v1/cognify/routers/get_cognify_router.py +8 -0
- cognee/api/v1/config/config.py +60 -0
- cognee/api/v1/datasets/routers/get_datasets_router.py +45 -3
- cognee/api/v1/memify/routers/get_memify_router.py +2 -0
- cognee/api/v1/search/routers/get_search_router.py +21 -6
- cognee/api/v1/search/search.py +25 -5
- cognee/api/v1/sync/routers/get_sync_router.py +3 -3
- cognee/cli/commands/add_command.py +1 -1
- cognee/cli/commands/cognify_command.py +6 -0
- cognee/cli/commands/config_command.py +1 -1
- cognee/context_global_variables.py +5 -1
- cognee/eval_framework/answer_generation/answer_generation_executor.py +7 -8
- cognee/infrastructure/databases/cache/cache_db_interface.py +38 -1
- cognee/infrastructure/databases/cache/config.py +6 -0
- cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py +21 -0
- cognee/infrastructure/databases/cache/get_cache_engine.py +9 -3
- cognee/infrastructure/databases/cache/redis/RedisAdapter.py +60 -1
- cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +7 -0
- cognee/infrastructure/databases/graph/get_graph_engine.py +29 -1
- cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +62 -27
- cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +17 -4
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +2 -1
- cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +2 -0
- cognee/infrastructure/databases/vector/config.py +6 -0
- cognee/infrastructure/databases/vector/create_vector_engine.py +69 -22
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +64 -9
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +13 -2
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +16 -3
- cognee/infrastructure/databases/vector/models/ScoredResult.py +3 -3
- cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +16 -3
- cognee/infrastructure/databases/vector/pgvector/PGVectorDatasetDatabaseHandler.py +86 -0
- cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py +81 -2
- cognee/infrastructure/databases/vector/vector_db_interface.py +8 -0
- cognee/infrastructure/files/utils/get_data_file_path.py +33 -27
- cognee/infrastructure/llm/prompts/extract_query_time.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_event_entity_prompt.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_event_graph_prompt.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt.txt +2 -2
- cognee/infrastructure/llm/prompts/generate_graph_prompt_guided.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt_oneshot.txt +2 -2
- cognee/infrastructure/llm/prompts/generate_graph_prompt_simple.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt_strict.txt +1 -1
- cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt +6 -6
- cognee/infrastructure/llm/prompts/test.txt +1 -1
- cognee/infrastructure/llm/prompts/translate_content.txt +19 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +24 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llama_cpp/adapter.py +191 -0
- cognee/modules/chunking/models/DocumentChunk.py +0 -1
- cognee/modules/cognify/config.py +2 -0
- cognee/modules/data/models/Data.py +1 -0
- cognee/modules/engine/models/Entity.py +0 -1
- cognee/modules/engine/operations/setup.py +6 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +150 -37
- cognee/modules/graph/cognee_graph/CogneeGraphElements.py +48 -2
- cognee/modules/graph/utils/__init__.py +1 -0
- cognee/modules/graph/utils/get_entity_nodes_from_triplets.py +12 -0
- cognee/modules/notebooks/methods/__init__.py +1 -0
- cognee/modules/notebooks/methods/create_notebook.py +0 -34
- cognee/modules/notebooks/methods/create_tutorial_notebooks.py +191 -0
- cognee/modules/notebooks/methods/get_notebooks.py +12 -8
- cognee/modules/notebooks/tutorials/cognee-basics/cell-1.md +3 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-2.md +10 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-3.md +7 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-4.py +28 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-5.py +3 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-6.py +9 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-7.py +17 -0
- cognee/modules/notebooks/tutorials/cognee-basics/config.json +4 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-1.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-10.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-11.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-12.py +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-13.md +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-14.py +6 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-15.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-16.py +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-2.md +9 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-3.md +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-4.md +9 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-5.md +5 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-6.py +13 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-7.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-8.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-9.py +31 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/config.json +4 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/copilot_conversations.json +107 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/guido_contributions.json +976 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/my_developer_rules.md +79 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/pep_style_guide.md +74 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/zen_principles.md +74 -0
- cognee/modules/retrieval/EntityCompletionRetriever.py +51 -38
- cognee/modules/retrieval/__init__.py +0 -1
- cognee/modules/retrieval/base_retriever.py +66 -10
- cognee/modules/retrieval/chunks_retriever.py +57 -49
- cognee/modules/retrieval/coding_rules_retriever.py +12 -5
- cognee/modules/retrieval/completion_retriever.py +29 -28
- cognee/modules/retrieval/cypher_search_retriever.py +25 -20
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +42 -46
- cognee/modules/retrieval/graph_completion_cot_retriever.py +68 -51
- cognee/modules/retrieval/graph_completion_retriever.py +78 -63
- cognee/modules/retrieval/graph_summary_completion_retriever.py +2 -0
- cognee/modules/retrieval/lexical_retriever.py +34 -12
- cognee/modules/retrieval/natural_language_retriever.py +18 -15
- cognee/modules/retrieval/summaries_retriever.py +51 -34
- cognee/modules/retrieval/temporal_retriever.py +59 -49
- cognee/modules/retrieval/triplet_retriever.py +31 -32
- cognee/modules/retrieval/utils/access_tracking.py +88 -0
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +99 -85
- cognee/modules/retrieval/utils/node_edge_vector_search.py +174 -0
- cognee/modules/search/methods/__init__.py +1 -0
- cognee/modules/search/methods/get_retriever_output.py +53 -0
- cognee/modules/search/methods/get_search_type_retriever_instance.py +252 -0
- cognee/modules/search/methods/search.py +90 -215
- cognee/modules/search/models/SearchResultPayload.py +67 -0
- cognee/modules/search/types/SearchResult.py +1 -8
- cognee/modules/search/types/SearchType.py +1 -2
- cognee/modules/search/types/__init__.py +1 -1
- cognee/modules/search/utils/__init__.py +1 -2
- cognee/modules/search/utils/transform_insights_to_graph.py +2 -2
- cognee/modules/search/utils/{transform_context_to_graph.py → transform_triplets_to_graph.py} +2 -2
- cognee/modules/users/authentication/default/default_transport.py +11 -1
- cognee/modules/users/authentication/get_api_auth_backend.py +2 -1
- cognee/modules/users/authentication/get_client_auth_backend.py +2 -1
- cognee/modules/users/methods/create_user.py +0 -9
- cognee/modules/users/permissions/methods/has_user_management_permission.py +29 -0
- cognee/modules/visualization/cognee_network_visualization.py +1 -1
- cognee/run_migrations.py +48 -0
- cognee/shared/exceptions/__init__.py +1 -3
- cognee/shared/exceptions/exceptions.py +11 -1
- cognee/shared/usage_logger.py +332 -0
- cognee/shared/utils.py +12 -5
- cognee/tasks/cleanup/cleanup_unused_data.py +172 -0
- cognee/tasks/memify/extract_usage_frequency.py +613 -0
- cognee/tasks/summarization/models.py +0 -2
- cognee/tasks/temporal_graph/__init__.py +0 -1
- cognee/tasks/translation/__init__.py +96 -0
- cognee/tasks/translation/config.py +110 -0
- cognee/tasks/translation/detect_language.py +190 -0
- cognee/tasks/translation/exceptions.py +62 -0
- cognee/tasks/translation/models.py +72 -0
- cognee/tasks/translation/providers/__init__.py +44 -0
- cognee/tasks/translation/providers/azure_provider.py +192 -0
- cognee/tasks/translation/providers/base.py +85 -0
- cognee/tasks/translation/providers/google_provider.py +158 -0
- cognee/tasks/translation/providers/llm_provider.py +143 -0
- cognee/tasks/translation/translate_content.py +282 -0
- cognee/tasks/web_scraper/default_url_crawler.py +6 -2
- cognee/tests/cli_tests/cli_unit_tests/test_cli_commands.py +1 -0
- cognee/tests/cli_tests/cli_unit_tests/test_cli_edge_cases.py +3 -0
- cognee/tests/integration/retrieval/test_brute_force_triplet_search_with_cognify.py +62 -0
- cognee/tests/integration/retrieval/test_chunks_retriever.py +115 -16
- cognee/tests/integration/retrieval/test_graph_completion_retriever.py +13 -5
- cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +22 -20
- cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +23 -24
- cognee/tests/integration/retrieval/test_rag_completion_retriever.py +70 -5
- cognee/tests/integration/retrieval/test_structured_output.py +62 -18
- cognee/tests/integration/retrieval/test_summaries_retriever.py +20 -9
- cognee/tests/integration/retrieval/test_temporal_retriever.py +38 -8
- cognee/tests/integration/retrieval/test_triplet_retriever.py +13 -4
- cognee/tests/integration/shared/test_usage_logger_integration.py +255 -0
- cognee/tests/tasks/translation/README.md +147 -0
- cognee/tests/tasks/translation/__init__.py +1 -0
- cognee/tests/tasks/translation/config_test.py +93 -0
- cognee/tests/tasks/translation/detect_language_test.py +118 -0
- cognee/tests/tasks/translation/providers_test.py +151 -0
- cognee/tests/tasks/translation/translate_content_test.py +213 -0
- cognee/tests/test_chromadb.py +1 -1
- cognee/tests/test_cleanup_unused_data.py +165 -0
- cognee/tests/test_delete_by_id.py +6 -6
- cognee/tests/test_extract_usage_frequency.py +308 -0
- cognee/tests/test_kuzu.py +17 -7
- cognee/tests/test_lancedb.py +3 -1
- cognee/tests/test_library.py +1 -1
- cognee/tests/test_neo4j.py +17 -7
- cognee/tests/test_neptune_analytics_vector.py +3 -1
- cognee/tests/test_permissions.py +172 -187
- cognee/tests/test_pgvector.py +3 -1
- cognee/tests/test_relational_db_migration.py +15 -1
- cognee/tests/test_remote_kuzu.py +3 -1
- cognee/tests/test_s3_file_storage.py +1 -1
- cognee/tests/test_search_db.py +97 -110
- cognee/tests/test_usage_logger_e2e.py +268 -0
- cognee/tests/unit/api/test_get_raw_data_endpoint.py +206 -0
- cognee/tests/unit/eval_framework/answer_generation_test.py +4 -3
- cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +2 -0
- cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +42 -2
- cognee/tests/unit/modules/graph/cognee_graph_test.py +329 -31
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +31 -59
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +70 -33
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +72 -52
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +27 -33
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +28 -15
- cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +37 -42
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +48 -64
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +263 -24
- cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py +273 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +30 -16
- cognee/tests/unit/modules/search/test_get_search_type_retriever_instance.py +125 -0
- cognee/tests/unit/modules/search/test_search.py +176 -0
- cognee/tests/unit/modules/search/test_search_prepare_search_result_contract.py +190 -0
- cognee/tests/unit/modules/users/test_tutorial_notebook_creation.py +511 -297
- cognee/tests/unit/shared/test_usage_logger.py +241 -0
- cognee/tests/unit/users/permissions/test_has_user_management_permission.py +46 -0
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dev0.dist-info}/METADATA +17 -10
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dev0.dist-info}/RECORD +232 -144
- cognee/api/.env.example +0 -5
- cognee/modules/retrieval/base_graph_retriever.py +0 -24
- cognee/modules/search/methods/get_search_type_tools.py +0 -223
- cognee/modules/search/methods/no_access_control_search.py +0 -62
- cognee/modules/search/utils/prepare_search_result.py +0 -63
- cognee/tests/test_feedback_enrichment.py +0 -174
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dev0.dist-info}/WHEEL +0 -0
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dev0.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dev0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dev0.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from unittest.mock import AsyncMock
|
|
3
|
+
|
|
4
|
+
from cognee.modules.retrieval.utils.node_edge_vector_search import NodeEdgeVectorSearch
|
|
5
|
+
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MockScoredResult:
|
|
9
|
+
"""Mock class for vector search results."""
|
|
10
|
+
|
|
11
|
+
def __init__(self, id, score, payload=None):
|
|
12
|
+
self.id = id
|
|
13
|
+
self.score = score
|
|
14
|
+
self.payload = payload or {}
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@pytest.mark.asyncio
|
|
18
|
+
async def test_node_edge_vector_search_single_query_shape():
|
|
19
|
+
"""Test that single query mode produces flat lists (not list-of-lists)."""
|
|
20
|
+
mock_vector_engine = AsyncMock()
|
|
21
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
22
|
+
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
23
|
+
|
|
24
|
+
node_results = [MockScoredResult("node1", 0.95), MockScoredResult("node2", 0.87)]
|
|
25
|
+
edge_results = [MockScoredResult("edge1", 0.92)]
|
|
26
|
+
|
|
27
|
+
def search_side_effect(*args, **kwargs):
|
|
28
|
+
collection_name = kwargs.get("collection_name")
|
|
29
|
+
if collection_name == "EdgeType_relationship_name":
|
|
30
|
+
return edge_results
|
|
31
|
+
return node_results
|
|
32
|
+
|
|
33
|
+
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
|
34
|
+
|
|
35
|
+
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
|
36
|
+
collections = ["Entity_name", "EdgeType_relationship_name"]
|
|
37
|
+
|
|
38
|
+
await vector_search.embed_and_retrieve_distances(
|
|
39
|
+
query="test query", query_batch=None, collections=collections, wide_search_limit=10
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
assert vector_search.query_list_length is None
|
|
43
|
+
assert vector_search.edge_distances == edge_results
|
|
44
|
+
assert vector_search.node_distances["Entity_name"] == node_results
|
|
45
|
+
mock_vector_engine.embedding_engine.embed_text.assert_called_once_with(["test query"])
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@pytest.mark.asyncio
|
|
49
|
+
async def test_node_edge_vector_search_batch_query_shape_and_empties():
|
|
50
|
+
"""Test that batch query mode produces list-of-lists with correct length and handles empty collections."""
|
|
51
|
+
mock_vector_engine = AsyncMock()
|
|
52
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
53
|
+
|
|
54
|
+
query_batch = ["query a", "query b"]
|
|
55
|
+
node_results_query_a = [MockScoredResult("node1", 0.95)]
|
|
56
|
+
node_results_query_b = [MockScoredResult("node2", 0.87)]
|
|
57
|
+
edge_results_query_a = [MockScoredResult("edge1", 0.92)]
|
|
58
|
+
edge_results_query_b = []
|
|
59
|
+
|
|
60
|
+
def batch_search_side_effect(*args, **kwargs):
|
|
61
|
+
collection_name = kwargs.get("collection_name")
|
|
62
|
+
if collection_name == "EdgeType_relationship_name":
|
|
63
|
+
return [edge_results_query_a, edge_results_query_b]
|
|
64
|
+
elif collection_name == "Entity_name":
|
|
65
|
+
return [node_results_query_a, node_results_query_b]
|
|
66
|
+
elif collection_name == "MissingCollection":
|
|
67
|
+
raise CollectionNotFoundError("Collection not found")
|
|
68
|
+
return [[], []]
|
|
69
|
+
|
|
70
|
+
mock_vector_engine.batch_search = AsyncMock(side_effect=batch_search_side_effect)
|
|
71
|
+
|
|
72
|
+
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
|
73
|
+
collections = [
|
|
74
|
+
"Entity_name",
|
|
75
|
+
"EdgeType_relationship_name",
|
|
76
|
+
"MissingCollection",
|
|
77
|
+
"EmptyCollection",
|
|
78
|
+
]
|
|
79
|
+
|
|
80
|
+
await vector_search.embed_and_retrieve_distances(
|
|
81
|
+
query=None, query_batch=query_batch, collections=collections, wide_search_limit=None
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
assert vector_search.query_list_length == 2
|
|
85
|
+
assert len(vector_search.edge_distances) == 2
|
|
86
|
+
assert vector_search.edge_distances[0] == edge_results_query_a
|
|
87
|
+
assert vector_search.edge_distances[1] == edge_results_query_b
|
|
88
|
+
assert len(vector_search.node_distances["Entity_name"]) == 2
|
|
89
|
+
assert vector_search.node_distances["Entity_name"][0] == node_results_query_a
|
|
90
|
+
assert vector_search.node_distances["Entity_name"][1] == node_results_query_b
|
|
91
|
+
assert len(vector_search.node_distances["MissingCollection"]) == 2
|
|
92
|
+
assert vector_search.node_distances["MissingCollection"] == [[], []]
|
|
93
|
+
assert len(vector_search.node_distances["EmptyCollection"]) == 2
|
|
94
|
+
assert vector_search.node_distances["EmptyCollection"] == [[], []]
|
|
95
|
+
mock_vector_engine.embedding_engine.embed_text.assert_not_called()
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@pytest.mark.asyncio
|
|
99
|
+
async def test_node_edge_vector_search_input_validation_both_provided():
|
|
100
|
+
"""Test that providing both query and query_batch raises ValueError."""
|
|
101
|
+
mock_vector_engine = AsyncMock()
|
|
102
|
+
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
|
103
|
+
collections = ["Entity_name"]
|
|
104
|
+
|
|
105
|
+
with pytest.raises(ValueError, match="Cannot provide both 'query' and 'query_batch'"):
|
|
106
|
+
await vector_search.embed_and_retrieve_distances(
|
|
107
|
+
query="test", query_batch=["test1", "test2"], collections=collections
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@pytest.mark.asyncio
|
|
112
|
+
async def test_node_edge_vector_search_input_validation_neither_provided():
|
|
113
|
+
"""Test that providing neither query nor query_batch raises ValueError."""
|
|
114
|
+
mock_vector_engine = AsyncMock()
|
|
115
|
+
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
|
116
|
+
collections = ["Entity_name"]
|
|
117
|
+
|
|
118
|
+
with pytest.raises(ValueError, match="Must provide either 'query' or 'query_batch'"):
|
|
119
|
+
await vector_search.embed_and_retrieve_distances(
|
|
120
|
+
query=None, query_batch=None, collections=collections
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@pytest.mark.asyncio
|
|
125
|
+
async def test_node_edge_vector_search_extract_relevant_node_ids_single_query():
|
|
126
|
+
"""Test that extract_relevant_node_ids returns IDs for single query mode."""
|
|
127
|
+
mock_vector_engine = AsyncMock()
|
|
128
|
+
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
|
129
|
+
vector_search.query_list_length = None
|
|
130
|
+
vector_search.node_distances = {
|
|
131
|
+
"Entity_name": [MockScoredResult("node1", 0.95), MockScoredResult("node2", 0.87)],
|
|
132
|
+
"TextSummary_text": [MockScoredResult("node1", 0.90), MockScoredResult("node3", 0.92)],
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
node_ids = vector_search.extract_relevant_node_ids()
|
|
136
|
+
assert set(node_ids) == {"node1", "node2", "node3"}
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@pytest.mark.asyncio
|
|
140
|
+
async def test_node_edge_vector_search_extract_relevant_node_ids_batch():
|
|
141
|
+
"""Test that extract_relevant_node_ids returns empty list for batch mode."""
|
|
142
|
+
mock_vector_engine = AsyncMock()
|
|
143
|
+
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
|
144
|
+
vector_search.query_list_length = 2
|
|
145
|
+
vector_search.node_distances = {
|
|
146
|
+
"Entity_name": [
|
|
147
|
+
[MockScoredResult("node1", 0.95)],
|
|
148
|
+
[MockScoredResult("node2", 0.87)],
|
|
149
|
+
],
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
node_ids = vector_search.extract_relevant_node_ids()
|
|
153
|
+
assert node_ids == []
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@pytest.mark.asyncio
|
|
157
|
+
async def test_node_edge_vector_search_has_results_single_query():
|
|
158
|
+
"""Test has_results returns True when results exist and False when only empties."""
|
|
159
|
+
mock_vector_engine = AsyncMock()
|
|
160
|
+
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
|
161
|
+
|
|
162
|
+
vector_search.edge_distances = [MockScoredResult("edge1", 0.92)]
|
|
163
|
+
vector_search.node_distances = {}
|
|
164
|
+
assert vector_search.has_results() is True
|
|
165
|
+
|
|
166
|
+
vector_search.edge_distances = []
|
|
167
|
+
vector_search.node_distances = {"Entity_name": [MockScoredResult("node1", 0.95)]}
|
|
168
|
+
assert vector_search.has_results() is True
|
|
169
|
+
|
|
170
|
+
vector_search.edge_distances = []
|
|
171
|
+
vector_search.node_distances = {}
|
|
172
|
+
assert vector_search.has_results() is False
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
@pytest.mark.asyncio
|
|
176
|
+
async def test_node_edge_vector_search_has_results_batch():
|
|
177
|
+
"""Test has_results works correctly for batch mode with list-of-lists."""
|
|
178
|
+
mock_vector_engine = AsyncMock()
|
|
179
|
+
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
|
180
|
+
vector_search.query_list_length = 2
|
|
181
|
+
|
|
182
|
+
vector_search.edge_distances = [[MockScoredResult("edge1", 0.92)], []]
|
|
183
|
+
vector_search.node_distances = {}
|
|
184
|
+
assert vector_search.has_results() is True
|
|
185
|
+
|
|
186
|
+
vector_search.edge_distances = [[], []]
|
|
187
|
+
vector_search.node_distances = {
|
|
188
|
+
"Entity_name": [[MockScoredResult("node1", 0.95)], []],
|
|
189
|
+
}
|
|
190
|
+
assert vector_search.has_results() is True
|
|
191
|
+
|
|
192
|
+
vector_search.edge_distances = [[], []]
|
|
193
|
+
vector_search.node_distances = {"Entity_name": [[], []]}
|
|
194
|
+
assert vector_search.has_results() is False
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
@pytest.mark.asyncio
|
|
198
|
+
async def test_node_edge_vector_search_single_query_collection_not_found():
|
|
199
|
+
"""Test that CollectionNotFoundError in single query mode returns empty list."""
|
|
200
|
+
mock_vector_engine = AsyncMock()
|
|
201
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
202
|
+
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
203
|
+
mock_vector_engine.search = AsyncMock(
|
|
204
|
+
side_effect=CollectionNotFoundError("Collection not found")
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
|
208
|
+
collections = ["MissingCollection"]
|
|
209
|
+
|
|
210
|
+
await vector_search.embed_and_retrieve_distances(
|
|
211
|
+
query="test query", query_batch=None, collections=collections, wide_search_limit=10
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
assert vector_search.node_distances["MissingCollection"] == []
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
@pytest.mark.asyncio
|
|
218
|
+
async def test_node_edge_vector_search_missing_collections_single_query():
|
|
219
|
+
"""Test that missing collections in single-query mode are handled gracefully with empty lists."""
|
|
220
|
+
mock_vector_engine = AsyncMock()
|
|
221
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
222
|
+
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
223
|
+
|
|
224
|
+
node_result = MockScoredResult("node1", 0.95)
|
|
225
|
+
|
|
226
|
+
def search_side_effect(*args, **kwargs):
|
|
227
|
+
collection_name = kwargs.get("collection_name")
|
|
228
|
+
if collection_name == "Entity_name":
|
|
229
|
+
return [node_result]
|
|
230
|
+
elif collection_name == "MissingCollection":
|
|
231
|
+
raise CollectionNotFoundError("Collection not found")
|
|
232
|
+
return []
|
|
233
|
+
|
|
234
|
+
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
|
235
|
+
|
|
236
|
+
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
|
237
|
+
collections = ["Entity_name", "MissingCollection", "EmptyCollection"]
|
|
238
|
+
|
|
239
|
+
await vector_search.embed_and_retrieve_distances(
|
|
240
|
+
query="test query", query_batch=None, collections=collections, wide_search_limit=10
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
assert len(vector_search.node_distances["Entity_name"]) == 1
|
|
244
|
+
assert vector_search.node_distances["Entity_name"][0].id == "node1"
|
|
245
|
+
assert vector_search.node_distances["Entity_name"][0].score == 0.95
|
|
246
|
+
assert vector_search.node_distances["MissingCollection"] == []
|
|
247
|
+
assert vector_search.node_distances["EmptyCollection"] == []
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
@pytest.mark.asyncio
|
|
251
|
+
async def test_node_edge_vector_search_has_results_batch_nodes_only():
|
|
252
|
+
"""Test has_results returns True when only node distances are populated in batch mode."""
|
|
253
|
+
mock_vector_engine = AsyncMock()
|
|
254
|
+
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
|
255
|
+
vector_search.query_list_length = 2
|
|
256
|
+
vector_search.edge_distances = [[], []]
|
|
257
|
+
vector_search.node_distances = {
|
|
258
|
+
"Entity_name": [[MockScoredResult("node1", 0.95)], []],
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
assert vector_search.has_results() is True
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
@pytest.mark.asyncio
|
|
265
|
+
async def test_node_edge_vector_search_has_results_batch_edges_only():
|
|
266
|
+
"""Test has_results returns True when only edge distances are populated in batch mode."""
|
|
267
|
+
mock_vector_engine = AsyncMock()
|
|
268
|
+
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
|
269
|
+
vector_search.query_list_length = 2
|
|
270
|
+
vector_search.edge_distances = [[MockScoredResult("edge1", 0.92)], []]
|
|
271
|
+
vector_search.node_distances = {}
|
|
272
|
+
|
|
273
|
+
assert vector_search.has_results() is True
|
|
@@ -31,14 +31,18 @@ async def test_get_context_success(mock_vector_engine):
|
|
|
31
31
|
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
|
32
32
|
return_value=mock_vector_engine,
|
|
33
33
|
):
|
|
34
|
-
|
|
34
|
+
objects = await retriever.get_retrieved_objects("test query")
|
|
35
|
+
context = await retriever.get_context_from_objects("test query", objects)
|
|
36
|
+
await retriever.get_completion_from_context("test query", objects, context)
|
|
35
37
|
|
|
36
38
|
assert context == "Alice knows Bob\nBob works at Tech Corp"
|
|
37
|
-
mock_vector_engine.search.assert_awaited_once_with(
|
|
39
|
+
mock_vector_engine.search.assert_awaited_once_with(
|
|
40
|
+
"Triplet_text", "test query", limit=5, include_payload=True
|
|
41
|
+
)
|
|
38
42
|
|
|
39
43
|
|
|
40
44
|
@pytest.mark.asyncio
|
|
41
|
-
async def
|
|
45
|
+
async def test_get_objects_no_collection(mock_vector_engine):
|
|
42
46
|
"""Test that NoDataError is raised when Triplet_text collection doesn't exist."""
|
|
43
47
|
mock_vector_engine.has_collection.return_value = False
|
|
44
48
|
|
|
@@ -49,7 +53,7 @@ async def test_get_context_no_collection(mock_vector_engine):
|
|
|
49
53
|
return_value=mock_vector_engine,
|
|
50
54
|
):
|
|
51
55
|
with pytest.raises(NoDataError, match="create_triplet_embeddings"):
|
|
52
|
-
await retriever.
|
|
56
|
+
await retriever.get_retrieved_objects("test query")
|
|
53
57
|
|
|
54
58
|
|
|
55
59
|
@pytest.mark.asyncio
|
|
@@ -63,13 +67,13 @@ async def test_get_context_empty_results(mock_vector_engine):
|
|
|
63
67
|
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
|
64
68
|
return_value=mock_vector_engine,
|
|
65
69
|
):
|
|
66
|
-
context = await retriever.
|
|
70
|
+
context = await retriever.get_context_from_objects("test query", [])
|
|
67
71
|
|
|
68
72
|
assert context == ""
|
|
69
73
|
|
|
70
74
|
|
|
71
75
|
@pytest.mark.asyncio
|
|
72
|
-
async def
|
|
76
|
+
async def test_get_objects_collection_not_found_error(mock_vector_engine):
|
|
73
77
|
"""Test that CollectionNotFoundError is converted to NoDataError."""
|
|
74
78
|
mock_vector_engine.has_collection.side_effect = CollectionNotFoundError("Collection not found")
|
|
75
79
|
|
|
@@ -80,7 +84,7 @@ async def test_get_context_collection_not_found_error(mock_vector_engine):
|
|
|
80
84
|
return_value=mock_vector_engine,
|
|
81
85
|
):
|
|
82
86
|
with pytest.raises(NoDataError, match="No data found"):
|
|
83
|
-
await retriever.
|
|
87
|
+
await retriever.get_retrieved_objects("test query")
|
|
84
88
|
|
|
85
89
|
|
|
86
90
|
@pytest.mark.asyncio
|
|
@@ -98,7 +102,8 @@ async def test_get_context_empty_payload_text(mock_vector_engine):
|
|
|
98
102
|
return_value=mock_vector_engine,
|
|
99
103
|
):
|
|
100
104
|
with pytest.raises(KeyError):
|
|
101
|
-
await retriever.
|
|
105
|
+
objects = await retriever.get_retrieved_objects("test query")
|
|
106
|
+
await retriever.get_context_from_objects("test query", retrieved_objects=objects)
|
|
102
107
|
|
|
103
108
|
|
|
104
109
|
@pytest.mark.asyncio
|
|
@@ -115,7 +120,8 @@ async def test_get_context_single_triplet(mock_vector_engine):
|
|
|
115
120
|
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
|
116
121
|
return_value=mock_vector_engine,
|
|
117
122
|
):
|
|
118
|
-
|
|
123
|
+
objects = await retriever.get_retrieved_objects("test query")
|
|
124
|
+
context = await retriever.get_context_from_objects("test query", retrieved_objects=objects)
|
|
119
125
|
|
|
120
126
|
assert context == "Single triplet"
|
|
121
127
|
|
|
@@ -172,7 +178,7 @@ async def test_get_completion_without_context(mock_vector_engine):
|
|
|
172
178
|
mock_config.caching = False
|
|
173
179
|
mock_cache_config.return_value = mock_config
|
|
174
180
|
|
|
175
|
-
completion = await retriever.
|
|
181
|
+
completion = await retriever.get_completion_from_context("test query", None, None)
|
|
176
182
|
|
|
177
183
|
assert isinstance(completion, list)
|
|
178
184
|
assert len(completion) == 1
|
|
@@ -195,7 +201,9 @@ async def test_get_completion_with_provided_context(mock_vector_engine):
|
|
|
195
201
|
mock_config.caching = False
|
|
196
202
|
mock_cache_config.return_value = mock_config
|
|
197
203
|
|
|
198
|
-
completion = await retriever.
|
|
204
|
+
completion = await retriever.get_completion_from_context(
|
|
205
|
+
"test query", None, context="Provided context"
|
|
206
|
+
)
|
|
199
207
|
|
|
200
208
|
assert isinstance(completion, list)
|
|
201
209
|
assert len(completion) == 1
|
|
@@ -210,7 +218,7 @@ async def test_get_completion_with_session(mock_vector_engine):
|
|
|
210
218
|
mock_vector_engine.has_collection.return_value = True
|
|
211
219
|
mock_vector_engine.search.return_value = [mock_result]
|
|
212
220
|
|
|
213
|
-
retriever = TripletRetriever()
|
|
221
|
+
retriever = TripletRetriever(session_id="test_session")
|
|
214
222
|
|
|
215
223
|
mock_user = MagicMock()
|
|
216
224
|
mock_user.id = "test-user-id"
|
|
@@ -243,7 +251,9 @@ async def test_get_completion_with_session(mock_vector_engine):
|
|
|
243
251
|
mock_cache_config.return_value = mock_config
|
|
244
252
|
mock_session_user.get.return_value = mock_user
|
|
245
253
|
|
|
246
|
-
|
|
254
|
+
objects = await retriever.get_retrieved_objects("test query")
|
|
255
|
+
context = await retriever.get_context_from_objects("test query", retrieved_objects=objects)
|
|
256
|
+
completion = await retriever.get_completion_from_context("test query", objects, context)
|
|
247
257
|
|
|
248
258
|
assert isinstance(completion, list)
|
|
249
259
|
assert len(completion) == 1
|
|
@@ -278,7 +288,9 @@ async def test_get_completion_with_session_no_user_id(mock_vector_engine):
|
|
|
278
288
|
mock_cache_config.return_value = mock_config
|
|
279
289
|
mock_session_user.get.return_value = None # No user
|
|
280
290
|
|
|
281
|
-
|
|
291
|
+
objects = await retriever.get_retrieved_objects("test query")
|
|
292
|
+
context = await retriever.get_context_from_objects("test query", retrieved_objects=objects)
|
|
293
|
+
completion = await retriever.get_completion_from_context("test query", objects, context)
|
|
282
294
|
|
|
283
295
|
assert isinstance(completion, list)
|
|
284
296
|
assert len(completion) == 1
|
|
@@ -297,7 +309,7 @@ async def test_get_completion_with_response_model(mock_vector_engine):
|
|
|
297
309
|
mock_vector_engine.has_collection.return_value = True
|
|
298
310
|
mock_vector_engine.search.return_value = [mock_result]
|
|
299
311
|
|
|
300
|
-
retriever = TripletRetriever()
|
|
312
|
+
retriever = TripletRetriever(response_model=TestModel)
|
|
301
313
|
|
|
302
314
|
with (
|
|
303
315
|
patch(
|
|
@@ -314,7 +326,9 @@ async def test_get_completion_with_response_model(mock_vector_engine):
|
|
|
314
326
|
mock_config.caching = False
|
|
315
327
|
mock_cache_config.return_value = mock_config
|
|
316
328
|
|
|
317
|
-
|
|
329
|
+
objects = await retriever.get_retrieved_objects("test query")
|
|
330
|
+
context = await retriever.get_context_from_objects("test query", retrieved_objects=objects)
|
|
331
|
+
completion = await retriever.get_completion_from_context("test query", objects, context)
|
|
318
332
|
|
|
319
333
|
assert isinstance(completion, list)
|
|
320
334
|
assert len(completion) == 1
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
|
|
4
|
+
from cognee.modules.search.types import SearchType
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class _DummyCommunityRetriever:
|
|
8
|
+
def __init__(self, *args, **kwargs):
|
|
9
|
+
self.kwargs = kwargs
|
|
10
|
+
|
|
11
|
+
def get_completion(self, *args, **kwargs):
|
|
12
|
+
return {"kind": "completion", "init": self.kwargs, "args": args, "kwargs": kwargs}
|
|
13
|
+
|
|
14
|
+
def get_context(self, *args, **kwargs):
|
|
15
|
+
return {"kind": "context", "init": self.kwargs, "args": args, "kwargs": kwargs}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@pytest.mark.asyncio
|
|
19
|
+
async def test_feeling_lucky_delegates_to_select_search_type(monkeypatch):
|
|
20
|
+
import cognee.modules.search.methods.get_search_type_retriever_instance as mod
|
|
21
|
+
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
|
|
22
|
+
|
|
23
|
+
async def _fake_select_search_type(query_text: str):
|
|
24
|
+
assert query_text == "hello"
|
|
25
|
+
return SearchType.CHUNKS
|
|
26
|
+
|
|
27
|
+
monkeypatch.setattr(mod, "select_search_type", _fake_select_search_type)
|
|
28
|
+
|
|
29
|
+
retriever_instance = await mod.get_search_type_retriever_instance(
|
|
30
|
+
SearchType.FEELING_LUCKY, query_text="hello"
|
|
31
|
+
)
|
|
32
|
+
assert isinstance(retriever_instance, ChunksRetriever)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@pytest.mark.asyncio
|
|
36
|
+
async def test_disallowed_cypher_search_types_raise(monkeypatch):
|
|
37
|
+
import cognee.modules.search.methods.get_search_type_retriever_instance as mod
|
|
38
|
+
|
|
39
|
+
monkeypatch.setenv("ALLOW_CYPHER_QUERY", "false")
|
|
40
|
+
|
|
41
|
+
with pytest.raises(UnsupportedSearchTypeError, match="disabled"):
|
|
42
|
+
await mod.get_search_type_retriever_instance(
|
|
43
|
+
SearchType.CYPHER, query_text="MATCH (n) RETURN n"
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
with pytest.raises(UnsupportedSearchTypeError, match="disabled"):
|
|
47
|
+
await mod.get_search_type_retriever_instance(
|
|
48
|
+
SearchType.NATURAL_LANGUAGE, query_text="Find nodes"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@pytest.mark.asyncio
|
|
53
|
+
async def test_allowed_cypher_search_types_return_tools(monkeypatch):
|
|
54
|
+
import cognee.modules.search.methods.get_search_type_retriever_instance as mod
|
|
55
|
+
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
|
|
56
|
+
|
|
57
|
+
monkeypatch.setenv("ALLOW_CYPHER_QUERY", "true")
|
|
58
|
+
|
|
59
|
+
retriever_instance = await mod.get_search_type_retriever_instance(
|
|
60
|
+
SearchType.CYPHER, query_text="q"
|
|
61
|
+
)
|
|
62
|
+
assert isinstance(retriever_instance, CypherSearchRetriever)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@pytest.mark.asyncio
|
|
66
|
+
async def test_registered_community_retriever_is_used(monkeypatch):
|
|
67
|
+
"""
|
|
68
|
+
Integration point: community retrievers are loaded from the registry module and should
|
|
69
|
+
override the default mapping when present.
|
|
70
|
+
"""
|
|
71
|
+
import cognee.modules.search.methods.get_search_type_retriever_instance as mod
|
|
72
|
+
from cognee.modules.retrieval import registered_community_retrievers as registry
|
|
73
|
+
|
|
74
|
+
monkeypatch.setattr(
|
|
75
|
+
registry,
|
|
76
|
+
"registered_community_retrievers",
|
|
77
|
+
{SearchType.SUMMARIES: _DummyCommunityRetriever},
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
retriever_instance = await mod.get_search_type_retriever_instance(
|
|
81
|
+
SearchType.SUMMARIES, query_text="q", top_k=7
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
assert isinstance(retriever_instance, _DummyCommunityRetriever)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@pytest.mark.asyncio
|
|
88
|
+
async def test_unknown_query_type_raises_unsupported():
|
|
89
|
+
import cognee.modules.search.methods.get_search_type_retriever_instance as mod
|
|
90
|
+
|
|
91
|
+
with pytest.raises(UnsupportedSearchTypeError, match="UNKNOWN_TYPE"):
|
|
92
|
+
await mod.get_search_type_retriever_instance("UNKNOWN_TYPE", query_text="q")
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@pytest.mark.asyncio
|
|
96
|
+
async def test_default_mapping_passes_top_k_to_retrievers():
|
|
97
|
+
import cognee.modules.search.methods.get_search_type_retriever_instance as mod
|
|
98
|
+
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
|
|
99
|
+
|
|
100
|
+
retriever_instance = await mod.get_search_type_retriever_instance(
|
|
101
|
+
SearchType.SUMMARIES, query_text="q", top_k=4
|
|
102
|
+
)
|
|
103
|
+
assert isinstance(retriever_instance, SummariesRetriever)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@pytest.mark.asyncio
|
|
107
|
+
async def test_chunks_lexical_returns_jaccard_tools():
|
|
108
|
+
import cognee.modules.search.methods.get_search_type_retriever_instance as mod
|
|
109
|
+
from cognee.modules.retrieval.jaccard_retrival import JaccardChunksRetriever
|
|
110
|
+
|
|
111
|
+
retriever_instance = await mod.get_search_type_retriever_instance(
|
|
112
|
+
SearchType.CHUNKS_LEXICAL, query_text="q", top_k=3
|
|
113
|
+
)
|
|
114
|
+
assert isinstance(retriever_instance, JaccardChunksRetriever)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@pytest.mark.asyncio
|
|
118
|
+
async def test_coding_rules_uses_node_name_as_rules_nodeset_name():
|
|
119
|
+
import cognee.modules.search.methods.get_search_type_retriever_instance as mod
|
|
120
|
+
from cognee.modules.retrieval.coding_rules_retriever import CodingRulesRetriever
|
|
121
|
+
|
|
122
|
+
retriever_instance = await mod.get_search_type_retriever_instance(
|
|
123
|
+
SearchType.CODING_RULES, query_text="q", node_name=[]
|
|
124
|
+
)
|
|
125
|
+
assert isinstance(retriever_instance, CodingRulesRetriever)
|