cognee 0.5.1.dev0__py3-none-any.whl → 0.5.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. cognee/__init__.py +2 -0
  2. cognee/alembic/README +1 -0
  3. cognee/alembic/env.py +107 -0
  4. cognee/alembic/script.py.mako +26 -0
  5. cognee/alembic/versions/1a58b986e6e1_enable_delete_for_old_tutorial_notebooks.py +52 -0
  6. cognee/alembic/versions/1d0bb7fede17_add_pipeline_run_status.py +33 -0
  7. cognee/alembic/versions/1daae0df1866_incremental_loading.py +48 -0
  8. cognee/alembic/versions/211ab850ef3d_add_sync_operations_table.py +118 -0
  9. cognee/alembic/versions/45957f0a9849_add_notebook_table.py +46 -0
  10. cognee/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py +333 -0
  11. cognee/alembic/versions/482cd6517ce4_add_default_user.py +30 -0
  12. cognee/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py +98 -0
  13. cognee/alembic/versions/8057ae7329c2_initial_migration.py +25 -0
  14. cognee/alembic/versions/9e7a3cb85175_loader_separation.py +104 -0
  15. cognee/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py +38 -0
  16. cognee/alembic/versions/ab7e313804ae_permission_system_rework.py +236 -0
  17. cognee/alembic/versions/b9274c27a25a_kuzu_11_migration.py +75 -0
  18. cognee/alembic/versions/c946955da633_multi_tenant_support.py +137 -0
  19. cognee/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py +51 -0
  20. cognee/alembic/versions/e4ebee1091e7_expand_data_model_info.py +140 -0
  21. cognee/alembic.ini +117 -0
  22. cognee/api/v1/add/routers/get_add_router.py +2 -0
  23. cognee/api/v1/cognify/cognify.py +11 -6
  24. cognee/api/v1/cognify/routers/get_cognify_router.py +8 -0
  25. cognee/api/v1/config/config.py +60 -0
  26. cognee/api/v1/datasets/routers/get_datasets_router.py +45 -3
  27. cognee/api/v1/memify/routers/get_memify_router.py +2 -0
  28. cognee/api/v1/search/routers/get_search_router.py +21 -6
  29. cognee/api/v1/search/search.py +25 -5
  30. cognee/api/v1/sync/routers/get_sync_router.py +3 -3
  31. cognee/cli/commands/add_command.py +1 -1
  32. cognee/cli/commands/cognify_command.py +6 -0
  33. cognee/cli/commands/config_command.py +1 -1
  34. cognee/context_global_variables.py +5 -1
  35. cognee/eval_framework/answer_generation/answer_generation_executor.py +7 -8
  36. cognee/infrastructure/databases/cache/cache_db_interface.py +38 -1
  37. cognee/infrastructure/databases/cache/config.py +6 -0
  38. cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py +21 -0
  39. cognee/infrastructure/databases/cache/get_cache_engine.py +9 -3
  40. cognee/infrastructure/databases/cache/redis/RedisAdapter.py +60 -1
  41. cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +7 -0
  42. cognee/infrastructure/databases/graph/get_graph_engine.py +29 -1
  43. cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +62 -27
  44. cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +17 -4
  45. cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +2 -1
  46. cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +2 -0
  47. cognee/infrastructure/databases/vector/config.py +6 -0
  48. cognee/infrastructure/databases/vector/create_vector_engine.py +69 -22
  49. cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +64 -9
  50. cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +13 -2
  51. cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +16 -3
  52. cognee/infrastructure/databases/vector/models/ScoredResult.py +3 -3
  53. cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +16 -3
  54. cognee/infrastructure/databases/vector/pgvector/PGVectorDatasetDatabaseHandler.py +86 -0
  55. cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py +81 -2
  56. cognee/infrastructure/databases/vector/vector_db_interface.py +8 -0
  57. cognee/infrastructure/files/utils/get_data_file_path.py +33 -27
  58. cognee/infrastructure/llm/prompts/extract_query_time.txt +1 -1
  59. cognee/infrastructure/llm/prompts/generate_event_entity_prompt.txt +1 -1
  60. cognee/infrastructure/llm/prompts/generate_event_graph_prompt.txt +1 -1
  61. cognee/infrastructure/llm/prompts/generate_graph_prompt.txt +2 -2
  62. cognee/infrastructure/llm/prompts/generate_graph_prompt_guided.txt +1 -1
  63. cognee/infrastructure/llm/prompts/generate_graph_prompt_oneshot.txt +2 -2
  64. cognee/infrastructure/llm/prompts/generate_graph_prompt_simple.txt +1 -1
  65. cognee/infrastructure/llm/prompts/generate_graph_prompt_strict.txt +1 -1
  66. cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt +6 -6
  67. cognee/infrastructure/llm/prompts/test.txt +1 -1
  68. cognee/infrastructure/llm/prompts/translate_content.txt +19 -0
  69. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +24 -0
  70. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llama_cpp/adapter.py +191 -0
  71. cognee/modules/chunking/models/DocumentChunk.py +0 -1
  72. cognee/modules/cognify/config.py +2 -0
  73. cognee/modules/data/models/Data.py +1 -0
  74. cognee/modules/engine/models/Entity.py +0 -1
  75. cognee/modules/engine/operations/setup.py +6 -0
  76. cognee/modules/graph/cognee_graph/CogneeGraph.py +150 -37
  77. cognee/modules/graph/cognee_graph/CogneeGraphElements.py +48 -2
  78. cognee/modules/graph/utils/__init__.py +1 -0
  79. cognee/modules/graph/utils/get_entity_nodes_from_triplets.py +12 -0
  80. cognee/modules/notebooks/methods/__init__.py +1 -0
  81. cognee/modules/notebooks/methods/create_notebook.py +0 -34
  82. cognee/modules/notebooks/methods/create_tutorial_notebooks.py +191 -0
  83. cognee/modules/notebooks/methods/get_notebooks.py +12 -8
  84. cognee/modules/notebooks/tutorials/cognee-basics/cell-1.md +3 -0
  85. cognee/modules/notebooks/tutorials/cognee-basics/cell-2.md +10 -0
  86. cognee/modules/notebooks/tutorials/cognee-basics/cell-3.md +7 -0
  87. cognee/modules/notebooks/tutorials/cognee-basics/cell-4.py +28 -0
  88. cognee/modules/notebooks/tutorials/cognee-basics/cell-5.py +3 -0
  89. cognee/modules/notebooks/tutorials/cognee-basics/cell-6.py +9 -0
  90. cognee/modules/notebooks/tutorials/cognee-basics/cell-7.py +17 -0
  91. cognee/modules/notebooks/tutorials/cognee-basics/config.json +4 -0
  92. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-1.md +3 -0
  93. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-10.md +3 -0
  94. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-11.md +3 -0
  95. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-12.py +3 -0
  96. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-13.md +7 -0
  97. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-14.py +6 -0
  98. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-15.md +3 -0
  99. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-16.py +7 -0
  100. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-2.md +9 -0
  101. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-3.md +7 -0
  102. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-4.md +9 -0
  103. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-5.md +5 -0
  104. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-6.py +13 -0
  105. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-7.md +3 -0
  106. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-8.md +3 -0
  107. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-9.py +31 -0
  108. cognee/modules/notebooks/tutorials/python-development-with-cognee/config.json +4 -0
  109. cognee/modules/notebooks/tutorials/python-development-with-cognee/data/copilot_conversations.json +107 -0
  110. cognee/modules/notebooks/tutorials/python-development-with-cognee/data/guido_contributions.json +976 -0
  111. cognee/modules/notebooks/tutorials/python-development-with-cognee/data/my_developer_rules.md +79 -0
  112. cognee/modules/notebooks/tutorials/python-development-with-cognee/data/pep_style_guide.md +74 -0
  113. cognee/modules/notebooks/tutorials/python-development-with-cognee/data/zen_principles.md +74 -0
  114. cognee/modules/retrieval/EntityCompletionRetriever.py +51 -38
  115. cognee/modules/retrieval/__init__.py +0 -1
  116. cognee/modules/retrieval/base_retriever.py +66 -10
  117. cognee/modules/retrieval/chunks_retriever.py +57 -49
  118. cognee/modules/retrieval/coding_rules_retriever.py +12 -5
  119. cognee/modules/retrieval/completion_retriever.py +29 -28
  120. cognee/modules/retrieval/cypher_search_retriever.py +25 -20
  121. cognee/modules/retrieval/graph_completion_context_extension_retriever.py +42 -46
  122. cognee/modules/retrieval/graph_completion_cot_retriever.py +68 -51
  123. cognee/modules/retrieval/graph_completion_retriever.py +78 -63
  124. cognee/modules/retrieval/graph_summary_completion_retriever.py +2 -0
  125. cognee/modules/retrieval/lexical_retriever.py +34 -12
  126. cognee/modules/retrieval/natural_language_retriever.py +18 -15
  127. cognee/modules/retrieval/summaries_retriever.py +51 -34
  128. cognee/modules/retrieval/temporal_retriever.py +59 -49
  129. cognee/modules/retrieval/triplet_retriever.py +31 -32
  130. cognee/modules/retrieval/utils/access_tracking.py +88 -0
  131. cognee/modules/retrieval/utils/brute_force_triplet_search.py +99 -85
  132. cognee/modules/retrieval/utils/node_edge_vector_search.py +174 -0
  133. cognee/modules/search/methods/__init__.py +1 -0
  134. cognee/modules/search/methods/get_retriever_output.py +53 -0
  135. cognee/modules/search/methods/get_search_type_retriever_instance.py +252 -0
  136. cognee/modules/search/methods/search.py +90 -215
  137. cognee/modules/search/models/SearchResultPayload.py +67 -0
  138. cognee/modules/search/types/SearchResult.py +1 -8
  139. cognee/modules/search/types/SearchType.py +1 -2
  140. cognee/modules/search/types/__init__.py +1 -1
  141. cognee/modules/search/utils/__init__.py +1 -2
  142. cognee/modules/search/utils/transform_insights_to_graph.py +2 -2
  143. cognee/modules/search/utils/{transform_context_to_graph.py → transform_triplets_to_graph.py} +2 -2
  144. cognee/modules/users/authentication/default/default_transport.py +11 -1
  145. cognee/modules/users/authentication/get_api_auth_backend.py +2 -1
  146. cognee/modules/users/authentication/get_client_auth_backend.py +2 -1
  147. cognee/modules/users/methods/create_user.py +0 -9
  148. cognee/modules/users/permissions/methods/has_user_management_permission.py +29 -0
  149. cognee/modules/visualization/cognee_network_visualization.py +1 -1
  150. cognee/run_migrations.py +48 -0
  151. cognee/shared/exceptions/__init__.py +1 -3
  152. cognee/shared/exceptions/exceptions.py +11 -1
  153. cognee/shared/usage_logger.py +332 -0
  154. cognee/shared/utils.py +12 -5
  155. cognee/tasks/cleanup/cleanup_unused_data.py +172 -0
  156. cognee/tasks/memify/extract_usage_frequency.py +613 -0
  157. cognee/tasks/summarization/models.py +0 -2
  158. cognee/tasks/temporal_graph/__init__.py +0 -1
  159. cognee/tasks/translation/__init__.py +96 -0
  160. cognee/tasks/translation/config.py +110 -0
  161. cognee/tasks/translation/detect_language.py +190 -0
  162. cognee/tasks/translation/exceptions.py +62 -0
  163. cognee/tasks/translation/models.py +72 -0
  164. cognee/tasks/translation/providers/__init__.py +44 -0
  165. cognee/tasks/translation/providers/azure_provider.py +192 -0
  166. cognee/tasks/translation/providers/base.py +85 -0
  167. cognee/tasks/translation/providers/google_provider.py +158 -0
  168. cognee/tasks/translation/providers/llm_provider.py +143 -0
  169. cognee/tasks/translation/translate_content.py +282 -0
  170. cognee/tasks/web_scraper/default_url_crawler.py +6 -2
  171. cognee/tests/cli_tests/cli_unit_tests/test_cli_commands.py +1 -0
  172. cognee/tests/cli_tests/cli_unit_tests/test_cli_edge_cases.py +3 -0
  173. cognee/tests/integration/retrieval/test_brute_force_triplet_search_with_cognify.py +62 -0
  174. cognee/tests/integration/retrieval/test_chunks_retriever.py +115 -16
  175. cognee/tests/integration/retrieval/test_graph_completion_retriever.py +13 -5
  176. cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +22 -20
  177. cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +23 -24
  178. cognee/tests/integration/retrieval/test_rag_completion_retriever.py +70 -5
  179. cognee/tests/integration/retrieval/test_structured_output.py +62 -18
  180. cognee/tests/integration/retrieval/test_summaries_retriever.py +20 -9
  181. cognee/tests/integration/retrieval/test_temporal_retriever.py +38 -8
  182. cognee/tests/integration/retrieval/test_triplet_retriever.py +13 -4
  183. cognee/tests/integration/shared/test_usage_logger_integration.py +255 -0
  184. cognee/tests/tasks/translation/README.md +147 -0
  185. cognee/tests/tasks/translation/__init__.py +1 -0
  186. cognee/tests/tasks/translation/config_test.py +93 -0
  187. cognee/tests/tasks/translation/detect_language_test.py +118 -0
  188. cognee/tests/tasks/translation/providers_test.py +151 -0
  189. cognee/tests/tasks/translation/translate_content_test.py +213 -0
  190. cognee/tests/test_chromadb.py +1 -1
  191. cognee/tests/test_cleanup_unused_data.py +165 -0
  192. cognee/tests/test_delete_by_id.py +6 -6
  193. cognee/tests/test_extract_usage_frequency.py +308 -0
  194. cognee/tests/test_kuzu.py +17 -7
  195. cognee/tests/test_lancedb.py +3 -1
  196. cognee/tests/test_library.py +1 -1
  197. cognee/tests/test_neo4j.py +17 -7
  198. cognee/tests/test_neptune_analytics_vector.py +3 -1
  199. cognee/tests/test_permissions.py +172 -187
  200. cognee/tests/test_pgvector.py +3 -1
  201. cognee/tests/test_relational_db_migration.py +15 -1
  202. cognee/tests/test_remote_kuzu.py +3 -1
  203. cognee/tests/test_s3_file_storage.py +1 -1
  204. cognee/tests/test_search_db.py +97 -110
  205. cognee/tests/test_usage_logger_e2e.py +268 -0
  206. cognee/tests/unit/api/test_get_raw_data_endpoint.py +206 -0
  207. cognee/tests/unit/eval_framework/answer_generation_test.py +4 -3
  208. cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +2 -0
  209. cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +42 -2
  210. cognee/tests/unit/modules/graph/cognee_graph_test.py +329 -31
  211. cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +31 -59
  212. cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +70 -33
  213. cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +72 -52
  214. cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +27 -33
  215. cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +28 -15
  216. cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +37 -42
  217. cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +48 -64
  218. cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +263 -24
  219. cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py +273 -0
  220. cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +30 -16
  221. cognee/tests/unit/modules/search/test_get_search_type_retriever_instance.py +125 -0
  222. cognee/tests/unit/modules/search/test_search.py +176 -0
  223. cognee/tests/unit/modules/search/test_search_prepare_search_result_contract.py +190 -0
  224. cognee/tests/unit/modules/users/test_tutorial_notebook_creation.py +511 -297
  225. cognee/tests/unit/shared/test_usage_logger.py +241 -0
  226. cognee/tests/unit/users/permissions/test_has_user_management_permission.py +46 -0
  227. {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dev0.dist-info}/METADATA +17 -10
  228. {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dev0.dist-info}/RECORD +232 -144
  229. cognee/api/.env.example +0 -5
  230. cognee/modules/retrieval/base_graph_retriever.py +0 -24
  231. cognee/modules/search/methods/get_search_type_tools.py +0 -223
  232. cognee/modules/search/methods/no_access_control_search.py +0 -62
  233. cognee/modules/search/utils/prepare_search_result.py +0 -63
  234. cognee/tests/test_feedback_enrichment.py +0 -174
  235. {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dev0.dist-info}/WHEEL +0 -0
  236. {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dev0.dist-info}/entry_points.txt +0 -0
  237. {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dev0.dist-info}/licenses/LICENSE +0 -0
  238. {cognee-0.5.1.dev0.dist-info → cognee-0.5.2.dev0.dist-info}/licenses/NOTICE.md +0 -0
@@ -8,6 +8,7 @@ from cognee.modules.retrieval.utils.session_cache import (
8
8
  save_conversation_history,
9
9
  get_conversation_history,
10
10
  )
11
+ from cognee.modules.retrieval.utils.access_tracking import update_node_access_timestamps
11
12
  from cognee.modules.retrieval.base_retriever import BaseRetriever
12
13
  from cognee.modules.retrieval.exceptions.exceptions import NoDataError
13
14
  from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
@@ -20,10 +21,6 @@ logger = get_logger("CompletionRetriever")
20
21
  class CompletionRetriever(BaseRetriever):
21
22
  """
22
23
  Retriever for handling LLM-based completion searches.
23
-
24
- Public methods:
25
- - get_context(query: str) -> str
26
- - get_completion(query: str, context: Optional[Any] = None) -> Any
27
24
  """
28
25
 
29
26
  def __init__(
@@ -32,14 +29,31 @@ class CompletionRetriever(BaseRetriever):
32
29
  system_prompt_path: str = "answer_simple_question.txt",
33
30
  system_prompt: Optional[str] = None,
34
31
  top_k: Optional[int] = 1,
32
+ session_id: Optional[str] = None,
33
+ response_model: Type = str,
35
34
  ):
36
35
  """Initialize retriever with optional custom prompt paths."""
37
36
  self.user_prompt_path = user_prompt_path
38
37
  self.system_prompt_path = system_prompt_path
39
38
  self.top_k = top_k if top_k is not None else 1
40
39
  self.system_prompt = system_prompt
40
+ self.session_id = session_id
41
+ self.response_model = response_model
42
+
43
+ async def get_retrieved_objects(self, query: str) -> Any:
44
+ vector_engine = get_vector_engine()
45
+
46
+ try:
47
+ found_chunks = await vector_engine.search(
48
+ "DocumentChunk_text", query, limit=self.top_k, include_payload=True
49
+ )
50
+
51
+ return found_chunks
52
+ except CollectionNotFoundError as error:
53
+ logger.error("DocumentChunk_text collection not found")
54
+ raise NoDataError("No data found in the system, please add data first.") from error
41
55
 
42
- async def get_context(self, query: str) -> str:
56
+ async def get_context_from_objects(self, query: str, retrieved_objects: Any) -> str:
43
57
  """
44
58
  Retrieves relevant document chunks as context.
45
59
 
@@ -58,28 +72,18 @@ class CompletionRetriever(BaseRetriever):
58
72
  - str: A string containing the combined text of the retrieved document chunks, or an
59
73
  empty string if none are found.
60
74
  """
61
- vector_engine = get_vector_engine()
62
-
63
- try:
64
- found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
65
-
66
- if len(found_chunks) == 0:
67
- return ""
68
-
69
- # Combine all chunks text returned from vector search (number of chunks is determined by top_k
70
- chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks]
75
+ if retrieved_objects:
76
+ # Combine all chunks text returned from vector search (number of chunks is determined by top_k)
77
+ chunks_payload = [found_chunk.payload["text"] for found_chunk in retrieved_objects]
71
78
  combined_context = "\n".join(chunks_payload)
72
79
  return combined_context
73
- except CollectionNotFoundError as error:
74
- logger.error("DocumentChunk_text collection not found")
75
- raise NoDataError("No data found in the system, please add data first.") from error
80
+ return ""
76
81
 
77
- async def get_completion(
82
+ async def get_completion_from_context(
78
83
  self,
79
84
  query: str,
85
+ retrieved_objects: Any,
80
86
  context: Optional[Any] = None,
81
- session_id: Optional[str] = None,
82
- response_model: Type = str,
83
87
  ) -> List[Any]:
84
88
  """
85
89
  Generates an LLM completion using the context.
@@ -102,9 +106,6 @@ class CompletionRetriever(BaseRetriever):
102
106
 
103
107
  - Any: The generated completion based on the provided query and context.
104
108
  """
105
- if context is None:
106
- context = await self.get_context(query)
107
-
108
109
  # Check if we need to generate context summary for caching
109
110
  cache_config = CacheConfig()
110
111
  user = session_user.get()
@@ -112,7 +113,7 @@ class CompletionRetriever(BaseRetriever):
112
113
  session_save = user_id and cache_config.caching
113
114
 
114
115
  if session_save:
115
- conversation_history = await get_conversation_history(session_id=session_id)
116
+ conversation_history = await get_conversation_history(session_id=self.session_id)
116
117
 
117
118
  context_summary, completion = await asyncio.gather(
118
119
  summarize_text(context),
@@ -123,7 +124,7 @@ class CompletionRetriever(BaseRetriever):
123
124
  system_prompt_path=self.system_prompt_path,
124
125
  system_prompt=self.system_prompt,
125
126
  conversation_history=conversation_history,
126
- response_model=response_model,
127
+ response_model=self.response_model,
127
128
  ),
128
129
  )
129
130
  else:
@@ -133,7 +134,7 @@ class CompletionRetriever(BaseRetriever):
133
134
  user_prompt_path=self.user_prompt_path,
134
135
  system_prompt_path=self.system_prompt_path,
135
136
  system_prompt=self.system_prompt,
136
- response_model=response_model,
137
+ response_model=self.response_model,
137
138
  )
138
139
 
139
140
  if session_save:
@@ -141,7 +142,7 @@ class CompletionRetriever(BaseRetriever):
141
142
  query=query,
142
143
  context_summary=context_summary,
143
144
  answer=completion,
144
- session_id=session_id,
145
+ session_id=self.session_id,
145
146
  )
146
147
 
147
148
  return [completion]
@@ -23,12 +23,29 @@ class CypherSearchRetriever(BaseRetriever):
23
23
  self,
24
24
  user_prompt_path: str = "context_for_question.txt",
25
25
  system_prompt_path: str = "answer_simple_question.txt",
26
+ session_id: Optional[str] = None,
26
27
  ):
27
28
  """Initialize retriever with optional custom prompt paths."""
28
29
  self.user_prompt_path = user_prompt_path
29
30
  self.system_prompt_path = system_prompt_path
31
+ self.session_id = session_id
30
32
 
31
- async def get_context(self, query: str) -> Any:
33
+ async def get_retrieved_objects(self, query: str) -> Any:
34
+ try:
35
+ graph_engine = await get_graph_engine()
36
+ is_empty = await graph_engine.is_empty()
37
+
38
+ if is_empty:
39
+ logger.warning("Search attempt on an empty knowledge graph")
40
+ return []
41
+
42
+ result = await graph_engine.query(query)
43
+ except Exception as e:
44
+ logger.error("Failed to execture cypher search retrieval: %s", str(e))
45
+ raise CypherSearchError() from e
46
+ return result
47
+
48
+ async def get_context_from_objects(self, query: str, retrieved_objects: Any) -> Any:
32
49
  """
33
50
  Retrieves relevant context using a cypher query.
34
51
 
@@ -44,22 +61,12 @@ class CypherSearchRetriever(BaseRetriever):
44
61
 
45
62
  - Any: The result of the cypher query execution.
46
63
  """
47
- try:
48
- graph_engine = await get_graph_engine()
49
- is_empty = await graph_engine.is_empty()
50
-
51
- if is_empty:
52
- logger.warning("Search attempt on an empty knowledge graph")
53
- return []
54
-
55
- result = jsonable_encoder(await graph_engine.query(query))
56
- except Exception as e:
57
- logger.error("Failed to execture cypher search retrieval: %s", str(e))
58
- raise CypherSearchError() from e
59
- return result
64
+ # TODO: Do we want to return a string response here?
65
+ # return jsonable_encoder(retrieved_objects)
66
+ return None
60
67
 
61
- async def get_completion(
62
- self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
68
+ async def get_completion_from_context(
69
+ self, query: str, retrieved_objects: Any, context: Optional[Any] = None
63
70
  ) -> Any:
64
71
  """
65
72
  Returns the graph connections context.
@@ -72,7 +79,6 @@ class CypherSearchRetriever(BaseRetriever):
72
79
  - query (str): The query to retrieve context.
73
80
  - context (Optional[Any]): Optional context to use, otherwise fetched using the
74
81
  query. (default None)
75
- - session_id (Optional[str]): Optional session identifier for caching. If None,
76
82
  defaults to 'default_session'. (default None)
77
83
 
78
84
  Returns:
@@ -80,6 +86,5 @@ class CypherSearchRetriever(BaseRetriever):
80
86
 
81
87
  - Any: The context, either provided or retrieved.
82
88
  """
83
- if context is None:
84
- context = await self.get_context(query)
85
- return context
89
+ # TODO: Do we want to generate a completion using LLM here?
90
+ return None
@@ -18,16 +18,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
18
18
  """
19
19
  Handles graph context completion for question answering tasks, extending context based
20
20
  on retrieved triplets.
21
-
22
- Public methods:
23
- - get_completion
24
-
25
- Instance variables:
26
- - user_prompt_path
27
- - system_prompt_path
28
- - top_k
29
- - node_type
30
- - node_name
31
21
  """
32
22
 
33
23
  def __init__(
@@ -41,6 +31,9 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
41
31
  save_interaction: bool = False,
42
32
  wide_search_top_k: Optional[int] = 100,
43
33
  triplet_distance_penalty: Optional[float] = 3.5,
34
+ context_extension_rounds: int = 4,
35
+ session_id: Optional[str] = None,
36
+ response_model: Type = str,
44
37
  ):
45
38
  super().__init__(
46
39
  user_prompt_path=user_prompt_path,
@@ -52,53 +45,38 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
52
45
  system_prompt=system_prompt,
53
46
  wide_search_top_k=wide_search_top_k,
54
47
  triplet_distance_penalty=triplet_distance_penalty,
48
+ session_id=session_id,
49
+ response_model=response_model,
55
50
  )
56
51
 
57
- async def get_completion(
58
- self,
59
- query: str,
60
- context: Optional[List[Edge]] = None,
61
- session_id: Optional[str] = None,
62
- context_extension_rounds=4,
63
- response_model: Type = str,
64
- ) -> List[Any]:
52
+ # context_extension_rounds: The maximum number of rounds to extend the context with
53
+ # new triplets before halting. (default 4)
54
+ self.context_extension_rounds = context_extension_rounds
55
+
56
+ async def get_retrieved_objects(self, query: str) -> List[Edge]:
65
57
  """
66
58
  Extends the context for a given query by retrieving related triplets and generating new
67
59
  completions based on them.
68
60
 
69
- The method runs for a specified number of rounds to enhance context until no new
61
+ The method runs for a specified number of rounds to enhance results until no new
70
62
  triplets are found or the maximum rounds are reached. It retrieves triplet suggestions
71
63
  based on a generated completion from previous iterations, logging the process of context
72
64
  extension.
73
65
 
74
66
  Parameters:
75
67
  -----------
76
-
77
68
  - query (str): The input query for which the completion is generated.
78
- - context (Optional[Any]): The existing context to use for enhancing the query; if
79
- None, it will be initialized from triplets generated for the query. (default None)
80
- - session_id (Optional[str]): Optional session identifier for caching. If None,
81
- defaults to 'default_session'. (default None)
82
- - context_extension_rounds: The maximum number of rounds to extend the context with
83
- new triplets before halting. (default 4)
84
- - response_model (Type): The Pydantic model type for structured output. (default str)
85
69
 
86
70
  Returns:
87
71
  --------
88
-
89
- - List[str]: A list containing the generated answer based on the query and the
90
- extended context.
72
+ - List[Edge]: A list of retrieved triplet edges relevant to the query.
91
73
  """
92
- triplets = context
93
-
94
- if triplets is None:
95
- triplets = await self.get_context(query)
96
74
 
75
+ triplets = await self.get_triplets(query)
97
76
  context_text = await self.resolve_edges_to_text(triplets)
98
-
99
77
  round_idx = 1
100
78
 
101
- while round_idx <= context_extension_rounds:
79
+ while round_idx <= self.context_extension_rounds:
102
80
  prev_size = len(triplets)
103
81
 
104
82
  logger.info(
@@ -112,7 +90,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
112
90
  system_prompt=self.system_prompt,
113
91
  )
114
92
 
115
- triplets += await self.get_context(completion)
93
+ triplets += await self.get_triplets(completion)
116
94
  triplets = list(set(triplets))
117
95
  context_text = await self.resolve_edges_to_text(triplets)
118
96
 
@@ -131,6 +109,24 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
131
109
 
132
110
  round_idx += 1
133
111
 
112
+ return triplets
113
+
114
+ async def get_completion_from_context(
115
+ self,
116
+ query: str,
117
+ retrieved_objects: List[Edge],
118
+ context: str,
119
+ ) -> List[Any]:
120
+ """
121
+ Returns a human readable answer based on the provided query and extended context derived from the retrieved objects.
122
+
123
+ Returns:
124
+ --------
125
+
126
+ - List[str]: A list containing the generated answer based on the query and the
127
+ extended context.
128
+ """
129
+
134
130
  # Check if we need to generate context summary for caching
135
131
  cache_config = CacheConfig()
136
132
  user = session_user.get()
@@ -138,33 +134,33 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
138
134
  session_save = user_id and cache_config.caching
139
135
 
140
136
  if session_save:
141
- conversation_history = await get_conversation_history(session_id=session_id)
137
+ conversation_history = await get_conversation_history(session_id=self.session_id)
142
138
 
143
139
  context_summary, completion = await asyncio.gather(
144
- summarize_text(context_text),
140
+ summarize_text(context),
145
141
  generate_completion(
146
142
  query=query,
147
- context=context_text,
143
+ context=context,
148
144
  user_prompt_path=self.user_prompt_path,
149
145
  system_prompt_path=self.system_prompt_path,
150
146
  system_prompt=self.system_prompt,
151
147
  conversation_history=conversation_history,
152
- response_model=response_model,
148
+ response_model=self.response_model,
153
149
  ),
154
150
  )
155
151
  else:
156
152
  completion = await generate_completion(
157
153
  query=query,
158
- context=context_text,
154
+ context=context,
159
155
  user_prompt_path=self.user_prompt_path,
160
156
  system_prompt_path=self.system_prompt_path,
161
157
  system_prompt=self.system_prompt,
162
- response_model=response_model,
158
+ response_model=self.response_model,
163
159
  )
164
160
 
165
- if self.save_interaction and context_text and triplets and completion:
161
+ if self.save_interaction and context and retrieved_objects and completion:
166
162
  await self.save_qa(
167
- question=query, answer=completion, context=context_text, triplets=triplets
163
+ question=query, answer=completion, context=context, triplets=retrieved_objects
168
164
  )
169
165
 
170
166
  if session_save:
@@ -172,7 +168,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
172
168
  query=query,
173
169
  context_summary=context_summary,
174
170
  answer=completion,
175
- session_id=session_id,
171
+ session_id=self.session_id,
176
172
  )
177
173
 
178
174
  return [completion]
@@ -18,6 +18,7 @@ from cognee.infrastructure.llm.LLMGateway import LLMGateway
18
18
  from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
19
19
  from cognee.context_global_variables import session_user
20
20
  from cognee.infrastructure.databases.cache.config import CacheConfig
21
+ from cognee.exceptions.exceptions import CogneeValidationError
21
22
 
22
23
  logger = get_logger()
23
24
 
@@ -67,6 +68,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
67
68
  save_interaction: bool = False,
68
69
  wide_search_top_k: Optional[int] = 100,
69
70
  triplet_distance_penalty: Optional[float] = 3.5,
71
+ max_iter: int = 4,
72
+ session_id: Optional[str] = None,
73
+ response_model: Type = str,
70
74
  ):
71
75
  super().__init__(
72
76
  user_prompt_path=user_prompt_path,
@@ -78,19 +82,68 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
78
82
  save_interaction=save_interaction,
79
83
  wide_search_top_k=wide_search_top_k,
80
84
  triplet_distance_penalty=triplet_distance_penalty,
85
+ session_id=session_id,
86
+ response_model=response_model,
81
87
  )
82
88
  self.validation_system_prompt_path = validation_system_prompt_path
83
89
  self.validation_user_prompt_path = validation_user_prompt_path
84
90
  self.followup_system_prompt_path = followup_system_prompt_path
85
91
  self.followup_user_prompt_path = followup_user_prompt_path
92
+ self.completion = []
93
+ self.max_iter = max_iter
94
+
95
+ async def get_retrieved_objects(self, query: str) -> List[Edge]:
96
+ """
97
+ Run chain-of-thought completion with optional structured output.
98
+
99
+ Parameters:
100
+ -----------
101
+ - query: User query
102
+
103
+ Returns:
104
+ --------
105
+ - List of retrieved edges
106
+ """
107
+ # Check if session saving is enabled
108
+ cache_config = CacheConfig()
109
+ user = session_user.get()
110
+ user_id = getattr(user, "id", None)
111
+ session_save = user_id and cache_config.caching
112
+
113
+ # Load conversation history if enabled
114
+ conversation_history = ""
115
+ if session_save:
116
+ conversation_history = await get_conversation_history(session_id=self.session_id)
117
+
118
+ completion, context_text, triplets = await self._run_cot_completion(
119
+ query=query,
120
+ conversation_history=conversation_history,
121
+ )
122
+
123
+ # Note: completion info is stored to reduce the need to call LLM again in get_completion_from_context
124
+ self.completion = completion
125
+
126
+ if self.save_interaction and context_text and triplets and completion:
127
+ await self.save_qa(
128
+ question=query, answer=str(completion), context=context_text, triplets=triplets
129
+ )
130
+
131
+ # Save to session cache if enabled
132
+ if session_save:
133
+ context_summary = await summarize_text(context_text)
134
+ await save_conversation_history(
135
+ query=query,
136
+ context_summary=context_summary,
137
+ answer=str(completion),
138
+ session_id=self.session_id,
139
+ )
140
+
141
+ return triplets
86
142
 
87
143
  async def _run_cot_completion(
88
144
  self,
89
145
  query: str,
90
- context: Optional[List[Edge]] = None,
91
146
  conversation_history: str = "",
92
- max_iter: int = 4,
93
- response_model: Type = str,
94
147
  ) -> tuple[Any, str, List[Edge]]:
95
148
  """
96
149
  Run chain-of-thought completion with optional structured output.
@@ -113,15 +166,12 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
113
166
  triplets = []
114
167
  completion = ""
115
168
 
116
- for round_idx in range(max_iter + 1):
169
+ for round_idx in range(self.max_iter + 1):
117
170
  if round_idx == 0:
118
- if context is None:
119
- triplets = await self.get_context(query)
120
- context_text = await self.resolve_edges_to_text(triplets)
121
- else:
122
- context_text = await self.resolve_edges_to_text(context)
171
+ triplets = await self.get_triplets(query)
172
+ context_text = await self.resolve_edges_to_text(triplets)
123
173
  else:
124
- triplets += await self.get_context(followup_question)
174
+ triplets += await self.get_triplets(followup_question)
125
175
  context_text = await self.resolve_edges_to_text(list(set(triplets)))
126
176
 
127
177
  completion = await generate_completion(
@@ -131,12 +181,12 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
131
181
  system_prompt_path=self.system_prompt_path,
132
182
  system_prompt=self.system_prompt,
133
183
  conversation_history=conversation_history if conversation_history else None,
134
- response_model=response_model,
184
+ response_model=self.response_model,
135
185
  )
136
186
 
137
187
  logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
138
188
 
139
- if round_idx < max_iter:
189
+ if round_idx < self.max_iter:
140
190
  answer_text = _as_answer_text(completion)
141
191
  valid_args = {"query": query, "answer": answer_text, "context": context_text}
142
192
  valid_user_prompt = render_prompt(
@@ -168,13 +218,11 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
168
218
 
169
219
  return completion, context_text, triplets
170
220
 
171
- async def get_completion(
221
+ async def get_completion_from_context(
172
222
  self,
173
223
  query: str,
174
- context: Optional[List[Edge]] = None,
175
- session_id: Optional[str] = None,
176
- max_iter=4,
177
- response_model: Type = str,
224
+ retrieved_objects: List[Edge],
225
+ context: str,
178
226
  ) -> List[Any]:
179
227
  """
180
228
  Generate completion responses based on a user query and contextual information.
@@ -202,38 +250,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
202
250
 
203
251
  - List[str]: A list containing the generated answer to the user's query.
204
252
  """
205
- # Check if session saving is enabled
206
- cache_config = CacheConfig()
207
- user = session_user.get()
208
- user_id = getattr(user, "id", None)
209
- session_save = user_id and cache_config.caching
210
-
211
- # Load conversation history if enabled
212
- conversation_history = ""
213
- if session_save:
214
- conversation_history = await get_conversation_history(session_id=session_id)
215
-
216
- completion, context_text, triplets = await self._run_cot_completion(
217
- query=query,
218
- context=context,
219
- conversation_history=conversation_history,
220
- max_iter=max_iter,
221
- response_model=response_model,
222
- )
223
-
224
- if self.save_interaction and context and triplets and completion:
225
- await self.save_qa(
226
- question=query, answer=str(completion), context=context_text, triplets=triplets
227
- )
228
-
229
- # Save to session cache if enabled
230
- if session_save:
231
- context_summary = await summarize_text(context_text)
232
- await save_conversation_history(
233
- query=query,
234
- context_summary=context_summary,
235
- answer=str(completion),
236
- session_id=session_id,
237
- )
238
-
253
+ if not retrieved_objects:
254
+ raise CogneeValidationError("No context retrieved to generate completion.")
255
+ completion = self.completion
239
256
  return [completion]