cognee 0.5.1__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +2 -0
- cognee/alembic/README +1 -0
- cognee/alembic/env.py +107 -0
- cognee/alembic/script.py.mako +26 -0
- cognee/alembic/versions/1a58b986e6e1_enable_delete_for_old_tutorial_notebooks.py +52 -0
- cognee/alembic/versions/1d0bb7fede17_add_pipeline_run_status.py +33 -0
- cognee/alembic/versions/1daae0df1866_incremental_loading.py +48 -0
- cognee/alembic/versions/211ab850ef3d_add_sync_operations_table.py +118 -0
- cognee/alembic/versions/45957f0a9849_add_notebook_table.py +46 -0
- cognee/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py +333 -0
- cognee/alembic/versions/482cd6517ce4_add_default_user.py +30 -0
- cognee/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py +98 -0
- cognee/alembic/versions/8057ae7329c2_initial_migration.py +25 -0
- cognee/alembic/versions/9e7a3cb85175_loader_separation.py +104 -0
- cognee/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py +38 -0
- cognee/alembic/versions/ab7e313804ae_permission_system_rework.py +236 -0
- cognee/alembic/versions/b9274c27a25a_kuzu_11_migration.py +75 -0
- cognee/alembic/versions/c946955da633_multi_tenant_support.py +137 -0
- cognee/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py +51 -0
- cognee/alembic/versions/e4ebee1091e7_expand_data_model_info.py +140 -0
- cognee/alembic.ini +117 -0
- cognee/api/v1/add/add.py +2 -1
- cognee/api/v1/add/routers/get_add_router.py +2 -0
- cognee/api/v1/cognify/cognify.py +11 -6
- cognee/api/v1/cognify/routers/get_cognify_router.py +8 -0
- cognee/api/v1/config/config.py +60 -0
- cognee/api/v1/datasets/routers/get_datasets_router.py +46 -3
- cognee/api/v1/memify/routers/get_memify_router.py +3 -0
- cognee/api/v1/search/routers/get_search_router.py +21 -6
- cognee/api/v1/search/search.py +21 -5
- cognee/api/v1/sync/routers/get_sync_router.py +3 -3
- cognee/cli/commands/add_command.py +1 -1
- cognee/cli/commands/cognify_command.py +6 -0
- cognee/cli/commands/config_command.py +1 -1
- cognee/context_global_variables.py +5 -1
- cognee/eval_framework/answer_generation/answer_generation_executor.py +7 -8
- cognee/infrastructure/databases/cache/cache_db_interface.py +38 -1
- cognee/infrastructure/databases/cache/config.py +6 -0
- cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py +21 -0
- cognee/infrastructure/databases/cache/get_cache_engine.py +9 -3
- cognee/infrastructure/databases/cache/redis/RedisAdapter.py +60 -1
- cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +7 -0
- cognee/infrastructure/databases/graph/get_graph_engine.py +29 -1
- cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +62 -27
- cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +17 -4
- cognee/infrastructure/databases/relational/config.py +16 -1
- cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +26 -3
- cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +2 -0
- cognee/infrastructure/databases/vector/config.py +6 -0
- cognee/infrastructure/databases/vector/create_vector_engine.py +70 -16
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +64 -9
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +13 -2
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +16 -3
- cognee/infrastructure/databases/vector/models/ScoredResult.py +3 -3
- cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +16 -3
- cognee/infrastructure/databases/vector/pgvector/PGVectorDatasetDatabaseHandler.py +86 -0
- cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py +81 -2
- cognee/infrastructure/databases/vector/vector_db_interface.py +8 -0
- cognee/infrastructure/files/utils/get_data_file_path.py +33 -27
- cognee/infrastructure/llm/LLMGateway.py +0 -13
- cognee/infrastructure/llm/prompts/extract_query_time.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_event_entity_prompt.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_event_graph_prompt.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt.txt +2 -2
- cognee/infrastructure/llm/prompts/generate_graph_prompt_guided.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt_oneshot.txt +2 -2
- cognee/infrastructure/llm/prompts/generate_graph_prompt_simple.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt_strict.txt +1 -1
- cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt +6 -6
- cognee/infrastructure/llm/prompts/test.txt +1 -1
- cognee/infrastructure/llm/prompts/translate_content.txt +19 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +29 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llama_cpp/adapter.py +191 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
- cognee/modules/chunking/models/DocumentChunk.py +0 -1
- cognee/modules/cognify/config.py +2 -0
- cognee/modules/data/models/Data.py +3 -1
- cognee/modules/engine/models/Entity.py +0 -1
- cognee/modules/engine/operations/setup.py +6 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +150 -37
- cognee/modules/graph/cognee_graph/CogneeGraphElements.py +48 -2
- cognee/modules/graph/utils/__init__.py +1 -0
- cognee/modules/graph/utils/get_entity_nodes_from_triplets.py +12 -0
- cognee/modules/notebooks/methods/__init__.py +1 -0
- cognee/modules/notebooks/methods/create_notebook.py +0 -34
- cognee/modules/notebooks/methods/create_tutorial_notebooks.py +191 -0
- cognee/modules/notebooks/methods/get_notebooks.py +12 -8
- cognee/modules/notebooks/tutorials/cognee-basics/cell-1.md +3 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-2.md +10 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-3.md +7 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-4.py +28 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-5.py +3 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-6.py +9 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-7.py +17 -0
- cognee/modules/notebooks/tutorials/cognee-basics/config.json +4 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-1.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-10.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-11.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-12.py +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-13.md +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-14.py +6 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-15.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-16.py +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-2.md +9 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-3.md +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-4.md +9 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-5.md +5 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-6.py +13 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-7.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-8.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-9.py +31 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/config.json +4 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/copilot_conversations.json +107 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/guido_contributions.json +976 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/my_developer_rules.md +79 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/pep_style_guide.md +74 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/zen_principles.md +74 -0
- cognee/modules/retrieval/EntityCompletionRetriever.py +51 -38
- cognee/modules/retrieval/__init__.py +0 -1
- cognee/modules/retrieval/base_retriever.py +66 -10
- cognee/modules/retrieval/chunks_retriever.py +57 -49
- cognee/modules/retrieval/coding_rules_retriever.py +12 -5
- cognee/modules/retrieval/completion_retriever.py +29 -28
- cognee/modules/retrieval/cypher_search_retriever.py +25 -20
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +42 -46
- cognee/modules/retrieval/graph_completion_cot_retriever.py +68 -51
- cognee/modules/retrieval/graph_completion_retriever.py +78 -63
- cognee/modules/retrieval/graph_summary_completion_retriever.py +2 -0
- cognee/modules/retrieval/lexical_retriever.py +34 -12
- cognee/modules/retrieval/natural_language_retriever.py +18 -15
- cognee/modules/retrieval/summaries_retriever.py +51 -34
- cognee/modules/retrieval/temporal_retriever.py +59 -49
- cognee/modules/retrieval/triplet_retriever.py +32 -33
- cognee/modules/retrieval/utils/access_tracking.py +88 -0
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +99 -103
- cognee/modules/retrieval/utils/node_edge_vector_search.py +174 -0
- cognee/modules/search/methods/__init__.py +1 -0
- cognee/modules/search/methods/get_retriever_output.py +53 -0
- cognee/modules/search/methods/get_search_type_retriever_instance.py +252 -0
- cognee/modules/search/methods/search.py +90 -222
- cognee/modules/search/models/SearchResultPayload.py +67 -0
- cognee/modules/search/types/SearchResult.py +1 -8
- cognee/modules/search/types/SearchType.py +1 -2
- cognee/modules/search/types/__init__.py +1 -1
- cognee/modules/search/utils/__init__.py +1 -2
- cognee/modules/search/utils/transform_insights_to_graph.py +2 -2
- cognee/modules/search/utils/{transform_context_to_graph.py → transform_triplets_to_graph.py} +2 -2
- cognee/modules/users/authentication/default/default_transport.py +11 -1
- cognee/modules/users/authentication/get_api_auth_backend.py +2 -1
- cognee/modules/users/authentication/get_client_auth_backend.py +2 -1
- cognee/modules/users/methods/create_user.py +0 -9
- cognee/modules/users/permissions/methods/has_user_management_permission.py +29 -0
- cognee/modules/visualization/cognee_network_visualization.py +1 -1
- cognee/run_migrations.py +48 -0
- cognee/shared/exceptions/__init__.py +1 -3
- cognee/shared/exceptions/exceptions.py +11 -1
- cognee/shared/usage_logger.py +332 -0
- cognee/shared/utils.py +12 -5
- cognee/tasks/chunks/__init__.py +9 -0
- cognee/tasks/cleanup/cleanup_unused_data.py +172 -0
- cognee/tasks/graph/__init__.py +7 -0
- cognee/tasks/ingestion/data_item.py +8 -0
- cognee/tasks/ingestion/ingest_data.py +12 -1
- cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
- cognee/tasks/memify/__init__.py +8 -0
- cognee/tasks/memify/extract_usage_frequency.py +613 -0
- cognee/tasks/summarization/models.py +0 -2
- cognee/tasks/temporal_graph/__init__.py +0 -1
- cognee/tasks/translation/__init__.py +96 -0
- cognee/tasks/translation/config.py +110 -0
- cognee/tasks/translation/detect_language.py +190 -0
- cognee/tasks/translation/exceptions.py +62 -0
- cognee/tasks/translation/models.py +72 -0
- cognee/tasks/translation/providers/__init__.py +44 -0
- cognee/tasks/translation/providers/azure_provider.py +192 -0
- cognee/tasks/translation/providers/base.py +85 -0
- cognee/tasks/translation/providers/google_provider.py +158 -0
- cognee/tasks/translation/providers/llm_provider.py +143 -0
- cognee/tasks/translation/translate_content.py +282 -0
- cognee/tasks/web_scraper/default_url_crawler.py +6 -2
- cognee/tests/cli_tests/cli_unit_tests/test_cli_commands.py +1 -0
- cognee/tests/cli_tests/cli_unit_tests/test_cli_edge_cases.py +3 -0
- cognee/tests/integration/retrieval/test_brute_force_triplet_search_with_cognify.py +62 -0
- cognee/tests/integration/retrieval/test_chunks_retriever.py +351 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever.py +276 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +228 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +217 -0
- cognee/tests/integration/retrieval/test_rag_completion_retriever.py +319 -0
- cognee/tests/integration/retrieval/test_structured_output.py +258 -0
- cognee/tests/integration/retrieval/test_summaries_retriever.py +195 -0
- cognee/tests/integration/retrieval/test_temporal_retriever.py +336 -0
- cognee/tests/integration/retrieval/test_triplet_retriever.py +45 -1
- cognee/tests/integration/shared/test_usage_logger_integration.py +255 -0
- cognee/tests/tasks/translation/README.md +147 -0
- cognee/tests/tasks/translation/__init__.py +1 -0
- cognee/tests/tasks/translation/config_test.py +93 -0
- cognee/tests/tasks/translation/detect_language_test.py +118 -0
- cognee/tests/tasks/translation/providers_test.py +151 -0
- cognee/tests/tasks/translation/translate_content_test.py +213 -0
- cognee/tests/test_chromadb.py +1 -1
- cognee/tests/test_cleanup_unused_data.py +165 -0
- cognee/tests/test_custom_data_label.py +68 -0
- cognee/tests/test_delete_by_id.py +6 -6
- cognee/tests/test_extract_usage_frequency.py +308 -0
- cognee/tests/test_kuzu.py +17 -7
- cognee/tests/test_lancedb.py +3 -1
- cognee/tests/test_library.py +1 -1
- cognee/tests/test_neo4j.py +17 -7
- cognee/tests/test_neptune_analytics_vector.py +3 -1
- cognee/tests/test_permissions.py +172 -187
- cognee/tests/test_pgvector.py +3 -1
- cognee/tests/test_relational_db_migration.py +15 -1
- cognee/tests/test_remote_kuzu.py +3 -1
- cognee/tests/test_s3_file_storage.py +1 -1
- cognee/tests/test_search_db.py +345 -205
- cognee/tests/test_usage_logger_e2e.py +268 -0
- cognee/tests/unit/api/test_get_raw_data_endpoint.py +206 -0
- cognee/tests/unit/eval_framework/answer_generation_test.py +4 -3
- cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
- cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
- cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +2 -0
- cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
- cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +42 -2
- cognee/tests/unit/modules/graph/cognee_graph_test.py +329 -31
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +122 -168
- cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +486 -157
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +693 -155
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +619 -200
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +300 -171
- cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +184 -155
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +544 -79
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +476 -28
- cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
- cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
- cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py +273 -0
- cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +267 -7
- cognee/tests/unit/modules/search/test_get_search_type_retriever_instance.py +125 -0
- cognee/tests/unit/modules/search/test_search.py +96 -20
- cognee/tests/unit/modules/search/test_search_prepare_search_result_contract.py +190 -0
- cognee/tests/unit/modules/users/test_tutorial_notebook_creation.py +511 -297
- cognee/tests/unit/shared/test_usage_logger.py +241 -0
- cognee/tests/unit/users/permissions/test_has_user_management_permission.py +46 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/METADATA +22 -17
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/RECORD +258 -157
- cognee/api/.env.example +0 -5
- cognee/modules/retrieval/base_graph_retriever.py +0 -24
- cognee/modules/search/methods/get_search_type_tools.py +0 -223
- cognee/modules/search/methods/no_access_control_search.py +0 -62
- cognee/modules/search/utils/prepare_search_result.py +0 -63
- cognee/tests/test_feedback_enrichment.py +0 -174
- cognee/tests/unit/modules/retrieval/structured_output_test.py +0 -204
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/WHEEL +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -1,177 +1,506 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import pytest
|
|
3
|
-
import
|
|
4
|
-
from
|
|
2
|
+
from unittest.mock import AsyncMock, patch, MagicMock
|
|
3
|
+
from uuid import UUID
|
|
5
4
|
|
|
6
|
-
import cognee
|
|
7
|
-
from cognee.low_level import setup, DataPoint
|
|
8
|
-
from cognee.tasks.storage import add_data_points
|
|
9
|
-
from cognee.modules.graph.utils import resolve_edges_to_text
|
|
10
5
|
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
|
11
6
|
GraphCompletionContextExtensionRetriever,
|
|
12
7
|
)
|
|
8
|
+
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@pytest.fixture
|
|
12
|
+
def mock_edge():
|
|
13
|
+
"""Create a mock edge."""
|
|
14
|
+
edge = MagicMock(spec=Edge)
|
|
15
|
+
return edge
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@pytest.mark.asyncio
|
|
19
|
+
async def test_get_triplets_inherited(mock_edge):
|
|
20
|
+
"""Test that get_triplets is inherited from parent class."""
|
|
21
|
+
retriever = GraphCompletionContextExtensionRetriever()
|
|
22
|
+
|
|
23
|
+
with patch(
|
|
24
|
+
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
|
25
|
+
return_value=[mock_edge],
|
|
26
|
+
):
|
|
27
|
+
triplets = await retriever.get_triplets("test query")
|
|
28
|
+
|
|
29
|
+
assert len(triplets) == 1
|
|
30
|
+
assert triplets[0] == mock_edge
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@pytest.mark.asyncio
|
|
34
|
+
async def test_init_defaults():
|
|
35
|
+
"""Test GraphCompletionContextExtensionRetriever initialization with defaults."""
|
|
36
|
+
retriever = GraphCompletionContextExtensionRetriever()
|
|
37
|
+
|
|
38
|
+
assert retriever.top_k == 5
|
|
39
|
+
assert retriever.user_prompt_path == "graph_context_for_question.txt"
|
|
40
|
+
assert retriever.system_prompt_path == "answer_simple_question.txt"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@pytest.mark.asyncio
|
|
44
|
+
async def test_init_custom_params():
|
|
45
|
+
"""Test GraphCompletionContextExtensionRetriever initialization with custom parameters."""
|
|
46
|
+
retriever = GraphCompletionContextExtensionRetriever(
|
|
47
|
+
top_k=10,
|
|
48
|
+
user_prompt_path="custom_user.txt",
|
|
49
|
+
system_prompt_path="custom_system.txt",
|
|
50
|
+
system_prompt="Custom prompt",
|
|
51
|
+
node_type=str,
|
|
52
|
+
node_name=["node1"],
|
|
53
|
+
save_interaction=True,
|
|
54
|
+
wide_search_top_k=200,
|
|
55
|
+
triplet_distance_penalty=5.0,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
assert retriever.top_k == 10
|
|
59
|
+
assert retriever.user_prompt_path == "custom_user.txt"
|
|
60
|
+
assert retriever.system_prompt_path == "custom_system.txt"
|
|
61
|
+
assert retriever.system_prompt == "Custom prompt"
|
|
62
|
+
assert retriever.node_type is str
|
|
63
|
+
assert retriever.node_name == ["node1"]
|
|
64
|
+
assert retriever.save_interaction is True
|
|
65
|
+
assert retriever.wide_search_top_k == 200
|
|
66
|
+
assert retriever.triplet_distance_penalty == 5.0
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@pytest.mark.asyncio
|
|
70
|
+
async def test_get_completion_without_context(mock_edge):
|
|
71
|
+
"""Test get_completion retrieves context when not provided."""
|
|
72
|
+
mock_graph_engine = AsyncMock()
|
|
73
|
+
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
74
|
+
|
|
75
|
+
retriever = GraphCompletionContextExtensionRetriever(context_extension_rounds=1)
|
|
76
|
+
|
|
77
|
+
with (
|
|
78
|
+
patch(
|
|
79
|
+
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
80
|
+
return_value=mock_graph_engine,
|
|
81
|
+
),
|
|
82
|
+
patch(
|
|
83
|
+
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
|
84
|
+
return_value=[mock_edge],
|
|
85
|
+
),
|
|
86
|
+
patch(
|
|
87
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
88
|
+
return_value="Resolved context",
|
|
89
|
+
),
|
|
90
|
+
patch(
|
|
91
|
+
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
|
92
|
+
return_value="Generated answer",
|
|
93
|
+
),
|
|
94
|
+
patch(
|
|
95
|
+
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
|
96
|
+
) as mock_cache_config,
|
|
97
|
+
):
|
|
98
|
+
mock_config = MagicMock()
|
|
99
|
+
mock_config.caching = False
|
|
100
|
+
mock_cache_config.return_value = mock_config
|
|
101
|
+
|
|
102
|
+
retrieved_objects = await retriever.get_retrieved_objects("test_query")
|
|
103
|
+
context = await retriever.get_context_from_objects("test query", retrieved_objects)
|
|
104
|
+
completion = await retriever.get_completion_from_context(
|
|
105
|
+
"test query", retrieved_objects, context
|
|
106
|
+
)
|
|
13
107
|
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
108
|
+
assert isinstance(completion, list)
|
|
109
|
+
assert len(completion) == 1
|
|
110
|
+
assert completion[0] == "Generated answer"
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@pytest.mark.asyncio
|
|
114
|
+
async def test_get_completion_with_provided_context(mock_edge):
|
|
115
|
+
"""Test get_completion uses provided context."""
|
|
116
|
+
retriever = GraphCompletionContextExtensionRetriever(context_extension_rounds=1)
|
|
117
|
+
|
|
118
|
+
with (
|
|
119
|
+
patch(
|
|
120
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
121
|
+
return_value="Resolved context",
|
|
122
|
+
),
|
|
123
|
+
patch(
|
|
124
|
+
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
|
125
|
+
return_value="Generated answer",
|
|
126
|
+
),
|
|
127
|
+
patch(
|
|
128
|
+
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
|
129
|
+
) as mock_cache_config,
|
|
130
|
+
):
|
|
131
|
+
mock_config = MagicMock()
|
|
132
|
+
mock_config.caching = False
|
|
133
|
+
mock_cache_config.return_value = mock_config
|
|
134
|
+
|
|
135
|
+
context = await retriever.get_context_from_objects(
|
|
136
|
+
"test query", retrieved_objects=[mock_edge]
|
|
21
137
|
)
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
pathlib.Path(__file__).parent,
|
|
25
|
-
".data_storage/test_graph_completion_extension_context_simple",
|
|
138
|
+
completion = await retriever.get_completion_from_context(
|
|
139
|
+
"test query", retrieved_objects=[mock_edge], context=context
|
|
26
140
|
)
|
|
27
|
-
cognee.config.data_root_directory(data_directory_path)
|
|
28
|
-
|
|
29
|
-
await cognee.prune.prune_data()
|
|
30
|
-
await cognee.prune.prune_system(metadata=True)
|
|
31
|
-
await setup()
|
|
32
|
-
|
|
33
|
-
class Company(DataPoint):
|
|
34
|
-
name: str
|
|
35
|
-
|
|
36
|
-
class Person(DataPoint):
|
|
37
|
-
name: str
|
|
38
|
-
works_for: Company
|
|
39
|
-
|
|
40
|
-
company1 = Company(name="Figma")
|
|
41
|
-
company2 = Company(name="Canva")
|
|
42
|
-
person1 = Person(name="Steve Rodger", works_for=company1)
|
|
43
|
-
person2 = Person(name="Ike Loma", works_for=company1)
|
|
44
|
-
person3 = Person(name="Jason Statham", works_for=company1)
|
|
45
|
-
person4 = Person(name="Mike Broski", works_for=company2)
|
|
46
|
-
person5 = Person(name="Christina Mayer", works_for=company2)
|
|
47
|
-
|
|
48
|
-
entities = [company1, company2, person1, person2, person3, person4, person5]
|
|
49
|
-
|
|
50
|
-
await add_data_points(entities)
|
|
51
|
-
|
|
52
|
-
retriever = GraphCompletionContextExtensionRetriever()
|
|
53
|
-
|
|
54
|
-
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
|
|
55
141
|
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
142
|
+
assert isinstance(completion, list)
|
|
143
|
+
assert len(completion) == 1
|
|
144
|
+
assert completion[0] == "Generated answer"
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@pytest.mark.asyncio
|
|
148
|
+
async def test_get_completion_context_extension_rounds(mock_edge):
|
|
149
|
+
"""Test get_completion with multiple context extension rounds."""
|
|
150
|
+
mock_graph_engine = AsyncMock()
|
|
151
|
+
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
152
|
+
|
|
153
|
+
retriever = GraphCompletionContextExtensionRetriever(context_extension_rounds=1)
|
|
154
|
+
|
|
155
|
+
# Create a second edge for extension rounds
|
|
156
|
+
mock_edge2 = MagicMock(spec=Edge)
|
|
157
|
+
|
|
158
|
+
with (
|
|
159
|
+
patch(
|
|
160
|
+
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
161
|
+
return_value=mock_graph_engine,
|
|
162
|
+
),
|
|
163
|
+
patch.object(
|
|
164
|
+
retriever,
|
|
165
|
+
"get_context_from_objects",
|
|
166
|
+
new_callable=AsyncMock,
|
|
167
|
+
side_effect=[[mock_edge], [mock_edge2]],
|
|
168
|
+
),
|
|
169
|
+
patch(
|
|
170
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
171
|
+
side_effect=["Resolved context", "Extended context"], # Different contexts
|
|
172
|
+
),
|
|
173
|
+
patch(
|
|
174
|
+
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
|
175
|
+
side_effect=[
|
|
176
|
+
"Extension query",
|
|
177
|
+
"Generated answer",
|
|
178
|
+
], # Query for extension, then final answer
|
|
179
|
+
),
|
|
180
|
+
patch(
|
|
181
|
+
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
|
182
|
+
) as mock_cache_config,
|
|
183
|
+
):
|
|
184
|
+
mock_config = MagicMock()
|
|
185
|
+
mock_config.caching = False
|
|
186
|
+
mock_cache_config.return_value = mock_config
|
|
187
|
+
|
|
188
|
+
objects = await retriever.get_retrieved_objects("test_query")
|
|
189
|
+
context = await retriever.get_context_from_objects("test query", objects)
|
|
190
|
+
completion = await retriever.get_completion_from_context(
|
|
191
|
+
"test query", objects, context=context
|
|
64
192
|
)
|
|
65
193
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
194
|
+
assert isinstance(completion, list)
|
|
195
|
+
assert len(completion) == 1
|
|
196
|
+
assert completion[0] == "Generated answer"
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
@pytest.mark.asyncio
|
|
200
|
+
async def test_get_completion_context_extension_stops_early(mock_edge):
|
|
201
|
+
"""Test get_completion stops early when no new triplets found."""
|
|
202
|
+
mock_graph_engine = AsyncMock()
|
|
203
|
+
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
204
|
+
|
|
205
|
+
retriever = GraphCompletionContextExtensionRetriever(context_extension_rounds=4)
|
|
206
|
+
|
|
207
|
+
with (
|
|
208
|
+
patch.object(
|
|
209
|
+
retriever, "get_context_from_objects", new_callable=AsyncMock, return_value=[mock_edge]
|
|
210
|
+
),
|
|
211
|
+
patch(
|
|
212
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
213
|
+
return_value="Resolved context",
|
|
214
|
+
),
|
|
215
|
+
patch(
|
|
216
|
+
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
|
217
|
+
side_effect=[
|
|
218
|
+
"Extension query",
|
|
219
|
+
"Generated answer",
|
|
220
|
+
],
|
|
221
|
+
),
|
|
222
|
+
patch(
|
|
223
|
+
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
|
224
|
+
) as mock_cache_config,
|
|
225
|
+
):
|
|
226
|
+
mock_config = MagicMock()
|
|
227
|
+
mock_config.caching = False
|
|
228
|
+
mock_cache_config.return_value = mock_config
|
|
229
|
+
|
|
230
|
+
# When get_context returns same triplets, the loop should stop early
|
|
231
|
+
objects = await retriever.get_retrieved_objects("test_query")
|
|
232
|
+
context = await retriever.get_context_from_objects("test query", objects)
|
|
233
|
+
completion = await retriever.get_completion_from_context(
|
|
234
|
+
"test query", objects, context=context
|
|
76
235
|
)
|
|
77
|
-
cognee.config.data_root_directory(data_directory_path)
|
|
78
|
-
|
|
79
|
-
await cognee.prune.prune_data()
|
|
80
|
-
await cognee.prune.prune_system(metadata=True)
|
|
81
|
-
await setup()
|
|
82
|
-
|
|
83
|
-
class Company(DataPoint):
|
|
84
|
-
name: str
|
|
85
|
-
metadata: dict = {"index_fields": ["name"]}
|
|
86
|
-
|
|
87
|
-
class Car(DataPoint):
|
|
88
|
-
brand: str
|
|
89
|
-
model: str
|
|
90
|
-
year: int
|
|
91
|
-
|
|
92
|
-
class Location(DataPoint):
|
|
93
|
-
country: str
|
|
94
|
-
city: str
|
|
95
|
-
|
|
96
|
-
class Home(DataPoint):
|
|
97
|
-
location: Location
|
|
98
|
-
rooms: int
|
|
99
|
-
sqm: int
|
|
100
|
-
|
|
101
|
-
class Person(DataPoint):
|
|
102
|
-
name: str
|
|
103
|
-
works_for: Company
|
|
104
|
-
owns: Optional[list[Union[Car, Home]]] = None
|
|
105
236
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
237
|
+
assert isinstance(completion, list)
|
|
238
|
+
assert len(completion) == 1
|
|
239
|
+
assert completion[0] == "Generated answer"
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
@pytest.mark.asyncio
|
|
243
|
+
async def test_get_completion_with_session(mock_edge):
|
|
244
|
+
"""Test get_completion with session caching enabled."""
|
|
245
|
+
mock_graph_engine = AsyncMock()
|
|
246
|
+
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
247
|
+
|
|
248
|
+
retriever = GraphCompletionContextExtensionRetriever(
|
|
249
|
+
session_id="test_session", context_extension_rounds=1
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
mock_user = MagicMock()
|
|
253
|
+
mock_user.id = "test-user-id"
|
|
254
|
+
|
|
255
|
+
with (
|
|
256
|
+
patch(
|
|
257
|
+
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
258
|
+
return_value=mock_graph_engine,
|
|
259
|
+
),
|
|
260
|
+
patch.object(
|
|
261
|
+
retriever, "get_context_from_objects", new_callable=AsyncMock, return_value=[mock_edge]
|
|
262
|
+
),
|
|
263
|
+
patch(
|
|
264
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
265
|
+
return_value="Resolved context",
|
|
266
|
+
),
|
|
267
|
+
patch(
|
|
268
|
+
"cognee.modules.retrieval.graph_completion_context_extension_retriever.get_conversation_history",
|
|
269
|
+
return_value="Previous conversation",
|
|
270
|
+
),
|
|
271
|
+
patch(
|
|
272
|
+
"cognee.modules.retrieval.graph_completion_context_extension_retriever.summarize_text",
|
|
273
|
+
return_value="Context summary",
|
|
274
|
+
),
|
|
275
|
+
patch(
|
|
276
|
+
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
|
277
|
+
side_effect=[
|
|
278
|
+
"Extension query",
|
|
279
|
+
"Generated answer",
|
|
280
|
+
], # Extension query, then final answer
|
|
281
|
+
),
|
|
282
|
+
patch(
|
|
283
|
+
"cognee.modules.retrieval.graph_completion_context_extension_retriever.save_conversation_history",
|
|
284
|
+
) as mock_save,
|
|
285
|
+
patch(
|
|
286
|
+
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
|
287
|
+
) as mock_cache_config,
|
|
288
|
+
patch(
|
|
289
|
+
"cognee.modules.retrieval.graph_completion_context_extension_retriever.session_user"
|
|
290
|
+
) as mock_session_user,
|
|
291
|
+
):
|
|
292
|
+
mock_config = MagicMock()
|
|
293
|
+
mock_config.caching = True
|
|
294
|
+
mock_cache_config.return_value = mock_config
|
|
295
|
+
mock_session_user.get.return_value = mock_user
|
|
296
|
+
|
|
297
|
+
objects = await retriever.get_retrieved_objects("test_query")
|
|
298
|
+
context = await retriever.get_context_from_objects("test query", objects)
|
|
299
|
+
completion = await retriever.get_completion_from_context(
|
|
300
|
+
"test query", objects, context=context
|
|
134
301
|
)
|
|
135
302
|
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
303
|
+
assert isinstance(completion, list)
|
|
304
|
+
assert len(completion) == 1
|
|
305
|
+
assert completion[0] == "Generated answer"
|
|
306
|
+
mock_save.assert_awaited_once()
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
@pytest.mark.asyncio
|
|
310
|
+
async def test_get_completion_with_save_interaction(mock_edge):
|
|
311
|
+
"""Test get_completion with save_interaction enabled."""
|
|
312
|
+
mock_graph_engine = AsyncMock()
|
|
313
|
+
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
314
|
+
mock_graph_engine.add_edges = AsyncMock()
|
|
315
|
+
|
|
316
|
+
retriever = GraphCompletionContextExtensionRetriever(
|
|
317
|
+
context_extension_rounds=1, save_interaction=True
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
mock_node1 = MagicMock()
|
|
321
|
+
mock_node2 = MagicMock()
|
|
322
|
+
mock_edge.node1 = mock_node1
|
|
323
|
+
mock_edge.node2 = mock_node2
|
|
324
|
+
|
|
325
|
+
with (
|
|
326
|
+
patch(
|
|
327
|
+
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
328
|
+
return_value=mock_graph_engine,
|
|
329
|
+
),
|
|
330
|
+
patch.object(
|
|
331
|
+
retriever, "get_context_from_objects", new_callable=AsyncMock, return_value="mock_edge"
|
|
332
|
+
),
|
|
333
|
+
patch(
|
|
334
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
335
|
+
return_value="Resolved context",
|
|
336
|
+
),
|
|
337
|
+
patch(
|
|
338
|
+
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
|
339
|
+
side_effect=[
|
|
340
|
+
"Extension query",
|
|
341
|
+
"Generated answer",
|
|
342
|
+
], # Extension query, then final answer
|
|
343
|
+
),
|
|
344
|
+
patch(
|
|
345
|
+
"cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
|
|
346
|
+
side_effect=[
|
|
347
|
+
UUID("550e8400-e29b-41d4-a716-446655440000"),
|
|
348
|
+
UUID("550e8400-e29b-41d4-a716-446655440001"),
|
|
349
|
+
],
|
|
350
|
+
),
|
|
351
|
+
patch(
|
|
352
|
+
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
|
|
353
|
+
) as mock_add_data,
|
|
354
|
+
patch(
|
|
355
|
+
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
|
356
|
+
) as mock_cache_config,
|
|
357
|
+
):
|
|
358
|
+
mock_config = MagicMock()
|
|
359
|
+
mock_config.caching = False
|
|
360
|
+
mock_cache_config.return_value = mock_config
|
|
361
|
+
|
|
362
|
+
context = await retriever.get_context_from_objects("test query", [mock_edge])
|
|
363
|
+
completion = await retriever.get_completion_from_context(
|
|
364
|
+
"test query", [mock_edge], context=context
|
|
147
365
|
)
|
|
148
366
|
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
367
|
+
assert isinstance(completion, list)
|
|
368
|
+
assert len(completion) == 1
|
|
369
|
+
mock_add_data.assert_awaited_once()
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
@pytest.mark.asyncio
|
|
373
|
+
async def test_get_completion_with_response_model(mock_edge):
|
|
374
|
+
"""Test get_completion with custom response model."""
|
|
375
|
+
from pydantic import BaseModel
|
|
376
|
+
|
|
377
|
+
class TestModel(BaseModel):
|
|
378
|
+
answer: str
|
|
379
|
+
|
|
380
|
+
mock_graph_engine = AsyncMock()
|
|
381
|
+
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
382
|
+
|
|
383
|
+
retriever = GraphCompletionContextExtensionRetriever(context_extension_rounds=1)
|
|
384
|
+
|
|
385
|
+
with (
|
|
386
|
+
patch(
|
|
387
|
+
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
388
|
+
return_value=mock_graph_engine,
|
|
389
|
+
),
|
|
390
|
+
patch.object(
|
|
391
|
+
retriever, "get_context_from_objects", new_callable=AsyncMock, return_value=[mock_edge]
|
|
392
|
+
),
|
|
393
|
+
patch(
|
|
394
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
395
|
+
return_value="Resolved context",
|
|
396
|
+
),
|
|
397
|
+
patch(
|
|
398
|
+
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
|
399
|
+
side_effect=[
|
|
400
|
+
"Extension query",
|
|
401
|
+
TestModel(answer="Test answer"),
|
|
402
|
+
], # Extension query, then final answer
|
|
403
|
+
),
|
|
404
|
+
patch(
|
|
405
|
+
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
|
406
|
+
) as mock_cache_config,
|
|
407
|
+
):
|
|
408
|
+
mock_config = MagicMock()
|
|
409
|
+
mock_config.caching = False
|
|
410
|
+
mock_cache_config.return_value = mock_config
|
|
411
|
+
|
|
412
|
+
objects = await retriever.get_retrieved_objects("test_query")
|
|
413
|
+
context = await retriever.get_context_from_objects("test query", objects)
|
|
414
|
+
completion = await retriever.get_completion_from_context(
|
|
415
|
+
"test query", objects, context=context
|
|
159
416
|
)
|
|
160
|
-
cognee.config.data_root_directory(data_directory_path)
|
|
161
|
-
|
|
162
|
-
await cognee.prune.prune_data()
|
|
163
|
-
await cognee.prune.prune_system(metadata=True)
|
|
164
|
-
|
|
165
|
-
retriever = GraphCompletionContextExtensionRetriever()
|
|
166
417
|
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
418
|
+
assert isinstance(completion, list)
|
|
419
|
+
assert len(completion) == 1
|
|
420
|
+
assert isinstance(completion[0], TestModel)
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
@pytest.mark.asyncio
|
|
424
|
+
async def test_get_completion_with_session_no_user_id(mock_edge):
|
|
425
|
+
"""Test get_completion with session config but no user ID."""
|
|
426
|
+
mock_graph_engine = AsyncMock()
|
|
427
|
+
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
428
|
+
|
|
429
|
+
retriever = GraphCompletionContextExtensionRetriever(context_extension_rounds=1)
|
|
430
|
+
|
|
431
|
+
with (
|
|
432
|
+
patch(
|
|
433
|
+
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
434
|
+
return_value=mock_graph_engine,
|
|
435
|
+
),
|
|
436
|
+
patch.object(
|
|
437
|
+
retriever, "get_context_from_objects", new_callable=AsyncMock, return_value=[mock_edge]
|
|
438
|
+
),
|
|
439
|
+
patch(
|
|
440
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
441
|
+
return_value="Resolved context",
|
|
442
|
+
),
|
|
443
|
+
patch(
|
|
444
|
+
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
|
445
|
+
side_effect=[
|
|
446
|
+
"Extension query",
|
|
447
|
+
"Generated answer",
|
|
448
|
+
], # Extension query, then final answer
|
|
449
|
+
),
|
|
450
|
+
patch(
|
|
451
|
+
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
|
452
|
+
) as mock_cache_config,
|
|
453
|
+
patch(
|
|
454
|
+
"cognee.modules.retrieval.graph_completion_context_extension_retriever.session_user"
|
|
455
|
+
) as mock_session_user,
|
|
456
|
+
):
|
|
457
|
+
mock_config = MagicMock()
|
|
458
|
+
mock_config.caching = True
|
|
459
|
+
mock_cache_config.return_value = mock_config
|
|
460
|
+
mock_session_user.get.return_value = None # No user
|
|
461
|
+
|
|
462
|
+
objects = await retriever.get_retrieved_objects("test_query")
|
|
463
|
+
context = await retriever.get_context_from_objects("test query", objects)
|
|
464
|
+
completion = await retriever.get_completion_from_context(
|
|
465
|
+
"test query", objects, context=context
|
|
177
466
|
)
|
|
467
|
+
|
|
468
|
+
assert isinstance(completion, list)
|
|
469
|
+
assert len(completion) == 1
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
@pytest.mark.asyncio
|
|
473
|
+
async def test_get_completion_zero_extension_rounds(mock_edge):
|
|
474
|
+
"""Test get_completion with zero context extension rounds."""
|
|
475
|
+
mock_graph_engine = AsyncMock()
|
|
476
|
+
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
477
|
+
|
|
478
|
+
retriever = GraphCompletionContextExtensionRetriever(context_extension_rounds=0)
|
|
479
|
+
|
|
480
|
+
with (
|
|
481
|
+
patch(
|
|
482
|
+
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
483
|
+
return_value=mock_graph_engine,
|
|
484
|
+
),
|
|
485
|
+
patch.object(
|
|
486
|
+
retriever, "get_context_from_objects", new_callable=AsyncMock, return_value=[mock_edge]
|
|
487
|
+
),
|
|
488
|
+
patch(
|
|
489
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
490
|
+
return_value="Resolved context",
|
|
491
|
+
),
|
|
492
|
+
patch(
|
|
493
|
+
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
|
494
|
+
return_value="Generated answer",
|
|
495
|
+
),
|
|
496
|
+
patch(
|
|
497
|
+
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
|
498
|
+
) as mock_cache_config,
|
|
499
|
+
):
|
|
500
|
+
mock_config = MagicMock()
|
|
501
|
+
mock_config.caching = False
|
|
502
|
+
mock_cache_config.return_value = mock_config
|
|
503
|
+
context = await retriever.get_context_from_objects("test query", None)
|
|
504
|
+
|
|
505
|
+
assert isinstance(context, list)
|
|
506
|
+
assert len(context) == 1
|