cognee 0.5.0.dev1__py3-none-any.whl → 0.5.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. cognee/api/v1/add/add.py +2 -1
  2. cognee/api/v1/datasets/routers/get_datasets_router.py +1 -0
  3. cognee/api/v1/memify/routers/get_memify_router.py +1 -0
  4. cognee/infrastructure/databases/relational/config.py +16 -1
  5. cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
  6. cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +24 -2
  7. cognee/infrastructure/databases/vector/create_vector_engine.py +9 -2
  8. cognee/infrastructure/llm/LLMGateway.py +0 -13
  9. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
  10. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
  11. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
  12. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +5 -5
  13. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
  14. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
  15. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
  16. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
  17. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
  18. cognee/modules/data/models/Data.py +2 -1
  19. cognee/modules/retrieval/triplet_retriever.py +1 -1
  20. cognee/modules/retrieval/utils/brute_force_triplet_search.py +0 -18
  21. cognee/tasks/ingestion/data_item.py +8 -0
  22. cognee/tasks/ingestion/ingest_data.py +12 -1
  23. cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
  24. cognee/tests/integration/retrieval/test_chunks_retriever.py +252 -0
  25. cognee/tests/integration/retrieval/test_graph_completion_retriever.py +268 -0
  26. cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +226 -0
  27. cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +218 -0
  28. cognee/tests/integration/retrieval/test_rag_completion_retriever.py +254 -0
  29. cognee/tests/{unit/modules/retrieval/structured_output_test.py → integration/retrieval/test_structured_output.py} +87 -77
  30. cognee/tests/integration/retrieval/test_summaries_retriever.py +184 -0
  31. cognee/tests/integration/retrieval/test_temporal_retriever.py +306 -0
  32. cognee/tests/integration/retrieval/test_triplet_retriever.py +35 -0
  33. cognee/tests/test_custom_data_label.py +68 -0
  34. cognee/tests/test_search_db.py +334 -181
  35. cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
  36. cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
  37. cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
  38. cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +181 -199
  39. cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
  40. cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +454 -162
  41. cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +674 -156
  42. cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +625 -200
  43. cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +319 -203
  44. cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +189 -155
  45. cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +539 -58
  46. cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +218 -9
  47. cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
  48. cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
  49. cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
  50. cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +246 -0
  51. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/METADATA +1 -1
  52. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/RECORD +56 -42
  53. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/WHEEL +0 -0
  54. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/entry_points.txt +0 -0
  55. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/LICENSE +0 -0
  56. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/NOTICE.md +0 -0
@@ -1,5 +1,10 @@
1
1
  import pathlib
2
2
  import os
3
+ import asyncio
4
+ import pytest
5
+ import pytest_asyncio
6
+ from collections import Counter
7
+
3
8
  import cognee
4
9
  from cognee.infrastructure.databases.graph import get_graph_engine
5
10
  from cognee.infrastructure.databases.vector import get_vector_engine
@@ -13,127 +18,172 @@ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphComplet
13
18
  from cognee.modules.retrieval.graph_summary_completion_retriever import (
14
19
  GraphSummaryCompletionRetriever,
15
20
  )
21
+ from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
22
+ from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
23
+ from cognee.modules.retrieval.completion_retriever import CompletionRetriever
24
+ from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
16
25
  from cognee.modules.retrieval.triplet_retriever import TripletRetriever
17
26
  from cognee.shared.logging_utils import get_logger
18
27
  from cognee.modules.search.types import SearchType
19
28
  from cognee.modules.users.methods import get_default_user
20
- from collections import Counter
21
29
 
22
30
  logger = get_logger()
23
31
 
24
32
 
25
- async def main():
26
- # This test runs for multiple db settings, to run this locally set the corresponding db envs
33
+ async def _reset_engines_and_prune() -> None:
34
+ """Reset db engine caches and prune data/system.
35
+
36
+ Kept intentionally identical to the inlined setup logic to avoid event loop issues when
37
+ using deployed databases (Neo4j, PostgreSQL) and to ensure fresh instances per run.
38
+ """
39
+ # Dispose of existing engines and clear caches to ensure fresh instances for each test
40
+ try:
41
+ from cognee.infrastructure.databases.vector import get_vector_engine
42
+
43
+ vector_engine = get_vector_engine()
44
+ # Dispose SQLAlchemy engine connection pool if it exists
45
+ if hasattr(vector_engine, "engine") and hasattr(vector_engine.engine, "dispose"):
46
+ await vector_engine.engine.dispose(close=True)
47
+ except Exception:
48
+ # Engine might not exist yet
49
+ pass
50
+
51
+ from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
52
+ from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
53
+ from cognee.infrastructure.databases.relational.create_relational_engine import (
54
+ create_relational_engine,
55
+ )
56
+
57
+ create_graph_engine.cache_clear()
58
+ create_vector_engine.cache_clear()
59
+ create_relational_engine.cache_clear()
60
+
27
61
  await cognee.prune.prune_data()
28
62
  await cognee.prune.prune_system(metadata=True)
29
63
 
30
- dataset_name = "test_dataset"
31
64
 
65
+ async def _seed_default_dataset(dataset_name: str) -> dict:
66
+ """Add the shared test dataset contents and run cognify (same steps/order as before)."""
32
67
  text_1 = """Germany is located in europe right next to the Netherlands"""
68
+
69
+ logger.info(f"Adding text data to dataset: {dataset_name}")
33
70
  await cognee.add(text_1, dataset_name)
34
71
 
35
72
  explanation_file_path_quantum = os.path.join(
36
73
  pathlib.Path(__file__).parent, "test_data/Quantum_computers.txt"
37
74
  )
38
75
 
76
+ logger.info(f"Adding file data to dataset: {dataset_name}")
39
77
  await cognee.add([explanation_file_path_quantum], dataset_name)
40
78
 
79
+ logger.info(f"Running cognify on dataset: {dataset_name}")
41
80
  await cognee.cognify([dataset_name])
42
81
 
82
+ return {
83
+ "dataset_name": dataset_name,
84
+ "text_1": text_1,
85
+ "explanation_file_path_quantum": explanation_file_path_quantum,
86
+ }
87
+
88
+
89
+ @pytest.fixture(scope="session")
90
+ def event_loop():
91
+ """Use a single asyncio event loop for this test module.
92
+
93
+ This helps avoid "Future attached to a different loop" when running multiple async
94
+ tests that share clients/engines.
95
+ """
96
+ loop = asyncio.new_event_loop()
97
+ try:
98
+ yield loop
99
+ finally:
100
+ loop.close()
101
+
102
+
103
+ async def setup_test_environment():
104
+ """Helper function to set up test environment with data, cognify, and triplet embeddings."""
105
+ # This test runs for multiple db settings, to run this locally set the corresponding db envs
106
+
107
+ dataset_name = "test_dataset"
108
+ logger.info("Starting test setup: pruning data and system")
109
+ await _reset_engines_and_prune()
110
+ state = await _seed_default_dataset(dataset_name=dataset_name)
111
+
43
112
  user = await get_default_user()
44
113
  from cognee.memify_pipelines.create_triplet_embeddings import create_triplet_embeddings
45
114
 
115
+ logger.info("Creating triplet embeddings")
46
116
  await create_triplet_embeddings(user=user, dataset=dataset_name, triplets_batch_size=5)
47
117
 
48
- graph_engine = await get_graph_engine()
49
- nodes, edges = await graph_engine.get_graph_data()
50
-
118
+ # Check if Triplet_text collection was created
51
119
  vector_engine = get_vector_engine()
52
- collection = await vector_engine.search(
53
- query_text="Test", limit=None, collection_name="Triplet_text"
54
- )
120
+ has_collection = await vector_engine.has_collection(collection_name="Triplet_text")
121
+ logger.info(f"Triplet_text collection exists after creation: {has_collection}")
55
122
 
56
- assert len(edges) == len(collection), (
57
- f"Expected {len(edges)} edges but got {len(collection)} in Triplet_text collection"
58
- )
123
+ if has_collection:
124
+ collection = await vector_engine.get_collection("Triplet_text")
125
+ count = await collection.count_rows() if hasattr(collection, "count_rows") else "unknown"
126
+ logger.info(f"Triplet_text collection row count: {count}")
59
127
 
60
- context_gk = await GraphCompletionRetriever().get_context(
61
- query="Next to which country is Germany located?"
62
- )
63
- context_gk_cot = await GraphCompletionCotRetriever().get_context(
64
- query="Next to which country is Germany located?"
65
- )
66
- context_gk_ext = await GraphCompletionContextExtensionRetriever().get_context(
67
- query="Next to which country is Germany located?"
68
- )
69
- context_gk_sum = await GraphSummaryCompletionRetriever().get_context(
70
- query="Next to which country is Germany located?"
71
- )
72
- context_triplet = await TripletRetriever().get_context(
73
- query="Next to which country is Germany located?"
74
- )
128
+ return state
75
129
 
76
- for name, context in [
77
- ("GraphCompletionRetriever", context_gk),
78
- ("GraphCompletionCotRetriever", context_gk_cot),
79
- ("GraphCompletionContextExtensionRetriever", context_gk_ext),
80
- ("GraphSummaryCompletionRetriever", context_gk_sum),
81
- ]:
82
- assert isinstance(context, list), f"{name}: Context should be a list"
83
- assert len(context) > 0, f"{name}: Context should not be empty"
84
130
 
85
- context_text = await resolve_edges_to_text(context)
86
- lower = context_text.lower()
87
- assert "germany" in lower or "netherlands" in lower, (
88
- f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}"
89
- )
131
+ async def setup_test_environment_for_feedback():
132
+ """Helper function to set up test environment for feedback weight calculation test."""
133
+ dataset_name = "test_dataset"
134
+ await _reset_engines_and_prune()
135
+ return await _seed_default_dataset(dataset_name=dataset_name)
90
136
 
91
- assert isinstance(context_triplet, str), "TripletRetriever: Context should be a string"
92
- assert len(context_triplet) > 0, "TripletRetriever: Context should not be empty"
93
- lower_triplet = context_triplet.lower()
94
- assert "germany" in lower_triplet or "netherlands" in lower_triplet, (
95
- f"TripletRetriever: Context did not contain 'germany' or 'netherlands'; got: {context_triplet!r}"
96
- )
97
137
 
98
- triplets_gk = await GraphCompletionRetriever().get_triplets(
99
- query="Next to which country is Germany located?"
100
- )
101
- triplets_gk_cot = await GraphCompletionCotRetriever().get_triplets(
102
- query="Next to which country is Germany located?"
103
- )
104
- triplets_gk_ext = await GraphCompletionContextExtensionRetriever().get_triplets(
105
- query="Next to which country is Germany located?"
106
- )
107
- triplets_gk_sum = await GraphSummaryCompletionRetriever().get_triplets(
108
- query="Next to which country is Germany located?"
109
- )
138
+ @pytest_asyncio.fixture(scope="session")
139
+ async def e2e_state():
140
+ """Compute E2E artifacts once; tests only assert.
110
141
 
111
- for name, triplets in [
112
- ("GraphCompletionRetriever", triplets_gk),
113
- ("GraphCompletionCotRetriever", triplets_gk_cot),
114
- ("GraphCompletionContextExtensionRetriever", triplets_gk_ext),
115
- ("GraphSummaryCompletionRetriever", triplets_gk_sum),
116
- ]:
117
- assert isinstance(triplets, list), f"{name}: Triplets should be a list"
118
- assert triplets, f"{name}: Triplets list should not be empty"
119
- for edge in triplets:
120
- assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances"
121
- distance = edge.attributes.get("vector_distance")
122
- node1_distance = edge.node1.attributes.get("vector_distance")
123
- node2_distance = edge.node2.attributes.get("vector_distance")
124
- assert isinstance(distance, float), (
125
- f"{name}: vector_distance should be float, got {type(distance)}"
126
- )
127
- assert 0 <= distance <= 1, (
128
- f"{name}: edge vector_distance {distance} out of [0,1], this shouldn't happen"
129
- )
130
- assert 0 <= node1_distance <= 1, (
131
- f"{name}: node_1 vector_distance {distance} out of [0,1], this shouldn't happen"
132
- )
133
- assert 0 <= node2_distance <= 1, (
134
- f"{name}: node_2 vector_distance {distance} out of [0,1], this shouldn't happen"
135
- )
142
+ This avoids repeating expensive setup and LLM calls across multiple tests.
143
+ """
144
+ await setup_test_environment()
145
+
146
+ # --- Graph/vector engine consistency ---
147
+ graph_engine = await get_graph_engine()
148
+ _nodes, edges = await graph_engine.get_graph_data()
136
149
 
150
+ vector_engine = get_vector_engine()
151
+ collection = await vector_engine.search(
152
+ collection_name="Triplet_text", query_text="Test", limit=None
153
+ )
154
+
155
+ # --- Retriever contexts ---
156
+ query = "Next to which country is Germany located?"
157
+
158
+ contexts = {
159
+ "graph_completion": await GraphCompletionRetriever().get_context(query=query),
160
+ "graph_completion_cot": await GraphCompletionCotRetriever().get_context(query=query),
161
+ "graph_completion_context_extension": await GraphCompletionContextExtensionRetriever().get_context(
162
+ query=query
163
+ ),
164
+ "graph_summary_completion": await GraphSummaryCompletionRetriever().get_context(
165
+ query=query
166
+ ),
167
+ "chunks": await ChunksRetriever(top_k=5).get_context(query=query),
168
+ "summaries": await SummariesRetriever(top_k=5).get_context(query=query),
169
+ "rag_completion": await CompletionRetriever(top_k=3).get_context(query=query),
170
+ "temporal": await TemporalRetriever(top_k=5).get_context(query=query),
171
+ "triplet": await TripletRetriever().get_context(query=query),
172
+ }
173
+
174
+ # --- Retriever triplets + vector distance validation ---
175
+ triplets = {
176
+ "graph_completion": await GraphCompletionRetriever().get_triplets(query=query),
177
+ "graph_completion_cot": await GraphCompletionCotRetriever().get_triplets(query=query),
178
+ "graph_completion_context_extension": await GraphCompletionContextExtensionRetriever().get_triplets(
179
+ query=query
180
+ ),
181
+ "graph_summary_completion": await GraphSummaryCompletionRetriever().get_triplets(
182
+ query=query
183
+ ),
184
+ }
185
+
186
+ # --- Search operations + graph side effects ---
137
187
  completion_gk = await cognee.search(
138
188
  query_type=SearchType.GRAPH_COMPLETION,
139
189
  query_text="Where is germany located, next to which country?",
@@ -164,6 +214,26 @@ async def main():
164
214
  query_text="Next to which country is Germany located?",
165
215
  save_interaction=True,
166
216
  )
217
+ completion_chunks = await cognee.search(
218
+ query_type=SearchType.CHUNKS,
219
+ query_text="Germany",
220
+ save_interaction=False,
221
+ )
222
+ completion_summaries = await cognee.search(
223
+ query_type=SearchType.SUMMARIES,
224
+ query_text="Germany",
225
+ save_interaction=False,
226
+ )
227
+ completion_rag = await cognee.search(
228
+ query_type=SearchType.RAG_COMPLETION,
229
+ query_text="Next to which country is Germany located?",
230
+ save_interaction=False,
231
+ )
232
+ completion_temporal = await cognee.search(
233
+ query_type=SearchType.TEMPORAL,
234
+ query_text="Next to which country is Germany located?",
235
+ save_interaction=False,
236
+ )
167
237
 
168
238
  await cognee.search(
169
239
  query_type=SearchType.FEEDBACK,
@@ -171,134 +241,217 @@ async def main():
171
241
  last_k=1,
172
242
  )
173
243
 
174
- for name, search_results in [
175
- ("GRAPH_COMPLETION", completion_gk),
176
- ("GRAPH_COMPLETION_COT", completion_cot),
177
- ("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext),
178
- ("GRAPH_SUMMARY_COMPLETION", completion_sum),
179
- ("TRIPLET_COMPLETION", completion_triplet),
244
+ # Snapshot after all E2E operations above (used by assertion-only tests).
245
+ graph_snapshot = await (await get_graph_engine()).get_graph_data()
246
+
247
+ return {
248
+ "graph_edges": edges,
249
+ "triplet_collection": collection,
250
+ "vector_collection_edges_count": len(collection),
251
+ "graph_edges_count": len(edges),
252
+ "contexts": contexts,
253
+ "triplets": triplets,
254
+ "search_results": {
255
+ "graph_completion": completion_gk,
256
+ "graph_completion_cot": completion_cot,
257
+ "graph_completion_context_extension": completion_ext,
258
+ "graph_summary_completion": completion_sum,
259
+ "triplet_completion": completion_triplet,
260
+ "chunks": completion_chunks,
261
+ "summaries": completion_summaries,
262
+ "rag_completion": completion_rag,
263
+ "temporal": completion_temporal,
264
+ },
265
+ "graph_snapshot": graph_snapshot,
266
+ }
267
+
268
+
269
+ @pytest_asyncio.fixture(scope="session")
270
+ async def feedback_state():
271
+ """Feedback-weight scenario computed once (fresh environment)."""
272
+ await setup_test_environment_for_feedback()
273
+
274
+ await cognee.search(
275
+ query_type=SearchType.GRAPH_COMPLETION,
276
+ query_text="Next to which country is Germany located?",
277
+ save_interaction=True,
278
+ )
279
+ await cognee.search(
280
+ query_type=SearchType.FEEDBACK,
281
+ query_text="This was the best answer I've ever seen",
282
+ last_k=1,
283
+ )
284
+ await cognee.search(
285
+ query_type=SearchType.FEEDBACK,
286
+ query_text="Wow the correctness of this answer blows my mind",
287
+ last_k=1,
288
+ )
289
+
290
+ graph_engine = await get_graph_engine()
291
+ graph = await graph_engine.get_graph_data()
292
+ return {"graph_snapshot": graph}
293
+
294
+
295
+ @pytest.mark.asyncio
296
+ async def test_e2e_graph_vector_consistency(e2e_state):
297
+ """Graph and vector stores contain the same triplet edges."""
298
+ assert e2e_state["graph_edges_count"] == e2e_state["vector_collection_edges_count"]
299
+
300
+
301
+ @pytest.mark.asyncio
302
+ async def test_e2e_retriever_contexts(e2e_state):
303
+ """All retrievers return non-empty, well-typed contexts."""
304
+ contexts = e2e_state["contexts"]
305
+
306
+ for name in [
307
+ "graph_completion",
308
+ "graph_completion_cot",
309
+ "graph_completion_context_extension",
310
+ "graph_summary_completion",
180
311
  ]:
181
- assert isinstance(search_results, list), f"{name}: should return a list"
182
- assert len(search_results) == 1, (
183
- f"{name}: expected single-element list, got {len(search_results)}"
312
+ ctx = contexts[name]
313
+ assert isinstance(ctx, list), f"{name}: Context should be a list"
314
+ assert ctx, f"{name}: Context should not be empty"
315
+ ctx_text = await resolve_edges_to_text(ctx)
316
+ lower = ctx_text.lower()
317
+ assert "germany" in lower or "netherlands" in lower, (
318
+ f"{name}: Context did not contain 'germany' or 'netherlands'; got: {ctx!r}"
184
319
  )
185
320
 
186
- from cognee.context_global_variables import backend_access_control_enabled
321
+ triplet_ctx = contexts["triplet"]
322
+ assert isinstance(triplet_ctx, str), "triplet: Context should be a string"
323
+ assert triplet_ctx.strip(), "triplet: Context should not be empty"
324
+
325
+ chunks_ctx = contexts["chunks"]
326
+ assert isinstance(chunks_ctx, list), "chunks: Context should be a list"
327
+ assert chunks_ctx, "chunks: Context should not be empty"
328
+ chunks_text = "\n".join(str(item.get("text", "")) for item in chunks_ctx).lower()
329
+ assert "germany" in chunks_text or "netherlands" in chunks_text
330
+
331
+ summaries_ctx = contexts["summaries"]
332
+ assert isinstance(summaries_ctx, list), "summaries: Context should be a list"
333
+ assert summaries_ctx, "summaries: Context should not be empty"
334
+ assert any(str(item.get("text", "")).strip() for item in summaries_ctx)
335
+
336
+ rag_ctx = contexts["rag_completion"]
337
+ assert isinstance(rag_ctx, str), "rag_completion: Context should be a string"
338
+ assert rag_ctx.strip(), "rag_completion: Context should not be empty"
339
+
340
+ temporal_ctx = contexts["temporal"]
341
+ assert isinstance(temporal_ctx, str), "temporal: Context should be a string"
342
+ assert temporal_ctx.strip(), "temporal: Context should not be empty"
343
+
344
+
345
+ @pytest.mark.asyncio
346
+ async def test_e2e_retriever_triplets_have_vector_distances(e2e_state):
347
+ """Graph retriever triplets include sane vector_distance metadata."""
348
+ for name, triplets in e2e_state["triplets"].items():
349
+ assert isinstance(triplets, list), f"{name}: Triplets should be a list"
350
+ assert triplets, f"{name}: Triplets list should not be empty"
351
+ for edge in triplets:
352
+ assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances"
353
+ distance = edge.attributes.get("vector_distance")
354
+ node1_distance = edge.node1.attributes.get("vector_distance")
355
+ node2_distance = edge.node2.attributes.get("vector_distance")
356
+ assert isinstance(distance, float), f"{name}: vector_distance should be float"
357
+ assert 0 <= distance <= 1
358
+ assert 0 <= node1_distance <= 1
359
+ assert 0 <= node2_distance <= 1
360
+
361
+
362
+ @pytest.mark.asyncio
363
+ async def test_e2e_search_results_and_wrappers(e2e_state):
364
+ """Search returns expected shapes across search types and access modes."""
365
+ from cognee.context_global_variables import backend_access_control_enabled
366
+
367
+ sr = e2e_state["search_results"]
368
+
369
+ # Completion-like search types: validate wrapper + content
370
+ for name in [
371
+ "graph_completion",
372
+ "graph_completion_cot",
373
+ "graph_completion_context_extension",
374
+ "graph_summary_completion",
375
+ "triplet_completion",
376
+ "rag_completion",
377
+ "temporal",
378
+ ]:
379
+ search_results = sr[name]
380
+ assert isinstance(search_results, list), f"{name}: should return a list"
381
+ assert len(search_results) == 1, f"{name}: expected single-element list"
187
382
 
188
383
  if backend_access_control_enabled():
189
- text = search_results[0]["search_result"][0]
384
+ wrapper = search_results[0]
385
+ assert isinstance(wrapper, dict), (
386
+ f"{name}: expected wrapper dict in access control mode"
387
+ )
388
+ assert wrapper.get("dataset_id"), f"{name}: missing dataset_id in wrapper"
389
+ assert wrapper.get("dataset_name") == "test_dataset"
390
+ assert "graphs" in wrapper
391
+ text = wrapper["search_result"][0]
190
392
  else:
191
393
  text = search_results[0]
192
- assert isinstance(text, str), f"{name}: element should be a string"
193
- assert text.strip(), f"{name}: string should not be empty"
194
- assert "netherlands" in text.lower(), (
195
- f"{name}: expected 'netherlands' in result, got: {text!r}"
196
- )
197
394
 
198
- graph_engine = await get_graph_engine()
199
- graph = await graph_engine.get_graph_data()
395
+ assert isinstance(text, str) and text.strip()
396
+ assert "netherlands" in text.lower()
200
397
 
201
- type_counts = Counter(node_data[1].get("type", {}) for node_data in graph[0])
398
+ # Non-LLM search types: CHUNKS / SUMMARIES validate payload list + text
399
+ for name in ["chunks", "summaries"]:
400
+ search_results = sr[name]
401
+ assert isinstance(search_results, list), f"{name}: should return a list"
402
+ assert search_results, f"{name}: should not be empty"
202
403
 
203
- edge_type_counts = Counter(edge_type[2] for edge_type in graph[1])
404
+ first = search_results[0]
405
+ assert isinstance(first, dict), f"{name}: expected dict entries"
204
406
 
205
- # Assert there are exactly 4 CogneeUserInteraction nodes.
206
- assert type_counts.get("CogneeUserInteraction", 0) == 4, (
207
- f"Expected exactly four CogneeUserInteraction nodes, but found {type_counts.get('CogneeUserInteraction', 0)}"
208
- )
407
+ payloads = search_results
408
+ if "search_result" in first and "text" not in first:
409
+ payloads = (first.get("search_result") or [None])[0]
209
410
 
210
- # Assert there is exactly two CogneeUserFeedback nodes.
211
- assert type_counts.get("CogneeUserFeedback", 0) == 2, (
212
- f"Expected exactly two CogneeUserFeedback nodes, but found {type_counts.get('CogneeUserFeedback', 0)}"
213
- )
214
-
215
- # Assert there is exactly two NodeSet.
216
- assert type_counts.get("NodeSet", 0) == 2, (
217
- f"Expected exactly two NodeSet nodes, but found {type_counts.get('NodeSet', 0)}"
218
- )
411
+ assert isinstance(payloads, list) and payloads
412
+ assert isinstance(payloads[0], dict)
413
+ assert str(payloads[0].get("text", "")).strip()
219
414
 
220
- # Assert that there are at least 10 'used_graph_element_to_answer' edges.
221
- assert edge_type_counts.get("used_graph_element_to_answer", 0) >= 10, (
222
- f"Expected at least ten 'used_graph_element_to_answer' edges, but found {edge_type_counts.get('used_graph_element_to_answer', 0)}"
223
- )
224
415
 
225
- # Assert that there are exactly 2 'gives_feedback_to' edges.
226
- assert edge_type_counts.get("gives_feedback_to", 0) == 2, (
227
- f"Expected exactly two 'gives_feedback_to' edges, but found {edge_type_counts.get('gives_feedback_to', 0)}"
228
- )
416
+ @pytest.mark.asyncio
417
+ async def test_e2e_graph_side_effects_and_node_fields(e2e_state):
418
+ """Search interactions create expected graph nodes/edges and required fields."""
419
+ graph = e2e_state["graph_snapshot"]
420
+ nodes, edges = graph
229
421
 
230
- # Assert that there are at least 6 'belongs_to_set' edges.
231
- assert edge_type_counts.get("belongs_to_set", 0) == 6, (
232
- f"Expected at least six 'belongs_to_set' edges, but found {edge_type_counts.get('belongs_to_set', 0)}"
233
- )
422
+ type_counts = Counter(node_data[1].get("type", {}) for node_data in nodes)
423
+ edge_type_counts = Counter(edge_type[2] for edge_type in edges)
234
424
 
235
- nodes = graph[0]
425
+ assert type_counts.get("CogneeUserInteraction", 0) == 4
426
+ assert type_counts.get("CogneeUserFeedback", 0) == 2
427
+ assert type_counts.get("NodeSet", 0) == 2
428
+ assert edge_type_counts.get("used_graph_element_to_answer", 0) >= 10
429
+ assert edge_type_counts.get("gives_feedback_to", 0) == 2
430
+ assert edge_type_counts.get("belongs_to_set", 0) >= 6
236
431
 
237
432
  required_fields_user_interaction = {"question", "answer", "context"}
238
433
  required_fields_feedback = {"feedback", "sentiment"}
239
434
 
240
435
  for node_id, data in nodes:
241
436
  if data.get("type") == "CogneeUserInteraction":
242
- assert required_fields_user_interaction.issubset(data.keys()), (
243
- f"Node {node_id} is missing fields: {required_fields_user_interaction - set(data.keys())}"
244
- )
245
-
437
+ assert required_fields_user_interaction.issubset(data.keys())
246
438
  for field in required_fields_user_interaction:
247
439
  value = data[field]
248
- assert isinstance(value, str) and value.strip(), (
249
- f"Node {node_id} has invalid value for '{field}': {value!r}"
250
- )
440
+ assert isinstance(value, str) and value.strip()
251
441
 
252
442
  if data.get("type") == "CogneeUserFeedback":
253
- assert required_fields_feedback.issubset(data.keys()), (
254
- f"Node {node_id} is missing fields: {required_fields_feedback - set(data.keys())}"
255
- )
256
-
443
+ assert required_fields_feedback.issubset(data.keys())
257
444
  for field in required_fields_feedback:
258
445
  value = data[field]
259
- assert isinstance(value, str) and value.strip(), (
260
- f"Node {node_id} has invalid value for '{field}': {value!r}"
261
- )
446
+ assert isinstance(value, str) and value.strip()
262
447
 
263
- await cognee.prune.prune_data()
264
- await cognee.prune.prune_system(metadata=True)
265
-
266
- await cognee.add(text_1, dataset_name)
267
-
268
- await cognee.add([text], dataset_name)
269
-
270
- await cognee.cognify([dataset_name])
271
-
272
- await cognee.search(
273
- query_type=SearchType.GRAPH_COMPLETION,
274
- query_text="Next to which country is Germany located?",
275
- save_interaction=True,
276
- )
277
448
 
278
- await cognee.search(
279
- query_type=SearchType.FEEDBACK,
280
- query_text="This was the best answer I've ever seen",
281
- last_k=1,
282
- )
283
-
284
- await cognee.search(
285
- query_type=SearchType.FEEDBACK,
286
- query_text="Wow the correctness of this answer blows my mind",
287
- last_k=1,
288
- )
289
-
290
- graph = await graph_engine.get_graph_data()
291
-
292
- edges = graph[1]
293
-
294
- for from_node, to_node, relationship_name, properties in edges:
449
+ @pytest.mark.asyncio
450
+ async def test_e2e_feedback_weight_calculation(feedback_state):
451
+ """Positive feedback increases used_graph_element_to_answer feedback_weight."""
452
+ _nodes, edges = feedback_state["graph_snapshot"]
453
+ for _from_node, _to_node, relationship_name, properties in edges:
295
454
  if relationship_name == "used_graph_element_to_answer":
296
455
  assert properties["feedback_weight"] >= 6, (
297
456
  "Feedback weight calculation is not correct, it should be more then 6."
298
457
  )
299
-
300
-
301
- if __name__ == "__main__":
302
- import asyncio
303
-
304
- asyncio.run(main())
@@ -11,6 +11,22 @@ MOCK_JSONL_DATA = """\
11
11
  {"id": "2", "question": "What is ML?", "answer": "Machine Learning", "paragraphs": [{"paragraph_text": "ML is a subset of AI."}]}
12
12
  """
13
13
 
14
+ MOCK_HOTPOT_CORPUS = [
15
+ {
16
+ "_id": "1",
17
+ "question": "Next to which country is Germany located?",
18
+ "answer": "Netherlands",
19
+ # HotpotQA uses "level"; TwoWikiMultiHop uses "type".
20
+ "level": "easy",
21
+ "type": "comparison",
22
+ "context": [
23
+ ["Germany", ["Germany is in Europe."]],
24
+ ["Netherlands", ["The Netherlands borders Germany."]],
25
+ ],
26
+ "supporting_facts": [["Netherlands", 0]],
27
+ }
28
+ ]
29
+
14
30
 
15
31
  ADAPTER_CLASSES = [
16
32
  HotpotQAAdapter,
@@ -35,6 +51,11 @@ def test_adapter_can_instantiate_and_load(AdapterClass):
35
51
  adapter = AdapterClass()
36
52
  result = adapter.load_corpus()
37
53
 
54
+ elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter):
55
+ with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
56
+ adapter = AdapterClass()
57
+ result = adapter.load_corpus()
58
+
38
59
  else:
39
60
  adapter = AdapterClass()
40
61
  result = adapter.load_corpus()
@@ -64,6 +85,10 @@ def test_adapter_returns_some_content(AdapterClass):
64
85
  ):
65
86
  adapter = AdapterClass()
66
87
  corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
88
+ elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter):
89
+ with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
90
+ adapter = AdapterClass()
91
+ corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
67
92
  else:
68
93
  adapter = AdapterClass()
69
94
  corpus_list, qa_pairs = adapter.load_corpus(limit=limit)