cognee 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. cognee/__init__.py +2 -0
  2. cognee/api/client.py +28 -3
  3. cognee/api/health.py +10 -13
  4. cognee/api/v1/add/add.py +3 -1
  5. cognee/api/v1/add/routers/get_add_router.py +12 -37
  6. cognee/api/v1/cloud/routers/__init__.py +1 -0
  7. cognee/api/v1/cloud/routers/get_checks_router.py +23 -0
  8. cognee/api/v1/cognify/code_graph_pipeline.py +9 -4
  9. cognee/api/v1/cognify/cognify.py +50 -3
  10. cognee/api/v1/cognify/routers/get_cognify_router.py +1 -1
  11. cognee/api/v1/datasets/routers/get_datasets_router.py +15 -4
  12. cognee/api/v1/memify/__init__.py +0 -0
  13. cognee/api/v1/memify/routers/__init__.py +1 -0
  14. cognee/api/v1/memify/routers/get_memify_router.py +100 -0
  15. cognee/api/v1/notebooks/routers/__init__.py +1 -0
  16. cognee/api/v1/notebooks/routers/get_notebooks_router.py +96 -0
  17. cognee/api/v1/search/routers/get_search_router.py +20 -1
  18. cognee/api/v1/search/search.py +11 -4
  19. cognee/api/v1/sync/__init__.py +17 -0
  20. cognee/api/v1/sync/routers/__init__.py +3 -0
  21. cognee/api/v1/sync/routers/get_sync_router.py +241 -0
  22. cognee/api/v1/sync/sync.py +877 -0
  23. cognee/api/v1/ui/__init__.py +1 -0
  24. cognee/api/v1/ui/ui.py +529 -0
  25. cognee/api/v1/users/routers/get_auth_router.py +13 -1
  26. cognee/base_config.py +10 -1
  27. cognee/cli/_cognee.py +93 -0
  28. cognee/infrastructure/databases/graph/config.py +10 -4
  29. cognee/infrastructure/databases/graph/kuzu/adapter.py +135 -0
  30. cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +89 -0
  31. cognee/infrastructure/databases/relational/__init__.py +2 -0
  32. cognee/infrastructure/databases/relational/get_async_session.py +15 -0
  33. cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +6 -1
  34. cognee/infrastructure/databases/relational/with_async_session.py +25 -0
  35. cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +1 -1
  36. cognee/infrastructure/databases/vector/config.py +13 -6
  37. cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +1 -1
  38. cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +2 -6
  39. cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py +4 -1
  40. cognee/infrastructure/files/storage/LocalFileStorage.py +9 -0
  41. cognee/infrastructure/files/storage/S3FileStorage.py +5 -0
  42. cognee/infrastructure/files/storage/StorageManager.py +7 -1
  43. cognee/infrastructure/files/storage/storage.py +16 -0
  44. cognee/infrastructure/llm/LLMGateway.py +18 -0
  45. cognee/infrastructure/llm/config.py +4 -2
  46. cognee/infrastructure/llm/prompts/extract_query_time.txt +15 -0
  47. cognee/infrastructure/llm/prompts/generate_event_entity_prompt.txt +25 -0
  48. cognee/infrastructure/llm/prompts/generate_event_graph_prompt.txt +30 -0
  49. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/__init__.py +2 -0
  50. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/extract_event_entities.py +44 -0
  51. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/__init__.py +1 -0
  52. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/extract_event_graph.py +46 -0
  53. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -1
  54. cognee/infrastructure/utils/run_sync.py +8 -1
  55. cognee/modules/chunking/models/DocumentChunk.py +4 -3
  56. cognee/modules/cloud/exceptions/CloudApiKeyMissingError.py +15 -0
  57. cognee/modules/cloud/exceptions/CloudConnectionError.py +15 -0
  58. cognee/modules/cloud/exceptions/__init__.py +2 -0
  59. cognee/modules/cloud/operations/__init__.py +1 -0
  60. cognee/modules/cloud/operations/check_api_key.py +25 -0
  61. cognee/modules/data/deletion/prune_system.py +1 -1
  62. cognee/modules/data/methods/check_dataset_name.py +1 -1
  63. cognee/modules/data/methods/get_dataset_data.py +1 -1
  64. cognee/modules/data/methods/load_or_create_datasets.py +1 -1
  65. cognee/modules/engine/models/Event.py +16 -0
  66. cognee/modules/engine/models/Interval.py +8 -0
  67. cognee/modules/engine/models/Timestamp.py +13 -0
  68. cognee/modules/engine/models/__init__.py +3 -0
  69. cognee/modules/engine/utils/__init__.py +2 -0
  70. cognee/modules/engine/utils/generate_event_datapoint.py +46 -0
  71. cognee/modules/engine/utils/generate_timestamp_datapoint.py +51 -0
  72. cognee/modules/graph/cognee_graph/CogneeGraph.py +2 -2
  73. cognee/modules/graph/utils/__init__.py +1 -0
  74. cognee/modules/graph/utils/resolve_edges_to_text.py +71 -0
  75. cognee/modules/memify/__init__.py +1 -0
  76. cognee/modules/memify/memify.py +118 -0
  77. cognee/modules/notebooks/methods/__init__.py +5 -0
  78. cognee/modules/notebooks/methods/create_notebook.py +26 -0
  79. cognee/modules/notebooks/methods/delete_notebook.py +13 -0
  80. cognee/modules/notebooks/methods/get_notebook.py +21 -0
  81. cognee/modules/notebooks/methods/get_notebooks.py +18 -0
  82. cognee/modules/notebooks/methods/update_notebook.py +17 -0
  83. cognee/modules/notebooks/models/Notebook.py +53 -0
  84. cognee/modules/notebooks/models/__init__.py +1 -0
  85. cognee/modules/notebooks/operations/__init__.py +1 -0
  86. cognee/modules/notebooks/operations/run_in_local_sandbox.py +55 -0
  87. cognee/modules/pipelines/layers/reset_dataset_pipeline_run_status.py +19 -3
  88. cognee/modules/pipelines/operations/pipeline.py +1 -0
  89. cognee/modules/pipelines/operations/run_tasks.py +17 -41
  90. cognee/modules/retrieval/base_graph_retriever.py +18 -0
  91. cognee/modules/retrieval/base_retriever.py +1 -1
  92. cognee/modules/retrieval/code_retriever.py +8 -0
  93. cognee/modules/retrieval/coding_rules_retriever.py +31 -0
  94. cognee/modules/retrieval/completion_retriever.py +9 -3
  95. cognee/modules/retrieval/context_providers/TripletSearchContextProvider.py +1 -0
  96. cognee/modules/retrieval/graph_completion_context_extension_retriever.py +23 -14
  97. cognee/modules/retrieval/graph_completion_cot_retriever.py +21 -11
  98. cognee/modules/retrieval/graph_completion_retriever.py +32 -65
  99. cognee/modules/retrieval/graph_summary_completion_retriever.py +3 -1
  100. cognee/modules/retrieval/insights_retriever.py +14 -3
  101. cognee/modules/retrieval/summaries_retriever.py +1 -1
  102. cognee/modules/retrieval/temporal_retriever.py +152 -0
  103. cognee/modules/retrieval/utils/brute_force_triplet_search.py +7 -32
  104. cognee/modules/retrieval/utils/completion.py +10 -3
  105. cognee/modules/search/methods/get_search_type_tools.py +168 -0
  106. cognee/modules/search/methods/no_access_control_search.py +47 -0
  107. cognee/modules/search/methods/search.py +219 -139
  108. cognee/modules/search/types/SearchResult.py +21 -0
  109. cognee/modules/search/types/SearchType.py +2 -0
  110. cognee/modules/search/types/__init__.py +1 -0
  111. cognee/modules/search/utils/__init__.py +2 -0
  112. cognee/modules/search/utils/prepare_search_result.py +41 -0
  113. cognee/modules/search/utils/transform_context_to_graph.py +38 -0
  114. cognee/modules/sync/__init__.py +1 -0
  115. cognee/modules/sync/methods/__init__.py +23 -0
  116. cognee/modules/sync/methods/create_sync_operation.py +53 -0
  117. cognee/modules/sync/methods/get_sync_operation.py +107 -0
  118. cognee/modules/sync/methods/update_sync_operation.py +248 -0
  119. cognee/modules/sync/models/SyncOperation.py +142 -0
  120. cognee/modules/sync/models/__init__.py +3 -0
  121. cognee/modules/users/__init__.py +0 -1
  122. cognee/modules/users/methods/__init__.py +4 -1
  123. cognee/modules/users/methods/create_user.py +26 -1
  124. cognee/modules/users/methods/get_authenticated_user.py +36 -42
  125. cognee/modules/users/methods/get_default_user.py +3 -1
  126. cognee/modules/users/permissions/methods/get_specific_user_permission_datasets.py +2 -1
  127. cognee/root_dir.py +19 -0
  128. cognee/shared/logging_utils.py +1 -1
  129. cognee/tasks/codingagents/__init__.py +0 -0
  130. cognee/tasks/codingagents/coding_rule_associations.py +127 -0
  131. cognee/tasks/ingestion/save_data_item_to_storage.py +23 -0
  132. cognee/tasks/memify/__init__.py +2 -0
  133. cognee/tasks/memify/extract_subgraph.py +7 -0
  134. cognee/tasks/memify/extract_subgraph_chunks.py +11 -0
  135. cognee/tasks/repo_processor/get_repo_file_dependencies.py +52 -27
  136. cognee/tasks/temporal_graph/__init__.py +1 -0
  137. cognee/tasks/temporal_graph/add_entities_to_event.py +85 -0
  138. cognee/tasks/temporal_graph/enrich_events.py +34 -0
  139. cognee/tasks/temporal_graph/extract_events_and_entities.py +32 -0
  140. cognee/tasks/temporal_graph/extract_knowledge_graph_from_events.py +41 -0
  141. cognee/tasks/temporal_graph/models.py +49 -0
  142. cognee/tests/test_kuzu.py +4 -4
  143. cognee/tests/test_neo4j.py +4 -4
  144. cognee/tests/test_permissions.py +3 -3
  145. cognee/tests/test_relational_db_migration.py +7 -5
  146. cognee/tests/test_search_db.py +18 -24
  147. cognee/tests/test_temporal_graph.py +167 -0
  148. cognee/tests/unit/api/__init__.py +1 -0
  149. cognee/tests/unit/api/test_conditional_authentication_endpoints.py +246 -0
  150. cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +18 -2
  151. cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +13 -16
  152. cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +11 -16
  153. cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +5 -4
  154. cognee/tests/unit/modules/retrieval/insights_retriever_test.py +4 -2
  155. cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +18 -2
  156. cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +225 -0
  157. cognee/tests/unit/modules/users/__init__.py +1 -0
  158. cognee/tests/unit/modules/users/test_conditional_authentication.py +277 -0
  159. cognee/tests/unit/processing/utils/utils_test.py +20 -1
  160. {cognee-0.2.4.dist-info → cognee-0.3.0.dist-info}/METADATA +8 -6
  161. {cognee-0.2.4.dist-info → cognee-0.3.0.dist-info}/RECORD +165 -90
  162. cognee/tests/unit/modules/search/search_methods_test.py +0 -225
  163. {cognee-0.2.4.dist-info → cognee-0.3.0.dist-info}/WHEEL +0 -0
  164. {cognee-0.2.4.dist-info → cognee-0.3.0.dist-info}/entry_points.txt +0 -0
  165. {cognee-0.2.4.dist-info → cognee-0.3.0.dist-info}/licenses/LICENSE +0 -0
  166. {cognee-0.2.4.dist-info → cognee-0.3.0.dist-info}/licenses/NOTICE.md +0 -0
@@ -0,0 +1,168 @@
1
+ from typing import Callable, List, Optional, Type
2
+
3
+ from cognee.modules.engine.models.node_set import NodeSet
4
+ from cognee.modules.search.types import SearchType
5
+ from cognee.modules.search.operations import select_search_type
6
+ from cognee.modules.search.exceptions import UnsupportedSearchTypeError
7
+
8
+ # Retrievers
9
+ from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
10
+ from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
11
+ from cognee.modules.retrieval.insights_retriever import InsightsRetriever
12
+ from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
13
+ from cognee.modules.retrieval.completion_retriever import CompletionRetriever
14
+ from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
15
+ from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
16
+ from cognee.modules.retrieval.coding_rules_retriever import CodingRulesRetriever
17
+ from cognee.modules.retrieval.graph_summary_completion_retriever import (
18
+ GraphSummaryCompletionRetriever,
19
+ )
20
+ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
21
+ from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
22
+ GraphCompletionContextExtensionRetriever,
23
+ )
24
+ from cognee.modules.retrieval.code_retriever import CodeRetriever
25
+ from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
26
+ from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
27
+
28
+
29
+ async def get_search_type_tools(
30
+ query_type: SearchType,
31
+ query_text: str,
32
+ system_prompt_path: str = "answer_simple_question.txt",
33
+ system_prompt: Optional[str] = None,
34
+ top_k: int = 10,
35
+ node_type: Optional[Type] = NodeSet,
36
+ node_name: Optional[List[str]] = None,
37
+ save_interaction: bool = False,
38
+ last_k: Optional[int] = None,
39
+ ) -> list:
40
+ search_tasks: dict[SearchType, List[Callable]] = {
41
+ SearchType.SUMMARIES: [
42
+ SummariesRetriever(top_k=top_k).get_completion,
43
+ SummariesRetriever(top_k=top_k).get_context,
44
+ ],
45
+ SearchType.INSIGHTS: [
46
+ InsightsRetriever(top_k=top_k).get_completion,
47
+ InsightsRetriever(top_k=top_k).get_context,
48
+ ],
49
+ SearchType.CHUNKS: [
50
+ ChunksRetriever(top_k=top_k).get_completion,
51
+ ChunksRetriever(top_k=top_k).get_context,
52
+ ],
53
+ SearchType.RAG_COMPLETION: [
54
+ CompletionRetriever(
55
+ system_prompt_path=system_prompt_path,
56
+ top_k=top_k,
57
+ system_prompt=system_prompt,
58
+ ).get_completion,
59
+ CompletionRetriever(
60
+ system_prompt_path=system_prompt_path,
61
+ top_k=top_k,
62
+ system_prompt=system_prompt,
63
+ ).get_context,
64
+ ],
65
+ SearchType.GRAPH_COMPLETION: [
66
+ GraphCompletionRetriever(
67
+ system_prompt_path=system_prompt_path,
68
+ top_k=top_k,
69
+ node_type=node_type,
70
+ node_name=node_name,
71
+ save_interaction=save_interaction,
72
+ system_prompt=system_prompt,
73
+ ).get_completion,
74
+ GraphCompletionRetriever(
75
+ system_prompt_path=system_prompt_path,
76
+ top_k=top_k,
77
+ node_type=node_type,
78
+ node_name=node_name,
79
+ save_interaction=save_interaction,
80
+ system_prompt=system_prompt,
81
+ ).get_context,
82
+ ],
83
+ SearchType.GRAPH_COMPLETION_COT: [
84
+ GraphCompletionCotRetriever(
85
+ system_prompt_path=system_prompt_path,
86
+ top_k=top_k,
87
+ node_type=node_type,
88
+ node_name=node_name,
89
+ save_interaction=save_interaction,
90
+ system_prompt=system_prompt,
91
+ ).get_completion,
92
+ GraphCompletionCotRetriever(
93
+ system_prompt_path=system_prompt_path,
94
+ top_k=top_k,
95
+ node_type=node_type,
96
+ node_name=node_name,
97
+ save_interaction=save_interaction,
98
+ system_prompt=system_prompt,
99
+ ).get_context,
100
+ ],
101
+ SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: [
102
+ GraphCompletionContextExtensionRetriever(
103
+ system_prompt_path=system_prompt_path,
104
+ top_k=top_k,
105
+ node_type=node_type,
106
+ node_name=node_name,
107
+ save_interaction=save_interaction,
108
+ system_prompt=system_prompt,
109
+ ).get_completion,
110
+ GraphCompletionContextExtensionRetriever(
111
+ system_prompt_path=system_prompt_path,
112
+ top_k=top_k,
113
+ node_type=node_type,
114
+ node_name=node_name,
115
+ save_interaction=save_interaction,
116
+ system_prompt=system_prompt,
117
+ ).get_context,
118
+ ],
119
+ SearchType.GRAPH_SUMMARY_COMPLETION: [
120
+ GraphSummaryCompletionRetriever(
121
+ system_prompt_path=system_prompt_path,
122
+ top_k=top_k,
123
+ node_type=node_type,
124
+ node_name=node_name,
125
+ save_interaction=save_interaction,
126
+ system_prompt=system_prompt,
127
+ ).get_completion,
128
+ GraphSummaryCompletionRetriever(
129
+ system_prompt_path=system_prompt_path,
130
+ top_k=top_k,
131
+ node_type=node_type,
132
+ node_name=node_name,
133
+ save_interaction=save_interaction,
134
+ system_prompt=system_prompt,
135
+ ).get_context,
136
+ ],
137
+ SearchType.CODE: [
138
+ CodeRetriever(top_k=top_k).get_completion,
139
+ CodeRetriever(top_k=top_k).get_context,
140
+ ],
141
+ SearchType.CYPHER: [
142
+ CypherSearchRetriever().get_completion,
143
+ CypherSearchRetriever().get_context,
144
+ ],
145
+ SearchType.NATURAL_LANGUAGE: [
146
+ NaturalLanguageRetriever().get_completion,
147
+ NaturalLanguageRetriever().get_context,
148
+ ],
149
+ SearchType.FEEDBACK: [UserQAFeedback(last_k=last_k).add_feedback],
150
+ SearchType.TEMPORAL: [
151
+ TemporalRetriever(top_k=top_k).get_completion,
152
+ TemporalRetriever(top_k=top_k).get_context,
153
+ ],
154
+ SearchType.CODING_RULES: [
155
+ CodingRulesRetriever(rules_nodeset_name=node_name).get_existing_rules,
156
+ ],
157
+ }
158
+
159
+ # If the query type is FEELING_LUCKY, select the search type intelligently
160
+ if query_type is SearchType.FEELING_LUCKY:
161
+ query_type = await select_search_type(query_text)
162
+
163
+ search_type_tools = search_tasks.get(query_type)
164
+
165
+ if not search_type_tools:
166
+ raise UnsupportedSearchTypeError(str(query_type))
167
+
168
+ return search_type_tools
@@ -0,0 +1,47 @@
1
+ from typing import Any, List, Optional, Tuple, Type, Union
2
+
3
+ from cognee.modules.data.models.Dataset import Dataset
4
+ from cognee.modules.engine.models.node_set import NodeSet
5
+ from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
6
+ from cognee.modules.search.types import SearchType
7
+
8
+ from .get_search_type_tools import get_search_type_tools
9
+
10
+
11
+ async def no_access_control_search(
12
+ query_type: SearchType,
13
+ query_text: str,
14
+ system_prompt_path: str = "answer_simple_question.txt",
15
+ system_prompt: Optional[str] = None,
16
+ top_k: int = 10,
17
+ node_type: Optional[Type] = NodeSet,
18
+ node_name: Optional[List[str]] = None,
19
+ save_interaction: bool = False,
20
+ last_k: Optional[int] = None,
21
+ only_context: bool = False,
22
+ ) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
23
+ search_tools = await get_search_type_tools(
24
+ query_type=query_type,
25
+ query_text=query_text,
26
+ system_prompt_path=system_prompt_path,
27
+ system_prompt=system_prompt,
28
+ top_k=top_k,
29
+ node_type=node_type,
30
+ node_name=node_name,
31
+ save_interaction=save_interaction,
32
+ last_k=last_k,
33
+ )
34
+ if len(search_tools) == 2:
35
+ [get_completion, get_context] = search_tools
36
+
37
+ if only_context:
38
+ return await get_context(query_text)
39
+
40
+ context = await get_context(query_text)
41
+ result = await get_completion(query_text, context)
42
+ else:
43
+ unknown_tool = search_tools[0]
44
+ result = await unknown_tool(query_text)
45
+ context = ""
46
+
47
+ return result, context, []
@@ -2,33 +2,28 @@ import os
2
2
  import json
3
3
  import asyncio
4
4
  from uuid import UUID
5
- from typing import Callable, List, Optional, Type, Union
5
+ from fastapi.encoders import jsonable_encoder
6
+ from typing import Any, List, Optional, Tuple, Type, Union
6
7
 
7
- from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
8
- from cognee.modules.search.exceptions import UnsupportedSearchTypeError
8
+ from cognee.shared.utils import send_telemetry
9
9
  from cognee.context_global_variables import set_database_global_context_variables
10
- from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
11
- from cognee.modules.retrieval.insights_retriever import InsightsRetriever
12
- from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
13
- from cognee.modules.retrieval.completion_retriever import CompletionRetriever
14
- from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
15
- from cognee.modules.retrieval.graph_summary_completion_retriever import (
16
- GraphSummaryCompletionRetriever,
17
- )
18
- from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
19
- from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
20
- GraphCompletionContextExtensionRetriever,
10
+
11
+ from cognee.modules.engine.models.node_set import NodeSet
12
+ from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
13
+ from cognee.modules.search.types import (
14
+ SearchResult,
15
+ CombinedSearchResult,
16
+ SearchResultDataset,
17
+ SearchType,
21
18
  )
22
- from cognee.modules.retrieval.code_retriever import CodeRetriever
23
- from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
24
- from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
25
- from cognee.modules.search.types import SearchType
26
- from cognee.modules.storage.utils import JSONEncoder
19
+ from cognee.modules.search.operations import log_query, log_result
27
20
  from cognee.modules.users.models import User
28
21
  from cognee.modules.data.models import Dataset
29
- from cognee.shared.utils import send_telemetry
30
22
  from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets
31
- from cognee.modules.search.operations import log_query, log_result, select_search_type
23
+
24
+ from .get_search_type_tools import get_search_type_tools
25
+ from .no_access_control_search import no_access_control_search
26
+ from ..utils.prepare_search_result import prepare_search_result
32
27
 
33
28
 
34
29
  async def search(
@@ -37,12 +32,15 @@ async def search(
37
32
  dataset_ids: Union[list[UUID], None],
38
33
  user: User,
39
34
  system_prompt_path="answer_simple_question.txt",
35
+ system_prompt: Optional[str] = None,
40
36
  top_k: int = 10,
41
- node_type: Optional[Type] = None,
37
+ node_type: Optional[Type] = NodeSet,
42
38
  node_name: Optional[List[str]] = None,
43
- save_interaction: Optional[bool] = False,
39
+ save_interaction: bool = False,
44
40
  last_k: Optional[int] = None,
45
- ):
41
+ only_context: bool = False,
42
+ use_combined_context: bool = False,
43
+ ) -> Union[CombinedSearchResult, List[SearchResult]]:
46
44
  """
47
45
 
48
46
  Args:
@@ -58,192 +56,274 @@ async def search(
58
56
  Notes:
59
57
  Searching by dataset is only available in ENABLE_BACKEND_ACCESS_CONTROL mode
60
58
  """
59
+ query = await log_query(query_text, query_type.value, user.id)
60
+ send_telemetry("cognee.search EXECUTION STARTED", user.id)
61
+
61
62
  # Use search function filtered by permissions if access control is enabled
62
63
  if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
63
- return await authorized_search(
64
- query_text=query_text,
64
+ search_results = await authorized_search(
65
65
  query_type=query_type,
66
+ query_text=query_text,
66
67
  user=user,
67
68
  dataset_ids=dataset_ids,
68
69
  system_prompt_path=system_prompt_path,
70
+ system_prompt=system_prompt,
69
71
  top_k=top_k,
72
+ node_type=node_type,
73
+ node_name=node_name,
70
74
  save_interaction=save_interaction,
71
75
  last_k=last_k,
76
+ only_context=only_context,
77
+ use_combined_context=use_combined_context,
72
78
  )
79
+ else:
80
+ search_results = [
81
+ await no_access_control_search(
82
+ query_type=query_type,
83
+ query_text=query_text,
84
+ system_prompt_path=system_prompt_path,
85
+ system_prompt=system_prompt,
86
+ top_k=top_k,
87
+ node_type=node_type,
88
+ node_name=node_name,
89
+ save_interaction=save_interaction,
90
+ last_k=last_k,
91
+ only_context=only_context,
92
+ )
93
+ ]
73
94
 
74
- query = await log_query(query_text, query_type.value, user.id)
75
-
76
- search_results = await specific_search(
77
- query_type,
78
- query_text,
79
- user,
80
- system_prompt_path=system_prompt_path,
81
- top_k=top_k,
82
- node_type=node_type,
83
- node_name=node_name,
84
- save_interaction=save_interaction,
85
- last_k=last_k,
86
- )
95
+ send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
87
96
 
88
97
  await log_result(
89
98
  query.id,
90
99
  json.dumps(
91
- search_results if len(search_results) > 1 else search_results[0], cls=JSONEncoder
100
+ jsonable_encoder(
101
+ await prepare_search_result(
102
+ search_results[0] if isinstance(search_results, list) else search_results
103
+ )
104
+ if use_combined_context
105
+ else [
106
+ await prepare_search_result(search_result) for search_result in search_results
107
+ ]
108
+ )
92
109
  ),
93
110
  user.id,
94
111
  )
95
112
 
96
- return search_results
113
+ if use_combined_context:
114
+ prepared_search_results = await prepare_search_result(
115
+ search_results[0] if isinstance(search_results, list) else search_results
116
+ )
117
+ result = prepared_search_results["result"]
118
+ graphs = prepared_search_results["graphs"]
119
+ context = prepared_search_results["context"]
120
+ datasets = prepared_search_results["datasets"]
97
121
 
122
+ return CombinedSearchResult(
123
+ result=result,
124
+ graphs=graphs,
125
+ context=context,
126
+ datasets=[
127
+ SearchResultDataset(
128
+ id=dataset.id,
129
+ name=dataset.name,
130
+ )
131
+ for dataset in datasets
132
+ ],
133
+ )
134
+ else:
135
+ return [
136
+ SearchResult(
137
+ search_result=result,
138
+ dataset_id=datasets[min(index, len(datasets) - 1)].id if datasets else None,
139
+ dataset_name=datasets[min(index, len(datasets) - 1)].name if datasets else None,
140
+ )
141
+ for index, (result, _, datasets) in enumerate(search_results)
142
+ ]
98
143
 
99
- async def specific_search(
144
+
145
+ async def authorized_search(
100
146
  query_type: SearchType,
101
- query: str,
147
+ query_text: str,
102
148
  user: User,
103
- system_prompt_path="answer_simple_question.txt",
149
+ dataset_ids: Optional[list[UUID]] = None,
150
+ system_prompt_path: str = "answer_simple_question.txt",
151
+ system_prompt: Optional[str] = None,
104
152
  top_k: int = 10,
105
- node_type: Optional[Type] = None,
153
+ node_type: Optional[Type] = NodeSet,
106
154
  node_name: Optional[List[str]] = None,
107
- save_interaction: Optional[bool] = False,
155
+ save_interaction: bool = False,
108
156
  last_k: Optional[int] = None,
109
- ) -> list:
110
- search_tasks: dict[SearchType, Callable] = {
111
- SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
112
- SearchType.INSIGHTS: InsightsRetriever(top_k=top_k).get_completion,
113
- SearchType.CHUNKS: ChunksRetriever(top_k=top_k).get_completion,
114
- SearchType.RAG_COMPLETION: CompletionRetriever(
115
- system_prompt_path=system_prompt_path, top_k=top_k
116
- ).get_completion,
117
- SearchType.GRAPH_COMPLETION: GraphCompletionRetriever(
118
- system_prompt_path=system_prompt_path,
119
- top_k=top_k,
120
- node_type=node_type,
121
- node_name=node_name,
122
- save_interaction=save_interaction,
123
- ).get_completion,
124
- SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever(
125
- system_prompt_path=system_prompt_path,
126
- top_k=top_k,
127
- node_type=node_type,
128
- node_name=node_name,
129
- save_interaction=save_interaction,
130
- ).get_completion,
131
- SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever(
157
+ only_context: bool = False,
158
+ use_combined_context: bool = False,
159
+ ) -> Union[
160
+ Tuple[Any, Union[List[Edge], str], List[Dataset]],
161
+ List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
162
+ ]:
163
+ """
164
+ Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
165
+ Not to be used outside of active access control mode.
166
+ """
167
+ # Find datasets user has read access for (if datasets are provided only return them. Provided user has read access)
168
+ search_datasets = await get_specific_user_permission_datasets(user.id, "read", dataset_ids)
169
+
170
+ if use_combined_context:
171
+ search_responses = await search_in_datasets_context(
172
+ search_datasets=search_datasets,
173
+ query_type=query_type,
174
+ query_text=query_text,
132
175
  system_prompt_path=system_prompt_path,
176
+ system_prompt=system_prompt,
133
177
  top_k=top_k,
134
178
  node_type=node_type,
135
179
  node_name=node_name,
136
180
  save_interaction=save_interaction,
137
- ).get_completion,
138
- SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
181
+ last_k=last_k,
182
+ only_context=True,
183
+ )
184
+
185
+ context = {}
186
+ datasets: List[Dataset] = []
187
+
188
+ for _, search_context, datasets in search_responses:
189
+ for dataset in datasets:
190
+ context[str(dataset.id)] = search_context
191
+
192
+ datasets.extend(datasets)
193
+
194
+ specific_search_tools = await get_search_type_tools(
195
+ query_type=query_type,
196
+ query_text=query_text,
139
197
  system_prompt_path=system_prompt_path,
198
+ system_prompt=system_prompt,
140
199
  top_k=top_k,
141
200
  node_type=node_type,
142
201
  node_name=node_name,
143
202
  save_interaction=save_interaction,
144
- ).get_completion,
145
- SearchType.CODE: CodeRetriever(top_k=top_k).get_completion,
146
- SearchType.CYPHER: CypherSearchRetriever().get_completion,
147
- SearchType.NATURAL_LANGUAGE: NaturalLanguageRetriever().get_completion,
148
- SearchType.FEEDBACK: UserQAFeedback(last_k=last_k).add_feedback,
149
- }
150
-
151
- # If the query type is FEELING_LUCKY, select the search type intelligently
152
- if query_type is SearchType.FEELING_LUCKY:
153
- query_type = await select_search_type(query)
154
-
155
- search_task = search_tasks.get(query_type)
156
-
157
- if search_task is None:
158
- raise UnsupportedSearchTypeError(str(query_type))
159
-
160
- send_telemetry("cognee.search EXECUTION STARTED", user.id)
161
-
162
- results = await search_task(query)
203
+ last_k=last_k,
204
+ )
205
+ search_tools = specific_search_tools
206
+ if len(search_tools) == 2:
207
+ [get_completion, _] = search_tools
208
+ else:
209
+ get_completion = search_tools[0]
163
210
 
164
- send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
211
+ def prepare_combined_context(
212
+ context,
213
+ ) -> Union[List[Edge], str]:
214
+ combined_context = []
165
215
 
166
- return results
216
+ for dataset_context in context.values():
217
+ combined_context += dataset_context
167
218
 
219
+ if combined_context and isinstance(combined_context[0], str):
220
+ return "\n".join(combined_context)
168
221
 
169
- async def authorized_search(
170
- query_text: str,
171
- query_type: SearchType,
172
- user: User = None,
173
- dataset_ids: Optional[list[UUID]] = None,
174
- system_prompt_path: str = "answer_simple_question.txt",
175
- top_k: int = 10,
176
- save_interaction: bool = False,
177
- last_k: Optional[int] = None,
178
- ) -> list:
179
- """
180
- Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
181
- Not to be used outside of active access control mode.
182
- """
222
+ return combined_context
183
223
 
184
- query = await log_query(query_text, query_type.value, user.id)
224
+ combined_context = prepare_combined_context(context)
225
+ completion = await get_completion(query_text, combined_context)
185
226
 
186
- # Find datasets user has read access for (if datasets are provided only return them. Provided user has read access)
187
- search_datasets = await get_specific_user_permission_datasets(user.id, "read", dataset_ids)
227
+ return completion, combined_context, datasets
188
228
 
189
229
  # Searches all provided datasets and handles setting up of appropriate database context based on permissions
190
- search_results = await specific_search_by_context(
191
- search_datasets,
192
- query_text,
193
- query_type,
194
- user,
195
- system_prompt_path,
196
- top_k,
197
- save_interaction,
230
+ search_results = await search_in_datasets_context(
231
+ search_datasets=search_datasets,
232
+ query_type=query_type,
233
+ query_text=query_text,
234
+ system_prompt_path=system_prompt_path,
235
+ system_prompt=system_prompt,
236
+ top_k=top_k,
237
+ node_type=node_type,
238
+ node_name=node_name,
239
+ save_interaction=save_interaction,
198
240
  last_k=last_k,
241
+ only_context=only_context,
199
242
  )
200
243
 
201
- await log_result(query.id, json.dumps(search_results, cls=JSONEncoder), user.id)
202
-
203
244
  return search_results
204
245
 
205
246
 
206
- async def specific_search_by_context(
247
+ async def search_in_datasets_context(
207
248
  search_datasets: list[Dataset],
208
- query_text: str,
209
249
  query_type: SearchType,
210
- user: User,
211
- system_prompt_path: str,
212
- top_k: int,
250
+ query_text: str,
251
+ system_prompt_path: str = "answer_simple_question.txt",
252
+ system_prompt: Optional[str] = None,
253
+ top_k: int = 10,
254
+ node_type: Optional[Type] = NodeSet,
255
+ node_name: Optional[List[str]] = None,
213
256
  save_interaction: bool = False,
214
257
  last_k: Optional[int] = None,
215
- ):
258
+ only_context: bool = False,
259
+ context: Optional[Any] = None,
260
+ ) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]:
216
261
  """
217
262
  Searches all provided datasets and handles setting up of appropriate database context based on permissions.
218
263
  Not to be used outside of active access control mode.
219
264
  """
220
265
 
221
- async def _search_by_context(
222
- dataset, user, query_type, query_text, system_prompt_path, top_k, last_k
223
- ):
266
+ async def _search_in_dataset_context(
267
+ dataset: Dataset,
268
+ query_type: SearchType,
269
+ query_text: str,
270
+ system_prompt_path: str = "answer_simple_question.txt",
271
+ system_prompt: Optional[str] = None,
272
+ top_k: int = 10,
273
+ node_type: Optional[Type] = NodeSet,
274
+ node_name: Optional[List[str]] = None,
275
+ save_interaction: bool = False,
276
+ last_k: Optional[int] = None,
277
+ only_context: bool = False,
278
+ context: Optional[Any] = None,
279
+ ) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
224
280
  # Set database configuration in async context for each dataset user has access for
225
281
  await set_database_global_context_variables(dataset.id, dataset.owner_id)
226
- search_results = await specific_search(
227
- query_type,
228
- query_text,
229
- user,
282
+
283
+ specific_search_tools = await get_search_type_tools(
284
+ query_type=query_type,
285
+ query_text=query_text,
230
286
  system_prompt_path=system_prompt_path,
287
+ system_prompt=system_prompt,
231
288
  top_k=top_k,
289
+ node_type=node_type,
290
+ node_name=node_name,
232
291
  save_interaction=save_interaction,
233
292
  last_k=last_k,
234
293
  )
235
- return {
236
- "search_result": search_results,
237
- "dataset_id": dataset.id,
238
- "dataset_name": dataset.name,
239
- }
294
+ search_tools = specific_search_tools
295
+ if len(search_tools) == 2:
296
+ [get_completion, get_context] = search_tools
297
+
298
+ if only_context:
299
+ return None, await get_context(query_text), [dataset]
300
+
301
+ search_context = context or await get_context(query_text)
302
+ search_result = await get_completion(query_text, search_context)
303
+
304
+ return search_result, search_context, [dataset]
305
+ else:
306
+ unknown_tool = search_tools[0]
307
+
308
+ return await unknown_tool(query_text), "", [dataset]
240
309
 
241
310
  # Search every dataset async based on query and appropriate database configuration
242
311
  tasks = []
243
312
  for dataset in search_datasets:
244
313
  tasks.append(
245
- _search_by_context(
246
- dataset, user, query_type, query_text, system_prompt_path, top_k, last_k
314
+ _search_in_dataset_context(
315
+ dataset=dataset,
316
+ query_type=query_type,
317
+ query_text=query_text,
318
+ system_prompt_path=system_prompt_path,
319
+ system_prompt=system_prompt,
320
+ top_k=top_k,
321
+ node_type=node_type,
322
+ node_name=node_name,
323
+ save_interaction=save_interaction,
324
+ last_k=last_k,
325
+ only_context=only_context,
326
+ context=context,
247
327
  )
248
328
  )
249
329