cognee 0.5.0__py3-none-any.whl → 0.5.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/api/client.py +5 -1
- cognee/api/v1/add/add.py +1 -2
- cognee/api/v1/cognify/code_graph_pipeline.py +119 -0
- cognee/api/v1/cognify/cognify.py +16 -24
- cognee/api/v1/cognify/routers/__init__.py +1 -0
- cognee/api/v1/cognify/routers/get_code_pipeline_router.py +90 -0
- cognee/api/v1/cognify/routers/get_cognify_router.py +1 -3
- cognee/api/v1/datasets/routers/get_datasets_router.py +3 -3
- cognee/api/v1/ontologies/ontologies.py +37 -12
- cognee/api/v1/ontologies/routers/get_ontology_router.py +25 -27
- cognee/api/v1/search/search.py +0 -4
- cognee/api/v1/ui/ui.py +68 -38
- cognee/context_global_variables.py +16 -61
- cognee/eval_framework/answer_generation/answer_generation_executor.py +0 -10
- cognee/eval_framework/answer_generation/run_question_answering_module.py +1 -1
- cognee/eval_framework/corpus_builder/task_getters/get_cascade_graph_tasks.py +2 -0
- cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py +4 -4
- cognee/eval_framework/eval_config.py +2 -2
- cognee/eval_framework/modal_run_eval.py +28 -16
- cognee/infrastructure/databases/graph/config.py +0 -3
- cognee/infrastructure/databases/graph/get_graph_engine.py +0 -1
- cognee/infrastructure/databases/graph/graph_db_interface.py +0 -15
- cognee/infrastructure/databases/graph/kuzu/adapter.py +0 -228
- cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +1 -80
- cognee/infrastructure/databases/utils/__init__.py +0 -3
- cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +48 -62
- cognee/infrastructure/databases/vector/config.py +0 -2
- cognee/infrastructure/databases/vector/create_vector_engine.py +0 -1
- cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +6 -8
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +7 -9
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +10 -11
- cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +544 -0
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +0 -2
- cognee/infrastructure/databases/vector/vector_db_interface.py +0 -35
- cognee/infrastructure/files/storage/s3_config.py +0 -2
- cognee/infrastructure/llm/LLMGateway.py +2 -5
- cognee/infrastructure/llm/config.py +0 -35
- cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py +2 -2
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +8 -23
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +16 -17
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +37 -40
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +36 -39
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +1 -19
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +9 -11
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +21 -23
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +34 -42
- cognee/modules/cognify/config.py +0 -2
- cognee/modules/data/deletion/prune_system.py +2 -52
- cognee/modules/data/methods/delete_dataset.py +0 -26
- cognee/modules/engine/models/__init__.py +0 -1
- cognee/modules/graph/cognee_graph/CogneeGraph.py +37 -85
- cognee/modules/graph/cognee_graph/CogneeGraphElements.py +3 -8
- cognee/modules/memify/memify.py +7 -1
- cognee/modules/pipelines/operations/pipeline.py +2 -18
- cognee/modules/retrieval/__init__.py +1 -1
- cognee/modules/retrieval/code_retriever.py +232 -0
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +0 -4
- cognee/modules/retrieval/graph_completion_cot_retriever.py +0 -4
- cognee/modules/retrieval/graph_completion_retriever.py +0 -10
- cognee/modules/retrieval/graph_summary_completion_retriever.py +0 -4
- cognee/modules/retrieval/temporal_retriever.py +0 -4
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +10 -42
- cognee/modules/run_custom_pipeline/run_custom_pipeline.py +1 -8
- cognee/modules/search/methods/get_search_type_tools.py +8 -54
- cognee/modules/search/methods/no_access_control_search.py +0 -4
- cognee/modules/search/methods/search.py +0 -21
- cognee/modules/search/types/SearchType.py +1 -1
- cognee/modules/settings/get_settings.py +0 -19
- cognee/modules/users/methods/get_authenticated_user.py +2 -2
- cognee/modules/users/models/DatasetDatabase.py +3 -15
- cognee/shared/logging_utils.py +0 -4
- cognee/tasks/code/enrich_dependency_graph_checker.py +35 -0
- cognee/tasks/code/get_local_dependencies_checker.py +20 -0
- cognee/tasks/code/get_repo_dependency_graph_checker.py +35 -0
- cognee/tasks/documents/__init__.py +1 -0
- cognee/tasks/documents/check_permissions_on_dataset.py +26 -0
- cognee/tasks/graph/extract_graph_from_data.py +10 -9
- cognee/tasks/repo_processor/__init__.py +2 -0
- cognee/tasks/repo_processor/get_local_dependencies.py +335 -0
- cognee/tasks/repo_processor/get_non_code_files.py +158 -0
- cognee/tasks/repo_processor/get_repo_file_dependencies.py +243 -0
- cognee/tasks/storage/add_data_points.py +2 -142
- cognee/tests/test_cognee_server_start.py +4 -2
- cognee/tests/test_conversation_history.py +1 -23
- cognee/tests/test_delete_bmw_example.py +60 -0
- cognee/tests/test_search_db.py +1 -37
- cognee/tests/unit/api/test_ontology_endpoint.py +89 -77
- cognee/tests/unit/infrastructure/mock_embedding_engine.py +7 -3
- cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +5 -0
- cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +2 -2
- cognee/tests/unit/modules/graph/cognee_graph_test.py +0 -406
- {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/METADATA +89 -76
- {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/RECORD +97 -118
- {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/WHEEL +1 -1
- cognee/api/v1/ui/node_setup.py +0 -360
- cognee/api/v1/ui/npm_utils.py +0 -50
- cognee/eval_framework/Dockerfile +0 -29
- cognee/infrastructure/databases/dataset_database_handler/__init__.py +0 -3
- cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py +0 -80
- cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +0 -18
- cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py +0 -10
- cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py +0 -81
- cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +0 -168
- cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py +0 -10
- cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py +0 -10
- cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py +0 -30
- cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py +0 -50
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py +0 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py +0 -153
- cognee/memify_pipelines/create_triplet_embeddings.py +0 -53
- cognee/modules/engine/models/Triplet.py +0 -9
- cognee/modules/retrieval/register_retriever.py +0 -10
- cognee/modules/retrieval/registered_community_retrievers.py +0 -1
- cognee/modules/retrieval/triplet_retriever.py +0 -182
- cognee/shared/rate_limiting.py +0 -30
- cognee/tasks/memify/get_triplet_datapoints.py +0 -289
- cognee/tests/integration/retrieval/test_triplet_retriever.py +0 -84
- cognee/tests/integration/tasks/test_add_data_points.py +0 -139
- cognee/tests/integration/tasks/test_get_triplet_datapoints.py +0 -69
- cognee/tests/test_dataset_database_handler.py +0 -137
- cognee/tests/test_dataset_delete.py +0 -76
- cognee/tests/test_edge_centered_payload.py +0 -170
- cognee/tests/test_pipeline_cache.py +0 -164
- cognee/tests/unit/infrastructure/llm/test_llm_config.py +0 -46
- cognee/tests/unit/modules/memify_tasks/test_get_triplet_datapoints.py +0 -214
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +0 -608
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +0 -83
- cognee/tests/unit/tasks/storage/test_add_data_points.py +0 -288
- {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/licenses/NOTICE.md +0 -0
cognee/modules/cognify/config.py
CHANGED
|
@@ -8,14 +8,12 @@ import os
|
|
|
8
8
|
class CognifyConfig(BaseSettings):
|
|
9
9
|
classification_model: object = DefaultContentPrediction
|
|
10
10
|
summarization_model: object = SummarizedContent
|
|
11
|
-
triplet_embedding: bool = False
|
|
12
11
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
|
13
12
|
|
|
14
13
|
def to_dict(self) -> dict:
|
|
15
14
|
return {
|
|
16
15
|
"classification_model": self.classification_model,
|
|
17
16
|
"summarization_model": self.summarization_model,
|
|
18
|
-
"triplet_embedding": self.triplet_embedding,
|
|
19
17
|
}
|
|
20
18
|
|
|
21
19
|
|
|
@@ -1,67 +1,17 @@
|
|
|
1
|
-
from sqlalchemy.exc import OperationalError
|
|
2
|
-
|
|
3
|
-
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
|
|
4
|
-
from cognee.context_global_variables import backend_access_control_enabled
|
|
5
1
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
6
2
|
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
|
7
3
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
|
8
|
-
from cognee.infrastructure.databases.utils import (
|
|
9
|
-
get_graph_dataset_database_handler,
|
|
10
|
-
get_vector_dataset_database_handler,
|
|
11
|
-
)
|
|
12
4
|
from cognee.shared.cache import delete_cache
|
|
13
|
-
from cognee.modules.users.models import DatasetDatabase
|
|
14
|
-
from cognee.shared.logging_utils import get_logger
|
|
15
|
-
|
|
16
|
-
logger = get_logger()
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
async def prune_graph_databases():
|
|
20
|
-
db_engine = get_relational_engine()
|
|
21
|
-
try:
|
|
22
|
-
dataset_databases = await db_engine.get_all_data_from_table("dataset_database")
|
|
23
|
-
# Go through each dataset database and delete the graph database
|
|
24
|
-
for dataset_database in dataset_databases:
|
|
25
|
-
handler = get_graph_dataset_database_handler(dataset_database)
|
|
26
|
-
await handler["handler_instance"].delete_dataset(dataset_database)
|
|
27
|
-
except (OperationalError, EntityNotFoundError) as e:
|
|
28
|
-
logger.debug(
|
|
29
|
-
"Skipping pruning of graph DB. Error when accessing dataset_database table: %s",
|
|
30
|
-
e,
|
|
31
|
-
)
|
|
32
|
-
return
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
async def prune_vector_databases():
|
|
36
|
-
db_engine = get_relational_engine()
|
|
37
|
-
try:
|
|
38
|
-
dataset_databases = await db_engine.get_all_data_from_table("dataset_database")
|
|
39
|
-
# Go through each dataset database and delete the vector database
|
|
40
|
-
for dataset_database in dataset_databases:
|
|
41
|
-
handler = get_vector_dataset_database_handler(dataset_database)
|
|
42
|
-
await handler["handler_instance"].delete_dataset(dataset_database)
|
|
43
|
-
except (OperationalError, EntityNotFoundError) as e:
|
|
44
|
-
logger.debug(
|
|
45
|
-
"Skipping pruning of vector DB. Error when accessing dataset_database table: %s",
|
|
46
|
-
e,
|
|
47
|
-
)
|
|
48
|
-
return
|
|
49
5
|
|
|
50
6
|
|
|
51
7
|
async def prune_system(graph=True, vector=True, metadata=True, cache=True):
|
|
52
|
-
|
|
53
|
-
# delete all graph and vector databases if called. It should only be used in development or testing environments.
|
|
54
|
-
if graph and not backend_access_control_enabled():
|
|
8
|
+
if graph:
|
|
55
9
|
graph_engine = await get_graph_engine()
|
|
56
10
|
await graph_engine.delete_graph()
|
|
57
|
-
elif graph and backend_access_control_enabled():
|
|
58
|
-
await prune_graph_databases()
|
|
59
11
|
|
|
60
|
-
if vector
|
|
12
|
+
if vector:
|
|
61
13
|
vector_engine = get_vector_engine()
|
|
62
14
|
await vector_engine.prune()
|
|
63
|
-
elif vector and backend_access_control_enabled():
|
|
64
|
-
await prune_vector_databases()
|
|
65
15
|
|
|
66
16
|
if metadata:
|
|
67
17
|
db_engine = get_relational_engine()
|
|
@@ -1,34 +1,8 @@
|
|
|
1
|
-
from cognee.modules.users.models import DatasetDatabase
|
|
2
|
-
from sqlalchemy import select
|
|
3
|
-
|
|
4
1
|
from cognee.modules.data.models import Dataset
|
|
5
|
-
from cognee.infrastructure.databases.utils.get_vector_dataset_database_handler import (
|
|
6
|
-
get_vector_dataset_database_handler,
|
|
7
|
-
)
|
|
8
|
-
from cognee.infrastructure.databases.utils.get_graph_dataset_database_handler import (
|
|
9
|
-
get_graph_dataset_database_handler,
|
|
10
|
-
)
|
|
11
2
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
|
12
3
|
|
|
13
4
|
|
|
14
5
|
async def delete_dataset(dataset: Dataset):
|
|
15
6
|
db_engine = get_relational_engine()
|
|
16
7
|
|
|
17
|
-
async with db_engine.get_async_session() as session:
|
|
18
|
-
stmt = select(DatasetDatabase).where(
|
|
19
|
-
DatasetDatabase.dataset_id == dataset.id,
|
|
20
|
-
)
|
|
21
|
-
dataset_database: DatasetDatabase = await session.scalar(stmt)
|
|
22
|
-
if dataset_database:
|
|
23
|
-
graph_dataset_database_handler = get_graph_dataset_database_handler(dataset_database)
|
|
24
|
-
vector_dataset_database_handler = get_vector_dataset_database_handler(dataset_database)
|
|
25
|
-
await graph_dataset_database_handler["handler_instance"].delete_dataset(
|
|
26
|
-
dataset_database
|
|
27
|
-
)
|
|
28
|
-
await vector_dataset_database_handler["handler_instance"].delete_dataset(
|
|
29
|
-
dataset_database
|
|
30
|
-
)
|
|
31
|
-
# TODO: Remove dataset from pipeline_run_status in Data objects related to dataset as well
|
|
32
|
-
# This blocks recreation of the dataset with the same name and data after deletion as
|
|
33
|
-
# it's marked as completed and will be just skipped even though it's empty.
|
|
34
8
|
return await db_engine.delete_entity_by_id(dataset.__tablename__, dataset.id)
|
|
@@ -56,68 +56,6 @@ class CogneeGraph(CogneeAbstractGraph):
|
|
|
56
56
|
def get_edges(self) -> List[Edge]:
|
|
57
57
|
return self.edges
|
|
58
58
|
|
|
59
|
-
async def _get_nodeset_subgraph(
|
|
60
|
-
self,
|
|
61
|
-
adapter,
|
|
62
|
-
node_type,
|
|
63
|
-
node_name,
|
|
64
|
-
):
|
|
65
|
-
"""Retrieve subgraph based on node type and name."""
|
|
66
|
-
logger.info("Retrieving graph filtered by node type and node name (NodeSet).")
|
|
67
|
-
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
|
|
68
|
-
node_type=node_type, node_name=node_name
|
|
69
|
-
)
|
|
70
|
-
if not nodes_data or not edges_data:
|
|
71
|
-
raise EntityNotFoundError(
|
|
72
|
-
message="Nodeset does not exist, or empty nodeset projected from the database."
|
|
73
|
-
)
|
|
74
|
-
return nodes_data, edges_data
|
|
75
|
-
|
|
76
|
-
async def _get_full_or_id_filtered_graph(
|
|
77
|
-
self,
|
|
78
|
-
adapter,
|
|
79
|
-
relevant_ids_to_filter,
|
|
80
|
-
):
|
|
81
|
-
"""Retrieve full or ID-filtered graph with fallback."""
|
|
82
|
-
if relevant_ids_to_filter is None:
|
|
83
|
-
logger.info("Retrieving full graph.")
|
|
84
|
-
nodes_data, edges_data = await adapter.get_graph_data()
|
|
85
|
-
if not nodes_data or not edges_data:
|
|
86
|
-
raise EntityNotFoundError(message="Empty graph projected from the database.")
|
|
87
|
-
return nodes_data, edges_data
|
|
88
|
-
|
|
89
|
-
get_graph_data_fn = getattr(adapter, "get_id_filtered_graph_data", adapter.get_graph_data)
|
|
90
|
-
if getattr(adapter.__class__, "get_id_filtered_graph_data", None):
|
|
91
|
-
logger.info("Retrieving ID-filtered graph from database.")
|
|
92
|
-
nodes_data, edges_data = await get_graph_data_fn(target_ids=relevant_ids_to_filter)
|
|
93
|
-
else:
|
|
94
|
-
logger.info("Retrieving full graph from database.")
|
|
95
|
-
nodes_data, edges_data = await get_graph_data_fn()
|
|
96
|
-
if hasattr(adapter, "get_id_filtered_graph_data") and (not nodes_data or not edges_data):
|
|
97
|
-
logger.warning(
|
|
98
|
-
"Id filtered graph returned empty, falling back to full graph retrieval."
|
|
99
|
-
)
|
|
100
|
-
logger.info("Retrieving full graph")
|
|
101
|
-
nodes_data, edges_data = await adapter.get_graph_data()
|
|
102
|
-
|
|
103
|
-
if not nodes_data or not edges_data:
|
|
104
|
-
raise EntityNotFoundError("Empty graph projected from the database.")
|
|
105
|
-
return nodes_data, edges_data
|
|
106
|
-
|
|
107
|
-
async def _get_filtered_graph(
|
|
108
|
-
self,
|
|
109
|
-
adapter,
|
|
110
|
-
memory_fragment_filter,
|
|
111
|
-
):
|
|
112
|
-
"""Retrieve graph filtered by attributes."""
|
|
113
|
-
logger.info("Retrieving graph filtered by memory fragment")
|
|
114
|
-
nodes_data, edges_data = await adapter.get_filtered_graph_data(
|
|
115
|
-
attribute_filters=memory_fragment_filter
|
|
116
|
-
)
|
|
117
|
-
if not nodes_data or not edges_data:
|
|
118
|
-
raise EntityNotFoundError(message="Empty filtered graph projected from the database.")
|
|
119
|
-
return nodes_data, edges_data
|
|
120
|
-
|
|
121
59
|
async def project_graph_from_db(
|
|
122
60
|
self,
|
|
123
61
|
adapter: Union[GraphDBInterface],
|
|
@@ -129,39 +67,40 @@ class CogneeGraph(CogneeAbstractGraph):
|
|
|
129
67
|
memory_fragment_filter=[],
|
|
130
68
|
node_type: Optional[Type] = None,
|
|
131
69
|
node_name: Optional[List[str]] = None,
|
|
132
|
-
relevant_ids_to_filter: Optional[List[str]] = None,
|
|
133
|
-
triplet_distance_penalty: float = 3.5,
|
|
134
70
|
) -> None:
|
|
135
71
|
if node_dimension < 1 or edge_dimension < 1:
|
|
136
72
|
raise InvalidDimensionsError()
|
|
137
73
|
try:
|
|
74
|
+
import time
|
|
75
|
+
|
|
76
|
+
start_time = time.time()
|
|
77
|
+
|
|
78
|
+
# Determine projection strategy
|
|
138
79
|
if node_type is not None and node_name not in [None, [], ""]:
|
|
139
|
-
nodes_data, edges_data = await
|
|
140
|
-
|
|
80
|
+
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
|
|
81
|
+
node_type=node_type, node_name=node_name
|
|
141
82
|
)
|
|
83
|
+
if not nodes_data or not edges_data:
|
|
84
|
+
raise EntityNotFoundError(
|
|
85
|
+
message="Nodeset does not exist, or empty nodetes projected from the database."
|
|
86
|
+
)
|
|
142
87
|
elif len(memory_fragment_filter) == 0:
|
|
143
|
-
nodes_data, edges_data = await
|
|
144
|
-
|
|
145
|
-
|
|
88
|
+
nodes_data, edges_data = await adapter.get_graph_data()
|
|
89
|
+
if not nodes_data or not edges_data:
|
|
90
|
+
raise EntityNotFoundError(message="Empty graph projected from the database.")
|
|
146
91
|
else:
|
|
147
|
-
nodes_data, edges_data = await
|
|
148
|
-
|
|
92
|
+
nodes_data, edges_data = await adapter.get_filtered_graph_data(
|
|
93
|
+
attribute_filters=memory_fragment_filter
|
|
149
94
|
)
|
|
95
|
+
if not nodes_data or not edges_data:
|
|
96
|
+
raise EntityNotFoundError(
|
|
97
|
+
message="Empty filtered graph projected from the database."
|
|
98
|
+
)
|
|
150
99
|
|
|
151
|
-
import time
|
|
152
|
-
|
|
153
|
-
start_time = time.time()
|
|
154
100
|
# Process nodes
|
|
155
101
|
for node_id, properties in nodes_data:
|
|
156
102
|
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
|
|
157
|
-
self.add_node(
|
|
158
|
-
Node(
|
|
159
|
-
str(node_id),
|
|
160
|
-
node_attributes,
|
|
161
|
-
dimension=node_dimension,
|
|
162
|
-
node_penalty=triplet_distance_penalty,
|
|
163
|
-
)
|
|
164
|
-
)
|
|
103
|
+
self.add_node(Node(str(node_id), node_attributes, dimension=node_dimension))
|
|
165
104
|
|
|
166
105
|
# Process edges
|
|
167
106
|
for source_id, target_id, relationship_type, properties in edges_data:
|
|
@@ -179,7 +118,6 @@ class CogneeGraph(CogneeAbstractGraph):
|
|
|
179
118
|
attributes=edge_attributes,
|
|
180
119
|
directed=directed,
|
|
181
120
|
dimension=edge_dimension,
|
|
182
|
-
edge_penalty=triplet_distance_penalty,
|
|
183
121
|
)
|
|
184
122
|
self.add_edge(edge)
|
|
185
123
|
|
|
@@ -211,10 +149,24 @@ class CogneeGraph(CogneeAbstractGraph):
|
|
|
211
149
|
node.add_attribute("vector_distance", score)
|
|
212
150
|
mapped_nodes += 1
|
|
213
151
|
|
|
214
|
-
async def map_vector_distances_to_graph_edges(
|
|
152
|
+
async def map_vector_distances_to_graph_edges(
|
|
153
|
+
self, vector_engine, query_vector, edge_distances
|
|
154
|
+
) -> None:
|
|
215
155
|
try:
|
|
156
|
+
if query_vector is None or len(query_vector) == 0:
|
|
157
|
+
raise ValueError("Failed to generate query embedding.")
|
|
158
|
+
|
|
216
159
|
if edge_distances is None:
|
|
217
|
-
|
|
160
|
+
start_time = time.time()
|
|
161
|
+
edge_distances = await vector_engine.search(
|
|
162
|
+
collection_name="EdgeType_relationship_name",
|
|
163
|
+
query_vector=query_vector,
|
|
164
|
+
limit=None,
|
|
165
|
+
)
|
|
166
|
+
projection_time = time.time() - start_time
|
|
167
|
+
logger.info(
|
|
168
|
+
f"Edge collection distances were calculated separately from nodes in {projection_time:.2f}s"
|
|
169
|
+
)
|
|
218
170
|
|
|
219
171
|
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
|
|
220
172
|
|
|
@@ -20,17 +20,13 @@ class Node:
|
|
|
20
20
|
status: np.ndarray
|
|
21
21
|
|
|
22
22
|
def __init__(
|
|
23
|
-
self,
|
|
24
|
-
node_id: str,
|
|
25
|
-
attributes: Optional[Dict[str, Any]] = None,
|
|
26
|
-
dimension: int = 1,
|
|
27
|
-
node_penalty: float = 3.5,
|
|
23
|
+
self, node_id: str, attributes: Optional[Dict[str, Any]] = None, dimension: int = 1
|
|
28
24
|
):
|
|
29
25
|
if dimension <= 0:
|
|
30
26
|
raise InvalidDimensionsError()
|
|
31
27
|
self.id = node_id
|
|
32
28
|
self.attributes = attributes if attributes is not None else {}
|
|
33
|
-
self.attributes["vector_distance"] =
|
|
29
|
+
self.attributes["vector_distance"] = float("inf")
|
|
34
30
|
self.skeleton_neighbours = []
|
|
35
31
|
self.skeleton_edges = []
|
|
36
32
|
self.status = np.ones(dimension, dtype=int)
|
|
@@ -109,14 +105,13 @@ class Edge:
|
|
|
109
105
|
attributes: Optional[Dict[str, Any]] = None,
|
|
110
106
|
directed: bool = True,
|
|
111
107
|
dimension: int = 1,
|
|
112
|
-
edge_penalty: float = 3.5,
|
|
113
108
|
):
|
|
114
109
|
if dimension <= 0:
|
|
115
110
|
raise InvalidDimensionsError()
|
|
116
111
|
self.node1 = node1
|
|
117
112
|
self.node2 = node2
|
|
118
113
|
self.attributes = attributes if attributes is not None else {}
|
|
119
|
-
self.attributes["vector_distance"] =
|
|
114
|
+
self.attributes["vector_distance"] = float("inf")
|
|
120
115
|
self.directed = directed
|
|
121
116
|
self.status = np.ones(dimension, dtype=int)
|
|
122
117
|
|
cognee/modules/memify/memify.py
CHANGED
|
@@ -12,6 +12,9 @@ from cognee.modules.users.models import User
|
|
|
12
12
|
from cognee.modules.pipelines.layers.resolve_authorized_user_datasets import (
|
|
13
13
|
resolve_authorized_user_datasets,
|
|
14
14
|
)
|
|
15
|
+
from cognee.modules.pipelines.layers.reset_dataset_pipeline_run_status import (
|
|
16
|
+
reset_dataset_pipeline_run_status,
|
|
17
|
+
)
|
|
15
18
|
from cognee.modules.engine.operations.setup import setup
|
|
16
19
|
from cognee.modules.pipelines.layers.pipeline_execution_mode import get_pipeline_executor
|
|
17
20
|
from cognee.tasks.memify.extract_subgraph_chunks import extract_subgraph_chunks
|
|
@@ -94,6 +97,10 @@ async def memify(
|
|
|
94
97
|
*enrichment_tasks,
|
|
95
98
|
]
|
|
96
99
|
|
|
100
|
+
await reset_dataset_pipeline_run_status(
|
|
101
|
+
authorized_dataset.id, user, pipeline_names=["memify_pipeline"]
|
|
102
|
+
)
|
|
103
|
+
|
|
97
104
|
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
|
|
98
105
|
pipeline_executor_func = get_pipeline_executor(run_in_background=run_in_background)
|
|
99
106
|
|
|
@@ -106,7 +113,6 @@ async def memify(
|
|
|
106
113
|
datasets=authorized_dataset.id,
|
|
107
114
|
vector_db_config=vector_db_config,
|
|
108
115
|
graph_db_config=graph_db_config,
|
|
109
|
-
use_pipeline_cache=False,
|
|
110
116
|
incremental_loading=False,
|
|
111
117
|
pipeline_name="memify_pipeline",
|
|
112
118
|
)
|
|
@@ -20,9 +20,6 @@ from cognee.modules.pipelines.layers.resolve_authorized_user_datasets import (
|
|
|
20
20
|
from cognee.modules.pipelines.layers.check_pipeline_run_qualification import (
|
|
21
21
|
check_pipeline_run_qualification,
|
|
22
22
|
)
|
|
23
|
-
from cognee.modules.pipelines.models.PipelineRunInfo import (
|
|
24
|
-
PipelineRunStarted,
|
|
25
|
-
)
|
|
26
23
|
from typing import Any
|
|
27
24
|
|
|
28
25
|
logger = get_logger("cognee.pipeline")
|
|
@@ -38,7 +35,6 @@ async def run_pipeline(
|
|
|
38
35
|
pipeline_name: str = "custom_pipeline",
|
|
39
36
|
vector_db_config: dict = None,
|
|
40
37
|
graph_db_config: dict = None,
|
|
41
|
-
use_pipeline_cache: bool = False,
|
|
42
38
|
incremental_loading: bool = False,
|
|
43
39
|
data_per_batch: int = 20,
|
|
44
40
|
):
|
|
@@ -55,7 +51,6 @@ async def run_pipeline(
|
|
|
55
51
|
data=data,
|
|
56
52
|
pipeline_name=pipeline_name,
|
|
57
53
|
context={"dataset": dataset},
|
|
58
|
-
use_pipeline_cache=use_pipeline_cache,
|
|
59
54
|
incremental_loading=incremental_loading,
|
|
60
55
|
data_per_batch=data_per_batch,
|
|
61
56
|
):
|
|
@@ -69,7 +64,6 @@ async def run_pipeline_per_dataset(
|
|
|
69
64
|
data=None,
|
|
70
65
|
pipeline_name: str = "custom_pipeline",
|
|
71
66
|
context: dict = None,
|
|
72
|
-
use_pipeline_cache=False,
|
|
73
67
|
incremental_loading=False,
|
|
74
68
|
data_per_batch: int = 20,
|
|
75
69
|
):
|
|
@@ -83,18 +77,8 @@ async def run_pipeline_per_dataset(
|
|
|
83
77
|
if process_pipeline_status:
|
|
84
78
|
# If pipeline was already processed or is currently being processed
|
|
85
79
|
# return status information to async generator and finish execution
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
yield process_pipeline_status
|
|
89
|
-
return
|
|
90
|
-
else:
|
|
91
|
-
# If pipeline caching is disabled we always return pipeline started information and proceed with re-processing
|
|
92
|
-
yield PipelineRunStarted(
|
|
93
|
-
pipeline_run_id=process_pipeline_status.pipeline_run_id,
|
|
94
|
-
dataset_id=dataset.id,
|
|
95
|
-
dataset_name=dataset.name,
|
|
96
|
-
payload=data,
|
|
97
|
-
)
|
|
80
|
+
yield process_pipeline_status
|
|
81
|
+
return
|
|
98
82
|
|
|
99
83
|
pipeline_run = run_tasks(
|
|
100
84
|
tasks,
|
|
@@ -1 +1 @@
|
|
|
1
|
-
|
|
1
|
+
from cognee.modules.retrieval.code_retriever import CodeRetriever
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
from typing import Any, Optional, List
|
|
2
|
+
import asyncio
|
|
3
|
+
import aiofiles
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
from cognee.shared.logging_utils import get_logger
|
|
7
|
+
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|
8
|
+
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
9
|
+
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
10
|
+
from cognee.infrastructure.llm.prompts import read_query_prompt
|
|
11
|
+
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
|
12
|
+
|
|
13
|
+
logger = get_logger("CodeRetriever")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class CodeRetriever(BaseRetriever):
|
|
17
|
+
"""Retriever for handling code-based searches."""
|
|
18
|
+
|
|
19
|
+
class CodeQueryInfo(BaseModel):
|
|
20
|
+
"""
|
|
21
|
+
Model for representing the result of a query related to code files.
|
|
22
|
+
|
|
23
|
+
This class holds a list of filenames and the corresponding source code extracted from a
|
|
24
|
+
query. It is used to encapsulate response data in a structured format.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
filenames: List[str] = []
|
|
28
|
+
sourcecode: str
|
|
29
|
+
|
|
30
|
+
def __init__(self, top_k: int = 3):
|
|
31
|
+
"""Initialize retriever with search parameters."""
|
|
32
|
+
self.top_k = top_k
|
|
33
|
+
self.file_name_collections = ["CodeFile_name"]
|
|
34
|
+
self.classes_and_functions_collections = [
|
|
35
|
+
"ClassDefinition_source_code",
|
|
36
|
+
"FunctionDefinition_source_code",
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
async def _process_query(self, query: str) -> "CodeRetriever.CodeQueryInfo":
|
|
40
|
+
"""Process the query using LLM to extract file names and source code parts."""
|
|
41
|
+
logger.debug(
|
|
42
|
+
f"Processing query with LLM: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
system_prompt = read_query_prompt("codegraph_retriever_system.txt")
|
|
46
|
+
|
|
47
|
+
try:
|
|
48
|
+
result = await LLMGateway.acreate_structured_output(
|
|
49
|
+
text_input=query,
|
|
50
|
+
system_prompt=system_prompt,
|
|
51
|
+
response_model=self.CodeQueryInfo,
|
|
52
|
+
)
|
|
53
|
+
logger.info(
|
|
54
|
+
f"LLM extracted {len(result.filenames)} filenames and {len(result.sourcecode)} chars of source code"
|
|
55
|
+
)
|
|
56
|
+
return result
|
|
57
|
+
except Exception as e:
|
|
58
|
+
logger.error(f"Failed to retrieve structured output from LLM: {str(e)}")
|
|
59
|
+
raise RuntimeError("Failed to retrieve structured output from LLM") from e
|
|
60
|
+
|
|
61
|
+
async def get_context(self, query: str) -> Any:
|
|
62
|
+
"""Find relevant code files based on the query."""
|
|
63
|
+
logger.info(
|
|
64
|
+
f"Starting code retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
if not query or not isinstance(query, str):
|
|
68
|
+
logger.error("Invalid query: must be a non-empty string")
|
|
69
|
+
raise ValueError("The query must be a non-empty string.")
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
vector_engine = get_vector_engine()
|
|
73
|
+
graph_engine = await get_graph_engine()
|
|
74
|
+
logger.debug("Successfully initialized vector and graph engines")
|
|
75
|
+
except Exception as e:
|
|
76
|
+
logger.error(f"Database initialization error: {str(e)}")
|
|
77
|
+
raise RuntimeError("Database initialization error in code_graph_retriever, ") from e
|
|
78
|
+
|
|
79
|
+
files_and_codeparts = await self._process_query(query)
|
|
80
|
+
|
|
81
|
+
similar_filenames = []
|
|
82
|
+
similar_codepieces = []
|
|
83
|
+
|
|
84
|
+
if not files_and_codeparts.filenames or not files_and_codeparts.sourcecode:
|
|
85
|
+
logger.info("No specific files/code extracted from query, performing general search")
|
|
86
|
+
|
|
87
|
+
for collection in self.file_name_collections:
|
|
88
|
+
logger.debug(f"Searching {collection} collection with general query")
|
|
89
|
+
search_results_file = await vector_engine.search(
|
|
90
|
+
collection, query, limit=self.top_k
|
|
91
|
+
)
|
|
92
|
+
logger.debug(f"Found {len(search_results_file)} results in {collection}")
|
|
93
|
+
for res in search_results_file:
|
|
94
|
+
similar_filenames.append(
|
|
95
|
+
{"id": res.id, "score": res.score, "payload": res.payload}
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
existing_collection = []
|
|
99
|
+
for collection in self.classes_and_functions_collections:
|
|
100
|
+
if await vector_engine.has_collection(collection):
|
|
101
|
+
existing_collection.append(collection)
|
|
102
|
+
|
|
103
|
+
if not existing_collection:
|
|
104
|
+
raise RuntimeError("No collection found for code retriever")
|
|
105
|
+
|
|
106
|
+
for collection in existing_collection:
|
|
107
|
+
logger.debug(f"Searching {collection} collection with general query")
|
|
108
|
+
search_results_code = await vector_engine.search(
|
|
109
|
+
collection, query, limit=self.top_k
|
|
110
|
+
)
|
|
111
|
+
logger.debug(f"Found {len(search_results_code)} results in {collection}")
|
|
112
|
+
for res in search_results_code:
|
|
113
|
+
similar_codepieces.append(
|
|
114
|
+
{"id": res.id, "score": res.score, "payload": res.payload}
|
|
115
|
+
)
|
|
116
|
+
else:
|
|
117
|
+
logger.info(
|
|
118
|
+
f"Using extracted filenames ({len(files_and_codeparts.filenames)}) and source code for targeted search"
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
for collection in self.file_name_collections:
|
|
122
|
+
for file_from_query in files_and_codeparts.filenames:
|
|
123
|
+
logger.debug(f"Searching {collection} for specific file: {file_from_query}")
|
|
124
|
+
search_results_file = await vector_engine.search(
|
|
125
|
+
collection, file_from_query, limit=self.top_k
|
|
126
|
+
)
|
|
127
|
+
logger.debug(
|
|
128
|
+
f"Found {len(search_results_file)} results for file {file_from_query}"
|
|
129
|
+
)
|
|
130
|
+
for res in search_results_file:
|
|
131
|
+
similar_filenames.append(
|
|
132
|
+
{"id": res.id, "score": res.score, "payload": res.payload}
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
for collection in self.classes_and_functions_collections:
|
|
136
|
+
logger.debug(f"Searching {collection} with extracted source code")
|
|
137
|
+
search_results_code = await vector_engine.search(
|
|
138
|
+
collection, files_and_codeparts.sourcecode, limit=self.top_k
|
|
139
|
+
)
|
|
140
|
+
logger.debug(f"Found {len(search_results_code)} results for source code search")
|
|
141
|
+
for res in search_results_code:
|
|
142
|
+
similar_codepieces.append(
|
|
143
|
+
{"id": res.id, "score": res.score, "payload": res.payload}
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
total_items = len(similar_filenames) + len(similar_codepieces)
|
|
147
|
+
logger.info(
|
|
148
|
+
f"Total search results: {total_items} items ({len(similar_filenames)} filenames, {len(similar_codepieces)} code pieces)"
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
if total_items == 0:
|
|
152
|
+
logger.warning("No search results found, returning empty list")
|
|
153
|
+
return []
|
|
154
|
+
|
|
155
|
+
logger.debug("Getting graph connections for all search results")
|
|
156
|
+
relevant_triplets = await asyncio.gather(
|
|
157
|
+
*[
|
|
158
|
+
graph_engine.get_connections(similar_piece["id"])
|
|
159
|
+
for similar_piece in similar_filenames + similar_codepieces
|
|
160
|
+
]
|
|
161
|
+
)
|
|
162
|
+
logger.info(f"Retrieved graph connections for {len(relevant_triplets)} items")
|
|
163
|
+
|
|
164
|
+
paths = set()
|
|
165
|
+
for i, sublist in enumerate(relevant_triplets):
|
|
166
|
+
logger.debug(f"Processing connections for item {i}: {len(sublist)} connections")
|
|
167
|
+
for tpl in sublist:
|
|
168
|
+
if isinstance(tpl, tuple) and len(tpl) >= 3:
|
|
169
|
+
if "file_path" in tpl[0]:
|
|
170
|
+
paths.add(tpl[0]["file_path"])
|
|
171
|
+
if "file_path" in tpl[2]:
|
|
172
|
+
paths.add(tpl[2]["file_path"])
|
|
173
|
+
|
|
174
|
+
logger.info(f"Found {len(paths)} unique file paths to read")
|
|
175
|
+
|
|
176
|
+
retrieved_files = {}
|
|
177
|
+
read_tasks = []
|
|
178
|
+
for file_path in paths:
|
|
179
|
+
|
|
180
|
+
async def read_file(fp):
|
|
181
|
+
try:
|
|
182
|
+
logger.debug(f"Reading file: {fp}")
|
|
183
|
+
async with aiofiles.open(fp, "r", encoding="utf-8") as f:
|
|
184
|
+
content = await f.read()
|
|
185
|
+
retrieved_files[fp] = content
|
|
186
|
+
logger.debug(f"Successfully read {len(content)} characters from {fp}")
|
|
187
|
+
except Exception as e:
|
|
188
|
+
logger.error(f"Error reading {fp}: {e}")
|
|
189
|
+
retrieved_files[fp] = ""
|
|
190
|
+
|
|
191
|
+
read_tasks.append(read_file(file_path))
|
|
192
|
+
|
|
193
|
+
await asyncio.gather(*read_tasks)
|
|
194
|
+
logger.info(
|
|
195
|
+
f"Successfully read {len([f for f in retrieved_files.values() if f])} files (out of {len(paths)} total)"
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
result = [
|
|
199
|
+
{
|
|
200
|
+
"name": file_path,
|
|
201
|
+
"description": file_path,
|
|
202
|
+
"content": retrieved_files[file_path],
|
|
203
|
+
}
|
|
204
|
+
for file_path in paths
|
|
205
|
+
]
|
|
206
|
+
|
|
207
|
+
logger.info(f"Returning {len(result)} code file contexts")
|
|
208
|
+
return result
|
|
209
|
+
|
|
210
|
+
async def get_completion(
|
|
211
|
+
self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
|
|
212
|
+
) -> Any:
|
|
213
|
+
"""
|
|
214
|
+
Returns the code files context.
|
|
215
|
+
|
|
216
|
+
Parameters:
|
|
217
|
+
-----------
|
|
218
|
+
|
|
219
|
+
- query (str): The query string to retrieve code context for.
|
|
220
|
+
- context (Optional[Any]): Optional pre-fetched context; if None, it retrieves
|
|
221
|
+
the context for the query. (default None)
|
|
222
|
+
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
|
223
|
+
defaults to 'default_session'. (default None)
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
--------
|
|
227
|
+
|
|
228
|
+
- Any: The code files context, either provided or retrieved.
|
|
229
|
+
"""
|
|
230
|
+
if context is None:
|
|
231
|
+
context = await self.get_context(query)
|
|
232
|
+
return context
|
|
@@ -39,8 +39,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|
|
39
39
|
node_type: Optional[Type] = None,
|
|
40
40
|
node_name: Optional[List[str]] = None,
|
|
41
41
|
save_interaction: bool = False,
|
|
42
|
-
wide_search_top_k: Optional[int] = 100,
|
|
43
|
-
triplet_distance_penalty: Optional[float] = 3.5,
|
|
44
42
|
):
|
|
45
43
|
super().__init__(
|
|
46
44
|
user_prompt_path=user_prompt_path,
|
|
@@ -50,8 +48,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|
|
50
48
|
node_name=node_name,
|
|
51
49
|
save_interaction=save_interaction,
|
|
52
50
|
system_prompt=system_prompt,
|
|
53
|
-
wide_search_top_k=wide_search_top_k,
|
|
54
|
-
triplet_distance_penalty=triplet_distance_penalty,
|
|
55
51
|
)
|
|
56
52
|
|
|
57
53
|
async def get_completion(
|