cognee 0.5.0__py3-none-any.whl → 0.5.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/api/client.py +5 -1
- cognee/api/v1/add/add.py +1 -2
- cognee/api/v1/cognify/code_graph_pipeline.py +119 -0
- cognee/api/v1/cognify/cognify.py +16 -24
- cognee/api/v1/cognify/routers/__init__.py +1 -0
- cognee/api/v1/cognify/routers/get_code_pipeline_router.py +90 -0
- cognee/api/v1/cognify/routers/get_cognify_router.py +1 -3
- cognee/api/v1/datasets/routers/get_datasets_router.py +3 -3
- cognee/api/v1/ontologies/ontologies.py +37 -12
- cognee/api/v1/ontologies/routers/get_ontology_router.py +25 -27
- cognee/api/v1/search/search.py +0 -4
- cognee/api/v1/ui/ui.py +68 -38
- cognee/context_global_variables.py +16 -61
- cognee/eval_framework/answer_generation/answer_generation_executor.py +0 -10
- cognee/eval_framework/answer_generation/run_question_answering_module.py +1 -1
- cognee/eval_framework/corpus_builder/task_getters/get_cascade_graph_tasks.py +2 -0
- cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py +4 -4
- cognee/eval_framework/eval_config.py +2 -2
- cognee/eval_framework/modal_run_eval.py +28 -16
- cognee/infrastructure/databases/graph/config.py +0 -3
- cognee/infrastructure/databases/graph/get_graph_engine.py +0 -1
- cognee/infrastructure/databases/graph/graph_db_interface.py +0 -15
- cognee/infrastructure/databases/graph/kuzu/adapter.py +0 -228
- cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +1 -80
- cognee/infrastructure/databases/utils/__init__.py +0 -3
- cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +48 -62
- cognee/infrastructure/databases/vector/config.py +0 -2
- cognee/infrastructure/databases/vector/create_vector_engine.py +0 -1
- cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +6 -8
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +7 -9
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +10 -11
- cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +544 -0
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +0 -2
- cognee/infrastructure/databases/vector/vector_db_interface.py +0 -35
- cognee/infrastructure/files/storage/s3_config.py +0 -2
- cognee/infrastructure/llm/LLMGateway.py +2 -5
- cognee/infrastructure/llm/config.py +0 -35
- cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py +2 -2
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +8 -23
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +16 -17
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +37 -40
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +36 -39
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +1 -19
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +9 -11
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +21 -23
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +34 -42
- cognee/modules/cognify/config.py +0 -2
- cognee/modules/data/deletion/prune_system.py +2 -52
- cognee/modules/data/methods/delete_dataset.py +0 -26
- cognee/modules/engine/models/__init__.py +0 -1
- cognee/modules/graph/cognee_graph/CogneeGraph.py +37 -85
- cognee/modules/graph/cognee_graph/CogneeGraphElements.py +3 -8
- cognee/modules/memify/memify.py +7 -1
- cognee/modules/pipelines/operations/pipeline.py +2 -18
- cognee/modules/retrieval/__init__.py +1 -1
- cognee/modules/retrieval/code_retriever.py +232 -0
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +0 -4
- cognee/modules/retrieval/graph_completion_cot_retriever.py +0 -4
- cognee/modules/retrieval/graph_completion_retriever.py +0 -10
- cognee/modules/retrieval/graph_summary_completion_retriever.py +0 -4
- cognee/modules/retrieval/temporal_retriever.py +0 -4
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +10 -42
- cognee/modules/run_custom_pipeline/run_custom_pipeline.py +1 -8
- cognee/modules/search/methods/get_search_type_tools.py +8 -54
- cognee/modules/search/methods/no_access_control_search.py +0 -4
- cognee/modules/search/methods/search.py +0 -21
- cognee/modules/search/types/SearchType.py +1 -1
- cognee/modules/settings/get_settings.py +0 -19
- cognee/modules/users/methods/get_authenticated_user.py +2 -2
- cognee/modules/users/models/DatasetDatabase.py +3 -15
- cognee/shared/logging_utils.py +0 -4
- cognee/tasks/code/enrich_dependency_graph_checker.py +35 -0
- cognee/tasks/code/get_local_dependencies_checker.py +20 -0
- cognee/tasks/code/get_repo_dependency_graph_checker.py +35 -0
- cognee/tasks/documents/__init__.py +1 -0
- cognee/tasks/documents/check_permissions_on_dataset.py +26 -0
- cognee/tasks/graph/extract_graph_from_data.py +10 -9
- cognee/tasks/repo_processor/__init__.py +2 -0
- cognee/tasks/repo_processor/get_local_dependencies.py +335 -0
- cognee/tasks/repo_processor/get_non_code_files.py +158 -0
- cognee/tasks/repo_processor/get_repo_file_dependencies.py +243 -0
- cognee/tasks/storage/add_data_points.py +2 -142
- cognee/tests/test_cognee_server_start.py +4 -2
- cognee/tests/test_conversation_history.py +1 -23
- cognee/tests/test_delete_bmw_example.py +60 -0
- cognee/tests/test_search_db.py +1 -37
- cognee/tests/unit/api/test_ontology_endpoint.py +89 -77
- cognee/tests/unit/infrastructure/mock_embedding_engine.py +7 -3
- cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +5 -0
- cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +2 -2
- cognee/tests/unit/modules/graph/cognee_graph_test.py +0 -406
- {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/METADATA +89 -76
- {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/RECORD +97 -118
- {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/WHEEL +1 -1
- cognee/api/v1/ui/node_setup.py +0 -360
- cognee/api/v1/ui/npm_utils.py +0 -50
- cognee/eval_framework/Dockerfile +0 -29
- cognee/infrastructure/databases/dataset_database_handler/__init__.py +0 -3
- cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py +0 -80
- cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +0 -18
- cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py +0 -10
- cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py +0 -81
- cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +0 -168
- cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py +0 -10
- cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py +0 -10
- cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py +0 -30
- cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py +0 -50
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py +0 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py +0 -153
- cognee/memify_pipelines/create_triplet_embeddings.py +0 -53
- cognee/modules/engine/models/Triplet.py +0 -9
- cognee/modules/retrieval/register_retriever.py +0 -10
- cognee/modules/retrieval/registered_community_retrievers.py +0 -1
- cognee/modules/retrieval/triplet_retriever.py +0 -182
- cognee/shared/rate_limiting.py +0 -30
- cognee/tasks/memify/get_triplet_datapoints.py +0 -289
- cognee/tests/integration/retrieval/test_triplet_retriever.py +0 -84
- cognee/tests/integration/tasks/test_add_data_points.py +0 -139
- cognee/tests/integration/tasks/test_get_triplet_datapoints.py +0 -69
- cognee/tests/test_dataset_database_handler.py +0 -137
- cognee/tests/test_dataset_delete.py +0 -76
- cognee/tests/test_edge_centered_payload.py +0 -170
- cognee/tests/test_pipeline_cache.py +0 -164
- cognee/tests/unit/infrastructure/llm/test_llm_config.py +0 -46
- cognee/tests/unit/modules/memify_tasks/test_get_triplet_datapoints.py +0 -214
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +0 -608
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +0 -83
- cognee/tests/unit/tasks/storage/test_add_data_points.py +0 -288
- {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -1,182 +0,0 @@
|
|
|
1
|
-
import asyncio
|
|
2
|
-
from typing import Any, Optional, Type, List
|
|
3
|
-
|
|
4
|
-
from cognee.shared.logging_utils import get_logger
|
|
5
|
-
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
6
|
-
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
|
7
|
-
from cognee.modules.retrieval.utils.session_cache import (
|
|
8
|
-
save_conversation_history,
|
|
9
|
-
get_conversation_history,
|
|
10
|
-
)
|
|
11
|
-
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|
12
|
-
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
|
13
|
-
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
|
14
|
-
from cognee.context_global_variables import session_user
|
|
15
|
-
from cognee.infrastructure.databases.cache.config import CacheConfig
|
|
16
|
-
|
|
17
|
-
logger = get_logger("TripletRetriever")
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class TripletRetriever(BaseRetriever):
|
|
21
|
-
"""
|
|
22
|
-
Retriever for handling LLM-based completion searches using triplets.
|
|
23
|
-
|
|
24
|
-
Public methods:
|
|
25
|
-
- get_context(query: str) -> str
|
|
26
|
-
- get_completion(query: str, context: Optional[Any] = None) -> Any
|
|
27
|
-
"""
|
|
28
|
-
|
|
29
|
-
def __init__(
|
|
30
|
-
self,
|
|
31
|
-
user_prompt_path: str = "context_for_question.txt",
|
|
32
|
-
system_prompt_path: str = "answer_simple_question.txt",
|
|
33
|
-
system_prompt: Optional[str] = None,
|
|
34
|
-
top_k: Optional[int] = 5,
|
|
35
|
-
):
|
|
36
|
-
"""Initialize retriever with optional custom prompt paths."""
|
|
37
|
-
self.user_prompt_path = user_prompt_path
|
|
38
|
-
self.system_prompt_path = system_prompt_path
|
|
39
|
-
self.top_k = top_k if top_k is not None else 1
|
|
40
|
-
self.system_prompt = system_prompt
|
|
41
|
-
|
|
42
|
-
async def get_context(self, query: str) -> str:
|
|
43
|
-
"""
|
|
44
|
-
Retrieves relevant triplets as context.
|
|
45
|
-
|
|
46
|
-
Fetches triplets based on a query from a vector engine and combines their text.
|
|
47
|
-
Returns empty string if no triplets are found. Raises NoDataError if the collection is not
|
|
48
|
-
found.
|
|
49
|
-
|
|
50
|
-
Parameters:
|
|
51
|
-
-----------
|
|
52
|
-
|
|
53
|
-
- query (str): The query string used to search for relevant triplets.
|
|
54
|
-
|
|
55
|
-
Returns:
|
|
56
|
-
--------
|
|
57
|
-
|
|
58
|
-
- str: A string containing the combined text of the retrieved triplets, or an
|
|
59
|
-
empty string if none are found.
|
|
60
|
-
"""
|
|
61
|
-
vector_engine = get_vector_engine()
|
|
62
|
-
|
|
63
|
-
try:
|
|
64
|
-
if not await vector_engine.has_collection(collection_name="Triplet_text"):
|
|
65
|
-
logger.error("Triplet_text collection not found")
|
|
66
|
-
raise NoDataError(
|
|
67
|
-
"In order to use TRIPLET_COMPLETION first use the create_triplet_embeddings memify pipeline. "
|
|
68
|
-
)
|
|
69
|
-
|
|
70
|
-
found_triplets = await vector_engine.search("Triplet_text", query, limit=self.top_k)
|
|
71
|
-
|
|
72
|
-
if len(found_triplets) == 0:
|
|
73
|
-
return ""
|
|
74
|
-
|
|
75
|
-
triplets_payload = [found_triplet.payload["text"] for found_triplet in found_triplets]
|
|
76
|
-
combined_context = "\n".join(triplets_payload)
|
|
77
|
-
return combined_context
|
|
78
|
-
except CollectionNotFoundError as error:
|
|
79
|
-
logger.error("Triplet_text collection not found")
|
|
80
|
-
raise NoDataError("No data found in the system, please add data first.") from error
|
|
81
|
-
|
|
82
|
-
async def get_completion(
|
|
83
|
-
self,
|
|
84
|
-
query: str,
|
|
85
|
-
context: Optional[Any] = None,
|
|
86
|
-
session_id: Optional[str] = None,
|
|
87
|
-
response_model: Type = str,
|
|
88
|
-
) -> List[Any]:
|
|
89
|
-
"""
|
|
90
|
-
Generates an LLM completion using the context.
|
|
91
|
-
|
|
92
|
-
Retrieves context if not provided and generates a completion based on the query and
|
|
93
|
-
context using an external completion generator.
|
|
94
|
-
|
|
95
|
-
Parameters:
|
|
96
|
-
-----------
|
|
97
|
-
|
|
98
|
-
- query (str): The query string to be used for generating a completion.
|
|
99
|
-
- context (Optional[Any]): Optional pre-fetched context to use for generating the
|
|
100
|
-
completion; if None, it retrieves the context for the query. (default None)
|
|
101
|
-
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
|
102
|
-
defaults to 'default_session'. (default None)
|
|
103
|
-
- response_model (Type): The Pydantic model type for structured output. (default str)
|
|
104
|
-
|
|
105
|
-
Returns:
|
|
106
|
-
--------
|
|
107
|
-
|
|
108
|
-
- Any: The generated completion based on the provided query and context.
|
|
109
|
-
"""
|
|
110
|
-
if context is None:
|
|
111
|
-
context = await self.get_context(query)
|
|
112
|
-
|
|
113
|
-
cache_config = CacheConfig()
|
|
114
|
-
user = session_user.get()
|
|
115
|
-
user_id = getattr(user, "id", None)
|
|
116
|
-
session_save = user_id and cache_config.caching
|
|
117
|
-
|
|
118
|
-
if session_save:
|
|
119
|
-
completion = await self._get_completion_with_session(
|
|
120
|
-
query=query,
|
|
121
|
-
context=context,
|
|
122
|
-
session_id=session_id,
|
|
123
|
-
response_model=response_model,
|
|
124
|
-
)
|
|
125
|
-
else:
|
|
126
|
-
completion = await self._get_completion_without_session(
|
|
127
|
-
query=query,
|
|
128
|
-
context=context,
|
|
129
|
-
response_model=response_model,
|
|
130
|
-
)
|
|
131
|
-
|
|
132
|
-
return [completion]
|
|
133
|
-
|
|
134
|
-
async def _get_completion_with_session(
|
|
135
|
-
self,
|
|
136
|
-
query: str,
|
|
137
|
-
context: str,
|
|
138
|
-
session_id: Optional[str],
|
|
139
|
-
response_model: Type,
|
|
140
|
-
) -> Any:
|
|
141
|
-
"""Generate completion with session history and caching."""
|
|
142
|
-
conversation_history = await get_conversation_history(session_id=session_id)
|
|
143
|
-
|
|
144
|
-
context_summary, completion = await asyncio.gather(
|
|
145
|
-
summarize_text(context),
|
|
146
|
-
generate_completion(
|
|
147
|
-
query=query,
|
|
148
|
-
context=context,
|
|
149
|
-
user_prompt_path=self.user_prompt_path,
|
|
150
|
-
system_prompt_path=self.system_prompt_path,
|
|
151
|
-
system_prompt=self.system_prompt,
|
|
152
|
-
conversation_history=conversation_history,
|
|
153
|
-
response_model=response_model,
|
|
154
|
-
),
|
|
155
|
-
)
|
|
156
|
-
|
|
157
|
-
await save_conversation_history(
|
|
158
|
-
query=query,
|
|
159
|
-
context_summary=context_summary,
|
|
160
|
-
answer=completion,
|
|
161
|
-
session_id=session_id,
|
|
162
|
-
)
|
|
163
|
-
|
|
164
|
-
return completion
|
|
165
|
-
|
|
166
|
-
async def _get_completion_without_session(
|
|
167
|
-
self,
|
|
168
|
-
query: str,
|
|
169
|
-
context: str,
|
|
170
|
-
response_model: Type,
|
|
171
|
-
) -> Any:
|
|
172
|
-
"""Generate completion without session history."""
|
|
173
|
-
completion = await generate_completion(
|
|
174
|
-
query=query,
|
|
175
|
-
context=context,
|
|
176
|
-
user_prompt_path=self.user_prompt_path,
|
|
177
|
-
system_prompt_path=self.system_prompt_path,
|
|
178
|
-
system_prompt=self.system_prompt,
|
|
179
|
-
response_model=response_model,
|
|
180
|
-
)
|
|
181
|
-
|
|
182
|
-
return completion
|
cognee/shared/rate_limiting.py
DELETED
|
@@ -1,30 +0,0 @@
|
|
|
1
|
-
from aiolimiter import AsyncLimiter
|
|
2
|
-
from contextlib import nullcontext
|
|
3
|
-
from cognee.infrastructure.llm.config import get_llm_config
|
|
4
|
-
|
|
5
|
-
llm_config = get_llm_config()
|
|
6
|
-
|
|
7
|
-
llm_rate_limiter = AsyncLimiter(
|
|
8
|
-
llm_config.llm_rate_limit_requests, llm_config.embedding_rate_limit_interval
|
|
9
|
-
)
|
|
10
|
-
embedding_rate_limiter = AsyncLimiter(
|
|
11
|
-
llm_config.embedding_rate_limit_requests, llm_config.embedding_rate_limit_interval
|
|
12
|
-
)
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def llm_rate_limiter_context_manager():
|
|
16
|
-
global llm_rate_limiter
|
|
17
|
-
if llm_config.llm_rate_limit_enabled:
|
|
18
|
-
return llm_rate_limiter
|
|
19
|
-
else:
|
|
20
|
-
# Return a no-op context manager if rate limiting is disabled
|
|
21
|
-
return nullcontext()
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def embedding_rate_limiter_context_manager():
|
|
25
|
-
global embedding_rate_limiter
|
|
26
|
-
if llm_config.embedding_rate_limit_enabled:
|
|
27
|
-
return embedding_rate_limiter
|
|
28
|
-
else:
|
|
29
|
-
# Return a no-op context manager if rate limiting is disabled
|
|
30
|
-
return nullcontext()
|
|
@@ -1,289 +0,0 @@
|
|
|
1
|
-
from typing import AsyncGenerator, Dict, Any, List, Optional
|
|
2
|
-
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
|
3
|
-
from cognee.modules.engine.utils import generate_node_id
|
|
4
|
-
from cognee.shared.logging_utils import get_logger
|
|
5
|
-
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
|
6
|
-
from cognee.infrastructure.engine import DataPoint
|
|
7
|
-
from cognee.modules.engine.models import Triplet
|
|
8
|
-
from cognee.tasks.storage import index_data_points
|
|
9
|
-
|
|
10
|
-
logger = get_logger("get_triplet_datapoints")
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
def _build_datapoint_type_index_mapping() -> Dict[str, List[str]]:
|
|
14
|
-
"""
|
|
15
|
-
Build a mapping of DataPoint type names to their index_fields.
|
|
16
|
-
|
|
17
|
-
Returns:
|
|
18
|
-
--------
|
|
19
|
-
- Dict[str, List[str]]: Mapping of type name to list of index field names
|
|
20
|
-
"""
|
|
21
|
-
logger.debug("Building DataPoint type to index_fields mapping")
|
|
22
|
-
subclasses = get_all_subclasses(DataPoint)
|
|
23
|
-
datapoint_type_index_property = {}
|
|
24
|
-
|
|
25
|
-
for subclass in subclasses:
|
|
26
|
-
if "metadata" in subclass.model_fields:
|
|
27
|
-
metadata_field = subclass.model_fields["metadata"]
|
|
28
|
-
default = getattr(metadata_field, "default", None)
|
|
29
|
-
if isinstance(default, dict):
|
|
30
|
-
index_fields = default.get("index_fields", [])
|
|
31
|
-
if index_fields:
|
|
32
|
-
datapoint_type_index_property[subclass.__name__] = index_fields
|
|
33
|
-
logger.debug(
|
|
34
|
-
f"Registered {subclass.__name__} with index_fields: {index_fields}"
|
|
35
|
-
)
|
|
36
|
-
|
|
37
|
-
logger.info(
|
|
38
|
-
f"Found {len(datapoint_type_index_property)} DataPoint types with index_fields: "
|
|
39
|
-
f"{list(datapoint_type_index_property.keys())}"
|
|
40
|
-
)
|
|
41
|
-
return datapoint_type_index_property
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
def _extract_embeddable_text(node_or_edge: Dict[str, Any], index_fields: List[str]) -> str:
|
|
45
|
-
"""
|
|
46
|
-
Extract and concatenate embeddable properties from a node or edge dictionary.
|
|
47
|
-
|
|
48
|
-
Parameters:
|
|
49
|
-
-----------
|
|
50
|
-
- node_or_edge (Dict[str, Any]): Dictionary containing node or edge properties.
|
|
51
|
-
- index_fields (List[str]): List of field names to extract and concatenate.
|
|
52
|
-
|
|
53
|
-
Returns:
|
|
54
|
-
--------
|
|
55
|
-
- str: Concatenated string of all embeddable property values, or empty string if none found.
|
|
56
|
-
"""
|
|
57
|
-
if not node_or_edge or not index_fields:
|
|
58
|
-
return ""
|
|
59
|
-
|
|
60
|
-
embeddable_values = []
|
|
61
|
-
for field_name in index_fields:
|
|
62
|
-
field_value = node_or_edge.get(field_name)
|
|
63
|
-
if field_value is not None:
|
|
64
|
-
field_value = str(field_value).strip()
|
|
65
|
-
|
|
66
|
-
if field_value:
|
|
67
|
-
embeddable_values.append(field_value)
|
|
68
|
-
|
|
69
|
-
return " ".join(embeddable_values) if embeddable_values else ""
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
def _extract_relationship_text(
|
|
73
|
-
relationship: Dict[str, Any], datapoint_type_index_property: Dict[str, List[str]]
|
|
74
|
-
) -> str:
|
|
75
|
-
"""
|
|
76
|
-
Extract relationship text from edge properties.
|
|
77
|
-
|
|
78
|
-
Parameters:
|
|
79
|
-
-----------
|
|
80
|
-
- relationship (Dict[str, Any]): Dictionary containing relationship properties
|
|
81
|
-
- datapoint_type_index_property (Dict[str, List[str]]): Mapping of type to index fields
|
|
82
|
-
|
|
83
|
-
Returns:
|
|
84
|
-
--------
|
|
85
|
-
- str: Extracted relationship text or empty string
|
|
86
|
-
"""
|
|
87
|
-
if not relationship:
|
|
88
|
-
return ""
|
|
89
|
-
|
|
90
|
-
edge_text = relationship.get("edge_text")
|
|
91
|
-
if edge_text and isinstance(edge_text, str) and edge_text.strip():
|
|
92
|
-
return edge_text.strip()
|
|
93
|
-
|
|
94
|
-
# Fallback to extracting from EdgeType index_fields
|
|
95
|
-
edge_type_index_fields = datapoint_type_index_property.get("EdgeType", [])
|
|
96
|
-
return _extract_embeddable_text(relationship, edge_type_index_fields)
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
def _process_single_triplet(
|
|
100
|
-
triplet_datapoint: Dict[str, Any],
|
|
101
|
-
datapoint_type_index_property: Dict[str, List[str]],
|
|
102
|
-
offset: int,
|
|
103
|
-
idx: int,
|
|
104
|
-
) -> tuple[Optional[Triplet], Optional[str]]:
|
|
105
|
-
"""
|
|
106
|
-
Process a single triplet and create a Triplet object.
|
|
107
|
-
|
|
108
|
-
Parameters:
|
|
109
|
-
-----------
|
|
110
|
-
- triplet_datapoint (Dict[str, Any]): Raw triplet data from graph engine
|
|
111
|
-
- datapoint_type_index_property (Dict[str, List[str]]): Type to index fields mapping
|
|
112
|
-
- offset (int): Current batch offset
|
|
113
|
-
- idx (int): Index within current batch
|
|
114
|
-
|
|
115
|
-
Returns:
|
|
116
|
-
--------
|
|
117
|
-
- tuple[Optional[Triplet], Optional[str]]: (Triplet object, error message if skipped)
|
|
118
|
-
"""
|
|
119
|
-
start_node = triplet_datapoint.get("start_node", {})
|
|
120
|
-
end_node = triplet_datapoint.get("end_node", {})
|
|
121
|
-
relationship = triplet_datapoint.get("relationship_properties", {})
|
|
122
|
-
|
|
123
|
-
start_node_type = start_node.get("type")
|
|
124
|
-
end_node_type = end_node.get("type")
|
|
125
|
-
|
|
126
|
-
start_index_fields = datapoint_type_index_property.get(start_node_type, [])
|
|
127
|
-
end_index_fields = datapoint_type_index_property.get(end_node_type, [])
|
|
128
|
-
|
|
129
|
-
if not start_index_fields:
|
|
130
|
-
logger.debug(
|
|
131
|
-
f"No index_fields found for start_node type '{start_node_type}' in triplet {offset + idx}"
|
|
132
|
-
)
|
|
133
|
-
if not end_index_fields:
|
|
134
|
-
logger.debug(
|
|
135
|
-
f"No index_fields found for end_node type '{end_node_type}' in triplet {offset + idx}"
|
|
136
|
-
)
|
|
137
|
-
|
|
138
|
-
start_node_id = start_node.get("id", "")
|
|
139
|
-
end_node_id = end_node.get("id", "")
|
|
140
|
-
|
|
141
|
-
if not start_node_id or not end_node_id:
|
|
142
|
-
return None, (
|
|
143
|
-
f"Skipping triplet at offset {offset + idx}: missing node IDs "
|
|
144
|
-
f"(start: {start_node_id}, end: {end_node_id})"
|
|
145
|
-
)
|
|
146
|
-
|
|
147
|
-
relationship_text = _extract_relationship_text(relationship, datapoint_type_index_property)
|
|
148
|
-
start_node_text = _extract_embeddable_text(start_node, start_index_fields)
|
|
149
|
-
end_node_text = _extract_embeddable_text(end_node, end_index_fields)
|
|
150
|
-
|
|
151
|
-
if not start_node_text and not end_node_text and not relationship_text:
|
|
152
|
-
return None, (
|
|
153
|
-
f"Skipping triplet at offset {offset + idx}: empty embeddable text "
|
|
154
|
-
f"(start_node_id: {start_node_id}, end_node_id: {end_node_id})"
|
|
155
|
-
)
|
|
156
|
-
|
|
157
|
-
embeddable_text = f"{start_node_text}-›{relationship_text}-›{end_node_text}".strip()
|
|
158
|
-
|
|
159
|
-
relationship_name = relationship.get("relationship_name", "")
|
|
160
|
-
triplet_id = generate_node_id(str(start_node_id) + str(relationship_name) + str(end_node_id))
|
|
161
|
-
|
|
162
|
-
triplet_obj = Triplet(
|
|
163
|
-
id=triplet_id, from_node_id=start_node_id, to_node_id=end_node_id, text=embeddable_text
|
|
164
|
-
)
|
|
165
|
-
|
|
166
|
-
return triplet_obj, None
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
async def get_triplet_datapoints(
|
|
170
|
-
data,
|
|
171
|
-
triplets_batch_size: int = 100,
|
|
172
|
-
) -> AsyncGenerator[Triplet, None]:
|
|
173
|
-
"""
|
|
174
|
-
Async generator that yields batches of triplet datapoints with embeddable text extracted.
|
|
175
|
-
|
|
176
|
-
Each triplet in the batch includes:
|
|
177
|
-
- Original triplet structure (start_node, relationship_properties, end_node)
|
|
178
|
-
- Extracted embeddable text for each element based on index_fields
|
|
179
|
-
|
|
180
|
-
Parameters:
|
|
181
|
-
-----------
|
|
182
|
-
- triplets_batch_size (int): Number of triplets to retrieve per batch. Default is 100.
|
|
183
|
-
|
|
184
|
-
Yields:
|
|
185
|
-
-------
|
|
186
|
-
- List[Dict[str, Any]]: A batch of triplets, each enriched with embeddable text.
|
|
187
|
-
"""
|
|
188
|
-
if not data or data == [{}]:
|
|
189
|
-
logger.info("Fetching graph data for current user")
|
|
190
|
-
|
|
191
|
-
logger.info(f"Starting triplet datapoints extraction with batch size: {triplets_batch_size}")
|
|
192
|
-
|
|
193
|
-
graph_engine = await get_graph_engine()
|
|
194
|
-
graph_engine_type = type(graph_engine).__name__
|
|
195
|
-
logger.debug(f"Using graph engine: {graph_engine_type}")
|
|
196
|
-
|
|
197
|
-
if not hasattr(graph_engine, "get_triplets_batch"):
|
|
198
|
-
error_msg = f"Graph adapter {graph_engine_type} does not support get_triplets_batch method"
|
|
199
|
-
logger.error(error_msg)
|
|
200
|
-
raise NotImplementedError(error_msg)
|
|
201
|
-
|
|
202
|
-
datapoint_type_index_property = _build_datapoint_type_index_mapping()
|
|
203
|
-
|
|
204
|
-
offset = 0
|
|
205
|
-
total_triplets_processed = 0
|
|
206
|
-
batch_number = 0
|
|
207
|
-
|
|
208
|
-
while True:
|
|
209
|
-
try:
|
|
210
|
-
batch_number += 1
|
|
211
|
-
logger.debug(
|
|
212
|
-
f"Fetching triplet batch {batch_number} (offset: {offset}, limit: {triplets_batch_size})"
|
|
213
|
-
)
|
|
214
|
-
|
|
215
|
-
triplets_batch = await graph_engine.get_triplets_batch(
|
|
216
|
-
offset=offset, limit=triplets_batch_size
|
|
217
|
-
)
|
|
218
|
-
|
|
219
|
-
if not triplets_batch:
|
|
220
|
-
logger.info(f"No more triplets found at offset {offset}. Processing complete.")
|
|
221
|
-
break
|
|
222
|
-
|
|
223
|
-
logger.debug(f"Retrieved {len(triplets_batch)} triplets in batch {batch_number}")
|
|
224
|
-
|
|
225
|
-
triplet_datapoints = []
|
|
226
|
-
skipped_count = 0
|
|
227
|
-
|
|
228
|
-
for idx, triplet_datapoint in enumerate(triplets_batch):
|
|
229
|
-
try:
|
|
230
|
-
triplet_obj, error_msg = _process_single_triplet(
|
|
231
|
-
triplet_datapoint, datapoint_type_index_property, offset, idx
|
|
232
|
-
)
|
|
233
|
-
|
|
234
|
-
if error_msg:
|
|
235
|
-
logger.warning(error_msg)
|
|
236
|
-
skipped_count += 1
|
|
237
|
-
continue
|
|
238
|
-
|
|
239
|
-
if triplet_obj:
|
|
240
|
-
triplet_datapoints.append(triplet_obj)
|
|
241
|
-
yield triplet_obj
|
|
242
|
-
|
|
243
|
-
except Exception as e:
|
|
244
|
-
logger.warning(
|
|
245
|
-
f"Error processing triplet at offset {offset + idx}: {e}. "
|
|
246
|
-
f"Skipping this triplet and continuing."
|
|
247
|
-
)
|
|
248
|
-
skipped_count += 1
|
|
249
|
-
continue
|
|
250
|
-
|
|
251
|
-
if skipped_count > 0:
|
|
252
|
-
logger.warning(
|
|
253
|
-
f"Skipped {skipped_count} out of {len(triplets_batch)} triplets in batch {batch_number}"
|
|
254
|
-
)
|
|
255
|
-
|
|
256
|
-
if not triplet_datapoints:
|
|
257
|
-
logger.warning(
|
|
258
|
-
f"No valid triplet datapoints in batch {batch_number} after processing"
|
|
259
|
-
)
|
|
260
|
-
offset += len(triplets_batch)
|
|
261
|
-
if len(triplets_batch) < triplets_batch_size:
|
|
262
|
-
break
|
|
263
|
-
continue
|
|
264
|
-
|
|
265
|
-
total_triplets_processed += len(triplet_datapoints)
|
|
266
|
-
logger.info(
|
|
267
|
-
f"Batch {batch_number} complete: processed {len(triplet_datapoints)} triplets "
|
|
268
|
-
f"(total processed: {total_triplets_processed})"
|
|
269
|
-
)
|
|
270
|
-
|
|
271
|
-
offset += len(triplets_batch)
|
|
272
|
-
if len(triplets_batch) < triplets_batch_size:
|
|
273
|
-
logger.info(
|
|
274
|
-
f"Last batch retrieved (got {len(triplets_batch)} < {triplets_batch_size} triplets). "
|
|
275
|
-
f"Processing complete."
|
|
276
|
-
)
|
|
277
|
-
break
|
|
278
|
-
|
|
279
|
-
except Exception as e:
|
|
280
|
-
logger.error(
|
|
281
|
-
f"Error retrieving triplet batch {batch_number} at offset {offset}: {e}",
|
|
282
|
-
exc_info=True,
|
|
283
|
-
)
|
|
284
|
-
raise
|
|
285
|
-
|
|
286
|
-
logger.info(
|
|
287
|
-
f"Triplet datapoints extraction complete. "
|
|
288
|
-
f"Processed {total_triplets_processed} triplets across {batch_number} batch(es)."
|
|
289
|
-
)
|
|
@@ -1,84 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import pytest
|
|
3
|
-
import pathlib
|
|
4
|
-
import pytest_asyncio
|
|
5
|
-
import cognee
|
|
6
|
-
|
|
7
|
-
from cognee.low_level import setup
|
|
8
|
-
from cognee.tasks.storage import add_data_points
|
|
9
|
-
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
|
10
|
-
from cognee.modules.retrieval.triplet_retriever import TripletRetriever
|
|
11
|
-
from cognee.modules.engine.models import Triplet
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
@pytest_asyncio.fixture
|
|
15
|
-
async def setup_test_environment_with_triplets():
|
|
16
|
-
"""Set up a clean test environment with triplets."""
|
|
17
|
-
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
|
18
|
-
system_directory_path = str(base_dir / ".cognee_system/test_triplet_retriever_context_simple")
|
|
19
|
-
data_directory_path = str(base_dir / ".data_storage/test_triplet_retriever_context_simple")
|
|
20
|
-
|
|
21
|
-
cognee.config.system_root_directory(system_directory_path)
|
|
22
|
-
cognee.config.data_root_directory(data_directory_path)
|
|
23
|
-
|
|
24
|
-
await cognee.prune.prune_data()
|
|
25
|
-
await cognee.prune.prune_system(metadata=True)
|
|
26
|
-
await setup()
|
|
27
|
-
|
|
28
|
-
triplet1 = Triplet(
|
|
29
|
-
from_node_id="node1",
|
|
30
|
-
to_node_id="node2",
|
|
31
|
-
text="Alice knows Bob",
|
|
32
|
-
)
|
|
33
|
-
triplet2 = Triplet(
|
|
34
|
-
from_node_id="node2",
|
|
35
|
-
to_node_id="node3",
|
|
36
|
-
text="Bob works at Tech Corp",
|
|
37
|
-
)
|
|
38
|
-
|
|
39
|
-
triplets = [triplet1, triplet2]
|
|
40
|
-
await add_data_points(triplets)
|
|
41
|
-
|
|
42
|
-
yield
|
|
43
|
-
|
|
44
|
-
try:
|
|
45
|
-
await cognee.prune.prune_data()
|
|
46
|
-
await cognee.prune.prune_system(metadata=True)
|
|
47
|
-
except Exception:
|
|
48
|
-
pass
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
@pytest_asyncio.fixture
|
|
52
|
-
async def setup_test_environment_empty():
|
|
53
|
-
"""Set up a clean test environment without triplets."""
|
|
54
|
-
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
|
|
55
|
-
system_directory_path = str(
|
|
56
|
-
base_dir / ".cognee_system/test_triplet_retriever_context_empty_collection"
|
|
57
|
-
)
|
|
58
|
-
data_directory_path = str(
|
|
59
|
-
base_dir / ".data_storage/test_triplet_retriever_context_empty_collection"
|
|
60
|
-
)
|
|
61
|
-
|
|
62
|
-
cognee.config.system_root_directory(system_directory_path)
|
|
63
|
-
cognee.config.data_root_directory(data_directory_path)
|
|
64
|
-
|
|
65
|
-
await cognee.prune.prune_data()
|
|
66
|
-
await cognee.prune.prune_system(metadata=True)
|
|
67
|
-
|
|
68
|
-
yield
|
|
69
|
-
|
|
70
|
-
try:
|
|
71
|
-
await cognee.prune.prune_data()
|
|
72
|
-
await cognee.prune.prune_system(metadata=True)
|
|
73
|
-
except Exception:
|
|
74
|
-
pass
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
@pytest.mark.asyncio
|
|
78
|
-
async def test_triplet_retriever_context_simple(setup_test_environment_with_triplets):
|
|
79
|
-
"""Integration test: verify TripletRetriever can retrieve triplet context."""
|
|
80
|
-
retriever = TripletRetriever(top_k=5)
|
|
81
|
-
|
|
82
|
-
context = await retriever.get_context("Alice")
|
|
83
|
-
|
|
84
|
-
assert "Alice knows Bob" in context, "Failed to get Alice triplet"
|