cognee 0.5.1__py3-none-any.whl → 0.5.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/api/v1/add/add.py +2 -1
- cognee/api/v1/datasets/routers/get_datasets_router.py +1 -0
- cognee/api/v1/memify/routers/get_memify_router.py +1 -0
- cognee/api/v1/search/search.py +0 -4
- cognee/infrastructure/databases/relational/config.py +16 -1
- cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +24 -2
- cognee/infrastructure/databases/vector/create_vector_engine.py +9 -2
- cognee/infrastructure/llm/LLMGateway.py +0 -13
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +5 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
- cognee/modules/data/models/Data.py +2 -1
- cognee/modules/retrieval/triplet_retriever.py +1 -1
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +0 -18
- cognee/modules/search/methods/search.py +18 -25
- cognee/tasks/ingestion/data_item.py +8 -0
- cognee/tasks/ingestion/ingest_data.py +12 -1
- cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
- cognee/tests/integration/retrieval/test_chunks_retriever.py +252 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever.py +268 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +226 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +218 -0
- cognee/tests/integration/retrieval/test_rag_completion_retriever.py +254 -0
- cognee/tests/{unit/modules/retrieval/structured_output_test.py → integration/retrieval/test_structured_output.py} +87 -77
- cognee/tests/integration/retrieval/test_summaries_retriever.py +184 -0
- cognee/tests/integration/retrieval/test_temporal_retriever.py +306 -0
- cognee/tests/integration/retrieval/test_triplet_retriever.py +35 -0
- cognee/tests/test_custom_data_label.py +68 -0
- cognee/tests/test_search_db.py +334 -181
- cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
- cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
- cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +181 -199
- cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +454 -162
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +674 -156
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +625 -200
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +319 -203
- cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +189 -155
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +539 -58
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +218 -9
- cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
- cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
- cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +246 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/METADATA +1 -1
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/RECORD +58 -45
- cognee/tests/unit/modules/search/test_search.py +0 -100
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/WHEEL +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/NOTICE.md +0 -0
cognee/tests/test_search_db.py
CHANGED
|
@@ -1,5 +1,10 @@
|
|
|
1
1
|
import pathlib
|
|
2
2
|
import os
|
|
3
|
+
import asyncio
|
|
4
|
+
import pytest
|
|
5
|
+
import pytest_asyncio
|
|
6
|
+
from collections import Counter
|
|
7
|
+
|
|
3
8
|
import cognee
|
|
4
9
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
5
10
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
@@ -13,127 +18,172 @@ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphComplet
|
|
|
13
18
|
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
|
14
19
|
GraphSummaryCompletionRetriever,
|
|
15
20
|
)
|
|
21
|
+
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
|
|
22
|
+
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
|
|
23
|
+
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
|
24
|
+
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
|
16
25
|
from cognee.modules.retrieval.triplet_retriever import TripletRetriever
|
|
17
26
|
from cognee.shared.logging_utils import get_logger
|
|
18
27
|
from cognee.modules.search.types import SearchType
|
|
19
28
|
from cognee.modules.users.methods import get_default_user
|
|
20
|
-
from collections import Counter
|
|
21
29
|
|
|
22
30
|
logger = get_logger()
|
|
23
31
|
|
|
24
32
|
|
|
25
|
-
async def
|
|
26
|
-
|
|
33
|
+
async def _reset_engines_and_prune() -> None:
|
|
34
|
+
"""Reset db engine caches and prune data/system.
|
|
35
|
+
|
|
36
|
+
Kept intentionally identical to the inlined setup logic to avoid event loop issues when
|
|
37
|
+
using deployed databases (Neo4j, PostgreSQL) and to ensure fresh instances per run.
|
|
38
|
+
"""
|
|
39
|
+
# Dispose of existing engines and clear caches to ensure fresh instances for each test
|
|
40
|
+
try:
|
|
41
|
+
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
42
|
+
|
|
43
|
+
vector_engine = get_vector_engine()
|
|
44
|
+
# Dispose SQLAlchemy engine connection pool if it exists
|
|
45
|
+
if hasattr(vector_engine, "engine") and hasattr(vector_engine.engine, "dispose"):
|
|
46
|
+
await vector_engine.engine.dispose(close=True)
|
|
47
|
+
except Exception:
|
|
48
|
+
# Engine might not exist yet
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
|
|
52
|
+
from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
|
|
53
|
+
from cognee.infrastructure.databases.relational.create_relational_engine import (
|
|
54
|
+
create_relational_engine,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
create_graph_engine.cache_clear()
|
|
58
|
+
create_vector_engine.cache_clear()
|
|
59
|
+
create_relational_engine.cache_clear()
|
|
60
|
+
|
|
27
61
|
await cognee.prune.prune_data()
|
|
28
62
|
await cognee.prune.prune_system(metadata=True)
|
|
29
63
|
|
|
30
|
-
dataset_name = "test_dataset"
|
|
31
64
|
|
|
65
|
+
async def _seed_default_dataset(dataset_name: str) -> dict:
|
|
66
|
+
"""Add the shared test dataset contents and run cognify (same steps/order as before)."""
|
|
32
67
|
text_1 = """Germany is located in europe right next to the Netherlands"""
|
|
68
|
+
|
|
69
|
+
logger.info(f"Adding text data to dataset: {dataset_name}")
|
|
33
70
|
await cognee.add(text_1, dataset_name)
|
|
34
71
|
|
|
35
72
|
explanation_file_path_quantum = os.path.join(
|
|
36
73
|
pathlib.Path(__file__).parent, "test_data/Quantum_computers.txt"
|
|
37
74
|
)
|
|
38
75
|
|
|
76
|
+
logger.info(f"Adding file data to dataset: {dataset_name}")
|
|
39
77
|
await cognee.add([explanation_file_path_quantum], dataset_name)
|
|
40
78
|
|
|
79
|
+
logger.info(f"Running cognify on dataset: {dataset_name}")
|
|
41
80
|
await cognee.cognify([dataset_name])
|
|
42
81
|
|
|
82
|
+
return {
|
|
83
|
+
"dataset_name": dataset_name,
|
|
84
|
+
"text_1": text_1,
|
|
85
|
+
"explanation_file_path_quantum": explanation_file_path_quantum,
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@pytest.fixture(scope="session")
|
|
90
|
+
def event_loop():
|
|
91
|
+
"""Use a single asyncio event loop for this test module.
|
|
92
|
+
|
|
93
|
+
This helps avoid "Future attached to a different loop" when running multiple async
|
|
94
|
+
tests that share clients/engines.
|
|
95
|
+
"""
|
|
96
|
+
loop = asyncio.new_event_loop()
|
|
97
|
+
try:
|
|
98
|
+
yield loop
|
|
99
|
+
finally:
|
|
100
|
+
loop.close()
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
async def setup_test_environment():
|
|
104
|
+
"""Helper function to set up test environment with data, cognify, and triplet embeddings."""
|
|
105
|
+
# This test runs for multiple db settings, to run this locally set the corresponding db envs
|
|
106
|
+
|
|
107
|
+
dataset_name = "test_dataset"
|
|
108
|
+
logger.info("Starting test setup: pruning data and system")
|
|
109
|
+
await _reset_engines_and_prune()
|
|
110
|
+
state = await _seed_default_dataset(dataset_name=dataset_name)
|
|
111
|
+
|
|
43
112
|
user = await get_default_user()
|
|
44
113
|
from cognee.memify_pipelines.create_triplet_embeddings import create_triplet_embeddings
|
|
45
114
|
|
|
115
|
+
logger.info("Creating triplet embeddings")
|
|
46
116
|
await create_triplet_embeddings(user=user, dataset=dataset_name, triplets_batch_size=5)
|
|
47
117
|
|
|
48
|
-
|
|
49
|
-
nodes, edges = await graph_engine.get_graph_data()
|
|
50
|
-
|
|
118
|
+
# Check if Triplet_text collection was created
|
|
51
119
|
vector_engine = get_vector_engine()
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
)
|
|
120
|
+
has_collection = await vector_engine.has_collection(collection_name="Triplet_text")
|
|
121
|
+
logger.info(f"Triplet_text collection exists after creation: {has_collection}")
|
|
55
122
|
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
123
|
+
if has_collection:
|
|
124
|
+
collection = await vector_engine.get_collection("Triplet_text")
|
|
125
|
+
count = await collection.count_rows() if hasattr(collection, "count_rows") else "unknown"
|
|
126
|
+
logger.info(f"Triplet_text collection row count: {count}")
|
|
59
127
|
|
|
60
|
-
|
|
61
|
-
query="Next to which country is Germany located?"
|
|
62
|
-
)
|
|
63
|
-
context_gk_cot = await GraphCompletionCotRetriever().get_context(
|
|
64
|
-
query="Next to which country is Germany located?"
|
|
65
|
-
)
|
|
66
|
-
context_gk_ext = await GraphCompletionContextExtensionRetriever().get_context(
|
|
67
|
-
query="Next to which country is Germany located?"
|
|
68
|
-
)
|
|
69
|
-
context_gk_sum = await GraphSummaryCompletionRetriever().get_context(
|
|
70
|
-
query="Next to which country is Germany located?"
|
|
71
|
-
)
|
|
72
|
-
context_triplet = await TripletRetriever().get_context(
|
|
73
|
-
query="Next to which country is Germany located?"
|
|
74
|
-
)
|
|
128
|
+
return state
|
|
75
129
|
|
|
76
|
-
for name, context in [
|
|
77
|
-
("GraphCompletionRetriever", context_gk),
|
|
78
|
-
("GraphCompletionCotRetriever", context_gk_cot),
|
|
79
|
-
("GraphCompletionContextExtensionRetriever", context_gk_ext),
|
|
80
|
-
("GraphSummaryCompletionRetriever", context_gk_sum),
|
|
81
|
-
]:
|
|
82
|
-
assert isinstance(context, list), f"{name}: Context should be a list"
|
|
83
|
-
assert len(context) > 0, f"{name}: Context should not be empty"
|
|
84
130
|
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
131
|
+
async def setup_test_environment_for_feedback():
|
|
132
|
+
"""Helper function to set up test environment for feedback weight calculation test."""
|
|
133
|
+
dataset_name = "test_dataset"
|
|
134
|
+
await _reset_engines_and_prune()
|
|
135
|
+
return await _seed_default_dataset(dataset_name=dataset_name)
|
|
90
136
|
|
|
91
|
-
assert isinstance(context_triplet, str), "TripletRetriever: Context should be a string"
|
|
92
|
-
assert len(context_triplet) > 0, "TripletRetriever: Context should not be empty"
|
|
93
|
-
lower_triplet = context_triplet.lower()
|
|
94
|
-
assert "germany" in lower_triplet or "netherlands" in lower_triplet, (
|
|
95
|
-
f"TripletRetriever: Context did not contain 'germany' or 'netherlands'; got: {context_triplet!r}"
|
|
96
|
-
)
|
|
97
137
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
triplets_gk_cot = await GraphCompletionCotRetriever().get_triplets(
|
|
102
|
-
query="Next to which country is Germany located?"
|
|
103
|
-
)
|
|
104
|
-
triplets_gk_ext = await GraphCompletionContextExtensionRetriever().get_triplets(
|
|
105
|
-
query="Next to which country is Germany located?"
|
|
106
|
-
)
|
|
107
|
-
triplets_gk_sum = await GraphSummaryCompletionRetriever().get_triplets(
|
|
108
|
-
query="Next to which country is Germany located?"
|
|
109
|
-
)
|
|
138
|
+
@pytest_asyncio.fixture(scope="session")
|
|
139
|
+
async def e2e_state():
|
|
140
|
+
"""Compute E2E artifacts once; tests only assert.
|
|
110
141
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
assert triplets, f"{name}: Triplets list should not be empty"
|
|
119
|
-
for edge in triplets:
|
|
120
|
-
assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances"
|
|
121
|
-
distance = edge.attributes.get("vector_distance")
|
|
122
|
-
node1_distance = edge.node1.attributes.get("vector_distance")
|
|
123
|
-
node2_distance = edge.node2.attributes.get("vector_distance")
|
|
124
|
-
assert isinstance(distance, float), (
|
|
125
|
-
f"{name}: vector_distance should be float, got {type(distance)}"
|
|
126
|
-
)
|
|
127
|
-
assert 0 <= distance <= 1, (
|
|
128
|
-
f"{name}: edge vector_distance {distance} out of [0,1], this shouldn't happen"
|
|
129
|
-
)
|
|
130
|
-
assert 0 <= node1_distance <= 1, (
|
|
131
|
-
f"{name}: node_1 vector_distance {distance} out of [0,1], this shouldn't happen"
|
|
132
|
-
)
|
|
133
|
-
assert 0 <= node2_distance <= 1, (
|
|
134
|
-
f"{name}: node_2 vector_distance {distance} out of [0,1], this shouldn't happen"
|
|
135
|
-
)
|
|
142
|
+
This avoids repeating expensive setup and LLM calls across multiple tests.
|
|
143
|
+
"""
|
|
144
|
+
await setup_test_environment()
|
|
145
|
+
|
|
146
|
+
# --- Graph/vector engine consistency ---
|
|
147
|
+
graph_engine = await get_graph_engine()
|
|
148
|
+
_nodes, edges = await graph_engine.get_graph_data()
|
|
136
149
|
|
|
150
|
+
vector_engine = get_vector_engine()
|
|
151
|
+
collection = await vector_engine.search(
|
|
152
|
+
collection_name="Triplet_text", query_text="Test", limit=None
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# --- Retriever contexts ---
|
|
156
|
+
query = "Next to which country is Germany located?"
|
|
157
|
+
|
|
158
|
+
contexts = {
|
|
159
|
+
"graph_completion": await GraphCompletionRetriever().get_context(query=query),
|
|
160
|
+
"graph_completion_cot": await GraphCompletionCotRetriever().get_context(query=query),
|
|
161
|
+
"graph_completion_context_extension": await GraphCompletionContextExtensionRetriever().get_context(
|
|
162
|
+
query=query
|
|
163
|
+
),
|
|
164
|
+
"graph_summary_completion": await GraphSummaryCompletionRetriever().get_context(
|
|
165
|
+
query=query
|
|
166
|
+
),
|
|
167
|
+
"chunks": await ChunksRetriever(top_k=5).get_context(query=query),
|
|
168
|
+
"summaries": await SummariesRetriever(top_k=5).get_context(query=query),
|
|
169
|
+
"rag_completion": await CompletionRetriever(top_k=3).get_context(query=query),
|
|
170
|
+
"temporal": await TemporalRetriever(top_k=5).get_context(query=query),
|
|
171
|
+
"triplet": await TripletRetriever().get_context(query=query),
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
# --- Retriever triplets + vector distance validation ---
|
|
175
|
+
triplets = {
|
|
176
|
+
"graph_completion": await GraphCompletionRetriever().get_triplets(query=query),
|
|
177
|
+
"graph_completion_cot": await GraphCompletionCotRetriever().get_triplets(query=query),
|
|
178
|
+
"graph_completion_context_extension": await GraphCompletionContextExtensionRetriever().get_triplets(
|
|
179
|
+
query=query
|
|
180
|
+
),
|
|
181
|
+
"graph_summary_completion": await GraphSummaryCompletionRetriever().get_triplets(
|
|
182
|
+
query=query
|
|
183
|
+
),
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
# --- Search operations + graph side effects ---
|
|
137
187
|
completion_gk = await cognee.search(
|
|
138
188
|
query_type=SearchType.GRAPH_COMPLETION,
|
|
139
189
|
query_text="Where is germany located, next to which country?",
|
|
@@ -164,6 +214,26 @@ async def main():
|
|
|
164
214
|
query_text="Next to which country is Germany located?",
|
|
165
215
|
save_interaction=True,
|
|
166
216
|
)
|
|
217
|
+
completion_chunks = await cognee.search(
|
|
218
|
+
query_type=SearchType.CHUNKS,
|
|
219
|
+
query_text="Germany",
|
|
220
|
+
save_interaction=False,
|
|
221
|
+
)
|
|
222
|
+
completion_summaries = await cognee.search(
|
|
223
|
+
query_type=SearchType.SUMMARIES,
|
|
224
|
+
query_text="Germany",
|
|
225
|
+
save_interaction=False,
|
|
226
|
+
)
|
|
227
|
+
completion_rag = await cognee.search(
|
|
228
|
+
query_type=SearchType.RAG_COMPLETION,
|
|
229
|
+
query_text="Next to which country is Germany located?",
|
|
230
|
+
save_interaction=False,
|
|
231
|
+
)
|
|
232
|
+
completion_temporal = await cognee.search(
|
|
233
|
+
query_type=SearchType.TEMPORAL,
|
|
234
|
+
query_text="Next to which country is Germany located?",
|
|
235
|
+
save_interaction=False,
|
|
236
|
+
)
|
|
167
237
|
|
|
168
238
|
await cognee.search(
|
|
169
239
|
query_type=SearchType.FEEDBACK,
|
|
@@ -171,134 +241,217 @@ async def main():
|
|
|
171
241
|
last_k=1,
|
|
172
242
|
)
|
|
173
243
|
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
244
|
+
# Snapshot after all E2E operations above (used by assertion-only tests).
|
|
245
|
+
graph_snapshot = await (await get_graph_engine()).get_graph_data()
|
|
246
|
+
|
|
247
|
+
return {
|
|
248
|
+
"graph_edges": edges,
|
|
249
|
+
"triplet_collection": collection,
|
|
250
|
+
"vector_collection_edges_count": len(collection),
|
|
251
|
+
"graph_edges_count": len(edges),
|
|
252
|
+
"contexts": contexts,
|
|
253
|
+
"triplets": triplets,
|
|
254
|
+
"search_results": {
|
|
255
|
+
"graph_completion": completion_gk,
|
|
256
|
+
"graph_completion_cot": completion_cot,
|
|
257
|
+
"graph_completion_context_extension": completion_ext,
|
|
258
|
+
"graph_summary_completion": completion_sum,
|
|
259
|
+
"triplet_completion": completion_triplet,
|
|
260
|
+
"chunks": completion_chunks,
|
|
261
|
+
"summaries": completion_summaries,
|
|
262
|
+
"rag_completion": completion_rag,
|
|
263
|
+
"temporal": completion_temporal,
|
|
264
|
+
},
|
|
265
|
+
"graph_snapshot": graph_snapshot,
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
@pytest_asyncio.fixture(scope="session")
|
|
270
|
+
async def feedback_state():
|
|
271
|
+
"""Feedback-weight scenario computed once (fresh environment)."""
|
|
272
|
+
await setup_test_environment_for_feedback()
|
|
273
|
+
|
|
274
|
+
await cognee.search(
|
|
275
|
+
query_type=SearchType.GRAPH_COMPLETION,
|
|
276
|
+
query_text="Next to which country is Germany located?",
|
|
277
|
+
save_interaction=True,
|
|
278
|
+
)
|
|
279
|
+
await cognee.search(
|
|
280
|
+
query_type=SearchType.FEEDBACK,
|
|
281
|
+
query_text="This was the best answer I've ever seen",
|
|
282
|
+
last_k=1,
|
|
283
|
+
)
|
|
284
|
+
await cognee.search(
|
|
285
|
+
query_type=SearchType.FEEDBACK,
|
|
286
|
+
query_text="Wow the correctness of this answer blows my mind",
|
|
287
|
+
last_k=1,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
graph_engine = await get_graph_engine()
|
|
291
|
+
graph = await graph_engine.get_graph_data()
|
|
292
|
+
return {"graph_snapshot": graph}
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
@pytest.mark.asyncio
|
|
296
|
+
async def test_e2e_graph_vector_consistency(e2e_state):
|
|
297
|
+
"""Graph and vector stores contain the same triplet edges."""
|
|
298
|
+
assert e2e_state["graph_edges_count"] == e2e_state["vector_collection_edges_count"]
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
@pytest.mark.asyncio
|
|
302
|
+
async def test_e2e_retriever_contexts(e2e_state):
|
|
303
|
+
"""All retrievers return non-empty, well-typed contexts."""
|
|
304
|
+
contexts = e2e_state["contexts"]
|
|
305
|
+
|
|
306
|
+
for name in [
|
|
307
|
+
"graph_completion",
|
|
308
|
+
"graph_completion_cot",
|
|
309
|
+
"graph_completion_context_extension",
|
|
310
|
+
"graph_summary_completion",
|
|
180
311
|
]:
|
|
181
|
-
|
|
182
|
-
assert
|
|
183
|
-
|
|
312
|
+
ctx = contexts[name]
|
|
313
|
+
assert isinstance(ctx, list), f"{name}: Context should be a list"
|
|
314
|
+
assert ctx, f"{name}: Context should not be empty"
|
|
315
|
+
ctx_text = await resolve_edges_to_text(ctx)
|
|
316
|
+
lower = ctx_text.lower()
|
|
317
|
+
assert "germany" in lower or "netherlands" in lower, (
|
|
318
|
+
f"{name}: Context did not contain 'germany' or 'netherlands'; got: {ctx!r}"
|
|
184
319
|
)
|
|
185
320
|
|
|
186
|
-
|
|
321
|
+
triplet_ctx = contexts["triplet"]
|
|
322
|
+
assert isinstance(triplet_ctx, str), "triplet: Context should be a string"
|
|
323
|
+
assert triplet_ctx.strip(), "triplet: Context should not be empty"
|
|
324
|
+
|
|
325
|
+
chunks_ctx = contexts["chunks"]
|
|
326
|
+
assert isinstance(chunks_ctx, list), "chunks: Context should be a list"
|
|
327
|
+
assert chunks_ctx, "chunks: Context should not be empty"
|
|
328
|
+
chunks_text = "\n".join(str(item.get("text", "")) for item in chunks_ctx).lower()
|
|
329
|
+
assert "germany" in chunks_text or "netherlands" in chunks_text
|
|
330
|
+
|
|
331
|
+
summaries_ctx = contexts["summaries"]
|
|
332
|
+
assert isinstance(summaries_ctx, list), "summaries: Context should be a list"
|
|
333
|
+
assert summaries_ctx, "summaries: Context should not be empty"
|
|
334
|
+
assert any(str(item.get("text", "")).strip() for item in summaries_ctx)
|
|
335
|
+
|
|
336
|
+
rag_ctx = contexts["rag_completion"]
|
|
337
|
+
assert isinstance(rag_ctx, str), "rag_completion: Context should be a string"
|
|
338
|
+
assert rag_ctx.strip(), "rag_completion: Context should not be empty"
|
|
339
|
+
|
|
340
|
+
temporal_ctx = contexts["temporal"]
|
|
341
|
+
assert isinstance(temporal_ctx, str), "temporal: Context should be a string"
|
|
342
|
+
assert temporal_ctx.strip(), "temporal: Context should not be empty"
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
@pytest.mark.asyncio
|
|
346
|
+
async def test_e2e_retriever_triplets_have_vector_distances(e2e_state):
|
|
347
|
+
"""Graph retriever triplets include sane vector_distance metadata."""
|
|
348
|
+
for name, triplets in e2e_state["triplets"].items():
|
|
349
|
+
assert isinstance(triplets, list), f"{name}: Triplets should be a list"
|
|
350
|
+
assert triplets, f"{name}: Triplets list should not be empty"
|
|
351
|
+
for edge in triplets:
|
|
352
|
+
assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances"
|
|
353
|
+
distance = edge.attributes.get("vector_distance")
|
|
354
|
+
node1_distance = edge.node1.attributes.get("vector_distance")
|
|
355
|
+
node2_distance = edge.node2.attributes.get("vector_distance")
|
|
356
|
+
assert isinstance(distance, float), f"{name}: vector_distance should be float"
|
|
357
|
+
assert 0 <= distance <= 1
|
|
358
|
+
assert 0 <= node1_distance <= 1
|
|
359
|
+
assert 0 <= node2_distance <= 1
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
@pytest.mark.asyncio
|
|
363
|
+
async def test_e2e_search_results_and_wrappers(e2e_state):
|
|
364
|
+
"""Search returns expected shapes across search types and access modes."""
|
|
365
|
+
from cognee.context_global_variables import backend_access_control_enabled
|
|
366
|
+
|
|
367
|
+
sr = e2e_state["search_results"]
|
|
368
|
+
|
|
369
|
+
# Completion-like search types: validate wrapper + content
|
|
370
|
+
for name in [
|
|
371
|
+
"graph_completion",
|
|
372
|
+
"graph_completion_cot",
|
|
373
|
+
"graph_completion_context_extension",
|
|
374
|
+
"graph_summary_completion",
|
|
375
|
+
"triplet_completion",
|
|
376
|
+
"rag_completion",
|
|
377
|
+
"temporal",
|
|
378
|
+
]:
|
|
379
|
+
search_results = sr[name]
|
|
380
|
+
assert isinstance(search_results, list), f"{name}: should return a list"
|
|
381
|
+
assert len(search_results) == 1, f"{name}: expected single-element list"
|
|
187
382
|
|
|
188
383
|
if backend_access_control_enabled():
|
|
189
|
-
|
|
384
|
+
wrapper = search_results[0]
|
|
385
|
+
assert isinstance(wrapper, dict), (
|
|
386
|
+
f"{name}: expected wrapper dict in access control mode"
|
|
387
|
+
)
|
|
388
|
+
assert wrapper.get("dataset_id"), f"{name}: missing dataset_id in wrapper"
|
|
389
|
+
assert wrapper.get("dataset_name") == "test_dataset"
|
|
390
|
+
assert "graphs" in wrapper
|
|
391
|
+
text = wrapper["search_result"][0]
|
|
190
392
|
else:
|
|
191
393
|
text = search_results[0]
|
|
192
|
-
assert isinstance(text, str), f"{name}: element should be a string"
|
|
193
|
-
assert text.strip(), f"{name}: string should not be empty"
|
|
194
|
-
assert "netherlands" in text.lower(), (
|
|
195
|
-
f"{name}: expected 'netherlands' in result, got: {text!r}"
|
|
196
|
-
)
|
|
197
394
|
|
|
198
|
-
|
|
199
|
-
|
|
395
|
+
assert isinstance(text, str) and text.strip()
|
|
396
|
+
assert "netherlands" in text.lower()
|
|
200
397
|
|
|
201
|
-
|
|
398
|
+
# Non-LLM search types: CHUNKS / SUMMARIES validate payload list + text
|
|
399
|
+
for name in ["chunks", "summaries"]:
|
|
400
|
+
search_results = sr[name]
|
|
401
|
+
assert isinstance(search_results, list), f"{name}: should return a list"
|
|
402
|
+
assert search_results, f"{name}: should not be empty"
|
|
202
403
|
|
|
203
|
-
|
|
404
|
+
first = search_results[0]
|
|
405
|
+
assert isinstance(first, dict), f"{name}: expected dict entries"
|
|
204
406
|
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
)
|
|
407
|
+
payloads = search_results
|
|
408
|
+
if "search_result" in first and "text" not in first:
|
|
409
|
+
payloads = (first.get("search_result") or [None])[0]
|
|
209
410
|
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
)
|
|
214
|
-
|
|
215
|
-
# Assert there is exactly two NodeSet.
|
|
216
|
-
assert type_counts.get("NodeSet", 0) == 2, (
|
|
217
|
-
f"Expected exactly two NodeSet nodes, but found {type_counts.get('NodeSet', 0)}"
|
|
218
|
-
)
|
|
411
|
+
assert isinstance(payloads, list) and payloads
|
|
412
|
+
assert isinstance(payloads[0], dict)
|
|
413
|
+
assert str(payloads[0].get("text", "")).strip()
|
|
219
414
|
|
|
220
|
-
# Assert that there are at least 10 'used_graph_element_to_answer' edges.
|
|
221
|
-
assert edge_type_counts.get("used_graph_element_to_answer", 0) >= 10, (
|
|
222
|
-
f"Expected at least ten 'used_graph_element_to_answer' edges, but found {edge_type_counts.get('used_graph_element_to_answer', 0)}"
|
|
223
|
-
)
|
|
224
415
|
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
416
|
+
@pytest.mark.asyncio
|
|
417
|
+
async def test_e2e_graph_side_effects_and_node_fields(e2e_state):
|
|
418
|
+
"""Search interactions create expected graph nodes/edges and required fields."""
|
|
419
|
+
graph = e2e_state["graph_snapshot"]
|
|
420
|
+
nodes, edges = graph
|
|
229
421
|
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
f"Expected at least six 'belongs_to_set' edges, but found {edge_type_counts.get('belongs_to_set', 0)}"
|
|
233
|
-
)
|
|
422
|
+
type_counts = Counter(node_data[1].get("type", {}) for node_data in nodes)
|
|
423
|
+
edge_type_counts = Counter(edge_type[2] for edge_type in edges)
|
|
234
424
|
|
|
235
|
-
|
|
425
|
+
assert type_counts.get("CogneeUserInteraction", 0) == 4
|
|
426
|
+
assert type_counts.get("CogneeUserFeedback", 0) == 2
|
|
427
|
+
assert type_counts.get("NodeSet", 0) == 2
|
|
428
|
+
assert edge_type_counts.get("used_graph_element_to_answer", 0) >= 10
|
|
429
|
+
assert edge_type_counts.get("gives_feedback_to", 0) == 2
|
|
430
|
+
assert edge_type_counts.get("belongs_to_set", 0) >= 6
|
|
236
431
|
|
|
237
432
|
required_fields_user_interaction = {"question", "answer", "context"}
|
|
238
433
|
required_fields_feedback = {"feedback", "sentiment"}
|
|
239
434
|
|
|
240
435
|
for node_id, data in nodes:
|
|
241
436
|
if data.get("type") == "CogneeUserInteraction":
|
|
242
|
-
assert required_fields_user_interaction.issubset(data.keys())
|
|
243
|
-
f"Node {node_id} is missing fields: {required_fields_user_interaction - set(data.keys())}"
|
|
244
|
-
)
|
|
245
|
-
|
|
437
|
+
assert required_fields_user_interaction.issubset(data.keys())
|
|
246
438
|
for field in required_fields_user_interaction:
|
|
247
439
|
value = data[field]
|
|
248
|
-
assert isinstance(value, str) and value.strip()
|
|
249
|
-
f"Node {node_id} has invalid value for '{field}': {value!r}"
|
|
250
|
-
)
|
|
440
|
+
assert isinstance(value, str) and value.strip()
|
|
251
441
|
|
|
252
442
|
if data.get("type") == "CogneeUserFeedback":
|
|
253
|
-
assert required_fields_feedback.issubset(data.keys())
|
|
254
|
-
f"Node {node_id} is missing fields: {required_fields_feedback - set(data.keys())}"
|
|
255
|
-
)
|
|
256
|
-
|
|
443
|
+
assert required_fields_feedback.issubset(data.keys())
|
|
257
444
|
for field in required_fields_feedback:
|
|
258
445
|
value = data[field]
|
|
259
|
-
assert isinstance(value, str) and value.strip()
|
|
260
|
-
f"Node {node_id} has invalid value for '{field}': {value!r}"
|
|
261
|
-
)
|
|
446
|
+
assert isinstance(value, str) and value.strip()
|
|
262
447
|
|
|
263
|
-
await cognee.prune.prune_data()
|
|
264
|
-
await cognee.prune.prune_system(metadata=True)
|
|
265
|
-
|
|
266
|
-
await cognee.add(text_1, dataset_name)
|
|
267
|
-
|
|
268
|
-
await cognee.add([text], dataset_name)
|
|
269
|
-
|
|
270
|
-
await cognee.cognify([dataset_name])
|
|
271
|
-
|
|
272
|
-
await cognee.search(
|
|
273
|
-
query_type=SearchType.GRAPH_COMPLETION,
|
|
274
|
-
query_text="Next to which country is Germany located?",
|
|
275
|
-
save_interaction=True,
|
|
276
|
-
)
|
|
277
448
|
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
await cognee.search(
|
|
285
|
-
query_type=SearchType.FEEDBACK,
|
|
286
|
-
query_text="Wow the correctness of this answer blows my mind",
|
|
287
|
-
last_k=1,
|
|
288
|
-
)
|
|
289
|
-
|
|
290
|
-
graph = await graph_engine.get_graph_data()
|
|
291
|
-
|
|
292
|
-
edges = graph[1]
|
|
293
|
-
|
|
294
|
-
for from_node, to_node, relationship_name, properties in edges:
|
|
449
|
+
@pytest.mark.asyncio
|
|
450
|
+
async def test_e2e_feedback_weight_calculation(feedback_state):
|
|
451
|
+
"""Positive feedback increases used_graph_element_to_answer feedback_weight."""
|
|
452
|
+
_nodes, edges = feedback_state["graph_snapshot"]
|
|
453
|
+
for _from_node, _to_node, relationship_name, properties in edges:
|
|
295
454
|
if relationship_name == "used_graph_element_to_answer":
|
|
296
455
|
assert properties["feedback_weight"] >= 6, (
|
|
297
456
|
"Feedback weight calculation is not correct, it should be more then 6."
|
|
298
457
|
)
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
if __name__ == "__main__":
|
|
302
|
-
import asyncio
|
|
303
|
-
|
|
304
|
-
asyncio.run(main())
|
|
@@ -11,6 +11,22 @@ MOCK_JSONL_DATA = """\
|
|
|
11
11
|
{"id": "2", "question": "What is ML?", "answer": "Machine Learning", "paragraphs": [{"paragraph_text": "ML is a subset of AI."}]}
|
|
12
12
|
"""
|
|
13
13
|
|
|
14
|
+
MOCK_HOTPOT_CORPUS = [
|
|
15
|
+
{
|
|
16
|
+
"_id": "1",
|
|
17
|
+
"question": "Next to which country is Germany located?",
|
|
18
|
+
"answer": "Netherlands",
|
|
19
|
+
# HotpotQA uses "level"; TwoWikiMultiHop uses "type".
|
|
20
|
+
"level": "easy",
|
|
21
|
+
"type": "comparison",
|
|
22
|
+
"context": [
|
|
23
|
+
["Germany", ["Germany is in Europe."]],
|
|
24
|
+
["Netherlands", ["The Netherlands borders Germany."]],
|
|
25
|
+
],
|
|
26
|
+
"supporting_facts": [["Netherlands", 0]],
|
|
27
|
+
}
|
|
28
|
+
]
|
|
29
|
+
|
|
14
30
|
|
|
15
31
|
ADAPTER_CLASSES = [
|
|
16
32
|
HotpotQAAdapter,
|
|
@@ -35,6 +51,11 @@ def test_adapter_can_instantiate_and_load(AdapterClass):
|
|
|
35
51
|
adapter = AdapterClass()
|
|
36
52
|
result = adapter.load_corpus()
|
|
37
53
|
|
|
54
|
+
elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter):
|
|
55
|
+
with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
|
|
56
|
+
adapter = AdapterClass()
|
|
57
|
+
result = adapter.load_corpus()
|
|
58
|
+
|
|
38
59
|
else:
|
|
39
60
|
adapter = AdapterClass()
|
|
40
61
|
result = adapter.load_corpus()
|
|
@@ -64,6 +85,10 @@ def test_adapter_returns_some_content(AdapterClass):
|
|
|
64
85
|
):
|
|
65
86
|
adapter = AdapterClass()
|
|
66
87
|
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
|
|
88
|
+
elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter):
|
|
89
|
+
with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
|
|
90
|
+
adapter = AdapterClass()
|
|
91
|
+
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
|
|
67
92
|
else:
|
|
68
93
|
adapter = AdapterClass()
|
|
69
94
|
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
|