cognee 0.5.1__py3-none-any.whl → 0.5.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/api/v1/add/add.py +2 -1
- cognee/api/v1/datasets/routers/get_datasets_router.py +1 -0
- cognee/api/v1/memify/routers/get_memify_router.py +1 -0
- cognee/api/v1/search/search.py +0 -4
- cognee/infrastructure/databases/relational/config.py +16 -1
- cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +24 -2
- cognee/infrastructure/databases/vector/create_vector_engine.py +9 -2
- cognee/infrastructure/llm/LLMGateway.py +0 -13
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +5 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
- cognee/modules/data/models/Data.py +2 -1
- cognee/modules/retrieval/triplet_retriever.py +1 -1
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +0 -18
- cognee/modules/search/methods/search.py +18 -25
- cognee/tasks/ingestion/data_item.py +8 -0
- cognee/tasks/ingestion/ingest_data.py +12 -1
- cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
- cognee/tests/integration/retrieval/test_chunks_retriever.py +252 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever.py +268 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +226 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +218 -0
- cognee/tests/integration/retrieval/test_rag_completion_retriever.py +254 -0
- cognee/tests/{unit/modules/retrieval/structured_output_test.py → integration/retrieval/test_structured_output.py} +87 -77
- cognee/tests/integration/retrieval/test_summaries_retriever.py +184 -0
- cognee/tests/integration/retrieval/test_temporal_retriever.py +306 -0
- cognee/tests/integration/retrieval/test_triplet_retriever.py +35 -0
- cognee/tests/test_custom_data_label.py +68 -0
- cognee/tests/test_search_db.py +334 -181
- cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
- cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
- cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +181 -199
- cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +454 -162
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +674 -156
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +625 -200
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +319 -203
- cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +189 -155
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +539 -58
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +218 -9
- cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
- cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
- cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +246 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/METADATA +1 -1
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/RECORD +58 -45
- cognee/tests/unit/modules/search/test_search.py +0 -100
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/WHEEL +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
import pytest
|
|
2
|
-
from unittest.mock import AsyncMock, patch
|
|
2
|
+
from unittest.mock import AsyncMock, patch, MagicMock
|
|
3
3
|
|
|
4
4
|
from cognee.modules.retrieval.utils.brute_force_triplet_search import (
|
|
5
5
|
brute_force_triplet_search,
|
|
6
6
|
get_memory_fragment,
|
|
7
|
+
format_triplets,
|
|
7
8
|
)
|
|
8
9
|
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
|
9
10
|
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
|
11
|
+
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
|
10
12
|
|
|
11
13
|
|
|
12
14
|
class MockScoredResult:
|
|
@@ -354,20 +356,30 @@ async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation
|
|
|
354
356
|
|
|
355
357
|
@pytest.mark.asyncio
|
|
356
358
|
async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found():
|
|
357
|
-
"""Test that get_memory_fragment returns empty graph when entity not found."""
|
|
359
|
+
"""Test that get_memory_fragment returns empty graph when entity not found (line 85)."""
|
|
358
360
|
mock_graph_engine = AsyncMock()
|
|
359
|
-
|
|
361
|
+
|
|
362
|
+
# Create a mock fragment that will raise EntityNotFoundError when project_graph_from_db is called
|
|
363
|
+
mock_fragment = MagicMock(spec=CogneeGraph)
|
|
364
|
+
mock_fragment.project_graph_from_db = AsyncMock(
|
|
360
365
|
side_effect=EntityNotFoundError("Entity not found")
|
|
361
366
|
)
|
|
362
367
|
|
|
363
|
-
with
|
|
364
|
-
|
|
365
|
-
|
|
368
|
+
with (
|
|
369
|
+
patch(
|
|
370
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
|
|
371
|
+
return_value=mock_graph_engine,
|
|
372
|
+
),
|
|
373
|
+
patch(
|
|
374
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.CogneeGraph",
|
|
375
|
+
return_value=mock_fragment,
|
|
376
|
+
),
|
|
366
377
|
):
|
|
367
|
-
|
|
378
|
+
result = await get_memory_fragment()
|
|
368
379
|
|
|
369
|
-
|
|
370
|
-
assert
|
|
380
|
+
# Fragment should be returned even though EntityNotFoundError was raised (pass statement on line 85)
|
|
381
|
+
assert result == mock_fragment
|
|
382
|
+
mock_fragment.project_graph_from_db.assert_awaited_once()
|
|
371
383
|
|
|
372
384
|
|
|
373
385
|
@pytest.mark.asyncio
|
|
@@ -606,3 +618,200 @@ async def test_brute_force_triplet_search_mixed_empty_collections():
|
|
|
606
618
|
|
|
607
619
|
call_kwargs = mock_get_fragment_fn.call_args[1]
|
|
608
620
|
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
|
|
621
|
+
|
|
622
|
+
|
|
623
|
+
def test_format_triplets():
|
|
624
|
+
"""Test format_triplets function."""
|
|
625
|
+
mock_edge = MagicMock()
|
|
626
|
+
mock_node1 = MagicMock()
|
|
627
|
+
mock_node2 = MagicMock()
|
|
628
|
+
|
|
629
|
+
mock_node1.attributes = {"name": "Node1", "type": "Entity", "id": "n1"}
|
|
630
|
+
mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": "n2"}
|
|
631
|
+
mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": "connects"}
|
|
632
|
+
|
|
633
|
+
mock_edge.node1 = mock_node1
|
|
634
|
+
mock_edge.node2 = mock_node2
|
|
635
|
+
|
|
636
|
+
result = format_triplets([mock_edge])
|
|
637
|
+
|
|
638
|
+
assert isinstance(result, str)
|
|
639
|
+
assert "Node1" in result
|
|
640
|
+
assert "Node2" in result
|
|
641
|
+
assert "relates_to" in result
|
|
642
|
+
assert "connects" in result
|
|
643
|
+
|
|
644
|
+
|
|
645
|
+
def test_format_triplets_with_none_values():
|
|
646
|
+
"""Test format_triplets filters out None values."""
|
|
647
|
+
mock_edge = MagicMock()
|
|
648
|
+
mock_node1 = MagicMock()
|
|
649
|
+
mock_node2 = MagicMock()
|
|
650
|
+
|
|
651
|
+
mock_node1.attributes = {"name": "Node1", "type": None, "id": "n1"}
|
|
652
|
+
mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": None}
|
|
653
|
+
mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": None}
|
|
654
|
+
|
|
655
|
+
mock_edge.node1 = mock_node1
|
|
656
|
+
mock_edge.node2 = mock_node2
|
|
657
|
+
|
|
658
|
+
result = format_triplets([mock_edge])
|
|
659
|
+
|
|
660
|
+
assert "Node1" in result
|
|
661
|
+
assert "Node2" in result
|
|
662
|
+
assert "relates_to" in result
|
|
663
|
+
assert "None" not in result or result.count("None") == 0
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
def test_format_triplets_with_nested_dict():
|
|
667
|
+
"""Test format_triplets handles nested dict attributes (lines 23-35)."""
|
|
668
|
+
mock_edge = MagicMock()
|
|
669
|
+
mock_node1 = MagicMock()
|
|
670
|
+
mock_node2 = MagicMock()
|
|
671
|
+
|
|
672
|
+
mock_node1.attributes = {"name": "Node1", "metadata": {"type": "Entity", "id": "n1"}}
|
|
673
|
+
mock_node2.attributes = {"name": "Node2", "metadata": {"type": "Entity", "id": "n2"}}
|
|
674
|
+
mock_edge.attributes = {"relationship_name": "relates_to"}
|
|
675
|
+
|
|
676
|
+
mock_edge.node1 = mock_node1
|
|
677
|
+
mock_edge.node2 = mock_node2
|
|
678
|
+
|
|
679
|
+
result = format_triplets([mock_edge])
|
|
680
|
+
|
|
681
|
+
assert isinstance(result, str)
|
|
682
|
+
assert "Node1" in result
|
|
683
|
+
assert "Node2" in result
|
|
684
|
+
assert "relates_to" in result
|
|
685
|
+
|
|
686
|
+
|
|
687
|
+
@pytest.mark.asyncio
|
|
688
|
+
async def test_brute_force_triplet_search_vector_engine_init_error():
|
|
689
|
+
"""Test brute_force_triplet_search handles vector engine initialization error (lines 145-147)."""
|
|
690
|
+
with (
|
|
691
|
+
patch(
|
|
692
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine"
|
|
693
|
+
) as mock_get_vector_engine,
|
|
694
|
+
):
|
|
695
|
+
mock_get_vector_engine.side_effect = Exception("Initialization error")
|
|
696
|
+
|
|
697
|
+
with pytest.raises(RuntimeError, match="Initialization error"):
|
|
698
|
+
await brute_force_triplet_search(query="test query")
|
|
699
|
+
|
|
700
|
+
|
|
701
|
+
@pytest.mark.asyncio
|
|
702
|
+
async def test_brute_force_triplet_search_collection_not_found_error():
|
|
703
|
+
"""Test brute_force_triplet_search handles CollectionNotFoundError in search (lines 156-157)."""
|
|
704
|
+
mock_vector_engine = AsyncMock()
|
|
705
|
+
mock_embedding_engine = AsyncMock()
|
|
706
|
+
mock_vector_engine.embedding_engine = mock_embedding_engine
|
|
707
|
+
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
708
|
+
|
|
709
|
+
mock_vector_engine.search = AsyncMock(
|
|
710
|
+
side_effect=[
|
|
711
|
+
CollectionNotFoundError("Collection not found"),
|
|
712
|
+
[],
|
|
713
|
+
[],
|
|
714
|
+
]
|
|
715
|
+
)
|
|
716
|
+
|
|
717
|
+
with (
|
|
718
|
+
patch(
|
|
719
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
|
720
|
+
return_value=mock_vector_engine,
|
|
721
|
+
),
|
|
722
|
+
patch(
|
|
723
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
|
724
|
+
return_value=CogneeGraph(),
|
|
725
|
+
),
|
|
726
|
+
):
|
|
727
|
+
result = await brute_force_triplet_search(
|
|
728
|
+
query="test query", collections=["missing_collection", "existing_collection"]
|
|
729
|
+
)
|
|
730
|
+
|
|
731
|
+
assert result == []
|
|
732
|
+
|
|
733
|
+
|
|
734
|
+
@pytest.mark.asyncio
|
|
735
|
+
async def test_brute_force_triplet_search_generic_exception():
|
|
736
|
+
"""Test brute_force_triplet_search handles generic exceptions (lines 209-217)."""
|
|
737
|
+
mock_vector_engine = AsyncMock()
|
|
738
|
+
mock_embedding_engine = AsyncMock()
|
|
739
|
+
mock_vector_engine.embedding_engine = mock_embedding_engine
|
|
740
|
+
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
741
|
+
|
|
742
|
+
mock_vector_engine.search = AsyncMock(side_effect=Exception("Generic error"))
|
|
743
|
+
|
|
744
|
+
with (
|
|
745
|
+
patch(
|
|
746
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
|
747
|
+
return_value=mock_vector_engine,
|
|
748
|
+
),
|
|
749
|
+
):
|
|
750
|
+
with pytest.raises(Exception, match="Generic error"):
|
|
751
|
+
await brute_force_triplet_search(query="test query")
|
|
752
|
+
|
|
753
|
+
|
|
754
|
+
@pytest.mark.asyncio
|
|
755
|
+
async def test_brute_force_triplet_search_with_node_name_sets_relevant_ids_to_none():
|
|
756
|
+
"""Test brute_force_triplet_search sets relevant_ids_to_filter to None when node_name is provided (line 191)."""
|
|
757
|
+
mock_vector_engine = AsyncMock()
|
|
758
|
+
mock_embedding_engine = AsyncMock()
|
|
759
|
+
mock_vector_engine.embedding_engine = mock_embedding_engine
|
|
760
|
+
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
761
|
+
|
|
762
|
+
mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"})
|
|
763
|
+
mock_vector_engine.search = AsyncMock(return_value=[mock_result])
|
|
764
|
+
|
|
765
|
+
mock_fragment = AsyncMock()
|
|
766
|
+
mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock()
|
|
767
|
+
mock_fragment.map_vector_distances_to_graph_edges = AsyncMock()
|
|
768
|
+
mock_fragment.calculate_top_triplet_importances = AsyncMock(return_value=[])
|
|
769
|
+
|
|
770
|
+
with (
|
|
771
|
+
patch(
|
|
772
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
|
773
|
+
return_value=mock_vector_engine,
|
|
774
|
+
),
|
|
775
|
+
patch(
|
|
776
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
|
777
|
+
return_value=mock_fragment,
|
|
778
|
+
) as mock_get_fragment,
|
|
779
|
+
):
|
|
780
|
+
await brute_force_triplet_search(query="test query", node_name=["Node1"])
|
|
781
|
+
|
|
782
|
+
assert mock_get_fragment.called
|
|
783
|
+
call_kwargs = mock_get_fragment.call_args.kwargs if mock_get_fragment.call_args else {}
|
|
784
|
+
assert call_kwargs.get("relevant_ids_to_filter") is None
|
|
785
|
+
|
|
786
|
+
|
|
787
|
+
@pytest.mark.asyncio
|
|
788
|
+
async def test_brute_force_triplet_search_collection_not_found_at_top_level():
|
|
789
|
+
"""Test brute_force_triplet_search handles CollectionNotFoundError at top level (line 210)."""
|
|
790
|
+
mock_vector_engine = AsyncMock()
|
|
791
|
+
mock_embedding_engine = AsyncMock()
|
|
792
|
+
mock_vector_engine.embedding_engine = mock_embedding_engine
|
|
793
|
+
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
794
|
+
|
|
795
|
+
mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"})
|
|
796
|
+
mock_vector_engine.search = AsyncMock(return_value=[mock_result])
|
|
797
|
+
|
|
798
|
+
mock_fragment = AsyncMock()
|
|
799
|
+
mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock()
|
|
800
|
+
mock_fragment.map_vector_distances_to_graph_edges = AsyncMock()
|
|
801
|
+
mock_fragment.calculate_top_triplet_importances = AsyncMock(
|
|
802
|
+
side_effect=CollectionNotFoundError("Collection not found")
|
|
803
|
+
)
|
|
804
|
+
|
|
805
|
+
with (
|
|
806
|
+
patch(
|
|
807
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
|
808
|
+
return_value=mock_vector_engine,
|
|
809
|
+
),
|
|
810
|
+
patch(
|
|
811
|
+
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
|
812
|
+
return_value=mock_fragment,
|
|
813
|
+
),
|
|
814
|
+
):
|
|
815
|
+
result = await brute_force_triplet_search(query="test query")
|
|
816
|
+
|
|
817
|
+
assert result == []
|
|
@@ -0,0 +1,343 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from unittest.mock import AsyncMock, patch, MagicMock
|
|
3
|
+
from typing import Type
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TestGenerateCompletion:
|
|
7
|
+
@pytest.mark.asyncio
|
|
8
|
+
async def test_generate_completion_with_system_prompt(self):
|
|
9
|
+
"""Test generate_completion with provided system_prompt."""
|
|
10
|
+
mock_llm_response = "Generated answer"
|
|
11
|
+
|
|
12
|
+
with (
|
|
13
|
+
patch(
|
|
14
|
+
"cognee.modules.retrieval.utils.completion.render_prompt",
|
|
15
|
+
return_value="User prompt text",
|
|
16
|
+
),
|
|
17
|
+
patch(
|
|
18
|
+
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
|
19
|
+
new_callable=AsyncMock,
|
|
20
|
+
return_value=mock_llm_response,
|
|
21
|
+
) as mock_llm,
|
|
22
|
+
):
|
|
23
|
+
from cognee.modules.retrieval.utils.completion import generate_completion
|
|
24
|
+
|
|
25
|
+
result = await generate_completion(
|
|
26
|
+
query="What is AI?",
|
|
27
|
+
context="AI is artificial intelligence",
|
|
28
|
+
user_prompt_path="user_prompt.txt",
|
|
29
|
+
system_prompt_path="system_prompt.txt",
|
|
30
|
+
system_prompt="Custom system prompt",
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
assert result == mock_llm_response
|
|
34
|
+
mock_llm.assert_awaited_once_with(
|
|
35
|
+
text_input="User prompt text",
|
|
36
|
+
system_prompt="Custom system prompt",
|
|
37
|
+
response_model=str,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
@pytest.mark.asyncio
|
|
41
|
+
async def test_generate_completion_without_system_prompt(self):
|
|
42
|
+
"""Test generate_completion reads system_prompt from file when not provided."""
|
|
43
|
+
mock_llm_response = "Generated answer"
|
|
44
|
+
|
|
45
|
+
with (
|
|
46
|
+
patch(
|
|
47
|
+
"cognee.modules.retrieval.utils.completion.render_prompt",
|
|
48
|
+
return_value="User prompt text",
|
|
49
|
+
),
|
|
50
|
+
patch(
|
|
51
|
+
"cognee.modules.retrieval.utils.completion.read_query_prompt",
|
|
52
|
+
return_value="System prompt from file",
|
|
53
|
+
),
|
|
54
|
+
patch(
|
|
55
|
+
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
|
56
|
+
new_callable=AsyncMock,
|
|
57
|
+
return_value=mock_llm_response,
|
|
58
|
+
) as mock_llm,
|
|
59
|
+
):
|
|
60
|
+
from cognee.modules.retrieval.utils.completion import generate_completion
|
|
61
|
+
|
|
62
|
+
result = await generate_completion(
|
|
63
|
+
query="What is AI?",
|
|
64
|
+
context="AI is artificial intelligence",
|
|
65
|
+
user_prompt_path="user_prompt.txt",
|
|
66
|
+
system_prompt_path="system_prompt.txt",
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
assert result == mock_llm_response
|
|
70
|
+
mock_llm.assert_awaited_once_with(
|
|
71
|
+
text_input="User prompt text",
|
|
72
|
+
system_prompt="System prompt from file",
|
|
73
|
+
response_model=str,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
@pytest.mark.asyncio
|
|
77
|
+
async def test_generate_completion_with_conversation_history(self):
|
|
78
|
+
"""Test generate_completion includes conversation_history in system_prompt."""
|
|
79
|
+
mock_llm_response = "Generated answer"
|
|
80
|
+
|
|
81
|
+
with (
|
|
82
|
+
patch(
|
|
83
|
+
"cognee.modules.retrieval.utils.completion.render_prompt",
|
|
84
|
+
return_value="User prompt text",
|
|
85
|
+
),
|
|
86
|
+
patch(
|
|
87
|
+
"cognee.modules.retrieval.utils.completion.read_query_prompt",
|
|
88
|
+
return_value="System prompt from file",
|
|
89
|
+
),
|
|
90
|
+
patch(
|
|
91
|
+
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
|
92
|
+
new_callable=AsyncMock,
|
|
93
|
+
return_value=mock_llm_response,
|
|
94
|
+
) as mock_llm,
|
|
95
|
+
):
|
|
96
|
+
from cognee.modules.retrieval.utils.completion import generate_completion
|
|
97
|
+
|
|
98
|
+
result = await generate_completion(
|
|
99
|
+
query="What is AI?",
|
|
100
|
+
context="AI is artificial intelligence",
|
|
101
|
+
user_prompt_path="user_prompt.txt",
|
|
102
|
+
system_prompt_path="system_prompt.txt",
|
|
103
|
+
conversation_history="Previous conversation:\nQ: What is ML?\nA: ML is machine learning",
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
assert result == mock_llm_response
|
|
107
|
+
expected_system_prompt = (
|
|
108
|
+
"Previous conversation:\nQ: What is ML?\nA: ML is machine learning"
|
|
109
|
+
+ "\nTASK:"
|
|
110
|
+
+ "System prompt from file"
|
|
111
|
+
)
|
|
112
|
+
mock_llm.assert_awaited_once_with(
|
|
113
|
+
text_input="User prompt text",
|
|
114
|
+
system_prompt=expected_system_prompt,
|
|
115
|
+
response_model=str,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
@pytest.mark.asyncio
|
|
119
|
+
async def test_generate_completion_with_conversation_history_and_custom_system_prompt(self):
|
|
120
|
+
"""Test generate_completion includes conversation_history with custom system_prompt."""
|
|
121
|
+
mock_llm_response = "Generated answer"
|
|
122
|
+
|
|
123
|
+
with (
|
|
124
|
+
patch(
|
|
125
|
+
"cognee.modules.retrieval.utils.completion.render_prompt",
|
|
126
|
+
return_value="User prompt text",
|
|
127
|
+
),
|
|
128
|
+
patch(
|
|
129
|
+
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
|
130
|
+
new_callable=AsyncMock,
|
|
131
|
+
return_value=mock_llm_response,
|
|
132
|
+
) as mock_llm,
|
|
133
|
+
):
|
|
134
|
+
from cognee.modules.retrieval.utils.completion import generate_completion
|
|
135
|
+
|
|
136
|
+
result = await generate_completion(
|
|
137
|
+
query="What is AI?",
|
|
138
|
+
context="AI is artificial intelligence",
|
|
139
|
+
user_prompt_path="user_prompt.txt",
|
|
140
|
+
system_prompt_path="system_prompt.txt",
|
|
141
|
+
system_prompt="Custom system prompt",
|
|
142
|
+
conversation_history="Previous conversation:\nQ: What is ML?\nA: ML is machine learning",
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
assert result == mock_llm_response
|
|
146
|
+
expected_system_prompt = (
|
|
147
|
+
"Previous conversation:\nQ: What is ML?\nA: ML is machine learning"
|
|
148
|
+
+ "\nTASK:"
|
|
149
|
+
+ "Custom system prompt"
|
|
150
|
+
)
|
|
151
|
+
mock_llm.assert_awaited_once_with(
|
|
152
|
+
text_input="User prompt text",
|
|
153
|
+
system_prompt=expected_system_prompt,
|
|
154
|
+
response_model=str,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
@pytest.mark.asyncio
|
|
158
|
+
async def test_generate_completion_with_response_model(self):
|
|
159
|
+
"""Test generate_completion with custom response_model."""
|
|
160
|
+
mock_response_model = MagicMock()
|
|
161
|
+
mock_llm_response = {"answer": "Generated answer"}
|
|
162
|
+
|
|
163
|
+
with (
|
|
164
|
+
patch(
|
|
165
|
+
"cognee.modules.retrieval.utils.completion.render_prompt",
|
|
166
|
+
return_value="User prompt text",
|
|
167
|
+
),
|
|
168
|
+
patch(
|
|
169
|
+
"cognee.modules.retrieval.utils.completion.read_query_prompt",
|
|
170
|
+
return_value="System prompt from file",
|
|
171
|
+
),
|
|
172
|
+
patch(
|
|
173
|
+
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
|
174
|
+
new_callable=AsyncMock,
|
|
175
|
+
return_value=mock_llm_response,
|
|
176
|
+
) as mock_llm,
|
|
177
|
+
):
|
|
178
|
+
from cognee.modules.retrieval.utils.completion import generate_completion
|
|
179
|
+
|
|
180
|
+
result = await generate_completion(
|
|
181
|
+
query="What is AI?",
|
|
182
|
+
context="AI is artificial intelligence",
|
|
183
|
+
user_prompt_path="user_prompt.txt",
|
|
184
|
+
system_prompt_path="system_prompt.txt",
|
|
185
|
+
response_model=mock_response_model,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
assert result == mock_llm_response
|
|
189
|
+
mock_llm.assert_awaited_once_with(
|
|
190
|
+
text_input="User prompt text",
|
|
191
|
+
system_prompt="System prompt from file",
|
|
192
|
+
response_model=mock_response_model,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
@pytest.mark.asyncio
|
|
196
|
+
async def test_generate_completion_render_prompt_args(self):
|
|
197
|
+
"""Test generate_completion passes correct args to render_prompt."""
|
|
198
|
+
mock_llm_response = "Generated answer"
|
|
199
|
+
|
|
200
|
+
with (
|
|
201
|
+
patch(
|
|
202
|
+
"cognee.modules.retrieval.utils.completion.render_prompt",
|
|
203
|
+
return_value="User prompt text",
|
|
204
|
+
) as mock_render,
|
|
205
|
+
patch(
|
|
206
|
+
"cognee.modules.retrieval.utils.completion.read_query_prompt",
|
|
207
|
+
return_value="System prompt from file",
|
|
208
|
+
),
|
|
209
|
+
patch(
|
|
210
|
+
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
|
211
|
+
new_callable=AsyncMock,
|
|
212
|
+
return_value=mock_llm_response,
|
|
213
|
+
),
|
|
214
|
+
):
|
|
215
|
+
from cognee.modules.retrieval.utils.completion import generate_completion
|
|
216
|
+
|
|
217
|
+
await generate_completion(
|
|
218
|
+
query="What is AI?",
|
|
219
|
+
context="AI is artificial intelligence",
|
|
220
|
+
user_prompt_path="user_prompt.txt",
|
|
221
|
+
system_prompt_path="system_prompt.txt",
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
mock_render.assert_called_once_with(
|
|
225
|
+
"user_prompt.txt",
|
|
226
|
+
{"question": "What is AI?", "context": "AI is artificial intelligence"},
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class TestSummarizeText:
|
|
231
|
+
@pytest.mark.asyncio
|
|
232
|
+
async def test_summarize_text_with_system_prompt(self):
|
|
233
|
+
"""Test summarize_text with provided system_prompt."""
|
|
234
|
+
mock_llm_response = "Summary text"
|
|
235
|
+
|
|
236
|
+
with patch(
|
|
237
|
+
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
|
238
|
+
new_callable=AsyncMock,
|
|
239
|
+
return_value=mock_llm_response,
|
|
240
|
+
) as mock_llm:
|
|
241
|
+
from cognee.modules.retrieval.utils.completion import summarize_text
|
|
242
|
+
|
|
243
|
+
result = await summarize_text(
|
|
244
|
+
text="Long text to summarize",
|
|
245
|
+
system_prompt_path="summarize_search_results.txt",
|
|
246
|
+
system_prompt="Custom summary prompt",
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
assert result == mock_llm_response
|
|
250
|
+
mock_llm.assert_awaited_once_with(
|
|
251
|
+
text_input="Long text to summarize",
|
|
252
|
+
system_prompt="Custom summary prompt",
|
|
253
|
+
response_model=str,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
@pytest.mark.asyncio
|
|
257
|
+
async def test_summarize_text_without_system_prompt(self):
|
|
258
|
+
"""Test summarize_text reads system_prompt from file when not provided."""
|
|
259
|
+
mock_llm_response = "Summary text"
|
|
260
|
+
|
|
261
|
+
with (
|
|
262
|
+
patch(
|
|
263
|
+
"cognee.modules.retrieval.utils.completion.read_query_prompt",
|
|
264
|
+
return_value="System prompt from file",
|
|
265
|
+
),
|
|
266
|
+
patch(
|
|
267
|
+
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
|
268
|
+
new_callable=AsyncMock,
|
|
269
|
+
return_value=mock_llm_response,
|
|
270
|
+
) as mock_llm,
|
|
271
|
+
):
|
|
272
|
+
from cognee.modules.retrieval.utils.completion import summarize_text
|
|
273
|
+
|
|
274
|
+
result = await summarize_text(
|
|
275
|
+
text="Long text to summarize",
|
|
276
|
+
system_prompt_path="summarize_search_results.txt",
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
assert result == mock_llm_response
|
|
280
|
+
mock_llm.assert_awaited_once_with(
|
|
281
|
+
text_input="Long text to summarize",
|
|
282
|
+
system_prompt="System prompt from file",
|
|
283
|
+
response_model=str,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
@pytest.mark.asyncio
|
|
287
|
+
async def test_summarize_text_default_prompt_path(self):
|
|
288
|
+
"""Test summarize_text uses default prompt path when not provided."""
|
|
289
|
+
mock_llm_response = "Summary text"
|
|
290
|
+
|
|
291
|
+
with (
|
|
292
|
+
patch(
|
|
293
|
+
"cognee.modules.retrieval.utils.completion.read_query_prompt",
|
|
294
|
+
return_value="Default system prompt",
|
|
295
|
+
) as mock_read,
|
|
296
|
+
patch(
|
|
297
|
+
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
|
298
|
+
new_callable=AsyncMock,
|
|
299
|
+
return_value=mock_llm_response,
|
|
300
|
+
) as mock_llm,
|
|
301
|
+
):
|
|
302
|
+
from cognee.modules.retrieval.utils.completion import summarize_text
|
|
303
|
+
|
|
304
|
+
result = await summarize_text(text="Long text to summarize")
|
|
305
|
+
|
|
306
|
+
assert result == mock_llm_response
|
|
307
|
+
mock_read.assert_called_once_with("summarize_search_results.txt")
|
|
308
|
+
mock_llm.assert_awaited_once_with(
|
|
309
|
+
text_input="Long text to summarize",
|
|
310
|
+
system_prompt="Default system prompt",
|
|
311
|
+
response_model=str,
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
@pytest.mark.asyncio
|
|
315
|
+
async def test_summarize_text_custom_prompt_path(self):
|
|
316
|
+
"""Test summarize_text uses custom prompt path when provided."""
|
|
317
|
+
mock_llm_response = "Summary text"
|
|
318
|
+
|
|
319
|
+
with (
|
|
320
|
+
patch(
|
|
321
|
+
"cognee.modules.retrieval.utils.completion.read_query_prompt",
|
|
322
|
+
return_value="Custom system prompt",
|
|
323
|
+
) as mock_read,
|
|
324
|
+
patch(
|
|
325
|
+
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
|
326
|
+
new_callable=AsyncMock,
|
|
327
|
+
return_value=mock_llm_response,
|
|
328
|
+
) as mock_llm,
|
|
329
|
+
):
|
|
330
|
+
from cognee.modules.retrieval.utils.completion import summarize_text
|
|
331
|
+
|
|
332
|
+
result = await summarize_text(
|
|
333
|
+
text="Long text to summarize",
|
|
334
|
+
system_prompt_path="custom_prompt.txt",
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
assert result == mock_llm_response
|
|
338
|
+
mock_read.assert_called_once_with("custom_prompt.txt")
|
|
339
|
+
mock_llm.assert_awaited_once_with(
|
|
340
|
+
text_input="Long text to summarize",
|
|
341
|
+
system_prompt="Custom system prompt",
|
|
342
|
+
response_model=str,
|
|
343
|
+
)
|