cognee 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +1 -0
- cognee/api/client.py +9 -5
- cognee/api/v1/add/add.py +2 -1
- cognee/api/v1/add/routers/get_add_router.py +3 -1
- cognee/api/v1/cognify/cognify.py +24 -16
- cognee/api/v1/cognify/routers/__init__.py +0 -1
- cognee/api/v1/cognify/routers/get_cognify_router.py +30 -1
- cognee/api/v1/datasets/routers/get_datasets_router.py +3 -3
- cognee/api/v1/ontologies/__init__.py +4 -0
- cognee/api/v1/ontologies/ontologies.py +158 -0
- cognee/api/v1/ontologies/routers/__init__.py +0 -0
- cognee/api/v1/ontologies/routers/get_ontology_router.py +109 -0
- cognee/api/v1/permissions/routers/get_permissions_router.py +41 -1
- cognee/api/v1/search/search.py +4 -0
- cognee/api/v1/ui/node_setup.py +360 -0
- cognee/api/v1/ui/npm_utils.py +50 -0
- cognee/api/v1/ui/ui.py +38 -68
- cognee/cli/commands/cognify_command.py +8 -1
- cognee/cli/config.py +1 -1
- cognee/context_global_variables.py +86 -9
- cognee/eval_framework/Dockerfile +29 -0
- cognee/eval_framework/answer_generation/answer_generation_executor.py +10 -0
- cognee/eval_framework/answer_generation/run_question_answering_module.py +1 -1
- cognee/eval_framework/corpus_builder/task_getters/get_cascade_graph_tasks.py +0 -2
- cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py +4 -4
- cognee/eval_framework/eval_config.py +2 -2
- cognee/eval_framework/modal_run_eval.py +16 -28
- cognee/infrastructure/databases/cache/config.py +3 -1
- cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py +151 -0
- cognee/infrastructure/databases/cache/get_cache_engine.py +20 -10
- cognee/infrastructure/databases/dataset_database_handler/__init__.py +3 -0
- cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py +80 -0
- cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +18 -0
- cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py +10 -0
- cognee/infrastructure/databases/exceptions/exceptions.py +16 -0
- cognee/infrastructure/databases/graph/config.py +7 -0
- cognee/infrastructure/databases/graph/get_graph_engine.py +3 -0
- cognee/infrastructure/databases/graph/graph_db_interface.py +15 -0
- cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py +81 -0
- cognee/infrastructure/databases/graph/kuzu/adapter.py +228 -0
- cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +168 -0
- cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +80 -1
- cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +9 -0
- cognee/infrastructure/databases/utils/__init__.py +3 -0
- cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py +10 -0
- cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +66 -18
- cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py +10 -0
- cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py +30 -0
- cognee/infrastructure/databases/vector/config.py +5 -0
- cognee/infrastructure/databases/vector/create_vector_engine.py +6 -1
- cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +8 -6
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +9 -7
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +11 -13
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +2 -0
- cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py +50 -0
- cognee/infrastructure/databases/vector/vector_db_interface.py +35 -0
- cognee/infrastructure/engine/models/Edge.py +13 -1
- cognee/infrastructure/files/storage/s3_config.py +2 -0
- cognee/infrastructure/files/utils/guess_file_type.py +4 -0
- cognee/infrastructure/llm/LLMGateway.py +5 -2
- cognee/infrastructure/llm/config.py +37 -0
- cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py +2 -2
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +23 -8
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +22 -18
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py +5 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py +153 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +47 -38
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +46 -37
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +20 -10
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +23 -11
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +36 -23
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +47 -36
- cognee/infrastructure/loaders/LoaderEngine.py +1 -0
- cognee/infrastructure/loaders/core/__init__.py +2 -1
- cognee/infrastructure/loaders/core/csv_loader.py +93 -0
- cognee/infrastructure/loaders/core/text_loader.py +1 -2
- cognee/infrastructure/loaders/external/advanced_pdf_loader.py +0 -9
- cognee/infrastructure/loaders/supported_loaders.py +2 -1
- cognee/memify_pipelines/create_triplet_embeddings.py +53 -0
- cognee/memify_pipelines/persist_sessions_in_knowledge_graph.py +55 -0
- cognee/modules/chunking/CsvChunker.py +35 -0
- cognee/modules/chunking/models/DocumentChunk.py +2 -1
- cognee/modules/chunking/text_chunker_with_overlap.py +124 -0
- cognee/modules/cognify/config.py +2 -0
- cognee/modules/data/deletion/prune_system.py +52 -2
- cognee/modules/data/methods/__init__.py +1 -0
- cognee/modules/data/methods/create_dataset.py +4 -2
- cognee/modules/data/methods/delete_dataset.py +26 -0
- cognee/modules/data/methods/get_dataset_ids.py +5 -1
- cognee/modules/data/methods/get_unique_data_id.py +68 -0
- cognee/modules/data/methods/get_unique_dataset_id.py +66 -4
- cognee/modules/data/models/Dataset.py +2 -0
- cognee/modules/data/processing/document_types/CsvDocument.py +33 -0
- cognee/modules/data/processing/document_types/__init__.py +1 -0
- cognee/modules/engine/models/Triplet.py +9 -0
- cognee/modules/engine/models/__init__.py +1 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +89 -39
- cognee/modules/graph/cognee_graph/CogneeGraphElements.py +8 -3
- cognee/modules/graph/utils/expand_with_nodes_and_edges.py +19 -2
- cognee/modules/graph/utils/resolve_edges_to_text.py +48 -49
- cognee/modules/ingestion/identify.py +4 -4
- cognee/modules/memify/memify.py +1 -7
- cognee/modules/notebooks/operations/run_in_local_sandbox.py +3 -0
- cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py +55 -23
- cognee/modules/pipelines/operations/pipeline.py +18 -2
- cognee/modules/pipelines/operations/run_tasks_data_item.py +1 -1
- cognee/modules/retrieval/EntityCompletionRetriever.py +10 -3
- cognee/modules/retrieval/__init__.py +1 -1
- cognee/modules/retrieval/base_graph_retriever.py +7 -3
- cognee/modules/retrieval/base_retriever.py +7 -3
- cognee/modules/retrieval/completion_retriever.py +11 -4
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +10 -2
- cognee/modules/retrieval/graph_completion_cot_retriever.py +18 -51
- cognee/modules/retrieval/graph_completion_retriever.py +14 -1
- cognee/modules/retrieval/graph_summary_completion_retriever.py +4 -0
- cognee/modules/retrieval/register_retriever.py +10 -0
- cognee/modules/retrieval/registered_community_retrievers.py +1 -0
- cognee/modules/retrieval/temporal_retriever.py +13 -2
- cognee/modules/retrieval/triplet_retriever.py +182 -0
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +43 -11
- cognee/modules/retrieval/utils/completion.py +2 -22
- cognee/modules/run_custom_pipeline/__init__.py +1 -0
- cognee/modules/run_custom_pipeline/run_custom_pipeline.py +76 -0
- cognee/modules/search/methods/get_search_type_tools.py +54 -8
- cognee/modules/search/methods/no_access_control_search.py +4 -0
- cognee/modules/search/methods/search.py +26 -3
- cognee/modules/search/types/SearchType.py +1 -1
- cognee/modules/settings/get_settings.py +19 -0
- cognee/modules/users/methods/create_user.py +12 -27
- cognee/modules/users/methods/get_authenticated_user.py +3 -2
- cognee/modules/users/methods/get_default_user.py +4 -2
- cognee/modules/users/methods/get_user.py +1 -1
- cognee/modules/users/methods/get_user_by_email.py +1 -1
- cognee/modules/users/models/DatasetDatabase.py +24 -3
- cognee/modules/users/models/Tenant.py +6 -7
- cognee/modules/users/models/User.py +6 -5
- cognee/modules/users/models/UserTenant.py +12 -0
- cognee/modules/users/models/__init__.py +1 -0
- cognee/modules/users/permissions/methods/get_all_user_permission_datasets.py +13 -13
- cognee/modules/users/roles/methods/add_user_to_role.py +3 -1
- cognee/modules/users/tenants/methods/__init__.py +1 -0
- cognee/modules/users/tenants/methods/add_user_to_tenant.py +21 -12
- cognee/modules/users/tenants/methods/create_tenant.py +22 -8
- cognee/modules/users/tenants/methods/select_tenant.py +62 -0
- cognee/shared/logging_utils.py +6 -0
- cognee/shared/rate_limiting.py +30 -0
- cognee/tasks/chunks/__init__.py +1 -0
- cognee/tasks/chunks/chunk_by_row.py +94 -0
- cognee/tasks/documents/__init__.py +0 -1
- cognee/tasks/documents/classify_documents.py +2 -0
- cognee/tasks/feedback/generate_improved_answers.py +3 -3
- cognee/tasks/graph/extract_graph_from_data.py +9 -10
- cognee/tasks/ingestion/ingest_data.py +1 -1
- cognee/tasks/memify/__init__.py +2 -0
- cognee/tasks/memify/cognify_session.py +41 -0
- cognee/tasks/memify/extract_user_sessions.py +73 -0
- cognee/tasks/memify/get_triplet_datapoints.py +289 -0
- cognee/tasks/storage/add_data_points.py +142 -2
- cognee/tasks/storage/index_data_points.py +33 -22
- cognee/tasks/storage/index_graph_edges.py +37 -57
- cognee/tests/integration/documents/CsvDocument_test.py +70 -0
- cognee/tests/integration/retrieval/test_triplet_retriever.py +84 -0
- cognee/tests/integration/tasks/test_add_data_points.py +139 -0
- cognee/tests/integration/tasks/test_get_triplet_datapoints.py +69 -0
- cognee/tests/tasks/entity_extraction/entity_extraction_test.py +1 -1
- cognee/tests/test_add_docling_document.py +2 -2
- cognee/tests/test_cognee_server_start.py +84 -3
- cognee/tests/test_conversation_history.py +68 -5
- cognee/tests/test_data/example_with_header.csv +3 -0
- cognee/tests/test_dataset_database_handler.py +137 -0
- cognee/tests/test_dataset_delete.py +76 -0
- cognee/tests/test_edge_centered_payload.py +170 -0
- cognee/tests/test_edge_ingestion.py +27 -0
- cognee/tests/test_feedback_enrichment.py +1 -1
- cognee/tests/test_library.py +6 -4
- cognee/tests/test_load.py +62 -0
- cognee/tests/test_multi_tenancy.py +165 -0
- cognee/tests/test_parallel_databases.py +2 -0
- cognee/tests/test_pipeline_cache.py +164 -0
- cognee/tests/test_relational_db_migration.py +54 -2
- cognee/tests/test_search_db.py +44 -2
- cognee/tests/unit/api/test_conditional_authentication_endpoints.py +12 -3
- cognee/tests/unit/api/test_ontology_endpoint.py +252 -0
- cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +5 -0
- cognee/tests/unit/infrastructure/databases/test_index_data_points.py +27 -0
- cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py +14 -16
- cognee/tests/unit/infrastructure/llm/test_llm_config.py +46 -0
- cognee/tests/unit/infrastructure/mock_embedding_engine.py +3 -7
- cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +0 -5
- cognee/tests/unit/modules/chunking/test_text_chunker.py +248 -0
- cognee/tests/unit/modules/chunking/test_text_chunker_with_overlap.py +324 -0
- cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +2 -2
- cognee/tests/unit/modules/graph/cognee_graph_test.py +406 -0
- cognee/tests/unit/modules/memify_tasks/test_cognify_session.py +111 -0
- cognee/tests/unit/modules/memify_tasks/test_extract_user_sessions.py +175 -0
- cognee/tests/unit/modules/memify_tasks/test_get_triplet_datapoints.py +214 -0
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +0 -51
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +1 -0
- cognee/tests/unit/modules/retrieval/structured_output_test.py +204 -0
- cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +1 -1
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +0 -1
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +608 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +83 -0
- cognee/tests/unit/modules/users/test_conditional_authentication.py +0 -63
- cognee/tests/unit/processing/chunks/chunk_by_row_test.py +52 -0
- cognee/tests/unit/tasks/storage/test_add_data_points.py +288 -0
- {cognee-0.4.1.dist-info → cognee-0.5.0.dist-info}/METADATA +11 -7
- {cognee-0.4.1.dist-info → cognee-0.5.0.dist-info}/RECORD +212 -160
- {cognee-0.4.1.dist-info → cognee-0.5.0.dist-info}/entry_points.txt +0 -1
- cognee/api/v1/cognify/code_graph_pipeline.py +0 -119
- cognee/api/v1/cognify/routers/get_code_pipeline_router.py +0 -90
- cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +0 -544
- cognee/modules/retrieval/code_retriever.py +0 -232
- cognee/tasks/code/enrich_dependency_graph_checker.py +0 -35
- cognee/tasks/code/get_local_dependencies_checker.py +0 -20
- cognee/tasks/code/get_repo_dependency_graph_checker.py +0 -35
- cognee/tasks/documents/check_permissions_on_dataset.py +0 -26
- cognee/tasks/repo_processor/__init__.py +0 -2
- cognee/tasks/repo_processor/get_local_dependencies.py +0 -335
- cognee/tasks/repo_processor/get_non_code_files.py +0 -158
- cognee/tasks/repo_processor/get_repo_file_dependencies.py +0 -243
- {cognee-0.4.1.dist-info → cognee-0.5.0.dist-info}/WHEEL +0 -0
- {cognee-0.4.1.dist-info → cognee-0.5.0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.4.1.dist-info → cognee-0.5.0.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import pytest
|
|
2
|
+
from unittest.mock import AsyncMock
|
|
2
3
|
|
|
3
4
|
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
|
|
4
5
|
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
|
@@ -11,6 +12,30 @@ def setup_graph():
|
|
|
11
12
|
return CogneeGraph()
|
|
12
13
|
|
|
13
14
|
|
|
15
|
+
@pytest.fixture
|
|
16
|
+
def mock_adapter():
|
|
17
|
+
"""Fixture to create a mock adapter for database operations."""
|
|
18
|
+
adapter = AsyncMock()
|
|
19
|
+
return adapter
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@pytest.fixture
|
|
23
|
+
def mock_vector_engine():
|
|
24
|
+
"""Fixture to create a mock vector engine."""
|
|
25
|
+
engine = AsyncMock()
|
|
26
|
+
engine.search = AsyncMock()
|
|
27
|
+
return engine
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class MockScoredResult:
|
|
31
|
+
"""Mock class for vector search results."""
|
|
32
|
+
|
|
33
|
+
def __init__(self, id, score, payload=None):
|
|
34
|
+
self.id = id
|
|
35
|
+
self.score = score
|
|
36
|
+
self.payload = payload or {}
|
|
37
|
+
|
|
38
|
+
|
|
14
39
|
def test_add_node_success(setup_graph):
|
|
15
40
|
"""Test successful addition of a node."""
|
|
16
41
|
graph = setup_graph
|
|
@@ -73,3 +98,384 @@ def test_get_edges_nonexistent_node(setup_graph):
|
|
|
73
98
|
graph = setup_graph
|
|
74
99
|
with pytest.raises(EntityNotFoundError, match="Node with id nonexistent does not exist."):
|
|
75
100
|
graph.get_edges_from_node("nonexistent")
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@pytest.mark.asyncio
|
|
104
|
+
async def test_project_graph_from_db_full_graph(setup_graph, mock_adapter):
|
|
105
|
+
"""Test projecting a full graph from database."""
|
|
106
|
+
graph = setup_graph
|
|
107
|
+
|
|
108
|
+
nodes_data = [
|
|
109
|
+
("1", {"name": "Node1", "description": "First node"}),
|
|
110
|
+
("2", {"name": "Node2", "description": "Second node"}),
|
|
111
|
+
]
|
|
112
|
+
edges_data = [
|
|
113
|
+
("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}),
|
|
114
|
+
]
|
|
115
|
+
|
|
116
|
+
mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
|
|
117
|
+
|
|
118
|
+
await graph.project_graph_from_db(
|
|
119
|
+
adapter=mock_adapter,
|
|
120
|
+
node_properties_to_project=["name", "description"],
|
|
121
|
+
edge_properties_to_project=["relationship_name"],
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
assert len(graph.nodes) == 2
|
|
125
|
+
assert len(graph.edges) == 1
|
|
126
|
+
assert graph.get_node("1") is not None
|
|
127
|
+
assert graph.get_node("2") is not None
|
|
128
|
+
assert graph.edges[0].node1.id == "1"
|
|
129
|
+
assert graph.edges[0].node2.id == "2"
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@pytest.mark.asyncio
|
|
133
|
+
async def test_project_graph_from_db_id_filtered(setup_graph, mock_adapter):
|
|
134
|
+
"""Test projecting an ID-filtered graph from database."""
|
|
135
|
+
graph = setup_graph
|
|
136
|
+
|
|
137
|
+
nodes_data = [
|
|
138
|
+
("1", {"name": "Node1"}),
|
|
139
|
+
("2", {"name": "Node2"}),
|
|
140
|
+
]
|
|
141
|
+
edges_data = [
|
|
142
|
+
("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}),
|
|
143
|
+
]
|
|
144
|
+
|
|
145
|
+
mock_adapter.get_id_filtered_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
|
|
146
|
+
|
|
147
|
+
await graph.project_graph_from_db(
|
|
148
|
+
adapter=mock_adapter,
|
|
149
|
+
node_properties_to_project=["name"],
|
|
150
|
+
edge_properties_to_project=["relationship_name"],
|
|
151
|
+
relevant_ids_to_filter=["1", "2"],
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
assert len(graph.nodes) == 2
|
|
155
|
+
assert len(graph.edges) == 1
|
|
156
|
+
mock_adapter.get_id_filtered_graph_data.assert_called_once()
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
@pytest.mark.asyncio
|
|
160
|
+
async def test_project_graph_from_db_nodeset_subgraph(setup_graph, mock_adapter):
|
|
161
|
+
"""Test projecting a nodeset subgraph filtered by node type and name."""
|
|
162
|
+
graph = setup_graph
|
|
163
|
+
|
|
164
|
+
nodes_data = [
|
|
165
|
+
("1", {"name": "Alice", "type": "Person"}),
|
|
166
|
+
("2", {"name": "Bob", "type": "Person"}),
|
|
167
|
+
]
|
|
168
|
+
edges_data = [
|
|
169
|
+
("1", "2", "KNOWS", {"relationship_name": "knows"}),
|
|
170
|
+
]
|
|
171
|
+
|
|
172
|
+
mock_adapter.get_nodeset_subgraph = AsyncMock(return_value=(nodes_data, edges_data))
|
|
173
|
+
|
|
174
|
+
await graph.project_graph_from_db(
|
|
175
|
+
adapter=mock_adapter,
|
|
176
|
+
node_properties_to_project=["name", "type"],
|
|
177
|
+
edge_properties_to_project=["relationship_name"],
|
|
178
|
+
node_type="Person",
|
|
179
|
+
node_name=["Alice"],
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
assert len(graph.nodes) == 2
|
|
183
|
+
assert graph.get_node("1") is not None
|
|
184
|
+
assert len(graph.edges) == 1
|
|
185
|
+
mock_adapter.get_nodeset_subgraph.assert_called_once()
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
@pytest.mark.asyncio
|
|
189
|
+
async def test_project_graph_from_db_empty_graph(setup_graph, mock_adapter):
|
|
190
|
+
"""Test projecting empty graph raises EntityNotFoundError."""
|
|
191
|
+
graph = setup_graph
|
|
192
|
+
|
|
193
|
+
mock_adapter.get_graph_data = AsyncMock(return_value=([], []))
|
|
194
|
+
|
|
195
|
+
with pytest.raises(EntityNotFoundError, match="Empty graph projected from the database."):
|
|
196
|
+
await graph.project_graph_from_db(
|
|
197
|
+
adapter=mock_adapter,
|
|
198
|
+
node_properties_to_project=["name"],
|
|
199
|
+
edge_properties_to_project=[],
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
@pytest.mark.asyncio
|
|
204
|
+
async def test_project_graph_from_db_missing_nodes(setup_graph, mock_adapter):
|
|
205
|
+
"""Test that edges referencing missing nodes raise error."""
|
|
206
|
+
graph = setup_graph
|
|
207
|
+
|
|
208
|
+
nodes_data = [
|
|
209
|
+
("1", {"name": "Node1"}),
|
|
210
|
+
]
|
|
211
|
+
edges_data = [
|
|
212
|
+
("1", "999", "CONNECTS_TO", {"relationship_name": "connects"}),
|
|
213
|
+
]
|
|
214
|
+
|
|
215
|
+
mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
|
|
216
|
+
|
|
217
|
+
with pytest.raises(EntityNotFoundError, match="Edge references nonexistent nodes"):
|
|
218
|
+
await graph.project_graph_from_db(
|
|
219
|
+
adapter=mock_adapter,
|
|
220
|
+
node_properties_to_project=["name"],
|
|
221
|
+
edge_properties_to_project=["relationship_name"],
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
@pytest.mark.asyncio
|
|
226
|
+
async def test_map_vector_distances_to_graph_nodes(setup_graph):
|
|
227
|
+
"""Test mapping vector distances to graph nodes."""
|
|
228
|
+
graph = setup_graph
|
|
229
|
+
|
|
230
|
+
node1 = Node("1", {"name": "Node1"})
|
|
231
|
+
node2 = Node("2", {"name": "Node2"})
|
|
232
|
+
graph.add_node(node1)
|
|
233
|
+
graph.add_node(node2)
|
|
234
|
+
|
|
235
|
+
node_distances = {
|
|
236
|
+
"Entity_name": [
|
|
237
|
+
MockScoredResult("1", 0.95),
|
|
238
|
+
MockScoredResult("2", 0.87),
|
|
239
|
+
]
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
await graph.map_vector_distances_to_graph_nodes(node_distances)
|
|
243
|
+
|
|
244
|
+
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
|
|
245
|
+
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@pytest.mark.asyncio
|
|
249
|
+
async def test_map_vector_distances_partial_node_coverage(setup_graph):
|
|
250
|
+
"""Test mapping vector distances when only some nodes have results."""
|
|
251
|
+
graph = setup_graph
|
|
252
|
+
|
|
253
|
+
node1 = Node("1", {"name": "Node1"})
|
|
254
|
+
node2 = Node("2", {"name": "Node2"})
|
|
255
|
+
node3 = Node("3", {"name": "Node3"})
|
|
256
|
+
graph.add_node(node1)
|
|
257
|
+
graph.add_node(node2)
|
|
258
|
+
graph.add_node(node3)
|
|
259
|
+
|
|
260
|
+
node_distances = {
|
|
261
|
+
"Entity_name": [
|
|
262
|
+
MockScoredResult("1", 0.95),
|
|
263
|
+
MockScoredResult("2", 0.87),
|
|
264
|
+
]
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
await graph.map_vector_distances_to_graph_nodes(node_distances)
|
|
268
|
+
|
|
269
|
+
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
|
|
270
|
+
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
|
|
271
|
+
assert graph.get_node("3").attributes.get("vector_distance") == 3.5
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
@pytest.mark.asyncio
|
|
275
|
+
async def test_map_vector_distances_multiple_categories(setup_graph):
|
|
276
|
+
"""Test mapping vector distances from multiple collection categories."""
|
|
277
|
+
graph = setup_graph
|
|
278
|
+
|
|
279
|
+
# Create nodes
|
|
280
|
+
node1 = Node("1")
|
|
281
|
+
node2 = Node("2")
|
|
282
|
+
node3 = Node("3")
|
|
283
|
+
node4 = Node("4")
|
|
284
|
+
graph.add_node(node1)
|
|
285
|
+
graph.add_node(node2)
|
|
286
|
+
graph.add_node(node3)
|
|
287
|
+
graph.add_node(node4)
|
|
288
|
+
|
|
289
|
+
node_distances = {
|
|
290
|
+
"Entity_name": [
|
|
291
|
+
MockScoredResult("1", 0.95),
|
|
292
|
+
MockScoredResult("2", 0.87),
|
|
293
|
+
],
|
|
294
|
+
"TextSummary_text": [
|
|
295
|
+
MockScoredResult("3", 0.92),
|
|
296
|
+
],
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
await graph.map_vector_distances_to_graph_nodes(node_distances)
|
|
300
|
+
|
|
301
|
+
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
|
|
302
|
+
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
|
|
303
|
+
assert graph.get_node("3").attributes.get("vector_distance") == 0.92
|
|
304
|
+
assert graph.get_node("4").attributes.get("vector_distance") == 3.5
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
@pytest.mark.asyncio
|
|
308
|
+
async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph):
|
|
309
|
+
"""Test mapping vector distances to edges when edge_distances provided."""
|
|
310
|
+
graph = setup_graph
|
|
311
|
+
|
|
312
|
+
node1 = Node("1")
|
|
313
|
+
node2 = Node("2")
|
|
314
|
+
graph.add_node(node1)
|
|
315
|
+
graph.add_node(node2)
|
|
316
|
+
|
|
317
|
+
edge = Edge(
|
|
318
|
+
node1,
|
|
319
|
+
node2,
|
|
320
|
+
attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
|
|
321
|
+
)
|
|
322
|
+
graph.add_edge(edge)
|
|
323
|
+
|
|
324
|
+
edge_distances = [
|
|
325
|
+
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
|
|
326
|
+
]
|
|
327
|
+
|
|
328
|
+
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
|
329
|
+
|
|
330
|
+
assert graph.edges[0].attributes.get("vector_distance") == 0.92
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
@pytest.mark.asyncio
|
|
334
|
+
async def test_map_vector_distances_partial_edge_coverage(setup_graph):
|
|
335
|
+
"""Test mapping edge distances when only some edges have results."""
|
|
336
|
+
graph = setup_graph
|
|
337
|
+
|
|
338
|
+
node1 = Node("1")
|
|
339
|
+
node2 = Node("2")
|
|
340
|
+
node3 = Node("3")
|
|
341
|
+
graph.add_node(node1)
|
|
342
|
+
graph.add_node(node2)
|
|
343
|
+
graph.add_node(node3)
|
|
344
|
+
|
|
345
|
+
edge1 = Edge(node1, node2, attributes={"edge_text": "CONNECTS_TO"})
|
|
346
|
+
edge2 = Edge(node2, node3, attributes={"edge_text": "DEPENDS_ON"})
|
|
347
|
+
graph.add_edge(edge1)
|
|
348
|
+
graph.add_edge(edge2)
|
|
349
|
+
|
|
350
|
+
edge_distances = [
|
|
351
|
+
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
|
|
352
|
+
]
|
|
353
|
+
|
|
354
|
+
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
|
355
|
+
|
|
356
|
+
assert graph.edges[0].attributes.get("vector_distance") == 0.92
|
|
357
|
+
assert graph.edges[1].attributes.get("vector_distance") == 3.5
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
@pytest.mark.asyncio
|
|
361
|
+
async def test_map_vector_distances_edges_fallback_to_relationship_type(setup_graph):
|
|
362
|
+
"""Test that edge mapping falls back to relationship_type when edge_text is missing."""
|
|
363
|
+
graph = setup_graph
|
|
364
|
+
|
|
365
|
+
node1 = Node("1")
|
|
366
|
+
node2 = Node("2")
|
|
367
|
+
graph.add_node(node1)
|
|
368
|
+
graph.add_node(node2)
|
|
369
|
+
|
|
370
|
+
edge = Edge(
|
|
371
|
+
node1,
|
|
372
|
+
node2,
|
|
373
|
+
attributes={"relationship_type": "KNOWS"},
|
|
374
|
+
)
|
|
375
|
+
graph.add_edge(edge)
|
|
376
|
+
|
|
377
|
+
edge_distances = [
|
|
378
|
+
MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}),
|
|
379
|
+
]
|
|
380
|
+
|
|
381
|
+
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
|
382
|
+
|
|
383
|
+
assert graph.edges[0].attributes.get("vector_distance") == 0.85
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
@pytest.mark.asyncio
|
|
387
|
+
async def test_map_vector_distances_no_edge_matches(setup_graph):
|
|
388
|
+
"""Test edge mapping when no edges match the distance results."""
|
|
389
|
+
graph = setup_graph
|
|
390
|
+
|
|
391
|
+
node1 = Node("1")
|
|
392
|
+
node2 = Node("2")
|
|
393
|
+
graph.add_node(node1)
|
|
394
|
+
graph.add_node(node2)
|
|
395
|
+
|
|
396
|
+
edge = Edge(
|
|
397
|
+
node1,
|
|
398
|
+
node2,
|
|
399
|
+
attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
|
|
400
|
+
)
|
|
401
|
+
graph.add_edge(edge)
|
|
402
|
+
|
|
403
|
+
edge_distances = [
|
|
404
|
+
MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}),
|
|
405
|
+
]
|
|
406
|
+
|
|
407
|
+
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
|
408
|
+
|
|
409
|
+
assert graph.edges[0].attributes.get("vector_distance") == 3.5
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
@pytest.mark.asyncio
|
|
413
|
+
async def test_map_vector_distances_none_returns_early(setup_graph):
|
|
414
|
+
"""Test that edge_distances=None returns early without error."""
|
|
415
|
+
graph = setup_graph
|
|
416
|
+
graph.add_node(Node("1"))
|
|
417
|
+
graph.add_node(Node("2"))
|
|
418
|
+
graph.add_edge(Edge(graph.get_node("1"), graph.get_node("2")))
|
|
419
|
+
|
|
420
|
+
await graph.map_vector_distances_to_graph_edges(edge_distances=None)
|
|
421
|
+
|
|
422
|
+
assert graph.edges[0].attributes.get("vector_distance") == 3.5
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
@pytest.mark.asyncio
|
|
426
|
+
async def test_calculate_top_triplet_importances(setup_graph):
|
|
427
|
+
"""Test calculating top triplet importances by score."""
|
|
428
|
+
graph = setup_graph
|
|
429
|
+
|
|
430
|
+
node1 = Node("1")
|
|
431
|
+
node2 = Node("2")
|
|
432
|
+
node3 = Node("3")
|
|
433
|
+
node4 = Node("4")
|
|
434
|
+
|
|
435
|
+
node1.add_attribute("vector_distance", 0.9)
|
|
436
|
+
node2.add_attribute("vector_distance", 0.8)
|
|
437
|
+
node3.add_attribute("vector_distance", 0.7)
|
|
438
|
+
node4.add_attribute("vector_distance", 0.6)
|
|
439
|
+
|
|
440
|
+
graph.add_node(node1)
|
|
441
|
+
graph.add_node(node2)
|
|
442
|
+
graph.add_node(node3)
|
|
443
|
+
graph.add_node(node4)
|
|
444
|
+
|
|
445
|
+
edge1 = Edge(node1, node2)
|
|
446
|
+
edge2 = Edge(node2, node3)
|
|
447
|
+
edge3 = Edge(node3, node4)
|
|
448
|
+
|
|
449
|
+
edge1.add_attribute("vector_distance", 0.85)
|
|
450
|
+
edge2.add_attribute("vector_distance", 0.75)
|
|
451
|
+
edge3.add_attribute("vector_distance", 0.65)
|
|
452
|
+
|
|
453
|
+
graph.add_edge(edge1)
|
|
454
|
+
graph.add_edge(edge2)
|
|
455
|
+
graph.add_edge(edge3)
|
|
456
|
+
|
|
457
|
+
top_triplets = await graph.calculate_top_triplet_importances(k=2)
|
|
458
|
+
|
|
459
|
+
assert len(top_triplets) == 2
|
|
460
|
+
|
|
461
|
+
assert top_triplets[0] == edge3
|
|
462
|
+
assert top_triplets[1] == edge2
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
@pytest.mark.asyncio
|
|
466
|
+
async def test_calculate_top_triplet_importances_default_distances(setup_graph):
|
|
467
|
+
"""Test calculating importances when nodes/edges have no vector distances."""
|
|
468
|
+
graph = setup_graph
|
|
469
|
+
|
|
470
|
+
node1 = Node("1")
|
|
471
|
+
node2 = Node("2")
|
|
472
|
+
graph.add_node(node1)
|
|
473
|
+
graph.add_node(node2)
|
|
474
|
+
|
|
475
|
+
edge = Edge(node1, node2)
|
|
476
|
+
graph.add_edge(edge)
|
|
477
|
+
|
|
478
|
+
top_triplets = await graph.calculate_top_triplet_importances(k=1)
|
|
479
|
+
|
|
480
|
+
assert len(top_triplets) == 1
|
|
481
|
+
assert top_triplets[0] == edge
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from unittest.mock import AsyncMock, patch
|
|
3
|
+
|
|
4
|
+
from cognee.tasks.memify.cognify_session import cognify_session
|
|
5
|
+
from cognee.exceptions import CogneeValidationError, CogneeSystemError
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@pytest.mark.asyncio
|
|
9
|
+
async def test_cognify_session_success():
|
|
10
|
+
"""Test successful cognification of session data."""
|
|
11
|
+
session_data = (
|
|
12
|
+
"Session ID: test_session\n\nQuestion: What is AI?\n\nAnswer: AI is artificial intelligence"
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
with (
|
|
16
|
+
patch("cognee.add", new_callable=AsyncMock) as mock_add,
|
|
17
|
+
patch("cognee.cognify", new_callable=AsyncMock) as mock_cognify,
|
|
18
|
+
):
|
|
19
|
+
await cognify_session(session_data, dataset_id="123")
|
|
20
|
+
|
|
21
|
+
mock_add.assert_called_once_with(
|
|
22
|
+
session_data, dataset_id="123", node_set=["user_sessions_from_cache"]
|
|
23
|
+
)
|
|
24
|
+
mock_cognify.assert_called_once()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@pytest.mark.asyncio
|
|
28
|
+
async def test_cognify_session_empty_string():
|
|
29
|
+
"""Test cognification fails with empty string."""
|
|
30
|
+
with pytest.raises(CogneeValidationError) as exc_info:
|
|
31
|
+
await cognify_session("")
|
|
32
|
+
|
|
33
|
+
assert "Session data cannot be empty" in str(exc_info.value)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@pytest.mark.asyncio
|
|
37
|
+
async def test_cognify_session_whitespace_string():
|
|
38
|
+
"""Test cognification fails with whitespace-only string."""
|
|
39
|
+
with pytest.raises(CogneeValidationError) as exc_info:
|
|
40
|
+
await cognify_session(" \n\t ")
|
|
41
|
+
|
|
42
|
+
assert "Session data cannot be empty" in str(exc_info.value)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@pytest.mark.asyncio
|
|
46
|
+
async def test_cognify_session_none_data():
|
|
47
|
+
"""Test cognification fails with None data."""
|
|
48
|
+
with pytest.raises(CogneeValidationError) as exc_info:
|
|
49
|
+
await cognify_session(None)
|
|
50
|
+
|
|
51
|
+
assert "Session data cannot be empty" in str(exc_info.value)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@pytest.mark.asyncio
|
|
55
|
+
async def test_cognify_session_add_failure():
|
|
56
|
+
"""Test cognification handles cognee.add failure."""
|
|
57
|
+
session_data = "Session ID: test\n\nQuestion: test?"
|
|
58
|
+
|
|
59
|
+
with (
|
|
60
|
+
patch("cognee.add", new_callable=AsyncMock) as mock_add,
|
|
61
|
+
patch("cognee.cognify", new_callable=AsyncMock),
|
|
62
|
+
):
|
|
63
|
+
mock_add.side_effect = Exception("Add operation failed")
|
|
64
|
+
|
|
65
|
+
with pytest.raises(CogneeSystemError) as exc_info:
|
|
66
|
+
await cognify_session(session_data)
|
|
67
|
+
|
|
68
|
+
assert "Failed to cognify session data" in str(exc_info.value)
|
|
69
|
+
assert "Add operation failed" in str(exc_info.value)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@pytest.mark.asyncio
|
|
73
|
+
async def test_cognify_session_cognify_failure():
|
|
74
|
+
"""Test cognification handles cognify failure."""
|
|
75
|
+
session_data = "Session ID: test\n\nQuestion: test?"
|
|
76
|
+
|
|
77
|
+
with (
|
|
78
|
+
patch("cognee.add", new_callable=AsyncMock),
|
|
79
|
+
patch("cognee.cognify", new_callable=AsyncMock) as mock_cognify,
|
|
80
|
+
):
|
|
81
|
+
mock_cognify.side_effect = Exception("Cognify operation failed")
|
|
82
|
+
|
|
83
|
+
with pytest.raises(CogneeSystemError) as exc_info:
|
|
84
|
+
await cognify_session(session_data)
|
|
85
|
+
|
|
86
|
+
assert "Failed to cognify session data" in str(exc_info.value)
|
|
87
|
+
assert "Cognify operation failed" in str(exc_info.value)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@pytest.mark.asyncio
|
|
91
|
+
async def test_cognify_session_re_raises_validation_error():
|
|
92
|
+
"""Test that CogneeValidationError is re-raised as-is."""
|
|
93
|
+
with pytest.raises(CogneeValidationError):
|
|
94
|
+
await cognify_session("")
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@pytest.mark.asyncio
|
|
98
|
+
async def test_cognify_session_with_special_characters():
|
|
99
|
+
"""Test cognification with special characters."""
|
|
100
|
+
session_data = "Session: test™ © Question: What's special? Answer: Cognee is special!"
|
|
101
|
+
|
|
102
|
+
with (
|
|
103
|
+
patch("cognee.add", new_callable=AsyncMock) as mock_add,
|
|
104
|
+
patch("cognee.cognify", new_callable=AsyncMock) as mock_cognify,
|
|
105
|
+
):
|
|
106
|
+
await cognify_session(session_data, dataset_id="123")
|
|
107
|
+
|
|
108
|
+
mock_add.assert_called_once_with(
|
|
109
|
+
session_data, dataset_id="123", node_set=["user_sessions_from_cache"]
|
|
110
|
+
)
|
|
111
|
+
mock_cognify.assert_called_once()
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import pytest
|
|
3
|
+
from unittest.mock import AsyncMock, MagicMock, patch
|
|
4
|
+
|
|
5
|
+
from cognee.tasks.memify.extract_user_sessions import extract_user_sessions
|
|
6
|
+
from cognee.exceptions import CogneeSystemError
|
|
7
|
+
from cognee.modules.users.models import User
|
|
8
|
+
|
|
9
|
+
# Get the actual module object (not the function) for patching
|
|
10
|
+
extract_user_sessions_module = sys.modules["cognee.tasks.memify.extract_user_sessions"]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@pytest.fixture
|
|
14
|
+
def mock_user():
|
|
15
|
+
"""Create a mock user."""
|
|
16
|
+
user = MagicMock(spec=User)
|
|
17
|
+
user.id = "test-user-123"
|
|
18
|
+
return user
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@pytest.fixture
|
|
22
|
+
def mock_qa_data():
|
|
23
|
+
"""Create mock Q&A data."""
|
|
24
|
+
return [
|
|
25
|
+
{
|
|
26
|
+
"question": "What is cognee?",
|
|
27
|
+
"context": "context about cognee",
|
|
28
|
+
"answer": "Cognee is a knowledge graph solution",
|
|
29
|
+
"time": "2025-01-01T12:00:00",
|
|
30
|
+
},
|
|
31
|
+
{
|
|
32
|
+
"question": "How does it work?",
|
|
33
|
+
"context": "how it works context",
|
|
34
|
+
"answer": "It processes data and creates graphs",
|
|
35
|
+
"time": "2025-01-01T12:05:00",
|
|
36
|
+
},
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@pytest.mark.asyncio
|
|
41
|
+
async def test_extract_user_sessions_success(mock_user, mock_qa_data):
|
|
42
|
+
"""Test successful extraction of sessions."""
|
|
43
|
+
mock_cache_engine = AsyncMock()
|
|
44
|
+
mock_cache_engine.get_all_qas.return_value = mock_qa_data
|
|
45
|
+
|
|
46
|
+
with (
|
|
47
|
+
patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
|
|
48
|
+
patch.object(
|
|
49
|
+
extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
|
|
50
|
+
),
|
|
51
|
+
):
|
|
52
|
+
mock_session_user.get.return_value = mock_user
|
|
53
|
+
|
|
54
|
+
sessions = []
|
|
55
|
+
async for session in extract_user_sessions([{}], session_ids=["test_session"]):
|
|
56
|
+
sessions.append(session)
|
|
57
|
+
|
|
58
|
+
assert len(sessions) == 1
|
|
59
|
+
assert "Session ID: test_session" in sessions[0]
|
|
60
|
+
assert "Question: What is cognee?" in sessions[0]
|
|
61
|
+
assert "Answer: Cognee is a knowledge graph solution" in sessions[0]
|
|
62
|
+
assert "Question: How does it work?" in sessions[0]
|
|
63
|
+
assert "Answer: It processes data and creates graphs" in sessions[0]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@pytest.mark.asyncio
|
|
67
|
+
async def test_extract_user_sessions_multiple_sessions(mock_user, mock_qa_data):
|
|
68
|
+
"""Test extraction of multiple sessions."""
|
|
69
|
+
mock_cache_engine = AsyncMock()
|
|
70
|
+
mock_cache_engine.get_all_qas.return_value = mock_qa_data
|
|
71
|
+
|
|
72
|
+
with (
|
|
73
|
+
patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
|
|
74
|
+
patch.object(
|
|
75
|
+
extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
|
|
76
|
+
),
|
|
77
|
+
):
|
|
78
|
+
mock_session_user.get.return_value = mock_user
|
|
79
|
+
|
|
80
|
+
sessions = []
|
|
81
|
+
async for session in extract_user_sessions([{}], session_ids=["session1", "session2"]):
|
|
82
|
+
sessions.append(session)
|
|
83
|
+
|
|
84
|
+
assert len(sessions) == 2
|
|
85
|
+
assert mock_cache_engine.get_all_qas.call_count == 2
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@pytest.mark.asyncio
|
|
89
|
+
async def test_extract_user_sessions_no_data(mock_user, mock_qa_data):
|
|
90
|
+
"""Test extraction handles empty data parameter."""
|
|
91
|
+
mock_cache_engine = AsyncMock()
|
|
92
|
+
mock_cache_engine.get_all_qas.return_value = mock_qa_data
|
|
93
|
+
|
|
94
|
+
with (
|
|
95
|
+
patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
|
|
96
|
+
patch.object(
|
|
97
|
+
extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
|
|
98
|
+
),
|
|
99
|
+
):
|
|
100
|
+
mock_session_user.get.return_value = mock_user
|
|
101
|
+
|
|
102
|
+
sessions = []
|
|
103
|
+
async for session in extract_user_sessions(None, session_ids=["test_session"]):
|
|
104
|
+
sessions.append(session)
|
|
105
|
+
|
|
106
|
+
assert len(sessions) == 1
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@pytest.mark.asyncio
|
|
110
|
+
async def test_extract_user_sessions_no_session_ids(mock_user):
|
|
111
|
+
"""Test extraction handles no session IDs provided."""
|
|
112
|
+
mock_cache_engine = AsyncMock()
|
|
113
|
+
|
|
114
|
+
with (
|
|
115
|
+
patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
|
|
116
|
+
patch.object(
|
|
117
|
+
extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
|
|
118
|
+
),
|
|
119
|
+
):
|
|
120
|
+
mock_session_user.get.return_value = mock_user
|
|
121
|
+
|
|
122
|
+
sessions = []
|
|
123
|
+
async for session in extract_user_sessions([{}], session_ids=None):
|
|
124
|
+
sessions.append(session)
|
|
125
|
+
|
|
126
|
+
assert len(sessions) == 0
|
|
127
|
+
mock_cache_engine.get_all_qas.assert_not_called()
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@pytest.mark.asyncio
|
|
131
|
+
async def test_extract_user_sessions_empty_qa_data(mock_user):
|
|
132
|
+
"""Test extraction handles empty Q&A data."""
|
|
133
|
+
mock_cache_engine = AsyncMock()
|
|
134
|
+
mock_cache_engine.get_all_qas.return_value = []
|
|
135
|
+
|
|
136
|
+
with (
|
|
137
|
+
patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
|
|
138
|
+
patch.object(
|
|
139
|
+
extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
|
|
140
|
+
),
|
|
141
|
+
):
|
|
142
|
+
mock_session_user.get.return_value = mock_user
|
|
143
|
+
|
|
144
|
+
sessions = []
|
|
145
|
+
async for session in extract_user_sessions([{}], session_ids=["empty_session"]):
|
|
146
|
+
sessions.append(session)
|
|
147
|
+
|
|
148
|
+
assert len(sessions) == 0
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@pytest.mark.asyncio
|
|
152
|
+
async def test_extract_user_sessions_cache_error_handling(mock_user, mock_qa_data):
|
|
153
|
+
"""Test extraction continues on cache error for specific session."""
|
|
154
|
+
mock_cache_engine = AsyncMock()
|
|
155
|
+
mock_cache_engine.get_all_qas.side_effect = [
|
|
156
|
+
mock_qa_data,
|
|
157
|
+
Exception("Cache error"),
|
|
158
|
+
mock_qa_data,
|
|
159
|
+
]
|
|
160
|
+
|
|
161
|
+
with (
|
|
162
|
+
patch.object(extract_user_sessions_module, "session_user") as mock_session_user,
|
|
163
|
+
patch.object(
|
|
164
|
+
extract_user_sessions_module, "get_cache_engine", return_value=mock_cache_engine
|
|
165
|
+
),
|
|
166
|
+
):
|
|
167
|
+
mock_session_user.get.return_value = mock_user
|
|
168
|
+
|
|
169
|
+
sessions = []
|
|
170
|
+
async for session in extract_user_sessions(
|
|
171
|
+
[{}], session_ids=["session1", "session2", "session3"]
|
|
172
|
+
):
|
|
173
|
+
sessions.append(session)
|
|
174
|
+
|
|
175
|
+
assert len(sessions) == 2
|