cognee 0.5.0.dev0__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/api/client.py +1 -5
- cognee/api/v1/add/add.py +2 -1
- cognee/api/v1/cognify/cognify.py +24 -16
- cognee/api/v1/cognify/routers/__init__.py +0 -1
- cognee/api/v1/cognify/routers/get_cognify_router.py +3 -1
- cognee/api/v1/datasets/routers/get_datasets_router.py +3 -3
- cognee/api/v1/ontologies/ontologies.py +12 -37
- cognee/api/v1/ontologies/routers/get_ontology_router.py +27 -25
- cognee/api/v1/search/search.py +8 -0
- cognee/api/v1/ui/node_setup.py +360 -0
- cognee/api/v1/ui/npm_utils.py +50 -0
- cognee/api/v1/ui/ui.py +38 -68
- cognee/context_global_variables.py +61 -16
- cognee/eval_framework/Dockerfile +29 -0
- cognee/eval_framework/answer_generation/answer_generation_executor.py +10 -0
- cognee/eval_framework/answer_generation/run_question_answering_module.py +1 -1
- cognee/eval_framework/corpus_builder/task_getters/get_cascade_graph_tasks.py +0 -2
- cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py +4 -4
- cognee/eval_framework/eval_config.py +2 -2
- cognee/eval_framework/modal_run_eval.py +16 -28
- cognee/infrastructure/databases/dataset_database_handler/__init__.py +3 -0
- cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py +80 -0
- cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +18 -0
- cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py +10 -0
- cognee/infrastructure/databases/graph/config.py +3 -0
- cognee/infrastructure/databases/graph/get_graph_engine.py +1 -0
- cognee/infrastructure/databases/graph/graph_db_interface.py +15 -0
- cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py +81 -0
- cognee/infrastructure/databases/graph/kuzu/adapter.py +228 -0
- cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +168 -0
- cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +80 -1
- cognee/infrastructure/databases/utils/__init__.py +3 -0
- cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py +10 -0
- cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +62 -48
- cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py +10 -0
- cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py +30 -0
- cognee/infrastructure/databases/vector/config.py +2 -0
- cognee/infrastructure/databases/vector/create_vector_engine.py +1 -0
- cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +8 -6
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +9 -7
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +11 -10
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +2 -0
- cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py +50 -0
- cognee/infrastructure/databases/vector/vector_db_interface.py +35 -0
- cognee/infrastructure/files/storage/s3_config.py +2 -0
- cognee/infrastructure/llm/LLMGateway.py +5 -2
- cognee/infrastructure/llm/config.py +35 -0
- cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py +2 -2
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +23 -8
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -16
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py +5 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py +153 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +40 -37
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +39 -36
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +19 -1
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +11 -9
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +23 -21
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +42 -34
- cognee/memify_pipelines/create_triplet_embeddings.py +53 -0
- cognee/modules/cognify/config.py +2 -0
- cognee/modules/data/deletion/prune_system.py +52 -2
- cognee/modules/data/methods/delete_dataset.py +26 -0
- cognee/modules/engine/models/Triplet.py +9 -0
- cognee/modules/engine/models/__init__.py +1 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +85 -37
- cognee/modules/graph/cognee_graph/CogneeGraphElements.py +8 -3
- cognee/modules/memify/memify.py +1 -7
- cognee/modules/pipelines/operations/pipeline.py +18 -2
- cognee/modules/retrieval/__init__.py +1 -1
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +4 -0
- cognee/modules/retrieval/graph_completion_cot_retriever.py +4 -0
- cognee/modules/retrieval/graph_completion_retriever.py +10 -0
- cognee/modules/retrieval/graph_summary_completion_retriever.py +4 -0
- cognee/modules/retrieval/register_retriever.py +10 -0
- cognee/modules/retrieval/registered_community_retrievers.py +1 -0
- cognee/modules/retrieval/temporal_retriever.py +4 -0
- cognee/modules/retrieval/triplet_retriever.py +182 -0
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +42 -10
- cognee/modules/run_custom_pipeline/run_custom_pipeline.py +8 -1
- cognee/modules/search/methods/get_search_type_tools.py +54 -8
- cognee/modules/search/methods/no_access_control_search.py +4 -0
- cognee/modules/search/methods/search.py +46 -18
- cognee/modules/search/types/SearchType.py +1 -1
- cognee/modules/settings/get_settings.py +19 -0
- cognee/modules/users/methods/get_authenticated_user.py +2 -2
- cognee/modules/users/models/DatasetDatabase.py +15 -3
- cognee/shared/logging_utils.py +4 -0
- cognee/shared/rate_limiting.py +30 -0
- cognee/tasks/documents/__init__.py +0 -1
- cognee/tasks/graph/extract_graph_from_data.py +9 -10
- cognee/tasks/memify/get_triplet_datapoints.py +289 -0
- cognee/tasks/storage/add_data_points.py +142 -2
- cognee/tests/integration/retrieval/test_triplet_retriever.py +84 -0
- cognee/tests/integration/tasks/test_add_data_points.py +139 -0
- cognee/tests/integration/tasks/test_get_triplet_datapoints.py +69 -0
- cognee/tests/test_cognee_server_start.py +2 -4
- cognee/tests/test_conversation_history.py +23 -1
- cognee/tests/test_dataset_database_handler.py +137 -0
- cognee/tests/test_dataset_delete.py +76 -0
- cognee/tests/test_edge_centered_payload.py +170 -0
- cognee/tests/test_pipeline_cache.py +164 -0
- cognee/tests/test_search_db.py +37 -1
- cognee/tests/unit/api/test_ontology_endpoint.py +77 -89
- cognee/tests/unit/infrastructure/llm/test_llm_config.py +46 -0
- cognee/tests/unit/infrastructure/mock_embedding_engine.py +3 -7
- cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +0 -5
- cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +2 -2
- cognee/tests/unit/modules/graph/cognee_graph_test.py +406 -0
- cognee/tests/unit/modules/memify_tasks/test_get_triplet_datapoints.py +214 -0
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +608 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +83 -0
- cognee/tests/unit/modules/search/test_search.py +100 -0
- cognee/tests/unit/tasks/storage/test_add_data_points.py +288 -0
- {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/METADATA +76 -89
- {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/RECORD +119 -97
- {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/WHEEL +1 -1
- cognee/api/v1/cognify/code_graph_pipeline.py +0 -119
- cognee/api/v1/cognify/routers/get_code_pipeline_router.py +0 -90
- cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +0 -544
- cognee/modules/retrieval/code_retriever.py +0 -232
- cognee/tasks/code/enrich_dependency_graph_checker.py +0 -35
- cognee/tasks/code/get_local_dependencies_checker.py +0 -20
- cognee/tasks/code/get_repo_dependency_graph_checker.py +0 -35
- cognee/tasks/documents/check_permissions_on_dataset.py +0 -26
- cognee/tasks/repo_processor/__init__.py +0 -2
- cognee/tasks/repo_processor/get_local_dependencies.py +0 -335
- cognee/tasks/repo_processor/get_non_code_files.py +0 -158
- cognee/tasks/repo_processor/get_repo_file_dependencies.py +0 -243
- cognee/tests/test_delete_bmw_example.py +0 -60
- {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -0,0 +1,608 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from unittest.mock import AsyncMock, patch
|
|
3
|
+
|
|
4
|
+
from cognee.modules.retrieval.utils.brute_force_triplet_search import (
|
|
5
|
+
brute_force_triplet_search,
|
|
6
|
+
get_memory_fragment,
|
|
7
|
+
)
|
|
8
|
+
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
|
9
|
+
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MockScoredResult:
|
|
13
|
+
"""Mock class for vector search results."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, id, score, payload=None):
|
|
16
|
+
self.id = id
|
|
17
|
+
self.score = score
|
|
18
|
+
self.payload = payload or {}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@pytest.mark.asyncio
|
|
22
|
+
async def test_brute_force_triplet_search_empty_query():
|
|
23
|
+
"""Test that empty query raises ValueError."""
|
|
24
|
+
with pytest.raises(ValueError, match="The query must be a non-empty string."):
|
|
25
|
+
await brute_force_triplet_search(query="")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@pytest.mark.asyncio
|
|
29
|
+
async def test_brute_force_triplet_search_none_query():
|
|
30
|
+
"""Test that None query raises ValueError."""
|
|
31
|
+
with pytest.raises(ValueError, match="The query must be a non-empty string."):
|
|
32
|
+
await brute_force_triplet_search(query=None)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@pytest.mark.asyncio
|
|
36
|
+
async def test_brute_force_triplet_search_negative_top_k():
|
|
37
|
+
"""Test that negative top_k raises ValueError."""
|
|
38
|
+
with pytest.raises(ValueError, match="top_k must be a positive integer."):
|
|
39
|
+
await brute_force_triplet_search(query="test query", top_k=-1)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@pytest.mark.asyncio
|
|
43
|
+
async def test_brute_force_triplet_search_zero_top_k():
|
|
44
|
+
"""Test that zero top_k raises ValueError."""
|
|
45
|
+
with pytest.raises(ValueError, match="top_k must be a positive integer."):
|
|
46
|
+
await brute_force_triplet_search(query="test query", top_k=0)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@pytest.mark.asyncio
|
|
50
|
+
async def test_brute_force_triplet_search_wide_search_limit_global_search():
|
|
51
|
+
"""Test that wide_search_limit is applied for global search (node_name=None)."""
|
|
52
|
+
mock_vector_engine = AsyncMock()
|
|
53
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
54
|
+
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
55
|
+
mock_vector_engine.search = AsyncMock(return_value=[])
|
|
56
|
+
|
|
57
|
+
with patch(
|
|
58
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
|
59
|
+
return_value=mock_vector_engine,
|
|
60
|
+
):
|
|
61
|
+
await brute_force_triplet_search(
|
|
62
|
+
query="test",
|
|
63
|
+
node_name=None, # Global search
|
|
64
|
+
wide_search_top_k=75,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
for call in mock_vector_engine.search.call_args_list:
|
|
68
|
+
assert call[1]["limit"] == 75
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@pytest.mark.asyncio
|
|
72
|
+
async def test_brute_force_triplet_search_wide_search_limit_filtered_search():
|
|
73
|
+
"""Test that wide_search_limit is None for filtered search (node_name provided)."""
|
|
74
|
+
mock_vector_engine = AsyncMock()
|
|
75
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
76
|
+
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
77
|
+
mock_vector_engine.search = AsyncMock(return_value=[])
|
|
78
|
+
|
|
79
|
+
with patch(
|
|
80
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
|
81
|
+
return_value=mock_vector_engine,
|
|
82
|
+
):
|
|
83
|
+
await brute_force_triplet_search(
|
|
84
|
+
query="test",
|
|
85
|
+
node_name=["Node1"],
|
|
86
|
+
wide_search_top_k=50,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
for call in mock_vector_engine.search.call_args_list:
|
|
90
|
+
assert call[1]["limit"] is None
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@pytest.mark.asyncio
|
|
94
|
+
async def test_brute_force_triplet_search_wide_search_default():
|
|
95
|
+
"""Test that wide_search_top_k defaults to 100."""
|
|
96
|
+
mock_vector_engine = AsyncMock()
|
|
97
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
98
|
+
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
99
|
+
mock_vector_engine.search = AsyncMock(return_value=[])
|
|
100
|
+
|
|
101
|
+
with patch(
|
|
102
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
|
103
|
+
return_value=mock_vector_engine,
|
|
104
|
+
):
|
|
105
|
+
await brute_force_triplet_search(query="test", node_name=None)
|
|
106
|
+
|
|
107
|
+
for call in mock_vector_engine.search.call_args_list:
|
|
108
|
+
assert call[1]["limit"] == 100
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@pytest.mark.asyncio
|
|
112
|
+
async def test_brute_force_triplet_search_default_collections():
|
|
113
|
+
"""Test that default collections are used when none provided."""
|
|
114
|
+
mock_vector_engine = AsyncMock()
|
|
115
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
116
|
+
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
117
|
+
mock_vector_engine.search = AsyncMock(return_value=[])
|
|
118
|
+
|
|
119
|
+
with patch(
|
|
120
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
|
121
|
+
return_value=mock_vector_engine,
|
|
122
|
+
):
|
|
123
|
+
await brute_force_triplet_search(query="test")
|
|
124
|
+
|
|
125
|
+
expected_collections = [
|
|
126
|
+
"Entity_name",
|
|
127
|
+
"TextSummary_text",
|
|
128
|
+
"EntityType_name",
|
|
129
|
+
"DocumentChunk_text",
|
|
130
|
+
"EdgeType_relationship_name",
|
|
131
|
+
]
|
|
132
|
+
|
|
133
|
+
call_collections = [
|
|
134
|
+
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
|
|
135
|
+
]
|
|
136
|
+
assert call_collections == expected_collections
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@pytest.mark.asyncio
|
|
140
|
+
async def test_brute_force_triplet_search_custom_collections():
|
|
141
|
+
"""Test that custom collections are used when provided."""
|
|
142
|
+
mock_vector_engine = AsyncMock()
|
|
143
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
144
|
+
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
145
|
+
mock_vector_engine.search = AsyncMock(return_value=[])
|
|
146
|
+
|
|
147
|
+
custom_collections = ["CustomCol1", "CustomCol2"]
|
|
148
|
+
|
|
149
|
+
with patch(
|
|
150
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
|
151
|
+
return_value=mock_vector_engine,
|
|
152
|
+
):
|
|
153
|
+
await brute_force_triplet_search(query="test", collections=custom_collections)
|
|
154
|
+
|
|
155
|
+
call_collections = [
|
|
156
|
+
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
|
|
157
|
+
]
|
|
158
|
+
assert set(call_collections) == set(custom_collections) | {"EdgeType_relationship_name"}
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
@pytest.mark.asyncio
|
|
162
|
+
async def test_brute_force_triplet_search_always_includes_edge_collection():
|
|
163
|
+
"""Test that EdgeType_relationship_name is always searched even when not in collections."""
|
|
164
|
+
mock_vector_engine = AsyncMock()
|
|
165
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
166
|
+
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
167
|
+
mock_vector_engine.search = AsyncMock(return_value=[])
|
|
168
|
+
|
|
169
|
+
collections_without_edge = ["Entity_name", "TextSummary_text"]
|
|
170
|
+
|
|
171
|
+
with patch(
|
|
172
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
|
173
|
+
return_value=mock_vector_engine,
|
|
174
|
+
):
|
|
175
|
+
await brute_force_triplet_search(query="test", collections=collections_without_edge)
|
|
176
|
+
|
|
177
|
+
call_collections = [
|
|
178
|
+
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
|
|
179
|
+
]
|
|
180
|
+
assert "EdgeType_relationship_name" in call_collections
|
|
181
|
+
assert set(call_collections) == set(collections_without_edge) | {
|
|
182
|
+
"EdgeType_relationship_name"
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
@pytest.mark.asyncio
|
|
187
|
+
async def test_brute_force_triplet_search_all_collections_empty():
|
|
188
|
+
"""Test that empty list is returned when all collections return no results."""
|
|
189
|
+
mock_vector_engine = AsyncMock()
|
|
190
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
191
|
+
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
192
|
+
mock_vector_engine.search = AsyncMock(return_value=[])
|
|
193
|
+
|
|
194
|
+
with patch(
|
|
195
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
|
196
|
+
return_value=mock_vector_engine,
|
|
197
|
+
):
|
|
198
|
+
results = await brute_force_triplet_search(query="test")
|
|
199
|
+
assert results == []
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
# Tests for query embedding
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
@pytest.mark.asyncio
|
|
206
|
+
async def test_brute_force_triplet_search_embeds_query():
|
|
207
|
+
"""Test that query is embedded before searching."""
|
|
208
|
+
query_text = "test query"
|
|
209
|
+
expected_vector = [0.1, 0.2, 0.3]
|
|
210
|
+
|
|
211
|
+
mock_vector_engine = AsyncMock()
|
|
212
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
213
|
+
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[expected_vector])
|
|
214
|
+
mock_vector_engine.search = AsyncMock(return_value=[])
|
|
215
|
+
|
|
216
|
+
with patch(
|
|
217
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
|
218
|
+
return_value=mock_vector_engine,
|
|
219
|
+
):
|
|
220
|
+
await brute_force_triplet_search(query=query_text)
|
|
221
|
+
|
|
222
|
+
mock_vector_engine.embedding_engine.embed_text.assert_called_once_with([query_text])
|
|
223
|
+
|
|
224
|
+
for call in mock_vector_engine.search.call_args_list:
|
|
225
|
+
assert call[1]["query_vector"] == expected_vector
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
@pytest.mark.asyncio
|
|
229
|
+
async def test_brute_force_triplet_search_extracts_node_ids_global_search():
|
|
230
|
+
"""Test that node IDs are extracted from search results for global search."""
|
|
231
|
+
scored_results = [
|
|
232
|
+
MockScoredResult("node1", 0.95),
|
|
233
|
+
MockScoredResult("node2", 0.87),
|
|
234
|
+
MockScoredResult("node3", 0.92),
|
|
235
|
+
]
|
|
236
|
+
|
|
237
|
+
mock_vector_engine = AsyncMock()
|
|
238
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
239
|
+
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
240
|
+
mock_vector_engine.search = AsyncMock(return_value=scored_results)
|
|
241
|
+
|
|
242
|
+
mock_fragment = AsyncMock(
|
|
243
|
+
map_vector_distances_to_graph_nodes=AsyncMock(),
|
|
244
|
+
map_vector_distances_to_graph_edges=AsyncMock(),
|
|
245
|
+
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
with (
|
|
249
|
+
patch(
|
|
250
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
|
251
|
+
return_value=mock_vector_engine,
|
|
252
|
+
),
|
|
253
|
+
patch(
|
|
254
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
|
255
|
+
return_value=mock_fragment,
|
|
256
|
+
) as mock_get_fragment_fn,
|
|
257
|
+
):
|
|
258
|
+
await brute_force_triplet_search(query="test", node_name=None)
|
|
259
|
+
|
|
260
|
+
call_kwargs = mock_get_fragment_fn.call_args[1]
|
|
261
|
+
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"}
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
@pytest.mark.asyncio
|
|
265
|
+
async def test_brute_force_triplet_search_reuses_provided_fragment():
|
|
266
|
+
"""Test that provided memory fragment is reused instead of creating new one."""
|
|
267
|
+
provided_fragment = AsyncMock(
|
|
268
|
+
map_vector_distances_to_graph_nodes=AsyncMock(),
|
|
269
|
+
map_vector_distances_to_graph_edges=AsyncMock(),
|
|
270
|
+
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
mock_vector_engine = AsyncMock()
|
|
274
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
275
|
+
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
276
|
+
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
|
|
277
|
+
|
|
278
|
+
with (
|
|
279
|
+
patch(
|
|
280
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
|
281
|
+
return_value=mock_vector_engine,
|
|
282
|
+
),
|
|
283
|
+
patch(
|
|
284
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment"
|
|
285
|
+
) as mock_get_fragment,
|
|
286
|
+
):
|
|
287
|
+
await brute_force_triplet_search(
|
|
288
|
+
query="test",
|
|
289
|
+
memory_fragment=provided_fragment,
|
|
290
|
+
node_name=["node"],
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
mock_get_fragment.assert_not_called()
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
@pytest.mark.asyncio
|
|
297
|
+
async def test_brute_force_triplet_search_creates_fragment_when_not_provided():
|
|
298
|
+
"""Test that memory fragment is created when not provided."""
|
|
299
|
+
mock_vector_engine = AsyncMock()
|
|
300
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
301
|
+
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
302
|
+
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
|
|
303
|
+
|
|
304
|
+
mock_fragment = AsyncMock(
|
|
305
|
+
map_vector_distances_to_graph_nodes=AsyncMock(),
|
|
306
|
+
map_vector_distances_to_graph_edges=AsyncMock(),
|
|
307
|
+
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
with (
|
|
311
|
+
patch(
|
|
312
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
|
313
|
+
return_value=mock_vector_engine,
|
|
314
|
+
),
|
|
315
|
+
patch(
|
|
316
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
|
317
|
+
return_value=mock_fragment,
|
|
318
|
+
) as mock_get_fragment,
|
|
319
|
+
):
|
|
320
|
+
await brute_force_triplet_search(query="test", node_name=["node"])
|
|
321
|
+
|
|
322
|
+
mock_get_fragment.assert_called_once()
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
@pytest.mark.asyncio
|
|
326
|
+
async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation():
|
|
327
|
+
"""Test that custom top_k is passed to importance calculation."""
|
|
328
|
+
mock_vector_engine = AsyncMock()
|
|
329
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
330
|
+
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
331
|
+
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
|
|
332
|
+
|
|
333
|
+
mock_fragment = AsyncMock(
|
|
334
|
+
map_vector_distances_to_graph_nodes=AsyncMock(),
|
|
335
|
+
map_vector_distances_to_graph_edges=AsyncMock(),
|
|
336
|
+
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
with (
|
|
340
|
+
patch(
|
|
341
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
|
342
|
+
return_value=mock_vector_engine,
|
|
343
|
+
),
|
|
344
|
+
patch(
|
|
345
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
|
346
|
+
return_value=mock_fragment,
|
|
347
|
+
),
|
|
348
|
+
):
|
|
349
|
+
custom_top_k = 15
|
|
350
|
+
await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"])
|
|
351
|
+
|
|
352
|
+
mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k)
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
@pytest.mark.asyncio
|
|
356
|
+
async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found():
|
|
357
|
+
"""Test that get_memory_fragment returns empty graph when entity not found."""
|
|
358
|
+
mock_graph_engine = AsyncMock()
|
|
359
|
+
mock_graph_engine.project_graph_from_db = AsyncMock(
|
|
360
|
+
side_effect=EntityNotFoundError("Entity not found")
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
with patch(
|
|
364
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
|
|
365
|
+
return_value=mock_graph_engine,
|
|
366
|
+
):
|
|
367
|
+
fragment = await get_memory_fragment()
|
|
368
|
+
|
|
369
|
+
assert isinstance(fragment, CogneeGraph)
|
|
370
|
+
assert len(fragment.nodes) == 0
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
@pytest.mark.asyncio
|
|
374
|
+
async def test_get_memory_fragment_returns_empty_graph_on_error():
|
|
375
|
+
"""Test that get_memory_fragment returns empty graph on generic error."""
|
|
376
|
+
mock_graph_engine = AsyncMock()
|
|
377
|
+
mock_graph_engine.project_graph_from_db = AsyncMock(side_effect=Exception("Generic error"))
|
|
378
|
+
|
|
379
|
+
with patch(
|
|
380
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
|
|
381
|
+
return_value=mock_graph_engine,
|
|
382
|
+
):
|
|
383
|
+
fragment = await get_memory_fragment()
|
|
384
|
+
|
|
385
|
+
assert isinstance(fragment, CogneeGraph)
|
|
386
|
+
assert len(fragment.nodes) == 0
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
@pytest.mark.asyncio
|
|
390
|
+
async def test_brute_force_triplet_search_deduplicates_node_ids():
|
|
391
|
+
"""Test that duplicate node IDs across collections are deduplicated."""
|
|
392
|
+
|
|
393
|
+
def search_side_effect(*args, **kwargs):
|
|
394
|
+
collection_name = kwargs.get("collection_name")
|
|
395
|
+
if collection_name == "Entity_name":
|
|
396
|
+
return [
|
|
397
|
+
MockScoredResult("node1", 0.95),
|
|
398
|
+
MockScoredResult("node2", 0.87),
|
|
399
|
+
]
|
|
400
|
+
elif collection_name == "TextSummary_text":
|
|
401
|
+
return [
|
|
402
|
+
MockScoredResult("node1", 0.90),
|
|
403
|
+
MockScoredResult("node3", 0.92),
|
|
404
|
+
]
|
|
405
|
+
else:
|
|
406
|
+
return []
|
|
407
|
+
|
|
408
|
+
mock_vector_engine = AsyncMock()
|
|
409
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
410
|
+
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
411
|
+
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
|
412
|
+
|
|
413
|
+
mock_fragment = AsyncMock(
|
|
414
|
+
map_vector_distances_to_graph_nodes=AsyncMock(),
|
|
415
|
+
map_vector_distances_to_graph_edges=AsyncMock(),
|
|
416
|
+
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
with (
|
|
420
|
+
patch(
|
|
421
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
|
422
|
+
return_value=mock_vector_engine,
|
|
423
|
+
),
|
|
424
|
+
patch(
|
|
425
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
|
426
|
+
return_value=mock_fragment,
|
|
427
|
+
) as mock_get_fragment_fn,
|
|
428
|
+
):
|
|
429
|
+
await brute_force_triplet_search(query="test", node_name=None)
|
|
430
|
+
|
|
431
|
+
call_kwargs = mock_get_fragment_fn.call_args[1]
|
|
432
|
+
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"}
|
|
433
|
+
assert len(call_kwargs["relevant_ids_to_filter"]) == 3
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
@pytest.mark.asyncio
|
|
437
|
+
async def test_brute_force_triplet_search_excludes_edge_collection():
|
|
438
|
+
"""Test that EdgeType_relationship_name collection is excluded from ID extraction."""
|
|
439
|
+
|
|
440
|
+
def search_side_effect(*args, **kwargs):
|
|
441
|
+
collection_name = kwargs.get("collection_name")
|
|
442
|
+
if collection_name == "Entity_name":
|
|
443
|
+
return [MockScoredResult("node1", 0.95)]
|
|
444
|
+
elif collection_name == "EdgeType_relationship_name":
|
|
445
|
+
return [MockScoredResult("edge1", 0.88)]
|
|
446
|
+
else:
|
|
447
|
+
return []
|
|
448
|
+
|
|
449
|
+
mock_vector_engine = AsyncMock()
|
|
450
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
451
|
+
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
452
|
+
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
|
453
|
+
|
|
454
|
+
mock_fragment = AsyncMock(
|
|
455
|
+
map_vector_distances_to_graph_nodes=AsyncMock(),
|
|
456
|
+
map_vector_distances_to_graph_edges=AsyncMock(),
|
|
457
|
+
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
with (
|
|
461
|
+
patch(
|
|
462
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
|
463
|
+
return_value=mock_vector_engine,
|
|
464
|
+
),
|
|
465
|
+
patch(
|
|
466
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
|
467
|
+
return_value=mock_fragment,
|
|
468
|
+
) as mock_get_fragment_fn,
|
|
469
|
+
):
|
|
470
|
+
await brute_force_triplet_search(
|
|
471
|
+
query="test",
|
|
472
|
+
node_name=None,
|
|
473
|
+
collections=["Entity_name", "EdgeType_relationship_name"],
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
call_kwargs = mock_get_fragment_fn.call_args[1]
|
|
477
|
+
assert call_kwargs["relevant_ids_to_filter"] == ["node1"]
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
@pytest.mark.asyncio
|
|
481
|
+
async def test_brute_force_triplet_search_skips_nodes_without_ids():
|
|
482
|
+
"""Test that nodes without ID attribute are skipped."""
|
|
483
|
+
|
|
484
|
+
class ScoredResultNoId:
|
|
485
|
+
"""Mock result without id attribute."""
|
|
486
|
+
|
|
487
|
+
def __init__(self, score):
|
|
488
|
+
self.score = score
|
|
489
|
+
|
|
490
|
+
def search_side_effect(*args, **kwargs):
|
|
491
|
+
collection_name = kwargs.get("collection_name")
|
|
492
|
+
if collection_name == "Entity_name":
|
|
493
|
+
return [
|
|
494
|
+
MockScoredResult("node1", 0.95),
|
|
495
|
+
ScoredResultNoId(0.90),
|
|
496
|
+
MockScoredResult("node2", 0.87),
|
|
497
|
+
]
|
|
498
|
+
else:
|
|
499
|
+
return []
|
|
500
|
+
|
|
501
|
+
mock_vector_engine = AsyncMock()
|
|
502
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
503
|
+
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
504
|
+
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
|
505
|
+
|
|
506
|
+
mock_fragment = AsyncMock(
|
|
507
|
+
map_vector_distances_to_graph_nodes=AsyncMock(),
|
|
508
|
+
map_vector_distances_to_graph_edges=AsyncMock(),
|
|
509
|
+
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
with (
|
|
513
|
+
patch(
|
|
514
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
|
515
|
+
return_value=mock_vector_engine,
|
|
516
|
+
),
|
|
517
|
+
patch(
|
|
518
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
|
519
|
+
return_value=mock_fragment,
|
|
520
|
+
) as mock_get_fragment_fn,
|
|
521
|
+
):
|
|
522
|
+
await brute_force_triplet_search(query="test", node_name=None)
|
|
523
|
+
|
|
524
|
+
call_kwargs = mock_get_fragment_fn.call_args[1]
|
|
525
|
+
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
@pytest.mark.asyncio
|
|
529
|
+
async def test_brute_force_triplet_search_handles_tuple_results():
|
|
530
|
+
"""Test that both list and tuple results are handled correctly."""
|
|
531
|
+
|
|
532
|
+
def search_side_effect(*args, **kwargs):
|
|
533
|
+
collection_name = kwargs.get("collection_name")
|
|
534
|
+
if collection_name == "Entity_name":
|
|
535
|
+
return (
|
|
536
|
+
MockScoredResult("node1", 0.95),
|
|
537
|
+
MockScoredResult("node2", 0.87),
|
|
538
|
+
)
|
|
539
|
+
else:
|
|
540
|
+
return []
|
|
541
|
+
|
|
542
|
+
mock_vector_engine = AsyncMock()
|
|
543
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
544
|
+
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
545
|
+
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
|
546
|
+
|
|
547
|
+
mock_fragment = AsyncMock(
|
|
548
|
+
map_vector_distances_to_graph_nodes=AsyncMock(),
|
|
549
|
+
map_vector_distances_to_graph_edges=AsyncMock(),
|
|
550
|
+
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
with (
|
|
554
|
+
patch(
|
|
555
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
|
556
|
+
return_value=mock_vector_engine,
|
|
557
|
+
),
|
|
558
|
+
patch(
|
|
559
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
|
560
|
+
return_value=mock_fragment,
|
|
561
|
+
) as mock_get_fragment_fn,
|
|
562
|
+
):
|
|
563
|
+
await brute_force_triplet_search(query="test", node_name=None)
|
|
564
|
+
|
|
565
|
+
call_kwargs = mock_get_fragment_fn.call_args[1]
|
|
566
|
+
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
@pytest.mark.asyncio
|
|
570
|
+
async def test_brute_force_triplet_search_mixed_empty_collections():
|
|
571
|
+
"""Test ID extraction with mixed empty and non-empty collections."""
|
|
572
|
+
|
|
573
|
+
def search_side_effect(*args, **kwargs):
|
|
574
|
+
collection_name = kwargs.get("collection_name")
|
|
575
|
+
if collection_name == "Entity_name":
|
|
576
|
+
return [MockScoredResult("node1", 0.95)]
|
|
577
|
+
elif collection_name == "TextSummary_text":
|
|
578
|
+
return []
|
|
579
|
+
elif collection_name == "EntityType_name":
|
|
580
|
+
return [MockScoredResult("node2", 0.92)]
|
|
581
|
+
else:
|
|
582
|
+
return []
|
|
583
|
+
|
|
584
|
+
mock_vector_engine = AsyncMock()
|
|
585
|
+
mock_vector_engine.embedding_engine = AsyncMock()
|
|
586
|
+
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
587
|
+
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
|
588
|
+
|
|
589
|
+
mock_fragment = AsyncMock(
|
|
590
|
+
map_vector_distances_to_graph_nodes=AsyncMock(),
|
|
591
|
+
map_vector_distances_to_graph_edges=AsyncMock(),
|
|
592
|
+
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
with (
|
|
596
|
+
patch(
|
|
597
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
|
598
|
+
return_value=mock_vector_engine,
|
|
599
|
+
),
|
|
600
|
+
patch(
|
|
601
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
|
602
|
+
return_value=mock_fragment,
|
|
603
|
+
) as mock_get_fragment_fn,
|
|
604
|
+
):
|
|
605
|
+
await brute_force_triplet_search(query="test", node_name=None)
|
|
606
|
+
|
|
607
|
+
call_kwargs = mock_get_fragment_fn.call_args[1]
|
|
608
|
+
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from unittest.mock import AsyncMock, patch, MagicMock
|
|
3
|
+
|
|
4
|
+
from cognee.modules.retrieval.triplet_retriever import TripletRetriever
|
|
5
|
+
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
|
6
|
+
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@pytest.fixture
|
|
10
|
+
def mock_vector_engine():
|
|
11
|
+
"""Create a mock vector engine."""
|
|
12
|
+
engine = AsyncMock()
|
|
13
|
+
engine.has_collection = AsyncMock(return_value=True)
|
|
14
|
+
engine.search = AsyncMock()
|
|
15
|
+
return engine
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@pytest.mark.asyncio
|
|
19
|
+
async def test_get_context_success(mock_vector_engine):
|
|
20
|
+
"""Test successful retrieval of triplet context."""
|
|
21
|
+
mock_result1 = MagicMock()
|
|
22
|
+
mock_result1.payload = {"text": "Alice knows Bob"}
|
|
23
|
+
mock_result2 = MagicMock()
|
|
24
|
+
mock_result2.payload = {"text": "Bob works at Tech Corp"}
|
|
25
|
+
|
|
26
|
+
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
|
|
27
|
+
|
|
28
|
+
retriever = TripletRetriever(top_k=5)
|
|
29
|
+
|
|
30
|
+
with patch(
|
|
31
|
+
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
|
32
|
+
return_value=mock_vector_engine,
|
|
33
|
+
):
|
|
34
|
+
context = await retriever.get_context("test query")
|
|
35
|
+
|
|
36
|
+
assert context == "Alice knows Bob\nBob works at Tech Corp"
|
|
37
|
+
mock_vector_engine.search.assert_awaited_once_with("Triplet_text", "test query", limit=5)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@pytest.mark.asyncio
|
|
41
|
+
async def test_get_context_no_collection(mock_vector_engine):
|
|
42
|
+
"""Test that NoDataError is raised when Triplet_text collection doesn't exist."""
|
|
43
|
+
mock_vector_engine.has_collection.return_value = False
|
|
44
|
+
|
|
45
|
+
retriever = TripletRetriever()
|
|
46
|
+
|
|
47
|
+
with patch(
|
|
48
|
+
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
|
49
|
+
return_value=mock_vector_engine,
|
|
50
|
+
):
|
|
51
|
+
with pytest.raises(NoDataError, match="create_triplet_embeddings"):
|
|
52
|
+
await retriever.get_context("test query")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@pytest.mark.asyncio
|
|
56
|
+
async def test_get_context_empty_results(mock_vector_engine):
|
|
57
|
+
"""Test that empty string is returned when no triplets are found."""
|
|
58
|
+
mock_vector_engine.search.return_value = []
|
|
59
|
+
|
|
60
|
+
retriever = TripletRetriever()
|
|
61
|
+
|
|
62
|
+
with patch(
|
|
63
|
+
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
|
64
|
+
return_value=mock_vector_engine,
|
|
65
|
+
):
|
|
66
|
+
context = await retriever.get_context("test query")
|
|
67
|
+
|
|
68
|
+
assert context == ""
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@pytest.mark.asyncio
|
|
72
|
+
async def test_get_context_collection_not_found_error(mock_vector_engine):
|
|
73
|
+
"""Test that CollectionNotFoundError is converted to NoDataError."""
|
|
74
|
+
mock_vector_engine.has_collection.side_effect = CollectionNotFoundError("Collection not found")
|
|
75
|
+
|
|
76
|
+
retriever = TripletRetriever()
|
|
77
|
+
|
|
78
|
+
with patch(
|
|
79
|
+
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
|
80
|
+
return_value=mock_vector_engine,
|
|
81
|
+
):
|
|
82
|
+
with pytest.raises(NoDataError, match="No data found"):
|
|
83
|
+
await retriever.get_context("test query")
|