cognee 0.5.1.dev0__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +2 -0
- cognee/alembic/README +1 -0
- cognee/alembic/env.py +107 -0
- cognee/alembic/script.py.mako +26 -0
- cognee/alembic/versions/1a58b986e6e1_enable_delete_for_old_tutorial_notebooks.py +52 -0
- cognee/alembic/versions/1d0bb7fede17_add_pipeline_run_status.py +33 -0
- cognee/alembic/versions/1daae0df1866_incremental_loading.py +48 -0
- cognee/alembic/versions/211ab850ef3d_add_sync_operations_table.py +118 -0
- cognee/alembic/versions/45957f0a9849_add_notebook_table.py +46 -0
- cognee/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py +333 -0
- cognee/alembic/versions/482cd6517ce4_add_default_user.py +30 -0
- cognee/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py +98 -0
- cognee/alembic/versions/8057ae7329c2_initial_migration.py +25 -0
- cognee/alembic/versions/9e7a3cb85175_loader_separation.py +104 -0
- cognee/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py +38 -0
- cognee/alembic/versions/ab7e313804ae_permission_system_rework.py +236 -0
- cognee/alembic/versions/b9274c27a25a_kuzu_11_migration.py +75 -0
- cognee/alembic/versions/c946955da633_multi_tenant_support.py +137 -0
- cognee/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py +51 -0
- cognee/alembic/versions/e4ebee1091e7_expand_data_model_info.py +140 -0
- cognee/alembic.ini +117 -0
- cognee/api/v1/add/routers/get_add_router.py +2 -0
- cognee/api/v1/cognify/cognify.py +11 -6
- cognee/api/v1/cognify/routers/get_cognify_router.py +8 -0
- cognee/api/v1/config/config.py +60 -0
- cognee/api/v1/datasets/routers/get_datasets_router.py +45 -3
- cognee/api/v1/memify/routers/get_memify_router.py +2 -0
- cognee/api/v1/search/routers/get_search_router.py +21 -6
- cognee/api/v1/search/search.py +25 -5
- cognee/api/v1/sync/routers/get_sync_router.py +3 -3
- cognee/cli/commands/add_command.py +1 -1
- cognee/cli/commands/cognify_command.py +6 -0
- cognee/cli/commands/config_command.py +1 -1
- cognee/context_global_variables.py +5 -1
- cognee/eval_framework/answer_generation/answer_generation_executor.py +7 -8
- cognee/infrastructure/databases/cache/cache_db_interface.py +38 -1
- cognee/infrastructure/databases/cache/config.py +6 -0
- cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py +21 -0
- cognee/infrastructure/databases/cache/get_cache_engine.py +9 -3
- cognee/infrastructure/databases/cache/redis/RedisAdapter.py +60 -1
- cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +7 -0
- cognee/infrastructure/databases/graph/get_graph_engine.py +29 -1
- cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +62 -27
- cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +17 -4
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +2 -1
- cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +2 -0
- cognee/infrastructure/databases/vector/config.py +6 -0
- cognee/infrastructure/databases/vector/create_vector_engine.py +69 -22
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +64 -9
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +13 -2
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +16 -3
- cognee/infrastructure/databases/vector/models/ScoredResult.py +3 -3
- cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +16 -3
- cognee/infrastructure/databases/vector/pgvector/PGVectorDatasetDatabaseHandler.py +86 -0
- cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py +81 -2
- cognee/infrastructure/databases/vector/vector_db_interface.py +8 -0
- cognee/infrastructure/files/utils/get_data_file_path.py +33 -27
- cognee/infrastructure/llm/prompts/extract_query_time.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_event_entity_prompt.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_event_graph_prompt.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt.txt +2 -2
- cognee/infrastructure/llm/prompts/generate_graph_prompt_guided.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt_oneshot.txt +2 -2
- cognee/infrastructure/llm/prompts/generate_graph_prompt_simple.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt_strict.txt +1 -1
- cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt +6 -6
- cognee/infrastructure/llm/prompts/test.txt +1 -1
- cognee/infrastructure/llm/prompts/translate_content.txt +19 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +24 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llama_cpp/adapter.py +191 -0
- cognee/modules/chunking/models/DocumentChunk.py +0 -1
- cognee/modules/cognify/config.py +2 -0
- cognee/modules/data/models/Data.py +1 -0
- cognee/modules/engine/models/Entity.py +0 -1
- cognee/modules/engine/operations/setup.py +6 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +150 -37
- cognee/modules/graph/cognee_graph/CogneeGraphElements.py +48 -2
- cognee/modules/graph/utils/__init__.py +1 -0
- cognee/modules/graph/utils/get_entity_nodes_from_triplets.py +12 -0
- cognee/modules/notebooks/methods/__init__.py +1 -0
- cognee/modules/notebooks/methods/create_notebook.py +0 -34
- cognee/modules/notebooks/methods/create_tutorial_notebooks.py +191 -0
- cognee/modules/notebooks/methods/get_notebooks.py +12 -8
- cognee/modules/notebooks/tutorials/cognee-basics/cell-1.md +3 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-2.md +10 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-3.md +7 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-4.py +28 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-5.py +3 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-6.py +9 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-7.py +17 -0
- cognee/modules/notebooks/tutorials/cognee-basics/config.json +4 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-1.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-10.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-11.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-12.py +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-13.md +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-14.py +6 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-15.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-16.py +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-2.md +9 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-3.md +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-4.md +9 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-5.md +5 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-6.py +13 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-7.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-8.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-9.py +31 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/config.json +4 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/copilot_conversations.json +107 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/guido_contributions.json +976 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/my_developer_rules.md +79 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/pep_style_guide.md +74 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/zen_principles.md +74 -0
- cognee/modules/retrieval/EntityCompletionRetriever.py +51 -38
- cognee/modules/retrieval/__init__.py +0 -1
- cognee/modules/retrieval/base_retriever.py +66 -10
- cognee/modules/retrieval/chunks_retriever.py +57 -49
- cognee/modules/retrieval/coding_rules_retriever.py +12 -5
- cognee/modules/retrieval/completion_retriever.py +29 -28
- cognee/modules/retrieval/cypher_search_retriever.py +25 -20
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +42 -46
- cognee/modules/retrieval/graph_completion_cot_retriever.py +68 -51
- cognee/modules/retrieval/graph_completion_retriever.py +78 -63
- cognee/modules/retrieval/graph_summary_completion_retriever.py +2 -0
- cognee/modules/retrieval/lexical_retriever.py +34 -12
- cognee/modules/retrieval/natural_language_retriever.py +18 -15
- cognee/modules/retrieval/summaries_retriever.py +51 -34
- cognee/modules/retrieval/temporal_retriever.py +59 -49
- cognee/modules/retrieval/triplet_retriever.py +31 -32
- cognee/modules/retrieval/utils/access_tracking.py +88 -0
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +99 -85
- cognee/modules/retrieval/utils/node_edge_vector_search.py +174 -0
- cognee/modules/search/methods/__init__.py +1 -0
- cognee/modules/search/methods/get_retriever_output.py +53 -0
- cognee/modules/search/methods/get_search_type_retriever_instance.py +252 -0
- cognee/modules/search/methods/search.py +90 -215
- cognee/modules/search/models/SearchResultPayload.py +67 -0
- cognee/modules/search/types/SearchResult.py +1 -8
- cognee/modules/search/types/SearchType.py +1 -2
- cognee/modules/search/types/__init__.py +1 -1
- cognee/modules/search/utils/__init__.py +1 -2
- cognee/modules/search/utils/transform_insights_to_graph.py +2 -2
- cognee/modules/search/utils/{transform_context_to_graph.py → transform_triplets_to_graph.py} +2 -2
- cognee/modules/users/authentication/default/default_transport.py +11 -1
- cognee/modules/users/authentication/get_api_auth_backend.py +2 -1
- cognee/modules/users/authentication/get_client_auth_backend.py +2 -1
- cognee/modules/users/methods/create_user.py +0 -9
- cognee/modules/users/permissions/methods/has_user_management_permission.py +29 -0
- cognee/modules/visualization/cognee_network_visualization.py +1 -1
- cognee/run_migrations.py +48 -0
- cognee/shared/exceptions/__init__.py +1 -3
- cognee/shared/exceptions/exceptions.py +11 -1
- cognee/shared/usage_logger.py +332 -0
- cognee/shared/utils.py +12 -5
- cognee/tasks/chunks/__init__.py +9 -0
- cognee/tasks/cleanup/cleanup_unused_data.py +172 -0
- cognee/tasks/graph/__init__.py +7 -0
- cognee/tasks/memify/__init__.py +8 -0
- cognee/tasks/memify/extract_usage_frequency.py +613 -0
- cognee/tasks/summarization/models.py +0 -2
- cognee/tasks/temporal_graph/__init__.py +0 -1
- cognee/tasks/translation/__init__.py +96 -0
- cognee/tasks/translation/config.py +110 -0
- cognee/tasks/translation/detect_language.py +190 -0
- cognee/tasks/translation/exceptions.py +62 -0
- cognee/tasks/translation/models.py +72 -0
- cognee/tasks/translation/providers/__init__.py +44 -0
- cognee/tasks/translation/providers/azure_provider.py +192 -0
- cognee/tasks/translation/providers/base.py +85 -0
- cognee/tasks/translation/providers/google_provider.py +158 -0
- cognee/tasks/translation/providers/llm_provider.py +143 -0
- cognee/tasks/translation/translate_content.py +282 -0
- cognee/tasks/web_scraper/default_url_crawler.py +6 -2
- cognee/tests/cli_tests/cli_unit_tests/test_cli_commands.py +1 -0
- cognee/tests/cli_tests/cli_unit_tests/test_cli_edge_cases.py +3 -0
- cognee/tests/integration/retrieval/test_brute_force_triplet_search_with_cognify.py +62 -0
- cognee/tests/integration/retrieval/test_chunks_retriever.py +115 -16
- cognee/tests/integration/retrieval/test_graph_completion_retriever.py +13 -5
- cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +22 -20
- cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +23 -24
- cognee/tests/integration/retrieval/test_rag_completion_retriever.py +70 -5
- cognee/tests/integration/retrieval/test_structured_output.py +62 -18
- cognee/tests/integration/retrieval/test_summaries_retriever.py +20 -9
- cognee/tests/integration/retrieval/test_temporal_retriever.py +38 -8
- cognee/tests/integration/retrieval/test_triplet_retriever.py +13 -4
- cognee/tests/integration/shared/test_usage_logger_integration.py +255 -0
- cognee/tests/tasks/translation/README.md +147 -0
- cognee/tests/tasks/translation/__init__.py +1 -0
- cognee/tests/tasks/translation/config_test.py +93 -0
- cognee/tests/tasks/translation/detect_language_test.py +118 -0
- cognee/tests/tasks/translation/providers_test.py +151 -0
- cognee/tests/tasks/translation/translate_content_test.py +213 -0
- cognee/tests/test_chromadb.py +1 -1
- cognee/tests/test_cleanup_unused_data.py +165 -0
- cognee/tests/test_delete_by_id.py +6 -6
- cognee/tests/test_extract_usage_frequency.py +308 -0
- cognee/tests/test_kuzu.py +17 -7
- cognee/tests/test_lancedb.py +3 -1
- cognee/tests/test_library.py +1 -1
- cognee/tests/test_neo4j.py +17 -7
- cognee/tests/test_neptune_analytics_vector.py +3 -1
- cognee/tests/test_permissions.py +172 -187
- cognee/tests/test_pgvector.py +3 -1
- cognee/tests/test_relational_db_migration.py +15 -1
- cognee/tests/test_remote_kuzu.py +3 -1
- cognee/tests/test_s3_file_storage.py +1 -1
- cognee/tests/test_search_db.py +97 -110
- cognee/tests/test_usage_logger_e2e.py +268 -0
- cognee/tests/unit/api/test_get_raw_data_endpoint.py +206 -0
- cognee/tests/unit/eval_framework/answer_generation_test.py +4 -3
- cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +2 -0
- cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +42 -2
- cognee/tests/unit/modules/graph/cognee_graph_test.py +329 -31
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +31 -59
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +70 -33
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +72 -52
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +27 -33
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +28 -15
- cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +37 -42
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +48 -64
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +263 -24
- cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py +273 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +30 -16
- cognee/tests/unit/modules/search/test_get_search_type_retriever_instance.py +125 -0
- cognee/tests/unit/modules/search/test_search.py +176 -0
- cognee/tests/unit/modules/search/test_search_prepare_search_result_contract.py +190 -0
- cognee/tests/unit/modules/users/test_tutorial_notebook_creation.py +511 -297
- cognee/tests/unit/shared/test_usage_logger.py +241 -0
- cognee/tests/unit/users/permissions/test_has_user_management_permission.py +46 -0
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dist-info}/METADATA +22 -17
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dist-info}/RECORD +235 -147
- cognee/api/.env.example +0 -5
- cognee/modules/retrieval/base_graph_retriever.py +0 -24
- cognee/modules/search/methods/get_search_type_tools.py +0 -223
- cognee/modules/search/methods/no_access_control_search.py +0 -62
- cognee/modules/search/utils/prepare_search_result.py +0 -63
- cognee/tests/test_feedback_enrichment.py +0 -174
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dist-info}/WHEEL +0 -0
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import pytest
|
|
2
2
|
from unittest.mock import AsyncMock
|
|
3
3
|
|
|
4
|
+
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
|
|
4
5
|
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
|
|
5
6
|
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
|
6
7
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
|
|
@@ -200,6 +201,37 @@ async def test_project_graph_from_db_empty_graph(setup_graph, mock_adapter):
|
|
|
200
201
|
)
|
|
201
202
|
|
|
202
203
|
|
|
204
|
+
@pytest.mark.asyncio
|
|
205
|
+
async def test_project_graph_from_db_stores_triplet_penalty_on_graph(mock_adapter):
|
|
206
|
+
"""Test that project_graph_from_db stores triplet_distance_penalty on the graph."""
|
|
207
|
+
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
|
208
|
+
|
|
209
|
+
nodes_data = [("1", {"name": "Node1"})]
|
|
210
|
+
edges_data = [("1", "1", "SELF", {})]
|
|
211
|
+
|
|
212
|
+
mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
|
|
213
|
+
|
|
214
|
+
graph = CogneeGraph()
|
|
215
|
+
custom_penalty = 5.0
|
|
216
|
+
await graph.project_graph_from_db(
|
|
217
|
+
adapter=mock_adapter,
|
|
218
|
+
node_properties_to_project=["name"],
|
|
219
|
+
edge_properties_to_project=[],
|
|
220
|
+
triplet_distance_penalty=custom_penalty,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
assert graph.triplet_distance_penalty == custom_penalty
|
|
224
|
+
|
|
225
|
+
graph2 = CogneeGraph()
|
|
226
|
+
await graph2.project_graph_from_db(
|
|
227
|
+
adapter=mock_adapter,
|
|
228
|
+
node_properties_to_project=["name"],
|
|
229
|
+
edge_properties_to_project=[],
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
assert graph2.triplet_distance_penalty == 3.5
|
|
233
|
+
|
|
234
|
+
|
|
203
235
|
@pytest.mark.asyncio
|
|
204
236
|
async def test_project_graph_from_db_missing_nodes(setup_graph, mock_adapter):
|
|
205
237
|
"""Test that edges referencing missing nodes raise error."""
|
|
@@ -241,8 +273,8 @@ async def test_map_vector_distances_to_graph_nodes(setup_graph):
|
|
|
241
273
|
|
|
242
274
|
await graph.map_vector_distances_to_graph_nodes(node_distances)
|
|
243
275
|
|
|
244
|
-
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
|
|
245
|
-
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
|
|
276
|
+
assert graph.get_node("1").attributes.get("vector_distance") == [0.95]
|
|
277
|
+
assert graph.get_node("2").attributes.get("vector_distance") == [0.87]
|
|
246
278
|
|
|
247
279
|
|
|
248
280
|
@pytest.mark.asyncio
|
|
@@ -266,9 +298,9 @@ async def test_map_vector_distances_partial_node_coverage(setup_graph):
|
|
|
266
298
|
|
|
267
299
|
await graph.map_vector_distances_to_graph_nodes(node_distances)
|
|
268
300
|
|
|
269
|
-
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
|
|
270
|
-
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
|
|
271
|
-
assert graph.get_node("3").attributes.get("vector_distance") == 3.5
|
|
301
|
+
assert graph.get_node("1").attributes.get("vector_distance") == [0.95]
|
|
302
|
+
assert graph.get_node("2").attributes.get("vector_distance") == [0.87]
|
|
303
|
+
assert graph.get_node("3").attributes.get("vector_distance") == [3.5]
|
|
272
304
|
|
|
273
305
|
|
|
274
306
|
@pytest.mark.asyncio
|
|
@@ -298,10 +330,36 @@ async def test_map_vector_distances_multiple_categories(setup_graph):
|
|
|
298
330
|
|
|
299
331
|
await graph.map_vector_distances_to_graph_nodes(node_distances)
|
|
300
332
|
|
|
301
|
-
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
|
|
302
|
-
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
|
|
303
|
-
assert graph.get_node("3").attributes.get("vector_distance") == 0.92
|
|
304
|
-
assert graph.get_node("4").attributes.get("vector_distance") == 3.5
|
|
333
|
+
assert graph.get_node("1").attributes.get("vector_distance") == [0.95]
|
|
334
|
+
assert graph.get_node("2").attributes.get("vector_distance") == [0.87]
|
|
335
|
+
assert graph.get_node("3").attributes.get("vector_distance") == [0.92]
|
|
336
|
+
assert graph.get_node("4").attributes.get("vector_distance") == [3.5]
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
@pytest.mark.asyncio
|
|
340
|
+
async def test_map_vector_distances_to_graph_nodes_multi_query(setup_graph):
|
|
341
|
+
"""Test mapping vector distances with multiple queries."""
|
|
342
|
+
graph = setup_graph
|
|
343
|
+
|
|
344
|
+
node1 = Node("1")
|
|
345
|
+
node2 = Node("2")
|
|
346
|
+
node3 = Node("3")
|
|
347
|
+
graph.add_node(node1)
|
|
348
|
+
graph.add_node(node2)
|
|
349
|
+
graph.add_node(node3)
|
|
350
|
+
|
|
351
|
+
node_distances = {
|
|
352
|
+
"Entity_name": [
|
|
353
|
+
[MockScoredResult("1", 0.95)], # query 0
|
|
354
|
+
[MockScoredResult("2", 0.87)], # query 1
|
|
355
|
+
]
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
await graph.map_vector_distances_to_graph_nodes(node_distances, query_list_length=2)
|
|
359
|
+
|
|
360
|
+
assert graph.get_node("1").attributes.get("vector_distance") == [0.95, 3.5]
|
|
361
|
+
assert graph.get_node("2").attributes.get("vector_distance") == [3.5, 0.87]
|
|
362
|
+
assert graph.get_node("3").attributes.get("vector_distance") == [3.5, 3.5]
|
|
305
363
|
|
|
306
364
|
|
|
307
365
|
@pytest.mark.asyncio
|
|
@@ -322,12 +380,12 @@ async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph):
|
|
|
322
380
|
graph.add_edge(edge)
|
|
323
381
|
|
|
324
382
|
edge_distances = [
|
|
325
|
-
MockScoredResult("
|
|
383
|
+
MockScoredResult(generate_edge_id("CONNECTS_TO"), 0.92, payload={"text": "CONNECTS_TO"}),
|
|
326
384
|
]
|
|
327
385
|
|
|
328
386
|
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
|
329
387
|
|
|
330
|
-
assert graph.edges[0].attributes.get("vector_distance") == 0.92
|
|
388
|
+
assert graph.edges[0].attributes.get("vector_distance") == [0.92]
|
|
331
389
|
|
|
332
390
|
|
|
333
391
|
@pytest.mark.asyncio
|
|
@@ -347,14 +405,15 @@ async def test_map_vector_distances_partial_edge_coverage(setup_graph):
|
|
|
347
405
|
graph.add_edge(edge1)
|
|
348
406
|
graph.add_edge(edge2)
|
|
349
407
|
|
|
408
|
+
edge_1_text = "CONNECTS_TO"
|
|
350
409
|
edge_distances = [
|
|
351
|
-
MockScoredResult(
|
|
410
|
+
MockScoredResult(generate_edge_id(edge_1_text), 0.92, payload={"text": edge_1_text}),
|
|
352
411
|
]
|
|
353
412
|
|
|
354
413
|
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
|
355
414
|
|
|
356
|
-
assert graph.edges[0].attributes.get("vector_distance") == 0.92
|
|
357
|
-
assert graph.edges[1].attributes.get("vector_distance") == 3.5
|
|
415
|
+
assert graph.edges[0].attributes.get("vector_distance") == [0.92]
|
|
416
|
+
assert graph.edges[1].attributes.get("vector_distance") == [3.5]
|
|
358
417
|
|
|
359
418
|
|
|
360
419
|
@pytest.mark.asyncio
|
|
@@ -374,13 +433,14 @@ async def test_map_vector_distances_edges_fallback_to_relationship_type(setup_gr
|
|
|
374
433
|
)
|
|
375
434
|
graph.add_edge(edge)
|
|
376
435
|
|
|
436
|
+
edge_text = "KNOWS"
|
|
377
437
|
edge_distances = [
|
|
378
|
-
MockScoredResult(
|
|
438
|
+
MockScoredResult(generate_edge_id(edge_text), 0.85, payload={"text": edge_text}),
|
|
379
439
|
]
|
|
380
440
|
|
|
381
441
|
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
|
382
442
|
|
|
383
|
-
assert graph.edges[0].attributes.get("vector_distance") == 0.85
|
|
443
|
+
assert graph.edges[0].attributes.get("vector_distance") == [0.85]
|
|
384
444
|
|
|
385
445
|
|
|
386
446
|
@pytest.mark.asyncio
|
|
@@ -400,18 +460,19 @@ async def test_map_vector_distances_no_edge_matches(setup_graph):
|
|
|
400
460
|
)
|
|
401
461
|
graph.add_edge(edge)
|
|
402
462
|
|
|
463
|
+
edge_text = "SOME_OTHER_EDGE"
|
|
403
464
|
edge_distances = [
|
|
404
|
-
MockScoredResult(
|
|
465
|
+
MockScoredResult(generate_edge_id(edge_text), 0.92, payload={"text": edge_text}),
|
|
405
466
|
]
|
|
406
467
|
|
|
407
468
|
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
|
408
469
|
|
|
409
|
-
assert graph.edges[0].attributes.get("vector_distance") == 3.5
|
|
470
|
+
assert graph.edges[0].attributes.get("vector_distance") == [3.5]
|
|
410
471
|
|
|
411
472
|
|
|
412
473
|
@pytest.mark.asyncio
|
|
413
474
|
async def test_map_vector_distances_none_returns_early(setup_graph):
|
|
414
|
-
"""Test that edge_distances=None returns early without error."""
|
|
475
|
+
"""Test that edge_distances=None returns early without error and vector_distance is set to default penalty."""
|
|
415
476
|
graph = setup_graph
|
|
416
477
|
graph.add_node(Node("1"))
|
|
417
478
|
graph.add_node(Node("2"))
|
|
@@ -419,7 +480,91 @@ async def test_map_vector_distances_none_returns_early(setup_graph):
|
|
|
419
480
|
|
|
420
481
|
await graph.map_vector_distances_to_graph_edges(edge_distances=None)
|
|
421
482
|
|
|
422
|
-
assert graph.edges[0].attributes.get("vector_distance") == 3.5
|
|
483
|
+
assert graph.edges[0].attributes.get("vector_distance") == [3.5]
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
@pytest.mark.asyncio
|
|
487
|
+
async def test_map_vector_distances_empty_nodes_returns_early(setup_graph):
|
|
488
|
+
"""Test that node_distances={} returns early without error and vector_distance is set to default penalty."""
|
|
489
|
+
graph = setup_graph
|
|
490
|
+
node1 = Node("1")
|
|
491
|
+
node2 = Node("2")
|
|
492
|
+
graph.add_node(node1)
|
|
493
|
+
graph.add_node(node2)
|
|
494
|
+
|
|
495
|
+
await graph.map_vector_distances_to_graph_nodes({})
|
|
496
|
+
|
|
497
|
+
assert node1.attributes.get("vector_distance") == [3.5]
|
|
498
|
+
assert node2.attributes.get("vector_distance") == [3.5]
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
@pytest.mark.asyncio
|
|
502
|
+
async def test_map_vector_distances_to_graph_edges_multi_query(setup_graph):
|
|
503
|
+
"""Test mapping edge distances with multiple queries."""
|
|
504
|
+
graph = setup_graph
|
|
505
|
+
|
|
506
|
+
node1 = Node("1")
|
|
507
|
+
node2 = Node("2")
|
|
508
|
+
node3 = Node("3")
|
|
509
|
+
graph.add_node(node1)
|
|
510
|
+
graph.add_node(node2)
|
|
511
|
+
graph.add_node(node3)
|
|
512
|
+
|
|
513
|
+
edge1 = Edge(node1, node2, attributes={"edge_text": "A"})
|
|
514
|
+
edge2 = Edge(node2, node3, attributes={"edge_text": "B"})
|
|
515
|
+
graph.add_edge(edge1)
|
|
516
|
+
graph.add_edge(edge2)
|
|
517
|
+
|
|
518
|
+
edge_1_text = "A"
|
|
519
|
+
edge_2_text = "B"
|
|
520
|
+
edge_distances = [
|
|
521
|
+
[
|
|
522
|
+
MockScoredResult(generate_edge_id(edge_1_text), 0.1, payload={"text": edge_1_text})
|
|
523
|
+
], # query 0
|
|
524
|
+
[
|
|
525
|
+
MockScoredResult(generate_edge_id(edge_2_text), 0.2, payload={"text": edge_2_text})
|
|
526
|
+
], # query 1
|
|
527
|
+
]
|
|
528
|
+
|
|
529
|
+
await graph.map_vector_distances_to_graph_edges(
|
|
530
|
+
edge_distances=edge_distances, query_list_length=2
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
assert graph.edges[0].attributes.get("vector_distance") == [0.1, 3.5]
|
|
534
|
+
assert graph.edges[1].attributes.get("vector_distance") == [3.5, 0.2]
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
@pytest.mark.asyncio
|
|
538
|
+
async def test_map_vector_distances_to_graph_edges_preserves_unmapped_indices(setup_graph):
|
|
539
|
+
"""Test that unmapped indices in multi-query mode stay at default penalty."""
|
|
540
|
+
graph = setup_graph
|
|
541
|
+
|
|
542
|
+
node1 = Node("1")
|
|
543
|
+
node2 = Node("2")
|
|
544
|
+
node3 = Node("3")
|
|
545
|
+
graph.add_node(node1)
|
|
546
|
+
graph.add_node(node2)
|
|
547
|
+
graph.add_node(node3)
|
|
548
|
+
|
|
549
|
+
edge1 = Edge(node1, node2, attributes={"edge_text": "A"})
|
|
550
|
+
edge2 = Edge(node2, node3, attributes={"edge_text": "B"})
|
|
551
|
+
graph.add_edge(edge1)
|
|
552
|
+
graph.add_edge(edge2)
|
|
553
|
+
|
|
554
|
+
edge_1_text = "A"
|
|
555
|
+
edge_distances = [
|
|
556
|
+
[
|
|
557
|
+
MockScoredResult(generate_edge_id(edge_1_text), 0.1, payload={"text": edge_1_text})
|
|
558
|
+
], # query 0: only edge1 mapped
|
|
559
|
+
[], # query 1: no edges mapped
|
|
560
|
+
]
|
|
561
|
+
|
|
562
|
+
await graph.map_vector_distances_to_graph_edges(
|
|
563
|
+
edge_distances=edge_distances, query_list_length=2
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
assert graph.edges[0].attributes.get("vector_distance") == [0.1, 3.5]
|
|
567
|
+
assert graph.edges[1].attributes.get("vector_distance") == [3.5, 3.5]
|
|
423
568
|
|
|
424
569
|
|
|
425
570
|
@pytest.mark.asyncio
|
|
@@ -432,10 +577,10 @@ async def test_calculate_top_triplet_importances(setup_graph):
|
|
|
432
577
|
node3 = Node("3")
|
|
433
578
|
node4 = Node("4")
|
|
434
579
|
|
|
435
|
-
node1.add_attribute("vector_distance", 0.9)
|
|
436
|
-
node2.add_attribute("vector_distance", 0.8)
|
|
437
|
-
node3.add_attribute("vector_distance", 0.7)
|
|
438
|
-
node4.add_attribute("vector_distance", 0.6)
|
|
580
|
+
node1.add_attribute("vector_distance", [0.9])
|
|
581
|
+
node2.add_attribute("vector_distance", [0.8])
|
|
582
|
+
node3.add_attribute("vector_distance", [0.7])
|
|
583
|
+
node4.add_attribute("vector_distance", [0.6])
|
|
439
584
|
|
|
440
585
|
graph.add_node(node1)
|
|
441
586
|
graph.add_node(node2)
|
|
@@ -446,9 +591,9 @@ async def test_calculate_top_triplet_importances(setup_graph):
|
|
|
446
591
|
edge2 = Edge(node2, node3)
|
|
447
592
|
edge3 = Edge(node3, node4)
|
|
448
593
|
|
|
449
|
-
edge1.add_attribute("vector_distance", 0.85)
|
|
450
|
-
edge2.add_attribute("vector_distance", 0.75)
|
|
451
|
-
edge3.add_attribute("vector_distance", 0.65)
|
|
594
|
+
edge1.add_attribute("vector_distance", [0.85])
|
|
595
|
+
edge2.add_attribute("vector_distance", [0.75])
|
|
596
|
+
edge3.add_attribute("vector_distance", [0.65])
|
|
452
597
|
|
|
453
598
|
graph.add_edge(edge1)
|
|
454
599
|
graph.add_edge(edge2)
|
|
@@ -464,7 +609,112 @@ async def test_calculate_top_triplet_importances(setup_graph):
|
|
|
464
609
|
|
|
465
610
|
@pytest.mark.asyncio
|
|
466
611
|
async def test_calculate_top_triplet_importances_default_distances(setup_graph):
|
|
467
|
-
"""Test
|
|
612
|
+
"""Test that vector_distance stays None when no distances are passed and calculate_top_triplet_importances handles it."""
|
|
613
|
+
graph = setup_graph
|
|
614
|
+
|
|
615
|
+
node1 = Node("1")
|
|
616
|
+
node2 = Node("2")
|
|
617
|
+
graph.add_node(node1)
|
|
618
|
+
graph.add_node(node2)
|
|
619
|
+
|
|
620
|
+
edge = Edge(node1, node2)
|
|
621
|
+
graph.add_edge(edge)
|
|
622
|
+
|
|
623
|
+
# Verify vector_distance is None when no distances are passed
|
|
624
|
+
assert node1.attributes.get("vector_distance") is None
|
|
625
|
+
assert node2.attributes.get("vector_distance") is None
|
|
626
|
+
assert edge.attributes.get("vector_distance") is None
|
|
627
|
+
|
|
628
|
+
# When no distances are set, calculate_top_triplet_importances should handle None
|
|
629
|
+
# by either raising an error or skipping edges with None distances
|
|
630
|
+
with pytest.raises(ValueError):
|
|
631
|
+
await graph.calculate_top_triplet_importances(k=1)
|
|
632
|
+
|
|
633
|
+
|
|
634
|
+
@pytest.mark.asyncio
|
|
635
|
+
async def test_calculate_top_triplet_importances_single_query_via_helper(setup_graph):
|
|
636
|
+
"""Test calculating top triplet importances for a single query index."""
|
|
637
|
+
graph = setup_graph
|
|
638
|
+
|
|
639
|
+
node1 = Node("1")
|
|
640
|
+
node2 = Node("2")
|
|
641
|
+
node3 = Node("3")
|
|
642
|
+
graph.add_node(node1)
|
|
643
|
+
graph.add_node(node2)
|
|
644
|
+
graph.add_node(node3)
|
|
645
|
+
|
|
646
|
+
node1.add_attribute("vector_distance", [0.1])
|
|
647
|
+
node2.add_attribute("vector_distance", [0.2])
|
|
648
|
+
node3.add_attribute("vector_distance", [0.3])
|
|
649
|
+
|
|
650
|
+
edge1 = Edge(node1, node2)
|
|
651
|
+
edge2 = Edge(node2, node3)
|
|
652
|
+
graph.add_edge(edge1)
|
|
653
|
+
graph.add_edge(edge2)
|
|
654
|
+
|
|
655
|
+
edge1.add_attribute("vector_distance", [0.3])
|
|
656
|
+
edge2.add_attribute("vector_distance", [0.4])
|
|
657
|
+
|
|
658
|
+
results = await graph.calculate_top_triplet_importances(k=1, query_list_length=1)
|
|
659
|
+
assert len(results) == 1
|
|
660
|
+
assert len(results[0]) == 1
|
|
661
|
+
assert results[0][0] == edge1
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
@pytest.mark.asyncio
|
|
665
|
+
async def test_calculate_top_triplet_importances_multi_query(setup_graph):
|
|
666
|
+
"""Test calculating top triplet importances with multiple queries."""
|
|
667
|
+
graph = setup_graph
|
|
668
|
+
|
|
669
|
+
node1 = Node("1")
|
|
670
|
+
node2 = Node("2")
|
|
671
|
+
node3 = Node("3")
|
|
672
|
+
graph.add_node(node1)
|
|
673
|
+
graph.add_node(node2)
|
|
674
|
+
graph.add_node(node3)
|
|
675
|
+
|
|
676
|
+
edge_a = Edge(node1, node2)
|
|
677
|
+
edge_b = Edge(node2, node3)
|
|
678
|
+
graph.add_edge(edge_a)
|
|
679
|
+
graph.add_edge(edge_b)
|
|
680
|
+
|
|
681
|
+
node1.add_attribute("vector_distance", [0.1, 0.9])
|
|
682
|
+
node2.add_attribute("vector_distance", [0.1, 0.9])
|
|
683
|
+
node3.add_attribute("vector_distance", [0.9, 0.1])
|
|
684
|
+
edge_a.add_attribute("vector_distance", [0.1, 0.9])
|
|
685
|
+
edge_b.add_attribute("vector_distance", [0.9, 0.1])
|
|
686
|
+
|
|
687
|
+
results = await graph.calculate_top_triplet_importances(k=1, query_list_length=2)
|
|
688
|
+
|
|
689
|
+
assert len(results) == 2
|
|
690
|
+
assert results[0][0] == edge_a
|
|
691
|
+
assert results[1][0] == edge_b
|
|
692
|
+
|
|
693
|
+
|
|
694
|
+
@pytest.mark.asyncio
|
|
695
|
+
async def test_calculate_top_triplet_importances_raises_on_short_list(setup_graph):
|
|
696
|
+
"""Test that scoring raises ValueError when list is too short for query_index."""
|
|
697
|
+
graph = setup_graph
|
|
698
|
+
|
|
699
|
+
node1 = Node("1")
|
|
700
|
+
node2 = Node("2")
|
|
701
|
+
graph.add_node(node1)
|
|
702
|
+
graph.add_node(node2)
|
|
703
|
+
|
|
704
|
+
node1.add_attribute("vector_distance", [0.1])
|
|
705
|
+
node2.add_attribute("vector_distance", [0.2])
|
|
706
|
+
|
|
707
|
+
edge = Edge(node1, node2)
|
|
708
|
+
edge.add_attribute("vector_distance", [0.3])
|
|
709
|
+
graph.add_edge(edge)
|
|
710
|
+
|
|
711
|
+
with pytest.raises(ValueError):
|
|
712
|
+
await graph.calculate_top_triplet_importances(k=1, query_list_length=2)
|
|
713
|
+
|
|
714
|
+
|
|
715
|
+
@pytest.mark.asyncio
|
|
716
|
+
async def test_calculate_top_triplet_importances_raises_on_missing_attribute(setup_graph):
|
|
717
|
+
"""Test that scoring raises error when vector_distance is missing."""
|
|
468
718
|
graph = setup_graph
|
|
469
719
|
|
|
470
720
|
node1 = Node("1")
|
|
@@ -472,10 +722,58 @@ async def test_calculate_top_triplet_importances_default_distances(setup_graph):
|
|
|
472
722
|
graph.add_node(node1)
|
|
473
723
|
graph.add_node(node2)
|
|
474
724
|
|
|
725
|
+
del node1.attributes["vector_distance"]
|
|
726
|
+
del node2.attributes["vector_distance"]
|
|
727
|
+
|
|
475
728
|
edge = Edge(node1, node2)
|
|
729
|
+
del edge.attributes["vector_distance"]
|
|
476
730
|
graph.add_edge(edge)
|
|
477
731
|
|
|
478
|
-
|
|
732
|
+
with pytest.raises(ValueError):
|
|
733
|
+
await graph.calculate_top_triplet_importances(k=1, query_list_length=1)
|
|
734
|
+
|
|
735
|
+
|
|
736
|
+
def test_normalize_query_distance_lists_flat_list_single_query(setup_graph):
|
|
737
|
+
"""Test that flat list is normalized to list-of-lists with length 1 for single-query mode."""
|
|
738
|
+
graph = setup_graph
|
|
739
|
+
flat_list = [MockScoredResult("node1", 0.95), MockScoredResult("node2", 0.87)]
|
|
740
|
+
|
|
741
|
+
result = graph._normalize_query_distance_lists(flat_list, query_list_length=None, name="test")
|
|
742
|
+
|
|
743
|
+
assert len(result) == 1
|
|
744
|
+
assert result[0] == flat_list
|
|
745
|
+
|
|
746
|
+
|
|
747
|
+
def test_normalize_query_distance_lists_nested_list_batch_mode(setup_graph):
|
|
748
|
+
"""Test that nested list is used as-is when query_list_length matches."""
|
|
749
|
+
graph = setup_graph
|
|
750
|
+
nested_list = [
|
|
751
|
+
[MockScoredResult("node1", 0.95)],
|
|
752
|
+
[MockScoredResult("node2", 0.87)],
|
|
753
|
+
]
|
|
754
|
+
|
|
755
|
+
result = graph._normalize_query_distance_lists(nested_list, query_list_length=2, name="test")
|
|
756
|
+
|
|
757
|
+
assert len(result) == 2
|
|
758
|
+
assert result == nested_list
|
|
759
|
+
|
|
760
|
+
|
|
761
|
+
def test_normalize_query_distance_lists_raises_on_length_mismatch(setup_graph):
|
|
762
|
+
"""Test that ValueError is raised when nested list length doesn't match query_list_length."""
|
|
763
|
+
graph = setup_graph
|
|
764
|
+
nested_list = [
|
|
765
|
+
[MockScoredResult("node1", 0.95)],
|
|
766
|
+
[MockScoredResult("node2", 0.87)],
|
|
767
|
+
]
|
|
768
|
+
|
|
769
|
+
with pytest.raises(ValueError, match="test has 2 query lists, but query_list_length is 3"):
|
|
770
|
+
graph._normalize_query_distance_lists(nested_list, query_list_length=3, name="test")
|
|
771
|
+
|
|
772
|
+
|
|
773
|
+
def test_normalize_query_distance_lists_empty_list(setup_graph):
|
|
774
|
+
"""Test that empty list returns empty list."""
|
|
775
|
+
graph = setup_graph
|
|
776
|
+
|
|
777
|
+
result = graph._normalize_query_distance_lists([], query_list_length=None, name="test")
|
|
479
778
|
|
|
480
|
-
assert
|
|
481
|
-
assert top_triplets[0] == edge
|
|
779
|
+
assert result == []
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import pytest
|
|
2
|
+
from types import SimpleNamespace
|
|
2
3
|
from unittest.mock import AsyncMock, patch, MagicMock
|
|
3
4
|
|
|
4
5
|
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
|
|
@@ -30,12 +31,14 @@ async def test_get_context_success(mock_vector_engine):
|
|
|
30
31
|
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
|
|
31
32
|
return_value=mock_vector_engine,
|
|
32
33
|
):
|
|
33
|
-
|
|
34
|
+
objects = await retriever.get_retrieved_objects("test query")
|
|
34
35
|
|
|
35
|
-
assert len(
|
|
36
|
-
assert
|
|
37
|
-
assert
|
|
38
|
-
mock_vector_engine.search.assert_awaited_once_with(
|
|
36
|
+
assert len(objects) == 2
|
|
37
|
+
assert objects[0].payload["text"] == "Steve Rodger"
|
|
38
|
+
assert objects[1].payload["text"] == "Mike Broski"
|
|
39
|
+
mock_vector_engine.search.assert_awaited_once_with(
|
|
40
|
+
"DocumentChunk_text", "test query", limit=5, include_payload=True
|
|
41
|
+
)
|
|
39
42
|
|
|
40
43
|
|
|
41
44
|
@pytest.mark.asyncio
|
|
@@ -50,7 +53,7 @@ async def test_get_context_collection_not_found_error(mock_vector_engine):
|
|
|
50
53
|
return_value=mock_vector_engine,
|
|
51
54
|
):
|
|
52
55
|
with pytest.raises(NoDataError, match="No data found"):
|
|
53
|
-
await retriever.
|
|
56
|
+
await retriever.get_retrieved_objects("test query")
|
|
54
57
|
|
|
55
58
|
|
|
56
59
|
@pytest.mark.asyncio
|
|
@@ -64,9 +67,9 @@ async def test_get_context_empty_results(mock_vector_engine):
|
|
|
64
67
|
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
|
|
65
68
|
return_value=mock_vector_engine,
|
|
66
69
|
):
|
|
67
|
-
|
|
70
|
+
objects = await retriever.get_retrieved_objects("test query")
|
|
68
71
|
|
|
69
|
-
assert
|
|
72
|
+
assert objects == []
|
|
70
73
|
|
|
71
74
|
|
|
72
75
|
@pytest.mark.asyncio
|
|
@@ -84,40 +87,29 @@ async def test_get_context_top_k_limit(mock_vector_engine):
|
|
|
84
87
|
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
|
|
85
88
|
return_value=mock_vector_engine,
|
|
86
89
|
):
|
|
87
|
-
|
|
90
|
+
objects = await retriever.get_retrieved_objects("test query")
|
|
88
91
|
|
|
89
|
-
assert len(
|
|
90
|
-
mock_vector_engine.search.assert_awaited_once_with(
|
|
92
|
+
assert len(objects) == 3
|
|
93
|
+
mock_vector_engine.search.assert_awaited_once_with(
|
|
94
|
+
"DocumentChunk_text", "test query", limit=3, include_payload=True
|
|
95
|
+
)
|
|
91
96
|
|
|
92
97
|
|
|
93
98
|
@pytest.mark.asyncio
|
|
94
|
-
async def
|
|
99
|
+
async def test_get_context(mock_vector_engine):
|
|
95
100
|
"""Test get_completion returns provided context."""
|
|
96
101
|
retriever = ChunksRetriever()
|
|
97
102
|
|
|
98
|
-
|
|
99
|
-
|
|
103
|
+
retrieved_objects = [
|
|
104
|
+
{"payload": {"text": "Steve Rodger"}},
|
|
105
|
+
{"payload": {"text": "Mike Broski"}},
|
|
106
|
+
]
|
|
107
|
+
# Wrap the outer dictionary so payload is an attribute
|
|
108
|
+
mock_objects = [SimpleNamespace(**obj) for obj in retrieved_objects]
|
|
100
109
|
|
|
101
|
-
|
|
110
|
+
context = await retriever.get_context_from_objects("test query", retrieved_objects=mock_objects)
|
|
102
111
|
|
|
103
|
-
|
|
104
|
-
@pytest.mark.asyncio
|
|
105
|
-
async def test_get_completion_without_context(mock_vector_engine):
|
|
106
|
-
"""Test get_completion retrieves context when not provided."""
|
|
107
|
-
mock_result = MagicMock()
|
|
108
|
-
mock_result.payload = {"text": "Steve Rodger"}
|
|
109
|
-
mock_vector_engine.search.return_value = [mock_result]
|
|
110
|
-
|
|
111
|
-
retriever = ChunksRetriever()
|
|
112
|
-
|
|
113
|
-
with patch(
|
|
114
|
-
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
|
|
115
|
-
return_value=mock_vector_engine,
|
|
116
|
-
):
|
|
117
|
-
completion = await retriever.get_completion("test query")
|
|
118
|
-
|
|
119
|
-
assert len(completion) == 1
|
|
120
|
-
assert completion[0]["text"] == "Steve Rodger"
|
|
112
|
+
assert context == "Steve Rodger\nMike Broski"
|
|
121
113
|
|
|
122
114
|
|
|
123
115
|
@pytest.mark.asyncio
|
|
@@ -147,29 +139,7 @@ async def test_init_none_top_k():
|
|
|
147
139
|
@pytest.mark.asyncio
|
|
148
140
|
async def test_get_context_empty_payload(mock_vector_engine):
|
|
149
141
|
"""Test get_context handles empty payload."""
|
|
150
|
-
|
|
151
|
-
mock_result.payload = {}
|
|
152
|
-
|
|
153
|
-
mock_vector_engine.search.return_value = [mock_result]
|
|
154
|
-
|
|
155
|
-
retriever = ChunksRetriever()
|
|
156
|
-
|
|
157
|
-
with patch(
|
|
158
|
-
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
|
|
159
|
-
return_value=mock_vector_engine,
|
|
160
|
-
):
|
|
161
|
-
context = await retriever.get_context("test query")
|
|
162
|
-
|
|
163
|
-
assert len(context) == 1
|
|
164
|
-
assert context[0] == {}
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
@pytest.mark.asyncio
|
|
168
|
-
async def test_get_completion_with_session_id(mock_vector_engine):
|
|
169
|
-
"""Test get_completion with session_id parameter."""
|
|
170
|
-
mock_result = MagicMock()
|
|
171
|
-
mock_result.payload = {"text": "Steve Rodger"}
|
|
172
|
-
mock_vector_engine.search.return_value = [mock_result]
|
|
142
|
+
mock_vector_engine.search.return_value = []
|
|
173
143
|
|
|
174
144
|
retriever = ChunksRetriever()
|
|
175
145
|
|
|
@@ -177,7 +147,9 @@ async def test_get_completion_with_session_id(mock_vector_engine):
|
|
|
177
147
|
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
|
|
178
148
|
return_value=mock_vector_engine,
|
|
179
149
|
):
|
|
180
|
-
|
|
150
|
+
retrieved_objects = await retriever.get_retrieved_objects("test query")
|
|
151
|
+
context = await retriever.get_context_from_objects(
|
|
152
|
+
"test query", retrieved_objects=retrieved_objects
|
|
153
|
+
)
|
|
181
154
|
|
|
182
|
-
assert
|
|
183
|
-
assert completion[0]["text"] == "Steve Rodger"
|
|
155
|
+
assert context == ""
|