cognee 0.5.0.dev0__py3-none-any.whl → 0.5.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. cognee/api/client.py +1 -5
  2. cognee/api/v1/add/add.py +2 -1
  3. cognee/api/v1/cognify/cognify.py +24 -16
  4. cognee/api/v1/cognify/routers/__init__.py +0 -1
  5. cognee/api/v1/cognify/routers/get_cognify_router.py +3 -1
  6. cognee/api/v1/datasets/routers/get_datasets_router.py +3 -3
  7. cognee/api/v1/ontologies/ontologies.py +12 -37
  8. cognee/api/v1/ontologies/routers/get_ontology_router.py +27 -25
  9. cognee/api/v1/search/search.py +4 -0
  10. cognee/api/v1/ui/node_setup.py +360 -0
  11. cognee/api/v1/ui/npm_utils.py +50 -0
  12. cognee/api/v1/ui/ui.py +38 -68
  13. cognee/context_global_variables.py +61 -16
  14. cognee/eval_framework/Dockerfile +29 -0
  15. cognee/eval_framework/answer_generation/answer_generation_executor.py +10 -0
  16. cognee/eval_framework/answer_generation/run_question_answering_module.py +1 -1
  17. cognee/eval_framework/corpus_builder/task_getters/get_cascade_graph_tasks.py +0 -2
  18. cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py +4 -4
  19. cognee/eval_framework/eval_config.py +2 -2
  20. cognee/eval_framework/modal_run_eval.py +16 -28
  21. cognee/infrastructure/databases/dataset_database_handler/__init__.py +3 -0
  22. cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py +80 -0
  23. cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +18 -0
  24. cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py +10 -0
  25. cognee/infrastructure/databases/graph/config.py +3 -0
  26. cognee/infrastructure/databases/graph/get_graph_engine.py +1 -0
  27. cognee/infrastructure/databases/graph/graph_db_interface.py +15 -0
  28. cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py +81 -0
  29. cognee/infrastructure/databases/graph/kuzu/adapter.py +228 -0
  30. cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +168 -0
  31. cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +80 -1
  32. cognee/infrastructure/databases/utils/__init__.py +3 -0
  33. cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py +10 -0
  34. cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +62 -48
  35. cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py +10 -0
  36. cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py +30 -0
  37. cognee/infrastructure/databases/vector/config.py +2 -0
  38. cognee/infrastructure/databases/vector/create_vector_engine.py +1 -0
  39. cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +8 -6
  40. cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +9 -7
  41. cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +11 -10
  42. cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +2 -0
  43. cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py +50 -0
  44. cognee/infrastructure/databases/vector/vector_db_interface.py +35 -0
  45. cognee/infrastructure/files/storage/s3_config.py +2 -0
  46. cognee/infrastructure/llm/LLMGateway.py +5 -2
  47. cognee/infrastructure/llm/config.py +35 -0
  48. cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py +2 -2
  49. cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +23 -8
  50. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -16
  51. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py +5 -0
  52. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py +153 -0
  53. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +40 -37
  54. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +39 -36
  55. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +19 -1
  56. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +11 -9
  57. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +23 -21
  58. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +42 -34
  59. cognee/memify_pipelines/create_triplet_embeddings.py +53 -0
  60. cognee/modules/cognify/config.py +2 -0
  61. cognee/modules/data/deletion/prune_system.py +52 -2
  62. cognee/modules/data/methods/delete_dataset.py +26 -0
  63. cognee/modules/engine/models/Triplet.py +9 -0
  64. cognee/modules/engine/models/__init__.py +1 -0
  65. cognee/modules/graph/cognee_graph/CogneeGraph.py +85 -37
  66. cognee/modules/graph/cognee_graph/CogneeGraphElements.py +8 -3
  67. cognee/modules/memify/memify.py +1 -7
  68. cognee/modules/pipelines/operations/pipeline.py +18 -2
  69. cognee/modules/retrieval/__init__.py +1 -1
  70. cognee/modules/retrieval/graph_completion_context_extension_retriever.py +4 -0
  71. cognee/modules/retrieval/graph_completion_cot_retriever.py +4 -0
  72. cognee/modules/retrieval/graph_completion_retriever.py +10 -0
  73. cognee/modules/retrieval/graph_summary_completion_retriever.py +4 -0
  74. cognee/modules/retrieval/register_retriever.py +10 -0
  75. cognee/modules/retrieval/registered_community_retrievers.py +1 -0
  76. cognee/modules/retrieval/temporal_retriever.py +4 -0
  77. cognee/modules/retrieval/triplet_retriever.py +182 -0
  78. cognee/modules/retrieval/utils/brute_force_triplet_search.py +42 -10
  79. cognee/modules/run_custom_pipeline/run_custom_pipeline.py +8 -1
  80. cognee/modules/search/methods/get_search_type_tools.py +54 -8
  81. cognee/modules/search/methods/no_access_control_search.py +4 -0
  82. cognee/modules/search/methods/search.py +21 -0
  83. cognee/modules/search/types/SearchType.py +1 -1
  84. cognee/modules/settings/get_settings.py +19 -0
  85. cognee/modules/users/methods/get_authenticated_user.py +2 -2
  86. cognee/modules/users/models/DatasetDatabase.py +15 -3
  87. cognee/shared/logging_utils.py +4 -0
  88. cognee/shared/rate_limiting.py +30 -0
  89. cognee/tasks/documents/__init__.py +0 -1
  90. cognee/tasks/graph/extract_graph_from_data.py +9 -10
  91. cognee/tasks/memify/get_triplet_datapoints.py +289 -0
  92. cognee/tasks/storage/add_data_points.py +142 -2
  93. cognee/tests/integration/retrieval/test_triplet_retriever.py +84 -0
  94. cognee/tests/integration/tasks/test_add_data_points.py +139 -0
  95. cognee/tests/integration/tasks/test_get_triplet_datapoints.py +69 -0
  96. cognee/tests/test_cognee_server_start.py +2 -4
  97. cognee/tests/test_conversation_history.py +23 -1
  98. cognee/tests/test_dataset_database_handler.py +137 -0
  99. cognee/tests/test_dataset_delete.py +76 -0
  100. cognee/tests/test_edge_centered_payload.py +170 -0
  101. cognee/tests/test_pipeline_cache.py +164 -0
  102. cognee/tests/test_search_db.py +37 -1
  103. cognee/tests/unit/api/test_ontology_endpoint.py +77 -89
  104. cognee/tests/unit/infrastructure/llm/test_llm_config.py +46 -0
  105. cognee/tests/unit/infrastructure/mock_embedding_engine.py +3 -7
  106. cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +0 -5
  107. cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +2 -2
  108. cognee/tests/unit/modules/graph/cognee_graph_test.py +406 -0
  109. cognee/tests/unit/modules/memify_tasks/test_get_triplet_datapoints.py +214 -0
  110. cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +608 -0
  111. cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +83 -0
  112. cognee/tests/unit/tasks/storage/test_add_data_points.py +288 -0
  113. {cognee-0.5.0.dev0.dist-info → cognee-0.5.0.dev1.dist-info}/METADATA +76 -89
  114. {cognee-0.5.0.dev0.dist-info → cognee-0.5.0.dev1.dist-info}/RECORD +118 -97
  115. {cognee-0.5.0.dev0.dist-info → cognee-0.5.0.dev1.dist-info}/WHEEL +1 -1
  116. cognee/api/v1/cognify/code_graph_pipeline.py +0 -119
  117. cognee/api/v1/cognify/routers/get_code_pipeline_router.py +0 -90
  118. cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +0 -544
  119. cognee/modules/retrieval/code_retriever.py +0 -232
  120. cognee/tasks/code/enrich_dependency_graph_checker.py +0 -35
  121. cognee/tasks/code/get_local_dependencies_checker.py +0 -20
  122. cognee/tasks/code/get_repo_dependency_graph_checker.py +0 -35
  123. cognee/tasks/documents/check_permissions_on_dataset.py +0 -26
  124. cognee/tasks/repo_processor/__init__.py +0 -2
  125. cognee/tasks/repo_processor/get_local_dependencies.py +0 -335
  126. cognee/tasks/repo_processor/get_non_code_files.py +0 -158
  127. cognee/tasks/repo_processor/get_repo_file_dependencies.py +0 -243
  128. cognee/tests/test_delete_bmw_example.py +0 -60
  129. {cognee-0.5.0.dev0.dist-info → cognee-0.5.0.dev1.dist-info}/entry_points.txt +0 -0
  130. {cognee-0.5.0.dev0.dist-info → cognee-0.5.0.dev1.dist-info}/licenses/LICENSE +0 -0
  131. {cognee-0.5.0.dev0.dist-info → cognee-0.5.0.dev1.dist-info}/licenses/NOTICE.md +0 -0
@@ -0,0 +1,182 @@
1
+ import asyncio
2
+ from typing import Any, Optional, Type, List
3
+
4
+ from cognee.shared.logging_utils import get_logger
5
+ from cognee.infrastructure.databases.vector import get_vector_engine
6
+ from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
7
+ from cognee.modules.retrieval.utils.session_cache import (
8
+ save_conversation_history,
9
+ get_conversation_history,
10
+ )
11
+ from cognee.modules.retrieval.base_retriever import BaseRetriever
12
+ from cognee.modules.retrieval.exceptions.exceptions import NoDataError
13
+ from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
14
+ from cognee.context_global_variables import session_user
15
+ from cognee.infrastructure.databases.cache.config import CacheConfig
16
+
17
+ logger = get_logger("TripletRetriever")
18
+
19
+
20
+ class TripletRetriever(BaseRetriever):
21
+ """
22
+ Retriever for handling LLM-based completion searches using triplets.
23
+
24
+ Public methods:
25
+ - get_context(query: str) -> str
26
+ - get_completion(query: str, context: Optional[Any] = None) -> Any
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ user_prompt_path: str = "context_for_question.txt",
32
+ system_prompt_path: str = "answer_simple_question.txt",
33
+ system_prompt: Optional[str] = None,
34
+ top_k: Optional[int] = 5,
35
+ ):
36
+ """Initialize retriever with optional custom prompt paths."""
37
+ self.user_prompt_path = user_prompt_path
38
+ self.system_prompt_path = system_prompt_path
39
+ self.top_k = top_k if top_k is not None else 1
40
+ self.system_prompt = system_prompt
41
+
42
+ async def get_context(self, query: str) -> str:
43
+ """
44
+ Retrieves relevant triplets as context.
45
+
46
+ Fetches triplets based on a query from a vector engine and combines their text.
47
+ Returns empty string if no triplets are found. Raises NoDataError if the collection is not
48
+ found.
49
+
50
+ Parameters:
51
+ -----------
52
+
53
+ - query (str): The query string used to search for relevant triplets.
54
+
55
+ Returns:
56
+ --------
57
+
58
+ - str: A string containing the combined text of the retrieved triplets, or an
59
+ empty string if none are found.
60
+ """
61
+ vector_engine = get_vector_engine()
62
+
63
+ try:
64
+ if not await vector_engine.has_collection(collection_name="Triplet_text"):
65
+ logger.error("Triplet_text collection not found")
66
+ raise NoDataError(
67
+ "In order to use TRIPLET_COMPLETION first use the create_triplet_embeddings memify pipeline. "
68
+ )
69
+
70
+ found_triplets = await vector_engine.search("Triplet_text", query, limit=self.top_k)
71
+
72
+ if len(found_triplets) == 0:
73
+ return ""
74
+
75
+ triplets_payload = [found_triplet.payload["text"] for found_triplet in found_triplets]
76
+ combined_context = "\n".join(triplets_payload)
77
+ return combined_context
78
+ except CollectionNotFoundError as error:
79
+ logger.error("Triplet_text collection not found")
80
+ raise NoDataError("No data found in the system, please add data first.") from error
81
+
82
+ async def get_completion(
83
+ self,
84
+ query: str,
85
+ context: Optional[Any] = None,
86
+ session_id: Optional[str] = None,
87
+ response_model: Type = str,
88
+ ) -> List[Any]:
89
+ """
90
+ Generates an LLM completion using the context.
91
+
92
+ Retrieves context if not provided and generates a completion based on the query and
93
+ context using an external completion generator.
94
+
95
+ Parameters:
96
+ -----------
97
+
98
+ - query (str): The query string to be used for generating a completion.
99
+ - context (Optional[Any]): Optional pre-fetched context to use for generating the
100
+ completion; if None, it retrieves the context for the query. (default None)
101
+ - session_id (Optional[str]): Optional session identifier for caching. If None,
102
+ defaults to 'default_session'. (default None)
103
+ - response_model (Type): The Pydantic model type for structured output. (default str)
104
+
105
+ Returns:
106
+ --------
107
+
108
+ - Any: The generated completion based on the provided query and context.
109
+ """
110
+ if context is None:
111
+ context = await self.get_context(query)
112
+
113
+ cache_config = CacheConfig()
114
+ user = session_user.get()
115
+ user_id = getattr(user, "id", None)
116
+ session_save = user_id and cache_config.caching
117
+
118
+ if session_save:
119
+ completion = await self._get_completion_with_session(
120
+ query=query,
121
+ context=context,
122
+ session_id=session_id,
123
+ response_model=response_model,
124
+ )
125
+ else:
126
+ completion = await self._get_completion_without_session(
127
+ query=query,
128
+ context=context,
129
+ response_model=response_model,
130
+ )
131
+
132
+ return [completion]
133
+
134
+ async def _get_completion_with_session(
135
+ self,
136
+ query: str,
137
+ context: str,
138
+ session_id: Optional[str],
139
+ response_model: Type,
140
+ ) -> Any:
141
+ """Generate completion with session history and caching."""
142
+ conversation_history = await get_conversation_history(session_id=session_id)
143
+
144
+ context_summary, completion = await asyncio.gather(
145
+ summarize_text(context),
146
+ generate_completion(
147
+ query=query,
148
+ context=context,
149
+ user_prompt_path=self.user_prompt_path,
150
+ system_prompt_path=self.system_prompt_path,
151
+ system_prompt=self.system_prompt,
152
+ conversation_history=conversation_history,
153
+ response_model=response_model,
154
+ ),
155
+ )
156
+
157
+ await save_conversation_history(
158
+ query=query,
159
+ context_summary=context_summary,
160
+ answer=completion,
161
+ session_id=session_id,
162
+ )
163
+
164
+ return completion
165
+
166
+ async def _get_completion_without_session(
167
+ self,
168
+ query: str,
169
+ context: str,
170
+ response_model: Type,
171
+ ) -> Any:
172
+ """Generate completion without session history."""
173
+ completion = await generate_completion(
174
+ query=query,
175
+ context=context,
176
+ user_prompt_path=self.user_prompt_path,
177
+ system_prompt_path=self.system_prompt_path,
178
+ system_prompt=self.system_prompt,
179
+ response_model=response_model,
180
+ )
181
+
182
+ return completion
@@ -58,6 +58,8 @@ async def get_memory_fragment(
58
58
  properties_to_project: Optional[List[str]] = None,
59
59
  node_type: Optional[Type] = None,
60
60
  node_name: Optional[List[str]] = None,
61
+ relevant_ids_to_filter: Optional[List[str]] = None,
62
+ triplet_distance_penalty: Optional[float] = 3.5,
61
63
  ) -> CogneeGraph:
62
64
  """Creates and initializes a CogneeGraph memory fragment with optional property projections."""
63
65
  if properties_to_project is None:
@@ -74,6 +76,8 @@ async def get_memory_fragment(
74
76
  edge_properties_to_project=["relationship_name", "edge_text"],
75
77
  node_type=node_type,
76
78
  node_name=node_name,
79
+ relevant_ids_to_filter=relevant_ids_to_filter,
80
+ triplet_distance_penalty=triplet_distance_penalty,
77
81
  )
78
82
 
79
83
  except EntityNotFoundError:
@@ -95,6 +99,8 @@ async def brute_force_triplet_search(
95
99
  memory_fragment: Optional[CogneeGraph] = None,
96
100
  node_type: Optional[Type] = None,
97
101
  node_name: Optional[List[str]] = None,
102
+ wide_search_top_k: Optional[int] = 100,
103
+ triplet_distance_penalty: Optional[float] = 3.5,
98
104
  ) -> List[Edge]:
99
105
  """
100
106
  Performs a brute force search to retrieve the top triplets from the graph.
@@ -107,6 +113,8 @@ async def brute_force_triplet_search(
107
113
  memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse.
108
114
  node_type: node type to filter
109
115
  node_name: node name to filter
116
+ wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections
117
+ triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection
110
118
 
111
119
  Returns:
112
120
  list: The top triplet results.
@@ -116,10 +124,10 @@ async def brute_force_triplet_search(
116
124
  if top_k <= 0:
117
125
  raise ValueError("top_k must be a positive integer.")
118
126
 
119
- if memory_fragment is None:
120
- memory_fragment = await get_memory_fragment(
121
- properties_to_project, node_type=node_type, node_name=node_name
122
- )
127
+ # Setting wide search limit based on the parameters
128
+ non_global_search = node_name is None
129
+
130
+ wide_search_limit = wide_search_top_k if non_global_search else None
123
131
 
124
132
  if collections is None:
125
133
  collections = [
@@ -129,6 +137,9 @@ async def brute_force_triplet_search(
129
137
  "DocumentChunk_text",
130
138
  ]
131
139
 
140
+ if "EdgeType_relationship_name" not in collections:
141
+ collections.append("EdgeType_relationship_name")
142
+
132
143
  try:
133
144
  vector_engine = get_vector_engine()
134
145
  except Exception as e:
@@ -140,7 +151,7 @@ async def brute_force_triplet_search(
140
151
  async def search_in_collection(collection_name: str):
141
152
  try:
142
153
  return await vector_engine.search(
143
- collection_name=collection_name, query_vector=query_vector, limit=None
154
+ collection_name=collection_name, query_vector=query_vector, limit=wide_search_limit
144
155
  )
145
156
  except CollectionNotFoundError:
146
157
  return []
@@ -156,19 +167,40 @@ async def brute_force_triplet_search(
156
167
  return []
157
168
 
158
169
  # Final statistics
159
- projection_time = time.time() - start_time
170
+ vector_collection_search_time = time.time() - start_time
160
171
  logger.info(
161
- f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {projection_time:.2f}s"
172
+ f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {vector_collection_search_time:.2f}s"
162
173
  )
163
174
 
164
175
  node_distances = {collection: result for collection, result in zip(collections, results)}
165
176
 
166
177
  edge_distances = node_distances.get("EdgeType_relationship_name", None)
167
178
 
179
+ if wide_search_limit is not None:
180
+ relevant_ids_to_filter = list(
181
+ {
182
+ str(getattr(scored_node, "id"))
183
+ for collection_name, score_collection in node_distances.items()
184
+ if collection_name != "EdgeType_relationship_name"
185
+ and isinstance(score_collection, (list, tuple))
186
+ for scored_node in score_collection
187
+ if getattr(scored_node, "id", None)
188
+ }
189
+ )
190
+ else:
191
+ relevant_ids_to_filter = None
192
+
193
+ if memory_fragment is None:
194
+ memory_fragment = await get_memory_fragment(
195
+ properties_to_project=properties_to_project,
196
+ node_type=node_type,
197
+ node_name=node_name,
198
+ relevant_ids_to_filter=relevant_ids_to_filter,
199
+ triplet_distance_penalty=triplet_distance_penalty,
200
+ )
201
+
168
202
  await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
169
- await memory_fragment.map_vector_distances_to_graph_edges(
170
- vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances
171
- )
203
+ await memory_fragment.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
172
204
 
173
205
  results = await memory_fragment.calculate_top_triplet_importances(k=top_k)
174
206
 
@@ -18,6 +18,8 @@ async def run_custom_pipeline(
18
18
  user: User = None,
19
19
  vector_db_config: Optional[dict] = None,
20
20
  graph_db_config: Optional[dict] = None,
21
+ use_pipeline_cache: bool = False,
22
+ incremental_loading: bool = False,
21
23
  data_per_batch: int = 20,
22
24
  run_in_background: bool = False,
23
25
  pipeline_name: str = "custom_pipeline",
@@ -40,6 +42,10 @@ async def run_custom_pipeline(
40
42
  user: User context for authentication and data access. Uses default if None.
41
43
  vector_db_config: Custom vector database configuration for embeddings storage.
42
44
  graph_db_config: Custom graph database configuration for relationship storage.
45
+ use_pipeline_cache: If True, pipelines with the same ID that are currently executing and pipelines with the same ID that were completed won't process data again.
46
+ Pipelines ID is created based on the generate_pipeline_id function. Pipeline status can be manually reset with the reset_dataset_pipeline_run_status function.
47
+ incremental_loading: If True, only new or modified data will be processed to avoid duplication. (Only works if data is used with the Cognee python Data model).
48
+ The incremental system stores and compares hashes of processed data in the Data model and skips data with the same content hash.
43
49
  data_per_batch: Number of data items to be processed in parallel.
44
50
  run_in_background: If True, starts processing asynchronously and returns immediately.
45
51
  If False, waits for completion before returning.
@@ -63,7 +69,8 @@ async def run_custom_pipeline(
63
69
  datasets=dataset,
64
70
  vector_db_config=vector_db_config,
65
71
  graph_db_config=graph_db_config,
66
- incremental_loading=False,
72
+ use_pipeline_cache=use_pipeline_cache,
73
+ incremental_loading=incremental_loading,
67
74
  data_per_batch=data_per_batch,
68
75
  pipeline_name=pipeline_name,
69
76
  )
@@ -2,6 +2,7 @@ import os
2
2
  from typing import Callable, List, Optional, Type
3
3
 
4
4
  from cognee.modules.engine.models.node_set import NodeSet
5
+ from cognee.modules.retrieval.triplet_retriever import TripletRetriever
5
6
  from cognee.modules.search.types import SearchType
6
7
  from cognee.modules.search.operations import select_search_type
7
8
  from cognee.modules.search.exceptions import UnsupportedSearchTypeError
@@ -22,7 +23,6 @@ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphComplet
22
23
  from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
23
24
  GraphCompletionContextExtensionRetriever,
24
25
  )
25
- from cognee.modules.retrieval.code_retriever import CodeRetriever
26
26
  from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
27
27
  from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
28
28
 
@@ -37,6 +37,8 @@ async def get_search_type_tools(
37
37
  node_name: Optional[List[str]] = None,
38
38
  save_interaction: bool = False,
39
39
  last_k: Optional[int] = None,
40
+ wide_search_top_k: Optional[int] = 100,
41
+ triplet_distance_penalty: Optional[float] = 3.5,
40
42
  ) -> list:
41
43
  search_tasks: dict[SearchType, List[Callable]] = {
42
44
  SearchType.SUMMARIES: [
@@ -59,6 +61,18 @@ async def get_search_type_tools(
59
61
  system_prompt=system_prompt,
60
62
  ).get_context,
61
63
  ],
64
+ SearchType.TRIPLET_COMPLETION: [
65
+ TripletRetriever(
66
+ system_prompt_path=system_prompt_path,
67
+ top_k=top_k,
68
+ system_prompt=system_prompt,
69
+ ).get_completion,
70
+ TripletRetriever(
71
+ system_prompt_path=system_prompt_path,
72
+ top_k=top_k,
73
+ system_prompt=system_prompt,
74
+ ).get_context,
75
+ ],
62
76
  SearchType.GRAPH_COMPLETION: [
63
77
  GraphCompletionRetriever(
64
78
  system_prompt_path=system_prompt_path,
@@ -67,6 +81,8 @@ async def get_search_type_tools(
67
81
  node_name=node_name,
68
82
  save_interaction=save_interaction,
69
83
  system_prompt=system_prompt,
84
+ wide_search_top_k=wide_search_top_k,
85
+ triplet_distance_penalty=triplet_distance_penalty,
70
86
  ).get_completion,
71
87
  GraphCompletionRetriever(
72
88
  system_prompt_path=system_prompt_path,
@@ -75,6 +91,8 @@ async def get_search_type_tools(
75
91
  node_name=node_name,
76
92
  save_interaction=save_interaction,
77
93
  system_prompt=system_prompt,
94
+ wide_search_top_k=wide_search_top_k,
95
+ triplet_distance_penalty=triplet_distance_penalty,
78
96
  ).get_context,
79
97
  ],
80
98
  SearchType.GRAPH_COMPLETION_COT: [
@@ -85,6 +103,8 @@ async def get_search_type_tools(
85
103
  node_name=node_name,
86
104
  save_interaction=save_interaction,
87
105
  system_prompt=system_prompt,
106
+ wide_search_top_k=wide_search_top_k,
107
+ triplet_distance_penalty=triplet_distance_penalty,
88
108
  ).get_completion,
89
109
  GraphCompletionCotRetriever(
90
110
  system_prompt_path=system_prompt_path,
@@ -93,6 +113,8 @@ async def get_search_type_tools(
93
113
  node_name=node_name,
94
114
  save_interaction=save_interaction,
95
115
  system_prompt=system_prompt,
116
+ wide_search_top_k=wide_search_top_k,
117
+ triplet_distance_penalty=triplet_distance_penalty,
96
118
  ).get_context,
97
119
  ],
98
120
  SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: [
@@ -103,6 +125,8 @@ async def get_search_type_tools(
103
125
  node_name=node_name,
104
126
  save_interaction=save_interaction,
105
127
  system_prompt=system_prompt,
128
+ wide_search_top_k=wide_search_top_k,
129
+ triplet_distance_penalty=triplet_distance_penalty,
106
130
  ).get_completion,
107
131
  GraphCompletionContextExtensionRetriever(
108
132
  system_prompt_path=system_prompt_path,
@@ -111,6 +135,8 @@ async def get_search_type_tools(
111
135
  node_name=node_name,
112
136
  save_interaction=save_interaction,
113
137
  system_prompt=system_prompt,
138
+ wide_search_top_k=wide_search_top_k,
139
+ triplet_distance_penalty=triplet_distance_penalty,
114
140
  ).get_context,
115
141
  ],
116
142
  SearchType.GRAPH_SUMMARY_COMPLETION: [
@@ -121,6 +147,8 @@ async def get_search_type_tools(
121
147
  node_name=node_name,
122
148
  save_interaction=save_interaction,
123
149
  system_prompt=system_prompt,
150
+ wide_search_top_k=wide_search_top_k,
151
+ triplet_distance_penalty=triplet_distance_penalty,
124
152
  ).get_completion,
125
153
  GraphSummaryCompletionRetriever(
126
154
  system_prompt_path=system_prompt_path,
@@ -129,12 +157,10 @@ async def get_search_type_tools(
129
157
  node_name=node_name,
130
158
  save_interaction=save_interaction,
131
159
  system_prompt=system_prompt,
160
+ wide_search_top_k=wide_search_top_k,
161
+ triplet_distance_penalty=triplet_distance_penalty,
132
162
  ).get_context,
133
163
  ],
134
- SearchType.CODE: [
135
- CodeRetriever(top_k=top_k).get_completion,
136
- CodeRetriever(top_k=top_k).get_context,
137
- ],
138
164
  SearchType.CYPHER: [
139
165
  CypherSearchRetriever().get_completion,
140
166
  CypherSearchRetriever().get_context,
@@ -145,8 +171,16 @@ async def get_search_type_tools(
145
171
  ],
146
172
  SearchType.FEEDBACK: [UserQAFeedback(last_k=last_k).add_feedback],
147
173
  SearchType.TEMPORAL: [
148
- TemporalRetriever(top_k=top_k).get_completion,
149
- TemporalRetriever(top_k=top_k).get_context,
174
+ TemporalRetriever(
175
+ top_k=top_k,
176
+ wide_search_top_k=wide_search_top_k,
177
+ triplet_distance_penalty=triplet_distance_penalty,
178
+ ).get_completion,
179
+ TemporalRetriever(
180
+ top_k=top_k,
181
+ wide_search_top_k=wide_search_top_k,
182
+ triplet_distance_penalty=triplet_distance_penalty,
183
+ ).get_context,
150
184
  ],
151
185
  SearchType.CHUNKS_LEXICAL: (
152
186
  lambda _r=JaccardChunksRetriever(top_k=top_k): [
@@ -169,7 +203,19 @@ async def get_search_type_tools(
169
203
  ):
170
204
  raise UnsupportedSearchTypeError("Cypher query search types are disabled.")
171
205
 
172
- search_type_tools = search_tasks.get(query_type)
206
+ from cognee.modules.retrieval.registered_community_retrievers import (
207
+ registered_community_retrievers,
208
+ )
209
+
210
+ if query_type in registered_community_retrievers:
211
+ retriever = registered_community_retrievers[query_type]
212
+ retriever_instance = retriever(top_k=top_k)
213
+ search_type_tools = [
214
+ retriever_instance.get_completion,
215
+ retriever_instance.get_context,
216
+ ]
217
+ else:
218
+ search_type_tools = search_tasks.get(query_type)
173
219
 
174
220
  if not search_type_tools:
175
221
  raise UnsupportedSearchTypeError(str(query_type))
@@ -24,6 +24,8 @@ async def no_access_control_search(
24
24
  last_k: Optional[int] = None,
25
25
  only_context: bool = False,
26
26
  session_id: Optional[str] = None,
27
+ wide_search_top_k: Optional[int] = 100,
28
+ triplet_distance_penalty: Optional[float] = 3.5,
27
29
  ) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
28
30
  search_tools = await get_search_type_tools(
29
31
  query_type=query_type,
@@ -35,6 +37,8 @@ async def no_access_control_search(
35
37
  node_name=node_name,
36
38
  save_interaction=save_interaction,
37
39
  last_k=last_k,
40
+ wide_search_top_k=wide_search_top_k,
41
+ triplet_distance_penalty=triplet_distance_penalty,
38
42
  )
39
43
  graph_engine = await get_graph_engine()
40
44
  is_empty = await graph_engine.is_empty()
@@ -47,6 +47,8 @@ async def search(
47
47
  only_context: bool = False,
48
48
  use_combined_context: bool = False,
49
49
  session_id: Optional[str] = None,
50
+ wide_search_top_k: Optional[int] = 100,
51
+ triplet_distance_penalty: Optional[float] = 3.5,
50
52
  ) -> Union[CombinedSearchResult, List[SearchResult]]:
51
53
  """
52
54
 
@@ -90,6 +92,8 @@ async def search(
90
92
  only_context=only_context,
91
93
  use_combined_context=use_combined_context,
92
94
  session_id=session_id,
95
+ wide_search_top_k=wide_search_top_k,
96
+ triplet_distance_penalty=triplet_distance_penalty,
93
97
  )
94
98
  else:
95
99
  search_results = [
@@ -105,6 +109,8 @@ async def search(
105
109
  last_k=last_k,
106
110
  only_context=only_context,
107
111
  session_id=session_id,
112
+ wide_search_top_k=wide_search_top_k,
113
+ triplet_distance_penalty=triplet_distance_penalty,
108
114
  )
109
115
  ]
110
116
 
@@ -219,6 +225,8 @@ async def authorized_search(
219
225
  only_context: bool = False,
220
226
  use_combined_context: bool = False,
221
227
  session_id: Optional[str] = None,
228
+ wide_search_top_k: Optional[int] = 100,
229
+ triplet_distance_penalty: Optional[float] = 3.5,
222
230
  ) -> Union[
223
231
  Tuple[Any, Union[List[Edge], str], List[Dataset]],
224
232
  List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
@@ -246,6 +254,8 @@ async def authorized_search(
246
254
  last_k=last_k,
247
255
  only_context=True,
248
256
  session_id=session_id,
257
+ wide_search_top_k=wide_search_top_k,
258
+ triplet_distance_penalty=triplet_distance_penalty,
249
259
  )
250
260
 
251
261
  context = {}
@@ -267,6 +277,8 @@ async def authorized_search(
267
277
  node_name=node_name,
268
278
  save_interaction=save_interaction,
269
279
  last_k=last_k,
280
+ wide_search_top_k=wide_search_top_k,
281
+ triplet_distance_penalty=triplet_distance_penalty,
270
282
  )
271
283
  search_tools = specific_search_tools
272
284
  if len(search_tools) == 2:
@@ -306,6 +318,7 @@ async def authorized_search(
306
318
  last_k=last_k,
307
319
  only_context=only_context,
308
320
  session_id=session_id,
321
+ wide_search_top_k=wide_search_top_k,
309
322
  )
310
323
 
311
324
  return search_results
@@ -325,6 +338,8 @@ async def search_in_datasets_context(
325
338
  only_context: bool = False,
326
339
  context: Optional[Any] = None,
327
340
  session_id: Optional[str] = None,
341
+ wide_search_top_k: Optional[int] = 100,
342
+ triplet_distance_penalty: Optional[float] = 3.5,
328
343
  ) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]:
329
344
  """
330
345
  Searches all provided datasets and handles setting up of appropriate database context based on permissions.
@@ -345,6 +360,8 @@ async def search_in_datasets_context(
345
360
  only_context: bool = False,
346
361
  context: Optional[Any] = None,
347
362
  session_id: Optional[str] = None,
363
+ wide_search_top_k: Optional[int] = 100,
364
+ triplet_distance_penalty: Optional[float] = 3.5,
348
365
  ) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
349
366
  # Set database configuration in async context for each dataset user has access for
350
367
  await set_database_global_context_variables(dataset.id, dataset.owner_id)
@@ -378,6 +395,8 @@ async def search_in_datasets_context(
378
395
  node_name=node_name,
379
396
  save_interaction=save_interaction,
380
397
  last_k=last_k,
398
+ wide_search_top_k=wide_search_top_k,
399
+ triplet_distance_penalty=triplet_distance_penalty,
381
400
  )
382
401
  search_tools = specific_search_tools
383
402
  if len(search_tools) == 2:
@@ -413,6 +432,8 @@ async def search_in_datasets_context(
413
432
  only_context=only_context,
414
433
  context=context,
415
434
  session_id=session_id,
435
+ wide_search_top_k=wide_search_top_k,
436
+ triplet_distance_penalty=triplet_distance_penalty,
416
437
  )
417
438
  )
418
439
 
@@ -5,9 +5,9 @@ class SearchType(Enum):
5
5
  SUMMARIES = "SUMMARIES"
6
6
  CHUNKS = "CHUNKS"
7
7
  RAG_COMPLETION = "RAG_COMPLETION"
8
+ TRIPLET_COMPLETION = "TRIPLET_COMPLETION"
8
9
  GRAPH_COMPLETION = "GRAPH_COMPLETION"
9
10
  GRAPH_SUMMARY_COMPLETION = "GRAPH_SUMMARY_COMPLETION"
10
- CODE = "CODE"
11
11
  CYPHER = "CYPHER"
12
12
  NATURAL_LANGUAGE = "NATURAL_LANGUAGE"
13
13
  GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT"
@@ -16,6 +16,7 @@ class ModelName(Enum):
16
16
  anthropic = "anthropic"
17
17
  gemini = "gemini"
18
18
  mistral = "mistral"
19
+ bedrock = "bedrock"
19
20
 
20
21
 
21
22
  class LLMConfig(BaseModel):
@@ -77,6 +78,10 @@ def get_settings() -> SettingsDict:
77
78
  "value": "mistral",
78
79
  "label": "Mistral",
79
80
  },
81
+ {
82
+ "value": "bedrock",
83
+ "label": "Bedrock",
84
+ },
80
85
  ]
81
86
 
82
87
  return SettingsDict.model_validate(
@@ -157,6 +162,20 @@ def get_settings() -> SettingsDict:
157
162
  "label": "Mistral Large 2.1",
158
163
  },
159
164
  ],
165
+ "bedrock": [
166
+ {
167
+ "value": "eu.anthropic.claude-sonnet-4-5-20250929-v1:0",
168
+ "label": "Claude 4.5 Sonnet",
169
+ },
170
+ {
171
+ "value": "eu.anthropic.claude-haiku-4-5-20251001-v1:0",
172
+ "label": "Claude 4.5 Haiku",
173
+ },
174
+ {
175
+ "value": "eu.amazon.nova-lite-v1:0",
176
+ "label": "Amazon Nova Lite",
177
+ },
178
+ ],
160
179
  },
161
180
  },
162
181
  vector_db={
@@ -12,8 +12,8 @@ logger = get_logger("get_authenticated_user")
12
12
 
13
13
  # Check environment variable to determine authentication requirement
14
14
  REQUIRE_AUTHENTICATION = (
15
- os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true"
16
- or backend_access_control_enabled()
15
+ os.getenv("REQUIRE_AUTHENTICATION", "true").lower() == "true"
16
+ or os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", "true").lower() == "true"
17
17
  )
18
18
 
19
19
  fastapi_users = get_fastapi_users()