cognee 0.5.0__py3-none-any.whl → 0.5.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. cognee/api/client.py +5 -1
  2. cognee/api/v1/add/add.py +1 -2
  3. cognee/api/v1/cognify/code_graph_pipeline.py +119 -0
  4. cognee/api/v1/cognify/cognify.py +16 -24
  5. cognee/api/v1/cognify/routers/__init__.py +1 -0
  6. cognee/api/v1/cognify/routers/get_code_pipeline_router.py +90 -0
  7. cognee/api/v1/cognify/routers/get_cognify_router.py +1 -3
  8. cognee/api/v1/datasets/routers/get_datasets_router.py +3 -3
  9. cognee/api/v1/ontologies/ontologies.py +37 -12
  10. cognee/api/v1/ontologies/routers/get_ontology_router.py +25 -27
  11. cognee/api/v1/search/search.py +0 -4
  12. cognee/api/v1/ui/ui.py +68 -38
  13. cognee/context_global_variables.py +16 -61
  14. cognee/eval_framework/answer_generation/answer_generation_executor.py +0 -10
  15. cognee/eval_framework/answer_generation/run_question_answering_module.py +1 -1
  16. cognee/eval_framework/corpus_builder/task_getters/get_cascade_graph_tasks.py +2 -0
  17. cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py +4 -4
  18. cognee/eval_framework/eval_config.py +2 -2
  19. cognee/eval_framework/modal_run_eval.py +28 -16
  20. cognee/infrastructure/databases/graph/config.py +0 -3
  21. cognee/infrastructure/databases/graph/get_graph_engine.py +0 -1
  22. cognee/infrastructure/databases/graph/graph_db_interface.py +0 -15
  23. cognee/infrastructure/databases/graph/kuzu/adapter.py +0 -228
  24. cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +1 -80
  25. cognee/infrastructure/databases/utils/__init__.py +0 -3
  26. cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +48 -62
  27. cognee/infrastructure/databases/vector/config.py +0 -2
  28. cognee/infrastructure/databases/vector/create_vector_engine.py +0 -1
  29. cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +6 -8
  30. cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +7 -9
  31. cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +10 -11
  32. cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +544 -0
  33. cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +0 -2
  34. cognee/infrastructure/databases/vector/vector_db_interface.py +0 -35
  35. cognee/infrastructure/files/storage/s3_config.py +0 -2
  36. cognee/infrastructure/llm/LLMGateway.py +2 -5
  37. cognee/infrastructure/llm/config.py +0 -35
  38. cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py +2 -2
  39. cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +8 -23
  40. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +16 -17
  41. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +37 -40
  42. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +36 -39
  43. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +1 -19
  44. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +9 -11
  45. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +21 -23
  46. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +34 -42
  47. cognee/modules/cognify/config.py +0 -2
  48. cognee/modules/data/deletion/prune_system.py +2 -52
  49. cognee/modules/data/methods/delete_dataset.py +0 -26
  50. cognee/modules/engine/models/__init__.py +0 -1
  51. cognee/modules/graph/cognee_graph/CogneeGraph.py +37 -85
  52. cognee/modules/graph/cognee_graph/CogneeGraphElements.py +3 -8
  53. cognee/modules/memify/memify.py +7 -1
  54. cognee/modules/pipelines/operations/pipeline.py +2 -18
  55. cognee/modules/retrieval/__init__.py +1 -1
  56. cognee/modules/retrieval/code_retriever.py +232 -0
  57. cognee/modules/retrieval/graph_completion_context_extension_retriever.py +0 -4
  58. cognee/modules/retrieval/graph_completion_cot_retriever.py +0 -4
  59. cognee/modules/retrieval/graph_completion_retriever.py +0 -10
  60. cognee/modules/retrieval/graph_summary_completion_retriever.py +0 -4
  61. cognee/modules/retrieval/temporal_retriever.py +0 -4
  62. cognee/modules/retrieval/utils/brute_force_triplet_search.py +10 -42
  63. cognee/modules/run_custom_pipeline/run_custom_pipeline.py +1 -8
  64. cognee/modules/search/methods/get_search_type_tools.py +8 -54
  65. cognee/modules/search/methods/no_access_control_search.py +0 -4
  66. cognee/modules/search/methods/search.py +0 -21
  67. cognee/modules/search/types/SearchType.py +1 -1
  68. cognee/modules/settings/get_settings.py +0 -19
  69. cognee/modules/users/methods/get_authenticated_user.py +2 -2
  70. cognee/modules/users/models/DatasetDatabase.py +3 -15
  71. cognee/shared/logging_utils.py +0 -4
  72. cognee/tasks/code/enrich_dependency_graph_checker.py +35 -0
  73. cognee/tasks/code/get_local_dependencies_checker.py +20 -0
  74. cognee/tasks/code/get_repo_dependency_graph_checker.py +35 -0
  75. cognee/tasks/documents/__init__.py +1 -0
  76. cognee/tasks/documents/check_permissions_on_dataset.py +26 -0
  77. cognee/tasks/graph/extract_graph_from_data.py +10 -9
  78. cognee/tasks/repo_processor/__init__.py +2 -0
  79. cognee/tasks/repo_processor/get_local_dependencies.py +335 -0
  80. cognee/tasks/repo_processor/get_non_code_files.py +158 -0
  81. cognee/tasks/repo_processor/get_repo_file_dependencies.py +243 -0
  82. cognee/tasks/storage/add_data_points.py +2 -142
  83. cognee/tests/test_cognee_server_start.py +4 -2
  84. cognee/tests/test_conversation_history.py +1 -23
  85. cognee/tests/test_delete_bmw_example.py +60 -0
  86. cognee/tests/test_search_db.py +1 -37
  87. cognee/tests/unit/api/test_ontology_endpoint.py +89 -77
  88. cognee/tests/unit/infrastructure/mock_embedding_engine.py +7 -3
  89. cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +5 -0
  90. cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +2 -2
  91. cognee/tests/unit/modules/graph/cognee_graph_test.py +0 -406
  92. {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/METADATA +89 -76
  93. {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/RECORD +97 -118
  94. {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/WHEEL +1 -1
  95. cognee/api/v1/ui/node_setup.py +0 -360
  96. cognee/api/v1/ui/npm_utils.py +0 -50
  97. cognee/eval_framework/Dockerfile +0 -29
  98. cognee/infrastructure/databases/dataset_database_handler/__init__.py +0 -3
  99. cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py +0 -80
  100. cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +0 -18
  101. cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py +0 -10
  102. cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py +0 -81
  103. cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +0 -168
  104. cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py +0 -10
  105. cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py +0 -10
  106. cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py +0 -30
  107. cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py +0 -50
  108. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py +0 -5
  109. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py +0 -153
  110. cognee/memify_pipelines/create_triplet_embeddings.py +0 -53
  111. cognee/modules/engine/models/Triplet.py +0 -9
  112. cognee/modules/retrieval/register_retriever.py +0 -10
  113. cognee/modules/retrieval/registered_community_retrievers.py +0 -1
  114. cognee/modules/retrieval/triplet_retriever.py +0 -182
  115. cognee/shared/rate_limiting.py +0 -30
  116. cognee/tasks/memify/get_triplet_datapoints.py +0 -289
  117. cognee/tests/integration/retrieval/test_triplet_retriever.py +0 -84
  118. cognee/tests/integration/tasks/test_add_data_points.py +0 -139
  119. cognee/tests/integration/tasks/test_get_triplet_datapoints.py +0 -69
  120. cognee/tests/test_dataset_database_handler.py +0 -137
  121. cognee/tests/test_dataset_delete.py +0 -76
  122. cognee/tests/test_edge_centered_payload.py +0 -170
  123. cognee/tests/test_pipeline_cache.py +0 -164
  124. cognee/tests/unit/infrastructure/llm/test_llm_config.py +0 -46
  125. cognee/tests/unit/modules/memify_tasks/test_get_triplet_datapoints.py +0 -214
  126. cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +0 -608
  127. cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +0 -83
  128. cognee/tests/unit/tasks/storage/test_add_data_points.py +0 -288
  129. {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/entry_points.txt +0 -0
  130. {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/licenses/LICENSE +0 -0
  131. {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/licenses/NOTICE.md +0 -0
@@ -8,14 +8,12 @@ import os
8
8
  class CognifyConfig(BaseSettings):
9
9
  classification_model: object = DefaultContentPrediction
10
10
  summarization_model: object = SummarizedContent
11
- triplet_embedding: bool = False
12
11
  model_config = SettingsConfigDict(env_file=".env", extra="allow")
13
12
 
14
13
  def to_dict(self) -> dict:
15
14
  return {
16
15
  "classification_model": self.classification_model,
17
16
  "summarization_model": self.summarization_model,
18
- "triplet_embedding": self.triplet_embedding,
19
17
  }
20
18
 
21
19
 
@@ -1,67 +1,17 @@
1
- from sqlalchemy.exc import OperationalError
2
-
3
- from cognee.infrastructure.databases.exceptions import EntityNotFoundError
4
- from cognee.context_global_variables import backend_access_control_enabled
5
1
  from cognee.infrastructure.databases.vector import get_vector_engine
6
2
  from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
7
3
  from cognee.infrastructure.databases.relational import get_relational_engine
8
- from cognee.infrastructure.databases.utils import (
9
- get_graph_dataset_database_handler,
10
- get_vector_dataset_database_handler,
11
- )
12
4
  from cognee.shared.cache import delete_cache
13
- from cognee.modules.users.models import DatasetDatabase
14
- from cognee.shared.logging_utils import get_logger
15
-
16
- logger = get_logger()
17
-
18
-
19
- async def prune_graph_databases():
20
- db_engine = get_relational_engine()
21
- try:
22
- dataset_databases = await db_engine.get_all_data_from_table("dataset_database")
23
- # Go through each dataset database and delete the graph database
24
- for dataset_database in dataset_databases:
25
- handler = get_graph_dataset_database_handler(dataset_database)
26
- await handler["handler_instance"].delete_dataset(dataset_database)
27
- except (OperationalError, EntityNotFoundError) as e:
28
- logger.debug(
29
- "Skipping pruning of graph DB. Error when accessing dataset_database table: %s",
30
- e,
31
- )
32
- return
33
-
34
-
35
- async def prune_vector_databases():
36
- db_engine = get_relational_engine()
37
- try:
38
- dataset_databases = await db_engine.get_all_data_from_table("dataset_database")
39
- # Go through each dataset database and delete the vector database
40
- for dataset_database in dataset_databases:
41
- handler = get_vector_dataset_database_handler(dataset_database)
42
- await handler["handler_instance"].delete_dataset(dataset_database)
43
- except (OperationalError, EntityNotFoundError) as e:
44
- logger.debug(
45
- "Skipping pruning of vector DB. Error when accessing dataset_database table: %s",
46
- e,
47
- )
48
- return
49
5
 
50
6
 
51
7
  async def prune_system(graph=True, vector=True, metadata=True, cache=True):
52
- # Note: prune system should not be available through the API, it has no permission checks and will
53
- # delete all graph and vector databases if called. It should only be used in development or testing environments.
54
- if graph and not backend_access_control_enabled():
8
+ if graph:
55
9
  graph_engine = await get_graph_engine()
56
10
  await graph_engine.delete_graph()
57
- elif graph and backend_access_control_enabled():
58
- await prune_graph_databases()
59
11
 
60
- if vector and not backend_access_control_enabled():
12
+ if vector:
61
13
  vector_engine = get_vector_engine()
62
14
  await vector_engine.prune()
63
- elif vector and backend_access_control_enabled():
64
- await prune_vector_databases()
65
15
 
66
16
  if metadata:
67
17
  db_engine = get_relational_engine()
@@ -1,34 +1,8 @@
1
- from cognee.modules.users.models import DatasetDatabase
2
- from sqlalchemy import select
3
-
4
1
  from cognee.modules.data.models import Dataset
5
- from cognee.infrastructure.databases.utils.get_vector_dataset_database_handler import (
6
- get_vector_dataset_database_handler,
7
- )
8
- from cognee.infrastructure.databases.utils.get_graph_dataset_database_handler import (
9
- get_graph_dataset_database_handler,
10
- )
11
2
  from cognee.infrastructure.databases.relational import get_relational_engine
12
3
 
13
4
 
14
5
  async def delete_dataset(dataset: Dataset):
15
6
  db_engine = get_relational_engine()
16
7
 
17
- async with db_engine.get_async_session() as session:
18
- stmt = select(DatasetDatabase).where(
19
- DatasetDatabase.dataset_id == dataset.id,
20
- )
21
- dataset_database: DatasetDatabase = await session.scalar(stmt)
22
- if dataset_database:
23
- graph_dataset_database_handler = get_graph_dataset_database_handler(dataset_database)
24
- vector_dataset_database_handler = get_vector_dataset_database_handler(dataset_database)
25
- await graph_dataset_database_handler["handler_instance"].delete_dataset(
26
- dataset_database
27
- )
28
- await vector_dataset_database_handler["handler_instance"].delete_dataset(
29
- dataset_database
30
- )
31
- # TODO: Remove dataset from pipeline_run_status in Data objects related to dataset as well
32
- # This blocks recreation of the dataset with the same name and data after deletion as
33
- # it's marked as completed and will be just skipped even though it's empty.
34
8
  return await db_engine.delete_entity_by_id(dataset.__tablename__, dataset.id)
@@ -7,4 +7,3 @@ from .ColumnValue import ColumnValue
7
7
  from .Timestamp import Timestamp
8
8
  from .Interval import Interval
9
9
  from .Event import Event
10
- from .Triplet import Triplet
@@ -56,68 +56,6 @@ class CogneeGraph(CogneeAbstractGraph):
56
56
  def get_edges(self) -> List[Edge]:
57
57
  return self.edges
58
58
 
59
- async def _get_nodeset_subgraph(
60
- self,
61
- adapter,
62
- node_type,
63
- node_name,
64
- ):
65
- """Retrieve subgraph based on node type and name."""
66
- logger.info("Retrieving graph filtered by node type and node name (NodeSet).")
67
- nodes_data, edges_data = await adapter.get_nodeset_subgraph(
68
- node_type=node_type, node_name=node_name
69
- )
70
- if not nodes_data or not edges_data:
71
- raise EntityNotFoundError(
72
- message="Nodeset does not exist, or empty nodeset projected from the database."
73
- )
74
- return nodes_data, edges_data
75
-
76
- async def _get_full_or_id_filtered_graph(
77
- self,
78
- adapter,
79
- relevant_ids_to_filter,
80
- ):
81
- """Retrieve full or ID-filtered graph with fallback."""
82
- if relevant_ids_to_filter is None:
83
- logger.info("Retrieving full graph.")
84
- nodes_data, edges_data = await adapter.get_graph_data()
85
- if not nodes_data or not edges_data:
86
- raise EntityNotFoundError(message="Empty graph projected from the database.")
87
- return nodes_data, edges_data
88
-
89
- get_graph_data_fn = getattr(adapter, "get_id_filtered_graph_data", adapter.get_graph_data)
90
- if getattr(adapter.__class__, "get_id_filtered_graph_data", None):
91
- logger.info("Retrieving ID-filtered graph from database.")
92
- nodes_data, edges_data = await get_graph_data_fn(target_ids=relevant_ids_to_filter)
93
- else:
94
- logger.info("Retrieving full graph from database.")
95
- nodes_data, edges_data = await get_graph_data_fn()
96
- if hasattr(adapter, "get_id_filtered_graph_data") and (not nodes_data or not edges_data):
97
- logger.warning(
98
- "Id filtered graph returned empty, falling back to full graph retrieval."
99
- )
100
- logger.info("Retrieving full graph")
101
- nodes_data, edges_data = await adapter.get_graph_data()
102
-
103
- if not nodes_data or not edges_data:
104
- raise EntityNotFoundError("Empty graph projected from the database.")
105
- return nodes_data, edges_data
106
-
107
- async def _get_filtered_graph(
108
- self,
109
- adapter,
110
- memory_fragment_filter,
111
- ):
112
- """Retrieve graph filtered by attributes."""
113
- logger.info("Retrieving graph filtered by memory fragment")
114
- nodes_data, edges_data = await adapter.get_filtered_graph_data(
115
- attribute_filters=memory_fragment_filter
116
- )
117
- if not nodes_data or not edges_data:
118
- raise EntityNotFoundError(message="Empty filtered graph projected from the database.")
119
- return nodes_data, edges_data
120
-
121
59
  async def project_graph_from_db(
122
60
  self,
123
61
  adapter: Union[GraphDBInterface],
@@ -129,39 +67,40 @@ class CogneeGraph(CogneeAbstractGraph):
129
67
  memory_fragment_filter=[],
130
68
  node_type: Optional[Type] = None,
131
69
  node_name: Optional[List[str]] = None,
132
- relevant_ids_to_filter: Optional[List[str]] = None,
133
- triplet_distance_penalty: float = 3.5,
134
70
  ) -> None:
135
71
  if node_dimension < 1 or edge_dimension < 1:
136
72
  raise InvalidDimensionsError()
137
73
  try:
74
+ import time
75
+
76
+ start_time = time.time()
77
+
78
+ # Determine projection strategy
138
79
  if node_type is not None and node_name not in [None, [], ""]:
139
- nodes_data, edges_data = await self._get_nodeset_subgraph(
140
- adapter, node_type, node_name
80
+ nodes_data, edges_data = await adapter.get_nodeset_subgraph(
81
+ node_type=node_type, node_name=node_name
141
82
  )
83
+ if not nodes_data or not edges_data:
84
+ raise EntityNotFoundError(
85
+ message="Nodeset does not exist, or empty nodetes projected from the database."
86
+ )
142
87
  elif len(memory_fragment_filter) == 0:
143
- nodes_data, edges_data = await self._get_full_or_id_filtered_graph(
144
- adapter, relevant_ids_to_filter
145
- )
88
+ nodes_data, edges_data = await adapter.get_graph_data()
89
+ if not nodes_data or not edges_data:
90
+ raise EntityNotFoundError(message="Empty graph projected from the database.")
146
91
  else:
147
- nodes_data, edges_data = await self._get_filtered_graph(
148
- adapter, memory_fragment_filter
92
+ nodes_data, edges_data = await adapter.get_filtered_graph_data(
93
+ attribute_filters=memory_fragment_filter
149
94
  )
95
+ if not nodes_data or not edges_data:
96
+ raise EntityNotFoundError(
97
+ message="Empty filtered graph projected from the database."
98
+ )
150
99
 
151
- import time
152
-
153
- start_time = time.time()
154
100
  # Process nodes
155
101
  for node_id, properties in nodes_data:
156
102
  node_attributes = {key: properties.get(key) for key in node_properties_to_project}
157
- self.add_node(
158
- Node(
159
- str(node_id),
160
- node_attributes,
161
- dimension=node_dimension,
162
- node_penalty=triplet_distance_penalty,
163
- )
164
- )
103
+ self.add_node(Node(str(node_id), node_attributes, dimension=node_dimension))
165
104
 
166
105
  # Process edges
167
106
  for source_id, target_id, relationship_type, properties in edges_data:
@@ -179,7 +118,6 @@ class CogneeGraph(CogneeAbstractGraph):
179
118
  attributes=edge_attributes,
180
119
  directed=directed,
181
120
  dimension=edge_dimension,
182
- edge_penalty=triplet_distance_penalty,
183
121
  )
184
122
  self.add_edge(edge)
185
123
 
@@ -211,10 +149,24 @@ class CogneeGraph(CogneeAbstractGraph):
211
149
  node.add_attribute("vector_distance", score)
212
150
  mapped_nodes += 1
213
151
 
214
- async def map_vector_distances_to_graph_edges(self, edge_distances) -> None:
152
+ async def map_vector_distances_to_graph_edges(
153
+ self, vector_engine, query_vector, edge_distances
154
+ ) -> None:
215
155
  try:
156
+ if query_vector is None or len(query_vector) == 0:
157
+ raise ValueError("Failed to generate query embedding.")
158
+
216
159
  if edge_distances is None:
217
- return
160
+ start_time = time.time()
161
+ edge_distances = await vector_engine.search(
162
+ collection_name="EdgeType_relationship_name",
163
+ query_vector=query_vector,
164
+ limit=None,
165
+ )
166
+ projection_time = time.time() - start_time
167
+ logger.info(
168
+ f"Edge collection distances were calculated separately from nodes in {projection_time:.2f}s"
169
+ )
218
170
 
219
171
  embedding_map = {result.payload["text"]: result.score for result in edge_distances}
220
172
 
@@ -20,17 +20,13 @@ class Node:
20
20
  status: np.ndarray
21
21
 
22
22
  def __init__(
23
- self,
24
- node_id: str,
25
- attributes: Optional[Dict[str, Any]] = None,
26
- dimension: int = 1,
27
- node_penalty: float = 3.5,
23
+ self, node_id: str, attributes: Optional[Dict[str, Any]] = None, dimension: int = 1
28
24
  ):
29
25
  if dimension <= 0:
30
26
  raise InvalidDimensionsError()
31
27
  self.id = node_id
32
28
  self.attributes = attributes if attributes is not None else {}
33
- self.attributes["vector_distance"] = node_penalty
29
+ self.attributes["vector_distance"] = float("inf")
34
30
  self.skeleton_neighbours = []
35
31
  self.skeleton_edges = []
36
32
  self.status = np.ones(dimension, dtype=int)
@@ -109,14 +105,13 @@ class Edge:
109
105
  attributes: Optional[Dict[str, Any]] = None,
110
106
  directed: bool = True,
111
107
  dimension: int = 1,
112
- edge_penalty: float = 3.5,
113
108
  ):
114
109
  if dimension <= 0:
115
110
  raise InvalidDimensionsError()
116
111
  self.node1 = node1
117
112
  self.node2 = node2
118
113
  self.attributes = attributes if attributes is not None else {}
119
- self.attributes["vector_distance"] = edge_penalty
114
+ self.attributes["vector_distance"] = float("inf")
120
115
  self.directed = directed
121
116
  self.status = np.ones(dimension, dtype=int)
122
117
 
@@ -12,6 +12,9 @@ from cognee.modules.users.models import User
12
12
  from cognee.modules.pipelines.layers.resolve_authorized_user_datasets import (
13
13
  resolve_authorized_user_datasets,
14
14
  )
15
+ from cognee.modules.pipelines.layers.reset_dataset_pipeline_run_status import (
16
+ reset_dataset_pipeline_run_status,
17
+ )
15
18
  from cognee.modules.engine.operations.setup import setup
16
19
  from cognee.modules.pipelines.layers.pipeline_execution_mode import get_pipeline_executor
17
20
  from cognee.tasks.memify.extract_subgraph_chunks import extract_subgraph_chunks
@@ -94,6 +97,10 @@ async def memify(
94
97
  *enrichment_tasks,
95
98
  ]
96
99
 
100
+ await reset_dataset_pipeline_run_status(
101
+ authorized_dataset.id, user, pipeline_names=["memify_pipeline"]
102
+ )
103
+
97
104
  # By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
98
105
  pipeline_executor_func = get_pipeline_executor(run_in_background=run_in_background)
99
106
 
@@ -106,7 +113,6 @@ async def memify(
106
113
  datasets=authorized_dataset.id,
107
114
  vector_db_config=vector_db_config,
108
115
  graph_db_config=graph_db_config,
109
- use_pipeline_cache=False,
110
116
  incremental_loading=False,
111
117
  pipeline_name="memify_pipeline",
112
118
  )
@@ -20,9 +20,6 @@ from cognee.modules.pipelines.layers.resolve_authorized_user_datasets import (
20
20
  from cognee.modules.pipelines.layers.check_pipeline_run_qualification import (
21
21
  check_pipeline_run_qualification,
22
22
  )
23
- from cognee.modules.pipelines.models.PipelineRunInfo import (
24
- PipelineRunStarted,
25
- )
26
23
  from typing import Any
27
24
 
28
25
  logger = get_logger("cognee.pipeline")
@@ -38,7 +35,6 @@ async def run_pipeline(
38
35
  pipeline_name: str = "custom_pipeline",
39
36
  vector_db_config: dict = None,
40
37
  graph_db_config: dict = None,
41
- use_pipeline_cache: bool = False,
42
38
  incremental_loading: bool = False,
43
39
  data_per_batch: int = 20,
44
40
  ):
@@ -55,7 +51,6 @@ async def run_pipeline(
55
51
  data=data,
56
52
  pipeline_name=pipeline_name,
57
53
  context={"dataset": dataset},
58
- use_pipeline_cache=use_pipeline_cache,
59
54
  incremental_loading=incremental_loading,
60
55
  data_per_batch=data_per_batch,
61
56
  ):
@@ -69,7 +64,6 @@ async def run_pipeline_per_dataset(
69
64
  data=None,
70
65
  pipeline_name: str = "custom_pipeline",
71
66
  context: dict = None,
72
- use_pipeline_cache=False,
73
67
  incremental_loading=False,
74
68
  data_per_batch: int = 20,
75
69
  ):
@@ -83,18 +77,8 @@ async def run_pipeline_per_dataset(
83
77
  if process_pipeline_status:
84
78
  # If pipeline was already processed or is currently being processed
85
79
  # return status information to async generator and finish execution
86
- if use_pipeline_cache:
87
- # If pipeline caching is enabled we do not proceed with re-processing
88
- yield process_pipeline_status
89
- return
90
- else:
91
- # If pipeline caching is disabled we always return pipeline started information and proceed with re-processing
92
- yield PipelineRunStarted(
93
- pipeline_run_id=process_pipeline_status.pipeline_run_id,
94
- dataset_id=dataset.id,
95
- dataset_name=dataset.name,
96
- payload=data,
97
- )
80
+ yield process_pipeline_status
81
+ return
98
82
 
99
83
  pipeline_run = run_tasks(
100
84
  tasks,
@@ -1 +1 @@
1
-
1
+ from cognee.modules.retrieval.code_retriever import CodeRetriever
@@ -0,0 +1,232 @@
1
+ from typing import Any, Optional, List
2
+ import asyncio
3
+ import aiofiles
4
+ from pydantic import BaseModel
5
+
6
+ from cognee.shared.logging_utils import get_logger
7
+ from cognee.modules.retrieval.base_retriever import BaseRetriever
8
+ from cognee.infrastructure.databases.graph import get_graph_engine
9
+ from cognee.infrastructure.databases.vector import get_vector_engine
10
+ from cognee.infrastructure.llm.prompts import read_query_prompt
11
+ from cognee.infrastructure.llm.LLMGateway import LLMGateway
12
+
13
+ logger = get_logger("CodeRetriever")
14
+
15
+
16
+ class CodeRetriever(BaseRetriever):
17
+ """Retriever for handling code-based searches."""
18
+
19
+ class CodeQueryInfo(BaseModel):
20
+ """
21
+ Model for representing the result of a query related to code files.
22
+
23
+ This class holds a list of filenames and the corresponding source code extracted from a
24
+ query. It is used to encapsulate response data in a structured format.
25
+ """
26
+
27
+ filenames: List[str] = []
28
+ sourcecode: str
29
+
30
+ def __init__(self, top_k: int = 3):
31
+ """Initialize retriever with search parameters."""
32
+ self.top_k = top_k
33
+ self.file_name_collections = ["CodeFile_name"]
34
+ self.classes_and_functions_collections = [
35
+ "ClassDefinition_source_code",
36
+ "FunctionDefinition_source_code",
37
+ ]
38
+
39
+ async def _process_query(self, query: str) -> "CodeRetriever.CodeQueryInfo":
40
+ """Process the query using LLM to extract file names and source code parts."""
41
+ logger.debug(
42
+ f"Processing query with LLM: '{query[:100]}{'...' if len(query) > 100 else ''}'"
43
+ )
44
+
45
+ system_prompt = read_query_prompt("codegraph_retriever_system.txt")
46
+
47
+ try:
48
+ result = await LLMGateway.acreate_structured_output(
49
+ text_input=query,
50
+ system_prompt=system_prompt,
51
+ response_model=self.CodeQueryInfo,
52
+ )
53
+ logger.info(
54
+ f"LLM extracted {len(result.filenames)} filenames and {len(result.sourcecode)} chars of source code"
55
+ )
56
+ return result
57
+ except Exception as e:
58
+ logger.error(f"Failed to retrieve structured output from LLM: {str(e)}")
59
+ raise RuntimeError("Failed to retrieve structured output from LLM") from e
60
+
61
+ async def get_context(self, query: str) -> Any:
62
+ """Find relevant code files based on the query."""
63
+ logger.info(
64
+ f"Starting code retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
65
+ )
66
+
67
+ if not query or not isinstance(query, str):
68
+ logger.error("Invalid query: must be a non-empty string")
69
+ raise ValueError("The query must be a non-empty string.")
70
+
71
+ try:
72
+ vector_engine = get_vector_engine()
73
+ graph_engine = await get_graph_engine()
74
+ logger.debug("Successfully initialized vector and graph engines")
75
+ except Exception as e:
76
+ logger.error(f"Database initialization error: {str(e)}")
77
+ raise RuntimeError("Database initialization error in code_graph_retriever, ") from e
78
+
79
+ files_and_codeparts = await self._process_query(query)
80
+
81
+ similar_filenames = []
82
+ similar_codepieces = []
83
+
84
+ if not files_and_codeparts.filenames or not files_and_codeparts.sourcecode:
85
+ logger.info("No specific files/code extracted from query, performing general search")
86
+
87
+ for collection in self.file_name_collections:
88
+ logger.debug(f"Searching {collection} collection with general query")
89
+ search_results_file = await vector_engine.search(
90
+ collection, query, limit=self.top_k
91
+ )
92
+ logger.debug(f"Found {len(search_results_file)} results in {collection}")
93
+ for res in search_results_file:
94
+ similar_filenames.append(
95
+ {"id": res.id, "score": res.score, "payload": res.payload}
96
+ )
97
+
98
+ existing_collection = []
99
+ for collection in self.classes_and_functions_collections:
100
+ if await vector_engine.has_collection(collection):
101
+ existing_collection.append(collection)
102
+
103
+ if not existing_collection:
104
+ raise RuntimeError("No collection found for code retriever")
105
+
106
+ for collection in existing_collection:
107
+ logger.debug(f"Searching {collection} collection with general query")
108
+ search_results_code = await vector_engine.search(
109
+ collection, query, limit=self.top_k
110
+ )
111
+ logger.debug(f"Found {len(search_results_code)} results in {collection}")
112
+ for res in search_results_code:
113
+ similar_codepieces.append(
114
+ {"id": res.id, "score": res.score, "payload": res.payload}
115
+ )
116
+ else:
117
+ logger.info(
118
+ f"Using extracted filenames ({len(files_and_codeparts.filenames)}) and source code for targeted search"
119
+ )
120
+
121
+ for collection in self.file_name_collections:
122
+ for file_from_query in files_and_codeparts.filenames:
123
+ logger.debug(f"Searching {collection} for specific file: {file_from_query}")
124
+ search_results_file = await vector_engine.search(
125
+ collection, file_from_query, limit=self.top_k
126
+ )
127
+ logger.debug(
128
+ f"Found {len(search_results_file)} results for file {file_from_query}"
129
+ )
130
+ for res in search_results_file:
131
+ similar_filenames.append(
132
+ {"id": res.id, "score": res.score, "payload": res.payload}
133
+ )
134
+
135
+ for collection in self.classes_and_functions_collections:
136
+ logger.debug(f"Searching {collection} with extracted source code")
137
+ search_results_code = await vector_engine.search(
138
+ collection, files_and_codeparts.sourcecode, limit=self.top_k
139
+ )
140
+ logger.debug(f"Found {len(search_results_code)} results for source code search")
141
+ for res in search_results_code:
142
+ similar_codepieces.append(
143
+ {"id": res.id, "score": res.score, "payload": res.payload}
144
+ )
145
+
146
+ total_items = len(similar_filenames) + len(similar_codepieces)
147
+ logger.info(
148
+ f"Total search results: {total_items} items ({len(similar_filenames)} filenames, {len(similar_codepieces)} code pieces)"
149
+ )
150
+
151
+ if total_items == 0:
152
+ logger.warning("No search results found, returning empty list")
153
+ return []
154
+
155
+ logger.debug("Getting graph connections for all search results")
156
+ relevant_triplets = await asyncio.gather(
157
+ *[
158
+ graph_engine.get_connections(similar_piece["id"])
159
+ for similar_piece in similar_filenames + similar_codepieces
160
+ ]
161
+ )
162
+ logger.info(f"Retrieved graph connections for {len(relevant_triplets)} items")
163
+
164
+ paths = set()
165
+ for i, sublist in enumerate(relevant_triplets):
166
+ logger.debug(f"Processing connections for item {i}: {len(sublist)} connections")
167
+ for tpl in sublist:
168
+ if isinstance(tpl, tuple) and len(tpl) >= 3:
169
+ if "file_path" in tpl[0]:
170
+ paths.add(tpl[0]["file_path"])
171
+ if "file_path" in tpl[2]:
172
+ paths.add(tpl[2]["file_path"])
173
+
174
+ logger.info(f"Found {len(paths)} unique file paths to read")
175
+
176
+ retrieved_files = {}
177
+ read_tasks = []
178
+ for file_path in paths:
179
+
180
+ async def read_file(fp):
181
+ try:
182
+ logger.debug(f"Reading file: {fp}")
183
+ async with aiofiles.open(fp, "r", encoding="utf-8") as f:
184
+ content = await f.read()
185
+ retrieved_files[fp] = content
186
+ logger.debug(f"Successfully read {len(content)} characters from {fp}")
187
+ except Exception as e:
188
+ logger.error(f"Error reading {fp}: {e}")
189
+ retrieved_files[fp] = ""
190
+
191
+ read_tasks.append(read_file(file_path))
192
+
193
+ await asyncio.gather(*read_tasks)
194
+ logger.info(
195
+ f"Successfully read {len([f for f in retrieved_files.values() if f])} files (out of {len(paths)} total)"
196
+ )
197
+
198
+ result = [
199
+ {
200
+ "name": file_path,
201
+ "description": file_path,
202
+ "content": retrieved_files[file_path],
203
+ }
204
+ for file_path in paths
205
+ ]
206
+
207
+ logger.info(f"Returning {len(result)} code file contexts")
208
+ return result
209
+
210
+ async def get_completion(
211
+ self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
212
+ ) -> Any:
213
+ """
214
+ Returns the code files context.
215
+
216
+ Parameters:
217
+ -----------
218
+
219
+ - query (str): The query string to retrieve code context for.
220
+ - context (Optional[Any]): Optional pre-fetched context; if None, it retrieves
221
+ the context for the query. (default None)
222
+ - session_id (Optional[str]): Optional session identifier for caching. If None,
223
+ defaults to 'default_session'. (default None)
224
+
225
+ Returns:
226
+ --------
227
+
228
+ - Any: The code files context, either provided or retrieved.
229
+ """
230
+ if context is None:
231
+ context = await self.get_context(query)
232
+ return context
@@ -39,8 +39,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
39
39
  node_type: Optional[Type] = None,
40
40
  node_name: Optional[List[str]] = None,
41
41
  save_interaction: bool = False,
42
- wide_search_top_k: Optional[int] = 100,
43
- triplet_distance_penalty: Optional[float] = 3.5,
44
42
  ):
45
43
  super().__init__(
46
44
  user_prompt_path=user_prompt_path,
@@ -50,8 +48,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
50
48
  node_name=node_name,
51
49
  save_interaction=save_interaction,
52
50
  system_prompt=system_prompt,
53
- wide_search_top_k=wide_search_top_k,
54
- triplet_distance_penalty=triplet_distance_penalty,
55
51
  )
56
52
 
57
53
  async def get_completion(