cognee 0.5.1__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (265) hide show
  1. cognee/__init__.py +2 -0
  2. cognee/alembic/README +1 -0
  3. cognee/alembic/env.py +107 -0
  4. cognee/alembic/script.py.mako +26 -0
  5. cognee/alembic/versions/1a58b986e6e1_enable_delete_for_old_tutorial_notebooks.py +52 -0
  6. cognee/alembic/versions/1d0bb7fede17_add_pipeline_run_status.py +33 -0
  7. cognee/alembic/versions/1daae0df1866_incremental_loading.py +48 -0
  8. cognee/alembic/versions/211ab850ef3d_add_sync_operations_table.py +118 -0
  9. cognee/alembic/versions/45957f0a9849_add_notebook_table.py +46 -0
  10. cognee/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py +333 -0
  11. cognee/alembic/versions/482cd6517ce4_add_default_user.py +30 -0
  12. cognee/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py +98 -0
  13. cognee/alembic/versions/8057ae7329c2_initial_migration.py +25 -0
  14. cognee/alembic/versions/9e7a3cb85175_loader_separation.py +104 -0
  15. cognee/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py +38 -0
  16. cognee/alembic/versions/ab7e313804ae_permission_system_rework.py +236 -0
  17. cognee/alembic/versions/b9274c27a25a_kuzu_11_migration.py +75 -0
  18. cognee/alembic/versions/c946955da633_multi_tenant_support.py +137 -0
  19. cognee/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py +51 -0
  20. cognee/alembic/versions/e4ebee1091e7_expand_data_model_info.py +140 -0
  21. cognee/alembic.ini +117 -0
  22. cognee/api/v1/add/add.py +2 -1
  23. cognee/api/v1/add/routers/get_add_router.py +2 -0
  24. cognee/api/v1/cognify/cognify.py +11 -6
  25. cognee/api/v1/cognify/routers/get_cognify_router.py +8 -0
  26. cognee/api/v1/config/config.py +60 -0
  27. cognee/api/v1/datasets/routers/get_datasets_router.py +46 -3
  28. cognee/api/v1/memify/routers/get_memify_router.py +3 -0
  29. cognee/api/v1/search/routers/get_search_router.py +21 -6
  30. cognee/api/v1/search/search.py +21 -5
  31. cognee/api/v1/sync/routers/get_sync_router.py +3 -3
  32. cognee/cli/commands/add_command.py +1 -1
  33. cognee/cli/commands/cognify_command.py +6 -0
  34. cognee/cli/commands/config_command.py +1 -1
  35. cognee/context_global_variables.py +5 -1
  36. cognee/eval_framework/answer_generation/answer_generation_executor.py +7 -8
  37. cognee/infrastructure/databases/cache/cache_db_interface.py +38 -1
  38. cognee/infrastructure/databases/cache/config.py +6 -0
  39. cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py +21 -0
  40. cognee/infrastructure/databases/cache/get_cache_engine.py +9 -3
  41. cognee/infrastructure/databases/cache/redis/RedisAdapter.py +60 -1
  42. cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +7 -0
  43. cognee/infrastructure/databases/graph/get_graph_engine.py +29 -1
  44. cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +62 -27
  45. cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +17 -4
  46. cognee/infrastructure/databases/relational/config.py +16 -1
  47. cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
  48. cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +26 -3
  49. cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +2 -0
  50. cognee/infrastructure/databases/vector/config.py +6 -0
  51. cognee/infrastructure/databases/vector/create_vector_engine.py +70 -16
  52. cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +64 -9
  53. cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +13 -2
  54. cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +16 -3
  55. cognee/infrastructure/databases/vector/models/ScoredResult.py +3 -3
  56. cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +16 -3
  57. cognee/infrastructure/databases/vector/pgvector/PGVectorDatasetDatabaseHandler.py +86 -0
  58. cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py +81 -2
  59. cognee/infrastructure/databases/vector/vector_db_interface.py +8 -0
  60. cognee/infrastructure/files/utils/get_data_file_path.py +33 -27
  61. cognee/infrastructure/llm/LLMGateway.py +0 -13
  62. cognee/infrastructure/llm/prompts/extract_query_time.txt +1 -1
  63. cognee/infrastructure/llm/prompts/generate_event_entity_prompt.txt +1 -1
  64. cognee/infrastructure/llm/prompts/generate_event_graph_prompt.txt +1 -1
  65. cognee/infrastructure/llm/prompts/generate_graph_prompt.txt +2 -2
  66. cognee/infrastructure/llm/prompts/generate_graph_prompt_guided.txt +1 -1
  67. cognee/infrastructure/llm/prompts/generate_graph_prompt_oneshot.txt +2 -2
  68. cognee/infrastructure/llm/prompts/generate_graph_prompt_simple.txt +1 -1
  69. cognee/infrastructure/llm/prompts/generate_graph_prompt_strict.txt +1 -1
  70. cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt +6 -6
  71. cognee/infrastructure/llm/prompts/test.txt +1 -1
  72. cognee/infrastructure/llm/prompts/translate_content.txt +19 -0
  73. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
  74. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
  75. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
  76. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +29 -5
  77. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llama_cpp/adapter.py +191 -0
  78. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
  79. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
  80. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
  81. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
  82. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
  83. cognee/modules/chunking/models/DocumentChunk.py +0 -1
  84. cognee/modules/cognify/config.py +2 -0
  85. cognee/modules/data/models/Data.py +3 -1
  86. cognee/modules/engine/models/Entity.py +0 -1
  87. cognee/modules/engine/operations/setup.py +6 -0
  88. cognee/modules/graph/cognee_graph/CogneeGraph.py +150 -37
  89. cognee/modules/graph/cognee_graph/CogneeGraphElements.py +48 -2
  90. cognee/modules/graph/utils/__init__.py +1 -0
  91. cognee/modules/graph/utils/get_entity_nodes_from_triplets.py +12 -0
  92. cognee/modules/notebooks/methods/__init__.py +1 -0
  93. cognee/modules/notebooks/methods/create_notebook.py +0 -34
  94. cognee/modules/notebooks/methods/create_tutorial_notebooks.py +191 -0
  95. cognee/modules/notebooks/methods/get_notebooks.py +12 -8
  96. cognee/modules/notebooks/tutorials/cognee-basics/cell-1.md +3 -0
  97. cognee/modules/notebooks/tutorials/cognee-basics/cell-2.md +10 -0
  98. cognee/modules/notebooks/tutorials/cognee-basics/cell-3.md +7 -0
  99. cognee/modules/notebooks/tutorials/cognee-basics/cell-4.py +28 -0
  100. cognee/modules/notebooks/tutorials/cognee-basics/cell-5.py +3 -0
  101. cognee/modules/notebooks/tutorials/cognee-basics/cell-6.py +9 -0
  102. cognee/modules/notebooks/tutorials/cognee-basics/cell-7.py +17 -0
  103. cognee/modules/notebooks/tutorials/cognee-basics/config.json +4 -0
  104. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-1.md +3 -0
  105. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-10.md +3 -0
  106. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-11.md +3 -0
  107. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-12.py +3 -0
  108. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-13.md +7 -0
  109. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-14.py +6 -0
  110. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-15.md +3 -0
  111. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-16.py +7 -0
  112. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-2.md +9 -0
  113. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-3.md +7 -0
  114. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-4.md +9 -0
  115. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-5.md +5 -0
  116. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-6.py +13 -0
  117. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-7.md +3 -0
  118. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-8.md +3 -0
  119. cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-9.py +31 -0
  120. cognee/modules/notebooks/tutorials/python-development-with-cognee/config.json +4 -0
  121. cognee/modules/notebooks/tutorials/python-development-with-cognee/data/copilot_conversations.json +107 -0
  122. cognee/modules/notebooks/tutorials/python-development-with-cognee/data/guido_contributions.json +976 -0
  123. cognee/modules/notebooks/tutorials/python-development-with-cognee/data/my_developer_rules.md +79 -0
  124. cognee/modules/notebooks/tutorials/python-development-with-cognee/data/pep_style_guide.md +74 -0
  125. cognee/modules/notebooks/tutorials/python-development-with-cognee/data/zen_principles.md +74 -0
  126. cognee/modules/retrieval/EntityCompletionRetriever.py +51 -38
  127. cognee/modules/retrieval/__init__.py +0 -1
  128. cognee/modules/retrieval/base_retriever.py +66 -10
  129. cognee/modules/retrieval/chunks_retriever.py +57 -49
  130. cognee/modules/retrieval/coding_rules_retriever.py +12 -5
  131. cognee/modules/retrieval/completion_retriever.py +29 -28
  132. cognee/modules/retrieval/cypher_search_retriever.py +25 -20
  133. cognee/modules/retrieval/graph_completion_context_extension_retriever.py +42 -46
  134. cognee/modules/retrieval/graph_completion_cot_retriever.py +68 -51
  135. cognee/modules/retrieval/graph_completion_retriever.py +78 -63
  136. cognee/modules/retrieval/graph_summary_completion_retriever.py +2 -0
  137. cognee/modules/retrieval/lexical_retriever.py +34 -12
  138. cognee/modules/retrieval/natural_language_retriever.py +18 -15
  139. cognee/modules/retrieval/summaries_retriever.py +51 -34
  140. cognee/modules/retrieval/temporal_retriever.py +59 -49
  141. cognee/modules/retrieval/triplet_retriever.py +32 -33
  142. cognee/modules/retrieval/utils/access_tracking.py +88 -0
  143. cognee/modules/retrieval/utils/brute_force_triplet_search.py +99 -103
  144. cognee/modules/retrieval/utils/node_edge_vector_search.py +174 -0
  145. cognee/modules/search/methods/__init__.py +1 -0
  146. cognee/modules/search/methods/get_retriever_output.py +53 -0
  147. cognee/modules/search/methods/get_search_type_retriever_instance.py +252 -0
  148. cognee/modules/search/methods/search.py +90 -222
  149. cognee/modules/search/models/SearchResultPayload.py +67 -0
  150. cognee/modules/search/types/SearchResult.py +1 -8
  151. cognee/modules/search/types/SearchType.py +1 -2
  152. cognee/modules/search/types/__init__.py +1 -1
  153. cognee/modules/search/utils/__init__.py +1 -2
  154. cognee/modules/search/utils/transform_insights_to_graph.py +2 -2
  155. cognee/modules/search/utils/{transform_context_to_graph.py → transform_triplets_to_graph.py} +2 -2
  156. cognee/modules/users/authentication/default/default_transport.py +11 -1
  157. cognee/modules/users/authentication/get_api_auth_backend.py +2 -1
  158. cognee/modules/users/authentication/get_client_auth_backend.py +2 -1
  159. cognee/modules/users/methods/create_user.py +0 -9
  160. cognee/modules/users/permissions/methods/has_user_management_permission.py +29 -0
  161. cognee/modules/visualization/cognee_network_visualization.py +1 -1
  162. cognee/run_migrations.py +48 -0
  163. cognee/shared/exceptions/__init__.py +1 -3
  164. cognee/shared/exceptions/exceptions.py +11 -1
  165. cognee/shared/usage_logger.py +332 -0
  166. cognee/shared/utils.py +12 -5
  167. cognee/tasks/chunks/__init__.py +9 -0
  168. cognee/tasks/cleanup/cleanup_unused_data.py +172 -0
  169. cognee/tasks/graph/__init__.py +7 -0
  170. cognee/tasks/ingestion/data_item.py +8 -0
  171. cognee/tasks/ingestion/ingest_data.py +12 -1
  172. cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
  173. cognee/tasks/memify/__init__.py +8 -0
  174. cognee/tasks/memify/extract_usage_frequency.py +613 -0
  175. cognee/tasks/summarization/models.py +0 -2
  176. cognee/tasks/temporal_graph/__init__.py +0 -1
  177. cognee/tasks/translation/__init__.py +96 -0
  178. cognee/tasks/translation/config.py +110 -0
  179. cognee/tasks/translation/detect_language.py +190 -0
  180. cognee/tasks/translation/exceptions.py +62 -0
  181. cognee/tasks/translation/models.py +72 -0
  182. cognee/tasks/translation/providers/__init__.py +44 -0
  183. cognee/tasks/translation/providers/azure_provider.py +192 -0
  184. cognee/tasks/translation/providers/base.py +85 -0
  185. cognee/tasks/translation/providers/google_provider.py +158 -0
  186. cognee/tasks/translation/providers/llm_provider.py +143 -0
  187. cognee/tasks/translation/translate_content.py +282 -0
  188. cognee/tasks/web_scraper/default_url_crawler.py +6 -2
  189. cognee/tests/cli_tests/cli_unit_tests/test_cli_commands.py +1 -0
  190. cognee/tests/cli_tests/cli_unit_tests/test_cli_edge_cases.py +3 -0
  191. cognee/tests/integration/retrieval/test_brute_force_triplet_search_with_cognify.py +62 -0
  192. cognee/tests/integration/retrieval/test_chunks_retriever.py +351 -0
  193. cognee/tests/integration/retrieval/test_graph_completion_retriever.py +276 -0
  194. cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +228 -0
  195. cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +217 -0
  196. cognee/tests/integration/retrieval/test_rag_completion_retriever.py +319 -0
  197. cognee/tests/integration/retrieval/test_structured_output.py +258 -0
  198. cognee/tests/integration/retrieval/test_summaries_retriever.py +195 -0
  199. cognee/tests/integration/retrieval/test_temporal_retriever.py +336 -0
  200. cognee/tests/integration/retrieval/test_triplet_retriever.py +45 -1
  201. cognee/tests/integration/shared/test_usage_logger_integration.py +255 -0
  202. cognee/tests/tasks/translation/README.md +147 -0
  203. cognee/tests/tasks/translation/__init__.py +1 -0
  204. cognee/tests/tasks/translation/config_test.py +93 -0
  205. cognee/tests/tasks/translation/detect_language_test.py +118 -0
  206. cognee/tests/tasks/translation/providers_test.py +151 -0
  207. cognee/tests/tasks/translation/translate_content_test.py +213 -0
  208. cognee/tests/test_chromadb.py +1 -1
  209. cognee/tests/test_cleanup_unused_data.py +165 -0
  210. cognee/tests/test_custom_data_label.py +68 -0
  211. cognee/tests/test_delete_by_id.py +6 -6
  212. cognee/tests/test_extract_usage_frequency.py +308 -0
  213. cognee/tests/test_kuzu.py +17 -7
  214. cognee/tests/test_lancedb.py +3 -1
  215. cognee/tests/test_library.py +1 -1
  216. cognee/tests/test_neo4j.py +17 -7
  217. cognee/tests/test_neptune_analytics_vector.py +3 -1
  218. cognee/tests/test_permissions.py +172 -187
  219. cognee/tests/test_pgvector.py +3 -1
  220. cognee/tests/test_relational_db_migration.py +15 -1
  221. cognee/tests/test_remote_kuzu.py +3 -1
  222. cognee/tests/test_s3_file_storage.py +1 -1
  223. cognee/tests/test_search_db.py +345 -205
  224. cognee/tests/test_usage_logger_e2e.py +268 -0
  225. cognee/tests/unit/api/test_get_raw_data_endpoint.py +206 -0
  226. cognee/tests/unit/eval_framework/answer_generation_test.py +4 -3
  227. cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
  228. cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
  229. cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +2 -0
  230. cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
  231. cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +42 -2
  232. cognee/tests/unit/modules/graph/cognee_graph_test.py +329 -31
  233. cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +122 -168
  234. cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
  235. cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +486 -157
  236. cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +693 -155
  237. cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +619 -200
  238. cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +300 -171
  239. cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +184 -155
  240. cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +544 -79
  241. cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +476 -28
  242. cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
  243. cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
  244. cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py +273 -0
  245. cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
  246. cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +267 -7
  247. cognee/tests/unit/modules/search/test_get_search_type_retriever_instance.py +125 -0
  248. cognee/tests/unit/modules/search/test_search.py +96 -20
  249. cognee/tests/unit/modules/search/test_search_prepare_search_result_contract.py +190 -0
  250. cognee/tests/unit/modules/users/test_tutorial_notebook_creation.py +511 -297
  251. cognee/tests/unit/shared/test_usage_logger.py +241 -0
  252. cognee/tests/unit/users/permissions/test_has_user_management_permission.py +46 -0
  253. {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/METADATA +22 -17
  254. {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/RECORD +258 -157
  255. cognee/api/.env.example +0 -5
  256. cognee/modules/retrieval/base_graph_retriever.py +0 -24
  257. cognee/modules/search/methods/get_search_type_tools.py +0 -223
  258. cognee/modules/search/methods/no_access_control_search.py +0 -62
  259. cognee/modules/search/utils/prepare_search_result.py +0 -63
  260. cognee/tests/test_feedback_enrichment.py +0 -174
  261. cognee/tests/unit/modules/retrieval/structured_output_test.py +0 -204
  262. {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/WHEEL +0 -0
  263. {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/entry_points.txt +0 -0
  264. {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/licenses/LICENSE +0 -0
  265. {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/licenses/NOTICE.md +0 -0
@@ -1,39 +1,18 @@
1
- import asyncio
2
- import time
3
- from typing import List, Optional, Type
1
+ from typing import List, Optional, Type, Union
4
2
 
5
3
  from cognee.shared.logging_utils import get_logger, ERROR
6
4
  from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
7
- from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
8
5
  from cognee.infrastructure.databases.graph import get_graph_engine
9
- from cognee.infrastructure.databases.vector import get_vector_engine
6
+ from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
10
7
  from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
11
8
  from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
12
- from cognee.modules.users.models import User
13
- from cognee.shared.utils import send_telemetry
9
+ from cognee.modules.retrieval.utils.node_edge_vector_search import NodeEdgeVectorSearch
14
10
 
15
11
  logger = get_logger(level=ERROR)
16
12
 
17
13
 
18
14
  def format_triplets(edges):
19
- print("\n\n\n")
20
-
21
- def filter_attributes(obj, attributes):
22
- """Helper function to filter out non-None properties, including nested dicts."""
23
- result = {}
24
- for attr in attributes:
25
- value = getattr(obj, attr, None)
26
- if value is not None:
27
- # If the value is a dict, extract relevant keys from it
28
- if isinstance(value, dict):
29
- nested_values = {
30
- k: v for k, v in value.items() if k in attributes and v is not None
31
- }
32
- result[attr] = nested_values
33
- else:
34
- result[attr] = value
35
- return result
36
-
15
+ """Formats edges into human-readable triplet strings."""
37
16
  triplets = []
38
17
  for edge in edges:
39
18
  node1 = edge.node1
@@ -42,12 +21,10 @@ def format_triplets(edges):
42
21
  node1_attributes = node1.attributes
43
22
  node2_attributes = node2.attributes
44
23
 
45
- # Filter only non-None properties
46
24
  node1_info = {key: value for key, value in node1_attributes.items() if value is not None}
47
25
  node2_info = {key: value for key, value in node2_attributes.items() if value is not None}
48
26
  edge_info = {key: value for key, value in edge_attributes.items() if value is not None}
49
27
 
50
- # Create the formatted triplet
51
28
  triplet = f"Node1: {node1_info}\nEdge: {edge_info}\nNode2: {node2_info}\n\n\n"
52
29
  triplets.append(triplet)
53
30
 
@@ -69,7 +46,6 @@ async def get_memory_fragment(
69
46
 
70
47
  try:
71
48
  graph_engine = await get_graph_engine()
72
-
73
49
  await memory_fragment.project_graph_from_db(
74
50
  graph_engine,
75
51
  node_properties_to_project=properties_to_project,
@@ -79,20 +55,64 @@ async def get_memory_fragment(
79
55
  relevant_ids_to_filter=relevant_ids_to_filter,
80
56
  triplet_distance_penalty=triplet_distance_penalty,
81
57
  )
82
-
83
58
  except EntityNotFoundError:
84
- # This is expected behavior - continue with empty fragment
85
59
  pass
86
60
  except Exception as e:
87
61
  logger.error(f"Error during memory fragment creation: {str(e)}")
88
- # Still return the fragment even if projection failed
89
- pass
90
62
 
91
63
  return memory_fragment
92
64
 
93
65
 
66
+ async def _get_top_triplet_importances(
67
+ memory_fragment: Optional[CogneeGraph],
68
+ vector_search: NodeEdgeVectorSearch,
69
+ properties_to_project: Optional[List[str]],
70
+ node_type: Optional[Type],
71
+ node_name: Optional[List[str]],
72
+ triplet_distance_penalty: float,
73
+ wide_search_limit: Optional[int],
74
+ top_k: int,
75
+ query_list_length: Optional[int] = None,
76
+ ) -> Union[List[Edge], List[List[Edge]]]:
77
+ """Creates memory fragment (if needed), maps distances, and calculates top triplet importances.
78
+
79
+ Args:
80
+ query_list_length: Number of queries in batch mode (None for single-query mode).
81
+ When None, node_distances/edge_distances are flat lists; when set, they are list-of-lists.
82
+
83
+ Returns:
84
+ List[Edge]: For single-query mode (query_list_length is None).
85
+ List[List[Edge]]: For batch mode (query_list_length is set), one list per query.
86
+ """
87
+ if memory_fragment is None:
88
+ if wide_search_limit is None:
89
+ relevant_node_ids = None
90
+ else:
91
+ relevant_node_ids = vector_search.extract_relevant_node_ids()
92
+
93
+ memory_fragment = await get_memory_fragment(
94
+ properties_to_project=properties_to_project,
95
+ node_type=node_type,
96
+ node_name=node_name,
97
+ relevant_ids_to_filter=relevant_node_ids,
98
+ triplet_distance_penalty=triplet_distance_penalty,
99
+ )
100
+
101
+ await memory_fragment.map_vector_distances_to_graph_nodes(
102
+ node_distances=vector_search.node_distances, query_list_length=query_list_length
103
+ )
104
+ await memory_fragment.map_vector_distances_to_graph_edges(
105
+ edge_distances=vector_search.edge_distances, query_list_length=query_list_length
106
+ )
107
+
108
+ return await memory_fragment.calculate_top_triplet_importances(
109
+ k=top_k, query_list_length=query_list_length
110
+ )
111
+
112
+
94
113
  async def brute_force_triplet_search(
95
- query: str,
114
+ query: Optional[str] = None,
115
+ query_batch: Optional[List[str]] = None,
96
116
  top_k: int = 5,
97
117
  collections: Optional[List[str]] = None,
98
118
  properties_to_project: Optional[List[str]] = None,
@@ -101,33 +121,49 @@ async def brute_force_triplet_search(
101
121
  node_name: Optional[List[str]] = None,
102
122
  wide_search_top_k: Optional[int] = 100,
103
123
  triplet_distance_penalty: Optional[float] = 3.5,
104
- ) -> List[Edge]:
124
+ ) -> Union[List[Edge], List[List[Edge]]]:
105
125
  """
106
126
  Performs a brute force search to retrieve the top triplets from the graph.
107
127
 
108
128
  Args:
109
- query (str): The search query.
129
+ query (Optional[str]): The search query (single query mode). Exactly one of query or query_batch must be provided.
130
+ query_batch (Optional[List[str]]): List of search queries (batch mode). Exactly one of query or query_batch must be provided.
110
131
  top_k (int): The number of top results to retrieve.
111
132
  collections (Optional[List[str]]): List of collections to query.
112
133
  properties_to_project (Optional[List[str]]): List of properties to project.
113
134
  memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse.
114
135
  node_type: node type to filter
115
136
  node_name: node name to filter
116
- wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections
137
+ wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections.
138
+ Ignored in batch mode (always None to project full graph).
117
139
  triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection
118
140
 
119
141
  Returns:
120
- list: The top triplet results.
142
+ List[Edge]: The top triplet results for single query mode (flat list).
143
+ List[List[Edge]]: List of top triplet results (one per query) for batch mode (list-of-lists).
144
+
145
+ Note:
146
+ In single-query mode, node_distances and edge_distances are stored as flat lists.
147
+ In batch mode, they are stored as list-of-lists (one list per query).
121
148
  """
122
- if not query or not isinstance(query, str):
149
+ if query is not None and query_batch is not None:
150
+ raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.")
151
+ if query is None and query_batch is None:
152
+ raise ValueError("Must provide either 'query' or 'query_batch'.")
153
+ if query is not None and (not query or not isinstance(query, str)):
123
154
  raise ValueError("The query must be a non-empty string.")
155
+ if query_batch is not None:
156
+ if not isinstance(query_batch, list) or not query_batch:
157
+ raise ValueError("query_batch must be a non-empty list of strings.")
158
+ if not all(isinstance(q, str) and q for q in query_batch):
159
+ raise ValueError("All items in query_batch must be non-empty strings.")
124
160
  if top_k <= 0:
125
161
  raise ValueError("top_k must be a positive integer.")
126
162
 
127
- # Setting wide search limit based on the parameters
128
- non_global_search = node_name is None
129
-
130
- wide_search_limit = wide_search_top_k if non_global_search else None
163
+ query_list_length = len(query_batch) if query_batch is not None else None
164
+ wide_search_limit = (
165
+ None if query_list_length else (wide_search_top_k if node_name is None else None)
166
+ )
131
167
 
132
168
  if collections is None:
133
169
  collections = [
@@ -141,77 +177,37 @@ async def brute_force_triplet_search(
141
177
  collections.append("EdgeType_relationship_name")
142
178
 
143
179
  try:
144
- vector_engine = get_vector_engine()
145
- except Exception as e:
146
- logger.error("Failed to initialize vector engine: %s", e)
147
- raise RuntimeError("Initialization error") from e
148
-
149
- query_vector = (await vector_engine.embedding_engine.embed_text([query]))[0]
180
+ vector_search = NodeEdgeVectorSearch()
150
181
 
151
- async def search_in_collection(collection_name: str):
152
- try:
153
- return await vector_engine.search(
154
- collection_name=collection_name, query_vector=query_vector, limit=wide_search_limit
155
- )
156
- except CollectionNotFoundError:
157
- return []
158
-
159
- try:
160
- start_time = time.time()
161
-
162
- results = await asyncio.gather(
163
- *[search_in_collection(collection_name) for collection_name in collections]
182
+ await vector_search.embed_and_retrieve_distances(
183
+ query=None if query_list_length else query,
184
+ query_batch=query_batch if query_list_length else None,
185
+ collections=collections,
186
+ wide_search_limit=wide_search_limit,
164
187
  )
165
188
 
166
- if all(not item for item in results):
167
- return []
168
-
169
- # Final statistics
170
- vector_collection_search_time = time.time() - start_time
171
- logger.info(
172
- f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {vector_collection_search_time:.2f}s"
189
+ if not vector_search.has_results():
190
+ return [[] for _ in range(query_list_length)] if query_list_length else []
191
+
192
+ results = await _get_top_triplet_importances(
193
+ memory_fragment,
194
+ vector_search,
195
+ properties_to_project,
196
+ node_type,
197
+ node_name,
198
+ triplet_distance_penalty,
199
+ wide_search_limit,
200
+ top_k,
201
+ query_list_length=query_list_length,
173
202
  )
174
203
 
175
- node_distances = {collection: result for collection, result in zip(collections, results)}
176
-
177
- edge_distances = node_distances.get("EdgeType_relationship_name", None)
178
-
179
- if wide_search_limit is not None:
180
- relevant_ids_to_filter = list(
181
- {
182
- str(getattr(scored_node, "id"))
183
- for collection_name, score_collection in node_distances.items()
184
- if collection_name != "EdgeType_relationship_name"
185
- and isinstance(score_collection, (list, tuple))
186
- for scored_node in score_collection
187
- if getattr(scored_node, "id", None)
188
- }
189
- )
190
- else:
191
- relevant_ids_to_filter = None
192
-
193
- if memory_fragment is None:
194
- memory_fragment = await get_memory_fragment(
195
- properties_to_project=properties_to_project,
196
- node_type=node_type,
197
- node_name=node_name,
198
- relevant_ids_to_filter=relevant_ids_to_filter,
199
- triplet_distance_penalty=triplet_distance_penalty,
200
- )
201
-
202
- await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
203
- await memory_fragment.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
204
-
205
- results = await memory_fragment.calculate_top_triplet_importances(k=top_k)
206
-
207
204
  return results
208
-
209
205
  except CollectionNotFoundError:
210
- return []
206
+ return [[] for _ in range(query_list_length)] if query_list_length else []
211
207
  except Exception as error:
212
208
  logger.error(
213
209
  "Error during brute force search for query: %s. Error: %s",
214
- query,
210
+ query_batch if query_list_length else [query],
215
211
  error,
216
212
  )
217
213
  raise error
@@ -0,0 +1,174 @@
1
+ import asyncio
2
+ import time
3
+ from typing import Any, List, Optional
4
+
5
+ from cognee.shared.logging_utils import get_logger, ERROR
6
+ from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
7
+ from cognee.infrastructure.databases.vector import get_vector_engine
8
+
9
+ logger = get_logger(level=ERROR)
10
+
11
+
12
+ class NodeEdgeVectorSearch:
13
+ """Manages vector search and distance retrieval for graph nodes and edges."""
14
+
15
+ def __init__(self, edge_collection: str = "EdgeType_relationship_name", vector_engine=None):
16
+ self.edge_collection = edge_collection
17
+ self.vector_engine = vector_engine or self._init_vector_engine()
18
+ self.query_vector: Optional[Any] = None
19
+ self.node_distances: dict[str, list[Any]] = {}
20
+ self.edge_distances: list[Any] = []
21
+ self.query_list_length: Optional[int] = None
22
+
23
+ def _init_vector_engine(self):
24
+ try:
25
+ return get_vector_engine()
26
+ except Exception as e:
27
+ logger.error("Failed to initialize vector engine: %s", e)
28
+ raise RuntimeError("Initialization error") from e
29
+
30
+ async def embed_and_retrieve_distances(
31
+ self,
32
+ query: Optional[str] = None,
33
+ query_batch: Optional[List[str]] = None,
34
+ collections: List[str] = None,
35
+ wide_search_limit: Optional[int] = None,
36
+ ):
37
+ """Embeds query/queries and retrieves vector distances from all collections."""
38
+ if query is not None and query_batch is not None:
39
+ raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.")
40
+ if query is None and query_batch is None:
41
+ raise ValueError("Must provide either 'query' or 'query_batch'.")
42
+ if not collections:
43
+ raise ValueError("'collections' must be a non-empty list.")
44
+
45
+ start_time = time.time()
46
+
47
+ if query_batch is not None:
48
+ self.query_list_length = len(query_batch)
49
+ search_results = await self._run_batch_search(collections, query_batch)
50
+ else:
51
+ self.query_list_length = None
52
+ search_results = await self._run_single_search(collections, query, wide_search_limit)
53
+
54
+ elapsed_time = time.time() - start_time
55
+ collections_with_results = sum(1 for result in search_results if any(result))
56
+ logger.info(
57
+ f"Vector collection retrieval completed: Retrieved distances from "
58
+ f"{collections_with_results} collections in {elapsed_time:.2f}s"
59
+ )
60
+
61
+ self.set_distances_from_results(collections, search_results, self.query_list_length)
62
+
63
+ def has_results(self) -> bool:
64
+ """Checks if any collections returned results."""
65
+ if self.query_list_length is None:
66
+ if self.edge_distances and any(self.edge_distances):
67
+ return True
68
+ return any(
69
+ bool(collection_results) for collection_results in self.node_distances.values()
70
+ )
71
+
72
+ if self.edge_distances and any(inner_list for inner_list in self.edge_distances):
73
+ return True
74
+ return any(
75
+ any(results_per_query for results_per_query in collection_results)
76
+ for collection_results in self.node_distances.values()
77
+ )
78
+
79
+ def extract_relevant_node_ids(self) -> List[str]:
80
+ """Extracts unique node IDs from search results."""
81
+ if self.query_list_length is not None:
82
+ return []
83
+ relevant_node_ids = set()
84
+ for scored_results in self.node_distances.values():
85
+ for scored_node in scored_results:
86
+ node_id = getattr(scored_node, "id", None)
87
+ if node_id:
88
+ relevant_node_ids.add(str(node_id))
89
+ return list(relevant_node_ids)
90
+
91
+ def set_distances_from_results(
92
+ self,
93
+ collections: List[str],
94
+ search_results: List[List[Any]],
95
+ query_list_length: Optional[int] = None,
96
+ ):
97
+ """Separates search results into node and edge distances with stable shapes.
98
+
99
+ Ensures all collections are present in the output, even if empty:
100
+ - Batch mode: missing/empty collections become [[]] * query_list_length
101
+ - Single mode: missing/empty collections become []
102
+ """
103
+ self.node_distances = {}
104
+ self.edge_distances = (
105
+ [] if query_list_length is None else [[] for _ in range(query_list_length)]
106
+ )
107
+ for collection, result in zip(collections, search_results):
108
+ if not result:
109
+ empty_result = (
110
+ [] if query_list_length is None else [[] for _ in range(query_list_length)]
111
+ )
112
+ if collection == self.edge_collection:
113
+ self.edge_distances = empty_result
114
+ else:
115
+ self.node_distances[collection] = empty_result
116
+ else:
117
+ if collection == self.edge_collection:
118
+ self.edge_distances = result
119
+ else:
120
+ self.node_distances[collection] = result
121
+
122
+ async def _run_batch_search(
123
+ self, collections: List[str], query_batch: List[str]
124
+ ) -> List[List[Any]]:
125
+ """Runs batch search across all collections and returns list-of-lists per collection."""
126
+ search_tasks = [
127
+ self._search_batch_collection(collection, query_batch) for collection in collections
128
+ ]
129
+ return await asyncio.gather(*search_tasks)
130
+
131
+ async def _search_batch_collection(
132
+ self, collection_name: str, query_batch: List[str]
133
+ ) -> List[List[Any]]:
134
+ """Searches one collection with batch queries and returns list-of-lists."""
135
+ try:
136
+ return await self.vector_engine.batch_search(
137
+ collection_name=collection_name, query_texts=query_batch, limit=None
138
+ )
139
+ except CollectionNotFoundError:
140
+ return [[]] * len(query_batch)
141
+
142
+ async def _run_single_search(
143
+ self, collections: List[str], query: str, wide_search_limit: Optional[int]
144
+ ) -> List[List[Any]]:
145
+ """Runs single query search and returns flat lists per collection.
146
+
147
+ Returns a list where each element is a collection's results (flat list).
148
+ These are stored as flat lists in node_distances/edge_distances for single-query mode.
149
+ """
150
+ await self._embed_query(query)
151
+ search_tasks = [
152
+ self._search_single_collection(self.vector_engine, wide_search_limit, collection)
153
+ for collection in collections
154
+ ]
155
+ search_results = await asyncio.gather(*search_tasks)
156
+ return search_results
157
+
158
+ async def _embed_query(self, query: str):
159
+ """Embeds the query and stores the resulting vector."""
160
+ query_embeddings = await self.vector_engine.embedding_engine.embed_text([query])
161
+ self.query_vector = query_embeddings[0]
162
+
163
+ async def _search_single_collection(
164
+ self, vector_engine: Any, wide_search_limit: Optional[int], collection_name: str
165
+ ):
166
+ """Searches one collection and returns results or empty list if not found."""
167
+ try:
168
+ return await vector_engine.search(
169
+ collection_name=collection_name,
170
+ query_vector=self.query_vector,
171
+ limit=wide_search_limit,
172
+ )
173
+ except CollectionNotFoundError:
174
+ return []
@@ -1 +1,2 @@
1
1
  from .search import search
2
+ from .get_retriever_output import get_retriever_output
@@ -0,0 +1,53 @@
1
+ from cognee.infrastructure.databases.graph import get_graph_engine
2
+ from cognee.modules.search.models.SearchResultPayload import SearchResultPayload
3
+ from cognee.modules.search.methods.get_search_type_retriever_instance import (
4
+ get_search_type_retriever_instance,
5
+ )
6
+ from cognee.modules.search.types import SearchType
7
+ from cognee.shared.logging_utils import get_logger
8
+
9
+ logger = get_logger()
10
+
11
+
12
+ async def get_retriever_output(query_type: SearchType, query_text: str, **kwargs):
13
+ graph_engine = await get_graph_engine()
14
+ is_empty = await graph_engine.is_empty()
15
+
16
+ if is_empty:
17
+ logger.warning("Search attempt on an empty knowledge graph")
18
+
19
+ retriever_instance = await get_search_type_retriever_instance(
20
+ query_type=query_type, query_text=query_text, **kwargs
21
+ )
22
+
23
+ # Get raw result objects from retriever and forward to context and completion methods to avoid duplicate retrievals.
24
+ retrieved_objects = await retriever_instance.get_retrieved_objects(query=query_text)
25
+
26
+ # Handle raw result object to extract context information
27
+ context = await retriever_instance.get_context_from_objects(
28
+ query=query_text, retrieved_objects=retrieved_objects
29
+ )
30
+
31
+ completion = None
32
+ if not kwargs.get(
33
+ "only_context", False
34
+ ): # If only_context is True, skip getting completion. Performance optimization.
35
+ # Handle raw result and context object to handle completion operation
36
+ completion = await retriever_instance.get_completion_from_context(
37
+ query=query_text,
38
+ retrieved_objects=retrieved_objects,
39
+ context=context,
40
+ )
41
+
42
+ search_result = SearchResultPayload(
43
+ result_object=retrieved_objects,
44
+ context=context,
45
+ completion=completion,
46
+ search_type=query_type,
47
+ only_context=kwargs.get("only_context", False),
48
+ dataset_name=kwargs.get("dataset").name if kwargs.get("dataset") else None,
49
+ dataset_id=kwargs.get("dataset").id if kwargs.get("dataset") else None,
50
+ dataset_tenant_id=kwargs.get("dataset").tenant_id if kwargs.get("dataset") else None,
51
+ )
52
+
53
+ return search_result