cognee 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +1 -0
- cognee/api/client.py +9 -5
- cognee/api/v1/add/add.py +2 -1
- cognee/api/v1/add/routers/get_add_router.py +3 -1
- cognee/api/v1/cognify/cognify.py +24 -16
- cognee/api/v1/cognify/routers/__init__.py +0 -1
- cognee/api/v1/cognify/routers/get_cognify_router.py +30 -1
- cognee/api/v1/datasets/routers/get_datasets_router.py +3 -3
- cognee/api/v1/ontologies/__init__.py +4 -0
- cognee/api/v1/ontologies/ontologies.py +158 -0
- cognee/api/v1/ontologies/routers/__init__.py +0 -0
- cognee/api/v1/ontologies/routers/get_ontology_router.py +109 -0
- cognee/api/v1/permissions/routers/get_permissions_router.py +41 -1
- cognee/api/v1/search/search.py +4 -0
- cognee/api/v1/ui/node_setup.py +360 -0
- cognee/api/v1/ui/npm_utils.py +50 -0
- cognee/api/v1/ui/ui.py +38 -68
- cognee/cli/commands/cognify_command.py +8 -1
- cognee/cli/config.py +1 -1
- cognee/context_global_variables.py +86 -9
- cognee/eval_framework/Dockerfile +29 -0
- cognee/eval_framework/answer_generation/answer_generation_executor.py +10 -0
- cognee/eval_framework/answer_generation/run_question_answering_module.py +1 -1
- cognee/eval_framework/corpus_builder/task_getters/get_cascade_graph_tasks.py +0 -2
- cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py +4 -4
- cognee/eval_framework/eval_config.py +2 -2
- cognee/eval_framework/modal_run_eval.py +16 -28
- cognee/infrastructure/databases/cache/config.py +3 -1
- cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py +151 -0
- cognee/infrastructure/databases/cache/get_cache_engine.py +20 -10
- cognee/infrastructure/databases/dataset_database_handler/__init__.py +3 -0
- cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py +80 -0
- cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +18 -0
- cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py +10 -0
- cognee/infrastructure/databases/exceptions/exceptions.py +16 -0
- cognee/infrastructure/databases/graph/config.py +7 -0
- cognee/infrastructure/databases/graph/get_graph_engine.py +3 -0
- cognee/infrastructure/databases/graph/graph_db_interface.py +15 -0
- cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py +81 -0
- cognee/infrastructure/databases/graph/kuzu/adapter.py +228 -0
- cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +168 -0
- cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +80 -1
- cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +9 -0
- cognee/infrastructure/databases/utils/__init__.py +3 -0
- cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py +10 -0
- cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +66 -18
- cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py +10 -0
- cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py +30 -0
- cognee/infrastructure/databases/vector/config.py +5 -0
- cognee/infrastructure/databases/vector/create_vector_engine.py +6 -1
- cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +8 -6
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +9 -7
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +11 -10
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +2 -0
- cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py +50 -0
- cognee/infrastructure/databases/vector/vector_db_interface.py +35 -0
- cognee/infrastructure/engine/models/Edge.py +13 -1
- cognee/infrastructure/files/storage/s3_config.py +2 -0
- cognee/infrastructure/files/utils/guess_file_type.py +4 -0
- cognee/infrastructure/llm/LLMGateway.py +5 -2
- cognee/infrastructure/llm/config.py +37 -0
- cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py +2 -2
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +23 -8
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +22 -18
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py +5 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py +153 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +47 -38
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +46 -37
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +20 -10
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +23 -11
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +36 -23
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +47 -36
- cognee/infrastructure/loaders/LoaderEngine.py +1 -0
- cognee/infrastructure/loaders/core/__init__.py +2 -1
- cognee/infrastructure/loaders/core/csv_loader.py +93 -0
- cognee/infrastructure/loaders/core/text_loader.py +1 -2
- cognee/infrastructure/loaders/external/advanced_pdf_loader.py +0 -9
- cognee/infrastructure/loaders/supported_loaders.py +2 -1
- cognee/memify_pipelines/create_triplet_embeddings.py +53 -0
- cognee/memify_pipelines/persist_sessions_in_knowledge_graph.py +55 -0
- cognee/modules/chunking/CsvChunker.py +35 -0
- cognee/modules/chunking/models/DocumentChunk.py +2 -1
- cognee/modules/chunking/text_chunker_with_overlap.py +124 -0
- cognee/modules/cognify/config.py +2 -0
- cognee/modules/data/deletion/prune_system.py +52 -2
- cognee/modules/data/methods/__init__.py +1 -0
- cognee/modules/data/methods/create_dataset.py +4 -2
- cognee/modules/data/methods/delete_dataset.py +26 -0
- cognee/modules/data/methods/get_dataset_ids.py +5 -1
- cognee/modules/data/methods/get_unique_data_id.py +68 -0
- cognee/modules/data/methods/get_unique_dataset_id.py +66 -4
- cognee/modules/data/models/Dataset.py +2 -0
- cognee/modules/data/processing/document_types/CsvDocument.py +33 -0
- cognee/modules/data/processing/document_types/__init__.py +1 -0
- cognee/modules/engine/models/Triplet.py +9 -0
- cognee/modules/engine/models/__init__.py +1 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +89 -39
- cognee/modules/graph/cognee_graph/CogneeGraphElements.py +8 -3
- cognee/modules/graph/utils/expand_with_nodes_and_edges.py +19 -2
- cognee/modules/graph/utils/resolve_edges_to_text.py +48 -49
- cognee/modules/ingestion/identify.py +4 -4
- cognee/modules/memify/memify.py +1 -7
- cognee/modules/notebooks/operations/run_in_local_sandbox.py +3 -0
- cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py +55 -23
- cognee/modules/pipelines/operations/pipeline.py +18 -2
- cognee/modules/pipelines/operations/run_tasks_data_item.py +1 -1
- cognee/modules/retrieval/EntityCompletionRetriever.py +10 -3
- cognee/modules/retrieval/__init__.py +1 -1
- cognee/modules/retrieval/base_graph_retriever.py +7 -3
- cognee/modules/retrieval/base_retriever.py +7 -3
- cognee/modules/retrieval/completion_retriever.py +11 -4
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +10 -2
- cognee/modules/retrieval/graph_completion_cot_retriever.py +18 -51
- cognee/modules/retrieval/graph_completion_retriever.py +14 -1
- cognee/modules/retrieval/graph_summary_completion_retriever.py +4 -0
- cognee/modules/retrieval/register_retriever.py +10 -0
- cognee/modules/retrieval/registered_community_retrievers.py +1 -0
- cognee/modules/retrieval/temporal_retriever.py +13 -2
- cognee/modules/retrieval/triplet_retriever.py +182 -0
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +43 -11
- cognee/modules/retrieval/utils/completion.py +2 -22
- cognee/modules/run_custom_pipeline/__init__.py +1 -0
- cognee/modules/run_custom_pipeline/run_custom_pipeline.py +76 -0
- cognee/modules/search/methods/get_search_type_tools.py +54 -8
- cognee/modules/search/methods/no_access_control_search.py +4 -0
- cognee/modules/search/methods/search.py +26 -3
- cognee/modules/search/types/SearchType.py +1 -1
- cognee/modules/settings/get_settings.py +19 -0
- cognee/modules/users/methods/create_user.py +12 -27
- cognee/modules/users/methods/get_authenticated_user.py +3 -2
- cognee/modules/users/methods/get_default_user.py +4 -2
- cognee/modules/users/methods/get_user.py +1 -1
- cognee/modules/users/methods/get_user_by_email.py +1 -1
- cognee/modules/users/models/DatasetDatabase.py +24 -3
- cognee/modules/users/models/Tenant.py +6 -7
- cognee/modules/users/models/User.py +6 -5
- cognee/modules/users/models/UserTenant.py +12 -0
- cognee/modules/users/models/__init__.py +1 -0
- cognee/modules/users/permissions/methods/get_all_user_permission_datasets.py +13 -13
- cognee/modules/users/roles/methods/add_user_to_role.py +3 -1
- cognee/modules/users/tenants/methods/__init__.py +1 -0
- cognee/modules/users/tenants/methods/add_user_to_tenant.py +21 -12
- cognee/modules/users/tenants/methods/create_tenant.py +22 -8
- cognee/modules/users/tenants/methods/select_tenant.py +62 -0
- cognee/shared/logging_utils.py +6 -0
- cognee/shared/rate_limiting.py +30 -0
- cognee/tasks/chunks/__init__.py +1 -0
- cognee/tasks/chunks/chunk_by_row.py +94 -0
- cognee/tasks/documents/__init__.py +0 -1
- cognee/tasks/documents/classify_documents.py +2 -0
- cognee/tasks/feedback/generate_improved_answers.py +3 -3
- cognee/tasks/graph/extract_graph_from_data.py +9 -10
- cognee/tasks/ingestion/ingest_data.py +1 -1
- cognee/tasks/memify/__init__.py +2 -0
- cognee/tasks/memify/cognify_session.py +41 -0
- cognee/tasks/memify/extract_user_sessions.py +73 -0
- cognee/tasks/memify/get_triplet_datapoints.py +289 -0
- cognee/tasks/storage/add_data_points.py +142 -2
- cognee/tasks/storage/index_data_points.py +33 -22
- cognee/tasks/storage/index_graph_edges.py +37 -57
- cognee/tests/integration/documents/CsvDocument_test.py +70 -0
- cognee/tests/integration/retrieval/test_triplet_retriever.py +84 -0
- cognee/tests/integration/tasks/test_add_data_points.py +139 -0
- cognee/tests/integration/tasks/test_get_triplet_datapoints.py +69 -0
- cognee/tests/integration/web_url_crawler/test_default_url_crawler.py +1 -1
- cognee/tests/integration/web_url_crawler/test_tavily_crawler.py +1 -1
- cognee/tests/integration/web_url_crawler/test_url_adding_e2e.py +13 -27
- cognee/tests/tasks/entity_extraction/entity_extraction_test.py +1 -1
- cognee/tests/test_add_docling_document.py +2 -2
- cognee/tests/test_cognee_server_start.py +84 -3
- cognee/tests/test_conversation_history.py +68 -5
- cognee/tests/test_data/example_with_header.csv +3 -0
- cognee/tests/test_dataset_database_handler.py +137 -0
- cognee/tests/test_dataset_delete.py +76 -0
- cognee/tests/test_edge_centered_payload.py +170 -0
- cognee/tests/test_edge_ingestion.py +27 -0
- cognee/tests/test_feedback_enrichment.py +1 -1
- cognee/tests/test_library.py +6 -4
- cognee/tests/test_load.py +62 -0
- cognee/tests/test_multi_tenancy.py +165 -0
- cognee/tests/test_parallel_databases.py +2 -0
- cognee/tests/test_pipeline_cache.py +164 -0
- cognee/tests/test_relational_db_migration.py +54 -2
- cognee/tests/test_search_db.py +44 -2
- cognee/tests/unit/api/test_conditional_authentication_endpoints.py +12 -3
- cognee/tests/unit/api/test_ontology_endpoint.py +252 -0
- cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +5 -0
- cognee/tests/unit/infrastructure/databases/test_index_data_points.py +27 -0
- cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py +14 -16
- cognee/tests/unit/infrastructure/llm/test_llm_config.py +46 -0
- cognee/tests/unit/infrastructure/mock_embedding_engine.py +3 -7
- cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +0 -5
- cognee/tests/unit/modules/chunking/test_text_chunker.py +248 -0
- cognee/tests/unit/modules/chunking/test_text_chunker_with_overlap.py +324 -0
- cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +2 -2
- cognee/tests/unit/modules/graph/cognee_graph_test.py +406 -0
- cognee/tests/unit/modules/memify_tasks/test_cognify_session.py +111 -0
- cognee/tests/unit/modules/memify_tasks/test_extract_user_sessions.py +175 -0
- cognee/tests/unit/modules/memify_tasks/test_get_triplet_datapoints.py +214 -0
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +0 -51
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +1 -0
- cognee/tests/unit/modules/retrieval/structured_output_test.py +204 -0
- cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +1 -1
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +0 -1
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +608 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +83 -0
- cognee/tests/unit/modules/users/test_conditional_authentication.py +0 -63
- cognee/tests/unit/processing/chunks/chunk_by_row_test.py +52 -0
- cognee/tests/unit/tasks/storage/test_add_data_points.py +288 -0
- {cognee-0.4.0.dist-info → cognee-0.5.0.dist-info}/METADATA +11 -6
- {cognee-0.4.0.dist-info → cognee-0.5.0.dist-info}/RECORD +215 -163
- {cognee-0.4.0.dist-info → cognee-0.5.0.dist-info}/WHEEL +1 -1
- {cognee-0.4.0.dist-info → cognee-0.5.0.dist-info}/entry_points.txt +0 -1
- cognee/api/v1/cognify/code_graph_pipeline.py +0 -119
- cognee/api/v1/cognify/routers/get_code_pipeline_router.py +0 -90
- cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +0 -544
- cognee/modules/retrieval/code_retriever.py +0 -232
- cognee/tasks/code/enrich_dependency_graph_checker.py +0 -35
- cognee/tasks/code/get_local_dependencies_checker.py +0 -20
- cognee/tasks/code/get_repo_dependency_graph_checker.py +0 -35
- cognee/tasks/documents/check_permissions_on_dataset.py +0 -26
- cognee/tasks/repo_processor/__init__.py +0 -2
- cognee/tasks/repo_processor/get_local_dependencies.py +0 -335
- cognee/tasks/repo_processor/get_non_code_files.py +0 -158
- cognee/tasks/repo_processor/get_repo_file_dependencies.py +0 -243
- {cognee-0.4.0.dist-info → cognee-0.5.0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.4.0.dist-info → cognee-0.5.0.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import pytest
|
|
3
|
+
from unittest.mock import AsyncMock, patch
|
|
4
|
+
|
|
5
|
+
from cognee.tasks.memify.get_triplet_datapoints import get_triplet_datapoints
|
|
6
|
+
from cognee.modules.engine.models import Triplet
|
|
7
|
+
from cognee.modules.engine.models.Entity import Entity
|
|
8
|
+
from cognee.infrastructure.engine import DataPoint
|
|
9
|
+
from cognee.modules.graph.models.EdgeType import EdgeType
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
get_triplet_datapoints_module = sys.modules["cognee.tasks.memify.get_triplet_datapoints"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@pytest.fixture
|
|
16
|
+
def mock_graph_engine():
|
|
17
|
+
"""Create a mock graph engine with get_triplets_batch method."""
|
|
18
|
+
engine = AsyncMock()
|
|
19
|
+
engine.get_triplets_batch = AsyncMock()
|
|
20
|
+
return engine
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@pytest.mark.asyncio
|
|
24
|
+
async def test_get_triplet_datapoints_success(mock_graph_engine):
|
|
25
|
+
"""Test successful extraction of triplet datapoints."""
|
|
26
|
+
mock_triplets_batch = [
|
|
27
|
+
{
|
|
28
|
+
"start_node": {
|
|
29
|
+
"id": "node1",
|
|
30
|
+
"type": "Entity",
|
|
31
|
+
"name": "Alice",
|
|
32
|
+
"description": "A person",
|
|
33
|
+
},
|
|
34
|
+
"end_node": {
|
|
35
|
+
"id": "node2",
|
|
36
|
+
"type": "Entity",
|
|
37
|
+
"name": "Bob",
|
|
38
|
+
"description": "Another person",
|
|
39
|
+
},
|
|
40
|
+
"relationship_properties": {
|
|
41
|
+
"relationship_name": "knows",
|
|
42
|
+
},
|
|
43
|
+
}
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
mock_graph_engine.get_triplets_batch.return_value = mock_triplets_batch
|
|
47
|
+
|
|
48
|
+
with (
|
|
49
|
+
patch.object(
|
|
50
|
+
get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
|
|
51
|
+
),
|
|
52
|
+
patch.object(get_triplet_datapoints_module, "get_all_subclasses") as mock_get_subclasses,
|
|
53
|
+
):
|
|
54
|
+
mock_get_subclasses.return_value = [Triplet, EdgeType, Entity]
|
|
55
|
+
|
|
56
|
+
triplets = []
|
|
57
|
+
async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
|
|
58
|
+
triplets.append(triplet)
|
|
59
|
+
|
|
60
|
+
assert len(triplets) == 1
|
|
61
|
+
assert isinstance(triplets[0], Triplet)
|
|
62
|
+
assert triplets[0].from_node_id == "node1"
|
|
63
|
+
assert triplets[0].to_node_id == "node2"
|
|
64
|
+
assert "Alice" in triplets[0].text
|
|
65
|
+
assert "knows" in triplets[0].text
|
|
66
|
+
assert "Bob" in triplets[0].text
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@pytest.mark.asyncio
|
|
70
|
+
async def test_get_triplet_datapoints_edge_text_priority_and_fallback(mock_graph_engine):
|
|
71
|
+
"""Test that edge_text is prioritized over relationship_name, and fallback works."""
|
|
72
|
+
|
|
73
|
+
class MockEntity(DataPoint):
|
|
74
|
+
name: str
|
|
75
|
+
metadata: dict = {"index_fields": ["name"]}
|
|
76
|
+
|
|
77
|
+
mock_triplets_batch = [
|
|
78
|
+
{
|
|
79
|
+
"start_node": {"id": "node1", "type": "Entity", "name": "Alice"},
|
|
80
|
+
"end_node": {"id": "node2", "type": "Entity", "name": "Bob"},
|
|
81
|
+
"relationship_properties": {
|
|
82
|
+
"relationship_name": "knows",
|
|
83
|
+
"edge_text": "has a close friendship with",
|
|
84
|
+
},
|
|
85
|
+
},
|
|
86
|
+
{
|
|
87
|
+
"start_node": {"id": "node3", "type": "Entity", "name": "Charlie"},
|
|
88
|
+
"end_node": {"id": "node4", "type": "Entity", "name": "Diana"},
|
|
89
|
+
"relationship_properties": {
|
|
90
|
+
"relationship_name": "works_with",
|
|
91
|
+
},
|
|
92
|
+
},
|
|
93
|
+
]
|
|
94
|
+
|
|
95
|
+
mock_graph_engine.get_triplets_batch.return_value = mock_triplets_batch
|
|
96
|
+
|
|
97
|
+
with (
|
|
98
|
+
patch.object(
|
|
99
|
+
get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
|
|
100
|
+
),
|
|
101
|
+
patch.object(get_triplet_datapoints_module, "get_all_subclasses") as mock_get_subclasses,
|
|
102
|
+
):
|
|
103
|
+
mock_get_subclasses.return_value = [Triplet, EdgeType, MockEntity]
|
|
104
|
+
|
|
105
|
+
triplets = []
|
|
106
|
+
async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
|
|
107
|
+
triplets.append(triplet)
|
|
108
|
+
|
|
109
|
+
assert len(triplets) == 2
|
|
110
|
+
assert "has a close friendship with" in triplets[0].text
|
|
111
|
+
assert "knows" not in triplets[0].text
|
|
112
|
+
assert "works_with" in triplets[1].text
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@pytest.mark.asyncio
|
|
116
|
+
async def test_get_triplet_datapoints_skips_missing_node_ids(mock_graph_engine):
|
|
117
|
+
"""Test that triplets with missing node IDs are skipped."""
|
|
118
|
+
|
|
119
|
+
class MockEntity(DataPoint):
|
|
120
|
+
name: str
|
|
121
|
+
metadata: dict = {"index_fields": ["name"]}
|
|
122
|
+
|
|
123
|
+
mock_triplets_batch = [
|
|
124
|
+
{
|
|
125
|
+
"start_node": {"id": "", "type": "Entity", "name": "Alice"},
|
|
126
|
+
"end_node": {"id": "node2", "type": "Entity", "name": "Bob"},
|
|
127
|
+
"relationship_properties": {"relationship_name": "knows"},
|
|
128
|
+
},
|
|
129
|
+
{
|
|
130
|
+
"start_node": {"id": "node3", "type": "Entity", "name": "Charlie"},
|
|
131
|
+
"end_node": {"id": "node4", "type": "Entity", "name": "Diana"},
|
|
132
|
+
"relationship_properties": {"relationship_name": "works_with"},
|
|
133
|
+
},
|
|
134
|
+
]
|
|
135
|
+
|
|
136
|
+
mock_graph_engine.get_triplets_batch.return_value = mock_triplets_batch
|
|
137
|
+
|
|
138
|
+
with (
|
|
139
|
+
patch.object(
|
|
140
|
+
get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
|
|
141
|
+
),
|
|
142
|
+
patch.object(get_triplet_datapoints_module, "get_all_subclasses") as mock_get_subclasses,
|
|
143
|
+
):
|
|
144
|
+
mock_get_subclasses.return_value = [Triplet, EdgeType, MockEntity]
|
|
145
|
+
|
|
146
|
+
triplets = []
|
|
147
|
+
async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
|
|
148
|
+
triplets.append(triplet)
|
|
149
|
+
|
|
150
|
+
assert len(triplets) == 1
|
|
151
|
+
assert triplets[0].from_node_id == "node3"
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@pytest.mark.asyncio
|
|
155
|
+
async def test_get_triplet_datapoints_error_handling(mock_graph_engine):
|
|
156
|
+
"""Test that errors are handled correctly - invalid data is skipped, query errors propagate."""
|
|
157
|
+
|
|
158
|
+
class MockEntity(DataPoint):
|
|
159
|
+
name: str
|
|
160
|
+
metadata: dict = {"index_fields": ["name"]}
|
|
161
|
+
|
|
162
|
+
mock_triplets_batch = [
|
|
163
|
+
{
|
|
164
|
+
"start_node": {"id": "node1", "type": "Entity", "name": "Alice"},
|
|
165
|
+
"end_node": {"id": "node2", "type": "Entity", "name": "Bob"},
|
|
166
|
+
"relationship_properties": {"relationship_name": "knows"},
|
|
167
|
+
},
|
|
168
|
+
{
|
|
169
|
+
"start_node": None,
|
|
170
|
+
"end_node": {"id": "node4", "type": "Entity", "name": "Diana"},
|
|
171
|
+
"relationship_properties": {"relationship_name": "works_with"},
|
|
172
|
+
},
|
|
173
|
+
]
|
|
174
|
+
|
|
175
|
+
mock_graph_engine.get_triplets_batch.return_value = mock_triplets_batch
|
|
176
|
+
|
|
177
|
+
with (
|
|
178
|
+
patch.object(
|
|
179
|
+
get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
|
|
180
|
+
),
|
|
181
|
+
patch.object(get_triplet_datapoints_module, "get_all_subclasses") as mock_get_subclasses,
|
|
182
|
+
):
|
|
183
|
+
mock_get_subclasses.return_value = [Triplet, EdgeType, MockEntity]
|
|
184
|
+
|
|
185
|
+
triplets = []
|
|
186
|
+
async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
|
|
187
|
+
triplets.append(triplet)
|
|
188
|
+
|
|
189
|
+
assert len(triplets) == 1
|
|
190
|
+
assert triplets[0].from_node_id == "node1"
|
|
191
|
+
|
|
192
|
+
mock_graph_engine.get_triplets_batch.side_effect = Exception("Database connection error")
|
|
193
|
+
|
|
194
|
+
with patch.object(
|
|
195
|
+
get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
|
|
196
|
+
):
|
|
197
|
+
triplets = []
|
|
198
|
+
with pytest.raises(Exception, match="Database connection error"):
|
|
199
|
+
async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
|
|
200
|
+
triplets.append(triplet)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
@pytest.mark.asyncio
|
|
204
|
+
async def test_get_triplet_datapoints_no_get_triplets_batch_method(mock_graph_engine):
|
|
205
|
+
"""Test that NotImplementedError is raised when graph engine lacks get_triplets_batch."""
|
|
206
|
+
del mock_graph_engine.get_triplets_batch
|
|
207
|
+
|
|
208
|
+
with patch.object(
|
|
209
|
+
get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
|
|
210
|
+
):
|
|
211
|
+
triplets = []
|
|
212
|
+
with pytest.raises(NotImplementedError, match="does not support get_triplets_batch"):
|
|
213
|
+
async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
|
|
214
|
+
triplets.append(triplet)
|
|
@@ -2,7 +2,6 @@ import os
|
|
|
2
2
|
import pytest
|
|
3
3
|
import pathlib
|
|
4
4
|
from typing import Optional, Union
|
|
5
|
-
from pydantic import BaseModel
|
|
6
5
|
|
|
7
6
|
import cognee
|
|
8
7
|
from cognee.low_level import setup, DataPoint
|
|
@@ -11,11 +10,6 @@ from cognee.tasks.storage import add_data_points
|
|
|
11
10
|
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
|
12
11
|
|
|
13
12
|
|
|
14
|
-
class TestAnswer(BaseModel):
|
|
15
|
-
answer: str
|
|
16
|
-
explanation: str
|
|
17
|
-
|
|
18
|
-
|
|
19
13
|
class TestGraphCompletionCoTRetriever:
|
|
20
14
|
@pytest.mark.asyncio
|
|
21
15
|
async def test_graph_completion_cot_context_simple(self):
|
|
@@ -174,48 +168,3 @@ class TestGraphCompletionCoTRetriever:
|
|
|
174
168
|
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
|
175
169
|
"Answer must contain only non-empty strings"
|
|
176
170
|
)
|
|
177
|
-
|
|
178
|
-
@pytest.mark.asyncio
|
|
179
|
-
async def test_get_structured_completion(self):
|
|
180
|
-
system_directory_path = os.path.join(
|
|
181
|
-
pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
|
|
182
|
-
)
|
|
183
|
-
cognee.config.system_root_directory(system_directory_path)
|
|
184
|
-
data_directory_path = os.path.join(
|
|
185
|
-
pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
|
|
186
|
-
)
|
|
187
|
-
cognee.config.data_root_directory(data_directory_path)
|
|
188
|
-
|
|
189
|
-
await cognee.prune.prune_data()
|
|
190
|
-
await cognee.prune.prune_system(metadata=True)
|
|
191
|
-
await setup()
|
|
192
|
-
|
|
193
|
-
class Company(DataPoint):
|
|
194
|
-
name: str
|
|
195
|
-
|
|
196
|
-
class Person(DataPoint):
|
|
197
|
-
name: str
|
|
198
|
-
works_for: Company
|
|
199
|
-
|
|
200
|
-
company1 = Company(name="Figma")
|
|
201
|
-
person1 = Person(name="Steve Rodger", works_for=company1)
|
|
202
|
-
|
|
203
|
-
entities = [company1, person1]
|
|
204
|
-
await add_data_points(entities)
|
|
205
|
-
|
|
206
|
-
retriever = GraphCompletionCotRetriever()
|
|
207
|
-
|
|
208
|
-
# Test with string response model (default)
|
|
209
|
-
string_answer = await retriever.get_structured_completion("Who works at Figma?")
|
|
210
|
-
assert isinstance(string_answer, str), f"Expected str, got {type(string_answer).__name__}"
|
|
211
|
-
assert string_answer.strip(), "Answer should not be empty"
|
|
212
|
-
|
|
213
|
-
# Test with structured response model
|
|
214
|
-
structured_answer = await retriever.get_structured_completion(
|
|
215
|
-
"Who works at Figma?", response_model=TestAnswer
|
|
216
|
-
)
|
|
217
|
-
assert isinstance(structured_answer, TestAnswer), (
|
|
218
|
-
f"Expected TestAnswer, got {type(structured_answer).__name__}"
|
|
219
|
-
)
|
|
220
|
-
assert structured_answer.answer.strip(), "Answer field should not be empty"
|
|
221
|
-
assert structured_answer.explanation.strip(), "Explanation field should not be empty"
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
import cognee
|
|
5
|
+
import pathlib
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel
|
|
9
|
+
from cognee.low_level import setup, DataPoint
|
|
10
|
+
from cognee.tasks.storage import add_data_points
|
|
11
|
+
from cognee.modules.chunking.models import DocumentChunk
|
|
12
|
+
from cognee.modules.data.processing.document_types import TextDocument
|
|
13
|
+
from cognee.modules.engine.models import Entity, EntityType
|
|
14
|
+
from cognee.modules.retrieval.entity_extractors.DummyEntityExtractor import DummyEntityExtractor
|
|
15
|
+
from cognee.modules.retrieval.context_providers.DummyContextProvider import DummyContextProvider
|
|
16
|
+
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
|
17
|
+
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
|
18
|
+
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
|
19
|
+
GraphCompletionContextExtensionRetriever,
|
|
20
|
+
)
|
|
21
|
+
from cognee.modules.retrieval.EntityCompletionRetriever import EntityCompletionRetriever
|
|
22
|
+
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
|
23
|
+
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class TestAnswer(BaseModel):
|
|
27
|
+
answer: str
|
|
28
|
+
explanation: str
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _assert_string_answer(answer: list[str]):
|
|
32
|
+
assert isinstance(answer, list), f"Expected str, got {type(answer).__name__}"
|
|
33
|
+
assert all(isinstance(item, str) and item.strip() for item in answer), "Items should be strings"
|
|
34
|
+
assert all(item.strip() for item in answer), "Items should not be empty"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _assert_structured_answer(answer: list[TestAnswer]):
|
|
38
|
+
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
|
39
|
+
assert all(isinstance(x, TestAnswer) for x in answer), "Items should be TestAnswer"
|
|
40
|
+
assert all(x.answer.strip() for x in answer), "Answer text should not be empty"
|
|
41
|
+
assert all(x.explanation.strip() for x in answer), "Explanation should not be empty"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
async def _test_get_structured_graph_completion_cot():
|
|
45
|
+
retriever = GraphCompletionCotRetriever()
|
|
46
|
+
|
|
47
|
+
# Test with string response model (default)
|
|
48
|
+
string_answer = await retriever.get_completion("Who works at Figma?")
|
|
49
|
+
_assert_string_answer(string_answer)
|
|
50
|
+
|
|
51
|
+
# Test with structured response model
|
|
52
|
+
structured_answer = await retriever.get_completion(
|
|
53
|
+
"Who works at Figma?", response_model=TestAnswer
|
|
54
|
+
)
|
|
55
|
+
_assert_structured_answer(structured_answer)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
async def _test_get_structured_graph_completion():
|
|
59
|
+
retriever = GraphCompletionRetriever()
|
|
60
|
+
|
|
61
|
+
# Test with string response model (default)
|
|
62
|
+
string_answer = await retriever.get_completion("Who works at Figma?")
|
|
63
|
+
_assert_string_answer(string_answer)
|
|
64
|
+
|
|
65
|
+
# Test with structured response model
|
|
66
|
+
structured_answer = await retriever.get_completion(
|
|
67
|
+
"Who works at Figma?", response_model=TestAnswer
|
|
68
|
+
)
|
|
69
|
+
_assert_structured_answer(structured_answer)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
async def _test_get_structured_graph_completion_temporal():
|
|
73
|
+
retriever = TemporalRetriever()
|
|
74
|
+
|
|
75
|
+
# Test with string response model (default)
|
|
76
|
+
string_answer = await retriever.get_completion("When did Steve start working at Figma?")
|
|
77
|
+
_assert_string_answer(string_answer)
|
|
78
|
+
|
|
79
|
+
# Test with structured response model
|
|
80
|
+
structured_answer = await retriever.get_completion(
|
|
81
|
+
"When did Steve start working at Figma??", response_model=TestAnswer
|
|
82
|
+
)
|
|
83
|
+
_assert_structured_answer(structured_answer)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
async def _test_get_structured_graph_completion_rag():
|
|
87
|
+
retriever = CompletionRetriever()
|
|
88
|
+
|
|
89
|
+
# Test with string response model (default)
|
|
90
|
+
string_answer = await retriever.get_completion("Where does Steve work?")
|
|
91
|
+
_assert_string_answer(string_answer)
|
|
92
|
+
|
|
93
|
+
# Test with structured response model
|
|
94
|
+
structured_answer = await retriever.get_completion(
|
|
95
|
+
"Where does Steve work?", response_model=TestAnswer
|
|
96
|
+
)
|
|
97
|
+
_assert_structured_answer(structured_answer)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
async def _test_get_structured_graph_completion_context_extension():
|
|
101
|
+
retriever = GraphCompletionContextExtensionRetriever()
|
|
102
|
+
|
|
103
|
+
# Test with string response model (default)
|
|
104
|
+
string_answer = await retriever.get_completion("Who works at Figma?")
|
|
105
|
+
_assert_string_answer(string_answer)
|
|
106
|
+
|
|
107
|
+
# Test with structured response model
|
|
108
|
+
structured_answer = await retriever.get_completion(
|
|
109
|
+
"Who works at Figma?", response_model=TestAnswer
|
|
110
|
+
)
|
|
111
|
+
_assert_structured_answer(structured_answer)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
async def _test_get_structured_entity_completion():
|
|
115
|
+
retriever = EntityCompletionRetriever(DummyEntityExtractor(), DummyContextProvider())
|
|
116
|
+
|
|
117
|
+
# Test with string response model (default)
|
|
118
|
+
string_answer = await retriever.get_completion("Who is Albert Einstein?")
|
|
119
|
+
_assert_string_answer(string_answer)
|
|
120
|
+
|
|
121
|
+
# Test with structured response model
|
|
122
|
+
structured_answer = await retriever.get_completion(
|
|
123
|
+
"Who is Albert Einstein?", response_model=TestAnswer
|
|
124
|
+
)
|
|
125
|
+
_assert_structured_answer(structured_answer)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class TestStructuredOutputCompletion:
|
|
129
|
+
@pytest.mark.asyncio
|
|
130
|
+
async def test_get_structured_completion(self):
|
|
131
|
+
system_directory_path = os.path.join(
|
|
132
|
+
pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
|
|
133
|
+
)
|
|
134
|
+
cognee.config.system_root_directory(system_directory_path)
|
|
135
|
+
data_directory_path = os.path.join(
|
|
136
|
+
pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
|
|
137
|
+
)
|
|
138
|
+
cognee.config.data_root_directory(data_directory_path)
|
|
139
|
+
|
|
140
|
+
await cognee.prune.prune_data()
|
|
141
|
+
await cognee.prune.prune_system(metadata=True)
|
|
142
|
+
await setup()
|
|
143
|
+
|
|
144
|
+
class Company(DataPoint):
|
|
145
|
+
name: str
|
|
146
|
+
|
|
147
|
+
class Person(DataPoint):
|
|
148
|
+
name: str
|
|
149
|
+
works_for: Company
|
|
150
|
+
works_since: int
|
|
151
|
+
|
|
152
|
+
company1 = Company(name="Figma")
|
|
153
|
+
person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015)
|
|
154
|
+
|
|
155
|
+
entities = [company1, person1]
|
|
156
|
+
await add_data_points(entities)
|
|
157
|
+
|
|
158
|
+
document = TextDocument(
|
|
159
|
+
name="Steve Rodger's career",
|
|
160
|
+
raw_data_location="somewhere",
|
|
161
|
+
external_metadata="",
|
|
162
|
+
mime_type="text/plain",
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
chunk1 = DocumentChunk(
|
|
166
|
+
text="Steve Rodger",
|
|
167
|
+
chunk_size=2,
|
|
168
|
+
chunk_index=0,
|
|
169
|
+
cut_type="sentence_end",
|
|
170
|
+
is_part_of=document,
|
|
171
|
+
contains=[],
|
|
172
|
+
)
|
|
173
|
+
chunk2 = DocumentChunk(
|
|
174
|
+
text="Mike Broski",
|
|
175
|
+
chunk_size=2,
|
|
176
|
+
chunk_index=1,
|
|
177
|
+
cut_type="sentence_end",
|
|
178
|
+
is_part_of=document,
|
|
179
|
+
contains=[],
|
|
180
|
+
)
|
|
181
|
+
chunk3 = DocumentChunk(
|
|
182
|
+
text="Christina Mayer",
|
|
183
|
+
chunk_size=2,
|
|
184
|
+
chunk_index=2,
|
|
185
|
+
cut_type="sentence_end",
|
|
186
|
+
is_part_of=document,
|
|
187
|
+
contains=[],
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
entities = [chunk1, chunk2, chunk3]
|
|
191
|
+
await add_data_points(entities)
|
|
192
|
+
|
|
193
|
+
entity_type = EntityType(name="Person", description="A human individual")
|
|
194
|
+
entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist")
|
|
195
|
+
|
|
196
|
+
entities = [entity]
|
|
197
|
+
await add_data_points(entities)
|
|
198
|
+
|
|
199
|
+
await _test_get_structured_graph_completion_cot()
|
|
200
|
+
await _test_get_structured_graph_completion()
|
|
201
|
+
await _test_get_structured_graph_completion_temporal()
|
|
202
|
+
await _test_get_structured_graph_completion_rag()
|
|
203
|
+
await _test_get_structured_graph_completion_context_extension()
|
|
204
|
+
await _test_get_structured_entity_completion()
|
|
@@ -13,7 +13,7 @@ from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
|
|
13
13
|
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
|
|
14
14
|
|
|
15
15
|
|
|
16
|
-
class
|
|
16
|
+
class TestSummariesRetriever:
|
|
17
17
|
@pytest.mark.asyncio
|
|
18
18
|
async def test_chunk_context(self):
|
|
19
19
|
system_directory_path = os.path.join(
|