cognee 0.5.1__py3-none-any.whl → 0.5.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/api/v1/add/add.py +2 -1
- cognee/api/v1/datasets/routers/get_datasets_router.py +1 -0
- cognee/api/v1/memify/routers/get_memify_router.py +1 -0
- cognee/api/v1/search/search.py +0 -4
- cognee/infrastructure/databases/relational/config.py +16 -1
- cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +24 -2
- cognee/infrastructure/databases/vector/create_vector_engine.py +9 -2
- cognee/infrastructure/llm/LLMGateway.py +0 -13
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +5 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
- cognee/modules/data/models/Data.py +2 -1
- cognee/modules/retrieval/triplet_retriever.py +1 -1
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +0 -18
- cognee/modules/search/methods/search.py +18 -25
- cognee/tasks/ingestion/data_item.py +8 -0
- cognee/tasks/ingestion/ingest_data.py +12 -1
- cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
- cognee/tests/integration/retrieval/test_chunks_retriever.py +252 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever.py +268 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +226 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +218 -0
- cognee/tests/integration/retrieval/test_rag_completion_retriever.py +254 -0
- cognee/tests/{unit/modules/retrieval/structured_output_test.py → integration/retrieval/test_structured_output.py} +87 -77
- cognee/tests/integration/retrieval/test_summaries_retriever.py +184 -0
- cognee/tests/integration/retrieval/test_temporal_retriever.py +306 -0
- cognee/tests/integration/retrieval/test_triplet_retriever.py +35 -0
- cognee/tests/test_custom_data_label.py +68 -0
- cognee/tests/test_search_db.py +334 -181
- cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
- cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
- cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +181 -199
- cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +454 -162
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +674 -156
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +625 -200
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +319 -203
- cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +189 -155
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +539 -58
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +218 -9
- cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
- cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
- cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +246 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/METADATA +1 -1
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/RECORD +58 -45
- cognee/tests/unit/modules/search/test_search.py +0 -100
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/WHEEL +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -1,205 +1,321 @@
|
|
|
1
|
-
import os
|
|
2
|
-
from typing import List
|
|
3
1
|
import pytest
|
|
4
|
-
import
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
from cognee.low_level import setup
|
|
8
|
-
from cognee.tasks.storage import add_data_points
|
|
9
|
-
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
10
|
-
from cognee.modules.chunking.models import DocumentChunk
|
|
11
|
-
from cognee.modules.data.processing.document_types import TextDocument
|
|
12
|
-
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
|
2
|
+
from unittest.mock import AsyncMock, patch, MagicMock
|
|
3
|
+
|
|
13
4
|
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
|
14
|
-
from cognee.
|
|
15
|
-
from cognee.
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
await
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
)
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
)
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
5
|
+
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
|
6
|
+
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@pytest.fixture
|
|
10
|
+
def mock_vector_engine():
|
|
11
|
+
"""Create a mock vector engine."""
|
|
12
|
+
engine = AsyncMock()
|
|
13
|
+
engine.search = AsyncMock()
|
|
14
|
+
return engine
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@pytest.mark.asyncio
|
|
18
|
+
async def test_get_context_success(mock_vector_engine):
|
|
19
|
+
"""Test successful retrieval of context."""
|
|
20
|
+
mock_result1 = MagicMock()
|
|
21
|
+
mock_result1.payload = {"text": "Steve Rodger"}
|
|
22
|
+
mock_result2 = MagicMock()
|
|
23
|
+
mock_result2.payload = {"text": "Mike Broski"}
|
|
24
|
+
|
|
25
|
+
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
|
|
26
|
+
|
|
27
|
+
retriever = CompletionRetriever(top_k=2)
|
|
28
|
+
|
|
29
|
+
with patch(
|
|
30
|
+
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
|
|
31
|
+
return_value=mock_vector_engine,
|
|
32
|
+
):
|
|
33
|
+
context = await retriever.get_context("test query")
|
|
34
|
+
|
|
35
|
+
assert context == "Steve Rodger\nMike Broski"
|
|
36
|
+
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@pytest.mark.asyncio
|
|
40
|
+
async def test_get_context_collection_not_found_error(mock_vector_engine):
|
|
41
|
+
"""Test that CollectionNotFoundError is converted to NoDataError."""
|
|
42
|
+
mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found")
|
|
43
|
+
|
|
44
|
+
retriever = CompletionRetriever()
|
|
45
|
+
|
|
46
|
+
with patch(
|
|
47
|
+
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
|
|
48
|
+
return_value=mock_vector_engine,
|
|
49
|
+
):
|
|
50
|
+
with pytest.raises(NoDataError, match="No data found"):
|
|
51
|
+
await retriever.get_context("test query")
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@pytest.mark.asyncio
|
|
55
|
+
async def test_get_context_empty_results(mock_vector_engine):
|
|
56
|
+
"""Test that empty string is returned when no chunks are found."""
|
|
57
|
+
mock_vector_engine.search.return_value = []
|
|
58
|
+
|
|
59
|
+
retriever = CompletionRetriever()
|
|
60
|
+
|
|
61
|
+
with patch(
|
|
62
|
+
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
|
|
63
|
+
return_value=mock_vector_engine,
|
|
64
|
+
):
|
|
65
|
+
context = await retriever.get_context("test query")
|
|
66
|
+
|
|
67
|
+
assert context == ""
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@pytest.mark.asyncio
|
|
71
|
+
async def test_get_context_top_k_limit(mock_vector_engine):
|
|
72
|
+
"""Test that top_k parameter limits the number of results."""
|
|
73
|
+
mock_results = [MagicMock() for _ in range(2)]
|
|
74
|
+
for i, result in enumerate(mock_results):
|
|
75
|
+
result.payload = {"text": f"Chunk {i}"}
|
|
76
|
+
|
|
77
|
+
mock_vector_engine.search.return_value = mock_results
|
|
78
|
+
|
|
79
|
+
retriever = CompletionRetriever(top_k=2)
|
|
80
|
+
|
|
81
|
+
with patch(
|
|
82
|
+
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
|
|
83
|
+
return_value=mock_vector_engine,
|
|
84
|
+
):
|
|
85
|
+
context = await retriever.get_context("test query")
|
|
86
|
+
|
|
87
|
+
assert context == "Chunk 0\nChunk 1"
|
|
88
|
+
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@pytest.mark.asyncio
|
|
92
|
+
async def test_get_context_single_chunk(mock_vector_engine):
|
|
93
|
+
"""Test get_context with single chunk result."""
|
|
94
|
+
mock_result = MagicMock()
|
|
95
|
+
mock_result.payload = {"text": "Single chunk text"}
|
|
96
|
+
mock_vector_engine.search.return_value = [mock_result]
|
|
97
|
+
|
|
98
|
+
retriever = CompletionRetriever()
|
|
99
|
+
|
|
100
|
+
with patch(
|
|
101
|
+
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
|
|
102
|
+
return_value=mock_vector_engine,
|
|
103
|
+
):
|
|
104
|
+
context = await retriever.get_context("test query")
|
|
105
|
+
|
|
106
|
+
assert context == "Single chunk text"
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@pytest.mark.asyncio
|
|
110
|
+
async def test_get_completion_without_session(mock_vector_engine):
|
|
111
|
+
"""Test get_completion without session caching."""
|
|
112
|
+
mock_result = MagicMock()
|
|
113
|
+
mock_result.payload = {"text": "Chunk text"}
|
|
114
|
+
mock_vector_engine.search.return_value = [mock_result]
|
|
115
|
+
|
|
116
|
+
retriever = CompletionRetriever()
|
|
117
|
+
|
|
118
|
+
with (
|
|
119
|
+
patch(
|
|
120
|
+
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
|
|
121
|
+
return_value=mock_vector_engine,
|
|
122
|
+
),
|
|
123
|
+
patch(
|
|
124
|
+
"cognee.modules.retrieval.completion_retriever.generate_completion",
|
|
125
|
+
return_value="Generated answer",
|
|
126
|
+
),
|
|
127
|
+
patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config,
|
|
128
|
+
):
|
|
129
|
+
mock_config = MagicMock()
|
|
130
|
+
mock_config.caching = False
|
|
131
|
+
mock_cache_config.return_value = mock_config
|
|
132
|
+
|
|
133
|
+
completion = await retriever.get_completion("test query")
|
|
134
|
+
|
|
135
|
+
assert isinstance(completion, list)
|
|
136
|
+
assert len(completion) == 1
|
|
137
|
+
assert completion[0] == "Generated answer"
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
@pytest.mark.asyncio
|
|
141
|
+
async def test_get_completion_with_provided_context(mock_vector_engine):
|
|
142
|
+
"""Test get_completion with provided context."""
|
|
143
|
+
retriever = CompletionRetriever()
|
|
144
|
+
|
|
145
|
+
with (
|
|
146
|
+
patch(
|
|
147
|
+
"cognee.modules.retrieval.completion_retriever.generate_completion",
|
|
148
|
+
return_value="Generated answer",
|
|
149
|
+
),
|
|
150
|
+
patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config,
|
|
151
|
+
):
|
|
152
|
+
mock_config = MagicMock()
|
|
153
|
+
mock_config.caching = False
|
|
154
|
+
mock_cache_config.return_value = mock_config
|
|
155
|
+
|
|
156
|
+
completion = await retriever.get_completion("test query", context="Provided context")
|
|
157
|
+
|
|
158
|
+
assert isinstance(completion, list)
|
|
159
|
+
assert len(completion) == 1
|
|
160
|
+
assert completion[0] == "Generated answer"
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
@pytest.mark.asyncio
|
|
164
|
+
async def test_get_completion_with_session(mock_vector_engine):
|
|
165
|
+
"""Test get_completion with session caching enabled."""
|
|
166
|
+
mock_result = MagicMock()
|
|
167
|
+
mock_result.payload = {"text": "Chunk text"}
|
|
168
|
+
mock_vector_engine.search.return_value = [mock_result]
|
|
169
|
+
|
|
170
|
+
retriever = CompletionRetriever()
|
|
171
|
+
|
|
172
|
+
mock_user = MagicMock()
|
|
173
|
+
mock_user.id = "test-user-id"
|
|
174
|
+
|
|
175
|
+
with (
|
|
176
|
+
patch(
|
|
177
|
+
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
|
|
178
|
+
return_value=mock_vector_engine,
|
|
179
|
+
),
|
|
180
|
+
patch(
|
|
181
|
+
"cognee.modules.retrieval.completion_retriever.get_conversation_history",
|
|
182
|
+
return_value="Previous conversation",
|
|
183
|
+
),
|
|
184
|
+
patch(
|
|
185
|
+
"cognee.modules.retrieval.completion_retriever.summarize_text",
|
|
186
|
+
return_value="Context summary",
|
|
187
|
+
),
|
|
188
|
+
patch(
|
|
189
|
+
"cognee.modules.retrieval.completion_retriever.generate_completion",
|
|
190
|
+
return_value="Generated answer",
|
|
191
|
+
),
|
|
192
|
+
patch(
|
|
193
|
+
"cognee.modules.retrieval.completion_retriever.save_conversation_history",
|
|
194
|
+
) as mock_save,
|
|
195
|
+
patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config,
|
|
196
|
+
patch("cognee.modules.retrieval.completion_retriever.session_user") as mock_session_user,
|
|
197
|
+
):
|
|
198
|
+
mock_config = MagicMock()
|
|
199
|
+
mock_config.caching = True
|
|
200
|
+
mock_cache_config.return_value = mock_config
|
|
201
|
+
mock_session_user.get.return_value = mock_user
|
|
202
|
+
|
|
203
|
+
completion = await retriever.get_completion("test query", session_id="test_session")
|
|
204
|
+
|
|
205
|
+
assert isinstance(completion, list)
|
|
206
|
+
assert len(completion) == 1
|
|
207
|
+
assert completion[0] == "Generated answer"
|
|
208
|
+
mock_save.assert_awaited_once()
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
@pytest.mark.asyncio
|
|
212
|
+
async def test_get_completion_with_session_no_user_id(mock_vector_engine):
|
|
213
|
+
"""Test get_completion with session config but no user ID."""
|
|
214
|
+
mock_result = MagicMock()
|
|
215
|
+
mock_result.payload = {"text": "Chunk text"}
|
|
216
|
+
mock_vector_engine.search.return_value = [mock_result]
|
|
217
|
+
|
|
218
|
+
retriever = CompletionRetriever()
|
|
219
|
+
|
|
220
|
+
with (
|
|
221
|
+
patch(
|
|
222
|
+
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
|
|
223
|
+
return_value=mock_vector_engine,
|
|
224
|
+
),
|
|
225
|
+
patch(
|
|
226
|
+
"cognee.modules.retrieval.completion_retriever.generate_completion",
|
|
227
|
+
return_value="Generated answer",
|
|
228
|
+
),
|
|
229
|
+
patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config,
|
|
230
|
+
patch("cognee.modules.retrieval.completion_retriever.session_user") as mock_session_user,
|
|
231
|
+
):
|
|
232
|
+
mock_config = MagicMock()
|
|
233
|
+
mock_config.caching = True
|
|
234
|
+
mock_cache_config.return_value = mock_config
|
|
235
|
+
mock_session_user.get.return_value = None # No user
|
|
236
|
+
|
|
237
|
+
completion = await retriever.get_completion("test query")
|
|
238
|
+
|
|
239
|
+
assert isinstance(completion, list)
|
|
240
|
+
assert len(completion) == 1
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
@pytest.mark.asyncio
|
|
244
|
+
async def test_get_completion_with_response_model(mock_vector_engine):
|
|
245
|
+
"""Test get_completion with custom response model."""
|
|
246
|
+
from pydantic import BaseModel
|
|
247
|
+
|
|
248
|
+
class TestModel(BaseModel):
|
|
249
|
+
answer: str
|
|
250
|
+
|
|
251
|
+
mock_result = MagicMock()
|
|
252
|
+
mock_result.payload = {"text": "Chunk text"}
|
|
253
|
+
mock_vector_engine.search.return_value = [mock_result]
|
|
254
|
+
|
|
255
|
+
retriever = CompletionRetriever()
|
|
256
|
+
|
|
257
|
+
with (
|
|
258
|
+
patch(
|
|
259
|
+
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
|
|
260
|
+
return_value=mock_vector_engine,
|
|
261
|
+
),
|
|
262
|
+
patch(
|
|
263
|
+
"cognee.modules.retrieval.completion_retriever.generate_completion",
|
|
264
|
+
return_value=TestModel(answer="Test answer"),
|
|
265
|
+
),
|
|
266
|
+
patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config,
|
|
267
|
+
):
|
|
268
|
+
mock_config = MagicMock()
|
|
269
|
+
mock_config.caching = False
|
|
270
|
+
mock_cache_config.return_value = mock_config
|
|
271
|
+
|
|
272
|
+
completion = await retriever.get_completion("test query", response_model=TestModel)
|
|
273
|
+
|
|
274
|
+
assert isinstance(completion, list)
|
|
275
|
+
assert len(completion) == 1
|
|
276
|
+
assert isinstance(completion[0], TestModel)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
@pytest.mark.asyncio
|
|
280
|
+
async def test_init_defaults():
|
|
281
|
+
"""Test CompletionRetriever initialization with defaults."""
|
|
282
|
+
retriever = CompletionRetriever()
|
|
283
|
+
|
|
284
|
+
assert retriever.user_prompt_path == "context_for_question.txt"
|
|
285
|
+
assert retriever.system_prompt_path == "answer_simple_question.txt"
|
|
286
|
+
assert retriever.top_k == 1
|
|
287
|
+
assert retriever.system_prompt is None
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
@pytest.mark.asyncio
|
|
291
|
+
async def test_init_custom_params():
|
|
292
|
+
"""Test CompletionRetriever initialization with custom parameters."""
|
|
293
|
+
retriever = CompletionRetriever(
|
|
294
|
+
user_prompt_path="custom_user.txt",
|
|
295
|
+
system_prompt_path="custom_system.txt",
|
|
296
|
+
system_prompt="Custom prompt",
|
|
297
|
+
top_k=10,
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
assert retriever.user_prompt_path == "custom_user.txt"
|
|
301
|
+
assert retriever.system_prompt_path == "custom_system.txt"
|
|
302
|
+
assert retriever.system_prompt == "Custom prompt"
|
|
303
|
+
assert retriever.top_k == 10
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
@pytest.mark.asyncio
|
|
307
|
+
async def test_get_context_missing_text_key(mock_vector_engine):
|
|
308
|
+
"""Test get_context handles missing text key in payload."""
|
|
309
|
+
mock_result = MagicMock()
|
|
310
|
+
mock_result.payload = {"other_key": "value"}
|
|
311
|
+
|
|
312
|
+
mock_vector_engine.search.return_value = [mock_result]
|
|
313
|
+
|
|
314
|
+
retriever = CompletionRetriever()
|
|
315
|
+
|
|
316
|
+
with patch(
|
|
317
|
+
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
|
|
318
|
+
return_value=mock_vector_engine,
|
|
319
|
+
):
|
|
320
|
+
with pytest.raises(KeyError):
|
|
321
|
+
await retriever.get_context("test query")
|