cognee 0.5.0.dev1__py3-none-any.whl → 0.5.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. cognee/api/v1/add/add.py +2 -1
  2. cognee/api/v1/datasets/routers/get_datasets_router.py +1 -0
  3. cognee/api/v1/memify/routers/get_memify_router.py +1 -0
  4. cognee/infrastructure/databases/relational/config.py +16 -1
  5. cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
  6. cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +24 -2
  7. cognee/infrastructure/databases/vector/create_vector_engine.py +9 -2
  8. cognee/infrastructure/llm/LLMGateway.py +0 -13
  9. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
  10. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
  11. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
  12. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +5 -5
  13. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
  14. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
  15. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
  16. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
  17. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
  18. cognee/modules/data/models/Data.py +2 -1
  19. cognee/modules/retrieval/triplet_retriever.py +1 -1
  20. cognee/modules/retrieval/utils/brute_force_triplet_search.py +0 -18
  21. cognee/tasks/ingestion/data_item.py +8 -0
  22. cognee/tasks/ingestion/ingest_data.py +12 -1
  23. cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
  24. cognee/tests/integration/retrieval/test_chunks_retriever.py +252 -0
  25. cognee/tests/integration/retrieval/test_graph_completion_retriever.py +268 -0
  26. cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +226 -0
  27. cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +218 -0
  28. cognee/tests/integration/retrieval/test_rag_completion_retriever.py +254 -0
  29. cognee/tests/{unit/modules/retrieval/structured_output_test.py → integration/retrieval/test_structured_output.py} +87 -77
  30. cognee/tests/integration/retrieval/test_summaries_retriever.py +184 -0
  31. cognee/tests/integration/retrieval/test_temporal_retriever.py +306 -0
  32. cognee/tests/integration/retrieval/test_triplet_retriever.py +35 -0
  33. cognee/tests/test_custom_data_label.py +68 -0
  34. cognee/tests/test_search_db.py +334 -181
  35. cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
  36. cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
  37. cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
  38. cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +181 -199
  39. cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
  40. cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +454 -162
  41. cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +674 -156
  42. cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +625 -200
  43. cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +319 -203
  44. cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +189 -155
  45. cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +539 -58
  46. cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +218 -9
  47. cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
  48. cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
  49. cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
  50. cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +246 -0
  51. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/METADATA +1 -1
  52. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/RECORD +56 -42
  53. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/WHEEL +0 -0
  54. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/entry_points.txt +0 -0
  55. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/LICENSE +0 -0
  56. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/NOTICE.md +0 -0
@@ -2,15 +2,38 @@ import pytest
2
2
  from cognee.eval_framework.corpus_builder.corpus_builder_executor import CorpusBuilderExecutor
3
3
  from cognee.infrastructure.databases.graph import get_graph_engine
4
4
  from unittest.mock import AsyncMock, patch
5
+ from cognee.eval_framework.benchmark_adapters.hotpot_qa_adapter import HotpotQAAdapter
5
6
 
6
7
  benchmark_options = ["HotPotQA", "Dummy", "TwoWikiMultiHop"]
7
8
 
9
+ MOCK_HOTPOT_CORPUS = [
10
+ {
11
+ "_id": "1",
12
+ "question": "Next to which country is Germany located?",
13
+ "answer": "Netherlands",
14
+ # HotpotQA uses "level"; TwoWikiMultiHop uses "type".
15
+ "level": "easy",
16
+ "type": "comparison",
17
+ "context": [
18
+ ["Germany", ["Germany is in Europe."]],
19
+ ["Netherlands", ["The Netherlands borders Germany."]],
20
+ ],
21
+ "supporting_facts": [["Netherlands", 0]],
22
+ }
23
+ ]
24
+
8
25
 
9
26
  @pytest.mark.parametrize("benchmark", benchmark_options)
10
27
  def test_corpus_builder_load_corpus(benchmark):
11
28
  limit = 2
12
- corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
13
- raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
29
+ if benchmark in ("HotPotQA", "TwoWikiMultiHop"):
30
+ with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
31
+ corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
32
+ raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
33
+ else:
34
+ corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
35
+ raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
36
+
14
37
  assert len(raw_corpus) > 0, f"Corpus builder loads empty corpus for {benchmark}"
15
38
  assert len(questions) <= 2, (
16
39
  f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
@@ -22,8 +45,14 @@ def test_corpus_builder_load_corpus(benchmark):
22
45
  @patch.object(CorpusBuilderExecutor, "run_cognee", new_callable=AsyncMock)
23
46
  async def test_corpus_builder_build_corpus(mock_run_cognee, benchmark):
24
47
  limit = 2
25
- corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
26
- questions = await corpus_builder.build_corpus(limit=limit)
48
+ if benchmark in ("HotPotQA", "TwoWikiMultiHop"):
49
+ with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
50
+ corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
51
+ questions = await corpus_builder.build_corpus(limit=limit)
52
+ else:
53
+ corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
54
+ questions = await corpus_builder.build_corpus(limit=limit)
55
+
27
56
  assert len(questions) <= 2, (
28
57
  f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
29
58
  )
@@ -0,0 +1,69 @@
1
+ import os
2
+ from unittest.mock import patch
3
+ from cognee.infrastructure.databases.relational.config import RelationalConfig
4
+
5
+
6
+ class TestRelationalConfig:
7
+ """Test suite for RelationalConfig DATABASE_CONNECT_ARGS parsing."""
8
+
9
+ def test_database_connect_args_valid_json_dict(self):
10
+ """Test that DATABASE_CONNECT_ARGS is parsed correctly when it's a valid JSON dict."""
11
+ with patch.dict(
12
+ os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60, "sslmode": "require"}'}
13
+ ):
14
+ config = RelationalConfig()
15
+ assert config.database_connect_args == {"timeout": 60, "sslmode": "require"}
16
+
17
+ def test_database_connect_args_empty_string(self):
18
+ """Test that empty DATABASE_CONNECT_ARGS is handled correctly."""
19
+ with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": ""}):
20
+ config = RelationalConfig()
21
+ assert config.database_connect_args == ""
22
+
23
+ def test_database_connect_args_not_set(self):
24
+ """Test that missing DATABASE_CONNECT_ARGS results in None."""
25
+ with patch.dict(os.environ, {}, clear=True):
26
+ config = RelationalConfig()
27
+ assert config.database_connect_args is None
28
+
29
+ def test_database_connect_args_invalid_json(self):
30
+ """Test that invalid JSON in DATABASE_CONNECT_ARGS results in empty dict."""
31
+ with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60'}): # Invalid JSON
32
+ config = RelationalConfig()
33
+ assert config.database_connect_args == {}
34
+
35
+ def test_database_connect_args_non_dict_json(self):
36
+ """Test that non-dict JSON in DATABASE_CONNECT_ARGS results in empty dict."""
37
+ with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '["list", "instead", "of", "dict"]'}):
38
+ config = RelationalConfig()
39
+ assert config.database_connect_args == {}
40
+
41
+ def test_database_connect_args_to_dict(self):
42
+ """Test that database_connect_args is included in to_dict() output."""
43
+ with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}):
44
+ config = RelationalConfig()
45
+ config_dict = config.to_dict()
46
+ assert "database_connect_args" in config_dict
47
+ assert config_dict["database_connect_args"] == {"timeout": 60}
48
+
49
+ def test_database_connect_args_integer_value(self):
50
+ """Test that DATABASE_CONNECT_ARGS with integer values is parsed correctly."""
51
+ with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"connect_timeout": 10}'}):
52
+ config = RelationalConfig()
53
+ assert config.database_connect_args == {"connect_timeout": 10}
54
+
55
+ def test_database_connect_args_mixed_types(self):
56
+ """Test that DATABASE_CONNECT_ARGS with mixed value types is parsed correctly."""
57
+ with patch.dict(
58
+ os.environ,
59
+ {
60
+ "DATABASE_CONNECT_ARGS": '{"timeout": 60, "sslmode": "require", "retries": 3, "keepalive": true}'
61
+ },
62
+ ):
63
+ config = RelationalConfig()
64
+ assert config.database_connect_args == {
65
+ "timeout": 60,
66
+ "sslmode": "require",
67
+ "retries": 3,
68
+ "keepalive": True,
69
+ }
@@ -1,201 +1,183 @@
1
- import os
2
1
  import pytest
3
- import pathlib
4
- from typing import List
5
- import cognee
6
- from cognee.low_level import setup
7
- from cognee.tasks.storage import add_data_points
8
- from cognee.infrastructure.databases.vector import get_vector_engine
9
- from cognee.modules.chunking.models import DocumentChunk
10
- from cognee.modules.data.processing.document_types import TextDocument
11
- from cognee.modules.retrieval.exceptions.exceptions import NoDataError
2
+ from unittest.mock import AsyncMock, patch, MagicMock
3
+
12
4
  from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
13
- from cognee.infrastructure.engine import DataPoint
14
- from cognee.modules.data.processing.document_types import Document
15
- from cognee.modules.engine.models import Entity
16
-
17
-
18
- class DocumentChunkWithEntities(DataPoint):
19
- text: str
20
- chunk_size: int
21
- chunk_index: int
22
- cut_type: str
23
- is_part_of: Document
24
- contains: List[Entity] = None
25
-
26
- metadata: dict = {"index_fields": ["text"]}
27
-
28
-
29
- class TestChunksRetriever:
30
- @pytest.mark.asyncio
31
- async def test_chunk_context_simple(self):
32
- system_directory_path = os.path.join(
33
- pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_simple"
34
- )
35
- cognee.config.system_root_directory(system_directory_path)
36
- data_directory_path = os.path.join(
37
- pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_simple"
38
- )
39
- cognee.config.data_root_directory(data_directory_path)
40
-
41
- await cognee.prune.prune_data()
42
- await cognee.prune.prune_system(metadata=True)
43
- await setup()
44
-
45
- document = TextDocument(
46
- name="Steve Rodger's career",
47
- raw_data_location="somewhere",
48
- external_metadata="",
49
- mime_type="text/plain",
50
- )
51
-
52
- chunk1 = DocumentChunk(
53
- text="Steve Rodger",
54
- chunk_size=2,
55
- chunk_index=0,
56
- cut_type="sentence_end",
57
- is_part_of=document,
58
- contains=[],
59
- )
60
- chunk2 = DocumentChunk(
61
- text="Mike Broski",
62
- chunk_size=2,
63
- chunk_index=1,
64
- cut_type="sentence_end",
65
- is_part_of=document,
66
- contains=[],
67
- )
68
- chunk3 = DocumentChunk(
69
- text="Christina Mayer",
70
- chunk_size=2,
71
- chunk_index=2,
72
- cut_type="sentence_end",
73
- is_part_of=document,
74
- contains=[],
75
- )
76
-
77
- entities = [chunk1, chunk2, chunk3]
78
-
79
- await add_data_points(entities)
80
-
81
- retriever = ChunksRetriever()
82
-
83
- context = await retriever.get_context("Mike")
84
-
85
- assert context[0]["text"] == "Mike Broski", "Failed to get Mike Broski"
86
-
87
- @pytest.mark.asyncio
88
- async def test_chunk_context_complex(self):
89
- system_directory_path = os.path.join(
90
- pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_complex"
91
- )
92
- cognee.config.system_root_directory(system_directory_path)
93
- data_directory_path = os.path.join(
94
- pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_complex"
95
- )
96
- cognee.config.data_root_directory(data_directory_path)
97
-
98
- await cognee.prune.prune_data()
99
- await cognee.prune.prune_system(metadata=True)
100
- await setup()
101
-
102
- document1 = TextDocument(
103
- name="Employee List",
104
- raw_data_location="somewhere",
105
- external_metadata="",
106
- mime_type="text/plain",
107
- )
108
-
109
- document2 = TextDocument(
110
- name="Car List",
111
- raw_data_location="somewhere",
112
- external_metadata="",
113
- mime_type="text/plain",
114
- )
115
-
116
- chunk1 = DocumentChunk(
117
- text="Steve Rodger",
118
- chunk_size=2,
119
- chunk_index=0,
120
- cut_type="sentence_end",
121
- is_part_of=document1,
122
- contains=[],
123
- )
124
- chunk2 = DocumentChunk(
125
- text="Mike Broski",
126
- chunk_size=2,
127
- chunk_index=1,
128
- cut_type="sentence_end",
129
- is_part_of=document1,
130
- contains=[],
131
- )
132
- chunk3 = DocumentChunk(
133
- text="Christina Mayer",
134
- chunk_size=2,
135
- chunk_index=2,
136
- cut_type="sentence_end",
137
- is_part_of=document1,
138
- contains=[],
139
- )
140
-
141
- chunk4 = DocumentChunk(
142
- text="Range Rover",
143
- chunk_size=2,
144
- chunk_index=0,
145
- cut_type="sentence_end",
146
- is_part_of=document2,
147
- contains=[],
148
- )
149
- chunk5 = DocumentChunk(
150
- text="Hyundai",
151
- chunk_size=2,
152
- chunk_index=1,
153
- cut_type="sentence_end",
154
- is_part_of=document2,
155
- contains=[],
156
- )
157
- chunk6 = DocumentChunk(
158
- text="Chrysler",
159
- chunk_size=2,
160
- chunk_index=2,
161
- cut_type="sentence_end",
162
- is_part_of=document2,
163
- contains=[],
164
- )
165
-
166
- entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
167
-
168
- await add_data_points(entities)
169
-
170
- retriever = ChunksRetriever(top_k=20)
171
-
172
- context = await retriever.get_context("Christina")
173
-
174
- assert context[0]["text"] == "Christina Mayer", "Failed to get Christina Mayer"
175
-
176
- @pytest.mark.asyncio
177
- async def test_chunk_context_on_empty_graph(self):
178
- system_directory_path = os.path.join(
179
- pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_on_empty_graph"
180
- )
181
- cognee.config.system_root_directory(system_directory_path)
182
- data_directory_path = os.path.join(
183
- pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_on_empty_graph"
184
- )
185
- cognee.config.data_root_directory(data_directory_path)
186
-
187
- await cognee.prune.prune_data()
188
- await cognee.prune.prune_system(metadata=True)
189
-
190
- retriever = ChunksRetriever()
191
-
192
- with pytest.raises(NoDataError):
193
- await retriever.get_context("Christina Mayer")
194
-
195
- vector_engine = get_vector_engine()
196
- await vector_engine.create_collection(
197
- "DocumentChunk_text", payload_schema=DocumentChunkWithEntities
198
- )
199
-
200
- context = await retriever.get_context("Christina Mayer")
201
- assert len(context) == 0, "Found chunks when none should exist"
5
+ from cognee.modules.retrieval.exceptions.exceptions import NoDataError
6
+ from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
7
+
8
+
9
+ @pytest.fixture
10
+ def mock_vector_engine():
11
+ """Create a mock vector engine."""
12
+ engine = AsyncMock()
13
+ engine.search = AsyncMock()
14
+ return engine
15
+
16
+
17
+ @pytest.mark.asyncio
18
+ async def test_get_context_success(mock_vector_engine):
19
+ """Test successful retrieval of chunk context."""
20
+ mock_result1 = MagicMock()
21
+ mock_result1.payload = {"text": "Steve Rodger", "chunk_index": 0}
22
+ mock_result2 = MagicMock()
23
+ mock_result2.payload = {"text": "Mike Broski", "chunk_index": 1}
24
+
25
+ mock_vector_engine.search.return_value = [mock_result1, mock_result2]
26
+
27
+ retriever = ChunksRetriever(top_k=5)
28
+
29
+ with patch(
30
+ "cognee.modules.retrieval.chunks_retriever.get_vector_engine",
31
+ return_value=mock_vector_engine,
32
+ ):
33
+ context = await retriever.get_context("test query")
34
+
35
+ assert len(context) == 2
36
+ assert context[0]["text"] == "Steve Rodger"
37
+ assert context[1]["text"] == "Mike Broski"
38
+ mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=5)
39
+
40
+
41
+ @pytest.mark.asyncio
42
+ async def test_get_context_collection_not_found_error(mock_vector_engine):
43
+ """Test that CollectionNotFoundError is converted to NoDataError."""
44
+ mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found")
45
+
46
+ retriever = ChunksRetriever()
47
+
48
+ with patch(
49
+ "cognee.modules.retrieval.chunks_retriever.get_vector_engine",
50
+ return_value=mock_vector_engine,
51
+ ):
52
+ with pytest.raises(NoDataError, match="No data found"):
53
+ await retriever.get_context("test query")
54
+
55
+
56
+ @pytest.mark.asyncio
57
+ async def test_get_context_empty_results(mock_vector_engine):
58
+ """Test that empty list is returned when no chunks are found."""
59
+ mock_vector_engine.search.return_value = []
60
+
61
+ retriever = ChunksRetriever()
62
+
63
+ with patch(
64
+ "cognee.modules.retrieval.chunks_retriever.get_vector_engine",
65
+ return_value=mock_vector_engine,
66
+ ):
67
+ context = await retriever.get_context("test query")
68
+
69
+ assert context == []
70
+
71
+
72
+ @pytest.mark.asyncio
73
+ async def test_get_context_top_k_limit(mock_vector_engine):
74
+ """Test that top_k parameter limits the number of results."""
75
+ mock_results = [MagicMock() for _ in range(3)]
76
+ for i, result in enumerate(mock_results):
77
+ result.payload = {"text": f"Chunk {i}"}
78
+
79
+ mock_vector_engine.search.return_value = mock_results
80
+
81
+ retriever = ChunksRetriever(top_k=3)
82
+
83
+ with patch(
84
+ "cognee.modules.retrieval.chunks_retriever.get_vector_engine",
85
+ return_value=mock_vector_engine,
86
+ ):
87
+ context = await retriever.get_context("test query")
88
+
89
+ assert len(context) == 3
90
+ mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=3)
91
+
92
+
93
+ @pytest.mark.asyncio
94
+ async def test_get_completion_with_context(mock_vector_engine):
95
+ """Test get_completion returns provided context."""
96
+ retriever = ChunksRetriever()
97
+
98
+ provided_context = [{"text": "Steve Rodger"}, {"text": "Mike Broski"}]
99
+ completion = await retriever.get_completion("test query", context=provided_context)
100
+
101
+ assert completion == provided_context
102
+
103
+
104
+ @pytest.mark.asyncio
105
+ async def test_get_completion_without_context(mock_vector_engine):
106
+ """Test get_completion retrieves context when not provided."""
107
+ mock_result = MagicMock()
108
+ mock_result.payload = {"text": "Steve Rodger"}
109
+ mock_vector_engine.search.return_value = [mock_result]
110
+
111
+ retriever = ChunksRetriever()
112
+
113
+ with patch(
114
+ "cognee.modules.retrieval.chunks_retriever.get_vector_engine",
115
+ return_value=mock_vector_engine,
116
+ ):
117
+ completion = await retriever.get_completion("test query")
118
+
119
+ assert len(completion) == 1
120
+ assert completion[0]["text"] == "Steve Rodger"
121
+
122
+
123
+ @pytest.mark.asyncio
124
+ async def test_init_defaults():
125
+ """Test ChunksRetriever initialization with defaults."""
126
+ retriever = ChunksRetriever()
127
+
128
+ assert retriever.top_k == 5
129
+
130
+
131
+ @pytest.mark.asyncio
132
+ async def test_init_custom_top_k():
133
+ """Test ChunksRetriever initialization with custom top_k."""
134
+ retriever = ChunksRetriever(top_k=10)
135
+
136
+ assert retriever.top_k == 10
137
+
138
+
139
+ @pytest.mark.asyncio
140
+ async def test_init_none_top_k():
141
+ """Test ChunksRetriever initialization with None top_k."""
142
+ retriever = ChunksRetriever(top_k=None)
143
+
144
+ assert retriever.top_k is None
145
+
146
+
147
+ @pytest.mark.asyncio
148
+ async def test_get_context_empty_payload(mock_vector_engine):
149
+ """Test get_context handles empty payload."""
150
+ mock_result = MagicMock()
151
+ mock_result.payload = {}
152
+
153
+ mock_vector_engine.search.return_value = [mock_result]
154
+
155
+ retriever = ChunksRetriever()
156
+
157
+ with patch(
158
+ "cognee.modules.retrieval.chunks_retriever.get_vector_engine",
159
+ return_value=mock_vector_engine,
160
+ ):
161
+ context = await retriever.get_context("test query")
162
+
163
+ assert len(context) == 1
164
+ assert context[0] == {}
165
+
166
+
167
+ @pytest.mark.asyncio
168
+ async def test_get_completion_with_session_id(mock_vector_engine):
169
+ """Test get_completion with session_id parameter."""
170
+ mock_result = MagicMock()
171
+ mock_result.payload = {"text": "Steve Rodger"}
172
+ mock_vector_engine.search.return_value = [mock_result]
173
+
174
+ retriever = ChunksRetriever()
175
+
176
+ with patch(
177
+ "cognee.modules.retrieval.chunks_retriever.get_vector_engine",
178
+ return_value=mock_vector_engine,
179
+ ):
180
+ completion = await retriever.get_completion("test query", session_id="test_session")
181
+
182
+ assert len(completion) == 1
183
+ assert completion[0]["text"] == "Steve Rodger"