cognee 0.5.0.dev1__py3-none-any.whl → 0.5.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/api/v1/add/add.py +2 -1
- cognee/api/v1/datasets/routers/get_datasets_router.py +1 -0
- cognee/api/v1/memify/routers/get_memify_router.py +1 -0
- cognee/infrastructure/databases/relational/config.py +16 -1
- cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +24 -2
- cognee/infrastructure/databases/vector/create_vector_engine.py +9 -2
- cognee/infrastructure/llm/LLMGateway.py +0 -13
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +5 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
- cognee/modules/data/models/Data.py +2 -1
- cognee/modules/retrieval/triplet_retriever.py +1 -1
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +0 -18
- cognee/tasks/ingestion/data_item.py +8 -0
- cognee/tasks/ingestion/ingest_data.py +12 -1
- cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
- cognee/tests/integration/retrieval/test_chunks_retriever.py +252 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever.py +268 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +226 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +218 -0
- cognee/tests/integration/retrieval/test_rag_completion_retriever.py +254 -0
- cognee/tests/{unit/modules/retrieval/structured_output_test.py → integration/retrieval/test_structured_output.py} +87 -77
- cognee/tests/integration/retrieval/test_summaries_retriever.py +184 -0
- cognee/tests/integration/retrieval/test_temporal_retriever.py +306 -0
- cognee/tests/integration/retrieval/test_triplet_retriever.py +35 -0
- cognee/tests/test_custom_data_label.py +68 -0
- cognee/tests/test_search_db.py +334 -181
- cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
- cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
- cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +181 -199
- cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +454 -162
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +674 -156
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +625 -200
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +319 -203
- cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +189 -155
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +539 -58
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +218 -9
- cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
- cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
- cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +246 -0
- {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/METADATA +1 -1
- {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/RECORD +56 -42
- {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/WHEEL +0 -0
- {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -1,170 +1,688 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import pytest
|
|
3
|
-
import
|
|
4
|
-
from
|
|
5
|
-
|
|
6
|
-
import
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
from cognee.modules.
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
2
|
+
from unittest.mock import AsyncMock, patch, MagicMock
|
|
3
|
+
from uuid import UUID
|
|
4
|
+
|
|
5
|
+
from cognee.modules.retrieval.graph_completion_cot_retriever import (
|
|
6
|
+
GraphCompletionCotRetriever,
|
|
7
|
+
_as_answer_text,
|
|
8
|
+
)
|
|
9
|
+
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
|
10
|
+
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@pytest.fixture
|
|
14
|
+
def mock_edge():
|
|
15
|
+
"""Create a mock edge."""
|
|
16
|
+
edge = MagicMock(spec=Edge)
|
|
17
|
+
return edge
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@pytest.mark.asyncio
|
|
21
|
+
async def test_get_triplets_inherited(mock_edge):
|
|
22
|
+
"""Test that get_triplets is inherited from parent class."""
|
|
23
|
+
retriever = GraphCompletionCotRetriever()
|
|
24
|
+
|
|
25
|
+
with patch(
|
|
26
|
+
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
|
27
|
+
return_value=[mock_edge],
|
|
28
|
+
):
|
|
29
|
+
triplets = await retriever.get_triplets("test query")
|
|
30
|
+
|
|
31
|
+
assert len(triplets) == 1
|
|
32
|
+
assert triplets[0] == mock_edge
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@pytest.mark.asyncio
|
|
36
|
+
async def test_init_custom_params():
|
|
37
|
+
"""Test GraphCompletionCotRetriever initialization with custom parameters."""
|
|
38
|
+
retriever = GraphCompletionCotRetriever(
|
|
39
|
+
top_k=10,
|
|
40
|
+
user_prompt_path="custom_user.txt",
|
|
41
|
+
system_prompt_path="custom_system.txt",
|
|
42
|
+
validation_user_prompt_path="custom_validation_user.txt",
|
|
43
|
+
validation_system_prompt_path="custom_validation_system.txt",
|
|
44
|
+
followup_system_prompt_path="custom_followup_system.txt",
|
|
45
|
+
followup_user_prompt_path="custom_followup_user.txt",
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
assert retriever.top_k == 10
|
|
49
|
+
assert retriever.user_prompt_path == "custom_user.txt"
|
|
50
|
+
assert retriever.system_prompt_path == "custom_system.txt"
|
|
51
|
+
assert retriever.validation_user_prompt_path == "custom_validation_user.txt"
|
|
52
|
+
assert retriever.validation_system_prompt_path == "custom_validation_system.txt"
|
|
53
|
+
assert retriever.followup_system_prompt_path == "custom_followup_system.txt"
|
|
54
|
+
assert retriever.followup_user_prompt_path == "custom_followup_user.txt"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@pytest.mark.asyncio
|
|
58
|
+
async def test_init_defaults():
|
|
59
|
+
"""Test GraphCompletionCotRetriever initialization with defaults."""
|
|
60
|
+
retriever = GraphCompletionCotRetriever()
|
|
61
|
+
|
|
62
|
+
assert retriever.validation_user_prompt_path == "cot_validation_user_prompt.txt"
|
|
63
|
+
assert retriever.validation_system_prompt_path == "cot_validation_system_prompt.txt"
|
|
64
|
+
assert retriever.followup_system_prompt_path == "cot_followup_system_prompt.txt"
|
|
65
|
+
assert retriever.followup_user_prompt_path == "cot_followup_user_prompt.txt"
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@pytest.mark.asyncio
|
|
69
|
+
async def test_run_cot_completion_round_zero_with_context(mock_edge):
|
|
70
|
+
"""Test _run_cot_completion round 0 with provided context."""
|
|
71
|
+
retriever = GraphCompletionCotRetriever()
|
|
72
|
+
|
|
73
|
+
with (
|
|
74
|
+
patch(
|
|
75
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
76
|
+
return_value="Resolved context",
|
|
77
|
+
),
|
|
78
|
+
patch(
|
|
79
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
|
80
|
+
return_value="Generated answer",
|
|
81
|
+
),
|
|
82
|
+
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
|
83
|
+
patch(
|
|
84
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
|
85
|
+
return_value="Generated answer",
|
|
86
|
+
),
|
|
87
|
+
patch(
|
|
88
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
|
|
89
|
+
return_value="Rendered prompt",
|
|
90
|
+
),
|
|
91
|
+
patch(
|
|
92
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
|
|
93
|
+
return_value="System prompt",
|
|
94
|
+
),
|
|
95
|
+
patch.object(
|
|
96
|
+
LLMGateway,
|
|
97
|
+
"acreate_structured_output",
|
|
98
|
+
new_callable=AsyncMock,
|
|
99
|
+
side_effect=["validation_result", "followup_question"],
|
|
100
|
+
),
|
|
101
|
+
):
|
|
102
|
+
completion, context_text, triplets = await retriever._run_cot_completion(
|
|
103
|
+
query="test query",
|
|
104
|
+
context=[mock_edge],
|
|
105
|
+
max_iter=1,
|
|
18
106
|
)
|
|
19
|
-
cognee.config.system_root_directory(system_directory_path)
|
|
20
|
-
data_directory_path = os.path.join(
|
|
21
|
-
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_simple"
|
|
22
|
-
)
|
|
23
|
-
cognee.config.data_root_directory(data_directory_path)
|
|
24
|
-
|
|
25
|
-
await cognee.prune.prune_data()
|
|
26
|
-
await cognee.prune.prune_system(metadata=True)
|
|
27
|
-
await setup()
|
|
28
|
-
|
|
29
|
-
class Company(DataPoint):
|
|
30
|
-
name: str
|
|
31
|
-
|
|
32
|
-
class Person(DataPoint):
|
|
33
|
-
name: str
|
|
34
|
-
works_for: Company
|
|
35
|
-
|
|
36
|
-
company1 = Company(name="Figma")
|
|
37
|
-
company2 = Company(name="Canva")
|
|
38
|
-
person1 = Person(name="Steve Rodger", works_for=company1)
|
|
39
|
-
person2 = Person(name="Ike Loma", works_for=company1)
|
|
40
|
-
person3 = Person(name="Jason Statham", works_for=company1)
|
|
41
|
-
person4 = Person(name="Mike Broski", works_for=company2)
|
|
42
|
-
person5 = Person(name="Christina Mayer", works_for=company2)
|
|
43
|
-
|
|
44
|
-
entities = [company1, company2, person1, person2, person3, person4, person5]
|
|
45
|
-
|
|
46
|
-
await add_data_points(entities)
|
|
47
|
-
|
|
48
|
-
retriever = GraphCompletionCotRetriever()
|
|
49
107
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
108
|
+
assert completion == "Generated answer"
|
|
109
|
+
assert context_text == "Resolved context"
|
|
110
|
+
assert len(triplets) >= 1
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@pytest.mark.asyncio
|
|
114
|
+
async def test_run_cot_completion_round_zero_without_context(mock_edge):
|
|
115
|
+
"""Test _run_cot_completion round 0 without provided context."""
|
|
116
|
+
mock_graph_engine = AsyncMock()
|
|
117
|
+
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
118
|
+
|
|
119
|
+
retriever = GraphCompletionCotRetriever()
|
|
120
|
+
|
|
121
|
+
with (
|
|
122
|
+
patch(
|
|
123
|
+
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
124
|
+
return_value=mock_graph_engine,
|
|
125
|
+
),
|
|
126
|
+
patch(
|
|
127
|
+
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
|
128
|
+
return_value=[mock_edge],
|
|
129
|
+
),
|
|
130
|
+
patch(
|
|
131
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
132
|
+
return_value="Resolved context",
|
|
133
|
+
),
|
|
134
|
+
patch(
|
|
135
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
|
136
|
+
return_value="Generated answer",
|
|
137
|
+
),
|
|
138
|
+
):
|
|
139
|
+
completion, context_text, triplets = await retriever._run_cot_completion(
|
|
140
|
+
query="test query",
|
|
141
|
+
context=None,
|
|
142
|
+
max_iter=1,
|
|
60
143
|
)
|
|
61
144
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
145
|
+
assert completion == "Generated answer"
|
|
146
|
+
assert context_text == "Resolved context"
|
|
147
|
+
assert len(triplets) >= 1
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
@pytest.mark.asyncio
|
|
151
|
+
async def test_run_cot_completion_multiple_rounds(mock_edge):
|
|
152
|
+
"""Test _run_cot_completion with multiple rounds."""
|
|
153
|
+
retriever = GraphCompletionCotRetriever()
|
|
154
|
+
|
|
155
|
+
mock_edge2 = MagicMock(spec=Edge)
|
|
156
|
+
|
|
157
|
+
with (
|
|
158
|
+
patch(
|
|
159
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
160
|
+
return_value="Resolved context",
|
|
161
|
+
),
|
|
162
|
+
patch(
|
|
163
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
|
164
|
+
return_value="Generated answer",
|
|
165
|
+
),
|
|
166
|
+
patch.object(
|
|
167
|
+
retriever,
|
|
168
|
+
"get_context",
|
|
169
|
+
new_callable=AsyncMock,
|
|
170
|
+
side_effect=[[mock_edge], [mock_edge2]],
|
|
171
|
+
),
|
|
172
|
+
patch(
|
|
173
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
|
|
174
|
+
return_value="Rendered prompt",
|
|
175
|
+
),
|
|
176
|
+
patch(
|
|
177
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
|
|
178
|
+
return_value="System prompt",
|
|
179
|
+
),
|
|
180
|
+
patch.object(
|
|
181
|
+
LLMGateway,
|
|
182
|
+
"acreate_structured_output",
|
|
183
|
+
new_callable=AsyncMock,
|
|
184
|
+
side_effect=[
|
|
185
|
+
"validation_result",
|
|
186
|
+
"followup_question",
|
|
187
|
+
"validation_result2",
|
|
188
|
+
"followup_question2",
|
|
189
|
+
],
|
|
190
|
+
),
|
|
191
|
+
patch(
|
|
192
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
|
193
|
+
return_value="Generated answer",
|
|
194
|
+
),
|
|
195
|
+
):
|
|
196
|
+
completion, context_text, triplets = await retriever._run_cot_completion(
|
|
197
|
+
query="test query",
|
|
198
|
+
context=[mock_edge],
|
|
199
|
+
max_iter=2,
|
|
67
200
|
)
|
|
68
|
-
cognee.config.system_root_directory(system_directory_path)
|
|
69
|
-
data_directory_path = os.path.join(
|
|
70
|
-
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_complex"
|
|
71
|
-
)
|
|
72
|
-
cognee.config.data_root_directory(data_directory_path)
|
|
73
|
-
|
|
74
|
-
await cognee.prune.prune_data()
|
|
75
|
-
await cognee.prune.prune_system(metadata=True)
|
|
76
|
-
await setup()
|
|
77
|
-
|
|
78
|
-
class Company(DataPoint):
|
|
79
|
-
name: str
|
|
80
|
-
metadata: dict = {"index_fields": ["name"]}
|
|
81
|
-
|
|
82
|
-
class Car(DataPoint):
|
|
83
|
-
brand: str
|
|
84
|
-
model: str
|
|
85
|
-
year: int
|
|
86
|
-
|
|
87
|
-
class Location(DataPoint):
|
|
88
|
-
country: str
|
|
89
|
-
city: str
|
|
90
|
-
|
|
91
|
-
class Home(DataPoint):
|
|
92
|
-
location: Location
|
|
93
|
-
rooms: int
|
|
94
|
-
sqm: int
|
|
95
|
-
|
|
96
|
-
class Person(DataPoint):
|
|
97
|
-
name: str
|
|
98
|
-
works_for: Company
|
|
99
|
-
owns: Optional[list[Union[Car, Home]]] = None
|
|
100
201
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
|
|
128
|
-
|
|
129
|
-
print(context)
|
|
130
|
-
|
|
131
|
-
assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
|
|
132
|
-
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
|
|
133
|
-
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
|
|
134
|
-
|
|
135
|
-
answer = await retriever.get_completion("Who works at Figma?")
|
|
136
|
-
|
|
137
|
-
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
|
138
|
-
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
|
139
|
-
"Answer must contain only non-empty strings"
|
|
202
|
+
assert completion == "Generated answer"
|
|
203
|
+
assert context_text == "Resolved context"
|
|
204
|
+
assert len(triplets) >= 1
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
@pytest.mark.asyncio
|
|
208
|
+
async def test_run_cot_completion_with_conversation_history(mock_edge):
|
|
209
|
+
"""Test _run_cot_completion with conversation history."""
|
|
210
|
+
retriever = GraphCompletionCotRetriever()
|
|
211
|
+
|
|
212
|
+
with (
|
|
213
|
+
patch(
|
|
214
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
215
|
+
return_value="Resolved context",
|
|
216
|
+
),
|
|
217
|
+
patch(
|
|
218
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
|
219
|
+
return_value="Generated answer",
|
|
220
|
+
) as mock_generate,
|
|
221
|
+
):
|
|
222
|
+
completion, context_text, triplets = await retriever._run_cot_completion(
|
|
223
|
+
query="test query",
|
|
224
|
+
context=[mock_edge],
|
|
225
|
+
conversation_history="Previous conversation",
|
|
226
|
+
max_iter=1,
|
|
140
227
|
)
|
|
141
228
|
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
229
|
+
assert completion == "Generated answer"
|
|
230
|
+
call_kwargs = mock_generate.call_args[1]
|
|
231
|
+
assert call_kwargs.get("conversation_history") == "Previous conversation"
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
@pytest.mark.asyncio
|
|
235
|
+
async def test_run_cot_completion_with_response_model(mock_edge):
|
|
236
|
+
"""Test _run_cot_completion with custom response model."""
|
|
237
|
+
from pydantic import BaseModel
|
|
238
|
+
|
|
239
|
+
class TestModel(BaseModel):
|
|
240
|
+
answer: str
|
|
241
|
+
|
|
242
|
+
retriever = GraphCompletionCotRetriever()
|
|
243
|
+
|
|
244
|
+
with (
|
|
245
|
+
patch(
|
|
246
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
247
|
+
return_value="Resolved context",
|
|
248
|
+
),
|
|
249
|
+
patch(
|
|
250
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
|
251
|
+
return_value=TestModel(answer="Test answer"),
|
|
252
|
+
),
|
|
253
|
+
):
|
|
254
|
+
completion, context_text, triplets = await retriever._run_cot_completion(
|
|
255
|
+
query="test query",
|
|
256
|
+
context=[mock_edge],
|
|
257
|
+
response_model=TestModel,
|
|
258
|
+
max_iter=1,
|
|
147
259
|
)
|
|
148
|
-
cognee.config.system_root_directory(system_directory_path)
|
|
149
|
-
data_directory_path = os.path.join(
|
|
150
|
-
pathlib.Path(__file__).parent,
|
|
151
|
-
".data_storage/test_get_graph_completion_cot_context_on_empty_graph",
|
|
152
|
-
)
|
|
153
|
-
cognee.config.data_root_directory(data_directory_path)
|
|
154
|
-
|
|
155
|
-
await cognee.prune.prune_data()
|
|
156
|
-
await cognee.prune.prune_system(metadata=True)
|
|
157
260
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
261
|
+
assert isinstance(completion, TestModel)
|
|
262
|
+
assert completion.answer == "Test answer"
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
@pytest.mark.asyncio
|
|
266
|
+
async def test_run_cot_completion_empty_conversation_history(mock_edge):
|
|
267
|
+
"""Test _run_cot_completion with empty conversation history."""
|
|
268
|
+
retriever = GraphCompletionCotRetriever()
|
|
269
|
+
|
|
270
|
+
with (
|
|
271
|
+
patch(
|
|
272
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
273
|
+
return_value="Resolved context",
|
|
274
|
+
),
|
|
275
|
+
patch(
|
|
276
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
|
277
|
+
return_value="Generated answer",
|
|
278
|
+
) as mock_generate,
|
|
279
|
+
):
|
|
280
|
+
completion, context_text, triplets = await retriever._run_cot_completion(
|
|
281
|
+
query="test query",
|
|
282
|
+
context=[mock_edge],
|
|
283
|
+
conversation_history="",
|
|
284
|
+
max_iter=1,
|
|
285
|
+
)
|
|
164
286
|
|
|
165
|
-
|
|
287
|
+
assert completion == "Generated answer"
|
|
288
|
+
# Verify conversation_history was passed as None when empty
|
|
289
|
+
call_kwargs = mock_generate.call_args[1]
|
|
290
|
+
assert call_kwargs.get("conversation_history") is None
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
@pytest.mark.asyncio
|
|
294
|
+
async def test_get_completion_without_context(mock_edge):
|
|
295
|
+
"""Test get_completion retrieves context when not provided."""
|
|
296
|
+
mock_graph_engine = AsyncMock()
|
|
297
|
+
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
298
|
+
|
|
299
|
+
retriever = GraphCompletionCotRetriever()
|
|
300
|
+
|
|
301
|
+
with (
|
|
302
|
+
patch(
|
|
303
|
+
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
304
|
+
return_value=mock_graph_engine,
|
|
305
|
+
),
|
|
306
|
+
patch(
|
|
307
|
+
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
|
308
|
+
return_value=[mock_edge],
|
|
309
|
+
),
|
|
310
|
+
patch(
|
|
311
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
312
|
+
return_value="Resolved context",
|
|
313
|
+
),
|
|
314
|
+
patch(
|
|
315
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
|
316
|
+
return_value="Generated answer",
|
|
317
|
+
),
|
|
318
|
+
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
|
319
|
+
patch(
|
|
320
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
|
321
|
+
return_value="Generated answer",
|
|
322
|
+
),
|
|
323
|
+
patch(
|
|
324
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
|
|
325
|
+
return_value="Rendered prompt",
|
|
326
|
+
),
|
|
327
|
+
patch(
|
|
328
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
|
|
329
|
+
return_value="System prompt",
|
|
330
|
+
),
|
|
331
|
+
patch.object(
|
|
332
|
+
LLMGateway,
|
|
333
|
+
"acreate_structured_output",
|
|
334
|
+
new_callable=AsyncMock,
|
|
335
|
+
side_effect=["validation_result", "followup_question"],
|
|
336
|
+
),
|
|
337
|
+
patch(
|
|
338
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
|
|
339
|
+
) as mock_cache_config,
|
|
340
|
+
):
|
|
341
|
+
mock_config = MagicMock()
|
|
342
|
+
mock_config.caching = False
|
|
343
|
+
mock_cache_config.return_value = mock_config
|
|
344
|
+
|
|
345
|
+
completion = await retriever.get_completion("test query", max_iter=1)
|
|
346
|
+
|
|
347
|
+
assert isinstance(completion, list)
|
|
348
|
+
assert len(completion) == 1
|
|
349
|
+
assert completion[0] == "Generated answer"
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
@pytest.mark.asyncio
|
|
353
|
+
async def test_get_completion_with_provided_context(mock_edge):
|
|
354
|
+
"""Test get_completion uses provided context."""
|
|
355
|
+
retriever = GraphCompletionCotRetriever()
|
|
356
|
+
|
|
357
|
+
with (
|
|
358
|
+
patch(
|
|
359
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
360
|
+
return_value="Resolved context",
|
|
361
|
+
),
|
|
362
|
+
patch(
|
|
363
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
|
364
|
+
return_value="Generated answer",
|
|
365
|
+
),
|
|
366
|
+
patch(
|
|
367
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
|
|
368
|
+
) as mock_cache_config,
|
|
369
|
+
):
|
|
370
|
+
mock_config = MagicMock()
|
|
371
|
+
mock_config.caching = False
|
|
372
|
+
mock_cache_config.return_value = mock_config
|
|
373
|
+
|
|
374
|
+
completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1)
|
|
375
|
+
|
|
376
|
+
assert isinstance(completion, list)
|
|
377
|
+
assert len(completion) == 1
|
|
378
|
+
assert completion[0] == "Generated answer"
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
@pytest.mark.asyncio
|
|
382
|
+
async def test_get_completion_with_session(mock_edge):
|
|
383
|
+
"""Test get_completion with session caching enabled."""
|
|
384
|
+
mock_graph_engine = AsyncMock()
|
|
385
|
+
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
386
|
+
|
|
387
|
+
retriever = GraphCompletionCotRetriever()
|
|
388
|
+
|
|
389
|
+
mock_user = MagicMock()
|
|
390
|
+
mock_user.id = "test-user-id"
|
|
391
|
+
|
|
392
|
+
with (
|
|
393
|
+
patch(
|
|
394
|
+
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
395
|
+
return_value=mock_graph_engine,
|
|
396
|
+
),
|
|
397
|
+
patch(
|
|
398
|
+
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
|
399
|
+
return_value=[mock_edge],
|
|
400
|
+
),
|
|
401
|
+
patch(
|
|
402
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
403
|
+
return_value="Resolved context",
|
|
404
|
+
),
|
|
405
|
+
patch(
|
|
406
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.get_conversation_history",
|
|
407
|
+
return_value="Previous conversation",
|
|
408
|
+
),
|
|
409
|
+
patch(
|
|
410
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.summarize_text",
|
|
411
|
+
return_value="Context summary",
|
|
412
|
+
),
|
|
413
|
+
patch(
|
|
414
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
|
415
|
+
return_value="Generated answer",
|
|
416
|
+
),
|
|
417
|
+
patch(
|
|
418
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.save_conversation_history",
|
|
419
|
+
) as mock_save,
|
|
420
|
+
patch(
|
|
421
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
|
|
422
|
+
) as mock_cache_config,
|
|
423
|
+
patch(
|
|
424
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.session_user"
|
|
425
|
+
) as mock_session_user,
|
|
426
|
+
):
|
|
427
|
+
mock_config = MagicMock()
|
|
428
|
+
mock_config.caching = True
|
|
429
|
+
mock_cache_config.return_value = mock_config
|
|
430
|
+
mock_session_user.get.return_value = mock_user
|
|
431
|
+
|
|
432
|
+
completion = await retriever.get_completion(
|
|
433
|
+
"test query", session_id="test_session", max_iter=1
|
|
434
|
+
)
|
|
166
435
|
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
436
|
+
assert isinstance(completion, list)
|
|
437
|
+
assert len(completion) == 1
|
|
438
|
+
assert completion[0] == "Generated answer"
|
|
439
|
+
mock_save.assert_awaited_once()
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
@pytest.mark.asyncio
|
|
443
|
+
async def test_get_completion_with_save_interaction(mock_edge):
|
|
444
|
+
"""Test get_completion with save_interaction enabled."""
|
|
445
|
+
mock_graph_engine = AsyncMock()
|
|
446
|
+
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
447
|
+
mock_graph_engine.add_edges = AsyncMock()
|
|
448
|
+
|
|
449
|
+
retriever = GraphCompletionCotRetriever(save_interaction=True)
|
|
450
|
+
|
|
451
|
+
mock_node1 = MagicMock()
|
|
452
|
+
mock_node2 = MagicMock()
|
|
453
|
+
mock_edge.node1 = mock_node1
|
|
454
|
+
mock_edge.node2 = mock_node2
|
|
455
|
+
|
|
456
|
+
with (
|
|
457
|
+
patch(
|
|
458
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
459
|
+
return_value="Resolved context",
|
|
460
|
+
),
|
|
461
|
+
patch(
|
|
462
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
|
463
|
+
return_value="Generated answer",
|
|
464
|
+
),
|
|
465
|
+
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
|
466
|
+
patch(
|
|
467
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
|
468
|
+
return_value="Generated answer",
|
|
469
|
+
),
|
|
470
|
+
patch(
|
|
471
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
|
|
472
|
+
return_value="Rendered prompt",
|
|
473
|
+
),
|
|
474
|
+
patch(
|
|
475
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
|
|
476
|
+
return_value="System prompt",
|
|
477
|
+
),
|
|
478
|
+
patch.object(
|
|
479
|
+
LLMGateway,
|
|
480
|
+
"acreate_structured_output",
|
|
481
|
+
new_callable=AsyncMock,
|
|
482
|
+
side_effect=["validation_result", "followup_question"],
|
|
483
|
+
),
|
|
484
|
+
patch(
|
|
485
|
+
"cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
|
|
486
|
+
side_effect=[
|
|
487
|
+
UUID("550e8400-e29b-41d4-a716-446655440000"),
|
|
488
|
+
UUID("550e8400-e29b-41d4-a716-446655440001"),
|
|
489
|
+
],
|
|
490
|
+
),
|
|
491
|
+
patch(
|
|
492
|
+
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
|
|
493
|
+
) as mock_add_data,
|
|
494
|
+
patch(
|
|
495
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
|
|
496
|
+
) as mock_cache_config,
|
|
497
|
+
):
|
|
498
|
+
mock_config = MagicMock()
|
|
499
|
+
mock_config.caching = False
|
|
500
|
+
mock_cache_config.return_value = mock_config
|
|
501
|
+
|
|
502
|
+
# Pass context so save_interaction condition is met
|
|
503
|
+
completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1)
|
|
504
|
+
|
|
505
|
+
assert isinstance(completion, list)
|
|
506
|
+
assert len(completion) == 1
|
|
507
|
+
mock_add_data.assert_awaited_once()
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
@pytest.mark.asyncio
|
|
511
|
+
async def test_get_completion_with_response_model(mock_edge):
|
|
512
|
+
"""Test get_completion with custom response model."""
|
|
513
|
+
from pydantic import BaseModel
|
|
514
|
+
|
|
515
|
+
class TestModel(BaseModel):
|
|
516
|
+
answer: str
|
|
517
|
+
|
|
518
|
+
mock_graph_engine = AsyncMock()
|
|
519
|
+
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
520
|
+
|
|
521
|
+
retriever = GraphCompletionCotRetriever()
|
|
522
|
+
|
|
523
|
+
with (
|
|
524
|
+
patch(
|
|
525
|
+
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
526
|
+
return_value=mock_graph_engine,
|
|
527
|
+
),
|
|
528
|
+
patch(
|
|
529
|
+
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
|
530
|
+
return_value=[mock_edge],
|
|
531
|
+
),
|
|
532
|
+
patch(
|
|
533
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
534
|
+
return_value="Resolved context",
|
|
535
|
+
),
|
|
536
|
+
patch(
|
|
537
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
|
538
|
+
return_value=TestModel(answer="Test answer"),
|
|
539
|
+
),
|
|
540
|
+
patch(
|
|
541
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
|
|
542
|
+
) as mock_cache_config,
|
|
543
|
+
):
|
|
544
|
+
mock_config = MagicMock()
|
|
545
|
+
mock_config.caching = False
|
|
546
|
+
mock_cache_config.return_value = mock_config
|
|
547
|
+
|
|
548
|
+
completion = await retriever.get_completion(
|
|
549
|
+
"test query", response_model=TestModel, max_iter=1
|
|
170
550
|
)
|
|
551
|
+
|
|
552
|
+
assert isinstance(completion, list)
|
|
553
|
+
assert len(completion) == 1
|
|
554
|
+
assert isinstance(completion[0], TestModel)
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
@pytest.mark.asyncio
|
|
558
|
+
async def test_get_completion_with_session_no_user_id(mock_edge):
|
|
559
|
+
"""Test get_completion with session config but no user ID."""
|
|
560
|
+
mock_graph_engine = AsyncMock()
|
|
561
|
+
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
562
|
+
|
|
563
|
+
retriever = GraphCompletionCotRetriever()
|
|
564
|
+
|
|
565
|
+
with (
|
|
566
|
+
patch(
|
|
567
|
+
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
568
|
+
return_value=mock_graph_engine,
|
|
569
|
+
),
|
|
570
|
+
patch(
|
|
571
|
+
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
|
572
|
+
return_value=[mock_edge],
|
|
573
|
+
),
|
|
574
|
+
patch(
|
|
575
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
576
|
+
return_value="Resolved context",
|
|
577
|
+
),
|
|
578
|
+
patch(
|
|
579
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
|
580
|
+
return_value="Generated answer",
|
|
581
|
+
),
|
|
582
|
+
patch(
|
|
583
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
|
|
584
|
+
) as mock_cache_config,
|
|
585
|
+
patch(
|
|
586
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.session_user"
|
|
587
|
+
) as mock_session_user,
|
|
588
|
+
):
|
|
589
|
+
mock_config = MagicMock()
|
|
590
|
+
mock_config.caching = True
|
|
591
|
+
mock_cache_config.return_value = mock_config
|
|
592
|
+
mock_session_user.get.return_value = None # No user
|
|
593
|
+
|
|
594
|
+
completion = await retriever.get_completion("test query", max_iter=1)
|
|
595
|
+
|
|
596
|
+
assert isinstance(completion, list)
|
|
597
|
+
assert len(completion) == 1
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
@pytest.mark.asyncio
|
|
601
|
+
async def test_get_completion_with_save_interaction_no_context(mock_edge):
|
|
602
|
+
"""Test get_completion with save_interaction but no context provided."""
|
|
603
|
+
retriever = GraphCompletionCotRetriever(save_interaction=True)
|
|
604
|
+
|
|
605
|
+
with (
|
|
606
|
+
patch(
|
|
607
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
608
|
+
return_value="Resolved context",
|
|
609
|
+
),
|
|
610
|
+
patch(
|
|
611
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
|
612
|
+
return_value="Generated answer",
|
|
613
|
+
),
|
|
614
|
+
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
|
615
|
+
patch(
|
|
616
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
|
617
|
+
return_value="Generated answer",
|
|
618
|
+
),
|
|
619
|
+
patch(
|
|
620
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
|
|
621
|
+
return_value="Rendered prompt",
|
|
622
|
+
),
|
|
623
|
+
patch(
|
|
624
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
|
|
625
|
+
return_value="System prompt",
|
|
626
|
+
),
|
|
627
|
+
patch.object(
|
|
628
|
+
LLMGateway,
|
|
629
|
+
"acreate_structured_output",
|
|
630
|
+
new_callable=AsyncMock,
|
|
631
|
+
side_effect=["validation_result", "followup_question"],
|
|
632
|
+
),
|
|
633
|
+
patch(
|
|
634
|
+
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
|
|
635
|
+
) as mock_cache_config,
|
|
636
|
+
):
|
|
637
|
+
mock_config = MagicMock()
|
|
638
|
+
mock_config.caching = False
|
|
639
|
+
mock_cache_config.return_value = mock_config
|
|
640
|
+
|
|
641
|
+
completion = await retriever.get_completion("test query", context=None, max_iter=1)
|
|
642
|
+
|
|
643
|
+
assert isinstance(completion, list)
|
|
644
|
+
assert len(completion) == 1
|
|
645
|
+
|
|
646
|
+
|
|
647
|
+
@pytest.mark.asyncio
|
|
648
|
+
async def test_as_answer_text_with_typeerror():
|
|
649
|
+
"""Test _as_answer_text handles TypeError when json.dumps fails."""
|
|
650
|
+
non_serializable = {1, 2, 3}
|
|
651
|
+
|
|
652
|
+
result = _as_answer_text(non_serializable)
|
|
653
|
+
|
|
654
|
+
assert isinstance(result, str)
|
|
655
|
+
assert result == str(non_serializable)
|
|
656
|
+
|
|
657
|
+
|
|
658
|
+
@pytest.mark.asyncio
|
|
659
|
+
async def test_as_answer_text_with_string():
|
|
660
|
+
"""Test _as_answer_text with string input."""
|
|
661
|
+
result = _as_answer_text("test string")
|
|
662
|
+
assert result == "test string"
|
|
663
|
+
|
|
664
|
+
|
|
665
|
+
@pytest.mark.asyncio
|
|
666
|
+
async def test_as_answer_text_with_dict():
|
|
667
|
+
"""Test _as_answer_text with dictionary input."""
|
|
668
|
+
test_dict = {"key": "value", "number": 42}
|
|
669
|
+
result = _as_answer_text(test_dict)
|
|
670
|
+
assert isinstance(result, str)
|
|
671
|
+
assert "key" in result
|
|
672
|
+
assert "value" in result
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
@pytest.mark.asyncio
|
|
676
|
+
async def test_as_answer_text_with_basemodel():
|
|
677
|
+
"""Test _as_answer_text with Pydantic BaseModel input."""
|
|
678
|
+
from pydantic import BaseModel
|
|
679
|
+
|
|
680
|
+
class TestModel(BaseModel):
|
|
681
|
+
answer: str
|
|
682
|
+
|
|
683
|
+
test_model = TestModel(answer="test answer")
|
|
684
|
+
result = _as_answer_text(test_model)
|
|
685
|
+
|
|
686
|
+
assert isinstance(result, str)
|
|
687
|
+
assert "[Structured Response]" in result
|
|
688
|
+
assert "test answer" in result
|