cognee 0.5.1__py3-none-any.whl → 0.5.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/api/v1/add/add.py +2 -1
- cognee/api/v1/datasets/routers/get_datasets_router.py +1 -0
- cognee/api/v1/memify/routers/get_memify_router.py +1 -0
- cognee/api/v1/search/search.py +0 -4
- cognee/infrastructure/databases/relational/config.py +16 -1
- cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +24 -2
- cognee/infrastructure/databases/vector/create_vector_engine.py +9 -2
- cognee/infrastructure/llm/LLMGateway.py +0 -13
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +5 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
- cognee/modules/data/models/Data.py +2 -1
- cognee/modules/retrieval/triplet_retriever.py +1 -1
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +0 -18
- cognee/modules/search/methods/search.py +18 -25
- cognee/tasks/ingestion/data_item.py +8 -0
- cognee/tasks/ingestion/ingest_data.py +12 -1
- cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
- cognee/tests/integration/retrieval/test_chunks_retriever.py +252 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever.py +268 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +226 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +218 -0
- cognee/tests/integration/retrieval/test_rag_completion_retriever.py +254 -0
- cognee/tests/{unit/modules/retrieval/structured_output_test.py → integration/retrieval/test_structured_output.py} +87 -77
- cognee/tests/integration/retrieval/test_summaries_retriever.py +184 -0
- cognee/tests/integration/retrieval/test_temporal_retriever.py +306 -0
- cognee/tests/integration/retrieval/test_triplet_retriever.py +35 -0
- cognee/tests/test_custom_data_label.py +68 -0
- cognee/tests/test_search_db.py +334 -181
- cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
- cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
- cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +181 -199
- cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +454 -162
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +674 -156
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +625 -200
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +319 -203
- cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +189 -155
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +539 -58
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +218 -9
- cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
- cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
- cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +246 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/METADATA +1 -1
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/RECORD +58 -45
- cognee/tests/unit/modules/search/test_search.py +0 -100
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/WHEEL +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -2,15 +2,38 @@ import pytest
|
|
|
2
2
|
from cognee.eval_framework.corpus_builder.corpus_builder_executor import CorpusBuilderExecutor
|
|
3
3
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
4
4
|
from unittest.mock import AsyncMock, patch
|
|
5
|
+
from cognee.eval_framework.benchmark_adapters.hotpot_qa_adapter import HotpotQAAdapter
|
|
5
6
|
|
|
6
7
|
benchmark_options = ["HotPotQA", "Dummy", "TwoWikiMultiHop"]
|
|
7
8
|
|
|
9
|
+
MOCK_HOTPOT_CORPUS = [
|
|
10
|
+
{
|
|
11
|
+
"_id": "1",
|
|
12
|
+
"question": "Next to which country is Germany located?",
|
|
13
|
+
"answer": "Netherlands",
|
|
14
|
+
# HotpotQA uses "level"; TwoWikiMultiHop uses "type".
|
|
15
|
+
"level": "easy",
|
|
16
|
+
"type": "comparison",
|
|
17
|
+
"context": [
|
|
18
|
+
["Germany", ["Germany is in Europe."]],
|
|
19
|
+
["Netherlands", ["The Netherlands borders Germany."]],
|
|
20
|
+
],
|
|
21
|
+
"supporting_facts": [["Netherlands", 0]],
|
|
22
|
+
}
|
|
23
|
+
]
|
|
24
|
+
|
|
8
25
|
|
|
9
26
|
@pytest.mark.parametrize("benchmark", benchmark_options)
|
|
10
27
|
def test_corpus_builder_load_corpus(benchmark):
|
|
11
28
|
limit = 2
|
|
12
|
-
|
|
13
|
-
|
|
29
|
+
if benchmark in ("HotPotQA", "TwoWikiMultiHop"):
|
|
30
|
+
with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
|
|
31
|
+
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
|
32
|
+
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
|
|
33
|
+
else:
|
|
34
|
+
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
|
35
|
+
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
|
|
36
|
+
|
|
14
37
|
assert len(raw_corpus) > 0, f"Corpus builder loads empty corpus for {benchmark}"
|
|
15
38
|
assert len(questions) <= 2, (
|
|
16
39
|
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
|
|
@@ -22,8 +45,14 @@ def test_corpus_builder_load_corpus(benchmark):
|
|
|
22
45
|
@patch.object(CorpusBuilderExecutor, "run_cognee", new_callable=AsyncMock)
|
|
23
46
|
async def test_corpus_builder_build_corpus(mock_run_cognee, benchmark):
|
|
24
47
|
limit = 2
|
|
25
|
-
|
|
26
|
-
|
|
48
|
+
if benchmark in ("HotPotQA", "TwoWikiMultiHop"):
|
|
49
|
+
with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
|
|
50
|
+
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
|
51
|
+
questions = await corpus_builder.build_corpus(limit=limit)
|
|
52
|
+
else:
|
|
53
|
+
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
|
54
|
+
questions = await corpus_builder.build_corpus(limit=limit)
|
|
55
|
+
|
|
27
56
|
assert len(questions) <= 2, (
|
|
28
57
|
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
|
|
29
58
|
)
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from unittest.mock import patch
|
|
3
|
+
from cognee.infrastructure.databases.relational.config import RelationalConfig
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TestRelationalConfig:
|
|
7
|
+
"""Test suite for RelationalConfig DATABASE_CONNECT_ARGS parsing."""
|
|
8
|
+
|
|
9
|
+
def test_database_connect_args_valid_json_dict(self):
|
|
10
|
+
"""Test that DATABASE_CONNECT_ARGS is parsed correctly when it's a valid JSON dict."""
|
|
11
|
+
with patch.dict(
|
|
12
|
+
os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60, "sslmode": "require"}'}
|
|
13
|
+
):
|
|
14
|
+
config = RelationalConfig()
|
|
15
|
+
assert config.database_connect_args == {"timeout": 60, "sslmode": "require"}
|
|
16
|
+
|
|
17
|
+
def test_database_connect_args_empty_string(self):
|
|
18
|
+
"""Test that empty DATABASE_CONNECT_ARGS is handled correctly."""
|
|
19
|
+
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": ""}):
|
|
20
|
+
config = RelationalConfig()
|
|
21
|
+
assert config.database_connect_args == ""
|
|
22
|
+
|
|
23
|
+
def test_database_connect_args_not_set(self):
|
|
24
|
+
"""Test that missing DATABASE_CONNECT_ARGS results in None."""
|
|
25
|
+
with patch.dict(os.environ, {}, clear=True):
|
|
26
|
+
config = RelationalConfig()
|
|
27
|
+
assert config.database_connect_args is None
|
|
28
|
+
|
|
29
|
+
def test_database_connect_args_invalid_json(self):
|
|
30
|
+
"""Test that invalid JSON in DATABASE_CONNECT_ARGS results in empty dict."""
|
|
31
|
+
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60'}): # Invalid JSON
|
|
32
|
+
config = RelationalConfig()
|
|
33
|
+
assert config.database_connect_args == {}
|
|
34
|
+
|
|
35
|
+
def test_database_connect_args_non_dict_json(self):
|
|
36
|
+
"""Test that non-dict JSON in DATABASE_CONNECT_ARGS results in empty dict."""
|
|
37
|
+
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '["list", "instead", "of", "dict"]'}):
|
|
38
|
+
config = RelationalConfig()
|
|
39
|
+
assert config.database_connect_args == {}
|
|
40
|
+
|
|
41
|
+
def test_database_connect_args_to_dict(self):
|
|
42
|
+
"""Test that database_connect_args is included in to_dict() output."""
|
|
43
|
+
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}):
|
|
44
|
+
config = RelationalConfig()
|
|
45
|
+
config_dict = config.to_dict()
|
|
46
|
+
assert "database_connect_args" in config_dict
|
|
47
|
+
assert config_dict["database_connect_args"] == {"timeout": 60}
|
|
48
|
+
|
|
49
|
+
def test_database_connect_args_integer_value(self):
|
|
50
|
+
"""Test that DATABASE_CONNECT_ARGS with integer values is parsed correctly."""
|
|
51
|
+
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"connect_timeout": 10}'}):
|
|
52
|
+
config = RelationalConfig()
|
|
53
|
+
assert config.database_connect_args == {"connect_timeout": 10}
|
|
54
|
+
|
|
55
|
+
def test_database_connect_args_mixed_types(self):
|
|
56
|
+
"""Test that DATABASE_CONNECT_ARGS with mixed value types is parsed correctly."""
|
|
57
|
+
with patch.dict(
|
|
58
|
+
os.environ,
|
|
59
|
+
{
|
|
60
|
+
"DATABASE_CONNECT_ARGS": '{"timeout": 60, "sslmode": "require", "retries": 3, "keepalive": true}'
|
|
61
|
+
},
|
|
62
|
+
):
|
|
63
|
+
config = RelationalConfig()
|
|
64
|
+
assert config.database_connect_args == {
|
|
65
|
+
"timeout": 60,
|
|
66
|
+
"sslmode": "require",
|
|
67
|
+
"retries": 3,
|
|
68
|
+
"keepalive": True,
|
|
69
|
+
}
|
|
@@ -1,201 +1,183 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import pytest
|
|
3
|
-
import
|
|
4
|
-
|
|
5
|
-
import cognee
|
|
6
|
-
from cognee.low_level import setup
|
|
7
|
-
from cognee.tasks.storage import add_data_points
|
|
8
|
-
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
9
|
-
from cognee.modules.chunking.models import DocumentChunk
|
|
10
|
-
from cognee.modules.data.processing.document_types import TextDocument
|
|
11
|
-
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
|
2
|
+
from unittest.mock import AsyncMock, patch, MagicMock
|
|
3
|
+
|
|
12
4
|
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
|
|
13
|
-
from cognee.
|
|
14
|
-
from cognee.
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
await
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
)
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
cognee.
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
)
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
cognee.
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
await
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
with pytest.raises(NoDataError):
|
|
193
|
-
await retriever.get_context("Christina Mayer")
|
|
194
|
-
|
|
195
|
-
vector_engine = get_vector_engine()
|
|
196
|
-
await vector_engine.create_collection(
|
|
197
|
-
"DocumentChunk_text", payload_schema=DocumentChunkWithEntities
|
|
198
|
-
)
|
|
199
|
-
|
|
200
|
-
context = await retriever.get_context("Christina Mayer")
|
|
201
|
-
assert len(context) == 0, "Found chunks when none should exist"
|
|
5
|
+
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
|
6
|
+
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@pytest.fixture
|
|
10
|
+
def mock_vector_engine():
|
|
11
|
+
"""Create a mock vector engine."""
|
|
12
|
+
engine = AsyncMock()
|
|
13
|
+
engine.search = AsyncMock()
|
|
14
|
+
return engine
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@pytest.mark.asyncio
|
|
18
|
+
async def test_get_context_success(mock_vector_engine):
|
|
19
|
+
"""Test successful retrieval of chunk context."""
|
|
20
|
+
mock_result1 = MagicMock()
|
|
21
|
+
mock_result1.payload = {"text": "Steve Rodger", "chunk_index": 0}
|
|
22
|
+
mock_result2 = MagicMock()
|
|
23
|
+
mock_result2.payload = {"text": "Mike Broski", "chunk_index": 1}
|
|
24
|
+
|
|
25
|
+
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
|
|
26
|
+
|
|
27
|
+
retriever = ChunksRetriever(top_k=5)
|
|
28
|
+
|
|
29
|
+
with patch(
|
|
30
|
+
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
|
|
31
|
+
return_value=mock_vector_engine,
|
|
32
|
+
):
|
|
33
|
+
context = await retriever.get_context("test query")
|
|
34
|
+
|
|
35
|
+
assert len(context) == 2
|
|
36
|
+
assert context[0]["text"] == "Steve Rodger"
|
|
37
|
+
assert context[1]["text"] == "Mike Broski"
|
|
38
|
+
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=5)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@pytest.mark.asyncio
|
|
42
|
+
async def test_get_context_collection_not_found_error(mock_vector_engine):
|
|
43
|
+
"""Test that CollectionNotFoundError is converted to NoDataError."""
|
|
44
|
+
mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found")
|
|
45
|
+
|
|
46
|
+
retriever = ChunksRetriever()
|
|
47
|
+
|
|
48
|
+
with patch(
|
|
49
|
+
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
|
|
50
|
+
return_value=mock_vector_engine,
|
|
51
|
+
):
|
|
52
|
+
with pytest.raises(NoDataError, match="No data found"):
|
|
53
|
+
await retriever.get_context("test query")
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@pytest.mark.asyncio
|
|
57
|
+
async def test_get_context_empty_results(mock_vector_engine):
|
|
58
|
+
"""Test that empty list is returned when no chunks are found."""
|
|
59
|
+
mock_vector_engine.search.return_value = []
|
|
60
|
+
|
|
61
|
+
retriever = ChunksRetriever()
|
|
62
|
+
|
|
63
|
+
with patch(
|
|
64
|
+
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
|
|
65
|
+
return_value=mock_vector_engine,
|
|
66
|
+
):
|
|
67
|
+
context = await retriever.get_context("test query")
|
|
68
|
+
|
|
69
|
+
assert context == []
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@pytest.mark.asyncio
|
|
73
|
+
async def test_get_context_top_k_limit(mock_vector_engine):
|
|
74
|
+
"""Test that top_k parameter limits the number of results."""
|
|
75
|
+
mock_results = [MagicMock() for _ in range(3)]
|
|
76
|
+
for i, result in enumerate(mock_results):
|
|
77
|
+
result.payload = {"text": f"Chunk {i}"}
|
|
78
|
+
|
|
79
|
+
mock_vector_engine.search.return_value = mock_results
|
|
80
|
+
|
|
81
|
+
retriever = ChunksRetriever(top_k=3)
|
|
82
|
+
|
|
83
|
+
with patch(
|
|
84
|
+
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
|
|
85
|
+
return_value=mock_vector_engine,
|
|
86
|
+
):
|
|
87
|
+
context = await retriever.get_context("test query")
|
|
88
|
+
|
|
89
|
+
assert len(context) == 3
|
|
90
|
+
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=3)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@pytest.mark.asyncio
|
|
94
|
+
async def test_get_completion_with_context(mock_vector_engine):
|
|
95
|
+
"""Test get_completion returns provided context."""
|
|
96
|
+
retriever = ChunksRetriever()
|
|
97
|
+
|
|
98
|
+
provided_context = [{"text": "Steve Rodger"}, {"text": "Mike Broski"}]
|
|
99
|
+
completion = await retriever.get_completion("test query", context=provided_context)
|
|
100
|
+
|
|
101
|
+
assert completion == provided_context
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@pytest.mark.asyncio
|
|
105
|
+
async def test_get_completion_without_context(mock_vector_engine):
|
|
106
|
+
"""Test get_completion retrieves context when not provided."""
|
|
107
|
+
mock_result = MagicMock()
|
|
108
|
+
mock_result.payload = {"text": "Steve Rodger"}
|
|
109
|
+
mock_vector_engine.search.return_value = [mock_result]
|
|
110
|
+
|
|
111
|
+
retriever = ChunksRetriever()
|
|
112
|
+
|
|
113
|
+
with patch(
|
|
114
|
+
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
|
|
115
|
+
return_value=mock_vector_engine,
|
|
116
|
+
):
|
|
117
|
+
completion = await retriever.get_completion("test query")
|
|
118
|
+
|
|
119
|
+
assert len(completion) == 1
|
|
120
|
+
assert completion[0]["text"] == "Steve Rodger"
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@pytest.mark.asyncio
|
|
124
|
+
async def test_init_defaults():
|
|
125
|
+
"""Test ChunksRetriever initialization with defaults."""
|
|
126
|
+
retriever = ChunksRetriever()
|
|
127
|
+
|
|
128
|
+
assert retriever.top_k == 5
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@pytest.mark.asyncio
|
|
132
|
+
async def test_init_custom_top_k():
|
|
133
|
+
"""Test ChunksRetriever initialization with custom top_k."""
|
|
134
|
+
retriever = ChunksRetriever(top_k=10)
|
|
135
|
+
|
|
136
|
+
assert retriever.top_k == 10
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@pytest.mark.asyncio
|
|
140
|
+
async def test_init_none_top_k():
|
|
141
|
+
"""Test ChunksRetriever initialization with None top_k."""
|
|
142
|
+
retriever = ChunksRetriever(top_k=None)
|
|
143
|
+
|
|
144
|
+
assert retriever.top_k is None
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@pytest.mark.asyncio
|
|
148
|
+
async def test_get_context_empty_payload(mock_vector_engine):
|
|
149
|
+
"""Test get_context handles empty payload."""
|
|
150
|
+
mock_result = MagicMock()
|
|
151
|
+
mock_result.payload = {}
|
|
152
|
+
|
|
153
|
+
mock_vector_engine.search.return_value = [mock_result]
|
|
154
|
+
|
|
155
|
+
retriever = ChunksRetriever()
|
|
156
|
+
|
|
157
|
+
with patch(
|
|
158
|
+
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
|
|
159
|
+
return_value=mock_vector_engine,
|
|
160
|
+
):
|
|
161
|
+
context = await retriever.get_context("test query")
|
|
162
|
+
|
|
163
|
+
assert len(context) == 1
|
|
164
|
+
assert context[0] == {}
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
@pytest.mark.asyncio
|
|
168
|
+
async def test_get_completion_with_session_id(mock_vector_engine):
|
|
169
|
+
"""Test get_completion with session_id parameter."""
|
|
170
|
+
mock_result = MagicMock()
|
|
171
|
+
mock_result.payload = {"text": "Steve Rodger"}
|
|
172
|
+
mock_vector_engine.search.return_value = [mock_result]
|
|
173
|
+
|
|
174
|
+
retriever = ChunksRetriever()
|
|
175
|
+
|
|
176
|
+
with patch(
|
|
177
|
+
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
|
|
178
|
+
return_value=mock_vector_engine,
|
|
179
|
+
):
|
|
180
|
+
completion = await retriever.get_completion("test query", session_id="test_session")
|
|
181
|
+
|
|
182
|
+
assert len(completion) == 1
|
|
183
|
+
assert completion[0]["text"] == "Steve Rodger"
|