cognee 0.5.0__py3-none-any.whl → 0.5.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/api/client.py +5 -1
- cognee/api/v1/add/add.py +1 -2
- cognee/api/v1/cognify/code_graph_pipeline.py +119 -0
- cognee/api/v1/cognify/cognify.py +16 -24
- cognee/api/v1/cognify/routers/__init__.py +1 -0
- cognee/api/v1/cognify/routers/get_code_pipeline_router.py +90 -0
- cognee/api/v1/cognify/routers/get_cognify_router.py +1 -3
- cognee/api/v1/datasets/routers/get_datasets_router.py +3 -3
- cognee/api/v1/ontologies/ontologies.py +37 -12
- cognee/api/v1/ontologies/routers/get_ontology_router.py +25 -27
- cognee/api/v1/search/search.py +0 -4
- cognee/api/v1/ui/ui.py +68 -38
- cognee/context_global_variables.py +16 -61
- cognee/eval_framework/answer_generation/answer_generation_executor.py +0 -10
- cognee/eval_framework/answer_generation/run_question_answering_module.py +1 -1
- cognee/eval_framework/corpus_builder/task_getters/get_cascade_graph_tasks.py +2 -0
- cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py +4 -4
- cognee/eval_framework/eval_config.py +2 -2
- cognee/eval_framework/modal_run_eval.py +28 -16
- cognee/infrastructure/databases/graph/config.py +0 -3
- cognee/infrastructure/databases/graph/get_graph_engine.py +0 -1
- cognee/infrastructure/databases/graph/graph_db_interface.py +0 -15
- cognee/infrastructure/databases/graph/kuzu/adapter.py +0 -228
- cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +1 -80
- cognee/infrastructure/databases/utils/__init__.py +0 -3
- cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +48 -62
- cognee/infrastructure/databases/vector/config.py +0 -2
- cognee/infrastructure/databases/vector/create_vector_engine.py +0 -1
- cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +6 -8
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +7 -9
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +10 -11
- cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +544 -0
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +0 -2
- cognee/infrastructure/databases/vector/vector_db_interface.py +0 -35
- cognee/infrastructure/files/storage/s3_config.py +0 -2
- cognee/infrastructure/llm/LLMGateway.py +2 -5
- cognee/infrastructure/llm/config.py +0 -35
- cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py +2 -2
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +8 -23
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +16 -17
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +37 -40
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +36 -39
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +1 -19
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +9 -11
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +21 -23
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +34 -42
- cognee/modules/cognify/config.py +0 -2
- cognee/modules/data/deletion/prune_system.py +2 -52
- cognee/modules/data/methods/delete_dataset.py +0 -26
- cognee/modules/engine/models/__init__.py +0 -1
- cognee/modules/graph/cognee_graph/CogneeGraph.py +37 -85
- cognee/modules/graph/cognee_graph/CogneeGraphElements.py +3 -8
- cognee/modules/memify/memify.py +7 -1
- cognee/modules/pipelines/operations/pipeline.py +2 -18
- cognee/modules/retrieval/__init__.py +1 -1
- cognee/modules/retrieval/code_retriever.py +232 -0
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +0 -4
- cognee/modules/retrieval/graph_completion_cot_retriever.py +0 -4
- cognee/modules/retrieval/graph_completion_retriever.py +0 -10
- cognee/modules/retrieval/graph_summary_completion_retriever.py +0 -4
- cognee/modules/retrieval/temporal_retriever.py +0 -4
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +10 -42
- cognee/modules/run_custom_pipeline/run_custom_pipeline.py +1 -8
- cognee/modules/search/methods/get_search_type_tools.py +8 -54
- cognee/modules/search/methods/no_access_control_search.py +0 -4
- cognee/modules/search/methods/search.py +0 -21
- cognee/modules/search/types/SearchType.py +1 -1
- cognee/modules/settings/get_settings.py +0 -19
- cognee/modules/users/methods/get_authenticated_user.py +2 -2
- cognee/modules/users/models/DatasetDatabase.py +3 -15
- cognee/shared/logging_utils.py +0 -4
- cognee/tasks/code/enrich_dependency_graph_checker.py +35 -0
- cognee/tasks/code/get_local_dependencies_checker.py +20 -0
- cognee/tasks/code/get_repo_dependency_graph_checker.py +35 -0
- cognee/tasks/documents/__init__.py +1 -0
- cognee/tasks/documents/check_permissions_on_dataset.py +26 -0
- cognee/tasks/graph/extract_graph_from_data.py +10 -9
- cognee/tasks/repo_processor/__init__.py +2 -0
- cognee/tasks/repo_processor/get_local_dependencies.py +335 -0
- cognee/tasks/repo_processor/get_non_code_files.py +158 -0
- cognee/tasks/repo_processor/get_repo_file_dependencies.py +243 -0
- cognee/tasks/storage/add_data_points.py +2 -142
- cognee/tests/test_cognee_server_start.py +4 -2
- cognee/tests/test_conversation_history.py +1 -23
- cognee/tests/test_delete_bmw_example.py +60 -0
- cognee/tests/test_search_db.py +1 -37
- cognee/tests/unit/api/test_ontology_endpoint.py +89 -77
- cognee/tests/unit/infrastructure/mock_embedding_engine.py +7 -3
- cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +5 -0
- cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +2 -2
- cognee/tests/unit/modules/graph/cognee_graph_test.py +0 -406
- {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/METADATA +89 -76
- {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/RECORD +97 -118
- {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/WHEEL +1 -1
- cognee/api/v1/ui/node_setup.py +0 -360
- cognee/api/v1/ui/npm_utils.py +0 -50
- cognee/eval_framework/Dockerfile +0 -29
- cognee/infrastructure/databases/dataset_database_handler/__init__.py +0 -3
- cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py +0 -80
- cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +0 -18
- cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py +0 -10
- cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py +0 -81
- cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +0 -168
- cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py +0 -10
- cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py +0 -10
- cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py +0 -30
- cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py +0 -50
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py +0 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py +0 -153
- cognee/memify_pipelines/create_triplet_embeddings.py +0 -53
- cognee/modules/engine/models/Triplet.py +0 -9
- cognee/modules/retrieval/register_retriever.py +0 -10
- cognee/modules/retrieval/registered_community_retrievers.py +0 -1
- cognee/modules/retrieval/triplet_retriever.py +0 -182
- cognee/shared/rate_limiting.py +0 -30
- cognee/tasks/memify/get_triplet_datapoints.py +0 -289
- cognee/tests/integration/retrieval/test_triplet_retriever.py +0 -84
- cognee/tests/integration/tasks/test_add_data_points.py +0 -139
- cognee/tests/integration/tasks/test_get_triplet_datapoints.py +0 -69
- cognee/tests/test_dataset_database_handler.py +0 -137
- cognee/tests/test_dataset_delete.py +0 -76
- cognee/tests/test_edge_centered_payload.py +0 -170
- cognee/tests/test_pipeline_cache.py +0 -164
- cognee/tests/unit/infrastructure/llm/test_llm_config.py +0 -46
- cognee/tests/unit/modules/memify_tasks/test_get_triplet_datapoints.py +0 -214
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +0 -608
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +0 -83
- cognee/tests/unit/tasks/storage/test_add_data_points.py +0 -288
- {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import math
|
|
3
|
+
import os
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Set
|
|
6
|
+
from typing import AsyncGenerator, Optional, List
|
|
7
|
+
from uuid import NAMESPACE_OID, uuid5
|
|
8
|
+
|
|
9
|
+
from cognee.infrastructure.engine import DataPoint
|
|
10
|
+
from cognee.shared.CodeGraphEntities import CodeFile, Repository
|
|
11
|
+
|
|
12
|
+
# constant, declared only once
|
|
13
|
+
EXCLUDED_DIRS: Set[str] = {
|
|
14
|
+
".venv",
|
|
15
|
+
"venv",
|
|
16
|
+
"env",
|
|
17
|
+
".env",
|
|
18
|
+
"site-packages",
|
|
19
|
+
"node_modules",
|
|
20
|
+
"dist",
|
|
21
|
+
"build",
|
|
22
|
+
".git",
|
|
23
|
+
"tests",
|
|
24
|
+
"test",
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
async def get_source_code_files(
|
|
29
|
+
repo_path,
|
|
30
|
+
language_config: dict[str, list[str]] | None = None,
|
|
31
|
+
excluded_paths: Optional[List[str]] = None,
|
|
32
|
+
):
|
|
33
|
+
"""
|
|
34
|
+
Retrieve Python source code files from the specified repository path.
|
|
35
|
+
|
|
36
|
+
This function scans the given repository path for files that have the .py extension
|
|
37
|
+
while excluding test files and files within a virtual environment. It returns a list of
|
|
38
|
+
absolute paths to the source code files that are not empty.
|
|
39
|
+
|
|
40
|
+
Parameters:
|
|
41
|
+
-----------
|
|
42
|
+
- repo_path: Root path of the repository to search
|
|
43
|
+
- language_config: dict mapping language names to file extensions, e.g.,
|
|
44
|
+
{'python': ['.py'], 'javascript': ['.js', '.jsx'], ...}
|
|
45
|
+
- excluded_paths: Optional list of path fragments or glob patterns to exclude
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
--------
|
|
49
|
+
A list of (absolute_path, language) tuples for source code files.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def _get_language_from_extension(file, language_config):
|
|
53
|
+
for lang, exts in language_config.items():
|
|
54
|
+
for ext in exts:
|
|
55
|
+
if file.endswith(ext):
|
|
56
|
+
return lang
|
|
57
|
+
return None
|
|
58
|
+
|
|
59
|
+
# Default config if not provided
|
|
60
|
+
if language_config is None:
|
|
61
|
+
language_config = {
|
|
62
|
+
"python": [".py"],
|
|
63
|
+
"javascript": [".js", ".jsx"],
|
|
64
|
+
"typescript": [".ts", ".tsx"],
|
|
65
|
+
"java": [".java"],
|
|
66
|
+
"csharp": [".cs"],
|
|
67
|
+
"go": [".go"],
|
|
68
|
+
"rust": [".rs"],
|
|
69
|
+
"cpp": [".cpp", ".c", ".h", ".hpp"],
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
if not os.path.exists(repo_path):
|
|
73
|
+
return []
|
|
74
|
+
|
|
75
|
+
source_code_files = set()
|
|
76
|
+
for root, _, files in os.walk(repo_path):
|
|
77
|
+
for file in files:
|
|
78
|
+
lang = _get_language_from_extension(file, language_config)
|
|
79
|
+
if lang is None:
|
|
80
|
+
continue
|
|
81
|
+
# Exclude tests, common build/venv directories and files provided in exclude_paths
|
|
82
|
+
excluded_dirs = EXCLUDED_DIRS
|
|
83
|
+
excluded_paths = {Path(p).resolve() for p in (excluded_paths or [])} # full paths
|
|
84
|
+
|
|
85
|
+
root_path = Path(root).resolve()
|
|
86
|
+
root_parts = set(root_path.parts) # same as before
|
|
87
|
+
base_name, _ext = os.path.splitext(file)
|
|
88
|
+
if (
|
|
89
|
+
base_name.startswith("test_")
|
|
90
|
+
or base_name.endswith("_test")
|
|
91
|
+
or ".test." in file
|
|
92
|
+
or ".spec." in file
|
|
93
|
+
or (excluded_dirs & root_parts) # name match
|
|
94
|
+
or any(
|
|
95
|
+
root_path.is_relative_to(p) # full-path match
|
|
96
|
+
for p in excluded_paths
|
|
97
|
+
)
|
|
98
|
+
):
|
|
99
|
+
continue
|
|
100
|
+
file_path = os.path.abspath(os.path.join(root, file))
|
|
101
|
+
if os.path.getsize(file_path) == 0:
|
|
102
|
+
continue
|
|
103
|
+
source_code_files.add((file_path, lang))
|
|
104
|
+
|
|
105
|
+
return sorted(list(source_code_files))
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def run_coroutine(coroutine_func, *args, **kwargs):
|
|
109
|
+
"""
|
|
110
|
+
Run a coroutine function until it completes.
|
|
111
|
+
|
|
112
|
+
This function creates a new asyncio event loop, sets it as the current loop, and
|
|
113
|
+
executes the given coroutine function with the provided arguments. Once the coroutine
|
|
114
|
+
completes, the loop is closed. Intended for use in environments where an existing event
|
|
115
|
+
loop is not available or desirable.
|
|
116
|
+
|
|
117
|
+
Parameters:
|
|
118
|
+
-----------
|
|
119
|
+
|
|
120
|
+
- coroutine_func: The coroutine function to be run.
|
|
121
|
+
- *args: Positional arguments to pass to the coroutine function.
|
|
122
|
+
- **kwargs: Keyword arguments to pass to the coroutine function.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
--------
|
|
126
|
+
|
|
127
|
+
The result returned by the coroutine after completion.
|
|
128
|
+
"""
|
|
129
|
+
loop = asyncio.new_event_loop()
|
|
130
|
+
asyncio.set_event_loop(loop)
|
|
131
|
+
result = loop.run_until_complete(coroutine_func(*args, **kwargs))
|
|
132
|
+
loop.close()
|
|
133
|
+
return result
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
async def get_repo_file_dependencies(
|
|
137
|
+
repo_path: str,
|
|
138
|
+
detailed_extraction: bool = False,
|
|
139
|
+
supported_languages: list = None,
|
|
140
|
+
excluded_paths: Optional[List[str]] = None,
|
|
141
|
+
) -> AsyncGenerator[DataPoint, None]:
|
|
142
|
+
"""
|
|
143
|
+
Generate a dependency graph for source files (multi-language) in the given repository path.
|
|
144
|
+
|
|
145
|
+
Check the validity of the repository path and yield a repository object followed by the
|
|
146
|
+
dependencies of source files within that repository. Raise a FileNotFoundError if the
|
|
147
|
+
provided path does not exist. The extraction of detailed dependencies can be controlled
|
|
148
|
+
via the `detailed_extraction` argument. Languages considered can be restricted via
|
|
149
|
+
the `supported_languages` argument.
|
|
150
|
+
|
|
151
|
+
Parameters:
|
|
152
|
+
-----------
|
|
153
|
+
|
|
154
|
+
- repo_path (str): The file path to the repository to process.
|
|
155
|
+
- detailed_extraction (bool): Whether to perform a detailed extraction of code parts.
|
|
156
|
+
- supported_languages (list | None): Subset of languages to include; if None, use defaults.
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
if isinstance(repo_path, list) and len(repo_path) == 1:
|
|
160
|
+
repo_path = repo_path[0]
|
|
161
|
+
|
|
162
|
+
if not os.path.exists(repo_path):
|
|
163
|
+
raise FileNotFoundError(f"Repository path {repo_path} does not exist.")
|
|
164
|
+
|
|
165
|
+
# Build language config from supported_languages
|
|
166
|
+
default_language_config = {
|
|
167
|
+
"python": [".py"],
|
|
168
|
+
"javascript": [".js", ".jsx"],
|
|
169
|
+
"typescript": [".ts", ".tsx"],
|
|
170
|
+
"java": [".java"],
|
|
171
|
+
"csharp": [".cs"],
|
|
172
|
+
"go": [".go"],
|
|
173
|
+
"rust": [".rs"],
|
|
174
|
+
"cpp": [".cpp", ".c", ".h", ".hpp"],
|
|
175
|
+
"c": [".c", ".h"],
|
|
176
|
+
}
|
|
177
|
+
if supported_languages is not None:
|
|
178
|
+
language_config = {
|
|
179
|
+
k: v for k, v in default_language_config.items() if k in supported_languages
|
|
180
|
+
}
|
|
181
|
+
else:
|
|
182
|
+
language_config = default_language_config
|
|
183
|
+
|
|
184
|
+
source_code_files = await get_source_code_files(
|
|
185
|
+
repo_path, language_config=language_config, excluded_paths=excluded_paths
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
repo = Repository(
|
|
189
|
+
id=uuid5(NAMESPACE_OID, repo_path),
|
|
190
|
+
path=repo_path,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
yield repo
|
|
194
|
+
|
|
195
|
+
chunk_size = 100
|
|
196
|
+
number_of_chunks = math.ceil(len(source_code_files) / chunk_size)
|
|
197
|
+
chunk_ranges = [
|
|
198
|
+
(
|
|
199
|
+
chunk_number * chunk_size,
|
|
200
|
+
min((chunk_number + 1) * chunk_size, len(source_code_files)) - 1,
|
|
201
|
+
)
|
|
202
|
+
for chunk_number in range(number_of_chunks)
|
|
203
|
+
]
|
|
204
|
+
|
|
205
|
+
# Import dependency extractors for each language (Python for now, extend later)
|
|
206
|
+
from cognee.tasks.repo_processor.get_local_dependencies import get_local_script_dependencies
|
|
207
|
+
import aiofiles
|
|
208
|
+
# TODO: Add other language extractors here
|
|
209
|
+
|
|
210
|
+
for start_range, end_range in chunk_ranges:
|
|
211
|
+
tasks = []
|
|
212
|
+
for file_path, lang in source_code_files[start_range : end_range + 1]:
|
|
213
|
+
# For now, only Python is supported; extend with other languages
|
|
214
|
+
if lang == "python":
|
|
215
|
+
tasks.append(
|
|
216
|
+
get_local_script_dependencies(repo_path, file_path, detailed_extraction)
|
|
217
|
+
)
|
|
218
|
+
else:
|
|
219
|
+
# Placeholder: create a minimal CodeFile for other languages
|
|
220
|
+
async def make_codefile_stub(file_path=file_path, lang=lang):
|
|
221
|
+
async with aiofiles.open(
|
|
222
|
+
file_path, "r", encoding="utf-8", errors="replace"
|
|
223
|
+
) as f:
|
|
224
|
+
source = await f.read()
|
|
225
|
+
return CodeFile(
|
|
226
|
+
id=uuid5(NAMESPACE_OID, file_path),
|
|
227
|
+
name=os.path.relpath(file_path, repo_path),
|
|
228
|
+
file_path=file_path,
|
|
229
|
+
language=lang,
|
|
230
|
+
source_code=source,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
tasks.append(make_codefile_stub())
|
|
234
|
+
|
|
235
|
+
results: list[CodeFile] = await asyncio.gather(*tasks)
|
|
236
|
+
|
|
237
|
+
for source_code_file in results:
|
|
238
|
+
source_code_file.part_of = repo
|
|
239
|
+
if getattr(
|
|
240
|
+
source_code_file, "language", None
|
|
241
|
+
) is None and source_code_file.file_path.endswith(".py"):
|
|
242
|
+
source_code_file.language = "python"
|
|
243
|
+
yield source_code_file
|
|
@@ -1,23 +1,16 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
from typing import List
|
|
2
|
+
from typing import List
|
|
3
3
|
from cognee.infrastructure.engine import DataPoint
|
|
4
4
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
5
5
|
from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_from_model
|
|
6
6
|
from .index_data_points import index_data_points
|
|
7
7
|
from .index_graph_edges import index_graph_edges
|
|
8
|
-
from cognee.modules.engine.models import Triplet
|
|
9
|
-
from cognee.shared.logging_utils import get_logger
|
|
10
8
|
from cognee.tasks.storage.exceptions import (
|
|
11
9
|
InvalidDataPointsInAddDataPointsError,
|
|
12
10
|
)
|
|
13
|
-
from ...modules.engine.utils import generate_node_id
|
|
14
11
|
|
|
15
|
-
logger = get_logger("add_data_points")
|
|
16
12
|
|
|
17
|
-
|
|
18
|
-
async def add_data_points(
|
|
19
|
-
data_points: List[DataPoint], custom_edges: Optional[List] = None, embed_triplets: bool = False
|
|
20
|
-
) -> List[DataPoint]:
|
|
13
|
+
async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
|
|
21
14
|
"""
|
|
22
15
|
Add a batch of data points to the graph database by extracting nodes and edges,
|
|
23
16
|
deduplicating them, and indexing them for retrieval.
|
|
@@ -30,10 +23,6 @@ async def add_data_points(
|
|
|
30
23
|
Args:
|
|
31
24
|
data_points (List[DataPoint]):
|
|
32
25
|
A list of data points to process and insert into the graph.
|
|
33
|
-
custom_edges (List[tuple]): Custom edges between datapoints.
|
|
34
|
-
embed_triplets (bool):
|
|
35
|
-
If True, creates and indexes triplet embeddings from the graph structure.
|
|
36
|
-
Defaults to False.
|
|
37
26
|
|
|
38
27
|
Returns:
|
|
39
28
|
List[DataPoint]:
|
|
@@ -45,7 +34,6 @@ async def add_data_points(
|
|
|
45
34
|
- Updates the node index via `index_data_points`.
|
|
46
35
|
- Inserts nodes and edges into the graph engine.
|
|
47
36
|
- Optionally updates the edge index via `index_graph_edges`.
|
|
48
|
-
- Optionally creates and indexes triplet embeddings if embed_triplets is True.
|
|
49
37
|
"""
|
|
50
38
|
|
|
51
39
|
if not isinstance(data_points, list):
|
|
@@ -86,132 +74,4 @@ async def add_data_points(
|
|
|
86
74
|
await graph_engine.add_edges(edges)
|
|
87
75
|
await index_graph_edges(edges)
|
|
88
76
|
|
|
89
|
-
if isinstance(custom_edges, list) and custom_edges:
|
|
90
|
-
# This must be handled separately from datapoint edges, created a task in linear to dig deeper but (COG-3488)
|
|
91
|
-
await graph_engine.add_edges(custom_edges)
|
|
92
|
-
await index_graph_edges(custom_edges)
|
|
93
|
-
edges.extend(custom_edges)
|
|
94
|
-
|
|
95
|
-
if embed_triplets:
|
|
96
|
-
triplets = _create_triplets_from_graph(nodes, edges)
|
|
97
|
-
if triplets:
|
|
98
|
-
await index_data_points(triplets)
|
|
99
|
-
logger.info(f"Created and indexed {len(triplets)} triplets from graph structure")
|
|
100
|
-
|
|
101
77
|
return data_points
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
def _extract_embeddable_text_from_datapoint(data_point: DataPoint) -> str:
|
|
105
|
-
"""
|
|
106
|
-
Extract embeddable text from a DataPoint using its index_fields metadata.
|
|
107
|
-
Uses the same approach as index_data_points.
|
|
108
|
-
|
|
109
|
-
Parameters:
|
|
110
|
-
-----------
|
|
111
|
-
- data_point (DataPoint): The data point to extract text from.
|
|
112
|
-
|
|
113
|
-
Returns:
|
|
114
|
-
--------
|
|
115
|
-
- str: Concatenated string of all embeddable property values, or empty string if none found.
|
|
116
|
-
"""
|
|
117
|
-
if not data_point or not hasattr(data_point, "metadata"):
|
|
118
|
-
return ""
|
|
119
|
-
|
|
120
|
-
index_fields = data_point.metadata.get("index_fields", [])
|
|
121
|
-
if not index_fields:
|
|
122
|
-
return ""
|
|
123
|
-
|
|
124
|
-
embeddable_values = []
|
|
125
|
-
for field_name in index_fields:
|
|
126
|
-
field_value = getattr(data_point, field_name, None)
|
|
127
|
-
if field_value is not None:
|
|
128
|
-
field_value = str(field_value).strip()
|
|
129
|
-
|
|
130
|
-
if field_value:
|
|
131
|
-
embeddable_values.append(field_value)
|
|
132
|
-
|
|
133
|
-
return " ".join(embeddable_values) if embeddable_values else ""
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
def _create_triplets_from_graph(nodes: List[DataPoint], edges: List[tuple]) -> List[Triplet]:
|
|
137
|
-
"""
|
|
138
|
-
Create Triplet objects from graph nodes and edges.
|
|
139
|
-
|
|
140
|
-
This function processes graph edges and their corresponding nodes to create
|
|
141
|
-
triplet datapoints with embeddable text, similar to the triplet embeddings pipeline.
|
|
142
|
-
|
|
143
|
-
Parameters:
|
|
144
|
-
-----------
|
|
145
|
-
- nodes (List[DataPoint]): List of graph nodes extracted from data points
|
|
146
|
-
- edges (List[tuple]): List of edge tuples in format
|
|
147
|
-
(source_node_id, target_node_id, relationship_name, properties_dict)
|
|
148
|
-
Note: All edges including those from DocumentChunk.contains are already extracted
|
|
149
|
-
by get_graph_from_model and included in this list.
|
|
150
|
-
|
|
151
|
-
Returns:
|
|
152
|
-
--------
|
|
153
|
-
- List[Triplet]: List of Triplet objects ready for indexing
|
|
154
|
-
"""
|
|
155
|
-
node_map: Dict[str, DataPoint] = {}
|
|
156
|
-
for node in nodes:
|
|
157
|
-
if hasattr(node, "id"):
|
|
158
|
-
node_id = str(node.id)
|
|
159
|
-
if node_id not in node_map:
|
|
160
|
-
node_map[node_id] = node
|
|
161
|
-
|
|
162
|
-
triplets = []
|
|
163
|
-
skipped_count = 0
|
|
164
|
-
seen_ids = set()
|
|
165
|
-
|
|
166
|
-
for edge_tuple in edges:
|
|
167
|
-
if len(edge_tuple) < 4:
|
|
168
|
-
continue
|
|
169
|
-
|
|
170
|
-
source_node_id, target_node_id, relationship_name, edge_properties = (
|
|
171
|
-
edge_tuple[0],
|
|
172
|
-
edge_tuple[1],
|
|
173
|
-
edge_tuple[2],
|
|
174
|
-
edge_tuple[3],
|
|
175
|
-
)
|
|
176
|
-
|
|
177
|
-
source_node = node_map.get(str(source_node_id))
|
|
178
|
-
target_node = node_map.get(str(target_node_id))
|
|
179
|
-
|
|
180
|
-
if not source_node or not target_node or relationship_name is None:
|
|
181
|
-
skipped_count += 1
|
|
182
|
-
continue
|
|
183
|
-
|
|
184
|
-
source_node_text = _extract_embeddable_text_from_datapoint(source_node)
|
|
185
|
-
target_node_text = _extract_embeddable_text_from_datapoint(target_node)
|
|
186
|
-
|
|
187
|
-
relationship_text = ""
|
|
188
|
-
if isinstance(edge_properties, dict):
|
|
189
|
-
edge_text = edge_properties.get("edge_text")
|
|
190
|
-
if edge_text and isinstance(edge_text, str) and edge_text.strip():
|
|
191
|
-
relationship_text = edge_text.strip()
|
|
192
|
-
|
|
193
|
-
if not relationship_text and relationship_name:
|
|
194
|
-
relationship_text = relationship_name
|
|
195
|
-
|
|
196
|
-
if not source_node_text and not relationship_text and not relationship_name:
|
|
197
|
-
skipped_count += 1
|
|
198
|
-
continue
|
|
199
|
-
|
|
200
|
-
embeddable_text = f"{source_node_text} -› {relationship_text}-›{target_node_text}".strip()
|
|
201
|
-
|
|
202
|
-
triplet_id = generate_node_id(str(source_node_id) + relationship_name + str(target_node_id))
|
|
203
|
-
|
|
204
|
-
if triplet_id in seen_ids:
|
|
205
|
-
continue
|
|
206
|
-
seen_ids.add(triplet_id)
|
|
207
|
-
|
|
208
|
-
triplets.append(
|
|
209
|
-
Triplet(
|
|
210
|
-
id=triplet_id,
|
|
211
|
-
from_node_id=str(source_node_id),
|
|
212
|
-
to_node_id=str(target_node_id),
|
|
213
|
-
text=embeddable_text,
|
|
214
|
-
)
|
|
215
|
-
)
|
|
216
|
-
|
|
217
|
-
return triplets
|
|
@@ -25,6 +25,8 @@ class TestCogneeServerStart(unittest.TestCase):
|
|
|
25
25
|
"--port",
|
|
26
26
|
"8000",
|
|
27
27
|
],
|
|
28
|
+
stdout=subprocess.PIPE,
|
|
29
|
+
stderr=subprocess.PIPE,
|
|
28
30
|
preexec_fn=os.setsid,
|
|
29
31
|
)
|
|
30
32
|
# Give the server some time to start
|
|
@@ -148,8 +150,8 @@ class TestCogneeServerStart(unittest.TestCase):
|
|
|
148
150
|
headers=headers,
|
|
149
151
|
files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))],
|
|
150
152
|
data={
|
|
151
|
-
"ontology_key": ontology_key,
|
|
152
|
-
"description": "Test ontology",
|
|
153
|
+
"ontology_key": json.dumps([ontology_key]),
|
|
154
|
+
"description": json.dumps(["Test ontology"]),
|
|
153
155
|
},
|
|
154
156
|
)
|
|
155
157
|
self.assertEqual(ontology_response.status_code, 200)
|
|
@@ -8,10 +8,10 @@ Tests all retrievers that save conversation history to Redis cache:
|
|
|
8
8
|
4. GRAPH_COMPLETION_CONTEXT_EXTENSION
|
|
9
9
|
5. GRAPH_SUMMARY_COMPLETION
|
|
10
10
|
6. TEMPORAL
|
|
11
|
-
7. TRIPLET_COMPLETION
|
|
12
11
|
"""
|
|
13
12
|
|
|
14
13
|
import os
|
|
14
|
+
import shutil
|
|
15
15
|
import cognee
|
|
16
16
|
import pathlib
|
|
17
17
|
|
|
@@ -63,10 +63,6 @@ async def main():
|
|
|
63
63
|
|
|
64
64
|
user = await get_default_user()
|
|
65
65
|
|
|
66
|
-
from cognee.memify_pipelines.create_triplet_embeddings import create_triplet_embeddings
|
|
67
|
-
|
|
68
|
-
await create_triplet_embeddings(user=user, dataset=dataset_name)
|
|
69
|
-
|
|
70
66
|
cache_engine = get_cache_engine()
|
|
71
67
|
assert cache_engine is not None, "Cache engine should be available for testing"
|
|
72
68
|
|
|
@@ -220,24 +216,6 @@ async def main():
|
|
|
220
216
|
]
|
|
221
217
|
assert len(our_qa_temporal) == 1, "Should find Temporal question in history"
|
|
222
218
|
|
|
223
|
-
session_id_triplet = "test_session_triplet"
|
|
224
|
-
|
|
225
|
-
result_triplet = await cognee.search(
|
|
226
|
-
query_type=SearchType.TRIPLET_COMPLETION,
|
|
227
|
-
query_text="What companies are mentioned?",
|
|
228
|
-
session_id=session_id_triplet,
|
|
229
|
-
)
|
|
230
|
-
|
|
231
|
-
assert isinstance(result_triplet, list) and len(result_triplet) > 0, (
|
|
232
|
-
f"TRIPLET_COMPLETION should return non-empty list, got: {result_triplet!r}"
|
|
233
|
-
)
|
|
234
|
-
|
|
235
|
-
history_triplet = await cache_engine.get_latest_qa(str(user.id), session_id_triplet, last_n=10)
|
|
236
|
-
our_qa_triplet = [
|
|
237
|
-
h for h in history_triplet if h["question"] == "What companies are mentioned?"
|
|
238
|
-
]
|
|
239
|
-
assert len(our_qa_triplet) == 1, "Should find Triplet question in history"
|
|
240
|
-
|
|
241
219
|
from cognee.modules.retrieval.utils.session_cache import (
|
|
242
220
|
get_conversation_history,
|
|
243
221
|
)
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import pathlib
|
|
3
|
+
from uuid import UUID
|
|
4
|
+
|
|
5
|
+
import cognee
|
|
6
|
+
|
|
7
|
+
from cognee.api.v1.datasets import datasets
|
|
8
|
+
from cognee.api.v1.visualize.visualize import visualize_graph
|
|
9
|
+
from cognee.context_global_variables import set_database_global_context_variables
|
|
10
|
+
from cognee.modules.engine.operations.setup import setup
|
|
11
|
+
from cognee.modules.users.methods import get_default_user
|
|
12
|
+
|
|
13
|
+
# from cognee.modules.engine.operations.setup import setup
|
|
14
|
+
from cognee.shared.logging_utils import get_logger
|
|
15
|
+
|
|
16
|
+
logger = get_logger()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
async def main():
|
|
20
|
+
data_directory_path = os.path.join(
|
|
21
|
+
pathlib.Path(__file__).parent, ".data_storage/test_delete_bmw_example"
|
|
22
|
+
)
|
|
23
|
+
cognee.config.data_root_directory(data_directory_path)
|
|
24
|
+
|
|
25
|
+
cognee_directory_path = os.path.join(
|
|
26
|
+
pathlib.Path(__file__).parent, ".cognee_system/test_delete_bmw_example"
|
|
27
|
+
)
|
|
28
|
+
cognee.config.system_root_directory(cognee_directory_path)
|
|
29
|
+
|
|
30
|
+
# await cognee.prune.prune_data()
|
|
31
|
+
# await cognee.prune.prune_system(metadata=True)
|
|
32
|
+
# await setup()
|
|
33
|
+
|
|
34
|
+
# add_result = await cognee.add("Bmw is a german carmanufacturer")
|
|
35
|
+
# add_result = await cognee.add("Germany is located next to the netherlands")
|
|
36
|
+
# data_id = add_result.data_ingestion_info[0]["data_id"]
|
|
37
|
+
|
|
38
|
+
# cognify_result: dict = await cognee.cognify()
|
|
39
|
+
# dataset_id = list(cognify_result.keys())[0]
|
|
40
|
+
|
|
41
|
+
user = await get_default_user()
|
|
42
|
+
await set_database_global_context_variables("main_dataset", user.id)
|
|
43
|
+
|
|
44
|
+
graph_file_path = os.path.join(data_directory_path, "artifacts/graph-before.html")
|
|
45
|
+
await visualize_graph(graph_file_path)
|
|
46
|
+
|
|
47
|
+
await datasets.delete_data(
|
|
48
|
+
UUID("b52be2e1-9fdb-5be0-a317-d3a56e9a34c6"),
|
|
49
|
+
UUID("fdae2cbd-61e1-5e99-93ca-4f3e32ed0d02"),
|
|
50
|
+
user,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
graph_file_path = os.path.join(data_directory_path, "artifacts/graph-after.html")
|
|
54
|
+
await visualize_graph(graph_file_path)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
if __name__ == "__main__":
|
|
58
|
+
import asyncio
|
|
59
|
+
|
|
60
|
+
asyncio.run(main())
|
cognee/tests/test_search_db.py
CHANGED
|
@@ -2,7 +2,6 @@ import pathlib
|
|
|
2
2
|
import os
|
|
3
3
|
import cognee
|
|
4
4
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
5
|
-
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
6
5
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
|
7
6
|
from cognee.modules.graph.utils import resolve_edges_to_text
|
|
8
7
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
|
@@ -13,10 +12,8 @@ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphComplet
|
|
|
13
12
|
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
|
14
13
|
GraphSummaryCompletionRetriever,
|
|
15
14
|
)
|
|
16
|
-
from cognee.modules.retrieval.triplet_retriever import TripletRetriever
|
|
17
15
|
from cognee.shared.logging_utils import get_logger
|
|
18
16
|
from cognee.modules.search.types import SearchType
|
|
19
|
-
from cognee.modules.users.methods import get_default_user
|
|
20
17
|
from collections import Counter
|
|
21
18
|
|
|
22
19
|
logger = get_logger()
|
|
@@ -40,23 +37,6 @@ async def main():
|
|
|
40
37
|
|
|
41
38
|
await cognee.cognify([dataset_name])
|
|
42
39
|
|
|
43
|
-
user = await get_default_user()
|
|
44
|
-
from cognee.memify_pipelines.create_triplet_embeddings import create_triplet_embeddings
|
|
45
|
-
|
|
46
|
-
await create_triplet_embeddings(user=user, dataset=dataset_name, triplets_batch_size=5)
|
|
47
|
-
|
|
48
|
-
graph_engine = await get_graph_engine()
|
|
49
|
-
nodes, edges = await graph_engine.get_graph_data()
|
|
50
|
-
|
|
51
|
-
vector_engine = get_vector_engine()
|
|
52
|
-
collection = await vector_engine.search(
|
|
53
|
-
query_text="Test", limit=None, collection_name="Triplet_text"
|
|
54
|
-
)
|
|
55
|
-
|
|
56
|
-
assert len(edges) == len(collection), (
|
|
57
|
-
f"Expected {len(edges)} edges but got {len(collection)} in Triplet_text collection"
|
|
58
|
-
)
|
|
59
|
-
|
|
60
40
|
context_gk = await GraphCompletionRetriever().get_context(
|
|
61
41
|
query="Next to which country is Germany located?"
|
|
62
42
|
)
|
|
@@ -69,9 +49,6 @@ async def main():
|
|
|
69
49
|
context_gk_sum = await GraphSummaryCompletionRetriever().get_context(
|
|
70
50
|
query="Next to which country is Germany located?"
|
|
71
51
|
)
|
|
72
|
-
context_triplet = await TripletRetriever().get_context(
|
|
73
|
-
query="Next to which country is Germany located?"
|
|
74
|
-
)
|
|
75
52
|
|
|
76
53
|
for name, context in [
|
|
77
54
|
("GraphCompletionRetriever", context_gk),
|
|
@@ -88,13 +65,6 @@ async def main():
|
|
|
88
65
|
f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}"
|
|
89
66
|
)
|
|
90
67
|
|
|
91
|
-
assert isinstance(context_triplet, str), "TripletRetriever: Context should be a string"
|
|
92
|
-
assert len(context_triplet) > 0, "TripletRetriever: Context should not be empty"
|
|
93
|
-
lower_triplet = context_triplet.lower()
|
|
94
|
-
assert "germany" in lower_triplet or "netherlands" in lower_triplet, (
|
|
95
|
-
f"TripletRetriever: Context did not contain 'germany' or 'netherlands'; got: {context_triplet!r}"
|
|
96
|
-
)
|
|
97
|
-
|
|
98
68
|
triplets_gk = await GraphCompletionRetriever().get_triplets(
|
|
99
69
|
query="Next to which country is Germany located?"
|
|
100
70
|
)
|
|
@@ -159,11 +129,6 @@ async def main():
|
|
|
159
129
|
query_text="Next to which country is Germany located?",
|
|
160
130
|
save_interaction=True,
|
|
161
131
|
)
|
|
162
|
-
completion_triplet = await cognee.search(
|
|
163
|
-
query_type=SearchType.TRIPLET_COMPLETION,
|
|
164
|
-
query_text="Next to which country is Germany located?",
|
|
165
|
-
save_interaction=True,
|
|
166
|
-
)
|
|
167
132
|
|
|
168
133
|
await cognee.search(
|
|
169
134
|
query_type=SearchType.FEEDBACK,
|
|
@@ -176,7 +141,6 @@ async def main():
|
|
|
176
141
|
("GRAPH_COMPLETION_COT", completion_cot),
|
|
177
142
|
("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext),
|
|
178
143
|
("GRAPH_SUMMARY_COMPLETION", completion_sum),
|
|
179
|
-
("TRIPLET_COMPLETION", completion_triplet),
|
|
180
144
|
]:
|
|
181
145
|
assert isinstance(search_results, list), f"{name}: should return a list"
|
|
182
146
|
assert len(search_results) == 1, (
|
|
@@ -204,7 +168,7 @@ async def main():
|
|
|
204
168
|
|
|
205
169
|
# Assert there are exactly 4 CogneeUserInteraction nodes.
|
|
206
170
|
assert type_counts.get("CogneeUserInteraction", 0) == 4, (
|
|
207
|
-
f"Expected exactly four
|
|
171
|
+
f"Expected exactly four DCogneeUserInteraction nodes, but found {type_counts.get('CogneeUserInteraction', 0)}"
|
|
208
172
|
)
|
|
209
173
|
|
|
210
174
|
# Assert there is exactly two CogneeUserFeedback nodes.
|