cognee 0.5.1__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +2 -0
- cognee/alembic/README +1 -0
- cognee/alembic/env.py +107 -0
- cognee/alembic/script.py.mako +26 -0
- cognee/alembic/versions/1a58b986e6e1_enable_delete_for_old_tutorial_notebooks.py +52 -0
- cognee/alembic/versions/1d0bb7fede17_add_pipeline_run_status.py +33 -0
- cognee/alembic/versions/1daae0df1866_incremental_loading.py +48 -0
- cognee/alembic/versions/211ab850ef3d_add_sync_operations_table.py +118 -0
- cognee/alembic/versions/45957f0a9849_add_notebook_table.py +46 -0
- cognee/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py +333 -0
- cognee/alembic/versions/482cd6517ce4_add_default_user.py +30 -0
- cognee/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py +98 -0
- cognee/alembic/versions/8057ae7329c2_initial_migration.py +25 -0
- cognee/alembic/versions/9e7a3cb85175_loader_separation.py +104 -0
- cognee/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py +38 -0
- cognee/alembic/versions/ab7e313804ae_permission_system_rework.py +236 -0
- cognee/alembic/versions/b9274c27a25a_kuzu_11_migration.py +75 -0
- cognee/alembic/versions/c946955da633_multi_tenant_support.py +137 -0
- cognee/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py +51 -0
- cognee/alembic/versions/e4ebee1091e7_expand_data_model_info.py +140 -0
- cognee/alembic.ini +117 -0
- cognee/api/v1/add/add.py +2 -1
- cognee/api/v1/add/routers/get_add_router.py +2 -0
- cognee/api/v1/cognify/cognify.py +11 -6
- cognee/api/v1/cognify/routers/get_cognify_router.py +8 -0
- cognee/api/v1/config/config.py +60 -0
- cognee/api/v1/datasets/routers/get_datasets_router.py +46 -3
- cognee/api/v1/memify/routers/get_memify_router.py +3 -0
- cognee/api/v1/search/routers/get_search_router.py +21 -6
- cognee/api/v1/search/search.py +21 -5
- cognee/api/v1/sync/routers/get_sync_router.py +3 -3
- cognee/cli/commands/add_command.py +1 -1
- cognee/cli/commands/cognify_command.py +6 -0
- cognee/cli/commands/config_command.py +1 -1
- cognee/context_global_variables.py +5 -1
- cognee/eval_framework/answer_generation/answer_generation_executor.py +7 -8
- cognee/infrastructure/databases/cache/cache_db_interface.py +38 -1
- cognee/infrastructure/databases/cache/config.py +6 -0
- cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py +21 -0
- cognee/infrastructure/databases/cache/get_cache_engine.py +9 -3
- cognee/infrastructure/databases/cache/redis/RedisAdapter.py +60 -1
- cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +7 -0
- cognee/infrastructure/databases/graph/get_graph_engine.py +29 -1
- cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +62 -27
- cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +17 -4
- cognee/infrastructure/databases/relational/config.py +16 -1
- cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +26 -3
- cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +2 -0
- cognee/infrastructure/databases/vector/config.py +6 -0
- cognee/infrastructure/databases/vector/create_vector_engine.py +70 -16
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +64 -9
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +13 -2
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +16 -3
- cognee/infrastructure/databases/vector/models/ScoredResult.py +3 -3
- cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +16 -3
- cognee/infrastructure/databases/vector/pgvector/PGVectorDatasetDatabaseHandler.py +86 -0
- cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py +81 -2
- cognee/infrastructure/databases/vector/vector_db_interface.py +8 -0
- cognee/infrastructure/files/utils/get_data_file_path.py +33 -27
- cognee/infrastructure/llm/LLMGateway.py +0 -13
- cognee/infrastructure/llm/prompts/extract_query_time.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_event_entity_prompt.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_event_graph_prompt.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt.txt +2 -2
- cognee/infrastructure/llm/prompts/generate_graph_prompt_guided.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt_oneshot.txt +2 -2
- cognee/infrastructure/llm/prompts/generate_graph_prompt_simple.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt_strict.txt +1 -1
- cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt +6 -6
- cognee/infrastructure/llm/prompts/test.txt +1 -1
- cognee/infrastructure/llm/prompts/translate_content.txt +19 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +29 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llama_cpp/adapter.py +191 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
- cognee/modules/chunking/models/DocumentChunk.py +0 -1
- cognee/modules/cognify/config.py +2 -0
- cognee/modules/data/models/Data.py +3 -1
- cognee/modules/engine/models/Entity.py +0 -1
- cognee/modules/engine/operations/setup.py +6 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +150 -37
- cognee/modules/graph/cognee_graph/CogneeGraphElements.py +48 -2
- cognee/modules/graph/utils/__init__.py +1 -0
- cognee/modules/graph/utils/get_entity_nodes_from_triplets.py +12 -0
- cognee/modules/notebooks/methods/__init__.py +1 -0
- cognee/modules/notebooks/methods/create_notebook.py +0 -34
- cognee/modules/notebooks/methods/create_tutorial_notebooks.py +191 -0
- cognee/modules/notebooks/methods/get_notebooks.py +12 -8
- cognee/modules/notebooks/tutorials/cognee-basics/cell-1.md +3 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-2.md +10 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-3.md +7 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-4.py +28 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-5.py +3 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-6.py +9 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-7.py +17 -0
- cognee/modules/notebooks/tutorials/cognee-basics/config.json +4 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-1.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-10.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-11.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-12.py +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-13.md +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-14.py +6 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-15.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-16.py +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-2.md +9 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-3.md +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-4.md +9 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-5.md +5 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-6.py +13 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-7.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-8.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-9.py +31 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/config.json +4 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/copilot_conversations.json +107 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/guido_contributions.json +976 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/my_developer_rules.md +79 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/pep_style_guide.md +74 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/zen_principles.md +74 -0
- cognee/modules/retrieval/EntityCompletionRetriever.py +51 -38
- cognee/modules/retrieval/__init__.py +0 -1
- cognee/modules/retrieval/base_retriever.py +66 -10
- cognee/modules/retrieval/chunks_retriever.py +57 -49
- cognee/modules/retrieval/coding_rules_retriever.py +12 -5
- cognee/modules/retrieval/completion_retriever.py +29 -28
- cognee/modules/retrieval/cypher_search_retriever.py +25 -20
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +42 -46
- cognee/modules/retrieval/graph_completion_cot_retriever.py +68 -51
- cognee/modules/retrieval/graph_completion_retriever.py +78 -63
- cognee/modules/retrieval/graph_summary_completion_retriever.py +2 -0
- cognee/modules/retrieval/lexical_retriever.py +34 -12
- cognee/modules/retrieval/natural_language_retriever.py +18 -15
- cognee/modules/retrieval/summaries_retriever.py +51 -34
- cognee/modules/retrieval/temporal_retriever.py +59 -49
- cognee/modules/retrieval/triplet_retriever.py +32 -33
- cognee/modules/retrieval/utils/access_tracking.py +88 -0
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +99 -103
- cognee/modules/retrieval/utils/node_edge_vector_search.py +174 -0
- cognee/modules/search/methods/__init__.py +1 -0
- cognee/modules/search/methods/get_retriever_output.py +53 -0
- cognee/modules/search/methods/get_search_type_retriever_instance.py +252 -0
- cognee/modules/search/methods/search.py +90 -222
- cognee/modules/search/models/SearchResultPayload.py +67 -0
- cognee/modules/search/types/SearchResult.py +1 -8
- cognee/modules/search/types/SearchType.py +1 -2
- cognee/modules/search/types/__init__.py +1 -1
- cognee/modules/search/utils/__init__.py +1 -2
- cognee/modules/search/utils/transform_insights_to_graph.py +2 -2
- cognee/modules/search/utils/{transform_context_to_graph.py → transform_triplets_to_graph.py} +2 -2
- cognee/modules/users/authentication/default/default_transport.py +11 -1
- cognee/modules/users/authentication/get_api_auth_backend.py +2 -1
- cognee/modules/users/authentication/get_client_auth_backend.py +2 -1
- cognee/modules/users/methods/create_user.py +0 -9
- cognee/modules/users/permissions/methods/has_user_management_permission.py +29 -0
- cognee/modules/visualization/cognee_network_visualization.py +1 -1
- cognee/run_migrations.py +48 -0
- cognee/shared/exceptions/__init__.py +1 -3
- cognee/shared/exceptions/exceptions.py +11 -1
- cognee/shared/usage_logger.py +332 -0
- cognee/shared/utils.py +12 -5
- cognee/tasks/chunks/__init__.py +9 -0
- cognee/tasks/cleanup/cleanup_unused_data.py +172 -0
- cognee/tasks/graph/__init__.py +7 -0
- cognee/tasks/ingestion/data_item.py +8 -0
- cognee/tasks/ingestion/ingest_data.py +12 -1
- cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
- cognee/tasks/memify/__init__.py +8 -0
- cognee/tasks/memify/extract_usage_frequency.py +613 -0
- cognee/tasks/summarization/models.py +0 -2
- cognee/tasks/temporal_graph/__init__.py +0 -1
- cognee/tasks/translation/__init__.py +96 -0
- cognee/tasks/translation/config.py +110 -0
- cognee/tasks/translation/detect_language.py +190 -0
- cognee/tasks/translation/exceptions.py +62 -0
- cognee/tasks/translation/models.py +72 -0
- cognee/tasks/translation/providers/__init__.py +44 -0
- cognee/tasks/translation/providers/azure_provider.py +192 -0
- cognee/tasks/translation/providers/base.py +85 -0
- cognee/tasks/translation/providers/google_provider.py +158 -0
- cognee/tasks/translation/providers/llm_provider.py +143 -0
- cognee/tasks/translation/translate_content.py +282 -0
- cognee/tasks/web_scraper/default_url_crawler.py +6 -2
- cognee/tests/cli_tests/cli_unit_tests/test_cli_commands.py +1 -0
- cognee/tests/cli_tests/cli_unit_tests/test_cli_edge_cases.py +3 -0
- cognee/tests/integration/retrieval/test_brute_force_triplet_search_with_cognify.py +62 -0
- cognee/tests/integration/retrieval/test_chunks_retriever.py +351 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever.py +276 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +228 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +217 -0
- cognee/tests/integration/retrieval/test_rag_completion_retriever.py +319 -0
- cognee/tests/integration/retrieval/test_structured_output.py +258 -0
- cognee/tests/integration/retrieval/test_summaries_retriever.py +195 -0
- cognee/tests/integration/retrieval/test_temporal_retriever.py +336 -0
- cognee/tests/integration/retrieval/test_triplet_retriever.py +45 -1
- cognee/tests/integration/shared/test_usage_logger_integration.py +255 -0
- cognee/tests/tasks/translation/README.md +147 -0
- cognee/tests/tasks/translation/__init__.py +1 -0
- cognee/tests/tasks/translation/config_test.py +93 -0
- cognee/tests/tasks/translation/detect_language_test.py +118 -0
- cognee/tests/tasks/translation/providers_test.py +151 -0
- cognee/tests/tasks/translation/translate_content_test.py +213 -0
- cognee/tests/test_chromadb.py +1 -1
- cognee/tests/test_cleanup_unused_data.py +165 -0
- cognee/tests/test_custom_data_label.py +68 -0
- cognee/tests/test_delete_by_id.py +6 -6
- cognee/tests/test_extract_usage_frequency.py +308 -0
- cognee/tests/test_kuzu.py +17 -7
- cognee/tests/test_lancedb.py +3 -1
- cognee/tests/test_library.py +1 -1
- cognee/tests/test_neo4j.py +17 -7
- cognee/tests/test_neptune_analytics_vector.py +3 -1
- cognee/tests/test_permissions.py +172 -187
- cognee/tests/test_pgvector.py +3 -1
- cognee/tests/test_relational_db_migration.py +15 -1
- cognee/tests/test_remote_kuzu.py +3 -1
- cognee/tests/test_s3_file_storage.py +1 -1
- cognee/tests/test_search_db.py +345 -205
- cognee/tests/test_usage_logger_e2e.py +268 -0
- cognee/tests/unit/api/test_get_raw_data_endpoint.py +206 -0
- cognee/tests/unit/eval_framework/answer_generation_test.py +4 -3
- cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
- cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
- cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +2 -0
- cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
- cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +42 -2
- cognee/tests/unit/modules/graph/cognee_graph_test.py +329 -31
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +122 -168
- cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +486 -157
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +693 -155
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +619 -200
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +300 -171
- cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +184 -155
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +544 -79
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +476 -28
- cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
- cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
- cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py +273 -0
- cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +267 -7
- cognee/tests/unit/modules/search/test_get_search_type_retriever_instance.py +125 -0
- cognee/tests/unit/modules/search/test_search.py +96 -20
- cognee/tests/unit/modules/search/test_search_prepare_search_result_contract.py +190 -0
- cognee/tests/unit/modules/users/test_tutorial_notebook_creation.py +511 -297
- cognee/tests/unit/shared/test_usage_logger.py +241 -0
- cognee/tests/unit/users/permissions/test_has_user_management_permission.py +46 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/METADATA +22 -17
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/RECORD +258 -157
- cognee/api/.env.example +0 -5
- cognee/modules/retrieval/base_graph_retriever.py +0 -24
- cognee/modules/search/methods/get_search_type_tools.py +0 -223
- cognee/modules/search/methods/no_access_control_search.py +0 -62
- cognee/modules/search/utils/prepare_search_result.py +0 -63
- cognee/tests/test_feedback_enrichment.py +0 -174
- cognee/tests/unit/modules/retrieval/structured_output_test.py +0 -204
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/WHEEL +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -1,7 +1,12 @@
|
|
|
1
1
|
from types import SimpleNamespace
|
|
2
2
|
import pytest
|
|
3
|
+
import os
|
|
4
|
+
from unittest.mock import AsyncMock, patch, MagicMock
|
|
5
|
+
from datetime import datetime
|
|
3
6
|
|
|
4
7
|
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
|
8
|
+
from cognee.tasks.temporal_graph.models import QueryInterval, Timestamp
|
|
9
|
+
from cognee.infrastructure.llm import LLMGateway
|
|
5
10
|
|
|
6
11
|
|
|
7
12
|
# Test TemporalRetriever initialization defaults and overrides
|
|
@@ -58,8 +63,8 @@ async def test_filter_top_k_events_sorts_and_limits():
|
|
|
58
63
|
]
|
|
59
64
|
|
|
60
65
|
scored_results = [
|
|
61
|
-
SimpleNamespace(payload={"id": "e2"}, score=0.10),
|
|
62
|
-
SimpleNamespace(payload={"id": "e1"}, score=0.20),
|
|
66
|
+
SimpleNamespace(id="e2", payload={"id": "e2"}, score=0.10),
|
|
67
|
+
SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.20),
|
|
63
68
|
]
|
|
64
69
|
|
|
65
70
|
top = await tr.filter_top_k_events(relevant_events, scored_results)
|
|
@@ -86,8 +91,8 @@ async def test_filter_top_k_events_includes_unknown_as_infinite_but_not_in_top_k
|
|
|
86
91
|
]
|
|
87
92
|
|
|
88
93
|
scored_results = [
|
|
89
|
-
SimpleNamespace(payload={"id": "known2"}, score=0.05),
|
|
90
|
-
SimpleNamespace(payload={"id": "known1"}, score=0.50),
|
|
94
|
+
SimpleNamespace(id="known2", payload={"id": "known2"}, score=0.05),
|
|
95
|
+
SimpleNamespace(id="known1", payload={"id": "known1"}, score=0.50),
|
|
91
96
|
]
|
|
92
97
|
|
|
93
98
|
top = await tr.filter_top_k_events(relevant_events, scored_results)
|
|
@@ -114,8 +119,8 @@ async def test_filter_top_k_events_limits_when_top_k_exceeds_events():
|
|
|
114
119
|
tr = TemporalRetriever(top_k=10)
|
|
115
120
|
relevant_events = [{"events": [{"id": "a"}, {"id": "b"}]}]
|
|
116
121
|
scored_results = [
|
|
117
|
-
SimpleNamespace(payload={"id": "a"}, score=0.1),
|
|
118
|
-
SimpleNamespace(payload={"id": "b"}, score=0.2),
|
|
122
|
+
SimpleNamespace(id="a", payload={"id": "a"}, score=0.1),
|
|
123
|
+
SimpleNamespace(id="b", payload={"id": "b"}, score=0.2),
|
|
119
124
|
]
|
|
120
125
|
out = await tr.filter_top_k_events(relevant_events, scored_results)
|
|
121
126
|
assert [e["id"] for e in out] == ["a", "b"]
|
|
@@ -140,85 +145,545 @@ async def test_filter_top_k_events_error_handling():
|
|
|
140
145
|
await tr.filter_top_k_events([{}], [])
|
|
141
146
|
|
|
142
147
|
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
148
|
+
@pytest.fixture
|
|
149
|
+
def mock_graph_engine():
|
|
150
|
+
"""Create a mock graph engine."""
|
|
151
|
+
engine = AsyncMock()
|
|
152
|
+
engine.collect_time_ids = AsyncMock()
|
|
153
|
+
engine.collect_events = AsyncMock()
|
|
154
|
+
return engine
|
|
147
155
|
|
|
148
|
-
|
|
149
|
-
|
|
156
|
+
|
|
157
|
+
@pytest.fixture
|
|
158
|
+
def mock_vector_engine():
|
|
159
|
+
"""Create a mock vector engine."""
|
|
160
|
+
engine = AsyncMock()
|
|
161
|
+
engine.embedding_engine = AsyncMock()
|
|
162
|
+
engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
163
|
+
engine.search = AsyncMock()
|
|
164
|
+
return engine
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
@pytest.mark.asyncio
|
|
168
|
+
async def test_get_context_with_time_range(mock_graph_engine, mock_vector_engine):
|
|
169
|
+
"""Test get_context when time range is extracted from query."""
|
|
170
|
+
retriever = TemporalRetriever(top_k=5)
|
|
171
|
+
|
|
172
|
+
mock_graph_engine.collect_time_ids.return_value = ["e1", "e2"]
|
|
173
|
+
mock_graph_engine.collect_events.return_value = [
|
|
174
|
+
{
|
|
175
|
+
"events": [
|
|
176
|
+
{"id": "e1", "description": "Event 1"},
|
|
177
|
+
{"id": "e2", "description": "Event 2"},
|
|
178
|
+
]
|
|
179
|
+
}
|
|
180
|
+
]
|
|
181
|
+
|
|
182
|
+
mock_result1 = SimpleNamespace(id="e2", payload={"id": "e2"}, score=0.05)
|
|
183
|
+
mock_result2 = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.10)
|
|
184
|
+
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
|
|
185
|
+
|
|
186
|
+
with (
|
|
187
|
+
patch.object(
|
|
188
|
+
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
|
|
189
|
+
),
|
|
190
|
+
patch(
|
|
191
|
+
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
|
192
|
+
return_value=mock_graph_engine,
|
|
193
|
+
),
|
|
194
|
+
patch(
|
|
195
|
+
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
|
196
|
+
return_value=mock_vector_engine,
|
|
197
|
+
),
|
|
198
|
+
):
|
|
199
|
+
objects = await retriever.get_retrieved_objects("What happened in 2024?")
|
|
200
|
+
context = await retriever.get_context_from_objects("What happened in 2024?", objects)
|
|
201
|
+
|
|
202
|
+
assert isinstance(context, str)
|
|
203
|
+
assert len(context) > 0
|
|
204
|
+
assert "Event" in context
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
@pytest.mark.asyncio
|
|
208
|
+
async def test_get_context_fallback_to_triplets_no_time(mock_graph_engine):
|
|
209
|
+
"""Test get_context falls back to triplets when no time is extracted."""
|
|
210
|
+
retriever = TemporalRetriever()
|
|
211
|
+
|
|
212
|
+
with (
|
|
213
|
+
patch(
|
|
214
|
+
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
|
215
|
+
return_value=mock_graph_engine,
|
|
216
|
+
),
|
|
217
|
+
patch.object(
|
|
218
|
+
retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}]
|
|
219
|
+
) as mock_get_triplets,
|
|
220
|
+
patch.object(
|
|
221
|
+
retriever, "resolve_edges_to_text", return_value="triplet text"
|
|
222
|
+
) as mock_resolve,
|
|
223
|
+
):
|
|
224
|
+
|
|
225
|
+
async def mock_extract_time(query):
|
|
226
|
+
return None, None
|
|
227
|
+
|
|
228
|
+
retriever.extract_time_from_query = mock_extract_time
|
|
229
|
+
|
|
230
|
+
objects = await retriever.get_retrieved_objects("test query")
|
|
231
|
+
context = await retriever.get_context_from_objects("test query", objects)
|
|
232
|
+
|
|
233
|
+
assert context == "triplet text"
|
|
234
|
+
mock_get_triplets.assert_awaited_once_with("test query")
|
|
235
|
+
mock_resolve.assert_awaited_once()
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@pytest.mark.asyncio
|
|
239
|
+
async def test_get_context_no_events_found(mock_graph_engine):
|
|
240
|
+
"""Test get_context falls back to triplets when no events are found."""
|
|
241
|
+
retriever = TemporalRetriever()
|
|
242
|
+
|
|
243
|
+
mock_graph_engine.collect_time_ids.return_value = []
|
|
244
|
+
|
|
245
|
+
with (
|
|
246
|
+
patch(
|
|
247
|
+
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
|
248
|
+
return_value=mock_graph_engine,
|
|
249
|
+
),
|
|
250
|
+
patch.object(
|
|
251
|
+
retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}]
|
|
252
|
+
) as mock_get_triplets,
|
|
253
|
+
patch.object(
|
|
254
|
+
retriever, "resolve_edges_to_text", return_value="triplet text"
|
|
255
|
+
) as mock_resolve,
|
|
256
|
+
):
|
|
257
|
+
|
|
258
|
+
async def mock_extract_time(query):
|
|
150
259
|
return "2024-01-01", "2024-12-31"
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
"
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
260
|
+
|
|
261
|
+
retriever.extract_time_from_query = mock_extract_time
|
|
262
|
+
|
|
263
|
+
objects = await retriever.get_retrieved_objects("test query")
|
|
264
|
+
context = await retriever.get_context_from_objects("test query", objects)
|
|
265
|
+
|
|
266
|
+
assert context == "triplet text"
|
|
267
|
+
mock_get_triplets.assert_awaited_once_with("test query")
|
|
268
|
+
mock_resolve.assert_awaited_once()
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
@pytest.mark.asyncio
|
|
272
|
+
async def test_get_context_time_from_only(mock_graph_engine, mock_vector_engine):
|
|
273
|
+
"""Test get_context with only time_from."""
|
|
274
|
+
retriever = TemporalRetriever(top_k=5)
|
|
275
|
+
|
|
276
|
+
mock_graph_engine.collect_time_ids.return_value = ["e1"]
|
|
277
|
+
mock_graph_engine.collect_events.return_value = [
|
|
278
|
+
{
|
|
279
|
+
"events": [
|
|
280
|
+
{"id": "e1", "description": "Event 1"},
|
|
281
|
+
]
|
|
282
|
+
}
|
|
283
|
+
]
|
|
284
|
+
|
|
285
|
+
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
|
|
286
|
+
mock_vector_engine.search.return_value = [mock_result]
|
|
287
|
+
|
|
288
|
+
with (
|
|
289
|
+
patch.object(retriever, "extract_time_from_query", return_value=("2024-01-01", None)),
|
|
290
|
+
patch(
|
|
291
|
+
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
|
292
|
+
return_value=mock_graph_engine,
|
|
293
|
+
),
|
|
294
|
+
patch(
|
|
295
|
+
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
|
296
|
+
return_value=mock_vector_engine,
|
|
297
|
+
),
|
|
298
|
+
):
|
|
299
|
+
objects = await retriever.get_retrieved_objects("What happened in 2024?")
|
|
300
|
+
context = await retriever.get_context_from_objects("What happened in 2024?", objects)
|
|
301
|
+
|
|
302
|
+
assert isinstance(context, str)
|
|
303
|
+
assert "Event 1" in context
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
@pytest.mark.asyncio
|
|
307
|
+
async def test_get_context_time_to_only(mock_graph_engine, mock_vector_engine):
|
|
308
|
+
"""Test get_context with only time_to."""
|
|
309
|
+
retriever = TemporalRetriever(top_k=5)
|
|
310
|
+
|
|
311
|
+
mock_graph_engine.collect_time_ids.return_value = ["e1"]
|
|
312
|
+
mock_graph_engine.collect_events.return_value = [
|
|
313
|
+
{
|
|
314
|
+
"events": [
|
|
315
|
+
{"id": "e1", "description": "Event 1"},
|
|
316
|
+
]
|
|
317
|
+
}
|
|
318
|
+
]
|
|
319
|
+
|
|
320
|
+
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
|
|
321
|
+
mock_vector_engine.search.return_value = [mock_result]
|
|
322
|
+
|
|
323
|
+
with (
|
|
324
|
+
patch.object(retriever, "extract_time_from_query", return_value=(None, "2024-12-31")),
|
|
325
|
+
patch(
|
|
326
|
+
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
|
327
|
+
return_value=mock_graph_engine,
|
|
328
|
+
),
|
|
329
|
+
patch(
|
|
330
|
+
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
|
331
|
+
return_value=mock_vector_engine,
|
|
332
|
+
),
|
|
333
|
+
):
|
|
334
|
+
objects = await retriever.get_retrieved_objects("What happened in 2024?")
|
|
335
|
+
context = await retriever.get_context_from_objects("What happened in 2024?", objects)
|
|
336
|
+
|
|
337
|
+
assert isinstance(context, str)
|
|
338
|
+
assert "Event 1" in context
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
@pytest.mark.asyncio
|
|
342
|
+
async def test_get_completion_without_context(mock_graph_engine, mock_vector_engine):
|
|
343
|
+
"""Test get_completion retrieves context when not provided."""
|
|
344
|
+
retriever = TemporalRetriever()
|
|
345
|
+
|
|
346
|
+
mock_graph_engine.collect_time_ids.return_value = ["e1"]
|
|
347
|
+
mock_graph_engine.collect_events.return_value = [
|
|
348
|
+
{
|
|
349
|
+
"events": [
|
|
350
|
+
{"id": "e1", "description": "Event 1"},
|
|
351
|
+
]
|
|
352
|
+
}
|
|
353
|
+
]
|
|
354
|
+
|
|
355
|
+
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
|
|
356
|
+
mock_vector_engine.search.return_value = [mock_result]
|
|
357
|
+
|
|
358
|
+
with (
|
|
359
|
+
patch.object(
|
|
360
|
+
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
|
|
361
|
+
),
|
|
362
|
+
patch(
|
|
363
|
+
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
|
364
|
+
return_value=mock_graph_engine,
|
|
365
|
+
),
|
|
366
|
+
patch(
|
|
367
|
+
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
|
368
|
+
return_value=mock_vector_engine,
|
|
369
|
+
),
|
|
370
|
+
patch(
|
|
371
|
+
"cognee.modules.retrieval.temporal_retriever.generate_completion",
|
|
372
|
+
return_value="Generated answer",
|
|
373
|
+
),
|
|
374
|
+
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
|
|
375
|
+
):
|
|
376
|
+
mock_config = MagicMock()
|
|
377
|
+
mock_config.caching = False
|
|
378
|
+
mock_cache_config.return_value = mock_config
|
|
379
|
+
|
|
380
|
+
objects = await retriever.get_retrieved_objects("What happened in 2024?")
|
|
381
|
+
context = await retriever.get_context_from_objects("What happened in 2024?", objects)
|
|
382
|
+
completion = await retriever.get_completion_from_context(
|
|
383
|
+
"What happened in 2024?", objects, context=context
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
assert isinstance(completion, list)
|
|
387
|
+
assert len(completion) == 1
|
|
388
|
+
assert completion[0] == "Generated answer"
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
@pytest.mark.asyncio
|
|
392
|
+
async def test_get_completion_with_provided_context():
|
|
393
|
+
"""Test get_completion uses provided context."""
|
|
394
|
+
retriever = TemporalRetriever()
|
|
395
|
+
|
|
396
|
+
with (
|
|
397
|
+
patch(
|
|
398
|
+
"cognee.modules.retrieval.temporal_retriever.generate_completion",
|
|
399
|
+
return_value="Generated answer",
|
|
400
|
+
),
|
|
401
|
+
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
|
|
402
|
+
):
|
|
403
|
+
mock_config = MagicMock()
|
|
404
|
+
mock_config.caching = False
|
|
405
|
+
mock_cache_config.return_value = mock_config
|
|
406
|
+
|
|
407
|
+
objects = await retriever.get_retrieved_objects("What happened in 2024?")
|
|
408
|
+
await retriever.get_context_from_objects("What happened in 2024?", objects)
|
|
409
|
+
completion = await retriever.get_completion_from_context(
|
|
410
|
+
"test query", objects, context="Provided context"
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
assert isinstance(completion, list)
|
|
414
|
+
assert len(completion) == 1
|
|
415
|
+
assert completion[0] == "Generated answer"
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
@pytest.mark.asyncio
|
|
419
|
+
async def test_get_completion_with_session(mock_graph_engine, mock_vector_engine):
|
|
420
|
+
"""Test get_completion with session caching enabled."""
|
|
421
|
+
retriever = TemporalRetriever(session_id="test_session")
|
|
422
|
+
|
|
423
|
+
mock_graph_engine.collect_time_ids.return_value = ["e1"]
|
|
424
|
+
mock_graph_engine.collect_events.return_value = [
|
|
425
|
+
{
|
|
426
|
+
"events": [
|
|
427
|
+
{"id": "e1", "description": "Event 1"},
|
|
428
|
+
]
|
|
429
|
+
}
|
|
430
|
+
]
|
|
431
|
+
|
|
432
|
+
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
|
|
433
|
+
mock_vector_engine.search.return_value = [mock_result]
|
|
434
|
+
|
|
435
|
+
mock_user = MagicMock()
|
|
436
|
+
mock_user.id = "test-user-id"
|
|
437
|
+
|
|
438
|
+
with (
|
|
439
|
+
patch.object(
|
|
440
|
+
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
|
|
441
|
+
),
|
|
442
|
+
patch(
|
|
443
|
+
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
|
444
|
+
return_value=mock_graph_engine,
|
|
445
|
+
),
|
|
446
|
+
patch(
|
|
447
|
+
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
|
448
|
+
return_value=mock_vector_engine,
|
|
449
|
+
),
|
|
450
|
+
patch(
|
|
451
|
+
"cognee.modules.retrieval.temporal_retriever.get_conversation_history",
|
|
452
|
+
return_value="Previous conversation",
|
|
453
|
+
),
|
|
454
|
+
patch(
|
|
455
|
+
"cognee.modules.retrieval.temporal_retriever.summarize_text",
|
|
456
|
+
return_value="Context summary",
|
|
457
|
+
),
|
|
458
|
+
patch(
|
|
459
|
+
"cognee.modules.retrieval.temporal_retriever.generate_completion",
|
|
460
|
+
return_value="Generated answer",
|
|
461
|
+
),
|
|
462
|
+
patch(
|
|
463
|
+
"cognee.modules.retrieval.temporal_retriever.save_conversation_history",
|
|
464
|
+
) as mock_save,
|
|
465
|
+
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
|
|
466
|
+
patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user,
|
|
467
|
+
):
|
|
468
|
+
mock_config = MagicMock()
|
|
469
|
+
mock_config.caching = True
|
|
470
|
+
mock_cache_config.return_value = mock_config
|
|
471
|
+
mock_session_user.get.return_value = mock_user
|
|
472
|
+
|
|
473
|
+
objects = await retriever.get_retrieved_objects("What happened in 2024?")
|
|
474
|
+
context = await retriever.get_context_from_objects("What happened in 2024?", objects)
|
|
475
|
+
completion = await retriever.get_completion_from_context(
|
|
476
|
+
"What happened in 2024?", objects, context
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
assert isinstance(completion, list)
|
|
480
|
+
assert len(completion) == 1
|
|
481
|
+
assert completion[0] == "Generated answer"
|
|
482
|
+
mock_save.assert_awaited_once()
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
@pytest.mark.asyncio
|
|
486
|
+
async def test_get_completion_with_session_no_user_id(mock_graph_engine, mock_vector_engine):
|
|
487
|
+
"""Test get_completion with session config but no user ID."""
|
|
488
|
+
retriever = TemporalRetriever()
|
|
489
|
+
|
|
490
|
+
mock_graph_engine.collect_time_ids.return_value = ["e1"]
|
|
491
|
+
mock_graph_engine.collect_events.return_value = [
|
|
492
|
+
{
|
|
493
|
+
"events": [
|
|
494
|
+
{"id": "e1", "description": "Event 1"},
|
|
495
|
+
]
|
|
496
|
+
}
|
|
497
|
+
]
|
|
498
|
+
|
|
499
|
+
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
|
|
500
|
+
mock_vector_engine.search.return_value = [mock_result]
|
|
501
|
+
|
|
502
|
+
with (
|
|
503
|
+
patch.object(
|
|
504
|
+
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
|
|
505
|
+
),
|
|
506
|
+
patch(
|
|
507
|
+
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
|
508
|
+
return_value=mock_graph_engine,
|
|
509
|
+
),
|
|
510
|
+
patch(
|
|
511
|
+
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
|
512
|
+
return_value=mock_vector_engine,
|
|
513
|
+
),
|
|
514
|
+
patch(
|
|
515
|
+
"cognee.modules.retrieval.temporal_retriever.generate_completion",
|
|
516
|
+
return_value="Generated answer",
|
|
517
|
+
),
|
|
518
|
+
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
|
|
519
|
+
patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user,
|
|
520
|
+
):
|
|
521
|
+
mock_config = MagicMock()
|
|
522
|
+
mock_config.caching = True
|
|
523
|
+
mock_cache_config.return_value = mock_config
|
|
524
|
+
mock_session_user.get.return_value = None # No user
|
|
525
|
+
|
|
526
|
+
objects = await retriever.get_retrieved_objects("What happened in 2024?")
|
|
527
|
+
context = await retriever.get_context_from_objects("What happened in 2024?", objects)
|
|
528
|
+
completion = await retriever.get_completion_from_context(
|
|
529
|
+
"What happened in 2024?", objects, context
|
|
202
530
|
)
|
|
203
|
-
|
|
204
|
-
|
|
531
|
+
|
|
532
|
+
assert isinstance(completion, list)
|
|
533
|
+
assert len(completion) == 1
|
|
205
534
|
|
|
206
535
|
|
|
207
|
-
# Test get_context fallback to triplets when no time is extracted
|
|
208
536
|
@pytest.mark.asyncio
|
|
209
|
-
async def
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
537
|
+
async def test_get_completion_with_response_model(mock_graph_engine, mock_vector_engine):
|
|
538
|
+
"""Test get_completion with custom response model."""
|
|
539
|
+
from pydantic import BaseModel
|
|
540
|
+
|
|
541
|
+
class TestModel(BaseModel):
|
|
542
|
+
answer: str
|
|
543
|
+
|
|
544
|
+
retriever = TemporalRetriever()
|
|
545
|
+
|
|
546
|
+
mock_graph_engine.collect_time_ids.return_value = ["e1"]
|
|
547
|
+
mock_graph_engine.collect_events.return_value = [
|
|
548
|
+
{
|
|
549
|
+
"events": [
|
|
550
|
+
{"id": "e1", "description": "Event 1"},
|
|
551
|
+
]
|
|
552
|
+
}
|
|
553
|
+
]
|
|
554
|
+
|
|
555
|
+
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
|
|
556
|
+
mock_vector_engine.search.return_value = [mock_result]
|
|
557
|
+
|
|
558
|
+
with (
|
|
559
|
+
patch.object(
|
|
560
|
+
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
|
|
561
|
+
),
|
|
562
|
+
patch(
|
|
563
|
+
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
|
564
|
+
return_value=mock_graph_engine,
|
|
565
|
+
),
|
|
566
|
+
patch(
|
|
567
|
+
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
|
568
|
+
return_value=mock_vector_engine,
|
|
569
|
+
),
|
|
570
|
+
patch(
|
|
571
|
+
"cognee.modules.retrieval.temporal_retriever.generate_completion",
|
|
572
|
+
return_value=TestModel(answer="Test answer"),
|
|
573
|
+
),
|
|
574
|
+
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
|
|
575
|
+
):
|
|
576
|
+
mock_config = MagicMock()
|
|
577
|
+
mock_config.caching = False
|
|
578
|
+
mock_cache_config.return_value = mock_config
|
|
579
|
+
|
|
580
|
+
objects = await retriever.get_retrieved_objects("What happened in 2024?")
|
|
581
|
+
context = await retriever.get_context_from_objects("What happened in 2024?", objects)
|
|
582
|
+
completion = await retriever.get_completion_from_context(
|
|
583
|
+
"What happened in 2024?", objects, context
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
assert isinstance(completion, list)
|
|
587
|
+
assert len(completion) == 1
|
|
588
|
+
assert isinstance(completion[0], TestModel)
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
@pytest.mark.asyncio
|
|
592
|
+
async def test_extract_time_from_query_relative_path():
|
|
593
|
+
"""Test extract_time_from_query with relative prompt path."""
|
|
594
|
+
retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt")
|
|
595
|
+
|
|
596
|
+
mock_timestamp_from = Timestamp(year=2024, month=1, day=1)
|
|
597
|
+
mock_timestamp_to = Timestamp(year=2024, month=12, day=31)
|
|
598
|
+
mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to)
|
|
599
|
+
|
|
600
|
+
with (
|
|
601
|
+
patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False),
|
|
602
|
+
patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
|
|
603
|
+
patch(
|
|
604
|
+
"cognee.modules.retrieval.temporal_retriever.render_prompt",
|
|
605
|
+
return_value="System prompt",
|
|
606
|
+
),
|
|
607
|
+
patch.object(
|
|
608
|
+
LLMGateway,
|
|
609
|
+
"acreate_structured_output",
|
|
610
|
+
new_callable=AsyncMock,
|
|
611
|
+
return_value=mock_interval,
|
|
612
|
+
),
|
|
613
|
+
):
|
|
614
|
+
mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
|
|
615
|
+
|
|
616
|
+
time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?")
|
|
617
|
+
|
|
618
|
+
assert time_from == mock_timestamp_from
|
|
619
|
+
assert time_to == mock_timestamp_to
|
|
620
|
+
|
|
621
|
+
|
|
622
|
+
@pytest.mark.asyncio
|
|
623
|
+
async def test_extract_time_from_query_absolute_path():
|
|
624
|
+
"""Test extract_time_from_query with absolute prompt path."""
|
|
625
|
+
retriever = TemporalRetriever(
|
|
626
|
+
time_extraction_prompt_path="/absolute/path/to/extract_query_time.txt"
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
mock_timestamp_from = Timestamp(year=2024, month=1, day=1)
|
|
630
|
+
mock_timestamp_to = Timestamp(year=2024, month=12, day=31)
|
|
631
|
+
mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to)
|
|
632
|
+
|
|
633
|
+
with (
|
|
634
|
+
patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=True),
|
|
635
|
+
patch(
|
|
636
|
+
"cognee.modules.retrieval.temporal_retriever.os.path.dirname",
|
|
637
|
+
return_value="/absolute/path/to",
|
|
638
|
+
),
|
|
639
|
+
patch(
|
|
640
|
+
"cognee.modules.retrieval.temporal_retriever.os.path.basename",
|
|
641
|
+
return_value="extract_query_time.txt",
|
|
642
|
+
),
|
|
643
|
+
patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
|
|
644
|
+
patch(
|
|
645
|
+
"cognee.modules.retrieval.temporal_retriever.render_prompt",
|
|
646
|
+
return_value="System prompt",
|
|
647
|
+
),
|
|
648
|
+
patch.object(
|
|
649
|
+
LLMGateway,
|
|
650
|
+
"acreate_structured_output",
|
|
651
|
+
new_callable=AsyncMock,
|
|
652
|
+
return_value=mock_interval,
|
|
653
|
+
),
|
|
654
|
+
):
|
|
655
|
+
mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
|
|
656
|
+
|
|
657
|
+
time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?")
|
|
658
|
+
|
|
659
|
+
assert time_from == mock_timestamp_from
|
|
660
|
+
assert time_to == mock_timestamp_to
|
|
215
661
|
|
|
216
662
|
|
|
217
|
-
# Test get_context when time is extracted and vector ranking is applied
|
|
218
663
|
@pytest.mark.asyncio
|
|
219
|
-
async def
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
664
|
+
async def test_extract_time_from_query_with_none_values():
|
|
665
|
+
"""Test extract_time_from_query when interval has None values."""
|
|
666
|
+
retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt")
|
|
667
|
+
|
|
668
|
+
mock_interval = QueryInterval(starts_at=None, ends_at=None)
|
|
669
|
+
|
|
670
|
+
with (
|
|
671
|
+
patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False),
|
|
672
|
+
patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
|
|
673
|
+
patch(
|
|
674
|
+
"cognee.modules.retrieval.temporal_retriever.render_prompt",
|
|
675
|
+
return_value="System prompt",
|
|
676
|
+
),
|
|
677
|
+
patch.object(
|
|
678
|
+
LLMGateway,
|
|
679
|
+
"acreate_structured_output",
|
|
680
|
+
new_callable=AsyncMock,
|
|
681
|
+
return_value=mock_interval,
|
|
682
|
+
),
|
|
683
|
+
):
|
|
684
|
+
mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
|
|
685
|
+
|
|
686
|
+
time_from, time_to = await retriever.extract_time_from_query("What happened?")
|
|
687
|
+
|
|
688
|
+
assert time_from is None
|
|
689
|
+
assert time_to is None
|