cognee 0.2.1.dev7__py3-none-any.whl → 0.2.2.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/api/client.py +3 -1
- cognee/api/v1/datasets/routers/get_datasets_router.py +2 -2
- cognee/infrastructure/databases/graph/kuzu/adapter.py +31 -9
- cognee/infrastructure/databases/graph/kuzu/kuzu_migrate.py +281 -0
- cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +103 -64
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +10 -3
- cognee/infrastructure/databases/vector/create_vector_engine.py +3 -11
- cognee/modules/data/models/Data.py +2 -2
- cognee/modules/data/processing/document_types/UnstructuredDocument.py +2 -5
- cognee/modules/graph/cognee_graph/CogneeGraph.py +39 -20
- cognee/modules/graph/methods/get_formatted_graph_data.py +1 -1
- cognee/modules/pipelines/operations/run_tasks.py +1 -1
- cognee/modules/pipelines/operations/run_tasks_distributed.py +1 -1
- cognee/modules/retrieval/chunks_retriever.py +23 -1
- cognee/modules/retrieval/code_retriever.py +64 -5
- cognee/modules/retrieval/completion_retriever.py +12 -10
- cognee/modules/retrieval/graph_completion_retriever.py +1 -1
- cognee/modules/retrieval/insights_retriever.py +4 -0
- cognee/modules/retrieval/natural_language_retriever.py +6 -10
- cognee/modules/retrieval/summaries_retriever.py +23 -1
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +23 -4
- cognee/modules/settings/get_settings.py +0 -4
- cognee/modules/settings/save_vector_db_config.py +1 -1
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +84 -9
- {cognee-0.2.1.dev7.dist-info → cognee-0.2.2.dev0.dist-info}/METADATA +5 -7
- {cognee-0.2.1.dev7.dist-info → cognee-0.2.2.dev0.dist-info}/RECORD +29 -29
- cognee/tests/test_weaviate.py +0 -94
- {cognee-0.2.1.dev7.dist-info → cognee-0.2.2.dev0.dist-info}/WHEEL +0 -0
- {cognee-0.2.1.dev7.dist-info → cognee-0.2.2.dev0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.2.1.dev7.dist-info → cognee-0.2.2.dev0.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import time
|
|
1
2
|
from cognee.shared.logging_utils import get_logger
|
|
2
3
|
from typing import List, Dict, Union, Optional, Type
|
|
3
4
|
|
|
@@ -8,7 +9,7 @@ from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
|
|
|
8
9
|
from cognee.modules.graph.cognee_graph.CogneeAbstractGraph import CogneeAbstractGraph
|
|
9
10
|
import heapq
|
|
10
11
|
|
|
11
|
-
logger = get_logger()
|
|
12
|
+
logger = get_logger("CogneeGraph")
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
class CogneeGraph(CogneeAbstractGraph):
|
|
@@ -66,7 +67,13 @@ class CogneeGraph(CogneeAbstractGraph):
|
|
|
66
67
|
) -> None:
|
|
67
68
|
if node_dimension < 1 or edge_dimension < 1:
|
|
68
69
|
raise InvalidValueError(message="Dimensions must be positive integers")
|
|
70
|
+
|
|
69
71
|
try:
|
|
72
|
+
import time
|
|
73
|
+
|
|
74
|
+
start_time = time.time()
|
|
75
|
+
|
|
76
|
+
# Determine projection strategy
|
|
70
77
|
if node_type is not None and node_name is not None:
|
|
71
78
|
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
|
|
72
79
|
node_type=node_type, node_name=node_name
|
|
@@ -83,16 +90,17 @@ class CogneeGraph(CogneeAbstractGraph):
|
|
|
83
90
|
nodes_data, edges_data = await adapter.get_filtered_graph_data(
|
|
84
91
|
attribute_filters=memory_fragment_filter
|
|
85
92
|
)
|
|
86
|
-
|
|
87
93
|
if not nodes_data or not edges_data:
|
|
88
94
|
raise EntityNotFoundError(
|
|
89
95
|
message="Empty filtered graph projected from the database."
|
|
90
96
|
)
|
|
91
97
|
|
|
98
|
+
# Process nodes
|
|
92
99
|
for node_id, properties in nodes_data:
|
|
93
100
|
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
|
|
94
101
|
self.add_node(Node(str(node_id), node_attributes, dimension=node_dimension))
|
|
95
102
|
|
|
103
|
+
# Process edges
|
|
96
104
|
for source_id, target_id, relationship_type, properties in edges_data:
|
|
97
105
|
source_node = self.get_node(str(source_id))
|
|
98
106
|
target_node = self.get_node(str(target_id))
|
|
@@ -113,17 +121,23 @@ class CogneeGraph(CogneeAbstractGraph):
|
|
|
113
121
|
|
|
114
122
|
source_node.add_skeleton_edge(edge)
|
|
115
123
|
target_node.add_skeleton_edge(edge)
|
|
116
|
-
|
|
117
124
|
else:
|
|
118
125
|
raise EntityNotFoundError(
|
|
119
126
|
message=f"Edge references nonexistent nodes: {source_id} -> {target_id}"
|
|
120
127
|
)
|
|
121
128
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
129
|
+
# Final statistics
|
|
130
|
+
projection_time = time.time() - start_time
|
|
131
|
+
logger.info(
|
|
132
|
+
f"Graph projection completed: {len(self.nodes)} nodes, {len(self.edges)} edges in {projection_time:.2f}s"
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
except Exception as e:
|
|
136
|
+
logger.error(f"Error during graph projection: {str(e)}")
|
|
137
|
+
raise
|
|
125
138
|
|
|
126
139
|
async def map_vector_distances_to_graph_nodes(self, node_distances) -> None:
|
|
140
|
+
mapped_nodes = 0
|
|
127
141
|
for category, scored_results in node_distances.items():
|
|
128
142
|
for scored_result in scored_results:
|
|
129
143
|
node_id = str(scored_result.id)
|
|
@@ -131,36 +145,41 @@ class CogneeGraph(CogneeAbstractGraph):
|
|
|
131
145
|
node = self.get_node(node_id)
|
|
132
146
|
if node:
|
|
133
147
|
node.add_attribute("vector_distance", score)
|
|
148
|
+
mapped_nodes += 1
|
|
134
149
|
|
|
135
|
-
async def map_vector_distances_to_graph_edges(
|
|
150
|
+
async def map_vector_distances_to_graph_edges(
|
|
151
|
+
self, vector_engine, query_vector, edge_distances
|
|
152
|
+
) -> None:
|
|
136
153
|
try:
|
|
137
|
-
query_vector = await vector_engine.embed_data([query])
|
|
138
|
-
query_vector = query_vector[0]
|
|
139
154
|
if query_vector is None or len(query_vector) == 0:
|
|
140
155
|
raise ValueError("Failed to generate query embedding.")
|
|
141
156
|
|
|
142
|
-
edge_distances
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
157
|
+
if edge_distances is None:
|
|
158
|
+
start_time = time.time()
|
|
159
|
+
edge_distances = await vector_engine.search(
|
|
160
|
+
collection_name="EdgeType_relationship_name",
|
|
161
|
+
query_vector=query_vector,
|
|
162
|
+
limit=0,
|
|
163
|
+
)
|
|
164
|
+
projection_time = time.time() - start_time
|
|
165
|
+
logger.info(
|
|
166
|
+
f"Edge collection distances were calculated separately from nodes in {projection_time:.2f}s"
|
|
167
|
+
)
|
|
147
168
|
|
|
148
169
|
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
|
|
149
170
|
|
|
150
171
|
for edge in self.edges:
|
|
151
172
|
relationship_type = edge.attributes.get("relationship_type")
|
|
152
|
-
if
|
|
153
|
-
|
|
154
|
-
continue
|
|
155
|
-
|
|
156
|
-
edge.attributes["vector_distance"] = embedding_map[relationship_type]
|
|
173
|
+
if relationship_type and relationship_type in embedding_map:
|
|
174
|
+
edge.attributes["vector_distance"] = embedding_map[relationship_type]
|
|
157
175
|
|
|
158
176
|
except Exception as ex:
|
|
159
|
-
|
|
177
|
+
logger.error(f"Error mapping vector distances to edges: {str(ex)}")
|
|
160
178
|
raise ex
|
|
161
179
|
|
|
162
180
|
async def calculate_top_triplet_importances(self, k: int) -> List:
|
|
163
181
|
min_heap = []
|
|
182
|
+
|
|
164
183
|
for i, edge in enumerate(self.edges):
|
|
165
184
|
source_node = self.get_node(edge.node1.id)
|
|
166
185
|
target_node = self.get_node(edge.node2.id)
|
|
@@ -44,7 +44,7 @@ if modal:
|
|
|
44
44
|
|
|
45
45
|
async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, context):
|
|
46
46
|
if not user:
|
|
47
|
-
user = get_default_user()
|
|
47
|
+
user = await get_default_user()
|
|
48
48
|
|
|
49
49
|
db_engine = get_relational_engine()
|
|
50
50
|
async with db_engine.get_async_session() as session:
|
|
@@ -1,10 +1,13 @@
|
|
|
1
1
|
from typing import Any, Optional
|
|
2
2
|
|
|
3
|
+
from cognee.shared.logging_utils import get_logger
|
|
3
4
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
4
5
|
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|
5
6
|
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
|
6
7
|
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
|
7
8
|
|
|
9
|
+
logger = get_logger("ChunksRetriever")
|
|
10
|
+
|
|
8
11
|
|
|
9
12
|
class ChunksRetriever(BaseRetriever):
|
|
10
13
|
"""
|
|
@@ -41,14 +44,22 @@ class ChunksRetriever(BaseRetriever):
|
|
|
41
44
|
|
|
42
45
|
- Any: A list of document chunk payloads retrieved from the search.
|
|
43
46
|
"""
|
|
47
|
+
logger.info(
|
|
48
|
+
f"Starting chunk retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
|
49
|
+
)
|
|
50
|
+
|
|
44
51
|
vector_engine = get_vector_engine()
|
|
45
52
|
|
|
46
53
|
try:
|
|
47
54
|
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
|
|
55
|
+
logger.info(f"Found {len(found_chunks)} chunks from vector search")
|
|
48
56
|
except CollectionNotFoundError as error:
|
|
57
|
+
logger.error("DocumentChunk_text collection not found in vector database")
|
|
49
58
|
raise NoDataError("No data found in the system, please add data first.") from error
|
|
50
59
|
|
|
51
|
-
|
|
60
|
+
chunk_payloads = [result.payload for result in found_chunks]
|
|
61
|
+
logger.info(f"Returning {len(chunk_payloads)} chunk payloads")
|
|
62
|
+
return chunk_payloads
|
|
52
63
|
|
|
53
64
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
|
54
65
|
"""
|
|
@@ -70,6 +81,17 @@ class ChunksRetriever(BaseRetriever):
|
|
|
70
81
|
- Any: The context used for the completion or the retrieved context if none was
|
|
71
82
|
provided.
|
|
72
83
|
"""
|
|
84
|
+
logger.info(
|
|
85
|
+
f"Starting completion generation for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
|
86
|
+
)
|
|
87
|
+
|
|
73
88
|
if context is None:
|
|
89
|
+
logger.debug("No context provided, retrieving context from vector database")
|
|
74
90
|
context = await self.get_context(query)
|
|
91
|
+
else:
|
|
92
|
+
logger.debug("Using provided context")
|
|
93
|
+
|
|
94
|
+
logger.info(
|
|
95
|
+
f"Returning context with {len(context) if isinstance(context, list) else 1} item(s)"
|
|
96
|
+
)
|
|
75
97
|
return context
|
|
@@ -3,12 +3,15 @@ import asyncio
|
|
|
3
3
|
import aiofiles
|
|
4
4
|
from pydantic import BaseModel
|
|
5
5
|
|
|
6
|
+
from cognee.shared.logging_utils import get_logger
|
|
6
7
|
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|
7
8
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
8
9
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
9
10
|
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
|
10
11
|
from cognee.infrastructure.llm.prompts import read_query_prompt
|
|
11
12
|
|
|
13
|
+
logger = get_logger("CodeRetriever")
|
|
14
|
+
|
|
12
15
|
|
|
13
16
|
class CodeRetriever(BaseRetriever):
|
|
14
17
|
"""Retriever for handling code-based searches."""
|
|
@@ -35,26 +38,43 @@ class CodeRetriever(BaseRetriever):
|
|
|
35
38
|
|
|
36
39
|
async def _process_query(self, query: str) -> "CodeRetriever.CodeQueryInfo":
|
|
37
40
|
"""Process the query using LLM to extract file names and source code parts."""
|
|
41
|
+
logger.debug(
|
|
42
|
+
f"Processing query with LLM: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
|
43
|
+
)
|
|
44
|
+
|
|
38
45
|
system_prompt = read_query_prompt("codegraph_retriever_system.txt")
|
|
39
46
|
llm_client = get_llm_client()
|
|
47
|
+
|
|
40
48
|
try:
|
|
41
|
-
|
|
49
|
+
result = await llm_client.acreate_structured_output(
|
|
42
50
|
text_input=query,
|
|
43
51
|
system_prompt=system_prompt,
|
|
44
52
|
response_model=self.CodeQueryInfo,
|
|
45
53
|
)
|
|
54
|
+
logger.info(
|
|
55
|
+
f"LLM extracted {len(result.filenames)} filenames and {len(result.sourcecode)} chars of source code"
|
|
56
|
+
)
|
|
57
|
+
return result
|
|
46
58
|
except Exception as e:
|
|
59
|
+
logger.error(f"Failed to retrieve structured output from LLM: {str(e)}")
|
|
47
60
|
raise RuntimeError("Failed to retrieve structured output from LLM") from e
|
|
48
61
|
|
|
49
62
|
async def get_context(self, query: str) -> Any:
|
|
50
63
|
"""Find relevant code files based on the query."""
|
|
64
|
+
logger.info(
|
|
65
|
+
f"Starting code retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
|
66
|
+
)
|
|
67
|
+
|
|
51
68
|
if not query or not isinstance(query, str):
|
|
69
|
+
logger.error("Invalid query: must be a non-empty string")
|
|
52
70
|
raise ValueError("The query must be a non-empty string.")
|
|
53
71
|
|
|
54
72
|
try:
|
|
55
73
|
vector_engine = get_vector_engine()
|
|
56
74
|
graph_engine = await get_graph_engine()
|
|
75
|
+
logger.debug("Successfully initialized vector and graph engines")
|
|
57
76
|
except Exception as e:
|
|
77
|
+
logger.error(f"Database initialization error: {str(e)}")
|
|
58
78
|
raise RuntimeError("Database initialization error in code_graph_retriever, ") from e
|
|
59
79
|
|
|
60
80
|
files_and_codeparts = await self._process_query(query)
|
|
@@ -63,52 +83,80 @@ class CodeRetriever(BaseRetriever):
|
|
|
63
83
|
similar_codepieces = []
|
|
64
84
|
|
|
65
85
|
if not files_and_codeparts.filenames or not files_and_codeparts.sourcecode:
|
|
86
|
+
logger.info("No specific files/code extracted from query, performing general search")
|
|
87
|
+
|
|
66
88
|
for collection in self.file_name_collections:
|
|
89
|
+
logger.debug(f"Searching {collection} collection with general query")
|
|
67
90
|
search_results_file = await vector_engine.search(
|
|
68
91
|
collection, query, limit=self.top_k
|
|
69
92
|
)
|
|
93
|
+
logger.debug(f"Found {len(search_results_file)} results in {collection}")
|
|
70
94
|
for res in search_results_file:
|
|
71
95
|
similar_filenames.append(
|
|
72
96
|
{"id": res.id, "score": res.score, "payload": res.payload}
|
|
73
97
|
)
|
|
74
98
|
|
|
75
99
|
for collection in self.classes_and_functions_collections:
|
|
100
|
+
logger.debug(f"Searching {collection} collection with general query")
|
|
76
101
|
search_results_code = await vector_engine.search(
|
|
77
102
|
collection, query, limit=self.top_k
|
|
78
103
|
)
|
|
104
|
+
logger.debug(f"Found {len(search_results_code)} results in {collection}")
|
|
79
105
|
for res in search_results_code:
|
|
80
106
|
similar_codepieces.append(
|
|
81
107
|
{"id": res.id, "score": res.score, "payload": res.payload}
|
|
82
108
|
)
|
|
83
109
|
else:
|
|
110
|
+
logger.info(
|
|
111
|
+
f"Using extracted filenames ({len(files_and_codeparts.filenames)}) and source code for targeted search"
|
|
112
|
+
)
|
|
113
|
+
|
|
84
114
|
for collection in self.file_name_collections:
|
|
85
115
|
for file_from_query in files_and_codeparts.filenames:
|
|
116
|
+
logger.debug(f"Searching {collection} for specific file: {file_from_query}")
|
|
86
117
|
search_results_file = await vector_engine.search(
|
|
87
118
|
collection, file_from_query, limit=self.top_k
|
|
88
119
|
)
|
|
120
|
+
logger.debug(
|
|
121
|
+
f"Found {len(search_results_file)} results for file {file_from_query}"
|
|
122
|
+
)
|
|
89
123
|
for res in search_results_file:
|
|
90
124
|
similar_filenames.append(
|
|
91
125
|
{"id": res.id, "score": res.score, "payload": res.payload}
|
|
92
126
|
)
|
|
93
127
|
|
|
94
128
|
for collection in self.classes_and_functions_collections:
|
|
129
|
+
logger.debug(f"Searching {collection} with extracted source code")
|
|
95
130
|
search_results_code = await vector_engine.search(
|
|
96
131
|
collection, files_and_codeparts.sourcecode, limit=self.top_k
|
|
97
132
|
)
|
|
133
|
+
logger.debug(f"Found {len(search_results_code)} results for source code search")
|
|
98
134
|
for res in search_results_code:
|
|
99
135
|
similar_codepieces.append(
|
|
100
136
|
{"id": res.id, "score": res.score, "payload": res.payload}
|
|
101
137
|
)
|
|
102
138
|
|
|
139
|
+
total_items = len(similar_filenames) + len(similar_codepieces)
|
|
140
|
+
logger.info(
|
|
141
|
+
f"Total search results: {total_items} items ({len(similar_filenames)} filenames, {len(similar_codepieces)} code pieces)"
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
if total_items == 0:
|
|
145
|
+
logger.warning("No search results found, returning empty list")
|
|
146
|
+
return []
|
|
147
|
+
|
|
148
|
+
logger.debug("Getting graph connections for all search results")
|
|
103
149
|
relevant_triplets = await asyncio.gather(
|
|
104
150
|
*[
|
|
105
151
|
graph_engine.get_connections(similar_piece["id"])
|
|
106
152
|
for similar_piece in similar_filenames + similar_codepieces
|
|
107
153
|
]
|
|
108
154
|
)
|
|
155
|
+
logger.info(f"Retrieved graph connections for {len(relevant_triplets)} items")
|
|
109
156
|
|
|
110
157
|
paths = set()
|
|
111
|
-
for sublist in relevant_triplets:
|
|
158
|
+
for i, sublist in enumerate(relevant_triplets):
|
|
159
|
+
logger.debug(f"Processing connections for item {i}: {len(sublist)} connections")
|
|
112
160
|
for tpl in sublist:
|
|
113
161
|
if isinstance(tpl, tuple) and len(tpl) >= 3:
|
|
114
162
|
if "file_path" in tpl[0]:
|
|
@@ -116,23 +164,31 @@ class CodeRetriever(BaseRetriever):
|
|
|
116
164
|
if "file_path" in tpl[2]:
|
|
117
165
|
paths.add(tpl[2]["file_path"])
|
|
118
166
|
|
|
167
|
+
logger.info(f"Found {len(paths)} unique file paths to read")
|
|
168
|
+
|
|
119
169
|
retrieved_files = {}
|
|
120
170
|
read_tasks = []
|
|
121
171
|
for file_path in paths:
|
|
122
172
|
|
|
123
173
|
async def read_file(fp):
|
|
124
174
|
try:
|
|
175
|
+
logger.debug(f"Reading file: {fp}")
|
|
125
176
|
async with aiofiles.open(fp, "r", encoding="utf-8") as f:
|
|
126
|
-
|
|
177
|
+
content = await f.read()
|
|
178
|
+
retrieved_files[fp] = content
|
|
179
|
+
logger.debug(f"Successfully read {len(content)} characters from {fp}")
|
|
127
180
|
except Exception as e:
|
|
128
|
-
|
|
181
|
+
logger.error(f"Error reading {fp}: {e}")
|
|
129
182
|
retrieved_files[fp] = ""
|
|
130
183
|
|
|
131
184
|
read_tasks.append(read_file(file_path))
|
|
132
185
|
|
|
133
186
|
await asyncio.gather(*read_tasks)
|
|
187
|
+
logger.info(
|
|
188
|
+
f"Successfully read {len([f for f in retrieved_files.values() if f])} files (out of {len(paths)} total)"
|
|
189
|
+
)
|
|
134
190
|
|
|
135
|
-
|
|
191
|
+
result = [
|
|
136
192
|
{
|
|
137
193
|
"name": file_path,
|
|
138
194
|
"description": file_path,
|
|
@@ -141,6 +197,9 @@ class CodeRetriever(BaseRetriever):
|
|
|
141
197
|
for file_path in paths
|
|
142
198
|
]
|
|
143
199
|
|
|
200
|
+
logger.info(f"Returning {len(result)} code file contexts")
|
|
201
|
+
return result
|
|
202
|
+
|
|
144
203
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
|
145
204
|
"""Returns the code files context."""
|
|
146
205
|
if context is None:
|
|
@@ -1,11 +1,14 @@
|
|
|
1
1
|
from typing import Any, Optional
|
|
2
2
|
|
|
3
|
+
from cognee.shared.logging_utils import get_logger
|
|
3
4
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
4
5
|
from cognee.modules.retrieval.utils.completion import generate_completion
|
|
5
6
|
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|
6
7
|
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
|
7
8
|
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
|
8
9
|
|
|
10
|
+
logger = get_logger("CompletionRetriever")
|
|
11
|
+
|
|
9
12
|
|
|
10
13
|
class CompletionRetriever(BaseRetriever):
|
|
11
14
|
"""
|
|
@@ -56,8 +59,10 @@ class CompletionRetriever(BaseRetriever):
|
|
|
56
59
|
|
|
57
60
|
# Combine all chunks text returned from vector search (number of chunks is determined by top_k
|
|
58
61
|
chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks]
|
|
59
|
-
|
|
62
|
+
combined_context = "\n".join(chunks_payload)
|
|
63
|
+
return combined_context
|
|
60
64
|
except CollectionNotFoundError as error:
|
|
65
|
+
logger.error("DocumentChunk_text collection not found")
|
|
61
66
|
raise NoDataError("No data found in the system, please add data first.") from error
|
|
62
67
|
|
|
63
68
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
|
@@ -70,22 +75,19 @@ class CompletionRetriever(BaseRetriever):
|
|
|
70
75
|
Parameters:
|
|
71
76
|
-----------
|
|
72
77
|
|
|
73
|
-
- query (str): The
|
|
74
|
-
- context (Optional[Any]): Optional context to use for generating the
|
|
75
|
-
|
|
78
|
+
- query (str): The query string to be used for generating a completion.
|
|
79
|
+
- context (Optional[Any]): Optional pre-fetched context to use for generating the
|
|
80
|
+
completion; if None, it retrieves the context for the query. (default None)
|
|
76
81
|
|
|
77
82
|
Returns:
|
|
78
83
|
--------
|
|
79
84
|
|
|
80
|
-
- Any:
|
|
85
|
+
- Any: The generated completion based on the provided query and context.
|
|
81
86
|
"""
|
|
82
87
|
if context is None:
|
|
83
88
|
context = await self.get_context(query)
|
|
84
89
|
|
|
85
90
|
completion = await generate_completion(
|
|
86
|
-
query
|
|
87
|
-
context=context,
|
|
88
|
-
user_prompt_path=self.user_prompt_path,
|
|
89
|
-
system_prompt_path=self.system_prompt_path,
|
|
91
|
+
query, context, self.user_prompt_path, self.system_prompt_path
|
|
90
92
|
)
|
|
91
|
-
return
|
|
93
|
+
return completion
|
|
@@ -10,7 +10,7 @@ from cognee.modules.retrieval.utils.completion import generate_completion
|
|
|
10
10
|
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
|
11
11
|
from cognee.shared.logging_utils import get_logger
|
|
12
12
|
|
|
13
|
-
logger = get_logger()
|
|
13
|
+
logger = get_logger("GraphCompletionRetriever")
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class GraphCompletionRetriever(BaseRetriever):
|
|
@@ -1,12 +1,15 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
from typing import Any, Optional
|
|
3
3
|
|
|
4
|
+
from cognee.shared.logging_utils import get_logger
|
|
4
5
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
5
6
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
6
7
|
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|
7
8
|
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
|
8
9
|
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
|
9
10
|
|
|
11
|
+
logger = get_logger("InsightsRetriever")
|
|
12
|
+
|
|
10
13
|
|
|
11
14
|
class InsightsRetriever(BaseRetriever):
|
|
12
15
|
"""
|
|
@@ -63,6 +66,7 @@ class InsightsRetriever(BaseRetriever):
|
|
|
63
66
|
vector_engine.search("EntityType_name", query_text=query, limit=self.top_k),
|
|
64
67
|
)
|
|
65
68
|
except CollectionNotFoundError as error:
|
|
69
|
+
logger.error("Entity collections not found")
|
|
66
70
|
raise NoDataError("No data found in the system, please add data first.") from error
|
|
67
71
|
|
|
68
72
|
results = [*results[0], *results[1]]
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from typing import Any, Optional
|
|
2
|
-
import
|
|
2
|
+
from cognee.shared.logging_utils import get_logger
|
|
3
3
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
4
4
|
from cognee.infrastructure.databases.graph.networkx.adapter import NetworkXAdapter
|
|
5
5
|
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
|
@@ -8,7 +8,7 @@ from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|
|
8
8
|
from cognee.modules.retrieval.exceptions import SearchTypeNotSupported
|
|
9
9
|
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
|
10
10
|
|
|
11
|
-
logger =
|
|
11
|
+
logger = get_logger("NaturalLanguageRetriever")
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class NaturalLanguageRetriever(BaseRetriever):
|
|
@@ -123,16 +123,12 @@ class NaturalLanguageRetriever(BaseRetriever):
|
|
|
123
123
|
- Optional[Any]: Returns the context retrieved from the graph database based on the
|
|
124
124
|
query.
|
|
125
125
|
"""
|
|
126
|
-
|
|
127
|
-
graph_engine = await get_graph_engine()
|
|
126
|
+
graph_engine = await get_graph_engine()
|
|
128
127
|
|
|
129
|
-
|
|
130
|
-
|
|
128
|
+
if isinstance(graph_engine, (NetworkXAdapter)):
|
|
129
|
+
raise SearchTypeNotSupported("Natural language search type not supported.")
|
|
131
130
|
|
|
132
|
-
|
|
133
|
-
except Exception as e:
|
|
134
|
-
logger.error("Failed to execute natural language search retrieval: %s", str(e))
|
|
135
|
-
raise e
|
|
131
|
+
return await self._execute_cypher_query(query, graph_engine)
|
|
136
132
|
|
|
137
133
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
|
138
134
|
"""
|
|
@@ -1,10 +1,13 @@
|
|
|
1
1
|
from typing import Any, Optional
|
|
2
2
|
|
|
3
|
+
from cognee.shared.logging_utils import get_logger
|
|
3
4
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
4
5
|
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|
5
6
|
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
|
6
7
|
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
|
7
8
|
|
|
9
|
+
logger = get_logger("SummariesRetriever")
|
|
10
|
+
|
|
8
11
|
|
|
9
12
|
class SummariesRetriever(BaseRetriever):
|
|
10
13
|
"""
|
|
@@ -40,16 +43,24 @@ class SummariesRetriever(BaseRetriever):
|
|
|
40
43
|
|
|
41
44
|
- Any: A list of payloads from the retrieved summaries.
|
|
42
45
|
"""
|
|
46
|
+
logger.info(
|
|
47
|
+
f"Starting summary retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
|
48
|
+
)
|
|
49
|
+
|
|
43
50
|
vector_engine = get_vector_engine()
|
|
44
51
|
|
|
45
52
|
try:
|
|
46
53
|
summaries_results = await vector_engine.search(
|
|
47
54
|
"TextSummary_text", query, limit=self.top_k
|
|
48
55
|
)
|
|
56
|
+
logger.info(f"Found {len(summaries_results)} summaries from vector search")
|
|
49
57
|
except CollectionNotFoundError as error:
|
|
58
|
+
logger.error("TextSummary_text collection not found in vector database")
|
|
50
59
|
raise NoDataError("No data found in the system, please add data first.") from error
|
|
51
60
|
|
|
52
|
-
|
|
61
|
+
summary_payloads = [summary.payload for summary in summaries_results]
|
|
62
|
+
logger.info(f"Returning {len(summary_payloads)} summary payloads")
|
|
63
|
+
return summary_payloads
|
|
53
64
|
|
|
54
65
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
|
55
66
|
"""
|
|
@@ -70,6 +81,17 @@ class SummariesRetriever(BaseRetriever):
|
|
|
70
81
|
|
|
71
82
|
- Any: The generated completion context, which is either provided or retrieved.
|
|
72
83
|
"""
|
|
84
|
+
logger.info(
|
|
85
|
+
f"Starting completion generation for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
|
86
|
+
)
|
|
87
|
+
|
|
73
88
|
if context is None:
|
|
89
|
+
logger.debug("No context provided, retrieving context from vector database")
|
|
74
90
|
context = await self.get_context(query)
|
|
91
|
+
else:
|
|
92
|
+
logger.debug("Using provided context")
|
|
93
|
+
|
|
94
|
+
logger.info(
|
|
95
|
+
f"Returning context with {len(context) if isinstance(context, list) else 1} item(s)"
|
|
96
|
+
)
|
|
75
97
|
return context
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
+
import time
|
|
2
3
|
from typing import List, Optional, Type
|
|
3
4
|
|
|
4
5
|
from cognee.shared.logging_utils import get_logger, ERROR
|
|
@@ -59,13 +60,13 @@ async def get_memory_fragment(
|
|
|
59
60
|
node_name: Optional[List[str]] = None,
|
|
60
61
|
) -> CogneeGraph:
|
|
61
62
|
"""Creates and initializes a CogneeGraph memory fragment with optional property projections."""
|
|
62
|
-
graph_engine = await get_graph_engine()
|
|
63
|
-
memory_fragment = CogneeGraph()
|
|
64
|
-
|
|
65
63
|
if properties_to_project is None:
|
|
66
64
|
properties_to_project = ["id", "description", "name", "type", "text"]
|
|
67
65
|
|
|
68
66
|
try:
|
|
67
|
+
graph_engine = await get_graph_engine()
|
|
68
|
+
memory_fragment = CogneeGraph()
|
|
69
|
+
|
|
69
70
|
await memory_fragment.project_graph_from_db(
|
|
70
71
|
graph_engine,
|
|
71
72
|
node_properties_to_project=properties_to_project,
|
|
@@ -73,7 +74,13 @@ async def get_memory_fragment(
|
|
|
73
74
|
node_type=node_type,
|
|
74
75
|
node_name=node_name,
|
|
75
76
|
)
|
|
77
|
+
|
|
76
78
|
except EntityNotFoundError:
|
|
79
|
+
# This is expected behavior - continue with empty fragment
|
|
80
|
+
pass
|
|
81
|
+
except Exception as e:
|
|
82
|
+
logger.error(f"Error during memory fragment creation: {str(e)}")
|
|
83
|
+
# Still return the fragment even if projection failed
|
|
77
84
|
pass
|
|
78
85
|
|
|
79
86
|
return memory_fragment
|
|
@@ -168,6 +175,8 @@ async def brute_force_search(
|
|
|
168
175
|
return []
|
|
169
176
|
|
|
170
177
|
try:
|
|
178
|
+
start_time = time.time()
|
|
179
|
+
|
|
171
180
|
results = await asyncio.gather(
|
|
172
181
|
*[search_in_collection(collection_name) for collection_name in collections]
|
|
173
182
|
)
|
|
@@ -175,10 +184,20 @@ async def brute_force_search(
|
|
|
175
184
|
if all(not item for item in results):
|
|
176
185
|
return []
|
|
177
186
|
|
|
187
|
+
# Final statistics
|
|
188
|
+
projection_time = time.time() - start_time
|
|
189
|
+
logger.info(
|
|
190
|
+
f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {projection_time:.2f}s"
|
|
191
|
+
)
|
|
192
|
+
|
|
178
193
|
node_distances = {collection: result for collection, result in zip(collections, results)}
|
|
179
194
|
|
|
195
|
+
edge_distances = node_distances.get("EdgeType_relationship_name", None)
|
|
196
|
+
|
|
180
197
|
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
|
|
181
|
-
await memory_fragment.map_vector_distances_to_graph_edges(
|
|
198
|
+
await memory_fragment.map_vector_distances_to_graph_edges(
|
|
199
|
+
vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances
|
|
200
|
+
)
|
|
182
201
|
|
|
183
202
|
results = await memory_fragment.calculate_top_triplet_importances(k=top_k)
|
|
184
203
|
|