cognee 0.5.1.dev0__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +2 -0
- cognee/alembic/README +1 -0
- cognee/alembic/env.py +107 -0
- cognee/alembic/script.py.mako +26 -0
- cognee/alembic/versions/1a58b986e6e1_enable_delete_for_old_tutorial_notebooks.py +52 -0
- cognee/alembic/versions/1d0bb7fede17_add_pipeline_run_status.py +33 -0
- cognee/alembic/versions/1daae0df1866_incremental_loading.py +48 -0
- cognee/alembic/versions/211ab850ef3d_add_sync_operations_table.py +118 -0
- cognee/alembic/versions/45957f0a9849_add_notebook_table.py +46 -0
- cognee/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py +333 -0
- cognee/alembic/versions/482cd6517ce4_add_default_user.py +30 -0
- cognee/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py +98 -0
- cognee/alembic/versions/8057ae7329c2_initial_migration.py +25 -0
- cognee/alembic/versions/9e7a3cb85175_loader_separation.py +104 -0
- cognee/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py +38 -0
- cognee/alembic/versions/ab7e313804ae_permission_system_rework.py +236 -0
- cognee/alembic/versions/b9274c27a25a_kuzu_11_migration.py +75 -0
- cognee/alembic/versions/c946955da633_multi_tenant_support.py +137 -0
- cognee/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py +51 -0
- cognee/alembic/versions/e4ebee1091e7_expand_data_model_info.py +140 -0
- cognee/alembic.ini +117 -0
- cognee/api/v1/add/routers/get_add_router.py +2 -0
- cognee/api/v1/cognify/cognify.py +11 -6
- cognee/api/v1/cognify/routers/get_cognify_router.py +8 -0
- cognee/api/v1/config/config.py +60 -0
- cognee/api/v1/datasets/routers/get_datasets_router.py +45 -3
- cognee/api/v1/memify/routers/get_memify_router.py +2 -0
- cognee/api/v1/search/routers/get_search_router.py +21 -6
- cognee/api/v1/search/search.py +25 -5
- cognee/api/v1/sync/routers/get_sync_router.py +3 -3
- cognee/cli/commands/add_command.py +1 -1
- cognee/cli/commands/cognify_command.py +6 -0
- cognee/cli/commands/config_command.py +1 -1
- cognee/context_global_variables.py +5 -1
- cognee/eval_framework/answer_generation/answer_generation_executor.py +7 -8
- cognee/infrastructure/databases/cache/cache_db_interface.py +38 -1
- cognee/infrastructure/databases/cache/config.py +6 -0
- cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py +21 -0
- cognee/infrastructure/databases/cache/get_cache_engine.py +9 -3
- cognee/infrastructure/databases/cache/redis/RedisAdapter.py +60 -1
- cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +7 -0
- cognee/infrastructure/databases/graph/get_graph_engine.py +29 -1
- cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +62 -27
- cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +17 -4
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +2 -1
- cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +2 -0
- cognee/infrastructure/databases/vector/config.py +6 -0
- cognee/infrastructure/databases/vector/create_vector_engine.py +69 -22
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +64 -9
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +13 -2
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +16 -3
- cognee/infrastructure/databases/vector/models/ScoredResult.py +3 -3
- cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +16 -3
- cognee/infrastructure/databases/vector/pgvector/PGVectorDatasetDatabaseHandler.py +86 -0
- cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py +81 -2
- cognee/infrastructure/databases/vector/vector_db_interface.py +8 -0
- cognee/infrastructure/files/utils/get_data_file_path.py +33 -27
- cognee/infrastructure/llm/prompts/extract_query_time.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_event_entity_prompt.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_event_graph_prompt.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt.txt +2 -2
- cognee/infrastructure/llm/prompts/generate_graph_prompt_guided.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt_oneshot.txt +2 -2
- cognee/infrastructure/llm/prompts/generate_graph_prompt_simple.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt_strict.txt +1 -1
- cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt +6 -6
- cognee/infrastructure/llm/prompts/test.txt +1 -1
- cognee/infrastructure/llm/prompts/translate_content.txt +19 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +24 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llama_cpp/adapter.py +191 -0
- cognee/modules/chunking/models/DocumentChunk.py +0 -1
- cognee/modules/cognify/config.py +2 -0
- cognee/modules/data/models/Data.py +1 -0
- cognee/modules/engine/models/Entity.py +0 -1
- cognee/modules/engine/operations/setup.py +6 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +150 -37
- cognee/modules/graph/cognee_graph/CogneeGraphElements.py +48 -2
- cognee/modules/graph/utils/__init__.py +1 -0
- cognee/modules/graph/utils/get_entity_nodes_from_triplets.py +12 -0
- cognee/modules/notebooks/methods/__init__.py +1 -0
- cognee/modules/notebooks/methods/create_notebook.py +0 -34
- cognee/modules/notebooks/methods/create_tutorial_notebooks.py +191 -0
- cognee/modules/notebooks/methods/get_notebooks.py +12 -8
- cognee/modules/notebooks/tutorials/cognee-basics/cell-1.md +3 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-2.md +10 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-3.md +7 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-4.py +28 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-5.py +3 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-6.py +9 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-7.py +17 -0
- cognee/modules/notebooks/tutorials/cognee-basics/config.json +4 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-1.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-10.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-11.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-12.py +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-13.md +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-14.py +6 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-15.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-16.py +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-2.md +9 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-3.md +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-4.md +9 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-5.md +5 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-6.py +13 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-7.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-8.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-9.py +31 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/config.json +4 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/copilot_conversations.json +107 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/guido_contributions.json +976 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/my_developer_rules.md +79 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/pep_style_guide.md +74 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/zen_principles.md +74 -0
- cognee/modules/retrieval/EntityCompletionRetriever.py +51 -38
- cognee/modules/retrieval/__init__.py +0 -1
- cognee/modules/retrieval/base_retriever.py +66 -10
- cognee/modules/retrieval/chunks_retriever.py +57 -49
- cognee/modules/retrieval/coding_rules_retriever.py +12 -5
- cognee/modules/retrieval/completion_retriever.py +29 -28
- cognee/modules/retrieval/cypher_search_retriever.py +25 -20
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +42 -46
- cognee/modules/retrieval/graph_completion_cot_retriever.py +68 -51
- cognee/modules/retrieval/graph_completion_retriever.py +78 -63
- cognee/modules/retrieval/graph_summary_completion_retriever.py +2 -0
- cognee/modules/retrieval/lexical_retriever.py +34 -12
- cognee/modules/retrieval/natural_language_retriever.py +18 -15
- cognee/modules/retrieval/summaries_retriever.py +51 -34
- cognee/modules/retrieval/temporal_retriever.py +59 -49
- cognee/modules/retrieval/triplet_retriever.py +31 -32
- cognee/modules/retrieval/utils/access_tracking.py +88 -0
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +99 -85
- cognee/modules/retrieval/utils/node_edge_vector_search.py +174 -0
- cognee/modules/search/methods/__init__.py +1 -0
- cognee/modules/search/methods/get_retriever_output.py +53 -0
- cognee/modules/search/methods/get_search_type_retriever_instance.py +252 -0
- cognee/modules/search/methods/search.py +90 -215
- cognee/modules/search/models/SearchResultPayload.py +67 -0
- cognee/modules/search/types/SearchResult.py +1 -8
- cognee/modules/search/types/SearchType.py +1 -2
- cognee/modules/search/types/__init__.py +1 -1
- cognee/modules/search/utils/__init__.py +1 -2
- cognee/modules/search/utils/transform_insights_to_graph.py +2 -2
- cognee/modules/search/utils/{transform_context_to_graph.py → transform_triplets_to_graph.py} +2 -2
- cognee/modules/users/authentication/default/default_transport.py +11 -1
- cognee/modules/users/authentication/get_api_auth_backend.py +2 -1
- cognee/modules/users/authentication/get_client_auth_backend.py +2 -1
- cognee/modules/users/methods/create_user.py +0 -9
- cognee/modules/users/permissions/methods/has_user_management_permission.py +29 -0
- cognee/modules/visualization/cognee_network_visualization.py +1 -1
- cognee/run_migrations.py +48 -0
- cognee/shared/exceptions/__init__.py +1 -3
- cognee/shared/exceptions/exceptions.py +11 -1
- cognee/shared/usage_logger.py +332 -0
- cognee/shared/utils.py +12 -5
- cognee/tasks/chunks/__init__.py +9 -0
- cognee/tasks/cleanup/cleanup_unused_data.py +172 -0
- cognee/tasks/graph/__init__.py +7 -0
- cognee/tasks/memify/__init__.py +8 -0
- cognee/tasks/memify/extract_usage_frequency.py +613 -0
- cognee/tasks/summarization/models.py +0 -2
- cognee/tasks/temporal_graph/__init__.py +0 -1
- cognee/tasks/translation/__init__.py +96 -0
- cognee/tasks/translation/config.py +110 -0
- cognee/tasks/translation/detect_language.py +190 -0
- cognee/tasks/translation/exceptions.py +62 -0
- cognee/tasks/translation/models.py +72 -0
- cognee/tasks/translation/providers/__init__.py +44 -0
- cognee/tasks/translation/providers/azure_provider.py +192 -0
- cognee/tasks/translation/providers/base.py +85 -0
- cognee/tasks/translation/providers/google_provider.py +158 -0
- cognee/tasks/translation/providers/llm_provider.py +143 -0
- cognee/tasks/translation/translate_content.py +282 -0
- cognee/tasks/web_scraper/default_url_crawler.py +6 -2
- cognee/tests/cli_tests/cli_unit_tests/test_cli_commands.py +1 -0
- cognee/tests/cli_tests/cli_unit_tests/test_cli_edge_cases.py +3 -0
- cognee/tests/integration/retrieval/test_brute_force_triplet_search_with_cognify.py +62 -0
- cognee/tests/integration/retrieval/test_chunks_retriever.py +115 -16
- cognee/tests/integration/retrieval/test_graph_completion_retriever.py +13 -5
- cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +22 -20
- cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +23 -24
- cognee/tests/integration/retrieval/test_rag_completion_retriever.py +70 -5
- cognee/tests/integration/retrieval/test_structured_output.py +62 -18
- cognee/tests/integration/retrieval/test_summaries_retriever.py +20 -9
- cognee/tests/integration/retrieval/test_temporal_retriever.py +38 -8
- cognee/tests/integration/retrieval/test_triplet_retriever.py +13 -4
- cognee/tests/integration/shared/test_usage_logger_integration.py +255 -0
- cognee/tests/tasks/translation/README.md +147 -0
- cognee/tests/tasks/translation/__init__.py +1 -0
- cognee/tests/tasks/translation/config_test.py +93 -0
- cognee/tests/tasks/translation/detect_language_test.py +118 -0
- cognee/tests/tasks/translation/providers_test.py +151 -0
- cognee/tests/tasks/translation/translate_content_test.py +213 -0
- cognee/tests/test_chromadb.py +1 -1
- cognee/tests/test_cleanup_unused_data.py +165 -0
- cognee/tests/test_delete_by_id.py +6 -6
- cognee/tests/test_extract_usage_frequency.py +308 -0
- cognee/tests/test_kuzu.py +17 -7
- cognee/tests/test_lancedb.py +3 -1
- cognee/tests/test_library.py +1 -1
- cognee/tests/test_neo4j.py +17 -7
- cognee/tests/test_neptune_analytics_vector.py +3 -1
- cognee/tests/test_permissions.py +172 -187
- cognee/tests/test_pgvector.py +3 -1
- cognee/tests/test_relational_db_migration.py +15 -1
- cognee/tests/test_remote_kuzu.py +3 -1
- cognee/tests/test_s3_file_storage.py +1 -1
- cognee/tests/test_search_db.py +97 -110
- cognee/tests/test_usage_logger_e2e.py +268 -0
- cognee/tests/unit/api/test_get_raw_data_endpoint.py +206 -0
- cognee/tests/unit/eval_framework/answer_generation_test.py +4 -3
- cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +2 -0
- cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +42 -2
- cognee/tests/unit/modules/graph/cognee_graph_test.py +329 -31
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +31 -59
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +70 -33
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +72 -52
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +27 -33
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +28 -15
- cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +37 -42
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +48 -64
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +263 -24
- cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py +273 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +30 -16
- cognee/tests/unit/modules/search/test_get_search_type_retriever_instance.py +125 -0
- cognee/tests/unit/modules/search/test_search.py +176 -0
- cognee/tests/unit/modules/search/test_search_prepare_search_result_contract.py +190 -0
- cognee/tests/unit/modules/users/test_tutorial_notebook_creation.py +511 -297
- cognee/tests/unit/shared/test_usage_logger.py +241 -0
- cognee/tests/unit/users/permissions/test_has_user_management_permission.py +46 -0
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dist-info}/METADATA +22 -17
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dist-info}/RECORD +235 -147
- cognee/api/.env.example +0 -5
- cognee/modules/retrieval/base_graph_retriever.py +0 -24
- cognee/modules/search/methods/get_search_type_tools.py +0 -223
- cognee/modules/search/methods/no_access_control_search.py +0 -62
- cognee/modules/search/utils/prepare_search_result.py +0 -63
- cognee/tests/test_feedback_enrichment.py +0 -174
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dist-info}/WHEEL +0 -0
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -72,7 +72,7 @@ async def test_get_completion_without_context(mock_edge):
|
|
|
72
72
|
mock_graph_engine = AsyncMock()
|
|
73
73
|
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
74
74
|
|
|
75
|
-
retriever = GraphCompletionContextExtensionRetriever()
|
|
75
|
+
retriever = GraphCompletionContextExtensionRetriever(context_extension_rounds=1)
|
|
76
76
|
|
|
77
77
|
with (
|
|
78
78
|
patch(
|
|
@@ -99,7 +99,11 @@ async def test_get_completion_without_context(mock_edge):
|
|
|
99
99
|
mock_config.caching = False
|
|
100
100
|
mock_cache_config.return_value = mock_config
|
|
101
101
|
|
|
102
|
-
|
|
102
|
+
retrieved_objects = await retriever.get_retrieved_objects("test_query")
|
|
103
|
+
context = await retriever.get_context_from_objects("test query", retrieved_objects)
|
|
104
|
+
completion = await retriever.get_completion_from_context(
|
|
105
|
+
"test query", retrieved_objects, context
|
|
106
|
+
)
|
|
103
107
|
|
|
104
108
|
assert isinstance(completion, list)
|
|
105
109
|
assert len(completion) == 1
|
|
@@ -109,7 +113,7 @@ async def test_get_completion_without_context(mock_edge):
|
|
|
109
113
|
@pytest.mark.asyncio
|
|
110
114
|
async def test_get_completion_with_provided_context(mock_edge):
|
|
111
115
|
"""Test get_completion uses provided context."""
|
|
112
|
-
retriever = GraphCompletionContextExtensionRetriever()
|
|
116
|
+
retriever = GraphCompletionContextExtensionRetriever(context_extension_rounds=1)
|
|
113
117
|
|
|
114
118
|
with (
|
|
115
119
|
patch(
|
|
@@ -128,8 +132,11 @@ async def test_get_completion_with_provided_context(mock_edge):
|
|
|
128
132
|
mock_config.caching = False
|
|
129
133
|
mock_cache_config.return_value = mock_config
|
|
130
134
|
|
|
131
|
-
|
|
132
|
-
"test query",
|
|
135
|
+
context = await retriever.get_context_from_objects(
|
|
136
|
+
"test query", retrieved_objects=[mock_edge]
|
|
137
|
+
)
|
|
138
|
+
completion = await retriever.get_completion_from_context(
|
|
139
|
+
"test query", retrieved_objects=[mock_edge], context=context
|
|
133
140
|
)
|
|
134
141
|
|
|
135
142
|
assert isinstance(completion, list)
|
|
@@ -143,7 +150,7 @@ async def test_get_completion_context_extension_rounds(mock_edge):
|
|
|
143
150
|
mock_graph_engine = AsyncMock()
|
|
144
151
|
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
145
152
|
|
|
146
|
-
retriever = GraphCompletionContextExtensionRetriever()
|
|
153
|
+
retriever = GraphCompletionContextExtensionRetriever(context_extension_rounds=1)
|
|
147
154
|
|
|
148
155
|
# Create a second edge for extension rounds
|
|
149
156
|
mock_edge2 = MagicMock(spec=Edge)
|
|
@@ -155,7 +162,7 @@ async def test_get_completion_context_extension_rounds(mock_edge):
|
|
|
155
162
|
),
|
|
156
163
|
patch.object(
|
|
157
164
|
retriever,
|
|
158
|
-
"
|
|
165
|
+
"get_context_from_objects",
|
|
159
166
|
new_callable=AsyncMock,
|
|
160
167
|
side_effect=[[mock_edge], [mock_edge2]],
|
|
161
168
|
),
|
|
@@ -178,7 +185,11 @@ async def test_get_completion_context_extension_rounds(mock_edge):
|
|
|
178
185
|
mock_config.caching = False
|
|
179
186
|
mock_cache_config.return_value = mock_config
|
|
180
187
|
|
|
181
|
-
|
|
188
|
+
objects = await retriever.get_retrieved_objects("test_query")
|
|
189
|
+
context = await retriever.get_context_from_objects("test query", objects)
|
|
190
|
+
completion = await retriever.get_completion_from_context(
|
|
191
|
+
"test query", objects, context=context
|
|
192
|
+
)
|
|
182
193
|
|
|
183
194
|
assert isinstance(completion, list)
|
|
184
195
|
assert len(completion) == 1
|
|
@@ -191,10 +202,12 @@ async def test_get_completion_context_extension_stops_early(mock_edge):
|
|
|
191
202
|
mock_graph_engine = AsyncMock()
|
|
192
203
|
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
193
204
|
|
|
194
|
-
retriever = GraphCompletionContextExtensionRetriever()
|
|
205
|
+
retriever = GraphCompletionContextExtensionRetriever(context_extension_rounds=4)
|
|
195
206
|
|
|
196
207
|
with (
|
|
197
|
-
patch.object(
|
|
208
|
+
patch.object(
|
|
209
|
+
retriever, "get_context_from_objects", new_callable=AsyncMock, return_value=[mock_edge]
|
|
210
|
+
),
|
|
198
211
|
patch(
|
|
199
212
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
200
213
|
return_value="Resolved context",
|
|
@@ -215,8 +228,10 @@ async def test_get_completion_context_extension_stops_early(mock_edge):
|
|
|
215
228
|
mock_cache_config.return_value = mock_config
|
|
216
229
|
|
|
217
230
|
# When get_context returns same triplets, the loop should stop early
|
|
218
|
-
|
|
219
|
-
|
|
231
|
+
objects = await retriever.get_retrieved_objects("test_query")
|
|
232
|
+
context = await retriever.get_context_from_objects("test query", objects)
|
|
233
|
+
completion = await retriever.get_completion_from_context(
|
|
234
|
+
"test query", objects, context=context
|
|
220
235
|
)
|
|
221
236
|
|
|
222
237
|
assert isinstance(completion, list)
|
|
@@ -230,7 +245,9 @@ async def test_get_completion_with_session(mock_edge):
|
|
|
230
245
|
mock_graph_engine = AsyncMock()
|
|
231
246
|
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
232
247
|
|
|
233
|
-
retriever = GraphCompletionContextExtensionRetriever(
|
|
248
|
+
retriever = GraphCompletionContextExtensionRetriever(
|
|
249
|
+
session_id="test_session", context_extension_rounds=1
|
|
250
|
+
)
|
|
234
251
|
|
|
235
252
|
mock_user = MagicMock()
|
|
236
253
|
mock_user.id = "test-user-id"
|
|
@@ -240,7 +257,9 @@ async def test_get_completion_with_session(mock_edge):
|
|
|
240
257
|
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
241
258
|
return_value=mock_graph_engine,
|
|
242
259
|
),
|
|
243
|
-
patch.object(
|
|
260
|
+
patch.object(
|
|
261
|
+
retriever, "get_context_from_objects", new_callable=AsyncMock, return_value=[mock_edge]
|
|
262
|
+
),
|
|
244
263
|
patch(
|
|
245
264
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
246
265
|
return_value="Resolved context",
|
|
@@ -275,8 +294,10 @@ async def test_get_completion_with_session(mock_edge):
|
|
|
275
294
|
mock_cache_config.return_value = mock_config
|
|
276
295
|
mock_session_user.get.return_value = mock_user
|
|
277
296
|
|
|
278
|
-
|
|
279
|
-
|
|
297
|
+
objects = await retriever.get_retrieved_objects("test_query")
|
|
298
|
+
context = await retriever.get_context_from_objects("test query", objects)
|
|
299
|
+
completion = await retriever.get_completion_from_context(
|
|
300
|
+
"test query", objects, context=context
|
|
280
301
|
)
|
|
281
302
|
|
|
282
303
|
assert isinstance(completion, list)
|
|
@@ -292,7 +313,9 @@ async def test_get_completion_with_save_interaction(mock_edge):
|
|
|
292
313
|
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
293
314
|
mock_graph_engine.add_edges = AsyncMock()
|
|
294
315
|
|
|
295
|
-
retriever = GraphCompletionContextExtensionRetriever(
|
|
316
|
+
retriever = GraphCompletionContextExtensionRetriever(
|
|
317
|
+
context_extension_rounds=1, save_interaction=True
|
|
318
|
+
)
|
|
296
319
|
|
|
297
320
|
mock_node1 = MagicMock()
|
|
298
321
|
mock_node2 = MagicMock()
|
|
@@ -304,7 +327,9 @@ async def test_get_completion_with_save_interaction(mock_edge):
|
|
|
304
327
|
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
305
328
|
return_value=mock_graph_engine,
|
|
306
329
|
),
|
|
307
|
-
patch.object(
|
|
330
|
+
patch.object(
|
|
331
|
+
retriever, "get_context_from_objects", new_callable=AsyncMock, return_value="mock_edge"
|
|
332
|
+
),
|
|
308
333
|
patch(
|
|
309
334
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
310
335
|
return_value="Resolved context",
|
|
@@ -334,8 +359,9 @@ async def test_get_completion_with_save_interaction(mock_edge):
|
|
|
334
359
|
mock_config.caching = False
|
|
335
360
|
mock_cache_config.return_value = mock_config
|
|
336
361
|
|
|
337
|
-
|
|
338
|
-
|
|
362
|
+
context = await retriever.get_context_from_objects("test query", [mock_edge])
|
|
363
|
+
completion = await retriever.get_completion_from_context(
|
|
364
|
+
"test query", [mock_edge], context=context
|
|
339
365
|
)
|
|
340
366
|
|
|
341
367
|
assert isinstance(completion, list)
|
|
@@ -354,14 +380,16 @@ async def test_get_completion_with_response_model(mock_edge):
|
|
|
354
380
|
mock_graph_engine = AsyncMock()
|
|
355
381
|
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
356
382
|
|
|
357
|
-
retriever = GraphCompletionContextExtensionRetriever()
|
|
383
|
+
retriever = GraphCompletionContextExtensionRetriever(context_extension_rounds=1)
|
|
358
384
|
|
|
359
385
|
with (
|
|
360
386
|
patch(
|
|
361
387
|
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
362
388
|
return_value=mock_graph_engine,
|
|
363
389
|
),
|
|
364
|
-
patch.object(
|
|
390
|
+
patch.object(
|
|
391
|
+
retriever, "get_context_from_objects", new_callable=AsyncMock, return_value=[mock_edge]
|
|
392
|
+
),
|
|
365
393
|
patch(
|
|
366
394
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
367
395
|
return_value="Resolved context",
|
|
@@ -381,8 +409,10 @@ async def test_get_completion_with_response_model(mock_edge):
|
|
|
381
409
|
mock_config.caching = False
|
|
382
410
|
mock_cache_config.return_value = mock_config
|
|
383
411
|
|
|
384
|
-
|
|
385
|
-
|
|
412
|
+
objects = await retriever.get_retrieved_objects("test_query")
|
|
413
|
+
context = await retriever.get_context_from_objects("test query", objects)
|
|
414
|
+
completion = await retriever.get_completion_from_context(
|
|
415
|
+
"test query", objects, context=context
|
|
386
416
|
)
|
|
387
417
|
|
|
388
418
|
assert isinstance(completion, list)
|
|
@@ -396,14 +426,16 @@ async def test_get_completion_with_session_no_user_id(mock_edge):
|
|
|
396
426
|
mock_graph_engine = AsyncMock()
|
|
397
427
|
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
398
428
|
|
|
399
|
-
retriever = GraphCompletionContextExtensionRetriever()
|
|
429
|
+
retriever = GraphCompletionContextExtensionRetriever(context_extension_rounds=1)
|
|
400
430
|
|
|
401
431
|
with (
|
|
402
432
|
patch(
|
|
403
433
|
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
404
434
|
return_value=mock_graph_engine,
|
|
405
435
|
),
|
|
406
|
-
patch.object(
|
|
436
|
+
patch.object(
|
|
437
|
+
retriever, "get_context_from_objects", new_callable=AsyncMock, return_value=[mock_edge]
|
|
438
|
+
),
|
|
407
439
|
patch(
|
|
408
440
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
409
441
|
return_value="Resolved context",
|
|
@@ -427,7 +459,11 @@ async def test_get_completion_with_session_no_user_id(mock_edge):
|
|
|
427
459
|
mock_cache_config.return_value = mock_config
|
|
428
460
|
mock_session_user.get.return_value = None # No user
|
|
429
461
|
|
|
430
|
-
|
|
462
|
+
objects = await retriever.get_retrieved_objects("test_query")
|
|
463
|
+
context = await retriever.get_context_from_objects("test query", objects)
|
|
464
|
+
completion = await retriever.get_completion_from_context(
|
|
465
|
+
"test query", objects, context=context
|
|
466
|
+
)
|
|
431
467
|
|
|
432
468
|
assert isinstance(completion, list)
|
|
433
469
|
assert len(completion) == 1
|
|
@@ -439,14 +475,16 @@ async def test_get_completion_zero_extension_rounds(mock_edge):
|
|
|
439
475
|
mock_graph_engine = AsyncMock()
|
|
440
476
|
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
441
477
|
|
|
442
|
-
retriever = GraphCompletionContextExtensionRetriever()
|
|
478
|
+
retriever = GraphCompletionContextExtensionRetriever(context_extension_rounds=0)
|
|
443
479
|
|
|
444
480
|
with (
|
|
445
481
|
patch(
|
|
446
482
|
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
447
483
|
return_value=mock_graph_engine,
|
|
448
484
|
),
|
|
449
|
-
patch.object(
|
|
485
|
+
patch.object(
|
|
486
|
+
retriever, "get_context_from_objects", new_callable=AsyncMock, return_value=[mock_edge]
|
|
487
|
+
),
|
|
450
488
|
patch(
|
|
451
489
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
452
490
|
return_value="Resolved context",
|
|
@@ -462,8 +500,7 @@ async def test_get_completion_zero_extension_rounds(mock_edge):
|
|
|
462
500
|
mock_config = MagicMock()
|
|
463
501
|
mock_config.caching = False
|
|
464
502
|
mock_cache_config.return_value = mock_config
|
|
503
|
+
context = await retriever.get_context_from_objects("test query", None)
|
|
465
504
|
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
assert isinstance(completion, list)
|
|
469
|
-
assert len(completion) == 1
|
|
505
|
+
assert isinstance(context, list)
|
|
506
|
+
assert len(context) == 1
|
|
@@ -2,6 +2,7 @@ import pytest
|
|
|
2
2
|
from unittest.mock import AsyncMock, patch, MagicMock
|
|
3
3
|
from uuid import UUID
|
|
4
4
|
|
|
5
|
+
from cognee.exceptions import CogneeValidationError
|
|
5
6
|
from cognee.modules.retrieval.graph_completion_cot_retriever import (
|
|
6
7
|
GraphCompletionCotRetriever,
|
|
7
8
|
_as_answer_text,
|
|
@@ -68,7 +69,7 @@ async def test_init_defaults():
|
|
|
68
69
|
@pytest.mark.asyncio
|
|
69
70
|
async def test_run_cot_completion_round_zero_with_context(mock_edge):
|
|
70
71
|
"""Test _run_cot_completion round 0 with provided context."""
|
|
71
|
-
retriever = GraphCompletionCotRetriever()
|
|
72
|
+
retriever = GraphCompletionCotRetriever(max_iter=1)
|
|
72
73
|
|
|
73
74
|
with (
|
|
74
75
|
patch(
|
|
@@ -79,7 +80,9 @@ async def test_run_cot_completion_round_zero_with_context(mock_edge):
|
|
|
79
80
|
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
|
80
81
|
return_value="Generated answer",
|
|
81
82
|
),
|
|
82
|
-
patch.object(
|
|
83
|
+
patch.object(
|
|
84
|
+
retriever, "get_context_from_objects", new_callable=AsyncMock, return_value="mock_edge"
|
|
85
|
+
),
|
|
83
86
|
patch(
|
|
84
87
|
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
|
85
88
|
return_value="Generated answer",
|
|
@@ -92,6 +95,7 @@ async def test_run_cot_completion_round_zero_with_context(mock_edge):
|
|
|
92
95
|
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
|
|
93
96
|
return_value="System prompt",
|
|
94
97
|
),
|
|
98
|
+
patch.object(retriever, "get_triplets", new_callable=AsyncMock, return_value=[mock_edge]),
|
|
95
99
|
patch.object(
|
|
96
100
|
LLMGateway,
|
|
97
101
|
"acreate_structured_output",
|
|
@@ -101,8 +105,6 @@ async def test_run_cot_completion_round_zero_with_context(mock_edge):
|
|
|
101
105
|
):
|
|
102
106
|
completion, context_text, triplets = await retriever._run_cot_completion(
|
|
103
107
|
query="test query",
|
|
104
|
-
context=[mock_edge],
|
|
105
|
-
max_iter=1,
|
|
106
108
|
)
|
|
107
109
|
|
|
108
110
|
assert completion == "Generated answer"
|
|
@@ -116,7 +118,7 @@ async def test_run_cot_completion_round_zero_without_context(mock_edge):
|
|
|
116
118
|
mock_graph_engine = AsyncMock()
|
|
117
119
|
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
118
120
|
|
|
119
|
-
retriever = GraphCompletionCotRetriever()
|
|
121
|
+
retriever = GraphCompletionCotRetriever(max_iter=1)
|
|
120
122
|
|
|
121
123
|
with (
|
|
122
124
|
patch(
|
|
@@ -131,16 +133,13 @@ async def test_run_cot_completion_round_zero_without_context(mock_edge):
|
|
|
131
133
|
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
132
134
|
return_value="Resolved context",
|
|
133
135
|
),
|
|
136
|
+
patch.object(retriever, "get_triplets", new_callable=AsyncMock, return_value=[mock_edge]),
|
|
134
137
|
patch(
|
|
135
138
|
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
|
136
139
|
return_value="Generated answer",
|
|
137
140
|
),
|
|
138
141
|
):
|
|
139
|
-
completion, context_text, triplets = await retriever._run_cot_completion(
|
|
140
|
-
query="test query",
|
|
141
|
-
context=None,
|
|
142
|
-
max_iter=1,
|
|
143
|
-
)
|
|
142
|
+
completion, context_text, triplets = await retriever._run_cot_completion(query="test query")
|
|
144
143
|
|
|
145
144
|
assert completion == "Generated answer"
|
|
146
145
|
assert context_text == "Resolved context"
|
|
@@ -150,7 +149,7 @@ async def test_run_cot_completion_round_zero_without_context(mock_edge):
|
|
|
150
149
|
@pytest.mark.asyncio
|
|
151
150
|
async def test_run_cot_completion_multiple_rounds(mock_edge):
|
|
152
151
|
"""Test _run_cot_completion with multiple rounds."""
|
|
153
|
-
retriever = GraphCompletionCotRetriever()
|
|
152
|
+
retriever = GraphCompletionCotRetriever(max_iter=2)
|
|
154
153
|
|
|
155
154
|
mock_edge2 = MagicMock(spec=Edge)
|
|
156
155
|
|
|
@@ -165,7 +164,7 @@ async def test_run_cot_completion_multiple_rounds(mock_edge):
|
|
|
165
164
|
),
|
|
166
165
|
patch.object(
|
|
167
166
|
retriever,
|
|
168
|
-
"
|
|
167
|
+
"get_retrieved_objects",
|
|
169
168
|
new_callable=AsyncMock,
|
|
170
169
|
side_effect=[[mock_edge], [mock_edge2]],
|
|
171
170
|
),
|
|
@@ -192,12 +191,9 @@ async def test_run_cot_completion_multiple_rounds(mock_edge):
|
|
|
192
191
|
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
|
193
192
|
return_value="Generated answer",
|
|
194
193
|
),
|
|
194
|
+
patch.object(retriever, "get_triplets", new_callable=AsyncMock, return_value=[mock_edge]),
|
|
195
195
|
):
|
|
196
|
-
completion, context_text, triplets = await retriever._run_cot_completion(
|
|
197
|
-
query="test query",
|
|
198
|
-
context=[mock_edge],
|
|
199
|
-
max_iter=2,
|
|
200
|
-
)
|
|
196
|
+
completion, context_text, triplets = await retriever._run_cot_completion(query="test query")
|
|
201
197
|
|
|
202
198
|
assert completion == "Generated answer"
|
|
203
199
|
assert context_text == "Resolved context"
|
|
@@ -207,7 +203,7 @@ async def test_run_cot_completion_multiple_rounds(mock_edge):
|
|
|
207
203
|
@pytest.mark.asyncio
|
|
208
204
|
async def test_run_cot_completion_with_conversation_history(mock_edge):
|
|
209
205
|
"""Test _run_cot_completion with conversation history."""
|
|
210
|
-
retriever = GraphCompletionCotRetriever()
|
|
206
|
+
retriever = GraphCompletionCotRetriever(max_iter=1)
|
|
211
207
|
|
|
212
208
|
with (
|
|
213
209
|
patch(
|
|
@@ -218,12 +214,11 @@ async def test_run_cot_completion_with_conversation_history(mock_edge):
|
|
|
218
214
|
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
|
219
215
|
return_value="Generated answer",
|
|
220
216
|
) as mock_generate,
|
|
217
|
+
patch.object(retriever, "get_triplets", new_callable=AsyncMock, return_value=[mock_edge]),
|
|
221
218
|
):
|
|
222
219
|
completion, context_text, triplets = await retriever._run_cot_completion(
|
|
223
220
|
query="test query",
|
|
224
|
-
context=[mock_edge],
|
|
225
221
|
conversation_history="Previous conversation",
|
|
226
|
-
max_iter=1,
|
|
227
222
|
)
|
|
228
223
|
|
|
229
224
|
assert completion == "Generated answer"
|
|
@@ -239,7 +234,7 @@ async def test_run_cot_completion_with_response_model(mock_edge):
|
|
|
239
234
|
class TestModel(BaseModel):
|
|
240
235
|
answer: str
|
|
241
236
|
|
|
242
|
-
retriever = GraphCompletionCotRetriever()
|
|
237
|
+
retriever = GraphCompletionCotRetriever(response_model=TestModel, max_iter=1)
|
|
243
238
|
|
|
244
239
|
with (
|
|
245
240
|
patch(
|
|
@@ -250,13 +245,9 @@ async def test_run_cot_completion_with_response_model(mock_edge):
|
|
|
250
245
|
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
|
251
246
|
return_value=TestModel(answer="Test answer"),
|
|
252
247
|
),
|
|
248
|
+
patch.object(retriever, "get_triplets", new_callable=AsyncMock, return_value=[mock_edge]),
|
|
253
249
|
):
|
|
254
|
-
completion, context_text, triplets = await retriever._run_cot_completion(
|
|
255
|
-
query="test query",
|
|
256
|
-
context=[mock_edge],
|
|
257
|
-
response_model=TestModel,
|
|
258
|
-
max_iter=1,
|
|
259
|
-
)
|
|
250
|
+
completion, context_text, triplets = await retriever._run_cot_completion(query="test query")
|
|
260
251
|
|
|
261
252
|
assert isinstance(completion, TestModel)
|
|
262
253
|
assert completion.answer == "Test answer"
|
|
@@ -265,7 +256,7 @@ async def test_run_cot_completion_with_response_model(mock_edge):
|
|
|
265
256
|
@pytest.mark.asyncio
|
|
266
257
|
async def test_run_cot_completion_empty_conversation_history(mock_edge):
|
|
267
258
|
"""Test _run_cot_completion with empty conversation history."""
|
|
268
|
-
retriever = GraphCompletionCotRetriever()
|
|
259
|
+
retriever = GraphCompletionCotRetriever(max_iter=1)
|
|
269
260
|
|
|
270
261
|
with (
|
|
271
262
|
patch(
|
|
@@ -276,12 +267,11 @@ async def test_run_cot_completion_empty_conversation_history(mock_edge):
|
|
|
276
267
|
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
|
277
268
|
return_value="Generated answer",
|
|
278
269
|
) as mock_generate,
|
|
270
|
+
patch.object(retriever, "get_triplets", new_callable=AsyncMock, return_value=[mock_edge]),
|
|
279
271
|
):
|
|
280
272
|
completion, context_text, triplets = await retriever._run_cot_completion(
|
|
281
273
|
query="test query",
|
|
282
|
-
context=[mock_edge],
|
|
283
274
|
conversation_history="",
|
|
284
|
-
max_iter=1,
|
|
285
275
|
)
|
|
286
276
|
|
|
287
277
|
assert completion == "Generated answer"
|
|
@@ -296,7 +286,7 @@ async def test_get_completion_without_context(mock_edge):
|
|
|
296
286
|
mock_graph_engine = AsyncMock()
|
|
297
287
|
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
298
288
|
|
|
299
|
-
retriever = GraphCompletionCotRetriever()
|
|
289
|
+
retriever = GraphCompletionCotRetriever(max_iter=1)
|
|
300
290
|
|
|
301
291
|
with (
|
|
302
292
|
patch(
|
|
@@ -315,7 +305,9 @@ async def test_get_completion_without_context(mock_edge):
|
|
|
315
305
|
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
|
316
306
|
return_value="Generated answer",
|
|
317
307
|
),
|
|
318
|
-
patch.object(
|
|
308
|
+
patch.object(
|
|
309
|
+
retriever, "get_context_from_objects", new_callable=AsyncMock, return_value=[mock_edge]
|
|
310
|
+
),
|
|
319
311
|
patch(
|
|
320
312
|
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
|
321
313
|
return_value="Generated answer",
|
|
@@ -342,7 +334,13 @@ async def test_get_completion_without_context(mock_edge):
|
|
|
342
334
|
mock_config.caching = False
|
|
343
335
|
mock_cache_config.return_value = mock_config
|
|
344
336
|
|
|
345
|
-
|
|
337
|
+
mock_edge = MagicMock()
|
|
338
|
+
|
|
339
|
+
objects = await retriever.get_retrieved_objects("test query")
|
|
340
|
+
context = await retriever.get_context_from_objects("test query", objects)
|
|
341
|
+
completion = await retriever.get_completion_from_context(
|
|
342
|
+
"test query", [mock_edge], context=context
|
|
343
|
+
)
|
|
346
344
|
|
|
347
345
|
assert isinstance(completion, list)
|
|
348
346
|
assert len(completion) == 1
|
|
@@ -352,7 +350,7 @@ async def test_get_completion_without_context(mock_edge):
|
|
|
352
350
|
@pytest.mark.asyncio
|
|
353
351
|
async def test_get_completion_with_provided_context(mock_edge):
|
|
354
352
|
"""Test get_completion uses provided context."""
|
|
355
|
-
retriever = GraphCompletionCotRetriever()
|
|
353
|
+
retriever = GraphCompletionCotRetriever(max_iter=1)
|
|
356
354
|
|
|
357
355
|
with (
|
|
358
356
|
patch(
|
|
@@ -371,7 +369,13 @@ async def test_get_completion_with_provided_context(mock_edge):
|
|
|
371
369
|
mock_config.caching = False
|
|
372
370
|
mock_cache_config.return_value = mock_config
|
|
373
371
|
|
|
374
|
-
|
|
372
|
+
mock_edge = MagicMock()
|
|
373
|
+
|
|
374
|
+
objects = await retriever.get_retrieved_objects("test query")
|
|
375
|
+
await retriever.get_context_from_objects("test query", objects)
|
|
376
|
+
completion = await retriever.get_completion_from_context(
|
|
377
|
+
"test query", [mock_edge], context="test"
|
|
378
|
+
)
|
|
375
379
|
|
|
376
380
|
assert isinstance(completion, list)
|
|
377
381
|
assert len(completion) == 1
|
|
@@ -384,7 +388,7 @@ async def test_get_completion_with_session(mock_edge):
|
|
|
384
388
|
mock_graph_engine = AsyncMock()
|
|
385
389
|
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
386
390
|
|
|
387
|
-
retriever = GraphCompletionCotRetriever()
|
|
391
|
+
retriever = GraphCompletionCotRetriever(session_id="test_session", max_iter=1)
|
|
388
392
|
|
|
389
393
|
mock_user = MagicMock()
|
|
390
394
|
mock_user.id = "test-user-id"
|
|
@@ -429,8 +433,9 @@ async def test_get_completion_with_session(mock_edge):
|
|
|
429
433
|
mock_cache_config.return_value = mock_config
|
|
430
434
|
mock_session_user.get.return_value = mock_user
|
|
431
435
|
|
|
432
|
-
|
|
433
|
-
|
|
436
|
+
retrieved_objects = await retriever.get_retrieved_objects("test query")
|
|
437
|
+
completion = await retriever.get_completion_from_context(
|
|
438
|
+
"test query", retrieved_objects, context="mock_edge"
|
|
434
439
|
)
|
|
435
440
|
|
|
436
441
|
assert isinstance(completion, list)
|
|
@@ -446,7 +451,7 @@ async def test_get_completion_with_save_interaction(mock_edge):
|
|
|
446
451
|
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
447
452
|
mock_graph_engine.add_edges = AsyncMock()
|
|
448
453
|
|
|
449
|
-
retriever = GraphCompletionCotRetriever(save_interaction=True)
|
|
454
|
+
retriever = GraphCompletionCotRetriever(save_interaction=True, max_iter=1)
|
|
450
455
|
|
|
451
456
|
mock_node1 = MagicMock()
|
|
452
457
|
mock_node2 = MagicMock()
|
|
@@ -462,7 +467,10 @@ async def test_get_completion_with_save_interaction(mock_edge):
|
|
|
462
467
|
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
|
463
468
|
return_value="Generated answer",
|
|
464
469
|
),
|
|
465
|
-
patch.object(retriever, "
|
|
470
|
+
patch.object(retriever, "get_triplets", new_callable=AsyncMock, return_value=[mock_edge]),
|
|
471
|
+
patch.object(
|
|
472
|
+
retriever, "get_context_from_objects", new_callable=AsyncMock, return_value="mock_edge"
|
|
473
|
+
),
|
|
466
474
|
patch(
|
|
467
475
|
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
|
468
476
|
return_value="Generated answer",
|
|
@@ -486,6 +494,8 @@ async def test_get_completion_with_save_interaction(mock_edge):
|
|
|
486
494
|
side_effect=[
|
|
487
495
|
UUID("550e8400-e29b-41d4-a716-446655440000"),
|
|
488
496
|
UUID("550e8400-e29b-41d4-a716-446655440001"),
|
|
497
|
+
UUID("550e8400-e29b-41d4-a716-446655440002"),
|
|
498
|
+
UUID("550e8400-e29b-41d4-a716-446655440003"),
|
|
489
499
|
],
|
|
490
500
|
),
|
|
491
501
|
patch(
|
|
@@ -500,7 +510,11 @@ async def test_get_completion_with_save_interaction(mock_edge):
|
|
|
500
510
|
mock_cache_config.return_value = mock_config
|
|
501
511
|
|
|
502
512
|
# Pass context so save_interaction condition is met
|
|
503
|
-
|
|
513
|
+
retrieved_objects = await retriever.get_retrieved_objects("test query")
|
|
514
|
+
context = await retriever.get_context_from_objects("test query", retrieved_objects)
|
|
515
|
+
completion = await retriever.get_completion_from_context(
|
|
516
|
+
"test query", [mock_edge], context=context
|
|
517
|
+
)
|
|
504
518
|
|
|
505
519
|
assert isinstance(completion, list)
|
|
506
520
|
assert len(completion) == 1
|
|
@@ -518,7 +532,7 @@ async def test_get_completion_with_response_model(mock_edge):
|
|
|
518
532
|
mock_graph_engine = AsyncMock()
|
|
519
533
|
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
520
534
|
|
|
521
|
-
retriever = GraphCompletionCotRetriever()
|
|
535
|
+
retriever = GraphCompletionCotRetriever(response_model=TestModel, max_iter=1)
|
|
522
536
|
|
|
523
537
|
with (
|
|
524
538
|
patch(
|
|
@@ -545,8 +559,12 @@ async def test_get_completion_with_response_model(mock_edge):
|
|
|
545
559
|
mock_config.caching = False
|
|
546
560
|
mock_cache_config.return_value = mock_config
|
|
547
561
|
|
|
548
|
-
|
|
549
|
-
|
|
562
|
+
mock_edge = MagicMock()
|
|
563
|
+
|
|
564
|
+
objects = await retriever.get_retrieved_objects("test query")
|
|
565
|
+
await retriever.get_context_from_objects("test query", objects)
|
|
566
|
+
completion = await retriever.get_completion_from_context(
|
|
567
|
+
"test query", [mock_edge], "mock_edge"
|
|
550
568
|
)
|
|
551
569
|
|
|
552
570
|
assert isinstance(completion, list)
|
|
@@ -560,7 +578,7 @@ async def test_get_completion_with_session_no_user_id(mock_edge):
|
|
|
560
578
|
mock_graph_engine = AsyncMock()
|
|
561
579
|
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
562
580
|
|
|
563
|
-
retriever = GraphCompletionCotRetriever()
|
|
581
|
+
retriever = GraphCompletionCotRetriever(max_iter=1)
|
|
564
582
|
|
|
565
583
|
with (
|
|
566
584
|
patch(
|
|
@@ -591,7 +609,9 @@ async def test_get_completion_with_session_no_user_id(mock_edge):
|
|
|
591
609
|
mock_cache_config.return_value = mock_config
|
|
592
610
|
mock_session_user.get.return_value = None # No user
|
|
593
611
|
|
|
594
|
-
completion = await retriever.
|
|
612
|
+
completion = await retriever.get_completion_from_context(
|
|
613
|
+
"test query", [mock_edge], context="mock_edge"
|
|
614
|
+
)
|
|
595
615
|
|
|
596
616
|
assert isinstance(completion, list)
|
|
597
617
|
assert len(completion) == 1
|
|
@@ -600,7 +620,7 @@ async def test_get_completion_with_session_no_user_id(mock_edge):
|
|
|
600
620
|
@pytest.mark.asyncio
|
|
601
621
|
async def test_get_completion_with_save_interaction_no_context(mock_edge):
|
|
602
622
|
"""Test get_completion with save_interaction but no context provided."""
|
|
603
|
-
retriever = GraphCompletionCotRetriever(save_interaction=True)
|
|
623
|
+
retriever = GraphCompletionCotRetriever(save_interaction=True, max_iter=1)
|
|
604
624
|
|
|
605
625
|
with (
|
|
606
626
|
patch(
|
|
@@ -611,7 +631,9 @@ async def test_get_completion_with_save_interaction_no_context(mock_edge):
|
|
|
611
631
|
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
|
612
632
|
return_value="Generated answer",
|
|
613
633
|
),
|
|
614
|
-
patch.object(
|
|
634
|
+
patch.object(
|
|
635
|
+
retriever, "get_retrieved_objects", new_callable=AsyncMock, return_value=[mock_edge]
|
|
636
|
+
),
|
|
615
637
|
patch(
|
|
616
638
|
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
|
617
639
|
return_value="Generated answer",
|
|
@@ -638,10 +660,8 @@ async def test_get_completion_with_save_interaction_no_context(mock_edge):
|
|
|
638
660
|
mock_config.caching = False
|
|
639
661
|
mock_cache_config.return_value = mock_config
|
|
640
662
|
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
assert isinstance(completion, list)
|
|
644
|
-
assert len(completion) == 1
|
|
663
|
+
with pytest.raises(CogneeValidationError):
|
|
664
|
+
await retriever.get_completion_from_context("test query", None, context=None)
|
|
645
665
|
|
|
646
666
|
|
|
647
667
|
@pytest.mark.asyncio
|