cognee 0.5.1__py3-none-any.whl → 0.5.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/api/v1/add/add.py +2 -1
- cognee/api/v1/datasets/routers/get_datasets_router.py +1 -0
- cognee/api/v1/memify/routers/get_memify_router.py +1 -0
- cognee/api/v1/search/search.py +0 -4
- cognee/infrastructure/databases/relational/config.py +16 -1
- cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +24 -2
- cognee/infrastructure/databases/vector/create_vector_engine.py +9 -2
- cognee/infrastructure/llm/LLMGateway.py +0 -13
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +5 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
- cognee/modules/data/models/Data.py +2 -1
- cognee/modules/retrieval/triplet_retriever.py +1 -1
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +0 -18
- cognee/modules/search/methods/search.py +18 -25
- cognee/tasks/ingestion/data_item.py +8 -0
- cognee/tasks/ingestion/ingest_data.py +12 -1
- cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
- cognee/tests/integration/retrieval/test_chunks_retriever.py +252 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever.py +268 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +226 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +218 -0
- cognee/tests/integration/retrieval/test_rag_completion_retriever.py +254 -0
- cognee/tests/{unit/modules/retrieval/structured_output_test.py → integration/retrieval/test_structured_output.py} +87 -77
- cognee/tests/integration/retrieval/test_summaries_retriever.py +184 -0
- cognee/tests/integration/retrieval/test_temporal_retriever.py +306 -0
- cognee/tests/integration/retrieval/test_triplet_retriever.py +35 -0
- cognee/tests/test_custom_data_label.py +68 -0
- cognee/tests/test_search_db.py +334 -181
- cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
- cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
- cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +181 -199
- cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +454 -162
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +674 -156
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +625 -200
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +319 -203
- cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +189 -155
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +539 -58
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +218 -9
- cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
- cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
- cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +246 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/METADATA +1 -1
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/RECORD +58 -45
- cognee/tests/unit/modules/search/test_search.py +0 -100
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/WHEEL +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -1,223 +1,648 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import pytest
|
|
3
|
-
import
|
|
4
|
-
from
|
|
2
|
+
from unittest.mock import AsyncMock, patch, MagicMock
|
|
3
|
+
from uuid import UUID
|
|
5
4
|
|
|
6
|
-
import cognee
|
|
7
|
-
from cognee.low_level import setup, DataPoint
|
|
8
|
-
from cognee.modules.graph.utils import resolve_edges_to_text
|
|
9
|
-
from cognee.tasks.storage import add_data_points
|
|
10
5
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
|
6
|
+
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
|
11
7
|
|
|
12
8
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
)
|
|
19
|
-
cognee.config.system_root_directory(system_directory_path)
|
|
20
|
-
data_directory_path = os.path.join(
|
|
21
|
-
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_simple"
|
|
22
|
-
)
|
|
23
|
-
cognee.config.data_root_directory(data_directory_path)
|
|
24
|
-
|
|
25
|
-
await cognee.prune.prune_data()
|
|
26
|
-
await cognee.prune.prune_system(metadata=True)
|
|
27
|
-
await setup()
|
|
28
|
-
|
|
29
|
-
class Company(DataPoint):
|
|
30
|
-
name: str
|
|
31
|
-
description: str
|
|
32
|
-
|
|
33
|
-
class Person(DataPoint):
|
|
34
|
-
name: str
|
|
35
|
-
description: str
|
|
36
|
-
works_for: Company
|
|
37
|
-
|
|
38
|
-
company1 = Company(name="Figma", description="Figma is a company")
|
|
39
|
-
company2 = Company(name="Canva", description="Canvas is a company")
|
|
40
|
-
person1 = Person(
|
|
41
|
-
name="Steve Rodger",
|
|
42
|
-
description="This is description about Steve Rodger",
|
|
43
|
-
works_for=company1,
|
|
44
|
-
)
|
|
45
|
-
person2 = Person(
|
|
46
|
-
name="Ike Loma", description="This is description about Ike Loma", works_for=company1
|
|
47
|
-
)
|
|
48
|
-
person3 = Person(
|
|
49
|
-
name="Jason Statham",
|
|
50
|
-
description="This is description about Jason Statham",
|
|
51
|
-
works_for=company1,
|
|
52
|
-
)
|
|
53
|
-
person4 = Person(
|
|
54
|
-
name="Mike Broski",
|
|
55
|
-
description="This is description about Mike Broski",
|
|
56
|
-
works_for=company2,
|
|
57
|
-
)
|
|
58
|
-
person5 = Person(
|
|
59
|
-
name="Christina Mayer",
|
|
60
|
-
description="This is description about Christina Mayer",
|
|
61
|
-
works_for=company2,
|
|
62
|
-
)
|
|
9
|
+
@pytest.fixture
|
|
10
|
+
def mock_edge():
|
|
11
|
+
"""Create a mock edge."""
|
|
12
|
+
edge = MagicMock(spec=Edge)
|
|
13
|
+
return edge
|
|
63
14
|
|
|
64
|
-
entities = [company1, company2, person1, person2, person3, person4, person5]
|
|
65
15
|
|
|
66
|
-
|
|
16
|
+
@pytest.mark.asyncio
|
|
17
|
+
async def test_get_triplets_success(mock_edge):
|
|
18
|
+
"""Test successful retrieval of triplets."""
|
|
19
|
+
retriever = GraphCompletionRetriever(top_k=5)
|
|
67
20
|
|
|
68
|
-
|
|
21
|
+
with patch(
|
|
22
|
+
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
|
23
|
+
return_value=[mock_edge],
|
|
24
|
+
) as mock_search:
|
|
25
|
+
triplets = await retriever.get_triplets("test query")
|
|
69
26
|
|
|
70
|
-
|
|
27
|
+
assert len(triplets) == 1
|
|
28
|
+
assert triplets[0] == mock_edge
|
|
29
|
+
mock_search.assert_awaited_once()
|
|
71
30
|
|
|
72
|
-
# Ensure the top-level sections are present
|
|
73
|
-
assert "Nodes:" in context, "Missing 'Nodes:' section in context"
|
|
74
|
-
assert "Connections:" in context, "Missing 'Connections:' section in context"
|
|
75
31
|
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
assert "Node: Jason Statham" in context, "Missing node header for Jason Statham"
|
|
81
|
-
assert "Node: Mike Broski" in context, "Missing node header for Mike Broski"
|
|
82
|
-
assert "Node: Canva" in context, "Missing node header for Canva"
|
|
83
|
-
assert "Node: Christina Mayer" in context, "Missing node header for Christina Mayer"
|
|
32
|
+
@pytest.mark.asyncio
|
|
33
|
+
async def test_get_triplets_empty_results():
|
|
34
|
+
"""Test that empty list is returned when no triplets are found."""
|
|
35
|
+
retriever = GraphCompletionRetriever()
|
|
84
36
|
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
assert "__node_content_start__\nFigma is a company\n__node_content_end__" in context, (
|
|
91
|
-
"Description block for Figma altered"
|
|
92
|
-
)
|
|
93
|
-
assert (
|
|
94
|
-
"__node_content_start__\nThis is description about Ike Loma\n__node_content_end__"
|
|
95
|
-
in context
|
|
96
|
-
), "Description block for Ike Loma altered"
|
|
97
|
-
assert (
|
|
98
|
-
"__node_content_start__\nThis is description about Jason Statham\n__node_content_end__"
|
|
99
|
-
in context
|
|
100
|
-
), "Description block for Jason Statham altered"
|
|
101
|
-
assert (
|
|
102
|
-
"__node_content_start__\nThis is description about Mike Broski\n__node_content_end__"
|
|
103
|
-
in context
|
|
104
|
-
), "Description block for Mike Broski altered"
|
|
105
|
-
assert "__node_content_start__\nCanvas is a company\n__node_content_end__" in context, (
|
|
106
|
-
"Description block for Canva altered"
|
|
107
|
-
)
|
|
108
|
-
assert (
|
|
109
|
-
"__node_content_start__\nThis is description about Christina Mayer\n__node_content_end__"
|
|
110
|
-
in context
|
|
111
|
-
), "Description block for Christina Mayer altered"
|
|
112
|
-
|
|
113
|
-
# --- Connections ---
|
|
114
|
-
assert "Steve Rodger --[works_for]--> Figma" in context, (
|
|
115
|
-
"Connection Steve Rodger→Figma missing or changed"
|
|
116
|
-
)
|
|
117
|
-
assert "Ike Loma --[works_for]--> Figma" in context, (
|
|
118
|
-
"Connection Ike Loma→Figma missing or changed"
|
|
119
|
-
)
|
|
120
|
-
assert "Jason Statham --[works_for]--> Figma" in context, (
|
|
121
|
-
"Connection Jason Statham→Figma missing or changed"
|
|
122
|
-
)
|
|
123
|
-
assert "Mike Broski --[works_for]--> Canva" in context, (
|
|
124
|
-
"Connection Mike Broski→Canva missing or changed"
|
|
125
|
-
)
|
|
126
|
-
assert "Christina Mayer --[works_for]--> Canva" in context, (
|
|
127
|
-
"Connection Christina Mayer→Canva missing or changed"
|
|
128
|
-
)
|
|
129
|
-
|
|
130
|
-
@pytest.mark.asyncio
|
|
131
|
-
async def test_graph_completion_context_complex(self):
|
|
132
|
-
system_directory_path = os.path.join(
|
|
133
|
-
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_complex"
|
|
134
|
-
)
|
|
135
|
-
cognee.config.system_root_directory(system_directory_path)
|
|
136
|
-
data_directory_path = os.path.join(
|
|
137
|
-
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_complex"
|
|
138
|
-
)
|
|
139
|
-
cognee.config.data_root_directory(data_directory_path)
|
|
37
|
+
with patch(
|
|
38
|
+
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
|
39
|
+
return_value=[],
|
|
40
|
+
):
|
|
41
|
+
triplets = await retriever.get_triplets("test query")
|
|
140
42
|
|
|
141
|
-
|
|
142
|
-
await cognee.prune.prune_system(metadata=True)
|
|
143
|
-
await setup()
|
|
43
|
+
assert triplets == []
|
|
144
44
|
|
|
145
|
-
class Company(DataPoint):
|
|
146
|
-
name: str
|
|
147
|
-
metadata: dict = {"index_fields": ["name"]}
|
|
148
45
|
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
46
|
+
@pytest.mark.asyncio
|
|
47
|
+
async def test_get_triplets_top_k_parameter():
|
|
48
|
+
"""Test that top_k parameter is passed to brute_force_triplet_search."""
|
|
49
|
+
retriever = GraphCompletionRetriever(top_k=10)
|
|
153
50
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
51
|
+
with patch(
|
|
52
|
+
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
|
53
|
+
return_value=[],
|
|
54
|
+
) as mock_search:
|
|
55
|
+
await retriever.get_triplets("test query")
|
|
157
56
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
rooms: int
|
|
161
|
-
sqm: int
|
|
57
|
+
call_kwargs = mock_search.call_args[1]
|
|
58
|
+
assert call_kwargs["top_k"] == 10
|
|
162
59
|
|
|
163
|
-
class Person(DataPoint):
|
|
164
|
-
name: str
|
|
165
|
-
works_for: Company
|
|
166
|
-
owns: Optional[list[Union[Car, Home]]] = None
|
|
167
60
|
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
61
|
+
@pytest.mark.asyncio
|
|
62
|
+
async def test_get_context_success(mock_edge):
|
|
63
|
+
"""Test successful retrieval of context."""
|
|
64
|
+
retriever = GraphCompletionRetriever()
|
|
65
|
+
|
|
66
|
+
mock_graph_engine = AsyncMock()
|
|
67
|
+
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
68
|
+
|
|
69
|
+
with (
|
|
70
|
+
patch(
|
|
71
|
+
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
72
|
+
return_value=mock_graph_engine,
|
|
73
|
+
),
|
|
74
|
+
patch(
|
|
75
|
+
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
|
76
|
+
return_value=[mock_edge],
|
|
77
|
+
),
|
|
78
|
+
):
|
|
79
|
+
context = await retriever.get_context("test query")
|
|
80
|
+
|
|
81
|
+
assert isinstance(context, list)
|
|
82
|
+
assert len(context) == 1
|
|
83
|
+
assert context[0] == mock_edge
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@pytest.mark.asyncio
|
|
87
|
+
async def test_get_context_empty_results():
|
|
88
|
+
"""Test that empty list is returned when no context is found."""
|
|
89
|
+
retriever = GraphCompletionRetriever()
|
|
90
|
+
|
|
91
|
+
mock_graph_engine = AsyncMock()
|
|
92
|
+
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
93
|
+
|
|
94
|
+
with (
|
|
95
|
+
patch(
|
|
96
|
+
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
97
|
+
return_value=mock_graph_engine,
|
|
98
|
+
),
|
|
99
|
+
patch(
|
|
100
|
+
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
|
101
|
+
return_value=[],
|
|
102
|
+
),
|
|
103
|
+
):
|
|
104
|
+
context = await retriever.get_context("test query")
|
|
105
|
+
|
|
106
|
+
assert context == []
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@pytest.mark.asyncio
|
|
110
|
+
async def test_get_context_empty_graph():
|
|
111
|
+
"""Test that empty list is returned when graph is empty."""
|
|
112
|
+
retriever = GraphCompletionRetriever()
|
|
113
|
+
|
|
114
|
+
mock_graph_engine = AsyncMock()
|
|
115
|
+
mock_graph_engine.is_empty = AsyncMock(return_value=True)
|
|
116
|
+
|
|
117
|
+
with patch(
|
|
118
|
+
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
119
|
+
return_value=mock_graph_engine,
|
|
120
|
+
):
|
|
121
|
+
context = await retriever.get_context("test query")
|
|
122
|
+
|
|
123
|
+
assert context == []
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@pytest.mark.asyncio
|
|
127
|
+
async def test_resolve_edges_to_text(mock_edge):
|
|
128
|
+
"""Test resolve_edges_to_text method."""
|
|
129
|
+
retriever = GraphCompletionRetriever()
|
|
130
|
+
|
|
131
|
+
with patch(
|
|
132
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
133
|
+
return_value="Resolved text",
|
|
134
|
+
) as mock_resolve:
|
|
135
|
+
result = await retriever.resolve_edges_to_text([mock_edge])
|
|
136
|
+
|
|
137
|
+
assert result == "Resolved text"
|
|
138
|
+
mock_resolve.assert_awaited_once_with([mock_edge])
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@pytest.mark.asyncio
|
|
142
|
+
async def test_init_defaults():
|
|
143
|
+
"""Test GraphCompletionRetriever initialization with defaults."""
|
|
144
|
+
retriever = GraphCompletionRetriever()
|
|
145
|
+
|
|
146
|
+
assert retriever.top_k == 5
|
|
147
|
+
assert retriever.user_prompt_path == "graph_context_for_question.txt"
|
|
148
|
+
assert retriever.system_prompt_path == "answer_simple_question.txt"
|
|
149
|
+
assert retriever.node_type is None
|
|
150
|
+
assert retriever.node_name is None
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@pytest.mark.asyncio
|
|
154
|
+
async def test_init_custom_params():
|
|
155
|
+
"""Test GraphCompletionRetriever initialization with custom parameters."""
|
|
156
|
+
retriever = GraphCompletionRetriever(
|
|
157
|
+
top_k=10,
|
|
158
|
+
user_prompt_path="custom_user.txt",
|
|
159
|
+
system_prompt_path="custom_system.txt",
|
|
160
|
+
system_prompt="Custom prompt",
|
|
161
|
+
node_type=str,
|
|
162
|
+
node_name=["node1"],
|
|
163
|
+
save_interaction=True,
|
|
164
|
+
wide_search_top_k=200,
|
|
165
|
+
triplet_distance_penalty=5.0,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
assert retriever.top_k == 10
|
|
169
|
+
assert retriever.user_prompt_path == "custom_user.txt"
|
|
170
|
+
assert retriever.system_prompt_path == "custom_system.txt"
|
|
171
|
+
assert retriever.system_prompt == "Custom prompt"
|
|
172
|
+
assert retriever.node_type is str
|
|
173
|
+
assert retriever.node_name == ["node1"]
|
|
174
|
+
assert retriever.save_interaction is True
|
|
175
|
+
assert retriever.wide_search_top_k == 200
|
|
176
|
+
assert retriever.triplet_distance_penalty == 5.0
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
@pytest.mark.asyncio
|
|
180
|
+
async def test_init_none_top_k():
|
|
181
|
+
"""Test GraphCompletionRetriever initialization with None top_k."""
|
|
182
|
+
retriever = GraphCompletionRetriever(top_k=None)
|
|
183
|
+
|
|
184
|
+
assert retriever.top_k == 5 # None defaults to 5
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@pytest.mark.asyncio
|
|
188
|
+
async def test_convert_retrieved_objects_to_context(mock_edge):
|
|
189
|
+
"""Test convert_retrieved_objects_to_context method."""
|
|
190
|
+
retriever = GraphCompletionRetriever()
|
|
191
|
+
|
|
192
|
+
with patch(
|
|
193
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
194
|
+
return_value="Resolved text",
|
|
195
|
+
) as mock_resolve:
|
|
196
|
+
result = await retriever.convert_retrieved_objects_to_context([mock_edge])
|
|
197
|
+
|
|
198
|
+
assert result == "Resolved text"
|
|
199
|
+
mock_resolve.assert_awaited_once_with([mock_edge])
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
@pytest.mark.asyncio
|
|
203
|
+
async def test_get_completion_without_context(mock_edge):
|
|
204
|
+
"""Test get_completion retrieves context when not provided."""
|
|
205
|
+
mock_graph_engine = AsyncMock()
|
|
206
|
+
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
207
|
+
|
|
208
|
+
retriever = GraphCompletionRetriever()
|
|
209
|
+
|
|
210
|
+
with (
|
|
211
|
+
patch(
|
|
212
|
+
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
213
|
+
return_value=mock_graph_engine,
|
|
214
|
+
),
|
|
215
|
+
patch(
|
|
216
|
+
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
|
217
|
+
return_value=[mock_edge],
|
|
218
|
+
),
|
|
219
|
+
patch(
|
|
220
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
221
|
+
return_value="Resolved context",
|
|
222
|
+
),
|
|
223
|
+
patch(
|
|
224
|
+
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
|
|
225
|
+
return_value="Generated answer",
|
|
226
|
+
),
|
|
227
|
+
patch(
|
|
228
|
+
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
|
|
229
|
+
) as mock_cache_config,
|
|
230
|
+
):
|
|
231
|
+
mock_config = MagicMock()
|
|
232
|
+
mock_config.caching = False
|
|
233
|
+
mock_cache_config.return_value = mock_config
|
|
234
|
+
|
|
235
|
+
completion = await retriever.get_completion("test query")
|
|
236
|
+
|
|
237
|
+
assert isinstance(completion, list)
|
|
238
|
+
assert len(completion) == 1
|
|
239
|
+
assert completion[0] == "Generated answer"
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
@pytest.mark.asyncio
|
|
243
|
+
async def test_get_completion_with_provided_context(mock_edge):
|
|
244
|
+
"""Test get_completion uses provided context."""
|
|
245
|
+
retriever = GraphCompletionRetriever()
|
|
246
|
+
|
|
247
|
+
with (
|
|
248
|
+
patch(
|
|
249
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
250
|
+
return_value="Resolved context",
|
|
251
|
+
),
|
|
252
|
+
patch(
|
|
253
|
+
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
|
|
254
|
+
return_value="Generated answer",
|
|
255
|
+
),
|
|
256
|
+
patch(
|
|
257
|
+
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
|
|
258
|
+
) as mock_cache_config,
|
|
259
|
+
):
|
|
260
|
+
mock_config = MagicMock()
|
|
261
|
+
mock_config.caching = False
|
|
262
|
+
mock_cache_config.return_value = mock_config
|
|
263
|
+
|
|
264
|
+
completion = await retriever.get_completion("test query", context=[mock_edge])
|
|
265
|
+
|
|
266
|
+
assert isinstance(completion, list)
|
|
267
|
+
assert len(completion) == 1
|
|
268
|
+
assert completion[0] == "Generated answer"
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
@pytest.mark.asyncio
|
|
272
|
+
async def test_get_completion_with_session(mock_edge):
|
|
273
|
+
"""Test get_completion with session caching enabled."""
|
|
274
|
+
mock_graph_engine = AsyncMock()
|
|
275
|
+
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
276
|
+
|
|
277
|
+
retriever = GraphCompletionRetriever()
|
|
278
|
+
|
|
279
|
+
mock_user = MagicMock()
|
|
280
|
+
mock_user.id = "test-user-id"
|
|
281
|
+
|
|
282
|
+
with (
|
|
283
|
+
patch(
|
|
284
|
+
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
285
|
+
return_value=mock_graph_engine,
|
|
286
|
+
),
|
|
287
|
+
patch(
|
|
288
|
+
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
|
289
|
+
return_value=[mock_edge],
|
|
290
|
+
),
|
|
291
|
+
patch(
|
|
292
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
293
|
+
return_value="Resolved context",
|
|
294
|
+
),
|
|
295
|
+
patch(
|
|
296
|
+
"cognee.modules.retrieval.graph_completion_retriever.get_conversation_history",
|
|
297
|
+
return_value="Previous conversation",
|
|
298
|
+
),
|
|
299
|
+
patch(
|
|
300
|
+
"cognee.modules.retrieval.graph_completion_retriever.summarize_text",
|
|
301
|
+
return_value="Context summary",
|
|
302
|
+
),
|
|
303
|
+
patch(
|
|
304
|
+
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
|
|
305
|
+
return_value="Generated answer",
|
|
306
|
+
),
|
|
307
|
+
patch(
|
|
308
|
+
"cognee.modules.retrieval.graph_completion_retriever.save_conversation_history",
|
|
309
|
+
) as mock_save,
|
|
310
|
+
patch(
|
|
311
|
+
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
|
|
312
|
+
) as mock_cache_config,
|
|
313
|
+
patch(
|
|
314
|
+
"cognee.modules.retrieval.graph_completion_retriever.session_user"
|
|
315
|
+
) as mock_session_user,
|
|
316
|
+
):
|
|
317
|
+
mock_config = MagicMock()
|
|
318
|
+
mock_config.caching = True
|
|
319
|
+
mock_cache_config.return_value = mock_config
|
|
320
|
+
mock_session_user.get.return_value = mock_user
|
|
321
|
+
|
|
322
|
+
completion = await retriever.get_completion("test query", session_id="test_session")
|
|
323
|
+
|
|
324
|
+
assert isinstance(completion, list)
|
|
325
|
+
assert len(completion) == 1
|
|
326
|
+
assert completion[0] == "Generated answer"
|
|
327
|
+
mock_save.assert_awaited_once()
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
@pytest.mark.asyncio
|
|
331
|
+
async def test_get_completion_with_response_model(mock_edge):
|
|
332
|
+
"""Test get_completion with custom response model."""
|
|
333
|
+
from pydantic import BaseModel
|
|
334
|
+
|
|
335
|
+
class TestModel(BaseModel):
|
|
336
|
+
answer: str
|
|
337
|
+
|
|
338
|
+
mock_graph_engine = AsyncMock()
|
|
339
|
+
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
340
|
+
|
|
341
|
+
retriever = GraphCompletionRetriever()
|
|
342
|
+
|
|
343
|
+
with (
|
|
344
|
+
patch(
|
|
345
|
+
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
346
|
+
return_value=mock_graph_engine,
|
|
347
|
+
),
|
|
348
|
+
patch(
|
|
349
|
+
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
|
350
|
+
return_value=[mock_edge],
|
|
351
|
+
),
|
|
352
|
+
patch(
|
|
353
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
354
|
+
return_value="Resolved context",
|
|
355
|
+
),
|
|
356
|
+
patch(
|
|
357
|
+
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
|
|
358
|
+
return_value=TestModel(answer="Test answer"),
|
|
359
|
+
),
|
|
360
|
+
patch(
|
|
361
|
+
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
|
|
362
|
+
) as mock_cache_config,
|
|
363
|
+
):
|
|
364
|
+
mock_config = MagicMock()
|
|
365
|
+
mock_config.caching = False
|
|
366
|
+
mock_cache_config.return_value = mock_config
|
|
367
|
+
|
|
368
|
+
completion = await retriever.get_completion("test query", response_model=TestModel)
|
|
369
|
+
|
|
370
|
+
assert isinstance(completion, list)
|
|
371
|
+
assert len(completion) == 1
|
|
372
|
+
assert isinstance(completion[0], TestModel)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
@pytest.mark.asyncio
|
|
376
|
+
async def test_get_completion_empty_context(mock_edge):
|
|
377
|
+
"""Test get_completion with empty context."""
|
|
378
|
+
mock_graph_engine = AsyncMock()
|
|
379
|
+
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
380
|
+
|
|
381
|
+
retriever = GraphCompletionRetriever()
|
|
382
|
+
|
|
383
|
+
with (
|
|
384
|
+
patch(
|
|
385
|
+
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
386
|
+
return_value=mock_graph_engine,
|
|
387
|
+
),
|
|
388
|
+
patch(
|
|
389
|
+
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
|
390
|
+
return_value=[],
|
|
391
|
+
),
|
|
392
|
+
patch(
|
|
393
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
394
|
+
return_value="",
|
|
395
|
+
),
|
|
396
|
+
patch(
|
|
397
|
+
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
|
|
398
|
+
return_value="Generated answer",
|
|
399
|
+
),
|
|
400
|
+
patch(
|
|
401
|
+
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
|
|
402
|
+
) as mock_cache_config,
|
|
403
|
+
):
|
|
404
|
+
mock_config = MagicMock()
|
|
405
|
+
mock_config.caching = False
|
|
406
|
+
mock_cache_config.return_value = mock_config
|
|
407
|
+
|
|
408
|
+
completion = await retriever.get_completion("test query")
|
|
409
|
+
|
|
410
|
+
assert isinstance(completion, list)
|
|
411
|
+
assert len(completion) == 1
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
@pytest.mark.asyncio
|
|
415
|
+
async def test_save_qa(mock_edge):
|
|
416
|
+
"""Test save_qa method."""
|
|
417
|
+
mock_graph_engine = AsyncMock()
|
|
418
|
+
mock_graph_engine.add_edges = AsyncMock()
|
|
419
|
+
|
|
420
|
+
retriever = GraphCompletionRetriever()
|
|
421
|
+
|
|
422
|
+
mock_node1 = MagicMock()
|
|
423
|
+
mock_node2 = MagicMock()
|
|
424
|
+
mock_edge.node1 = mock_node1
|
|
425
|
+
mock_edge.node2 = mock_node2
|
|
426
|
+
|
|
427
|
+
with (
|
|
428
|
+
patch(
|
|
429
|
+
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
430
|
+
return_value=mock_graph_engine,
|
|
431
|
+
),
|
|
432
|
+
patch(
|
|
433
|
+
"cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
|
|
434
|
+
side_effect=["uuid1", "uuid2"],
|
|
435
|
+
),
|
|
436
|
+
patch(
|
|
437
|
+
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
|
|
438
|
+
) as mock_add_data,
|
|
439
|
+
):
|
|
440
|
+
await retriever.save_qa(
|
|
441
|
+
question="Test question",
|
|
442
|
+
answer="Test answer",
|
|
443
|
+
context="Test context",
|
|
444
|
+
triplets=[mock_edge],
|
|
207
445
|
)
|
|
208
|
-
cognee.config.system_root_directory(system_directory_path)
|
|
209
|
-
data_directory_path = os.path.join(
|
|
210
|
-
pathlib.Path(__file__).parent,
|
|
211
|
-
".data_storage/test_get_graph_completion_context_on_empty_graph",
|
|
212
|
-
)
|
|
213
|
-
cognee.config.data_root_directory(data_directory_path)
|
|
214
|
-
|
|
215
|
-
await cognee.prune.prune_data()
|
|
216
|
-
await cognee.prune.prune_system(metadata=True)
|
|
217
446
|
|
|
218
|
-
|
|
447
|
+
mock_add_data.assert_awaited_once()
|
|
448
|
+
mock_graph_engine.add_edges.assert_awaited_once()
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
@pytest.mark.asyncio
|
|
452
|
+
async def test_save_qa_no_triplet_ids(mock_edge):
|
|
453
|
+
"""Test save_qa when triplets have no extractable IDs."""
|
|
454
|
+
mock_graph_engine = AsyncMock()
|
|
455
|
+
mock_graph_engine.add_edges = AsyncMock()
|
|
456
|
+
|
|
457
|
+
retriever = GraphCompletionRetriever()
|
|
458
|
+
|
|
459
|
+
mock_node1 = MagicMock()
|
|
460
|
+
mock_node2 = MagicMock()
|
|
461
|
+
mock_edge.node1 = mock_node1
|
|
462
|
+
mock_edge.node2 = mock_node2
|
|
463
|
+
|
|
464
|
+
with (
|
|
465
|
+
patch(
|
|
466
|
+
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
467
|
+
return_value=mock_graph_engine,
|
|
468
|
+
),
|
|
469
|
+
patch(
|
|
470
|
+
"cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
|
|
471
|
+
return_value=None,
|
|
472
|
+
),
|
|
473
|
+
patch(
|
|
474
|
+
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
|
|
475
|
+
) as mock_add_data,
|
|
476
|
+
):
|
|
477
|
+
await retriever.save_qa(
|
|
478
|
+
question="Test question",
|
|
479
|
+
answer="Test answer",
|
|
480
|
+
context="Test context",
|
|
481
|
+
triplets=[mock_edge],
|
|
482
|
+
)
|
|
219
483
|
|
|
220
|
-
|
|
484
|
+
mock_add_data.assert_awaited_once()
|
|
485
|
+
mock_graph_engine.add_edges.assert_not_called()
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
@pytest.mark.asyncio
|
|
489
|
+
async def test_save_qa_empty_triplets():
|
|
490
|
+
"""Test save_qa with empty triplets list."""
|
|
491
|
+
mock_graph_engine = AsyncMock()
|
|
492
|
+
mock_graph_engine.add_edges = AsyncMock()
|
|
493
|
+
|
|
494
|
+
retriever = GraphCompletionRetriever()
|
|
495
|
+
|
|
496
|
+
with (
|
|
497
|
+
patch(
|
|
498
|
+
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
499
|
+
return_value=mock_graph_engine,
|
|
500
|
+
),
|
|
501
|
+
patch(
|
|
502
|
+
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
|
|
503
|
+
) as mock_add_data,
|
|
504
|
+
):
|
|
505
|
+
await retriever.save_qa(
|
|
506
|
+
question="Test question",
|
|
507
|
+
answer="Test answer",
|
|
508
|
+
context="Test context",
|
|
509
|
+
triplets=[],
|
|
510
|
+
)
|
|
221
511
|
|
|
222
|
-
|
|
223
|
-
|
|
512
|
+
mock_add_data.assert_awaited_once()
|
|
513
|
+
mock_graph_engine.add_edges.assert_not_called()
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
@pytest.mark.asyncio
|
|
517
|
+
async def test_get_completion_with_save_interaction_no_completion(mock_edge):
|
|
518
|
+
"""Test get_completion with save_interaction but no completion."""
|
|
519
|
+
mock_graph_engine = AsyncMock()
|
|
520
|
+
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
521
|
+
|
|
522
|
+
retriever = GraphCompletionRetriever(save_interaction=True)
|
|
523
|
+
|
|
524
|
+
with (
|
|
525
|
+
patch(
|
|
526
|
+
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
527
|
+
return_value=mock_graph_engine,
|
|
528
|
+
),
|
|
529
|
+
patch(
|
|
530
|
+
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
|
531
|
+
return_value=[mock_edge],
|
|
532
|
+
),
|
|
533
|
+
patch(
|
|
534
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
535
|
+
return_value="Resolved context",
|
|
536
|
+
),
|
|
537
|
+
patch(
|
|
538
|
+
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
|
|
539
|
+
return_value=None, # No completion
|
|
540
|
+
),
|
|
541
|
+
patch(
|
|
542
|
+
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
|
|
543
|
+
) as mock_cache_config,
|
|
544
|
+
):
|
|
545
|
+
mock_config = MagicMock()
|
|
546
|
+
mock_config.caching = False
|
|
547
|
+
mock_cache_config.return_value = mock_config
|
|
548
|
+
|
|
549
|
+
completion = await retriever.get_completion("test query")
|
|
550
|
+
|
|
551
|
+
assert isinstance(completion, list)
|
|
552
|
+
assert len(completion) == 1
|
|
553
|
+
assert completion[0] is None
|
|
554
|
+
|
|
555
|
+
|
|
556
|
+
@pytest.mark.asyncio
|
|
557
|
+
async def test_get_completion_with_save_interaction_no_context(mock_edge):
|
|
558
|
+
"""Test get_completion with save_interaction but no context provided."""
|
|
559
|
+
mock_graph_engine = AsyncMock()
|
|
560
|
+
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
561
|
+
|
|
562
|
+
retriever = GraphCompletionRetriever(save_interaction=True)
|
|
563
|
+
|
|
564
|
+
with (
|
|
565
|
+
patch(
|
|
566
|
+
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
567
|
+
return_value=mock_graph_engine,
|
|
568
|
+
),
|
|
569
|
+
patch(
|
|
570
|
+
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
|
571
|
+
return_value=[mock_edge],
|
|
572
|
+
),
|
|
573
|
+
patch(
|
|
574
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
575
|
+
return_value="Resolved context",
|
|
576
|
+
),
|
|
577
|
+
patch(
|
|
578
|
+
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
|
|
579
|
+
return_value="Generated answer",
|
|
580
|
+
),
|
|
581
|
+
patch(
|
|
582
|
+
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
|
|
583
|
+
) as mock_cache_config,
|
|
584
|
+
):
|
|
585
|
+
mock_config = MagicMock()
|
|
586
|
+
mock_config.caching = False
|
|
587
|
+
mock_cache_config.return_value = mock_config
|
|
588
|
+
|
|
589
|
+
completion = await retriever.get_completion("test query", context=None)
|
|
590
|
+
|
|
591
|
+
assert isinstance(completion, list)
|
|
592
|
+
assert len(completion) == 1
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
@pytest.mark.asyncio
|
|
596
|
+
async def test_get_completion_with_save_interaction_all_conditions_met(mock_edge):
|
|
597
|
+
"""Test get_completion with save_interaction when all conditions are met (line 216)."""
|
|
598
|
+
mock_graph_engine = AsyncMock()
|
|
599
|
+
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
|
600
|
+
|
|
601
|
+
retriever = GraphCompletionRetriever(save_interaction=True)
|
|
602
|
+
|
|
603
|
+
mock_node1 = MagicMock()
|
|
604
|
+
mock_node2 = MagicMock()
|
|
605
|
+
mock_edge.node1 = mock_node1
|
|
606
|
+
mock_edge.node2 = mock_node2
|
|
607
|
+
|
|
608
|
+
with (
|
|
609
|
+
patch(
|
|
610
|
+
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
|
611
|
+
return_value=mock_graph_engine,
|
|
612
|
+
),
|
|
613
|
+
patch(
|
|
614
|
+
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
|
615
|
+
return_value=[mock_edge],
|
|
616
|
+
),
|
|
617
|
+
patch(
|
|
618
|
+
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
|
619
|
+
return_value="Resolved context",
|
|
620
|
+
),
|
|
621
|
+
patch(
|
|
622
|
+
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
|
|
623
|
+
return_value="Generated answer",
|
|
624
|
+
),
|
|
625
|
+
patch(
|
|
626
|
+
"cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
|
|
627
|
+
side_effect=[
|
|
628
|
+
UUID("550e8400-e29b-41d4-a716-446655440000"),
|
|
629
|
+
UUID("550e8400-e29b-41d4-a716-446655440001"),
|
|
630
|
+
],
|
|
631
|
+
),
|
|
632
|
+
patch(
|
|
633
|
+
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
|
|
634
|
+
) as mock_add_data,
|
|
635
|
+
patch(
|
|
636
|
+
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
|
|
637
|
+
) as mock_cache_config,
|
|
638
|
+
):
|
|
639
|
+
mock_config = MagicMock()
|
|
640
|
+
mock_config.caching = False
|
|
641
|
+
mock_cache_config.return_value = mock_config
|
|
642
|
+
|
|
643
|
+
completion = await retriever.get_completion("test query", context=[mock_edge])
|
|
644
|
+
|
|
645
|
+
assert isinstance(completion, list)
|
|
646
|
+
assert len(completion) == 1
|
|
647
|
+
assert completion[0] == "Generated answer"
|
|
648
|
+
mock_add_data.assert_awaited_once()
|