cognee 0.5.1__py3-none-any.whl → 0.5.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. cognee/api/v1/add/add.py +2 -1
  2. cognee/api/v1/datasets/routers/get_datasets_router.py +1 -0
  3. cognee/api/v1/memify/routers/get_memify_router.py +1 -0
  4. cognee/api/v1/search/search.py +0 -4
  5. cognee/infrastructure/databases/relational/config.py +16 -1
  6. cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
  7. cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +24 -2
  8. cognee/infrastructure/databases/vector/create_vector_engine.py +9 -2
  9. cognee/infrastructure/llm/LLMGateway.py +0 -13
  10. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
  11. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
  12. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
  13. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +5 -5
  14. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
  15. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
  16. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
  17. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
  18. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
  19. cognee/modules/data/models/Data.py +2 -1
  20. cognee/modules/retrieval/triplet_retriever.py +1 -1
  21. cognee/modules/retrieval/utils/brute_force_triplet_search.py +0 -18
  22. cognee/modules/search/methods/search.py +18 -25
  23. cognee/tasks/ingestion/data_item.py +8 -0
  24. cognee/tasks/ingestion/ingest_data.py +12 -1
  25. cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
  26. cognee/tests/integration/retrieval/test_chunks_retriever.py +252 -0
  27. cognee/tests/integration/retrieval/test_graph_completion_retriever.py +268 -0
  28. cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +226 -0
  29. cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +218 -0
  30. cognee/tests/integration/retrieval/test_rag_completion_retriever.py +254 -0
  31. cognee/tests/{unit/modules/retrieval/structured_output_test.py → integration/retrieval/test_structured_output.py} +87 -77
  32. cognee/tests/integration/retrieval/test_summaries_retriever.py +184 -0
  33. cognee/tests/integration/retrieval/test_temporal_retriever.py +306 -0
  34. cognee/tests/integration/retrieval/test_triplet_retriever.py +35 -0
  35. cognee/tests/test_custom_data_label.py +68 -0
  36. cognee/tests/test_search_db.py +334 -181
  37. cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
  38. cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
  39. cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
  40. cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +181 -199
  41. cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
  42. cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +454 -162
  43. cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +674 -156
  44. cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +625 -200
  45. cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +319 -203
  46. cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +189 -155
  47. cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +539 -58
  48. cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +218 -9
  49. cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
  50. cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
  51. cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
  52. cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +246 -0
  53. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/METADATA +1 -1
  54. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/RECORD +58 -45
  55. cognee/tests/unit/modules/search/test_search.py +0 -100
  56. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/WHEEL +0 -0
  57. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/entry_points.txt +0 -0
  58. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/LICENSE +0 -0
  59. {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/NOTICE.md +0 -0
@@ -1,12 +1,14 @@
1
1
  import pytest
2
- from unittest.mock import AsyncMock, patch
2
+ from unittest.mock import AsyncMock, patch, MagicMock
3
3
 
4
4
  from cognee.modules.retrieval.utils.brute_force_triplet_search import (
5
5
  brute_force_triplet_search,
6
6
  get_memory_fragment,
7
+ format_triplets,
7
8
  )
8
9
  from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
9
10
  from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
11
+ from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
10
12
 
11
13
 
12
14
  class MockScoredResult:
@@ -354,20 +356,30 @@ async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation
354
356
 
355
357
  @pytest.mark.asyncio
356
358
  async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found():
357
- """Test that get_memory_fragment returns empty graph when entity not found."""
359
+ """Test that get_memory_fragment returns empty graph when entity not found (line 85)."""
358
360
  mock_graph_engine = AsyncMock()
359
- mock_graph_engine.project_graph_from_db = AsyncMock(
361
+
362
+ # Create a mock fragment that will raise EntityNotFoundError when project_graph_from_db is called
363
+ mock_fragment = MagicMock(spec=CogneeGraph)
364
+ mock_fragment.project_graph_from_db = AsyncMock(
360
365
  side_effect=EntityNotFoundError("Entity not found")
361
366
  )
362
367
 
363
- with patch(
364
- "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
365
- return_value=mock_graph_engine,
368
+ with (
369
+ patch(
370
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
371
+ return_value=mock_graph_engine,
372
+ ),
373
+ patch(
374
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.CogneeGraph",
375
+ return_value=mock_fragment,
376
+ ),
366
377
  ):
367
- fragment = await get_memory_fragment()
378
+ result = await get_memory_fragment()
368
379
 
369
- assert isinstance(fragment, CogneeGraph)
370
- assert len(fragment.nodes) == 0
380
+ # Fragment should be returned even though EntityNotFoundError was raised (pass statement on line 85)
381
+ assert result == mock_fragment
382
+ mock_fragment.project_graph_from_db.assert_awaited_once()
371
383
 
372
384
 
373
385
  @pytest.mark.asyncio
@@ -606,3 +618,200 @@ async def test_brute_force_triplet_search_mixed_empty_collections():
606
618
 
607
619
  call_kwargs = mock_get_fragment_fn.call_args[1]
608
620
  assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
621
+
622
+
623
+ def test_format_triplets():
624
+ """Test format_triplets function."""
625
+ mock_edge = MagicMock()
626
+ mock_node1 = MagicMock()
627
+ mock_node2 = MagicMock()
628
+
629
+ mock_node1.attributes = {"name": "Node1", "type": "Entity", "id": "n1"}
630
+ mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": "n2"}
631
+ mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": "connects"}
632
+
633
+ mock_edge.node1 = mock_node1
634
+ mock_edge.node2 = mock_node2
635
+
636
+ result = format_triplets([mock_edge])
637
+
638
+ assert isinstance(result, str)
639
+ assert "Node1" in result
640
+ assert "Node2" in result
641
+ assert "relates_to" in result
642
+ assert "connects" in result
643
+
644
+
645
+ def test_format_triplets_with_none_values():
646
+ """Test format_triplets filters out None values."""
647
+ mock_edge = MagicMock()
648
+ mock_node1 = MagicMock()
649
+ mock_node2 = MagicMock()
650
+
651
+ mock_node1.attributes = {"name": "Node1", "type": None, "id": "n1"}
652
+ mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": None}
653
+ mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": None}
654
+
655
+ mock_edge.node1 = mock_node1
656
+ mock_edge.node2 = mock_node2
657
+
658
+ result = format_triplets([mock_edge])
659
+
660
+ assert "Node1" in result
661
+ assert "Node2" in result
662
+ assert "relates_to" in result
663
+ assert "None" not in result or result.count("None") == 0
664
+
665
+
666
+ def test_format_triplets_with_nested_dict():
667
+ """Test format_triplets handles nested dict attributes (lines 23-35)."""
668
+ mock_edge = MagicMock()
669
+ mock_node1 = MagicMock()
670
+ mock_node2 = MagicMock()
671
+
672
+ mock_node1.attributes = {"name": "Node1", "metadata": {"type": "Entity", "id": "n1"}}
673
+ mock_node2.attributes = {"name": "Node2", "metadata": {"type": "Entity", "id": "n2"}}
674
+ mock_edge.attributes = {"relationship_name": "relates_to"}
675
+
676
+ mock_edge.node1 = mock_node1
677
+ mock_edge.node2 = mock_node2
678
+
679
+ result = format_triplets([mock_edge])
680
+
681
+ assert isinstance(result, str)
682
+ assert "Node1" in result
683
+ assert "Node2" in result
684
+ assert "relates_to" in result
685
+
686
+
687
+ @pytest.mark.asyncio
688
+ async def test_brute_force_triplet_search_vector_engine_init_error():
689
+ """Test brute_force_triplet_search handles vector engine initialization error (lines 145-147)."""
690
+ with (
691
+ patch(
692
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine"
693
+ ) as mock_get_vector_engine,
694
+ ):
695
+ mock_get_vector_engine.side_effect = Exception("Initialization error")
696
+
697
+ with pytest.raises(RuntimeError, match="Initialization error"):
698
+ await brute_force_triplet_search(query="test query")
699
+
700
+
701
+ @pytest.mark.asyncio
702
+ async def test_brute_force_triplet_search_collection_not_found_error():
703
+ """Test brute_force_triplet_search handles CollectionNotFoundError in search (lines 156-157)."""
704
+ mock_vector_engine = AsyncMock()
705
+ mock_embedding_engine = AsyncMock()
706
+ mock_vector_engine.embedding_engine = mock_embedding_engine
707
+ mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
708
+
709
+ mock_vector_engine.search = AsyncMock(
710
+ side_effect=[
711
+ CollectionNotFoundError("Collection not found"),
712
+ [],
713
+ [],
714
+ ]
715
+ )
716
+
717
+ with (
718
+ patch(
719
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
720
+ return_value=mock_vector_engine,
721
+ ),
722
+ patch(
723
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
724
+ return_value=CogneeGraph(),
725
+ ),
726
+ ):
727
+ result = await brute_force_triplet_search(
728
+ query="test query", collections=["missing_collection", "existing_collection"]
729
+ )
730
+
731
+ assert result == []
732
+
733
+
734
+ @pytest.mark.asyncio
735
+ async def test_brute_force_triplet_search_generic_exception():
736
+ """Test brute_force_triplet_search handles generic exceptions (lines 209-217)."""
737
+ mock_vector_engine = AsyncMock()
738
+ mock_embedding_engine = AsyncMock()
739
+ mock_vector_engine.embedding_engine = mock_embedding_engine
740
+ mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
741
+
742
+ mock_vector_engine.search = AsyncMock(side_effect=Exception("Generic error"))
743
+
744
+ with (
745
+ patch(
746
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
747
+ return_value=mock_vector_engine,
748
+ ),
749
+ ):
750
+ with pytest.raises(Exception, match="Generic error"):
751
+ await brute_force_triplet_search(query="test query")
752
+
753
+
754
+ @pytest.mark.asyncio
755
+ async def test_brute_force_triplet_search_with_node_name_sets_relevant_ids_to_none():
756
+ """Test brute_force_triplet_search sets relevant_ids_to_filter to None when node_name is provided (line 191)."""
757
+ mock_vector_engine = AsyncMock()
758
+ mock_embedding_engine = AsyncMock()
759
+ mock_vector_engine.embedding_engine = mock_embedding_engine
760
+ mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
761
+
762
+ mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"})
763
+ mock_vector_engine.search = AsyncMock(return_value=[mock_result])
764
+
765
+ mock_fragment = AsyncMock()
766
+ mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock()
767
+ mock_fragment.map_vector_distances_to_graph_edges = AsyncMock()
768
+ mock_fragment.calculate_top_triplet_importances = AsyncMock(return_value=[])
769
+
770
+ with (
771
+ patch(
772
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
773
+ return_value=mock_vector_engine,
774
+ ),
775
+ patch(
776
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
777
+ return_value=mock_fragment,
778
+ ) as mock_get_fragment,
779
+ ):
780
+ await brute_force_triplet_search(query="test query", node_name=["Node1"])
781
+
782
+ assert mock_get_fragment.called
783
+ call_kwargs = mock_get_fragment.call_args.kwargs if mock_get_fragment.call_args else {}
784
+ assert call_kwargs.get("relevant_ids_to_filter") is None
785
+
786
+
787
+ @pytest.mark.asyncio
788
+ async def test_brute_force_triplet_search_collection_not_found_at_top_level():
789
+ """Test brute_force_triplet_search handles CollectionNotFoundError at top level (line 210)."""
790
+ mock_vector_engine = AsyncMock()
791
+ mock_embedding_engine = AsyncMock()
792
+ mock_vector_engine.embedding_engine = mock_embedding_engine
793
+ mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
794
+
795
+ mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"})
796
+ mock_vector_engine.search = AsyncMock(return_value=[mock_result])
797
+
798
+ mock_fragment = AsyncMock()
799
+ mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock()
800
+ mock_fragment.map_vector_distances_to_graph_edges = AsyncMock()
801
+ mock_fragment.calculate_top_triplet_importances = AsyncMock(
802
+ side_effect=CollectionNotFoundError("Collection not found")
803
+ )
804
+
805
+ with (
806
+ patch(
807
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
808
+ return_value=mock_vector_engine,
809
+ ),
810
+ patch(
811
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
812
+ return_value=mock_fragment,
813
+ ),
814
+ ):
815
+ result = await brute_force_triplet_search(query="test query")
816
+
817
+ assert result == []
@@ -0,0 +1,343 @@
1
+ import pytest
2
+ from unittest.mock import AsyncMock, patch, MagicMock
3
+ from typing import Type
4
+
5
+
6
+ class TestGenerateCompletion:
7
+ @pytest.mark.asyncio
8
+ async def test_generate_completion_with_system_prompt(self):
9
+ """Test generate_completion with provided system_prompt."""
10
+ mock_llm_response = "Generated answer"
11
+
12
+ with (
13
+ patch(
14
+ "cognee.modules.retrieval.utils.completion.render_prompt",
15
+ return_value="User prompt text",
16
+ ),
17
+ patch(
18
+ "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
19
+ new_callable=AsyncMock,
20
+ return_value=mock_llm_response,
21
+ ) as mock_llm,
22
+ ):
23
+ from cognee.modules.retrieval.utils.completion import generate_completion
24
+
25
+ result = await generate_completion(
26
+ query="What is AI?",
27
+ context="AI is artificial intelligence",
28
+ user_prompt_path="user_prompt.txt",
29
+ system_prompt_path="system_prompt.txt",
30
+ system_prompt="Custom system prompt",
31
+ )
32
+
33
+ assert result == mock_llm_response
34
+ mock_llm.assert_awaited_once_with(
35
+ text_input="User prompt text",
36
+ system_prompt="Custom system prompt",
37
+ response_model=str,
38
+ )
39
+
40
+ @pytest.mark.asyncio
41
+ async def test_generate_completion_without_system_prompt(self):
42
+ """Test generate_completion reads system_prompt from file when not provided."""
43
+ mock_llm_response = "Generated answer"
44
+
45
+ with (
46
+ patch(
47
+ "cognee.modules.retrieval.utils.completion.render_prompt",
48
+ return_value="User prompt text",
49
+ ),
50
+ patch(
51
+ "cognee.modules.retrieval.utils.completion.read_query_prompt",
52
+ return_value="System prompt from file",
53
+ ),
54
+ patch(
55
+ "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
56
+ new_callable=AsyncMock,
57
+ return_value=mock_llm_response,
58
+ ) as mock_llm,
59
+ ):
60
+ from cognee.modules.retrieval.utils.completion import generate_completion
61
+
62
+ result = await generate_completion(
63
+ query="What is AI?",
64
+ context="AI is artificial intelligence",
65
+ user_prompt_path="user_prompt.txt",
66
+ system_prompt_path="system_prompt.txt",
67
+ )
68
+
69
+ assert result == mock_llm_response
70
+ mock_llm.assert_awaited_once_with(
71
+ text_input="User prompt text",
72
+ system_prompt="System prompt from file",
73
+ response_model=str,
74
+ )
75
+
76
+ @pytest.mark.asyncio
77
+ async def test_generate_completion_with_conversation_history(self):
78
+ """Test generate_completion includes conversation_history in system_prompt."""
79
+ mock_llm_response = "Generated answer"
80
+
81
+ with (
82
+ patch(
83
+ "cognee.modules.retrieval.utils.completion.render_prompt",
84
+ return_value="User prompt text",
85
+ ),
86
+ patch(
87
+ "cognee.modules.retrieval.utils.completion.read_query_prompt",
88
+ return_value="System prompt from file",
89
+ ),
90
+ patch(
91
+ "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
92
+ new_callable=AsyncMock,
93
+ return_value=mock_llm_response,
94
+ ) as mock_llm,
95
+ ):
96
+ from cognee.modules.retrieval.utils.completion import generate_completion
97
+
98
+ result = await generate_completion(
99
+ query="What is AI?",
100
+ context="AI is artificial intelligence",
101
+ user_prompt_path="user_prompt.txt",
102
+ system_prompt_path="system_prompt.txt",
103
+ conversation_history="Previous conversation:\nQ: What is ML?\nA: ML is machine learning",
104
+ )
105
+
106
+ assert result == mock_llm_response
107
+ expected_system_prompt = (
108
+ "Previous conversation:\nQ: What is ML?\nA: ML is machine learning"
109
+ + "\nTASK:"
110
+ + "System prompt from file"
111
+ )
112
+ mock_llm.assert_awaited_once_with(
113
+ text_input="User prompt text",
114
+ system_prompt=expected_system_prompt,
115
+ response_model=str,
116
+ )
117
+
118
+ @pytest.mark.asyncio
119
+ async def test_generate_completion_with_conversation_history_and_custom_system_prompt(self):
120
+ """Test generate_completion includes conversation_history with custom system_prompt."""
121
+ mock_llm_response = "Generated answer"
122
+
123
+ with (
124
+ patch(
125
+ "cognee.modules.retrieval.utils.completion.render_prompt",
126
+ return_value="User prompt text",
127
+ ),
128
+ patch(
129
+ "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
130
+ new_callable=AsyncMock,
131
+ return_value=mock_llm_response,
132
+ ) as mock_llm,
133
+ ):
134
+ from cognee.modules.retrieval.utils.completion import generate_completion
135
+
136
+ result = await generate_completion(
137
+ query="What is AI?",
138
+ context="AI is artificial intelligence",
139
+ user_prompt_path="user_prompt.txt",
140
+ system_prompt_path="system_prompt.txt",
141
+ system_prompt="Custom system prompt",
142
+ conversation_history="Previous conversation:\nQ: What is ML?\nA: ML is machine learning",
143
+ )
144
+
145
+ assert result == mock_llm_response
146
+ expected_system_prompt = (
147
+ "Previous conversation:\nQ: What is ML?\nA: ML is machine learning"
148
+ + "\nTASK:"
149
+ + "Custom system prompt"
150
+ )
151
+ mock_llm.assert_awaited_once_with(
152
+ text_input="User prompt text",
153
+ system_prompt=expected_system_prompt,
154
+ response_model=str,
155
+ )
156
+
157
+ @pytest.mark.asyncio
158
+ async def test_generate_completion_with_response_model(self):
159
+ """Test generate_completion with custom response_model."""
160
+ mock_response_model = MagicMock()
161
+ mock_llm_response = {"answer": "Generated answer"}
162
+
163
+ with (
164
+ patch(
165
+ "cognee.modules.retrieval.utils.completion.render_prompt",
166
+ return_value="User prompt text",
167
+ ),
168
+ patch(
169
+ "cognee.modules.retrieval.utils.completion.read_query_prompt",
170
+ return_value="System prompt from file",
171
+ ),
172
+ patch(
173
+ "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
174
+ new_callable=AsyncMock,
175
+ return_value=mock_llm_response,
176
+ ) as mock_llm,
177
+ ):
178
+ from cognee.modules.retrieval.utils.completion import generate_completion
179
+
180
+ result = await generate_completion(
181
+ query="What is AI?",
182
+ context="AI is artificial intelligence",
183
+ user_prompt_path="user_prompt.txt",
184
+ system_prompt_path="system_prompt.txt",
185
+ response_model=mock_response_model,
186
+ )
187
+
188
+ assert result == mock_llm_response
189
+ mock_llm.assert_awaited_once_with(
190
+ text_input="User prompt text",
191
+ system_prompt="System prompt from file",
192
+ response_model=mock_response_model,
193
+ )
194
+
195
+ @pytest.mark.asyncio
196
+ async def test_generate_completion_render_prompt_args(self):
197
+ """Test generate_completion passes correct args to render_prompt."""
198
+ mock_llm_response = "Generated answer"
199
+
200
+ with (
201
+ patch(
202
+ "cognee.modules.retrieval.utils.completion.render_prompt",
203
+ return_value="User prompt text",
204
+ ) as mock_render,
205
+ patch(
206
+ "cognee.modules.retrieval.utils.completion.read_query_prompt",
207
+ return_value="System prompt from file",
208
+ ),
209
+ patch(
210
+ "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
211
+ new_callable=AsyncMock,
212
+ return_value=mock_llm_response,
213
+ ),
214
+ ):
215
+ from cognee.modules.retrieval.utils.completion import generate_completion
216
+
217
+ await generate_completion(
218
+ query="What is AI?",
219
+ context="AI is artificial intelligence",
220
+ user_prompt_path="user_prompt.txt",
221
+ system_prompt_path="system_prompt.txt",
222
+ )
223
+
224
+ mock_render.assert_called_once_with(
225
+ "user_prompt.txt",
226
+ {"question": "What is AI?", "context": "AI is artificial intelligence"},
227
+ )
228
+
229
+
230
+ class TestSummarizeText:
231
+ @pytest.mark.asyncio
232
+ async def test_summarize_text_with_system_prompt(self):
233
+ """Test summarize_text with provided system_prompt."""
234
+ mock_llm_response = "Summary text"
235
+
236
+ with patch(
237
+ "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
238
+ new_callable=AsyncMock,
239
+ return_value=mock_llm_response,
240
+ ) as mock_llm:
241
+ from cognee.modules.retrieval.utils.completion import summarize_text
242
+
243
+ result = await summarize_text(
244
+ text="Long text to summarize",
245
+ system_prompt_path="summarize_search_results.txt",
246
+ system_prompt="Custom summary prompt",
247
+ )
248
+
249
+ assert result == mock_llm_response
250
+ mock_llm.assert_awaited_once_with(
251
+ text_input="Long text to summarize",
252
+ system_prompt="Custom summary prompt",
253
+ response_model=str,
254
+ )
255
+
256
+ @pytest.mark.asyncio
257
+ async def test_summarize_text_without_system_prompt(self):
258
+ """Test summarize_text reads system_prompt from file when not provided."""
259
+ mock_llm_response = "Summary text"
260
+
261
+ with (
262
+ patch(
263
+ "cognee.modules.retrieval.utils.completion.read_query_prompt",
264
+ return_value="System prompt from file",
265
+ ),
266
+ patch(
267
+ "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
268
+ new_callable=AsyncMock,
269
+ return_value=mock_llm_response,
270
+ ) as mock_llm,
271
+ ):
272
+ from cognee.modules.retrieval.utils.completion import summarize_text
273
+
274
+ result = await summarize_text(
275
+ text="Long text to summarize",
276
+ system_prompt_path="summarize_search_results.txt",
277
+ )
278
+
279
+ assert result == mock_llm_response
280
+ mock_llm.assert_awaited_once_with(
281
+ text_input="Long text to summarize",
282
+ system_prompt="System prompt from file",
283
+ response_model=str,
284
+ )
285
+
286
+ @pytest.mark.asyncio
287
+ async def test_summarize_text_default_prompt_path(self):
288
+ """Test summarize_text uses default prompt path when not provided."""
289
+ mock_llm_response = "Summary text"
290
+
291
+ with (
292
+ patch(
293
+ "cognee.modules.retrieval.utils.completion.read_query_prompt",
294
+ return_value="Default system prompt",
295
+ ) as mock_read,
296
+ patch(
297
+ "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
298
+ new_callable=AsyncMock,
299
+ return_value=mock_llm_response,
300
+ ) as mock_llm,
301
+ ):
302
+ from cognee.modules.retrieval.utils.completion import summarize_text
303
+
304
+ result = await summarize_text(text="Long text to summarize")
305
+
306
+ assert result == mock_llm_response
307
+ mock_read.assert_called_once_with("summarize_search_results.txt")
308
+ mock_llm.assert_awaited_once_with(
309
+ text_input="Long text to summarize",
310
+ system_prompt="Default system prompt",
311
+ response_model=str,
312
+ )
313
+
314
+ @pytest.mark.asyncio
315
+ async def test_summarize_text_custom_prompt_path(self):
316
+ """Test summarize_text uses custom prompt path when provided."""
317
+ mock_llm_response = "Summary text"
318
+
319
+ with (
320
+ patch(
321
+ "cognee.modules.retrieval.utils.completion.read_query_prompt",
322
+ return_value="Custom system prompt",
323
+ ) as mock_read,
324
+ patch(
325
+ "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
326
+ new_callable=AsyncMock,
327
+ return_value=mock_llm_response,
328
+ ) as mock_llm,
329
+ ):
330
+ from cognee.modules.retrieval.utils.completion import summarize_text
331
+
332
+ result = await summarize_text(
333
+ text="Long text to summarize",
334
+ system_prompt_path="custom_prompt.txt",
335
+ )
336
+
337
+ assert result == mock_llm_response
338
+ mock_read.assert_called_once_with("custom_prompt.txt")
339
+ mock_llm.assert_awaited_once_with(
340
+ text_input="Long text to summarize",
341
+ system_prompt="Custom system prompt",
342
+ response_model=str,
343
+ )