cognee 0.5.0.dev0__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. cognee/api/client.py +1 -5
  2. cognee/api/v1/add/add.py +2 -1
  3. cognee/api/v1/cognify/cognify.py +24 -16
  4. cognee/api/v1/cognify/routers/__init__.py +0 -1
  5. cognee/api/v1/cognify/routers/get_cognify_router.py +3 -1
  6. cognee/api/v1/datasets/routers/get_datasets_router.py +3 -3
  7. cognee/api/v1/ontologies/ontologies.py +12 -37
  8. cognee/api/v1/ontologies/routers/get_ontology_router.py +27 -25
  9. cognee/api/v1/search/search.py +8 -0
  10. cognee/api/v1/ui/node_setup.py +360 -0
  11. cognee/api/v1/ui/npm_utils.py +50 -0
  12. cognee/api/v1/ui/ui.py +38 -68
  13. cognee/context_global_variables.py +61 -16
  14. cognee/eval_framework/Dockerfile +29 -0
  15. cognee/eval_framework/answer_generation/answer_generation_executor.py +10 -0
  16. cognee/eval_framework/answer_generation/run_question_answering_module.py +1 -1
  17. cognee/eval_framework/corpus_builder/task_getters/get_cascade_graph_tasks.py +0 -2
  18. cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py +4 -4
  19. cognee/eval_framework/eval_config.py +2 -2
  20. cognee/eval_framework/modal_run_eval.py +16 -28
  21. cognee/infrastructure/databases/dataset_database_handler/__init__.py +3 -0
  22. cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py +80 -0
  23. cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +18 -0
  24. cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py +10 -0
  25. cognee/infrastructure/databases/graph/config.py +3 -0
  26. cognee/infrastructure/databases/graph/get_graph_engine.py +1 -0
  27. cognee/infrastructure/databases/graph/graph_db_interface.py +15 -0
  28. cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py +81 -0
  29. cognee/infrastructure/databases/graph/kuzu/adapter.py +228 -0
  30. cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +168 -0
  31. cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +80 -1
  32. cognee/infrastructure/databases/utils/__init__.py +3 -0
  33. cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py +10 -0
  34. cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +62 -48
  35. cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py +10 -0
  36. cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py +30 -0
  37. cognee/infrastructure/databases/vector/config.py +2 -0
  38. cognee/infrastructure/databases/vector/create_vector_engine.py +1 -0
  39. cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +8 -6
  40. cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +9 -7
  41. cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +11 -10
  42. cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +2 -0
  43. cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py +50 -0
  44. cognee/infrastructure/databases/vector/vector_db_interface.py +35 -0
  45. cognee/infrastructure/files/storage/s3_config.py +2 -0
  46. cognee/infrastructure/llm/LLMGateway.py +5 -2
  47. cognee/infrastructure/llm/config.py +35 -0
  48. cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py +2 -2
  49. cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +23 -8
  50. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -16
  51. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py +5 -0
  52. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py +153 -0
  53. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +40 -37
  54. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +39 -36
  55. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +19 -1
  56. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +11 -9
  57. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +23 -21
  58. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +42 -34
  59. cognee/memify_pipelines/create_triplet_embeddings.py +53 -0
  60. cognee/modules/cognify/config.py +2 -0
  61. cognee/modules/data/deletion/prune_system.py +52 -2
  62. cognee/modules/data/methods/delete_dataset.py +26 -0
  63. cognee/modules/engine/models/Triplet.py +9 -0
  64. cognee/modules/engine/models/__init__.py +1 -0
  65. cognee/modules/graph/cognee_graph/CogneeGraph.py +85 -37
  66. cognee/modules/graph/cognee_graph/CogneeGraphElements.py +8 -3
  67. cognee/modules/memify/memify.py +1 -7
  68. cognee/modules/pipelines/operations/pipeline.py +18 -2
  69. cognee/modules/retrieval/__init__.py +1 -1
  70. cognee/modules/retrieval/graph_completion_context_extension_retriever.py +4 -0
  71. cognee/modules/retrieval/graph_completion_cot_retriever.py +4 -0
  72. cognee/modules/retrieval/graph_completion_retriever.py +10 -0
  73. cognee/modules/retrieval/graph_summary_completion_retriever.py +4 -0
  74. cognee/modules/retrieval/register_retriever.py +10 -0
  75. cognee/modules/retrieval/registered_community_retrievers.py +1 -0
  76. cognee/modules/retrieval/temporal_retriever.py +4 -0
  77. cognee/modules/retrieval/triplet_retriever.py +182 -0
  78. cognee/modules/retrieval/utils/brute_force_triplet_search.py +42 -10
  79. cognee/modules/run_custom_pipeline/run_custom_pipeline.py +8 -1
  80. cognee/modules/search/methods/get_search_type_tools.py +54 -8
  81. cognee/modules/search/methods/no_access_control_search.py +4 -0
  82. cognee/modules/search/methods/search.py +46 -18
  83. cognee/modules/search/types/SearchType.py +1 -1
  84. cognee/modules/settings/get_settings.py +19 -0
  85. cognee/modules/users/methods/get_authenticated_user.py +2 -2
  86. cognee/modules/users/models/DatasetDatabase.py +15 -3
  87. cognee/shared/logging_utils.py +4 -0
  88. cognee/shared/rate_limiting.py +30 -0
  89. cognee/tasks/documents/__init__.py +0 -1
  90. cognee/tasks/graph/extract_graph_from_data.py +9 -10
  91. cognee/tasks/memify/get_triplet_datapoints.py +289 -0
  92. cognee/tasks/storage/add_data_points.py +142 -2
  93. cognee/tests/integration/retrieval/test_triplet_retriever.py +84 -0
  94. cognee/tests/integration/tasks/test_add_data_points.py +139 -0
  95. cognee/tests/integration/tasks/test_get_triplet_datapoints.py +69 -0
  96. cognee/tests/test_cognee_server_start.py +2 -4
  97. cognee/tests/test_conversation_history.py +23 -1
  98. cognee/tests/test_dataset_database_handler.py +137 -0
  99. cognee/tests/test_dataset_delete.py +76 -0
  100. cognee/tests/test_edge_centered_payload.py +170 -0
  101. cognee/tests/test_pipeline_cache.py +164 -0
  102. cognee/tests/test_search_db.py +37 -1
  103. cognee/tests/unit/api/test_ontology_endpoint.py +77 -89
  104. cognee/tests/unit/infrastructure/llm/test_llm_config.py +46 -0
  105. cognee/tests/unit/infrastructure/mock_embedding_engine.py +3 -7
  106. cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +0 -5
  107. cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +2 -2
  108. cognee/tests/unit/modules/graph/cognee_graph_test.py +406 -0
  109. cognee/tests/unit/modules/memify_tasks/test_get_triplet_datapoints.py +214 -0
  110. cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +608 -0
  111. cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +83 -0
  112. cognee/tests/unit/modules/search/test_search.py +100 -0
  113. cognee/tests/unit/tasks/storage/test_add_data_points.py +288 -0
  114. {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/METADATA +76 -89
  115. {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/RECORD +119 -97
  116. {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/WHEEL +1 -1
  117. cognee/api/v1/cognify/code_graph_pipeline.py +0 -119
  118. cognee/api/v1/cognify/routers/get_code_pipeline_router.py +0 -90
  119. cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +0 -544
  120. cognee/modules/retrieval/code_retriever.py +0 -232
  121. cognee/tasks/code/enrich_dependency_graph_checker.py +0 -35
  122. cognee/tasks/code/get_local_dependencies_checker.py +0 -20
  123. cognee/tasks/code/get_repo_dependency_graph_checker.py +0 -35
  124. cognee/tasks/documents/check_permissions_on_dataset.py +0 -26
  125. cognee/tasks/repo_processor/__init__.py +0 -2
  126. cognee/tasks/repo_processor/get_local_dependencies.py +0 -335
  127. cognee/tasks/repo_processor/get_non_code_files.py +0 -158
  128. cognee/tasks/repo_processor/get_repo_file_dependencies.py +0 -243
  129. cognee/tests/test_delete_bmw_example.py +0 -60
  130. {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/entry_points.txt +0 -0
  131. {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/licenses/LICENSE +0 -0
  132. {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/licenses/NOTICE.md +0 -0
@@ -0,0 +1,182 @@
1
+ import asyncio
2
+ from typing import Any, Optional, Type, List
3
+
4
+ from cognee.shared.logging_utils import get_logger
5
+ from cognee.infrastructure.databases.vector import get_vector_engine
6
+ from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
7
+ from cognee.modules.retrieval.utils.session_cache import (
8
+ save_conversation_history,
9
+ get_conversation_history,
10
+ )
11
+ from cognee.modules.retrieval.base_retriever import BaseRetriever
12
+ from cognee.modules.retrieval.exceptions.exceptions import NoDataError
13
+ from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
14
+ from cognee.context_global_variables import session_user
15
+ from cognee.infrastructure.databases.cache.config import CacheConfig
16
+
17
+ logger = get_logger("TripletRetriever")
18
+
19
+
20
+ class TripletRetriever(BaseRetriever):
21
+ """
22
+ Retriever for handling LLM-based completion searches using triplets.
23
+
24
+ Public methods:
25
+ - get_context(query: str) -> str
26
+ - get_completion(query: str, context: Optional[Any] = None) -> Any
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ user_prompt_path: str = "context_for_question.txt",
32
+ system_prompt_path: str = "answer_simple_question.txt",
33
+ system_prompt: Optional[str] = None,
34
+ top_k: Optional[int] = 5,
35
+ ):
36
+ """Initialize retriever with optional custom prompt paths."""
37
+ self.user_prompt_path = user_prompt_path
38
+ self.system_prompt_path = system_prompt_path
39
+ self.top_k = top_k if top_k is not None else 1
40
+ self.system_prompt = system_prompt
41
+
42
+ async def get_context(self, query: str) -> str:
43
+ """
44
+ Retrieves relevant triplets as context.
45
+
46
+ Fetches triplets based on a query from a vector engine and combines their text.
47
+ Returns empty string if no triplets are found. Raises NoDataError if the collection is not
48
+ found.
49
+
50
+ Parameters:
51
+ -----------
52
+
53
+ - query (str): The query string used to search for relevant triplets.
54
+
55
+ Returns:
56
+ --------
57
+
58
+ - str: A string containing the combined text of the retrieved triplets, or an
59
+ empty string if none are found.
60
+ """
61
+ vector_engine = get_vector_engine()
62
+
63
+ try:
64
+ if not await vector_engine.has_collection(collection_name="Triplet_text"):
65
+ logger.error("Triplet_text collection not found")
66
+ raise NoDataError(
67
+ "In order to use TRIPLET_COMPLETION first use the create_triplet_embeddings memify pipeline. "
68
+ )
69
+
70
+ found_triplets = await vector_engine.search("Triplet_text", query, limit=self.top_k)
71
+
72
+ if len(found_triplets) == 0:
73
+ return ""
74
+
75
+ triplets_payload = [found_triplet.payload["text"] for found_triplet in found_triplets]
76
+ combined_context = "\n".join(triplets_payload)
77
+ return combined_context
78
+ except CollectionNotFoundError as error:
79
+ logger.error("Triplet_text collection not found")
80
+ raise NoDataError("No data found in the system, please add data first.") from error
81
+
82
+ async def get_completion(
83
+ self,
84
+ query: str,
85
+ context: Optional[Any] = None,
86
+ session_id: Optional[str] = None,
87
+ response_model: Type = str,
88
+ ) -> List[Any]:
89
+ """
90
+ Generates an LLM completion using the context.
91
+
92
+ Retrieves context if not provided and generates a completion based on the query and
93
+ context using an external completion generator.
94
+
95
+ Parameters:
96
+ -----------
97
+
98
+ - query (str): The query string to be used for generating a completion.
99
+ - context (Optional[Any]): Optional pre-fetched context to use for generating the
100
+ completion; if None, it retrieves the context for the query. (default None)
101
+ - session_id (Optional[str]): Optional session identifier for caching. If None,
102
+ defaults to 'default_session'. (default None)
103
+ - response_model (Type): The Pydantic model type for structured output. (default str)
104
+
105
+ Returns:
106
+ --------
107
+
108
+ - Any: The generated completion based on the provided query and context.
109
+ """
110
+ if context is None:
111
+ context = await self.get_context(query)
112
+
113
+ cache_config = CacheConfig()
114
+ user = session_user.get()
115
+ user_id = getattr(user, "id", None)
116
+ session_save = user_id and cache_config.caching
117
+
118
+ if session_save:
119
+ completion = await self._get_completion_with_session(
120
+ query=query,
121
+ context=context,
122
+ session_id=session_id,
123
+ response_model=response_model,
124
+ )
125
+ else:
126
+ completion = await self._get_completion_without_session(
127
+ query=query,
128
+ context=context,
129
+ response_model=response_model,
130
+ )
131
+
132
+ return [completion]
133
+
134
+ async def _get_completion_with_session(
135
+ self,
136
+ query: str,
137
+ context: str,
138
+ session_id: Optional[str],
139
+ response_model: Type,
140
+ ) -> Any:
141
+ """Generate completion with session history and caching."""
142
+ conversation_history = await get_conversation_history(session_id=session_id)
143
+
144
+ context_summary, completion = await asyncio.gather(
145
+ summarize_text(context),
146
+ generate_completion(
147
+ query=query,
148
+ context=context,
149
+ user_prompt_path=self.user_prompt_path,
150
+ system_prompt_path=self.system_prompt_path,
151
+ system_prompt=self.system_prompt,
152
+ conversation_history=conversation_history,
153
+ response_model=response_model,
154
+ ),
155
+ )
156
+
157
+ await save_conversation_history(
158
+ query=query,
159
+ context_summary=context_summary,
160
+ answer=completion,
161
+ session_id=session_id,
162
+ )
163
+
164
+ return completion
165
+
166
+ async def _get_completion_without_session(
167
+ self,
168
+ query: str,
169
+ context: str,
170
+ response_model: Type,
171
+ ) -> Any:
172
+ """Generate completion without session history."""
173
+ completion = await generate_completion(
174
+ query=query,
175
+ context=context,
176
+ user_prompt_path=self.user_prompt_path,
177
+ system_prompt_path=self.system_prompt_path,
178
+ system_prompt=self.system_prompt,
179
+ response_model=response_model,
180
+ )
181
+
182
+ return completion
@@ -58,6 +58,8 @@ async def get_memory_fragment(
58
58
  properties_to_project: Optional[List[str]] = None,
59
59
  node_type: Optional[Type] = None,
60
60
  node_name: Optional[List[str]] = None,
61
+ relevant_ids_to_filter: Optional[List[str]] = None,
62
+ triplet_distance_penalty: Optional[float] = 3.5,
61
63
  ) -> CogneeGraph:
62
64
  """Creates and initializes a CogneeGraph memory fragment with optional property projections."""
63
65
  if properties_to_project is None:
@@ -74,6 +76,8 @@ async def get_memory_fragment(
74
76
  edge_properties_to_project=["relationship_name", "edge_text"],
75
77
  node_type=node_type,
76
78
  node_name=node_name,
79
+ relevant_ids_to_filter=relevant_ids_to_filter,
80
+ triplet_distance_penalty=triplet_distance_penalty,
77
81
  )
78
82
 
79
83
  except EntityNotFoundError:
@@ -95,6 +99,8 @@ async def brute_force_triplet_search(
95
99
  memory_fragment: Optional[CogneeGraph] = None,
96
100
  node_type: Optional[Type] = None,
97
101
  node_name: Optional[List[str]] = None,
102
+ wide_search_top_k: Optional[int] = 100,
103
+ triplet_distance_penalty: Optional[float] = 3.5,
98
104
  ) -> List[Edge]:
99
105
  """
100
106
  Performs a brute force search to retrieve the top triplets from the graph.
@@ -107,6 +113,8 @@ async def brute_force_triplet_search(
107
113
  memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse.
108
114
  node_type: node type to filter
109
115
  node_name: node name to filter
116
+ wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections
117
+ triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection
110
118
 
111
119
  Returns:
112
120
  list: The top triplet results.
@@ -116,10 +124,10 @@ async def brute_force_triplet_search(
116
124
  if top_k <= 0:
117
125
  raise ValueError("top_k must be a positive integer.")
118
126
 
119
- if memory_fragment is None:
120
- memory_fragment = await get_memory_fragment(
121
- properties_to_project, node_type=node_type, node_name=node_name
122
- )
127
+ # Setting wide search limit based on the parameters
128
+ non_global_search = node_name is None
129
+
130
+ wide_search_limit = wide_search_top_k if non_global_search else None
123
131
 
124
132
  if collections is None:
125
133
  collections = [
@@ -129,6 +137,9 @@ async def brute_force_triplet_search(
129
137
  "DocumentChunk_text",
130
138
  ]
131
139
 
140
+ if "EdgeType_relationship_name" not in collections:
141
+ collections.append("EdgeType_relationship_name")
142
+
132
143
  try:
133
144
  vector_engine = get_vector_engine()
134
145
  except Exception as e:
@@ -140,7 +151,7 @@ async def brute_force_triplet_search(
140
151
  async def search_in_collection(collection_name: str):
141
152
  try:
142
153
  return await vector_engine.search(
143
- collection_name=collection_name, query_vector=query_vector, limit=None
154
+ collection_name=collection_name, query_vector=query_vector, limit=wide_search_limit
144
155
  )
145
156
  except CollectionNotFoundError:
146
157
  return []
@@ -156,19 +167,40 @@ async def brute_force_triplet_search(
156
167
  return []
157
168
 
158
169
  # Final statistics
159
- projection_time = time.time() - start_time
170
+ vector_collection_search_time = time.time() - start_time
160
171
  logger.info(
161
- f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {projection_time:.2f}s"
172
+ f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {vector_collection_search_time:.2f}s"
162
173
  )
163
174
 
164
175
  node_distances = {collection: result for collection, result in zip(collections, results)}
165
176
 
166
177
  edge_distances = node_distances.get("EdgeType_relationship_name", None)
167
178
 
179
+ if wide_search_limit is not None:
180
+ relevant_ids_to_filter = list(
181
+ {
182
+ str(getattr(scored_node, "id"))
183
+ for collection_name, score_collection in node_distances.items()
184
+ if collection_name != "EdgeType_relationship_name"
185
+ and isinstance(score_collection, (list, tuple))
186
+ for scored_node in score_collection
187
+ if getattr(scored_node, "id", None)
188
+ }
189
+ )
190
+ else:
191
+ relevant_ids_to_filter = None
192
+
193
+ if memory_fragment is None:
194
+ memory_fragment = await get_memory_fragment(
195
+ properties_to_project=properties_to_project,
196
+ node_type=node_type,
197
+ node_name=node_name,
198
+ relevant_ids_to_filter=relevant_ids_to_filter,
199
+ triplet_distance_penalty=triplet_distance_penalty,
200
+ )
201
+
168
202
  await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
169
- await memory_fragment.map_vector_distances_to_graph_edges(
170
- vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances
171
- )
203
+ await memory_fragment.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
172
204
 
173
205
  results = await memory_fragment.calculate_top_triplet_importances(k=top_k)
174
206
 
@@ -18,6 +18,8 @@ async def run_custom_pipeline(
18
18
  user: User = None,
19
19
  vector_db_config: Optional[dict] = None,
20
20
  graph_db_config: Optional[dict] = None,
21
+ use_pipeline_cache: bool = False,
22
+ incremental_loading: bool = False,
21
23
  data_per_batch: int = 20,
22
24
  run_in_background: bool = False,
23
25
  pipeline_name: str = "custom_pipeline",
@@ -40,6 +42,10 @@ async def run_custom_pipeline(
40
42
  user: User context for authentication and data access. Uses default if None.
41
43
  vector_db_config: Custom vector database configuration for embeddings storage.
42
44
  graph_db_config: Custom graph database configuration for relationship storage.
45
+ use_pipeline_cache: If True, pipelines with the same ID that are currently executing and pipelines with the same ID that were completed won't process data again.
46
+ Pipelines ID is created based on the generate_pipeline_id function. Pipeline status can be manually reset with the reset_dataset_pipeline_run_status function.
47
+ incremental_loading: If True, only new or modified data will be processed to avoid duplication. (Only works if data is used with the Cognee python Data model).
48
+ The incremental system stores and compares hashes of processed data in the Data model and skips data with the same content hash.
43
49
  data_per_batch: Number of data items to be processed in parallel.
44
50
  run_in_background: If True, starts processing asynchronously and returns immediately.
45
51
  If False, waits for completion before returning.
@@ -63,7 +69,8 @@ async def run_custom_pipeline(
63
69
  datasets=dataset,
64
70
  vector_db_config=vector_db_config,
65
71
  graph_db_config=graph_db_config,
66
- incremental_loading=False,
72
+ use_pipeline_cache=use_pipeline_cache,
73
+ incremental_loading=incremental_loading,
67
74
  data_per_batch=data_per_batch,
68
75
  pipeline_name=pipeline_name,
69
76
  )
@@ -2,6 +2,7 @@ import os
2
2
  from typing import Callable, List, Optional, Type
3
3
 
4
4
  from cognee.modules.engine.models.node_set import NodeSet
5
+ from cognee.modules.retrieval.triplet_retriever import TripletRetriever
5
6
  from cognee.modules.search.types import SearchType
6
7
  from cognee.modules.search.operations import select_search_type
7
8
  from cognee.modules.search.exceptions import UnsupportedSearchTypeError
@@ -22,7 +23,6 @@ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphComplet
22
23
  from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
23
24
  GraphCompletionContextExtensionRetriever,
24
25
  )
25
- from cognee.modules.retrieval.code_retriever import CodeRetriever
26
26
  from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
27
27
  from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
28
28
 
@@ -37,6 +37,8 @@ async def get_search_type_tools(
37
37
  node_name: Optional[List[str]] = None,
38
38
  save_interaction: bool = False,
39
39
  last_k: Optional[int] = None,
40
+ wide_search_top_k: Optional[int] = 100,
41
+ triplet_distance_penalty: Optional[float] = 3.5,
40
42
  ) -> list:
41
43
  search_tasks: dict[SearchType, List[Callable]] = {
42
44
  SearchType.SUMMARIES: [
@@ -59,6 +61,18 @@ async def get_search_type_tools(
59
61
  system_prompt=system_prompt,
60
62
  ).get_context,
61
63
  ],
64
+ SearchType.TRIPLET_COMPLETION: [
65
+ TripletRetriever(
66
+ system_prompt_path=system_prompt_path,
67
+ top_k=top_k,
68
+ system_prompt=system_prompt,
69
+ ).get_completion,
70
+ TripletRetriever(
71
+ system_prompt_path=system_prompt_path,
72
+ top_k=top_k,
73
+ system_prompt=system_prompt,
74
+ ).get_context,
75
+ ],
62
76
  SearchType.GRAPH_COMPLETION: [
63
77
  GraphCompletionRetriever(
64
78
  system_prompt_path=system_prompt_path,
@@ -67,6 +81,8 @@ async def get_search_type_tools(
67
81
  node_name=node_name,
68
82
  save_interaction=save_interaction,
69
83
  system_prompt=system_prompt,
84
+ wide_search_top_k=wide_search_top_k,
85
+ triplet_distance_penalty=triplet_distance_penalty,
70
86
  ).get_completion,
71
87
  GraphCompletionRetriever(
72
88
  system_prompt_path=system_prompt_path,
@@ -75,6 +91,8 @@ async def get_search_type_tools(
75
91
  node_name=node_name,
76
92
  save_interaction=save_interaction,
77
93
  system_prompt=system_prompt,
94
+ wide_search_top_k=wide_search_top_k,
95
+ triplet_distance_penalty=triplet_distance_penalty,
78
96
  ).get_context,
79
97
  ],
80
98
  SearchType.GRAPH_COMPLETION_COT: [
@@ -85,6 +103,8 @@ async def get_search_type_tools(
85
103
  node_name=node_name,
86
104
  save_interaction=save_interaction,
87
105
  system_prompt=system_prompt,
106
+ wide_search_top_k=wide_search_top_k,
107
+ triplet_distance_penalty=triplet_distance_penalty,
88
108
  ).get_completion,
89
109
  GraphCompletionCotRetriever(
90
110
  system_prompt_path=system_prompt_path,
@@ -93,6 +113,8 @@ async def get_search_type_tools(
93
113
  node_name=node_name,
94
114
  save_interaction=save_interaction,
95
115
  system_prompt=system_prompt,
116
+ wide_search_top_k=wide_search_top_k,
117
+ triplet_distance_penalty=triplet_distance_penalty,
96
118
  ).get_context,
97
119
  ],
98
120
  SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: [
@@ -103,6 +125,8 @@ async def get_search_type_tools(
103
125
  node_name=node_name,
104
126
  save_interaction=save_interaction,
105
127
  system_prompt=system_prompt,
128
+ wide_search_top_k=wide_search_top_k,
129
+ triplet_distance_penalty=triplet_distance_penalty,
106
130
  ).get_completion,
107
131
  GraphCompletionContextExtensionRetriever(
108
132
  system_prompt_path=system_prompt_path,
@@ -111,6 +135,8 @@ async def get_search_type_tools(
111
135
  node_name=node_name,
112
136
  save_interaction=save_interaction,
113
137
  system_prompt=system_prompt,
138
+ wide_search_top_k=wide_search_top_k,
139
+ triplet_distance_penalty=triplet_distance_penalty,
114
140
  ).get_context,
115
141
  ],
116
142
  SearchType.GRAPH_SUMMARY_COMPLETION: [
@@ -121,6 +147,8 @@ async def get_search_type_tools(
121
147
  node_name=node_name,
122
148
  save_interaction=save_interaction,
123
149
  system_prompt=system_prompt,
150
+ wide_search_top_k=wide_search_top_k,
151
+ triplet_distance_penalty=triplet_distance_penalty,
124
152
  ).get_completion,
125
153
  GraphSummaryCompletionRetriever(
126
154
  system_prompt_path=system_prompt_path,
@@ -129,12 +157,10 @@ async def get_search_type_tools(
129
157
  node_name=node_name,
130
158
  save_interaction=save_interaction,
131
159
  system_prompt=system_prompt,
160
+ wide_search_top_k=wide_search_top_k,
161
+ triplet_distance_penalty=triplet_distance_penalty,
132
162
  ).get_context,
133
163
  ],
134
- SearchType.CODE: [
135
- CodeRetriever(top_k=top_k).get_completion,
136
- CodeRetriever(top_k=top_k).get_context,
137
- ],
138
164
  SearchType.CYPHER: [
139
165
  CypherSearchRetriever().get_completion,
140
166
  CypherSearchRetriever().get_context,
@@ -145,8 +171,16 @@ async def get_search_type_tools(
145
171
  ],
146
172
  SearchType.FEEDBACK: [UserQAFeedback(last_k=last_k).add_feedback],
147
173
  SearchType.TEMPORAL: [
148
- TemporalRetriever(top_k=top_k).get_completion,
149
- TemporalRetriever(top_k=top_k).get_context,
174
+ TemporalRetriever(
175
+ top_k=top_k,
176
+ wide_search_top_k=wide_search_top_k,
177
+ triplet_distance_penalty=triplet_distance_penalty,
178
+ ).get_completion,
179
+ TemporalRetriever(
180
+ top_k=top_k,
181
+ wide_search_top_k=wide_search_top_k,
182
+ triplet_distance_penalty=triplet_distance_penalty,
183
+ ).get_context,
150
184
  ],
151
185
  SearchType.CHUNKS_LEXICAL: (
152
186
  lambda _r=JaccardChunksRetriever(top_k=top_k): [
@@ -169,7 +203,19 @@ async def get_search_type_tools(
169
203
  ):
170
204
  raise UnsupportedSearchTypeError("Cypher query search types are disabled.")
171
205
 
172
- search_type_tools = search_tasks.get(query_type)
206
+ from cognee.modules.retrieval.registered_community_retrievers import (
207
+ registered_community_retrievers,
208
+ )
209
+
210
+ if query_type in registered_community_retrievers:
211
+ retriever = registered_community_retrievers[query_type]
212
+ retriever_instance = retriever(top_k=top_k)
213
+ search_type_tools = [
214
+ retriever_instance.get_completion,
215
+ retriever_instance.get_context,
216
+ ]
217
+ else:
218
+ search_type_tools = search_tasks.get(query_type)
173
219
 
174
220
  if not search_type_tools:
175
221
  raise UnsupportedSearchTypeError(str(query_type))
@@ -24,6 +24,8 @@ async def no_access_control_search(
24
24
  last_k: Optional[int] = None,
25
25
  only_context: bool = False,
26
26
  session_id: Optional[str] = None,
27
+ wide_search_top_k: Optional[int] = 100,
28
+ triplet_distance_penalty: Optional[float] = 3.5,
27
29
  ) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
28
30
  search_tools = await get_search_type_tools(
29
31
  query_type=query_type,
@@ -35,6 +37,8 @@ async def no_access_control_search(
35
37
  node_name=node_name,
36
38
  save_interaction=save_interaction,
37
39
  last_k=last_k,
40
+ wide_search_top_k=wide_search_top_k,
41
+ triplet_distance_penalty=triplet_distance_penalty,
38
42
  )
39
43
  graph_engine = await get_graph_engine()
40
44
  is_empty = await graph_engine.is_empty()
@@ -47,6 +47,9 @@ async def search(
47
47
  only_context: bool = False,
48
48
  use_combined_context: bool = False,
49
49
  session_id: Optional[str] = None,
50
+ wide_search_top_k: Optional[int] = 100,
51
+ triplet_distance_penalty: Optional[float] = 3.5,
52
+ verbose: bool = False,
50
53
  ) -> Union[CombinedSearchResult, List[SearchResult]]:
51
54
  """
52
55
 
@@ -90,6 +93,8 @@ async def search(
90
93
  only_context=only_context,
91
94
  use_combined_context=use_combined_context,
92
95
  session_id=session_id,
96
+ wide_search_top_k=wide_search_top_k,
97
+ triplet_distance_penalty=triplet_distance_penalty,
93
98
  )
94
99
  else:
95
100
  search_results = [
@@ -105,6 +110,8 @@ async def search(
105
110
  last_k=last_k,
106
111
  only_context=only_context,
107
112
  session_id=session_id,
113
+ wide_search_top_k=wide_search_top_k,
114
+ triplet_distance_penalty=triplet_distance_penalty,
108
115
  )
109
116
  ]
110
117
 
@@ -134,6 +141,7 @@ async def search(
134
141
  )
135
142
 
136
143
  if use_combined_context:
144
+ # Note: combined context search must always be verbose and return a CombinedSearchResult with graphs info
137
145
  prepared_search_results = await prepare_search_result(
138
146
  search_results[0] if isinstance(search_results, list) else search_results
139
147
  )
@@ -167,25 +175,30 @@ async def search(
167
175
  datasets = prepared_search_results["datasets"]
168
176
 
169
177
  if only_context:
170
- return_value.append(
171
- {
172
- "search_result": [context] if context else None,
173
- "dataset_id": datasets[0].id,
174
- "dataset_name": datasets[0].name,
175
- "dataset_tenant_id": datasets[0].tenant_id,
176
- "graphs": graphs,
177
- }
178
- )
178
+ search_result_dict = {
179
+ "search_result": [context] if context else None,
180
+ "dataset_id": datasets[0].id,
181
+ "dataset_name": datasets[0].name,
182
+ "dataset_tenant_id": datasets[0].tenant_id,
183
+ }
184
+ if verbose:
185
+ # Include graphs only in verbose mode
186
+ search_result_dict["graphs"] = graphs
187
+
188
+ return_value.append(search_result_dict)
179
189
  else:
180
- return_value.append(
181
- {
182
- "search_result": [result] if result else None,
183
- "dataset_id": datasets[0].id,
184
- "dataset_name": datasets[0].name,
185
- "dataset_tenant_id": datasets[0].tenant_id,
186
- "graphs": graphs,
187
- }
188
- )
190
+ search_result_dict = {
191
+ "search_result": [result] if result else None,
192
+ "dataset_id": datasets[0].id,
193
+ "dataset_name": datasets[0].name,
194
+ "dataset_tenant_id": datasets[0].tenant_id,
195
+ }
196
+ if verbose:
197
+ # Include graphs only in verbose mode
198
+ search_result_dict["graphs"] = graphs
199
+
200
+ return_value.append(search_result_dict)
201
+
189
202
  return return_value
190
203
  else:
191
204
  return_value = []
@@ -219,6 +232,8 @@ async def authorized_search(
219
232
  only_context: bool = False,
220
233
  use_combined_context: bool = False,
221
234
  session_id: Optional[str] = None,
235
+ wide_search_top_k: Optional[int] = 100,
236
+ triplet_distance_penalty: Optional[float] = 3.5,
222
237
  ) -> Union[
223
238
  Tuple[Any, Union[List[Edge], str], List[Dataset]],
224
239
  List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
@@ -246,6 +261,8 @@ async def authorized_search(
246
261
  last_k=last_k,
247
262
  only_context=True,
248
263
  session_id=session_id,
264
+ wide_search_top_k=wide_search_top_k,
265
+ triplet_distance_penalty=triplet_distance_penalty,
249
266
  )
250
267
 
251
268
  context = {}
@@ -267,6 +284,8 @@ async def authorized_search(
267
284
  node_name=node_name,
268
285
  save_interaction=save_interaction,
269
286
  last_k=last_k,
287
+ wide_search_top_k=wide_search_top_k,
288
+ triplet_distance_penalty=triplet_distance_penalty,
270
289
  )
271
290
  search_tools = specific_search_tools
272
291
  if len(search_tools) == 2:
@@ -306,6 +325,7 @@ async def authorized_search(
306
325
  last_k=last_k,
307
326
  only_context=only_context,
308
327
  session_id=session_id,
328
+ wide_search_top_k=wide_search_top_k,
309
329
  )
310
330
 
311
331
  return search_results
@@ -325,6 +345,8 @@ async def search_in_datasets_context(
325
345
  only_context: bool = False,
326
346
  context: Optional[Any] = None,
327
347
  session_id: Optional[str] = None,
348
+ wide_search_top_k: Optional[int] = 100,
349
+ triplet_distance_penalty: Optional[float] = 3.5,
328
350
  ) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]:
329
351
  """
330
352
  Searches all provided datasets and handles setting up of appropriate database context based on permissions.
@@ -345,6 +367,8 @@ async def search_in_datasets_context(
345
367
  only_context: bool = False,
346
368
  context: Optional[Any] = None,
347
369
  session_id: Optional[str] = None,
370
+ wide_search_top_k: Optional[int] = 100,
371
+ triplet_distance_penalty: Optional[float] = 3.5,
348
372
  ) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
349
373
  # Set database configuration in async context for each dataset user has access for
350
374
  await set_database_global_context_variables(dataset.id, dataset.owner_id)
@@ -378,6 +402,8 @@ async def search_in_datasets_context(
378
402
  node_name=node_name,
379
403
  save_interaction=save_interaction,
380
404
  last_k=last_k,
405
+ wide_search_top_k=wide_search_top_k,
406
+ triplet_distance_penalty=triplet_distance_penalty,
381
407
  )
382
408
  search_tools = specific_search_tools
383
409
  if len(search_tools) == 2:
@@ -413,6 +439,8 @@ async def search_in_datasets_context(
413
439
  only_context=only_context,
414
440
  context=context,
415
441
  session_id=session_id,
442
+ wide_search_top_k=wide_search_top_k,
443
+ triplet_distance_penalty=triplet_distance_penalty,
416
444
  )
417
445
  )
418
446
 
@@ -5,9 +5,9 @@ class SearchType(Enum):
5
5
  SUMMARIES = "SUMMARIES"
6
6
  CHUNKS = "CHUNKS"
7
7
  RAG_COMPLETION = "RAG_COMPLETION"
8
+ TRIPLET_COMPLETION = "TRIPLET_COMPLETION"
8
9
  GRAPH_COMPLETION = "GRAPH_COMPLETION"
9
10
  GRAPH_SUMMARY_COMPLETION = "GRAPH_SUMMARY_COMPLETION"
10
- CODE = "CODE"
11
11
  CYPHER = "CYPHER"
12
12
  NATURAL_LANGUAGE = "NATURAL_LANGUAGE"
13
13
  GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT"