cognee 0.5.1__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +2 -0
- cognee/alembic/README +1 -0
- cognee/alembic/env.py +107 -0
- cognee/alembic/script.py.mako +26 -0
- cognee/alembic/versions/1a58b986e6e1_enable_delete_for_old_tutorial_notebooks.py +52 -0
- cognee/alembic/versions/1d0bb7fede17_add_pipeline_run_status.py +33 -0
- cognee/alembic/versions/1daae0df1866_incremental_loading.py +48 -0
- cognee/alembic/versions/211ab850ef3d_add_sync_operations_table.py +118 -0
- cognee/alembic/versions/45957f0a9849_add_notebook_table.py +46 -0
- cognee/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py +333 -0
- cognee/alembic/versions/482cd6517ce4_add_default_user.py +30 -0
- cognee/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py +98 -0
- cognee/alembic/versions/8057ae7329c2_initial_migration.py +25 -0
- cognee/alembic/versions/9e7a3cb85175_loader_separation.py +104 -0
- cognee/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py +38 -0
- cognee/alembic/versions/ab7e313804ae_permission_system_rework.py +236 -0
- cognee/alembic/versions/b9274c27a25a_kuzu_11_migration.py +75 -0
- cognee/alembic/versions/c946955da633_multi_tenant_support.py +137 -0
- cognee/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py +51 -0
- cognee/alembic/versions/e4ebee1091e7_expand_data_model_info.py +140 -0
- cognee/alembic.ini +117 -0
- cognee/api/v1/add/add.py +2 -1
- cognee/api/v1/add/routers/get_add_router.py +2 -0
- cognee/api/v1/cognify/cognify.py +11 -6
- cognee/api/v1/cognify/routers/get_cognify_router.py +8 -0
- cognee/api/v1/config/config.py +60 -0
- cognee/api/v1/datasets/routers/get_datasets_router.py +46 -3
- cognee/api/v1/memify/routers/get_memify_router.py +3 -0
- cognee/api/v1/search/routers/get_search_router.py +21 -6
- cognee/api/v1/search/search.py +21 -5
- cognee/api/v1/sync/routers/get_sync_router.py +3 -3
- cognee/cli/commands/add_command.py +1 -1
- cognee/cli/commands/cognify_command.py +6 -0
- cognee/cli/commands/config_command.py +1 -1
- cognee/context_global_variables.py +5 -1
- cognee/eval_framework/answer_generation/answer_generation_executor.py +7 -8
- cognee/infrastructure/databases/cache/cache_db_interface.py +38 -1
- cognee/infrastructure/databases/cache/config.py +6 -0
- cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py +21 -0
- cognee/infrastructure/databases/cache/get_cache_engine.py +9 -3
- cognee/infrastructure/databases/cache/redis/RedisAdapter.py +60 -1
- cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +7 -0
- cognee/infrastructure/databases/graph/get_graph_engine.py +29 -1
- cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +62 -27
- cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +17 -4
- cognee/infrastructure/databases/relational/config.py +16 -1
- cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +26 -3
- cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +2 -0
- cognee/infrastructure/databases/vector/config.py +6 -0
- cognee/infrastructure/databases/vector/create_vector_engine.py +70 -16
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +64 -9
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +13 -2
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +16 -3
- cognee/infrastructure/databases/vector/models/ScoredResult.py +3 -3
- cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +16 -3
- cognee/infrastructure/databases/vector/pgvector/PGVectorDatasetDatabaseHandler.py +86 -0
- cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py +81 -2
- cognee/infrastructure/databases/vector/vector_db_interface.py +8 -0
- cognee/infrastructure/files/utils/get_data_file_path.py +33 -27
- cognee/infrastructure/llm/LLMGateway.py +0 -13
- cognee/infrastructure/llm/prompts/extract_query_time.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_event_entity_prompt.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_event_graph_prompt.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt.txt +2 -2
- cognee/infrastructure/llm/prompts/generate_graph_prompt_guided.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt_oneshot.txt +2 -2
- cognee/infrastructure/llm/prompts/generate_graph_prompt_simple.txt +1 -1
- cognee/infrastructure/llm/prompts/generate_graph_prompt_strict.txt +1 -1
- cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt +6 -6
- cognee/infrastructure/llm/prompts/test.txt +1 -1
- cognee/infrastructure/llm/prompts/translate_content.txt +19 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +29 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llama_cpp/adapter.py +191 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
- cognee/modules/chunking/models/DocumentChunk.py +0 -1
- cognee/modules/cognify/config.py +2 -0
- cognee/modules/data/models/Data.py +3 -1
- cognee/modules/engine/models/Entity.py +0 -1
- cognee/modules/engine/operations/setup.py +6 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +150 -37
- cognee/modules/graph/cognee_graph/CogneeGraphElements.py +48 -2
- cognee/modules/graph/utils/__init__.py +1 -0
- cognee/modules/graph/utils/get_entity_nodes_from_triplets.py +12 -0
- cognee/modules/notebooks/methods/__init__.py +1 -0
- cognee/modules/notebooks/methods/create_notebook.py +0 -34
- cognee/modules/notebooks/methods/create_tutorial_notebooks.py +191 -0
- cognee/modules/notebooks/methods/get_notebooks.py +12 -8
- cognee/modules/notebooks/tutorials/cognee-basics/cell-1.md +3 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-2.md +10 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-3.md +7 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-4.py +28 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-5.py +3 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-6.py +9 -0
- cognee/modules/notebooks/tutorials/cognee-basics/cell-7.py +17 -0
- cognee/modules/notebooks/tutorials/cognee-basics/config.json +4 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-1.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-10.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-11.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-12.py +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-13.md +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-14.py +6 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-15.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-16.py +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-2.md +9 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-3.md +7 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-4.md +9 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-5.md +5 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-6.py +13 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-7.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-8.md +3 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-9.py +31 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/config.json +4 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/copilot_conversations.json +107 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/guido_contributions.json +976 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/my_developer_rules.md +79 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/pep_style_guide.md +74 -0
- cognee/modules/notebooks/tutorials/python-development-with-cognee/data/zen_principles.md +74 -0
- cognee/modules/retrieval/EntityCompletionRetriever.py +51 -38
- cognee/modules/retrieval/__init__.py +0 -1
- cognee/modules/retrieval/base_retriever.py +66 -10
- cognee/modules/retrieval/chunks_retriever.py +57 -49
- cognee/modules/retrieval/coding_rules_retriever.py +12 -5
- cognee/modules/retrieval/completion_retriever.py +29 -28
- cognee/modules/retrieval/cypher_search_retriever.py +25 -20
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +42 -46
- cognee/modules/retrieval/graph_completion_cot_retriever.py +68 -51
- cognee/modules/retrieval/graph_completion_retriever.py +78 -63
- cognee/modules/retrieval/graph_summary_completion_retriever.py +2 -0
- cognee/modules/retrieval/lexical_retriever.py +34 -12
- cognee/modules/retrieval/natural_language_retriever.py +18 -15
- cognee/modules/retrieval/summaries_retriever.py +51 -34
- cognee/modules/retrieval/temporal_retriever.py +59 -49
- cognee/modules/retrieval/triplet_retriever.py +32 -33
- cognee/modules/retrieval/utils/access_tracking.py +88 -0
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +99 -103
- cognee/modules/retrieval/utils/node_edge_vector_search.py +174 -0
- cognee/modules/search/methods/__init__.py +1 -0
- cognee/modules/search/methods/get_retriever_output.py +53 -0
- cognee/modules/search/methods/get_search_type_retriever_instance.py +252 -0
- cognee/modules/search/methods/search.py +90 -222
- cognee/modules/search/models/SearchResultPayload.py +67 -0
- cognee/modules/search/types/SearchResult.py +1 -8
- cognee/modules/search/types/SearchType.py +1 -2
- cognee/modules/search/types/__init__.py +1 -1
- cognee/modules/search/utils/__init__.py +1 -2
- cognee/modules/search/utils/transform_insights_to_graph.py +2 -2
- cognee/modules/search/utils/{transform_context_to_graph.py → transform_triplets_to_graph.py} +2 -2
- cognee/modules/users/authentication/default/default_transport.py +11 -1
- cognee/modules/users/authentication/get_api_auth_backend.py +2 -1
- cognee/modules/users/authentication/get_client_auth_backend.py +2 -1
- cognee/modules/users/methods/create_user.py +0 -9
- cognee/modules/users/permissions/methods/has_user_management_permission.py +29 -0
- cognee/modules/visualization/cognee_network_visualization.py +1 -1
- cognee/run_migrations.py +48 -0
- cognee/shared/exceptions/__init__.py +1 -3
- cognee/shared/exceptions/exceptions.py +11 -1
- cognee/shared/usage_logger.py +332 -0
- cognee/shared/utils.py +12 -5
- cognee/tasks/chunks/__init__.py +9 -0
- cognee/tasks/cleanup/cleanup_unused_data.py +172 -0
- cognee/tasks/graph/__init__.py +7 -0
- cognee/tasks/ingestion/data_item.py +8 -0
- cognee/tasks/ingestion/ingest_data.py +12 -1
- cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
- cognee/tasks/memify/__init__.py +8 -0
- cognee/tasks/memify/extract_usage_frequency.py +613 -0
- cognee/tasks/summarization/models.py +0 -2
- cognee/tasks/temporal_graph/__init__.py +0 -1
- cognee/tasks/translation/__init__.py +96 -0
- cognee/tasks/translation/config.py +110 -0
- cognee/tasks/translation/detect_language.py +190 -0
- cognee/tasks/translation/exceptions.py +62 -0
- cognee/tasks/translation/models.py +72 -0
- cognee/tasks/translation/providers/__init__.py +44 -0
- cognee/tasks/translation/providers/azure_provider.py +192 -0
- cognee/tasks/translation/providers/base.py +85 -0
- cognee/tasks/translation/providers/google_provider.py +158 -0
- cognee/tasks/translation/providers/llm_provider.py +143 -0
- cognee/tasks/translation/translate_content.py +282 -0
- cognee/tasks/web_scraper/default_url_crawler.py +6 -2
- cognee/tests/cli_tests/cli_unit_tests/test_cli_commands.py +1 -0
- cognee/tests/cli_tests/cli_unit_tests/test_cli_edge_cases.py +3 -0
- cognee/tests/integration/retrieval/test_brute_force_triplet_search_with_cognify.py +62 -0
- cognee/tests/integration/retrieval/test_chunks_retriever.py +351 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever.py +276 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +228 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +217 -0
- cognee/tests/integration/retrieval/test_rag_completion_retriever.py +319 -0
- cognee/tests/integration/retrieval/test_structured_output.py +258 -0
- cognee/tests/integration/retrieval/test_summaries_retriever.py +195 -0
- cognee/tests/integration/retrieval/test_temporal_retriever.py +336 -0
- cognee/tests/integration/retrieval/test_triplet_retriever.py +45 -1
- cognee/tests/integration/shared/test_usage_logger_integration.py +255 -0
- cognee/tests/tasks/translation/README.md +147 -0
- cognee/tests/tasks/translation/__init__.py +1 -0
- cognee/tests/tasks/translation/config_test.py +93 -0
- cognee/tests/tasks/translation/detect_language_test.py +118 -0
- cognee/tests/tasks/translation/providers_test.py +151 -0
- cognee/tests/tasks/translation/translate_content_test.py +213 -0
- cognee/tests/test_chromadb.py +1 -1
- cognee/tests/test_cleanup_unused_data.py +165 -0
- cognee/tests/test_custom_data_label.py +68 -0
- cognee/tests/test_delete_by_id.py +6 -6
- cognee/tests/test_extract_usage_frequency.py +308 -0
- cognee/tests/test_kuzu.py +17 -7
- cognee/tests/test_lancedb.py +3 -1
- cognee/tests/test_library.py +1 -1
- cognee/tests/test_neo4j.py +17 -7
- cognee/tests/test_neptune_analytics_vector.py +3 -1
- cognee/tests/test_permissions.py +172 -187
- cognee/tests/test_pgvector.py +3 -1
- cognee/tests/test_relational_db_migration.py +15 -1
- cognee/tests/test_remote_kuzu.py +3 -1
- cognee/tests/test_s3_file_storage.py +1 -1
- cognee/tests/test_search_db.py +345 -205
- cognee/tests/test_usage_logger_e2e.py +268 -0
- cognee/tests/unit/api/test_get_raw_data_endpoint.py +206 -0
- cognee/tests/unit/eval_framework/answer_generation_test.py +4 -3
- cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
- cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
- cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +2 -0
- cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
- cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +42 -2
- cognee/tests/unit/modules/graph/cognee_graph_test.py +329 -31
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +122 -168
- cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +486 -157
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +693 -155
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +619 -200
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +300 -171
- cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +184 -155
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +544 -79
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +476 -28
- cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
- cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
- cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py +273 -0
- cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +267 -7
- cognee/tests/unit/modules/search/test_get_search_type_retriever_instance.py +125 -0
- cognee/tests/unit/modules/search/test_search.py +96 -20
- cognee/tests/unit/modules/search/test_search_prepare_search_result_contract.py +190 -0
- cognee/tests/unit/modules/users/test_tutorial_notebook_creation.py +511 -297
- cognee/tests/unit/shared/test_usage_logger.py +241 -0
- cognee/tests/unit/users/permissions/test_has_user_management_permission.py +46 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/METADATA +22 -17
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/RECORD +258 -157
- cognee/api/.env.example +0 -5
- cognee/modules/retrieval/base_graph_retriever.py +0 -24
- cognee/modules/search/methods/get_search_type_tools.py +0 -223
- cognee/modules/search/methods/no_access_control_search.py +0 -62
- cognee/modules/search/utils/prepare_search_result.py +0 -63
- cognee/tests/test_feedback_enrichment.py +0 -174
- cognee/tests/unit/modules/retrieval/structured_output_test.py +0 -204
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/WHEEL +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.2.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import pytest
|
|
3
|
+
import pytest_asyncio
|
|
4
|
+
import asyncio
|
|
5
|
+
from fastapi.testclient import TestClient
|
|
6
|
+
|
|
7
|
+
import cognee
|
|
8
|
+
from cognee.api.client import app
|
|
9
|
+
from cognee.modules.users.methods import get_default_user, get_authenticated_user
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
async def _reset_engines_and_prune():
|
|
13
|
+
"""Reset db engine caches and prune data/system."""
|
|
14
|
+
try:
|
|
15
|
+
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
16
|
+
|
|
17
|
+
vector_engine = get_vector_engine()
|
|
18
|
+
if hasattr(vector_engine, "engine") and hasattr(vector_engine.engine, "dispose"):
|
|
19
|
+
await vector_engine.engine.dispose(close=True)
|
|
20
|
+
except Exception:
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
await cognee.prune.prune_data()
|
|
24
|
+
await cognee.prune.prune_system(metadata=True)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@pytest.fixture(scope="session")
|
|
28
|
+
def event_loop():
|
|
29
|
+
"""Use a single asyncio event loop for this test module."""
|
|
30
|
+
loop = asyncio.new_event_loop()
|
|
31
|
+
try:
|
|
32
|
+
yield loop
|
|
33
|
+
finally:
|
|
34
|
+
loop.close()
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@pytest.fixture(scope="session")
|
|
38
|
+
def e2e_config():
|
|
39
|
+
"""Configure environment for E2E tests."""
|
|
40
|
+
original_env = os.environ.copy()
|
|
41
|
+
os.environ["USAGE_LOGGING"] = "true"
|
|
42
|
+
os.environ["CACHE_BACKEND"] = "redis"
|
|
43
|
+
os.environ["CACHE_HOST"] = "localhost"
|
|
44
|
+
os.environ["CACHE_PORT"] = "6379"
|
|
45
|
+
yield
|
|
46
|
+
os.environ.clear()
|
|
47
|
+
os.environ.update(original_env)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@pytest.fixture(scope="session")
|
|
51
|
+
def authenticated_client(test_client):
|
|
52
|
+
"""Override authentication to use default user."""
|
|
53
|
+
|
|
54
|
+
async def override_get_authenticated_user():
|
|
55
|
+
return await get_default_user()
|
|
56
|
+
|
|
57
|
+
app.dependency_overrides[get_authenticated_user] = override_get_authenticated_user
|
|
58
|
+
yield test_client
|
|
59
|
+
app.dependency_overrides.pop(get_authenticated_user, None)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@pytest_asyncio.fixture(scope="session")
|
|
63
|
+
async def test_data_setup():
|
|
64
|
+
"""Set up test data: prune first, then add file and cognify."""
|
|
65
|
+
await _reset_engines_and_prune()
|
|
66
|
+
|
|
67
|
+
dataset_name = "test_e2e_dataset"
|
|
68
|
+
test_text = "Germany is located in Europe right next to the Netherlands."
|
|
69
|
+
|
|
70
|
+
await cognee.add(test_text, dataset_name)
|
|
71
|
+
await cognee.cognify([dataset_name])
|
|
72
|
+
|
|
73
|
+
yield dataset_name
|
|
74
|
+
|
|
75
|
+
await _reset_engines_and_prune()
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@pytest_asyncio.fixture
|
|
79
|
+
async def mcp_data_setup():
|
|
80
|
+
"""Set up test data for MCP tests: prune first, then add file and cognify."""
|
|
81
|
+
await _reset_engines_and_prune()
|
|
82
|
+
|
|
83
|
+
dataset_name = "test_mcp_dataset"
|
|
84
|
+
test_text = "Germany is located in Europe right next to the Netherlands."
|
|
85
|
+
|
|
86
|
+
await cognee.add(test_text, dataset_name)
|
|
87
|
+
await cognee.cognify([dataset_name])
|
|
88
|
+
|
|
89
|
+
yield dataset_name
|
|
90
|
+
|
|
91
|
+
await _reset_engines_and_prune()
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@pytest.fixture(scope="session")
|
|
95
|
+
def test_client():
|
|
96
|
+
"""TestClient instance for API calls."""
|
|
97
|
+
with TestClient(app) as client:
|
|
98
|
+
yield client
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@pytest_asyncio.fixture
|
|
102
|
+
async def cache_engine(e2e_config):
|
|
103
|
+
"""Get cache engine for log verification in test's event loop."""
|
|
104
|
+
from cognee.infrastructure.databases.cache.redis.RedisAdapter import RedisAdapter
|
|
105
|
+
from cognee.infrastructure.databases.cache.config import get_cache_config
|
|
106
|
+
|
|
107
|
+
config = get_cache_config()
|
|
108
|
+
if not config.usage_logging or config.cache_backend != "redis":
|
|
109
|
+
pytest.skip("Redis usage logging not configured")
|
|
110
|
+
|
|
111
|
+
engine = RedisAdapter(
|
|
112
|
+
host=config.cache_host,
|
|
113
|
+
port=config.cache_port,
|
|
114
|
+
username=config.cache_username,
|
|
115
|
+
password=config.cache_password,
|
|
116
|
+
log_key="usage_logs",
|
|
117
|
+
)
|
|
118
|
+
return engine
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@pytest.mark.asyncio
|
|
122
|
+
async def test_api_endpoint_logging(e2e_config, authenticated_client, cache_engine):
|
|
123
|
+
"""Test that API endpoints succeed and log to Redis."""
|
|
124
|
+
user = await get_default_user()
|
|
125
|
+
dataset_name = "test_e2e_api_dataset"
|
|
126
|
+
|
|
127
|
+
add_response = authenticated_client.post(
|
|
128
|
+
"/api/v1/add",
|
|
129
|
+
data={"datasetName": dataset_name},
|
|
130
|
+
files=[
|
|
131
|
+
(
|
|
132
|
+
"data",
|
|
133
|
+
(
|
|
134
|
+
"test.txt",
|
|
135
|
+
b"Germany is located in Europe right next to the Netherlands.",
|
|
136
|
+
"text/plain",
|
|
137
|
+
),
|
|
138
|
+
)
|
|
139
|
+
],
|
|
140
|
+
)
|
|
141
|
+
assert add_response.status_code in [200, 201], f"Add endpoint failed: {add_response.text}"
|
|
142
|
+
|
|
143
|
+
cognify_response = authenticated_client.post(
|
|
144
|
+
"/api/v1/cognify",
|
|
145
|
+
json={"datasets": [dataset_name], "run_in_background": False},
|
|
146
|
+
)
|
|
147
|
+
assert cognify_response.status_code in [200, 201], (
|
|
148
|
+
f"Cognify endpoint failed: {cognify_response.text}"
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
search_response = authenticated_client.post(
|
|
152
|
+
"/api/v1/search",
|
|
153
|
+
json={"query": "Germany", "search_type": "GRAPH_COMPLETION", "datasets": [dataset_name]},
|
|
154
|
+
)
|
|
155
|
+
assert search_response.status_code == 200, f"Search endpoint failed: {search_response.text}"
|
|
156
|
+
|
|
157
|
+
logs = await cache_engine.get_usage_logs(str(user.id), limit=20)
|
|
158
|
+
|
|
159
|
+
add_logs = [log for log in logs if log.get("function_name") == "POST /v1/add"]
|
|
160
|
+
assert len(add_logs) > 0
|
|
161
|
+
assert add_logs[0]["type"] == "api_endpoint"
|
|
162
|
+
assert add_logs[0]["user_id"] == str(user.id)
|
|
163
|
+
assert add_logs[0]["success"] is True
|
|
164
|
+
|
|
165
|
+
cognify_logs = [log for log in logs if log.get("function_name") == "POST /v1/cognify"]
|
|
166
|
+
assert len(cognify_logs) > 0
|
|
167
|
+
assert cognify_logs[0]["type"] == "api_endpoint"
|
|
168
|
+
assert cognify_logs[0]["user_id"] == str(user.id)
|
|
169
|
+
assert cognify_logs[0]["success"] is True
|
|
170
|
+
|
|
171
|
+
search_logs = [log for log in logs if log.get("function_name") == "POST /v1/search"]
|
|
172
|
+
assert len(search_logs) > 0
|
|
173
|
+
assert search_logs[0]["type"] == "api_endpoint"
|
|
174
|
+
assert search_logs[0]["user_id"] == str(user.id)
|
|
175
|
+
assert search_logs[0]["success"] is True
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
@pytest.mark.asyncio
|
|
179
|
+
async def test_mcp_tool_logging(e2e_config, cache_engine):
|
|
180
|
+
"""Test that MCP tools succeed and log to Redis."""
|
|
181
|
+
import sys
|
|
182
|
+
import importlib.util
|
|
183
|
+
from pathlib import Path
|
|
184
|
+
|
|
185
|
+
await _reset_engines_and_prune()
|
|
186
|
+
|
|
187
|
+
repo_root = Path(__file__).parent.parent.parent
|
|
188
|
+
mcp_src_path = repo_root / "cognee-mcp" / "src"
|
|
189
|
+
mcp_server_path = mcp_src_path / "server.py"
|
|
190
|
+
|
|
191
|
+
if not mcp_server_path.exists():
|
|
192
|
+
pytest.skip(f"MCP server not found at {mcp_server_path}")
|
|
193
|
+
|
|
194
|
+
if str(mcp_src_path) not in sys.path:
|
|
195
|
+
sys.path.insert(0, str(mcp_src_path))
|
|
196
|
+
|
|
197
|
+
spec = importlib.util.spec_from_file_location("mcp_server_module", mcp_server_path)
|
|
198
|
+
mcp_server_module = importlib.util.module_from_spec(spec)
|
|
199
|
+
|
|
200
|
+
import os
|
|
201
|
+
|
|
202
|
+
original_cwd = os.getcwd()
|
|
203
|
+
try:
|
|
204
|
+
os.chdir(str(mcp_src_path))
|
|
205
|
+
spec.loader.exec_module(mcp_server_module)
|
|
206
|
+
finally:
|
|
207
|
+
os.chdir(original_cwd)
|
|
208
|
+
|
|
209
|
+
if mcp_server_module.cognee_client is None:
|
|
210
|
+
cognee_client_path = mcp_src_path / "cognee_client.py"
|
|
211
|
+
if cognee_client_path.exists():
|
|
212
|
+
spec_client = importlib.util.spec_from_file_location(
|
|
213
|
+
"cognee_client", cognee_client_path
|
|
214
|
+
)
|
|
215
|
+
cognee_client_module = importlib.util.module_from_spec(spec_client)
|
|
216
|
+
spec_client.loader.exec_module(cognee_client_module)
|
|
217
|
+
CogneeClient = cognee_client_module.CogneeClient
|
|
218
|
+
mcp_server_module.cognee_client = CogneeClient()
|
|
219
|
+
else:
|
|
220
|
+
pytest.skip(f"CogneeClient not found at {cognee_client_path}")
|
|
221
|
+
|
|
222
|
+
test_text = "Germany is located in Europe right next to the Netherlands."
|
|
223
|
+
await mcp_server_module.cognify(data=test_text)
|
|
224
|
+
await asyncio.sleep(30.0)
|
|
225
|
+
|
|
226
|
+
list_result = await mcp_server_module.list_data()
|
|
227
|
+
assert list_result is not None, "List data should return results"
|
|
228
|
+
|
|
229
|
+
search_result = await mcp_server_module.search(
|
|
230
|
+
search_query="Germany", search_type="GRAPH_COMPLETION", top_k=5
|
|
231
|
+
)
|
|
232
|
+
assert search_result is not None, "Search should return results"
|
|
233
|
+
|
|
234
|
+
interaction_data = "User: What is Germany?\nAgent: Germany is a country in Europe."
|
|
235
|
+
await mcp_server_module.save_interaction(data=interaction_data)
|
|
236
|
+
await asyncio.sleep(30.0)
|
|
237
|
+
|
|
238
|
+
status_result = await mcp_server_module.cognify_status()
|
|
239
|
+
assert status_result is not None, "Cognify status should return results"
|
|
240
|
+
|
|
241
|
+
await mcp_server_module.prune()
|
|
242
|
+
await asyncio.sleep(0.5)
|
|
243
|
+
|
|
244
|
+
logs = await cache_engine.get_usage_logs("unknown", limit=50)
|
|
245
|
+
mcp_logs = [log for log in logs if log.get("type") == "mcp_tool"]
|
|
246
|
+
assert len(mcp_logs) > 0, (
|
|
247
|
+
f"Should have MCP tool logs with user_id='unknown'. Found logs: {[log.get('function_name') for log in logs[:5]]}"
|
|
248
|
+
)
|
|
249
|
+
assert len(mcp_logs) == 6
|
|
250
|
+
function_names = [log.get("function_name") for log in mcp_logs]
|
|
251
|
+
expected_tools = [
|
|
252
|
+
"MCP cognify",
|
|
253
|
+
"MCP list_data",
|
|
254
|
+
"MCP search",
|
|
255
|
+
"MCP save_interaction",
|
|
256
|
+
"MCP cognify_status",
|
|
257
|
+
"MCP prune",
|
|
258
|
+
]
|
|
259
|
+
|
|
260
|
+
for expected_tool in expected_tools:
|
|
261
|
+
assert expected_tool in function_names, (
|
|
262
|
+
f"Should have {expected_tool} log. Found: {function_names}"
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
for log in mcp_logs:
|
|
266
|
+
assert log["type"] == "mcp_tool"
|
|
267
|
+
assert log["user_id"] == "unknown"
|
|
268
|
+
assert log["success"] is True
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import uuid
|
|
3
|
+
from contextlib import asynccontextmanager
|
|
4
|
+
from types import SimpleNamespace
|
|
5
|
+
from unittest.mock import AsyncMock
|
|
6
|
+
|
|
7
|
+
import pytest
|
|
8
|
+
from fastapi import FastAPI
|
|
9
|
+
from fastapi.testclient import TestClient
|
|
10
|
+
|
|
11
|
+
from cognee.modules.users.methods import get_authenticated_user
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@pytest.fixture(scope="session")
|
|
15
|
+
def test_client():
|
|
16
|
+
from cognee.api.v1.datasets.routers.get_datasets_router import get_datasets_router
|
|
17
|
+
|
|
18
|
+
app = FastAPI()
|
|
19
|
+
app.include_router(get_datasets_router(), prefix="/api/v1/datasets")
|
|
20
|
+
|
|
21
|
+
with TestClient(app) as c:
|
|
22
|
+
yield c
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@pytest.fixture
|
|
26
|
+
def client(test_client):
|
|
27
|
+
async def override_get_authenticated_user():
|
|
28
|
+
return SimpleNamespace(
|
|
29
|
+
id=str(uuid.uuid4()),
|
|
30
|
+
email="default@example.com",
|
|
31
|
+
is_active=True,
|
|
32
|
+
tenant_id=str(uuid.uuid4()),
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
import importlib
|
|
36
|
+
|
|
37
|
+
datasets_router_module = importlib.import_module(
|
|
38
|
+
"cognee.api.v1.datasets.routers.get_datasets_router"
|
|
39
|
+
)
|
|
40
|
+
datasets_router_module.send_telemetry = lambda *args, **kwargs: None
|
|
41
|
+
|
|
42
|
+
test_client.app.dependency_overrides[get_authenticated_user] = override_get_authenticated_user
|
|
43
|
+
yield test_client
|
|
44
|
+
test_client.app.dependency_overrides.pop(get_authenticated_user, None)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _patch_raw_download_dependencies(
|
|
48
|
+
monkeypatch, *, dataset_id, data_id, raw_data_location, name, mime_type
|
|
49
|
+
):
|
|
50
|
+
"""
|
|
51
|
+
Patch the internal dataset/data lookups used by GET /datasets/{dataset_id}/data/{data_id}/raw.
|
|
52
|
+
Keeps the test focused on response behavior (FileResponse vs StreamingResponse).
|
|
53
|
+
"""
|
|
54
|
+
import importlib
|
|
55
|
+
|
|
56
|
+
datasets_router_module = importlib.import_module(
|
|
57
|
+
"cognee.api.v1.datasets.routers.get_datasets_router"
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
monkeypatch.setattr(
|
|
61
|
+
datasets_router_module,
|
|
62
|
+
"get_authorized_existing_datasets",
|
|
63
|
+
AsyncMock(return_value=[SimpleNamespace(id=dataset_id)]),
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
import cognee.modules.data.methods as data_methods_module
|
|
67
|
+
|
|
68
|
+
monkeypatch.setattr(
|
|
69
|
+
data_methods_module,
|
|
70
|
+
"get_dataset_data",
|
|
71
|
+
AsyncMock(return_value=[SimpleNamespace(id=data_id)]),
|
|
72
|
+
)
|
|
73
|
+
monkeypatch.setattr(
|
|
74
|
+
data_methods_module,
|
|
75
|
+
"get_data",
|
|
76
|
+
AsyncMock(
|
|
77
|
+
return_value=SimpleNamespace(
|
|
78
|
+
id=data_id,
|
|
79
|
+
raw_data_location=raw_data_location,
|
|
80
|
+
name=name,
|
|
81
|
+
mime_type=mime_type,
|
|
82
|
+
)
|
|
83
|
+
),
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def test_get_raw_data_local_file_downloads_bytes(client, monkeypatch, tmp_path):
|
|
88
|
+
"""Downloads bytes from a file:// raw_data_location."""
|
|
89
|
+
dataset_id = uuid.uuid4()
|
|
90
|
+
data_id = uuid.uuid4()
|
|
91
|
+
|
|
92
|
+
file_path = tmp_path / "example.txt"
|
|
93
|
+
content = b"hello from disk"
|
|
94
|
+
file_path.write_bytes(content)
|
|
95
|
+
|
|
96
|
+
_patch_raw_download_dependencies(
|
|
97
|
+
monkeypatch,
|
|
98
|
+
dataset_id=dataset_id,
|
|
99
|
+
data_id=data_id,
|
|
100
|
+
raw_data_location=file_path.as_uri(),
|
|
101
|
+
name="example.txt",
|
|
102
|
+
mime_type="text/plain",
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
response = client.get(f"/api/v1/datasets/{dataset_id}/data/{data_id}/raw")
|
|
106
|
+
assert response.status_code == 200
|
|
107
|
+
assert response.content == content
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def test_get_raw_data_s3_streams_bytes_without_s3_dependency(client, monkeypatch):
|
|
111
|
+
"""Streams bytes from an s3:// raw_data_location (mocked)."""
|
|
112
|
+
dataset_id = uuid.uuid4()
|
|
113
|
+
data_id = uuid.uuid4()
|
|
114
|
+
|
|
115
|
+
_patch_raw_download_dependencies(
|
|
116
|
+
monkeypatch,
|
|
117
|
+
dataset_id=dataset_id,
|
|
118
|
+
data_id=data_id,
|
|
119
|
+
raw_data_location="s3://bucket/path/to/file.txt",
|
|
120
|
+
name="file.txt",
|
|
121
|
+
mime_type="text/plain",
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
import cognee.infrastructure.files.utils.open_data_file as open_data_file_module
|
|
125
|
+
|
|
126
|
+
@asynccontextmanager
|
|
127
|
+
async def fake_open_data_file(_file_path: str, mode: str = "rb", **_kwargs):
|
|
128
|
+
assert mode == "rb"
|
|
129
|
+
yield io.BytesIO(b"hello from s3")
|
|
130
|
+
|
|
131
|
+
monkeypatch.setattr(open_data_file_module, "open_data_file", fake_open_data_file)
|
|
132
|
+
|
|
133
|
+
response = client.get(f"/api/v1/datasets/{dataset_id}/data/{data_id}/raw")
|
|
134
|
+
assert response.status_code == 200
|
|
135
|
+
assert response.content == b"hello from s3"
|
|
136
|
+
assert response.headers.get("content-disposition") == 'attachment; filename="file.txt"'
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def test_get_raw_data_unsupported_scheme_returns_501(client, monkeypatch):
|
|
140
|
+
"""Returns 501 for unsupported raw_data_location schemes (e.g., http://)."""
|
|
141
|
+
dataset_id = uuid.uuid4()
|
|
142
|
+
data_id = uuid.uuid4()
|
|
143
|
+
|
|
144
|
+
_patch_raw_download_dependencies(
|
|
145
|
+
monkeypatch,
|
|
146
|
+
dataset_id=dataset_id,
|
|
147
|
+
data_id=data_id,
|
|
148
|
+
raw_data_location="http://example.com/file.txt",
|
|
149
|
+
name="file.txt",
|
|
150
|
+
mime_type="text/plain",
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
response = client.get(f"/api/v1/datasets/{dataset_id}/data/{data_id}/raw")
|
|
154
|
+
assert response.status_code == 501
|
|
155
|
+
assert "Storage scheme 'http' not supported" in response.json()["detail"]
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def test_get_raw_data_plain_path_downloads_bytes(client, monkeypatch, tmp_path):
|
|
159
|
+
"""Downloads bytes from a plain local path (no scheme)."""
|
|
160
|
+
dataset_id = uuid.uuid4()
|
|
161
|
+
data_id = uuid.uuid4()
|
|
162
|
+
|
|
163
|
+
file_path = tmp_path / "plain.txt"
|
|
164
|
+
content = b"plain content"
|
|
165
|
+
file_path.write_bytes(content)
|
|
166
|
+
|
|
167
|
+
_patch_raw_download_dependencies(
|
|
168
|
+
monkeypatch,
|
|
169
|
+
dataset_id=dataset_id,
|
|
170
|
+
data_id=data_id,
|
|
171
|
+
raw_data_location=str(file_path),
|
|
172
|
+
name="plain.txt",
|
|
173
|
+
mime_type="text/plain",
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
response = client.get(f"/api/v1/datasets/{dataset_id}/data/{data_id}/raw")
|
|
177
|
+
assert response.status_code == 200
|
|
178
|
+
assert response.content == content
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def test_get_raw_data_encoded_path_downloads_bytes(client, monkeypatch, tmp_path):
|
|
182
|
+
"""Downloads bytes from a percent-encoded file URI (e.g. spaces)."""
|
|
183
|
+
dataset_id = uuid.uuid4()
|
|
184
|
+
data_id = uuid.uuid4()
|
|
185
|
+
|
|
186
|
+
file_name = "example file.txt"
|
|
187
|
+
file_path = tmp_path / file_name
|
|
188
|
+
content = b"content with spaces"
|
|
189
|
+
file_path.write_bytes(content)
|
|
190
|
+
|
|
191
|
+
# Convert to URI, which should encode space as %20
|
|
192
|
+
uri = file_path.as_uri()
|
|
193
|
+
assert "%20" in uri
|
|
194
|
+
|
|
195
|
+
_patch_raw_download_dependencies(
|
|
196
|
+
monkeypatch,
|
|
197
|
+
dataset_id=dataset_id,
|
|
198
|
+
data_id=data_id,
|
|
199
|
+
raw_data_location=uri,
|
|
200
|
+
name=file_name,
|
|
201
|
+
mime_type="text/plain",
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
response = client.get(f"/api/v1/datasets/{dataset_id}/data/{data_id}/raw")
|
|
205
|
+
assert response.status_code == 200
|
|
206
|
+
assert response.content == content
|
|
@@ -12,8 +12,9 @@ async def test_answer_generation():
|
|
|
12
12
|
corpus_list, qa_pairs = DummyAdapter().load_corpus(limit=limit)
|
|
13
13
|
|
|
14
14
|
mock_retriever = AsyncMock()
|
|
15
|
-
mock_retriever.
|
|
16
|
-
mock_retriever.
|
|
15
|
+
mock_retriever.get_retrieved_objects = AsyncMock(return_value=[])
|
|
16
|
+
mock_retriever.get_context_from_objects = AsyncMock(return_value="Mocked retrieval context")
|
|
17
|
+
mock_retriever.get_completion_from_context = AsyncMock(return_value=["Mocked answer"])
|
|
17
18
|
|
|
18
19
|
answer_generator = AnswerGeneratorExecutor()
|
|
19
20
|
answers = await answer_generator.question_answering_non_parallel(
|
|
@@ -21,7 +22,7 @@ async def test_answer_generation():
|
|
|
21
22
|
retriever=mock_retriever,
|
|
22
23
|
)
|
|
23
24
|
|
|
24
|
-
mock_retriever.
|
|
25
|
+
mock_retriever.get_context_from_objects.assert_any_await(qa_pairs[0]["question"], [])
|
|
25
26
|
|
|
26
27
|
assert len(answers) == len(qa_pairs)
|
|
27
28
|
assert answers[0]["question"] == qa_pairs[0]["question"], (
|
|
@@ -11,6 +11,22 @@ MOCK_JSONL_DATA = """\
|
|
|
11
11
|
{"id": "2", "question": "What is ML?", "answer": "Machine Learning", "paragraphs": [{"paragraph_text": "ML is a subset of AI."}]}
|
|
12
12
|
"""
|
|
13
13
|
|
|
14
|
+
MOCK_HOTPOT_CORPUS = [
|
|
15
|
+
{
|
|
16
|
+
"_id": "1",
|
|
17
|
+
"question": "Next to which country is Germany located?",
|
|
18
|
+
"answer": "Netherlands",
|
|
19
|
+
# HotpotQA uses "level"; TwoWikiMultiHop uses "type".
|
|
20
|
+
"level": "easy",
|
|
21
|
+
"type": "comparison",
|
|
22
|
+
"context": [
|
|
23
|
+
["Germany", ["Germany is in Europe."]],
|
|
24
|
+
["Netherlands", ["The Netherlands borders Germany."]],
|
|
25
|
+
],
|
|
26
|
+
"supporting_facts": [["Netherlands", 0]],
|
|
27
|
+
}
|
|
28
|
+
]
|
|
29
|
+
|
|
14
30
|
|
|
15
31
|
ADAPTER_CLASSES = [
|
|
16
32
|
HotpotQAAdapter,
|
|
@@ -35,6 +51,11 @@ def test_adapter_can_instantiate_and_load(AdapterClass):
|
|
|
35
51
|
adapter = AdapterClass()
|
|
36
52
|
result = adapter.load_corpus()
|
|
37
53
|
|
|
54
|
+
elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter):
|
|
55
|
+
with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
|
|
56
|
+
adapter = AdapterClass()
|
|
57
|
+
result = adapter.load_corpus()
|
|
58
|
+
|
|
38
59
|
else:
|
|
39
60
|
adapter = AdapterClass()
|
|
40
61
|
result = adapter.load_corpus()
|
|
@@ -64,6 +85,10 @@ def test_adapter_returns_some_content(AdapterClass):
|
|
|
64
85
|
):
|
|
65
86
|
adapter = AdapterClass()
|
|
66
87
|
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
|
|
88
|
+
elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter):
|
|
89
|
+
with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
|
|
90
|
+
adapter = AdapterClass()
|
|
91
|
+
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
|
|
67
92
|
else:
|
|
68
93
|
adapter = AdapterClass()
|
|
69
94
|
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
|
|
@@ -2,15 +2,38 @@ import pytest
|
|
|
2
2
|
from cognee.eval_framework.corpus_builder.corpus_builder_executor import CorpusBuilderExecutor
|
|
3
3
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
4
4
|
from unittest.mock import AsyncMock, patch
|
|
5
|
+
from cognee.eval_framework.benchmark_adapters.hotpot_qa_adapter import HotpotQAAdapter
|
|
5
6
|
|
|
6
7
|
benchmark_options = ["HotPotQA", "Dummy", "TwoWikiMultiHop"]
|
|
7
8
|
|
|
9
|
+
MOCK_HOTPOT_CORPUS = [
|
|
10
|
+
{
|
|
11
|
+
"_id": "1",
|
|
12
|
+
"question": "Next to which country is Germany located?",
|
|
13
|
+
"answer": "Netherlands",
|
|
14
|
+
# HotpotQA uses "level"; TwoWikiMultiHop uses "type".
|
|
15
|
+
"level": "easy",
|
|
16
|
+
"type": "comparison",
|
|
17
|
+
"context": [
|
|
18
|
+
["Germany", ["Germany is in Europe."]],
|
|
19
|
+
["Netherlands", ["The Netherlands borders Germany."]],
|
|
20
|
+
],
|
|
21
|
+
"supporting_facts": [["Netherlands", 0]],
|
|
22
|
+
}
|
|
23
|
+
]
|
|
24
|
+
|
|
8
25
|
|
|
9
26
|
@pytest.mark.parametrize("benchmark", benchmark_options)
|
|
10
27
|
def test_corpus_builder_load_corpus(benchmark):
|
|
11
28
|
limit = 2
|
|
12
|
-
|
|
13
|
-
|
|
29
|
+
if benchmark in ("HotPotQA", "TwoWikiMultiHop"):
|
|
30
|
+
with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
|
|
31
|
+
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
|
32
|
+
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
|
|
33
|
+
else:
|
|
34
|
+
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
|
35
|
+
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
|
|
36
|
+
|
|
14
37
|
assert len(raw_corpus) > 0, f"Corpus builder loads empty corpus for {benchmark}"
|
|
15
38
|
assert len(questions) <= 2, (
|
|
16
39
|
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
|
|
@@ -22,8 +45,14 @@ def test_corpus_builder_load_corpus(benchmark):
|
|
|
22
45
|
@patch.object(CorpusBuilderExecutor, "run_cognee", new_callable=AsyncMock)
|
|
23
46
|
async def test_corpus_builder_build_corpus(mock_run_cognee, benchmark):
|
|
24
47
|
limit = 2
|
|
25
|
-
|
|
26
|
-
|
|
48
|
+
if benchmark in ("HotPotQA", "TwoWikiMultiHop"):
|
|
49
|
+
with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
|
|
50
|
+
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
|
51
|
+
questions = await corpus_builder.build_corpus(limit=limit)
|
|
52
|
+
else:
|
|
53
|
+
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
|
54
|
+
questions = await corpus_builder.build_corpus(limit=limit)
|
|
55
|
+
|
|
27
56
|
assert len(questions) <= 2, (
|
|
28
57
|
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
|
|
29
58
|
)
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from unittest.mock import patch
|
|
3
|
+
from cognee.infrastructure.databases.relational.config import RelationalConfig
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TestRelationalConfig:
|
|
7
|
+
"""Test suite for RelationalConfig DATABASE_CONNECT_ARGS parsing."""
|
|
8
|
+
|
|
9
|
+
def test_database_connect_args_valid_json_dict(self):
|
|
10
|
+
"""Test that DATABASE_CONNECT_ARGS is parsed correctly when it's a valid JSON dict."""
|
|
11
|
+
with patch.dict(
|
|
12
|
+
os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60, "sslmode": "require"}'}
|
|
13
|
+
):
|
|
14
|
+
config = RelationalConfig()
|
|
15
|
+
assert config.database_connect_args == {"timeout": 60, "sslmode": "require"}
|
|
16
|
+
|
|
17
|
+
def test_database_connect_args_empty_string(self):
|
|
18
|
+
"""Test that empty DATABASE_CONNECT_ARGS is handled correctly."""
|
|
19
|
+
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": ""}):
|
|
20
|
+
config = RelationalConfig()
|
|
21
|
+
assert config.database_connect_args == ""
|
|
22
|
+
|
|
23
|
+
def test_database_connect_args_not_set(self):
|
|
24
|
+
"""Test that missing DATABASE_CONNECT_ARGS results in None."""
|
|
25
|
+
with patch.dict(os.environ, {}, clear=True):
|
|
26
|
+
config = RelationalConfig()
|
|
27
|
+
assert config.database_connect_args is None
|
|
28
|
+
|
|
29
|
+
def test_database_connect_args_invalid_json(self):
|
|
30
|
+
"""Test that invalid JSON in DATABASE_CONNECT_ARGS results in empty dict."""
|
|
31
|
+
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60'}): # Invalid JSON
|
|
32
|
+
config = RelationalConfig()
|
|
33
|
+
assert config.database_connect_args == {}
|
|
34
|
+
|
|
35
|
+
def test_database_connect_args_non_dict_json(self):
|
|
36
|
+
"""Test that non-dict JSON in DATABASE_CONNECT_ARGS results in empty dict."""
|
|
37
|
+
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '["list", "instead", "of", "dict"]'}):
|
|
38
|
+
config = RelationalConfig()
|
|
39
|
+
assert config.database_connect_args == {}
|
|
40
|
+
|
|
41
|
+
def test_database_connect_args_to_dict(self):
|
|
42
|
+
"""Test that database_connect_args is included in to_dict() output."""
|
|
43
|
+
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}):
|
|
44
|
+
config = RelationalConfig()
|
|
45
|
+
config_dict = config.to_dict()
|
|
46
|
+
assert "database_connect_args" in config_dict
|
|
47
|
+
assert config_dict["database_connect_args"] == {"timeout": 60}
|
|
48
|
+
|
|
49
|
+
def test_database_connect_args_integer_value(self):
|
|
50
|
+
"""Test that DATABASE_CONNECT_ARGS with integer values is parsed correctly."""
|
|
51
|
+
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"connect_timeout": 10}'}):
|
|
52
|
+
config = RelationalConfig()
|
|
53
|
+
assert config.database_connect_args == {"connect_timeout": 10}
|
|
54
|
+
|
|
55
|
+
def test_database_connect_args_mixed_types(self):
|
|
56
|
+
"""Test that DATABASE_CONNECT_ARGS with mixed value types is parsed correctly."""
|
|
57
|
+
with patch.dict(
|
|
58
|
+
os.environ,
|
|
59
|
+
{
|
|
60
|
+
"DATABASE_CONNECT_ARGS": '{"timeout": 60, "sslmode": "require", "retries": 3, "keepalive": true}'
|
|
61
|
+
},
|
|
62
|
+
):
|
|
63
|
+
config = RelationalConfig()
|
|
64
|
+
assert config.database_connect_args == {
|
|
65
|
+
"timeout": 60,
|
|
66
|
+
"sslmode": "require",
|
|
67
|
+
"retries": 3,
|
|
68
|
+
"keepalive": True,
|
|
69
|
+
}
|