cognee 0.5.0.dev0__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/api/client.py +1 -5
- cognee/api/v1/add/add.py +2 -1
- cognee/api/v1/cognify/cognify.py +24 -16
- cognee/api/v1/cognify/routers/__init__.py +0 -1
- cognee/api/v1/cognify/routers/get_cognify_router.py +3 -1
- cognee/api/v1/datasets/routers/get_datasets_router.py +3 -3
- cognee/api/v1/ontologies/ontologies.py +12 -37
- cognee/api/v1/ontologies/routers/get_ontology_router.py +27 -25
- cognee/api/v1/search/search.py +8 -0
- cognee/api/v1/ui/node_setup.py +360 -0
- cognee/api/v1/ui/npm_utils.py +50 -0
- cognee/api/v1/ui/ui.py +38 -68
- cognee/context_global_variables.py +61 -16
- cognee/eval_framework/Dockerfile +29 -0
- cognee/eval_framework/answer_generation/answer_generation_executor.py +10 -0
- cognee/eval_framework/answer_generation/run_question_answering_module.py +1 -1
- cognee/eval_framework/corpus_builder/task_getters/get_cascade_graph_tasks.py +0 -2
- cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py +4 -4
- cognee/eval_framework/eval_config.py +2 -2
- cognee/eval_framework/modal_run_eval.py +16 -28
- cognee/infrastructure/databases/dataset_database_handler/__init__.py +3 -0
- cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py +80 -0
- cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +18 -0
- cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py +10 -0
- cognee/infrastructure/databases/graph/config.py +3 -0
- cognee/infrastructure/databases/graph/get_graph_engine.py +1 -0
- cognee/infrastructure/databases/graph/graph_db_interface.py +15 -0
- cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py +81 -0
- cognee/infrastructure/databases/graph/kuzu/adapter.py +228 -0
- cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +168 -0
- cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +80 -1
- cognee/infrastructure/databases/utils/__init__.py +3 -0
- cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py +10 -0
- cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +62 -48
- cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py +10 -0
- cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py +30 -0
- cognee/infrastructure/databases/vector/config.py +2 -0
- cognee/infrastructure/databases/vector/create_vector_engine.py +1 -0
- cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +8 -6
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +9 -7
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +11 -10
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +2 -0
- cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py +50 -0
- cognee/infrastructure/databases/vector/vector_db_interface.py +35 -0
- cognee/infrastructure/files/storage/s3_config.py +2 -0
- cognee/infrastructure/llm/LLMGateway.py +5 -2
- cognee/infrastructure/llm/config.py +35 -0
- cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py +2 -2
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +23 -8
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -16
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py +5 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py +153 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +40 -37
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +39 -36
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +19 -1
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +11 -9
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +23 -21
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +42 -34
- cognee/memify_pipelines/create_triplet_embeddings.py +53 -0
- cognee/modules/cognify/config.py +2 -0
- cognee/modules/data/deletion/prune_system.py +52 -2
- cognee/modules/data/methods/delete_dataset.py +26 -0
- cognee/modules/engine/models/Triplet.py +9 -0
- cognee/modules/engine/models/__init__.py +1 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +85 -37
- cognee/modules/graph/cognee_graph/CogneeGraphElements.py +8 -3
- cognee/modules/memify/memify.py +1 -7
- cognee/modules/pipelines/operations/pipeline.py +18 -2
- cognee/modules/retrieval/__init__.py +1 -1
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +4 -0
- cognee/modules/retrieval/graph_completion_cot_retriever.py +4 -0
- cognee/modules/retrieval/graph_completion_retriever.py +10 -0
- cognee/modules/retrieval/graph_summary_completion_retriever.py +4 -0
- cognee/modules/retrieval/register_retriever.py +10 -0
- cognee/modules/retrieval/registered_community_retrievers.py +1 -0
- cognee/modules/retrieval/temporal_retriever.py +4 -0
- cognee/modules/retrieval/triplet_retriever.py +182 -0
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +42 -10
- cognee/modules/run_custom_pipeline/run_custom_pipeline.py +8 -1
- cognee/modules/search/methods/get_search_type_tools.py +54 -8
- cognee/modules/search/methods/no_access_control_search.py +4 -0
- cognee/modules/search/methods/search.py +46 -18
- cognee/modules/search/types/SearchType.py +1 -1
- cognee/modules/settings/get_settings.py +19 -0
- cognee/modules/users/methods/get_authenticated_user.py +2 -2
- cognee/modules/users/models/DatasetDatabase.py +15 -3
- cognee/shared/logging_utils.py +4 -0
- cognee/shared/rate_limiting.py +30 -0
- cognee/tasks/documents/__init__.py +0 -1
- cognee/tasks/graph/extract_graph_from_data.py +9 -10
- cognee/tasks/memify/get_triplet_datapoints.py +289 -0
- cognee/tasks/storage/add_data_points.py +142 -2
- cognee/tests/integration/retrieval/test_triplet_retriever.py +84 -0
- cognee/tests/integration/tasks/test_add_data_points.py +139 -0
- cognee/tests/integration/tasks/test_get_triplet_datapoints.py +69 -0
- cognee/tests/test_cognee_server_start.py +2 -4
- cognee/tests/test_conversation_history.py +23 -1
- cognee/tests/test_dataset_database_handler.py +137 -0
- cognee/tests/test_dataset_delete.py +76 -0
- cognee/tests/test_edge_centered_payload.py +170 -0
- cognee/tests/test_pipeline_cache.py +164 -0
- cognee/tests/test_search_db.py +37 -1
- cognee/tests/unit/api/test_ontology_endpoint.py +77 -89
- cognee/tests/unit/infrastructure/llm/test_llm_config.py +46 -0
- cognee/tests/unit/infrastructure/mock_embedding_engine.py +3 -7
- cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +0 -5
- cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +2 -2
- cognee/tests/unit/modules/graph/cognee_graph_test.py +406 -0
- cognee/tests/unit/modules/memify_tasks/test_get_triplet_datapoints.py +214 -0
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +608 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +83 -0
- cognee/tests/unit/modules/search/test_search.py +100 -0
- cognee/tests/unit/tasks/storage/test_add_data_points.py +288 -0
- {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/METADATA +76 -89
- {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/RECORD +119 -97
- {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/WHEEL +1 -1
- cognee/api/v1/cognify/code_graph_pipeline.py +0 -119
- cognee/api/v1/cognify/routers/get_code_pipeline_router.py +0 -90
- cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +0 -544
- cognee/modules/retrieval/code_retriever.py +0 -232
- cognee/tasks/code/enrich_dependency_graph_checker.py +0 -35
- cognee/tasks/code/get_local_dependencies_checker.py +0 -20
- cognee/tasks/code/get_repo_dependency_graph_checker.py +0 -35
- cognee/tasks/documents/check_permissions_on_dataset.py +0 -26
- cognee/tasks/repo_processor/__init__.py +0 -2
- cognee/tasks/repo_processor/get_local_dependencies.py +0 -335
- cognee/tasks/repo_processor/get_non_code_files.py +0 -158
- cognee/tasks/repo_processor/get_repo_file_dependencies.py +0 -243
- cognee/tests/test_delete_bmw_example.py +0 -60
- {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from typing import Any, Optional, Type, List
|
|
3
|
+
|
|
4
|
+
from cognee.shared.logging_utils import get_logger
|
|
5
|
+
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
6
|
+
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
|
7
|
+
from cognee.modules.retrieval.utils.session_cache import (
|
|
8
|
+
save_conversation_history,
|
|
9
|
+
get_conversation_history,
|
|
10
|
+
)
|
|
11
|
+
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|
12
|
+
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
|
13
|
+
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
|
14
|
+
from cognee.context_global_variables import session_user
|
|
15
|
+
from cognee.infrastructure.databases.cache.config import CacheConfig
|
|
16
|
+
|
|
17
|
+
logger = get_logger("TripletRetriever")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TripletRetriever(BaseRetriever):
|
|
21
|
+
"""
|
|
22
|
+
Retriever for handling LLM-based completion searches using triplets.
|
|
23
|
+
|
|
24
|
+
Public methods:
|
|
25
|
+
- get_context(query: str) -> str
|
|
26
|
+
- get_completion(query: str, context: Optional[Any] = None) -> Any
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
user_prompt_path: str = "context_for_question.txt",
|
|
32
|
+
system_prompt_path: str = "answer_simple_question.txt",
|
|
33
|
+
system_prompt: Optional[str] = None,
|
|
34
|
+
top_k: Optional[int] = 5,
|
|
35
|
+
):
|
|
36
|
+
"""Initialize retriever with optional custom prompt paths."""
|
|
37
|
+
self.user_prompt_path = user_prompt_path
|
|
38
|
+
self.system_prompt_path = system_prompt_path
|
|
39
|
+
self.top_k = top_k if top_k is not None else 1
|
|
40
|
+
self.system_prompt = system_prompt
|
|
41
|
+
|
|
42
|
+
async def get_context(self, query: str) -> str:
|
|
43
|
+
"""
|
|
44
|
+
Retrieves relevant triplets as context.
|
|
45
|
+
|
|
46
|
+
Fetches triplets based on a query from a vector engine and combines their text.
|
|
47
|
+
Returns empty string if no triplets are found. Raises NoDataError if the collection is not
|
|
48
|
+
found.
|
|
49
|
+
|
|
50
|
+
Parameters:
|
|
51
|
+
-----------
|
|
52
|
+
|
|
53
|
+
- query (str): The query string used to search for relevant triplets.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
--------
|
|
57
|
+
|
|
58
|
+
- str: A string containing the combined text of the retrieved triplets, or an
|
|
59
|
+
empty string if none are found.
|
|
60
|
+
"""
|
|
61
|
+
vector_engine = get_vector_engine()
|
|
62
|
+
|
|
63
|
+
try:
|
|
64
|
+
if not await vector_engine.has_collection(collection_name="Triplet_text"):
|
|
65
|
+
logger.error("Triplet_text collection not found")
|
|
66
|
+
raise NoDataError(
|
|
67
|
+
"In order to use TRIPLET_COMPLETION first use the create_triplet_embeddings memify pipeline. "
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
found_triplets = await vector_engine.search("Triplet_text", query, limit=self.top_k)
|
|
71
|
+
|
|
72
|
+
if len(found_triplets) == 0:
|
|
73
|
+
return ""
|
|
74
|
+
|
|
75
|
+
triplets_payload = [found_triplet.payload["text"] for found_triplet in found_triplets]
|
|
76
|
+
combined_context = "\n".join(triplets_payload)
|
|
77
|
+
return combined_context
|
|
78
|
+
except CollectionNotFoundError as error:
|
|
79
|
+
logger.error("Triplet_text collection not found")
|
|
80
|
+
raise NoDataError("No data found in the system, please add data first.") from error
|
|
81
|
+
|
|
82
|
+
async def get_completion(
|
|
83
|
+
self,
|
|
84
|
+
query: str,
|
|
85
|
+
context: Optional[Any] = None,
|
|
86
|
+
session_id: Optional[str] = None,
|
|
87
|
+
response_model: Type = str,
|
|
88
|
+
) -> List[Any]:
|
|
89
|
+
"""
|
|
90
|
+
Generates an LLM completion using the context.
|
|
91
|
+
|
|
92
|
+
Retrieves context if not provided and generates a completion based on the query and
|
|
93
|
+
context using an external completion generator.
|
|
94
|
+
|
|
95
|
+
Parameters:
|
|
96
|
+
-----------
|
|
97
|
+
|
|
98
|
+
- query (str): The query string to be used for generating a completion.
|
|
99
|
+
- context (Optional[Any]): Optional pre-fetched context to use for generating the
|
|
100
|
+
completion; if None, it retrieves the context for the query. (default None)
|
|
101
|
+
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
|
102
|
+
defaults to 'default_session'. (default None)
|
|
103
|
+
- response_model (Type): The Pydantic model type for structured output. (default str)
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
--------
|
|
107
|
+
|
|
108
|
+
- Any: The generated completion based on the provided query and context.
|
|
109
|
+
"""
|
|
110
|
+
if context is None:
|
|
111
|
+
context = await self.get_context(query)
|
|
112
|
+
|
|
113
|
+
cache_config = CacheConfig()
|
|
114
|
+
user = session_user.get()
|
|
115
|
+
user_id = getattr(user, "id", None)
|
|
116
|
+
session_save = user_id and cache_config.caching
|
|
117
|
+
|
|
118
|
+
if session_save:
|
|
119
|
+
completion = await self._get_completion_with_session(
|
|
120
|
+
query=query,
|
|
121
|
+
context=context,
|
|
122
|
+
session_id=session_id,
|
|
123
|
+
response_model=response_model,
|
|
124
|
+
)
|
|
125
|
+
else:
|
|
126
|
+
completion = await self._get_completion_without_session(
|
|
127
|
+
query=query,
|
|
128
|
+
context=context,
|
|
129
|
+
response_model=response_model,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
return [completion]
|
|
133
|
+
|
|
134
|
+
async def _get_completion_with_session(
|
|
135
|
+
self,
|
|
136
|
+
query: str,
|
|
137
|
+
context: str,
|
|
138
|
+
session_id: Optional[str],
|
|
139
|
+
response_model: Type,
|
|
140
|
+
) -> Any:
|
|
141
|
+
"""Generate completion with session history and caching."""
|
|
142
|
+
conversation_history = await get_conversation_history(session_id=session_id)
|
|
143
|
+
|
|
144
|
+
context_summary, completion = await asyncio.gather(
|
|
145
|
+
summarize_text(context),
|
|
146
|
+
generate_completion(
|
|
147
|
+
query=query,
|
|
148
|
+
context=context,
|
|
149
|
+
user_prompt_path=self.user_prompt_path,
|
|
150
|
+
system_prompt_path=self.system_prompt_path,
|
|
151
|
+
system_prompt=self.system_prompt,
|
|
152
|
+
conversation_history=conversation_history,
|
|
153
|
+
response_model=response_model,
|
|
154
|
+
),
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
await save_conversation_history(
|
|
158
|
+
query=query,
|
|
159
|
+
context_summary=context_summary,
|
|
160
|
+
answer=completion,
|
|
161
|
+
session_id=session_id,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
return completion
|
|
165
|
+
|
|
166
|
+
async def _get_completion_without_session(
|
|
167
|
+
self,
|
|
168
|
+
query: str,
|
|
169
|
+
context: str,
|
|
170
|
+
response_model: Type,
|
|
171
|
+
) -> Any:
|
|
172
|
+
"""Generate completion without session history."""
|
|
173
|
+
completion = await generate_completion(
|
|
174
|
+
query=query,
|
|
175
|
+
context=context,
|
|
176
|
+
user_prompt_path=self.user_prompt_path,
|
|
177
|
+
system_prompt_path=self.system_prompt_path,
|
|
178
|
+
system_prompt=self.system_prompt,
|
|
179
|
+
response_model=response_model,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
return completion
|
|
@@ -58,6 +58,8 @@ async def get_memory_fragment(
|
|
|
58
58
|
properties_to_project: Optional[List[str]] = None,
|
|
59
59
|
node_type: Optional[Type] = None,
|
|
60
60
|
node_name: Optional[List[str]] = None,
|
|
61
|
+
relevant_ids_to_filter: Optional[List[str]] = None,
|
|
62
|
+
triplet_distance_penalty: Optional[float] = 3.5,
|
|
61
63
|
) -> CogneeGraph:
|
|
62
64
|
"""Creates and initializes a CogneeGraph memory fragment with optional property projections."""
|
|
63
65
|
if properties_to_project is None:
|
|
@@ -74,6 +76,8 @@ async def get_memory_fragment(
|
|
|
74
76
|
edge_properties_to_project=["relationship_name", "edge_text"],
|
|
75
77
|
node_type=node_type,
|
|
76
78
|
node_name=node_name,
|
|
79
|
+
relevant_ids_to_filter=relevant_ids_to_filter,
|
|
80
|
+
triplet_distance_penalty=triplet_distance_penalty,
|
|
77
81
|
)
|
|
78
82
|
|
|
79
83
|
except EntityNotFoundError:
|
|
@@ -95,6 +99,8 @@ async def brute_force_triplet_search(
|
|
|
95
99
|
memory_fragment: Optional[CogneeGraph] = None,
|
|
96
100
|
node_type: Optional[Type] = None,
|
|
97
101
|
node_name: Optional[List[str]] = None,
|
|
102
|
+
wide_search_top_k: Optional[int] = 100,
|
|
103
|
+
triplet_distance_penalty: Optional[float] = 3.5,
|
|
98
104
|
) -> List[Edge]:
|
|
99
105
|
"""
|
|
100
106
|
Performs a brute force search to retrieve the top triplets from the graph.
|
|
@@ -107,6 +113,8 @@ async def brute_force_triplet_search(
|
|
|
107
113
|
memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse.
|
|
108
114
|
node_type: node type to filter
|
|
109
115
|
node_name: node name to filter
|
|
116
|
+
wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections
|
|
117
|
+
triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection
|
|
110
118
|
|
|
111
119
|
Returns:
|
|
112
120
|
list: The top triplet results.
|
|
@@ -116,10 +124,10 @@ async def brute_force_triplet_search(
|
|
|
116
124
|
if top_k <= 0:
|
|
117
125
|
raise ValueError("top_k must be a positive integer.")
|
|
118
126
|
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
127
|
+
# Setting wide search limit based on the parameters
|
|
128
|
+
non_global_search = node_name is None
|
|
129
|
+
|
|
130
|
+
wide_search_limit = wide_search_top_k if non_global_search else None
|
|
123
131
|
|
|
124
132
|
if collections is None:
|
|
125
133
|
collections = [
|
|
@@ -129,6 +137,9 @@ async def brute_force_triplet_search(
|
|
|
129
137
|
"DocumentChunk_text",
|
|
130
138
|
]
|
|
131
139
|
|
|
140
|
+
if "EdgeType_relationship_name" not in collections:
|
|
141
|
+
collections.append("EdgeType_relationship_name")
|
|
142
|
+
|
|
132
143
|
try:
|
|
133
144
|
vector_engine = get_vector_engine()
|
|
134
145
|
except Exception as e:
|
|
@@ -140,7 +151,7 @@ async def brute_force_triplet_search(
|
|
|
140
151
|
async def search_in_collection(collection_name: str):
|
|
141
152
|
try:
|
|
142
153
|
return await vector_engine.search(
|
|
143
|
-
collection_name=collection_name, query_vector=query_vector, limit=
|
|
154
|
+
collection_name=collection_name, query_vector=query_vector, limit=wide_search_limit
|
|
144
155
|
)
|
|
145
156
|
except CollectionNotFoundError:
|
|
146
157
|
return []
|
|
@@ -156,19 +167,40 @@ async def brute_force_triplet_search(
|
|
|
156
167
|
return []
|
|
157
168
|
|
|
158
169
|
# Final statistics
|
|
159
|
-
|
|
170
|
+
vector_collection_search_time = time.time() - start_time
|
|
160
171
|
logger.info(
|
|
161
|
-
f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {
|
|
172
|
+
f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {vector_collection_search_time:.2f}s"
|
|
162
173
|
)
|
|
163
174
|
|
|
164
175
|
node_distances = {collection: result for collection, result in zip(collections, results)}
|
|
165
176
|
|
|
166
177
|
edge_distances = node_distances.get("EdgeType_relationship_name", None)
|
|
167
178
|
|
|
179
|
+
if wide_search_limit is not None:
|
|
180
|
+
relevant_ids_to_filter = list(
|
|
181
|
+
{
|
|
182
|
+
str(getattr(scored_node, "id"))
|
|
183
|
+
for collection_name, score_collection in node_distances.items()
|
|
184
|
+
if collection_name != "EdgeType_relationship_name"
|
|
185
|
+
and isinstance(score_collection, (list, tuple))
|
|
186
|
+
for scored_node in score_collection
|
|
187
|
+
if getattr(scored_node, "id", None)
|
|
188
|
+
}
|
|
189
|
+
)
|
|
190
|
+
else:
|
|
191
|
+
relevant_ids_to_filter = None
|
|
192
|
+
|
|
193
|
+
if memory_fragment is None:
|
|
194
|
+
memory_fragment = await get_memory_fragment(
|
|
195
|
+
properties_to_project=properties_to_project,
|
|
196
|
+
node_type=node_type,
|
|
197
|
+
node_name=node_name,
|
|
198
|
+
relevant_ids_to_filter=relevant_ids_to_filter,
|
|
199
|
+
triplet_distance_penalty=triplet_distance_penalty,
|
|
200
|
+
)
|
|
201
|
+
|
|
168
202
|
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
|
|
169
|
-
await memory_fragment.map_vector_distances_to_graph_edges(
|
|
170
|
-
vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances
|
|
171
|
-
)
|
|
203
|
+
await memory_fragment.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
|
172
204
|
|
|
173
205
|
results = await memory_fragment.calculate_top_triplet_importances(k=top_k)
|
|
174
206
|
|
|
@@ -18,6 +18,8 @@ async def run_custom_pipeline(
|
|
|
18
18
|
user: User = None,
|
|
19
19
|
vector_db_config: Optional[dict] = None,
|
|
20
20
|
graph_db_config: Optional[dict] = None,
|
|
21
|
+
use_pipeline_cache: bool = False,
|
|
22
|
+
incremental_loading: bool = False,
|
|
21
23
|
data_per_batch: int = 20,
|
|
22
24
|
run_in_background: bool = False,
|
|
23
25
|
pipeline_name: str = "custom_pipeline",
|
|
@@ -40,6 +42,10 @@ async def run_custom_pipeline(
|
|
|
40
42
|
user: User context for authentication and data access. Uses default if None.
|
|
41
43
|
vector_db_config: Custom vector database configuration for embeddings storage.
|
|
42
44
|
graph_db_config: Custom graph database configuration for relationship storage.
|
|
45
|
+
use_pipeline_cache: If True, pipelines with the same ID that are currently executing and pipelines with the same ID that were completed won't process data again.
|
|
46
|
+
Pipelines ID is created based on the generate_pipeline_id function. Pipeline status can be manually reset with the reset_dataset_pipeline_run_status function.
|
|
47
|
+
incremental_loading: If True, only new or modified data will be processed to avoid duplication. (Only works if data is used with the Cognee python Data model).
|
|
48
|
+
The incremental system stores and compares hashes of processed data in the Data model and skips data with the same content hash.
|
|
43
49
|
data_per_batch: Number of data items to be processed in parallel.
|
|
44
50
|
run_in_background: If True, starts processing asynchronously and returns immediately.
|
|
45
51
|
If False, waits for completion before returning.
|
|
@@ -63,7 +69,8 @@ async def run_custom_pipeline(
|
|
|
63
69
|
datasets=dataset,
|
|
64
70
|
vector_db_config=vector_db_config,
|
|
65
71
|
graph_db_config=graph_db_config,
|
|
66
|
-
|
|
72
|
+
use_pipeline_cache=use_pipeline_cache,
|
|
73
|
+
incremental_loading=incremental_loading,
|
|
67
74
|
data_per_batch=data_per_batch,
|
|
68
75
|
pipeline_name=pipeline_name,
|
|
69
76
|
)
|
|
@@ -2,6 +2,7 @@ import os
|
|
|
2
2
|
from typing import Callable, List, Optional, Type
|
|
3
3
|
|
|
4
4
|
from cognee.modules.engine.models.node_set import NodeSet
|
|
5
|
+
from cognee.modules.retrieval.triplet_retriever import TripletRetriever
|
|
5
6
|
from cognee.modules.search.types import SearchType
|
|
6
7
|
from cognee.modules.search.operations import select_search_type
|
|
7
8
|
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
|
|
@@ -22,7 +23,6 @@ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphComplet
|
|
|
22
23
|
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
|
23
24
|
GraphCompletionContextExtensionRetriever,
|
|
24
25
|
)
|
|
25
|
-
from cognee.modules.retrieval.code_retriever import CodeRetriever
|
|
26
26
|
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
|
|
27
27
|
from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
|
|
28
28
|
|
|
@@ -37,6 +37,8 @@ async def get_search_type_tools(
|
|
|
37
37
|
node_name: Optional[List[str]] = None,
|
|
38
38
|
save_interaction: bool = False,
|
|
39
39
|
last_k: Optional[int] = None,
|
|
40
|
+
wide_search_top_k: Optional[int] = 100,
|
|
41
|
+
triplet_distance_penalty: Optional[float] = 3.5,
|
|
40
42
|
) -> list:
|
|
41
43
|
search_tasks: dict[SearchType, List[Callable]] = {
|
|
42
44
|
SearchType.SUMMARIES: [
|
|
@@ -59,6 +61,18 @@ async def get_search_type_tools(
|
|
|
59
61
|
system_prompt=system_prompt,
|
|
60
62
|
).get_context,
|
|
61
63
|
],
|
|
64
|
+
SearchType.TRIPLET_COMPLETION: [
|
|
65
|
+
TripletRetriever(
|
|
66
|
+
system_prompt_path=system_prompt_path,
|
|
67
|
+
top_k=top_k,
|
|
68
|
+
system_prompt=system_prompt,
|
|
69
|
+
).get_completion,
|
|
70
|
+
TripletRetriever(
|
|
71
|
+
system_prompt_path=system_prompt_path,
|
|
72
|
+
top_k=top_k,
|
|
73
|
+
system_prompt=system_prompt,
|
|
74
|
+
).get_context,
|
|
75
|
+
],
|
|
62
76
|
SearchType.GRAPH_COMPLETION: [
|
|
63
77
|
GraphCompletionRetriever(
|
|
64
78
|
system_prompt_path=system_prompt_path,
|
|
@@ -67,6 +81,8 @@ async def get_search_type_tools(
|
|
|
67
81
|
node_name=node_name,
|
|
68
82
|
save_interaction=save_interaction,
|
|
69
83
|
system_prompt=system_prompt,
|
|
84
|
+
wide_search_top_k=wide_search_top_k,
|
|
85
|
+
triplet_distance_penalty=triplet_distance_penalty,
|
|
70
86
|
).get_completion,
|
|
71
87
|
GraphCompletionRetriever(
|
|
72
88
|
system_prompt_path=system_prompt_path,
|
|
@@ -75,6 +91,8 @@ async def get_search_type_tools(
|
|
|
75
91
|
node_name=node_name,
|
|
76
92
|
save_interaction=save_interaction,
|
|
77
93
|
system_prompt=system_prompt,
|
|
94
|
+
wide_search_top_k=wide_search_top_k,
|
|
95
|
+
triplet_distance_penalty=triplet_distance_penalty,
|
|
78
96
|
).get_context,
|
|
79
97
|
],
|
|
80
98
|
SearchType.GRAPH_COMPLETION_COT: [
|
|
@@ -85,6 +103,8 @@ async def get_search_type_tools(
|
|
|
85
103
|
node_name=node_name,
|
|
86
104
|
save_interaction=save_interaction,
|
|
87
105
|
system_prompt=system_prompt,
|
|
106
|
+
wide_search_top_k=wide_search_top_k,
|
|
107
|
+
triplet_distance_penalty=triplet_distance_penalty,
|
|
88
108
|
).get_completion,
|
|
89
109
|
GraphCompletionCotRetriever(
|
|
90
110
|
system_prompt_path=system_prompt_path,
|
|
@@ -93,6 +113,8 @@ async def get_search_type_tools(
|
|
|
93
113
|
node_name=node_name,
|
|
94
114
|
save_interaction=save_interaction,
|
|
95
115
|
system_prompt=system_prompt,
|
|
116
|
+
wide_search_top_k=wide_search_top_k,
|
|
117
|
+
triplet_distance_penalty=triplet_distance_penalty,
|
|
96
118
|
).get_context,
|
|
97
119
|
],
|
|
98
120
|
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: [
|
|
@@ -103,6 +125,8 @@ async def get_search_type_tools(
|
|
|
103
125
|
node_name=node_name,
|
|
104
126
|
save_interaction=save_interaction,
|
|
105
127
|
system_prompt=system_prompt,
|
|
128
|
+
wide_search_top_k=wide_search_top_k,
|
|
129
|
+
triplet_distance_penalty=triplet_distance_penalty,
|
|
106
130
|
).get_completion,
|
|
107
131
|
GraphCompletionContextExtensionRetriever(
|
|
108
132
|
system_prompt_path=system_prompt_path,
|
|
@@ -111,6 +135,8 @@ async def get_search_type_tools(
|
|
|
111
135
|
node_name=node_name,
|
|
112
136
|
save_interaction=save_interaction,
|
|
113
137
|
system_prompt=system_prompt,
|
|
138
|
+
wide_search_top_k=wide_search_top_k,
|
|
139
|
+
triplet_distance_penalty=triplet_distance_penalty,
|
|
114
140
|
).get_context,
|
|
115
141
|
],
|
|
116
142
|
SearchType.GRAPH_SUMMARY_COMPLETION: [
|
|
@@ -121,6 +147,8 @@ async def get_search_type_tools(
|
|
|
121
147
|
node_name=node_name,
|
|
122
148
|
save_interaction=save_interaction,
|
|
123
149
|
system_prompt=system_prompt,
|
|
150
|
+
wide_search_top_k=wide_search_top_k,
|
|
151
|
+
triplet_distance_penalty=triplet_distance_penalty,
|
|
124
152
|
).get_completion,
|
|
125
153
|
GraphSummaryCompletionRetriever(
|
|
126
154
|
system_prompt_path=system_prompt_path,
|
|
@@ -129,12 +157,10 @@ async def get_search_type_tools(
|
|
|
129
157
|
node_name=node_name,
|
|
130
158
|
save_interaction=save_interaction,
|
|
131
159
|
system_prompt=system_prompt,
|
|
160
|
+
wide_search_top_k=wide_search_top_k,
|
|
161
|
+
triplet_distance_penalty=triplet_distance_penalty,
|
|
132
162
|
).get_context,
|
|
133
163
|
],
|
|
134
|
-
SearchType.CODE: [
|
|
135
|
-
CodeRetriever(top_k=top_k).get_completion,
|
|
136
|
-
CodeRetriever(top_k=top_k).get_context,
|
|
137
|
-
],
|
|
138
164
|
SearchType.CYPHER: [
|
|
139
165
|
CypherSearchRetriever().get_completion,
|
|
140
166
|
CypherSearchRetriever().get_context,
|
|
@@ -145,8 +171,16 @@ async def get_search_type_tools(
|
|
|
145
171
|
],
|
|
146
172
|
SearchType.FEEDBACK: [UserQAFeedback(last_k=last_k).add_feedback],
|
|
147
173
|
SearchType.TEMPORAL: [
|
|
148
|
-
TemporalRetriever(
|
|
149
|
-
|
|
174
|
+
TemporalRetriever(
|
|
175
|
+
top_k=top_k,
|
|
176
|
+
wide_search_top_k=wide_search_top_k,
|
|
177
|
+
triplet_distance_penalty=triplet_distance_penalty,
|
|
178
|
+
).get_completion,
|
|
179
|
+
TemporalRetriever(
|
|
180
|
+
top_k=top_k,
|
|
181
|
+
wide_search_top_k=wide_search_top_k,
|
|
182
|
+
triplet_distance_penalty=triplet_distance_penalty,
|
|
183
|
+
).get_context,
|
|
150
184
|
],
|
|
151
185
|
SearchType.CHUNKS_LEXICAL: (
|
|
152
186
|
lambda _r=JaccardChunksRetriever(top_k=top_k): [
|
|
@@ -169,7 +203,19 @@ async def get_search_type_tools(
|
|
|
169
203
|
):
|
|
170
204
|
raise UnsupportedSearchTypeError("Cypher query search types are disabled.")
|
|
171
205
|
|
|
172
|
-
|
|
206
|
+
from cognee.modules.retrieval.registered_community_retrievers import (
|
|
207
|
+
registered_community_retrievers,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
if query_type in registered_community_retrievers:
|
|
211
|
+
retriever = registered_community_retrievers[query_type]
|
|
212
|
+
retriever_instance = retriever(top_k=top_k)
|
|
213
|
+
search_type_tools = [
|
|
214
|
+
retriever_instance.get_completion,
|
|
215
|
+
retriever_instance.get_context,
|
|
216
|
+
]
|
|
217
|
+
else:
|
|
218
|
+
search_type_tools = search_tasks.get(query_type)
|
|
173
219
|
|
|
174
220
|
if not search_type_tools:
|
|
175
221
|
raise UnsupportedSearchTypeError(str(query_type))
|
|
@@ -24,6 +24,8 @@ async def no_access_control_search(
|
|
|
24
24
|
last_k: Optional[int] = None,
|
|
25
25
|
only_context: bool = False,
|
|
26
26
|
session_id: Optional[str] = None,
|
|
27
|
+
wide_search_top_k: Optional[int] = 100,
|
|
28
|
+
triplet_distance_penalty: Optional[float] = 3.5,
|
|
27
29
|
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
|
|
28
30
|
search_tools = await get_search_type_tools(
|
|
29
31
|
query_type=query_type,
|
|
@@ -35,6 +37,8 @@ async def no_access_control_search(
|
|
|
35
37
|
node_name=node_name,
|
|
36
38
|
save_interaction=save_interaction,
|
|
37
39
|
last_k=last_k,
|
|
40
|
+
wide_search_top_k=wide_search_top_k,
|
|
41
|
+
triplet_distance_penalty=triplet_distance_penalty,
|
|
38
42
|
)
|
|
39
43
|
graph_engine = await get_graph_engine()
|
|
40
44
|
is_empty = await graph_engine.is_empty()
|
|
@@ -47,6 +47,9 @@ async def search(
|
|
|
47
47
|
only_context: bool = False,
|
|
48
48
|
use_combined_context: bool = False,
|
|
49
49
|
session_id: Optional[str] = None,
|
|
50
|
+
wide_search_top_k: Optional[int] = 100,
|
|
51
|
+
triplet_distance_penalty: Optional[float] = 3.5,
|
|
52
|
+
verbose: bool = False,
|
|
50
53
|
) -> Union[CombinedSearchResult, List[SearchResult]]:
|
|
51
54
|
"""
|
|
52
55
|
|
|
@@ -90,6 +93,8 @@ async def search(
|
|
|
90
93
|
only_context=only_context,
|
|
91
94
|
use_combined_context=use_combined_context,
|
|
92
95
|
session_id=session_id,
|
|
96
|
+
wide_search_top_k=wide_search_top_k,
|
|
97
|
+
triplet_distance_penalty=triplet_distance_penalty,
|
|
93
98
|
)
|
|
94
99
|
else:
|
|
95
100
|
search_results = [
|
|
@@ -105,6 +110,8 @@ async def search(
|
|
|
105
110
|
last_k=last_k,
|
|
106
111
|
only_context=only_context,
|
|
107
112
|
session_id=session_id,
|
|
113
|
+
wide_search_top_k=wide_search_top_k,
|
|
114
|
+
triplet_distance_penalty=triplet_distance_penalty,
|
|
108
115
|
)
|
|
109
116
|
]
|
|
110
117
|
|
|
@@ -134,6 +141,7 @@ async def search(
|
|
|
134
141
|
)
|
|
135
142
|
|
|
136
143
|
if use_combined_context:
|
|
144
|
+
# Note: combined context search must always be verbose and return a CombinedSearchResult with graphs info
|
|
137
145
|
prepared_search_results = await prepare_search_result(
|
|
138
146
|
search_results[0] if isinstance(search_results, list) else search_results
|
|
139
147
|
)
|
|
@@ -167,25 +175,30 @@ async def search(
|
|
|
167
175
|
datasets = prepared_search_results["datasets"]
|
|
168
176
|
|
|
169
177
|
if only_context:
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
178
|
+
search_result_dict = {
|
|
179
|
+
"search_result": [context] if context else None,
|
|
180
|
+
"dataset_id": datasets[0].id,
|
|
181
|
+
"dataset_name": datasets[0].name,
|
|
182
|
+
"dataset_tenant_id": datasets[0].tenant_id,
|
|
183
|
+
}
|
|
184
|
+
if verbose:
|
|
185
|
+
# Include graphs only in verbose mode
|
|
186
|
+
search_result_dict["graphs"] = graphs
|
|
187
|
+
|
|
188
|
+
return_value.append(search_result_dict)
|
|
179
189
|
else:
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
190
|
+
search_result_dict = {
|
|
191
|
+
"search_result": [result] if result else None,
|
|
192
|
+
"dataset_id": datasets[0].id,
|
|
193
|
+
"dataset_name": datasets[0].name,
|
|
194
|
+
"dataset_tenant_id": datasets[0].tenant_id,
|
|
195
|
+
}
|
|
196
|
+
if verbose:
|
|
197
|
+
# Include graphs only in verbose mode
|
|
198
|
+
search_result_dict["graphs"] = graphs
|
|
199
|
+
|
|
200
|
+
return_value.append(search_result_dict)
|
|
201
|
+
|
|
189
202
|
return return_value
|
|
190
203
|
else:
|
|
191
204
|
return_value = []
|
|
@@ -219,6 +232,8 @@ async def authorized_search(
|
|
|
219
232
|
only_context: bool = False,
|
|
220
233
|
use_combined_context: bool = False,
|
|
221
234
|
session_id: Optional[str] = None,
|
|
235
|
+
wide_search_top_k: Optional[int] = 100,
|
|
236
|
+
triplet_distance_penalty: Optional[float] = 3.5,
|
|
222
237
|
) -> Union[
|
|
223
238
|
Tuple[Any, Union[List[Edge], str], List[Dataset]],
|
|
224
239
|
List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
|
|
@@ -246,6 +261,8 @@ async def authorized_search(
|
|
|
246
261
|
last_k=last_k,
|
|
247
262
|
only_context=True,
|
|
248
263
|
session_id=session_id,
|
|
264
|
+
wide_search_top_k=wide_search_top_k,
|
|
265
|
+
triplet_distance_penalty=triplet_distance_penalty,
|
|
249
266
|
)
|
|
250
267
|
|
|
251
268
|
context = {}
|
|
@@ -267,6 +284,8 @@ async def authorized_search(
|
|
|
267
284
|
node_name=node_name,
|
|
268
285
|
save_interaction=save_interaction,
|
|
269
286
|
last_k=last_k,
|
|
287
|
+
wide_search_top_k=wide_search_top_k,
|
|
288
|
+
triplet_distance_penalty=triplet_distance_penalty,
|
|
270
289
|
)
|
|
271
290
|
search_tools = specific_search_tools
|
|
272
291
|
if len(search_tools) == 2:
|
|
@@ -306,6 +325,7 @@ async def authorized_search(
|
|
|
306
325
|
last_k=last_k,
|
|
307
326
|
only_context=only_context,
|
|
308
327
|
session_id=session_id,
|
|
328
|
+
wide_search_top_k=wide_search_top_k,
|
|
309
329
|
)
|
|
310
330
|
|
|
311
331
|
return search_results
|
|
@@ -325,6 +345,8 @@ async def search_in_datasets_context(
|
|
|
325
345
|
only_context: bool = False,
|
|
326
346
|
context: Optional[Any] = None,
|
|
327
347
|
session_id: Optional[str] = None,
|
|
348
|
+
wide_search_top_k: Optional[int] = 100,
|
|
349
|
+
triplet_distance_penalty: Optional[float] = 3.5,
|
|
328
350
|
) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]:
|
|
329
351
|
"""
|
|
330
352
|
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
|
|
@@ -345,6 +367,8 @@ async def search_in_datasets_context(
|
|
|
345
367
|
only_context: bool = False,
|
|
346
368
|
context: Optional[Any] = None,
|
|
347
369
|
session_id: Optional[str] = None,
|
|
370
|
+
wide_search_top_k: Optional[int] = 100,
|
|
371
|
+
triplet_distance_penalty: Optional[float] = 3.5,
|
|
348
372
|
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
|
|
349
373
|
# Set database configuration in async context for each dataset user has access for
|
|
350
374
|
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
|
@@ -378,6 +402,8 @@ async def search_in_datasets_context(
|
|
|
378
402
|
node_name=node_name,
|
|
379
403
|
save_interaction=save_interaction,
|
|
380
404
|
last_k=last_k,
|
|
405
|
+
wide_search_top_k=wide_search_top_k,
|
|
406
|
+
triplet_distance_penalty=triplet_distance_penalty,
|
|
381
407
|
)
|
|
382
408
|
search_tools = specific_search_tools
|
|
383
409
|
if len(search_tools) == 2:
|
|
@@ -413,6 +439,8 @@ async def search_in_datasets_context(
|
|
|
413
439
|
only_context=only_context,
|
|
414
440
|
context=context,
|
|
415
441
|
session_id=session_id,
|
|
442
|
+
wide_search_top_k=wide_search_top_k,
|
|
443
|
+
triplet_distance_penalty=triplet_distance_penalty,
|
|
416
444
|
)
|
|
417
445
|
)
|
|
418
446
|
|
|
@@ -5,9 +5,9 @@ class SearchType(Enum):
|
|
|
5
5
|
SUMMARIES = "SUMMARIES"
|
|
6
6
|
CHUNKS = "CHUNKS"
|
|
7
7
|
RAG_COMPLETION = "RAG_COMPLETION"
|
|
8
|
+
TRIPLET_COMPLETION = "TRIPLET_COMPLETION"
|
|
8
9
|
GRAPH_COMPLETION = "GRAPH_COMPLETION"
|
|
9
10
|
GRAPH_SUMMARY_COMPLETION = "GRAPH_SUMMARY_COMPLETION"
|
|
10
|
-
CODE = "CODE"
|
|
11
11
|
CYPHER = "CYPHER"
|
|
12
12
|
NATURAL_LANGUAGE = "NATURAL_LANGUAGE"
|
|
13
13
|
GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT"
|