cognee 0.5.1__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (265) hide show
  1. cognee/__init__.py +2 -0
  2. cognee/alembic/README +1 -0
  3. cognee/alembic/env.py +107 -0
  4. cognee/alembic/script.py.mako +26 -0
  5. cognee/alembic/versions/1a58b986e6e1_enable_delete_for_old_tutorial_notebooks.py +52 -0
  6. cognee/alembic/versions/1d0bb7fede17_add_pipeline_run_status.py +33 -0
  7. cognee/alembic/versions/1daae0df1866_incremental_loading.py +48 -0
  8. cognee/alembic/versions/211ab850ef3d_add_sync_operations_table.py +118 -0
  9. cognee/alembic/versions/45957f0a9849_add_notebook_table.py +46 -0
  10. cognee/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py +333 -0
  11. cognee/alembic/versions/482cd6517ce4_add_default_user.py +30 -0
  12. cognee/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py +98 -0
  13. cognee/alembic/versions/8057ae7329c2_initial_migration.py +25 -0
  14. cognee/alembic/versions/9e7a3cb85175_loader_separation.py +104 -0
  15. cognee/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py +38 -0
  16. cognee/alembic/versions/ab7e313804ae_permission_system_rework.py +236 -0
  17. cognee/alembic/versions/b9274c27a25a_kuzu_11_migration.py +75 -0
  18. cognee/alembic/versions/c946955da633_multi_tenant_support.py +137 -0
  19. cognee/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py +51 -0
  20. cognee/alembic/versions/e4ebee1091e7_expand_data_model_info.py +140 -0
  21. cognee/alembic.ini +117 -0
  22. cognee/api/v1/add/add.py +2 -1
  23. cognee/api/v1/add/routers/get_add_router.py +2 -0
  24. cognee/api/v1/cognify/cognify.py +11 -6
  25. cognee/api/v1/cognify/routers/get_cognify_router.py +8 -0
  26. cognee/api/v1/config/config.py +60 -0
  27. cognee/api/v1/datasets/routers/get_datasets_router.py +46 -3
  28. cognee/api/v1/memify/routers/get_memify_router.py +3 -0
  29. cognee/api/v1/search/routers/get_search_router.py +21 -6
  30. cognee/api/v1/search/search.py +21 -5
  31. cognee/api/v1/sync/routers/get_sync_router.py +3 -3
  32. cognee/cli/commands/add_command.py +1 -1
  33. cognee/cli/commands/cognify_command.py +6 -0
  34. cognee/cli/commands/config_command.py +1 -1
  35. cognee/context_global_variables.py +5 -1
  36. cognee/eval_framework/answer_generation/answer_generation_executor.py +7 -8
  37. cognee/infrastructure/databases/cache/cache_db_interface.py +38 -1
  38. cognee/infrastructure/databases/cache/config.py +6 -0
  39. cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py +21 -0
  40. cognee/infrastructure/databases/cache/get_cache_engine.py +9 -3
  41. cognee/infrastructure/databases/cache/redis/RedisAdapter.py +60 -1
  42. cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +7 -0
  43. cognee/infrastructure/databases/graph/get_graph_engine.py +29 -1
  44. cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +62 -27
  45. cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +17 -4
  46. cognee/infrastructure/databases/relational/config.py +16 -1
  47. cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
  48. cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +26 -3
  49. cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +2 -0
  50. cognee/infrastructure/databases/vector/config.py +6 -0
  51. cognee/infrastructure/databases/vector/create_vector_engine.py +70 -16
  52. cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +64 -9
  53. cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +13 -2
  54. cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +16 -3
  55. cognee/infrastructure/databases/vector/models/ScoredResult.py +3 -3
  56. cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +16 -3
  57. cognee/infrastructure/databases/vector/pgvector/PGVectorDatasetDatabaseHandler.py +86 -0
  58. cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py +81 -2
  59. cognee/infrastructure/databases/vector/vector_db_interface.py +8 -0
  60. cognee/infrastructure/files/utils/get_data_file_path.py +33 -27
  61. cognee/infrastructure/llm/LLMGateway.py +0 -13
  62. cognee/infrastructure/llm/prompts/extract_query_time.txt +1 -1
  63. cognee/infrastructure/llm/prompts/generate_event_entity_prompt.txt +1 -1
  64. cognee/infrastructure/llm/prompts/generate_event_graph_prompt.txt +1 -1
  65. cognee/infrastructure/llm/prompts/generate_graph_prompt.txt +2 -2
  66. cognee/infrastructure/llm/prompts/generate_graph_prompt_guided.txt +1 -1
  67. cognee/infrastructure/llm/prompts/generate_graph_prompt_oneshot.txt +2 -2
  68. cognee/infrastructure/llm/prompts/generate_graph_prompt_simple.txt +1 -1
  69. cognee/infrastructure/llm/prompts/generate_graph_prompt_strict.txt +1 -1
  70. cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt +6 -6
  71. cognee/infrastructure/llm/prompts/test.txt +1 -1
  72. cognee/infrastructure/llm/prompts/translate_content.txt +19 -0
  73. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
  74. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
  75. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
  76. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +29 -5
  77. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llama_cpp/adapter.py +191 -0
  78. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
  79. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
  80. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
  81. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
  82. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
  83. cognee/modules/chunking/models/DocumentChunk.py +0 -1
  84. cognee/modules/cognify/config.py +2 -0
  85. cognee/modules/data/models/Data.py +3 -1
  86. cognee/modules/engine/models/Entity.py +0 -1
  87. cognee/modules/engine/operations/setup.py +6 -0
  88. cognee/modules/graph/cognee_graph/CogneeGraph.py +150 -37
  89. cognee/modules/graph/cognee_graph/CogneeGraphElements.py +48 -2
  90. cognee/modules/graph/utils/__init__.py +1 -0
  91. cognee/modules/graph/utils/get_entity_nodes_from_triplets.py +12 -0
  92. cognee/modules/notebooks/methods/__init__.py +1 -0
  93. cognee/modules/notebooks/methods/create_notebook.py +0 -34
  94. cognee/modules/notebooks/methods/create_tutorial_notebooks.py +191 -0
  95. cognee/modules/notebooks/methods/get_notebooks.py +12 -8
  96. cognee/modules/notebooks/tutorials/cognee-basics/cell-1.md +3 -0
  97. cognee/modules/notebooks/tutorials/cognee-basics/cell-2.md +10 -0
  98. cognee/modules/notebooks/tutorials/cognee-basics/cell-3.md +7 -0
  99. cognee/modules/notebooks/tutorials/cognee-basics/cell-4.py +28 -0
  100. cognee/modules/notebooks/tutorials/cognee-basics/cell-5.py +3 -0
  101. cognee/modules/notebooks/tutorials/cognee-basics/cell-6.py +9 -0
  102. cognee/modules/notebooks/tutorials/cognee-basics/cell-7.py +17 -0
  103. cognee/modules/notebooks/tutorials/cognee-basics/config.json +4 -0
  104. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-1.md +3 -0
  105. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-10.md +3 -0
  106. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-11.md +3 -0
  107. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-12.py +3 -0
  108. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-13.md +7 -0
  109. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-14.py +6 -0
  110. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-15.md +3 -0
  111. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-16.py +7 -0
  112. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-2.md +9 -0
  113. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-3.md +7 -0
  114. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-4.md +9 -0
  115. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-5.md +5 -0
  116. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-6.py +13 -0
  117. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-7.md +3 -0
  118. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-8.md +3 -0
  119. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-9.py +31 -0
  120. cognee/modules/notebooks/tutorials/python-development-with-cognee/config.json +4 -0
  121. cognee/modules/notebooks/tutorials/python-development-with-cognee/data/copilot_conversations.json +107 -0
  122. cognee/modules/notebooks/tutorials/python-development-with-cognee/data/guido_contributions.json +976 -0
  123. cognee/modules/notebooks/tutorials/python-development-with-cognee/data/my_developer_rules.md +79 -0
  124. cognee/modules/notebooks/tutorials/python-development-with-cognee/data/pep_style_guide.md +74 -0
  125. cognee/modules/notebooks/tutorials/python-development-with-cognee/data/zen_principles.md +74 -0
  126. cognee/modules/retrieval/EntityCompletionRetriever.py +51 -38
  127. cognee/modules/retrieval/__init__.py +0 -1
  128. cognee/modules/retrieval/base_retriever.py +66 -10
  129. cognee/modules/retrieval/chunks_retriever.py +57 -49
  130. cognee/modules/retrieval/coding_rules_retriever.py +12 -5
  131. cognee/modules/retrieval/completion_retriever.py +29 -28
  132. cognee/modules/retrieval/cypher_search_retriever.py +25 -20
  133. cognee/modules/retrieval/graph_completion_context_extension_retriever.py +42 -46
  134. cognee/modules/retrieval/graph_completion_cot_retriever.py +68 -51
  135. cognee/modules/retrieval/graph_completion_retriever.py +78 -63
  136. cognee/modules/retrieval/graph_summary_completion_retriever.py +2 -0
  137. cognee/modules/retrieval/lexical_retriever.py +34 -12
  138. cognee/modules/retrieval/natural_language_retriever.py +18 -15
  139. cognee/modules/retrieval/summaries_retriever.py +51 -34
  140. cognee/modules/retrieval/temporal_retriever.py +59 -49
  141. cognee/modules/retrieval/triplet_retriever.py +32 -33
  142. cognee/modules/retrieval/utils/access_tracking.py +88 -0
  143. cognee/modules/retrieval/utils/brute_force_triplet_search.py +99 -103
  144. cognee/modules/retrieval/utils/node_edge_vector_search.py +174 -0
  145. cognee/modules/search/methods/__init__.py +1 -0
  146. cognee/modules/search/methods/get_retriever_output.py +53 -0
  147. cognee/modules/search/methods/get_search_type_retriever_instance.py +252 -0
  148. cognee/modules/search/methods/search.py +90 -222
  149. cognee/modules/search/models/SearchResultPayload.py +67 -0
  150. cognee/modules/search/types/SearchResult.py +1 -8
  151. cognee/modules/search/types/SearchType.py +1 -2
  152. cognee/modules/search/types/__init__.py +1 -1
  153. cognee/modules/search/utils/__init__.py +1 -2
  154. cognee/modules/search/utils/transform_insights_to_graph.py +2 -2
  155. cognee/modules/search/utils/{transform_context_to_graph.py → transform_triplets_to_graph.py} +2 -2
  156. cognee/modules/users/authentication/default/default_transport.py +11 -1
  157. cognee/modules/users/authentication/get_api_auth_backend.py +2 -1
  158. cognee/modules/users/authentication/get_client_auth_backend.py +2 -1
  159. cognee/modules/users/methods/create_user.py +0 -9
  160. cognee/modules/users/permissions/methods/has_user_management_permission.py +29 -0
  161. cognee/modules/visualization/cognee_network_visualization.py +1 -1
  162. cognee/run_migrations.py +48 -0
  163. cognee/shared/exceptions/__init__.py +1 -3
  164. cognee/shared/exceptions/exceptions.py +11 -1
  165. cognee/shared/usage_logger.py +332 -0
  166. cognee/shared/utils.py +12 -5
  167. cognee/tasks/chunks/__init__.py +9 -0
  168. cognee/tasks/cleanup/cleanup_unused_data.py +172 -0
  169. cognee/tasks/graph/__init__.py +7 -0
  170. cognee/tasks/ingestion/data_item.py +8 -0
  171. cognee/tasks/ingestion/ingest_data.py +12 -1
  172. cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
  173. cognee/tasks/memify/__init__.py +8 -0
  174. cognee/tasks/memify/extract_usage_frequency.py +613 -0
  175. cognee/tasks/summarization/models.py +0 -2
  176. cognee/tasks/temporal_graph/__init__.py +0 -1
  177. cognee/tasks/translation/__init__.py +96 -0
  178. cognee/tasks/translation/config.py +110 -0
  179. cognee/tasks/translation/detect_language.py +190 -0
  180. cognee/tasks/translation/exceptions.py +62 -0
  181. cognee/tasks/translation/models.py +72 -0
  182. cognee/tasks/translation/providers/__init__.py +44 -0
  183. cognee/tasks/translation/providers/azure_provider.py +192 -0
  184. cognee/tasks/translation/providers/base.py +85 -0
  185. cognee/tasks/translation/providers/google_provider.py +158 -0
  186. cognee/tasks/translation/providers/llm_provider.py +143 -0
  187. cognee/tasks/translation/translate_content.py +282 -0
  188. cognee/tasks/web_scraper/default_url_crawler.py +6 -2
  189. cognee/tests/cli_tests/cli_unit_tests/test_cli_commands.py +1 -0
  190. cognee/tests/cli_tests/cli_unit_tests/test_cli_edge_cases.py +3 -0
  191. cognee/tests/integration/retrieval/test_brute_force_triplet_search_with_cognify.py +62 -0
  192. cognee/tests/integration/retrieval/test_chunks_retriever.py +351 -0
  193. cognee/tests/integration/retrieval/test_graph_completion_retriever.py +276 -0
  194. cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +228 -0
  195. cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +217 -0
  196. cognee/tests/integration/retrieval/test_rag_completion_retriever.py +319 -0
  197. cognee/tests/integration/retrieval/test_structured_output.py +258 -0
  198. cognee/tests/integration/retrieval/test_summaries_retriever.py +195 -0
  199. cognee/tests/integration/retrieval/test_temporal_retriever.py +336 -0
  200. cognee/tests/integration/retrieval/test_triplet_retriever.py +45 -1
  201. cognee/tests/integration/shared/test_usage_logger_integration.py +255 -0
  202. cognee/tests/tasks/translation/README.md +147 -0
  203. cognee/tests/tasks/translation/__init__.py +1 -0
  204. cognee/tests/tasks/translation/config_test.py +93 -0
  205. cognee/tests/tasks/translation/detect_language_test.py +118 -0
  206. cognee/tests/tasks/translation/providers_test.py +151 -0
  207. cognee/tests/tasks/translation/translate_content_test.py +213 -0
  208. cognee/tests/test_chromadb.py +1 -1
  209. cognee/tests/test_cleanup_unused_data.py +165 -0
  210. cognee/tests/test_custom_data_label.py +68 -0
  211. cognee/tests/test_delete_by_id.py +6 -6
  212. cognee/tests/test_extract_usage_frequency.py +308 -0
  213. cognee/tests/test_kuzu.py +17 -7
  214. cognee/tests/test_lancedb.py +3 -1
  215. cognee/tests/test_library.py +1 -1
  216. cognee/tests/test_neo4j.py +17 -7
  217. cognee/tests/test_neptune_analytics_vector.py +3 -1
  218. cognee/tests/test_permissions.py +172 -187
  219. cognee/tests/test_pgvector.py +3 -1
  220. cognee/tests/test_relational_db_migration.py +15 -1
  221. cognee/tests/test_remote_kuzu.py +3 -1
  222. cognee/tests/test_s3_file_storage.py +1 -1
  223. cognee/tests/test_search_db.py +345 -205
  224. cognee/tests/test_usage_logger_e2e.py +268 -0
  225. cognee/tests/unit/api/test_get_raw_data_endpoint.py +206 -0
  226. cognee/tests/unit/eval_framework/answer_generation_test.py +4 -3
  227. cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
  228. cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
  229. cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +2 -0
  230. cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
  231. cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +42 -2
  232. cognee/tests/unit/modules/graph/cognee_graph_test.py +329 -31
  233. cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +122 -168
  234. cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
  235. cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +486 -157
  236. cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +693 -155
  237. cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +619 -200
  238. cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +300 -171
  239. cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +184 -155
  240. cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +544 -79
  241. cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +476 -28
  242. cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
  243. cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
  244. cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py +273 -0
  245. cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
  246. cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +267 -7
  247. cognee/tests/unit/modules/search/test_get_search_type_retriever_instance.py +125 -0
  248. cognee/tests/unit/modules/search/test_search.py +96 -20
  249. cognee/tests/unit/modules/search/test_search_prepare_search_result_contract.py +190 -0
  250. cognee/tests/unit/modules/users/test_tutorial_notebook_creation.py +511 -297
  251. cognee/tests/unit/shared/test_usage_logger.py +241 -0
  252. cognee/tests/unit/users/permissions/test_has_user_management_permission.py +46 -0
  253. {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/METADATA +22 -17
  254. {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/RECORD +258 -157
  255. cognee/api/.env.example +0 -5
  256. cognee/modules/retrieval/base_graph_retriever.py +0 -24
  257. cognee/modules/search/methods/get_search_type_tools.py +0 -223
  258. cognee/modules/search/methods/no_access_control_search.py +0 -62
  259. cognee/modules/search/utils/prepare_search_result.py +0 -63
  260. cognee/tests/test_feedback_enrichment.py +0 -174
  261. cognee/tests/unit/modules/retrieval/structured_output_test.py +0 -204
  262. {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/WHEEL +0 -0
  263. {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/entry_points.txt +0 -0
  264. {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/licenses/LICENSE +0 -0
  265. {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/licenses/NOTICE.md +0 -0
@@ -1,170 +1,708 @@
1
- import os
2
1
  import pytest
3
- import pathlib
4
- from typing import Optional, Union
5
-
6
- import cognee
7
- from cognee.low_level import setup, DataPoint
8
- from cognee.modules.graph.utils import resolve_edges_to_text
9
- from cognee.tasks.storage import add_data_points
10
- from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
11
-
12
-
13
- class TestGraphCompletionCoTRetriever:
14
- @pytest.mark.asyncio
15
- async def test_graph_completion_cot_context_simple(self):
16
- system_directory_path = os.path.join(
17
- pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_cot_context_simple"
2
+ from unittest.mock import AsyncMock, patch, MagicMock
3
+ from uuid import UUID
4
+
5
+ from cognee.exceptions import CogneeValidationError
6
+ from cognee.modules.retrieval.graph_completion_cot_retriever import (
7
+ GraphCompletionCotRetriever,
8
+ _as_answer_text,
9
+ )
10
+ from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
11
+ from cognee.infrastructure.llm.LLMGateway import LLMGateway
12
+
13
+
14
+ @pytest.fixture
15
+ def mock_edge():
16
+ """Create a mock edge."""
17
+ edge = MagicMock(spec=Edge)
18
+ return edge
19
+
20
+
21
+ @pytest.mark.asyncio
22
+ async def test_get_triplets_inherited(mock_edge):
23
+ """Test that get_triplets is inherited from parent class."""
24
+ retriever = GraphCompletionCotRetriever()
25
+
26
+ with patch(
27
+ "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
28
+ return_value=[mock_edge],
29
+ ):
30
+ triplets = await retriever.get_triplets("test query")
31
+
32
+ assert len(triplets) == 1
33
+ assert triplets[0] == mock_edge
34
+
35
+
36
+ @pytest.mark.asyncio
37
+ async def test_init_custom_params():
38
+ """Test GraphCompletionCotRetriever initialization with custom parameters."""
39
+ retriever = GraphCompletionCotRetriever(
40
+ top_k=10,
41
+ user_prompt_path="custom_user.txt",
42
+ system_prompt_path="custom_system.txt",
43
+ validation_user_prompt_path="custom_validation_user.txt",
44
+ validation_system_prompt_path="custom_validation_system.txt",
45
+ followup_system_prompt_path="custom_followup_system.txt",
46
+ followup_user_prompt_path="custom_followup_user.txt",
47
+ )
48
+
49
+ assert retriever.top_k == 10
50
+ assert retriever.user_prompt_path == "custom_user.txt"
51
+ assert retriever.system_prompt_path == "custom_system.txt"
52
+ assert retriever.validation_user_prompt_path == "custom_validation_user.txt"
53
+ assert retriever.validation_system_prompt_path == "custom_validation_system.txt"
54
+ assert retriever.followup_system_prompt_path == "custom_followup_system.txt"
55
+ assert retriever.followup_user_prompt_path == "custom_followup_user.txt"
56
+
57
+
58
+ @pytest.mark.asyncio
59
+ async def test_init_defaults():
60
+ """Test GraphCompletionCotRetriever initialization with defaults."""
61
+ retriever = GraphCompletionCotRetriever()
62
+
63
+ assert retriever.validation_user_prompt_path == "cot_validation_user_prompt.txt"
64
+ assert retriever.validation_system_prompt_path == "cot_validation_system_prompt.txt"
65
+ assert retriever.followup_system_prompt_path == "cot_followup_system_prompt.txt"
66
+ assert retriever.followup_user_prompt_path == "cot_followup_user_prompt.txt"
67
+
68
+
69
+ @pytest.mark.asyncio
70
+ async def test_run_cot_completion_round_zero_with_context(mock_edge):
71
+ """Test _run_cot_completion round 0 with provided context."""
72
+ retriever = GraphCompletionCotRetriever(max_iter=1)
73
+
74
+ with (
75
+ patch(
76
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
77
+ return_value="Resolved context",
78
+ ),
79
+ patch(
80
+ "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
81
+ return_value="Generated answer",
82
+ ),
83
+ patch.object(
84
+ retriever, "get_context_from_objects", new_callable=AsyncMock, return_value="mock_edge"
85
+ ),
86
+ patch(
87
+ "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
88
+ return_value="Generated answer",
89
+ ),
90
+ patch(
91
+ "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
92
+ return_value="Rendered prompt",
93
+ ),
94
+ patch(
95
+ "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
96
+ return_value="System prompt",
97
+ ),
98
+ patch.object(retriever, "get_triplets", new_callable=AsyncMock, return_value=[mock_edge]),
99
+ patch.object(
100
+ LLMGateway,
101
+ "acreate_structured_output",
102
+ new_callable=AsyncMock,
103
+ side_effect=["validation_result", "followup_question"],
104
+ ),
105
+ ):
106
+ completion, context_text, triplets = await retriever._run_cot_completion(
107
+ query="test query",
18
108
  )
19
- cognee.config.system_root_directory(system_directory_path)
20
- data_directory_path = os.path.join(
21
- pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_simple"
22
- )
23
- cognee.config.data_root_directory(data_directory_path)
24
-
25
- await cognee.prune.prune_data()
26
- await cognee.prune.prune_system(metadata=True)
27
- await setup()
28
-
29
- class Company(DataPoint):
30
- name: str
31
-
32
- class Person(DataPoint):
33
- name: str
34
- works_for: Company
35
-
36
- company1 = Company(name="Figma")
37
- company2 = Company(name="Canva")
38
- person1 = Person(name="Steve Rodger", works_for=company1)
39
- person2 = Person(name="Ike Loma", works_for=company1)
40
- person3 = Person(name="Jason Statham", works_for=company1)
41
- person4 = Person(name="Mike Broski", works_for=company2)
42
- person5 = Person(name="Christina Mayer", works_for=company2)
43
-
44
- entities = [company1, company2, person1, person2, person3, person4, person5]
45
-
46
- await add_data_points(entities)
47
-
48
- retriever = GraphCompletionCotRetriever()
49
109
 
50
- context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
51
-
52
- assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
53
- assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
54
-
55
- answer = await retriever.get_completion("Who works at Canva?")
56
-
57
- assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
58
- assert all(isinstance(item, str) and item.strip() for item in answer), (
59
- "Answer must contain only non-empty strings"
110
+ assert completion == "Generated answer"
111
+ assert context_text == "Resolved context"
112
+ assert len(triplets) >= 1
113
+
114
+
115
+ @pytest.mark.asyncio
116
+ async def test_run_cot_completion_round_zero_without_context(mock_edge):
117
+ """Test _run_cot_completion round 0 without provided context."""
118
+ mock_graph_engine = AsyncMock()
119
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
120
+
121
+ retriever = GraphCompletionCotRetriever(max_iter=1)
122
+
123
+ with (
124
+ patch(
125
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
126
+ return_value=mock_graph_engine,
127
+ ),
128
+ patch(
129
+ "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
130
+ return_value=[mock_edge],
131
+ ),
132
+ patch(
133
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
134
+ return_value="Resolved context",
135
+ ),
136
+ patch.object(retriever, "get_triplets", new_callable=AsyncMock, return_value=[mock_edge]),
137
+ patch(
138
+ "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
139
+ return_value="Generated answer",
140
+ ),
141
+ ):
142
+ completion, context_text, triplets = await retriever._run_cot_completion(query="test query")
143
+
144
+ assert completion == "Generated answer"
145
+ assert context_text == "Resolved context"
146
+ assert len(triplets) >= 1
147
+
148
+
149
+ @pytest.mark.asyncio
150
+ async def test_run_cot_completion_multiple_rounds(mock_edge):
151
+ """Test _run_cot_completion with multiple rounds."""
152
+ retriever = GraphCompletionCotRetriever(max_iter=2)
153
+
154
+ mock_edge2 = MagicMock(spec=Edge)
155
+
156
+ with (
157
+ patch(
158
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
159
+ return_value="Resolved context",
160
+ ),
161
+ patch(
162
+ "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
163
+ return_value="Generated answer",
164
+ ),
165
+ patch.object(
166
+ retriever,
167
+ "get_retrieved_objects",
168
+ new_callable=AsyncMock,
169
+ side_effect=[[mock_edge], [mock_edge2]],
170
+ ),
171
+ patch(
172
+ "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
173
+ return_value="Rendered prompt",
174
+ ),
175
+ patch(
176
+ "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
177
+ return_value="System prompt",
178
+ ),
179
+ patch.object(
180
+ LLMGateway,
181
+ "acreate_structured_output",
182
+ new_callable=AsyncMock,
183
+ side_effect=[
184
+ "validation_result",
185
+ "followup_question",
186
+ "validation_result2",
187
+ "followup_question2",
188
+ ],
189
+ ),
190
+ patch(
191
+ "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
192
+ return_value="Generated answer",
193
+ ),
194
+ patch.object(retriever, "get_triplets", new_callable=AsyncMock, return_value=[mock_edge]),
195
+ ):
196
+ completion, context_text, triplets = await retriever._run_cot_completion(query="test query")
197
+
198
+ assert completion == "Generated answer"
199
+ assert context_text == "Resolved context"
200
+ assert len(triplets) >= 1
201
+
202
+
203
+ @pytest.mark.asyncio
204
+ async def test_run_cot_completion_with_conversation_history(mock_edge):
205
+ """Test _run_cot_completion with conversation history."""
206
+ retriever = GraphCompletionCotRetriever(max_iter=1)
207
+
208
+ with (
209
+ patch(
210
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
211
+ return_value="Resolved context",
212
+ ),
213
+ patch(
214
+ "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
215
+ return_value="Generated answer",
216
+ ) as mock_generate,
217
+ patch.object(retriever, "get_triplets", new_callable=AsyncMock, return_value=[mock_edge]),
218
+ ):
219
+ completion, context_text, triplets = await retriever._run_cot_completion(
220
+ query="test query",
221
+ conversation_history="Previous conversation",
60
222
  )
61
223
 
62
- @pytest.mark.asyncio
63
- async def test_graph_completion_cot_context_complex(self):
64
- system_directory_path = os.path.join(
65
- pathlib.Path(__file__).parent,
66
- ".cognee_system/test_graph_completion_cot_context_complex",
224
+ assert completion == "Generated answer"
225
+ call_kwargs = mock_generate.call_args[1]
226
+ assert call_kwargs.get("conversation_history") == "Previous conversation"
227
+
228
+
229
+ @pytest.mark.asyncio
230
+ async def test_run_cot_completion_with_response_model(mock_edge):
231
+ """Test _run_cot_completion with custom response model."""
232
+ from pydantic import BaseModel
233
+
234
+ class TestModel(BaseModel):
235
+ answer: str
236
+
237
+ retriever = GraphCompletionCotRetriever(response_model=TestModel, max_iter=1)
238
+
239
+ with (
240
+ patch(
241
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
242
+ return_value="Resolved context",
243
+ ),
244
+ patch(
245
+ "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
246
+ return_value=TestModel(answer="Test answer"),
247
+ ),
248
+ patch.object(retriever, "get_triplets", new_callable=AsyncMock, return_value=[mock_edge]),
249
+ ):
250
+ completion, context_text, triplets = await retriever._run_cot_completion(query="test query")
251
+
252
+ assert isinstance(completion, TestModel)
253
+ assert completion.answer == "Test answer"
254
+
255
+
256
+ @pytest.mark.asyncio
257
+ async def test_run_cot_completion_empty_conversation_history(mock_edge):
258
+ """Test _run_cot_completion with empty conversation history."""
259
+ retriever = GraphCompletionCotRetriever(max_iter=1)
260
+
261
+ with (
262
+ patch(
263
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
264
+ return_value="Resolved context",
265
+ ),
266
+ patch(
267
+ "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
268
+ return_value="Generated answer",
269
+ ) as mock_generate,
270
+ patch.object(retriever, "get_triplets", new_callable=AsyncMock, return_value=[mock_edge]),
271
+ ):
272
+ completion, context_text, triplets = await retriever._run_cot_completion(
273
+ query="test query",
274
+ conversation_history="",
67
275
  )
68
- cognee.config.system_root_directory(system_directory_path)
69
- data_directory_path = os.path.join(
70
- pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_complex"
71
- )
72
- cognee.config.data_root_directory(data_directory_path)
73
-
74
- await cognee.prune.prune_data()
75
- await cognee.prune.prune_system(metadata=True)
76
- await setup()
77
-
78
- class Company(DataPoint):
79
- name: str
80
- metadata: dict = {"index_fields": ["name"]}
81
-
82
- class Car(DataPoint):
83
- brand: str
84
- model: str
85
- year: int
86
-
87
- class Location(DataPoint):
88
- country: str
89
- city: str
90
-
91
- class Home(DataPoint):
92
- location: Location
93
- rooms: int
94
- sqm: int
95
-
96
- class Person(DataPoint):
97
- name: str
98
- works_for: Company
99
- owns: Optional[list[Union[Car, Home]]] = None
100
276
 
101
- company1 = Company(name="Figma")
102
- company2 = Company(name="Canva")
103
-
104
- person1 = Person(name="Mike Rodger", works_for=company1)
105
- person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
106
-
107
- person2 = Person(name="Ike Loma", works_for=company1)
108
- person2.owns = [
109
- Car(brand="Tesla", model="Model S", year=2021),
110
- Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
111
- ]
112
-
113
- person3 = Person(name="Jason Statham", works_for=company1)
114
-
115
- person4 = Person(name="Mike Broski", works_for=company2)
116
- person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
117
-
118
- person5 = Person(name="Christina Mayer", works_for=company2)
119
- person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
120
-
121
- entities = [company1, company2, person1, person2, person3, person4, person5]
122
-
123
- await add_data_points(entities)
124
-
125
- retriever = GraphCompletionCotRetriever(top_k=20)
126
-
127
- context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
128
-
129
- print(context)
130
-
131
- assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
132
- assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
133
- assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
134
-
135
- answer = await retriever.get_completion("Who works at Figma?")
136
-
137
- assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
138
- assert all(isinstance(item, str) and item.strip() for item in answer), (
139
- "Answer must contain only non-empty strings"
277
+ assert completion == "Generated answer"
278
+ # Verify conversation_history was passed as None when empty
279
+ call_kwargs = mock_generate.call_args[1]
280
+ assert call_kwargs.get("conversation_history") is None
281
+
282
+
283
+ @pytest.mark.asyncio
284
+ async def test_get_completion_without_context(mock_edge):
285
+ """Test get_completion retrieves context when not provided."""
286
+ mock_graph_engine = AsyncMock()
287
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
288
+
289
+ retriever = GraphCompletionCotRetriever(max_iter=1)
290
+
291
+ with (
292
+ patch(
293
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
294
+ return_value=mock_graph_engine,
295
+ ),
296
+ patch(
297
+ "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
298
+ return_value=[mock_edge],
299
+ ),
300
+ patch(
301
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
302
+ return_value="Resolved context",
303
+ ),
304
+ patch(
305
+ "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
306
+ return_value="Generated answer",
307
+ ),
308
+ patch.object(
309
+ retriever, "get_context_from_objects", new_callable=AsyncMock, return_value=[mock_edge]
310
+ ),
311
+ patch(
312
+ "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
313
+ return_value="Generated answer",
314
+ ),
315
+ patch(
316
+ "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
317
+ return_value="Rendered prompt",
318
+ ),
319
+ patch(
320
+ "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
321
+ return_value="System prompt",
322
+ ),
323
+ patch.object(
324
+ LLMGateway,
325
+ "acreate_structured_output",
326
+ new_callable=AsyncMock,
327
+ side_effect=["validation_result", "followup_question"],
328
+ ),
329
+ patch(
330
+ "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
331
+ ) as mock_cache_config,
332
+ ):
333
+ mock_config = MagicMock()
334
+ mock_config.caching = False
335
+ mock_cache_config.return_value = mock_config
336
+
337
+ mock_edge = MagicMock()
338
+
339
+ objects = await retriever.get_retrieved_objects("test query")
340
+ context = await retriever.get_context_from_objects("test query", objects)
341
+ completion = await retriever.get_completion_from_context(
342
+ "test query", [mock_edge], context=context
140
343
  )
141
344
 
142
- @pytest.mark.asyncio
143
- async def test_get_graph_completion_cot_context_on_empty_graph(self):
144
- system_directory_path = os.path.join(
145
- pathlib.Path(__file__).parent,
146
- ".cognee_system/test_get_graph_completion_cot_context_on_empty_graph",
345
+ assert isinstance(completion, list)
346
+ assert len(completion) == 1
347
+ assert completion[0] == "Generated answer"
348
+
349
+
350
+ @pytest.mark.asyncio
351
+ async def test_get_completion_with_provided_context(mock_edge):
352
+ """Test get_completion uses provided context."""
353
+ retriever = GraphCompletionCotRetriever(max_iter=1)
354
+
355
+ with (
356
+ patch(
357
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
358
+ return_value="Resolved context",
359
+ ),
360
+ patch(
361
+ "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
362
+ return_value="Generated answer",
363
+ ),
364
+ patch(
365
+ "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
366
+ ) as mock_cache_config,
367
+ ):
368
+ mock_config = MagicMock()
369
+ mock_config.caching = False
370
+ mock_cache_config.return_value = mock_config
371
+
372
+ mock_edge = MagicMock()
373
+
374
+ objects = await retriever.get_retrieved_objects("test query")
375
+ await retriever.get_context_from_objects("test query", objects)
376
+ completion = await retriever.get_completion_from_context(
377
+ "test query", [mock_edge], context="test"
147
378
  )
148
- cognee.config.system_root_directory(system_directory_path)
149
- data_directory_path = os.path.join(
150
- pathlib.Path(__file__).parent,
151
- ".data_storage/test_get_graph_completion_cot_context_on_empty_graph",
152
- )
153
- cognee.config.data_root_directory(data_directory_path)
154
-
155
- await cognee.prune.prune_data()
156
- await cognee.prune.prune_system(metadata=True)
157
379
 
158
- retriever = GraphCompletionCotRetriever()
159
-
160
- await setup()
380
+ assert isinstance(completion, list)
381
+ assert len(completion) == 1
382
+ assert completion[0] == "Generated answer"
383
+
384
+
385
+ @pytest.mark.asyncio
386
+ async def test_get_completion_with_session(mock_edge):
387
+ """Test get_completion with session caching enabled."""
388
+ mock_graph_engine = AsyncMock()
389
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
390
+
391
+ retriever = GraphCompletionCotRetriever(session_id="test_session", max_iter=1)
392
+
393
+ mock_user = MagicMock()
394
+ mock_user.id = "test-user-id"
395
+
396
+ with (
397
+ patch(
398
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
399
+ return_value=mock_graph_engine,
400
+ ),
401
+ patch(
402
+ "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
403
+ return_value=[mock_edge],
404
+ ),
405
+ patch(
406
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
407
+ return_value="Resolved context",
408
+ ),
409
+ patch(
410
+ "cognee.modules.retrieval.graph_completion_cot_retriever.get_conversation_history",
411
+ return_value="Previous conversation",
412
+ ),
413
+ patch(
414
+ "cognee.modules.retrieval.graph_completion_cot_retriever.summarize_text",
415
+ return_value="Context summary",
416
+ ),
417
+ patch(
418
+ "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
419
+ return_value="Generated answer",
420
+ ),
421
+ patch(
422
+ "cognee.modules.retrieval.graph_completion_cot_retriever.save_conversation_history",
423
+ ) as mock_save,
424
+ patch(
425
+ "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
426
+ ) as mock_cache_config,
427
+ patch(
428
+ "cognee.modules.retrieval.graph_completion_cot_retriever.session_user"
429
+ ) as mock_session_user,
430
+ ):
431
+ mock_config = MagicMock()
432
+ mock_config.caching = True
433
+ mock_cache_config.return_value = mock_config
434
+ mock_session_user.get.return_value = mock_user
435
+
436
+ retrieved_objects = await retriever.get_retrieved_objects("test query")
437
+ completion = await retriever.get_completion_from_context(
438
+ "test query", retrieved_objects, context="mock_edge"
439
+ )
161
440
 
162
- context = await retriever.get_context("Who works at Figma?")
163
- assert context == [], "Context should be empty on an empty graph"
441
+ assert isinstance(completion, list)
442
+ assert len(completion) == 1
443
+ assert completion[0] == "Generated answer"
444
+ mock_save.assert_awaited_once()
445
+
446
+
447
+ @pytest.mark.asyncio
448
+ async def test_get_completion_with_save_interaction(mock_edge):
449
+ """Test get_completion with save_interaction enabled."""
450
+ mock_graph_engine = AsyncMock()
451
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
452
+ mock_graph_engine.add_edges = AsyncMock()
453
+
454
+ retriever = GraphCompletionCotRetriever(save_interaction=True, max_iter=1)
455
+
456
+ mock_node1 = MagicMock()
457
+ mock_node2 = MagicMock()
458
+ mock_edge.node1 = mock_node1
459
+ mock_edge.node2 = mock_node2
460
+
461
+ with (
462
+ patch(
463
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
464
+ return_value="Resolved context",
465
+ ),
466
+ patch(
467
+ "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
468
+ return_value="Generated answer",
469
+ ),
470
+ patch.object(retriever, "get_triplets", new_callable=AsyncMock, return_value=[mock_edge]),
471
+ patch.object(
472
+ retriever, "get_context_from_objects", new_callable=AsyncMock, return_value="mock_edge"
473
+ ),
474
+ patch(
475
+ "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
476
+ return_value="Generated answer",
477
+ ),
478
+ patch(
479
+ "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
480
+ return_value="Rendered prompt",
481
+ ),
482
+ patch(
483
+ "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
484
+ return_value="System prompt",
485
+ ),
486
+ patch.object(
487
+ LLMGateway,
488
+ "acreate_structured_output",
489
+ new_callable=AsyncMock,
490
+ side_effect=["validation_result", "followup_question"],
491
+ ),
492
+ patch(
493
+ "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
494
+ side_effect=[
495
+ UUID("550e8400-e29b-41d4-a716-446655440000"),
496
+ UUID("550e8400-e29b-41d4-a716-446655440001"),
497
+ UUID("550e8400-e29b-41d4-a716-446655440002"),
498
+ UUID("550e8400-e29b-41d4-a716-446655440003"),
499
+ ],
500
+ ),
501
+ patch(
502
+ "cognee.modules.retrieval.graph_completion_retriever.add_data_points",
503
+ ) as mock_add_data,
504
+ patch(
505
+ "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
506
+ ) as mock_cache_config,
507
+ ):
508
+ mock_config = MagicMock()
509
+ mock_config.caching = False
510
+ mock_cache_config.return_value = mock_config
511
+
512
+ # Pass context so save_interaction condition is met
513
+ retrieved_objects = await retriever.get_retrieved_objects("test query")
514
+ context = await retriever.get_context_from_objects("test query", retrieved_objects)
515
+ completion = await retriever.get_completion_from_context(
516
+ "test query", [mock_edge], context=context
517
+ )
164
518
 
165
- answer = await retriever.get_completion("Who works at Figma?")
519
+ assert isinstance(completion, list)
520
+ assert len(completion) == 1
521
+ mock_add_data.assert_awaited_once()
522
+
523
+
524
+ @pytest.mark.asyncio
525
+ async def test_get_completion_with_response_model(mock_edge):
526
+ """Test get_completion with custom response model."""
527
+ from pydantic import BaseModel
528
+
529
+ class TestModel(BaseModel):
530
+ answer: str
531
+
532
+ mock_graph_engine = AsyncMock()
533
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
534
+
535
+ retriever = GraphCompletionCotRetriever(response_model=TestModel, max_iter=1)
536
+
537
+ with (
538
+ patch(
539
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
540
+ return_value=mock_graph_engine,
541
+ ),
542
+ patch(
543
+ "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
544
+ return_value=[mock_edge],
545
+ ),
546
+ patch(
547
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
548
+ return_value="Resolved context",
549
+ ),
550
+ patch(
551
+ "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
552
+ return_value=TestModel(answer="Test answer"),
553
+ ),
554
+ patch(
555
+ "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
556
+ ) as mock_cache_config,
557
+ ):
558
+ mock_config = MagicMock()
559
+ mock_config.caching = False
560
+ mock_cache_config.return_value = mock_config
561
+
562
+ mock_edge = MagicMock()
563
+
564
+ objects = await retriever.get_retrieved_objects("test query")
565
+ await retriever.get_context_from_objects("test query", objects)
566
+ completion = await retriever.get_completion_from_context(
567
+ "test query", [mock_edge], "mock_edge"
568
+ )
166
569
 
167
- assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
168
- assert all(isinstance(item, str) and item.strip() for item in answer), (
169
- "Answer must contain only non-empty strings"
570
+ assert isinstance(completion, list)
571
+ assert len(completion) == 1
572
+ assert isinstance(completion[0], TestModel)
573
+
574
+
575
+ @pytest.mark.asyncio
576
+ async def test_get_completion_with_session_no_user_id(mock_edge):
577
+ """Test get_completion with session config but no user ID."""
578
+ mock_graph_engine = AsyncMock()
579
+ mock_graph_engine.is_empty = AsyncMock(return_value=False)
580
+
581
+ retriever = GraphCompletionCotRetriever(max_iter=1)
582
+
583
+ with (
584
+ patch(
585
+ "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
586
+ return_value=mock_graph_engine,
587
+ ),
588
+ patch(
589
+ "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
590
+ return_value=[mock_edge],
591
+ ),
592
+ patch(
593
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
594
+ return_value="Resolved context",
595
+ ),
596
+ patch(
597
+ "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
598
+ return_value="Generated answer",
599
+ ),
600
+ patch(
601
+ "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
602
+ ) as mock_cache_config,
603
+ patch(
604
+ "cognee.modules.retrieval.graph_completion_cot_retriever.session_user"
605
+ ) as mock_session_user,
606
+ ):
607
+ mock_config = MagicMock()
608
+ mock_config.caching = True
609
+ mock_cache_config.return_value = mock_config
610
+ mock_session_user.get.return_value = None # No user
611
+
612
+ completion = await retriever.get_completion_from_context(
613
+ "test query", [mock_edge], context="mock_edge"
170
614
  )
615
+
616
+ assert isinstance(completion, list)
617
+ assert len(completion) == 1
618
+
619
+
620
+ @pytest.mark.asyncio
621
+ async def test_get_completion_with_save_interaction_no_context(mock_edge):
622
+ """Test get_completion with save_interaction but no context provided."""
623
+ retriever = GraphCompletionCotRetriever(save_interaction=True, max_iter=1)
624
+
625
+ with (
626
+ patch(
627
+ "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
628
+ return_value="Resolved context",
629
+ ),
630
+ patch(
631
+ "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
632
+ return_value="Generated answer",
633
+ ),
634
+ patch.object(
635
+ retriever, "get_retrieved_objects", new_callable=AsyncMock, return_value=[mock_edge]
636
+ ),
637
+ patch(
638
+ "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
639
+ return_value="Generated answer",
640
+ ),
641
+ patch(
642
+ "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
643
+ return_value="Rendered prompt",
644
+ ),
645
+ patch(
646
+ "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
647
+ return_value="System prompt",
648
+ ),
649
+ patch.object(
650
+ LLMGateway,
651
+ "acreate_structured_output",
652
+ new_callable=AsyncMock,
653
+ side_effect=["validation_result", "followup_question"],
654
+ ),
655
+ patch(
656
+ "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
657
+ ) as mock_cache_config,
658
+ ):
659
+ mock_config = MagicMock()
660
+ mock_config.caching = False
661
+ mock_cache_config.return_value = mock_config
662
+
663
+ with pytest.raises(CogneeValidationError):
664
+ await retriever.get_completion_from_context("test query", None, context=None)
665
+
666
+
667
+ @pytest.mark.asyncio
668
+ async def test_as_answer_text_with_typeerror():
669
+ """Test _as_answer_text handles TypeError when json.dumps fails."""
670
+ non_serializable = {1, 2, 3}
671
+
672
+ result = _as_answer_text(non_serializable)
673
+
674
+ assert isinstance(result, str)
675
+ assert result == str(non_serializable)
676
+
677
+
678
+ @pytest.mark.asyncio
679
+ async def test_as_answer_text_with_string():
680
+ """Test _as_answer_text with string input."""
681
+ result = _as_answer_text("test string")
682
+ assert result == "test string"
683
+
684
+
685
+ @pytest.mark.asyncio
686
+ async def test_as_answer_text_with_dict():
687
+ """Test _as_answer_text with dictionary input."""
688
+ test_dict = {"key": "value", "number": 42}
689
+ result = _as_answer_text(test_dict)
690
+ assert isinstance(result, str)
691
+ assert "key" in result
692
+ assert "value" in result
693
+
694
+
695
+ @pytest.mark.asyncio
696
+ async def test_as_answer_text_with_basemodel():
697
+ """Test _as_answer_text with Pydantic BaseModel input."""
698
+ from pydantic import BaseModel
699
+
700
+ class TestModel(BaseModel):
701
+ answer: str
702
+
703
+ test_model = TestModel(answer="test answer")
704
+ result = _as_answer_text(test_model)
705
+
706
+ assert isinstance(result, str)
707
+ assert "[Structured Response]" in result
708
+ assert "test answer" in result