cognee 0.2.1.dev7__py3-none-any.whl → 0.2.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. cognee/api/client.py +3 -1
  2. cognee/api/v1/datasets/routers/get_datasets_router.py +2 -2
  3. cognee/infrastructure/databases/graph/kuzu/adapter.py +31 -9
  4. cognee/infrastructure/databases/graph/kuzu/kuzu_migrate.py +281 -0
  5. cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +103 -64
  6. cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +10 -3
  7. cognee/infrastructure/databases/vector/create_vector_engine.py +3 -11
  8. cognee/modules/data/models/Data.py +2 -2
  9. cognee/modules/data/processing/document_types/UnstructuredDocument.py +2 -5
  10. cognee/modules/graph/cognee_graph/CogneeGraph.py +39 -20
  11. cognee/modules/graph/methods/get_formatted_graph_data.py +1 -1
  12. cognee/modules/pipelines/operations/run_tasks.py +1 -1
  13. cognee/modules/pipelines/operations/run_tasks_distributed.py +1 -1
  14. cognee/modules/retrieval/chunks_retriever.py +23 -1
  15. cognee/modules/retrieval/code_retriever.py +64 -5
  16. cognee/modules/retrieval/completion_retriever.py +12 -10
  17. cognee/modules/retrieval/graph_completion_retriever.py +1 -1
  18. cognee/modules/retrieval/insights_retriever.py +4 -0
  19. cognee/modules/retrieval/natural_language_retriever.py +6 -10
  20. cognee/modules/retrieval/summaries_retriever.py +23 -1
  21. cognee/modules/retrieval/utils/brute_force_triplet_search.py +23 -4
  22. cognee/modules/settings/get_settings.py +0 -4
  23. cognee/modules/settings/save_vector_db_config.py +1 -1
  24. cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +84 -9
  25. {cognee-0.2.1.dev7.dist-info → cognee-0.2.2.dev0.dist-info}/METADATA +5 -7
  26. {cognee-0.2.1.dev7.dist-info → cognee-0.2.2.dev0.dist-info}/RECORD +29 -29
  27. cognee/tests/test_weaviate.py +0 -94
  28. {cognee-0.2.1.dev7.dist-info → cognee-0.2.2.dev0.dist-info}/WHEEL +0 -0
  29. {cognee-0.2.1.dev7.dist-info → cognee-0.2.2.dev0.dist-info}/licenses/LICENSE +0 -0
  30. {cognee-0.2.1.dev7.dist-info → cognee-0.2.2.dev0.dist-info}/licenses/NOTICE.md +0 -0
@@ -1,3 +1,4 @@
1
+ import time
1
2
  from cognee.shared.logging_utils import get_logger
2
3
  from typing import List, Dict, Union, Optional, Type
3
4
 
@@ -8,7 +9,7 @@ from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
8
9
  from cognee.modules.graph.cognee_graph.CogneeAbstractGraph import CogneeAbstractGraph
9
10
  import heapq
10
11
 
11
- logger = get_logger()
12
+ logger = get_logger("CogneeGraph")
12
13
 
13
14
 
14
15
  class CogneeGraph(CogneeAbstractGraph):
@@ -66,7 +67,13 @@ class CogneeGraph(CogneeAbstractGraph):
66
67
  ) -> None:
67
68
  if node_dimension < 1 or edge_dimension < 1:
68
69
  raise InvalidValueError(message="Dimensions must be positive integers")
70
+
69
71
  try:
72
+ import time
73
+
74
+ start_time = time.time()
75
+
76
+ # Determine projection strategy
70
77
  if node_type is not None and node_name is not None:
71
78
  nodes_data, edges_data = await adapter.get_nodeset_subgraph(
72
79
  node_type=node_type, node_name=node_name
@@ -83,16 +90,17 @@ class CogneeGraph(CogneeAbstractGraph):
83
90
  nodes_data, edges_data = await adapter.get_filtered_graph_data(
84
91
  attribute_filters=memory_fragment_filter
85
92
  )
86
-
87
93
  if not nodes_data or not edges_data:
88
94
  raise EntityNotFoundError(
89
95
  message="Empty filtered graph projected from the database."
90
96
  )
91
97
 
98
+ # Process nodes
92
99
  for node_id, properties in nodes_data:
93
100
  node_attributes = {key: properties.get(key) for key in node_properties_to_project}
94
101
  self.add_node(Node(str(node_id), node_attributes, dimension=node_dimension))
95
102
 
103
+ # Process edges
96
104
  for source_id, target_id, relationship_type, properties in edges_data:
97
105
  source_node = self.get_node(str(source_id))
98
106
  target_node = self.get_node(str(target_id))
@@ -113,17 +121,23 @@ class CogneeGraph(CogneeAbstractGraph):
113
121
 
114
122
  source_node.add_skeleton_edge(edge)
115
123
  target_node.add_skeleton_edge(edge)
116
-
117
124
  else:
118
125
  raise EntityNotFoundError(
119
126
  message=f"Edge references nonexistent nodes: {source_id} -> {target_id}"
120
127
  )
121
128
 
122
- except (ValueError, TypeError) as e:
123
- print(f"Error projecting graph: {e}")
124
- raise e
129
+ # Final statistics
130
+ projection_time = time.time() - start_time
131
+ logger.info(
132
+ f"Graph projection completed: {len(self.nodes)} nodes, {len(self.edges)} edges in {projection_time:.2f}s"
133
+ )
134
+
135
+ except Exception as e:
136
+ logger.error(f"Error during graph projection: {str(e)}")
137
+ raise
125
138
 
126
139
  async def map_vector_distances_to_graph_nodes(self, node_distances) -> None:
140
+ mapped_nodes = 0
127
141
  for category, scored_results in node_distances.items():
128
142
  for scored_result in scored_results:
129
143
  node_id = str(scored_result.id)
@@ -131,36 +145,41 @@ class CogneeGraph(CogneeAbstractGraph):
131
145
  node = self.get_node(node_id)
132
146
  if node:
133
147
  node.add_attribute("vector_distance", score)
148
+ mapped_nodes += 1
134
149
 
135
- async def map_vector_distances_to_graph_edges(self, vector_engine, query) -> None:
150
+ async def map_vector_distances_to_graph_edges(
151
+ self, vector_engine, query_vector, edge_distances
152
+ ) -> None:
136
153
  try:
137
- query_vector = await vector_engine.embed_data([query])
138
- query_vector = query_vector[0]
139
154
  if query_vector is None or len(query_vector) == 0:
140
155
  raise ValueError("Failed to generate query embedding.")
141
156
 
142
- edge_distances = await vector_engine.search(
143
- collection_name="EdgeType_relationship_name",
144
- query_text=query,
145
- limit=0,
146
- )
157
+ if edge_distances is None:
158
+ start_time = time.time()
159
+ edge_distances = await vector_engine.search(
160
+ collection_name="EdgeType_relationship_name",
161
+ query_vector=query_vector,
162
+ limit=0,
163
+ )
164
+ projection_time = time.time() - start_time
165
+ logger.info(
166
+ f"Edge collection distances were calculated separately from nodes in {projection_time:.2f}s"
167
+ )
147
168
 
148
169
  embedding_map = {result.payload["text"]: result.score for result in edge_distances}
149
170
 
150
171
  for edge in self.edges:
151
172
  relationship_type = edge.attributes.get("relationship_type")
152
- if not relationship_type or relationship_type not in embedding_map:
153
- print(f"Edge {edge} has an unknown or missing relationship type.")
154
- continue
155
-
156
- edge.attributes["vector_distance"] = embedding_map[relationship_type]
173
+ if relationship_type and relationship_type in embedding_map:
174
+ edge.attributes["vector_distance"] = embedding_map[relationship_type]
157
175
 
158
176
  except Exception as ex:
159
- print(f"Error mapping vector distances to edges: {ex}")
177
+ logger.error(f"Error mapping vector distances to edges: {str(ex)}")
160
178
  raise ex
161
179
 
162
180
  async def calculate_top_triplet_importances(self, k: int) -> List:
163
181
  min_heap = []
182
+
164
183
  for i, edge in enumerate(self.edges):
165
184
  source_node = self.get_node(edge.node1.id)
166
185
  target_node = self.get_node(edge.node2.id)
@@ -33,7 +33,7 @@ async def get_formatted_graph_data(dataset_id: UUID, user_id: UUID):
33
33
  lambda edge: {
34
34
  "source": str(edge[0]),
35
35
  "target": str(edge[1]),
36
- "label": edge[2],
36
+ "label": str(edge[2]),
37
37
  },
38
38
  edges,
39
39
  )
@@ -58,7 +58,7 @@ async def run_tasks(
58
58
  context: dict = None,
59
59
  ):
60
60
  if not user:
61
- user = get_default_user()
61
+ user = await get_default_user()
62
62
 
63
63
  # Get Dataset object
64
64
  db_engine = get_relational_engine()
@@ -44,7 +44,7 @@ if modal:
44
44
 
45
45
  async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, context):
46
46
  if not user:
47
- user = get_default_user()
47
+ user = await get_default_user()
48
48
 
49
49
  db_engine = get_relational_engine()
50
50
  async with db_engine.get_async_session() as session:
@@ -1,10 +1,13 @@
1
1
  from typing import Any, Optional
2
2
 
3
+ from cognee.shared.logging_utils import get_logger
3
4
  from cognee.infrastructure.databases.vector import get_vector_engine
4
5
  from cognee.modules.retrieval.base_retriever import BaseRetriever
5
6
  from cognee.modules.retrieval.exceptions.exceptions import NoDataError
6
7
  from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
7
8
 
9
+ logger = get_logger("ChunksRetriever")
10
+
8
11
 
9
12
  class ChunksRetriever(BaseRetriever):
10
13
  """
@@ -41,14 +44,22 @@ class ChunksRetriever(BaseRetriever):
41
44
 
42
45
  - Any: A list of document chunk payloads retrieved from the search.
43
46
  """
47
+ logger.info(
48
+ f"Starting chunk retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
49
+ )
50
+
44
51
  vector_engine = get_vector_engine()
45
52
 
46
53
  try:
47
54
  found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
55
+ logger.info(f"Found {len(found_chunks)} chunks from vector search")
48
56
  except CollectionNotFoundError as error:
57
+ logger.error("DocumentChunk_text collection not found in vector database")
49
58
  raise NoDataError("No data found in the system, please add data first.") from error
50
59
 
51
- return [result.payload for result in found_chunks]
60
+ chunk_payloads = [result.payload for result in found_chunks]
61
+ logger.info(f"Returning {len(chunk_payloads)} chunk payloads")
62
+ return chunk_payloads
52
63
 
53
64
  async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
54
65
  """
@@ -70,6 +81,17 @@ class ChunksRetriever(BaseRetriever):
70
81
  - Any: The context used for the completion or the retrieved context if none was
71
82
  provided.
72
83
  """
84
+ logger.info(
85
+ f"Starting completion generation for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
86
+ )
87
+
73
88
  if context is None:
89
+ logger.debug("No context provided, retrieving context from vector database")
74
90
  context = await self.get_context(query)
91
+ else:
92
+ logger.debug("Using provided context")
93
+
94
+ logger.info(
95
+ f"Returning context with {len(context) if isinstance(context, list) else 1} item(s)"
96
+ )
75
97
  return context
@@ -3,12 +3,15 @@ import asyncio
3
3
  import aiofiles
4
4
  from pydantic import BaseModel
5
5
 
6
+ from cognee.shared.logging_utils import get_logger
6
7
  from cognee.modules.retrieval.base_retriever import BaseRetriever
7
8
  from cognee.infrastructure.databases.graph import get_graph_engine
8
9
  from cognee.infrastructure.databases.vector import get_vector_engine
9
10
  from cognee.infrastructure.llm.get_llm_client import get_llm_client
10
11
  from cognee.infrastructure.llm.prompts import read_query_prompt
11
12
 
13
+ logger = get_logger("CodeRetriever")
14
+
12
15
 
13
16
  class CodeRetriever(BaseRetriever):
14
17
  """Retriever for handling code-based searches."""
@@ -35,26 +38,43 @@ class CodeRetriever(BaseRetriever):
35
38
 
36
39
  async def _process_query(self, query: str) -> "CodeRetriever.CodeQueryInfo":
37
40
  """Process the query using LLM to extract file names and source code parts."""
41
+ logger.debug(
42
+ f"Processing query with LLM: '{query[:100]}{'...' if len(query) > 100 else ''}'"
43
+ )
44
+
38
45
  system_prompt = read_query_prompt("codegraph_retriever_system.txt")
39
46
  llm_client = get_llm_client()
47
+
40
48
  try:
41
- return await llm_client.acreate_structured_output(
49
+ result = await llm_client.acreate_structured_output(
42
50
  text_input=query,
43
51
  system_prompt=system_prompt,
44
52
  response_model=self.CodeQueryInfo,
45
53
  )
54
+ logger.info(
55
+ f"LLM extracted {len(result.filenames)} filenames and {len(result.sourcecode)} chars of source code"
56
+ )
57
+ return result
46
58
  except Exception as e:
59
+ logger.error(f"Failed to retrieve structured output from LLM: {str(e)}")
47
60
  raise RuntimeError("Failed to retrieve structured output from LLM") from e
48
61
 
49
62
  async def get_context(self, query: str) -> Any:
50
63
  """Find relevant code files based on the query."""
64
+ logger.info(
65
+ f"Starting code retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
66
+ )
67
+
51
68
  if not query or not isinstance(query, str):
69
+ logger.error("Invalid query: must be a non-empty string")
52
70
  raise ValueError("The query must be a non-empty string.")
53
71
 
54
72
  try:
55
73
  vector_engine = get_vector_engine()
56
74
  graph_engine = await get_graph_engine()
75
+ logger.debug("Successfully initialized vector and graph engines")
57
76
  except Exception as e:
77
+ logger.error(f"Database initialization error: {str(e)}")
58
78
  raise RuntimeError("Database initialization error in code_graph_retriever, ") from e
59
79
 
60
80
  files_and_codeparts = await self._process_query(query)
@@ -63,52 +83,80 @@ class CodeRetriever(BaseRetriever):
63
83
  similar_codepieces = []
64
84
 
65
85
  if not files_and_codeparts.filenames or not files_and_codeparts.sourcecode:
86
+ logger.info("No specific files/code extracted from query, performing general search")
87
+
66
88
  for collection in self.file_name_collections:
89
+ logger.debug(f"Searching {collection} collection with general query")
67
90
  search_results_file = await vector_engine.search(
68
91
  collection, query, limit=self.top_k
69
92
  )
93
+ logger.debug(f"Found {len(search_results_file)} results in {collection}")
70
94
  for res in search_results_file:
71
95
  similar_filenames.append(
72
96
  {"id": res.id, "score": res.score, "payload": res.payload}
73
97
  )
74
98
 
75
99
  for collection in self.classes_and_functions_collections:
100
+ logger.debug(f"Searching {collection} collection with general query")
76
101
  search_results_code = await vector_engine.search(
77
102
  collection, query, limit=self.top_k
78
103
  )
104
+ logger.debug(f"Found {len(search_results_code)} results in {collection}")
79
105
  for res in search_results_code:
80
106
  similar_codepieces.append(
81
107
  {"id": res.id, "score": res.score, "payload": res.payload}
82
108
  )
83
109
  else:
110
+ logger.info(
111
+ f"Using extracted filenames ({len(files_and_codeparts.filenames)}) and source code for targeted search"
112
+ )
113
+
84
114
  for collection in self.file_name_collections:
85
115
  for file_from_query in files_and_codeparts.filenames:
116
+ logger.debug(f"Searching {collection} for specific file: {file_from_query}")
86
117
  search_results_file = await vector_engine.search(
87
118
  collection, file_from_query, limit=self.top_k
88
119
  )
120
+ logger.debug(
121
+ f"Found {len(search_results_file)} results for file {file_from_query}"
122
+ )
89
123
  for res in search_results_file:
90
124
  similar_filenames.append(
91
125
  {"id": res.id, "score": res.score, "payload": res.payload}
92
126
  )
93
127
 
94
128
  for collection in self.classes_and_functions_collections:
129
+ logger.debug(f"Searching {collection} with extracted source code")
95
130
  search_results_code = await vector_engine.search(
96
131
  collection, files_and_codeparts.sourcecode, limit=self.top_k
97
132
  )
133
+ logger.debug(f"Found {len(search_results_code)} results for source code search")
98
134
  for res in search_results_code:
99
135
  similar_codepieces.append(
100
136
  {"id": res.id, "score": res.score, "payload": res.payload}
101
137
  )
102
138
 
139
+ total_items = len(similar_filenames) + len(similar_codepieces)
140
+ logger.info(
141
+ f"Total search results: {total_items} items ({len(similar_filenames)} filenames, {len(similar_codepieces)} code pieces)"
142
+ )
143
+
144
+ if total_items == 0:
145
+ logger.warning("No search results found, returning empty list")
146
+ return []
147
+
148
+ logger.debug("Getting graph connections for all search results")
103
149
  relevant_triplets = await asyncio.gather(
104
150
  *[
105
151
  graph_engine.get_connections(similar_piece["id"])
106
152
  for similar_piece in similar_filenames + similar_codepieces
107
153
  ]
108
154
  )
155
+ logger.info(f"Retrieved graph connections for {len(relevant_triplets)} items")
109
156
 
110
157
  paths = set()
111
- for sublist in relevant_triplets:
158
+ for i, sublist in enumerate(relevant_triplets):
159
+ logger.debug(f"Processing connections for item {i}: {len(sublist)} connections")
112
160
  for tpl in sublist:
113
161
  if isinstance(tpl, tuple) and len(tpl) >= 3:
114
162
  if "file_path" in tpl[0]:
@@ -116,23 +164,31 @@ class CodeRetriever(BaseRetriever):
116
164
  if "file_path" in tpl[2]:
117
165
  paths.add(tpl[2]["file_path"])
118
166
 
167
+ logger.info(f"Found {len(paths)} unique file paths to read")
168
+
119
169
  retrieved_files = {}
120
170
  read_tasks = []
121
171
  for file_path in paths:
122
172
 
123
173
  async def read_file(fp):
124
174
  try:
175
+ logger.debug(f"Reading file: {fp}")
125
176
  async with aiofiles.open(fp, "r", encoding="utf-8") as f:
126
- retrieved_files[fp] = await f.read()
177
+ content = await f.read()
178
+ retrieved_files[fp] = content
179
+ logger.debug(f"Successfully read {len(content)} characters from {fp}")
127
180
  except Exception as e:
128
- print(f"Error reading {fp}: {e}")
181
+ logger.error(f"Error reading {fp}: {e}")
129
182
  retrieved_files[fp] = ""
130
183
 
131
184
  read_tasks.append(read_file(file_path))
132
185
 
133
186
  await asyncio.gather(*read_tasks)
187
+ logger.info(
188
+ f"Successfully read {len([f for f in retrieved_files.values() if f])} files (out of {len(paths)} total)"
189
+ )
134
190
 
135
- return [
191
+ result = [
136
192
  {
137
193
  "name": file_path,
138
194
  "description": file_path,
@@ -141,6 +197,9 @@ class CodeRetriever(BaseRetriever):
141
197
  for file_path in paths
142
198
  ]
143
199
 
200
+ logger.info(f"Returning {len(result)} code file contexts")
201
+ return result
202
+
144
203
  async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
145
204
  """Returns the code files context."""
146
205
  if context is None:
@@ -1,11 +1,14 @@
1
1
  from typing import Any, Optional
2
2
 
3
+ from cognee.shared.logging_utils import get_logger
3
4
  from cognee.infrastructure.databases.vector import get_vector_engine
4
5
  from cognee.modules.retrieval.utils.completion import generate_completion
5
6
  from cognee.modules.retrieval.base_retriever import BaseRetriever
6
7
  from cognee.modules.retrieval.exceptions.exceptions import NoDataError
7
8
  from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
8
9
 
10
+ logger = get_logger("CompletionRetriever")
11
+
9
12
 
10
13
  class CompletionRetriever(BaseRetriever):
11
14
  """
@@ -56,8 +59,10 @@ class CompletionRetriever(BaseRetriever):
56
59
 
57
60
  # Combine all chunks text returned from vector search (number of chunks is determined by top_k
58
61
  chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks]
59
- return "\n".join(chunks_payload)
62
+ combined_context = "\n".join(chunks_payload)
63
+ return combined_context
60
64
  except CollectionNotFoundError as error:
65
+ logger.error("DocumentChunk_text collection not found")
61
66
  raise NoDataError("No data found in the system, please add data first.") from error
62
67
 
63
68
  async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
@@ -70,22 +75,19 @@ class CompletionRetriever(BaseRetriever):
70
75
  Parameters:
71
76
  -----------
72
77
 
73
- - query (str): The input query for which the completion is generated.
74
- - context (Optional[Any]): Optional context to use for generating the completion; if
75
- not provided, it will be retrieved using get_context. (default None)
78
+ - query (str): The query string to be used for generating a completion.
79
+ - context (Optional[Any]): Optional pre-fetched context to use for generating the
80
+ completion; if None, it retrieves the context for the query. (default None)
76
81
 
77
82
  Returns:
78
83
  --------
79
84
 
80
- - Any: A list containing the generated completion from the LLM.
85
+ - Any: The generated completion based on the provided query and context.
81
86
  """
82
87
  if context is None:
83
88
  context = await self.get_context(query)
84
89
 
85
90
  completion = await generate_completion(
86
- query=query,
87
- context=context,
88
- user_prompt_path=self.user_prompt_path,
89
- system_prompt_path=self.system_prompt_path,
91
+ query, context, self.user_prompt_path, self.system_prompt_path
90
92
  )
91
- return [completion]
93
+ return completion
@@ -10,7 +10,7 @@ from cognee.modules.retrieval.utils.completion import generate_completion
10
10
  from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
11
11
  from cognee.shared.logging_utils import get_logger
12
12
 
13
- logger = get_logger()
13
+ logger = get_logger("GraphCompletionRetriever")
14
14
 
15
15
 
16
16
  class GraphCompletionRetriever(BaseRetriever):
@@ -1,12 +1,15 @@
1
1
  import asyncio
2
2
  from typing import Any, Optional
3
3
 
4
+ from cognee.shared.logging_utils import get_logger
4
5
  from cognee.infrastructure.databases.graph import get_graph_engine
5
6
  from cognee.infrastructure.databases.vector import get_vector_engine
6
7
  from cognee.modules.retrieval.base_retriever import BaseRetriever
7
8
  from cognee.modules.retrieval.exceptions.exceptions import NoDataError
8
9
  from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
9
10
 
11
+ logger = get_logger("InsightsRetriever")
12
+
10
13
 
11
14
  class InsightsRetriever(BaseRetriever):
12
15
  """
@@ -63,6 +66,7 @@ class InsightsRetriever(BaseRetriever):
63
66
  vector_engine.search("EntityType_name", query_text=query, limit=self.top_k),
64
67
  )
65
68
  except CollectionNotFoundError as error:
69
+ logger.error("Entity collections not found")
66
70
  raise NoDataError("No data found in the system, please add data first.") from error
67
71
 
68
72
  results = [*results[0], *results[1]]
@@ -1,5 +1,5 @@
1
1
  from typing import Any, Optional
2
- import logging
2
+ from cognee.shared.logging_utils import get_logger
3
3
  from cognee.infrastructure.databases.graph import get_graph_engine
4
4
  from cognee.infrastructure.databases.graph.networkx.adapter import NetworkXAdapter
5
5
  from cognee.infrastructure.llm.get_llm_client import get_llm_client
@@ -8,7 +8,7 @@ from cognee.modules.retrieval.base_retriever import BaseRetriever
8
8
  from cognee.modules.retrieval.exceptions import SearchTypeNotSupported
9
9
  from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
10
10
 
11
- logger = logging.getLogger("NaturalLanguageRetriever")
11
+ logger = get_logger("NaturalLanguageRetriever")
12
12
 
13
13
 
14
14
  class NaturalLanguageRetriever(BaseRetriever):
@@ -123,16 +123,12 @@ class NaturalLanguageRetriever(BaseRetriever):
123
123
  - Optional[Any]: Returns the context retrieved from the graph database based on the
124
124
  query.
125
125
  """
126
- try:
127
- graph_engine = await get_graph_engine()
126
+ graph_engine = await get_graph_engine()
128
127
 
129
- if isinstance(graph_engine, (NetworkXAdapter)):
130
- raise SearchTypeNotSupported("Natural language search type not supported.")
128
+ if isinstance(graph_engine, (NetworkXAdapter)):
129
+ raise SearchTypeNotSupported("Natural language search type not supported.")
131
130
 
132
- return await self._execute_cypher_query(query, graph_engine)
133
- except Exception as e:
134
- logger.error("Failed to execute natural language search retrieval: %s", str(e))
135
- raise e
131
+ return await self._execute_cypher_query(query, graph_engine)
136
132
 
137
133
  async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
138
134
  """
@@ -1,10 +1,13 @@
1
1
  from typing import Any, Optional
2
2
 
3
+ from cognee.shared.logging_utils import get_logger
3
4
  from cognee.infrastructure.databases.vector import get_vector_engine
4
5
  from cognee.modules.retrieval.base_retriever import BaseRetriever
5
6
  from cognee.modules.retrieval.exceptions.exceptions import NoDataError
6
7
  from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
7
8
 
9
+ logger = get_logger("SummariesRetriever")
10
+
8
11
 
9
12
  class SummariesRetriever(BaseRetriever):
10
13
  """
@@ -40,16 +43,24 @@ class SummariesRetriever(BaseRetriever):
40
43
 
41
44
  - Any: A list of payloads from the retrieved summaries.
42
45
  """
46
+ logger.info(
47
+ f"Starting summary retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
48
+ )
49
+
43
50
  vector_engine = get_vector_engine()
44
51
 
45
52
  try:
46
53
  summaries_results = await vector_engine.search(
47
54
  "TextSummary_text", query, limit=self.top_k
48
55
  )
56
+ logger.info(f"Found {len(summaries_results)} summaries from vector search")
49
57
  except CollectionNotFoundError as error:
58
+ logger.error("TextSummary_text collection not found in vector database")
50
59
  raise NoDataError("No data found in the system, please add data first.") from error
51
60
 
52
- return [summary.payload for summary in summaries_results]
61
+ summary_payloads = [summary.payload for summary in summaries_results]
62
+ logger.info(f"Returning {len(summary_payloads)} summary payloads")
63
+ return summary_payloads
53
64
 
54
65
  async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
55
66
  """
@@ -70,6 +81,17 @@ class SummariesRetriever(BaseRetriever):
70
81
 
71
82
  - Any: The generated completion context, which is either provided or retrieved.
72
83
  """
84
+ logger.info(
85
+ f"Starting completion generation for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
86
+ )
87
+
73
88
  if context is None:
89
+ logger.debug("No context provided, retrieving context from vector database")
74
90
  context = await self.get_context(query)
91
+ else:
92
+ logger.debug("Using provided context")
93
+
94
+ logger.info(
95
+ f"Returning context with {len(context) if isinstance(context, list) else 1} item(s)"
96
+ )
75
97
  return context
@@ -1,4 +1,5 @@
1
1
  import asyncio
2
+ import time
2
3
  from typing import List, Optional, Type
3
4
 
4
5
  from cognee.shared.logging_utils import get_logger, ERROR
@@ -59,13 +60,13 @@ async def get_memory_fragment(
59
60
  node_name: Optional[List[str]] = None,
60
61
  ) -> CogneeGraph:
61
62
  """Creates and initializes a CogneeGraph memory fragment with optional property projections."""
62
- graph_engine = await get_graph_engine()
63
- memory_fragment = CogneeGraph()
64
-
65
63
  if properties_to_project is None:
66
64
  properties_to_project = ["id", "description", "name", "type", "text"]
67
65
 
68
66
  try:
67
+ graph_engine = await get_graph_engine()
68
+ memory_fragment = CogneeGraph()
69
+
69
70
  await memory_fragment.project_graph_from_db(
70
71
  graph_engine,
71
72
  node_properties_to_project=properties_to_project,
@@ -73,7 +74,13 @@ async def get_memory_fragment(
73
74
  node_type=node_type,
74
75
  node_name=node_name,
75
76
  )
77
+
76
78
  except EntityNotFoundError:
79
+ # This is expected behavior - continue with empty fragment
80
+ pass
81
+ except Exception as e:
82
+ logger.error(f"Error during memory fragment creation: {str(e)}")
83
+ # Still return the fragment even if projection failed
77
84
  pass
78
85
 
79
86
  return memory_fragment
@@ -168,6 +175,8 @@ async def brute_force_search(
168
175
  return []
169
176
 
170
177
  try:
178
+ start_time = time.time()
179
+
171
180
  results = await asyncio.gather(
172
181
  *[search_in_collection(collection_name) for collection_name in collections]
173
182
  )
@@ -175,10 +184,20 @@ async def brute_force_search(
175
184
  if all(not item for item in results):
176
185
  return []
177
186
 
187
+ # Final statistics
188
+ projection_time = time.time() - start_time
189
+ logger.info(
190
+ f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {projection_time:.2f}s"
191
+ )
192
+
178
193
  node_distances = {collection: result for collection, result in zip(collections, results)}
179
194
 
195
+ edge_distances = node_distances.get("EdgeType_relationship_name", None)
196
+
180
197
  await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
181
- await memory_fragment.map_vector_distances_to_graph_edges(vector_engine, query)
198
+ await memory_fragment.map_vector_distances_to_graph_edges(
199
+ vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances
200
+ )
182
201
 
183
202
  results = await memory_fragment.calculate_top_triplet_importances(k=top_k)
184
203
 
@@ -43,10 +43,6 @@ def get_settings() -> SettingsDict:
43
43
  llm_config = get_llm_config()
44
44
 
45
45
  vector_dbs = [
46
- {
47
- "value": "weaviate",
48
- "label": "Weaviate",
49
- },
50
46
  {
51
47
  "value": "qdrant",
52
48
  "label": "Qdrant",