cognee 0.5.1__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +2 -0
- cognee/alembic/README +1 -0
- cognee/alembic/env.py +107 -0
- cognee/alembic/script.py.mako +26 -0
- cognee/alembic/versions/1a58b986e6e1_enable_delete_for_old_tutorial_notebooks.py +52 -0
- cognee/alembic/versions/1d0bb7fede17_add_pipeline_run_status.py +33 -0
- cognee/alembic/versions/1daae0df1866_incremental_loading.py +48 -0
- cognee/alembic/versions/211ab850ef3d_add_sync_operations_table.py +118 -0
- cognee/alembic/versions/45957f0a9849_add_notebook_table.py +46 -0
- cognee/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py +333 -0
- cognee/alembic/versions/482cd6517ce4_add_default_user.py +30 -0
- cognee/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py +98 -0
- cognee/alembic/versions/8057ae7329c2_initial_migration.py +25 -0
- cognee/alembic/versions/9e7a3cb85175_loader_separation.py +104 -0
- cognee/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py +38 -0
- cognee/alembic/versions/ab7e313804ae_permission_system_rework.py +236 -0
- cognee/alembic/versions/b9274c27a25a_kuzu_11_migration.py +75 -0
- cognee/alembic/versions/c946955da633_multi_tenant_support.py +137 -0
- cognee/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py +51 -0
- cognee/alembic/versions/e4ebee1091e7_expand_data_model_info.py +140 -0
- cognee/alembic.ini +117 -0
- cognee/api/v1/add/add.py +2 -1
- cognee/api/v1/add/routers/get_add_router.py +2 -0
- cognee/api/v1/cognify/cognify.py +11 -6
- cognee/api/v1/cognify/routers/get_cognify_router.py +8 -0
- cognee/api/v1/config/config.py +60 -0
- cognee/api/v1/datasets/routers/get_datasets_router.py +46 -3
- cognee/api/v1/memify/routers/get_memify_router.py +3 -0
- cognee/api/v1/search/routers/get_search_router.py +21 -6
- cognee/api/v1/search/search.py +21 -5
- cognee/api/v1/sync/routers/get_sync_router.py +3 -3
- cognee/cli/commands/add_command.py +1 -1
- cognee/cli/commands/cognify_command.py +6 -0
- cognee/cli/commands/config_command.py +1 -1
- cognee/context_global_variables.py +5 -1
- cognee/eval_framework/answer_generation/answer_generation_executor.py +7 -8
- cognee/infrastructure/databases/cache/cache_db_interface.py +38 -1
- cognee/infrastructure/databases/cache/config.py +6 -0
- cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py +21 -0
- cognee/infrastructure/databases/cache/get_cache_engine.py +9 -3
- cognee/infrastructure/databases/cache/redis/RedisAdapter.py +60 -1
- cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +7 -0
- cognee/infrastructure/databases/graph/get_graph_engine.py +29 -1
- cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +62 -27
- cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +17 -4
- cognee/infrastructure/databases/relational/config.py +16 -1
- cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +26 -3
- cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +2 -0
- cognee/infrastructure/databases/vector/config.py +6 -0
- cognee/infrastructure/databases/vector/create_vector_engine.py +70 -16
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +64 -9
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +13 -2
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +16 -3
- cognee/infrastructure/databases/vector/models/ScoredResult.py +3 -3
- cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +16 -3
- cognee/infrastructure/databases/vector/pgvector/PGVectorDatasetDatabaseHandler.py +86 -0
- cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py +81 -2
- cognee/infrastructure/databases/vector/vector_db_interface.py +8 -0
- cognee/infrastructure/files/utils/get_data_file_path.py +33 -27
- cognee/infrastructure/llm/LLMGateway.py +0 -13
- cognee/infrastructure/llm/prompts/extract_query_time.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_event_entity_prompt.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_event_graph_prompt.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt.txt +2 -2
- cognee/infrastructure/llm/prompts/generate_graph_prompt_guided.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt_oneshot.txt +2 -2
- cognee/infrastructure/llm/prompts/generate_graph_prompt_simple.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt_strict.txt +1 -1
- cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt +6 -6
- cognee/infrastructure/llm/prompts/test.txt +1 -1
- cognee/infrastructure/llm/prompts/translate_content.txt +19 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +29 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llama_cpp/adapter.py +191 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
- cognee/modules/chunking/models/DocumentChunk.py +0 -1
- cognee/modules/cognify/config.py +2 -0
- cognee/modules/data/models/Data.py +3 -1
- cognee/modules/engine/models/Entity.py +0 -1
- cognee/modules/engine/operations/setup.py +6 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +150 -37
- cognee/modules/graph/cognee_graph/CogneeGraphElements.py +48 -2
- cognee/modules/graph/utils/__init__.py +1 -0
- cognee/modules/graph/utils/get_entity_nodes_from_triplets.py +12 -0
- cognee/modules/notebooks/methods/__init__.py +1 -0
- cognee/modules/notebooks/methods/create_notebook.py +0 -34
- cognee/modules/notebooks/methods/create_tutorial_notebooks.py +191 -0
- cognee/modules/notebooks/methods/get_notebooks.py +12 -8
- cognee/modules/notebooks/tutorials/cognee-basics/cell-1.md +3 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-2.md +10 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-3.md +7 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-4.py +28 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-5.py +3 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-6.py +9 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-7.py +17 -0
- cognee/modules/notebooks/tutorials/cognee-basics/config.json +4 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-1.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-10.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-11.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-12.py +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-13.md +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-14.py +6 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-15.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-16.py +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-2.md +9 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-3.md +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-4.md +9 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-5.md +5 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-6.py +13 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-7.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-8.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-9.py +31 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/config.json +4 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/copilot_conversations.json +107 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/guido_contributions.json +976 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/my_developer_rules.md +79 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/pep_style_guide.md +74 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/zen_principles.md +74 -0
- cognee/modules/retrieval/EntityCompletionRetriever.py +51 -38
- cognee/modules/retrieval/__init__.py +0 -1
- cognee/modules/retrieval/base_retriever.py +66 -10
- cognee/modules/retrieval/chunks_retriever.py +57 -49
- cognee/modules/retrieval/coding_rules_retriever.py +12 -5
- cognee/modules/retrieval/completion_retriever.py +29 -28
- cognee/modules/retrieval/cypher_search_retriever.py +25 -20
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +42 -46
- cognee/modules/retrieval/graph_completion_cot_retriever.py +68 -51
- cognee/modules/retrieval/graph_completion_retriever.py +78 -63
- cognee/modules/retrieval/graph_summary_completion_retriever.py +2 -0
- cognee/modules/retrieval/lexical_retriever.py +34 -12
- cognee/modules/retrieval/natural_language_retriever.py +18 -15
- cognee/modules/retrieval/summaries_retriever.py +51 -34
- cognee/modules/retrieval/temporal_retriever.py +59 -49
- cognee/modules/retrieval/triplet_retriever.py +32 -33
- cognee/modules/retrieval/utils/access_tracking.py +88 -0
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +99 -103
- cognee/modules/retrieval/utils/node_edge_vector_search.py +174 -0
- cognee/modules/search/methods/__init__.py +1 -0
- cognee/modules/search/methods/get_retriever_output.py +53 -0
- cognee/modules/search/methods/get_search_type_retriever_instance.py +252 -0
- cognee/modules/search/methods/search.py +90 -222
- cognee/modules/search/models/SearchResultPayload.py +67 -0
- cognee/modules/search/types/SearchResult.py +1 -8
- cognee/modules/search/types/SearchType.py +1 -2
- cognee/modules/search/types/__init__.py +1 -1
- cognee/modules/search/utils/__init__.py +1 -2
- cognee/modules/search/utils/transform_insights_to_graph.py +2 -2
- cognee/modules/search/utils/{transform_context_to_graph.py → transform_triplets_to_graph.py} +2 -2
- cognee/modules/users/authentication/default/default_transport.py +11 -1
- cognee/modules/users/authentication/get_api_auth_backend.py +2 -1
- cognee/modules/users/authentication/get_client_auth_backend.py +2 -1
- cognee/modules/users/methods/create_user.py +0 -9
- cognee/modules/users/permissions/methods/has_user_management_permission.py +29 -0
- cognee/modules/visualization/cognee_network_visualization.py +1 -1
- cognee/run_migrations.py +48 -0
- cognee/shared/exceptions/__init__.py +1 -3
- cognee/shared/exceptions/exceptions.py +11 -1
- cognee/shared/usage_logger.py +332 -0
- cognee/shared/utils.py +12 -5
- cognee/tasks/chunks/__init__.py +9 -0
- cognee/tasks/cleanup/cleanup_unused_data.py +172 -0
- cognee/tasks/graph/__init__.py +7 -0
- cognee/tasks/ingestion/data_item.py +8 -0
- cognee/tasks/ingestion/ingest_data.py +12 -1
- cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
- cognee/tasks/memify/__init__.py +8 -0
- cognee/tasks/memify/extract_usage_frequency.py +613 -0
- cognee/tasks/summarization/models.py +0 -2
- cognee/tasks/temporal_graph/__init__.py +0 -1
- cognee/tasks/translation/__init__.py +96 -0
- cognee/tasks/translation/config.py +110 -0
- cognee/tasks/translation/detect_language.py +190 -0
- cognee/tasks/translation/exceptions.py +62 -0
- cognee/tasks/translation/models.py +72 -0
- cognee/tasks/translation/providers/__init__.py +44 -0
- cognee/tasks/translation/providers/azure_provider.py +192 -0
- cognee/tasks/translation/providers/base.py +85 -0
- cognee/tasks/translation/providers/google_provider.py +158 -0
- cognee/tasks/translation/providers/llm_provider.py +143 -0
- cognee/tasks/translation/translate_content.py +282 -0
- cognee/tasks/web_scraper/default_url_crawler.py +6 -2
- cognee/tests/cli_tests/cli_unit_tests/test_cli_commands.py +1 -0
- cognee/tests/cli_tests/cli_unit_tests/test_cli_edge_cases.py +3 -0
- cognee/tests/integration/retrieval/test_brute_force_triplet_search_with_cognify.py +62 -0
- cognee/tests/integration/retrieval/test_chunks_retriever.py +351 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever.py +276 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +228 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +217 -0
- cognee/tests/integration/retrieval/test_rag_completion_retriever.py +319 -0
- cognee/tests/integration/retrieval/test_structured_output.py +258 -0
- cognee/tests/integration/retrieval/test_summaries_retriever.py +195 -0
- cognee/tests/integration/retrieval/test_temporal_retriever.py +336 -0
- cognee/tests/integration/retrieval/test_triplet_retriever.py +45 -1
- cognee/tests/integration/shared/test_usage_logger_integration.py +255 -0
- cognee/tests/tasks/translation/README.md +147 -0
- cognee/tests/tasks/translation/__init__.py +1 -0
- cognee/tests/tasks/translation/config_test.py +93 -0
- cognee/tests/tasks/translation/detect_language_test.py +118 -0
- cognee/tests/tasks/translation/providers_test.py +151 -0
- cognee/tests/tasks/translation/translate_content_test.py +213 -0
- cognee/tests/test_chromadb.py +1 -1
- cognee/tests/test_cleanup_unused_data.py +165 -0
- cognee/tests/test_custom_data_label.py +68 -0
- cognee/tests/test_delete_by_id.py +6 -6
- cognee/tests/test_extract_usage_frequency.py +308 -0
- cognee/tests/test_kuzu.py +17 -7
- cognee/tests/test_lancedb.py +3 -1
- cognee/tests/test_library.py +1 -1
- cognee/tests/test_neo4j.py +17 -7
- cognee/tests/test_neptune_analytics_vector.py +3 -1
- cognee/tests/test_permissions.py +172 -187
- cognee/tests/test_pgvector.py +3 -1
- cognee/tests/test_relational_db_migration.py +15 -1
- cognee/tests/test_remote_kuzu.py +3 -1
- cognee/tests/test_s3_file_storage.py +1 -1
- cognee/tests/test_search_db.py +345 -205
- cognee/tests/test_usage_logger_e2e.py +268 -0
- cognee/tests/unit/api/test_get_raw_data_endpoint.py +206 -0
- cognee/tests/unit/eval_framework/answer_generation_test.py +4 -3
- cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
- cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
- cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +2 -0
- cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
- cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +42 -2
- cognee/tests/unit/modules/graph/cognee_graph_test.py +329 -31
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +122 -168
- cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +486 -157
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +693 -155
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +619 -200
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +300 -171
- cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +184 -155
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +544 -79
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +476 -28
- cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
- cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
- cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py +273 -0
- cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +267 -7
- cognee/tests/unit/modules/search/test_get_search_type_retriever_instance.py +125 -0
- cognee/tests/unit/modules/search/test_search.py +96 -20
- cognee/tests/unit/modules/search/test_search_prepare_search_result_contract.py +190 -0
- cognee/tests/unit/modules/users/test_tutorial_notebook_creation.py +511 -297
- cognee/tests/unit/shared/test_usage_logger.py +241 -0
- cognee/tests/unit/users/permissions/test_has_user_management_permission.py +46 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/METADATA +22 -17
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/RECORD +258 -157
- cognee/api/.env.example +0 -5
- cognee/modules/retrieval/base_graph_retriever.py +0 -24
- cognee/modules/search/methods/get_search_type_tools.py +0 -223
- cognee/modules/search/methods/no_access_control_search.py +0 -62
- cognee/modules/search/utils/prepare_search_result.py +0 -63
- cognee/tests/test_feedback_enrichment.py +0 -174
- cognee/tests/unit/modules/retrieval/structured_output_test.py +0 -204
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/WHEEL +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from unittest.mock import AsyncMock
|
|
3
|
+
|
|
4
|
+
from cognee.modules.retrieval.utils.node_edge_vector_search import NodeEdgeVectorSearch
|
|
5
|
+
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MockScoredResult:
|
|
9
|
+
"""Mock class for vector search results."""
|
|
10
|
+
|
|
11
|
+
def __init__(self, id, score, payload=None):
|
|
12
|
+
self.id = id
|
|
13
|
+
self.score = score
|
|
14
|
+
self.payload = payload or {}
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@pytest.mark.asyncio
|
|
18
|
+
async def test_node_edge_vector_search_single_query_shape():
|
|
19
|
+
"""Test that single query mode produces flat lists (not list-of-lists)."""
|
|
20
|
+
mock_vector_engine = AsyncMock()
|
|
21
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
22
|
+
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
23
|
+
|
|
24
|
+
node_results = [MockScoredResult("node1", 0.95), MockScoredResult("node2", 0.87)]
|
|
25
|
+
edge_results = [MockScoredResult("edge1", 0.92)]
|
|
26
|
+
|
|
27
|
+
def search_side_effect(*args, **kwargs):
|
|
28
|
+
collection_name = kwargs.get("collection_name")
|
|
29
|
+
if collection_name == "EdgeType_relationship_name":
|
|
30
|
+
return edge_results
|
|
31
|
+
return node_results
|
|
32
|
+
|
|
33
|
+
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
|
34
|
+
|
|
35
|
+
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
|
36
|
+
collections = ["Entity_name", "EdgeType_relationship_name"]
|
|
37
|
+
|
|
38
|
+
await vector_search.embed_and_retrieve_distances(
|
|
39
|
+
query="test query", query_batch=None, collections=collections, wide_search_limit=10
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
assert vector_search.query_list_length is None
|
|
43
|
+
assert vector_search.edge_distances == edge_results
|
|
44
|
+
assert vector_search.node_distances["Entity_name"] == node_results
|
|
45
|
+
mock_vector_engine.embedding_engine.embed_text.assert_called_once_with(["test query"])
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@pytest.mark.asyncio
|
|
49
|
+
async def test_node_edge_vector_search_batch_query_shape_and_empties():
|
|
50
|
+
"""Test that batch query mode produces list-of-lists with correct length and handles empty collections."""
|
|
51
|
+
mock_vector_engine = AsyncMock()
|
|
52
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
53
|
+
|
|
54
|
+
query_batch = ["query a", "query b"]
|
|
55
|
+
node_results_query_a = [MockScoredResult("node1", 0.95)]
|
|
56
|
+
node_results_query_b = [MockScoredResult("node2", 0.87)]
|
|
57
|
+
edge_results_query_a = [MockScoredResult("edge1", 0.92)]
|
|
58
|
+
edge_results_query_b = []
|
|
59
|
+
|
|
60
|
+
def batch_search_side_effect(*args, **kwargs):
|
|
61
|
+
collection_name = kwargs.get("collection_name")
|
|
62
|
+
if collection_name == "EdgeType_relationship_name":
|
|
63
|
+
return [edge_results_query_a, edge_results_query_b]
|
|
64
|
+
elif collection_name == "Entity_name":
|
|
65
|
+
return [node_results_query_a, node_results_query_b]
|
|
66
|
+
elif collection_name == "MissingCollection":
|
|
67
|
+
raise CollectionNotFoundError("Collection not found")
|
|
68
|
+
return [[], []]
|
|
69
|
+
|
|
70
|
+
mock_vector_engine.batch_search = AsyncMock(side_effect=batch_search_side_effect)
|
|
71
|
+
|
|
72
|
+
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
|
73
|
+
collections = [
|
|
74
|
+
"Entity_name",
|
|
75
|
+
"EdgeType_relationship_name",
|
|
76
|
+
"MissingCollection",
|
|
77
|
+
"EmptyCollection",
|
|
78
|
+
]
|
|
79
|
+
|
|
80
|
+
await vector_search.embed_and_retrieve_distances(
|
|
81
|
+
query=None, query_batch=query_batch, collections=collections, wide_search_limit=None
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
assert vector_search.query_list_length == 2
|
|
85
|
+
assert len(vector_search.edge_distances) == 2
|
|
86
|
+
assert vector_search.edge_distances[0] == edge_results_query_a
|
|
87
|
+
assert vector_search.edge_distances[1] == edge_results_query_b
|
|
88
|
+
assert len(vector_search.node_distances["Entity_name"]) == 2
|
|
89
|
+
assert vector_search.node_distances["Entity_name"][0] == node_results_query_a
|
|
90
|
+
assert vector_search.node_distances["Entity_name"][1] == node_results_query_b
|
|
91
|
+
assert len(vector_search.node_distances["MissingCollection"]) == 2
|
|
92
|
+
assert vector_search.node_distances["MissingCollection"] == [[], []]
|
|
93
|
+
assert len(vector_search.node_distances["EmptyCollection"]) == 2
|
|
94
|
+
assert vector_search.node_distances["EmptyCollection"] == [[], []]
|
|
95
|
+
mock_vector_engine.embedding_engine.embed_text.assert_not_called()
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@pytest.mark.asyncio
|
|
99
|
+
async def test_node_edge_vector_search_input_validation_both_provided():
|
|
100
|
+
"""Test that providing both query and query_batch raises ValueError."""
|
|
101
|
+
mock_vector_engine = AsyncMock()
|
|
102
|
+
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
|
103
|
+
collections = ["Entity_name"]
|
|
104
|
+
|
|
105
|
+
with pytest.raises(ValueError, match="Cannot provide both 'query' and 'query_batch'"):
|
|
106
|
+
await vector_search.embed_and_retrieve_distances(
|
|
107
|
+
query="test", query_batch=["test1", "test2"], collections=collections
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@pytest.mark.asyncio
|
|
112
|
+
async def test_node_edge_vector_search_input_validation_neither_provided():
|
|
113
|
+
"""Test that providing neither query nor query_batch raises ValueError."""
|
|
114
|
+
mock_vector_engine = AsyncMock()
|
|
115
|
+
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
|
116
|
+
collections = ["Entity_name"]
|
|
117
|
+
|
|
118
|
+
with pytest.raises(ValueError, match="Must provide either 'query' or 'query_batch'"):
|
|
119
|
+
await vector_search.embed_and_retrieve_distances(
|
|
120
|
+
query=None, query_batch=None, collections=collections
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@pytest.mark.asyncio
|
|
125
|
+
async def test_node_edge_vector_search_extract_relevant_node_ids_single_query():
|
|
126
|
+
"""Test that extract_relevant_node_ids returns IDs for single query mode."""
|
|
127
|
+
mock_vector_engine = AsyncMock()
|
|
128
|
+
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
|
129
|
+
vector_search.query_list_length = None
|
|
130
|
+
vector_search.node_distances = {
|
|
131
|
+
"Entity_name": [MockScoredResult("node1", 0.95), MockScoredResult("node2", 0.87)],
|
|
132
|
+
"TextSummary_text": [MockScoredResult("node1", 0.90), MockScoredResult("node3", 0.92)],
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
node_ids = vector_search.extract_relevant_node_ids()
|
|
136
|
+
assert set(node_ids) == {"node1", "node2", "node3"}
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@pytest.mark.asyncio
|
|
140
|
+
async def test_node_edge_vector_search_extract_relevant_node_ids_batch():
|
|
141
|
+
"""Test that extract_relevant_node_ids returns empty list for batch mode."""
|
|
142
|
+
mock_vector_engine = AsyncMock()
|
|
143
|
+
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
|
144
|
+
vector_search.query_list_length = 2
|
|
145
|
+
vector_search.node_distances = {
|
|
146
|
+
"Entity_name": [
|
|
147
|
+
[MockScoredResult("node1", 0.95)],
|
|
148
|
+
[MockScoredResult("node2", 0.87)],
|
|
149
|
+
],
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
node_ids = vector_search.extract_relevant_node_ids()
|
|
153
|
+
assert node_ids == []
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@pytest.mark.asyncio
|
|
157
|
+
async def test_node_edge_vector_search_has_results_single_query():
|
|
158
|
+
"""Test has_results returns True when results exist and False when only empties."""
|
|
159
|
+
mock_vector_engine = AsyncMock()
|
|
160
|
+
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
|
161
|
+
|
|
162
|
+
vector_search.edge_distances = [MockScoredResult("edge1", 0.92)]
|
|
163
|
+
vector_search.node_distances = {}
|
|
164
|
+
assert vector_search.has_results() is True
|
|
165
|
+
|
|
166
|
+
vector_search.edge_distances = []
|
|
167
|
+
vector_search.node_distances = {"Entity_name": [MockScoredResult("node1", 0.95)]}
|
|
168
|
+
assert vector_search.has_results() is True
|
|
169
|
+
|
|
170
|
+
vector_search.edge_distances = []
|
|
171
|
+
vector_search.node_distances = {}
|
|
172
|
+
assert vector_search.has_results() is False
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
@pytest.mark.asyncio
|
|
176
|
+
async def test_node_edge_vector_search_has_results_batch():
|
|
177
|
+
"""Test has_results works correctly for batch mode with list-of-lists."""
|
|
178
|
+
mock_vector_engine = AsyncMock()
|
|
179
|
+
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
|
180
|
+
vector_search.query_list_length = 2
|
|
181
|
+
|
|
182
|
+
vector_search.edge_distances = [[MockScoredResult("edge1", 0.92)], []]
|
|
183
|
+
vector_search.node_distances = {}
|
|
184
|
+
assert vector_search.has_results() is True
|
|
185
|
+
|
|
186
|
+
vector_search.edge_distances = [[], []]
|
|
187
|
+
vector_search.node_distances = {
|
|
188
|
+
"Entity_name": [[MockScoredResult("node1", 0.95)], []],
|
|
189
|
+
}
|
|
190
|
+
assert vector_search.has_results() is True
|
|
191
|
+
|
|
192
|
+
vector_search.edge_distances = [[], []]
|
|
193
|
+
vector_search.node_distances = {"Entity_name": [[], []]}
|
|
194
|
+
assert vector_search.has_results() is False
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
@pytest.mark.asyncio
|
|
198
|
+
async def test_node_edge_vector_search_single_query_collection_not_found():
|
|
199
|
+
"""Test that CollectionNotFoundError in single query mode returns empty list."""
|
|
200
|
+
mock_vector_engine = AsyncMock()
|
|
201
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
202
|
+
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
203
|
+
mock_vector_engine.search = AsyncMock(
|
|
204
|
+
side_effect=CollectionNotFoundError("Collection not found")
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
|
208
|
+
collections = ["MissingCollection"]
|
|
209
|
+
|
|
210
|
+
await vector_search.embed_and_retrieve_distances(
|
|
211
|
+
query="test query", query_batch=None, collections=collections, wide_search_limit=10
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
assert vector_search.node_distances["MissingCollection"] == []
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
@pytest.mark.asyncio
|
|
218
|
+
async def test_node_edge_vector_search_missing_collections_single_query():
|
|
219
|
+
"""Test that missing collections in single-query mode are handled gracefully with empty lists."""
|
|
220
|
+
mock_vector_engine = AsyncMock()
|
|
221
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
222
|
+
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
223
|
+
|
|
224
|
+
node_result = MockScoredResult("node1", 0.95)
|
|
225
|
+
|
|
226
|
+
def search_side_effect(*args, **kwargs):
|
|
227
|
+
collection_name = kwargs.get("collection_name")
|
|
228
|
+
if collection_name == "Entity_name":
|
|
229
|
+
return [node_result]
|
|
230
|
+
elif collection_name == "MissingCollection":
|
|
231
|
+
raise CollectionNotFoundError("Collection not found")
|
|
232
|
+
return []
|
|
233
|
+
|
|
234
|
+
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
|
235
|
+
|
|
236
|
+
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
|
237
|
+
collections = ["Entity_name", "MissingCollection", "EmptyCollection"]
|
|
238
|
+
|
|
239
|
+
await vector_search.embed_and_retrieve_distances(
|
|
240
|
+
query="test query", query_batch=None, collections=collections, wide_search_limit=10
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
assert len(vector_search.node_distances["Entity_name"]) == 1
|
|
244
|
+
assert vector_search.node_distances["Entity_name"][0].id == "node1"
|
|
245
|
+
assert vector_search.node_distances["Entity_name"][0].score == 0.95
|
|
246
|
+
assert vector_search.node_distances["MissingCollection"] == []
|
|
247
|
+
assert vector_search.node_distances["EmptyCollection"] == []
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
@pytest.mark.asyncio
|
|
251
|
+
async def test_node_edge_vector_search_has_results_batch_nodes_only():
|
|
252
|
+
"""Test has_results returns True when only node distances are populated in batch mode."""
|
|
253
|
+
mock_vector_engine = AsyncMock()
|
|
254
|
+
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
|
255
|
+
vector_search.query_list_length = 2
|
|
256
|
+
vector_search.edge_distances = [[], []]
|
|
257
|
+
vector_search.node_distances = {
|
|
258
|
+
"Entity_name": [[MockScoredResult("node1", 0.95)], []],
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
assert vector_search.has_results() is True
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
@pytest.mark.asyncio
|
|
265
|
+
async def test_node_edge_vector_search_has_results_batch_edges_only():
|
|
266
|
+
"""Test has_results returns True when only edge distances are populated in batch mode."""
|
|
267
|
+
mock_vector_engine = AsyncMock()
|
|
268
|
+
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
|
269
|
+
vector_search.query_list_length = 2
|
|
270
|
+
vector_search.edge_distances = [[MockScoredResult("edge1", 0.92)], []]
|
|
271
|
+
vector_search.node_distances = {}
|
|
272
|
+
|
|
273
|
+
assert vector_search.has_results() is True
|
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from unittest.mock import AsyncMock, patch, MagicMock
|
|
3
|
+
from uuid import UUID, NAMESPACE_OID, uuid5
|
|
4
|
+
|
|
5
|
+
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
|
|
6
|
+
from cognee.modules.retrieval.utils.models import UserFeedbackEvaluation, UserFeedbackSentiment
|
|
7
|
+
from cognee.modules.engine.models import NodeSet
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@pytest.fixture
|
|
11
|
+
def mock_feedback_evaluation():
|
|
12
|
+
"""Create a mock feedback evaluation."""
|
|
13
|
+
evaluation = MagicMock(spec=UserFeedbackEvaluation)
|
|
14
|
+
evaluation.evaluation = MagicMock()
|
|
15
|
+
evaluation.evaluation.value = "positive"
|
|
16
|
+
evaluation.score = 4.5
|
|
17
|
+
return evaluation
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@pytest.fixture
|
|
21
|
+
def mock_graph_engine():
|
|
22
|
+
"""Create a mock graph engine."""
|
|
23
|
+
engine = AsyncMock()
|
|
24
|
+
engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
|
|
25
|
+
engine.add_edges = AsyncMock()
|
|
26
|
+
engine.apply_feedback_weight = AsyncMock()
|
|
27
|
+
return engine
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class TestUserQAFeedback:
|
|
31
|
+
@pytest.mark.asyncio
|
|
32
|
+
async def test_init_default(self):
|
|
33
|
+
"""Test UserQAFeedback initialization with default last_k."""
|
|
34
|
+
retriever = UserQAFeedback()
|
|
35
|
+
assert retriever.last_k == 1
|
|
36
|
+
|
|
37
|
+
@pytest.mark.asyncio
|
|
38
|
+
async def test_init_custom_last_k(self):
|
|
39
|
+
"""Test UserQAFeedback initialization with custom last_k."""
|
|
40
|
+
retriever = UserQAFeedback(last_k=5)
|
|
41
|
+
assert retriever.last_k == 5
|
|
42
|
+
|
|
43
|
+
@pytest.mark.asyncio
|
|
44
|
+
async def test_add_feedback_success_with_relationships(
|
|
45
|
+
self, mock_feedback_evaluation, mock_graph_engine
|
|
46
|
+
):
|
|
47
|
+
"""Test add_feedback successfully creates feedback with relationships."""
|
|
48
|
+
interaction_id_1 = str(UUID("550e8400-e29b-41d4-a716-446655440000"))
|
|
49
|
+
interaction_id_2 = str(UUID("550e8400-e29b-41d4-a716-446655440001"))
|
|
50
|
+
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(
|
|
51
|
+
return_value=[interaction_id_1, interaction_id_2]
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
feedback_text = "This answer was helpful"
|
|
55
|
+
|
|
56
|
+
with (
|
|
57
|
+
patch(
|
|
58
|
+
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
|
|
59
|
+
new_callable=AsyncMock,
|
|
60
|
+
return_value=mock_feedback_evaluation,
|
|
61
|
+
),
|
|
62
|
+
patch(
|
|
63
|
+
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
|
|
64
|
+
return_value=mock_graph_engine,
|
|
65
|
+
),
|
|
66
|
+
patch(
|
|
67
|
+
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
|
|
68
|
+
new_callable=AsyncMock,
|
|
69
|
+
) as mock_add_data,
|
|
70
|
+
patch(
|
|
71
|
+
"cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
|
|
72
|
+
new_callable=AsyncMock,
|
|
73
|
+
) as mock_index_edges,
|
|
74
|
+
):
|
|
75
|
+
retriever = UserQAFeedback(last_k=2)
|
|
76
|
+
result = await retriever.add_feedback(feedback_text)
|
|
77
|
+
|
|
78
|
+
assert result == [feedback_text]
|
|
79
|
+
mock_add_data.assert_awaited_once()
|
|
80
|
+
mock_graph_engine.add_edges.assert_awaited_once()
|
|
81
|
+
mock_index_edges.assert_awaited_once()
|
|
82
|
+
mock_graph_engine.apply_feedback_weight.assert_awaited_once()
|
|
83
|
+
|
|
84
|
+
# Verify add_edges was called with correct relationships
|
|
85
|
+
call_args = mock_graph_engine.add_edges.call_args[0][0]
|
|
86
|
+
assert len(call_args) == 2
|
|
87
|
+
assert call_args[0][0] == uuid5(NAMESPACE_OID, name=feedback_text)
|
|
88
|
+
assert call_args[0][1] == UUID(interaction_id_1)
|
|
89
|
+
assert call_args[0][2] == "gives_feedback_to"
|
|
90
|
+
assert call_args[0][3]["relationship_name"] == "gives_feedback_to"
|
|
91
|
+
assert call_args[0][3]["ontology_valid"] is False
|
|
92
|
+
|
|
93
|
+
# Verify apply_feedback_weight was called with correct node IDs
|
|
94
|
+
weight_call_args = mock_graph_engine.apply_feedback_weight.call_args[1]["node_ids"]
|
|
95
|
+
assert len(weight_call_args) == 2
|
|
96
|
+
assert interaction_id_1 in weight_call_args
|
|
97
|
+
assert interaction_id_2 in weight_call_args
|
|
98
|
+
|
|
99
|
+
@pytest.mark.asyncio
|
|
100
|
+
async def test_add_feedback_success_no_relationships(
|
|
101
|
+
self, mock_feedback_evaluation, mock_graph_engine
|
|
102
|
+
):
|
|
103
|
+
"""Test add_feedback successfully creates feedback without relationships."""
|
|
104
|
+
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
|
|
105
|
+
|
|
106
|
+
feedback_text = "This answer was helpful"
|
|
107
|
+
|
|
108
|
+
with (
|
|
109
|
+
patch(
|
|
110
|
+
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
|
|
111
|
+
new_callable=AsyncMock,
|
|
112
|
+
return_value=mock_feedback_evaluation,
|
|
113
|
+
),
|
|
114
|
+
patch(
|
|
115
|
+
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
|
|
116
|
+
return_value=mock_graph_engine,
|
|
117
|
+
),
|
|
118
|
+
patch(
|
|
119
|
+
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
|
|
120
|
+
new_callable=AsyncMock,
|
|
121
|
+
) as mock_add_data,
|
|
122
|
+
patch(
|
|
123
|
+
"cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
|
|
124
|
+
new_callable=AsyncMock,
|
|
125
|
+
) as mock_index_edges,
|
|
126
|
+
):
|
|
127
|
+
retriever = UserQAFeedback(last_k=1)
|
|
128
|
+
result = await retriever.add_feedback(feedback_text)
|
|
129
|
+
|
|
130
|
+
assert result == [feedback_text]
|
|
131
|
+
mock_add_data.assert_awaited_once()
|
|
132
|
+
# Should not call add_edges or index_graph_edges when no relationships
|
|
133
|
+
mock_graph_engine.add_edges.assert_not_awaited()
|
|
134
|
+
mock_index_edges.assert_not_awaited()
|
|
135
|
+
mock_graph_engine.apply_feedback_weight.assert_not_awaited()
|
|
136
|
+
|
|
137
|
+
@pytest.mark.asyncio
|
|
138
|
+
async def test_add_feedback_creates_correct_feedback_node(
|
|
139
|
+
self, mock_feedback_evaluation, mock_graph_engine
|
|
140
|
+
):
|
|
141
|
+
"""Test add_feedback creates CogneeUserFeedback with correct attributes."""
|
|
142
|
+
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
|
|
143
|
+
|
|
144
|
+
feedback_text = "This was a negative experience"
|
|
145
|
+
mock_feedback_evaluation.evaluation.value = "negative"
|
|
146
|
+
mock_feedback_evaluation.score = -3.0
|
|
147
|
+
|
|
148
|
+
with (
|
|
149
|
+
patch(
|
|
150
|
+
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
|
|
151
|
+
new_callable=AsyncMock,
|
|
152
|
+
return_value=mock_feedback_evaluation,
|
|
153
|
+
),
|
|
154
|
+
patch(
|
|
155
|
+
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
|
|
156
|
+
return_value=mock_graph_engine,
|
|
157
|
+
),
|
|
158
|
+
patch(
|
|
159
|
+
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
|
|
160
|
+
new_callable=AsyncMock,
|
|
161
|
+
) as mock_add_data,
|
|
162
|
+
):
|
|
163
|
+
retriever = UserQAFeedback()
|
|
164
|
+
await retriever.add_feedback(feedback_text)
|
|
165
|
+
|
|
166
|
+
# Verify add_data_points was called with correct CogneeUserFeedback
|
|
167
|
+
call_args = mock_add_data.call_args[1]["data_points"]
|
|
168
|
+
assert len(call_args) == 1
|
|
169
|
+
feedback_node = call_args[0]
|
|
170
|
+
assert feedback_node.id == uuid5(NAMESPACE_OID, name=feedback_text)
|
|
171
|
+
assert feedback_node.feedback == feedback_text
|
|
172
|
+
assert feedback_node.sentiment == "negative"
|
|
173
|
+
assert feedback_node.score == -3.0
|
|
174
|
+
assert isinstance(feedback_node.belongs_to_set, NodeSet)
|
|
175
|
+
assert feedback_node.belongs_to_set.name == "UserQAFeedbacks"
|
|
176
|
+
|
|
177
|
+
@pytest.mark.asyncio
|
|
178
|
+
async def test_add_feedback_calls_llm_with_correct_prompt(
|
|
179
|
+
self, mock_feedback_evaluation, mock_graph_engine
|
|
180
|
+
):
|
|
181
|
+
"""Test add_feedback calls LLM with correct sentiment analysis prompt."""
|
|
182
|
+
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
|
|
183
|
+
|
|
184
|
+
feedback_text = "Great answer!"
|
|
185
|
+
|
|
186
|
+
with (
|
|
187
|
+
patch(
|
|
188
|
+
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
|
|
189
|
+
new_callable=AsyncMock,
|
|
190
|
+
return_value=mock_feedback_evaluation,
|
|
191
|
+
) as mock_llm,
|
|
192
|
+
patch(
|
|
193
|
+
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
|
|
194
|
+
return_value=mock_graph_engine,
|
|
195
|
+
),
|
|
196
|
+
patch(
|
|
197
|
+
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
|
|
198
|
+
new_callable=AsyncMock,
|
|
199
|
+
),
|
|
200
|
+
):
|
|
201
|
+
retriever = UserQAFeedback()
|
|
202
|
+
await retriever.add_feedback(feedback_text)
|
|
203
|
+
|
|
204
|
+
mock_llm.assert_awaited_once()
|
|
205
|
+
call_kwargs = mock_llm.call_args[1]
|
|
206
|
+
assert call_kwargs["text_input"] == feedback_text
|
|
207
|
+
assert "sentiment analysis assistant" in call_kwargs["system_prompt"]
|
|
208
|
+
assert call_kwargs["response_model"] == UserFeedbackEvaluation
|
|
209
|
+
|
|
210
|
+
@pytest.mark.asyncio
|
|
211
|
+
async def test_add_feedback_uses_last_k_parameter(
|
|
212
|
+
self, mock_feedback_evaluation, mock_graph_engine
|
|
213
|
+
):
|
|
214
|
+
"""Test add_feedback uses last_k parameter when getting interaction IDs."""
|
|
215
|
+
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
|
|
216
|
+
|
|
217
|
+
feedback_text = "Test feedback"
|
|
218
|
+
|
|
219
|
+
with (
|
|
220
|
+
patch(
|
|
221
|
+
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
|
|
222
|
+
new_callable=AsyncMock,
|
|
223
|
+
return_value=mock_feedback_evaluation,
|
|
224
|
+
),
|
|
225
|
+
patch(
|
|
226
|
+
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
|
|
227
|
+
return_value=mock_graph_engine,
|
|
228
|
+
),
|
|
229
|
+
patch(
|
|
230
|
+
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
|
|
231
|
+
new_callable=AsyncMock,
|
|
232
|
+
),
|
|
233
|
+
):
|
|
234
|
+
retriever = UserQAFeedback(last_k=5)
|
|
235
|
+
await retriever.add_feedback(feedback_text)
|
|
236
|
+
|
|
237
|
+
mock_graph_engine.get_last_user_interaction_ids.assert_awaited_once_with(limit=5)
|
|
238
|
+
|
|
239
|
+
@pytest.mark.asyncio
|
|
240
|
+
async def test_add_feedback_with_single_interaction(
|
|
241
|
+
self, mock_feedback_evaluation, mock_graph_engine
|
|
242
|
+
):
|
|
243
|
+
"""Test add_feedback with single interaction ID."""
|
|
244
|
+
interaction_id = str(UUID("550e8400-e29b-41d4-a716-446655440000"))
|
|
245
|
+
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[interaction_id])
|
|
246
|
+
|
|
247
|
+
feedback_text = "Test feedback"
|
|
248
|
+
|
|
249
|
+
with (
|
|
250
|
+
patch(
|
|
251
|
+
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
|
|
252
|
+
new_callable=AsyncMock,
|
|
253
|
+
return_value=mock_feedback_evaluation,
|
|
254
|
+
),
|
|
255
|
+
patch(
|
|
256
|
+
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
|
|
257
|
+
return_value=mock_graph_engine,
|
|
258
|
+
),
|
|
259
|
+
patch(
|
|
260
|
+
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
|
|
261
|
+
new_callable=AsyncMock,
|
|
262
|
+
),
|
|
263
|
+
patch(
|
|
264
|
+
"cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
|
|
265
|
+
new_callable=AsyncMock,
|
|
266
|
+
),
|
|
267
|
+
):
|
|
268
|
+
retriever = UserQAFeedback()
|
|
269
|
+
result = await retriever.add_feedback(feedback_text)
|
|
270
|
+
|
|
271
|
+
assert result == [feedback_text]
|
|
272
|
+
# Should create relationship for the interaction
|
|
273
|
+
call_args = mock_graph_engine.add_edges.call_args[0][0]
|
|
274
|
+
assert len(call_args) == 1
|
|
275
|
+
assert call_args[0][1] == UUID(interaction_id)
|
|
276
|
+
|
|
277
|
+
@pytest.mark.asyncio
|
|
278
|
+
async def test_add_feedback_applies_weight_correctly(
|
|
279
|
+
self, mock_feedback_evaluation, mock_graph_engine
|
|
280
|
+
):
|
|
281
|
+
"""Test add_feedback applies feedback weight with correct score."""
|
|
282
|
+
interaction_id = str(UUID("550e8400-e29b-41d4-a716-446655440000"))
|
|
283
|
+
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[interaction_id])
|
|
284
|
+
mock_feedback_evaluation.score = 4.5
|
|
285
|
+
|
|
286
|
+
feedback_text = "Positive feedback"
|
|
287
|
+
|
|
288
|
+
with (
|
|
289
|
+
patch(
|
|
290
|
+
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
|
|
291
|
+
new_callable=AsyncMock,
|
|
292
|
+
return_value=mock_feedback_evaluation,
|
|
293
|
+
),
|
|
294
|
+
patch(
|
|
295
|
+
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
|
|
296
|
+
return_value=mock_graph_engine,
|
|
297
|
+
),
|
|
298
|
+
patch(
|
|
299
|
+
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
|
|
300
|
+
new_callable=AsyncMock,
|
|
301
|
+
),
|
|
302
|
+
patch(
|
|
303
|
+
"cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
|
|
304
|
+
new_callable=AsyncMock,
|
|
305
|
+
),
|
|
306
|
+
):
|
|
307
|
+
retriever = UserQAFeedback()
|
|
308
|
+
await retriever.add_feedback(feedback_text)
|
|
309
|
+
|
|
310
|
+
mock_graph_engine.apply_feedback_weight.assert_awaited_once_with(
|
|
311
|
+
node_ids=[interaction_id], weight=4.5
|
|
312
|
+
)
|