cognee 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (227) hide show
  1. cognee/__init__.py +1 -0
  2. cognee/api/client.py +9 -5
  3. cognee/api/v1/add/add.py +2 -1
  4. cognee/api/v1/add/routers/get_add_router.py +3 -1
  5. cognee/api/v1/cognify/cognify.py +24 -16
  6. cognee/api/v1/cognify/routers/__init__.py +0 -1
  7. cognee/api/v1/cognify/routers/get_cognify_router.py +30 -1
  8. cognee/api/v1/datasets/routers/get_datasets_router.py +3 -3
  9. cognee/api/v1/ontologies/__init__.py +4 -0
  10. cognee/api/v1/ontologies/ontologies.py +158 -0
  11. cognee/api/v1/ontologies/routers/__init__.py +0 -0
  12. cognee/api/v1/ontologies/routers/get_ontology_router.py +109 -0
  13. cognee/api/v1/permissions/routers/get_permissions_router.py +41 -1
  14. cognee/api/v1/search/search.py +4 -0
  15. cognee/api/v1/ui/node_setup.py +360 -0
  16. cognee/api/v1/ui/npm_utils.py +50 -0
  17. cognee/api/v1/ui/ui.py +38 -68
  18. cognee/cli/commands/cognify_command.py +8 -1
  19. cognee/cli/config.py +1 -1
  20. cognee/context_global_variables.py +86 -9
  21. cognee/eval_framework/Dockerfile +29 -0
  22. cognee/eval_framework/answer_generation/answer_generation_executor.py +10 -0
  23. cognee/eval_framework/answer_generation/run_question_answering_module.py +1 -1
  24. cognee/eval_framework/corpus_builder/task_getters/get_cascade_graph_tasks.py +0 -2
  25. cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py +4 -4
  26. cognee/eval_framework/eval_config.py +2 -2
  27. cognee/eval_framework/modal_run_eval.py +16 -28
  28. cognee/infrastructure/databases/cache/config.py +3 -1
  29. cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py +151 -0
  30. cognee/infrastructure/databases/cache/get_cache_engine.py +20 -10
  31. cognee/infrastructure/databases/dataset_database_handler/__init__.py +3 -0
  32. cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py +80 -0
  33. cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +18 -0
  34. cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py +10 -0
  35. cognee/infrastructure/databases/exceptions/exceptions.py +16 -0
  36. cognee/infrastructure/databases/graph/config.py +7 -0
  37. cognee/infrastructure/databases/graph/get_graph_engine.py +3 -0
  38. cognee/infrastructure/databases/graph/graph_db_interface.py +15 -0
  39. cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py +81 -0
  40. cognee/infrastructure/databases/graph/kuzu/adapter.py +228 -0
  41. cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +168 -0
  42. cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +80 -1
  43. cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +9 -0
  44. cognee/infrastructure/databases/utils/__init__.py +3 -0
  45. cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py +10 -0
  46. cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +66 -18
  47. cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py +10 -0
  48. cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py +30 -0
  49. cognee/infrastructure/databases/vector/config.py +5 -0
  50. cognee/infrastructure/databases/vector/create_vector_engine.py +6 -1
  51. cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +8 -6
  52. cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +9 -7
  53. cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +11 -10
  54. cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +2 -0
  55. cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py +50 -0
  56. cognee/infrastructure/databases/vector/vector_db_interface.py +35 -0
  57. cognee/infrastructure/engine/models/Edge.py +13 -1
  58. cognee/infrastructure/files/storage/s3_config.py +2 -0
  59. cognee/infrastructure/files/utils/guess_file_type.py +4 -0
  60. cognee/infrastructure/llm/LLMGateway.py +5 -2
  61. cognee/infrastructure/llm/config.py +37 -0
  62. cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py +2 -2
  63. cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +23 -8
  64. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +22 -18
  65. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py +5 -0
  66. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py +153 -0
  67. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +47 -38
  68. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +46 -37
  69. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +20 -10
  70. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +23 -11
  71. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +36 -23
  72. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +47 -36
  73. cognee/infrastructure/loaders/LoaderEngine.py +1 -0
  74. cognee/infrastructure/loaders/core/__init__.py +2 -1
  75. cognee/infrastructure/loaders/core/csv_loader.py +93 -0
  76. cognee/infrastructure/loaders/core/text_loader.py +1 -2
  77. cognee/infrastructure/loaders/external/advanced_pdf_loader.py +0 -9
  78. cognee/infrastructure/loaders/supported_loaders.py +2 -1
  79. cognee/memify_pipelines/create_triplet_embeddings.py +53 -0
  80. cognee/memify_pipelines/persist_sessions_in_knowledge_graph.py +55 -0
  81. cognee/modules/chunking/CsvChunker.py +35 -0
  82. cognee/modules/chunking/models/DocumentChunk.py +2 -1
  83. cognee/modules/chunking/text_chunker_with_overlap.py +124 -0
  84. cognee/modules/cognify/config.py +2 -0
  85. cognee/modules/data/deletion/prune_system.py +52 -2
  86. cognee/modules/data/methods/__init__.py +1 -0
  87. cognee/modules/data/methods/create_dataset.py +4 -2
  88. cognee/modules/data/methods/delete_dataset.py +26 -0
  89. cognee/modules/data/methods/get_dataset_ids.py +5 -1
  90. cognee/modules/data/methods/get_unique_data_id.py +68 -0
  91. cognee/modules/data/methods/get_unique_dataset_id.py +66 -4
  92. cognee/modules/data/models/Dataset.py +2 -0
  93. cognee/modules/data/processing/document_types/CsvDocument.py +33 -0
  94. cognee/modules/data/processing/document_types/__init__.py +1 -0
  95. cognee/modules/engine/models/Triplet.py +9 -0
  96. cognee/modules/engine/models/__init__.py +1 -0
  97. cognee/modules/graph/cognee_graph/CogneeGraph.py +89 -39
  98. cognee/modules/graph/cognee_graph/CogneeGraphElements.py +8 -3
  99. cognee/modules/graph/utils/expand_with_nodes_and_edges.py +19 -2
  100. cognee/modules/graph/utils/resolve_edges_to_text.py +48 -49
  101. cognee/modules/ingestion/identify.py +4 -4
  102. cognee/modules/memify/memify.py +1 -7
  103. cognee/modules/notebooks/operations/run_in_local_sandbox.py +3 -0
  104. cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py +55 -23
  105. cognee/modules/pipelines/operations/pipeline.py +18 -2
  106. cognee/modules/pipelines/operations/run_tasks_data_item.py +1 -1
  107. cognee/modules/retrieval/EntityCompletionRetriever.py +10 -3
  108. cognee/modules/retrieval/__init__.py +1 -1
  109. cognee/modules/retrieval/base_graph_retriever.py +7 -3
  110. cognee/modules/retrieval/base_retriever.py +7 -3
  111. cognee/modules/retrieval/completion_retriever.py +11 -4
  112. cognee/modules/retrieval/graph_completion_context_extension_retriever.py +10 -2
  113. cognee/modules/retrieval/graph_completion_cot_retriever.py +18 -51
  114. cognee/modules/retrieval/graph_completion_retriever.py +14 -1
  115. cognee/modules/retrieval/graph_summary_completion_retriever.py +4 -0
  116. cognee/modules/retrieval/register_retriever.py +10 -0
  117. cognee/modules/retrieval/registered_community_retrievers.py +1 -0
  118. cognee/modules/retrieval/temporal_retriever.py +13 -2
  119. cognee/modules/retrieval/triplet_retriever.py +182 -0
  120. cognee/modules/retrieval/utils/brute_force_triplet_search.py +43 -11
  121. cognee/modules/retrieval/utils/completion.py +2 -22
  122. cognee/modules/run_custom_pipeline/__init__.py +1 -0
  123. cognee/modules/run_custom_pipeline/run_custom_pipeline.py +76 -0
  124. cognee/modules/search/methods/get_search_type_tools.py +54 -8
  125. cognee/modules/search/methods/no_access_control_search.py +4 -0
  126. cognee/modules/search/methods/search.py +26 -3
  127. cognee/modules/search/types/SearchType.py +1 -1
  128. cognee/modules/settings/get_settings.py +19 -0
  129. cognee/modules/users/methods/create_user.py +12 -27
  130. cognee/modules/users/methods/get_authenticated_user.py +3 -2
  131. cognee/modules/users/methods/get_default_user.py +4 -2
  132. cognee/modules/users/methods/get_user.py +1 -1
  133. cognee/modules/users/methods/get_user_by_email.py +1 -1
  134. cognee/modules/users/models/DatasetDatabase.py +24 -3
  135. cognee/modules/users/models/Tenant.py +6 -7
  136. cognee/modules/users/models/User.py +6 -5
  137. cognee/modules/users/models/UserTenant.py +12 -0
  138. cognee/modules/users/models/__init__.py +1 -0
  139. cognee/modules/users/permissions/methods/get_all_user_permission_datasets.py +13 -13
  140. cognee/modules/users/roles/methods/add_user_to_role.py +3 -1
  141. cognee/modules/users/tenants/methods/__init__.py +1 -0
  142. cognee/modules/users/tenants/methods/add_user_to_tenant.py +21 -12
  143. cognee/modules/users/tenants/methods/create_tenant.py +22 -8
  144. cognee/modules/users/tenants/methods/select_tenant.py +62 -0
  145. cognee/shared/logging_utils.py +6 -0
  146. cognee/shared/rate_limiting.py +30 -0
  147. cognee/tasks/chunks/__init__.py +1 -0
  148. cognee/tasks/chunks/chunk_by_row.py +94 -0
  149. cognee/tasks/documents/__init__.py +0 -1
  150. cognee/tasks/documents/classify_documents.py +2 -0
  151. cognee/tasks/feedback/generate_improved_answers.py +3 -3
  152. cognee/tasks/graph/extract_graph_from_data.py +9 -10
  153. cognee/tasks/ingestion/ingest_data.py +1 -1
  154. cognee/tasks/memify/__init__.py +2 -0
  155. cognee/tasks/memify/cognify_session.py +41 -0
  156. cognee/tasks/memify/extract_user_sessions.py +73 -0
  157. cognee/tasks/memify/get_triplet_datapoints.py +289 -0
  158. cognee/tasks/storage/add_data_points.py +142 -2
  159. cognee/tasks/storage/index_data_points.py +33 -22
  160. cognee/tasks/storage/index_graph_edges.py +37 -57
  161. cognee/tests/integration/documents/CsvDocument_test.py +70 -0
  162. cognee/tests/integration/retrieval/test_triplet_retriever.py +84 -0
  163. cognee/tests/integration/tasks/test_add_data_points.py +139 -0
  164. cognee/tests/integration/tasks/test_get_triplet_datapoints.py +69 -0
  165. cognee/tests/integration/web_url_crawler/test_default_url_crawler.py +1 -1
  166. cognee/tests/integration/web_url_crawler/test_tavily_crawler.py +1 -1
  167. cognee/tests/integration/web_url_crawler/test_url_adding_e2e.py +13 -27
  168. cognee/tests/tasks/entity_extraction/entity_extraction_test.py +1 -1
  169. cognee/tests/test_add_docling_document.py +2 -2
  170. cognee/tests/test_cognee_server_start.py +84 -3
  171. cognee/tests/test_conversation_history.py +68 -5
  172. cognee/tests/test_data/example_with_header.csv +3 -0
  173. cognee/tests/test_dataset_database_handler.py +137 -0
  174. cognee/tests/test_dataset_delete.py +76 -0
  175. cognee/tests/test_edge_centered_payload.py +170 -0
  176. cognee/tests/test_edge_ingestion.py +27 -0
  177. cognee/tests/test_feedback_enrichment.py +1 -1
  178. cognee/tests/test_library.py +6 -4
  179. cognee/tests/test_load.py +62 -0
  180. cognee/tests/test_multi_tenancy.py +165 -0
  181. cognee/tests/test_parallel_databases.py +2 -0
  182. cognee/tests/test_pipeline_cache.py +164 -0
  183. cognee/tests/test_relational_db_migration.py +54 -2
  184. cognee/tests/test_search_db.py +44 -2
  185. cognee/tests/unit/api/test_conditional_authentication_endpoints.py +12 -3
  186. cognee/tests/unit/api/test_ontology_endpoint.py +252 -0
  187. cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +5 -0
  188. cognee/tests/unit/infrastructure/databases/test_index_data_points.py +27 -0
  189. cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py +14 -16
  190. cognee/tests/unit/infrastructure/llm/test_llm_config.py +46 -0
  191. cognee/tests/unit/infrastructure/mock_embedding_engine.py +3 -7
  192. cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +0 -5
  193. cognee/tests/unit/modules/chunking/test_text_chunker.py +248 -0
  194. cognee/tests/unit/modules/chunking/test_text_chunker_with_overlap.py +324 -0
  195. cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +2 -2
  196. cognee/tests/unit/modules/graph/cognee_graph_test.py +406 -0
  197. cognee/tests/unit/modules/memify_tasks/test_cognify_session.py +111 -0
  198. cognee/tests/unit/modules/memify_tasks/test_extract_user_sessions.py +175 -0
  199. cognee/tests/unit/modules/memify_tasks/test_get_triplet_datapoints.py +214 -0
  200. cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +0 -51
  201. cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +1 -0
  202. cognee/tests/unit/modules/retrieval/structured_output_test.py +204 -0
  203. cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +1 -1
  204. cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +0 -1
  205. cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +608 -0
  206. cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +83 -0
  207. cognee/tests/unit/modules/users/test_conditional_authentication.py +0 -63
  208. cognee/tests/unit/processing/chunks/chunk_by_row_test.py +52 -0
  209. cognee/tests/unit/tasks/storage/test_add_data_points.py +288 -0
  210. {cognee-0.4.0.dist-info → cognee-0.5.0.dist-info}/METADATA +11 -6
  211. {cognee-0.4.0.dist-info → cognee-0.5.0.dist-info}/RECORD +215 -163
  212. {cognee-0.4.0.dist-info → cognee-0.5.0.dist-info}/WHEEL +1 -1
  213. {cognee-0.4.0.dist-info → cognee-0.5.0.dist-info}/entry_points.txt +0 -1
  214. cognee/api/v1/cognify/code_graph_pipeline.py +0 -119
  215. cognee/api/v1/cognify/routers/get_code_pipeline_router.py +0 -90
  216. cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +0 -544
  217. cognee/modules/retrieval/code_retriever.py +0 -232
  218. cognee/tasks/code/enrich_dependency_graph_checker.py +0 -35
  219. cognee/tasks/code/get_local_dependencies_checker.py +0 -20
  220. cognee/tasks/code/get_repo_dependency_graph_checker.py +0 -35
  221. cognee/tasks/documents/check_permissions_on_dataset.py +0 -26
  222. cognee/tasks/repo_processor/__init__.py +0 -2
  223. cognee/tasks/repo_processor/get_local_dependencies.py +0 -335
  224. cognee/tasks/repo_processor/get_non_code_files.py +0 -158
  225. cognee/tasks/repo_processor/get_repo_file_dependencies.py +0 -243
  226. {cognee-0.4.0.dist-info → cognee-0.5.0.dist-info}/licenses/LICENSE +0 -0
  227. {cognee-0.4.0.dist-info → cognee-0.5.0.dist-info}/licenses/NOTICE.md +0 -0
@@ -1,4 +1,4 @@
1
- from typing import List, Optional
1
+ from typing import Any, List, Optional, Type
2
2
  from abc import ABC, abstractmethod
3
3
 
4
4
  from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
@@ -14,7 +14,11 @@ class BaseGraphRetriever(ABC):
14
14
 
15
15
  @abstractmethod
16
16
  async def get_completion(
17
- self, query: str, context: Optional[List[Edge]] = None, session_id: Optional[str] = None
18
- ) -> str:
17
+ self,
18
+ query: str,
19
+ context: Optional[List[Edge]] = None,
20
+ session_id: Optional[str] = None,
21
+ response_model: Type = str,
22
+ ) -> List[Any]:
19
23
  """Generates a response using the query and optional context (triplets)."""
20
24
  pass
@@ -1,5 +1,5 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Any, Optional
2
+ from typing import Any, Optional, Type, List
3
3
 
4
4
 
5
5
  class BaseRetriever(ABC):
@@ -12,7 +12,11 @@ class BaseRetriever(ABC):
12
12
 
13
13
  @abstractmethod
14
14
  async def get_completion(
15
- self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
16
- ) -> Any:
15
+ self,
16
+ query: str,
17
+ context: Optional[Any] = None,
18
+ session_id: Optional[str] = None,
19
+ response_model: Type = str,
20
+ ) -> List[Any]:
17
21
  """Generates a response using the query and optional context."""
18
22
  pass
@@ -1,5 +1,5 @@
1
1
  import asyncio
2
- from typing import Any, Optional
2
+ from typing import Any, Optional, Type, List
3
3
 
4
4
  from cognee.shared.logging_utils import get_logger
5
5
  from cognee.infrastructure.databases.vector import get_vector_engine
@@ -75,8 +75,12 @@ class CompletionRetriever(BaseRetriever):
75
75
  raise NoDataError("No data found in the system, please add data first.") from error
76
76
 
77
77
  async def get_completion(
78
- self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
79
- ) -> str:
78
+ self,
79
+ query: str,
80
+ context: Optional[Any] = None,
81
+ session_id: Optional[str] = None,
82
+ response_model: Type = str,
83
+ ) -> List[Any]:
80
84
  """
81
85
  Generates an LLM completion using the context.
82
86
 
@@ -91,6 +95,7 @@ class CompletionRetriever(BaseRetriever):
91
95
  completion; if None, it retrieves the context for the query. (default None)
92
96
  - session_id (Optional[str]): Optional session identifier for caching. If None,
93
97
  defaults to 'default_session'. (default None)
98
+ - response_model (Type): The Pydantic model type for structured output. (default str)
94
99
 
95
100
  Returns:
96
101
  --------
@@ -118,6 +123,7 @@ class CompletionRetriever(BaseRetriever):
118
123
  system_prompt_path=self.system_prompt_path,
119
124
  system_prompt=self.system_prompt,
120
125
  conversation_history=conversation_history,
126
+ response_model=response_model,
121
127
  ),
122
128
  )
123
129
  else:
@@ -127,6 +133,7 @@ class CompletionRetriever(BaseRetriever):
127
133
  user_prompt_path=self.user_prompt_path,
128
134
  system_prompt_path=self.system_prompt_path,
129
135
  system_prompt=self.system_prompt,
136
+ response_model=response_model,
130
137
  )
131
138
 
132
139
  if session_save:
@@ -137,4 +144,4 @@ class CompletionRetriever(BaseRetriever):
137
144
  session_id=session_id,
138
145
  )
139
146
 
140
- return completion
147
+ return [completion]
@@ -1,5 +1,5 @@
1
1
  import asyncio
2
- from typing import Optional, List, Type
2
+ from typing import Optional, List, Type, Any
3
3
  from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
4
4
  from cognee.shared.logging_utils import get_logger
5
5
  from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
@@ -39,6 +39,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
39
39
  node_type: Optional[Type] = None,
40
40
  node_name: Optional[List[str]] = None,
41
41
  save_interaction: bool = False,
42
+ wide_search_top_k: Optional[int] = 100,
43
+ triplet_distance_penalty: Optional[float] = 3.5,
42
44
  ):
43
45
  super().__init__(
44
46
  user_prompt_path=user_prompt_path,
@@ -48,6 +50,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
48
50
  node_name=node_name,
49
51
  save_interaction=save_interaction,
50
52
  system_prompt=system_prompt,
53
+ wide_search_top_k=wide_search_top_k,
54
+ triplet_distance_penalty=triplet_distance_penalty,
51
55
  )
52
56
 
53
57
  async def get_completion(
@@ -56,7 +60,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
56
60
  context: Optional[List[Edge]] = None,
57
61
  session_id: Optional[str] = None,
58
62
  context_extension_rounds=4,
59
- ) -> List[str]:
63
+ response_model: Type = str,
64
+ ) -> List[Any]:
60
65
  """
61
66
  Extends the context for a given query by retrieving related triplets and generating new
62
67
  completions based on them.
@@ -76,6 +81,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
76
81
  defaults to 'default_session'. (default None)
77
82
  - context_extension_rounds: The maximum number of rounds to extend the context with
78
83
  new triplets before halting. (default 4)
84
+ - response_model (Type): The Pydantic model type for structured output. (default str)
79
85
 
80
86
  Returns:
81
87
  --------
@@ -143,6 +149,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
143
149
  system_prompt_path=self.system_prompt_path,
144
150
  system_prompt=self.system_prompt,
145
151
  conversation_history=conversation_history,
152
+ response_model=response_model,
146
153
  ),
147
154
  )
148
155
  else:
@@ -152,6 +159,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
152
159
  user_prompt_path=self.user_prompt_path,
153
160
  system_prompt_path=self.system_prompt_path,
154
161
  system_prompt=self.system_prompt,
162
+ response_model=response_model,
155
163
  )
156
164
 
157
165
  if self.save_interaction and context_text and triplets and completion:
@@ -7,7 +7,7 @@ from cognee.shared.logging_utils import get_logger
7
7
 
8
8
  from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
9
9
  from cognee.modules.retrieval.utils.completion import (
10
- generate_structured_completion,
10
+ generate_completion,
11
11
  summarize_text,
12
12
  )
13
13
  from cognee.modules.retrieval.utils.session_cache import (
@@ -44,7 +44,6 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
44
44
  questions based on reasoning. The public methods are:
45
45
 
46
46
  - get_completion
47
- - get_structured_completion
48
47
 
49
48
  Instance variables include:
50
49
  - validation_system_prompt_path
@@ -66,6 +65,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
66
65
  node_type: Optional[Type] = None,
67
66
  node_name: Optional[List[str]] = None,
68
67
  save_interaction: bool = False,
68
+ wide_search_top_k: Optional[int] = 100,
69
+ triplet_distance_penalty: Optional[float] = 3.5,
69
70
  ):
70
71
  super().__init__(
71
72
  user_prompt_path=user_prompt_path,
@@ -75,6 +76,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
75
76
  node_type=node_type,
76
77
  node_name=node_name,
77
78
  save_interaction=save_interaction,
79
+ wide_search_top_k=wide_search_top_k,
80
+ triplet_distance_penalty=triplet_distance_penalty,
78
81
  )
79
82
  self.validation_system_prompt_path = validation_system_prompt_path
80
83
  self.validation_user_prompt_path = validation_user_prompt_path
@@ -121,7 +124,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
121
124
  triplets += await self.get_context(followup_question)
122
125
  context_text = await self.resolve_edges_to_text(list(set(triplets)))
123
126
 
124
- completion = await generate_structured_completion(
127
+ completion = await generate_completion(
125
128
  query=query,
126
129
  context=context_text,
127
130
  user_prompt_path=self.user_prompt_path,
@@ -165,24 +168,28 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
165
168
 
166
169
  return completion, context_text, triplets
167
170
 
168
- async def get_structured_completion(
171
+ async def get_completion(
169
172
  self,
170
173
  query: str,
171
174
  context: Optional[List[Edge]] = None,
172
175
  session_id: Optional[str] = None,
173
- max_iter: int = 4,
176
+ max_iter=4,
174
177
  response_model: Type = str,
175
- ) -> Any:
178
+ ) -> List[Any]:
176
179
  """
177
- Generate structured completion responses based on a user query and contextual information.
180
+ Generate completion responses based on a user query and contextual information.
178
181
 
179
- This method applies the same chain-of-thought logic as get_completion but returns
182
+ This method interacts with a language model client to retrieve a structured response,
183
+ using a series of iterations to refine the answers and generate follow-up questions
184
+ based on reasoning derived from previous outputs. It raises exceptions if the context
185
+ retrieval fails or if the model encounters issues in generating outputs. It returns
180
186
  structured output using the provided response model.
181
187
 
182
188
  Parameters:
183
189
  -----------
190
+
184
191
  - query (str): The user's query to be processed and answered.
185
- - context (Optional[List[Edge]]): Optional context that may assist in answering the query.
192
+ - context (Optional[Any]): Optional context that may assist in answering the query.
186
193
  If not provided, it will be fetched based on the query. (default None)
187
194
  - session_id (Optional[str]): Optional session identifier for caching. If None,
188
195
  defaults to 'default_session'. (default None)
@@ -192,7 +199,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
192
199
 
193
200
  Returns:
194
201
  --------
195
- - Any: The generated structured completion based on the response model.
202
+
203
+ - List[str]: A list containing the generated answer to the user's query.
196
204
  """
197
205
  # Check if session saving is enabled
198
206
  cache_config = CacheConfig()
@@ -228,45 +236,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
228
236
  session_id=session_id,
229
237
  )
230
238
 
231
- return completion
232
-
233
- async def get_completion(
234
- self,
235
- query: str,
236
- context: Optional[List[Edge]] = None,
237
- session_id: Optional[str] = None,
238
- max_iter=4,
239
- ) -> List[str]:
240
- """
241
- Generate completion responses based on a user query and contextual information.
242
-
243
- This method interacts with a language model client to retrieve a structured response,
244
- using a series of iterations to refine the answers and generate follow-up questions
245
- based on reasoning derived from previous outputs. It raises exceptions if the context
246
- retrieval fails or if the model encounters issues in generating outputs.
247
-
248
- Parameters:
249
- -----------
250
-
251
- - query (str): The user's query to be processed and answered.
252
- - context (Optional[Any]): Optional context that may assist in answering the query.
253
- If not provided, it will be fetched based on the query. (default None)
254
- - session_id (Optional[str]): Optional session identifier for caching. If None,
255
- defaults to 'default_session'. (default None)
256
- - max_iter: The maximum number of iterations to refine the answer and generate
257
- follow-up questions. (default 4)
258
-
259
- Returns:
260
- --------
261
-
262
- - List[str]: A list containing the generated answer to the user's query.
263
- """
264
- completion = await self.get_structured_completion(
265
- query=query,
266
- context=context,
267
- session_id=session_id,
268
- max_iter=max_iter,
269
- response_model=str,
270
- )
271
-
272
239
  return [completion]
@@ -47,6 +47,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
47
47
  node_type: Optional[Type] = None,
48
48
  node_name: Optional[List[str]] = None,
49
49
  save_interaction: bool = False,
50
+ wide_search_top_k: Optional[int] = 100,
51
+ triplet_distance_penalty: Optional[float] = 3.5,
50
52
  ):
51
53
  """Initialize retriever with prompt paths and search parameters."""
52
54
  self.save_interaction = save_interaction
@@ -54,8 +56,10 @@ class GraphCompletionRetriever(BaseGraphRetriever):
54
56
  self.system_prompt_path = system_prompt_path
55
57
  self.system_prompt = system_prompt
56
58
  self.top_k = top_k if top_k is not None else 5
59
+ self.wide_search_top_k = wide_search_top_k
57
60
  self.node_type = node_type
58
61
  self.node_name = node_name
62
+ self.triplet_distance_penalty = triplet_distance_penalty
59
63
 
60
64
  async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
61
65
  """
@@ -105,6 +109,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
105
109
  collections=vector_index_collections or None,
106
110
  node_type=self.node_type,
107
111
  node_name=self.node_name,
112
+ wide_search_top_k=self.wide_search_top_k,
113
+ triplet_distance_penalty=self.triplet_distance_penalty,
108
114
  )
109
115
 
110
116
  return found_triplets
@@ -141,12 +147,17 @@ class GraphCompletionRetriever(BaseGraphRetriever):
141
147
 
142
148
  return triplets
143
149
 
150
+ async def convert_retrieved_objects_to_context(self, triplets: List[Edge]):
151
+ context = await self.resolve_edges_to_text(triplets)
152
+ return context
153
+
144
154
  async def get_completion(
145
155
  self,
146
156
  query: str,
147
157
  context: Optional[List[Edge]] = None,
148
158
  session_id: Optional[str] = None,
149
- ) -> List[str]:
159
+ response_model: Type = str,
160
+ ) -> List[Any]:
150
161
  """
151
162
  Generates a completion using graph connections context based on a query.
152
163
 
@@ -188,6 +199,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
188
199
  system_prompt_path=self.system_prompt_path,
189
200
  system_prompt=self.system_prompt,
190
201
  conversation_history=conversation_history,
202
+ response_model=response_model,
191
203
  ),
192
204
  )
193
205
  else:
@@ -197,6 +209,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
197
209
  user_prompt_path=self.user_prompt_path,
198
210
  system_prompt_path=self.system_prompt_path,
199
211
  system_prompt=self.system_prompt,
212
+ response_model=response_model,
200
213
  )
201
214
 
202
215
  if self.save_interaction and context and triplets and completion:
@@ -26,6 +26,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
26
26
  node_type: Optional[Type] = None,
27
27
  node_name: Optional[List[str]] = None,
28
28
  save_interaction: bool = False,
29
+ wide_search_top_k: Optional[int] = 100,
30
+ triplet_distance_penalty: Optional[float] = 3.5,
29
31
  ):
30
32
  """Initialize retriever with default prompt paths and search parameters."""
31
33
  super().__init__(
@@ -36,6 +38,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
36
38
  node_name=node_name,
37
39
  save_interaction=save_interaction,
38
40
  system_prompt=system_prompt,
41
+ wide_search_top_k=wide_search_top_k,
42
+ triplet_distance_penalty=triplet_distance_penalty,
39
43
  )
40
44
  self.summarize_prompt_path = summarize_prompt_path
41
45
 
@@ -0,0 +1,10 @@
1
+ from typing import Type
2
+
3
+ from .base_retriever import BaseRetriever
4
+ from .registered_community_retrievers import registered_community_retrievers
5
+ from ..search.types import SearchType
6
+
7
+
8
+ def use_retriever(search_type: SearchType, retriever: Type[BaseRetriever]):
9
+ """Register a retriever class for a given search type."""
10
+ registered_community_retrievers[search_type] = retriever
@@ -0,0 +1 @@
1
+ registered_community_retrievers = {}
@@ -47,6 +47,8 @@ class TemporalRetriever(GraphCompletionRetriever):
47
47
  top_k: Optional[int] = 5,
48
48
  node_type: Optional[Type] = None,
49
49
  node_name: Optional[List[str]] = None,
50
+ wide_search_top_k: Optional[int] = 100,
51
+ triplet_distance_penalty: Optional[float] = 3.5,
50
52
  ):
51
53
  super().__init__(
52
54
  user_prompt_path=user_prompt_path,
@@ -54,6 +56,8 @@ class TemporalRetriever(GraphCompletionRetriever):
54
56
  top_k=top_k,
55
57
  node_type=node_type,
56
58
  node_name=node_name,
59
+ wide_search_top_k=wide_search_top_k,
60
+ triplet_distance_penalty=triplet_distance_penalty,
57
61
  )
58
62
  self.user_prompt_path = user_prompt_path
59
63
  self.system_prompt_path = system_prompt_path
@@ -146,8 +150,12 @@ class TemporalRetriever(GraphCompletionRetriever):
146
150
  return self.descriptions_to_string(top_k_events)
147
151
 
148
152
  async def get_completion(
149
- self, query: str, context: Optional[str] = None, session_id: Optional[str] = None
150
- ) -> List[str]:
153
+ self,
154
+ query: str,
155
+ context: Optional[str] = None,
156
+ session_id: Optional[str] = None,
157
+ response_model: Type = str,
158
+ ) -> List[Any]:
151
159
  """
152
160
  Generates a response using the query and optional context.
153
161
 
@@ -159,6 +167,7 @@ class TemporalRetriever(GraphCompletionRetriever):
159
167
  retrieved based on the query. (default None)
160
168
  - session_id (Optional[str]): Optional session identifier for caching. If None,
161
169
  defaults to 'default_session'. (default None)
170
+ - response_model (Type): The Pydantic model type for structured output. (default str)
162
171
 
163
172
  Returns:
164
173
  --------
@@ -186,6 +195,7 @@ class TemporalRetriever(GraphCompletionRetriever):
186
195
  user_prompt_path=self.user_prompt_path,
187
196
  system_prompt_path=self.system_prompt_path,
188
197
  conversation_history=conversation_history,
198
+ response_model=response_model,
189
199
  ),
190
200
  )
191
201
  else:
@@ -194,6 +204,7 @@ class TemporalRetriever(GraphCompletionRetriever):
194
204
  context=context,
195
205
  user_prompt_path=self.user_prompt_path,
196
206
  system_prompt_path=self.system_prompt_path,
207
+ response_model=response_model,
197
208
  )
198
209
 
199
210
  if session_save:
@@ -0,0 +1,182 @@
1
+ import asyncio
2
+ from typing import Any, Optional, Type, List
3
+
4
+ from cognee.shared.logging_utils import get_logger
5
+ from cognee.infrastructure.databases.vector import get_vector_engine
6
+ from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
7
+ from cognee.modules.retrieval.utils.session_cache import (
8
+ save_conversation_history,
9
+ get_conversation_history,
10
+ )
11
+ from cognee.modules.retrieval.base_retriever import BaseRetriever
12
+ from cognee.modules.retrieval.exceptions.exceptions import NoDataError
13
+ from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
14
+ from cognee.context_global_variables import session_user
15
+ from cognee.infrastructure.databases.cache.config import CacheConfig
16
+
17
+ logger = get_logger("TripletRetriever")
18
+
19
+
20
+ class TripletRetriever(BaseRetriever):
21
+ """
22
+ Retriever for handling LLM-based completion searches using triplets.
23
+
24
+ Public methods:
25
+ - get_context(query: str) -> str
26
+ - get_completion(query: str, context: Optional[Any] = None) -> Any
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ user_prompt_path: str = "context_for_question.txt",
32
+ system_prompt_path: str = "answer_simple_question.txt",
33
+ system_prompt: Optional[str] = None,
34
+ top_k: Optional[int] = 5,
35
+ ):
36
+ """Initialize retriever with optional custom prompt paths."""
37
+ self.user_prompt_path = user_prompt_path
38
+ self.system_prompt_path = system_prompt_path
39
+ self.top_k = top_k if top_k is not None else 1
40
+ self.system_prompt = system_prompt
41
+
42
+ async def get_context(self, query: str) -> str:
43
+ """
44
+ Retrieves relevant triplets as context.
45
+
46
+ Fetches triplets based on a query from a vector engine and combines their text.
47
+ Returns empty string if no triplets are found. Raises NoDataError if the collection is not
48
+ found.
49
+
50
+ Parameters:
51
+ -----------
52
+
53
+ - query (str): The query string used to search for relevant triplets.
54
+
55
+ Returns:
56
+ --------
57
+
58
+ - str: A string containing the combined text of the retrieved triplets, or an
59
+ empty string if none are found.
60
+ """
61
+ vector_engine = get_vector_engine()
62
+
63
+ try:
64
+ if not await vector_engine.has_collection(collection_name="Triplet_text"):
65
+ logger.error("Triplet_text collection not found")
66
+ raise NoDataError(
67
+ "In order to use TRIPLET_COMPLETION first use the create_triplet_embeddings memify pipeline. "
68
+ )
69
+
70
+ found_triplets = await vector_engine.search("Triplet_text", query, limit=self.top_k)
71
+
72
+ if len(found_triplets) == 0:
73
+ return ""
74
+
75
+ triplets_payload = [found_triplet.payload["text"] for found_triplet in found_triplets]
76
+ combined_context = "\n".join(triplets_payload)
77
+ return combined_context
78
+ except CollectionNotFoundError as error:
79
+ logger.error("Triplet_text collection not found")
80
+ raise NoDataError("No data found in the system, please add data first.") from error
81
+
82
+ async def get_completion(
83
+ self,
84
+ query: str,
85
+ context: Optional[Any] = None,
86
+ session_id: Optional[str] = None,
87
+ response_model: Type = str,
88
+ ) -> List[Any]:
89
+ """
90
+ Generates an LLM completion using the context.
91
+
92
+ Retrieves context if not provided and generates a completion based on the query and
93
+ context using an external completion generator.
94
+
95
+ Parameters:
96
+ -----------
97
+
98
+ - query (str): The query string to be used for generating a completion.
99
+ - context (Optional[Any]): Optional pre-fetched context to use for generating the
100
+ completion; if None, it retrieves the context for the query. (default None)
101
+ - session_id (Optional[str]): Optional session identifier for caching. If None,
102
+ defaults to 'default_session'. (default None)
103
+ - response_model (Type): The Pydantic model type for structured output. (default str)
104
+
105
+ Returns:
106
+ --------
107
+
108
+ - Any: The generated completion based on the provided query and context.
109
+ """
110
+ if context is None:
111
+ context = await self.get_context(query)
112
+
113
+ cache_config = CacheConfig()
114
+ user = session_user.get()
115
+ user_id = getattr(user, "id", None)
116
+ session_save = user_id and cache_config.caching
117
+
118
+ if session_save:
119
+ completion = await self._get_completion_with_session(
120
+ query=query,
121
+ context=context,
122
+ session_id=session_id,
123
+ response_model=response_model,
124
+ )
125
+ else:
126
+ completion = await self._get_completion_without_session(
127
+ query=query,
128
+ context=context,
129
+ response_model=response_model,
130
+ )
131
+
132
+ return [completion]
133
+
134
+ async def _get_completion_with_session(
135
+ self,
136
+ query: str,
137
+ context: str,
138
+ session_id: Optional[str],
139
+ response_model: Type,
140
+ ) -> Any:
141
+ """Generate completion with session history and caching."""
142
+ conversation_history = await get_conversation_history(session_id=session_id)
143
+
144
+ context_summary, completion = await asyncio.gather(
145
+ summarize_text(context),
146
+ generate_completion(
147
+ query=query,
148
+ context=context,
149
+ user_prompt_path=self.user_prompt_path,
150
+ system_prompt_path=self.system_prompt_path,
151
+ system_prompt=self.system_prompt,
152
+ conversation_history=conversation_history,
153
+ response_model=response_model,
154
+ ),
155
+ )
156
+
157
+ await save_conversation_history(
158
+ query=query,
159
+ context_summary=context_summary,
160
+ answer=completion,
161
+ session_id=session_id,
162
+ )
163
+
164
+ return completion
165
+
166
+ async def _get_completion_without_session(
167
+ self,
168
+ query: str,
169
+ context: str,
170
+ response_model: Type,
171
+ ) -> Any:
172
+ """Generate completion without session history."""
173
+ completion = await generate_completion(
174
+ query=query,
175
+ context=context,
176
+ user_prompt_path=self.user_prompt_path,
177
+ system_prompt_path=self.system_prompt_path,
178
+ system_prompt=self.system_prompt,
179
+ response_model=response_model,
180
+ )
181
+
182
+ return completion