cognee 0.5.1.dev0__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. cognee/__init__.py +2 -0
  2. cognee/alembic/README +1 -0
  3. cognee/alembic/env.py +107 -0
  4. cognee/alembic/script.py.mako +26 -0
  5. cognee/alembic/versions/1a58b986e6e1_enable_delete_for_old_tutorial_notebooks.py +52 -0
  6. cognee/alembic/versions/1d0bb7fede17_add_pipeline_run_status.py +33 -0
  7. cognee/alembic/versions/1daae0df1866_incremental_loading.py +48 -0
  8. cognee/alembic/versions/211ab850ef3d_add_sync_operations_table.py +118 -0
  9. cognee/alembic/versions/45957f0a9849_add_notebook_table.py +46 -0
  10. cognee/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py +333 -0
  11. cognee/alembic/versions/482cd6517ce4_add_default_user.py +30 -0
  12. cognee/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py +98 -0
  13. cognee/alembic/versions/8057ae7329c2_initial_migration.py +25 -0
  14. cognee/alembic/versions/9e7a3cb85175_loader_separation.py +104 -0
  15. cognee/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py +38 -0
  16. cognee/alembic/versions/ab7e313804ae_permission_system_rework.py +236 -0
  17. cognee/alembic/versions/b9274c27a25a_kuzu_11_migration.py +75 -0
  18. cognee/alembic/versions/c946955da633_multi_tenant_support.py +137 -0
  19. cognee/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py +51 -0
  20. cognee/alembic/versions/e4ebee1091e7_expand_data_model_info.py +140 -0
  21. cognee/alembic.ini +117 -0
  22. cognee/api/v1/add/routers/get_add_router.py +2 -0
  23. cognee/api/v1/cognify/cognify.py +11 -6
  24. cognee/api/v1/cognify/routers/get_cognify_router.py +8 -0
  25. cognee/api/v1/config/config.py +60 -0
  26. cognee/api/v1/datasets/routers/get_datasets_router.py +45 -3
  27. cognee/api/v1/memify/routers/get_memify_router.py +2 -0
  28. cognee/api/v1/search/routers/get_search_router.py +21 -6
  29. cognee/api/v1/search/search.py +25 -5
  30. cognee/api/v1/sync/routers/get_sync_router.py +3 -3
  31. cognee/cli/commands/add_command.py +1 -1
  32. cognee/cli/commands/cognify_command.py +6 -0
  33. cognee/cli/commands/config_command.py +1 -1
  34. cognee/context_global_variables.py +5 -1
  35. cognee/eval_framework/answer_generation/answer_generation_executor.py +7 -8
  36. cognee/infrastructure/databases/cache/cache_db_interface.py +38 -1
  37. cognee/infrastructure/databases/cache/config.py +6 -0
  38. cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py +21 -0
  39. cognee/infrastructure/databases/cache/get_cache_engine.py +9 -3
  40. cognee/infrastructure/databases/cache/redis/RedisAdapter.py +60 -1
  41. cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +7 -0
  42. cognee/infrastructure/databases/graph/get_graph_engine.py +29 -1
  43. cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +62 -27
  44. cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +17 -4
  45. cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +2 -1
  46. cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +2 -0
  47. cognee/infrastructure/databases/vector/config.py +6 -0
  48. cognee/infrastructure/databases/vector/create_vector_engine.py +69 -22
  49. cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +64 -9
  50. cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +13 -2
  51. cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +16 -3
  52. cognee/infrastructure/databases/vector/models/ScoredResult.py +3 -3
  53. cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +16 -3
  54. cognee/infrastructure/databases/vector/pgvector/PGVectorDatasetDatabaseHandler.py +86 -0
  55. cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py +81 -2
  56. cognee/infrastructure/databases/vector/vector_db_interface.py +8 -0
  57. cognee/infrastructure/files/utils/get_data_file_path.py +33 -27
  58. cognee/infrastructure/llm/prompts/extract_query_time.txt +1 -1
  59. cognee/infrastructure/llm/prompts/generate_event_entity_prompt.txt +1 -1
  60. cognee/infrastructure/llm/prompts/generate_event_graph_prompt.txt +1 -1
  61. cognee/infrastructure/llm/prompts/generate_graph_prompt.txt +2 -2
  62. cognee/infrastructure/llm/prompts/generate_graph_prompt_guided.txt +1 -1
  63. cognee/infrastructure/llm/prompts/generate_graph_prompt_oneshot.txt +2 -2
  64. cognee/infrastructure/llm/prompts/generate_graph_prompt_simple.txt +1 -1
  65. cognee/infrastructure/llm/prompts/generate_graph_prompt_strict.txt +1 -1
  66. cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt +6 -6
  67. cognee/infrastructure/llm/prompts/test.txt +1 -1
  68. cognee/infrastructure/llm/prompts/translate_content.txt +19 -0
  69. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +24 -0
  70. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llama_cpp/adapter.py +191 -0
  71. cognee/modules/chunking/models/DocumentChunk.py +0 -1
  72. cognee/modules/cognify/config.py +2 -0
  73. cognee/modules/data/models/Data.py +1 -0
  74. cognee/modules/engine/models/Entity.py +0 -1
  75. cognee/modules/engine/operations/setup.py +6 -0
  76. cognee/modules/graph/cognee_graph/CogneeGraph.py +150 -37
  77. cognee/modules/graph/cognee_graph/CogneeGraphElements.py +48 -2
  78. cognee/modules/graph/utils/__init__.py +1 -0
  79. cognee/modules/graph/utils/get_entity_nodes_from_triplets.py +12 -0
  80. cognee/modules/notebooks/methods/__init__.py +1 -0
  81. cognee/modules/notebooks/methods/create_notebook.py +0 -34
  82. cognee/modules/notebooks/methods/create_tutorial_notebooks.py +191 -0
  83. cognee/modules/notebooks/methods/get_notebooks.py +12 -8
  84. cognee/modules/notebooks/tutorials/cognee-basics/cell-1.md +3 -0
  85. cognee/modules/notebooks/tutorials/cognee-basics/cell-2.md +10 -0
  86. cognee/modules/notebooks/tutorials/cognee-basics/cell-3.md +7 -0
  87. cognee/modules/notebooks/tutorials/cognee-basics/cell-4.py +28 -0
  88. cognee/modules/notebooks/tutorials/cognee-basics/cell-5.py +3 -0
  89. cognee/modules/notebooks/tutorials/cognee-basics/cell-6.py +9 -0
  90. cognee/modules/notebooks/tutorials/cognee-basics/cell-7.py +17 -0
  91. cognee/modules/notebooks/tutorials/cognee-basics/config.json +4 -0
  92. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-1.md +3 -0
  93. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-10.md +3 -0
  94. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-11.md +3 -0
  95. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-12.py +3 -0
  96. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-13.md +7 -0
  97. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-14.py +6 -0
  98. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-15.md +3 -0
  99. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-16.py +7 -0
  100. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-2.md +9 -0
  101. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-3.md +7 -0
  102. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-4.md +9 -0
  103. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-5.md +5 -0
  104. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-6.py +13 -0
  105. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-7.md +3 -0
  106. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-8.md +3 -0
  107. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-9.py +31 -0
  108. cognee/modules/notebooks/tutorials/python-development-with-cognee/config.json +4 -0
  109. cognee/modules/notebooks/tutorials/python-development-with-cognee/data/copilot_conversations.json +107 -0
  110. cognee/modules/notebooks/tutorials/python-development-with-cognee/data/guido_contributions.json +976 -0
  111. cognee/modules/notebooks/tutorials/python-development-with-cognee/data/my_developer_rules.md +79 -0
  112. cognee/modules/notebooks/tutorials/python-development-with-cognee/data/pep_style_guide.md +74 -0
  113. cognee/modules/notebooks/tutorials/python-development-with-cognee/data/zen_principles.md +74 -0
  114. cognee/modules/retrieval/EntityCompletionRetriever.py +51 -38
  115. cognee/modules/retrieval/__init__.py +0 -1
  116. cognee/modules/retrieval/base_retriever.py +66 -10
  117. cognee/modules/retrieval/chunks_retriever.py +57 -49
  118. cognee/modules/retrieval/coding_rules_retriever.py +12 -5
  119. cognee/modules/retrieval/completion_retriever.py +29 -28
  120. cognee/modules/retrieval/cypher_search_retriever.py +25 -20
  121. cognee/modules/retrieval/graph_completion_context_extension_retriever.py +42 -46
  122. cognee/modules/retrieval/graph_completion_cot_retriever.py +68 -51
  123. cognee/modules/retrieval/graph_completion_retriever.py +78 -63
  124. cognee/modules/retrieval/graph_summary_completion_retriever.py +2 -0
  125. cognee/modules/retrieval/lexical_retriever.py +34 -12
  126. cognee/modules/retrieval/natural_language_retriever.py +18 -15
  127. cognee/modules/retrieval/summaries_retriever.py +51 -34
  128. cognee/modules/retrieval/temporal_retriever.py +59 -49
  129. cognee/modules/retrieval/triplet_retriever.py +31 -32
  130. cognee/modules/retrieval/utils/access_tracking.py +88 -0
  131. cognee/modules/retrieval/utils/brute_force_triplet_search.py +99 -85
  132. cognee/modules/retrieval/utils/node_edge_vector_search.py +174 -0
  133. cognee/modules/search/methods/__init__.py +1 -0
  134. cognee/modules/search/methods/get_retriever_output.py +53 -0
  135. cognee/modules/search/methods/get_search_type_retriever_instance.py +252 -0
  136. cognee/modules/search/methods/search.py +90 -215
  137. cognee/modules/search/models/SearchResultPayload.py +67 -0
  138. cognee/modules/search/types/SearchResult.py +1 -8
  139. cognee/modules/search/types/SearchType.py +1 -2
  140. cognee/modules/search/types/__init__.py +1 -1
  141. cognee/modules/search/utils/__init__.py +1 -2
  142. cognee/modules/search/utils/transform_insights_to_graph.py +2 -2
  143. cognee/modules/search/utils/{transform_context_to_graph.py → transform_triplets_to_graph.py} +2 -2
  144. cognee/modules/users/authentication/default/default_transport.py +11 -1
  145. cognee/modules/users/authentication/get_api_auth_backend.py +2 -1
  146. cognee/modules/users/authentication/get_client_auth_backend.py +2 -1
  147. cognee/modules/users/methods/create_user.py +0 -9
  148. cognee/modules/users/permissions/methods/has_user_management_permission.py +29 -0
  149. cognee/modules/visualization/cognee_network_visualization.py +1 -1
  150. cognee/run_migrations.py +48 -0
  151. cognee/shared/exceptions/__init__.py +1 -3
  152. cognee/shared/exceptions/exceptions.py +11 -1
  153. cognee/shared/usage_logger.py +332 -0
  154. cognee/shared/utils.py +12 -5
  155. cognee/tasks/chunks/__init__.py +9 -0
  156. cognee/tasks/cleanup/cleanup_unused_data.py +172 -0
  157. cognee/tasks/graph/__init__.py +7 -0
  158. cognee/tasks/memify/__init__.py +8 -0
  159. cognee/tasks/memify/extract_usage_frequency.py +613 -0
  160. cognee/tasks/summarization/models.py +0 -2
  161. cognee/tasks/temporal_graph/__init__.py +0 -1
  162. cognee/tasks/translation/__init__.py +96 -0
  163. cognee/tasks/translation/config.py +110 -0
  164. cognee/tasks/translation/detect_language.py +190 -0
  165. cognee/tasks/translation/exceptions.py +62 -0
  166. cognee/tasks/translation/models.py +72 -0
  167. cognee/tasks/translation/providers/__init__.py +44 -0
  168. cognee/tasks/translation/providers/azure_provider.py +192 -0
  169. cognee/tasks/translation/providers/base.py +85 -0
  170. cognee/tasks/translation/providers/google_provider.py +158 -0
  171. cognee/tasks/translation/providers/llm_provider.py +143 -0
  172. cognee/tasks/translation/translate_content.py +282 -0
  173. cognee/tasks/web_scraper/default_url_crawler.py +6 -2
  174. cognee/tests/cli_tests/cli_unit_tests/test_cli_commands.py +1 -0
  175. cognee/tests/cli_tests/cli_unit_tests/test_cli_edge_cases.py +3 -0
  176. cognee/tests/integration/retrieval/test_brute_force_triplet_search_with_cognify.py +62 -0
  177. cognee/tests/integration/retrieval/test_chunks_retriever.py +115 -16
  178. cognee/tests/integration/retrieval/test_graph_completion_retriever.py +13 -5
  179. cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +22 -20
  180. cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +23 -24
  181. cognee/tests/integration/retrieval/test_rag_completion_retriever.py +70 -5
  182. cognee/tests/integration/retrieval/test_structured_output.py +62 -18
  183. cognee/tests/integration/retrieval/test_summaries_retriever.py +20 -9
  184. cognee/tests/integration/retrieval/test_temporal_retriever.py +38 -8
  185. cognee/tests/integration/retrieval/test_triplet_retriever.py +13 -4
  186. cognee/tests/integration/shared/test_usage_logger_integration.py +255 -0
  187. cognee/tests/tasks/translation/README.md +147 -0
  188. cognee/tests/tasks/translation/__init__.py +1 -0
  189. cognee/tests/tasks/translation/config_test.py +93 -0
  190. cognee/tests/tasks/translation/detect_language_test.py +118 -0
  191. cognee/tests/tasks/translation/providers_test.py +151 -0
  192. cognee/tests/tasks/translation/translate_content_test.py +213 -0
  193. cognee/tests/test_chromadb.py +1 -1
  194. cognee/tests/test_cleanup_unused_data.py +165 -0
  195. cognee/tests/test_delete_by_id.py +6 -6
  196. cognee/tests/test_extract_usage_frequency.py +308 -0
  197. cognee/tests/test_kuzu.py +17 -7
  198. cognee/tests/test_lancedb.py +3 -1
  199. cognee/tests/test_library.py +1 -1
  200. cognee/tests/test_neo4j.py +17 -7
  201. cognee/tests/test_neptune_analytics_vector.py +3 -1
  202. cognee/tests/test_permissions.py +172 -187
  203. cognee/tests/test_pgvector.py +3 -1
  204. cognee/tests/test_relational_db_migration.py +15 -1
  205. cognee/tests/test_remote_kuzu.py +3 -1
  206. cognee/tests/test_s3_file_storage.py +1 -1
  207. cognee/tests/test_search_db.py +97 -110
  208. cognee/tests/test_usage_logger_e2e.py +268 -0
  209. cognee/tests/unit/api/test_get_raw_data_endpoint.py +206 -0
  210. cognee/tests/unit/eval_framework/answer_generation_test.py +4 -3
  211. cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +2 -0
  212. cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +42 -2
  213. cognee/tests/unit/modules/graph/cognee_graph_test.py +329 -31
  214. cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +31 -59
  215. cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +70 -33
  216. cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +72 -52
  217. cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +27 -33
  218. cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +28 -15
  219. cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +37 -42
  220. cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +48 -64
  221. cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +263 -24
  222. cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py +273 -0
  223. cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +30 -16
  224. cognee/tests/unit/modules/search/test_get_search_type_retriever_instance.py +125 -0
  225. cognee/tests/unit/modules/search/test_search.py +176 -0
  226. cognee/tests/unit/modules/search/test_search_prepare_search_result_contract.py +190 -0
  227. cognee/tests/unit/modules/users/test_tutorial_notebook_creation.py +511 -297
  228. cognee/tests/unit/shared/test_usage_logger.py +241 -0
  229. cognee/tests/unit/users/permissions/test_has_user_management_permission.py +46 -0
  230. {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dist-info}/METADATA +22 -17
  231. {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dist-info}/RECORD +235 -147
  232. cognee/api/.env.example +0 -5
  233. cognee/modules/retrieval/base_graph_retriever.py +0 -24
  234. cognee/modules/search/methods/get_search_type_tools.py +0 -223
  235. cognee/modules/search/methods/no_access_control_search.py +0 -62
  236. cognee/modules/search/utils/prepare_search_result.py +0 -63
  237. cognee/tests/test_feedback_enrichment.py +0 -174
  238. {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dist-info}/WHEEL +0 -0
  239. {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dist-info}/entry_points.txt +0 -0
  240. {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dist-info}/licenses/LICENSE +0 -0
  241. {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dist-info}/licenses/NOTICE.md +0 -0
@@ -72,7 +72,7 @@ async def test_get_completion_without_context(mock_edge):
72
72
  mock_graph_engine = AsyncMock()
73
73
  mock_graph_engine.is_empty = AsyncMock(return_value=False)
74
74
 
75
- retriever = GraphCompletionContextExtensionRetriever()
75
+ retriever = GraphCompletionContextExtensionRetriever(context_extension_rounds=1)
76
76
 
77
77
  with (
78
78
  patch(
@@ -99,7 +99,11 @@ async def test_get_completion_without_context(mock_edge):
99
99
  mock_config.caching = False
100
100
  mock_cache_config.return_value = mock_config
101
101
 
102
- completion = await retriever.get_completion("test query", context_extension_rounds=1)
102
+ retrieved_objects = await retriever.get_retrieved_objects("test_query")
103
+ context = await retriever.get_context_from_objects("test query", retrieved_objects)
104
+ completion = await retriever.get_completion_from_context(
105
+ "test query", retrieved_objects, context
106
+ )
103
107
 
104
108
  assert isinstance(completion, list)
105
109
  assert len(completion) == 1
@@ -109,7 +113,7 @@ async def test_get_completion_without_context(mock_edge):
109
113
  @pytest.mark.asyncio
110
114
  async def test_get_completion_with_provided_context(mock_edge):
111
115
  """Test get_completion uses provided context."""
112
- retriever = GraphCompletionContextExtensionRetriever()
116
+ retriever = GraphCompletionContextExtensionRetriever(context_extension_rounds=1)
113
117
 
114
118
  with (
115
119
  patch(
@@ -128,8 +132,11 @@ async def test_get_completion_with_provided_context(mock_edge):
128
132
  mock_config.caching = False
129
133
  mock_cache_config.return_value = mock_config
130
134
 
131
- completion = await retriever.get_completion(
132
- "test query", context=[mock_edge], context_extension_rounds=1
135
+ context = await retriever.get_context_from_objects(
136
+ "test query", retrieved_objects=[mock_edge]
137
+ )
138
+ completion = await retriever.get_completion_from_context(
139
+ "test query", retrieved_objects=[mock_edge], context=context
133
140
  )
134
141
 
135
142
  assert isinstance(completion, list)
@@ -143,7 +150,7 @@ async def test_get_completion_context_extension_rounds(mock_edge):
143
150
  mock_graph_engine = AsyncMock()
144
151
  mock_graph_engine.is_empty = AsyncMock(return_value=False)
145
152
 
146
- retriever = GraphCompletionContextExtensionRetriever()
153
+ retriever = GraphCompletionContextExtensionRetriever(context_extension_rounds=1)
147
154
 
148
155
  # Create a second edge for extension rounds
149
156
  mock_edge2 = MagicMock(spec=Edge)
@@ -155,7 +162,7 @@ async def test_get_completion_context_extension_rounds(mock_edge):
155
162
  ),
156
163
  patch.object(
157
164
  retriever,
158
- "get_context",
165
+ "get_context_from_objects",
159
166
  new_callable=AsyncMock,
160
167
  side_effect=[[mock_edge], [mock_edge2]],
161
168
  ),
@@ -178,7 +185,11 @@ async def test_get_completion_context_extension_rounds(mock_edge):
178
185
  mock_config.caching = False
179
186
  mock_cache_config.return_value = mock_config
180
187
 
181
- completion = await retriever.get_completion("test query", context_extension_rounds=1)
188
+ objects = await retriever.get_retrieved_objects("test_query")
189
+ context = await retriever.get_context_from_objects("test query", objects)
190
+ completion = await retriever.get_completion_from_context(
191
+ "test query", objects, context=context
192
+ )
182
193
 
183
194
  assert isinstance(completion, list)
184
195
  assert len(completion) == 1
@@ -191,10 +202,12 @@ async def test_get_completion_context_extension_stops_early(mock_edge):
191
202
  mock_graph_engine = AsyncMock()
192
203
  mock_graph_engine.is_empty = AsyncMock(return_value=False)
193
204
 
194
- retriever = GraphCompletionContextExtensionRetriever()
205
+ retriever = GraphCompletionContextExtensionRetriever(context_extension_rounds=4)
195
206
 
196
207
  with (
197
- patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
208
+ patch.object(
209
+ retriever, "get_context_from_objects", new_callable=AsyncMock, return_value=[mock_edge]
210
+ ),
198
211
  patch(
199
212
  "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
200
213
  return_value="Resolved context",
@@ -215,8 +228,10 @@ async def test_get_completion_context_extension_stops_early(mock_edge):
215
228
  mock_cache_config.return_value = mock_config
216
229
 
217
230
  # When get_context returns same triplets, the loop should stop early
218
- completion = await retriever.get_completion(
219
- "test query", context=[mock_edge], context_extension_rounds=4
231
+ objects = await retriever.get_retrieved_objects("test_query")
232
+ context = await retriever.get_context_from_objects("test query", objects)
233
+ completion = await retriever.get_completion_from_context(
234
+ "test query", objects, context=context
220
235
  )
221
236
 
222
237
  assert isinstance(completion, list)
@@ -230,7 +245,9 @@ async def test_get_completion_with_session(mock_edge):
230
245
  mock_graph_engine = AsyncMock()
231
246
  mock_graph_engine.is_empty = AsyncMock(return_value=False)
232
247
 
233
- retriever = GraphCompletionContextExtensionRetriever()
248
+ retriever = GraphCompletionContextExtensionRetriever(
249
+ session_id="test_session", context_extension_rounds=1
250
+ )
234
251
 
235
252
  mock_user = MagicMock()
236
253
  mock_user.id = "test-user-id"
@@ -240,7 +257,9 @@ async def test_get_completion_with_session(mock_edge):
240
257
  "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
241
258
  return_value=mock_graph_engine,
242
259
  ),
243
- patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
260
+ patch.object(
261
+ retriever, "get_context_from_objects", new_callable=AsyncMock, return_value=[mock_edge]
262
+ ),
244
263
  patch(
245
264
  "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
246
265
  return_value="Resolved context",
@@ -275,8 +294,10 @@ async def test_get_completion_with_session(mock_edge):
275
294
  mock_cache_config.return_value = mock_config
276
295
  mock_session_user.get.return_value = mock_user
277
296
 
278
- completion = await retriever.get_completion(
279
- "test query", session_id="test_session", context_extension_rounds=1
297
+ objects = await retriever.get_retrieved_objects("test_query")
298
+ context = await retriever.get_context_from_objects("test query", objects)
299
+ completion = await retriever.get_completion_from_context(
300
+ "test query", objects, context=context
280
301
  )
281
302
 
282
303
  assert isinstance(completion, list)
@@ -292,7 +313,9 @@ async def test_get_completion_with_save_interaction(mock_edge):
292
313
  mock_graph_engine.is_empty = AsyncMock(return_value=False)
293
314
  mock_graph_engine.add_edges = AsyncMock()
294
315
 
295
- retriever = GraphCompletionContextExtensionRetriever(save_interaction=True)
316
+ retriever = GraphCompletionContextExtensionRetriever(
317
+ context_extension_rounds=1, save_interaction=True
318
+ )
296
319
 
297
320
  mock_node1 = MagicMock()
298
321
  mock_node2 = MagicMock()
@@ -304,7 +327,9 @@ async def test_get_completion_with_save_interaction(mock_edge):
304
327
  "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
305
328
  return_value=mock_graph_engine,
306
329
  ),
307
- patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
330
+ patch.object(
331
+ retriever, "get_context_from_objects", new_callable=AsyncMock, return_value="mock_edge"
332
+ ),
308
333
  patch(
309
334
  "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
310
335
  return_value="Resolved context",
@@ -334,8 +359,9 @@ async def test_get_completion_with_save_interaction(mock_edge):
334
359
  mock_config.caching = False
335
360
  mock_cache_config.return_value = mock_config
336
361
 
337
- completion = await retriever.get_completion(
338
- "test query", context=[mock_edge], context_extension_rounds=1
362
+ context = await retriever.get_context_from_objects("test query", [mock_edge])
363
+ completion = await retriever.get_completion_from_context(
364
+ "test query", [mock_edge], context=context
339
365
  )
340
366
 
341
367
  assert isinstance(completion, list)
@@ -354,14 +380,16 @@ async def test_get_completion_with_response_model(mock_edge):
354
380
  mock_graph_engine = AsyncMock()
355
381
  mock_graph_engine.is_empty = AsyncMock(return_value=False)
356
382
 
357
- retriever = GraphCompletionContextExtensionRetriever()
383
+ retriever = GraphCompletionContextExtensionRetriever(context_extension_rounds=1)
358
384
 
359
385
  with (
360
386
  patch(
361
387
  "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
362
388
  return_value=mock_graph_engine,
363
389
  ),
364
- patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
390
+ patch.object(
391
+ retriever, "get_context_from_objects", new_callable=AsyncMock, return_value=[mock_edge]
392
+ ),
365
393
  patch(
366
394
  "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
367
395
  return_value="Resolved context",
@@ -381,8 +409,10 @@ async def test_get_completion_with_response_model(mock_edge):
381
409
  mock_config.caching = False
382
410
  mock_cache_config.return_value = mock_config
383
411
 
384
- completion = await retriever.get_completion(
385
- "test query", response_model=TestModel, context_extension_rounds=1
412
+ objects = await retriever.get_retrieved_objects("test_query")
413
+ context = await retriever.get_context_from_objects("test query", objects)
414
+ completion = await retriever.get_completion_from_context(
415
+ "test query", objects, context=context
386
416
  )
387
417
 
388
418
  assert isinstance(completion, list)
@@ -396,14 +426,16 @@ async def test_get_completion_with_session_no_user_id(mock_edge):
396
426
  mock_graph_engine = AsyncMock()
397
427
  mock_graph_engine.is_empty = AsyncMock(return_value=False)
398
428
 
399
- retriever = GraphCompletionContextExtensionRetriever()
429
+ retriever = GraphCompletionContextExtensionRetriever(context_extension_rounds=1)
400
430
 
401
431
  with (
402
432
  patch(
403
433
  "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
404
434
  return_value=mock_graph_engine,
405
435
  ),
406
- patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
436
+ patch.object(
437
+ retriever, "get_context_from_objects", new_callable=AsyncMock, return_value=[mock_edge]
438
+ ),
407
439
  patch(
408
440
  "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
409
441
  return_value="Resolved context",
@@ -427,7 +459,11 @@ async def test_get_completion_with_session_no_user_id(mock_edge):
427
459
  mock_cache_config.return_value = mock_config
428
460
  mock_session_user.get.return_value = None # No user
429
461
 
430
- completion = await retriever.get_completion("test query", context_extension_rounds=1)
462
+ objects = await retriever.get_retrieved_objects("test_query")
463
+ context = await retriever.get_context_from_objects("test query", objects)
464
+ completion = await retriever.get_completion_from_context(
465
+ "test query", objects, context=context
466
+ )
431
467
 
432
468
  assert isinstance(completion, list)
433
469
  assert len(completion) == 1
@@ -439,14 +475,16 @@ async def test_get_completion_zero_extension_rounds(mock_edge):
439
475
  mock_graph_engine = AsyncMock()
440
476
  mock_graph_engine.is_empty = AsyncMock(return_value=False)
441
477
 
442
- retriever = GraphCompletionContextExtensionRetriever()
478
+ retriever = GraphCompletionContextExtensionRetriever(context_extension_rounds=0)
443
479
 
444
480
  with (
445
481
  patch(
446
482
  "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
447
483
  return_value=mock_graph_engine,
448
484
  ),
449
- patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
485
+ patch.object(
486
+ retriever, "get_context_from_objects", new_callable=AsyncMock, return_value=[mock_edge]
487
+ ),
450
488
  patch(
451
489
  "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
452
490
  return_value="Resolved context",
@@ -462,8 +500,7 @@ async def test_get_completion_zero_extension_rounds(mock_edge):
462
500
  mock_config = MagicMock()
463
501
  mock_config.caching = False
464
502
  mock_cache_config.return_value = mock_config
503
+ context = await retriever.get_context_from_objects("test query", None)
465
504
 
466
- completion = await retriever.get_completion("test query", context_extension_rounds=0)
467
-
468
- assert isinstance(completion, list)
469
- assert len(completion) == 1
505
+ assert isinstance(context, list)
506
+ assert len(context) == 1
@@ -2,6 +2,7 @@ import pytest
2
2
  from unittest.mock import AsyncMock, patch, MagicMock
3
3
  from uuid import UUID
4
4
 
5
+ from cognee.exceptions import CogneeValidationError
5
6
  from cognee.modules.retrieval.graph_completion_cot_retriever import (
6
7
  GraphCompletionCotRetriever,
7
8
  _as_answer_text,
@@ -68,7 +69,7 @@ async def test_init_defaults():
68
69
  @pytest.mark.asyncio
69
70
  async def test_run_cot_completion_round_zero_with_context(mock_edge):
70
71
  """Test _run_cot_completion round 0 with provided context."""
71
- retriever = GraphCompletionCotRetriever()
72
+ retriever = GraphCompletionCotRetriever(max_iter=1)
72
73
 
73
74
  with (
74
75
  patch(
@@ -79,7 +80,9 @@ async def test_run_cot_completion_round_zero_with_context(mock_edge):
79
80
  "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
80
81
  return_value="Generated answer",
81
82
  ),
82
- patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
83
+ patch.object(
84
+ retriever, "get_context_from_objects", new_callable=AsyncMock, return_value="mock_edge"
85
+ ),
83
86
  patch(
84
87
  "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
85
88
  return_value="Generated answer",
@@ -92,6 +95,7 @@ async def test_run_cot_completion_round_zero_with_context(mock_edge):
92
95
  "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
93
96
  return_value="System prompt",
94
97
  ),
98
+ patch.object(retriever, "get_triplets", new_callable=AsyncMock, return_value=[mock_edge]),
95
99
  patch.object(
96
100
  LLMGateway,
97
101
  "acreate_structured_output",
@@ -101,8 +105,6 @@ async def test_run_cot_completion_round_zero_with_context(mock_edge):
101
105
  ):
102
106
  completion, context_text, triplets = await retriever._run_cot_completion(
103
107
  query="test query",
104
- context=[mock_edge],
105
- max_iter=1,
106
108
  )
107
109
 
108
110
  assert completion == "Generated answer"
@@ -116,7 +118,7 @@ async def test_run_cot_completion_round_zero_without_context(mock_edge):
116
118
  mock_graph_engine = AsyncMock()
117
119
  mock_graph_engine.is_empty = AsyncMock(return_value=False)
118
120
 
119
- retriever = GraphCompletionCotRetriever()
121
+ retriever = GraphCompletionCotRetriever(max_iter=1)
120
122
 
121
123
  with (
122
124
  patch(
@@ -131,16 +133,13 @@ async def test_run_cot_completion_round_zero_without_context(mock_edge):
131
133
  "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
132
134
  return_value="Resolved context",
133
135
  ),
136
+ patch.object(retriever, "get_triplets", new_callable=AsyncMock, return_value=[mock_edge]),
134
137
  patch(
135
138
  "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
136
139
  return_value="Generated answer",
137
140
  ),
138
141
  ):
139
- completion, context_text, triplets = await retriever._run_cot_completion(
140
- query="test query",
141
- context=None,
142
- max_iter=1,
143
- )
142
+ completion, context_text, triplets = await retriever._run_cot_completion(query="test query")
144
143
 
145
144
  assert completion == "Generated answer"
146
145
  assert context_text == "Resolved context"
@@ -150,7 +149,7 @@ async def test_run_cot_completion_round_zero_without_context(mock_edge):
150
149
  @pytest.mark.asyncio
151
150
  async def test_run_cot_completion_multiple_rounds(mock_edge):
152
151
  """Test _run_cot_completion with multiple rounds."""
153
- retriever = GraphCompletionCotRetriever()
152
+ retriever = GraphCompletionCotRetriever(max_iter=2)
154
153
 
155
154
  mock_edge2 = MagicMock(spec=Edge)
156
155
 
@@ -165,7 +164,7 @@ async def test_run_cot_completion_multiple_rounds(mock_edge):
165
164
  ),
166
165
  patch.object(
167
166
  retriever,
168
- "get_context",
167
+ "get_retrieved_objects",
169
168
  new_callable=AsyncMock,
170
169
  side_effect=[[mock_edge], [mock_edge2]],
171
170
  ),
@@ -192,12 +191,9 @@ async def test_run_cot_completion_multiple_rounds(mock_edge):
192
191
  "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
193
192
  return_value="Generated answer",
194
193
  ),
194
+ patch.object(retriever, "get_triplets", new_callable=AsyncMock, return_value=[mock_edge]),
195
195
  ):
196
- completion, context_text, triplets = await retriever._run_cot_completion(
197
- query="test query",
198
- context=[mock_edge],
199
- max_iter=2,
200
- )
196
+ completion, context_text, triplets = await retriever._run_cot_completion(query="test query")
201
197
 
202
198
  assert completion == "Generated answer"
203
199
  assert context_text == "Resolved context"
@@ -207,7 +203,7 @@ async def test_run_cot_completion_multiple_rounds(mock_edge):
207
203
  @pytest.mark.asyncio
208
204
  async def test_run_cot_completion_with_conversation_history(mock_edge):
209
205
  """Test _run_cot_completion with conversation history."""
210
- retriever = GraphCompletionCotRetriever()
206
+ retriever = GraphCompletionCotRetriever(max_iter=1)
211
207
 
212
208
  with (
213
209
  patch(
@@ -218,12 +214,11 @@ async def test_run_cot_completion_with_conversation_history(mock_edge):
218
214
  "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
219
215
  return_value="Generated answer",
220
216
  ) as mock_generate,
217
+ patch.object(retriever, "get_triplets", new_callable=AsyncMock, return_value=[mock_edge]),
221
218
  ):
222
219
  completion, context_text, triplets = await retriever._run_cot_completion(
223
220
  query="test query",
224
- context=[mock_edge],
225
221
  conversation_history="Previous conversation",
226
- max_iter=1,
227
222
  )
228
223
 
229
224
  assert completion == "Generated answer"
@@ -239,7 +234,7 @@ async def test_run_cot_completion_with_response_model(mock_edge):
239
234
  class TestModel(BaseModel):
240
235
  answer: str
241
236
 
242
- retriever = GraphCompletionCotRetriever()
237
+ retriever = GraphCompletionCotRetriever(response_model=TestModel, max_iter=1)
243
238
 
244
239
  with (
245
240
  patch(
@@ -250,13 +245,9 @@ async def test_run_cot_completion_with_response_model(mock_edge):
250
245
  "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
251
246
  return_value=TestModel(answer="Test answer"),
252
247
  ),
248
+ patch.object(retriever, "get_triplets", new_callable=AsyncMock, return_value=[mock_edge]),
253
249
  ):
254
- completion, context_text, triplets = await retriever._run_cot_completion(
255
- query="test query",
256
- context=[mock_edge],
257
- response_model=TestModel,
258
- max_iter=1,
259
- )
250
+ completion, context_text, triplets = await retriever._run_cot_completion(query="test query")
260
251
 
261
252
  assert isinstance(completion, TestModel)
262
253
  assert completion.answer == "Test answer"
@@ -265,7 +256,7 @@ async def test_run_cot_completion_with_response_model(mock_edge):
265
256
  @pytest.mark.asyncio
266
257
  async def test_run_cot_completion_empty_conversation_history(mock_edge):
267
258
  """Test _run_cot_completion with empty conversation history."""
268
- retriever = GraphCompletionCotRetriever()
259
+ retriever = GraphCompletionCotRetriever(max_iter=1)
269
260
 
270
261
  with (
271
262
  patch(
@@ -276,12 +267,11 @@ async def test_run_cot_completion_empty_conversation_history(mock_edge):
276
267
  "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
277
268
  return_value="Generated answer",
278
269
  ) as mock_generate,
270
+ patch.object(retriever, "get_triplets", new_callable=AsyncMock, return_value=[mock_edge]),
279
271
  ):
280
272
  completion, context_text, triplets = await retriever._run_cot_completion(
281
273
  query="test query",
282
- context=[mock_edge],
283
274
  conversation_history="",
284
- max_iter=1,
285
275
  )
286
276
 
287
277
  assert completion == "Generated answer"
@@ -296,7 +286,7 @@ async def test_get_completion_without_context(mock_edge):
296
286
  mock_graph_engine = AsyncMock()
297
287
  mock_graph_engine.is_empty = AsyncMock(return_value=False)
298
288
 
299
- retriever = GraphCompletionCotRetriever()
289
+ retriever = GraphCompletionCotRetriever(max_iter=1)
300
290
 
301
291
  with (
302
292
  patch(
@@ -315,7 +305,9 @@ async def test_get_completion_without_context(mock_edge):
315
305
  "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
316
306
  return_value="Generated answer",
317
307
  ),
318
- patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
308
+ patch.object(
309
+ retriever, "get_context_from_objects", new_callable=AsyncMock, return_value=[mock_edge]
310
+ ),
319
311
  patch(
320
312
  "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
321
313
  return_value="Generated answer",
@@ -342,7 +334,13 @@ async def test_get_completion_without_context(mock_edge):
342
334
  mock_config.caching = False
343
335
  mock_cache_config.return_value = mock_config
344
336
 
345
- completion = await retriever.get_completion("test query", max_iter=1)
337
+ mock_edge = MagicMock()
338
+
339
+ objects = await retriever.get_retrieved_objects("test query")
340
+ context = await retriever.get_context_from_objects("test query", objects)
341
+ completion = await retriever.get_completion_from_context(
342
+ "test query", [mock_edge], context=context
343
+ )
346
344
 
347
345
  assert isinstance(completion, list)
348
346
  assert len(completion) == 1
@@ -352,7 +350,7 @@ async def test_get_completion_without_context(mock_edge):
352
350
  @pytest.mark.asyncio
353
351
  async def test_get_completion_with_provided_context(mock_edge):
354
352
  """Test get_completion uses provided context."""
355
- retriever = GraphCompletionCotRetriever()
353
+ retriever = GraphCompletionCotRetriever(max_iter=1)
356
354
 
357
355
  with (
358
356
  patch(
@@ -371,7 +369,13 @@ async def test_get_completion_with_provided_context(mock_edge):
371
369
  mock_config.caching = False
372
370
  mock_cache_config.return_value = mock_config
373
371
 
374
- completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1)
372
+ mock_edge = MagicMock()
373
+
374
+ objects = await retriever.get_retrieved_objects("test query")
375
+ await retriever.get_context_from_objects("test query", objects)
376
+ completion = await retriever.get_completion_from_context(
377
+ "test query", [mock_edge], context="test"
378
+ )
375
379
 
376
380
  assert isinstance(completion, list)
377
381
  assert len(completion) == 1
@@ -384,7 +388,7 @@ async def test_get_completion_with_session(mock_edge):
384
388
  mock_graph_engine = AsyncMock()
385
389
  mock_graph_engine.is_empty = AsyncMock(return_value=False)
386
390
 
387
- retriever = GraphCompletionCotRetriever()
391
+ retriever = GraphCompletionCotRetriever(session_id="test_session", max_iter=1)
388
392
 
389
393
  mock_user = MagicMock()
390
394
  mock_user.id = "test-user-id"
@@ -429,8 +433,9 @@ async def test_get_completion_with_session(mock_edge):
429
433
  mock_cache_config.return_value = mock_config
430
434
  mock_session_user.get.return_value = mock_user
431
435
 
432
- completion = await retriever.get_completion(
433
- "test query", session_id="test_session", max_iter=1
436
+ retrieved_objects = await retriever.get_retrieved_objects("test query")
437
+ completion = await retriever.get_completion_from_context(
438
+ "test query", retrieved_objects, context="mock_edge"
434
439
  )
435
440
 
436
441
  assert isinstance(completion, list)
@@ -446,7 +451,7 @@ async def test_get_completion_with_save_interaction(mock_edge):
446
451
  mock_graph_engine.is_empty = AsyncMock(return_value=False)
447
452
  mock_graph_engine.add_edges = AsyncMock()
448
453
 
449
- retriever = GraphCompletionCotRetriever(save_interaction=True)
454
+ retriever = GraphCompletionCotRetriever(save_interaction=True, max_iter=1)
450
455
 
451
456
  mock_node1 = MagicMock()
452
457
  mock_node2 = MagicMock()
@@ -462,7 +467,10 @@ async def test_get_completion_with_save_interaction(mock_edge):
462
467
  "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
463
468
  return_value="Generated answer",
464
469
  ),
465
- patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
470
+ patch.object(retriever, "get_triplets", new_callable=AsyncMock, return_value=[mock_edge]),
471
+ patch.object(
472
+ retriever, "get_context_from_objects", new_callable=AsyncMock, return_value="mock_edge"
473
+ ),
466
474
  patch(
467
475
  "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
468
476
  return_value="Generated answer",
@@ -486,6 +494,8 @@ async def test_get_completion_with_save_interaction(mock_edge):
486
494
  side_effect=[
487
495
  UUID("550e8400-e29b-41d4-a716-446655440000"),
488
496
  UUID("550e8400-e29b-41d4-a716-446655440001"),
497
+ UUID("550e8400-e29b-41d4-a716-446655440002"),
498
+ UUID("550e8400-e29b-41d4-a716-446655440003"),
489
499
  ],
490
500
  ),
491
501
  patch(
@@ -500,7 +510,11 @@ async def test_get_completion_with_save_interaction(mock_edge):
500
510
  mock_cache_config.return_value = mock_config
501
511
 
502
512
  # Pass context so save_interaction condition is met
503
- completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1)
513
+ retrieved_objects = await retriever.get_retrieved_objects("test query")
514
+ context = await retriever.get_context_from_objects("test query", retrieved_objects)
515
+ completion = await retriever.get_completion_from_context(
516
+ "test query", [mock_edge], context=context
517
+ )
504
518
 
505
519
  assert isinstance(completion, list)
506
520
  assert len(completion) == 1
@@ -518,7 +532,7 @@ async def test_get_completion_with_response_model(mock_edge):
518
532
  mock_graph_engine = AsyncMock()
519
533
  mock_graph_engine.is_empty = AsyncMock(return_value=False)
520
534
 
521
- retriever = GraphCompletionCotRetriever()
535
+ retriever = GraphCompletionCotRetriever(response_model=TestModel, max_iter=1)
522
536
 
523
537
  with (
524
538
  patch(
@@ -545,8 +559,12 @@ async def test_get_completion_with_response_model(mock_edge):
545
559
  mock_config.caching = False
546
560
  mock_cache_config.return_value = mock_config
547
561
 
548
- completion = await retriever.get_completion(
549
- "test query", response_model=TestModel, max_iter=1
562
+ mock_edge = MagicMock()
563
+
564
+ objects = await retriever.get_retrieved_objects("test query")
565
+ await retriever.get_context_from_objects("test query", objects)
566
+ completion = await retriever.get_completion_from_context(
567
+ "test query", [mock_edge], "mock_edge"
550
568
  )
551
569
 
552
570
  assert isinstance(completion, list)
@@ -560,7 +578,7 @@ async def test_get_completion_with_session_no_user_id(mock_edge):
560
578
  mock_graph_engine = AsyncMock()
561
579
  mock_graph_engine.is_empty = AsyncMock(return_value=False)
562
580
 
563
- retriever = GraphCompletionCotRetriever()
581
+ retriever = GraphCompletionCotRetriever(max_iter=1)
564
582
 
565
583
  with (
566
584
  patch(
@@ -591,7 +609,9 @@ async def test_get_completion_with_session_no_user_id(mock_edge):
591
609
  mock_cache_config.return_value = mock_config
592
610
  mock_session_user.get.return_value = None # No user
593
611
 
594
- completion = await retriever.get_completion("test query", max_iter=1)
612
+ completion = await retriever.get_completion_from_context(
613
+ "test query", [mock_edge], context="mock_edge"
614
+ )
595
615
 
596
616
  assert isinstance(completion, list)
597
617
  assert len(completion) == 1
@@ -600,7 +620,7 @@ async def test_get_completion_with_session_no_user_id(mock_edge):
600
620
  @pytest.mark.asyncio
601
621
  async def test_get_completion_with_save_interaction_no_context(mock_edge):
602
622
  """Test get_completion with save_interaction but no context provided."""
603
- retriever = GraphCompletionCotRetriever(save_interaction=True)
623
+ retriever = GraphCompletionCotRetriever(save_interaction=True, max_iter=1)
604
624
 
605
625
  with (
606
626
  patch(
@@ -611,7 +631,9 @@ async def test_get_completion_with_save_interaction_no_context(mock_edge):
611
631
  "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
612
632
  return_value="Generated answer",
613
633
  ),
614
- patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
634
+ patch.object(
635
+ retriever, "get_retrieved_objects", new_callable=AsyncMock, return_value=[mock_edge]
636
+ ),
615
637
  patch(
616
638
  "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
617
639
  return_value="Generated answer",
@@ -638,10 +660,8 @@ async def test_get_completion_with_save_interaction_no_context(mock_edge):
638
660
  mock_config.caching = False
639
661
  mock_cache_config.return_value = mock_config
640
662
 
641
- completion = await retriever.get_completion("test query", context=None, max_iter=1)
642
-
643
- assert isinstance(completion, list)
644
- assert len(completion) == 1
663
+ with pytest.raises(CogneeValidationError):
664
+ await retriever.get_completion_from_context("test query", None, context=None)
645
665
 
646
666
 
647
667
  @pytest.mark.asyncio